Advanced

Performance Optimization

Pipeline performance is the difference between a pipeline that runs in minutes and one that takes hours. These 5 challenges cover the optimization techniques that senior data engineers use daily and interviewers test to separate mid-level from senior candidates.

Challenge 1: Chunked Processing

📝
Problem: Implement a chunked processing framework that reads large datasets in configurable chunks, applies transformations to each chunk, aggregates results across chunks, and handles chunk-boundary edge cases (e.g., a group split across two chunks).
import sys
from typing import Callable, Generator, List, Dict, Any

class ChunkedProcessor:
    """Process large datasets in memory-efficient chunks.

    Handles the key challenge of chunked processing: correctly aggregating
    results when groups may span multiple chunks.
    """

    def __init__(self, chunk_size=1000):
        self.chunk_size = chunk_size
        self.stats = {'chunks_processed': 0, 'records_processed': 0,
                      'peak_memory_records': 0}

    def chunk_generator(self, data) -> Generator:
        """Yield data in chunks of chunk_size."""
        for i in range(0, len(data), self.chunk_size):
            chunk = data[i:i + self.chunk_size]
            self.stats['chunks_processed'] += 1
            self.stats['records_processed'] += len(chunk)
            self.stats['peak_memory_records'] = max(
                self.stats['peak_memory_records'], len(chunk)
            )
            yield chunk

    def process_with_aggregation(self, data, group_key, transform_fn, agg_fn):
        """Process data in chunks with correct cross-chunk aggregation.

        This is the hard part: if records with the same group_key appear
        in different chunks, we need to merge their partial results.

        Args:
            data: Full dataset (list of dicts)
            group_key: Field to group by
            transform_fn: Function applied to each record
            agg_fn: Function that merges two partial aggregates
        """
        partial_results = {}  # group_key_value -> partial aggregate

        for chunk in self.chunk_generator(data):
            # Transform each record
            transformed = [transform_fn(record) for record in chunk]

            # Group within this chunk
            chunk_groups = {}
            for record in transformed:
                key = record[group_key]
                if key not in chunk_groups:
                    chunk_groups[key] = []
                chunk_groups[key].append(record)

            # Merge chunk results with running aggregates
            for key, group_records in chunk_groups.items():
                chunk_agg = self._reduce_group(group_records)
                if key in partial_results:
                    partial_results[key] = agg_fn(partial_results[key], chunk_agg)
                else:
                    partial_results[key] = chunk_agg

        return partial_results

    def _reduce_group(self, records):
        """Compute basic aggregates for a group of records."""
        amounts = [r.get('amount', 0) for r in records]
        return {
            'count': len(records),
            'sum': sum(amounts),
            'min': min(amounts) if amounts else 0,
            'max': max(amounts) if amounts else 0,
        }

# Merge function for combining partial aggregates
def merge_aggregates(a, b):
    """Merge two partial aggregation results."""
    return {
        'count': a['count'] + b['count'],
        'sum': a['sum'] + b['sum'],
        'min': min(a['min'], b['min']),
        'max': max(a['max'], b['max']),
    }

# Generate test data (simulating a large dataset)
import random
random.seed(42)
data = [
    {'user_id': f'user_{i % 10}', 'amount': random.randint(10, 500),
     'category': random.choice(['A', 'B', 'C'])}
    for i in range(50)  # In production, this would be millions
]

# Process in chunks of 8 (small for demo)
processor = ChunkedProcessor(chunk_size=8)

results = processor.process_with_aggregation(
    data,
    group_key='user_id',
    transform_fn=lambda r: r,  # No transformation in this example
    agg_fn=merge_aggregates
)

print(f"Processing stats: {processor.stats}")
print(f"\nAggregated results ({len(results)} groups):")
for user, agg in sorted(results.items()):
    avg = agg['sum'] / agg['count']
    print(f"  {user}: count={agg['count']}, sum={agg['sum']}, "
          f"avg={avg:.0f}, range=[{agg['min']}, {agg['max']}]")

Complexity: O(n) total time, O(chunk_size) memory at any point. The merge operation is O(g) per chunk where g is unique groups in the chunk. This is the core pattern behind MapReduce and Spark's reduceByKey.

