Skip to content

wezteoh/gameplay-trajectory-diffusion

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

15 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

gameplay-trajectory-diffusion

A DDPM for sketch-guided trajectory simulation: given a sparse, sketch-like set of observed positions across agents and timesteps as conditioning, the model generates a complete multi-agent gameplay simulation.

Sketched conditioning Simulated rollout
sketch conditioning model rollout

The left panel illustrates the sparse conditioning observations passed to the model; the right shows a full gameplay simulation rolled out by the diffusion model.

Setup

uv sync
source .venv/bin/activate

Trajectory data is loaded from .npy files of shape (N, T, A, 2). Defaults in configs/data/trajectory_nba_filling.yaml point to data/nba_train.npy and data/nba_test.npy; place files there or override data.params.train_path / data.params.val_path via Hydra.

Training

The Hydra entry point is train.py, defaulting to configs/train_trajectory_filling_ddpm.yaml:

python train.py --config-name=train_trajectory_filling_ddpm_dit_mixedmask_learnsigma

Each run writes its checkpoints and a resolved config.yaml under checkpoints/<wandb_project>/<run_id>/. The sampling script reads that config.yaml from next to the .ckpt, so keep them together.

Validation masks

For comparable validation metrics across epochs, pregenerate a fixed mask of shape [N_val, T, A] and point the data config at it:

python scripts/pregenerate_trajectory_masks.py \
    --input data/nba_test.npy \
    --output tmp/val_masks_pregen.npy \
    --seed 42

Edit the MASKING dict at the top of scripts/pregenerate_trajectory_masks.py to match your validation protocol (same schema as data.params.masking.train). Then wire it into training:

python train.py data.params.masking.val_mask_path=tmp/val_masks_pregen.npy

Sampling

Use scripts/sample_trajectory_ddpm.py with a trained checkpoint:

python scripts/sample_trajectory_ddpm.py \
    --checkpoint checkpoints/<project>/<run_id>/<file>.ckpt \
    --num-samples 8 \
    --save-videos

You can create your own conditioning sketches via this app. To condition on your own sketches, pass --input-dir <dir> where each immediate subdirectory contains a traj.npy (court XY, shape (T, A, 2)) and a mask.npy (shape (T, A)); outputs are written under tmp/<output-subdir>/<subdir>/.

Pretrained models

A model pretrained on SportVU NBA dataset: https://huggingface.co/wezteoh/traj-diffusion

Acknowledgments

The trajectory datasets used in this project are kindly open-sourced by MediaBrain-SJTU/LED. Implementation details in src/modules/diffusion/gaussian_diffusion.py and src/modules/diffusion/diffusion_utils.py heavily reference facebookresearch/DiT.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages