Intermediate

API Security for AI Services

Lesson 4 of 4 Estimated Time 50 min

API Security for AI Services

Protecting AI Through APIs

When you expose AI capabilities through APIs, you create new attack surfaces. Users interact through APIs, attackers also attack through APIs. This lesson covers securing AI service APIs.

Core API Security for AI

Authentication & Authorization

class APIAuthentication:
    def __init__(self):
        self.api_keys = {}  # api_key -> user_metadata
        self.rate_limits = {}  # api_key -> limit_config
        self.user_permissions = {}  # api_key -> permissions

    def authenticate_request(self, api_key):
        """Verify API key is valid."""

        if api_key not in self.api_keys:
            raise AuthenticationError("Invalid API key")

        metadata = self.api_keys[api_key]

        if metadata['revoked']:
            raise AuthenticationError("API key revoked")

        if metadata['expires_at'] < datetime.now():
            raise AuthenticationError("API key expired")

        return metadata

    def authorize_action(self, api_key, action):
        """Check user can perform action."""

        permissions = self.user_permissions.get(api_key, set())

        if action not in permissions:
            raise AuthorizationError(f"Not authorized for {action}")

        return True

Input Validation & Size Limits

class APIInputValidation:
    def validate_request(self, request):
        """Validate API request."""

        # Size limits
        MAX_INPUT_SIZE = 10000  # characters
        MAX_REQUEST_SIZE = 1 * 1024 * 1024  # 1 MB

        if len(request.get('prompt', '')) > MAX_INPUT_SIZE:
            raise ValueError(f"Input exceeds max size {MAX_INPUT_SIZE}")

        if len(str(request)) > MAX_REQUEST_SIZE:
            raise ValueError(f"Request exceeds max size {MAX_REQUEST_SIZE}")

        # Required fields
        required_fields = ['prompt']
        for field in required_fields:
            if field not in request:
                raise ValueError(f"Missing required field: {field}")

        # Type validation
        if not isinstance(request['prompt'], str):
            raise ValueError("Prompt must be string")

        # Injection detection
        if self.is_suspicious(request['prompt']):
            return {'valid': False, 'reason': 'Suspicious input detected'}

        return {'valid': True}

    def is_suspicious(self, text):
        """Detect suspicious patterns."""

        patterns = [
            r'ignore.*instruction',
            r'system.*prompt',
            r'new instruction',
        ]

        return any(re.search(p, text, re.IGNORECASE) for p in patterns)

Rate Limiting

from collections import defaultdict
import time

class APIRateLimiter:
    def __init__(self):
        self.requests = defaultdict(list)  # api_key -> [timestamps]
        self.limits = defaultdict(lambda: {'requests': 100, 'window': 3600})  # Per hour

    def check_rate_limit(self, api_key):
        """Check if request is within rate limit."""

        now = time.time()
        limit_config = self.limits[api_key]

        # Remove old timestamps outside window
        self.requests[api_key] = [
            ts for ts in self.requests[api_key]
            if now - ts < limit_config['window']
        ]

        # Check count
        if len(self.requests[api_key]) >= limit_config['requests']:
            reset_time = self.requests[api_key][0] + limit_config['window']
            return {
                'allowed': False,
                'reset_at': reset_time,
                'retry_after': int(reset_time - now)
            }

        # Record this request
        self.requests[api_key].append(now)

        return {'allowed': True}

Output Validation & Sanitization

class APIOutputValidation:
    def validate_response(self, response):
        """Validate response before returning."""

        # Check for sensitive data leakage
        if self.contains_sensitive_data(response):
            # Log and redact
            self.log_security_event('sensitive_data_in_response', response[:100])
            response = self.redact_sensitive_data(response)

        # Check for harmful content
        if self.contains_harmful_content(response):
            self.log_security_event('harmful_content_detected', response[:100])
            # Don't return harmful content
            return "I cannot provide that response"

        # Validate format
        if not self.is_valid_format(response):
            return "Response formatting error"

        return response

    def contains_sensitive_data(self, text):
        """Detect if response leaks sensitive data."""

        patterns = {
            'ssn': r'\d{3}-\d{2}-\d{4}',
            'api_key': r'api[_\s]?key[:\s]+\w{32,}',
            'password': r'password[:\s]+\S+',
        }

        return any(re.search(p, text) for p in patterns.values())

    def redact_sensitive_data(self, text):
        """Remove sensitive data from text."""

        text = re.sub(r'\d{3}-\d{2}-\d{4}', '[REDACTED_SSN]', text)
        text = re.sub(r'api[_\s]?key[:\s]+\w{32,}', '[REDACTED_API_KEY]', text)

        return text

Abuse Detection

