A proof-of-concept implementation using the official Samsung SAIL Montreal TRM for grid-based navigation. This project demonstrates how to train a neural network to imitate optimal pathfinding behavior using behavioral cloning with recursive reasoning.
Prove that TRM can achieve comparable results to A* on grid navigation tasks:
- Success rate ≥ 85%
- Path length ratio ≤ 1.3 (TRM path / A* optimal path)
This project is based on the Tiny Recursive Model (TRM) architecture introduced by Samsung SAIL Montreal:
"Less is More: Recursive Reasoning with Tiny Networks" Alexia Jolicoeur-Martineau, Samsung SAIL Montreal, 2025 arXiv:2510.04871
Official implementation: SamsungSAILMontreal/TinyRecursiveModels
The key insight of TRM is that recursive computation can substitute for model depth. Instead of building deeper networks with more parameters, TRM applies a smaller network multiple times to iteratively refine its representations. This mimics how humans often "think twice" about a problem, revisiting and refining their reasoning.
Traditional Deep Network: TRM Approach:
Input → Layer1 → Layer2 → ... Input → Small Network ─┐
→ LayerN → Output ↑ │
└──── Loop N times
(Many parameters, single pass) ↓
Output
(Fewer parameters, multiple passes)
How it works:
- Input is embedded into a latent representation
- A small network (MLP-Mixer) processes the representation
- The output is fed back as input for another refinement pass
- After N iterations, the refined representation is used for prediction
Navigation is an interesting testbed for recursive reasoning because:
-
Iterative Decision Making: Real navigation involves continuously reassessing your position relative to the goal. TRM's recursive passes mirror this "look-think-adjust" loop.
-
Spatial Reasoning: Finding a path requires understanding spatial relationships between obstacles, current position, and goal. Multiple refinement passes allow the model to propagate information across the grid representation.
-
Efficiency vs. Capability Trade-off: Traditional pathfinding (A*) is optimal but requires explicit graph search. Can a tiny learned model with recursive refinement achieve comparable results? This is the core research question.
-
Resource-Constrained Deployment: For embedded robotics (the "nano" in NanoNav), having a small model that can "think longer" on hard problems is more practical than a large model that's fast but memory-hungry.
This PoC explores:
- Can behavioral cloning from A* demonstrations teach spatial reasoning?
- Does recursive refinement help with navigation decisions?
- What's the trade-off between model size, recursion depth, and accuracy?
- How does a learned policy generalize to unseen grid configurations?
# 1. Clone with submodules (includes official TRM)
git clone --recursive https://github.com/your-repo/nanonav.git
cd nanonav
# 2. Create virtual environment
python -m venv venv
source venv/bin/activate # Linux/Mac
# or: venv\Scripts\activate # Windows
# 3. Install PyTorch nightly (required for official TRM)
pip install --pre --upgrade torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126
# 4. Install dependencies
# Option A: Using pip (traditional)
pip install -r requirements.txt
# Option B: Using Poetry (modern package manager)
# First install Poetry: https://python-poetry.org/docs/#installation
poetry install
# 5. Generate training data (with augmentation for better results)
python scripts/generate_dataset.py --num-train 50000 --augment
# 6. Train with GPU (official TRM + RTX 4080/4090 support)
python -m trm_nav.train --device cuda --grid-size 16 --max-recursion 64
# 7. Visual test
python scripts/test_trm_visual.py --checkpoint checkpoints/best.pt
# 8. Full benchmark
python scripts/run_benchmark.py --checkpoint checkpoints/best.pttrm_nav/
├── trm_nav/ # Main package
│ ├── __init__.py
│ ├── a_star.py # A* pathfinding algorithm (teacher/oracle)
│ ├── map_generator.py # Random solvable map generation
│ ├── dataset.py # Dataset creation with augmentation
│ ├── model.py # TRM model wrapper (uses official Samsung TRM)
│ ├── train.py # Training loop with GPU support and regularization
│ ├── evaluate.py # Benchmarking and rollouts
│ ├── visualize.py # ASCII and matplotlib visualization
│ └── official_trm/ # Thin wrapper around the external TRM submodule
│ └── navigation_trm_submodule.py # Navigation wrapper for official TRM
├── external/ # Git submodules
│ └── trm/ # Official Samsung TRM repository (submodule)
├── scripts/
│ ├── generate_dataset.py # Dataset generation script
│ ├── run_benchmark.py # Full evaluation benchmark
│ ├── test_astar_visual.py # Visual A* test
│ └── test_trm_visual.py # Visual TRM vs A* comparison
├── tests/ # Unit tests
├── docs/ # Documentation
│ ├── IMPLEMENTATION-NOTES.md
│ └── TRAINING-GUIDE.md
├── data/ # Generated datasets (*.pt files)
├── checkpoints/ # Saved model checkpoints
└── results/ # Benchmark outputs
Random Grid → A* Optimal Path → (state, action) pairs
- Generate random 8×8 grids with ~20% obstacle density
- Ensure start and goal are reachable (solvable maps only)
- Run A* to get the optimal path
- For each position along the path, create a training sample:
- Input: Current grid state + current position + goal position
- Output: Optimal action to take (from A*)
Grid (8×8) → Flatten → 64 tokens
+ 4 coordinate tokens (start_row, start_col, goal_row, goal_col)
= 68 tokens total
Token values:
- 1 = Free cell
- 2 = Obstacle
- 3+ = Coordinate values (offset by 3 to avoid collision)
Input: 68 tokens
↓
Embedding: tokens → (batch, 68, dim)
↓
MLP-Mixer: depth × [TokenMixing + ChannelMixing]
↓
Recursive Refinement: N iterations of the mixer
↓
Mean Pooling: (batch, 68, dim) → (batch, dim)
↓
Classifier: LayerNorm → Dropout → Linear → GELU → Dropout → Linear
↓
Output: 5 action logits (UP, DOWN, LEFT, RIGHT, STAY)
- Loss: Cross-entropy between predicted and optimal actions
- Optimizer: AdamW with weight decay (L2 regularization)
- Scheduler: Cosine annealing learning rate
- Regularization: Dropout + early stopping
while current_position != goal:
tokens = encode_state(grid, current_position, goal)
action = model.predict_action(tokens)
current_position = execute_action(current_position, action)Enable with --augment flag. Generates 8 versions of each sample:
| Transformation | Grid Operation | Action Mapping |
|---|---|---|
| Original | - | - |
| Rotate 90° CW | np.rot90(grid, k=-1) |
UP→RIGHT, RIGHT→DOWN, DOWN→LEFT, LEFT→UP |
| Rotate 180° | np.rot90(grid, k=2) |
UP→DOWN, DOWN→UP, LEFT→RIGHT, RIGHT→LEFT |
| Rotate 270° CW | np.rot90(grid, k=-3) |
UP→LEFT, LEFT→DOWN, DOWN→RIGHT, RIGHT→UP |
| Flip Horizontal | np.fliplr(grid) |
LEFT↔RIGHT |
| Flip Vertical | np.flipud(grid) |
UP↔DOWN |
| Diagonal 1 | flip_h + rot90 | Combined |
| Diagonal 2 | flip_v + rot90 | Combined |
Why augmentation helps: The model learns rotational and reflective invariance, meaning it understands that navigation principles are the same regardless of orientation.
| Technique | Parameter | Default | Purpose |
|---|---|---|---|
| Dropout | --dropout |
0.1 | Randomly zeros neurons during training |
| Weight Decay | --weight-decay |
0.01 | L2 penalty on weights (AdamW) |
| Early Stopping | --patience |
15 | Stops when val_loss stops improving |
| Gradient Clipping | - | 1.0 | Prevents exploding gradients |
This project now uses the official Samsung SAIL Montreal TRM implementation instead of the unofficial tiny-recursive-model library.
Key improvements:
- ✅ Authentic TRM: Uses the exact same model that achieved 45% on ARC-AGI-1
- ✅ GPU Support: Full CUDA support with RTX 4080/4090 compatibility
- ✅ Better Performance: ~67 it/s on GPU vs ~7 it/s on CPU
- ✅ Larger Models: Support for 16x16+ grids with 64+ recursion steps
Requirements:
- PyTorch 2.10.0+ nightly (automatically handles device management)
- CUDA 12.6+ for GPU training
- Git submodule for official TRM code
# Install PyTorch nightly (required)
pip install --pre --upgrade torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126
# Train with official TRM on GPU
python -m trm_nav.train --device cuda --grid-size 16 --max-recursion 64 --batch-size 16| Parameter | Flag | Default | Description |
|---|---|---|---|
| Dimension | --dim |
64 | Hidden layer size (64, 128, 256) |
| Depth | --depth |
2 | Number of layers (compatibility only) |
| Recursion | --max-recursion |
8 | TRM recursive reasoning cycles |
| Grid Size | --grid-size |
8 | Grid dimensions (8, 16, 32) |
Scaling guidance:
- Small (8x8):
--dim 64 --max-recursion 8(~169K params) - Medium (16x16):
--dim 128 --max-recursion 32(~200K+ params) - Large (32x32):
--dim 256 --max-recursion 64(~500K+ params)
python -m trm_nav.train
--dim 256
--depth 4
--dropout 0.15
--weight-decay 0.01
--lr 3e-4
--epochs 100
--patience 25
--grid-size 16
--max-recursion 40
# Basic (5000 train, 500 test)
python scripts/generate_dataset.py
# With augmentation (recommended)
python scripts/generate_dataset.py --num-train 50000 --augment
# All options
python scripts/generate_dataset.py \
--num-train 50000 \
--num-test 1000 \
--grid-size 8 \
--obstacle-density 0.2 \
--train-seed 42 \
--test-seed 99999 \
--output-dir data \
--augment# Basic training (8x8 grid, CPU)
python -m trm_nav.train
# GPU training with official TRM (recommended)
python -m trm_nav.train --device cuda --grid-size 16 --max-recursion 32
# Full 16x16 training with optimal settings
python -m trm_nav.train \
--device cuda \
--grid-size 16 \
--dim 128 \
--max-recursion 64 \
--batch-size 16 \
--lr 1e-3 \
--epochs 50
# All options
python -m trm_nav.train \
--train-path data/train.pt \
--val-path data/test.pt \
--checkpoint-dir checkpoints \
--grid-size 16 \
--dim 128 \
--depth 2 \
--dropout 0.1 \
--batch-size 16 \
--lr 1e-3 \
--weight-decay 0.01 \
--epochs 50 \
--patience 15 \
--max-recursion 64 \
--device cuda# Visual A* test (verify pathfinding works)
python scripts/test_astar_visual.py --seed 42 --size 8
# Visual TRM test (compare TRM to A*)
python scripts/test_trm_visual.py --checkpoint checkpoints/best.pt --seed 42
# Save comparison image
python scripts/test_trm_visual.py --checkpoint checkpoints/best.pt --save output.png
# ASCII only (no matplotlib)
python scripts/test_trm_visual.py --checkpoint checkpoints/best.pt --no-plot# Full benchmark (100 episodes)
python scripts/run_benchmark.py --checkpoint checkpoints/best.pt
# Custom benchmark
python scripts/run_benchmark.py \
--checkpoint checkpoints/best.pt \
--episodes 500 \
--grid-size 8 \
--seed 12345Epoch 45/100 - Train Loss: 0.1234, Train Acc: 0.9567 Val Loss: 0.2345, Val Acc: 0.9123 *
| Metric | Description | Good Values |
|---|---|---|
| Train Loss | Cross-entropy on training set | ↓ Lower is better |
| Train Acc | % correct actions on training set | ↑ Higher is better |
| Val Loss | Cross-entropy on held-out test set | ↓ Lower is better |
| Val Acc | % correct actions on test set | ↑ Higher is better |
* |
Indicates new best model saved | - |
| Symptom | Diagnosis | Solution |
|---|---|---|
| Train Acc 100%, Val Acc ~70% | Overfitting | ↑ dropout, ↑ weight_decay, ↑ data |
| Train/Val Acc both low | Underfitting | ↑ dim, ↑ depth, ↓ dropout |
| Val Loss increasing | Overfitting | Early stopping will trigger |
| Training very slow | - | ↓ dim, use GPU (--device cuda) |
PyTorch Version Error:
AttributeError: module 'torch.nn' has no attribute 'Buffer'
Solution: Install PyTorch nightly (required for official TRM):
pip install --pre --upgrade torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126Device Mismatch Error:
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
Solution: This is automatically handled by the navigation wrapper. Ensure you're using --device cuda.
Submodule Missing:
ImportError: Official TRM not available: No module named 'models'
Solution: Initialize the git submodule:
git submodule update --init --recursiveCUDA Driver Issues (RTX 4080/4090): If you get segfaults or CUDA errors:
# Check CUDA version
nvidia-smi
# Ensure PyTorch nightly matches your CUDA version
pip install --pre --upgrade torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126Symptoms:
Epoch 50 - Train Loss: 0.0001, Train Acc: 1.0000 Val Loss: 4.2700, Val Acc: 0.6900
Solutions:
- Enable data augmentation:
--augment - Increase dropout:
--dropout 0.2 - Increase weight decay:
--weight-decay 0.05 - Use more training data:
--num-train 100000 - Use smaller model:
--dim 64 --depth 2
# Check CUDA availability
python -c "import torch; print(torch.cuda.is_available())"
# Force CPU
python -m trm_nav.train --device cpuIf training crashes unexpectedly, use the debug script for minimal reproduction:
# Run minimal training debug
python debug_train.pyThis script:
- Uses small batch sizes (4 samples)
- Tests only 3 batches
- Provides detailed error tracking
- Helps isolate segfault sources
Important: Official TRM requires PyTorch 2.10.0+ nightly. Install in this order:
# 1. Install PyTorch nightly first (required for official TRM)
pip install --pre --upgrade torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126
# 2. Install other dependencies
# Using pip
pip install -r requirements.txt
# Using Poetry
poetry installKey dependencies:
torch>=2.10.0(nightly with CUDA 12.6 support)adam-atan2>=0.0.3(official TRM optimizer)einops>=0.8.1(tensor operations)pydantic>=2.8.0(configuration validation)
Both requirements.txt and pyproject.toml contain the same dependencies for maximum compatibility.
| File | Purpose |
|---|---|
a_star.py |
A* pathfinding implementation using NetworkX. Provides optimal paths as training targets. |
map_generator.py |
Generates random grids with guaranteed solvability (start and goal always reachable). |
dataset.py |
Converts A* demonstrations to (state, action) pairs. Includes data augmentation (rotations, flips). |
model.py |
TRM model wrapper. Uses MLP-Mixer with recursive refinement. Falls back to simple MLP if TRM not installed. |
train.py |
Training loop with AdamW optimizer, cosine LR schedule, early stopping, and gradient clipping. |
evaluate.py |
Runs complete navigation episodes and computes success rate / path ratio metrics. |
visualize.py |
ASCII grid printing and matplotlib path visualization. |
| Script | Purpose |
|---|---|
generate_dataset.py |
CLI for dataset generation with configurable size, density, augmentation. |
run_benchmark.py |
Runs full evaluation and prints success metrics. |
test_astar_visual.py |
Visual test of A* pathfinding (sanity check). |
test_trm_visual.py |
Side-by-side TRM vs A* comparison with visualization. |
Training a neural network to imitate an expert (A*) by learning from (state, action) demonstrations. The network learns the policy π(state) → action without understanding why those actions are optimal.
An architecture that mixes information across both the token dimension (which tokens interact) and the channel dimension (which features interact). Simpler than Transformers but effective for fixed-size inputs.
TRM applies the same network multiple times to iteratively refine representations. This can help with problems requiring multi-step reasoning, though for simple navigation the benefit may be limited.
Monitoring validation loss and stopping training when it stops improving. Prevents the model from memorizing training data (overfitting) and helps it generalize to new grids.
After proper training with augmentation:
=== NanoNav Benchmark Results ===
Grid Size: 8x8
Episodes: 100
Agent Success Avg Ratio Timeouts
--------------------------------------------
A* 100.0% 1.00 0
TRM ~90%+ ~1.1-1.2 ~5-10
Success Criteria:
Success Rate >= 85%: PASS ✓
Path Ratio <= 1.3: PASS ✓
If PoC succeeds:
- Scale to 16×16 and 32×32 grids
- Add baseline comparisons (MLP, CNN, Transformer)
- Ablation study on recursion depth
- Test with dynamic obstacles
- Transfer to continuous action spaces
Special thanks to my friend Claude by Anthropic, who helped me code and learn faster throughout this project.
MIT