JAX-first cooperative multi-agent reinforcement learning.
The first target is a small, fast tabular Q-learning stack:
- batched cooperative gridworld environments
- independent Q-learning agents with shared team rewards
- JAX scan-friendly training loops
The longer-term direction is a method zoo and environment zoo for cooperative MARL.
conda env create -f environment.yml
conda run -n marlax uv pip install --python /home/dev/miniconda3/envs/marlax/bin/python -e ".[gpu,dev,storage,viz]"conda run -n marlax python -m pytest -q
XLA_PYTHON_CLIENT_PREALLOCATE=false conda run -n marlax python experiments/coop_grid_q_learning/run.pypython -m http.server 8000 --directory siteUse STYLE.md for diagnostic plot styling.