Advanced

Top-K Problems

Top-K problems are the most frequently asked sorting-related questions in ML interviews because they directly map to ranking and retrieval systems. Every recommendation engine, search system, and content feed uses some form of top-K selection. This lesson covers 5 essential problems with multiple solution approaches.

Problem 1: Kth Largest Element

Problem: Find the kth largest element in an unsorted array. This is the foundation of all top-K problems.

ML Application: Finding the k-th highest confidence prediction to set a cutoff threshold.

Approach A: Min-Heap (O(n log k))

import heapq

def kth_largest_heap(nums: list, k: int) -> int:
    """Find kth largest element using a min-heap of size k.

    Time:  O(n log k)
    Space: O(k)

    Maintain a min-heap of the k largest elements seen so far.
    The heap's root is always the kth largest.
    """
    # Use a min-heap of size k
    min_heap = nums[:k]
    heapq.heapify(min_heap)  # O(k)

    for num in nums[k:]:
        if num > min_heap[0]:
            heapq.heapreplace(min_heap, num)  # O(log k)

    return min_heap[0]  # Root = kth largest

# Example
nums = [3, 2, 1, 5, 6, 4]
print(kth_largest_heap(nums, 2))  # 5 (2nd largest)

Approach B: Quickselect (O(n) average)

import random

def kth_largest_quickselect(nums: list, k: int) -> int:
    """Find kth largest using Quickselect (Hoare's algorithm).

    Time:  O(n) average, O(n^2) worst case
    Space: O(1) - in-place

    This is the optimal average-case algorithm.
    Used internally by numpy.partition().
    """
    # kth largest = (n-k)th smallest (0-indexed)
    target = len(nums) - k
    return quickselect(nums, 0, len(nums) - 1, target)

def quickselect(arr: list, lo: int, hi: int, k: int) -> int:
    """Find the k-th smallest element (0-indexed).

    Like QuickSort, but only recurses into ONE half.
    """
    if lo == hi:
        return arr[lo]

    # Randomized pivot
    pivot_idx = random.randint(lo, hi)
    arr[pivot_idx], arr[hi] = arr[hi], arr[pivot_idx]

    pivot = arr[hi]
    store = lo

    for i in range(lo, hi):
        if arr[i] < pivot:
            arr[store], arr[i] = arr[i], arr[store]
            store += 1

    arr[store], arr[hi] = arr[hi], arr[store]

    if store == k:
        return arr[store]
    elif store < k:
        return quickselect(arr, store + 1, hi, k)
    else:
        return quickselect(arr, lo, store - 1, k)

# Example
nums = [3, 2, 3, 1, 2, 4, 5, 5, 6]
print(kth_largest_quickselect(nums.copy(), 4))  # 4
💡
Interview tip: Know both approaches. Heap is O(n log k) with O(k) space — better when k is small and data is streaming. Quickselect is O(n) average with O(1) space — better for in-memory data. Mention numpy.partition() uses introselect (quickselect with fallback to median-of-medians for O(n) worst case).

Problem 2: Top K Frequent Elements

Problem: Given an array of integers, return the k most frequent elements.

ML Application: Finding the most common tokens in a corpus, most frequent user actions, or most popular items for cold-start recommendations.

from collections import Counter
import heapq

def top_k_frequent(nums: list, k: int) -> list:
    """Find k most frequent elements. LeetCode 347.

    Time:  O(n log k) - count O(n) + heap operations O(n log k)
    Space: O(n) for counter

    Three approaches ranked by efficiency:
    1. Heap: O(n log k) - best when k << n
    2. Bucket sort: O(n) - best overall
    3. Full sort: O(n log n) - simple but slow
    """
    count = Counter(nums)

    # Use a min-heap of size k on frequencies
    return heapq.nlargest(k, count.keys(), key=count.get)

def top_k_frequent_bucket(nums: list, k: int) -> list:
    """O(n) solution using bucket sort on frequencies.

    Key insight: frequency is bounded by n, so we can use
    an array where index = frequency, value = list of elements.
    """
    count = Counter(nums)
    n = len(nums)

    # Bucket: index = frequency, value = elements with that frequency
    buckets = [[] for _ in range(n + 1)]
    for num, freq in count.items():
        buckets[freq].append(num)

    # Collect top k from highest frequency buckets
    result = []
    for freq in range(n, 0, -1):
        for num in buckets[freq]:
            result.append(num)
            if len(result) == k:
                return result

    return result

# Example
nums = [1, 1, 1, 2, 2, 3]
print(top_k_frequent(nums, 2))         # [1, 2]
print(top_k_frequent_bucket(nums, 2))  # [1, 2]

