Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions examples/gsm8k_geo3k/train_colocate.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,8 +440,10 @@ def train(args):
parser.add_argument("--max_ckpt_mem", type=int, default=1e8)
parser.add_argument("--load_checkpoint", action="store_true", default=False)

# DAPO
parser.add_argument("--dynamic_sampling", action="store_true", default=False, help="Enable DAPO dynamic sampling strategy")
# DAPO (Decoupled Clip and Dynamic Sampling Policy Optimization)
parser.add_argument("--dynamic_sampling", action="store_true", default=False, help="Enable DAPO dynamic sampling strategy to filter out groups with uniform rewards")
parser.add_argument("--dynamic_sampling_metric", type=str, default="reward", help="Metric for dynamic sampling filtering: 'reward', 'acc', etc.")
parser.add_argument("--max_num_gen_batches", type=int, default=10, help="Maximum number of generation batches for dynamic sampling (<=0 means no limit)")
parser.add_argument("--overlong_buffer", action="store_true", default=False, help="Apply overlong sequence buffer in DAPO")
parser.add_argument("--overlong_buffer_len", type=int, default=1024, help="Max token threshold for overlong buffer")
parser.add_argument("--overlong_buffer_penalty_factor", type=float, default=1.0, help="Penalty scaling factor for overlong sequences, <1 discourages long outputs; >1 encourages them")
Expand Down
8 changes: 6 additions & 2 deletions lightrft/strategy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,12 @@ class StrategyConfig:
overlong_buffer_penalty_factor: float = 1.0

# Dynamic sampling and advantage estimation
# (bool): Enable dynamic sampling for advantage estimation, defaults to False
# (bool): Enable dynamic sampling for advantage estimation (DAPO), defaults to False
dynamic_sampling: bool = False
# (str): Metric to filter groups in dynamic sampling: "reward", "acc", etc., defaults to "reward"
dynamic_sampling_metric: str = "reward"
# (int): Maximum number of generation batches for dynamic sampling, <=0 means no limit, defaults to 10
max_num_gen_batches: int = 10
# (str): Advantage estimator method, defaults to "gae"
advantage_estimator: str = "group_norm"

Expand Down Expand Up @@ -280,7 +284,7 @@ def print_config_summary(self) -> None:

