Intermediate

The Transformer Architecture

Lesson 1 of 4 Estimated Time 60 min

The Transformer Architecture

The Transformer architecture revolutionized deep learning by replacing recurrence entirely with self-attention mechanisms. This lesson covers the mathematical foundations of self-attention, multi-head attention, positional encoding, and the encoder-decoder structure that underlies modern language models and beyond.

Core Concepts

Self-Attention Mechanism

Self-attention computes relationships between all positions in a sequence in a single operation, enabling parallel processing and long-range dependencies.

Scaled Dot-Product Attention:

Attention(Q, K, V) = softmax(QK^T / √d_k)V

Where:

  • Q (Query): Learned projection asking “what am I looking for?”
  • K (Key): Learned projection answering “what can I offer?”
  • V (Value): Information to be attended to
  • d_k: Dimension of keys (used for scaling)

Why Scaling? Without the √d_k scaling, dot products can become very large, causing gradients to vanish in softmax.

Step-by-step:

  1. Compute attention scores: S = QK^T / √d_k
  2. Apply softmax to normalize: A = softmax(S)
  3. Weight values: Output = AV

Computational Complexity: O(n²d) where n is sequence length and d is embedding dimension. Quadratic in sequence length!

Multi-Head Attention

Instead of single attention, compute h parallel attention heads with different learned projections:

MultiHead(Q,K,V) = Concat(head_1, ..., head_h)W^O

where head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)

Advantages:

  • Different heads attend to different representation subspaces
  • One head might focus on syntax, another on semantics
  • Enables richer representation without increasing computation proportionally

Example: With d=512 and h=8, each head processes 64-dimensional subspace

Positional Encoding

Transformers lack inherent sequential order, so position information must be injected. Positional encodings are added to embeddings:

PE(pos, 2i) = sin(pos / 10000^(2i/d))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d))

Properties:

  • Periodic patterns at different frequencies
  • Each position has unique encoding
  • Enables model to learn relative positions
  • Extrapolates to longer sequences than training

Encoder-Decoder Architecture

Encoder Stack:

  • Input embeddings + positional encoding
  • N identical layers with:
    • Multi-head self-attention
    • Feed-forward network (2 linear layers with ReLU)
    • Residual connections and layer normalization
  • Output: context-aware representations

Decoder Stack:

  • Output embeddings + positional encoding
  • N identical layers with:
    • Masked multi-head self-attention (can’t attend to future)
    • Encoder-decoder attention (attends to encoder output)
    • Feed-forward network
    • Residual connections and layer normalization
  • Output: prediction for next token

Masking: Prevents decoder from cheating by looking ahead during training

Practical Implementation

Building Transformer from Scratch

import torch
import torch.nn as nn
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length=5000):
        super().__init__()
        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() *
                            (-math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

class ScaledDotProductAttention(nn.Module):
    def forward(self, Q, K, V, mask=None):
        d_k = Q.shape[-1]
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attention = torch.softmax(scores, dim=-1)
        return torch.matmul(attention, V), attention

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

        self.attention = ScaledDotProductAttention()

    def forward(self, Q, K, V, mask=None):
        batch_size = Q.shape[0]

        # Linear transformations
        Q = self.W_q(Q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(K).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(V).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        # Apply attention
        attn_output, _ = self.attention(Q, K, V, mask)

        # Concatenate heads
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, -1, self.d_model)

        # Final linear transformation
        return self.W_o(attn_output)

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))

class TransformerLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attention = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = FeedForward(d_model, d_ff)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Self-attention with residual
        attn_output = self.attention(x, x, x, mask)
        x = self.norm1(x + self.dropout1(attn_output))

        # Feed-forward with residual
        ff_output = self.feed_forward(x)
        x = self.norm2(x + self.dropout2(ff_output))

        return x

class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, d_ff, num_layers, max_seq_length=5000):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_seq_length)
        self.layers = nn.ModuleList([
            TransformerLayer(d_model, num_heads, d_ff)
            for _ in range(num_layers)
        ])

    def forward(self, x, mask=None):
        x = self.embedding(x)
        x = self.pos_encoding(x)

        for layer in self.layers:
            x = layer(x, mask)

        return x

class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, d_ff, num_layers, max_seq_length=5000):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_seq_length)
        self.d_model = d_model

        self.layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff)
            for _ in range(num_layers)
        ])

        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, x, encoder_output, tgt_mask=None, src_mask=None):
        x = self.embedding(x)
        x = self.pos_encoding(x)

        for layer in self.layers:
            x = layer(x, encoder_output, tgt_mask, src_mask)

        return self.fc_out(x)

class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.self_attention = MultiHeadAttention(d_model, num_heads)
        self.encoder_attention = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = FeedForward(d_model, d_ff)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

    def forward(self, x, encoder_output, tgt_mask=None, src_mask=None):
        # Masked self-attention
        attn = self.self_attention(x, x, x, tgt_mask)
        x = self.norm1(x + attn)

        # Cross-attention to encoder
        attn = self.encoder_attention(x, encoder_output, encoder_output, src_mask)
        x = self.norm2(x + attn)

        # Feed-forward
        ff = self.feed_forward(x)
        x = self.norm3(x + ff)

        return x

Causal Masking for Autoregressive Decoding

def create_causal_mask(seq_length, device):
    """Create lower triangular matrix for causal masking"""
    mask = torch.tril(torch.ones(seq_length, seq_length, device=device))
    return mask.unsqueeze(0).unsqueeze(0)  # [1, 1, seq_len, seq_len]

# During decoding
tgt_mask = create_causal_mask(tgt_seq_length, device)
decoder_output = decoder(tgt, encoder_output, tgt_mask)

Advanced Techniques

Relative Position Representations

Standard positional encodings don’t capture relative distances well. Relative position representations:

class RelativePositionalBias(nn.Module):
    def __init__(self, num_heads, max_distance):
        super().__init__()
        self.relative_attention_bias = nn.Embedding(max_distance * 2, num_heads)
        self.max_distance = max_distance

    def forward(self, seq_length):
        relative_positions = torch.arange(seq_length).unsqueeze(1) - torch.arange(seq_length)
        relative_positions = torch.clamp(relative_positions, -self.max_distance, self.max_distance)
        relative_positions = relative_positions + self.max_distance
        return self.relative_attention_bias(relative_positions)

Flash Attention (Efficient Implementation)

try:
    from flash_attn import flash_attn_func

    class FlashAttention(nn.Module):
        def forward(self, Q, K, V):
            return flash_attn_func(Q, K, V, dropout_p=0.0)
except ImportError:
    print("Flash Attention not available, use standard implementation")

Layer Normalization Variants

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(dim))
        self.eps = eps

    def forward(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight

Production Considerations

Efficient Inference

@torch.no_grad()
def generate(encoder_output, decoder, start_token, max_length, device):
    """Autoregressive generation with caching"""
    batch_size = encoder_output.shape[0]
    tgt = torch.full((batch_size, 1), start_token, dtype=torch.long, device=device)

    for _ in range(max_length - 1):
        decoder_output = decoder(tgt, encoder_output)
        next_token = decoder_output[:, -1, :].argmax(dim=-1, keepdim=True)
        tgt = torch.cat([tgt, next_token], dim=1)

    return tgt

Attention Visualization

def visualize_attention(attention_weights, tokens):
    import matplotlib.pyplot as plt
    import seaborn as sns

    plt.figure(figsize=(8, 6))
    sns.heatmap(attention_weights[0, 0].cpu(),
                xticklabels=tokens,
                yticklabels=tokens,
                cmap='viridis')
    plt.title('Attention Weights')
    plt.show()

Key Takeaway

The Transformer architecture’s self-attention mechanism enables parallel processing of entire sequences while capturing long-range dependencies. Understanding scaled dot-product attention, multi-head projection, positional encoding, and the encoder-decoder structure is fundamental to building and customizing modern neural architectures.

Practical Exercise

Task: Implement a sequence-to-sequence Transformer for machine translation (English to French).

Requirements:

  1. Build encoder-decoder Transformer from scratch
  2. Implement causal masking for decoder
  3. Train on subset of WMT14 dataset
  4. Implement beam search for decoding
  5. Evaluate with BLEU score

Evaluation:

  • Achieve BLEU score > 15 on test set
  • Visualize attention patterns to verify it learns meaningful alignments
  • Compare greedy decoding vs. beam search
  • Profile inference time and memory usage
  • Analyze failure cases and attention visualizations