diff --git a/examples/gsm8k_geo3k/train_colocate.py b/examples/gsm8k_geo3k/train_colocate.py index 20abc35..7c27a66 100644 --- a/examples/gsm8k_geo3k/train_colocate.py +++ b/examples/gsm8k_geo3k/train_colocate.py @@ -45,15 +45,16 @@ ensure_video_input_available() -from lightrft.datasets import PromptDatasetVL, SFTDatasetVL +from lightrft.datasets import DatasetConfig, DatasetLoader from lightrft.models.actor_language import ActorLanguage from lightrft.models.actor_vl import ActorVL from lightrft.strategy import get_strategy from lightrft.trainer.spmd_ppo_trainer import SPMDPPOTrainerVL -from lightrft.utils import blending_datasets, get_tokenizer_processor_vl +from lightrft.utils import get_tokenizer_processor_vl sys.path.append(os.path.dirname(os.path.abspath(__file__))) -from reward_models_utils import RECIPE, load_reward_models, reward_fn +from reward_models_utils import RECIPE +from lightrft.reward import RewardManager def train(args): @@ -155,16 +156,6 @@ def train(args): else: critic = None - # Load reward models (multiple types: value, safety, knowledge, etc.) - strategy.report_memory(f"before loaded reward models in main entry") - reward_models, reward_tokenizers, label_map = load_reward_models( - raw_reward_pretrain=args.reward_pretrain, - strategy=strategy, - use_engine=args.rm_use_engine, - ) - strategy.print(f"label_map: {label_map}") - strategy.report_memory(f"after loaded reward models in main entry") - strategy.print(actor) strategy.print(critic) @@ -203,70 +194,129 @@ def train(args): ) assert processor is not None, "processor is None" - # ==================== Data Loading Optimization ==================== - # The following sections now rely on the robust `blending_datasets` function. - # We add more logging for clarity. + # Initialize reward manager (using rule-based rewards for gsm8k/geo3k) + strategy.report_memory(f"before loaded reward models in main entry") + + # For gsm8k/geo3k, we use rule-based rewards, so no neural models are needed + # Create a wrapper function for compatibility with trainer + def reward_fn( + model_reward_list, + labels, + queries, + refs, + label_map, + ): + """ + Wrapper function for RewardManager to match trainer's expected interface. + + For rule-based rewards, model_reward_list will be empty. + The reward manager will compute rewards based on labels and queries. + """ + # Determine rule type from labels (geo3k or gsm8k) + # Use the first label to determine rule type, or default to geo3k_combined + if labels: + first_label = labels[0] + if "gsm8k" in first_label.lower(): + rule_type = "gsm8k_combined" + elif "geo3k" in first_label.lower(): + rule_type = "geo3k_combined" + else: + rule_type = "geo3k_combined" # Default + else: + rule_type = "geo3k_combined" + + # Create a temporary reward manager with the correct rule type + # Note: We could optimize this by caching managers per rule type + # For rule-based rewards, tokenizer and strategy are not required + temp_reward_manager = RewardManager( + reward_type="rule", + rule_type=rule_type, + ) + + # Compute rewards using the reward manager + rewards, metrics = temp_reward_manager.compute( + queries=queries, + references=refs, + labels=labels, + ) + + return rewards, metrics + + label_map = {} # Empty for rule-based rewards + reward_models = [] # Empty for rule-based rewards + reward_tokenizers = [] # Empty for rule-based rewards + + strategy.print(f"Initialized rule-based reward manager for gsm8k/geo3k") + strategy.report_memory(f"after loaded reward models in main entry") + # ==================== Data Loading with New API ==================== + # Use DatasetConfig and DatasetLoader for unified dataset loading + + # Initialize dataset loader + dataset_loader = DatasetLoader( + tokenizer=tokenizer, + processor=processor, + strategy=strategy, + ) + # Prepare prompts dataset - strategy.print(f"Loading prompts dataset from: {args.prompt_data} with split: {args.prompt_split}") - prompts_data = blending_datasets( - args.prompt_data, - args.prompt_data_probs, - strategy, - args.seed, - return_eval=False, - train_split=args.prompt_split, + train_config = DatasetConfig.for_train( + data_path=args.prompt_data, + data_probs=args.prompt_data_probs, + split=args.prompt_split, + max_samples=args.max_samples, + seed=args.seed, + ) + prompts_dataset = dataset_loader.load_train_dataset( + config=train_config, + prompt_max_len=args.prompt_max_len, + input_template=args.input_template, ) - prompts_data = prompts_data.select(range(min(args.max_samples, len(prompts_data)))) - prompts_dataset = PromptDatasetVL(prompts_data, tokenizer, processor, args.prompt_max_len, strategy, input_template=args.input_template) - strategy.print(f"Loaded {len(prompts_dataset)} samples for prompts.") - # Prepare evaluation dataset eval_dataloader = None if args.eval_data or args.eval_split: eval_data_path = args.eval_data if args.eval_data else args.prompt_data if eval_data_path: - strategy.print(f"Loading evaluation dataset from {eval_data_path}, split='{args.eval_split}'") - eval_data = blending_datasets( - eval_data_path, "1.0", strategy, args.seed, return_eval=False, - # Note: `train_split` parameter is used to specify the desired split name for evaluation data. - train_split=args.eval_split, + eval_config = DatasetConfig.for_eval( + data_path=eval_data_path, + data_probs="1.0", + split=args.eval_split, + max_samples=args.max_eval_samples, + seed=args.seed, ) - if len(eval_data) == 0: - strategy.print(f"Warning: Evaluation dataset at {eval_data_path} with split '{args.eval_split}' is empty. Skipping evaluation.") - else: - eval_data = eval_data.select(range(min(args.max_eval_samples, len(eval_data)))) - - eval_dataset = PromptDatasetVL(eval_data, tokenizer, processor, args.prompt_max_len, strategy, input_template=args.input_template) + eval_dataset = dataset_loader.load_eval_dataset( + config=eval_config, + prompt_max_len=args.prompt_max_len, + input_template=args.input_template, + ) + if eval_dataset is not None: eval_dataloader = strategy.setup_dataloader( eval_dataset, args.rollout_batch_size // strategy.world_size, False, False, collate_fn=eval_dataset.collate_fn ) - strategy.print(f"Evaluation dataset loaded: {len(eval_dataset)} samples") else: strategy.print("Warning: eval_split specified but no data path available for evaluation.") # Prepare pretrain dataset pretrain_dataloader = None if args.pretrain_data: - strategy.print(f"Loading pretrain dataset from: {args.pretrain_data} with split: {args.pretrain_split}") - pretrain_data = blending_datasets( - args.pretrain_data, args.pretrain_data_probs, strategy, args.seed, - return_eval=False, train_split=args.pretrain_split, + pretrain_max_len = args.max_len if args.max_len else args.prompt_max_len + args.generate_max_len + # Calculate total samples needed for pretraining + total_pretrain_samples = args.max_epochs * len(prompts_dataset) * args.n_samples_per_prompt + + pretrain_config = DatasetConfig.for_pretrain( + data_path=args.pretrain_data, + data_probs=args.pretrain_data_probs, + split=args.pretrain_split, + max_samples=total_pretrain_samples, + seed=args.seed, ) - if len(pretrain_data) == 0: - strategy.print(f"Warning: Pretrain dataset at {args.pretrain_data} is empty. PTX loss will not be applied.") - pretrain_dataloader = None - else: - pretrain_max_len = args.max_len if args.max_len else args.prompt_max_len + args.generate_max_len - # Calculate total samples needed for pretraining - total_pretrain_samples = args.max_epochs * len(prompts_dataset) * args.n_samples_per_prompt - pretrain_data_subset = pretrain_data.select(range(min(len(pretrain_data), total_pretrain_samples))) - - pretrain_dataset = SFTDatasetVL( - pretrain_data_subset, tokenizer, pretrain_max_len, strategy, pretrain_mode=True, - ) - strategy.print(f"Loaded {len(pretrain_dataset)} samples for pretraining.") + pretrain_dataset = dataset_loader.load_pretrain_dataset( + config=pretrain_config, + pretrain_max_len=pretrain_max_len, + ) + + if pretrain_dataset is not None: pretrain_dataloader = itertools.cycle( iter( strategy.setup_dataloader( @@ -282,21 +332,6 @@ def train(args): prompts_dataset, args.rollout_batch_size // strategy.world_size, True, True, collate_fn=prompts_dataset.collate_fn ) - if args.pretrain_data: - pretrain_dataloader = itertools.cycle( - iter( - strategy.setup_dataloader( - pretrain_dataset, - args.micro_train_batch_size, - True, - True, - pretrain_dataset.collate_fn, - ) - ) - ) - else: - pretrain_dataloader = None - # for scheduler num_update_steps_per_episodes = ( len(prompts_dataset) * args.n_samples_per_prompt // args.train_batch_size * args.max_epochs diff --git a/lightrft/datasets/__init__.py b/lightrft/datasets/__init__.py index 009ea04..50b8517 100644 --- a/lightrft/datasets/__init__.py +++ b/lightrft/datasets/__init__.py @@ -1,3 +1,32 @@ +""" +Dataset Module for LightRFT + +This module provides unified interfaces for loading datasets for training, +evaluation, and pretraining in RLHF workflows. + +Main Features: + - Unified dataset configuration via DatasetConfig + - Consistent loading interface via DatasetLoader + - Support for train, eval, and pretrain datasets + - Automatic handling of blending_datasets parameters + +Classes: + DatasetConfig: Configuration class for dataset loading + DatasetLoader: Unified loader for all dataset types +""" + +# Import new unified interfaces first +from .config import DatasetConfig +from .loader import DatasetLoader + +# Import existing dataset classes +from .process_reward_dataset import ProcessRewardDataset +from .prompts_dataset import PromptDataset +from .prompts_dataset_vl import PromptDatasetVL +from .sft_dataset import SFTDataset +from .sft_dataset_vl import SFTDatasetVL + +# Import other dataset classes from .grm_dataset import GRMDataset from .srm_dataset import RankDatasetVL, RankDatasetAL from .omnirewardbench import * @@ -8,6 +37,7 @@ from .videodpo import * from .videogen_rewardbench import * from .genai_bench import * +from .rft_dataset import RFTDatasetVL from .utils import ( extract_answer, zero_pad_sequences, @@ -15,11 +45,14 @@ load_multimodal_content, BaseDataHandler, ) -from .process_reward_dataset import ProcessRewardDataset -from .prompts_dataset import PromptDataset -from .prompts_dataset_vl import PromptDatasetVL -from .sft_dataset import SFTDataset -from .sft_dataset_vl import SFTDatasetVL -from .rft_dataset import RFTDatasetVL -__all__ = ["ProcessRewardDataset", "PromptDataset", "PromptDatasetVL", "SFTDataset", "SFTDatasetVL", "RFTDatasetVL"] +__all__ = [ + "DatasetConfig", + "DatasetLoader", + "ProcessRewardDataset", + "PromptDataset", + "PromptDatasetVL", + "SFTDataset", + "SFTDatasetVL", + "RFTDatasetVL", +] diff --git a/lightrft/datasets/config.py b/lightrft/datasets/config.py new file mode 100644 index 0000000..f5f30b6 --- /dev/null +++ b/lightrft/datasets/config.py @@ -0,0 +1,204 @@ +""" +Dataset Configuration + +This module provides configuration classes for dataset loading, +unifying parameters for train, eval, and pretrain datasets. + +Main Features: + - Unified configuration for all dataset types + - Automatic normalization of data_path and data_probs + - Factory methods for train/eval/pretrain configurations + - Validation of configuration parameters + +Classes: + DatasetConfig: Dataclass for dataset configuration + +Author: lightrft Team +""" + +from dataclasses import dataclass +from typing import Optional, Union + + +@dataclass +class DatasetConfig: + """ + Configuration for dataset loading. + + This class unifies parameters for train, eval, and pretrain datasets, + providing a consistent interface for dataset configuration. + + :param data_path: Path(s) to dataset(s), can be string or list + :type data_path: Optional[Union[str, list]] + :param data_probs: Sampling probabilities for datasets. Default to "1.0" + :type data_probs: Optional[Union[str, list]] + :param split: Dataset split to use. Default to "train" + :type split: str + :param max_samples: Maximum number of samples to load + :type max_samples: Optional[int] + :param max_len: Maximum sequence length + :type max_len: Optional[int] + :param seed: Random seed. Default to 42 + :type seed: int + :param return_eval: Whether to return evaluation data. Default to False + :type return_eval: bool + """ + + # Data source + data_path: Optional[Union[str, list]] = None + data_probs: Optional[Union[str, list]] = "1.0" + + # Split configuration + split: str = "train" + + # Data filtering + max_samples: Optional[int] = None + max_len: Optional[int] = None + + # Additional parameters + seed: int = 42 + return_eval: bool = False + + def __post_init__(self): + """ + Validate configuration after initialization. + + :raises ValueError: If data_path is None or if data_path and data_probs have mismatched lengths + """ + if self.data_path is None: + raise ValueError("data_path must be specified") + + # Normalize data_probs + if isinstance(self.data_probs, str): + # Parse comma-separated string + self.data_probs = [float(p.strip()) for p in self.data_probs.split(",")] + elif isinstance(self.data_probs, (int, float)): + self.data_probs = [float(self.data_probs)] + + # Normalize data_path + if isinstance(self.data_path, str): + self.data_path = [self.data_path] + + # Ensure data_path and data_probs have same length + if len(self.data_probs) == 1 and len(self.data_path) > 1: + # Repeat single prob for all paths + self.data_probs = self.data_probs * len(self.data_path) + elif len(self.data_probs) != len(self.data_path): + raise ValueError( + f"data_path and data_probs must have the same length, " + f"got {len(self.data_path)} and {len(self.data_probs)}" + ) + + @classmethod + def for_train( + cls, + data_path: Optional[Union[str, list]] = None, + data_probs: Optional[Union[str, list]] = "1.0", + split: str = "train", + max_samples: Optional[int] = None, + max_len: Optional[int] = None, + seed: int = 42, + ) -> "DatasetConfig": + """ + Create configuration for training dataset. + + :param data_path: Path(s) to dataset(s) + :type data_path: Optional[Union[str, list]] + :param data_probs: Sampling probabilities for datasets + :type data_probs: Optional[Union[str, list]] + :param split: Dataset split to use + :type split: str + :param max_samples: Maximum number of samples to load + :type max_samples: Optional[int] + :param max_len: Maximum sequence length + :type max_len: Optional[int] + :param seed: Random seed + :type seed: int + :return: DatasetConfig instance for training + :rtype: DatasetConfig + """ + return cls( + data_path=data_path, + data_probs=data_probs, + split=split, + max_samples=max_samples, + max_len=max_len, + seed=seed, + return_eval=False, + ) + + @classmethod + def for_eval( + cls, + data_path: Optional[Union[str, list]] = None, + data_probs: Optional[Union[str, list]] = "1.0", + split: str = "test", + max_samples: Optional[int] = None, + max_len: Optional[int] = None, + seed: int = 42, + ) -> "DatasetConfig": + """ + Create configuration for evaluation dataset. + + :param data_path: Path(s) to dataset(s) + :type data_path: Optional[Union[str, list]] + :param data_probs: Sampling probabilities for datasets + :type data_probs: Optional[Union[str, list]] + :param split: Dataset split to use + :type split: str + :param max_samples: Maximum number of samples to load + :type max_samples: Optional[int] + :param max_len: Maximum sequence length + :type max_len: Optional[int] + :param seed: Random seed + :type seed: int + :return: DatasetConfig instance for evaluation + :rtype: DatasetConfig + """ + return cls( + data_path=data_path, + data_probs=data_probs, + split=split, + max_samples=max_samples, + max_len=max_len, + seed=seed, + return_eval=False, + ) + + @classmethod + def for_pretrain( + cls, + data_path: Optional[Union[str, list]] = None, + data_probs: Optional[Union[str, list]] = "1.0", + split: str = "train", + max_samples: Optional[int] = None, + max_len: Optional[int] = None, + seed: int = 42, + ) -> "DatasetConfig": + """ + Create configuration for pretraining dataset. + + :param data_path: Path(s) to dataset(s) + :type data_path: Optional[Union[str, list]] + :param data_probs: Sampling probabilities for datasets + :type data_probs: Optional[Union[str, list]] + :param split: Dataset split to use + :type split: str + :param max_samples: Maximum number of samples to load + :type max_samples: Optional[int] + :param max_len: Maximum sequence length + :type max_len: Optional[int] + :param seed: Random seed + :type seed: int + :return: DatasetConfig instance for pretraining + :rtype: DatasetConfig + """ + return cls( + data_path=data_path, + data_probs=data_probs, + split=split, + max_samples=max_samples, + max_len=max_len, + seed=seed, + return_eval=False, + ) diff --git a/lightrft/datasets/loader.py b/lightrft/datasets/loader.py new file mode 100644 index 0000000..8c58dd8 --- /dev/null +++ b/lightrft/datasets/loader.py @@ -0,0 +1,237 @@ +""" +Dataset Loader + +This module provides a unified interface for loading datasets for training, +evaluation, and pretraining, abstracting away the differences between +different dataset types and splits. + +Main Features: + - Unified loading interface for train, eval, and pretrain datasets + - Automatic handling of blending_datasets parameters + - Support for PromptDatasetVL and SFTDatasetVL + - Consistent logging via strategy + +Classes: + DatasetLoader: Unified loader for all dataset types + +Author: lightrft Team +""" + +from typing import Optional, Any + +from lightrft.utils import blending_datasets +from .prompts_dataset_vl import PromptDatasetVL +from .sft_dataset_vl import SFTDatasetVL +from .config import DatasetConfig + + +class DatasetLoader: + """ + Unified dataset loader for train, eval, and pretrain datasets. + + This class provides a consistent interface for loading datasets, + handling the differences between prompt datasets (for training/eval) + and SFT datasets (for pretraining). + + :param tokenizer: Tokenizer for tokenizing text + :type tokenizer: Any + :param processor: Processor for multimodal data (optional) + :type processor: Optional[Any] + :param strategy: Training strategy (optional, for logging) + :type strategy: Optional[Any] + """ + def __init__( + self, + tokenizer: Any, + processor: Optional[Any] = None, + strategy: Optional[Any] = None, + ): + """ + Initialize dataset loader. + + :param tokenizer: Tokenizer for tokenizing text + :type tokenizer: Any + :param processor: Processor for multimodal data (optional) + :type processor: Optional[Any] + :param strategy: Training strategy (optional, for logging) + :type strategy: Optional[Any] + """ + self.tokenizer = tokenizer + self.processor = processor + self.strategy = strategy + + def _log(self, message: str): + """ + Log message if strategy is available. + + :param message: Message to log + :type message: str + """ + if self.strategy: + self.strategy.print(message) + else: + print(message) + + def load_train_dataset( + self, + config: DatasetConfig, + prompt_max_len: int, + input_template: Optional[str] = None, + ) -> PromptDatasetVL: + """ + Load training dataset. + + :param config: Dataset configuration + :type config: DatasetConfig + :param prompt_max_len: Maximum prompt length + :type prompt_max_len: int + :param input_template: Input template for formatting prompts + :type input_template: Optional[str] + :return: PromptDatasetVL instance for training + :rtype: PromptDatasetVL + """ + # Convert data_path list to comma-separated string for blending_datasets + data_path_str = config.data_path if isinstance(config.data_path, str) else ",".join(config.data_path) + self._log(f"Loading training dataset from: {data_path_str} with split: {config.split}") + + # Load and blend datasets + data = blending_datasets( + data_path_str, + ",".join(map(str, config.data_probs)), + self.strategy, + config.seed, + return_eval=config.return_eval, + train_split=config.split, + ) + + # Limit samples if specified + if config.max_samples is not None: + data = data.select(range(min(config.max_samples, len(data)))) + + self._log(f"Loaded {len(data)} samples for training.") + + # Create dataset + dataset = PromptDatasetVL( + data, + self.tokenizer, + self.processor, + prompt_max_len, + self.strategy, + input_template=input_template, + ) + + return dataset + + def load_eval_dataset( + self, + config: DatasetConfig, + prompt_max_len: int, + input_template: Optional[str] = None, + ) -> Optional[PromptDatasetVL]: + """ + Load evaluation dataset. + + :param config: Dataset configuration + :type config: DatasetConfig + :param prompt_max_len: Maximum prompt length + :type prompt_max_len: int + :param input_template: Input template for formatting prompts + :type input_template: Optional[str] + :return: PromptDatasetVL instance for evaluation, or None if no data + :rtype: Optional[PromptDatasetVL] + """ + if config.data_path is None: + return None + + # Convert data_path list to comma-separated string for blending_datasets + data_path_str = config.data_path if isinstance(config.data_path, str) else ",".join(config.data_path) + self._log(f"Loading evaluation dataset from {data_path_str}, split='{config.split}'") + + # Load and blend datasets + data = blending_datasets( + data_path_str, + ",".join(map(str, config.data_probs)), + self.strategy, + config.seed, + return_eval=config.return_eval, + train_split=config.split, + ) + + if len(data) == 0: + self._log( + f"Warning: Evaluation dataset at {data_path_str} with split '{config.split}' " + "is empty. Skipping evaluation." + ) + return None + + # Limit samples if specified + if config.max_samples is not None: + data = data.select(range(min(config.max_samples, len(data)))) + + self._log(f"Evaluation dataset loaded: {len(data)} samples") + + # Create dataset + dataset = PromptDatasetVL( + data, + self.tokenizer, + self.processor, + prompt_max_len, + self.strategy, + input_template=input_template, + ) + + return dataset + + def load_pretrain_dataset( + self, + config: DatasetConfig, + pretrain_max_len: int, + ) -> Optional[SFTDatasetVL]: + """ + Load pretraining dataset. + + :param config: Dataset configuration + :type config: DatasetConfig + :param pretrain_max_len: Maximum sequence length for pretraining + :type pretrain_max_len: int + :return: SFTDatasetVL instance for pretraining, or None if no data + :rtype: Optional[SFTDatasetVL] + """ + if config.data_path is None: + return None + + # Convert data_path list to comma-separated string for blending_datasets + data_path_str = config.data_path if isinstance(config.data_path, str) else ",".join(config.data_path) + self._log(f"Loading pretrain dataset from: {data_path_str} with split: {config.split}") + + # Load and blend datasets + data = blending_datasets( + data_path_str, + ",".join(map(str, config.data_probs)), + self.strategy, + config.seed, + return_eval=config.return_eval, + train_split=config.split, + ) + + if len(data) == 0: + self._log(f"Warning: Pretrain dataset at {data_path_str} is empty. " + "PTX loss will not be applied.") + return None + + # Limit samples if specified + if config.max_samples is not None: + data = data.select(range(min(config.max_samples, len(data)))) + + self._log(f"Loaded {len(data)} samples for pretraining.") + + # Create dataset + dataset = SFTDatasetVL( + data, + self.tokenizer, + pretrain_max_len, + self.strategy, + pretrain_mode=True, + ) + + return dataset diff --git a/lightrft/reward/__init__.py b/lightrft/reward/__init__.py new file mode 100644 index 0000000..efcfebe --- /dev/null +++ b/lightrft/reward/__init__.py @@ -0,0 +1,34 @@ +""" +Reward Module for LightRLHF + +This module provides unified interfaces for different types of rewards in RLHF training. +It supports rule-based rewards, single reward models, and multiple reward model ensembles. + +Main Features: + - Unified reward interface with consistent compute() method signature + - Rule-based reward functions (format checking, accuracy verification) + - Single and multiple reward model support + - Automatic reward type selection via RewardManager + +Classes: + BaseReward: Abstract base class for all reward types + RuleReward: Rule-based reward implementation + SingleRewardModel: Single reward model wrapper + MultiRewardModel: Multiple reward model ensemble + RewardManager: Unified manager for automatic reward type selection + +Author: lightrft Team +""" + +from .base import BaseReward +from .rule import RuleReward +from .model import SingleRewardModel, MultiRewardModel +from .manager import RewardManager + +__all__ = [ + "BaseReward", + "RuleReward", + "SingleRewardModel", + "MultiRewardModel", + "RewardManager", +] diff --git a/lightrft/reward/base.py b/lightrft/reward/base.py new file mode 100644 index 0000000..8a939fc --- /dev/null +++ b/lightrft/reward/base.py @@ -0,0 +1,82 @@ +""" +Base Reward Interface + +This module defines the abstract base class for all reward types in LightRLHF, +ensuring a consistent interface across rule-based rewards, single reward models, +and multiple reward model ensembles. + +Main Features: + - Unified compute() method signature for all reward types + - Consistent return format: (rewards, metrics) + - Support for queries, references, and labels + +Classes: + BaseReward: Abstract base class for all reward implementations + +Author: lightrft Team +""" + +from abc import ABC, abstractmethod +from typing import Dict, Tuple, Sequence, Optional + +import torch + + +class BaseReward(ABC): + """ + Abstract base class for all reward types. + + This class defines the unified interface that all reward implementations + must follow, ensuring consistency across rule-based rewards, single reward + models, and multiple reward models. + + All reward implementations should return: + - rewards: torch.Tensor of shape (batch_size,) containing reward values + - metrics: Dict[str, torch.Tensor] containing detailed reward metrics + """ + @abstractmethod + def compute( + self, + queries: Sequence[str], + references: Optional[Sequence[str]] = None, + labels: Optional[Sequence[str]] = None, + **kwargs + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + """ + Compute rewards for given queries. + + :param queries: List of query/solution strings (length B) + :type queries: Sequence[str] + :param references: List of reference answers (length B), optional + :type references: Optional[Sequence[str]] + :param labels: List of data labels indicating reward type (length B), optional + :type labels: Optional[Sequence[str]] + :param kwargs: Additional arguments for specific reward types + :return: Tuple of (rewards, metrics) where rewards is torch.Tensor of shape (B,) + and metrics is Dict[str, torch.Tensor] with keys like 'format_reward', + 'accuracy_reward', 'model_reward', 'rule_reward' + :rtype: Tuple[torch.Tensor, Dict[str, torch.Tensor]] + """ + pass + + def __call__( + self, + queries: Sequence[str], + references: Optional[Sequence[str]] = None, + labels: Optional[Sequence[str]] = None, + **kwargs + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + """ + Make the reward object callable. + + :param queries: List of query/solution strings + :type queries: Sequence[str] + :param references: List of reference answers, optional + :type references: Optional[Sequence[str]] + :param labels: List of data labels, optional + :type labels: Optional[Sequence[str]] + :param kwargs: Additional arguments + :return: Tuple of (rewards, metrics) + :rtype: Tuple[torch.Tensor, Dict[str, torch.Tensor]] + """ + return self.compute(queries, references, labels, **kwargs) diff --git a/lightrft/reward/manager.py b/lightrft/reward/manager.py new file mode 100644 index 0000000..a0f2c23 --- /dev/null +++ b/lightrft/reward/manager.py @@ -0,0 +1,231 @@ +""" +Reward Manager + +This module provides a unified manager for different reward types, +automatically selecting and using the appropriate reward implementation +based on configuration. + +Main Features: + - Automatic reward type selection (rule, single, multi) + - Unified compute_rewards() interface + - Configuration-based initialization + - Support for factory pattern via from_config() + +Classes: + RewardManager: Unified manager for all reward types + +Author: lightrft Team +""" + +from typing import Dict, Sequence, Optional, List, Tuple, Any, Union + +import torch + +from .base import BaseReward +from .rule import RuleReward +from .model import SingleRewardModel, MultiRewardModel + + +class RewardManager(BaseReward): + """ + Unified reward manager that automatically selects the appropriate + reward implementation based on configuration. + + Supports: + - Rule-based rewards (pure rule functions) + - Single reward model + - Multiple reward models (with recipe-based aggregation) + + :param reward_type: Type of reward to use ("rule", "single", "multi") + :type reward_type: str + :param reward_model: Single reward model or list of reward models + :type reward_model: Optional[Union[Any, List[Any]]] + :param reward_tokenizers: List of tokenizers for reward models + :type reward_tokenizers: Optional[List[Any]] + :param reward_fn: Aggregation function for multiple reward models + :type reward_fn: Optional[Any] + :param reward_fn_label_map: Mapping from reward type to model index + :type reward_fn_label_map: Optional[Dict[str, int]] + :param reward_recipe: Recipe configuration for combining rewards + :type reward_recipe: Optional[Dict[str, List[Tuple[str, Optional[str], float]]]] + :param rule_type: Type of rule reward (e.g., "geo3k_combined", "gsm8k_combined") + :type rule_type: Optional[str] + :param tokenizer: Tokenizer for decoding sequences + :type tokenizer: Optional[Any] + :param strategy: Training strategy (for model loading/offloading) + :type strategy: Optional[Any] + :param packing_samples: Whether samples are packed. Default to False + :type packing_samples: bool + :param device: Device to place reward tensors on + :type device: Optional[torch.device] + """ + def __init__( + self, + reward_type: str = "multi", # "rule", "single", "multi" + reward_model: Optional[Union[Any, List[Any]]] = None, + reward_tokenizers: Optional[List[Any]] = None, + reward_fn: Optional[Any] = None, + reward_fn_label_map: Optional[Dict[str, int]] = None, + reward_recipe: Optional[Dict[str, List[Tuple[str, Optional[str], float]]]] = None, + rule_type: Optional[str] = None, + tokenizer: Optional[Any] = None, + strategy: Optional[Any] = None, + packing_samples: bool = False, + device: Optional[torch.device] = None, + ): + """ + Initialize reward manager. + + :param reward_type: Type of reward to use ("rule", "single", "multi") + :type reward_type: str + :param reward_model: Single reward model or list of reward models + :type reward_model: Optional[Union[Any, List[Any]]] + :param reward_tokenizers: List of tokenizers for reward models + :type reward_tokenizers: Optional[List[Any]] + :param reward_fn: Aggregation function for multiple reward models + :type reward_fn: Optional[Any] + :param reward_fn_label_map: Mapping from reward type to model index + :type reward_fn_label_map: Optional[Dict[str, int]] + :param reward_recipe: Recipe configuration for combining rewards + :type reward_recipe: Optional[Dict[str, List[Tuple[str, Optional[str], float]]]] + :param rule_type: Type of rule reward (e.g., "geo3k_combined", "gsm8k_combined") + :type rule_type: Optional[str] + :param tokenizer: Tokenizer for decoding sequences + :type tokenizer: Optional[Any] + :param strategy: Training strategy (for model loading/offloading) + :type strategy: Optional[Any] + :param packing_samples: Whether samples are packed + :type packing_samples: bool + :param device: Device to place reward tensors on + :type device: Optional[torch.device] + :raises ValueError: If required parameters are missing for the specified reward_type + """ + super().__init__() + self.reward_type = reward_type + self.device = device or torch.cuda.current_device() if torch.cuda.is_available() else torch.device("cpu") + + # Initialize the appropriate reward implementation + if reward_type == "rule": + if rule_type is None: + raise ValueError("rule_type must be specified for rule-based rewards") + self.reward_impl = RuleReward( + rule_type=rule_type, + device=self.device, + ) + elif reward_type == "single": + if reward_model is None: + raise ValueError("reward_model must be specified for single reward model") + if tokenizer is None: + raise ValueError("tokenizer must be specified for single reward model") + if strategy is None: + raise ValueError("strategy must be specified for single reward model") + + # Ensure reward_model is a single model, not a list + if isinstance(reward_model, (list, tuple)): + if len(reward_model) != 1: + raise ValueError("reward_model must be a single model for reward_type='single'") + reward_model = reward_model[0] + + self.reward_impl = SingleRewardModel( + reward_model=reward_model, + tokenizer=tokenizer, + strategy=strategy, + packing_samples=packing_samples, + device=self.device, + ) + elif reward_type == "multi": + if reward_model is None: + raise ValueError("reward_model must be specified for multiple reward models") + if reward_fn is None: + raise ValueError("reward_fn must be specified for multiple reward models") + if tokenizer is None: + raise ValueError("tokenizer must be specified for multiple reward models") + if strategy is None: + raise ValueError("strategy must be specified for multiple reward models") + + # Ensure reward_model is a list + if not isinstance(reward_model, (list, tuple)): + reward_model = [reward_model] + + self.reward_impl = MultiRewardModel( + reward_models=reward_model, + reward_tokenizers=reward_tokenizers or [], + reward_fn=reward_fn, + reward_fn_label_map=reward_fn_label_map or {}, + reward_recipe=reward_recipe or {}, + tokenizer=tokenizer, + strategy=strategy, + packing_samples=packing_samples, + device=self.device, + ) + else: + raise ValueError(f"Unknown reward_type: {reward_type}. Must be 'rule', 'single', or 'multi'") + + def compute( + self, + queries: Sequence[str], + references: Optional[Sequence[str]] = None, + labels: Optional[Sequence[str]] = None, + **kwargs + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + """ + Compute rewards using the configured reward implementation. + + :param queries: List of query/solution strings (length B) + :type queries: Sequence[str] + :param references: List of reference answers (length B), optional + :type references: Optional[Sequence[str]] + :param labels: List of data labels indicating reward type (length B), optional + :type labels: Optional[Sequence[str]] + :param kwargs: Additional arguments passed to the reward implementation + :return: Tuple of (rewards, metrics) where rewards is torch.Tensor of shape (B,) + and metrics contains detailed reward metrics + :rtype: Tuple[torch.Tensor, Dict[str, torch.Tensor]] + """ + return self.reward_impl.compute(queries, references, labels, **kwargs) + + @classmethod + def from_config( + cls, + config: Dict[str, Any], + reward_models: Optional[Union[Any, List[Any]]] = None, + reward_tokenizers: Optional[List[Any]] = None, + tokenizer: Optional[Any] = None, + strategy: Optional[Any] = None, + ) -> "RewardManager": + """ + Create RewardManager from configuration dictionary. + + :param config: Configuration dictionary with keys: + - reward_type: "rule", "single", or "multi" + - rule_type: Type of rule (for rule-based rewards) + - reward_fn: Aggregation function (for multi rewards) + - reward_fn_label_map: Label map (for multi rewards) + - reward_recipe: Recipe config (for multi rewards) + - packing_samples: Whether samples are packed + - device: Device to use + :type config: Dict[str, Any] + :param reward_models: Reward model(s) to use + :type reward_models: Optional[Union[Any, List[Any]]] + :param reward_tokenizers: Tokenizers for reward models + :type reward_tokenizers: Optional[List[Any]] + :param tokenizer: Tokenizer for decoding sequences + :type tokenizer: Optional[Any] + :param strategy: Training strategy + :type strategy: Optional[Any] + :return: RewardManager instance + :rtype: RewardManager + """ + return cls( + reward_type=config.get("reward_type", "multi"), + reward_model=reward_models, + reward_tokenizers=reward_tokenizers, + reward_fn=config.get("reward_fn"), + reward_fn_label_map=config.get("reward_fn_label_map"), + reward_recipe=config.get("reward_recipe"), + rule_type=config.get("rule_type"), + tokenizer=tokenizer, + strategy=strategy, + packing_samples=config.get("packing_samples", False), + device=config.get("device"), + ) diff --git a/lightrft/reward/model.py b/lightrft/reward/model.py new file mode 100644 index 0000000..7f85388 --- /dev/null +++ b/lightrft/reward/model.py @@ -0,0 +1,356 @@ +""" +Reward Model Implementation + +This module provides implementations for single and multiple reward models, +encapsulating the logic for computing rewards using neural models. + +Main Features: + - Single reward model wrapper with automatic loading/offloading + - Multiple reward model ensemble with recipe-based aggregation + - Support for both standard PyTorch models and custom engine models + - Consistent interface with BaseReward + +Classes: + SingleRewardModel: Wrapper for single reward model + MultiRewardModel: Ensemble of multiple reward models with aggregation +""" + +from typing import Dict, Sequence, Optional, List, Tuple, Any + +import torch +import torch.nn as nn + +from .base import BaseReward + + +class SingleRewardModel(BaseReward): + """ + Single reward model implementation. + + This class encapsulates the logic for computing rewards using a single + neural reward model. It handles both standard PyTorch models and custom + engine models (e.g., SGLang). + + :param reward_model: PyTorch reward model instance + :type reward_model: nn.Module + :param tokenizer: Tokenizer for decoding sequences + :type tokenizer: Any + :param strategy: Training strategy (for model loading/offloading) + :type strategy: Any + :param packing_samples: Whether samples are packed. Default to False + :type packing_samples: bool + :param device: Device to place reward tensors on + :type device: Optional[torch.device] + """ + def __init__( + self, + reward_model: nn.Module, + tokenizer: Any, + strategy: Any, + packing_samples: bool = False, + device: Optional[torch.device] = None, + ): + """ + Initialize single reward model. + + :param reward_model: PyTorch reward model instance + :type reward_model: nn.Module + :param tokenizer: Tokenizer for decoding sequences + :type tokenizer: Any + :param strategy: Training strategy (for model loading/offloading) + :type strategy: Any + :param packing_samples: Whether samples are packed + :type packing_samples: bool + :param device: Device to place reward tensors on + :type device: Optional[torch.device] + """ + super().__init__() + self.reward_model = reward_model + self.tokenizer = tokenizer + self.strategy = strategy + self.packing_samples = packing_samples + self.device = device or torch.cuda.current_device() if torch.cuda.is_available() else torch.device("cpu") + + def compute( + self, + queries: Sequence[str], + references: Optional[Sequence[str]] = None, + labels: Optional[Sequence[str]] = None, + sequences: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + prompt_and_output: Optional[Sequence[str]] = None, + raw_images: Optional[List] = None, + img_num: Optional[List[int]] = None, + **kwargs + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + """ + Compute rewards using a single reward model. + + :param queries: List of query/solution strings (length B) + :type queries: Sequence[str] + :param references: List of reference answers (length B), optional + :type references: Optional[Sequence[str]] + :param labels: Not used for single RM, kept for interface consistency + :type labels: Optional[Sequence[str]] + :param sequences: Token ID sequences [B, seq_len], optional + :type sequences: Optional[torch.Tensor] + :param attention_mask: Attention mask for sequences, optional + :type attention_mask: Optional[torch.Tensor] + :param prompt_and_output: List of prompt+output strings, optional + :type prompt_and_output: Optional[Sequence[str]] + :param raw_images: List of PIL images, optional + :type raw_images: Optional[List] + :param img_num: List of image counts per sample, optional + :type img_num: Optional[List[int]] + :param kwargs: Additional arguments for reward model forward pass + :return: Tuple of (rewards, metrics) where rewards is torch.Tensor of shape (B,) + and metrics contains 'model_reward' + :rtype: Tuple[torch.Tensor, Dict[str, torch.Tensor]] + """ + # Load model to GPU if needed + if isinstance(self.reward_model, torch.nn.Module): + self.strategy.reload_model(self.reward_model) + + # Prepare inputs + if sequences is not None: + # Standard PyTorch model path + rm_output = self.reward_model( + sequences, + attention_mask, + prompt_and_output=prompt_and_output, + raw_images=raw_images, + img_num=img_num, + **kwargs + ) + else: + # Custom engine model path + rm_output = self.reward_model( + None, + None, + prompt_and_outputs=prompt_and_output if prompt_and_output else queries, + raw_images=raw_images, + img_num=img_num, + references=references, + labels=labels, + **kwargs + ) + + # Extract scores + if isinstance(rm_output, dict): + scores = rm_output["score"] + else: + scores = rm_output + + # Ensure tensor format + if not isinstance(scores, torch.Tensor): + scores = torch.as_tensor(scores, dtype=torch.float32, device=self.device) + else: + scores = scores.to(self.device) + + # Offload model after use + if isinstance(self.reward_model, torch.nn.Module): + self.strategy.offload_model(self.reward_model) + + # Create metrics + metrics = { + 'model_reward': scores.clone(), + } + + return scores, metrics + + +class MultiRewardModel(BaseReward): + """ + Multiple reward model implementation. + + This class encapsulates the logic for computing rewards using multiple + reward models and aggregating them according to a recipe configuration. + + :param reward_models: List of reward model instances + :type reward_models: List[nn.Module] + :param reward_tokenizers: List of corresponding tokenizers + :type reward_tokenizers: List[Any] + :param reward_fn: Aggregation function for combining rewards + :type reward_fn: Any + :param reward_fn_label_map: Mapping from reward type to model index + :type reward_fn_label_map: Dict[str, int] + :param reward_recipe: Recipe configuration for combining rewards + :type reward_recipe: Dict[str, List[Tuple[str, Optional[str], float]]] + :param tokenizer: Tokenizer for decoding sequences + :type tokenizer: Any + :param strategy: Training strategy (for model loading/offloading) + :type strategy: Any + :param packing_samples: Whether samples are packed. Default to False + :type packing_samples: bool + :param device: Device to place reward tensors on + :type device: Optional[torch.device] + """ + def __init__( + self, + reward_models: List[nn.Module], + reward_tokenizers: List[Any], + reward_fn: Any, + reward_fn_label_map: Dict[str, int], + reward_recipe: Dict[str, List[Tuple[str, Optional[str], float]]], + tokenizer: Any, + strategy: Any, + packing_samples: bool = False, + device: Optional[torch.device] = None, + ): + """ + Initialize multiple reward models. + + :param reward_models: List of reward model instances + :type reward_models: List[nn.Module] + :param reward_tokenizers: List of corresponding tokenizers + :type reward_tokenizers: List[Any] + :param reward_fn: Aggregation function for combining rewards + :type reward_fn: Any + :param reward_fn_label_map: Mapping from reward type to model index + :type reward_fn_label_map: Dict[str, int] + :param reward_recipe: Recipe configuration for combining rewards + :type reward_recipe: Dict[str, List[Tuple[str, Optional[str], float]]] + :param tokenizer: Tokenizer for decoding sequences + :type tokenizer: Any + :param strategy: Training strategy (for model loading/offloading) + :type strategy: Any + :param packing_samples: Whether samples are packed + :type packing_samples: bool + :param device: Device to place reward tensors on + :type device: Optional[torch.device] + """ + super().__init__() + self.reward_models = reward_models + self.reward_tokenizers = reward_tokenizers + self.reward_fn = reward_fn + self.reward_fn_label_map = reward_fn_label_map + self.reward_recipe = reward_recipe + self.tokenizer = tokenizer + self.strategy = strategy + self.packing_samples = packing_samples + self.device = device or torch.cuda.current_device() if torch.cuda.is_available() else torch.device("cpu") + + def compute( + self, + queries: Sequence[str], + references: Optional[Sequence[str]] = None, + labels: Optional[Sequence[str]] = None, + sequences: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + prompt_and_output: Optional[Sequence[str]] = None, + raw_images: Optional[List] = None, + img_num: Optional[List[int]] = None, + **kwargs + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + """ + Compute rewards using multiple reward models and aggregate them. + + :param queries: List of query/solution strings (length B) + :type queries: Sequence[str] + :param references: List of reference answers (length B), optional + :type references: Optional[Sequence[str]] + :param labels: List of data labels indicating reward type (length B), required + :type labels: Optional[Sequence[str]] + :param sequences: Token ID sequences [B, seq_len], optional + :type sequences: Optional[torch.Tensor] + :param attention_mask: Attention mask for sequences, optional + :type attention_mask: Optional[torch.Tensor] + :param prompt_and_output: List of prompt+output strings, optional + :type prompt_and_output: Optional[Sequence[str]] + :param raw_images: List of PIL images, optional + :type raw_images: Optional[List] + :param img_num: List of image counts per sample, optional + :type img_num: Optional[List[int]] + :param kwargs: Additional arguments for reward model forward pass + :return: Tuple of (rewards, metrics) where rewards is torch.Tensor of shape (B,) + and metrics contains detailed reward metrics + :rtype: Tuple[torch.Tensor, Dict[str, torch.Tensor]] + :raises ValueError: If labels are not provided + """ + if labels is None: + raise ValueError("labels are required for MultiRewardModel") + + B = len(queries) + + # Load all models to GPU + for rm in self.reward_models: + if isinstance(rm, torch.nn.Module): + self.strategy.reload_model(rm) + + # Compute rewards for each RM + model_reward_list = [] + + for rm_idx, rm in enumerate(self.reward_models): + # Check if this is a custom engine model + is_custom_engine = ( + isinstance(rm, torch.nn.Module) and hasattr(rm, "base_model") + and not isinstance(rm.base_model, torch.nn.Module) + ) + + if is_custom_engine: + # Custom engine model path + rm_output = rm( + None, + None, + prompt_and_outputs=prompt_and_output if prompt_and_output else queries, + raw_images=raw_images, + img_num=img_num, + references=references, + labels=labels, + **kwargs + ) + else: + # Standard PyTorch model path + rm_output = rm( + sequences, + attention_mask, + prompt_and_output=prompt_and_output, + raw_images=raw_images, + img_num=img_num, + **kwargs + ) + + # Extract scores + if isinstance(rm_output, dict): + scores = rm_output["score"] + else: + scores = rm_output + + # Ensure tensor format + if not isinstance(scores, torch.Tensor): + scores = torch.as_tensor(scores, dtype=torch.float32, device=self.device) + else: + scores = scores.to(self.device) + + model_reward_list.append(scores) + + # Offload model after use + if isinstance(rm, torch.nn.Module): + self.strategy.offload_model(rm) + + # Aggregate rewards using reward_fn + rewards, reward_metrics = self.reward_fn( + model_reward_list=model_reward_list, + labels=labels, + queries=queries, + refs=references if references else [""] * B, + label_map=self.reward_fn_label_map, + ) + + # Ensure rewards are on correct device + if not isinstance(rewards, torch.Tensor): + rewards = torch.as_tensor(rewards, dtype=torch.float32, device=self.device) + else: + rewards = rewards.to(self.device) + + # Ensure metrics are tensors + if reward_metrics is not None: + for key, value in reward_metrics.items(): + if not isinstance(value, torch.Tensor): + reward_metrics[key] = torch.as_tensor(value, dtype=torch.float32, device=self.device) + else: + reward_metrics[key] = value.to(self.device) + else: + reward_metrics = {} + + return rewards, reward_metrics diff --git a/lightrft/reward/rule.py b/lightrft/reward/rule.py new file mode 100644 index 0000000..53520fd --- /dev/null +++ b/lightrft/reward/rule.py @@ -0,0 +1,340 @@ +""" +Rule-based Reward Implementation + +This module provides rule-based reward functions that evaluate model outputs +based on heuristics and format checking rather than neural models. + +Main Features: + - Format checking (e.g., tags, \\boxed{} notation) + - Accuracy verification using mathruler grader + - Language consistency checking + - Registry pattern for custom rule types + +Supported Rule Types: + - default: Basic format checking + - geo3k_accuracy: Geo3K accuracy verification + - geo3k_format: Geo3K format checking + - geo3k_combined: Combined format + accuracy + - gsm8k_accuracy: GSM8K accuracy verification + - gsm8k_format: GSM8K format checking + - gsm8k_combined: Combined format + accuracy + +Author: lightrft Team +""" + +import re +from typing import Dict, Sequence, Optional, Callable + +import torch + +from .base import BaseReward + + +class RuleReward(BaseReward): + """ + Rule-based reward implementation. + + This class encapsulates various rule-based reward functions such as: + - Format checking (e.g., tags) + - Accuracy checking (e.g., math answer verification) + - Language consistency checking + + Supports multiple rule types through a registry pattern. + + :param rule_type: Type of rule to use (e.g., "geo3k_combined", "gsm8k_combined", "default") + :type rule_type: str + :param format_weight: Weight for format reward when combining with accuracy. Default to 0.1 + :type format_weight: float + :param device: Device to place reward tensors on + :type device: Optional[torch.device] + """ + + # Registry for rule reward functions + _RULE_FUNCTIONS: Dict[str, Callable] = {} + + @classmethod + def register_rule(cls, name: str): + """ + Decorator to register a rule reward function. + + :param name: Name of the rule type + :type name: str + :return: Decorator function + :rtype: Callable + + Example:: + + @RuleReward.register_rule("geo3k") + def geo3k_rule(sol: str, gt: str) -> float: + ... + """ + def decorator(func: Callable): + cls._RULE_FUNCTIONS[name] = func + return func + + return decorator + + def __init__( + self, + rule_type: str = "default", + format_weight: float = 0.1, + device: Optional[torch.device] = None, + ): + """ + Initialize rule-based reward. + + :param rule_type: Type of rule to use (e.g., "geo3k", "gsm8k", "default") + :type rule_type: str + :param format_weight: Weight for format reward when combining with accuracy + :type format_weight: float + :param device: Device to place reward tensors on + :type device: Optional[torch.device] + :raises ValueError: If rule_type is not registered + """ + super().__init__() + self.rule_type = rule_type + self.format_weight = format_weight + self.device = device or torch.cuda.current_device() if torch.cuda.is_available() else torch.device("cpu") + + # Get the rule function + if rule_type not in self._RULE_FUNCTIONS: + raise ValueError( + f"Unknown rule type: {rule_type}. " + f"Available types: {list(self._RULE_FUNCTIONS.keys())}" + ) + self.rule_func = self._RULE_FUNCTIONS[rule_type] + + def compute( + self, + queries: Sequence[str], + references: Optional[Sequence[str]] = None, + labels: Optional[Sequence[str]] = None, + **kwargs + ) -> tuple[torch.Tensor, Dict[str, torch.Tensor]]: + """ + Compute rule-based rewards. + + :param queries: List of solution strings (length B) + :type queries: Sequence[str] + :param references: List of ground truth answers (length B), required for accuracy checking + :type references: Optional[Sequence[str]] + :param labels: Not used for rule rewards, kept for interface consistency + :type labels: Optional[Sequence[str]] + :param kwargs: Additional arguments + :return: Tuple of (rewards, metrics) where rewards is torch.Tensor of shape (B,) + and metrics contains 'format_reward', 'accuracy_reward', 'rule_reward' + :rtype: Tuple[torch.Tensor, Dict[str, torch.Tensor]] + """ + if references is None: + references = [""] * len(queries) + + B = len(queries) + device = self.device + + # Initialize metrics + metrics = { + 'format_reward': torch.zeros(B, dtype=torch.float32, device=device), + 'accuracy_reward': torch.zeros(B, dtype=torch.float32, device=device), + 'rule_reward': torch.zeros(B, dtype=torch.float32, device=device), + } + + rewards = torch.zeros(B, dtype=torch.float32, device=device) + + # Compute rewards for each query + for i, (sol, gt) in enumerate(zip(queries, references)): + reward_value = self.rule_func(sol, gt) + rewards[i] = reward_value + metrics['rule_reward'][i] = reward_value + + # For combined rules (geo3k, gsm8k), extract individual components + if self.rule_type in ["geo3k_combined", "gsm8k_combined"]: + # These rule functions return combined reward, but we can extract components + # by calling the individual functions if available + if hasattr(self, '_extract_components'): + fmt_r, acc_r = self._extract_components(sol, gt) + metrics['format_reward'][i] = fmt_r + metrics['accuracy_reward'][i] = acc_r + + return rewards, metrics + + +# ============================================================================ +# Default Rule Reward Functions +# ============================================================================ + + +def _default_rule_reward_fn(sol: str, gt: str) -> float: + """ + Default rule reward: format checking. + + Checks if solution matches format: ... + non-empty content. + + :param sol: Solution string to check + :type sol: str + :param gt: Ground truth (not used in format check) + :type gt: str + :return: 1.0 if format is valid, 0.0 otherwise + :rtype: float + """ + pattern = r".*.+?\s*\S+" + return 1.0 if re.match(pattern, sol, re.DOTALL) else 0.0 + + +RuleReward.register_rule("default")(_default_rule_reward_fn) + +# ============================================================================ +# Geo3K Rule Reward Functions +# ============================================================================ + + +def _geo3k_accuracy_reward_fn(sol: str, gt: str) -> float: + """ + Geo3K accuracy reward function. + + Extract answer from \\boxed{} notation and use mathruler to verify correctness. + + :param sol: Solution string from model (should contain \\boxed{answer}) + :type sol: str + :param gt: Ground truth answer + :type gt: str + :return: 1.0 if answer is correct, 0.0 otherwise + :rtype: float + """ + try: + from mathruler.grader import extract_boxed_content, grade_answer + pred = extract_boxed_content(sol) + return 1.0 if grade_answer(pred, gt) else 0.0 + except ImportError: + # Fallback if mathruler is not available + return 0.0 + + +def _geo3k_format_reward_fn(sol: str, gt: str) -> float: + """ + Geo3K format reward function. + + Check if the solution follows the required format: + - Contains ... tags for reasoning + - Contains \\boxed{} for final answer + - The think tags must appear BEFORE the boxed answer + + :param sol: Solution string from model + :type sol: str + :param gt: Ground truth (not used in format check) + :type gt: str + :return: 1.0 if format is correct, 0.0 otherwise + :rtype: float + """ + sol_stripped = sol.strip() + + think_match = re.search(r'.*?', sol_stripped, re.DOTALL) + boxed_match = re.search(r'\\boxed\{.*?\}', sol_stripped, re.DOTALL) + + if think_match and boxed_match: + think_end = think_match.end() + boxed_start = boxed_match.start() + return 1.0 if think_end <= boxed_start else 0.0 + else: + return 0.0 + + +def _geo3k_combined_reward_fn(sol: str, gt: str) -> float: + """ + Geo3K combined reward function. + + Combines format reward and accuracy reward with weights. + Default: 90% accuracy + 10% format. + + :param sol: Solution string from model + :type sol: str + :param gt: Ground truth answer + :type gt: str + :return: Weighted combination of format and accuracy rewards + :rtype: float + """ + acc_r = _geo3k_accuracy_reward_fn(sol, gt) + fmt_r = _geo3k_format_reward_fn(sol, gt) + return 0.9 * acc_r + 0.1 * fmt_r + + +RuleReward.register_rule("geo3k_accuracy")(_geo3k_accuracy_reward_fn) +RuleReward.register_rule("geo3k_format")(_geo3k_format_reward_fn) +RuleReward.register_rule("geo3k_combined")(_geo3k_combined_reward_fn) + +# ============================================================================ +# GSM8K Rule Reward Functions +# ============================================================================ + + +def _gsm8k_accuracy_reward_fn(sol: str, gt: str) -> float: + """ + GSM8K accuracy reward function. + + Extract answer from \\boxed{} notation and use mathruler to verify correctness. + + :param sol: Solution string from model (should contain \\boxed{answer}) + :type sol: str + :param gt: Ground truth answer + :type gt: str + :return: 1.0 if answer is correct, 0.0 otherwise + :rtype: float + """ + try: + from mathruler.grader import extract_boxed_content, grade_answer + pred = extract_boxed_content(sol) + return 1.0 if grade_answer(pred, gt) else 0.0 + except ImportError: + return 0.0 + + +def _gsm8k_format_reward_fn(sol: str, gt: str) -> float: + """ + GSM8K format reward function. + + Check if the solution follows the required format: + - Contains ... tags for reasoning + - Contains \\boxed{} for final answer + - The think tags must appear BEFORE the boxed answer + + :param sol: Solution string from model + :type sol: str + :param gt: Ground truth (not used in format check) + :type gt: str + :return: 1.0 if format is correct, 0.0 otherwise + :rtype: float + """ + sol_stripped = sol.strip() + + think_match = re.search(r'.*?', sol_stripped, re.DOTALL) + boxed_match = re.search(r'\\boxed\{.*?\}', sol_stripped, re.DOTALL) + + if think_match and boxed_match: + think_end = think_match.end() + boxed_start = boxed_match.start() + return 1.0 if think_end <= boxed_start else 0.0 + else: + return 0.0 + + +def _gsm8k_combined_reward_fn(sol: str, gt: str) -> float: + """ + GSM8K combined reward function. + + Combines format reward and accuracy reward with weights. + Default: 90% accuracy + 10% format. + + :param sol: Solution string from model + :type sol: str + :param gt: Ground truth answer + :type gt: str + :return: Weighted combination of format and accuracy rewards + :rtype: float + """ + acc_r = _gsm8k_accuracy_reward_fn(sol, gt) + fmt_r = _gsm8k_format_reward_fn(sol, gt) + return 0.9 * acc_r + 0.1 * fmt_r + + +RuleReward.register_rule("gsm8k_accuracy")(_gsm8k_accuracy_reward_fn) +RuleReward.register_rule("gsm8k_format")(_gsm8k_format_reward_fn) +RuleReward.register_rule("gsm8k_combined")(_gsm8k_combined_reward_fn)