Self-Attention Explained
Motivation
Transformers are really important. Self-attention is the core mechanism of the Transformer model. It's a bit confusing at first, but building it from the ground up with some mathematical intuitions, it's not so bad!
The Problem
We have a sequence of vectors (representing words, pixels, or anything else), and we want to process it into a new sequence (of the same dimensions) in a way that considers the meaning of the whole sequence. This means considering complex interdependencies: for example, in the sentence "Alice wrote a note and gave it to Bob," it is only by considering context that you can figure out that "it" refers to the note. in order to complete a downstream task (whether that's classifying the sentiment of the sentence, translating it to another language, or predicting the next word).
Old way of doing this: RNNs, which processed the words one at a time, and used the hidden state of the previous word to inform the current word. This is a good idea, but it has some problems: it's slow, and it doesn't scale well to long sequences, because of vanishing / exploding gradients, and problems capturing long-term dependencies. However, it did successfully allow processing each word to be aware of the other words, since each previous word would modify the hidden state, saving information that could be used to process the next word.
The Idea of Attention
Desiderata:
- Process all the tokens in parallel, so we don't have to wait for each one before providing the next one. We're using GPUs!
- Allow the tokens to interact with each other.
Self-attention is the thing that results from wanting both of these. It provides a way to process all the tokens in the sequence in parallel, but also allow them to interact with each other in the computation, so that dependencies can be captured. It's basically performing a weighted average at each position, weighted by "relevance" (more on that later). Another way to think of what's going on is that information is being "copied" from one position
Dumb Idea: Average the Whole Sequence
This does capture meaning from all the parts! But it sort of "flattens" things, which is too crude for what we want. We have to keep the words separate, especially if we're trying to do things like predict the next word or translate the sentence (order and distinct words matter a lot!)
Better Idea: Position-Wise Weighted Average
Imagine we have an oracle that can tell us what other words in the sentence are relevant to each word. We can use this to weight the words in the sequence, and then take a weighted average. This is a good idea, but it's not quite what we want. We want to be able to learn what words are relevant to each word, rather than having to specify it manually. We want to be able to learn the weights, rather than having to specify them manually.
Mathematical Intuition: Non-Local Means
Unlike "local mean" filters, which take the mean value of a group of pixels surrounding a target pixel to smooth the image, non-local means filtering takes a mean of all pixels in the image, weighted by how similar these pixels are to the target pixel. This results in much greater post-filtering clarity, and less loss of detail in the image compared with local mean algorithms. If compared with other well-known denoising techniques, non-local means adds "method noise" (i.e. error in the denoising process) which looks more like white noise, which is desirable because it is typically less disturbing in the denoised product. (Source: Wikipedia)
Mathematical Intuition: Kernel Regression
Kernel regression estimates the conditional mean at a point
Attention outputs a value at position
Dot-Product Attention
We don't have an oracle. What can we do instead?
Dot Product as a Measure of Similarity
Basically, when two vectors are similar, their dot product is large. When they are orthogonal, their dot product is small. When they are opposite, their dot product is negative. This is a good measure of similarity, because it's easy to compute, and it's easy to interpret.
The Attention Matrix
The attention matrix is a matrix of size
The Whole Thing
def DotProductAttention(X):
attn_matrix = X @ X.T
attn_weights = F.softmax(attn_matrix, dim=1)
attn_output = attn_weights @ X
return attn_output
Why This Isn't Quite Enough
No reason to think that "note" and "it" would have similar vectors. We need to learn the weights, not just use the dot product.
Queries, Keys, and Values
One thing you might think is missing from the previous exposition is the opportunity for learning. As laid out above, self-attention is a deterministic (though complicated!) function of the input sequence
The introduction of queries, keys, and values introduces parameters into the attention layer, allowing it to be more flexible and learn as the neural network trains. To understand the motivation for queries, keys, and values, recognize that in a self-attention layer, each input element
- It acts as the query
that all elements in the sequence are compared to to assess their similarity, to compute attention weights in . - It acts as the key, as an element of
, when its similarity is compared to the query , also to compute the attention weight between itself and . - It acts as the value for averaging, i.e. the
in .
It makes sense in the self-attention setting to allow these three to be different vectors. This introduces more flexibility into the model, and as described above, allows the introduction of trainable parameters. To do this, we modify self-attention by using three different linear projections of
The self-attention computation is the same as before, but with these values replacing
The whole process is summarized below, mostly zoomed in on the self-attention computation for the first element of the input sequence (here,
Multi-Head Attention (So, More Heads?)
The fancier query-key-value attention outlined above is great! But, we might still have complaints—with only one set of weights used for all tokens and all sequences, we have a bit of a one-size-fits-all solution. And it’s not easy to attend to multiple positions at the same time. (Generally with softmax, the largest value dominates, hence the “max.”) To introduce even more flexibility, and allow more “representational subspaces” in the self-attention layer, we repeat the attention computation multiple times, with a different set of (
Afterwards, it is customary to project the stack of outputs back to the original input dimension, simply by concatenating them, followed by a linear projection (i.e. fully-connected or dense layer). We often want to apply many self-attention layers one after the other, and don’t want the output size to grow exponentially.
Attention compares all pairs of words/tokens (in fact, it is permutation-invariant), so it is much better than recurrent or convolutional architectures at capturing long-range dependencies. Because many computations can run in parallel, it also takes better advantage of GPUs.
Masked Attention
In some cases, we want to mask out certain elements of the sequence, so that they are not considered when computing the attention weights. For example, in language modeling, we want to mask out the future tokens in the sequence, so that the model can’t cheat by looking ahead. In machine translation, we want to mask out the padding tokens in the sequence, so that the model doesn’t pay attention to them.
Beyond Self-Attention
Self-attention is the main thing now, but it actually generalizes to cases where the queries, keys, and values do not all come from the same sequence. That's how attention originated, and it continues to be used this way in encoder-decoder models.
The Origin of Attention (RNN)
Before the advent of transformers, neural machine translation (translating from language to another using neural networks) was based on recurrent models, which processed the input sentence one word at a time using shared weights, generating a new hidden state
Attention was initially popularized as a solution to this problem: instead of just using the final hidden state to represent the sentence, Bahdanau et al. proposed using all of the hidden states (one for each input word) to represent the sentence. As the sentence is “decoded” (i.e. at each step when a new translated word is output), the hidden states are averaged, weighted by their relevance to the current step of decoding. In this way, the decoder can “pay attention” to different hidden states (corresponding to different encoding steps), depending on how relevant they are to the current decoding step. The weights, i.e. the relevance of one vector to another, can be determined (most simply) by a dot product. (Vectors that point in similar directions tend to have more positive dot products, and vectors that are opposite or orthogonal have smaller or negative dot products.) The raw scores of the dot product are then normalized by the softmax function.