<

Attention Mechanisms and Transformers

Table of Contents

Attention Mechanisms and Transformers

Introduction

What’s Wrong with Seq2Seq Models?

  • The Seq2seq model originates from language modeling.
    • Normally has an encoder-decoder architecture:
      • Encoder processes the input sequence and compresses it into a fixed-length context vector, where this representation is expected to be a good summary of the meaning of the whole source sequence.
      • Decoder is initialized with the context vector to emit the transformed output.
      • Both encoder and decoder are typically RNNs (LSTMs/GRUs).
    • Transforms an input sequence (source) to new one (target) where both sequences can be of arbitrary lengths:
        notion image
    • Limitation: incapable to remember long sentences because of the fixed-length context vector.
      • The context vector struggles to retain information from long sentences, often causing early parts of the sequence to be forgotten by the time processing completes.
          ⇒ The attention mechanism was born to resolve this problem.

Attention Cues in Biology

  • When inspecting a visual scene, our optic nerve receives information at the order of 
    Loading equation...
    bits per second, far exceeding what our brain can fully process. Fortunately, our ancestors had learned from experience (or, data in our case) that not all sensory inputs are created equal.
      ⇒ Throughout human history, the capability of directing attention to only a fraction of information of interest has enabled our brain to allocate resources more smartly to survive, to grow, and to socialize, such as detecting predators, preys, and mates.
  • To explain how our attention is deployed in the visual world, a two-component framework has emerged. Idea dates back to William James, who is considered the “father of American psychology” [James, 2007].
      ⇒ In this framework, subjects selectively direct the spotlight of attention using
      1. Nonvolitional cue
      1. Volitional cue
      Nonvolitional cue is based on the saliency and conspicuity of objects in the environment:
      Using the nonvolitional cue based on saliency (red cup, non-paper), attention is involuntarily directed to the coffee.
      Using the nonvolitional cue based on saliency (red cup, non-paper), attention is involuntarily directed to the coffee.
    • While all the paper products are printed in black and white, the coffee cup is red.
        ⇒ This coffee is intrinsically salient and conspicuous in this visual environment.
        ⇒ Automatically and involuntarily draws attention.
        ⇒ You bring the fovea (the center of the macula where visual acuity is highest) onto the coffee.
    • ⇒ After drinking coffee, you become caffeinated and want to read a book. You turn your head, refocus your eyes, and look at the book. In this task-dependent case you select the book under cognitive and volitional control:
      Using the volitional cue (want to read a book) that is task-dependent, attention is directed to the book under volitional control.
      Using the volitional cue (want to read a book) that is task-dependent, attention is directed to the book under volitional control.
      ⇒ Using the volitional cue based on variable selection criteria, this form of attention is more deliberate. It is also more powerful with the subject’s voluntary effort.

Queries, Keys, and Values

  • Inspired by the nonvolitional and volitional attention cues that explain the attentional deployment.
  • Consider the case where only nonvolitional cues are available. To bias selection over sensory inputs, we can simply use:
    • Parameterized fully-connected layer.
    • Non-parameterized max or average pooling.
    • ⇒ What sets attention mechanisms apart from FC or pooling layers is inclusion of the volitional cues.
    • In the context of attention mechanisms, we refer to volitional cues as queries.
    • Given any query, attention mechanisms bias selection over sensory inputs via attention pooling.
      • These sensory inputs are called values in the context of attention mechanisms.
      • Every value is paired with a key, which can be thought of the nonvolitional cue of that sensory input.
    • Loading equation...
      Design attention pooling such that the given query (volitional cue) can interact with keys (nonvolitional cues), which guides bias selection over values (sensory inputs):
      notion image

Nadaraya-Watson Kernel Regression

  • Example of ML algorithm with attention mechanisms.
  • We generate a dataset
    Loading equation...
    according to the following non-linear function with the noise:
      Loading equation...
      Given this dataset, how to learn
      Loading equation...
      to predict output
      Loading equation...
      for any new input
      Loading equation...
      ?

Average Pooling

  • Begin with the “dumbest” estimator for this regression problem.
      ⇒ Use average pooling to average over all the training outputs:
      Loading equation...
      notion image
      ⇒ As we can see, this estimator is not so smart as average pooling omits the inputs
      Loading equation...
      .

Nonparametric Attention Pooling

  • Better idea was proposed by Nadaraya and Watson to weight the outputs
    Loading equation...
    according to their input locations:
      Loading equation...
      where 
      Loading equation...
       is a kernel.
      ⇒ We can rewrite it in a more generalized form of attention pooling:
      Loading equation...
      where 
      Loading equation...
       is the query and 
      Loading equation...
       is the key-value pair,
      Loading equation...
      is an attention weight.
    • Loading equation...
      is assigned to the corresponding value
      Loading equation...
      .
    • Consider a Gaussian kernel defined as
        Loading equation...
        ⇒ Plugging the Gaussian kernel into the last two equations gives us
        Loading equation...
    • Key
      Loading equation...
      that is closer to the given query
      Loading equation...
      will get more attention via a larger attention weight assigned to the key’s corresponding value.
    • ⇒ Predicted line is smooth and closer to the ground-truth than that produced by average pooling:
      notion image
    • Now let us take a look at the attention weights. Here testing inputs are queries while training inputs are keys. Since both inputs are sorted, we can see that the closer the query-key pair is, the higher attention weight is in the attention pooling.
        notion image

Parametric Attention Pooling

  • We can easily integrate learnable parameters into attention pooling.
      ⇒ Distance between the query
      Loading equation...
      and the key
      Loading equation...
      can be multiplied by a learnable parameter
      Loading equation...
      :
      Loading equation...
      ⇒ After training the parametric attention model, we can plot its prediction:
      notion image
    • Comparing with nonparametric attention pooling, the region with large attention weights becomes sharper in the learnable and parametric setting:
        notion image

Attention Pooling and Attention Scoring Functions

  • In the previous section, we obtained probability distribution over values that are paired with keys: output of the attention pooling is simply a weighted sum of the values based on the attention weights.
  • At a high level, we can use the above algorithm to instantiate the framework of attention mechanisms.
      ⇒ Denoting an attention scoring function by 
      Loading equation...
      , the following figure illustrates how the output of attention pooling can be computed as a weighted sum of values:
      Computing the output of attention pooling as a weighted average of values
      Computing the output of attention pooling as a weighted average of values
    • NB: attention weights are a probability distribution, weighted sum is a weighted average.
    • More formally, suppose that we have:
      • Query
        Loading equation...
        .
      • Key-value pairs
        Loading equation...
        , where any
        Loading equation...
        and any
        Loading equation...
        .
      • ⇒ Attention pooling (
        Loading equation...
        before) is instantiated as a weighted sum of the values:
        Loading equation...
        where the attention weight (scalar) for the query
        Loading equation...
        and key
        Loading equation...
        is computed by the softmax operation of an attention scoring function
        Loading equation...
        that maps two vectors to a scalar:
        Loading equation...
  • Different choices of the attention scoring function
    Loading equation...
    lead to different behaviors of attention pooling.

Additive Attention

  • When queries and keys are vectors of different lengths, we can use additive attention as the scoring function. Given a query
    Loading equation...
    and a key
    Loading equation...
    , the additive attention scoring function is
      Loading equation...
      where
      Loading equation...
      , and
      Loading equation...
      are learnable parameters.
      Loading code...

Scaled Dot-Product Attention

  • More computationally efficient but requires both query and key having the same vector of length
    Loading equation...
    .
  • Assume that all elements of query and key are i.i.d variables with zero mean and unit variance.
      Loading equation...
      Loading equation...
      ⇒ To ensure that the variance of the dot product still remains one regardless of vector length, the scaled dot-product attention scoring function is
      Loading equation...
  • In practice, we often think in mini-batches for efficiency.
      ⇒ Scaled dot-product attention of queries
      Loading equation...
      , keys
      Loading equation...
      , and values
      Loading equation...
      is
      Loading equation...
      Loading code...

Masked Softmax Operation

  • In some cases, not all the values should be fed into attention pooling.
    • E.g.: for efficient minibatch processing, some text sequences are padded with special tokens that do not carry meaning. To get an attention pooling over only meaningful tokens as values, we can specify a valid sequence length (in number of tokens) to filter out those beyond this specified range when computing softmax. In this way, we can implement such a masked softmax operation, where any value beyond the valid length is masked as zero.

Basic Attention Mechanisms

