Optimize a JEPA world model for the PushT task — minimize validation loss under model size and training time constraints.
- Read the in-scope files:
jepa.py— JEPA model architecture. You modify this.module.py— Transformer blocks, attention, embedder, predictor. You modify this.train.py— Training script. You modify this.utils.py— Image preprocessing, normalization, checkpointing utilities. You modify this.config/train/lewm.yaml— Training hyperparameters. You modify this.config/train/data/pusht.yaml— Dataset config. You modify this.eval/eval.sh— Runs evaluation. Do not modify.prepare.sh— Downloads data and installs deps. Do not modify.
- Run prepare:
bash prepare.shto install dependencies and download the PushT dataset. - Verify data exists: Check that
$STABLEWM_HOME(defaults to~/.stable-wm/) containspusht_expert_train.h5. - Initialize results.tsv: Create
results.tsvwith just the header row. - Run baseline:
bash eval/eval.shto establish the starting score.
The PushT task requires a world model to predict future latent states given a sequence of observations and actions. The model is trained on expert demonstrations of a T-shaped block pushing task. The primary metric is validation prediction loss — how well the model predicts next-state embeddings on held-out data. The baseline uses a ViT-tiny encoder (~15M params) with a 6-layer conditional transformer predictor.
What you CAN do:
- Modify
jepa.py,module.py,train.py,utils.py, and config files underconfig/train/. - Change the model architecture (encoder, predictor, projector, embedder).
- Change training hyperparameters (learning rate, batch size, optimizer, scheduler, etc.).
- Change the loss function or regularization.
- Add new modules or training techniques.
What you CANNOT do:
- Modify
eval/,prepare.sh, or test data. - Exceed 20M trainable parameters — the eval will fail.
- Exceed 10 minutes training time — training is killed after 10 min.
- Remove the
val_loss.txtoutput — the eval script reads this file. - Use pre-trained weights or download external models.
The goal: minimize val_loss. This is the mean squared error between predicted and target embeddings on the validation set. Lower is better.
Simplicity criterion: All else being equal, simpler is better.
- Modify the code/config.
- Run
bash eval/eval.shand check the output. - Log results: append a row to
results.tsvwith columns:experiment_name,val_loss,params. - Iterate.
---
val_loss: <value>
params: <value>
total: 1