Advanced

Training & Optimization

12 interview questions on the practical techniques that separate working models from failing ones. These questions test whether you have actually trained models in production, not just read about them.

Q1: What is learning rate warmup and why is it important for Transformers?

A

Learning rate warmup starts training with a very small learning rate and linearly increases it to the target learning rate over a fixed number of steps (typically 1-10% of total training steps).

Why it helps Transformers: At initialization, the Adam optimizer's second moment estimates (v_t) are zero, making the effective learning rate very large (lr / sqrt(v_t + eps) where v_t is tiny). This causes unstable large updates early in training that can permanently damage the model. Warmup gives the optimizer time to build accurate moment estimates.

Common schedule: Linear warmup for 2000 steps, then cosine decay to 0. This is the default for virtually all Transformer training.

import torch
from torch.optim.lr_scheduler import LambdaLR

def get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps):
    """Linear warmup then cosine decay - the standard LLM schedule"""
    import math
    def lr_lambda(current_step):
        if current_step < warmup_steps:
            return float(current_step) / float(max(1, warmup_steps))
        progress = (current_step - warmup_steps) / (total_steps - warmup_steps)
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))
    return LambdaLR(optimizer, lr_lambda)

# Usage
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps=2000, total_steps=100000)

for step, batch in enumerate(dataloader):
    loss = model(batch)
    loss.backward()
    optimizer.step()
    scheduler.step()
    optimizer.zero_grad()

Q2: Explain mixed precision training. How does it work and what are the pitfalls?

A

Mixed precision uses FP16 (or BF16) for forward/backward passes and FP32 for weight updates. FP16 uses half the memory and is 2-8x faster on modern GPUs (Tensor Cores).

How it works:

  1. Maintain FP32 master weights
  2. Cast to FP16 for forward and backward passes
  3. Scale the loss before backward to prevent gradient underflow (loss scaling)
  4. Unscale gradients and update FP32 master weights

Pitfalls: FP16 has a small dynamic range (6e-8 to 65504). Very small gradients can underflow to zero. Loss scaling multiplies the loss by a large factor (e.g., 1024) before backward, then divides gradients by the same factor. If gradients overflow (become inf), the step is skipped and the scale factor is reduced.

BF16 vs FP16: BF16 has the same exponent range as FP32 (preventing overflow/underflow) but less precision. It does not need loss scaling. Preferred when available (A100, H100 GPUs).

import torch
from torch.cuda.amp import autocast, GradScaler

# Mixed precision training with PyTorch
scaler = GradScaler()

for batch in dataloader:
    optimizer.zero_grad()

    # Forward pass in FP16
    with autocast(dtype=torch.float16):
        outputs = model(batch['input'])
        loss = criterion(outputs, batch['target'])

    # Backward pass: scaler handles loss scaling
    scaler.scale(loss).backward()

    # Unscale gradients, clip, then step
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    scaler.step(optimizer)
    scaler.update()

# BF16 (simpler - no scaler needed)
with autocast(dtype=torch.bfloat16):
    outputs = model(batch['input'])
    loss = criterion(outputs, batch['target'])
loss.backward()
optimizer.step()

Q3: What is gradient accumulation? When and how do you use it?

A

Problem: You want to train with batch_size=256 but your GPU can only fit batch_size=32. Large batch sizes often give better gradient estimates and more stable training.

Solution: Accumulate gradients over multiple mini-batches before updating weights. Process 8 mini-batches of 32, sum the gradients, then do one optimizer step. Mathematically equivalent to a batch of 256.

Key detail: Divide the loss by the accumulation steps (or average over them) to keep the gradient magnitude correct.

accumulation_steps = 8
optimizer.zero_grad()

for i, batch in enumerate(dataloader):
    outputs = model(batch['input'])
    loss = criterion(outputs, batch['target'])
    loss = loss / accumulation_steps  # Normalize loss
    loss.backward()                   # Gradients accumulate

    if (i + 1) % accumulation_steps == 0:
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

# Effective batch size = micro_batch * accumulation_steps * num_gpus
# 32 * 8 * 4 GPUs = 1024 effective batch size

Q4: What data augmentation techniques are used for images vs. text?

A

