diff --git a/mlx_vlm/turboquant.py b/mlx_vlm/turboquant.py index 8c8d4461d..d5e3beaa2 100644 --- a/mlx_vlm/turboquant.py +++ b/mlx_vlm/turboquant.py @@ -4287,7 +4287,7 @@ def score(self, queries: mx.array, state: TurboQuantMSEState) -> mx.array: return self.score_prepared(self.prepare_queries(queries), state) def weighted_sum(self, weights: mx.array, state: TurboQuantMSEState) -> mx.array: - if weights.shape[-2] == 1: + if not self.use_rht and weights.shape[-2] == 1: fast_output = _metal_mse_weighted_sum( weights, state, @@ -4327,8 +4327,11 @@ def weighted_sum_stats_from_scores( self, scores: mx.array, state: TurboQuantMSEState ) -> tuple[mx.array, mx.array, mx.array]: max_scores = mx.max(scores, axis=-1) - # Metal kernel fast path: only for single-query decode (L=1) - if scores.ndim == 5 and scores.shape[-2] == 1: + # Metal kernel fast path: only for single-query decode (L=1). + # Skip under RHT: these L=1 value kernels undo the codec rotation with + # matmul(., rotation) and do not implement the RHT inverse, so they + # corrupt the output when use_rht is set (the codec default). + if not self.use_rht and scores.ndim == 5 and scores.shape[-2] == 1: max_scores_2d = max_scores.reshape( max_scores.shape[0], max_scores.shape[1],