Implementing Transformers
A comprehensive guide to implementing transformers within the context of transformer architecture deep dive.
Building a Transformer from Scratch
Implementing a transformer from scratch is one of the most effective ways to deeply understand the architecture. In this lesson, we will build a complete, working decoder-only transformer using PyTorch, covering every component from token embedding to text generation. This implementation follows the modern LLaMA-style architecture.
Model Configuration
from dataclasses import dataclass
@dataclass
class TransformerConfig:
vocab_size: int = 32000
d_model: int = 512
num_heads: int = 8
num_layers: int = 6
d_ff: int = 1376 # 2/3 * 4 * d_model for SwiGLU
max_seq_len: int = 2048
dropout: float = 0.1
rope_theta: float = 10000.0
RMS Normalization
Modern transformers use RMSNorm instead of LayerNorm because it is faster (no mean computation) and equally effective:
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
return x / rms * self.weight
SwiGLU Feed-Forward Network
SwiGLU (Shazeer, 2020) replaces the standard ReLU FFN with a gated linear unit using the SiLU (Swish) activation. It uses three projection matrices instead of two but with a reduced hidden dimension to keep the parameter count similar:
class SwiGLU(nn.Module):
def __init__(self, d_model, d_ff):
super().__init__()
self.w1 = nn.Linear(d_model, d_ff, bias=False)
self.w2 = nn.Linear(d_ff, d_model, bias=False)
self.w3 = nn.Linear(d_model, d_ff, bias=False)
def forward(self, x):
return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
Complete Transformer Block
class TransformerBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.attn_norm = RMSNorm(config.d_model)
self.attn = MultiHeadAttention(
config.d_model, config.num_heads
)
self.ff_norm = RMSNorm(config.d_model)
self.ff = SwiGLU(config.d_model, config.d_ff)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x, mask=None):
# Pre-norm architecture
h = x + self.dropout(self.attn(self.attn_norm(x), mask))
out = h + self.dropout(self.ff(self.ff_norm(h)))
return out
Full Model Assembly
class Transformer(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.tok_emb = nn.Embedding(config.vocab_size, config.d_model)
self.layers = nn.ModuleList([
TransformerBlock(config) for _ in range(config.num_layers)
])
self.norm = RMSNorm(config.d_model)
self.output = nn.Linear(config.d_model, config.vocab_size, bias=False)
# Weight tying
self.output.weight = self.tok_emb.weight
def forward(self, tokens):
x = self.tok_emb(tokens)
# Create causal mask
seq_len = tokens.size(1)
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
mask = ~mask # True = attend, False = mask
for layer in self.layers:
x = layer(x, mask=mask)
x = self.norm(x)
logits = self.output(x)
return logits
Text Generation
@torch.no_grad()
def generate(model, tokenizer, prompt, max_new_tokens=100, temperature=0.8):
tokens = tokenizer.encode(prompt, return_tensors="pt")
for _ in range(max_new_tokens):
logits = model(tokens)
next_logits = logits[:, -1, :] / temperature
probs = torch.softmax(next_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
tokens = torch.cat([tokens, next_token], dim=1)
if next_token.item() == tokenizer.eos_token_id:
break
return tokenizer.decode(tokens[0])
Training Loop Essentials
Key training considerations for transformers:
- Learning rate schedule — Use warmup followed by cosine decay. Warmup prevents early instability.
- Gradient clipping — Clip gradient norm to 1.0 to prevent training instabilities
- Mixed precision — Use bfloat16 for 2x speedup with minimal quality loss
- Gradient accumulation — Simulate larger batch sizes on limited GPU memory
This completes the Transformer Architecture Deep Dive. You now understand every component of the transformer, from attention mechanisms to modern variants and practical implementation. Apply this knowledge to evaluate, customize, and build transformer-based systems for your specific use cases.
Lilly Tech Systems