perf(turboquant): RHT-correct L=1 value kernels (keep the fast path under RHT)#1252
Merged
Merged
Conversation
…nder RHT) The single-token value kernels (_metal_mse_weighted_sum and friends) return the weighted value sum in the codec's rotated space and undo it with a hard-coded matmul(weighted_rot, rotation). That inverse is only correct for the dense-rotation codec, so Blaizzy#1244 disabled these kernels whenever the codec uses the Randomized Hadamard Transform (use_rht) — every RHT decode fell back to the slower einsum path. Pass the codec's RHT signs into the wrappers and apply the matching inverse (_rht_inverse when signs are set, else the dense matmul) via a small _value_rotate_inverse helper, mirroring _TurboQuantMSECodec._rotate_inverse and the already-correct fused-decode path. The kernel computes weighted_rot correctly regardless of rotation type; only the post-kernel inverse needed fixing. The not-self.use_rht guards are dropped so RHT decode takes the kernel again. Prod-mode fused-decode call sites pass no signs (default None) and keep the exact dense matmul — byte-for-byte unchanged. Verified: RHT weighted_sum / weighted_sum_stats_from_scores match the einsum fallback and the dequantize ground truth to <1e-4 / <1e-3 across dims {64,128,256} x bits {2,3,4,8} x repeats {1,4}; the L=1 value reconstruction runs 3.5-13.7x faster than the einsum fallback (the gap grows with context length).
4 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Follow-up to #1244. That PR made the single-token (L=1) value Metal kernels correct under RHT by declining them — every RHT decode falls back to the slower einsum path. This PR makes the kernels themselves RHT-correct so the fast path stays engaged.
Root cause.
_metal_mse_weighted_sum/_metal_mse_weighted_sum_from_scores/_metal_mse_weighted_sum_sum_from_scoresreturn the weighted value sum in the codec's rotated space and undid it with a hard-codedmatmul(weighted_rot, rotation). That inverse is only valid for the dense-rotation codec; under the Randomized Hadamard Transform (use_rht, the default for power-of-2 head dims) the correct inverse is_rht_inverse. The kernel computesweighted_rotcorrectly regardless of rotation type — only the post-kernel inverse was wrong, which is exactly why #1244 had to skip it.Fix. Thread the codec's RHT
signsinto the three wrappers and apply the matching inverse through a small_value_rotate_inversehelper (_rht_inversewhen signs are set, else the dense matmul) — mirroring_TurboQuantMSECodec._rotate_inverseand the already-correct fused-decode path (value_codec._rotate_inverse(out_rotated)). Thenot self.use_rhtguards from #1244 are dropped so RHT decode takes the kernel again.Prod-mode fused-decode call sites pass no
signs(defaultNone) → identical dense matmul, byte-for-byte unchanged.Verification
New
mlx_vlm/tests/test_turboquant_rht_value_kernels.py:{64,128,256}× bits{2,3,4,8}× repeats{1,4}, plus theweighted_sum_stats_from_scoresdecode path. Agreement<1e-4(kernel vs einsum) and<1e-3(kernel vs dequant truth).weighted_sumruns the Metal kernel and never the einsum fallback (which unpacks low-bit indices). This test fails onmaintoday (RHT is skipped) and passes here.test_turboquant.pysuite: 23 passed, no regressions.Benchmark (M-series, L=1 value reconstruction, dim=128 bits=4 B=4 H=8 repeats=4)
The gap grows with context length because the einsum fallback re-unpacks every token each decode step.
Scope / risk
Low. Kernel math is unchanged; only the post-kernel inverse moves to the RHT-aware
_value_rotate_inverse. The dense-rotation and prod-mode paths are untouched (nosigns→ same matmul).