Multi-Head Attention
A comprehensive guide to multi-head attention within the context of transformer architecture deep dive.
Why Multiple Attention Heads
A single attention function computes one set of attention weights, which means it can only capture one type of relationship between tokens at a time. Multi-head attention addresses this limitation by running multiple attention operations in parallel, each with different learned projections. This allows the model to simultaneously attend to information from different representation subspaces at different positions.
For example, one head might learn to attend to syntactic relationships (subject-verb agreement), another to semantic relationships (word similarity), and another to positional patterns (attending to nearby tokens). The original Transformer paper used 8 attention heads, and modern models use 32 to 128 heads.
Multi-Head Attention Architecture
Multi-head attention splits the model dimension into h heads, each operating on a lower-dimensional subspace:
import torch
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super().__init__()
assert embed_dim % num_heads == 0
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.embed_dim = embed_dim
self.W_q = nn.Linear(embed_dim, embed_dim)
self.W_k = nn.Linear(embed_dim, embed_dim)
self.W_v = nn.Linear(embed_dim, embed_dim)
self.W_o = nn.Linear(embed_dim, embed_dim)
def forward(self, x, mask=None):
batch_size, seq_len, _ = x.shape
# Project and reshape to (batch, heads, seq_len, head_dim)
Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# Scaled dot-product attention per head
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
if mask is not None:
scores = scores.masked_fill(mask.unsqueeze(1) == 0, float('-inf'))
weights = torch.softmax(scores, dim=-1)
attn_output = torch.matmul(weights, V)
# Concatenate heads and project
attn_output = attn_output.transpose(1, 2).contiguous().view(
batch_size, seq_len, self.embed_dim
)
return self.W_o(attn_output)
What Different Heads Learn
Research by Clark et al. (2019) analyzed attention heads in BERT and found that different heads specialize in different linguistic phenomena:
- Positional heads — Attend to the next or previous token (bigram patterns)
- Separator heads — Attend to special tokens like [SEP] or [CLS]
- Syntactic heads — Mirror dependency parse trees (e.g., attending from a verb to its subject)
- Coreference heads — Link pronouns to their antecedents
- Rare token heads — Attend strongly to infrequent words that carry high information
Head Pruning
Not all heads are equally important. Michel et al. (2019) showed that many heads can be removed after training with minimal impact on performance. This has practical implications for model compression and inference optimization.
The Output Projection
After concatenating the outputs from all heads, a final linear projection (W_o) maps the concatenated representation back to the model dimension. This projection is crucial because it allows the model to learn how to combine information from different heads.
Grouped Query Attention (GQA)
Modern models like LLaMA 2 and Mistral use Grouped Query Attention, where multiple query heads share a single key-value head. This reduces the memory required to cache key-value pairs during inference (the KV cache) without significantly impacting quality:
# Standard MHA: 32 query heads, 32 KV heads
# GQA: 32 query heads, 8 KV heads (4 query heads per KV group)
# MQA: 32 query heads, 1 KV head (all queries share one KV)
class GroupedQueryAttention(nn.Module):
def __init__(self, embed_dim, num_q_heads, num_kv_heads):
super().__init__()
self.num_q_heads = num_q_heads # e.g., 32
self.num_kv_heads = num_kv_heads # e.g., 8
self.group_size = num_q_heads // num_kv_heads # e.g., 4
self.head_dim = embed_dim // num_q_heads
self.W_q = nn.Linear(embed_dim, num_q_heads * self.head_dim)
self.W_k = nn.Linear(embed_dim, num_kv_heads * self.head_dim)
self.W_v = nn.Linear(embed_dim, num_kv_heads * self.head_dim)
Multi-Head Cross-Attention
In encoder-decoder transformers, the decoder uses cross-attention to attend to the encoder's output. Here, queries come from the decoder, while keys and values come from the encoder. This enables the decoder to focus on relevant parts of the input when generating each output token.
Practical Considerations
- Number of heads — Must evenly divide the model dimension. Common choices: 8, 12, 16, 32, 64
- Head dimension — Typically 64 or 128. Smaller head dimensions with more heads can be better than fewer heads with larger dimensions
- KV cache — During autoregressive generation, cache key and value tensors to avoid recomputation
- Attention dropout — Apply dropout to attention weights during training to prevent overfitting
The next lesson covers positional encoding — how transformers encode sequence order without any recurrence.
Lilly Tech Systems