Advanced

Implementing Transformers

A comprehensive guide to implementing transformers within the context of transformer architecture deep dive.

Building a Transformer from Scratch

Implementing a transformer from scratch is one of the most effective ways to deeply understand the architecture. In this lesson, we will build a complete, working decoder-only transformer using PyTorch, covering every component from token embedding to text generation. This implementation follows the modern LLaMA-style architecture.

Model Configuration

from dataclasses import dataclass

@dataclass
class TransformerConfig:
    vocab_size: int = 32000
    d_model: int = 512
    num_heads: int = 8
    num_layers: int = 6
    d_ff: int = 1376  # 2/3 * 4 * d_model for SwiGLU
    max_seq_len: int = 2048
    dropout: float = 0.1
    rope_theta: float = 10000.0

RMS Normalization

Modern transformers use RMSNorm instead of LayerNorm because it is faster (no mean computation) and equally effective:

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
        return x / rms * self.weight

SwiGLU Feed-Forward Network

SwiGLU (Shazeer, 2020) replaces the standard ReLU FFN with a gated linear unit using the SiLU (Swish) activation. It uses three projection matrices instead of two but with a reduced hidden dimension to keep the parameter count similar:

class SwiGLU(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.w1 = nn.Linear(d_model, d_ff, bias=False)
        self.w2 = nn.Linear(d_ff, d_model, bias=False)
        self.w3 = nn.Linear(d_model, d_ff, bias=False)

    def forward(self, x):
        return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
💡
SwiGLU vs ReLU: SwiGLU consistently outperforms ReLU and GELU activations in language modeling benchmarks. The gating mechanism (multiplying by w3(x)) allows the network to selectively pass information, similar to LSTM gates.

Complete Transformer Block

class TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attn_norm = RMSNorm(config.d_model)
        self.attn = MultiHeadAttention(
            config.d_model, config.num_heads
        )
        self.ff_norm = RMSNorm(config.d_model)
        self.ff = SwiGLU(config.d_model, config.d_ff)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x, mask=None):
        # Pre-norm architecture
        h = x + self.dropout(self.attn(self.attn_norm(x), mask))
        out = h + self.dropout(self.ff(self.ff_norm(h)))
        return out

Full Model Assembly

class Transformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.tok_emb = nn.Embedding(config.vocab_size, config.d_model)
        self.layers = nn.ModuleList([
            TransformerBlock(config) for _ in range(config.num_layers)
        ])
        self.norm = RMSNorm(config.d_model)
        self.output = nn.Linear(config.d_model, config.vocab_size, bias=False)
        # Weight tying
        self.output.weight = self.tok_emb.weight

    def forward(self, tokens):
        x = self.tok_emb(tokens)
        # Create causal mask
        seq_len = tokens.size(1)
        mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
        mask = ~mask  # True = attend, False = mask

        for layer in self.layers:
            x = layer(x, mask=mask)
        x = self.norm(x)
        logits = self.output(x)
        return logits

Text Generation

@torch.no_grad()
def generate(model, tokenizer, prompt, max_new_tokens=100, temperature=0.8):
    tokens = tokenizer.encode(prompt, return_tensors="pt")
    for _ in range(max_new_tokens):
        logits = model(tokens)
        next_logits = logits[:, -1, :] / temperature
        probs = torch.softmax(next_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        tokens = torch.cat([tokens, next_token], dim=1)
        if next_token.item() == tokenizer.eos_token_id:
            break
    return tokenizer.decode(tokens[0])

Training Loop Essentials

Key training considerations for transformers:

  • Learning rate schedule — Use warmup followed by cosine decay. Warmup prevents early instability.
  • Gradient clipping — Clip gradient norm to 1.0 to prevent training instabilities
  • Mixed precision — Use bfloat16 for 2x speedup with minimal quality loss
  • Gradient accumulation — Simulate larger batch sizes on limited GPU memory
Weight initialization matters: Poor initialization can cause training to diverge. Use scaled initialization where weights are divided by sqrt(2 * num_layers) for residual connections. This prevents the residual stream from growing too large in deep networks.

This completes the Transformer Architecture Deep Dive. You now understand every component of the transformer, from attention mechanisms to modern variants and practical implementation. Apply this knowledge to evaluate, customize, and build transformer-based systems for your specific use cases.