From 0113d46558668d537657f964fb414746754bb14c Mon Sep 17 00:00:00 2001 From: popfido Date: Sat, 30 May 2026 21:44:07 +0800 Subject: [PATCH] turboquant: guard L=1 value kernels behind `not use_rht` The L=1 value-reconstruction Metal kernels (_metal_mse_weighted_sum, _metal_mse_weighted_sum_sum_from_scores) undo the codec rotation with matmul(weighted_rot, rotation), which is only the inverse for a plain rotation. _TurboQuantMSECodec defaults to use_rht=True (randomized Hadamard transform) whose inverse is _rht_inverse(.; signs), so under RHT these kernels return uncorrelated output (~140% reconstruction error at every bit depth). weighted_sum_from_scores already guards its kernel with `if not self.use_rht`; weighted_sum and weighted_sum_stats_from_scores did not, corrupting the slow (masked / array-mask) single-query decode path used by continuous-batching decode with per-request left-padding. Add the same guard so RHT falls back to the correct einsum + _rotate_inverse path. Verified 140% -> ~1%. --- mlx_vlm/turboquant.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/mlx_vlm/turboquant.py b/mlx_vlm/turboquant.py index 8c8d4461d..d5e3beaa2 100644 --- a/mlx_vlm/turboquant.py +++ b/mlx_vlm/turboquant.py @@ -4287,7 +4287,7 @@ def score(self, queries: mx.array, state: TurboQuantMSEState) -> mx.array: return self.score_prepared(self.prepare_queries(queries), state) def weighted_sum(self, weights: mx.array, state: TurboQuantMSEState) -> mx.array: - if weights.shape[-2] == 1: + if not self.use_rht and weights.shape[-2] == 1: fast_output = _metal_mse_weighted_sum( weights, state, @@ -4327,8 +4327,11 @@ def weighted_sum_stats_from_scores( self, scores: mx.array, state: TurboQuantMSEState ) -> tuple[mx.array, mx.array, mx.array]: max_scores = mx.max(scores, axis=-1) - # Metal kernel fast path: only for single-query decode (L=1) - if scores.ndim == 5 and scores.shape[-2] == 1: + # Metal kernel fast path: only for single-query decode (L=1). + # Skip under RHT: these L=1 value kernels undo the codec rotation with + # matmul(., rotation) and do not implement the RHT inverse, so they + # corrupt the output when use_rht is set (the codec default). + if not self.use_rht and scores.ndim == 5 and scores.shape[-2] == 1: max_scores_2d = max_scores.reshape( max_scores.shape[0], max_scores.shape[1],