Graph Neural Networks Deep Dive
Learn to apply neural networks to graph-structured data — from social networks and molecules to fraud detection and recommendation systems.
What Are Graphs?
A graph is a data structure consisting of nodes (entities) connected by edges (relationships). Graphs are everywhere:
Graph G = (V, E)
V = set of nodes (vertices)
E = set of edges (connections between nodes)
Node features: each node has a feature vector x_i
Example: a user's profile (age, location, interests)
Edge features: each edge can have attributes
Example: friendship strength, transaction amount
Adjacency Matrix A (n x n):
A[i][j] = 1 if there's an edge between node i and node j
A[i][j] = 0 otherwise
Example (3-node triangle):
A = [[0, 1, 1], Node 0 connects to 1 and 2
[1, 0, 1], Node 1 connects to 0 and 2
[1, 1, 0]] Node 2 connects to 0 and 1
Social Networks
Nodes = users. Edges = friendships/follows. Predict user interests, detect communities, recommend connections.
Molecules
Nodes = atoms. Edges = chemical bonds. Predict molecular properties, drug efficacy, toxicity.
Financial Networks
Nodes = accounts. Edges = transactions. Detect fraud rings, money laundering patterns.
Knowledge Graphs
Nodes = entities. Edges = relationships. Power search engines, question answering, recommendations.
Why Standard Neural Networks Fail on Graphs
Standard neural networks (MLPs, CNNs, RNNs) assume data has a fixed structure:
- MLPs: Expect fixed-size feature vectors. Graphs have variable numbers of neighbors.
- CNNs: Expect grid-structured data (images). Graphs have irregular topology.
- RNNs: Expect sequential data. Graphs have no inherent ordering.
Message Passing Framework
Almost all GNNs follow the message passing paradigm. Each node updates its representation by aggregating information from its neighbors:
Message Passing (for each layer):
1. MESSAGE: Each node creates a "message" for its neighbors
m_{i→j} = MSG(h_i, h_j, e_{ij})
(based on sender features, receiver features, edge features)
2. AGGREGATE: Each node collects messages from ALL neighbors
M_i = AGG({m_{j→i} : j in Neighbors(i)})
AGG can be: sum, mean, max, or attention-weighted
3. UPDATE: Each node updates its own representation
h_i^{new} = UPDATE(h_i, M_i)
(combine old features with aggregated neighbor info)
After K layers of message passing, each node's representation
captures information from its K-hop neighborhood.
GCN (Graph Convolutional Networks)
The foundational GNN architecture (Kipf & Welling, 2017). It performs a normalized neighborhood aggregation:
GCN Layer:
H^{l+1} = sigma(D^{-1/2} * A_hat * D^{-1/2} * H^l * W^l)
Where:
A_hat = A + I (adjacency matrix + self-loops)
D = degree matrix of A_hat
H^l = node features at layer l
W^l = learnable weight matrix
sigma = activation (ReLU)
Simplified interpretation:
For each node i:
h_i^{new} = ReLU(W * MEAN(h_j for j in {i} ∪ Neighbors(i)))
→ Average your own features with your neighbors', then transform
GAT (Graph Attention Networks)
Instead of treating all neighbors equally, GAT learns attention weights — how important each neighbor is:
GAT Layer:
1. Compute attention coefficients:
e_{ij} = LeakyReLU(a^T * [W*h_i || W*h_j])
alpha_{ij} = softmax_j(e_{ij}) (normalize over all neighbors)
2. Aggregate with attention weights:
h_i^{new} = sigma(SUM[alpha_{ij} * W * h_j])
Multi-head attention:
Concatenate K independent attention heads for richer representations.
Advantage over GCN:
- Learns WHICH neighbors are important (not equal weighting)
- More expressive, especially for heterogeneous graphs
GraphSAGE
Graph SAmple and aggreGatE — designed for large graphs and inductive learning (generalizing to unseen nodes):
GraphSAGE:
1. SAMPLE: Randomly sample a fixed number of neighbors
(instead of using ALL neighbors — scales to large graphs)
2. AGGREGATE: Apply an aggregation function
Options: mean, LSTM, max-pooling
3. UPDATE: Concatenate own features with aggregated neighbors
h_i^{new} = sigma(W * CONCAT(h_i, AGG({h_j : j in sampled_neighbors})))
Key advantage:
- Inductive: can generate embeddings for NEW, unseen nodes
- GCN/GAT are transductive (need to retrain for new nodes)
- Scales to graphs with millions of nodes
GNN Architecture Comparison
| Architecture | Aggregation | Attention | Scalability | Best For |
|---|---|---|---|---|
| GCN | Normalized mean | No (equal weights) | Medium | Homogeneous graphs, semi-supervised learning |
| GAT | Attention-weighted | Yes (learned) | Medium | Heterogeneous importance, citation networks |
| GraphSAGE | Sample + aggregate | Optional | High | Large graphs, inductive tasks, production |
Applications
| Domain | Graph Structure | Task | Impact |
|---|---|---|---|
| Social Networks | Users + friendships | Link prediction, community detection | Friend suggestions (Facebook, LinkedIn) |
| Drug Discovery | Atoms + bonds | Molecular property prediction | 10x faster drug screening |
| Recommendation | Users + items + interactions | Predict user-item interactions | Pinterest, Uber Eats recommendations |
| Fraud Detection | Accounts + transactions | Anomalous subgraph detection | Detect fraud rings in banking |
| Traffic/Maps | Intersections + roads | Travel time prediction | Google Maps ETA predictions |
| Protein Folding | Amino acids + contacts | 3D structure prediction | AlphaFold's breakthrough |
Comparison with CNNs and RNNs
| Aspect | CNNs | RNNs | GNNs |
|---|---|---|---|
| Data Type | Grid (images) | Sequences (text, time) | Graphs (arbitrary topology) |
| Structure | Fixed grid | Linear chain | Arbitrary connections |
| Key Operation | Convolution (local filter) | Recurrence (sequential) | Message passing (neighbors) |
| Permutation | Invariant (pooling) | Order-dependent | Invariant (by design) |
| Scalability | O(pixels) | O(sequence length) | O(nodes + edges) |
PyTorch Geometric Code Example
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, global_mean_pool
from torch_geometric.datasets import Planetoid, TUDataset
from torch_geometric.loader import DataLoader
# ============================================
# Example 1: Node Classification (Cora Dataset)
# ============================================
# Cora: citation network, 2708 papers, 7 classes
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]
print(f"Nodes: {data.num_nodes}")
print(f"Edges: {data.num_edges}")
print(f"Features per node: {data.num_node_features}")
print(f"Classes: {dataset.num_classes}")
class GCN_NodeClassifier(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=0.5, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
# Train
model = GCN_NodeClassifier(
dataset.num_node_features, 64, dataset.num_classes
)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
model.train()
for epoch in range(200):
optimizer.zero_grad()
out = model(data)
loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
# Evaluate
model.eval()
pred = model(data).argmax(dim=1)
correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
accuracy = int(correct) / int(data.test_mask.sum())
print(f"\nGCN Node Classification Accuracy: {accuracy:.4f}")
# ============================================
# Example 2: Graph Classification (MUTAG Dataset)
# ============================================
# MUTAG: 188 molecules, classify mutagenicity
dataset = TUDataset(root='/tmp/MUTAG', name='MUTAG')
print(f"\nMUTAG: {len(dataset)} graphs, {dataset.num_classes} classes")
# Split into train/test
torch.manual_seed(42)
perm = torch.randperm(len(dataset))
train_dataset = dataset[perm[:150]]
test_dataset = dataset[perm[150:]]
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)
class GAT_GraphClassifier(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = GATConv(in_channels, hidden_channels, heads=4)
self.conv2 = GATConv(hidden_channels * 4, hidden_channels)
self.classifier = torch.nn.Linear(hidden_channels, out_channels)
def forward(self, data):
x, edge_index, batch = data.x, data.edge_index, data.batch
x = F.relu(self.conv1(x, edge_index))
x = F.dropout(x, p=0.3, training=self.training)
x = F.relu(self.conv2(x, edge_index))
x = global_mean_pool(x, batch) # graph-level readout
x = self.classifier(x)
return F.log_softmax(x, dim=1)
model = GAT_GraphClassifier(
dataset.num_node_features, 32, dataset.num_classes
)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
# Training loop
for epoch in range(100):
model.train()
for batch in train_loader:
optimizer.zero_grad()
out = model(batch)
loss = F.nll_loss(out, batch.y)
loss.backward()
optimizer.step()
# Evaluate
model.eval()
correct = 0
total = 0
for batch in test_loader:
pred = model(batch).argmax(dim=1)
correct += (pred == batch.y).sum().item()
total += batch.y.size(0)
print(f"GAT Graph Classification Accuracy: {correct/total:.4f}")
Lilly Tech Systems