diff --git a/examples/gsm8k_geo3k/run_grpo_geo3k_lora_qwen2.5_vl_7b.sh b/examples/gsm8k_geo3k/run_grpo_geo3k_lora_qwen2.5_vl_7b.sh
new file mode 100644
index 0000000..d0110c9
--- /dev/null
+++ b/examples/gsm8k_geo3k/run_grpo_geo3k_lora_qwen2.5_vl_7b.sh
@@ -0,0 +1,214 @@
+#!/bin/bash
+#
+# LightRFT Multi-Modal LoRA Training Script for the Geo3K Dataset.
+# This script is designed for fine-tuning a large multi-modal model using the GRPO algorithm with LoRA.
+#
+################################################################################
+# LoRA Pipeline Introduction #
+# #
+# This pipeline is a LoRA-based (Low-Rank Adaptation) modification of the #
+# standard GRPO training process. It enables parameter-efficient fine-tuning #
+# of large-scale models (e.g., Qwen2.5-VL-7B) by updating only a small #
+# percentage of weights. #
+# #
+# Main modifications for LoRA: #
+# - Parameter Efficiency: Significantly reduces VRAM usage for 7B+ models. #
+# - Targeted Adaptation: Adapts all linear layers to maintain necessary #
+# model capacity. #
+# - Memory Optimization: Fully compatible with FSDP and ZeRO strategies. #
+################################################################################
+#
+# Key Features:
+# 1. Parameter-Efficient Fine-Tuning (LoRA):
+# - Enables training large 7B+ multi-modal models on consumer-grade or limited GPU setups.
+# - Configured with a high LoRA rank (128) and alpha (256) to capture complex reasoning patterns.
+# 2. PURE RULE-BASED REWARD:
+# - Eliminates the need for a separate reward model, reducing computational overhead.
+# - Reward is calculated based on:
+# - Format Correctness (10%): Adherence to the required ... and \boxed{} format.
+# - Answer Accuracy (90%): Correctness of the final answer.
+#
+
+################################################################################
+# Part 1: User Configuration #
+# Please update the following paths and settings to match your environment. #
+################################################################################
+
+# --- Model and Dataset Paths ---
+# Path to the base model. Can be a Hugging Face model name (e.g., "Qwen/Qwen2.5-VL-7B-Instruct")
+# or a local directory containing the model files.
+PATH_TO_YOUR_BASE_MODEL="/path/to/your/base/model"
+
+# Path to the preprocessed geo3k dataset.
+# See "Usage Instructions" at the end of the script for preprocessing steps.
+PATH_TO_YOUR_GEO3K_DATASET="/path/to/your/preprocessed/geo3k_dataset"
+
+# --- Experiment and Logging ---
+# A descriptive name for your experiment. Used for organizing logs and checkpoints.
+EXPERIMENT_NAME="lightrft-geo3k-grpo-lora-training"
+
+# Your Weights & Biases API key.
+# Set to an empty string "" if you are not using W&B.
+# It is strongly recommended to set this as an environment variable instead of hardcoding.
+export WANDB_API_KEY="YOUR_WANDB_API_KEY"
+export WANDB_PROJECT="LightRFT-Geo3K-Experiments"
+
+
+################################################################################
+# Part 2: Training Hyperparameters #
+# These settings control the training process. Adjust them as needed. #
+################################################################################
+
+# --- GRPO Settings ---
+GROUP_METHOD="normal"
+N_SAMPLES=8 # Number of samples per prompt for GRPO (must be > 1).
+EPISODE=20 # Total number of training episodes.
+WARMUP=0.03 # Learning rate warmup ratio.
+RBS=128 # Rollout Batch Size.
+TBS=128 # Training Batch Size.
+
+# --- Learning and Model Settings ---
+KL=0.01 # KL divergence coefficient.
+LR=1e-6 # Actor learning rate.
+MAX_LENGTH=3072 # Max sequence length (prompt + generation).
+PROMPT_MAX_LEN=1024 # Max length of the input prompt.
+GENERATE_MAX_LEN=2048 # Max length of the generated response.
+LORA_RANK=128 # LoRA rank.
+LORA_ALPHA=256 # LoRA alpha.
+TARGET_MODULES="all-linear" # Target modules for LoRA.
+
+# --- Multi-modal Settings ---
+limit_mm_image_per_prompt=10 # Max number of images per prompt.
+
+# --- Evaluation Settings ---
+EVAL_SPLIT="test" # Dataset split to use for evaluation ("test", "validation").
+MAX_EVAL_SAMPLES=700 # Max samples for evaluation to keep it fast.
+# Note: hiyouga/geometry3k dataset splits: train (2.1k), validation (300), test (601).
+
+
+################################################################################
+# Part 3: Distributed Training Setup #
+# Configure settings for multi-GPU and multi-node training. #
+################################################################################
+
+# --- PyTorch Distributed Environment Variables ---
+export NNODES=1 # Number of nodes.
+export GPUS_PER_NODE=8 # Number of GPUs per node.
+export NODE_RANK=0 # Rank of the current node.
+export MASTER_ADDR="localhost" # IP address of the master node (node 0).
+export MASTER_PORT=20091 # Port for the master node.
+
+# --- vLLM/SGLang Engine Settings ---
+ENGINE_TP=2 # Tensor parallelism size for the inference engine. Adjust based on your model and GPU setup.
+
+
+################################################################################
+# Part 4: Execution and Logging #
+# This section prepares and launches the training command. #
+################################################################################
+
+# --- Generate dynamic names and paths ---
+current_time=$(date +"%Y%m%d_%H%M%S")
+SAVE_MODEL_NAME="${EXPERIMENT_NAME}-ep${EPISODE}-kl${KL}-lr${LR}-lora_rank-${LORA_RANK}-alpha_${LORA_ALPHA}-${current_time}"
+WANDB_RUN_NAME="${EXPERIMENT_NAME}-${current_time}"
+
+# --- Create directories for logs and checkpoints ---
+mkdir -p "results/${EXPERIMENT_NAME}/${SAVE_MODEL_NAME}"
+mkdir -p "rft_logs/${EXPERIMENT_NAME}"
+
+# --- System and Environment Optimizations ---
+export TORCH_NCCL_AVOID_RECORD_STREAMS=1
+export NCCL_DEBUG="WARN"
+export IGNORE_EOS=0
+export WANDB_MODE="offline" # Set to "online" for real-time W&B logging.
+
+# --- Set execution verbosity ---
+set -x
+
+
+################################################################################
+# Part 5: Main Training Command #
+################################################################################
+
+torchrun \
+ --nnodes $NNODES \
+ --nproc-per-node $GPUS_PER_NODE \
+ --node_rank $NODE_RANK \
+ --master-port $MASTER_PORT \
+ --master-addr $MASTER_ADDR \
+ examples/gsm8k_geo3k/train_colocate.py \
+ --pretrain "${PATH_TO_YOUR_BASE_MODEL}" \
+ --save_trajectories \
+ --print_replay_buffer_stats \
+ --loss_agg_mode "seq-mean-token-mean" \
+ --fsdp \
+ --rm_use_engine \
+ --mixed_mm_data \
+ --reward_pretrain "{}" \
+ --save_path "results/${EXPERIMENT_NAME}/${SAVE_MODEL_NAME}" \
+ --ckpt_path "results/${EXPERIMENT_NAME}/${SAVE_MODEL_NAME}" \
+ --micro_train_batch_size 4 \
+ --train_batch_size ${TBS} \
+ --micro_rollout_batch_size 8 \
+ --rollout_batch_size ${RBS} \
+ --advantage_estimator "group_norm" \
+ --max_epochs 1 \
+ --num_episodes ${EPISODE} \
+ --lr_warmup_ratio ${WARMUP} \
+ --n_samples_per_prompt $N_SAMPLES \
+ --prompt_max_len $PROMPT_MAX_LEN \
+ --generate_max_len $GENERATE_MAX_LEN \
+ --zero_stage 3 \
+ --bf16 \
+ --actor_learning_rate $LR \
+ --lora_rank $LORA_RANK \
+ --lora_alpha $LORA_ALPHA \
+ --target_modules $TARGET_MODULES \
+ --use_kl_loss \
+ --init_kl_coef $KL \
+ --kl_estimator "k3" \
+ --prompt_data "${PATH_TO_YOUR_GEO3K_DATASET}" \
+ --input_key "prompt" \
+ --images_key "images" \
+ --label_key "label" \
+ --eval_steps 20 \
+ --eval_split "${EVAL_SPLIT}" \
+ --apply_chat_template \
+ --flash_attn \
+ --gradient_checkpointing \
+ --save_steps 20 \
+ --max_ckpt_num 2 \
+ --engine_type sglang \
+ --engine_mem_util 0.6 \
+ --engine_tp_size $ENGINE_TP \
+ --enable_engine_sleep \
+ --system_prompt 'A conversation between the User and Assistant. The User asks a question, and the Assistant provides a solution. The Assistant first thinks through the reasoning process internally with self-reflection and consistency check and then gives the final analysis and answer. The reasoning process should be enclosed within , followed directly by the final thought and answer, the final answer MUST BE put in \\boxed{}, like this: reasoning process here final thought and \\boxed{answer} here.' \
+ --l2 1.0e-2 \
+ --freeze_prefix \
+ --adam_offload \
+ --limit_mm_image_per_prompt $limit_mm_image_per_prompt \
+ --use_wandb "${WANDB_API_KEY}" \
+ --wandb_project "${WANDB_PROJECT}" \
+ --wandb_run_name "${WANDB_RUN_NAME}" \
+ 2>&1 | tee "rft_logs/${EXPERIMENT_NAME}/node${NODE_RANK}_${current_time}.log"
+
+
+################################################################################
+# Usage Instructions #
+# #
+# Step 1: Preprocess the Geo3K Dataset #
+# Run the provided preprocessing script to prepare the dataset. #
+# Make sure the output directory matches `PATH_TO_YOUR_GEO3K_DATASET`. #
+# #
+# `python ./examples/data_preprocess/geo3k.py --local_save_dir /path/to/your/preprocessed/geo3k_dataset`
+# #
+# Step 2: Configure the Script #
+# Edit "Part 1: User Configuration" at the top of this file. Set the paths #
+# to your base model and the preprocessed dataset. #
+# #
+# Step 3: Run the Training Script #
+# Execute this script from your shell. #
+# #
+# `bash /path/to/this/script.sh` #
+# #
+################################################################################
\ No newline at end of file
diff --git a/examples/gsm8k_geo3k/test_geo3k_lora.py b/examples/gsm8k_geo3k/test_geo3k_lora.py
new file mode 100644
index 0000000..34ed53f
--- /dev/null
+++ b/examples/gsm8k_geo3k/test_geo3k_lora.py
@@ -0,0 +1,293 @@
+#!/usr/bin/env python3
+"""
+LoRA Evaluation Script for Geo3K Dataset
+
+This script evaluates a LoRA-fine-tuned vision-language model (e.g. Qwen2.5-VL)
+on mathematical reasoning benchmarks (e.g. Geo3K). It automatically merges the LoRA adapter
+into the base weights and runs high-throughput inference using vLLM engine.
+
+Key Features:
+ - LoRA Merging: Combines parameter-efficient adapters with base vision-language models.
+ - Fast Inference: Utilizes vLLM for high-throughput batch generation.
+ - Consistency: Adopts the exact data pipeline and prompt mappings as training.
+ - Rule-Based Evaluation: Computes format correctness and mathematical accuracy rewards.
+
+Execution Flow:
+ 1. LoRA Merging: Loads the base model, merges the given LoRA adapter, and saves the unified model to `/merged_model` (skipped if this directory already exists).
+ 2. Data Loading: Prepares the Geo3K test dataset applying chat templates using LightRFT's `PromptDatasetVL`.
+ 3. vLLM Initialization: Loads the fully-merged weights into the vLLM engine for high-throughput batch generation.
+ 4. Generation: Generates responses for all prompts in the evaluation split.
+ 5. Scoring: Evaluates the generated answers against ground truth using `geo3k_accuracy_reward_fn` and `geo3k_format_reward_fn`.
+ 6. Reporting: Averages the rewards and saves the full prediction results to `/eval_results.json`.
+
+Usage:
+ python test_lora_geo3k.py \
+ --base_model /path/to/base_model \
+ --lora_path /path/to/lora \
+ --eval_data /path/to/data \
+ --output_dir ./eval_output
+"""
+
+import os
+import argparse
+import json
+import torch
+import sys
+from tqdm import tqdm
+from transformers import AutoModelForVision2Seq
+from peft import PeftModel
+from vllm import LLM, SamplingParams
+
+sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
+from lightrft.datasets import PromptDatasetVL
+from lightrft.utils import blending_datasets, get_tokenizer_processor_vl
+from examples.gsm8k_geo3k.reward_models_utils import geo3k_accuracy_reward_fn, geo3k_format_reward_fn
+
+
+# ============================================================================
+# Configuration and Utilities
+# ============================================================================
+
+class MockStrategy:
+ """
+ Mock Strategy to avoid deepspeed/sglang imports when just running inference.
+ """
+ def __init__(self, args):
+ self.args = args
+
+ def print(self, msg):
+ print(msg)
+
+ def is_rank_0(self):
+ return True
+
+class MockArgs:
+ """
+ Mock arguments class to simulate training arguments for data blending.
+
+ This guarantees compatibility with LightRFT's dataset loading utilities
+ which expect a parsed arguments object containing specific configuration keys.
+ """
+
+ def __init__(self, seed: int = 42, **kwargs):
+ self.seed = seed
+ self.input_key = "prompt"
+ self.images_key = "images"
+ self.reference_key = "ground_truth"
+ self.label_key = "label"
+ self.apply_chat_template = True
+ self.system_prompt = 'A conversation between the User and Assistant. The User asks a question, and the Assistant provides a solution. The Assistant first thinks through the reasoning process internally with self-reflection and consistency check and then gives the final analysis and answer. The reasoning process should be enclosed within , followed directly by the final thought and answer, the final answer MUST BE put in \\boxed{}, like this: reasoning process here final thought and \\boxed{answer} here.'
+ for k, v in kwargs.items():
+ setattr(self, k, v)
+
+
+def parse_args() -> argparse.Namespace:
+ """
+ Parse command-line arguments for evaluation.
+
+ :return: Parsed arguments
+ :rtype: argparse.Namespace
+ """
+ parser = argparse.ArgumentParser(description="Merge LoRA and test on Geo3K dataset")
+ parser.add_argument("--base_model", type=str, required=True, help="Path to base model (e.g. Qwen2.5-VL-7B-Instruct)")
+ parser.add_argument("--lora_path", type=str, required=True, help="Path to LoRA weights (e.g. /path/to/global_step5_lora)")
+ parser.add_argument("--eval_data", type=str, required=True, help="Path to Geo3K data directory")
+ parser.add_argument("--eval_split", type=str, default="test", help="Dataset split to evaluate")
+ parser.add_argument("--output_dir", type=str, default="./geo3k_lora_eval_results", help="Directory to save merged model and evaluation results")
+ parser.add_argument("--tp_size", type=int, default=1, help="Tensor parallel size for vLLM")
+ parser.add_argument("--max_samples", type=int, default=None, help="Max samples to evaluate")
+ parser.add_argument("--prompt_max_len", type=int, default=1024, help="Max tokens for prompt")
+ parser.add_argument("--generate_max_len", type=int, default=2048, help="Max tokens to generate")
+ parser.add_argument("--temperature", type=float, default=0.0, help="Sampling temperature")
+ parser.add_argument("--top_p", type=float, default=1.0, help="Top-p for sampling (match training by default)")
+ parser.add_argument(
+ "--system_prompt",
+ type=str,
+ default='A conversation between the User and Assistant. The User asks a question, and the Assistant provides a solution. The Assistant first thinks through the reasoning process internally with self-reflection and consistency check and then gives the final analysis and answer. The reasoning process should be enclosed within , followed directly by the final thought and answer, the final answer MUST BE put in \\boxed{}, like this: reasoning process here final thought and \\boxed{answer} here.',
+ help="System prompt used when applying chat template; should match training.",
+ )
+ parser.add_argument("--seed", type=int, default=42, help="Random seed")
+ return parser.parse_args()
+
+
+# ============================================================================
+# Merging and Evaluation Functions
+# ============================================================================
+
+def merge_lora_weights(base_model_path: str, lora_path: str, save_dir: str) -> str:
+ """
+ Merge LoRA weights into the base model and save to disk.
+
+ This function loads the base model and LoRA adapter, merges them, and saves
+ the resulting full weights along with the tokenizer and processor. This ensures
+ that vLLM can load a unified model directly for high-throughput inference.
+
+ :param base_model_path: Path to the underlying base model
+ :type base_model_path: str
+ :param lora_path: Path to the LoRA adapter weights
+ :type lora_path: str
+ :param save_dir: Directory to save the merged model
+ :type save_dir: str
+ :return: Path to the directory containing the merged model
+ :rtype: str
+ """
+ print(f"Loading base model from {base_model_path}...")
+ model = AutoModelForVision2Seq.from_pretrained(
+ base_model_path,
+ torch_dtype=torch.bfloat16,
+ device_map="auto"
+ )
+
+ print(f"Loading LoRA from {lora_path}...")
+ model = PeftModel.from_pretrained(model, lora_path)
+
+ print("Merging adapter...")
+ model = model.merge_and_unload()
+
+ print(f"Saving merged model to {save_dir}...")
+ model.save_pretrained(save_dir, safe_serialization=True)
+
+ tokenizer, processor = get_tokenizer_processor_vl(base_model_path, model, "left", use_fast=True)
+ tokenizer.save_pretrained(save_dir)
+ processor.save_pretrained(save_dir)
+
+ print("Merged model saved successfully.")
+ return save_dir
+
+
+def evaluate_model(model_path: str, args: argparse.Namespace) -> None:
+ """
+ Evaluate the merged model using vLLM on the specified dataset.
+
+ Loads configuration and data identically to the training pipeline,
+ generates responses using vLLM, and calculates accuracy and format rewards
+ using the predefined reward functions.
+
+ :param model_path: Path to the merged model weights
+ :type model_path: str
+ :param args: Parsed command-line arguments with evaluation settings
+ :type args: argparse.Namespace
+ """
+ print(f"Loading tokenizer and processor from {model_path}...")
+ mock_args = MockArgs(seed=args.seed, system_prompt=args.system_prompt)
+ # create a dynamic object that PromptDatasetVL expects for strategy.args
+ mock_strategy = type('MockStrategyParams', (), {'args': mock_args, 'print': print, 'is_rank_0': lambda self: True})()
+
+ tokenizer, processor = get_tokenizer_processor_vl(model_path, None, "left", use_fast=True)
+
+ print(f"Loading evaluation data from {args.eval_data}, split='{args.eval_split}'...")
+ eval_data = blending_datasets(
+ args.eval_data, "1.0", mock_strategy, args.seed, return_eval=False,
+ train_split=args.eval_split
+ )
+
+ if args.max_samples:
+ eval_data = eval_data.select(range(min(args.max_samples, len(eval_data))))
+
+ eval_dataset = PromptDatasetVL(
+ dataset=eval_data,
+ tokenizer=tokenizer,
+ processor=processor,
+ max_length=args.prompt_max_len,
+ strategy=mock_strategy,
+ input_template=None
+ )
+
+ print(f"Evaluation dataset loaded: {len(eval_dataset)} samples")
+
+ print("Initializing vLLM engine...")
+ engine = LLM(
+ model=model_path,
+ tensor_parallel_size=args.tp_size,
+ trust_remote_code=True,
+ max_model_len=4096,
+ limit_mm_per_prompt={"image": 10}
+ )
+
+ im_end_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
+ stop_token_ids = [tokenizer.eos_token_id]
+ if im_end_id is not None:
+ stop_token_ids.append(im_end_id)
+
+ sampling_params = SamplingParams(
+ temperature=args.temperature,
+ top_p=args.top_p,
+ max_tokens=args.generate_max_len,
+ stop_token_ids=stop_token_ids
+ )
+
+ vllm_inputs = []
+ refs = []
+
+ for i in range(len(eval_dataset)):
+ prompt, images, reference, label = eval_dataset[i]
+
+ inp = {"prompt": prompt}
+ if images and len(images) > 0:
+ inp["multi_modal_data"] = {"image": images}
+
+ vllm_inputs.append(inp)
+
+ refs.append(reference)
+
+ print("Running vLLM inference generation...")
+ outputs = engine.generate(vllm_inputs, sampling_params)
+
+ results = []
+ total_acc = 0
+ total_fmt = 0
+
+ for i, output in enumerate(outputs):
+ generated_text = output.outputs[0].text
+ gt = refs[i]
+
+ if isinstance(gt, list) and len(gt) > 0:
+ gt = gt[0]
+
+ acc_reward = geo3k_accuracy_reward_fn(generated_text, str(gt))
+ fmt_reward = geo3k_format_reward_fn(generated_text)
+
+ total_acc += acc_reward
+ total_fmt += fmt_reward
+
+ results.append({
+ "prompt": vllm_inputs[i]["prompt"],
+ "generated": generated_text,
+ "ground_truth": gt,
+ "accuracy": acc_reward,
+ "format": fmt_reward,
+ })
+
+ avg_acc = total_acc / len(outputs) if len(outputs) > 0 else 0
+ avg_fmt = total_fmt / len(outputs) if len(outputs) > 0 else 0
+
+ print(f"\n{'='*40}")
+ print(f"--- Final Evaluation Results ---")
+ print(f"Total Evaluated Samples: {len(outputs)}")
+ print(f"Average Accuracy Reward: {avg_acc:.4f} ({(avg_acc*100):.2f}%)")
+ print(f"Average Format Correctness: {avg_fmt:.4f} ({(avg_fmt*100):.2f}%)")
+ print(f"{'='*40}\n")
+
+ output_json_path = os.path.join(args.output_dir, "eval_results.json")
+ with open(output_json_path, "w", encoding="utf-8") as f:
+ json.dump(results, f, indent=4, ensure_ascii=False)
+ print(f"Full results saved to {output_json_path}")
+
+
+# ============================================================================
+# Main Entry Point
+# ============================================================================
+
+if __name__ == "__main__":
+ args = parse_args()
+
+ os.makedirs(args.output_dir, exist_ok=True)
+ merged_model_dir = os.path.join(args.output_dir, "merged_model")
+
+ if not os.path.exists(merged_model_dir):
+ print(f"Starting LoRA merge process...")
+ merge_lora_weights(args.base_model, args.lora_path, merged_model_dir)
+ else:
+ print(f"Merged model path '{merged_model_dir}' already exists. Skipping merging step...")
+
+ evaluate_model(merged_model_dir, args)
diff --git a/examples/gsm8k_geo3k/train_colocate.py b/examples/gsm8k_geo3k/train_colocate.py
index 818bcce..9fe724b 100644
--- a/examples/gsm8k_geo3k/train_colocate.py
+++ b/examples/gsm8k_geo3k/train_colocate.py
@@ -406,14 +406,14 @@ def train(args):
strategy.save_model(
ema_model if args.enable_ema else actor,
tokenizer,
- args.save_path,
+ os.path.join(args.save_path, "final_ckpt"),
)
if args.critic_pretrain and args.save_value_network:
strategy.save_model(
critic,
tokenizer,
- args.save_path + "_critic",
+ os.path.join(args.save_path, "critic"),
)
@@ -540,7 +540,7 @@ def train(args):
parser.add_argument("--lora_rank", type=int, default=0)
parser.add_argument("--lora_alpha", type=int, default=16)
parser.add_argument("--target_modules", type=str, nargs="*", default="all-linear")
- parser.add_argument("--lora_dropout", type=float, default=0)
+ parser.add_argument("--lora_dropout", type=float, default=0.0)
# Models
parser.add_argument("--pretrain", type=str, default=None, help="HF model name or path")
diff --git a/lightrft/models/utils.py b/lightrft/models/utils.py
index 2cff71d..28ba8f6 100644
--- a/lightrft/models/utils.py
+++ b/lightrft/models/utils.py
@@ -346,7 +346,7 @@ def apply_lora_configuration(
model.enable_input_require_grads()
# Auto-detect target modules if not provided
- if target_modules is None:
+ if target_modules is None or "all-linear" in target_modules:
target_modules = find_all_linear_modules(model, freeze_vision_tower)
print("target_modules: ", target_modules)
diff --git a/lightrft/strategy/fsdp/fsdpv2.py b/lightrft/strategy/fsdp/fsdpv2.py
index 8067307..b8b2c99 100755
--- a/lightrft/strategy/fsdp/fsdpv2.py
+++ b/lightrft/strategy/fsdp/fsdpv2.py
@@ -12,7 +12,7 @@
import shutil
from collections import defaultdict
from contextlib import contextmanager
-from typing import List, Tuple, Union
+from typing import Any, List, Optional, Tuple, Union
import torch
from torch import distributed as dist
@@ -39,6 +39,7 @@
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.state_dict import (
+ StateDictOptions,
get_model_state_dict,
get_optimizer_state_dict,
set_model_state_dict,
@@ -451,9 +452,10 @@ def unwrap_model(self, model) -> nn.Module:
else:
return model
- def save_model(self, *args, **kwargs) -> None:
+ def save_model(self, model: nn.Module, tokenizer: Any, output_dir: str, **kwargs) -> None:
"""
- Save the model, its configuration, and tokenizer.
+ Save the model, its configuration, and tokenizer in Hugging Face format.
+ In LoRA mode, this saves the adapter. In full mode, this saves the full model.
This method handles gathering and saving the full model parameters in a distributed setting.
Only rank 0 process saves the model to disk.
@@ -465,7 +467,38 @@ def save_model(self, *args, **kwargs) -> None:
:type output_dir: str
:param kwargs: Additional arguments to pass to model.save_pretrained
"""
- self.print("FSDP save model is not implemented, please use offline tools to convert to huggingface model")
+ # Determine the model to save (unwrap ActorVL or similar wrappers)
+ actual_model = model.model if hasattr(model, "model") else model
+
+ # [Gather Configuration]
+ # In this environment, get_model_state_dict uses 'options' and 'StateDictOptions'.
+ # We want the full state dict collected to rank 0 (cpu_offload=True to avoid OOM).
+ opts = StateDictOptions(full_state_dict=True, cpu_offload=True)
+
+ # get_model_state_dict is a collective call, must be called on ALL ranks.
+ # It internally interacts with FSDP modules to perform the All-Gather.
+ state_dict = get_model_state_dict(actual_model, options=opts)
+
+ if self.is_rank_0():
+ os.makedirs(output_dir, exist_ok=True)
+
+ # Use save_pretrained if available (handles HF and PEFT)
+ # PEFT's save_pretrained will automatically filter for adapter weights if state_dict is provided
+ if hasattr(actual_model, "save_pretrained"):
+ actual_model.save_pretrained(output_dir, state_dict=state_dict, safe_serialization=True)
+ else:
+ # Fallback to torch.save
+ save_path = os.path.join(output_dir, "pytorch_model.bin")
+ torch.save(state_dict, save_path)
+
+ # Save the tokenizer
+ if tokenizer is not None:
+ tokenizer.save_pretrained(output_dir)
+
+ self.print(f"Hugging Face model saved to {output_dir}")
+
+ # Ensure all ranks wait for rank 0 to finish saving
+ dist.barrier()
def save_ckpt(
self,
@@ -537,7 +570,13 @@ def save_ckpt(
dist.barrier()
- fsdp_state_dict = get_model_state_dict(model)
+ # Determine the model to save (unwrap ActorVL or similar wrappers)
+ actual_model = model.model if is_actor(model) or hasattr(model, "model") else model
+
+ # [Sharded State Dict]
+ # For DCP, we want the sharded state dict (full_state_dict=False by default)
+ opts = StateDictOptions(full_state_dict=False, cpu_offload=False)
+ fsdp_state_dict = get_model_state_dict(actual_model, options=opts)
fp = os.path.join(save_dir, tag)
os.makedirs(fp, exist_ok=True)
@@ -552,7 +591,7 @@ def save_ckpt(
torch.save(optimizer.state_dict(), opt_ckpt_path)
else:
# DCP can only be use with naive optimizer
- fsdp_optim_state_dict = get_optimizer_state_dict(model, optimizer)
+ fsdp_optim_state_dict = get_optimizer_state_dict(actual_model, optimizer)
dcp.save(fsdp_optim_state_dict, checkpoint_id=opt_base_dir)
client_ckpt_path = os.path.join(fp, "client_state.pt")
diff --git a/lightrft/strategy/utils/broadcast_utils.py b/lightrft/strategy/utils/broadcast_utils.py
index 96df20d..efd8fc0 100755
--- a/lightrft/strategy/utils/broadcast_utils.py
+++ b/lightrft/strategy/utils/broadcast_utils.py
@@ -7,6 +7,7 @@
parameters and efficiently transferring them to inference engines like vllm and sglang.
"""
+from collections import OrderedDict
from typing import Any
import deepspeed
@@ -64,9 +65,22 @@ def _map_weight_name_for_sglang(self, name: str) -> str:
:param name: Original weight name from training model
:return: Mapped weight name for SGLang
"""
- # Step 1: Remove outermost "model." prefix if present
- if name.startswith("model."):
- name = name[6:] # Remove "model."
+ # Step 0: Handle PEFT/LoRA and other potential wrapping prefixes
+ # PEFT models have weights like base_model.model.
+ # We recursively strip "base_model.model." or "model." prefixes until we find
+ # core components like "visual" or "language_model"
+ while name.startswith("base_model.model.") or name.startswith("model."):
+ if name.startswith("base_model.model."):
+ name = name[len("base_model.model."):]
+ elif name.startswith("model."):
+ # We strip "model." and let the following steps handle it.
+ # If "language_model" follows, it will be added back as "model."
+ # for SGLang's expectation.
+ name = name[len("model."):]
+
+ # PEFT models also rename original weights to include ".base_layer."
+ # we need to strip this to match standard weight names
+ name = name.replace(".base_layer.", ".")
# Step 2: Handle language_model prefix mapping
if name.startswith("language_model."):
@@ -103,97 +117,126 @@ def _deepspeed_broadcast(self):
# For ZeRO-3, allgather sharded parameter and broadcast to all vllm engines by rank 0
with deepspeed.zero.GatheredParameters([param], enabled=self.strategy.args.zero_stage == 3):
+ # Both engines: LoRA adapters are already merged, no need to broadcast them
+ if ".lora_" in name:
+ continue
+
shape = param.shape if self.strategy.args.zero_stage != 3 else param.ds_shape
kwargs = dict(
name=name, dtype=param.dtype, shape=shape, weight=param.data, empty_cache=(count == num_params)
)
if self.strategy.engine_type == "vllm":
+ if ".base_layer" in name or "base_model" in name:
+ raise NotImplementedError("vLLM name mapping is not supported for LoRA broadcasting yet.")
self.inference_engine.llm_engine.model_executor.collective_rpc("update_weight", kwargs=kwargs)
elif self.strategy.engine_type == "sglang":
- if self.strategy.args.text_only:
- # for LLM
- self.inference_engine.update_weights_from_tensor(
- name, param.data, flush_cache=(count == num_params)
- )
- else:
- # for VLM
- # Map weight names from training model to SGLang format
- # Training model: model.visual.xxx, model.language_model.xxx
- # SGLang expects: visual.xxx, model.xxx (for language model), lm_head
- sglang_name = self._map_weight_name_for_sglang(name)
- self.inference_engine.update_weights_from_tensor(
- sglang_name, param.data, flush_cache=(count == num_params)
- )
+ sglang_name = self._map_weight_name_for_sglang(name)
+ self.inference_engine.update_weights_from_tensor(
+ sglang_name, param.data, flush_cache=(count == num_params)
+ )
+ else:
+ raise RuntimeError(f"Unsupported engine type: {self.strategy.engine_type}")
def _fsdp_v2_broadcast(self):
"""
Broadcast model weights using PyTorch's FSDP v2.
- This method uses the state_dict approach to gather and broadcast weights
- for FSDP v2, which has a different API compared to v1. It handles DTensor
- parameters by converting them to full tensors before broadcasting.
-
- :raises NotImplementedError: If sglang is used as the inference engine, which doesn't support FSDP v2
+ Specialized for LoRA/PEFT:
+ Instead of calling merge_adapter() which fails on FSDP DTensors,
+ we manually gather base and lora weights and merge them on the fly.
"""
model = self.actor.model
- count, num_params = 0, len(list(model.named_parameters()))
+ param_dict = OrderedDict(model.named_parameters())
+ count, num_params = 0, len(param_dict)
dst_dtype = torch.bfloat16 if self.strategy.args.bf16 else torch.float16
- for name, param in model.named_parameters():
- count += 1 # empty_cache at last param
- param_on_device = param.to(get_current_device())
- if isinstance(param, DTensor):
- full_param = param_on_device.full_tensor().to(dst_dtype)
+
+ # Get PEFT config for scaling
+ is_peft = hasattr(model, "peft_config")
+ lora_config = model.peft_config.get("default") if is_peft else None
+ scaling = lora_config.lora_alpha / lora_config.r if lora_config else 1.0
+
+ for name, param in param_dict.items():
+ count += 1
+
+ # Skip LoRA adapters directly, they will be merged when processing base_layer
+ if ".lora_" in name:
+ continue
+
+ # Identify if this is a PEFT base layer
+ effective_name = name
+ full_weight = None
+
+ if ".base_layer.weight" in name:
+ if self.strategy.engine_type == "vllm":
+ raise NotImplementedError("vLLM is not supported for FSDP LoRA broadcasting yet.")
+ # This is a LoRA-enabled layer
+ prefix = name.replace(".base_layer.weight", "")
+ lora_a_name = f"{prefix}.lora_A.default.weight"
+ lora_b_name = f"{prefix}.lora_B.default.weight"
+
+ # Gather Base, LoRA A, and LoRA B
+ w_base = param.to(get_current_device()).full_tensor().to(torch.float32)
+ w_a = param_dict[lora_a_name].to(get_current_device()).full_tensor().to(torch.float32)
+ w_b = param_dict[lora_b_name].to(get_current_device()).full_tensor().to(torch.float32)
+
+ # Merge: W = W + scale * (B @ A)
+ full_weight = (w_base + scaling * (w_b @ w_a)).to(dst_dtype)
+
+ # Clean up intermediate huge gathered tensors
+ del w_base, w_a, w_b
else:
- full_param = param_on_device.to(dst_dtype)
+ # Normal layer (e.g. vision tower or non-lora layer)
+ param_on_device = param.to(get_current_device())
+ if isinstance(param, DTensor):
+ full_weight = param_on_device.full_tensor().to(dst_dtype)
+ else:
+ full_weight = param_on_device.to(dst_dtype)
+ del param_on_device
+ # Broadcast to engine
if self.strategy.engine_type == "vllm":
+ # TODOļ¼map weight name for vllm
kwargs = dict(
name=name,
- dtype=full_param.dtype,
- shape=full_param.shape,
- weight=full_param.data,
+ dtype=full_weight.dtype,
+ shape=full_weight.shape,
+ weight=full_weight.data,
empty_cache=(count == num_params),
)
self.inference_engine.llm_engine.model_executor.collective_rpc("update_weight", kwargs=kwargs)
elif self.strategy.engine_type == "sglang":
- if self.strategy.args.text_only:
- # for LLM
- self.inference_engine.update_weights_from_tensor(
- name, param.data, flush_cache=(count == num_params)
- )
- else:
- # for VLM
- # Map weight names from training model to SGLang format
- # Training model: model.visual.xxx, model.language_model.xxx
- # SGLang expects: visual.xxx, model.xxx (for language model), lm_head
- sglang_name = self._map_weight_name_for_sglang(name)
- self.inference_engine.update_weights_from_tensor(
- sglang_name, param.data, flush_cache=(count == num_params)
- )
+ sglang_name = self._map_weight_name_for_sglang(effective_name)
+ self.inference_engine.update_weights_from_tensor(
+ sglang_name, full_weight.data, flush_cache=(count == num_params)
+ )
- del param_on_device
- del full_param
+ del full_weight
def broadcast_to_engine(self):
"""
Broadcast model weights to the inference engine.
This method selects the appropriate broadcasting strategy based on the
- distributed training configuration (DeepSpeed, FSDP v2). It automatically
- detects whether to use DeepSpeed or FSDP broadcasting based on the strategy
- configuration.
-
- Example::
-
- # Initialize the broadcast manager
- broadcast_manager = BroadcastManager(actor_model, strategy, inference_engine)
-
- # Broadcast weights to inference engine
- broadcast_manager.broadcast_to_engine()
-
- :raises NotImplementedError: If an unsupported configuration is used
+ distributed training configuration (DeepSpeed, FSDP v2).
"""
- if self.strategy.args.fsdp:
- self._fsdp_v2_broadcast()
+ if self.strategy.engine_type in ("vllm", "sglang"):
+ if self.strategy.args.fsdp:
+ # FSDP handles merging manually inside _fsdp_v2_broadcast
+ self._fsdp_v2_broadcast()
+ else:
+ # DeepSpeed path
+ is_peft = hasattr(self.actor.model, "merge_adapter")
+ if is_peft:
+ self.strategy.print("Merging LoRA adapters for weight synchronization...")
+ self.actor.model.merge_adapter()
+
+ try:
+ self._deepspeed_broadcast()
+ finally:
+ if is_peft:
+ self.strategy.print("Unmerging LoRA adapters after synchronization...")
+ self.actor.model.unmerge_adapter()
else:
- self._deepspeed_broadcast()
+ raise RuntimeError(f"Unsupported engine type: {self.strategy.engine_type}")
+
+ self.strategy.print("Finished weight broadcasting to inference engine.")
diff --git a/lightrft/trainer/ppo_trainer_vl.py b/lightrft/trainer/ppo_trainer_vl.py
index 4584129..d9a3ef0 100644
--- a/lightrft/trainer/ppo_trainer_vl.py
+++ b/lightrft/trainer/ppo_trainer_vl.py
@@ -15,6 +15,7 @@
from lightrft.models.actor_modality import ActorModality, get_supported_parameters
from lightrft.models.utils import masked_mean, unpacking_samples, compute_approx_kl
from lightrft.utils.distributed_sampler import DistributedSampler
+from lightrft.utils import rotate_ckpt_dirs
from lightrft.trainer import AdaptiveKLController, ExperienceVL, FixedKLController, NaiveExperienceMakerVL, NaiveReplayBufferVL # noqa
@@ -161,6 +162,7 @@ def __init__(
self.reward_fn = reward_fn
self.reward_fn_label_map = reward_fn_label_map
self.reward_recipe = reward_recipe
+ self.is_lora = getattr(self.args, "lora_rank", 0) > 0
self.actor = actor
self.critic = critic
@@ -1138,10 +1140,11 @@ def _save_checkpoint(self, args, tag, client_states):
:param client_states: Client state for checkpoint recovery.
:type client_states: dict
"""
- if not self.disable_ds_ckpt:
+ ckpt_path = args.ckpt_path
+ if not self.disable_ds_ckpt and not self.is_lora:
self.strategy.save_ckpt(
self.actor.model,
- os.path.join(args.ckpt_path, "_actor"),
+ os.path.join(ckpt_path, "_actor"),
tag,
args.max_ckpt_num,
args.max_ckpt_mem,
@@ -1149,11 +1152,24 @@ def _save_checkpoint(self, args, tag, client_states):
)
if self.critic is not None:
self.strategy.save_ckpt(
- self.critic, os.path.join(args.ckpt_path, "_critic"), tag, args.max_ckpt_num, args.max_ckpt_mem
+ self.critic, os.path.join(ckpt_path, "_critic"), tag, args.max_ckpt_num, args.max_ckpt_mem
)
- if self.save_hf_ckpt:
- save_path = os.path.join(args.ckpt_path, f"{tag}_hf")
+ # For LoRA, we ALWAYS save the HF adapter as it is much smaller and more convenient for deployment.
+ if self.save_hf_ckpt or self.is_lora:
+ # Rotate HF checkpoints
+ if self.strategy.is_rank_0():
+ os.makedirs(ckpt_path, exist_ok=True)
+ max_num = getattr(args, "max_ckpt_num", 3)
+ rotate_ckpt_dirs(
+ ckpt_path,
+ max_num,
+ suffix="_lora",
+ strategy=self.strategy,
+ label="HF ckpt",
+ )
+
+ save_path = os.path.join(ckpt_path, f"{tag}_lora")
self.strategy.save_model(self.actor, self.tokenizer, save_path)
def evaluate(self, eval_dataloader, global_step):
diff --git a/lightrft/utils/__init__.py b/lightrft/utils/__init__.py
index 66e6056..0160007 100644
--- a/lightrft/utils/__init__.py
+++ b/lightrft/utils/__init__.py
@@ -12,7 +12,7 @@
from .processor import get_processor, reward_normalization
from .utils import (
blending_datasets, get_tokenizer, get_tokenizer_processor_vl, print_rank_0, get_current_device, get_torch_profiler,
- ensure_video_input_available, all_gather_and_flatten, all_reduce_dict
+ ensure_video_input_available, all_gather_and_flatten, all_reduce_dict, rotate_ckpt_dirs
)
from .cli_args import add_arguments
@@ -42,6 +42,7 @@
"ensure_video_input_available",
"all_gather_and_flatten",
"all_reduce_dict",
+ "rotate_ckpt_dirs",
# cli_args
"add_arguments",
diff --git a/lightrft/utils/utils.py b/lightrft/utils/utils.py
index 3ad1272..9b6a3a2 100644
--- a/lightrft/utils/utils.py
+++ b/lightrft/utils/utils.py
@@ -1,5 +1,6 @@
import os
import sys
+import shutil
from typing import Any, Dict, List, Optional, Tuple, Union
from datasets import interleave_datasets, load_dataset, load_from_disk, Dataset, DatasetDict
@@ -538,3 +539,42 @@ def all_reduce_dict(metrics_dict: Dict[str, float],
reduced_values = tensor.tolist()
return {k: v for k, v in zip(keys, reduced_values)}
+
+
+def rotate_ckpt_dirs(
+ ckpt_path: str,
+ max_num: int,
+ suffix: Optional[str] = None,
+ strategy: Optional[Any] = None,
+ label: str = "checkpoint",
+) -> None:
+ """
+ Remove oldest checkpoint directories under a path when reaching a limit.
+
+ :param ckpt_path: Parent checkpoint directory.
+ :type ckpt_path: str
+ :param max_num: Maximum number of checkpoints to keep.
+ :type max_num: int
+ :param suffix: Optional directory suffix filter (e.g., "_lora").
+ :type suffix: Optional[str]
+ :param strategy: Optional strategy for logging.
+ :type strategy: Optional[Any]
+ :param label: Label used in log messages.
+ :type label: str
+ """
+ if max_num is None or max_num <= 0 or not os.path.isdir(ckpt_path):
+ return
+
+ subdirs = sorted(
+ [(os.path.join(ckpt_path, d), os.path.getmtime(os.path.join(ckpt_path, d)))
+ for d in os.listdir(ckpt_path)
+ if os.path.isdir(os.path.join(ckpt_path, d)) and (suffix is None or d.endswith(suffix))],
+ key=lambda x: x[1],
+ )
+
+ while len(subdirs) >= max_num:
+ oldest_dir = subdirs.pop(0)[0]
+ if os.path.exists(oldest_dir):
+ shutil.rmtree(oldest_dir)
+ if strategy is not None:
+ strategy.print(f"Deleted oldest {label} {oldest_dir}")