Bahdanau Attention

  • As we get from the intro, attention mechanism was born to help memorize long source sentences in neural machine translation (NMT).
      ⇒ Rather than building a single context vector out of the encoder’s last hidden state, the secret sauce invented by attention is to create shortcuts between the context vector and the entire source input. The weights of these shortcut connections are customizable for each output element.
      When predicting a token, not all the input tokens are relevant.
      Model aligns only to parts of the input sequence that are relevant to the current prediction.
      ⇒ This is achieved by treating the context variable at the decoding time step
      Loading equation...
      as an output of attention pooling:
      notion image
      Loading equation...
      where
    • Loading equation...
      is the number of tokens in the input sequence.
    • Decoder hidden state
      Loading equation...
      at time step
      Loading equation...
      is the query.
    • Encoder hidden states
      Loading equation...
      are both the keys and values.
    • Attention weight
      Loading equation...
      is computed using the additive attention scoring function.
  • Matrix of alignment scores is a nice byproduct to explicitly show the correlation between source and target words:
      Alignment matrix of "L'accord sur l'Espace économique européen a été signé en août 1992" (French) and its English translation "The agreement on the European Economic Area was signed in August 1992".
      Alignment matrix of "L'accord sur l'Espace économique européen a été signé en août 1992" (French) and its English translation "The agreement on the European Economic Area was signed in August 1992".

Multi-Head Attention

  • Given the same set of queries, keys, and values, we may want our model to combine knowledge from different behaviors of the same attention mechanism, such as capturing dependencies of various ranges (e.g., shorter-range vs. longer-range) within a sequence.
      It may be beneficial to allow our attention mechanism to jointly use different representation subspaces of queries, keys, and values.
    • Usually, understanding the role of a word in a sentence requires understanding how it is related to different parts of the sentence. For example, in some languages, subjects define verb inflection (e.g., gender agreement), verbs define the case of their objects, and many more. In other words, each word is part of many relations.
    • ⇒ Let model focus on different things from different representation subspaces at different positions.
      ⇒ Instead of performing a single attention pooling, queries, keys, and values can be transformed with 
      Loading equation...
      independently learned linear projections. Then these 
      Loading equation...
       projected queries, keys, and values are fed into attention pooling in parallel. In the end, 
      Loading equation...
       attention pooling outputs (heads) are concatenated and transformed with another learned linear projection to produce the final output.
      ⇒ Given a query
      Loading equation...
      , a key
      Loading equation...
      , and a value
      Loading equation...
      , each attention head is computed as
      Loading equation...
      where
    • Loading equation...
      and
      Loading equation...
      are learnable parameters.
    • Loading equation...
      is attention pooling (e.g. additive attention, scaled dot-product attention, etc.).
    • ⇒ Multi-head attention output is a linear transformation via
      Loading equation...
      of the concatenation of
      Loading equation...
      heads:
      Loading equation...
      Multi-head attention, where multiple heads are concatenated then linearly transformed
      Multi-head attention, where multiple heads are concatenated then linearly transformed

Self-Attention

  • We often use CNNs or RNNs to encode a sequence.
      ⇒ Idea: with attention mechanisms in mind, try to feed a sequence of tokens into attention pooling so that the same set of tokens act as queries, keys, and values.
      Since queries, keys, and values come from the same place, this performs self-attention.
      ⇒ Given a sequence of input tokens
      Loading equation...
      ,
      Loading equation...
      , its self-attention outputs a sequence of the same length
      Loading equation...
      , where
      Loading equation...
    • E.g. it can be multi-head attention.

CNNs/RNNs vs. Self-Attention

Comparing CNN (padding tokens are omitted), RNN, and self-attention architectures
Comparing CNN (padding tokens are omitted), RNN, and self-attention architectures
  • Let’s compare architectures for mapping a sequence of 
    Loading equation...
     tokens to another sequence of equal length, where each input or output token is represented by a 
    Loading equation...
    -dimensional vector.
    • CNN (consider a convolutional layer whose kernel size is 
      Loading equation...
      ):
      • Computational complexity of the convolutional layer is
        Loading equation...
        .
      • Since CNNs are hierarchical, there are
        Loading equation...
        sequential operations and maximum path length is
        Loading equation...
        .
    • RNN:
      • When updating the hidden state, multiplication of the
        Loading equation...
        weight matrix and the
        Loading equation...
        -dimensional hidden state has a computational complexity of
        Loading equation...
        . Since the sequence length is
        Loading equation...
        , the computational complexity of the recurrent layer is
        Loading equation...
        .
      • According to the figure above, there are
        Loading equation...
        sequential operations that cannot be parallelized and the maximum path length is also
        Loading equation...
        .
    • Self-attention:
      • Queries, keys, and values are all
        Loading equation...
        matrices. Consider the scaled dot-product attention, where a
        Loading equation...
        matrix is multiplied by a
        Loading equation...
        matrix, then the output
        Loading equation...
        matrix is multiplied by a
        Loading equation...
        matrix. As a result, the self-attention has a
        Loading equation...
        computational complexity.
      • As we can see in the figure, each token is directly connected to any other token via self-attention. Therefore, computation can be parallel with
        Loading equation...
        sequential operations and the maximum path length is also
        Loading equation...
        .
