diff --git a/examples/gsm8k_geo3k/run_grpo_geo3k_lora_qwen2.5_vl_7b.sh b/examples/gsm8k_geo3k/run_grpo_geo3k_lora_qwen2.5_vl_7b.sh new file mode 100644 index 0000000..d0110c9 --- /dev/null +++ b/examples/gsm8k_geo3k/run_grpo_geo3k_lora_qwen2.5_vl_7b.sh @@ -0,0 +1,214 @@ +#!/bin/bash +# +# LightRFT Multi-Modal LoRA Training Script for the Geo3K Dataset. +# This script is designed for fine-tuning a large multi-modal model using the GRPO algorithm with LoRA. +# +################################################################################ +# LoRA Pipeline Introduction # +# # +# This pipeline is a LoRA-based (Low-Rank Adaptation) modification of the # +# standard GRPO training process. It enables parameter-efficient fine-tuning # +# of large-scale models (e.g., Qwen2.5-VL-7B) by updating only a small # +# percentage of weights. # +# # +# Main modifications for LoRA: # +# - Parameter Efficiency: Significantly reduces VRAM usage for 7B+ models. # +# - Targeted Adaptation: Adapts all linear layers to maintain necessary # +# model capacity. # +# - Memory Optimization: Fully compatible with FSDP and ZeRO strategies. # +################################################################################ +# +# Key Features: +# 1. Parameter-Efficient Fine-Tuning (LoRA): +# - Enables training large 7B+ multi-modal models on consumer-grade or limited GPU setups. +# - Configured with a high LoRA rank (128) and alpha (256) to capture complex reasoning patterns. +# 2. PURE RULE-BASED REWARD: +# - Eliminates the need for a separate reward model, reducing computational overhead. +# - Reward is calculated based on: +# - Format Correctness (10%): Adherence to the required ... and \boxed{} format. +# - Answer Accuracy (90%): Correctness of the final answer. +# + +################################################################################ +# Part 1: User Configuration # +# Please update the following paths and settings to match your environment. # +################################################################################ + +# --- Model and Dataset Paths --- +# Path to the base model. Can be a Hugging Face model name (e.g., "Qwen/Qwen2.5-VL-7B-Instruct") +# or a local directory containing the model files. +PATH_TO_YOUR_BASE_MODEL="/path/to/your/base/model" + +# Path to the preprocessed geo3k dataset. +# See "Usage Instructions" at the end of the script for preprocessing steps. +PATH_TO_YOUR_GEO3K_DATASET="/path/to/your/preprocessed/geo3k_dataset" + +# --- Experiment and Logging --- +# A descriptive name for your experiment. Used for organizing logs and checkpoints. +EXPERIMENT_NAME="lightrft-geo3k-grpo-lora-training" + +# Your Weights & Biases API key. +# Set to an empty string "" if you are not using W&B. +# It is strongly recommended to set this as an environment variable instead of hardcoding. +export WANDB_API_KEY="YOUR_WANDB_API_KEY" +export WANDB_PROJECT="LightRFT-Geo3K-Experiments" + + +################################################################################ +# Part 2: Training Hyperparameters # +# These settings control the training process. Adjust them as needed. # +################################################################################ + +# --- GRPO Settings --- +GROUP_METHOD="normal" +N_SAMPLES=8 # Number of samples per prompt for GRPO (must be > 1). +EPISODE=20 # Total number of training episodes. +WARMUP=0.03 # Learning rate warmup ratio. +RBS=128 # Rollout Batch Size. +TBS=128 # Training Batch Size. + +# --- Learning and Model Settings --- +KL=0.01 # KL divergence coefficient. +LR=1e-6 # Actor learning rate. +MAX_LENGTH=3072 # Max sequence length (prompt + generation). +PROMPT_MAX_LEN=1024 # Max length of the input prompt. +GENERATE_MAX_LEN=2048 # Max length of the generated response. +LORA_RANK=128 # LoRA rank. +LORA_ALPHA=256 # LoRA alpha. +TARGET_MODULES="all-linear" # Target modules for LoRA. + +# --- Multi-modal Settings --- +limit_mm_image_per_prompt=10 # Max number of images per prompt. + +# --- Evaluation Settings --- +EVAL_SPLIT="test" # Dataset split to use for evaluation ("test", "validation"). +MAX_EVAL_SAMPLES=700 # Max samples for evaluation to keep it fast. +# Note: hiyouga/geometry3k dataset splits: train (2.1k), validation (300), test (601). + + +################################################################################ +# Part 3: Distributed Training Setup # +# Configure settings for multi-GPU and multi-node training. # +################################################################################ + +# --- PyTorch Distributed Environment Variables --- +export NNODES=1 # Number of nodes. +export GPUS_PER_NODE=8 # Number of GPUs per node. +export NODE_RANK=0 # Rank of the current node. +export MASTER_ADDR="localhost" # IP address of the master node (node 0). +export MASTER_PORT=20091 # Port for the master node. + +# --- vLLM/SGLang Engine Settings --- +ENGINE_TP=2 # Tensor parallelism size for the inference engine. Adjust based on your model and GPU setup. + + +################################################################################ +# Part 4: Execution and Logging # +# This section prepares and launches the training command. # +################################################################################ + +# --- Generate dynamic names and paths --- +current_time=$(date +"%Y%m%d_%H%M%S") +SAVE_MODEL_NAME="${EXPERIMENT_NAME}-ep${EPISODE}-kl${KL}-lr${LR}-lora_rank-${LORA_RANK}-alpha_${LORA_ALPHA}-${current_time}" +WANDB_RUN_NAME="${EXPERIMENT_NAME}-${current_time}" + +# --- Create directories for logs and checkpoints --- +mkdir -p "results/${EXPERIMENT_NAME}/${SAVE_MODEL_NAME}" +mkdir -p "rft_logs/${EXPERIMENT_NAME}" + +# --- System and Environment Optimizations --- +export TORCH_NCCL_AVOID_RECORD_STREAMS=1 +export NCCL_DEBUG="WARN" +export IGNORE_EOS=0 +export WANDB_MODE="offline" # Set to "online" for real-time W&B logging. + +# --- Set execution verbosity --- +set -x + + +################################################################################ +# Part 5: Main Training Command # +################################################################################ + +torchrun \ + --nnodes $NNODES \ + --nproc-per-node $GPUS_PER_NODE \ + --node_rank $NODE_RANK \ + --master-port $MASTER_PORT \ + --master-addr $MASTER_ADDR \ + examples/gsm8k_geo3k/train_colocate.py \ + --pretrain "${PATH_TO_YOUR_BASE_MODEL}" \ + --save_trajectories \ + --print_replay_buffer_stats \ + --loss_agg_mode "seq-mean-token-mean" \ + --fsdp \ + --rm_use_engine \ + --mixed_mm_data \ + --reward_pretrain "{}" \ + --save_path "results/${EXPERIMENT_NAME}/${SAVE_MODEL_NAME}" \ + --ckpt_path "results/${EXPERIMENT_NAME}/${SAVE_MODEL_NAME}" \ + --micro_train_batch_size 4 \ + --train_batch_size ${TBS} \ + --micro_rollout_batch_size 8 \ + --rollout_batch_size ${RBS} \ + --advantage_estimator "group_norm" \ + --max_epochs 1 \ + --num_episodes ${EPISODE} \ + --lr_warmup_ratio ${WARMUP} \ + --n_samples_per_prompt $N_SAMPLES \ + --prompt_max_len $PROMPT_MAX_LEN \ + --generate_max_len $GENERATE_MAX_LEN \ + --zero_stage 3 \ + --bf16 \ + --actor_learning_rate $LR \ + --lora_rank $LORA_RANK \ + --lora_alpha $LORA_ALPHA \ + --target_modules $TARGET_MODULES \ + --use_kl_loss \ + --init_kl_coef $KL \ + --kl_estimator "k3" \ + --prompt_data "${PATH_TO_YOUR_GEO3K_DATASET}" \ + --input_key "prompt" \ + --images_key "images" \ + --label_key "label" \ + --eval_steps 20 \ + --eval_split "${EVAL_SPLIT}" \ + --apply_chat_template \ + --flash_attn \ + --gradient_checkpointing \ + --save_steps 20 \ + --max_ckpt_num 2 \ + --engine_type sglang \ + --engine_mem_util 0.6 \ + --engine_tp_size $ENGINE_TP \ + --enable_engine_sleep \ + --system_prompt 'A conversation between the User and Assistant. The User asks a question, and the Assistant provides a solution. The Assistant first thinks through the reasoning process internally with self-reflection and consistency check and then gives the final analysis and answer. The reasoning process should be enclosed within , followed directly by the final thought and answer, the final answer MUST BE put in \\boxed{}, like this: reasoning process here final thought and \\boxed{answer} here.' \ + --l2 1.0e-2 \ + --freeze_prefix \ + --adam_offload \ + --limit_mm_image_per_prompt $limit_mm_image_per_prompt \ + --use_wandb "${WANDB_API_KEY}" \ + --wandb_project "${WANDB_PROJECT}" \ + --wandb_run_name "${WANDB_RUN_NAME}" \ + 2>&1 | tee "rft_logs/${EXPERIMENT_NAME}/node${NODE_RANK}_${current_time}.log" + + +################################################################################ +# Usage Instructions # +# # +# Step 1: Preprocess the Geo3K Dataset # +# Run the provided preprocessing script to prepare the dataset. # +# Make sure the output directory matches `PATH_TO_YOUR_GEO3K_DATASET`. # +# # +# `python ./examples/data_preprocess/geo3k.py --local_save_dir /path/to/your/preprocessed/geo3k_dataset` +# # +# Step 2: Configure the Script # +# Edit "Part 1: User Configuration" at the top of this file. Set the paths # +# to your base model and the preprocessed dataset. # +# # +# Step 3: Run the Training Script # +# Execute this script from your shell. # +# # +# `bash /path/to/this/script.sh` # +# # +################################################################################ \ No newline at end of file diff --git a/examples/gsm8k_geo3k/test_geo3k_lora.py b/examples/gsm8k_geo3k/test_geo3k_lora.py new file mode 100644 index 0000000..34ed53f --- /dev/null +++ b/examples/gsm8k_geo3k/test_geo3k_lora.py @@ -0,0 +1,293 @@ +#!/usr/bin/env python3 +""" +LoRA Evaluation Script for Geo3K Dataset + +This script evaluates a LoRA-fine-tuned vision-language model (e.g. Qwen2.5-VL) +on mathematical reasoning benchmarks (e.g. Geo3K). It automatically merges the LoRA adapter +into the base weights and runs high-throughput inference using vLLM engine. + +Key Features: + - LoRA Merging: Combines parameter-efficient adapters with base vision-language models. + - Fast Inference: Utilizes vLLM for high-throughput batch generation. + - Consistency: Adopts the exact data pipeline and prompt mappings as training. + - Rule-Based Evaluation: Computes format correctness and mathematical accuracy rewards. + +Execution Flow: + 1. LoRA Merging: Loads the base model, merges the given LoRA adapter, and saves the unified model to `/merged_model` (skipped if this directory already exists). + 2. Data Loading: Prepares the Geo3K test dataset applying chat templates using LightRFT's `PromptDatasetVL`. + 3. vLLM Initialization: Loads the fully-merged weights into the vLLM engine for high-throughput batch generation. + 4. Generation: Generates responses for all prompts in the evaluation split. + 5. Scoring: Evaluates the generated answers against ground truth using `geo3k_accuracy_reward_fn` and `geo3k_format_reward_fn`. + 6. Reporting: Averages the rewards and saves the full prediction results to `/eval_results.json`. + +Usage: + python test_lora_geo3k.py \ + --base_model /path/to/base_model \ + --lora_path /path/to/lora \ + --eval_data /path/to/data \ + --output_dir ./eval_output +""" + +import os +import argparse +import json +import torch +import sys +from tqdm import tqdm +from transformers import AutoModelForVision2Seq +from peft import PeftModel +from vllm import LLM, SamplingParams + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) +from lightrft.datasets import PromptDatasetVL +from lightrft.utils import blending_datasets, get_tokenizer_processor_vl +from examples.gsm8k_geo3k.reward_models_utils import geo3k_accuracy_reward_fn, geo3k_format_reward_fn + + +# ============================================================================ +# Configuration and Utilities +# ============================================================================ + +class MockStrategy: + """ + Mock Strategy to avoid deepspeed/sglang imports when just running inference. + """ + def __init__(self, args): + self.args = args + + def print(self, msg): + print(msg) + + def is_rank_0(self): + return True + +class MockArgs: + """ + Mock arguments class to simulate training arguments for data blending. + + This guarantees compatibility with LightRFT's dataset loading utilities + which expect a parsed arguments object containing specific configuration keys. + """ + + def __init__(self, seed: int = 42, **kwargs): + self.seed = seed + self.input_key = "prompt" + self.images_key = "images" + self.reference_key = "ground_truth" + self.label_key = "label" + self.apply_chat_template = True + self.system_prompt = 'A conversation between the User and Assistant. The User asks a question, and the Assistant provides a solution. The Assistant first thinks through the reasoning process internally with self-reflection and consistency check and then gives the final analysis and answer. The reasoning process should be enclosed within , followed directly by the final thought and answer, the final answer MUST BE put in \\boxed{}, like this: reasoning process here final thought and \\boxed{answer} here.' + for k, v in kwargs.items(): + setattr(self, k, v) + + +def parse_args() -> argparse.Namespace: + """ + Parse command-line arguments for evaluation. + + :return: Parsed arguments + :rtype: argparse.Namespace + """ + parser = argparse.ArgumentParser(description="Merge LoRA and test on Geo3K dataset") + parser.add_argument("--base_model", type=str, required=True, help="Path to base model (e.g. Qwen2.5-VL-7B-Instruct)") + parser.add_argument("--lora_path", type=str, required=True, help="Path to LoRA weights (e.g. /path/to/global_step5_lora)") + parser.add_argument("--eval_data", type=str, required=True, help="Path to Geo3K data directory") + parser.add_argument("--eval_split", type=str, default="test", help="Dataset split to evaluate") + parser.add_argument("--output_dir", type=str, default="./geo3k_lora_eval_results", help="Directory to save merged model and evaluation results") + parser.add_argument("--tp_size", type=int, default=1, help="Tensor parallel size for vLLM") + parser.add_argument("--max_samples", type=int, default=None, help="Max samples to evaluate") + parser.add_argument("--prompt_max_len", type=int, default=1024, help="Max tokens for prompt") + parser.add_argument("--generate_max_len", type=int, default=2048, help="Max tokens to generate") + parser.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature") + parser.add_argument("--top_p", type=float, default=1.0, help="Top-p for sampling (match training by default)") + parser.add_argument( + "--system_prompt", + type=str, + default='A conversation between the User and Assistant. The User asks a question, and the Assistant provides a solution. The Assistant first thinks through the reasoning process internally with self-reflection and consistency check and then gives the final analysis and answer. The reasoning process should be enclosed within , followed directly by the final thought and answer, the final answer MUST BE put in \\boxed{}, like this: reasoning process here final thought and \\boxed{answer} here.', + help="System prompt used when applying chat template; should match training.", + ) + parser.add_argument("--seed", type=int, default=42, help="Random seed") + return parser.parse_args() + + +# ============================================================================ +# Merging and Evaluation Functions +# ============================================================================ + +def merge_lora_weights(base_model_path: str, lora_path: str, save_dir: str) -> str: + """ + Merge LoRA weights into the base model and save to disk. + + This function loads the base model and LoRA adapter, merges them, and saves + the resulting full weights along with the tokenizer and processor. This ensures + that vLLM can load a unified model directly for high-throughput inference. + + :param base_model_path: Path to the underlying base model + :type base_model_path: str + :param lora_path: Path to the LoRA adapter weights + :type lora_path: str + :param save_dir: Directory to save the merged model + :type save_dir: str + :return: Path to the directory containing the merged model + :rtype: str + """ + print(f"Loading base model from {base_model_path}...") + model = AutoModelForVision2Seq.from_pretrained( + base_model_path, + torch_dtype=torch.bfloat16, + device_map="auto" + ) + + print(f"Loading LoRA from {lora_path}...") + model = PeftModel.from_pretrained(model, lora_path) + + print("Merging adapter...") + model = model.merge_and_unload() + + print(f"Saving merged model to {save_dir}...") + model.save_pretrained(save_dir, safe_serialization=True) + + tokenizer, processor = get_tokenizer_processor_vl(base_model_path, model, "left", use_fast=True) + tokenizer.save_pretrained(save_dir) + processor.save_pretrained(save_dir) + + print("Merged model saved successfully.") + return save_dir + + +def evaluate_model(model_path: str, args: argparse.Namespace) -> None: + """ + Evaluate the merged model using vLLM on the specified dataset. + + Loads configuration and data identically to the training pipeline, + generates responses using vLLM, and calculates accuracy and format rewards + using the predefined reward functions. + + :param model_path: Path to the merged model weights + :type model_path: str + :param args: Parsed command-line arguments with evaluation settings + :type args: argparse.Namespace + """ + print(f"Loading tokenizer and processor from {model_path}...") + mock_args = MockArgs(seed=args.seed, system_prompt=args.system_prompt) + # create a dynamic object that PromptDatasetVL expects for strategy.args + mock_strategy = type('MockStrategyParams', (), {'args': mock_args, 'print': print, 'is_rank_0': lambda self: True})() + + tokenizer, processor = get_tokenizer_processor_vl(model_path, None, "left", use_fast=True) + + print(f"Loading evaluation data from {args.eval_data}, split='{args.eval_split}'...") + eval_data = blending_datasets( + args.eval_data, "1.0", mock_strategy, args.seed, return_eval=False, + train_split=args.eval_split + ) + + if args.max_samples: + eval_data = eval_data.select(range(min(args.max_samples, len(eval_data)))) + + eval_dataset = PromptDatasetVL( + dataset=eval_data, + tokenizer=tokenizer, + processor=processor, + max_length=args.prompt_max_len, + strategy=mock_strategy, + input_template=None + ) + + print(f"Evaluation dataset loaded: {len(eval_dataset)} samples") + + print("Initializing vLLM engine...") + engine = LLM( + model=model_path, + tensor_parallel_size=args.tp_size, + trust_remote_code=True, + max_model_len=4096, + limit_mm_per_prompt={"image": 10} + ) + + im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>") + stop_token_ids = [tokenizer.eos_token_id] + if im_end_id is not None: + stop_token_ids.append(im_end_id) + + sampling_params = SamplingParams( + temperature=args.temperature, + top_p=args.top_p, + max_tokens=args.generate_max_len, + stop_token_ids=stop_token_ids + ) + + vllm_inputs = [] + refs = [] + + for i in range(len(eval_dataset)): + prompt, images, reference, label = eval_dataset[i] + + inp = {"prompt": prompt} + if images and len(images) > 0: + inp["multi_modal_data"] = {"image": images} + + vllm_inputs.append(inp) + + refs.append(reference) + + print("Running vLLM inference generation...") + outputs = engine.generate(vllm_inputs, sampling_params) + + results = [] + total_acc = 0 + total_fmt = 0 + + for i, output in enumerate(outputs): + generated_text = output.outputs[0].text + gt = refs[i] + + if isinstance(gt, list) and len(gt) > 0: + gt = gt[0] + + acc_reward = geo3k_accuracy_reward_fn(generated_text, str(gt)) + fmt_reward = geo3k_format_reward_fn(generated_text) + + total_acc += acc_reward + total_fmt += fmt_reward + + results.append({ + "prompt": vllm_inputs[i]["prompt"], + "generated": generated_text, + "ground_truth": gt, + "accuracy": acc_reward, + "format": fmt_reward, + }) + + avg_acc = total_acc / len(outputs) if len(outputs) > 0 else 0 + avg_fmt = total_fmt / len(outputs) if len(outputs) > 0 else 0 + + print(f"\n{'='*40}") + print(f"--- Final Evaluation Results ---") + print(f"Total Evaluated Samples: {len(outputs)}") + print(f"Average Accuracy Reward: {avg_acc:.4f} ({(avg_acc*100):.2f}%)") + print(f"Average Format Correctness: {avg_fmt:.4f} ({(avg_fmt*100):.2f}%)") + print(f"{'='*40}\n") + + output_json_path = os.path.join(args.output_dir, "eval_results.json") + with open(output_json_path, "w", encoding="utf-8") as f: + json.dump(results, f, indent=4, ensure_ascii=False) + print(f"Full results saved to {output_json_path}") + + +# ============================================================================ +# Main Entry Point +# ============================================================================ + +if __name__ == "__main__": + args = parse_args() + + os.makedirs(args.output_dir, exist_ok=True) + merged_model_dir = os.path.join(args.output_dir, "merged_model") + + if not os.path.exists(merged_model_dir): + print(f"Starting LoRA merge process...") + merge_lora_weights(args.base_model, args.lora_path, merged_model_dir) + else: + print(f"Merged model path '{merged_model_dir}' already exists. Skipping merging step...") + + evaluate_model(merged_model_dir, args) diff --git a/examples/gsm8k_geo3k/train_colocate.py b/examples/gsm8k_geo3k/train_colocate.py index 818bcce..9fe724b 100644 --- a/examples/gsm8k_geo3k/train_colocate.py +++ b/examples/gsm8k_geo3k/train_colocate.py @@ -406,14 +406,14 @@ def train(args): strategy.save_model( ema_model if args.enable_ema else actor, tokenizer, - args.save_path, + os.path.join(args.save_path, "final_ckpt"), ) if args.critic_pretrain and args.save_value_network: strategy.save_model( critic, tokenizer, - args.save_path + "_critic", + os.path.join(args.save_path, "critic"), ) @@ -540,7 +540,7 @@ def train(args): parser.add_argument("--lora_rank", type=int, default=0) parser.add_argument("--lora_alpha", type=int, default=16) parser.add_argument("--target_modules", type=str, nargs="*", default="all-linear") - parser.add_argument("--lora_dropout", type=float, default=0) + parser.add_argument("--lora_dropout", type=float, default=0.0) # Models parser.add_argument("--pretrain", type=str, default=None, help="HF model name or path") diff --git a/lightrft/models/utils.py b/lightrft/models/utils.py index 2cff71d..28ba8f6 100644 --- a/lightrft/models/utils.py +++ b/lightrft/models/utils.py @@ -346,7 +346,7 @@ def apply_lora_configuration( model.enable_input_require_grads() # Auto-detect target modules if not provided - if target_modules is None: + if target_modules is None or "all-linear" in target_modules: target_modules = find_all_linear_modules(model, freeze_vision_tower) print("target_modules: ", target_modules) diff --git a/lightrft/strategy/fsdp/fsdpv2.py b/lightrft/strategy/fsdp/fsdpv2.py index 8067307..b8b2c99 100755 --- a/lightrft/strategy/fsdp/fsdpv2.py +++ b/lightrft/strategy/fsdp/fsdpv2.py @@ -12,7 +12,7 @@ import shutil from collections import defaultdict from contextlib import contextmanager -from typing import List, Tuple, Union +from typing import Any, List, Optional, Tuple, Union import torch from torch import distributed as dist @@ -39,6 +39,7 @@ import torch.distributed.checkpoint as dcp from torch.distributed.checkpoint.state_dict import ( + StateDictOptions, get_model_state_dict, get_optimizer_state_dict, set_model_state_dict, @@ -451,9 +452,10 @@ def unwrap_model(self, model) -> nn.Module: else: return model - def save_model(self, *args, **kwargs) -> None: + def save_model(self, model: nn.Module, tokenizer: Any, output_dir: str, **kwargs) -> None: """ - Save the model, its configuration, and tokenizer. + Save the model, its configuration, and tokenizer in Hugging Face format. + In LoRA mode, this saves the adapter. In full mode, this saves the full model. This method handles gathering and saving the full model parameters in a distributed setting. Only rank 0 process saves the model to disk. @@ -465,7 +467,38 @@ def save_model(self, *args, **kwargs) -> None: :type output_dir: str :param kwargs: Additional arguments to pass to model.save_pretrained """ - self.print("FSDP save model is not implemented, please use offline tools to convert to huggingface model") + # Determine the model to save (unwrap ActorVL or similar wrappers) + actual_model = model.model if hasattr(model, "model") else model + + # [Gather Configuration] + # In this environment, get_model_state_dict uses 'options' and 'StateDictOptions'. + # We want the full state dict collected to rank 0 (cpu_offload=True to avoid OOM). + opts = StateDictOptions(full_state_dict=True, cpu_offload=True) + + # get_model_state_dict is a collective call, must be called on ALL ranks. + # It internally interacts with FSDP modules to perform the All-Gather. + state_dict = get_model_state_dict(actual_model, options=opts) + + if self.is_rank_0(): + os.makedirs(output_dir, exist_ok=True) + + # Use save_pretrained if available (handles HF and PEFT) + # PEFT's save_pretrained will automatically filter for adapter weights if state_dict is provided + if hasattr(actual_model, "save_pretrained"): + actual_model.save_pretrained(output_dir, state_dict=state_dict, safe_serialization=True) + else: + # Fallback to torch.save + save_path = os.path.join(output_dir, "pytorch_model.bin") + torch.save(state_dict, save_path) + + # Save the tokenizer + if tokenizer is not None: + tokenizer.save_pretrained(output_dir) + + self.print(f"Hugging Face model saved to {output_dir}") + + # Ensure all ranks wait for rank 0 to finish saving + dist.barrier() def save_ckpt( self, @@ -537,7 +570,13 @@ def save_ckpt( dist.barrier() - fsdp_state_dict = get_model_state_dict(model) + # Determine the model to save (unwrap ActorVL or similar wrappers) + actual_model = model.model if is_actor(model) or hasattr(model, "model") else model + + # [Sharded State Dict] + # For DCP, we want the sharded state dict (full_state_dict=False by default) + opts = StateDictOptions(full_state_dict=False, cpu_offload=False) + fsdp_state_dict = get_model_state_dict(actual_model, options=opts) fp = os.path.join(save_dir, tag) os.makedirs(fp, exist_ok=True) @@ -552,7 +591,7 @@ def save_ckpt( torch.save(optimizer.state_dict(), opt_ckpt_path) else: # DCP can only be use with naive optimizer - fsdp_optim_state_dict = get_optimizer_state_dict(model, optimizer) + fsdp_optim_state_dict = get_optimizer_state_dict(actual_model, optimizer) dcp.save(fsdp_optim_state_dict, checkpoint_id=opt_base_dir) client_ckpt_path = os.path.join(fp, "client_state.pt") diff --git a/lightrft/strategy/utils/broadcast_utils.py b/lightrft/strategy/utils/broadcast_utils.py index 96df20d..efd8fc0 100755 --- a/lightrft/strategy/utils/broadcast_utils.py +++ b/lightrft/strategy/utils/broadcast_utils.py @@ -7,6 +7,7 @@ parameters and efficiently transferring them to inference engines like vllm and sglang. """ +from collections import OrderedDict from typing import Any import deepspeed @@ -64,9 +65,22 @@ def _map_weight_name_for_sglang(self, name: str) -> str: :param name: Original weight name from training model :return: Mapped weight name for SGLang """ - # Step 1: Remove outermost "model." prefix if present - if name.startswith("model."): - name = name[6:] # Remove "model." + # Step 0: Handle PEFT/LoRA and other potential wrapping prefixes + # PEFT models have weights like base_model.model. + # We recursively strip "base_model.model." or "model." prefixes until we find + # core components like "visual" or "language_model" + while name.startswith("base_model.model.") or name.startswith("model."): + if name.startswith("base_model.model."): + name = name[len("base_model.model."):] + elif name.startswith("model."): + # We strip "model." and let the following steps handle it. + # If "language_model" follows, it will be added back as "model." + # for SGLang's expectation. + name = name[len("model."):] + + # PEFT models also rename original weights to include ".base_layer." + # we need to strip this to match standard weight names + name = name.replace(".base_layer.", ".") # Step 2: Handle language_model prefix mapping if name.startswith("language_model."): @@ -103,97 +117,126 @@ def _deepspeed_broadcast(self): # For ZeRO-3, allgather sharded parameter and broadcast to all vllm engines by rank 0 with deepspeed.zero.GatheredParameters([param], enabled=self.strategy.args.zero_stage == 3): + # Both engines: LoRA adapters are already merged, no need to broadcast them + if ".lora_" in name: + continue + shape = param.shape if self.strategy.args.zero_stage != 3 else param.ds_shape kwargs = dict( name=name, dtype=param.dtype, shape=shape, weight=param.data, empty_cache=(count == num_params) ) if self.strategy.engine_type == "vllm": + if ".base_layer" in name or "base_model" in name: + raise NotImplementedError("vLLM name mapping is not supported for LoRA broadcasting yet.") self.inference_engine.llm_engine.model_executor.collective_rpc("update_weight", kwargs=kwargs) elif self.strategy.engine_type == "sglang": - if self.strategy.args.text_only: - # for LLM - self.inference_engine.update_weights_from_tensor( - name, param.data, flush_cache=(count == num_params) - ) - else: - # for VLM - # Map weight names from training model to SGLang format - # Training model: model.visual.xxx, model.language_model.xxx - # SGLang expects: visual.xxx, model.xxx (for language model), lm_head - sglang_name = self._map_weight_name_for_sglang(name) - self.inference_engine.update_weights_from_tensor( - sglang_name, param.data, flush_cache=(count == num_params) - ) + sglang_name = self._map_weight_name_for_sglang(name) + self.inference_engine.update_weights_from_tensor( + sglang_name, param.data, flush_cache=(count == num_params) + ) + else: + raise RuntimeError(f"Unsupported engine type: {self.strategy.engine_type}") def _fsdp_v2_broadcast(self): """ Broadcast model weights using PyTorch's FSDP v2. - This method uses the state_dict approach to gather and broadcast weights - for FSDP v2, which has a different API compared to v1. It handles DTensor - parameters by converting them to full tensors before broadcasting. - - :raises NotImplementedError: If sglang is used as the inference engine, which doesn't support FSDP v2 + Specialized for LoRA/PEFT: + Instead of calling merge_adapter() which fails on FSDP DTensors, + we manually gather base and lora weights and merge them on the fly. """ model = self.actor.model - count, num_params = 0, len(list(model.named_parameters())) + param_dict = OrderedDict(model.named_parameters()) + count, num_params = 0, len(param_dict) dst_dtype = torch.bfloat16 if self.strategy.args.bf16 else torch.float16 - for name, param in model.named_parameters(): - count += 1 # empty_cache at last param - param_on_device = param.to(get_current_device()) - if isinstance(param, DTensor): - full_param = param_on_device.full_tensor().to(dst_dtype) + + # Get PEFT config for scaling + is_peft = hasattr(model, "peft_config") + lora_config = model.peft_config.get("default") if is_peft else None + scaling = lora_config.lora_alpha / lora_config.r if lora_config else 1.0 + + for name, param in param_dict.items(): + count += 1 + + # Skip LoRA adapters directly, they will be merged when processing base_layer + if ".lora_" in name: + continue + + # Identify if this is a PEFT base layer + effective_name = name + full_weight = None + + if ".base_layer.weight" in name: + if self.strategy.engine_type == "vllm": + raise NotImplementedError("vLLM is not supported for FSDP LoRA broadcasting yet.") + # This is a LoRA-enabled layer + prefix = name.replace(".base_layer.weight", "") + lora_a_name = f"{prefix}.lora_A.default.weight" + lora_b_name = f"{prefix}.lora_B.default.weight" + + # Gather Base, LoRA A, and LoRA B + w_base = param.to(get_current_device()).full_tensor().to(torch.float32) + w_a = param_dict[lora_a_name].to(get_current_device()).full_tensor().to(torch.float32) + w_b = param_dict[lora_b_name].to(get_current_device()).full_tensor().to(torch.float32) + + # Merge: W = W + scale * (B @ A) + full_weight = (w_base + scaling * (w_b @ w_a)).to(dst_dtype) + + # Clean up intermediate huge gathered tensors + del w_base, w_a, w_b else: - full_param = param_on_device.to(dst_dtype) + # Normal layer (e.g. vision tower or non-lora layer) + param_on_device = param.to(get_current_device()) + if isinstance(param, DTensor): + full_weight = param_on_device.full_tensor().to(dst_dtype) + else: + full_weight = param_on_device.to(dst_dtype) + del param_on_device + # Broadcast to engine if self.strategy.engine_type == "vllm": + # TODO:map weight name for vllm kwargs = dict( name=name, - dtype=full_param.dtype, - shape=full_param.shape, - weight=full_param.data, + dtype=full_weight.dtype, + shape=full_weight.shape, + weight=full_weight.data, empty_cache=(count == num_params), ) self.inference_engine.llm_engine.model_executor.collective_rpc("update_weight", kwargs=kwargs) elif self.strategy.engine_type == "sglang": - if self.strategy.args.text_only: - # for LLM - self.inference_engine.update_weights_from_tensor( - name, param.data, flush_cache=(count == num_params) - ) - else: - # for VLM - # Map weight names from training model to SGLang format - # Training model: model.visual.xxx, model.language_model.xxx - # SGLang expects: visual.xxx, model.xxx (for language model), lm_head - sglang_name = self._map_weight_name_for_sglang(name) - self.inference_engine.update_weights_from_tensor( - sglang_name, param.data, flush_cache=(count == num_params) - ) + sglang_name = self._map_weight_name_for_sglang(effective_name) + self.inference_engine.update_weights_from_tensor( + sglang_name, full_weight.data, flush_cache=(count == num_params) + ) - del param_on_device - del full_param + del full_weight def broadcast_to_engine(self): """ Broadcast model weights to the inference engine. This method selects the appropriate broadcasting strategy based on the - distributed training configuration (DeepSpeed, FSDP v2). It automatically - detects whether to use DeepSpeed or FSDP broadcasting based on the strategy - configuration. - - Example:: - - # Initialize the broadcast manager - broadcast_manager = BroadcastManager(actor_model, strategy, inference_engine) - - # Broadcast weights to inference engine - broadcast_manager.broadcast_to_engine() - - :raises NotImplementedError: If an unsupported configuration is used + distributed training configuration (DeepSpeed, FSDP v2). """ - if self.strategy.args.fsdp: - self._fsdp_v2_broadcast() + if self.strategy.engine_type in ("vllm", "sglang"): + if self.strategy.args.fsdp: + # FSDP handles merging manually inside _fsdp_v2_broadcast + self._fsdp_v2_broadcast() + else: + # DeepSpeed path + is_peft = hasattr(self.actor.model, "merge_adapter") + if is_peft: + self.strategy.print("Merging LoRA adapters for weight synchronization...") + self.actor.model.merge_adapter() + + try: + self._deepspeed_broadcast() + finally: + if is_peft: + self.strategy.print("Unmerging LoRA adapters after synchronization...") + self.actor.model.unmerge_adapter() else: - self._deepspeed_broadcast() + raise RuntimeError(f"Unsupported engine type: {self.strategy.engine_type}") + + self.strategy.print("Finished weight broadcasting to inference engine.") diff --git a/lightrft/trainer/ppo_trainer_vl.py b/lightrft/trainer/ppo_trainer_vl.py index 4584129..d9a3ef0 100644 --- a/lightrft/trainer/ppo_trainer_vl.py +++ b/lightrft/trainer/ppo_trainer_vl.py @@ -15,6 +15,7 @@ from lightrft.models.actor_modality import ActorModality, get_supported_parameters from lightrft.models.utils import masked_mean, unpacking_samples, compute_approx_kl from lightrft.utils.distributed_sampler import DistributedSampler +from lightrft.utils import rotate_ckpt_dirs from lightrft.trainer import AdaptiveKLController, ExperienceVL, FixedKLController, NaiveExperienceMakerVL, NaiveReplayBufferVL # noqa @@ -161,6 +162,7 @@ def __init__( self.reward_fn = reward_fn self.reward_fn_label_map = reward_fn_label_map self.reward_recipe = reward_recipe + self.is_lora = getattr(self.args, "lora_rank", 0) > 0 self.actor = actor self.critic = critic @@ -1138,10 +1140,11 @@ def _save_checkpoint(self, args, tag, client_states): :param client_states: Client state for checkpoint recovery. :type client_states: dict """ - if not self.disable_ds_ckpt: + ckpt_path = args.ckpt_path + if not self.disable_ds_ckpt and not self.is_lora: self.strategy.save_ckpt( self.actor.model, - os.path.join(args.ckpt_path, "_actor"), + os.path.join(ckpt_path, "_actor"), tag, args.max_ckpt_num, args.max_ckpt_mem, @@ -1149,11 +1152,24 @@ def _save_checkpoint(self, args, tag, client_states): ) if self.critic is not None: self.strategy.save_ckpt( - self.critic, os.path.join(args.ckpt_path, "_critic"), tag, args.max_ckpt_num, args.max_ckpt_mem + self.critic, os.path.join(ckpt_path, "_critic"), tag, args.max_ckpt_num, args.max_ckpt_mem ) - if self.save_hf_ckpt: - save_path = os.path.join(args.ckpt_path, f"{tag}_hf") + # For LoRA, we ALWAYS save the HF adapter as it is much smaller and more convenient for deployment. + if self.save_hf_ckpt or self.is_lora: + # Rotate HF checkpoints + if self.strategy.is_rank_0(): + os.makedirs(ckpt_path, exist_ok=True) + max_num = getattr(args, "max_ckpt_num", 3) + rotate_ckpt_dirs( + ckpt_path, + max_num, + suffix="_lora", + strategy=self.strategy, + label="HF ckpt", + ) + + save_path = os.path.join(ckpt_path, f"{tag}_lora") self.strategy.save_model(self.actor, self.tokenizer, save_path) def evaluate(self, eval_dataloader, global_step): diff --git a/lightrft/utils/__init__.py b/lightrft/utils/__init__.py index 66e6056..0160007 100644 --- a/lightrft/utils/__init__.py +++ b/lightrft/utils/__init__.py @@ -12,7 +12,7 @@ from .processor import get_processor, reward_normalization from .utils import ( blending_datasets, get_tokenizer, get_tokenizer_processor_vl, print_rank_0, get_current_device, get_torch_profiler, - ensure_video_input_available, all_gather_and_flatten, all_reduce_dict + ensure_video_input_available, all_gather_and_flatten, all_reduce_dict, rotate_ckpt_dirs ) from .cli_args import add_arguments @@ -42,6 +42,7 @@ "ensure_video_input_available", "all_gather_and_flatten", "all_reduce_dict", + "rotate_ckpt_dirs", # cli_args "add_arguments", diff --git a/lightrft/utils/utils.py b/lightrft/utils/utils.py index 3ad1272..9b6a3a2 100644 --- a/lightrft/utils/utils.py +++ b/lightrft/utils/utils.py @@ -1,5 +1,6 @@ import os import sys +import shutil from typing import Any, Dict, List, Optional, Tuple, Union from datasets import interleave_datasets, load_dataset, load_from_disk, Dataset, DatasetDict @@ -538,3 +539,42 @@ def all_reduce_dict(metrics_dict: Dict[str, float], reduced_values = tensor.tolist() return {k: v for k, v in zip(keys, reduced_values)} + + +def rotate_ckpt_dirs( + ckpt_path: str, + max_num: int, + suffix: Optional[str] = None, + strategy: Optional[Any] = None, + label: str = "checkpoint", +) -> None: + """ + Remove oldest checkpoint directories under a path when reaching a limit. + + :param ckpt_path: Parent checkpoint directory. + :type ckpt_path: str + :param max_num: Maximum number of checkpoints to keep. + :type max_num: int + :param suffix: Optional directory suffix filter (e.g., "_lora"). + :type suffix: Optional[str] + :param strategy: Optional strategy for logging. + :type strategy: Optional[Any] + :param label: Label used in log messages. + :type label: str + """ + if max_num is None or max_num <= 0 or not os.path.isdir(ckpt_path): + return + + subdirs = sorted( + [(os.path.join(ckpt_path, d), os.path.getmtime(os.path.join(ckpt_path, d))) + for d in os.listdir(ckpt_path) + if os.path.isdir(os.path.join(ckpt_path, d)) and (suffix is None or d.endswith(suffix))], + key=lambda x: x[1], + ) + + while len(subdirs) >= max_num: + oldest_dir = subdirs.pop(0)[0] + if os.path.exists(oldest_dir): + shutil.rmtree(oldest_dir) + if strategy is not None: + strategy.print(f"Deleted oldest {label} {oldest_dir}")