Challenge 2: Parallel Pipeline Execution

📝
Problem: Implement a parallel pipeline executor that runs independent pipeline stages concurrently using a thread pool. Support dependency declaration between stages, automatic parallelization of independent stages, and timeout handling.
from concurrent.futures import ThreadPoolExecutor, as_completed, TimeoutError
import time
from collections import defaultdict

class ParallelPipelineExecutor:
    """Execute pipeline stages in parallel where dependencies allow.

    Uses a DAG (directed acyclic graph) to determine which stages can
    run concurrently.
    """

    def __init__(self, max_workers=4, default_timeout=30):
        self.max_workers = max_workers
        self.default_timeout = default_timeout
        self.stages = {}          # name -> function
        self.dependencies = {}    # name -> [dependency_names]
        self.results = {}
        self.execution_log = []

    def add_stage(self, name, func, depends_on=None):
        """Register a pipeline stage with optional dependencies."""
        self.stages[name] = func
        self.dependencies[name] = depends_on or []
        return self

    def _get_ready_stages(self, completed):
        """Find stages whose dependencies are all completed."""
        ready = []
        for name in self.stages:
            if name in completed:
                continue
            deps = self.dependencies[name]
            if all(d in completed for d in deps):
                ready.append(name)
        return ready

    def execute(self, initial_data=None):
        """Execute all stages respecting dependencies and maximizing parallelism."""
        completed = set()
        self.results = {}
        self.execution_log = []
        start_time = time.time()

        if initial_data is not None:
            self.results['__input__'] = initial_data

        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
            while len(completed) < len(self.stages):
                ready = self._get_ready_stages(completed)

                if not ready:
                    if len(completed) < len(self.stages):
                        remaining = set(self.stages.keys()) - completed
                        raise RuntimeError(
                            f"Deadlock detected. Remaining stages: {remaining}"
                        )
                    break

                # Submit all ready stages in parallel
                futures = {}
                for stage_name in ready:
                    # Gather dependency outputs as inputs
                    dep_results = {
                        dep: self.results[dep]
                        for dep in self.dependencies[stage_name]
                    }
                    if not dep_results and initial_data is not None:
                        dep_results = {'__input__': initial_data}

                    stage_start = time.time()
                    future = executor.submit(self.stages[stage_name], dep_results)
                    futures[future] = (stage_name, stage_start)

                # Wait for submitted stages to complete
                for future in as_completed(futures, timeout=self.default_timeout):
                    stage_name, stage_start = futures[future]
                    try:
                        result = future.result()
                        duration = time.time() - stage_start
                        self.results[stage_name] = result
                        completed.add(stage_name)
                        self.execution_log.append({
                            'stage': stage_name,
                            'status': 'success',
                            'duration_ms': round(duration * 1000, 1)
                        })
                    except Exception as e:
                        completed.add(stage_name)
                        self.results[stage_name] = None
                        self.execution_log.append({
                            'stage': stage_name,
                            'status': 'error',
                            'error': str(e)
                        })

        total_time = time.time() - start_time
        return {
            'results': self.results,
            'log': self.execution_log,
            'total_time_ms': round(total_time * 1000, 1)
        }

# Test: Pipeline with parallel stages
pipeline = ParallelPipelineExecutor(max_workers=3)

# Stage functions
def extract_users(inputs):
    time.sleep(0.1)  # Simulate I/O
    return [{'id': 1, 'name': 'Alice'}, {'id': 2, 'name': 'Bob'}]

def extract_orders(inputs):
    time.sleep(0.1)  # Simulate I/O
    return [{'user_id': 1, 'amount': 100}, {'user_id': 2, 'amount': 200}]

def extract_products(inputs):
    time.sleep(0.1)  # Simulate I/O
    return [{'id': 'p1', 'name': 'Widget', 'price': 50}]

def transform_join(inputs):
    """Join users and orders - depends on both extract stages."""
    users = {u['id']: u for u in inputs['extract_users']}
    orders = inputs['extract_orders']
    return [
        {**order, 'user_name': users.get(order['user_id'], {}).get('name', 'Unknown')}
        for order in orders
    ]

