Advanced

Dijkstra & Advanced Graph Algorithms

Weighted graphs model real-world costs — network latency, API call pricing, model inference time. Dijkstra's algorithm finds the cheapest path, minimum spanning trees connect all nodes at minimum cost, and union-find efficiently tracks connected components. These algorithms power routing in microservice architectures, cost optimization in multi-model AI systems, and clustering in graph-based ML.

Problem 1: Dijkstra's Shortest Path

🎯
Problem: Given a weighted directed graph with n nodes and non-negative edge weights, find the shortest path from source to all other nodes.
ML Context: In multi-model serving architectures, different inference paths have different latencies and costs. Dijkstra finds the optimal routing: e.g., route through a cached embedding lookup (cheap) instead of a full model inference (expensive) when possible.
import heapq
from collections import defaultdict

def dijkstra(n: int, edges: list, source: int) -> list:
    """Dijkstra's algorithm with min-heap priority queue.
    Works for non-negative edge weights only.
    Time: O((V + E) log V), Space: O(V + E).
    """
    graph = defaultdict(list)
    for u, v, w in edges:
        graph[u].append((v, w))

    dist = [float('inf')] * n
    dist[source] = 0
    heap = [(0, source)]  # (distance, node)

    while heap:
        d, u = heapq.heappop(heap)
        if d > dist[u]:
            continue  # Skip stale entries
        for v, w in graph[u]:
            new_dist = dist[u] + w
            if new_dist < dist[v]:
                dist[v] = new_dist
                heapq.heappush(heap, (new_dist, v))

    return dist

def dijkstra_with_path(n: int, edges: list, source: int, target: int) -> tuple:
    """Dijkstra with path reconstruction.
    Returns (distance, path) to target.
    """
    graph = defaultdict(list)
    for u, v, w in edges:
        graph[u].append((v, w))

    dist = [float('inf')] * n
    dist[source] = 0
    parent = [-1] * n
    heap = [(0, source)]

    while heap:
        d, u = heapq.heappop(heap)
        if d > dist[u]:
            continue
        if u == target:
            break
        for v, w in graph[u]:
            new_dist = dist[u] + w
            if new_dist < dist[v]:
                dist[v] = new_dist
                parent[v] = u
                heapq.heappush(heap, (new_dist, v))

    # Reconstruct path
    if dist[target] == float('inf'):
        return float('inf'), []
    path = []
    node = target
    while node != -1:
        path.append(node)
        node = parent[node]
    return dist[target], path[::-1]

# Test
edges = [(0,1,4), (0,2,1), (2,1,2), (1,3,1), (2,3,5)]
print(dijkstra(4, edges, 0))                    # [0, 3, 1, 4]
print(dijkstra_with_path(4, edges, 0, 3))       # (4, [0, 2, 1, 3])

Problem 2: Network Delay Time

🎯
Problem: You have n network nodes and a list of travel times as directed edges (u, v, time). Given a source node k, find the time it takes for all nodes to receive a signal. Return -1 if not all nodes can be reached.
ML Context: This models latency propagation in distributed inference systems. When a request hits a model ensemble, the total response time is the maximum time for any model to respond (the longest shortest path from the orchestrator to each model server).
def network_delay_time(times: list, n: int, k: int) -> int:
    """Find max shortest path from source to any node.
    This is the time for ALL nodes to receive the signal.
    Time: O((V + E) log V), Space: O(V + E).
    """
    graph = defaultdict(list)
    for u, v, w in times:
        graph[u].append((v, w))

    dist = {k: 0}
    heap = [(0, k)]

    while heap:
        d, u = heapq.heappop(heap)
        if d > dist.get(u, float('inf')):
            continue
        for v, w in graph[u]:
            new_dist = d + w
            if new_dist < dist.get(v, float('inf')):
                dist[v] = new_dist
                heapq.heappush(heap, (new_dist, v))

    if len(dist) != n:
        return -1  # Some nodes unreachable

    return max(dist.values())

# Test
times = [(2,1,1), (2,3,1), (3,4,1)]
print(network_delay_time(times, 4, 2))  # 2
# Node 2 -> Node 1 (1), Node 2 -> Node 3 (1), Node 3 -> Node 4 (2)
# All nodes reached by time 2

Problem 3: Cheapest Flights Within K Stops

🎯
Problem: Given n cities, flights [from, to, price], find the cheapest price from source to destination with at most k stops.
ML Context: This models cost-constrained inference routing. You want the cheapest way to process a query through a chain of models, but with a latency constraint (max k hops). For example: route through at most 2 intermediate services to keep p99 latency under 200ms.
def find_cheapest_price(n: int, flights: list, src: int, dst: int, k: int) -> int:
    """Modified Bellman-Ford: relax edges at most k+1 times.
    Time: O(k * E), Space: O(V).
    """
    prices = [float('inf')] * n
    prices[src] = 0

    for _ in range(k + 1):  # At most k stops = k+1 edges
        temp = prices.copy()
        for u, v, price in flights:
            if prices[u] != float('inf'):
                temp[v] = min(temp[v], prices[u] + price)
        prices = temp

    return prices[dst] if prices[dst] != float('inf') else -1

def find_cheapest_bfs(n: int, flights: list, src: int, dst: int, k: int) -> int:
    """BFS with level tracking (stops = levels).
    Time: O(k * E), Space: O(V * k) worst case.
    """
    graph = defaultdict(list)
    for u, v, price in flights:
        graph[u].append((v, price))

    best = [float('inf')] * n
    best[src] = 0
    queue = deque([(src, 0, 0)])  # (node, cost, stops)

    while queue:
        node, cost, stops = queue.popleft()
        if stops > k:
            continue
        for neighbor, price in graph[node]:
            new_cost = cost + price
            if new_cost < best[neighbor]:
                best[neighbor] = new_cost
                queue.append((neighbor, new_cost, stops + 1))

    return best[dst] if best[dst] != float('inf') else -1

# Test
flights = [(0,1,100), (1,2,100), (0,2,500)]
print(find_cheapest_price(3, flights, 0, 2, 1))  # 200 (0->1->2)
print(find_cheapest_price(3, flights, 0, 2, 0))  # 500 (direct 0->2)

Problem 4: Minimum Spanning Tree

🎯
Problem: Given n nodes and weighted undirected edges, find the minimum spanning tree (MST) — the subset of edges that connects all nodes with minimum total weight.
ML Context: MST is used in hierarchical clustering (single-linkage clustering is equivalent to MST), feature selection (minimum redundancy graphs), and network design for distributed training (connecting GPU nodes with minimum communication cost).
def kruskal_mst(n: int, edges: list) -> tuple:
    """Kruskal's algorithm: sort edges by weight, add if no cycle.
    Uses Union-Find for cycle detection.
    Time: O(E log E), Space: O(V).
    """
    parent = list(range(n))
    rank = [0] * n

    def find(x):
        while parent[x] != x:
            parent[x] = parent[parent[x]]  # Path compression
            x = parent[x]
        return x

    def union(x, y):
        px, py = find(x), find(y)
        if px == py:
            return False
        if rank[px] < rank[py]:
            px, py = py, px
        parent[py] = px
        if rank[px] == rank[py]:
            rank[px] += 1
        return True

    # Sort edges by weight
    sorted_edges = sorted(edges, key=lambda e: e[2])
    mst_edges = []
    total_weight = 0

    for u, v, w in sorted_edges:
        if union(u, v):
            mst_edges.append((u, v, w))
            total_weight += w
            if len(mst_edges) == n - 1:
                break

    return total_weight, mst_edges

def prim_mst(n: int, edges: list) -> tuple:
    """Prim's algorithm: grow MST from a starting node.
    Time: O(E log V), Space: O(V + E).
    """
    graph = defaultdict(list)
    for u, v, w in edges:
        graph[u].append((w, v))
        graph[v].append((w, u))

    visited = set()
    heap = [(0, 0)]  # (weight, node), start from node 0
    total_weight = 0
    mst_edges = []

    while heap and len(visited) < n:
        weight, u = heapq.heappop(heap)
        if u in visited:
            continue
        visited.add(u)
        total_weight += weight
        for w, v in graph[u]:
            if v not in visited:
                heapq.heappush(heap, (w, v))

    return total_weight

# Test
edges = [(0,1,4), (0,2,1), (1,2,2), (1,3,5), (2,3,3)]
print(kruskal_mst(4, edges))  # (6, [(0,2,1), (1,2,2), (2,3,3)])
print(prim_mst(4, edges))     # 6

Problem 5: Union-Find (Disjoint Set Union)

🎯
Problem: Implement a Union-Find data structure that supports: (1) union(x, y) — merge the sets containing x and y, (2) find(x) — return the representative of x's set, (3) connected(x, y) — check if x and y are in the same set.
ML Context: Union-Find powers DBSCAN clustering (merge density-connected points), image segmentation (merge similar adjacent pixels), and duplicate detection in data pipelines (merge records that match on any of several keys).
class UnionFind:
    """Disjoint Set Union with path compression and union by rank.
    Nearly O(1) amortized per operation (inverse Ackermann).
    """
    def __init__(self, n: int):
        self.parent = list(range(n))
        self.rank = [0] * n
        self.count = n  # Number of connected components

    def find(self, x: int) -> int:
        """Find root with path compression."""
        while self.parent[x] != x:
            self.parent[x] = self.parent[self.parent[x]]
            x = self.parent[x]
        return x

    def union(self, x: int, y: int) -> bool:
        """Merge sets containing x and y. Returns True if merged."""
        px, py = self.find(x), self.find(y)
        if px == py:
            return False
        # Union by rank
        if self.rank[px] < self.rank[py]:
            px, py = py, px
        self.parent[py] = px
        if self.rank[px] == self.rank[py]:
            self.rank[px] += 1
        self.count -= 1
        return True

    def connected(self, x: int, y: int) -> bool:
        return self.find(x) == self.find(y)

# Application: Redundant Connection
# Given a tree with one extra edge (creating exactly one cycle),
# find the edge that can be removed to make it a tree again.
def find_redundant_connection(edges: list) -> list:
    """Find the edge that creates a cycle.
    Time: O(E * alpha(V)), Space: O(V).
    """
    n = len(edges)
    uf = UnionFind(n + 1)  # 1-indexed nodes

    for u, v in edges:
        if not uf.union(u, v):
            return [u, v]  # This edge creates a cycle

    return []

# Test
uf = UnionFind(5)
uf.union(0, 1)
uf.union(2, 3)
print(uf.connected(0, 1))  # True
print(uf.connected(0, 2))  # False
uf.union(1, 3)
print(uf.connected(0, 2))  # True (0-1-3-2 are now connected)
print(uf.count)             # 2 (one component: {0,1,2,3}, one: {4})

edges = [(1,2), (1,3), (2,3)]
print(find_redundant_connection(edges))  # [2, 3]

Algorithm Selection Guide

AlgorithmTimeUse When
BFSO(V + E)Shortest path in unweighted graph
DijkstraO((V+E) log V)Shortest path with non-negative weights
Bellman-FordO(V * E)Shortest path with negative weights, or k-stop constraint
DAG shortest pathO(V + E)Shortest path when graph is a DAG
Kruskal (MST)O(E log E)Minimum spanning tree, sparse graphs
Prim (MST)O(E log V)Minimum spanning tree, dense graphs
Union-FindO(alpha(V))Dynamic connectivity, cycle detection, MST
💡
Interview tip: When you see "shortest path", immediately ask: Are edges weighted? (Yes = Dijkstra/Bellman-Ford, No = BFS). Are there negative weights? (Yes = Bellman-Ford, No = Dijkstra). Is the graph a DAG? (Yes = topological sort approach, O(V+E)). Is there a constraint on number of edges? (Yes = modified Bellman-Ford).