Skip to content

ai code review and fix#1480

Open
pstjohn wants to merge 3 commits intoNVIDIA:mainfrom
pstjohn:pstjohn/claude-review
Open

ai code review and fix#1480
pstjohn wants to merge 3 commits intoNVIDIA:mainfrom
pstjohn:pstjohn/claude-review

Conversation

@pstjohn
Copy link
Collaborator

@pstjohn pstjohn commented Feb 21, 2026

Misc. semi-automated fixes adding documentation and some additional tests throughout the recipes directories

Summary by CodeRabbit

Release Notes

  • New Features

    • Added Mixtral model support with TransformerEngine optimizations and comprehensive documentation
    • Introduced context-parallel tensor processing helpers across model implementations
    • Added token dropout support for different input formats (BSHD and THD)
    • New test utilities for checkpoint pruning and scheduler validation
  • Refactoring

    • Updated state transformation API across convert modules for consistency
    • Simplified checkpoint and dataset utilities
    • Refactored collators to use modular helper functions
  • Documentation

    • Expanded module docstrings describing state dict transformation system
    • Added documentation for attention input formats across models
    • Updated README references and configuration documentation
  • Tests

    • Added Mixtral export validation tests
    • New checkpoint pruning and scheduler tests with comprehensive edge-case coverage
    • Updated test base classes for consistency across model implementations

@copy-pr-bot
Copy link

copy-pr-bot bot commented Feb 21, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 21, 2026

Important

Review skipped

Auto reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs).
Share your feedback on Discord.


Comment @coderabbitai help to get the list of available commands and usage tips.

Signed-off-by: Peter St. John <pstjohn@nvidia.com>
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
@pstjohn pstjohn force-pushed the pstjohn/claude-review branch from 14b1640 to 773d581 Compare February 24, 2026 22:09
@pstjohn pstjohn marked this pull request as ready for review February 24, 2026 22:10
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
bionemo-recipes/recipes/esm2_native_te/checkpoint.py (1)

590-590: ⚠️ Potential issue | 🟠 Major

torch.load called without weights_only=True — arbitrary code execution risk.

Loading a pickle-backed .pt file without weights_only=True allows a crafted or corrupted checkpoint to execute arbitrary Python during deserialization. The llama3_native_te/checkpoint.py equivalent (line 444) already passes weights_only=True; this file should match it.

🛡️ Proposed fix
-    dataloader_state = torch.load(dataloader_path)
+    dataloader_state = torch.load(dataloader_path, weights_only=True)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/recipes/esm2_native_te/checkpoint.py` at line 590, The call
dataloader_state = torch.load(dataloader_path) in checkpoint.py is unsafe;
update the torch.load invocation to pass weights_only=True (matching the
llama3_native_te/checkpoint.py usage) so it deserializes only tensor data and
avoids executing arbitrary pickle code—locate the dataloader_state assignment in
the file and add the weights_only=True argument to torch.load.
♻️ Duplicate comments (3)
bionemo-recipes/models/llama3/tests/common/__init__.py (1)

21-21: Same docstring formatting issue as in the ESM2 __init__.py.

Two bullet items collapsed onto one line. See the fix proposed in the ESM2 review.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/models/llama3/tests/common/__init__.py` at line 21, The
module docstring in bionemo-recipes/models/llama3/tests/common/__init__.py has
two bullet items merged onto one line; update the top-level docstring so each
bullet is on its own line (separate the entries for BaseModelTest and
TestTolerances into distinct list items) and ensure the docstring follows the
same multiline bullet formatting used in the ESM2 __init__.py example.
bionemo-recipes/models/llama3/collator.py (1)

733-868: Duplicate of helpers already reviewed in models/esm2/collator.py.

