Intermediate
API Security for AI Services
API Security for AI Services
Protecting AI Through APIs
When you expose AI capabilities through APIs, you create new attack surfaces. Users interact through APIs, attackers also attack through APIs. This lesson covers securing AI service APIs.
Core API Security for AI
Authentication & Authorization
class APIAuthentication:
def __init__(self):
self.api_keys = {} # api_key -> user_metadata
self.rate_limits = {} # api_key -> limit_config
self.user_permissions = {} # api_key -> permissions
def authenticate_request(self, api_key):
"""Verify API key is valid."""
if api_key not in self.api_keys:
raise AuthenticationError("Invalid API key")
metadata = self.api_keys[api_key]
if metadata['revoked']:
raise AuthenticationError("API key revoked")
if metadata['expires_at'] < datetime.now():
raise AuthenticationError("API key expired")
return metadata
def authorize_action(self, api_key, action):
"""Check user can perform action."""
permissions = self.user_permissions.get(api_key, set())
if action not in permissions:
raise AuthorizationError(f"Not authorized for {action}")
return True
Input Validation & Size Limits
class APIInputValidation:
def validate_request(self, request):
"""Validate API request."""
# Size limits
MAX_INPUT_SIZE = 10000 # characters
MAX_REQUEST_SIZE = 1 * 1024 * 1024 # 1 MB
if len(request.get('prompt', '')) > MAX_INPUT_SIZE:
raise ValueError(f"Input exceeds max size {MAX_INPUT_SIZE}")
if len(str(request)) > MAX_REQUEST_SIZE:
raise ValueError(f"Request exceeds max size {MAX_REQUEST_SIZE}")
# Required fields
required_fields = ['prompt']
for field in required_fields:
if field not in request:
raise ValueError(f"Missing required field: {field}")
# Type validation
if not isinstance(request['prompt'], str):
raise ValueError("Prompt must be string")
# Injection detection
if self.is_suspicious(request['prompt']):
return {'valid': False, 'reason': 'Suspicious input detected'}
return {'valid': True}
def is_suspicious(self, text):
"""Detect suspicious patterns."""
patterns = [
r'ignore.*instruction',
r'system.*prompt',
r'new instruction',
]
return any(re.search(p, text, re.IGNORECASE) for p in patterns)
Rate Limiting
from collections import defaultdict
import time
class APIRateLimiter:
def __init__(self):
self.requests = defaultdict(list) # api_key -> [timestamps]
self.limits = defaultdict(lambda: {'requests': 100, 'window': 3600}) # Per hour
def check_rate_limit(self, api_key):
"""Check if request is within rate limit."""
now = time.time()
limit_config = self.limits[api_key]
# Remove old timestamps outside window
self.requests[api_key] = [
ts for ts in self.requests[api_key]
if now - ts < limit_config['window']
]
# Check count
if len(self.requests[api_key]) >= limit_config['requests']:
reset_time = self.requests[api_key][0] + limit_config['window']
return {
'allowed': False,
'reset_at': reset_time,
'retry_after': int(reset_time - now)
}
# Record this request
self.requests[api_key].append(now)
return {'allowed': True}
Output Validation & Sanitization
class APIOutputValidation:
def validate_response(self, response):
"""Validate response before returning."""
# Check for sensitive data leakage
if self.contains_sensitive_data(response):
# Log and redact
self.log_security_event('sensitive_data_in_response', response[:100])
response = self.redact_sensitive_data(response)
# Check for harmful content
if self.contains_harmful_content(response):
self.log_security_event('harmful_content_detected', response[:100])
# Don't return harmful content
return "I cannot provide that response"
# Validate format
if not self.is_valid_format(response):
return "Response formatting error"
return response
def contains_sensitive_data(self, text):
"""Detect if response leaks sensitive data."""
patterns = {
'ssn': r'\d{3}-\d{2}-\d{4}',
'api_key': r'api[_\s]?key[:\s]+\w{32,}',
'password': r'password[:\s]+\S+',
}
return any(re.search(p, text) for p in patterns.values())
def redact_sensitive_data(self, text):
"""Remove sensitive data from text."""
text = re.sub(r'\d{3}-\d{2}-\d{4}', '[REDACTED_SSN]', text)
text = re.sub(r'api[_\s]?key[:\s]+\w{32,}', '[REDACTED_API_KEY]', text)
return text
Abuse Detection
class AbuseDetection:
def __init__(self):
self.patterns = defaultdict(list) # api_key -> [request_patterns]
self.abuse_threshold = 5 # requests with same pattern = abuse
def detect_abuse(self, api_key, request):
"""Detect if user is abusing the API."""
# Record pattern
pattern = self.extract_pattern(request)
self.patterns[api_key].append(pattern)
# Keep only recent patterns (last hour)
now = time.time()
self.patterns[api_key] = [
(p, t) for p, t in self.patterns[api_key]
if now - t < 3600
]
# Check for repeated patterns
pattern_counts = defaultdict(int)
for pattern, _ in self.patterns[api_key]:
pattern_counts[pattern] += 1
max_count = max(pattern_counts.values()) if pattern_counts else 0
if max_count > self.abuse_threshold:
return {
'abuse_detected': True,
'action': 'block_api_key',
'reason': f'Repeated pattern detected {max_count} times'
}
return {'abuse_detected': False}
def extract_pattern(self, request):
"""Extract pattern from request."""
prompt = request.get('prompt', '')
# Simple pattern extraction
if 'ignore' in prompt.lower():
return 'injection_attempt'
if len(prompt) > 5000:
return 'excessive_input'
return 'normal'
Cost Management
class APICostManagement:
def __init__(self):
self.user_costs = defaultdict(float) # api_key -> cost
self.cost_limits = defaultdict(lambda: {'daily': 100, 'monthly': 1000})
self.cost_metrics = {
'input_token': 0.000001, # $1 per 1M tokens
'output_token': 0.000002,
'function_call': 0.01,
}
def estimate_cost(self, request):
"""Estimate cost of request."""
prompt_tokens = len(request.get('prompt', '').split()) * 1.3
cost = prompt_tokens * self.cost_metrics['input_token']
return cost
def check_cost_limit(self, api_key, estimated_cost):
"""Check if cost would exceed limit."""
limits = self.cost_limits[api_key]
current_cost = self.user_costs[api_key]
# Rough daily/monthly check
if current_cost + estimated_cost > limits['daily']:
return {
'allowed': False,
'reason': 'Daily limit exceeded',
'current': current_cost,
'limit': limits['daily']
}
return {'allowed': True}
def charge_user(self, api_key, cost):
"""Charge user for usage."""
self.user_costs[api_key] += cost
# Log for billing
self.log_usage(api_key, cost)
def log_usage(self, api_key, cost):
"""Log usage for billing."""
# In production, write to billing database
pass
Monitoring & Logging
class APIMonitoring:
def __init__(self):
self.access_logs = []
self.alerts = []
def log_api_request(self, api_key, request, response, duration):
"""Log API request."""
self.access_logs.append({
'timestamp': datetime.now(),
'api_key': api_key[:4] + '***', # Don't log full key
'request_size': len(str(request)),
'response_size': len(str(response)),
'duration_ms': duration,
'status': 'success' if response else 'error',
})
def detect_anomalies(self, api_key):
"""Detect unusual API usage."""
recent_requests = [
log for log in self.access_logs[-100:]
if log['api_key'].startswith(api_key[:4])
]
# Check for unusual patterns
avg_duration = sum(r['duration_ms'] for r in recent_requests) / len(recent_requests) if recent_requests else 0
if recent_requests:
latest_duration = recent_requests[-1]['duration_ms']
# If 10x slower than normal, alert
if latest_duration > avg_duration * 10:
self.alert('Unusual request duration', {'api_key': api_key, 'duration': latest_duration})
def alert(self, message, details):
"""Alert on suspicious activity."""
self.alerts.append({
'timestamp': datetime.now(),
'message': message,
'details': details,
})
# In production, send real alerts
# email_alert, slack_alert, etc.
Security Checklist for AI APIs
- All requests authenticated with valid API key
- Input validated for size, format, dangerous patterns
- Output filtered for sensitive data before returning
- Rate limiting prevents brute force and abuse
- Cost tracking prevents expensive attacks
- Abuse detection identifies attackers
- Request/response logging for audit trail
- Monitoring detects anomalies
- Error messages don’t leak sensitive information
- HTTPS required; no unencrypted communication
- API keys rotated regularly
- Deprecated endpoints removed
- Documentation doesn’t reveal security details
Key Takeaway
Key Takeaway: API security for AI requires authentication, input/output validation, rate limiting, cost management, and monitoring. Each layer prevents different attacks: auth prevents unauthorized access, input validation prevents injection, rate limiting prevents abuse, cost management prevents expensive exploitation, and monitoring detects attacks.
Exercise: Secure an AI API
- Design authentication with API keys and permissions
- Implement input validation for your API
- Add rate limiting to prevent abuse
- Build output filtering for sensitive data
- Set up cost tracking and limits
- Create monitoring to detect abuse
- Test defenses against common API attacks
Next Module: AI Supply Chain Security—securing the supply chain of models and dependencies.