Advanced

Fast Inference and Optimization

Lesson 2 of 4 Estimated Time 60 min

Fast Inference and Optimization

Production serving requires fast, efficient inference. This lesson covers KV-cache optimization, speculative decoding, continuous batching, and systems like vLLM and TensorRT for high-throughput serving.

Core Concepts

KV-Cache

Cache key-value pairs from previous tokens to avoid recomputation:

Token 1: Compute Q, K, V, output
Token 2: Load K1, V1 from cache, compute Q2, K2, V2
...

Memory-latency tradeoff: Larger cache → more memory but faster generation.

Speculative Decoding

Use fast, small model to draft tokens, then verify with large model:

Small model: Draft 4 tokens fast
Large model: Verify 4 tokens, accept/reject
Expected speedup: 2-3x with 95%+ accuracy

Continuous Batching

Don’t wait for sequence to finish; add new sequences when slot available:

Seq 1: Gen token 1, 2, 3, 4, ... (still generating)
Seq 2: (new) Gen token 1, 2, 3, ...
Seq 3: (done) Start new

Practical Implementation

KV-Cache Implementation

def generate_with_cache(model, tokenizer, prompt, max_length=100):
    input_ids = tokenizer.encode(prompt, return_tensors='pt')
    past_key_values = None
    
    with torch.no_grad():
        for _ in range(max_length - len(input_ids[0])):
            # Only pass last token (rest in cache)
            outputs = model(
                input_ids=input_ids[:, -1:],
                past_key_values=past_key_values,
                use_cache=True,
                return_dict=True,
            )
            
            next_token = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(-1)
            input_ids = torch.cat([input_ids, next_token], dim=1)
            past_key_values = outputs.past_key_values
            
            if next_token == tokenizer.eos_token_id:
                break
    
    return tokenizer.decode(input_ids[0])

vLLM for High-Throughput Serving

from vllm import LLM, SamplingParams

# Initialize with continuous batching
llm = LLM(
    model="meta-llama/Llama-2-7b",
    tensor_parallel_size=4,  # Multi-GPU
    gpu_memory_utilization=0.9,
)

sampling_params = SamplingParams(
    temperature=0.8,
    top_p=0.95,
    max_tokens=200,
)

# Batch inference
prompts = [
    "What is machine learning?",
    "Explain quantum computing",
    "Write a poem about AI",
]

outputs = llm.generate(prompts, sampling_params)

for output in outputs:
    print(output.outputs[0].text)

Speculative Decoding

class SpeculativeDecoder:
    def __init__(self, draft_model, verify_model):
        self.draft = draft_model  # Small, fast
        self.verify = verify_model  # Large, accurate
        self.accept_count = 0
        self.reject_count = 0
    
    def generate(self, input_ids, max_tokens=100, draft_tokens=4):
        for _ in range(max_tokens // draft_tokens):
            # Draft with small model
            draft_output = self.draft.generate(input_ids, max_new_tokens=draft_tokens)
            draft_ids = draft_output[0][-draft_tokens:]
            
            # Verify with large model
            verify_output = self.verify(draft_ids)
            verify_logits = verify_output.logits
            
            # Accept if probs match
            draft_probs = torch.softmax(self.draft(draft_ids).logits, dim=-1)
            verify_probs = torch.softmax(verify_logits, dim=-1)
            
            acceptance_ratio = (draft_probs * verify_probs).sum() / draft_probs.sum()
            
            if acceptance_ratio > 0.9:
                input_ids = draft_ids
                self.accept_count += 1
            else:
                input_ids = torch.cat([input_ids, verify_logits[:, -1].argmax().unsqueeze(-1)])
                self.reject_count += 1
        
        return input_ids

Advanced Techniques

Continuous Batching

class ContinuousBatchScheduler:
    def __init__(self, max_batch_size=32):
        self.queue = []
        self.active_sequences = {}
        self.max_batch_size = max_batch_size
    
    def add_request(self, prompt_id, prompt):
        self.queue.append((prompt_id, prompt))
    
    def get_batch(self):
        batch = []
        # Fill from queue
        while self.queue and len(batch) < self.max_batch_size:
            prompt_id, prompt = self.queue.pop(0)
            batch.append((prompt_id, prompt))
            self.active_sequences[prompt_id] = {"prompt": prompt, "tokens_generated": 0}
        
        # Add from active (if not finished)
        for seq_id, seq_data in list(self.active_sequences.items()):
            if len(batch) >= self.max_batch_size:
                break
            if seq_data["tokens_generated"] > 0:
                batch.append((seq_id, seq_data["prompt"]))
        
        return batch

Production Considerations

Latency Optimization

import time

def benchmark_inference(model, prompt, num_runs=10):
    # Warmup
    for _ in range(3):
        _ = model.generate(prompt)
    
    # Measure
    latencies = []
    for _ in range(num_runs):
        start = time.time()
        output = model.generate(prompt)
        latencies.append(time.time() - start)
    
    return {
        "mean": np.mean(latencies),
        "p50": np.percentile(latencies, 50),
        "p99": np.percentile(latencies, 99),
    }

Key Takeaway

Fast inference systems leverage KV-cache, speculative decoding, and continuous batching to achieve 10-100x throughput improvements. vLLM and TGI exemplify production-ready systems for serving LLMs efficiently.

Practical Exercise

Task: Deploy LLM serving system with vLLM achieving 1000+ req/s.

Requirements:

  1. Set up vLLM with multi-GPU tensor parallelism
  2. Implement continuous batching
  3. Add speculative decoding for speedup
  4. Create API endpoint
  5. Load test to measure throughput

Evaluation:

  • Throughput > 1000 requests/second
  • P99 latency < 200ms
  • Memory efficiency
  • Compared to baseline throughput