diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index 7c53445722..f0137269bf 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -3,6 +3,7 @@ import argparse import json import os +import warnings from dataclasses import MISSING as dataclass_missing from dataclasses import asdict, dataclass, field, fields from enum import Enum @@ -2877,7 +2878,32 @@ def __post_init__(self): @dataclass -class TeacherConfig(PPOActorConfig): +class TeacherConfig: + engine_type: str = field( + default="rollout", + metadata={ + "help": "Teacher engine type. 'rollout' uses inference engine scoring; " + "'train' uses the legacy train-engine teacher path.", + "choices": ["rollout", "train"], + }, + ) + rollout: InferenceEngineConfig | None = field(default=None) + train: PPOActorConfig | None = field( + default=None, + metadata={ + "help": "Legacy train-engine teacher config. Required when engine_type='train'." + }, + ) + path: str = field( + default="", + metadata={ + "help": "Teacher model path. If set, overrides shared rollout backend model path." + }, + ) + offload: bool = field( + default=False, + metadata={"help": "Whether to offload teacher rollout model between steps"}, + ) rl_loss_weight: float = field( default=1.0, metadata={"help": "RL loss weight"}, @@ -2888,6 +2914,22 @@ class TeacherConfig(PPOActorConfig): metadata={"help": "Distillation loss weight"}, ) + def __post_init__(self): + if self.rollout is not None and self.train is not None: + warnings.warn( + "Both teacher.rollout and teacher.train are configured; " + f"teacher.engine_type={self.engine_type!r} selects which one is used.", + stacklevel=2, + ) + if self.engine_type == "rollout" and self.rollout is None: + raise ValueError( + "teacher.rollout must be provided when teacher.engine_type='rollout'." + ) + if self.engine_type == "train" and self.train is None: + raise ValueError( + "teacher.train must be provided when teacher.engine_type='train'." + ) + @dataclass class PPOConfig(BaseExperimentConfig): diff --git a/areal/api/engine_api.py b/areal/api/engine_api.py index 08cd31c45c..f3ca01187c 100644 --- a/areal/api/engine_api.py +++ b/areal/api/engine_api.py @@ -717,6 +717,15 @@ def get_version(self) -> int: """ raise NotImplementedError() + def compute_logp(self, data: list[dict[str, Any]]) -> list[torch.Tensor]: + """Compute token log-probabilities for teacher distillation. + + Implementations support this as an inference-side scoring API. + """ + raise NotImplementedError( + f"{self.__class__.__name__} does not implement compute_logp()." + ) + def submit( self, data: dict[str, Any], diff --git a/areal/engine/sglang_remote.py b/areal/engine/sglang_remote.py index 52039df326..901d6ca0d1 100644 --- a/areal/engine/sglang_remote.py +++ b/areal/engine/sglang_remote.py @@ -10,6 +10,7 @@ import numpy as np import pybase64 +import torch from torchdata.stateful_dataloader import StatefulDataLoader from areal.api import ( @@ -126,6 +127,40 @@ def parse_generation_response( routed_experts=routed_experts, ) + def build_score_request( + self, input_ids: list[int], target_len: int, with_lora: bool, version: int + ) -> HttpRequest: + payload: dict[str, Any] = { + "input_ids": input_ids, + "sampling_params": { + "max_new_tokens": 1, + "temperature": 0.0, + }, + "return_logprob": True, + "logprob_start_len": max(0, len(input_ids) - target_len - 1), + "top_logprobs_num": 0, + "stream": False, + } + if with_lora: + raise NotImplementedError( + "LoRA scoring request is not supported in SGLang teacher compute_logp yet." + ) + return HttpRequest(endpoint="/generate", payload=payload) + + def parse_score_response( + self, response: dict[str, Any], target_len: int + ) -> list[float]: + meta_info = response.get("meta_info") + if meta_info is None: + raise ValueError("SGLang response missing meta_info for score request") + # SGLang returns [logprob, token_id, ...] + all_logprobs = [float(x[0]) for x in meta_info.get("input_token_logprobs", [])] + if len(all_logprobs) < target_len: + raise ValueError( + f"SGLang returned insufficient input_token_logprobs: {len(all_logprobs)} < {target_len}" + ) + return all_logprobs[-target_len:] + def build_disk_weight_update_requests( self, meta: WeightUpdateMeta ) -> WeightUpdateRequests: @@ -502,6 +537,9 @@ def prepare_batch( dynamic_bs=dynamic_bs, ) + def compute_logp(self, data: list[dict[str, Any]]) -> list[torch.Tensor]: + return self._engine.compute_logp(data) + def pause(self): return self._engine.pause() diff --git a/areal/engine/vllm_remote.py b/areal/engine/vllm_remote.py index 5db6cc8f88..cabbae4995 100644 --- a/areal/engine/vllm_remote.py +++ b/areal/engine/vllm_remote.py @@ -8,6 +8,7 @@ from concurrent.futures import Future from typing import Any +import torch from torchdata.stateful_dataloader import StatefulDataLoader from areal.api import ( @@ -126,6 +127,46 @@ def parse_generation_response( stop_reason=stop_reason, ) + def build_score_request( + self, input_ids: list[int], target_len: int, with_lora: bool, version: int + ) -> HttpRequest: + payload: dict[str, Any] = { + "prompt": input_ids, + "max_tokens": 1, + "temperature": 0.0, + "logprobs": 1, + "prompt_logprobs": 1, + "echo": True, + } + if with_lora: + raise NotImplementedError( + "LoRA scoring request is not supported in vLLM teacher compute_logp yet." + ) + return HttpRequest(endpoint="/v1/completions", payload=payload) + + def parse_score_response( + self, response: dict[str, Any], target_len: int + ) -> list[float]: + choices = response.get("choices") + if not choices: + raise ValueError("vLLM response missing choices for score request") + prompt_logprobs = choices[0].get("prompt_logprobs") + if prompt_logprobs is None: + raise ValueError("vLLM response missing prompt_logprobs for score request") + if len(prompt_logprobs) < target_len + 1: + raise ValueError( + f"prompt_logprobs too short: got {len(prompt_logprobs)}, need {target_len + 1}" + ) + sliced = prompt_logprobs[-target_len:] + token_logps: list[float] = [] + for item in sliced: + if not item: + token_logps.append(0.0) + continue + top = next(iter(item.values())) + token_logps.append(float(top["logprob"] if isinstance(top, dict) else top)) + return token_logps + def build_disk_weight_update_requests( self, meta: WeightUpdateMeta ) -> WeightUpdateRequests: @@ -465,6 +506,9 @@ def prepare_batch( dynamic_bs=dynamic_bs, ) + def compute_logp(self, data: list[dict[str, Any]]) -> list[torch.Tensor]: + return self._engine.compute_logp(data) + def pause(self): return self._engine.pause() diff --git a/areal/infra/controller/rollout_controller.py b/areal/infra/controller/rollout_controller.py index ec327cc545..6f3b731fba 100644 --- a/areal/infra/controller/rollout_controller.py +++ b/areal/infra/controller/rollout_controller.py @@ -983,6 +983,46 @@ def task_input_generator(): trajectories = [r.trajectory if r is not None else None for r in results] return [t for t in trajectories if t is not None] + def compute_logp(self, data: list[dict[str, Any]]) -> list[Any]: + """Compute token log-probabilities for trajectories via remote workers.""" + if len(data) == 0: + return [] + + async def _compute(): + indexed_chunks: list[list[int]] = [] + tasks = [] + n_workers = len(self.workers) + if n_workers == 0: + raise RuntimeError("No workers available for compute_logp.") + + for rank, worker in enumerate(self.workers): + idxs = list(range(rank, len(data), n_workers)) + if not idxs: + continue + chunk = [data[i] for i in idxs] + indexed_chunks.append(idxs) + tasks.append( + self.scheduler.async_call_engine( + worker_id=worker.id, + method="compute_logp", + engine_name=self._engine_name(rank), + data=chunk, + http_timeout=self.config.request_timeout, + ) + ) + rpc_results = await asyncio.gather(*tasks) + merged: list[Any] = [None] * len(data) + for idxs, chunk_result in zip(indexed_chunks, rpc_results): + if len(chunk_result) != len(idxs): + raise RuntimeError( + f"compute_logp result length mismatch: got {len(chunk_result)}, expected {len(idxs)}" + ) + for out_idx, value in zip(idxs, chunk_result): + merged[out_idx] = value + return merged + + return run_async_task(_compute) + async def agenerate(self, req: ModelRequest) -> ModelResponse: """Asynchronously generate a response for the given request. diff --git a/areal/infra/remote_inf_engine.py b/areal/infra/remote_inf_engine.py index aacafbe20f..53d1311f02 100644 --- a/areal/infra/remote_inf_engine.py +++ b/areal/infra/remote_inf_engine.py @@ -20,6 +20,7 @@ import numpy as np import ray import requests +import torch import torch.distributed as dist import uvloop from torchdata.stateful_dataloader import StatefulDataLoader @@ -174,6 +175,18 @@ def parse_generation_response( """ ... + def build_score_request( + self, input_ids: list[int], target_len: int, with_lora: bool, version: int + ) -> HttpRequest: + """Build HTTP request for token log-prob scoring.""" + ... + + def parse_score_response( + self, response: dict[str, Any], target_len: int + ) -> list[float]: + """Parse token log-prob scoring response.""" + ... + def build_disk_weight_update_requests( self, meta: WeightUpdateMeta ) -> WeightUpdateRequests: @@ -502,6 +515,55 @@ def get_version(self): with self.lock: return self._version + def compute_logp(self, data: list[dict[str, Any]]) -> list[torch.Tensor]: + results: list[torch.Tensor] = [] + timeout = self.config.request_timeout + version = self.get_version() + for traj in data: + input_ids = traj["input_ids"] + loss_mask = traj["loss_mask"] + if input_ids.dim() != 2 or loss_mask.dim() != 2: + raise ValueError("input_ids and loss_mask must be 2D tensors") + bs = input_ids.shape[0] + out = torch.zeros_like(loss_mask, dtype=torch.float32) + for i in range(bs): + token_ids = input_ids[i].tolist() + target_len = int(loss_mask[i].sum().item()) + if target_len <= 0: + continue + if "attention_mask" in traj: + attn_mask = traj["attention_mask"][i] + active_idx = torch.nonzero(attn_mask, as_tuple=False).squeeze(-1) + token_ids = input_ids[i, active_idx].tolist() + else: + token_ids = input_ids[i].tolist() + server_addr = self.choose_server() + http_req = self.backend.build_score_request( + input_ids=token_ids, + target_len=target_len, + with_lora=self.config.use_lora, + version=version, + ) + response = requests.request( + http_req.method, + f"http://{server_addr}{http_req.endpoint}", + json=http_req.payload, + timeout=timeout, + ) + response.raise_for_status() + payload = response.json() + token_logps = self.backend.parse_score_response(payload, target_len) + if len(token_logps) != target_len: + raise ValueError( + f"Expected {target_len} token logprobs, got {len(token_logps)}" + ) + write_idx = torch.nonzero(loss_mask[i], as_tuple=False).squeeze(-1) + out[i, write_idx] = torch.tensor( + token_logps, device=out.device, dtype=out.dtype + ) + results.append(out) + return results + def set_proxy_gateway_addr(self, addr: str) -> None: """Set the proxy gateway address. diff --git a/areal/trainer/rl_trainer.py b/areal/trainer/rl_trainer.py index c7249eb281..5512b79554 100644 --- a/areal/trainer/rl_trainer.py +++ b/areal/trainer/rl_trainer.py @@ -181,11 +181,21 @@ def __init__( self.ref = self._create_train_engine(config.ref, ref_alloc) self.teacher = None + self.teacher_alloc = None if config.teacher is not None: - teacher_alloc = ModelAllocation.from_str( - config.teacher.backend, name="teacher" - ) - self.teacher = self._create_train_engine(config.teacher, teacher_alloc) + if config.teacher.engine_type == "rollout": + self.teacher_alloc = ModelAllocation.from_str( + config.teacher.rollout.backend, name="teacher" + ) + else: + assert config.teacher.train is not None + self.teacher_alloc = ModelAllocation.from_str( + self.config.teacher.train.backend, name="teacher" + ) + logger.warning( + "teacher.engine_type='train' uses legacy train-engine teacher path " + "and is deprecated; please migrate to engine_type='rollout'." + ) steps_per_epoch: int | None = None self.train_dataloader: StatefulDataLoader | _EmptyDataLoader @@ -281,7 +291,14 @@ def __init__( if self.ref is not None: self.ref.initialize(**engine_init_kwargs, role="ref") - if self.teacher is not None: + if ( + self.config.teacher is not None + and self.config.teacher.engine_type == "train" + ): + assert self.config.teacher.train is not None + self.teacher = self._create_train_engine( + self.config.teacher.train, self.teacher_alloc + ) self.teacher.initialize(**engine_init_kwargs, role="teacher") # Save initial LoRA weights if enabled (for inference server pre-loading) @@ -297,6 +314,11 @@ def __init__( self.eval_rollout = self._init_rollout( config.rollout, is_eval=True, lora_path=initial_lora_path ) + if ( + self.config.teacher is not None + and self.config.teacher.engine_type == "rollout" + ): + self.teacher = self._init_teacher_rollout(self.config.teacher.rollout) # Proxy worker initialization (lazy, for AgentWorkflow support) self._proxy_started = False @@ -634,7 +656,6 @@ def train( traj["distill_loss_weight"] = ( self.config.teacher.distill_loss_weight ) - self.teacher.get_device_stats().log("teacher logp") if self._should_offload_teacher: self._offload_model(self.teacher, role="teacher") @@ -835,6 +856,8 @@ def close(self): if self.eval_rollout is not None: self.eval_rollout.destroy() self.rollout.destroy() + if self.teacher is not None: + self.teacher.destroy() if self.ref is not None: self.ref.destroy() if self.critic is not None: @@ -1050,6 +1073,50 @@ def _init_rollout( controller.initialize(**init_kwargs) return controller + def _init_teacher_rollout( + self, rollout_config: InferenceEngineConfig + ) -> InferenceEngine | RolloutController: + if self.teacher_alloc is None: + raise RuntimeError("teacher_alloc is not initialized") + rollout_alloc = self.teacher_alloc + config = deepcopy(rollout_config) + if rollout_alloc.backend == "sglang": + engine_cls = RemoteSGLangEngine + teacher_sglang_cfg = deepcopy(self.config.sglang) + if self.config.teacher is not None and self.config.teacher.path: + teacher_sglang_cfg.model_path = self.config.teacher.path + server_args = SGLangConfig.build_args( + sglang_config=teacher_sglang_cfg, + tp_size=rollout_alloc.parallel.tp_size, + pp_size=rollout_alloc.parallel.pp_size, + base_gpu_id=0, + ) + elif rollout_alloc.backend == "vllm": + engine_cls = RemotevLLMEngine + teacher_vllm_cfg = deepcopy(self.config.vllm) + if self.config.teacher is not None and self.config.teacher.path: + teacher_vllm_cfg.model = self.config.teacher.path + if not rollout_config.tokenizer_path: + config.tokenizer_path = self.config.teacher.path + server_args = vLLMConfig.build_args( + vllm_config=teacher_vllm_cfg, + tp_size=rollout_alloc.parallel.tp_size, + pp_size=rollout_alloc.parallel.pp_size, + ) + else: + raise ValueError( + f"Invalid teacher rollout backend: {rollout_alloc.backend}, expected sglang or vllm" + ) + if not is_single_controller(): + engine = engine_cls(config) + engine.initialize( + train_data_parallel_size=self.actor_alloc.parallel.dp_size + ) + return engine + controller = engine_cls.as_controller(config, self.scheduler) + controller.initialize(role="teacher", server_args=server_args) + return controller + def _save_initial_lora_weights(self) -> str | None: """Save initial LoRA weights for inference server pre-loading. diff --git a/docs/en/algorithms/distillation.md b/docs/en/algorithms/distillation.md index f52ae69107..e4c36f0b2e 100644 --- a/docs/en/algorithms/distillation.md +++ b/docs/en/algorithms/distillation.md @@ -23,7 +23,9 @@ bias. A simple yet effective method is to maximize the log-likelihood on data generated by the teacher, known as supervised fine-tuning (SFT). This is equivalent to minimizing the -Forward KL divergence between $\pi_T$ and $\pi_\theta$: $$\arg \min_{\theta} +Forward KL divergence between $\pi_T$ and $\pi_\theta$: + +$$\arg \min_{\theta} D_{KL}(\pi_T \parallel \pi_\theta) = \arg \max_{\theta} \mathbb{E}_{q \sim Q, o \sim \pi_T(\cdot|q)} [\log \pi_\theta(o|q)]$$ @@ -33,7 +35,9 @@ While SFT is efficient, training on off-policy data induces exposure bias: a mis between training on teacher-generated prefixes and inference on self-generated prefixes. This is especially severe for reasoning LLMs with long response chains. To alleviate this, we can train on self-generated trajectories, which is equivalent to minimizing the -Reverse KL divergence (RKL) [1]: $$\arg \min_{\theta} D_{KL}(\pi_\theta +Reverse KL divergence (RKL) [1]: + +$$\arg \min_{\theta} D_{KL}(\pi_\theta \parallel \pi_T) = \arg \max_{\theta} \mathbb{E}_{q \sim Q, o \sim \pi_\theta(\cdot|q)} \left[ \log \frac{\pi_T(o|q)}{\pi_\theta(o|q)} \right]$$ @@ -78,29 +82,74 @@ J_{RKL}(\theta)$. ## Running the example -Need to add teacher configuration to your yaml: +Need to add teacher configuration to your yaml. + +Teacher supports two modes via `teacher.engine_type`: + +- `rollout` (recommended): inference-only teacher (vLLM/SGLang) with lower memory + overhead. +- `train` (legacy, deprecated): train-engine teacher path kept for backward + compatibility. + +### Mode 1: rollout teacher (recommended) + + +```yaml +teacher: + engine_type: rollout + path: Qwen/Qwen2.5-14B-Instruct + rollout: + backend: "vllm:d1p1t2" # or sglang:d... + scheduling_spec: + - task_type: worker + port_count: 2 + gpu: 1 + mem: 32 + cmd: python3 -m areal.infra.rpc.rpc_server + env_vars: {} + rl_loss_weight: 1.0 + distill_loss_weight: 5e-3 +``` + +Example command using local scheduler: + +```bash +python3 examples/math/gsm8k_rl.py \ + --config examples/distillation/gsm8k_grpo_distill_mode_rolloutEngine.yaml \ + scheduler.type=local \ + experiment_name=gsm8k-grpo-distillation \ + trial_name=trial0 +``` + +### Mode 2: legacy train teacher (deprecated) ```yaml teacher: - backend: fsdp:d1p1t4 + engine_type: train + train: + backend: fsdp:d1p1t4 + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: Qwen/Qwen2.5-14B-Instruct + init_from_scratch: false + disable_dropout: true + dtype: ${actor.dtype} + mb_spec: + max_tokens_per_mb: 10240 + optimizer: null + scheduling_spec: ${actor.scheduling_spec} rl_loss_weight: 1.0 distill_loss_weight: 0.005 - experiment_name: ${experiment_name} - trial_name: ${trial_name} - path: Qwen/Qwen3-32B - init_from_scratch: false - disable_dropout: true - dtype: ${actor.dtype} - mb_spec: - max_tokens_per_mb: 10240 - optimizer: null - scheduling_spec: ${actor.scheduling_spec} ``` Example command using local scheduler: ```bash -python3 examples/math/gsm8k_rl.py --config examples/distillation/gsm8k_grpo_distill.yaml scheduler.type=local experiment_name=gsm8k-grpo-distillation trial_name=trial0 +python3 examples/math/gsm8k_rl.py \ + --config examples/distillation/gsm8k_grpo_distill_mode_trainEngine.yaml \ + scheduler.type=local \ + experiment_name=gsm8k-grpo-distillation \ + trial_name=trial0 ``` ## Result diff --git a/docs/en/algorithms/reward_curve.png b/docs/en/algorithms/reward_curve.png index 10d2db17e0..08b639e529 100644 Binary files a/docs/en/algorithms/reward_curve.png and b/docs/en/algorithms/reward_curve.png differ diff --git a/docs/en/cli_reference.md b/docs/en/cli_reference.md index 63fad21047..cb640a8573 100644 --- a/docs/en/cli_reference.md +++ b/docs/en/cli_reference.md @@ -1233,71 +1233,12 @@ Configuration for per-session lifecycle tracing. Configuration class: TeacherConfig -| Parameter | Type | Default | Description | -| --------------------------- | --------------------------------------------------------------- | --------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `experiment_name` | string | **Required** | - | -| `trial_name` | string | **Required** | - | -| `path` | string | `""` | Path to HuggingFace checkpoint | -| `attn_impl` | string | `"flash_attention_2"` | Attention implementation for huggingface transformers model. Accepts builtin transformers backends or a Hugging Face kernels repo ID formatted as org/repo\[@revision\]\[:entrypoint\]. **Choices:** `eager`, `sdpa`, `flash_attention_2`, `flash_attention_3`, `flex_attention` | -| `use_kernels` | boolean | `False` | Enable Hugging Face kernels model kernelization after model creation. | -| `init_from_scratch` | boolean | `False` | Initialize model weights randomly | -| `is_critic` | boolean | `False` | Whether to use a critic/reward model | -| `temperature` | float | `1.0` | Temperature during generation. | -| `mb_spec` | [`MicroBatchSpec`](section-micro-batch) | **Required** | - | -| `pad_to_maximum` | boolean | `False` | Whether to pad each microbatch to the length upper bound specified by mb_spec. Can reduce memory fragmentation but slows down training. | -| `disable_dropout` | boolean | `False` | Disable dropout layers during training | -| `gradient_checkpointing` | boolean | `False` | Enable gradient checkpointing | -| `dtype` | string | `"bfloat16"` | Forward/backward compute dtype. | -| `grad_reduce_dtype` | string | `"float32"` | Gradient reduction data type. | -| `optimizer_dtype` | string | `"float32"` | Underlying parameter storage dtype, also the dtype of optimizer states (exp_avg, exp_avg_sq) since torch.optim.AdamW inherits dtype from model.parameters(). Default 'float32' maintains fp32 master weights matching DeepSpeed ZeRO-3 and Megatron precision-aware optimizer behavior. FSDP2's MixedPrecisionPolicy(param_dtype=`dtype`) will still cast forward/backward computation to `dtype` (e.g. bfloat16). Set to 'bfloat16' together with optimizer.type='adam_bf16' to reduce memory at the cost of needing Kahan summation for stability. Currently FSDP-only; Megatron uses use_precision_aware_optimizer instead and ignores this field. | -| `optimizer` | [`OptimizerConfig`](section-optimizer) \| None | `None` | Optimizer configuration. None means no training. | -| `weight_update_mode` | string | `"xccl"` | Weight update backend type. **Choices:** `disk`, `xccl` | -| `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | -| `archon` | [`ArchonEngineConfig`](section-archon-engine) | **Required** | - | -| `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | -| `offload` | boolean | `False` | Whether to offload model parameters and optimizer states to CPU. | -| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | -| `lora_rank` | integer | `32` | lora rank | -| `lora_alpha` | integer | `16` | lora alpha | -| `target_modules` | list of string | **Required** | lora target_modules. | -| `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | -| `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | -| `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | -| `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | -| `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | -| `admin_api_key` | string | `"areal-admin-key"` | Admin API key used by gateway/router/data-proxy in controller v2. | -| `log_level` | string | `"warning"` | Gateway stack log level for controller v2. | -| `request_timeout` | float | `3600.0` | Gateway request timeout in seconds for controller v2. | -| `setup_timeout` | float | `3600.0` | Gateway setup timeout in seconds for controller v2. | -| `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this TrainEngine, either separation or colocation. Currently only used by the TrainController. | -| `ppo_n_minibatches` | integer | `4` | Number of minibatches for each PPO update | -| `eps_clip` | float | `0.2` | Clipping factor for policy ratio | -| `eps_clip_higher` | float \| None | `None` | Clipping factor (higher value) for policy ratio. Default is None. When eps_clip_higher is set (decoupled), eps_clip will be used as the lower value. | -| `c_clip` | float \| None | `None` | Dual clipping factor for policy ratio, must be > 1.0. None disables dual clipping. | -| `m2_threshold` | float \| None | `None` | The second momentum threshold for M2PO. | -| `reward_norm` | [`NormConfig`](section-norm) \| None | `None` | Normalization configuration for rewards | -| `reward_scaling` | float | `1.0` | Reward scaling factor | -| `reward_bias` | float | `0.0` | Reward bias | -| `reward_clip` | float | `20.0` | Maximum absolute value for reward clipping | -| `overlong_reward_penalty` | boolean | `False` | Penalty for overlong sequences. Used within DAPO. | -| `overlong_tokens` | integer \| None | `None` | Number of tokens in the tail that will receive a penalty | -| `overlong_penalty_factor` | float \| None | `None` | Penalty factor for tokens in the tail | -| `mask_no_eos_with_zero` | boolean | `False` | Mask truncated generations (no EOS token) and exclude from training | -| `discount` | float | `1.0` | Discount factor for future rewards | -| `gae_lambda` | float | `1.0` | Lambda parameter for GAE | -| `adv_norm` | [`NormConfig`](section-norm) \| None | `None` | Normalization configuration for advantages. | -| `kl_ctl` | float | `0.1` | KL divergence coefficient | -| `kl_estimator` | string | `"k1"` | KL divergence estimator **Choices:** `k1`, `k2`, `k3` | -| `use_sapo_loss` | boolean | `False` | Use SAPO loss (mutually exclusive with PPO clipping) | -| `sapo_tau_pos` | float | `1.0` | SAPO temperature for positive advantages | -| `sapo_tau_neg` | float | `1.05` | SAPO temperature for negative advantages | -| `recompute_logprob` | boolean | `False` | Recompute log probability and replace the log probability returned by inference. | -| `use_decoupled_loss` | boolean | `False` | Use the decoupled loss. Implicitly enables recompute_logprob. | -| `rejection_sampling` | [`RejectionSamplingConfig`](section-rejection-sampling) \| None | `None` | Rejection sampling configuration for filtering stale samples. None disables filtering (equivalent to old behave_imp_weight_mode='disabled'). Only effective when use_decoupled_loss=True. | -| `importance_sampling_level` | string | `"token"` | Level at which to compute importance sampling ratios. 'token': per-token ratios (standard PPO). 'sequence': sequence-level geometric mean of per-token ratios (GSPO). **Choices:** `token`, `sequence` | -| `prox_logp_method` | string | `"recompute"` | Method for computing proximal policy log-probabilities in decoupled PPO. Only effective when use_decoupled_loss=True. Options: 'recompute' (default): Standard decoupled PPO, recompute proximal policy via forward pass. 'loglinear': Use log-linear interpolation to approximate proximal policy (skip forward pass). 'metrics': Like 'recompute', but also compute approximation metrics for evaluation. **Choices:** `recompute`, `loglinear`, `metrics` | -| `log_agent_stats` | boolean | `False` | Log statistics for agent trajectories | -| `log_agent_stats_keys` | list of string | **Required** | Keys for logging agent trajectory statistics | -| `max_new_tokens` | integer | `1024` | Maximum number of new tokens to generate | -| `rl_loss_weight` | float | `1.0` | RL loss weight | -| `distill_loss_weight` | float | `0.005` | Distillation loss weight | +| Parameter | Type | Default | Description | +| --------------------- | ----------------------------------------------------------- | ----------- | ------------------------------------------------------------------------------------------------------------------------------------------------ | +| `engine_type` | string | `"rollout"` | Teacher engine type. 'rollout' uses inference engine scoring; 'train' uses the legacy train-engine teacher path. **Choices:** `rollout`, `train` | +| `rollout` | [`InferenceEngineConfig`](section-inference-engine) \| None | `None` | - | +| `train` | [`PPOActorConfig`](section-ppo-actor) \| None | `None` | Legacy train-engine teacher config. Required when engine_type='train'. | +| `path` | string | `""` | Teacher model path. If set, overrides shared rollout backend model path. | +| `offload` | boolean | `False` | Whether to offload teacher rollout model between steps | +| `rl_loss_weight` | float | `1.0` | RL loss weight | +| `distill_loss_weight` | float | `0.005` | Distillation loss weight | diff --git a/docs/zh/cli_reference.md b/docs/zh/cli_reference.md index 79cd7e754c..a73cf0eb3d 100644 --- a/docs/zh/cli_reference.md +++ b/docs/zh/cli_reference.md @@ -1231,71 +1231,12 @@ Configuration for per-session lifecycle tracing. Configuration class: TeacherConfig -| Parameter | Type | Default | Description | -| --------------------------- | --------------------------------------------------------------- | --------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `experiment_name` | string | **Required** | - | -| `trial_name` | string | **Required** | - | -| `path` | string | `""` | Path to HuggingFace checkpoint | -| `attn_impl` | string | `"flash_attention_2"` | Attention implementation for huggingface transformers model. Accepts builtin transformers backends or a Hugging Face kernels repo ID formatted as org/repo\[@revision\]\[:entrypoint\]. **Choices:** `eager`, `sdpa`, `flash_attention_2`, `flash_attention_3`, `flex_attention` | -| `use_kernels` | boolean | `False` | Enable Hugging Face kernels model kernelization after model creation. | -| `init_from_scratch` | boolean | `False` | Initialize model weights randomly | -| `is_critic` | boolean | `False` | Whether to use a critic/reward model | -| `temperature` | float | `1.0` | Temperature during generation. | -| `mb_spec` | [`MicroBatchSpec`](section-micro-batch) | **Required** | - | -| `pad_to_maximum` | boolean | `False` | Whether to pad each microbatch to the length upper bound specified by mb_spec. Can reduce memory fragmentation but slows down training. | -| `disable_dropout` | boolean | `False` | Disable dropout layers during training | -| `gradient_checkpointing` | boolean | `False` | Enable gradient checkpointing | -| `dtype` | string | `"bfloat16"` | Forward/backward compute dtype. | -| `grad_reduce_dtype` | string | `"float32"` | Gradient reduction data type. | -| `optimizer_dtype` | string | `"float32"` | Underlying parameter storage dtype, also the dtype of optimizer states (exp_avg, exp_avg_sq) since torch.optim.AdamW inherits dtype from model.parameters(). Default 'float32' maintains fp32 master weights matching DeepSpeed ZeRO-3 and Megatron precision-aware optimizer behavior. FSDP2's MixedPrecisionPolicy(param_dtype=`dtype`) will still cast forward/backward computation to `dtype` (e.g. bfloat16). Set to 'bfloat16' together with optimizer.type='adam_bf16' to reduce memory at the cost of needing Kahan summation for stability. Currently FSDP-only; Megatron uses use_precision_aware_optimizer instead and ignores this field. | -| `optimizer` | [`OptimizerConfig`](section-optimizer) \| None | `None` | Optimizer configuration. None means no training. | -| `weight_update_mode` | string | `"xccl"` | Weight update backend type. **Choices:** `disk`, `xccl` | -| `fsdp` | [`FSDPEngineConfig`](section-fsdp-engine) | **Required** | - | -| `archon` | [`ArchonEngineConfig`](section-archon-engine) | **Required** | - | -| `megatron` | [`MegatronEngineConfig`](section-megatron-engine) | **Required** | - | -| `offload` | boolean | `False` | Whether to offload model parameters and optimizer states to CPU. | -| `use_lora` | boolean | `False` | Whether to use LoRA. Only support FSDP. Note that should be enabled together with vLLM/SGLang. | -| `lora_rank` | integer | `32` | lora rank | -| `lora_alpha` | integer | `16` | lora alpha | -| `target_modules` | list of string | **Required** | lora target_modules. | -| `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | -| `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | -| `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | -| `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | -| `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | -| `admin_api_key` | string | `"areal-admin-key"` | Admin API key used by gateway/router/data-proxy in controller v2. | -| `log_level` | string | `"warning"` | Gateway stack log level for controller v2. | -| `request_timeout` | float | `3600.0` | Gateway request timeout in seconds for controller v2. | -| `setup_timeout` | float | `3600.0` | Gateway setup timeout in seconds for controller v2. | -| `scheduling_strategy` | [`SchedulingStrategy`](section-scheduling-strategy) | **Required** | The scheduling strategy of this TrainEngine, either separation or colocation. Currently only used by the TrainController. | -| `ppo_n_minibatches` | integer | `4` | Number of minibatches for each PPO update | -| `eps_clip` | float | `0.2` | Clipping factor for policy ratio | -| `eps_clip_higher` | float \| None | `None` | Clipping factor (higher value) for policy ratio. Default is None. When eps_clip_higher is set (decoupled), eps_clip will be used as the lower value. | -| `c_clip` | float \| None | `None` | Dual clipping factor for policy ratio, must be > 1.0. None disables dual clipping. | -| `m2_threshold` | float \| None | `None` | The second momentum threshold for M2PO. | -| `reward_norm` | [`NormConfig`](section-norm) \| None | `None` | Normalization configuration for rewards | -| `reward_scaling` | float | `1.0` | Reward scaling factor | -| `reward_bias` | float | `0.0` | Reward bias | -| `reward_clip` | float | `20.0` | Maximum absolute value for reward clipping | -| `overlong_reward_penalty` | boolean | `False` | Penalty for overlong sequences. Used within DAPO. | -| `overlong_tokens` | integer \| None | `None` | Number of tokens in the tail that will receive a penalty | -| `overlong_penalty_factor` | float \| None | `None` | Penalty factor for tokens in the tail | -| `mask_no_eos_with_zero` | boolean | `False` | Mask truncated generations (no EOS token) and exclude from training | -| `discount` | float | `1.0` | Discount factor for future rewards | -| `gae_lambda` | float | `1.0` | Lambda parameter for GAE | -| `adv_norm` | [`NormConfig`](section-norm) \| None | `None` | Normalization configuration for advantages. | -| `kl_ctl` | float | `0.1` | KL divergence coefficient | -| `kl_estimator` | string | `"k1"` | KL divergence estimator **Choices:** `k1`, `k2`, `k3` | -| `use_sapo_loss` | boolean | `False` | Use SAPO loss (mutually exclusive with PPO clipping) | -| `sapo_tau_pos` | float | `1.0` | SAPO temperature for positive advantages | -| `sapo_tau_neg` | float | `1.05` | SAPO temperature for negative advantages | -| `recompute_logprob` | boolean | `False` | Recompute log probability and replace the log probability returned by inference. | -| `use_decoupled_loss` | boolean | `False` | Use the decoupled loss. Implicitly enables recompute_logprob. | -| `rejection_sampling` | [`RejectionSamplingConfig`](section-rejection-sampling) \| None | `None` | Rejection sampling configuration for filtering stale samples. None disables filtering (equivalent to old behave_imp_weight_mode='disabled'). Only effective when use_decoupled_loss=True. | -| `importance_sampling_level` | string | `"token"` | Level at which to compute importance sampling ratios. 'token': per-token ratios (standard PPO). 'sequence': sequence-level geometric mean of per-token ratios (GSPO). **Choices:** `token`, `sequence` | -| `prox_logp_method` | string | `"recompute"` | Method for computing proximal policy log-probabilities in decoupled PPO. Only effective when use_decoupled_loss=True. Options: 'recompute' (default): Standard decoupled PPO, recompute proximal policy via forward pass. 'loglinear': Use log-linear interpolation to approximate proximal policy (skip forward pass). 'metrics': Like 'recompute', but also compute approximation metrics for evaluation. **Choices:** `recompute`, `loglinear`, `metrics` | -| `log_agent_stats` | boolean | `False` | Log statistics for agent trajectories | -| `log_agent_stats_keys` | list of string | **Required** | Keys for logging agent trajectory statistics | -| `max_new_tokens` | integer | `1024` | Maximum number of new tokens to generate | -| `rl_loss_weight` | float | `1.0` | RL loss weight | -| `distill_loss_weight` | float | `0.005` | Distillation loss weight | +| Parameter | Type | Default | Description | +| --------------------- | ----------------------------------------------------------- | ----------- | ------------------------------------------------------------------------------------------------------------------------------------------------ | +| `engine_type` | string | `"rollout"` | Teacher engine type. 'rollout' uses inference engine scoring; 'train' uses the legacy train-engine teacher path. **Choices:** `rollout`, `train` | +| `rollout` | [`InferenceEngineConfig`](section-inference-engine) \| None | `None` | - | +| `train` | [`PPOActorConfig`](section-ppo-actor) \| None | `None` | Legacy train-engine teacher config. Required when engine_type='train'. | +| `path` | string | `""` | Teacher model path. If set, overrides shared rollout backend model path. | +| `offload` | boolean | `False` | Whether to offload teacher rollout model between steps | +| `rl_loss_weight` | float | `1.0` | RL loss weight | +| `distill_loss_weight` | float | `0.005` | Distillation loss weight | diff --git a/examples/distillation/gsm8k_grpo_distill_mode_rolloutEngine.yaml b/examples/distillation/gsm8k_grpo_distill_mode_rolloutEngine.yaml new file mode 100644 index 0000000000..f9554d791d --- /dev/null +++ b/examples/distillation/gsm8k_grpo_distill_mode_rolloutEngine.yaml @@ -0,0 +1,195 @@ +experiment_name: gsm8k-grpo +trial_name: trial0 + +seed: 1 +enable_offload: false +total_train_epochs: 10 +tokenizer_path: ${actor.path} + +cluster: + n_nodes: 1 + n_gpus_per_node: 8 + fileroot: /tmp/areal/experiments + name_resolve: + type: nfs + nfs_record_root: /tmp/areal/name_resolve + +scheduler: + type: null + +rollout: + backend: "vllm:d1p1t1" + experiment_name: ${experiment_name} + trial_name: ${trial_name} + max_concurrent_rollouts: 256 + queue_size: null + consumer_batch_size: ${train_dataset.batch_size} + max_head_offpolicyness: 2 + enable_rollout_tracing: false + scheduling_spec: ${actor.scheduling_spec} + fileroot: ${cluster.fileroot} + tokenizer_path: ${tokenizer_path} + dump_to_file: true + +gconfig: + n_samples: 4 + min_new_tokens: 0 + max_new_tokens: 2048 + greedy: false + temperature: 1.0 + +actor: + backend: "fsdp:d1p1t1" + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: Qwen/Qwen3-0.6B + init_from_scratch: false + disable_dropout: true + gradient_checkpointing: true + dtype: bfloat16 + mb_spec: + max_tokens_per_mb: 10240 + optimizer: + type: adam + lr: 1.70e-5 + weight_decay: 0.017 + beta1: 0.9 + beta2: 0.999 + eps: 1e-8 + lr_scheduler_type: constant + gradient_clipping: 1.0 + warmup_steps_proportion: 0.001 + eps_clip: 0.4 + temperature: ${gconfig.temperature} + reward_scaling: 10.0 + reward_bias: -0.5 + kl_ctl: 0.0 + ppo_n_minibatches: 1 + recompute_logprob: true + use_decoupled_loss: true + rejection_sampling: + metric: ratio + upper: 5.0 + reward_norm: + mean_level: group + std_level: group + group_size: ${gconfig.n_samples} + adv_norm: + mean_level: batch + std_level: batch + weight_update_mode: xccl + max_new_tokens: ${gconfig.max_new_tokens} + scheduling_spec: + - task_type: worker + port_count: 2 + gpu: 1 + mem: 32 + cmd: python3 -m areal.infra.rpc.rpc_server + env_vars: {} + +teacher: + engine_type: rollout + path: Qwen/Qwen2.5-14B-Instruct + rollout: + backend: "vllm:d1p1t2" + scheduling_spec: + - task_type: worker + port_count: 2 + gpu: 1 + mem: 32 + cmd: python3 -m areal.infra.rpc.rpc_server + env_vars: {} + rl_loss_weight: 1.0 + distill_loss_weight: 5e-3 + +ref: + backend: ${actor.backend} + experiment_name: ${experiment_name} + trial_name: ${trial_name} + path: ${actor.path} + init_from_scratch: false + disable_dropout: true + dtype: ${actor.dtype} + mb_spec: + max_tokens_per_mb: 10240 + optimizer: null + scheduling_strategy: + type: colocation + target: actor + scheduling_spec: ${actor.scheduling_spec} + +# SGLang +sglang: + model_path: ${actor.path} + random_seed: ${seed} + skip_tokenizer_init: true + dtype: ${actor.dtype} + max_running_requests: null + context_length: 32768 + mem_fraction_static: 0.8 + +vllm: + model: ${actor.path} + seed: ${seed} + skip_tokenizer_init: false + dtype: ${actor.dtype} + max_model_len: 32768 + gpu_memory_utilization: 0.8 + +# datasets +train_dataset: + batch_size: 256 + shuffle: true + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + max_length: 1024 + +valid_dataset: + batch_size: 256 + pin_memory: true + num_workers: 4 + path: openai/gsm8k + type: rl + +# Utilities +saver: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +recover: + mode: disabled + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: 3600 + +evaluator: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + freq_epochs: 1 + freq_steps: null + freq_secs: null + +stats_logger: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + wandb: + mode: disabled + +perf_tracer: + experiment_name: ${experiment_name} + trial_name: ${trial_name} + fileroot: ${cluster.fileroot} + enabled: false + session_tracer: + enabled: false diff --git a/examples/distillation/gsm8k_grpo_distill.yaml b/examples/distillation/gsm8k_grpo_distill_mode_trainEngine.yaml similarity index 100% rename from examples/distillation/gsm8k_grpo_distill.yaml rename to examples/distillation/gsm8k_grpo_distill_mode_trainEngine.yaml