Top-K Problems
Top-K problems are the most frequently asked sorting-related questions in ML interviews because they directly map to ranking and retrieval systems. Every recommendation engine, search system, and content feed uses some form of top-K selection. This lesson covers 5 essential problems with multiple solution approaches.
Problem 1: Kth Largest Element
Problem: Find the kth largest element in an unsorted array. This is the foundation of all top-K problems.
ML Application: Finding the k-th highest confidence prediction to set a cutoff threshold.
Approach A: Min-Heap (O(n log k))
import heapq
def kth_largest_heap(nums: list, k: int) -> int:
"""Find kth largest element using a min-heap of size k.
Time: O(n log k)
Space: O(k)
Maintain a min-heap of the k largest elements seen so far.
The heap's root is always the kth largest.
"""
# Use a min-heap of size k
min_heap = nums[:k]
heapq.heapify(min_heap) # O(k)
for num in nums[k:]:
if num > min_heap[0]:
heapq.heapreplace(min_heap, num) # O(log k)
return min_heap[0] # Root = kth largest
# Example
nums = [3, 2, 1, 5, 6, 4]
print(kth_largest_heap(nums, 2)) # 5 (2nd largest)
Approach B: Quickselect (O(n) average)
import random
def kth_largest_quickselect(nums: list, k: int) -> int:
"""Find kth largest using Quickselect (Hoare's algorithm).
Time: O(n) average, O(n^2) worst case
Space: O(1) - in-place
This is the optimal average-case algorithm.
Used internally by numpy.partition().
"""
# kth largest = (n-k)th smallest (0-indexed)
target = len(nums) - k
return quickselect(nums, 0, len(nums) - 1, target)
def quickselect(arr: list, lo: int, hi: int, k: int) -> int:
"""Find the k-th smallest element (0-indexed).
Like QuickSort, but only recurses into ONE half.
"""
if lo == hi:
return arr[lo]
# Randomized pivot
pivot_idx = random.randint(lo, hi)
arr[pivot_idx], arr[hi] = arr[hi], arr[pivot_idx]
pivot = arr[hi]
store = lo
for i in range(lo, hi):
if arr[i] < pivot:
arr[store], arr[i] = arr[i], arr[store]
store += 1
arr[store], arr[hi] = arr[hi], arr[store]
if store == k:
return arr[store]
elif store < k:
return quickselect(arr, store + 1, hi, k)
else:
return quickselect(arr, lo, store - 1, k)
# Example
nums = [3, 2, 3, 1, 2, 4, 5, 5, 6]
print(kth_largest_quickselect(nums.copy(), 4)) # 4
numpy.partition() uses introselect (quickselect with fallback to median-of-medians for O(n) worst case).Problem 2: Top K Frequent Elements
Problem: Given an array of integers, return the k most frequent elements.
ML Application: Finding the most common tokens in a corpus, most frequent user actions, or most popular items for cold-start recommendations.
from collections import Counter
import heapq
def top_k_frequent(nums: list, k: int) -> list:
"""Find k most frequent elements. LeetCode 347.
Time: O(n log k) - count O(n) + heap operations O(n log k)
Space: O(n) for counter
Three approaches ranked by efficiency:
1. Heap: O(n log k) - best when k << n
2. Bucket sort: O(n) - best overall
3. Full sort: O(n log n) - simple but slow
"""
count = Counter(nums)
# Use a min-heap of size k on frequencies
return heapq.nlargest(k, count.keys(), key=count.get)
def top_k_frequent_bucket(nums: list, k: int) -> list:
"""O(n) solution using bucket sort on frequencies.
Key insight: frequency is bounded by n, so we can use
an array where index = frequency, value = list of elements.
"""
count = Counter(nums)
n = len(nums)
# Bucket: index = frequency, value = elements with that frequency
buckets = [[] for _ in range(n + 1)]
for num, freq in count.items():
buckets[freq].append(num)
# Collect top k from highest frequency buckets
result = []
for freq in range(n, 0, -1):
for num in buckets[freq]:
result.append(num)
if len(result) == k:
return result
return result
# Example
nums = [1, 1, 1, 2, 2, 3]
print(top_k_frequent(nums, 2)) # [1, 2]
print(top_k_frequent_bucket(nums, 2)) # [1, 2]
# ML example: most common words in a corpus
corpus = ["the", "cat", "sat", "on", "the", "mat", "the",
"cat", "ate", "the", "rat", "cat"]
top_words = top_k_frequent(corpus, 3)
print(f"Top 3 words: {top_words}") # ['the', 'cat', 'sat'] or similar
Problem 3: K Closest Points to Origin
Problem: Given an array of points, find the K closest points to the origin (0, 0).
ML Application: K-nearest neighbors in 2D/3D feature spaces, finding closest embeddings, approximate nearest neighbor retrieval.
import heapq
import math
def k_closest_points(points: list, k: int) -> list:
"""Find K closest points to origin. LeetCode 973.
Time: O(n log k) using max-heap of size k
Space: O(k)
Use a MAX-heap (negate distances) of size k.
The root is the farthest of our k closest points.
If a new point is closer, replace the root.
"""
# Max-heap trick: negate distance for max-heap behavior
max_heap = []
for x, y in points:
dist = -(x * x + y * y) # Negate for max-heap
if len(max_heap) < k:
heapq.heappush(max_heap, (dist, x, y))
elif dist > max_heap[0][0]: # Closer than farthest in heap
heapq.heapreplace(max_heap, (dist, x, y))
return [[x, y] for _, x, y in max_heap]
# Example
points = [[1, 3], [-2, 2], [5, 8], [0, 1]]
print(k_closest_points(points, 2)) # [[0, 1], [-2, 2]]
# ML example: K nearest neighbors in feature space
def knn_classify(train_points, train_labels, query, k):
"""Simple KNN using K closest points.
Time: O(n log k)
Space: O(k)
"""
max_heap = []
for i, (x, y) in enumerate(train_points):
dist = -((x - query[0])**2 + (y - query[1])**2)
if len(max_heap) < k:
heapq.heappush(max_heap, (dist, i))
elif dist > max_heap[0][0]:
heapq.heapreplace(max_heap, (dist, i))
# Vote among k nearest neighbors
neighbor_labels = [train_labels[i] for _, i in max_heap]
return Counter(neighbor_labels).most_common(1)[0][0]
# Example KNN
train_X = [[1, 1], [2, 2], [3, 3], [6, 6], [7, 7], [8, 8]]
train_y = [0, 0, 0, 1, 1, 1]
query = [4, 4]
print(f"KNN prediction: {knn_classify(train_X, train_y, query, 3)}") # 0 or 1
Problem 4: Merge K Sorted Lists
Problem: Merge k sorted lists into one sorted list.
ML Application: Merging ranked results from multiple retrieval shards (distributed search), merging sorted feature streams in data pipelines.
import heapq
def merge_k_sorted(lists: list) -> list:
"""Merge k sorted lists into one sorted list. LeetCode 23.
Time: O(N log k) where N = total elements, k = number of lists
Space: O(k) for the heap
Use a min-heap of size k, always containing one element
from each list. Extract min, then push the next element
from the same list.
"""
min_heap = []
result = []
# Initialize heap with first element from each list
for i, lst in enumerate(lists):
if lst:
# (value, list_index, element_index)
heapq.heappush(min_heap, (lst[0], i, 0))
while min_heap:
val, list_idx, elem_idx = heapq.heappop(min_heap)
result.append(val)
# Push next element from the same list
next_idx = elem_idx + 1
if next_idx < len(lists[list_idx]):
heapq.heappush(min_heap, (lists[list_idx][next_idx], list_idx, next_idx))
return result
# Example
lists = [[1, 4, 5], [1, 3, 4], [2, 6]]
print(merge_k_sorted(lists)) # [1, 1, 2, 3, 4, 4, 5, 6]
# ML example: merge search results from k shards
shard_results = [
[("doc_a", 0.95), ("doc_b", 0.82), ("doc_c", 0.71)],
[("doc_d", 0.91), ("doc_e", 0.85), ("doc_f", 0.60)],
[("doc_g", 0.93), ("doc_h", 0.78), ("doc_i", 0.65)],
]
def merge_ranked_results(shards: list, top_k: int) -> list:
"""Merge ranked results from multiple search shards.
Each shard returns results sorted by score (descending).
We merge and return the global top-K.
Time: O(top_k * log(num_shards))
Space: O(num_shards)
"""
min_heap = [] # Max-heap via negation
for shard_idx, shard in enumerate(shards):
if shard:
doc, score = shard[0]
heapq.heappush(min_heap, (-score, shard_idx, 0, doc))
result = []
while min_heap and len(result) < top_k:
neg_score, shard_idx, elem_idx, doc = heapq.heappop(min_heap)
result.append((doc, -neg_score))
next_idx = elem_idx + 1
if next_idx < len(shards[shard_idx]):
next_doc, next_score = shards[shard_idx][next_idx]
heapq.heappush(min_heap, (-next_score, shard_idx, next_idx, next_doc))
return result
top_5 = merge_ranked_results(shard_results, 5)
print("Global top 5:")
for doc, score in top_5:
print(f" {doc}: {score:.2f}")
Problem 5: Find Median from Data Stream
Problem: Design a data structure that supports adding numbers and finding the median in O(log n) per operation.
ML Application: Computing running median for real-time feature normalization, anomaly detection thresholds, and streaming statistics.
import heapq
class MedianFinder:
"""Find median from data stream. LeetCode 295.
Uses two heaps:
- max_heap: stores the smaller half (negated for max-heap in Python)
- min_heap: stores the larger half
Invariant: len(max_heap) == len(min_heap) or len(max_heap) == len(min_heap) + 1
addNum: O(log n)
findMedian: O(1)
Space: O(n)
"""
def __init__(self):
self.max_heap = [] # Smaller half (negated values)
self.min_heap = [] # Larger half
def add_num(self, num: int) -> None:
"""Add a number to the data structure."""
# Always add to max_heap first
heapq.heappush(self.max_heap, -num)
# Ensure max_heap's max <= min_heap's min
if self.min_heap and (-self.max_heap[0] > self.min_heap[0]):
val = -heapq.heappop(self.max_heap)
heapq.heappush(self.min_heap, val)
# Balance sizes: max_heap can have at most 1 more element
if len(self.max_heap) > len(self.min_heap) + 1:
val = -heapq.heappop(self.max_heap)
heapq.heappush(self.min_heap, val)
elif len(self.min_heap) > len(self.max_heap):
val = heapq.heappop(self.min_heap)
heapq.heappush(self.max_heap, -val)
def find_median(self) -> float:
"""Return the median of all elements added so far."""
if len(self.max_heap) > len(self.min_heap):
return -self.max_heap[0]
return (-self.max_heap[0] + self.min_heap[0]) / 2.0
# Example
mf = MedianFinder()
mf.add_num(1)
print(mf.find_median()) # 1.0
mf.add_num(2)
print(mf.find_median()) # 1.5
mf.add_num(3)
print(mf.find_median()) # 2.0
# ML example: streaming feature normalization
class StreamingNormalizer:
"""Normalize features using running median (robust to outliers).
Unlike mean-based normalization, median is robust to outliers.
Used in real-time ML systems processing sensor data, financial
data, or any stream with potential anomalies.
"""
def __init__(self):
self.median_finder = MedianFinder()
self.values = []
def add_and_normalize(self, value: float) -> float:
"""Add value and return median-normalized value."""
self.median_finder.add_num(value)
self.values.append(value)
median = self.median_finder.find_median()
# Median absolute deviation for scale
mad = sorted(abs(v - median) for v in self.values)[len(self.values) // 2]
if mad == 0:
return 0.0
return (value - median) / (1.4826 * mad) # 1.4826 for normal consistency
normalizer = StreamingNormalizer()
stream = [10, 12, 11, 100, 13, 11, 12] # 100 is an outlier
for val in stream:
normalized = normalizer.add_and_normalize(val)
median = normalizer.median_finder.find_median()
print(f" Value: {val:3d}, Median: {median:.1f}, Normalized: {normalized:+.2f}")
Complexity Comparison
| Problem | Heap Approach | Alternative | Best For |
|---|---|---|---|
| Kth Largest | O(n log k) | Quickselect O(n) avg | Heap for streaming, QS for batch |
| Top K Frequent | O(n log k) | Bucket sort O(n) | Bucket sort when n is known |
| K Closest Points | O(n log k) | Quickselect O(n) avg | Heap for streaming KNN |
| Merge K Sorted | O(N log k) | Divide & conquer O(N log k) | Heap for streaming merge |
| Median of Stream | O(log n) per add | Sorted list O(n) per add | Two heaps is optimal |
- Top-K problems are the most common sorting questions in ML interviews because they map directly to ranking and retrieval
- Min-heap of size K is the go-to approach: O(n log k) time, O(k) space, works with streaming data
- Quickselect gives O(n) average for batch data but does not work for streams
- Merge K sorted lists is how distributed search systems combine results from multiple shards
- Median of stream using two heaps is essential for real-time feature normalization and anomaly detection
- Python's
heapq.nlargest(k, iterable)andheapq.nsmallest(k, iterable)implement optimal algorithms internally
Lilly Tech Systems