diff --git a/examples/gsm8k_geo3k/train_colocate.py b/examples/gsm8k_geo3k/train_colocate.py index d033260..6b8b29b 100644 --- a/examples/gsm8k_geo3k/train_colocate.py +++ b/examples/gsm8k_geo3k/train_colocate.py @@ -409,7 +409,8 @@ def train(args): top_p=args.top_p, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, - # reward model + # reward model / teacher model URL (used for OPD) + remote_rm_url=args.remote_rm_url, reward_fn=reward_fn, reward_fn_label_map=label_map, reward_recipe=RECIPE, @@ -560,12 +561,13 @@ def train(args): parser.add_argument( "--advantage_estimator", type=str, - choices=["gae", "reinforce", "rloo", "reinforce_baseline", "group_norm", "cpgd", "reinforce++"], + choices=["gae", "reinforce", "rloo", "reinforce_baseline", "group_norm", "cpgd", "reinforce++", "on_policy_distillation", "on_policy_distillation_hybrid"], default="gae", help="Choose advantage estimation method: gae, reinforce, rloo, reinforce_baseline, group_norm, reinforce++", ) parser.add_argument("--use_kl_loss", action="store_true", default=False, help="whether to use KL loss from GRPO") + parser.add_argument("--opd_kl_coef", type=float, default=1.0, help="KL coefficient for on-policy distillation penalty") # LoRA parser.add_argument("--load_in_4bit", action="store_true", default=False) @@ -652,7 +654,7 @@ def train(args): elif args.critic_pretrain is None: args.critic_pretrain = args.pretrain - if args.advantage_estimator in ["rloo", "reinforce_baseline", "group_norm"]: + if args.advantage_estimator in ["rloo", "reinforce_baseline", "group_norm", "on_policy_distillation", "on_policy_distillation_hybrid"]: assert args.n_samples_per_prompt > 1, f"{args.advantage_estimator} requires n_samples_per_prompt > 1" if args.use_kl_loss: diff --git a/examples/on_policy_distillation/README.md b/examples/on_policy_distillation/README.md new file mode 100644 index 0000000..af20882 --- /dev/null +++ b/examples/on_policy_distillation/README.md @@ -0,0 +1,232 @@ +# On-Policy Distillation (OPD) for LightRFT + +On-policy knowledge distillation enables smaller student models to learn from larger teacher models during reinforcement learning training. + +## Overview + +| Aspect | Description | +|--------|-------------| +| **Teacher** | Large model providing token-level log probability supervision | +| **Student** | Smaller model being trained to match teacher's distribution | +| **Training** | On-policy: teacher evaluates student's actual generated responses | +| **Reward** | Teacher's log probabilities serve as the learning signal | + +## Quick Start + +### 1. Start Teacher Model Server + +```bash +# Launch SGLang server for teacher model +CUDA_VISIBLE_DEVICES=7 python3 -m sglang.launch_server \ + --model-path "Qwen/Qwen2.5-7B-Instruct" \ + --host 0.0.0.0 \ + --port 13141 \ + --tp 1 \ + --mem-fraction-static 0.6 +``` + +### 2. Run Training + +```bash +bash examples/on_policy_distillation/run_opd_qwen_2.sh +``` + +Or manually: + +```bash +torchrun --nproc-per-node 2 examples/gsm8k_geo3k/train_colocate.py \ + --pretrain "Qwen/Qwen2.5-0.5B-Instruct" \ + --advantage_estimator "on_policy_distillation" \ + --remote_rm_url "http://127.0.0.1:13141/generate" \ + --reward_pretrain "" \ + --n_samples_per_prompt 4 \ + --actor_learning_rate 1e-6 \ + --init_kl_coef 0.01 \ + --num_episodes 30 +``` + +## Architecture + +``` +┌─────────────────────────────────────────────────────────────┐ +│ OPD Training Pipeline │ +├─────────────────────────────────────────────────────────────┤ +│ │ +│ 1. Generate Prompt ──► [Student] ──► Response │ +│ │ +│ 2. Evaluate [Prompt + Response] ──► [Teacher Server] │ +│ │ │ +│ ▼ │ +│ Teacher Log Probs │ +│ │ +│ 3. Compute Advantage = Teacher_logp - Student_logp │ +│ │ +│ 4. Update Student ◄── Policy Gradient Loss │ +│ │ +└─────────────────────────────────────────────────────────────┘ +``` + +## Key Components + +### 1. Advantage Calculator + +**File**: `lightrft/trainer/advantage_calculator.py` + +```python +class OnPolicyDistillationCalculator(AdvantageCalculator): + def compute(self, experience, ...): + teacher_log_probs = experience.info["teacher_log_probs"] + student_log_probs = experience.action_log_probs + + # Reverse KL: encourages student to match teacher + reverse_kl = student_log_probs - teacher_log_probs + advantages = -reverse_kl # = teacher - student + + return advantages, returns, {"opd_reverse_kl": reverse_kl} +``` + +### 2. Teacher Log Prob Fetcher + +**File**: `examples/on_policy_distillation/on_policy_distillation_reward.py` + +- Async HTTP requests to teacher server +- Supports SGLang and vLLM response formats +- Automatic retry with exponential backoff + +### 3. Experience Maker Integration + +**File**: `lightrft/trainer/fast_exp_maker.py` + +- `--remote_rm_url` is used as teacher URL (not reward model) when `--advantage_estimator "on_policy_distillation"` +- Teacher log probs stored in `experience.info["teacher_log_probs"]` +- OPD metrics (`opd_reverse_kl_mean/std/min/max`) logged to wandb + +## Configuration + +### Required Arguments + +| Argument | Value | Description | +|----------|-------|-------------| +| `--advantage_estimator` | `"on_policy_distillation"` | Enable OPD mode | +| `--remote_rm_url` | `"http://host:port/generate"` | Teacher server URL | +| `--reward_pretrain` | `""` | Empty (no reward model needed) | + +### Recommended Hyperparameters + +| Parameter | Value | Description | +|-----------|-------|-------------| +| `--n_samples_per_prompt` | 4 | Responses per prompt | +| `--actor_learning_rate` | 1e-6 | Student learning rate | +| `--init_kl_coef` | 0.01 | KL regularization coefficient | +| `--num_episodes` | 30 | Training episodes | + +## Teacher Server Formats + +### SGLang (Recommended) + +```json +{ + "meta_info": { + "input_token_logprobs": [[logprob, rank, token], ...] + } +} +``` + +### vLLM + +```json +{ + "token_logprobs": [logprob1, logprob2, ...] +} +``` + +## Monitoring + +### Logged Metrics + +| Metric | Description | +|--------|-------------| +| `opd_reverse_kl_mean` | Average KL(student \|\| teacher) | +| `opd_reverse_kl_std` | Standard deviation of reverse KL | +| `advantages_mean` | Average advantage (should center ~0) | +| `policy_loss` | Should decrease during training | + +### Console Output + +``` +📊 Detailed Step Statistics +============================================================ +🎁 Total Reward: 0.0000 ± 0.0000 (placeholder for OPD) +📈 Advantages: 0.0012 ± 0.8234 (...) +🎓 OPD Reverse KL: 0.1523 ± 0.0891 (...) +============================================================ +``` + +## Troubleshooting + +### Teacher Server Issues + +```bash +# Check if port is in use +lsof -i :13141 + +# Check GPU availability +nvidia-smi + +# Reduce memory if OOM +--mem-fraction-static 0.5 +``` + +### Training OOM + +```bash +--micro_train_batch_size 2 +--micro_rollout_batch_size 2 +--gradient_checkpointing +--zero_stage 3 +``` + +### Slow Convergence + +```bash +--n_samples_per_prompt 8 +--actor_learning_rate 5e-7 +--num_episodes 50 +``` + +## Comparison with Other Methods + +| Method | Reward Signal | Mode | Requires RM | +|--------|--------------|------|-------------| +| GRPO | Task-specific reward | Online | Yes | +| DPO | Preference pairs | Offline | No | +| **OPD** | Teacher log probs | Online | No (uses teacher) | + +### Advantages + +- No separate reward model training required +- Token-level supervision (finer than sequence-level) +- On-policy: adapts to student's changing distribution +- Works for any task with a good teacher model + +### Limitations + +- Requires running teacher model (inference overhead) +- Student cannot exceed teacher's capabilities +- Needs sufficient compute for teacher inference + +## File Structure + +``` +examples/on_policy_distillation/ +├── README.md # This file +├── README_zh.md # Chinese version +├── run_opd_qwen_2.sh # Training script +└── on_policy_distillation_reward.py # Teacher logprob fetcher +``` + +## References + +- [LightRFT Documentation](../../README.md) +- [Advantage Calculator Source](../../lightrft/trainer/advantage_calculator.py) +- [Fast Experience Maker Source](../../lightrft/trainer/fast_exp_maker.py) diff --git a/examples/on_policy_distillation/README_zh.md b/examples/on_policy_distillation/README_zh.md new file mode 100644 index 0000000..c623912 --- /dev/null +++ b/examples/on_policy_distillation/README_zh.md @@ -0,0 +1,232 @@ +# LightRFT 在线策略蒸馏 (OPD) + +在线策略知识蒸馏使小型学生模型能够在强化学习训练过程中从大型教师模型学习。 + +## 概述 + +| 方面 | 描述 | +|------|------| +| **教师模型** | 提供 token 级别对数概率监督的大型模型 | +| **学生模型** | 被训练以匹配教师分布的小型模型 | +| **训练方式** | 在线策略:教师评估学生实际生成的响应 | +| **奖励信号** | 教师的对数概率作为学习信号 | + +## 快速开始 + +### 1. 启动教师模型服务器 + +```bash +# 启动 SGLang 服务器运行教师模型 +CUDA_VISIBLE_DEVICES=7 python3 -m sglang.launch_server \ + --model-path "Qwen/Qwen2.5-7B-Instruct" \ + --host 0.0.0.0 \ + --port 13141 \ + --tp 1 \ + --mem-fraction-static 0.6 +``` + +### 2. 运行训练 + +```bash +bash examples/on_policy_distillation/run_opd_qwen_2.sh +``` + +或手动运行: + +```bash +torchrun --nproc-per-node 2 examples/gsm8k_geo3k/train_colocate.py \ + --pretrain "Qwen/Qwen2.5-0.5B-Instruct" \ + --advantage_estimator "on_policy_distillation" \ + --remote_rm_url "http://127.0.0.1:13141/generate" \ + --reward_pretrain "" \ + --n_samples_per_prompt 4 \ + --actor_learning_rate 1e-6 \ + --init_kl_coef 0.01 \ + --num_episodes 30 +``` + +## 架构 + +``` +┌─────────────────────────────────────────────────────────────┐ +│ OPD 训练流程 │ +├─────────────────────────────────────────────────────────────┤ +│ │ +│ 1. 生成 提示词 ──► [学生模型] ──► 响应 │ +│ │ +│ 2. 评估 [提示词 + 响应] ──► [教师服务器] │ +│ │ │ +│ ▼ │ +│ 教师对数概率 │ +│ │ +│ 3. 计算 优势值 = 教师_logp - 学生_logp │ +│ │ +│ 4. 更新 学生模型 ◄── 策略梯度损失 │ +│ │ +└─────────────────────────────────────────────────────────────┘ +``` + +## 核心组件 + +### 1. 优势值计算器 + +**文件**: `lightrft/trainer/advantage_calculator.py` + +```python +class OnPolicyDistillationCalculator(AdvantageCalculator): + def compute(self, experience, ...): + teacher_log_probs = experience.info["teacher_log_probs"] + student_log_probs = experience.action_log_probs + + # 反向 KL:鼓励学生匹配教师 + reverse_kl = student_log_probs - teacher_log_probs + advantages = -reverse_kl # = teacher - student + + return advantages, returns, {"opd_reverse_kl": reverse_kl} +``` + +### 2. 教师对数概率获取器 + +**文件**: `examples/on_policy_distillation/on_policy_distillation_reward.py` + +- 异步 HTTP 请求到教师服务器 +- 支持 SGLang 和 vLLM 响应格式 +- 自动重试(指数退避) + +### 3. Experience Maker 集成 + +**文件**: `lightrft/trainer/fast_exp_maker.py` + +- 当 `--advantage_estimator "on_policy_distillation"` 时,`--remote_rm_url` 作为教师 URL(而非奖励模型) +- 教师对数概率存储在 `experience.info["teacher_log_probs"]` +- OPD 指标(`opd_reverse_kl_mean/std/min/max`)记录到 wandb + +## 配置 + +### 必需参数 + +| 参数 | 值 | 描述 | +|------|---|------| +| `--advantage_estimator` | `"on_policy_distillation"` | 启用 OPD 模式 | +| `--remote_rm_url` | `"http://host:port/generate"` | 教师服务器 URL | +| `--reward_pretrain` | `""` | 空值(不需要奖励模型) | + +### 推荐超参数 + +| 参数 | 值 | 描述 | +|------|---|------| +| `--n_samples_per_prompt` | 4 | 每个提示词的响应数 | +| `--actor_learning_rate` | 1e-6 | 学生学习率 | +| `--init_kl_coef` | 0.01 | KL 正则化系数 | +| `--num_episodes` | 30 | 训练轮数 | + +## 教师服务器格式 + +### SGLang(推荐) + +```json +{ + "meta_info": { + "input_token_logprobs": [[logprob, rank, token], ...] + } +} +``` + +### vLLM + +```json +{ + "token_logprobs": [logprob1, logprob2, ...] +} +``` + +## 监控 + +### 记录的指标 + +| 指标 | 描述 | +|------|------| +| `opd_reverse_kl_mean` | 平均 KL(学生 \|\| 教师) | +| `opd_reverse_kl_std` | 反向 KL 的标准差 | +| `advantages_mean` | 平均优势值(应接近 0) | +| `policy_loss` | 训练中应下降 | + +### 控制台输出 + +``` +📊 详细步骤统计 +============================================================ +🎁 总奖励: 0.0000 ± 0.0000 (OPD 占位符) +📈 优势值: 0.0012 ± 0.8234 (...) +🎓 OPD 反向 KL: 0.1523 ± 0.0891 (...) +============================================================ +``` + +## 故障排除 + +### 教师服务器问题 + +```bash +# 检查端口是否被占用 +lsof -i :13141 + +# 检查 GPU 可用性 +nvidia-smi + +# 内存不足时减少内存占用 +--mem-fraction-static 0.5 +``` + +### 训练内存不足 + +```bash +--micro_train_batch_size 2 +--micro_rollout_batch_size 2 +--gradient_checkpointing +--zero_stage 3 +``` + +### 收敛缓慢 + +```bash +--n_samples_per_prompt 8 +--actor_learning_rate 5e-7 +--num_episodes 50 +``` + +## 与其他方法对比 + +| 方法 | 奖励信号 | 模式 | 需要 RM | +|------|---------|------|---------| +| GRPO | 任务特定奖励 | 在线 | 是 | +| DPO | 偏好对 | 离线 | 否 | +| **OPD** | 教师对数概率 | 在线 | 否(使用教师) | + +### 优势 + +- 无需单独训练奖励模型 +- Token 级监督(比序列级更精细) +- 在线策略:适应学生不断变化的分布 +- 适用于任何有好教师模型的任务 + +### 局限性 + +- 需要运行教师模型(推理开销) +- 学生无法超越教师的能力 +- 需要足够的计算资源进行教师推理 + +## 文件结构 + +``` +examples/on_policy_distillation/ +├── README.md # 英文文档 +├── README_zh.md # 本文件 +├── run_opd_qwen_2.sh # 训练脚本 +└── on_policy_distillation_reward.py # 教师对数概率获取器 +``` + +## 参考资料 + +- [LightRFT 文档](../../README.md) +- [优势值计算器源码](../../lightrft/trainer/advantage_calculator.py) +- [Fast Experience Maker 源码](../../lightrft/trainer/fast_exp_maker.py) diff --git a/examples/on_policy_distillation/on_policy_distillation_reward.py b/examples/on_policy_distillation/on_policy_distillation_reward.py new file mode 100644 index 0000000..7bc6550 --- /dev/null +++ b/examples/on_policy_distillation/on_policy_distillation_reward.py @@ -0,0 +1,256 @@ +""" +On-Policy Distillation Reward Function for LightRFT + +This module provides functions to query a teacher model for log probabilities +used in knowledge distillation during RL training. + +The teacher model runs as a separate inference server (SGLang), +and this module queries it to get token-level log probabilities for +the sequences generated by the student model. + +Key design decisions (aligned with Slime reference implementation): +- Uses input_ids (not text) to ensure token-level alignment +- Returns per-sample tensors (response tokens only) for dimension alignment +- Includes retry logic for robustness +""" + +import asyncio +import aiohttp +import torch +import logging +from typing import List, Dict, Any, Optional + +logger = logging.getLogger(__name__) + + +async def get_teacher_logprobs_by_ids( + url: str, + input_ids_list: List[List[int]], + response_lengths: List[int], + max_retries: int = 3, + retry_delay: float = 1.0, +) -> List[torch.Tensor]: + """ + Query teacher model using input_ids and return per-sample log prob tensors. + + This is the primary function for OPD, using input_ids to ensure exact + token-level alignment between teacher and student log probabilities. + + :param url: URL of the teacher model inference server (SGLang /generate endpoint) + :param input_ids_list: List of token id sequences [prompt + response] + :param response_lengths: Number of response tokens for each sequence + :param max_retries: Maximum retry attempts per query + :param retry_delay: Initial delay between retries (exponential backoff) + :return: List of tensors, each containing teacher log probs for response tokens only + """ + timeout = aiohttp.ClientTimeout(total=300) + async with aiohttp.ClientSession(timeout=timeout) as session: + + async def query_single(input_ids: List[int], attempt: int = 0) -> Dict[str, Any]: + payload = { + "input_ids": input_ids, + "sampling_params": { + "temperature": 0, + "max_new_tokens": 0, + "skip_special_tokens": False, + }, + "return_logprob": True, + "logprob_start_len": 0, + } + try: + async with session.post(url, json=payload) as resp: + resp.raise_for_status() + return await resp.json() + except Exception as e: + if attempt < max_retries - 1: + delay = retry_delay * (2 ** attempt) + logger.warning( + f"Teacher query failed (attempt {attempt + 1}/{max_retries}): {e}. " + f"Retrying in {delay:.1f}s..." + ) + await asyncio.sleep(delay) + return await query_single(input_ids, attempt + 1) + else: + raise RuntimeError( + f"Failed to query teacher model after {max_retries} attempts: {e}" + ) from e + + tasks = [query_single(ids) for ids in input_ids_list] + results = await asyncio.gather(*tasks) + + # Extract log probs for response tokens only + teacher_lp_list = [] + for response, resp_len in zip(results, response_lengths): + logprobs = response["meta_info"]["input_token_logprobs"] + # Extract logprob values; skip first token (no logprob for BOS) + lp_values = [item[0] if isinstance(item, list) else item for item in logprobs] + teacher_lp = torch.tensor(lp_values[1:], dtype=torch.float32) + # Take the last resp_len tokens (response part only) + teacher_lp = teacher_lp[-resp_len:] + teacher_lp_list.append(teacher_lp) + + return teacher_lp_list + + +async def get_teacher_logprobs_async( + url: str, + sequences: List[str], + session: Optional[aiohttp.ClientSession] = None, + max_retries: int = 3, + retry_delay: float = 1.0, +) -> List[Dict[str, Any]]: + """ + [Legacy] Asynchronously query teacher model using text sequences. + + Prefer get_teacher_logprobs_by_ids() for better token-level alignment. + """ + should_close_session = session is None + if session is None: + timeout = aiohttp.ClientTimeout(total=300) + session = aiohttp.ClientSession(timeout=timeout) + + async def query_single(sequence: str, attempt: int = 0) -> Dict[str, Any]: + payload = { + "text": sequence, + "sampling_params": { + "temperature": 0, + "max_new_tokens": 0, + "skip_special_tokens": False, + }, + "return_logprob": True, + "logprob_start_len": 0, + } + try: + async with session.post(url, json=payload) as resp: + resp.raise_for_status() + return await resp.json() + except Exception as e: + if attempt < max_retries - 1: + delay = retry_delay * (2 ** attempt) + logger.warning( + f"Teacher query failed (attempt {attempt + 1}/{max_retries}): {e}. " + f"Retrying in {delay:.1f}s..." + ) + await asyncio.sleep(delay) + return await query_single(sequence, attempt + 1) + else: + raise RuntimeError( + f"Failed to query teacher model after {max_retries} attempts: {e}" + ) from e + + try: + tasks = [query_single(seq) for seq in sequences] + results = await asyncio.gather(*tasks) + return results + finally: + if should_close_session: + await session.close() + + +def extract_teacher_logprobs( + teacher_responses: List[Dict[str, Any]], + response_lengths: List[int], + device: str = "cpu" +) -> List[torch.Tensor]: + """ + Extract teacher log probabilities for response tokens only. + """ + teacher_log_probs_list = [] + + for response, response_length in zip(teacher_responses, response_lengths): + if "meta_info" in response and "input_token_logprobs" in response["meta_info"]: + logprobs = response["meta_info"]["input_token_logprobs"] + logprob_values = [item[0] if isinstance(item, list) else item for item in logprobs] + teacher_log_probs = torch.tensor(logprob_values[1:], dtype=torch.float32) + teacher_log_probs = teacher_log_probs[-response_length:] + else: + raise ValueError( + f"Unknown response format from teacher model. " + f"Expected 'meta_info' with 'input_token_logprobs'. " + f"Got keys: {response.keys()}" + ) + + if len(teacher_log_probs) < response_length: + padding_length = response_length - len(teacher_log_probs) + teacher_log_probs = torch.cat([ + torch.zeros(padding_length, dtype=torch.float32), + teacher_log_probs + ]) + + teacher_log_probs_list.append(teacher_log_probs.to(device)) + + return teacher_log_probs_list + + +def reward_func(queries: List[str], prompts: List[str], **kwargs) -> torch.Tensor: + """ + Placeholder reward function for on-policy distillation. + Returns zeros; actual learning signal comes from teacher log probs + task rewards. + """ + return torch.zeros(len(queries), dtype=torch.float32) + + +async def get_teacher_logprobs_for_experiences( + teacher_url: str, + sequences: List[str], + response_lengths: List[int], + device: str = "cpu" +) -> torch.Tensor: + """ + [Legacy] Get teacher log probabilities using text sequences. + + Prefer get_teacher_logprobs_by_ids() called from _fetch_teacher_logprobs(). + """ + if not sequences: + return torch.tensor([], device=device) + + responses = await get_teacher_logprobs_async(teacher_url, sequences) + teacher_log_probs_list = extract_teacher_logprobs(responses, response_lengths, device) + + max_length = max(response_lengths) + padded_log_probs = [] + for log_probs, response_length in zip(teacher_log_probs_list, response_lengths): + current_len = len(log_probs) + if current_len < max_length: + padding = torch.zeros(max_length - current_len, dtype=torch.float32, device=device) + log_probs = torch.cat([log_probs, padding]) + elif current_len > max_length: + log_probs = log_probs[:max_length] + padded_log_probs.append(log_probs) + + return torch.stack(padded_log_probs) + + +def get_teacher_logprobs_sync( + teacher_url: str, + sequences: List[str], + response_lengths: List[int], + device: str = "cpu" +) -> torch.Tensor: + """ + [Legacy] Synchronous wrapper for text-based teacher log prob queries. + """ + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + import concurrent.futures + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit( + asyncio.run, + get_teacher_logprobs_for_experiences( + teacher_url, sequences, response_lengths, device + ) + ) + return future.result() + else: + return loop.run_until_complete( + get_teacher_logprobs_for_experiences( + teacher_url, sequences, response_lengths, device + ) + ) + except RuntimeError: + return asyncio.run( + get_teacher_logprobs_for_experiences( + teacher_url, sequences, response_lengths, device + ) + ) diff --git a/examples/on_policy_distillation/run_opd_qwen.sh b/examples/on_policy_distillation/run_opd_qwen.sh new file mode 100644 index 0000000..56116a9 --- /dev/null +++ b/examples/on_policy_distillation/run_opd_qwen.sh @@ -0,0 +1,271 @@ +#!/bin/bash +# +# LightRFT On-Policy Distillation Training Script (Template) +# Knowledge distillation from a teacher model to a student model. +# +# Features: +# - Auto GPU detection and allocation (teacher + training) +# - Robust teacher server with health monitoring +# - Two OPD modes: pure distillation / hybrid (GRPO + OPD) +# +# Usage: +# # Edit paths below, then: +# bash examples/on_policy_distillation/run_opd_qwen.sh +# OPD_MODE=hybrid bash examples/on_policy_distillation/run_opd_qwen.sh +# + +set -euo pipefail + +################################################################################ +# Part 1: User Configuration # +################################################################################ + +# --- Model Paths (EDIT THESE) --- +TEACHER_MODEL_PATH="${TEACHER_MODEL_PATH:-Qwen/Qwen2.5-7B-Instruct}" +STUDENT_MODEL_PATH="${STUDENT_MODEL_PATH:-Qwen/Qwen2.5-0.5B-Instruct}" +DATASET_PATH="${DATASET_PATH:-/path/to/your/dataset.jsonl}" + +# --- Experiment --- +EXPERIMENT_NAME="${EXPERIMENT_NAME:-opd-qwen-7b-to-0.5b}" +export WANDB_API_KEY="${WANDB_API_KEY:-YOUR_WANDB_API_KEY}" +export WANDB_PROJECT="${WANDB_PROJECT:-LightRFT-OnPolicyDistillation}" + +# --- Teacher Server --- +TEACHER_IP="127.0.0.1" +TEACHER_PORT=${TEACHER_PORT:-13141} + +################################################################################ +# Part 2: Auto GPU Detection & Allocation # +################################################################################ + +TOTAL_GPUS=$(nvidia-smi -L 2>/dev/null | wc -l) +if [ "$TOTAL_GPUS" -lt 2 ]; then + echo "ERROR: Need at least 2 GPUs (1 teacher + 1 training). Found: $TOTAL_GPUS" + exit 1 +fi + +# Last GPU for teacher, rest for training +TEACHER_GPU=$((TOTAL_GPUS - 1)) +TRAIN_GPUS=$((TOTAL_GPUS - 1)) + +if [ "$TRAIN_GPUS" -ge 2 ]; then + ENGINE_TP=2 +else + ENGINE_TP=1 +fi + +export NNODES=1 +export GPUS_PER_NODE=$TRAIN_GPUS +export NODE_RANK=0 +export MASTER_ADDR="localhost" +export MASTER_PORT=${MASTER_PORT:-20090} + +echo "GPU Allocation: ${TOTAL_GPUS} total → Teacher: GPU ${TEACHER_GPU}, Training: GPU 0-$((TRAIN_GPUS-1)) (TP=${ENGINE_TP})" + +################################################################################ +# Part 3: Training Hyperparameters # +################################################################################ + +# --- OPD Mode (override via env: OPD_MODE=hybrid) --- +# "pure" - Pure distillation (Slime default): rewards=0, only OPD KL signal +# "hybrid" - GRPO task rewards + OPD KL penalty with advantage whitening +OPD_MODE="${OPD_MODE:-pure}" + +N_SAMPLES=${N_SAMPLES:-8} +EPISODE=${EPISODE:-30} +WARMUP=${WARMUP:-0.03} +OPD_KL_COEF=${OPD_KL_COEF:-1.0} +MICRO_TRAIN_BS=${MICRO_TRAIN_BS:-4} +MICRO_ROLLOUT_BS=${MICRO_ROLLOUT_BS:-4} + +# Auto-align batch sizes to (micro_batch * world_size) +WORLD_SIZE=$((NNODES * GPUS_PER_NODE)) +ALIGN=$((MICRO_TRAIN_BS * WORLD_SIZE)) +RBS=$(( (${RBS:-128} / ALIGN) * ALIGN )) +TBS=$(( (${TBS:-128} / ALIGN) * ALIGN )) +[ "$RBS" -lt "$ALIGN" ] && RBS=$ALIGN +[ "$TBS" -lt "$ALIGN" ] && TBS=$ALIGN + +echo "Batch sizes: TBS=${TBS}, RBS=${RBS} (aligned to micro=${MICRO_TRAIN_BS} * world=${WORLD_SIZE} = ${ALIGN})" + +if [ "$OPD_MODE" = "hybrid" ]; then + ADVANTAGE_ESTIMATOR="on_policy_distillation_hybrid" + KL=${KL:-0.01} + LR=${LR:-5e-7} +else + ADVANTAGE_ESTIMATOR="on_policy_distillation" + KL=${KL:-0.00} + LR=${LR:-5e-7} +fi + +PROMPT_MAX_LEN=${PROMPT_MAX_LEN:-1024} +GENERATE_MAX_LEN=${GENERATE_MAX_LEN:-2048} + +################################################################################ +# Part 4: Teacher Model Server # +################################################################################ + +TEACHER_PID="" +LOG_DIR="rft_logs/${EXPERIMENT_NAME}" +mkdir -p "$LOG_DIR" +TEACHER_LOG="${LOG_DIR}/teacher_model_$(date +%Y%m%d_%H%M%S).log" + +cleanup() { + echo "" + echo "=== Cleaning up ===" + if [ -n "$TEACHER_PID" ] && kill -0 "$TEACHER_PID" 2>/dev/null; then + echo "Stopping teacher server (PID: $TEACHER_PID)..." + kill "$TEACHER_PID" 2>/dev/null || true + sleep 2 + kill -9 "$TEACHER_PID" 2>/dev/null || true + fi + pkill -f "sglang.launch_server.*${TEACHER_PORT}" 2>/dev/null || true + echo "Done." +} +trap cleanup EXIT INT TERM + +start_teacher_server() { + if lsof -Pi :"$TEACHER_PORT" -sTCP:LISTEN -t >/dev/null 2>&1; then + echo "Port $TEACHER_PORT in use, killing existing process..." + lsof -ti:"$TEACHER_PORT" | xargs kill -9 2>/dev/null || true + sleep 3 + fi + + echo "Starting teacher server on GPU $TEACHER_GPU..." + CUDA_VISIBLE_DEVICES=$TEACHER_GPU python3 -m sglang.launch_server \ + --model-path "$TEACHER_MODEL_PATH" \ + --host 0.0.0.0 \ + --port "$TEACHER_PORT" \ + --tp 1 \ + --chunked-prefill-size 4096 \ + --mem-fraction-static 0.7 \ + --disable-radix-cache \ + --request-timeout 300 \ + --max-running-requests 64 \ + >> "$TEACHER_LOG" 2>&1 & + + TEACHER_PID=$! + echo "Teacher PID: $TEACHER_PID, Log: $TEACHER_LOG" + + local max_wait=600 waited=0 + while ! curl -sf "http://$TEACHER_IP:$TEACHER_PORT/health" >/dev/null 2>&1; do + if [ $waited -ge $max_wait ]; then + echo "ERROR: Teacher server failed to start in ${max_wait}s" + tail -30 "$TEACHER_LOG" + exit 1 + fi + if ! kill -0 "$TEACHER_PID" 2>/dev/null; then + echo "ERROR: Teacher server process died" + tail -30 "$TEACHER_LOG" + exit 1 + fi + printf "." + sleep 5 + waited=$((waited + 5)) + done + echo "" + echo "Teacher server ready at $TEACHER_IP:$TEACHER_PORT" +} + +# Validate model paths +for p in "$TEACHER_MODEL_PATH" "$STUDENT_MODEL_PATH"; do + [ -e "$p" ] || { echo "ERROR: Path not found: $p"; exit 1; } +done + +start_teacher_server +sleep 3 + +################################################################################ +# Part 5: Launch Training # +################################################################################ + +current_time=$(date +"%Y%m%d_%H%M%S") +SAVE_MODEL_NAME="${EXPERIMENT_NAME}-${OPD_MODE}-ep${EPISODE}-kl${KL}-lr${LR}-${current_time}" +WANDB_RUN_NAME="${EXPERIMENT_NAME}-${OPD_MODE}-${current_time}" +TEACHER_URL="http://$TEACHER_IP:$TEACHER_PORT/generate" + +mkdir -p "results/${EXPERIMENT_NAME}/${SAVE_MODEL_NAME}" + +export TORCH_NCCL_AVOID_RECORD_STREAMS=1 +export NCCL_DEBUG="WARN" +export NCCL_TIMEOUT=3600 +export IGNORE_EOS=0 +export WANDB_MODE="${WANDB_MODE:-offline}" + +echo "=========================================" +echo "On-Policy Distillation Training" +echo "=========================================" +echo "Mode: $OPD_MODE ($ADVANTAGE_ESTIMATOR)" +echo "Student: $STUDENT_MODEL_PATH" +echo "Teacher: $TEACHER_URL" +echo "GPUs: Training=0-$((TRAIN_GPUS-1)), Teacher=$TEACHER_GPU" +echo "=========================================" + +set -x + +torchrun \ + --nnodes $NNODES \ + --nproc-per-node $GPUS_PER_NODE \ + --node_rank $NODE_RANK \ + --master-port $MASTER_PORT \ + --master-addr $MASTER_ADDR \ + examples/gsm8k_geo3k/train_colocate.py \ + --pretrain "$STUDENT_MODEL_PATH" \ + --save_trajectories \ + --advantage_estimator "${ADVANTAGE_ESTIMATOR}" \ + --opd_kl_coef ${OPD_KL_COEF} \ + --fsdp \ + --use_kl_loss \ + --flash_attn \ + --engine_type sglang \ + --enable_engine_sleep \ + --rm_use_engine \ + --reward_pretrain "" \ + --remote_rm_url "$TEACHER_URL" \ + --save_path "results/${EXPERIMENT_NAME}/${SAVE_MODEL_NAME}" \ + --ckpt_path "results/${EXPERIMENT_NAME}/${SAVE_MODEL_NAME}" \ + --micro_train_batch_size ${MICRO_TRAIN_BS} \ + --train_batch_size ${TBS} \ + --micro_rollout_batch_size ${MICRO_ROLLOUT_BS} \ + --rollout_batch_size ${RBS} \ + --max_epochs 1 \ + --num_episodes ${EPISODE} \ + --lr_warmup_ratio ${WARMUP} \ + --n_samples_per_prompt $N_SAMPLES \ + --prompt_max_len $PROMPT_MAX_LEN \ + --generate_max_len $GENERATE_MAX_LEN \ + --zero_stage 3 \ + --bf16 \ + --actor_learning_rate $LR \ + --init_kl_coef $KL \ + --kl_estimator "k3" \ + --prompt_data "$DATASET_PATH" \ + --input_key "prompt" \ + --label_key "label" \ + --eval_steps 20 \ + --eval_split "test" \ + --apply_chat_template \ + --gradient_checkpointing \ + --save_steps 20 \ + --max_ckpt_num 3 \ + --engine_mem_util 0.6 \ + --engine_tp_size $ENGINE_TP \ + --l2 1.0e-2 \ + --freeze_prefix \ + --adam_offload \ + --text_only \ + --use_wandb "${WANDB_API_KEY}" \ + --wandb_project "${WANDB_PROJECT}" \ + --wandb_run_name "${WANDB_RUN_NAME}" \ + 2>&1 | tee "${LOG_DIR}/node${NODE_RANK}_${current_time}.log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} +set +x + +echo "" +echo "=========================================" +echo "Training Complete (exit: $TRAINING_EXIT_CODE)" +echo "Model: results/${EXPERIMENT_NAME}/${SAVE_MODEL_NAME}" +echo "=========================================" + +exit $TRAINING_EXIT_CODE diff --git a/examples/on_policy_distillation/test_opd.py b/examples/on_policy_distillation/test_opd.py new file mode 100644 index 0000000..6318e26 --- /dev/null +++ b/examples/on_policy_distillation/test_opd.py @@ -0,0 +1,453 @@ +""" +Test script for On-Policy Distillation implementation in LightRFT. + +Tests both pure and hybrid OPD modes, advantage whitening, +teacher logprob extraction, and dimension alignment. +""" + +import torch +import sys +from pathlib import Path + +lightrft_path = Path(__file__).parent.parent.parent +sys.path.insert(0, str(lightrft_path)) + + +class MockConfig: + """Reusable mock config for advantage calculators.""" + def __init__(self, **kwargs): + self.advantages_norm = False + self.advantage_clip = 0.0 + self.opd_kl_coef = 1.0 + self.n_samples_per_prompt = 4 + self.dynamic_sampling = False + self.micro_train_batch_size = 4 + for k, v in kwargs.items(): + setattr(self, k, v) + + +class MockExperience: + """Reusable mock experience object.""" + def __init__(self, batch_size=4, num_actions=10, teacher_offset=-0.5): + self.action_log_probs = torch.randn(batch_size, num_actions) * 0.5 - 1.0 + self.action_mask = torch.ones(batch_size, num_actions, dtype=torch.bool) + self.action_mask[:, -2:] = False # last 2 tokens are padding + self.info = { + "teacher_log_probs": self.action_log_probs + teacher_offset, + "reward": torch.rand(batch_size), + "response_length": torch.full((batch_size,), num_actions), + } + + +# ============================================================================ +# Test: Factory registration +# ============================================================================ + +def test_factory(): + """All estimators including both OPD modes are registered.""" + print("Test: Factory registration") + + from lightrft.trainer.advantage_calculator import get_advantage_calculator + + config = MockConfig() + estimators = [ + "gae", "reinforce", "rloo", "reinforce_baseline", + "group_norm", "grpo", "cpgd", + "on_policy_distillation", + "on_policy_distillation_hybrid", + ] + + for name in estimators: + calc = get_advantage_calculator(name, config) + print(f" {name} -> {calc.__class__.__name__}") + + try: + get_advantage_calculator("nonexistent", config) + assert False, "Should raise ValueError" + except ValueError: + pass + + print(" PASS\n") + return True + + +# ============================================================================ +# Test: Pure OPD calculator +# ============================================================================ + +def test_pure_opd(): + """Pure distillation: rewards zeroed, advantages from KL only.""" + print("Test: Pure OPD calculator") + + from lightrft.trainer.advantage_calculator import OnPolicyDistillationCalculator + + config = MockConfig(opd_kl_coef=1.0) + calc = OnPolicyDistillationCalculator(config) + + # preprocess_rewards should zero out rewards + rewards = torch.tensor([0.5, 0.8, 0.3, 0.9, 0.1, 0.7, 0.2, 0.4]) + experiences = [MockExperience(batch_size=4), MockExperience(batch_size=4)] + exps, reward_chunks = calc.preprocess_rewards(rewards, experiences, max_new_tokens=100) + for chunk in reward_chunks: + assert (chunk == 0).all(), f"Pure OPD should zero rewards, got {chunk}" + print(" preprocess_rewards zeros out rewards: OK") + + # compute should produce whitened advantages from KL + exp = MockExperience(batch_size=4, num_actions=10, teacher_offset=-0.5) + final_reward = torch.zeros(4, 10) + adv, ret, info = calc.compute(exp, final_reward, gamma=1.0, generate_kwargs={}) + + assert adv.shape == (4, 10), f"Wrong shape: {adv.shape}" + assert (adv[:, -2:] == 0).all(), "Padding positions should be 0" + + # Whitened: mean should be ~0 + masked = adv[exp.action_mask] + assert abs(masked.mean()) < 0.1, f"Whitened mean should be ~0, got {masked.mean():.4f}" + print(f" advantages whitened (mean={masked.mean():.4f}, std={masked.std():.4f}): OK") + + # opd_reverse_kl metric should be present + assert "opd_reverse_kl" in info, "Missing opd_reverse_kl metric" + print(f" opd_reverse_kl metric present: OK") + + # Missing teacher_log_probs should raise + exp_bad = MockExperience() + del exp_bad.info["teacher_log_probs"] + try: + calc.compute(exp_bad, final_reward, gamma=1.0, generate_kwargs={}) + assert False, "Should raise ValueError" + except ValueError: + pass + print(" missing teacher_log_probs raises ValueError: OK") + + print(" PASS\n") + return True + + +# ============================================================================ +# Test: Hybrid OPD calculator +# ============================================================================ + +def test_hybrid_opd(): + """Hybrid: GRPO base advantages + OPD KL penalty, then whitened.""" + print("Test: Hybrid OPD calculator") + + from lightrft.trainer.advantage_calculator import OnPolicyDistillationHybridCalculator + + config = MockConfig(opd_kl_coef=1.0, n_samples_per_prompt=4) + calc = OnPolicyDistillationHybridCalculator(config) + + # preprocess_rewards should apply GRPO normalization (not zero) + rewards = torch.tensor([0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0]) + experiences = [MockExperience(batch_size=4), MockExperience(batch_size=4)] + exps, reward_chunks = calc.preprocess_rewards(rewards, experiences, max_new_tokens=100) + combined = torch.cat(reward_chunks) + assert not (combined == 0).all(), "Hybrid should NOT zero rewards" + print(f" preprocess_rewards applies GRPO normalization: OK") + + # compute should combine GRPO + OPD and whiten + exp = MockExperience(batch_size=4, num_actions=10, teacher_offset=-0.3) + # Simulate GRPO-normalized reward broadcast to tokens + final_reward = torch.randn(4, 10) * 0.5 + adv, ret, info = calc.compute(exp, final_reward, gamma=1.0, generate_kwargs={}) + + assert adv.shape == (4, 10), f"Wrong shape: {adv.shape}" + masked = adv[exp.action_mask] + assert abs(masked.mean()) < 0.1, f"Whitened mean should be ~0, got {masked.mean():.4f}" + print(f" advantages whitened (mean={masked.mean():.4f}, std={masked.std():.4f}): OK") + assert "opd_reverse_kl" in info + print(" PASS\n") + return True + + +# ============================================================================ +# Test: Advantage whitening +# ============================================================================ + +def test_whiten_advantages(): + """Whitening normalizes to ~zero mean, ~unit std.""" + print("Test: Advantage whitening") + + from lightrft.trainer.advantage_calculator import _whiten_advantages + + # Large-scale advantages (simulating raw OPD KL) + adv = torch.randn(4, 20) * 10 + 5 # mean=5, std=10 + mask = torch.ones(4, 20, dtype=torch.bool) + mask[:, -3:] = False + + whitened = _whiten_advantages(adv, mask) + + masked_vals = whitened[mask] + assert abs(masked_vals.mean()) < 0.01, f"Mean should be ~0, got {masked_vals.mean():.4f}" + assert abs(masked_vals.std() - 1.0) < 0.1, f"Std should be ~1, got {masked_vals.std():.4f}" + print(f" mean={masked_vals.mean():.4f}, std={masked_vals.std():.4f}: OK") + + # Edge case: very small batch + adv_small = torch.tensor([[1.0, 2.0]]) + mask_small = torch.ones(1, 2, dtype=torch.bool) + whitened_small = _whiten_advantages(adv_small, mask_small) + assert whitened_small.shape == (1, 2) + print(" small batch handled: OK") + + print(" PASS\n") + return True + + +# ============================================================================ +# Test: OPD KL penalty helper +# ============================================================================ + +def test_opd_kl_penalty(): + """_apply_opd_kl_penalty computes correct penalty direction.""" + print("Test: OPD KL penalty") + + from lightrft.trainer.advantage_calculator import _apply_opd_kl_penalty + + # Teacher is better (higher log probs) → student should get positive advantage + student_lp = torch.tensor([[-2.0, -3.0, -1.5]]) + teacher_lp = torch.tensor([[-1.0, -1.5, -0.5]]) + mask = torch.ones(1, 3, dtype=torch.bool) + + opd_adv, info = _apply_opd_kl_penalty(student_lp, teacher_lp, mask, opd_kl_coef=1.0) + + # reverse_kl = student - teacher = negative (student worse) + # opd_adv = -1.0 * reverse_kl = positive (encourage matching teacher) + assert (opd_adv > 0).all(), f"Should be positive when teacher > student, got {opd_adv}" + print(f" teacher > student → positive advantage: OK ({opd_adv.tolist()})") + + # Student is overconfident → negative advantage + student_lp2 = torch.tensor([[-0.5, -0.3, -0.2]]) + teacher_lp2 = torch.tensor([[-2.0, -2.5, -3.0]]) + opd_adv2, _ = _apply_opd_kl_penalty(student_lp2, teacher_lp2, mask, opd_kl_coef=1.0) + assert (opd_adv2 < 0).all(), f"Should be negative when student > teacher, got {opd_adv2}" + print(f" student > teacher → negative advantage: OK ({opd_adv2.tolist()})") + + # opd_kl_coef scales the penalty + opd_adv_scaled, _ = _apply_opd_kl_penalty(student_lp, teacher_lp, mask, opd_kl_coef=0.5) + assert torch.allclose(opd_adv_scaled, opd_adv * 0.5) + print(" opd_kl_coef scaling: OK") + + # Mask is respected + mask_partial = torch.tensor([[True, True, False]]) + opd_adv_masked, _ = _apply_opd_kl_penalty(student_lp, teacher_lp, mask_partial, opd_kl_coef=1.0) + assert opd_adv_masked[0, 2] == 0, "Masked position should be 0" + print(" action mask respected: OK") + + assert "opd_reverse_kl" in info + print(" PASS\n") + return True + + +# ============================================================================ +# Test: Teacher logprob extraction (get_teacher_logprobs_by_ids mock) +# ============================================================================ + +def test_teacher_logprob_extraction(): + """extract_teacher_logprobs handles SGLang format correctly.""" + print("Test: Teacher logprob extraction") + + from examples.on_policy_distillation.on_policy_distillation_reward import ( + extract_teacher_logprobs + ) + + # SGLang format: [logprob, rank, token_str] tuples + response = { + "meta_info": { + "input_token_logprobs": [ + None, # BOS token + [-0.1, 1, "hello"], + [-0.2, 2, "world"], + [-0.15, 1, "!"], + [-0.3, 3, "."], + [-0.25, 2, "end"], + ] + } + } + + # Extract last 3 tokens as response + lp_list = extract_teacher_logprobs([response], response_lengths=[3], device="cpu") + assert len(lp_list) == 1 + assert len(lp_list[0]) == 3 + expected = torch.tensor([-0.3, -0.25, -0.15]) # Wait, let me check... + + # logprob_values = [-0.1, -0.2, -0.15, -0.3, -0.25] (skip None, take [0] from each) + # teacher_log_probs[-3:] = [-0.15, -0.3, -0.25] + # Hmm, actually the last 3 of [-0.1, -0.2, -0.15, -0.3, -0.25] = [-0.15, -0.3, -0.25] + expected = torch.tensor([-0.15, -0.3, -0.25]) + assert torch.allclose(lp_list[0], expected), f"Got {lp_list[0]}, expected {expected}" + print(f" SGLang format extraction: OK ({lp_list[0].tolist()})") + + # Test padding when response_length > available logprobs + lp_list2 = extract_teacher_logprobs([response], response_lengths=[10], device="cpu") + assert len(lp_list2[0]) == 10, f"Should pad to 10, got {len(lp_list2[0])}" + print(f" padding for short sequences: OK (len={len(lp_list2[0])})") + + print(" PASS\n") + return True + + +# ============================================================================ +# Test: Dimension alignment (teacher_log_probs vs action_log_probs) +# ============================================================================ + +def test_dimension_alignment(): + """teacher_log_probs must match action_log_probs shape [batch, num_actions].""" + print("Test: Dimension alignment") + + from lightrft.trainer.advantage_calculator import OnPolicyDistillationCalculator + + config = MockConfig(opd_kl_coef=1.0) + calc = OnPolicyDistillationCalculator(config) + + # Simulate different response lengths within a batch + # action_log_probs: [batch=2, num_actions=8] (padded to max response length) + # action_mask: first sample has 6 real tokens, second has 8 + batch_size, num_actions = 2, 8 + exp = MockExperience.__new__(MockExperience) + exp.action_log_probs = torch.randn(batch_size, num_actions) + exp.action_mask = torch.ones(batch_size, num_actions, dtype=torch.bool) + exp.action_mask[0, :2] = False # first 2 positions are prompt padding for sample 0 + exp.info = { + "teacher_log_probs": torch.randn(batch_size, num_actions), # same shape + } + + final_reward = torch.zeros(batch_size, num_actions) + adv, _, _ = calc.compute(exp, final_reward, gamma=1.0, generate_kwargs={}) + assert adv.shape == (batch_size, num_actions) + assert adv[0, 0] == 0 and adv[0, 1] == 0, "Masked positions should be 0" + print(f" shape match [batch={batch_size}, num_actions={num_actions}]: OK") + + # Mismatched shapes should fail + exp_bad = MockExperience.__new__(MockExperience) + exp_bad.action_log_probs = torch.randn(2, 8) + exp_bad.action_mask = torch.ones(2, 8, dtype=torch.bool) + exp_bad.info = {"teacher_log_probs": torch.randn(2, 12)} # wrong dim! + + try: + calc.compute(exp_bad, final_reward, gamma=1.0, generate_kwargs={}) + assert False, "Should fail on shape mismatch" + except RuntimeError: + pass + print(" shape mismatch correctly raises RuntimeError: OK") + + print(" PASS\n") + return True + + +# ============================================================================ +# Test: Pure vs Hybrid produce different results +# ============================================================================ + +def test_pure_vs_hybrid(): + """Pure and hybrid modes produce meaningfully different advantages.""" + print("Test: Pure vs Hybrid comparison") + + from lightrft.trainer.advantage_calculator import ( + OnPolicyDistillationCalculator, + OnPolicyDistillationHybridCalculator, + ) + + config = MockConfig(opd_kl_coef=1.0, n_samples_per_prompt=4) + pure_calc = OnPolicyDistillationCalculator(config) + hybrid_calc = OnPolicyDistillationHybridCalculator(config) + + torch.manual_seed(42) + exp = MockExperience(batch_size=4, num_actions=10, teacher_offset=-0.5) + + # Pure: rewards are zeroed + rewards = torch.tensor([0.0, 1.0, 0.0, 1.0, 0.5, 0.5, 0.8, 0.2]) + _, pure_rewards = pure_calc.preprocess_rewards(rewards.clone(), [exp, exp], 100) + _, hybrid_rewards = hybrid_calc.preprocess_rewards(rewards.clone(), [exp, exp], 100) + + pure_r = torch.cat(pure_rewards) + hybrid_r = torch.cat(hybrid_rewards) + assert (pure_r == 0).all(), "Pure should zero rewards" + assert not (hybrid_r == 0).all(), "Hybrid should keep rewards" + print(f" reward preprocessing differs: OK (pure=0, hybrid has signal)") + + # Both should produce valid advantages + final_reward_pure = torch.zeros(4, 10) + final_reward_hybrid = torch.randn(4, 10) * 0.5 + + adv_pure, _, _ = pure_calc.compute(exp, final_reward_pure, 1.0, {}) + adv_hybrid, _, _ = hybrid_calc.compute(exp, final_reward_hybrid, 1.0, {}) + + assert adv_pure.shape == adv_hybrid.shape + # They should differ (different reward signals) + assert not torch.allclose(adv_pure, adv_hybrid, atol=0.01) + print(f" advantages differ between modes: OK") + + print(" PASS\n") + return True + + +# ============================================================================ +# Test: reward_func returns zeros +# ============================================================================ + +def test_reward_func(): + """Placeholder reward_func returns zeros.""" + print("Test: reward_func placeholder") + + from examples.on_policy_distillation.on_policy_distillation_reward import reward_func + + result = reward_func( + queries=["q1", "q2", "q3"], + prompts=["p1", "p2", "p3"], + ) + assert isinstance(result, torch.Tensor) + assert result.shape == (3,) + assert (result == 0).all() + print(" returns zeros: OK") + + print(" PASS\n") + return True + + +# ============================================================================ +# Runner +# ============================================================================ + +def run_all_tests(): + print("=" * 60) + print("On-Policy Distillation Test Suite (v3)") + print("=" * 60) + print() + + tests = [ + ("Factory registration", test_factory), + ("OPD KL penalty", test_opd_kl_penalty), + ("Advantage whitening", test_whiten_advantages), + ("Pure OPD calculator", test_pure_opd), + ("Hybrid OPD calculator", test_hybrid_opd), + ("Dimension alignment", test_dimension_alignment), + ("Pure vs Hybrid", test_pure_vs_hybrid), + ("Teacher logprob extraction", test_teacher_logprob_extraction), + ("Reward func placeholder", test_reward_func), + ] + + results = [] + for name, fn in tests: + try: + ok = fn() + results.append((name, ok)) + except Exception as e: + print(f" FAIL: {e}") + import traceback + traceback.print_exc() + results.append((name, False)) + print() + + print("=" * 60) + print("Summary") + print("=" * 60) + passed = sum(1 for _, ok in results if ok) + for name, ok in results: + print(f" {'PASS' if ok else 'FAIL'}: {name}") + print(f"\n{passed}/{len(results)} passed") + print("=" * 60) + + return 0 if passed == len(results) else 1 + + +if __name__ == "__main__": + sys.exit(run_all_tests()) diff --git a/lightrft/strategy/config.py b/lightrft/strategy/config.py index c699300..f88ae0f 100644 --- a/lightrft/strategy/config.py +++ b/lightrft/strategy/config.py @@ -114,6 +114,8 @@ class StrategyConfig: dynamic_sampling: bool = False # (str): Advantage estimator method, defaults to "gae" advantage_estimator: str = "group_norm" + # (float): OPD KL coefficient for on-policy distillation penalty + opd_kl_coef: float = 1.0 # KL loss and estimation # (bool): Use KL loss in training, defaults to False diff --git a/lightrft/trainer/advantage_calculator.py b/lightrft/trainer/advantage_calculator.py index 1ce9a8f..0160de4 100644 --- a/lightrft/trainer/advantage_calculator.py +++ b/lightrft/trainer/advantage_calculator.py @@ -717,6 +717,150 @@ def compute( return advantages, returns, info_dict +def _apply_opd_kl_penalty( + student_log_probs: torch.Tensor, + teacher_log_probs: torch.Tensor, + action_mask: Optional[torch.Tensor], + opd_kl_coef: float, +) -> Tuple[torch.Tensor, Dict]: + """ + Compute OPD reverse KL penalty: -opd_kl_coef * (student_logp - teacher_logp). + + Shared helper for both pure and hybrid OPD modes. + + :return: Tuple of (opd_advantages, info_dict with opd_reverse_kl metric) + """ + reverse_kl = student_log_probs - teacher_log_probs + opd_adv = -opd_kl_coef * reverse_kl + + if action_mask is not None: + opd_adv = opd_adv * action_mask + + info_dict = {} + if action_mask is not None: + masked_rkl = reverse_kl * action_mask + info_dict["opd_reverse_kl"] = masked_rkl.sum(-1) / action_mask.sum(-1).clamp(min=1) + + return opd_adv, info_dict + + +def _whiten_advantages(advantages: torch.Tensor, action_mask: Optional[torch.Tensor], eps: float = 1e-8) -> torch.Tensor: + """ + Whiten advantages using masked mean/std (matching Slime's distributed_masked_whiten). + + This normalizes advantages to zero mean and unit variance, which stabilizes + training when OPD KL penalty has different scale from base advantages. + """ + if action_mask is not None: + mask = action_mask.bool() + masked_adv = torch.masked_select(advantages, mask) + else: + masked_adv = advantages.flatten() + + if masked_adv.numel() < 2: + return advantages + + mean = masked_adv.mean() + std = masked_adv.std() + return (advantages - mean) / (std + eps) + + +class OnPolicyDistillationCalculator(AdvantageCalculator): + """ + On-Policy Distillation calculator — pure distillation mode. + + Following Slime's design: + - Task rewards are zeroed out + - The ONLY learning signal is the OPD KL penalty: + advantages = -opd_kl_coef * (student_logp - teacher_logp) + - Advantage whitening is applied for training stability + + Use --advantage_estimator on_policy_distillation + """ + def __init__(self, config): + super().__init__(config) + self.opd_kl_coef = getattr(config, 'opd_kl_coef', 1.0) + + def preprocess_rewards(self, rewards, experiences, max_new_tokens): + """Zero out all rewards — pure distillation mode.""" + zero_rewards = torch.zeros_like(rewards) + return experiences, list(zero_rewards.chunk(len(experiences))) + + def compute(self, experience, final_reward, gamma, generate_kwargs): + """advantages = -opd_kl_coef * (student_logp - teacher_logp). + Whitening is done cross-batch in normalize_advantages_cross_batch.""" + if "teacher_log_probs" not in experience.info: + raise ValueError("teacher_log_probs not found in experience.info.") + + teacher_lp = experience.info["teacher_log_probs"].to(experience.action_log_probs.device) + student_lp = experience.action_log_probs + + advantages, info_dict = _apply_opd_kl_penalty( + student_lp, teacher_lp, experience.action_mask, self.opd_kl_coef + ) + + returns = advantages.clone() + + if self.config.advantage_clip > 0: + clip_val = self.config.advantage_clip + info_dict["advantage_clip_frac"] = compute_clip_fraction(advantages, clip_val, -clip_val) + advantages = torch.clamp(advantages, -clip_val, clip_val) + + return advantages, returns, info_dict + + +class OnPolicyDistillationHybridCalculator(AdvantageCalculator): + """ + Hybrid On-Policy Distillation calculator — GRPO task rewards + OPD KL penalty. + + Combines GRPO (group normalization) base advantages from task rewards with + OPD KL penalty from teacher model. Advantage whitening is applied AFTER + combining both signals to resolve scale mismatch. + + advantages = whiten(GRPO_base_advantages + OPD_KL_penalty) + + Use --advantage_estimator on_policy_distillation_hybrid + """ + def __init__(self, config): + super().__init__(config) + self.opd_kl_coef = getattr(config, 'opd_kl_coef', 1.0) + self.base_calculator = GroupNormCalculator(config) + + def preprocess_rewards(self, rewards, experiences, max_new_tokens): + """Delegate to GRPO for task reward normalization.""" + return self.base_calculator.preprocess_rewards(rewards, experiences, max_new_tokens) + + def compute(self, experience, final_reward, gamma, generate_kwargs): + """advantages = GRPO_base + OPD_KL_penalty. + Whitening is done cross-batch in normalize_advantages_cross_batch.""" + # Step 1: GRPO base advantages from task rewards + base_advantages, returns, info_dict = self.base_calculator.compute( + experience, final_reward, gamma, generate_kwargs + ) + + # Step 2: OPD KL penalty + if "teacher_log_probs" not in experience.info: + raise ValueError("teacher_log_probs not found in experience.info.") + + teacher_lp = experience.info["teacher_log_probs"].to(base_advantages.device) + student_lp = experience.action_log_probs + + opd_adv, opd_info = _apply_opd_kl_penalty( + student_lp, teacher_lp, experience.action_mask, self.opd_kl_coef + ) + info_dict.update(opd_info) + + # Step 3: Combine (whitening done cross-batch later) + advantages = base_advantages + opd_adv + + if self.config.advantage_clip > 0: + clip_val = self.config.advantage_clip + info_dict["advantage_clip_frac"] = compute_clip_fraction(advantages, clip_val, -clip_val) + advantages = torch.clamp(advantages, -clip_val, clip_val) + + return advantages, returns, info_dict + + # ============================================================================ # Factory Function # ============================================================================ @@ -728,7 +872,8 @@ def get_advantage_calculator(estimator_name: str, config) -> AdvantageCalculator :param estimator_name: Name of the advantage estimation method Options: "gae", "cpgd", "reinforce", "rloo", - "reinforce_baseline", "group_norm", "grpo" + "reinforce_baseline", "group_norm", "grpo", + "on_policy_distillation" :type estimator_name: str :param config: Configuration object containing training parameters :type config: object @@ -744,6 +889,8 @@ def get_advantage_calculator(estimator_name: str, config) -> AdvantageCalculator "reinforce_baseline": REINFORCEBaselineCalculator, "cpgd": CPGDCalculator, "grpo": GroupNormCalculator, # Alias for group_norm + "on_policy_distillation": OnPolicyDistillationCalculator, + "on_policy_distillation_hybrid": OnPolicyDistillationHybridCalculator, } calculator_class = calculator_map.get(estimator_name) @@ -779,7 +926,10 @@ def normalize_advantages_cross_batch(experiences: List, advantage_estimator: str :return: List of Experience objects with normalized advantages. :rtype: List """ - if advantage_estimator not in ["gae", "reinforce", "reinforce_baseline"]: + if advantage_estimator not in [ + "gae", "reinforce", "reinforce_baseline", + "on_policy_distillation", "on_policy_distillation_hybrid", + ]: return experiences # Collect all advantages and action masks diff --git a/lightrft/trainer/experience_maker.py b/lightrft/trainer/experience_maker.py index d7fd089..c5f63ea 100644 --- a/lightrft/trainer/experience_maker.py +++ b/lightrft/trainer/experience_maker.py @@ -400,6 +400,15 @@ def make_experience_list(self, all_prompts: Union[str, List[str]], **generate_kw generate_kwargs["gamma"], ) experience.advantages = deepcopy(experience.returns) + elif self.advantage_estimator in ("on_policy_distillation", "on_policy_distillation_hybrid"): + # OPD: cumulative returns from rewards (0 for pure, task rewards for hybrid) + # The OPD KL penalty is applied in the advantage calculator + experience.returns = self.get_cumulative_returns( + reward, + experience.action_mask, + generate_kwargs["gamma"], + ) + experience.advantages = deepcopy(experience.returns) else: raise Exception(f"Unknown advantage_estimator {self.advantage_estimator}") @@ -575,6 +584,83 @@ def process_experiences(self, experiences: List[Experience]) -> Tuple[List[Exper """ args = self.strategy.args + # On-policy distillation: query teacher model for log probs, then use GRPO reward shaping + if args.advantage_estimator in ("on_policy_distillation", "on_policy_distillation_hybrid"): + if self.remote_rm_url is None or len(self.remote_rm_url) == 0: + raise ValueError( + "On-policy distillation requires a teacher model URL. " + "Please set --remote_rm_url to the teacher model inference server." + ) + + import asyncio + teacher_url = self.remote_rm_url[0] if isinstance(self.remote_rm_url, list) else self.remote_rm_url + + # Collect all sequences as input_ids and response lengths + all_input_ids = [] + all_response_lengths = [] + for experience in experiences: + sequences_batch = experience.sequences + response_lengths = experience.info["response_length"] + for i, seq in enumerate(sequences_batch): + all_input_ids.append(seq.cpu().tolist()) + all_response_lengths.append(int(response_lengths[i].item())) + + # Query teacher model for log probs using input_ids + try: + from examples.on_policy_distillation.on_policy_distillation_reward import ( + get_teacher_logprobs_by_ids + ) + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + teacher_lp_list = loop.run_until_complete( + get_teacher_logprobs_by_ids( + url=teacher_url, + input_ids_list=all_input_ids, + response_lengths=all_response_lengths, + ) + ) + finally: + loop.close() + + # Align teacher log probs to action_log_probs shape [batch, num_actions] + idx = 0 + for experience in experiences: + batch_size = experience.sequences.size(0) + num_actions = experience.action_mask.shape[1] + aligned = torch.zeros(batch_size, num_actions, dtype=torch.float32) + for j in range(batch_size): + tlp = teacher_lp_list[idx + j] + resp_len = all_response_lengths[idx + j] + actual_len = min(len(tlp), resp_len, num_actions) + start_pos = num_actions - resp_len + if start_pos >= 0: + aligned[j, start_pos:start_pos + actual_len] = tlp[:actual_len] + else: + aligned[j, :] = tlp[-num_actions:] + experience.info["teacher_log_probs"] = aligned + idx += batch_size + + except Exception as e: + logger.error(f"Failed to get teacher log probs: {e}") + raise + + # Return rewards based on mode + if args.advantage_estimator == "on_policy_distillation": + # Pure distillation: zero rewards, learning signal from OPD KL only + zero_rewards = torch.zeros(sum(exp.sequences.size(0) for exp in experiences)) + rewards = zero_rewards.chunk(len(experiences)) + return experiences, list(rewards) + else: + # Hybrid: use task rewards with GRPO normalization + rewards = torch.cat([experience.info["reward"] for experience in experiences]) + rewards = rewards.reshape(-1, args.n_samples_per_prompt) + baseline = rewards.mean(-1, keepdim=True) + rewards = (rewards - baseline) / (rewards.std(-1, keepdim=True) + 1e-9) + rewards = rewards.flatten().chunk(len(experiences)) + return experiences, list(rewards) + # Reward shaping for RLOO if args.advantage_estimator == "rloo": rewards = torch.cat([experience.info["reward"] for experience in experiences]) diff --git a/lightrft/trainer/fast_exp_maker.py b/lightrft/trainer/fast_exp_maker.py index 98c91d2..094bbbd 100644 --- a/lightrft/trainer/fast_exp_maker.py +++ b/lightrft/trainer/fast_exp_maker.py @@ -58,6 +58,17 @@ from .image_utils import normalize_images, get_images_num from .video_utils import normalize_videos, get_videos_num +# On-Policy Distillation imports +try: + from examples.on_policy_distillation.on_policy_distillation_reward import ( + get_teacher_logprobs_for_experiences, + get_teacher_logprobs_by_ids, + ) + import asyncio + OPD_AVAILABLE = True +except ImportError: + OPD_AVAILABLE = False + # ============================================================================ # Data Structures # ============================================================================ @@ -441,7 +452,13 @@ def __init__( :type packing_samples: bool """ self.reward_model = reward_model - self.remote_rm_url = remote_rm_url + # Ensure remote_rm_url is a list for consistent iteration + if remote_rm_url is None: + self.remote_rm_url = None + elif isinstance(remote_rm_url, str): + self.remote_rm_url = [remote_rm_url] + else: + self.remote_rm_url = list(remote_rm_url) self.custom_reward_func = custom_reward_func self.reward_fn = reward_fn self.reward_fn_label_map = reward_fn_label_map or {} @@ -923,9 +940,20 @@ def __init__(self, *args, packing_samples: bool = False, processor=None, **kwarg else: self.multimodal_processor = None + # For On-Policy Distillation (OPD), remote_rm_url is used for teacher model, + # not for reward model. So we don't pass it to RewardComputationEngine. + # Instead, we store it separately for _fetch_teacher_logprobs(). + if advantage_estimator in ("on_policy_distillation", "on_policy_distillation_hybrid"): + # Store teacher URL separately for OPD + self.teacher_model_url = self.remote_rm_url + rm_url_for_reward_engine = None # Don't use remote_rm_url for rewards in OPD mode + else: + self.teacher_model_url = None + rm_url_for_reward_engine = self.remote_rm_url + self.reward_engine = RewardComputationEngine( reward_model=self.reward_model, - remote_rm_url=self.remote_rm_url, + remote_rm_url=rm_url_for_reward_engine, custom_reward_func=getattr(self, "custom_reward_func", None), reward_fn=self.reward_fn, reward_fn_label_map=getattr(self, "reward_fn_label_map", None), @@ -1046,6 +1074,10 @@ def make_experience_list( self._process_multi_image_video_thws(experiences, expanded_images_num, expanded_videos_num) + # ========== Stage 6.5: On-Policy Distillation Teacher Log-Probs ========== + if config.advantage_estimator in ("on_policy_distillation", "on_policy_distillation_hybrid"): + self._fetch_teacher_logprobs(experiences) + # ========== Stage 7: Advantage Computation ========== experiences = self._compute_advantages_and_returns(experiences, rewards, generate_kwargs) @@ -1414,6 +1446,91 @@ def _process_multi_image_video_thws( else: experience.video_grid_thws = [None] * len(micro_videos_num) + def _fetch_teacher_logprobs( + self, + experiences: List[ExperienceVL], + ) -> None: + """ + Fetch teacher log probabilities for on-policy distillation. + + Uses input_ids (not text) to query teacher model, ensuring token-level + alignment between teacher and student log probabilities. + + Teacher log probs are aligned to the same shape as action_log_probs + [batch_size, seq_len], with zeros for prompt positions. + + :param experiences: List of experiences to add teacher log probs to + :type experiences: List[Union[Experience, ExperienceVL]] + """ + if not OPD_AVAILABLE: + raise RuntimeError( + "On-policy distillation module not available. " + "Make sure examples/on_policy_distillation/on_policy_distillation_reward.py exists." + ) + + # Get teacher URL from config + teacher_url = self.teacher_model_url + if isinstance(teacher_url, list): + teacher_url = teacher_url[0] if teacher_url else None + if teacher_url is None: + raise ValueError( + "Teacher model URL not specified. " + "Please set --remote_rm_url to the teacher model server URL." + ) + + Timer.start(' fetch_teacher_logprobs') + + for exp in experiences: + sequences = exp.sequences # [batch_size, seq_len] + action_mask = exp.action_mask # [batch_size, num_actions] + + # Get response lengths (number of generated tokens per sequence) + response_lengths = action_mask.sum(dim=-1).tolist() + num_actions = action_mask.shape[1] # action_log_probs dim + + # Use input_ids for teacher query to ensure token-level alignment + input_ids_list = sequences.cpu().tolist() + + # Query teacher model for log probs + try: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + teacher_lp_list = loop.run_until_complete( + get_teacher_logprobs_by_ids( + url=teacher_url, + input_ids_list=input_ids_list, + response_lengths=response_lengths, + ) + ) + finally: + loop.close() + + # Align teacher log probs to action_log_probs shape [batch_size, num_actions] + # teacher_lp_list[i] has shape [resp_len_i], need to pad/align to [num_actions] + batch_size = sequences.shape[0] + aligned_teacher_lp = torch.zeros(batch_size, num_actions, dtype=torch.float32) + for i, (tlp, resp_len) in enumerate(zip(teacher_lp_list, response_lengths)): + # Right-align: teacher log probs fill the last resp_len positions + # (matching where action_mask == 1) + actual_len = min(len(tlp), resp_len, num_actions) + start_pos = num_actions - resp_len + if start_pos >= 0: + aligned_teacher_lp[i, start_pos:start_pos + actual_len] = tlp[:actual_len] + else: + # resp_len > num_actions (shouldn't happen, but handle gracefully) + aligned_teacher_lp[i, :] = tlp[-num_actions:] + + # Store in experience info for advantage calculator + exp.info["teacher_log_probs"] = aligned_teacher_lp + + except Exception as e: + raise RuntimeError( + f"Failed to fetch teacher log probs from {teacher_url}: {e}" + ) from e + + Timer.stop(' fetch_teacher_logprobs') + def _process_experiences( self, experiences: List[ExperienceVL], diff --git a/lightrft/trainer/spmd_ppo_trainer.py b/lightrft/trainer/spmd_ppo_trainer.py index d79a745..0010a9b 100644 --- a/lightrft/trainer/spmd_ppo_trainer.py +++ b/lightrft/trainer/spmd_ppo_trainer.py @@ -321,6 +321,7 @@ def ppo_train(self, global_steps=0): # Currently using this rewritten ppo_train all_advantages = [] all_returns = [] all_response_lengths = [] + all_opd_reverse_kl = [] # For on-policy distillation metrics for item in self.replay_buffer.items: # Collect rewards @@ -347,6 +348,10 @@ def ppo_train(self, global_steps=0): # Currently using this rewritten ppo_train if hasattr(item, 'info') and item.info is not None and 'response_length' in item.info: all_response_lengths.append(item.info['response_length']) + # Collect on-policy distillation reverse KL + if hasattr(item, 'info') and item.info is not None and 'opd_reverse_kl' in item.info: + all_opd_reverse_kl.append(item.info['opd_reverse_kl']) + # Compute statistics # [TENSOR-FIX] Handle both tensor lists and scalar lists for all reward types if all_rewards: @@ -422,6 +427,22 @@ def ppo_train(self, global_steps=0): # Currently using this rewritten ppo_train status_mean["response_length_mean"] = lengths_tensor.float().mean().item() status_mean["response_length_std"] = lengths_tensor.float().std().item() + # On-Policy Distillation metrics + if all_opd_reverse_kl: + # Collect reverse KL from all experiences + if isinstance(all_opd_reverse_kl[0], torch.Tensor): + opd_kl_tensor = torch.cat([t.to(device).float() for t in all_opd_reverse_kl]) + else: + opd_kl_tensor = torch.tensor(all_opd_reverse_kl, dtype=torch.float32, device=device) + # Mask out zero values (padding) + non_zero_mask = opd_kl_tensor != 0 + if non_zero_mask.any(): + masked_kl = opd_kl_tensor[non_zero_mask] + status_mean["opd_reverse_kl_mean"] = masked_kl.mean().item() + status_mean["opd_reverse_kl_std"] = masked_kl.std().item() + status_mean["opd_reverse_kl_max"] = masked_kl.max().item() + status_mean["opd_reverse_kl_min"] = masked_kl.min().item() + # Print detailed reward breakdown (only on rank 0) if self.print_replay_buffer_stats and self.strategy.is_rank_0(): self.strategy.print("\n" + "=" * 60) @@ -460,6 +481,12 @@ def ppo_train(self, global_steps=0): # Currently using this rewritten ppo_train f"📏 Response Length: {status_mean['response_length_mean']:.1f} ± {status_mean['response_length_std']:.1f} tokens" # noqa ) + if all_opd_reverse_kl and 'opd_reverse_kl_mean' in status_mean: + self.strategy.print( + f"🎓 OPD Reverse KL: {status_mean['opd_reverse_kl_mean']:.4f} ± {status_mean['opd_reverse_kl_std']:.4f} " # noqa + f"(min={status_mean['opd_reverse_kl_min']:.4f}, max={status_mean['opd_reverse_kl_max']:.4f})" + ) + self.strategy.print("=" * 60 + "\n") torch.cuda.empty_cache()