Multi-Modal and Multi-Index RAG
Multi-Modal and Multi-Index RAG
Traditional RAG systems process only text. Multi-modal RAG extends retrieval to images, tables, charts, and structured data. This enables question-answering over rich documents while knowledge graphs capture complex relationships between entities for superior reasoning.
Multi-Modal Embeddings
Multi-modal embeddings represent text and images in the same vector space, enabling cross-modal retrieval.
from typing import Union, List, Tuple
from dataclasses import dataclass
import numpy as np
from PIL import Image
@dataclass
class MultiModalDocument:
"""Document containing multiple modalities."""
doc_id: str
text: str
images: List[Image.Image] = None
tables: List[dict] = None
metadata: dict = None
class MultiModalEmbedder:
"""Generate embeddings for text and images."""
def __init__(self, model_name: str = "openai/clip-vit-base-patch32"):
self.model_name = model_name
# In production: from transformers import CLIPProcessor, CLIPModel
# self.processor = CLIPProcessor.from_pretrained(model_name)
# self.model = CLIPModel.from_pretrained(model_name)
def embed_text(self, text: str) -> np.ndarray:
"""Generate embedding for text."""
# In production:
# inputs = self.processor(text=text, return_tensors="pt")
# text_features = self.model.get_text_features(**inputs)
# return text_features[0].detach().numpy()
return np.random.randn(512) # Mock
def embed_image(self, image: Image.Image) -> np.ndarray:
"""Generate embedding for image."""
# In production:
# inputs = self.processor(images=image, return_tensors="pt")
# image_features = self.model.get_image_features(**inputs)
# return image_features[0].detach().numpy()
return np.random.randn(512) # Mock
def embed_document(self, doc: MultiModalDocument) -> dict:
"""Generate embeddings for all modalities."""
embeddings = {
"text": self.embed_text(doc.text),
"images": [],
"tables": []
}
if doc.images:
for img in doc.images:
embeddings["images"].append(self.embed_image(img))
if doc.tables:
for table in doc.tables:
# Convert table to text representation
table_text = self._table_to_text(table)
embeddings["tables"].append(self.embed_text(table_text))
return embeddings
def _table_to_text(self, table: dict) -> str:
"""Convert table to text for embedding."""
rows = []
for row in table.get("rows", []):
rows.append(" | ".join(str(v) for v in row.values()))
return "\n".join(rows)
Multi-Index Retrieval
Use separate indexes for different modalities and retrieve from all simultaneously.
from typing import Optional, Dict
import faiss
class MultiModalIndex:
"""Index for text and image embeddings."""
def __init__(self, embedding_dim: int = 512):
self.embedding_dim = embedding_dim
self.text_index = faiss.IndexFlatL2(embedding_dim)
self.image_index = faiss.IndexFlatL2(embedding_dim)
self.table_index = faiss.IndexFlatL2(embedding_dim)
self.doc_mapping = {} # Map index positions to doc_ids
self.position_counter = {"text": 0, "image": 0, "table": 0}
def add_document(
self,
doc_id: str,
embeddings: dict,
metadata: dict = None
):
"""Add document embeddings to indexes."""
# Add text embeddings
if "text" in embeddings:
self.text_index.add(
embeddings["text"].reshape(1, -1).astype('float32')
)
pos = self.position_counter["text"]
self.doc_mapping[f"text_{pos}"] = (doc_id, "text", metadata)
self.position_counter["text"] += 1
# Add image embeddings
if "images" in embeddings:
for i, img_embedding in enumerate(embeddings["images"]):
self.image_index.add(img_embedding.reshape(1, -1).astype('float32'))
pos = self.position_counter["image"]
self.doc_mapping[f"image_{pos}"] = (doc_id, f"image_{i}", metadata)
self.position_counter["image"] += 1
# Add table embeddings
if "tables" in embeddings:
for i, table_embedding in enumerate(embeddings["tables"]):
self.table_index.add(table_embedding.reshape(1, -1).astype('float32'))
pos = self.position_counter["table"]
self.doc_mapping[f"table_{pos}"] = (doc_id, f"table_{i}", metadata)
self.position_counter["table"] += 1
def search(
self,
query_embedding: np.ndarray,
k: int = 5,
modalities: List[str] = None
) -> List[Tuple[str, str, float]]:
"""Search across modalities."""
if modalities is None:
modalities = ["text", "image", "table"]
results = []
query = query_embedding.reshape(1, -1).astype('float32')
if "text" in modalities:
distances, indices = self.text_index.search(query, k)
for dist, idx in zip(distances[0], indices[0]):
key = f"text_{idx}"
if key in self.doc_mapping:
doc_id, source, metadata = self.doc_mapping[key]
results.append((doc_id, source, float(dist)))
if "image" in modalities:
distances, indices = self.image_index.search(query, k)
for dist, idx in zip(distances[0], indices[0]):
key = f"image_{idx}"
if key in self.doc_mapping:
doc_id, source, metadata = self.doc_mapping[key]
results.append((doc_id, source, float(dist)))
# Sort by distance and return top k
results.sort(key=lambda x: x[2])
return results[:k]
Knowledge Graphs
Knowledge graphs represent entities and relationships, enabling structured reasoning.
from typing import Set, Optional
@dataclass
class Entity:
"""Knowledge graph entity."""
id: str
name: str
entity_type: str
attributes: dict = None
@dataclass
class Relation:
"""Relationship between entities."""
source_id: str
target_id: str
relation_type: str
weight: float = 1.0
attributes: dict = None
class KnowledgeGraph:
"""Graph-based knowledge representation."""
def __init__(self):
self.entities: Dict[str, Entity] = {}
self.relations: List[Relation] = []
self.relation_index: Dict[str, Set[str]] = {}
def add_entity(self, entity: Entity):
"""Add entity to graph."""
self.entities[entity.id] = entity
def add_relation(self, relation: Relation):
"""Add relationship to graph."""
self.relations.append(relation)
# Index relations for fast lookup
key = f"{relation.source_id}_{relation.relation_type}"
if key not in self.relation_index:
self.relation_index[key] = set()
self.relation_index[key].add(relation.target_id)
def get_neighbors(
self,
entity_id: str,
relation_type: Optional[str] = None,
distance: int = 1
) -> Dict[str, Set[str]]:
"""Get entities connected to given entity."""
neighbors = {}
for i in range(distance):
if i == 0:
source_ids = {entity_id}
else:
source_ids = set()
for prev_neighbors in neighbors.values():
source_ids.update(prev_neighbors)
for source_id in source_ids:
for rel in self.relations:
if rel.source_id == source_id:
if relation_type is None or rel.relation_type == relation_type:
if rel.relation_type not in neighbors:
neighbors[rel.relation_type] = set()
neighbors[rel.relation_type].add(rel.target_id)
return neighbors
def find_path(
self,
start_id: str,
end_id: str,
max_depth: int = 3
) -> Optional[List[str]]:
"""Find path between two entities using BFS."""
from collections import deque
queue = deque([(start_id, [start_id])])
visited = {start_id}
while queue:
current_id, path = queue.popleft()
if current_id == end_id:
return path
if len(path) >= max_depth:
continue
for rel in self.relations:
if rel.source_id == current_id and rel.target_id not in visited:
visited.add(rel.target_id)
queue.append((rel.target_id, path + [rel.target_id]))
return None
def find_related_entities(
self,
entity_id: str,
relation_types: List[str]
) -> Dict[str, List[str]]:
"""Find entities related through multiple relation types."""
result = {}
for rel_type in relation_types:
related = []
for rel in self.relations:
if rel.source_id == entity_id and rel.relation_type == rel_type:
related.append(rel.target_id)
result[rel_type] = related
return result
GraphRAG Integration
Combine graph-based retrieval with LLM for structured reasoning.
import anthropic
class GraphRAG:
"""RAG system with knowledge graph reasoning."""
def __init__(
self,
knowledge_graph: KnowledgeGraph,
embedder: MultiModalEmbedder,
index: MultiModalIndex
):
self.kg = knowledge_graph
self.embedder = embedder
self.index = index
self.client = anthropic.Anthropic()
def answer_question(self, question: str) -> str:
"""Answer question using graph and semantic retrieval."""
# Extract entities from question
entities = self._extract_entities(question)
# Retrieve relevant documents
query_embedding = self.embedder.embed_text(question)
doc_results = self.index.search(query_embedding, k=3)
# Expand with graph neighbors
context_entities = set(entities)
for entity_id in entities:
neighbors = self.kg.get_neighbors(entity_id, distance=2)
for rel_type, neighbor_ids in neighbors.items():
context_entities.update(neighbor_ids)
# Build context from entities and documents
context = self._build_context(context_entities, doc_results)
# Generate response with LLM
response = self.client.messages.create(
model="claude-3-5-sonnet-20241022",
max_tokens=1024,
system="You are an expert assistant answering questions using provided context and knowledge graph information.",
messages=[{
"role": "user",
"content": f"Question: {question}\n\nContext:\n{context}"
}]
)
return response.content[0].text
def _extract_entities(self, text: str) -> Set[str]:
"""Extract entity mentions from text."""
# In production: use NER or entity linker
entity_ids = set()
for entity_id, entity in self.kg.entities.items():
if entity.name.lower() in text.lower():
entity_ids.add(entity_id)
return entity_ids
def _build_context(
self,
entity_ids: Set[str],
doc_results: List[Tuple[str, str, float]]
) -> str:
"""Build context string from entities and documents."""
context_parts = []
# Add entity information
context_parts.append("Entities:")
for entity_id in entity_ids:
if entity_id in self.kg.entities:
entity = self.kg.entities[entity_id]
context_parts.append(f" - {entity.name} ({entity.entity_type})")
# Add retrieved documents
context_parts.append("\nRelevant Documents:")
for doc_id, source, distance in doc_results:
context_parts.append(f" - {doc_id}: {source} (similarity: {1-distance:.2f})")
return "\n".join(context_parts)
Table and Structured Data Processing
Extract meaning from tables and structured data.
from typing import Any
class TableProcessor:
"""Process and embed table data."""
def extract_columns(self, table: dict) -> dict[str, list]:
"""Extract columns from table."""
columns = {}
if "headers" in table:
headers = table["headers"]
rows = table.get("rows", [])
for i, header in enumerate(headers):
columns[header] = [row[i] if i < len(row) else None for row in rows]
return columns
def summarize_table(self, table: dict) -> str:
"""Generate text summary of table."""
columns = self.extract_columns(table)
summaries = []
for column_name, values in columns.items():
if not values:
continue
# Type detection
if all(isinstance(v, (int, float)) for v in values if v is not None):
summary = f"{column_name}: numeric column with {len([v for v in values if v is not None])} values"
else:
unique = len(set(v for v in values if v is not None))
summary = f"{column_name}: categorical with {unique} unique values"
summaries.append(summary)
return f"Table structure: {', '.join(summaries)}"
def find_numeric_patterns(self, table: dict) -> dict:
"""Find patterns in numeric columns."""
columns = self.extract_columns(table)
patterns = {}
for col_name, values in columns.items():
numeric_values = [v for v in values if isinstance(v, (int, float))]
if numeric_values:
patterns[col_name] = {
"min": min(numeric_values),
"max": max(numeric_values),
"avg": sum(numeric_values) / len(numeric_values),
"count": len(numeric_values)
}
return patterns
Key Takeaway
Multi-modal RAG processes text, images, and tables using shared embeddings. Knowledge graphs capture relationships enabling structured reasoning. Combined, they enable sophisticated question-answering over rich, complex documents.
Exercises
- Implement multi-modal embeddings for text and images
- Build separate indexes for text, image, and table modalities
- Create knowledge graph with 50+ entities and relation types
- Implement graph traversal for multi-hop reasoning
- Combine graph and semantic retrieval for answer synthesis
- Process tables with column type detection and summarization
- Build end-to-end QA system over mixed-media documents