def load_results(inputs):
    """Final load - depends on transform."""
    joined = inputs['transform_join']
    products = inputs['extract_products']
    return {'loaded_records': len(joined), 'product_count': len(products)}

# Build DAG
(pipeline
    .add_stage('extract_users', extract_users)
    .add_stage('extract_orders', extract_orders)
    .add_stage('extract_products', extract_products)
    .add_stage('transform_join', transform_join,
               depends_on=['extract_users', 'extract_orders'])
    .add_stage('load_results', load_results,
               depends_on=['transform_join', 'extract_products'])
)

result = pipeline.execute()
print(f"Total time: {result['total_time_ms']}ms")
print(f"\nExecution log:")
for entry in result['log']:
    print(f"  {entry['stage']}: {entry['status']} ({entry.get('duration_ms', 'N/A')}ms)")
print(f"\nFinal result: {result['results']['load_results']}")

Complexity: Total time is the critical path through the DAG (longest chain of dependent stages), not the sum of all stages. With w workers and a balanced DAG, speedup approaches min(w, max_parallelism).

Challenge 3: Memory-Efficient Operations

📝
Problem: Implement memory-efficient versions of common data operations: streaming sort (external merge sort for data larger than memory), streaming distinct (using probabilistic data structures), and streaming top-K using a heap. Compare memory usage with naive approaches.
import heapq
import hashlib
import math

class StreamingTopK:
    """Memory-efficient top-K using a min-heap.

    Maintains only K elements in memory regardless of stream size.
    """

    def __init__(self, k, key_func=None):
        self.k = k
        self.key_func = key_func or (lambda x: x)
        self.heap = []  # Min-heap of (key, item) tuples

    def add(self, item):
        """Add item to top-K tracker. O(log K) per item."""
        key = self.key_func(item)
        if len(self.heap) < self.k:
            heapq.heappush(self.heap, (key, item))
        elif key > self.heap[0][0]:
            heapq.heapreplace(self.heap, (key, item))

    def get_top_k(self):
        """Return top K items sorted by key descending."""
        return [item for _, item in sorted(self.heap, reverse=True)]


class BloomFilter:
    """Probabilistic set membership for memory-efficient distinct counting.

    Uses ~10 bits per element for 1% false positive rate.
    """

    def __init__(self, expected_items, false_positive_rate=0.01):
        self.size = self._optimal_size(expected_items, false_positive_rate)
        self.num_hashes = self._optimal_hashes(self.size, expected_items)
        self.bit_array = [0] * self.size
        self.count = 0

    def _optimal_size(self, n, p):
        return int(-n * math.log(p) / (math.log(2) ** 2))

    def _optimal_hashes(self, m, n):
        return max(1, int(m / n * math.log(2)))

    def _hash(self, item, seed):
        h = hashlib.md5(f"{item}:{seed}".encode()).hexdigest()
        return int(h, 16) % self.size

    def add(self, item):
        """Add item to bloom filter."""
        was_new = not self.might_contain(item)
        for i in range(self.num_hashes):
            self.bit_array[self._hash(item, i)] = 1
        if was_new:
            self.count += 1

    def might_contain(self, item):
        """Check if item might be in the set (may have false positives)."""
        return all(self.bit_array[self._hash(item, i)] for i in range(self.num_hashes))


class ExternalMergeSort:
    """Simulate external merge sort for data larger than memory.

    Splits data into sorted chunks that fit in memory, then merges them.
    """

    def __init__(self, memory_limit_records):
        self.memory_limit = memory_limit_records
        self.stats = {'chunks_created': 0, 'merge_passes': 0}

    def sort(self, data, key_func=None):
        """Sort data using external merge sort pattern.

        In production, chunks would be written to disk/S3.
        Here we simulate the pattern in memory for demonstration.
        """
        key_func = key_func or (lambda x: x)

        # Phase 1: Create sorted chunks
        chunks = []
        for i in range(0, len(data), self.memory_limit):
            chunk = sorted(data[i:i + self.memory_limit], key=key_func)
            chunks.append(chunk)
            self.stats['chunks_created'] += 1

        # Phase 2: K-way merge using a heap
        self.stats['merge_passes'] += 1
        result = list(heapq.merge(*chunks, key=key_func))
        return result

# Test all three
import random
random.seed(42)

# 1. Streaming Top-K
print("=== Streaming Top-K ===")
data = [{'name': f'item_{i}', 'score': random.randint(1, 1000)} for i in range(100)]

topk = StreamingTopK(k=5, key_func=lambda x: x['score'])
for item in data:
    topk.add(item)

print(f"Top 5 from {len(data)} items (only 5 in memory):")
for item in topk.get_top_k():
    print(f"  {item['name']}: score={item['score']}")

# 2. Bloom Filter for approximate distinct
print("\n=== Bloom Filter Distinct ===")
items = [f"user_{random.randint(1, 50)}" for _ in range(200)]
bloom = BloomFilter(expected_items=100)

for item in items:
    bloom.add(item)

actual_distinct = len(set(items))
print(f"Actual distinct: {actual_distinct}")
print(f"Bloom estimate: {bloom.count}")
print(f"Bloom filter size: {bloom.size} bits ({bloom.size / 8:.0f} bytes)")
print(f"vs storing all strings: ~{actual_distinct * 20} bytes")

# 3. External Merge Sort
print("\n=== External Merge Sort ===")
data = [random.randint(1, 10000) for _ in range(50)]
sorter = ExternalMergeSort(memory_limit_records=10)
sorted_data = sorter.sort(data)
print(f"Sorted {len(data)} records using {sorter.stats['chunks_created']} chunks")
print(f"First 10: {sorted_data[:10]}")
print(f"Correctly sorted: {sorted_data == sorted(data)}")

Complexity: Top-K: O(n log K) time, O(K) space. Bloom Filter: O(k) per operation where k is hash count, O(m) space where m is bit array size. External Merge Sort: O(n log n) time, O(memory_limit) space per pass.

Challenge 4: Caching for Pipeline Performance

📝
Problem: Implement a multi-level caching system for a data pipeline that includes: an LRU in-memory cache, a TTL (time-to-live) policy, cache warming, and statistics tracking (hit rate, eviction count). Support both individual record and batch lookups.
import time
from collections import OrderedDict

class PipelineCache:
    """Multi-level cache with LRU eviction and TTL expiry.

    Designed for data pipelines where the same lookups (e.g., dimension
    table joins, feature store reads) happen repeatedly.
    """

    def __init__(self, max_size=1000, ttl_seconds=300):
        self.max_size = max_size
        self.ttl = ttl_seconds
        self.cache = OrderedDict()  # key -> (value, expiry_time)
        self.stats = {
            'hits': 0, 'misses': 0, 'evictions': 0,
            'expired': 0, 'sets': 0
        }

    def get(self, key):
        """Get a value from cache. Returns (value, hit) tuple."""
        if key in self.cache:
            value, expiry = self.cache[key]
            if time.time() < expiry:
                # Move to end (most recently used)
                self.cache.move_to_end(key)
                self.stats['hits'] += 1
                return value, True
            else:
                # Expired
                del self.cache[key]
                self.stats['expired'] += 1

        self.stats['misses'] += 1
        return None, False

    def set(self, key, value, ttl=None):
        """Set a value in cache with optional custom TTL."""
        if key in self.cache:
            self.cache.move_to_end(key)
        elif len(self.cache) >= self.max_size:
            self.cache.popitem(last=False)  # Evict LRU
            self.stats['evictions'] += 1

        expiry = time.time() + (ttl or self.ttl)
        self.cache[key] = (value, expiry)
        self.stats['sets'] += 1

    def batch_get(self, keys):
        """Batch lookup - returns (found_dict, missing_keys)."""
        found = {}
        missing = []
        for key in keys:
            value, hit = self.get(key)
            if hit:
                found[key] = value
            else:
                missing.append(key)
        return found, missing

    def warm(self, data_dict, ttl=None):
        """Pre-populate cache with known values."""
        for key, value in data_dict.items():
            self.set(key, value, ttl)

    @property
    def hit_rate(self):
        total = self.stats['hits'] + self.stats['misses']
        return self.stats['hits'] / total if total > 0 else 0

    def get_stats(self):
        return {**self.stats, 'hit_rate': f"{self.hit_rate:.1%}",
                'size': len(self.cache)}


def cached_lookup(cache, keys, fetch_fn):
    """Pattern: check cache first, fetch missing, update cache.

    This is the standard read-through caching pattern used in pipelines.
    """
    found, missing = cache.batch_get(keys)

    if missing:
        # Fetch missing values from source (e.g., database, API)
        fetched = fetch_fn(missing)
        for key, value in fetched.items():
            cache.set(key, value)
            found[key] = value

    return found

# Test
cache = PipelineCache(max_size=5, ttl_seconds=10)

# Simulate a dimension table lookup (e.g., user_id -> user details)
dimension_table = {
    'u1': {'name': 'Alice', 'tier': 'gold'},
    'u2': {'name': 'Bob', 'tier': 'silver'},
    'u3': {'name': 'Charlie', 'tier': 'bronze'},
    'u4': {'name': 'Diana', 'tier': 'gold'},
    'u5': {'name': 'Eve', 'tier': 'silver'},
    'u6': {'name': 'Frank', 'tier': 'bronze'},
}

def fetch_from_db(keys):
    """Simulate database lookup."""
    return {k: dimension_table[k] for k in keys if k in dimension_table}

# Process a stream of events that need dimension enrichment
events = [
    {'user_id': 'u1', 'action': 'click'},
    {'user_id': 'u2', 'action': 'view'},
    {'user_id': 'u1', 'action': 'buy'},    # Cache hit
    {'user_id': 'u3', 'action': 'click'},
    {'user_id': 'u1', 'action': 'view'},   # Cache hit
    {'user_id': 'u4', 'action': 'buy'},
    {'user_id': 'u5', 'action': 'click'},
    {'user_id': 'u6', 'action': 'view'},   # Causes eviction
    {'user_id': 'u1', 'action': 'click'},  # May be evicted
    {'user_id': 'u2', 'action': 'buy'},    # May be evicted
]

for event in events:
    user_id = event['user_id']
    enriched = cached_lookup(cache, [user_id], fetch_from_db)
    user_info = enriched.get(user_id, {})
    print(f"  {user_id} ({user_info.get('tier', '?')}): {event['action']}")

print(f"\nCache stats: {cache.get_stats()}")

Complexity: O(1) amortized for get/set operations. Batch get is O(b) where b is batch size. The LRU eviction policy ensures bounded memory usage regardless of stream length.

Challenge 5: Batch vs Streaming Trade-offs

📝
Problem: Implement both a batch and a streaming version of the same aggregation pipeline. Compare them on latency, throughput, memory usage, and correctness. Implement a hybrid "micro-batch" approach that balances both.
import time
from collections import defaultdict

class BatchAggregator:
    """Batch processing: collect all data, then aggregate.

    Pros: Simple, efficient for full-dataset operations, exact results.
    Cons: High latency (must wait for all data), high memory (all in memory).
    """

    def process(self, records, group_key, value_key):
        start = time.time()

        groups = defaultdict(list)
        for record in records:
            groups[record[group_key]].append(record[value_key])

        results = {}
        for key, values in groups.items():
            results[key] = {
                'count': len(values),
                'sum': sum(values),
                'avg': sum(values) / len(values),
                'min': min(values),
                'max': max(values),
            }

        return {
            'results': results,
            'latency_ms': round((time.time() - start) * 1000, 2),
            'memory_records': len(records),  # All records in memory
            'mode': 'batch'
        }


class StreamingAggregator:
    """Streaming processing: update aggregates incrementally per event.

    Pros: Low latency (results available immediately), constant memory.
    Cons: Some operations are harder (median, percentile), state management.
    """

    def __init__(self):
        self.state = defaultdict(lambda: {
            'count': 0, 'sum': 0, 'min': float('inf'), 'max': float('-inf')
        })
        self.events_processed = 0

    def process_event(self, record, group_key, value_key):
        key = record[group_key]
        value = record[value_key]
        s = self.state[key]

        s['count'] += 1
        s['sum'] += value
        s['min'] = min(s['min'], value)
        s['max'] = max(s['max'], value)
        self.events_processed += 1

    def get_results(self):
        results = {}
        for key, s in self.state.items():
            results[key] = {
                **s,
                'avg': s['sum'] / s['count'] if s['count'] > 0 else 0
            }
        return results


