From 0f3cd68f576216ba68c048b95db5d00283e2ffe5 Mon Sep 17 00:00:00 2001 From: puyuan1996 Date: Tue, 10 Feb 2026 23:46:38 +0800 Subject: [PATCH 01/18] feature(pu): add init version of on_policy_distillation --- examples/on_policy_distillation/README.md | 291 ++++++++++++++++++ .../on_policy_distillation_reward.py | 202 ++++++++++++ .../on_policy_distillation/run_opd_qwen.sh | 258 ++++++++++++++++ examples/on_policy_distillation/test_opd.py | 260 ++++++++++++++++ lightrft/trainer/advantage_calculator.py | 80 ++++- lightrft/trainer/experience_maker.py | 70 +++++ 6 files changed, 1160 insertions(+), 1 deletion(-) create mode 100644 examples/on_policy_distillation/README.md create mode 100644 examples/on_policy_distillation/on_policy_distillation_reward.py create mode 100644 examples/on_policy_distillation/run_opd_qwen.sh create mode 100644 examples/on_policy_distillation/test_opd.py diff --git a/examples/on_policy_distillation/README.md b/examples/on_policy_distillation/README.md new file mode 100644 index 00000000..754112f0 --- /dev/null +++ b/examples/on_policy_distillation/README.md @@ -0,0 +1,291 @@ +# On-Policy Distillation for LightRFT + +This directory contains a complete implementation of on-policy knowledge distillation for LightRFT, enabling smaller student models to learn from larger teacher models during reinforcement learning. + +## Overview + +On-policy distillation is a technique where: +- A **teacher model** (large, powerful) provides token-level supervision +- A **student model** (small, efficient) learns to match the teacher's probability distribution +- Training happens **on-policy**: teacher evaluates student's actual generated responses +- No separate reward model is needed - teacher's log probabilities serve as the learning signal + +## Quick Start + +### 1. Installation + +Ensure you have LightRFT installed with SGLang support: + +```bash +pip install lightrft +pip install sglang # For teacher model inference server +``` + +### 2. Prepare Your Dataset + +Your dataset should be in JSONL format with prompts: + +```json +{"prompt": "Solve: What is 2 + 2?"} +{"prompt": "Explain the theory of relativity."} +``` + +### 3. Run Training + +```bash +# Edit the configuration in run_opd_qwen.sh +bash examples/on_policy_distillation/run_opd_qwen.sh +``` + +## How It Works + +### Architecture + +``` +┌─────────────────────────────────────────────────────────┐ +│ Training Pipeline │ +└─────────────────────────────────────────────────────────┘ + +1. Student generates responses: + Prompt → [Student Model] → Response + +2. Teacher evaluates responses: + [Prompt + Response] → [Teacher Model] → Teacher Log Probs + +3. Advantage calculation: + Advantage = Teacher Log Probs - Student Log Probs + +4. Student optimization: + Update student to increase probability where teacher has high probability +``` + +### Key Components + +#### 1. OnPolicyDistillationCalculator (`lightrft/trainer/advantage_calculator.py`) + +Computes advantages using teacher log probabilities: + +```python +advantage = teacher_log_probs - student_log_probs +``` + +This encourages the student to match the teacher's token-level distribution. + +#### 2. Teacher Logprob Function (`on_policy_distillation_reward.py`) + +Queries the teacher model inference server to get log probabilities: + +```python +teacher_log_probs = get_teacher_logprobs_sync( + teacher_url=teacher_url, + sequences=sequences, + response_lengths=response_lengths +) +``` + +#### 3. Experience Maker Integration + +Modified `experience_maker.py` to: +- Query teacher model during experience collection +- Store teacher log probs in `experience.info["teacher_log_probs"]` +- Use OnPolicyDistillationCalculator for advantage computation + +## Configuration + +### Required Arguments + +```bash +--advantage_estimator "on_policy_distillation" # Enable on-policy distillation +--remote_rm_url "http://localhost:13141/generate" # Teacher model URL +--pretrain "Qwen/Qwen2.5-0.5B-Instruct" # Student model +``` + +### Recommended Hyperparameters + +| Parameter | Value | Description | +|-----------|-------|-------------| +| `n_samples_per_prompt` | 4 | Number of responses per prompt | +| `actor_learning_rate` | 1e-6 | Learning rate for student | +| `init_kl_coef` | 0.01 | KL coefficient for regularization | +| `num_episodes` | 30 | Number of training episodes | + +## Example Use Cases + +### 1. Math Reasoning (GSM8K) + +Train a small model to solve math problems like a larger model: + +```bash +TEACHER_MODEL_PATH="Qwen/Qwen2.5-7B-Instruct" +STUDENT_MODEL_PATH="Qwen/Qwen2.5-0.5B-Instruct" +DATASET_PATH="path/to/gsm8k.jsonl" +``` + +### 2. General Instruction Following + +Distill instruction-following capabilities: + +```bash +TEACHER_MODEL_PATH="Qwen/Qwen2.5-14B-Instruct" +STUDENT_MODEL_PATH="Qwen/Qwen2.5-1.5B-Instruct" +DATASET_PATH="path/to/instruction_data.jsonl" +``` + +### 3. Domain-Specific Tasks + +Transfer domain expertise from a fine-tuned teacher to a smaller student: + +```bash +TEACHER_MODEL_PATH="path/to/finetuned_teacher" +STUDENT_MODEL_PATH="Qwen/Qwen2.5-0.5B-Instruct" +DATASET_PATH="path/to/domain_data.jsonl" +``` + +## Technical Details + +### Advantage Computation + +The advantage estimator computes: + +```python +# Get teacher and student log probs for each token +teacher_log_probs = experience.info["teacher_log_probs"] +student_log_probs = experience.action_log_probs + +# Compute advantage (encourages matching teacher) +advantages = teacher_log_probs - student_log_probs + +# Apply action mask (only consider generated tokens) +advantages = advantages * experience.action_mask + +# Optional: normalize advantages +if config.advantages_norm: + advantages = (advantages - mean) / (std + 1e-8) +``` + +### Teacher Server Format + +The implementation supports both SGLang and vLLM formats: + +**SGLang format:** +```json +{ + "meta_info": { + "input_token_logprobs": [[logprob, rank, token], ...] + } +} +``` + +**vLLM format:** +```json +{ + "token_logprobs": [logprob1, logprob2, ...] +} +``` + +## Performance Tips + +### 1. GPU Memory Optimization + +- Run teacher on separate GPU(s) from training +- Use tensor parallelism for large teachers: `--tp 2` +- Adjust memory fraction: `--mem-fraction-static 0.6` + +### 2. Training Speed + +- Increase `n_samples_per_prompt` for more stable gradients (but slower) +- Use larger batch sizes if memory permits +- Enable gradient checkpointing for memory-intensive models + +### 3. Convergence + +- Start with lower learning rate (1e-6) for stable distillation +- Use KL coefficient to prevent student from diverging too far +- Monitor teacher-student log prob difference in W&B + +## Troubleshooting + +### Teacher server won't start + +```bash +# Check GPU availability +nvidia-smi + +# Check if port is already in use +lsof -i :13141 + +# Try different memory fraction +--mem-fraction-static 0.5 +``` + +### Training OOM (Out of Memory) + +```bash +# Reduce batch sizes +--micro_train_batch_size 2 +--micro_rollout_batch_size 2 + +# Enable gradient checkpointing +--gradient_checkpointing + +# Use ZeRO-3 optimization +--zero_stage 3 +``` + +### Slow convergence + +```bash +# Increase samples per prompt +--n_samples_per_prompt 8 + +# Adjust learning rate +--actor_learning_rate 5e-7 + +# Increase training episodes +--num_episodes 50 +``` + +## Comparison with Other Methods + +| Method | Reward Signal | Offline/Online | Requires RM | +|--------|--------------|----------------|-------------| +| GRPO | Task-specific reward | Online | Yes | +| DPO | Preference pairs | Offline | No | +| **On-Policy Distillation** | Teacher log probs | Online | No (uses teacher) | + +**Advantages:** +- ✅ No need to train a separate reward model +- ✅ Token-level supervision (finer-grained than sequence-level rewards) +- ✅ On-policy: adapts to student's changing distribution +- ✅ Works for any task where you have a good teacher model + +**Limitations:** +- ⚠️ Requires a teacher model (inference overhead) +- ⚠️ Student cannot exceed teacher's capabilities +- ⚠️ Needs sufficient compute for teacher inference + +## References + +- [Original slime implementation](https://github.com/OpenRLHF/slime) +- [LightRFT Documentation](../../README.md) +- [On-Policy Distillation Paper](https://arxiv.org/abs/XXXX.XXXXX) + +## Citation + +If you use this implementation, please cite: + +```bibtex +@software{lightrft_opd, + title={On-Policy Distillation for LightRFT}, + author={LightRFT Team}, + year={2024}, + url={https://github.com/yourusername/LightRFT} +} +``` + +## Support + +For questions or issues: +- Open an issue on GitHub +- Check the [FAQ](../../docs/source/best_practice/faq.md) +- Review [troubleshooting guide](../../docs/source/best_practice/troubleshooting.md) diff --git a/examples/on_policy_distillation/on_policy_distillation_reward.py b/examples/on_policy_distillation/on_policy_distillation_reward.py new file mode 100644 index 00000000..1c5629e0 --- /dev/null +++ b/examples/on_policy_distillation/on_policy_distillation_reward.py @@ -0,0 +1,202 @@ +""" +On-Policy Distillation Reward Function for LightRFT + +This module provides a custom reward function that queries a teacher model +to obtain log probabilities for knowledge distillation during RL training. + +The teacher model runs as a separate inference server (vLLM or SGLang), +and this function queries it to get token-level log probabilities for +the sequences generated by the student model. +""" + +import asyncio +import aiohttp +import torch +import numpy as np +from typing import List, Dict, Any, Optional + + +async def get_teacher_logprobs_async( + url: str, + sequences: List[str], + session: Optional[aiohttp.ClientSession] = None +) -> List[Dict[str, Any]]: + """ + Asynchronously query teacher model for log probabilities. + + :param url: URL of the teacher model inference server + :type url: str + :param sequences: List of full sequences (prompt + response) + :type sequences: List[str] + :param session: Optional aiohttp session for connection reuse + :type session: Optional[aiohttp.ClientSession] + :return: List of response dictionaries containing log probabilities + :rtype: List[Dict[str, Any]] + """ + should_close_session = session is None + if session is None: + session = aiohttp.ClientSession() + + try: + tasks = [] + for sequence in sequences: + payload = { + "text": sequence, + "sampling_params": { + "temperature": 0, + "max_tokens": 0, # No new generation, just logprobs + "skip_special_tokens": False, + }, + "return_logprob": True, + "logprob_start_len": 0, # Get logprobs from the beginning + } + tasks.append(session.post(url, json=payload)) + + responses = await asyncio.gather(*tasks) + results = [] + for resp in responses: + resp.raise_for_status() + results.append(await resp.json()) + + return results + finally: + if should_close_session: + await session.close() + + +def extract_teacher_logprobs( + teacher_responses: List[Dict[str, Any]], + response_lengths: List[int], + device: str = "cpu" +) -> List[torch.Tensor]: + """ + Extract teacher log probabilities for the response tokens only. + + :param teacher_responses: List of teacher model API responses + :type teacher_responses: List[Dict[str, Any]] + :param response_lengths: Number of response tokens for each sequence + :type response_lengths: List[int] + :param device: Target device for tensors + :type device: str + :return: List of tensors containing teacher log probs for response tokens + :rtype: List[torch.Tensor] + """ + teacher_log_probs_list = [] + + for response, response_length in zip(teacher_responses, response_lengths): + # Extract log probabilities from teacher response + # The format depends on the inference server (vLLM/SGLang) + if "meta_info" in response and "input_token_logprobs" in response["meta_info"]: + # SGLang format + logprobs = response["meta_info"]["input_token_logprobs"] + # logprobs is a list of [logprob, rank, decoded_token] tuples + # Extract just the logprob values + logprob_values = [item[0] if isinstance(item, list) else item for item in logprobs] + # Skip the first token (it doesn't have a logprob) and take last response_length tokens + teacher_log_probs = torch.tensor(logprob_values[1:], dtype=torch.float32) + teacher_log_probs = teacher_log_probs[-response_length:] + elif "prompt_logprobs" in response or "token_logprobs" in response: + # vLLM format + logprobs = response.get("token_logprobs", response.get("prompt_logprobs", [])) + # Filter out None values and convert to tensor + logprob_values = [lp for lp in logprobs if lp is not None] + teacher_log_probs = torch.tensor(logprob_values, dtype=torch.float32) + teacher_log_probs = teacher_log_probs[-response_length:] + else: + raise ValueError( + f"Unknown response format from teacher model. " + f"Expected 'meta_info' (SGLang) or 'token_logprobs' (vLLM). " + f"Got keys: {response.keys()}" + ) + + teacher_log_probs_list.append(teacher_log_probs.to(device)) + + return teacher_log_probs_list + + +def reward_func(queries: List[str], prompts: List[str], **kwargs) -> torch.Tensor: + """ + Custom reward function for on-policy distillation. + + This function is called by LightRFT's experience maker to compute rewards. + It queries the teacher model and returns a placeholder reward tensor. + The actual teacher log probs are stored separately and used by the + OnPolicyDistillationCalculator. + + :param queries: List of full sequences (prompt + response) + :type queries: List[str] + :param prompts: List of prompts (unused, for compatibility) + :type prompts: List[str] + :return: Placeholder reward tensor (zeros) + :rtype: torch.Tensor + """ + # Return placeholder rewards + # The actual advantage computation happens in OnPolicyDistillationCalculator + # using teacher log probs stored in experience.info + return torch.zeros(len(queries), dtype=torch.float32) + + +async def get_teacher_logprobs_for_experiences( + teacher_url: str, + sequences: List[str], + response_lengths: List[int], + device: str = "cpu" +) -> torch.Tensor: + """ + Get teacher log probabilities for a batch of sequences. + + This is the main entry point for obtaining teacher log probs during training. + + :param teacher_url: URL of the teacher model server + :type teacher_url: str + :param sequences: List of full sequences (prompt + response) + :type sequences: List[str] + :param response_lengths: Number of response tokens for each sequence + :type response_lengths: List[int] + :param device: Target device for tensors + :type device: str + :return: Tensor of teacher log probs, padded to match response lengths + :rtype: torch.Tensor + """ + # Query teacher model + responses = await get_teacher_logprobs_async(teacher_url, sequences) + + # Extract log probs + teacher_log_probs_list = extract_teacher_logprobs(responses, response_lengths, device) + + # Pad to uniform length if needed + max_length = max(response_lengths) + padded_log_probs = [] + for log_probs, response_length in zip(teacher_log_probs_list, response_lengths): + if len(log_probs) < max_length: + padding = torch.zeros(max_length - len(log_probs), dtype=torch.float32, device=device) + log_probs = torch.cat([log_probs, padding]) + padded_log_probs.append(log_probs) + + return torch.stack(padded_log_probs) + + +# Synchronous wrapper for compatibility with LightRFT +def get_teacher_logprobs_sync( + teacher_url: str, + sequences: List[str], + response_lengths: List[int], + device: str = "cpu" +) -> torch.Tensor: + """ + Synchronous wrapper for getting teacher log probabilities. + + :param teacher_url: URL of the teacher model server + :type teacher_url: str + :param sequences: List of full sequences (prompt + response) + :type sequences: List[str] + :param response_lengths: Number of response tokens for each sequence + :type response_lengths: List[int] + :param device: Target device for tensors + :type device: str + :return: Tensor of teacher log probs + :rtype: torch.Tensor + """ + return asyncio.run( + get_teacher_logprobs_for_experiences(teacher_url, sequences, response_lengths, device) + ) diff --git a/examples/on_policy_distillation/run_opd_qwen.sh b/examples/on_policy_distillation/run_opd_qwen.sh new file mode 100644 index 00000000..aaa0aef9 --- /dev/null +++ b/examples/on_policy_distillation/run_opd_qwen.sh @@ -0,0 +1,258 @@ +#!/bin/bash +# +# LightRFT On-Policy Distillation Training Script +# This script demonstrates knowledge distillation from Qwen2.5-7B (teacher) to Qwen2.5-0.5B (student) +# using on-policy distillation during reinforcement learning. +# +# Key Features: +# - No separate reward model needed - teacher model provides the learning signal +# - Token-level supervision from teacher log probabilities +# - On-policy: teacher evaluates student's actual generated responses +# + +set -e + +################################################################################ +# Part 1: User Configuration # +################################################################################ + +# --- Model Paths --- +# Teacher model (larger, provides learning signal) +TEACHER_MODEL_PATH="Qwen/Qwen2.5-7B-Instruct" + +# Student model (smaller, being trained) +STUDENT_MODEL_PATH="Qwen/Qwen2.5-0.5B-Instruct" + +# --- Dataset Path --- +# Path to your training dataset (JSONL format) +# Each line should be a JSON object with a "prompt" field +DATASET_PATH="/path/to/your/dataset.jsonl" + +# --- Teacher Model Server Configuration --- +TEACHER_IP="127.0.0.1" +TEACHER_PORT=13141 +TEACHER_GPU=7 # GPU to run teacher model on + +# --- Experiment Configuration --- +EXPERIMENT_NAME="opd-qwen-7b-to-0.5b" +export WANDB_API_KEY="YOUR_WANDB_API_KEY" +export WANDB_PROJECT="LightRFT-OnPolicyDistillation" + +################################################################################ +# Part 2: Training Hyperparameters # +################################################################################ + +# --- Distillation Settings --- +N_SAMPLES=4 # Number of samples per prompt +EPISODE=30 # Total number of training episodes +WARMUP=0.03 # Learning rate warmup ratio + +# --- Batch Size Configuration --- +RBS=128 # Rollout Batch Size +TBS=128 # Train Batch Size + +# --- Learning Settings --- +KL=0.01 # KL divergence coefficient (for regularization) +LR=1e-6 # Student learning rate +MAX_LENGTH=3072 # Max sequence length +PROMPT_MAX_LEN=1024 # Max prompt length +GENERATE_MAX_LEN=2048 # Max generation length + +################################################################################ +# Part 3: Distributed Training Setup # +################################################################################ + +# --- Single-Node Setup --- +export MLP_WORKER_NUM=1 +export MLP_WORKER_GPU=8 +export MLP_ROLE_INDEX=0 +export MLP_WORKER_0_HOST="localhost" +export MLP_WORKER_0_PORT=20090 + +# --- PyTorch Distributed Variables --- +export MASTER_ADDR=$MLP_WORKER_0_HOST +export MASTER_PORT=$MLP_WORKER_0_PORT +export NNODES=$MLP_WORKER_NUM +export NODE_RANK=$MLP_ROLE_INDEX +export GPUS_PER_NODE=$MLP_WORKER_GPU + +# --- vLLM Engine Settings --- +ENGINE_TP=2 # Tensor parallelism for inference engine + +################################################################################ +# Part 4: Start Teacher Model Server # +################################################################################ + +echo "=========================================" +echo "Starting Teacher Model Server" +echo "=========================================" + +# Generate unique log file for teacher server +LOG_FILE="/tmp/teacher_model_$(date +%Y%m%d_%H%M%S).log" + +# Launch teacher model server in background +CUDA_VISIBLE_DEVICES=$TEACHER_GPU python3 -m sglang.launch_server \ + --model-path "$TEACHER_MODEL_PATH" \ + --host 0.0.0.0 \ + --port $TEACHER_PORT \ + --tp 1 \ + --chunked-prefill-size 4096 \ + --mem-fraction-static 0.6 \ + > "$LOG_FILE" 2>&1 & + +TEACHER_PID=$! +echo "Teacher model server starting (PID: $TEACHER_PID)..." +echo "Logs: $LOG_FILE" + +# Wait for teacher model server to be ready +MAX_WAIT=300 # Maximum wait time in seconds +WAITED=0 +until curl -sf http://$TEACHER_IP:$TEACHER_PORT/health > /dev/null 2>&1; do + if [ $WAITED -ge $MAX_WAIT ]; then + echo "ERROR: Teacher model server failed to start within $MAX_WAIT seconds" + echo "Last 20 lines of log:" + tail -n 20 "$LOG_FILE" + kill $TEACHER_PID 2>/dev/null || true + exit 1 + fi + echo "Waiting for teacher model server to start... ($WAITED/$MAX_WAIT seconds)" + tail -n 5 "$LOG_FILE" + sleep 5 + WAITED=$((WAITED + 5)) +done + +echo "✓ Teacher model server is up and running at $TEACHER_IP:$TEACHER_PORT" +sleep 5 + +################################################################################ +# Part 5: Training Setup # +################################################################################ + +# --- Generate dynamic names --- +current_time=$(date +"%Y%m%d_%H%M%S") +SAVE_MODEL_NAME="${EXPERIMENT_NAME}-ep${EPISODE}-kl${KL}-lr${LR}-${current_time}" +WANDB_RUN_NAME="${EXPERIMENT_NAME}-${current_time}" + +# --- Create directories --- +mkdir -p "results/${EXPERIMENT_NAME}/${SAVE_MODEL_NAME}" +mkdir -p "rft_logs/${EXPERIMENT_NAME}" + +# --- Environment optimizations --- +export TORCH_NCCL_AVOID_RECORD_STREAMS=1 +export NCCL_DEBUG="WARN" +export IGNORE_EOS=0 +export WANDB_MODE="offline" # Set to "online" for real-time logging + +# --- Teacher model URL for distillation --- +TEACHER_URL="http://$TEACHER_IP:$TEACHER_PORT/generate" + +set -x + +################################################################################ +# Part 6: Launch Training # +################################################################################ + +echo "=========================================" +echo "Starting On-Policy Distillation Training" +echo "=========================================" + +# Function to cleanup on exit +cleanup() { + echo "Cleaning up..." + kill $TEACHER_PID 2>/dev/null || true + pkill -f "sglang.launch_server" 2>/dev/null || true + echo "Cleanup complete" +} +trap cleanup EXIT + +torchrun \ + --nnodes $NNODES \ + --nproc-per-node $GPUS_PER_NODE \ + --node_rank $NODE_RANK \ + --master-port $MASTER_PORT \ + --master-addr $MASTER_ADDR \ + examples/gsm8k_geo3k/train_colocate.py \ + --pretrain "$STUDENT_MODEL_PATH" \ + --save_trajectories \ + --advantage_estimator "on_policy_distillation" \ + --fsdp \ + --use_kl_loss \ + --flash_attn \ + --rm_use_engine \ + --reward_pretrain "{}" \ + --remote_rm_url "$TEACHER_URL" \ + --save_path "results/${EXPERIMENT_NAME}/${SAVE_MODEL_NAME}" \ + --ckpt_path "results/${EXPERIMENT_NAME}/${SAVE_MODEL_NAME}" \ + --micro_train_batch_size 4 \ + --train_batch_size ${TBS} \ + --micro_rollout_batch_size 4 \ + --rollout_batch_size ${RBS} \ + --max_epochs 1 \ + --num_episodes ${EPISODE} \ + --lr_warmup_ratio ${WARMUP} \ + --n_samples_per_prompt $N_SAMPLES \ + --prompt_max_len $PROMPT_MAX_LEN \ + --generate_max_len $GENERATE_MAX_LEN \ + --zero_stage 3 \ + --bf16 \ + --actor_learning_rate $LR \ + --init_kl_coef $KL \ + --kl_estimator "k3" \ + --prompt_data "$DATASET_PATH" \ + --max_ckpt_num 3 \ + --max_ckpt_mem 160 \ + --use_wandb \ + --wandb_project "$WANDB_PROJECT" \ + --wandb_run_name "$WANDB_RUN_NAME" \ + --logging_steps 1 \ + --eval_steps -1 \ + --rm_engine_tp $ENGINE_TP + +echo "Training complete!" + +################################################################################ +# Part 7: Usage Instructions # +################################################################################ + +: <<'USAGE' +=============================================================================== +Usage Instructions +=============================================================================== + +1. Prerequisites: + - Install LightRFT and dependencies + - Install SGLang: pip install sglang + - Prepare your training dataset in JSONL format + +2. Configure the script: + - Set TEACHER_MODEL_PATH to your teacher model + - Set STUDENT_MODEL_PATH to your student model + - Set DATASET_PATH to your training data + - Adjust hyperparameters as needed + +3. Run the script: + bash examples/on_policy_distillation/run_opd_qwen.sh + +4. Monitor training: + - Check W&B dashboard for training metrics + - Logs are saved in rft_logs/${EXPERIMENT_NAME}/ + - Checkpoints are saved in results/${EXPERIMENT_NAME}/ + +5. Key Parameters: + - N_SAMPLES: Number of responses per prompt (higher = more stable but slower) + - LR: Learning rate (typically 1e-6 for distillation) + - KL: KL coefficient for regularization (keeps student close to initialization) + - EPISODE: Number of training episodes + +6. Expected Behavior: + - Student model should gradually match teacher's probability distribution + - Training loss should decrease over episodes + - Student responses should become more similar to teacher's style + +7. Troubleshooting: + - If teacher server fails: Check GPU memory and CUDA availability + - If training OOMs: Reduce batch sizes or enable gradient checkpointing + - If convergence is slow: Adjust learning rate or increase N_SAMPLES + +=============================================================================== +USAGE diff --git a/examples/on_policy_distillation/test_opd.py b/examples/on_policy_distillation/test_opd.py new file mode 100644 index 00000000..e2476c48 --- /dev/null +++ b/examples/on_policy_distillation/test_opd.py @@ -0,0 +1,260 @@ +""" +Test script for On-Policy Distillation implementation in LightRFT. + +This script validates the core components of the on-policy distillation mechanism. +""" + +import torch +import sys +from pathlib import Path + +# Add LightRFT to path +lightrft_path = Path(__file__).parent.parent.parent +sys.path.insert(0, str(lightrft_path)) + +def test_advantage_calculator(): + """Test OnPolicyDistillationCalculator.""" + print("Testing OnPolicyDistillationCalculator...") + + from lightrft.trainer.advantage_calculator import get_advantage_calculator + + # Create mock config + class MockConfig: + advantages_norm = True + advantage_clip = 10.0 + + config = MockConfig() + + # Create calculator + calculator = get_advantage_calculator("on_policy_distillation", config) + print(f"✓ Created calculator: {calculator.__class__.__name__}") + + # Create mock experience + class MockExperience: + def __init__(self): + self.action_log_probs = torch.tensor([[0.1, 0.2, 0.3], [0.2, 0.3, 0.4]]) + self.action_mask = torch.tensor([[True, True, True], [True, True, False]]) + self.info = { + "teacher_log_probs": torch.tensor([[0.3, 0.4, 0.5], [0.4, 0.5, 0.6]]) + } + + experience = MockExperience() + final_reward = torch.zeros_like(experience.action_log_probs) + generate_kwargs = {} + + # Compute advantages + advantages, returns, info = calculator.compute( + experience, final_reward, gamma=1.0, generate_kwargs=generate_kwargs + ) + + print(f"✓ Computed advantages: {advantages.shape}") + print(f" Advantages sample: {advantages[0]}") + print(f" Expected positive (teacher > student): {(advantages > 0).float().mean():.2f}") + + # Test error case (missing teacher_log_probs) + experience_no_teacher = MockExperience() + del experience_no_teacher.info["teacher_log_probs"] + + try: + calculator.compute(experience_no_teacher, final_reward, gamma=1.0, generate_kwargs=generate_kwargs) + print("✗ Should have raised ValueError for missing teacher_log_probs") + return False + except ValueError as e: + print(f"✓ Correctly raised ValueError: {str(e)[:50]}...") + + print("✓ OnPolicyDistillationCalculator tests passed\n") + return True + + +def test_reward_function(): + """Test teacher logprob extraction functions.""" + print("Testing teacher logprob functions...") + + from examples.on_policy_distillation.on_policy_distillation_reward import ( + extract_teacher_logprobs + ) + + # Test SGLang format + sglang_response = { + "meta_info": { + "input_token_logprobs": [ + None, # First token has no logprob + [-0.1, 1, "hello"], + [-0.2, 2, "world"], + [-0.15, 1, "!"], + ] + } + } + + teacher_log_probs = extract_teacher_logprobs( + [sglang_response], + response_lengths=[3], + device="cpu" + ) + + print(f"✓ Extracted teacher log probs (SGLang format): {teacher_log_probs[0]}") + assert len(teacher_log_probs[0]) == 3, "Should extract exactly 3 response tokens" + assert torch.allclose(teacher_log_probs[0], torch.tensor([-0.1, -0.2, -0.15])), "Values mismatch" + + # Test vLLM format + vllm_response = { + "token_logprobs": [None, -0.1, -0.2, -0.15, -0.3] + } + + teacher_log_probs = extract_teacher_logprobs( + [vllm_response], + response_lengths=[3], + device="cpu" + ) + + print(f"✓ Extracted teacher log probs (vLLM format): {teacher_log_probs[0]}") + assert len(teacher_log_probs[0]) == 3, "Should extract exactly 3 response tokens" + + print("✓ Teacher logprob extraction tests passed\n") + return True + + +def test_factory_function(): + """Test that on_policy_distillation is registered in factory.""" + print("Testing factory function registration...") + + from lightrft.trainer.advantage_calculator import get_advantage_calculator + + class MockConfig: + advantages_norm = False + advantage_clip = 0.0 + + config = MockConfig() + + # Test valid estimators + estimators = [ + "gae", + "reinforce", + "rloo", + "reinforce_baseline", + "group_norm", + "cpgd", + "on_policy_distillation" + ] + + for estimator in estimators: + try: + calc = get_advantage_calculator(estimator, config) + print(f"✓ Created {estimator}: {calc.__class__.__name__}") + except Exception as e: + print(f"✗ Failed to create {estimator}: {e}") + return False + + # Test invalid estimator + try: + get_advantage_calculator("invalid_estimator", config) + print("✗ Should have raised ValueError for invalid estimator") + return False + except ValueError: + print("✓ Correctly raised ValueError for invalid estimator") + + print("✓ Factory function tests passed\n") + return True + + +def test_integration(): + """Test basic integration flow.""" + print("Testing integration flow...") + + # This is a simplified test to ensure components work together + from lightrft.trainer.advantage_calculator import OnPolicyDistillationCalculator + + class MockConfig: + advantages_norm = True + advantage_clip = 5.0 + + calculator = OnPolicyDistillationCalculator(MockConfig()) + + # Simulate a batch of experiences + batch_size = 4 + seq_len = 10 + + class MockExperience: + def __init__(self): + # Student generated these log probs + self.action_log_probs = torch.randn(batch_size, seq_len) * 0.5 - 1.0 + + # Teacher evaluated and got these log probs + # Teacher is better, so generally higher log probs + self.info = { + "teacher_log_probs": torch.randn(batch_size, seq_len) * 0.3 - 0.5 + } + + # Action mask (last 2 tokens are padding) + self.action_mask = torch.ones(batch_size, seq_len, dtype=torch.bool) + self.action_mask[:, -2:] = False + + experience = MockExperience() + final_reward = torch.zeros_like(experience.action_log_probs) + + advantages, returns, info = calculator.compute( + experience, final_reward, gamma=1.0, generate_kwargs={} + ) + + print(f"✓ Computed advantages for batch: {advantages.shape}") + print(f" Mean advantage: {advantages.mean():.4f}") + print(f" Std advantage: {advantages.std():.4f}") + print(f" Masked correctly: {advantages[:, -2:].sum() == 0}") + + # Check that advantages are normalized (approximately) + masked_adv = advantages[experience.action_mask] + assert abs(masked_adv.mean()) < 0.1, "Advantages should be normalized (mean ≈ 0)" + print(f"✓ Advantages are normalized (mean={masked_adv.mean():.4f})") + + print("✓ Integration tests passed\n") + return True + + +def run_all_tests(): + """Run all tests.""" + print("=" * 70) + print("On-Policy Distillation Test Suite") + print("=" * 70) + print() + + tests = [ + ("Factory Function", test_factory_function), + ("Advantage Calculator", test_advantage_calculator), + ("Reward Function", test_reward_function), + ("Integration", test_integration), + ] + + results = [] + for name, test_fn in tests: + try: + result = test_fn() + results.append((name, result)) + except Exception as e: + print(f"✗ {name} failed with exception: {e}") + import traceback + traceback.print_exc() + results.append((name, False)) + + print("=" * 70) + print("Test Summary") + print("=" * 70) + + all_passed = True + for name, passed in results: + status = "✓ PASS" if passed else "✗ FAIL" + print(f"{status}: {name}") + if not passed: + all_passed = False + + print("=" * 70) + if all_passed: + print("✓ All tests passed!") + return 0 + else: + print("✗ Some tests failed") + return 1 + + +if __name__ == "__main__": + exit_code = run_all_tests() + sys.exit(exit_code) diff --git a/lightrft/trainer/advantage_calculator.py b/lightrft/trainer/advantage_calculator.py index 1ce9a8f7..7f916804 100644 --- a/lightrft/trainer/advantage_calculator.py +++ b/lightrft/trainer/advantage_calculator.py @@ -717,6 +717,82 @@ def compute( return advantages, returns, info_dict +class OnPolicyDistillationCalculator(AdvantageCalculator): + """ + On-Policy Distillation calculator. + + Uses teacher model's log probabilities as learning signal for knowledge distillation. + The advantage is computed as the difference between teacher and student log probabilities, + encouraging the student to match the teacher's token-level distribution. + + Reference: On-policy distillation from teacher models during RL training + """ + def compute( + self, + experience, + final_reward: torch.Tensor, + gamma: Optional[float], + generate_kwargs: Dict, + ) -> Tuple[torch.Tensor, torch.Tensor, Dict]: + """ + Compute advantages using teacher log probabilities. + + The advantage is computed as: + advantage = teacher_log_probs - student_log_probs + + This encourages the student model to match the teacher's probability distribution. + + :param experience: Experience object containing teacher_log_probs in info dict + :type experience: object + :param final_reward: Unused for on_policy_distillation + :type final_reward: torch.Tensor + :param gamma: Discount factor. Unused for on_policy_distillation. + :type gamma: Optional[float] + :param generate_kwargs: Unused + :type generate_kwargs: Dict + :return: Tuple of (advantages, returns, info_dict) + :rtype: Tuple[torch.Tensor, torch.Tensor, Dict] + """ + # Get teacher log probs from experience info + if "teacher_log_probs" not in experience.info: + raise ValueError( + "teacher_log_probs not found in experience.info. " + "Make sure to use the on_policy_distillation reward function." + ) + + teacher_log_probs = experience.info["teacher_log_probs"].to(final_reward.device) + + # Student log probs are already computed in experience.action_log_probs + student_log_probs = experience.action_log_probs + + # Compute advantage as teacher - student log probs + # This encourages student to increase probability where teacher has higher probability + advantages = teacher_log_probs - student_log_probs + + # Apply action mask to ensure we only consider generated tokens + if experience.action_mask is not None: + advantages = advantages * experience.action_mask + + # Returns are the same as advantages for distillation + returns = deepcopy(advantages) + + # Advantage whitening (normalization) + info_dict = {} + if self.config.advantages_norm: + masked_adv = torch.masked_select(advantages, experience.action_mask) + adv_mean = masked_adv.mean() + adv_std = masked_adv.std() + advantages = (advantages - adv_mean) / (adv_std + 1e-8) + + # Advantage clipping + if self.config.advantage_clip > 0: + clip_val = self.config.advantage_clip + info_dict["advantage_clip_frac"] = compute_clip_fraction(advantages, clip_val, -clip_val) + advantages = torch.clamp(advantages, -clip_val, clip_val) + + return advantages, returns, info_dict + + # ============================================================================ # Factory Function # ============================================================================ @@ -728,7 +804,8 @@ def get_advantage_calculator(estimator_name: str, config) -> AdvantageCalculator :param estimator_name: Name of the advantage estimation method Options: "gae", "cpgd", "reinforce", "rloo", - "reinforce_baseline", "group_norm", "grpo" + "reinforce_baseline", "group_norm", "grpo", + "on_policy_distillation" :type estimator_name: str :param config: Configuration object containing training parameters :type config: object @@ -744,6 +821,7 @@ def get_advantage_calculator(estimator_name: str, config) -> AdvantageCalculator "reinforce_baseline": REINFORCEBaselineCalculator, "cpgd": CPGDCalculator, "grpo": GroupNormCalculator, # Alias for group_norm + "on_policy_distillation": OnPolicyDistillationCalculator, } calculator_class = calculator_map.get(estimator_name) diff --git a/lightrft/trainer/experience_maker.py b/lightrft/trainer/experience_maker.py index 0ef2065a..c545275e 100644 --- a/lightrft/trainer/experience_maker.py +++ b/lightrft/trainer/experience_maker.py @@ -390,6 +390,20 @@ def make_experience_list(self, all_prompts: Union[str, List[str]], **generate_kw generate_kwargs["gamma"], ) experience.advantages = deepcopy(experience.returns) + elif self.advantage_estimator == "on_policy_distillation": + # For on_policy_distillation, use teacher log probs from experience.info + # The actual advantage computation happens in OnPolicyDistillationCalculator + # Here we just set placeholder values + teacher_log_probs = experience.info.get("teacher_log_probs") + if teacher_log_probs is None: + raise ValueError( + "teacher_log_probs not found in experience.info. " + "This should have been set in process_experiences()." + ) + # Set placeholder returns and advantages + # They will be properly computed in the advantage calculator + experience.returns = torch.zeros_like(experience.action_log_probs) + experience.advantages = torch.zeros_like(experience.action_log_probs) else: raise Exception(f"Unknown advantage_estimator {self.advantage_estimator}") @@ -544,6 +558,62 @@ def process_experiences(self, experiences: List[Experience]) -> Tuple[List[Exper """ args = self.strategy.args + # On-policy distillation: query teacher model for log probs + if args.advantage_estimator == "on_policy_distillation": + if self.remote_rm_url is None or len(self.remote_rm_url) == 0: + raise ValueError( + "On-policy distillation requires a teacher model URL. " + "Please set --remote_rm_url to the teacher model inference server." + ) + + # Import the teacher logprob function + import asyncio + import sys + import os.path + teacher_url = self.remote_rm_url[0] if isinstance(self.remote_rm_url, list) else self.remote_rm_url + + # Collect all sequences and response lengths + all_sequences = [] + all_response_lengths = [] + for experience in experiences: + sequences_batch = experience.sequences + response_lengths = experience.info["response_length"] + + # Decode sequences to text + for i, seq in enumerate(sequences_batch): + decoded_seq = self.tokenizer.decode(seq.cpu().tolist(), skip_special_tokens=False) + all_sequences.append(decoded_seq) + all_response_lengths.append(int(response_lengths[i].item())) + + # Query teacher model for log probs + try: + # Import the custom teacher logprob function + from examples.on_policy_distillation.on_policy_distillation_reward import ( + get_teacher_logprobs_sync + ) + + teacher_log_probs = get_teacher_logprobs_sync( + teacher_url=teacher_url, + sequences=all_sequences, + response_lengths=all_response_lengths, + device="cpu" + ) + + # Split and store teacher log probs in each experience + idx = 0 + for experience in experiences: + batch_size = experience.sequences.size(0) + experience.info["teacher_log_probs"] = teacher_log_probs[idx:idx + batch_size] + idx += batch_size + + except Exception as e: + logger.error(f"Failed to get teacher log probs: {e}") + raise + + # Return placeholder rewards (actual learning signal comes from teacher log probs) + rewards = [torch.zeros_like(experience.info["reward"]) for experience in experiences] + return experiences, rewards + # Reward shaping for RLOO if args.advantage_estimator == "rloo": rewards = torch.cat([experience.info["reward"] for experience in experiences]) From 5b5f9fb8abbe671945bd48efd453966b94e0261e Mon Sep 17 00:00:00 2001 From: puyuan1996 Date: Mon, 16 Mar 2026 16:41:34 +0800 Subject: [PATCH 02/18] fix(pu): fix opd implementation --- examples/gsm8k_geo3k/train_colocate.py | 5 +- .../on_policy_distillation_reward.py | 138 +++++++++++++----- .../on_policy_distillation/run_opd_qwen.sh | 3 +- lightrft/trainer/advantage_calculator.py | 17 ++- lightrft/trainer/fast_exp_maker.py | 91 ++++++++++++ 5 files changed, 214 insertions(+), 40 deletions(-) diff --git a/examples/gsm8k_geo3k/train_colocate.py b/examples/gsm8k_geo3k/train_colocate.py index 818bcce2..43435291 100644 --- a/examples/gsm8k_geo3k/train_colocate.py +++ b/examples/gsm8k_geo3k/train_colocate.py @@ -384,7 +384,8 @@ def train(args): top_p=args.top_p, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, - # reward model + # reward model / teacher model URL (used for OPD) + remote_rm_url=args.remote_rm_url, reward_fn=reward_fn, reward_fn_label_map=label_map, reward_recipe=RECIPE, @@ -528,7 +529,7 @@ def train(args): parser.add_argument( "--advantage_estimator", type=str, - choices=["gae", "reinforce", "rloo", "reinforce_baseline", "group_norm", "cpgd", "reinforce++"], + choices=["gae", "reinforce", "rloo", "reinforce_baseline", "group_norm", "cpgd", "reinforce++", "on_policy_distillation"], default="gae", help="Choose advantage estimation method: gae, reinforce, rloo, reinforce_baseline, group_norm, reinforce++", ) diff --git a/examples/on_policy_distillation/on_policy_distillation_reward.py b/examples/on_policy_distillation/on_policy_distillation_reward.py index 1c5629e0..0729d6f8 100644 --- a/examples/on_policy_distillation/on_policy_distillation_reward.py +++ b/examples/on_policy_distillation/on_policy_distillation_reward.py @@ -7,19 +7,29 @@ The teacher model runs as a separate inference server (vLLM or SGLang), and this function queries it to get token-level log probabilities for the sequences generated by the student model. + +Key differences from Slime implementation: +- Uses text input (SGLang will tokenize internally) +- Uses async HTTP requests for efficiency +- Includes retry logic for robustness """ import asyncio import aiohttp import torch import numpy as np -from typing import List, Dict, Any, Optional +import logging +from typing import List, Dict, Any, Optional, Union + +logger = logging.getLogger(__name__) async def get_teacher_logprobs_async( url: str, sequences: List[str], - session: Optional[aiohttp.ClientSession] = None + session: Optional[aiohttp.ClientSession] = None, + max_retries: int = 3, + retry_delay: float = 1.0, ) -> List[Dict[str, Any]]: """ Asynchronously query teacher model for log probabilities. @@ -30,34 +40,52 @@ async def get_teacher_logprobs_async( :type sequences: List[str] :param session: Optional aiohttp session for connection reuse :type session: Optional[aiohttp.ClientSession] + :param max_retries: Maximum number of retry attempts + :type max_retries: int + :param retry_delay: Initial delay between retries (exponential backoff) + :type retry_delay: float :return: List of response dictionaries containing log probabilities :rtype: List[Dict[str, Any]] """ should_close_session = session is None if session is None: - session = aiohttp.ClientSession() + timeout = aiohttp.ClientTimeout(total=300) # 5 minute timeout + session = aiohttp.ClientSession(timeout=timeout) - try: - tasks = [] - for sequence in sequences: - payload = { - "text": sequence, - "sampling_params": { - "temperature": 0, - "max_tokens": 0, # No new generation, just logprobs - "skip_special_tokens": False, - }, - "return_logprob": True, - "logprob_start_len": 0, # Get logprobs from the beginning - } - tasks.append(session.post(url, json=payload)) - - responses = await asyncio.gather(*tasks) - results = [] - for resp in responses: - resp.raise_for_status() - results.append(await resp.json()) + async def query_single(sequence: str, attempt: int = 0) -> Dict[str, Any]: + """Query teacher for a single sequence with retry logic.""" + payload = { + "text": sequence, + "sampling_params": { + "temperature": 0, + "max_new_tokens": 0, # Correct SGLang parameter name + "skip_special_tokens": False, + }, + "return_logprob": True, + "logprob_start_len": 0, # Get logprobs from the beginning + } + try: + async with session.post(url, json=payload) as resp: + resp.raise_for_status() + return await resp.json() + except Exception as e: + if attempt < max_retries - 1: + delay = retry_delay * (2 ** attempt) + logger.warning( + f"Teacher query failed (attempt {attempt + 1}/{max_retries}): {e}. " + f"Retrying in {delay:.1f}s..." + ) + await asyncio.sleep(delay) + return await query_single(sequence, attempt + 1) + else: + raise RuntimeError( + f"Failed to query teacher model after {max_retries} attempts: {e}" + ) from e + + try: + tasks = [query_single(seq) for seq in sequences] + results = await asyncio.gather(*tasks) return results finally: if should_close_session: @@ -92,7 +120,8 @@ def extract_teacher_logprobs( # logprobs is a list of [logprob, rank, decoded_token] tuples # Extract just the logprob values logprob_values = [item[0] if isinstance(item, list) else item for item in logprobs] - # Skip the first token (it doesn't have a logprob) and take last response_length tokens + # Skip the first token (it doesn't have a logprob from the sequence start) + # Then take the last response_length tokens teacher_log_probs = torch.tensor(logprob_values[1:], dtype=torch.float32) teacher_log_probs = teacher_log_probs[-response_length:] elif "prompt_logprobs" in response or "token_logprobs" in response: @@ -109,6 +138,15 @@ def extract_teacher_logprobs( f"Got keys: {response.keys()}" ) + # Ensure we have the right number of tokens + if len(teacher_log_probs) < response_length: + # Pad with zeros if teacher returned fewer logprobs than expected + padding_length = response_length - len(teacher_log_probs) + teacher_log_probs = torch.cat([ + torch.zeros(padding_length, dtype=torch.float32), + teacher_log_probs + ]) + teacher_log_probs_list.append(teacher_log_probs.to(device)) return teacher_log_probs_list @@ -119,9 +157,8 @@ def reward_func(queries: List[str], prompts: List[str], **kwargs) -> torch.Tenso Custom reward function for on-policy distillation. This function is called by LightRFT's experience maker to compute rewards. - It queries the teacher model and returns a placeholder reward tensor. - The actual teacher log probs are stored separately and used by the - OnPolicyDistillationCalculator. + It returns a placeholder reward tensor. The actual teacher log probs + are fetched separately by FastExperienceMaker._fetch_teacher_logprobs(). :param queries: List of full sequences (prompt + response) :type queries: List[str] @@ -155,9 +192,12 @@ async def get_teacher_logprobs_for_experiences( :type response_lengths: List[int] :param device: Target device for tensors :type device: str - :return: Tensor of teacher log probs, padded to match response lengths + :return: Tensor of teacher log probs, padded to match max response length :rtype: torch.Tensor """ + if not sequences: + return torch.tensor([], device=device) + # Query teacher model responses = await get_teacher_logprobs_async(teacher_url, sequences) @@ -168,9 +208,13 @@ async def get_teacher_logprobs_for_experiences( max_length = max(response_lengths) padded_log_probs = [] for log_probs, response_length in zip(teacher_log_probs_list, response_lengths): - if len(log_probs) < max_length: - padding = torch.zeros(max_length - len(log_probs), dtype=torch.float32, device=device) + current_len = len(log_probs) + if current_len < max_length: + padding = torch.zeros(max_length - current_len, dtype=torch.float32, device=device) log_probs = torch.cat([log_probs, padding]) + elif current_len > max_length: + # Truncate if longer (shouldn't happen, but handle gracefully) + log_probs = log_probs[:max_length] padded_log_probs.append(log_probs) return torch.stack(padded_log_probs) @@ -186,6 +230,10 @@ def get_teacher_logprobs_sync( """ Synchronous wrapper for getting teacher log probabilities. + Note: This creates a new event loop, which may conflict with existing loops + in distributed training environments. Prefer using the async version + via FastExperienceMaker._fetch_teacher_logprobs() which handles this properly. + :param teacher_url: URL of the teacher model server :type teacher_url: str :param sequences: List of full sequences (prompt + response) @@ -197,6 +245,30 @@ def get_teacher_logprobs_sync( :return: Tensor of teacher log probs :rtype: torch.Tensor """ - return asyncio.run( - get_teacher_logprobs_for_experiences(teacher_url, sequences, response_lengths, device) - ) + try: + # Try to get existing event loop + loop = asyncio.get_event_loop() + if loop.is_running(): + # We're in an async context, need to use different approach + import concurrent.futures + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit( + asyncio.run, + get_teacher_logprobs_for_experiences( + teacher_url, sequences, response_lengths, device + ) + ) + return future.result() + else: + return loop.run_until_complete( + get_teacher_logprobs_for_experiences( + teacher_url, sequences, response_lengths, device + ) + ) + except RuntimeError: + # No event loop exists, create one + return asyncio.run( + get_teacher_logprobs_for_experiences( + teacher_url, sequences, response_lengths, device + ) + ) diff --git a/examples/on_policy_distillation/run_opd_qwen.sh b/examples/on_policy_distillation/run_opd_qwen.sh index aaa0aef9..2b5b41bd 100644 --- a/examples/on_policy_distillation/run_opd_qwen.sh +++ b/examples/on_policy_distillation/run_opd_qwen.sh @@ -88,7 +88,8 @@ echo "Starting Teacher Model Server" echo "=========================================" # Generate unique log file for teacher server -LOG_FILE="/tmp/teacher_model_$(date +%Y%m%d_%H%M%S).log" +LOG_FILE="rft_logs/${EXPERIMENT_NAME}/teacher_model_$(date +%Y%m%d_%H%M%S).log" +mkdir -p rft_logs # Launch teacher model server in background CUDA_VISIBLE_DEVICES=$TEACHER_GPU python3 -m sglang.launch_server \ diff --git a/lightrft/trainer/advantage_calculator.py b/lightrft/trainer/advantage_calculator.py index 7f916804..b750f68a 100644 --- a/lightrft/trainer/advantage_calculator.py +++ b/lightrft/trainer/advantage_calculator.py @@ -765,19 +765,28 @@ def compute( # Student log probs are already computed in experience.action_log_probs student_log_probs = experience.action_log_probs - # Compute advantage as teacher - student log probs - # This encourages student to increase probability where teacher has higher probability - advantages = teacher_log_probs - student_log_probs + # Compute reverse KL divergence: student - teacher + # This is the correct direction for on-policy distillation: + # - When student > teacher: positive penalty (discourage student from being overconfident) + # - When student < teacher: negative penalty (encourage student to match teacher) + # The final advantage is: base_advantage - opd_kl_coef * reverse_kl + # Since we don't have a base advantage here, we use: -reverse_kl = teacher - student + # which encourages minimizing KL(student || teacher) + reverse_kl = student_log_probs - teacher_log_probs + advantages = -reverse_kl # This equals: teacher - student # Apply action mask to ensure we only consider generated tokens if experience.action_mask is not None: advantages = advantages * experience.action_mask + reverse_kl = reverse_kl * experience.action_mask # Returns are the same as advantages for distillation returns = deepcopy(advantages) + # Store reverse KL for metrics logging + info_dict = {"opd_reverse_kl": reverse_kl} + # Advantage whitening (normalization) - info_dict = {} if self.config.advantages_norm: masked_adv = torch.masked_select(advantages, experience.action_mask) adv_mean = masked_adv.mean() diff --git a/lightrft/trainer/fast_exp_maker.py b/lightrft/trainer/fast_exp_maker.py index 1c2935b3..ce53a9ee 100644 --- a/lightrft/trainer/fast_exp_maker.py +++ b/lightrft/trainer/fast_exp_maker.py @@ -59,6 +59,16 @@ from .image_utils import normalize_images, get_images_num from .video_utils import normalize_videos, get_videos_num +# On-Policy Distillation imports +try: + from examples.on_policy_distillation.on_policy_distillation_reward import ( + get_teacher_logprobs_for_experiences + ) + import asyncio + OPD_AVAILABLE = True +except ImportError: + OPD_AVAILABLE = False + # ============================================================================ # Data Structures # ============================================================================ @@ -1047,6 +1057,10 @@ def make_experience_list( self._process_multi_image_video_thws(experiences, expanded_images_num, expanded_videos_num) + # ========== Stage 6.5: On-Policy Distillation Teacher Log-Probs ========== + if config.advantage_estimator == "on_policy_distillation": + self._fetch_teacher_logprobs(experiences) + # ========== Stage 7: Advantage Computation ========== experiences = self._compute_advantages_and_returns(experiences, rewards, generate_kwargs) @@ -1410,6 +1424,83 @@ def _process_multi_image_video_thws( else: experience.video_grid_thws = [None] * len(micro_videos_num) + def _fetch_teacher_logprobs( + self, + experiences: List[ExperienceVL], + ) -> None: + """ + Fetch teacher log probabilities for on-policy distillation. + + This method queries the teacher model server to get log probabilities + for each generated sequence, which are used by OnPolicyDistillationCalculator + to compute advantages. + + :param experiences: List of experiences to add teacher log probs to + :type experiences: List[Union[Experience, ExperienceVL]] + """ + if not OPD_AVAILABLE: + raise RuntimeError( + "On-policy distillation module not available. " + "Make sure examples/on_policy_distillation/on_policy_distillation_reward.py exists." + ) + + # Get teacher URL from config + # remote_rm_url may be a string or list of strings + teacher_url = self.remote_rm_url + if isinstance(teacher_url, list): + teacher_url = teacher_url[0] if teacher_url else None + if teacher_url is None: + raise ValueError( + "Teacher model URL not specified. " + "Please set --remote_rm_url to the teacher model server URL." + ) + + Timer.start(' fetch_teacher_logprobs') + + for exp in experiences: + # Decode sequences to text for teacher model query + sequences = exp.sequences + attention_mask = exp.attention_mask + action_mask = exp.action_mask + + # Get response lengths (number of generated tokens per sequence) + response_lengths = action_mask.sum(dim=-1).tolist() + + # Decode full sequences to text + # Using batch_decode for efficiency + sequence_texts = self.tokenizer.batch_decode( + sequences, + skip_special_tokens=False, + clean_up_tokenization_spaces=False, + ) + + # Query teacher model for log probs + try: + # Use asyncio to run the async function + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + teacher_log_probs = loop.run_until_complete( + get_teacher_logprobs_for_experiences( + teacher_url=teacher_url, + sequences=sequence_texts, + response_lengths=response_lengths, + device="cpu", # Store on CPU, will move to GPU when needed + ) + ) + finally: + loop.close() + + # Store in experience info for advantage calculator + exp.info["teacher_log_probs"] = teacher_log_probs + + except Exception as e: + raise RuntimeError( + f"Failed to fetch teacher log probs from {teacher_url}: {e}" + ) from e + + Timer.stop(' fetch_teacher_logprobs') + def _process_experiences( self, experiences: List[ExperienceVL], From a03536f30f4bf9107993942e55ad34d8900569ae Mon Sep 17 00:00:00 2001 From: puyuan1996 Date: Mon, 16 Mar 2026 18:08:31 +0800 Subject: [PATCH 03/18] fix(pu): fix remote_rm_url bug --- lightrft/trainer/fast_exp_maker.py | 25 +++++++++++++++++++++---- lightrft/trainer/spmd_ppo_trainer.py | 27 +++++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 4 deletions(-) diff --git a/lightrft/trainer/fast_exp_maker.py b/lightrft/trainer/fast_exp_maker.py index ce53a9ee..ddf46eb3 100644 --- a/lightrft/trainer/fast_exp_maker.py +++ b/lightrft/trainer/fast_exp_maker.py @@ -452,7 +452,13 @@ def __init__( :type packing_samples: bool """ self.reward_model = reward_model - self.remote_rm_url = remote_rm_url + # Ensure remote_rm_url is a list for consistent iteration + if remote_rm_url is None: + self.remote_rm_url = None + elif isinstance(remote_rm_url, str): + self.remote_rm_url = [remote_rm_url] + else: + self.remote_rm_url = list(remote_rm_url) self.custom_reward_func = custom_reward_func self.reward_fn = reward_fn self.reward_fn_label_map = reward_fn_label_map or {} @@ -934,9 +940,20 @@ def __init__(self, *args, packing_samples: bool = False, processor=None, **kwarg else: self.multimodal_processor = None + # For On-Policy Distillation (OPD), remote_rm_url is used for teacher model, + # not for reward model. So we don't pass it to RewardComputationEngine. + # Instead, we store it separately for _fetch_teacher_logprobs(). + if advantage_estimator == "on_policy_distillation": + # Store teacher URL separately for OPD + self.teacher_model_url = self.remote_rm_url + rm_url_for_reward_engine = None # Don't use remote_rm_url for rewards in OPD mode + else: + self.teacher_model_url = None + rm_url_for_reward_engine = self.remote_rm_url + self.reward_engine = RewardComputationEngine( reward_model=self.reward_model, - remote_rm_url=self.remote_rm_url, + remote_rm_url=rm_url_for_reward_engine, custom_reward_func=getattr(self, "custom_reward_func", None), reward_fn=self.reward_fn, reward_fn_label_map=getattr(self, "reward_fn_label_map", None), @@ -1445,8 +1462,8 @@ def _fetch_teacher_logprobs( ) # Get teacher URL from config - # remote_rm_url may be a string or list of strings - teacher_url = self.remote_rm_url + # Use self.teacher_model_url which is set during __init__ for OPD mode + teacher_url = self.teacher_model_url if isinstance(teacher_url, list): teacher_url = teacher_url[0] if teacher_url else None if teacher_url is None: diff --git a/lightrft/trainer/spmd_ppo_trainer.py b/lightrft/trainer/spmd_ppo_trainer.py index d79a7458..0010a9b9 100644 --- a/lightrft/trainer/spmd_ppo_trainer.py +++ b/lightrft/trainer/spmd_ppo_trainer.py @@ -321,6 +321,7 @@ def ppo_train(self, global_steps=0): # Currently using this rewritten ppo_train all_advantages = [] all_returns = [] all_response_lengths = [] + all_opd_reverse_kl = [] # For on-policy distillation metrics for item in self.replay_buffer.items: # Collect rewards @@ -347,6 +348,10 @@ def ppo_train(self, global_steps=0): # Currently using this rewritten ppo_train if hasattr(item, 'info') and item.info is not None and 'response_length' in item.info: all_response_lengths.append(item.info['response_length']) + # Collect on-policy distillation reverse KL + if hasattr(item, 'info') and item.info is not None and 'opd_reverse_kl' in item.info: + all_opd_reverse_kl.append(item.info['opd_reverse_kl']) + # Compute statistics # [TENSOR-FIX] Handle both tensor lists and scalar lists for all reward types if all_rewards: @@ -422,6 +427,22 @@ def ppo_train(self, global_steps=0): # Currently using this rewritten ppo_train status_mean["response_length_mean"] = lengths_tensor.float().mean().item() status_mean["response_length_std"] = lengths_tensor.float().std().item() + # On-Policy Distillation metrics + if all_opd_reverse_kl: + # Collect reverse KL from all experiences + if isinstance(all_opd_reverse_kl[0], torch.Tensor): + opd_kl_tensor = torch.cat([t.to(device).float() for t in all_opd_reverse_kl]) + else: + opd_kl_tensor = torch.tensor(all_opd_reverse_kl, dtype=torch.float32, device=device) + # Mask out zero values (padding) + non_zero_mask = opd_kl_tensor != 0 + if non_zero_mask.any(): + masked_kl = opd_kl_tensor[non_zero_mask] + status_mean["opd_reverse_kl_mean"] = masked_kl.mean().item() + status_mean["opd_reverse_kl_std"] = masked_kl.std().item() + status_mean["opd_reverse_kl_max"] = masked_kl.max().item() + status_mean["opd_reverse_kl_min"] = masked_kl.min().item() + # Print detailed reward breakdown (only on rank 0) if self.print_replay_buffer_stats and self.strategy.is_rank_0(): self.strategy.print("\n" + "=" * 60) @@ -460,6 +481,12 @@ def ppo_train(self, global_steps=0): # Currently using this rewritten ppo_train f"📏 Response Length: {status_mean['response_length_mean']:.1f} ± {status_mean['response_length_std']:.1f} tokens" # noqa ) + if all_opd_reverse_kl and 'opd_reverse_kl_mean' in status_mean: + self.strategy.print( + f"🎓 OPD Reverse KL: {status_mean['opd_reverse_kl_mean']:.4f} ± {status_mean['opd_reverse_kl_std']:.4f} " # noqa + f"(min={status_mean['opd_reverse_kl_min']:.4f}, max={status_mean['opd_reverse_kl_max']:.4f})" + ) + self.strategy.print("=" * 60 + "\n") torch.cuda.empty_cache() From ef7c6bf7d1728de9fde56a39b30aeb737e10c784 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 17 Mar 2026 17:37:40 +0800 Subject: [PATCH 04/18] fix(pu): fix no grpo adv bug, use input_ids in get_teacher_logprobs_by_ids --- examples/gsm8k_geo3k/train_colocate.py | 3 +- examples/on_policy_distillation/README.md | 317 +++++++----------- examples/on_policy_distillation/README_zh.md | 232 +++++++++++++ .../on_policy_distillation_reward.py | 198 +++++------ lightrft/strategy/config.py | 2 + lightrft/trainer/advantage_calculator.py | 92 ++--- lightrft/trainer/experience_maker.py | 82 +++-- lightrft/trainer/fast_exp_maker.py | 55 +-- 8 files changed, 585 insertions(+), 396 deletions(-) create mode 100644 examples/on_policy_distillation/README_zh.md diff --git a/examples/gsm8k_geo3k/train_colocate.py b/examples/gsm8k_geo3k/train_colocate.py index 43435291..fe85f336 100644 --- a/examples/gsm8k_geo3k/train_colocate.py +++ b/examples/gsm8k_geo3k/train_colocate.py @@ -535,6 +535,7 @@ def train(args): ) parser.add_argument("--use_kl_loss", action="store_true", default=False, help="whether to use KL loss from GRPO") + parser.add_argument("--opd_kl_coef", type=float, default=1.0, help="KL coefficient for on-policy distillation penalty") # LoRA parser.add_argument("--load_in_4bit", action="store_true", default=False) @@ -621,7 +622,7 @@ def train(args): elif args.critic_pretrain is None: args.critic_pretrain = args.pretrain - if args.advantage_estimator in ["rloo", "reinforce_baseline", "group_norm"]: + if args.advantage_estimator in ["rloo", "reinforce_baseline", "group_norm", "on_policy_distillation"]: assert args.n_samples_per_prompt > 1, f"{args.advantage_estimator} requires n_samples_per_prompt > 1" if args.use_kl_loss: diff --git a/examples/on_policy_distillation/README.md b/examples/on_policy_distillation/README.md index 754112f0..af20882d 100644 --- a/examples/on_policy_distillation/README.md +++ b/examples/on_policy_distillation/README.md @@ -1,173 +1,129 @@ -# On-Policy Distillation for LightRFT +# On-Policy Distillation (OPD) for LightRFT -This directory contains a complete implementation of on-policy knowledge distillation for LightRFT, enabling smaller student models to learn from larger teacher models during reinforcement learning. +On-policy knowledge distillation enables smaller student models to learn from larger teacher models during reinforcement learning training. ## Overview -On-policy distillation is a technique where: -- A **teacher model** (large, powerful) provides token-level supervision -- A **student model** (small, efficient) learns to match the teacher's probability distribution -- Training happens **on-policy**: teacher evaluates student's actual generated responses -- No separate reward model is needed - teacher's log probabilities serve as the learning signal +| Aspect | Description | +|--------|-------------| +| **Teacher** | Large model providing token-level log probability supervision | +| **Student** | Smaller model being trained to match teacher's distribution | +| **Training** | On-policy: teacher evaluates student's actual generated responses | +| **Reward** | Teacher's log probabilities serve as the learning signal | ## Quick Start -### 1. Installation - -Ensure you have LightRFT installed with SGLang support: +### 1. Start Teacher Model Server ```bash -pip install lightrft -pip install sglang # For teacher model inference server +# Launch SGLang server for teacher model +CUDA_VISIBLE_DEVICES=7 python3 -m sglang.launch_server \ + --model-path "Qwen/Qwen2.5-7B-Instruct" \ + --host 0.0.0.0 \ + --port 13141 \ + --tp 1 \ + --mem-fraction-static 0.6 ``` -### 2. Prepare Your Dataset - -Your dataset should be in JSONL format with prompts: +### 2. Run Training -```json -{"prompt": "Solve: What is 2 + 2?"} -{"prompt": "Explain the theory of relativity."} +```bash +bash examples/on_policy_distillation/run_opd_qwen_2.sh ``` -### 3. Run Training +Or manually: ```bash -# Edit the configuration in run_opd_qwen.sh -bash examples/on_policy_distillation/run_opd_qwen.sh +torchrun --nproc-per-node 2 examples/gsm8k_geo3k/train_colocate.py \ + --pretrain "Qwen/Qwen2.5-0.5B-Instruct" \ + --advantage_estimator "on_policy_distillation" \ + --remote_rm_url "http://127.0.0.1:13141/generate" \ + --reward_pretrain "" \ + --n_samples_per_prompt 4 \ + --actor_learning_rate 1e-6 \ + --init_kl_coef 0.01 \ + --num_episodes 30 ``` -## How It Works - -### Architecture +## Architecture ``` -┌─────────────────────────────────────────────────────────┐ -│ Training Pipeline │ -└─────────────────────────────────────────────────────────┘ - -1. Student generates responses: - Prompt → [Student Model] → Response - -2. Teacher evaluates responses: - [Prompt + Response] → [Teacher Model] → Teacher Log Probs - -3. Advantage calculation: - Advantage = Teacher Log Probs - Student Log Probs - -4. Student optimization: - Update student to increase probability where teacher has high probability +┌─────────────────────────────────────────────────────────────┐ +│ OPD Training Pipeline │ +├─────────────────────────────────────────────────────────────┤ +│ │ +│ 1. Generate Prompt ──► [Student] ──► Response │ +│ │ +│ 2. Evaluate [Prompt + Response] ──► [Teacher Server] │ +│ │ │ +│ ▼ │ +│ Teacher Log Probs │ +│ │ +│ 3. Compute Advantage = Teacher_logp - Student_logp │ +│ │ +│ 4. Update Student ◄── Policy Gradient Loss │ +│ │ +└─────────────────────────────────────────────────────────────┘ ``` -### Key Components +## Key Components -#### 1. OnPolicyDistillationCalculator (`lightrft/trainer/advantage_calculator.py`) +### 1. Advantage Calculator -Computes advantages using teacher log probabilities: +**File**: `lightrft/trainer/advantage_calculator.py` ```python -advantage = teacher_log_probs - student_log_probs +class OnPolicyDistillationCalculator(AdvantageCalculator): + def compute(self, experience, ...): + teacher_log_probs = experience.info["teacher_log_probs"] + student_log_probs = experience.action_log_probs + + # Reverse KL: encourages student to match teacher + reverse_kl = student_log_probs - teacher_log_probs + advantages = -reverse_kl # = teacher - student + + return advantages, returns, {"opd_reverse_kl": reverse_kl} ``` -This encourages the student to match the teacher's token-level distribution. +### 2. Teacher Log Prob Fetcher -#### 2. Teacher Logprob Function (`on_policy_distillation_reward.py`) +**File**: `examples/on_policy_distillation/on_policy_distillation_reward.py` -Queries the teacher model inference server to get log probabilities: +- Async HTTP requests to teacher server +- Supports SGLang and vLLM response formats +- Automatic retry with exponential backoff -```python -teacher_log_probs = get_teacher_logprobs_sync( - teacher_url=teacher_url, - sequences=sequences, - response_lengths=response_lengths -) -``` +### 3. Experience Maker Integration -#### 3. Experience Maker Integration +**File**: `lightrft/trainer/fast_exp_maker.py` -Modified `experience_maker.py` to: -- Query teacher model during experience collection -- Store teacher log probs in `experience.info["teacher_log_probs"]` -- Use OnPolicyDistillationCalculator for advantage computation +- `--remote_rm_url` is used as teacher URL (not reward model) when `--advantage_estimator "on_policy_distillation"` +- Teacher log probs stored in `experience.info["teacher_log_probs"]` +- OPD metrics (`opd_reverse_kl_mean/std/min/max`) logged to wandb ## Configuration ### Required Arguments -```bash ---advantage_estimator "on_policy_distillation" # Enable on-policy distillation ---remote_rm_url "http://localhost:13141/generate" # Teacher model URL ---pretrain "Qwen/Qwen2.5-0.5B-Instruct" # Student model -``` +| Argument | Value | Description | +|----------|-------|-------------| +| `--advantage_estimator` | `"on_policy_distillation"` | Enable OPD mode | +| `--remote_rm_url` | `"http://host:port/generate"` | Teacher server URL | +| `--reward_pretrain` | `""` | Empty (no reward model needed) | ### Recommended Hyperparameters | Parameter | Value | Description | |-----------|-------|-------------| -| `n_samples_per_prompt` | 4 | Number of responses per prompt | -| `actor_learning_rate` | 1e-6 | Learning rate for student | -| `init_kl_coef` | 0.01 | KL coefficient for regularization | -| `num_episodes` | 30 | Number of training episodes | - -## Example Use Cases - -### 1. Math Reasoning (GSM8K) - -Train a small model to solve math problems like a larger model: - -```bash -TEACHER_MODEL_PATH="Qwen/Qwen2.5-7B-Instruct" -STUDENT_MODEL_PATH="Qwen/Qwen2.5-0.5B-Instruct" -DATASET_PATH="path/to/gsm8k.jsonl" -``` - -### 2. General Instruction Following - -Distill instruction-following capabilities: - -```bash -TEACHER_MODEL_PATH="Qwen/Qwen2.5-14B-Instruct" -STUDENT_MODEL_PATH="Qwen/Qwen2.5-1.5B-Instruct" -DATASET_PATH="path/to/instruction_data.jsonl" -``` - -### 3. Domain-Specific Tasks - -Transfer domain expertise from a fine-tuned teacher to a smaller student: - -```bash -TEACHER_MODEL_PATH="path/to/finetuned_teacher" -STUDENT_MODEL_PATH="Qwen/Qwen2.5-0.5B-Instruct" -DATASET_PATH="path/to/domain_data.jsonl" -``` - -## Technical Details - -### Advantage Computation - -The advantage estimator computes: +| `--n_samples_per_prompt` | 4 | Responses per prompt | +| `--actor_learning_rate` | 1e-6 | Student learning rate | +| `--init_kl_coef` | 0.01 | KL regularization coefficient | +| `--num_episodes` | 30 | Training episodes | -```python -# Get teacher and student log probs for each token -teacher_log_probs = experience.info["teacher_log_probs"] -student_log_probs = experience.action_log_probs - -# Compute advantage (encourages matching teacher) -advantages = teacher_log_probs - student_log_probs +## Teacher Server Formats -# Apply action mask (only consider generated tokens) -advantages = advantages * experience.action_mask - -# Optional: normalize advantages -if config.advantages_norm: - advantages = (advantages - mean) / (std + 1e-8) -``` +### SGLang (Recommended) -### Teacher Server Format - -The implementation supports both SGLang and vLLM formats: - -**SGLang format:** ```json { "meta_info": { @@ -176,116 +132,101 @@ The implementation supports both SGLang and vLLM formats: } ``` -**vLLM format:** +### vLLM + ```json { "token_logprobs": [logprob1, logprob2, ...] } ``` -## Performance Tips - -### 1. GPU Memory Optimization +## Monitoring -- Run teacher on separate GPU(s) from training -- Use tensor parallelism for large teachers: `--tp 2` -- Adjust memory fraction: `--mem-fraction-static 0.6` +### Logged Metrics -### 2. Training Speed +| Metric | Description | +|--------|-------------| +| `opd_reverse_kl_mean` | Average KL(student \|\| teacher) | +| `opd_reverse_kl_std` | Standard deviation of reverse KL | +| `advantages_mean` | Average advantage (should center ~0) | +| `policy_loss` | Should decrease during training | -- Increase `n_samples_per_prompt` for more stable gradients (but slower) -- Use larger batch sizes if memory permits -- Enable gradient checkpointing for memory-intensive models +### Console Output -### 3. Convergence - -- Start with lower learning rate (1e-6) for stable distillation -- Use KL coefficient to prevent student from diverging too far -- Monitor teacher-student log prob difference in W&B +``` +📊 Detailed Step Statistics +============================================================ +🎁 Total Reward: 0.0000 ± 0.0000 (placeholder for OPD) +📈 Advantages: 0.0012 ± 0.8234 (...) +🎓 OPD Reverse KL: 0.1523 ± 0.0891 (...) +============================================================ +``` ## Troubleshooting -### Teacher server won't start +### Teacher Server Issues ```bash +# Check if port is in use +lsof -i :13141 + # Check GPU availability nvidia-smi -# Check if port is already in use -lsof -i :13141 - -# Try different memory fraction +# Reduce memory if OOM --mem-fraction-static 0.5 ``` -### Training OOM (Out of Memory) +### Training OOM ```bash -# Reduce batch sizes --micro_train_batch_size 2 --micro_rollout_batch_size 2 - -# Enable gradient checkpointing --gradient_checkpointing - -# Use ZeRO-3 optimization --zero_stage 3 ``` -### Slow convergence +### Slow Convergence ```bash -# Increase samples per prompt --n_samples_per_prompt 8 - -# Adjust learning rate --actor_learning_rate 5e-7 - -# Increase training episodes --num_episodes 50 ``` ## Comparison with Other Methods -| Method | Reward Signal | Offline/Online | Requires RM | -|--------|--------------|----------------|-------------| +| Method | Reward Signal | Mode | Requires RM | +|--------|--------------|------|-------------| | GRPO | Task-specific reward | Online | Yes | | DPO | Preference pairs | Offline | No | -| **On-Policy Distillation** | Teacher log probs | Online | No (uses teacher) | +| **OPD** | Teacher log probs | Online | No (uses teacher) | -**Advantages:** -- ✅ No need to train a separate reward model -- ✅ Token-level supervision (finer-grained than sequence-level rewards) -- ✅ On-policy: adapts to student's changing distribution -- ✅ Works for any task where you have a good teacher model +### Advantages -**Limitations:** -- ⚠️ Requires a teacher model (inference overhead) -- ⚠️ Student cannot exceed teacher's capabilities -- ⚠️ Needs sufficient compute for teacher inference +- No separate reward model training required +- Token-level supervision (finer than sequence-level) +- On-policy: adapts to student's changing distribution +- Works for any task with a good teacher model -## References - -- [Original slime implementation](https://github.com/OpenRLHF/slime) -- [LightRFT Documentation](../../README.md) -- [On-Policy Distillation Paper](https://arxiv.org/abs/XXXX.XXXXX) +### Limitations -## Citation +- Requires running teacher model (inference overhead) +- Student cannot exceed teacher's capabilities +- Needs sufficient compute for teacher inference -If you use this implementation, please cite: +## File Structure -```bibtex -@software{lightrft_opd, - title={On-Policy Distillation for LightRFT}, - author={LightRFT Team}, - year={2024}, - url={https://github.com/yourusername/LightRFT} -} +``` +examples/on_policy_distillation/ +├── README.md # This file +├── README_zh.md # Chinese version +├── run_opd_qwen_2.sh # Training script +└── on_policy_distillation_reward.py # Teacher logprob fetcher ``` -## Support +## References -For questions or issues: -- Open an issue on GitHub -- Check the [FAQ](../../docs/source/best_practice/faq.md) -- Review [troubleshooting guide](../../docs/source/best_practice/troubleshooting.md) +- [LightRFT Documentation](../../README.md) +- [Advantage Calculator Source](../../lightrft/trainer/advantage_calculator.py) +- [Fast Experience Maker Source](../../lightrft/trainer/fast_exp_maker.py) diff --git a/examples/on_policy_distillation/README_zh.md b/examples/on_policy_distillation/README_zh.md new file mode 100644 index 00000000..c6239123 --- /dev/null +++ b/examples/on_policy_distillation/README_zh.md @@ -0,0 +1,232 @@ +# LightRFT 在线策略蒸馏 (OPD) + +在线策略知识蒸馏使小型学生模型能够在强化学习训练过程中从大型教师模型学习。 + +## 概述 + +| 方面 | 描述 | +|------|------| +| **教师模型** | 提供 token 级别对数概率监督的大型模型 | +| **学生模型** | 被训练以匹配教师分布的小型模型 | +| **训练方式** | 在线策略:教师评估学生实际生成的响应 | +| **奖励信号** | 教师的对数概率作为学习信号 | + +## 快速开始 + +### 1. 启动教师模型服务器 + +```bash +# 启动 SGLang 服务器运行教师模型 +CUDA_VISIBLE_DEVICES=7 python3 -m sglang.launch_server \ + --model-path "Qwen/Qwen2.5-7B-Instruct" \ + --host 0.0.0.0 \ + --port 13141 \ + --tp 1 \ + --mem-fraction-static 0.6 +``` + +### 2. 运行训练 + +```bash +bash examples/on_policy_distillation/run_opd_qwen_2.sh +``` + +或手动运行: + +```bash +torchrun --nproc-per-node 2 examples/gsm8k_geo3k/train_colocate.py \ + --pretrain "Qwen/Qwen2.5-0.5B-Instruct" \ + --advantage_estimator "on_policy_distillation" \ + --remote_rm_url "http://127.0.0.1:13141/generate" \ + --reward_pretrain "" \ + --n_samples_per_prompt 4 \ + --actor_learning_rate 1e-6 \ + --init_kl_coef 0.01 \ + --num_episodes 30 +``` + +## 架构 + +``` +┌─────────────────────────────────────────────────────────────┐ +│ OPD 训练流程 │ +├─────────────────────────────────────────────────────────────┤ +│ │ +│ 1. 生成 提示词 ──► [学生模型] ──► 响应 │ +│ │ +│ 2. 评估 [提示词 + 响应] ──► [教师服务器] │ +│ │ │ +│ ▼ │ +│ 教师对数概率 │ +│ │ +│ 3. 计算 优势值 = 教师_logp - 学生_logp │ +│ │ +│ 4. 更新 学生模型 ◄── 策略梯度损失 │ +│ │ +└─────────────────────────────────────────────────────────────┘ +``` + +## 核心组件 + +### 1. 优势值计算器 + +**文件**: `lightrft/trainer/advantage_calculator.py` + +```python +class OnPolicyDistillationCalculator(AdvantageCalculator): + def compute(self, experience, ...): + teacher_log_probs = experience.info["teacher_log_probs"] + student_log_probs = experience.action_log_probs + + # 反向 KL:鼓励学生匹配教师 + reverse_kl = student_log_probs - teacher_log_probs + advantages = -reverse_kl # = teacher - student + + return advantages, returns, {"opd_reverse_kl": reverse_kl} +``` + +### 2. 教师对数概率获取器 + +**文件**: `examples/on_policy_distillation/on_policy_distillation_reward.py` + +- 异步 HTTP 请求到教师服务器 +- 支持 SGLang 和 vLLM 响应格式 +- 自动重试(指数退避) + +### 3. Experience Maker 集成 + +**文件**: `lightrft/trainer/fast_exp_maker.py` + +- 当 `--advantage_estimator "on_policy_distillation"` 时,`--remote_rm_url` 作为教师 URL(而非奖励模型) +- 教师对数概率存储在 `experience.info["teacher_log_probs"]` +- OPD 指标(`opd_reverse_kl_mean/std/min/max`)记录到 wandb + +## 配置 + +### 必需参数 + +| 参数 | 值 | 描述 | +|------|---|------| +| `--advantage_estimator` | `"on_policy_distillation"` | 启用 OPD 模式 | +| `--remote_rm_url` | `"http://host:port/generate"` | 教师服务器 URL | +| `--reward_pretrain` | `""` | 空值(不需要奖励模型) | + +### 推荐超参数 + +| 参数 | 值 | 描述 | +|------|---|------| +| `--n_samples_per_prompt` | 4 | 每个提示词的响应数 | +| `--actor_learning_rate` | 1e-6 | 学生学习率 | +| `--init_kl_coef` | 0.01 | KL 正则化系数 | +| `--num_episodes` | 30 | 训练轮数 | + +## 教师服务器格式 + +### SGLang(推荐) + +```json +{ + "meta_info": { + "input_token_logprobs": [[logprob, rank, token], ...] + } +} +``` + +### vLLM + +```json +{ + "token_logprobs": [logprob1, logprob2, ...] +} +``` + +## 监控 + +### 记录的指标 + +| 指标 | 描述 | +|------|------| +| `opd_reverse_kl_mean` | 平均 KL(学生 \|\| 教师) | +| `opd_reverse_kl_std` | 反向 KL 的标准差 | +| `advantages_mean` | 平均优势值(应接近 0) | +| `policy_loss` | 训练中应下降 | + +### 控制台输出 + +``` +📊 详细步骤统计 +============================================================ +🎁 总奖励: 0.0000 ± 0.0000 (OPD 占位符) +📈 优势值: 0.0012 ± 0.8234 (...) +🎓 OPD 反向 KL: 0.1523 ± 0.0891 (...) +============================================================ +``` + +## 故障排除 + +### 教师服务器问题 + +```bash +# 检查端口是否被占用 +lsof -i :13141 + +# 检查 GPU 可用性 +nvidia-smi + +# 内存不足时减少内存占用 +--mem-fraction-static 0.5 +``` + +### 训练内存不足 + +```bash +--micro_train_batch_size 2 +--micro_rollout_batch_size 2 +--gradient_checkpointing +--zero_stage 3 +``` + +### 收敛缓慢 + +```bash +--n_samples_per_prompt 8 +--actor_learning_rate 5e-7 +--num_episodes 50 +``` + +## 与其他方法对比 + +| 方法 | 奖励信号 | 模式 | 需要 RM | +|------|---------|------|---------| +| GRPO | 任务特定奖励 | 在线 | 是 | +| DPO | 偏好对 | 离线 | 否 | +| **OPD** | 教师对数概率 | 在线 | 否(使用教师) | + +### 优势 + +- 无需单独训练奖励模型 +- Token 级监督(比序列级更精细) +- 在线策略:适应学生不断变化的分布 +- 适用于任何有好教师模型的任务 + +### 局限性 + +- 需要运行教师模型(推理开销) +- 学生无法超越教师的能力 +- 需要足够的计算资源进行教师推理 + +## 文件结构 + +``` +examples/on_policy_distillation/ +├── README.md # 英文文档 +├── README_zh.md # 本文件 +├── run_opd_qwen_2.sh # 训练脚本 +└── on_policy_distillation_reward.py # 教师对数概率获取器 +``` + +## 参考资料 + +- [LightRFT 文档](../../README.md) +- [优势值计算器源码](../../lightrft/trainer/advantage_calculator.py) +- [Fast Experience Maker 源码](../../lightrft/trainer/fast_exp_maker.py) diff --git a/examples/on_policy_distillation/on_policy_distillation_reward.py b/examples/on_policy_distillation/on_policy_distillation_reward.py index 0729d6f8..7bc65503 100644 --- a/examples/on_policy_distillation/on_policy_distillation_reward.py +++ b/examples/on_policy_distillation/on_policy_distillation_reward.py @@ -1,29 +1,97 @@ """ On-Policy Distillation Reward Function for LightRFT -This module provides a custom reward function that queries a teacher model -to obtain log probabilities for knowledge distillation during RL training. +This module provides functions to query a teacher model for log probabilities +used in knowledge distillation during RL training. -The teacher model runs as a separate inference server (vLLM or SGLang), -and this function queries it to get token-level log probabilities for +The teacher model runs as a separate inference server (SGLang), +and this module queries it to get token-level log probabilities for the sequences generated by the student model. -Key differences from Slime implementation: -- Uses text input (SGLang will tokenize internally) -- Uses async HTTP requests for efficiency +Key design decisions (aligned with Slime reference implementation): +- Uses input_ids (not text) to ensure token-level alignment +- Returns per-sample tensors (response tokens only) for dimension alignment - Includes retry logic for robustness """ import asyncio import aiohttp import torch -import numpy as np import logging -from typing import List, Dict, Any, Optional, Union +from typing import List, Dict, Any, Optional logger = logging.getLogger(__name__) +async def get_teacher_logprobs_by_ids( + url: str, + input_ids_list: List[List[int]], + response_lengths: List[int], + max_retries: int = 3, + retry_delay: float = 1.0, +) -> List[torch.Tensor]: + """ + Query teacher model using input_ids and return per-sample log prob tensors. + + This is the primary function for OPD, using input_ids to ensure exact + token-level alignment between teacher and student log probabilities. + + :param url: URL of the teacher model inference server (SGLang /generate endpoint) + :param input_ids_list: List of token id sequences [prompt + response] + :param response_lengths: Number of response tokens for each sequence + :param max_retries: Maximum retry attempts per query + :param retry_delay: Initial delay between retries (exponential backoff) + :return: List of tensors, each containing teacher log probs for response tokens only + """ + timeout = aiohttp.ClientTimeout(total=300) + async with aiohttp.ClientSession(timeout=timeout) as session: + + async def query_single(input_ids: List[int], attempt: int = 0) -> Dict[str, Any]: + payload = { + "input_ids": input_ids, + "sampling_params": { + "temperature": 0, + "max_new_tokens": 0, + "skip_special_tokens": False, + }, + "return_logprob": True, + "logprob_start_len": 0, + } + try: + async with session.post(url, json=payload) as resp: + resp.raise_for_status() + return await resp.json() + except Exception as e: + if attempt < max_retries - 1: + delay = retry_delay * (2 ** attempt) + logger.warning( + f"Teacher query failed (attempt {attempt + 1}/{max_retries}): {e}. " + f"Retrying in {delay:.1f}s..." + ) + await asyncio.sleep(delay) + return await query_single(input_ids, attempt + 1) + else: + raise RuntimeError( + f"Failed to query teacher model after {max_retries} attempts: {e}" + ) from e + + tasks = [query_single(ids) for ids in input_ids_list] + results = await asyncio.gather(*tasks) + + # Extract log probs for response tokens only + teacher_lp_list = [] + for response, resp_len in zip(results, response_lengths): + logprobs = response["meta_info"]["input_token_logprobs"] + # Extract logprob values; skip first token (no logprob for BOS) + lp_values = [item[0] if isinstance(item, list) else item for item in logprobs] + teacher_lp = torch.tensor(lp_values[1:], dtype=torch.float32) + # Take the last resp_len tokens (response part only) + teacher_lp = teacher_lp[-resp_len:] + teacher_lp_list.append(teacher_lp) + + return teacher_lp_list + + async def get_teacher_logprobs_async( url: str, sequences: List[str], @@ -32,39 +100,26 @@ async def get_teacher_logprobs_async( retry_delay: float = 1.0, ) -> List[Dict[str, Any]]: """ - Asynchronously query teacher model for log probabilities. - - :param url: URL of the teacher model inference server - :type url: str - :param sequences: List of full sequences (prompt + response) - :type sequences: List[str] - :param session: Optional aiohttp session for connection reuse - :type session: Optional[aiohttp.ClientSession] - :param max_retries: Maximum number of retry attempts - :type max_retries: int - :param retry_delay: Initial delay between retries (exponential backoff) - :type retry_delay: float - :return: List of response dictionaries containing log probabilities - :rtype: List[Dict[str, Any]] + [Legacy] Asynchronously query teacher model using text sequences. + + Prefer get_teacher_logprobs_by_ids() for better token-level alignment. """ should_close_session = session is None if session is None: - timeout = aiohttp.ClientTimeout(total=300) # 5 minute timeout + timeout = aiohttp.ClientTimeout(total=300) session = aiohttp.ClientSession(timeout=timeout) async def query_single(sequence: str, attempt: int = 0) -> Dict[str, Any]: - """Query teacher for a single sequence with retry logic.""" payload = { "text": sequence, "sampling_params": { "temperature": 0, - "max_new_tokens": 0, # Correct SGLang parameter name + "max_new_tokens": 0, "skip_special_tokens": False, }, "return_logprob": True, - "logprob_start_len": 0, # Get logprobs from the beginning + "logprob_start_len": 0, } - try: async with session.post(url, json=payload) as resp: resp.raise_for_status() @@ -98,49 +153,24 @@ def extract_teacher_logprobs( device: str = "cpu" ) -> List[torch.Tensor]: """ - Extract teacher log probabilities for the response tokens only. - - :param teacher_responses: List of teacher model API responses - :type teacher_responses: List[Dict[str, Any]] - :param response_lengths: Number of response tokens for each sequence - :type response_lengths: List[int] - :param device: Target device for tensors - :type device: str - :return: List of tensors containing teacher log probs for response tokens - :rtype: List[torch.Tensor] + Extract teacher log probabilities for response tokens only. """ teacher_log_probs_list = [] for response, response_length in zip(teacher_responses, response_lengths): - # Extract log probabilities from teacher response - # The format depends on the inference server (vLLM/SGLang) if "meta_info" in response and "input_token_logprobs" in response["meta_info"]: - # SGLang format logprobs = response["meta_info"]["input_token_logprobs"] - # logprobs is a list of [logprob, rank, decoded_token] tuples - # Extract just the logprob values logprob_values = [item[0] if isinstance(item, list) else item for item in logprobs] - # Skip the first token (it doesn't have a logprob from the sequence start) - # Then take the last response_length tokens teacher_log_probs = torch.tensor(logprob_values[1:], dtype=torch.float32) teacher_log_probs = teacher_log_probs[-response_length:] - elif "prompt_logprobs" in response or "token_logprobs" in response: - # vLLM format - logprobs = response.get("token_logprobs", response.get("prompt_logprobs", [])) - # Filter out None values and convert to tensor - logprob_values = [lp for lp in logprobs if lp is not None] - teacher_log_probs = torch.tensor(logprob_values, dtype=torch.float32) - teacher_log_probs = teacher_log_probs[-response_length:] else: raise ValueError( f"Unknown response format from teacher model. " - f"Expected 'meta_info' (SGLang) or 'token_logprobs' (vLLM). " + f"Expected 'meta_info' with 'input_token_logprobs'. " f"Got keys: {response.keys()}" ) - # Ensure we have the right number of tokens if len(teacher_log_probs) < response_length: - # Pad with zeros if teacher returned fewer logprobs than expected padding_length = response_length - len(teacher_log_probs) teacher_log_probs = torch.cat([ torch.zeros(padding_length, dtype=torch.float32), @@ -154,22 +184,9 @@ def extract_teacher_logprobs( def reward_func(queries: List[str], prompts: List[str], **kwargs) -> torch.Tensor: """ - Custom reward function for on-policy distillation. - - This function is called by LightRFT's experience maker to compute rewards. - It returns a placeholder reward tensor. The actual teacher log probs - are fetched separately by FastExperienceMaker._fetch_teacher_logprobs(). - - :param queries: List of full sequences (prompt + response) - :type queries: List[str] - :param prompts: List of prompts (unused, for compatibility) - :type prompts: List[str] - :return: Placeholder reward tensor (zeros) - :rtype: torch.Tensor + Placeholder reward function for on-policy distillation. + Returns zeros; actual learning signal comes from teacher log probs + task rewards. """ - # Return placeholder rewards - # The actual advantage computation happens in OnPolicyDistillationCalculator - # using teacher log probs stored in experience.info return torch.zeros(len(queries), dtype=torch.float32) @@ -180,31 +197,16 @@ async def get_teacher_logprobs_for_experiences( device: str = "cpu" ) -> torch.Tensor: """ - Get teacher log probabilities for a batch of sequences. - - This is the main entry point for obtaining teacher log probs during training. + [Legacy] Get teacher log probabilities using text sequences. - :param teacher_url: URL of the teacher model server - :type teacher_url: str - :param sequences: List of full sequences (prompt + response) - :type sequences: List[str] - :param response_lengths: Number of response tokens for each sequence - :type response_lengths: List[int] - :param device: Target device for tensors - :type device: str - :return: Tensor of teacher log probs, padded to match max response length - :rtype: torch.Tensor + Prefer get_teacher_logprobs_by_ids() called from _fetch_teacher_logprobs(). """ if not sequences: return torch.tensor([], device=device) - # Query teacher model responses = await get_teacher_logprobs_async(teacher_url, sequences) - - # Extract log probs teacher_log_probs_list = extract_teacher_logprobs(responses, response_lengths, device) - # Pad to uniform length if needed max_length = max(response_lengths) padded_log_probs = [] for log_probs, response_length in zip(teacher_log_probs_list, response_lengths): @@ -213,14 +215,12 @@ async def get_teacher_logprobs_for_experiences( padding = torch.zeros(max_length - current_len, dtype=torch.float32, device=device) log_probs = torch.cat([log_probs, padding]) elif current_len > max_length: - # Truncate if longer (shouldn't happen, but handle gracefully) log_probs = log_probs[:max_length] padded_log_probs.append(log_probs) return torch.stack(padded_log_probs) -# Synchronous wrapper for compatibility with LightRFT def get_teacher_logprobs_sync( teacher_url: str, sequences: List[str], @@ -228,28 +228,11 @@ def get_teacher_logprobs_sync( device: str = "cpu" ) -> torch.Tensor: """ - Synchronous wrapper for getting teacher log probabilities. - - Note: This creates a new event loop, which may conflict with existing loops - in distributed training environments. Prefer using the async version - via FastExperienceMaker._fetch_teacher_logprobs() which handles this properly. - - :param teacher_url: URL of the teacher model server - :type teacher_url: str - :param sequences: List of full sequences (prompt + response) - :type sequences: List[str] - :param response_lengths: Number of response tokens for each sequence - :type response_lengths: List[int] - :param device: Target device for tensors - :type device: str - :return: Tensor of teacher log probs - :rtype: torch.Tensor + [Legacy] Synchronous wrapper for text-based teacher log prob queries. """ try: - # Try to get existing event loop loop = asyncio.get_event_loop() if loop.is_running(): - # We're in an async context, need to use different approach import concurrent.futures with concurrent.futures.ThreadPoolExecutor() as executor: future = executor.submit( @@ -266,7 +249,6 @@ def get_teacher_logprobs_sync( ) ) except RuntimeError: - # No event loop exists, create one return asyncio.run( get_teacher_logprobs_for_experiences( teacher_url, sequences, response_lengths, device diff --git a/lightrft/strategy/config.py b/lightrft/strategy/config.py index c6993005..f88ae0ff 100644 --- a/lightrft/strategy/config.py +++ b/lightrft/strategy/config.py @@ -114,6 +114,8 @@ class StrategyConfig: dynamic_sampling: bool = False # (str): Advantage estimator method, defaults to "gae" advantage_estimator: str = "group_norm" + # (float): OPD KL coefficient for on-policy distillation penalty + opd_kl_coef: float = 1.0 # KL loss and estimation # (bool): Use KL loss in training, defaults to False diff --git a/lightrft/trainer/advantage_calculator.py b/lightrft/trainer/advantage_calculator.py index b750f68a..5f6a352f 100644 --- a/lightrft/trainer/advantage_calculator.py +++ b/lightrft/trainer/advantage_calculator.py @@ -719,14 +719,34 @@ def compute( class OnPolicyDistillationCalculator(AdvantageCalculator): """ - On-Policy Distillation calculator. + On-Policy Distillation calculator (GRPO + OPD KL penalty). - Uses teacher model's log probabilities as learning signal for knowledge distillation. - The advantage is computed as the difference between teacher and student log probabilities, - encouraging the student to match the teacher's token-level distribution. + Combines GRPO (group normalization) as the base advantage estimator with + an on-policy distillation KL penalty from a teacher model. This is orthogonal + to the base advantage estimator, following the Slime framework design: + + advantages = base_advantages(GRPO) - opd_kl_coef * (student_logp - teacher_logp) Reference: On-policy distillation from teacher models during RL training """ + def __init__(self, config): + super().__init__(config) + self.opd_kl_coef = getattr(config, 'opd_kl_coef', 1.0) + # Use GroupNormCalculator (GRPO) as the base advantage estimator + self.base_calculator = GroupNormCalculator(config) + + def preprocess_rewards( + self, + rewards: torch.Tensor, + experiences: List, + max_new_tokens: int, + ) -> Tuple[List, List[torch.Tensor]]: + """ + Delegate reward preprocessing to GRPO base calculator. + This applies group normalization (mean/std) to task rewards. + """ + return self.base_calculator.preprocess_rewards(rewards, experiences, max_new_tokens) + def compute( self, experience, @@ -735,65 +755,57 @@ def compute( generate_kwargs: Dict, ) -> Tuple[torch.Tensor, torch.Tensor, Dict]: """ - Compute advantages using teacher log probabilities. + Compute advantages = GRPO base advantages - opd_kl_coef * reverse_KL. - The advantage is computed as: - advantage = teacher_log_probs - student_log_probs - - This encourages the student model to match the teacher's probability distribution. + Step 1: Compute GRPO base advantages from task rewards (e.g., GSM8K accuracy). + Step 2: Apply OPD KL penalty: advantages -= opd_kl_coef * (student_logp - teacher_logp). + This encourages the student to match teacher's distribution while still + optimizing for task performance. :param experience: Experience object containing teacher_log_probs in info dict :type experience: object - :param final_reward: Unused for on_policy_distillation + :param final_reward: Processed reward tensor (from task rewards) :type final_reward: torch.Tensor - :param gamma: Discount factor. Unused for on_policy_distillation. + :param gamma: Discount factor for GRPO base advantages :type gamma: Optional[float] - :param generate_kwargs: Unused + :param generate_kwargs: Generation parameters :type generate_kwargs: Dict :return: Tuple of (advantages, returns, info_dict) :rtype: Tuple[torch.Tensor, torch.Tensor, Dict] """ - # Get teacher log probs from experience info + # Step 1: Compute GRPO base advantages from task rewards + base_advantages, returns, info_dict = self.base_calculator.compute( + experience, final_reward, gamma, generate_kwargs + ) + + # Step 2: Apply OPD KL penalty (if teacher_log_probs available) if "teacher_log_probs" not in experience.info: raise ValueError( "teacher_log_probs not found in experience.info. " - "Make sure to use the on_policy_distillation reward function." + "Make sure to use the on_policy_distillation reward function " + "and that _fetch_teacher_logprobs() was called." ) - teacher_log_probs = experience.info["teacher_log_probs"].to(final_reward.device) - - # Student log probs are already computed in experience.action_log_probs + teacher_log_probs = experience.info["teacher_log_probs"].to(base_advantages.device) student_log_probs = experience.action_log_probs - # Compute reverse KL divergence: student - teacher - # This is the correct direction for on-policy distillation: - # - When student > teacher: positive penalty (discourage student from being overconfident) - # - When student < teacher: negative penalty (encourage student to match teacher) - # The final advantage is: base_advantage - opd_kl_coef * reverse_kl - # Since we don't have a base advantage here, we use: -reverse_kl = teacher - student - # which encourages minimizing KL(student || teacher) + # Compute reverse KL: student_logp - teacher_logp + # Penalty: when student diverges from teacher, reverse_kl > 0 reverse_kl = student_log_probs - teacher_log_probs - advantages = -reverse_kl # This equals: teacher - student - # Apply action mask to ensure we only consider generated tokens + # Apply OPD penalty to base advantages + advantages = base_advantages - self.opd_kl_coef * reverse_kl + + # Apply action mask if experience.action_mask is not None: advantages = advantages * experience.action_mask - reverse_kl = reverse_kl * experience.action_mask - # Returns are the same as advantages for distillation - returns = deepcopy(advantages) - - # Store reverse KL for metrics logging - info_dict = {"opd_reverse_kl": reverse_kl} - - # Advantage whitening (normalization) - if self.config.advantages_norm: - masked_adv = torch.masked_select(advantages, experience.action_mask) - adv_mean = masked_adv.mean() - adv_std = masked_adv.std() - advantages = (advantages - adv_mean) / (adv_std + 1e-8) + # Store metrics for logging + if experience.action_mask is not None: + masked_rkl = reverse_kl * experience.action_mask + info_dict["opd_reverse_kl"] = masked_rkl.sum(-1) / experience.action_mask.sum(-1).clamp(min=1) - # Advantage clipping + # Advantage clipping (skip advantages_norm since GRPO already normalized rewards) if self.config.advantage_clip > 0: clip_val = self.config.advantage_clip info_dict["advantage_clip_frac"] = compute_clip_fraction(advantages, clip_val, -clip_val) diff --git a/lightrft/trainer/experience_maker.py b/lightrft/trainer/experience_maker.py index c545275e..f282087c 100644 --- a/lightrft/trainer/experience_maker.py +++ b/lightrft/trainer/experience_maker.py @@ -391,19 +391,15 @@ def make_experience_list(self, all_prompts: Union[str, List[str]], **generate_kw ) experience.advantages = deepcopy(experience.returns) elif self.advantage_estimator == "on_policy_distillation": - # For on_policy_distillation, use teacher log probs from experience.info - # The actual advantage computation happens in OnPolicyDistillationCalculator - # Here we just set placeholder values - teacher_log_probs = experience.info.get("teacher_log_probs") - if teacher_log_probs is None: - raise ValueError( - "teacher_log_probs not found in experience.info. " - "This should have been set in process_experiences()." - ) - # Set placeholder returns and advantages - # They will be properly computed in the advantage calculator - experience.returns = torch.zeros_like(experience.action_log_probs) - experience.advantages = torch.zeros_like(experience.action_log_probs) + # OPD uses GRPO base advantages + OPD KL penalty + # Here compute GRPO-style cumulative returns from task rewards + # The OPD KL penalty is applied in OnPolicyDistillationCalculator + experience.returns = self.get_cumulative_returns( + reward, + experience.action_mask, + generate_kwargs["gamma"], + ) + experience.advantages = deepcopy(experience.returns) else: raise Exception(f"Unknown advantage_estimator {self.advantage_estimator}") @@ -558,7 +554,7 @@ def process_experiences(self, experiences: List[Experience]) -> Tuple[List[Exper """ args = self.strategy.args - # On-policy distillation: query teacher model for log probs + # On-policy distillation: query teacher model for log probs, then use GRPO reward shaping if args.advantage_estimator == "on_policy_distillation": if self.remote_rm_url is None or len(self.remote_rm_url) == 0: raise ValueError( @@ -566,52 +562,66 @@ def process_experiences(self, experiences: List[Experience]) -> Tuple[List[Exper "Please set --remote_rm_url to the teacher model inference server." ) - # Import the teacher logprob function import asyncio - import sys - import os.path teacher_url = self.remote_rm_url[0] if isinstance(self.remote_rm_url, list) else self.remote_rm_url - # Collect all sequences and response lengths - all_sequences = [] + # Collect all sequences as input_ids and response lengths + all_input_ids = [] all_response_lengths = [] for experience in experiences: sequences_batch = experience.sequences response_lengths = experience.info["response_length"] - - # Decode sequences to text for i, seq in enumerate(sequences_batch): - decoded_seq = self.tokenizer.decode(seq.cpu().tolist(), skip_special_tokens=False) - all_sequences.append(decoded_seq) + all_input_ids.append(seq.cpu().tolist()) all_response_lengths.append(int(response_lengths[i].item())) - # Query teacher model for log probs + # Query teacher model for log probs using input_ids try: - # Import the custom teacher logprob function from examples.on_policy_distillation.on_policy_distillation_reward import ( - get_teacher_logprobs_sync + get_teacher_logprobs_by_ids ) - teacher_log_probs = get_teacher_logprobs_sync( - teacher_url=teacher_url, - sequences=all_sequences, - response_lengths=all_response_lengths, - device="cpu" - ) + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + teacher_lp_list = loop.run_until_complete( + get_teacher_logprobs_by_ids( + url=teacher_url, + input_ids_list=all_input_ids, + response_lengths=all_response_lengths, + ) + ) + finally: + loop.close() - # Split and store teacher log probs in each experience + # Align teacher log probs to action_log_probs shape [batch, num_actions] idx = 0 for experience in experiences: batch_size = experience.sequences.size(0) - experience.info["teacher_log_probs"] = teacher_log_probs[idx:idx + batch_size] + num_actions = experience.action_mask.shape[1] + aligned = torch.zeros(batch_size, num_actions, dtype=torch.float32) + for j in range(batch_size): + tlp = teacher_lp_list[idx + j] + resp_len = all_response_lengths[idx + j] + actual_len = min(len(tlp), resp_len, num_actions) + start_pos = num_actions - resp_len + if start_pos >= 0: + aligned[j, start_pos:start_pos + actual_len] = tlp[:actual_len] + else: + aligned[j, :] = tlp[-num_actions:] + experience.info["teacher_log_probs"] = aligned idx += batch_size except Exception as e: logger.error(f"Failed to get teacher log probs: {e}") raise - # Return placeholder rewards (actual learning signal comes from teacher log probs) - rewards = [torch.zeros_like(experience.info["reward"]) for experience in experiences] + # Use GRPO reward shaping (group normalization) on task rewards + rewards = torch.cat([experience.info["reward"] for experience in experiences]) + rewards = rewards.reshape(-1, args.n_samples_per_prompt) + baseline = rewards.mean(-1, keepdim=True) + rewards = (rewards - baseline) / (rewards.std(1, keepdim=True) + 1e-8) + rewards = rewards.flatten().chunk(len(experiences)) return experiences, rewards # Reward shaping for RLOO diff --git a/lightrft/trainer/fast_exp_maker.py b/lightrft/trainer/fast_exp_maker.py index ddf46eb3..20fd7be2 100644 --- a/lightrft/trainer/fast_exp_maker.py +++ b/lightrft/trainer/fast_exp_maker.py @@ -62,7 +62,8 @@ # On-Policy Distillation imports try: from examples.on_policy_distillation.on_policy_distillation_reward import ( - get_teacher_logprobs_for_experiences + get_teacher_logprobs_for_experiences, + get_teacher_logprobs_by_ids, ) import asyncio OPD_AVAILABLE = True @@ -1448,9 +1449,11 @@ def _fetch_teacher_logprobs( """ Fetch teacher log probabilities for on-policy distillation. - This method queries the teacher model server to get log probabilities - for each generated sequence, which are used by OnPolicyDistillationCalculator - to compute advantages. + Uses input_ids (not text) to query teacher model, ensuring token-level + alignment between teacher and student log probabilities. + + Teacher log probs are aligned to the same shape as action_log_probs + [batch_size, seq_len], with zeros for prompt positions. :param experiences: List of experiences to add teacher log probs to :type experiences: List[Union[Experience, ExperienceVL]] @@ -1462,7 +1465,6 @@ def _fetch_teacher_logprobs( ) # Get teacher URL from config - # Use self.teacher_model_url which is set during __init__ for OPD mode teacher_url = self.teacher_model_url if isinstance(teacher_url, list): teacher_url = teacher_url[0] if teacher_url else None @@ -1475,41 +1477,48 @@ def _fetch_teacher_logprobs( Timer.start(' fetch_teacher_logprobs') for exp in experiences: - # Decode sequences to text for teacher model query - sequences = exp.sequences - attention_mask = exp.attention_mask - action_mask = exp.action_mask + sequences = exp.sequences # [batch_size, seq_len] + action_mask = exp.action_mask # [batch_size, num_actions] # Get response lengths (number of generated tokens per sequence) response_lengths = action_mask.sum(dim=-1).tolist() + num_actions = action_mask.shape[1] # action_log_probs dim - # Decode full sequences to text - # Using batch_decode for efficiency - sequence_texts = self.tokenizer.batch_decode( - sequences, - skip_special_tokens=False, - clean_up_tokenization_spaces=False, - ) + # Use input_ids for teacher query to ensure token-level alignment + input_ids_list = sequences.cpu().tolist() # Query teacher model for log probs try: - # Use asyncio to run the async function loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: - teacher_log_probs = loop.run_until_complete( - get_teacher_logprobs_for_experiences( - teacher_url=teacher_url, - sequences=sequence_texts, + teacher_lp_list = loop.run_until_complete( + get_teacher_logprobs_by_ids( + url=teacher_url, + input_ids_list=input_ids_list, response_lengths=response_lengths, - device="cpu", # Store on CPU, will move to GPU when needed ) ) finally: loop.close() + # Align teacher log probs to action_log_probs shape [batch_size, num_actions] + # teacher_lp_list[i] has shape [resp_len_i], need to pad/align to [num_actions] + batch_size = sequences.shape[0] + aligned_teacher_lp = torch.zeros(batch_size, num_actions, dtype=torch.float32) + for i, (tlp, resp_len) in enumerate(zip(teacher_lp_list, response_lengths)): + # Right-align: teacher log probs fill the last resp_len positions + # (matching where action_mask == 1) + actual_len = min(len(tlp), resp_len, num_actions) + start_pos = num_actions - resp_len + if start_pos >= 0: + aligned_teacher_lp[i, start_pos:start_pos + actual_len] = tlp[:actual_len] + else: + # resp_len > num_actions (shouldn't happen, but handle gracefully) + aligned_teacher_lp[i, :] = tlp[-num_actions:] + # Store in experience info for advantage calculator - exp.info["teacher_log_probs"] = teacher_log_probs + exp.info["teacher_log_probs"] = aligned_teacher_lp except Exception as e: raise RuntimeError( From 7b83fe24e1784f11a4fb66681fea7373bf6d0390 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 17 Mar 2026 21:25:22 +0800 Subject: [PATCH 05/18] fix(pu): add two opd_mode, add advantage_whitening for opd_kl --- examples/gsm8k_geo3k/train_colocate.py | 4 +- .../on_policy_distillation/run_opd_qwen.sh | 355 +++++++++--------- lightrft/trainer/advantage_calculator.py | 184 +++++---- lightrft/trainer/experience_maker.py | 30 +- lightrft/trainer/fast_exp_maker.py | 4 +- 5 files changed, 318 insertions(+), 259 deletions(-) diff --git a/examples/gsm8k_geo3k/train_colocate.py b/examples/gsm8k_geo3k/train_colocate.py index f6ad862c..6b8b29b3 100644 --- a/examples/gsm8k_geo3k/train_colocate.py +++ b/examples/gsm8k_geo3k/train_colocate.py @@ -561,7 +561,7 @@ def train(args): parser.add_argument( "--advantage_estimator", type=str, - choices=["gae", "reinforce", "rloo", "reinforce_baseline", "group_norm", "cpgd", "reinforce++", "on_policy_distillation"], + choices=["gae", "reinforce", "rloo", "reinforce_baseline", "group_norm", "cpgd", "reinforce++", "on_policy_distillation", "on_policy_distillation_hybrid"], default="gae", help="Choose advantage estimation method: gae, reinforce, rloo, reinforce_baseline, group_norm, reinforce++", ) @@ -654,7 +654,7 @@ def train(args): elif args.critic_pretrain is None: args.critic_pretrain = args.pretrain - if args.advantage_estimator in ["rloo", "reinforce_baseline", "group_norm", "on_policy_distillation"]: + if args.advantage_estimator in ["rloo", "reinforce_baseline", "group_norm", "on_policy_distillation", "on_policy_distillation_hybrid"]: assert args.n_samples_per_prompt > 1, f"{args.advantage_estimator} requires n_samples_per_prompt > 1" if args.use_kl_loss: diff --git a/examples/on_policy_distillation/run_opd_qwen.sh b/examples/on_policy_distillation/run_opd_qwen.sh index 2b5b41bd..ceeebaa6 100644 --- a/examples/on_policy_distillation/run_opd_qwen.sh +++ b/examples/on_policy_distillation/run_opd_qwen.sh @@ -1,170 +1,198 @@ #!/bin/bash # -# LightRFT On-Policy Distillation Training Script -# This script demonstrates knowledge distillation from Qwen2.5-7B (teacher) to Qwen2.5-0.5B (student) -# using on-policy distillation during reinforcement learning. +# LightRFT On-Policy Distillation Training Script (Template) +# Knowledge distillation from a teacher model to a student model. # -# Key Features: -# - No separate reward model needed - teacher model provides the learning signal -# - Token-level supervision from teacher log probabilities -# - On-policy: teacher evaluates student's actual generated responses +# Features: +# - Auto GPU detection and allocation (teacher + training) +# - Robust teacher server with health monitoring +# - Two OPD modes: pure distillation / hybrid (GRPO + OPD) +# +# Usage: +# # Edit paths below, then: +# bash examples/on_policy_distillation/run_opd_qwen.sh +# OPD_MODE=hybrid bash examples/on_policy_distillation/run_opd_qwen.sh # -set -e +set -euo pipefail ################################################################################ # Part 1: User Configuration # ################################################################################ -# --- Model Paths --- -# Teacher model (larger, provides learning signal) -TEACHER_MODEL_PATH="Qwen/Qwen2.5-7B-Instruct" - -# Student model (smaller, being trained) -STUDENT_MODEL_PATH="Qwen/Qwen2.5-0.5B-Instruct" +# --- Model Paths (EDIT THESE) --- +TEACHER_MODEL_PATH="${TEACHER_MODEL_PATH:-Qwen/Qwen2.5-7B-Instruct}" +STUDENT_MODEL_PATH="${STUDENT_MODEL_PATH:-Qwen/Qwen2.5-0.5B-Instruct}" +DATASET_PATH="${DATASET_PATH:-/path/to/your/dataset.jsonl}" -# --- Dataset Path --- -# Path to your training dataset (JSONL format) -# Each line should be a JSON object with a "prompt" field -DATASET_PATH="/path/to/your/dataset.jsonl" +# --- Experiment --- +EXPERIMENT_NAME="${EXPERIMENT_NAME:-opd-qwen-7b-to-0.5b}" +export WANDB_API_KEY="${WANDB_API_KEY:-YOUR_WANDB_API_KEY}" +export WANDB_PROJECT="${WANDB_PROJECT:-LightRFT-OnPolicyDistillation}" -# --- Teacher Model Server Configuration --- +# --- Teacher Server --- TEACHER_IP="127.0.0.1" -TEACHER_PORT=13141 -TEACHER_GPU=7 # GPU to run teacher model on - -# --- Experiment Configuration --- -EXPERIMENT_NAME="opd-qwen-7b-to-0.5b" -export WANDB_API_KEY="YOUR_WANDB_API_KEY" -export WANDB_PROJECT="LightRFT-OnPolicyDistillation" +TEACHER_PORT=${TEACHER_PORT:-13141} ################################################################################ -# Part 2: Training Hyperparameters # +# Part 2: Auto GPU Detection & Allocation # ################################################################################ -# --- Distillation Settings --- -N_SAMPLES=4 # Number of samples per prompt -EPISODE=30 # Total number of training episodes -WARMUP=0.03 # Learning rate warmup ratio +TOTAL_GPUS=$(nvidia-smi -L 2>/dev/null | wc -l) +if [ "$TOTAL_GPUS" -lt 2 ]; then + echo "ERROR: Need at least 2 GPUs (1 teacher + 1 training). Found: $TOTAL_GPUS" + exit 1 +fi + +# Last GPU for teacher, rest for training +TEACHER_GPU=$((TOTAL_GPUS - 1)) +TRAIN_GPUS=$((TOTAL_GPUS - 1)) + +if [ "$TRAIN_GPUS" -ge 2 ]; then + ENGINE_TP=2 +else + ENGINE_TP=1 +fi -# --- Batch Size Configuration --- -RBS=128 # Rollout Batch Size -TBS=128 # Train Batch Size +export NNODES=1 +export GPUS_PER_NODE=$TRAIN_GPUS +export NODE_RANK=0 +export MASTER_ADDR="localhost" +export MASTER_PORT=${MASTER_PORT:-20090} -# --- Learning Settings --- -KL=0.01 # KL divergence coefficient (for regularization) -LR=1e-6 # Student learning rate -MAX_LENGTH=3072 # Max sequence length -PROMPT_MAX_LEN=1024 # Max prompt length -GENERATE_MAX_LEN=2048 # Max generation length +echo "GPU Allocation: ${TOTAL_GPUS} total → Teacher: GPU ${TEACHER_GPU}, Training: GPU 0-$((TRAIN_GPUS-1)) (TP=${ENGINE_TP})" ################################################################################ -# Part 3: Distributed Training Setup # +# Part 3: Training Hyperparameters # ################################################################################ -# --- Single-Node Setup --- -export MLP_WORKER_NUM=1 -export MLP_WORKER_GPU=8 -export MLP_ROLE_INDEX=0 -export MLP_WORKER_0_HOST="localhost" -export MLP_WORKER_0_PORT=20090 +# --- OPD Mode (override via env: OPD_MODE=hybrid) --- +# "pure" - Pure distillation (Slime default): rewards=0, only OPD KL signal +# "hybrid" - GRPO task rewards + OPD KL penalty with advantage whitening +OPD_MODE="${OPD_MODE:-pure}" -# --- PyTorch Distributed Variables --- -export MASTER_ADDR=$MLP_WORKER_0_HOST -export MASTER_PORT=$MLP_WORKER_0_PORT -export NNODES=$MLP_WORKER_NUM -export NODE_RANK=$MLP_ROLE_INDEX -export GPUS_PER_NODE=$MLP_WORKER_GPU +N_SAMPLES=${N_SAMPLES:-8} +EPISODE=${EPISODE:-30} +WARMUP=${WARMUP:-0.03} +OPD_KL_COEF=${OPD_KL_COEF:-1.0} -# --- vLLM Engine Settings --- -ENGINE_TP=2 # Tensor parallelism for inference engine +RBS=${RBS:-128} +TBS=${TBS:-128} + +if [ "$OPD_MODE" = "hybrid" ]; then + ADVANTAGE_ESTIMATOR="on_policy_distillation_hybrid" + KL=${KL:-0.01} + LR=${LR:-5e-7} +else + ADVANTAGE_ESTIMATOR="on_policy_distillation" + KL=${KL:-0.00} + LR=${LR:-5e-7} +fi + +PROMPT_MAX_LEN=${PROMPT_MAX_LEN:-1024} +GENERATE_MAX_LEN=${GENERATE_MAX_LEN:-2048} ################################################################################ -# Part 4: Start Teacher Model Server # +# Part 4: Teacher Model Server # ################################################################################ -echo "=========================================" -echo "Starting Teacher Model Server" -echo "=========================================" +TEACHER_PID="" +LOG_DIR="rft_logs/${EXPERIMENT_NAME}" +mkdir -p "$LOG_DIR" +TEACHER_LOG="${LOG_DIR}/teacher_model_$(date +%Y%m%d_%H%M%S).log" -# Generate unique log file for teacher server -LOG_FILE="rft_logs/${EXPERIMENT_NAME}/teacher_model_$(date +%Y%m%d_%H%M%S).log" -mkdir -p rft_logs - -# Launch teacher model server in background -CUDA_VISIBLE_DEVICES=$TEACHER_GPU python3 -m sglang.launch_server \ - --model-path "$TEACHER_MODEL_PATH" \ - --host 0.0.0.0 \ - --port $TEACHER_PORT \ - --tp 1 \ - --chunked-prefill-size 4096 \ - --mem-fraction-static 0.6 \ - > "$LOG_FILE" 2>&1 & - -TEACHER_PID=$! -echo "Teacher model server starting (PID: $TEACHER_PID)..." -echo "Logs: $LOG_FILE" - -# Wait for teacher model server to be ready -MAX_WAIT=300 # Maximum wait time in seconds -WAITED=0 -until curl -sf http://$TEACHER_IP:$TEACHER_PORT/health > /dev/null 2>&1; do - if [ $WAITED -ge $MAX_WAIT ]; then - echo "ERROR: Teacher model server failed to start within $MAX_WAIT seconds" - echo "Last 20 lines of log:" - tail -n 20 "$LOG_FILE" - kill $TEACHER_PID 2>/dev/null || true - exit 1 +cleanup() { + echo "" + echo "=== Cleaning up ===" + if [ -n "$TEACHER_PID" ] && kill -0 "$TEACHER_PID" 2>/dev/null; then + echo "Stopping teacher server (PID: $TEACHER_PID)..." + kill "$TEACHER_PID" 2>/dev/null || true + sleep 2 + kill -9 "$TEACHER_PID" 2>/dev/null || true + fi + pkill -f "sglang.launch_server.*${TEACHER_PORT}" 2>/dev/null || true + echo "Done." +} +trap cleanup EXIT INT TERM + +start_teacher_server() { + if lsof -Pi :"$TEACHER_PORT" -sTCP:LISTEN -t >/dev/null 2>&1; then + echo "Port $TEACHER_PORT in use, killing existing process..." + lsof -ti:"$TEACHER_PORT" | xargs kill -9 2>/dev/null || true + sleep 3 fi - echo "Waiting for teacher model server to start... ($WAITED/$MAX_WAIT seconds)" - tail -n 5 "$LOG_FILE" - sleep 5 - WAITED=$((WAITED + 5)) + + echo "Starting teacher server on GPU $TEACHER_GPU..." + CUDA_VISIBLE_DEVICES=$TEACHER_GPU python3 -m sglang.launch_server \ + --model-path "$TEACHER_MODEL_PATH" \ + --host 0.0.0.0 \ + --port "$TEACHER_PORT" \ + --tp 1 \ + --chunked-prefill-size 4096 \ + --mem-fraction-static 0.7 \ + --disable-radix-cache \ + --request-timeout 300 \ + --max-running-requests 64 \ + >> "$TEACHER_LOG" 2>&1 & + + TEACHER_PID=$! + echo "Teacher PID: $TEACHER_PID, Log: $TEACHER_LOG" + + local max_wait=600 waited=0 + while ! curl -sf "http://$TEACHER_IP:$TEACHER_PORT/health" >/dev/null 2>&1; do + if [ $waited -ge $max_wait ]; then + echo "ERROR: Teacher server failed to start in ${max_wait}s" + tail -30 "$TEACHER_LOG" + exit 1 + fi + if ! kill -0 "$TEACHER_PID" 2>/dev/null; then + echo "ERROR: Teacher server process died" + tail -30 "$TEACHER_LOG" + exit 1 + fi + printf "." + sleep 5 + waited=$((waited + 5)) + done + echo "" + echo "Teacher server ready at $TEACHER_IP:$TEACHER_PORT" +} + +# Validate model paths +for p in "$TEACHER_MODEL_PATH" "$STUDENT_MODEL_PATH"; do + [ -e "$p" ] || { echo "ERROR: Path not found: $p"; exit 1; } done -echo "✓ Teacher model server is up and running at $TEACHER_IP:$TEACHER_PORT" -sleep 5 +start_teacher_server +sleep 3 ################################################################################ -# Part 5: Training Setup # +# Part 5: Launch Training # ################################################################################ -# --- Generate dynamic names --- current_time=$(date +"%Y%m%d_%H%M%S") -SAVE_MODEL_NAME="${EXPERIMENT_NAME}-ep${EPISODE}-kl${KL}-lr${LR}-${current_time}" -WANDB_RUN_NAME="${EXPERIMENT_NAME}-${current_time}" +SAVE_MODEL_NAME="${EXPERIMENT_NAME}-${OPD_MODE}-ep${EPISODE}-kl${KL}-lr${LR}-${current_time}" +WANDB_RUN_NAME="${EXPERIMENT_NAME}-${OPD_MODE}-${current_time}" +TEACHER_URL="http://$TEACHER_IP:$TEACHER_PORT/generate" -# --- Create directories --- mkdir -p "results/${EXPERIMENT_NAME}/${SAVE_MODEL_NAME}" -mkdir -p "rft_logs/${EXPERIMENT_NAME}" -# --- Environment optimizations --- export TORCH_NCCL_AVOID_RECORD_STREAMS=1 export NCCL_DEBUG="WARN" +export NCCL_TIMEOUT=3600 export IGNORE_EOS=0 -export WANDB_MODE="offline" # Set to "online" for real-time logging - -# --- Teacher model URL for distillation --- -TEACHER_URL="http://$TEACHER_IP:$TEACHER_PORT/generate" - -set -x - -################################################################################ -# Part 6: Launch Training # -################################################################################ +export WANDB_MODE="${WANDB_MODE:-offline}" echo "=========================================" -echo "Starting On-Policy Distillation Training" +echo "On-Policy Distillation Training" +echo "=========================================" +echo "Mode: $OPD_MODE ($ADVANTAGE_ESTIMATOR)" +echo "Student: $STUDENT_MODEL_PATH" +echo "Teacher: $TEACHER_URL" +echo "GPUs: Training=0-$((TRAIN_GPUS-1)), Teacher=$TEACHER_GPU" echo "=========================================" -# Function to cleanup on exit -cleanup() { - echo "Cleaning up..." - kill $TEACHER_PID 2>/dev/null || true - pkill -f "sglang.launch_server" 2>/dev/null || true - echo "Cleanup complete" -} -trap cleanup EXIT +set -x torchrun \ --nnodes $NNODES \ @@ -175,12 +203,15 @@ torchrun \ examples/gsm8k_geo3k/train_colocate.py \ --pretrain "$STUDENT_MODEL_PATH" \ --save_trajectories \ - --advantage_estimator "on_policy_distillation" \ + --advantage_estimator "${ADVANTAGE_ESTIMATOR}" \ + --opd_kl_coef ${OPD_KL_COEF} \ --fsdp \ --use_kl_loss \ --flash_attn \ + --engine_type sglang \ + --enable_engine_sleep \ --rm_use_engine \ - --reward_pretrain "{}" \ + --reward_pretrain "" \ --remote_rm_url "$TEACHER_URL" \ --save_path "results/${EXPERIMENT_NAME}/${SAVE_MODEL_NAME}" \ --ckpt_path "results/${EXPERIMENT_NAME}/${SAVE_MODEL_NAME}" \ @@ -200,60 +231,32 @@ torchrun \ --init_kl_coef $KL \ --kl_estimator "k3" \ --prompt_data "$DATASET_PATH" \ + --input_key "prompt" \ + --label_key "label" \ + --eval_steps 20 \ + --eval_split "test" \ + --apply_chat_template \ + --gradient_checkpointing \ + --save_steps 20 \ --max_ckpt_num 3 \ - --max_ckpt_mem 160 \ - --use_wandb \ - --wandb_project "$WANDB_PROJECT" \ - --wandb_run_name "$WANDB_RUN_NAME" \ - --logging_steps 1 \ - --eval_steps -1 \ - --rm_engine_tp $ENGINE_TP - -echo "Training complete!" - -################################################################################ -# Part 7: Usage Instructions # -################################################################################ + --engine_mem_util 0.6 \ + --engine_tp_size $ENGINE_TP \ + --l2 1.0e-2 \ + --freeze_prefix \ + --adam_offload \ + --text_only \ + --use_wandb "${WANDB_API_KEY}" \ + --wandb_project "${WANDB_PROJECT}" \ + --wandb_run_name "${WANDB_RUN_NAME}" \ + 2>&1 | tee "${LOG_DIR}/node${NODE_RANK}_${current_time}.log" + +TRAINING_EXIT_CODE=${PIPESTATUS[0]} +set +x + +echo "" +echo "=========================================" +echo "Training Complete (exit: $TRAINING_EXIT_CODE)" +echo "Model: results/${EXPERIMENT_NAME}/${SAVE_MODEL_NAME}" +echo "=========================================" -: <<'USAGE' -=============================================================================== -Usage Instructions -=============================================================================== - -1. Prerequisites: - - Install LightRFT and dependencies - - Install SGLang: pip install sglang - - Prepare your training dataset in JSONL format - -2. Configure the script: - - Set TEACHER_MODEL_PATH to your teacher model - - Set STUDENT_MODEL_PATH to your student model - - Set DATASET_PATH to your training data - - Adjust hyperparameters as needed - -3. Run the script: - bash examples/on_policy_distillation/run_opd_qwen.sh - -4. Monitor training: - - Check W&B dashboard for training metrics - - Logs are saved in rft_logs/${EXPERIMENT_NAME}/ - - Checkpoints are saved in results/${EXPERIMENT_NAME}/ - -5. Key Parameters: - - N_SAMPLES: Number of responses per prompt (higher = more stable but slower) - - LR: Learning rate (typically 1e-6 for distillation) - - KL: KL coefficient for regularization (keeps student close to initialization) - - EPISODE: Number of training episodes - -6. Expected Behavior: - - Student model should gradually match teacher's probability distribution - - Training loss should decrease over episodes - - Student responses should become more similar to teacher's style - -7. Troubleshooting: - - If teacher server fails: Check GPU memory and CUDA availability - - If training OOMs: Reduce batch sizes or enable gradient checkpointing - - If convergence is slow: Adjust learning rate or increase N_SAMPLES - -=============================================================================== -USAGE +exit $TRAINING_EXIT_CODE diff --git a/lightrft/trainer/advantage_calculator.py b/lightrft/trainer/advantage_calculator.py index 5f6a352f..b67b021e 100644 --- a/lightrft/trainer/advantage_calculator.py +++ b/lightrft/trainer/advantage_calculator.py @@ -717,95 +717,144 @@ def compute( return advantages, returns, info_dict +def _apply_opd_kl_penalty( + student_log_probs: torch.Tensor, + teacher_log_probs: torch.Tensor, + action_mask: Optional[torch.Tensor], + opd_kl_coef: float, +) -> Tuple[torch.Tensor, Dict]: + """ + Compute OPD reverse KL penalty: -opd_kl_coef * (student_logp - teacher_logp). + + Shared helper for both pure and hybrid OPD modes. + + :return: Tuple of (opd_advantages, info_dict with opd_reverse_kl metric) + """ + reverse_kl = student_log_probs - teacher_log_probs + opd_adv = -opd_kl_coef * reverse_kl + + if action_mask is not None: + opd_adv = opd_adv * action_mask + + info_dict = {} + if action_mask is not None: + masked_rkl = reverse_kl * action_mask + info_dict["opd_reverse_kl"] = masked_rkl.sum(-1) / action_mask.sum(-1).clamp(min=1) + + return opd_adv, info_dict + + +def _whiten_advantages(advantages: torch.Tensor, action_mask: Optional[torch.Tensor], eps: float = 1e-8) -> torch.Tensor: + """ + Whiten advantages using masked mean/std (matching Slime's distributed_masked_whiten). + + This normalizes advantages to zero mean and unit variance, which stabilizes + training when OPD KL penalty has different scale from base advantages. + """ + if action_mask is not None: + mask = action_mask.bool() + masked_adv = torch.masked_select(advantages, mask) + else: + masked_adv = advantages.flatten() + + if masked_adv.numel() < 2: + return advantages + + mean = masked_adv.mean() + std = masked_adv.std() + return (advantages - mean) / (std + eps) + + class OnPolicyDistillationCalculator(AdvantageCalculator): """ - On-Policy Distillation calculator (GRPO + OPD KL penalty). + On-Policy Distillation calculator — pure distillation mode. + + Following Slime's design: + - Task rewards are zeroed out + - The ONLY learning signal is the OPD KL penalty: + advantages = -opd_kl_coef * (student_logp - teacher_logp) + - Advantage whitening is applied for training stability + + Use --advantage_estimator on_policy_distillation + """ + def __init__(self, config): + super().__init__(config) + self.opd_kl_coef = getattr(config, 'opd_kl_coef', 1.0) + + def preprocess_rewards(self, rewards, experiences, max_new_tokens): + """Zero out all rewards — pure distillation mode.""" + zero_rewards = torch.zeros_like(rewards) + return experiences, list(zero_rewards.chunk(len(experiences))) + + def compute(self, experience, final_reward, gamma, generate_kwargs): + """advantages = -opd_kl_coef * (student_logp - teacher_logp), then whiten.""" + if "teacher_log_probs" not in experience.info: + raise ValueError("teacher_log_probs not found in experience.info.") + + teacher_lp = experience.info["teacher_log_probs"].to(experience.action_log_probs.device) + student_lp = experience.action_log_probs - Combines GRPO (group normalization) as the base advantage estimator with - an on-policy distillation KL penalty from a teacher model. This is orthogonal - to the base advantage estimator, following the Slime framework design: + advantages, info_dict = _apply_opd_kl_penalty( + student_lp, teacher_lp, experience.action_mask, self.opd_kl_coef + ) + + # Whiten advantages for training stability + advantages = _whiten_advantages(advantages, experience.action_mask) + + returns = advantages.clone() - advantages = base_advantages(GRPO) - opd_kl_coef * (student_logp - teacher_logp) + if self.config.advantage_clip > 0: + clip_val = self.config.advantage_clip + info_dict["advantage_clip_frac"] = compute_clip_fraction(advantages, clip_val, -clip_val) + advantages = torch.clamp(advantages, -clip_val, clip_val) + + return advantages, returns, info_dict - Reference: On-policy distillation from teacher models during RL training + +class OnPolicyDistillationHybridCalculator(AdvantageCalculator): + """ + Hybrid On-Policy Distillation calculator — GRPO task rewards + OPD KL penalty. + + Combines GRPO (group normalization) base advantages from task rewards with + OPD KL penalty from teacher model. Advantage whitening is applied AFTER + combining both signals to resolve scale mismatch. + + advantages = whiten(GRPO_base_advantages + OPD_KL_penalty) + + Use --advantage_estimator on_policy_distillation_hybrid """ def __init__(self, config): super().__init__(config) self.opd_kl_coef = getattr(config, 'opd_kl_coef', 1.0) - # Use GroupNormCalculator (GRPO) as the base advantage estimator self.base_calculator = GroupNormCalculator(config) - def preprocess_rewards( - self, - rewards: torch.Tensor, - experiences: List, - max_new_tokens: int, - ) -> Tuple[List, List[torch.Tensor]]: - """ - Delegate reward preprocessing to GRPO base calculator. - This applies group normalization (mean/std) to task rewards. - """ + def preprocess_rewards(self, rewards, experiences, max_new_tokens): + """Delegate to GRPO for task reward normalization.""" return self.base_calculator.preprocess_rewards(rewards, experiences, max_new_tokens) - def compute( - self, - experience, - final_reward: torch.Tensor, - gamma: Optional[float], - generate_kwargs: Dict, - ) -> Tuple[torch.Tensor, torch.Tensor, Dict]: - """ - Compute advantages = GRPO base advantages - opd_kl_coef * reverse_KL. - - Step 1: Compute GRPO base advantages from task rewards (e.g., GSM8K accuracy). - Step 2: Apply OPD KL penalty: advantages -= opd_kl_coef * (student_logp - teacher_logp). - This encourages the student to match teacher's distribution while still - optimizing for task performance. - - :param experience: Experience object containing teacher_log_probs in info dict - :type experience: object - :param final_reward: Processed reward tensor (from task rewards) - :type final_reward: torch.Tensor - :param gamma: Discount factor for GRPO base advantages - :type gamma: Optional[float] - :param generate_kwargs: Generation parameters - :type generate_kwargs: Dict - :return: Tuple of (advantages, returns, info_dict) - :rtype: Tuple[torch.Tensor, torch.Tensor, Dict] - """ - # Step 1: Compute GRPO base advantages from task rewards + def compute(self, experience, final_reward, gamma, generate_kwargs): + """advantages = whiten(GRPO_base + OPD_KL_penalty).""" + # Step 1: GRPO base advantages from task rewards base_advantages, returns, info_dict = self.base_calculator.compute( experience, final_reward, gamma, generate_kwargs ) - # Step 2: Apply OPD KL penalty (if teacher_log_probs available) + # Step 2: OPD KL penalty if "teacher_log_probs" not in experience.info: - raise ValueError( - "teacher_log_probs not found in experience.info. " - "Make sure to use the on_policy_distillation reward function " - "and that _fetch_teacher_logprobs() was called." - ) - - teacher_log_probs = experience.info["teacher_log_probs"].to(base_advantages.device) - student_log_probs = experience.action_log_probs + raise ValueError("teacher_log_probs not found in experience.info.") - # Compute reverse KL: student_logp - teacher_logp - # Penalty: when student diverges from teacher, reverse_kl > 0 - reverse_kl = student_log_probs - teacher_log_probs + teacher_lp = experience.info["teacher_log_probs"].to(base_advantages.device) + student_lp = experience.action_log_probs - # Apply OPD penalty to base advantages - advantages = base_advantages - self.opd_kl_coef * reverse_kl - - # Apply action mask - if experience.action_mask is not None: - advantages = advantages * experience.action_mask + opd_adv, opd_info = _apply_opd_kl_penalty( + student_lp, teacher_lp, experience.action_mask, self.opd_kl_coef + ) + info_dict.update(opd_info) - # Store metrics for logging - if experience.action_mask is not None: - masked_rkl = reverse_kl * experience.action_mask - info_dict["opd_reverse_kl"] = masked_rkl.sum(-1) / experience.action_mask.sum(-1).clamp(min=1) + # Step 3: Combine and whiten to resolve scale mismatch + advantages = base_advantages + opd_adv + advantages = _whiten_advantages(advantages, experience.action_mask) - # Advantage clipping (skip advantages_norm since GRPO already normalized rewards) if self.config.advantage_clip > 0: clip_val = self.config.advantage_clip info_dict["advantage_clip_frac"] = compute_clip_fraction(advantages, clip_val, -clip_val) @@ -843,6 +892,7 @@ def get_advantage_calculator(estimator_name: str, config) -> AdvantageCalculator "cpgd": CPGDCalculator, "grpo": GroupNormCalculator, # Alias for group_norm "on_policy_distillation": OnPolicyDistillationCalculator, + "on_policy_distillation_hybrid": OnPolicyDistillationHybridCalculator, } calculator_class = calculator_map.get(estimator_name) diff --git a/lightrft/trainer/experience_maker.py b/lightrft/trainer/experience_maker.py index cf4a2a37..c5f63eac 100644 --- a/lightrft/trainer/experience_maker.py +++ b/lightrft/trainer/experience_maker.py @@ -400,10 +400,9 @@ def make_experience_list(self, all_prompts: Union[str, List[str]], **generate_kw generate_kwargs["gamma"], ) experience.advantages = deepcopy(experience.returns) - elif self.advantage_estimator == "on_policy_distillation": - # OPD uses GRPO base advantages + OPD KL penalty - # Here compute GRPO-style cumulative returns from task rewards - # The OPD KL penalty is applied in OnPolicyDistillationCalculator + elif self.advantage_estimator in ("on_policy_distillation", "on_policy_distillation_hybrid"): + # OPD: cumulative returns from rewards (0 for pure, task rewards for hybrid) + # The OPD KL penalty is applied in the advantage calculator experience.returns = self.get_cumulative_returns( reward, experience.action_mask, @@ -586,7 +585,7 @@ def process_experiences(self, experiences: List[Experience]) -> Tuple[List[Exper args = self.strategy.args # On-policy distillation: query teacher model for log probs, then use GRPO reward shaping - if args.advantage_estimator == "on_policy_distillation": + if args.advantage_estimator in ("on_policy_distillation", "on_policy_distillation_hybrid"): if self.remote_rm_url is None or len(self.remote_rm_url) == 0: raise ValueError( "On-policy distillation requires a teacher model URL. " @@ -647,13 +646,20 @@ def process_experiences(self, experiences: List[Experience]) -> Tuple[List[Exper logger.error(f"Failed to get teacher log probs: {e}") raise - # Use GRPO reward shaping (group normalization) on task rewards - rewards = torch.cat([experience.info["reward"] for experience in experiences]) - rewards = rewards.reshape(-1, args.n_samples_per_prompt) - baseline = rewards.mean(-1, keepdim=True) - rewards = (rewards - baseline) / (rewards.std(1, keepdim=True) + 1e-8) - rewards = rewards.flatten().chunk(len(experiences)) - return experiences, rewards + # Return rewards based on mode + if args.advantage_estimator == "on_policy_distillation": + # Pure distillation: zero rewards, learning signal from OPD KL only + zero_rewards = torch.zeros(sum(exp.sequences.size(0) for exp in experiences)) + rewards = zero_rewards.chunk(len(experiences)) + return experiences, list(rewards) + else: + # Hybrid: use task rewards with GRPO normalization + rewards = torch.cat([experience.info["reward"] for experience in experiences]) + rewards = rewards.reshape(-1, args.n_samples_per_prompt) + baseline = rewards.mean(-1, keepdim=True) + rewards = (rewards - baseline) / (rewards.std(-1, keepdim=True) + 1e-9) + rewards = rewards.flatten().chunk(len(experiences)) + return experiences, list(rewards) # Reward shaping for RLOO if args.advantage_estimator == "rloo": diff --git a/lightrft/trainer/fast_exp_maker.py b/lightrft/trainer/fast_exp_maker.py index 840d0a5b..094bbbdc 100644 --- a/lightrft/trainer/fast_exp_maker.py +++ b/lightrft/trainer/fast_exp_maker.py @@ -943,7 +943,7 @@ def __init__(self, *args, packing_samples: bool = False, processor=None, **kwarg # For On-Policy Distillation (OPD), remote_rm_url is used for teacher model, # not for reward model. So we don't pass it to RewardComputationEngine. # Instead, we store it separately for _fetch_teacher_logprobs(). - if advantage_estimator == "on_policy_distillation": + if advantage_estimator in ("on_policy_distillation", "on_policy_distillation_hybrid"): # Store teacher URL separately for OPD self.teacher_model_url = self.remote_rm_url rm_url_for_reward_engine = None # Don't use remote_rm_url for rewards in OPD mode @@ -1075,7 +1075,7 @@ def make_experience_list( self._process_multi_image_video_thws(experiences, expanded_images_num, expanded_videos_num) # ========== Stage 6.5: On-Policy Distillation Teacher Log-Probs ========== - if config.advantage_estimator == "on_policy_distillation": + if config.advantage_estimator in ("on_policy_distillation", "on_policy_distillation_hybrid"): self._fetch_teacher_logprobs(experiences) # ========== Stage 7: Advantage Computation ========== From 0daefc60c2911965ca14065a7bcbf58a62720f5b Mon Sep 17 00:00:00 2001 From: root Date: Tue, 17 Mar 2026 21:35:40 +0800 Subject: [PATCH 06/18] fix(pu): auto-align batch sizes to (micro_batch * world_size) --- .../on_policy_distillation/run_opd_qwen.sh | 17 +- examples/on_policy_distillation/test_opd.py | 541 ++++++++++++------ 2 files changed, 380 insertions(+), 178 deletions(-) diff --git a/examples/on_policy_distillation/run_opd_qwen.sh b/examples/on_policy_distillation/run_opd_qwen.sh index ceeebaa6..56116a9a 100644 --- a/examples/on_policy_distillation/run_opd_qwen.sh +++ b/examples/on_policy_distillation/run_opd_qwen.sh @@ -75,9 +75,18 @@ N_SAMPLES=${N_SAMPLES:-8} EPISODE=${EPISODE:-30} WARMUP=${WARMUP:-0.03} OPD_KL_COEF=${OPD_KL_COEF:-1.0} +MICRO_TRAIN_BS=${MICRO_TRAIN_BS:-4} +MICRO_ROLLOUT_BS=${MICRO_ROLLOUT_BS:-4} -RBS=${RBS:-128} -TBS=${TBS:-128} +# Auto-align batch sizes to (micro_batch * world_size) +WORLD_SIZE=$((NNODES * GPUS_PER_NODE)) +ALIGN=$((MICRO_TRAIN_BS * WORLD_SIZE)) +RBS=$(( (${RBS:-128} / ALIGN) * ALIGN )) +TBS=$(( (${TBS:-128} / ALIGN) * ALIGN )) +[ "$RBS" -lt "$ALIGN" ] && RBS=$ALIGN +[ "$TBS" -lt "$ALIGN" ] && TBS=$ALIGN + +echo "Batch sizes: TBS=${TBS}, RBS=${RBS} (aligned to micro=${MICRO_TRAIN_BS} * world=${WORLD_SIZE} = ${ALIGN})" if [ "$OPD_MODE" = "hybrid" ]; then ADVANTAGE_ESTIMATOR="on_policy_distillation_hybrid" @@ -215,9 +224,9 @@ torchrun \ --remote_rm_url "$TEACHER_URL" \ --save_path "results/${EXPERIMENT_NAME}/${SAVE_MODEL_NAME}" \ --ckpt_path "results/${EXPERIMENT_NAME}/${SAVE_MODEL_NAME}" \ - --micro_train_batch_size 4 \ + --micro_train_batch_size ${MICRO_TRAIN_BS} \ --train_batch_size ${TBS} \ - --micro_rollout_batch_size 4 \ + --micro_rollout_batch_size ${MICRO_ROLLOUT_BS} \ --rollout_batch_size ${RBS} \ --max_epochs 1 \ --num_episodes ${EPISODE} \ diff --git a/examples/on_policy_distillation/test_opd.py b/examples/on_policy_distillation/test_opd.py index e2476c48..6318e260 100644 --- a/examples/on_policy_distillation/test_opd.py +++ b/examples/on_policy_distillation/test_opd.py @@ -1,260 +1,453 @@ """ Test script for On-Policy Distillation implementation in LightRFT. -This script validates the core components of the on-policy distillation mechanism. +Tests both pure and hybrid OPD modes, advantage whitening, +teacher logprob extraction, and dimension alignment. """ import torch import sys from pathlib import Path -# Add LightRFT to path lightrft_path = Path(__file__).parent.parent.parent sys.path.insert(0, str(lightrft_path)) -def test_advantage_calculator(): - """Test OnPolicyDistillationCalculator.""" - print("Testing OnPolicyDistillationCalculator...") - from lightrft.trainer.advantage_calculator import get_advantage_calculator +class MockConfig: + """Reusable mock config for advantage calculators.""" + def __init__(self, **kwargs): + self.advantages_norm = False + self.advantage_clip = 0.0 + self.opd_kl_coef = 1.0 + self.n_samples_per_prompt = 4 + self.dynamic_sampling = False + self.micro_train_batch_size = 4 + for k, v in kwargs.items(): + setattr(self, k, v) + + +class MockExperience: + """Reusable mock experience object.""" + def __init__(self, batch_size=4, num_actions=10, teacher_offset=-0.5): + self.action_log_probs = torch.randn(batch_size, num_actions) * 0.5 - 1.0 + self.action_mask = torch.ones(batch_size, num_actions, dtype=torch.bool) + self.action_mask[:, -2:] = False # last 2 tokens are padding + self.info = { + "teacher_log_probs": self.action_log_probs + teacher_offset, + "reward": torch.rand(batch_size), + "response_length": torch.full((batch_size,), num_actions), + } + + +# ============================================================================ +# Test: Factory registration +# ============================================================================ + +def test_factory(): + """All estimators including both OPD modes are registered.""" + print("Test: Factory registration") - # Create mock config - class MockConfig: - advantages_norm = True - advantage_clip = 10.0 + from lightrft.trainer.advantage_calculator import get_advantage_calculator config = MockConfig() + estimators = [ + "gae", "reinforce", "rloo", "reinforce_baseline", + "group_norm", "grpo", "cpgd", + "on_policy_distillation", + "on_policy_distillation_hybrid", + ] - # Create calculator - calculator = get_advantage_calculator("on_policy_distillation", config) - print(f"✓ Created calculator: {calculator.__class__.__name__}") - - # Create mock experience - class MockExperience: - def __init__(self): - self.action_log_probs = torch.tensor([[0.1, 0.2, 0.3], [0.2, 0.3, 0.4]]) - self.action_mask = torch.tensor([[True, True, True], [True, True, False]]) - self.info = { - "teacher_log_probs": torch.tensor([[0.3, 0.4, 0.5], [0.4, 0.5, 0.6]]) - } - - experience = MockExperience() - final_reward = torch.zeros_like(experience.action_log_probs) - generate_kwargs = {} - - # Compute advantages - advantages, returns, info = calculator.compute( - experience, final_reward, gamma=1.0, generate_kwargs=generate_kwargs - ) + for name in estimators: + calc = get_advantage_calculator(name, config) + print(f" {name} -> {calc.__class__.__name__}") - print(f"✓ Computed advantages: {advantages.shape}") - print(f" Advantages sample: {advantages[0]}") - print(f" Expected positive (teacher > student): {(advantages > 0).float().mean():.2f}") + try: + get_advantage_calculator("nonexistent", config) + assert False, "Should raise ValueError" + except ValueError: + pass + + print(" PASS\n") + return True + + +# ============================================================================ +# Test: Pure OPD calculator +# ============================================================================ - # Test error case (missing teacher_log_probs) - experience_no_teacher = MockExperience() - del experience_no_teacher.info["teacher_log_probs"] +def test_pure_opd(): + """Pure distillation: rewards zeroed, advantages from KL only.""" + print("Test: Pure OPD calculator") + from lightrft.trainer.advantage_calculator import OnPolicyDistillationCalculator + + config = MockConfig(opd_kl_coef=1.0) + calc = OnPolicyDistillationCalculator(config) + + # preprocess_rewards should zero out rewards + rewards = torch.tensor([0.5, 0.8, 0.3, 0.9, 0.1, 0.7, 0.2, 0.4]) + experiences = [MockExperience(batch_size=4), MockExperience(batch_size=4)] + exps, reward_chunks = calc.preprocess_rewards(rewards, experiences, max_new_tokens=100) + for chunk in reward_chunks: + assert (chunk == 0).all(), f"Pure OPD should zero rewards, got {chunk}" + print(" preprocess_rewards zeros out rewards: OK") + + # compute should produce whitened advantages from KL + exp = MockExperience(batch_size=4, num_actions=10, teacher_offset=-0.5) + final_reward = torch.zeros(4, 10) + adv, ret, info = calc.compute(exp, final_reward, gamma=1.0, generate_kwargs={}) + + assert adv.shape == (4, 10), f"Wrong shape: {adv.shape}" + assert (adv[:, -2:] == 0).all(), "Padding positions should be 0" + + # Whitened: mean should be ~0 + masked = adv[exp.action_mask] + assert abs(masked.mean()) < 0.1, f"Whitened mean should be ~0, got {masked.mean():.4f}" + print(f" advantages whitened (mean={masked.mean():.4f}, std={masked.std():.4f}): OK") + + # opd_reverse_kl metric should be present + assert "opd_reverse_kl" in info, "Missing opd_reverse_kl metric" + print(f" opd_reverse_kl metric present: OK") + + # Missing teacher_log_probs should raise + exp_bad = MockExperience() + del exp_bad.info["teacher_log_probs"] try: - calculator.compute(experience_no_teacher, final_reward, gamma=1.0, generate_kwargs=generate_kwargs) - print("✗ Should have raised ValueError for missing teacher_log_probs") - return False - except ValueError as e: - print(f"✓ Correctly raised ValueError: {str(e)[:50]}...") + calc.compute(exp_bad, final_reward, gamma=1.0, generate_kwargs={}) + assert False, "Should raise ValueError" + except ValueError: + pass + print(" missing teacher_log_probs raises ValueError: OK") + + print(" PASS\n") + return True + + +# ============================================================================ +# Test: Hybrid OPD calculator +# ============================================================================ + +def test_hybrid_opd(): + """Hybrid: GRPO base advantages + OPD KL penalty, then whitened.""" + print("Test: Hybrid OPD calculator") + + from lightrft.trainer.advantage_calculator import OnPolicyDistillationHybridCalculator + + config = MockConfig(opd_kl_coef=1.0, n_samples_per_prompt=4) + calc = OnPolicyDistillationHybridCalculator(config) + + # preprocess_rewards should apply GRPO normalization (not zero) + rewards = torch.tensor([0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0]) + experiences = [MockExperience(batch_size=4), MockExperience(batch_size=4)] + exps, reward_chunks = calc.preprocess_rewards(rewards, experiences, max_new_tokens=100) + combined = torch.cat(reward_chunks) + assert not (combined == 0).all(), "Hybrid should NOT zero rewards" + print(f" preprocess_rewards applies GRPO normalization: OK") + + # compute should combine GRPO + OPD and whiten + exp = MockExperience(batch_size=4, num_actions=10, teacher_offset=-0.3) + # Simulate GRPO-normalized reward broadcast to tokens + final_reward = torch.randn(4, 10) * 0.5 + adv, ret, info = calc.compute(exp, final_reward, gamma=1.0, generate_kwargs={}) + + assert adv.shape == (4, 10), f"Wrong shape: {adv.shape}" + masked = adv[exp.action_mask] + assert abs(masked.mean()) < 0.1, f"Whitened mean should be ~0, got {masked.mean():.4f}" + print(f" advantages whitened (mean={masked.mean():.4f}, std={masked.std():.4f}): OK") + assert "opd_reverse_kl" in info + print(" PASS\n") + return True + + +# ============================================================================ +# Test: Advantage whitening +# ============================================================================ + +def test_whiten_advantages(): + """Whitening normalizes to ~zero mean, ~unit std.""" + print("Test: Advantage whitening") + + from lightrft.trainer.advantage_calculator import _whiten_advantages + + # Large-scale advantages (simulating raw OPD KL) + adv = torch.randn(4, 20) * 10 + 5 # mean=5, std=10 + mask = torch.ones(4, 20, dtype=torch.bool) + mask[:, -3:] = False + + whitened = _whiten_advantages(adv, mask) + + masked_vals = whitened[mask] + assert abs(masked_vals.mean()) < 0.01, f"Mean should be ~0, got {masked_vals.mean():.4f}" + assert abs(masked_vals.std() - 1.0) < 0.1, f"Std should be ~1, got {masked_vals.std():.4f}" + print(f" mean={masked_vals.mean():.4f}, std={masked_vals.std():.4f}: OK") + + # Edge case: very small batch + adv_small = torch.tensor([[1.0, 2.0]]) + mask_small = torch.ones(1, 2, dtype=torch.bool) + whitened_small = _whiten_advantages(adv_small, mask_small) + assert whitened_small.shape == (1, 2) + print(" small batch handled: OK") + + print(" PASS\n") + return True + + +# ============================================================================ +# Test: OPD KL penalty helper +# ============================================================================ + +def test_opd_kl_penalty(): + """_apply_opd_kl_penalty computes correct penalty direction.""" + print("Test: OPD KL penalty") + + from lightrft.trainer.advantage_calculator import _apply_opd_kl_penalty + + # Teacher is better (higher log probs) → student should get positive advantage + student_lp = torch.tensor([[-2.0, -3.0, -1.5]]) + teacher_lp = torch.tensor([[-1.0, -1.5, -0.5]]) + mask = torch.ones(1, 3, dtype=torch.bool) + + opd_adv, info = _apply_opd_kl_penalty(student_lp, teacher_lp, mask, opd_kl_coef=1.0) + + # reverse_kl = student - teacher = negative (student worse) + # opd_adv = -1.0 * reverse_kl = positive (encourage matching teacher) + assert (opd_adv > 0).all(), f"Should be positive when teacher > student, got {opd_adv}" + print(f" teacher > student → positive advantage: OK ({opd_adv.tolist()})") + + # Student is overconfident → negative advantage + student_lp2 = torch.tensor([[-0.5, -0.3, -0.2]]) + teacher_lp2 = torch.tensor([[-2.0, -2.5, -3.0]]) + opd_adv2, _ = _apply_opd_kl_penalty(student_lp2, teacher_lp2, mask, opd_kl_coef=1.0) + assert (opd_adv2 < 0).all(), f"Should be negative when student > teacher, got {opd_adv2}" + print(f" student > teacher → negative advantage: OK ({opd_adv2.tolist()})") + + # opd_kl_coef scales the penalty + opd_adv_scaled, _ = _apply_opd_kl_penalty(student_lp, teacher_lp, mask, opd_kl_coef=0.5) + assert torch.allclose(opd_adv_scaled, opd_adv * 0.5) + print(" opd_kl_coef scaling: OK") - print("✓ OnPolicyDistillationCalculator tests passed\n") + # Mask is respected + mask_partial = torch.tensor([[True, True, False]]) + opd_adv_masked, _ = _apply_opd_kl_penalty(student_lp, teacher_lp, mask_partial, opd_kl_coef=1.0) + assert opd_adv_masked[0, 2] == 0, "Masked position should be 0" + print(" action mask respected: OK") + + assert "opd_reverse_kl" in info + print(" PASS\n") return True -def test_reward_function(): - """Test teacher logprob extraction functions.""" - print("Testing teacher logprob functions...") +# ============================================================================ +# Test: Teacher logprob extraction (get_teacher_logprobs_by_ids mock) +# ============================================================================ + +def test_teacher_logprob_extraction(): + """extract_teacher_logprobs handles SGLang format correctly.""" + print("Test: Teacher logprob extraction") from examples.on_policy_distillation.on_policy_distillation_reward import ( extract_teacher_logprobs ) - # Test SGLang format - sglang_response = { + # SGLang format: [logprob, rank, token_str] tuples + response = { "meta_info": { "input_token_logprobs": [ - None, # First token has no logprob + None, # BOS token [-0.1, 1, "hello"], [-0.2, 2, "world"], [-0.15, 1, "!"], + [-0.3, 3, "."], + [-0.25, 2, "end"], ] } } - teacher_log_probs = extract_teacher_logprobs( - [sglang_response], - response_lengths=[3], - device="cpu" - ) + # Extract last 3 tokens as response + lp_list = extract_teacher_logprobs([response], response_lengths=[3], device="cpu") + assert len(lp_list) == 1 + assert len(lp_list[0]) == 3 + expected = torch.tensor([-0.3, -0.25, -0.15]) # Wait, let me check... + + # logprob_values = [-0.1, -0.2, -0.15, -0.3, -0.25] (skip None, take [0] from each) + # teacher_log_probs[-3:] = [-0.15, -0.3, -0.25] + # Hmm, actually the last 3 of [-0.1, -0.2, -0.15, -0.3, -0.25] = [-0.15, -0.3, -0.25] + expected = torch.tensor([-0.15, -0.3, -0.25]) + assert torch.allclose(lp_list[0], expected), f"Got {lp_list[0]}, expected {expected}" + print(f" SGLang format extraction: OK ({lp_list[0].tolist()})") + + # Test padding when response_length > available logprobs + lp_list2 = extract_teacher_logprobs([response], response_lengths=[10], device="cpu") + assert len(lp_list2[0]) == 10, f"Should pad to 10, got {len(lp_list2[0])}" + print(f" padding for short sequences: OK (len={len(lp_list2[0])})") + + print(" PASS\n") + return True - print(f"✓ Extracted teacher log probs (SGLang format): {teacher_log_probs[0]}") - assert len(teacher_log_probs[0]) == 3, "Should extract exactly 3 response tokens" - assert torch.allclose(teacher_log_probs[0], torch.tensor([-0.1, -0.2, -0.15])), "Values mismatch" - # Test vLLM format - vllm_response = { - "token_logprobs": [None, -0.1, -0.2, -0.15, -0.3] - } +# ============================================================================ +# Test: Dimension alignment (teacher_log_probs vs action_log_probs) +# ============================================================================ - teacher_log_probs = extract_teacher_logprobs( - [vllm_response], - response_lengths=[3], - device="cpu" - ) - - print(f"✓ Extracted teacher log probs (vLLM format): {teacher_log_probs[0]}") - assert len(teacher_log_probs[0]) == 3, "Should extract exactly 3 response tokens" +def test_dimension_alignment(): + """teacher_log_probs must match action_log_probs shape [batch, num_actions].""" + print("Test: Dimension alignment") - print("✓ Teacher logprob extraction tests passed\n") - return True + from lightrft.trainer.advantage_calculator import OnPolicyDistillationCalculator + config = MockConfig(opd_kl_coef=1.0) + calc = OnPolicyDistillationCalculator(config) + + # Simulate different response lengths within a batch + # action_log_probs: [batch=2, num_actions=8] (padded to max response length) + # action_mask: first sample has 6 real tokens, second has 8 + batch_size, num_actions = 2, 8 + exp = MockExperience.__new__(MockExperience) + exp.action_log_probs = torch.randn(batch_size, num_actions) + exp.action_mask = torch.ones(batch_size, num_actions, dtype=torch.bool) + exp.action_mask[0, :2] = False # first 2 positions are prompt padding for sample 0 + exp.info = { + "teacher_log_probs": torch.randn(batch_size, num_actions), # same shape + } -def test_factory_function(): - """Test that on_policy_distillation is registered in factory.""" - print("Testing factory function registration...") + final_reward = torch.zeros(batch_size, num_actions) + adv, _, _ = calc.compute(exp, final_reward, gamma=1.0, generate_kwargs={}) + assert adv.shape == (batch_size, num_actions) + assert adv[0, 0] == 0 and adv[0, 1] == 0, "Masked positions should be 0" + print(f" shape match [batch={batch_size}, num_actions={num_actions}]: OK") - from lightrft.trainer.advantage_calculator import get_advantage_calculator + # Mismatched shapes should fail + exp_bad = MockExperience.__new__(MockExperience) + exp_bad.action_log_probs = torch.randn(2, 8) + exp_bad.action_mask = torch.ones(2, 8, dtype=torch.bool) + exp_bad.info = {"teacher_log_probs": torch.randn(2, 12)} # wrong dim! - class MockConfig: - advantages_norm = False - advantage_clip = 0.0 + try: + calc.compute(exp_bad, final_reward, gamma=1.0, generate_kwargs={}) + assert False, "Should fail on shape mismatch" + except RuntimeError: + pass + print(" shape mismatch correctly raises RuntimeError: OK") - config = MockConfig() + print(" PASS\n") + return True - # Test valid estimators - estimators = [ - "gae", - "reinforce", - "rloo", - "reinforce_baseline", - "group_norm", - "cpgd", - "on_policy_distillation" - ] - for estimator in estimators: - try: - calc = get_advantage_calculator(estimator, config) - print(f"✓ Created {estimator}: {calc.__class__.__name__}") - except Exception as e: - print(f"✗ Failed to create {estimator}: {e}") - return False +# ============================================================================ +# Test: Pure vs Hybrid produce different results +# ============================================================================ - # Test invalid estimator - try: - get_advantage_calculator("invalid_estimator", config) - print("✗ Should have raised ValueError for invalid estimator") - return False - except ValueError: - print("✓ Correctly raised ValueError for invalid estimator") +def test_pure_vs_hybrid(): + """Pure and hybrid modes produce meaningfully different advantages.""" + print("Test: Pure vs Hybrid comparison") - print("✓ Factory function tests passed\n") - return True + from lightrft.trainer.advantage_calculator import ( + OnPolicyDistillationCalculator, + OnPolicyDistillationHybridCalculator, + ) + config = MockConfig(opd_kl_coef=1.0, n_samples_per_prompt=4) + pure_calc = OnPolicyDistillationCalculator(config) + hybrid_calc = OnPolicyDistillationHybridCalculator(config) -def test_integration(): - """Test basic integration flow.""" - print("Testing integration flow...") + torch.manual_seed(42) + exp = MockExperience(batch_size=4, num_actions=10, teacher_offset=-0.5) - # This is a simplified test to ensure components work together - from lightrft.trainer.advantage_calculator import OnPolicyDistillationCalculator + # Pure: rewards are zeroed + rewards = torch.tensor([0.0, 1.0, 0.0, 1.0, 0.5, 0.5, 0.8, 0.2]) + _, pure_rewards = pure_calc.preprocess_rewards(rewards.clone(), [exp, exp], 100) + _, hybrid_rewards = hybrid_calc.preprocess_rewards(rewards.clone(), [exp, exp], 100) - class MockConfig: - advantages_norm = True - advantage_clip = 5.0 + pure_r = torch.cat(pure_rewards) + hybrid_r = torch.cat(hybrid_rewards) + assert (pure_r == 0).all(), "Pure should zero rewards" + assert not (hybrid_r == 0).all(), "Hybrid should keep rewards" + print(f" reward preprocessing differs: OK (pure=0, hybrid has signal)") - calculator = OnPolicyDistillationCalculator(MockConfig()) + # Both should produce valid advantages + final_reward_pure = torch.zeros(4, 10) + final_reward_hybrid = torch.randn(4, 10) * 0.5 - # Simulate a batch of experiences - batch_size = 4 - seq_len = 10 + adv_pure, _, _ = pure_calc.compute(exp, final_reward_pure, 1.0, {}) + adv_hybrid, _, _ = hybrid_calc.compute(exp, final_reward_hybrid, 1.0, {}) - class MockExperience: - def __init__(self): - # Student generated these log probs - self.action_log_probs = torch.randn(batch_size, seq_len) * 0.5 - 1.0 + assert adv_pure.shape == adv_hybrid.shape + # They should differ (different reward signals) + assert not torch.allclose(adv_pure, adv_hybrid, atol=0.01) + print(f" advantages differ between modes: OK") - # Teacher evaluated and got these log probs - # Teacher is better, so generally higher log probs - self.info = { - "teacher_log_probs": torch.randn(batch_size, seq_len) * 0.3 - 0.5 - } + print(" PASS\n") + return True - # Action mask (last 2 tokens are padding) - self.action_mask = torch.ones(batch_size, seq_len, dtype=torch.bool) - self.action_mask[:, -2:] = False - experience = MockExperience() - final_reward = torch.zeros_like(experience.action_log_probs) +# ============================================================================ +# Test: reward_func returns zeros +# ============================================================================ - advantages, returns, info = calculator.compute( - experience, final_reward, gamma=1.0, generate_kwargs={} - ) +def test_reward_func(): + """Placeholder reward_func returns zeros.""" + print("Test: reward_func placeholder") - print(f"✓ Computed advantages for batch: {advantages.shape}") - print(f" Mean advantage: {advantages.mean():.4f}") - print(f" Std advantage: {advantages.std():.4f}") - print(f" Masked correctly: {advantages[:, -2:].sum() == 0}") + from examples.on_policy_distillation.on_policy_distillation_reward import reward_func - # Check that advantages are normalized (approximately) - masked_adv = advantages[experience.action_mask] - assert abs(masked_adv.mean()) < 0.1, "Advantages should be normalized (mean ≈ 0)" - print(f"✓ Advantages are normalized (mean={masked_adv.mean():.4f})") + result = reward_func( + queries=["q1", "q2", "q3"], + prompts=["p1", "p2", "p3"], + ) + assert isinstance(result, torch.Tensor) + assert result.shape == (3,) + assert (result == 0).all() + print(" returns zeros: OK") - print("✓ Integration tests passed\n") + print(" PASS\n") return True +# ============================================================================ +# Runner +# ============================================================================ + def run_all_tests(): - """Run all tests.""" - print("=" * 70) - print("On-Policy Distillation Test Suite") - print("=" * 70) + print("=" * 60) + print("On-Policy Distillation Test Suite (v3)") + print("=" * 60) print() tests = [ - ("Factory Function", test_factory_function), - ("Advantage Calculator", test_advantage_calculator), - ("Reward Function", test_reward_function), - ("Integration", test_integration), + ("Factory registration", test_factory), + ("OPD KL penalty", test_opd_kl_penalty), + ("Advantage whitening", test_whiten_advantages), + ("Pure OPD calculator", test_pure_opd), + ("Hybrid OPD calculator", test_hybrid_opd), + ("Dimension alignment", test_dimension_alignment), + ("Pure vs Hybrid", test_pure_vs_hybrid), + ("Teacher logprob extraction", test_teacher_logprob_extraction), + ("Reward func placeholder", test_reward_func), ] results = [] - for name, test_fn in tests: + for name, fn in tests: try: - result = test_fn() - results.append((name, result)) + ok = fn() + results.append((name, ok)) except Exception as e: - print(f"✗ {name} failed with exception: {e}") + print(f" FAIL: {e}") import traceback traceback.print_exc() results.append((name, False)) + print() - print("=" * 70) - print("Test Summary") - print("=" * 70) - - all_passed = True - for name, passed in results: - status = "✓ PASS" if passed else "✗ FAIL" - print(f"{status}: {name}") - if not passed: - all_passed = False + print("=" * 60) + print("Summary") + print("=" * 60) + passed = sum(1 for _, ok in results if ok) + for name, ok in results: + print(f" {'PASS' if ok else 'FAIL'}: {name}") + print(f"\n{passed}/{len(results)} passed") + print("=" * 60) - print("=" * 70) - if all_passed: - print("✓ All tests passed!") - return 0 - else: - print("✗ Some tests failed") - return 1 + return 0 if passed == len(results) else 1 if __name__ == "__main__": - exit_code = run_all_tests() - sys.exit(exit_code) + sys.exit(run_all_tests()) From 5446e4b9dd0a1e1d27f215b88060c6315a71f3be Mon Sep 17 00:00:00 2001 From: root Date: Tue, 17 Mar 2026 21:59:38 +0800 Subject: [PATCH 07/18] fix(pu): use normalize_advantages_cross_batch for opd --- lightrft/trainer/advantage_calculator.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/lightrft/trainer/advantage_calculator.py b/lightrft/trainer/advantage_calculator.py index b67b021e..0160de4b 100644 --- a/lightrft/trainer/advantage_calculator.py +++ b/lightrft/trainer/advantage_calculator.py @@ -787,7 +787,8 @@ def preprocess_rewards(self, rewards, experiences, max_new_tokens): return experiences, list(zero_rewards.chunk(len(experiences))) def compute(self, experience, final_reward, gamma, generate_kwargs): - """advantages = -opd_kl_coef * (student_logp - teacher_logp), then whiten.""" + """advantages = -opd_kl_coef * (student_logp - teacher_logp). + Whitening is done cross-batch in normalize_advantages_cross_batch.""" if "teacher_log_probs" not in experience.info: raise ValueError("teacher_log_probs not found in experience.info.") @@ -798,9 +799,6 @@ def compute(self, experience, final_reward, gamma, generate_kwargs): student_lp, teacher_lp, experience.action_mask, self.opd_kl_coef ) - # Whiten advantages for training stability - advantages = _whiten_advantages(advantages, experience.action_mask) - returns = advantages.clone() if self.config.advantage_clip > 0: @@ -833,7 +831,8 @@ def preprocess_rewards(self, rewards, experiences, max_new_tokens): return self.base_calculator.preprocess_rewards(rewards, experiences, max_new_tokens) def compute(self, experience, final_reward, gamma, generate_kwargs): - """advantages = whiten(GRPO_base + OPD_KL_penalty).""" + """advantages = GRPO_base + OPD_KL_penalty. + Whitening is done cross-batch in normalize_advantages_cross_batch.""" # Step 1: GRPO base advantages from task rewards base_advantages, returns, info_dict = self.base_calculator.compute( experience, final_reward, gamma, generate_kwargs @@ -851,9 +850,8 @@ def compute(self, experience, final_reward, gamma, generate_kwargs): ) info_dict.update(opd_info) - # Step 3: Combine and whiten to resolve scale mismatch + # Step 3: Combine (whitening done cross-batch later) advantages = base_advantages + opd_adv - advantages = _whiten_advantages(advantages, experience.action_mask) if self.config.advantage_clip > 0: clip_val = self.config.advantage_clip @@ -928,7 +926,10 @@ def normalize_advantages_cross_batch(experiences: List, advantage_estimator: str :return: List of Experience objects with normalized advantages. :rtype: List """ - if advantage_estimator not in ["gae", "reinforce", "reinforce_baseline"]: + if advantage_estimator not in [ + "gae", "reinforce", "reinforce_baseline", + "on_policy_distillation", "on_policy_distillation_hybrid", + ]: return experiences # Collect all advantages and action masks From fc667924b973c5c0affe8f2aaccfa2094804370d Mon Sep 17 00:00:00 2001 From: puyuan1996 Date: Thu, 19 Mar 2026 23:18:05 +0800 Subject: [PATCH 08/18] tmp --- examples/gsm8k_geo3k/reward_models_utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/gsm8k_geo3k/reward_models_utils.py b/examples/gsm8k_geo3k/reward_models_utils.py index 06047e86..cf33e7e4 100644 --- a/examples/gsm8k_geo3k/reward_models_utils.py +++ b/examples/gsm8k_geo3k/reward_models_utils.py @@ -369,6 +369,9 @@ def mix_rewards( # GSM8K pure rule-based reward (format + accuracy) acc_r = gsm8k_accuracy_reward_fn(sol_completion, gt) fmt_r = gsm8k_format_reward_fn(sol_completion) + # only for debug now + # acc_r = gsm8k_accuracy_reward_fn(sol, gt) + # fmt_r = gsm8k_format_reward_fn(sol) combined_r = (1.0 - 0.1) * acc_r + 0.1 * fmt_r r += w * combined_r From 6f7ab07b11b9001b80fe6f0860481458e1d4ec2e Mon Sep 17 00:00:00 2001 From: puyuan1996 Date: Fri, 20 Mar 2026 00:47:52 +0800 Subject: [PATCH 09/18] fix(pu): fix _map_weight_name_for_sglang bug in text-only model --- examples/gsm8k_geo3k/reward_models_utils.py | 3 -- lightrft/strategy/utils/broadcast_utils.py | 56 ++++++++++++--------- 2 files changed, 32 insertions(+), 27 deletions(-) diff --git a/examples/gsm8k_geo3k/reward_models_utils.py b/examples/gsm8k_geo3k/reward_models_utils.py index cf33e7e4..06047e86 100644 --- a/examples/gsm8k_geo3k/reward_models_utils.py +++ b/examples/gsm8k_geo3k/reward_models_utils.py @@ -369,9 +369,6 @@ def mix_rewards( # GSM8K pure rule-based reward (format + accuracy) acc_r = gsm8k_accuracy_reward_fn(sol_completion, gt) fmt_r = gsm8k_format_reward_fn(sol_completion) - # only for debug now - # acc_r = gsm8k_accuracy_reward_fn(sol, gt) - # fmt_r = gsm8k_format_reward_fn(sol) combined_r = (1.0 - 0.1) * acc_r + 0.1 * fmt_r r += w * combined_r diff --git a/lightrft/strategy/utils/broadcast_utils.py b/lightrft/strategy/utils/broadcast_utils.py index efd8fc0f..010b79e9 100755 --- a/lightrft/strategy/utils/broadcast_utils.py +++ b/lightrft/strategy/utils/broadcast_utils.py @@ -65,36 +65,44 @@ def _map_weight_name_for_sglang(self, name: str) -> str: :param name: Original weight name from training model :return: Mapped weight name for SGLang """ - # Step 0: Handle PEFT/LoRA and other potential wrapping prefixes + # Step 0: Handle PEFT/LoRA wrapping prefixes # PEFT models have weights like base_model.model. - # We recursively strip "base_model.model." or "model." prefixes until we find - # core components like "visual" or "language_model" - while name.startswith("base_model.model.") or name.startswith("model."): - if name.startswith("base_model.model."): - name = name[len("base_model.model."):] - elif name.startswith("model."): - # We strip "model." and let the following steps handle it. - # If "language_model" follows, it will be added back as "model." - # for SGLang's expectation. - name = name[len("model."):] + # Strip "base_model.model." prefix (possibly nested) to get the original name. + while name.startswith("base_model.model."): + name = name[len("base_model.model."):] # PEFT models also rename original weights to include ".base_layer." # we need to strip this to match standard weight names name = name.replace(".base_layer.", ".") - # Step 2: Handle language_model prefix mapping - if name.startswith("language_model."): - # Remove "language_model." prefix - name = name[15:] # Remove "language_model." - - # For lm_head, keep as is (no "model." prefix) - if name.startswith("lm_head"): - return name - - # For other components (embed_tokens, layers, norm), add "model." prefix - return f"model.{name}" - - # Step 3: Return as is for other cases (e.g., visual.xxx) + # Step 1: Handle VLM models wrapped by ActorVL + # ActorVL wraps the HF model as self.model, so parameter names get an extra "model." prefix: + # Training (ActorVL): model.visual.xxx, model.model.layers.xxx, model.lm_head.xxx + # SGLang expects: visual.xxx, model.layers.xxx, lm_head.xxx + # Also handle the "model.language_model." pattern (some VLM architectures): + # Training: model.language_model.model.layers.xxx + # SGLang expects: model.layers.xxx + if name.startswith("model.language_model."): + inner = name[len("model.language_model."):] + if inner.startswith("lm_head"): + return inner + return f"model.{inner}" + + if name.startswith("model.visual."): + return name[len("model."):] + + if name.startswith("model.lm_head"): + return name[len("model."):] + + # Handle VLM's double "model.model." prefix (ActorVL.model -> HF model.layers) + # model.model.layers.xxx -> model.layers.xxx + # model.model.embed_tokens.xxx -> model.embed_tokens.xxx + if name.startswith("model.model."): + return name[len("model."):] + + # Step 2: For text-only models (e.g., Qwen2.5-0.5B-Instruct), parameter names + # are already in SGLang's expected format: model.layers.xxx, model.embed_tokens.xxx, + # model.norm.xxx, lm_head.xxx. Return as-is without stripping "model." prefix. return name def _deepspeed_broadcast(self): From 3fb087d685ce1fb0962cee0e94fc7a42b74cfc2e Mon Sep 17 00:00:00 2001 From: puyuan1996 Date: Fri, 20 Mar 2026 15:24:29 +0800 Subject: [PATCH 10/18] fix(pu): polish norm_adv_cross_batch, fix teacher_logprob --- .../on_policy_distillation_reward.py | 12 ++-- lightrft/trainer/advantage_calculator.py | 69 +++++++++++++++---- 2 files changed, 62 insertions(+), 19 deletions(-) diff --git a/examples/on_policy_distillation/on_policy_distillation_reward.py b/examples/on_policy_distillation/on_policy_distillation_reward.py index 7bc65503..3f75c789 100644 --- a/examples/on_policy_distillation/on_policy_distillation_reward.py +++ b/examples/on_policy_distillation/on_policy_distillation_reward.py @@ -82,9 +82,10 @@ async def query_single(input_ids: List[int], attempt: int = 0) -> Dict[str, Any] teacher_lp_list = [] for response, resp_len in zip(results, response_lengths): logprobs = response["meta_info"]["input_token_logprobs"] - # Extract logprob values; skip first token (no logprob for BOS) - lp_values = [item[0] if isinstance(item, list) else item for item in logprobs] - teacher_lp = torch.tensor(lp_values[1:], dtype=torch.float32) + # Align with Slime: first [1:] to skip BOS/None, then take item[0] for each + # SGLang returns list of [logprob, token_id] pairs, first element is None (BOS) + lp_values = [item[0] for item in logprobs[1:]] + teacher_lp = torch.tensor(lp_values, dtype=torch.float32) # Take the last resp_len tokens (response part only) teacher_lp = teacher_lp[-resp_len:] teacher_lp_list.append(teacher_lp) @@ -160,8 +161,9 @@ def extract_teacher_logprobs( for response, response_length in zip(teacher_responses, response_lengths): if "meta_info" in response and "input_token_logprobs" in response["meta_info"]: logprobs = response["meta_info"]["input_token_logprobs"] - logprob_values = [item[0] if isinstance(item, list) else item for item in logprobs] - teacher_log_probs = torch.tensor(logprob_values[1:], dtype=torch.float32) + # Align with Slime: first [1:] to skip BOS/None, then take item[0] + logprob_values = [item[0] for item in logprobs[1:]] + teacher_log_probs = torch.tensor(logprob_values, dtype=torch.float32) teacher_log_probs = teacher_log_probs[-response_length:] else: raise ValueError( diff --git a/lightrft/trainer/advantage_calculator.py b/lightrft/trainer/advantage_calculator.py index 0160de4b..c67f2cbe 100644 --- a/lightrft/trainer/advantage_calculator.py +++ b/lightrft/trainer/advantage_calculator.py @@ -27,6 +27,7 @@ from typing import Dict, List, Optional, Tuple, Union import torch +import torch.distributed as dist import warnings from .utils import RunningMoments, compute_clip_fraction @@ -727,19 +728,39 @@ def _apply_opd_kl_penalty( Compute OPD reverse KL penalty: -opd_kl_coef * (student_logp - teacher_logp). Shared helper for both pure and hybrid OPD modes. + Aligned with Slime's apply_opd_kl_to_advantages. + + Safety guards: + - Masks out positions where teacher_log_probs == 0 (uninitialized/padded), + since teacher_logp=0 means P(token)=1, which is nonsensical and would + produce a large spurious KL penalty. + - Clamps per-token reverse KL to [-20, 20] to prevent extreme gradients. :return: Tuple of (opd_advantages, info_dict with opd_reverse_kl metric) """ reverse_kl = student_log_probs - teacher_log_probs + + # Clamp per-token reverse KL to prevent extreme values from teacher-student + # capability gap (e.g. 7B teacher vs 0.5B student) + reverse_kl = torch.clamp(reverse_kl, min=-20.0, max=20.0) + opd_adv = -opd_kl_coef * reverse_kl + # Mask out positions where teacher_log_probs == 0 (padded/uninitialized) + # teacher_logp=0 means exp(0)=1 probability, which is nonsensical + valid_teacher_mask = (teacher_log_probs != 0.0).float() + if action_mask is not None: - opd_adv = opd_adv * action_mask + effective_mask = action_mask * valid_teacher_mask + else: + effective_mask = valid_teacher_mask + + opd_adv = opd_adv * effective_mask info_dict = {} if action_mask is not None: - masked_rkl = reverse_kl * action_mask - info_dict["opd_reverse_kl"] = masked_rkl.sum(-1) / action_mask.sum(-1).clamp(min=1) + masked_rkl = reverse_kl * effective_mask + info_dict["opd_reverse_kl"] = masked_rkl.sum(-1) / effective_mask.sum(-1).clamp(min=1) return opd_adv, info_dict @@ -911,11 +932,10 @@ def get_advantage_calculator(estimator_name: str, config) -> AdvantageCalculator @torch.no_grad() def normalize_advantages_cross_batch(experiences: List, advantage_estimator: str, args) -> List: """ - Apply cross-batch advantage normalization for GAE, REINFORCE, and REINFORCE-baseline. + Apply cross-batch advantage normalization across all data-parallel ranks. - This method normalizes advantages across all experiences in a batch using their action masks. - Reference: https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/trainer/ppo_utils/ - experience_maker.py#L794-L816 + Matches Slime's distributed_masked_whiten: computes global mean/variance + via all_reduce across all ranks, then whitens locally. :param experiences: List of Experience objects. :type experiences: List @@ -941,21 +961,42 @@ def normalize_advantages_cross_batch(experiences: List, advantage_estimator: str # Concatenate into vectors advantages_vector = torch.cat(all_advantages, dim=0).float() - action_masks_vector = torch.cat(all_action_masks, dim=0) - num_actions = action_masks_vector.sum() + action_masks_vector = torch.cat(all_action_masks, dim=0).float() + + # Compute local intermediate statistics + local_sum = (advantages_vector * action_masks_vector).sum() + local_sum_sq = ((advantages_vector ** 2) * action_masks_vector).sum() + local_count = action_masks_vector.sum() + + # Aggregate across all data-parallel ranks via all_reduce + # (matching Slime's distributed_masked_whiten) + stats = torch.stack([local_sum, local_sum_sq, local_count]).to( + device=advantages_vector.device, dtype=torch.float32 + ) + if dist.is_initialized(): + dist.all_reduce(stats, op=dist.ReduceOp.SUM) + + global_sum, global_sum_sq, global_count = stats + + if global_count.item() == 0: + return experiences + + global_mean = global_sum / global_count + global_var = global_sum_sq / global_count - global_mean ** 2 - # Compute mean - mean = (advantages_vector * action_masks_vector).sum() / num_actions + # Bessel's correction for unbiased variance estimate (matching Slime) + if global_count.item() >= 2: + bessel_correction = global_count / (global_count - 1) + global_var = global_var * bessel_correction # Compute std (if not disabled) if not getattr(args, "no_advantage_std_norm", False): - var = ((advantages_vector - mean).pow(2) * action_masks_vector).sum() / num_actions - rstd = var.clamp(min=1e-8).rsqrt() + rstd = global_var.clamp(min=1e-8).rsqrt() else: rstd = 1 # Apply normalization to each experience for exp in experiences: - exp.advantages = (exp.advantages - mean) * rstd + exp.advantages = (exp.advantages - global_mean) * rstd return experiences From 9aa39a99b4f96aee1abd355a42b07141074ba876 Mon Sep 17 00:00:00 2001 From: puyuan1996 Date: Fri, 3 Apr 2026 15:23:05 +0800 Subject: [PATCH 11/18] fix(pu): fix pad-token bug in _fetch_teacher_logprobs, fix aligned_teacher_lp, not use adv_norm in opd pure mode --- lightrft/trainer/advantage_calculator.py | 26 ++++--------- lightrft/trainer/fast_exp_maker.py | 47 ++++++++++++------------ 2 files changed, 32 insertions(+), 41 deletions(-) diff --git a/lightrft/trainer/advantage_calculator.py b/lightrft/trainer/advantage_calculator.py index c67f2cbe..3b5ec6bf 100644 --- a/lightrft/trainer/advantage_calculator.py +++ b/lightrft/trainer/advantage_calculator.py @@ -730,12 +730,6 @@ def _apply_opd_kl_penalty( Shared helper for both pure and hybrid OPD modes. Aligned with Slime's apply_opd_kl_to_advantages. - Safety guards: - - Masks out positions where teacher_log_probs == 0 (uninitialized/padded), - since teacher_logp=0 means P(token)=1, which is nonsensical and would - produce a large spurious KL penalty. - - Clamps per-token reverse KL to [-20, 20] to prevent extreme gradients. - :return: Tuple of (opd_advantages, info_dict with opd_reverse_kl metric) """ reverse_kl = student_log_probs - teacher_log_probs @@ -746,21 +740,17 @@ def _apply_opd_kl_penalty( opd_adv = -opd_kl_coef * reverse_kl - # Mask out positions where teacher_log_probs == 0 (padded/uninitialized) - # teacher_logp=0 means exp(0)=1 probability, which is nonsensical - valid_teacher_mask = (teacher_log_probs != 0.0).float() - + # Rely solely on action_mask for padding filtering. + # Previous code also masked teacher_log_probs == 0.0, but log_prob=0 is a + # legitimate value (P=1). With the padding-strip fix in _fetch_teacher_logprobs, + # teacher logprobs are now correctly aligned and action_mask is sufficient. if action_mask is not None: - effective_mask = action_mask * valid_teacher_mask - else: - effective_mask = valid_teacher_mask - - opd_adv = opd_adv * effective_mask + opd_adv = opd_adv * action_mask info_dict = {} if action_mask is not None: - masked_rkl = reverse_kl * effective_mask - info_dict["opd_reverse_kl"] = masked_rkl.sum(-1) / effective_mask.sum(-1).clamp(min=1) + masked_rkl = reverse_kl * action_mask + info_dict["opd_reverse_kl"] = masked_rkl.sum(-1) / action_mask.sum(-1).clamp(min=1) return opd_adv, info_dict @@ -948,7 +938,7 @@ def normalize_advantages_cross_batch(experiences: List, advantage_estimator: str """ if advantage_estimator not in [ "gae", "reinforce", "reinforce_baseline", - "on_policy_distillation", "on_policy_distillation_hybrid", + "on_policy_distillation_hybrid", ]: return experiences diff --git a/lightrft/trainer/fast_exp_maker.py b/lightrft/trainer/fast_exp_maker.py index 094bbbdc..dc027ccc 100644 --- a/lightrft/trainer/fast_exp_maker.py +++ b/lightrft/trainer/fast_exp_maker.py @@ -1481,17 +1481,23 @@ def _fetch_teacher_logprobs( Timer.start(' fetch_teacher_logprobs') for exp in experiences: - sequences = exp.sequences # [batch_size, seq_len] - action_mask = exp.action_mask # [batch_size, num_actions] + sequences = exp.sequences # [batch_size, seq_len] + attention_mask = exp.attention_mask # [batch_size, seq_len] + action_mask = exp.action_mask # [batch_size, num_actions] + + # response_lengths must be int for slicing + response_lengths = action_mask.sum(dim=-1).int().tolist() + num_actions = action_mask.shape[1] + + # Strip padding tokens before sending to SGLang. + # sequences is [prompt, response, eos, pad, pad, ...] — the padding + # tokens would cause SGLang to return logprobs for pad positions, + # making the [-resp_len:] slice grab wrong tokens. + input_ids_list = [] + for j in range(sequences.shape[0]): + valid_len = int(attention_mask[j].sum().item()) + input_ids_list.append(sequences[j, :valid_len].cpu().tolist()) - # Get response lengths (number of generated tokens per sequence) - response_lengths = action_mask.sum(dim=-1).tolist() - num_actions = action_mask.shape[1] # action_log_probs dim - - # Use input_ids for teacher query to ensure token-level alignment - input_ids_list = sequences.cpu().tolist() - - # Query teacher model for log probs try: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) @@ -1506,22 +1512,17 @@ def _fetch_teacher_logprobs( finally: loop.close() - # Align teacher log probs to action_log_probs shape [batch_size, num_actions] - # teacher_lp_list[i] has shape [resp_len_i], need to pad/align to [num_actions] + # Align teacher log probs to action_log_probs shape [batch_size, num_actions]. + # Use action_mask indices directly — works regardless of left/right padding. batch_size = sequences.shape[0] aligned_teacher_lp = torch.zeros(batch_size, num_actions, dtype=torch.float32) for i, (tlp, resp_len) in enumerate(zip(teacher_lp_list, response_lengths)): - # Right-align: teacher log probs fill the last resp_len positions - # (matching where action_mask == 1) - actual_len = min(len(tlp), resp_len, num_actions) - start_pos = num_actions - resp_len - if start_pos >= 0: - aligned_teacher_lp[i, start_pos:start_pos + actual_len] = tlp[:actual_len] - else: - # resp_len > num_actions (shouldn't happen, but handle gracefully) - aligned_teacher_lp[i, :] = tlp[-num_actions:] - - # Store in experience info for advantage calculator + if resp_len == 0: + continue + valid_indices = torch.where(action_mask[i] == 1)[0] + actual_len = min(len(tlp), len(valid_indices)) + aligned_teacher_lp[i, valid_indices[:actual_len]] = tlp[:actual_len] + exp.info["teacher_log_probs"] = aligned_teacher_lp except Exception as e: From 8637296a8b75e62dce7594800c691ca791357016 Mon Sep 17 00:00:00 2001 From: puyuan1996 Date: Thu, 16 Apr 2026 20:11:34 +0800 Subject: [PATCH 12/18] polish(pu): polish num_actions, use_task_reward, teacher_model_url, test_opd.py and seperate start scripts --- examples/gsm8k_geo3k/reward_models_utils.py | 8 +- examples/gsm8k_geo3k/train_colocate.py | 6 +- examples/on_policy_distillation/README.md | 24 +- examples/on_policy_distillation/README_zh.md | 29 +- .../on_policy_distillation/run_opd_qwen.sh | 5 +- examples/on_policy_distillation/test_opd.py | 800 ++++++++---------- lightrft/strategy/config.py | 5 + lightrft/trainer/advantage_calculator.py | 40 +- lightrft/trainer/experience_maker.py | 47 +- lightrft/trainer/fast_exp_maker.py | 64 +- 10 files changed, 510 insertions(+), 518 deletions(-) diff --git a/examples/gsm8k_geo3k/reward_models_utils.py b/examples/gsm8k_geo3k/reward_models_utils.py index 06047e86..dcb7fd3b 100644 --- a/examples/gsm8k_geo3k/reward_models_utils.py +++ b/examples/gsm8k_geo3k/reward_models_utils.py @@ -295,6 +295,7 @@ def mix_rewards( label_map: Dict[str, int], solution_strs: Sequence[str], refs: Sequence[str], + use_task_reward: bool = True, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Mix rewards from multiple sources according to recipe configuration. @@ -383,6 +384,10 @@ def mix_rewards( final_reward[i] = r + # When use_task_reward=False, zero out final rewards but keep metrics for logging + if not use_task_reward: + final_reward.zero_() + return final_reward, metrics_dict @@ -392,6 +397,7 @@ def reward_fn( queries: Sequence[str], refs: Sequence[str], label_map: Dict[str, int], + use_task_reward: bool = True, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ External unified interface for computing final rewards. @@ -429,4 +435,4 @@ def reward_fn( model_scores = torch.zeros(0, B, dtype=torch.float32, device="cuda") # Call mix_rewards to compute final rewards - return mix_rewards(labels, model_scores, label_map, queries, refs) + return mix_rewards(labels, model_scores, label_map, queries, refs, use_task_reward=use_task_reward) diff --git a/examples/gsm8k_geo3k/train_colocate.py b/examples/gsm8k_geo3k/train_colocate.py index 6b8b29b3..bcba57b3 100644 --- a/examples/gsm8k_geo3k/train_colocate.py +++ b/examples/gsm8k_geo3k/train_colocate.py @@ -56,6 +56,7 @@ sys.path.append(os.path.dirname(os.path.abspath(__file__))) from reward_models_utils import RECIPE, load_reward_models, reward_fn +import functools import torch.multiprocessing # Fix "multiprocessing.context.AuthenticationError: digest received was wrong" error @@ -411,7 +412,7 @@ def train(args): eos_token_id=tokenizer.eos_token_id, # reward model / teacher model URL (used for OPD) remote_rm_url=args.remote_rm_url, - reward_fn=reward_fn, + reward_fn=functools.partial(reward_fn, use_task_reward=args.use_task_reward), reward_fn_label_map=label_map, reward_recipe=RECIPE, reward_tokenizers=reward_tokenizers, @@ -568,6 +569,8 @@ def train(args): parser.add_argument("--use_kl_loss", action="store_true", default=False, help="whether to use KL loss from GRPO") parser.add_argument("--opd_kl_coef", type=float, default=1.0, help="KL coefficient for on-policy distillation penalty") + parser.add_argument("--use_task_reward", action="store_true", dest="use_task_reward", default=True, help="Use task reward in final reward (default)") + parser.add_argument("--no_task_reward", action="store_false", dest="use_task_reward", help="Zero out task reward (metrics still logged)") # LoRA parser.add_argument("--load_in_4bit", action="store_true", default=False) @@ -580,6 +583,7 @@ def train(args): parser.add_argument("--pretrain", type=str, default=None, help="HF model name or path") parser.add_argument("--reward_pretrain", type=str, default=None, help="HF model name or path") parser.add_argument("--remote_rm_url", type=str, default=None, help="remote RM API") + parser.add_argument("--teacher_model_url", type=str, default=None, help="Teacher model URL for OPD (overrides --remote_rm_url for teacher)") parser.add_argument("--critic_pretrain", type=str, default=None, help="HF model name or path") parser.add_argument("--value_head_prefix", type=str, default="score") diff --git a/examples/on_policy_distillation/README.md b/examples/on_policy_distillation/README.md index af20882d..569be2d1 100644 --- a/examples/on_policy_distillation/README.md +++ b/examples/on_policy_distillation/README.md @@ -37,7 +37,8 @@ Or manually: torchrun --nproc-per-node 2 examples/gsm8k_geo3k/train_colocate.py \ --pretrain "Qwen/Qwen2.5-0.5B-Instruct" \ --advantage_estimator "on_policy_distillation" \ - --remote_rm_url "http://127.0.0.1:13141/generate" \ + --teacher_model_url "http://127.0.0.1:13141/generate" \ + --no_task_reward \ --reward_pretrain "" \ --n_samples_per_prompt 4 \ --actor_learning_rate 1e-6 \ @@ -45,6 +46,18 @@ torchrun --nproc-per-node 2 examples/gsm8k_geo3k/train_colocate.py \ --num_episodes 30 ``` +### Separate Deployment + +For multi-node setups or when teacher and training run on different machines: + +```bash +# Terminal 1: Start teacher server +TEACHER_GPU=7 bash examples/on_policy_distillation/start_teacher.sh + +# Terminal 2: Start training (after teacher is ready) +TEACHER_URL=http://127.0.0.1:13141/generate bash examples/on_policy_distillation/start_training.sh +``` + ## Architecture ``` @@ -97,7 +110,7 @@ class OnPolicyDistillationCalculator(AdvantageCalculator): **File**: `lightrft/trainer/fast_exp_maker.py` -- `--remote_rm_url` is used as teacher URL (not reward model) when `--advantage_estimator "on_policy_distillation"` +- `--teacher_model_url` specifies the teacher server when `--advantage_estimator "on_policy_distillation"` (falls back to `--remote_rm_url` with deprecation warning) - Teacher log probs stored in `experience.info["teacher_log_probs"]` - OPD metrics (`opd_reverse_kl_mean/std/min/max`) logged to wandb @@ -108,7 +121,7 @@ class OnPolicyDistillationCalculator(AdvantageCalculator): | Argument | Value | Description | |----------|-------|-------------| | `--advantage_estimator` | `"on_policy_distillation"` | Enable OPD mode | -| `--remote_rm_url` | `"http://host:port/generate"` | Teacher server URL | +| `--teacher_model_url` | `"http://host:port/generate"` | Teacher server URL | | `--reward_pretrain` | `""` | Empty (no reward model needed) | ### Recommended Hyperparameters @@ -221,7 +234,10 @@ nvidia-smi examples/on_policy_distillation/ ├── README.md # This file ├── README_zh.md # Chinese version -├── run_opd_qwen_2.sh # Training script +├── run_opd_qwen.sh # All-in-one training script +├── start_teacher.sh # Teacher server only +├── start_training.sh # Training only (requires TEACHER_URL) +├── test_opd.py # Unit tests └── on_policy_distillation_reward.py # Teacher logprob fetcher ``` diff --git a/examples/on_policy_distillation/README_zh.md b/examples/on_policy_distillation/README_zh.md index c6239123..a0800b2a 100644 --- a/examples/on_policy_distillation/README_zh.md +++ b/examples/on_policy_distillation/README_zh.md @@ -37,7 +37,8 @@ bash examples/on_policy_distillation/run_opd_qwen_2.sh torchrun --nproc-per-node 2 examples/gsm8k_geo3k/train_colocate.py \ --pretrain "Qwen/Qwen2.5-0.5B-Instruct" \ --advantage_estimator "on_policy_distillation" \ - --remote_rm_url "http://127.0.0.1:13141/generate" \ + --teacher_model_url "http://127.0.0.1:13141/generate" \ + --no_task_reward \ --reward_pretrain "" \ --n_samples_per_prompt 4 \ --actor_learning_rate 1e-6 \ @@ -45,6 +46,18 @@ torchrun --nproc-per-node 2 examples/gsm8k_geo3k/train_colocate.py \ --num_episodes 30 ``` +### 分离部署 + +教师服务器和训练可以在不同终端或不同机器上运行: + +```bash +# 终端 1:启动教师服务器 +TEACHER_GPU=7 bash examples/on_policy_distillation/start_teacher.sh + +# 终端 2:启动训练(教师就绪后) +TEACHER_URL=http://127.0.0.1:13141/generate bash examples/on_policy_distillation/start_training.sh +``` + ## 架构 ``` @@ -91,13 +104,13 @@ class OnPolicyDistillationCalculator(AdvantageCalculator): - 异步 HTTP 请求到教师服务器 - 支持 SGLang 和 vLLM 响应格式 -- 自动重试(指数退避) +- 自动重试(指数退避):当教师服务器出现瞬态故障时,重试延迟按指数增长(1s → 2s → 4s → ...),避免在服务器暂时过载期间发送大量重试请求 ### 3. Experience Maker 集成 **文件**: `lightrft/trainer/fast_exp_maker.py` -- 当 `--advantage_estimator "on_policy_distillation"` 时,`--remote_rm_url` 作为教师 URL(而非奖励模型) +- 当 `--advantage_estimator "on_policy_distillation"` 时,`--teacher_model_url` 指定教师 URL(也可通过 `--remote_rm_url` 传递,但已弃用) - 教师对数概率存储在 `experience.info["teacher_log_probs"]` - OPD 指标(`opd_reverse_kl_mean/std/min/max`)记录到 wandb @@ -108,7 +121,7 @@ class OnPolicyDistillationCalculator(AdvantageCalculator): | 参数 | 值 | 描述 | |------|---|------| | `--advantage_estimator` | `"on_policy_distillation"` | 启用 OPD 模式 | -| `--remote_rm_url` | `"http://host:port/generate"` | 教师服务器 URL | +| `--teacher_model_url` | `"http://host:port/generate"` | 教师服务器 URL | | `--reward_pretrain` | `""` | 空值(不需要奖励模型) | ### 推荐超参数 @@ -204,7 +217,7 @@ nvidia-smi ### 优势 -- 无需单独训练奖励模型 +- 无需单独训练奖励模型:与 GRPO/PPO 等方法不同,OPD 不需要序列级的结果奖励模型(Outcome Reward Model)。教师模型本身充当 token 粒度的奖励信号——通过在每个 token 位置提供对数概率监督,教师直接指导学生在生成过程中的每一步决策,而非仅在完整序列结束后给出单一的好/坏评分。 - Token 级监督(比序列级更精细) - 在线策略:适应学生不断变化的分布 - 适用于任何有好教师模型的任务 @@ -221,7 +234,10 @@ nvidia-smi examples/on_policy_distillation/ ├── README.md # 英文文档 ├── README_zh.md # 本文件 -├── run_opd_qwen_2.sh # 训练脚本 +├── run_opd_qwen.sh # 一体化训练脚本 +├── start_teacher.sh # 仅启动教师服务器 +├── start_training.sh # 仅启动训练(需要 TEACHER_URL) +├── test_opd.py # 单元测试 └── on_policy_distillation_reward.py # 教师对数概率获取器 ``` @@ -230,3 +246,4 @@ examples/on_policy_distillation/ - [LightRFT 文档](../../README.md) - [优势值计算器源码](../../lightrft/trainer/advantage_calculator.py) - [Fast Experience Maker 源码](../../lightrft/trainer/fast_exp_maker.py) +- [On-Policy Distillation Blog](https://thinkingmachines.ai/blog/on-policy-distillation/) diff --git a/examples/on_policy_distillation/run_opd_qwen.sh b/examples/on_policy_distillation/run_opd_qwen.sh index 56116a9a..d210acb1 100644 --- a/examples/on_policy_distillation/run_opd_qwen.sh +++ b/examples/on_policy_distillation/run_opd_qwen.sh @@ -92,10 +92,12 @@ if [ "$OPD_MODE" = "hybrid" ]; then ADVANTAGE_ESTIMATOR="on_policy_distillation_hybrid" KL=${KL:-0.01} LR=${LR:-5e-7} + TASK_REWARD_FLAG="--use_task_reward" else ADVANTAGE_ESTIMATOR="on_policy_distillation" KL=${KL:-0.00} LR=${LR:-5e-7} + TASK_REWARD_FLAG="--no_task_reward" fi PROMPT_MAX_LEN=${PROMPT_MAX_LEN:-1024} @@ -221,7 +223,8 @@ torchrun \ --enable_engine_sleep \ --rm_use_engine \ --reward_pretrain "" \ - --remote_rm_url "$TEACHER_URL" \ + --teacher_model_url "$TEACHER_URL" \ + ${TASK_REWARD_FLAG} \ --save_path "results/${EXPERIMENT_NAME}/${SAVE_MODEL_NAME}" \ --ckpt_path "results/${EXPERIMENT_NAME}/${SAVE_MODEL_NAME}" \ --micro_train_batch_size ${MICRO_TRAIN_BS} \ diff --git a/examples/on_policy_distillation/test_opd.py b/examples/on_policy_distillation/test_opd.py index 6318e260..a5fad991 100644 --- a/examples/on_policy_distillation/test_opd.py +++ b/examples/on_policy_distillation/test_opd.py @@ -1,453 +1,401 @@ """ -Test script for On-Policy Distillation implementation in LightRFT. +Pytest suite for On-Policy Distillation implementation in LightRFT. -Tests both pure and hybrid OPD modes, advantage whitening, -teacher logprob extraction, and dimension alignment. +Tests both pure and hybrid OPD modes, KL penalty computation, +teacher logprob extraction, dimension alignment, and reward engine validation. """ +import pytest import torch -import sys -from pathlib import Path - -lightrft_path = Path(__file__).parent.parent.parent -sys.path.insert(0, str(lightrft_path)) - - -class MockConfig: - """Reusable mock config for advantage calculators.""" - def __init__(self, **kwargs): - self.advantages_norm = False - self.advantage_clip = 0.0 - self.opd_kl_coef = 1.0 - self.n_samples_per_prompt = 4 - self.dynamic_sampling = False - self.micro_train_batch_size = 4 - for k, v in kwargs.items(): - setattr(self, k, v) - - -class MockExperience: - """Reusable mock experience object.""" - def __init__(self, batch_size=4, num_actions=10, teacher_offset=-0.5): - self.action_log_probs = torch.randn(batch_size, num_actions) * 0.5 - 1.0 - self.action_mask = torch.ones(batch_size, num_actions, dtype=torch.bool) - self.action_mask[:, -2:] = False # last 2 tokens are padding - self.info = { - "teacher_log_probs": self.action_log_probs + teacher_offset, + +from lightrft.trainer.advantage_calculator import ( + OnPolicyDistillationCalculator, + OnPolicyDistillationHybridCalculator, + _apply_opd_kl_penalty, + get_advantage_calculator, + normalize_advantages_cross_batch, +) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture +def mock_config(): + """Factory fixture for mock config objects.""" + def _make(**kwargs): + defaults = dict( + advantages_norm=False, + advantage_clip=0.0, + opd_kl_coef=1.0, + n_samples_per_prompt=4, + dynamic_sampling=False, + micro_train_batch_size=4, + ) + defaults.update(kwargs) + + class _Cfg: + pass + + cfg = _Cfg() + for k, v in defaults.items(): + setattr(cfg, k, v) + return cfg + + return _make + + +@pytest.fixture +def mock_experience(): + """Factory fixture for mock experience objects (num_tokens naming).""" + def _make(batch_size=4, num_tokens=10, teacher_offset=-0.5): + class _Exp: + pass + + exp = _Exp() + exp.action_log_probs = torch.randn(batch_size, num_tokens) * 0.5 - 1.0 + exp.action_mask = torch.ones(batch_size, num_tokens, dtype=torch.bool) + exp.action_mask[:, -2:] = False # last 2 tokens are padding + exp.info = { + "teacher_log_probs": exp.action_log_probs + teacher_offset, "reward": torch.rand(batch_size), - "response_length": torch.full((batch_size,), num_actions), + "response_length": torch.full((batch_size,), num_tokens), } + return exp + return _make -# ============================================================================ -# Test: Factory registration -# ============================================================================ - -def test_factory(): - """All estimators including both OPD modes are registered.""" - print("Test: Factory registration") - - from lightrft.trainer.advantage_calculator import get_advantage_calculator - - config = MockConfig() - estimators = [ - "gae", "reinforce", "rloo", "reinforce_baseline", - "group_norm", "grpo", "cpgd", - "on_policy_distillation", - "on_policy_distillation_hybrid", - ] - - for name in estimators: - calc = get_advantage_calculator(name, config) - print(f" {name} -> {calc.__class__.__name__}") - - try: - get_advantage_calculator("nonexistent", config) - assert False, "Should raise ValueError" - except ValueError: - pass - - print(" PASS\n") - return True - -# ============================================================================ +# --------------------------------------------------------------------------- +# Test: Factory registration +# --------------------------------------------------------------------------- + +class TestFactory: + def test_all_estimators_registered(self, mock_config): + """All estimators including both OPD modes are registered.""" + config = mock_config() + estimators = [ + "gae", "reinforce", "rloo", "reinforce_baseline", + "group_norm", "grpo", "cpgd", + "on_policy_distillation", + "on_policy_distillation_hybrid", + ] + for name in estimators: + calc = get_advantage_calculator(name, config) + assert calc is not None + + def test_unknown_estimator_raises(self, mock_config): + """Unknown estimator name raises ValueError.""" + with pytest.raises(ValueError): + get_advantage_calculator("nonexistent", mock_config()) + + +# --------------------------------------------------------------------------- # Test: Pure OPD calculator -# ============================================================================ - -def test_pure_opd(): - """Pure distillation: rewards zeroed, advantages from KL only.""" - print("Test: Pure OPD calculator") - - from lightrft.trainer.advantage_calculator import OnPolicyDistillationCalculator - - config = MockConfig(opd_kl_coef=1.0) - calc = OnPolicyDistillationCalculator(config) - - # preprocess_rewards should zero out rewards - rewards = torch.tensor([0.5, 0.8, 0.3, 0.9, 0.1, 0.7, 0.2, 0.4]) - experiences = [MockExperience(batch_size=4), MockExperience(batch_size=4)] - exps, reward_chunks = calc.preprocess_rewards(rewards, experiences, max_new_tokens=100) - for chunk in reward_chunks: - assert (chunk == 0).all(), f"Pure OPD should zero rewards, got {chunk}" - print(" preprocess_rewards zeros out rewards: OK") - - # compute should produce whitened advantages from KL - exp = MockExperience(batch_size=4, num_actions=10, teacher_offset=-0.5) - final_reward = torch.zeros(4, 10) - adv, ret, info = calc.compute(exp, final_reward, gamma=1.0, generate_kwargs={}) - - assert adv.shape == (4, 10), f"Wrong shape: {adv.shape}" - assert (adv[:, -2:] == 0).all(), "Padding positions should be 0" - - # Whitened: mean should be ~0 - masked = adv[exp.action_mask] - assert abs(masked.mean()) < 0.1, f"Whitened mean should be ~0, got {masked.mean():.4f}" - print(f" advantages whitened (mean={masked.mean():.4f}, std={masked.std():.4f}): OK") - - # opd_reverse_kl metric should be present - assert "opd_reverse_kl" in info, "Missing opd_reverse_kl metric" - print(f" opd_reverse_kl metric present: OK") - - # Missing teacher_log_probs should raise - exp_bad = MockExperience() - del exp_bad.info["teacher_log_probs"] - try: - calc.compute(exp_bad, final_reward, gamma=1.0, generate_kwargs={}) - assert False, "Should raise ValueError" - except ValueError: - pass - print(" missing teacher_log_probs raises ValueError: OK") - - print(" PASS\n") - return True - - -# ============================================================================ +# --------------------------------------------------------------------------- + +class TestPureOPD: + def test_preprocess_rewards_passthrough(self, mock_config, mock_experience): + """Pure OPD preprocess_rewards passes through rewards (zeroing done upstream).""" + calc = OnPolicyDistillationCalculator(mock_config(opd_kl_coef=1.0)) + rewards = torch.tensor([0.5, 0.8, 0.3, 0.9, 0.1, 0.7, 0.2, 0.4]) + experiences = [mock_experience(batch_size=4), mock_experience(batch_size=4)] + _, reward_chunks = calc.preprocess_rewards(rewards, experiences, max_new_tokens=100) + # Rewards are passed through; upstream --no_task_reward zeroes them + combined = torch.cat(reward_chunks) + assert combined.shape == rewards.shape + + def test_compute_advantages_shape_and_masking(self, mock_config, mock_experience): + """Advantages have correct shape and padding positions are zero.""" + calc = OnPolicyDistillationCalculator(mock_config(opd_kl_coef=1.0)) + exp = mock_experience(batch_size=4, num_tokens=10, teacher_offset=-0.5) + final_reward = torch.zeros(4, 10) + adv, _ret, info = calc.compute(exp, final_reward, gamma=1.0, generate_kwargs={}) + + assert adv.shape == (4, 10) + assert (adv[:, -2:] == 0).all(), "Padding positions should be 0" + assert "opd_reverse_kl" in info + + def test_missing_teacher_logprobs_raises(self, mock_config, mock_experience): + """Missing teacher_log_probs in experience.info raises ValueError.""" + calc = OnPolicyDistillationCalculator(mock_config(opd_kl_coef=1.0)) + exp = mock_experience() + del exp.info["teacher_log_probs"] + final_reward = torch.zeros(4, 10) + with pytest.raises(ValueError): + calc.compute(exp, final_reward, gamma=1.0, generate_kwargs={}) + + +# --------------------------------------------------------------------------- # Test: Hybrid OPD calculator -# ============================================================================ - -def test_hybrid_opd(): - """Hybrid: GRPO base advantages + OPD KL penalty, then whitened.""" - print("Test: Hybrid OPD calculator") - - from lightrft.trainer.advantage_calculator import OnPolicyDistillationHybridCalculator - - config = MockConfig(opd_kl_coef=1.0, n_samples_per_prompt=4) - calc = OnPolicyDistillationHybridCalculator(config) +# --------------------------------------------------------------------------- + +class TestHybridOPD: + def test_preprocess_rewards_grpo_normalization(self, mock_config, mock_experience): + """Hybrid mode applies GRPO normalization (rewards not zeroed).""" + calc = OnPolicyDistillationHybridCalculator( + mock_config(opd_kl_coef=1.0, n_samples_per_prompt=4) + ) + rewards = torch.tensor([0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0]) + experiences = [mock_experience(batch_size=4), mock_experience(batch_size=4)] + _, reward_chunks = calc.preprocess_rewards(rewards, experiences, max_new_tokens=100) + combined = torch.cat(reward_chunks) + assert not (combined == 0).all(), "Hybrid should NOT zero rewards" + + def test_compute_advantages(self, mock_config, mock_experience): + """Hybrid compute produces correct shape and includes KL metric.""" + calc = OnPolicyDistillationHybridCalculator( + mock_config(opd_kl_coef=1.0, n_samples_per_prompt=4) + ) + exp = mock_experience(batch_size=4, num_tokens=10, teacher_offset=-0.3) + final_reward = torch.randn(4, 10) * 0.5 + adv, _ret, info = calc.compute(exp, final_reward, gamma=1.0, generate_kwargs={}) + + assert adv.shape == (4, 10) + assert "opd_reverse_kl" in info + + +# --------------------------------------------------------------------------- +# Test: OPD KL penalty helper +# --------------------------------------------------------------------------- + +class TestOPDKLPenalty: + def test_teacher_better_positive_advantage(self): + """When teacher > student (higher logprobs), advantage should be positive.""" + student_lp = torch.tensor([[-2.0, -3.0, -1.5]]) + teacher_lp = torch.tensor([[-1.0, -1.5, -0.5]]) + mask = torch.ones(1, 3, dtype=torch.bool) + + opd_adv, info = _apply_opd_kl_penalty(student_lp, teacher_lp, mask, opd_kl_coef=1.0) + assert (opd_adv > 0).all() + assert "opd_reverse_kl" in info + + def test_student_overconfident_negative_advantage(self): + """When student > teacher, advantage should be negative.""" + student_lp = torch.tensor([[-0.5, -0.3, -0.2]]) + teacher_lp = torch.tensor([[-2.0, -2.5, -3.0]]) + mask = torch.ones(1, 3, dtype=torch.bool) + + opd_adv, _ = _apply_opd_kl_penalty(student_lp, teacher_lp, mask, opd_kl_coef=1.0) + assert (opd_adv < 0).all() + + def test_coef_scaling(self): + """opd_kl_coef scales the penalty linearly.""" + student_lp = torch.tensor([[-2.0, -3.0, -1.5]]) + teacher_lp = torch.tensor([[-1.0, -1.5, -0.5]]) + mask = torch.ones(1, 3, dtype=torch.bool) + + adv_1, _ = _apply_opd_kl_penalty(student_lp, teacher_lp, mask, opd_kl_coef=1.0) + adv_half, _ = _apply_opd_kl_penalty(student_lp, teacher_lp, mask, opd_kl_coef=0.5) + assert torch.allclose(adv_half, adv_1 * 0.5) + + def test_mask_respected(self): + """Masked positions should have zero advantage.""" + student_lp = torch.tensor([[-2.0, -3.0, -1.5]]) + teacher_lp = torch.tensor([[-1.0, -1.5, -0.5]]) + mask = torch.tensor([[True, True, False]]) + + opd_adv, _ = _apply_opd_kl_penalty(student_lp, teacher_lp, mask, opd_kl_coef=1.0) + assert opd_adv[0, 2] == 0, "Masked position should be 0" + + +# --------------------------------------------------------------------------- +# Test: Advantage normalization (cross-batch) +# --------------------------------------------------------------------------- + +class TestNormalizeAdvantages: + def test_normalize_advantages_cross_batch_shape(self, mock_experience): + """normalize_advantages_cross_batch preserves experience structure.""" + exp1 = mock_experience(batch_size=4, num_tokens=10) + exp2 = mock_experience(batch_size=4, num_tokens=10) + # Add advantages attribute (normally set by compute) + exp1.advantages = torch.randn(4, 10) * 5 + 2 + exp2.advantages = torch.randn(4, 10) * 5 + 2 + + class _Args: + pass + + args = _Args() + # on_policy_distillation_hybrid triggers normalization + result = normalize_advantages_cross_batch( + [exp1, exp2], "on_policy_distillation_hybrid", args + ) + assert len(result) == 2 + assert result[0].advantages.shape == (4, 10) + + def test_pure_opd_skips_normalization(self, mock_experience): + """Pure OPD mode skips cross-batch normalization.""" + exp = mock_experience(batch_size=4, num_tokens=10) + exp.advantages = torch.randn(4, 10) * 5 + 2 + original = exp.advantages.clone() + + class _Args: + pass + + result = normalize_advantages_cross_batch( + [exp], "on_policy_distillation", _Args() + ) + # Should return unchanged (not in whitening list) + assert torch.equal(result[0].advantages, original) + + +# --------------------------------------------------------------------------- +# Test: Teacher logprob extraction +# --------------------------------------------------------------------------- + +class TestTeacherLogprobExtraction: + def test_sglang_format(self): + """extract_teacher_logprobs handles SGLang format correctly.""" + from examples.on_policy_distillation.on_policy_distillation_reward import ( + extract_teacher_logprobs, + ) + + response = { + "meta_info": { + "input_token_logprobs": [ + None, # BOS token + [-0.1, 1, "hello"], + [-0.2, 2, "world"], + [-0.15, 1, "!"], + [-0.3, 3, "."], + [-0.25, 2, "end"], + ] + } + } - # preprocess_rewards should apply GRPO normalization (not zero) - rewards = torch.tensor([0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0]) - experiences = [MockExperience(batch_size=4), MockExperience(batch_size=4)] - exps, reward_chunks = calc.preprocess_rewards(rewards, experiences, max_new_tokens=100) - combined = torch.cat(reward_chunks) - assert not (combined == 0).all(), "Hybrid should NOT zero rewards" - print(f" preprocess_rewards applies GRPO normalization: OK") + lp_list = extract_teacher_logprobs([response], response_lengths=[3], device="cpu") + assert len(lp_list) == 1 + assert len(lp_list[0]) == 3 + # logprob_values = [-0.1, -0.2, -0.15, -0.3, -0.25]; last 3 = [-0.15, -0.3, -0.25] + expected = torch.tensor([-0.15, -0.3, -0.25]) + assert torch.allclose(lp_list[0], expected) + + def test_padding_for_long_response(self): + """Pads to response_length when requested length > available logprobs.""" + from examples.on_policy_distillation.on_policy_distillation_reward import ( + extract_teacher_logprobs, + ) + + response = { + "meta_info": { + "input_token_logprobs": [ + None, + [-0.1, 1, "a"], + [-0.2, 2, "b"], + ] + } + } - # compute should combine GRPO + OPD and whiten - exp = MockExperience(batch_size=4, num_actions=10, teacher_offset=-0.3) - # Simulate GRPO-normalized reward broadcast to tokens - final_reward = torch.randn(4, 10) * 0.5 - adv, ret, info = calc.compute(exp, final_reward, gamma=1.0, generate_kwargs={}) + lp_list = extract_teacher_logprobs([response], response_lengths=[10], device="cpu") + assert len(lp_list[0]) == 10 - assert adv.shape == (4, 10), f"Wrong shape: {adv.shape}" - masked = adv[exp.action_mask] - assert abs(masked.mean()) < 0.1, f"Whitened mean should be ~0, got {masked.mean():.4f}" - print(f" advantages whitened (mean={masked.mean():.4f}, std={masked.std():.4f}): OK") - assert "opd_reverse_kl" in info - print(" PASS\n") - return True +# --------------------------------------------------------------------------- +# Test: Dimension alignment +# --------------------------------------------------------------------------- -# ============================================================================ -# Test: Advantage whitening -# ============================================================================ +class TestDimensionAlignment: + def test_matching_shapes(self, mock_config): + """teacher_log_probs matching action_log_probs shape [batch, num_tokens] works.""" + calc = OnPolicyDistillationCalculator(mock_config(opd_kl_coef=1.0)) + batch_size, num_tokens = 2, 8 -def test_whiten_advantages(): - """Whitening normalizes to ~zero mean, ~unit std.""" - print("Test: Advantage whitening") + class _Exp: + pass - from lightrft.trainer.advantage_calculator import _whiten_advantages + exp = _Exp() + exp.action_log_probs = torch.randn(batch_size, num_tokens) + exp.action_mask = torch.ones(batch_size, num_tokens, dtype=torch.bool) + exp.action_mask[0, :2] = False # prompt padding for sample 0 + exp.info = {"teacher_log_probs": torch.randn(batch_size, num_tokens)} - # Large-scale advantages (simulating raw OPD KL) - adv = torch.randn(4, 20) * 10 + 5 # mean=5, std=10 - mask = torch.ones(4, 20, dtype=torch.bool) - mask[:, -3:] = False + final_reward = torch.zeros(batch_size, num_tokens) + adv, _, _ = calc.compute(exp, final_reward, gamma=1.0, generate_kwargs={}) + assert adv.shape == (batch_size, num_tokens) + assert adv[0, 0] == 0 and adv[0, 1] == 0, "Masked positions should be 0" - whitened = _whiten_advantages(adv, mask) + def test_mismatched_shapes_raises(self, mock_config): + """Mismatched teacher_log_probs shape raises RuntimeError.""" + calc = OnPolicyDistillationCalculator(mock_config(opd_kl_coef=1.0)) - masked_vals = whitened[mask] - assert abs(masked_vals.mean()) < 0.01, f"Mean should be ~0, got {masked_vals.mean():.4f}" - assert abs(masked_vals.std() - 1.0) < 0.1, f"Std should be ~1, got {masked_vals.std():.4f}" - print(f" mean={masked_vals.mean():.4f}, std={masked_vals.std():.4f}: OK") + class _Exp: + pass - # Edge case: very small batch - adv_small = torch.tensor([[1.0, 2.0]]) - mask_small = torch.ones(1, 2, dtype=torch.bool) - whitened_small = _whiten_advantages(adv_small, mask_small) - assert whitened_small.shape == (1, 2) - print(" small batch handled: OK") + exp = _Exp() + exp.action_log_probs = torch.randn(2, 8) + exp.action_mask = torch.ones(2, 8, dtype=torch.bool) + exp.info = {"teacher_log_probs": torch.randn(2, 12)} # wrong dim - print(" PASS\n") - return True + final_reward = torch.zeros(2, 8) + with pytest.raises(RuntimeError): + calc.compute(exp, final_reward, gamma=1.0, generate_kwargs={}) -# ============================================================================ -# Test: OPD KL penalty helper -# ============================================================================ - -def test_opd_kl_penalty(): - """_apply_opd_kl_penalty computes correct penalty direction.""" - print("Test: OPD KL penalty") - - from lightrft.trainer.advantage_calculator import _apply_opd_kl_penalty - - # Teacher is better (higher log probs) → student should get positive advantage - student_lp = torch.tensor([[-2.0, -3.0, -1.5]]) - teacher_lp = torch.tensor([[-1.0, -1.5, -0.5]]) - mask = torch.ones(1, 3, dtype=torch.bool) - - opd_adv, info = _apply_opd_kl_penalty(student_lp, teacher_lp, mask, opd_kl_coef=1.0) - - # reverse_kl = student - teacher = negative (student worse) - # opd_adv = -1.0 * reverse_kl = positive (encourage matching teacher) - assert (opd_adv > 0).all(), f"Should be positive when teacher > student, got {opd_adv}" - print(f" teacher > student → positive advantage: OK ({opd_adv.tolist()})") - - # Student is overconfident → negative advantage - student_lp2 = torch.tensor([[-0.5, -0.3, -0.2]]) - teacher_lp2 = torch.tensor([[-2.0, -2.5, -3.0]]) - opd_adv2, _ = _apply_opd_kl_penalty(student_lp2, teacher_lp2, mask, opd_kl_coef=1.0) - assert (opd_adv2 < 0).all(), f"Should be negative when student > teacher, got {opd_adv2}" - print(f" student > teacher → negative advantage: OK ({opd_adv2.tolist()})") - - # opd_kl_coef scales the penalty - opd_adv_scaled, _ = _apply_opd_kl_penalty(student_lp, teacher_lp, mask, opd_kl_coef=0.5) - assert torch.allclose(opd_adv_scaled, opd_adv * 0.5) - print(" opd_kl_coef scaling: OK") - - # Mask is respected - mask_partial = torch.tensor([[True, True, False]]) - opd_adv_masked, _ = _apply_opd_kl_penalty(student_lp, teacher_lp, mask_partial, opd_kl_coef=1.0) - assert opd_adv_masked[0, 2] == 0, "Masked position should be 0" - print(" action mask respected: OK") - - assert "opd_reverse_kl" in info - print(" PASS\n") - return True - - -# ============================================================================ -# Test: Teacher logprob extraction (get_teacher_logprobs_by_ids mock) -# ============================================================================ - -def test_teacher_logprob_extraction(): - """extract_teacher_logprobs handles SGLang format correctly.""" - print("Test: Teacher logprob extraction") - - from examples.on_policy_distillation.on_policy_distillation_reward import ( - extract_teacher_logprobs - ) - - # SGLang format: [logprob, rank, token_str] tuples - response = { - "meta_info": { - "input_token_logprobs": [ - None, # BOS token - [-0.1, 1, "hello"], - [-0.2, 2, "world"], - [-0.15, 1, "!"], - [-0.3, 3, "."], - [-0.25, 2, "end"], - ] - } - } - - # Extract last 3 tokens as response - lp_list = extract_teacher_logprobs([response], response_lengths=[3], device="cpu") - assert len(lp_list) == 1 - assert len(lp_list[0]) == 3 - expected = torch.tensor([-0.3, -0.25, -0.15]) # Wait, let me check... - - # logprob_values = [-0.1, -0.2, -0.15, -0.3, -0.25] (skip None, take [0] from each) - # teacher_log_probs[-3:] = [-0.15, -0.3, -0.25] - # Hmm, actually the last 3 of [-0.1, -0.2, -0.15, -0.3, -0.25] = [-0.15, -0.3, -0.25] - expected = torch.tensor([-0.15, -0.3, -0.25]) - assert torch.allclose(lp_list[0], expected), f"Got {lp_list[0]}, expected {expected}" - print(f" SGLang format extraction: OK ({lp_list[0].tolist()})") - - # Test padding when response_length > available logprobs - lp_list2 = extract_teacher_logprobs([response], response_lengths=[10], device="cpu") - assert len(lp_list2[0]) == 10, f"Should pad to 10, got {len(lp_list2[0])}" - print(f" padding for short sequences: OK (len={len(lp_list2[0])})") - - print(" PASS\n") - return True - - -# ============================================================================ -# Test: Dimension alignment (teacher_log_probs vs action_log_probs) -# ============================================================================ - -def test_dimension_alignment(): - """teacher_log_probs must match action_log_probs shape [batch, num_actions].""" - print("Test: Dimension alignment") - - from lightrft.trainer.advantage_calculator import OnPolicyDistillationCalculator - - config = MockConfig(opd_kl_coef=1.0) - calc = OnPolicyDistillationCalculator(config) - - # Simulate different response lengths within a batch - # action_log_probs: [batch=2, num_actions=8] (padded to max response length) - # action_mask: first sample has 6 real tokens, second has 8 - batch_size, num_actions = 2, 8 - exp = MockExperience.__new__(MockExperience) - exp.action_log_probs = torch.randn(batch_size, num_actions) - exp.action_mask = torch.ones(batch_size, num_actions, dtype=torch.bool) - exp.action_mask[0, :2] = False # first 2 positions are prompt padding for sample 0 - exp.info = { - "teacher_log_probs": torch.randn(batch_size, num_actions), # same shape - } - - final_reward = torch.zeros(batch_size, num_actions) - adv, _, _ = calc.compute(exp, final_reward, gamma=1.0, generate_kwargs={}) - assert adv.shape == (batch_size, num_actions) - assert adv[0, 0] == 0 and adv[0, 1] == 0, "Masked positions should be 0" - print(f" shape match [batch={batch_size}, num_actions={num_actions}]: OK") - - # Mismatched shapes should fail - exp_bad = MockExperience.__new__(MockExperience) - exp_bad.action_log_probs = torch.randn(2, 8) - exp_bad.action_mask = torch.ones(2, 8, dtype=torch.bool) - exp_bad.info = {"teacher_log_probs": torch.randn(2, 12)} # wrong dim! - - try: - calc.compute(exp_bad, final_reward, gamma=1.0, generate_kwargs={}) - assert False, "Should fail on shape mismatch" - except RuntimeError: - pass - print(" shape mismatch correctly raises RuntimeError: OK") - - print(" PASS\n") - return True - - -# ============================================================================ +# --------------------------------------------------------------------------- # Test: Pure vs Hybrid produce different results -# ============================================================================ - -def test_pure_vs_hybrid(): - """Pure and hybrid modes produce meaningfully different advantages.""" - print("Test: Pure vs Hybrid comparison") - - from lightrft.trainer.advantage_calculator import ( - OnPolicyDistillationCalculator, - OnPolicyDistillationHybridCalculator, - ) - - config = MockConfig(opd_kl_coef=1.0, n_samples_per_prompt=4) - pure_calc = OnPolicyDistillationCalculator(config) - hybrid_calc = OnPolicyDistillationHybridCalculator(config) - - torch.manual_seed(42) - exp = MockExperience(batch_size=4, num_actions=10, teacher_offset=-0.5) - - # Pure: rewards are zeroed - rewards = torch.tensor([0.0, 1.0, 0.0, 1.0, 0.5, 0.5, 0.8, 0.2]) - _, pure_rewards = pure_calc.preprocess_rewards(rewards.clone(), [exp, exp], 100) - _, hybrid_rewards = hybrid_calc.preprocess_rewards(rewards.clone(), [exp, exp], 100) - - pure_r = torch.cat(pure_rewards) - hybrid_r = torch.cat(hybrid_rewards) - assert (pure_r == 0).all(), "Pure should zero rewards" - assert not (hybrid_r == 0).all(), "Hybrid should keep rewards" - print(f" reward preprocessing differs: OK (pure=0, hybrid has signal)") - - # Both should produce valid advantages - final_reward_pure = torch.zeros(4, 10) - final_reward_hybrid = torch.randn(4, 10) * 0.5 - - adv_pure, _, _ = pure_calc.compute(exp, final_reward_pure, 1.0, {}) - adv_hybrid, _, _ = hybrid_calc.compute(exp, final_reward_hybrid, 1.0, {}) - - assert adv_pure.shape == adv_hybrid.shape - # They should differ (different reward signals) - assert not torch.allclose(adv_pure, adv_hybrid, atol=0.01) - print(f" advantages differ between modes: OK") - - print(" PASS\n") - return True - - -# ============================================================================ -# Test: reward_func returns zeros -# ============================================================================ - -def test_reward_func(): - """Placeholder reward_func returns zeros.""" - print("Test: reward_func placeholder") - - from examples.on_policy_distillation.on_policy_distillation_reward import reward_func - - result = reward_func( - queries=["q1", "q2", "q3"], - prompts=["p1", "p2", "p3"], - ) - assert isinstance(result, torch.Tensor) - assert result.shape == (3,) - assert (result == 0).all() - print(" returns zeros: OK") - - print(" PASS\n") - return True - - -# ============================================================================ -# Runner -# ============================================================================ - -def run_all_tests(): - print("=" * 60) - print("On-Policy Distillation Test Suite (v3)") - print("=" * 60) - print() - - tests = [ - ("Factory registration", test_factory), - ("OPD KL penalty", test_opd_kl_penalty), - ("Advantage whitening", test_whiten_advantages), - ("Pure OPD calculator", test_pure_opd), - ("Hybrid OPD calculator", test_hybrid_opd), - ("Dimension alignment", test_dimension_alignment), - ("Pure vs Hybrid", test_pure_vs_hybrid), - ("Teacher logprob extraction", test_teacher_logprob_extraction), - ("Reward func placeholder", test_reward_func), - ] - - results = [] - for name, fn in tests: - try: - ok = fn() - results.append((name, ok)) - except Exception as e: - print(f" FAIL: {e}") - import traceback - traceback.print_exc() - results.append((name, False)) - print() - - print("=" * 60) - print("Summary") - print("=" * 60) - passed = sum(1 for _, ok in results if ok) - for name, ok in results: - print(f" {'PASS' if ok else 'FAIL'}: {name}") - print(f"\n{passed}/{len(results)} passed") - print("=" * 60) - - return 0 if passed == len(results) else 1 - - -if __name__ == "__main__": - sys.exit(run_all_tests()) +# --------------------------------------------------------------------------- + +class TestPureVsHybrid: + def test_advantages_differ(self, mock_config, mock_experience): + """Pure and hybrid modes produce meaningfully different advantages.""" + config = mock_config(opd_kl_coef=1.0, n_samples_per_prompt=4) + pure_calc = OnPolicyDistillationCalculator(config) + hybrid_calc = OnPolicyDistillationHybridCalculator(config) + + torch.manual_seed(42) + exp = mock_experience(batch_size=4, num_tokens=10, teacher_offset=-0.5) + + rewards = torch.tensor([0.0, 1.0, 0.0, 1.0, 0.5, 0.5, 0.8, 0.2]) + _, pure_rewards = pure_calc.preprocess_rewards(rewards.clone(), [exp, exp], 100) + _, hybrid_rewards = hybrid_calc.preprocess_rewards(rewards.clone(), [exp, exp], 100) + + hybrid_r = torch.cat(hybrid_rewards) + assert not (hybrid_r == 0).all(), "Hybrid should keep rewards" + + final_reward_pure = torch.zeros(4, 10) + final_reward_hybrid = torch.randn(4, 10) * 0.5 + + adv_pure, _, _ = pure_calc.compute(exp, final_reward_pure, 1.0, {}) + adv_hybrid, _, _ = hybrid_calc.compute(exp, final_reward_hybrid, 1.0, {}) + + assert adv_pure.shape == adv_hybrid.shape + assert not torch.allclose(adv_pure, adv_hybrid, atol=0.01) + + +# --------------------------------------------------------------------------- +# Test: Reward func placeholder +# --------------------------------------------------------------------------- + +class TestRewardFunc: + def test_reward_func_returns_zeros(self): + """Placeholder reward_func returns zeros.""" + from examples.on_policy_distillation.on_policy_distillation_reward import reward_func + + result = reward_func(queries=["q1", "q2", "q3"], prompts=["p1", "p2", "p3"]) + assert isinstance(result, torch.Tensor) + assert result.shape == (3,) + assert (result == 0).all() + + +# --------------------------------------------------------------------------- +# Test: RewardComputationEngine TypeError +# --------------------------------------------------------------------------- + +class TestRewardEngineTypeError: + def test_invalid_remote_rm_url_type_raises(self): + """Passing an invalid type for remote_rm_url raises TypeError.""" + from lightrft.trainer.fast_exp_maker import RewardComputationEngine + + with pytest.raises(TypeError, match="remote_rm_url must be str, list, tuple, or None"): + RewardComputationEngine( + reward_model=None, + remote_rm_url=12345, # int is invalid + custom_reward_func=None, + reward_fn=None, + reward_fn_label_map=None, + reward_recipe=None, + tokenizer=None, + strategy=None, + packing_samples=False, + ) diff --git a/lightrft/strategy/config.py b/lightrft/strategy/config.py index f88ae0ff..008fccd9 100644 --- a/lightrft/strategy/config.py +++ b/lightrft/strategy/config.py @@ -116,6 +116,11 @@ class StrategyConfig: advantage_estimator: str = "group_norm" # (float): OPD KL coefficient for on-policy distillation penalty opd_kl_coef: float = 1.0 + # (str): Dedicated teacher model URL for OPD (falls back to remote_rm_url if not set) + teacher_model_url: Optional[str] = None + # (bool): Use task reward in final reward computation. When False, rewards are zeroed + # but metrics (accuracy_reward, etc.) are still computed for logging. + use_task_reward: bool = True # KL loss and estimation # (bool): Use KL loss in training, defaults to False diff --git a/lightrft/trainer/advantage_calculator.py b/lightrft/trainer/advantage_calculator.py index 3b5ec6bf..2ba4f5d4 100644 --- a/lightrft/trainer/advantage_calculator.py +++ b/lightrft/trainer/advantage_calculator.py @@ -728,7 +728,6 @@ def _apply_opd_kl_penalty( Compute OPD reverse KL penalty: -opd_kl_coef * (student_logp - teacher_logp). Shared helper for both pure and hybrid OPD modes. - Aligned with Slime's apply_opd_kl_to_advantages. :return: Tuple of (opd_advantages, info_dict with opd_reverse_kl metric) """ @@ -745,7 +744,7 @@ def _apply_opd_kl_penalty( # legitimate value (P=1). With the padding-strip fix in _fetch_teacher_logprobs, # teacher logprobs are now correctly aligned and action_mask is sufficient. if action_mask is not None: - opd_adv = opd_adv * action_mask + opd_adv *= action_mask info_dict = {} if action_mask is not None: @@ -755,26 +754,6 @@ def _apply_opd_kl_penalty( return opd_adv, info_dict -def _whiten_advantages(advantages: torch.Tensor, action_mask: Optional[torch.Tensor], eps: float = 1e-8) -> torch.Tensor: - """ - Whiten advantages using masked mean/std (matching Slime's distributed_masked_whiten). - - This normalizes advantages to zero mean and unit variance, which stabilizes - training when OPD KL penalty has different scale from base advantages. - """ - if action_mask is not None: - mask = action_mask.bool() - masked_adv = torch.masked_select(advantages, mask) - else: - masked_adv = advantages.flatten() - - if masked_adv.numel() < 2: - return advantages - - mean = masked_adv.mean() - std = masked_adv.std() - return (advantages - mean) / (std + eps) - class OnPolicyDistillationCalculator(AdvantageCalculator): """ @@ -793,9 +772,8 @@ def __init__(self, config): self.opd_kl_coef = getattr(config, 'opd_kl_coef', 1.0) def preprocess_rewards(self, rewards, experiences, max_new_tokens): - """Zero out all rewards — pure distillation mode.""" - zero_rewards = torch.zeros_like(rewards) - return experiences, list(zero_rewards.chunk(len(experiences))) + """Pass through rewards — zeroing is handled upstream via --no_task_reward.""" + return experiences, list(rewards.chunk(len(experiences))) def compute(self, experience, final_reward, gamma, generate_kwargs): """advantages = -opd_kl_coef * (student_logp - teacher_logp). @@ -950,18 +928,18 @@ def normalize_advantages_cross_batch(experiences: List, advantage_estimator: str all_action_masks.append(exp.action_mask.flatten()) # Concatenate into vectors - advantages_vector = torch.cat(all_advantages, dim=0).float() - action_masks_vector = torch.cat(all_action_masks, dim=0).float() + advantages = torch.cat(all_advantages, dim=0).float() + action_masks = torch.cat(all_action_masks, dim=0).float() # Compute local intermediate statistics - local_sum = (advantages_vector * action_masks_vector).sum() - local_sum_sq = ((advantages_vector ** 2) * action_masks_vector).sum() - local_count = action_masks_vector.sum() + local_sum = (advantages * action_masks).sum() + local_sum_sq = ((advantages ** 2) * action_masks).sum() + local_count = action_masks.sum() # Aggregate across all data-parallel ranks via all_reduce # (matching Slime's distributed_masked_whiten) stats = torch.stack([local_sum, local_sum_sq, local_count]).to( - device=advantages_vector.device, dtype=torch.float32 + device=advantages.device, dtype=torch.float32 ) if dist.is_initialized(): dist.all_reduce(stats, op=dist.ReduceOp.SUM) diff --git a/lightrft/trainer/experience_maker.py b/lightrft/trainer/experience_maker.py index c5f63eac..7fee825a 100644 --- a/lightrft/trainer/experience_maker.py +++ b/lightrft/trainer/experience_maker.py @@ -301,6 +301,7 @@ def __init__( self.reward_recipe = reward_recipe self.perf_stats = None self.advantage_estimator = strategy.args.advantage_estimator + self.teacher_model_url = getattr(strategy.args, 'teacher_model_url', None) # Custom reward function for reinforced fine-tuning self.custom_reward_func = None @@ -393,16 +394,7 @@ def make_experience_list(self, all_prompts: Union[str, List[str]], **generate_kw generate_kwargs["gamma"], generate_kwargs["lambd"], ) - elif self.advantage_estimator in ["reinforce", "rloo", "reinforce_baseline", "group_norm"]: - experience.returns = self.get_cumulative_returns( - reward, - experience.action_mask, - generate_kwargs["gamma"], - ) - experience.advantages = deepcopy(experience.returns) - elif self.advantage_estimator in ("on_policy_distillation", "on_policy_distillation_hybrid"): - # OPD: cumulative returns from rewards (0 for pure, task rewards for hybrid) - # The OPD KL penalty is applied in the advantage calculator + elif self.advantage_estimator in ["reinforce", "rloo", "reinforce_baseline", "group_norm", "on_policy_distillation", "on_policy_distillation_hybrid"]: experience.returns = self.get_cumulative_returns( reward, experience.action_mask, @@ -586,14 +578,25 @@ def process_experiences(self, experiences: List[Experience]) -> Tuple[List[Exper # On-policy distillation: query teacher model for log probs, then use GRPO reward shaping if args.advantage_estimator in ("on_policy_distillation", "on_policy_distillation_hybrid"): - if self.remote_rm_url is None or len(self.remote_rm_url) == 0: - raise ValueError( - "On-policy distillation requires a teacher model URL. " - "Please set --remote_rm_url to the teacher model inference server." - ) + # Prefer dedicated teacher_model_url, fall back to remote_rm_url + teacher_url = self.teacher_model_url + if teacher_url is None: + if self.remote_rm_url is not None and len(self.remote_rm_url) > 0: + import warnings + warnings.warn( + "Using --remote_rm_url as teacher URL is deprecated. " + "Use --teacher_model_url instead.", + DeprecationWarning, + stacklevel=2, + ) + teacher_url = self.remote_rm_url[0] if isinstance(self.remote_rm_url, list) else self.remote_rm_url + else: + raise ValueError( + "On-policy distillation requires a teacher model URL. " + "Please set --teacher_model_url to the teacher model inference server." + ) import asyncio - teacher_url = self.remote_rm_url[0] if isinstance(self.remote_rm_url, list) else self.remote_rm_url # Collect all sequences as input_ids and response lengths all_input_ids = [] @@ -624,21 +627,21 @@ def process_experiences(self, experiences: List[Experience]) -> Tuple[List[Exper finally: loop.close() - # Align teacher log probs to action_log_probs shape [batch, num_actions] + # Align teacher log probs to action_log_probs shape [batch, num_tokens] idx = 0 for experience in experiences: batch_size = experience.sequences.size(0) - num_actions = experience.action_mask.shape[1] - aligned = torch.zeros(batch_size, num_actions, dtype=torch.float32) + num_tokens = experience.action_mask.shape[1] + aligned = torch.zeros(batch_size, num_tokens, dtype=torch.float32) for j in range(batch_size): tlp = teacher_lp_list[idx + j] resp_len = all_response_lengths[idx + j] - actual_len = min(len(tlp), resp_len, num_actions) - start_pos = num_actions - resp_len + actual_len = min(len(tlp), resp_len, num_tokens) + start_pos = num_tokens - resp_len if start_pos >= 0: aligned[j, start_pos:start_pos + actual_len] = tlp[:actual_len] else: - aligned[j, :] = tlp[-num_actions:] + aligned[j, :] = tlp[-num_tokens:] experience.info["teacher_log_probs"] = aligned idx += batch_size diff --git a/lightrft/trainer/fast_exp_maker.py b/lightrft/trainer/fast_exp_maker.py index dc027ccc..82ccdae0 100644 --- a/lightrft/trainer/fast_exp_maker.py +++ b/lightrft/trainer/fast_exp_maker.py @@ -57,17 +57,12 @@ from .advantage_calculator import get_advantage_calculator, normalize_advantages_cross_batch from .image_utils import normalize_images, get_images_num from .video_utils import normalize_videos, get_videos_num +from examples.on_policy_distillation.on_policy_distillation_reward import ( + get_teacher_logprobs_for_experiences, + get_teacher_logprobs_by_ids, +) +import asyncio -# On-Policy Distillation imports -try: - from examples.on_policy_distillation.on_policy_distillation_reward import ( - get_teacher_logprobs_for_experiences, - get_teacher_logprobs_by_ids, - ) - import asyncio - OPD_AVAILABLE = True -except ImportError: - OPD_AVAILABLE = False # ============================================================================ # Data Structures @@ -457,8 +452,13 @@ def __init__( self.remote_rm_url = None elif isinstance(remote_rm_url, str): self.remote_rm_url = [remote_rm_url] - else: + elif isinstance(remote_rm_url, (list, tuple)): self.remote_rm_url = list(remote_rm_url) + else: + raise TypeError( + f"remote_rm_url must be str, list, tuple, or None, got {type(remote_rm_url).__name__}" + ) + self.custom_reward_func = custom_reward_func self.reward_fn = reward_fn self.reward_fn_label_map = reward_fn_label_map or {} @@ -940,12 +940,23 @@ def __init__(self, *args, packing_samples: bool = False, processor=None, **kwarg else: self.multimodal_processor = None - # For On-Policy Distillation (OPD), remote_rm_url is used for teacher model, - # not for reward model. So we don't pass it to RewardComputationEngine. - # Instead, we store it separately for _fetch_teacher_logprobs(). + # For On-Policy Distillation (OPD), prefer dedicated teacher_model_url. + # Fall back to remote_rm_url with deprecation warning for backwards compatibility. if advantage_estimator in ("on_policy_distillation", "on_policy_distillation_hybrid"): - # Store teacher URL separately for OPD - self.teacher_model_url = self.remote_rm_url + teacher_url = getattr(self.strategy.args, 'teacher_model_url', None) + if teacher_url is not None: + self.teacher_model_url = teacher_url + elif self.remote_rm_url is not None: + import warnings + warnings.warn( + "Using --remote_rm_url as teacher URL is deprecated. " + "Use --teacher_model_url instead.", + DeprecationWarning, + stacklevel=2, + ) + self.teacher_model_url = self.remote_rm_url + else: + self.teacher_model_url = None rm_url_for_reward_engine = None # Don't use remote_rm_url for rewards in OPD mode else: self.teacher_model_url = None @@ -1462,11 +1473,6 @@ def _fetch_teacher_logprobs( :param experiences: List of experiences to add teacher log probs to :type experiences: List[Union[Experience, ExperienceVL]] """ - if not OPD_AVAILABLE: - raise RuntimeError( - "On-policy distillation module not available. " - "Make sure examples/on_policy_distillation/on_policy_distillation_reward.py exists." - ) # Get teacher URL from config teacher_url = self.teacher_model_url @@ -1475,7 +1481,7 @@ def _fetch_teacher_logprobs( if teacher_url is None: raise ValueError( "Teacher model URL not specified. " - "Please set --remote_rm_url to the teacher model server URL." + "Please set --teacher_model_url to the teacher model server URL." ) Timer.start(' fetch_teacher_logprobs') @@ -1483,11 +1489,11 @@ def _fetch_teacher_logprobs( for exp in experiences: sequences = exp.sequences # [batch_size, seq_len] attention_mask = exp.attention_mask # [batch_size, seq_len] - action_mask = exp.action_mask # [batch_size, num_actions] + action_mask = exp.action_mask # [batch_size, num_tokens] # response_lengths must be int for slicing response_lengths = action_mask.sum(dim=-1).int().tolist() - num_actions = action_mask.shape[1] + num_tokens = action_mask.shape[1] # Strip padding tokens before sending to SGLang. # sequences is [prompt, response, eos, pad, pad, ...] — the padding @@ -1512,10 +1518,16 @@ def _fetch_teacher_logprobs( finally: loop.close() - # Align teacher log probs to action_log_probs shape [batch_size, num_actions]. + # Align teacher log probs to action_log_probs shape [batch_size, num_tokens]. # Use action_mask indices directly — works regardless of left/right padding. + # + # Correctness: teacher_lp_list[i] contains teacher logprobs for response + # tokens in first→last order (from teacher_lp[-resp_len:]). + # valid_indices = ascending positions where action_mask==1 (real response tokens). + # So aligned_teacher_lp[i, valid_indices[k]] = tlp[k] correctly maps the k-th + # teacher logprob to the k-th response token position, regardless of padding direction. batch_size = sequences.shape[0] - aligned_teacher_lp = torch.zeros(batch_size, num_actions, dtype=torch.float32) + aligned_teacher_lp = torch.zeros(batch_size, num_tokens, dtype=torch.float32) for i, (tlp, resp_len) in enumerate(zip(teacher_lp_list, response_lengths)): if resp_len == 0: continue From 4de9be9e60d76aeac1e4eaa7736c9f10aa6dc25c Mon Sep 17 00:00:00 2001 From: puyuan1996 Date: Thu, 16 Apr 2026 20:48:55 +0800 Subject: [PATCH 13/18] fix(pu): fix little bugs --- examples/on_policy_distillation/README.md | 6 +- examples/on_policy_distillation/README_zh.md | 8 +- .../on_policy_distillation/run_opd_qwen.sh | 31 ++-- .../on_policy_distillation/start_teacher.sh | 80 ++++++++++ .../on_policy_distillation/start_training.sh | 137 ++++++++++++++++++ 5 files changed, 241 insertions(+), 21 deletions(-) create mode 100644 examples/on_policy_distillation/start_teacher.sh create mode 100644 examples/on_policy_distillation/start_training.sh diff --git a/examples/on_policy_distillation/README.md b/examples/on_policy_distillation/README.md index 569be2d1..cda58bcd 100644 --- a/examples/on_policy_distillation/README.md +++ b/examples/on_policy_distillation/README.md @@ -234,9 +234,9 @@ nvidia-smi examples/on_policy_distillation/ ├── README.md # This file ├── README_zh.md # Chinese version -├── run_opd_qwen.sh # All-in-one training script -├── start_teacher.sh # Teacher server only -├── start_training.sh # Training only (requires TEACHER_URL) +├── run_opd_qwen.sh # All-in-one training script (template) +├── start_teacher.sh # Teacher server only (template) +├── start_training.sh # Training only (template, requires TEACHER_URL) ├── test_opd.py # Unit tests └── on_policy_distillation_reward.py # Teacher logprob fetcher ``` diff --git a/examples/on_policy_distillation/README_zh.md b/examples/on_policy_distillation/README_zh.md index a0800b2a..f1b2c3b8 100644 --- a/examples/on_policy_distillation/README_zh.md +++ b/examples/on_policy_distillation/README_zh.md @@ -28,7 +28,7 @@ CUDA_VISIBLE_DEVICES=7 python3 -m sglang.launch_server \ ### 2. 运行训练 ```bash -bash examples/on_policy_distillation/run_opd_qwen_2.sh +bash examples/on_policy_distillation/run_opd_qwen.sh ``` 或手动运行: @@ -234,9 +234,9 @@ nvidia-smi examples/on_policy_distillation/ ├── README.md # 英文文档 ├── README_zh.md # 本文件 -├── run_opd_qwen.sh # 一体化训练脚本 -├── start_teacher.sh # 仅启动教师服务器 -├── start_training.sh # 仅启动训练(需要 TEACHER_URL) +├── run_opd_qwen.sh # 一体化训练脚本(模板) +├── start_teacher.sh # 仅启动教师服务器(模板) +├── start_training.sh # 仅启动训练(模板,需要 TEACHER_URL) ├── test_opd.py # 单元测试 └── on_policy_distillation_reward.py # 教师对数概率获取器 ``` diff --git a/examples/on_policy_distillation/run_opd_qwen.sh b/examples/on_policy_distillation/run_opd_qwen.sh index d210acb1..eb43c5d9 100644 --- a/examples/on_policy_distillation/run_opd_qwen.sh +++ b/examples/on_policy_distillation/run_opd_qwen.sh @@ -6,12 +6,12 @@ # Features: # - Auto GPU detection and allocation (teacher + training) # - Robust teacher server with health monitoring -# - Two OPD modes: pure distillation / hybrid (GRPO + OPD) +# - Two advantage estimators: pure distillation / hybrid (GRPO + OPD) # # Usage: # # Edit paths below, then: # bash examples/on_policy_distillation/run_opd_qwen.sh -# OPD_MODE=hybrid bash examples/on_policy_distillation/run_opd_qwen.sh +# ADVANTAGE_ESTIMATOR=on_policy_distillation_hybrid USE_TASK_REWARD=true bash examples/on_policy_distillation/run_opd_qwen.sh # set -euo pipefail @@ -66,10 +66,11 @@ echo "GPU Allocation: ${TOTAL_GPUS} total → Teacher: GPU ${TEACHER_GPU}, Train # Part 3: Training Hyperparameters # ################################################################################ -# --- OPD Mode (override via env: OPD_MODE=hybrid) --- -# "pure" - Pure distillation (Slime default): rewards=0, only OPD KL signal -# "hybrid" - GRPO task rewards + OPD KL penalty with advantage whitening -OPD_MODE="${OPD_MODE:-pure}" +# --- Mode control (override via env) --- +# ADVANTAGE_ESTIMATOR=on_policy_distillation - Pure distillation: rewards=0, only OPD KL signal +# ADVANTAGE_ESTIMATOR=on_policy_distillation_hybrid - GRPO task rewards + OPD KL penalty +ADVANTAGE_ESTIMATOR="${ADVANTAGE_ESTIMATOR:-on_policy_distillation}" +USE_TASK_REWARD="${USE_TASK_REWARD:-false}" N_SAMPLES=${N_SAMPLES:-8} EPISODE=${EPISODE:-30} @@ -88,16 +89,18 @@ TBS=$(( (${TBS:-128} / ALIGN) * ALIGN )) echo "Batch sizes: TBS=${TBS}, RBS=${RBS} (aligned to micro=${MICRO_TRAIN_BS} * world=${WORLD_SIZE} = ${ALIGN})" -if [ "$OPD_MODE" = "hybrid" ]; then - ADVANTAGE_ESTIMATOR="on_policy_distillation_hybrid" +if [ "$USE_TASK_REWARD" = "true" ]; then + TASK_REWARD_FLAG="--use_task_reward" +else + TASK_REWARD_FLAG="--no_task_reward" +fi + +if [ "$ADVANTAGE_ESTIMATOR" = "on_policy_distillation_hybrid" ]; then KL=${KL:-0.01} LR=${LR:-5e-7} - TASK_REWARD_FLAG="--use_task_reward" else - ADVANTAGE_ESTIMATOR="on_policy_distillation" KL=${KL:-0.00} LR=${LR:-5e-7} - TASK_REWARD_FLAG="--no_task_reward" fi PROMPT_MAX_LEN=${PROMPT_MAX_LEN:-1024} @@ -182,8 +185,8 @@ sleep 3 ################################################################################ current_time=$(date +"%Y%m%d_%H%M%S") -SAVE_MODEL_NAME="${EXPERIMENT_NAME}-${OPD_MODE}-ep${EPISODE}-kl${KL}-lr${LR}-${current_time}" -WANDB_RUN_NAME="${EXPERIMENT_NAME}-${OPD_MODE}-${current_time}" +SAVE_MODEL_NAME="${EXPERIMENT_NAME}-${ADVANTAGE_ESTIMATOR}-ep${EPISODE}-kl${KL}-lr${LR}-${current_time}" +WANDB_RUN_NAME="${EXPERIMENT_NAME}-${ADVANTAGE_ESTIMATOR}-${current_time}" TEACHER_URL="http://$TEACHER_IP:$TEACHER_PORT/generate" mkdir -p "results/${EXPERIMENT_NAME}/${SAVE_MODEL_NAME}" @@ -197,7 +200,7 @@ export WANDB_MODE="${WANDB_MODE:-offline}" echo "=========================================" echo "On-Policy Distillation Training" echo "=========================================" -echo "Mode: $OPD_MODE ($ADVANTAGE_ESTIMATOR)" +echo "Estimator: $ADVANTAGE_ESTIMATOR" echo "Student: $STUDENT_MODEL_PATH" echo "Teacher: $TEACHER_URL" echo "GPUs: Training=0-$((TRAIN_GPUS-1)), Teacher=$TEACHER_GPU" diff --git a/examples/on_policy_distillation/start_teacher.sh b/examples/on_policy_distillation/start_teacher.sh new file mode 100644 index 00000000..363d65fc --- /dev/null +++ b/examples/on_policy_distillation/start_teacher.sh @@ -0,0 +1,80 @@ +#!/bin/bash +# +# Start SGLang teacher server for On-Policy Distillation. +# Prints TEACHER_URL on success for use with start_training.sh. +# +# Usage: +# bash examples/on_policy_distillation/start_teacher.sh +# # Then in another terminal: +# TEACHER_URL=http://127.0.0.1:13141/generate bash examples/on_policy_distillation/start_training.sh +# + +set -euo pipefail + +TEACHER_MODEL_PATH="${TEACHER_MODEL_PATH:-Qwen/Qwen2.5-7B-Instruct}" +TEACHER_IP="${TEACHER_IP:-127.0.0.1}" +TEACHER_PORT="${TEACHER_PORT:-13141}" +TEACHER_GPU="${TEACHER_GPU:-0}" +MEM_FRACTION="${MEM_FRACTION:-0.7}" + +LOG_DIR="rft_logs/teacher" +mkdir -p "$LOG_DIR" +TEACHER_LOG="${LOG_DIR}/teacher_$(date +%Y%m%d_%H%M%S).log" + +# Kill any existing process on the port +if lsof -Pi :"$TEACHER_PORT" -sTCP:LISTEN -t >/dev/null 2>&1; then + echo "Port $TEACHER_PORT in use, killing existing process..." + lsof -ti:"$TEACHER_PORT" | xargs kill -9 2>/dev/null || true + sleep 3 +fi + +echo "Starting teacher server on GPU $TEACHER_GPU..." +echo " Model: $TEACHER_MODEL_PATH" +echo " Log: $TEACHER_LOG" + +CUDA_VISIBLE_DEVICES=$TEACHER_GPU python3 -m sglang.launch_server \ + --model-path "$TEACHER_MODEL_PATH" \ + --host 0.0.0.0 \ + --port "$TEACHER_PORT" \ + --tp 1 \ + --chunked-prefill-size 4096 \ + --mem-fraction-static "$MEM_FRACTION" \ + --disable-radix-cache \ + --max-running-requests 64 \ + >> "$TEACHER_LOG" 2>&1 & + +TEACHER_PID=$! + +# Wait for health check +max_wait=600 +waited=0 +while ! curl -sf "http://$TEACHER_IP:$TEACHER_PORT/health" >/dev/null 2>&1; do + if [ $waited -ge $max_wait ]; then + echo "ERROR: Teacher server failed to start in ${max_wait}s" + tail -30 "$TEACHER_LOG" + kill "$TEACHER_PID" 2>/dev/null || true + exit 1 + fi + if ! kill -0 "$TEACHER_PID" 2>/dev/null; then + echo "ERROR: Teacher server process died" + tail -30 "$TEACHER_LOG" + exit 1 + fi + printf "." + sleep 5 + waited=$((waited + 5)) +done + +TEACHER_URL="http://$TEACHER_IP:$TEACHER_PORT/generate" +echo "" +echo "=========================================" +echo "Teacher server ready!" +echo " PID: $TEACHER_PID" +echo " URL: $TEACHER_URL" +echo "=========================================" +echo "" +echo "Export for training:" +echo " export TEACHER_URL=$TEACHER_URL" + +# Keep running in foreground +wait "$TEACHER_PID" diff --git a/examples/on_policy_distillation/start_training.sh b/examples/on_policy_distillation/start_training.sh new file mode 100644 index 00000000..6be89d85 --- /dev/null +++ b/examples/on_policy_distillation/start_training.sh @@ -0,0 +1,137 @@ +#!/bin/bash +# +# Start OPD training. Requires TEACHER_URL env var. +# +# Usage: +# TEACHER_URL=http://127.0.0.1:13141/generate bash examples/on_policy_distillation/start_training.sh +# + +set -euo pipefail + +if [ -z "${TEACHER_URL:-}" ]; then + echo "ERROR: TEACHER_URL not set." + echo "Start the teacher first: bash examples/on_policy_distillation/start_teacher.sh" + echo "Then export TEACHER_URL=http://host:port/generate" + exit 1 +fi + +# --- Configuration --- +STUDENT_MODEL_PATH="${STUDENT_MODEL_PATH:-Qwen/Qwen2.5-0.5B-Instruct}" +DATASET_PATH="${DATASET_PATH:-/path/to/your/dataset.jsonl}" +EXPERIMENT_NAME="${EXPERIMENT_NAME:-opd-qwen}" + +export WANDB_API_KEY="${WANDB_API_KEY:-YOUR_WANDB_API_KEY}" +export WANDB_PROJECT="${WANDB_PROJECT:-LightRFT-OnPolicyDistillation}" +export WANDB_MODE="${WANDB_MODE:-offline}" + +# --- GPU setup --- +GPUS_PER_NODE="${GPUS_PER_NODE:-$(nvidia-smi -L 2>/dev/null | wc -l)}" +ENGINE_TP=$( [ "$GPUS_PER_NODE" -ge 2 ] && echo 2 || echo 1 ) + +# --- Hyperparameters --- +N_SAMPLES=${N_SAMPLES:-8} +EPISODE=${EPISODE:-30} +OPD_KL_COEF=${OPD_KL_COEF:-1.0} +MICRO_TRAIN_BS=${MICRO_TRAIN_BS:-4} +MICRO_ROLLOUT_BS=${MICRO_ROLLOUT_BS:-4} +LR=${LR:-5e-7} + +WORLD_SIZE=$((1 * GPUS_PER_NODE)) +ALIGN=$((MICRO_TRAIN_BS * WORLD_SIZE)) +RBS=$(( (${RBS:-128} / ALIGN) * ALIGN )) +TBS=$(( (${TBS:-128} / ALIGN) * ALIGN )) +[ "$RBS" -lt "$ALIGN" ] && RBS=$ALIGN +[ "$TBS" -lt "$ALIGN" ] && TBS=$ALIGN + +ADVANTAGE_ESTIMATOR="${ADVANTAGE_ESTIMATOR:-on_policy_distillation}" +USE_TASK_REWARD="${USE_TASK_REWARD:-false}" + +if [ "$USE_TASK_REWARD" = "true" ]; then + TASK_REWARD_FLAG="--use_task_reward" +else + TASK_REWARD_FLAG="--no_task_reward" +fi + +if [ "$ADVANTAGE_ESTIMATOR" = "on_policy_distillation_hybrid" ]; then + KL=${KL:-0.01} +else + KL=${KL:-0.00} +fi + +current_time=$(date +"%Y%m%d_%H%M%S") +SAVE_MODEL_NAME="${EXPERIMENT_NAME}-${ADVANTAGE_ESTIMATOR}-ep${EPISODE}-lr${LR}-${current_time}" +LOG_DIR="rft_logs/${EXPERIMENT_NAME}" +mkdir -p "$LOG_DIR" "results/${EXPERIMENT_NAME}/${SAVE_MODEL_NAME}" + +export TORCH_NCCL_AVOID_RECORD_STREAMS=1 +export NCCL_DEBUG="WARN" +export NCCL_TIMEOUT=3600 + +echo "=========================================" +echo "On-Policy Distillation Training" +echo " Estimator: $ADVANTAGE_ESTIMATOR" +echo " Student: $STUDENT_MODEL_PATH" +echo " Teacher: $TEACHER_URL" +echo " GPUs: $GPUS_PER_NODE" +echo "=========================================" + +set -x + +torchrun \ + --nnodes 1 \ + --nproc-per-node $GPUS_PER_NODE \ + --node_rank 0 \ + --master-port ${MASTER_PORT:-20090} \ + --master-addr localhost \ + examples/gsm8k_geo3k/train_colocate.py \ + --pretrain "$STUDENT_MODEL_PATH" \ + --save_trajectories \ + --advantage_estimator "${ADVANTAGE_ESTIMATOR}" \ + --opd_kl_coef ${OPD_KL_COEF} \ + --fsdp \ + --use_kl_loss \ + --flash_attn \ + --engine_type sglang \ + --enable_engine_sleep \ + --rm_use_engine \ + --reward_pretrain "" \ + --teacher_model_url "$TEACHER_URL" \ + ${TASK_REWARD_FLAG} \ + --save_path "results/${EXPERIMENT_NAME}/${SAVE_MODEL_NAME}" \ + --ckpt_path "results/${EXPERIMENT_NAME}/${SAVE_MODEL_NAME}" \ + --micro_train_batch_size ${MICRO_TRAIN_BS} \ + --train_batch_size ${TBS} \ + --micro_rollout_batch_size ${MICRO_ROLLOUT_BS} \ + --rollout_batch_size ${RBS} \ + --max_epochs 1 \ + --num_episodes ${EPISODE} \ + --lr_warmup_ratio 0.03 \ + --n_samples_per_prompt $N_SAMPLES \ + --prompt_max_len ${PROMPT_MAX_LEN:-1024} \ + --generate_max_len ${GENERATE_MAX_LEN:-2048} \ + --zero_stage 3 \ + --bf16 \ + --actor_learning_rate $LR \ + --init_kl_coef $KL \ + --kl_estimator "k3" \ + --prompt_data "$DATASET_PATH" \ + --input_key "prompt" \ + --label_key "label" \ + --eval_steps 20 \ + --eval_split "test" \ + --apply_chat_template \ + --gradient_checkpointing \ + --save_steps 20 \ + --max_ckpt_num 3 \ + --engine_mem_util 0.6 \ + --engine_tp_size $ENGINE_TP \ + --l2 1.0e-2 \ + --freeze_prefix \ + --adam_offload \ + --text_only \ + --use_wandb "${WANDB_API_KEY}" \ + --wandb_project "${WANDB_PROJECT}" \ + --wandb_run_name "${EXPERIMENT_NAME}-${ADVANTAGE_ESTIMATOR}-${current_time}" \ + 2>&1 | tee "${LOG_DIR}/train_${current_time}.log" + +exit ${PIPESTATUS[0]} From fa644790aa7191a0f52644c6a2db5fe19bbfd8b4 Mon Sep 17 00:00:00 2001 From: puyuan1996 Date: Thu, 16 Apr 2026 21:04:47 +0800 Subject: [PATCH 14/18] polish(pu): delete opd related code in ExperienceMaker --- lightrft/trainer/experience_maker.py | 94 +++------------------------- 1 file changed, 7 insertions(+), 87 deletions(-) diff --git a/lightrft/trainer/experience_maker.py b/lightrft/trainer/experience_maker.py index 7fee825a..64b101c7 100644 --- a/lightrft/trainer/experience_maker.py +++ b/lightrft/trainer/experience_maker.py @@ -394,7 +394,7 @@ def make_experience_list(self, all_prompts: Union[str, List[str]], **generate_kw generate_kwargs["gamma"], generate_kwargs["lambd"], ) - elif self.advantage_estimator in ["reinforce", "rloo", "reinforce_baseline", "group_norm", "on_policy_distillation", "on_policy_distillation_hybrid"]: + elif self.advantage_estimator in ["reinforce", "rloo", "reinforce_baseline", "group_norm"]: experience.returns = self.get_cumulative_returns( reward, experience.action_mask, @@ -576,93 +576,13 @@ def process_experiences(self, experiences: List[Experience]) -> Tuple[List[Exper """ args = self.strategy.args - # On-policy distillation: query teacher model for log probs, then use GRPO reward shaping + # On-policy distillation is only supported via FastExperienceMaker if args.advantage_estimator in ("on_policy_distillation", "on_policy_distillation_hybrid"): - # Prefer dedicated teacher_model_url, fall back to remote_rm_url - teacher_url = self.teacher_model_url - if teacher_url is None: - if self.remote_rm_url is not None and len(self.remote_rm_url) > 0: - import warnings - warnings.warn( - "Using --remote_rm_url as teacher URL is deprecated. " - "Use --teacher_model_url instead.", - DeprecationWarning, - stacklevel=2, - ) - teacher_url = self.remote_rm_url[0] if isinstance(self.remote_rm_url, list) else self.remote_rm_url - else: - raise ValueError( - "On-policy distillation requires a teacher model URL. " - "Please set --teacher_model_url to the teacher model inference server." - ) - - import asyncio - - # Collect all sequences as input_ids and response lengths - all_input_ids = [] - all_response_lengths = [] - for experience in experiences: - sequences_batch = experience.sequences - response_lengths = experience.info["response_length"] - for i, seq in enumerate(sequences_batch): - all_input_ids.append(seq.cpu().tolist()) - all_response_lengths.append(int(response_lengths[i].item())) - - # Query teacher model for log probs using input_ids - try: - from examples.on_policy_distillation.on_policy_distillation_reward import ( - get_teacher_logprobs_by_ids - ) - - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - teacher_lp_list = loop.run_until_complete( - get_teacher_logprobs_by_ids( - url=teacher_url, - input_ids_list=all_input_ids, - response_lengths=all_response_lengths, - ) - ) - finally: - loop.close() - - # Align teacher log probs to action_log_probs shape [batch, num_tokens] - idx = 0 - for experience in experiences: - batch_size = experience.sequences.size(0) - num_tokens = experience.action_mask.shape[1] - aligned = torch.zeros(batch_size, num_tokens, dtype=torch.float32) - for j in range(batch_size): - tlp = teacher_lp_list[idx + j] - resp_len = all_response_lengths[idx + j] - actual_len = min(len(tlp), resp_len, num_tokens) - start_pos = num_tokens - resp_len - if start_pos >= 0: - aligned[j, start_pos:start_pos + actual_len] = tlp[:actual_len] - else: - aligned[j, :] = tlp[-num_tokens:] - experience.info["teacher_log_probs"] = aligned - idx += batch_size - - except Exception as e: - logger.error(f"Failed to get teacher log probs: {e}") - raise - - # Return rewards based on mode - if args.advantage_estimator == "on_policy_distillation": - # Pure distillation: zero rewards, learning signal from OPD KL only - zero_rewards = torch.zeros(sum(exp.sequences.size(0) for exp in experiences)) - rewards = zero_rewards.chunk(len(experiences)) - return experiences, list(rewards) - else: - # Hybrid: use task rewards with GRPO normalization - rewards = torch.cat([experience.info["reward"] for experience in experiences]) - rewards = rewards.reshape(-1, args.n_samples_per_prompt) - baseline = rewards.mean(-1, keepdim=True) - rewards = (rewards - baseline) / (rewards.std(-1, keepdim=True) + 1e-9) - rewards = rewards.flatten().chunk(len(experiences)) - return experiences, list(rewards) + raise NotImplementedError( + "On-policy distillation is only supported with FastExperienceMaker " + "(use train_colocate.py / SpmdPPOTrainer). " + "NaiveExperienceMaker does not support OPD." + ) # Reward shaping for RLOO if args.advantage_estimator == "rloo": From 3d4c091f8d3e99b0ff217a6325d289bd4ec2e261 Mon Sep 17 00:00:00 2001 From: puyuan1996 Date: Thu, 16 Apr 2026 22:11:35 +0800 Subject: [PATCH 15/18] polish(pu): merge on_policy_distillation_hybrid into on_policy_distillation --- examples/gsm8k_geo3k/train_colocate.py | 4 +- examples/on_policy_distillation/README.md | 2 +- .../on_policy_distillation/run_opd_qwen.sh | 9 +- .../on_policy_distillation/start_training.sh | 2 +- examples/on_policy_distillation/test_opd.py | 125 +++++++----------- lightrft/trainer/advantage_calculator.py | 67 ++-------- lightrft/trainer/experience_maker.py | 2 +- lightrft/trainer/fast_exp_maker.py | 4 +- lightrft/trainer/ppo_trainer_vl.py | 10 +- lightrft/trainer/spmd_ppo_trainer.py | 5 +- 10 files changed, 74 insertions(+), 156 deletions(-) diff --git a/examples/gsm8k_geo3k/train_colocate.py b/examples/gsm8k_geo3k/train_colocate.py index bcba57b3..9bc14225 100644 --- a/examples/gsm8k_geo3k/train_colocate.py +++ b/examples/gsm8k_geo3k/train_colocate.py @@ -562,7 +562,7 @@ def train(args): parser.add_argument( "--advantage_estimator", type=str, - choices=["gae", "reinforce", "rloo", "reinforce_baseline", "group_norm", "cpgd", "reinforce++", "on_policy_distillation", "on_policy_distillation_hybrid"], + choices=["gae", "reinforce", "rloo", "reinforce_baseline", "group_norm", "cpgd", "reinforce++", "on_policy_distillation"], default="gae", help="Choose advantage estimation method: gae, reinforce, rloo, reinforce_baseline, group_norm, reinforce++", ) @@ -658,7 +658,7 @@ def train(args): elif args.critic_pretrain is None: args.critic_pretrain = args.pretrain - if args.advantage_estimator in ["rloo", "reinforce_baseline", "group_norm", "on_policy_distillation", "on_policy_distillation_hybrid"]: + if args.advantage_estimator in ["rloo", "reinforce_baseline", "group_norm", "on_policy_distillation"]: assert args.n_samples_per_prompt > 1, f"{args.advantage_estimator} requires n_samples_per_prompt > 1" if args.use_kl_loss: diff --git a/examples/on_policy_distillation/README.md b/examples/on_policy_distillation/README.md index cda58bcd..0c030ea6 100644 --- a/examples/on_policy_distillation/README.md +++ b/examples/on_policy_distillation/README.md @@ -28,7 +28,7 @@ CUDA_VISIBLE_DEVICES=7 python3 -m sglang.launch_server \ ### 2. Run Training ```bash -bash examples/on_policy_distillation/run_opd_qwen_2.sh +bash examples/on_policy_distillation/run_opd_qwen.sh ``` Or manually: diff --git a/examples/on_policy_distillation/run_opd_qwen.sh b/examples/on_policy_distillation/run_opd_qwen.sh index eb43c5d9..34fe74c7 100644 --- a/examples/on_policy_distillation/run_opd_qwen.sh +++ b/examples/on_policy_distillation/run_opd_qwen.sh @@ -6,12 +6,11 @@ # Features: # - Auto GPU detection and allocation (teacher + training) # - Robust teacher server with health monitoring -# - Two advantage estimators: pure distillation / hybrid (GRPO + OPD) # # Usage: # # Edit paths below, then: # bash examples/on_policy_distillation/run_opd_qwen.sh -# ADVANTAGE_ESTIMATOR=on_policy_distillation_hybrid USE_TASK_REWARD=true bash examples/on_policy_distillation/run_opd_qwen.sh +# USE_TASK_REWARD=true bash examples/on_policy_distillation/run_opd_qwen.sh # set -euo pipefail @@ -67,8 +66,8 @@ echo "GPU Allocation: ${TOTAL_GPUS} total → Teacher: GPU ${TEACHER_GPU}, Train ################################################################################ # --- Mode control (override via env) --- -# ADVANTAGE_ESTIMATOR=on_policy_distillation - Pure distillation: rewards=0, only OPD KL signal -# ADVANTAGE_ESTIMATOR=on_policy_distillation_hybrid - GRPO task rewards + OPD KL penalty +# USE_TASK_REWARD=false - Pure distillation: rewards=0, only OPD KL signal +# USE_TASK_REWARD=true - GRPO task rewards + OPD KL penalty ADVANTAGE_ESTIMATOR="${ADVANTAGE_ESTIMATOR:-on_policy_distillation}" USE_TASK_REWARD="${USE_TASK_REWARD:-false}" @@ -95,7 +94,7 @@ else TASK_REWARD_FLAG="--no_task_reward" fi -if [ "$ADVANTAGE_ESTIMATOR" = "on_policy_distillation_hybrid" ]; then +if [ "$USE_TASK_REWARD" = "true" ]; then KL=${KL:-0.01} LR=${LR:-5e-7} else diff --git a/examples/on_policy_distillation/start_training.sh b/examples/on_policy_distillation/start_training.sh index 6be89d85..2cfc9107 100644 --- a/examples/on_policy_distillation/start_training.sh +++ b/examples/on_policy_distillation/start_training.sh @@ -52,7 +52,7 @@ else TASK_REWARD_FLAG="--no_task_reward" fi -if [ "$ADVANTAGE_ESTIMATOR" = "on_policy_distillation_hybrid" ]; then +if [ "$USE_TASK_REWARD" = "true" ]; then KL=${KL:-0.01} else KL=${KL:-0.00} diff --git a/examples/on_policy_distillation/test_opd.py b/examples/on_policy_distillation/test_opd.py index a5fad991..76235977 100644 --- a/examples/on_policy_distillation/test_opd.py +++ b/examples/on_policy_distillation/test_opd.py @@ -1,7 +1,7 @@ """ Pytest suite for On-Policy Distillation implementation in LightRFT. -Tests both pure and hybrid OPD modes, KL penalty computation, +Tests the unified OPD calculator, KL penalty computation, teacher logprob extraction, dimension alignment, and reward engine validation. """ @@ -10,7 +10,6 @@ from lightrft.trainer.advantage_calculator import ( OnPolicyDistillationCalculator, - OnPolicyDistillationHybridCalculator, _apply_opd_kl_penalty, get_advantage_calculator, normalize_advantages_cross_batch, @@ -73,13 +72,12 @@ class _Exp: class TestFactory: def test_all_estimators_registered(self, mock_config): - """All estimators including both OPD modes are registered.""" + """All estimators are registered.""" config = mock_config() estimators = [ "gae", "reinforce", "rloo", "reinforce_baseline", "group_norm", "grpo", "cpgd", "on_policy_distillation", - "on_policy_distillation_hybrid", ] for name in estimators: calc = get_advantage_calculator(name, config) @@ -90,21 +88,28 @@ def test_unknown_estimator_raises(self, mock_config): with pytest.raises(ValueError): get_advantage_calculator("nonexistent", mock_config()) + def test_hybrid_removed(self, mock_config): + """on_policy_distillation_hybrid is no longer registered.""" + with pytest.raises(ValueError): + get_advantage_calculator("on_policy_distillation_hybrid", mock_config()) + # --------------------------------------------------------------------------- -# Test: Pure OPD calculator +# Test: Unified OPD calculator # --------------------------------------------------------------------------- -class TestPureOPD: - def test_preprocess_rewards_passthrough(self, mock_config, mock_experience): - """Pure OPD preprocess_rewards passes through rewards (zeroing done upstream).""" - calc = OnPolicyDistillationCalculator(mock_config(opd_kl_coef=1.0)) - rewards = torch.tensor([0.5, 0.8, 0.3, 0.9, 0.1, 0.7, 0.2, 0.4]) +class TestOPDCalculator: + def test_preprocess_rewards_grpo_normalization(self, mock_config, mock_experience): + """OPD applies GRPO normalization to rewards.""" + calc = OnPolicyDistillationCalculator( + mock_config(opd_kl_coef=1.0, n_samples_per_prompt=4) + ) + rewards = torch.tensor([0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0]) experiences = [mock_experience(batch_size=4), mock_experience(batch_size=4)] _, reward_chunks = calc.preprocess_rewards(rewards, experiences, max_new_tokens=100) - # Rewards are passed through; upstream --no_task_reward zeroes them combined = torch.cat(reward_chunks) - assert combined.shape == rewards.shape + # Non-uniform rewards should produce non-zero normalized values + assert not (combined == 0).all(), "Should apply GRPO normalization" def test_compute_advantages_shape_and_masking(self, mock_config, mock_experience): """Advantages have correct shape and padding positions are zero.""" @@ -127,35 +132,6 @@ def test_missing_teacher_logprobs_raises(self, mock_config, mock_experience): calc.compute(exp, final_reward, gamma=1.0, generate_kwargs={}) -# --------------------------------------------------------------------------- -# Test: Hybrid OPD calculator -# --------------------------------------------------------------------------- - -class TestHybridOPD: - def test_preprocess_rewards_grpo_normalization(self, mock_config, mock_experience): - """Hybrid mode applies GRPO normalization (rewards not zeroed).""" - calc = OnPolicyDistillationHybridCalculator( - mock_config(opd_kl_coef=1.0, n_samples_per_prompt=4) - ) - rewards = torch.tensor([0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 0.0]) - experiences = [mock_experience(batch_size=4), mock_experience(batch_size=4)] - _, reward_chunks = calc.preprocess_rewards(rewards, experiences, max_new_tokens=100) - combined = torch.cat(reward_chunks) - assert not (combined == 0).all(), "Hybrid should NOT zero rewards" - - def test_compute_advantages(self, mock_config, mock_experience): - """Hybrid compute produces correct shape and includes KL metric.""" - calc = OnPolicyDistillationHybridCalculator( - mock_config(opd_kl_coef=1.0, n_samples_per_prompt=4) - ) - exp = mock_experience(batch_size=4, num_tokens=10, teacher_offset=-0.3) - final_reward = torch.randn(4, 10) * 0.5 - adv, _ret, info = calc.compute(exp, final_reward, gamma=1.0, generate_kwargs={}) - - assert adv.shape == (4, 10) - assert "opd_reverse_kl" in info - - # --------------------------------------------------------------------------- # Test: OPD KL penalty helper # --------------------------------------------------------------------------- @@ -217,15 +193,15 @@ class _Args: pass args = _Args() - # on_policy_distillation_hybrid triggers normalization + # on_policy_distillation triggers normalization result = normalize_advantages_cross_batch( - [exp1, exp2], "on_policy_distillation_hybrid", args + [exp1, exp2], "on_policy_distillation", args ) assert len(result) == 2 assert result[0].advantages.shape == (4, 10) - def test_pure_opd_skips_normalization(self, mock_experience): - """Pure OPD mode skips cross-batch normalization.""" + def test_group_norm_skips_normalization(self, mock_experience): + """group_norm mode skips cross-batch normalization.""" exp = mock_experience(batch_size=4, num_tokens=10) exp.advantages = torch.randn(4, 10) * 5 + 2 original = exp.advantages.clone() @@ -234,12 +210,38 @@ class _Args: pass result = normalize_advantages_cross_batch( - [exp], "on_policy_distillation", _Args() + [exp], "group_norm", _Args() ) # Should return unchanged (not in whitening list) assert torch.equal(result[0].advantages, original) +# --------------------------------------------------------------------------- +# Test: Zero vs non-zero rewards (replaces TestPureVsHybrid) +# --------------------------------------------------------------------------- + +class TestZeroVsNonZeroRewards: + def test_advantages_differ_with_rewards(self, mock_config, mock_experience): + """OPD with zero rewards vs non-zero rewards produces different advantages.""" + config = mock_config(opd_kl_coef=1.0, n_samples_per_prompt=4) + calc = OnPolicyDistillationCalculator(config) + + torch.manual_seed(42) + exp = mock_experience(batch_size=4, num_tokens=10, teacher_offset=-0.5) + + # Zero rewards (pure distillation mode) + final_reward_zero = torch.zeros(4, 10) + adv_zero, _, _ = calc.compute(exp, final_reward_zero, 1.0, {}) + + # Non-zero rewards (hybrid mode with task rewards) + final_reward_nonzero = torch.randn(4, 10) * 0.5 + adv_nonzero, _, _ = calc.compute(exp, final_reward_nonzero, 1.0, {}) + + assert adv_zero.shape == adv_nonzero.shape + # With non-zero task rewards, advantages should differ + assert not torch.allclose(adv_zero, adv_nonzero, atol=0.01) + + # --------------------------------------------------------------------------- # Test: Teacher logprob extraction # --------------------------------------------------------------------------- @@ -332,37 +334,6 @@ class _Exp: calc.compute(exp, final_reward, gamma=1.0, generate_kwargs={}) -# --------------------------------------------------------------------------- -# Test: Pure vs Hybrid produce different results -# --------------------------------------------------------------------------- - -class TestPureVsHybrid: - def test_advantages_differ(self, mock_config, mock_experience): - """Pure and hybrid modes produce meaningfully different advantages.""" - config = mock_config(opd_kl_coef=1.0, n_samples_per_prompt=4) - pure_calc = OnPolicyDistillationCalculator(config) - hybrid_calc = OnPolicyDistillationHybridCalculator(config) - - torch.manual_seed(42) - exp = mock_experience(batch_size=4, num_tokens=10, teacher_offset=-0.5) - - rewards = torch.tensor([0.0, 1.0, 0.0, 1.0, 0.5, 0.5, 0.8, 0.2]) - _, pure_rewards = pure_calc.preprocess_rewards(rewards.clone(), [exp, exp], 100) - _, hybrid_rewards = hybrid_calc.preprocess_rewards(rewards.clone(), [exp, exp], 100) - - hybrid_r = torch.cat(hybrid_rewards) - assert not (hybrid_r == 0).all(), "Hybrid should keep rewards" - - final_reward_pure = torch.zeros(4, 10) - final_reward_hybrid = torch.randn(4, 10) * 0.5 - - adv_pure, _, _ = pure_calc.compute(exp, final_reward_pure, 1.0, {}) - adv_hybrid, _, _ = hybrid_calc.compute(exp, final_reward_hybrid, 1.0, {}) - - assert adv_pure.shape == adv_hybrid.shape - assert not torch.allclose(adv_pure, adv_hybrid, atol=0.01) - - # --------------------------------------------------------------------------- # Test: Reward func placeholder # --------------------------------------------------------------------------- diff --git a/lightrft/trainer/advantage_calculator.py b/lightrft/trainer/advantage_calculator.py index 2ba4f5d4..cdbc686c 100644 --- a/lightrft/trainer/advantage_calculator.py +++ b/lightrft/trainer/advantage_calculator.py @@ -757,58 +757,15 @@ def _apply_opd_kl_penalty( class OnPolicyDistillationCalculator(AdvantageCalculator): """ - On-Policy Distillation calculator — pure distillation mode. + On-Policy Distillation calculator. - Following Slime's design: - - Task rewards are zeroed out - - The ONLY learning signal is the OPD KL penalty: - advantages = -opd_kl_coef * (student_logp - teacher_logp) - - Advantage whitening is applied for training stability + When USE_TASK_REWARD=true: advantages = GRPO_base(task_rewards) + OPD_KL_penalty + When USE_TASK_REWARD=false: task rewards are all zeros, so advantages = OPD_KL_penalty only - Use --advantage_estimator on_policy_distillation - """ - def __init__(self, config): - super().__init__(config) - self.opd_kl_coef = getattr(config, 'opd_kl_coef', 1.0) - - def preprocess_rewards(self, rewards, experiences, max_new_tokens): - """Pass through rewards — zeroing is handled upstream via --no_task_reward.""" - return experiences, list(rewards.chunk(len(experiences))) - - def compute(self, experience, final_reward, gamma, generate_kwargs): - """advantages = -opd_kl_coef * (student_logp - teacher_logp). - Whitening is done cross-batch in normalize_advantages_cross_batch.""" - if "teacher_log_probs" not in experience.info: - raise ValueError("teacher_log_probs not found in experience.info.") - - teacher_lp = experience.info["teacher_log_probs"].to(experience.action_log_probs.device) - student_lp = experience.action_log_probs - - advantages, info_dict = _apply_opd_kl_penalty( - student_lp, teacher_lp, experience.action_mask, self.opd_kl_coef - ) - - returns = advantages.clone() - - if self.config.advantage_clip > 0: - clip_val = self.config.advantage_clip - info_dict["advantage_clip_frac"] = compute_clip_fraction(advantages, clip_val, -clip_val) - advantages = torch.clamp(advantages, -clip_val, clip_val) + This unified calculator handles both cases automatically. + Whitening is done cross-batch in normalize_advantages_cross_batch. - return advantages, returns, info_dict - - -class OnPolicyDistillationHybridCalculator(AdvantageCalculator): - """ - Hybrid On-Policy Distillation calculator — GRPO task rewards + OPD KL penalty. - - Combines GRPO (group normalization) base advantages from task rewards with - OPD KL penalty from teacher model. Advantage whitening is applied AFTER - combining both signals to resolve scale mismatch. - - advantages = whiten(GRPO_base_advantages + OPD_KL_penalty) - - Use --advantage_estimator on_policy_distillation_hybrid + Use --advantage_estimator on_policy_distillation """ def __init__(self, config): super().__init__(config) @@ -820,9 +777,8 @@ def preprocess_rewards(self, rewards, experiences, max_new_tokens): return self.base_calculator.preprocess_rewards(rewards, experiences, max_new_tokens) def compute(self, experience, final_reward, gamma, generate_kwargs): - """advantages = GRPO_base + OPD_KL_penalty. - Whitening is done cross-batch in normalize_advantages_cross_batch.""" - # Step 1: GRPO base advantages from task rewards + """advantages = GRPO_base + OPD_KL_penalty.""" + # Step 1: GRPO base advantages from task rewards (zeros if no_task_reward) base_advantages, returns, info_dict = self.base_calculator.compute( experience, final_reward, gamma, generate_kwargs ) @@ -839,7 +795,7 @@ def compute(self, experience, final_reward, gamma, generate_kwargs): ) info_dict.update(opd_info) - # Step 3: Combine (whitening done cross-batch later) + # Step 3: Combine advantages = base_advantages + opd_adv if self.config.advantage_clip > 0: @@ -859,7 +815,7 @@ def get_advantage_calculator(estimator_name: str, config) -> AdvantageCalculator """ Factory function to create an advantage calculator instance. - :param estimator_name: Name of the advantage estimation method + :param estimator_name: Name of the advantage estimation method. Options: "gae", "cpgd", "reinforce", "rloo", "reinforce_baseline", "group_norm", "grpo", "on_policy_distillation" @@ -879,7 +835,6 @@ def get_advantage_calculator(estimator_name: str, config) -> AdvantageCalculator "cpgd": CPGDCalculator, "grpo": GroupNormCalculator, # Alias for group_norm "on_policy_distillation": OnPolicyDistillationCalculator, - "on_policy_distillation_hybrid": OnPolicyDistillationHybridCalculator, } calculator_class = calculator_map.get(estimator_name) @@ -916,7 +871,7 @@ def normalize_advantages_cross_batch(experiences: List, advantage_estimator: str """ if advantage_estimator not in [ "gae", "reinforce", "reinforce_baseline", - "on_policy_distillation_hybrid", + "on_policy_distillation", ]: return experiences diff --git a/lightrft/trainer/experience_maker.py b/lightrft/trainer/experience_maker.py index 64b101c7..477737ef 100644 --- a/lightrft/trainer/experience_maker.py +++ b/lightrft/trainer/experience_maker.py @@ -577,7 +577,7 @@ def process_experiences(self, experiences: List[Experience]) -> Tuple[List[Exper args = self.strategy.args # On-policy distillation is only supported via FastExperienceMaker - if args.advantage_estimator in ("on_policy_distillation", "on_policy_distillation_hybrid"): + if args.advantage_estimator == "on_policy_distillation": raise NotImplementedError( "On-policy distillation is only supported with FastExperienceMaker " "(use train_colocate.py / SpmdPPOTrainer). " diff --git a/lightrft/trainer/fast_exp_maker.py b/lightrft/trainer/fast_exp_maker.py index 82ccdae0..4359db23 100644 --- a/lightrft/trainer/fast_exp_maker.py +++ b/lightrft/trainer/fast_exp_maker.py @@ -942,7 +942,7 @@ def __init__(self, *args, packing_samples: bool = False, processor=None, **kwarg # For On-Policy Distillation (OPD), prefer dedicated teacher_model_url. # Fall back to remote_rm_url with deprecation warning for backwards compatibility. - if advantage_estimator in ("on_policy_distillation", "on_policy_distillation_hybrid"): + if advantage_estimator == "on_policy_distillation": teacher_url = getattr(self.strategy.args, 'teacher_model_url', None) if teacher_url is not None: self.teacher_model_url = teacher_url @@ -1086,7 +1086,7 @@ def make_experience_list( self._process_multi_image_video_thws(experiences, expanded_images_num, expanded_videos_num) # ========== Stage 6.5: On-Policy Distillation Teacher Log-Probs ========== - if config.advantage_estimator in ("on_policy_distillation", "on_policy_distillation_hybrid"): + if config.advantage_estimator == "on_policy_distillation": self._fetch_teacher_logprobs(experiences) # ========== Stage 7: Advantage Computation ========== diff --git a/lightrft/trainer/ppo_trainer_vl.py b/lightrft/trainer/ppo_trainer_vl.py index e3b44091..a8126a09 100644 --- a/lightrft/trainer/ppo_trainer_vl.py +++ b/lightrft/trainer/ppo_trainer_vl.py @@ -458,10 +458,7 @@ def fit( format_tensor = torch.tensor(all_format_rewards, dtype=torch.float32, device=device) mean_format_reward = format_tensor.mean().item() - - # Only display if mean is significantly non-zero - if abs(mean_format_reward) > 1e-6: - rollout_status["rollout_format_reward"] = mean_format_reward + rollout_status["rollout_format_reward"] = mean_format_reward if all_accuracy_rewards: # [TENSOR-FIX] Handle both tensor lists and scalar lists @@ -471,10 +468,7 @@ def fit( accuracy_tensor = torch.tensor(all_accuracy_rewards, dtype=torch.float32, device=device) mean_accuracy_reward = accuracy_tensor.mean().item() - - # Only display if mean is significantly non-zero - if abs(mean_accuracy_reward) > 1e-6: - rollout_status["rollout_accuracy_reward"] = mean_accuracy_reward + rollout_status["rollout_accuracy_reward"] = mean_accuracy_reward if all_response_lengths: # [TENSOR-FIX] Handle both tensor lists and scalar lists diff --git a/lightrft/trainer/spmd_ppo_trainer.py b/lightrft/trainer/spmd_ppo_trainer.py index 0010a9b9..8d505649 100644 --- a/lightrft/trainer/spmd_ppo_trainer.py +++ b/lightrft/trainer/spmd_ppo_trainer.py @@ -400,9 +400,8 @@ def ppo_train(self, global_steps=0): # Currently using this rewritten ppo_train rule_tensor = torch.cat([t.to(device).float() for t in all_rule_rewards]) else: rule_tensor = torch.tensor(all_rule_rewards, dtype=torch.float32, device=device) - if rule_tensor.abs().sum() > 0: # Only log if rule rewards are non-zero - status_mean["rule_reward_mean"] = rule_tensor.mean().item() - self.strategy.print(f"rule_reward_mean: {status_mean['rule_reward_mean']}") + status_mean["rule_reward_mean"] = rule_tensor.mean().item() + self.strategy.print(f"rule_reward_mean: {status_mean['rule_reward_mean']}") # For advantages, returns, and lengths, they are already lists of tensors, # so torch.cat() is the correct function to use. From eb7ff841c49c938f4c925558333a1c2adcb3f8b9 Mon Sep 17 00:00:00 2001 From: puyuan1996 Date: Thu, 16 Apr 2026 22:12:51 +0800 Subject: [PATCH 16/18] style(pu): make format --- lightrft/trainer/advantage_calculator.py | 13 +++++-------- lightrft/trainer/fast_exp_maker.py | 15 +++++---------- 2 files changed, 10 insertions(+), 18 deletions(-) diff --git a/lightrft/trainer/advantage_calculator.py b/lightrft/trainer/advantage_calculator.py index cdbc686c..37ca073b 100644 --- a/lightrft/trainer/advantage_calculator.py +++ b/lightrft/trainer/advantage_calculator.py @@ -754,7 +754,6 @@ def _apply_opd_kl_penalty( return opd_adv, info_dict - class OnPolicyDistillationCalculator(AdvantageCalculator): """ On-Policy Distillation calculator. @@ -790,9 +789,7 @@ def compute(self, experience, final_reward, gamma, generate_kwargs): teacher_lp = experience.info["teacher_log_probs"].to(base_advantages.device) student_lp = experience.action_log_probs - opd_adv, opd_info = _apply_opd_kl_penalty( - student_lp, teacher_lp, experience.action_mask, self.opd_kl_coef - ) + opd_adv, opd_info = _apply_opd_kl_penalty(student_lp, teacher_lp, experience.action_mask, self.opd_kl_coef) info_dict.update(opd_info) # Step 3: Combine @@ -870,7 +867,9 @@ def normalize_advantages_cross_batch(experiences: List, advantage_estimator: str :rtype: List """ if advantage_estimator not in [ - "gae", "reinforce", "reinforce_baseline", + "gae", + "reinforce", + "reinforce_baseline", "on_policy_distillation", ]: return experiences @@ -893,9 +892,7 @@ def normalize_advantages_cross_batch(experiences: List, advantage_estimator: str # Aggregate across all data-parallel ranks via all_reduce # (matching Slime's distributed_masked_whiten) - stats = torch.stack([local_sum, local_sum_sq, local_count]).to( - device=advantages.device, dtype=torch.float32 - ) + stats = torch.stack([local_sum, local_sum_sq, local_count]).to(device=advantages.device, dtype=torch.float32) if dist.is_initialized(): dist.all_reduce(stats, op=dist.ReduceOp.SUM) diff --git a/lightrft/trainer/fast_exp_maker.py b/lightrft/trainer/fast_exp_maker.py index 4359db23..2bcb04b9 100644 --- a/lightrft/trainer/fast_exp_maker.py +++ b/lightrft/trainer/fast_exp_maker.py @@ -63,7 +63,6 @@ ) import asyncio - # ============================================================================ # Data Structures # ============================================================================ @@ -455,10 +454,8 @@ def __init__( elif isinstance(remote_rm_url, (list, tuple)): self.remote_rm_url = list(remote_rm_url) else: - raise TypeError( - f"remote_rm_url must be str, list, tuple, or None, got {type(remote_rm_url).__name__}" - ) - + raise TypeError(f"remote_rm_url must be str, list, tuple, or None, got {type(remote_rm_url).__name__}") + self.custom_reward_func = custom_reward_func self.reward_fn = reward_fn self.reward_fn_label_map = reward_fn_label_map or {} @@ -1487,9 +1484,9 @@ def _fetch_teacher_logprobs( Timer.start(' fetch_teacher_logprobs') for exp in experiences: - sequences = exp.sequences # [batch_size, seq_len] + sequences = exp.sequences # [batch_size, seq_len] attention_mask = exp.attention_mask # [batch_size, seq_len] - action_mask = exp.action_mask # [batch_size, num_tokens] + action_mask = exp.action_mask # [batch_size, num_tokens] # response_lengths must be int for slicing response_lengths = action_mask.sum(dim=-1).int().tolist() @@ -1538,9 +1535,7 @@ def _fetch_teacher_logprobs( exp.info["teacher_log_probs"] = aligned_teacher_lp except Exception as e: - raise RuntimeError( - f"Failed to fetch teacher log probs from {teacher_url}: {e}" - ) from e + raise RuntimeError(f"Failed to fetch teacher log probs from {teacher_url}: {e}") from e Timer.stop(' fetch_teacher_logprobs') From 651f6e0a76241ec94bfbe3c4227be5ad416e7836 Mon Sep 17 00:00:00 2001 From: puyuan1996 Date: Fri, 17 Apr 2026 17:15:39 +0800 Subject: [PATCH 17/18] polish(pu): polish opd_utils.py --- examples/on_policy_distillation/README.md | 5 ++--- examples/on_policy_distillation/README_zh.md | 5 ++--- examples/on_policy_distillation/run_opd_qwen.sh | 3 --- examples/on_policy_distillation/start_training.sh | 3 --- examples/on_policy_distillation/test_opd.py | 10 +++------- lightrft/trainer/fast_exp_maker.py | 2 +- .../trainer/opd_utils.py | 2 +- lightrft/utils/cli_args.py | 2 +- 8 files changed, 10 insertions(+), 22 deletions(-) rename examples/on_policy_distillation/on_policy_distillation_reward.py => lightrft/trainer/opd_utils.py (99%) diff --git a/examples/on_policy_distillation/README.md b/examples/on_policy_distillation/README.md index 0c030ea6..fc84d2d8 100644 --- a/examples/on_policy_distillation/README.md +++ b/examples/on_policy_distillation/README.md @@ -100,7 +100,7 @@ class OnPolicyDistillationCalculator(AdvantageCalculator): ### 2. Teacher Log Prob Fetcher -**File**: `examples/on_policy_distillation/on_policy_distillation_reward.py` +**File**: `lightrft/trainer/opd_utils.py` - Async HTTP requests to teacher server - Supports SGLang and vLLM response formats @@ -237,8 +237,7 @@ examples/on_policy_distillation/ ├── run_opd_qwen.sh # All-in-one training script (template) ├── start_teacher.sh # Teacher server only (template) ├── start_training.sh # Training only (template, requires TEACHER_URL) -├── test_opd.py # Unit tests -└── on_policy_distillation_reward.py # Teacher logprob fetcher +└── test_opd.py # Unit tests ``` ## References diff --git a/examples/on_policy_distillation/README_zh.md b/examples/on_policy_distillation/README_zh.md index f1b2c3b8..e9ed1012 100644 --- a/examples/on_policy_distillation/README_zh.md +++ b/examples/on_policy_distillation/README_zh.md @@ -100,7 +100,7 @@ class OnPolicyDistillationCalculator(AdvantageCalculator): ### 2. 教师对数概率获取器 -**文件**: `examples/on_policy_distillation/on_policy_distillation_reward.py` +**文件**: `lightrft/trainer/opd_utils.py` - 异步 HTTP 请求到教师服务器 - 支持 SGLang 和 vLLM 响应格式 @@ -237,8 +237,7 @@ examples/on_policy_distillation/ ├── run_opd_qwen.sh # 一体化训练脚本(模板) ├── start_teacher.sh # 仅启动教师服务器(模板) ├── start_training.sh # 仅启动训练(模板,需要 TEACHER_URL) -├── test_opd.py # 单元测试 -└── on_policy_distillation_reward.py # 教师对数概率获取器 +└── test_opd.py # 单元测试 ``` ## 参考资料 diff --git a/examples/on_policy_distillation/run_opd_qwen.sh b/examples/on_policy_distillation/run_opd_qwen.sh index 34fe74c7..29d031a0 100644 --- a/examples/on_policy_distillation/run_opd_qwen.sh +++ b/examples/on_policy_distillation/run_opd_qwen.sh @@ -222,8 +222,6 @@ torchrun \ --use_kl_loss \ --flash_attn \ --engine_type sglang \ - --enable_engine_sleep \ - --rm_use_engine \ --reward_pretrain "" \ --teacher_model_url "$TEACHER_URL" \ ${TASK_REWARD_FLAG} \ @@ -233,7 +231,6 @@ torchrun \ --train_batch_size ${TBS} \ --micro_rollout_batch_size ${MICRO_ROLLOUT_BS} \ --rollout_batch_size ${RBS} \ - --max_epochs 1 \ --num_episodes ${EPISODE} \ --lr_warmup_ratio ${WARMUP} \ --n_samples_per_prompt $N_SAMPLES \ diff --git a/examples/on_policy_distillation/start_training.sh b/examples/on_policy_distillation/start_training.sh index 2cfc9107..7f63e1a1 100644 --- a/examples/on_policy_distillation/start_training.sh +++ b/examples/on_policy_distillation/start_training.sh @@ -92,8 +92,6 @@ torchrun \ --use_kl_loss \ --flash_attn \ --engine_type sglang \ - --enable_engine_sleep \ - --rm_use_engine \ --reward_pretrain "" \ --teacher_model_url "$TEACHER_URL" \ ${TASK_REWARD_FLAG} \ @@ -103,7 +101,6 @@ torchrun \ --train_batch_size ${TBS} \ --micro_rollout_batch_size ${MICRO_ROLLOUT_BS} \ --rollout_batch_size ${RBS} \ - --max_epochs 1 \ --num_episodes ${EPISODE} \ --lr_warmup_ratio 0.03 \ --n_samples_per_prompt $N_SAMPLES \ diff --git a/examples/on_policy_distillation/test_opd.py b/examples/on_policy_distillation/test_opd.py index 76235977..705d5d73 100644 --- a/examples/on_policy_distillation/test_opd.py +++ b/examples/on_policy_distillation/test_opd.py @@ -249,9 +249,7 @@ def test_advantages_differ_with_rewards(self, mock_config, mock_experience): class TestTeacherLogprobExtraction: def test_sglang_format(self): """extract_teacher_logprobs handles SGLang format correctly.""" - from examples.on_policy_distillation.on_policy_distillation_reward import ( - extract_teacher_logprobs, - ) + from lightrft.trainer.opd_utils import extract_teacher_logprobs response = { "meta_info": { @@ -275,9 +273,7 @@ def test_sglang_format(self): def test_padding_for_long_response(self): """Pads to response_length when requested length > available logprobs.""" - from examples.on_policy_distillation.on_policy_distillation_reward import ( - extract_teacher_logprobs, - ) + from lightrft.trainer.opd_utils import extract_teacher_logprobs response = { "meta_info": { @@ -341,7 +337,7 @@ class _Exp: class TestRewardFunc: def test_reward_func_returns_zeros(self): """Placeholder reward_func returns zeros.""" - from examples.on_policy_distillation.on_policy_distillation_reward import reward_func + from lightrft.trainer.opd_utils import reward_func result = reward_func(queries=["q1", "q2", "q3"], prompts=["p1", "p2", "p3"]) assert isinstance(result, torch.Tensor) diff --git a/lightrft/trainer/fast_exp_maker.py b/lightrft/trainer/fast_exp_maker.py index 2bcb04b9..bbd13d6d 100644 --- a/lightrft/trainer/fast_exp_maker.py +++ b/lightrft/trainer/fast_exp_maker.py @@ -57,7 +57,7 @@ from .advantage_calculator import get_advantage_calculator, normalize_advantages_cross_batch from .image_utils import normalize_images, get_images_num from .video_utils import normalize_videos, get_videos_num -from examples.on_policy_distillation.on_policy_distillation_reward import ( +from lightrft.trainer.opd_utils import ( get_teacher_logprobs_for_experiences, get_teacher_logprobs_by_ids, ) diff --git a/examples/on_policy_distillation/on_policy_distillation_reward.py b/lightrft/trainer/opd_utils.py similarity index 99% rename from examples/on_policy_distillation/on_policy_distillation_reward.py rename to lightrft/trainer/opd_utils.py index 3f75c789..fd5680a3 100644 --- a/examples/on_policy_distillation/on_policy_distillation_reward.py +++ b/lightrft/trainer/opd_utils.py @@ -1,5 +1,5 @@ """ -On-Policy Distillation Reward Function for LightRFT +On-Policy Distillation Utility Functions for LightRFT This module provides functions to query a teacher model for log probabilities used in knowledge distillation during RL training. diff --git a/lightrft/utils/cli_args.py b/lightrft/utils/cli_args.py index db879ecf..c282b422 100644 --- a/lightrft/utils/cli_args.py +++ b/lightrft/utils/cli_args.py @@ -122,7 +122,7 @@ def add_arguments(parser: argparse.ArgumentParser) -> None: parser.add_argument( "--rm_use_engine", action="store_true", - default=False, + default=True, help="Use the high-throughput inference engine (e.g., vLLM) for the reward model during RLHF training. " "Can significantly speed up reward evaluation compared to standard forward passes.", ) From f0d9fb336bc790ad9ccbb0915d0bb328d308c917 Mon Sep 17 00:00:00 2001 From: puyuan1996 Date: Fri, 17 Apr 2026 17:24:52 +0800 Subject: [PATCH 18/18] polish(pu): polish opd reference in readme --- examples/on_policy_distillation/README.md | 8 +------- examples/on_policy_distillation/README_zh.md | 9 +-------- 2 files changed, 2 insertions(+), 15 deletions(-) diff --git a/examples/on_policy_distillation/README.md b/examples/on_policy_distillation/README.md index fc84d2d8..452497d7 100644 --- a/examples/on_policy_distillation/README.md +++ b/examples/on_policy_distillation/README.md @@ -1,6 +1,6 @@ # On-Policy Distillation (OPD) for LightRFT -On-policy knowledge distillation enables smaller student models to learn from larger teacher models during reinforcement learning training. +[On-Policy Distillation](https://thinkingmachines.ai/blog/on-policy-distillation/) enables smaller student models to learn from larger teacher models during reinforcement learning training. ## Overview @@ -239,9 +239,3 @@ examples/on_policy_distillation/ ├── start_training.sh # Training only (template, requires TEACHER_URL) └── test_opd.py # Unit tests ``` - -## References - -- [LightRFT Documentation](../../README.md) -- [Advantage Calculator Source](../../lightrft/trainer/advantage_calculator.py) -- [Fast Experience Maker Source](../../lightrft/trainer/fast_exp_maker.py) diff --git a/examples/on_policy_distillation/README_zh.md b/examples/on_policy_distillation/README_zh.md index e9ed1012..93f4c76d 100644 --- a/examples/on_policy_distillation/README_zh.md +++ b/examples/on_policy_distillation/README_zh.md @@ -1,6 +1,6 @@ # LightRFT 在线策略蒸馏 (OPD) -在线策略知识蒸馏使小型学生模型能够在强化学习训练过程中从大型教师模型学习。 +在线策略蒸馏([On-Policy Distillation Blog](https://thinkingmachines.ai/blog/on-policy-distillation/))使小型学生模型能够在强化学习训练过程中从大型教师模型学习。 ## 概述 @@ -239,10 +239,3 @@ examples/on_policy_distillation/ ├── start_training.sh # 仅启动训练(模板,需要 TEACHER_URL) └── test_opd.py # 单元测试 ``` - -## 参考资料 - -- [LightRFT 文档](../../README.md) -- [优势值计算器源码](../../lightrft/trainer/advantage_calculator.py) -- [Fast Experience Maker 源码](../../lightrft/trainer/fast_exp_maker.py) -- [On-Policy Distillation Blog](https://thinkingmachines.ai/blog/on-policy-distillation/)