Advanced

Datasets & DataLoaders

Five challenges covering the data pipeline engineering that every production DL system requires. Data loading is often the bottleneck — getting it right determines whether training takes hours or days.

Challenge 1: Custom Dataset

Implement a custom Dataset class for a multi-modal dataset (text + images + labels) with lazy loading.

import torch
from torch.utils.data import Dataset, DataLoader
import os

class MultiModalDataset(Dataset):
    """
    Challenge: Implement a Dataset for paired text-image data.
    - Lazy loading: don't load all data into memory at init
    - Return (image_tensor, token_ids, label)
    - Handle missing files gracefully
    - Support optional transforms
    """
    # YOUR SOLUTION HERE
    pass

# ---- SOLUTION ----
class MultiModalDataset(Dataset):
    def __init__(self, data_dir, split='train', transform=None, max_seq_len=128):
        self.data_dir = data_dir
        self.transform = transform
        self.max_seq_len = max_seq_len

        # Load metadata (lightweight - just paths and labels)
        self.samples = []
        metadata_path = os.path.join(data_dir, f'{split}_metadata.csv')

        # For demonstration, create synthetic metadata
        # In production, this would read from a CSV/JSON file
        self.samples = [
            {
                'image_path': f'images/{i:06d}.jpg',
                'text': f'Sample text for item {i}',
                'label': i % 10,
            }
            for i in range(1000)
        ]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]

        # Lazy load image (simulate with random tensor)
        # In production: image = Image.open(os.path.join(self.data_dir, sample['image_path']))
        image = torch.randn(3, 224, 224)  # placeholder

        if self.transform:
            image = self.transform(image)

        # Tokenize text (simple character-level for demo)
        # In production: use a real tokenizer
        token_ids = torch.tensor(
            [ord(c) % 256 for c in sample['text'][:self.max_seq_len]],
            dtype=torch.long
        )
        # Pad to max_seq_len
        if len(token_ids) < self.max_seq_len:
            padding = torch.zeros(self.max_seq_len - len(token_ids), dtype=torch.long)
            token_ids = torch.cat([token_ids, padding])

        label = torch.tensor(sample['label'], dtype=torch.long)

        return image, token_ids, label

# Test
dataset = MultiModalDataset('/tmp/data')
print(f"Dataset size: {len(dataset)}")

image, tokens, label = dataset[0]
print(f"Image shape:  {image.shape}")    # (3, 224, 224)
print(f"Token shape:  {tokens.shape}")   # (128,)
print(f"Label:        {label.item()}")   # 0

loader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=0)
batch = next(iter(loader))
print(f"Batch images: {batch[0].shape}") # (8, 3, 224, 224)
print(f"Batch tokens: {batch[1].shape}") # (8, 128)
print(f"Batch labels: {batch[2].shape}") # (8,)

Challenge 2: Data Augmentation Pipeline

Build a composable data augmentation pipeline using pure PyTorch operations (no torchvision dependency).

import torch
import torch.nn.functional as F

class RandomCrop:
    """Randomly crop a tensor image."""
    def __init__(self, size):
        self.size = size if isinstance(size, tuple) else (size, size)

    def __call__(self, img):
        _, h, w = img.shape
        th, tw = self.size
        if h < th or w < tw:
            # Pad if image is smaller than crop size
            img = F.pad(img, (0, max(0, tw - w), 0, max(0, th - h)))
            _, h, w = img.shape
        top = torch.randint(0, h - th + 1, (1,)).item()
        left = torch.randint(0, w - tw + 1, (1,)).item()
        return img[:, top:top + th, left:left + tw]

class RandomHorizontalFlip:
    """Randomly flip horizontally with given probability."""
    def __init__(self, p=0.5):
        self.p = p

    def __call__(self, img):
        if torch.rand(1).item() < self.p:
            return img.flip(-1)  # flip last dimension (width)
        return img

class ColorJitter:
    """Randomly adjust brightness and contrast."""
    def __init__(self, brightness=0.2, contrast=0.2):
        self.brightness = brightness
        self.contrast = contrast

    def __call__(self, img):
        # Brightness
        b_factor = 1.0 + (torch.rand(1).item() * 2 - 1) * self.brightness
        img = img * b_factor

        # Contrast
        mean = img.mean()
        c_factor = 1.0 + (torch.rand(1).item() * 2 - 1) * self.contrast
        img = (img - mean) * c_factor + mean

        return img.clamp(0, 1)

