Advanced
Dijkstra & Advanced Graph Algorithms
Weighted graphs model real-world costs — network latency, API call pricing, model inference time. Dijkstra's algorithm finds the cheapest path, minimum spanning trees connect all nodes at minimum cost, and union-find efficiently tracks connected components. These algorithms power routing in microservice architectures, cost optimization in multi-model AI systems, and clustering in graph-based ML.
Problem 1: Dijkstra's Shortest Path
Problem: Given a weighted directed graph with n nodes and non-negative edge weights, find the shortest path from source to all other nodes.
ML Context: In multi-model serving architectures, different inference paths have different latencies and costs. Dijkstra finds the optimal routing: e.g., route through a cached embedding lookup (cheap) instead of a full model inference (expensive) when possible.
ML Context: In multi-model serving architectures, different inference paths have different latencies and costs. Dijkstra finds the optimal routing: e.g., route through a cached embedding lookup (cheap) instead of a full model inference (expensive) when possible.
import heapq
from collections import defaultdict
def dijkstra(n: int, edges: list, source: int) -> list:
"""Dijkstra's algorithm with min-heap priority queue.
Works for non-negative edge weights only.
Time: O((V + E) log V), Space: O(V + E).
"""
graph = defaultdict(list)
for u, v, w in edges:
graph[u].append((v, w))
dist = [float('inf')] * n
dist[source] = 0
heap = [(0, source)] # (distance, node)
while heap:
d, u = heapq.heappop(heap)
if d > dist[u]:
continue # Skip stale entries
for v, w in graph[u]:
new_dist = dist[u] + w
if new_dist < dist[v]:
dist[v] = new_dist
heapq.heappush(heap, (new_dist, v))
return dist
def dijkstra_with_path(n: int, edges: list, source: int, target: int) -> tuple:
"""Dijkstra with path reconstruction.
Returns (distance, path) to target.
"""
graph = defaultdict(list)
for u, v, w in edges:
graph[u].append((v, w))
dist = [float('inf')] * n
dist[source] = 0
parent = [-1] * n
heap = [(0, source)]
while heap:
d, u = heapq.heappop(heap)
if d > dist[u]:
continue
if u == target:
break
for v, w in graph[u]:
new_dist = dist[u] + w
if new_dist < dist[v]:
dist[v] = new_dist
parent[v] = u
heapq.heappush(heap, (new_dist, v))
# Reconstruct path
if dist[target] == float('inf'):
return float('inf'), []
path = []
node = target
while node != -1:
path.append(node)
node = parent[node]
return dist[target], path[::-1]
# Test
edges = [(0,1,4), (0,2,1), (2,1,2), (1,3,1), (2,3,5)]
print(dijkstra(4, edges, 0)) # [0, 3, 1, 4]
print(dijkstra_with_path(4, edges, 0, 3)) # (4, [0, 2, 1, 3])
Problem 2: Network Delay Time
Problem: You have n network nodes and a list of travel times as directed edges (u, v, time). Given a source node k, find the time it takes for all nodes to receive a signal. Return -1 if not all nodes can be reached.
ML Context: This models latency propagation in distributed inference systems. When a request hits a model ensemble, the total response time is the maximum time for any model to respond (the longest shortest path from the orchestrator to each model server).
ML Context: This models latency propagation in distributed inference systems. When a request hits a model ensemble, the total response time is the maximum time for any model to respond (the longest shortest path from the orchestrator to each model server).
def network_delay_time(times: list, n: int, k: int) -> int:
"""Find max shortest path from source to any node.
This is the time for ALL nodes to receive the signal.
Time: O((V + E) log V), Space: O(V + E).
"""
graph = defaultdict(list)
for u, v, w in times:
graph[u].append((v, w))
dist = {k: 0}
heap = [(0, k)]
while heap:
d, u = heapq.heappop(heap)
if d > dist.get(u, float('inf')):
continue
for v, w in graph[u]:
new_dist = d + w
if new_dist < dist.get(v, float('inf')):
dist[v] = new_dist
heapq.heappush(heap, (new_dist, v))
if len(dist) != n:
return -1 # Some nodes unreachable
return max(dist.values())
# Test
times = [(2,1,1), (2,3,1), (3,4,1)]
print(network_delay_time(times, 4, 2)) # 2
# Node 2 -> Node 1 (1), Node 2 -> Node 3 (1), Node 3 -> Node 4 (2)
# All nodes reached by time 2
Problem 3: Cheapest Flights Within K Stops
Problem: Given n cities, flights [from, to, price], find the cheapest price from source to destination with at most k stops.
ML Context: This models cost-constrained inference routing. You want the cheapest way to process a query through a chain of models, but with a latency constraint (max k hops). For example: route through at most 2 intermediate services to keep p99 latency under 200ms.
ML Context: This models cost-constrained inference routing. You want the cheapest way to process a query through a chain of models, but with a latency constraint (max k hops). For example: route through at most 2 intermediate services to keep p99 latency under 200ms.
def find_cheapest_price(n: int, flights: list, src: int, dst: int, k: int) -> int:
"""Modified Bellman-Ford: relax edges at most k+1 times.
Time: O(k * E), Space: O(V).
"""
prices = [float('inf')] * n
prices[src] = 0
for _ in range(k + 1): # At most k stops = k+1 edges
temp = prices.copy()
for u, v, price in flights:
if prices[u] != float('inf'):
temp[v] = min(temp[v], prices[u] + price)
prices = temp
return prices[dst] if prices[dst] != float('inf') else -1
def find_cheapest_bfs(n: int, flights: list, src: int, dst: int, k: int) -> int:
"""BFS with level tracking (stops = levels).
Time: O(k * E), Space: O(V * k) worst case.
"""
graph = defaultdict(list)
for u, v, price in flights:
graph[u].append((v, price))
best = [float('inf')] * n
best[src] = 0
queue = deque([(src, 0, 0)]) # (node, cost, stops)
while queue:
node, cost, stops = queue.popleft()
if stops > k:
continue
for neighbor, price in graph[node]:
new_cost = cost + price
if new_cost < best[neighbor]:
best[neighbor] = new_cost
queue.append((neighbor, new_cost, stops + 1))
return best[dst] if best[dst] != float('inf') else -1
# Test
flights = [(0,1,100), (1,2,100), (0,2,500)]
print(find_cheapest_price(3, flights, 0, 2, 1)) # 200 (0->1->2)
print(find_cheapest_price(3, flights, 0, 2, 0)) # 500 (direct 0->2)
Problem 4: Minimum Spanning Tree
Problem: Given n nodes and weighted undirected edges, find the minimum spanning tree (MST) — the subset of edges that connects all nodes with minimum total weight.
ML Context: MST is used in hierarchical clustering (single-linkage clustering is equivalent to MST), feature selection (minimum redundancy graphs), and network design for distributed training (connecting GPU nodes with minimum communication cost).
ML Context: MST is used in hierarchical clustering (single-linkage clustering is equivalent to MST), feature selection (minimum redundancy graphs), and network design for distributed training (connecting GPU nodes with minimum communication cost).
def kruskal_mst(n: int, edges: list) -> tuple:
"""Kruskal's algorithm: sort edges by weight, add if no cycle.
Uses Union-Find for cycle detection.
Time: O(E log E), Space: O(V).
"""
parent = list(range(n))
rank = [0] * n
def find(x):
while parent[x] != x:
parent[x] = parent[parent[x]] # Path compression
x = parent[x]
return x
def union(x, y):
px, py = find(x), find(y)
if px == py:
return False
if rank[px] < rank[py]:
px, py = py, px
parent[py] = px
if rank[px] == rank[py]:
rank[px] += 1
return True
# Sort edges by weight
sorted_edges = sorted(edges, key=lambda e: e[2])
mst_edges = []
total_weight = 0
for u, v, w in sorted_edges:
if union(u, v):
mst_edges.append((u, v, w))
total_weight += w
if len(mst_edges) == n - 1:
break
return total_weight, mst_edges
def prim_mst(n: int, edges: list) -> tuple:
"""Prim's algorithm: grow MST from a starting node.
Time: O(E log V), Space: O(V + E).
"""
graph = defaultdict(list)
for u, v, w in edges:
graph[u].append((w, v))
graph[v].append((w, u))
visited = set()
heap = [(0, 0)] # (weight, node), start from node 0
total_weight = 0
mst_edges = []
while heap and len(visited) < n:
weight, u = heapq.heappop(heap)
if u in visited:
continue
visited.add(u)
total_weight += weight
for w, v in graph[u]:
if v not in visited:
heapq.heappush(heap, (w, v))
return total_weight
# Test
edges = [(0,1,4), (0,2,1), (1,2,2), (1,3,5), (2,3,3)]
print(kruskal_mst(4, edges)) # (6, [(0,2,1), (1,2,2), (2,3,3)])
print(prim_mst(4, edges)) # 6
Problem 5: Union-Find (Disjoint Set Union)
Problem: Implement a Union-Find data structure that supports: (1) union(x, y) — merge the sets containing x and y, (2) find(x) — return the representative of x's set, (3) connected(x, y) — check if x and y are in the same set.
ML Context: Union-Find powers DBSCAN clustering (merge density-connected points), image segmentation (merge similar adjacent pixels), and duplicate detection in data pipelines (merge records that match on any of several keys).
ML Context: Union-Find powers DBSCAN clustering (merge density-connected points), image segmentation (merge similar adjacent pixels), and duplicate detection in data pipelines (merge records that match on any of several keys).
class UnionFind:
"""Disjoint Set Union with path compression and union by rank.
Nearly O(1) amortized per operation (inverse Ackermann).
"""
def __init__(self, n: int):
self.parent = list(range(n))
self.rank = [0] * n
self.count = n # Number of connected components
def find(self, x: int) -> int:
"""Find root with path compression."""
while self.parent[x] != x:
self.parent[x] = self.parent[self.parent[x]]
x = self.parent[x]
return x
def union(self, x: int, y: int) -> bool:
"""Merge sets containing x and y. Returns True if merged."""
px, py = self.find(x), self.find(y)
if px == py:
return False
# Union by rank
if self.rank[px] < self.rank[py]:
px, py = py, px
self.parent[py] = px
if self.rank[px] == self.rank[py]:
self.rank[px] += 1
self.count -= 1
return True
def connected(self, x: int, y: int) -> bool:
return self.find(x) == self.find(y)
# Application: Redundant Connection
# Given a tree with one extra edge (creating exactly one cycle),
# find the edge that can be removed to make it a tree again.
def find_redundant_connection(edges: list) -> list:
"""Find the edge that creates a cycle.
Time: O(E * alpha(V)), Space: O(V).
"""
n = len(edges)
uf = UnionFind(n + 1) # 1-indexed nodes
for u, v in edges:
if not uf.union(u, v):
return [u, v] # This edge creates a cycle
return []
# Test
uf = UnionFind(5)
uf.union(0, 1)
uf.union(2, 3)
print(uf.connected(0, 1)) # True
print(uf.connected(0, 2)) # False
uf.union(1, 3)
print(uf.connected(0, 2)) # True (0-1-3-2 are now connected)
print(uf.count) # 2 (one component: {0,1,2,3}, one: {4})
edges = [(1,2), (1,3), (2,3)]
print(find_redundant_connection(edges)) # [2, 3]
Algorithm Selection Guide
| Algorithm | Time | Use When |
|---|---|---|
| BFS | O(V + E) | Shortest path in unweighted graph |
| Dijkstra | O((V+E) log V) | Shortest path with non-negative weights |
| Bellman-Ford | O(V * E) | Shortest path with negative weights, or k-stop constraint |
| DAG shortest path | O(V + E) | Shortest path when graph is a DAG |
| Kruskal (MST) | O(E log E) | Minimum spanning tree, sparse graphs |
| Prim (MST) | O(E log V) | Minimum spanning tree, dense graphs |
| Union-Find | O(alpha(V)) | Dynamic connectivity, cycle detection, MST |
Interview tip: When you see "shortest path", immediately ask: Are edges weighted? (Yes = Dijkstra/Bellman-Ford, No = BFS). Are there negative weights? (Yes = Bellman-Ford, No = Dijkstra). Is the graph a DAG? (Yes = topological sort approach, O(V+E)). Is there a constraint on number of edges? (Yes = modified Bellman-Ford).
Lilly Tech Systems