Advanced

Multi-Modal and Multi-Index RAG

Lesson 1 of 4 Estimated Time 55 min

Multi-Modal and Multi-Index RAG

Traditional RAG systems process only text. Multi-modal RAG extends retrieval to images, tables, charts, and structured data. This enables question-answering over rich documents while knowledge graphs capture complex relationships between entities for superior reasoning.

Multi-Modal Embeddings

Multi-modal embeddings represent text and images in the same vector space, enabling cross-modal retrieval.

from typing import Union, List, Tuple
from dataclasses import dataclass
import numpy as np
from PIL import Image

@dataclass
class MultiModalDocument:
    """Document containing multiple modalities."""
    doc_id: str
    text: str
    images: List[Image.Image] = None
    tables: List[dict] = None
    metadata: dict = None

class MultiModalEmbedder:
    """Generate embeddings for text and images."""

    def __init__(self, model_name: str = "openai/clip-vit-base-patch32"):
        self.model_name = model_name
        # In production: from transformers import CLIPProcessor, CLIPModel
        # self.processor = CLIPProcessor.from_pretrained(model_name)
        # self.model = CLIPModel.from_pretrained(model_name)

    def embed_text(self, text: str) -> np.ndarray:
        """Generate embedding for text."""
        # In production:
        # inputs = self.processor(text=text, return_tensors="pt")
        # text_features = self.model.get_text_features(**inputs)
        # return text_features[0].detach().numpy()
        return np.random.randn(512)  # Mock

    def embed_image(self, image: Image.Image) -> np.ndarray:
        """Generate embedding for image."""
        # In production:
        # inputs = self.processor(images=image, return_tensors="pt")
        # image_features = self.model.get_image_features(**inputs)
        # return image_features[0].detach().numpy()
        return np.random.randn(512)  # Mock

    def embed_document(self, doc: MultiModalDocument) -> dict:
        """Generate embeddings for all modalities."""
        embeddings = {
            "text": self.embed_text(doc.text),
            "images": [],
            "tables": []
        }

        if doc.images:
            for img in doc.images:
                embeddings["images"].append(self.embed_image(img))

        if doc.tables:
            for table in doc.tables:
                # Convert table to text representation
                table_text = self._table_to_text(table)
                embeddings["tables"].append(self.embed_text(table_text))

        return embeddings

    def _table_to_text(self, table: dict) -> str:
        """Convert table to text for embedding."""
        rows = []
        for row in table.get("rows", []):
            rows.append(" | ".join(str(v) for v in row.values()))
        return "\n".join(rows)

Multi-Index Retrieval

Use separate indexes for different modalities and retrieve from all simultaneously.

from typing import Optional, Dict
import faiss

class MultiModalIndex:
    """Index for text and image embeddings."""

    def __init__(self, embedding_dim: int = 512):
        self.embedding_dim = embedding_dim
        self.text_index = faiss.IndexFlatL2(embedding_dim)
        self.image_index = faiss.IndexFlatL2(embedding_dim)
        self.table_index = faiss.IndexFlatL2(embedding_dim)

        self.doc_mapping = {}  # Map index positions to doc_ids
        self.position_counter = {"text": 0, "image": 0, "table": 0}

    def add_document(
        self,
        doc_id: str,
        embeddings: dict,
        metadata: dict = None
    ):
        """Add document embeddings to indexes."""
        # Add text embeddings
        if "text" in embeddings:
            self.text_index.add(
                embeddings["text"].reshape(1, -1).astype('float32')
            )
            pos = self.position_counter["text"]
            self.doc_mapping[f"text_{pos}"] = (doc_id, "text", metadata)
            self.position_counter["text"] += 1

        # Add image embeddings
        if "images" in embeddings:
            for i, img_embedding in enumerate(embeddings["images"]):
                self.image_index.add(img_embedding.reshape(1, -1).astype('float32'))
                pos = self.position_counter["image"]
                self.doc_mapping[f"image_{pos}"] = (doc_id, f"image_{i}", metadata)
                self.position_counter["image"] += 1

        # Add table embeddings
        if "tables" in embeddings:
            for i, table_embedding in enumerate(embeddings["tables"]):
                self.table_index.add(table_embedding.reshape(1, -1).astype('float32'))
                pos = self.position_counter["table"]
                self.doc_mapping[f"table_{pos}"] = (doc_id, f"table_{i}", metadata)
                self.position_counter["table"] += 1

    def search(
        self,
        query_embedding: np.ndarray,
        k: int = 5,
        modalities: List[str] = None
    ) -> List[Tuple[str, str, float]]:
        """Search across modalities."""
        if modalities is None:
            modalities = ["text", "image", "table"]

        results = []
        query = query_embedding.reshape(1, -1).astype('float32')

        if "text" in modalities:
            distances, indices = self.text_index.search(query, k)
            for dist, idx in zip(distances[0], indices[0]):
                key = f"text_{idx}"
                if key in self.doc_mapping:
                    doc_id, source, metadata = self.doc_mapping[key]
                    results.append((doc_id, source, float(dist)))

        if "image" in modalities:
            distances, indices = self.image_index.search(query, k)
            for dist, idx in zip(distances[0], indices[0]):
                key = f"image_{idx}"
                if key in self.doc_mapping:
                    doc_id, source, metadata = self.doc_mapping[key]
                    results.append((doc_id, source, float(dist)))

        # Sort by distance and return top k
        results.sort(key=lambda x: x[2])
        return results[:k]

Knowledge Graphs

Knowledge graphs represent entities and relationships, enabling structured reasoning.

from typing import Set, Optional

@dataclass
class Entity:
    """Knowledge graph entity."""
    id: str
    name: str
    entity_type: str
    attributes: dict = None

@dataclass
class Relation:
    """Relationship between entities."""
    source_id: str
    target_id: str
    relation_type: str
    weight: float = 1.0
    attributes: dict = None

class KnowledgeGraph:
    """Graph-based knowledge representation."""

    def __init__(self):
        self.entities: Dict[str, Entity] = {}
        self.relations: List[Relation] = []
        self.relation_index: Dict[str, Set[str]] = {}

    def add_entity(self, entity: Entity):
        """Add entity to graph."""
        self.entities[entity.id] = entity

    def add_relation(self, relation: Relation):
        """Add relationship to graph."""
        self.relations.append(relation)

        # Index relations for fast lookup
        key = f"{relation.source_id}_{relation.relation_type}"
        if key not in self.relation_index:
            self.relation_index[key] = set()
        self.relation_index[key].add(relation.target_id)

    def get_neighbors(
        self,
        entity_id: str,
        relation_type: Optional[str] = None,
        distance: int = 1
    ) -> Dict[str, Set[str]]:
        """Get entities connected to given entity."""
        neighbors = {}

        for i in range(distance):
            if i == 0:
                source_ids = {entity_id}
            else:
                source_ids = set()
                for prev_neighbors in neighbors.values():
                    source_ids.update(prev_neighbors)

            for source_id in source_ids:
                for rel in self.relations:
                    if rel.source_id == source_id:
                        if relation_type is None or rel.relation_type == relation_type:
                            if rel.relation_type not in neighbors:
                                neighbors[rel.relation_type] = set()
                            neighbors[rel.relation_type].add(rel.target_id)

        return neighbors

    def find_path(
        self,
        start_id: str,
        end_id: str,
        max_depth: int = 3
    ) -> Optional[List[str]]:
        """Find path between two entities using BFS."""
        from collections import deque

        queue = deque([(start_id, [start_id])])
        visited = {start_id}

        while queue:
            current_id, path = queue.popleft()

            if current_id == end_id:
                return path

            if len(path) >= max_depth:
                continue

            for rel in self.relations:
                if rel.source_id == current_id and rel.target_id not in visited:
                    visited.add(rel.target_id)
                    queue.append((rel.target_id, path + [rel.target_id]))

        return None

    def find_related_entities(
        self,
        entity_id: str,
        relation_types: List[str]
    ) -> Dict[str, List[str]]:
        """Find entities related through multiple relation types."""
        result = {}

        for rel_type in relation_types:
            related = []
            for rel in self.relations:
                if rel.source_id == entity_id and rel.relation_type == rel_type:
                    related.append(rel.target_id)
            result[rel_type] = related

        return result

GraphRAG Integration

Combine graph-based retrieval with LLM for structured reasoning.

import anthropic