class Normalize:
    """Normalize with mean and std."""
    def __init__(self, mean, std):
        self.mean = torch.tensor(mean).view(-1, 1, 1)
        self.std = torch.tensor(std).view(-1, 1, 1)

    def __call__(self, img):
        return (img - self.mean) / self.std

class Compose:
    """Compose multiple transforms."""
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img

# ---- SOLUTION: Build augmentation pipeline ----
train_transform = Compose([
    RandomCrop(224),
    RandomHorizontalFlip(p=0.5),
    ColorJitter(brightness=0.2, contrast=0.2),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

val_transform = Compose([
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Test
img = torch.rand(3, 256, 256)  # random image
augmented = train_transform(img)
print(f"Original:  {img.shape}")        # (3, 256, 256)
print(f"Augmented: {augmented.shape}")   # (3, 224, 224)

Challenge 3: Custom Collate Functions

Implement collate functions for variable-length sequences — essential for NLP tasks.

import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset

def variable_length_collate(batch):
    """
    Challenge: Collate function for variable-length sequences.
    Each sample is (sequence_tensor, label).
    - Pad sequences to the max length in the batch
    - Create attention mask (1 for real tokens, 0 for padding)
    - Return (padded_sequences, attention_mask, labels, lengths)
    """
    # YOUR SOLUTION HERE
    pass

# ---- SOLUTION ----
def variable_length_collate(batch):
    # Separate sequences and labels
    sequences = [item[0] for item in batch]
    labels = torch.stack([item[1] for item in batch])

    # Record original lengths
    lengths = torch.tensor([len(seq) for seq in sequences])

    # Pad sequences to max length in batch
    padded = pad_sequence(sequences, batch_first=True, padding_value=0)
    # padded shape: (batch_size, max_seq_len)

    # Create attention mask: 1 for real tokens, 0 for padding
    max_len = padded.size(1)
    attention_mask = torch.arange(max_len).unsqueeze(0) < lengths.unsqueeze(1)
    attention_mask = attention_mask.long()

    return padded, attention_mask, labels, lengths


class VariableLengthDataset(Dataset):
    """Demo dataset with variable-length sequences."""
    def __init__(self, num_samples=100, vocab_size=1000, num_classes=5):
        self.data = []
        for _ in range(num_samples):
            seq_len = torch.randint(5, 50, (1,)).item()
            seq = torch.randint(1, vocab_size, (seq_len,))
            label = torch.randint(0, num_classes, (1,))
            self.data.append((seq, label.squeeze()))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


# Test
dataset = VariableLengthDataset()
loader = DataLoader(
    dataset,
    batch_size=8,
    shuffle=True,
    collate_fn=variable_length_collate
)

padded, mask, labels, lengths = next(iter(loader))
print(f"Padded:    {padded.shape}")    # (8, max_len_in_batch)
print(f"Mask:      {mask.shape}")      # (8, max_len_in_batch)
print(f"Labels:    {labels.shape}")    # (8,)
print(f"Lengths:   {lengths}")         # actual lengths per sample

Challenge 4: Distributed Sampling

Implement a sampler that handles class-balanced sampling and distributed training.

import torch
from torch.utils.data import Sampler, Dataset
import math

class BalancedClassSampler(Sampler):
    """
    Challenge: Implement a sampler that:
    - Ensures each class is equally represented in every epoch
    - Oversamples minority classes / undersamples majority classes
    - Supports setting a random seed for reproducibility
    """
    # YOUR SOLUTION HERE
    pass

# ---- SOLUTION ----
class BalancedClassSampler(Sampler):
    def __init__(self, labels, samples_per_class=None, seed=42):
        self.labels = torch.tensor(labels) if not isinstance(labels, torch.Tensor) else labels
        self.classes = torch.unique(self.labels)
        self.num_classes = len(self.classes)
        self.seed = seed

        # Group indices by class
        self.class_indices = {}
        for c in self.classes:
            self.class_indices[c.item()] = torch.where(self.labels == c)[0].tolist()

        # Determine samples per class
        if samples_per_class is None:
            # Use the size of the largest class
            self.samples_per_class = max(len(v) for v in self.class_indices.values())
        else:
            self.samples_per_class = samples_per_class

    def __iter__(self):
        g = torch.Generator()
        g.manual_seed(self.seed)

        indices = []
        for c in self.classes:
            c_indices = self.class_indices[c.item()]
            # Oversample if class has fewer than samples_per_class
            num_needed = self.samples_per_class
            repeated = (c_indices * math.ceil(num_needed / len(c_indices)))[:num_needed]
            indices.extend(repeated)

        # Shuffle all indices
        perm = torch.randperm(len(indices), generator=g)
        return iter([indices[i] for i in perm])

    def __len__(self):
        return self.samples_per_class * self.num_classes

# Test
# Simulate imbalanced dataset: class 0 has 100 samples, class 1 has 10
labels = [0] * 100 + [1] * 10
sampler = BalancedClassSampler(labels, samples_per_class=50)
print(f"Sampler length: {len(sampler)}")  # 100 (50 * 2 classes)

# Verify balance
sampled = list(sampler)
sampled_labels = [labels[i] for i in sampled]
from collections import Counter
print(f"Class distribution: {Counter(sampled_labels)}")  # ~50 each

Challenge 5: Streaming Dataset

Implement a dataset that streams data from disk without loading everything into memory — essential for datasets that do not fit in RAM.

import torch
from torch.utils.data import IterableDataset, DataLoader
import math

class StreamingDataset(IterableDataset):
    """
    Challenge: Implement a streaming dataset that:
    - Reads data in chunks from disk (simulated here)
    - Supports shuffling within a buffer
    - Works correctly with multiple DataLoader workers
    - Handles epoch boundaries properly
    """
    # YOUR SOLUTION HERE
    pass

# ---- SOLUTION ----
class StreamingDataset(IterableDataset):
    def __init__(self, data_source, buffer_size=1000, seed=42):
        self.data_source = data_source  # iterable/generator of samples
        self.buffer_size = buffer_size
        self.seed = seed

    def _shuffle_buffer(self, iterator):
        """Reservoir-based shuffle buffer."""
        buffer = []
        for item in iterator:
            buffer.append(item)
            if len(buffer) >= self.buffer_size:
                # Yield a random item from the buffer
                idx = torch.randint(0, len(buffer), (1,)).item()
                yield buffer[idx]
                buffer[idx] = buffer[-1]
                buffer.pop()

        # Drain remaining items
        while buffer:
            idx = torch.randint(0, len(buffer), (1,)).item()
            yield buffer[idx]
            buffer[idx] = buffer[-1]
            buffer.pop()

    def _get_worker_slice(self, full_iter):
        """Partition data across DataLoader workers."""
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:
            # Single-process loading
            yield from full_iter
        else:
            # Multi-process: each worker takes every Nth item
            worker_id = worker_info.id
            num_workers = worker_info.num_workers
            for i, item in enumerate(full_iter):
                if i % num_workers == worker_id:
                    yield item

    def __iter__(self):
        # Generate data (in production, this reads from files)
        def data_generator():
            for item in self.data_source:
                x = torch.tensor(item['features'], dtype=torch.float32)
                y = torch.tensor(item['label'], dtype=torch.long)
                yield x, y

        # Apply worker partitioning, then shuffle buffer
        worker_iter = self._get_worker_slice(data_generator())
        yield from self._shuffle_buffer(worker_iter)


# Test
# Simulate a data source (in production: read from files/database)
data_source = [
    {'features': [float(i)] * 10, 'label': i % 5}
    for i in range(10000)
]

dataset = StreamingDataset(data_source, buffer_size=500)
loader = DataLoader(dataset, batch_size=32, num_workers=0)

batch_x, batch_y = next(iter(loader))
print(f"Batch X shape: {batch_x.shape}")  # (32, 10)
print(f"Batch Y shape: {batch_y.shape}")  # (32,)

# Count total samples
total = sum(len(y) for _, y in DataLoader(dataset, batch_size=64, num_workers=0))
print(f"Total samples streamed: {total}")  # 10000
💡
Interview tip: The three things interviewers check about your data pipeline: (1) Does your Dataset's __getitem__ handle edge cases (missing files, corrupt data)? (2) Do you set num_workers > 0 and understand the implications for multi-process data loading? (3) Can you write a custom collate function for variable-length inputs? These show you have shipped real training jobs.