BERT and Encoder Models
BERT and Encoder Models
BERT revolutionized NLP by introducing bidirectional pre-training on large unlabeled corpora. This lesson covers BERT’s masked language modeling objective, fine-tuning strategies for downstream tasks, and the sentence-transformer extensions for semantic similarity.
Core Concepts
BERT: Bidirectional Encoder Representations from Transformers
BERT is a pre-trained transformer encoder designed to understand context from both directions (left and right context).
Key Innovation: Unlike traditional language models that predict left-to-right, BERT uses masked language modeling:
- Randomly mask 15% of input tokens
- Train model to predict masked tokens using surrounding context
- Forces bidirectional understanding
Architecture:
- 12 or 24 transformer layers
- Hidden size: 768 or 1024
- 12 attention heads
- Intermediate feed-forward dimension: 3072
- Pre-trained on Wikipedia + BookCorpus (3.3B words)
Masked Language Modeling (MLM) Objective
Input: "The [MASK] sat on the [MASK]"
Target: Predict "cat" and "mat"
Training procedure:
- Select 15% of tokens randomly
- For selected tokens:
- 80% of time: replace with [MASK]
- 10% of time: replace with random token
- 10% of time: keep original (to prevent model relying on perfect masking at test time)
Why this works: Forces model to learn contextual representations, not just memorize patterns
Next Sentence Prediction (NSP)
BERT also uses secondary objective during pre-training:
Input: [CLS] Sentence A [SEP] Sentence B [SEP]
Task: Predict if B follows A in corpus (binary classification)
Useful for tasks requiring sentence relationships (entailment, paraphrase detection)
Fine-tuning for Downstream Tasks
Text Classification:
[CLS] Text input [SEP] → [CLS] token → Linear → softmax → classes
Named Entity Recognition:
[CLS] Token1 Token2 ... [SEP] → Each token → Linear → BIO tags
Question Answering:
[CLS] Question [SEP] Context [SEP] → Span selection → Answer
Practical Implementation
Using Pre-trained BERT with Hugging Face
from transformers import AutoTokenizer, AutoModel, BertForSequenceClassification
import torch
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
model = AutoModel.from_pretrained('bert-base-uncased')
# Tokenize input
text = "The quick brown fox jumps over the lazy dog"
inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True)
# inputs = {'input_ids': [...], 'attention_mask': [...], 'token_type_ids': [...]}
# Get embeddings
with torch.no_grad():
outputs = model(**inputs)
# outputs.last_hidden_state: [batch_size, seq_len, 768]
# outputs.pooler_output: [batch_size, 768] (CLS token representation)
# Use CLS token embedding for classification
cls_embedding = outputs.pooler_output # [batch_size, 768]
Text Classification with BERT
from transformers import BertForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset
import numpy as np
from sklearn.metrics import accuracy_score, f1_score
# Load dataset
dataset = load_dataset('glue', 'sst2')
# Tokenize
def tokenize_function(examples):
return tokenizer(examples['sentence'], padding='max_length', truncation=True, max_length=128)
tokenized_datasets = dataset.map(tokenize_function, batched=True)
# Create model
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
# Define metrics
def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
return {
'accuracy': accuracy_score(labels, predictions),
'f1': f1_score(labels, predictions, average='weighted')
}
# Training arguments
training_args = TrainingArguments(
output_dir='./results',
num_train_epochs=3,
per_device_train_batch_size=32,
per_device_eval_batch_size=32,
warmup_steps=500,
weight_decay=0.01,
logging_dir='./logs',
)
# Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets['train'],
eval_dataset=tokenized_datasets['validation'],
compute_metrics=compute_metrics,
)
# Train
trainer.train()
Named Entity Recognition with BERT
from transformers import BertForTokenClassification
# Dataset: tokens with BIO tags
# BIO: B-PER (Beginning), I-PER (Inside), O (Outside)
class BERTForNER(torch.nn.Module):
def __init__(self, num_labels):
super().__init__()
self.bert = AutoModel.from_pretrained('bert-base-uncased')
self.dropout = torch.nn.Dropout(0.1)
self.classifier = torch.nn.Linear(768, num_labels)
def forward(self, input_ids, attention_mask=None):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
sequence_output = outputs[0]
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
return logits
model = BERTForNER(num_labels=9) # B-PER, I-PER, B-ORG, I-ORG, B-LOC, I-LOC, B-MISC, I-MISC, O
# Training
def train_ner():
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
loss_fn = torch.nn.CrossEntropyLoss()
model.train()
for epoch in range(3):
total_loss = 0
for batch in train_loader:
input_ids = batch['input_ids']
attention_mask = batch['attention_mask']
labels = batch['labels']
optimizer.zero_grad()
logits = model(input_ids, attention_mask)
# Flatten for loss computation
loss = loss_fn(
logits.view(-1, 9),
labels.view(-1)
)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f'Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}')
# Inference
@torch.no_grad()
def predict_ner(text):
inputs = tokenizer(text, return_tensors='pt')
logits = model(**inputs)
predictions = logits.argmax(dim=-1)
return predictions[0]
Sentence-BERT (SBERT) for Semantic Similarity
from sentence_transformers import SentenceTransformer, util
# Pre-trained sentence transformer (fine-tuned for semantic similarity)
model = SentenceTransformer('all-MiniLM-L6-v2') # 384-dim embeddings
# Encode sentences
sentences = [
"A man is eating food.",
"A man eats bread.",
"The cat is under the table.",
]
embeddings = model.encode(sentences)
# Compute similarity
similarity = util.pytorch_cos_sim(embeddings, embeddings)
print(similarity)
# Output: tensor([[1.0000, 0.9450, 0.1234],
# [0.9450, 1.0000, 0.1567],
# [0.1234, 0.1567, 1.0000]])
# Semantic search
query = "A man eating food"
query_embedding = model.encode(query)
# Find most similar
hits = util.semantic_search(query_embedding, embeddings, top_k=2)
for hit in hits[0]:
print(f"Score: {hit['score']:.4f}, Text: {sentences[hit['corpus_id']]}")
Advanced Techniques
Multi-task Fine-tuning
class MultitaskBERT(torch.nn.Module):
def __init__(self):
super().__init__()
self.bert = AutoModel.from_pretrained('bert-base-uncased')
# Task-specific heads
self.classification_head = torch.nn.Linear(768, 2) # Binary classification
self.ner_head = torch.nn.Linear(768, 9) # NER with 9 tags
self.regression_head = torch.nn.Linear(768, 1) # Regression
def forward(self, input_ids, attention_mask=None, task=None):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
if task == 'classification':
return self.classification_head(outputs.pooler_output)
elif task == 'ner':
return self.ner_head(outputs.last_hidden_state)
elif task == 'regression':
return self.regression_head(outputs.pooler_output)
Adapter Modules for Parameter Efficiency
from adapters import AutoAdapterModel
# Load model with adapters
model = AutoAdapterModel.from_pretrained('bert-base-uncased')
# Add task-specific adapters
model.add_adapter('classification_adapter', config='pfeiffer')
model.add_adapter('ner_adapter', config='pfeiffer')
# Fine-tune only adapters (not BERT parameters)
adapter_config = torch.nn.Linear(768, 768)
# Train with minimal parameters
Domain-Specific Fine-tuning
# Continue pre-training on domain corpus (MLM)
from transformers import DataCollatorForLanguageModeling, Trainer
# Dataset of domain-specific text
domain_dataset = load_dataset('text', data_files='domain_corpus.txt')
# MLM data collator
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=True)
# Fine-tune on domain
model = AutoModel.from_pretrained('bert-base-uncased')
trainer = Trainer(
model=model,
args=TrainingArguments('domain-bert', num_train_epochs=1),
train_dataset=domain_dataset,
data_collator=data_collator,
)
trainer.train()
Production Considerations
Inference Optimization
# Quantization for faster inference
from transformers import BertForSequenceClassification
import torch.quantization
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
model.eval()
model_q = torch.quantization.quantize_dynamic(
model,
{torch.nn.Linear},
dtype=torch.qint8
)
# Inference with quantized model
outputs = model_q(**inputs)
Batch Processing
def batch_encode(texts, batch_size=32):
all_embeddings = []
model.eval()
with torch.no_grad():
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i+batch_size]
inputs = tokenizer(batch_texts, padding=True, truncation=True, return_tensors='pt')
outputs = model(**inputs)
all_embeddings.append(outputs.pooler_output.cpu())
return torch.cat(all_embeddings, dim=0)
API Deployment
from fastapi import FastAPI
from pydantic import BaseModel
app = FastAPI()
model = AutoModel.from_pretrained('bert-base-uncased')
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
class TextInput(BaseModel):
text: str
@app.post('/embed')
async def embed(input: TextInput):
inputs = tokenizer(input.text, return_tensors='pt')
with torch.no_grad():
outputs = model(**inputs)
embedding = outputs.pooler_output.tolist()
return {'embedding': embedding}
# Run: uvicorn app:app --reload
Key Takeaway
BERT’s masked language modeling and bidirectional context revolutionized transfer learning in NLP. Fine-tuning pre-trained BERT on task-specific data achieves state-of-the-art results with minimal labeled data, making it an essential tool for text classification, NER, semantic similarity, and countless other applications.
Practical Exercise
Task: Build a semantic search engine over a document corpus using BERT and sentence-transformers.
Requirements:
- Load corpus (Wikipedia snippets or custom documents)
- Embed all documents using sentence-transformer
- Build search system with cosine similarity ranking
- Fine-tune sentence-transformer on domain data
- Evaluate with mean average precision
Evaluation:
- Measure search quality with MAP metric
- Compare pre-trained vs. fine-tuned embeddings
- Test with various query types
- Optimize for latency (index with Faiss or Milvus)
- Benchmark memory usage and inference speed