diff --git a/examples/gsm8k_geo3k/train_colocate.py b/examples/gsm8k_geo3k/train_colocate.py index 818bcce..da036e9 100644 --- a/examples/gsm8k_geo3k/train_colocate.py +++ b/examples/gsm8k_geo3k/train_colocate.py @@ -440,8 +440,10 @@ def train(args): parser.add_argument("--max_ckpt_mem", type=int, default=1e8) parser.add_argument("--load_checkpoint", action="store_true", default=False) - # DAPO - parser.add_argument("--dynamic_sampling", action="store_true", default=False, help="Enable DAPO dynamic sampling strategy") + # DAPO (Decoupled Clip and Dynamic Sampling Policy Optimization) + parser.add_argument("--dynamic_sampling", action="store_true", default=False, help="Enable DAPO dynamic sampling strategy to filter out groups with uniform rewards") + parser.add_argument("--dynamic_sampling_metric", type=str, default="reward", help="Metric for dynamic sampling filtering: 'reward', 'acc', etc.") + parser.add_argument("--max_num_gen_batches", type=int, default=10, help="Maximum number of generation batches for dynamic sampling (<=0 means no limit)") parser.add_argument("--overlong_buffer", action="store_true", default=False, help="Apply overlong sequence buffer in DAPO") parser.add_argument("--overlong_buffer_len", type=int, default=1024, help="Max token threshold for overlong buffer") parser.add_argument("--overlong_buffer_penalty_factor", type=float, default=1.0, help="Penalty scaling factor for overlong sequences, <1 discourages long outputs; >1 encourages them") diff --git a/lightrft/strategy/config.py b/lightrft/strategy/config.py index c699300..70318e6 100644 --- a/lightrft/strategy/config.py +++ b/lightrft/strategy/config.py @@ -110,8 +110,12 @@ class StrategyConfig: overlong_buffer_penalty_factor: float = 1.0 # Dynamic sampling and advantage estimation - # (bool): Enable dynamic sampling for advantage estimation, defaults to False + # (bool): Enable dynamic sampling for advantage estimation (DAPO), defaults to False dynamic_sampling: bool = False + # (str): Metric to filter groups in dynamic sampling: "reward", "acc", etc., defaults to "reward" + dynamic_sampling_metric: str = "reward" + # (int): Maximum number of generation batches for dynamic sampling, <=0 means no limit, defaults to 10 + max_num_gen_batches: int = 10 # (str): Advantage estimator method, defaults to "gae" advantage_estimator: str = "group_norm" @@ -280,7 +284,7 @@ def print_config_summary(self) -> None: # Dynamic Sampling and Advantage Estimation Parameters print("\nDynamic Sampling and Advantage Estimation Parameters:") - for attr in ['dynamic_sampling', 'advantage_estimator']: + for attr in ['dynamic_sampling', 'dynamic_sampling_metric', 'max_num_gen_batches', 'advantage_estimator']: current = getattr(self, attr) default = getattr(default_config, attr) status = "Overridden" if current != default else "Default" diff --git a/lightrft/trainer/advantage_calculator.py b/lightrft/trainer/advantage_calculator.py index 1ce9a8f..cb3844f 100644 --- a/lightrft/trainer/advantage_calculator.py +++ b/lightrft/trainer/advantage_calculator.py @@ -627,9 +627,12 @@ class GroupNormCalculator(BaseREINFORCECalculator): """ Group normalization calculator (GRPO). - Normalizes rewards within each group and optionally filters degenerate cases. + Normalizes rewards within each group and optionally filters degenerate cases + using dynamic sampling strategy (DAPO). - Reference: GRPO: https://arxiv.org/pdf/2402.03300 + Reference: + - GRPO: https://arxiv.org/pdf/2402.03300 + - DAPO: https://arxiv.org/abs/2503.14476 """ def preprocess_rewards( self, @@ -640,6 +643,10 @@ def preprocess_rewards( """ Preprocess rewards using group normalization with optional dynamic filtering. + Dynamic sampling (DAPO) filters out groups where all samples have the same metric value + (e.g., all rewards are 0 or all are 1), as these groups provide no learning signal. + This is achieved by setting action_mask to all zeros for filtered groups. + :param rewards: Concatenated reward tensor :type rewards: torch.Tensor :param experiences: List of experiences (may be filtered) @@ -652,17 +659,36 @@ def preprocess_rewards( config = self.config n_samples = config.n_samples_per_prompt - # Dynamic sampling filtering + # Dynamic sampling filtering (DAPO) + # Filter out groups where all outputs have the same metric value if config.dynamic_sampling: - step_size = n_samples // config.micro_train_batch_size - for i in range(0, len(experiences), step_size): - chunk = experiences[i:i + step_size] - chunk_rewards = torch.cat([exp.info["reward"] for exp in chunk]) - - # Filter out degenerate cases (all 0s or all 1s) - if torch.all(chunk_rewards == 0) or torch.all(chunk_rewards == 1): - for exp in chunk: - exp.action_mask = torch.zeros_like(exp.action_mask, dtype=torch.bool) + metric = config.dynamic_sampling_metric + + # When micro_rollout_batch_size == n_samples_per_prompt, each experience + # contains all samples for one prompt in batched format + # exp.info["reward"] has shape=[n_samples], representing all samples for that prompt + for exp in experiences: + reward_tensor = exp.info["reward"] # shape=[n_samples] + + # Extract metric values (all samples within this experience/prompt) + if metric == "reward": + metric_values = reward_tensor + elif metric == "acc": + # Use accuracy if available + if "accuracy" in exp.info: + metric_values = exp.info["accuracy"] + else: + # Fallback: treat reward as binary accuracy + metric_values = reward_tensor + else: + # Default to reward + metric_values = reward_tensor + + # Check if all samples have the same metric value (degenerate group) + # This prompt provides no learning signal for relative comparison + if torch.all(metric_values == metric_values[0]): + # Mark this experience for filtering by zeroing out action mask + exp.action_mask = torch.zeros_like(exp.action_mask, dtype=torch.bool) # Group normalization rewards = rewards.reshape(-1, n_samples).to("cuda") diff --git a/lightrft/trainer/ppo_trainer_vl.py b/lightrft/trainer/ppo_trainer_vl.py index 4584129..ba00a3e 100644 --- a/lightrft/trainer/ppo_trainer_vl.py +++ b/lightrft/trainer/ppo_trainer_vl.py @@ -369,30 +369,85 @@ def fit( f"rand_prompts:\n {rand_prompts}\n , rand_images:{rand_images}\n , rand_references:{rand_references}\n, rand_labels:{rand_labels}\n " # noqa ) - for i, experience in enumerate( - self.experience_maker.make_experience_list( - rand_prompts, - rand_images, - all_videos=rand_videos, - all_references=rand_references, - all_labels=rand_labels, - **self.generate_kwargs - ) - ): - if i == 0: - output = self.tokenizer.batch_decode( - experience.sequences[0].unsqueeze(0), skip_special_tokens=True + # ========== Dynamic Sampling Loop (DAPO) ========== + # When dynamic_sampling is enabled, we may need to generate multiple batches + # to collect enough valid prompts (groups with varying rewards) + num_gen_batches = 0 + target_num_prompts = args.rollout_batch_size + n_samples = args.n_samples_per_prompt + + while True: + num_gen_batches += 1 + + # Generate experiences for current batch + for i, experience in enumerate( + self.experience_maker.make_experience_list( + rand_prompts, + rand_images, + all_videos=rand_videos, + all_references=rand_references, + all_labels=rand_labels, + **self.generate_kwargs ) - self.strategy.print("collect phase: experience.sequences w skip_special_tokens: ", output) + ): + if i == 0: + output = self.tokenizer.batch_decode( + experience.sequences[0].unsqueeze(0), skip_special_tokens=True + ) + self.strategy.print("collect phase: experience.sequences w skip_special_tokens: ", output) + self.strategy.print( + f"collect phase: rand_prompts:\n {rand_prompts[0:2]}\n , rand_images:{rand_images[0:2]}\n , rand_references:{rand_references[0:2]}\n, rand_labels:{rand_labels[0:2]}\n " # noqa + ) + + self.replay_buffer.append(experience) + + # Check if dynamic sampling is enabled + if not self.strategy.config.dynamic_sampling: + # No dynamic sampling, exit after first batch + break + + # Count valid prompts (groups with non-zero action masks after filtering) + # This check happens AFTER all experiences in the batch are generated + # Note: When micro_rollout_batch_size == n_samples_per_prompt, each experience + # contains all samples for one prompt. So we count experiences directly. + num_valid_prompts = 0 + for exp in self.replay_buffer.items: + # Check if this experience has any valid actions (not all filtered) + if exp.action_mask.sum() > 0: + num_valid_prompts += 1 + + if self.strategy.is_rank_0(): self.strategy.print( - f"collect phase: rand_prompts:\n {rand_prompts[0:2]}\n , rand_images:{rand_images[0:2]}\n , rand_references:{rand_references[0:2]}\n, rand_labels:{rand_labels[0:2]}\n " # noqa + f"Dynamic Sampling: num_valid_prompts={num_valid_prompts}, " + f"target={target_num_prompts}, num_gen_batches={num_gen_batches}" ) - # print all - # self.strategy.print( - # f"rand_prompts:\n {rand_prompts}\n , rand_images:{rand_images}\n , rand_references:{rand_references}\n, rand_labels:{rand_labels}\n " # noqa - # ) - self.replay_buffer.append(experience) + # Check if we have enough valid prompts + if num_valid_prompts >= target_num_prompts: + # Trim to exact target size + target_num_experiences = target_num_prompts * n_samples + self.replay_buffer.items = self.replay_buffer.items[:target_num_experiences] + break + + # Check if we've reached the maximum number of generation batches + max_num_gen_batches = self.strategy.config.max_num_gen_batches + if max_num_gen_batches > 0 and num_gen_batches >= max_num_gen_batches: + if self.strategy.is_rank_0(): + self.strategy.print( + f"Warning: Reached max_num_gen_batches={max_num_gen_batches} " + f"with only {num_valid_prompts} valid prompts. Proceeding with available data." + ) + break + + # Need more samples, but current implementation only processes one batch + # In a full implementation, we would fetch the next batch from dataloader here + # For now, we proceed with what we have + if self.strategy.is_rank_0(): + self.strategy.print( + f"Warning: Dynamic sampling needs more batches, but current implementation " + f"processes one batch at a time. Proceeding with {num_valid_prompts} valid prompts." + ) + break self.strategy.report_memory('after replay_buffer ready') @@ -711,6 +766,14 @@ def training_step_actor(self, base_action_log_probs = experience.base_action_log_probs if advantages is not None: + # Check if advantages is empty (can happen when dynamic sampling filters all samples) + if advantages.numel() == 0: + self.strategy.print( + "[Warning] Empty advantages tensor detected. This may occur when dynamic sampling " + "filters out all samples in a batch. Skipping this training step." + ) + return {} # Return empty status dict to skip this step + # Log max advantage before clipping for debugging (optional) max_adv = advantages.max().item() if max_adv > 10.0: diff --git a/lightrft/trainer/spmd_ppo_trainer.py b/lightrft/trainer/spmd_ppo_trainer.py index d79a745..7b7b5bd 100644 --- a/lightrft/trainer/spmd_ppo_trainer.py +++ b/lightrft/trainer/spmd_ppo_trainer.py @@ -232,6 +232,17 @@ def ppo_train(self, global_steps=0): # Currently using this rewritten ppo_train ) should_skip_local = not is_valid + # Check for empty advantages (can happen when dynamic sampling filters all samples) + if not should_skip_local and hasattr(experience, 'advantages') and experience.advantages is not None: + if isinstance(experience.advantages, list): + # Packed samples: check if any advantages are empty + if any(adv.numel() == 0 for adv in experience.advantages): + should_skip_local = True + else: + # Single tensor: check if empty + if experience.advantages.numel() == 0: + should_skip_local = True + # Step 2: Synchronize skip decision across all ranks via all_reduce # This ensures all ranks agree on whether to skip, preventing execution divergence skip_flag = torch.tensor([1.0 if should_skip_local else 0.0], device=device)