IntroductionWhat’s Wrong with Seq2Seq Models?Attention Cues in BiologyQueries, Keys, and ValuesNadaraya-Watson Kernel RegressionAverage PoolingNonparametric Attention PoolingParametric Attention PoolingAttention Pooling and Attention Scoring FunctionsAdditive AttentionScaled Dot-Product AttentionMasked Softmax OperationBasic Attention MechanismsBahdanau AttentionMulti-Head AttentionSelf-AttentionCNNs/RNNs vs. Self-AttentionPositional Encoding“Vanilla” TransformerModelPositionwise Feed-Forward NetworkTransformers for VisionPatch EmbeddingViT’s MLPViT’s “Add & Norm” LayerSummaryTransformer OptimizationsImproved Attention SpanLonger Attention Span: Transformer-XLAdaptive Attention SpanSparse Attention Matrix Factorization: Sparse TransformersLocality-Sensitive Hashing: ReformerKV CachingFlash AttentionSources
Introduction
What’s Wrong with Seq2Seq Models?
- The Seq2seq model originates from language modeling.
- Normally has an encoder-decoder architecture:
- Encoder processes the input sequence and compresses it into a fixed-length context vector, where this representation is expected to be a good summary of the meaning of the whole source sequence.
- Decoder is initialized with the context vector to emit the transformed output.
- Both encoder and decoder are typically RNNs (LSTMs/GRUs).
- Transforms an input sequence (source) to new one (target) where both sequences can be of arbitrary lengths:
- Limitation: incapable to remember long sentences because of the fixed-length context vector.
- The context vector struggles to retain information from long sentences, often causing early parts of the sequence to be forgotten by the time processing completes.⇒ The attention mechanism was born to resolve this problem.
- The context vector struggles to retain information from long sentences, often causing early parts of the sequence to be forgotten by the time processing completes.
- Normally has an encoder-decoder architecture:
Attention Cues in Biology
- When inspecting a visual scene, our optic nerve receives information at the order of bits per second, far exceeding what our brain can fully process. Fortunately, our ancestors had learned from experience (or, data in our case) that not all sensory inputs are created equal. ⇒ Throughout human history, the capability of directing attention to only a fraction of information of interest has enabled our brain to allocate resources more smartly to survive, to grow, and to socialize, such as detecting predators, preys, and mates.
- To explain how our attention is deployed in the visual world, a two-component framework has emerged. Idea dates back to William James, who is considered the “father of American psychology” [James, 2007].
- Nonvolitional cue
- Volitional cue
- While all the paper products are printed in black and white, the coffee cup is red. ⇒ This coffee is intrinsically salient and conspicuous in this visual environment.⇒ Automatically and involuntarily draws attention.⇒ You bring the fovea (the center of the macula where visual acuity is highest) onto the coffee.
⇒ In this framework, subjects selectively direct the spotlight of attention using⇒ Nonvolitional cue is based on the saliency and conspicuity of objects in the environment:Using the nonvolitional cue based on saliency (red cup, non-paper), attention is involuntarily directed to the coffee. ⇒ After drinking coffee, you become caffeinated and want to read a book. You turn your head, refocus your eyes, and look at the book. In this task-dependent case you select the book under cognitive and volitional control:Using the volitional cue (want to read a book) that is task-dependent, attention is directed to the book under volitional control. ⇒ Using the volitional cue based on variable selection criteria, this form of attention is more deliberate. It is also more powerful with the subject’s voluntary effort.
Queries, Keys, and Values
- Inspired by the nonvolitional and volitional attention cues that explain the attentional deployment.
- Consider the case where only nonvolitional cues are available. To bias selection over sensory inputs, we can simply use:
- Parameterized fully-connected layer.
- Non-parameterized max or average pooling.
- In the context of attention mechanisms, we refer to volitional cues as queries.
- Given any query, attention mechanisms bias selection over sensory inputs via attention pooling.
- These sensory inputs are called values in the context of attention mechanisms.
- Every value is paired with a key, which can be thought of the nonvolitional cue of that sensory input.
⇒ What sets attention mechanisms apart from FC or pooling layers is inclusion of the volitional cues.Design attention pooling such that the given query (volitional cue) can interact with keys (nonvolitional cues), which guides bias selection over values (sensory inputs):
Nadaraya-Watson Kernel Regression
- Example of ML algorithm with attention mechanisms.
- We generate a dataset according to the following non-linear function with the noise:Given this dataset, how to learn to predict output for any new input ?
Average Pooling
- Begin with the “dumbest” estimator for this regression problem. ⇒ Use average pooling to average over all the training outputs:⇒ As we can see, this estimator is not so smart as average pooling omits the inputs .
Nonparametric Attention Pooling
- Better idea was proposed by Nadaraya and Watson to weight the outputs according to their input locations:
- is assigned to the corresponding value .
- Consider a Gaussian kernel defined as⇒ Plugging the Gaussian kernel into the last two equations gives us
- Key that is closer to the given query will get more attention via a larger attention weight assigned to the key’s corresponding value.
- Now let us take a look at the attention weights. Here testing inputs are queries while training inputs are keys. Since both inputs are sorted, we can see that the closer the query-key pair is, the higher attention weight is in the attention pooling.
where is a kernel.⇒ We can rewrite it in a more generalized form of attention pooling:where is the query and is the key-value pair, is an attention weight.⇒ Predicted line is smooth and closer to the ground-truth than that produced by average pooling:
Parametric Attention Pooling
- We can easily integrate learnable parameters into attention pooling.
- Comparing with nonparametric attention pooling, the region with large attention weights becomes sharper in the learnable and parametric setting:
⇒ Distance between the query and the key can be multiplied by a learnable parameter :⇒ After training the parametric attention model, we can plot its prediction: - Comparing with nonparametric attention pooling, the region with large attention weights becomes sharper in the learnable and parametric setting:
Attention Pooling and Attention Scoring Functions
- In the previous section, we obtained probability distribution over values that are paired with keys: output of the attention pooling is simply a weighted sum of the values based on the attention weights.
- At a high level, we can use the above algorithm to instantiate the framework of attention mechanisms.
- NB: attention weights are a probability distribution, weighted sum is a weighted average.
- More formally, suppose that we have:
- Query .
- Key-value pairs , where any and any .
⇒ Attention pooling ( before) is instantiated as a weighted sum of the values:where the attention weight (scalar) for the query and key is computed by the softmax operation of an attention scoring function that maps two vectors to a scalar:
⇒ Denoting an attention scoring function by , the following figure illustrates how the output of attention pooling can be computed as a weighted sum of values:Computing the output of attention pooling as a weighted average of values
- Different choices of the attention scoring function lead to different behaviors of attention pooling.
Additive Attention
- When queries and keys are vectors of different lengths, we can use additive attention as the scoring function. Given a query and a key , the additive attention scoring function iswhere , and are learnable parameters.
Scaled Dot-Product Attention
- More computationally efficient but requires both query and key having the same vector of length .
- Assume that all elements of query and key are i.i.d variables with zero mean and unit variance. ⇒ To ensure that the variance of the dot product still remains one regardless of vector length, the scaled dot-product attention scoring function is
- In practice, we often think in mini-batches for efficiency. ⇒ Scaled dot-product attention of queries , keys , and values is
Masked Softmax Operation
- In some cases, not all the values should be fed into attention pooling.
- E.g.: for efficient minibatch processing, some text sequences are padded with special tokens that do not carry meaning. To get an attention pooling over only meaningful tokens as values, we can specify a valid sequence length (in number of tokens) to filter out those beyond this specified range when computing softmax. In this way, we can implement such a masked softmax operation, where any value beyond the valid length is masked as zero.
Basic Attention Mechanisms
Bahdanau Attention
- As we get from the intro, attention mechanism was born to help memorize long source sentences in neural machine translation (NMT).
- is the number of tokens in the input sequence.
- Decoder hidden state at time step is the query.
- Encoder hidden states are both the keys and values.
- Attention weight is computed using the additive attention scoring function.
⇒ Rather than building a single context vector out of the encoder’s last hidden state, the secret sauce invented by attention is to create shortcuts between the context vector and the entire source input. The weights of these shortcut connections are customizable for each output element.⇒ When predicting a token, not all the input tokens are relevant.⇒ Model aligns only to parts of the input sequence that are relevant to the current prediction.⇒ This is achieved by treating the context variable at the decoding time step as an output of attention pooling:where
- Matrix of alignment scores is a nice byproduct to explicitly show the correlation between source and target words:
Alignment matrix of "L'accord sur l'Espace économique européen a été signé en août 1992" (French) and its English translation "The agreement on the European Economic Area was signed in August 1992".
Multi-Head Attention
- Given the same set of queries, keys, and values, we may want our model to combine knowledge from different behaviors of the same attention mechanism, such as capturing dependencies of various ranges (e.g., shorter-range vs. longer-range) within a sequence.
- Usually, understanding the role of a word in a sentence requires understanding how it is related to different parts of the sentence. For example, in some languages, subjects define verb inflection (e.g., gender agreement), verbs define the case of their objects, and many more. In other words, each word is part of many relations.
- and are learnable parameters.
- is attention pooling (e.g. additive attention, scaled dot-product attention, etc.).
⇒ It may be beneficial to allow our attention mechanism to jointly use different representation subspaces of queries, keys, and values.⇒ Let model focus on different things from different representation subspaces at different positions.⇒ Instead of performing a single attention pooling, queries, keys, and values can be transformed with independently learned linear projections. Then these projected queries, keys, and values are fed into attention pooling in parallel. In the end, attention pooling outputs (heads) are concatenated and transformed with another learned linear projection to produce the final output.⇒ Given a query , a key , and a value , each attention head is computed aswhere⇒ Multi-head attention output is a linear transformation via of the concatenation of heads:Multi-head attention, where multiple heads are concatenated then linearly transformed
Self-Attention
- We often use CNNs or RNNs to encode a sequence.
- E.g. it can be multi-head attention.
⇒ Idea: with attention mechanisms in mind, try to feed a sequence of tokens into attention pooling so that the same set of tokens act as queries, keys, and values.⇒ Since queries, keys, and values come from the same place, this performs self-attention.⇒ Given a sequence of input tokens , , its self-attention outputs a sequence of the same length , where
CNNs/RNNs vs. Self-Attention