Loading equation...
Self-attention enjoy parallel computation and has shortest maximum path length.
  • However, the quadratic computational complexity with respect to a sequence length makes self-attention prohibitively slow for very long sequences.

Positional Encoding

  • Self-attention ditches sequential operations in favor of parallel computation.
      ⇒ To use the sequence order information, we can inject absolute or relative positional information.
      ⇒ Add positional encoding to input representations where these encodings can be learned or fixed.
  • Fixed positional encoding can be based on sine and cosine functions [Vaswani et al., 2017].
      ⇒ Called sinusoidal positional encoding therefore.
      ⇒ Input representation
      Loading equation...
      contains
      Loading equation...
      -dimensional embeddings for
      Loading equation...
      tokens of a sequence.
      ⇒ Positional encoding outputs
      Loading equation...
      using a positional embedding matrix
      Loading equation...
      of the same shape, whose element on the
      Loading equation...
      -th row and the
      Loading equation...
      -th or the
      Loading equation...
      -th column is
      Loading equation...
    • In the positional embedding matrix 
      Loading equation...
      , rows correspond to positions within a sequence and columns represent different positional encoding dimensions.
      • E.g.: from the graph below, we can see that the 
        Loading equation...
        -th and the 
        Loading equation...
        -th columns of the positional embedding matrix have a higher frequency than the 
        Loading equation...
        -th and the 
        Loading equation...
        -th columns. The offset between the 
        Loading equation...
        -th and the 
        Loading equation...
        -th (same for the 
        Loading equation...
        -th and the 
        Loading equation...
        -th) columns is due to the alternation of sine and cosine functions.
          notion image
Absolute Positional Information
  • Let’s see how monotonically decreased frequency along encoding dimension relates to absolute positional info.
      ⇒ Let’s print out the binary representations of 
      Loading equation...
      :
      Loading code...
      ⇒ Lowest bit, 2nd-lowest bit, and 3rd-lowest bit alternate on every number, every 2 numbers, and every 4 numbers.
      ⇒ In binary representations, a higher bit has a lower frequency than a lower bit.
      ⇒ Similarly, positional encoding decreases frequencies along encoding dimension by using trigonometric functions:
      notion image
    • Since the outputs are float numbers, such continuous representations are more space-efficient than binary representations.
Relative Positional Information
  • Besides capturing absolute positional information, the above positional encoding also allows a model to easily learn to attend by relative positions.
    • This is because for any fixed position offset 
      Loading equation...
      , the positional encoding at position 
      Loading equation...
      can be represented by a linear projection of that at position 
      Loading equation...
      :
        Loading equation...
        where
        Loading equation...
        and the
        Loading equation...
        projection matrix does not depend on any position index
        Loading equation...
        .
        ⇒ Any pair of
        Loading equation...
        can be linearly projected to
        Loading equation...
        for any fixed offset
        Loading equation...
        .

“Vanilla” Transformer

  • We have compared CNNs, RNNs, and self-attention.
      ⇒ Self-attention enjoys both parallel computation and the shortest maximum path length.
      ⇒ It’s appealing to design deep architectures by using self-attention.
      Transformer model is solely based on attention mechanisms.

