diff --git a/examples/grm_training/rejection_sampling/convert_to_rejection_sampling_data.py b/examples/grm_training/rejection_sampling/convert_to_rejection_sampling_data.py new file mode 100644 index 0000000..c36e36b --- /dev/null +++ b/examples/grm_training/rejection_sampling/convert_to_rejection_sampling_data.py @@ -0,0 +1,163 @@ +""" +Convert filtered samples to rejection sampling training data format. + +This script converts the filtered correct samples into the format required +for rejection sampling training, similar to imagegen-cot-reward dataset. +""" + +import os +import json +import argparse +from typing import List, Dict +from loguru import logger + + +def convert_to_rejection_sampling_format( + filtered_samples_path: str, + output_path: str, + data_root: str, + task_instruction_template: str = None, +): + """ + Convert filtered samples to rejection sampling training format. + + :param filtered_samples_path: Path to filtered samples JSON file + :type filtered_samples_path: str + :param output_path: Path to save converted training data + :type output_path: str + :param data_root: Root directory of the dataset (for image paths) + :type data_root: str + :param task_instruction_template: Template for task instruction + :type task_instruction_template: str, optional + :return: List of training data items in imagegen-cot-reward format + :rtype: List[Dict] + """ + logger.info(f"Loading filtered samples from {filtered_samples_path}") + + with open(filtered_samples_path, 'r', encoding='utf-8') as f: + filtered_samples = json.load(f) + + logger.info(f"Loaded {len(filtered_samples)} filtered samples") + + # Default task instruction template + if task_instruction_template is None: + task_instruction_template = """Given a caption and two images generated based on this caption, please analyze in detail the two provided images. +Evaluate them on various dimensions such as semantic consistency (how closely the image content aligns with the caption), +aesthetics (composition, color usage, artistic expression), authenticity (realism and attention to detail), +and any other factors you deem relevant. For each evaluation dimension, +provide a score between 1-10 for both images (e.g., Image 1: 8/10, Image 2: 6/10) and provide a concise rationale for the score. +Calculate the total score for each image by summing all dimension scores. +Use a chain-of-thought process to detail your reasoning steps, and enclose all your detailed reasoning within and tags. +Then, in the tag, output exactly one of the following strings: 'Image 1 is better' or 'Image 2 is better' based on the total scores. +No additional text is allowed in the section. +Example output format: + +Semantic consistency: Image 1 (9/10) - ...; Image 2 (7/10) - ... +Aesthetics: Image 2 (8/10) - ...; Image 1 (8/10) - ... +Authenticity: Image 1 (8/10) - ...; Image 2 (5/10) - ... +[Additional dimensions if any]: Image 2 (8/10) - ...; Image 1 (6/10) - ... +Total score: +Image 1: 9+8+8+6=31 +Image 2: 7+8+5+8=28 + +Image 1 is better +Note: In the example above, scores and the final answer are placeholders meant only to demonstrate the format. Your actual evaluation should be based on the quality of two given images. +Your task is provided as follows: +Text Caption: {prompt}""" + + training_data = [] + + for idx, sample in enumerate(filtered_samples): + prompt = sample['prompt'] + path1 = sample['path1'] + path2 = sample['path2'] + preference = sample['preference'] + generated_text = sample.get('generated_text', '') + reasoning = sample.get('reasoning', '') + + # Determine which image is better based on preference + # In HPDv3, path1 is the preferred path, path2 is the rejected path + # preference "A" means Image 1 (which is path1) is better + # preference "B" means Image 2 (which is path2) is better, but this means path2 was randomly chosen as Image 1 + # Actually, in HPDv3GRMHandler, preference A means image0 (first shown) is preferred + # So we need to check which path corresponds to which image + + # Since we stored preferred_path and rejected_path, we know: + # - preferred_path (path1) should be the better one + # - rejected_path (path2) should be the worse one + # But the handler randomly assigns them to Image 1 or Image 2 + + # For training data, we always use: Image 1 = preferred, Image 2 = rejected + # This ensures consistency + answer = "Image 1 is better" if preference == "A" else "Image 2 is better" + image1_path = path1 # preferred + image2_path = path2 # rejected + + # Build the response with CoT reasoning + # Note: We use instead of to match the instruction format + if reasoning: + # Use the extracted reasoning from inference + # Clean up the reasoning text + reasoning_clean = reasoning.strip() + response = f"\n{reasoning_clean}\n\n{answer}" + else: + # If no reasoning was extracted, create a placeholder + # In practice, you might want to regenerate this or use a template + response = f"\nBased on the evaluation of semantic consistency, aesthetics, and authenticity, I will compare the two images.\n\n{answer}" + + # Build conversations format + task_instruction = task_instruction_template.format(prompt=prompt) + + # Create training data item in imagegen-cot-reward format + training_item = { + "conversations": [ + { + "from": "human", + "value": task_instruction + }, + { + "from": "gpt", + "value": response + } + ], + "images": [ + image1_path if os.path.isabs(image1_path) else os.path.join(data_root, image1_path), + image2_path if os.path.isabs(image2_path) else os.path.join(data_root, image2_path), + ] + } + + training_data.append(training_item) + + if (idx + 1) % 100 == 0: + logger.info(f"Converted {idx + 1}/{len(filtered_samples)} samples") + + # Save training data + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, 'w', encoding='utf-8') as f: + json.dump(training_data, f, indent=2, ensure_ascii=False) + + logger.info(f"Saved {len(training_data)} training samples to {output_path}") + + return training_data + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert filtered samples to rejection sampling training format") + parser.add_argument("--filtered_samples_path", type=str, required=True, + help="Path to filtered samples JSON file") + parser.add_argument("--output_path", type=str, required=True, + help="Path to save converted training data") + parser.add_argument("--data_root", type=str, required=True, + help="Root directory of the dataset (for image paths)") + parser.add_argument("--task_instruction", type=str, default=None, + help="Task instruction template (optional)") + + args = parser.parse_args() + + convert_to_rejection_sampling_format( + filtered_samples_path=args.filtered_samples_path, + output_path=args.output_path, + data_root=args.data_root, + task_instruction_template=args.task_instruction, + ) + diff --git a/examples/grm_training/rejection_sampling/convert_to_rejection_sampling_data_t2v.py b/examples/grm_training/rejection_sampling/convert_to_rejection_sampling_data_t2v.py new file mode 100644 index 0000000..08613cf --- /dev/null +++ b/examples/grm_training/rejection_sampling/convert_to_rejection_sampling_data_t2v.py @@ -0,0 +1,155 @@ +""" +Convert filtered samples to rejection sampling training data format for T2V. + +This script converts the filtered correct samples into the format required +for rejection sampling training, similar to imagegen-cot-reward dataset but for videos. +""" + +import os +import json +import argparse +from typing import List, Dict +from loguru import logger + + +def convert_to_rejection_sampling_format( + filtered_samples_path: str, + output_path: str, + task_instruction_template: str = None, + video_fps: float = 2.0, +): + """ + Convert filtered samples to rejection sampling training format for T2V. + + :param filtered_samples_path: Path to filtered samples JSON file + :type filtered_samples_path: str + :param output_path: Path to save converted training data + :type output_path: str + :param task_instruction_template: Template for task instruction + :type task_instruction_template: str, optional + :param video_fps: FPS for video processing + :type video_fps: float + :return: List of training data items in imagegen-cot-reward format (for videos) + :rtype: List[Dict] + """ + logger.info(f"Loading filtered samples from {filtered_samples_path}") + + with open(filtered_samples_path, 'r', encoding='utf-8') as f: + filtered_samples = json.load(f) + + logger.info(f"Loaded {len(filtered_samples)} filtered samples") + + # Default task instruction template for T2V + if task_instruction_template is None: + task_instruction_template = """Given a caption and two videos generated based on this caption, please analyze in detail the two provided videos. +Evaluate them on various dimensions such as semantic consistency (how closely the video content aligns with the caption), temporal coherence (smoothness and logical flow of motion across frames), authenticity (realism and attention to detail), and any other factors you deem relevant. +For each evaluation dimension, provide a score between 1-10 for both videos (e.g., Video 1: 8/10, Video 2: 6/10) and provide a concise rationale for the score. +Calculate the total score for each video by summing all dimension scores. +Use a chain-of-thought process to detail your reasoning steps, and enclose all your detailed reasoning within and tags. Then, in the tag, output exactly one of the following strings: +'Video 1 is better' or 'Video 2 is better' based on the total scores. No additional text is allowed in the section. +Example output format: + +1. Semantic consistency: Video 1 (9/10) - ...; Video 2 (7/10) - ... +2. Temporal coherence: Video 1 (8/10) - ...; Video 2 (6/10) - ... +3. Authenticity: Video 1 (7/10) - ...; Video 2 (5/10) - ... +... +[Additional dimensions if any]: Video 2 (8/10) - ...; Video 1 (6/10) - ... +Total score: +Video 1: 9+8+7+6=30 +Video 2: 7+6+5+8=26 + +Video 1 is better + +Note: In the example above, scores and the final answer are placeholders meant only to demonstrate the format. Your actual evaluation should be based on the quality of two given videos. +Your task is provided as follows: +Text Caption: {prompt}""" + + training_data = [] + + for idx, sample in enumerate(filtered_samples): + prompt = sample['prompt'] + path1 = sample['path1'] + path2 = sample['path2'] + preference = sample['preference'] + generated_text = sample.get('generated_text', '') + reasoning = sample.get('reasoning', '') + + # Determine which video is better based on preference + # preference "A" means Video 1 is better + # preference "B" means Video 2 is better + # For training data, we always use: Video 1 = preferred, Video 2 = rejected + # This ensures consistency + answer = "Video 1 is better" if preference == "A" else "Video 2 is better" + video1_path = path1 if preference == "A" else path2 # preferred + video2_path = path2 if preference == "A" else path1 # rejected + + # Build the response with CoT reasoning + if reasoning: + # Use the extracted reasoning from inference + # Clean up the reasoning text + reasoning_clean = reasoning.strip() + response = f"\n{reasoning_clean}\n\n{answer}" + else: + # If no reasoning was extracted, create a placeholder + response = f"\nBased on the evaluation of semantic consistency, temporal coherence, and authenticity, I will compare the two videos.\n\n{answer}" + + # Build conversations format + task_instruction = task_instruction_template.format(prompt=prompt) + + # Create training data item in imagegen-cot-reward format (but for videos) + # We use "images" field name to be compatible with ImageGenCoTRewardHandler + # but store video paths - the handler will need to be modified to support videos + # For now, we store relative paths from data_root + training_item = { + "conversations": [ + { + "from": "human", + "value": task_instruction + }, + { + "from": "gpt", + "value": response + } + ], + "images": [ + video1_path if os.path.isabs(video1_path) else video1_path, + video2_path if os.path.isabs(video2_path) else video2_path, + ], + "video_fps": video_fps, # Store FPS for video processing + } + + training_data.append(training_item) + + if (idx + 1) % 100 == 0: + logger.info(f"Converted {idx + 1}/{len(filtered_samples)} samples") + + # Save training data + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, 'w', encoding='utf-8') as f: + json.dump(training_data, f, indent=2, ensure_ascii=False) + + logger.info(f"Saved {len(training_data)} training samples to {output_path}") + + return training_data + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert filtered samples to rejection sampling training format for T2V") + parser.add_argument("--filtered_samples_path", type=str, required=True, + help="Path to filtered samples JSON file") + parser.add_argument("--output_path", type=str, required=True, + help="Path to save converted training data") + parser.add_argument("--task_instruction", type=str, default=None, + help="Task instruction template (optional)") + parser.add_argument("--video_fps", type=float, default=2.0, + help="FPS for video processing") + + args = parser.parse_args() + + convert_to_rejection_sampling_format( + filtered_samples_path=args.filtered_samples_path, + output_path=args.output_path, + task_instruction_template=args.task_instruction, + video_fps=args.video_fps, + ) + diff --git a/examples/grm_training/rejection_sampling/rejection_sampling_inference.py b/examples/grm_training/rejection_sampling/rejection_sampling_inference.py new file mode 100644 index 0000000..a6546a4 --- /dev/null +++ b/examples/grm_training/rejection_sampling/rejection_sampling_inference.py @@ -0,0 +1,376 @@ +""" +Rejection Sampling Inference Script + +This script performs inference on a dataset using a trained GRM model, +filters out correctly predicted samples, and generates training data +with CoT reasoning for rejection sampling training. +""" + +import os +import json +import argparse +import re +from tqdm import tqdm +from typing import List, Dict +from loguru import logger +from torch.utils.data import DataLoader + +from transformers import AutoProcessor, AutoTokenizer +from vllm import LLM, SamplingParams +from lightrft.datasets import GRMDataset, extract_answer + + +TASK_INSTRUCTION_COT = """Given a caption and two images generated based on this caption, please analyze in detail the two provided images. +Evaluate them on various dimensions such as semantic consistency (how closely the image content aligns with the caption), +aesthetics (composition, color usage, artistic expression), authenticity (realism and attention to detail), +and any other factors you deem relevant. For each evaluation dimension, +provide a score between 1-10 for both images (e.g., Image 1: 8/10, Image 2: 6/10) and provide a concise rationale for the score. +Calculate the total score for each image by summing all dimension scores. +Use a chain-of-thought process to detail your reasoning steps, and enclose all your detailed reasoning within and tags. +Then, in the tag, output exactly one of the following strings: 'Image 1 is better' or 'Image 2 is better' based on the total scores. +No additional text is allowed in the section. +Example output format: + +Semantic consistency: Image 1 (9/10) - ...; Image 2 (7/10) - ... +Aesthetics: Image 2 (8/10) - ...; Image 1 (8/10) - ... +Authenticity: Image 1 (8/10) - ...; Image 2 (5/10) - ... +[Additional dimensions if any]: Image 2 (8/10) - ...; Image 1 (6/10) - ... +Total score: +Image 1: 9+8+8+6=31 +Image 2: 7+8+5+8=28 + +Image 1 is better +Note: In the example above, scores and the final answer are placeholders meant only to demonstrate the format. Your actual evaluation should be based on the quality of two given images. +Your task is provided as follows: +Text Caption: {prompt} +""" + + +class GRMPromptDatasetVL: + """ + Dataset wrapper for vLLM inference that returns prompts and image/video paths + instead of tokenized inputs. + """ + def __init__( + self, + dataset_paths: List[str], + processor: AutoProcessor, + tokenizer: AutoTokenizer, + strategy=None, + max_length: int = 8192, + config: Dict = None, + is_training: bool = False, + ): + self.base_dataset = GRMDataset( + dataset_paths, + processor=processor, + tokenizer=tokenizer, + strategy=strategy, + max_length=max_length, + config=config, + is_training=is_training, + ) + self.processor = processor + self.tokenizer = tokenizer + + def __len__(self): + return len(self.base_dataset) + + def __getitem__(self, idx): + item = self.base_dataset.data[idx] + source = item["source"] + handler = self.base_dataset.handlers[source] + + # Get media info (paths) + media_info = handler.get_media_info(item) + + # Load media content (needed for parse_item) + loaded_content = self.base_dataset.media_content_loader(media_info) + if loaded_content is None: + raise RuntimeError(f"Failed to load media content: {media_info}") + + # Parse item to get messages + messages, other = handler.parse_item(item, loaded_content, self.base_dataset.config) + + # Get prompt text (exclude the last assistant message for inference) + messages_for_prompt = messages[:-1] if len(messages) > 0 else messages + prompt_text = self.processor.apply_chat_template( + messages_for_prompt, + tokenize=False, + add_generation_prompt=True, + ) + + # Extract image and video paths from media_info for vLLM + image_paths = [] + video_paths = [] + + if media_info: + # media_info is typically a dict with 'images' and 'videos' keys + if isinstance(media_info, dict): + image_paths = media_info.get('images', []) + video_paths = media_info.get('videos', []) + elif isinstance(media_info, list): + # If it's a list, assume all are images + image_paths = media_info + + return prompt_text, image_paths, video_paths, other + + def collate_fn(self, batch): + input_texts = [] + image_inputs_list = [] + video_inputs_list = [] + extras = [] + + for prompt_text, image_paths, video_paths, other in batch: + input_texts.append(prompt_text) + image_inputs_list.append(image_paths if image_paths else None) + video_inputs_list.append(video_paths if video_paths else None) + extras.append(other) + + return input_texts, image_inputs_list, video_inputs_list, extras + + +def inference_and_filter( + model_path: str, + data_path: List[str], + output_path: str, + config: dict = None, + batch_size: int = 32, + max_new_tokens: int = 2048, + use_cot: bool = True, + tensor_parallel_size: int = 1, + gpu_memory_utilization: float = 0.9, +): + """ + Perform inference on dataset and filter correctly predicted samples. + + :param model_path: Path to the trained GRM model + :type model_path: str + :param data_path: List of dataset paths in format "source:path" + :type data_path: List[str] + :param output_path: Path to save filtered samples + :type output_path: str + :param config: Configuration dict for dataset + :type config: dict, optional + :param batch_size: Batch size for inference + :type batch_size: int + :param max_new_tokens: Maximum tokens to generate + :type max_new_tokens: int + :param use_cot: Whether to use CoT instruction (for generating reasoning) + :type use_cot: bool + :param tensor_parallel_size: Number of GPUs for tensor parallelism + :type tensor_parallel_size: int + :param gpu_memory_utilization: GPU memory utilization ratio + :type gpu_memory_utilization: float + :return: List of correctly predicted samples with their generated text and reasoning + :rtype: List[Dict] + """ + logger.info(f"Loading model from: {model_path}") + + # Initialize vLLM + llm = LLM( + model=model_path, + tensor_parallel_size=tensor_parallel_size, + trust_remote_code=True, + gpu_memory_utilization=gpu_memory_utilization, + limit_mm_per_prompt={ + "image": 2, + "video": 2 + }, + ) + + sampling_params = SamplingParams( + temperature=0.0, # For deterministic output + max_tokens=max_new_tokens, + ) + + logger.info(f"Model loaded successfully from {model_path}.") + + # Load Processor and Tokenizer for Dataset + processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + + # Load Dataset + dataset = GRMPromptDatasetVL( + data_path, + processor=processor, + tokenizer=tokenizer, + strategy=None, + max_length=8192, + config=config, + is_training=False, + ) + + data_loader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + drop_last=False, + collate_fn=dataset.collate_fn, + ) + + logger.info(f"Starting inference with CoT: {use_cot}, batch_size: {batch_size}") + + correct_samples = [] + total_samples = 0 + correct_count = 0 + + for batch_idx, batch in enumerate(tqdm(data_loader)): + try: + input_texts, image_inputs_list, video_inputs_list, extras = batch + + # Prepare inputs for vLLM + inputs = [] + for i in range(len(input_texts)): + prompt = input_texts[i] + image_inputs = image_inputs_list[i] + video_inputs = video_inputs_list[i] + + mm_data = {} + if image_inputs is not None: + mm_data["image"] = image_inputs + if video_inputs is not None: + mm_data["video"] = video_inputs + + inputs.append({ + "prompt": prompt, + "multi_modal_data": mm_data + }) + + # Generate with vLLM + outputs = llm.generate(inputs, sampling_params=sampling_params) + + # Decode + gen_texts = [output.outputs[0].text for output in outputs] + + # Evaluate and filter + for i, (gen_text, extra) in enumerate(zip(gen_texts, extras)): + total_samples += 1 + predicted_answer = extract_answer(gen_text) + gt_preference = extra['preference'] # A or B + + # Mapping logic: + # In HPDv3GRMHandler, preference "A" means Image 1 (first shown) is preferred + # preference "B" means Image 2 (second shown) is preferred + # But the handler randomly swaps images, so we need to check the actual mapping + # The handler stores: preferred_path (path1) and rejected_path (path2) + # When preference is "A", image0 (which could be preferred or rejected) is shown as Image 1 + # When preference is "B", image1 (which could be preferred or rejected) is shown as Image 1 + + # Since the handler randomly assigns, we check based on the stored preference + # If gt_preference is "A", it means Image 1 (first shown) is better + # If gt_preference is "B", it means Image 2 (second shown) is better + is_correct = False + if gt_preference == "A" and predicted_answer == "Image 1 is better": + is_correct = True + elif gt_preference == "B" and predicted_answer == "Image 2 is better": + is_correct = True + + if is_correct: + correct_count += 1 + # Prepare sample for rejection sampling training + sample = { + "prompt": extra['prompt'], + "path1": extra['preferred_path'], + "path2": extra['rejected_path'], + "preference": gt_preference, + "generated_text": gen_text, + "predicted_answer": predicted_answer, + } + + # If we want to use the generated CoT reasoning, extract it + if use_cot: + # Extract reasoning from generated text + # Try both and tags (in case of different formats) + reasoning_match = None + import re + # Try first (standard format) + if "" in gen_text: + reasoning_pattern = r"(.*?)" + reasoning_match = re.search(reasoning_pattern, gen_text, re.DOTALL) + # Try as fallback + elif "" in gen_text: + reasoning_pattern = r"(.*?)" + reasoning_match = re.search(reasoning_pattern, gen_text, re.DOTALL) + + if reasoning_match: + reasoning = reasoning_match.group(1).strip() + sample["reasoning"] = reasoning + else: + # If no reasoning found, we'll use the full generated text (excluding answer) + # or generate it during training data preparation + # Remove answer part to get reasoning + answer_part = f"{predicted_answer}" if predicted_answer else "" + reasoning_candidate = gen_text.replace(answer_part, "").strip() + sample["reasoning"] = reasoning_candidate if reasoning_candidate else None + + correct_samples.append(sample) + + if total_samples % 100 == 0: + logger.info(f"Processed {total_samples} samples, {correct_count} correct ({correct_count/total_samples*100:.2f}%)") + + except Exception as e: + logger.error(f"Error at batch {batch_idx}: {e}") + raise + + # Summary + accuracy = correct_count / total_samples if total_samples > 0 else 0.0 + logger.info(f"Inference completed. Accuracy: {accuracy*100:.2f}% ({correct_count}/{total_samples})") + logger.info(f"Filtered {len(correct_samples)} correct samples for rejection sampling training") + + # Save filtered samples + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, 'w', encoding='utf-8') as f: + json.dump(correct_samples, f, indent=2, ensure_ascii=False) + + logger.info(f"Saved {len(correct_samples)} correct samples to {output_path}") + + # Save statistics + stats_path = output_path.replace('.json', '_stats.txt') + with open(stats_path, 'w', encoding='utf-8') as f: + f.write(f"Dataset paths: {data_path}\n") + f.write(f"Model path: {model_path}\n") + f.write(f"Total samples: {total_samples}\n") + f.write(f"Correct samples: {correct_count}\n") + f.write(f"Accuracy: {accuracy*100:.2f}%\n") + f.write(f"Filtered samples for training: {len(correct_samples)}\n") + + return correct_samples + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Rejection Sampling Inference") + parser.add_argument("--model_path", type=str, required=True, help="Path to trained GRM model") + parser.add_argument("--data_path", type=str, required=True, help="Dataset path in format 'source:path'") + parser.add_argument("--output_path", type=str, required=True, help="Path to save filtered samples") + parser.add_argument("--batch_size", type=int, default=32, help="Batch size for inference") + parser.add_argument("--max_new_tokens", type=int, default=2048, help="Maximum tokens to generate") + parser.add_argument("--use_cot", action="store_true", default=True, help="Use CoT instruction for reasoning") + parser.add_argument("--task_instruction", type=str, default=TASK_INSTRUCTION_COT, help="Task instruction template") + + # vLLM arguments + parser.add_argument("--tensor_parallel_size", type=int, default=1, help="Number of GPUs for tensor parallelism") + parser.add_argument("--gpu_memory_utilization", type=float, default=0.9, help="GPU memory utilization ratio") + + args = parser.parse_args() + + # Parse data path + data_paths = [args.data_path] if isinstance(args.data_path, str) else args.data_path.split(",") + + config = { + "task_instruction": args.task_instruction, + "name": "rejection_sampling_inference", + } + + inference_and_filter( + model_path=args.model_path, + data_path=data_paths, + output_path=args.output_path, + config=config, + batch_size=args.batch_size, + max_new_tokens=args.max_new_tokens, + use_cot=args.use_cot, + tensor_parallel_size=args.tensor_parallel_size, + gpu_memory_utilization=args.gpu_memory_utilization, + ) + diff --git a/examples/grm_training/rejection_sampling/rejection_sampling_inference_t2v.py b/examples/grm_training/rejection_sampling/rejection_sampling_inference_t2v.py new file mode 100644 index 0000000..65d4e3a --- /dev/null +++ b/examples/grm_training/rejection_sampling/rejection_sampling_inference_t2v.py @@ -0,0 +1,477 @@ +""" +Rejection Sampling Inference Script for Text-to-Video (T2V) + +This script performs inference on a dataset using a trained GRM model, +filters out correctly predicted samples, and generates training data +with CoT reasoning for rejection sampling training. + +For Rapidata-T2V, we compute gt_preference based on the sum of three dimensions: +Alignment + Coherence + Preference +""" + +import os +import json +import argparse +import re +from tqdm import tqdm +from typing import List, Dict +from loguru import logger +from torch.utils.data import DataLoader + +from transformers import AutoProcessor, AutoTokenizer +from vllm import LLM, SamplingParams +from lightrft.datasets import extract_answer, RFTDatasetVL + +# Import qwen_vl_utils for processing vision info +try: + from qwen_vl_utils import process_vision_info +except ImportError: + try: + from keye_vl_utils import process_vision_info + except ImportError: + raise ImportError("Neither qwen_vl_utils nor keye_vl_utils is available") + + +TASK_INSTRUCTION_COT_T2V = """Given a caption and two videos generated based on this caption, please analyze in detail the two provided videos. +Evaluate them on various dimensions such as semantic consistency (how closely the video content aligns with the caption), temporal coherence (smoothness and logical flow of motion across frames), authenticity (realism and attention to detail), and any other factors you deem relevant. +For each evaluation dimension, provide a score between 1-10 for both videos (e.g., Video 1: 8/10, Video 2: 6/10) and provide a concise rationale for the score. +Calculate the total score for each video by summing all dimension scores. +Use a chain-of-thought process to detail your reasoning steps, and enclose all your detailed reasoning within and tags. Then, in the tag, output exactly one of the following strings: +'Video 1 is better' or 'Video 2 is better' based on the total scores. No additional text is allowed in the section. +Example output format: + +1. Semantic consistency: Video 1 (9/10) - ...; Video 2 (7/10) - ... +2. Temporal coherence: Video 1 (8/10) - ...; Video 2 (6/10) - ... +3. Authenticity: Video 1 (7/10) - ...; Video 2 (5/10) - ... +... +[Additional dimensions if any]: Video 2 (8/10) - ...; Video 1 (6/10) - ... +Total score: +Video 1: 9+8+7+6=30 +Video 2: 7+6+5+8=26 + +Video 1 is better + +Note: In the example above, scores and the final answer are placeholders meant only to demonstrate the format. Your actual evaluation should be based on the quality of two given videos. +Your task is provided as follows: +Text Caption: **{prompt}** +""" + + +class GRMPromptDatasetVLT2V: + """ + Dataset wrapper for vLLM inference that returns prompts and video paths + instead of tokenized inputs. Adapted for T2V with RFTDatasetVL. + """ + def __init__( + self, + dataset_paths: List[str], + processor: AutoProcessor, + tokenizer: AutoTokenizer, + strategy=None, + max_length: int = 8192, + config: Dict = None, + is_training: bool = False, + ): + self.base_dataset = RFTDatasetVL( + dataset_paths, + processor=processor, + tokenizer=tokenizer, + strategy=strategy, + max_length=max_length, + config=config, + is_train=is_training, + ) + self.processor = processor + self.tokenizer = tokenizer + + def __len__(self): + return len(self.base_dataset) + + def __getitem__(self, idx): + item = self.base_dataset.data[idx] + source = item["source"] + handler = self.base_dataset.handlers[source] + + # Get media info (paths) + media_info = handler.get_media_info(item) + + # Load media content (needed for parse_item) + loaded_content = self.base_dataset.media_content_loader(media_info) + if loaded_content is None: + raise RuntimeError(f"Failed to load media content: {media_info}") + + # Parse item to get messages (returns messages0, messages1, other for PairHandler) + messages0, messages1, other = handler.parse_item(item, loaded_content, self.base_dataset.config) + + # Combine messages0 and messages1 to show both videos in the same conversation + # Similar to HPDv3GRMHandler format: system prompt + Video 1 + Video 2 + messages = [] + + # Add system prompt (from messages0) + if len(messages0) > 0 and messages0[0].get("role") == "system": + messages.append(messages0[0]) + + # Add Video 1 with label + if len(messages0) > 1 and messages0[1].get("role") == "user": + video1_content = messages0[1]["content"] + messages.append({ + "role": "user", + "content": [ + { + "type": "text", + "text": "**Video 1:**" + }, + video1_content[0] if isinstance(video1_content, list) and len(video1_content) > 0 else video1_content + ] + }) + + # Add Video 2 with label (from messages1) + if len(messages1) > 1 and messages1[1].get("role") == "user": + video2_content = messages1[1]["content"] + messages.append({ + "role": "user", + "content": [ + { + "type": "text", + "text": "**Video 2:**" + }, + video2_content[0] if isinstance(video2_content, list) and len(video2_content) > 0 else video2_content + ] + }) + + # Get prompt text (exclude the last assistant message for inference) + messages_for_prompt = messages[:-1] if len(messages) > 0 and messages[-1].get("role") == "assistant" else messages + prompt_text = self.processor.apply_chat_template( + messages_for_prompt, + tokenize=False, + add_generation_prompt=True, + ) + + # Extract video information from messages using process_vision_info + # This is the same way test_grm_vl_vllm.py does it + # process_vision_info returns (image_inputs, video_inputs, video_kwargs) + # but we only need image_inputs and video_inputs for vLLM + image_inputs, video_inputs, _ = process_vision_info( + messages_for_prompt, + return_video_kwargs=True, + ) + + # Store original item for accessing raw scores + other['_raw_item'] = item + + return prompt_text, image_inputs, video_inputs, other + + def collate_fn(self, batch): + input_texts = [] + image_inputs_list = [] + video_inputs_list = [] + extras = [] + + for prompt_text, image_inputs, video_inputs, other in batch: + input_texts.append(prompt_text) + image_inputs_list.append(image_inputs if image_inputs else None) + video_inputs_list.append(video_inputs if video_inputs else None) + extras.append(other) + + return input_texts, image_inputs_list, video_inputs_list, extras + + +def safe_get_score(item: Dict, key: str, default: float = 0.0) -> float: + """ + Safely get score value from item, handling None values. + + :param item: Dictionary containing score values + :param key: Key to look up in the dictionary + :param default: Default value to use if key is missing or value is None + :return: Float score value + """ + value = item.get(key, default) + return default if value is None else float(value) + + +def compute_total_score(item: Dict, video_num: int) -> float: + """ + Compute total score for a video based on three dimensions. + + :param item: Dictionary containing score values + :param video_num: Video number (1 or 2) + :return: Total score (Alignment + Coherence + Preference) + """ + alignment = safe_get_score(item, f"weighted_results{video_num}_Alignment", 0.0) + coherence = safe_get_score(item, f"weighted_results{video_num}_Coherence", 0.0) + preference = safe_get_score(item, f"weighted_results{video_num}_Preference", 0.0) + return alignment + coherence + preference + + +def compute_gt_preference_from_scores(item: Dict) -> str: + """ + Compute ground truth preference based on sum of three dimensions: + Alignment + Coherence + Preference + + Returns "A" if video1 has higher total score, "B" if video2 has higher total score. + """ + total_score1 = compute_total_score(item, 1) + total_score2 = compute_total_score(item, 2) + + if total_score1 > total_score2: + return "A" # Video 1 is better + elif total_score1 < total_score2: + return "B" # Video 2 is better + else: + return "C" # Equal (shouldn't happen often, but handle it) + + +def inference_and_filter( + model_path: str, + data_path: List[str], + output_path: str, + config: dict = None, + batch_size: int = 32, + max_new_tokens: int = 2048, + use_cot: bool = True, + tensor_parallel_size: int = 1, + gpu_memory_utilization: float = 0.9, + video_fps: float = 2.0, +): + """ + Perform inference on dataset and filter correctly predicted samples. + + :param model_path: Path to the trained GRM model + :type model_path: str + :param data_path: List of dataset paths in format "source:path" + :type data_path: List[str] + :param output_path: Path to save filtered samples + :type output_path: str + :param config: Configuration dict for dataset + :type config: dict, optional + :param batch_size: Batch size for inference + :type batch_size: int + :param max_new_tokens: Maximum tokens to generate + :type max_new_tokens: int + :param use_cot: Whether to use CoT instruction (for generating reasoning) + :type use_cot: bool + :param tensor_parallel_size: Number of GPUs for tensor parallelism + :type tensor_parallel_size: int + :param gpu_memory_utilization: GPU memory utilization ratio + :type gpu_memory_utilization: float + :param video_fps: FPS for video processing + :type video_fps: float + :return: List of correctly predicted samples with their generated text and reasoning + :rtype: List[Dict] + """ + logger.info(f"Loading model from: {model_path}") + + # Initialize vLLM + llm = LLM( + model=model_path, + tensor_parallel_size=tensor_parallel_size, + trust_remote_code=True, + gpu_memory_utilization=gpu_memory_utilization, + limit_mm_per_prompt={ + "image": 0, + "video": 2 + }, + ) + + sampling_params = SamplingParams( + temperature=0.0, # For deterministic output + max_tokens=max_new_tokens, + ) + + logger.info(f"Model loaded successfully from {model_path}.") + + # Load Processor and Tokenizer for Dataset + processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + + # Load Dataset + dataset = GRMPromptDatasetVLT2V( + data_path, + processor=processor, + tokenizer=tokenizer, + strategy=None, + max_length=8192, + config=config, + is_training=False, + ) + + data_loader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + drop_last=False, + collate_fn=dataset.collate_fn, + ) + + logger.info(f"Starting inference with CoT: {use_cot}, batch_size: {batch_size}") + + correct_samples = [] + total_samples = 0 + correct_count = 0 + + for batch_idx, batch in enumerate(tqdm(data_loader)): + try: + input_texts, image_inputs_list, video_inputs_list, extras = batch + + # Prepare inputs for vLLM (same format as test_grm_vl_vllm.py) + inputs = [] + for i in range(len(input_texts)): + prompt = input_texts[i] + image_inputs = image_inputs_list[i] + video_inputs = video_inputs_list[i] + + mm_data = {} + if image_inputs is not None: + mm_data["image"] = image_inputs + if video_inputs is not None: + mm_data["video"] = video_inputs + + inputs.append({ + "prompt": prompt, + "multi_modal_data": mm_data + }) + + # Generate with vLLM + outputs = llm.generate(inputs, sampling_params=sampling_params) + + # Decode + gen_texts = [output.outputs[0].text for output in outputs] + + # Evaluate and filter + for i, (gen_text, extra) in enumerate(zip(gen_texts, extras)): + total_samples += 1 + predicted_answer = extract_answer(gen_text) + + # Get raw item to compute gt_preference from scores + raw_item = extra.get('_raw_item', {}) + gt_preference = compute_gt_preference_from_scores(raw_item) + + # Mapping logic: + # "A" means Video 1 is better + # "B" means Video 2 is better + is_correct = False + if gt_preference == "A" and predicted_answer == "Video 1 is better": + is_correct = True + elif gt_preference == "B" and predicted_answer == "Video 2 is better": + is_correct = True + elif gt_preference == "C": + # Handle tie case (should be rare) + logger.warning(f"Tie detected in sample {total_samples}, skipping") + continue + + if is_correct: + correct_count += 1 + # Get video paths from raw item + data_root = raw_item.get('data_root', '') + video1_path = os.path.join(data_root, "videos", raw_item.get('file_name1', '')) + video2_path = os.path.join(data_root, "videos", raw_item.get('file_name2', '')) + + # Prepare sample for rejection sampling training + sample = { + "prompt": raw_item.get('prompt', ''), + "path1": video1_path, + "path2": video2_path, + "preference": gt_preference, + "generated_text": gen_text, + "predicted_answer": predicted_answer, + "score1_total": compute_total_score(raw_item, 1), + "score2_total": compute_total_score(raw_item, 2), + } + + # If we want to use the generated CoT reasoning, extract it + if use_cot: + # Extract reasoning from generated text + reasoning_match = None + # Try tag first + if "" in gen_text: + reasoning_pattern = r"(.*?)" + reasoning_match = re.search(reasoning_pattern, gen_text, re.DOTALL) + # Try as fallback (in case model uses different format) + elif "" in gen_text: + reasoning_pattern = r"(.*?)" + reasoning_match = re.search(reasoning_pattern, gen_text, re.DOTALL) + + if reasoning_match: + reasoning = reasoning_match.group(1).strip() + sample["reasoning"] = reasoning + else: + # If no reasoning found, use the full generated text (excluding answer) + answer_part = f"{predicted_answer}" if predicted_answer else "" + reasoning_candidate = gen_text.replace(answer_part, "").strip() + sample["reasoning"] = reasoning_candidate if reasoning_candidate else None + + correct_samples.append(sample) + + if total_samples % 100 == 0: + logger.info(f"Processed {total_samples} samples, {correct_count} correct ({correct_count/total_samples*100:.2f}%)") + + except Exception as e: + logger.error(f"Error at batch {batch_idx}: {e}") + import traceback + traceback.print_exc() + raise + + # Summary + accuracy = correct_count / total_samples if total_samples > 0 else 0.0 + logger.info(f"Inference completed. Accuracy: {accuracy*100:.2f}% ({correct_count}/{total_samples})") + logger.info(f"Filtered {len(correct_samples)} correct samples for rejection sampling training") + + # Save filtered samples + os.makedirs(os.path.dirname(output_path), exist_ok=True) + with open(output_path, 'w', encoding='utf-8') as f: + json.dump(correct_samples, f, indent=2, ensure_ascii=False) + + logger.info(f"Saved {len(correct_samples)} correct samples to {output_path}") + + # Save statistics + stats_path = output_path.replace('.json', '_stats.txt') + with open(stats_path, 'w', encoding='utf-8') as f: + f.write(f"Dataset paths: {data_path}\n") + f.write(f"Model path: {model_path}\n") + f.write(f"Total samples: {total_samples}\n") + f.write(f"Correct samples: {correct_count}\n") + f.write(f"Accuracy: {accuracy*100:.2f}%\n") + f.write(f"Filtered samples for training: {len(correct_samples)}\n") + + return correct_samples + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Rejection Sampling Inference for T2V") + parser.add_argument("--model_path", type=str, required=True, help="Path to trained GRM model") + parser.add_argument("--data_path", type=str, required=True, help="Dataset path(s) in format 'source:path' (comma-separated for multiple)") + parser.add_argument("--output_path", type=str, required=True, help="Path to save filtered samples") + parser.add_argument("--batch_size", type=int, default=32, help="Batch size for inference") + parser.add_argument("--max_new_tokens", type=int, default=2048, help="Maximum tokens to generate") + parser.add_argument("--use_cot", action="store_true", default=True, help="Use CoT instruction for reasoning") + parser.add_argument("--task_instruction", type=str, default=TASK_INSTRUCTION_COT_T2V, help="Task instruction template") + + # vLLM arguments + parser.add_argument("--tensor_parallel_size", type=int, default=1, help="Number of GPUs for tensor parallelism") + parser.add_argument("--gpu_memory_utilization", type=float, default=0.9, help="GPU memory utilization ratio") + parser.add_argument("--video_fps", type=float, default=2.0, help="FPS for video processing") + + args = parser.parse_args() + + # Parse data path + data_paths = args.data_path.split(",") if isinstance(args.data_path, str) else args.data_path + + config = { + "task_instruction": args.task_instruction, + "name": "rejection_sampling_inference_t2v", + "video_fps": args.video_fps, + } + + inference_and_filter( + model_path=args.model_path, + data_path=data_paths, + output_path=args.output_path, + config=config, + batch_size=args.batch_size, + max_new_tokens=args.max_new_tokens, + use_cot=args.use_cot, + tensor_parallel_size=args.tensor_parallel_size, + gpu_memory_utilization=args.gpu_memory_utilization, + video_fps=args.video_fps, + ) + diff --git a/examples/grm_training/rejection_sampling/run_rejection_sampling.sh b/examples/grm_training/rejection_sampling/run_rejection_sampling.sh new file mode 100644 index 0000000..e5ab07f --- /dev/null +++ b/examples/grm_training/rejection_sampling/run_rejection_sampling.sh @@ -0,0 +1,193 @@ +#!/bin/bash + +# Rejection Sampling Training Script +# This script performs the complete rejection sampling pipeline: +# 1. Inference on dataset and filter correct samples +# 2. Convert filtered samples to training format +# 3. Train the model on filtered samples + +set -e + +unset http_proxy +unset https_proxy +unset HTTP_PROXY +unset HTTPS_PROXY + +############################# Configuration ########################## +# Model path (cold-start model) +# Please set your model path here +MODEL_PATH="/mnt/shared-storage-user/puyuan/wanzunian/models/lightrlhf-grm-lr1e5-imagegen_cot_reward-qwen2.5vl3B-gs3000" + +# Dataset configuration +# Please set your dataset path here (format: "source:path") +DATA_PATH="hpdv3:/mnt/shared-storage-user/puyuan/wanzunian/datasets/HPDv3/train_sub5k.json" +# Please set your dataset root directory here +DATA_ROOT="/mnt/shared-storage-user/puyuan/wanzunian/datasets/HPDv3" + +# Output paths +OUTPUT_DIR="./results/rejection_sampling_$(date +%Y%m%d_%H%M%S)" +FILTERED_SAMPLES_PATH="${OUTPUT_DIR}/filtered_samples.json" +TRAINING_DATA_PATH="${OUTPUT_DIR}/rejection_sampling_train.json" +FINAL_CHECKPOINT_PATH="${OUTPUT_DIR}/checkpoint" + +# Training hyperparameters +TBS=8 +LR=2.5e-6 +MAX_LENGTH=13000 +MAX_EPOCHS=3 +MICRO_BATCH_SIZE=1 +GRADIENT_ACCUMULATION_STEPS=32 + +# Inference parameters (reduced to avoid OOM) +INFERENCE_BATCH_SIZE=8 +MAX_NEW_TOKENS=2048 + +# Task instruction for CoT reasoning +TASK_INSTRUCTION="""Given a caption and two images generated based on this caption, please analyze in detail the two provided images. +Evaluate them on various dimensions such as semantic consistency (how closely the image content aligns with the caption), +aesthetics (composition, color usage, artistic expression), authenticity (realism and attention to detail), +and any other factors you deem relevant. For each evaluation dimension, +provide a score between 1-10 for both images (e.g., Image 1: 8/10, Image 2: 6/10) and provide a concise rationale for the score. +Calculate the total score for each image by summing all dimension scores. +Use a chain-of-thought process to detail your reasoning steps, and enclose all your detailed reasoning within tags. +Then, in the tag, output exactly one of the following strings: 'Image 1 is better' or 'Image 2 is better' or 'Both are equal' based on the total scores. +No additional text is allowed in the section. +Example output format: + +Semantic consistency: Image 1 (9/10) - ...; Image 2 (7/10) - ... +Aesthetics: Image 2 (8/10) - ...; Image 1 (8/10) - ... +Authenticity: Image 1 (8/10) - ...; Image 2 (5/10) - ... +[Additional dimensions if any]: Image 2 (8/10) - ...; Image 1 (6/10) - ... +Total score: +Image 1: 9+8+8+6=31 +Image 2: 7+8+5+8=28 + +Image 1 is better +Note: In the example above, scores and the final answer are placeholders meant only to demonstrate the format. Your actual evaluation should be based on the quality of two given images. +Your task is provided as follows: +Text Caption: **{prompt}** +""" + +############################### Environment ##################### +export GPUS_PER_NODE=${GPUS_PER_NODE:-2} # Use 2 GPUs +export NNODES=${NNODES:-1} +export NODE_RANK=${RANK:-0} +export MASTER_ADDR=${MASTER_ADDR:-"localhost"} +export MASTER_PORT=${MASTER_PORT:-29500} + +export WORLD_SIZE=$((NNODES * GPUS_PER_NODE)) + +# Validate required configuration +if [ -z "${MODEL_PATH}" ]; then + echo "Error: MODEL_PATH is not set. Please configure it in the script." + exit 1 +fi + +if [ -z "${DATA_PATH}" ]; then + echo "Error: DATA_PATH is not set. Please configure it in the script." + exit 1 +fi + +if [ -z "${DATA_ROOT}" ]; then + echo "Error: DATA_ROOT is not set. Please configure it in the script." + exit 1 +fi + +# Create output directory +mkdir -p ${OUTPUT_DIR} +LOG_BASE="${OUTPUT_DIR}/logs" +mkdir -p ${LOG_BASE} + +echo "==========================================" +echo "Rejection Sampling Training Pipeline" +echo "==========================================" +echo "Model: ${MODEL_PATH}" +echo "Data: ${DATA_PATH}" +echo "Output: ${OUTPUT_DIR}" +echo "==========================================" + +############################### Step 1: Inference and Filter ########################## +echo "" +echo "Step 1: Running inference and filtering correct samples..." +echo "==========================================" + +# Use vLLM for inference (vLLM handles multi-GPU internally via tensor_parallel_size) +python examples/grm_training/rejection_sampling/rejection_sampling_inference.py \ + --model_path ${MODEL_PATH} \ + --data_path ${DATA_PATH} \ + --output_path ${FILTERED_SAMPLES_PATH} \ + --batch_size ${INFERENCE_BATCH_SIZE} \ + --max_new_tokens ${MAX_NEW_TOKENS} \ + --use_cot \ + --task_instruction "${TASK_INSTRUCTION}" \ + --tensor_parallel_size ${GPUS_PER_NODE} \ + --gpu_memory_utilization 0.9 \ + 2>&1 | tee ${LOG_BASE}/inference.log + +if [ ! -f "${FILTERED_SAMPLES_PATH}" ]; then + echo "Error: Filtered samples file not created!" + exit 1 +fi + +echo "Step 1 completed. Filtered samples saved to: ${FILTERED_SAMPLES_PATH}" + +############################### Step 2: Convert to Training Format ########################## +echo "" +echo "Step 2: Converting filtered samples to training format..." +echo "==========================================" + +python examples/grm_training/rejection_sampling/convert_to_rejection_sampling_data.py \ + --filtered_samples_path ${FILTERED_SAMPLES_PATH} \ + --output_path ${TRAINING_DATA_PATH} \ + --data_root ${DATA_ROOT} \ + --task_instruction "${TASK_INSTRUCTION}" \ + 2>&1 | tee ${LOG_BASE}/convert.log + +if [ ! -f "${TRAINING_DATA_PATH}" ]; then + echo "Error: Training data file not created!" + exit 1 +fi + +echo "Step 2 completed. Training data saved to: ${TRAINING_DATA_PATH}" + +############################### Step 3: Training ########################## +echo "" +echo "Step 3: Training on filtered samples..." +echo "==========================================" + +# Use imagegen-cot-reward handler for the converted data +TRAINING_DATA_SOURCE="imagegen-cot-reward-5k:${TRAINING_DATA_PATH}" + +torchrun --nnodes $NNODES --nproc-per-node $GPUS_PER_NODE \ + --node_rank $NODE_RANK --master-port $MASTER_PORT --master-addr $MASTER_ADDR \ + examples/grm_training/train_grm_vl.py \ + --pretrain ${MODEL_PATH} \ + --save_path ${FINAL_CHECKPOINT_PATH} \ + --ckpt_path ${FINAL_CHECKPOINT_PATH} \ + --train_batch_size ${TBS} \ + --micro_train_batch_size ${MICRO_BATCH_SIZE} \ + --max_epochs ${MAX_EPOCHS} \ + --lr_warmup_ratio 0.03 \ + --prompt_max_len ${MAX_LENGTH} \ + --fps 2.0 \ + --zero_stage 3 \ + --bf16 \ + --actor_learning_rate ${LR} \ + --train_data ${TRAINING_DATA_SOURCE} \ + --gradient_checkpointing \ + --save_steps 1000 \ + --max_ckpt_num 8 \ + --use_tensorboard "${OUTPUT_DIR}/tensorboard" \ + --l2 0.0 \ + --flash_attn \ + --task_instruction "${TASK_INSTRUCTION}" \ + 2>&1 | tee ${LOG_BASE}/training.log + +echo "" +echo "==========================================" +echo "Rejection Sampling Training Completed!" +echo "==========================================" +echo "Final checkpoint: ${FINAL_CHECKPOINT_PATH}/final_checkpoint" +echo "All outputs saved to: ${OUTPUT_DIR}" +echo "==========================================" + diff --git a/examples/grm_training/rejection_sampling/run_rejection_sampling_t2v.sh b/examples/grm_training/rejection_sampling/run_rejection_sampling_t2v.sh new file mode 100755 index 0000000..8b623c8 --- /dev/null +++ b/examples/grm_training/rejection_sampling/run_rejection_sampling_t2v.sh @@ -0,0 +1,197 @@ +#!/bin/bash + +# Rejection Sampling Training Script for Text-to-Video (T2V) +# This script performs the complete rejection sampling pipeline: +# 1. Inference on dataset and filter correct samples +# 2. Convert filtered samples to training format +# 3. Train the model on filtered samples + +set -e + +unset http_proxy +unset https_proxy +unset HTTP_PROXY +unset HTTPS_PROXY + +############################# Configuration ########################## +# Model path (cold-start model) +# Please set your model path here +MODEL_PATH="/mnt/shared-storage-user/puyuan/wanzunian/models/lightrlhf-grm-lr1e5-imagegen_cot_reward-qwen2.5vl3B-gs3000" + +# Dataset configuration +# Multiple rapidata-t2v datasets +DATA_PATH=( + "rapidata-t2v:/mnt/shared-storage-user/puyuan/wanzunian/datasets/rapidata/text-2-video-human-preferences-veo3/data/train-00000-of-00001.parquet" + "rapidata-t2v:/mnt/shared-storage-user/puyuan/wanzunian/datasets/rapidata/text-2-video-human-preferences-pika2.2/data/train-00000-of-00001.parquet" + "rapidata-t2v:/mnt/shared-storage-user/puyuan/wanzunian/datasets/rapidata/text-2-video-human-preferences-wan2.1/data/train-00000-of-00001.parquet" + "rapidata-t2v:/mnt/shared-storage-user/puyuan/wanzunian/datasets/rapidata/text-2-video-human-preferences/data/train-00000-of-00001.parquet" +) + +# Output paths +OUTPUT_DIR="./results/rejection_sampling_t2v_$(date +%Y%m%d_%H%M%S)" +FILTERED_SAMPLES_PATH="${OUTPUT_DIR}/filtered_samples.json" +TRAINING_DATA_PATH="${OUTPUT_DIR}/rejection_sampling_train.json" +FINAL_CHECKPOINT_PATH="${OUTPUT_DIR}/checkpoint" + +# Training hyperparameters +TBS=8 +LR=2.5e-6 +MAX_LENGTH=13000 +MAX_EPOCHS=3 +MICRO_BATCH_SIZE=1 +GRADIENT_ACCUMULATION_STEPS=32 + +# Inference parameters (reduced to avoid OOM) +INFERENCE_BATCH_SIZE=8 +MAX_NEW_TOKENS=2048 + +# Video FPS configuration +VIDEO_FPS=2.0 + +# Task instruction for CoT reasoning (T2V) +TASK_INSTRUCTION="""Given a caption and two videos generated based on this caption, please analyze in detail the two provided videos. +Evaluate them on various dimensions such as semantic consistency (how closely the video content aligns with the caption), temporal coherence (smoothness and logical flow of motion across frames), authenticity (realism and attention to detail), and any other factors you deem relevant. +For each evaluation dimension, provide a score between 1-10 for both videos (e.g., Video 1: 8/10, Video 2: 6/10) and provide a concise rationale for the score. +Calculate the total score for each video by summing all dimension scores. +Use a chain-of-thought process to detail your reasoning steps, and enclose all your detailed reasoning within and tags. Then, in the tag, output exactly one of the following strings: +'Video 1 is better' or 'Video 2 is better' based on the total scores. No additional text is allowed in the section. +Example output format: + +1. Semantic consistency: Video 1 (9/10) - ...; Video 2 (7/10) - ... +2. Temporal coherence: Video 1 (8/10) - ...; Video 2 (6/10) - ... +3. Authenticity: Video 1 (7/10) - ...; Video 2 (5/10) - ... +... +[Additional dimensions if any]: Video 2 (8/10) - ...; Video 1 (6/10) - ... +Total score: +Video 1: 9+8+7+6=30 +Video 2: 7+6+5+8=26 + +Video 1 is better + +Note: In the example above, scores and the final answer are placeholders meant only to demonstrate the format. Your actual evaluation should be based on the quality of two given videos. +Your task is provided as follows: +Text Caption: **{prompt}** +""" + +############################### Environment ##################### +export GPUS_PER_NODE=${GPUS_PER_NODE:-2} # Use 2 GPUs +export NNODES=${NNODES:-1} +export NODE_RANK=${RANK:-0} +export MASTER_ADDR=${MASTER_ADDR:-"localhost"} +export MASTER_PORT=${MASTER_PORT:-29500} + +export WORLD_SIZE=$((NNODES * GPUS_PER_NODE)) + +# Validate required configuration +if [ -z "${MODEL_PATH}" ]; then + echo "Error: MODEL_PATH is not set. Please configure it in the script." + exit 1 +fi + +if [ ${#DATA_PATH[@]} -eq 0 ]; then + echo "Error: DATA_PATH is not set. Please configure it in the script." + exit 1 +fi + +# Create output directory +mkdir -p ${OUTPUT_DIR} +LOG_BASE="${OUTPUT_DIR}/logs" +mkdir -p ${LOG_BASE} + +echo "==========================================" +echo "Rejection Sampling Training Pipeline (T2V)" +echo "==========================================" +echo "Model: ${MODEL_PATH}" +echo "Data: ${DATA_PATH[@]}" +echo "Output: ${OUTPUT_DIR}" +echo "==========================================" + +############################### Step 1: Inference and Filter ########################## +echo "" +echo "Step 1: Running inference and filtering correct samples..." +echo "==========================================" + +# Convert array to comma-separated string for Python script +DATA_PATH_STR=$(IFS=','; echo "${DATA_PATH[*]}") + +# Use vLLM for inference (vLLM handles multi-GPU internally via tensor_parallel_size) +python examples/grm_training/rejection_sampling/rejection_sampling_inference_t2v.py \ + --model_path ${MODEL_PATH} \ + --data_path ${DATA_PATH_STR} \ + --output_path ${FILTERED_SAMPLES_PATH} \ + --batch_size ${INFERENCE_BATCH_SIZE} \ + --max_new_tokens ${MAX_NEW_TOKENS} \ + --use_cot \ + --task_instruction "${TASK_INSTRUCTION}" \ + --tensor_parallel_size ${GPUS_PER_NODE} \ + --gpu_memory_utilization 0.9 \ + --video_fps ${VIDEO_FPS} \ + 2>&1 | tee ${LOG_BASE}/inference.log + +if [ ! -f "${FILTERED_SAMPLES_PATH}" ]; then + echo "Error: Filtered samples file not created!" + exit 1 +fi + +echo "Step 1 completed. Filtered samples saved to: ${FILTERED_SAMPLES_PATH}" + +############################### Step 2: Convert to Training Format ########################## +echo "" +echo "Step 2: Converting filtered samples to training format..." +echo "==========================================" + +python examples/grm_training/rejection_sampling/convert_to_rejection_sampling_data_t2v.py \ + --filtered_samples_path ${FILTERED_SAMPLES_PATH} \ + --output_path ${TRAINING_DATA_PATH} \ + --task_instruction "${TASK_INSTRUCTION}" \ + --video_fps ${VIDEO_FPS} \ + 2>&1 | tee ${LOG_BASE}/convert.log + +if [ ! -f "${TRAINING_DATA_PATH}" ]; then + echo "Error: Training data file not created!" + exit 1 +fi + +echo "Step 2 completed. Training data saved to: ${TRAINING_DATA_PATH}" + +############################### Step 3: Training ########################## +echo "" +echo "Step 3: Training on filtered samples..." +echo "==========================================" + +# Use imagegen-cot-reward handler for the converted data (it supports video too) +TRAINING_DATA_SOURCE="imagegen-cot-reward-5k:${TRAINING_DATA_PATH}" + +torchrun --nnodes $NNODES --nproc-per-node $GPUS_PER_NODE \ + --node_rank $NODE_RANK --master-port $MASTER_PORT --master-addr $MASTER_ADDR \ + examples/grm_training/train_grm_vl.py \ + --pretrain ${MODEL_PATH} \ + --save_path ${FINAL_CHECKPOINT_PATH} \ + --ckpt_path ${FINAL_CHECKPOINT_PATH} \ + --train_batch_size ${TBS} \ + --micro_train_batch_size ${MICRO_BATCH_SIZE} \ + --max_epochs ${MAX_EPOCHS} \ + --lr_warmup_ratio 0.03 \ + --prompt_max_len ${MAX_LENGTH} \ + --fps ${VIDEO_FPS} \ + --zero_stage 3 \ + --bf16 \ + --actor_learning_rate ${LR} \ + --train_data ${TRAINING_DATA_SOURCE} \ + --gradient_checkpointing \ + --save_steps 1000 \ + --max_ckpt_num 8 \ + --use_tensorboard "${OUTPUT_DIR}/tensorboard" \ + --l2 0.0 \ + --flash_attn \ + --task_instruction "${TASK_INSTRUCTION}" \ + 2>&1 | tee ${LOG_BASE}/training.log + +echo "" +echo "==========================================" +echo "Rejection Sampling Training Completed (T2V)!" +echo "==========================================" +echo "Final checkpoint: ${FINAL_CHECKPOINT_PATH}/final_checkpoint" +echo "All outputs saved to: ${OUTPUT_DIR}" +echo "==========================================" + diff --git a/examples/grm_training/rejection_sampling/train_rejection_sampling.sh b/examples/grm_training/rejection_sampling/train_rejection_sampling.sh new file mode 100755 index 0000000..eff3c60 --- /dev/null +++ b/examples/grm_training/rejection_sampling/train_rejection_sampling.sh @@ -0,0 +1,134 @@ +#!/bin/bash + +# Training script for rejection sampling data +# This script trains the model on the filtered rejection sampling data + +set -e + +unset http_proxy +unset https_proxy +unset HTTP_PROXY +unset HTTPS_PROXY + +############################# Configuration ########################## +# Model path (pretrained model to continue training from) +MODEL_PATH="/mnt/shared-storage-user/puyuan/wanzunian/models/lightrlhf-grm-lr1e5-imagegen_cot_reward-qwen2.5vl3B-gs3000" + +# Training data path (already converted rejection sampling data) +TRAINING_DATA_PATH="/mnt/shared-storage-user/sunjiaxuan/dec/LightRFT/results/rejection_sampling_20251226_185045/filtered_samples.json" + +# Output directory for checkpoints +OUTPUT_DIR="/mnt/shared-storage-user/sunjiaxuan/dec/LightRFT/results/rejection_sampling_20260102_022303/checkpoint" +LOG_DIR="/mnt/shared-storage-user/sunjiaxuan/dec/LightRFT/results/rejection_sampling_20260102_022303/logs" + +# Training hyperparameters +TBS=4 # Reduced from 8 to save memory +LR=2.5e-6 +MAX_LENGTH=13000 +MAX_EPOCHS=3 +MICRO_BATCH_SIZE=1 +GRADIENT_ACCUMULATION_STEPS=16 # Increase to maintain effective batch size (4 * 16 = 64) + +# Task instruction for CoT reasoning (must match the one used during inference) +TASK_INSTRUCTION="""Given a caption and two images generated based on this caption, please analyze in detail the two provided images. +Evaluate them on various dimensions such as semantic consistency (how closely the image content aligns with the caption), +aesthetics (composition, color usage, artistic expression), authenticity (realism and attention to detail), +and any other factors you deem relevant. For each evaluation dimension, +provide a score between 1-10 for both images (e.g., Image 1: 8/10, Image 2: 6/10) and provide a concise rationale for the score. +Calculate the total score for each image by summing all dimension scores. +Use a chain-of-thought process to detail your reasoning steps, and enclose all your detailed reasoning within tags. +Then, in the tag, output exactly one of the following strings: 'Image 1 is better' or 'Image 2 is better' or 'Both are equal' based on the total scores. +No additional text is allowed in the section. +Example output format: + +Semantic consistency: Image 1 (9/10) - ...; Image 2 (7/10) - ... +Aesthetics: Image 2 (8/10) - ...; Image 1 (8/10) - ... +Authenticity: Image 1 (8/10) - ...; Image 2 (5/10) - ... +[Additional dimensions if any]: Image 2 (8/10) - ...; Image 1 (6/10) - ... +Total score: +Image 1: 9+8+8+6=31 +Image 2: 7+8+5+8=28 + +Image 1 is better +Note: In the example above, scores and the final answer are placeholders meant only to demonstrate the format. Your actual evaluation should be based on the quality of two given images. +Your task is provided as follows: +Text Caption: **{prompt}** +""" + +############################### Environment ##################### +export GPUS_PER_NODE=${GPUS_PER_NODE:-2} # Use 2 GPUs by default +export NNODES=${NNODES:-1} +export NODE_RANK=${RANK:-0} +export MASTER_ADDR=${MASTER_ADDR:-"localhost"} +export MASTER_PORT=${MASTER_PORT:-29500} + +export WORLD_SIZE=$((NNODES * GPUS_PER_NODE)) + +# Memory optimization: reduce fragmentation +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +# Validate required configuration +if [ ! -f "${TRAINING_DATA_PATH}" ]; then + echo "Error: Training data file not found: ${TRAINING_DATA_PATH}" + exit 1 +fi + +if [ -z "${MODEL_PATH}" ]; then + echo "Error: MODEL_PATH is not set. Please configure it in the script." + exit 1 +fi + +# Create output directories +mkdir -p ${OUTPUT_DIR} +mkdir -p ${LOG_DIR} + +echo "==========================================" +echo "Rejection Sampling Training" +echo "==========================================" +echo "Model: ${MODEL_PATH}" +echo "Training Data: ${TRAINING_DATA_PATH}" +echo "Output: ${OUTPUT_DIR}" +echo "GPUs: ${GPUS_PER_NODE}" +echo "==========================================" + +# Use imagegen-cot-reward handler for the converted data +TRAINING_DATA_SOURCE="imagegen-cot-reward-5k:${TRAINING_DATA_PATH}" + +############################### Training ########################## +echo "" +echo "Starting training on rejection sampling data..." +echo "==========================================" + +torchrun --nnodes $NNODES --nproc-per-node $GPUS_PER_NODE \ + --node_rank $NODE_RANK --master-port $MASTER_PORT --master-addr $MASTER_ADDR \ + examples/grm_training/train_grm_vl.py \ + --pretrain ${MODEL_PATH} \ + --save_path ${OUTPUT_DIR} \ + --ckpt_path ${OUTPUT_DIR} \ + --train_batch_size ${TBS} \ + --micro_train_batch_size ${MICRO_BATCH_SIZE} \ + --max_epochs ${MAX_EPOCHS} \ + --lr_warmup_ratio 0.03 \ + --prompt_max_len ${MAX_LENGTH} \ + --fps 2.0 \ + --zero_stage 3 \ + --bf16 \ + --actor_learning_rate ${LR} \ + --train_data ${TRAINING_DATA_SOURCE} \ + --gradient_checkpointing \ + --save_steps 1000 \ + --max_ckpt_num 2 \ + --use_tensorboard "${OUTPUT_DIR}/../tensorboard" \ + --l2 0.0 \ + --flash_attn \ + --task_instruction "${TASK_INSTRUCTION}" \ + 2>&1 | tee ${LOG_DIR}/training.log + +echo "" +echo "==========================================" +echo "Training Completed!" +echo "==========================================" +echo "Final checkpoint: ${OUTPUT_DIR}/final_checkpoint" +echo "Training logs: ${LOG_DIR}/training.log" +echo "==========================================" + diff --git a/examples/grm_training/rejection_sampling/train_rejection_sampling_mix.sh b/examples/grm_training/rejection_sampling/train_rejection_sampling_mix.sh new file mode 100644 index 0000000..9b30790 --- /dev/null +++ b/examples/grm_training/rejection_sampling/train_rejection_sampling_mix.sh @@ -0,0 +1,130 @@ +#!/bin/bash + +# Mixed training script for rejection sampling Image (T2I) + Video (T2V) data +# 使用同一个 GRM 模型,在图像和视频拒绝采样数据上进行联合训练。 + +set -e + +unset http_proxy +unset https_proxy +unset HTTP_PROXY +unset HTTPS_PROXY + +############################# Configuration ########################## +# 预训练模型路径(从该模型继续训练) +MODEL_PATH="/mnt/shared-storage-user/puyuan/wanzunian/models/lightrlhf-grm-lr1e5-imagegen_cot_reward-qwen2.5vl3B-gs3000" + +# 图像拒绝采样数据(convert_to_rejection_sampling_data.py 的输出) +T2I_TRAINING_DATA_PATH="/mnt/shared-storage-user/sunjiaxuan/dec/LightRFT/results/rejection_sampling_20260115_170931/rejection_sampling_train.json" + +# 视频拒绝采样数据(convert_to_rejection_sampling_data_t2v.py 的输出) +T2V_TRAINING_DATA_PATH="/mnt/shared-storage-user/sunjiaxuan/dec/LightRFT/results/rejection_sampling_t2v_20260104_193830/rejection_sampling_train.json" + +# 输出目录 +OUTPUT_DIR="/mnt/shared-storage-user/sunjiaxuan/dec/LightRFT/results/rejection_sampling_mix_$(date +%Y%m%d_%H%M%S)/checkpoint" +LOG_DIR="$(dirname "${OUTPUT_DIR}")/logs" + +# 训练超参(可以按需修改) +TBS=8 # global train batch size +LR=2.5e-6 +MAX_LENGTH=13000 +MAX_EPOCHS=3 +MICRO_BATCH_SIZE=1 +GRADIENT_ACCUMULATION_STEPS=32 + +# 视频 FPS 配置 +VIDEO_FPS=2.0 + +# 注意: +# - Image 数据的 system prompt(task_instruction)从 T2I_TRAINING_DATA_PATH 对应的 json 里读取; +# - Video 数据的 system prompt 从 T2V_TRAINING_DATA_PATH 对应的 json 里读取; +# 每条样本在 json 的 conversations[0]['value'] 里已经带了各自的 CoT 说明,因此这里不再额外传统一的 TASK_INSTRUCTION。 + +############################### Environment ##################### +export GPUS_PER_NODE=${GPUS_PER_NODE:-2} +export NNODES=${NNODES:-1} +export NODE_RANK=${RANK:-0} +export MASTER_ADDR=${MASTER_ADDR:-"localhost"} +export MASTER_PORT=${MASTER_PORT:-29500} + +export WORLD_SIZE=$((NNODES * GPUS_PER_NODE)) + +# 减少显存碎片 +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +# 检查配置 +if [ ! -f "${T2I_TRAINING_DATA_PATH}" ]; then + echo "Error: T2I training data file not found: ${T2I_TRAINING_DATA_PATH}" + exit 1 +fi + +if [ ! -f "${T2V_TRAINING_DATA_PATH}" ]; then + echo "Error: T2V training data file not found: ${T2V_TRAINING_DATA_PATH}" + exit 1 +fi + +if [ -z "${MODEL_PATH}" ]; then + echo "Error: MODEL_PATH is not set. Please configure it in the script." + exit 1 +fi + +# 创建输出目录 +mkdir -p "${OUTPUT_DIR}" +mkdir -p "${LOG_DIR}" + +echo "==========================================" +echo "Rejection Sampling Mixed Training (T2I + T2V)" +echo "==========================================" +echo "Model: ${MODEL_PATH}" +echo "T2I Training Data: ${T2I_TRAINING_DATA_PATH}" +echo "T2V Training Data: ${T2V_TRAINING_DATA_PATH}" +echo "Output: ${OUTPUT_DIR}" +echo "GPUs: ${GPUS_PER_NODE}" +echo "Video FPS: ${VIDEO_FPS}" +echo "==========================================" + +# 这里利用 GRMDataset 中的多 handler 能力: +# args.train_data 会在 train_grm_vl.py 中被按逗号切分成 list, +# 每个元素是 "source:path" 的形式。 +T2I_SOURCE="imagegen-cot-reward-5k:${T2I_TRAINING_DATA_PATH}" +T2V_SOURCE="rejection-sampling-t2v:${T2V_TRAINING_DATA_PATH}" + +TRAINING_DATA_SOURCES="${T2I_SOURCE},${T2V_SOURCE}" + +############################### Training ########################## +echo "" +echo "Starting mixed training on T2I + T2V rejection sampling data..." +echo "==========================================" + +torchrun --nnodes $NNODES --nproc-per-node $GPUS_PER_NODE \ + --node_rank $NODE_RANK --master-port $MASTER_PORT --master-addr $MASTER_ADDR \ + examples/grm_training/train_grm_vl.py \ + --pretrain ${MODEL_PATH} \ + --save_path ${OUTPUT_DIR} \ + --ckpt_path ${OUTPUT_DIR} \ + --train_batch_size ${TBS} \ + --micro_train_batch_size ${MICRO_BATCH_SIZE} \ + --max_epochs ${MAX_EPOCHS} \ + --lr_warmup_ratio 0.03 \ + --prompt_max_len ${MAX_LENGTH} \ + --fps ${VIDEO_FPS} \ + --zero_stage 3 \ + --bf16 \ + --actor_learning_rate ${LR} \ + --train_data ${TRAINING_DATA_SOURCES} \ + --gradient_checkpointing \ + --save_steps 1000 \ + --max_ckpt_num 8 \ + --use_tensorboard "$(dirname "${OUTPUT_DIR}")/tensorboard" \ + --l2 0.0 \ + --flash_attn \ + 2>&1 | tee ${LOG_DIR}/training.log + +echo "" +echo "==========================================" +echo "Mixed Training Completed!" +echo "==========================================" +echo "Final checkpoint: ${OUTPUT_DIR}/final_checkpoint" +echo "Training logs: ${LOG_DIR}/training.log" +echo "==========================================" + diff --git a/examples/grm_training/rejection_sampling/train_rejection_sampling_t2v.sh b/examples/grm_training/rejection_sampling/train_rejection_sampling_t2v.sh new file mode 100755 index 0000000..1f05bc7 --- /dev/null +++ b/examples/grm_training/rejection_sampling/train_rejection_sampling_t2v.sh @@ -0,0 +1,138 @@ +#!/bin/bash + +# Training script for rejection sampling T2V (text-to-video) data +# This script trains the model on the filtered rejection sampling video data + +set -e + +unset http_proxy +unset https_proxy +unset HTTP_PROXY +unset HTTPS_PROXY + +############################# Configuration ########################## +# Model path (pretrained model to continue training from) +MODEL_PATH="/mnt/shared-storage-user/puyuan/wanzunian/models/lightrlhf-grm-lr1e5-imagegen_cot_reward-qwen2.5vl3B-gs3000" + +# Training data path (already converted rejection sampling data) +# This should be the output from convert_to_rejection_sampling_data_t2v.py +TRAINING_DATA_PATH="/mnt/shared-storage-user/sunjiaxuan/dec/LightRFT/results/rejection_sampling_t2v_20260104_193830/rejection_sampling_train.json" + +# Output directory for checkpoints +OUTPUT_DIR="/mnt/shared-storage-user/sunjiaxuan/dec/LightRFT/results/rejection_sampling_t2v_20260104_193830/checkpoint" +LOG_DIR="/mnt/shared-storage-user/sunjiaxuan/dec/LightRFT/results/rejection_sampling_t2v_20260104_193830/logs" + +# Training hyperparameters +TBS=8 +LR=2.5e-6 +MAX_LENGTH=13000 +MAX_EPOCHS=3 +MICRO_BATCH_SIZE=1 +GRADIENT_ACCUMULATION_STEPS=32 # Increase to maintain effective batch size + +# Video FPS configuration +VIDEO_FPS=2.0 + +# Task instruction for CoT reasoning (T2V) - must match the one used during inference +TASK_INSTRUCTION="""Given a caption and two videos generated based on this caption, please analyze in detail the two provided videos. +Evaluate them on various dimensions such as semantic consistency (how closely the video content aligns with the caption), temporal coherence (smoothness and logical flow of motion across frames), authenticity (realism and attention to detail), and any other factors you deem relevant. +For each evaluation dimension, provide a score between 1-10 for both videos (e.g., Video 1: 8/10, Video 2: 6/10) and provide a concise rationale for the score. +Calculate the total score for each video by summing all dimension scores. +Use a chain-of-thought process to detail your reasoning steps, and enclose all your detailed reasoning within and tags. Then, in the tag, output exactly one of the following strings: +'Video 1 is better' or 'Video 2 is better' based on the total scores. No additional text is allowed in the section. +Example output format: + +1. Semantic consistency: Video 1 (9/10) - ...; Video 2 (7/10) - ... +2. Temporal coherence: Video 1 (8/10) - ...; Video 2 (6/10) - ... +3. Authenticity: Video 1 (7/10) - ...; Video 2 (5/10) - ... +... +[Additional dimensions if any]: Video 2 (8/10) - ...; Video 1 (6/10) - ... +Total score: +Video 1: 9+8+7+6=30 +Video 2: 7+6+5+8=26 + +Video 1 is better + +Note: In the example above, scores and the final answer are placeholders meant only to demonstrate the format. Your actual evaluation should be based on the quality of two given videos. +Your task is provided as follows: +Text Caption: **{prompt}** +""" + +############################### Environment ##################### +export GPUS_PER_NODE=${GPUS_PER_NODE:-2} # Use 2 GPUs by default +export NNODES=${NNODES:-1} +export NODE_RANK=${RANK:-0} +export MASTER_ADDR=${MASTER_ADDR:-"localhost"} +export MASTER_PORT=${MASTER_PORT:-29500} + +export WORLD_SIZE=$((NNODES * GPUS_PER_NODE)) + +# Memory optimization: reduce fragmentation +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + +# Validate required configuration +if [ ! -f "${TRAINING_DATA_PATH}" ]; then + echo "Error: Training data file not found: ${TRAINING_DATA_PATH}" + echo "Please run convert_to_rejection_sampling_data_t2v.py first to convert filtered samples." + exit 1 +fi + +if [ -z "${MODEL_PATH}" ]; then + echo "Error: MODEL_PATH is not set. Please configure it in the script." + exit 1 +fi + +# Create output directories +mkdir -p ${OUTPUT_DIR} +mkdir -p ${LOG_DIR} + +echo "==========================================" +echo "Rejection Sampling Training (T2V)" +echo "==========================================" +echo "Model: ${MODEL_PATH}" +echo "Training Data: ${TRAINING_DATA_PATH}" +echo "Output: ${OUTPUT_DIR}" +echo "GPUs: ${GPUS_PER_NODE}" +echo "Video FPS: ${VIDEO_FPS}" +echo "==========================================" + +# Use rejection-sampling-t2v handler for the converted data +TRAINING_DATA_SOURCE="rejection-sampling-t2v:${TRAINING_DATA_PATH}" + +############################### Training ########################## +echo "" +echo "Starting training on rejection sampling T2V data..." +echo "==========================================" + +torchrun --nnodes $NNODES --nproc-per-node $GPUS_PER_NODE \ + --node_rank $NODE_RANK --master-port $MASTER_PORT --master-addr $MASTER_ADDR \ + examples/grm_training/train_grm_vl.py \ + --pretrain ${MODEL_PATH} \ + --save_path ${OUTPUT_DIR} \ + --ckpt_path ${OUTPUT_DIR} \ + --train_batch_size ${TBS} \ + --micro_train_batch_size ${MICRO_BATCH_SIZE} \ + --max_epochs ${MAX_EPOCHS} \ + --lr_warmup_ratio 0.03 \ + --prompt_max_len ${MAX_LENGTH} \ + --fps ${VIDEO_FPS} \ + --zero_stage 3 \ + --bf16 \ + --actor_learning_rate ${LR} \ + --train_data ${TRAINING_DATA_SOURCE} \ + --gradient_checkpointing \ + --save_steps 1000 \ + --max_ckpt_num 8 \ + --use_tensorboard "${OUTPUT_DIR}/../tensorboard" \ + --l2 0.0 \ + --flash_attn \ + --task_instruction "${TASK_INSTRUCTION}" \ + 2>&1 | tee ${LOG_DIR}/training.log + +echo "" +echo "==========================================" +echo "Training Completed!" +echo "==========================================" +echo "Final checkpoint: ${OUTPUT_DIR}/final_checkpoint" +echo "Training logs: ${LOG_DIR}/training.log" +echo "==========================================" diff --git a/lightrft/datasets/__init__.py b/lightrft/datasets/__init__.py index 06e06ac..c5ab354 100644 --- a/lightrft/datasets/__init__.py +++ b/lightrft/datasets/__init__.py @@ -17,5 +17,28 @@ from .prompts_dataset_vl import PromptDatasetVL from .sft_dataset import SFTDataset from .sft_dataset_vl import SFTDatasetVL +from .rft_dataset import RFTDatasetVL -__all__ = ["ProcessRewardDataset", "PromptDataset", "PromptDatasetVL", "SFTDataset", "SFTDatasetVL"] +# Import PairHandlers for RFTDatasetVL +from .rapidata import RapidataT2VPairHandler, RapidataI2VPairHandler +from .hpdv3 import HPDv3PairHandler +from .omnirewardbench import ( + OmniRewardBenchT2IPairHandler, + OmniRewardBenchT2VPairHandler, + VideoGenRewardBenchPairHandler, +) + +__all__ = [ + "ProcessRewardDataset", + "PromptDataset", + "PromptDatasetVL", + "SFTDataset", + "SFTDatasetVL", + "RFTDatasetVL", + "RapidataT2VPairHandler", + "RapidataI2VPairHandler", + "HPDv3PairHandler", + "OmniRewardBenchT2IPairHandler", + "OmniRewardBenchT2VPairHandler", + "VideoGenRewardBenchPairHandler", +] diff --git a/lightrft/datasets/grm_dataset.py b/lightrft/datasets/grm_dataset.py index f65b6ee..2e1bbb1 100644 --- a/lightrft/datasets/grm_dataset.py +++ b/lightrft/datasets/grm_dataset.py @@ -8,6 +8,7 @@ from .omnirewardbench import OmniRewardBenchT2IGRMHandler from .imagegen_cot_reward import ImageGenCoTRewardHandler from .hpdv3 import HPDv3GRMHandler +from .rejection_sampling_t2v import RejectionSamplingT2VHandler from .utils import zero_pad_sequences, load_multimodal_content, find_subsequence @@ -86,6 +87,7 @@ def __init__( "imagegen-cot-reward-5k": ImageGenCoTRewardHandler(), "omnirewardbench-t2i": OmniRewardBenchT2IGRMHandler(), "hpdv3": HPDv3GRMHandler(), + "rejection-sampling-t2v": RejectionSamplingT2VHandler(), } # Load data from all specified dataset paths diff --git a/lightrft/datasets/hpdv3.py b/lightrft/datasets/hpdv3.py index 9558a2a..0ab74e1 100644 --- a/lightrft/datasets/hpdv3.py +++ b/lightrft/datasets/hpdv3.py @@ -133,6 +133,10 @@ def parse_item(self, item: Dict[str, Any], media_content: Dict[str, Any], return messages0, messages1, other +# Alias for RFTDatasetVL compatibility +HPDv3PairHandler = HPDv3Handler + + class HPDv3GRMHandler(HPDv3Handler): """ Data Handler for HPDv3 dataset with Generative Reward Model (GRM) training. diff --git a/lightrft/datasets/imagegen_cot_reward.py b/lightrft/datasets/imagegen_cot_reward.py index 1e699c3..1b9aa2b 100644 --- a/lightrft/datasets/imagegen_cot_reward.py +++ b/lightrft/datasets/imagegen_cot_reward.py @@ -31,23 +31,53 @@ def load_data(self, path: str) -> List[Dict[str, Any]]: def get_media_info(self, item: Dict[str, Any]) -> Dict[str, Dict[str, str]]: """ - Extract path info for the two images. + Extract path info for the two images or videos. + Supports both images and videos based on file extension or video_fps field. """ data_root = item['data_root'] if not data_root: - raise ValueError(f"Missing 'data_root' in item. Cannot resolve image paths.") - images = item['images'] - image0_full_path = os.path.join(data_root, images[0]) - image1_full_path = os.path.join(data_root, images[1]) - - return { - 'image0': { - 'image_local_path': image0_full_path - }, - 'image1': { - 'image_local_path': image1_full_path - }, - } + raise ValueError(f"Missing 'data_root' in item. Cannot resolve media paths.") + media_paths = item['images'] # Can contain image or video paths + path0 = media_paths[0] if isinstance(media_paths[0], str) else media_paths[0] + path1 = media_paths[1] if isinstance(media_paths[1], str) else media_paths[1] + + # Check if paths are absolute or relative + if os.path.isabs(path0): + media0_full_path = path0 + else: + media0_full_path = os.path.join(data_root, path0) + + if os.path.isabs(path1): + media1_full_path = path1 + else: + media1_full_path = os.path.join(data_root, path1) + + # Check if it's video based on file extension or video_fps field + video_extensions = {'.mp4', '.avi', '.mov', '.mkv', '.webm'} + is_video = ( + item.get('video_fps') is not None or + any(media0_full_path.lower().endswith(ext) for ext in video_extensions) or + any(media1_full_path.lower().endswith(ext) for ext in video_extensions) + ) + + if is_video: + return { + 'video0': { + 'video_local_path': media0_full_path + }, + 'video1': { + 'video_local_path': media1_full_path + }, + } + else: + return { + 'image0': { + 'image_local_path': media0_full_path + }, + 'image1': { + 'image_local_path': media1_full_path + }, + } def parse_item( self, @@ -55,41 +85,87 @@ def parse_item( media_content: Dict[str, Any], config: Dict[str, Any] | None, ) -> Tuple[List[Dict], List[Dict], Dict]: - - image0 = media_content['image0'] - image1 = media_content['image1'] - - if not all([image0, image1]): - raise ValueError(f"Missing visual content for 'image0' or 'image1'.") - - # Get conversations from data item - conversations = item["conversations"] - system_prompt = conversations[0]['value'] - response = conversations[-1]['value'] - - # Build messages - messages = [{ - "role": "system", - "content": system_prompt - }, { - "role": "user", - "content": [{ - "type": "text", - "text": "**Image 1:**" + + # Check if it's video or image + is_video = 'video0' in media_content or 'video1' in media_content + + if is_video: + video0 = media_content.get('video0') + video1 = media_content.get('video1') + + if not all([video0, video1]): + raise ValueError(f"Missing visual content for 'video0' or 'video1'.") + + # Get FPS from config or item + fps = config.get("video_fps") if config else item.get("video_fps", 2.0) + + # Get conversations from data item + conversations = item["conversations"] + system_prompt = conversations[0]['value'] + response = conversations[-1]['value'] + + # Build messages for video + messages = [{ + "role": "system", + "content": system_prompt }, { - "type": "image", - "image": image0 + "role": "user", + "content": [{ + "type": "text", + "text": "**Video 1:**" + }, { + "type": "video", + "video": video0 if isinstance(video0, str) else video0.get('video_local_path'), + "fps": fps, + "max_pixels": 720 * 480 + }] + }, { + "role": "user", + "content": [{ + "type": "text", + "text": "**Video 2:**" + }, { + "type": "video", + "video": video1 if isinstance(video1, str) else video1.get('video_local_path'), + "fps": fps, + "max_pixels": 720 * 480 + }] }] - }, { - "role": "user", - "content": [{ - "type": "text", - "text": "**Image 2:**" + else: + image0 = media_content.get('image0') + image1 = media_content.get('image1') + + if not all([image0, image1]): + raise ValueError(f"Missing visual content for 'image0' or 'image1'.") + + # Get conversations from data item + conversations = item["conversations"] + system_prompt = conversations[0]['value'] + response = conversations[-1]['value'] + + # Build messages for image + messages = [{ + "role": "system", + "content": system_prompt + }, { + "role": "user", + "content": [{ + "type": "text", + "text": "**Image 1:**" + }, { + "type": "image", + "image": image0 + }] }, { - "type": "image", - "image": image1 + "role": "user", + "content": [{ + "type": "text", + "text": "**Image 2:**" + }, { + "type": "image", + "image": image1 + }] }] - }] # During evaluation, we do not include the response part in the messages is_training = config.get("is_training", True) @@ -97,7 +173,7 @@ def parse_item( messages.append({"role": "assistant", "content": response}) other = { - "source": item['source'], + "source": item.get('source', 'imagegen-cot-reward-5k'), "data_item": item, "system_prompt": system_prompt, "response": response, diff --git a/lightrft/datasets/omnirewardbench.py b/lightrft/datasets/omnirewardbench.py index b5ba9f5..6735b0a 100644 --- a/lightrft/datasets/omnirewardbench.py +++ b/lightrft/datasets/omnirewardbench.py @@ -371,3 +371,19 @@ def parse_item(self, item: Dict[str, Any], media_content: Dict[str, Any], "model2": item['model2'], } return messages, other + + +# Aliases for RFTDatasetVL compatibility +OmniRewardBenchT2IPairHandler = OmniRewardBenchT2IHandler +OmniRewardBenchT2VPairHandler = OmniRewardBenchT2VHandler + + +# Placeholder for VideoGenRewardBenchPairHandler +# TODO: Implement this handler when VideoGenRewardBench dataset is available +class VideoGenRewardBenchPairHandler(OmniRewardBenchT2VHandler): + """ + Placeholder handler for VideoGenRewardBench dataset. + Currently uses OmniRewardBenchT2VHandler as a fallback. + TODO: Implement proper handler when dataset is available. + """ + pass diff --git a/lightrft/datasets/rapidata.py b/lightrft/datasets/rapidata.py index cc58546..daaa7c7 100644 --- a/lightrft/datasets/rapidata.py +++ b/lightrft/datasets/rapidata.py @@ -23,6 +23,8 @@ class RapidataT2VHandler(BaseDataHandler): Dataset Repo: https://huggingface.co/Rapidata/datasets """ + task_type = "text-to-video" + def load_data(self, path: str) -> List[Dict[str, Any]]: """ Loads data from parquet file. @@ -52,16 +54,23 @@ def get_media_info(self, item: Dict[str, Any]) -> Dict[str, Dict[str, str]]: if 'file_name1' not in item or 'file_name2' not in item: raise ValueError(f"Item missing 'file_name1' or 'file_name2'.") - full_path1 = os.path.join(data_root, "videos", item['file_name1']) - full_path2 = os.path.join(data_root, "videos", item['file_name2']) + # Try both "videos" and "Videos" + video_dir = "videos" + if not os.path.exists(os.path.join(data_root, video_dir)): + if os.path.exists(os.path.join(data_root, "Videos")): + video_dir = "Videos" + + full_path1 = os.path.join(data_root, video_dir, item['file_name1']) + full_path2 = os.path.join(data_root, video_dir, item['file_name2']) return {'video1': {'video_local_path': full_path1}, 'video2': {'video_local_path': full_path2}} def _get_label(self, val1: float, val2: float) -> str: """ Helper to determine preference label based on two scores. - A > B, B > A, C == C """ + if val1 is None or val2 is None: + return "C" if val1 > val2: return "A" elif val1 < val2: @@ -85,6 +94,9 @@ def parse_item(self, item: Dict[str, Any], media_content: Dict[str, Any], task_instruction_template = config["task_instruction"] task_instruction = task_instruction_template.format(prompt=video_gen_prompt) + # Get max_pixels from config (default to 720 * 480 if not provided) + max_pixels = config.get("max_pixels", 720 * 480) + # Get FPS from config fps = config["video_fps"] @@ -100,8 +112,8 @@ def parse_item(self, item: Dict[str, Any], media_content: Dict[str, Any], "type": "video", "video": video1, "fps": fps, - "max_pixels": 720 * 480 - } # 480p limit to reduce memory + "max_pixels": max_pixels + } ] } ] @@ -115,21 +127,25 @@ def parse_item(self, item: Dict[str, Any], media_content: Dict[str, Any], "type": "video", "video": video2, "fps": fps, - "max_pixels": 720 * 480 + "max_pixels": max_pixels }] }] - # Get human preference labels based on weighted scores - pref_label = self._get_label(item["weighted_results1_Preference"], item["weighted_results2_Preference"]) - cohe_label = self._get_label(item["weighted_results1_Coherence"], item["weighted_results2_Coherence"]) - align_label = self._get_label(item['weighted_results1_Alignment'], item['weighted_results2_Alignment']) + # Get human preference labels and total scores based on weighted metrics + metrics = ['Preference', 'Coherence', 'Alignment'] + labels = { + f"{m.lower()}_label": self._get_label(item.get(f'weighted_results1_{m}'), item.get(f'weighted_results2_{m}')) + for m in metrics + } + + score1 = sum(item.get(f'weighted_results1_{m}') or 0.0 for m in metrics) + score2 = sum(item.get(f'weighted_results2_{m}') or 0.0 for m in metrics) other = { - "preference": pref_label, - "coherence": cohe_label, - "alignment": align_label, + "preference": self._get_label(score1, score2), + **labels, "source": item['source'], - "task_type": "t2v", + "task_type": self.task_type, } return messages0, messages1, other @@ -142,6 +158,8 @@ class RapidataI2VHandler(RapidataT2VHandler): Dataset Repo: https://huggingface.co/Rapidata/datasets """ + task_type = "image-to-video" + def __init__(self): super().__init__() @@ -204,6 +222,9 @@ def parse_item(self, item: Dict[str, Any], media_content: Dict[str, Any], task_instruction_template = config["task_instruction"] task_instruction = task_instruction_template.format(prompt=prompt_text) + # Get max_pixels from config (default to 720 * 480 if not provided) + max_pixels = config.get("max_pixels", 720 * 480) + # Get FPS from config fps = config["video_fps"] @@ -216,12 +237,12 @@ def parse_item(self, item: Dict[str, Any], media_content: Dict[str, Any], "content": [{ "type": "image", "image": copy.deepcopy(init_image), - "max_pixels": 720 * 480 + "max_pixels": max_pixels }, { "type": "video", "video": video1, "fps": fps, - "max_pixels": 720 * 480 + "max_pixels": max_pixels }] }] @@ -233,25 +254,177 @@ def parse_item(self, item: Dict[str, Any], media_content: Dict[str, Any], "content": [{ "type": "image", "image": copy.deepcopy(init_image), - "max_pixels": 720 * 480 + "max_pixels": max_pixels }, { "type": "video", "video": video2, "fps": fps, - "max_pixels": 720 * 480 + "max_pixels": max_pixels }] }] - # Get human preference labels based on weighted scores - pref_label = self._get_label(item['weighted_results1_Preference'], item['weighted_results2_Preference']) - cohe_label = self._get_label(item['weighted_results1_Coherence'], item['weighted_results2_Coherence']) - align_label = self._get_label(item['weighted_results1_Alignment'], item['weighted_results2_Alignment']) + # Get human preference labels and total scores based on weighted metrics + metrics = ['Preference', 'Coherence', 'Alignment'] + labels = { + f"{m.lower()}_label": self._get_label(item.get(f'weighted_results1_{m}'), item.get(f'weighted_results2_{m}')) + for m in metrics + } + + score1 = sum(item.get(f'weighted_results1_{m}') or 0.0 for m in metrics) + score2 = sum(item.get(f'weighted_results2_{m}') or 0.0 for m in metrics) other = { - "preference": pref_label, - "coherence": cohe_label, - "alignment": align_label, + "preference": self._get_label(score1, score2), + **labels, "source": item['source'], - "task_type": "t2v", # Text-to-Video + "task_type": self.task_type, } return messages0, messages1, other + + +class RapidataT2VPairHandler(RapidataT2VHandler): + """ + Data Handler for Rapidata text-to-video human preferences dataset in pairwise format. + """ + def __init__(self): + super().__init__() + + def parse_item(self, + item: Dict[str, Any], + media_content: Dict[str, Any], + config: Dict[str, Any] + ) -> Tuple[List[Dict], List[Dict], Dict]: + + video1 = media_content['video1'] + video2 = media_content['video2'] + + if not all([video1, video2]): + raise ValueError(f"Missing visual content for 'video1' or 'video2'.") + + # Get generation prompt from data item + video_gen_prompt = item["prompt"] + + # Get system prompts from config + task_instruction_template = config["task_instruction"] + task_instruction = task_instruction_template.format(prompt=video_gen_prompt) + + # Get max_pixels from config (default to 720 * 480 if not provided) + max_pixels = config.get("max_pixels", 720 * 480) + + # Get FPS from config + fps = config["video_fps"] + + # Build messages + messages = [ + { + "role": "system", + "content": [{"type": "text", "text": task_instruction}] + }, + { + "role": "user", + "content": [ + {"type": "text", "text": "The following is the first video."}, + {"type": "video", "video": video1, "fps": fps, "max_pixels": max_pixels}, + + {"type": "text", "text": "The following is the second video."}, + {"type": "video", "video": video2, "fps": fps, "max_pixels": max_pixels}, + ] + } + ] + + # Get human preference labels and total scores based on weighted metrics + metrics = ['Preference', 'Coherence', 'Alignment'] + labels = { + f"{m.lower()}_label": self._get_label(item.get(f'weighted_results1_{m}'), item.get(f'weighted_results2_{m}')) + for m in metrics + } + + score1 = sum(item.get(f'weighted_results1_{m}') or 0.0 for m in metrics) + score2 = sum(item.get(f'weighted_results2_{m}') or 0.0 for m in metrics) + + other = { + "preference": self._get_label(score1, score2), + **labels, + "source": item['source'], + "task_type": self.task_type, + } + return messages, other + + +class RapidataI2VPairHandler(RapidataI2VHandler): + """ + Data Handler for Rapidata image-to-video human preferences dataset in pairwise format. + """ + task_type = "image-to-video" + + def __init__(self): + super().__init__() + + def parse_item(self, + item: Dict[str, Any], + media_content: Dict[str, Any], + config: Dict[str, Any] + ) -> Tuple[List[Dict], List[Dict], Dict]: + + video1 = media_content['video1'] + video2 = media_content['video2'] + init_image = media_content['init_image'] + + if not all([video1, video2, init_image]): + raise ValueError("Missing visual content for 'video1' or 'video2' or 'init_image'.") + + # Get generation prompt from data item + prompt_text = item["prompt"] + + # Get system prompts from config + task_instruction_template = config["task_instruction"] + task_instruction = task_instruction_template.format(prompt=prompt_text) + + # Get FPS from config + fps = config["video_fps"] + max_pixels = config.get("max_pixels", 720 * 480) + + # Build messages + messages = [ + { + "role": "system", + "content": [{"type": "text", "text": task_instruction}] + }, + { + "role": "user", + "content": [ + {"type": "text", "text": "Reference Image:"}, + {"type": "image", "image": init_image, "max_pixels": max_pixels}, + + {"type": "text", "text": "The following is the first video."}, + {"type": "video", "video": video1, "fps": fps, "max_pixels": max_pixels}, + + {"type": "text", "text": "The following is the second video."}, + {"type": "video", "video": video2, "fps": fps, "max_pixels": max_pixels}, + ] + } + ] + + # Get human preference labels and total scores based on weighted metrics + metrics = ['Preference', 'Coherence', 'Alignment'] + labels = { + f"{m.lower()}_label": self._get_label(item.get(f'weighted_results1_{m}'), item.get(f'weighted_results2_{m}')) + for m in metrics + } + + score1 = sum(item.get(f'weighted_results1_{m}') or 0.0 for m in metrics) + score2 = sum(item.get(f'weighted_results2_{m}') or 0.0 for m in metrics) + + other = { + "preference": self._get_label(score1, score2), + **labels, + "source": item['source'], + "task_type": self.task_type, + } + return messages, other + + + +# Alias for RFTDatasetVL compatibility +RapidataT2VPairHandler = RapidataT2VHandler +RapidataI2VPairHandler = RapidataI2VHandler diff --git a/lightrft/datasets/rejection_sampling_t2v.py b/lightrft/datasets/rejection_sampling_t2v.py new file mode 100644 index 0000000..1388def --- /dev/null +++ b/lightrft/datasets/rejection_sampling_t2v.py @@ -0,0 +1,177 @@ +import os +import copy +import json +from typing import List, Dict, Any, Tuple +from loguru import logger + +from .utils import BaseDataHandler + + +class RejectionSamplingT2VHandler(BaseDataHandler): + """ + Data handler for Rejection Sampling text-to-video training data. + This handler processes video pairs for GRM training, similar to RapidataT2VHandler format. + + The data format is similar to imagegen-cot-reward but specifically for videos. + Each item contains: + - conversations: [{"from": "human", "value": task_instruction}, {"from": "gpt", "value": response}] + - images: [video1_path, video2_path] # Note: uses "images" field name for compatibility + - video_fps: float + """ + task_type = "text-to-video" + + def load_data(self, path: str) -> List[Dict[str, Any]]: + """ + Loads data from json file. + """ + raw_data = [] + with open(path, 'rb') as f: + raw_data = json.load(f) + + data_root = os.path.dirname(path) + for item in raw_data: + item['data_root'] = data_root + + logger.info(f"Loaded {len(raw_data)} samples from {path}") + return raw_data + + def _resolve_video_path(self, path: str, data_root: str) -> str: + """ + Resolve video path, handling case-insensitive 'videos'/'Videos' directory. + Similar to RapidataT2VHandler format. + """ + # Check if path is absolute or relative + if os.path.isabs(path): + full_path = path + else: + full_path = os.path.join(data_root, path) + + # If file exists, return it directly + if os.path.exists(full_path): + return full_path + + # Try to handle videos/Videos case sensitivity issue + # Replace 'videos' with 'Videos' or vice versa in the path + if '/videos/' in full_path: + alt_path = full_path.replace('/videos/', '/Videos/') + if os.path.exists(alt_path): + return alt_path + elif '/Videos/' in full_path: + alt_path = full_path.replace('/Videos/', '/videos/') + if os.path.exists(alt_path): + return alt_path + + # Also handle case where path ends with '/videos' or '/Videos' + if full_path.endswith('/videos'): + alt_path = full_path[:-7] + '/Videos' + if os.path.exists(alt_path): + return alt_path + elif full_path.endswith('/Videos'): + alt_path = full_path[:-7] + '/videos' + if os.path.exists(alt_path): + return alt_path + + # If still not found, return original path (will raise error later) + return full_path + + def get_media_info(self, item: Dict[str, Any]) -> Dict[str, Dict[str, str]]: + """ + Extract path info for the two videos. + Similar to RapidataT2VHandler format. + Handles case-insensitive 'videos'/'Videos' directory names. + """ + data_root = item['data_root'] + if not data_root: + raise ValueError(f"Missing 'data_root' in item. Cannot resolve video paths.") + + # Get video paths from "images" field (for compatibility with conversion script) + media_paths = item.get('images', []) + if len(media_paths) < 2: + raise ValueError(f"Item must contain at least 2 video paths in 'images' field.") + + path1 = media_paths[0] + path2 = media_paths[1] + + # Resolve paths with case-insensitive handling + video1_full_path = self._resolve_video_path(path1, data_root) + video2_full_path = self._resolve_video_path(path2, data_root) + + return { + 'video1': { + 'video_local_path': video1_full_path + }, + 'video2': { + 'video_local_path': video2_full_path + } + } + + def parse_item( + self, + item: Dict[str, Any], + media_content: Dict[str, Any], + config: Dict[str, Any] | None, + ) -> Tuple[List[Dict], Dict]: + """ + Parse item into messages format, similar to RapidataT2VHandler but for GRM training. + Returns messages in the format expected by GRMDataset. + """ + video1 = media_content.get('video1') + video2 = media_content.get('video2') + + if not all([video1, video2]): + raise ValueError(f"Missing visual content for 'video1' or 'video2'.") + + # Get FPS from config or item + fps = config.get("video_fps") if config else item.get("video_fps", 2.0) + + # Get max_pixels from config (default to 720 * 480 if not provided) + max_pixels = config.get("max_pixels", 720 * 480) if config else 720 * 480 + + # Get conversations from data item + conversations = item["conversations"] + system_prompt = conversations[0]['value'] + response = conversations[-1]['value'] + + # Build messages for video, similar to RapidataT2VHandler format + # But using the format from imagegen_cot_reward for GRM training + # Note: video1 and video2 are already loaded by load_multimodal_content + messages = [{ + "role": "system", + "content": system_prompt + }, { + "role": "user", + "content": [{ + "type": "text", + "text": "**Video 1:**" + }, { + "type": "video", + "video": video1, + "fps": fps, + "max_pixels": max_pixels + }] + }, { + "role": "user", + "content": [{ + "type": "text", + "text": "**Video 2:**" + }, { + "type": "video", + "video": video2, + "fps": fps, + "max_pixels": max_pixels + }] + }] + + # During evaluation, we do not include the response part in the messages + is_training = config.get("is_training", True) if config else True + if is_training: + messages.append({"role": "assistant", "content": response}) + + other = { + "source": item.get('source', 'rejection-sampling-t2v'), + "data_item": item, + "system_prompt": system_prompt, + "response": response, + "task_type": self.task_type, + } + return messages, other diff --git a/lightrft/datasets/rft_dataset.py b/lightrft/datasets/rft_dataset.py new file mode 100644 index 0000000..c2b8d60 --- /dev/null +++ b/lightrft/datasets/rft_dataset.py @@ -0,0 +1,171 @@ +import random +import copy + +from torch.utils.data import Dataset +from torch.nn.utils.rnn import pad_sequence +from transformers import AutoTokenizer, AutoProcessor + +from loguru import logger +from typing import Any, Dict, List, Tuple, Union + +from .utils import load_multimodal_content +from lightrft.datasets import ( + RapidataI2VPairHandler, + RapidataT2VPairHandler, + VideoGenRewardBenchPairHandler, + HPDv3PairHandler, + ImageGenCoTRewardHandler, + OmniRewardBenchT2IPairHandler, + OmniRewardBenchT2VPairHandler, +) + + +class RFTDatasetVL(Dataset): + """ + Dataset for Reinforcement Fine-Tuning (RFT) with vision-language models. + + RFTDatasetVL supports multiple data sources through pluggable Data Handlers + and is designed for training models using reinforcement learning. + + It loads data items, processes multimodal content (images, videos), and + prepares inputs suitable for the model. + + :param dataset_paths: List of dataset file paths or directories. The + handler is determined by the source keyword (e.g. "rapidata-t2v", + "videogen-rewardbench"). The format is "source:path". + e.g. "rapidata-t2v:/path/to/file.parquet" + :type dataset_paths: List[str] + :param processor: Multimodal processor used for tokenization and visual + processing. + :type processor: transformers.AutoProcessor + :param tokenizer: Tokenizer used for text tokenization. + :type tokenizer: transformers.AutoTokenizer + :param strategy: Optional data loading strategy. + :type strategy: Any + :param max_length: Maximum sequence length for tokenization/truncation. + Defaults to 4096. + :type max_length: int + :param is_train: Whether the dataset is used for training or evaluation. + Defaults to True. + :type is_train: bool + :param config: Additional configuration options. + :type config: Dict[str, Any] + """ + def __init__( + self, + dataset_paths: List[str], + processor: AutoProcessor, + tokenizer: AutoTokenizer, + strategy = None, + max_length: int = 4096, + is_train: bool = True, + config: Dict[str, Any] = None, + ): + + super().__init__() + self.processor = processor + self.tokenizer = tokenizer + self.strategy = strategy + self.max_length = max_length + self.is_train = is_train + self.config = config if config else {} + + self.media_content_loader = load_multimodal_content + + if "qwen" in self.processor.__class__.__name__.lower(): + from qwen_vl_utils import process_vision_info + self.process_vision_info = process_vision_info + else: + raise NotImplementedError(f"Processor type {self.processor.__class__.__name__} not supported yet.") + + self.handlers = { + "rapidata-i2v": RapidataI2VPairHandler(), + "rapidata-t2v": RapidataT2VPairHandler(), + "videogen-rewardbench": VideoGenRewardBenchPairHandler(), + "hpdv3": HPDv3PairHandler(), + "imagegen-cot-reward-5k": ImageGenCoTRewardHandler(), + "omnirewardbench-t2i": OmniRewardBenchT2IPairHandler(), + "omnirewardbench-t2v": OmniRewardBenchT2VPairHandler(), + } + + # Load data from all specified dataset paths + # We expect dataset_paths to be in the format: "source:path" + # e.g. "rapidata-t2v:/path/to/file.parquet" + self.data = [] + for item in dataset_paths: + try: + source, path = item.split(":", 1) + except ValueError: + raise ValueError(f"Dataset path '{item}' is not in the expected format 'source:path'.") + + if source not in self.handlers: + raise NotImplementedError(f"The data handler for source {source} is not implemented.") + + handler = self.handlers[source] + try: + loaded_items = handler.load_data(path) + for item in loaded_items: + item["source"] = source + self.data.extend(loaded_items) + except Exception as e: + logger.error(f"Failed to load data {path} (source: {source}): {e}") + + logger.info(f"Loaded {len(self.data)} items in total, sources: {[s for s in dataset_paths]}") + random.shuffle(self.data) + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + item = self.data[idx] + source = item["source"] + + handler = self.handlers[source] + + # Get paths for all media content + media_info = handler.get_media_info(item) + + # Load all media content at once + loaded_content = self.media_content_loader(media_info) + if loaded_content is None: + raise RuntimeError(f"Failed to load media content: {media_info}") + + # Select prompt based on task type or source if task_instruction is a dict + config = copy.deepcopy(self.config) + task_instruction = config.get("task_instruction") + if isinstance(task_instruction, dict): + if hasattr(handler, "task_type"): + prompt = task_instruction.get(handler.task_type) + if prompt is None: + raise ValueError(f"Task instruction for {handler.task_type} not found.") + else: + raise ValueError(f"Handler for source {source} does not specify a task_type.") + + config["task_instruction"] = prompt + + # Pass the loaded content dict to parse_item + messages, reference = handler.parse_item(item, loaded_content, config) + + # Prepare inputs from message sequences + input_text, image_inputs, video_inputs = self._prepare_inputs(messages) + + # Configure label, by default "general" + label = "general" + + return input_text, image_inputs, video_inputs, reference, label + + def _prepare_inputs(self, messages): + if not self.is_train: + # For evaluation, we only need the input text without generation prompt + # Remove messages with role "assistant" + messages = [msg for msg in messages if msg["role"] != "assistant"] + + input_text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + + image_inputs, video_inputs = self.process_vision_info(messages, return_video_kwargs=False) + + return input_text, image_inputs, video_inputs + + def collate_fn(self, batch): + input_texts, image_inputs_list, video_inputs_list, references, labels = zip(*batch) + return list(input_texts), list(image_inputs_list), list(video_inputs_list), list(references), list(labels) \ No newline at end of file diff --git a/lightrft/models/grm_vl.py b/lightrft/models/grm_vl.py index 7a2d311..983c33e 100644 --- a/lightrft/models/grm_vl.py +++ b/lightrft/models/grm_vl.py @@ -65,7 +65,7 @@ def __init__( pretrain_or_model, trust_remote_code=True, attn_implementation=attn_implementation, - dtype=torch.bfloat16 if bf16 else "auto", + torch_dtype=torch.bfloat16 if bf16 else "auto", device_map=device_map, ) @@ -217,4 +217,4 @@ def print_trainable_parameters(self): actor.print_trainable_parameters() # Output: trainable params: 4,194,304 || all params: 7,241,732,096 || trainable%: 0.058 """ - self.model.print_trainable_parameters() + self.model.print_trainable_parameters() \ No newline at end of file diff --git a/lightrft/strategy/strategy_base.py b/lightrft/strategy/strategy_base.py index ab62265..1d3c248 100644 --- a/lightrft/strategy/strategy_base.py +++ b/lightrft/strategy/strategy_base.py @@ -37,7 +37,7 @@ set_sequence_parallel_group, ) from lightrft.strategy.utils.statistic import GenLenAnalyser -from .sglang_utils import get_sglang_engine_for_rollout +# from .sglang_utils import get_sglang_engine_for_rollout from .vllm_utils import get_vllm_engine_for_rollout from lightrft.strategy.config import StrategyConfig