# Dynamic Sampling and Advantage Estimation Parameters
print("\nDynamic Sampling and Advantage Estimation Parameters:")
for attr in ['dynamic_sampling', 'advantage_estimator']:
for attr in ['dynamic_sampling', 'dynamic_sampling_metric', 'max_num_gen_batches', 'advantage_estimator']:
current = getattr(self, attr)
default = getattr(default_config, attr)
status = "Overridden" if current != default else "Default"
Expand Down
50 changes: 38 additions & 12 deletions lightrft/trainer/advantage_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,9 +627,12 @@ class GroupNormCalculator(BaseREINFORCECalculator):
"""
Group normalization calculator (GRPO).

Normalizes rewards within each group and optionally filters degenerate cases.
Normalizes rewards within each group and optionally filters degenerate cases
using dynamic sampling strategy (DAPO).

Reference: GRPO: https://arxiv.org/pdf/2402.03300
Reference:
- GRPO: https://arxiv.org/pdf/2402.03300
- DAPO: https://arxiv.org/abs/2503.14476
"""
def preprocess_rewards(
self,
Expand All @@ -640,6 +643,10 @@ def preprocess_rewards(
"""
Preprocess rewards using group normalization with optional dynamic filtering.

Dynamic sampling (DAPO) filters out groups where all samples have the same metric value
(e.g., all rewards are 0 or all are 1), as these groups provide no learning signal.
This is achieved by setting action_mask to all zeros for filtered groups.

:param rewards: Concatenated reward tensor
:type rewards: torch.Tensor
:param experiences: List of experiences (may be filtered)
Expand All @@ -652,17 +659,36 @@ def preprocess_rewards(
config = self.config
n_samples = config.n_samples_per_prompt

# Dynamic sampling filtering
# Dynamic sampling filtering (DAPO)
# Filter out groups where all outputs have the same metric value
if config.dynamic_sampling:
step_size = n_samples // config.micro_train_batch_size
for i in range(0, len(experiences), step_size):
chunk = experiences[i:i + step_size]
chunk_rewards = torch.cat([exp.info["reward"] for exp in chunk])

# Filter out degenerate cases (all 0s or all 1s)
if torch.all(chunk_rewards == 0) or torch.all(chunk_rewards == 1):
for exp in chunk:
exp.action_mask = torch.zeros_like(exp.action_mask, dtype=torch.bool)
metric = config.dynamic_sampling_metric

# When micro_rollout_batch_size == n_samples_per_prompt, each experience
# contains all samples for one prompt in batched format
# exp.info["reward"] has shape=[n_samples], representing all samples for that prompt
for exp in experiences:
reward_tensor = exp.info["reward"] # shape=[n_samples]

# Extract metric values (all samples within this experience/prompt)
if metric == "reward":
metric_values = reward_tensor
elif metric == "acc":
# Use accuracy if available
if "accuracy" in exp.info:
metric_values = exp.info["accuracy"]
else:
# Fallback: treat reward as binary accuracy
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

raise RuntimeError here, not use fallback

metric_values = reward_tensor
else:
# Default to reward
metric_values = reward_tensor

# Check if all samples have the same metric value (degenerate group)
# This prompt provides no learning signal for relative comparison
if torch.all(metric_values == metric_values[0]):
# Mark this experience for filtering by zeroing out action mask
exp.action_mask = torch.zeros_like(exp.action_mask, dtype=torch.bool)

# Group normalization
rewards = rewards.reshape(-1, n_samples).to("cuda")
Expand Down
103 changes: 83 additions & 20 deletions lightrft/trainer/ppo_trainer_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,30 +369,85 @@ def fit(
f"rand_prompts:\n {rand_prompts}\n , rand_images:{rand_images}\n , rand_references:{rand_references}\n, rand_labels:{rand_labels}\n " # noqa
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should also add the similar implementation in ppo_trainer.py

)

for i, experience in enumerate(
self.experience_maker.make_experience_list(
rand_prompts,
rand_images,
all_videos=rand_videos,
all_references=rand_references,
all_labels=rand_labels,
**self.generate_kwargs
)
):
if i == 0:
output = self.tokenizer.batch_decode(
experience.sequences[0].unsqueeze(0), skip_special_tokens=True
# ========== Dynamic Sampling Loop (DAPO) ==========
# When dynamic_sampling is enabled, we may need to generate multiple batches
# to collect enough valid prompts (groups with varying rewards)
num_gen_batches = 0
target_num_prompts = args.rollout_batch_size
n_samples = args.n_samples_per_prompt

while True:
num_gen_batches += 1

# Generate experiences for current batch
for i, experience in enumerate(
self.experience_maker.make_experience_list(
rand_prompts,
rand_images,
all_videos=rand_videos,
all_references=rand_references,
all_labels=rand_labels,
**self.generate_kwargs
)
self.strategy.print("collect phase: experience.sequences w skip_special_tokens: ", output)
):
if i == 0:
output = self.tokenizer.batch_decode(
experience.sequences[0].unsqueeze(0), skip_special_tokens=True
)
self.strategy.print("collect phase: experience.sequences w skip_special_tokens: ", output)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove these two prints here

self.strategy.print(
f"collect phase: rand_prompts:\n {rand_prompts[0:2]}\n , rand_images:{rand_images[0:2]}\n , rand_references:{rand_references[0:2]}\n, rand_labels:{rand_labels[0:2]}\n " # noqa
)

self.replay_buffer.append(experience)

# Check if dynamic sampling is enabled
if not self.strategy.config.dynamic_sampling:
# No dynamic sampling, exit after first batch
break

# Count valid prompts (groups with non-zero action masks after filtering)
# This check happens AFTER all experiences in the batch are generated
# Note: When micro_rollout_batch_size == n_samples_per_prompt, each experience
# contains all samples for one prompt. So we count experiences directly.
num_valid_prompts = 0
for exp in self.replay_buffer.items:
# Check if this experience has any valid actions (not all filtered)
if exp.action_mask.sum() > 0:
num_valid_prompts += 1

if self.strategy.is_rank_0():
self.strategy.print(
f"collect phase: rand_prompts:\n {rand_prompts[0:2]}\n , rand_images:{rand_images[0:2]}\n , rand_references:{rand_references[0:2]}\n, rand_labels:{rand_labels[0:2]}\n " # noqa
f"Dynamic Sampling: num_valid_prompts={num_valid_prompts}, "
f"target={target_num_prompts}, num_gen_batches={num_gen_batches}"
)
# print all
# self.strategy.print(
# f"rand_prompts:\n {rand_prompts}\n , rand_images:{rand_images}\n , rand_references:{rand_references}\n, rand_labels:{rand_labels}\n " # noqa
# )

self.replay_buffer.append(experience)
# Check if we have enough valid prompts
if num_valid_prompts >= target_num_prompts:
# Trim to exact target size
target_num_experiences = target_num_prompts * n_samples
self.replay_buffer.items = self.replay_buffer.items[:target_num_experiences]
break

# Check if we've reached the maximum number of generation batches
max_num_gen_batches = self.strategy.config.max_num_gen_batches
if max_num_gen_batches > 0 and num_gen_batches >= max_num_gen_batches:
if self.strategy.is_rank_0():
self.strategy.print(
f"Warning: Reached max_num_gen_batches={max_num_gen_batches} "
f"with only {num_valid_prompts} valid prompts. Proceeding with available data."
)
break

# Need more samples, but current implementation only processes one batch
# In a full implementation, we would fetch the next batch from dataloader here
# For now, we proceed with what we have
if self.strategy.is_rank_0():
self.strategy.print(
f"Warning: Dynamic sampling needs more batches, but current implementation "
f"processes one batch at a time. Proceeding with {num_valid_prompts} valid prompts."
)
break
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you add break here, this while loop is not necessary, maybe you can omit it. Maybe you should set continue in the for loop. Only in the case that num_valid_prompts >= target_num_prompts, the following code should be executed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there are bugs in you previous experiment, the current implementation can not accumulate enough data for training when dynamic_sampling is enabled.


self.strategy.report_memory('after replay_buffer ready')

Expand Down Expand Up @@ -711,6 +766,14 @@ def training_step_actor(self,
base_action_log_probs = experience.base_action_log_probs

if advantages is not None:
# Check if advantages is empty (can happen when dynamic sampling filters all samples)
if advantages.numel() == 0:
self.strategy.print(
"[Warning] Empty advantages tensor detected. This may occur when dynamic sampling "
"filters out all samples in a batch. Skipping this training step."
)
return {} # Return empty status dict to skip this step

# Log max advantage before clipping for debugging (optional)
max_adv = advantages.max().item()
if max_adv > 10.0:
Expand Down
11 changes: 11 additions & 0 deletions lightrft/trainer/spmd_ppo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,17 @@ def ppo_train(self, global_steps=0): # Currently using this rewritten ppo_train
)
should_skip_local = not is_valid

# Check for empty advantages (can happen when dynamic sampling filters all samples)
if not should_skip_local and hasattr(experience, 'advantages') and experience.advantages is not None:
if isinstance(experience.advantages, list):
# Packed samples: check if any advantages are empty
if any(adv.numel() == 0 for adv in experience.advantages):
should_skip_local = True
else:
# Single tensor: check if empty
if experience.advantages.numel() == 0:
should_skip_local = True

# Step 2: Synchronize skip decision across all ranks via all_reduce
# This ensures all ranks agree on whether to skip, preventing execution divergence
skip_flag = torch.tensor([1.0 if should_skip_local else 0.0], device=device)
Expand Down