class MicroBatchAggregator:
    """Micro-batch: collect small batches, process periodically.

    Balances latency and throughput. Used by Spark Structured Streaming.

    Pros: Better throughput than per-event, lower latency than full batch.
    Cons: Tuning batch interval is critical.
    """

    def __init__(self, batch_interval_records=10):
        self.batch_interval = batch_interval_records
        self.buffer = []
        self.state = defaultdict(lambda: {
            'count': 0, 'sum': 0, 'min': float('inf'), 'max': float('-inf')
        })
        self.batches_processed = 0

    def add_event(self, record, group_key, value_key):
        self.buffer.append(record)
        if len(self.buffer) >= self.batch_interval:
            self._flush_batch(group_key, value_key)

    def _flush_batch(self, group_key, value_key):
        """Process accumulated micro-batch."""
        for record in self.buffer:
            key = record[group_key]
            value = record[value_key]
            s = self.state[key]
            s['count'] += 1
            s['sum'] += value
            s['min'] = min(s['min'], value)
            s['max'] = max(s['max'], value)

        self.buffer = []
        self.batches_processed += 1

    def finalize(self, group_key, value_key):
        if self.buffer:
            self._flush_batch(group_key, value_key)

    def get_results(self):
        results = {}
        for key, s in self.state.items():
            results[key] = {
                **s,
                'avg': s['sum'] / s['count'] if s['count'] > 0 else 0
            }
        return results


# Compare all three approaches
import random
random.seed(42)

records = [
    {'user': f'user_{random.randint(1, 5)}',
     'amount': random.randint(10, 500)}
    for _ in range(100)
]

# 1. Batch
batch = BatchAggregator()
batch_result = batch.process(records, 'user', 'amount')

# 2. Streaming
stream = StreamingAggregator()
stream_start = time.time()
for record in records:
    stream.process_event(record, 'user', 'amount')
stream_latency = round((time.time() - stream_start) * 1000, 2)
stream_results = stream.get_results()

# 3. Micro-batch
micro = MicroBatchAggregator(batch_interval_records=10)
micro_start = time.time()
for record in records:
    micro.add_event(record, 'user', 'amount')
micro.finalize('user', 'amount')
micro_latency = round((time.time() - micro_start) * 1000, 2)
micro_results = micro.get_results()

# Comparison
print("=== Comparison: Batch vs Streaming vs Micro-Batch ===\n")

print(f"{'Metric':<25} {'Batch':<15} {'Streaming':<15} {'Micro-Batch':<15}")
print('-' * 70)
print(f"{'Latency (ms)':<25} {batch_result['latency_ms']:<15} {stream_latency:<15} {micro_latency:<15}")
print(f"{'Memory (records)':<25} {batch_result['memory_records']:<15} {'1 (constant)':<15} {micro.batch_interval:<15}")
print(f"{'Batches':<25} {'1':<15} {'N/A':<15} {micro.batches_processed:<15}")

print(f"\nResults match: ", end="")
# Verify all three produce the same results
for user in batch_result['results']:
    b = batch_result['results'][user]
    s = stream_results[user]
    m = micro_results[user]
    assert b['count'] == s['count'] == m['count']
    assert b['sum'] == s['sum'] == m['sum']
print("All three approaches produce identical results.")

Complexity: All three are O(n) in time. Batch uses O(n) memory, streaming uses O(g) where g is unique groups, micro-batch uses O(b + g) where b is batch size. The trade-off is latency vs throughput vs simplicity.

💡
Interview tip: When asked about batch vs streaming, structure your answer around four axes: latency requirements, throughput needs, correctness guarantees, and operational complexity. Most production systems use a hybrid approach (Lambda or Kappa architecture). Mention that Spark Structured Streaming uses micro-batching internally, while Flink uses true per-event processing.