Batch Processing and Scaling Prompts
Batch Processing and Scaling Prompts
You’ve built pipelines for single inputs. Now learn how to process thousands of items through LLM pipelines while managing costs, rate limits, and infrastructure constraints. This lesson covers batch processing patterns that real production systems use.
The Challenge of Scaling
A single prompt call takes seconds. But process 10,000 customer reviews, and naive approaches fail:
- Sequential processing: Would take hours
- Naive parallelization: Hit rate limits immediately
- No caching: Reprocess identical inputs
- No monitoring: Can’t detect quality drift or failures
This lesson teaches you to scale intelligently.
Batch Processing Fundamentals
Break large datasets into manageable batches:
from typing import List, Iterator
def batch_iterator(items: List, batch_size: int = 100) -> Iterator[List]:
"""Yield successive batches from a list."""
for i in range(0, len(items), batch_size):
yield items[i:i + batch_size]
# Process 10,000 items in batches of 100
items = list(range(10000))
for batch in batch_iterator(items, batch_size=100):
print(f"Processing batch of {len(batch)} items")
# Call LLM for each item in batch
Async and Parallel Execution
Process multiple items concurrently without hitting rate limits:
import asyncio
import openai
from datetime import datetime, timedelta
class AsyncPromptBatcher:
"""
Process items asynchronously with rate limiting.
Respects API rate limits while maximizing throughput.
"""
def __init__(self, requests_per_minute: int = 60):
self.rpm_limit = requests_per_minute
self.request_times = [] # Track recent request times
async def _wait_for_rate_limit(self):
"""Wait if necessary to respect RPM limit."""
now = datetime.now()
# Keep only requests from the last minute
self.request_times = [
t for t in self.request_times
if now - t < timedelta(minutes=1)
]
if len(self.request_times) >= self.rpm_limit:
# Wait until the oldest request leaves the 1-minute window
sleep_time = (self.request_times[0] + timedelta(minutes=1) - now).total_seconds()
if sleep_time > 0:
await asyncio.sleep(sleep_time)
async def call_llm(self, prompt: str, model: str = "gpt-4") -> str:
"""Make a single async LLM call with rate limiting."""
await self._wait_for_rate_limit()
try:
# This would use actual async OpenAI client in production
response = await asyncio.to_thread(
self._sync_llm_call,
prompt,
model
)
self.request_times.append(datetime.now())
return response
except Exception as e:
print(f"Error calling LLM: {e}")
return None
def _sync_llm_call(self, prompt: str, model: str) -> str:
"""Synchronous wrapper for OpenAI call."""
response = openai.ChatCompletion.create(
model=model,
messages=[{"role": "user", "content": prompt}]
)
return response.choices[0].message.content
async def process_batch(self, inputs: list, model: str = "gpt-4") -> list:
"""
Process a batch of inputs concurrently.
"""
tasks = [
self.call_llm(input_text, model)
for input_text in inputs
]
results = await asyncio.gather(*tasks)
return results
# Usage
async def process_batch_async():
batcher = AsyncPromptBatcher(requests_per_minute=60)
reviews = [
"This product is great!",
"Not worth the money",
"It's okay, nothing special"
]
results = await batcher.process_batch(reviews)
return results
# Run async code
# results = asyncio.run(process_batch_async())
Advanced Rate Limiting and Throttling
Different APIs have different rate limits. Use sophisticated throttling:
from enum import Enum
from dataclasses import dataclass
import time
from collections import deque
class TokenLimitType(Enum):
"""Types of token limits to respect."""
RPM = "requests_per_minute"
TPM = "tokens_per_minute"
@dataclass
class RateLimit:
"""Rate limit configuration."""
requests_per_minute: int = 60
tokens_per_minute: int = 90000
max_concurrent_requests: int = 10
class SmartRateLimiter:
"""
Sophisticated rate limiting that respects both RPM and TPM.
"""
def __init__(self, limit: RateLimit):
self.limit = limit
self.request_queue = deque() # Track recent requests
self.token_queue = deque() # Track recent tokens
self.concurrent_count = 0
def can_request(self, estimated_tokens: int) -> bool:
"""Check if a request is allowed given rate limits."""
now = time.time()
# Clean up old requests outside the 1-minute window
while self.request_queue and (now - self.request_queue[0][0]) > 60:
self.request_queue.popleft()
while self.token_queue and (now - self.token_queue[0][0]) > 60:
self.token_queue.popleft()
# Check RPM limit
if len(self.request_queue) >= self.limit.requests_per_minute:
return False
# Check TPM limit
recent_tokens = sum(t[1] for t in self.token_queue)
if recent_tokens + estimated_tokens > self.limit.tokens_per_minute:
return False
# Check concurrent limit
if self.concurrent_count >= self.limit.max_concurrent_requests:
return False
return True
def wait_for_capacity(self, estimated_tokens: int):
"""Block until a request is allowed."""
while not self.can_request(estimated_tokens):
time.sleep(0.1)
def record_request(self, tokens_used: int):
"""Record that a request was made."""
now = time.time()
self.request_queue.append((now, 1))
self.token_queue.append((now, tokens_used))
def __enter__(self):
"""Context manager for tracking concurrent requests."""
self.concurrent_count += 1
return self
def __exit__(self, *args):
self.concurrent_count -= 1
# Usage
limiter = SmartRateLimiter(RateLimit(
requests_per_minute=60,
tokens_per_minute=90000
))
# Before making a request
estimated_tokens = 150 # Based on prompt + expected response
limiter.wait_for_capacity(estimated_tokens)
# Make request
with limiter:
response = openai.ChatCompletion.create(...)
actual_tokens = response["usage"]["total_tokens"]
limiter.record_request(actual_tokens)
Caching for Performance
Avoid reprocessing identical inputs:
import hashlib
import json
import sqlite3
from typing import Optional
class PromptCache:
"""
Cache LLM responses to avoid redundant API calls.
Useful for recurring prompts (e.g., classifying the same text twice).
"""
def __init__(self, db_path: str = "prompt_cache.db"):
self.db_path = db_path
self._init_db()
def _init_db(self):
"""Initialize the cache database."""
with sqlite3.connect(self.db_path) as conn:
conn.execute("""
CREATE TABLE IF NOT EXISTS cache (
prompt_hash TEXT PRIMARY KEY,
prompt TEXT,
model TEXT,
response TEXT,
created_at TIMESTAMP,
hit_count INTEGER DEFAULT 1
)
""")
conn.commit()
@staticmethod
def _hash_prompt(prompt: str, model: str) -> str:
"""Create a hash of the prompt and model."""
key = f"{prompt}:{model}"
return hashlib.sha256(key.encode()).hexdigest()
def get(self, prompt: str, model: str) -> Optional[str]:
"""Retrieve a cached response."""
hash_key = self._hash_prompt(prompt, model)
with sqlite3.connect(self.db_path) as conn:
cursor = conn.execute(
"SELECT response, hit_count FROM cache WHERE prompt_hash = ?",
(hash_key,)
)
row = cursor.fetchone()
if row:
response, hit_count = row
# Update hit count
conn.execute(
"UPDATE cache SET hit_count = ? WHERE prompt_hash = ?",
(hit_count + 1, hash_key)
)
conn.commit()
return response
return None
def set(self, prompt: str, model: str, response: str):
"""Cache a response."""
hash_key = self._hash_prompt(prompt, model)
with sqlite3.connect(self.db_path) as conn:
conn.execute("""
INSERT OR IGNORE INTO cache
(prompt_hash, prompt, model, response, created_at)
VALUES (?, ?, ?, ?, datetime('now'))
""", (hash_key, prompt, model, response))
conn.commit()
def stats(self) -> dict:
"""Get cache statistics."""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.execute("""
SELECT
COUNT(*) as total,
SUM(hit_count) as total_hits,
AVG(hit_count) as avg_hits
FROM cache
""")
row = cursor.fetchone()
return {
"cached_items": row[0],
"total_hits": row[1] or 0,
"avg_hits": row[2] or 0
}
# Usage
cache = PromptCache()
def classify_with_cache(text: str, model: str = "gpt-4") -> str:
"""Classify text, using cache when available."""
prompt = f"Classify this sentiment: {text}"
# Check cache first
cached = cache.get(prompt, model)
if cached:
print("Cache hit!")
return cached
# Call LLM if not cached
response = openai.ChatCompletion.create(
model=model,
messages=[{"role": "user", "content": prompt}]
)
result = response.choices[0].message.content
# Store in cache
cache.set(prompt, model, result)
return result
# Process many items with automatic caching
texts = ["Great product!", "Terrible experience", "Great product!", ...]
results = [classify_with_cache(text) for text in texts]
print(cache.stats())
# Output: cached_items=2, total_hits=3, avg_hits=1.5
Batch Job Monitoring
Track progress and detect issues in long-running batch jobs:
import json
from datetime import datetime
from dataclasses import dataclass, asdict
from enum import Enum
class JobStatus(Enum):
"""Status of a batch job."""
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
@dataclass
class BatchJobMetrics:
"""Metrics for a batch job."""
total_items: int
processed_items: int
failed_items: int
cached_hits: int
tokens_used: int
elapsed_seconds: float
@property
def success_rate(self) -> float:
return (self.processed_items - self.failed_items) / self.total_items
@property
def cache_hit_rate(self) -> float:
if self.processed_items == 0:
return 0
return self.cached_hits / self.processed_items
class BatchJobMonitor:
"""Monitor a batch job's progress and metrics."""
def __init__(self, job_id: str, total_items: int):
self.job_id = job_id
self.status = JobStatus.PENDING
self.metrics = BatchJobMetrics(
total_items=total_items,
processed_items=0,
failed_items=0,
cached_hits=0,
tokens_used=0,
elapsed_seconds=0
)
self.start_time = None
self.errors = []
def start(self):
"""Mark job as started."""
self.status = JobStatus.RUNNING
self.start_time = datetime.now()
def record_success(self, tokens_used: int = 0, cached: bool = False):
"""Record successful processing of an item."""
self.metrics.processed_items += 1
self.metrics.tokens_used += tokens_used
if cached:
self.metrics.cached_hits += 1
def record_failure(self, error: str):
"""Record failed processing of an item."""
self.metrics.failed_items += 1
self.errors.append({
"error": error,
"timestamp": datetime.now().isoformat()
})
def finish(self):
"""Mark job as completed."""
if self.metrics.failed_items > 0:
self.status = JobStatus.COMPLETED # With failures
else:
self.status = JobStatus.COMPLETED
if self.start_time:
elapsed = (datetime.now() - self.start_time).total_seconds()
self.metrics.elapsed_seconds = elapsed
def get_report(self) -> dict:
"""Generate a job report."""
return {
"job_id": self.job_id,
"status": self.status.value,
"metrics": asdict(self.metrics),
"success_rate": f"{self.metrics.success_rate:.1%}",
"cache_hit_rate": f"{self.metrics.cache_hit_rate:.1%}",
"errors": self.errors[:5] # First 5 errors
}
# Usage
monitor = BatchJobMonitor("batch_123", total_items=1000)
monitor.start()
for item in items:
try:
result = process_item(item)
monitor.record_success(tokens_used=150)
except Exception as e:
monitor.record_failure(str(e))
monitor.finish()
report = monitor.get_report()
print(json.dumps(report, indent=2))
# Output:
# {
# "job_id": "batch_123",
# "status": "completed",
# "metrics": {
# "total_items": 1000,
# "processed_items": 998,
# "failed_items": 2,
# "cached_hits": 120,
# "tokens_used": 147150,
# "elapsed_seconds": 2400.5
# },
# "success_rate": "99.8%",
# "cache_hit_rate": "12.0%"
# }
Quality Drift Detection
Monitor if output quality degrades over time:
import statistics
class QualityMonitor:
"""
Track quality metrics to detect degradation.
"""
def __init__(self, window_size: int = 100):
self.window_size = window_size
self.scores = []
self.thresholds = {
"min_score": 0.6,
"quality_drop_threshold": 0.1
}
def record_score(self, score: float):
"""Record a quality score (0-1) for an output."""
self.scores.append(score)
# Alert if quality drops
if len(self.scores) >= self.window_size:
recent_avg = statistics.mean(self.scores[-self.window_size:])
if len(self.scores) > self.window_size:
previous_avg = statistics.mean(
self.scores[-2*self.window_size:-self.window_size]
)
quality_change = previous_avg - recent_avg
if quality_change > self.thresholds["quality_drop_threshold"]:
print(f"WARNING: Quality dropped by {quality_change:.1%}")
print(f"Previous avg: {previous_avg:.2f}, Recent: {recent_avg:.2f}")
return "quality_alert"
if recent_avg < self.thresholds["min_score"]:
print(f"CRITICAL: Quality below threshold ({recent_avg:.2f})")
return "quality_critical"
return "ok"
# Usage in batch job
quality_monitor = QualityMonitor(window_size=50)
for i, item in enumerate(items):
result = process_item(item)
quality_score = evaluate_quality(result)
alert = quality_monitor.record_score(quality_score)
if alert == "quality_critical":
# Pause job, investigate
break
Key Takeaway: Production batch systems require rate limiting, caching, async execution, and monitoring. Together, these patterns enable you to process thousands of items efficiently while maintaining quality and controlling costs.
Exercise: Build an Async Batch Processor for Support Tickets
Create a system that processes 1,000 support tickets asynchronously:
- Classify each ticket by sentiment and category
- Respect API rate limits (60 RPM, 90K TPM)
- Cache identical classifications
- Monitor job progress and quality
- Generate a final report
Requirements:
- Use AsyncPromptBatcher for async execution
- Implement SmartRateLimiter for rate limits
- Use PromptCache for caching
- Monitor with BatchJobMonitor
- Detect quality drift with QualityMonitor
- Generate JSON report with statistics
Starter code:
async def batch_process_tickets(tickets: list) -> dict:
"""Process 1000+ tickets with async batching."""
batcher = AsyncPromptBatcher(requests_per_minute=60)
cache = PromptCache()
monitor = BatchJobMonitor("ticket_batch", len(tickets))
quality_mon = QualityMonitor()
monitor.start()
# TODO: Process tickets asynchronously
# TODO: Track progress
# TODO: Monitor quality
# TODO: Generate final report
monitor.finish()
return monitor.get_report()
# Load 1000 test tickets
tickets = load_tickets("support_tickets.json")
# Run batch job
result = asyncio.run(batch_process_tickets(tickets))
print(result)
Extension challenges:
- Implement priority queuing (urgent tickets first)
- Add dead-letter queue for permanently failed items
- Create a retry strategy for transient failures
- Build a dashboard showing real-time progress
- Implement exponential backoff for rate limit errors
By mastering batch processing, you can reliably handle large-scale LLM workloads that would be impossible to process sequentially.