Skip to content

remunds/symbolic_options

Repository files navigation

Symbolic Option Learning

Installation

UV project manager

You can either use uv:

  • CUDA users probably want to enable GPU acceleration:
uv add "jax[cuda12]"
  • Now simply run example (e.g. hierarchical seaquest agent with fixed meta-policy):
uv run main.py +alg=pqn_jaxtari_sea3_hier_llm

Python venv

Instead, you can use venv:

python3 -m venv .venv
source .venv/bin/activate

python3 -m pip install -U pip
pip3 install -e .
  • Optionally enable CUDA acceleration:
pip3 install -U "jax[cuda12]"
  • Run example:
python3 main.py +alg=pqn_jaxtari_sea3_hier_llm

Docker container (CUDA enabled)

docker build -t symbol_opt .
docker run -it --rm --gpus device=0 -v "$(pwd)":/app -w /app symbol_opt uv run --active main.py +alg=pqn_jaxtari_sea3_hier_llm

Reproducing paper experiments

Each environment and algorithm type has its own config file under src/symbolic_options/config/alg/. For example, the fixed/symbolic meta-policy experiments are called 'hier_llm', while the soft/neuro-symbolic meta-policies are called 'hier_comb'.

The corresponding reward-functions and meta-policy functions (generated by LLMs and then adapted) are available under src/symbolic_options/reward_functions/.

Each Experiment can be run by simply defining the corresponding config-file name. For instance, the example from above runs the fixed meta-policy experiment for seaquest.

The method itself is implemented in src/symbolic_options/hierarchical_pqn_jaxtari.py and src/symbolic_options/hierarchical_pqn_craftax.py.

Finally, all plots from the paper can be generated with the src/symbolic_options/plots/plots.ipynb notebook. Note that this requires the wandb run ids of the finished experiments.

Environments

We conduct most of our experiments on the JAX-based reimplementation of the Atari Learning Environment JAXAtari.

All our methods utilize the object-centric observations provided by the environment.

Additional experiments were conducted on the JAX-based reimplementation of Crafter Craftax Classic. As before, we employ the symbolic observations.

Acknowledgements

PQN implementation from https://github.com/mttga/purejaxql/tree/main PPO implementation from https://github.com/luchris429/purejaxrl

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors