Skip to content

gerkone/gyaradax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

137 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

gyaradax: Gyrokinetics in JAX

gyaradax Logo

gyaradax is a JAX code for local flux-tube gyrokinetic simulations. It is based on GKW. At the current stage, it provides a differentiable solver for the electrostatic, collisionless Vlasov-Poisson system.

This was made possible with significant usage of agentic workflows. PROMPT.md contains the prompt used to obtain the initial working version of gyaradax

Check out our whitepaper, or see agent notes for a detailed walkthrough of GKW and this reimplementation.

Nonlinear ITG turbulence on a torus

Installation

Create a conda environment with CUDA toolkit and all dependencies:

conda env create -f environment.yml
conda activate gyaradax_env
pip install -e ".[dev]"

This installs gyaradax in editable mode with JAX (CUDA 13, local toolkit), numpy, and dev tools (pytest, ruff, black). The conda environment provides the CUDA toolkit (>= 13.1), cuDNN, cmake, and a C++ compiler.

CUDA Backend

The optional CUDA backend provides fused kernels for the linear RHS stencils and the nonlinear Poisson bracket (cuFFT graph-captured pipeline). It requires a GPU with compute capability >= 80.

From gyaradax/backends/cuda_kernels/:

mkdir -p _build && cd _build && cmake .. -DCMAKE_BUILD_TYPE=Release && cmake --build . -j$(nproc) && cmake --install . && cd ..

To target a specific GPU architecture (e.g., Ampere sm_80):

cmake .. -DCMAKE_BUILD_TYPE=Release -DGPU_ARCHITECTURES="80"

CMake prints the detected compute capability, jaxlib version, and cudatoolkit. Ensure these are correct before proceeding.

Structure

  • solver.py: Linear and nonlinear Terms (I-VIII), RK4 integrator.
  • simulate.py: Interface for trajectory generation.
  • integrals.py: Field solvers and flux integrals.
  • geometry.py: Parsers for GKW geometry files and metric tensor coefficients.
  • params.py: Configuration pytrees.
  • stencils.py: Finite difference stencil definitions.
  • diag.py: Diagnostics (growth rate, frequency, spectral).
  • backends/: Backend dispatch (JAX, CUDA). See CUDA build instructions.
  • plot_utils.py: Visualization.

Running Simulations

Basic usage

The scripts/run.py script provides a convenient way to execute simulations, supporting single or multiple configuration files, batch execution, and specifying runtime options like the device and number of blocks.

# Run a single configuration
python -u -m scripts.run configs/iteration_13.yaml --device 0

When multiple YAML configuration files are provided, and they share the same grid resolution and static parameters, scripts/run.py can automatically batch them using jax.vmap for parallel execution on a single device.

# Run two configurations in parallel on device 0
python -u -m scripts.run configs/adiabatic_a.yaml configs/adiabatic_b.yaml --device 0

Usage

Run a simulation

from gyaradax.simulate import gk_from_config, gksimulate

# load yaml and run with IO/checkpointing
df, geometry, params, state, pre = gk_from_config("configs/my_sim.yaml")
df, phi, fluxes, state = gksimulate(
  df, geometry, params, state, 1200, pre=pre,
  output_dir="outputs", checkpoint_interval=120
)

Resume from GKW checkpoints

gyaradax can resume from GKW binary K files. The simplest way is gk_from_gkw_dir, which loads geometry, params, and the last K-file automatically:

from gyaradax.simulate import gk_from_gkw_dir, gksimulate

# loads input.dat, geometry, and resumes from the last K-file
df, geometry, params, state, pre = gk_from_gkw_dir("/path/to/gkw/run/")
df, phi, fluxes, state = gksimulate(df, geometry, params, state, 120, pre=pre)

Configuration from GKW

If you have an existing GKW run, you can extract its parameters and geometry into yaml:

python -m scripts.gkw_to_yaml /path/to/gkw_run configs/my_sim.yaml

CUDA backend

Once compiled, the CUDA backend is auto-detected:

# auto-detect (uses CUDA if available, falls back to JAX)
params = GKParams(backend="auto")

# force CUDA
params = GKParams(backend="cuda")

Or via config YAML:

solver:
  backend: cuda

Testing

python -m pytest tests/ -x -q

Most tests require GKW reference data. Set the GKW_DATA_ROOT environment variable to the directory containing the reference runs (e.g. iteration_8/, iteration_13/, kinetic_electrons/). These tests skip when the data is not available.

State of the project

Verification:

  • Empirical validation against reference GKW trajectories.
  • Analytical validation on RH and Cyclone Base Case.
  • Differentiable programming: inverse problem and sensitivity analysis.
  • GKW tests and benchmarks (see the gkw paper and Chapter 11 in the manual).
  • Solver-in-the-Loop and PINNs as an ML showcase.
  • Portable unit tests

Physics and solver extensions:

  • Linear solver.
  • Adiabatic electrons corrections and cases (ion only, single species).
  • Kinetic electrons (multi-species).
  • Electromagnetic effects.
  • Collisionality.

Optimization:

  • JAX-based improvements.
  • CUDA LTO backend (fused linear stencil and nonlinear solve).
  • Fully spectral solver.
  • Implicit/explicit integration (IMEX).

Citing

@misc{galletti2026gyaradax,
      title={gyaradax: Local Gyrokinetics JAX Code}, 
      author={Gianluca Galletti and Eric Volkmann and Johannes Brandstetter},
      year={2026},
      primaryClass={physics.plasm-ph},
      url={https://arxiv.org/abs/2604.06085}, 
}

About

Your local gyrokinetics jax code 🐉

Resources

License

Stars

Watchers

Forks

Contributors