Intermediate

Self-Attention Mechanism

A comprehensive guide to self-attention mechanism within the context of transformer architecture deep dive.

Understanding Self-Attention

Self-attention is the mechanism that allows each token in a sequence to compute a weighted representation of all other tokens. Unlike cross-attention (where queries come from one sequence and keys/values from another), in self-attention, the queries, keys, and values all come from the same sequence. This enables the model to learn relationships between different positions within a single input.

Consider the sentence "The cat sat on the mat because it was tired." For a human, it is obvious that "it" refers to "the cat." Self-attention enables the model to learn this kind of relationship by computing how strongly each token should attend to every other token.

Queries, Keys, and Values

The self-attention mechanism is built on three learned linear projections of the input:

  • Query (Q) — Represents "what am I looking for?" Each token generates a query vector that describes what information it needs from other tokens.
  • Key (K) — Represents "what do I contain?" Each token generates a key vector that describes what information it offers to other tokens.
  • Value (V) — Represents "what information do I provide?" When a query matches a key, the corresponding value is what gets passed along.

Think of it like a library search. The query is your search term, the keys are the book titles/descriptions, and the values are the actual book contents. The attention score between a query and a key determines how much of that book's content you read.

import torch
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.embed_dim = embed_dim
        # Learned linear projections
        self.W_q = nn.Linear(embed_dim, embed_dim)
        self.W_k = nn.Linear(embed_dim, embed_dim)
        self.W_v = nn.Linear(embed_dim, embed_dim)
        self.scale = embed_dim ** 0.5

    def forward(self, x, mask=None):
        # x shape: (batch, seq_len, embed_dim)
        Q = self.W_q(x)  # (batch, seq_len, embed_dim)
        K = self.W_k(x)  # (batch, seq_len, embed_dim)
        V = self.W_v(x)  # (batch, seq_len, embed_dim)

        # Compute attention scores
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
        # scores shape: (batch, seq_len, seq_len)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        weights = torch.softmax(scores, dim=-1)
        output = torch.matmul(weights, V)
        return output, weights
💡
Why scale by sqrt(d_k)? When the dimension d_k is large, the dot products Q*K can grow very large in magnitude. This pushes the softmax function into regions where gradients are extremely small, making learning very slow. Dividing by sqrt(d_k) keeps the variance of the dot products at approximately 1, regardless of dimension.

The Attention Matrix

The attention matrix (after softmax) is a square matrix of size (seq_len x seq_len) where each row sums to 1. Entry (i, j) represents how much token i attends to token j. This matrix is incredibly informative for understanding what the model has learned.

Attention Patterns

Research has identified common patterns in learned attention matrices:

  • Diagonal pattern — Tokens attend primarily to themselves (local self-reference)
  • Vertical stripes — All tokens attend to a specific important token (like the CLS token or punctuation)
  • Block diagonal — Tokens attend to nearby tokens (local context)
  • Coreference pattern — Pronouns attend strongly to their referent nouns
  • Syntactic pattern — Tokens attend to syntactically related tokens (subject-verb, modifier-noun)

Computational Complexity

Self-attention computes pairwise interactions between all tokens, resulting in O(n^2) time and space complexity where n is the sequence length:

  • Memory — The attention matrix requires storing n^2 values per head per layer
  • Computation — Two matrix multiplications of size (n x d) by (d x n)
  • Practical limits — Standard transformers struggle with sequences longer than 2048-8192 tokens

Efficient Attention Alternatives

Several approaches reduce the quadratic complexity:

  1. Sparse attention (BigBird, Longformer) — Only compute attention for a subset of token pairs
  2. Linear attention (Performer, Linear Transformer) — Approximate softmax attention with linear-time kernel methods
  3. Flash Attention — Hardware-aware algorithm that reduces memory I/O, not asymptotic complexity
  4. Sliding window (Mistral) — Each token attends only to a fixed window of nearby tokens
# Flash Attention usage with PyTorch 2.0+
import torch.nn.functional as F

# Standard attention (O(n^2) memory)
attn_output = F.scaled_dot_product_attention(
    query, key, value,
    attn_mask=mask,
    is_causal=True  # For decoder-style autoregressive models
)
# PyTorch automatically uses Flash Attention when available
Memory bottleneck: For a sequence of length 4096 with 32 attention heads, storing the attention matrices alone requires 4096 x 4096 x 32 x 4 bytes = 2 GB per layer. With 32 layers, that is 64 GB just for attention maps. Flash Attention addresses this by never materializing the full attention matrix.

Masked Self-Attention

In autoregressive models (like GPT), each token should only attend to previous tokens, not future ones. This is enforced with a causal mask — an upper-triangular matrix of negative infinity values that prevents information flow from future positions.

In the next lesson, we explore multi-head attention, which allows the model to jointly attend to information from different representation subspaces at different positions.