Model

  • Transformer is composed of an encoder and a decoder:
      The architecture of the vanilla Transformer model.
      The architecture of the vanilla Transformer model.
    • Input (source) and output (target) sequence embeddings are added with sinusoidal positional encoding before being fed into the encoder and the decoder that stack modules based on self-attention.
    • Encoder is a stack of multiple identical layers, where each layer has 2 sublayers.
      • The first is a multi-head self-attention pooling.
      • The second is a position-wise feed-forward network.
        • By position-wise (or point-wise), it means that it applies the same linear transformation (with the same weights) to each element in the sequence.
        • This can also be viewed as a convolutional layer with filter size of 1.
      • Inspired by the ResNet design, a residual connection is employed around both sublayers.
      • The addition from the residual connection is followed by layer normalization [Ba et al., 2016].
      • ⇒ Encoder outputs a
        Loading equation...
        -dimensional vector representation for each position of the input sequence.
        ⇒ Generates an attention-based representation with capability to locate a specific piece of information from a large context.
    • Decoder is similar to the encoder, except that the decoder contains two multi-head attention sublayers instead of one in each identical repeating layers.
      • The first multi-head attention sublayer is masked to prevent positions from attending to the future. Thus, each position in decoder is allowed to only attend to all positions in the decoder up to that position.
        • This masked attention preserves the auto-regressive property, ensuring that the prediction only depends on those output tokens that have been generated.
      • The second multi-head attention sublayer is called encoder-decoder attention, where queries are from outputs of previous decoder layer, and keys and values are from transformer encoder outputs.
      • ⇒ Retrieves information from the encoded representation.

Positionwise Feed-Forward Network

  • Position-wise FFNN transforms representation at all the sequence positions using the same MLP.
      ⇒ This is why it is called position-wise:
      Loading code...
    • Since the same MLP transforms at all the positions, when the inputs at all these positions are the same, their outputs are also identical:
        Loading code...

Transformers for Vision

  • Transformer architecture was initially proposed for Seq2Seq learning, with a focus on machine translation.
      ⇒ Later on, it emerged as the model of choice in various natural language processing tasks.
      ⇒ In the field of computer vision the dominant architecture has remained the CNN.
      ⇒ Researchers started to wonder if it’s possible to do better by adapting Transformer to image data.
      ⇒ Transformers have also become a game-changer in computer vision.
  • Vision Transformers (ViTs) extract patches from images and feed them into a Transformer encoder to obtain a global representation, which will finally be transformed into the output label:
      The vision Transformer architecture. In this example, an image is split into nine patches. A special “<cls>” token and the nine flattened image patches are transformed via patch embedding and n Transformer encoder blocks into ten representations, respectively. The “<cls>” representation is further transformed into the output label.
      The vision Transformer architecture. In this example, an image is split into nine patches. A special “<cls>” token and the nine flattened image patches are transformed via patch embedding and n Transformer encoder blocks into ten representations, respectively. The “<cls>” representation is further transformed into the output label.
    • Consider an input image with height
      Loading equation...
      , width
      Loading equation...
      ,
      Loading equation...
      channels, and patch height and width both as
      Loading equation...
      .
        ⇒ Image is split into a sequence of
        Loading equation...
        patches (each patch is flattened to a vector of length
        Loading equation...
        ).
        Image patches can be treated similarly to tokens in text sequences by Transformer encoders.

Patch Embedding

  • Splitting an image into patches and linearly projecting these flattened patches can be simplified as a single convolution operation, where both the kernel size and the stride size are set to the patch size:
      Loading code...

ViT’s MLP

  • The MLP of vision Transformer encoder is slightly different from the position-wise FFN of original Transformer encoder.
    • Activation function uses the GELU, which is smoother version of the ReLU.
    • Dropout is applied to the output of each fully connected layer in the MLP for regularization.
    • Loading code...

ViT’s “Add & Norm” Layer

  • Normalization is applied right before the multi-head attention or MLP (not after as in vanilla Transformer).
      Loading code...

Summary

  • Input image is fed into PatchEmbedding whose output is concatenated with <cls> token embedding.
      ⇒ It’s summed with learnable positional embeddings before dropout.
      ⇒ Then output is fed into Transformer encoder that stacks num_blks instances of ViTBlock class.
      ⇒ Finally, the representation of the <cls> token is projected by the network head.
      Loading code...
  • For small datasets like ImageNet (1.2M images), vision Transformer does not outperform the ResNet.
    • Transformers lack useful principles in convolution, such as translation invariance and locality.
    • ⇒ However, the picture changes when training larger models on larger datasets (e.g., 300M images).
      ⇒ Vision Transformers outperform ResNets by a large margin in image classification.
      ⇒ Superiority of Transformers in scalability.
  • Quadratic complexity of self-attention makes the architecture less suitable for higher-resolution images.
      ⇒ Swin Transformers addressed the quadratic computational complexity with respect to image size.
      ⇒ Reinstated convolution-like priors, extending the applicability of Transformers to a range of computer vision tasks beyond image classification with state-of-the-art results (Liu et al., 2021).

Transformer Optimizations

  • As we get it, the transformer architecture is heavily bottlenecked by the self-attention mechanism, which has quadratic time and memory complexity.
      ⇒ Try to improve Transformer architecture and attention operation in various ways.

KV Caching

  • Applicable only to auto-regressive inference, not training.
  • In self-attention, for each token in the input sequence, the model computes key and value vectors.
      ⇒ During generation, this computation is repeated unnecessarily for previously seen tokens.
      ⇒ KV caching solves this by storing the computed vectors for past tokens and reusing them in future steps.
      ⇒ Avoids redundant computations.
      Loading code...
    • Keys (K) and Values (V) for past tokens are computed once and cached.
    • Only the new token's K and V are computed and appended to the cache.
    • Attention operates over the cached K and V, avoiding recomputation.

Flash Attention

  • Attention algorithm used to scale transformer models more efficiently, enabling faster training and inference.
  • Standard attention mechanism relies on High Bandwidth Memory (HBM) to write/read keys/queries/values.
    • While HBM offers large capacity and high bandwidth, it has higher latency compared to on-chip memory like SRAM or Shared Memory.
    • ⇒ In the standard implementation, the cost of writing/reading keys/queries/values from HBM is high:
      1. Model loads this data from HBM into GPU on-chip Shared Memory.
      1. Performs a step of the attention mechanism.
      1. Writes the result back to HBM.
      1. Repeats this process for each attention step.
  • FlashAttention optimizes this workflow by loading keys, queries, and values into Shared Memory once, fusing the operations of the attention mechanism (such as softmax and matrix multiplication), and writing the result back to HBM:
      notion image
      Left: FlashAttention uses tiling to prevent materialization of the large  attention matrix (dotted box) on (relatively) slow GPU HBM. In the outer loop (red arrows), FlashAttention loops through blocks of the  and  matrices and loads them to fast on-chip SRAM. In each block, FlashAttention loops over blocks of  matrix (blue arrows), loading them to SRAM, and writing the output of the attention computation back to HBM. Right: Speedup over the PyTorch implementation of attention on GPT-2. Flashattention does not read and write the large  attention matrix to HBM, resulting in an  speedup on the attention computation.
      Left: FlashAttention uses tiling to prevent materialization of the large
      Loading equation...
      attention matrix (dotted box) on (relatively) slow GPU HBM. In the outer loop (red arrows), FlashAttention loops through blocks of the
      Loading equation...
      and
      Loading equation...
      matrices and loads them to fast on-chip SRAM. In each block, FlashAttention loops over blocks of
      Loading equation...
      matrix (blue arrows), loading them to SRAM, and writing the output of the attention computation back to HBM. Right: Speedup over the PyTorch implementation of attention on GPT-2. Flashattention does not read and write the large
      Loading equation...
      attention matrix to HBM, resulting in an
      Loading equation...
      speedup on the attention computation.
      ⇒ Significantly reduces memory access overhead and improves performance.

Improved Attention Span

  • Make the context that can be used in self-attention longer, more efficient and flexible.

Longer Attention Span: Transformer-XL

“XL” means “extra long”
  • The vanilla Transformer has a fixed and limited attention span.
      ⇒ The model can only attend to other elements in the same segments during each update step and no information can flow across separated fixed-length segments.
      ⇒ This context segmentation causes several issues:
    • The model cannot capture very long term dependencies.
    • It is hard to predict the first few tokens in each segment given no or thin context.
    • The evaluation is expensive: whenever the segment is shifted to the right by one, the new segment is re-processed from scratch, although there are a lot of overlapped tokens.
    • Transformer-XL solves the context segmentation problem with two main modifications:
      1. Reusing hidden states between segments.
      1. Adopting a new positional encoding that is suitable for reused states.
