Scaling AI Applications
Scaling AI Applications
Scaling AI applications requires moving beyond single-instance deployments to handle thousands of concurrent users and requests. This lesson covers async processing, message queues, caching strategies, and horizontal scaling patterns that keep AI systems responsive under load.
Async Processing Fundamentals
Traditional synchronous request-response patterns become bottlenecks when processing latency-sensitive LLM calls. Async processing decouples request acceptance from result computation.
import asyncio
from typing import Optional
import anthropic
class AsyncLLMProcessor:
"""Process LLM requests asynchronously."""
def __init__(self, model: str = "claude-3-5-sonnet-20241022"):
self.model = model
self.client = anthropic.AsyncAnthropic()
async def process_single(self, prompt: str) -> str:
"""Process a single prompt asynchronously."""
message = await self.client.messages.create(
model=self.model,
max_tokens=1024,
messages=[{"role": "user", "content": prompt}]
)
return message.content[0].text
async def process_batch(self, prompts: list[str]) -> list[str]:
"""Process multiple prompts concurrently."""
tasks = [self.process_single(prompt) for prompt in prompts]
return await asyncio.gather(*tasks)
# Usage
async def main():
processor = AsyncLLMProcessor()
prompts = ["What is AI?", "Explain ML", "Define NLP"]
results = await processor.process_batch(prompts)
for prompt, result in zip(prompts, results):
print(f"Q: {prompt}\nA: {result}\n")
# asyncio.run(main())
Concurrency Patterns
Different workloads benefit from different concurrency approaches:
import asyncio
from concurrent.futures import ThreadPoolExecutor
import time
class ConcurrencyManager:
"""Manage different concurrency patterns."""
async def rate_limited_processing(
self,
items: list[str],
rate_limit: int = 10
) -> list[str]:
"""Process items with rate limiting."""
semaphore = asyncio.Semaphore(rate_limit)
async def limited_process(item):
async with semaphore:
# Simulate LLM processing
await asyncio.sleep(0.1)
return f"processed: {item}"
tasks = [limited_process(item) for item in items]
return await asyncio.gather(*tasks)
def cpu_bound_with_threads(self, data: list) -> list:
"""Use threading for CPU-bound operations."""
with ThreadPoolExecutor(max_workers=4) as executor:
results = list(executor.map(self._compute_intensive, data))
return results
def _compute_intensive(self, item):
"""Simulate CPU-intensive work."""
time.sleep(0.1)
return item ** 2
Message Queues and Task Distribution
Message queues decouple request producers from processors, enabling scalability through horizontal scaling. RabbitMQ and Apache Kafka are industry standards.
import json
from typing import Any, Callable
from dataclasses import dataclass
from abc import ABC, abstractmethod
@dataclass
class QueuedTask:
"""Represents a queued task."""
task_id: str
task_type: str
payload: dict
priority: int = 0
retry_count: int = 0
max_retries: int = 3
class MessageQueue(ABC):
"""Abstract message queue interface."""
@abstractmethod
async def enqueue(self, task: QueuedTask) -> str:
"""Add task to queue."""
pass
@abstractmethod
async def dequeue(self) -> Optional[QueuedTask]:
"""Get next task from queue."""
pass
@abstractmethod
async def ack(self, task_id: str) -> bool:
"""Acknowledge task completion."""
pass
class RabbitMQQueue(MessageQueue):
"""RabbitMQ-based queue implementation."""
def __init__(self, connection_url: str):
self.connection_url = connection_url
# In production: import pika; self.connection = pika.BlockingConnection(...)
async def enqueue(self, task: QueuedTask) -> str:
"""Add task to queue with priority handling."""
message = json.dumps({
"task_id": task.task_id,
"type": task.task_type,
"payload": task.payload,
"priority": task.priority,
"retry_count": task.retry_count
})
# channel.basic_publish(exchange='tasks',
# routing_key=f'task.{task.task_type}', body=message)
return task.task_id
async def dequeue(self) -> Optional[QueuedTask]:
"""Retrieve next task from queue."""
# method, properties, body = channel.basic_get('task_queue')
# if method:
# return QueuedTask(**json.loads(body))
return None
async def ack(self, task_id: str) -> bool:
"""Acknowledge successful task processing."""
# channel.basic_ack(delivery_tag=method.delivery_tag)
return True
class TaskProcessor:
"""Process queued tasks with retry logic."""
def __init__(self, queue: MessageQueue):
self.queue = queue
self.handlers: dict[str, Callable] = {}
def register_handler(self, task_type: str, handler: Callable):
"""Register handler for task type."""
self.handlers[task_type] = handler
async def process_queue(self):
"""Continuously process queue."""
while True:
task = await self.queue.dequeue()
if task is None:
await asyncio.sleep(1)
continue
try:
handler = self.handlers.get(task.task_type)
if handler:
await handler(task.payload)
await self.queue.ack(task.task_id)
except Exception as e:
if task.retry_count < task.max_retries:
task.retry_count += 1
await self.queue.enqueue(task)
else:
# Log failure
print(f"Task {task.task_id} failed after {task.max_retries} retries")
Caching Strategies
Caching reduces latency and backend load by storing computed results. Redis provides fast, distributed caching.
from typing import Optional, TypeVar, Callable
import hashlib
import json
from datetime import datetime, timedelta
T = TypeVar('T')
class CacheLayer:
"""Multi-level caching strategy."""
def __init__(self, ttl_seconds: int = 3600):
self.ttl_seconds = ttl_seconds
self.memory_cache: dict = {} # L1 cache
# self.redis_client = redis.Redis(host='localhost') # L2 cache
def _hash_key(self, *args, **kwargs) -> str:
"""Generate cache key from function args."""
key_data = json.dumps({
"args": str(args),
"kwargs": str(kwargs)
}, sort_keys=True)
return hashlib.md5(key_data.encode()).hexdigest()
async def get_or_compute(
self,
key: str,
compute_fn: Callable[[], T],
ttl: Optional[int] = None
) -> T:
"""Get cached value or compute and cache."""
ttl = ttl or self.ttl_seconds
# Check memory cache (L1)
cached = self.memory_cache.get(key)
if cached and cached["expires_at"] > datetime.now():
return cached["value"]
# Check Redis (L2)
# redis_value = self.redis_client.get(key)
# if redis_value:
# return json.loads(redis_value)
# Compute new value
value = await compute_fn()
# Store in both caches
self.memory_cache[key] = {
"value": value,
"expires_at": datetime.now() + timedelta(seconds=ttl)
}
# self.redis_client.setex(key, ttl, json.dumps(value))
return value
def invalidate(self, pattern: str):
"""Invalidate cache entries matching pattern."""
keys_to_delete = [k for k in self.memory_cache.keys() if pattern in k]
for key in keys_to_delete:
del self.memory_cache[key]
class SmartCache:
"""Cache with semantic similarity detection."""
def __init__(self, cache_layer: CacheLayer):
self.cache = cache_layer
self.similarity_threshold = 0.85
async def cached_generation(
self,
prompt: str,
generate_fn: Callable[[str], str]
) -> str:
"""Cache generations with similarity matching."""
# Simple string matching for similar prompts
similar_cached = self._find_similar_cached(prompt)
if similar_cached:
return similar_cached
result = await generate_fn(prompt)
await self.cache.get_or_compute(
self._hash_key(prompt),
lambda: result
)
return result
def _find_similar_cached(self, prompt: str) -> Optional[str]:
"""Find similar cached results."""
# In production: use embeddings and cosine similarity
return None
Horizontal Scaling
Horizontal scaling adds more instances to handle increased load. Load balancers distribute traffic across instances.
from typing import List
from dataclasses import dataclass
from enum import Enum
class InstanceStatus(Enum):
HEALTHY = "healthy"
DEGRADED = "degraded"
UNHEALTHY = "unhealthy"
@dataclass
class InstanceMetrics:
instance_id: str
status: InstanceStatus
latency_ms: float
error_rate: float
throughput: float
memory_usage: float
cpu_usage: float
class LoadBalancer:
"""Distribute requests across instances."""
def __init__(self):
self.instances: List[InstanceMetrics] = []
self.routing_policy = "least_loaded"
def add_instance(self, metrics: InstanceMetrics):
"""Register new instance."""
self.instances.append(metrics)
def remove_instance(self, instance_id: str):
"""Deregister unhealthy instance."""
self.instances = [
i for i in self.instances if i.instance_id != instance_id
]
def select_instance(self) -> Optional[str]:
"""Select best instance for request."""
healthy = [i for i in self.instances
if i.status == InstanceStatus.HEALTHY]
if not healthy:
return None
if self.routing_policy == "least_loaded":
return min(healthy, key=lambda i: i.latency_ms).instance_id
elif self.routing_policy == "round_robin":
return healthy[0].instance_id
elif self.routing_policy == "weighted":
return self._weighted_selection(healthy)
return healthy[0].instance_id
def _weighted_selection(self, instances: List[InstanceMetrics]) -> str:
"""Select instance weighted by health metrics."""
total_score = sum(i.throughput - i.error_rate for i in instances)
if total_score <= 0:
return instances[0].instance_id
scores = [(i.instance_id, (i.throughput - i.error_rate) / total_score)
for i in instances]
return scores[0][0]
class AutoScaler:
"""Automatically scale instances based on demand."""
def __init__(
self,
min_instances: int = 2,
max_instances: int = 10,
cpu_threshold: float = 0.8
):
self.min_instances = min_instances
self.max_instances = max_instances
self.cpu_threshold = cpu_threshold
self.current_instances = min_instances
def evaluate_scaling(self, metrics: List[InstanceMetrics]) -> int:
"""Determine if scaling is needed."""
avg_cpu = sum(m.cpu_usage for m in metrics) / len(metrics)
avg_error_rate = sum(m.error_rate for m in metrics) / len(metrics)
if avg_cpu > self.cpu_threshold or avg_error_rate > 0.05:
# Scale up
new_instances = min(self.current_instances + 2, self.max_instances)
elif avg_cpu < 0.2 and avg_error_rate < 0.01:
# Scale down
new_instances = max(self.current_instances - 1, self.min_instances)
else:
new_instances = self.current_instances
delta = new_instances - self.current_instances
if delta != 0:
self.current_instances = new_instances
return delta
return 0
Database Scaling
Efficient database access is critical for scalability. Connection pooling and query optimization prevent bottlenecks.
from typing import Any, List
import asyncpg
class ConnectionPool:
"""Manage database connections efficiently."""
def __init__(
self,
dsn: str,
min_size: int = 10,
max_size: int = 20
):
self.dsn = dsn
self.min_size = min_size
self.max_size = max_size
self.pool = None
async def initialize(self):
"""Create connection pool."""
# self.pool = await asyncpg.create_pool(
# self.dsn,
# min_size=self.min_size,
# max_size=self.max_size
# )
pass
async def execute(self, query: str, *args) -> Any:
"""Execute query with connection from pool."""
# async with self.pool.acquire() as connection:
# return await connection.fetch(query, *args)
pass
class QueryOptimizer:
"""Optimize queries for performance."""
@staticmethod
def add_indexes(fields: List[str]) -> str:
"""Generate CREATE INDEX statements."""
indexes = []
for field in fields:
indexes.append(f"CREATE INDEX IF NOT EXISTS idx_{field} ON data({field});")
return "\n".join(indexes)
@staticmethod
def batch_operations(operations: List[str], batch_size: int = 1000):
"""Group operations into batches."""
for i in range(0, len(operations), batch_size):
yield operations[i:i + batch_size]
Key Takeaway
Scaling requires async processing, message queues for decoupling, intelligent caching, horizontal instance scaling with load balancing, and optimized database access. Each layer must scale independently.
Exercises
- Implement async batch processing for 100+ concurrent requests
- Set up a message queue system with task retry logic
- Build a multi-level cache (memory + Redis) with TTL management
- Create a load balancer with least-loaded routing policy
- Implement auto-scaling based on CPU/error rate metrics
- Optimize database queries and add connection pooling
- Measure and profile your scaling bottlenecks