Understanding positional encoding in Transformers

Transformers are a very popular architecture in machine learning. While they were first introduced in natural language processing, they have been applied to many fields such as protein folding and design.
Transformers were first introduced in the excellent paper Attention is all you need by Vaswani et al. The paper describes the key elements, including multiheaded attention, and how they come together to create a sequence to sequence model for language translation. The key advance in Attention is all you need is the replacement of all recurrent layers with pure attention + fully connected blocks. Attention is very efficeint to compute and allows for fast comparisons over long distances within a sequence.
One issue, however, is that attention does not natively include a notion of position within a sequence. This means that all tokens could be scrambled and would produce the same result. To overcome this, one can explicitely add a positional encoding to each token. Ideally, such a positional encoding should reflect the relative distance between tokens when computing the query/key comparison such that closer tokens are attended to more than futher tokens. In Attention is all you need, Vaswani et al. propose the slightly mysterious sinusoidal positional encodings which are simply added to the token embeddings:

Sinusoidal positional encoding equations from Vaswani et al.

When I first saw these, I had absolutely no intuition for why they would allow the model to attend to relative distances. It seemed more intuitive to me to simply append a direct measure of position such as the index in the sequence.

Why the naive approach fails

To understand why sinusoidal position encoding makes sense, we have to remember how attention is actually calculated. To determine the attention weight of a given target token based on a given source token, we compute the dot product of the query mapping of the source with the key mapping of the target. In the simplest case, we can imagine that the query and key mappings are the identity, in which case we are just taking the dot product of the embeddings of the two tokens. If we were to encode position by appending the index, the resulting dot product would be monotone increasing with the index, i.e., tokens near the end of a sequence would receive the most attention regardless of where the query token is.
We could similarly imagine another encoding such as one hot encodings of position. These would result in a value of 1 whenever the target and source have the same index and 0 otherwise. Ideally, what we want is an encoding which gives a high value for nearby tokens which smoothly decays as we look at tokens which are further away.

Sinusoidal encodings

Imagine a series of pendulums all in a line, swinging at different frequencies, where the leftmost is swinging the slowest, and each consecutive pendulum is swinging faster. Now imagine we take a picture of where each pendulum is and then another picture a short while later. The pendulums on the left will have moved little whereas the ones on the right may have moved a lot. If we were to compute the dot product of their positions, the ones on the left will be very well aligned and so will contribute positively. Conversely, the ones on the right will be out of phase and so will essentially contribute just noise, centered around 0. As we increase our timestep between pictures, more of the pendulums will be out of phase, and so the whole value will converge to 0. This is essentially what sinusoidal positional encoding is modelling. Our positional encoding for each token is the location of each pendulum at time t, where t is the position of the token. The figure below shows the encoded values at different positions, as well as the value of the dot product between the encoding at position 1000 and neighbouring tokens. Importantly, the value is high for nearby tokens and smoothly falls off as the distance increases.

Positional encoding values (left) and the value of the dot product of position 1000 with neighbouring positions (right). Figure adapted from TensorFlow’s Transformers tutorial.

Why are positional encodings added (not appended)?

Another question I had about positional encodings is why they are added to the encodings as opposed to appended to the end. Why doesn’t that “overwrite” some of the data stored at those positions? There is some good discussion about this topic in this GitHub issue. Basically, positional encoding only ends up adding meaningful information to the first part of your encoding vector. The model can likely learn to reserve this for the positional encoding and include all other information about tokens with the remainder.

Some limitations

As we have seen, sinusoidal positional encoding allows us to learn to pay more attention to nearby tokens by adding a value which falls off as realtive distance increases. This is only one way to add positional information and, while effective, has some drawbacks.

Additive encodings

The first limitation is that the positional encodings discussed thus far are additive. To calculate attention, we take the dot product of the query and key mappings of our source and target tokens. When positional encoding is just included as part of the token embedding, we are essentially adding the relatedness of the two tokens to the positional encoding term. This means that highly related words will be given high attention scores regardless of where they are which can be problematic, especially over long sequences. The RoFormer introduced by Su et al. addresses this by making positional information multiplicative, making it easier to ignore distant tokens. Rotary positional encodings have become quite popular and are used in ESM2.

Position = closeness?

The second (and more fundamental) drawback of positional encoding is that it explicitely assumes that closer tokens are more important. In natural language, this is often a reasonable assumption as it relates to the natural organization of information in sentences and paragraphs. Unfortunately, other data modalities such as protein sequences have complex and non-convex positional relationships related to where the tokens lay in 3D space. Learning an appropriate positional encoding in these contexts requires learning the underlying 3D positions, which for proteins is the famous protein structure prediction problem. As it turns out, AlphaFold2 makes use of transformers to predict protein structure. So how does AlphaFold2 learn to model 3D positions using information from the tokens and linear positions? One of the ways is by pairing cysteines. As we can see below, in the first few layers, the model is only looking at neighbouring amino acids or distant cysteines which could form disulfide bonds to stabilize the structure. Downstream layers eventually learn to model attention more coherently which closely matches the amino acid contact map!

What does AlphaFold2 pay attention to? Attention for the first two layers (a and b) looks at neighbouring amino acids and paired cysteines. Attention in the later layers begins to resemble the amino acid contact map (d). Figure reproduced from Jumper et al.

Author