-
Notifications
You must be signed in to change notification settings - Fork 10
feature(sunjx): implement dynamic sampling strategy in DAPO #40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -369,30 +369,85 @@ def fit( | |
| f"rand_prompts:\n {rand_prompts}\n , rand_images:{rand_images}\n , rand_references:{rand_references}\n, rand_labels:{rand_labels}\n " # noqa | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you should also add the similar implementation in |
||
| ) | ||
|
|
||
| for i, experience in enumerate( | ||
| self.experience_maker.make_experience_list( | ||
| rand_prompts, | ||
| rand_images, | ||
| all_videos=rand_videos, | ||
| all_references=rand_references, | ||
| all_labels=rand_labels, | ||
| **self.generate_kwargs | ||
| ) | ||
| ): | ||
| if i == 0: | ||
| output = self.tokenizer.batch_decode( | ||
| experience.sequences[0].unsqueeze(0), skip_special_tokens=True | ||
| # ========== Dynamic Sampling Loop (DAPO) ========== | ||
| # When dynamic_sampling is enabled, we may need to generate multiple batches | ||
| # to collect enough valid prompts (groups with varying rewards) | ||
| num_gen_batches = 0 | ||
| target_num_prompts = args.rollout_batch_size | ||
| n_samples = args.n_samples_per_prompt | ||
|
|
||
| while True: | ||
| num_gen_batches += 1 | ||
|
|
||
| # Generate experiences for current batch | ||
| for i, experience in enumerate( | ||
| self.experience_maker.make_experience_list( | ||
| rand_prompts, | ||
| rand_images, | ||
| all_videos=rand_videos, | ||
| all_references=rand_references, | ||
| all_labels=rand_labels, | ||
| **self.generate_kwargs | ||
| ) | ||
| self.strategy.print("collect phase: experience.sequences w skip_special_tokens: ", output) | ||
| ): | ||
| if i == 0: | ||
| output = self.tokenizer.batch_decode( | ||
| experience.sequences[0].unsqueeze(0), skip_special_tokens=True | ||
| ) | ||
| self.strategy.print("collect phase: experience.sequences w skip_special_tokens: ", output) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove these two prints here |
||
| self.strategy.print( | ||
| f"collect phase: rand_prompts:\n {rand_prompts[0:2]}\n , rand_images:{rand_images[0:2]}\n , rand_references:{rand_references[0:2]}\n, rand_labels:{rand_labels[0:2]}\n " # noqa | ||
| ) | ||
|
|
||
| self.replay_buffer.append(experience) | ||
|
|
||
| # Check if dynamic sampling is enabled | ||
| if not self.strategy.config.dynamic_sampling: | ||
| # No dynamic sampling, exit after first batch | ||
| break | ||
|
|
||
| # Count valid prompts (groups with non-zero action masks after filtering) | ||
| # This check happens AFTER all experiences in the batch are generated | ||
| # Note: When micro_rollout_batch_size == n_samples_per_prompt, each experience | ||
| # contains all samples for one prompt. So we count experiences directly. | ||
| num_valid_prompts = 0 | ||
| for exp in self.replay_buffer.items: | ||
| # Check if this experience has any valid actions (not all filtered) | ||
| if exp.action_mask.sum() > 0: | ||
| num_valid_prompts += 1 | ||
|
|
||
| if self.strategy.is_rank_0(): | ||
| self.strategy.print( | ||
| f"collect phase: rand_prompts:\n {rand_prompts[0:2]}\n , rand_images:{rand_images[0:2]}\n , rand_references:{rand_references[0:2]}\n, rand_labels:{rand_labels[0:2]}\n " # noqa | ||
| f"Dynamic Sampling: num_valid_prompts={num_valid_prompts}, " | ||
| f"target={target_num_prompts}, num_gen_batches={num_gen_batches}" | ||
| ) | ||
| # print all | ||
| # self.strategy.print( | ||
| # f"rand_prompts:\n {rand_prompts}\n , rand_images:{rand_images}\n , rand_references:{rand_references}\n, rand_labels:{rand_labels}\n " # noqa | ||
| # ) | ||
|
|
||
| self.replay_buffer.append(experience) | ||
| # Check if we have enough valid prompts | ||
| if num_valid_prompts >= target_num_prompts: | ||
| # Trim to exact target size | ||
| target_num_experiences = target_num_prompts * n_samples | ||
| self.replay_buffer.items = self.replay_buffer.items[:target_num_experiences] | ||
| break | ||
|
|
||
| # Check if we've reached the maximum number of generation batches | ||
| max_num_gen_batches = self.strategy.config.max_num_gen_batches | ||
| if max_num_gen_batches > 0 and num_gen_batches >= max_num_gen_batches: | ||
| if self.strategy.is_rank_0(): | ||
| self.strategy.print( | ||
| f"Warning: Reached max_num_gen_batches={max_num_gen_batches} " | ||
| f"with only {num_valid_prompts} valid prompts. Proceeding with available data." | ||
| ) | ||
| break | ||
|
|
||
| # Need more samples, but current implementation only processes one batch | ||
| # In a full implementation, we would fetch the next batch from dataloader here | ||
| # For now, we proceed with what we have | ||
| if self.strategy.is_rank_0(): | ||
| self.strategy.print( | ||
| f"Warning: Dynamic sampling needs more batches, but current implementation " | ||
| f"processes one batch at a time. Proceeding with {num_valid_prompts} valid prompts." | ||
| ) | ||
| break | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you add
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think there are bugs in you previous experiment, the current implementation can not accumulate enough data for training when dynamic_sampling is enabled. |
||
|
|
||
| self.strategy.report_memory('after replay_buffer ready') | ||
|
|
||
|
|
@@ -711,6 +766,14 @@ def training_step_actor(self, | |
| base_action_log_probs = experience.base_action_log_probs | ||
|
|
||
| if advantages is not None: | ||
| # Check if advantages is empty (can happen when dynamic sampling filters all samples) | ||
| if advantages.numel() == 0: | ||
| self.strategy.print( | ||
| "[Warning] Empty advantages tensor detected. This may occur when dynamic sampling " | ||
| "filters out all samples in a batch. Skipping this training step." | ||
| ) | ||
| return {} # Return empty status dict to skip this step | ||
|
|
||
| # Log max advantage before clipping for debugging (optional) | ||
| max_adv = advantages.max().item() | ||
| if max_adv > 10.0: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
raise RuntimeError here, not use fallback