3 approaches to linear-memory Transformers

Transformers are a very popular architecture for processing sequential data, notably text and (our interest) proteins. Transformers learn more complex patterns with larger models on more data, as demonstrated by models like GPT-4 and ESM-2. Transformers work by updating tokens according to an attention value computed as a weighted sum of all other tokens. In standard implentations this requires computing the product of a query and key matrix which requires O(N2d) computations and, problematically, O(N2) memory for a sequence of length N and an embedding size of d. To speed up Transformers, and to analyze longer sequences, several variants have been proposed which require only O(N) memory. Broadly, these can be divided into sparse methods, softmax-approximators, and memory-efficient Transformers.

Sparse methods

Sparse attention strategies (adapted from https://blog.research.google/2020/10/rethinking-attention-with-performers.html)

The first, and simplest method to reduce the memory requirement is to limit the number of other tokens that each token attends to. This is done implicitely in methods like CNNs and explicitely in many GNNs which compute the k-nearest neighbours. The main limitation of sparse models is obvious, they don’t connect all the nodes and so information may not be passed as efficiently. There are ways to mitigate these issues such as creating global nodes that are connected to all other nodes or adding a hidden memory during sequential decoding. Sparse methods are promising but there are other methods which ensure linear memory with fully-connected graphs.

Softmax approximators

Reordering matrix multiplications reduces complexity (adapted from https://blog.research.google/2020/10/rethinking-attention-with-performers.html)

There is a neat feature about matrix multiplication where the computational cost of multiplying 3 matrices depends on the the order in which you do the multiplication. If we don’t need to compute the softmax function of the attention matrix, A=QKT, then we can compute (QKT)V as Q(KTV). This changes the computational complexity from O(N2d) to O(Nd2). Similarly, because the intermediate attention matrix, A, is never explicitely computed, the memory scales by d2 instead of N2. The drawback is subtle: modelling the “softamax kernel” as the inner product of two mapped embeddings is very challenging: (<f(x), f(y)> = e<x, y>). In general, it requires an infite dimensional embedding size. Reasonable approximations exist under some cases and methods like the Performer report good results on some tasks, but the missing softmax can lead to issues with training stability or accuracy.

Memory-efficient Transformers

Computing a block of A “on the fly” is faster than retrieving it from memory (adapted from https://arxiv.org/abs/2205.14135)

The final approach relies on the fact that the softmax can be computed exactly in a block-wise fashion, such that the intermediate matrix A never has to be stored completely in memory. Each of these intermediate blocks can be passed to different GPU threads which makes them fast. Interestingly, FlashAttention achieves faster attention in linear memory because the authors find that it is slower to retrieve a precomputed block of A from memory than it is to just recompute it from blocks of Q and K. This means that almost all standard Transformers use FlashAttention to compute attention quickly in linear memory. If you are using a Transformer implementation from a standard library like PyTorch, you are probably using FlashAttention without even thinking about it.

Author