Advanced

Binary Search Patterns

Binary search reduces O(n) linear scans to O(log n) by halving the search space at each step. This lesson covers 6 problems that demonstrate the key binary search patterns you need for interviews.

💡
Why this matters for AI/ML: Binary search is used in hyperparameter tuning (finding optimal learning rates on a log scale), threshold optimization for classification (finding the best probability cutoff), percentile computation in large datasets, and bisection methods in numerical optimization.

Problem 1: Classic Binary Search

📝
Problem: Given a sorted array of integers nums and a target, return the index of the target if found, or -1 if not found. Your solution must be O(log n).

Linear Search: O(n) Time

def search_linear(nums, target):
    """Scan every element - O(n), not efficient for sorted data."""
    for i, num in enumerate(nums):
        if num == target:
            return i
    return -1

Binary Search: O(log n) Time, O(1) Space

def binary_search(nums, target):
    """
    Classic binary search on sorted array.

    Invariant: target is always in [left, right] if it exists.
    - If nums[mid] == target: found it
    - If nums[mid] < target: target is in right half
    - If nums[mid] > target: target is in left half
    """
    left, right = 0, len(nums) - 1

    while left <= right:
        mid = left + (right - left) // 2  # Avoid overflow

        if nums[mid] == target:
            return mid
        elif nums[mid] < target:
            left = mid + 1
        else:
            right = mid - 1

    return -1

# Test
print(binary_search([-1, 0, 3, 5, 9, 12], 9))   # 4
print(binary_search([-1, 0, 3, 5, 9, 12], 2))    # -1
print(binary_search([5], 5))                       # 0
Common bug: Using mid = (left + right) // 2 can cause integer overflow in other languages. In Python, integers have arbitrary precision so it does not overflow, but using left + (right - left) // 2 is a good habit and shows the interviewer you think about edge cases.

Problem 2: Search in Rotated Sorted Array

📝
Problem: Given a sorted array that has been rotated at some pivot (e.g., [4,5,6,7,0,1,2]), search for a target value. Return its index or -1 if not found. You must achieve O(log n) time.

Brute Force: O(n) Time

def search_rotated_brute(nums, target):
    """Linear scan - doesn't leverage the sorted property."""
    for i, num in enumerate(nums):
        if num == target:
            return i
    return -1

Optimal: O(log n) Time, O(1) Space

def search_rotated(nums, target):
    """
    Modified binary search for rotated sorted array.

    Key insight: At least one half of the array is always sorted.
    We can determine which half by comparing nums[left] with nums[mid].

    If left half is sorted (nums[left] <= nums[mid]):
      - Check if target is in [nums[left], nums[mid]]
      - If yes, search left half; otherwise search right half

    If right half is sorted:
      - Check if target is in [nums[mid], nums[right]]
      - If yes, search right half; otherwise search left half
    """
    left, right = 0, len(nums) - 1

    while left <= right:
        mid = left + (right - left) // 2

        if nums[mid] == target:
            return mid

        # Left half is sorted
        if nums[left] <= nums[mid]:
            if nums[left] <= target < nums[mid]:
                right = mid - 1  # Target in sorted left half
            else:
                left = mid + 1   # Target in right half
        # Right half is sorted
        else:
            if nums[mid] < target <= nums[right]:
                left = mid + 1   # Target in sorted right half
            else:
                right = mid - 1  # Target in left half

    return -1

# Test
print(search_rotated([4, 5, 6, 7, 0, 1, 2], 0))  # 4
print(search_rotated([4, 5, 6, 7, 0, 1, 2], 3))  # -1
print(search_rotated([1], 0))                       # -1
print(search_rotated([1], 1))                       # 0

AI/ML context: Rotated sorted arrays appear in circular buffers for streaming data in online learning. The pattern of identifying which half is “normal” and narrowing the search is also used in searching time-series data that wraps around (e.g., seasonal data).

Problem 3: Find Peak Element

📝
Problem: A peak element is greater than its neighbors. Given an array where nums[i] != nums[i+1], find a peak element and return its index. The array may contain multiple peaks; return any one. You must achieve O(log n) time.

Brute Force: O(n) Time

def find_peak_brute(nums):
    """Linear scan for peak."""
    for i in range(len(nums)):
        is_peak = True
        if i > 0 and nums[i] <= nums[i - 1]:
            is_peak = False
        if i < len(nums) - 1 and nums[i] <= nums[i + 1]:
            is_peak = False
        if is_peak:
            return i
    return len(nums) - 1

Optimal (Binary Search): O(log n) Time, O(1) Space

def find_peak_element(nums):
    """
    Binary search for peak.

    Key insight: If nums[mid] < nums[mid + 1], there MUST be
    a peak to the right (the array goes up, and must come down
    eventually or end at the rightmost element which would be a peak).

    Similarly, if nums[mid] < nums[mid - 1], peak is to the left.

    We are guaranteed to find a peak because:
    - nums[-1] = nums[n] = -infinity (boundary condition)
    - The array must go up then down somewhere
    """
    left, right = 0, len(nums) - 1

    while left < right:
        mid = left + (right - left) // 2

        if nums[mid] < nums[mid + 1]:
            # Rising slope -> peak is to the right
            left = mid + 1
        else:
            # Falling slope -> peak is at mid or to the left
            right = mid

    return left  # left == right == peak index

# Test
print(find_peak_element([1, 2, 3, 1]))      # 2
print(find_peak_element([1, 2, 1, 3, 5, 6, 4]))  # 1 or 5
print(find_peak_element([1]))                 # 0

AI/ML context: Finding peaks is central to gradient-free optimization methods like golden section search. In signal processing for audio ML, peak detection identifies dominant frequencies. Loss landscape analysis also uses peak/valley finding to understand model convergence.

Problem 4: Search a 2D Matrix

📝
Problem: Given an m x n matrix where each row is sorted and the first element of each row is greater than the last element of the previous row, determine if a target value exists in the matrix.

Brute Force: O(m * n) Time

def search_matrix_brute(matrix, target):
    """Check every element."""
    for row in matrix:
        for val in row:
            if val == target:
                return True
    return False

Optimal (Treat as 1D Array): O(log(m*n)) Time, O(1) Space

def search_matrix(matrix, target):
    """
    Treat the 2D matrix as a flattened sorted 1D array.

    For a matrix with m rows and n columns:
    - 1D index 'mid' maps to matrix[mid // n][mid % n]
    - Total elements: m * n

    This works because each row's first element > previous row's last element,
    so the flattened array is fully sorted.
    """
    if not matrix or not matrix[0]:
        return False

    m, n = len(matrix), len(matrix[0])
    left, right = 0, m * n - 1

    while left <= right:
        mid = left + (right - left) // 2
        # Convert 1D index to 2D coordinates
        row = mid // n
        col = mid % n
        val = matrix[row][col]

        if val == target:
            return True
        elif val < target:
            left = mid + 1
        else:
            right = mid - 1

    return False

# Test
matrix = [
    [1, 3, 5, 7],
    [10, 11, 16, 20],
    [23, 30, 34, 60]
]
print(search_matrix(matrix, 3))    # True
print(search_matrix(matrix, 13))   # False
print(search_matrix(matrix, 60))   # True

AI/ML context: 2D matrix search is directly applicable to searching sorted lookup tables for quantized model weights, finding entries in embedding tables sorted by frequency, and searching pre-computed distance matrices in nearest neighbor algorithms.

Problem 5: Median of Two Sorted Arrays

📝
Problem: Given two sorted arrays nums1 and nums2, return the median of the two sorted arrays. Your solution must be O(log(min(m, n))).

Brute Force (Merge): O(m + n) Time, O(m + n) Space

def find_median_brute(nums1, nums2):
    """Merge both arrays, then find middle element(s)."""
    merged = []
    i, j = 0, 0

    while i < len(nums1) and j < len(nums2):
        if nums1[i] <= nums2[j]:
            merged.append(nums1[i])
            i += 1
        else:
            merged.append(nums2[j])
            j += 1

    merged.extend(nums1[i:])
    merged.extend(nums2[j:])

    n = len(merged)
    if n % 2 == 1:
        return merged[n // 2]
    return (merged[n // 2 - 1] + merged[n // 2]) / 2

print(find_median_brute([1, 3], [2]))        # 2.0
print(find_median_brute([1, 2], [3, 4]))     # 2.5

Optimal (Binary Search): O(log(min(m,n))) Time, O(1) Space

def find_median_sorted_arrays(nums1, nums2):
    """
    Binary search on the smaller array to find the correct partition.

    We partition both arrays such that:
    - Left partition has exactly half the total elements
    - All elements in left partition <= all elements in right partition

    Binary search on nums1's partition point. For each candidate:
    - Check if max(left1) <= min(right2) AND max(left2) <= min(right1)
    - If yes: found the correct partition -> compute median
    - If max(left1) > min(right2): partition too far right, move left
    - Otherwise: partition too far left, move right
    """
    # Always binary search on the smaller array
    if len(nums1) > len(nums2):
        nums1, nums2 = nums2, nums1

    m, n = len(nums1), len(nums2)
    left, right = 0, m
    half = (m + n + 1) // 2

    while left <= right:
        i = (left + right) // 2   # Partition point in nums1
        j = half - i               # Partition point in nums2

        # Edge values (use -inf/+inf for boundaries)
        left1 = nums1[i - 1] if i > 0 else float('-inf')
        right1 = nums1[i] if i < m else float('inf')
        left2 = nums2[j - 1] if j > 0 else float('-inf')
        right2 = nums2[j] if j < n else float('inf')

        if left1 <= right2 and left2 <= right1:
            # Correct partition found
            if (m + n) % 2 == 1:
                return max(left1, left2)
            return (max(left1, left2) + min(right1, right2)) / 2
        elif left1 > right2:
            right = i - 1  # Too many elements from nums1
        else:
            left = i + 1   # Too few elements from nums1

# Test
print(find_median_sorted_arrays([1, 3], [2]))         # 2.0
print(find_median_sorted_arrays([1, 2], [3, 4]))      # 2.5
print(find_median_sorted_arrays([0, 0], [0, 0]))      # 0.0
print(find_median_sorted_arrays([], [1]))              # 1.0
print(find_median_sorted_arrays([2], []))              # 2.0

AI/ML context: Computing medians efficiently across distributed datasets is critical in federated learning and data parallelism. When training data is split across multiple machines, finding the global median for normalization requires algorithms like this rather than gathering all data to one node.

Problem 6: Find Minimum Threshold (Binary Search on Answer)

📝
Problem: Given an array of scores and a target pass rate, find the minimum threshold such that at least target_percent of scores are at or above the threshold. Scores range from 0 to 100.

Brute Force: O(n * R) Time where R is the range

def find_threshold_brute(scores, target_percent):
    """Try every possible threshold from 0 to 100."""
    for threshold in range(101):
        pass_count = sum(1 for s in scores if s >= threshold)
        pass_rate = pass_count / len(scores) * 100
        if pass_rate < target_percent:
            return threshold - 1
    return 100

Optimal (Binary Search on Answer): O(n log R) Time, O(1) Space

def find_threshold(scores, target_percent):
    """
    Binary search on the answer (threshold value).

    This is the "binary search on answer" pattern:
    - Define a condition function: does this threshold achieve
      the target pass rate?
    - Binary search for the boundary where the condition flips.

    The condition is monotonic: as threshold increases,
    pass rate decreases. We want the highest threshold
    that still meets the target pass rate.
    """
    def passes_target(threshold):
        """Check if this threshold achieves target pass rate."""
        pass_count = sum(1 for s in scores if s >= threshold)
        return (pass_count / len(scores) * 100) >= target_percent

    left, right = 0, 100
    result = 0

    while left <= right:
        mid = left + (right - left) // 2

        if passes_target(mid):
            result = mid       # This threshold works, try higher
            left = mid + 1
        else:
            right = mid - 1    # Too high, try lower

    return result

# Test
scores = [85, 92, 78, 95, 88, 70, 65, 90, 82, 77]

# Find threshold where at least 60% of students pass
print(find_threshold(scores, 60))  # 78

# Find threshold where at least 80% pass
print(find_threshold(scores, 80))  # 70

# Find threshold where 100% pass
print(find_threshold(scores, 100))  # 65

AI/ML context: This “binary search on answer” pattern is exactly how you optimize classification thresholds. Given a trained model's predicted probabilities, you binary search for the probability threshold that achieves a target precision, recall, or F1 score. It is also used in hyperparameter tuning to find optimal regularization strengths.

Summary: Complexity Comparison

ProblemBrute ForceOptimalKey Pattern
Classic Binary SearchO(n)O(log n)Standard binary search
Search Rotated ArrayO(n)O(log n)Identify sorted half
Find PeakO(n)O(log n)Follow rising slope
Search 2D MatrixO(m*n)O(log(m*n))Flatten to 1D index
Median of Two ArraysO(m+n)O(log(min(m,n)))Binary partition
Threshold FindingO(n*R)O(n log R)Binary search on answer
💡
Pattern recognition: Binary search applies whenever you have a monotonic condition (as input increases, the condition flips from true to false or vice versa). The three main patterns are: (1) search in sorted data, (2) search in rotated/modified sorted data, (3) binary search on the answer space.