GPT and Decoder Models
GPT and Decoder Models
GPT introduced autoregressive language modeling at scale, proving that transformer decoders could learn powerful representations through next-token prediction alone. This lesson covers GPT’s causal attention mechanism, autoregressive generation, temperature/top-p sampling strategies, and the differences from BERT’s bidirectional approach.
Core Concepts
Autoregressive Language Modeling
Unlike BERT’s bidirectional masked prediction, GPT uses autoregressive (left-to-right) prediction:
- Model learns to predict next token given previous tokens
- Training objective: minimize loss of predicting token t+1 given tokens 1..t
- Enables generation by sampling iteratively
Loss function:
L = -log P(x_1) - log P(x_2 | x_1) - log P(x_3 | x_1, x_2) - ...
Each position can only attend to previous positions (causal masking).
Causal Attention Masking
Decoder must never attend to future tokens. Implemented by setting future attention scores to -∞ before softmax:
if position_i > position_j:
attention_score[i,j] = -∞
This ensures attention matrix is lower triangular.
Generation Strategies
Greedy Decoding:
x_t = argmax(P(x_t | x_1..x_{t-1}))
Fast but often produces repetitive, low-quality text.
Beam Search:
Keep top-k sequences at each step, expand by one token.
Use length normalization: score = log P(sequence) / length
Better quality but slower than greedy.
Top-p (Nucleus) Sampling:
1. Sort tokens by probability
2. Select minimum set of top tokens with cumulative probability ≥ p
3. Renormalize and sample from this distribution
Balances quality and diversity.
Temperature Scaling:
P'(x_t) = softmax(logits / temperature)
- temperature < 1: sharper distribution, more confident
- temperature = 1: original distribution
- temperature > 1: softer distribution, more diversity
Practical Implementation
GPT-Style Decoder Model
import torch
import torch.nn as nn
from transformers import GPT2Tokenizer, GPT2LMHeadModel
# Using pre-trained GPT-2
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')
# Forward pass
input_ids = tokenizer.encode("Once upon a time", return_tensors='pt')
outputs = model(input_ids)
logits = outputs.logits # [batch_size, seq_len, vocab_size]
# Next token prediction
next_token_logits = logits[:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1)
next_token_text = tokenizer.decode([next_token])
Greedy Decoding
def greedy_generate(model, tokenizer, prompt, max_length=50):
input_ids = tokenizer.encode(prompt, return_tensors='pt')
model.eval()
with torch.no_grad():
for _ in range(max_length - len(input_ids[0])):
outputs = model(input_ids)
next_token_logits = outputs.logits[:, -1, :]
next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)
input_ids = torch.cat([input_ids, next_token], dim=1)
# Stop if EOS token generated
if next_token == tokenizer.eos_token_id:
break
return tokenizer.decode(input_ids[0])
generated = greedy_generate(model, tokenizer, "Once upon a time")
print(generated)
Beam Search
from transformers import BeamSearchScorer, LogitsProcessorList
def beam_search_generate(model, tokenizer, prompt, num_beams=5, max_length=50):
input_ids = tokenizer.encode(prompt, return_tensors='pt')
# Initialize beam search scorer
beam_scorer = BeamSearchScorer(
batch_size=1,
num_beams=num_beams,
device=model.device,
length_penalty=1.0,
early_stopping=True
)
with torch.no_grad():
outputs = model.generate(
input_ids,
max_length=max_length,
num_beams=num_beams,
beam_scorer=beam_scorer,
do_sample=False
)
return tokenizer.decode(outputs[0])
generated = beam_search_generate(model, tokenizer, "Once upon a time", num_beams=5)
Top-p and Top-k Sampling
def top_k_top_p_sampling(logits, top_k=10, top_p=0.9, temperature=1.0):
"""Top-k and top-p sampling"""
# Apply temperature
logits = logits / temperature
# Top-k filtering
if top_k > 0:
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = float('-inf')
# Top-p filtering
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumsum_logits = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumsum_logits > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = float('-inf')
# Sample
probabilities = torch.softmax(logits, dim=-1)
next_token = torch.multinomial(probabilities, num_samples=1)
return next_token
def sample_generate(model, tokenizer, prompt, max_length=50, top_k=50, top_p=0.95, temperature=1.0):
input_ids = tokenizer.encode(prompt, return_tensors='pt')
model.eval()
with torch.no_grad():
for _ in range(max_length - len(input_ids[0])):
outputs = model(input_ids)
next_token_logits = outputs.logits[:, -1, :]
# Apply sampling
next_token = top_k_top_p_sampling(
next_token_logits[0],
top_k=top_k,
top_p=top_p,
temperature=temperature
)
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
if next_token == tokenizer.eos_token_id:
break
return tokenizer.decode(input_ids[0])
# Low temperature = deterministic
greedy_style = sample_generate(model, tokenizer, "Once upon", temperature=0.7)
# High temperature = diverse
creative_style = sample_generate(model, tokenizer, "Once upon", temperature=1.5, top_p=0.9)
Fine-tuning GPT-2 on Custom Data
from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling
from datasets import load_dataset
# Load custom text data
dataset = load_dataset('text', data_files='custom_corpus.txt')
# Tokenize
def tokenize_function(examples):
return tokenizer(examples['text'], truncation=True, max_length=512)
tokenized_dataset = dataset.map(tokenize_function, batched=True)
# Data collator for MLM
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
# Training arguments
training_args = TrainingArguments(
output_dir='./gpt2-finetuned',
num_train_epochs=3,
per_device_train_batch_size=8,
save_steps=10_000,
save_total_limit=2,
logging_steps=500,
)
# Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset['train'],
data_collator=data_collator,
)
trainer.train()
# Save
model.save_pretrained('./gpt2-finetuned')
tokenizer.save_pretrained('./gpt2-finetuned')
Advanced Techniques
Conditional Text Generation
def conditional_generate(model, tokenizer, prompt, condition_keyword, max_length=50):
"""Generate text containing specific keyword"""
input_ids = tokenizer.encode(prompt, return_tensors='pt')
condition_ids = tokenizer.encode(f" {condition_keyword}", add_special_tokens=False)[0]
generated = input_ids.clone()
while len(generated[0]) < max_length:
outputs = model(generated)
logits = outputs.logits[:, -1, :]
# Penalize all tokens except condition
logits[0, :condition_ids] -= 100
logits[0, condition_ids+1:] -= 100
next_token = torch.argmax(logits, dim=-1).unsqueeze(-1)
generated = torch.cat([generated, next_token], dim=1)
if next_token == tokenizer.eos_token_id:
break
return tokenizer.decode(generated[0])
Prompt Engineering with Few-Shot
def few_shot_generate(model, tokenizer, examples, new_input, max_length=100):
"""Few-shot generation with examples in prompt"""
prompt = "Examples:\n"
for input_text, output_text in examples:
prompt += f"Input: {input_text}\nOutput: {output_text}\n\n"
prompt += f"Input: {new_input}\nOutput:"
input_ids = tokenizer.encode(prompt, return_tensors='pt')
outputs = model.generate(input_ids, max_length=max_length, num_beams=5)
return tokenizer.decode(outputs[0])
# Usage
examples = [
("Translate to French: hello", "bonjour"),
("Translate to French: goodbye", "au revoir"),
]
result = few_shot_generate(model, tokenizer, examples, "Translate to French: thank you")
KV-Cache for Efficient Generation
def generate_with_cache(model, tokenizer, prompt, max_length=50):
"""Generate using KV-cache to avoid recomputing attention"""
input_ids = tokenizer.encode(prompt, return_tensors='pt')
past_key_values = None
with torch.no_grad():
for _ in range(max_length - len(input_ids[0])):
outputs = model(
input_ids[:, -1:], # Only pass last token
past_key_values=past_key_values,
use_cache=True
)
next_token = torch.argmax(outputs.logits[:, -1, :], dim=-1).unsqueeze(-1)
input_ids = torch.cat([input_ids, next_token], dim=1)
past_key_values = outputs.past_key_values
if next_token == tokenizer.eos_token_id:
break
return tokenizer.decode(input_ids[0])
Production Considerations
Efficient Generation with vLLM
# Using vLLM for batch generation
from vllm import LLM, SamplingParams
llm = LLM(model="gpt2")
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
max_tokens=100
)
prompts = [
"Once upon a time",
"In a galaxy far away",
]
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
print(output.outputs[0].text)
Streaming Generation
def stream_generate(model, tokenizer, prompt, max_length=50):
"""Stream tokens as they're generated"""
input_ids = tokenizer.encode(prompt, return_tensors='pt')
model.eval()
with torch.no_grad():
for _ in range(max_length - len(input_ids[0])):
outputs = model(input_ids)
next_token = torch.argmax(outputs.logits[:, -1, :], dim=-1).unsqueeze(-1)
input_ids = torch.cat([input_ids, next_token], dim=1)
# Yield token for streaming
yield tokenizer.decode([next_token[0, 0]])
if next_token == tokenizer.eos_token_id:
break
# Usage in web app
for token in stream_generate(model, tokenizer, "Once upon"):
print(token, end='', flush=True)
Key Takeaway
GPT’s autoregressive generation paradigm enabled scaling language models to billions of parameters with next-token prediction as the sole objective. Understanding causal masking, sampling strategies, and generation techniques is essential for building effective generative models and applications.
Practical Exercise
Task: Build a chatbot using fine-tuned GPT-2 with conversation history context.
Requirements:
- Fine-tune GPT-2 on dialogue dataset (ConvAI2 or similar)
- Implement conversation history management
- Use top-p sampling for natural responses
- Implement multi-turn context (last 3-5 turns)
- Add response quality filters (length, profanity)
Evaluation:
- Evaluate on dialogue quality (BLEU, METEOR)
- Test coherence across conversation turns
- Measure response diversity
- A/B test sampling parameters (temperature, top_p)
- Analyze failure modes (repetition, off-topic)