These are identical implementations. See the code duplication comment on the ESM2 collator review — consider a shared utility module.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/models/llama3/collator.py` around lines 733 - 868, The
functions _find_seq_dim, _process_tensor_thd, and _process_tensor_bshd are
duplicates of helpers in models/esm2/collator.py; refactor by extracting these
helpers into a shared utility module (e.g., a new module like
models/shared/collators.py or similar) and replace the local definitions with
imports and usage of the shared functions; update the current file to import
_find_seq_dim, _process_tensor_thd, and _process_tensor_bshd from that shared
module and remove the duplicate definitions here, ensuring any referenced
symbols (seq_len, slice_sizes, cu_seqlens_padded, cp_rank, total_slices,
cp_world_size) match the shared helper signatures.
bionemo-recipes/models/mixtral/tests/common/__init__.py (1)

21-21: Same docstring formatting issue as in the ESM2 and Llama3 __init__.py files.

Two bullet items collapsed onto one line. See the fix proposed in the ESM2 review.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/models/mixtral/tests/common/__init__.py` at line 21, The
module docstring has two bullet items collapsed onto one line; update the
module-level docstring in __init__.py so each item is on its own line (use a
newline and proper bullet prefix for "BaseModelTest: Base test class with all
common test methods" and "TestTolerances: Dataclass for model-specific numerical
tolerances")—locate the docstring near the top of the file and adjust the
formatting to match the fixed style used in the ESM2/Llama3 __init__.py files.
🧹 Nitpick comments (8)
bionemo-recipes/models/mixtral/README.md (2)

11-19: Table cell padding is not mdformat-compliant.

mdformat normalises table cells to use minimal spacing (single space padding), but the current table uses wide right-padding to visually align columns. Running mdformat will reformat these cells, producing diff noise in future PRs.

Run mdformat bionemo-recipes/models/mixtral/README.md to normalise formatting. As per coding guidelines: "Use mdformat for Markdown formatting."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/models/mixtral/README.md` around lines 11 - 19, The Markdown
table in README.md uses wide cell padding instead of mdformat's minimal
single-space padding; run mdformat on the file (e.g., mdformat
bionemo-recipes/models/mixtral/README.md) or manually reduce each table cell to
single-space padding so the table rows (the lines containing "Feature | Support"
and the subsequent pipe-separated rows like "**FP8** | ✅ Supported...") conform
to mdformat normalization and avoid future diff noise.

87-101: export.py is not mentioned in the Developer Guide.

Per project requirements, each model recipe must ship an export.py for Hugging Face Hub bundling. The README does not reference it, leaving users without guidance on how to package and publish the TE model. Consider adding a brief "Exporting to Hugging Face Hub" subsection that calls out export.py and its usage.

Based on learnings: "Models in bionemo-recipes/models/ must include: … export script (export.py) for Hugging Face Hub bundling."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/models/mixtral/README.md` around lines 87 - 101, Add a new
"Exporting to Hugging Face Hub" subsection to the Developer Guide that documents
the required export.py script: state that the model directory must include
export.py, explain how to run it to create the HF bundle (e.g., run export.py
from the model directory or via python export.py with any required args), note
any dependencies or env vars needed for HF upload, and link to
recipes_local_test.py as the local test step before publishing; reference the
filename export.py and the test runner recipes_local_test.py so maintainers can
locate and update the script.
bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py (1)

101-108: Consider adding reduce_dtype=torch.float32 to the else branch for gradient stability.

The FSDP2 MixedPrecisionPolicy signature is MixedPrecisionPolicy(param_dtype=None, reduce_dtype=None, output_dtype=None, cast_forward_inputs=True), confirming output_dtype is a valid field — the if branch is correct.

For the else branch (pure BF16, use_fp32_master_weights=False), MixedPrecisionPolicy() defaults leave reduce_dtype=None, meaning gradient all-reduces also happen in BF16. Gradients can vary significantly from rank to rank, and reducing in float32 can be critical for numerics. The if branch already does this correctly with reduce_dtype=torch.float32; omitting it in the else branch may cause training instability, especially for larger models.

💡 Suggested improvement for the `else` branch
     else:
-        mp_policy = MixedPrecisionPolicy()
+        mp_policy = MixedPrecisionPolicy(reduce_dtype=torch.float32)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py` around lines 101 -
108, The else branch leaves MixedPrecisionPolicy.reduce_dtype as None so
gradient all-reduces happen in BF16, which can destabilize training; update the
else branch that sets mp_policy (when args.use_fp32_master_weights is False) to
instantiate MixedPrecisionPolicy with reduce_dtype=torch.float32 (e.g.,
MixedPrecisionPolicy(reduce_dtype=torch.float32, param_dtype=torch.bfloat16,
output_dtype=torch.bfloat16 or just set reduce_dtype alongside the default
call)) so gradient reductions occur in FP32 while keeping the rest of the BF16
policy.
bionemo-recipes/models/mixtral/convert.py (1)

65-74: Add Args and Returns sections to complete the Google-style docstring.

The updated body text is well-written, but the docstring omits the Args and Returns sections that are present in every other function in this file, leaving num_experts and the return type undocumented.

📝 Proposed fix
 def _make_merge_experts_fn(num_experts: int):
     """Create a merge function with the correct number of named parameters.
 
     The state.py transform system maps function parameter names to source dict keys by inspecting
     the function signature. When ``source_key`` is a tuple, it pairs each tuple element with the
     corresponding named parameter via ``{param: source_key[i]}``. This means ``*args`` style
     parameters do not work -- the system cannot map positional varargs to specific source keys.
 
     Since the number of experts is dynamic (varies per model config), we use ``exec()`` to generate
     a function with exactly ``num_experts`` named parameters (weight0, weight1, ..., weightN-1).
+
+    Args:
+        num_experts: Number of experts; determines the count of named parameters in the generated function.
+
+    Returns:
+        A callable ``merge_experts(weight0, weight1, ..., weightN-1)`` that stacks its inputs along a new
+        leading dimension using ``torch.stack``.
     """

As per coding guidelines, "Ensure all Python files follow Google-style docstrings (pydocstyle convention)."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/models/mixtral/convert.py` around lines 65 - 74, The
docstring for the function create_merge_fn (the factory that generates a merge
function with num_experts named parameters) is missing Google-style "Args" and
"Returns" sections; update the docstring to add an Args section documenting
num_experts (type: int, meaning the number of expert-weight parameters to
generate) and any other parameters, and add a Returns section describing the
returned callable (e.g., a function taking weight0..weightN-1 and returning the
merged result, include its type/signature). Keep wording consistent with other
functions in the file and follow the existing Google-style formatting used
elsewhere.
bionemo-recipes/recipes/esm2_native_te/tests/test_checkpoint_pruning.py (1)

89-103: Good coverage — consider adding a negative save_every_n_steps edge-case assertion.

The current suite covers the documented contract thoroughly. One untested, albeit well-defined, edge case is save_every_n_steps < 0: the existing guard save_every_n_steps > 0 makes it return False, but an explicit assertion would document the intended behaviour and guard against future regressions if the guard is ever changed.

✏️ Suggested addition
     # save_every_n_steps=0 should never save
     assert should_save_checkpoint(step=10, save_every_n_steps=0) is False
+
+    # Negative save_every_n_steps should never save
+    assert should_save_checkpoint(step=10, save_every_n_steps=-1) is False
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/recipes/esm2_native_te/tests/test_checkpoint_pruning.py`
around lines 89 - 103, Add an explicit test assertion for the edge case
save_every_n_steps < 0 in the test_should_save_checkpoint function: verify that
should_save_checkpoint(step=10, save_every_n_steps=-1) returns False to document
and lock in the intended behavior; update the same test (or add a new one
nearby) referencing the should_save_checkpoint function so future changes to the
guard save_every_n_steps > 0 will be caught by CI.
bionemo-recipes/models/esm2/collator.py (2)

845-849: Minor: _process_tensor_bshd divisibility check is incomplete but safe in practice.

The error message says the sequence length "must be divisible by" total_chunks, but the check (chunk_size == 0) only catches when seq_len < total_chunks. If seq_len % total_chunks != 0, the remainder is silently dropped. This is safe because the upstream pad_thd_sequences_for_cp guarantees divisibility, but the error message is misleading. Consider adding a strict check:

🔧 Optional stricter validation
-    if chunk_size == 0:
+    if seq_len % total_chunks != 0:
         raise ValueError(
             f"Sequence length {seq_len} must be divisible by {total_chunks} "
             f"(2 * cp_world_size) for BSHD context parallelism"
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/models/esm2/collator.py` around lines 845 - 849, In
_process_tensor_bshd, replace the incomplete divisibility check that only tests
chunk_size == 0 with a strict modulus check (if seq_len % total_chunks != 0) and
raise a ValueError including seq_len and total_chunks in the message; reference
the relationship to pad_thd_sequences_for_cp in the message or a comment to
clarify why this should normally not trigger. This ensures remainder cases are
caught instead of silently dropping tokens and makes the error message accurate.

733-868: Significant code duplication across three model collators: extract shared helpers to reduce maintenance burden.

_find_seq_dim, _process_tensor_thd, and _process_tensor_bshd are duplicated verbatim in bionemo-recipes/models/esm2/collator.py, bionemo-recipes/models/llama3/collator.py, and bionemo-recipes/models/mixtral/collator.py (~113 lines total). While recipe duplication is justified by the self-containment guideline, the duplication across three model collators violates DRY principles. Consider extracting these helpers into a shared utility module under bionemo-recipes/models/ (e.g., bionemo-recipes/models/common/cp_utils.py) and importing from there in all three model collators.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/models/esm2/collator.py` around lines 733 - 868, Extract the
duplicated helpers _find_seq_dim, _process_tensor_thd, and _process_tensor_bshd
into a single shared utility module (e.g., cp_utils) and replace the verbatim
copies in each collator with imports from that module; specifically, move the
three functions as-is into the new module (preserving signatures and torch
usage), update the collator files to import _find_seq_dim, _process_tensor_thd,
and _process_tensor_bshd, and ensure any device/typing references still resolve
(add necessary imports like torch and typing in the new module) so behavior and
exceptions remain unchanged.
bionemo-recipes/models/mixtral/tests/test_modeling_mixtral.py (1)

241-290: Consider reusing _create_inference_params to avoid duplication.

The TE beam-search test re-implements the same KV-cache setup; reusing the helper keeps it consistent.

♻️ Suggested refactor
-        past_key_values = HFInferenceParams(
-            max_batch_size=2 * num_beams,
-            max_sequence_length=256,
-            num_heads_kv=config.num_key_value_heads,
-            head_dim_k=config.hidden_size // config.num_attention_heads,
-            dtype=torch.bfloat16,
-            qkv_format="thd",
-            max_ctx_len=256,
-        )
-        for layer_number in range(1, config.num_hidden_layers + 1):
-            past_key_values.allocate_memory(layer_number)
+        past_key_values = self._create_inference_params(
+            config,
+            batch_size=2,
+            max_seq_len=256,
+            num_beams=num_beams,
+        )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@bionemo-recipes/models/mixtral/tests/test_modeling_mixtral.py` around lines
241 - 290, The test test_te_mixtral_model_generate_with_cache_beam_search
duplicates KV-cache setup; replace the manual HFInferenceParams construction and
loop with a call to the existing helper _create_inference_params (or whatever
public helper is present) to build and allocate past_key_values for the model
config, then use that returned past_key_values in the generate() call; ensure
you pass the same args (dtype, qkv_format, max_ctx_len, max_batch_size, etc.)
into _create_inference_params so behavior remains identical and remove the
manual for-loop that calls past_key_values.allocate_memory.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@bionemo-recipes/models/esm2/tests/common/__init__.py`:
- Line 21: The module docstring currently has the two bullet points for
BaseModelTest and TestTolerances collapsed onto a single line; update the
top-level docstring in this tests/common __init__.py so each item is its own
bullet on its own line (e.g., "- BaseModelTest: Base test class..." newline "-
TestTolerances: Dataclass for model-specific numerical tolerances"), and apply
the same fix to the identical docstrings in the llama3 and mixtral __init__.py
files to ensure consistent rendering.

In `@bionemo-recipes/models/mixtral/collator.py`:
- Around line 811-868: The function _process_tensor_bshd currently uses floor
division for chunk_size which silently drops tail tokens; add an explicit
divisibility guard after computing total_chunks: if seq_len % total_chunks != 0
raise a ValueError (with a clear message referencing seq_len and total_chunks)
so we fail fast rather than truncating data; keep the existing check for
chunk_size==0 but add this new check before computing chunk indices and slicing.

In `@bionemo-recipes/models/mixtral/README.md`:
- Around line 46-48: The README example disables KV-cache by passing
use_cache=False to model_te.generate, contradicting the "KV-cache inference" ✅
claim; update the quick-start snippet to exercise KV-cache by removing the
use_cache override or setting use_cache=True when calling model_te.generate
(inside the with torch.no_grad() block) so the example actually uses the model's
KV-cache inference path and matches the Feature Support table.
- Around line 81-85: The "Validating Converted Models" section is circular and
lacks runnable comparison commands; update it in the README by either (a)
including a concrete, copy-pastable snippet under the "Inference Examples" /
"Validating Converted Models" headings that shows how to load the baseline
Hugging Face model and the converted model and compute/compare logits and loss
(e.g., commands or Python call sequence to run inference for both models and
diff their outputs), or (b) add a clear link and brief instruction pointing
readers to the golden-value test test_modeling_mixtral.py explaining exactly
which test function/assertion to run and how to interpret its outputs; reference
the "Inference Examples" section and the test file name test_modeling_mixtral.py
so readers can locate the code to run.
- Around line 25-50: The README code examples use bare imports like "from
convert import convert_mixtral_hf_to_te" and "from modeling_mixtral_te import
..." which will raise ModuleNotFoundError unless the current working directory
is bionemo-recipes/models/mixtral; update the snippets to include a one-line
preamble instructing users to either run the snippet from that directory (e.g.,
"cd bionemo-recipes/models/mixtral") or to set up the PYTHONPATH/sys.path or
install the package via the documented workflow (pip install -r
requirements.txt) before running, and add that same short note to every block
that uses convert or modeling_mixtral_te so users won't hit silent import
failures.

In `@bionemo-recipes/recipes/esm2_native_te/collator.py`:
- Around line 811-868: The function _process_tensor_bshd currently can drop tail
tokens when seq_len is not divisible by (2 * cp_world_size); add an explicit
divisibility check after computing total_chunks (or chunk_size) and if seq_len %
total_chunks != 0 raise a ValueError with a clear message (e.g. "Sequence length
{seq_len} must be divisible by {total_chunks} (2 * cp_world_size) for BSHD
context parallelism") so the function fails fast instead of silently dropping
tokens.

In `@bionemo-recipes/recipes/esm2_native_te/tests/test_scheduler.py`:
- Line 21: The test imports the recipe-local scheduler module using an absolute
import; replace the top-level import in test_scheduler.py so it explicitly
references the local scheduler module by using a relative import (e.g. import
get_linear_schedule_with_warmup from ..scheduler) so that
get_linear_schedule_with_warmup is resolved from the recipe's scheduler module
rather than relying on conftest.py sys.path manipulation.

---

Outside diff comments:
In `@bionemo-recipes/recipes/esm2_native_te/checkpoint.py`:
- Line 590: The call dataloader_state = torch.load(dataloader_path) in
checkpoint.py is unsafe; update the torch.load invocation to pass
weights_only=True (matching the llama3_native_te/checkpoint.py usage) so it
deserializes only tensor data and avoids executing arbitrary pickle code—locate
the dataloader_state assignment in the file and add the weights_only=True
argument to torch.load.

---

Duplicate comments:
In `@bionemo-recipes/models/llama3/collator.py`:
- Around line 733-868: The functions _find_seq_dim, _process_tensor_thd, and
_process_tensor_bshd are duplicates of helpers in models/esm2/collator.py;
refactor by extracting these helpers into a shared utility module (e.g., a new
module like models/shared/collators.py or similar) and replace the local
definitions with imports and usage of the shared functions; update the current
file to import _find_seq_dim, _process_tensor_thd, and _process_tensor_bshd from
that shared module and remove the duplicate definitions here, ensuring any
referenced symbols (seq_len, slice_sizes, cu_seqlens_padded, cp_rank,
total_slices, cp_world_size) match the shared helper signatures.

In `@bionemo-recipes/models/llama3/tests/common/__init__.py`:
- Line 21: The module docstring in
bionemo-recipes/models/llama3/tests/common/__init__.py has two bullet items
merged onto one line; update the top-level docstring so each bullet is on its
own line (separate the entries for BaseModelTest and TestTolerances into
distinct list items) and ensure the docstring follows the same multiline bullet
formatting used in the ESM2 __init__.py example.

In `@bionemo-recipes/models/mixtral/tests/common/__init__.py`:
- Line 21: The module docstring has two bullet items collapsed onto one line;
update the module-level docstring in __init__.py so each item is on its own line
(use a newline and proper bullet prefix for "BaseModelTest: Base test class with
all common test methods" and "TestTolerances: Dataclass for model-specific
numerical tolerances")—locate the docstring near the top of the file and adjust
the formatting to match the fixed style used in the ESM2/Llama3 __init__.py
files.

---

Nitpick comments:
In `@bionemo-recipes/models/esm2/collator.py`:
- Around line 845-849: In _process_tensor_bshd, replace the incomplete
divisibility check that only tests chunk_size == 0 with a strict modulus check
(if seq_len % total_chunks != 0) and raise a ValueError including seq_len and
total_chunks in the message; reference the relationship to
pad_thd_sequences_for_cp in the message or a comment to clarify why this should
normally not trigger. This ensures remainder cases are caught instead of
silently dropping tokens and makes the error message accurate.
- Around line 733-868: Extract the duplicated helpers _find_seq_dim,
_process_tensor_thd, and _process_tensor_bshd into a single shared utility
module (e.g., cp_utils) and replace the verbatim copies in each collator with
imports from that module; specifically, move the three functions as-is into the
new module (preserving signatures and torch usage), update the collator files to
import _find_seq_dim, _process_tensor_thd, and _process_tensor_bshd, and ensure
any device/typing references still resolve (add necessary imports like torch and
typing in the new module) so behavior and exceptions remain unchanged.

In `@bionemo-recipes/models/mixtral/convert.py`:
- Around line 65-74: The docstring for the function create_merge_fn (the factory
that generates a merge function with num_experts named parameters) is missing
Google-style "Args" and "Returns" sections; update the docstring to add an Args
section documenting num_experts (type: int, meaning the number of expert-weight
parameters to generate) and any other parameters, and add a Returns section
describing the returned callable (e.g., a function taking weight0..weightN-1 and
returning the merged result, include its type/signature). Keep wording
consistent with other functions in the file and follow the existing Google-style
formatting used elsewhere.

In `@bionemo-recipes/models/mixtral/README.md`:
- Around line 11-19: The Markdown table in README.md uses wide cell padding
instead of mdformat's minimal single-space padding; run mdformat on the file
(e.g., mdformat bionemo-recipes/models/mixtral/README.md) or manually reduce
each table cell to single-space padding so the table rows (the lines containing
"Feature | Support" and the subsequent pipe-separated rows like "**FP8** | ✅
Supported...") conform to mdformat normalization and avoid future diff noise.
- Around line 87-101: Add a new "Exporting to Hugging Face Hub" subsection to
the Developer Guide that documents the required export.py script: state that the
model directory must include export.py, explain how to run it to create the HF
bundle (e.g., run export.py from the model directory or via python export.py
with any required args), note any dependencies or env vars needed for HF upload,
and link to recipes_local_test.py as the local test step before publishing;
reference the filename export.py and the test runner recipes_local_test.py so
maintainers can locate and update the script.

In `@bionemo-recipes/models/mixtral/tests/test_modeling_mixtral.py`:
- Around line 241-290: The test
test_te_mixtral_model_generate_with_cache_beam_search duplicates KV-cache setup;
replace the manual HFInferenceParams construction and loop with a call to the
existing helper _create_inference_params (or whatever public helper is present)
to build and allocate past_key_values for the model config, then use that
returned past_key_values in the generate() call; ensure you pass the same args
(dtype, qkv_format, max_ctx_len, max_batch_size, etc.) into
_create_inference_params so behavior remains identical and remove the manual
for-loop that calls past_key_values.allocate_memory.

In `@bionemo-recipes/recipes/esm2_native_te/tests/test_checkpoint_pruning.py`:
- Around line 89-103: Add an explicit test assertion for the edge case
save_every_n_steps < 0 in the test_should_save_checkpoint function: verify that
should_save_checkpoint(step=10, save_every_n_steps=-1) returns False to document
and lock in the intended behavior; update the same test (or add a new one
nearby) referencing the should_save_checkpoint function so future changes to the
guard save_every_n_steps > 0 will be caught by CI.

In `@bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py`:
- Around line 101-108: The else branch leaves MixedPrecisionPolicy.reduce_dtype
as None so gradient all-reduces happen in BF16, which can destabilize training;
update the else branch that sets mp_policy (when args.use_fp32_master_weights is
False) to instantiate MixedPrecisionPolicy with reduce_dtype=torch.float32
(e.g., MixedPrecisionPolicy(reduce_dtype=torch.float32,
param_dtype=torch.bfloat16, output_dtype=torch.bfloat16 or just set reduce_dtype
alongside the default call)) so gradient reductions occur in FP32 while keeping
the rest of the BF16 policy.

ℹ️ Review info

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6a60786 and 773d581.

📒 Files selected for processing (35)
  • bionemo-recipes/models/amplify/src/amplify/state.py
  • bionemo-recipes/models/esm2/README.md
  • bionemo-recipes/models/esm2/collator.py
  • bionemo-recipes/models/esm2/convert.py
  • bionemo-recipes/models/esm2/modeling_esm_te.py
  • bionemo-recipes/models/esm2/state.py
  • bionemo-recipes/models/esm2/tests/common/__init__.py
  • bionemo-recipes/models/llama3/collator.py
  • bionemo-recipes/models/llama3/convert.py
  • bionemo-recipes/models/llama3/modeling_llama_te.py
  • bionemo-recipes/models/llama3/state.py
  • bionemo-recipes/models/llama3/tests/common/__init__.py
  • bionemo-recipes/models/mixtral/README.md
  • bionemo-recipes/models/mixtral/collator.py
  • bionemo-recipes/models/mixtral/convert.py
  • bionemo-recipes/models/mixtral/modeling_mixtral_te.py
  • bionemo-recipes/models/mixtral/state.py
  • bionemo-recipes/models/mixtral/tests/common/__init__.py
  • bionemo-recipes/models/mixtral/tests/test_export.py
  • bionemo-recipes/models/mixtral/tests/test_modeling_mixtral.py
  • bionemo-recipes/recipes/esm2_accelerate_te/example_8m_checkpoint/esm_nv.py
  • bionemo-recipes/recipes/esm2_native_te/README.md
  • bionemo-recipes/recipes/esm2_native_te/checkpoint.py
  • bionemo-recipes/recipes/esm2_native_te/collator.py
  • bionemo-recipes/recipes/esm2_native_te/dataset.py
  • bionemo-recipes/recipes/esm2_native_te/example_8m_checkpoint/esm_nv.py
  • bionemo-recipes/recipes/esm2_native_te/tests/test_checkpoint_pruning.py
  • bionemo-recipes/recipes/esm2_native_te/tests/test_scheduler.py
  • bionemo-recipes/recipes/esm2_native_te/train_fsdp2.py
  • bionemo-recipes/recipes/esm2_peft_te/collator.py
  • bionemo-recipes/recipes/esm2_peft_te/example_8m_checkpoint/esm_nv.py
  • bionemo-recipes/recipes/llama3_native_te/checkpoint.py
  • bionemo-recipes/recipes/llama3_native_te/collator.py
  • bionemo-recipes/recipes/llama3_native_te/dataset.py
  • bionemo-recipes/recipes/llama3_native_te/modeling_llama_te.py

Copy link
Collaborator

@trvachov trvachov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One nit and one real comment about future security alerts.

export_hf_checkpoint("NeuralNovel/Mini-Mixtral-v0.2", tmp_path / "checkpoint_export")

_ = AutoTokenizer.from_pretrained(tmp_path / "checkpoint_export")
model = AutoModelForCausalLM.from_pretrained(tmp_path / "checkpoint_export", trust_remote_code=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this "trust_remote_code" going to trigger a security notice?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@coderabbitai what do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧩 Analysis chain

🌐 Web query:

trust_remote_code security implications HuggingFace 2025

💡 Result:

trust_remote_code=True in Hugging Face (e.g., transformers / diffusers) tells the library it may download and execute Python files from the model repo (custom modeling.py, pipeline code, etc.). That means loading a model can become equivalent to running unvetted code from the internet, with the same impact as any supply-chain compromise: arbitrary code execution, credential/token exfiltration, data tampering, persistence, lateral movement—whatever the current process permissions allow. [1], [5]

Key security implications

  • Arbitrary code execution (RCE) by design: if the repo (or any dependency it pulls) is malicious/compromised, your environment is compromised. [1]
  • Update/maintainer compromise risk: even if the repo is safe today, a later update (or account takeover) can change executed code unless you pin versions. [1]
  • Model-hub supply-chain vectors (not limited to trust_remote_code): attackers also exploit ecosystem mechanics like name squatting / impersonation and namespace reuse/hijacking, tricking users/CI into downloading attacker-controlled artifacts. This increases the practical risk of enabling remote code. [3], [4]
  • Real-world vulnerability backdrop (2025): multiple 2025-era advisories/CVEs show how executing “untrusted code” around HF tooling/agents can become RCE when sandboxes or loaders are bypassed. (Different component than trust_remote_code, but same core risk pattern.) [2]

Practical mitigations recommended by HF/security guidance

  • Avoid trust_remote_code unless you must; prefer architectures supported natively by the library. [1]
  • If you must use it:
    • Inspect the repo’s Python files before running. [1]
    • Pin an immutable revision (commit hash/tag) to prevent silent code changes. [1]
    • Prefer safetensors and set use_safetensors to avoid unsafe deserialization formats (separate but commonly adjacent risk). [1]
  • Note that some platforms now block models requiring trust_remote_code unless verified/trusted (example: HF collection on Microsoft Foundry/Azure ML). [6]

Sources:
[1] Transformers security policy (“remote artefacts” + “remote code” + pin revision)
[2] NVD CVE-2025-5120 (smolagents sandbox escape to RCE)
[3] SC Media / Unit 42 report on Hugging Face namespace reuse risk (2025)
[4] Protect AI research on name squatting/impersonation on Hugging Face
[5] Diffusers discussion raising trust_remote_code=True arbitrary-code concerns (2025)
[6] Hugging Face docs for Microsoft Foundry/Azure ML security: trust_remote_code disallowed unless verified/trusted


@trvachov Yes, trust_remote_code=True will very likely trigger security notices from GitHub's security scanning, dependency checkers, and other security tools.

Why it's flagged

According to current security guidance, trust_remote_code=True allows the library to "download and execute Python files from the model repo," which means loading a model can execute arbitrary code from the internet. This creates supply-chain risks including code execution, credential exfiltration, and other compromises.

Recommended mitigations for this test

Since this is a test file, here are practical steps to reduce risk while keeping functionality:

  1. Pin a specific revision (commit hash) to prevent silent code changes:
model = AutoModelForCausalLM.from_pretrained(
    tmp_path / "checkpoint_export",
    trust_remote_code=True,
    revision="<specific-commit-hash>"  # Pin immutable version
)
  1. Check if trust_remote_code is actually necessary - Standard Mixtral architectures should work without it. You can test by removing the parameter to see if the model loads.

  2. Add a comment explaining why it's needed (if it is):

# trust_remote_code required for custom Mixtral implementation
  1. Consider using safetensors format (separate but related risk mitigation).

Since the test is already skipped in CI, the immediate security exposure is limited to local development environments. However, addressing this proactively will prevent security alerts and reduce risk if the test is ever run in CI or other environments.


🧠 Learnings used
Learnt from: CR
Repo: NVIDIA/bionemo-framework PR: 0
File: CLAUDE.md:0-0
Timestamp: 2026-01-22T19:29:40.552Z
Learning: Applies to bionemo-recipes/models/**/*.py : Models in bionemo-recipes/models/ must include: golden value tests proving TE model matches reference model, bidirectional conversion functions (convert_hf_to_te() and convert_te_to_hf()), export script (export.py) for Hugging Face Hub bundling, and open-source license

Learnt from: ohadmo
Repo: NVIDIA/bionemo-framework PR: 1139
File: bionemo-recipes/models/geneformer/tests/test_checkpoints_modeling_bert.py:306-353
Timestamp: 2025-11-10T23:19:53.726Z
Learning: In `models/geneformer/tests/test_checkpoints_modeling_bert.py`, the test code intentionally accesses the `.transform` attribute of `_unpack_qkv_weight` and `_unpack_qkv_bias` functions to bypass the `io.state_transform` decorator. This allows the test to use the underlying tensor transformation functions directly without requiring a full NeMo `TransformCTX` with complete source-target state dictionaries.

Copy link
Collaborator Author

@pstjohn pstjohn Feb 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

silly coderabbit, trix are for kids

anyways, we don't have a revision here because we're loading from a local filesystem. We have to say "trust_remote_code" because we're testing whether the code we've dumped into that folder is executable, even though it's actually 'local' code in this case.

The alternative is to use the bare NVMixtralFor... classes, but that doesn't give us the same coverage that our "auto_map" dictionary is working correctly

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like the humans are having a chat. I'll hop back into my burrow for now. If you need me again, just tag @coderabbitai in a new comment, and I'll come hopping out!

Signed-off-by: Peter St. John <pstjohn@nvidia.com>
@pstjohn
Copy link
Collaborator Author

pstjohn commented Feb 25, 2026

@coderabbiai resolve

if val is None:
return val

seq_dim = _find_seq_dim(val, seq_len)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did it just make this function simply because there was too much code? The base function that this came from was from TE

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/coderabbitai what do you think?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need the "@"

@coderabbitai what do you think

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check how many times _find_seq_dim is used in the file
rg -n "_find_seq_dim" bionemo-recipes/models/esm2/collator.py

# Look for references to the TransformerEngine PR
rg -n "TransformerEngine" bionemo-recipes/models/esm2/collator.py

Length of output: 274


🌐 Web query:

TransformerEngine PR 2387 context parallel _find_seq_dim

💡 Result:

PR #2387 in NVIDIA/TransformerEngine is titled “Enables specified cp rank slicing” and targets PyTorch context parallelism (the THD/variable-length path) by letting callers explicitly choose which context-parallel rank’s shard to extract from a full (unsliced) batch tensor. The PR description gives an example where, for cp_size=2, the sequence is split into two discontiguous shards per rank for load-balancing. [1]

Concretely, it updates get_batch_on_this_cp_rank(...) (in transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py) to add an optional cp_rank argument so you can “slice the batch for the provided rank” rather than always using the current process’s rank. [1]

Review feedback in the PR notes a couple of issues in the proposed validation logic (notably checking 0 <= cp_rank < cp_size before checking cp_rank is None, which can raise a TypeError when cp_rank=None), and flags backward-compatibility concerns around how cp_group/rank discovery is handled. [2]

On _find_seq_dim: I don’t see a symbol/function named _find_seq_dim in PR #2387’s diff or in the current context_parallel.py on main, so it’s likely from a different branch/version or another repo component you’re looking at (or a local patch). [2][3]


🏁 Script executed:

#!/bin/bash
# Check if _find_seq_dim or similar logic appears in _process_tensor_bshd
ast-grep --pattern $'def _process_tensor_bshd($$$) {
  $$$
}'

# Look at the full _process_tensor_thd function to see context
sed -n '761,820p' bionemo-recipes/models/esm2/collator.py

Length of output: 2093


@pstjohn Good catch on the mention!

@jomitchellnv You raise a fair point. Looking at the code, _find_seq_dim is indeed only used once (in _process_tensor_thd at line 786), which suggests it might have been extracted primarily for code organization rather than reusability.

That said, the extraction does provide some benefits:

  • Clear error messages: It encapsulates the dimension-finding logic with explicit error handling for different tensor shapes
  • Cleaner separation: THD format needs dynamic dimension detection (could be dim 0 or 1), while BSHD hardcodes dim=1
  • Readability: The function name makes the intent explicit in _process_tensor_thd

Given the TODO comment at line 928-929 indicating this will eventually be replaced with the TransformerEngine implementation once PR #2387 merges, it seems like this was interim code. The extraction might have been done to keep the temporary implementation cleaner and more maintainable until the upstream TE version is ready.

So yes, it's likely a combination of reducing complexity in _process_tensor_thd and making the code easier to understand/maintain temporarily. Whether to inline it or keep it extracted is a judgment call—both approaches are reasonable for a single-use helper.

)
# Ensure cu_seqlens_padded[-1] is a Python int, not a 0-dim tensor
last_elem = cu_seqlens_padded[-1]
seq_len_val = last_elem.item() if isinstance(last_elem, torch.Tensor) else last_elem
Copy link
Collaborator

@jomitchellnv jomitchellnv Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

although I think we've diverged since then

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/coderabbitai what edge case does this fix

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my assumption is that this is for typing; @coderabbitai what edge case does this fix

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pstjohn @jomitchellnv This change doesn't introduce a fix for a new edge case—it's a refactoring that preserves existing edge case handling. The edge case is explicitly documented in the comment above the code:

The edge case: When indexing a 1D tensor like cu_seqlens_padded[-1], the result is a 0-dimensional tensor, not a Python scalar. If you then use this value in contexts that expect a native Python int (like certain operations or comparisons), you can encounter issues.

The handling: The code checks if the result is a torch.Tensor and calls .item() to extract the Python int value, otherwise uses it as-is.

The refactoring moved this logic out of the nested process_tensor function (which was removed) to the top level, making it reusable for the new _process_tensor_thd helper function. The logic itself remained the same—just extracted and made more concise with a ternary expression.

So your assumption about typing is correct—it ensures seq_len_val is consistently a Python int rather than sometimes being a 0-dim tensor, which improves type safety and prevents potential runtime issues.

reduce_dtype=torch.float32 if args.use_fp32_master_weights else None, # Gradient reductions in FP32
output_dtype=torch.bfloat16 if args.use_fp32_master_weights else None, # Forward output dtype
)
if args.use_fp32_master_weights:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this does the same thing right? I think the left might even be cleaner it avoids the if/else loop

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/coderabbitai what do you think?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@coderabbitai what do you think?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my preference would probably just be else: mp_policy = None

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧩 Analysis chain

🌐 Web query:

PyTorch FSDP MixedPrecisionPolicy default parameters documentation

💡 Result:

In PyTorch’s FSDP2 API, the docs show MixedPrecisionPolicy with these default parameters:

  • Signature: torch.distributed.fsdp.MixedPrecisionPolicy(param_dtype=None, reduce_dtype=None, output_dtype=None, cast_forward_inputs=True) [1]
  • Default when calling fully_shard: mp_policy=MixedPrecisionPolicy(param_dtype=None, reduce_dtype=None, output_dtype=None, cast_forward_inputs=True) [1]

Doc-noted default behaviors:

  • param_dtype=None: use the parameter’s original dtype for the unsharded parameter (compute/all-gather). [1]
  • reduce_dtype=None: if param_dtype is set (not None), reduction uses the compute dtype; otherwise uses the default behavior implied by None. [1]
  • output_dtype=None: no forced casting of floating-point forward outputs. [1]
  • cast_forward_inputs=True: cast floating-point forward inputs to param_dtype (when param_dtype is set). [1]

@pstjohn Both approaches are functionally identical since MixedPrecisionPolicy() defaults all parameters to None.

However, I have a slight preference for the new if/else approach for these reasons:

  1. Explicitness: It's clearer that when use_fp32_master_weights=False, you're deliberately using the default policy rather than manually passing None values
  2. Readability: The intent is more obvious at a glance - one case configures specific precision settings, the other uses defaults
  3. Maintainability: If the two branches need different configurations in the future, the structure is already in place

That said, @jomitchellnv makes a valid point about conciseness. The original ternary operator approach is more compact and keeps all the logic in one statement. If you value brevity and the team finds ternary operators readable, the original works well too.

My recommendation: Keep the new if/else approach for clarity, especially since the comments in the True branch explain what each parameter does.

@jomitchellnv
Copy link
Collaborator

jomitchellnv commented Feb 26, 2026

Release notes make it clear this MR covers four buckets:
1. Mixtral + TE support / model features (context-parallel helpers, token dropout)
2. Refactors (state transform API, checkpoint/dataset utils, collators)
3. Docs (state-dict system + input format docs + README/config)
4. Tests (Mixtral export validation, checkpoint/scheduler tests, base class updates)

Would you be open to splitting this into 2–4 MRs along those lines? As-is it’s difficult to review in one pass, and splitting would align better with our CONTRIBUTING guidance (“Make sure your PR does one thing/Have a clear answer to ‘What does this PR do?’")

scale_factor = (1 - mask_ratio_train) / (1 - mask_ratio_observed)
reshaped_scale_factor = torch.repeat_interleave(scale_factor, src_lengths_padded, dim=0)
embeddings = (embeddings * reshaped_scale_factor.unsqueeze(-1)).to(embeddings.dtype)
embeddings = self._apply_token_dropout_bshd(embeddings, input_ids, attention_mask)
Copy link
Collaborator

@jomitchellnv jomitchellnv Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like GenAI is finding areas where it thinks there is too much code, then it wraps that chunk into a section, but that causes more indirection. It doesn't look like its fixing anything either. I'm seeing a lot of these type of changes but I'm not sure if these are necessary / more helpful than what we had before

https://github.com/coderabbitai what do you think? Is this method cleaner and more readable to the user or should we stick with the older one?

@jomitchellnv jomitchellnv self-requested a review February 26, 2026 23:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants