Text Generation and Summarization
Text Generation and Summarization
Text summarization and generation require models that understand context and produce coherent, fluent text. This lesson covers encoder-decoder architectures (T5, BART), abstractive and extractive summarization, controlled generation, and evaluation metrics like ROUGE and BLEU.
Core Concepts
Summarization Approaches
Extractive Summarization: Select key sentences from source text
- Advantages: Preserves factual accuracy, computationally efficient
- Disadvantages: Lacks coherence, no paraphrasing
- Methods: TF-IDF, TextRank, neural ranking
Abstractive Summarization: Generate summary with new words
- Advantages: More fluent, natural, concise
- Disadvantages: Risk of hallucination, requires more training data
- Methods: Seq2Seq, BART, T5
Hybrid Approaches: Extract then abstract, or hierarchical attention
Encoder-Decoder Architecture for Generation
Input: "Document text..."
Encoder: Processes document, creates context representation
Decoder: Generates summary token-by-token, attends to encoder
Output: "Summary of key points..."
Key components:
- Cross-attention: Decoder attends to encoder outputs
- Teacher forcing: During training, feed ground-truth previous tokens
- Beam search: During inference, explore multiple hypotheses
T5 and BART Models
T5 (Text-to-Text Transfer Transformer):
- Treats all tasks as text-to-text
- Prefix-based: “summarize: document text”
- Unified architecture for multiple tasks
- Pre-trained on variety of text-to-text objectives
BART (Bidirectional Auto-Regressive Transformers):
- Encoder-decoder like BERT + GPT
- Pre-trained with corrupted text → reconstruction
- Particularly effective for summarization
- Can handle both classification and generation
Practical Implementation
Abstractive Summarization with BART
from transformers import (
AutoTokenizer,
AutoModelForSeq2SeqLM,
TrainingArguments,
Trainer,
DataCollatorForSeq2Seq,
)
from datasets import load_dataset
import numpy as np
import nltk
from nltk.tokenize import sent_tokenize
from rouge_score import rouge_scorer
# Download NLTK data
nltk.download('punkt')
# Load dataset
dataset = load_dataset('cnn_dailymail', '3.0.0')
# Model and tokenizer
model_name = 'facebook/bart-large-cnn'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
# Preprocessing
max_input_length = 1024
max_target_length = 128
def preprocess_function(examples):
inputs = [doc for doc in examples['article']]
model_inputs = tokenizer(
inputs,
max_length=max_input_length,
truncation=True,
padding='max_length',
)
with tokenizer.as_target_tokenizer():
labels = tokenizer(
examples['highlights'],
max_length=max_target_length,
truncation=True,
padding='max_length',
)
model_inputs['labels'] = labels['input_ids']
return model_inputs
tokenized_dataset = dataset.map(preprocess_function, batched=True)
# Data collator
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)
# ROUGE metrics
def compute_metrics(eval_pred):
predictions, labels = eval_pred
predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)
rouge_scores = {'rouge1': [], 'rougeL': []}
for pred, label in zip(decoded_preds, decoded_labels):
scores = scorer.score(label, pred)
rouge_scores['rouge1'].append(scores['rouge1'].fmeasure)
rouge_scores['rougeL'].append(scores['rougeL'].fmeasure)
return {
'rouge1': np.mean(rouge_scores['rouge1']),
'rougeL': np.mean(rouge_scores['rougeL']),
}
# Training
training_args = TrainingArguments(
output_dir='./summarization_model',
num_train_epochs=3,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
warmup_steps=500,
weight_decay=0.01,
logging_steps=100,
evaluation_strategy='epoch',
save_strategy='epoch',
load_best_model_at_end=True,
metric_for_best_model='rouge1',
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset['train'],
eval_dataset=tokenized_dataset['validation'],
compute_metrics=compute_metrics,
data_collator=data_collator,
)
trainer.train()
Extractive Summarization with TextRank
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
class TextRankSummarizer:
def __init__(self, num_sentences=3):
self.num_sentences = num_sentences
def summarize(self, text):
# Tokenize into sentences
sentences = sent_tokenize(text)
if len(sentences) <= self.num_sentences:
return ' '.join(sentences)
# Vectorize sentences
vectorizer = TfidfVectorizer().fit_transform(sentences)
similarity_matrix = cosine_similarity(vectorizer)
# TextRank algorithm
scores = np.ones(len(sentences)) / len(sentences)
for _ in range(100):
scores_new = 0.15 / len(sentences) + 0.85 * similarity_matrix.T.dot(scores)
if np.linalg.norm(scores - scores_new) < 0.001:
break
scores = scores_new
# Select top-k sentences
top_indices = np.argsort(scores)[-self.num_sentences:]
top_indices = np.sort(top_indices)
summary = ' '.join([sentences[i] for i in top_indices])
return summary
# Usage
summarizer = TextRankSummarizer(num_sentences=3)
text = "Long document text here..."
summary = summarizer.summarize(text)
Controlled Generation with Prefixes
class ControlledGenerator:
def __init__(self, model_name='facebook/bart-large'):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(self.device)
def generate_with_control(self, text, task='summarize', length='short'):
"""Generate with task prefix and length control"""
# Task-specific prefixes
prefixes = {
'summarize': 'summarize: ',
'question': 'question: ',
'paraphrase': 'paraphrase: ',
'expand': 'expand: ',
}
# Length control
length_params = {
'short': {'min_length': 20, 'max_length': 50},
'medium': {'min_length': 50, 'max_length': 150},
'long': {'min_length': 150, 'max_length': 300},
}
prefix = prefixes.get(task, '')
length_config = length_params.get(length, length_params['medium'])
input_text = prefix + text
inputs = self.tokenizer(input_text, return_tensors='pt', truncation=True, max_length=1024)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model.generate(
**inputs,
num_beams=5,
**length_config,
early_stopping=True,
)
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Usage
generator = ControlledGenerator()
summary = generator.generate_with_control('Long text...', task='summarize', length='short')
T5 for Multiple Tasks
class T5MultiTask:
def __init__(self, model_name='t5-base'):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model.to(self.device)
def summarize(self, text, max_length=128):
input_text = f"summarize: {text}"
return self._generate(input_text, max_length)
def translate(self, text, language, max_length=128):
input_text = f"translate to {language}: {text}"
return self._generate(input_text, max_length)
def paraphrase(self, text, max_length=128):
input_text = f"paraphrase: {text}"
return self._generate(input_text, max_length)
def question_generation(self, text, answer, max_length=128):
input_text = f"question: {text} answer: {answer}"
return self._generate(input_text, max_length)
def _generate(self, input_text, max_length):
inputs = self.tokenizer(input_text, return_tensors='pt', truncation=True, max_length=512)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_length=max_length,
num_beams=4,
early_stopping=True,
)
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Usage
t5 = T5MultiTask()
summary = t5.summarize('Long document text...')
translation = t5.translate('Hello world', 'French')
Advanced Techniques
Reinforcement Learning for Summarization
from torch.optim import Adam
from torch.nn import functional as F
class RLSummarizer:
def __init__(self, model, tokenizer, device):
self.model = model
self.tokenizer = tokenizer
self.device = device
self.optimizer = Adam(model.parameters(), lr=1e-5)
def compute_reward(self, reference, candidate):
"""Compute ROUGE-based reward"""
scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
scores = scorer.score(reference, candidate)
return scores['rougeL'].fmeasure
def train_step(self, text, reference):
"""Train with RL objective"""
inputs = self.tokenizer(text, return_tensors='pt', truncation=True)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Generate with sampling
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_length=128,
do_sample=True,
top_p=0.95,
num_return_sequences=4,
)
# Compute rewards
candidates = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
rewards = [self.compute_reward(reference, cand) for cand in candidates]
rewards = torch.tensor(rewards, device=self.device)
# Normalize rewards
rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
# Compute loss for top-k samples
top_k = 2
top_indices = torch.argsort(rewards, descending=True)[:top_k]
loss = 0
for idx in top_indices:
# Recompute log probability for selected samples
input_ids = inputs['input_ids']
output_seq = outputs[idx]
logits = self.model(input_ids=input_ids, decoder_input_ids=output_seq[:-1])
log_probs = F.log_softmax(logits.logits, dim=-1)
selected_log_probs = log_probs[0, range(len(output_seq)-1), output_seq[1:]]
loss = loss - rewards[idx] * selected_log_probs.sum()
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss.item()
Consistency and Factuality Checking
from transformers import pipeline
class FactualityChecker:
def __init__(self):
self.entailment_checker = pipeline('zero-shot-classification',
model='facebook/bart-large-mnli')
def check_factuality(self, document, summary):
"""Check if summary is entailed by document"""
result = self.entailment_checker(
document,
[summary],
hypothesis_template='This text: {}',
)
entailment_score = result['scores'][result['labels'].index('entailment')]
return entailment_score
def detect_hallucinations(self, document, summary, threshold=0.5):
"""Detect claims in summary not supported by document"""
sentences = sent_tokenize(summary)
hallucinated = []
for sent in sentences:
score = self.check_factuality(document, sent)
if score < threshold:
hallucinated.append({'text': sent, 'score': score})
return hallucinated
Production Considerations
Batch Processing for Efficiency
def batch_summarize(texts, batch_size=16):
"""Summarize multiple documents efficiently"""
model = AutoModelForSeq2SeqLM.from_pretrained('facebook/bart-large-cnn')
tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large-cnn')
summaries = []
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i+batch_size]
inputs = tokenizer(batch_texts, return_tensors='pt', padding=True,
truncation=True, max_length=1024)
with torch.no_grad():
summary_ids = model.generate(
**inputs,
max_length=150,
num_beams=4,
)
batch_summaries = tokenizer.batch_decode(summary_ids, skip_special_tokens=True)
summaries.extend(batch_summaries)
return summaries
Latency Optimization
# Use smaller models for faster inference
def fast_summarize(text):
model = AutoModelForSeq2SeqLM.from_pretrained('facebook/bart-large-cnn')
model.half() # FP16
tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large-cnn')
inputs = tokenizer(text, return_tensors='pt', truncation=True, max_length=1024)
with torch.no_grad():
summary_ids = model.generate(**inputs, max_length=150, num_beams=2)
return tokenizer.decode(summary_ids[0], skip_special_tokens=True)
Key Takeaway
Text generation and summarization require models that understand context, maintain coherence, and produce factually consistent output. Master encoder-decoder architectures, evaluation metrics, and techniques for controlling output to build reliable generation systems.
Practical Exercise
Task: Build a multi-document summarization system with fact-checking.
Requirements:
- Fine-tune BART on multi-document dataset
- Implement extractive + abstractive hybrid
- Add factuality checker to filter hallucinations
- Support multiple summary lengths
- Evaluate with ROUGE and human judgment
Evaluation:
- ROUGE-1, ROUGE-L > 0.35
- Human eval: readability and factuality
- Compare with lead-3 baseline
- Test on out-of-domain documents
- Profile inference time and memory