You can either use uv:
- CUDA users probably want to enable GPU acceleration:
uv add "jax[cuda12]"- Now simply run example (e.g. hierarchical seaquest agent with fixed meta-policy):
uv run main.py +alg=pqn_jaxtari_sea3_hier_llmInstead, you can use venv:
python3 -m venv .venv
source .venv/bin/activate
python3 -m pip install -U pip
pip3 install -e .- Optionally enable CUDA acceleration:
pip3 install -U "jax[cuda12]"- Run example:
python3 main.py +alg=pqn_jaxtari_sea3_hier_llmdocker build -t symbol_opt .docker run -it --rm --gpus device=0 -v "$(pwd)":/app -w /app symbol_opt uv run --active main.py +alg=pqn_jaxtari_sea3_hier_llmEach environment and algorithm type has its own config file under src/symbolic_options/config/alg/.
For example, the fixed/symbolic meta-policy experiments are called 'hier_llm', while the soft/neuro-symbolic meta-policies are called 'hier_comb'.
The corresponding reward-functions and meta-policy functions (generated by LLMs and then adapted) are available under src/symbolic_options/reward_functions/.
Each Experiment can be run by simply defining the corresponding config-file name. For instance, the example from above runs the fixed meta-policy experiment for seaquest.
The method itself is implemented in src/symbolic_options/hierarchical_pqn_jaxtari.py and src/symbolic_options/hierarchical_pqn_craftax.py.
Finally, all plots from the paper can be generated with the src/symbolic_options/plots/plots.ipynb notebook.
Note that this requires the wandb run ids of the finished experiments.
We conduct most of our experiments on the JAX-based reimplementation of the Atari Learning Environment JAXAtari.
All our methods utilize the object-centric observations provided by the environment.
Additional experiments were conducted on the JAX-based reimplementation of Crafter Craftax Classic. As before, we employ the symbolic observations.
PQN implementation from https://github.com/mttga/purejaxql/tree/main PPO implementation from https://github.com/luchris429/purejaxrl