Advanced

Scaling AI Applications

Lesson 3 of 4 Estimated Time 50 min

Scaling AI Applications

Scaling AI applications requires moving beyond single-instance deployments to handle thousands of concurrent users and requests. This lesson covers async processing, message queues, caching strategies, and horizontal scaling patterns that keep AI systems responsive under load.

Async Processing Fundamentals

Traditional synchronous request-response patterns become bottlenecks when processing latency-sensitive LLM calls. Async processing decouples request acceptance from result computation.

import asyncio
from typing import Optional
import anthropic

class AsyncLLMProcessor:
    """Process LLM requests asynchronously."""

    def __init__(self, model: str = "claude-3-5-sonnet-20241022"):
        self.model = model
        self.client = anthropic.AsyncAnthropic()

    async def process_single(self, prompt: str) -> str:
        """Process a single prompt asynchronously."""
        message = await self.client.messages.create(
            model=self.model,
            max_tokens=1024,
            messages=[{"role": "user", "content": prompt}]
        )
        return message.content[0].text

    async def process_batch(self, prompts: list[str]) -> list[str]:
        """Process multiple prompts concurrently."""
        tasks = [self.process_single(prompt) for prompt in prompts]
        return await asyncio.gather(*tasks)

# Usage
async def main():
    processor = AsyncLLMProcessor()
    prompts = ["What is AI?", "Explain ML", "Define NLP"]
    results = await processor.process_batch(prompts)
    for prompt, result in zip(prompts, results):
        print(f"Q: {prompt}\nA: {result}\n")

# asyncio.run(main())

Concurrency Patterns

Different workloads benefit from different concurrency approaches:

import asyncio
from concurrent.futures import ThreadPoolExecutor
import time

class ConcurrencyManager:
    """Manage different concurrency patterns."""

    async def rate_limited_processing(
        self,
        items: list[str],
        rate_limit: int = 10
    ) -> list[str]:
        """Process items with rate limiting."""
        semaphore = asyncio.Semaphore(rate_limit)

        async def limited_process(item):
            async with semaphore:
                # Simulate LLM processing
                await asyncio.sleep(0.1)
                return f"processed: {item}"

        tasks = [limited_process(item) for item in items]
        return await asyncio.gather(*tasks)

    def cpu_bound_with_threads(self, data: list) -> list:
        """Use threading for CPU-bound operations."""
        with ThreadPoolExecutor(max_workers=4) as executor:
            results = list(executor.map(self._compute_intensive, data))
        return results

    def _compute_intensive(self, item):
        """Simulate CPU-intensive work."""
        time.sleep(0.1)
        return item ** 2

Message Queues and Task Distribution

Message queues decouple request producers from processors, enabling scalability through horizontal scaling. RabbitMQ and Apache Kafka are industry standards.

import json
from typing import Any, Callable
from dataclasses import dataclass
from abc import ABC, abstractmethod

@dataclass
class QueuedTask:
    """Represents a queued task."""
    task_id: str
    task_type: str
    payload: dict
    priority: int = 0
    retry_count: int = 0
    max_retries: int = 3

class MessageQueue(ABC):
    """Abstract message queue interface."""

    @abstractmethod
    async def enqueue(self, task: QueuedTask) -> str:
        """Add task to queue."""
        pass

    @abstractmethod
    async def dequeue(self) -> Optional[QueuedTask]:
        """Get next task from queue."""
        pass

    @abstractmethod
    async def ack(self, task_id: str) -> bool:
        """Acknowledge task completion."""
        pass

class RabbitMQQueue(MessageQueue):
    """RabbitMQ-based queue implementation."""

    def __init__(self, connection_url: str):
        self.connection_url = connection_url
        # In production: import pika; self.connection = pika.BlockingConnection(...)

    async def enqueue(self, task: QueuedTask) -> str:
        """Add task to queue with priority handling."""
        message = json.dumps({
            "task_id": task.task_id,
            "type": task.task_type,
            "payload": task.payload,
            "priority": task.priority,
            "retry_count": task.retry_count
        })
        # channel.basic_publish(exchange='tasks',
        #     routing_key=f'task.{task.task_type}', body=message)
        return task.task_id

    async def dequeue(self) -> Optional[QueuedTask]:
        """Retrieve next task from queue."""
        # method, properties, body = channel.basic_get('task_queue')
        # if method:
        #     return QueuedTask(**json.loads(body))
        return None

    async def ack(self, task_id: str) -> bool:
        """Acknowledge successful task processing."""
        # channel.basic_ack(delivery_tag=method.delivery_tag)
        return True