- Let’s compare architectures for mapping a sequence of tokens to another sequence of equal length, where each input or output token is represented by a -dimensional vector.
- CNN (consider a convolutional layer whose kernel size is ):
- Computational complexity of the convolutional layer is .
- Since CNNs are hierarchical, there are sequential operations and maximum path length is .
- RNN:
- When updating the hidden state, multiplication of the weight matrix and the -dimensional hidden state has a computational complexity of . Since the sequence length is , the computational complexity of the recurrent layer is .
- According to the figure above, there are sequential operations that cannot be parallelized and the maximum path length is also .
- Self-attention:
- Queries, keys, and values are all matrices. Consider the scaled dot-product attention, where a matrix is multiplied by a matrix, then the output matrix is multiplied by a matrix. As a result, the self-attention has a computational complexity.
- As we can see in the figure, each token is directly connected to any other token via self-attention. Therefore, computation can be parallel with sequential operations and the maximum path length is also .
- CNN (consider a convolutional layer whose kernel size is ):
Self-attention enjoy parallel computation and has shortest maximum path length.
- However, the quadratic computational complexity with respect to a sequence length makes self-attention prohibitively slow for very long sequences.
Positional Encoding
- Self-attention ditches sequential operations in favor of parallel computation. ⇒ To use the sequence order information, we can inject absolute or relative positional information.⇒ Add positional encoding to input representations where these encodings can be learned or fixed.
- Fixed positional encoding can be based on sine and cosine functions [Vaswani et al., 2017].
- In the positional embedding matrix , rows correspond to positions within a sequence and columns represent different positional encoding dimensions.
- E.g.: from the graph below, we can see that the -th and the -th columns of the positional embedding matrix have a higher frequency than the -th and the -th columns. The offset between the -th and the -th (same for the -th and the -th) columns is due to the alternation of sine and cosine functions.
- E.g.: from the graph below, we can see that the -th and the -th columns of the positional embedding matrix have a higher frequency than the -th and the -th columns. The offset between the -th and the -th (same for the -th and the -th) columns is due to the alternation of sine and cosine functions.
⇒ Called sinusoidal positional encoding therefore.⇒ Input representation contains -dimensional embeddings for tokens of a sequence.⇒ Positional encoding outputs using a positional embedding matrix of the same shape, whose element on the -th row and the -th or the -th column is - In the positional embedding matrix , rows correspond to positions within a sequence and columns represent different positional encoding dimensions.
Absolute Positional Information
- Let’s see how monotonically decreased frequency along encoding dimension relates to absolute positional info.
- Since the outputs are float numbers, such continuous representations are more space-efficient than binary representations.
⇒ Let’s print out the binary representations of :⇒ Lowest bit, 2nd-lowest bit, and 3rd-lowest bit alternate on every number, every 2 numbers, and every 4 numbers.⇒ In binary representations, a higher bit has a lower frequency than a lower bit.⇒ Similarly, positional encoding decreases frequencies along encoding dimension by using trigonometric functions:
Relative Positional Information
- Besides capturing absolute positional information, the above positional encoding also allows a model to easily learn to attend by relative positions.
- This is because for any fixed position offset , the positional encoding at position can be represented by a linear projection of that at position :where and the projection matrix does not depend on any position index .⇒ Any pair of can be linearly projected to for any fixed offset .
- This is because for any fixed position offset , the positional encoding at position can be represented by a linear projection of that at position :
“Vanilla” Transformer
- We have compared CNNs, RNNs, and self-attention.⇒ Self-attention enjoys both parallel computation and the shortest maximum path length.⇒ It’s appealing to design deep architectures by using self-attention.⇒ Transformer model is solely based on attention mechanisms.
Model
- Transformer is composed of an encoder and a decoder:
- Input (source) and output (target) sequence embeddings are added with sinusoidal positional encoding before being fed into the encoder and the decoder that stack modules based on self-attention.
- Encoder is a stack of multiple identical layers, where each layer has 2 sublayers.
- The first is a multi-head self-attention pooling.
- The second is a position-wise feed-forward network.
- By position-wise (or point-wise), it means that it applies the same linear transformation (with the same weights) to each element in the sequence.
- This can also be viewed as a convolutional layer with filter size of 1.
- Inspired by the ResNet design, a residual connection is employed around both sublayers.
- The addition from the residual connection is followed by layer normalization [Ba et al., 2016].
⇒ Encoder outputs a -dimensional vector representation for each position of the input sequence.⇒ Generates an attention-based representation with capability to locate a specific piece of information from a large context. - Decoder is similar to the encoder, except that the decoder contains two multi-head attention sublayers instead of one in each identical repeating layers.
- The first multi-head attention sublayer is masked to prevent positions from attending to the future. Thus, each position in decoder is allowed to only attend to all positions in the decoder up to that position.
- This masked attention preserves the auto-regressive property, ensuring that the prediction only depends on those output tokens that have been generated.
- The second multi-head attention sublayer is called encoder-decoder attention, where queries are from outputs of previous decoder layer, and keys and values are from transformer encoder outputs.
⇒ Retrieves information from the encoded representation. - The first multi-head attention sublayer is masked to prevent positions from attending to the future. Thus, each position in decoder is allowed to only attend to all positions in the decoder up to that position.
The architecture of the vanilla Transformer model.
Positionwise Feed-Forward Network
- Position-wise FFNN transforms representation at all the sequence positions using the same MLP.
- Since the same MLP transforms at all the positions, when the inputs at all these positions are the same, their outputs are also identical:
⇒ This is why it is called position-wise: - Since the same MLP transforms at all the positions, when the inputs at all these positions are the same, their outputs are also identical:
Transformers for Vision
- Transformer architecture was initially proposed for Seq2Seq learning, with a focus on machine translation. ⇒ Later on, it emerged as the model of choice in various natural language processing tasks.⇒ In the field of computer vision the dominant architecture has remained the CNN.⇒ Researchers started to wonder if it’s possible to do better by adapting Transformer to image data.⇒ Transformers have also become a game-changer in computer vision.
- Vision Transformers (ViTs) extract patches from images and feed them into a Transformer encoder to obtain a global representation, which will finally be transformed into the output label:
- Consider an input image with height , width , channels, and patch height and width both as . ⇒ Image is split into a sequence of patches (each patch is flattened to a vector of length ).⇒ Image patches can be treated similarly to tokens in text sequences by Transformer encoders.
The vision Transformer architecture. In this example, an image is split into nine patches. A special “<cls>” token and the nine flattened image patches are transformed via patch embedding and n Transformer encoder blocks into ten representations, respectively. The “<cls>” representation is further transformed into the output label. - Consider an input image with height , width , channels, and patch height and width both as .
Patch Embedding
- Splitting an image into patches and linearly projecting these flattened patches can be simplified as a single convolution operation, where both the kernel size and the stride size are set to the patch size:
ViT’s MLP
- The MLP of vision Transformer encoder is slightly different from the position-wise FFN of original Transformer encoder.
- Activation function uses the GELU, which is smoother version of the ReLU.
- Dropout is applied to the output of each fully connected layer in the MLP for regularization.
ViT’s “Add & Norm” Layer
- Normalization is applied right before the multi-head attention or MLP (not after as in vanilla Transformer).
Summary
- Input image is fed into
PatchEmbedding
whose output is concatenated with<cls>
token embedding.⇒ It’s summed with learnable positional embeddings before dropout.⇒ Then output is fed into Transformer encoder that stacksnum_blks
instances ofViTBlock
class.⇒ Finally, the representation of the<cls>
token is projected by the network head.
- For small datasets like ImageNet (1.2M images), vision Transformer does not outperform the ResNet.
- Transformers lack useful principles in convolution, such as translation invariance and locality.
⇒ However, the picture changes when training larger models on larger datasets (e.g., 300M images).⇒ Vision Transformers outperform ResNets by a large margin in image classification.⇒ Superiority of Transformers in scalability.
- It was shown to be effective with data-efficient training strategies of DeiT (Touvron et al., 2021).
- Quadratic complexity of self-attention makes the architecture less suitable for higher-resolution images. ⇒ Swin Transformers addressed the quadratic computational complexity with respect to image size.
Transformer Optimizations
- As we get it, the transformer architecture is heavily bottlenecked by the self-attention mechanism, which has quadratic time and memory complexity.⇒ Try to improve Transformer architecture and attention operation in various ways.
Improved Attention Span
- Make the context that can be used in self-attention longer, more efficient and flexible.
Longer Attention Span: Transformer-XL
“XL” means “extra long”
- The vanilla Transformer has a fixed and limited attention span.
- The model cannot capture very long term dependencies.
- It is hard to predict the first few tokens in each segment given no or thin context.
- The evaluation is expensive: whenever the segment is shifted to the right by one, the new segment is re-processed from scratch, although there are a lot of overlapped tokens.
- Reusing hidden states between segments.
- Adopting a new positional encoding that is suitable for reused states.
⇒ The model can only attend to other elements in the same segments during each update step and no information can flow across separated fixed-length segments.⇒ This context segmentation causes several issues:⇒ Transformer-XL solves the context segmentation problem with two main modifications:
Hidden State Reuse
- The recurrent connection between segments is introduced into the model by continuously using the hidden states from the previous segments:
- Let's label the hidden state of the -th layer for the -th segment in the model as .
- Not the difference between and .
- NB: both key and value rely on the extended hidden state, while the query only consumes hidden state at current step.
- The concatenation operation is along the sequence length dimension.
⇒ In addition to the hidden state of the last layer for the same segment , it also depends on the hidden state of the same layer for the previous segment .⇒ By incorporating information from the previous hidden states, the model extends the attention span much longer in the past, over multiple segments:
A comparison between the training phrase of vanilla Transformer & Transformer-XL with a segment length 4. (Image source: left part of Figure 2 in Dai et al., 2019). - Let's label the hidden state of the -th layer for the -th segment in the model as .
Relative Positional Encoding
- In order to work with this new form of attention span, Transformer-XL proposed a new type of positional encoding based on reparametrization of dot-product of keys and queries.
- Q: Why?A: If using the same approach by vanilla Transformer and encoding the absolute position, the previous and current segments will be assigned with the same encoding, which is undesired.
- Replaces with relative positional encoding .
- Replaces with 2 trainable parameters (for content) and (for location) in 2 different terms.
- Splits into two matrices, for content information and for location information.
⇒ To keep the positional information flow coherently across segments, Transformer-XL encodes the relative position instead, as it could be sufficient enough to know the position offset for making good predictions, i.e. , between one key vector and its query .⇒ If omitting the scalar and the normalizing term in softmax but including positional encodings, we can write the attention score between query at position and key at position as:⇒ Transformer-XL reparameterizes the above four terms as follows: - Q: Why?
Adaptive Attention Span
- One key advantage of Transformer is the capability of capturing long-term dependencies. ⇒ Depending on the context, the model may prefer to attend further sometime than others; or one attention head may had different attention pattern from the other.⇒ If the attention span could adapt its length flexibly and only attend further back when needed, it would help reduce both computation and memory cost to support longer maximum context size in the model.⇒ This is the motivation for Adaptive Attention Span—a self-attention mechanism that seeks an optimal attention span.
- The authors hypothesized that different attention heads might assign scores differently within the same context window:
Two attention heads in the same model, A & B, assign attention differently within the same context window. Head A attends more to the recent tokens, while head B look further back into the past uniformly. (Image source: Sukhbaatar, et al. 2019) ⇒ Optimal span should be trained separately per head.
- Given the -th token, we need to compute the attention weights between this token and other keys at positions , where defines the -th token's context window:
- A soft mask function is added to control for an effective adjustable attention span, which maps the distance between query and key into a value. is parameterized by and is to be learned:
- is differentiable so it is trained jointly with other parts of the model.
- Parameters , are learned separately per head.
- Moreover, the loss function has an extra L1 penalty on .
where is a hyper-parameter which defines the softness of .The soft masking function used in the adaptive attention span. (Image source: Sukhbaatar, et al. 2019.) ⇒ The soft mask function is applied to the softmax elements in the attention weights:
- In the experiments of Transformer with adaptive attention span, Sukhbaatar, et al. (2019) found a general tendency that lower layers do not require very long attention spans, while a few attention heads in higher layers may use exceptionally long spans.
- Adaptive attention span also helps greatly reduce the number of FLOPS, especially in a big model with many attention layers and a large context length.
Sparse Attention Matrix Factorization: Sparse Transformers
- Sparse Transformer introduced factorized self-attention, through sparse matrix factorization.⇒ Makes possible to train dense attention networks with hundreds of layers on sequence length up to 16384, which would be infeasible on modern hardware otherwise.
- a set of attention connectivity pattern , where each records a set of key positions that the -th query vector attends to:
- NB: the size of is not fixed, while is always of size , and thus .
where
- In auto-regressive models, one attention span is defined as as it allows each token to attend to all the positions in the past. In factorized self-attention, the set is decomposed into a tree of dependencies, such that for every pair of where , there is a path connecting back to and can attend to either directly or indirectly.⇒ Precisely, the set is divided into non-overlapping subsets, where the -th subset is denoted as , .
- Sparse Transformer proposed two types of fractorized attention:
- Strided attention with stride :
- Works well with image data as the structure is aligned with strides. In the image case, each pixel would attend to all the previous pixels in the raster scanning order (naturally cover the entire width of the image) and then those pixels attend to others in the same column (defined by another attention connectivity subset).
- Fixed attention:
- A small set of tokens summarize previous locations and propagate that information to all future locations.
where is a hyperparameter (if , it restricts the representation whereas many depend on a few positions; the paper chose for ).
The top row illustrates the attention connectivity patterns in (a) Transformer, (b) Sparse Transformer with strided attention, and (c) Sparse Transformer with fixed attention. The bottom row contains corresponding self-attention connectivity matrices. Note that the top and bottom rows are not in the same scale. (Image source: Child et al., 2019 + a few of extra annotations.) - Strided attention with stride :
- There are three ways to use sparse factorized attention patterns in Transformer architecture:
- One attention type per residual block and then interleave them, , where is the index of the current residual block.
- Set up a single head which attends to locations that all the factorized heads attend to, .
- Use a multi-head attention mechanism, but different from vanilla Transformer, each head might adopt a pattern presented above, 1 or 2 .
Locality-Sensitive Hashing: Reformer
- Reformer proposed two main changes:
- Replace the dot-product attention with locality-sensitive hashing (LSH) attention, reducing the complexity from to .
- Replace the standard residual blocks with reversible residual layers, which allows storing activations only once during training instead of times (i.e. proportional to the number of layers).
The LSH attention consists of 4 steps: bucketing, sorting, chunking, and attention computation. (Image source: left part of Figure 1 in Kitaev, et al. 2020).
KV Caching
- Applicable only to auto-regressive inference, not training.
- In self-attention, for each token in the input sequence, the model computes key and value vectors.
- Keys (K) and Values (V) for past tokens are computed once and cached.
- Only the new token's K and V are computed and appended to the cache.
- Attention operates over the cached K and V, avoiding recomputation.
⇒ During generation, this computation is repeated unnecessarily for previously seen tokens.⇒ KV caching solves this by storing the computed vectors for past tokens and reusing them in future steps.⇒ Avoids redundant computations.
Flash Attention
- Attention algorithm used to scale transformer models more efficiently, enabling faster training and inference.
- Proposed to optimize the attention computations by writing custom CUDA kernels.⇒ Makes them much faster and more memory efficient.
- Proposed to optimize the attention computations by writing custom CUDA kernels.
- Standard attention mechanism relies on High Bandwidth Memory (HBM) to write/read keys/queries/values.
- NB: global memory of the GPU is confusingly called the "High Bandwidth Memory" here…
- HBM offers large capacity and high bandwidth but has higher latency compared to on-chip memory like SRAM / Shared Memory.
- Model loads this data from HBM into GPU on-chip Shared Memory.
- Performs a step of the attention mechanism.
- Writes the result back to HBM.
- Repeats this process for each attention step.
⇒ In the standard implementation, the cost of writing/reading keys/queries/values from HBM is high:Basic implementation of the attention mechanism involve a lot of transfer between memory and workers. It requires materializing the S and P matrices in HBM which means that the results need to be sent to HBM and then back to SRAM for the next computations. ⇒ Flash Attention make efficient use of the various GPU memories to avoid relying too much on the slowest one (global memory).
- FlashAttention optimizes this workflow by loading keys, queries, and values into Shared Memory once, fusing the operations of the attention mechanism (such as softmax and matrix multiplication), and writing the result back to HBM:
- Key element is to compute the S matrices in small pieces which can fit in the smaller shared memory of the SM.→ But we can do even better and avoid materializing the very large S matrix all together in favor of keeping only the necessary statistics for computing the normalization factor of the softmax.→ We can compute part of O directly in one computation in SRAM rather than moving intermediate results back and forth.→ Not even do we make use of the shared memory but we also release the memory bottleneck resulting from materializing one of the largest activation matrices in the model (at long context length), the attention matrix.
- By avoiding to materialize the S matrix we reduce the memory burden of attention.
- We also remove a large part of the naive impact of the S^2 cost of attention.
Left: FlashAttention uses tiling to prevent materialization of the large attention matrix (dotted box) on (relatively) slow GPU HBM. In the outer loop (red arrows), FlashAttention loops through blocks of the and matrices and loads them to fast on-chip SRAM. In each block, FlashAttention loops over blocks of matrix (blue arrows), loading them to SRAM, and writing the output of the attention computation back to HBM. Right: Speedup over the PyTorch implementation of attention on GPT-2. Flashattention does not read and write the large attention matrix to HBM, resulting in an speedup on the attention computation. ⇒ Significantly reduces memory access overhead and improves performance.⇒ All variants of linear attention and sub-quadratic approaches to approximate attention—developed shortly after the invention of the transformers architecture—have been mostly put aside in favor of this exact and fast flash attention implementation and mechanism. - Key element is to compute the S matrices in small pieces which can fit in the smaller shared memory of the SM.
- Following Flash-attention 1, two successive improved versions have been released by the same lab: Flash-attention 2 and 3.
- In comparison to Flash-attention 1, the improvements in Flash-attention 2 and 3 are less about the general attention mechanism than about tailoring its low level implementation more specifically to the GPU by (1) reducing the number of non-matmul operations as much as possible (2) partitioning carefully the workload among wraps and thread blocks (for Flash Attention 2) and carefully optimizing for FP8 and Tensor Core support on the latest Hopper (H100) architecture for Flash Attention 3.