Binary Search Variants
Binary search is deceptively simple in concept but tricky in implementation. Off-by-one errors are the #1 source of bugs. This lesson covers 6 essential binary search patterns with bulletproof implementations that you can adapt to any variant.
Problem 1: First and Last Occurrence
Problem: Given a sorted array and a target, find the first and last index where the target appears. Return [-1, -1] if not found.
ML Application: Finding the range of documents with the same relevance score in a sorted retrieval list.
def find_first(arr: list, target: int) -> int:
"""Find the first occurrence of target in sorted array.
Time: O(log n)
Space: O(1)
Key insight: When we find target, we do NOT return immediately.
Instead, we save the position and keep searching LEFT to find
an earlier occurrence.
"""
lo, hi = 0, len(arr) - 1
result = -1
while lo <= hi:
mid = lo + (hi - lo) // 2 # Prevents integer overflow
if arr[mid] == target:
result = mid # Save this position
hi = mid - 1 # Keep searching LEFT
elif arr[mid] < target:
lo = mid + 1
else:
hi = mid - 1
return result
def find_last(arr: list, target: int) -> int:
"""Find the last occurrence of target in sorted array.
Time: O(log n)
Space: O(1)
Same as find_first, but when we find target, we search RIGHT.
"""
lo, hi = 0, len(arr) - 1
result = -1
while lo <= hi:
mid = lo + (hi - lo) // 2
if arr[mid] == target:
result = mid # Save this position
lo = mid + 1 # Keep searching RIGHT
elif arr[mid] < target:
lo = mid + 1
else:
hi = mid - 1
return result
def search_range(arr: list, target: int) -> list:
"""Find first and last position of target. LeetCode 34.
Time: O(log n) - two binary searches
Space: O(1)
"""
return [find_first(arr, target), find_last(arr, target)]
# Example
scores = [1, 2, 3, 3, 3, 3, 4, 5, 6]
print(search_range(scores, 3)) # [2, 5]
print(search_range(scores, 7)) # [-1, -1]
Problem 2: Search Insert Position
Problem: Given a sorted array and target, return the index where the target is found. If not, return where it would be inserted to keep the array sorted.
ML Application: Inserting a new data point into a sorted feature array for online learning.
def search_insert(arr: list, target: int) -> int:
"""Find insert position in sorted array. LeetCode 35.
Time: O(log n)
Space: O(1)
Equivalent to bisect.bisect_left(arr, target).
After the loop, lo = first index where arr[lo] >= target.
"""
lo, hi = 0, len(arr) # Note: hi = len(arr), not len(arr) - 1
while lo < hi: # Note: strict <, not <=
mid = lo + (hi - lo) // 2
if arr[mid] < target:
lo = mid + 1
else:
hi = mid
return lo
# Example
arr = [1, 3, 5, 6]
print(search_insert(arr, 5)) # 2 (found at index 2)
print(search_insert(arr, 2)) # 1 (insert between 1 and 3)
print(search_insert(arr, 7)) # 4 (insert at end)
print(search_insert(arr, 0)) # 0 (insert at beginning)
Problem 3: Find Peak Element
Problem: Find any peak element in an array (element greater than its neighbors). The array may have multiple peaks.
ML Application: Finding local optima in a loss landscape or hyperparameter search space.
def find_peak_element(nums: list) -> int:
"""Find a peak element index. LeetCode 162.
Time: O(log n)
Space: O(1)
Key insight: If nums[mid] < nums[mid+1], there MUST be a peak
to the right (because the sequence is rising and must eventually
fall or reach the end, which counts as a peak). Similarly,
if nums[mid] < nums[mid-1], there must be a peak to the left.
"""
lo, hi = 0, len(nums) - 1
while lo < hi:
mid = lo + (hi - lo) // 2
if nums[mid] < nums[mid + 1]:
# Rising slope: peak must be to the right
lo = mid + 1
else:
# Falling slope or plateau: peak is at mid or left
hi = mid
return lo # lo == hi, both point to a peak
# Example
nums = [1, 2, 3, 1]
print(find_peak_element(nums)) # 2 (index of value 3)
nums = [1, 2, 1, 3, 5, 6, 4]
print(find_peak_element(nums)) # 5 (index of value 6, or 1 for value 2)
# ML application: find optimal learning rate
loss_at_lr = [0.9, 0.7, 0.5, 0.3, 0.2, 0.25, 0.4, 0.6]
# Negate to find minimum (peak of negated = trough of original)
neg_loss = [-x for x in loss_at_lr]
best_idx = find_peak_element(neg_loss)
print(f"Best learning rate index: {best_idx}") # 4
Problem 4: Search in Rotated Sorted Array
Problem: Search for a target in a sorted array that has been rotated at some pivot. All elements are unique.
ML Application: Searching in circular buffers used for sliding-window feature computation in streaming ML.
def search_rotated(nums: list, target: int) -> int:
"""Search in rotated sorted array. LeetCode 33.
Time: O(log n)
Space: O(1)
Key insight: After splitting at mid, one half is ALWAYS sorted.
Check if target falls within the sorted half. If yes, search there.
If not, search the other half.
"""
lo, hi = 0, len(nums) - 1
while lo <= hi:
mid = lo + (hi - lo) // 2
if nums[mid] == target:
return mid
# Left half is sorted
if nums[lo] <= nums[mid]:
if nums[lo] <= target < nums[mid]:
hi = mid - 1 # Target in sorted left half
else:
lo = mid + 1 # Target in right half
# Right half is sorted
else:
if nums[mid] < target <= nums[hi]:
lo = mid + 1 # Target in sorted right half
else:
hi = mid - 1 # Target in left half
return -1
# Example
nums = [4, 5, 6, 7, 0, 1, 2]
print(search_rotated(nums, 0)) # 4
print(search_rotated(nums, 3)) # -1
# Circular buffer example
buffer = [80, 90, 95, 10, 20, 30, 50, 60, 70]
print(search_rotated(buffer, 30)) # 5
Problem 5: Search a 2D Matrix
Problem: Search for a value in an m x n matrix where each row is sorted and the first element of each row is greater than the last element of the previous row.
ML Application: Searching in feature matrices, confusion matrices, or lookup tables for quantized embeddings.
def search_matrix(matrix: list, target: int) -> bool:
"""Search in a sorted 2D matrix. LeetCode 74.
Time: O(log(m * n))
Space: O(1)
Treat the 2D matrix as a virtual 1D sorted array.
Index i in the 1D array maps to matrix[i // cols][i % cols].
"""
if not matrix or not matrix[0]:
return False
rows, cols = len(matrix), len(matrix[0])
lo, hi = 0, rows * cols - 1
while lo <= hi:
mid = lo + (hi - lo) // 2
# Convert 1D index to 2D coordinates
val = matrix[mid // cols][mid % cols]
if val == target:
return True
elif val < target:
lo = mid + 1
else:
hi = mid - 1
return False
# Example: search confusion matrix lookup
matrix = [
[1, 3, 5, 7],
[10, 11, 16, 20],
[23, 30, 34, 60]
]
print(search_matrix(matrix, 3)) # True
print(search_matrix(matrix, 13)) # False
Search in Row-and-Column Sorted Matrix
def search_matrix_ii(matrix: list, target: int) -> bool:
"""Search matrix where rows AND columns are sorted. LeetCode 240.
Time: O(m + n)
Space: O(1)
Start from top-right corner:
- If current > target, move left (eliminate column)
- If current < target, move down (eliminate row)
"""
if not matrix or not matrix[0]:
return False
rows, cols = len(matrix), len(matrix[0])
row, col = 0, cols - 1 # Start top-right
while row < rows and col >= 0:
if matrix[row][col] == target:
return True
elif matrix[row][col] > target:
col -= 1 # Too big, go left
else:
row += 1 # Too small, go down
return False
# Example: feature correlation matrix lookup
correlation = [
[0.1, 0.3, 0.5, 0.7],
[0.2, 0.4, 0.6, 0.8],
[0.3, 0.5, 0.7, 0.9],
[0.4, 0.6, 0.8, 1.0]
]
# Find if correlation 0.6 exists
print(search_matrix_ii(correlation, 0.6)) # True
Problem 6: Threshold Finding for ML Calibration
Problem: Given sorted prediction scores and labels, find the classification threshold that achieves a target recall (e.g., 95%).
ML Application: This is a real production ML task. When deploying a classifier, you need to set the decision threshold to meet business requirements.
def find_threshold_for_recall(
scores: list,
labels: list,
target_recall: float
) -> float:
"""Find classification threshold for target recall using binary search.
Time: O(n log n) for sorting + O(n log n) for binary search with
O(n) recall computation at each step = O(n log n) total
Space: O(n)
In practice, this is much faster than grid search over thresholds.
"""
# Sort by score descending (paired with labels)
paired = sorted(zip(scores, labels), key=lambda x: -x[0])
total_positives = sum(labels)
if total_positives == 0:
return 0.0
def recall_at_threshold(threshold: float) -> float:
"""Compute recall at a given threshold."""
true_positives = sum(
1 for score, label in paired
if score >= threshold and label == 1
)
return true_positives / total_positives
# Binary search over threshold values
lo, hi = 0.0, 1.0
best_threshold = 0.0
for _ in range(50): # 50 iterations gives ~15 decimal precision
mid = (lo + hi) / 2
recall = recall_at_threshold(mid)
if recall >= target_recall:
best_threshold = mid
lo = mid # Try higher threshold (more precise)
else:
hi = mid # Lower threshold (more recall)
return best_threshold
# Example: find threshold for 95% recall
import random
random.seed(42)
n = 1000
scores = [random.random() for _ in range(n)]
# Labels: higher scores more likely positive
labels = [1 if s > 0.3 and random.random() > 0.2 else 0 for s in scores]
threshold = find_threshold_for_recall(scores, labels, target_recall=0.95)
print(f"Threshold for 95% recall: {threshold:.4f}")
# Verify
tp = sum(1 for s, l in zip(scores, labels) if s >= threshold and l == 1)
fn = sum(1 for s, l in zip(scores, labels) if s < threshold and l == 1)
actual_recall = tp / (tp + fn) if (tp + fn) > 0 else 0
print(f"Actual recall: {actual_recall:.4f}")
Binary Search Patterns Cheat Sheet
| Pattern | lo/hi Init | Loop Condition | Update Rule |
|---|---|---|---|
| Exact match | 0, n-1 | lo <= hi | lo=mid+1 or hi=mid-1 |
| First occurrence | 0, n-1 | lo <= hi | Found: save, hi=mid-1 |
| Last occurrence | 0, n-1 | lo <= hi | Found: save, lo=mid+1 |
| Insert position | 0, n | lo < hi | hi=mid (not mid-1) |
| Peak element | 0, n-1 | lo < hi | Compare mid vs mid+1 |
| Float search | 0.0, 1.0 | fixed iterations | mid=(lo+hi)/2 |
- Always use
mid = lo + (hi - lo) // 2to avoid integer overflow (matters in C++/Java, good habit in Python) - The difference between
lo <= hiandlo < hidetermines whether you process single-element ranges - For float binary search (thresholds, calibration), use a fixed number of iterations (50 is more than enough) instead of an epsilon check
- Binary search on the answer is a powerful pattern: "find the minimum X such that condition(X) is true"
- Every binary search variant is about defining what "left half" and "right half" mean for your specific problem
Lilly Tech Systems