Skip to content

Latest commit

 

History

History
59 lines (45 loc) · 2.77 KB

File metadata and controls

59 lines (45 loc) · 2.77 KB

LeWorldModel (PushT)

Optimize a JEPA world model for the PushT task — minimize validation loss under model size and training time constraints.

Setup

  1. Read the in-scope files:
    • jepa.py — JEPA model architecture. You modify this.
    • module.py — Transformer blocks, attention, embedder, predictor. You modify this.
    • train.py — Training script. You modify this.
    • utils.py — Image preprocessing, normalization, checkpointing utilities. You modify this.
    • config/train/lewm.yaml — Training hyperparameters. You modify this.
    • config/train/data/pusht.yaml — Dataset config. You modify this.
    • eval/eval.sh — Runs evaluation. Do not modify.
    • prepare.sh — Downloads data and installs deps. Do not modify.
  2. Run prepare: bash prepare.sh to install dependencies and download the PushT dataset.
  3. Verify data exists: Check that $STABLEWM_HOME (defaults to ~/.stable-wm/) contains pusht_expert_train.h5.
  4. Initialize results.tsv: Create results.tsv with just the header row.
  5. Run baseline: bash eval/eval.sh to establish the starting score.

The benchmark

The PushT task requires a world model to predict future latent states given a sequence of observations and actions. The model is trained on expert demonstrations of a T-shaped block pushing task. The primary metric is validation prediction loss — how well the model predicts next-state embeddings on held-out data. The baseline uses a ViT-tiny encoder (~15M params) with a 6-layer conditional transformer predictor.

Experimentation

What you CAN do:

  • Modify jepa.py, module.py, train.py, utils.py, and config files under config/train/.
  • Change the model architecture (encoder, predictor, projector, embedder).
  • Change training hyperparameters (learning rate, batch size, optimizer, scheduler, etc.).
  • Change the loss function or regularization.
  • Add new modules or training techniques.

What you CANNOT do:

  • Modify eval/, prepare.sh, or test data.
  • Exceed 20M trainable parameters — the eval will fail.
  • Exceed 10 minutes training time — training is killed after 10 min.
  • Remove the val_loss.txt output — the eval script reads this file.
  • Use pre-trained weights or download external models.

The goal: minimize val_loss. This is the mean squared error between predicted and target embeddings on the validation set. Lower is better.

Simplicity criterion: All else being equal, simpler is better.

Experiment loop

  1. Modify the code/config.
  2. Run bash eval/eval.sh and check the output.
  3. Log results: append a row to results.tsv with columns: experiment_name, val_loss, params.
  4. Iterate.

Output format

---
val_loss:         <value>
params:           <value>
total:            1