Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,10 @@ rest of the cells run as-is.

| Tutorial | Summary | Difficulty | Framework | Launch |
|---|---|---|---|---|
| [`000_rl_basics`](tutorials/rl/000_rl_basics/000_rl_basics.ipynb) | Qwen3-4B haiku evaluation with verifiable rewards — serve, evaluate, train, compare | Beginner | `slime` | <a href="https://modal.com/notebooks/new/https://github.com/modal-projects/training-gym/blob/devin/1778791977-positive-rewards-001-sandboxes/tutorials/rl/000_rl_basics/000_rl_basics.ipynb" target="_blank" rel="nofollow noopener noreferrer"><img src="https://modal-cdn.com/open-in-modal.svg" alt="Open in Modal"></a> |
| [`001_sandboxes`](tutorials/rl/001_sandboxes/001_sandboxes.ipynb) | Code RL with Harbor hello-world and sandboxed verification | Intermediate | `slime` | <a href="https://modal.com/notebooks/new/https://github.com/modal-projects/training-gym/blob/devin/1778791977-positive-rewards-001-sandboxes/tutorials/rl/001_sandboxes/001_sandboxes.ipynb" target="_blank" rel="nofollow noopener noreferrer"><img src="https://modal-cdn.com/open-in-modal.svg" alt="Open in Modal"></a> |
| [`002_multiturn`](tutorials/rl/002_multiturn/002_multiturn.ipynb) | Multi-turn number-guessing RL with custom generate and reward functions | Intermediate | `slime` | <a href="https://modal.com/notebooks/new/https://github.com/modal-projects/training-gym/blob/devin/1778791977-positive-rewards-001-sandboxes/tutorials/rl/002_multiturn/002_multiturn.ipynb" target="_blank" rel="nofollow noopener noreferrer"><img src="https://modal-cdn.com/open-in-modal.svg" alt="Open in Modal"></a> |
| [`000_rl_basics`](tutorials/rl/000_rl_basics/000_rl_basics.ipynb) | Qwen3-4B haiku evaluation with verifiable rewards — serve, evaluate, train, compare | Beginner | `slime` | <a href="https://modal.com/notebooks/new/https://github.com/modal-projects/training-gym/blob/main/tutorials/rl/000_rl_basics/000_rl_basics.ipynb" target="_blank" rel="nofollow noopener noreferrer"><img src="https://modal-cdn.com/open-in-modal.svg" alt="Open in Modal"></a> |
| [`001_sandboxes`](tutorials/rl/001_sandboxes/001_sandboxes.ipynb) | Code RL with Harbor hello-world and sandboxed verification | Intermediate | `slime` | <a href="https://modal.com/notebooks/new/https://github.com/modal-projects/training-gym/blob/main/tutorials/rl/001_sandboxes/001_sandboxes.ipynb" target="_blank" rel="nofollow noopener noreferrer"><img src="https://modal-cdn.com/open-in-modal.svg" alt="Open in Modal"></a> |
| [`002_multiturn`](tutorials/rl/002_multiturn/002_multiturn.ipynb) | Multi-turn number-guessing RL with custom generate and reward functions | Intermediate | `slime` | <a href="https://modal.com/notebooks/new/https://github.com/modal-projects/training-gym/blob/main/tutorials/rl/002_multiturn/002_multiturn.ipynb" target="_blank" rel="nofollow noopener noreferrer"><img src="https://modal-cdn.com/open-in-modal.svg" alt="Open in Modal"></a> |
| [`003_glm_gsm8k`](tutorials/rl/003_glm_gsm8k/003_glm_gsm8k.ipynb) | GLM-4.7 (355B MoE) on GSM8K math — serve, evaluate, GRPO-train, compare | Advanced | `slime` | <a href="https://modal.com/notebooks/new/https://github.com/modal-projects/training-gym/blob/main/tutorials/rl/003_glm_gsm8k/003_glm_gsm8k.ipynb" target="_blank" rel="nofollow noopener noreferrer"><img src="https://modal-cdn.com/open-in-modal.svg" alt="Open in Modal"></a> |
<!-- END TUTORIAL TABLE -->