Hidden State Reuse
  • The recurrent connection between segments is introduced into the model by continuously using the hidden states from the previous segments:
      A comparison between the training phrase of vanilla Transformer & Transformer-XL with a segment length 4. (Image source: left part of Figure 2 in Dai et al., 2019).
      A comparison between the training phrase of vanilla Transformer & Transformer-XL with a segment length 4. (Image source: left part of Figure 2 in Dai et al., 2019).
    • Let's label the hidden state of the
      Loading equation...
      -th layer for the
      Loading equation...
      -th segment in the model as
      Loading equation...
      .
        In addition to the hidden state of the last layer for the same segment
        Loading equation...
        , it also depends on the hidden state of the same layer for the previous segment
        Loading equation...
        .
        ⇒ By incorporating information from the previous hidden states, the model extends the attention span much longer in the past, over multiple segments:
        Loading equation...
      • Not the difference between
        Loading equation...
        and
        Loading equation...
        .
      • NB: both key and value rely on the extended hidden state, while the query only consumes hidden state at current step.
      • The concatenation operation
        Loading equation...
        is along the sequence length dimension.
Relative Positional Encoding
  • In order to work with this new form of attention span, Transformer-XL proposed a new type of positional encoding based on reparametrization of dot-product of keys and queries.
    • Q: Why?
        A: If using the same approach by vanilla Transformer and encoding the absolute position, the previous and current segments will be assigned with the same encoding, which is undesired.
    • ⇒ To keep the positional information flow coherently across segments, Transformer-XL encodes the relative position instead, as it could be sufficient enough to know the position offset for making good predictions, i.e.
      Loading equation...
      , between one key vector
      Loading equation...
      and its query
      Loading equation...
      .
      ⇒ If omitting the scalar
      Loading equation...
      and the normalizing term in softmax but including positional encodings, we can write the attention score between query at position
      Loading equation...
      and key at position
      Loading equation...
      as:
      Loading equation...
      ⇒ Transformer-XL reparameterizes the above four terms as follows:
      Loading equation...
    • Replaces
      Loading equation...
      with relative positional encoding
      Loading equation...
      .
    • Replaces
      Loading equation...
      with 2 trainable parameters
      Loading equation...
      (for content) and
      Loading equation...
      (for location) in 2 different terms.
    • Splits
      Loading equation...
      into two matrices,
      Loading equation...
      for content information and
      Loading equation...
      for location information.

Adaptive Attention Span

  • One key advantage of Transformer is the capability of capturing long-term dependencies.
      ⇒ Depending on the context, the model may prefer to attend further sometime than others; or one attention head may had different attention pattern from the other.
      ⇒ If the attention span could adapt its length flexibly and only attend further back when needed, it would help reduce both computation and memory cost to support longer maximum context size in the model.
      ⇒ This is the motivation for Adaptive Attention Span—a self-attention mechanism that seeks an optimal attention span.
  • The authors hypothesized that different attention heads might assign scores differently within the same context window:
      Two attention heads in the same model, A & B, assign attention differently within the same context window. Head A attends more to the recent tokens, while head B look further back into the past uniformly. (Image source: Sukhbaatar, et al. 2019)
      Two attention heads in the same model, A & B, assign attention differently within the same context window. Head A attends more to the recent tokens, while head B look further back into the past uniformly. (Image source: Sukhbaatar, et al. 2019)
      ⇒ Optimal span should be trained separately per head.
  • Given the
    Loading equation...
    -th token, we need to compute the attention weights between this token and other keys at positions
    Loading equation...
    , where
    Loading equation...
    defines the
    Loading equation...
    -th token's context window:
      Loading equation...
  • A soft mask function
    Loading equation...
    is added to control for an effective adjustable attention span, which maps the distance between query and key into a
    Loading equation...
    value.
    Loading equation...
    is parameterized by
    Loading equation...
    and
    Loading equation...
    is to be learned:
      Loading equation...
      where
      Loading equation...
      is a hyper-parameter which defines the softness of
      Loading equation...
      .
      The soft masking function used in the adaptive attention span. (Image source: Sukhbaatar, et al. 2019.)
      The soft masking function used in the adaptive attention span. (Image source: Sukhbaatar, et al. 2019.)
      ⇒ The soft mask function is applied to the softmax elements in the attention weights:
      Loading equation...
    • Loading equation...
      is differentiable so it is trained jointly with other parts of the model.
    • Parameters
      Loading equation...
      ,
      Loading equation...
      are learned separately per head.
    • Moreover, the loss function has an extra L1 penalty on
      Loading equation...
      .
  • In the experiments of Transformer with adaptive attention span, Sukhbaatar, et al. (2019) found a general tendency that lower layers do not require very long attention spans, while a few attention heads in higher layers may use exceptionally long spans.
  • Adaptive attention span also helps greatly reduce the number of FLOPS, especially in a big model with many attention layers and a large context length.

