Advanced

NLP & Transformers on GPU

Explore how Transformer models like BERT and GPT leverage NVIDIA GPUs for training and inference. This lesson covers GPU-optimized attention mechanisms, Triton Inference Server for model serving, and techniques for deploying large language models efficiently.

Transformer Architecture on GPU

Transformers rely heavily on matrix multiplications in their attention mechanism, making them ideal for GPU acceleration. Understanding how attention maps to GPU hardware is essential for the certification.

# Transformer Attention on GPU

transformer_gpu = {
    "Self-Attention": {
        "operation": "softmax(Q @ K^T / sqrt(d_k)) @ V",
        "gpu_mapping": "Q@K^T is a batched matrix multiply -> Tensor Cores",
        "memory": "Attention matrix is O(n^2) where n = sequence length",
        "bottleneck": "Memory, not compute, for long sequences"
    },
    "Why GPUs Excel at Transformers": [
        "Attention = batched matrix multiplications -> perfect for Tensor Cores",
        "FFN layers = large matrix multiplications -> parallel across CUDA cores",
        "Layer normalization = element-wise -> simple parallel operation",
        "Embedding lookups = gather operations -> fast on GPU global memory"
    ],
    "Key Models": {
        "BERT": {
            "type": "Encoder-only",
            "params": "110M (base), 340M (large)",
            "use_cases": "Classification, NER, question answering",
            "gpu_memory": "~1.5GB (base), ~4GB (large) for inference"
        },
        "GPT-2/3": {
            "type": "Decoder-only (autoregressive)",
            "params": "117M to 175B",
            "use_cases": "Text generation, few-shot learning",
            "gpu_memory": "Varies: 117M fits on any GPU, 175B needs multi-node"
        },
        "T5": {
            "type": "Encoder-decoder",
            "params": "60M to 11B",
            "use_cases": "Translation, summarization, text-to-text",
            "gpu_memory": "60M-770M fit on single GPU"
        }
    }
}

Fine-Tuning BERT on GPU

# Fine-tuning BERT for Text Classification on GPU

from transformers import BertTokenizer, BertForSequenceClassification
import torch
from torch.cuda.amp import autocast, GradScaler

device = torch.device("cuda")

# Load pretrained BERT
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained(
    'bert-base-uncased', num_labels=2
).to(device)

# Mixed precision for faster training on Tensor Cores
scaler = GradScaler()
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)

# Training loop with mixed precision
model.train()
for epoch in range(3):
    for batch in train_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()

        with autocast():  # FP16 forward pass on Tensor Cores
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

# Key fine-tuning tips for certification:
bert_tips = {
    "Learning rate": "2e-5 to 5e-5 (much lower than training from scratch)",
    "Epochs": "2-4 epochs is typical (BERT overfits quickly on small datasets)",
    "Batch size": "16-32 (limited by GPU memory due to sequence length)",
    "Max sequence length": "128-512 tokens (memory scales quadratically)",
    "Weight decay": "0.01 (AdamW optimizer)",
    "Warmup": "10% of total steps with linear warmup"
}

GPU-Optimized Attention: Flash Attention

Flash Attention is a GPU-aware algorithm that computes exact attention without materializing the full O(n^2) attention matrix, dramatically reducing memory usage and increasing speed.

# Flash Attention - Key Concept for Certification

flash_attention = {
    "Problem": {
        "standard_attention": "Computes full n x n attention matrix in HBM (global memory)",
        "memory": "O(n^2) for sequence length n -> 4GB for seq_len=32K",
        "speed": "Memory-bound: reads/writes to slow HBM dominate"
    },
    "Solution": {
        "flash_attention": "Tiles computation to fit in fast SRAM (shared memory)",
        "memory": "O(n) instead of O(n^2)",
        "speed": "2-4x faster by reducing HBM reads/writes",
        "key_insight": "Fuses softmax computation with matrix multiply"
    },
    "Usage": {
        "PyTorch_native": "torch.nn.functional.scaled_dot_product_attention() (PyTorch 2.0+)",
        "library": "pip install flash-attn",
        "huggingface": "model.config.attn_implementation = 'flash_attention_2'"
    }
}

# Using Flash Attention in PyTorch 2.0+
import torch.nn.functional as F

# Standard attention (materializes full attention matrix)
# attn_weights = torch.softmax(Q @ K.transpose(-2, -1) / sqrt(d), dim=-1)
# output = attn_weights @ V

