Securing the AI Data Pipeline
Securing the AI Data Pipeline
End-to-End Data Protection
Data security in AI systems requires protection at every stage: ingestion, processing, storage, and retrieval. This lesson covers securing the complete data pipeline.
Data Pipeline Architecture
┌─────────────────┐
│ Data Sources │
│ (APIs, DBs, │
│ Files, Streams)│
└────────┬────────┘
│
▼
┌─────────────────┐ ┌──────────────────┐
│ Data Ingestion │──────▶│ Access Control │
│ (Validation, │ │ (Auth, Audit) │
│ Filtering) │ └──────────────────┘
└────────┬────────┘
│
▼
┌─────────────────┐ ┌──────────────────┐
│ Processing │──────▶│ Encryption │
│ (Transform, │ │ (in transit) │
│ Deduplicate) │ └──────────────────┘
└────────┬────────┘
│
▼
┌─────────────────┐ ┌──────────────────┐
│ Storage │──────▶│ Encryption │
│ (DB, Vector DB, │ │ (at rest) │
│ Cache) │ └──────────────────┘
└────────┬────────┘
│
▼
┌─────────────────┐ ┌──────────────────┐
│ Retrieval │──────▶│ Access Control │
│ (RAG, Context) │ │ (Row-level) │
└────────┬────────┘ └──────────────────┘
│
▼
┌─────────────────┐
│ Model Access │
│ (Inference) │
└─────────────────┘
Stage 1: Data Ingestion
Validation
Reject invalid or suspicious data:
class DataValidator:
def __init__(self):
self.schema_rules = {
'email': r'^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$',
'phone': r'^\+?1?\d{10}$',
'date': r'^\d{4}-\d{2}-\d{2}$',
'amount': lambda x: isinstance(x, (int, float)) and 0 <= x <= 1000000,
}
self.size_limits = {
'name': 100,
'description': 5000,
'file': 100 * 1024 * 1024, # 100MB
}
def validate(self, data, schema):
"""Validate data against schema."""
errors = []
for field, value in data.items():
if field not in schema:
errors.append(f"Unknown field: {field}")
continue
field_type = schema[field]
# Check type
if field_type == 'email':
if not re.match(self.schema_rules['email'], value):
errors.append(f"Invalid email: {field}")
# Check size limits
if field in self.size_limits:
if len(str(value)) > self.size_limits[field]:
errors.append(f"Field too large: {field}")
return {
'is_valid': len(errors) == 0,
'errors': errors
}
Rate Limiting
Prevent data ingestion DoS:
from collections import defaultdict
import time
class IngestionRateLimiter:
def __init__(self, max_records_per_minute=1000):
self.max_records = max_records_per_minute
self.ingestion_counts = defaultdict(list)
def should_accept(self, source_id):
"""Check if source has exceeded rate limit."""
now = time.time()
one_minute_ago = now - 60
# Remove old timestamps
self.ingestion_counts[source_id] = [
ts for ts in self.ingestion_counts[source_id]
if ts > one_minute_ago
]
if len(self.ingestion_counts[source_id]) >= self.max_records:
return False
self.ingestion_counts[source_id].append(now)
return True
Stage 2: Processing & Transformation
PII Redaction During Processing
Remove sensitive data during transformation:
class PII_RedactingProcessor:
def __init__(self):
self.detector = HybridPIIDetector()
def process(self, raw_data):
"""Process data with automatic PII redaction."""
processed = {}
for field, value in raw_data.items():
# Detect PII
findings = self.detector.detect(str(value))
if findings['high_confidence']:
# Redact
processed[field] = self.redact(value, findings)
else:
processed[field] = value
return processed
def redact(self, value, pii_findings):
"""Redact PII from value."""
redacted = str(value)
for pii_type, instances in pii_findings['high_confidence'].items():
for instance in instances:
redacted = redacted.replace(instance, f'[REDACTED_{pii_type}]')
return redacted
Deduplication
Remove duplicates that increase memorization risk:
import hashlib
def deduplicate_during_processing(data_batch):
"""Remove exact and near-duplicates."""
seen_hashes = set()
deduplicated = []
for record in data_batch:
# Exact duplicate detection
record_hash = hashlib.sha256(
json.dumps(record, sort_keys=True).encode()
).hexdigest()
if record_hash not in seen_hashes:
deduplicated.append(record)
seen_hashes.add(record_hash)
else:
# Duplicate found
continue
return deduplicated
Stage 3: Storage
Encryption at Rest
Encrypt data in databases and storage:
from cryptography.fernet import Fernet
class EncryptedDataStore:
def __init__(self, encryption_key):
self.cipher = Fernet(encryption_key)
def store(self, data_id, data):
"""Store data encrypted."""
# Serialize
serialized = json.dumps(data).encode()
# Encrypt
encrypted = self.cipher.encrypt(serialized)
# Store in database
db.store(data_id, encrypted)
def retrieve(self, data_id):
"""Retrieve and decrypt data."""
encrypted = db.retrieve(data_id)
# Decrypt
decrypted = self.cipher.decrypt(encrypted)
# Deserialize
data = json.loads(decrypted)
return data
Access Control in Databases
Implement row-level security:
class AccessControlledStore:
def __init__(self):
self.permissions = {} # user -> allowed_data_ids
def get_accessible_data(self, user_id, data_id):
"""Get data only if user has permission."""
if not self.has_permission(user_id, data_id):
raise PermissionError(f"User {user_id} cannot access {data_id}")
return self.retrieve(data_id)
def has_permission(self, user_id, data_id):
"""Check user permissions."""
if user_id not in self.permissions:
return False
return data_id in self.permissions[user_id]
def grant_permission(self, user_id, data_id):
"""Grant access to data."""
if user_id not in self.permissions:
self.permissions[user_id] = set()
self.permissions[user_id].add(data_id)
def revoke_permission(self, user_id, data_id):
"""Revoke access."""
if user_id in self.permissions:
self.permissions[user_id].discard(data_id)
Stage 4: Vector Database Security
If using embeddings/vector DB for RAG:
class SecureVectorDB:
def __init__(self):
self.vector_db = VectorDatabase()
self.source_metadata = {}
def store_embedding(self, data_id, embedding, source_document):
"""Store embedding with metadata."""
# Store embedding
self.vector_db.add(data_id, embedding)
# Store metadata (what document this came from)
self.source_metadata[data_id] = {
'source': source_document,
'timestamp': datetime.now(),
'access_log': []
}
def retrieve_similar(self, query_embedding, user_id):
"""Retrieve similar embeddings with access control."""
results = self.vector_db.query(query_embedding, top_k=5)
# Filter by access control
accessible_results = []
for result_id, similarity_score in results:
if self.can_access(user_id, result_id):
accessible_results.append((result_id, similarity_score))
self.log_access(result_id, user_id)
return accessible_results
def can_access(self, user_id, data_id):
"""Check if user can access this embedding's source."""
source = self.source_metadata[data_id]['source']
return self.access_control.can_user_access(user_id, source)
def log_access(self, data_id, user_id):
"""Audit access to embeddings."""
self.source_metadata[data_id]['access_log'].append({
'user': user_id,
'timestamp': datetime.now()
})
Stage 5: Encryption in Transit
Protect data moving through pipelines:
import ssl
import requests
class SecureDataTransport:
@staticmethod
def send_to_llm_api(data, api_key):
"""Send data securely to LLM API."""
# Use HTTPS (TLS 1.3)
session = requests.Session()
# Verify certificate
session.verify = True
# Encrypt sensitive fields
encrypted_data = {
'encrypted_payload': encrypt(data),
'key_id': 'key_v1'
}
# Send
response = session.post(
'https://api.llm.example.com/generate',
json=encrypted_data,
headers={'Authorization': f'Bearer {api_key}'}
)
return response
@staticmethod
def receive_from_database(query):
"""Retrieve data securely from database."""
# Use SSL/TLS connection
connection = db.connect(
host='db.example.com',
ssl_mode='require',
ssl_cert='path/to/cert.pem'
)
# Run query
result = connection.execute(query)
# Data is encrypted in transit (TLS)
return result
Data Retention and Deletion
Implement data lifecycle policies:
class DataLifecycleManager:
def __init__(self):
self.retention_policies = {
'training_data': 365 * 5, # 5 years
'logs': 90, # 90 days
'user_queries': 30, # 30 days
'context_windows': 1, # 1 day
}
def should_delete(self, data_type, creation_date):
"""Check if data should be deleted."""
retention_days = self.retention_policies.get(data_type, 365)
age_days = (datetime.now() - creation_date).days
return age_days > retention_days
def delete_expired_data(self):
"""Delete data that has exceeded retention."""
for data_type, records in self.get_all_data().items():
for record_id, creation_date in records:
if self.should_delete(data_type, creation_date):
self.securely_delete(record_id)
def securely_delete(self, data_id):
"""Delete data securely (cryptographic erasure)."""
# Method 1: Overwrite (for traditional storage)
# Overwrite with random data multiple times
# Method 2: Cryptographic erasure (for encrypted data)
# Delete the encryption key instead of data
# Method 3: Hardware destruction (for sensitive data)
log_audit_event('data_deleted', {'data_id': data_id})
Key Takeaway
Key Takeaway: Data security in AI systems requires protection at every stage of the pipeline: ingestion (validation, rate limiting), processing (PII redaction, deduplication), storage (encryption at rest, access control), retrieval (access control, logging), and transit (encryption in transit). Defense in depth at each stage is essential.
Exercise: Secure Your Data Pipeline
- Map your current pipeline: Document each stage
- Identify vulnerabilities: Where is data exposed?
- Implement controls: For each stage, implement at least one security control
- Test data protection: Verify encryption is working, access control is enforced
- Document policies: Create data retention and deletion policies
Next Lesson: Privacy-Preserving AI Techniques—advanced methods for secure AI.