[fsdp, algo] no grad for entropy and kl if the loss coef is 0#6519
[fsdp, algo] no grad for entropy and kl if the loss coef is 0#6519huaiyizhao wants to merge 2 commits into
Conversation
There was a problem hiding this comment.
Code Review
This pull request optimizes gradient computation during training. In transformer_impl.py, it dynamically determines if entropy calculations require gradients based on the entropy_coeff value, running under torch.no_grad() when possible. In losses.py, it optimizes PPO loss when kl_loss_coef is 0.0 by detaching log_prob and computing the KL penalty inside a torch.no_grad() block. The reviewer suggested making the retrieval of entropy_coeff more robust by supporting both object attributes and dictionary lookups for loss_config to prevent unnecessary gradient tracking.
| entropy_needs_grad = calculate_entropy | ||
| if calculate_entropy and logits_processor_func is not None: | ||
| loss_config = getattr(logits_processor_func, "keywords", {}).get("config") | ||
| entropy_coeff = getattr(loss_config, "entropy_coeff", None) | ||
| if entropy_coeff is not None: | ||
| try: | ||
| entropy_needs_grad = float(entropy_coeff) != 0.0 | ||
| except (TypeError, ValueError): | ||
| entropy_needs_grad = True |
There was a problem hiding this comment.
The current implementation uses getattr(loss_config, "entropy_coeff", None) to retrieve the entropy coefficient. However, if loss_config is a standard Python dictionary (which is common in some testing environments or custom pipelines), getattr will return None instead of retrieving the key. This will silently cause entropy_needs_grad to default to True, unnecessarily extending the backward graph and consuming extra memory.
To make this robust, we should support retrieving entropy_coeff from both object attributes and dictionary keys.
| entropy_needs_grad = calculate_entropy | |
| if calculate_entropy and logits_processor_func is not None: | |
| loss_config = getattr(logits_processor_func, "keywords", {}).get("config") | |
| entropy_coeff = getattr(loss_config, "entropy_coeff", None) | |
| if entropy_coeff is not None: | |
| try: | |
| entropy_needs_grad = float(entropy_coeff) != 0.0 | |
| except (TypeError, ValueError): | |
| entropy_needs_grad = True | |
| entropy_needs_grad = calculate_entropy | |
| if calculate_entropy and logits_processor_func is not None: | |
| loss_config = getattr(logits_processor_func, "keywords", {}).get("config") | |
| entropy_coeff = None | |
| if loss_config is not None: | |
| if isinstance(loss_config, dict): | |
| entropy_coeff = loss_config.get("entropy_coeff") | |
| else: | |
| entropy_coeff = getattr(loss_config, "entropy_coeff", None) | |
| if entropy_coeff is None and hasattr(loss_config, "get"): | |
| entropy_coeff = loss_config.get("entropy_coeff", None) | |
| if entropy_coeff is not None: | |
| try: | |
| entropy_needs_grad = float(entropy_coeff) != 0.0 | |
| except (TypeError, ValueError): | |
| entropy_needs_grad = True |
|
|
||
| policy_loss += kl_loss * config.kl_loss_coef | ||
| kl_loss_coeff = float(config.kl_loss_coef) | ||
| if kl_loss_coeff == 0.0: |
There was a problem hiding this comment.
unify computation with nullcontext() as well.
What does this PR do?
This PR prevents logging-only entropy and KL metrics from unnecessarily extending the actor backward graph.
Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,veomni,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,cfg,reward,fully_async,one_step_off,like[megatron, fsdp, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batchingTest
API and Usage Example
# Add code snippet or script demonstrating how to use thisDesign & Code Changes
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)recipesubmodule, please also update the reference to the submodule commit viagit submodule update --remoteorcd recipe && git pull origin main.