Intermediate

Graph Neural Networks

Lesson 3 of 4 Estimated Time 50 min

Graph Neural Networks

Graph Neural Networks (GNNs) extend deep learning to graph-structured data, enabling powerful representations for molecules, social networks, knowledge graphs, and recommendation systems. This lesson covers message passing frameworks, different GNN architectures, and practical implementations using PyTorch Geometric.

Core Concepts

Graph Fundamentals

A graph G = (V, E) consists of vertices (nodes) V and edges E. Attributes can be associated with both nodes and edges.

Graph Types:

  • Directed vs. undirected
  • Homogeneous (single node/edge type) vs. heterogeneous
  • Static vs. dynamic
  • Attributed vs. unattributed

Message Passing Framework

GNNs operate on the principle of message passing: nodes update their representations by aggregating information from neighbors.

General Message Passing Update:

h_i^(k+1) = UPDATE(h_i^(k), AGGREGATE({h_j^(k) : j ∈ N(i)}))

Where:

  • h_i^(k): node representation at layer k
  • N(i): neighbors of node i
  • AGGREGATE: symmetric function (e.g., sum, mean, max)
  • UPDATE: learnable transformation (e.g., MLP)

Graph Convolutional Networks (GCN)

GCN is one of the most popular GNN architectures, applying spectral convolution concepts to graphs.

GCN Layer:

H^(k+1) = σ(D̂^(-1/2) Â D̂^(-1/2) H^(k) W^(k))

Where:

  • Â = A + I (adjacency matrix with self-loops)
  • D̂: degree matrix of Â
  • W^(k): learnable weight matrix
  • σ: activation function

Intuition: Each node aggregates normalized features from itself and neighbors, then applies transformation.

GraphSAGE (SAmple and aggreGatE)

GraphSAGE addresses scalability by sampling neighborhoods and aggregating features:

Algorithm:

  1. Sample a small fixed-size neighborhood for each node
  2. Aggregate features from sampled neighbors
  3. Update node representation using MLP
  4. Enables inductive learning on unseen nodes

Aggregation Functions:

  • Mean: Average neighbor features
  • LSTM: Sequential aggregation with learnable weights
  • Pooling: Max/average pooling over neighbor features

Practical Implementation

Graph Convolutional Network in PyTorch Geometric

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, DataLoader
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv

class GCN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2):
        super().__init__()
        self.num_layers = num_layers

        self.convs = nn.ModuleList()
        self.convs.append(GCNConv(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(GCNConv(hidden_channels, hidden_channels))
        self.convs.append(GCNConv(hidden_channels, out_channels))

        self.dropout = nn.Dropout(0.5)

    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, edge_index)
            x = F.relu(x)
            x = self.dropout(x)

        x = self.convs[-1](x, edge_index)
        return x

# Load Cora dataset
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]

# Initialize model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN(
    in_channels=data.num_node_features,
    hidden_channels=64,
    out_channels=dataset.num_classes,
    num_layers=2
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss()

# Training loop
def train():
    model.train()
    optimizer.zero_grad()

    out = model(data.x.to(device), data.edge_index.to(device))
    loss = criterion(out[data.train_mask], data.y[data.train_mask].to(device))

    loss.backward()
    optimizer.step()
    return loss.item()

@torch.no_grad()
def test():
    model.eval()
    out = model(data.x.to(device), data.edge_index.to(device))

    pred = out.argmax(dim=1)
    train_correct = (pred[data.train_mask] == data.y[data.train_mask].to(device)).sum().item()
    train_acc = train_correct / data.train_mask.sum().item()

    val_correct = (pred[data.val_mask] == data.y[data.val_mask].to(device)).sum().item()
    val_acc = val_correct / data.val_mask.sum().item()

    test_correct = (pred[data.test_mask] == data.y[data.test_mask].to(device)).sum().item()
    test_acc = test_correct / data.test_mask.sum().item()

    return train_acc, val_acc, test_acc

for epoch in range(1, 201):
    loss = train()
    train_acc, val_acc, test_acc = test()
    if epoch % 20 == 0:
        print(f'Epoch {epoch:3d}, Loss: {loss:.4f}, Train: {train_acc:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}')

GraphSAGE Implementation

from torch_geometric.nn import SAGEConv

class GraphSAGE(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2):
        super().__init__()
        self.convs = nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        self.convs.append(SAGEConv(hidden_channels, out_channels))

        self.dropout = nn.Dropout(0.5)

    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, edge_index)
            x = F.relu(x)
            x = self.dropout(x)
        x = self.convs[-1](x, edge_index)
        return x

# Mini-batch training with neighbor sampling
from torch_geometric.loader import NeighborLoader

train_loader = NeighborLoader(
    data,
    num_neighbors=[25, 10],
    batch_size=1024,
    input_nodes=data.train_mask,
)

model = GraphSAGE(
    in_channels=data.num_node_features,
    hidden_channels=64,
    out_channels=dataset.num_classes,
    num_layers=2
).to(device)

for epoch in range(1, 21):
    model.train()
    total_loss = 0
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()

        out = model(batch.x, batch.edge_index)
        loss = criterion(out[:batch.batch_size], batch.y[:batch.batch_size])

        loss.backward()
        optimizer.step()
        total_loss += loss.item() * batch.batch_size

    print(f'Epoch {epoch:2d}, Loss: {total_loss / len(data.train_mask):.4f}')

Attention-based GNN (GAT)

from torch_geometric.nn import GATConv

class GAT(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_heads=8):
        super().__init__()
        self.conv1 = GATConv(in_channels, hidden_channels, heads=num_heads, dropout=0.6)
        self.conv2 = GATConv(hidden_channels * num_heads, out_channels, heads=1, dropout=0.6)

    def forward(self, x, edge_index):
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index)
        return x

model = GAT(
    in_channels=data.num_node_features,
    hidden_channels=8,
    out_channels=dataset.num_classes,
    num_heads=8
).to(device)

Advanced Techniques

Custom Message Passing Layers

from torch_geometric.nn import MessagePassing
from torch_scatter import scatter

class CustomGNNLayer(MessagePassing):
    def __init__(self, in_channels, out_channels, aggr='add'):
        super().__init__(aggr=aggr)
        self.lin = nn.Linear(in_channels, out_channels)
        self.lin_neighbor = nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # x shape: [num_nodes, in_channels]
        # edge_index shape: [2, num_edges]

        # Propagate messages
        out = self.propagate(edge_index, x=x)
        return out

    def message(self, x_j, x_i):
        # x_j: features of source nodes
        # x_i: features of target nodes
        return self.lin_neighbor(x_j)

    def aggregate(self, inputs, index):
        # Custom aggregation
        return scatter(inputs, index, dim=self.node_dim, reduce=self.aggr)

    def update(self, inputs, x):
        # Update with node features
        return self.lin(x) + inputs

Graph Pooling

from torch_geometric.nn import global_mean_pool, global_max_pool, SAGPooling

class PoolingGNN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.pool1 = SAGPooling(hidden_channels, ratio=0.8)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.pool2 = SAGPooling(hidden_channels, ratio=0.8)

        self.fc = nn.Linear(hidden_channels, out_channels)

    def forward(self, x, edge_index, batch):
        x = F.relu(self.conv1(x, edge_index))
        x, edge_index, _, batch, _, _ = self.pool1(x, edge_index, batch=batch)

        x = F.relu(self.conv2(x, edge_index))
        x, edge_index, _, batch, _, _ = self.pool2(x, edge_index, batch=batch)

        x = global_mean_pool(x, batch)
        return self.fc(x)

Heterogeneous Graph Neural Networks

from torch_geometric.nn import HeteroConv, GCNConv

class HeteroGNN(nn.Module):
    def __init__(self, in_channels_dict, hidden_channels, out_channels, metadata):
        super().__init__()

        self.conv = HeteroConv({
            edge_type: GCNConv(in_channels_dict[edge_type[0]], hidden_channels)
            for edge_type in metadata[1]
        })

        self.lin = nn.Linear(hidden_channels, out_channels)

    def forward(self, x_dict, edge_index_dict):
        x_dict = self.conv(x_dict, edge_index_dict)
        x_dict = {key: F.relu(x) for key, x in x_dict.items()}
        return x_dict

Production Considerations

Scalability

Challenges:

  • Large graphs don’t fit in memory
  • Full-batch training becomes infeasible
  • Neighborhood explosion in deep GNNs (over-squashing)

Solutions:

# Mini-batch sampling
from torch_geometric.loader import ClusterLoader

cluster_loader = ClusterLoader(
    data,
    num_parts=10,
    batch_size=20,
    shuffle=True
)

# Layer-wise sampling
from torch_geometric.loader import NeighborSampler

sampler = NeighborSampler(
    edge_index,
    num_nodes=num_nodes,
    batch_size=1024,
    num_workers=4
)

Evaluation on Graphs

from sklearn.metrics import accuracy_score, f1_score, roc_auc_score

@torch.no_grad()
def evaluate_node_classification(model, data, mask):
    model.eval()
    logits = model(data.x, data.edge_index)
    preds = logits[mask].argmax(dim=1)
    acc = accuracy_score(data.y[mask].cpu(), preds.cpu())
    return acc

# Link prediction
from torch_geometric.utils import negative_sampling

def evaluate_link_prediction(model, data):
    model.eval()
    z = model.encode(data.x, data.edge_index)

    pos_edge_index = data.test_pos_edge_index
    neg_edge_index = negative_sampling(
        data.edge_index,
        num_nodes=data.num_nodes,
        num_neg_samples=pos_edge_index.size(1)
    )

    pos_scores = torch.sigmoid((z[pos_edge_index[0]] * z[pos_edge_index[1]]).sum(dim=1))
    neg_scores = torch.sigmoid((z[neg_edge_index[0]] * z[neg_edge_index[1]]).sum(dim=1))

    labels = torch.cat([torch.ones_like(pos_scores), torch.zeros_like(neg_scores)])
    scores = torch.cat([pos_scores, neg_scores])

    auc = roc_auc_score(labels.cpu().numpy(), scores.detach().cpu().numpy())
    return auc

Model Export and Serving

# Convert to ONNX for inference
import onnx

dummy_x = torch.randn(10, data.num_node_features)
dummy_edge_index = torch.randint(0, 10, (2, 20))

torch.onnx.export(
    model,
    (dummy_x, dummy_edge_index),
    'gnn_model.onnx',
    input_names=['x', 'edge_index'],
    output_names=['output'],
    opset_version=11
)

Key Takeaway

Graph Neural Networks unlock the power of structured relationships in data. By learning to propagate information through graph neighborhoods, GNNs excel at node classification, link prediction, and graph-level tasks. Master message passing, sampling strategies, and graph pooling to build scalable models for complex relational data.

Practical Exercise

Task: Build a GNN-based link prediction system for a social network or citation network.

Requirements:

  1. Load a graph dataset (Cora, Citeseer, or custom)
  2. Implement both GCN and GraphSAGE architectures
  3. Create train/val/test split on edges (80/10/10)
  4. Train with link prediction objective:
    • Positive pairs: existing edges
    • Negative pairs: sampled non-edges
  5. Use dot product similarity: score = z_i · z_j
  6. Implement evaluation using AUC-ROC and AP metrics

Evaluation:

  • Compare GCN vs GraphSAGE on test AUC
  • Analyze computational efficiency (training time)
  • Test inductive setting: encode unseen nodes
  • Visualize learned embeddings with t-SNE
  • Implement hard negative sampling for better training