Skip to content

perf(turboquant): RHT-correct L=1 value kernels (keep the fast path under RHT)#1252

Merged
Blaizzy merged 2 commits into
Blaizzy:mainfrom
popfido:fix/turboquant-rht-value-kernels
Jun 1, 2026
Merged

perf(turboquant): RHT-correct L=1 value kernels (keep the fast path under RHT)#1252
Blaizzy merged 2 commits into
Blaizzy:mainfrom
popfido:fix/turboquant-rht-value-kernels

Conversation

@popfido
Copy link
Copy Markdown
Contributor

@popfido popfido commented May 31, 2026

Summary

Follow-up to #1244. That PR made the single-token (L=1) value Metal kernels correct under RHT by declining them — every RHT decode falls back to the slower einsum path. This PR makes the kernels themselves RHT-correct so the fast path stays engaged.

Root cause. _metal_mse_weighted_sum / _metal_mse_weighted_sum_from_scores / _metal_mse_weighted_sum_sum_from_scores return the weighted value sum in the codec's rotated space and undid it with a hard-coded matmul(weighted_rot, rotation). That inverse is only valid for the dense-rotation codec; under the Randomized Hadamard Transform (use_rht, the default for power-of-2 head dims) the correct inverse is _rht_inverse. The kernel computes weighted_rot correctly regardless of rotation type — only the post-kernel inverse was wrong, which is exactly why #1244 had to skip it.

Fix. Thread the codec's RHT signs into the three wrappers and apply the matching inverse through a small _value_rotate_inverse helper (_rht_inverse when signs are set, else the dense matmul) — mirroring _TurboQuantMSECodec._rotate_inverse and the already-correct fused-decode path (value_codec._rotate_inverse(out_rotated)). The not self.use_rht guards from #1244 are dropped so RHT decode takes the kernel again.

Prod-mode fused-decode call sites pass no signs (default None) → identical dense matmul, byte-for-byte unchanged.

Verification

New mlx_vlm/tests/test_turboquant_rht_value_kernels.py:

  • Equivalence — kernel path vs einsum fallback vs dequantize ground truth across dims {64,128,256} × bits {2,3,4,8} × repeats {1,4}, plus the weighted_sum_stats_from_scores decode path. Agreement <1e-4 (kernel vs einsum) and <1e-3 (kernel vs dequant truth).
  • Kernel-path guard — asserts RHT weighted_sum runs the Metal kernel and never the einsum fallback (which unpacks low-bit indices). This test fails on main today (RHT is skipped) and passes here.
  • Existing test_turboquant.py suite: 23 passed, no regressions.

Benchmark (M-series, L=1 value reconstruction, dim=128 bits=4 B=4 H=8 repeats=4)

context T einsum fallback RHT kernel speedup
512 1.85 ms 0.53 ms 3.5×
2048 3.05 ms 0.43 ms 7.0×
8192 15.56 ms 1.14 ms 13.7×

The gap grows with context length because the einsum fallback re-unpacks every token each decode step.

Scope / risk

Low. Kernel math is unchanged; only the post-kernel inverse moves to the RHT-aware _value_rotate_inverse. The dense-rotation and prod-mode paths are untouched (no signs → same matmul).

popfido and others added 2 commits May 31, 2026 11:23
…nder RHT)

The single-token value kernels (_metal_mse_weighted_sum and friends) return
the weighted value sum in the codec's rotated space and undo it with a
hard-coded matmul(weighted_rot, rotation). That inverse is only correct for
the dense-rotation codec, so Blaizzy#1244 disabled these kernels whenever the codec
uses the Randomized Hadamard Transform (use_rht) — every RHT decode fell back
to the slower einsum path.

Pass the codec's RHT signs into the wrappers and apply the matching inverse
(_rht_inverse when signs are set, else the dense matmul) via a small
_value_rotate_inverse helper, mirroring _TurboQuantMSECodec._rotate_inverse
and the already-correct fused-decode path. The kernel computes weighted_rot
correctly regardless of rotation type; only the post-kernel inverse needed
fixing. The not-self.use_rht guards are dropped so RHT decode takes the
kernel again.

Prod-mode fused-decode call sites pass no signs (default None) and keep the
exact dense matmul — byte-for-byte unchanged.

Verified: RHT weighted_sum / weighted_sum_stats_from_scores match the einsum
fallback and the dequantize ground truth to <1e-4 / <1e-3 across
dims {64,128,256} x bits {2,3,4,8} x repeats {1,4}; the L=1 value
reconstruction runs 3.5-13.7x faster than the einsum fallback (the gap grows
with context length).
Copy link
Copy Markdown
Owner

@Blaizzy Blaizzy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@Blaizzy Blaizzy merged commit 1090328 into Blaizzy:main Jun 1, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants