IntroductionWhat’s Wrong with Seq2Seq Models?Attention Cues in BiologyQueries, Keys, and ValuesNadaraya-Watson Kernel RegressionAverage PoolingNonparametric Attention PoolingParametric Attention PoolingAttention Pooling and Attention Scoring FunctionsAdditive AttentionScaled Dot-Product AttentionMasked Softmax OperationBasic Attention MechanismsBahdanau AttentionMulti-Head AttentionSelf-AttentionCNNs/RNNs vs. Self-AttentionPositional Encoding“Vanilla” TransformerModelPositionwise Feed-Forward NetworkTransformers for VisionPatch EmbeddingViT’s MLPViT’s “Add & Norm” LayerSummaryTransformer OptimizationsSources
Introduction
What’s Wrong with Seq2Seq Models?
- The Seq2seq model originates from language modeling.
- Normally has an encoder-decoder architecture:
- Encoder processes the input sequence and compresses it into a fixed-length context vector, where this representation is expected to be a good summary of the meaning of the whole source sequence.
- Decoder is initialized with the context vector to emit the transformed output.
- Both encoder and decoder are typically RNNs (LSTMs/GRUs).
- Transforms an input sequence (source) to new one (target) where both sequences can be of arbitrary lengths:
- Limitation: incapable to remember long sentences because of the fixed-length context vector.
- The context vector struggles to retain information from long sentences, often causing early parts of the sequence to be forgotten by the time processing completes.⇒ The attention mechanism was born to resolve this problem.
- The context vector struggles to retain information from long sentences, often causing early parts of the sequence to be forgotten by the time processing completes.
- Normally has an encoder-decoder architecture:
Attention Cues in Biology
- When inspecting a visual scene, our optic nerve receives information at the order of Loading equation...bits per second, far exceeding what our brain can fully process. Fortunately, our ancestors had learned from experience (or, data in our case) that not all sensory inputs are created equal.⇒ Throughout human history, the capability of directing attention to only a fraction of information of interest has enabled our brain to allocate resources more smartly to survive, to grow, and to socialize, such as detecting predators, preys, and mates.
- To explain how our attention is deployed in the visual world, a two-component framework has emerged. Idea dates back to William James, who is considered the “father of American psychology” [James, 2007].
- Nonvolitional cue
- Volitional cue
- While all the paper products are printed in black and white, the coffee cup is red. ⇒ This coffee is intrinsically salient and conspicuous in this visual environment.⇒ Automatically and involuntarily draws attention.⇒ You bring the fovea (the center of the macula where visual acuity is highest) onto the coffee.
⇒ In this framework, subjects selectively direct the spotlight of attention using⇒ Nonvolitional cue is based on the saliency and conspicuity of objects in the environment:Using the nonvolitional cue based on saliency (red cup, non-paper), attention is involuntarily directed to the coffee. ⇒ After drinking coffee, you become caffeinated and want to read a book. You turn your head, refocus your eyes, and look at the book. In this task-dependent case you select the book under cognitive and volitional control:Using the volitional cue (want to read a book) that is task-dependent, attention is directed to the book under volitional control. ⇒ Using the volitional cue based on variable selection criteria, this form of attention is more deliberate. It is also more powerful with the subject’s voluntary effort.
Queries, Keys, and Values
- Inspired by the nonvolitional and volitional attention cues that explain the attentional deployment.
- Consider the case where only nonvolitional cues are available. To bias selection over sensory inputs, we can simply use:
- Parameterized fully-connected layer.
- Non-parameterized max or average pooling.
- In the context of attention mechanisms, we refer to volitional cues as queries.
- Given any query, attention mechanisms bias selection over sensory inputs via attention pooling.
- These sensory inputs are called values in the context of attention mechanisms.
- Every value is paired with a key, which can be thought of the nonvolitional cue of that sensory input.
⇒ What sets attention mechanisms apart from FC or pooling layers is inclusion of the volitional cues.Loading equation...Design attention pooling such that the given query (volitional cue) can interact with keys (nonvolitional cues), which guides bias selection over values (sensory inputs):
Nadaraya-Watson Kernel Regression
- Example of ML algorithm with attention mechanisms.
- We generate a dataset Loading equation...according to the following non-linear function with the noise:Loading equation...Given this dataset, how to learnLoading equation...to predict outputLoading equation...for any new inputLoading equation...?
Average Pooling
- Begin with the “dumbest” estimator for this regression problem. ⇒ Use average pooling to average over all the training outputs:Loading equation...⇒ As we can see, this estimator is not so smart as average pooling omits the inputsLoading equation....
Nonparametric Attention Pooling
- Better idea was proposed by Nadaraya and Watson to weight the outputs Loading equation...according to their input locations:
- Loading equation...is assigned to the corresponding valueLoading equation....
- Consider a Gaussian kernel defined asLoading equation...⇒ Plugging the Gaussian kernel into the last two equations gives usLoading equation...
- Key Loading equation...that is closer to the given queryLoading equation...will get more attention via a larger attention weight assigned to the key’s corresponding value.
- Now let us take a look at the attention weights. Here testing inputs are queries while training inputs are keys. Since both inputs are sorted, we can see that the closer the query-key pair is, the higher attention weight is in the attention pooling.
Loading equation...whereLoading equation...is a kernel.⇒ We can rewrite it in a more generalized form of attention pooling:Loading equation...whereLoading equation...is the query andLoading equation...is the key-value pair,Loading equation...is an attention weight.⇒ Predicted line is smooth and closer to the ground-truth than that produced by average pooling:
Parametric Attention Pooling
- We can easily integrate learnable parameters into attention pooling.
- Comparing with nonparametric attention pooling, the region with large attention weights becomes sharper in the learnable and parametric setting:
⇒ Distance between the queryLoading equation...and the keyLoading equation...can be multiplied by a learnable parameterLoading equation...:Loading equation...⇒ After training the parametric attention model, we can plot its prediction: - Comparing with nonparametric attention pooling, the region with large attention weights becomes sharper in the learnable and parametric setting:
Attention Pooling and Attention Scoring Functions
- In the previous section, we obtained probability distribution over values that are paired with keys: output of the attention pooling is simply a weighted sum of the values based on the attention weights.
- At a high level, we can use the above algorithm to instantiate the framework of attention mechanisms.
- NB: attention weights are a probability distribution, weighted sum is a weighted average.
- More formally, suppose that we have:
- Query Loading equation....
- Key-value pairs Loading equation..., where anyLoading equation...and anyLoading equation....
⇒ Attention pooling (Loading equation...before) is instantiated as a weighted sum of the values:Loading equation...where the attention weight (scalar) for the queryLoading equation...and keyLoading equation...is computed by the softmax operation of an attention scoring functionLoading equation...that maps two vectors to a scalar:Loading equation... - Query
⇒ Denoting an attention scoring function byLoading equation..., the following figure illustrates how the output of attention pooling can be computed as a weighted sum of values:Computing the output of attention pooling as a weighted average of values
- Different choices of the attention scoring function Loading equation...lead to different behaviors of attention pooling.
Additive Attention
- When queries and keys are vectors of different lengths, we can use additive attention as the scoring function. Given a query Loading equation...and a keyLoading equation..., the additive attention scoring function isLoading equation...whereLoading equation..., andLoading equation...are learnable parameters.Loading code...
Scaled Dot-Product Attention
- More computationally efficient but requires both query and key having the same vector of length Loading equation....
- Assume that all elements of query and key are i.i.d variables with zero mean and unit variance. Loading equation...Loading equation...⇒ To ensure that the variance of the dot product still remains one regardless of vector length, the scaled dot-product attention scoring function isLoading equation...
- In practice, we often think in mini-batches for efficiency. ⇒ Scaled dot-product attention of queriesLoading equation..., keysLoading equation..., and valuesLoading equation...isLoading equation...Loading code...
Masked Softmax Operation
- In some cases, not all the values should be fed into attention pooling.
- E.g.: for efficient minibatch processing, some text sequences are padded with special tokens that do not carry meaning. To get an attention pooling over only meaningful tokens as values, we can specify a valid sequence length (in number of tokens) to filter out those beyond this specified range when computing softmax. In this way, we can implement such a masked softmax operation, where any value beyond the valid length is masked as zero.
Basic Attention Mechanisms
Bahdanau Attention
- As we get from the intro, attention mechanism was born to help memorize long source sentences in neural machine translation (NMT).
- Loading equation...is the number of tokens in the input sequence.
- Decoder hidden state Loading equation...at time stepLoading equation...is the query.
- Encoder hidden states Loading equation...are both the keys and values.
- Attention weight Loading equation...is computed using the additive attention scoring function.
⇒ Rather than building a single context vector out of the encoder’s last hidden state, the secret sauce invented by attention is to create shortcuts between the context vector and the entire source input. The weights of these shortcut connections are customizable for each output element.⇒ When predicting a token, not all the input tokens are relevant.⇒ Model aligns only to parts of the input sequence that are relevant to the current prediction.⇒ This is achieved by treating the context variable at the decoding time stepLoading equation...as an output of attention pooling:Loading equation...where
- Matrix of alignment scores is a nice byproduct to explicitly show the correlation between source and target words:
Alignment matrix of "L'accord sur l'Espace économique européen a été signé en août 1992" (French) and its English translation "The agreement on the European Economic Area was signed in August 1992".
Multi-Head Attention
- Given the same set of queries, keys, and values, we may want our model to combine knowledge from different behaviors of the same attention mechanism, such as capturing dependencies of various ranges (e.g., shorter-range vs. longer-range) within a sequence.
- Usually, understanding the role of a word in a sentence requires understanding how it is related to different parts of the sentence. For example, in some languages, subjects define verb inflection (e.g., gender agreement), verbs define the case of their objects, and many more. In other words, each word is part of many relations.
- Loading equation...andLoading equation...are learnable parameters.
- Loading equation...is attention pooling (e.g. additive attention, scaled dot-product attention, etc.).
⇒ It may be beneficial to allow our attention mechanism to jointly use different representation subspaces of queries, keys, and values.⇒ Let model focus on different things from different representation subspaces at different positions.⇒ Instead of performing a single attention pooling, queries, keys, and values can be transformed withLoading equation...independently learned linear projections. Then theseLoading equation...projected queries, keys, and values are fed into attention pooling in parallel. In the end,Loading equation...attention pooling outputs (heads) are concatenated and transformed with another learned linear projection to produce the final output.⇒ Given a queryLoading equation..., a keyLoading equation..., and a valueLoading equation..., each attention head is computed asLoading equation...where⇒ Multi-head attention output is a linear transformation viaLoading equation...of the concatenation ofLoading equation...heads:Loading equation...Multi-head attention, where multiple heads are concatenated then linearly transformed
Self-Attention
- We often use CNNs or RNNs to encode a sequence.
- E.g. it can be multi-head attention.
⇒ Idea: with attention mechanisms in mind, try to feed a sequence of tokens into attention pooling so that the same set of tokens act as queries, keys, and values.⇒ Since queries, keys, and values come from the same place, this performs self-attention.⇒ Given a sequence of input tokensLoading equation...,Loading equation..., its self-attention outputs a sequence of the same lengthLoading equation..., whereLoading equation...
CNNs/RNNs vs. Self-Attention

- Let’s compare architectures for mapping a sequence of Loading equation...tokens to another sequence of equal length, where each input or output token is represented by aLoading equation...-dimensional vector.
- CNN (consider a convolutional layer whose kernel size is Loading equation...):
- Computational complexity of the convolutional layer is Loading equation....
- Since CNNs are hierarchical, there are Loading equation...sequential operations and maximum path length isLoading equation....
- Computational complexity of the convolutional layer is
- RNN:
- When updating the hidden state, multiplication of the Loading equation...weight matrix and theLoading equation...-dimensional hidden state has a computational complexity ofLoading equation.... Since the sequence length isLoading equation..., the computational complexity of the recurrent layer isLoading equation....
- According to the figure above, there are Loading equation...sequential operations that cannot be parallelized and the maximum path length is alsoLoading equation....
- When updating the hidden state, multiplication of the
- Self-attention:
- Queries, keys, and values are all Loading equation...matrices. Consider the scaled dot-product attention, where aLoading equation...matrix is multiplied by aLoading equation...matrix, then the outputLoading equation...matrix is multiplied by aLoading equation...matrix. As a result, the self-attention has aLoading equation...computational complexity.
- As we can see in the figure, each token is directly connected to any other token via self-attention. Therefore, computation can be parallel with Loading equation...sequential operations and the maximum path length is alsoLoading equation....
- Queries, keys, and values are all
- CNN (consider a convolutional layer whose kernel size is
Loading equation...
Self-attention enjoy parallel computation and has shortest maximum path length. - However, the quadratic computational complexity with respect to a sequence length makes self-attention prohibitively slow for very long sequences.
Positional Encoding
- Self-attention ditches sequential operations in favor of parallel computation. ⇒ To use the sequence order information, we can inject absolute or relative positional information.⇒ Add positional encoding to input representations where these encodings can be learned or fixed.
- Fixed positional encoding can be based on sine and cosine functions [Vaswani et al., 2017].
- In the positional embedding matrix Loading equation..., rows correspond to positions within a sequence and columns represent different positional encoding dimensions.
- E.g.: from the graph below, we can see that the Loading equation...-th and theLoading equation...-th columns of the positional embedding matrix have a higher frequency than theLoading equation...-th and theLoading equation...-th columns. The offset between theLoading equation...-th and theLoading equation...-th (same for theLoading equation...-th and theLoading equation...-th) columns is due to the alternation of sine and cosine functions.
- E.g.: from the graph below, we can see that the
⇒ Called sinusoidal positional encoding therefore.⇒ Input representationLoading equation...containsLoading equation...-dimensional embeddings forLoading equation...tokens of a sequence.⇒ Positional encoding outputsLoading equation...using a positional embedding matrixLoading equation...of the same shape, whose element on theLoading equation...-th row and theLoading equation...-th or theLoading equation...-th column isLoading equation... - In the positional embedding matrix
Absolute Positional Information
- Let’s see how monotonically decreased frequency along encoding dimension relates to absolute positional info.
- Since the outputs are float numbers, such continuous representations are more space-efficient than binary representations.
⇒ Let’s print out the binary representations ofLoading equation...:Loading code...⇒ Lowest bit, 2nd-lowest bit, and 3rd-lowest bit alternate on every number, every 2 numbers, and every 4 numbers.⇒ In binary representations, a higher bit has a lower frequency than a lower bit.⇒ Similarly, positional encoding decreases frequencies along encoding dimension by using trigonometric functions:
Relative Positional Information
- Besides capturing absolute positional information, the above positional encoding also allows a model to easily learn to attend by relative positions.
- This is because for any fixed position offset Loading equation..., the positional encoding at positionLoading equation...can be represented by a linear projection of that at positionLoading equation...:Loading equation...whereLoading equation...and theLoading equation...projection matrix does not depend on any position indexLoading equation....⇒ Any pair ofLoading equation...can be linearly projected toLoading equation...for any fixed offsetLoading equation....
- This is because for any fixed position offset
“Vanilla” Transformer
- We have compared CNNs, RNNs, and self-attention.⇒ Self-attention enjoys both parallel computation and the shortest maximum path length.⇒ It’s appealing to design deep architectures by using self-attention.⇒ Transformer model is solely based on attention mechanisms.
Model
- Transformer is composed of an encoder and a decoder:
- Input (source) and output (target) sequence embeddings are added with sinusoidal positional encoding before being fed into the encoder and the decoder that stack modules based on self-attention.
- Encoder is a stack of multiple identical layers, where each layer has 2 sublayers.
- The first is a multi-head self-attention pooling.
- The second is a position-wise feed-forward network.
- By position-wise (or point-wise), it means that it applies the same linear transformation (with the same weights) to each element in the sequence.
- This can also be viewed as a convolutional layer with filter size of 1.
- Inspired by the ResNet design, a residual connection is employed around both sublayers.
- The addition from the residual connection is followed by layer normalization [Ba et al., 2016].
⇒ Encoder outputs aLoading equation...-dimensional vector representation for each position of the input sequence.⇒ Generates an attention-based representation with capability to locate a specific piece of information from a large context. - Decoder is similar to the encoder, except that the decoder contains two multi-head attention sublayers instead of one in each identical repeating layers.
- The first multi-head attention sublayer is masked to prevent positions from attending to the future. Thus, each position in decoder is allowed to only attend to all positions in the decoder up to that position.
- This masked attention preserves the auto-regressive property, ensuring that the prediction only depends on those output tokens that have been generated.
- The second multi-head attention sublayer is called encoder-decoder attention, where queries are from outputs of previous decoder layer, and keys and values are from transformer encoder outputs.
⇒ Retrieves information from the encoded representation. - The first multi-head attention sublayer is masked to prevent positions from attending to the future. Thus, each position in decoder is allowed to only attend to all positions in the decoder up to that position.
The architecture of the vanilla Transformer model.
Positionwise Feed-Forward Network
- Position-wise FFNN transforms representation at all the sequence positions using the same MLP.
- Since the same MLP transforms at all the positions, when the inputs at all these positions are the same, their outputs are also identical:Loading code...
⇒ This is why it is called position-wise:Loading code... - Since the same MLP transforms at all the positions, when the inputs at all these positions are the same, their outputs are also identical:
Transformers for Vision
- Transformer architecture was initially proposed for Seq2Seq learning, with a focus on machine translation. ⇒ Later on, it emerged as the model of choice in various natural language processing tasks.⇒ In the field of computer vision the dominant architecture has remained the CNN.⇒ Researchers started to wonder if it’s possible to do better by adapting Transformer to image data.⇒ Transformers have also become a game-changer in computer vision.
- Vision Transformers (ViTs) extract patches from images and feed them into a Transformer encoder to obtain a global representation, which will finally be transformed into the output label:
- Consider an input image with height Loading equation..., widthLoading equation...,Loading equation...channels, and patch height and width both asLoading equation....⇒ Image is split into a sequence ofLoading equation...patches (each patch is flattened to a vector of lengthLoading equation...).⇒ Image patches can be treated similarly to tokens in text sequences by Transformer encoders.
The vision Transformer architecture. In this example, an image is split into nine patches. A special “<cls>” token and the nine flattened image patches are transformed via patch embedding and n Transformer encoder blocks into ten representations, respectively. The “<cls>” representation is further transformed into the output label. - Consider an input image with height
Patch Embedding
- Splitting an image into patches and linearly projecting these flattened patches can be simplified as a single convolution operation, where both the kernel size and the stride size are set to the patch size:Loading code...
ViT’s MLP
- The MLP of vision Transformer encoder is slightly different from the position-wise FFN of original Transformer encoder.
- Activation function uses the GELU, which is smoother version of the ReLU.
- Dropout is applied to the output of each fully connected layer in the MLP for regularization.
Loading code...
ViT’s “Add & Norm” Layer
- Normalization is applied right before the multi-head attention or MLP (not after as in vanilla Transformer). Loading code...
Summary
- Input image is fed into
PatchEmbedding
whose output is concatenated with<cls>
token embedding.⇒ It’s summed with learnable positional embeddings before dropout.⇒ Then output is fed into Transformer encoder that stacksnum_blks
instances ofViTBlock
class.⇒ Finally, the representation of the<cls>
token is projected by the network head.Loading code...
- For small datasets like ImageNet (1.2M images), vision Transformer does not outperform the ResNet.
- Transformers lack useful principles in convolution, such as translation invariance and locality.
⇒ However, the picture changes when training larger models on larger datasets (e.g., 300M images).⇒ Vision Transformers outperform ResNets by a large margin in image classification.⇒ Superiority of Transformers in scalability.
- It was shown to be effective with data-efficient training strategies of DeiT (Touvron et al., 2021).
- Quadratic complexity of self-attention makes the architecture less suitable for higher-resolution images. ⇒ Swin Transformers addressed the quadratic computational complexity with respect to image size.
Transformer Optimizations
- As we get it, the transformer architecture is heavily bottlenecked by the self-attention mechanism, which has quadratic time and memory complexity.⇒ Try to improve Transformer architecture and attention operation in various ways.
KV Caching
- Applicable only to auto-regressive inference, not training.
- In self-attention, for each token in the input sequence, the model computes key and value vectors.
- Keys (K) and Values (V) for past tokens are computed once and cached.
- Only the new token's K and V are computed and appended to the cache.
- Attention operates over the cached K and V, avoiding recomputation.
⇒ During generation, this computation is repeated unnecessarily for previously seen tokens.⇒ KV caching solves this by storing the computed vectors for past tokens and reusing them in future steps.⇒ Avoids redundant computations.Loading code...
Flash Attention
- Attention algorithm used to scale transformer models more efficiently, enabling faster training and inference.
- Standard attention mechanism relies on High Bandwidth Memory (HBM) to write/read keys/queries/values.
- While HBM offers large capacity and high bandwidth, it has higher latency compared to on-chip memory like SRAM or Shared Memory.
- Model loads this data from HBM into GPU on-chip Shared Memory.
- Performs a step of the attention mechanism.
- Writes the result back to HBM.
- Repeats this process for each attention step.
⇒ In the standard implementation, the cost of writing/reading keys/queries/values from HBM is high:
- FlashAttention optimizes this workflow by loading keys, queries, and values into Shared Memory once, fusing the operations of the attention mechanism (such as softmax and matrix multiplication), and writing the result back to HBM:
Left: FlashAttention uses tiling to prevent materialization of the large Loading equation...attention matrix (dotted box) on (relatively) slow GPU HBM. In the outer loop (red arrows), FlashAttention loops through blocks of theLoading equation...andLoading equation...matrices and loads them to fast on-chip SRAM. In each block, FlashAttention loops over blocks ofLoading equation...matrix (blue arrows), loading them to SRAM, and writing the output of the attention computation back to HBM. Right: Speedup over the PyTorch implementation of attention on GPT-2. Flashattention does not read and write the largeLoading equation...attention matrix to HBM, resulting in anLoading equation...speedup on the attention computation.⇒ Significantly reduces memory access overhead and improves performance.
Improved Attention Span
- Make the context that can be used in self-attention longer, more efficient and flexible.
Longer Attention Span: Transformer-XL
“XL” means “extra long”
- The vanilla Transformer has a fixed and limited attention span.
- The model cannot capture very long term dependencies.
- It is hard to predict the first few tokens in each segment given no or thin context.
- The evaluation is expensive: whenever the segment is shifted to the right by one, the new segment is re-processed from scratch, although there are a lot of overlapped tokens.
- Reusing hidden states between segments.
- Adopting a new positional encoding that is suitable for reused states.
⇒ The model can only attend to other elements in the same segments during each update step and no information can flow across separated fixed-length segments.⇒ This context segmentation causes several issues:⇒ Transformer-XL solves the context segmentation problem with two main modifications:
Hidden State Reuse
- The recurrent connection between segments is introduced into the model by continuously using the hidden states from the previous segments:
- Let's label the hidden state of the Loading equation...-th layer for theLoading equation...-th segment in the model asLoading equation....
- Not the difference between Loading equation...andLoading equation....
- NB: both key and value rely on the extended hidden state, while the query only consumes hidden state at current step.
- The concatenation operation Loading equation...is along the sequence length dimension.
⇒ In addition to the hidden state of the last layer for the same segmentLoading equation..., it also depends on the hidden state of the same layer for the previous segmentLoading equation....⇒ By incorporating information from the previous hidden states, the model extends the attention span much longer in the past, over multiple segments:Loading equation... - Not the difference between
A comparison between the training phrase of vanilla Transformer & Transformer-XL with a segment length 4. (Image source: left part of Figure 2 in Dai et al., 2019). - Let's label the hidden state of the
Relative Positional Encoding
- In order to work with this new form of attention span, Transformer-XL proposed a new type of positional encoding based on reparametrization of dot-product of keys and queries.
- Q: Why?A: If using the same approach by vanilla Transformer and encoding the absolute position, the previous and current segments will be assigned with the same encoding, which is undesired.
- Replaces Loading equation...with relative positional encodingLoading equation....
- Replaces Loading equation...with 2 trainable parametersLoading equation...(for content) andLoading equation...(for location) in 2 different terms.
- Splits Loading equation...into two matrices,Loading equation...for content information andLoading equation...for location information.
⇒ To keep the positional information flow coherently across segments, Transformer-XL encodes the relative position instead, as it could be sufficient enough to know the position offset for making good predictions, i.e.Loading equation..., between one key vectorLoading equation...and its queryLoading equation....⇒ If omitting the scalarLoading equation...and the normalizing term in softmax but including positional encodings, we can write the attention score between query at positionLoading equation...and key at positionLoading equation...as:Loading equation...⇒ Transformer-XL reparameterizes the above four terms as follows:Loading equation... - Q: Why?
Adaptive Attention Span
- One key advantage of Transformer is the capability of capturing long-term dependencies. ⇒ Depending on the context, the model may prefer to attend further sometime than others; or one attention head may had different attention pattern from the other.⇒ If the attention span could adapt its length flexibly and only attend further back when needed, it would help reduce both computation and memory cost to support longer maximum context size in the model.⇒ This is the motivation for Adaptive Attention Span—a self-attention mechanism that seeks an optimal attention span.
- The authors hypothesized that different attention heads might assign scores differently within the same context window:
Two attention heads in the same model, A & B, assign attention differently within the same context window. Head A attends more to the recent tokens, while head B look further back into the past uniformly. (Image source: Sukhbaatar, et al. 2019) ⇒ Optimal span should be trained separately per head.
- Given the Loading equation...-th token, we need to compute the attention weights between this token and other keys at positionsLoading equation..., whereLoading equation...defines theLoading equation...-th token's context window:Loading equation...
- A soft mask function Loading equation...is added to control for an effective adjustable attention span, which maps the distance between query and key into aLoading equation...value.Loading equation...is parameterized byLoading equation...andLoading equation...is to be learned:
- Loading equation...is differentiable so it is trained jointly with other parts of the model.
- Parameters Loading equation...,Loading equation...are learned separately per head.
- Moreover, the loss function has an extra L1 penalty on Loading equation....
Loading equation...whereLoading equation...is a hyper-parameter which defines the softness ofLoading equation....The soft masking function used in the adaptive attention span. (Image source: Sukhbaatar, et al. 2019.) ⇒ The soft mask function is applied to the softmax elements in the attention weights:Loading equation...
- In the experiments of Transformer with adaptive attention span, Sukhbaatar, et al. (2019) found a general tendency that lower layers do not require very long attention spans, while a few attention heads in higher layers may use exceptionally long spans.
- Adaptive attention span also helps greatly reduce the number of FLOPS, especially in a big model with many attention layers and a large context length.
Sparse Attention Matrix Factorization: Sparse Transformers
- Sparse Transformer introduced factorized self-attention, through sparse matrix factorization.⇒ Makes possible to train dense attention networks with hundreds of layers on sequence length up to 16384, which would be infeasible on modern hardware otherwise.
- Loading equation...a set of attention connectivity patternLoading equation..., where eachLoading equation...records a set of key positions that theLoading equation...-th query vector attends to:
- NB: the size of Loading equation...is not fixed, whileLoading equation...is always of sizeLoading equation..., and thusLoading equation....
Loading equation...whereLoading equation... - NB: the size of
- In auto-regressive models, one attention span is defined as Loading equation...as it allows each token to attend to all the positions in the past. In factorized self-attention, the setLoading equation...is decomposed into a tree of dependencies, such that for every pair ofLoading equation...whereLoading equation..., there is a path connectingLoading equation...back toLoading equation...andLoading equation...can attend toLoading equation...either directly or indirectly.⇒ Precisely, the setLoading equation...is divided intoLoading equation...non-overlapping subsets, where theLoading equation...-th subset is denoted asLoading equation...,Loading equation....
- Sparse Transformer proposed two types of fractorized attention:
- Strided attention with stride Loading equation...:
- Works well with image data as the structure is aligned with strides. In the image case, each pixel would attend to all the previous Loading equation...pixels in the raster scanning order (naturally cover the entire width of the image) and then those pixels attend to others in the same column (defined by another attention connectivity subset).
Loading equation... - Works well with image data as the structure is aligned with strides. In the image case, each pixel would attend to all the previous
- Fixed attention:
- A small set of tokens summarize previous locations and propagate that information to all future locations.
Loading equation...whereLoading equation...is a hyperparameter (ifLoading equation..., it restricts the representation whereas many depend on a few positions; the paper choseLoading equation...forLoading equation...).
The top row illustrates the attention connectivity patterns in (a) Transformer, (b) Sparse Transformer with strided attention, and (c) Sparse Transformer with fixed attention. The bottom row contains corresponding self-attention connectivity matrices. Note that the top and bottom rows are not in the same scale. (Image source: Child et al., 2019 + a few of extra annotations.) - Strided attention with stride
- There are three ways to use sparse factorized attention patterns in Transformer architecture:
- One attention type per residual block and then interleave them, Loading equation..., whereLoading equation...is the index of the current residual block.
- Set up a single head which attends to locations that all the factorized heads attend to, Loading equation....
- Use a multi-head attention mechanism, but different from vanilla Transformer, each head might adopt a pattern presented above, 1 or 2 .
- One attention type per residual block and then interleave them,
Locality-Sensitive Hashing: Reformer
- Reformer proposed two main changes:
- Replace the dot-product attention with locality-sensitive hashing (LSH) attention, reducing the complexity from Loading equation...toLoading equation....
- Replace the standard residual blocks with reversible residual layers, which allows storing activations only once during training instead of Loading equation...times (i.e. proportional to the number of layers).
The LSH attention consists of 4 steps: bucketing, sorting, chunking, and attention computation. (Image source: left part of Figure 1 in Kitaev, et al. 2020). - Replace the dot-product attention with locality-sensitive hashing (LSH) attention, reducing the complexity from