See [`tutorials/README.md`](tutorials/README.md) for how to run the `.py`
Expand Down
4 changes: 4 additions & 0 deletions modal_training_gym/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
"EvalConfigDurable": ("modal_training_gym.common.eval", "EvalConfigDurable"),
"EvalResult": ("modal_training_gym.common.eval", "EvalResult"),
"EvalRowResult": ("modal_training_gym.common.eval", "EvalRowResult"),
"GLM_4_7": ("modal_training_gym.common.models", "GLM_4_7"),
"GLM_4_7_Flash": ("modal_training_gym.common.models", "GLM_4_7_Flash"),
"HFModelConfiguration": (
"modal_training_gym.common.models",
"HFModelConfiguration",
Expand Down Expand Up @@ -60,6 +62,8 @@
"ModelConfig",
"ModelDeployment",
"MultiTurn",
"GLM_4_7",
"GLM_4_7_Flash",
"Qwen3_0_6B",
"Qwen3_1_7B",
"Qwen3_4B",
Expand Down
4 changes: 4 additions & 0 deletions modal_training_gym/common/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
ModelArchitecture,
ModelConfig,
)
from .glm_4_7 import GLM_4_7
from .glm_4_7_flash import GLM_4_7_Flash
from .qwen3_0_6b import Qwen3_0_6B
from .qwen3_1_7b import Qwen3_1_7B
from .qwen3_4b import Qwen3_4B
Expand All @@ -12,6 +14,8 @@
from .qwen3_32b import Qwen3_32B

__all__ = [
"GLM_4_7",
"GLM_4_7_Flash",
"HFModelConfiguration",
"ModelArchitecture",
"ModelConfig",
Expand Down
17 changes: 17 additions & 0 deletions modal_training_gym/common/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@ class ModelArchitecture:
use_rotary_position_embeddings: bool = True
rotary_base: int = 10000

# MoE (Mixture of Experts)
num_experts: int = 0
moe_router_topk: int = 0
moe_ffn_hidden_size: int = 0
num_shared_experts: int = 0
first_k_dense_replace: int = 0

def to_megatron_args(self) -> list[str]:
"""Generate Megatron-LM CLI flags from this architecture spec."""
args: list[str] = []
Expand Down Expand Up @@ -111,6 +118,16 @@ def to_megatron_args(self) -> list[str]:
args += ["--position-embedding-type", "rope"]
if self.rotary_base != 10000:
args += ["--rotary-base", str(self.rotary_base)]
if self.num_experts:
args += ["--num-experts", str(self.num_experts)]
if self.moe_router_topk:
args += ["--moe-router-topk", str(self.moe_router_topk)]
if self.moe_ffn_hidden_size:
args += ["--moe-ffn-hidden-size", str(self.moe_ffn_hidden_size)]
if self.num_shared_experts:
args += ["--num-shared-experts", str(self.num_shared_experts)]
if self.first_k_dense_replace:
args += ["--first-k-dense-replace", str(self.first_k_dense_replace)]
return args


Expand Down
38 changes: 38 additions & 0 deletions modal_training_gym/common/models/glm_4_7.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""GLM-4.7 (355B-A32B MoE) model spec."""

from .base import HFModelConfiguration, ModelArchitecture


class GLM_4_7(HFModelConfiguration):
"""GLM-4.7 (355B total, 32B active) MoE from Zhipu AI.

160 routed experts with top-8 routing plus 1 shared expert.
First 3 layers are dense; remaining 89 are MoE.
Uses GQA (96 Q heads, 8 KV heads) with partial RoPE.
Downloads from ``zai-org/GLM-4.7`` on HuggingFace.
"""

model_name = "zai-org/GLM-4.7"
architecture = ModelArchitecture(
num_layers=92,
hidden_size=5120,
ffn_hidden_size=12288,
num_attention_heads=96,
group_query_attention=True,
num_query_groups=8,
kv_channels=128,
vocab_size=151552,
normalization="RMSNorm",
norm_epsilon=1e-5,
swiglu=True,
disable_bias_linear=False,
qk_layernorm=True,
untie_embeddings_and_output_weights=True,
use_rotary_position_embeddings=True,
rotary_base=1000000,
num_experts=160,
moe_router_topk=8,
moe_ffn_hidden_size=1536,
num_shared_experts=1,
first_k_dense_replace=3,
)
36 changes: 36 additions & 0 deletions modal_training_gym/common/models/glm_4_7_flash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""GLM-4.7-Flash (30B-A3B MoE) model spec."""

from .base import HFModelConfiguration, ModelArchitecture


class GLM_4_7_Flash(HFModelConfiguration):
"""GLM-4.7-Flash (30B total, 3B active) MoE from Zhipu AI.

64 routed experts with top-4 routing plus 1 shared expert.
Uses Multi-head Latent Attention (MLA) and multi-token prediction.
Downloads from ``zai-org/GLM-4.7-Flash`` on HuggingFace.
"""

model_name = "zai-org/GLM-4.7-Flash"
architecture = ModelArchitecture(
num_layers=47,
hidden_size=2048,
ffn_hidden_size=10240,
num_attention_heads=20,
group_query_attention=False,
num_query_groups=20,
kv_channels=128,
vocab_size=154880,
normalization="RMSNorm",
norm_epsilon=1e-5,
swiglu=True,
disable_bias_linear=True,
qk_layernorm=True,
use_rotary_position_embeddings=True,
rotary_base=1000000,
num_experts=64,
moe_router_topk=4,
moe_ffn_hidden_size=1536,
num_shared_experts=1,
first_k_dense_replace=1,
)
8 changes: 8 additions & 0 deletions modal_training_gym/deploy_recipes/sglang_recipe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from modal_training_gym.deploy_recipes.sglang_recipe.recipe import SglangRecipe
from modal_training_gym.deploy_recipes.sglang_recipe.glm_4_7 import (
GLM_4_7_SglangRecipe,
)
from modal_training_gym.deploy_recipes.sglang_recipe.glm_4_7_flash import (
GLM_4_7_Flash_SglangRecipe,
)
from modal_training_gym.deploy_recipes.sglang_recipe.qwen3_0_6b import (
Qwen3_0_6b_SglangRecipe,
)
Expand All @@ -22,6 +28,8 @@
)

__all__ = [
"GLM_4_7_SglangRecipe",
"GLM_4_7_Flash_SglangRecipe",
"SglangRecipe",
"Qwen3_0_6b_SglangRecipe",
"Qwen3_1_7b_SglangRecipe",
Expand Down
26 changes: 26 additions & 0 deletions modal_training_gym/deploy_recipes/sglang_recipe/glm_4_7.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from dataclasses import dataclass

from modal_training_gym.deploy_recipes.sglang_recipe.recipe import SglangRecipe

_GLM_4_7_DEFAULTS = {
"gpu": "H100",
"tp": 8,
"context_length": 32768,
"mem_fraction_static": 0.80,
"chunked_prefill_size": 8192,
"max_running_requests": 16,
"extra_server_args": {"--trust-remote-code": ""},
}


_SGLANG_DEFAULTS = SglangRecipe()


@dataclass
class GLM_4_7_SglangRecipe(SglangRecipe):
"""GLM-4.7 (355B) on 8×H100 — tensor-parallel MoE serving."""

def __post_init__(self) -> None:
for key, val in _GLM_4_7_DEFAULTS.items():
if getattr(self, key) == getattr(_SGLANG_DEFAULTS, key):
object.__setattr__(self, key, val)
27 changes: 27 additions & 0 deletions modal_training_gym/deploy_recipes/sglang_recipe/glm_4_7_flash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from dataclasses import dataclass

from modal_training_gym.deploy_recipes.sglang_recipe.recipe import SglangRecipe

_GLM_4_7_FLASH_DEFAULTS = {
"gpu": "H100",
"tp": 1,
"dp": 8,
"context_length": 32768,
"mem_fraction_static": 0.80,
"chunked_prefill_size": 8192,
"max_running_requests": 16,
"extra_server_args": {"--trust-remote-code": ""},
}


_SGLANG_DEFAULTS = SglangRecipe()


@dataclass
class GLM_4_7_Flash_SglangRecipe(SglangRecipe):
"""GLM-4.7-Flash on 8×H100 — DP-attention MoE serving."""

def __post_init__(self) -> None:
for key, val in _GLM_4_7_FLASH_DEFAULTS.items():
if getattr(self, key) == getattr(_SGLANG_DEFAULTS, key):
object.__setattr__(self, key, val)
8 changes: 8 additions & 0 deletions modal_training_gym/deploy_recipes/vllm_recipe/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from modal_training_gym.deploy_recipes.vllm_recipe.recipe import VllmRecipe
from modal_training_gym.deploy_recipes.vllm_recipe.glm_4_7 import (
GLM_4_7_VllmRecipe,
)
from modal_training_gym.deploy_recipes.vllm_recipe.glm_4_7_flash import (
GLM_4_7_Flash_VllmRecipe,
)
from modal_training_gym.deploy_recipes.vllm_recipe.qwen3_0_6b import (
Qwen3_0_6b_VllmRecipe,
)
Expand All @@ -12,6 +18,8 @@
from modal_training_gym.deploy_recipes.vllm_recipe.qwen3_32b import Qwen3_32b_VllmRecipe

__all__ = [
"GLM_4_7_VllmRecipe",
"GLM_4_7_Flash_VllmRecipe",
"VllmRecipe",
"Qwen3_0_6b_VllmRecipe",
"Qwen3_1_7b_VllmRecipe",
Expand Down
21 changes: 21 additions & 0 deletions modal_training_gym/deploy_recipes/vllm_recipe/glm_4_7.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from dataclasses import dataclass

from modal_training_gym.deploy_recipes.vllm_recipe.recipe import VllmRecipe

_GLM_4_7_DEFAULTS = {
"gpu": "H100",
"n_gpu": 8,
"extra_vllm_args": ["--trust-remote-code"],
}

_VLLM_DEFAULTS = VllmRecipe()


@dataclass
class GLM_4_7_VllmRecipe(VllmRecipe):
"""GLM-4.7 (355B) on 8×H100 — tensor-parallel MoE serving."""

def __post_init__(self) -> None:
for key, val in _GLM_4_7_DEFAULTS.items():
if getattr(self, key) == getattr(_VLLM_DEFAULTS, key):
object.__setattr__(self, key, val)
21 changes: 21 additions & 0 deletions modal_training_gym/deploy_recipes/vllm_recipe/glm_4_7_flash.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from dataclasses import dataclass

from modal_training_gym.deploy_recipes.vllm_recipe.recipe import VllmRecipe

_GLM_4_7_FLASH_DEFAULTS = {
"gpu": "H100",
"n_gpu": 2,
"extra_vllm_args": ["--trust-remote-code"],
}

_VLLM_DEFAULTS = VllmRecipe()


@dataclass
class GLM_4_7_Flash_VllmRecipe(VllmRecipe):
"""GLM-4.7-Flash on 2×H100 — tensor-parallel MoE serving."""

def __post_init__(self) -> None:
for key, val in _GLM_4_7_FLASH_DEFAULTS.items():
if getattr(self, key) == getattr(_VLLM_DEFAULTS, key):
object.__setattr__(self, key, val)
6 changes: 6 additions & 0 deletions modal_training_gym/train_recipes/slime_recipe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,19 @@
SlimeRecipeBlock,
)
from modal_training_gym.train_recipes.slime_recipe.recipe import SlimeRecipe
from modal_training_gym.train_recipes.slime_recipe.glm_4_7 import GLM_4_7_Recipe
from modal_training_gym.train_recipes.slime_recipe.glm_4_7_flash import (
GLM_4_7_Flash_Recipe,
)
from modal_training_gym.train_recipes.slime_recipe.qwen3_1_7b import Qwen3_1_7b_Recipe
from modal_training_gym.train_recipes.slime_recipe.qwen3_8b import Qwen3_8b_Recipe
from modal_training_gym.train_recipes.slime_recipe.qwen3_14b import Qwen3_14b_Recipe
from modal_training_gym.train_recipes.slime_recipe.qwen3_32b import Qwen3_32b_Recipe
from modal_training_gym.train_recipes.slime_recipe.qwen3_4b import Qwen3_4b_Recipe

__all__ = [
"GLM_4_7_Recipe",
"GLM_4_7_Flash_Recipe",
"MultiTurn",
"SlimeRecipe",
"SlimeRecipeBlock",
Expand Down
51 changes: 51 additions & 0 deletions modal_training_gym/train_recipes/slime_recipe/glm_4_7.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from pydantic import ConfigDict
from pydantic.dataclasses import dataclass

from modal_training_gym.train_recipes.slime_recipe.recipe import SlimeRecipe


@dataclass(config=ConfigDict(extra="forbid", arbitrary_types_allowed=True))
class GLM_4_7_Recipe(SlimeRecipe):
"""GLM-4.7 (355B-A32B MoE) on 8×8×H100, colocated GRPO.

TP=8, PP=4, CP=2, EP=16 across 8 nodes (64 GPUs).
Uses CPU optimizer offloading for the large parameter count.
"""

gpu_type: str = "H100"
colocate: bool = True
actor_num_nodes: int = 8
actor_num_gpus_per_node: int = 8
tensor_model_parallel_size: int = 8
sequence_parallel: bool = True
rollout_num_gpus_per_engine: int = 32

# MoE parallelism
expert_model_parallel_size: int = 16
expert_tensor_parallel_size: int = 1
pipeline_model_parallel_size: int = 4
context_parallel_size: int = 2
attention_backend: str | None = "flash"

# Rollout
num_rollout: int = 1
rollout_batch_size: int = 64
rollout_max_response_len: int = 4096
rollout_temperature: float = 1.0
sglang_mem_fraction_static: float = 0.70

save_interval: int = 10

# Training
n_samples_per_prompt: int = 8
global_batch_size: int = 512
lr: float = 1e-6
max_tokens_per_gpu: int = 16384

# Optimizer offloading (required for 355B model)
optimizer_cpu_offload: bool = True
overlap_cpu_optimizer_d2h_h2d: bool = True
use_precision_aware_optimizer: bool = True

eval_interval: int | None = 10
eval_max_response_len: int = 4096
Loading
Loading