Skip to content

nissymori/mahjax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

197 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation


MahJax

PyPI License Supported Python versions

A GPU-Accelerated Mahjong Simulator for Reinforcement Learning in JAX

Note

Japanese Riichi Mahjong is a challenging multi-agent RL environment with imperfect information, stochastic dynamics, more than two players, and high-dimensional observations. Mahjax aims to make Mahjong research more accessible to a broader RL community. For newcommers, please see our basic introduction and the bilingual visualization.

Overview

  • 🚀 Vectorized Environment: Extremely fast (approx. 1.6M steps/sec on 8x A100 GPUs).
  • 🎨 Rich Visualization: SVG-based visualization with bilingual support for those unfamiliar with Kanji.
  • 🎮 Playable Interface: A web-based UI allows you to play directly against the agents you train.
  • 📚 RL Examples: Simple examples for Behavior Cloning + PPO in the examples/.

For more details, please refer to the Documentation.

Quick Start

Install

Mahjax is available on PyPI. Please make sure that your Python environment has jax and jaxlib installed, depending on your hardware setup.

pip install mahjax

📣 Mahjax is currently under active development. If you prefer to use the latest codebase with the newest features, please clone the repository and install it in editable mode:

git clone https://github.com/nissymori/mahjax.git
cd mahjax
pip install -e .

Note

The current API is still provisional and under active development, so it may change in future releases.

Basic Usage

We basically follow the Pgx API design.

import jax
import jax.numpy as jnp
import mahjax

batch_size = 10
rng = jax.random.PRNGKey(0)

# Initialize environment
env = mahjax.make(
    "red_mahjong",
    round_mode="single", # "single", "east" (tonpuusen), or "half" (hanchan)
    observe_type="dict", # "dict" for Transformer, "2D" for CNN
    order_points=[30, 10, -10, -30] # Final score bonuses (uma)
)

init_fn = jax.jit(jax.vmap(env.init))
step_fn = jax.jit(jax.vmap(env.step))
obs_fn = jax.jit(jax.vmap(env.observe))

# Initialize state
rng, subrng = jax.random.split(rng)
rngs = jax.random.split(subrng, batch_size)
state = init_fn(rngs)

# Step
rng, subrng = jax.random.split(rng)
rngs = jax.random.split(subrng, batch_size)
action = jnp.zeros((batch_size,), dtype=jnp.int8)
state = step_fn(state, action, rngs)

# Get observation
obs = obs_fn(state)

User interface

MahJax includes a web-based UI (FastAPI + JS) that allows you to play against built-in or custom agents directly in your browser.

Running the UI

Install dependencies and start the server:

pip install mahjax
uvicorn mahjax.ui.app:create_app --host 0.0.0.0 --port 8000

Open http://localhost:8000 to start playing. The default agents are the random and rule_based ones.

Playing Against Your Agent

You can register your trained agent to appear in the UI's agent selector. Create a Python script (e.g., my_app.py) and register your agent's act function:

### my_app.py
from pathlib import Path
from mahjax.ui.app import create_app

app = create_app()

# Load your custom agent
app.state.manager.registry.load_callable_from_path(
    file_path=Path("path/to/my_agent.py"),
    attribute="act", # The function name to call: act(state, rng) -> action_id
    description="My Custom Agent",
)

Run uvicorn my_ui:app --port 8000.

Supported Rules

Currently, MahJax supports the following rules:

Rule id Status Code Speed (steps/sec)
No-Red Mahjong no_red_mahjong no_red_mahjong ~1.6M
Red Mahjong red_mahjong red_mahjong ~9M
Selective Rules - 🚧 - -
3-player Mahjong - 🚧 - -

red_mahjong implements standard 4-player riichi mahjong with red fives. Its rules are designed to follow Tenhou, one of the most widely used online mahjong platforms in Japan, and we validate the implementation against downloaded Tenhou game logs. For the detailed rule specification, see the official Tenhou rules.

no_red_mahjong implements 4-player riichi mahjong without red fives. This variant is intentionally simplified for speed, and excludes some rules such as abortive draws (特殊流局), pao, and double ron. If throughput is your priority, no_red_mahjong is the recommended option.

You can configure the environment with:

  • id: the rule set, such as red_mahjong or no_red_mahjong
  • round_mode: single for a single round, east for tonpuusen (East-only), or half for hanchan (East-South)
  • observe_type: dict for transformer-style inputs or 2D for CNN-style inputs
  • order_points: final placement bonuses (uma), for example [30, 10, -10, -30]
env = mahjax.make(
    id="red_mahjong",
    round_mode="single",
    observe_type="dict",
    order_points=[30, 10, -10, -30],
)

Note

The observation features are not yet finalized (though the current version suffices for RL with BC; see examples/).

See also

JAX-based environments

  • Pgx: Board game environments such as Go, Chess, and Shogi.
  • Brax: Robotics control.
  • Gymnax: Popular small-scale RL environments such as CartPole or bsuite.
  • Jumanji: A diverse suite of RL environments (packing, routing, etc.).
  • Craftax: A JAX version of Crafter + Nethack.
  • JaxMARL: Multi-agent environments such as Hanabi.
  • Navix: A JAX version of MiniGrid.

Cite us

Paper coming soon.

Acknowledgement

  • sotetsuk: For general advice on the development of mahjax based on his experience developing pgx.
  • habara-k: For developing core JAX components such as shanten and Yaku calculation.
  • OkanoShinri: For the initial implementation of MahJax and its SVG visualization.
  • easonyu0203: For advice on PPO implementation in a multi-player imperfect-information game.

About

A GPU-Accelerated Mahjong Simulator for RL in JAX

Topics

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors