Beginner

Setting Up pytest for ML

Configuring pytest for machine learning projects. Part of the Unit Testing for ML Pipelines course at AI School by Lilly Tech Systems.

Project Structure for Testable ML Code

Before writing tests, you need a project structure that separates concerns and makes code testable. The key principle is to extract logic from notebooks into importable Python modules. Notebooks are for exploration; production code belongs in proper Python packages with tests.

# Recommended ML project structure
ml-project/
├── src/
│   ├── __init__.py
│   ├── data/
│   │   ├── __init__.py
│   │   ├── loader.py          # Data loading functions
│   │   ├── cleaner.py         # Data cleaning functions
│   │   └── validator.py       # Data validation functions
│   ├── features/
│   │   ├── __init__.py
│   │   └── engineering.py     # Feature engineering
│   ├── models/
│   │   ├── __init__.py
│   │   ├── trainer.py         # Model training
│   │   └── evaluator.py       # Model evaluation
│   └── utils/
│       ├── __init__.py
│       └── metrics.py         # Custom metrics
├── tests/
│   ├── __init__.py
│   ├── conftest.py            # Shared fixtures
│   ├── test_data_loader.py
│   ├── test_data_cleaner.py
│   ├── test_features.py
│   ├── test_trainer.py
│   └── test_metrics.py
├── notebooks/                  # Exploration only
├── pyproject.toml
└── pytest.ini

Configuring pytest

Create a pytest configuration file that sets up sensible defaults for ML testing:

# pytest.ini
[pytest]
testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*
addopts = -v --tb=short --strict-markers
markers =
    slow: marks tests as slow (deselect with '-m "not slow"')
    gpu: marks tests that require GPU
    integration: marks integration tests
    data: marks data validation tests
filterwarnings =
    ignore::DeprecationWarning
    ignore::FutureWarning

Writing Fixtures for ML Tests

Fixtures in pytest provide reusable test data and resources. For ML projects, common fixtures include sample datasets, trained models, and configuration objects:

# tests/conftest.py
import pytest
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestClassifier

@pytest.fixture
def sample_dataframe():
    # Create a small representative DataFrame for testing.
    np.random.seed(42)
    return pd.DataFrame({
        'user_id': range(100),
        'age': np.random.randint(18, 80, 100),
        'income': np.random.uniform(20000, 200000, 100),
        'purchases': np.random.randint(0, 50, 100),
        'target': np.random.randint(0, 2, 100)
    })

@pytest.fixture
def sample_features_and_labels(sample_dataframe):
    # Extract features and labels from sample data.
    X = sample_dataframe[['age', 'income', 'purchases']].values
    y = sample_dataframe['target'].values
    return X, y

@pytest.fixture(scope="session")
def trained_classifier():
    # Train a classifier once for the entire test session.
    from sklearn.datasets import make_classification
    X, y = make_classification(n_samples=500, n_features=10, random_state=42)
    model = RandomForestClassifier(n_estimators=50, random_state=42)
    model.fit(X, y)
    return model

@pytest.fixture
def tmp_model_path(tmp_path):
    # Provide a temporary path for model serialization tests.
    return tmp_path / "test_model.pkl"
💡
Performance tip: Use session-scoped fixtures for expensive operations like training models. This trains the model once and reuses it across all tests in the session, dramatically reducing test execution time.

Running Tests Efficiently

ML test suites can be slow due to data loading and model training. Use markers and selective execution to keep feedback loops fast:

  • pytest -m "not slow" — Skip slow tests during development
  • pytest tests/test_features.py — Run only feature engineering tests
  • pytest -x — Stop at the first failure for quick debugging
  • pytest --lf — Re-run only tests that failed last time
  • pytest -n auto — Run tests in parallel using pytest-xdist

Essential pytest Plugins for ML

Several pytest plugins are particularly useful for ML testing:

  1. pytest-xdist — Parallel test execution for faster feedback
  2. pytest-cov — Code coverage reporting to find untested code
  3. pytest-timeout — Prevent hung tests from blocking the pipeline
  4. pytest-benchmark — Performance benchmarking for inference speed tests
  5. pytest-randomly — Randomize test order to detect test interdependencies
Common mistake: Do not import heavy ML libraries (TensorFlow, PyTorch) at the top of test files that do not need them. Use lazy imports inside test functions or fixtures. This keeps fast tests fast and prevents unnecessary GPU initialization.