Skip to content

feat(conversions): cutlass FP8 e4m3 per-channel + compile_target/metadata tagging#2

Draft
KavinKrishnan wants to merge 1 commit into
kavink/post-2389-kernel-compile-planfrom
kavink/post-2389-conversion-registry-extensions
Draft

feat(conversions): cutlass FP8 e4m3 per-channel + compile_target/metadata tagging#2
KavinKrishnan wants to merge 1 commit into
kavink/post-2389-kernel-compile-planfrom
kavink/post-2389-conversion-registry-extensions

Conversation

@KavinKrishnan
Copy link
Copy Markdown
Owner

Summary

Trainer-side half of the design fix for the live coworker complaint that prime-rl breaks on Qwen3-MoE with cutlass kernels. The receiver-side filtering API is already shipped in ai-dynamo/modelexpress#349 (Phase 3a/3b on kavink/post-2389-phase3-4); this PR is the registry-side that produces tags that PR can filter on.

Draft because it depends on the Phase 2 graduation (draft #1) landing before the publisher actually emits the new `compile_target` / `compile_metadata` fields onto the wire. The registry extensions themselves are useful immediately — `select_default_conversion` now stops raising `NotImplementedError` for cutlass-format checkpoints, and the per-channel quantize helpers are unit-tested + ready to swap into any existing trainer-side path.

Registry plumbing

`ConversionEntry` gains two safe-default fields:

  • `compile_target: str = "hf_raw"`
  • `compile_metadata: dict[str, Any] = {}`

Mirrors the canonical strings in `modelexpress.shape_descriptors` (Phase 3a) without a cross-repo import dependency. Existing entries tagged retroactively: `bf16_cast`/`fp32_cast` → `hf_raw`, `fp8_128x128` → `deep_gemm_fp8` with metadata `{block_size: [128,128], scale_layout: "blockwise", dtype: "e4m3"}`.

`select_default_conversion` refactored to a table-driven design — `_DEFAULT_RULES: list[(predicate, name)]` walked in order. Adding a new kernel = one new module that calls `register(...)` + `register_default_rule(predicate, name)`. Predicates that raise on a malformed config are treated as "doesn't match" and skipped, keeping the resolver robust to model-card weirdness.

The `AutoConfig` import was hoisted into the function body so the registry loads without requiring `transformers` (tests + tooling don't need HF download capability).

New conversion: `cutlass_fp8_e4m3_per_channel`

Property Value
Quant dtype FP8 e4m3 (max 448)
Scale layout per-output-channel
2D shape weight `(out, in)` → scale `(out,)`
3D shape weight `(E, out, in)` → scale `(E, out)`
Reduction axis innermost (`axis=-1`)
`compile_target` `"cutlass_fp8"`
`compile_metadata` `{dtype, scale_layout, scale_axis, activation_scheme}`

Default-resolver predicates:

  1. `quant_method="fp8"` + `quant_format="cutlass"` (explicit)
  2. `quant_method="fp8"` + `weight_block_size=None` + `activation_scheme="dynamic"` (vLLM's published convention)

Both run after the deep-gemm rule, so models with `block_size=[128,128]` AND `activation_scheme="dynamic"` still resolve to `fp8_128x128` — regression-tested.

Tests

19/19 green via `python3 -m pytest tests/unit/train/models/conversions/test_cutlass_fp8.py`, direct-load + `transformers` stub so no HF download needed.

Categories:

  • Per-channel quantize helper: 2D/3D shape, 1D rejection, bf16 dequant accuracy (≤5% median rel error), into-buffer write.
  • Registry: existing entries retain correct `compile_target` + `compile_metadata`, cutlass entry registered + listed, default-rule insert/append ordering, unknown-quant error message lists registered names.
  • Default-resolver dispatch: no-quant → bf16, [128,128] → fp8_128x128, `quant_format=cutlass` → cutlass, no block_size + dynamic → cutlass, deep-gemm wins when both rules match.
  • Conversion fn dispatch: 2D linear correctness, 3D MoE correctness, `requires_scale=True` enforced.

Adding more kernels later

One module per kernel, ~80 LOC each:

```python
def my_kernel_quantize(src, out, scale_out): ...
register("my_kernel", my_kernel_quantize, requires_scale=True,
compile_target="my_kernel",
compile_metadata={...})
register_default_rule(lambda quant: quant.get(...) == ..., "my_kernel")
```

Future candidates: `cutlass_fp8_e4m3_per_token`, `awq_int4`, `gptq_int4`, `mxfp4`, `trtllm_w4a16`.

Test plan

  • 19/19 new unit tests pass
  • No regression on existing `bf16_cast` / `fp8_128x128` entries (covered by registry tests + the existing `test_qwen3_moe.py` which exercises both via `select_default_conversion`)
  • Cluster validation on a real Qwen3-MoE cutlass FP8 checkpoint (deferred; requires a cutlass-quantized checkpoint to be available)
  • After Phase 2 graduation lands, plumb `ConversionEntry.compile_target` + `compile_metadata` into `MxV2TrainingPublisher` at publish time so receivers can filter via `discover_v2_sources(compile_target_filter=…)` (follow-up PR)

Context

  • Base RFC: `docs/proposals/post-pr2389-kernel-compile-plan.md` on the base branch.
  • Sister PR (MX side): #349 — receiver-side filtering API.
  • Sister PR (rendezvous): #1 — Phase 2 heartbeat/dedup/same-rank filter.
  • Coworker complaint that motivated this work: prime-rl breaking on Qwen3-MoE + cutlass + non-128x128 quant.

…data tagging

Extends prime_rl/trainer/models/conversions/ to address the live coworker
complaint that prime-rl breaks on Qwen3-MoE with cutlass kernels — the
registry currently has only `bf16_cast` and `fp8_128x128`; anything else
raises NotImplementedError, and there's no compile_target tag on the
publish so wrong-target receivers silently misinterpret bytes.

This is the trainer-side half of the design fix; the receiver-side
filtering API is already shipped in modelexpress as PR PrimeIntellect-ai#349 (Phase 3a/3b
on kavink/post-2389-phase3-4). Once Phase 2 graduation lands on
#1, the MxV2TrainingPublisher will read each
tensor's resolved ConversionEntry.compile_target + compile_metadata and
tag the v2 publish so receivers can filter via
discover_v2_sources(compile_target_filter=…, required_compile_metadata=…).

What lands:

ConversionEntry gains two new fields with safe defaults:
  - compile_target: str = "hf_raw"
  - compile_metadata: dict[str, Any] = {}
register(...) takes them as kwargs; existing call sites are unchanged.
Mirrors the constants in modelexpress.shape_descriptors (Phase 3a) but
without a hard import dep in either direction — both repos keep their
own canonical string set.

select_default_conversion is refactored to a table-driven design. The
old if/else chain is replaced by _DEFAULT_RULES: list[(predicate, name)]
which the resolver walks in order. Adding a new kernel = adding one row
via register_default_rule(predicate, name) from the kernel's own module
on import. A predicate that raises on a malformed config is treated as
"doesn't match" and skipped, keeping the resolver robust to model-card
weirdness without forcing every predicate to be defensive.

The AutoConfig import is deferred into the function body so the
registry loads without requiring `transformers` (the registry is
imported by tests + tooling that have no HF download capability).

Existing entries get their tags retroactively:
  - bf16_cast / fp32_cast: compile_target="hf_raw"
  - fp8_128x128: compile_target="deep_gemm_fp8" + metadata{block_size:
    [128,128], scale_layout:"blockwise", dtype:"e4m3"}

New conversion: cutlass_fp8_e4m3_per_channel
  - One scalar scale per output row (vs DeepGemm's per-128x128-block).
  - 2D dispatch: (out, in) weight → (out,) scale. 3D dispatch:
    (E, out, in) stacked MoE → (E, out) scale.
  - compile_target="cutlass_fp8", compile_metadata={dtype:"e4m3",
    scale_layout:"per_channel", scale_axis:-1, activation_scheme:
    "dynamic"} — matches cutlass scaled_mm + vLLM's native FP8 path.
  - Two default-resolver predicates:
      * quant_method="fp8" + quant_format="cutlass" (explicit)
      * quant_method="fp8" + weight_block_size=None +
        activation_scheme="dynamic" (the vLLM-published convention)
    Both predicates run AFTER the deep-gemm rule, so models with
    block_size=[128,128] AND activation_scheme="dynamic" still resolve
    to fp8_128x128 (regression-tested).

Per-channel helpers in trainer/models/fp8.py:
  - fp8_per_channel_quantize(weight) → (q_e4m3, scale_f32). Handles 2D
    and 3D via the same code path; reduction over the innermost axis.
  - fp8_per_channel_quantize_into(weight, out, sf) — writes into
    preallocated buffers, matches the convention of fp8_block_quantize.

Tests: 19/19 green via direct-load + transformers stub.
Categories:
  - Per-channel quantize: 2D shape, 3D shape, 1D rejected,
    bf16 dequant accuracy (≤5% median rel error), into-buffer write.
  - Registry: existing entries carry correct compile_target +
    compile_metadata, cutlass entry registered + listed, default-rule
    insert/append ordering works, unknown quant error message lists
    registered names.
  - select_default_conversion dispatch: no-quant → bf16, [128,128]
    blockwise → fp8_128x128, quant_format=cutlass → cutlass, no
    weight_block_size + dynamic → cutlass, deep-gemm wins when both
    rules match.
  - Conversion fn dispatch: 2D linear path correctness, 3D MoE path
    correctness, requires_scale=True enforced.

Adding a sibling kernel (per-token cutlass, awq, gptq, mxfp4, …) is now
one new module ~80 LOC: write the quant fn, register() it with
appropriate compile_target/metadata, register_default_rule() with its
HF-config predicate.

Branches off PR PrimeIntellect-ai#2389 head 79ea824. Independent of the Phase 2
graduation PR — these can land in parallel.
KavinKrishnan added a commit that referenced this pull request May 29, 2026
… (v0.7.x)

Captures the empirical findings from baking PRs #1 and #2 into an ARM64
GB200 image and running it on the kavin namespace for 8+ hours on
Qwen3-30B-A3B-Instruct-2507 with gsm8k.

Documents three real surprises the unit tests didn't cover:

1. Dockerfile.cuda's `uv sync` is missing `--extra disagg`, so modelexpress
   isn't installed in stock images; inference workers crash at the first
   import. Shipped v0.7.1 as a one-line overlay that adds the extra until
   the upstream Dockerfile.cuda can be updated.

2. `LD_PRELOAD` path for libcudart.so.12 — v0.5.2 had /usr/local/cuda
   present in the final stage; v0.7.0 (built from upstream Dockerfile.cuda
   as-is) doesn't. The pip-installed wheel path
   (/app/.venv/lib/python3.12/site-packages/nvidia/cuda_runtime/lib/) is
   the new canonical location.

3. The configmap monkeypatch (patch_nixl_mx.py) and Phase 2's source-baked
   fixes are complementary — they patch different layers (broadcast vs
   rendezvous-wait) and both should stay until PR #1 merges upstream.

Build experience numbers:
  - v0.7.0 from-scratch ARM64 build under QEMU: 6h45min (uv sync 45m,
    flash-attn from source 3h45m).
  - v0.7.1 overlay on top of v0.7.0: ~3 min.

Cluster observations from v0.5.2 + configmap monkeypatch (the
runtime-patched path our PR #1 codifies into source):
  - 183 successful RL refit cycles in one 66-min uninterrupted window
  - Reward variance 0.5-1.0 across orchestrator steps (real learning)
  - Off-policy level = 0 throughout
  - Zero NIXL data-plane errors
  - Recurring orchestrator wait_for_all_peers_ready timeout (~once per
    30-66 min) is the exact bug class Phase 2's rendezvous-level dedup
    eliminates

Also notes seven RFC updates queued in
pensieve/RL/PrimeRL/09_rfc_updates_needed.md, three of which are new from
this build experience (disagg extra, LD_PRELOAD path, vLLM PR #43375 /
Anyscale RDT positioning).

Companion to the RFC at docs/proposals/post-pr2389-kernel-compile-plan.md.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant