The Transformer Architecture
The Transformer Architecture
The Transformer architecture revolutionized deep learning by replacing recurrence entirely with self-attention mechanisms. This lesson covers the mathematical foundations of self-attention, multi-head attention, positional encoding, and the encoder-decoder structure that underlies modern language models and beyond.
Core Concepts
Self-Attention Mechanism
Self-attention computes relationships between all positions in a sequence in a single operation, enabling parallel processing and long-range dependencies.
Scaled Dot-Product Attention:
Attention(Q, K, V) = softmax(QK^T / √d_k)V
Where:
- Q (Query): Learned projection asking “what am I looking for?”
- K (Key): Learned projection answering “what can I offer?”
- V (Value): Information to be attended to
- d_k: Dimension of keys (used for scaling)
Why Scaling? Without the √d_k scaling, dot products can become very large, causing gradients to vanish in softmax.
Step-by-step:
- Compute attention scores: S = QK^T / √d_k
- Apply softmax to normalize: A = softmax(S)
- Weight values: Output = AV
Computational Complexity: O(n²d) where n is sequence length and d is embedding dimension. Quadratic in sequence length!
Multi-Head Attention
Instead of single attention, compute h parallel attention heads with different learned projections:
MultiHead(Q,K,V) = Concat(head_1, ..., head_h)W^O
where head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)
Advantages:
- Different heads attend to different representation subspaces
- One head might focus on syntax, another on semantics
- Enables richer representation without increasing computation proportionally
Example: With d=512 and h=8, each head processes 64-dimensional subspace
Positional Encoding
Transformers lack inherent sequential order, so position information must be injected. Positional encodings are added to embeddings:
PE(pos, 2i) = sin(pos / 10000^(2i/d))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d))
Properties:
- Periodic patterns at different frequencies
- Each position has unique encoding
- Enables model to learn relative positions
- Extrapolates to longer sequences than training
Encoder-Decoder Architecture
Encoder Stack:
- Input embeddings + positional encoding
- N identical layers with:
- Multi-head self-attention
- Feed-forward network (2 linear layers with ReLU)
- Residual connections and layer normalization
- Output: context-aware representations
Decoder Stack:
- Output embeddings + positional encoding
- N identical layers with:
- Masked multi-head self-attention (can’t attend to future)
- Encoder-decoder attention (attends to encoder output)
- Feed-forward network
- Residual connections and layer normalization
- Output: prediction for next token
Masking: Prevents decoder from cheating by looking ahead during training
Practical Implementation
Building Transformer from Scratch
import torch
import torch.nn as nn
import math
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_seq_length=5000):
super().__init__()
pe = torch.zeros(max_seq_length, d_model)
position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
(-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x):
return x + self.pe[:, :x.size(1)]
class ScaledDotProductAttention(nn.Module):
def forward(self, Q, K, V, mask=None):
d_k = Q.shape[-1]
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attention = torch.softmax(scores, dim=-1)
return torch.matmul(attention, V), attention
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.attention = ScaledDotProductAttention()
def forward(self, Q, K, V, mask=None):
batch_size = Q.shape[0]
# Linear transformations
Q = self.W_q(Q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
K = self.W_k(K).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
V = self.W_v(V).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
# Apply attention
attn_output, _ = self.attention(Q, K, V, mask)
# Concatenate heads
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, -1, self.d_model)
# Final linear transformation
return self.W_o(attn_output)
class FeedForward(nn.Module):
def __init__(self, d_model, d_ff):
super().__init__()
self.fc1 = nn.Linear(d_model, d_ff)
self.fc2 = nn.Linear(d_ff, d_model)
self.relu = nn.ReLU()
def forward(self, x):
return self.fc2(self.relu(self.fc1(x)))
class TransformerLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
super().__init__()
self.attention = MultiHeadAttention(d_model, num_heads)
self.feed_forward = FeedForward(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x, mask=None):
# Self-attention with residual
attn_output = self.attention(x, x, x, mask)
x = self.norm1(x + self.dropout1(attn_output))
# Feed-forward with residual
ff_output = self.feed_forward(x)
x = self.norm2(x + self.dropout2(ff_output))
return x
class TransformerEncoder(nn.Module):
def __init__(self, vocab_size, d_model, num_heads, d_ff, num_layers, max_seq_length=5000):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoding = PositionalEncoding(d_model, max_seq_length)
self.layers = nn.ModuleList([
TransformerLayer(d_model, num_heads, d_ff)
for _ in range(num_layers)
])
def forward(self, x, mask=None):
x = self.embedding(x)
x = self.pos_encoding(x)
for layer in self.layers:
x = layer(x, mask)
return x
class TransformerDecoder(nn.Module):
def __init__(self, vocab_size, d_model, num_heads, d_ff, num_layers, max_seq_length=5000):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoding = PositionalEncoding(d_model, max_seq_length)
self.d_model = d_model
self.layers = nn.ModuleList([
DecoderLayer(d_model, num_heads, d_ff)
for _ in range(num_layers)
])
self.fc_out = nn.Linear(d_model, vocab_size)
def forward(self, x, encoder_output, tgt_mask=None, src_mask=None):
x = self.embedding(x)
x = self.pos_encoding(x)
for layer in self.layers:
x = layer(x, encoder_output, tgt_mask, src_mask)
return self.fc_out(x)
class DecoderLayer(nn.Module):
def __init__(self, d_model, num_heads, d_ff):
super().__init__()
self.self_attention = MultiHeadAttention(d_model, num_heads)
self.encoder_attention = MultiHeadAttention(d_model, num_heads)
self.feed_forward = FeedForward(d_model, d_ff)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
def forward(self, x, encoder_output, tgt_mask=None, src_mask=None):
# Masked self-attention
attn = self.self_attention(x, x, x, tgt_mask)
x = self.norm1(x + attn)
# Cross-attention to encoder
attn = self.encoder_attention(x, encoder_output, encoder_output, src_mask)
x = self.norm2(x + attn)
# Feed-forward
ff = self.feed_forward(x)
x = self.norm3(x + ff)
return x
Causal Masking for Autoregressive Decoding
def create_causal_mask(seq_length, device):
"""Create lower triangular matrix for causal masking"""
mask = torch.tril(torch.ones(seq_length, seq_length, device=device))
return mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, seq_len]
# During decoding
tgt_mask = create_causal_mask(tgt_seq_length, device)
decoder_output = decoder(tgt, encoder_output, tgt_mask)
Advanced Techniques
Relative Position Representations
Standard positional encodings don’t capture relative distances well. Relative position representations:
class RelativePositionalBias(nn.Module):
def __init__(self, num_heads, max_distance):
super().__init__()
self.relative_attention_bias = nn.Embedding(max_distance * 2, num_heads)
self.max_distance = max_distance
def forward(self, seq_length):
relative_positions = torch.arange(seq_length).unsqueeze(1) - torch.arange(seq_length)
relative_positions = torch.clamp(relative_positions, -self.max_distance, self.max_distance)
relative_positions = relative_positions + self.max_distance
return self.relative_attention_bias(relative_positions)
Flash Attention (Efficient Implementation)
try:
from flash_attn import flash_attn_func
class FlashAttention(nn.Module):
def forward(self, Q, K, V):
return flash_attn_func(Q, K, V, dropout_p=0.0)
except ImportError:
print("Flash Attention not available, use standard implementation")
Layer Normalization Variants
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
Production Considerations
Efficient Inference
@torch.no_grad()
def generate(encoder_output, decoder, start_token, max_length, device):
"""Autoregressive generation with caching"""
batch_size = encoder_output.shape[0]
tgt = torch.full((batch_size, 1), start_token, dtype=torch.long, device=device)
for _ in range(max_length - 1):
decoder_output = decoder(tgt, encoder_output)
next_token = decoder_output[:, -1, :].argmax(dim=-1, keepdim=True)
tgt = torch.cat([tgt, next_token], dim=1)
return tgt
Attention Visualization
def visualize_attention(attention_weights, tokens):
import matplotlib.pyplot as plt
import seaborn as sns
plt.figure(figsize=(8, 6))
sns.heatmap(attention_weights[0, 0].cpu(),
xticklabels=tokens,
yticklabels=tokens,
cmap='viridis')
plt.title('Attention Weights')
plt.show()
Key Takeaway
The Transformer architecture’s self-attention mechanism enables parallel processing of entire sequences while capturing long-range dependencies. Understanding scaled dot-product attention, multi-head projection, positional encoding, and the encoder-decoder structure is fundamental to building and customizing modern neural architectures.
Practical Exercise
Task: Implement a sequence-to-sequence Transformer for machine translation (English to French).
Requirements:
- Build encoder-decoder Transformer from scratch
- Implement causal masking for decoder
- Train on subset of WMT14 dataset
- Implement beam search for decoding
- Evaluate with BLEU score
Evaluation:
- Achieve BLEU score > 15 on test set
- Visualize attention patterns to verify it learns meaningful alignments
- Compare greedy decoding vs. beam search
- Profile inference time and memory usage
- Analyze failure cases and attention visualizations