Image augmentation:

  • Geometric: Random crop, horizontal flip, rotation, affine transforms
  • Color: Brightness, contrast, saturation, hue jitter
  • Advanced: Mixup (blend two images and their labels), CutMix (paste a patch from one image onto another), RandAugment (randomly apply N augmentations from a pool), AutoAugment (learned augmentation policies)
  • Regularization: Cutout/Random Erasing (zero out a random rectangle)

Text augmentation:

  • Token-level: Synonym replacement, random insertion, random swap, random deletion
  • Sentence-level: Back-translation (translate to another language and back)
  • For LLMs: Paraphrasing with a different model, instruction rephrasing
  • For embedding models: Dropout as augmentation (pass same text twice with different dropout masks)
import torchvision.transforms as T

# Standard ImageNet augmentation pipeline
train_transform = T.Compose([
    T.RandomResizedCrop(224, scale=(0.08, 1.0)),
    T.RandomHorizontalFlip(),
    T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
    T.RandAugment(num_ops=2, magnitude=9),  # 2 random augmentations
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    T.RandomErasing(p=0.25),  # Cutout-style
])

# Mixup implementation
def mixup(x, y, alpha=0.2):
    lam = torch.distributions.Beta(alpha, alpha).sample()
    idx = torch.randperm(x.size(0))
    mixed_x = lam * x + (1 - lam) * x[idx]
    return mixed_x, y, y[idx], lam

# Usage in training loop
mixed_x, y_a, y_b, lam = mixup(images, labels)
outputs = model(mixed_x)
loss = lam * criterion(outputs, y_a) + (1 - lam) * criterion(outputs, y_b)

Q5: What is knowledge distillation? When and how do you use it?

A

Knowledge distillation trains a small "student" model to mimic a large "teacher" model. The student learns from the teacher's soft probability outputs (logits) rather than just the hard labels.

Why soft labels are better: A teacher's output like [0.7, 0.2, 0.1] for [cat, dog, tiger] encodes that cats look similar to dogs and tigers. Hard labels [1, 0, 0] lose this information. The student learns richer representations from these "dark knowledge" signals.

Loss function: L = alpha * KL(softmax(teacher_logits/T), softmax(student_logits/T)) + (1-alpha) * CE(student_logits, hard_labels). Temperature T (typically 2-20) softens the probability distribution, making the teacher's confidence information more accessible.