class TaskProcessor:
    """Process queued tasks with retry logic."""

    def __init__(self, queue: MessageQueue):
        self.queue = queue
        self.handlers: dict[str, Callable] = {}

    def register_handler(self, task_type: str, handler: Callable):
        """Register handler for task type."""
        self.handlers[task_type] = handler

    async def process_queue(self):
        """Continuously process queue."""
        while True:
            task = await self.queue.dequeue()
            if task is None:
                await asyncio.sleep(1)
                continue

            try:
                handler = self.handlers.get(task.task_type)
                if handler:
                    await handler(task.payload)
                await self.queue.ack(task.task_id)
            except Exception as e:
                if task.retry_count < task.max_retries:
                    task.retry_count += 1
                    await self.queue.enqueue(task)
                else:
                    # Log failure
                    print(f"Task {task.task_id} failed after {task.max_retries} retries")

Caching Strategies

Caching reduces latency and backend load by storing computed results. Redis provides fast, distributed caching.

from typing import Optional, TypeVar, Callable
import hashlib
import json
from datetime import datetime, timedelta

T = TypeVar('T')

class CacheLayer:
    """Multi-level caching strategy."""

    def __init__(self, ttl_seconds: int = 3600):
        self.ttl_seconds = ttl_seconds
        self.memory_cache: dict = {}  # L1 cache
        # self.redis_client = redis.Redis(host='localhost')  # L2 cache

    def _hash_key(self, *args, **kwargs) -> str:
        """Generate cache key from function args."""
        key_data = json.dumps({
            "args": str(args),
            "kwargs": str(kwargs)
        }, sort_keys=True)
        return hashlib.md5(key_data.encode()).hexdigest()

    async def get_or_compute(
        self,
        key: str,
        compute_fn: Callable[[], T],
        ttl: Optional[int] = None
    ) -> T:
        """Get cached value or compute and cache."""
        ttl = ttl or self.ttl_seconds

        # Check memory cache (L1)
        cached = self.memory_cache.get(key)
        if cached and cached["expires_at"] > datetime.now():
            return cached["value"]

        # Check Redis (L2)
        # redis_value = self.redis_client.get(key)
        # if redis_value:
        #     return json.loads(redis_value)

        # Compute new value
        value = await compute_fn()

        # Store in both caches
        self.memory_cache[key] = {
            "value": value,
            "expires_at": datetime.now() + timedelta(seconds=ttl)
        }
        # self.redis_client.setex(key, ttl, json.dumps(value))

        return value

    def invalidate(self, pattern: str):
        """Invalidate cache entries matching pattern."""
        keys_to_delete = [k for k in self.memory_cache.keys() if pattern in k]
        for key in keys_to_delete:
            del self.memory_cache[key]

class SmartCache:
    """Cache with semantic similarity detection."""

    def __init__(self, cache_layer: CacheLayer):
        self.cache = cache_layer
        self.similarity_threshold = 0.85

    async def cached_generation(
        self,
        prompt: str,
        generate_fn: Callable[[str], str]
    ) -> str:
        """Cache generations with similarity matching."""
        # Simple string matching for similar prompts
        similar_cached = self._find_similar_cached(prompt)
        if similar_cached:
            return similar_cached

        result = await generate_fn(prompt)
        await self.cache.get_or_compute(
            self._hash_key(prompt),
            lambda: result
        )
        return result

    def _find_similar_cached(self, prompt: str) -> Optional[str]:
        """Find similar cached results."""
        # In production: use embeddings and cosine similarity
        return None

Horizontal Scaling

Horizontal scaling adds more instances to handle increased load. Load balancers distribute traffic across instances.

from typing import List
from dataclasses import dataclass
from enum import Enum

class InstanceStatus(Enum):
    HEALTHY = "healthy"
    DEGRADED = "degraded"
    UNHEALTHY = "unhealthy"

@dataclass
class InstanceMetrics:
    instance_id: str
    status: InstanceStatus
    latency_ms: float
    error_rate: float
    throughput: float
    memory_usage: float
    cpu_usage: float

