This document describes the architecture of a Spatial Retrieval-Augmented Generation (Spatial-RAG) system for latent world models. The system combines:
- Latent World Models: Learn compressed representations of environment dynamics
- Spatial Memory: Store experiences with geospatial metadata for efficient retrieval
- Hybrid Retrieval: Combine spatial proximity filtering with latent similarity search
The goal is to build embodied AI systems that can leverage past experiences for improved prediction, planning, and decision-making.
┌─────────────────────────────────────────────────────────────────────────────┐
│ DEPLOYMENT STACK │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
│ │ Next.js │◄───►│ FastAPI │◄───►│ Qdrant │ │
│ │ UI :3000 │ SSE │ API :8080 │ │ DB :6333 │ │
│ └──────────────┘ └──────┬───────┘ └──────────────┘ │
│ │ │
│ ┌─────────┴─────────┐ │
│ │ TorchScript │ │
│ │ Models │ │
│ └─────────┬─────────┘ │
│ │ │
│ ┌──────────────────────────┼──────────────────────────┐ │
│ │ ROS2 INTEGRATION (Optional) │ │
│ │ ┌─────────────┐ ┌─────────────┐ ┌─────────┐ │ │
│ │ │ /camera │───►│ Latent │───►│ /latent │ │ │
│ │ │ /image_raw │ │ Publisher │ │ topic │ │ │
│ │ └─────────────┘ └──────┬──────┘ └─────────┘ │ │
│ │ │ │ │
│ │ ┌─────────────┐ ┌──────┴──────┐ │ │
│ │ │ /actions │───►│ Transition │ │ │
│ │ │ topic │ │ Predictor │ │ │
│ │ └─────────────┘ └─────────────┘ │ │
│ └─────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────────┐
│ PERCEPTION PIPELINE │
├─────────────────────────────────────────────────────────────────────┤
│ RGB Image ──┐ │
│ Depth Map ──┼──► [Encoder E_φ] ──► Latent z_t ──► Memory Bank │
│ Proprio ────┘ │ │
│ │ │
│ ▼ │
│ ┌────────────────────────┐ │
│ │ RETRIEVAL MODULE │ │
│ │ ┌─────────────────┐ │ │
│ Query: z_t, position ────► │ │ Spatial Filter │ │ │
│ │ │ (H3 / BBox) │ │ │
│ │ └────────┬────────┘ │ │
│ │ ▼ │ │
│ │ ┌─────────────────┐ │ │
│ │ │ Latent KNN │ │ │
│ │ │ (Faiss/Qdrant) │ │ │
│ │ └────────┬────────┘ │ │
│ └────────────┼──────────┘ │
│ ▼ │
│ Retrieved Latents R = {z_r1, ..., z_rK}│
│ │ │
├───────────────────────────────────────────┼─────────────────────────┤
│ DYNAMICS MODEL │
├───────────────────────────────────────────┼─────────────────────────┤
│ ▼ │
│ z_t ──┐ ┌────────────────────────┐ │
│ a_t ──┼──────────────────► │ Transition f_θ │ │
│ R ──┘ │ (GRU + Attention) │ │
│ └───────────┬───────────┘ │
│ ▼ │
│ z_{t+1} (predicted) │
│ │ │
│ ┌───────────┴───────────┐ │
│ ▼ ▼ │
│ [Decoder D_ψ] [Policy/Planner] │
│ │ │ │
│ ▼ ▼ │
│ Reconstruction Action Selection │
└─────────────────────────────────────────────────────────────────────┘
All services share a single 13GB base image for efficiency:
┌─────────────────────────────────────────────────────────────────────┐
│ DOCKER IMAGE ARCHITECTURE │
├─────────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────────────────────────────────────────────┐ │
│ │ spatial-rag-base:latest │ │
│ │ (13.2 GB) │ │
│ │ ┌─────────────────────────────────────────────────────┐ │ │
│ │ │ Python 3.10 + PyTorch + Dependencies │ │ │
│ │ └─────────────────────────────────────────────────────┘ │ │
│ └────────────────────────────┬────────────────────────────────┘ │
│ │ │
│ ┌─────────────────────┼─────────────────────┐ │
│ ▼ ▼ ▼ │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ generate- │ │ train │ │ experiment │ │
│ │ data │ │ │ │ │ │
│ └─────────────┘ └─────────────┘ └─────────────┘ │
│ ▼ ▼ ▼ │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ reports │ │ collect │ │ (+ more) │ │
│ └─────────────┘ └─────────────┘ └─────────────┘ │
│ │
│ Separate Images: │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ api │ │ ros2 │ │ ui │ │
│ │ (12.1 GB) │ │ (13.5 GB) │ │ (alpine) │ │
│ └─────────────┘ └─────────────┘ └─────────────┘ │
│ │
│ Total: ~25 GB (down from ~117 GB with separate images) │
└─────────────────────────────────────────────────────────────────────┘
profiles:
core: [qdrant, api] # Essential services
ui: [ui] # Dashboard
ros2: [ros2] # Robotics
data: [generate-data] # Data generation
train: [train] # Training
experiment: [experiment] # Experiments
reports: [reports] # Report generation
collect: [collect] # Data collectionPurpose: Map observations to latent representations
Architecture: ResNet18 backbone with projection head
Input:
- RGB images: [B, 3, H, W]
- Optional depth: [B, 1, H, W]
- Optional proprioception: [B, proprio_dim]
Output:
- Latent vector: z ∈ ℝ^{z_dim}
- Feature map (optional): [B, 512, H/32, W/32]
Purpose: Store latent states with rich metadata for retrieval
Storage Format:
{
"id": str, # Unique identifier
"vector": np.ndarray, # Latent z [z_dim]
"position": {"x": float, "y": float, "z": float},
"h3_cell": str, # H3 geospatial cell ID
"action": np.ndarray, # Action that led here
"timestamp": float, # Unix timestamp
}Backends:
- Qdrant (Production): Vector DB with payload filtering
- Faiss (Development): Local CPU-based
Algorithm:
1. Spatial Prefilter:
- Compute H3 cell for query position
- Get neighboring cells (k-ring)
- Filter memory to candidates
2. Latent KNN:
- Search top-K nearest in latent space
3. Return:
- Retrieved latents: R = [z_r1, ..., z_rK]
- Metadata for each retrieved state
Architecture: GRU + Cross-Attention
def forward(z_t, a_t, retrieved_z):
a_emb = action_proj(a_t)
h = gru([z_t || a_emb])
attn_out = multihead_attention(query=h, key=retrieved_z, value=retrieved_z)
gate = sigmoid(gate_proj([h || attn_out]))
fused = gate * attn_out + (1 - gate) * h
z_t1 = mlp(fused) + z_t # Residual
return z_t1Purpose: Reconstruct observations from latent representations
Architecture: ConvTranspose upsampling network
┌─────────────────────────────────────────────────────────┐
│ ROS2 Latent Publisher Node │
├─────────────────────────────────────────────────────────┤
│ │
│ Subscriptions: │
│ ┌────────────────────┐ │
│ │ /camera/image_raw │ ─┐ │
│ │ (sensor_msgs/Image)│ │ │
│ └────────────────────┘ │ │
│ ▼ │
│ ┌────────────────────┐ ┌────────────────────┐ │
│ │ /actions │──►│ LatentPublisher │ │
│ │ (Float32MultiArray)│ │ │ │
│ └────────────────────┘ │ - Image encoder │ │
│ │ - Transition │ │
│ ┌────────────────────┐ │ predictor │ │
│ │ /pose │──►│ │ │
│ │ (PoseStamped) │ └─────────┬──────────┘ │
│ └────────────────────┘ │ │
│ ▼ │
│ Publishers: │
│ ┌────────────────────┐ ┌────────────────────┐ │
│ │ /latent │◄─┤ Encoded latent │ │
│ │ (Float32MultiArray)│ │ vector z_t (32d) │ │
│ └────────────────────┘ └────────────────────┘ │
│ │
│ ┌────────────────────┐ ┌────────────────────┐ │
│ │ /latent_next │◄─┤ Predicted next │ │
│ │ (Float32MultiArray)│ │ latent z_{t+1} │ │
│ └────────────────────┘ └────────────────────┘ │
│ │
│ Parameters: │
│ - model_path: TorchScript model file │
│ - z_dim: 32 (latent dimension) │
│ - publish_rate: 20.0 Hz │
└─────────────────────────────────────────────────────────┘
$ ros2 topic list
/actions
/camera/image_raw
/latent
/latent_next
/parameter_events
/pose
/rosoutlayout:
dim:
- label: z
size: 32
stride: 32
data: [0.021, -0.132, -0.055, ...] # 32-dim vector @ 20Hz┌─────────────────────────────────────────────────────────┐
│ FastAPI Application │
├─────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ /encode │ │ /predict │ │ /rollout │ │
│ │ POST │ │ POST │ │ POST │ │
│ └──────┬──────┘ └──────┬──────┘ └──────┬──────┘ │
│ │ │ │ │
│ ▼ ▼ ▼ │
│ ┌─────────────────────────────────────────────┐ │
│ │ Model Components │ │
│ │ ┌─────────┐ ┌──────────┐ ┌─────────┐ │ │
│ │ │ Encoder │ │Transition│ │ Decoder │ │ │
│ │ └─────────┘ └──────────┘ └─────────┘ │ │
│ └────────────────────┬────────────────────────┘ │
│ │ │
│ ┌─────────────┴─────────────┐ │
│ ▼ ▼ │
│ ┌─────────────┐ ┌─────────────┐ │
│ │ Retriever │◄───────────►│ Qdrant │ │
│ └─────────────┘ └─────────────┘ │
│ │
├─────────────────────────────────────────────────────────┤
│ SSE Streaming: /stream-rollout │
│ Real-time frame prediction with decoded images │
└─────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────────┐
│ TRAINING PIPELINE │
├─────────────────────────────────────────────────────────────────┤
│ │
│ ┌──────────────────┐ │
│ │ Data Generation │ │
│ │ ┌────────────┐ │ │
│ │ │ Synthetic │ │ docker compose run --rm generate-data │
│ │ │ 2D Env │ │ → data/trajectories/ │
│ │ └────────────┘ │ │
│ │ ┌────────────┐ │ │
│ │ │ Robot Data │ │ docker compose exec collect ... │
│ │ │ Collection │ │ → data/robot_trajectories/ │
│ │ └────────────┘ │ │
│ └────────┬─────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────────┐ │
│ │ Training Script │ │
│ │ docker compose │ │
│ │ run --rm train│ │
│ │ │ │
│ │ ┌────────────┐ │ │
│ │ │ Encoder │ │ │
│ │ │ Transition │ │ ← Uses spatial-rag-base image │
│ │ │ Decoder │ │ │
│ │ └────────────┘ │ │
│ └────────┬─────────┘ │
│ │ │
│ ▼ │
│ ┌──────────────────┐ │
│ │ Checkpoint │ checkpoints/model.pt │
│ │ Export │ artifacts/encoder_transition.pt │
│ └──────────────────┘ │
│ │
└─────────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────┐
│ Next.js Dashboard │
├─────────────────────────────────────────────────────────┤
│ │
│ ┌───────────────────────────────────────────────┐ │
│ │ page.tsx │ │
│ │ ┌─────────────────┐ ┌─────────────────┐ │ │
│ │ │ Control Panel │ │ Visualization │ │ │
│ │ │ - Latent Input │ │ ┌─────────────┐ │ │ │
│ │ │ - Horizon Slider│ │ │LiveRollout │ │ │ │
│ │ │ - Start/Stop │ │ │ Component │ │ │ │
│ │ └─────────────────┘ │ └─────────────┘ │ │ │
│ │ └─────────────────┘ │ │
│ └───────────────────────────────────────────────┘ │
│ │ │
│ ▼ │
│ ┌───────────────────────────────────────────────┐ │
│ │ LiveRollout.tsx (SSE Client) │ │
│ │ │ │
│ │ EventSource ──► /api/stream-rollout ──► │ │
│ │ ┌─────────┐ ┌─────────┐ ┌─────────┐ │ │
│ │ │ Frame 0 │ │ Frame 1 │ │ Frame N │ │ │
│ │ └─────────┘ └─────────┘ └─────────┘ │ │
│ └───────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────┘
| Scale | Solution |
|---|---|
| 10K vectors | Faiss Flat (exact search) |
| 100K vectors | Faiss IVF (approximate) |
| 1M+ vectors | Qdrant/Milvus (distributed) |
| Use Case | Target Latency | Hardware |
|---|---|---|
| Real-time control | < 10ms | GPU (CUDA) |
| Interactive UI | < 50ms | CPU/GPU |
| ROS2 node (20 Hz) | < 50ms | CPU/Jetson |
| Variable | Default | Description |
|---|---|---|
Z_DIM |
32 | Latent dimension |
ACTION_DIM |
2 | Action vector dimension |
TOPK |
8 | Retrieved memory count |
IMAGE_SIZE |
64 | Input image size |
USE_TINY_MODELS |
true | Use lightweight models |
QDRANT_HOST |
localhost | Qdrant server host |
QDRANT_PORT |
6333 | Qdrant server port |
WM_CHECKPOINT_PATH |
- | Model checkpoint path |
| Traditional AI | Spatial-RAG |
|---|---|
| Sees current state only | Remembers past experiences |
| No memory of past | Retrieves relevant memories |
| Learns from scratch | Reuses past knowledge |
| Slow to adapt | Fast adaptation |
- Autonomous Robots/Drones: Navigate using past flight memories
- Self-Driving Cars: Predict pedestrian behavior based on location history
- Warehouse Robots: Remember item locations for faster picking
- Home Assistants: Learn house layout, remember where things are
- AR Navigation: Predictive overlays based on spatial memory
| Benefit | Description |
|---|---|
| Sample Efficiency | 15-30% fewer examples needed |
| Better Predictions | Context-aware forecasting |
| Faster Adaptation | Reuse similar past experiences |
| Safety | Predict before acting |
See docs/PRACTICAL_GUIDE.md for detailed examples.
- DreamerV3: Mastering Diverse Domains through World Models
- Retrieval-Augmented Generation for Knowledge-Intensive NLP Tasks
- H3: Uber's Hexagonal Hierarchical Spatial Index
- MAGE: World Model as Agent Environment