Use cases: Deploy a small model (DistilBERT is 40% smaller, 60% faster, retains 97% of BERT's performance), on-device inference, reduce serving costs.

import torch.nn.functional as F

def distillation_loss(student_logits, teacher_logits, labels,
                      temperature=4.0, alpha=0.7):
    """Knowledge distillation loss"""
    # Soft target loss (KL divergence between soft distributions)
    soft_loss = F.kl_div(
        F.log_softmax(student_logits / temperature, dim=-1),
        F.softmax(teacher_logits / temperature, dim=-1),
        reduction='batchmean'
    ) * (temperature ** 2)  # Scale by T^2 to match gradient magnitude

    # Hard target loss (standard cross-entropy)
    hard_loss = F.cross_entropy(student_logits, labels)

    return alpha * soft_loss + (1 - alpha) * hard_loss

# Training loop
teacher.eval()
student.train()
for batch in dataloader:
    with torch.no_grad():
        teacher_logits = teacher(batch['input'])

    student_logits = student(batch['input'])
    loss = distillation_loss(student_logits, teacher_logits, batch['labels'])
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

Q6: What is model pruning? Compare structured vs. unstructured pruning.

A

Pruning removes unnecessary weights/connections from a trained model to make it smaller and faster.

Unstructured pruning: Remove individual weights (set to zero) based on magnitude. Can achieve high sparsity (90%+) with minimal accuracy loss. Problem: sparse matrices need specialized hardware/libraries to be actually faster. On standard GPUs, a 90% sparse model is often not much faster than the dense model.

Structured pruning: Remove entire neurons, channels, or attention heads. The resulting model is a standard dense model (just smaller), so it runs faster on all hardware without special support. Typically achieves lower sparsity (30-50%) before accuracy degrades.

The lottery ticket hypothesis: Dense networks contain sparse subnetworks that, when trained in isolation from the start (with the same initialization), achieve comparable performance. This suggests that much of the network is redundant.

Q7: Explain quantization. What is the difference between PTQ and QAT?

A

Quantization converts model weights and/or activations from high precision (FP32/FP16) to lower precision (INT8/INT4). This reduces model size and can increase inference speed on hardware with integer math support.

Post-Training Quantization (PTQ): Quantize a pre-trained model without retraining. Calibrate the quantization ranges using a small representative dataset. Fast but can degrade accuracy, especially for INT4.

Quantization-Aware Training (QAT): Simulate quantization during training using "fake quantization" nodes. The model learns to be robust to quantization noise. Better accuracy than PTQ but requires retraining.

Modern LLM quantization: GPTQ, AWQ, and GGUF formats enable 4-bit quantization of large language models with minimal quality loss. A 70B parameter model goes from 140 GB (FP16) to ~35 GB (INT4), making it runnable on consumer GPUs.

Q8: What is LoRA and why is it used for fine-tuning large models?

A

Low-Rank Adaptation (LoRA) adds small trainable matrices to each layer while keeping the original weights frozen. Instead of updating a d*d weight matrix W, it adds a low-rank decomposition: W' = W + A*B, where A is d*r and B is r*d, with r << d (typically r=8-64).

Why it works: The weight updates during fine-tuning have low intrinsic rank — they can be approximated by a low-rank matrix with minimal loss. This is an empirical finding supported by theory.

Benefits:

  • Memory: Only store/update r*(d_in + d_out) parameters instead of d_in * d_out. For a 7B model, LoRA might train 10M parameters instead of 7B (700x reduction).
  • Speed: Much faster training because fewer parameters to update and smaller optimizer states.
  • Serving: Can merge A*B into W at inference time, so there is zero inference overhead.
  • Multi-task: Can swap different LoRA adapters for different tasks without loading separate full models.
import torch
import torch.nn as nn

class LoRALinear(nn.Module):
    """LoRA adapter for a linear layer"""
    def __init__(self, original_layer, rank=8, alpha=16):
        super().__init__()
        self.original = original_layer
        self.original.weight.requires_grad = False  # Freeze original
        if self.original.bias is not None:
            self.original.bias.requires_grad = False

        d_in = original_layer.in_features
        d_out = original_layer.out_features

        # Low-rank matrices
        self.lora_A = nn.Parameter(torch.randn(d_in, rank) * 0.01)
        self.lora_B = nn.Parameter(torch.zeros(rank, d_out))
        self.scaling = alpha / rank

    def forward(self, x):
        # Original output + low-rank update
        return self.original(x) + (x @ self.lora_A @ self.lora_B) * self.scaling

# Apply LoRA to a model
def add_lora(model, rank=8, target_modules=['q_proj', 'v_proj']):
    for name, module in model.named_modules():
        if any(t in name for t in target_modules) and isinstance(module, nn.Linear):
            parent_name = '.'.join(name.split('.')[:-1])
            child_name = name.split('.')[-1]
            parent = dict(model.named_modules())[parent_name]
            setattr(parent, child_name, LoRALinear(module, rank=rank))
    return model

Q9: What is early stopping? How do you implement it properly?

A

Early stopping monitors a validation metric during training and stops when the metric stops improving for a specified number of epochs (patience). It prevents overfitting by stopping training at the point of best generalization.

Key decisions:

  • Which metric: Use the metric that matters for your task (F1 for imbalanced classification, BLEU for translation), not just validation loss
  • Patience: Too small (2-3) risks stopping during a temporary plateau. Too large (20+) wastes compute. Typically 5-10 epochs.
  • Restore best weights: Always save and restore the checkpoint with the best validation metric, not the weights at the stopped epoch
class EarlyStopping:
    def __init__(self, patience=7, min_delta=0.001, mode='min'):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_score = None
        self.should_stop = False
        self.mode = mode

    def __call__(self, score, model, path='checkpoint.pt'):
        is_better = (
            self.best_score is None or
            (self.mode == 'min' and score < self.best_score - self.min_delta) or
            (self.mode == 'max' and score > self.best_score + self.min_delta)
        )

        if is_better:
            self.best_score = score
            self.counter = 0
            torch.save(model.state_dict(), path)
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.should_stop = True

# Usage
early_stopping = EarlyStopping(patience=5, mode='min')
for epoch in range(100):
    train_loss = train_one_epoch(model, train_loader)
    val_loss = evaluate(model, val_loader)

    early_stopping(val_loss, model)
    if early_stopping.should_stop:
        print(f"Early stopping at epoch {epoch}")
        model.load_state_dict(torch.load('checkpoint.pt'))
        break

Q10: How does distributed data parallel (DDP) training work? What about model parallelism?

A

Data parallelism (DDP): Each GPU has a copy of the full model. The data batch is split across GPUs. Each GPU computes gradients on its shard. Gradients are averaged across GPUs (all-reduce), then each GPU updates its model identically. Simple, efficient, the default for most training.

Model parallelism: The model is too large for one GPU. Split the model across GPUs.

  • Tensor parallelism: Split individual layers (e.g., split a large matrix multiplication across GPUs). Low latency but requires high-bandwidth interconnect.
  • Pipeline parallelism: Different layers on different GPUs. Micro-batching keeps all GPUs busy. Higher latency but works with lower bandwidth.
  • FSDP (Fully Sharded Data Parallel): Shards model parameters, gradients, and optimizer states across GPUs. Each GPU only stores 1/N of the model. Gathers parameters on-demand for forward/backward. Used for training models larger than GPU memory.

Q11: What is label smoothing and why is it used?

A

Label smoothing replaces hard targets [1, 0, 0] with soft targets [0.9, 0.05, 0.05] (for epsilon=0.1). The smoothed label is: y_smooth = (1 - epsilon) * y_hard + epsilon / num_classes.

Why it helps:

  • Prevents the model from becoming overconfident (outputting probabilities very close to 0 or 1)
  • Acts as a regularizer, reducing overfitting
  • Encourages the model to maintain distinctions between non-target classes
  • Particularly effective for Transformers and large models

Typical value: epsilon = 0.1 is the most common choice. Used in the original Transformer, ResNet training recipes, and most modern architectures.

Q12: You are training a model and the loss is NaN. How do you debug this?

A

This is a very common practical interview question. Here is a systematic debugging approach:

  1. Check the learning rate: Too high is the most common cause. Try reducing by 10x. If using warmup, check that it is implemented correctly.
  2. Check for division by zero: Look for divisions, log() of zero, sqrt() of negative values in your loss function or model. Add epsilon (1e-8) to denominators.
  3. Check input data: Look for NaN/Inf in your input data. Add assertions: assert not torch.isnan(x).any()
  4. Check gradient norms: If gradients explode before loss goes NaN, add gradient clipping (max_norm=1.0).
  5. Mixed precision issues: FP16 overflow (values > 65504). Switch to BF16 or increase loss scaling.
  6. Weight initialization: Bad initialization can cause immediate NaN. Try reducing initialization scale.
  7. Use anomaly detection: torch.autograd.set_detect_anomaly(True) identifies the exact operation that produced NaN.
# Debugging NaN loss
torch.autograd.set_detect_anomaly(True)  # Identifies problematic operation

# Add NaN checks
def check_nan(tensor, name):
    if torch.isnan(tensor).any():
        raise ValueError(f"NaN detected in {name}")

# Monitor gradient norms
total_norm = 0
for p in model.parameters():
    if p.grad is not None:
        total_norm += p.grad.data.norm(2).item() ** 2
total_norm = total_norm ** 0.5
print(f"Gradient norm: {total_norm:.4f}")
# If this is very large (>100) or growing rapidly, gradients are exploding

Key Takeaways

💡
  • Learning rate warmup prevents early training instability; cosine decay is the standard schedule
  • Mixed precision (BF16 preferred, FP16 with loss scaling) gives 2-4x speedup with minimal code changes
  • Gradient accumulation simulates large batch sizes on limited GPU memory
  • Knowledge distillation transfers knowledge from large to small models via soft labels
  • LoRA enables efficient fine-tuning with 700x fewer trainable parameters and zero inference overhead
  • When loss is NaN: check learning rate first, then data, then gradients, then precision