A GPU-Accelerated Mahjong Simulator for Reinforcement Learning in JAX
Note
Japanese Riichi Mahjong is a challenging multi-agent RL environment with imperfect information, stochastic dynamics, more than two players, and high-dimensional observations. Mahjax aims to make Mahjong research more accessible to a broader RL community. For newcommers, please see our basic introduction and the bilingual visualization.
- 🚀 Vectorized Environment: Extremely fast (approx. 1.6M steps/sec on 8x A100 GPUs).
- 🎨 Rich Visualization: SVG-based visualization with bilingual support for those unfamiliar with Kanji.
- 🎮 Playable Interface: A web-based UI allows you to play directly against the agents you train.
- 📚 RL Examples: Simple examples for Behavior Cloning + PPO in the
examples/.
For more details, please refer to the Documentation.
Mahjax is available on PyPI. Please make sure that your Python environment has jax and jaxlib installed, depending on your hardware setup.
pip install mahjax📣 Mahjax is currently under active development. If you prefer to use the latest codebase with the newest features, please clone the repository and install it in editable mode:
git clone https://github.com/nissymori/mahjax.git
cd mahjax
pip install -e .Note
The current API is still provisional and under active development, so it may change in future releases.
We basically follow the Pgx API design.
import jax
import jax.numpy as jnp
import mahjax
batch_size = 10
rng = jax.random.PRNGKey(0)
# Initialize environment
env = mahjax.make(
"red_mahjong",
round_mode="single", # "single", "east" (tonpuusen), or "half" (hanchan)
observe_type="dict", # "dict" for Transformer, "2D" for CNN
order_points=[30, 10, -10, -30] # Final score bonuses (uma)
)
init_fn = jax.jit(jax.vmap(env.init))
step_fn = jax.jit(jax.vmap(env.step))
obs_fn = jax.jit(jax.vmap(env.observe))
# Initialize state
rng, subrng = jax.random.split(rng)
rngs = jax.random.split(subrng, batch_size)
state = init_fn(rngs)
# Step
rng, subrng = jax.random.split(rng)
rngs = jax.random.split(subrng, batch_size)
action = jnp.zeros((batch_size,), dtype=jnp.int8)
state = step_fn(state, action, rngs)
# Get observation
obs = obs_fn(state)MahJax includes a web-based UI (FastAPI + JS) that allows you to play against built-in or custom agents directly in your browser.
Install dependencies and start the server:
pip install mahjax
uvicorn mahjax.ui.app:create_app --host 0.0.0.0 --port 8000Open http://localhost:8000 to start playing. The default agents are the random and rule_based ones.
You can register your trained agent to appear in the UI's agent selector.
Create a Python script (e.g., my_app.py) and register your agent's act function:
### my_app.py
from pathlib import Path
from mahjax.ui.app import create_app
app = create_app()
# Load your custom agent
app.state.manager.registry.load_callable_from_path(
file_path=Path("path/to/my_agent.py"),
attribute="act", # The function name to call: act(state, rng) -> action_id
description="My Custom Agent",
)Run uvicorn my_ui:app --port 8000.
Currently, MahJax supports the following rules:
| Rule | id | Status | Code | Speed (steps/sec) |
|---|---|---|---|---|
| No-Red Mahjong | no_red_mahjong |
✅ | no_red_mahjong | ~1.6M |
| Red Mahjong | red_mahjong |
✅ | red_mahjong | ~9M |
| Selective Rules | - | 🚧 | - | - |
| 3-player Mahjong | - | 🚧 | - | - |
red_mahjong implements standard 4-player riichi mahjong with red fives.
Its rules are designed to follow Tenhou, one of the most widely used online mahjong platforms in Japan, and we validate the implementation against downloaded Tenhou game logs.
For the detailed rule specification, see the official Tenhou rules.
no_red_mahjong implements 4-player riichi mahjong without red fives.
This variant is intentionally simplified for speed, and excludes some rules such as abortive draws (特殊流局), pao, and double ron.
If throughput is your priority, no_red_mahjong is the recommended option.
You can configure the environment with:
id: the rule set, such asred_mahjongorno_red_mahjonground_mode:singlefor a single round,eastfor tonpuusen (East-only), orhalffor hanchan (East-South)observe_type:dictfor transformer-style inputs or2Dfor CNN-style inputsorder_points: final placement bonuses (uma), for example[30, 10, -10, -30]
env = mahjax.make(
id="red_mahjong",
round_mode="single",
observe_type="dict",
order_points=[30, 10, -10, -30],
)Note
The observation features are not yet finalized (though the current version suffices for RL with BC; see examples/).
JAX-based environments
- Pgx: Board game environments such as Go, Chess, and Shogi.
- Brax: Robotics control.
- Gymnax: Popular small-scale RL environments such as CartPole or bsuite.
- Jumanji: A diverse suite of RL environments (packing, routing, etc.).
- Craftax: A JAX version of Crafter + Nethack.
- JaxMARL: Multi-agent environments such as Hanabi.
- Navix: A JAX version of MiniGrid.
Paper coming soon.
- sotetsuk: For general advice on the development of mahjax based on his experience developing pgx.
- habara-k: For developing core JAX components such as shanten and Yaku calculation.
- OkanoShinri: For the initial implementation of MahJax and its SVG visualization.
- easonyu0203: For advice on PPO implementation in a multi-player imperfect-information game.

