Advanced

Graph Neural Networks Deep Dive

Learn to apply neural networks to graph-structured data — from social networks and molecules to fraud detection and recommendation systems.

What Are Graphs?

A graph is a data structure consisting of nodes (entities) connected by edges (relationships). Graphs are everywhere:

Graph G = (V, E)
  V = set of nodes (vertices)
  E = set of edges (connections between nodes)

Node features: each node has a feature vector x_i
  Example: a user's profile (age, location, interests)

Edge features: each edge can have attributes
  Example: friendship strength, transaction amount

Adjacency Matrix A (n x n):
  A[i][j] = 1 if there's an edge between node i and node j
  A[i][j] = 0 otherwise

Example (3-node triangle):
  A = [[0, 1, 1],    Node 0 connects to 1 and 2
       [1, 0, 1],    Node 1 connects to 0 and 2
       [1, 1, 0]]    Node 2 connects to 0 and 1
👥

Social Networks

Nodes = users. Edges = friendships/follows. Predict user interests, detect communities, recommend connections.

Molecules

Nodes = atoms. Edges = chemical bonds. Predict molecular properties, drug efficacy, toxicity.

💰

Financial Networks

Nodes = accounts. Edges = transactions. Detect fraud rings, money laundering patterns.

🌍

Knowledge Graphs

Nodes = entities. Edges = relationships. Power search engines, question answering, recommendations.

Why Standard Neural Networks Fail on Graphs

Standard neural networks (MLPs, CNNs, RNNs) assume data has a fixed structure:

  • MLPs: Expect fixed-size feature vectors. Graphs have variable numbers of neighbors.
  • CNNs: Expect grid-structured data (images). Graphs have irregular topology.
  • RNNs: Expect sequential data. Graphs have no inherent ordering.
💡
The key challenge: Graphs are permutation invariant — reordering nodes shouldn't change the output. Standard NNs depend on input ordering. GNNs solve this by operating on the graph structure directly, aggregating information from neighbors regardless of their order.

Message Passing Framework

Almost all GNNs follow the message passing paradigm. Each node updates its representation by aggregating information from its neighbors:

Message Passing (for each layer):

1. MESSAGE: Each node creates a "message" for its neighbors
   m_{i→j} = MSG(h_i, h_j, e_{ij})
   (based on sender features, receiver features, edge features)

2. AGGREGATE: Each node collects messages from ALL neighbors
   M_i = AGG({m_{j→i} : j in Neighbors(i)})
   AGG can be: sum, mean, max, or attention-weighted

3. UPDATE: Each node updates its own representation
   h_i^{new} = UPDATE(h_i, M_i)
   (combine old features with aggregated neighbor info)

After K layers of message passing, each node's representation
captures information from its K-hop neighborhood.

GCN (Graph Convolutional Networks)

The foundational GNN architecture (Kipf & Welling, 2017). It performs a normalized neighborhood aggregation:

GCN Layer:
  H^{l+1} = sigma(D^{-1/2} * A_hat * D^{-1/2} * H^l * W^l)

Where:
  A_hat = A + I   (adjacency matrix + self-loops)
  D = degree matrix of A_hat
  H^l = node features at layer l
  W^l = learnable weight matrix
  sigma = activation (ReLU)

Simplified interpretation:
  For each node i:
    h_i^{new} = ReLU(W * MEAN(h_j for j in {i} ∪ Neighbors(i)))
  → Average your own features with your neighbors', then transform

GAT (Graph Attention Networks)

Instead of treating all neighbors equally, GAT learns attention weights — how important each neighbor is:

GAT Layer:
1. Compute attention coefficients:
   e_{ij} = LeakyReLU(a^T * [W*h_i || W*h_j])
   alpha_{ij} = softmax_j(e_{ij})  (normalize over all neighbors)

2. Aggregate with attention weights:
   h_i^{new} = sigma(SUM[alpha_{ij} * W * h_j])

Multi-head attention:
   Concatenate K independent attention heads for richer representations.

Advantage over GCN:
   - Learns WHICH neighbors are important (not equal weighting)
   - More expressive, especially for heterogeneous graphs

GraphSAGE

Graph SAmple and aggreGatE — designed for large graphs and inductive learning (generalizing to unseen nodes):

GraphSAGE:
1. SAMPLE: Randomly sample a fixed number of neighbors
   (instead of using ALL neighbors — scales to large graphs)

2. AGGREGATE: Apply an aggregation function
   Options: mean, LSTM, max-pooling

3. UPDATE: Concatenate own features with aggregated neighbors
   h_i^{new} = sigma(W * CONCAT(h_i, AGG({h_j : j in sampled_neighbors})))

Key advantage:
   - Inductive: can generate embeddings for NEW, unseen nodes
   - GCN/GAT are transductive (need to retrain for new nodes)
   - Scales to graphs with millions of nodes

GNN Architecture Comparison

ArchitectureAggregationAttentionScalabilityBest For
GCNNormalized meanNo (equal weights)MediumHomogeneous graphs, semi-supervised learning
GATAttention-weightedYes (learned)MediumHeterogeneous importance, citation networks
GraphSAGESample + aggregateOptionalHighLarge graphs, inductive tasks, production

Applications

DomainGraph StructureTaskImpact
Social NetworksUsers + friendshipsLink prediction, community detectionFriend suggestions (Facebook, LinkedIn)
Drug DiscoveryAtoms + bondsMolecular property prediction10x faster drug screening
RecommendationUsers + items + interactionsPredict user-item interactionsPinterest, Uber Eats recommendations
Fraud DetectionAccounts + transactionsAnomalous subgraph detectionDetect fraud rings in banking
Traffic/MapsIntersections + roadsTravel time predictionGoogle Maps ETA predictions
Protein FoldingAmino acids + contacts3D structure predictionAlphaFold's breakthrough

Comparison with CNNs and RNNs

AspectCNNsRNNsGNNs
Data TypeGrid (images)Sequences (text, time)Graphs (arbitrary topology)
StructureFixed gridLinear chainArbitrary connections
Key OperationConvolution (local filter)Recurrence (sequential)Message passing (neighbors)
PermutationInvariant (pooling)Order-dependentInvariant (by design)
ScalabilityO(pixels)O(sequence length)O(nodes + edges)

PyTorch Geometric Code Example

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, global_mean_pool
from torch_geometric.datasets import Planetoid, TUDataset
from torch_geometric.loader import DataLoader

# ============================================
# Example 1: Node Classification (Cora Dataset)
# ============================================
# Cora: citation network, 2708 papers, 7 classes
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]

print(f"Nodes: {data.num_nodes}")
print(f"Edges: {data.num_edges}")
print(f"Features per node: {data.num_node_features}")
print(f"Classes: {dataset.num_classes}")

class GCN_NodeClassifier(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

# Train
model = GCN_NodeClassifier(
    dataset.num_node_features, 64, dataset.num_classes
)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

model.train()
for epoch in range(200):
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()

# Evaluate
model.eval()
pred = model(data).argmax(dim=1)
correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
accuracy = int(correct) / int(data.test_mask.sum())
print(f"\nGCN Node Classification Accuracy: {accuracy:.4f}")

# ============================================
# Example 2: Graph Classification (MUTAG Dataset)
# ============================================
# MUTAG: 188 molecules, classify mutagenicity
dataset = TUDataset(root='/tmp/MUTAG', name='MUTAG')
print(f"\nMUTAG: {len(dataset)} graphs, {dataset.num_classes} classes")

# Split into train/test
torch.manual_seed(42)
perm = torch.randperm(len(dataset))
train_dataset = dataset[perm[:150]]
test_dataset = dataset[perm[150:]]

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)

class GAT_GraphClassifier(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GATConv(in_channels, hidden_channels, heads=4)
        self.conv2 = GATConv(hidden_channels * 4, hidden_channels)
        self.classifier = torch.nn.Linear(hidden_channels, out_channels)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.3, training=self.training)
        x = F.relu(self.conv2(x, edge_index))
        x = global_mean_pool(x, batch)  # graph-level readout
        x = self.classifier(x)
        return F.log_softmax(x, dim=1)

model = GAT_GraphClassifier(
    dataset.num_node_features, 32, dataset.num_classes
)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

# Training loop
for epoch in range(100):
    model.train()
    for batch in train_loader:
        optimizer.zero_grad()
        out = model(batch)
        loss = F.nll_loss(out, batch.y)
        loss.backward()
        optimizer.step()

# Evaluate
model.eval()
correct = 0
total = 0
for batch in test_loader:
    pred = model(batch).argmax(dim=1)
    correct += (pred == batch.y).sum().item()
    total += batch.y.size(0)

print(f"GAT Graph Classification Accuracy: {correct/total:.4f}")
When NOT to use GNNs: If your data isn't naturally graph-structured, don't force it into a graph. Tabular data works better with gradient boosting. Images work better with CNNs. Text works better with Transformers. GNNs shine when relationships between entities are a core part of the problem.