Sparse Attention Matrix Factorization: Sparse Transformers

  • Sparse Transformer introduced factorized self-attention, through sparse matrix factorization.
      ⇒ Makes possible to train dense attention networks with hundreds of layers on sequence length up to 16384, which would be infeasible on modern hardware otherwise.
  • Loading equation...
    a set of attention connectivity pattern
    Loading equation...
    , where each
    Loading equation...
    records a set of key positions that the
    Loading equation...
    -th query vector attends to:
      Loading equation...
      where
      Loading equation...
    • NB: the size of
      Loading equation...
      is not fixed, while
      Loading equation...
      is always of size
      Loading equation...
      , and thus
      Loading equation...
      .
  • In auto-regressive models, one attention span is defined as
    Loading equation...
    as it allows each token to attend to all the positions in the past. In factorized self-attention, the set
    Loading equation...
    is decomposed into a tree of dependencies, such that for every pair of
    Loading equation...
    where
    Loading equation...
    , there is a path connecting
    Loading equation...
    back to
    Loading equation...
    and
    Loading equation...
    can attend to
    Loading equation...
    either directly or indirectly.
      ⇒ Precisely, the set
      Loading equation...
      is divided into
      Loading equation...
      non-overlapping subsets, where the
      Loading equation...
      -th subset is denoted as
      Loading equation...
      ,
      Loading equation...
      .
  • Sparse Transformer proposed two types of fractorized attention:
      The top row illustrates the attention connectivity patterns in (a) Transformer, (b) Sparse Transformer with strided attention, and (c) Sparse Transformer with fixed attention. The bottom row contains corresponding self-attention connectivity matrices. Note that the top and bottom rows are not in the same scale. (Image source: Child et al., 2019 + a few of extra annotations.)
      The top row illustrates the attention connectivity patterns in (a) Transformer, (b) Sparse Transformer with strided attention, and (c) Sparse Transformer with fixed attention. The bottom row contains corresponding self-attention connectivity matrices. Note that the top and bottom rows are not in the same scale. (Image source: Child et al., 2019 + a few of extra annotations.)
      1. Strided attention with stride
        Loading equation...
        :
          Loading equation...
          • Works well with image data as the structure is aligned with strides. In the image case, each pixel would attend to all the previous
            Loading equation...
            pixels in the raster scanning order (naturally cover the entire width of the image) and then those pixels attend to others in the same column (defined by another attention connectivity subset).
      1. Fixed attention:
          Loading equation...
          where
          Loading equation...
          is a hyperparameter (if
          Loading equation...
          , it restricts the representation whereas many depend on a few positions; the paper chose
          Loading equation...
          for
          Loading equation...
          ).
          • A small set of tokens summarize previous locations and propagate that information to all future locations.
  • There are three ways to use sparse factorized attention patterns in Transformer architecture:
      1. One attention type per residual block and then interleave them,
        Loading equation...
        , where
        Loading equation...
        is the index of the current residual block.
      1. Set up a single head which attends to locations that all the factorized heads attend to,
        Loading equation...
        .
      1. Use a multi-head attention mechanism, but different from vanilla Transformer, each head might adopt a pattern presented above, 1 or 2 .

Locality-Sensitive Hashing: Reformer

  • Reformer proposed two main changes:
      1. Replace the dot-product attention with locality-sensitive hashing (LSH) attention, reducing the complexity from
        Loading equation...
        to
        Loading equation...
        .
      1. Replace the standard residual blocks with reversible residual layers, which allows storing activations only once during training instead of
        Loading equation...
        times (i.e. proportional to the number of layers).
      The LSH attention consists of 4 steps: bucketing, sorting, chunking, and attention computation. (Image source: left part of Figure 1 in Kitaev, et al. 2020).
      The LSH attention consists of 4 steps: bucketing, sorting, chunking, and attention computation. (Image source: left part of Figure 1 in Kitaev, et al. 2020).

Sources