Image Classification
Image classification assigns a label to an entire image. It is the most fundamental computer vision task and the basis for many other CV applications.
The Classification Pipeline
Input Image
Resize to a fixed dimension (e.g., 224x224). Normalize pixel values.
Feature Extraction
A CNN or Vision Transformer extracts hierarchical features from the image.
Classification Head
Fully connected layers map features to class probabilities via softmax.
Output
The predicted class label and confidence score.
CNN Architecture Evolution
| Model | Year | Params | Key Innovation |
|---|---|---|---|
| LeNet | 1998 | 60K | First successful CNN for digit recognition (MNIST) |
| AlexNet | 2012 | 60M | Deep CNN with ReLU, dropout, GPU training. Won ImageNet. |
| VGG | 2014 | 138M | Very deep (16-19 layers) with uniform 3x3 convolutions |
| ResNet | 2015 | 25M | Skip connections enable training 100+ layer networks |
| EfficientNet | 2019 | 5-66M | Compound scaling of depth, width, and resolution |
| ViT | 2020 | 86M+ | Transformer architecture applied to image patches |
Transfer Learning with Pretrained Models
Instead of training from scratch, use a model pretrained on ImageNet (1.2M images, 1000 classes) as a starting point. This is the most practical approach for most real-world applications.
import torch import torchvision.models as models import torch.nn as nn # Load pretrained ResNet-50 model = models.resnet50(weights="IMAGENET1K_V2") # Replace classification head for custom classes num_classes = 10 model.fc = nn.Linear(model.fc.in_features, num_classes) # Freeze backbone (optional - for feature extraction) for param in model.parameters(): param.requires_grad = False for param in model.fc.parameters(): param.requires_grad = True # Training setup optimizer = torch.optim.Adam(model.fc.parameters(), lr=1e-3) criterion = nn.CrossEntropyLoss()
Data Augmentation
Augmentation creates variations of training images to improve model generalization and prevent overfitting:
from torchvision import transforms train_transform = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.RandomRotation(15), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) val_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ])
Vision Transformer (ViT)
ViT applies the transformer architecture to images by dividing them into fixed-size patches (e.g., 16x16), linearly embedding each patch, and processing the sequence of patch embeddings with a standard transformer encoder.
from transformers import ViTForImageClassification, ViTImageProcessor from PIL import Image # Load pretrained ViT processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224") model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224") # Classify an image image = Image.open("cat.jpg") inputs = processor(images=image, return_tensors="pt") outputs = model(**inputs) predicted_class = outputs.logits.argmax(-1).item() print(f"Predicted: {model.config.id2label[predicted_class]}")
Fine-Tuning Best Practices
- Start frozen: First train only the classification head with a higher learning rate, then unfreeze and fine-tune the backbone with a lower learning rate.
- Use a learning rate scheduler: Cosine annealing or reduce-on-plateau helps find optimal weights.
- Early stopping: Monitor validation loss and stop training when it stops improving to prevent overfitting.
- Model selection: For small datasets, use smaller models (ResNet-18, EfficientNet-B0). For larger datasets, try bigger models (ResNet-50, ViT).
Lilly Tech Systems