Skip to content

Latest commit

 

History

History
456 lines (392 loc) · 29.7 KB

File metadata and controls

456 lines (392 loc) · 29.7 KB

Spatial-RAG World Model: Architecture Design Document

Overview

This document describes the architecture of a Spatial Retrieval-Augmented Generation (Spatial-RAG) system for latent world models. The system combines:

  • Latent World Models: Learn compressed representations of environment dynamics
  • Spatial Memory: Store experiences with geospatial metadata for efficient retrieval
  • Hybrid Retrieval: Combine spatial proximity filtering with latent similarity search

The goal is to build embodied AI systems that can leverage past experiences for improved prediction, planning, and decision-making.

System Architecture

High-Level Overview

┌─────────────────────────────────────────────────────────────────────────────┐
│                           DEPLOYMENT STACK                                   │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                              │
│  ┌──────────────┐     ┌──────────────┐     ┌──────────────┐                │
│  │   Next.js    │◄───►│   FastAPI    │◄───►│   Qdrant     │                │
│  │   UI :3000   │ SSE │   API :8080  │     │   DB :6333   │                │
│  └──────────────┘     └──────┬───────┘     └──────────────┘                │
│                              │                                              │
│                    ┌─────────┴─────────┐                                   │
│                    │  TorchScript      │                                   │
│                    │  Models           │                                   │
│                    └─────────┬─────────┘                                   │
│                              │                                              │
│  ┌──────────────────────────┼──────────────────────────┐                  │
│  │           ROS2 INTEGRATION (Optional)               │                  │
│  │  ┌─────────────┐    ┌─────────────┐    ┌─────────┐ │                  │
│  │  │  /camera    │───►│  Latent     │───►│ /latent │ │                  │
│  │  │  /image_raw │    │  Publisher  │    │ topic   │ │                  │
│  │  └─────────────┘    └──────┬──────┘    └─────────┘ │                  │
│  │                            │                        │                  │
│  │  ┌─────────────┐    ┌──────┴──────┐                │                  │
│  │  │  /actions   │───►│  Transition │                │                  │
│  │  │  topic      │    │  Predictor  │                │                  │
│  │  └─────────────┘    └─────────────┘                │                  │
│  └─────────────────────────────────────────────────────┘                  │
└─────────────────────────────────────────────────────────────────────────────┘

Core Architecture

┌─────────────────────────────────────────────────────────────────────┐
│                        PERCEPTION PIPELINE                          │
├─────────────────────────────────────────────────────────────────────┤
│  RGB Image ──┐                                                      │
│  Depth Map ──┼──► [Encoder E_φ] ──► Latent z_t ──► Memory Bank     │
│  Proprio ────┘                            │                         │
│                                           │                         │
│                                           ▼                         │
│                              ┌────────────────────────┐             │
│                              │   RETRIEVAL MODULE     │             │
│                              │   ┌─────────────────┐  │             │
│   Query: z_t, position ────► │   │ Spatial Filter  │  │             │
│                              │   │ (H3 / BBox)     │  │             │
│                              │   └────────┬────────┘  │             │
│                              │            ▼           │             │
│                              │   ┌─────────────────┐  │             │
│                              │   │ Latent KNN      │  │             │
│                              │   │ (Faiss/Qdrant)  │  │             │
│                              │   └────────┬────────┘  │             │
│                              └────────────┼──────────┘             │
│                                           ▼                         │
│                              Retrieved Latents R = {z_r1, ..., z_rK}│
│                                           │                         │
├───────────────────────────────────────────┼─────────────────────────┤
│                        DYNAMICS MODEL                               │
├───────────────────────────────────────────┼─────────────────────────┤
│                                           ▼                         │
│   z_t ──┐                    ┌────────────────────────┐             │
│   a_t ──┼──────────────────► │  Transition f_θ       │             │
│   R   ──┘                    │  (GRU + Attention)    │             │
│                              └───────────┬───────────┘             │
│                                          ▼                          │
│                                    z_{t+1} (predicted)              │
│                                          │                          │
│                              ┌───────────┴───────────┐              │
│                              ▼                       ▼              │
│                    [Decoder D_ψ]            [Policy/Planner]        │
│                         │                        │                  │
│                         ▼                        ▼                  │
│                   Reconstruction           Action Selection         │
└─────────────────────────────────────────────────────────────────────┘

Docker Architecture

Optimized Image Sharing

All services share a single 13GB base image for efficiency:

┌─────────────────────────────────────────────────────────────────────┐
│                    DOCKER IMAGE ARCHITECTURE                         │
├─────────────────────────────────────────────────────────────────────┤
│                                                                      │
│  ┌─────────────────────────────────────────────────────────────┐   │
│  │                 spatial-rag-base:latest                      │   │
│  │                      (13.2 GB)                               │   │
│  │  ┌─────────────────────────────────────────────────────┐   │   │
│  │  │  Python 3.10 + PyTorch + Dependencies               │   │   │
│  │  └─────────────────────────────────────────────────────┘   │   │
│  └────────────────────────────┬────────────────────────────────┘   │
│                               │                                      │
│         ┌─────────────────────┼─────────────────────┐               │
│         ▼                     ▼                     ▼               │
│  ┌─────────────┐     ┌─────────────┐     ┌─────────────┐          │
│  │ generate-   │     │   train     │     │  experiment │          │
│  │    data     │     │             │     │             │          │
│  └─────────────┘     └─────────────┘     └─────────────┘          │
│         ▼                     ▼                     ▼               │
│  ┌─────────────┐     ┌─────────────┐     ┌─────────────┐          │
│  │   reports   │     │   collect   │     │  (+ more)   │          │
│  └─────────────┘     └─────────────┘     └─────────────┘          │
│                                                                      │
│  Separate Images:                                                    │
│  ┌─────────────┐     ┌─────────────┐     ┌─────────────┐          │
│  │     api     │     │    ros2     │     │     ui      │          │
│  │  (12.1 GB)  │     │  (13.5 GB)  │     │  (alpine)   │          │
│  └─────────────┘     └─────────────┘     └─────────────┘          │
│                                                                      │
│  Total: ~25 GB (down from ~117 GB with separate images)             │
└─────────────────────────────────────────────────────────────────────┘

Service Profiles

profiles:
  core:    [qdrant, api]          # Essential services
  ui:      [ui]                   # Dashboard
  ros2:    [ros2]                 # Robotics
  data:    [generate-data]        # Data generation
  train:   [train]                # Training
  experiment: [experiment]        # Experiments
  reports: [reports]              # Report generation
  collect: [collect]              # Data collection

Components

1. Encoder (E_φ)

Purpose: Map observations to latent representations

Architecture: ResNet18 backbone with projection head

Input:

  • RGB images: [B, 3, H, W]
  • Optional depth: [B, 1, H, W]
  • Optional proprioception: [B, proprio_dim]

Output:

  • Latent vector: z ∈ ℝ^{z_dim}
  • Feature map (optional): [B, 512, H/32, W/32]

2. Memory Bank

Purpose: Store latent states with rich metadata for retrieval

Storage Format:

{
    "id": str,                    # Unique identifier
    "vector": np.ndarray,         # Latent z [z_dim]
    "position": {"x": float, "y": float, "z": float},
    "h3_cell": str,              # H3 geospatial cell ID
    "action": np.ndarray,         # Action that led here
    "timestamp": float,           # Unix timestamp
}

Backends:

  1. Qdrant (Production): Vector DB with payload filtering
  2. Faiss (Development): Local CPU-based

3. Retrieval Module

Algorithm:

1. Spatial Prefilter:
   - Compute H3 cell for query position
   - Get neighboring cells (k-ring)
   - Filter memory to candidates
   
2. Latent KNN:
   - Search top-K nearest in latent space
   
3. Return:
   - Retrieved latents: R = [z_r1, ..., z_rK]
   - Metadata for each retrieved state

4. Transition Model (f_θ)

Architecture: GRU + Cross-Attention

def forward(z_t, a_t, retrieved_z):
    a_emb = action_proj(a_t)
    h = gru([z_t || a_emb])
    attn_out = multihead_attention(query=h, key=retrieved_z, value=retrieved_z)
    gate = sigmoid(gate_proj([h || attn_out]))
    fused = gate * attn_out + (1 - gate) * h
    z_t1 = mlp(fused) + z_t  # Residual
    return z_t1

5. Decoder (D_ψ)

Purpose: Reconstruct observations from latent representations

Architecture: ConvTranspose upsampling network

