Setting Up pytest for ML
Configuring pytest for machine learning projects. Part of the Unit Testing for ML Pipelines course at AI School by Lilly Tech Systems.
Project Structure for Testable ML Code
Before writing tests, you need a project structure that separates concerns and makes code testable. The key principle is to extract logic from notebooks into importable Python modules. Notebooks are for exploration; production code belongs in proper Python packages with tests.
# Recommended ML project structure
ml-project/
├── src/
│ ├── __init__.py
│ ├── data/
│ │ ├── __init__.py
│ │ ├── loader.py # Data loading functions
│ │ ├── cleaner.py # Data cleaning functions
│ │ └── validator.py # Data validation functions
│ ├── features/
│ │ ├── __init__.py
│ │ └── engineering.py # Feature engineering
│ ├── models/
│ │ ├── __init__.py
│ │ ├── trainer.py # Model training
│ │ └── evaluator.py # Model evaluation
│ └── utils/
│ ├── __init__.py
│ └── metrics.py # Custom metrics
├── tests/
│ ├── __init__.py
│ ├── conftest.py # Shared fixtures
│ ├── test_data_loader.py
│ ├── test_data_cleaner.py
│ ├── test_features.py
│ ├── test_trainer.py
│ └── test_metrics.py
├── notebooks/ # Exploration only
├── pyproject.toml
└── pytest.ini
Configuring pytest
Create a pytest configuration file that sets up sensible defaults for ML testing:
# pytest.ini
[pytest]
testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*
addopts = -v --tb=short --strict-markers
markers =
slow: marks tests as slow (deselect with '-m "not slow"')
gpu: marks tests that require GPU
integration: marks integration tests
data: marks data validation tests
filterwarnings =
ignore::DeprecationWarning
ignore::FutureWarning
Writing Fixtures for ML Tests
Fixtures in pytest provide reusable test data and resources. For ML projects, common fixtures include sample datasets, trained models, and configuration objects:
# tests/conftest.py
import pytest
import pandas as pd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
@pytest.fixture
def sample_dataframe():
# Create a small representative DataFrame for testing.
np.random.seed(42)
return pd.DataFrame({
'user_id': range(100),
'age': np.random.randint(18, 80, 100),
'income': np.random.uniform(20000, 200000, 100),
'purchases': np.random.randint(0, 50, 100),
'target': np.random.randint(0, 2, 100)
})
@pytest.fixture
def sample_features_and_labels(sample_dataframe):
# Extract features and labels from sample data.
X = sample_dataframe[['age', 'income', 'purchases']].values
y = sample_dataframe['target'].values
return X, y
@pytest.fixture(scope="session")
def trained_classifier():
# Train a classifier once for the entire test session.
from sklearn.datasets import make_classification
X, y = make_classification(n_samples=500, n_features=10, random_state=42)
model = RandomForestClassifier(n_estimators=50, random_state=42)
model.fit(X, y)
return model
@pytest.fixture
def tmp_model_path(tmp_path):
# Provide a temporary path for model serialization tests.
return tmp_path / "test_model.pkl"
Running Tests Efficiently
ML test suites can be slow due to data loading and model training. Use markers and selective execution to keep feedback loops fast:
pytest -m "not slow"— Skip slow tests during developmentpytest tests/test_features.py— Run only feature engineering testspytest -x— Stop at the first failure for quick debuggingpytest --lf— Re-run only tests that failed last timepytest -n auto— Run tests in parallel using pytest-xdist
Essential pytest Plugins for ML
Several pytest plugins are particularly useful for ML testing:
- pytest-xdist — Parallel test execution for faster feedback
- pytest-cov — Code coverage reporting to find untested code
- pytest-timeout — Prevent hung tests from blocking the pipeline
- pytest-benchmark — Performance benchmarking for inference speed tests
- pytest-randomly — Randomize test order to detect test interdependencies
Lilly Tech Systems