Performance Optimization
Pipeline performance is the difference between a pipeline that runs in minutes and one that takes hours. These 5 challenges cover the optimization techniques that senior data engineers use daily and interviewers test to separate mid-level from senior candidates.
Challenge 1: Chunked Processing
import sys
from typing import Callable, Generator, List, Dict, Any
class ChunkedProcessor:
"""Process large datasets in memory-efficient chunks.
Handles the key challenge of chunked processing: correctly aggregating
results when groups may span multiple chunks.
"""
def __init__(self, chunk_size=1000):
self.chunk_size = chunk_size
self.stats = {'chunks_processed': 0, 'records_processed': 0,
'peak_memory_records': 0}
def chunk_generator(self, data) -> Generator:
"""Yield data in chunks of chunk_size."""
for i in range(0, len(data), self.chunk_size):
chunk = data[i:i + self.chunk_size]
self.stats['chunks_processed'] += 1
self.stats['records_processed'] += len(chunk)
self.stats['peak_memory_records'] = max(
self.stats['peak_memory_records'], len(chunk)
)
yield chunk
def process_with_aggregation(self, data, group_key, transform_fn, agg_fn):
"""Process data in chunks with correct cross-chunk aggregation.
This is the hard part: if records with the same group_key appear
in different chunks, we need to merge their partial results.
Args:
data: Full dataset (list of dicts)
group_key: Field to group by
transform_fn: Function applied to each record
agg_fn: Function that merges two partial aggregates
"""
partial_results = {} # group_key_value -> partial aggregate
for chunk in self.chunk_generator(data):
# Transform each record
transformed = [transform_fn(record) for record in chunk]
# Group within this chunk
chunk_groups = {}
for record in transformed:
key = record[group_key]
if key not in chunk_groups:
chunk_groups[key] = []
chunk_groups[key].append(record)
# Merge chunk results with running aggregates
for key, group_records in chunk_groups.items():
chunk_agg = self._reduce_group(group_records)
if key in partial_results:
partial_results[key] = agg_fn(partial_results[key], chunk_agg)
else:
partial_results[key] = chunk_agg
return partial_results
def _reduce_group(self, records):
"""Compute basic aggregates for a group of records."""
amounts = [r.get('amount', 0) for r in records]
return {
'count': len(records),
'sum': sum(amounts),
'min': min(amounts) if amounts else 0,
'max': max(amounts) if amounts else 0,
}
# Merge function for combining partial aggregates
def merge_aggregates(a, b):
"""Merge two partial aggregation results."""
return {
'count': a['count'] + b['count'],
'sum': a['sum'] + b['sum'],
'min': min(a['min'], b['min']),
'max': max(a['max'], b['max']),
}
# Generate test data (simulating a large dataset)
import random
random.seed(42)
data = [
{'user_id': f'user_{i % 10}', 'amount': random.randint(10, 500),
'category': random.choice(['A', 'B', 'C'])}
for i in range(50) # In production, this would be millions
]
# Process in chunks of 8 (small for demo)
processor = ChunkedProcessor(chunk_size=8)
results = processor.process_with_aggregation(
data,
group_key='user_id',
transform_fn=lambda r: r, # No transformation in this example
agg_fn=merge_aggregates
)
print(f"Processing stats: {processor.stats}")
print(f"\nAggregated results ({len(results)} groups):")
for user, agg in sorted(results.items()):
avg = agg['sum'] / agg['count']
print(f" {user}: count={agg['count']}, sum={agg['sum']}, "
f"avg={avg:.0f}, range=[{agg['min']}, {agg['max']}]")
Complexity: O(n) total time, O(chunk_size) memory at any point. The merge operation is O(g) per chunk where g is unique groups in the chunk. This is the core pattern behind MapReduce and Spark's reduceByKey.
Challenge 2: Parallel Pipeline Execution
from concurrent.futures import ThreadPoolExecutor, as_completed, TimeoutError
import time
from collections import defaultdict
class ParallelPipelineExecutor:
"""Execute pipeline stages in parallel where dependencies allow.
Uses a DAG (directed acyclic graph) to determine which stages can
run concurrently.
"""
def __init__(self, max_workers=4, default_timeout=30):
self.max_workers = max_workers
self.default_timeout = default_timeout
self.stages = {} # name -> function
self.dependencies = {} # name -> [dependency_names]
self.results = {}
self.execution_log = []
def add_stage(self, name, func, depends_on=None):
"""Register a pipeline stage with optional dependencies."""
self.stages[name] = func
self.dependencies[name] = depends_on or []
return self
def _get_ready_stages(self, completed):
"""Find stages whose dependencies are all completed."""
ready = []
for name in self.stages:
if name in completed:
continue
deps = self.dependencies[name]
if all(d in completed for d in deps):
ready.append(name)
return ready
def execute(self, initial_data=None):
"""Execute all stages respecting dependencies and maximizing parallelism."""
completed = set()
self.results = {}
self.execution_log = []
start_time = time.time()
if initial_data is not None:
self.results['__input__'] = initial_data
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
while len(completed) < len(self.stages):
ready = self._get_ready_stages(completed)
if not ready:
if len(completed) < len(self.stages):
remaining = set(self.stages.keys()) - completed
raise RuntimeError(
f"Deadlock detected. Remaining stages: {remaining}"
)
break
# Submit all ready stages in parallel
futures = {}
for stage_name in ready:
# Gather dependency outputs as inputs
dep_results = {
dep: self.results[dep]
for dep in self.dependencies[stage_name]
}
if not dep_results and initial_data is not None:
dep_results = {'__input__': initial_data}
stage_start = time.time()
future = executor.submit(self.stages[stage_name], dep_results)
futures[future] = (stage_name, stage_start)
# Wait for submitted stages to complete
for future in as_completed(futures, timeout=self.default_timeout):
stage_name, stage_start = futures[future]
try:
result = future.result()
duration = time.time() - stage_start
self.results[stage_name] = result
completed.add(stage_name)
self.execution_log.append({
'stage': stage_name,
'status': 'success',
'duration_ms': round(duration * 1000, 1)
})
except Exception as e:
completed.add(stage_name)
self.results[stage_name] = None
self.execution_log.append({
'stage': stage_name,
'status': 'error',
'error': str(e)
})
total_time = time.time() - start_time
return {
'results': self.results,
'log': self.execution_log,
'total_time_ms': round(total_time * 1000, 1)
}
# Test: Pipeline with parallel stages
pipeline = ParallelPipelineExecutor(max_workers=3)
# Stage functions
def extract_users(inputs):
time.sleep(0.1) # Simulate I/O
return [{'id': 1, 'name': 'Alice'}, {'id': 2, 'name': 'Bob'}]
def extract_orders(inputs):
time.sleep(0.1) # Simulate I/O
return [{'user_id': 1, 'amount': 100}, {'user_id': 2, 'amount': 200}]
def extract_products(inputs):
time.sleep(0.1) # Simulate I/O
return [{'id': 'p1', 'name': 'Widget', 'price': 50}]
def transform_join(inputs):
"""Join users and orders - depends on both extract stages."""
users = {u['id']: u for u in inputs['extract_users']}
orders = inputs['extract_orders']
return [
{**order, 'user_name': users.get(order['user_id'], {}).get('name', 'Unknown')}
for order in orders
]
def load_results(inputs):
"""Final load - depends on transform."""
joined = inputs['transform_join']
products = inputs['extract_products']
return {'loaded_records': len(joined), 'product_count': len(products)}
# Build DAG
(pipeline
.add_stage('extract_users', extract_users)
.add_stage('extract_orders', extract_orders)
.add_stage('extract_products', extract_products)
.add_stage('transform_join', transform_join,
depends_on=['extract_users', 'extract_orders'])
.add_stage('load_results', load_results,
depends_on=['transform_join', 'extract_products'])
)
result = pipeline.execute()
print(f"Total time: {result['total_time_ms']}ms")
print(f"\nExecution log:")
for entry in result['log']:
print(f" {entry['stage']}: {entry['status']} ({entry.get('duration_ms', 'N/A')}ms)")
print(f"\nFinal result: {result['results']['load_results']}")
Complexity: Total time is the critical path through the DAG (longest chain of dependent stages), not the sum of all stages. With w workers and a balanced DAG, speedup approaches min(w, max_parallelism).
Challenge 3: Memory-Efficient Operations
import heapq
import hashlib
import math
class StreamingTopK:
"""Memory-efficient top-K using a min-heap.
Maintains only K elements in memory regardless of stream size.
"""
def __init__(self, k, key_func=None):
self.k = k
self.key_func = key_func or (lambda x: x)
self.heap = [] # Min-heap of (key, item) tuples
def add(self, item):
"""Add item to top-K tracker. O(log K) per item."""
key = self.key_func(item)
if len(self.heap) < self.k:
heapq.heappush(self.heap, (key, item))
elif key > self.heap[0][0]:
heapq.heapreplace(self.heap, (key, item))
def get_top_k(self):
"""Return top K items sorted by key descending."""
return [item for _, item in sorted(self.heap, reverse=True)]
class BloomFilter:
"""Probabilistic set membership for memory-efficient distinct counting.
Uses ~10 bits per element for 1% false positive rate.
"""
def __init__(self, expected_items, false_positive_rate=0.01):
self.size = self._optimal_size(expected_items, false_positive_rate)
self.num_hashes = self._optimal_hashes(self.size, expected_items)
self.bit_array = [0] * self.size
self.count = 0
def _optimal_size(self, n, p):
return int(-n * math.log(p) / (math.log(2) ** 2))
def _optimal_hashes(self, m, n):
return max(1, int(m / n * math.log(2)))
def _hash(self, item, seed):
h = hashlib.md5(f"{item}:{seed}".encode()).hexdigest()
return int(h, 16) % self.size
def add(self, item):
"""Add item to bloom filter."""
was_new = not self.might_contain(item)
for i in range(self.num_hashes):
self.bit_array[self._hash(item, i)] = 1
if was_new:
self.count += 1
def might_contain(self, item):
"""Check if item might be in the set (may have false positives)."""
return all(self.bit_array[self._hash(item, i)] for i in range(self.num_hashes))
class ExternalMergeSort:
"""Simulate external merge sort for data larger than memory.
Splits data into sorted chunks that fit in memory, then merges them.
"""
def __init__(self, memory_limit_records):
self.memory_limit = memory_limit_records
self.stats = {'chunks_created': 0, 'merge_passes': 0}
def sort(self, data, key_func=None):
"""Sort data using external merge sort pattern.
In production, chunks would be written to disk/S3.
Here we simulate the pattern in memory for demonstration.
"""
key_func = key_func or (lambda x: x)
# Phase 1: Create sorted chunks
chunks = []
for i in range(0, len(data), self.memory_limit):
chunk = sorted(data[i:i + self.memory_limit], key=key_func)
chunks.append(chunk)
self.stats['chunks_created'] += 1
# Phase 2: K-way merge using a heap
self.stats['merge_passes'] += 1
result = list(heapq.merge(*chunks, key=key_func))
return result
# Test all three
import random
random.seed(42)
# 1. Streaming Top-K
print("=== Streaming Top-K ===")
data = [{'name': f'item_{i}', 'score': random.randint(1, 1000)} for i in range(100)]
topk = StreamingTopK(k=5, key_func=lambda x: x['score'])
for item in data:
topk.add(item)
print(f"Top 5 from {len(data)} items (only 5 in memory):")
for item in topk.get_top_k():
print(f" {item['name']}: score={item['score']}")
# 2. Bloom Filter for approximate distinct
print("\n=== Bloom Filter Distinct ===")
items = [f"user_{random.randint(1, 50)}" for _ in range(200)]
bloom = BloomFilter(expected_items=100)
for item in items:
bloom.add(item)
actual_distinct = len(set(items))
print(f"Actual distinct: {actual_distinct}")
print(f"Bloom estimate: {bloom.count}")
print(f"Bloom filter size: {bloom.size} bits ({bloom.size / 8:.0f} bytes)")
print(f"vs storing all strings: ~{actual_distinct * 20} bytes")
# 3. External Merge Sort
print("\n=== External Merge Sort ===")
data = [random.randint(1, 10000) for _ in range(50)]
sorter = ExternalMergeSort(memory_limit_records=10)
sorted_data = sorter.sort(data)
print(f"Sorted {len(data)} records using {sorter.stats['chunks_created']} chunks")
print(f"First 10: {sorted_data[:10]}")
print(f"Correctly sorted: {sorted_data == sorted(data)}")
Complexity: Top-K: O(n log K) time, O(K) space. Bloom Filter: O(k) per operation where k is hash count, O(m) space where m is bit array size. External Merge Sort: O(n log n) time, O(memory_limit) space per pass.
Challenge 4: Caching for Pipeline Performance
import time
from collections import OrderedDict
class PipelineCache:
"""Multi-level cache with LRU eviction and TTL expiry.
Designed for data pipelines where the same lookups (e.g., dimension
table joins, feature store reads) happen repeatedly.
"""
def __init__(self, max_size=1000, ttl_seconds=300):
self.max_size = max_size
self.ttl = ttl_seconds
self.cache = OrderedDict() # key -> (value, expiry_time)
self.stats = {
'hits': 0, 'misses': 0, 'evictions': 0,
'expired': 0, 'sets': 0
}
def get(self, key):
"""Get a value from cache. Returns (value, hit) tuple."""
if key in self.cache:
value, expiry = self.cache[key]
if time.time() < expiry:
# Move to end (most recently used)
self.cache.move_to_end(key)
self.stats['hits'] += 1
return value, True
else:
# Expired
del self.cache[key]
self.stats['expired'] += 1
self.stats['misses'] += 1
return None, False
def set(self, key, value, ttl=None):
"""Set a value in cache with optional custom TTL."""
if key in self.cache:
self.cache.move_to_end(key)
elif len(self.cache) >= self.max_size:
self.cache.popitem(last=False) # Evict LRU
self.stats['evictions'] += 1
expiry = time.time() + (ttl or self.ttl)
self.cache[key] = (value, expiry)
self.stats['sets'] += 1
def batch_get(self, keys):
"""Batch lookup - returns (found_dict, missing_keys)."""
found = {}
missing = []
for key in keys:
value, hit = self.get(key)
if hit:
found[key] = value
else:
missing.append(key)
return found, missing
def warm(self, data_dict, ttl=None):
"""Pre-populate cache with known values."""
for key, value in data_dict.items():
self.set(key, value, ttl)
@property
def hit_rate(self):
total = self.stats['hits'] + self.stats['misses']
return self.stats['hits'] / total if total > 0 else 0
def get_stats(self):
return {**self.stats, 'hit_rate': f"{self.hit_rate:.1%}",
'size': len(self.cache)}
def cached_lookup(cache, keys, fetch_fn):
"""Pattern: check cache first, fetch missing, update cache.
This is the standard read-through caching pattern used in pipelines.
"""
found, missing = cache.batch_get(keys)
if missing:
# Fetch missing values from source (e.g., database, API)
fetched = fetch_fn(missing)
for key, value in fetched.items():
cache.set(key, value)
found[key] = value
return found
# Test
cache = PipelineCache(max_size=5, ttl_seconds=10)
# Simulate a dimension table lookup (e.g., user_id -> user details)
dimension_table = {
'u1': {'name': 'Alice', 'tier': 'gold'},
'u2': {'name': 'Bob', 'tier': 'silver'},
'u3': {'name': 'Charlie', 'tier': 'bronze'},
'u4': {'name': 'Diana', 'tier': 'gold'},
'u5': {'name': 'Eve', 'tier': 'silver'},
'u6': {'name': 'Frank', 'tier': 'bronze'},
}
def fetch_from_db(keys):
"""Simulate database lookup."""
return {k: dimension_table[k] for k in keys if k in dimension_table}
# Process a stream of events that need dimension enrichment
events = [
{'user_id': 'u1', 'action': 'click'},
{'user_id': 'u2', 'action': 'view'},
{'user_id': 'u1', 'action': 'buy'}, # Cache hit
{'user_id': 'u3', 'action': 'click'},
{'user_id': 'u1', 'action': 'view'}, # Cache hit
{'user_id': 'u4', 'action': 'buy'},
{'user_id': 'u5', 'action': 'click'},
{'user_id': 'u6', 'action': 'view'}, # Causes eviction
{'user_id': 'u1', 'action': 'click'}, # May be evicted
{'user_id': 'u2', 'action': 'buy'}, # May be evicted
]
for event in events:
user_id = event['user_id']
enriched = cached_lookup(cache, [user_id], fetch_from_db)
user_info = enriched.get(user_id, {})
print(f" {user_id} ({user_info.get('tier', '?')}): {event['action']}")
print(f"\nCache stats: {cache.get_stats()}")
Complexity: O(1) amortized for get/set operations. Batch get is O(b) where b is batch size. The LRU eviction policy ensures bounded memory usage regardless of stream length.
Challenge 5: Batch vs Streaming Trade-offs
import time
from collections import defaultdict
class BatchAggregator:
"""Batch processing: collect all data, then aggregate.
Pros: Simple, efficient for full-dataset operations, exact results.
Cons: High latency (must wait for all data), high memory (all in memory).
"""
def process(self, records, group_key, value_key):
start = time.time()
groups = defaultdict(list)
for record in records:
groups[record[group_key]].append(record[value_key])
results = {}
for key, values in groups.items():
results[key] = {
'count': len(values),
'sum': sum(values),
'avg': sum(values) / len(values),
'min': min(values),
'max': max(values),
}
return {
'results': results,
'latency_ms': round((time.time() - start) * 1000, 2),
'memory_records': len(records), # All records in memory
'mode': 'batch'
}
class StreamingAggregator:
"""Streaming processing: update aggregates incrementally per event.
Pros: Low latency (results available immediately), constant memory.
Cons: Some operations are harder (median, percentile), state management.
"""
def __init__(self):
self.state = defaultdict(lambda: {
'count': 0, 'sum': 0, 'min': float('inf'), 'max': float('-inf')
})
self.events_processed = 0
def process_event(self, record, group_key, value_key):
key = record[group_key]
value = record[value_key]
s = self.state[key]
s['count'] += 1
s['sum'] += value
s['min'] = min(s['min'], value)
s['max'] = max(s['max'], value)
self.events_processed += 1
def get_results(self):
results = {}
for key, s in self.state.items():
results[key] = {
**s,
'avg': s['sum'] / s['count'] if s['count'] > 0 else 0
}
return results
class MicroBatchAggregator:
"""Micro-batch: collect small batches, process periodically.
Balances latency and throughput. Used by Spark Structured Streaming.
Pros: Better throughput than per-event, lower latency than full batch.
Cons: Tuning batch interval is critical.
"""
def __init__(self, batch_interval_records=10):
self.batch_interval = batch_interval_records
self.buffer = []
self.state = defaultdict(lambda: {
'count': 0, 'sum': 0, 'min': float('inf'), 'max': float('-inf')
})
self.batches_processed = 0
def add_event(self, record, group_key, value_key):
self.buffer.append(record)
if len(self.buffer) >= self.batch_interval:
self._flush_batch(group_key, value_key)
def _flush_batch(self, group_key, value_key):
"""Process accumulated micro-batch."""
for record in self.buffer:
key = record[group_key]
value = record[value_key]
s = self.state[key]
s['count'] += 1
s['sum'] += value
s['min'] = min(s['min'], value)
s['max'] = max(s['max'], value)
self.buffer = []
self.batches_processed += 1
def finalize(self, group_key, value_key):
if self.buffer:
self._flush_batch(group_key, value_key)
def get_results(self):
results = {}
for key, s in self.state.items():
results[key] = {
**s,
'avg': s['sum'] / s['count'] if s['count'] > 0 else 0
}
return results
# Compare all three approaches
import random
random.seed(42)
records = [
{'user': f'user_{random.randint(1, 5)}',
'amount': random.randint(10, 500)}
for _ in range(100)
]
# 1. Batch
batch = BatchAggregator()
batch_result = batch.process(records, 'user', 'amount')
# 2. Streaming
stream = StreamingAggregator()
stream_start = time.time()
for record in records:
stream.process_event(record, 'user', 'amount')
stream_latency = round((time.time() - stream_start) * 1000, 2)
stream_results = stream.get_results()
# 3. Micro-batch
micro = MicroBatchAggregator(batch_interval_records=10)
micro_start = time.time()
for record in records:
micro.add_event(record, 'user', 'amount')
micro.finalize('user', 'amount')
micro_latency = round((time.time() - micro_start) * 1000, 2)
micro_results = micro.get_results()
# Comparison
print("=== Comparison: Batch vs Streaming vs Micro-Batch ===\n")
print(f"{'Metric':<25} {'Batch':<15} {'Streaming':<15} {'Micro-Batch':<15}")
print('-' * 70)
print(f"{'Latency (ms)':<25} {batch_result['latency_ms']:<15} {stream_latency:<15} {micro_latency:<15}")
print(f"{'Memory (records)':<25} {batch_result['memory_records']:<15} {'1 (constant)':<15} {micro.batch_interval:<15}")
print(f"{'Batches':<25} {'1':<15} {'N/A':<15} {micro.batches_processed:<15}")
print(f"\nResults match: ", end="")
# Verify all three produce the same results
for user in batch_result['results']:
b = batch_result['results'][user]
s = stream_results[user]
m = micro_results[user]
assert b['count'] == s['count'] == m['count']
assert b['sum'] == s['sum'] == m['sum']
print("All three approaches produce identical results.")
Complexity: All three are O(n) in time. Batch uses O(n) memory, streaming uses O(g) where g is unique groups, micro-batch uses O(b + g) where b is batch size. The trade-off is latency vs throughput vs simplicity.
Lilly Tech Systems