Advanced
Datasets & DataLoaders
Five challenges covering the data pipeline engineering that every production DL system requires. Data loading is often the bottleneck — getting it right determines whether training takes hours or days.
Challenge 1: Custom Dataset
Implement a custom Dataset class for a multi-modal dataset (text + images + labels) with lazy loading.
import torch
from torch.utils.data import Dataset, DataLoader
import os
class MultiModalDataset(Dataset):
"""
Challenge: Implement a Dataset for paired text-image data.
- Lazy loading: don't load all data into memory at init
- Return (image_tensor, token_ids, label)
- Handle missing files gracefully
- Support optional transforms
"""
# YOUR SOLUTION HERE
pass
# ---- SOLUTION ----
class MultiModalDataset(Dataset):
def __init__(self, data_dir, split='train', transform=None, max_seq_len=128):
self.data_dir = data_dir
self.transform = transform
self.max_seq_len = max_seq_len
# Load metadata (lightweight - just paths and labels)
self.samples = []
metadata_path = os.path.join(data_dir, f'{split}_metadata.csv')
# For demonstration, create synthetic metadata
# In production, this would read from a CSV/JSON file
self.samples = [
{
'image_path': f'images/{i:06d}.jpg',
'text': f'Sample text for item {i}',
'label': i % 10,
}
for i in range(1000)
]
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
sample = self.samples[idx]
# Lazy load image (simulate with random tensor)
# In production: image = Image.open(os.path.join(self.data_dir, sample['image_path']))
image = torch.randn(3, 224, 224) # placeholder
if self.transform:
image = self.transform(image)
# Tokenize text (simple character-level for demo)
# In production: use a real tokenizer
token_ids = torch.tensor(
[ord(c) % 256 for c in sample['text'][:self.max_seq_len]],
dtype=torch.long
)
# Pad to max_seq_len
if len(token_ids) < self.max_seq_len:
padding = torch.zeros(self.max_seq_len - len(token_ids), dtype=torch.long)
token_ids = torch.cat([token_ids, padding])
label = torch.tensor(sample['label'], dtype=torch.long)
return image, token_ids, label
# Test
dataset = MultiModalDataset('/tmp/data')
print(f"Dataset size: {len(dataset)}")
image, tokens, label = dataset[0]
print(f"Image shape: {image.shape}") # (3, 224, 224)
print(f"Token shape: {tokens.shape}") # (128,)
print(f"Label: {label.item()}") # 0
loader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=0)
batch = next(iter(loader))
print(f"Batch images: {batch[0].shape}") # (8, 3, 224, 224)
print(f"Batch tokens: {batch[1].shape}") # (8, 128)
print(f"Batch labels: {batch[2].shape}") # (8,)
Challenge 2: Data Augmentation Pipeline
Build a composable data augmentation pipeline using pure PyTorch operations (no torchvision dependency).
import torch
import torch.nn.functional as F
class RandomCrop:
"""Randomly crop a tensor image."""
def __init__(self, size):
self.size = size if isinstance(size, tuple) else (size, size)
def __call__(self, img):
_, h, w = img.shape
th, tw = self.size
if h < th or w < tw:
# Pad if image is smaller than crop size
img = F.pad(img, (0, max(0, tw - w), 0, max(0, th - h)))
_, h, w = img.shape
top = torch.randint(0, h - th + 1, (1,)).item()
left = torch.randint(0, w - tw + 1, (1,)).item()
return img[:, top:top + th, left:left + tw]
class RandomHorizontalFlip:
"""Randomly flip horizontally with given probability."""
def __init__(self, p=0.5):
self.p = p
def __call__(self, img):
if torch.rand(1).item() < self.p:
return img.flip(-1) # flip last dimension (width)
return img
class ColorJitter:
"""Randomly adjust brightness and contrast."""
def __init__(self, brightness=0.2, contrast=0.2):
self.brightness = brightness
self.contrast = contrast
def __call__(self, img):
# Brightness
b_factor = 1.0 + (torch.rand(1).item() * 2 - 1) * self.brightness
img = img * b_factor
# Contrast
mean = img.mean()
c_factor = 1.0 + (torch.rand(1).item() * 2 - 1) * self.contrast
img = (img - mean) * c_factor + mean
return img.clamp(0, 1)
class Normalize:
"""Normalize with mean and std."""
def __init__(self, mean, std):
self.mean = torch.tensor(mean).view(-1, 1, 1)
self.std = torch.tensor(std).view(-1, 1, 1)
def __call__(self, img):
return (img - self.mean) / self.std
class Compose:
"""Compose multiple transforms."""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, img):
for t in self.transforms:
img = t(img)
return img
# ---- SOLUTION: Build augmentation pipeline ----
train_transform = Compose([
RandomCrop(224),
RandomHorizontalFlip(p=0.5),
ColorJitter(brightness=0.2, contrast=0.2),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
val_transform = Compose([
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Test
img = torch.rand(3, 256, 256) # random image
augmented = train_transform(img)
print(f"Original: {img.shape}") # (3, 256, 256)
print(f"Augmented: {augmented.shape}") # (3, 224, 224)
Challenge 3: Custom Collate Functions
Implement collate functions for variable-length sequences — essential for NLP tasks.
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
def variable_length_collate(batch):
"""
Challenge: Collate function for variable-length sequences.
Each sample is (sequence_tensor, label).
- Pad sequences to the max length in the batch
- Create attention mask (1 for real tokens, 0 for padding)
- Return (padded_sequences, attention_mask, labels, lengths)
"""
# YOUR SOLUTION HERE
pass
# ---- SOLUTION ----
def variable_length_collate(batch):
# Separate sequences and labels
sequences = [item[0] for item in batch]
labels = torch.stack([item[1] for item in batch])
# Record original lengths
lengths = torch.tensor([len(seq) for seq in sequences])
# Pad sequences to max length in batch
padded = pad_sequence(sequences, batch_first=True, padding_value=0)
# padded shape: (batch_size, max_seq_len)
# Create attention mask: 1 for real tokens, 0 for padding
max_len = padded.size(1)
attention_mask = torch.arange(max_len).unsqueeze(0) < lengths.unsqueeze(1)
attention_mask = attention_mask.long()
return padded, attention_mask, labels, lengths
class VariableLengthDataset(Dataset):
"""Demo dataset with variable-length sequences."""
def __init__(self, num_samples=100, vocab_size=1000, num_classes=5):
self.data = []
for _ in range(num_samples):
seq_len = torch.randint(5, 50, (1,)).item()
seq = torch.randint(1, vocab_size, (seq_len,))
label = torch.randint(0, num_classes, (1,))
self.data.append((seq, label.squeeze()))
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
# Test
dataset = VariableLengthDataset()
loader = DataLoader(
dataset,
batch_size=8,
shuffle=True,
collate_fn=variable_length_collate
)
padded, mask, labels, lengths = next(iter(loader))
print(f"Padded: {padded.shape}") # (8, max_len_in_batch)
print(f"Mask: {mask.shape}") # (8, max_len_in_batch)
print(f"Labels: {labels.shape}") # (8,)
print(f"Lengths: {lengths}") # actual lengths per sample
Challenge 4: Distributed Sampling
Implement a sampler that handles class-balanced sampling and distributed training.
import torch
from torch.utils.data import Sampler, Dataset
import math
class BalancedClassSampler(Sampler):
"""
Challenge: Implement a sampler that:
- Ensures each class is equally represented in every epoch
- Oversamples minority classes / undersamples majority classes
- Supports setting a random seed for reproducibility
"""
# YOUR SOLUTION HERE
pass
# ---- SOLUTION ----
class BalancedClassSampler(Sampler):
def __init__(self, labels, samples_per_class=None, seed=42):
self.labels = torch.tensor(labels) if not isinstance(labels, torch.Tensor) else labels
self.classes = torch.unique(self.labels)
self.num_classes = len(self.classes)
self.seed = seed
# Group indices by class
self.class_indices = {}
for c in self.classes:
self.class_indices[c.item()] = torch.where(self.labels == c)[0].tolist()
# Determine samples per class
if samples_per_class is None:
# Use the size of the largest class
self.samples_per_class = max(len(v) for v in self.class_indices.values())
else:
self.samples_per_class = samples_per_class
def __iter__(self):
g = torch.Generator()
g.manual_seed(self.seed)
indices = []
for c in self.classes:
c_indices = self.class_indices[c.item()]
# Oversample if class has fewer than samples_per_class
num_needed = self.samples_per_class
repeated = (c_indices * math.ceil(num_needed / len(c_indices)))[:num_needed]
indices.extend(repeated)
# Shuffle all indices
perm = torch.randperm(len(indices), generator=g)
return iter([indices[i] for i in perm])
def __len__(self):
return self.samples_per_class * self.num_classes
# Test
# Simulate imbalanced dataset: class 0 has 100 samples, class 1 has 10
labels = [0] * 100 + [1] * 10
sampler = BalancedClassSampler(labels, samples_per_class=50)
print(f"Sampler length: {len(sampler)}") # 100 (50 * 2 classes)
# Verify balance
sampled = list(sampler)
sampled_labels = [labels[i] for i in sampled]
from collections import Counter
print(f"Class distribution: {Counter(sampled_labels)}") # ~50 each
Challenge 5: Streaming Dataset
Implement a dataset that streams data from disk without loading everything into memory — essential for datasets that do not fit in RAM.
import torch
from torch.utils.data import IterableDataset, DataLoader
import math
class StreamingDataset(IterableDataset):
"""
Challenge: Implement a streaming dataset that:
- Reads data in chunks from disk (simulated here)
- Supports shuffling within a buffer
- Works correctly with multiple DataLoader workers
- Handles epoch boundaries properly
"""
# YOUR SOLUTION HERE
pass
# ---- SOLUTION ----
class StreamingDataset(IterableDataset):
def __init__(self, data_source, buffer_size=1000, seed=42):
self.data_source = data_source # iterable/generator of samples
self.buffer_size = buffer_size
self.seed = seed
def _shuffle_buffer(self, iterator):
"""Reservoir-based shuffle buffer."""
buffer = []
for item in iterator:
buffer.append(item)
if len(buffer) >= self.buffer_size:
# Yield a random item from the buffer
idx = torch.randint(0, len(buffer), (1,)).item()
yield buffer[idx]
buffer[idx] = buffer[-1]
buffer.pop()
# Drain remaining items
while buffer:
idx = torch.randint(0, len(buffer), (1,)).item()
yield buffer[idx]
buffer[idx] = buffer[-1]
buffer.pop()
def _get_worker_slice(self, full_iter):
"""Partition data across DataLoader workers."""
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
# Single-process loading
yield from full_iter
else:
# Multi-process: each worker takes every Nth item
worker_id = worker_info.id
num_workers = worker_info.num_workers
for i, item in enumerate(full_iter):
if i % num_workers == worker_id:
yield item
def __iter__(self):
# Generate data (in production, this reads from files)
def data_generator():
for item in self.data_source:
x = torch.tensor(item['features'], dtype=torch.float32)
y = torch.tensor(item['label'], dtype=torch.long)
yield x, y
# Apply worker partitioning, then shuffle buffer
worker_iter = self._get_worker_slice(data_generator())
yield from self._shuffle_buffer(worker_iter)
# Test
# Simulate a data source (in production: read from files/database)
data_source = [
{'features': [float(i)] * 10, 'label': i % 5}
for i in range(10000)
]
dataset = StreamingDataset(data_source, buffer_size=500)
loader = DataLoader(dataset, batch_size=32, num_workers=0)
batch_x, batch_y = next(iter(loader))
print(f"Batch X shape: {batch_x.shape}") # (32, 10)
print(f"Batch Y shape: {batch_y.shape}") # (32,)
# Count total samples
total = sum(len(y) for _, y in DataLoader(dataset, batch_size=64, num_workers=0))
print(f"Total samples streamed: {total}") # 10000
Interview tip: The three things interviewers check about your data pipeline: (1) Does your Dataset's
__getitem__ handle edge cases (missing files, corrupt data)? (2) Do you set num_workers > 0 and understand the implications for multi-process data loading? (3) Can you write a custom collate function for variable-length inputs? These show you have shipped real training jobs.
Lilly Tech Systems