Skip to content

feature(sunjx): implement dynamic sampling strategy in DAPO#40

Closed
Jiaxuan-Sun wants to merge 3 commits intoopendilab:mainfrom
Jiaxuan-Sun:feature/dynamic-sampling
Closed

feature(sunjx): implement dynamic sampling strategy in DAPO#40
Jiaxuan-Sun wants to merge 3 commits intoopendilab:mainfrom
Jiaxuan-Sun:feature/dynamic-sampling

Conversation

@Jiaxuan-Sun
Copy link
Contributor

Implement Dynamic Sampling (DAPO) for GRPO Training

This PR implements the dynamic sampling strategy from DAPO (Decoupled Clip and Dynamic Sampling Policy Optimization) to improve GRPO training efficiency.

Key Features

  • Group filtering: Filters out prompt groups where all responses have the same metric value (all correct or all incorrect), as they provide no useful gradient information for relative policy optimization
image

@puyuan1996 puyuan1996 changed the title Feature(sunjx): Implement Dynamic Sampling (DAPO) for GRPO Training feature(sunjx): implement dynamic sampling strategy in DAPO Feb 10, 2026
if "accuracy" in exp.info:
metric_values = exp.info["accuracy"]
else:
# Fallback: treat reward as binary accuracy
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

raise RuntimeError here, not use fallback

f"Warning: Dynamic sampling needs more batches, but current implementation "
f"processes one batch at a time. Proceeding with {num_valid_prompts} valid prompts."
)
break
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you add break here, this while loop is not necessary, maybe you can omit it. Maybe you should set continue in the for loop. Only in the case that num_valid_prompts >= target_num_prompts, the following code should be executed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there are bugs in you previous experiment, the current implementation can not accumulate enough data for training when dynamic_sampling is enabled.

output = self.tokenizer.batch_decode(
experience.sequences[0].unsqueeze(0), skip_special_tokens=True
)
self.strategy.print("collect phase: experience.sequences w skip_special_tokens: ", output)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove these two prints here

@@ -369,30 +369,85 @@ def fit(
f"rand_prompts:\n {rand_prompts}\n , rand_images:{rand_images}\n , rand_references:{rand_references}\n, rand_labels:{rand_labels}\n " # noqa
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should also add the similar implementation in ppo_trainer.py

@puyuan1996
Copy link
Collaborator

We have a new PR: #51

@puyuan1996 puyuan1996 closed this Mar 11, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants