Fast Inference and Optimization
Fast Inference and Optimization
Production serving requires fast, efficient inference. This lesson covers KV-cache optimization, speculative decoding, continuous batching, and systems like vLLM and TensorRT for high-throughput serving.
Core Concepts
KV-Cache
Cache key-value pairs from previous tokens to avoid recomputation:
Token 1: Compute Q, K, V, output
Token 2: Load K1, V1 from cache, compute Q2, K2, V2
...
Memory-latency tradeoff: Larger cache → more memory but faster generation.
Speculative Decoding
Use fast, small model to draft tokens, then verify with large model:
Small model: Draft 4 tokens fast
Large model: Verify 4 tokens, accept/reject
Expected speedup: 2-3x with 95%+ accuracy
Continuous Batching
Don’t wait for sequence to finish; add new sequences when slot available:
Seq 1: Gen token 1, 2, 3, 4, ... (still generating)
Seq 2: (new) Gen token 1, 2, 3, ...
Seq 3: (done) Start new
Practical Implementation
KV-Cache Implementation
def generate_with_cache(model, tokenizer, prompt, max_length=100):
input_ids = tokenizer.encode(prompt, return_tensors='pt')
past_key_values = None
with torch.no_grad():
for _ in range(max_length - len(input_ids[0])):
# Only pass last token (rest in cache)
outputs = model(
input_ids=input_ids[:, -1:],
past_key_values=past_key_values,
use_cache=True,
return_dict=True,
)
next_token = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(-1)
input_ids = torch.cat([input_ids, next_token], dim=1)
past_key_values = outputs.past_key_values
if next_token == tokenizer.eos_token_id:
break
return tokenizer.decode(input_ids[0])
vLLM for High-Throughput Serving
from vllm import LLM, SamplingParams
# Initialize with continuous batching
llm = LLM(
model="meta-llama/Llama-2-7b",
tensor_parallel_size=4, # Multi-GPU
gpu_memory_utilization=0.9,
)
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
max_tokens=200,
)
# Batch inference
prompts = [
"What is machine learning?",
"Explain quantum computing",
"Write a poem about AI",
]
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
print(output.outputs[0].text)
Speculative Decoding
class SpeculativeDecoder:
def __init__(self, draft_model, verify_model):
self.draft = draft_model # Small, fast
self.verify = verify_model # Large, accurate
self.accept_count = 0
self.reject_count = 0
def generate(self, input_ids, max_tokens=100, draft_tokens=4):
for _ in range(max_tokens // draft_tokens):
# Draft with small model
draft_output = self.draft.generate(input_ids, max_new_tokens=draft_tokens)
draft_ids = draft_output[0][-draft_tokens:]
# Verify with large model
verify_output = self.verify(draft_ids)
verify_logits = verify_output.logits
# Accept if probs match
draft_probs = torch.softmax(self.draft(draft_ids).logits, dim=-1)
verify_probs = torch.softmax(verify_logits, dim=-1)
acceptance_ratio = (draft_probs * verify_probs).sum() / draft_probs.sum()
if acceptance_ratio > 0.9:
input_ids = draft_ids
self.accept_count += 1
else:
input_ids = torch.cat([input_ids, verify_logits[:, -1].argmax().unsqueeze(-1)])
self.reject_count += 1
return input_ids
Advanced Techniques
Continuous Batching
class ContinuousBatchScheduler:
def __init__(self, max_batch_size=32):
self.queue = []
self.active_sequences = {}
self.max_batch_size = max_batch_size
def add_request(self, prompt_id, prompt):
self.queue.append((prompt_id, prompt))
def get_batch(self):
batch = []
# Fill from queue
while self.queue and len(batch) < self.max_batch_size:
prompt_id, prompt = self.queue.pop(0)
batch.append((prompt_id, prompt))
self.active_sequences[prompt_id] = {"prompt": prompt, "tokens_generated": 0}
# Add from active (if not finished)
for seq_id, seq_data in list(self.active_sequences.items()):
if len(batch) >= self.max_batch_size:
break
if seq_data["tokens_generated"] > 0:
batch.append((seq_id, seq_data["prompt"]))
return batch
Production Considerations
Latency Optimization
import time
def benchmark_inference(model, prompt, num_runs=10):
# Warmup
for _ in range(3):
_ = model.generate(prompt)
# Measure
latencies = []
for _ in range(num_runs):
start = time.time()
output = model.generate(prompt)
latencies.append(time.time() - start)
return {
"mean": np.mean(latencies),
"p50": np.percentile(latencies, 50),
"p99": np.percentile(latencies, 99),
}
Key Takeaway
Fast inference systems leverage KV-cache, speculative decoding, and continuous batching to achieve 10-100x throughput improvements. vLLM and TGI exemplify production-ready systems for serving LLMs efficiently.
Practical Exercise
Task: Deploy LLM serving system with vLLM achieving 1000+ req/s.
Requirements:
- Set up vLLM with multi-GPU tensor parallelism
- Implement continuous batching
- Add speculative decoding for speedup
- Create API endpoint
- Load test to measure throughput
Evaluation:
- Throughput > 1000 requests/second
- P99 latency < 200ms
- Memory efficiency
- Compared to baseline throughput