Graph Neural Networks
Graph Neural Networks
Graph Neural Networks (GNNs) extend deep learning to graph-structured data, enabling powerful representations for molecules, social networks, knowledge graphs, and recommendation systems. This lesson covers message passing frameworks, different GNN architectures, and practical implementations using PyTorch Geometric.
Core Concepts
Graph Fundamentals
A graph G = (V, E) consists of vertices (nodes) V and edges E. Attributes can be associated with both nodes and edges.
Graph Types:
- Directed vs. undirected
- Homogeneous (single node/edge type) vs. heterogeneous
- Static vs. dynamic
- Attributed vs. unattributed
Message Passing Framework
GNNs operate on the principle of message passing: nodes update their representations by aggregating information from neighbors.
General Message Passing Update:
h_i^(k+1) = UPDATE(h_i^(k), AGGREGATE({h_j^(k) : j ∈ N(i)}))
Where:
- h_i^(k): node representation at layer k
- N(i): neighbors of node i
- AGGREGATE: symmetric function (e.g., sum, mean, max)
- UPDATE: learnable transformation (e.g., MLP)
Graph Convolutional Networks (GCN)
GCN is one of the most popular GNN architectures, applying spectral convolution concepts to graphs.
GCN Layer:
H^(k+1) = σ(D̂^(-1/2) Â D̂^(-1/2) H^(k) W^(k))
Where:
- Â = A + I (adjacency matrix with self-loops)
- D̂: degree matrix of Â
- W^(k): learnable weight matrix
- σ: activation function
Intuition: Each node aggregates normalized features from itself and neighbors, then applies transformation.
GraphSAGE (SAmple and aggreGatE)
GraphSAGE addresses scalability by sampling neighborhoods and aggregating features:
Algorithm:
- Sample a small fixed-size neighborhood for each node
- Aggregate features from sampled neighbors
- Update node representation using MLP
- Enables inductive learning on unseen nodes
Aggregation Functions:
- Mean: Average neighbor features
- LSTM: Sequential aggregation with learnable weights
- Pooling: Max/average pooling over neighbor features
Practical Implementation
Graph Convolutional Network in PyTorch Geometric
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, DataLoader
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv
class GCN(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2):
super().__init__()
self.num_layers = num_layers
self.convs = nn.ModuleList()
self.convs.append(GCNConv(in_channels, hidden_channels))
for _ in range(num_layers - 2):
self.convs.append(GCNConv(hidden_channels, hidden_channels))
self.convs.append(GCNConv(hidden_channels, out_channels))
self.dropout = nn.Dropout(0.5)
def forward(self, x, edge_index):
for i, conv in enumerate(self.convs[:-1]):
x = conv(x, edge_index)
x = F.relu(x)
x = self.dropout(x)
x = self.convs[-1](x, edge_index)
return x
# Load Cora dataset
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]
# Initialize model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN(
in_channels=data.num_node_features,
hidden_channels=64,
out_channels=dataset.num_classes,
num_layers=2
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss()
# Training loop
def train():
model.train()
optimizer.zero_grad()
out = model(data.x.to(device), data.edge_index.to(device))
loss = criterion(out[data.train_mask], data.y[data.train_mask].to(device))
loss.backward()
optimizer.step()
return loss.item()
@torch.no_grad()
def test():
model.eval()
out = model(data.x.to(device), data.edge_index.to(device))
pred = out.argmax(dim=1)
train_correct = (pred[data.train_mask] == data.y[data.train_mask].to(device)).sum().item()
train_acc = train_correct / data.train_mask.sum().item()
val_correct = (pred[data.val_mask] == data.y[data.val_mask].to(device)).sum().item()
val_acc = val_correct / data.val_mask.sum().item()
test_correct = (pred[data.test_mask] == data.y[data.test_mask].to(device)).sum().item()
test_acc = test_correct / data.test_mask.sum().item()
return train_acc, val_acc, test_acc
for epoch in range(1, 201):
loss = train()
train_acc, val_acc, test_acc = test()
if epoch % 20 == 0:
print(f'Epoch {epoch:3d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}')
GraphSAGE Implementation
from torch_geometric.nn import SAGEConv
class GraphSAGE(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2):
super().__init__()
self.convs = nn.ModuleList()
self.convs.append(SAGEConv(in_channels, hidden_channels))
for _ in range(num_layers - 2):
self.convs.append(SAGEConv(hidden_channels, hidden_channels))
self.convs.append(SAGEConv(hidden_channels, out_channels))
self.dropout = nn.Dropout(0.5)
def forward(self, x, edge_index):
for i, conv in enumerate(self.convs[:-1]):
x = conv(x, edge_index)
x = F.relu(x)
x = self.dropout(x)
x = self.convs[-1](x, edge_index)
return x
# Mini-batch training with neighbor sampling
from torch_geometric.loader import NeighborLoader
train_loader = NeighborLoader(
data,
num_neighbors=[25, 10],
batch_size=1024,
input_nodes=data.train_mask,
)
model = GraphSAGE(
in_channels=data.num_node_features,
hidden_channels=64,
out_channels=dataset.num_classes,
num_layers=2
).to(device)
for epoch in range(1, 21):
model.train()
total_loss = 0
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
out = model(batch.x, batch.edge_index)
loss = criterion(out[:batch.batch_size], batch.y[:batch.batch_size])
loss.backward()
optimizer.step()
total_loss += loss.item() * batch.batch_size
print(f'Epoch {epoch:2d}, Loss: {total_loss / len(data.train_mask):.4f}')
Attention-based GNN (GAT)
from torch_geometric.nn import GATConv
class GAT(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, num_heads=8):
super().__init__()
self.conv1 = GATConv(in_channels, hidden_channels, heads=num_heads, dropout=0.6)
self.conv2 = GATConv(hidden_channels * num_heads, out_channels, heads=1, dropout=0.6)
def forward(self, x, edge_index):
x = F.dropout(x, p=0.6, training=self.training)
x = self.conv1(x, edge_index)
x = F.elu(x)
x = F.dropout(x, p=0.6, training=self.training)
x = self.conv2(x, edge_index)
return x
model = GAT(
in_channels=data.num_node_features,
hidden_channels=8,
out_channels=dataset.num_classes,
num_heads=8
).to(device)
Advanced Techniques
Custom Message Passing Layers
from torch_geometric.nn import MessagePassing
from torch_scatter import scatter
class CustomGNNLayer(MessagePassing):
def __init__(self, in_channels, out_channels, aggr='add'):
super().__init__(aggr=aggr)
self.lin = nn.Linear(in_channels, out_channels)
self.lin_neighbor = nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
# x shape: [num_nodes, in_channels]
# edge_index shape: [2, num_edges]
# Propagate messages
out = self.propagate(edge_index, x=x)
return out
def message(self, x_j, x_i):
# x_j: features of source nodes
# x_i: features of target nodes
return self.lin_neighbor(x_j)
def aggregate(self, inputs, index):
# Custom aggregation
return scatter(inputs, index, dim=self.node_dim, reduce=self.aggr)
def update(self, inputs, x):
# Update with node features
return self.lin(x) + inputs
Graph Pooling
from torch_geometric.nn import global_mean_pool, global_max_pool, SAGPooling
class PoolingGNN(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.pool1 = SAGPooling(hidden_channels, ratio=0.8)
self.conv2 = GCNConv(hidden_channels, hidden_channels)
self.pool2 = SAGPooling(hidden_channels, ratio=0.8)
self.fc = nn.Linear(hidden_channels, out_channels)
def forward(self, x, edge_index, batch):
x = F.relu(self.conv1(x, edge_index))
x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, batch=batch)
x = F.relu(self.conv2(x, edge_index))
x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, batch=batch)
x = global_mean_pool(x, batch)
return self.fc(x)
Heterogeneous Graph Neural Networks
from torch_geometric.nn import HeteroConv, GCNConv
class HeteroGNN(nn.Module):
def __init__(self, in_channels_dict, hidden_channels, out_channels, metadata):
super().__init__()
self.conv = HeteroConv({
edge_type: GCNConv(in_channels_dict[edge_type[0]], hidden_channels)
for edge_type in metadata[1]
})
self.lin = nn.Linear(hidden_channels, out_channels)
def forward(self, x_dict, edge_index_dict):
x_dict = self.conv(x_dict, edge_index_dict)
x_dict = {key: F.relu(x) for key, x in x_dict.items()}
return x_dict
Production Considerations
Scalability
Challenges:
- Large graphs don’t fit in memory
- Full-batch training becomes infeasible
- Neighborhood explosion in deep GNNs (over-squashing)
Solutions:
# Mini-batch sampling
from torch_geometric.loader import ClusterLoader
cluster_loader = ClusterLoader(
data,
num_parts=10,
batch_size=20,
shuffle=True
)
# Layer-wise sampling
from torch_geometric.loader import NeighborSampler
sampler = NeighborSampler(
edge_index,
num_nodes=num_nodes,
batch_size=1024,
num_workers=4
)
Evaluation on Graphs
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
@torch.no_grad()
def evaluate_node_classification(model, data, mask):
model.eval()
logits = model(data.x, data.edge_index)
preds = logits[mask].argmax(dim=1)
acc = accuracy_score(data.y[mask].cpu(), preds.cpu())
return acc
# Link prediction
from torch_geometric.utils import negative_sampling
def evaluate_link_prediction(model, data):
model.eval()
z = model.encode(data.x, data.edge_index)
pos_edge_index = data.test_pos_edge_index
neg_edge_index = negative_sampling(
data.edge_index,
num_nodes=data.num_nodes,
num_neg_samples=pos_edge_index.size(1)
)
pos_scores = torch.sigmoid((z[pos_edge_index[0]] * z[pos_edge_index[1]]).sum(dim=1))
neg_scores = torch.sigmoid((z[neg_edge_index[0]] * z[neg_edge_index[1]]).sum(dim=1))
labels = torch.cat([torch.ones_like(pos_scores), torch.zeros_like(neg_scores)])
scores = torch.cat([pos_scores, neg_scores])
auc = roc_auc_score(labels.cpu().numpy(), scores.detach().cpu().numpy())
return auc
Model Export and Serving
# Convert to ONNX for inference
import onnx
dummy_x = torch.randn(10, data.num_node_features)
dummy_edge_index = torch.randint(0, 10, (2, 20))
torch.onnx.export(
model,
(dummy_x, dummy_edge_index),
'gnn_model.onnx',
input_names=['x', 'edge_index'],
output_names=['output'],
opset_version=11
)
Key Takeaway
Graph Neural Networks unlock the power of structured relationships in data. By learning to propagate information through graph neighborhoods, GNNs excel at node classification, link prediction, and graph-level tasks. Master message passing, sampling strategies, and graph pooling to build scalable models for complex relational data.
Practical Exercise
Task: Build a GNN-based link prediction system for a social network or citation network.
Requirements:
- Load a graph dataset (Cora, Citeseer, or custom)
- Implement both GCN and GraphSAGE architectures
- Create train/val/test split on edges (80/10/10)
- Train with link prediction objective:
- Positive pairs: existing edges
- Negative pairs: sampled non-edges
- Use dot product similarity: score = z_i · z_j
- Implement evaluation using AUC-ROC and AP metrics
Evaluation:
- Compare GCN vs GraphSAGE on test AUC
- Analyze computational efficiency (training time)
- Test inductive setting: encode unseen nodes
- Visualize learned embeddings with t-SNE
- Implement hard negative sampling for better training