Comparison Sorts: QuickSort, MergeSort, HeapSort
These three algorithms form the foundation of sorting in computer science. Every ML engineer must be able to implement them from scratch, explain their trade-offs, and know when to use each one. This lesson provides complete Python implementations with line-by-line analysis.
1. QuickSort
QuickSort is the fastest comparison sort in practice for random data. It works by selecting a pivot element, partitioning the array so that elements smaller than the pivot go left and larger ones go right, then recursing on both halves.
| Property | Value |
|---|---|
| Average Time | O(n log n) |
| Worst Time | O(n²) — when pivot is always min or max |
| Space | O(log n) for recursion stack |
| Stable? | No — swaps can change relative order of equal elements |
| In-place? | Yes |
import random
def quicksort(arr: list, low: int = 0, high: int = None) -> None:
"""In-place QuickSort with randomized pivot.
Time: O(n log n) average, O(n^2) worst case
Space: O(log n) recursion stack
"""
if high is None:
high = len(arr) - 1
if low < high:
pivot_idx = partition(arr, low, high)
quicksort(arr, low, pivot_idx - 1)
quicksort(arr, pivot_idx + 1, high)
def partition(arr: list, low: int, high: int) -> int:
"""Lomuto partition with randomized pivot selection.
Randomization prevents O(n^2) on sorted/reverse-sorted input.
"""
# Randomize pivot to avoid worst-case on sorted data
rand_idx = random.randint(low, high)
arr[rand_idx], arr[high] = arr[high], arr[rand_idx]
pivot = arr[high]
i = low - 1 # i tracks the boundary of elements <= pivot
for j in range(low, high):
if arr[j] <= pivot:
i += 1
arr[i], arr[j] = arr[j], arr[i]
arr[i + 1], arr[high] = arr[high], arr[i + 1]
return i + 1
# Example usage
scores = [0.95, 0.23, 0.87, 0.41, 0.99, 0.15, 0.73]
quicksort(scores)
print(scores) # [0.15, 0.23, 0.41, 0.73, 0.87, 0.95, 0.99]
2. MergeSort
MergeSort guarantees O(n log n) in all cases and is stable, meaning equal elements maintain their original order. This makes it the default choice when stability matters — for example, when sorting ML predictions that have the same confidence score.
| Property | Value |
|---|---|
| Time | O(n log n) in all cases |
| Space | O(n) — needs auxiliary array for merging |
| Stable? | Yes — preserves relative order of equal elements |
| In-place? | No |
def merge_sort(arr: list) -> list:
"""Stable MergeSort returning a new sorted list.
Time: O(n log n) guaranteed
Space: O(n) for auxiliary arrays
"""
if len(arr) <= 1:
return arr
mid = len(arr) // 2
left = merge_sort(arr[:mid])
right = merge_sort(arr[mid:])
return merge(left, right)
def merge(left: list, right: list) -> list:
"""Merge two sorted lists into one sorted list.
Uses <= for stability: equal elements from 'left' come first.
"""
result = []
i = j = 0
while i < len(left) and j < len(right):
if left[i] <= right[j]: # <= ensures stability
result.append(left[i])
i += 1
else:
result.append(right[j])
j += 1
result.extend(left[i:])
result.extend(right[j:])
return result
# Example: stable sort preserves original order for equal scores
predictions = [
("cat", 0.91), ("dog", 0.91), ("bird", 0.95), ("fish", 0.87)
]
sorted_preds = merge_sort(predictions)
# key function version:
sorted_by_score = sorted(predictions, key=lambda x: x[1])
# ("fish", 0.87), ("cat", 0.91), ("dog", 0.91), ("bird", 0.95)
# Note: "cat" comes before "dog" because they have equal scores
# and "cat" appeared first in the original list (stability)
Counting Inversions with MergeSort
A classic interview problem: count the number of inversions in an array (pairs where i < j but arr[i] > arr[j]). This is useful in ML for measuring rank correlation (Kendall tau distance).
def count_inversions(arr: list) -> tuple:
"""Count inversions using modified MergeSort.
An inversion is a pair (i, j) where i < j but arr[i] > arr[j].
Used in ML to measure ranking disagreement (Kendall tau).
Time: O(n log n)
Space: O(n)
Returns: (sorted_array, inversion_count)
"""
if len(arr) <= 1:
return arr, 0
mid = len(arr) // 2
left, left_inv = count_inversions(arr[:mid])
right, right_inv = count_inversions(arr[mid:])
merged = []
inversions = left_inv + right_inv
i = j = 0
while i < len(left) and j < len(right):
if left[i] <= right[j]:
merged.append(left[i])
i += 1
else:
# All remaining elements in left are > right[j]
# so each forms an inversion with right[j]
inversions += len(left) - i
merged.append(right[j])
j += 1
merged.extend(left[i:])
merged.extend(right[j:])
return merged, inversions
# Example: compare two rankings
model_a_ranking = [3, 1, 2, 5, 4] # model A's ordering
_, inv_count = count_inversions(model_a_ranking)
print(f"Inversions from sorted: {inv_count}") # 3
3. HeapSort
HeapSort uses a binary heap to sort in O(n log n) with O(1) extra space. It is not stable but guarantees worst-case performance. More importantly, the heap data structure is the foundation of top-K problems.
| Property | Value |
|---|---|
| Time | O(n log n) in all cases |
| Space | O(1) — in-place |
| Stable? | No |
| In-place? | Yes |
def heapsort(arr: list) -> None:
"""In-place HeapSort using a max-heap.
Time: O(n log n) guaranteed
Space: O(1)
Steps:
1. Build max-heap from array: O(n)
2. Repeatedly extract max and place at end: O(n log n)
"""
n = len(arr)
# Step 1: Build max-heap (bottom-up)
# Start from last non-leaf node and heapify down
for i in range(n // 2 - 1, -1, -1):
heapify_down(arr, n, i)
# Step 2: Extract elements one by one
for i in range(n - 1, 0, -1):
# Move current root (max) to end
arr[0], arr[i] = arr[i], arr[0]
# Heapify the reduced heap
heapify_down(arr, i, 0)
def heapify_down(arr: list, heap_size: int, root: int) -> None:
"""Maintain max-heap property by sifting down from root.
Compares root with its children and swaps with the largest.
Continues until the heap property is restored.
"""
largest = root
left = 2 * root + 1
right = 2 * root + 2
if left < heap_size and arr[left] > arr[largest]:
largest = left
if right < heap_size and arr[right] > arr[largest]:
largest = right
if largest != root:
arr[root], arr[largest] = arr[largest], arr[root]
heapify_down(arr, heap_size, largest)
# Example
data = [12, 11, 13, 5, 6, 7]
heapsort(data)
print(data) # [5, 6, 7, 11, 12, 13]
Partial Sort with Heap (Top-K Preview)
One of the most important applications: extracting the top K elements without fully sorting. This is O(n log k) instead of O(n log n).
import heapq
def top_k_scores(scores: list, k: int) -> list:
"""Find top K scores using a min-heap of size K.
Time: O(n log k) - much better than O(n log n) when k << n
Space: O(k)
Why min-heap? We maintain the K largest seen so far.
If a new element is larger than the heap's minimum,
we replace it. The min-heap efficiently gives us
the smallest of our K largest elements.
"""
if k >= len(scores):
return sorted(scores, reverse=True)
# Build a min-heap of size k
min_heap = scores[:k]
heapq.heapify(min_heap) # O(k)
# Process remaining elements
for score in scores[k:]:
if score > min_heap[0]: # Larger than current k-th largest
heapq.heapreplace(min_heap, score) # O(log k)
# Return sorted in descending order
return sorted(min_heap, reverse=True)
# ML example: top 5 predictions from 1 million candidates
import random
all_scores = [random.random() for _ in range(1_000_000)]
top_5 = top_k_scores(all_scores, 5)
print(f"Top 5: {top_5}") # 5 highest scores, sorted descending
When to Use Which?
| Scenario | Best Choice | Why |
|---|---|---|
| General purpose, random data | QuickSort | Fastest in practice due to cache locality, O(log n) space |
| Need stability (equal elements keep order) | MergeSort | Only O(n log n) stable sort; Python's built-in sort uses Timsort (merge-based) |
| Memory constrained, no extra space | HeapSort | O(1) extra space, guaranteed O(n log n) |
| Nearly sorted data | Timsort (Python built-in) | O(n) on nearly sorted data due to run detection |
| Top-K from large dataset | Heap (partial sort) | O(n log k) << O(n log n) when k << n |
| Linked list data | MergeSort | Does not need random access; QuickSort is slow on linked lists |
| External sort (data on disk) | MergeSort | Sequential access pattern works well with disk I/O |
| Counting inversions / rank correlation | MergeSort | Natural extension counts inversions during merge step |
- QuickSort is fastest in practice but has O(n²) worst case — use randomized pivot to mitigate
- MergeSort is the only O(n log n) stable sort and is ideal for linked lists and external sorting
- HeapSort uses O(1) extra space but is slower in practice due to poor cache performance
- For ML ranking: use MergeSort when equal-score ordering matters, heap-based partial sort for top-K retrieval
- Python's built-in
sorted()uses Timsort which is already optimal for most cases — implement from scratch only when asked in interviews
Lilly Tech Systems