Advanced

Mocking External Services

Isolating ML tests from external dependencies. Part of the Unit Testing for ML Pipelines course at AI School by Lilly Tech Systems.

Why Mocking Is Essential for ML Tests

ML pipelines often depend on external services: cloud storage for data, model registries for artifacts, feature stores for online features, and APIs for predictions. Unit tests that call real external services are slow, flaky, and expensive. Mocking replaces these dependencies with controlled substitutes, making your tests fast, reliable, and independent of external infrastructure.

Python Mocking Fundamentals

Python's unittest.mock library provides the tools you need. The key concepts are Mock objects, patch decorators, and side_effect for custom behavior:

from unittest.mock import Mock, patch, MagicMock
import pytest

# Example: A function that loads data from S3
def load_training_data(s3_client, bucket, key):
    # Load training data from S3.
    response = s3_client.get_object(Bucket=bucket, Key=key)
    body = response['Body'].read()
    return pd.read_parquet(io.BytesIO(body))

def test_load_training_data():
    # Create a mock S3 client
    mock_s3 = Mock()

    # Create sample parquet data
    sample_df = pd.DataFrame({'feature1': [1, 2, 3], 'target': [0, 1, 0]})
    buffer = io.BytesIO()
    sample_df.to_parquet(buffer)
    buffer.seek(0)

    # Configure mock to return sample data
    mock_s3.get_object.return_value = {
        'Body': Mock(read=Mock(return_value=buffer.read()))
    }

    result = load_training_data(mock_s3, 'my-bucket', 'data/train.parquet')
    assert len(result) == 3
    assert 'feature1' in result.columns
    mock_s3.get_object.assert_called_once_with(
        Bucket='my-bucket', Key='data/train.parquet'
    )

Mocking ML Model Registries

When your code interacts with model registries like MLflow or SageMaker Model Registry, mock the registry client to test registration and retrieval logic:

def register_model(mlflow_client, model, model_name, metrics):
    # Register a model if it beats the current production model.
    try:
        current = mlflow_client.get_latest_versions(model_name, stages=["Production"])
        current_f1 = float(current[0].tags.get('f1_score', 0))
    except Exception:
        current_f1 = 0

    if metrics['f1_score'] > current_f1:
        mlflow_client.log_model(model, model_name)
        return True
    return False

def test_register_model_when_better():
    mock_client = Mock()
    mock_version = Mock()
    mock_version.tags = {'f1_score': '0.80'}
    mock_client.get_latest_versions.return_value = [mock_version]

    result = register_model(
        mock_client, Mock(), 'my_model', {'f1_score': 0.85}
    )
    assert result is True
    mock_client.log_model.assert_called_once()

def test_skip_registration_when_worse():
    mock_client = Mock()
    mock_version = Mock()
    mock_version.tags = {'f1_score': '0.90'}
    mock_client.get_latest_versions.return_value = [mock_version]

    result = register_model(
        mock_client, Mock(), 'my_model', {'f1_score': 0.85}
    )
    assert result is False
    mock_client.log_model.assert_not_called()
💡
Design tip: If you find mocking is extremely complex, it often means your code is too tightly coupled to external services. Refactor to use dependency injection: pass clients as parameters rather than creating them inside functions. This makes testing straightforward.

Mocking API Calls for Inference Testing

When testing code that calls ML prediction APIs, mock the HTTP layer to avoid network calls:

import requests
from unittest.mock import patch

def get_prediction(api_url, features):
    # Call an ML prediction API.
    response = requests.post(api_url, json={'features': features})
    response.raise_for_status()
    return response.json()['prediction']

@patch('requests.post')
def test_get_prediction_success(mock_post):
    mock_post.return_value = Mock(
        status_code=200,
        json=Mock(return_value={'prediction': 0.87})
    )
    mock_post.return_value.raise_for_status = Mock()

    result = get_prediction('http://model-api/predict', [1.0, 2.0, 3.0])
    assert result == 0.87

@patch('requests.post')
def test_get_prediction_api_error(mock_post):
    mock_post.return_value.raise_for_status.side_effect = requests.HTTPError("500")

    with pytest.raises(requests.HTTPError):
        get_prediction('http://model-api/predict', [1.0, 2.0, 3.0])

When NOT to Mock

Not everything should be mocked. Integration tests should use real services (or at least realistic test doubles). Over-mocking can hide real integration issues. Use mocks for unit tests and real services for integration tests. If you mock everything, you only prove that your mocks work, not that your code works with real services.

Anti-pattern: Mocking the thing you are testing. If you mock a function and then assert that the mock was called, you have not tested anything meaningful. Mock the dependencies, not the system under test.