ROS2 Node Architecture

┌─────────────────────────────────────────────────────────┐
│               ROS2 Latent Publisher Node                 │
├─────────────────────────────────────────────────────────┤
│                                                          │
│  Subscriptions:                                          │
│  ┌────────────────────┐                                 │
│  │ /camera/image_raw  │ ─┐                              │
│  │ (sensor_msgs/Image)│  │                              │
│  └────────────────────┘  │                              │
│                          ▼                              │
│  ┌────────────────────┐  ┌────────────────────┐        │
│  │ /actions           │──►│  LatentPublisher   │        │
│  │ (Float32MultiArray)│  │                    │        │
│  └────────────────────┘  │  - Image encoder   │        │
│                          │  - Transition      │        │
│  ┌────────────────────┐  │    predictor      │        │
│  │ /pose              │──►│                    │        │
│  │ (PoseStamped)      │  └─────────┬──────────┘        │
│  └────────────────────┘            │                    │
│                                    ▼                    │
│  Publishers:                                            │
│  ┌────────────────────┐  ┌────────────────────┐        │
│  │ /latent            │◄─┤  Encoded latent    │        │
│  │ (Float32MultiArray)│  │  vector z_t (32d)  │        │
│  └────────────────────┘  └────────────────────┘        │
│                                                          │
│  ┌────────────────────┐  ┌────────────────────┐        │
│  │ /latent_next       │◄─┤  Predicted next    │        │
│  │ (Float32MultiArray)│  │  latent z_{t+1}    │        │
│  └────────────────────┘  └────────────────────┘        │
│                                                          │
│  Parameters:                                             │
│  - model_path: TorchScript model file                   │
│  - z_dim: 32 (latent dimension)                         │
│  - publish_rate: 20.0 Hz                                │
└─────────────────────────────────────────────────────────┘

Verified ROS2 Topics

$ ros2 topic list
/actions
/camera/image_raw
/latent
/latent_next
/parameter_events
/pose
/rosout

Sample Latent Output

layout:
  dim:
  - label: z
    size: 32
    stride: 32
data: [0.021, -0.132, -0.055, ...]  # 32-dim vector @ 20Hz

FastAPI Server Architecture

┌─────────────────────────────────────────────────────────┐
│                   FastAPI Application                    │
├─────────────────────────────────────────────────────────┤
│                                                          │
│  ┌─────────────┐  ┌─────────────┐  ┌─────────────┐     │
│  │  /encode    │  │  /predict   │  │  /rollout   │     │
│  │  POST       │  │  POST       │  │  POST       │     │
│  └──────┬──────┘  └──────┬──────┘  └──────┬──────┘     │
│         │                │                │             │
│         ▼                ▼                ▼             │
│  ┌─────────────────────────────────────────────┐       │
│  │              Model Components               │       │
│  │  ┌─────────┐  ┌──────────┐  ┌─────────┐   │       │
│  │  │ Encoder │  │Transition│  │ Decoder │   │       │
│  │  └─────────┘  └──────────┘  └─────────┘   │       │
│  └────────────────────┬────────────────────────┘       │
│                       │                                 │
│         ┌─────────────┴─────────────┐                  │
│         ▼                           ▼                  │
│  ┌─────────────┐             ┌─────────────┐          │
│  │  Retriever  │◄───────────►│   Qdrant    │          │
│  └─────────────┘             └─────────────┘          │
│                                                          │
├─────────────────────────────────────────────────────────┤
│  SSE Streaming: /stream-rollout                         │
│  Real-time frame prediction with decoded images         │
└─────────────────────────────────────────────────────────┘

Training Pipeline

