From f01ea234c5672572268d3cb53c05e8fa8de12b08 Mon Sep 17 00:00:00 2001 From: Sun Jiaxuan Date: Thu, 25 Dec 2025 15:01:44 +0800 Subject: [PATCH 1/7] add rejection sampling --- .../REJECTION_SAMPLING_README.md | 167 +++++++++ .../convert_to_rejection_sampling_data.py | 158 +++++++++ .../rejection_sampling_inference.py | 333 ++++++++++++++++++ .../run_rejection_sampling.sh | 170 +++++++++ lightrft/models/grm_vl.py | 36 +- 5 files changed, 860 insertions(+), 4 deletions(-) create mode 100644 examples/grm_training/rejection_sampling/REJECTION_SAMPLING_README.md create mode 100644 examples/grm_training/rejection_sampling/convert_to_rejection_sampling_data.py create mode 100644 examples/grm_training/rejection_sampling/rejection_sampling_inference.py create mode 100644 examples/grm_training/rejection_sampling/run_rejection_sampling.sh diff --git a/examples/grm_training/rejection_sampling/REJECTION_SAMPLING_README.md b/examples/grm_training/rejection_sampling/REJECTION_SAMPLING_README.md new file mode 100644 index 0000000..ab8e1fd --- /dev/null +++ b/examples/grm_training/rejection_sampling/REJECTION_SAMPLING_README.md @@ -0,0 +1,167 @@ +# Rejection Sampling 实现说明 + +本文档说明如何在 LightRFT 框架下实现 rejection_sampling 训练流程。 + +## 概述 + +Rejection Sampling 是 UnifiedReward-Think 训练流程的第二阶段,主要步骤包括: + +1. **推理阶段**:使用 cold-start 阶段训练好的模型对大规模数据进行推理 +2. **筛选阶段**:筛选出模型预测正确的样本 +3. **数据转换**:将筛选出的样本转换为包含 CoT reasoning 的训练数据格式 +4. **训练阶段**:使用筛选出的正确样本进行监督学习训练 + +## 文件说明 + +- `rejection_sampling_inference.py`: 推理脚本,对数据集进行推理并筛选正确样本 +- `convert_to_rejection_sampling_data.py`: 数据转换脚本,将筛选出的样本转换为训练格式 +- `run_rejection_sampling.sh`: 完整的运行脚本,整合整个流程 + +## 使用方法 + +### 方法一:使用完整脚本(推荐) + +直接运行完整的 rejection sampling 流程: + +```bash +cd /mnt/shared-storage-user/sunjiaxuan/dec/LightRFT + +# 修改脚本中的配置(如需要) +# - MODEL_PATH: 你的 cold-start 模型路径 +# - DATA_PATH: 数据集路径 +# - DATA_ROOT: 数据集根目录 + +bash examples/grm_training/rejection_sampling/run_rejection_sampling.sh +``` + +### 方法二:分步执行 + +#### 步骤 1: 推理和筛选 + +```bash +python examples/grm_training/rejection_sampling/rejection_sampling_inference.py \ + --model_path /mnt/shared-storage-user/puyuan/wanzunian/models/lightrlhf-grm-lr1e5-imagegen_cot_reward-qwen2.5vl3B-gs3000 \ + --data_path "hpdv3:/mnt/shared-storage-user/puyuan/wanzunian/datasets/HPDv3/train_subset_5percent.json" \ + --output_path ./results/filtered_samples.json \ + --batch_size 32 \ + --max_new_tokens 2048 \ + --use_cot +``` + +#### 步骤 2: 数据转换 + +```bash +python examples/grm_training/rejection_sampling/convert_to_rejection_sampling_data.py \ + --filtered_samples_path ./results/filtered_samples.json \ + --output_path ./results/rejection_sampling_train.json \ + --data_root /mnt/shared-storage-user/puyuan/wanzunian/datasets/HPDv3 +``` + +#### 步骤 3: 训练 + +```bash +torchrun --nnodes 1 --nproc-per-node 8 \ + examples/grm_training/train_grm_vl.py \ + --pretrain /mnt/shared-storage-user/puyuan/wanzunian/models/lightrlhf-grm-lr1e5-imagegen_cot_reward-qwen2.5vl3B-gs3000 \ + --save_path ./results/rejection_sampling_checkpoint \ + --train_data "imagegen-cot-reward-5k:./results/rejection_sampling_train.json" \ + --train_batch_size 8 \ + --micro_train_batch_size 1 \ + --max_epochs 3 \ + --prompt_max_len 13000 \ + --actor_learning_rate 2.5e-6 \ + --zero_stage 3 \ + --bf16 \ + --gradient_checkpointing \ + --flash_attn +``` + +## 配置说明 + +### 推理阶段参数 + +- `--model_path`: Cold-start 阶段训练好的模型路径 +- `--data_path`: 数据集路径,格式为 `"source:path"`,例如 `"hpdv3:/path/to/data.json"` +- `--output_path`: 筛选出的样本保存路径 +- `--batch_size`: 推理批次大小(默认 32) +- `--max_new_tokens`: 最大生成 token 数(默认 2048) +- `--use_cot`: 是否使用 CoT 指令生成推理过程 + +### 训练阶段参数 + +- `--pretrain`: 预训练模型路径(通常是 cold-start 模型) +- `--train_data`: 训练数据路径,格式为 `"source:path"`,使用 `imagegen-cot-reward-5k` 作为 source +- `--train_batch_size`: 全局训练批次大小 +- `--micro_train_batch_size`: 每张 GPU 的微批次大小 +- `--max_epochs`: 训练轮数(默认 3) +- `--prompt_max_len`: 最大序列长度(默认 13000,支持长 CoT) +- `--actor_learning_rate`: 学习率(默认 2.5e-6) + +## 数据格式 + +### 输入数据格式(HPDv3) + +```json +{ + "path1": "images/image1.jpg", + "path2": "images/image2.jpg", + "prompt": "A beautiful landscape", + "confidence": null, + "choice_dist": null, + "model1": "model_name", + "model2": "model_name" +} +``` + +### 输出训练数据格式 + +```json +{ + "conversations": [ + { + "from": "human", + "value": "Task instruction with {prompt} placeholder..." + }, + { + "from": "gpt", + "value": "\nCoT reasoning here...\n\nImage 1 is better" + } + ], + "images": [ + "/path/to/image1.jpg", + "/path/to/image2.jpg" + ] +} +``` + +## 注意事项 + +1. **模型路径**:确保 cold-start 模型路径正确 +2. **数据路径**:确保数据集路径和根目录配置正确 +3. **显存要求**:训练时可能需要较大的显存,建议使用梯度检查点和 ZeRO Stage 3 +4. **CoT 格式**:生成的 CoT reasoning 应该包含在 `...` 标签中 +5. **答案格式**:最终答案应该在 `...` 标签中,格式为 "Image 1 is better" 或 "Image 2 is better" + +## 输出文件 + +运行完成后,会在输出目录生成以下文件: + +- `filtered_samples.json`: 筛选出的正确样本(原始格式) +- `filtered_samples_stats.txt`: 推理统计信息 +- `rejection_sampling_train.json`: 转换后的训练数据 +- `checkpoint/`: 训练好的模型检查点 +- `logs/`: 各阶段的日志文件 + +## 故障排查 + +1. **推理阶段失败**:检查模型路径和数据路径是否正确 +2. **数据转换失败**:检查图像路径是否存在,确保 `data_root` 配置正确 +3. **训练阶段 OOM**:减小 `micro_train_batch_size` 或启用 `--gradient_checkpointing` +4. **准确率低**:检查模型是否在 cold-start 阶段训练充分 + +## 参考 + +- UnifiedReward-Think 论文: https://arxiv.org/pdf/2505.03318 +- LightRFT 文档: 查看项目 README 和文档目录 + + diff --git a/examples/grm_training/rejection_sampling/convert_to_rejection_sampling_data.py b/examples/grm_training/rejection_sampling/convert_to_rejection_sampling_data.py new file mode 100644 index 0000000..ab996c1 --- /dev/null +++ b/examples/grm_training/rejection_sampling/convert_to_rejection_sampling_data.py @@ -0,0 +1,158 @@ +""" +Convert filtered samples to rejection sampling training data format. + +This script converts the filtered correct samples into the format required +for rejection sampling training, similar to imagegen-cot-reward dataset. +""" + +import os +import json +import argparse +from typing import List, Dict +from loguru import logger + + +def convert_to_rejection_sampling_format( + filtered_samples_path: str, + output_path: str, + data_root: str, + task_instruction_template: str = None, +): + """ + Convert filtered samples to rejection sampling training format. + + Args: + filtered_samples_path: Path to filtered samples JSON file + output_path: Path to save converted training data + data_root: Root directory of the dataset (for image paths) + task_instruction_template: Template for task instruction + """ + logger.info(f"Loading filtered samples from {filtered_samples_path}") + + with open(filtered_samples_path, 'r', encoding='utf-8') as f: + filtered_samples = json.load(f) + + logger.info(f"Loaded {len(filtered_samples)} filtered samples") + + # Default task instruction template + if task_instruction_template is None: + task_instruction_template = """Given a caption and two images generated based on this caption, please analyze in detail the two provided images. +Evaluate them on various dimensions such as semantic consistency (how closely the image content aligns with the caption), +aesthetics (composition, color usage, artistic expression), authenticity (realism and attention to detail), +and any other factors you deem relevant. For each evaluation dimension, +provide a score between 1-10 for both images (e.g., Image 1: 8/10, Image 2: 6/10) and provide a concise rationale for the score. +Calculate the total score for each image by summing all dimension scores. +Use a chain-of-thought process to detail your reasoning steps, and enclose all your detailed reasoning within and tags. +Then, in the tag, output exactly one of the following strings: 'Image 1 is better' or 'Image 2 is better' based on the total scores. +No additional text is allowed in the section. +Example output format: + +Semantic consistency: Image 1 (9/10) - ...; Image 2 (7/10) - ... +Aesthetics: Image 2 (8/10) - ...; Image 1 (8/10) - ... +Authenticity: Image 1 (8/10) - ...; Image 2 (5/10) - ... +[Additional dimensions if any]: Image 2 (8/10) - ...; Image 1 (6/10) - ... +Total score: +Image 1: 9+8+8+6=31 +Image 2: 7+8+5+8=28 + +Image 1 is better +Note: In the example above, scores and the final answer are placeholders meant only to demonstrate the format. Your actual evaluation should be based on the quality of two given images. +Your task is provided as follows: +Text Caption: {prompt}""" + + training_data = [] + + for idx, sample in enumerate(filtered_samples): + prompt = sample['prompt'] + path1 = sample['path1'] + path2 = sample['path2'] + preference = sample['preference'] + generated_text = sample.get('generated_text', '') + reasoning = sample.get('reasoning', '') + + # Determine which image is better based on preference + # In HPDv3, path1 is the preferred path, path2 is the rejected path + # preference "A" means Image 1 (which is path1) is better + # preference "B" means Image 2 (which is path2) is better, but this means path2 was randomly chosen as Image 1 + # Actually, in HPDv3GRMHandler, preference A means image0 (first shown) is preferred + # So we need to check which path corresponds to which image + + # Since we stored preferred_path and rejected_path, we know: + # - preferred_path (path1) should be the better one + # - rejected_path (path2) should be the worse one + # But the handler randomly assigns them to Image 1 or Image 2 + + # For training data, we always use: Image 1 = preferred, Image 2 = rejected + # This ensures consistency + answer = "Image 1 is better" + image1_path = path1 # preferred + image2_path = path2 # rejected + + # Build the response with CoT reasoning + # Note: We use instead of to match the instruction format + if reasoning: + # Use the extracted reasoning from inference + # Clean up the reasoning text + reasoning_clean = reasoning.strip() + response = f"\n{reasoning_clean}\n\n{answer}" + else: + # If no reasoning was extracted, create a placeholder + # In practice, you might want to regenerate this or use a template + response = f"\nBased on the evaluation of semantic consistency, aesthetics, and authenticity, I will compare the two images.\n\n{answer}" + + # Build conversations format + task_instruction = task_instruction_template.format(prompt=prompt) + + # Create training data item in imagegen-cot-reward format + training_item = { + "conversations": [ + { + "from": "human", + "value": task_instruction + }, + { + "from": "gpt", + "value": response + } + ], + "images": [ + image1_path if os.path.isabs(image1_path) else os.path.join(data_root, image1_path), + image2_path if os.path.isabs(image2_path) else os.path.join(data_root, image2_path), + ] + } + + training_data.append(training_item) + + if (idx + 1) % 100 == 0: + logger.info(f"Converted {idx + 1}/{len(filtered_samples)} samples") + + # Save training data + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, 'w', encoding='utf-8') as f: + json.dump(training_data, f, indent=2, ensure_ascii=False) + + logger.info(f"Saved {len(training_data)} training samples to {output_path}") + + return training_data + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert filtered samples to rejection sampling training format") + parser.add_argument("--filtered_samples_path", type=str, required=True, + help="Path to filtered samples JSON file") + parser.add_argument("--output_path", type=str, required=True, + help="Path to save converted training data") + parser.add_argument("--data_root", type=str, required=True, + help="Root directory of the dataset (for image paths)") + parser.add_argument("--task_instruction", type=str, default=None, + help="Task instruction template (optional)") + + args = parser.parse_args() + + convert_to_rejection_sampling_format( + filtered_samples_path=args.filtered_samples_path, + output_path=args.output_path, + data_root=args.data_root, + task_instruction_template=args.task_instruction, + ) + diff --git a/examples/grm_training/rejection_sampling/rejection_sampling_inference.py b/examples/grm_training/rejection_sampling/rejection_sampling_inference.py new file mode 100644 index 0000000..18bc707 --- /dev/null +++ b/examples/grm_training/rejection_sampling/rejection_sampling_inference.py @@ -0,0 +1,333 @@ +""" +Rejection Sampling Inference Script + +This script performs inference on a dataset using a trained GRM model, +filters out correctly predicted samples, and generates training data +with CoT reasoning for rejection sampling training. +""" + +import os +import json +import argparse +import torch +from tqdm import tqdm +from typing import List, Dict +from loguru import logger +from torch.utils.data import DataLoader + +from lightrft.models import GenerativeRewardModelVL +from transformers import AutoProcessor +from lightrft.datasets import GRMDataset, extract_answer + + +TASK_INSTRUCTION_COT = """Given a caption and two images generated based on this caption, please analyze in detail the two provided images. +Evaluate them on various dimensions such as semantic consistency (how closely the image content aligns with the caption), +aesthetics (composition, color usage, artistic expression), authenticity (realism and attention to detail), +and any other factors you deem relevant. For each evaluation dimension, +provide a score between 1-10 for both images (e.g., Image 1: 8/10, Image 2: 6/10) and provide a concise rationale for the score. +Calculate the total score for each image by summing all dimension scores. +Use a chain-of-thought process to detail your reasoning steps, and enclose all your detailed reasoning within and tags. +Then, in the tag, output exactly one of the following strings: 'Image 1 is better' or 'Image 2 is better' based on the total scores. +No additional text is allowed in the section. +Example output format: + +Semantic consistency: Image 1 (9/10) - ...; Image 2 (7/10) - ... +Aesthetics: Image 2 (8/10) - ...; Image 1 (8/10) - ... +Authenticity: Image 1 (8/10) - ...; Image 2 (5/10) - ... +[Additional dimensions if any]: Image 2 (8/10) - ...; Image 1 (6/10) - ... +Total score: +Image 1: 9+8+8+6=31 +Image 2: 7+8+5+8=28 + +Image 1 is better +Note: In the example above, scores and the final answer are placeholders meant only to demonstrate the format. Your actual evaluation should be based on the quality of two given images. +Your task is provided as follows: +Text Caption: {prompt} +""" + + +@torch.no_grad() +def inference_and_filter( + model_path: str, + data_path: List[str], + output_path: str, + config: dict = None, + batch_size: int = 32, + max_new_tokens: int = 2048, + use_cot: bool = True, +): + """ + Perform inference on dataset and filter correctly predicted samples. + + Args: + model_path: Path to the trained GRM model + data_path: List of dataset paths in format "source:path" + output_path: Path to save filtered samples + config: Configuration dict for dataset + batch_size: Batch size for inference + max_new_tokens: Maximum tokens to generate + use_cot: Whether to use CoT instruction (for generating reasoning) + """ + logger.info(f"Loading model from: {model_path}") + + # Load Model + # Note: Qwen2.5-VL doesn't support dtype parameter, so we disable flash attention + # and handle dtype conversion manually in the model class + if torch.cuda.is_available(): + num_gpus = torch.cuda.device_count() + logger.info(f"Found {num_gpus} GPU(s)") + + # Use DataParallel if multiple GPUs are available + if num_gpus > 1: + device = "cuda" + use_data_parallel = True + logger.info(f"Using DataParallel with {num_gpus} GPUs") + else: + device = f"cuda:{torch.cuda.current_device()}" + use_data_parallel = False + else: + device = "cpu" + use_data_parallel = False + + model = GenerativeRewardModelVL( + model_path, + bf16=True, + lora_rank=0, + lora_alpha=0, + target_modules=None, + ds_config=None, + device_map=None, # We'll move to device manually + use_flash_attention_2=False, # Disable to avoid dtype issues with Qwen2.5-VL + ) + logger.info(f"Model loaded successfully from {model_path}.") + + # Move model to device + model.model = model.model.to(device) + + # Use DataParallel for multi-GPU inference + if use_data_parallel: + model.model = torch.nn.DataParallel(model.model) + logger.info("Model wrapped with DataParallel") + + model.eval() + + processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True, use_fast=False) + + # Load Dataset + dataset = GRMDataset( + data_path, + tokenizer=processor.tokenizer, + strategy=None, + processor=processor, + max_length=8192, + config=config, + is_training=False, + ) + + # Reduce batch size if it's too large to avoid OOM + # For Qwen2.5-VL with images, smaller batch size is recommended + effective_batch_size = min(batch_size, 4) # Limit to 4 for safety + if batch_size > effective_batch_size: + logger.warning(f"Reducing batch size from {batch_size} to {effective_batch_size} to avoid OOM") + + data_loader = DataLoader( + dataset, + batch_size=effective_batch_size, + shuffle=False, + drop_last=False, + pin_memory=False, # Disable pin_memory to save memory + collate_fn=dataset.collate_fn, + num_workers=2, # Reduce workers to save memory + ) + + logger.info(f"Starting inference with CoT: {use_cot}, batch_size: {effective_batch_size}") + + correct_samples = [] + total_samples = 0 + correct_count = 0 + + # Clear cache before starting + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + for batch_idx, batch in enumerate(tqdm(data_loader)): + try: + ids, mask, pixel_values, image_grid_thws, pixel_values_videos, video_grid_thws, labels, extras = batch + + # Ensure device is a string for .to() method + # For DataParallel, use "cuda" to let it handle device placement + if use_data_parallel: + device_str = "cuda" + else: + device_str = str(device) if isinstance(device, torch.device) else device + + ids = ids.squeeze(1).to(device_str, non_blocking=False) + mask = mask.squeeze(1).to(device_str, non_blocking=False) + + if pixel_values is not None: + pixel_values = pixel_values.to(device_str, non_blocking=False) + image_grid_thws = image_grid_thws.to(device_str, non_blocking=False) + + if pixel_values_videos is not None: + pixel_values_videos = pixel_values_videos.to(device_str, non_blocking=False) + video_grid_thws = video_grid_thws.to(device_str, non_blocking=False) + + # Generate with unified max_new_tokens + # Use torch.cuda.amp for mixed precision to save memory + with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16): + # Handle DataParallel wrapper + model_to_use = model.model.module if isinstance(model.model, torch.nn.DataParallel) else model.model + gen_ids = model_to_use.generate( + input_ids=ids, + attention_mask=mask, + pixel_values=pixel_values, + image_grid_thw=image_grid_thws, + pixel_values_videos=pixel_values_videos, + video_grid_thw=video_grid_thws, + max_new_tokens=max_new_tokens, + do_sample=False, + temperature=0.0, + ) + + # Move to CPU and clear GPU memory immediately + ids_cpu = ids.cpu() + gen_ids = gen_ids.cpu() + + # Decode (gen_ids is already on CPU) + gen_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(ids_cpu, gen_ids)] + gen_texts = processor.batch_decode(gen_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False) + + # Clear GPU memory immediately + del ids, mask, pixel_values, image_grid_thws, gen_ids, gen_ids_trimmed, ids_cpu + if pixel_values_videos is not None: + del pixel_values_videos, video_grid_thws + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + # Evaluate and filter + for i, (gen_text, extra) in enumerate(zip(gen_texts, extras)): + total_samples += 1 + predicted_answer = extract_answer(gen_text) + gt_preference = extra['preference'] # A or B + + # Mapping logic: + # In HPDv3GRMHandler, preference "A" means Image 1 (first shown) is preferred + # preference "B" means Image 2 (second shown) is preferred + # But the handler randomly swaps images, so we need to check the actual mapping + # The handler stores: preferred_path (path1) and rejected_path (path2) + # When preference is "A", image0 (which could be preferred or rejected) is shown as Image 1 + # When preference is "B", image1 (which could be preferred or rejected) is shown as Image 1 + + # Since the handler randomly assigns, we check based on the stored preference + # If gt_preference is "A", it means Image 1 (first shown) is better + # If gt_preference is "B", it means Image 2 (second shown) is better + is_correct = False + if gt_preference == "A" and predicted_answer == "Image 1 is better": + is_correct = True + elif gt_preference == "B" and predicted_answer == "Image 2 is better": + is_correct = True + + if is_correct: + correct_count += 1 + # Prepare sample for rejection sampling training + sample = { + "prompt": extra['prompt'], + "path1": extra['preferred_path'], + "path2": extra['rejected_path'], + "preference": gt_preference, + "generated_text": gen_text, + "predicted_answer": predicted_answer, + } + + # If we want to use the generated CoT reasoning, extract it + if use_cot: + # Extract reasoning from generated text + # Try both and tags (in case of different formats) + reasoning_match = None + import re + # Try first (standard format) + if "" in gen_text: + reasoning_pattern = r"(.*?)" + reasoning_match = re.search(reasoning_pattern, gen_text, re.DOTALL) + # Try as fallback + elif "" in gen_text: + reasoning_pattern = r"(.*?)" + reasoning_match = re.search(reasoning_pattern, gen_text, re.DOTALL) + + if reasoning_match: + reasoning = reasoning_match.group(1).strip() + sample["reasoning"] = reasoning + else: + # If no reasoning found, we'll use the full generated text (excluding answer) + # or generate it during training data preparation + # Remove answer part to get reasoning + answer_part = f"{predicted_answer}" if predicted_answer else "" + reasoning_candidate = gen_text.replace(answer_part, "").strip() + sample["reasoning"] = reasoning_candidate if reasoning_candidate else None + + correct_samples.append(sample) + + if total_samples % 100 == 0: + logger.info(f"Processed {total_samples} samples, {correct_count} correct ({correct_count/total_samples*100:.2f}%)") + + except torch.cuda.OutOfMemoryError as e: + logger.error(f"OOM error at batch {batch_idx}. Please restart with --batch_size 1") + # Clear cache + torch.cuda.empty_cache() + raise + + # Summary + accuracy = correct_count / total_samples if total_samples > 0 else 0.0 + logger.info(f"Inference completed. Accuracy: {accuracy*100:.2f}% ({correct_count}/{total_samples})") + logger.info(f"Filtered {len(correct_samples)} correct samples for rejection sampling training") + + # Save filtered samples + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, 'w', encoding='utf-8') as f: + json.dump(correct_samples, f, indent=2, ensure_ascii=False) + + logger.info(f"Saved {len(correct_samples)} correct samples to {output_path}") + + # Save statistics + stats_path = output_path.replace('.json', '_stats.txt') + with open(stats_path, 'w', encoding='utf-8') as f: + f.write(f"Dataset paths: {data_path}\n") + f.write(f"Model path: {model_path}\n") + f.write(f"Total samples: {total_samples}\n") + f.write(f"Correct samples: {correct_count}\n") + f.write(f"Accuracy: {accuracy*100:.2f}%\n") + f.write(f"Filtered samples for training: {len(correct_samples)}\n") + + return correct_samples + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Rejection Sampling Inference") + parser.add_argument("--model_path", type=str, required=True, help="Path to trained GRM model") + parser.add_argument("--data_path", type=str, required=True, help="Dataset path in format 'source:path'") + parser.add_argument("--output_path", type=str, required=True, help="Path to save filtered samples") + parser.add_argument("--batch_size", type=int, default=32, help="Batch size for inference") + parser.add_argument("--max_new_tokens", type=int, default=2048, help="Maximum tokens to generate") + parser.add_argument("--use_cot", action="store_true", default=True, help="Use CoT instruction for reasoning") + parser.add_argument("--task_instruction", type=str, default=TASK_INSTRUCTION_COT, help="Task instruction template") + + args = parser.parse_args() + + # Parse data path + data_paths = [args.data_path] if isinstance(args.data_path, str) else args.data_path.split(",") + + config = { + "task_instruction": args.task_instruction, + "name": "rejection_sampling_inference", + } + + inference_and_filter( + model_path=args.model_path, + data_path=data_paths, + output_path=args.output_path, + config=config, + batch_size=args.batch_size, + max_new_tokens=args.max_new_tokens, + use_cot=args.use_cot, + ) + diff --git a/examples/grm_training/rejection_sampling/run_rejection_sampling.sh b/examples/grm_training/rejection_sampling/run_rejection_sampling.sh new file mode 100644 index 0000000..7df48b9 --- /dev/null +++ b/examples/grm_training/rejection_sampling/run_rejection_sampling.sh @@ -0,0 +1,170 @@ +#!/bin/bash + +# Rejection Sampling Training Script +# This script performs the complete rejection sampling pipeline: +# 1. Inference on dataset and filter correct samples +# 2. Convert filtered samples to training format +# 3. Train the model on filtered samples + +set -e + +unset http_proxy +unset https_proxy +unset HTTP_PROXY +unset HTTPS_PROXY + +############################# Configuration ########################## +# Model path (cold-start model) +MODEL_PATH="/mnt/shared-storage-user/puyuan/wanzunian/models/lightrlhf-grm-lr1e5-imagegen_cot_reward-qwen2.5vl3B-gs3000" + +# Dataset configuration +DATA_PATH="hpdv3:/mnt/shared-storage-user/puyuan/wanzunian/datasets/HPDv3/train_subset_5percent.json" +DATA_ROOT="/mnt/shared-storage-user/puyuan/wanzunian/datasets/HPDv3" + +# Output paths +OUTPUT_DIR="./results/rejection_sampling_$(date +%Y%m%d_%H%M%S)" +FILTERED_SAMPLES_PATH="${OUTPUT_DIR}/filtered_samples.json" +TRAINING_DATA_PATH="${OUTPUT_DIR}/rejection_sampling_train.json" +FINAL_CHECKPOINT_PATH="${OUTPUT_DIR}/checkpoint" + +# Training hyperparameters +TBS=8 +LR=2.5e-6 +MAX_LENGTH=13000 +MAX_EPOCHS=3 +MICRO_BATCH_SIZE=1 +GRADIENT_ACCUMULATION_STEPS=32 + +# Inference parameters (reduced to avoid OOM) +INFERENCE_BATCH_SIZE=8 +MAX_NEW_TOKENS=2048 + +# Task instruction for CoT reasoning +TASK_INSTRUCTION="""Given a caption and two images generated based on this caption, please analyze in detail the two provided images. +Evaluate them on various dimensions such as semantic consistency (how closely the image content aligns with the caption), +aesthetics (composition, color usage, artistic expression), authenticity (realism and attention to detail), +and any other factors you deem relevant. For each evaluation dimension, +provide a score between 1-10 for both images (e.g., Image 1: 8/10, Image 2: 6/10) and provide a concise rationale for the score. +Calculate the total score for each image by summing all dimension scores. +Use a chain-of-thought process to detail your reasoning steps, and enclose all your detailed reasoning within and tags. +Then, in the tag, output exactly one of the following strings: 'Image 1 is better' or 'Image 2 is better' based on the total scores. +No additional text is allowed in the section. +Example output format: + +Semantic consistency: Image 1 (9/10) - ...; Image 2 (7/10) - ... +Aesthetics: Image 2 (8/10) - ...; Image 1 (8/10) - ... +Authenticity: Image 1 (8/10) - ...; Image 2 (5/10) - ... +[Additional dimensions if any]: Image 2 (8/10) - ...; Image 1 (6/10) - ... +Total score: +Image 1: 9+8+8+6=31 +Image 2: 7+8+5+8=28 + +Image 1 is better +Note: In the example above, scores and the final answer are placeholders meant only to demonstrate the format. Your actual evaluation should be based on the quality of two given images. +Your task is provided as follows: +Text Caption: {prompt}""" + +############################### Environment ##################### +export GPUS_PER_NODE=${GPUS_PER_NODE:-2} # Use 2 GPUs +export NNODES=${NNODES:-1} +export NODE_RANK=${RANK:-0} +export MASTER_ADDR=${MASTER_ADDR:-"localhost"} +export MASTER_PORT=${MASTER_PORT:-29500} + +export WORLD_SIZE=$((NNODES * GPUS_PER_NODE)) + +# Create output directory +mkdir -p ${OUTPUT_DIR} +LOG_BASE="${OUTPUT_DIR}/logs" +mkdir -p ${LOG_BASE} + +echo "==========================================" +echo "Rejection Sampling Training Pipeline" +echo "==========================================" +echo "Model: ${MODEL_PATH}" +echo "Data: ${DATA_PATH}" +echo "Output: ${OUTPUT_DIR}" +echo "==========================================" + +############################### Step 1: Inference and Filter ########################## +echo "" +echo "Step 1: Running inference and filtering correct samples..." +echo "==========================================" + +python examples/grm_training/rejection_sampling/rejection_sampling_inference.py \ + --model_path ${MODEL_PATH} \ + --data_path ${DATA_PATH} \ + --output_path ${FILTERED_SAMPLES_PATH} \ + --batch_size ${INFERENCE_BATCH_SIZE} \ + --max_new_tokens ${MAX_NEW_TOKENS} \ + --use_cot \ + --task_instruction "${TASK_INSTRUCTION}" \ + 2>&1 | tee ${LOG_BASE}/inference.log + +if [ ! -f "${FILTERED_SAMPLES_PATH}" ]; then + echo "Error: Filtered samples file not created!" + exit 1 +fi + +echo "Step 1 completed. Filtered samples saved to: ${FILTERED_SAMPLES_PATH}" + +############################### Step 2: Convert to Training Format ########################## +echo "" +echo "Step 2: Converting filtered samples to training format..." +echo "==========================================" + +python examples/grm_training/rejection_sampling/convert_to_rejection_sampling_data.py \ + --filtered_samples_path ${FILTERED_SAMPLES_PATH} \ + --output_path ${TRAINING_DATA_PATH} \ + --data_root ${DATA_ROOT} \ + --task_instruction "${TASK_INSTRUCTION}" \ + 2>&1 | tee ${LOG_BASE}/convert.log + +if [ ! -f "${TRAINING_DATA_PATH}" ]; then + echo "Error: Training data file not created!" + exit 1 +fi + +echo "Step 2 completed. Training data saved to: ${TRAINING_DATA_PATH}" + +############################### Step 3: Training ########################## +echo "" +echo "Step 3: Training on filtered samples..." +echo "==========================================" + +# Use imagegen-cot-reward handler for the converted data +TRAINING_DATA_SOURCE="imagegen-cot-reward-5k:${TRAINING_DATA_PATH}" + +torchrun --nnodes $NNODES --nproc-per-node $GPUS_PER_NODE \ + --node_rank $NODE_RANK --master-port $MASTER_PORT --master-addr $MASTER_ADDR \ + examples/grm_training/train_grm_vl.py \ + --pretrain ${MODEL_PATH} \ + --save_path ${FINAL_CHECKPOINT_PATH} \ + --ckpt_path ${FINAL_CHECKPOINT_PATH} \ + --train_batch_size ${TBS} \ + --micro_train_batch_size ${MICRO_BATCH_SIZE} \ + --max_epochs ${MAX_EPOCHS} \ + --lr_warmup_ratio 0.03 \ + --prompt_max_len ${MAX_LENGTH} \ + --fps 2.0 \ + --zero_stage 3 \ + --bf16 \ + --actor_learning_rate ${LR} \ + --train_data ${TRAINING_DATA_SOURCE} \ + --gradient_checkpointing \ + --save_steps 1000 \ + --max_ckpt_num 8 \ + --use_tensorboard "${OUTPUT_DIR}/tensorboard" \ + --l2 0.0 \ + --flash_attn \ + --task_instruction "${TASK_INSTRUCTION}" \ + 2>&1 | tee ${LOG_BASE}/training.log + +echo "" +echo "==========================================" +echo "Rejection Sampling Training Completed!" +echo "==========================================" +echo "Final checkpoint: ${FINAL_CHECKPOINT_PATH}/final_checkpoint" +echo "All outputs saved to: ${OUTPUT_DIR}" +echo "==========================================" + diff --git a/lightrft/models/grm_vl.py b/lightrft/models/grm_vl.py index 7a2d311..a463702 100644 --- a/lightrft/models/grm_vl.py +++ b/lightrft/models/grm_vl.py @@ -61,13 +61,41 @@ def __init__( else: dschf = None + # Check if model is Qwen2.5-VL which doesn't support dtype parameter + # We can check by looking at the model path or config + from transformers import AutoConfig + try: + config = AutoConfig.from_pretrained(pretrain_or_model, trust_remote_code=True) + is_qwen_vl = "qwen" in config.model_type.lower() and "vl" in config.model_type.lower() + except: + # If we can't load config, assume it might be Qwen2.5-VL if path contains qwen + is_qwen_vl = "qwen" in pretrain_or_model.lower() and ("vl" in pretrain_or_model.lower() or "vision" in pretrain_or_model.lower()) + + # Build base kwargs + load_kwargs = { + "trust_remote_code": True, + "attn_implementation": attn_implementation, + } + + if device_map is not None: + load_kwargs["device_map"] = device_map + + # Qwen2.5-VL doesn't support dtype parameter, so skip it for these models + if not is_qwen_vl: + if bf16: + load_kwargs["dtype"] = torch.bfloat16 + else: + load_kwargs["dtype"] = "auto" + + # Load model self.model = AutoModelForVision2Seq.from_pretrained( pretrain_or_model, - trust_remote_code=True, - attn_implementation=attn_implementation, - dtype=torch.bfloat16 if bf16 else "auto", - device_map=device_map, + **load_kwargs ) + + # For Qwen2.5-VL, convert to bfloat16 manually if requested + if is_qwen_vl and bf16: + self.model = self.model.to(torch.bfloat16) # LoRA if lora_rank > 0: From 029a1114012ea24868c3b0b1b6bd533055706f75 Mon Sep 17 00:00:00 2001 From: Sun Jiaxuan Date: Thu, 25 Dec 2025 15:04:46 +0800 Subject: [PATCH 2/7] Modify path --- .../REJECTION_SAMPLING_README.md | 167 ------------------ .../run_rejection_sampling.sh | 25 ++- 2 files changed, 22 insertions(+), 170 deletions(-) delete mode 100644 examples/grm_training/rejection_sampling/REJECTION_SAMPLING_README.md diff --git a/examples/grm_training/rejection_sampling/REJECTION_SAMPLING_README.md b/examples/grm_training/rejection_sampling/REJECTION_SAMPLING_README.md deleted file mode 100644 index ab8e1fd..0000000 --- a/examples/grm_training/rejection_sampling/REJECTION_SAMPLING_README.md +++ /dev/null @@ -1,167 +0,0 @@ -# Rejection Sampling 实现说明 - -本文档说明如何在 LightRFT 框架下实现 rejection_sampling 训练流程。 - -## 概述 - -Rejection Sampling 是 UnifiedReward-Think 训练流程的第二阶段,主要步骤包括: - -1. **推理阶段**:使用 cold-start 阶段训练好的模型对大规模数据进行推理 -2. **筛选阶段**:筛选出模型预测正确的样本 -3. **数据转换**:将筛选出的样本转换为包含 CoT reasoning 的训练数据格式 -4. **训练阶段**:使用筛选出的正确样本进行监督学习训练 - -## 文件说明 - -- `rejection_sampling_inference.py`: 推理脚本,对数据集进行推理并筛选正确样本 -- `convert_to_rejection_sampling_data.py`: 数据转换脚本,将筛选出的样本转换为训练格式 -- `run_rejection_sampling.sh`: 完整的运行脚本,整合整个流程 - -## 使用方法 - -### 方法一:使用完整脚本(推荐) - -直接运行完整的 rejection sampling 流程: - -```bash -cd /mnt/shared-storage-user/sunjiaxuan/dec/LightRFT - -# 修改脚本中的配置(如需要) -# - MODEL_PATH: 你的 cold-start 模型路径 -# - DATA_PATH: 数据集路径 -# - DATA_ROOT: 数据集根目录 - -bash examples/grm_training/rejection_sampling/run_rejection_sampling.sh -``` - -### 方法二:分步执行 - -#### 步骤 1: 推理和筛选 - -```bash -python examples/grm_training/rejection_sampling/rejection_sampling_inference.py \ - --model_path /mnt/shared-storage-user/puyuan/wanzunian/models/lightrlhf-grm-lr1e5-imagegen_cot_reward-qwen2.5vl3B-gs3000 \ - --data_path "hpdv3:/mnt/shared-storage-user/puyuan/wanzunian/datasets/HPDv3/train_subset_5percent.json" \ - --output_path ./results/filtered_samples.json \ - --batch_size 32 \ - --max_new_tokens 2048 \ - --use_cot -``` - -#### 步骤 2: 数据转换 - -```bash -python examples/grm_training/rejection_sampling/convert_to_rejection_sampling_data.py \ - --filtered_samples_path ./results/filtered_samples.json \ - --output_path ./results/rejection_sampling_train.json \ - --data_root /mnt/shared-storage-user/puyuan/wanzunian/datasets/HPDv3 -``` - -#### 步骤 3: 训练 - -```bash -torchrun --nnodes 1 --nproc-per-node 8 \ - examples/grm_training/train_grm_vl.py \ - --pretrain /mnt/shared-storage-user/puyuan/wanzunian/models/lightrlhf-grm-lr1e5-imagegen_cot_reward-qwen2.5vl3B-gs3000 \ - --save_path ./results/rejection_sampling_checkpoint \ - --train_data "imagegen-cot-reward-5k:./results/rejection_sampling_train.json" \ - --train_batch_size 8 \ - --micro_train_batch_size 1 \ - --max_epochs 3 \ - --prompt_max_len 13000 \ - --actor_learning_rate 2.5e-6 \ - --zero_stage 3 \ - --bf16 \ - --gradient_checkpointing \ - --flash_attn -``` - -## 配置说明 - -### 推理阶段参数 - -- `--model_path`: Cold-start 阶段训练好的模型路径 -- `--data_path`: 数据集路径,格式为 `"source:path"`,例如 `"hpdv3:/path/to/data.json"` -- `--output_path`: 筛选出的样本保存路径 -- `--batch_size`: 推理批次大小(默认 32) -- `--max_new_tokens`: 最大生成 token 数(默认 2048) -- `--use_cot`: 是否使用 CoT 指令生成推理过程 - -### 训练阶段参数 - -- `--pretrain`: 预训练模型路径(通常是 cold-start 模型) -- `--train_data`: 训练数据路径,格式为 `"source:path"`,使用 `imagegen-cot-reward-5k` 作为 source -- `--train_batch_size`: 全局训练批次大小 -- `--micro_train_batch_size`: 每张 GPU 的微批次大小 -- `--max_epochs`: 训练轮数(默认 3) -- `--prompt_max_len`: 最大序列长度(默认 13000,支持长 CoT) -- `--actor_learning_rate`: 学习率(默认 2.5e-6) - -## 数据格式 - -### 输入数据格式(HPDv3) - -```json -{ - "path1": "images/image1.jpg", - "path2": "images/image2.jpg", - "prompt": "A beautiful landscape", - "confidence": null, - "choice_dist": null, - "model1": "model_name", - "model2": "model_name" -} -``` - -### 输出训练数据格式 - -```json -{ - "conversations": [ - { - "from": "human", - "value": "Task instruction with {prompt} placeholder..." - }, - { - "from": "gpt", - "value": "\nCoT reasoning here...\n\nImage 1 is better" - } - ], - "images": [ - "/path/to/image1.jpg", - "/path/to/image2.jpg" - ] -} -``` - -## 注意事项 - -1. **模型路径**:确保 cold-start 模型路径正确 -2. **数据路径**:确保数据集路径和根目录配置正确 -3. **显存要求**:训练时可能需要较大的显存,建议使用梯度检查点和 ZeRO Stage 3 -4. **CoT 格式**:生成的 CoT reasoning 应该包含在 `...` 标签中 -5. **答案格式**:最终答案应该在 `...` 标签中,格式为 "Image 1 is better" 或 "Image 2 is better" - -## 输出文件 - -运行完成后,会在输出目录生成以下文件: - -- `filtered_samples.json`: 筛选出的正确样本(原始格式) -- `filtered_samples_stats.txt`: 推理统计信息 -- `rejection_sampling_train.json`: 转换后的训练数据 -- `checkpoint/`: 训练好的模型检查点 -- `logs/`: 各阶段的日志文件 - -## 故障排查 - -1. **推理阶段失败**:检查模型路径和数据路径是否正确 -2. **数据转换失败**:检查图像路径是否存在,确保 `data_root` 配置正确 -3. **训练阶段 OOM**:减小 `micro_train_batch_size` 或启用 `--gradient_checkpointing` -4. **准确率低**:检查模型是否在 cold-start 阶段训练充分 - -## 参考 - -- UnifiedReward-Think 论文: https://arxiv.org/pdf/2505.03318 -- LightRFT 文档: 查看项目 README 和文档目录 - - diff --git a/examples/grm_training/rejection_sampling/run_rejection_sampling.sh b/examples/grm_training/rejection_sampling/run_rejection_sampling.sh index 7df48b9..cdd211e 100644 --- a/examples/grm_training/rejection_sampling/run_rejection_sampling.sh +++ b/examples/grm_training/rejection_sampling/run_rejection_sampling.sh @@ -15,11 +15,14 @@ unset HTTPS_PROXY ############################# Configuration ########################## # Model path (cold-start model) -MODEL_PATH="/mnt/shared-storage-user/puyuan/wanzunian/models/lightrlhf-grm-lr1e5-imagegen_cot_reward-qwen2.5vl3B-gs3000" +# Please set your model path here +MODEL_PATH="" # Dataset configuration -DATA_PATH="hpdv3:/mnt/shared-storage-user/puyuan/wanzunian/datasets/HPDv3/train_subset_5percent.json" -DATA_ROOT="/mnt/shared-storage-user/puyuan/wanzunian/datasets/HPDv3" +# Please set your dataset path here (format: "source:path") +DATA_PATH="" +# Please set your dataset root directory here +DATA_ROOT="" # Output paths OUTPUT_DIR="./results/rejection_sampling_$(date +%Y%m%d_%H%M%S)" @@ -73,6 +76,22 @@ export MASTER_PORT=${MASTER_PORT:-29500} export WORLD_SIZE=$((NNODES * GPUS_PER_NODE)) +# Validate required configuration +if [ -z "${MODEL_PATH}" ]; then + echo "Error: MODEL_PATH is not set. Please configure it in the script." + exit 1 +fi + +if [ -z "${DATA_PATH}" ]; then + echo "Error: DATA_PATH is not set. Please configure it in the script." + exit 1 +fi + +if [ -z "${DATA_ROOT}" ]; then + echo "Error: DATA_ROOT is not set. Please configure it in the script." + exit 1 +fi + # Create output directory mkdir -p ${OUTPUT_DIR} LOG_BASE="${OUTPUT_DIR}/logs" From c5344719871dbf774ac1fed440724f8e5251b3b0 Mon Sep 17 00:00:00 2001 From: Sun Jiaxuan Date: Fri, 26 Dec 2025 18:47:34 +0800 Subject: [PATCH 3/7] Improve comments and add distributed training. --- .../convert_to_rejection_sampling_data.py | 15 +- .../rejection_sampling_inference.py | 202 ++++++++++++++---- .../run_rejection_sampling.sh | 5 +- lightrft/models/grm_vl.py | 38 +--- 4 files changed, 177 insertions(+), 83 deletions(-) diff --git a/examples/grm_training/rejection_sampling/convert_to_rejection_sampling_data.py b/examples/grm_training/rejection_sampling/convert_to_rejection_sampling_data.py index ab996c1..3046c84 100644 --- a/examples/grm_training/rejection_sampling/convert_to_rejection_sampling_data.py +++ b/examples/grm_training/rejection_sampling/convert_to_rejection_sampling_data.py @@ -21,11 +21,16 @@ def convert_to_rejection_sampling_format( """ Convert filtered samples to rejection sampling training format. - Args: - filtered_samples_path: Path to filtered samples JSON file - output_path: Path to save converted training data - data_root: Root directory of the dataset (for image paths) - task_instruction_template: Template for task instruction + :param filtered_samples_path: Path to filtered samples JSON file + :type filtered_samples_path: str + :param output_path: Path to save converted training data + :type output_path: str + :param data_root: Root directory of the dataset (for image paths) + :type data_root: str + :param task_instruction_template: Template for task instruction + :type task_instruction_template: str, optional + :return: List of training data items in imagegen-cot-reward format + :rtype: List[Dict] """ logger.info(f"Loading filtered samples from {filtered_samples_path}") diff --git a/examples/grm_training/rejection_sampling/rejection_sampling_inference.py b/examples/grm_training/rejection_sampling/rejection_sampling_inference.py index 18bc707..f299c07 100644 --- a/examples/grm_training/rejection_sampling/rejection_sampling_inference.py +++ b/examples/grm_training/rejection_sampling/rejection_sampling_inference.py @@ -10,10 +10,12 @@ import json import argparse import torch +import torch.distributed as dist from tqdm import tqdm from typing import List, Dict from loguru import logger from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler from lightrft.models import GenerativeRewardModelVL from transformers import AutoProcessor @@ -55,38 +57,67 @@ def inference_and_filter( batch_size: int = 32, max_new_tokens: int = 2048, use_cot: bool = True, + local_rank: int = -1, + world_size: int = 1, ): """ Perform inference on dataset and filter correctly predicted samples. - Args: - model_path: Path to the trained GRM model - data_path: List of dataset paths in format "source:path" - output_path: Path to save filtered samples - config: Configuration dict for dataset - batch_size: Batch size for inference - max_new_tokens: Maximum tokens to generate - use_cot: Whether to use CoT instruction (for generating reasoning) + :param model_path: Path to the trained GRM model + :type model_path: str + :param data_path: List of dataset paths in format "source:path" + :type data_path: List[str] + :param output_path: Path to save filtered samples + :type output_path: str + :param config: Configuration dict for dataset + :type config: dict, optional + :param batch_size: Batch size for inference + :type batch_size: int + :param max_new_tokens: Maximum tokens to generate + :type max_new_tokens: int + :param use_cot: Whether to use CoT instruction (for generating reasoning) + :type use_cot: bool + :param local_rank: Local rank for distributed inference (-1 for single GPU) + :type local_rank: int + :param world_size: World size for distributed inference (1 for single GPU) + :type world_size: int + :return: List of correctly predicted samples with their generated text and reasoning + :rtype: List[Dict] """ + # Initialize distributed training if needed + use_distributed = local_rank >= 0 and world_size > 1 + if use_distributed: + dist.init_process_group(backend='nccl') + torch.cuda.set_device(local_rank) + device = torch.device(f'cuda:{local_rank}') + logger.info(f"Using distributed inference: rank {local_rank}/{world_size} on device {device}") + else: + device = None + logger.info(f"Loading model from: {model_path}") # Load Model # Note: Qwen2.5-VL doesn't support dtype parameter, so we disable flash attention # and handle dtype conversion manually in the model class if torch.cuda.is_available(): - num_gpus = torch.cuda.device_count() - logger.info(f"Found {num_gpus} GPU(s)") - - # Use DataParallel if multiple GPUs are available - if num_gpus > 1: - device = "cuda" - use_data_parallel = True - logger.info(f"Using DataParallel with {num_gpus} GPUs") - else: - device = f"cuda:{torch.cuda.current_device()}" + if use_distributed: + num_gpus = world_size + logger.info(f"Using distributed inference with {num_gpus} GPU(s)") use_data_parallel = False + else: + num_gpus = torch.cuda.device_count() + logger.info(f"Found {num_gpus} GPU(s)") + + # Use DataParallel if multiple GPUs are available + if num_gpus > 1: + device = torch.device("cuda:0") # Use first GPU as main device + use_data_parallel = True + logger.info(f"Using DataParallel with {num_gpus} GPUs") + else: + device = torch.device(f"cuda:{torch.cuda.current_device()}") + use_data_parallel = False else: - device = "cpu" + device = torch.device("cpu") use_data_parallel = False model = GenerativeRewardModelVL( @@ -104,10 +135,19 @@ def inference_and_filter( # Move model to device model.model = model.model.to(device) - # Use DataParallel for multi-GPU inference + # Use DataParallel for multi-GPU inference (non-distributed) if use_data_parallel: model.model = torch.nn.DataParallel(model.model) logger.info("Model wrapped with DataParallel") + elif use_distributed: + # For distributed inference, wrap with DDP + model.model = torch.nn.parallel.DistributedDataParallel( + model.model, + device_ids=[local_rank], + output_device=local_rank, + find_unused_parameters=False + ) + logger.info(f"Model wrapped with DistributedDataParallel on rank {local_rank}") model.eval() @@ -124,16 +164,38 @@ def inference_and_filter( is_training=False, ) - # Reduce batch size if it's too large to avoid OOM - # For Qwen2.5-VL with images, smaller batch size is recommended - effective_batch_size = min(batch_size, 4) # Limit to 4 for safety - if batch_size > effective_batch_size: - logger.warning(f"Reducing batch size from {batch_size} to {effective_batch_size} to avoid OOM") + # Adjust batch size based on available GPUs + # For distributed inference, batch_size is per GPU + # For single GPU with DataParallel, batch_size is total across all GPUs + if use_distributed: + # In distributed mode, batch_size is already per GPU + effective_batch_size = batch_size + logger.info(f"Using batch size {effective_batch_size} per GPU (distributed mode)") + else: + # For single GPU or DataParallel, use the provided batch_size + # DataParallel will automatically split across GPUs + effective_batch_size = batch_size + logger.info(f"Using batch size {effective_batch_size} (single GPU or DataParallel mode)") + + # Use DistributedSampler for distributed inference + if use_distributed: + sampler = DistributedSampler( + dataset, + num_replicas=world_size, + rank=local_rank, + shuffle=False, + drop_last=False + ) + shuffle = False + else: + sampler = None + shuffle = False data_loader = DataLoader( dataset, batch_size=effective_batch_size, - shuffle=False, + shuffle=shuffle, + sampler=sampler, drop_last=False, pin_memory=False, # Disable pin_memory to save memory collate_fn=dataset.collate_fn, @@ -154,10 +216,12 @@ def inference_and_filter( try: ids, mask, pixel_values, image_grid_thws, pixel_values_videos, video_grid_thws, labels, extras = batch - # Ensure device is a string for .to() method + # Ensure device is correct for .to() method # For DataParallel, use "cuda" to let it handle device placement if use_data_parallel: device_str = "cuda" + elif use_distributed: + device_str = device else: device_str = str(device) if isinstance(device, torch.device) else device @@ -175,8 +239,13 @@ def inference_and_filter( # Generate with unified max_new_tokens # Use torch.cuda.amp for mixed precision to save memory with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16): - # Handle DataParallel wrapper - model_to_use = model.model.module if isinstance(model.model, torch.nn.DataParallel) else model.model + # Handle DataParallel and DDP wrapper + if isinstance(model.model, torch.nn.DataParallel): + model_to_use = model.model.module + elif isinstance(model.model, torch.nn.parallel.DistributedDataParallel): + model_to_use = model.model.module + else: + model_to_use = model.model gen_ids = model_to_use.generate( input_ids=ids, attention_mask=mask, @@ -276,27 +345,56 @@ def inference_and_filter( torch.cuda.empty_cache() raise + # Gather results from all processes if using distributed inference + if use_distributed: + # Gather all correct_samples from all processes + gather_list = [None] * world_size + dist.all_gather_object(gather_list, correct_samples) + + # Flatten the list (only on rank 0) + if local_rank == 0: + correct_samples = [item for sublist in gather_list for item in sublist] + + # Gather statistics using tensors + total_samples_tensor = torch.tensor([total_samples], dtype=torch.long, device=device) + correct_count_tensor = torch.tensor([correct_count], dtype=torch.long, device=device) + dist.all_reduce(total_samples_tensor, op=dist.ReduceOp.SUM) + dist.all_reduce(correct_count_tensor, op=dist.ReduceOp.SUM) + total_samples = int(total_samples_tensor.item()) + correct_count = int(correct_count_tensor.item()) + + # Only rank 0 continues to save results + if local_rank != 0: + if use_distributed: + dist.destroy_process_group() + return correct_samples + # Summary accuracy = correct_count / total_samples if total_samples > 0 else 0.0 logger.info(f"Inference completed. Accuracy: {accuracy*100:.2f}% ({correct_count}/{total_samples})") logger.info(f"Filtered {len(correct_samples)} correct samples for rejection sampling training") - # Save filtered samples - os.makedirs(os.path.dirname(output_path), exist_ok=True) - with open(output_path, 'w', encoding='utf-8') as f: - json.dump(correct_samples, f, indent=2, ensure_ascii=False) - - logger.info(f"Saved {len(correct_samples)} correct samples to {output_path}") + # Save filtered samples (only on rank 0 for distributed) + if not use_distributed or local_rank == 0: + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, 'w', encoding='utf-8') as f: + json.dump(correct_samples, f, indent=2, ensure_ascii=False) + + logger.info(f"Saved {len(correct_samples)} correct samples to {output_path}") + + # Save statistics + stats_path = output_path.replace('.json', '_stats.txt') + with open(stats_path, 'w', encoding='utf-8') as f: + f.write(f"Dataset paths: {data_path}\n") + f.write(f"Model path: {model_path}\n") + f.write(f"Total samples: {total_samples}\n") + f.write(f"Correct samples: {correct_count}\n") + f.write(f"Accuracy: {accuracy*100:.2f}%\n") + f.write(f"Filtered samples for training: {len(correct_samples)}\n") - # Save statistics - stats_path = output_path.replace('.json', '_stats.txt') - with open(stats_path, 'w', encoding='utf-8') as f: - f.write(f"Dataset paths: {data_path}\n") - f.write(f"Model path: {model_path}\n") - f.write(f"Total samples: {total_samples}\n") - f.write(f"Correct samples: {correct_count}\n") - f.write(f"Accuracy: {accuracy*100:.2f}%\n") - f.write(f"Filtered samples for training: {len(correct_samples)}\n") + # Clean up distributed process group + if use_distributed: + dist.destroy_process_group() return correct_samples @@ -311,8 +409,22 @@ def inference_and_filter( parser.add_argument("--use_cot", action="store_true", default=True, help="Use CoT instruction for reasoning") parser.add_argument("--task_instruction", type=str, default=TASK_INSTRUCTION_COT, help="Task instruction template") + # Distributed training arguments (set by torchrun) + parser.add_argument("--local_rank", type=int, default=-1, help="Local rank for distributed inference") + parser.add_argument("--world_size", type=int, default=1, help="World size for distributed inference") + args = parser.parse_args() + # Get distributed settings from environment if available + local_rank = args.local_rank + world_size = args.world_size + + # Try to get from environment variables (set by torchrun) + if local_rank == -1: + local_rank = int(os.environ.get('LOCAL_RANK', -1)) + if world_size == 1: + world_size = int(os.environ.get('WORLD_SIZE', 1)) + # Parse data path data_paths = [args.data_path] if isinstance(args.data_path, str) else args.data_path.split(",") @@ -329,5 +441,7 @@ def inference_and_filter( batch_size=args.batch_size, max_new_tokens=args.max_new_tokens, use_cot=args.use_cot, + local_rank=local_rank, + world_size=world_size, ) diff --git a/examples/grm_training/rejection_sampling/run_rejection_sampling.sh b/examples/grm_training/rejection_sampling/run_rejection_sampling.sh index cdd211e..ca0bed6 100644 --- a/examples/grm_training/rejection_sampling/run_rejection_sampling.sh +++ b/examples/grm_training/rejection_sampling/run_rejection_sampling.sh @@ -110,7 +110,10 @@ echo "" echo "Step 1: Running inference and filtering correct samples..." echo "==========================================" -python examples/grm_training/rejection_sampling/rejection_sampling_inference.py \ +# Use torchrun for distributed inference to utilize multiple GPUs +torchrun --nnodes $NNODES --nproc-per-node $GPUS_PER_NODE \ + --node_rank $NODE_RANK --master-port $MASTER_PORT --master-addr $MASTER_ADDR \ + examples/grm_training/rejection_sampling/rejection_sampling_inference.py \ --model_path ${MODEL_PATH} \ --data_path ${DATA_PATH} \ --output_path ${FILTERED_SAMPLES_PATH} \ diff --git a/lightrft/models/grm_vl.py b/lightrft/models/grm_vl.py index a463702..fccaa0f 100644 --- a/lightrft/models/grm_vl.py +++ b/lightrft/models/grm_vl.py @@ -61,41 +61,13 @@ def __init__( else: dschf = None - # Check if model is Qwen2.5-VL which doesn't support dtype parameter - # We can check by looking at the model path or config - from transformers import AutoConfig - try: - config = AutoConfig.from_pretrained(pretrain_or_model, trust_remote_code=True) - is_qwen_vl = "qwen" in config.model_type.lower() and "vl" in config.model_type.lower() - except: - # If we can't load config, assume it might be Qwen2.5-VL if path contains qwen - is_qwen_vl = "qwen" in pretrain_or_model.lower() and ("vl" in pretrain_or_model.lower() or "vision" in pretrain_or_model.lower()) - - # Build base kwargs - load_kwargs = { - "trust_remote_code": True, - "attn_implementation": attn_implementation, - } - - if device_map is not None: - load_kwargs["device_map"] = device_map - - # Qwen2.5-VL doesn't support dtype parameter, so skip it for these models - if not is_qwen_vl: - if bf16: - load_kwargs["dtype"] = torch.bfloat16 - else: - load_kwargs["dtype"] = "auto" - - # Load model self.model = AutoModelForVision2Seq.from_pretrained( pretrain_or_model, - **load_kwargs + trust_remote_code=True, + attn_implementation=attn_implementation, + dtype=torch.bfloat16 if bf16 else "auto", + device_map=device_map, ) - - # For Qwen2.5-VL, convert to bfloat16 manually if requested - if is_qwen_vl and bf16: - self.model = self.model.to(torch.bfloat16) # LoRA if lora_rank > 0: @@ -245,4 +217,4 @@ def print_trainable_parameters(self): actor.print_trainable_parameters() # Output: trainable params: 4,194,304 || all params: 7,241,732,096 || trainable%: 0.058 """ - self.model.print_trainable_parameters() + self.model.print_trainable_parameters() \ No newline at end of file From ef9b3dfcfe5fd86d437f6e0249f12d9651d68626 Mon Sep 17 00:00:00 2001 From: Sun Jiaxuan Date: Mon, 5 Jan 2026 14:16:31 +0800 Subject: [PATCH 4/7] feature(sunjx): add t2i and t2v rejection sampling --- .../convert_to_rejection_sampling_data_t2v.py | 155 ++++++ .../rejection_sampling_inference.py | 397 ++++++--------- .../rejection_sampling_inference_t2v.py | 477 ++++++++++++++++++ .../run_rejection_sampling.sh | 21 +- .../run_rejection_sampling_t2v.sh | 197 ++++++++ .../train_rejection_sampling.sh | 134 +++++ lightrft/datasets/__init__.py | 25 +- lightrft/datasets/hpdv3.py | 4 + lightrft/datasets/imagegen_cot_reward.py | 168 ++++-- lightrft/datasets/omnirewardbench.py | 16 + lightrft/datasets/rapidata.py | 225 ++++++++- lightrft/datasets/rft_dataset.py | 171 +++++++ lightrft/models/grm_vl.py | 2 +- lightrft/strategy/strategy_base.py | 2 +- 14 files changed, 1675 insertions(+), 319 deletions(-) create mode 100644 examples/grm_training/rejection_sampling/convert_to_rejection_sampling_data_t2v.py create mode 100644 examples/grm_training/rejection_sampling/rejection_sampling_inference_t2v.py create mode 100755 examples/grm_training/rejection_sampling/run_rejection_sampling_t2v.sh create mode 100755 examples/grm_training/rejection_sampling/train_rejection_sampling.sh create mode 100644 lightrft/datasets/rft_dataset.py diff --git a/examples/grm_training/rejection_sampling/convert_to_rejection_sampling_data_t2v.py b/examples/grm_training/rejection_sampling/convert_to_rejection_sampling_data_t2v.py new file mode 100644 index 0000000..08613cf --- /dev/null +++ b/examples/grm_training/rejection_sampling/convert_to_rejection_sampling_data_t2v.py @@ -0,0 +1,155 @@ +""" +Convert filtered samples to rejection sampling training data format for T2V. + +This script converts the filtered correct samples into the format required +for rejection sampling training, similar to imagegen-cot-reward dataset but for videos. +""" + +import os +import json +import argparse +from typing import List, Dict +from loguru import logger + + +def convert_to_rejection_sampling_format( + filtered_samples_path: str, + output_path: str, + task_instruction_template: str = None, + video_fps: float = 2.0, +): + """ + Convert filtered samples to rejection sampling training format for T2V. + + :param filtered_samples_path: Path to filtered samples JSON file + :type filtered_samples_path: str + :param output_path: Path to save converted training data + :type output_path: str + :param task_instruction_template: Template for task instruction + :type task_instruction_template: str, optional + :param video_fps: FPS for video processing + :type video_fps: float + :return: List of training data items in imagegen-cot-reward format (for videos) + :rtype: List[Dict] + """ + logger.info(f"Loading filtered samples from {filtered_samples_path}") + + with open(filtered_samples_path, 'r', encoding='utf-8') as f: + filtered_samples = json.load(f) + + logger.info(f"Loaded {len(filtered_samples)} filtered samples") + + # Default task instruction template for T2V + if task_instruction_template is None: + task_instruction_template = """Given a caption and two videos generated based on this caption, please analyze in detail the two provided videos. +Evaluate them on various dimensions such as semantic consistency (how closely the video content aligns with the caption), temporal coherence (smoothness and logical flow of motion across frames), authenticity (realism and attention to detail), and any other factors you deem relevant. +For each evaluation dimension, provide a score between 1-10 for both videos (e.g., Video 1: 8/10, Video 2: 6/10) and provide a concise rationale for the score. +Calculate the total score for each video by summing all dimension scores. +Use a chain-of-thought process to detail your reasoning steps, and enclose all your detailed reasoning within and tags. Then, in the tag, output exactly one of the following strings: +'Video 1 is better' or 'Video 2 is better' based on the total scores. No additional text is allowed in the section. +Example output format: + +1. Semantic consistency: Video 1 (9/10) - ...; Video 2 (7/10) - ... +2. Temporal coherence: Video 1 (8/10) - ...; Video 2 (6/10) - ... +3. Authenticity: Video 1 (7/10) - ...; Video 2 (5/10) - ... +... +[Additional dimensions if any]: Video 2 (8/10) - ...; Video 1 (6/10) - ... +Total score: +Video 1: 9+8+7+6=30 +Video 2: 7+6+5+8=26 + +Video 1 is better + +Note: In the example above, scores and the final answer are placeholders meant only to demonstrate the format. Your actual evaluation should be based on the quality of two given videos. +Your task is provided as follows: +Text Caption: {prompt}""" + + training_data = [] + + for idx, sample in enumerate(filtered_samples): + prompt = sample['prompt'] + path1 = sample['path1'] + path2 = sample['path2'] + preference = sample['preference'] + generated_text = sample.get('generated_text', '') + reasoning = sample.get('reasoning', '') + + # Determine which video is better based on preference + # preference "A" means Video 1 is better + # preference "B" means Video 2 is better + # For training data, we always use: Video 1 = preferred, Video 2 = rejected + # This ensures consistency + answer = "Video 1 is better" if preference == "A" else "Video 2 is better" + video1_path = path1 if preference == "A" else path2 # preferred + video2_path = path2 if preference == "A" else path1 # rejected + + # Build the response with CoT reasoning + if reasoning: + # Use the extracted reasoning from inference + # Clean up the reasoning text + reasoning_clean = reasoning.strip() + response = f"\n{reasoning_clean}\n\n{answer}" + else: + # If no reasoning was extracted, create a placeholder + response = f"\nBased on the evaluation of semantic consistency, temporal coherence, and authenticity, I will compare the two videos.\n\n{answer}" + + # Build conversations format + task_instruction = task_instruction_template.format(prompt=prompt) + + # Create training data item in imagegen-cot-reward format (but for videos) + # We use "images" field name to be compatible with ImageGenCoTRewardHandler + # but store video paths - the handler will need to be modified to support videos + # For now, we store relative paths from data_root + training_item = { + "conversations": [ + { + "from": "human", + "value": task_instruction + }, + { + "from": "gpt", + "value": response + } + ], + "images": [ + video1_path if os.path.isabs(video1_path) else video1_path, + video2_path if os.path.isabs(video2_path) else video2_path, + ], + "video_fps": video_fps, # Store FPS for video processing + } + + training_data.append(training_item) + + if (idx + 1) % 100 == 0: + logger.info(f"Converted {idx + 1}/{len(filtered_samples)} samples") + + # Save training data + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, 'w', encoding='utf-8') as f: + json.dump(training_data, f, indent=2, ensure_ascii=False) + + logger.info(f"Saved {len(training_data)} training samples to {output_path}") + + return training_data + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert filtered samples to rejection sampling training format for T2V") + parser.add_argument("--filtered_samples_path", type=str, required=True, + help="Path to filtered samples JSON file") + parser.add_argument("--output_path", type=str, required=True, + help="Path to save converted training data") + parser.add_argument("--task_instruction", type=str, default=None, + help="Task instruction template (optional)") + parser.add_argument("--video_fps", type=float, default=2.0, + help="FPS for video processing") + + args = parser.parse_args() + + convert_to_rejection_sampling_format( + filtered_samples_path=args.filtered_samples_path, + output_path=args.output_path, + task_instruction_template=args.task_instruction, + video_fps=args.video_fps, + ) + diff --git a/examples/grm_training/rejection_sampling/rejection_sampling_inference.py b/examples/grm_training/rejection_sampling/rejection_sampling_inference.py index f299c07..a6546a4 100644 --- a/examples/grm_training/rejection_sampling/rejection_sampling_inference.py +++ b/examples/grm_training/rejection_sampling/rejection_sampling_inference.py @@ -9,16 +9,14 @@ import os import json import argparse -import torch -import torch.distributed as dist +import re from tqdm import tqdm from typing import List, Dict from loguru import logger from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler -from lightrft.models import GenerativeRewardModelVL -from transformers import AutoProcessor +from transformers import AutoProcessor, AutoTokenizer +from vllm import LLM, SamplingParams from lightrft.datasets import GRMDataset, extract_answer @@ -48,7 +46,90 @@ """ -@torch.no_grad() +class GRMPromptDatasetVL: + """ + Dataset wrapper for vLLM inference that returns prompts and image/video paths + instead of tokenized inputs. + """ + def __init__( + self, + dataset_paths: List[str], + processor: AutoProcessor, + tokenizer: AutoTokenizer, + strategy=None, + max_length: int = 8192, + config: Dict = None, + is_training: bool = False, + ): + self.base_dataset = GRMDataset( + dataset_paths, + processor=processor, + tokenizer=tokenizer, + strategy=strategy, + max_length=max_length, + config=config, + is_training=is_training, + ) + self.processor = processor + self.tokenizer = tokenizer + + def __len__(self): + return len(self.base_dataset) + + def __getitem__(self, idx): + item = self.base_dataset.data[idx] + source = item["source"] + handler = self.base_dataset.handlers[source] + + # Get media info (paths) + media_info = handler.get_media_info(item) + + # Load media content (needed for parse_item) + loaded_content = self.base_dataset.media_content_loader(media_info) + if loaded_content is None: + raise RuntimeError(f"Failed to load media content: {media_info}") + + # Parse item to get messages + messages, other = handler.parse_item(item, loaded_content, self.base_dataset.config) + + # Get prompt text (exclude the last assistant message for inference) + messages_for_prompt = messages[:-1] if len(messages) > 0 else messages + prompt_text = self.processor.apply_chat_template( + messages_for_prompt, + tokenize=False, + add_generation_prompt=True, + ) + + # Extract image and video paths from media_info for vLLM + image_paths = [] + video_paths = [] + + if media_info: + # media_info is typically a dict with 'images' and 'videos' keys + if isinstance(media_info, dict): + image_paths = media_info.get('images', []) + video_paths = media_info.get('videos', []) + elif isinstance(media_info, list): + # If it's a list, assume all are images + image_paths = media_info + + return prompt_text, image_paths, video_paths, other + + def collate_fn(self, batch): + input_texts = [] + image_inputs_list = [] + video_inputs_list = [] + extras = [] + + for prompt_text, image_paths, video_paths, other in batch: + input_texts.append(prompt_text) + image_inputs_list.append(image_paths if image_paths else None) + video_inputs_list.append(video_paths if video_paths else None) + extras.append(other) + + return input_texts, image_inputs_list, video_inputs_list, extras + + def inference_and_filter( model_path: str, data_path: List[str], @@ -57,8 +138,8 @@ def inference_and_filter( batch_size: int = 32, max_new_tokens: int = 2048, use_cot: bool = True, - local_rank: int = -1, - world_size: int = 1, + tensor_parallel_size: int = 1, + gpu_memory_utilization: float = 0.9, ): """ Perform inference on dataset and filter correctly predicted samples. @@ -77,201 +158,90 @@ def inference_and_filter( :type max_new_tokens: int :param use_cot: Whether to use CoT instruction (for generating reasoning) :type use_cot: bool - :param local_rank: Local rank for distributed inference (-1 for single GPU) - :type local_rank: int - :param world_size: World size for distributed inference (1 for single GPU) - :type world_size: int + :param tensor_parallel_size: Number of GPUs for tensor parallelism + :type tensor_parallel_size: int + :param gpu_memory_utilization: GPU memory utilization ratio + :type gpu_memory_utilization: float :return: List of correctly predicted samples with their generated text and reasoning :rtype: List[Dict] """ - # Initialize distributed training if needed - use_distributed = local_rank >= 0 and world_size > 1 - if use_distributed: - dist.init_process_group(backend='nccl') - torch.cuda.set_device(local_rank) - device = torch.device(f'cuda:{local_rank}') - logger.info(f"Using distributed inference: rank {local_rank}/{world_size} on device {device}") - else: - device = None - logger.info(f"Loading model from: {model_path}") - # Load Model - # Note: Qwen2.5-VL doesn't support dtype parameter, so we disable flash attention - # and handle dtype conversion manually in the model class - if torch.cuda.is_available(): - if use_distributed: - num_gpus = world_size - logger.info(f"Using distributed inference with {num_gpus} GPU(s)") - use_data_parallel = False - else: - num_gpus = torch.cuda.device_count() - logger.info(f"Found {num_gpus} GPU(s)") - - # Use DataParallel if multiple GPUs are available - if num_gpus > 1: - device = torch.device("cuda:0") # Use first GPU as main device - use_data_parallel = True - logger.info(f"Using DataParallel with {num_gpus} GPUs") - else: - device = torch.device(f"cuda:{torch.cuda.current_device()}") - use_data_parallel = False - else: - device = torch.device("cpu") - use_data_parallel = False + # Initialize vLLM + llm = LLM( + model=model_path, + tensor_parallel_size=tensor_parallel_size, + trust_remote_code=True, + gpu_memory_utilization=gpu_memory_utilization, + limit_mm_per_prompt={ + "image": 2, + "video": 2 + }, + ) - model = GenerativeRewardModelVL( - model_path, - bf16=True, - lora_rank=0, - lora_alpha=0, - target_modules=None, - ds_config=None, - device_map=None, # We'll move to device manually - use_flash_attention_2=False, # Disable to avoid dtype issues with Qwen2.5-VL + sampling_params = SamplingParams( + temperature=0.0, # For deterministic output + max_tokens=max_new_tokens, ) - logger.info(f"Model loaded successfully from {model_path}.") - # Move model to device - model.model = model.model.to(device) + logger.info(f"Model loaded successfully from {model_path}.") - # Use DataParallel for multi-GPU inference (non-distributed) - if use_data_parallel: - model.model = torch.nn.DataParallel(model.model) - logger.info("Model wrapped with DataParallel") - elif use_distributed: - # For distributed inference, wrap with DDP - model.model = torch.nn.parallel.DistributedDataParallel( - model.model, - device_ids=[local_rank], - output_device=local_rank, - find_unused_parameters=False - ) - logger.info(f"Model wrapped with DistributedDataParallel on rank {local_rank}") + # Load Processor and Tokenizer for Dataset + processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) - model.eval() - - processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True, use_fast=False) - # Load Dataset - dataset = GRMDataset( + dataset = GRMPromptDatasetVL( data_path, - tokenizer=processor.tokenizer, - strategy=None, processor=processor, + tokenizer=tokenizer, + strategy=None, max_length=8192, config=config, is_training=False, ) - - # Adjust batch size based on available GPUs - # For distributed inference, batch_size is per GPU - # For single GPU with DataParallel, batch_size is total across all GPUs - if use_distributed: - # In distributed mode, batch_size is already per GPU - effective_batch_size = batch_size - logger.info(f"Using batch size {effective_batch_size} per GPU (distributed mode)") - else: - # For single GPU or DataParallel, use the provided batch_size - # DataParallel will automatically split across GPUs - effective_batch_size = batch_size - logger.info(f"Using batch size {effective_batch_size} (single GPU or DataParallel mode)") - - # Use DistributedSampler for distributed inference - if use_distributed: - sampler = DistributedSampler( - dataset, - num_replicas=world_size, - rank=local_rank, - shuffle=False, - drop_last=False - ) - shuffle = False - else: - sampler = None - shuffle = False data_loader = DataLoader( dataset, - batch_size=effective_batch_size, - shuffle=shuffle, - sampler=sampler, + batch_size=batch_size, + shuffle=False, drop_last=False, - pin_memory=False, # Disable pin_memory to save memory collate_fn=dataset.collate_fn, - num_workers=2, # Reduce workers to save memory ) - - logger.info(f"Starting inference with CoT: {use_cot}, batch_size: {effective_batch_size}") + + logger.info(f"Starting inference with CoT: {use_cot}, batch_size: {batch_size}") correct_samples = [] total_samples = 0 correct_count = 0 - # Clear cache before starting - if torch.cuda.is_available(): - torch.cuda.empty_cache() - for batch_idx, batch in enumerate(tqdm(data_loader)): try: - ids, mask, pixel_values, image_grid_thws, pixel_values_videos, video_grid_thws, labels, extras = batch - - # Ensure device is correct for .to() method - # For DataParallel, use "cuda" to let it handle device placement - if use_data_parallel: - device_str = "cuda" - elif use_distributed: - device_str = device - else: - device_str = str(device) if isinstance(device, torch.device) else device - - ids = ids.squeeze(1).to(device_str, non_blocking=False) - mask = mask.squeeze(1).to(device_str, non_blocking=False) - - if pixel_values is not None: - pixel_values = pixel_values.to(device_str, non_blocking=False) - image_grid_thws = image_grid_thws.to(device_str, non_blocking=False) - - if pixel_values_videos is not None: - pixel_values_videos = pixel_values_videos.to(device_str, non_blocking=False) - video_grid_thws = video_grid_thws.to(device_str, non_blocking=False) - - # Generate with unified max_new_tokens - # Use torch.cuda.amp for mixed precision to save memory - with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16): - # Handle DataParallel and DDP wrapper - if isinstance(model.model, torch.nn.DataParallel): - model_to_use = model.model.module - elif isinstance(model.model, torch.nn.parallel.DistributedDataParallel): - model_to_use = model.model.module - else: - model_to_use = model.model - gen_ids = model_to_use.generate( - input_ids=ids, - attention_mask=mask, - pixel_values=pixel_values, - image_grid_thw=image_grid_thws, - pixel_values_videos=pixel_values_videos, - video_grid_thw=video_grid_thws, - max_new_tokens=max_new_tokens, - do_sample=False, - temperature=0.0, - ) + input_texts, image_inputs_list, video_inputs_list, extras = batch - # Move to CPU and clear GPU memory immediately - ids_cpu = ids.cpu() - gen_ids = gen_ids.cpu() + # Prepare inputs for vLLM + inputs = [] + for i in range(len(input_texts)): + prompt = input_texts[i] + image_inputs = image_inputs_list[i] + video_inputs = video_inputs_list[i] + + mm_data = {} + if image_inputs is not None: + mm_data["image"] = image_inputs + if video_inputs is not None: + mm_data["video"] = video_inputs + + inputs.append({ + "prompt": prompt, + "multi_modal_data": mm_data + }) - # Decode (gen_ids is already on CPU) - gen_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(ids_cpu, gen_ids)] - gen_texts = processor.batch_decode(gen_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False) + # Generate with vLLM + outputs = llm.generate(inputs, sampling_params=sampling_params) - # Clear GPU memory immediately - del ids, mask, pixel_values, image_grid_thws, gen_ids, gen_ids_trimmed, ids_cpu - if pixel_values_videos is not None: - del pixel_values_videos, video_grid_thws - if torch.cuda.is_available(): - torch.cuda.empty_cache() + # Decode + gen_texts = [output.outputs[0].text for output in outputs] # Evaluate and filter for i, (gen_text, extra) in enumerate(zip(gen_texts, extras)): @@ -339,62 +309,31 @@ def inference_and_filter( if total_samples % 100 == 0: logger.info(f"Processed {total_samples} samples, {correct_count} correct ({correct_count/total_samples*100:.2f}%)") - except torch.cuda.OutOfMemoryError as e: - logger.error(f"OOM error at batch {batch_idx}. Please restart with --batch_size 1") - # Clear cache - torch.cuda.empty_cache() + except Exception as e: + logger.error(f"Error at batch {batch_idx}: {e}") raise - - # Gather results from all processes if using distributed inference - if use_distributed: - # Gather all correct_samples from all processes - gather_list = [None] * world_size - dist.all_gather_object(gather_list, correct_samples) - - # Flatten the list (only on rank 0) - if local_rank == 0: - correct_samples = [item for sublist in gather_list for item in sublist] - - # Gather statistics using tensors - total_samples_tensor = torch.tensor([total_samples], dtype=torch.long, device=device) - correct_count_tensor = torch.tensor([correct_count], dtype=torch.long, device=device) - dist.all_reduce(total_samples_tensor, op=dist.ReduceOp.SUM) - dist.all_reduce(correct_count_tensor, op=dist.ReduceOp.SUM) - total_samples = int(total_samples_tensor.item()) - correct_count = int(correct_count_tensor.item()) - - # Only rank 0 continues to save results - if local_rank != 0: - if use_distributed: - dist.destroy_process_group() - return correct_samples # Summary accuracy = correct_count / total_samples if total_samples > 0 else 0.0 logger.info(f"Inference completed. Accuracy: {accuracy*100:.2f}% ({correct_count}/{total_samples})") logger.info(f"Filtered {len(correct_samples)} correct samples for rejection sampling training") - - # Save filtered samples (only on rank 0 for distributed) - if not use_distributed or local_rank == 0: - os.makedirs(os.path.dirname(output_path), exist_ok=True) - with open(output_path, 'w', encoding='utf-8') as f: - json.dump(correct_samples, f, indent=2, ensure_ascii=False) - - logger.info(f"Saved {len(correct_samples)} correct samples to {output_path}") - - # Save statistics - stats_path = output_path.replace('.json', '_stats.txt') - with open(stats_path, 'w', encoding='utf-8') as f: - f.write(f"Dataset paths: {data_path}\n") - f.write(f"Model path: {model_path}\n") - f.write(f"Total samples: {total_samples}\n") - f.write(f"Correct samples: {correct_count}\n") - f.write(f"Accuracy: {accuracy*100:.2f}%\n") - f.write(f"Filtered samples for training: {len(correct_samples)}\n") - # Clean up distributed process group - if use_distributed: - dist.destroy_process_group() + # Save filtered samples + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, 'w', encoding='utf-8') as f: + json.dump(correct_samples, f, indent=2, ensure_ascii=False) + + logger.info(f"Saved {len(correct_samples)} correct samples to {output_path}") + + # Save statistics + stats_path = output_path.replace('.json', '_stats.txt') + with open(stats_path, 'w', encoding='utf-8') as f: + f.write(f"Dataset paths: {data_path}\n") + f.write(f"Model path: {model_path}\n") + f.write(f"Total samples: {total_samples}\n") + f.write(f"Correct samples: {correct_count}\n") + f.write(f"Accuracy: {accuracy*100:.2f}%\n") + f.write(f"Filtered samples for training: {len(correct_samples)}\n") return correct_samples @@ -409,22 +348,12 @@ def inference_and_filter( parser.add_argument("--use_cot", action="store_true", default=True, help="Use CoT instruction for reasoning") parser.add_argument("--task_instruction", type=str, default=TASK_INSTRUCTION_COT, help="Task instruction template") - # Distributed training arguments (set by torchrun) - parser.add_argument("--local_rank", type=int, default=-1, help="Local rank for distributed inference") - parser.add_argument("--world_size", type=int, default=1, help="World size for distributed inference") + # vLLM arguments + parser.add_argument("--tensor_parallel_size", type=int, default=1, help="Number of GPUs for tensor parallelism") + parser.add_argument("--gpu_memory_utilization", type=float, default=0.9, help="GPU memory utilization ratio") args = parser.parse_args() - # Get distributed settings from environment if available - local_rank = args.local_rank - world_size = args.world_size - - # Try to get from environment variables (set by torchrun) - if local_rank == -1: - local_rank = int(os.environ.get('LOCAL_RANK', -1)) - if world_size == 1: - world_size = int(os.environ.get('WORLD_SIZE', 1)) - # Parse data path data_paths = [args.data_path] if isinstance(args.data_path, str) else args.data_path.split(",") @@ -441,7 +370,7 @@ def inference_and_filter( batch_size=args.batch_size, max_new_tokens=args.max_new_tokens, use_cot=args.use_cot, - local_rank=local_rank, - world_size=world_size, + tensor_parallel_size=args.tensor_parallel_size, + gpu_memory_utilization=args.gpu_memory_utilization, ) diff --git a/examples/grm_training/rejection_sampling/rejection_sampling_inference_t2v.py b/examples/grm_training/rejection_sampling/rejection_sampling_inference_t2v.py new file mode 100644 index 0000000..65d4e3a --- /dev/null +++ b/examples/grm_training/rejection_sampling/rejection_sampling_inference_t2v.py @@ -0,0 +1,477 @@ +""" +Rejection Sampling Inference Script for Text-to-Video (T2V) + +This script performs inference on a dataset using a trained GRM model, +filters out correctly predicted samples, and generates training data +with CoT reasoning for rejection sampling training. + +For Rapidata-T2V, we compute gt_preference based on the sum of three dimensions: +Alignment + Coherence + Preference +""" + +import os +import json +import argparse +import re +from tqdm import tqdm +from typing import List, Dict +from loguru import logger +from torch.utils.data import DataLoader + +from transformers import AutoProcessor, AutoTokenizer +from vllm import LLM, SamplingParams +from lightrft.datasets import extract_answer, RFTDatasetVL + +# Import qwen_vl_utils for processing vision info +try: + from qwen_vl_utils import process_vision_info +except ImportError: + try: + from keye_vl_utils import process_vision_info + except ImportError: + raise ImportError("Neither qwen_vl_utils nor keye_vl_utils is available") + + +TASK_INSTRUCTION_COT_T2V = """Given a caption and two videos generated based on this caption, please analyze in detail the two provided videos. +Evaluate them on various dimensions such as semantic consistency (how closely the video content aligns with the caption), temporal coherence (smoothness and logical flow of motion across frames), authenticity (realism and attention to detail), and any other factors you deem relevant. +For each evaluation dimension, provide a score between 1-10 for both videos (e.g., Video 1: 8/10, Video 2: 6/10) and provide a concise rationale for the score. +Calculate the total score for each video by summing all dimension scores. +Use a chain-of-thought process to detail your reasoning steps, and enclose all your detailed reasoning within and tags. Then, in the tag, output exactly one of the following strings: +'Video 1 is better' or 'Video 2 is better' based on the total scores. No additional text is allowed in the section. +Example output format: + +1. Semantic consistency: Video 1 (9/10) - ...; Video 2 (7/10) - ... +2. Temporal coherence: Video 1 (8/10) - ...; Video 2 (6/10) - ... +3. Authenticity: Video 1 (7/10) - ...; Video 2 (5/10) - ... +... +[Additional dimensions if any]: Video 2 (8/10) - ...; Video 1 (6/10) - ... +Total score: +Video 1: 9+8+7+6=30 +Video 2: 7+6+5+8=26 + +Video 1 is better + +Note: In the example above, scores and the final answer are placeholders meant only to demonstrate the format. Your actual evaluation should be based on the quality of two given videos. +Your task is provided as follows: +Text Caption: **{prompt}** +""" + + +class GRMPromptDatasetVLT2V: + """ + Dataset wrapper for vLLM inference that returns prompts and video paths + instead of tokenized inputs. Adapted for T2V with RFTDatasetVL. + """ + def __init__( + self, + dataset_paths: List[str], + processor: AutoProcessor, + tokenizer: AutoTokenizer, + strategy=None, + max_length: int = 8192, + config: Dict = None, + is_training: bool = False, + ): + self.base_dataset = RFTDatasetVL( + dataset_paths, + processor=processor, + tokenizer=tokenizer, + strategy=strategy, + max_length=max_length, + config=config, + is_train=is_training, + ) + self.processor = processor + self.tokenizer = tokenizer + + def __len__(self): + return len(self.base_dataset) + + def __getitem__(self, idx): + item = self.base_dataset.data[idx] + source = item["source"] + handler = self.base_dataset.handlers[source] + + # Get media info (paths) + media_info = handler.get_media_info(item) + + # Load media content (needed for parse_item) + loaded_content = self.base_dataset.media_content_loader(media_info) + if loaded_content is None: + raise RuntimeError(f"Failed to load media content: {media_info}") + + # Parse item to get messages (returns messages0, messages1, other for PairHandler) + messages0, messages1, other = handler.parse_item(item, loaded_content, self.base_dataset.config) + + # Combine messages0 and messages1 to show both videos in the same conversation + # Similar to HPDv3GRMHandler format: system prompt + Video 1 + Video 2 + messages = [] + + # Add system prompt (from messages0) + if len(messages0) > 0 and messages0[0].get("role") == "system": + messages.append(messages0[0]) + + # Add Video 1 with label + if len(messages0) > 1 and messages0[1].get("role") == "user": + video1_content = messages0[1]["content"] + messages.append({ + "role": "user", + "content": [ + { + "type": "text", + "text": "**Video 1:**" + }, + video1_content[0] if isinstance(video1_content, list) and len(video1_content) > 0 else video1_content + ] + }) + + # Add Video 2 with label (from messages1) + if len(messages1) > 1 and messages1[1].get("role") == "user": + video2_content = messages1[1]["content"] + messages.append({ + "role": "user", + "content": [ + { + "type": "text", + "text": "**Video 2:**" + }, + video2_content[0] if isinstance(video2_content, list) and len(video2_content) > 0 else video2_content + ] + }) + + # Get prompt text (exclude the last assistant message for inference) + messages_for_prompt = messages[:-1] if len(messages) > 0 and messages[-1].get("role") == "assistant" else messages + prompt_text = self.processor.apply_chat_template( + messages_for_prompt, + tokenize=False, + add_generation_prompt=True, + ) + + # Extract video information from messages using process_vision_info + # This is the same way test_grm_vl_vllm.py does it + # process_vision_info returns (image_inputs, video_inputs, video_kwargs) + # but we only need image_inputs and video_inputs for vLLM + image_inputs, video_inputs, _ = process_vision_info( + messages_for_prompt, + return_video_kwargs=True, + ) + + # Store original item for accessing raw scores + other['_raw_item'] = item + + return prompt_text, image_inputs, video_inputs, other + + def collate_fn(self, batch): + input_texts = [] + image_inputs_list = [] + video_inputs_list = [] + extras = [] + + for prompt_text, image_inputs, video_inputs, other in batch: + input_texts.append(prompt_text) + image_inputs_list.append(image_inputs if image_inputs else None) + video_inputs_list.append(video_inputs if video_inputs else None) + extras.append(other) + + return input_texts, image_inputs_list, video_inputs_list, extras + + +def safe_get_score(item: Dict, key: str, default: float = 0.0) -> float: + """ + Safely get score value from item, handling None values. + + :param item: Dictionary containing score values + :param key: Key to look up in the dictionary + :param default: Default value to use if key is missing or value is None + :return: Float score value + """ + value = item.get(key, default) + return default if value is None else float(value) + + +def compute_total_score(item: Dict, video_num: int) -> float: + """ + Compute total score for a video based on three dimensions. + + :param item: Dictionary containing score values + :param video_num: Video number (1 or 2) + :return: Total score (Alignment + Coherence + Preference) + """ + alignment = safe_get_score(item, f"weighted_results{video_num}_Alignment", 0.0) + coherence = safe_get_score(item, f"weighted_results{video_num}_Coherence", 0.0) + preference = safe_get_score(item, f"weighted_results{video_num}_Preference", 0.0) + return alignment + coherence + preference + + +def compute_gt_preference_from_scores(item: Dict) -> str: + """ + Compute ground truth preference based on sum of three dimensions: + Alignment + Coherence + Preference + + Returns "A" if video1 has higher total score, "B" if video2 has higher total score. + """ + total_score1 = compute_total_score(item, 1) + total_score2 = compute_total_score(item, 2) + + if total_score1 > total_score2: + return "A" # Video 1 is better + elif total_score1 < total_score2: + return "B" # Video 2 is better + else: + return "C" # Equal (shouldn't happen often, but handle it) + + +def inference_and_filter( + model_path: str, + data_path: List[str], + output_path: str, + config: dict = None, + batch_size: int = 32, + max_new_tokens: int = 2048, + use_cot: bool = True, + tensor_parallel_size: int = 1, + gpu_memory_utilization: float = 0.9, + video_fps: float = 2.0, +): + """ + Perform inference on dataset and filter correctly predicted samples. + + :param model_path: Path to the trained GRM model + :type model_path: str + :param data_path: List of dataset paths in format "source:path" + :type data_path: List[str] + :param output_path: Path to save filtered samples + :type output_path: str + :param config: Configuration dict for dataset + :type config: dict, optional + :param batch_size: Batch size for inference + :type batch_size: int + :param max_new_tokens: Maximum tokens to generate + :type max_new_tokens: int + :param use_cot: Whether to use CoT instruction (for generating reasoning) + :type use_cot: bool + :param tensor_parallel_size: Number of GPUs for tensor parallelism + :type tensor_parallel_size: int + :param gpu_memory_utilization: GPU memory utilization ratio + :type gpu_memory_utilization: float + :param video_fps: FPS for video processing + :type video_fps: float + :return: List of correctly predicted samples with their generated text and reasoning + :rtype: List[Dict] + """ + logger.info(f"Loading model from: {model_path}") + + # Initialize vLLM + llm = LLM( + model=model_path, + tensor_parallel_size=tensor_parallel_size, + trust_remote_code=True, + gpu_memory_utilization=gpu_memory_utilization, + limit_mm_per_prompt={ + "image": 0, + "video": 2 + }, + ) + + sampling_params = SamplingParams( + temperature=0.0, # For deterministic output + max_tokens=max_new_tokens, + ) + + logger.info(f"Model loaded successfully from {model_path}.") + + # Load Processor and Tokenizer for Dataset + processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + + # Load Dataset + dataset = GRMPromptDatasetVLT2V( + data_path, + processor=processor, + tokenizer=tokenizer, + strategy=None, + max_length=8192, + config=config, + is_training=False, + ) + + data_loader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + drop_last=False, + collate_fn=dataset.collate_fn, + ) + + logger.info(f"Starting inference with CoT: {use_cot}, batch_size: {batch_size}") + + correct_samples = [] + total_samples = 0 + correct_count = 0 + + for batch_idx, batch in enumerate(tqdm(data_loader)): + try: + input_texts, image_inputs_list, video_inputs_list, extras = batch + + # Prepare inputs for vLLM (same format as test_grm_vl_vllm.py) + inputs = [] + for i in range(len(input_texts)): + prompt = input_texts[i] + image_inputs = image_inputs_list[i] + video_inputs = video_inputs_list[i] + + mm_data = {} + if image_inputs is not None: + mm_data["image"] = image_inputs + if video_inputs is not None: + mm_data["video"] = video_inputs + + inputs.append({ + "prompt": prompt, + "multi_modal_data": mm_data + }) + + # Generate with vLLM + outputs = llm.generate(inputs, sampling_params=sampling_params) + + # Decode + gen_texts = [output.outputs[0].text for output in outputs] + + # Evaluate and filter + for i, (gen_text, extra) in enumerate(zip(gen_texts, extras)): + total_samples += 1 + predicted_answer = extract_answer(gen_text) + + # Get raw item to compute gt_preference from scores + raw_item = extra.get('_raw_item', {}) + gt_preference = compute_gt_preference_from_scores(raw_item) + + # Mapping logic: + # "A" means Video 1 is better + # "B" means Video 2 is better + is_correct = False + if gt_preference == "A" and predicted_answer == "Video 1 is better": + is_correct = True + elif gt_preference == "B" and predicted_answer == "Video 2 is better": + is_correct = True + elif gt_preference == "C": + # Handle tie case (should be rare) + logger.warning(f"Tie detected in sample {total_samples}, skipping") + continue + + if is_correct: + correct_count += 1 + # Get video paths from raw item + data_root = raw_item.get('data_root', '') + video1_path = os.path.join(data_root, "videos", raw_item.get('file_name1', '')) + video2_path = os.path.join(data_root, "videos", raw_item.get('file_name2', '')) + + # Prepare sample for rejection sampling training + sample = { + "prompt": raw_item.get('prompt', ''), + "path1": video1_path, + "path2": video2_path, + "preference": gt_preference, + "generated_text": gen_text, + "predicted_answer": predicted_answer, + "score1_total": compute_total_score(raw_item, 1), + "score2_total": compute_total_score(raw_item, 2), + } + + # If we want to use the generated CoT reasoning, extract it + if use_cot: + # Extract reasoning from generated text + reasoning_match = None + # Try tag first + if "" in gen_text: + reasoning_pattern = r"(.*?)" + reasoning_match = re.search(reasoning_pattern, gen_text, re.DOTALL) + # Try as fallback (in case model uses different format) + elif "" in gen_text: + reasoning_pattern = r"(.*?)" + reasoning_match = re.search(reasoning_pattern, gen_text, re.DOTALL) + + if reasoning_match: + reasoning = reasoning_match.group(1).strip() + sample["reasoning"] = reasoning + else: + # If no reasoning found, use the full generated text (excluding answer) + answer_part = f"{predicted_answer}" if predicted_answer else "" + reasoning_candidate = gen_text.replace(answer_part, "").strip() + sample["reasoning"] = reasoning_candidate if reasoning_candidate else None + + correct_samples.append(sample) + + if total_samples % 100 == 0: + logger.info(f"Processed {total_samples} samples, {correct_count} correct ({correct_count/total_samples*100:.2f}%)") + + except Exception as e: + logger.error(f"Error at batch {batch_idx}: {e}") + import traceback + traceback.print_exc() + raise + + # Summary + accuracy = correct_count / total_samples if total_samples > 0 else 0.0 + logger.info(f"Inference completed. Accuracy: {accuracy*100:.2f}% ({correct_count}/{total_samples})") + logger.info(f"Filtered {len(correct_samples)} correct samples for rejection sampling training") + + # Save filtered samples + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, 'w', encoding='utf-8') as f: + json.dump(correct_samples, f, indent=2, ensure_ascii=False) + + logger.info(f"Saved {len(correct_samples)} correct samples to {output_path}") + + # Save statistics + stats_path = output_path.replace('.json', '_stats.txt') + with open(stats_path, 'w', encoding='utf-8') as f: + f.write(f"Dataset paths: {data_path}\n") + f.write(f"Model path: {model_path}\n") + f.write(f"Total samples: {total_samples}\n") + f.write(f"Correct samples: {correct_count}\n") + f.write(f"Accuracy: {accuracy*100:.2f}%\n") + f.write(f"Filtered samples for training: {len(correct_samples)}\n") + + return correct_samples + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Rejection Sampling Inference for T2V") + parser.add_argument("--model_path", type=str, required=True, help="Path to trained GRM model") + parser.add_argument("--data_path", type=str, required=True, help="Dataset path(s) in format 'source:path' (comma-separated for multiple)") + parser.add_argument("--output_path", type=str, required=True, help="Path to save filtered samples") + parser.add_argument("--batch_size", type=int, default=32, help="Batch size for inference") + parser.add_argument("--max_new_tokens", type=int, default=2048, help="Maximum tokens to generate") + parser.add_argument("--use_cot", action="store_true", default=True, help="Use CoT instruction for reasoning") + parser.add_argument("--task_instruction", type=str, default=TASK_INSTRUCTION_COT_T2V, help="Task instruction template") + + # vLLM arguments + parser.add_argument("--tensor_parallel_size", type=int, default=1, help="Number of GPUs for tensor parallelism") + parser.add_argument("--gpu_memory_utilization", type=float, default=0.9, help="GPU memory utilization ratio") + parser.add_argument("--video_fps", type=float, default=2.0, help="FPS for video processing") + + args = parser.parse_args() + + # Parse data path + data_paths = args.data_path.split(",") if isinstance(args.data_path, str) else args.data_path + + config = { + "task_instruction": args.task_instruction, + "name": "rejection_sampling_inference_t2v", + "video_fps": args.video_fps, + } + + inference_and_filter( + model_path=args.model_path, + data_path=data_paths, + output_path=args.output_path, + config=config, + batch_size=args.batch_size, + max_new_tokens=args.max_new_tokens, + use_cot=args.use_cot, + tensor_parallel_size=args.tensor_parallel_size, + gpu_memory_utilization=args.gpu_memory_utilization, + video_fps=args.video_fps, + ) + diff --git a/examples/grm_training/rejection_sampling/run_rejection_sampling.sh b/examples/grm_training/rejection_sampling/run_rejection_sampling.sh index ca0bed6..277533b 100644 --- a/examples/grm_training/rejection_sampling/run_rejection_sampling.sh +++ b/examples/grm_training/rejection_sampling/run_rejection_sampling.sh @@ -16,13 +16,13 @@ unset HTTPS_PROXY ############################# Configuration ########################## # Model path (cold-start model) # Please set your model path here -MODEL_PATH="" +MODEL_PATH="/mnt/shared-storage-user/puyuan/wanzunian/models/lightrlhf-grm-lr1e5-imagegen_cot_reward-qwen2.5vl3B-gs3000" # Dataset configuration # Please set your dataset path here (format: "source:path") -DATA_PATH="" +DATA_PATH="hpdv3:/mnt/shared-storage-user/puyuan/wanzunian/datasets/HPDv3/train_subset_5percent.json" # Please set your dataset root directory here -DATA_ROOT="" +DATA_ROOT="/mnt/shared-storage-user/puyuan/wanzunian/datasets/HPDv3" # Output paths OUTPUT_DIR="./results/rejection_sampling_$(date +%Y%m%d_%H%M%S)" @@ -49,8 +49,8 @@ aesthetics (composition, color usage, artistic expression), authenticity (realis and any other factors you deem relevant. For each evaluation dimension, provide a score between 1-10 for both images (e.g., Image 1: 8/10, Image 2: 6/10) and provide a concise rationale for the score. Calculate the total score for each image by summing all dimension scores. -Use a chain-of-thought process to detail your reasoning steps, and enclose all your detailed reasoning within and tags. -Then, in the tag, output exactly one of the following strings: 'Image 1 is better' or 'Image 2 is better' based on the total scores. +Use a chain-of-thought process to detail your reasoning steps, and enclose all your detailed reasoning within tags. +Then, in the tag, output exactly one of the following strings: 'Image 1 is better' or 'Image 2 is better' or 'Both are equal' based on the total scores. No additional text is allowed in the section. Example output format: @@ -65,7 +65,8 @@ Image 2: 7+8+5+8=28 Image 1 is better Note: In the example above, scores and the final answer are placeholders meant only to demonstrate the format. Your actual evaluation should be based on the quality of two given images. Your task is provided as follows: -Text Caption: {prompt}""" +Text Caption: **{prompt}** +""" ############################### Environment ##################### export GPUS_PER_NODE=${GPUS_PER_NODE:-2} # Use 2 GPUs @@ -110,10 +111,8 @@ echo "" echo "Step 1: Running inference and filtering correct samples..." echo "==========================================" -# Use torchrun for distributed inference to utilize multiple GPUs -torchrun --nnodes $NNODES --nproc-per-node $GPUS_PER_NODE \ - --node_rank $NODE_RANK --master-port $MASTER_PORT --master-addr $MASTER_ADDR \ - examples/grm_training/rejection_sampling/rejection_sampling_inference.py \ +# Use vLLM for inference (vLLM handles multi-GPU internally via tensor_parallel_size) +python examples/grm_training/rejection_sampling/rejection_sampling_inference.py \ --model_path ${MODEL_PATH} \ --data_path ${DATA_PATH} \ --output_path ${FILTERED_SAMPLES_PATH} \ @@ -121,6 +120,8 @@ torchrun --nnodes $NNODES --nproc-per-node $GPUS_PER_NODE \ --max_new_tokens ${MAX_NEW_TOKENS} \ --use_cot \ --task_instruction "${TASK_INSTRUCTION}" \ + --tensor_parallel_size ${GPUS_PER_NODE} \ + --gpu_memory_utilization 0.9 \ 2>&1 | tee ${LOG_BASE}/inference.log if [ ! -f "${FILTERED_SAMPLES_PATH}" ]; then diff --git a/examples/grm_training/rejection_sampling/run_rejection_sampling_t2v.sh b/examples/grm_training/rejection_sampling/run_rejection_sampling_t2v.sh new file mode 100755 index 0000000..8b623c8 --- /dev/null +++ b/examples/grm_training/rejection_sampling/run_rejection_sampling_t2v.sh @@ -0,0 +1,197 @@ +#!/bin/bash + +# Rejection Sampling Training Script for Text-to-Video (T2V) +# This script performs the complete rejection sampling pipeline: +# 1. Inference on dataset and filter correct samples +# 2. Convert filtered samples to training format +# 3. Train the model on filtered samples + +set -e + +unset http_proxy +unset https_proxy +unset HTTP_PROXY +unset HTTPS_PROXY + +############################# Configuration ########################## +# Model path (cold-start model) +# Please set your model path here +MODEL_PATH="/mnt/shared-storage-user/puyuan/wanzunian/models/lightrlhf-grm-lr1e5-imagegen_cot_reward-qwen2.5vl3B-gs3000" + +# Dataset configuration +# Multiple rapidata-t2v datasets +DATA_PATH=( + "rapidata-t2v:/mnt/shared-storage-user/puyuan/wanzunian/datasets/rapidata/text-2-video-human-preferences-veo3/data/train-00000-of-00001.parquet" + "rapidata-t2v:/mnt/shared-storage-user/puyuan/wanzunian/datasets/rapidata/text-2-video-human-preferences-pika2.2/data/train-00000-of-00001.parquet" + "rapidata-t2v:/mnt/shared-storage-user/puyuan/wanzunian/datasets/rapidata/text-2-video-human-preferences-wan2.1/data/train-00000-of-00001.parquet" + "rapidata-t2v:/mnt/shared-storage-user/puyuan/wanzunian/datasets/rapidata/text-2-video-human-preferences/data/train-00000-of-00001.parquet" +) + +# Output paths +OUTPUT_DIR="./results/rejection_sampling_t2v_$(date +%Y%m%d_%H%M%S)" +FILTERED_SAMPLES_PATH="${OUTPUT_DIR}/filtered_samples.json" +TRAINING_DATA_PATH="${OUTPUT_DIR}/rejection_sampling_train.json" +FINAL_CHECKPOINT_PATH="${OUTPUT_DIR}/checkpoint" + +# Training hyperparameters +TBS=8 +LR=2.5e-6 +MAX_LENGTH=13000 +MAX_EPOCHS=3 +MICRO_BATCH_SIZE=1 +GRADIENT_ACCUMULATION_STEPS=32 + +# Inference parameters (reduced to avoid OOM) +INFERENCE_BATCH_SIZE=8 +MAX_NEW_TOKENS=2048 + +# Video FPS configuration +VIDEO_FPS=2.0 + +# Task instruction for CoT reasoning (T2V) +TASK_INSTRUCTION="""Given a caption and two videos generated based on this caption, please analyze in detail the two provided videos. +Evaluate them on various dimensions such as semantic consistency (how closely the video content aligns with the caption), temporal coherence (smoothness and logical flow of motion across frames), authenticity (realism and attention to detail), and any other factors you deem relevant. +For each evaluation dimension, provide a score between 1-10 for both videos (e.g., Video 1: 8/10, Video 2: 6/10) and provide a concise rationale for the score. +Calculate the total score for each video by summing all dimension scores. +Use a chain-of-thought process to detail your reasoning steps, and enclose all your detailed reasoning within and tags. Then, in the tag, output exactly one of the following strings: +'Video 1 is better' or 'Video 2 is better' based on the total scores. No additional text is allowed in the section. +Example output format: + +1. Semantic consistency: Video 1 (9/10) - ...; Video 2 (7/10) - ... +2. Temporal coherence: Video 1 (8/10) - ...; Video 2 (6/10) - ... +3. Authenticity: Video 1 (7/10) - ...; Video 2 (5/10) - ... +... +[Additional dimensions if any]: Video 2 (8/10) - ...; Video 1 (6/10) - ... +Total score: +Video 1: 9+8+7+6=30 +Video 2: 7+6+5+8=26 + +Video 1 is better + +Note: In the example above, scores and the final answer are placeholders meant only to demonstrate the format. Your actual evaluation should be based on the quality of two given videos. +Your task is provided as follows: +Text Caption: **{prompt}** +""" + +############################### Environment ##################### +export GPUS_PER_NODE=${GPUS_PER_NODE:-2} # Use 2 GPUs +export NNODES=${NNODES:-1} +export NODE_RANK=${RANK:-0} +export MASTER_ADDR=${MASTER_ADDR:-"localhost"} +export MASTER_PORT=${MASTER_PORT:-29500} + +export WORLD_SIZE=$((NNODES * GPUS_PER_NODE)) + +# Validate required configuration +if [ -z "${MODEL_PATH}" ]; then + echo "Error: MODEL_PATH is not set. Please configure it in the script." + exit 1 +fi + +if [ ${#DATA_PATH[@]} -eq 0 ]; then + echo "Error: DATA_PATH is not set. Please configure it in the script." + exit 1 +fi + +# Create output directory +mkdir -p ${OUTPUT_DIR} +LOG_BASE="${OUTPUT_DIR}/logs" +mkdir -p ${LOG_BASE} + +echo "==========================================" +echo "Rejection Sampling Training Pipeline (T2V)" +echo "==========================================" +echo "Model: ${MODEL_PATH}" +echo "Data: ${DATA_PATH[@]}" +echo "Output: ${OUTPUT_DIR}" +echo "==========================================" + +############################### Step 1: Inference and Filter ########################## +echo "" +echo "Step 1: Running inference and filtering correct samples..." +echo "==========================================" + +# Convert array to comma-separated string for Python script +DATA_PATH_STR=$(IFS=','; echo "${DATA_PATH[*]}") + +# Use vLLM for inference (vLLM handles multi-GPU internally via tensor_parallel_size) +python examples/grm_training/rejection_sampling/rejection_sampling_inference_t2v.py \ + --model_path ${MODEL_PATH} \ + --data_path ${DATA_PATH_STR} \ + --output_path ${FILTERED_SAMPLES_PATH} \ + --batch_size ${INFERENCE_BATCH_SIZE} \ + --max_new_tokens ${MAX_NEW_TOKENS} \ + --use_cot \ + --task_instruction "${TASK_INSTRUCTION}" \ + --tensor_parallel_size ${GPUS_PER_NODE} \ + --gpu_memory_utilization 0.9 \ + --video_fps ${VIDEO_FPS} \ + 2>&1 | tee ${LOG_BASE}/inference.log + +if [ ! -f "${FILTERED_SAMPLES_PATH}" ]; then + echo "Error: Filtered samples file not created!" + exit 1 +fi + +echo "Step 1 completed. Filtered samples saved to: ${FILTERED_SAMPLES_PATH}" + +############################### Step 2: Convert to Training Format ########################## +echo "" +echo "Step 2: Converting filtered samples to training format..." +echo "==========================================" + +python examples/grm_training/rejection_sampling/convert_to_rejection_sampling_data_t2v.py \ + --filtered_samples_path ${FILTERED_SAMPLES_PATH} \ + --output_path ${TRAINING_DATA_PATH} \ + --task_instruction "${TASK_INSTRUCTION}" \ + --video_fps ${VIDEO_FPS} \ + 2>&1 | tee ${LOG_BASE}/convert.log + +if [ ! -f "${TRAINING_DATA_PATH}" ]; then + echo "Error: Training data file not created!" + exit 1 +fi + +echo "Step 2 completed. Training data saved to: ${TRAINING_DATA_PATH}" + +############################### Step 3: Training ########################## +echo "" +echo "Step 3: Training on filtered samples..." +echo "==========================================" + +# Use imagegen-cot-reward handler for the converted data (it supports video too) +TRAINING_DATA_SOURCE="imagegen-cot-reward-5k:${TRAINING_DATA_PATH}" + +torchrun --nnodes $NNODES --nproc-per-node $GPUS_PER_NODE \ + --node_rank $NODE_RANK --master-port $MASTER_PORT --master-addr $MASTER_ADDR \ + examples/grm_training/train_grm_vl.py \ + --pretrain ${MODEL_PATH} \ + --save_path ${FINAL_CHECKPOINT_PATH} \ + --ckpt_path ${FINAL_CHECKPOINT_PATH} \ + --train_batch_size ${TBS} \ + --micro_train_batch_size ${MICRO_BATCH_SIZE} \ + --max_epochs ${MAX_EPOCHS} \ + --lr_warmup_ratio 0.03 \ + --prompt_max_len ${MAX_LENGTH} \ + --fps ${VIDEO_FPS} \ + --zero_stage 3 \ + --bf16 \ + --actor_learning_rate ${LR} \ + --train_data ${TRAINING_DATA_SOURCE} \ + --gradient_checkpointing \ + --save_steps 1000 \ + --max_ckpt_num 8 \ + --use_tensorboard "${OUTPUT_DIR}/tensorboard" \ + --l2 0.0 \ + --flash_attn \ + --task_instruction "${TASK_INSTRUCTION}" \ + 2>&1 | tee ${LOG_BASE}/training.log + +echo "" +echo "==========================================" +echo "Rejection Sampling Training Completed (T2V)!" +echo "==========================================" +echo "Final checkpoint: ${FINAL_CHECKPOINT_PATH}/final_checkpoint" +echo "All outputs saved to: ${OUTPUT_DIR}" +echo "==========================================" + diff --git a/examples/grm_training/rejection_sampling/train_rejection_sampling.sh b/examples/grm_training/rejection_sampling/train_rejection_sampling.sh new file mode 100755 index 0000000..aa6cb53 --- /dev/null +++ b/examples/grm_training/rejection_sampling/train_rejection_sampling.sh @@ -0,0 +1,134 @@ +#!/bin/bash + +# Training script for rejection sampling data +# This script trains the model on the filtered rejection sampling data + +set -e + +unset http_proxy +unset https_proxy +unset HTTP_PROXY +unset HTTPS_PROXY + +############################# Configuration ########################## +# Model path (pretrained model to continue training from) +MODEL_PATH="/mnt/shared-storage-user/puyuan/wanzunian/models/lightrlhf-grm-lr1e5-imagegen_cot_reward-qwen2.5vl3B-gs3000" + +# Training data path (already converted rejection sampling data) +TRAINING_DATA_PATH="/mnt/shared-storage-user/sunjiaxuan/dec/LightRFT/results/rejection_sampling_20260102_022303/rejection_sampling_train.json" + +# Output directory for checkpoints +OUTPUT_DIR="/mnt/shared-storage-user/sunjiaxuan/dec/LightRFT/results/rejection_sampling_20260102_022303/checkpoint" +LOG_DIR="/mnt/shared-storage-user/sunjiaxuan/dec/LightRFT/results/rejection_sampling_20260102_022303/logs" + +# Training hyperparameters +TBS=4 # Reduced from 8 to save memory +LR=2.5e-6 +MAX_LENGTH=13000 +MAX_EPOCHS=3 +MICRO_BATCH_SIZE=1 +GRADIENT_ACCUMULATION_STEPS=16 # Increase to maintain effective batch size (4 * 16 = 64) + +# Task instruction for CoT reasoning (must match the one used during inference) +TASK_INSTRUCTION="""Given a caption and two images generated based on this caption, please analyze in detail the two provided images. +Evaluate them on various dimensions such as semantic consistency (how closely the image content aligns with the caption), +aesthetics (composition, color usage, artistic expression), authenticity (realism and attention to detail), +and any other factors you deem relevant. For each evaluation dimension, +provide a score between 1-10 for both images (e.g., Image 1: 8/10, Image 2: 6/10) and provide a concise rationale for the score. +Calculate the total score for each image by summing all dimension scores. +Use a chain-of-thought process to detail your reasoning steps, and enclose all your detailed reasoning within tags. +Then, in the tag, output exactly one of the following strings: 'Image 1 is better' or 'Image 2 is better' or 'Both are equal' based on the total scores. +No additional text is allowed in the section. +Example output format: + +Semantic consistency: Image 1 (9/10) - ...; Image 2 (7/10) - ... +Aesthetics: Image 2 (8/10) - ...; Image 1 (8/10) - ... +Authenticity: Image 1 (8/10) - ...; Image 2 (5/10) - ... +[Additional dimensions if any]: Image 2 (8/10) - ...; Image 1 (6/10) - ... +Total score: +Image 1: 9+8+8+6=31 +Image 2: 7+8+5+8=28 + +Image 1 is better +Note: In the example above, scores and the final answer are placeholders meant only to demonstrate the format. Your actual evaluation should be based on the quality of two given images. +Your task is provided as follows: +Text Caption: **{prompt}** +""" + +############################### Environment ##################### +export GPUS_PER_NODE=${GPUS_PER_NODE:-2} # Use 2 GPUs by default +export NNODES=${NNODES:-1} +export NODE_RANK=${RANK:-0} +export MASTER_ADDR=${MASTER_ADDR:-"localhost"} +export MASTER_PORT=${MASTER_PORT:-29500} + +export WORLD_SIZE=$((NNODES * GPUS_PER_NODE)) + +# Memory optimization: reduce fragmentation +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +# Validate required configuration +if [ ! -f "${TRAINING_DATA_PATH}" ]; then + echo "Error: Training data file not found: ${TRAINING_DATA_PATH}" + exit 1 +fi + +if [ -z "${MODEL_PATH}" ]; then + echo "Error: MODEL_PATH is not set. Please configure it in the script." + exit 1 +fi + +# Create output directories +mkdir -p ${OUTPUT_DIR} +mkdir -p ${LOG_DIR} + +echo "==========================================" +echo "Rejection Sampling Training" +echo "==========================================" +echo "Model: ${MODEL_PATH}" +echo "Training Data: ${TRAINING_DATA_PATH}" +echo "Output: ${OUTPUT_DIR}" +echo "GPUs: ${GPUS_PER_NODE}" +echo "==========================================" + +# Use imagegen-cot-reward handler for the converted data +TRAINING_DATA_SOURCE="imagegen-cot-reward-5k:${TRAINING_DATA_PATH}" + +############################### Training ########################## +echo "" +echo "Starting training on rejection sampling data..." +echo "==========================================" + +torchrun --nnodes $NNODES --nproc-per-node $GPUS_PER_NODE \ + --node_rank $NODE_RANK --master-port $MASTER_PORT --master-addr $MASTER_ADDR \ + examples/grm_training/train_grm_vl.py \ + --pretrain ${MODEL_PATH} \ + --save_path ${OUTPUT_DIR} \ + --ckpt_path ${OUTPUT_DIR} \ + --train_batch_size ${TBS} \ + --micro_train_batch_size ${MICRO_BATCH_SIZE} \ + --max_epochs ${MAX_EPOCHS} \ + --lr_warmup_ratio 0.03 \ + --prompt_max_len ${MAX_LENGTH} \ + --fps 2.0 \ + --zero_stage 3 \ + --bf16 \ + --actor_learning_rate ${LR} \ + --train_data ${TRAINING_DATA_SOURCE} \ + --gradient_checkpointing \ + --save_steps 1000 \ + --max_ckpt_num 2 \ + --use_tensorboard "${OUTPUT_DIR}/../tensorboard" \ + --l2 0.0 \ + --flash_attn \ + --task_instruction "${TASK_INSTRUCTION}" \ + 2>&1 | tee ${LOG_DIR}/training.log + +echo "" +echo "==========================================" +echo "Training Completed!" +echo "==========================================" +echo "Final checkpoint: ${OUTPUT_DIR}/final_checkpoint" +echo "Training logs: ${LOG_DIR}/training.log" +echo "==========================================" + diff --git a/lightrft/datasets/__init__.py b/lightrft/datasets/__init__.py index 06e06ac..c5ab354 100644 --- a/lightrft/datasets/__init__.py +++ b/lightrft/datasets/__init__.py @@ -17,5 +17,28 @@ from .prompts_dataset_vl import PromptDatasetVL from .sft_dataset import SFTDataset from .sft_dataset_vl import SFTDatasetVL +from .rft_dataset import RFTDatasetVL -__all__ = ["ProcessRewardDataset", "PromptDataset", "PromptDatasetVL", "SFTDataset", "SFTDatasetVL"] +# Import PairHandlers for RFTDatasetVL +from .rapidata import RapidataT2VPairHandler, RapidataI2VPairHandler +from .hpdv3 import HPDv3PairHandler +from .omnirewardbench import ( + OmniRewardBenchT2IPairHandler, + OmniRewardBenchT2VPairHandler, + VideoGenRewardBenchPairHandler, +) + +__all__ = [ + "ProcessRewardDataset", + "PromptDataset", + "PromptDatasetVL", + "SFTDataset", + "SFTDatasetVL", + "RFTDatasetVL", + "RapidataT2VPairHandler", + "RapidataI2VPairHandler", + "HPDv3PairHandler", + "OmniRewardBenchT2IPairHandler", + "OmniRewardBenchT2VPairHandler", + "VideoGenRewardBenchPairHandler", +] diff --git a/lightrft/datasets/hpdv3.py b/lightrft/datasets/hpdv3.py index 9558a2a..0ab74e1 100644 --- a/lightrft/datasets/hpdv3.py +++ b/lightrft/datasets/hpdv3.py @@ -133,6 +133,10 @@ def parse_item(self, item: Dict[str, Any], media_content: Dict[str, Any], return messages0, messages1, other +# Alias for RFTDatasetVL compatibility +HPDv3PairHandler = HPDv3Handler + + class HPDv3GRMHandler(HPDv3Handler): """ Data Handler for HPDv3 dataset with Generative Reward Model (GRM) training. diff --git a/lightrft/datasets/imagegen_cot_reward.py b/lightrft/datasets/imagegen_cot_reward.py index 1e699c3..1b9aa2b 100644 --- a/lightrft/datasets/imagegen_cot_reward.py +++ b/lightrft/datasets/imagegen_cot_reward.py @@ -31,23 +31,53 @@ def load_data(self, path: str) -> List[Dict[str, Any]]: def get_media_info(self, item: Dict[str, Any]) -> Dict[str, Dict[str, str]]: """ - Extract path info for the two images. + Extract path info for the two images or videos. + Supports both images and videos based on file extension or video_fps field. """ data_root = item['data_root'] if not data_root: - raise ValueError(f"Missing 'data_root' in item. Cannot resolve image paths.") - images = item['images'] - image0_full_path = os.path.join(data_root, images[0]) - image1_full_path = os.path.join(data_root, images[1]) - - return { - 'image0': { - 'image_local_path': image0_full_path - }, - 'image1': { - 'image_local_path': image1_full_path - }, - } + raise ValueError(f"Missing 'data_root' in item. Cannot resolve media paths.") + media_paths = item['images'] # Can contain image or video paths + path0 = media_paths[0] if isinstance(media_paths[0], str) else media_paths[0] + path1 = media_paths[1] if isinstance(media_paths[1], str) else media_paths[1] + + # Check if paths are absolute or relative + if os.path.isabs(path0): + media0_full_path = path0 + else: + media0_full_path = os.path.join(data_root, path0) + + if os.path.isabs(path1): + media1_full_path = path1 + else: + media1_full_path = os.path.join(data_root, path1) + + # Check if it's video based on file extension or video_fps field + video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.webm'} + is_video = ( + item.get('video_fps') is not None or + any(media0_full_path.lower().endswith(ext) for ext in video_extensions) or + any(media1_full_path.lower().endswith(ext) for ext in video_extensions) + ) + + if is_video: + return { + 'video0': { + 'video_local_path': media0_full_path + }, + 'video1': { + 'video_local_path': media1_full_path + }, + } + else: + return { + 'image0': { + 'image_local_path': media0_full_path + }, + 'image1': { + 'image_local_path': media1_full_path + }, + } def parse_item( self, @@ -55,41 +85,87 @@ def parse_item( media_content: Dict[str, Any], config: Dict[str, Any] | None, ) -> Tuple[List[Dict], List[Dict], Dict]: - - image0 = media_content['image0'] - image1 = media_content['image1'] - - if not all([image0, image1]): - raise ValueError(f"Missing visual content for 'image0' or 'image1'.") - - # Get conversations from data item - conversations = item["conversations"] - system_prompt = conversations[0]['value'] - response = conversations[-1]['value'] - - # Build messages - messages = [{ - "role": "system", - "content": system_prompt - }, { - "role": "user", - "content": [{ - "type": "text", - "text": "**Image 1:**" + + # Check if it's video or image + is_video = 'video0' in media_content or 'video1' in media_content + + if is_video: + video0 = media_content.get('video0') + video1 = media_content.get('video1') + + if not all([video0, video1]): + raise ValueError(f"Missing visual content for 'video0' or 'video1'.") + + # Get FPS from config or item + fps = config.get("video_fps") if config else item.get("video_fps", 2.0) + + # Get conversations from data item + conversations = item["conversations"] + system_prompt = conversations[0]['value'] + response = conversations[-1]['value'] + + # Build messages for video + messages = [{ + "role": "system", + "content": system_prompt }, { - "type": "image", - "image": image0 + "role": "user", + "content": [{ + "type": "text", + "text": "**Video 1:**" + }, { + "type": "video", + "video": video0 if isinstance(video0, str) else video0.get('video_local_path'), + "fps": fps, + "max_pixels": 720 * 480 + }] + }, { + "role": "user", + "content": [{ + "type": "text", + "text": "**Video 2:**" + }, { + "type": "video", + "video": video1 if isinstance(video1, str) else video1.get('video_local_path'), + "fps": fps, + "max_pixels": 720 * 480 + }] }] - }, { - "role": "user", - "content": [{ - "type": "text", - "text": "**Image 2:**" + else: + image0 = media_content.get('image0') + image1 = media_content.get('image1') + + if not all([image0, image1]): + raise ValueError(f"Missing visual content for 'image0' or 'image1'.") + + # Get conversations from data item + conversations = item["conversations"] + system_prompt = conversations[0]['value'] + response = conversations[-1]['value'] + + # Build messages for image + messages = [{ + "role": "system", + "content": system_prompt + }, { + "role": "user", + "content": [{ + "type": "text", + "text": "**Image 1:**" + }, { + "type": "image", + "image": image0 + }] }, { - "type": "image", - "image": image1 + "role": "user", + "content": [{ + "type": "text", + "text": "**Image 2:**" + }, { + "type": "image", + "image": image1 + }] }] - }] # During evaluation, we do not include the response part in the messages is_training = config.get("is_training", True) @@ -97,7 +173,7 @@ def parse_item( messages.append({"role": "assistant", "content": response}) other = { - "source": item['source'], + "source": item.get('source', 'imagegen-cot-reward-5k'), "data_item": item, "system_prompt": system_prompt, "response": response, diff --git a/lightrft/datasets/omnirewardbench.py b/lightrft/datasets/omnirewardbench.py index b5ba9f5..6735b0a 100644 --- a/lightrft/datasets/omnirewardbench.py +++ b/lightrft/datasets/omnirewardbench.py @@ -371,3 +371,19 @@ def parse_item(self, item: Dict[str, Any], media_content: Dict[str, Any], "model2": item['model2'], } return messages, other + + +# Aliases for RFTDatasetVL compatibility +OmniRewardBenchT2IPairHandler = OmniRewardBenchT2IHandler +OmniRewardBenchT2VPairHandler = OmniRewardBenchT2VHandler + + +# Placeholder for VideoGenRewardBenchPairHandler +# TODO: Implement this handler when VideoGenRewardBench dataset is available +class VideoGenRewardBenchPairHandler(OmniRewardBenchT2VHandler): + """ + Placeholder handler for VideoGenRewardBench dataset. + Currently uses OmniRewardBenchT2VHandler as a fallback. + TODO: Implement proper handler when dataset is available. + """ + pass diff --git a/lightrft/datasets/rapidata.py b/lightrft/datasets/rapidata.py index cc58546..daaa7c7 100644 --- a/lightrft/datasets/rapidata.py +++ b/lightrft/datasets/rapidata.py @@ -23,6 +23,8 @@ class RapidataT2VHandler(BaseDataHandler): Dataset Repo: https://huggingface.co/Rapidata/datasets """ + task_type = "text-to-video" + def load_data(self, path: str) -> List[Dict[str, Any]]: """ Loads data from parquet file. @@ -52,16 +54,23 @@ def get_media_info(self, item: Dict[str, Any]) -> Dict[str, Dict[str, str]]: if 'file_name1' not in item or 'file_name2' not in item: raise ValueError(f"Item missing 'file_name1' or 'file_name2'.") - full_path1 = os.path.join(data_root, "videos", item['file_name1']) - full_path2 = os.path.join(data_root, "videos", item['file_name2']) + # Try both "videos" and "Videos" + video_dir = "videos" + if not os.path.exists(os.path.join(data_root, video_dir)): + if os.path.exists(os.path.join(data_root, "Videos")): + video_dir = "Videos" + + full_path1 = os.path.join(data_root, video_dir, item['file_name1']) + full_path2 = os.path.join(data_root, video_dir, item['file_name2']) return {'video1': {'video_local_path': full_path1}, 'video2': {'video_local_path': full_path2}} def _get_label(self, val1: float, val2: float) -> str: """ Helper to determine preference label based on two scores. - A > B, B > A, C == C """ + if val1 is None or val2 is None: + return "C" if val1 > val2: return "A" elif val1 < val2: @@ -85,6 +94,9 @@ def parse_item(self, item: Dict[str, Any], media_content: Dict[str, Any], task_instruction_template = config["task_instruction"] task_instruction = task_instruction_template.format(prompt=video_gen_prompt) + # Get max_pixels from config (default to 720 * 480 if not provided) + max_pixels = config.get("max_pixels", 720 * 480) + # Get FPS from config fps = config["video_fps"] @@ -100,8 +112,8 @@ def parse_item(self, item: Dict[str, Any], media_content: Dict[str, Any], "type": "video", "video": video1, "fps": fps, - "max_pixels": 720 * 480 - } # 480p limit to reduce memory + "max_pixels": max_pixels + } ] } ] @@ -115,21 +127,25 @@ def parse_item(self, item: Dict[str, Any], media_content: Dict[str, Any], "type": "video", "video": video2, "fps": fps, - "max_pixels": 720 * 480 + "max_pixels": max_pixels }] }] - # Get human preference labels based on weighted scores - pref_label = self._get_label(item["weighted_results1_Preference"], item["weighted_results2_Preference"]) - cohe_label = self._get_label(item["weighted_results1_Coherence"], item["weighted_results2_Coherence"]) - align_label = self._get_label(item['weighted_results1_Alignment'], item['weighted_results2_Alignment']) + # Get human preference labels and total scores based on weighted metrics + metrics = ['Preference', 'Coherence', 'Alignment'] + labels = { + f"{m.lower()}_label": self._get_label(item.get(f'weighted_results1_{m}'), item.get(f'weighted_results2_{m}')) + for m in metrics + } + + score1 = sum(item.get(f'weighted_results1_{m}') or 0.0 for m in metrics) + score2 = sum(item.get(f'weighted_results2_{m}') or 0.0 for m in metrics) other = { - "preference": pref_label, - "coherence": cohe_label, - "alignment": align_label, + "preference": self._get_label(score1, score2), + **labels, "source": item['source'], - "task_type": "t2v", + "task_type": self.task_type, } return messages0, messages1, other @@ -142,6 +158,8 @@ class RapidataI2VHandler(RapidataT2VHandler): Dataset Repo: https://huggingface.co/Rapidata/datasets """ + task_type = "image-to-video" + def __init__(self): super().__init__() @@ -204,6 +222,9 @@ def parse_item(self, item: Dict[str, Any], media_content: Dict[str, Any], task_instruction_template = config["task_instruction"] task_instruction = task_instruction_template.format(prompt=prompt_text) + # Get max_pixels from config (default to 720 * 480 if not provided) + max_pixels = config.get("max_pixels", 720 * 480) + # Get FPS from config fps = config["video_fps"] @@ -216,12 +237,12 @@ def parse_item(self, item: Dict[str, Any], media_content: Dict[str, Any], "content": [{ "type": "image", "image": copy.deepcopy(init_image), - "max_pixels": 720 * 480 + "max_pixels": max_pixels }, { "type": "video", "video": video1, "fps": fps, - "max_pixels": 720 * 480 + "max_pixels": max_pixels }] }] @@ -233,25 +254,177 @@ def parse_item(self, item: Dict[str, Any], media_content: Dict[str, Any], "content": [{ "type": "image", "image": copy.deepcopy(init_image), - "max_pixels": 720 * 480 + "max_pixels": max_pixels }, { "type": "video", "video": video2, "fps": fps, - "max_pixels": 720 * 480 + "max_pixels": max_pixels }] }] - # Get human preference labels based on weighted scores - pref_label = self._get_label(item['weighted_results1_Preference'], item['weighted_results2_Preference']) - cohe_label = self._get_label(item['weighted_results1_Coherence'], item['weighted_results2_Coherence']) - align_label = self._get_label(item['weighted_results1_Alignment'], item['weighted_results2_Alignment']) + # Get human preference labels and total scores based on weighted metrics + metrics = ['Preference', 'Coherence', 'Alignment'] + labels = { + f"{m.lower()}_label": self._get_label(item.get(f'weighted_results1_{m}'), item.get(f'weighted_results2_{m}')) + for m in metrics + } + + score1 = sum(item.get(f'weighted_results1_{m}') or 0.0 for m in metrics) + score2 = sum(item.get(f'weighted_results2_{m}') or 0.0 for m in metrics) other = { - "preference": pref_label, - "coherence": cohe_label, - "alignment": align_label, + "preference": self._get_label(score1, score2), + **labels, "source": item['source'], - "task_type": "t2v", # Text-to-Video + "task_type": self.task_type, } return messages0, messages1, other + + +class RapidataT2VPairHandler(RapidataT2VHandler): + """ + Data Handler for Rapidata text-to-video human preferences dataset in pairwise format. + """ + def __init__(self): + super().__init__() + + def parse_item(self, + item: Dict[str, Any], + media_content: Dict[str, Any], + config: Dict[str, Any] + ) -> Tuple[List[Dict], List[Dict], Dict]: + + video1 = media_content['video1'] + video2 = media_content['video2'] + + if not all([video1, video2]): + raise ValueError(f"Missing visual content for 'video1' or 'video2'.") + + # Get generation prompt from data item + video_gen_prompt = item["prompt"] + + # Get system prompts from config + task_instruction_template = config["task_instruction"] + task_instruction = task_instruction_template.format(prompt=video_gen_prompt) + + # Get max_pixels from config (default to 720 * 480 if not provided) + max_pixels = config.get("max_pixels", 720 * 480) + + # Get FPS from config + fps = config["video_fps"] + + # Build messages + messages = [ + { + "role": "system", + "content": [{"type": "text", "text": task_instruction}] + }, + { + "role": "user", + "content": [ + {"type": "text", "text": "The following is the first video."}, + {"type": "video", "video": video1, "fps": fps, "max_pixels": max_pixels}, + + {"type": "text", "text": "The following is the second video."}, + {"type": "video", "video": video2, "fps": fps, "max_pixels": max_pixels}, + ] + } + ] + + # Get human preference labels and total scores based on weighted metrics + metrics = ['Preference', 'Coherence', 'Alignment'] + labels = { + f"{m.lower()}_label": self._get_label(item.get(f'weighted_results1_{m}'), item.get(f'weighted_results2_{m}')) + for m in metrics + } + + score1 = sum(item.get(f'weighted_results1_{m}') or 0.0 for m in metrics) + score2 = sum(item.get(f'weighted_results2_{m}') or 0.0 for m in metrics) + + other = { + "preference": self._get_label(score1, score2), + **labels, + "source": item['source'], + "task_type": self.task_type, + } + return messages, other + + +class RapidataI2VPairHandler(RapidataI2VHandler): + """ + Data Handler for Rapidata image-to-video human preferences dataset in pairwise format. + """ + task_type = "image-to-video" + + def __init__(self): + super().__init__() + + def parse_item(self, + item: Dict[str, Any], + media_content: Dict[str, Any], + config: Dict[str, Any] + ) -> Tuple[List[Dict], List[Dict], Dict]: + + video1 = media_content['video1'] + video2 = media_content['video2'] + init_image = media_content['init_image'] + + if not all([video1, video2, init_image]): + raise ValueError("Missing visual content for 'video1' or 'video2' or 'init_image'.") + + # Get generation prompt from data item + prompt_text = item["prompt"] + + # Get system prompts from config + task_instruction_template = config["task_instruction"] + task_instruction = task_instruction_template.format(prompt=prompt_text) + + # Get FPS from config + fps = config["video_fps"] + max_pixels = config.get("max_pixels", 720 * 480) + + # Build messages + messages = [ + { + "role": "system", + "content": [{"type": "text", "text": task_instruction}] + }, + { + "role": "user", + "content": [ + {"type": "text", "text": "Reference Image:"}, + {"type": "image", "image": init_image, "max_pixels": max_pixels}, + + {"type": "text", "text": "The following is the first video."}, + {"type": "video", "video": video1, "fps": fps, "max_pixels": max_pixels}, + + {"type": "text", "text": "The following is the second video."}, + {"type": "video", "video": video2, "fps": fps, "max_pixels": max_pixels}, + ] + } + ] + + # Get human preference labels and total scores based on weighted metrics + metrics = ['Preference', 'Coherence', 'Alignment'] + labels = { + f"{m.lower()}_label": self._get_label(item.get(f'weighted_results1_{m}'), item.get(f'weighted_results2_{m}')) + for m in metrics + } + + score1 = sum(item.get(f'weighted_results1_{m}') or 0.0 for m in metrics) + score2 = sum(item.get(f'weighted_results2_{m}') or 0.0 for m in metrics) + + other = { + "preference": self._get_label(score1, score2), + **labels, + "source": item['source'], + "task_type": self.task_type, + } + return messages, other + + + +# Alias for RFTDatasetVL compatibility +RapidataT2VPairHandler = RapidataT2VHandler +RapidataI2VPairHandler = RapidataI2VHandler diff --git a/lightrft/datasets/rft_dataset.py b/lightrft/datasets/rft_dataset.py new file mode 100644 index 0000000..c2b8d60 --- /dev/null +++ b/lightrft/datasets/rft_dataset.py @@ -0,0 +1,171 @@ +import random +import copy + +from torch.utils.data import Dataset +from torch.nn.utils.rnn import pad_sequence +from transformers import AutoTokenizer, AutoProcessor + +from loguru import logger +from typing import Any, Dict, List, Tuple, Union + +from .utils import load_multimodal_content +from lightrft.datasets import ( + RapidataI2VPairHandler, + RapidataT2VPairHandler, + VideoGenRewardBenchPairHandler, + HPDv3PairHandler, + ImageGenCoTRewardHandler, + OmniRewardBenchT2IPairHandler, + OmniRewardBenchT2VPairHandler, +) + + +class RFTDatasetVL(Dataset): + """ + Dataset for Reinforcement Fine-Tuning (RFT) with vision-language models. + + RFTDatasetVL supports multiple data sources through pluggable Data Handlers + and is designed for training models using reinforcement learning. + + It loads data items, processes multimodal content (images, videos), and + prepares inputs suitable for the model. + + :param dataset_paths: List of dataset file paths or directories. The + handler is determined by the source keyword (e.g. "rapidata-t2v", + "videogen-rewardbench"). The format is "source:path". + e.g. "rapidata-t2v:/path/to/file.parquet" + :type dataset_paths: List[str] + :param processor: Multimodal processor used for tokenization and visual + processing. + :type processor: transformers.AutoProcessor + :param tokenizer: Tokenizer used for text tokenization. + :type tokenizer: transformers.AutoTokenizer + :param strategy: Optional data loading strategy. + :type strategy: Any + :param max_length: Maximum sequence length for tokenization/truncation. + Defaults to 4096. + :type max_length: int + :param is_train: Whether the dataset is used for training or evaluation. + Defaults to True. + :type is_train: bool + :param config: Additional configuration options. + :type config: Dict[str, Any] + """ + def __init__( + self, + dataset_paths: List[str], + processor: AutoProcessor, + tokenizer: AutoTokenizer, + strategy = None, + max_length: int = 4096, + is_train: bool = True, + config: Dict[str, Any] = None, + ): + + super().__init__() + self.processor = processor + self.tokenizer = tokenizer + self.strategy = strategy + self.max_length = max_length + self.is_train = is_train + self.config = config if config else {} + + self.media_content_loader = load_multimodal_content + + if "qwen" in self.processor.__class__.__name__.lower(): + from qwen_vl_utils import process_vision_info + self.process_vision_info = process_vision_info + else: + raise NotImplementedError(f"Processor type {self.processor.__class__.__name__} not supported yet.") + + self.handlers = { + "rapidata-i2v": RapidataI2VPairHandler(), + "rapidata-t2v": RapidataT2VPairHandler(), + "videogen-rewardbench": VideoGenRewardBenchPairHandler(), + "hpdv3": HPDv3PairHandler(), + "imagegen-cot-reward-5k": ImageGenCoTRewardHandler(), + "omnirewardbench-t2i": OmniRewardBenchT2IPairHandler(), + "omnirewardbench-t2v": OmniRewardBenchT2VPairHandler(), + } + + # Load data from all specified dataset paths + # We expect dataset_paths to be in the format: "source:path" + # e.g. "rapidata-t2v:/path/to/file.parquet" + self.data = [] + for item in dataset_paths: + try: + source, path = item.split(":", 1) + except ValueError: + raise ValueError(f"Dataset path '{item}' is not in the expected format 'source:path'.") + + if source not in self.handlers: + raise NotImplementedError(f"The data handler for source {source} is not implemented.") + + handler = self.handlers[source] + try: + loaded_items = handler.load_data(path) + for item in loaded_items: + item["source"] = source + self.data.extend(loaded_items) + except Exception as e: + logger.error(f"Failed to load data {path} (source: {source}): {e}") + + logger.info(f"Loaded {len(self.data)} items in total, sources: {[s for s in dataset_paths]}") + random.shuffle(self.data) + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + item = self.data[idx] + source = item["source"] + + handler = self.handlers[source] + + # Get paths for all media content + media_info = handler.get_media_info(item) + + # Load all media content at once + loaded_content = self.media_content_loader(media_info) + if loaded_content is None: + raise RuntimeError(f"Failed to load media content: {media_info}") + + # Select prompt based on task type or source if task_instruction is a dict + config = copy.deepcopy(self.config) + task_instruction = config.get("task_instruction") + if isinstance(task_instruction, dict): + if hasattr(handler, "task_type"): + prompt = task_instruction.get(handler.task_type) + if prompt is None: + raise ValueError(f"Task instruction for {handler.task_type} not found.") + else: + raise ValueError(f"Handler for source {source} does not specify a task_type.") + + config["task_instruction"] = prompt + + # Pass the loaded content dict to parse_item + messages, reference = handler.parse_item(item, loaded_content, config) + + # Prepare inputs from message sequences + input_text, image_inputs, video_inputs = self._prepare_inputs(messages) + + # Configure label, by default "general" + label = "general" + + return input_text, image_inputs, video_inputs, reference, label + + def _prepare_inputs(self, messages): + if not self.is_train: + # For evaluation, we only need the input text without generation prompt + # Remove messages with role "assistant" + messages = [msg for msg in messages if msg["role"] != "assistant"] + + input_text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + + image_inputs, video_inputs = self.process_vision_info(messages, return_video_kwargs=False) + + return input_text, image_inputs, video_inputs + + def collate_fn(self, batch): + input_texts, image_inputs_list, video_inputs_list, references, labels = zip(*batch) + return list(input_texts), list(image_inputs_list), list(video_inputs_list), list(references), list(labels) \ No newline at end of file diff --git a/lightrft/models/grm_vl.py b/lightrft/models/grm_vl.py index fccaa0f..983c33e 100644 --- a/lightrft/models/grm_vl.py +++ b/lightrft/models/grm_vl.py @@ -65,7 +65,7 @@ def __init__( pretrain_or_model, trust_remote_code=True, attn_implementation=attn_implementation, - dtype=torch.bfloat16 if bf16 else "auto", + torch_dtype=torch.bfloat16 if bf16 else "auto", device_map=device_map, ) diff --git a/lightrft/strategy/strategy_base.py b/lightrft/strategy/strategy_base.py index ab62265..1d3c248 100644 --- a/lightrft/strategy/strategy_base.py +++ b/lightrft/strategy/strategy_base.py @@ -37,7 +37,7 @@ set_sequence_parallel_group, ) from lightrft.strategy.utils.statistic import GenLenAnalyser -from .sglang_utils import get_sglang_engine_for_rollout +# from .sglang_utils import get_sglang_engine_for_rollout from .vllm_utils import get_vllm_engine_for_rollout from lightrft.strategy.config import StrategyConfig From 812e94167a2cb85b1e3306c2b17f82491dfbd2c1 Mon Sep 17 00:00:00 2001 From: Sun Jiaxuan Date: Thu, 15 Jan 2026 11:24:54 +0800 Subject: [PATCH 5/7] feature(sunjx): add t2v rejection sampling dataset handler --- .../train_rejection_sampling_t2v.sh | 138 ++++++++++++++ lightrft/datasets/grm_dataset.py | 2 + lightrft/datasets/rejection_sampling_t2v.py | 177 ++++++++++++++++++ 3 files changed, 317 insertions(+) create mode 100755 examples/grm_training/rejection_sampling/train_rejection_sampling_t2v.sh create mode 100644 lightrft/datasets/rejection_sampling_t2v.py diff --git a/examples/grm_training/rejection_sampling/train_rejection_sampling_t2v.sh b/examples/grm_training/rejection_sampling/train_rejection_sampling_t2v.sh new file mode 100755 index 0000000..1f05bc7 --- /dev/null +++ b/examples/grm_training/rejection_sampling/train_rejection_sampling_t2v.sh @@ -0,0 +1,138 @@ +#!/bin/bash + +# Training script for rejection sampling T2V (text-to-video) data +# This script trains the model on the filtered rejection sampling video data + +set -e + +unset http_proxy +unset https_proxy +unset HTTP_PROXY +unset HTTPS_PROXY + +############################# Configuration ########################## +# Model path (pretrained model to continue training from) +MODEL_PATH="/mnt/shared-storage-user/puyuan/wanzunian/models/lightrlhf-grm-lr1e5-imagegen_cot_reward-qwen2.5vl3B-gs3000" + +# Training data path (already converted rejection sampling data) +# This should be the output from convert_to_rejection_sampling_data_t2v.py +TRAINING_DATA_PATH="/mnt/shared-storage-user/sunjiaxuan/dec/LightRFT/results/rejection_sampling_t2v_20260104_193830/rejection_sampling_train.json" + +# Output directory for checkpoints +OUTPUT_DIR="/mnt/shared-storage-user/sunjiaxuan/dec/LightRFT/results/rejection_sampling_t2v_20260104_193830/checkpoint" +LOG_DIR="/mnt/shared-storage-user/sunjiaxuan/dec/LightRFT/results/rejection_sampling_t2v_20260104_193830/logs" + +# Training hyperparameters +TBS=8 +LR=2.5e-6 +MAX_LENGTH=13000 +MAX_EPOCHS=3 +MICRO_BATCH_SIZE=1 +GRADIENT_ACCUMULATION_STEPS=32 # Increase to maintain effective batch size + +# Video FPS configuration +VIDEO_FPS=2.0 + +# Task instruction for CoT reasoning (T2V) - must match the one used during inference +TASK_INSTRUCTION="""Given a caption and two videos generated based on this caption, please analyze in detail the two provided videos. +Evaluate them on various dimensions such as semantic consistency (how closely the video content aligns with the caption), temporal coherence (smoothness and logical flow of motion across frames), authenticity (realism and attention to detail), and any other factors you deem relevant. +For each evaluation dimension, provide a score between 1-10 for both videos (e.g., Video 1: 8/10, Video 2: 6/10) and provide a concise rationale for the score. +Calculate the total score for each video by summing all dimension scores. +Use a chain-of-thought process to detail your reasoning steps, and enclose all your detailed reasoning within and tags. Then, in the tag, output exactly one of the following strings: +'Video 1 is better' or 'Video 2 is better' based on the total scores. No additional text is allowed in the section. +Example output format: + +1. Semantic consistency: Video 1 (9/10) - ...; Video 2 (7/10) - ... +2. Temporal coherence: Video 1 (8/10) - ...; Video 2 (6/10) - ... +3. Authenticity: Video 1 (7/10) - ...; Video 2 (5/10) - ... +... +[Additional dimensions if any]: Video 2 (8/10) - ...; Video 1 (6/10) - ... +Total score: +Video 1: 9+8+7+6=30 +Video 2: 7+6+5+8=26 + +Video 1 is better + +Note: In the example above, scores and the final answer are placeholders meant only to demonstrate the format. Your actual evaluation should be based on the quality of two given videos. +Your task is provided as follows: +Text Caption: **{prompt}** +""" + +############################### Environment ##################### +export GPUS_PER_NODE=${GPUS_PER_NODE:-2} # Use 2 GPUs by default +export NNODES=${NNODES:-1} +export NODE_RANK=${RANK:-0} +export MASTER_ADDR=${MASTER_ADDR:-"localhost"} +export MASTER_PORT=${MASTER_PORT:-29500} + +export WORLD_SIZE=$((NNODES * GPUS_PER_NODE)) + +# Memory optimization: reduce fragmentation +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +# Validate required configuration +if [ ! -f "${TRAINING_DATA_PATH}" ]; then + echo "Error: Training data file not found: ${TRAINING_DATA_PATH}" + echo "Please run convert_to_rejection_sampling_data_t2v.py first to convert filtered samples." + exit 1 +fi + +if [ -z "${MODEL_PATH}" ]; then + echo "Error: MODEL_PATH is not set. Please configure it in the script." + exit 1 +fi + +# Create output directories +mkdir -p ${OUTPUT_DIR} +mkdir -p ${LOG_DIR} + +echo "==========================================" +echo "Rejection Sampling Training (T2V)" +echo "==========================================" +echo "Model: ${MODEL_PATH}" +echo "Training Data: ${TRAINING_DATA_PATH}" +echo "Output: ${OUTPUT_DIR}" +echo "GPUs: ${GPUS_PER_NODE}" +echo "Video FPS: ${VIDEO_FPS}" +echo "==========================================" + +# Use rejection-sampling-t2v handler for the converted data +TRAINING_DATA_SOURCE="rejection-sampling-t2v:${TRAINING_DATA_PATH}" + +############################### Training ########################## +echo "" +echo "Starting training on rejection sampling T2V data..." +echo "==========================================" + +torchrun --nnodes $NNODES --nproc-per-node $GPUS_PER_NODE \ + --node_rank $NODE_RANK --master-port $MASTER_PORT --master-addr $MASTER_ADDR \ + examples/grm_training/train_grm_vl.py \ + --pretrain ${MODEL_PATH} \ + --save_path ${OUTPUT_DIR} \ + --ckpt_path ${OUTPUT_DIR} \ + --train_batch_size ${TBS} \ + --micro_train_batch_size ${MICRO_BATCH_SIZE} \ + --max_epochs ${MAX_EPOCHS} \ + --lr_warmup_ratio 0.03 \ + --prompt_max_len ${MAX_LENGTH} \ + --fps ${VIDEO_FPS} \ + --zero_stage 3 \ + --bf16 \ + --actor_learning_rate ${LR} \ + --train_data ${TRAINING_DATA_SOURCE} \ + --gradient_checkpointing \ + --save_steps 1000 \ + --max_ckpt_num 8 \ + --use_tensorboard "${OUTPUT_DIR}/../tensorboard" \ + --l2 0.0 \ + --flash_attn \ + --task_instruction "${TASK_INSTRUCTION}" \ + 2>&1 | tee ${LOG_DIR}/training.log + +echo "" +echo "==========================================" +echo "Training Completed!" +echo "==========================================" +echo "Final checkpoint: ${OUTPUT_DIR}/final_checkpoint" +echo "Training logs: ${LOG_DIR}/training.log" +echo "==========================================" diff --git a/lightrft/datasets/grm_dataset.py b/lightrft/datasets/grm_dataset.py index f65b6ee..2e1bbb1 100644 --- a/lightrft/datasets/grm_dataset.py +++ b/lightrft/datasets/grm_dataset.py @@ -8,6 +8,7 @@ from .omnirewardbench import OmniRewardBenchT2IGRMHandler from .imagegen_cot_reward import ImageGenCoTRewardHandler from .hpdv3 import HPDv3GRMHandler +from .rejection_sampling_t2v import RejectionSamplingT2VHandler from .utils import zero_pad_sequences, load_multimodal_content, find_subsequence @@ -86,6 +87,7 @@ def __init__( "imagegen-cot-reward-5k": ImageGenCoTRewardHandler(), "omnirewardbench-t2i": OmniRewardBenchT2IGRMHandler(), "hpdv3": HPDv3GRMHandler(), + "rejection-sampling-t2v": RejectionSamplingT2VHandler(), } # Load data from all specified dataset paths diff --git a/lightrft/datasets/rejection_sampling_t2v.py b/lightrft/datasets/rejection_sampling_t2v.py new file mode 100644 index 0000000..1388def --- /dev/null +++ b/lightrft/datasets/rejection_sampling_t2v.py @@ -0,0 +1,177 @@ +import os +import copy +import json +from typing import List, Dict, Any, Tuple +from loguru import logger + +from .utils import BaseDataHandler + + +class RejectionSamplingT2VHandler(BaseDataHandler): + """ + Data handler for Rejection Sampling text-to-video training data. + This handler processes video pairs for GRM training, similar to RapidataT2VHandler format. + + The data format is similar to imagegen-cot-reward but specifically for videos. + Each item contains: + - conversations: [{"from": "human", "value": task_instruction}, {"from": "gpt", "value": response}] + - images: [video1_path, video2_path] # Note: uses "images" field name for compatibility + - video_fps: float + """ + task_type = "text-to-video" + + def load_data(self, path: str) -> List[Dict[str, Any]]: + """ + Loads data from json file. + """ + raw_data = [] + with open(path, 'rb') as f: + raw_data = json.load(f) + + data_root = os.path.dirname(path) + for item in raw_data: + item['data_root'] = data_root + + logger.info(f"Loaded {len(raw_data)} samples from {path}") + return raw_data + + def _resolve_video_path(self, path: str, data_root: str) -> str: + """ + Resolve video path, handling case-insensitive 'videos'/'Videos' directory. + Similar to RapidataT2VHandler format. + """ + # Check if path is absolute or relative + if os.path.isabs(path): + full_path = path + else: + full_path = os.path.join(data_root, path) + + # If file exists, return it directly + if os.path.exists(full_path): + return full_path + + # Try to handle videos/Videos case sensitivity issue + # Replace 'videos' with 'Videos' or vice versa in the path + if '/videos/' in full_path: + alt_path = full_path.replace('/videos/', '/Videos/') + if os.path.exists(alt_path): + return alt_path + elif '/Videos/' in full_path: + alt_path = full_path.replace('/Videos/', '/videos/') + if os.path.exists(alt_path): + return alt_path + + # Also handle case where path ends with '/videos' or '/Videos' + if full_path.endswith('/videos'): + alt_path = full_path[:-7] + '/Videos' + if os.path.exists(alt_path): + return alt_path + elif full_path.endswith('/Videos'): + alt_path = full_path[:-7] + '/videos' + if os.path.exists(alt_path): + return alt_path + + # If still not found, return original path (will raise error later) + return full_path + + def get_media_info(self, item: Dict[str, Any]) -> Dict[str, Dict[str, str]]: + """ + Extract path info for the two videos. + Similar to RapidataT2VHandler format. + Handles case-insensitive 'videos'/'Videos' directory names. + """ + data_root = item['data_root'] + if not data_root: + raise ValueError(f"Missing 'data_root' in item. Cannot resolve video paths.") + + # Get video paths from "images" field (for compatibility with conversion script) + media_paths = item.get('images', []) + if len(media_paths) < 2: + raise ValueError(f"Item must contain at least 2 video paths in 'images' field.") + + path1 = media_paths[0] + path2 = media_paths[1] + + # Resolve paths with case-insensitive handling + video1_full_path = self._resolve_video_path(path1, data_root) + video2_full_path = self._resolve_video_path(path2, data_root) + + return { + 'video1': { + 'video_local_path': video1_full_path + }, + 'video2': { + 'video_local_path': video2_full_path + } + } + + def parse_item( + self, + item: Dict[str, Any], + media_content: Dict[str, Any], + config: Dict[str, Any] | None, + ) -> Tuple[List[Dict], Dict]: + """ + Parse item into messages format, similar to RapidataT2VHandler but for GRM training. + Returns messages in the format expected by GRMDataset. + """ + video1 = media_content.get('video1') + video2 = media_content.get('video2') + + if not all([video1, video2]): + raise ValueError(f"Missing visual content for 'video1' or 'video2'.") + + # Get FPS from config or item + fps = config.get("video_fps") if config else item.get("video_fps", 2.0) + + # Get max_pixels from config (default to 720 * 480 if not provided) + max_pixels = config.get("max_pixels", 720 * 480) if config else 720 * 480 + + # Get conversations from data item + conversations = item["conversations"] + system_prompt = conversations[0]['value'] + response = conversations[-1]['value'] + + # Build messages for video, similar to RapidataT2VHandler format + # But using the format from imagegen_cot_reward for GRM training + # Note: video1 and video2 are already loaded by load_multimodal_content + messages = [{ + "role": "system", + "content": system_prompt + }, { + "role": "user", + "content": [{ + "type": "text", + "text": "**Video 1:**" + }, { + "type": "video", + "video": video1, + "fps": fps, + "max_pixels": max_pixels + }] + }, { + "role": "user", + "content": [{ + "type": "text", + "text": "**Video 2:**" + }, { + "type": "video", + "video": video2, + "fps": fps, + "max_pixels": max_pixels + }] + }] + + # During evaluation, we do not include the response part in the messages + is_training = config.get("is_training", True) if config else True + if is_training: + messages.append({"role": "assistant", "content": response}) + + other = { + "source": item.get('source', 'rejection-sampling-t2v'), + "data_item": item, + "system_prompt": system_prompt, + "response": response, + "task_type": self.task_type, + } + return messages, other From a52bc6961e729ebb95f735b2121981d5713504f3 Mon Sep 17 00:00:00 2001 From: Sun Jiaxuan Date: Thu, 15 Jan 2026 17:12:59 +0800 Subject: [PATCH 6/7] feature(sunjx): fix image 1 bugs --- .../rejection_sampling/convert_to_rejection_sampling_data.py | 2 +- .../grm_training/rejection_sampling/run_rejection_sampling.sh | 2 +- .../grm_training/rejection_sampling/train_rejection_sampling.sh | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/grm_training/rejection_sampling/convert_to_rejection_sampling_data.py b/examples/grm_training/rejection_sampling/convert_to_rejection_sampling_data.py index 3046c84..c36e36b 100644 --- a/examples/grm_training/rejection_sampling/convert_to_rejection_sampling_data.py +++ b/examples/grm_training/rejection_sampling/convert_to_rejection_sampling_data.py @@ -89,7 +89,7 @@ def convert_to_rejection_sampling_format( # For training data, we always use: Image 1 = preferred, Image 2 = rejected # This ensures consistency - answer = "Image 1 is better" + answer = "Image 1 is better" if preference == "A" else "Image 2 is better" image1_path = path1 # preferred image2_path = path2 # rejected diff --git a/examples/grm_training/rejection_sampling/run_rejection_sampling.sh b/examples/grm_training/rejection_sampling/run_rejection_sampling.sh index 277533b..e5ab07f 100644 --- a/examples/grm_training/rejection_sampling/run_rejection_sampling.sh +++ b/examples/grm_training/rejection_sampling/run_rejection_sampling.sh @@ -20,7 +20,7 @@ MODEL_PATH="/mnt/shared-storage-user/puyuan/wanzunian/models/lightrlhf-grm-lr1e5 # Dataset configuration # Please set your dataset path here (format: "source:path") -DATA_PATH="hpdv3:/mnt/shared-storage-user/puyuan/wanzunian/datasets/HPDv3/train_subset_5percent.json" +DATA_PATH="hpdv3:/mnt/shared-storage-user/puyuan/wanzunian/datasets/HPDv3/train_sub5k.json" # Please set your dataset root directory here DATA_ROOT="/mnt/shared-storage-user/puyuan/wanzunian/datasets/HPDv3" diff --git a/examples/grm_training/rejection_sampling/train_rejection_sampling.sh b/examples/grm_training/rejection_sampling/train_rejection_sampling.sh index aa6cb53..eff3c60 100755 --- a/examples/grm_training/rejection_sampling/train_rejection_sampling.sh +++ b/examples/grm_training/rejection_sampling/train_rejection_sampling.sh @@ -15,7 +15,7 @@ unset HTTPS_PROXY MODEL_PATH="/mnt/shared-storage-user/puyuan/wanzunian/models/lightrlhf-grm-lr1e5-imagegen_cot_reward-qwen2.5vl3B-gs3000" # Training data path (already converted rejection sampling data) -TRAINING_DATA_PATH="/mnt/shared-storage-user/sunjiaxuan/dec/LightRFT/results/rejection_sampling_20260102_022303/rejection_sampling_train.json" +TRAINING_DATA_PATH="/mnt/shared-storage-user/sunjiaxuan/dec/LightRFT/results/rejection_sampling_20251226_185045/filtered_samples.json" # Output directory for checkpoints OUTPUT_DIR="/mnt/shared-storage-user/sunjiaxuan/dec/LightRFT/results/rejection_sampling_20260102_022303/checkpoint" From 5ede7b15d91e8e97c3fb0130deb14c323494f76f Mon Sep 17 00:00:00 2001 From: Sun Jiaxuan Date: Thu, 15 Jan 2026 19:30:01 +0800 Subject: [PATCH 7/7] feature(sunjx): mix train --- .../train_rejection_sampling_mix.sh | 130 ++++++++++++++++++ 1 file changed, 130 insertions(+) create mode 100644 examples/grm_training/rejection_sampling/train_rejection_sampling_mix.sh diff --git a/examples/grm_training/rejection_sampling/train_rejection_sampling_mix.sh b/examples/grm_training/rejection_sampling/train_rejection_sampling_mix.sh new file mode 100644 index 0000000..9b30790 --- /dev/null +++ b/examples/grm_training/rejection_sampling/train_rejection_sampling_mix.sh @@ -0,0 +1,130 @@ +#!/bin/bash + +# Mixed training script for rejection sampling Image (T2I) + Video (T2V) data +# 使用同一个 GRM 模型,在图像和视频拒绝采样数据上进行联合训练。 + +set -e + +unset http_proxy +unset https_proxy +unset HTTP_PROXY +unset HTTPS_PROXY + +############################# Configuration ########################## +# 预训练模型路径(从该模型继续训练) +MODEL_PATH="/mnt/shared-storage-user/puyuan/wanzunian/models/lightrlhf-grm-lr1e5-imagegen_cot_reward-qwen2.5vl3B-gs3000" + +# 图像拒绝采样数据(convert_to_rejection_sampling_data.py 的输出) +T2I_TRAINING_DATA_PATH="/mnt/shared-storage-user/sunjiaxuan/dec/LightRFT/results/rejection_sampling_20260115_170931/rejection_sampling_train.json" + +# 视频拒绝采样数据(convert_to_rejection_sampling_data_t2v.py 的输出) +T2V_TRAINING_DATA_PATH="/mnt/shared-storage-user/sunjiaxuan/dec/LightRFT/results/rejection_sampling_t2v_20260104_193830/rejection_sampling_train.json" + +# 输出目录 +OUTPUT_DIR="/mnt/shared-storage-user/sunjiaxuan/dec/LightRFT/results/rejection_sampling_mix_$(date +%Y%m%d_%H%M%S)/checkpoint" +LOG_DIR="$(dirname "${OUTPUT_DIR}")/logs" + +# 训练超参(可以按需修改) +TBS=8 # global train batch size +LR=2.5e-6 +MAX_LENGTH=13000 +MAX_EPOCHS=3 +MICRO_BATCH_SIZE=1 +GRADIENT_ACCUMULATION_STEPS=32 + +# 视频 FPS 配置 +VIDEO_FPS=2.0 + +# 注意: +# - Image 数据的 system prompt(task_instruction)从 T2I_TRAINING_DATA_PATH 对应的 json 里读取; +# - Video 数据的 system prompt 从 T2V_TRAINING_DATA_PATH 对应的 json 里读取; +# 每条样本在 json 的 conversations[0]['value'] 里已经带了各自的 CoT 说明,因此这里不再额外传统一的 TASK_INSTRUCTION。 + +############################### Environment ##################### +export GPUS_PER_NODE=${GPUS_PER_NODE:-2} +export NNODES=${NNODES:-1} +export NODE_RANK=${RANK:-0} +export MASTER_ADDR=${MASTER_ADDR:-"localhost"} +export MASTER_PORT=${MASTER_PORT:-29500} + +export WORLD_SIZE=$((NNODES * GPUS_PER_NODE)) + +# 减少显存碎片 +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +# 检查配置 +if [ ! -f "${T2I_TRAINING_DATA_PATH}" ]; then + echo "Error: T2I training data file not found: ${T2I_TRAINING_DATA_PATH}" + exit 1 +fi + +if [ ! -f "${T2V_TRAINING_DATA_PATH}" ]; then + echo "Error: T2V training data file not found: ${T2V_TRAINING_DATA_PATH}" + exit 1 +fi + +if [ -z "${MODEL_PATH}" ]; then + echo "Error: MODEL_PATH is not set. Please configure it in the script." + exit 1 +fi + +# 创建输出目录 +mkdir -p "${OUTPUT_DIR}" +mkdir -p "${LOG_DIR}" + +echo "==========================================" +echo "Rejection Sampling Mixed Training (T2I + T2V)" +echo "==========================================" +echo "Model: ${MODEL_PATH}" +echo "T2I Training Data: ${T2I_TRAINING_DATA_PATH}" +echo "T2V Training Data: ${T2V_TRAINING_DATA_PATH}" +echo "Output: ${OUTPUT_DIR}" +echo "GPUs: ${GPUS_PER_NODE}" +echo "Video FPS: ${VIDEO_FPS}" +echo "==========================================" + +# 这里利用 GRMDataset 中的多 handler 能力: +# args.train_data 会在 train_grm_vl.py 中被按逗号切分成 list, +# 每个元素是 "source:path" 的形式。 +T2I_SOURCE="imagegen-cot-reward-5k:${T2I_TRAINING_DATA_PATH}" +T2V_SOURCE="rejection-sampling-t2v:${T2V_TRAINING_DATA_PATH}" + +TRAINING_DATA_SOURCES="${T2I_SOURCE},${T2V_SOURCE}" + +############################### Training ########################## +echo "" +echo "Starting mixed training on T2I + T2V rejection sampling data..." +echo "==========================================" + +torchrun --nnodes $NNODES --nproc-per-node $GPUS_PER_NODE \ + --node_rank $NODE_RANK --master-port $MASTER_PORT --master-addr $MASTER_ADDR \ + examples/grm_training/train_grm_vl.py \ + --pretrain ${MODEL_PATH} \ + --save_path ${OUTPUT_DIR} \ + --ckpt_path ${OUTPUT_DIR} \ + --train_batch_size ${TBS} \ + --micro_train_batch_size ${MICRO_BATCH_SIZE} \ + --max_epochs ${MAX_EPOCHS} \ + --lr_warmup_ratio 0.03 \ + --prompt_max_len ${MAX_LENGTH} \ + --fps ${VIDEO_FPS} \ + --zero_stage 3 \ + --bf16 \ + --actor_learning_rate ${LR} \ + --train_data ${TRAINING_DATA_SOURCES} \ + --gradient_checkpointing \ + --save_steps 1000 \ + --max_ckpt_num 8 \ + --use_tensorboard "$(dirname "${OUTPUT_DIR}")/tensorboard" \ + --l2 0.0 \ + --flash_attn \ + 2>&1 | tee ${LOG_DIR}/training.log + +echo "" +echo "==========================================" +echo "Mixed Training Completed!" +echo "==========================================" +echo "Final checkpoint: ${OUTPUT_DIR}/final_checkpoint" +echo "Training logs: ${LOG_DIR}/training.log" +echo "==========================================" +