class AbuseDetection:
    def __init__(self):
        self.patterns = defaultdict(list)  # api_key -> [request_patterns]
        self.abuse_threshold = 5  # requests with same pattern = abuse

    def detect_abuse(self, api_key, request):
        """Detect if user is abusing the API."""

        # Record pattern
        pattern = self.extract_pattern(request)
        self.patterns[api_key].append(pattern)

        # Keep only recent patterns (last hour)
        now = time.time()
        self.patterns[api_key] = [
            (p, t) for p, t in self.patterns[api_key]
            if now - t < 3600
        ]

        # Check for repeated patterns
        pattern_counts = defaultdict(int)
        for pattern, _ in self.patterns[api_key]:
            pattern_counts[pattern] += 1

        max_count = max(pattern_counts.values()) if pattern_counts else 0

        if max_count > self.abuse_threshold:
            return {
                'abuse_detected': True,
                'action': 'block_api_key',
                'reason': f'Repeated pattern detected {max_count} times'
            }

        return {'abuse_detected': False}

    def extract_pattern(self, request):
        """Extract pattern from request."""

        prompt = request.get('prompt', '')

        # Simple pattern extraction
        if 'ignore' in prompt.lower():
            return 'injection_attempt'

        if len(prompt) > 5000:
            return 'excessive_input'

        return 'normal'

Cost Management

class APICostManagement:
    def __init__(self):
        self.user_costs = defaultdict(float)  # api_key -> cost
        self.cost_limits = defaultdict(lambda: {'daily': 100, 'monthly': 1000})
        self.cost_metrics = {
            'input_token': 0.000001,  # $1 per 1M tokens
            'output_token': 0.000002,
            'function_call': 0.01,
        }

    def estimate_cost(self, request):
        """Estimate cost of request."""

        prompt_tokens = len(request.get('prompt', '').split()) * 1.3

        cost = prompt_tokens * self.cost_metrics['input_token']

        return cost

    def check_cost_limit(self, api_key, estimated_cost):
        """Check if cost would exceed limit."""

        limits = self.cost_limits[api_key]
        current_cost = self.user_costs[api_key]

        # Rough daily/monthly check
        if current_cost + estimated_cost > limits['daily']:
            return {
                'allowed': False,
                'reason': 'Daily limit exceeded',
                'current': current_cost,
                'limit': limits['daily']
            }

        return {'allowed': True}

    def charge_user(self, api_key, cost):
        """Charge user for usage."""

        self.user_costs[api_key] += cost

        # Log for billing
        self.log_usage(api_key, cost)

    def log_usage(self, api_key, cost):
        """Log usage for billing."""

        # In production, write to billing database
        pass

Monitoring & Logging

class APIMonitoring:
    def __init__(self):
        self.access_logs = []
        self.alerts = []

    def log_api_request(self, api_key, request, response, duration):
        """Log API request."""

        self.access_logs.append({
            'timestamp': datetime.now(),
            'api_key': api_key[:4] + '***',  # Don't log full key
            'request_size': len(str(request)),
            'response_size': len(str(response)),
            'duration_ms': duration,
            'status': 'success' if response else 'error',
        })

    def detect_anomalies(self, api_key):
        """Detect unusual API usage."""

        recent_requests = [
            log for log in self.access_logs[-100:]
            if log['api_key'].startswith(api_key[:4])
        ]

        # Check for unusual patterns
        avg_duration = sum(r['duration_ms'] for r in recent_requests) / len(recent_requests) if recent_requests else 0

        if recent_requests:
            latest_duration = recent_requests[-1]['duration_ms']

            # If 10x slower than normal, alert
            if latest_duration > avg_duration * 10:
                self.alert('Unusual request duration', {'api_key': api_key, 'duration': latest_duration})

    def alert(self, message, details):
        """Alert on suspicious activity."""

        self.alerts.append({
            'timestamp': datetime.now(),
            'message': message,
            'details': details,
        })

        # In production, send real alerts
        # email_alert, slack_alert, etc.

Security Checklist for AI APIs

  • All requests authenticated with valid API key
  • Input validated for size, format, dangerous patterns
  • Output filtered for sensitive data before returning
  • Rate limiting prevents brute force and abuse
  • Cost tracking prevents expensive attacks
  • Abuse detection identifies attackers
  • Request/response logging for audit trail
  • Monitoring detects anomalies
  • Error messages don’t leak sensitive information
  • HTTPS required; no unencrypted communication
  • API keys rotated regularly
  • Deprecated endpoints removed
  • Documentation doesn’t reveal security details

Key Takeaway

Key Takeaway: API security for AI requires authentication, input/output validation, rate limiting, cost management, and monitoring. Each layer prevents different attacks: auth prevents unauthorized access, input validation prevents injection, rate limiting prevents abuse, cost management prevents expensive exploitation, and monitoring detects attacks.

Exercise: Secure an AI API

  1. Design authentication with API keys and permissions
  2. Implement input validation for your API
  3. Add rate limiting to prevent abuse
  4. Build output filtering for sensitive data
  5. Set up cost tracking and limits
  6. Create monitoring to detect abuse
  7. Test defenses against common API attacks

Next Module: AI Supply Chain Security—securing the supply chain of models and dependencies.