refactor: simplify forward() permutation logic for compile-friendly execution#344
Merged
fracape merged 2 commits intoInterDigitalInc:masterfrom Oct 23, 2025
studyingeugene:master
Merged
refactor: simplify forward() permutation logic for compile-friendly execution#344fracape merged 2 commits intoInterDigitalInc:masterfrom studyingeugene:master
fracape merged 2 commits intoInterDigitalInc:masterfrom
studyingeugene:master
Conversation
…xecution What's changed - Replace tensor-based perm construction with list-based version - Add explicit inverse permutation for correctness - Remove TorchScript-specific branches Why - Compile-friendly: torch.compile/AOTAutograd prefer static Python control flow and index lists over device tensor construction inside forward. Replacing torch.tensor([...]), torch.arange(...), and torch.cat(...) with plain Python lists reduces graph breaks and guard complexity, improving compilation stability and cache reuse.
Fix lint errors in entropy_models.py
fracape
approved these changes
Oct 23, 2025
Contributor
Author
|
@fracape Thank you for reviewing and accepting my pull request! |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
I treated this as a small bug fix rather than a feature addition, so I submitted a PR directly without an issue.
Apologies if that’s against the usual workflow — I’ll be glad to open an issue if preferred.
What's changed
torch.tensor,torch.arange,torch.cat) with pure Python list versioninv_perm[p] = i)is_scripting()guard used for perm/perm_inv construction)Why
The previous implementation created small tensors on device each forward call, e.g.:
Old one causes:
torch.compile()due to dynamic tensor creationChanged one improves:
torch.compileEvaluation
Please see the attached test script:
test_script.zip
The script:
mbt2018_meanmodels (model_old,model_new)EntropyBottleneck.forward()ofmodel_newtorch.allcloseand compilation checkstorch._dynamo.explain()Results:
Below are excerpts from the output logs:
These results confirm that the refactored version compiles cleanly and runs efficiently without any graph breaks.
These result confirms that no functional difference exists between the original and refactored implementations.
Addition
Since only the forward() method was modified, all existing parameters and buffers remain valid and can be reused without any reinitialization.
Thanks for reading
I appreciate your time reviewing this change.