Multi-Head Attention

Examining a module consisting of several attention layers running in parallel.

m0nads
Becoming Human: Artificial Intelligence Magazine

--

Image by Javier Miranda
Image by Javier Miranda

This post refers to the Transformer network architecture (paper). The Transformer model represents a successful attempt to overcome old architectures such as recurrent and convolutional networks. We will only deal with a small Transformer component, the Multi-Head Attention module.

The Transformer architecture

The first application of Transformer model was language translation. Very briefly, the Transformer has an encoder-decoder structure (just like other language translation models) and uses stacked attention layers instead of recurrent or convolutional ones. In the picture below (taken from the original paper) we can see that the Transformer consists of a certain number N of stacked attention layers (the original paper sets N=6) where an input sequence is processed to obtain an encoding that is used by the decoder part (right) to generate the output words (translation). We will not delve too much into the whole architecture details, we are going to examine the left module part instead (the Multi-Head Attention module in particular). Below, The Transformer architecture.

Transformer architecture

Embedding and Positional Encoding

Input words are transformed through an embedding layer, the resulting vectors have length dim(this value is set to 512 for many applications). Then these embeddings need to be “labeled” someway to set the position of each word in a sequence. Positional Encoding provides a representation of the location or “position” of items in a sequence. Positional Encoding adds information about the position of a word in the input sentence using trigonometric functions (check here for some good insights and explanations on how it works). The following picture shows the input for Multi-Head Attention module, that is, the sum of the input embedding and the positional encoding. In this example, the input embedding is a batch of 64 words and each word has a 512 values representation.

Input for Multi-Head Attention

Multi-Head Attention module for the encoder

We refer to this PyTorch implementation using the praised Einops library. It is intended for ViT (Vision Transformer) model users but, since ViT model is based on the Transformer architecture, almost all of the code concerns Multi-Head Attention + Transformer classes.

Multi-Head Attention takes compound inputs (embedding + positional encoding) at the beginning. Each of these three inputs undergoes a linear transformation: this is repeated for each head (heads, the number of heads, is 8 for default). The nn.Linear layers are, in essence, linear transformations of the kind Ax + b (without bias b in our case). nn.Linear operates on tensors in the following way: if our input tensor dimensions are (64, 512) and we perform nn.Linear(512,1536), then the resulting output tensor dimensions are (64, 1536). Below, the Multi-Head Attention mechanics.

The Transformer uses Multi-Head Attention in three different ways, we will focus on the encoder layer behavior (essentially a self-attention mechanism). The Multi-Head Attention module takes three identical inputs (positionally embedded words if at the beginning, the output from the previous layer in the encoder otherwise). Through three trainable matrices (Linear layers), for each word in the source sentence three vectors are generated (query, key and value). Word after word, these vectors eventually fill the matrices Q, K and V. Think of key-values system as a sort of dictionary

{key1: “cats”, key2: “chase”, key3: “mice”}.

A key is a vector representation. Each query vector is compared to all the keys (every word in the source sentence is compared to every word in the same sentence). A query should be similar to the keys corresponding to words having some kind of link, connection or affinity with the query itself. This similarity is expressed by dot products of rows and columns in the QKᐪ matrix (a division by s, the square root of dim_head, is performed to avoid the excessive growth in magnitude of products).

Example. Suppose that the source sentence consists of 10 words. For each word there is a positionally encoded embedding row. The 10 × dim positionally encoded embeddings will be fed three times as input to form, through multiplication by nn.Linear matrices, three vectors (query, key and value) for each word. So there are 30 resulting vectors in total. Take, for example, the query vector obtained from the first word and perform the dot product of this vector with each one of the ten keys: a 10 components numerical vector is obtained. The largest components should correspond to words in the sentence that are someway linked to the query. Doing the same with the remaining 9 query vectors will produce resulting vectors that would fill a 10 by 10 matrix QKᐪ (repeat all this for all heads to get the whole picture). The self-attention mechanism is depicted below.

Self-attention mechanism

The following picture shows almost the same mechanism depicted above:each row of Q forms a dot product with each column of K. The magnitude of products, instead of bars, is represented by squares (larger products correspond to larger squares).

How attention matrix is formed

Below, the PyTorch code for Multi-Head Attention class. Note that, in this code, the real attention (as defined in the original paper) is expressed by out (containing components for all heads), that is, the matrix product between attn and v.

class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == dim)

self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
) if project_out else nn.Identity()
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(
t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)

A concrete example

BERT is a model for pre-training language representations. A general-purpose language understanding model is trained on unlabeled large text corpus (for example, Wikipedia) and then employed for a wide range of tasks. BERT is, structurally, a stack of encoder modules from Transformer architecture. The picture below shows a base version of the BERT architecture (12 encoder modules, hidden size=768, attention heads=12). BERT base has 12 attention heads per layer (144 in total).

The picture below (from the article Revealing the Dark Secrets of BERT by O. Kovaleva et al.) shows the attention patterns for two of the 144 heads. Two heads patterns show their ability to capture semantic relations. In practice it is not convenient to use bars or squares to express magnitude, so color shades are a good choice (heatmap).

[source]

There are 2 out of 144 attention heads (the heatmap is obtained through averaging of all the individual input example maps) that account for 0.201 and 0.209, which are greater than a 99-th percentile of the distribution of values for all heads. Attention weights relative to both heads are high for “he” while processing “agitated” in the sentence “He was becoming agitated”.

For further attention visualizations, check this page.

Originally posted on m0nads.

Support this blog.

Useful links

PyTorch code using Einops notation.

Transformer Architecture: The Positional Encoding [link].

Attention Is All You Need
A. Vaswani et al.
arXiv:1706.03762v5 [cs.CL], 2017.

An Image is Worth 16×16 Words: Transformer for Image Recognition at Scale
A. Dosovitskiy et al.
arXiv:2010.11929 [cs.CV], 2021.

Revealing the Dark Secrets of BERT
O. Kovaleva, A. Romanov, A. Rogers, A. Rumshisky
arXiv:1908.08593 [cs.CL], 2019.

The Annotated Transformer — Attention Visualization [link]

--

--