class GraphRAG:
    """RAG system with knowledge graph reasoning."""

    def __init__(
        self,
        knowledge_graph: KnowledgeGraph,
        embedder: MultiModalEmbedder,
        index: MultiModalIndex
    ):
        self.kg = knowledge_graph
        self.embedder = embedder
        self.index = index
        self.client = anthropic.Anthropic()

    def answer_question(self, question: str) -> str:
        """Answer question using graph and semantic retrieval."""
        # Extract entities from question
        entities = self._extract_entities(question)

        # Retrieve relevant documents
        query_embedding = self.embedder.embed_text(question)
        doc_results = self.index.search(query_embedding, k=3)

        # Expand with graph neighbors
        context_entities = set(entities)
        for entity_id in entities:
            neighbors = self.kg.get_neighbors(entity_id, distance=2)
            for rel_type, neighbor_ids in neighbors.items():
                context_entities.update(neighbor_ids)

        # Build context from entities and documents
        context = self._build_context(context_entities, doc_results)

        # Generate response with LLM
        response = self.client.messages.create(
            model="claude-3-5-sonnet-20241022",
            max_tokens=1024,
            system="You are an expert assistant answering questions using provided context and knowledge graph information.",
            messages=[{
                "role": "user",
                "content": f"Question: {question}\n\nContext:\n{context}"
            }]
        )

        return response.content[0].text

    def _extract_entities(self, text: str) -> Set[str]:
        """Extract entity mentions from text."""
        # In production: use NER or entity linker
        entity_ids = set()
        for entity_id, entity in self.kg.entities.items():
            if entity.name.lower() in text.lower():
                entity_ids.add(entity_id)
        return entity_ids

    def _build_context(
        self,
        entity_ids: Set[str],
        doc_results: List[Tuple[str, str, float]]
    ) -> str:
        """Build context string from entities and documents."""
        context_parts = []

        # Add entity information
        context_parts.append("Entities:")
        for entity_id in entity_ids:
            if entity_id in self.kg.entities:
                entity = self.kg.entities[entity_id]
                context_parts.append(f"  - {entity.name} ({entity.entity_type})")

        # Add retrieved documents
        context_parts.append("\nRelevant Documents:")
        for doc_id, source, distance in doc_results:
            context_parts.append(f"  - {doc_id}: {source} (similarity: {1-distance:.2f})")

        return "\n".join(context_parts)

Table and Structured Data Processing

Extract meaning from tables and structured data.

from typing import Any

class TableProcessor:
    """Process and embed table data."""

    def extract_columns(self, table: dict) -> dict[str, list]:
        """Extract columns from table."""
        columns = {}
        if "headers" in table:
            headers = table["headers"]
            rows = table.get("rows", [])
            for i, header in enumerate(headers):
                columns[header] = [row[i] if i < len(row) else None for row in rows]
        return columns

    def summarize_table(self, table: dict) -> str:
        """Generate text summary of table."""
        columns = self.extract_columns(table)
        summaries = []

        for column_name, values in columns.items():
            if not values:
                continue

            # Type detection
            if all(isinstance(v, (int, float)) for v in values if v is not None):
                summary = f"{column_name}: numeric column with {len([v for v in values if v is not None])} values"
            else:
                unique = len(set(v for v in values if v is not None))
                summary = f"{column_name}: categorical with {unique} unique values"

            summaries.append(summary)

        return f"Table structure: {', '.join(summaries)}"

    def find_numeric_patterns(self, table: dict) -> dict:
        """Find patterns in numeric columns."""
        columns = self.extract_columns(table)
        patterns = {}

        for col_name, values in columns.items():
            numeric_values = [v for v in values if isinstance(v, (int, float))]
            if numeric_values:
                patterns[col_name] = {
                    "min": min(numeric_values),
                    "max": max(numeric_values),
                    "avg": sum(numeric_values) / len(numeric_values),
                    "count": len(numeric_values)
                }

        return patterns

Key Takeaway

Multi-modal RAG processes text, images, and tables using shared embeddings. Knowledge graphs capture relationships enabling structured reasoning. Combined, they enable sophisticated question-answering over rich, complex documents.

Exercises

  1. Implement multi-modal embeddings for text and images
  2. Build separate indexes for text, image, and table modalities
  3. Create knowledge graph with 50+ entities and relation types
  4. Implement graph traversal for multi-hop reasoning
  5. Combine graph and semantic retrieval for answer synthesis
  6. Process tables with column type detection and summarization
  7. Build end-to-end QA system over mixed-media documents