# Flash Attention (fused, memory-efficient)
output = F.scaled_dot_product_attention(Q, K, V, is_causal=True)
# Same result, but O(n) memory and 2-4x faster

Triton Inference Server

NVIDIA Triton Inference Server is a production-grade model serving platform that supports multiple frameworks and optimizes GPU utilization for inference workloads.

# Triton Inference Server - Key Concepts

triton_server = {
    "What": "Production model serving platform from NVIDIA",
    "Supports": [
        "TensorRT engines",
        "ONNX Runtime",
        "PyTorch (TorchScript)",
        "TensorFlow (SavedModel)",
        "Python backends (custom logic)",
        "Multiple models simultaneously"
    ],
    "Key Features": {
        "Dynamic batching": "Combines requests into batches for GPU efficiency",
        "Model ensemble": "Chain multiple models in a pipeline",
        "Concurrent execution": "Run multiple models on same GPU",
        "Model versioning": "A/B testing with automatic version management",
        "Metrics": "Prometheus metrics for monitoring latency/throughput"
    }
}

# Triton model repository structure
"""
model_repository/
  bert_classifier/
    config.pbtxt           # Model configuration
    1/                     # Version 1
      model.onnx           # Model file
    2/                     # Version 2
      model.onnx
"""

# config.pbtxt example
"""
name: "bert_classifier"
platform: "onnxruntime_onnx"
max_batch_size: 32
input [
  {
    name: "input_ids"
    data_type: TYPE_INT64
    dims: [512]
  },
  {
    name: "attention_mask"
    data_type: TYPE_INT64
    dims: [512]
  }
]
output [
  {
    name: "logits"
    data_type: TYPE_FP32
    dims: [2]
  }
]
dynamic_batching {
  preferred_batch_size: [8, 16, 32]
  max_queue_delay_microseconds: 100
}
"""

# Launch Triton Server
"""
docker run --gpus all -p 8000:8000 -p 8001:8001 -p 8002:8002 \
  -v /path/to/model_repository:/models \
  nvcr.io/nvidia/tritonserver:24.01-py3 \
  tritonserver --model-repository=/models
"""

# Client request
import tritonclient.http as httpclient
client = httpclient.InferenceServerClient(url="localhost:8000")
inputs = [httpclient.InferInput("input_ids", [1, 512], "INT64")]
inputs[0].set_data_from_numpy(input_ids_numpy)
result = client.infer("bert_classifier", inputs)

LLM Optimization Techniques

Quantization

INT8/INT4 quantization reduces model size and increases inference speed. Post-training quantization (PTQ) requires a calibration dataset. Quantization-aware training (QAT) trains with simulated low precision for better accuracy.

KV Cache

During autoregressive generation, KV cache stores previously computed key-value pairs to avoid recomputation. This trades memory for speed. Paged attention (vLLM) manages KV cache like virtual memory.

Tensor Parallelism

Split individual layers across multiple GPUs. Each GPU computes a portion of the matrix multiply. Essential for models too large for a single GPU (e.g., 70B+ parameters).

Speculative Decoding

Use a small draft model to generate candidate tokens, then verify with the large model in parallel. Can achieve 2-3x speedup for autoregressive generation without changing output quality.

Practice Questions

💡
Test your understanding:
  1. Why are Transformer models well-suited for GPU acceleration? What operations dominate the computation?
  2. What is Flash Attention and how does it reduce memory from O(n^2) to O(n)?
  3. Explain the role of Triton Inference Server. What is dynamic batching and why is it important?
  4. What is mixed precision training and why is it especially useful for fine-tuning BERT?
  5. Compare INT8 quantization and FP16 for inference. What are the trade-offs?
  6. What is KV cache in autoregressive generation? Why does it grow with sequence length?
  7. You need to deploy a 13B parameter model. It does not fit on a single 24GB GPU. What strategies can you use?

Key Takeaways

💡
  • Transformers map naturally to GPUs because attention is a series of batched matrix multiplications
  • Flash Attention reduces attention memory from O(n^2) to O(n) by tiling computation to fit in SRAM
  • Fine-tune BERT with low learning rates (2e-5), mixed precision, and 2-4 epochs
  • Triton Inference Server provides production-grade model serving with dynamic batching and multi-framework support
  • Quantization (INT8/INT4) and tensor parallelism are essential for deploying large language models
  • KV cache and speculative decoding are key optimizations for autoregressive generation