class LoadBalancer:
    """Distribute requests across instances."""

    def __init__(self):
        self.instances: List[InstanceMetrics] = []
        self.routing_policy = "least_loaded"

    def add_instance(self, metrics: InstanceMetrics):
        """Register new instance."""
        self.instances.append(metrics)

    def remove_instance(self, instance_id: str):
        """Deregister unhealthy instance."""
        self.instances = [
            i for i in self.instances if i.instance_id != instance_id
        ]

    def select_instance(self) -> Optional[str]:
        """Select best instance for request."""
        healthy = [i for i in self.instances
                   if i.status == InstanceStatus.HEALTHY]
        if not healthy:
            return None

        if self.routing_policy == "least_loaded":
            return min(healthy, key=lambda i: i.latency_ms).instance_id
        elif self.routing_policy == "round_robin":
            return healthy[0].instance_id
        elif self.routing_policy == "weighted":
            return self._weighted_selection(healthy)

        return healthy[0].instance_id

    def _weighted_selection(self, instances: List[InstanceMetrics]) -> str:
        """Select instance weighted by health metrics."""
        total_score = sum(i.throughput - i.error_rate for i in instances)
        if total_score <= 0:
            return instances[0].instance_id

        scores = [(i.instance_id, (i.throughput - i.error_rate) / total_score)
                  for i in instances]
        return scores[0][0]

class AutoScaler:
    """Automatically scale instances based on demand."""

    def __init__(
        self,
        min_instances: int = 2,
        max_instances: int = 10,
        cpu_threshold: float = 0.8
    ):
        self.min_instances = min_instances
        self.max_instances = max_instances
        self.cpu_threshold = cpu_threshold
        self.current_instances = min_instances

    def evaluate_scaling(self, metrics: List[InstanceMetrics]) -> int:
        """Determine if scaling is needed."""
        avg_cpu = sum(m.cpu_usage for m in metrics) / len(metrics)
        avg_error_rate = sum(m.error_rate for m in metrics) / len(metrics)

        if avg_cpu > self.cpu_threshold or avg_error_rate > 0.05:
            # Scale up
            new_instances = min(self.current_instances + 2, self.max_instances)
        elif avg_cpu < 0.2 and avg_error_rate < 0.01:
            # Scale down
            new_instances = max(self.current_instances - 1, self.min_instances)
        else:
            new_instances = self.current_instances

        delta = new_instances - self.current_instances
        if delta != 0:
            self.current_instances = new_instances
            return delta
        return 0

Database Scaling

Efficient database access is critical for scalability. Connection pooling and query optimization prevent bottlenecks.

from typing import Any, List
import asyncpg

class ConnectionPool:
    """Manage database connections efficiently."""

    def __init__(
        self,
        dsn: str,
        min_size: int = 10,
        max_size: int = 20
    ):
        self.dsn = dsn
        self.min_size = min_size
        self.max_size = max_size
        self.pool = None

    async def initialize(self):
        """Create connection pool."""
        # self.pool = await asyncpg.create_pool(
        #     self.dsn,
        #     min_size=self.min_size,
        #     max_size=self.max_size
        # )
        pass

    async def execute(self, query: str, *args) -> Any:
        """Execute query with connection from pool."""
        # async with self.pool.acquire() as connection:
        #     return await connection.fetch(query, *args)
        pass

class QueryOptimizer:
    """Optimize queries for performance."""

    @staticmethod
    def add_indexes(fields: List[str]) -> str:
        """Generate CREATE INDEX statements."""
        indexes = []
        for field in fields:
            indexes.append(f"CREATE INDEX IF NOT EXISTS idx_{field} ON data({field});")
        return "\n".join(indexes)

    @staticmethod
    def batch_operations(operations: List[str], batch_size: int = 1000):
        """Group operations into batches."""
        for i in range(0, len(operations), batch_size):
            yield operations[i:i + batch_size]

Key Takeaway

Scaling requires async processing, message queues for decoupling, intelligent caching, horizontal instance scaling with load balancing, and optimized database access. Each layer must scale independently.

Exercises

  1. Implement async batch processing for 100+ concurrent requests
  2. Set up a message queue system with task retry logic
  3. Build a multi-level cache (memory + Redis) with TTL management
  4. Create a load balancer with least-loaded routing policy
  5. Implement auto-scaling based on CPU/error rate metrics
  6. Optimize database queries and add connection pooling
  7. Measure and profile your scaling bottlenecks