# ML example: most common words in a corpus
corpus = ["the", "cat", "sat", "on", "the", "mat", "the",
          "cat", "ate", "the", "rat", "cat"]
top_words = top_k_frequent(corpus, 3)
print(f"Top 3 words: {top_words}")  # ['the', 'cat', 'sat'] or similar

Problem 3: K Closest Points to Origin

Problem: Given an array of points, find the K closest points to the origin (0, 0).

ML Application: K-nearest neighbors in 2D/3D feature spaces, finding closest embeddings, approximate nearest neighbor retrieval.

import heapq
import math

def k_closest_points(points: list, k: int) -> list:
    """Find K closest points to origin. LeetCode 973.

    Time:  O(n log k) using max-heap of size k
    Space: O(k)

    Use a MAX-heap (negate distances) of size k.
    The root is the farthest of our k closest points.
    If a new point is closer, replace the root.
    """
    # Max-heap trick: negate distance for max-heap behavior
    max_heap = []

    for x, y in points:
        dist = -(x * x + y * y)  # Negate for max-heap
        if len(max_heap) < k:
            heapq.heappush(max_heap, (dist, x, y))
        elif dist > max_heap[0][0]:  # Closer than farthest in heap
            heapq.heapreplace(max_heap, (dist, x, y))

    return [[x, y] for _, x, y in max_heap]

# Example
points = [[1, 3], [-2, 2], [5, 8], [0, 1]]
print(k_closest_points(points, 2))  # [[0, 1], [-2, 2]]

# ML example: K nearest neighbors in feature space
def knn_classify(train_points, train_labels, query, k):
    """Simple KNN using K closest points.

    Time:  O(n log k)
    Space: O(k)
    """
    max_heap = []

    for i, (x, y) in enumerate(train_points):
        dist = -((x - query[0])**2 + (y - query[1])**2)
        if len(max_heap) < k:
            heapq.heappush(max_heap, (dist, i))
        elif dist > max_heap[0][0]:
            heapq.heapreplace(max_heap, (dist, i))

    # Vote among k nearest neighbors
    neighbor_labels = [train_labels[i] for _, i in max_heap]
    return Counter(neighbor_labels).most_common(1)[0][0]

# Example KNN
train_X = [[1, 1], [2, 2], [3, 3], [6, 6], [7, 7], [8, 8]]
train_y = [0, 0, 0, 1, 1, 1]
query = [4, 4]
print(f"KNN prediction: {knn_classify(train_X, train_y, query, 3)}")  # 0 or 1

Problem 4: Merge K Sorted Lists

Problem: Merge k sorted lists into one sorted list.

ML Application: Merging ranked results from multiple retrieval shards (distributed search), merging sorted feature streams in data pipelines.

import heapq

def merge_k_sorted(lists: list) -> list:
    """Merge k sorted lists into one sorted list. LeetCode 23.

    Time:  O(N log k) where N = total elements, k = number of lists
    Space: O(k) for the heap

    Use a min-heap of size k, always containing one element
    from each list. Extract min, then push the next element
    from the same list.
    """
    min_heap = []
    result = []

    # Initialize heap with first element from each list
    for i, lst in enumerate(lists):
        if lst:
            # (value, list_index, element_index)
            heapq.heappush(min_heap, (lst[0], i, 0))

    while min_heap:
        val, list_idx, elem_idx = heapq.heappop(min_heap)
        result.append(val)

        # Push next element from the same list
        next_idx = elem_idx + 1
        if next_idx < len(lists[list_idx]):
            heapq.heappush(min_heap, (lists[list_idx][next_idx], list_idx, next_idx))

    return result

# Example
lists = [[1, 4, 5], [1, 3, 4], [2, 6]]
print(merge_k_sorted(lists))  # [1, 1, 2, 3, 4, 4, 5, 6]

# ML example: merge search results from k shards
shard_results = [
    [("doc_a", 0.95), ("doc_b", 0.82), ("doc_c", 0.71)],
    [("doc_d", 0.91), ("doc_e", 0.85), ("doc_f", 0.60)],
    [("doc_g", 0.93), ("doc_h", 0.78), ("doc_i", 0.65)],
]

def merge_ranked_results(shards: list, top_k: int) -> list:
    """Merge ranked results from multiple search shards.

    Each shard returns results sorted by score (descending).
    We merge and return the global top-K.

    Time:  O(top_k * log(num_shards))
    Space: O(num_shards)
    """
    min_heap = []  # Max-heap via negation

    for shard_idx, shard in enumerate(shards):
        if shard:
            doc, score = shard[0]
            heapq.heappush(min_heap, (-score, shard_idx, 0, doc))

    result = []
    while min_heap and len(result) < top_k:
        neg_score, shard_idx, elem_idx, doc = heapq.heappop(min_heap)
        result.append((doc, -neg_score))

        next_idx = elem_idx + 1
        if next_idx < len(shards[shard_idx]):
            next_doc, next_score = shards[shard_idx][next_idx]
            heapq.heappush(min_heap, (-next_score, shard_idx, next_idx, next_doc))

    return result

top_5 = merge_ranked_results(shard_results, 5)
print("Global top 5:")
for doc, score in top_5:
    print(f"  {doc}: {score:.2f}")

Problem 5: Find Median from Data Stream

Problem: Design a data structure that supports adding numbers and finding the median in O(log n) per operation.

ML Application: Computing running median for real-time feature normalization, anomaly detection thresholds, and streaming statistics.

import heapq

class MedianFinder:
    """Find median from data stream. LeetCode 295.

    Uses two heaps:
    - max_heap: stores the smaller half (negated for max-heap in Python)
    - min_heap: stores the larger half

    Invariant: len(max_heap) == len(min_heap) or len(max_heap) == len(min_heap) + 1

    addNum:    O(log n)
    findMedian: O(1)
    Space:      O(n)
    """

    def __init__(self):
        self.max_heap = []  # Smaller half (negated values)
        self.min_heap = []  # Larger half

    def add_num(self, num: int) -> None:
        """Add a number to the data structure."""
        # Always add to max_heap first
        heapq.heappush(self.max_heap, -num)

        # Ensure max_heap's max <= min_heap's min
        if self.min_heap and (-self.max_heap[0] > self.min_heap[0]):
            val = -heapq.heappop(self.max_heap)
            heapq.heappush(self.min_heap, val)

        # Balance sizes: max_heap can have at most 1 more element
        if len(self.max_heap) > len(self.min_heap) + 1:
            val = -heapq.heappop(self.max_heap)
            heapq.heappush(self.min_heap, val)
        elif len(self.min_heap) > len(self.max_heap):
            val = heapq.heappop(self.min_heap)
            heapq.heappush(self.max_heap, -val)

    def find_median(self) -> float:
        """Return the median of all elements added so far."""
        if len(self.max_heap) > len(self.min_heap):
            return -self.max_heap[0]
        return (-self.max_heap[0] + self.min_heap[0]) / 2.0

# Example
mf = MedianFinder()
mf.add_num(1)
print(mf.find_median())  # 1.0
mf.add_num(2)
print(mf.find_median())  # 1.5
mf.add_num(3)
print(mf.find_median())  # 2.0

# ML example: streaming feature normalization
class StreamingNormalizer:
    """Normalize features using running median (robust to outliers).

    Unlike mean-based normalization, median is robust to outliers.
    Used in real-time ML systems processing sensor data, financial
    data, or any stream with potential anomalies.
    """

    def __init__(self):
        self.median_finder = MedianFinder()
        self.values = []

    def add_and_normalize(self, value: float) -> float:
        """Add value and return median-normalized value."""
        self.median_finder.add_num(value)
        self.values.append(value)
        median = self.median_finder.find_median()
        # Median absolute deviation for scale
        mad = sorted(abs(v - median) for v in self.values)[len(self.values) // 2]
        if mad == 0:
            return 0.0
        return (value - median) / (1.4826 * mad)  # 1.4826 for normal consistency

normalizer = StreamingNormalizer()
stream = [10, 12, 11, 100, 13, 11, 12]  # 100 is an outlier
for val in stream:
    normalized = normalizer.add_and_normalize(val)
    median = normalizer.median_finder.find_median()
    print(f"  Value: {val:3d}, Median: {median:.1f}, Normalized: {normalized:+.2f}")

Complexity Comparison

ProblemHeap ApproachAlternativeBest For
Kth LargestO(n log k)Quickselect O(n) avgHeap for streaming, QS for batch
Top K FrequentO(n log k)Bucket sort O(n)Bucket sort when n is known
K Closest PointsO(n log k)Quickselect O(n) avgHeap for streaming KNN
Merge K SortedO(N log k)Divide & conquer O(N log k)Heap for streaming merge
Median of StreamO(log n) per addSorted list O(n) per addTwo heaps is optimal
💡
  • Top-K problems are the most common sorting questions in ML interviews because they map directly to ranking and retrieval
  • Min-heap of size K is the go-to approach: O(n log k) time, O(k) space, works with streaming data
  • Quickselect gives O(n) average for batch data but does not work for streams
  • Merge K sorted lists is how distributed search systems combine results from multiple shards
  • Median of stream using two heaps is essential for real-time feature normalization and anomaly detection
  • Python's heapq.nlargest(k, iterable) and heapq.nsmallest(k, iterable) implement optimal algorithms internally