gyaradax is a JAX code for local flux-tube gyrokinetic simulations. It is based on GKW. At the current stage, it provides a differentiable solver for the electrostatic, collisionless Vlasov-Poisson system.
This was made possible with significant usage of agentic workflows. PROMPT.md contains the prompt used to obtain the initial working version of gyaradax
Check out our whitepaper, or see agent notes for a detailed walkthrough of GKW and this reimplementation.
Create a conda environment with CUDA toolkit and all dependencies:
conda env create -f environment.yml
conda activate gyaradax_env
pip install -e ".[dev]"This installs gyaradax in editable mode with JAX (CUDA 13, local toolkit), numpy, and dev tools (pytest, ruff, black). The conda environment provides the CUDA toolkit (>= 13.1), cuDNN, cmake, and a C++ compiler.
The optional CUDA backend provides fused kernels for the linear RHS stencils and the nonlinear Poisson bracket (cuFFT graph-captured pipeline). It requires a GPU with compute capability >= 80.
From gyaradax/backends/cuda_kernels/:
mkdir -p _build && cd _build && cmake .. -DCMAKE_BUILD_TYPE=Release && cmake --build . -j$(nproc) && cmake --install . && cd ..To target a specific GPU architecture (e.g., Ampere sm_80):
cmake .. -DCMAKE_BUILD_TYPE=Release -DGPU_ARCHITECTURES="80"CMake prints the detected compute capability, jaxlib version, and cudatoolkit. Ensure these are correct before proceeding.
solver.py: Linear and nonlinear Terms (I-VIII), RK4 integrator.simulate.py: Interface for trajectory generation.integrals.py: Field solvers and flux integrals.geometry.py: Parsers for GKW geometry files and metric tensor coefficients.params.py: Configuration pytrees.stencils.py: Finite difference stencil definitions.diag.py: Diagnostics (growth rate, frequency, spectral).backends/: Backend dispatch (JAX, CUDA). See CUDA build instructions.plot_utils.py: Visualization.
The scripts/run.py script provides a convenient way to execute simulations, supporting single or multiple configuration files, batch execution, and specifying runtime options like the device and number of blocks.
# Run a single configuration
python -u -m scripts.run configs/iteration_13.yaml --device 0When multiple YAML configuration files are provided, and they share the same grid resolution and static parameters, scripts/run.py can automatically batch them using jax.vmap for parallel execution on a single device.
# Run two configurations in parallel on device 0
python -u -m scripts.run configs/adiabatic_a.yaml configs/adiabatic_b.yaml --device 0from gyaradax.simulate import gk_from_config, gksimulate
# load yaml and run with IO/checkpointing
df, geometry, params, state, pre = gk_from_config("configs/my_sim.yaml")
df, phi, fluxes, state = gksimulate(
df, geometry, params, state, 1200, pre=pre,
output_dir="outputs", checkpoint_interval=120
)gyaradax can resume from GKW binary K files. The simplest way is gk_from_gkw_dir, which loads geometry, params, and the last K-file automatically:
from gyaradax.simulate import gk_from_gkw_dir, gksimulate
# loads input.dat, geometry, and resumes from the last K-file
df, geometry, params, state, pre = gk_from_gkw_dir("/path/to/gkw/run/")
df, phi, fluxes, state = gksimulate(df, geometry, params, state, 120, pre=pre)If you have an existing GKW run, you can extract its parameters and geometry into yaml:
python -m scripts.gkw_to_yaml /path/to/gkw_run configs/my_sim.yamlOnce compiled, the CUDA backend is auto-detected:
# auto-detect (uses CUDA if available, falls back to JAX)
params = GKParams(backend="auto")
# force CUDA
params = GKParams(backend="cuda")Or via config YAML:
solver:
backend: cudapython -m pytest tests/ -x -qMost tests require GKW reference data. Set the GKW_DATA_ROOT environment variable to the directory containing the reference runs (e.g. iteration_8/, iteration_13/, kinetic_electrons/). These tests skip when the data is not available.
Verification:
- Empirical validation against reference GKW trajectories.
- Analytical validation on RH and Cyclone Base Case.
- Differentiable programming: inverse problem and sensitivity analysis.
- GKW tests and benchmarks (see the gkw paper and Chapter 11 in the manual).
- Solver-in-the-Loop and PINNs as an ML showcase.
- Portable unit tests
Physics and solver extensions:
- Linear solver.
- Adiabatic electrons corrections and cases (ion only, single species).
- Kinetic electrons (multi-species).
- Electromagnetic effects.
- Collisionality.
Optimization:
- JAX-based improvements.
- CUDA LTO backend (fused linear stencil and nonlinear solve).
- Fully spectral solver.
- Implicit/explicit integration (IMEX).
@misc{galletti2026gyaradax,
title={gyaradax: Local Gyrokinetics JAX Code},
author={Gianluca Galletti and Eric Volkmann and Johannes Brandstetter},
year={2026},
primaryClass={physics.plasm-ph},
url={https://arxiv.org/abs/2604.06085},
}

