A review of different Transformers-based algorithms and methods

Published February 15, 2023

In this blog I review some of the research in the area of area of Transformers models. The article is meant to be written at a high level. For more technical details please refer to the references mentioned.

Recap: Transformers

First, I briefly explain what Transformers are, what some typical tasks are that these models aim to solve, and go through some of their architectural components and fundamental architectures. This brief recap is largely based on the wonderful formal presentation of Transformers by Phuong & Hutter (2022). 1.

Problems that Transformers aim to solve
  • Transformers model sequence data. They are most commonly used for sequence modeling (e.g., language modeling) and sequence-to-sequence prediction (e.g., text-to-speech), and classification problems (e.g., sentiment analysis, image classification).
Tokenization
  • A token is the atomic structure of one single element of the vocabulary of each the sequences are composed from which a Transformer aims to solve one of the tasks mentioned above.
  • character-level or subword level tokenization
  • Tokenization refers to transforming a sequence, at some level discussed above, into a sequence of integers mapping to the elements in the vocabulary
Architectural components specific to Transformers
  • token embedding: a function that learns to map vocabulary elements into d-dimensional real-valued vectors
  • positional embedding: a function that learns a position vector of a token in a sequence of d-dimensional token embeddings. The position embedding is typically added to the token embedding which makes up the initial token embedding.
    • Positional embeddings can learn the absolute position (e.g., sinusoidal, rotation matrix) or relative position of a token wrt to the other contextual tokens. The embedding can be learned layer-wise or independently for each (encoder/decoder) layer.
    • Transformers are sequence ordering invariant, meaning that they can’t distinguish between two different sequences with exactly the same words.
    • If the position embedding is learned then it has a finite length so inputs cannot be larger than that maximum length. If it’s not learned, for example sinusoidal position embedding, then also inputs larger than those at training time can be used.
  • attention: Roughly speaking, this is the mechanism by which the Transformer model can predict the probability of a token given its contextual information by learning which tokens in that context to pay more attention to, given the similarity between the token to be predicted and all others.
    • the to-predict-token is mapped to a key vector, while all other in-context tokens are mapped to query tokens, each of which has also a value token. A good analogy here is provided by 2, that of querying a database for approximate answers. The input is a query. The database has (key, value) pairs. Querying the database means trying to find the values for which the query is most similar to the keys. Not all keys are equally important, hence we also learn to attend to the ones that are more important.
  • masking: is a mechanism by which the attention mechanism may attend to a constrained set of tokens in the sequence. When no masking is applied, tokens can attend to all other tokens in the sequence, hence we say it is a bidirectional encoding. On the other hand, when tokens are not allowed to attend to other tokens in the future, but only those from the past, we use uni-directional encoding, masking the future tokens.
    • Masking is a task-dependent mechanism. For some tasks, peaking into the future (in the sequence) is just not permitted.
  • cross-attention: this mechanism is typically applied in Encoder-Decoder architectures (cf. next section) such as language translation where we have a pair of source and target language sequences. Specifically, the encoder applies self-attention on the source-language sequence and the encoded output is then used in the decoder as part of a cross-attention between it and the encoded target-language sequence.
  • multi-head attention: typically attention, as described above, is applied many times independently in parallel over chunks of the initial sequence, and the encoded outputs are then concatenated. This is achieved by using multiple attention headas, each with its own set of paramters randomly initialied. In this way the multi-attention heads can capture different variability in the sequence.
Transformers architectures
  • Encoder-Decoder

    • Typically used for sequence-to-sequence modeling, e.g., machine translation. To expand on this example, we have a source-language sequence and a target-language sequence. First the encoder encodes the source-language sequence by allowing each token to attend to each other token in that sequence. Then the decoder is encoding first the target-language sequence in the same fashion as the encoder using bidirectional self-attention. Then, the decoder additionally uses cross-attention with causal masking between the encoded source-language sequence and the encoded target-language sequence to produce, in this case, a translation from the source to target language.
  • Encoder-Only (e.g., Bert)

    • BERT is an instantiation of an encoder-only Transformer. It was developed for learning robust text representations that can easily be applied to new downstream-tasks without modifying the backbone architecture. During training some tokens are masked out with some probability using a masked-token, and the model needs to learn to reconstruct the original tokens using the context around them.
  • Decoder-Only (e.g., GPT-2)

    • GPT-2 is an instantiation of this Transformer architecture. Being an autoregressive model, i.e., trying to predict the next token from a partial sequence, it uses unidirectional attention (masking out future tokens for any query token that come after it).

Various Transformer-based architectures and methods

This section is based on Lil’Log’s blog post 3. I read through many of the papers referenced in the post.

Improving squared dependency on context length

  • impacts inference time since context length depends on the one used during training
  • T: $O(L^2 \times D)$, M: $O(L^2)$
  • overcome within sequence limited attention
    • Transformers-XL
      • context memory to capture attention dependencies intra-segments (between segments)
    • Compressive Transformer
      • iteratively compressing sequence into shorter and shorter memories
      • uses two new losses
        • Attention-reconstruction loss : how well can i reconstruct the attention given original memories and compressed memories
        • Auto-encoding loss: how well can i reconstruct original memories from compressed ones
    • Non-Differentiable External Memory
      • in addition to LM next-token prediction, add a kNN (on LM embedding) average of all the k, v = (LM embedding representation of context, next token)
      • huge external memory possible
      • use fast kNN retrievals like FAISS
      • indexing is costly but happens only once
    • SPALM
      • kNN-LM + Transformers-XL (balancing short-term memory and long-term memory)
      • Distance-Enhanced Attention Scores
      • DA-Transformer
        • maintains complexity of vanilla Transformer
        • different relative position embeddings for each head
        • relative position distance weights regulate the attention weight
          • for example if the weight for the long distance is large than the attention score for that k,q pair is amplified.
        • extrapolation is possible because of using relative distance

Recurrent modules

  • Universal Transformers
    • equation (4) and (5)
    • hidden state is computed from self-attention on previous h plus transition function (either FFN or conv layer depending on problem) to attention-layer output
    • determine number of timesteps to update hidden state dynamically using ACT 4
    • hyper-parameters: threshold_for_halting_probability, max_timesteps
    • probability of a hidden state is computed via a sigmoid dense layer

Adaptive Modeling

  • the quadratic time complexity is particularly problematic for character-level modeling (where context spans across thousands of timesteps)
    • cases such as protein folding are impacted?
  • Adaptive Attention Span
    • some evidence suggests that attention layers in early layers of the net attend to short context lengths while layers in upper end attend to longer context lengths
    • unlike in the standard Transformer where each head shares the same attention span S, here the objective is to learn the attention spans independently for each head
    • Attention span range is pre-determined by a minimum range (32) and maximum range (S=4096)
    • Dynamic adaptive span: where the span parameters are dependent to the current input

Depth-Adaptive Transformer

  • Perform different amount of computation based on how difficult it is to make predictions for given input sequence
  • The model was used for the problem of large-scale machine translation (source language to target language)
  • Output classifiers parameterized independently (or can share weights) are attached to the output of each decoder block
    • Training multiple output classifiers:
      • The following two approaches are ways the authors experimented with training the classifiers to output the next token given previous hidden state. This loss is called the decoder loss.
      • Aligned training
        • All classifiers are trained simultaneously. This assumes that all hidden states needed for each classifier at the respective decoder depth are available (The paper does not explain how this simultaneous computing procedure is performed, so I can only speculate that a forward pass of the whole input sequence for each decoder depth is performed, after which simultaneous backward passes for all classifiers can be run).
        • At test time though if hidden states are not available because for a previous time-step the model exited early, the missing hidden states are copied from previous layers are copied.
      • Mixed training
        • To avoid this mismatch between training and testing that aligned training does, authors experimented with a training procedure that handles the two phases, training and testing, equally. They propose to sample M exits for each of the timesteps of the input sequence. That is, for each timestep the model will exit at a random decoder depth. To ensure that all previously required hidden states are available, they are copied from decoder layers below if not available at the current depth.
    • Adaptive Depth Estimation
      • The actual depth estimation, so where to stop and output the next token, is learned via a second loss called the exit loss.
      • Sequence-specific depth: In this approach the model learns to stop at the same depth for all tokens in the sequence.
        • The exit loss here is defined as the cross-entropy between a time-independent parametric distribution $q$ which takes the encoder output as input, and an oracle distribution $q^{*}$, also time-independent that takes in the encoder output and target sequence y. Both distributions are multinomial softmax functions that output the probability of exiting at a given depth n.
          • The oracle distribution, so the target for the approximate distribution, was chosen by the authors to be one of:
            • Sequence likelihood-based: The block for which the likelihood over the entire sequence is maximal is selected.
              • Optionally, one can add regularization over the depth $n$ to encourage maximum likelihood at lower depth.
            • Correctness-based: Here, the number of correctly predicted tokens of each block is maximized. Similarly to the above oracle, regularization on the depth can be used to prefer lower blocks.
      • Token-Specific Depth: This approach selects a different exit block for each token in the sequence. In one method, the most likely exit conditioned on the first decoder hidden state among all blocks is selected. In a second method, a probability is assigned to each block based on its activations. For all blocks except the last one, the probability of stopping at the current block, multiplied by the probability trace of not stopping at any previous block lower than the current one, is assigned. This is applicable to every depth $n < N$. However, for the final block, only the trace is taken into account due to the probability of ending being one since there is no decoder block following it.
        • For both methods, oracles are needed. Oracles can be determined by likelihood-based method, so favoring at each timestep the block whose exit classifier has the highest likelihood.

Flash Attention

  • Implements standard attention in exact form by computing some extra statistics such that the number of read/writes to/from HBM is minimized. In doing so it improves the memory usage over the standard attention implementation from quadratic complexity in sequence length to linear complexity.
  • Authors provide also a flash attention implementation for sparse attention (where some of the pairwise attention values are zero) that is more efficient than flash attention by a factor proportional to the sparsity.

  1. Phuong, M., & Hutter, M. (2022). Formal algorithms for transformers. arXiv preprint arXiv:2207.09238. ↩︎

  2. Zhang, A., Lipton, Z. C., Li, M., & Smola, A. J. (2021). Dive into deep learning. arXiv preprint arXiv:2106.11342. ↩︎

  3. Weng, Lilian. (Apr 2020). The transformer family. Lil’Log. https://lilianweng.github.io/posts/2020-04-07-the-transformer-family/↩︎

  4. Alex Graves. Adaptive computation time for recurrent neural networks. arXiv preprint arXiv:1603.08983, 2016. ↩︎