┌─────────────────────────────────────────────────────────────────┐
│                    TRAINING PIPELINE                            │
├─────────────────────────────────────────────────────────────────┤
│                                                                 │
│  ┌──────────────────┐                                          │
│  │  Data Generation │                                          │
│  │  ┌────────────┐  │                                          │
│  │  │ Synthetic  │  │  docker compose run --rm generate-data   │
│  │  │  2D Env    │  │  → data/trajectories/                   │
│  │  └────────────┘  │                                          │
│  │  ┌────────────┐  │                                          │
│  │  │ Robot Data │  │  docker compose exec collect ...         │
│  │  │ Collection │  │  → data/robot_trajectories/             │
│  │  └────────────┘  │                                          │
│  └────────┬─────────┘                                          │
│           │                                                      │
│           ▼                                                      │
│  ┌──────────────────┐                                          │
│  │ Training Script │                                          │
│  │ docker compose  │                                          │
│  │   run --rm train│                                          │
│  │                  │                                          │
│  │  ┌────────────┐  │                                          │
│  │  │ Encoder    │  │                                          │
│  │  │ Transition │  │  ← Uses spatial-rag-base image          │
│  │  │ Decoder    │  │                                          │
│  │  └────────────┘  │                                          │
│  └────────┬─────────┘                                          │
│           │                                                      │
│           ▼                                                      │
│  ┌──────────────────┐                                          │
│  │  Checkpoint     │  checkpoints/model.pt                     │
│  │  Export         │  artifacts/encoder_transition.pt          │
│  └──────────────────┘                                          │
│                                                                 │
└─────────────────────────────────────────────────────────────────┘

Next.js UI Architecture

┌─────────────────────────────────────────────────────────┐
│                   Next.js Dashboard                      │
├─────────────────────────────────────────────────────────┤
│                                                          │
│  ┌───────────────────────────────────────────────┐     │
│  │                  page.tsx                      │     │
│  │  ┌─────────────────┐  ┌─────────────────┐    │     │
│  │  │ Control Panel   │  │ Visualization   │    │     │
│  │  │ - Latent Input  │  │ ┌─────────────┐ │    │     │
│  │  │ - Horizon Slider│  │ │LiveRollout  │ │    │     │
│  │  │ - Start/Stop    │  │ │ Component   │ │    │     │
│  │  └─────────────────┘  │ └─────────────┘ │    │     │
│  │                       └─────────────────┘    │     │
│  └───────────────────────────────────────────────┘     │
│                              │                          │
│                              ▼                          │
│  ┌───────────────────────────────────────────────┐     │
│  │           LiveRollout.tsx (SSE Client)        │     │
│  │                                               │     │
│  │  EventSource ──► /api/stream-rollout ──►      │     │
│  │  ┌─────────┐  ┌─────────┐  ┌─────────┐      │     │
│  │  │ Frame 0 │  │ Frame 1 │  │ Frame N │      │     │
│  │  └─────────┘  └─────────┘  └─────────┘      │     │
│  └───────────────────────────────────────────────┘     │
└─────────────────────────────────────────────────────────┘

Scaling

Memory Store Scaling

Scale Solution
10K vectors Faiss Flat (exact search)
100K vectors Faiss IVF (approximate)
1M+ vectors Qdrant/Milvus (distributed)

Inference Latency Targets

Use Case Target Latency Hardware
Real-time control < 10ms GPU (CUDA)
Interactive UI < 50ms CPU/GPU
ROS2 node (20 Hz) < 50ms CPU/Jetson

Environment Variables

Variable Default Description
Z_DIM 32 Latent dimension
ACTION_DIM 2 Action vector dimension
TOPK 8 Retrieved memory count
IMAGE_SIZE 64 Input image size
USE_TINY_MODELS true Use lightweight models
QDRANT_HOST localhost Qdrant server host
QDRANT_PORT 6333 Qdrant server port
WM_CHECKPOINT_PATH - Model checkpoint path

Practical Applications

Why Spatial-RAG Matters

Traditional AI Spatial-RAG
Sees current state only Remembers past experiences
No memory of past Retrieves relevant memories
Learns from scratch Reuses past knowledge
Slow to adapt Fast adaptation

Real-World Use Cases

  1. Autonomous Robots/Drones: Navigate using past flight memories
  2. Self-Driving Cars: Predict pedestrian behavior based on location history
  3. Warehouse Robots: Remember item locations for faster picking
  4. Home Assistants: Learn house layout, remember where things are
  5. AR Navigation: Predictive overlays based on spatial memory

Benefits

Benefit Description
Sample Efficiency 15-30% fewer examples needed
Better Predictions Context-aware forecasting
Faster Adaptation Reuse similar past experiences
Safety Predict before acting

See docs/PRACTICAL_GUIDE.md for detailed examples.

References

  • DreamerV3: Mastering Diverse Domains through World Models
  • Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks
  • H3: Uber's Hexagonal Hierarchical Spatial Index
  • MAGE: World Model as Agent Environment