feat(conversions): cutlass FP8 e4m3 per-channel + compile_target/metadata tagging#2
Draft
KavinKrishnan wants to merge 1 commit into
Conversation
…data tagging Extends prime_rl/trainer/models/conversions/ to address the live coworker complaint that prime-rl breaks on Qwen3-MoE with cutlass kernels — the registry currently has only `bf16_cast` and `fp8_128x128`; anything else raises NotImplementedError, and there's no compile_target tag on the publish so wrong-target receivers silently misinterpret bytes. This is the trainer-side half of the design fix; the receiver-side filtering API is already shipped in modelexpress as PR PrimeIntellect-ai#349 (Phase 3a/3b on kavink/post-2389-phase3-4). Once Phase 2 graduation lands on #1, the MxV2TrainingPublisher will read each tensor's resolved ConversionEntry.compile_target + compile_metadata and tag the v2 publish so receivers can filter via discover_v2_sources(compile_target_filter=…, required_compile_metadata=…). What lands: ConversionEntry gains two new fields with safe defaults: - compile_target: str = "hf_raw" - compile_metadata: dict[str, Any] = {} register(...) takes them as kwargs; existing call sites are unchanged. Mirrors the constants in modelexpress.shape_descriptors (Phase 3a) but without a hard import dep in either direction — both repos keep their own canonical string set. select_default_conversion is refactored to a table-driven design. The old if/else chain is replaced by _DEFAULT_RULES: list[(predicate, name)] which the resolver walks in order. Adding a new kernel = adding one row via register_default_rule(predicate, name) from the kernel's own module on import. A predicate that raises on a malformed config is treated as "doesn't match" and skipped, keeping the resolver robust to model-card weirdness without forcing every predicate to be defensive. The AutoConfig import is deferred into the function body so the registry loads without requiring `transformers` (the registry is imported by tests + tooling that have no HF download capability). Existing entries get their tags retroactively: - bf16_cast / fp32_cast: compile_target="hf_raw" - fp8_128x128: compile_target="deep_gemm_fp8" + metadata{block_size: [128,128], scale_layout:"blockwise", dtype:"e4m3"} New conversion: cutlass_fp8_e4m3_per_channel - One scalar scale per output row (vs DeepGemm's per-128x128-block). - 2D dispatch: (out, in) weight → (out,) scale. 3D dispatch: (E, out, in) stacked MoE → (E, out) scale. - compile_target="cutlass_fp8", compile_metadata={dtype:"e4m3", scale_layout:"per_channel", scale_axis:-1, activation_scheme: "dynamic"} — matches cutlass scaled_mm + vLLM's native FP8 path. - Two default-resolver predicates: * quant_method="fp8" + quant_format="cutlass" (explicit) * quant_method="fp8" + weight_block_size=None + activation_scheme="dynamic" (the vLLM-published convention) Both predicates run AFTER the deep-gemm rule, so models with block_size=[128,128] AND activation_scheme="dynamic" still resolve to fp8_128x128 (regression-tested). Per-channel helpers in trainer/models/fp8.py: - fp8_per_channel_quantize(weight) → (q_e4m3, scale_f32). Handles 2D and 3D via the same code path; reduction over the innermost axis. - fp8_per_channel_quantize_into(weight, out, sf) — writes into preallocated buffers, matches the convention of fp8_block_quantize. Tests: 19/19 green via direct-load + transformers stub. Categories: - Per-channel quantize: 2D shape, 3D shape, 1D rejected, bf16 dequant accuracy (≤5% median rel error), into-buffer write. - Registry: existing entries carry correct compile_target + compile_metadata, cutlass entry registered + listed, default-rule insert/append ordering works, unknown quant error message lists registered names. - select_default_conversion dispatch: no-quant → bf16, [128,128] blockwise → fp8_128x128, quant_format=cutlass → cutlass, no weight_block_size + dynamic → cutlass, deep-gemm wins when both rules match. - Conversion fn dispatch: 2D linear path correctness, 3D MoE path correctness, requires_scale=True enforced. Adding a sibling kernel (per-token cutlass, awq, gptq, mxfp4, …) is now one new module ~80 LOC: write the quant fn, register() it with appropriate compile_target/metadata, register_default_rule() with its HF-config predicate. Branches off PR PrimeIntellect-ai#2389 head 79ea824. Independent of the Phase 2 graduation PR — these can land in parallel.
KavinKrishnan
added a commit
that referenced
this pull request
May 29, 2026
… (v0.7.x) Captures the empirical findings from baking PRs #1 and #2 into an ARM64 GB200 image and running it on the kavin namespace for 8+ hours on Qwen3-30B-A3B-Instruct-2507 with gsm8k. Documents three real surprises the unit tests didn't cover: 1. Dockerfile.cuda's `uv sync` is missing `--extra disagg`, so modelexpress isn't installed in stock images; inference workers crash at the first import. Shipped v0.7.1 as a one-line overlay that adds the extra until the upstream Dockerfile.cuda can be updated. 2. `LD_PRELOAD` path for libcudart.so.12 — v0.5.2 had /usr/local/cuda present in the final stage; v0.7.0 (built from upstream Dockerfile.cuda as-is) doesn't. The pip-installed wheel path (/app/.venv/lib/python3.12/site-packages/nvidia/cuda_runtime/lib/) is the new canonical location. 3. The configmap monkeypatch (patch_nixl_mx.py) and Phase 2's source-baked fixes are complementary — they patch different layers (broadcast vs rendezvous-wait) and both should stay until PR #1 merges upstream. Build experience numbers: - v0.7.0 from-scratch ARM64 build under QEMU: 6h45min (uv sync 45m, flash-attn from source 3h45m). - v0.7.1 overlay on top of v0.7.0: ~3 min. Cluster observations from v0.5.2 + configmap monkeypatch (the runtime-patched path our PR #1 codifies into source): - 183 successful RL refit cycles in one 66-min uninterrupted window - Reward variance 0.5-1.0 across orchestrator steps (real learning) - Off-policy level = 0 throughout - Zero NIXL data-plane errors - Recurring orchestrator wait_for_all_peers_ready timeout (~once per 30-66 min) is the exact bug class Phase 2's rendezvous-level dedup eliminates Also notes seven RFC updates queued in pensieve/RL/PrimeRL/09_rfc_updates_needed.md, three of which are new from this build experience (disagg extra, LD_PRELOAD path, vLLM PR #43375 / Anyscale RDT positioning). Companion to the RFC at docs/proposals/post-pr2389-kernel-compile-plan.md.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Trainer-side half of the design fix for the live coworker complaint that prime-rl breaks on Qwen3-MoE with cutlass kernels. The receiver-side filtering API is already shipped in ai-dynamo/modelexpress#349 (Phase 3a/3b on
kavink/post-2389-phase3-4); this PR is the registry-side that produces tags that PR can filter on.Draft because it depends on the Phase 2 graduation (draft #1) landing before the publisher actually emits the new `compile_target` / `compile_metadata` fields onto the wire. The registry extensions themselves are useful immediately — `select_default_conversion` now stops raising `NotImplementedError` for cutlass-format checkpoints, and the per-channel quantize helpers are unit-tested + ready to swap into any existing trainer-side path.
Registry plumbing
`ConversionEntry` gains two safe-default fields:
Mirrors the canonical strings in `modelexpress.shape_descriptors` (Phase 3a) without a cross-repo import dependency. Existing entries tagged retroactively: `bf16_cast`/`fp32_cast` → `hf_raw`, `fp8_128x128` → `deep_gemm_fp8` with metadata `{block_size: [128,128], scale_layout: "blockwise", dtype: "e4m3"}`.
`select_default_conversion` refactored to a table-driven design — `_DEFAULT_RULES: list[(predicate, name)]` walked in order. Adding a new kernel = one new module that calls `register(...)` + `register_default_rule(predicate, name)`. Predicates that raise on a malformed config are treated as "doesn't match" and skipped, keeping the resolver robust to model-card weirdness.
The `AutoConfig` import was hoisted into the function body so the registry loads without requiring `transformers` (tests + tooling don't need HF download capability).
New conversion: `cutlass_fp8_e4m3_per_channel`
Default-resolver predicates:
Both run after the deep-gemm rule, so models with `block_size=[128,128]` AND `activation_scheme="dynamic"` still resolve to `fp8_128x128` — regression-tested.
Tests
19/19 green via `python3 -m pytest tests/unit/train/models/conversions/test_cutlass_fp8.py`, direct-load + `transformers` stub so no HF download needed.
Categories:
Adding more kernels later
One module per kernel, ~80 LOC each:
```python
def my_kernel_quantize(src, out, scale_out): ...
register("my_kernel", my_kernel_quantize, requires_scale=True,
compile_target="my_kernel",
compile_metadata={...})
register_default_rule(lambda quant: quant.get(...) == ..., "my_kernel")
```
Future candidates: `cutlass_fp8_e4m3_per_token`, `awq_int4`, `gptq_int4`, `mxfp4`, `trtllm_w4a16`.
Test plan
Context