Intermediate

Image Classification

Image classification assigns a label to an entire image. It is the most fundamental computer vision task and the basis for many other CV applications.

The Classification Pipeline

  1. Input Image

    Resize to a fixed dimension (e.g., 224x224). Normalize pixel values.

  2. Feature Extraction

    A CNN or Vision Transformer extracts hierarchical features from the image.

  3. Classification Head

    Fully connected layers map features to class probabilities via softmax.

  4. Output

    The predicted class label and confidence score.

CNN Architecture Evolution

ModelYearParamsKey Innovation
LeNet199860KFirst successful CNN for digit recognition (MNIST)
AlexNet201260MDeep CNN with ReLU, dropout, GPU training. Won ImageNet.
VGG2014138MVery deep (16-19 layers) with uniform 3x3 convolutions
ResNet201525MSkip connections enable training 100+ layer networks
EfficientNet20195-66MCompound scaling of depth, width, and resolution
ViT202086M+Transformer architecture applied to image patches

Transfer Learning with Pretrained Models

Instead of training from scratch, use a model pretrained on ImageNet (1.2M images, 1000 classes) as a starting point. This is the most practical approach for most real-world applications.

Python - PyTorch Transfer Learning
import torch
import torchvision.models as models
import torch.nn as nn

# Load pretrained ResNet-50
model = models.resnet50(weights="IMAGENET1K_V2")

# Replace classification head for custom classes
num_classes = 10
model.fc = nn.Linear(model.fc.in_features, num_classes)

# Freeze backbone (optional - for feature extraction)
for param in model.parameters():
    param.requires_grad = False
for param in model.fc.parameters():
    param.requires_grad = True

# Training setup
optimizer = torch.optim.Adam(model.fc.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

Data Augmentation

Augmentation creates variations of training images to improve model generalization and prevent overfitting:

Python - PyTorch Augmentation
from torchvision import transforms

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

Vision Transformer (ViT)

ViT applies the transformer architecture to images by dividing them into fixed-size patches (e.g., 16x16), linearly embedding each patch, and processing the sequence of patch embeddings with a standard transformer encoder.

Python - ViT with Hugging Face
from transformers import ViTForImageClassification, ViTImageProcessor
from PIL import Image

# Load pretrained ViT
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224")

# Classify an image
image = Image.open("cat.jpg")
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)

predicted_class = outputs.logits.argmax(-1).item()
print(f"Predicted: {model.config.id2label[predicted_class]}")

Fine-Tuning Best Practices

  • Start frozen: First train only the classification head with a higher learning rate, then unfreeze and fine-tune the backbone with a lower learning rate.
  • Use a learning rate scheduler: Cosine annealing or reduce-on-plateau helps find optimal weights.
  • Early stopping: Monitor validation loss and stop training when it stops improving to prevent overfitting.
  • Model selection: For small datasets, use smaller models (ResNet-18, EfficientNet-B0). For larger datasets, try bigger models (ResNet-50, ViT).
Key takeaway: Image classification has evolved from simple CNNs to sophisticated Vision Transformers. Transfer learning is the practical standard: start with a pretrained model and fine-tune it on your data. Data augmentation is essential for building robust classifiers.