Benjamin Anderson
github logo linkedin logo

Self-Attention Explained

Posted Nov 18, 2022 by Benjamin Anderson

Motivation

Transformers are really important. Self-attention is the core mechanism of the Transformer model. It's a bit confusing at first, but building it from the ground up with some mathematical intuitions, it's not so bad!

The Problem

We have a sequence of vectors (representing words, pixels, or anything else), and we want to process it into a new sequence (of the same dimensions) in a way that considers the meaning of the whole sequence. This means considering complex interdependencies: for example, in the sentence "Alice wrote a note and gave it to Bob," it is only by considering context that you can figure out that "it" refers to the note. in order to complete a downstream task (whether that's classifying the sentiment of the sentence, translating it to another language, or predicting the next word).

Old way of doing this: RNNs, which processed the words one at a time, and used the hidden state of the previous word to inform the current word. This is a good idea, but it has some problems: it's slow, and it doesn't scale well to long sequences, because of vanishing / exploding gradients, and problems capturing long-term dependencies. However, it did successfully allow processing each word to be aware of the other words, since each previous word would modify the hidden state, saving information that could be used to process the next word.

The Idea of Attention

Desiderata:

Self-attention is the thing that results from wanting both of these. It provides a way to process all the tokens in the sequence in parallel, but also allow them to interact with each other in the computation, so that dependencies can be captured. It's basically performing a weighted average at each position, weighted by "relevance" (more on that later). Another way to think of what's going on is that information is being "copied" from one position si to another sj. For example, above, we'd want to copy "note" to "it" so that the model can understand that "it" refers to "note". Each sequence token is modified to be a combination of itself, plus the other things in the sequence. This makes sense because when we are modeling a sequence, the meaning of a part often depends on the meaning of the whole or the other parts.

Dumb Idea: Average the Whole Sequence

This does capture meaning from all the parts! But it sort of "flattens" things, which is too crude for what we want. We have to keep the words separate, especially if we're trying to do things like predict the next word or translate the sentence (order and distinct words matter a lot!)

Better Idea: Position-Wise Weighted Average

Imagine we have an oracle that can tell us what other words in the sentence are relevant to each word. We can use this to weight the words in the sequence, and then take a weighted average. This is a good idea, but it's not quite what we want. We want to be able to learn what words are relevant to each word, rather than having to specify it manually. We want to be able to learn the weights, rather than having to specify them manually.

Mathematical Intuition: Non-Local Means

Unlike "local mean" filters, which take the mean value of a group of pixels surrounding a target pixel to smooth the image, non-local means filtering takes a mean of all pixels in the image, weighted by how similar these pixels are to the target pixel. This results in much greater post-filtering clarity, and less loss of detail in the image compared with local mean algorithms. If compared with other well-known denoising techniques, non-local means adds "method noise" (i.e. error in the denoising process) which looks more like white noise, which is desirable because it is typically less disturbing in the denoised product. (Source: Wikipedia)

Mathematical Intuition: Kernel Regression

Kernel regression estimates the conditional mean at a point x by taking a weighted average of the values of the training data, where the weights are determined by a kernel function k that measures the similarity between the training data and the point x.

Attention outputs a value at position t that is a weighted average of all the items in the sequence, based on their similarity or relevance (abstract) to st.

Dot-Product Attention

We don't have an oracle. What can we do instead?

Dot Product as a Measure of Similarity

Basically, when two vectors are similar, their dot product is large. When they are orthogonal, their dot product is small. When they are opposite, their dot product is negative. This is a good measure of similarity, because it's easy to compute, and it's easy to interpret.

The Attention Matrix

The attention matrix is a matrix of size n×n, where n is the length of the sequence. Each element aij is the similarity between the ith and jth elements of the sequence. The attention matrix is symmetric, because the similarity between i and j is the same as the similarity between j and i. Normalize with Softmax so that the values are between 0 and 1, and sum to 1. Use these weights to average the sequence, to get the weighted average of the sequence.

The Whole Thing

SelfAttn(X)=Softmax(XTX)X

def DotProductAttention(X):
attn_matrix = X @ X.T
attn_weights = F.softmax(attn_matrix, dim=1)
attn_output = attn_weights @ X
return attn_output

Why This Isn't Quite Enough

No reason to think that "note" and "it" would have similar vectors. We need to learn the weights, not just use the dot product.

Queries, Keys, and Values

One thing you might think is missing from the previous exposition is the opportunity for learning. As laid out above, self-attention is a deterministic (though complicated!) function of the input sequence S, with no trainable parameters. Sure, the gradients can propagate through one or more attention layers back to the original (trainable) embeddings, but the attention layers themselves are not parameterized, so they can’t specialize or learn anything to solve the specific task at hand. Using just a dot product to measure similarity is a crude approach as well—the network can’t learn which things should be similar, it is hard coded into the embeddings.

The introduction of queries, keys, and values introduces parameters into the attention layer, allowing it to be more flexible and learn as the neural network trains. To understand the motivation for queries, keys, and values, recognize that in a self-attention layer, each input element si plays several roles:

  1. It acts as the query x that all elements in the sequence are compared to to assess their similarity, to compute attention weights in Attn(S,x).
  2. It acts as the key, as an element of S, when its similarity is compared to the query x, also to compute the attention weight between itself and x.
  3. It acts as the value for averaging, i.e. the si in iwisi.

It makes sense in the self-attention setting to allow these three to be different vectors. This introduces more flexibility into the model, and as described above, allows the introduction of trainable parameters. To do this, we modify self-attention by using three different linear projections of si for these three different roles. The weights (WQ,WK,WV ) of these projections are learned as the model is trained.

qi=WQsi ki=WKsi vi=WVsi

The self-attention computation is the same as before, but with these values replacing si for each of its roles:

Attn(S,si)=i=1|S|wivi

wi=Softmax(kiqi)

The whole process is summarized below, mostly zoomed in on the self-attention computation for the first element of the input sequence (here, z1), mapping it all the way to the first element of the output sequence (here, o1).

Multi-Head Attention (So, More Heads?)

The fancier query-key-value attention outlined above is great! But, we might still have complaints—with only one set of weights used for all tokens and all sequences, we have a bit of a one-size-fits-all solution. And it’s not easy to attend to multiple positions at the same time. (Generally with softmax, the largest value dominates, hence the “max.”) To introduce even more flexibility, and allow more “representational subspaces” in the self-attention layer, we repeat the attention computation multiple times, with a different set of (WQ,WK,WV). (In practice, these can all happen in parallel.) Each iteration of the attention computation is called an attention head. This means instead of just one output, a self-attention layer with h heads produce a stack of h different outputs. Since each set of weights is different, each output in the stack is different, and can attend to a different set of interrelationships among the tokens in the sequence.

Afterwards, it is customary to project the stack of outputs back to the original input dimension, simply by concatenating them, followed by a linear projection (i.e. fully-connected or dense layer). We often want to apply many self-attention layers one after the other, and don’t want the output size to grow exponentially.

Attention compares all pairs of words/tokens (in fact, it is permutation-invariant), so it is much better than recurrent or convolutional architectures at capturing long-range dependencies. Because many computations can run in parallel, it also takes better advantage of GPUs.

Masked Attention

In some cases, we want to mask out certain elements of the sequence, so that they are not considered when computing the attention weights. For example, in language modeling, we want to mask out the future tokens in the sequence, so that the model can’t cheat by looking ahead. In machine translation, we want to mask out the padding tokens in the sequence, so that the model doesn’t pay attention to them.

Beyond Self-Attention

Self-attention is the main thing now, but it actually generalizes to cases where the queries, keys, and values do not all come from the same sequence. That's how attention originated, and it continues to be used this way in encoder-decoder models.

The Origin of Attention (RNN)

Before the advent of transformers, neural machine translation (translating from language to another using neural networks) was based on recurrent models, which processed the input sentence one word at a time using shared weights, generating a new hidden state ht for each position t in the input sentence (based on ht1 and the t-th word in the sentence). At the end of the sequence, the final hidden state would be used as the representation of the whole sentence. This was limiting, because all the information from the whole sentence had to be compressed into that one state. It rendered models unable to deal with longer sentences.

Attention was initially popularized as a solution to this problem: instead of just using the final hidden state to represent the sentence, Bahdanau et al. proposed using all of the hidden states (one for each input word) to represent the sentence. As the sentence is “decoded” (i.e. at each step when a new translated word is output), the hidden states are averaged, weighted by their relevance to the current step of decoding. In this way, the decoder can “pay attention” to different hidden states (corresponding to different encoding steps), depending on how relevant they are to the current decoding step. The weights, i.e. the relevance of one vector to another, can be determined (most simply) by a dot product. (Vectors that point in similar directions tend to have more positive dot products, and vectors that are opposite or orthogonal have smaller or negative dot products.) The raw scores of the dot product are then normalized by the softmax function.

Encoder-Decoder Attention in Transformers

References