From b1ec70da156f05fc326948e24a5e48a3291527a7 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Thu, 21 May 2026 07:37:07 +0200 Subject: [PATCH 01/26] Fix Qwen3.5 batched left-padding drift --- mlx_vlm/models/qwen3_5/language.py | 113 +++++++++++++++++++++++++++++ 1 file changed, 113 insertions(+) diff --git a/mlx_vlm/models/qwen3_5/language.py b/mlx_vlm/models/qwen3_5/language.py index 880c18b40..d94f4e371 100644 --- a/mlx_vlm/models/qwen3_5/language.py +++ b/mlx_vlm/models/qwen3_5/language.py @@ -258,6 +258,28 @@ def _target_verify_embedding_as_linear(embedding, x: mx.array, target_verify: bo ) +def _extract_row_cache(cache_entry, row: int): + if isinstance(cache_entry, ArraysCache): + row_cache = ArraysCache(size=len(cache_entry.cache)) + row_cache.cache = [ + None if cached is None else cached[row : row + 1] + for cached in cache_entry.cache + ] + lengths = getattr(cache_entry, "lengths", None) + if lengths is not None: + row_cache.lengths = lengths[row : row + 1] + return row_cache + + if hasattr(cache_entry, "extract") and not cache_entry.empty(): + return cache_entry.extract(row) + + if hasattr(cache_entry, "left_padding"): + row_cache = KVCache() + return row_cache + + return cache_entry + + def _gated_delta_update_verify_decode( q: mx.array, k: mx.array, @@ -654,6 +676,75 @@ def __call__( if cache is None: cache = [None] * len(self.layers) + fa_cache = cache[self.fa_idx] + if ( + h.shape[0] > 1 + and hidden_sink is None + and gdn_sink is None + and fa_cache is not None + and hasattr(fa_cache, "extract") + and hasattr(fa_cache.__class__, "merge") + and isinstance(getattr(fa_cache, "offset", None), mx.array) + and fa_cache.offset.ndim > 0 + ): + query_left_padding = mx.minimum(mx.maximum(-fa_cache.offset, 0), h.shape[1]) + cache_left_padding = getattr(fa_cache, "left_padding", None) + has_left_padding = ( + isinstance(cache_left_padding, mx.array) + and cache_left_padding.ndim > 0 + and int(cache_left_padding.max().item()) > 0 + ) + if has_left_padding or int(query_left_padding.max().item()) > 0: + row_outputs = [] + row_caches = [[] for _ in cache] + for row, pad in enumerate(query_left_padding.tolist()): + pad = min(max(int(pad), 0), h.shape[1]) + row_inputs = inputs[row : row + 1, pad:] + row_embeds = h[row : row + 1, pad:] + row_position_ids = ( + None + if position_ids is None + else position_ids[:, row : row + 1, pad:] + ) + if pad > 0 and row_position_ids is not None: + row_position_ids = mx.broadcast_to( + mx.arange(row_inputs.shape[1])[None, None, :], + (3, 1, row_inputs.shape[1]), + ) + current_cache = [] + for cache_entry in cache: + if cache_entry is None: + current_cache.append(None) + else: + current_cache.append(_extract_row_cache(cache_entry, row)) + + row_out = self( + row_inputs, + inputs_embeds=row_embeds, + cache=current_cache, + position_ids=row_position_ids, + ) + if pad > 0: + row_out = mx.concatenate( + [ + mx.zeros( + (1, pad, row_out.shape[-1]), dtype=row_out.dtype + ), + row_out, + ], + axis=1, + ) + row_outputs.append(row_out) + for i, cache_entry in enumerate(current_cache): + row_caches[i].append(cache_entry) + + for i, entries in enumerate(row_caches): + if cache[i] is None: + continue + if hasattr(cache[i].__class__, "merge"): + cache[i] = cache[i].__class__.merge(entries) + return mx.concatenate(row_outputs, axis=0) + fa_mask = create_attention_mask(h, cache[self.fa_idx]) ssm_mask = create_ssm_mask(h, cache[self.ssm_idx]) @@ -1063,17 +1154,25 @@ def __call__( pixel_values = kwargs.pop("pixel_values", None) image_grid_thw = kwargs.pop("image_grid_thw", None) video_grid_thw = kwargs.pop("video_grid_thw", None) + attention_mask = kwargs.pop("attention_mask", None) capture_layer_ids = kwargs.pop("capture_layer_ids", None) return_hidden = kwargs.pop("return_hidden", False) return_shared_kv = kwargs.pop("return_shared_kv", False) skip_logits = kwargs.pop("skip_logits", False) rope_deltas_kw = kwargs.pop("rope_deltas", None) + if ( + mask is None + and attention_mask is not None + and attention_mask.shape[-1] == inputs.shape[-1] + ): + mask = attention_mask if pixel_values is not None: self._rope_deltas = None self._position_ids = None cache_offset = 0 cache_offsets = None # per-element offsets for batched caches + c0 = None if cache and cache[self.model.fa_idx] is not None: c0 = cache[self.model.fa_idx] cache_offset = c0._idx if hasattr(c0, "_idx") else c0.offset @@ -1084,6 +1183,16 @@ def __call__( ): cache_offsets = mx.maximum(c0.offset, 0) + if mask is None and c0 is not None and cache_offset == 0: + left_padding = getattr(c0, "left_padding", None) + if ( + isinstance(left_padding, mx.array) + and left_padding.ndim > 0 + and left_padding.size >= inputs.shape[0] + ): + positions = mx.arange(inputs.shape[-1])[None, :] + mask = positions >= left_padding[: inputs.shape[0], None] + # Check if mask shape matches input shape (for chunked prefill compatibility) rope_mask = mask if mask is not None and mask.shape[-1] != inputs.shape[-1]: @@ -1113,6 +1222,10 @@ def __call__( position_ids, rope_deltas = self.get_rope_index( inputs, image_grid_thw, video_grid_thw, rope_mask ) + if image_grid_thw is None and video_grid_thw is None: + rope_deltas = mx.zeros( + (batch_size, 1), dtype=rope_deltas.dtype + ) self._rope_deltas = rope_deltas self._position_ids = position_ids else: From ac4f2d59b1a7170c2f6fbfc3bc4c8a0351d93962 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Thu, 21 May 2026 12:32:58 +0200 Subject: [PATCH 02/26] Fix Qwen target verify batch drift --- mlx_vlm/models/qwen3_5/language.py | 25 +++++++++++++++++++------ mlx_vlm/tests/test_speculative.py | 24 +++++++++++++++++++++--- 2 files changed, 40 insertions(+), 9 deletions(-) diff --git a/mlx_vlm/models/qwen3_5/language.py b/mlx_vlm/models/qwen3_5/language.py index d94f4e371..98385f59b 100644 --- a/mlx_vlm/models/qwen3_5/language.py +++ b/mlx_vlm/models/qwen3_5/language.py @@ -220,6 +220,18 @@ def _target_verify_weight(weight: mx.array, x: mx.array) -> Optional[mx.array]: return out.reshape(B, L, O) +def _target_verify_singletons(fn, x: mx.array) -> mx.array: + rows = [] + for row in range(x.shape[0]): + rows.append( + mx.concatenate( + [fn(x[row : row + 1, i : i + 1]) for i in range(x.shape[1])], + axis=1, + ) + ) + return mx.concatenate(rows, axis=0) + + def _target_verify_linear(linear, x: mx.array, target_verify: bool) -> mx.array: if not _use_target_verify_dense(linear, x, target_verify): return linear(x) @@ -229,7 +241,7 @@ def _target_verify_linear(linear, x: mx.array, target_verify: bool) -> mx.array: if out is not None: return out - return mx.concatenate([linear(x[:, i : i + 1]) for i in range(x.shape[1])], axis=1) + return _target_verify_singletons(linear, x) def _target_verify_linears(linears, x: mx.array, target_verify: bool): @@ -252,10 +264,7 @@ def _target_verify_embedding_as_linear(embedding, x: mx.array, target_verify: bo if out is not None: return out - return mx.concatenate( - [embedding.as_linear(x[:, i : i + 1]) for i in range(x.shape[1])], - axis=1, - ) + return _target_verify_singletons(embedding.as_linear, x) def _extract_row_cache(cache_entry, row: int): @@ -403,7 +412,11 @@ def __call__( values[:, :, : prefix_len + i + 1, :], cache=cache, scale=self.scale, - mask=None, + mask=( + mask[..., i : i + 1, : prefix_len + i + 1] + if isinstance(mask, mx.array) and mask.ndim >= 4 + else None + ), ) for i in range(L) ], diff --git a/mlx_vlm/tests/test_speculative.py b/mlx_vlm/tests/test_speculative.py index bc2d0b6f6..fe87dbfea 100644 --- a/mlx_vlm/tests/test_speculative.py +++ b/mlx_vlm/tests/test_speculative.py @@ -400,7 +400,16 @@ def test_qwen_target_verify_linear_matches_singleton_dense_gemv(): linear.bias = mx.random.normal((32,)).astype(mx.bfloat16) x = mx.random.normal((3, 4, 16)).astype(mx.bfloat16) - ref = mx.concatenate([linear(x[:, i : i + 1]) for i in range(x.shape[1])], axis=1) + ref = mx.concatenate( + [ + mx.concatenate( + [linear(x[row : row + 1, i : i + 1]) for i in range(x.shape[1])], + axis=1, + ) + for row in range(x.shape[0]) + ], + axis=0, + ) out = qwen_language._target_verify_linear(linear, x, target_verify=True) mx.eval(ref, out) @@ -424,9 +433,18 @@ def test_qwen_target_verify_small_projection_matches_singleton_dense_gemv(): mx.random.seed(10) linear = nn.Linear(256, 8, bias=False) linear.weight = mx.random.normal((8, 256)).astype(mx.bfloat16) - x = mx.random.normal((1, 3, 256)).astype(mx.bfloat16) + x = mx.random.normal((3, 3, 256)).astype(mx.bfloat16) - ref = mx.concatenate([linear(x[:, i : i + 1]) for i in range(x.shape[1])], axis=1) + ref = mx.concatenate( + [ + mx.concatenate( + [linear(x[row : row + 1, i : i + 1]) for i in range(x.shape[1])], + axis=1, + ) + for row in range(x.shape[0]) + ], + axis=0, + ) out = qwen_language._target_verify_linear(linear, x, target_verify=True) mx.eval(ref, out) From 256757b3057741112a7430c4428406e038b85c27 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Thu, 21 May 2026 14:39:21 +0200 Subject: [PATCH 03/26] Fix Qwen batch parity for padded vision rows --- mlx_vlm/models/qwen3_5/language.py | 86 ++++++++++++++++++++++------- mlx_vlm/models/qwen3_vl/language.py | 43 +++++++++------ mlx_vlm/tests/test_speculative.py | 73 +++++++++++++++++++++++- 3 files changed, 164 insertions(+), 38 deletions(-) diff --git a/mlx_vlm/models/qwen3_5/language.py b/mlx_vlm/models/qwen3_5/language.py index 98385f59b..6289895bf 100644 --- a/mlx_vlm/models/qwen3_5/language.py +++ b/mlx_vlm/models/qwen3_5/language.py @@ -289,6 +289,15 @@ def _extract_row_cache(cache_entry, row: int): return cache_entry +def _is_single_row_batch_cache(cache_entry) -> bool: + left_padding = getattr(cache_entry, "left_padding", None) + return ( + isinstance(left_padding, mx.array) + and left_padding.ndim > 0 + and left_padding.size == 1 + ) + + def _gated_delta_update_verify_decode( q: mx.array, k: mx.array, @@ -690,6 +699,35 @@ def __call__( cache = [None] * len(self.layers) fa_cache = cache[self.fa_idx] + if ( + h.shape[0] == 1 + and hidden_sink is None + and gdn_sink is None + and fa_cache is not None + and _is_single_row_batch_cache(fa_cache) + ): + row_cache = [] + for cache_entry in cache: + if cache_entry is None: + row_cache.append(None) + elif _is_single_row_batch_cache(cache_entry): + row_cache.append(_extract_row_cache(cache_entry, 0)) + else: + row_cache.append(cache_entry) + + row_out = self( + inputs, + inputs_embeds=h, + cache=row_cache, + position_ids=position_ids, + ) + for i, cache_entry in enumerate(row_cache): + if cache[i] is None or cache_entry is None: + continue + if hasattr(cache[i].__class__, "merge"): + cache[i] = cache[i].__class__.merge([cache_entry]) + return row_out + if ( h.shape[0] > 1 and hidden_sink is None @@ -719,11 +757,6 @@ def __call__( if position_ids is None else position_ids[:, row : row + 1, pad:] ) - if pad > 0 and row_position_ids is not None: - row_position_ids = mx.broadcast_to( - mx.arange(row_inputs.shape[1])[None, None, :], - (3, 1, row_inputs.shape[1]), - ) current_cache = [] for cache_entry in cache: if cache_entry is None: @@ -1016,21 +1049,20 @@ def get_rope_index( ) image_index, video_index = 0, 0 for i, input_ids in enumerate(total_input_ids): - input_ids = mx.where( - attention_mask[i] == 1, input_ids, mx.zeros_like(input_ids) - ) + row_mask = attention_mask[i].tolist() + input_tokens = [ + token + for token, keep in zip(input_ids.tolist(), row_mask) + if keep == 1 + ] image_nums, video_nums = 0, 0 - vision_start_indices = mx.sum( - mx.where( - input_ids == vision_start_token_id, - mx.arange(input_ids.shape[0]), - mx.zeros_like(input_ids), - ) - ) - vision_tokens = input_ids[vision_start_indices + 1] - image_nums = (vision_tokens == image_token_id).sum().item() - video_nums = (vision_tokens == video_token_id).sum().item() - input_tokens = input_ids.tolist() + vision_tokens = [ + input_tokens[idx + 1] + for idx, token in enumerate(input_tokens[:-1]) + if token == vision_start_token_id + ] + image_nums = sum(token == image_token_id for token in vision_tokens) + video_nums = sum(token == video_token_id for token in vision_tokens) llm_pos_ids_list: list = [] st = 0 remain_images, remain_videos = image_nums, video_nums @@ -1112,7 +1144,19 @@ def get_rope_index( llm_pos_ids_list.append(t_index + st_idx) llm_positions = mx.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1) - mask = mx.array(attention_mask[i] == 1) + compact_max_position = llm_positions.max() + padded_positions = [[1] * total_input_ids.shape[1] for _ in range(3)] + compact_positions = llm_positions.tolist() + compact_idx = 0 + for col, keep in enumerate(row_mask): + if keep == 1: + for dim in range(3): + padded_positions[dim][col] = compact_positions[dim][ + compact_idx + ] + compact_idx += 1 + llm_positions = mx.array(padded_positions, dtype=position_ids.dtype) + mask = mx.array(row_mask, dtype=mx.bool_) expanded_mask = mx.expand_dims(mask, axis=0) expanded_mask = mx.broadcast_to(expanded_mask, (3, 1, mask.shape[0])) expanded_positions = mx.expand_dims(llm_positions, axis=1) @@ -1129,7 +1173,7 @@ def get_rope_index( ) position_ids = updated_position_ids mrope_position_deltas.append( - llm_positions.max() + 1 - len(total_input_ids[i]) + compact_max_position + 1 - len(input_tokens) ) mrope_position_deltas = mx.array(mrope_position_deltas).reshape(-1, 1) return position_ids, mrope_position_deltas diff --git a/mlx_vlm/models/qwen3_vl/language.py b/mlx_vlm/models/qwen3_vl/language.py index f6b6486e6..81dc442ff 100644 --- a/mlx_vlm/models/qwen3_vl/language.py +++ b/mlx_vlm/models/qwen3_vl/language.py @@ -357,21 +357,20 @@ def get_rope_index( ) image_index, video_index = 0, 0 for i, input_ids in enumerate(total_input_ids): - input_ids = mx.where( - attention_mask[i] == 1, input_ids, mx.zeros_like(input_ids) - ) + row_mask = attention_mask[i].tolist() + input_tokens = [ + token + for token, keep in zip(input_ids.tolist(), row_mask) + if keep == 1 + ] image_nums, video_nums = 0, 0 - vision_start_indices = mx.sum( - mx.where( - input_ids == vision_start_token_id, - mx.arange(input_ids.shape[0]), - mx.zeros_like(input_ids), - ) - ) - vision_tokens = input_ids[vision_start_indices + 1] - image_nums = (vision_tokens == image_token_id).sum().item() - video_nums = (vision_tokens == video_token_id).sum().item() - input_tokens = input_ids.tolist() + vision_tokens = [ + input_tokens[idx + 1] + for idx, token in enumerate(input_tokens[:-1]) + if token == vision_start_token_id + ] + image_nums = sum(token == image_token_id for token in vision_tokens) + video_nums = sum(token == video_token_id for token in vision_tokens) llm_pos_ids_list: list = [] st = 0 remain_images, remain_videos = image_nums, video_nums @@ -463,7 +462,19 @@ def get_rope_index( llm_pos_ids_list.append(t_index + st_idx) llm_positions = mx.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1) - mask = mx.array(attention_mask[i] == 1) + compact_max_position = llm_positions.max() + padded_positions = [[1] * total_input_ids.shape[1] for _ in range(3)] + compact_positions = llm_positions.tolist() + compact_idx = 0 + for col, keep in enumerate(row_mask): + if keep == 1: + for dim in range(3): + padded_positions[dim][col] = compact_positions[dim][ + compact_idx + ] + compact_idx += 1 + llm_positions = mx.array(padded_positions, dtype=position_ids.dtype) + mask = mx.array(row_mask, dtype=mx.bool_) expanded_mask = mx.expand_dims(mask, axis=0) expanded_mask = mx.broadcast_to(expanded_mask, (3, 1, mask.shape[0])) expanded_positions = mx.expand_dims(llm_positions, axis=1) @@ -480,7 +491,7 @@ def get_rope_index( ) position_ids = updated_position_ids mrope_position_deltas.append( - llm_positions.max() + 1 - len(total_input_ids[i]) + compact_max_position + 1 - len(input_tokens) ) mrope_position_deltas = mx.array(mrope_position_deltas).reshape(-1, 1) return position_ids, mrope_position_deltas diff --git a/mlx_vlm/tests/test_speculative.py b/mlx_vlm/tests/test_speculative.py index fe87dbfea..2e6c11ab8 100644 --- a/mlx_vlm/tests/test_speculative.py +++ b/mlx_vlm/tests/test_speculative.py @@ -16,7 +16,13 @@ import mlx_vlm.models.qwen3_5.language as qwen_language import mlx_vlm.speculative.mtp as mtp_utils -from mlx_vlm.models.cache import ArraysCache, BufferedRotatingKVCache, RotatingKVCache +from mlx_vlm.models.cache import ( + ArraysCache, + BatchKVCache, + BufferedRotatingKVCache, + KVCache, + RotatingKVCache, +) from mlx_vlm.speculative.common import _SpeculativeSamplerRNG from mlx_vlm.speculative.drafters import ( DEFAULT_DRAFTER_KIND, @@ -524,6 +530,71 @@ def test_qwen_gdn_verify_conv_matches_singleton_windows(): assert bool(mx.array_equal(ref, out).item()) +def test_qwen3_5_rope_index_ignores_left_padding_for_vision_rows(): + model_config = qwen_language.ModelConfig( + text_config=_tiny_qwen3_5_text_config(), + vision_config=SimpleNamespace(spatial_merge_size=2), + model_type="qwen3_5", + image_token_id=101, + video_token_id=102, + image_token_index=101, + video_token_index=102, + vision_start_token_id=100, + vision_end_token_id=103, + vocab_size=128, + ) + lm = qwen_language.LanguageModel.__new__(qwen_language.LanguageModel) + lm.config = model_config + + singleton_ids = mx.array([[10, 100, 101, 11, 12]], dtype=mx.int32) + padded_ids = mx.array( + [[0, 10, 100, 101, 11, 12], [20, 21, 22, 23, 24, 25]], + dtype=mx.int32, + ) + attention_mask = mx.array( + [[0, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]], dtype=mx.int32 + ) + image_grid_thw = mx.array([[1, 2, 2]], dtype=mx.int32) + + singleton_pos, singleton_delta = lm.get_rope_index(singleton_ids, image_grid_thw) + padded_pos, padded_delta = lm.get_rope_index( + padded_ids, image_grid_thw, attention_mask=attention_mask + ) + mx.eval(singleton_pos, padded_pos, singleton_delta, padded_delta) + + assert padded_pos[:, 0, 1:].tolist() == singleton_pos[:, 0, :].tolist() + assert padded_delta[0, 0].item() == singleton_delta[0, 0].item() + assert padded_delta[1, 0].item() == 0 + + +def test_qwen3_5_single_row_batch_cache_matches_singleton_cache(): + text_config = _tiny_qwen3_5_text_config() + text_config.num_hidden_layers = 2 + text_config.full_attention_interval = 2 + model = qwen_language.Qwen3_5Model(text_config) + + singleton_cache = [ArraysCache(size=2), KVCache()] + batch_arrays = ArraysCache(size=2) + batch_arrays.left_padding = mx.array([0], dtype=mx.int32) + batch_cache = [batch_arrays, BatchKVCache([0])] + + prompt = mx.array([[1, 2, 3]], dtype=mx.int32) + singleton_prompt = model(prompt, cache=singleton_cache) + batch_prompt = model(prompt, cache=batch_cache) + mx.eval(singleton_prompt, batch_prompt) + + assert bool(mx.array_equal(singleton_prompt, batch_prompt).item()) + assert isinstance(batch_cache[1], BatchKVCache) + + decode = mx.array([[4]], dtype=mx.int32) + singleton_decode = model(decode, cache=singleton_cache) + batch_decode = model(decode, cache=batch_cache) + mx.eval(singleton_decode, batch_decode) + + assert bool(mx.array_equal(singleton_decode, batch_decode).item()) + assert isinstance(batch_cache[1], BatchKVCache) + + def test_speculative_walk_accepts_until_first_mismatch(): accepted, new_tokens = _speculative_walk( mx.array([[11, 12, 13]], dtype=mx.int32), From f8f570c582f151b1a695871fef5f89049c836e6f Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Thu, 21 May 2026 22:43:50 +0200 Subject: [PATCH 04/26] Fix exact batched Qwen MTP verification --- mlx_vlm/models/qwen3_5/language.py | 95 ++++++++++++++++++++++++++---- mlx_vlm/server/generation.py | 8 ++- mlx_vlm/speculative/mtp.py | 76 +++++++++++++++++++++--- mlx_vlm/tests/test_server.py | 26 ++++++++ mlx_vlm/tests/test_speculative.py | 41 +++++++++++++ 5 files changed, 227 insertions(+), 19 deletions(-) diff --git a/mlx_vlm/models/qwen3_5/language.py b/mlx_vlm/models/qwen3_5/language.py index 6289895bf..1c1583782 100644 --- a/mlx_vlm/models/qwen3_5/language.py +++ b/mlx_vlm/models/qwen3_5/language.py @@ -298,6 +298,20 @@ def _is_single_row_batch_cache(cache_entry) -> bool: ) +def _pad_row_time(x: mx.array, pad: int, target_length: int) -> mx.array: + if pad <= 0: + return x + if x.shape[1] >= target_length: + return x + return mx.concatenate( + [ + mx.zeros((x.shape[0], pad, *x.shape[2:]), dtype=x.dtype), + x, + ], + axis=1, + ) + + def _gated_delta_update_verify_decode( q: mx.array, k: mx.array, @@ -324,6 +338,67 @@ def _gated_delta_update_verify_decode( ) +def _target_verify_left_padded_attention( + queries: mx.array, + keys: mx.array, + values: mx.array, + *, + cache, + scale: float, + mask: Optional[mx.array], +) -> Optional[mx.array]: + if hasattr(cache, "bits") or queries.ndim != 4 or keys.ndim != 4: + return None + + left_padding = getattr(cache, "left_padding", None) + if not ( + isinstance(left_padding, mx.array) + and left_padding.ndim > 0 + and int(left_padding.max().item()) > 0 + ): + return None + + row_outputs = {} + pads = [int(p) for p in left_padding.tolist()] + for pad in sorted(set(pads)): + rows = [i for i, row_pad in enumerate(pads) if row_pad == pad] + row_idx = mx.array(rows, dtype=mx.int32) + group_queries = mx.take(queries, row_idx, axis=0) + group_keys = mx.take(keys, row_idx, axis=0)[:, :, pad:, :] + group_values = mx.take(values, row_idx, axis=0)[:, :, pad:, :] + + if group_queries.shape[2] > 1: + prefix_len = group_keys.shape[-2] - group_queries.shape[2] + group_output = mx.concatenate( + [ + scaled_dot_product_attention( + group_queries[:, :, i : i + 1, :], + group_keys[:, :, : prefix_len + i + 1, :], + group_values[:, :, : prefix_len + i + 1, :], + cache=None, + scale=scale, + mask=None, + ) + for i in range(group_queries.shape[2]) + ], + axis=2, + ) + else: + group_output = scaled_dot_product_attention( + group_queries, + group_keys, + group_values, + cache=None, + scale=scale, + mask=None, + ) + + for j, row in enumerate(rows): + row_outputs[row] = group_output[j : j + 1] + + return mx.concatenate([row_outputs[i] for i in range(queries.shape[0])], axis=0) + + class Qwen3_5Attention(nn.Module): def __init__(self, args: TextConfig): super().__init__() @@ -372,7 +447,6 @@ def __call__( target_verify: bool = False, ) -> mx.array: B, L, D = x.shape - q_proj_output, keys, values = _target_verify_linears( (self.q_proj, self.k_proj, self.v_proj), x, target_verify ) @@ -412,6 +486,13 @@ def __call__( keys, values = cache.update_and_fetch(keys, values) if target_verify and L > 1: + output = _target_verify_left_padded_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask + ) + else: + output = None + + if output is None and target_verify and L > 1: prefix_len = keys.shape[-2] - L output = mx.concatenate( [ @@ -431,7 +512,7 @@ def __call__( ], axis=2, ) - else: + elif output is None: output = scaled_dot_product_attention( queries, keys, values, cache=cache, scale=self.scale, mask=mask ) @@ -771,15 +852,7 @@ def __call__( position_ids=row_position_ids, ) if pad > 0: - row_out = mx.concatenate( - [ - mx.zeros( - (1, pad, row_out.shape[-1]), dtype=row_out.dtype - ), - row_out, - ], - axis=1, - ) + row_out = _pad_row_time(row_out, pad, h.shape[1]) row_outputs.append(row_out) for i, cache_entry in enumerate(current_cache): row_caches[i].append(cache_entry) diff --git a/mlx_vlm/server/generation.py b/mlx_vlm/server/generation.py index 4ebf960ac..a483cbdd7 100644 --- a/mlx_vlm/server/generation.py +++ b/mlx_vlm/server/generation.py @@ -91,6 +91,12 @@ def get_speculative_batch_coalesce_s(): return DEFAULT_SPECULATIVE_BATCH_COALESCE_MS / 1000.0 +def _sample_last_token(logits: mx.array, sampler: Callable[[mx.array], mx.array]): + logits = logits[:, -1, :] + logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) + return sampler(logprobs) + + def get_server_enable_thinking(): raw = os.environ.get("MLX_VLM_ENABLE_THINKING") if raw is None: @@ -993,7 +999,7 @@ def _run_speculative(self): out = lm(input_mx, cache=prompt_cache, **lm_call_kwargs) hidden = speculative_hidden_state(draft_kind, out) shared_kv_states = out.shared_kv_states if is_mtp else None - first_bonus = sampler(out.logits[:, -1:]).squeeze(-1) + first_bonus = _sample_last_token(out.logits, sampler) mx.eval(first_bonus, hidden, out.logits) prompt_elapsed = time.perf_counter() - prompt_started for uid in uids: diff --git a/mlx_vlm/speculative/mtp.py b/mlx_vlm/speculative/mtp.py index ccaafb6b9..30317f536 100644 --- a/mlx_vlm/speculative/mtp.py +++ b/mlx_vlm/speculative/mtp.py @@ -254,6 +254,53 @@ def _speculative_walk_batch_deferred_greedy( return accepted, new_tokens +def _speculative_walk_batch_deferred_uniform( + lm: nn.Module, + target_hidden: mx.array, + draft_tokens: mx.array, + sampler: Callable[[mx.array], mx.array], + budgets: List[int], +) -> Tuple[List[int], List[List[int]]]: + """Deferred walk for models whose batched drafter cache needs lockstep rows. + + Stop at the first rejection in any row. This avoids consuming target sampler + RNG for verifier positions that will be thrown away by the uniform clamp. + """ + B = draft_tokens.shape[0] + n_draft = draft_tokens.shape[1] + draft_lists = [[int(token) for token in row] for row in draft_tokens.tolist()] + budgets = [int(budget) for budget in budgets] + + accepted = 0 + for pos in range(n_draft + 1): + with mx.stream(generation_stream): + logits = lm.speculative_logits_from_hidden( + target_hidden[:, pos : pos + 1, :] + ) + if logits.ndim == 3 and logits.shape[1] == 1: + logits = logits[:, 0, :] + logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) + target_tokens = sampler(logprobs) + mx.eval(target_tokens) + target_list = [int(token) for token in target_tokens.reshape(-1).tolist()] + + if pos < n_draft and all( + target_list[row] == draft_lists[row][pos] for row in range(B) + ): + accepted += 1 + continue + + new_tokens: List[List[int]] = [] + for row, budget in enumerate(budgets): + row_tokens = draft_lists[row][:accepted] + if len(row_tokens) < budget: + row_tokens.append(target_list[row]) + new_tokens.append(row_tokens[:budget]) + return [accepted] * B, new_tokens + + return [accepted] * B, [draft_lists[row][: budgets[row]] for row in range(B)] + + def _mtp_acceptance_walk( lm: nn.Module, verify: _MTPVerifyResult, @@ -740,6 +787,7 @@ def _mtp_rounds_batch( verify_input, prompt_cache, sampler, + sample_target_tokens=greedy_sampling, ) hidden_full = verify.hidden # [B_active, bs, H] @@ -766,13 +814,27 @@ def _mtp_rounds_batch( ) else: sampler_rng.target_eval(hidden_full) - accepted_list, new_tokens_list = _speculative_walk_batch_deferred_greedy( - lm, - hidden_full, - draft_tokens, - sampler, - budgets, - ) + if ( + n_active > 1 + and getattr(draft_model, "requires_uniform_batch_acceptance", False) + ): + accepted_list, new_tokens_list = ( + _speculative_walk_batch_deferred_uniform( + lm, + hidden_full, + draft_tokens, + sampler, + budgets, + ) + ) + else: + accepted_list, new_tokens_list = _speculative_walk_batch_deferred_greedy( + lm, + hidden_full, + draft_tokens, + sampler, + budgets, + ) sampler_rng.target_sampled(sync_draft=True) # Keep the adaptive block-size history on a per-round basis so # batched MTP reacts like the singleton loop instead of letting diff --git a/mlx_vlm/tests/test_server.py b/mlx_vlm/tests/test_server.py index 0b459aad7..8ff067ad6 100644 --- a/mlx_vlm/tests/test_server.py +++ b/mlx_vlm/tests/test_server.py @@ -54,6 +54,32 @@ def test_speculative_server_dispatches_mtp_batch_loop(): ) +def test_speculative_server_samples_first_bonus_like_decode_step(): + seen = {} + logits = mx.array( + [ + [[1.0, 2.0, 3.0]], + [[4.0, 1.0, 0.0]], + ], + dtype=mx.float32, + ) + + def sampler(logprobs): + seen["shape"] = logprobs.shape + seen["values"] = logprobs + return mx.argmax(logprobs, axis=-1) + + tokens = server_generation._sample_last_token(logits, sampler) + expected_logprobs = logits[:, -1, :] - mx.logsumexp( + logits[:, -1, :], axis=-1, keepdims=True + ) + mx.eval(tokens, seen["values"], expected_logprobs) + + assert seen["shape"] == (2, 3) + assert tokens.tolist() == [2, 0] + assert bool(mx.allclose(seen["values"], expected_logprobs).item()) + + def test_speculative_server_dispatches_eagle3_batch_loop(): assert ( speculative_utils.get_speculative_rounds_batch("eagle3") diff --git a/mlx_vlm/tests/test_speculative.py b/mlx_vlm/tests/test_speculative.py index 2e6c11ab8..1763f474d 100644 --- a/mlx_vlm/tests/test_speculative.py +++ b/mlx_vlm/tests/test_speculative.py @@ -867,6 +867,47 @@ def as_linear(self, hidden): assert fake_head.calls == 3 +def test_speculative_walk_batch_deferred_uniform_stops_at_batch_rejection(): + class FakeEmbed: + def __init__(self): + self.calls = 0 + + def as_linear(self, hidden): + self.calls += 1 + return hidden + + fake_head = FakeEmbed() + lm = SimpleNamespace(speculative_logits_from_hidden=fake_head.as_linear) + target_hidden = mx.array( + [ + [ + [0, 0, 9, 0], + [0, 0, 0, 9], + [0, 9, 0, 0], + ], + [ + [0, 9, 0, 0], + [9, 0, 0, 0], + [0, 0, 9, 0], + ], + ], + dtype=mx.float32, + ) + draft_tokens = mx.array([[2, 3], [0, 2]], dtype=mx.int32) + + accepted, new_tokens = mtp_utils._speculative_walk_batch_deferred_uniform( + lm, + target_hidden, + draft_tokens, + lambda logits: mx.argmax(logits, axis=-1), + budgets=[3, 3], + ) + + assert accepted == [0, 0] + assert new_tokens == [[2], [1]] + assert fake_head.calls == 1 + + def test_mtp_draft_hidden_uses_model_hook(): hidden = mx.array([[[1.0, 2.0]]], dtype=mx.float32) lm = SimpleNamespace(speculative_draft_hidden=lambda h: h * 2) From 7722c96618aa79a6078ac80b7ef22fcf5fa5125f Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Fri, 22 May 2026 00:16:36 +0200 Subject: [PATCH 05/26] Fix ragged Qwen3.5 MTP batch parity --- mlx_vlm/models/qwen3_5/language.py | 69 ++++++- .../drafters/qwen3_5_mtp/qwen3_5_mtp.py | 168 +++++++++++++----- mlx_vlm/speculative/mtp.py | 7 +- mlx_vlm/tests/test_speculative.py | 91 ++++++++++ 4 files changed, 278 insertions(+), 57 deletions(-) diff --git a/mlx_vlm/models/qwen3_5/language.py b/mlx_vlm/models/qwen3_5/language.py index 1c1583782..82fb1f6ef 100644 --- a/mlx_vlm/models/qwen3_5/language.py +++ b/mlx_vlm/models/qwen3_5/language.py @@ -8,7 +8,6 @@ from ..base import ( LanguageModelOutput, create_attention_mask, - create_ssm_mask, scaled_dot_product_attention, ) from ..cache import ArraysCache, KVCache @@ -197,7 +196,7 @@ def _use_target_verify_dense(linear, x: mx.array, target_verify: bool) -> bool: target_verify and x.ndim == 3 and x.shape[1] > 1 - and isinstance(linear, nn.Linear) + and isinstance(linear, (nn.Linear, nn.QuantizedLinear)) ) @@ -220,6 +219,10 @@ def _target_verify_weight(weight: mx.array, x: mx.array) -> Optional[mx.array]: return out.reshape(B, L, O) +def _target_verify_timewise(fn, x: mx.array) -> mx.array: + return mx.concatenate([fn(x[:, i : i + 1]) for i in range(x.shape[1])], axis=1) + + def _target_verify_singletons(fn, x: mx.array) -> mx.array: rows = [] for row in range(x.shape[0]): @@ -236,7 +239,10 @@ def _target_verify_linear(linear, x: mx.array, target_verify: bool) -> mx.array: if not _use_target_verify_dense(linear, x, target_verify): return linear(x) - if "bias" not in linear: + if isinstance(linear, nn.QuantizedLinear): + return _target_verify_timewise(linear, x) + + if isinstance(linear, nn.Linear) and "bias" not in linear: out = _target_verify_weight(linear.weight, x) if out is not None: return out @@ -249,7 +255,9 @@ def _target_verify_linears(linears, x: mx.array, target_verify: bool): target_verify and x.ndim == 3 and x.shape[1] > 1 - and all(isinstance(linear, nn.Linear) for linear in linears) + and all( + isinstance(linear, (nn.Linear, nn.QuantizedLinear)) for linear in linears + ) ): return tuple(linear(x) for linear in linears) @@ -264,7 +272,7 @@ def _target_verify_embedding_as_linear(embedding, x: mx.array, target_verify: bo if out is not None: return out - return _target_verify_singletons(embedding.as_linear, x) + return _target_verify_timewise(embedding.as_linear, x) def _extract_row_cache(cache_entry, row: int): @@ -312,6 +320,21 @@ def _pad_row_time(x: mx.array, pad: int, target_length: int) -> mx.array: ) +def _create_qwen3_5_ssm_mask(h: mx.array, cache): + if not (cache and hasattr(cache, "make_mask")): + return None + + left_padding = getattr(cache, "left_padding", None) + if isinstance(left_padding, mx.array) and int(left_padding.max().item()) <= 0: + return None + + lengths = getattr(cache, "lengths", None) + if isinstance(lengths, mx.array) and int(lengths.min().item()) >= h.shape[1]: + return None + + return cache.make_mask(h.shape[1]) + + def _gated_delta_update_verify_decode( q: mx.array, k: mx.array, @@ -476,7 +499,13 @@ def __call__( cos, sin = self.rotary_emb(values, position_ids) if mask is not None and isinstance(mask, mx.array): - if isinstance(kv_seq_len, mx.array): + if ( + cache is not None + and hasattr(cache, "_idx") + and hasattr(cache, "left_padding") + ): + kv_seq_len = int(cache._idx) + L + elif isinstance(kv_seq_len, mx.array): kv_seq_len = kv_seq_len.max().item() mask = mask[..., : int(kv_seq_len)] @@ -865,7 +894,7 @@ def __call__( return mx.concatenate(row_outputs, axis=0) fa_mask = create_attention_mask(h, cache[self.fa_idx]) - ssm_mask = create_ssm_mask(h, cache[self.ssm_idx]) + ssm_mask = _create_qwen3_5_ssm_mask(h, cache[self.ssm_idx]) capture_set = set(capture_layer_ids) if capture_layer_ids else set() for i, (layer, c) in enumerate(zip(self.layers, cache)): @@ -921,7 +950,31 @@ def rollback_speculative_cache( if c.is_trimmable(): if trim > 0: c.trim(trim) - if is_batch and hasattr(c, "_idx") and c.keys is not None and max_a > 0: + right_trimmed = False + if is_batch and max_a > 0: + extra_trim = max_a - accepted + if int(extra_trim.max().item()) > 0: + prepare = getattr(c, "prepare", None) + finalize = getattr(c, "finalize", None) + if ( + c.keys is not None + and callable(prepare) + and callable(finalize) + ): + prepare( + right_padding=[ + int(extra) for extra in extra_trim.tolist() + ] + ) + finalize() + right_trimmed = True + if ( + is_batch + and not right_trimmed + and hasattr(c, "_idx") + and c.keys is not None + and max_a > 0 + ): kv_len = c._idx ve = valid_ends.tolist() verify_start = kv_len - n diff --git a/mlx_vlm/speculative/drafters/qwen3_5_mtp/qwen3_5_mtp.py b/mlx_vlm/speculative/drafters/qwen3_5_mtp/qwen3_5_mtp.py index 8b9792cf9..af9fca18f 100644 --- a/mlx_vlm/speculative/drafters/qwen3_5_mtp/qwen3_5_mtp.py +++ b/mlx_vlm/speculative/drafters/qwen3_5_mtp/qwen3_5_mtp.py @@ -5,7 +5,7 @@ import mlx.nn as nn from ....models.base import create_attention_mask -from ....models.cache import KVCache +from ....models.cache import BatchKVCache, KVCache from ....models.qwen3_5.language import Qwen3_5DecoderLayer from ....models.qwen3_5_moe.language import Qwen3_5MoeDecoderLayer from .config import Qwen3_5MTPConfig @@ -14,7 +14,8 @@ class Qwen3_5MTPDraftModel(nn.Module): supports_greedy_draft_argmax = True prefer_requested_block_size = True - requires_uniform_batch_acceptance = True + requires_uniform_batch_acceptance = False + supports_ragged_batch_acceptance = True def __init__(self, config: Qwen3_5MTPConfig): super().__init__() @@ -90,15 +91,19 @@ def bind(self, target_model) -> "Qwen3_5MTPDraftModel": ) return self - def make_cache(self) -> List[KVCache]: + def make_cache(self, left_padding: Optional[List[int]] = None) -> List[KVCache]: + if left_padding is not None: + return [BatchKVCache(left_padding) for _ in self.layers] return [KVCache() for _ in self.layers] - def reset(self, target_model) -> List[KVCache]: + def reset( + self, target_model, left_padding: Optional[List[int]] = None + ) -> List[KVCache]: self.bind(target_model) self.accept_lens = [] self.draft_lens = [] self._draft_round = 0 - self._cache = self.make_cache() + self._cache = self.make_cache(left_padding) self._seed_token = None self._seed_hidden = None self._next_position = 0 @@ -120,7 +125,8 @@ def set_shared_kv( position = kv_valid_len self._kv_valid_len = kv_valid_len self._position = position - if not self._cache or self._cache[0].offset == 0: + cache_empty = not self._cache or all(cache.empty() for cache in self._cache) + if cache_empty: self._next_position = kv_valid_len def _draft_start_position(self): @@ -266,12 +272,7 @@ def accept_verified_tokens_batch( token_dtype: mx.Dtype = mx.int32, greedy: bool = False, ) -> None: - """Extend the Qwen MTP drafter cache after a batched verify step. - - Qwen's native MTP drafter uses its own recurrent cache. For now the - batched path is restricted to lockstep acceptance so this cache keeps a - single scalar offset, matching normal decode's aligned batch invariant. - """ + """Extend the Qwen MTP drafter cache after a batched verify step.""" if len(accepted) <= 1: self.accept_verified_tokens( verify_hidden, @@ -284,43 +285,113 @@ def accept_verified_tokens_batch( ) return - accepted_set = {int(a) for a in accepted} - if len(accepted_set) != 1: - raise ValueError( - "Qwen MTP batched cache update requires uniform acceptance." - ) - accepted_i = accepted_set.pop() - - keep_appended = min(accepted_i, self._round_appended) - trim = self._round_appended - keep_appended - if trim > 0: - for cache in self._cache: - cache.trim(trim) - self._next_position = ( - self._next_position - trim - if isinstance(self._next_position, int) - else self._next_position - trim + accepted = [int(a) for a in accepted] + keep_appended = [min(a, self._round_appended) for a in accepted] + trims = [self._round_appended - keep for keep in keep_appended] + if any(trims): + if len(set(trims)) == 1: + trim = trims[0] + for cache in self._cache: + cache.trim(trim) + self._next_position = ( + self._next_position - trim + if isinstance(self._next_position, int) + else self._next_position - trim + ) + elif all(hasattr(cache, "prepare") for cache in self._cache) and all( + hasattr(cache, "finalize") for cache in self._cache + ): + for cache in self._cache: + cache.prepare(right_padding=trims) + for cache in self._cache: + cache.finalize() + if isinstance(self._next_position, mx.array): + self._next_position = self._next_position - mx.array( + trims, dtype=mx.int32 + ) + else: + self._next_position = mx.full( + (len(accepted),), self._next_position, dtype=mx.int32 + ) - mx.array(trims, dtype=mx.int32) + else: + raise ValueError( + "Qwen MTP ragged batch acceptance requires a batch-aware cache." + ) + + draft_rows = draft_tokens.tolist() + row_tokens = [] + row_hiddens = [] + for row, accepted_i in enumerate(accepted): + tokens_i = [] + hiddens_i = [] + for draft_idx in range(keep_appended[row], accepted_i): + tokens_i.append(int(draft_rows[row][draft_idx])) + hiddens_i.append( + verify_hidden[row : row + 1, draft_idx : draft_idx + 1, :] + ) + if new_tokens[row]: + tokens_i.append(int(new_tokens[row][-1])) + hiddens_i.append( + verify_hidden[row : row + 1, accepted_i : accepted_i + 1, :] + ) + row_tokens.append(tokens_i) + row_hiddens.append(hiddens_i) + + lengths = [len(tokens_i) for tokens_i in row_tokens] + max_len = max(lengths) if lengths else 0 + if max_len > 0: + token_data = [] + hidden_rows = [] + for tokens_i, hiddens_i in zip(row_tokens, row_hiddens): + token_data.extend(tokens_i) + pad = max_len - len(tokens_i) + if pad: + token_data.extend([0] * pad) + if hiddens_i: + hidden_row = mx.concatenate(hiddens_i, axis=1) + else: + hidden_row = mx.zeros( + (1, 0, verify_hidden.shape[-1]), dtype=verify_hidden.dtype + ) + if pad: + hidden_row = mx.concatenate( + [ + hidden_row, + mx.zeros( + (1, pad, verify_hidden.shape[-1]), + dtype=verify_hidden.dtype, + ), + ], + axis=1, + ) + hidden_rows.append(hidden_row) + + tokens = mx.array(token_data, dtype=token_dtype).reshape( + len(row_tokens), max_len ) + hiddens = mx.concatenate(hidden_rows, axis=0) + right_padding = [max_len - length for length in lengths] + if any(right_padding): + for cache in self._cache: + prepare = getattr(cache, "prepare", None) + if callable(prepare): + prepare(right_padding=right_padding, lengths=lengths) - token_chunks = [] - hidden_chunks = [] - for draft_idx in range(keep_appended, accepted_i): - token_chunks.append(draft_tokens[:, draft_idx : draft_idx + 1]) - hidden_chunks.append(verify_hidden[:, draft_idx : draft_idx + 1, :]) - - if all(new_tokens): - bonus = mx.array( - [[int(row_tokens[-1])] for row_tokens in new_tokens], - dtype=token_dtype, - ) - token_chunks.append(bonus) - hidden_chunks.append(verify_hidden[:, accepted_i : accepted_i + 1, :]) - - if token_chunks: - tokens = mx.concatenate(token_chunks, axis=1).astype(token_dtype) - hiddens = mx.concatenate(hidden_chunks, axis=1) h = self._forward_tokens(tokens, hiddens, token_dtype) - self._set_seed_from_hidden(h[:, -1:, :], sampler, greedy) + + if any(right_padding): + for cache in self._cache: + finalize = getattr(cache, "finalize", None) + if callable(finalize): + finalize() + if isinstance(self._next_position, mx.array): + self._next_position = self._next_position - mx.array( + right_padding, dtype=mx.int32 + ) + + last_idx = mx.array([length - 1 for length in lengths], dtype=mx.int32) + last_hidden = mx.take_along_axis(h, last_idx[:, None, None], axis=1) + self._set_seed_from_hidden(last_hidden, sampler, greedy) self._round_appended = 0 def filter_batch(self, keep) -> None: @@ -329,7 +400,10 @@ def filter_batch(self, keep) -> None: keep = mx.array(keep, dtype=mx.int32) for cache in self._cache: - if cache.keys is not None: + cache_filter = getattr(cache, "filter", None) + if callable(cache_filter): + cache_filter(keep) + elif cache.keys is not None: cache.keys = cache.keys[keep] cache.values = cache.values[keep] diff --git a/mlx_vlm/speculative/mtp.py b/mlx_vlm/speculative/mtp.py index 30317f536..ae1a12d4e 100644 --- a/mlx_vlm/speculative/mtp.py +++ b/mlx_vlm/speculative/mtp.py @@ -719,7 +719,10 @@ def _mtp_rounds_batch( B = first_bonus.shape[0] block_total = _dflash_block_total(draft_model, draft_block_size) configured_block_total = int(getattr(draft_model.config, "block_size", block_total)) - draft_model.reset(model) + if getattr(draft_model, "supports_ragged_batch_acceptance", False): + draft_model.reset(model, left_padding=[0] * B) + else: + draft_model.reset(model) sampler_rng = _SpeculativeSamplerRNG(draft_model, enabled=not greedy_sampling) # First-round hidden: prefill output may have shape [B, L, H]; reduce @@ -900,7 +903,7 @@ def _mtp_rounds_batch( # Rollback target cache (uniform trim by ``bs - max_a - 1`` plus # per-row tail-zero on rows that accepted less). - if max_a < bs - 1: + if any(a < bs - 1 for a in accepted_list): with mx.stream(generation_stream): lm.rollback_speculative_cache( prompt_cache, verify.gdn_states, accepted_arr, bs diff --git a/mlx_vlm/tests/test_speculative.py b/mlx_vlm/tests/test_speculative.py index 1763f474d..e0ba0b4be 100644 --- a/mlx_vlm/tests/test_speculative.py +++ b/mlx_vlm/tests/test_speculative.py @@ -2017,6 +2017,70 @@ def test_qwen3_5_mtp_batch_accept_updates_uniform_cache(): assert drafter._next_position == 5 +def test_qwen3_5_mtp_batch_accept_updates_ragged_cache(): + text_config = _tiny_qwen3_5_text_config() + text_config.mtp_num_hidden_layers = 1 + drafter = Qwen3_5MTPDraftModel( + Qwen3_5MTPConfig(text_config=text_config, block_size=3) + ) + target = SimpleNamespace( + language_model=SimpleNamespace( + model=SimpleNamespace(embed_tokens=nn.Embedding(32, 16)) + ) + ) + drafter.reset(target, left_padding=[0, 0]) + drafter.set_shared_kv( + {}, + kv_offset=4, + position=mx.array([4, 4], dtype=mx.int32), + kv_valid_len=mx.array([4, 4], dtype=mx.int32), + ) + hidden = mx.zeros((2, 1, 16), dtype=mx.float32) + draft_tokens = drafter.draft_block( + mx.array([7, 8], dtype=mx.int32), + hidden, + None, + 3, + lambda logits: mx.argmax(logits, axis=-1), + mx.int32, + greedy=True, + ) + verify_hidden = mx.zeros((2, 3, 16), dtype=mx.float32) + drafter.accept_verified_tokens_batch( + verify_hidden, + draft_tokens, + accepted=[1, 0], + new_tokens=[[int(draft_tokens[0, 0].item()), 5], [6]], + sampler=lambda logits: mx.argmax(logits, axis=-1), + token_dtype=mx.int32, + greedy=True, + ) + + mx.eval(drafter._seed_token, drafter._cache[0].offset) + assert drafter._seed_token.shape == (2, 1) + assert drafter._round_appended == 0 + assert drafter._cache[0]._idx == 3 + assert drafter._cache[0].offset.tolist() == [2, 1] + assert drafter._cache[0].left_padding.tolist() == [1, 2] + assert drafter._next_position.tolist() == [6, 5] + + +def test_qwen3_5_rollback_speculative_cache_trims_batch_rows_ragged(): + text_config = _tiny_qwen3_5_text_config() + language = qwen_language.LanguageModel(args=text_config) + cache = qwen_language.KVCache.merge([qwen_language.KVCache(), qwen_language.KVCache()]) + keys = mx.arange(2 * 1 * 5 * 4, dtype=mx.float32).reshape(2, 1, 5, 4) + values = keys + 100 + cache.update_and_fetch(keys, values) + + language.rollback_speculative_cache([cache], [], mx.array([2, 0]), block_size=3) + + mx.eval(cache.keys, cache.values, cache.offset, cache.left_padding) + assert cache._idx == 5 + assert cache.offset.tolist() == [5, 3] + assert cache.left_padding.tolist() == [0, 2] + + def test_qwen3_5_mtp_filter_batch_keeps_drafter_state_aligned(): text_config = _tiny_qwen3_5_text_config() text_config.mtp_num_hidden_layers = 1 @@ -2063,6 +2127,33 @@ def test_qwen3_5_mtp_filter_batch_keeps_drafter_state_aligned(): assert drafter._next_position.tolist() == [6] +def test_qwen3_5_mtp_filter_batch_keeps_batch_cache_padding_aligned(): + text_config = _tiny_qwen3_5_text_config() + text_config.mtp_num_hidden_layers = 1 + drafter = Qwen3_5MTPDraftModel( + Qwen3_5MTPConfig(text_config=text_config, block_size=3) + ) + target = SimpleNamespace( + language_model=SimpleNamespace( + model=SimpleNamespace(embed_tokens=nn.Embedding(32, 16)) + ) + ) + drafter.reset(target, left_padding=[0, 1, 2]) + drafter._cache[0].update_and_fetch( + mx.zeros((3, 1, 2, 8), dtype=mx.float32), + mx.zeros((3, 1, 2, 8), dtype=mx.float32), + ) + drafter._next_position = mx.array([4, 5, 6], dtype=mx.int32) + + drafter.filter_batch(mx.array([0, 2], dtype=mx.int32)) + + mx.eval(drafter._cache[0].left_padding, drafter._cache[0].offset) + assert drafter._cache[0].keys.shape[0] == 2 + assert drafter._cache[0].left_padding.tolist() == [0, 2] + assert drafter._cache[0].offset.tolist() == [2, 0] + assert drafter._next_position.tolist() == [4, 6] + + def test_qwen3_5_mtp_sanitize_strips_prefix_and_offsets_norms(): weights = { "mtp.fc.weight": mx.ones((2, 4)), From a40c23bbe54435e04003ff7b1aa931a2d8dbc17b Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Fri, 22 May 2026 00:48:44 +0200 Subject: [PATCH 06/26] Keep Qwen MTP ragged greedy and uniform sampled parity Use uniform deferred verification for non-greedy batched MTP so target sampling consumes RNG in the same lockstep order as no-drafter batches. Keep ragged acceptance enabled for greedy decoding, where argmax has no RNG-order drift and preserves the faster batch path. --- mlx_vlm/speculative/mtp.py | 25 ++++++++++++++++++++++--- mlx_vlm/tests/test_speculative.py | 18 ++++++++++++++++++ 2 files changed, 40 insertions(+), 3 deletions(-) diff --git a/mlx_vlm/speculative/mtp.py b/mlx_vlm/speculative/mtp.py index ae1a12d4e..e551fafed 100644 --- a/mlx_vlm/speculative/mtp.py +++ b/mlx_vlm/speculative/mtp.py @@ -301,6 +301,24 @@ def _speculative_walk_batch_deferred_uniform( return [accepted] * B, [draft_lists[row][: budgets[row]] for row in range(B)] +def _mtp_use_uniform_deferred_walk( + draft_model: nn.Module, + *, + n_active: int, + greedy_sampling: bool, +) -> bool: + if n_active <= 1: + return False + + if getattr(draft_model, "requires_uniform_batch_acceptance", False): + return True + + # Non-greedy sampling uses a single target RNG stream. If rows accept a + # ragged number of tokens, the verifier would draw future samples for some + # rows before no-drafter has drawn the next lockstep batch sample. + return not greedy_sampling + + def _mtp_acceptance_walk( lm: nn.Module, verify: _MTPVerifyResult, @@ -817,9 +835,10 @@ def _mtp_rounds_batch( ) else: sampler_rng.target_eval(hidden_full) - if ( - n_active > 1 - and getattr(draft_model, "requires_uniform_batch_acceptance", False) + if _mtp_use_uniform_deferred_walk( + draft_model, + n_active=n_active, + greedy_sampling=greedy_sampling, ): accepted_list, new_tokens_list = ( _speculative_walk_batch_deferred_uniform( diff --git a/mlx_vlm/tests/test_speculative.py b/mlx_vlm/tests/test_speculative.py index e0ba0b4be..33b359efb 100644 --- a/mlx_vlm/tests/test_speculative.py +++ b/mlx_vlm/tests/test_speculative.py @@ -908,6 +908,24 @@ def as_linear(self, hidden): assert fake_head.calls == 1 +def test_mtp_uses_uniform_deferred_walk_for_batched_sampling(): + ragged_drafter = SimpleNamespace(requires_uniform_batch_acceptance=False) + uniform_drafter = SimpleNamespace(requires_uniform_batch_acceptance=True) + + assert not mtp_utils._mtp_use_uniform_deferred_walk( + ragged_drafter, n_active=1, greedy_sampling=False + ) + assert not mtp_utils._mtp_use_uniform_deferred_walk( + ragged_drafter, n_active=2, greedy_sampling=True + ) + assert mtp_utils._mtp_use_uniform_deferred_walk( + ragged_drafter, n_active=2, greedy_sampling=False + ) + assert mtp_utils._mtp_use_uniform_deferred_walk( + uniform_drafter, n_active=2, greedy_sampling=True + ) + + def test_mtp_draft_hidden_uses_model_hook(): hidden = mx.array([[[1.0, 2.0]]], dtype=mx.float32) lm = SimpleNamespace(speculative_draft_hidden=lambda h: h * 2) From 11eb394cca125cdc452e3734989e794054fb4020 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Fri, 22 May 2026 01:51:11 +0200 Subject: [PATCH 07/26] Enable ragged Qwen MTP for sampled batches Add a positioned target sampler so no-drafter and MTP consume deterministic per-position target draws instead of relying on global RNG order. This keeps sampled batched decoding exact while allowing Qwen MTP to use the ragged acceptance path. --- mlx_vlm/generate.py | 27 ++++++++- mlx_vlm/server/app.py | 1 + mlx_vlm/server/generation.py | 99 +++++++++++++++++++++++++++---- mlx_vlm/speculative/mtp.py | 33 ++++++++++- mlx_vlm/speculative/utils.py | 2 + mlx_vlm/tests/test_server.py | 69 +++++++++++++++++++++ mlx_vlm/tests/test_speculative.py | 88 +++++++++++++++++++++++++-- 7 files changed, 301 insertions(+), 18 deletions(-) diff --git a/mlx_vlm/generate.py b/mlx_vlm/generate.py index f63c1215a..b12b8884f 100644 --- a/mlx_vlm/generate.py +++ b/mlx_vlm/generate.py @@ -1563,6 +1563,19 @@ class PromptProgress: prompt_time: float = 0.0 +def _sample_with_positions( + sampler: Callable[[mx.array], mx.array], + logprobs: mx.array, + *, + row_ids: Optional[List[int]] = None, + positions: Optional[List[int]] = None, +) -> mx.array: + sample_target = getattr(sampler, "sample_target", None) + if callable(sample_target) and row_ids is not None and positions is not None: + return sample_target(logprobs, row_ids=row_ids, positions=positions) + return sampler(logprobs) + + class GenerationBatch: """ Batched token generator with double-buffered pipelining. @@ -1661,7 +1674,12 @@ def _step(self): logits = mx.concatenate(processed_logits, axis=0) logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) - sampled = self.sampler(logprobs) + sampled = _sample_with_positions( + self.sampler, + logprobs, + row_ids=[0] * len(self.uids), + positions=[n + 1 for n in self._num_tokens], + ) self._next_tokens = sampled prev_top_idx = self._next_top_idx @@ -2174,7 +2192,12 @@ def generate( logits = mx.concatenate(processed_logits, axis=0) logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) - first_tokens = sampler(logprobs) + first_tokens = _sample_with_positions( + sampler, + logprobs, + row_ids=[0] * len(self.uids), + positions=[0] * len(self.uids), + ) mx.async_eval(first_tokens) diff --git a/mlx_vlm/server/app.py b/mlx_vlm/server/app.py index c51787322..6cddc9bb7 100644 --- a/mlx_vlm/server/app.py +++ b/mlx_vlm/server/app.py @@ -105,6 +105,7 @@ def _build_gen_args( top_p=getattr(request, "top_p", DEFAULT_TOP_P), top_k=getattr(request, "top_k", 0), min_p=getattr(request, "min_p", 0.0), + seed=getattr(request, "seed", None), repetition_penalty=getattr(request, "repetition_penalty", None), logit_bias=logit_bias, enable_thinking=enable_thinking, diff --git a/mlx_vlm/server/generation.py b/mlx_vlm/server/generation.py index a483cbdd7..8956ce055 100644 --- a/mlx_vlm/server/generation.py +++ b/mlx_vlm/server/generation.py @@ -20,6 +20,7 @@ DEFAULT_MAX_TOKENS, DEFAULT_PREFILL_STEP_SIZE, DEFAULT_QUANTIZED_KV_START, + DEFAULT_SEED, DEFAULT_TEMPERATURE, DEFAULT_TOP_P, BatchGenerator, @@ -91,9 +92,83 @@ def get_speculative_batch_coalesce_s(): return DEFAULT_SPECULATIVE_BATCH_COALESCE_MS / 1000.0 -def _sample_last_token(logits: mx.array, sampler: Callable[[mx.array], mx.array]): +def _position_seed(seed: int, row_id: int, position: int) -> int: + x = (int(seed) ^ 0x9E3779B9) & 0xFFFFFFFF + x = (x + (int(row_id) + 1) * 0x85EBCA6B) & 0xFFFFFFFF + x = (x ^ ((int(position) + 1) * 0xC2B2AE35)) & 0xFFFFFFFF + x ^= x >> 16 + x = (x * 0x7FEB352D) & 0xFFFFFFFF + x ^= x >> 15 + return int(x & 0xFFFFFFFF) + + +def _position_keys(seed: int, row_ids: List[int], positions: List[int]) -> mx.array: + return mx.stack( + [ + mx.random.key(_position_seed(seed, row, pos)) + for row, pos in zip(row_ids, positions) + ] + ) + + +class _PositionedTargetSampler: + """Server sampler with stateless target draws for ragged verification.""" + + def __init__(self, *, temperature: float, top_p: float, seed: Optional[int]): + self.temperature = float(temperature) + self.top_p = float(top_p) + self.seed = DEFAULT_SEED if seed is None else int(seed) + + def __call__(self, logprobs: mx.array) -> mx.array: + if self.top_p > 0 and self.top_p < 1.0: + return top_p_sampling(logprobs, self.top_p, self.temperature) + return mx.random.categorical(logprobs * (1 / self.temperature)) + + def sample_target( + self, + logprobs: mx.array, + *, + row_ids: List[int], + positions: List[int], + ) -> mx.array: + if logprobs.shape[0] != len(row_ids) or len(row_ids) != len(positions): + raise ValueError("row_ids and positions must match logprobs batch size.") + keys = _position_keys(self.seed, row_ids, positions) + if self.top_p > 0 and self.top_p < 1.0: + return mx.vmap(self._sample_top_p_one, in_axes=(0, 0))(logprobs, keys) + return mx.vmap(self._sample_one, in_axes=(0, 0))(logprobs, keys) + + def _sample_one(self, logprobs: mx.array, key: mx.array) -> mx.array: + return mx.random.categorical(logprobs * (1 / self.temperature), key=key) + + def _sample_top_p_one(self, logprobs: mx.array, key: mx.array) -> mx.array: + if logprobs.dtype == mx.bfloat16: + logprobs = logprobs.astype(mx.float32) + probs = mx.softmax(logprobs / self.temperature, axis=-1) + sorted_indices = mx.argsort(probs, axis=-1) + sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=-1) + cumulative_probs = mx.cumsum(sorted_probs, axis=-1) + top_probs = mx.where( + cumulative_probs > 1 - self.top_p, + sorted_probs, + mx.zeros_like(sorted_probs), + ) + sampled_pos = mx.random.categorical(mx.log(top_probs), key=key) + return mx.take_along_axis(sorted_indices, sampled_pos[..., None], axis=-1)[0] + + +def _sample_last_token( + logits: mx.array, + sampler: Callable[[mx.array], mx.array], + *, + row_ids: Optional[List[int]] = None, + positions: Optional[List[int]] = None, +): logits = logits[:, -1, :] logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) + sample_target = getattr(sampler, "sample_target", None) + if callable(sample_target) and row_ids is not None and positions is not None: + return sample_target(logprobs, row_ids=row_ids, positions=positions) return sampler(logprobs) @@ -698,14 +773,11 @@ def _cpu_preprocess(self, prompt, images=None, audio=None) -> dict: def _make_sampler(self, args: GenerationArguments) -> Optional[Callable]: if args.temperature == 0: return None - - def sampler(logprobs: mx.array) -> mx.array: - if args.top_p > 0 and args.top_p < 1.0: - return top_p_sampling(logprobs, args.top_p, args.temperature) - else: - return mx.random.categorical(logprobs * (1 / args.temperature)) - - return sampler + return _PositionedTargetSampler( + temperature=args.temperature, + top_p=args.top_p, + seed=args.seed, + ) def _gpu_embed(self, raw_inputs: dict, images=None) -> Tuple[mx.array, dict]: """GPU-only: run vision encoder if needed. Must run on GPU thread.""" @@ -999,7 +1071,13 @@ def _run_speculative(self): out = lm(input_mx, cache=prompt_cache, **lm_call_kwargs) hidden = speculative_hidden_state(draft_kind, out) shared_kv_states = out.shared_kv_states if is_mtp else None - first_bonus = _sample_last_token(out.logits, sampler) + sample_row_ids = [0] * B + first_bonus = _sample_last_token( + out.logits, + sampler, + row_ids=sample_row_ids, + positions=[0] * B, + ) mx.eval(first_bonus, hidden, out.logits) prompt_elapsed = time.perf_counter() - prompt_started for uid in uids: @@ -1070,6 +1148,7 @@ def stop_check(seq_idx, token_id): shared_kv_states=shared_kv_states, eos_token_ids=eos_set, prompt_tokens=input_mx, + row_ids=sample_row_ids, ) for tok_list, _ in rounds_iter: for j, tok in enumerate(tok_list): diff --git a/mlx_vlm/speculative/mtp.py b/mlx_vlm/speculative/mtp.py index e551fafed..5dc6dba4f 100644 --- a/mlx_vlm/speculative/mtp.py +++ b/mlx_vlm/speculative/mtp.py @@ -214,6 +214,8 @@ def _speculative_walk_batch_deferred_greedy( draft_tokens: mx.array, sampler: Callable[[mx.array], mx.array], budgets: List[int], + row_ids: Optional[List[int]] = None, + base_positions: Optional[List[int]] = None, ) -> Tuple[List[int], List[List[int]]]: """Batched greedy walk that projects target logits only until all rows stop.""" B = draft_tokens.shape[0] @@ -234,7 +236,19 @@ def _speculative_walk_batch_deferred_greedy( if logits.ndim == 3 and logits.shape[1] == 1: logits = logits[:, 0, :] logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) - target_tokens = sampler(logprobs) + if _sampler_supports_positioned_target(sampler): + if row_ids is None or base_positions is None: + raise ValueError( + "positioned target sampling requires row_ids and " + "base_positions." + ) + target_tokens = sampler.sample_target( + logprobs, + row_ids=row_ids, + positions=[int(position) + pos for position in base_positions], + ) + else: + target_tokens = sampler(logprobs) mx.eval(target_tokens) target_list = [int(token) for token in target_tokens.reshape(-1).tolist()] @@ -301,11 +315,16 @@ def _speculative_walk_batch_deferred_uniform( return [accepted] * B, [draft_lists[row][: budgets[row]] for row in range(B)] +def _sampler_supports_positioned_target(sampler: Callable[[mx.array], mx.array]) -> bool: + return callable(getattr(sampler, "sample_target", None)) + + def _mtp_use_uniform_deferred_walk( draft_model: nn.Module, *, n_active: int, greedy_sampling: bool, + sampler: Callable[[mx.array], mx.array], ) -> bool: if n_active <= 1: return False @@ -313,6 +332,9 @@ def _mtp_use_uniform_deferred_walk( if getattr(draft_model, "requires_uniform_batch_acceptance", False): return True + if _sampler_supports_positioned_target(sampler): + return False + # Non-greedy sampling uses a single target RNG stream. If rows accept a # ragged number of tokens, the verifier would draw future samples for some # rows before no-drafter has drawn the next lockstep batch sample. @@ -718,6 +740,7 @@ def _mtp_rounds_batch( stop_check: Optional[Callable[[int, int], bool]] = None, eos_token_ids: Optional[set] = None, greedy_sampling: bool = False, + row_ids: Optional[List[int]] = None, ) -> Generator[Tuple[List[Optional[int]], None], None, None]: """Batched Gemma 4 MTP round loop (B > 1). @@ -735,6 +758,7 @@ def _mtp_rounds_batch( ) B = first_bonus.shape[0] + row_ids = list(range(B)) if row_ids is None else list(row_ids) block_total = _dflash_block_total(draft_model, draft_block_size) configured_block_total = int(getattr(draft_model.config, "block_size", block_total)) if getattr(draft_model, "supports_ragged_batch_acceptance", False): @@ -839,6 +863,7 @@ def _mtp_rounds_batch( draft_model, n_active=n_active, greedy_sampling=greedy_sampling, + sampler=sampler, ): accepted_list, new_tokens_list = ( _speculative_walk_batch_deferred_uniform( @@ -856,8 +881,12 @@ def _mtp_rounds_batch( draft_tokens, sampler, budgets, + row_ids=[row_ids[active_idx[j]] for j in range(n_active)], + base_positions=[emitted[active_idx[j]] for j in range(n_active)], ) - sampler_rng.target_sampled(sync_draft=True) + sampler_rng.target_sampled( + sync_draft=not _sampler_supports_positioned_target(sampler) + ) # Keep the adaptive block-size history on a per-round basis so # batched MTP reacts like the singleton loop instead of letting # batch size change the controller signal. diff --git a/mlx_vlm/speculative/utils.py b/mlx_vlm/speculative/utils.py index 0606e8065..a417cc316 100644 --- a/mlx_vlm/speculative/utils.py +++ b/mlx_vlm/speculative/utils.py @@ -132,6 +132,7 @@ def run_speculative_server_rounds( shared_kv_states: Optional[dict] = None, eos_token_ids: Optional[set] = None, prompt_tokens: Optional[mx.array] = None, + row_ids: Optional[List[int]] = None, ) -> Generator[Tuple[List[Optional[int]], None], None, None]: batch_size = int(first_bonus.shape[0]) if first_bonus.ndim > 0 else 1 @@ -206,6 +207,7 @@ def run_speculative_server_rounds( stop_check=stop_check, eos_token_ids=eos_token_ids, greedy_sampling=greedy_sampling, + row_ids=row_ids, ) return diff --git a/mlx_vlm/tests/test_server.py b/mlx_vlm/tests/test_server.py index 8ff067ad6..06b17e0d7 100644 --- a/mlx_vlm/tests/test_server.py +++ b/mlx_vlm/tests/test_server.py @@ -80,6 +80,75 @@ def sampler(logprobs): assert bool(mx.allclose(seen["values"], expected_logprobs).item()) +def test_speculative_server_samples_first_bonus_with_positioned_sampler(): + seen = {} + logits = mx.array( + [ + [[1.0, 2.0, 3.0]], + [[4.0, 1.0, 0.0]], + ], + dtype=mx.float32, + ) + + class Sampler: + def __call__(self, logprobs): + raise AssertionError("positioned sampler was not used") + + def sample_target(self, logprobs, *, row_ids, positions): + seen["shape"] = logprobs.shape + seen["row_ids"] = list(row_ids) + seen["positions"] = list(positions) + return mx.argmax(logprobs, axis=-1) + + tokens = server_generation._sample_last_token( + logits, + Sampler(), + row_ids=[0, 0], + positions=[0, 0], + ) + mx.eval(tokens) + + assert seen == { + "shape": (2, 3), + "row_ids": [0, 0], + "positions": [0, 0], + } + assert tokens.tolist() == [2, 0] + + +def test_positioned_target_sampler_is_batch_grouping_invariant(): + sampler = server_generation._PositionedTargetSampler( + temperature=0.7, top_p=1.0, seed=42 + ) + logits = mx.array( + [ + [0.0, 1.0, 2.0, 3.0], + [3.0, 2.0, 1.0, 0.0], + ], + dtype=mx.float32, + ) + logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) + + batched = sampler.sample_target( + logprobs, + row_ids=[0, 0], + positions=[5, 5], + ) + single_0 = sampler.sample_target( + logprobs[0:1], + row_ids=[0], + positions=[5], + ) + single_1 = sampler.sample_target( + logprobs[1:2], + row_ids=[0], + positions=[5], + ) + mx.eval(batched, single_0, single_1) + + assert batched.tolist() == [single_0.item(), single_1.item()] + + def test_speculative_server_dispatches_eagle3_batch_loop(): assert ( speculative_utils.get_speculative_rounds_batch("eagle3") diff --git a/mlx_vlm/tests/test_speculative.py b/mlx_vlm/tests/test_speculative.py index 33b359efb..43d3bb245 100644 --- a/mlx_vlm/tests/test_speculative.py +++ b/mlx_vlm/tests/test_speculative.py @@ -867,6 +867,66 @@ def as_linear(self, hidden): assert fake_head.calls == 3 +def test_speculative_walk_batch_deferred_greedy_uses_positioned_sampler(): + class FakeEmbed: + def __init__(self): + self.calls = 0 + + def as_linear(self, hidden): + self.calls += 1 + return hidden + + class PositionedSampler: + def __init__(self): + self.calls = [] + + def __call__(self, logprobs): + raise AssertionError("positioned target sampler was not used") + + def sample_target(self, logprobs, *, row_ids, positions): + self.calls.append((list(row_ids), list(positions))) + return mx.argmax(logprobs, axis=-1) + + fake_head = FakeEmbed() + sampler = PositionedSampler() + lm = SimpleNamespace(speculative_logits_from_hidden=fake_head.as_linear) + target_hidden = mx.array( + [ + [ + [0, 0, 9, 0], + [0, 9, 0, 0], + [0, 0, 0, 9], + ], + [ + [9, 0, 0, 0], + [0, 0, 9, 0], + [0, 9, 0, 0], + ], + ], + dtype=mx.float32, + ) + draft_tokens = mx.array([[2, 3], [0, 2]], dtype=mx.int32) + + accepted, new_tokens = _speculative_walk_batch_deferred_greedy( + lm, + target_hidden, + draft_tokens, + sampler, + budgets=[3, 2], + row_ids=[10, 11], + base_positions=[7, 12], + ) + + assert accepted == [1, 2] + assert new_tokens == [[2, 1], [0, 2]] + assert fake_head.calls == 3 + assert sampler.calls == [ + ([10, 11], [7, 12]), + ([10, 11], [8, 13]), + ([10, 11], [9, 14]), + ] + + def test_speculative_walk_batch_deferred_uniform_stops_at_batch_rejection(): class FakeEmbed: def __init__(self): @@ -911,18 +971,38 @@ def as_linear(self, hidden): def test_mtp_uses_uniform_deferred_walk_for_batched_sampling(): ragged_drafter = SimpleNamespace(requires_uniform_batch_acceptance=False) uniform_drafter = SimpleNamespace(requires_uniform_batch_acceptance=True) + normal_sampler = lambda logits: mx.argmax(logits, axis=-1) + positioned_sampler = SimpleNamespace(sample_target=lambda *args, **kwargs: None) assert not mtp_utils._mtp_use_uniform_deferred_walk( - ragged_drafter, n_active=1, greedy_sampling=False + ragged_drafter, + n_active=1, + greedy_sampling=False, + sampler=normal_sampler, ) assert not mtp_utils._mtp_use_uniform_deferred_walk( - ragged_drafter, n_active=2, greedy_sampling=True + ragged_drafter, + n_active=2, + greedy_sampling=True, + sampler=normal_sampler, ) assert mtp_utils._mtp_use_uniform_deferred_walk( - ragged_drafter, n_active=2, greedy_sampling=False + ragged_drafter, + n_active=2, + greedy_sampling=False, + sampler=normal_sampler, + ) + assert not mtp_utils._mtp_use_uniform_deferred_walk( + ragged_drafter, + n_active=2, + greedy_sampling=False, + sampler=positioned_sampler, ) assert mtp_utils._mtp_use_uniform_deferred_walk( - uniform_drafter, n_active=2, greedy_sampling=True + uniform_drafter, + n_active=2, + greedy_sampling=True, + sampler=positioned_sampler, ) From 3aaba7baa41b675832da5368086944618cc6c55f Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Fri, 22 May 2026 14:11:53 +0200 Subject: [PATCH 08/26] Fix server sampler reuse across idle batches --- mlx_vlm/server/generation.py | 5 ++ mlx_vlm/tests/test_server.py | 142 +++++++++++++++++++++++++++++++++++ 2 files changed, 147 insertions(+) diff --git a/mlx_vlm/server/generation.py b/mlx_vlm/server/generation.py index 8956ce055..dbde63343 100644 --- a/mlx_vlm/server/generation.py +++ b/mlx_vlm/server/generation.py @@ -897,6 +897,11 @@ def _run(self): except Exception: pass + if new_items and batch_gen is not None and not active: + if not batch_gen.has_work: + batch_gen.close() + batch_gen = None + for rqueue, raw_inputs, prompt_tokens, args, images in new_items: if batch_gen is None: batch_gen = BatchGenerator( diff --git a/mlx_vlm/tests/test_server.py b/mlx_vlm/tests/test_server.py index 06b17e0d7..3bd07a841 100644 --- a/mlx_vlm/tests/test_server.py +++ b/mlx_vlm/tests/test_server.py @@ -2061,6 +2061,148 @@ def fake_initialize_model(): (str(uid * 10 + 1), "length"), ] + def test_idle_batch_generator_is_recreated_for_new_sampler(self, monkeypatch): + created = [] + next_uid = [1] + + class FakeDetokenizer: + def __init__(self): + self.last_segment = "" + + def reset(self): + self.last_segment = "" + + def add_token(self, token): + self.last_segment = str(token) + + def finalize(self): + pass + + class FakeBatchGenerator: + def __init__(self, *args, **kwargs): + del args + self.sampler = kwargs.get("sampler") + self.closed = False + self._active = {} + created.append(self) + + def insert(self, *args, **kwargs): + del args, kwargs + uid = next_uid[0] + next_uid[0] += 1 + self._active[uid] = True + return (uid,) + + @property + def has_work(self): + return bool(self._active) + + @property + def unprocessed_prompts(self): + return [] + + @property + def has_pending_prompts(self): + return False + + def next(self, **kwargs): + del kwargs + responses = [ + SimpleNamespace( + uid=uid, + token=uid, + token_logprob=0.0, + finish_reason="length", + ) + for uid in list(self._active) + ] + self._active.clear() + return [], responses + + def remove(self, uid): + return self._active.pop(uid, None) is not None + + def close(self): + self.closed = True + + monkeypatch.setattr(server_generation, "BatchGenerator", FakeBatchGenerator) + monkeypatch.setattr( + server_generation, + "make_streaming_detokenizer", + lambda _: FakeDetokenizer(), + ) + + gen = server.ResponseGenerator.__new__(server.ResponseGenerator) + gen.model_path = "demo" + gen.adapter_path = None + gen.model = None + gen.processor = None + gen.config = None + gen.stop_tokens = set() + gen.vision_cache = None + gen.draft_model = None + gen.draft_kind = None + gen.kv_bits = None + gen.kv_group_size = server.DEFAULT_KV_GROUP_SIZE + gen.kv_quant_scheme = server.DEFAULT_KV_QUANT_SCHEME + gen.quantized_kv_start = server.DEFAULT_QUANTIZED_KV_START + gen.top_logprobs_k = 0 + gen.apc_manager = None + gen.tokenizer = SimpleNamespace() + gen.requests = Queue() + gen._stop = False + gen._ready = Event() + gen._load_error = None + gen._cancelled = set() + gen._cancel_lock = Lock() + gen._make_sampler = lambda args: f"sampler-{args.temperature}" + + def fake_initialize_model(): + gen.model = SimpleNamespace(language_model=object()) + gen.processor = SimpleNamespace() + gen.config = SimpleNamespace() + gen.stop_tokens = set() + gen.draft_model = None + gen.draft_kind = None + gen.tokenizer = SimpleNamespace() + + gen._initialize_model = fake_initialize_model + gen._gpu_embed = lambda raw_inputs, images=None: ( + mx.array([[raw_inputs["request_id"]]], dtype=mx.int32), + {}, + ) + + worker = Thread(target=gen._run, daemon=True) + worker.start() + + def run_request(request_id, temperature): + rqueue = Queue() + gen.requests.put( + ( + rqueue, + {"request_id": request_id}, + 1, + server.GenerationArguments(max_tokens=1, temperature=temperature), + None, + ) + ) + ctx = rqueue.get(timeout=1) + assert isinstance(ctx, server.GenerationContext) + item = rqueue.get(timeout=1) + assert item.finish_reason == "length" + assert rqueue.get(timeout=1) is None + + try: + run_request(1, 0.0) + run_request(2, 0.6) + finally: + gen._stop = True + gen.requests.put(None) + worker.join(timeout=2) + + assert [bg.sampler for bg in created] == ["sampler-0.0", "sampler-0.6"] + assert created[0].closed is True + def test_step_attaches_prompt_tps_from_prompt_progress(self): class SimpleTokenizer: vocab = {"hi": 0} From 744372a3b4070d18ec66c26e82a70e0b066b2ffe Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Fri, 22 May 2026 17:30:26 +0200 Subject: [PATCH 09/26] Route server MTP singleton through batch path --- mlx_vlm/speculative/mtp.py | 2 +- mlx_vlm/speculative/utils.py | 19 ---------------- mlx_vlm/tests/test_speculative.py | 36 +++++++++++++++++++++++++++++++ 3 files changed, 37 insertions(+), 20 deletions(-) diff --git a/mlx_vlm/speculative/mtp.py b/mlx_vlm/speculative/mtp.py index 5dc6dba4f..041aa7273 100644 --- a/mlx_vlm/speculative/mtp.py +++ b/mlx_vlm/speculative/mtp.py @@ -742,7 +742,7 @@ def _mtp_rounds_batch( greedy_sampling: bool = False, row_ids: Optional[List[int]] = None, ) -> Generator[Tuple[List[Optional[int]], None], None, None]: - """Batched Gemma 4 MTP round loop (B > 1). + """Batched Gemma 4 MTP round loop (B >= 1). Mirrors ``_dflash_rounds_batch``: per-row state tracked by original index, continuous-batching filter on row finish. Differences vs DFlash diff --git a/mlx_vlm/speculative/utils.py b/mlx_vlm/speculative/utils.py index a417cc316..62648bf28 100644 --- a/mlx_vlm/speculative/utils.py +++ b/mlx_vlm/speculative/utils.py @@ -174,25 +174,6 @@ def run_speculative_server_rounds( return if draft_kind == "mtp": - if batch_size == 1: - yield from ( - ([tok], state) - for tok, state in _mtp_rounds( - model, - draft_model, - prompt_cache, - hidden, - shared_kv_states, - first_bonus=int(first_bonus.reshape(-1).item()), - max_tokens=max_tokens, - sampler=sampler, - draft_block_size=draft_block_size, - token_dtype=token_dtype, - greedy_sampling=greedy_sampling, - ) - ) - return - yield from _mtp_rounds_batch( model, draft_model, diff --git a/mlx_vlm/tests/test_speculative.py b/mlx_vlm/tests/test_speculative.py index 43d3bb245..b22ba9783 100644 --- a/mlx_vlm/tests/test_speculative.py +++ b/mlx_vlm/tests/test_speculative.py @@ -968,6 +968,42 @@ def as_linear(self, hidden): assert fake_head.calls == 1 +def test_mtp_server_singleton_dispatches_batch_rounds(monkeypatch): + calls = [] + + def fake_batch(*args, **kwargs): + calls.append(("batch", args, kwargs)) + yield [3], None + + def fake_single(*args, **kwargs): + raise AssertionError("server MTP singleton should use batch round path") + + monkeypatch.setattr(speculative_utils, "_mtp_rounds_batch", fake_batch) + monkeypatch.setattr(speculative_utils, "_mtp_rounds", fake_single) + + result = list( + speculative_utils.run_speculative_server_rounds( + SimpleNamespace(language_model=SimpleNamespace()), + SimpleNamespace(), + prompt_cache=[], + hidden=mx.zeros((1, 1, 1), dtype=mx.float32), + shared_kv_states={}, + draft_kind="mtp", + first_bonus=mx.array([2], dtype=mx.int32), + max_tokens=4, + sampler=lambda logprobs: mx.argmax(logprobs, axis=-1), + token_dtype=mx.int32, + greedy_sampling=False, + row_ids=[0], + ) + ) + + assert result == [([3], None)] + assert calls + assert calls[0][2]["first_bonus"].tolist() == [2] + assert calls[0][2]["row_ids"] == [0] + + def test_mtp_uses_uniform_deferred_walk_for_batched_sampling(): ragged_drafter = SimpleNamespace(requires_uniform_batch_acceptance=False) uniform_drafter = SimpleNamespace(requires_uniform_batch_acceptance=True) From c28fa1af20ca054839a1090d3a6687f0cf65a014 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 23 May 2026 09:41:21 +0200 Subject: [PATCH 10/26] Fix seeded Qwen MTP CLI parity --- mlx_vlm/generate.py | 113 ++++++++- mlx_vlm/models/qwen3_5/language.py | 384 ++++++++++++++++++++++++++++- mlx_vlm/speculative/mtp.py | 31 ++- mlx_vlm/tests/test_speculative.py | 61 +++++ 4 files changed, 578 insertions(+), 11 deletions(-) diff --git a/mlx_vlm/generate.py b/mlx_vlm/generate.py index b12b8884f..b48303286 100644 --- a/mlx_vlm/generate.py +++ b/mlx_vlm/generate.py @@ -22,6 +22,7 @@ from . import apc as _apc from .models import cache from .prompt_utils import apply_chat_template +from .sample_utils import top_p_sampling from .speculative.utils import format_speculative_stats, run_speculative_rounds from .tokenizer_utils import make_streaming_detokenizer from .turboquant import BatchTurboQuantKVCache, TurboQuantKVCache, turboquant_enabled @@ -57,6 +58,71 @@ DEFAULT_PREFILL_STEP_SIZE = 2048 +def _position_seed(seed: int, row_id: int, position: int) -> int: + x = (int(seed) ^ 0x9E3779B9) & 0xFFFFFFFF + x = (x + (int(row_id) + 1) * 0x85EBCA6B) & 0xFFFFFFFF + x = (x ^ ((int(position) + 1) * 0xC2B2AE35)) & 0xFFFFFFFF + x ^= x >> 16 + x = (x * 0x7FEB352D) & 0xFFFFFFFF + x ^= x >> 15 + return int(x & 0xFFFFFFFF) + + +def _position_keys(seed: int, row_ids: List[int], positions: List[int]) -> mx.array: + return mx.stack( + [ + mx.random.key(_position_seed(seed, row, pos)) + for row, pos in zip(row_ids, positions) + ] + ) + + +class _PositionedTargetSampler: + """Sampler with stateless target draws keyed by generated-token position.""" + + def __init__(self, *, temperature: float, top_p: float, seed: int): + self.temperature = float(temperature) + self.top_p = float(top_p) + self.seed = int(seed) + + def __call__(self, logprobs: mx.array) -> mx.array: + if self.top_p > 0 and self.top_p < 1.0: + return top_p_sampling(logprobs, self.top_p, self.temperature) + return mx.random.categorical(logprobs * (1 / self.temperature)) + + def sample_target( + self, + logprobs: mx.array, + *, + row_ids: List[int], + positions: List[int], + ) -> mx.array: + if logprobs.shape[0] != len(row_ids) or len(row_ids) != len(positions): + raise ValueError("row_ids and positions must match logprobs batch size.") + keys = _position_keys(self.seed, row_ids, positions) + if self.top_p > 0 and self.top_p < 1.0: + return mx.vmap(self._sample_top_p_one, in_axes=(0, 0))(logprobs, keys) + return mx.vmap(self._sample_one, in_axes=(0, 0))(logprobs, keys) + + def _sample_one(self, logprobs: mx.array, key: mx.array) -> mx.array: + return mx.random.categorical(logprobs * (1 / self.temperature), key=key) + + def _sample_top_p_one(self, logprobs: mx.array, key: mx.array) -> mx.array: + if logprobs.dtype == mx.bfloat16: + logprobs = logprobs.astype(mx.float32) + probs = mx.softmax(logprobs / self.temperature, axis=-1) + sorted_indices = mx.argsort(probs, axis=-1) + sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=-1) + cumulative_probs = mx.cumsum(sorted_probs, axis=-1) + top_probs = mx.where( + cumulative_probs > 1 - self.top_p, + sorted_probs, + mx.zeros_like(sorted_probs), + ) + sampled_pos = mx.random.categorical(mx.log(top_probs), key=key) + return mx.take_along_axis(sorted_indices, sampled_pos[..., None], axis=-1)[0] + + def parse_arguments(): parser = argparse.ArgumentParser( description="Generate text from an image using a model." @@ -132,6 +198,12 @@ def parse_arguments(): default=DEFAULT_TEMPERATURE, help="Temperature for sampling.", ) + parser.add_argument( + "--seed", + type=int, + default=DEFAULT_SEED, + help="Seed for random sampling.", + ) parser.add_argument("--chat", action="store_true", help="Chat in multi-turn style.") parser.add_argument("--verbose", action="store_false", help="Detailed output.") parser.add_argument( @@ -463,6 +535,7 @@ def generate_step( *, max_tokens: int = DEFAULT_MAX_TOKENS, temperature: float = DEFAULT_TEMPERATURE, + seed: Optional[int] = None, repetition_penalty: Optional[float] = None, repetition_context_size: Optional[int] = DEFAULT_REPETITION_CONTEXT_SIZE, top_p: float = DEFAULT_TOP_P, @@ -495,6 +568,7 @@ def generate_step( mask: The attention mask (optional). max_tokens (int): Maximum number of tokens to generate. temperature (float): The temperature for sampling, if 0 the argmax is used. + seed (int, optional): Seed for deterministic target sampling. repetition_penalty (float, optional): The penalty factor for repeating tokens. repetition_context_size (int, optional): The number of tokens to @@ -544,12 +618,24 @@ def generate_step( sampler_is_greedy = sampler is None and temperature == 0 if sampler is None: - sampler = make_sampler( - temp=temperature, - top_p=top_p, - min_p=min_p, - top_k=top_k, - ) + if ( + seed is not None + and temperature > 0 + and min_p == DEFAULT_MIN_P + and top_k == DEFAULT_TOP_K + ): + sampler = _PositionedTargetSampler( + temperature=temperature, + top_p=top_p, + seed=seed, + ) + else: + sampler = make_sampler( + temp=temperature, + top_p=top_p, + min_p=min_p, + top_k=top_k, + ) processors = make_logits_processors( logit_bias, repetition_penalty, repetition_context_size @@ -559,6 +645,7 @@ def generate_step( y = input_ids tokens = mx.array([], dtype=input_ids.dtype) + target_sample_position = 0 thinking_budget_criteria = kwargs.pop("thinking_budget_criteria", None) @@ -596,7 +683,7 @@ def generate_step( lm._rope_deltas = None def _step(y, inputs_embeds=None): - nonlocal tokens, kwargs, last_outputs + nonlocal tokens, kwargs, last_outputs, target_sample_position with mx.stream(generation_stream): if "decoder_input_ids" in kwargs: @@ -624,7 +711,13 @@ def _step(y, inputs_embeds=None): quantize_cache_fn(prompt_cache) logprobs = logits - mx.logsumexp(logits) - y = sampler(logprobs) + y = _sample_with_positions( + sampler, + logprobs, + row_ids=[0] * int(logprobs.shape[0]), + positions=[target_sample_position] * int(logprobs.shape[0]), + ) + target_sample_position += 1 if outputs.cross_attention_states is not None: kwargs = {"cross_attention_states": outputs.cross_attention_states} @@ -3325,6 +3418,8 @@ def _generate_batch( def main(): args = parse_arguments() + mx.random.seed(args.seed) + if isinstance(args.image, str): args.image = [args.image] if isinstance(args.audio, str): @@ -3426,6 +3521,7 @@ def main(): stream_kwargs = { "max_tokens": args.max_tokens, "temperature": args.temperature, + "seed": args.seed, "vision_cache": vision_cache, **kwargs, } @@ -3456,6 +3552,7 @@ def main(): "video": args.video, "fps": args.fps, "temperature": args.temperature, + "seed": args.seed, "max_tokens": args.max_tokens, "verbose": args.verbose, "max_kv_size": args.max_kv_size, diff --git a/mlx_vlm/models/qwen3_5/language.py b/mlx_vlm/models/qwen3_5/language.py index 82fb1f6ef..bc8cb6b1e 100644 --- a/mlx_vlm/models/qwen3_5/language.py +++ b/mlx_vlm/models/qwen3_5/language.py @@ -1,4 +1,4 @@ -from functools import partial +from functools import lru_cache, partial from typing import Any, List, Optional import mlx.core as mx @@ -219,6 +219,375 @@ def _target_verify_weight(weight: mx.array, x: mx.array) -> Optional[mx.array]: return out.reshape(B, L, O) +def _target_verify_qlinear_header(bits: int, group_size: int) -> str: + return ( + r""" + using namespace metal; + + constant constexpr int SIMD_SIZE = 32; + constant constexpr int BITS = __BITS__; + constant constexpr int GS = __GS__; + constant constexpr int PACK_FACTOR = (BITS == 5 ? 8 : 32 / BITS); + constant constexpr int BYTES_PER_PACK = (BITS == 5 ? 5 : 32 / 8); + constant constexpr int PACKS_PER_THREAD = 2; + constant constexpr int VALUES_PER_THREAD = PACK_FACTOR * PACKS_PER_THREAD; + constant constexpr int BLOCK_SIZE = VALUES_PER_THREAD * SIMD_SIZE; + constant constexpr int SCALE_STEP_PER_THREAD = GS / VALUES_PER_THREAD; + constant constexpr int RESULTS_PER_SIMDGROUP = 4; + constant constexpr int NUM_SIMDGROUPS = 2; + constant constexpr int BN = RESULTS_PER_SIMDGROUP * NUM_SIMDGROUPS; + + template + inline float load_vector_exact(const device T* x, thread float* x_thread) { + float sum = 0.0f; + if (BITS == 4) { + for (int i = 0; i < VALUES_PER_THREAD; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 16.0f; + x_thread[i + 2] = x[i + 2] / 256.0f; + x_thread[i + 3] = x[i + 3] / 4096.0f; + } + } else if (BITS == 5) { + for (int i = 0; i < VALUES_PER_THREAD; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 32.0f; + x_thread[i + 2] = x[i + 2] / 4.0f; + x_thread[i + 3] = x[i + 3] / 128.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 2.0f; + x_thread[i + 6] = x[i + 6] / 64.0f; + x_thread[i + 7] = x[i + 7] / 8.0f; + } + } + return sum; + } + + inline float qdot_exact( + const device uint8_t* w, + const thread float* x_thread, + float scale, + float bias, + float sum) { + float accum = 0.0f; + if (BITS == 4) { + const device uint16_t* ws = (const device uint16_t*)w; + for (int i = 0; i < (VALUES_PER_THREAD / 4); i++) { + accum += + (x_thread[4 * i] * (ws[i] & 0x000f) + + x_thread[4 * i + 1] * (ws[i] & 0x00f0) + + x_thread[4 * i + 2] * (ws[i] & 0x0f00) + + x_thread[4 * i + 3] * (ws[i] & 0xf000)); + } + } else if (BITS == 5) { + for (int i = 0; i < (VALUES_PER_THREAD / 8); i++) { + const thread float* xt = x_thread + 8 * i; + const device uint8_t* wb = w + 5 * i; + + accum += (wb[0] & 0x1f) * xt[0]; + accum += (wb[0] & 0xe0) * xt[1]; + accum += (wb[1] & 0x3) * (xt[1] * 256.0f); + accum += (wb[1] & 0x7c) * xt[2]; + accum += (wb[1] & 0x80) * xt[3]; + accum += (wb[2] & 0xf) * (xt[3] * 256.0f); + accum += (wb[2] & 0xf0) * xt[4]; + accum += (wb[3] & 0x1) * (xt[4] * 256.0f); + accum += (wb[3] & 0x3e) * xt[5]; + accum += (wb[3] & 0xc0) * xt[6]; + accum += (wb[4] & 0x7) * (xt[6] * 256.0f); + accum += (wb[4] & 0xf8) * xt[7]; + } + } + return scale * accum + sum * bias; + } +""" + .replace("__BITS__", str(bits)) + .replace("__GS__", str(group_size)) + ) + + +_TARGET_VERIFY_QMV_SOURCE = r""" + uint n_tile = threadgroup_position_in_grid.y; + uint b_idx = threadgroup_position_in_grid.z; + uint simd_gid = simdgroup_index_in_threadgroup; + uint simd_lid = thread_index_in_simdgroup; + + int out_row = int(n_tile) * BN + int(simd_gid) * RESULTS_PER_SIMDGROUP; + int in_vec_size_w = K_SIZE * BYTES_PER_PACK / PACK_FACTOR; + int in_vec_size_g = K_SIZE / GS; + + const device uint8_t* ws_base = + (const device uint8_t*)w + out_row * in_vec_size_w + + int(simd_lid) * PACKS_PER_THREAD * BYTES_PER_PACK; + const device T* scales_base = + scales + out_row * in_vec_size_g + int(simd_lid) / SCALE_STEP_PER_THREAD; + const device T* biases_base = + biases + out_row * in_vec_size_g + int(simd_lid) / SCALE_STEP_PER_THREAD; + const device T* x_base = + x + int(b_idx) * VERIFY_T * K_SIZE + int(simd_lid) * VALUES_PER_THREAD; + + float result[VERIFY_T][RESULTS_PER_SIMDGROUP]; + float x_thread[VERIFY_T][VALUES_PER_THREAD]; + for (int t = 0; t < VERIFY_T; ++t) { + for (int row = 0; row < RESULTS_PER_SIMDGROUP; ++row) { + result[t][row] = 0.0f; + } + } + + const device uint8_t* ws = ws_base; + const device T* sc = scales_base; + const device T* bs = biases_base; + const device T* xk = x_base; + + for (int k = 0; k < K_SIZE; k += BLOCK_SIZE) { + float sums[VERIFY_T]; + for (int t = 0; t < VERIFY_T; ++t) { + sums[t] = load_vector_exact(xk + t * K_SIZE, x_thread[t]); + } + + for (int row = 0; row < RESULTS_PER_SIMDGROUP; ++row) { + const device uint8_t* wl = ws + row * in_vec_size_w; + const device T* sl = sc + row * in_vec_size_g; + const device T* bl = bs + row * in_vec_size_g; + float s = float(sl[0]); + float b = float(bl[0]); + for (int t = 0; t < VERIFY_T; ++t) { + result[t][row] += qdot_exact(wl, x_thread[t], s, b, sums[t]); + } + } + + ws += BLOCK_SIZE * BYTES_PER_PACK / PACK_FACTOR; + sc += BLOCK_SIZE / GS; + bs += BLOCK_SIZE / GS; + xk += BLOCK_SIZE; + } + + for (int row = 0; row < RESULTS_PER_SIMDGROUP; ++row) { + int n = out_row + row; + for (int t = 0; t < VERIFY_T; ++t) { + float r = simd_sum(result[t][row]); + if (simd_lid == 0) { + y[(int(b_idx) * VERIFY_T + t) * N_SIZE + n] = T(r); + } + } + } +""" + + +_TARGET_VERIFY_QARGMAX_SOURCE = r""" + uint n_tile = threadgroup_position_in_grid.y; + uint b_idx = threadgroup_position_in_grid.z; + uint simd_gid = simdgroup_index_in_threadgroup; + uint simd_lid = thread_index_in_simdgroup; + + int out_row = int(n_tile) * BN + int(simd_gid) * RESULTS_PER_SIMDGROUP; + int in_vec_size_w = K_SIZE * BYTES_PER_PACK / PACK_FACTOR; + int in_vec_size_g = K_SIZE / GS; + + threadgroup float tile_best_values[VERIFY_T][NUM_SIMDGROUPS]; + threadgroup int tile_best_indices[VERIFY_T][NUM_SIMDGROUPS]; + + const device uint8_t* ws_base = + (const device uint8_t*)w + out_row * in_vec_size_w + + int(simd_lid) * PACKS_PER_THREAD * BYTES_PER_PACK; + const device T* scales_base = + scales + out_row * in_vec_size_g + int(simd_lid) / SCALE_STEP_PER_THREAD; + const device T* biases_base = + biases + out_row * in_vec_size_g + int(simd_lid) / SCALE_STEP_PER_THREAD; + const device T* x_base = + x + int(b_idx) * VERIFY_T * K_SIZE + int(simd_lid) * VALUES_PER_THREAD; + + float result[VERIFY_T][RESULTS_PER_SIMDGROUP]; + float x_thread[VERIFY_T][VALUES_PER_THREAD]; + for (int t = 0; t < VERIFY_T; ++t) { + for (int row = 0; row < RESULTS_PER_SIMDGROUP; ++row) { + result[t][row] = 0.0f; + } + } + + const device uint8_t* ws = ws_base; + const device T* sc = scales_base; + const device T* bs = biases_base; + const device T* xk = x_base; + + for (int k = 0; k < K_SIZE; k += BLOCK_SIZE) { + float sums[VERIFY_T]; + for (int t = 0; t < VERIFY_T; ++t) { + sums[t] = load_vector_exact(xk + t * K_SIZE, x_thread[t]); + } + + for (int row = 0; row < RESULTS_PER_SIMDGROUP; ++row) { + const device uint8_t* wl = ws + row * in_vec_size_w; + const device T* sl = sc + row * in_vec_size_g; + const device T* bl = bs + row * in_vec_size_g; + float s = float(sl[0]); + float b = float(bl[0]); + for (int t = 0; t < VERIFY_T; ++t) { + result[t][row] += qdot_exact(wl, x_thread[t], s, b, sums[t]); + } + } + + ws += BLOCK_SIZE * BYTES_PER_PACK / PACK_FACTOR; + sc += BLOCK_SIZE / GS; + bs += BLOCK_SIZE / GS; + xk += BLOCK_SIZE; + } + + for (int t = 0; t < VERIFY_T; ++t) { + float best_value = -3.4028234663852886e38f; + int best_index = 0; + for (int row = 0; row < RESULTS_PER_SIMDGROUP; ++row) { + int n = out_row + row; + if (n < N_SIZE) { + float rounded = float(T(simd_sum(result[t][row]))); + if (rounded > best_value) { + best_value = rounded; + best_index = n; + } + } + } + + if (simd_lid == 0) { + tile_best_values[t][simd_gid] = best_value; + tile_best_indices[t][simd_gid] = best_index; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (simd_gid == 0 && simd_lid == 0) { + for (int t = 0; t < VERIFY_T; ++t) { + float best = tile_best_values[t][0]; + int best_idx = tile_best_indices[t][0]; + for (int i = 1; i < NUM_SIMDGROUPS; ++i) { + float candidate = tile_best_values[t][i]; + int candidate_idx = tile_best_indices[t][i]; + if (candidate > best) { + best = candidate; + best_idx = candidate_idx; + } + } + int offset = (int(b_idx) * VERIFY_T + t) * NUM_TILES + int(n_tile); + tile_values[offset] = T(best); + tile_indices[offset] = best_idx; + } + } +""" + + +@lru_cache(maxsize=None) +def _target_verify_qmv_kernel(bits, group_size, dtype, verify_t, k_size, n_size): + dtype_name = {mx.bfloat16: "bf16", mx.float16: "fp16"}.get(dtype, "unk") + return mx.fast.metal_kernel( + name=( + "qwen3_5_target_verify_qmv_" + f"b{bits}_gs{group_size}_t{verify_t}_k{k_size}_n{n_size}_{dtype_name}" + ), + input_names=["x", "w", "scales", "biases"], + output_names=["y"], + header=_target_verify_qlinear_header(bits, group_size), + source=_TARGET_VERIFY_QMV_SOURCE, + ) + + +@lru_cache(maxsize=None) +def _target_verify_qargmax_kernel(bits, group_size, dtype, verify_t, k_size, n_size): + dtype_name = {mx.bfloat16: "bf16", mx.float16: "fp16"}.get(dtype, "unk") + return mx.fast.metal_kernel( + name=( + "qwen3_5_target_verify_qargmax_" + f"b{bits}_gs{group_size}_t{verify_t}_k{k_size}_n{n_size}_{dtype_name}" + ), + input_names=["x", "w", "scales", "biases"], + output_names=["tile_values", "tile_indices"], + header=_target_verify_qlinear_header(bits, group_size), + source=_TARGET_VERIFY_QARGMAX_SOURCE, + ) + + +def _can_target_verify_quantized(linear, x: mx.array) -> bool: + if ( + not isinstance(linear, nn.QuantizedLinear) + or x.ndim != 3 + or x.shape[1] <= 1 + or linear.bits not in (4, 5) + or linear.mode != "affine" + or linear.biases is None + or x.dtype not in (mx.bfloat16, mx.float16) + or linear.scales.dtype != x.dtype + or linear.biases.dtype != x.dtype + ): + return False + + _, _, K = x.shape + N = linear.weight.shape[0] + return ( + K == linear.weight.shape[1] * 32 // linear.bits + and K % 512 == 0 + and N % 8 == 0 + ) + + +def _target_verify_quantized_linear(linear, x: mx.array) -> Optional[mx.array]: + if not _can_target_verify_quantized(linear, x): + return None + + B, T, K = x.shape + N = linear.weight.shape[0] + + x = mx.contiguous(x) + kernel = _target_verify_qmv_kernel(linear.bits, linear.group_size, x.dtype, T, K, N) + out = kernel( + inputs=[x, linear.weight, linear.scales, linear.biases], + template=[ + ("T", x.dtype), + ("VERIFY_T", int(T)), + ("K_SIZE", int(K)), + ("N_SIZE", int(N)), + ], + grid=(32, 2 * (N // 8), B), + threadgroup=(32, 2, 1), + output_shapes=[(B, T, N)], + output_dtypes=[x.dtype], + )[0] + if "bias" in linear: + out = out + linear["bias"] + return out + + +def _target_verify_quantized_argmax(linear, x: mx.array) -> Optional[mx.array]: + if not _can_target_verify_quantized(linear, x) or "bias" in linear: + return None + + B, T, K = x.shape + N = linear.weight.shape[0] + num_tiles = N // 8 + + x = mx.contiguous(x) + kernel = _target_verify_qargmax_kernel( + linear.bits, linear.group_size, x.dtype, T, K, N + ) + tile_values, tile_indices = kernel( + inputs=[x, linear.weight, linear.scales, linear.biases], + template=[ + ("T", x.dtype), + ("VERIFY_T", int(T)), + ("K_SIZE", int(K)), + ("N_SIZE", int(N)), + ("NUM_TILES", int(num_tiles)), + ], + grid=(32, 2 * num_tiles, B), + threadgroup=(32, 2, 1), + output_shapes=[(B, T, num_tiles), (B, T, num_tiles)], + output_dtypes=[x.dtype, mx.int32], + ) + best_tile = mx.argmax(tile_values, axis=-1) + return mx.take_along_axis(tile_indices, best_tile[..., None], axis=-1).squeeze( + -1 + ) + + def _target_verify_timewise(fn, x: mx.array) -> mx.array: return mx.concatenate([fn(x[:, i : i + 1]) for i in range(x.shape[1])], axis=1) @@ -240,6 +609,11 @@ def _target_verify_linear(linear, x: mx.array, target_verify: bool) -> mx.array: return linear(x) if isinstance(linear, nn.QuantizedLinear): + if x.shape[0] == 1: + return linear(x) + out = _target_verify_quantized_linear(linear, x) + if out is not None: + return out return _target_verify_timewise(linear, x) if isinstance(linear, nn.Linear) and "bias" not in linear: @@ -1479,6 +1853,14 @@ def speculative_logits_from_hidden(self, hidden: mx.array) -> mx.array: return self.model.embed_tokens.as_linear(hidden) return self.lm_head(hidden) + def speculative_argmax_from_hidden(self, hidden: mx.array) -> Optional[mx.array]: + if not self.args.tie_word_embeddings: + out = _target_verify_quantized_argmax(self.lm_head, hidden) + if out is not None: + return out + logits = self.speculative_logits_from_hidden(hidden) + return mx.argmax(logits, axis=-1) + def speculative_verify_logits(self, inputs: mx.array, cache, sampler): out = self( inputs, diff --git a/mlx_vlm/speculative/mtp.py b/mlx_vlm/speculative/mtp.py index 041aa7273..c04ef1f21 100644 --- a/mlx_vlm/speculative/mtp.py +++ b/mlx_vlm/speculative/mtp.py @@ -142,6 +142,15 @@ def _mtp_verify_target( sample_target_tokens: bool = True, ) -> _MTPVerifyResult: if sample_target_tokens: + argmax_from_hidden = getattr(lm, "speculative_argmax_from_hidden", None) + if callable(argmax_from_hidden): + result = _mtp_verify_without_logits(lm, verify_input, prompt_cache) + if result is not None: + target_tokens = argmax_from_hidden(result.hidden) + if target_tokens is not None: + result.target_tokens = target_tokens + return result + result = _mtp_verify_with_model_method(lm, verify_input, prompt_cache, sampler) if result is not None: return result @@ -176,6 +185,8 @@ def _speculative_walk_deferred_greedy( draft_tokens: mx.array, sampler: Callable[[mx.array], mx.array], budget: int, + row_id: int = 0, + base_position: Optional[int] = None, ) -> Tuple[int, List[int]]: """Greedy MTP walk that projects target logits only until rejection.""" n_draft = draft_tokens.shape[1] @@ -191,7 +202,15 @@ def _speculative_walk_deferred_greedy( if logits.ndim == 3 and logits.shape[1] == 1: logits = logits[:, 0, :] logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) - target_token = sampler(logprobs) + sample_target = getattr(sampler, "sample_target", None) + if callable(sample_target) and base_position is not None: + target_token = sample_target( + logprobs, + row_ids=[row_id], + positions=[int(base_position) + pos], + ) + else: + target_token = sampler(logprobs) mx.eval(target_token) token = int(target_token.reshape(-1).item()) @@ -347,6 +366,8 @@ def _mtp_acceptance_walk( draft_tokens: mx.array, sampler: Callable[[mx.array], mx.array], budget: int, + row_id: int = 0, + base_position: Optional[int] = None, ) -> Tuple[int, List[int]]: if verify.target_tokens is not None: mx.async_eval(verify.target_tokens, verify.hidden) @@ -359,6 +380,8 @@ def _mtp_acceptance_walk( draft_tokens, sampler, budget, + row_id=row_id, + base_position=base_position, ) @@ -557,8 +580,12 @@ def _mtp_rounds( draft_tokens, sampler, max_tokens - emitted, + row_id=0, + base_position=emitted, + ) + sampler_rng.target_sampled( + sync_draft=not _sampler_supports_positioned_target(sampler) ) - sampler_rng.target_sampled(sync_draft=True) _record_speculative_round(draft_model, accepted, bs - 1) for tok in new_tokens: diff --git a/mlx_vlm/tests/test_speculative.py b/mlx_vlm/tests/test_speculative.py index b22ba9783..363ecdd59 100644 --- a/mlx_vlm/tests/test_speculative.py +++ b/mlx_vlm/tests/test_speculative.py @@ -435,6 +435,34 @@ def test_qwen_target_verify_gemv_kernel_matches_singleton_dense_gemv(): assert bool(mx.array_equal(ref, out).item()) +def test_qwen_target_verify_quantized_linear_matches_singleton_path(): + mx.random.seed(15) + linear = nn.QuantizedLinear(512, 16, bias=False, group_size=32, bits=4) + linear.scales = linear.scales.astype(mx.bfloat16) + linear.biases = linear.biases.astype(mx.bfloat16) + x = mx.random.normal((2, 3, 512)).astype(mx.bfloat16) + + ref = qwen_language._target_verify_timewise(linear, x) + out = qwen_language._target_verify_linear(linear, x, target_verify=True) + mx.eval(ref, out) + + assert bool(mx.array_equal(ref, out).item()) + + +def test_qwen_target_verify_quantized_argmax_matches_singleton_path(): + mx.random.seed(16) + linear = nn.QuantizedLinear(512, 16, bias=False, group_size=32, bits=4) + linear.scales = linear.scales.astype(mx.bfloat16) + linear.biases = linear.biases.astype(mx.bfloat16) + + x = mx.random.normal((2, 3, 512)).astype(mx.bfloat16) + ref = mx.argmax(qwen_language._target_verify_timewise(linear, x), axis=-1) + out = qwen_language._target_verify_quantized_argmax(linear, x) + mx.eval(ref, out) + + assert bool(mx.array_equal(ref, out).item()) + + def test_qwen_target_verify_small_projection_matches_singleton_dense_gemv(): mx.random.seed(10) linear = nn.Linear(256, 8, bias=False) @@ -1087,6 +1115,39 @@ def verify_logits(inputs, cache, sampler): assert result.target_tokens is target_tokens +def test_mtp_verify_target_prefers_argmax_hidden_hook_for_greedy_tokens(): + verify_input = mx.array([[7, 8]], dtype=mx.int32) + hidden = mx.array([[[1.0, 0.0], [0.0, 1.0]]], dtype=mx.float32) + target_tokens = mx.array([[3, 4]], dtype=mx.int32) + calls = [] + + def verify_hidden(inputs, cache): + calls.append((inputs, cache)) + return hidden, {"full": ("k", "v")}, ["gdn"] + + lm = SimpleNamespace( + speculative_verify_hidden=verify_hidden, + speculative_argmax_from_hidden=lambda h: target_tokens, + speculative_verify_logits=lambda *args: (_ for _ in ()).throw( + AssertionError("full logits should not be used for greedy argmax") + ), + ) + + result = _mtp_verify_target( + lm, + verify_input, + prompt_cache=["cache"], + sampler=lambda logits: mx.argmax(logits, axis=-1), + ) + + assert calls[0][0] is verify_input + assert calls[0][1] == ["cache"] + assert result.hidden is hidden + assert result.shared_kv_states == {"full": ("k", "v")} + assert result.gdn_states == ["gdn"] + assert result.target_tokens is target_tokens + + def test_mtp_rounds_rolls_back_gemma_without_gdn_states(): class Draft: def __init__(self): From 8169230b4bf589605de263bcdb117d7284d039f6 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 23 May 2026 10:19:31 +0200 Subject: [PATCH 11/26] Speed up exact positioned MTP sampling --- mlx_vlm/speculative/mtp.py | 73 ++++++++++++++++++++++++++----- mlx_vlm/tests/test_speculative.py | 66 ++++++++++++++++++++++++++++ 2 files changed, 128 insertions(+), 11 deletions(-) diff --git a/mlx_vlm/speculative/mtp.py b/mlx_vlm/speculative/mtp.py index c04ef1f21..2c9fe7d56 100644 --- a/mlx_vlm/speculative/mtp.py +++ b/mlx_vlm/speculative/mtp.py @@ -227,6 +227,34 @@ def _speculative_walk_deferred_greedy( return accepted, new_tokens +def _positioned_target_tokens( + lm: nn.Module, + target_hidden: mx.array, + sampler: Callable[[mx.array], mx.array], + *, + row_id: int, + base_position: int, +) -> Optional[mx.array]: + sample_target = getattr(sampler, "sample_target", None) + if not callable(sample_target): + return None + + with mx.stream(generation_stream): + logits = lm.speculative_logits_from_hidden(target_hidden) + if logits.ndim == 3: + if logits.shape[0] != 1: + return None + logits = logits[0] + logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) + positions = [int(base_position) + pos for pos in range(logprobs.shape[0])] + target_tokens = sample_target( + logprobs, + row_ids=[int(row_id)] * len(positions), + positions=positions, + ) + return target_tokens[None, :] + + def _speculative_walk_batch_deferred_greedy( lm: nn.Module, target_hidden: mx.array, @@ -373,6 +401,18 @@ def _mtp_acceptance_walk( mx.async_eval(verify.target_tokens, verify.hidden) return _speculative_walk(draft_tokens, verify.target_tokens, budget) + if base_position is not None: + target_tokens = _positioned_target_tokens( + lm, + verify.hidden, + sampler, + row_id=row_id, + base_position=base_position, + ) + if target_tokens is not None: + mx.async_eval(target_tokens, verify.hidden) + return _speculative_walk(draft_tokens, target_tokens, budget) + mx.async_eval(verify.hidden) return _speculative_walk_deferred_greedy( lm, @@ -505,7 +545,10 @@ def _mtp_rounds( block_total = _dflash_block_total(draft_model, draft_block_size) configured_block_total = int(getattr(draft_model.config, "block_size", block_total)) draft_model.reset(model) - sampler_rng = _SpeculativeSamplerRNG(draft_model, enabled=not greedy_sampling) + sampler_rng = _SpeculativeSamplerRNG( + draft_model, + enabled=not greedy_sampling and not _sampler_supports_positioned_target(sampler), + ) # Hidden from prefill is full prompt-length; reduce to a single slot. # The semantically-correct choice is the *last* prompt token's hidden: @@ -524,7 +567,7 @@ def _mtp_rounds( first_bonus, sampler, token_dtype, - **_mtp_draft_kwargs(draft_model, greedy_sampling), + **_mtp_draft_kwargs(draft_model, greedy_sampling, sampler), ) if hidden.shape[1] > 1: @@ -560,7 +603,7 @@ def _mtp_rounds( bs, sampler, token_dtype, - **_mtp_draft_kwargs(draft_model, greedy_sampling), + **_mtp_draft_kwargs(draft_model, greedy_sampling, sampler), ) with mx.stream(generation_stream): @@ -604,7 +647,7 @@ def _mtp_rounds( new_tokens, sampler, token_dtype, - **_mtp_draft_kwargs(draft_model, greedy_sampling), + **_mtp_draft_kwargs(draft_model, greedy_sampling, sampler), ) # Hidden for next round: pick the slot of the newly accepted bonus. @@ -666,8 +709,13 @@ def _mtp_draft_position(kv_valid_len: Any) -> Any: return mx.maximum(mx.array(kv_valid_len, dtype=mx.int32) - 1, 0) -def _mtp_draft_kwargs(draft_model: nn.Module, greedy_sampling: bool) -> Dict[str, bool]: - if greedy_sampling and getattr(draft_model, "supports_greedy_draft_argmax", False): +def _mtp_draft_kwargs( + draft_model: nn.Module, + greedy_sampling: bool, + sampler: Optional[Callable[[mx.array], mx.array]] = None, +) -> Dict[str, bool]: + greedy_draft = greedy_sampling or _sampler_supports_positioned_target(sampler) + if greedy_draft and getattr(draft_model, "supports_greedy_draft_argmax", False): return {"greedy": True} return {} @@ -696,7 +744,7 @@ def _mtp_draft_block_active( block_size, sampler, token_dtype, - **_mtp_draft_kwargs(draft_model, greedy_sampling), + **_mtp_draft_kwargs(draft_model, greedy_sampling, sampler), ) positions_list = [int(position) for position in positions] @@ -709,7 +757,7 @@ def _mtp_draft_block_active( block_size, sampler, token_dtype, - **_mtp_draft_kwargs(draft_model, greedy_sampling), + **_mtp_draft_kwargs(draft_model, greedy_sampling, sampler), ) rowwise_tokens = [] @@ -736,7 +784,7 @@ def _mtp_draft_block_active( block_size, sampler, token_dtype, - **_mtp_draft_kwargs(draft_model, greedy_sampling), + **_mtp_draft_kwargs(draft_model, greedy_sampling, sampler), ) ) @@ -792,7 +840,10 @@ def _mtp_rounds_batch( draft_model.reset(model, left_padding=[0] * B) else: draft_model.reset(model) - sampler_rng = _SpeculativeSamplerRNG(draft_model, enabled=not greedy_sampling) + sampler_rng = _SpeculativeSamplerRNG( + draft_model, + enabled=not greedy_sampling and not _sampler_supports_positioned_target(sampler), + ) # First-round hidden: prefill output may have shape [B, L, H]; reduce # to a single slot per row (last prompt token's hidden — see comment in @@ -936,7 +987,7 @@ def _mtp_rounds_batch( new_tokens_list, sampler, token_dtype, - **_mtp_draft_kwargs(draft_model, greedy_sampling), + **_mtp_draft_kwargs(draft_model, greedy_sampling, sampler), ) # Per-row hidden: each row picks its own accepted slot from diff --git a/mlx_vlm/tests/test_speculative.py b/mlx_vlm/tests/test_speculative.py index 363ecdd59..fc1dd9fe8 100644 --- a/mlx_vlm/tests/test_speculative.py +++ b/mlx_vlm/tests/test_speculative.py @@ -847,6 +847,72 @@ def as_linear(self, hidden): assert fake_head.calls == 2 +def test_mtp_acceptance_walk_samples_positioned_block_once(): + class FakeEmbed: + def __init__(self): + self.calls = 0 + + def as_linear(self, hidden): + self.calls += 1 + return hidden + + class PositionedSampler: + def __init__(self): + self.calls = [] + + def __call__(self, logprobs): + raise AssertionError("positioned target sampler was not used") + + def sample_target(self, logprobs, *, row_ids, positions): + self.calls.append((list(row_ids), list(positions))) + return mx.argmax(logprobs, axis=-1) + + fake_head = FakeEmbed() + sampler = PositionedSampler() + lm = SimpleNamespace(speculative_logits_from_hidden=fake_head.as_linear) + target_hidden = mx.array( + [ + [ + [0, 0, 9, 0], + [0, 0, 0, 9], + [0, 9, 0, 0], + [9, 0, 0, 0], + ] + ], + dtype=mx.float32, + ) + draft_tokens = mx.array([[2, 1, 3]], dtype=mx.int32) + verify = mtp_utils._MTPVerifyResult(hidden=target_hidden, shared_kv_states={}) + + accepted, new_tokens = mtp_utils._mtp_acceptance_walk( + lm, + verify, + draft_tokens, + sampler, + budget=4, + row_id=5, + base_position=7, + ) + + assert accepted == 1 + assert new_tokens == [2, 3] + assert fake_head.calls == 1 + assert sampler.calls == [([5, 5, 5, 5], [7, 8, 9, 10])] + + +def test_mtp_draft_kwargs_uses_greedy_for_positioned_sampler(): + class PositionedSampler: + def sample_target(self, logprobs, *, row_ids, positions): + return mx.argmax(logprobs, axis=-1) + + draft_model = SimpleNamespace(supports_greedy_draft_argmax=True) + + assert mtp_utils._mtp_draft_kwargs(draft_model, False) == {} + assert mtp_utils._mtp_draft_kwargs(draft_model, False, PositionedSampler()) == { + "greedy": True + } + + def test_speculative_walk_batch_deferred_greedy_matches_batch_walk(): class FakeEmbed: def __init__(self): From c3581f26d154a096c42aeb314444bcabe6895d3d Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 23 May 2026 11:31:11 +0200 Subject: [PATCH 12/26] Use exact qmatvec for Qwen MTP verifier logits --- mlx_vlm/models/qwen3_5/language.py | 3 +++ mlx_vlm/tests/test_speculative.py | 14 ++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/mlx_vlm/models/qwen3_5/language.py b/mlx_vlm/models/qwen3_5/language.py index bc8cb6b1e..e6a80377c 100644 --- a/mlx_vlm/models/qwen3_5/language.py +++ b/mlx_vlm/models/qwen3_5/language.py @@ -1851,6 +1851,9 @@ def __call__( def speculative_logits_from_hidden(self, hidden: mx.array) -> mx.array: if self.args.tie_word_embeddings: return self.model.embed_tokens.as_linear(hidden) + out = _target_verify_quantized_linear(self.lm_head, hidden) + if out is not None: + return out return self.lm_head(hidden) def speculative_argmax_from_hidden(self, hidden: mx.array) -> Optional[mx.array]: diff --git a/mlx_vlm/tests/test_speculative.py b/mlx_vlm/tests/test_speculative.py index fc1dd9fe8..9451ac778 100644 --- a/mlx_vlm/tests/test_speculative.py +++ b/mlx_vlm/tests/test_speculative.py @@ -449,6 +449,20 @@ def test_qwen_target_verify_quantized_linear_matches_singleton_path(): assert bool(mx.array_equal(ref, out).item()) +def test_qwen_target_verify_quantized_linear_matches_singleton_batch_path(): + mx.random.seed(17) + linear = nn.QuantizedLinear(512, 16, bias=False, group_size=32, bits=4) + linear.scales = linear.scales.astype(mx.bfloat16) + linear.biases = linear.biases.astype(mx.bfloat16) + x = mx.random.normal((1, 3, 512)).astype(mx.bfloat16) + + ref = linear(x) + out = qwen_language._target_verify_quantized_linear(linear, x) + mx.eval(ref, out) + + assert bool(mx.array_equal(ref, out).item()) + + def test_qwen_target_verify_quantized_argmax_matches_singleton_path(): mx.random.seed(16) linear = nn.QuantizedLinear(512, 16, bias=False, group_size=32, bits=4) From 9b25b25c8a50097c952a14b2295421bf3b34d1e6 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 23 May 2026 14:52:30 +0200 Subject: [PATCH 13/26] Fuse Qwen GDN accepted state scatter --- mlx_vlm/models/qwen3_5/gated_delta.py | 142 ++++++++++++++++++++++++++ mlx_vlm/models/qwen3_5/language.py | 127 ++++++++++++++++------- mlx_vlm/tests/test_speculative.py | 33 ++++++ 3 files changed, 263 insertions(+), 39 deletions(-) diff --git a/mlx_vlm/models/qwen3_5/gated_delta.py b/mlx_vlm/models/qwen3_5/gated_delta.py index 7111730ff..6d8ea7ae0 100644 --- a/mlx_vlm/models/qwen3_5/gated_delta.py +++ b/mlx_vlm/models/qwen3_5/gated_delta.py @@ -276,6 +276,69 @@ def _make_gated_delta_state_kernel(has_mask: bool = False): _gated_delta_state_kernel_masked = _make_gated_delta_state_kernel(True) +def _make_gated_delta_accept_states_kernel(): + if not mx.metal.is_available(): + return None + + return mx.fast.metal_kernel( + name="qwen3_5_gated_delta_accept_states", + input_names=[ + "intermediate_states", + "conv_input", + "live_state", + "live_conv", + "accepted", + ], + output_names=["state_out", "conv_out"], + source=r""" + uint idx = thread_position_in_grid.x; + + if (idx < StateTotal) { + uint dk = idx % Dk; + uint t0 = idx / Dk; + uint dv = t0 % Dv; + t0 /= Dv; + uint hv = t0 % Hv; + uint row = t0 / Hv; + + int step = int(accepted[row]); + bool use_intermediate = step >= 0 && step < T; + StT value; + if (use_intermediate) { + value = intermediate_states[ + ((((row * T + uint(step)) * Hv + hv) * Dv + dv) * Dk + dk) + ]; + } else { + value = live_state[((row * Hv + hv) * Dv + dv) * Dk + dk]; + } + state_out[idx] = static_cast(value); + } + + if (idx < ConvTotal) { + uint c = idx % C; + uint t0 = idx / C; + uint win = t0 % ConvW; + uint row = t0 / ConvW; + + int step = int(accepted[row]); + bool use_intermediate = step >= 0 && step < T; + ConvT value; + if (use_intermediate) { + value = conv_input[ + (row * ConvInputT + uint(step) + 1 + win) * C + c + ]; + } else { + value = live_conv[(row * ConvW + win) * C + c]; + } + conv_out[idx] = static_cast(value); + } + """, + ) + + +_gated_delta_accept_states_kernel = _make_gated_delta_accept_states_kernel() + + def _gated_delta_state_ops( k: mx.array, v: mx.array, @@ -305,6 +368,85 @@ def _gated_delta_state_ops( return state +def _gated_delta_accept_states_ops( + intermediate_states: mx.array, + conv_input: mx.array, + live_state: mx.array, + live_conv: mx.array, + accepted: mx.array, + kernel_size: int, +): + steps = [int(step) for step in accepted.tolist()] + state_rows = [] + conv_rows = [] + state_steps = intermediate_states.shape[1] + for row, step in enumerate(steps): + if 0 <= step < state_steps: + state_rows.append(intermediate_states[row, step]) + conv_rows.append(conv_input[row : row + 1, step + 1 : step + kernel_size]) + else: + state_rows.append(live_state[row]) + conv_rows.append(live_conv[row : row + 1]) + return mx.stack(state_rows, axis=0), mx.concatenate(conv_rows, axis=0) + + +def gated_delta_accept_states( + intermediate_states: mx.array, + conv_input: mx.array, + live_state: mx.array, + live_conv: mx.array, + accepted: mx.array, + kernel_size: int, + use_kernel: bool = True, +): + if accepted.dtype != mx.int32: + accepted = accepted.astype(mx.int32) + + if ( + not use_kernel + or mx.default_device() != mx.gpu + or not mx.metal.is_available() + or _gated_delta_accept_states_kernel is None + ): + return _gated_delta_accept_states_ops( + intermediate_states, + conv_input, + live_state, + live_conv, + accepted, + kernel_size, + ) + + rows, state_steps, Hv, Dv, Dk = intermediate_states.shape + conv_input_t = conv_input.shape[1] + conv_dim = conv_input.shape[-1] + conv_window = int(kernel_size) - 1 + state_total = rows * Hv * Dv * Dk + conv_total = rows * conv_window * conv_dim + total = max(state_total, conv_total) + + return _gated_delta_accept_states_kernel( + inputs=[intermediate_states, conv_input, live_state, live_conv, accepted], + template=[ + ("StT", intermediate_states.dtype), + ("ConvT", conv_input.dtype), + ("T", state_steps), + ("Hv", Hv), + ("Dv", Dv), + ("Dk", Dk), + ("C", conv_dim), + ("ConvW", conv_window), + ("ConvInputT", conv_input_t), + ("StateTotal", state_total), + ("ConvTotal", conv_total), + ], + grid=(total, 1, 1), + threadgroup=(256, 1, 1), + output_shapes=[live_state.shape, live_conv.shape], + output_dtypes=[intermediate_states.dtype, conv_input.dtype], + ) + + def gated_delta_state_update( k: mx.array, v: mx.array, diff --git a/mlx_vlm/models/qwen3_5/language.py b/mlx_vlm/models/qwen3_5/language.py index e6a80377c..5775d412f 100644 --- a/mlx_vlm/models/qwen3_5/language.py +++ b/mlx_vlm/models/qwen3_5/language.py @@ -13,6 +13,7 @@ from ..cache import ArraysCache, KVCache from .config import ModelConfig, TextConfig from .gated_delta import ( + gated_delta_accept_states, gated_delta_state_update, gated_delta_update, gated_delta_update_with_states, @@ -1365,47 +1366,95 @@ def rollback_speculative_cache( if all(len(s) > 11 and s[11] is not None for s in gdn_states): a0 = int(accepted[0].item()) if not is_batch else None - for j, c in enumerate(ssm_caches): - ( - _q, - _k, - _v, - _a, - _b, - _A_log, - _dt_bias, - _init_state, - _mask, - conv_input, - K, - intermediate_states, - *_, - ) = gdn_states[j] - if is_batch: - acc_list = accepted.tolist() - state_steps = intermediate_states.shape[1] - states = [ - ( - intermediate_states[bi, int(acc_list[bi])] - if int(acc_list[bi]) < state_steps - else c[1][bi] + if is_batch: + intermediate_parts = [] + conv_input_parts = [] + live_state_parts = [] + live_conv_parts = [] + layer_batch_sizes = [] + kernel_sizes = [] + + for j, c in enumerate(ssm_caches): + ( + _q, + _k, + _v, + _a, + _b, + _A_log, + _dt_bias, + _init_state, + _mask, + conv_input, + K, + intermediate_states, + *_, + ) = gdn_states[j] + rows = intermediate_states.shape[0] + layer_batch_sizes.append(rows) + kernel_sizes.append(int(K)) + intermediate_parts.append(intermediate_states) + conv_input_parts.append(conv_input) + + live_state = c[1] + if live_state is None: + live_state = mx.zeros( + ( + rows, + intermediate_states.shape[2], + intermediate_states.shape[3], + intermediate_states.shape[4], + ), + dtype=intermediate_states.dtype, ) - for bi in range(len(acc_list)) - ] - c[1] = mx.stack(states, axis=0) - slices = [ - ( - conv_input[ - bi : bi + 1, - int(acc_list[bi]) + 1 : int(acc_list[bi]) + K, - ] - if int(acc_list[bi]) < state_steps - else c[0][bi : bi + 1] + live_state_parts.append(live_state) + + live_conv = c[0] + if live_conv is None: + live_conv = mx.zeros( + (rows, int(K) - 1, conv_input.shape[-1]), + dtype=conv_input.dtype, ) - for bi in range(len(acc_list)) - ] - c[0] = mx.concatenate(slices, axis=0) - else: + live_conv_parts.append(live_conv) + + if len(set(kernel_sizes)) != 1: + raise ValueError("Qwen GDN layers must share conv kernel size.") + + accepted_bat = mx.concatenate( + [accepted.astype(mx.int32) for _ in ssm_caches], axis=0 + ) + state_bat, conv_bat = gated_delta_accept_states( + mx.concatenate(intermediate_parts, axis=0), + mx.concatenate(conv_input_parts, axis=0), + mx.concatenate(live_state_parts, axis=0), + mx.concatenate(live_conv_parts, axis=0), + accepted_bat, + kernel_sizes[0], + use_kernel=True, + ) + + offset = 0 + for c, rows in zip(ssm_caches, layer_batch_sizes): + c[1] = state_bat[offset : offset + rows] + c[0] = conv_bat[offset : offset + rows] + offset += rows + else: + for j, c in enumerate(ssm_caches): + ( + _q, + _k, + _v, + _a, + _b, + _A_log, + _dt_bias, + _init_state, + _mask, + conv_input, + K, + intermediate_states, + *_, + ) = gdn_states[j] if a0 < intermediate_states.shape[1]: c[1] = intermediate_states[:, a0] c[0] = conv_input[:, a0 + 1 : a0 + K] diff --git a/mlx_vlm/tests/test_speculative.py b/mlx_vlm/tests/test_speculative.py index 9451ac778..75244e982 100644 --- a/mlx_vlm/tests/test_speculative.py +++ b/mlx_vlm/tests/test_speculative.py @@ -263,6 +263,39 @@ def test_qwen_rollback_speculative_cache_uses_intermediate_states(): ] +def test_qwen_gated_delta_accept_states_matches_python_gather(): + accepted = mx.array([0, 2, 1, 3], dtype=mx.int32) + intermediate_states = mx.arange(4 * 4 * 2 * 3 * 5, dtype=mx.float32).reshape( + 4, 4, 2, 3, 5 + ) + conv_input = mx.arange(4 * 7 * 6, dtype=mx.float32).reshape(4, 7, 6) + live_state = mx.full((4, 2, 3, 5), -1.0, dtype=mx.float32) + live_conv = mx.full((4, 3, 6), -2.0, dtype=mx.float32) + + ref_state, ref_conv = qwen_language.gated_delta_accept_states( + intermediate_states, + conv_input, + live_state, + live_conv, + accepted, + kernel_size=4, + use_kernel=False, + ) + out_state, out_conv = qwen_language.gated_delta_accept_states( + intermediate_states, + conv_input, + live_state, + live_conv, + accepted, + kernel_size=4, + use_kernel=True, + ) + mx.eval(ref_state, ref_conv, out_state, out_conv) + + assert bool(mx.array_equal(ref_state, out_state).item()) + assert bool(mx.array_equal(ref_conv, out_conv).item()) + + def test_qwen_rollback_speculative_cache_zero_inits_missing_state(): accepted = mx.array([1, 0], dtype=mx.int32) caches = [ArraysCache(size=2)] From a94e83c460dee047009e42e541276447c2d543b3 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 23 May 2026 17:21:22 +0200 Subject: [PATCH 14/26] Fix singleton batch generator cache performance --- mlx_vlm/generate.py | 36 ++++++++++++++++++++++++-- mlx_vlm/tests/test_generate.py | 46 ++++++++++++++++++++++++++++++++++ 2 files changed, 80 insertions(+), 2 deletions(-) diff --git a/mlx_vlm/generate.py b/mlx_vlm/generate.py index b48303286..fb2e569df 100644 --- a/mlx_vlm/generate.py +++ b/mlx_vlm/generate.py @@ -1522,8 +1522,33 @@ def _extend_cache(cache_a, cache_b): return cache_b if not cache_b: return cache_a - for ca, cb in zip(cache_a, cache_b): - ca.extend(cb) + + def kv_rows(c): + if isinstance(c, cache.KVCache): + return [c] + offset = getattr(c, "offset", None) + if ( + hasattr(c, "extract") + and isinstance(offset, mx.array) + and offset.ndim > 0 + ): + return [c.extract(i) for i in range(int(offset.shape[0]))] + return None + + for i, (ca, cb) in enumerate(zip(cache_a, cache_b)): + if ( + hasattr(ca, "extend") + and not isinstance(ca, cache.KVCache) + and not isinstance(cb, cache.KVCache) + ): + ca.extend(cb) + continue + ca_rows = kv_rows(ca) + cb_rows = kv_rows(cb) + if ca_rows is not None and cb_rows is not None: + cache_a[i] = cache.BatchKVCache.merge(ca_rows + cb_rows) + continue + raise ValueError(f"{type(ca)} does not yet support batch extension") return cache_a @@ -2076,6 +2101,13 @@ def __init__( if warm_cache is not None: self.prompt_cache = warm_cache + elif ( + len(input_ids) == 1 + and right_pad_per_row is None + and kv_bits is None + and hasattr(model, "make_cache") + ): + self.prompt_cache = cache.make_prompt_cache(model) else: self.prompt_cache = _make_cache( model, diff --git a/mlx_vlm/tests/test_generate.py b/mlx_vlm/tests/test_generate.py index 4858fa042..6ef4da3df 100644 --- a/mlx_vlm/tests/test_generate.py +++ b/mlx_vlm/tests/test_generate.py @@ -21,6 +21,7 @@ _prime_cached_prefix_rope_state, normalize_resize_shape, ) +from mlx_vlm.models import cache as cache_module from mlx_vlm.utils import ThinkingBudgetCriteria generate_module = sys.modules["mlx_vlm.generate"] @@ -1417,6 +1418,51 @@ def fake_prompt_batch(**kwargs): assert prompt_kwargs["per_layer_inputs"][3, :, 0, 0].tolist() == [0, 0, 0, 4] +def test_singleton_prompt_batch_uses_singleton_kv_cache(): + model = SimpleNamespace(make_cache=lambda: [cache_module.KVCache()]) + + batch = generate_module.PromptProcessingBatch( + model=model, + uids=[1], + input_ids=[[1, 2, 3]], + max_tokens=[4], + inputs_embeds=mx.ones((1, 3, 2)), + prompt_kwargs={}, + ) + + assert isinstance(batch.prompt_cache[0], cache_module.KVCache) + assert not isinstance(batch.prompt_cache[0], cache_module.BatchKVCache) + + +def test_extend_cache_promotes_singleton_kv_cache_to_batch(): + left = cache_module.KVCache() + right = cache_module.KVCache() + left.update_and_fetch(mx.ones((1, 1, 3, 2)), mx.ones((1, 1, 3, 2))) + right.update_and_fetch(mx.ones((1, 1, 2, 2)), mx.ones((1, 1, 2, 2))) + + merged = generate_module._extend_cache([left], [right]) + + assert isinstance(merged[0], cache_module.BatchKVCache) + assert merged[0].left_padding.tolist() == [0, 1] + assert merged[0].offset.tolist() == [3, 2] + + +def test_extend_cache_promotes_batch_and_singleton_kv_cache_to_batch(): + left_a = cache_module.KVCache() + left_b = cache_module.KVCache() + right = cache_module.KVCache() + left_a.update_and_fetch(mx.ones((1, 1, 3, 2)), mx.ones((1, 1, 3, 2))) + left_b.update_and_fetch(mx.ones((1, 1, 1, 2)), mx.ones((1, 1, 1, 2))) + right.update_and_fetch(mx.ones((1, 1, 2, 2)), mx.ones((1, 1, 2, 2))) + left = cache_module.BatchKVCache.merge([left_a, left_b]) + + merged = generate_module._extend_cache([left], [right]) + + assert isinstance(merged[0], cache_module.BatchKVCache) + assert merged[0].left_padding.tolist() == [0, 2, 1] + assert merged[0].offset.tolist() == [3, 1, 2] + + def test_mixed_apc_batch_strips_private_kwargs_before_prefill(): bg = object.__new__(BatchGenerator) bg.apc_manager = object() From 774c69e503f7c4ea00a2f7de48020520aa994864 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 24 May 2026 04:39:43 +0200 Subject: [PATCH 15/26] Route server MTP through batch generator --- mlx_vlm/generate.py | 352 +++++++++++++++++++++++++++-- mlx_vlm/models/qwen3_5/language.py | 2 +- mlx_vlm/server/app.py | 1 + mlx_vlm/server/generation.py | 23 +- mlx_vlm/tests/test_generate.py | 34 +++ mlx_vlm/tests/test_server.py | 148 ++++++++++++ 6 files changed, 534 insertions(+), 26 deletions(-) diff --git a/mlx_vlm/generate.py b/mlx_vlm/generate.py index fb2e569df..b0a266b57 100644 --- a/mlx_vlm/generate.py +++ b/mlx_vlm/generate.py @@ -23,7 +23,14 @@ from .models import cache from .prompt_utils import apply_chat_template from .sample_utils import top_p_sampling -from .speculative.utils import format_speculative_stats, run_speculative_rounds +from .speculative.utils import ( + format_speculative_stats, + make_speculative_prompt_cache, + run_speculative_rounds, + run_speculative_server_rounds, + speculative_hidden_state, + speculative_prefill_kwargs, +) from .tokenizer_utils import make_streaming_detokenizer from .turboquant import BatchTurboQuantKVCache, TurboQuantKVCache, turboquant_enabled from .utils import ( @@ -1721,6 +1728,7 @@ def __init__( stop_criteria, max_tokens: List[int], top_logprobs_k: int = 0, + greedy_sampling: bool = False, token_context: Optional[List[List[int]]] = None, logits_processors: Optional[ List[Optional[List[Callable[[mx.array, mx.array], mx.array]]]] @@ -1736,6 +1744,7 @@ def __init__( self._num_tokens = [0] * len(uids) self.compute_logprobs = True self.top_logprobs_k = top_logprobs_k + self.greedy_sampling = greedy_sampling self.logits_processors = logits_processors or [] self.token_context = [list(ctx) for ctx in (token_context or [])] @@ -1752,6 +1761,36 @@ def __init__( def __len__(self): return len(self.uids) + def _greedy_argmax_step(self, inputs: mx.array, fwd_kwargs: dict): + if ( + not self.greedy_sampling + or self.compute_logprobs + or self.top_logprobs_k > 0 + or (self.logits_processors and any(self.logits_processors)) + ): + return None + + argmax_from_hidden = getattr( + self._language_model, "speculative_argmax_from_hidden", None + ) + if not callable(argmax_from_hidden): + return None + + output = self._language_model( + inputs[:, None], + cache=self.prompt_cache, + return_hidden=True, + skip_logits=True, + **fwd_kwargs, + ) + hidden = output.hidden_states[-1] + sampled = argmax_from_hidden(hidden) + if sampled is None: + return None + if sampled.ndim == 2 and sampled.shape[1] == 1: + sampled = sampled[:, 0] + return sampled + def _step(self): """Perform one generation step with double buffering.""" self._current_tokens = self._next_tokens @@ -1762,6 +1801,16 @@ def _step(self): if self._rope_deltas is not None: fwd_kwargs["rope_deltas"] = self._rope_deltas + sampled = self._greedy_argmax_step(inputs, fwd_kwargs) + if sampled is not None: + self._next_tokens = sampled + self._next_lps = None + self._next_top_idx = None + self._next_top_lp = None + mx.async_eval(self._next_tokens) + mx.eval(inputs) + return inputs.tolist(), None, None, None + output = self._language_model( inputs[:, None], cache=self.prompt_cache, **fwd_kwargs ) @@ -1977,7 +2026,13 @@ def next(self) -> List[Response]: @classmethod def empty( - cls, model, sampler, stop_criteria, compute_logprobs=True, top_logprobs_k=0 + cls, + model, + sampler, + stop_criteria, + compute_logprobs=True, + top_logprobs_k=0, + greedy_sampling: bool = False, ): """Create an empty generation batch.""" batch = cls.__new__(cls) @@ -1991,6 +2046,7 @@ def empty( batch._num_tokens = [] batch.compute_logprobs = compute_logprobs batch.top_logprobs_k = top_logprobs_k + batch.greedy_sampling = greedy_sampling batch.token_context = [] batch.logits_processors = [] batch._current_tokens = None @@ -2003,6 +2059,176 @@ def empty( return batch +class SpeculativeGenerationBatch: + """GenerationBatch-compatible wrapper for server-side MTP decode.""" + + is_speculative = True + Response = GenerationBatch.Response + + def __init__( + self, + model: nn.Module, + draft_model: nn.Module, + draft_kind: str, + uids: List[int], + first_tokens: mx.array, + prompt_cache: List[Any], + sampler: Callable[[mx.array], mx.array], + stop_criteria, + max_tokens: List[int], + hidden: mx.array, + shared_kv_states: Optional[dict], + prompt_tokens: mx.array, + *, + draft_block_size: Optional[int] = None, + token_dtype: mx.Dtype = mx.int32, + greedy_sampling: bool = False, + ): + self.model = model + self.draft_model = draft_model + self.draft_kind = draft_kind + self.uids = list(uids) + self._all_uids = list(uids) + self.first_tokens = first_tokens + self.prompt_cache = prompt_cache + self.sampler = sampler + self.stop_criteria = stop_criteria + self.max_tokens = list(max_tokens) + self.hidden = hidden + self.shared_kv_states = shared_kv_states + self.prompt_tokens = prompt_tokens + self.draft_block_size = draft_block_size + self.token_dtype = token_dtype + self.greedy_sampling = greedy_sampling + self._num_tokens = [0] * len(uids) + self._finished = [False] * len(uids) + self._sent_first = False + self._rounds_iter = None + + def __len__(self): + return sum(not done for done in self._finished) + + def _refresh_uids(self): + self.uids = [ + uid for uid, done in zip(self._all_uids, self._finished) if not done + ] + + def extend(self, other: "SpeculativeGenerationBatch"): + if len(self) == 0: + self.__dict__.update(other.__dict__) + return + raise RuntimeError("Cannot extend an active speculative generation batch.") + + def filter(self, keep: List[int]): + keep_uids = {self.uids[idx] for idx in keep} + for i, uid in enumerate(self._all_uids): + if uid not in keep_uids: + self._finished[i] = True + self._refresh_uids() + + def _finish_reason(self, row: int, token: int) -> Optional[str]: + if self.stop_criteria(token): + return "stop" + if self._num_tokens[row] >= self.max_tokens[row]: + return "length" + return None + + def _start_rounds(self): + if self._rounds_iter is not None: + return + + def stop_check(seq_idx, token_id): + return ( + self._finished[seq_idx] + or self.stop_criteria(token_id) + or self._num_tokens[seq_idx] >= self.max_tokens[seq_idx] + ) + + self._rounds_iter = run_speculative_server_rounds( + self.model, + self.draft_model, + self.prompt_cache, + self.hidden, + draft_kind=self.draft_kind, + first_bonus=self.first_tokens, + max_tokens=max(self.max_tokens) if self.max_tokens else 0, + sampler=self.sampler, + draft_block_size=self.draft_block_size, + token_dtype=self.token_dtype, + stop_check=stop_check, + greedy_sampling=self.greedy_sampling, + shared_kv_states=self.shared_kv_states, + eos_token_ids=None, + prompt_tokens=self.prompt_tokens, + row_ids=[0] * len(self._all_uids), + ) + + def next(self) -> List[Response]: + if len(self) == 0: + return [] + + responses: List[GenerationBatch.Response] = [] + if not self._sent_first: + self._sent_first = True + mx.eval(self.first_tokens) + for row, token in enumerate(self.first_tokens.tolist()): + if self._finished[row]: + continue + token = int(token) + self._num_tokens[row] += 1 + finish_reason = self._finish_reason(row, token) + if finish_reason is not None: + self._finished[row] = True + responses.append( + self.Response( + uid=self._all_uids[row], + token=token, + token_logprob=0.0, + finish_reason=finish_reason, + ) + ) + self._refresh_uids() + return responses + + self._start_rounds() + try: + tok_list, _ = next(self._rounds_iter) + except StopIteration: + for row, done in enumerate(self._finished): + if not done: + self._finished[row] = True + responses.append( + self.Response( + uid=self._all_uids[row], + token=None, + token_logprob=0.0, + finish_reason="length", + ) + ) + self._refresh_uids() + return responses + + for row, token in enumerate(tok_list): + if token is None or self._finished[row]: + continue + token = int(token) + self._num_tokens[row] += 1 + finish_reason = self._finish_reason(row, token) + if finish_reason is not None: + self._finished[row] = True + responses.append( + self.Response( + uid=self._all_uids[row], + token=token, + token_logprob=0.0, + finish_reason=finish_reason, + ) + ) + + self._refresh_uids() + return responses + + class PromptProcessingBatch: """ Handles VLM prompt processing with inputs_embeds and chunked prefill. @@ -2033,12 +2259,20 @@ def __init__( right_pad_per_row: Optional[List[int]] = None, suffix_lens: Optional[List[int]] = None, apc_mode: Optional[str] = None, + draft_model: Optional[nn.Module] = None, + draft_kind: Optional[str] = None, + draft_block_size: Optional[int] = None, + greedy_sampling: bool = False, ): self.model = model self.uids = uids self._prompt_uids = list(uids) self.max_tokens = max_tokens self.prefill_step_size = prefill_step_size + self.draft_model = draft_model + self.draft_kind = draft_kind + self.draft_block_size = draft_block_size + self.greedy_sampling = greedy_sampling lengths = [len(ids) for ids in input_ids] max_length = max(lengths) @@ -2101,6 +2335,20 @@ def __init__( if warm_cache is not None: self.prompt_cache = warm_cache + elif draft_model is not None and draft_kind is not None: + self.prompt_cache = make_speculative_prompt_cache( + model, + draft_kind=draft_kind, + batch_size=len(input_ids), + left_padding=left_padding, + make_cache=lambda lm, lp: _make_cache( + lm, + lp, + kv_bits=kv_bits, + kv_group_size=kv_group_size, + kv_quant_scheme=kv_quant_scheme, + ), + ) elif ( len(input_ids) == 1 and right_pad_per_row is None @@ -2287,11 +2535,17 @@ def generate( self, sampler, stop_criteria, compute_logprobs=True, top_logprobs_k=0 ) -> GenerationBatch: """Process final tokens and transition to GenerationBatch.""" + call_kwargs = dict(self._prompt_kwargs) + if self.draft_model is not None and self.draft_kind is not None: + call_kwargs.update( + speculative_prefill_kwargs(self.draft_kind, self.draft_model) + ) + output = self.model( self._input_ids, cache=self.prompt_cache, inputs_embeds=self._inputs_embeds, - **self._prompt_kwargs, + **call_kwargs, ) logits = output.logits if hasattr(output, "logits") else output if self._right_pad_per_row is not None and any(self._right_pad_per_row): @@ -2352,27 +2606,50 @@ def generate( self._suffix_lens, ) - gen_batch = GenerationBatch( - model=self.model, - uids=list(self.uids), - inputs=first_tokens, - prompt_cache=self.prompt_cache, - sampler=sampler, - stop_criteria=stop_criteria, - max_tokens=list(self.max_tokens), - top_logprobs_k=top_logprobs_k, - token_context=[list(ctx) for ctx in self._token_context], - logits_processors=list(self.logits_processors), - ) + if self.draft_model is not None and self.draft_kind is not None: + gen_batch = SpeculativeGenerationBatch( + model=self.model, + draft_model=self.draft_model, + draft_kind=self.draft_kind, + uids=list(self.uids), + first_tokens=first_tokens, + prompt_cache=self.prompt_cache, + sampler=sampler, + stop_criteria=stop_criteria, + max_tokens=list(self.max_tokens), + hidden=speculative_hidden_state(self.draft_kind, output), + shared_kv_states=( + output.shared_kv_states if self.draft_kind == "mtp" else None + ), + prompt_tokens=self._input_ids, + draft_block_size=self.draft_block_size, + token_dtype=self._input_ids.dtype, + greedy_sampling=self.greedy_sampling, + ) + compute_logprobs = False + else: + gen_batch = GenerationBatch( + model=self.model, + uids=list(self.uids), + inputs=first_tokens, + prompt_cache=self.prompt_cache, + sampler=sampler, + stop_criteria=stop_criteria, + max_tokens=list(self.max_tokens), + top_logprobs_k=top_logprobs_k, + greedy_sampling=self.greedy_sampling, + token_context=[list(ctx) for ctx in self._token_context], + logits_processors=list(self.logits_processors), + ) gen_batch.compute_logprobs = compute_logprobs - if compute_logprobs: + if compute_logprobs and isinstance(gen_batch, GenerationBatch): gen_batch._next_lps = logprobs[ mx.arange(first_tokens.shape[0]), first_tokens ] # Prime top-K buffers so the first token can emit top_logprobs too. - if top_logprobs_k > 0: + if top_logprobs_k > 0 and isinstance(gen_batch, GenerationBatch): k = top_logprobs_k sort_idx = mx.argsort(logprobs, axis=-1) top_idx = sort_idx[..., -k:][..., ::-1].astype(mx.int32) @@ -2522,6 +2799,10 @@ def __init__( ] = None, stream=None, apc_manager: Optional["_apc.APCManager"] = None, + draft_model: Optional[nn.Module] = None, + draft_kind: Optional[str] = None, + draft_block_size: Optional[int] = None, + greedy_sampling: bool = False, ): self.model = model self.max_tokens = max_tokens @@ -2533,6 +2814,17 @@ def __init__( self.compute_logprobs = compute_logprobs self.top_logprobs_k = top_logprobs_k self.logits_processors = logits_processors or [] + self.draft_model = draft_model + self.draft_kind = draft_kind + self.draft_block_size = draft_block_size + self.greedy_sampling = greedy_sampling or sampler is None + if self.draft_model is not None: + prefill_step_size = None + apc_manager = None + compute_logprobs = False + top_logprobs_k = 0 + self.compute_logprobs = False + self.top_logprobs_k = 0 # APC: opt-out for KV-quantized caches. Plain KV models use block APC; # mixed/custom cache models use exact prompt-cache snapshots. self.apc_mode = None @@ -2562,6 +2854,7 @@ def __init__( self.tokenizer.stopping_criteria, compute_logprobs=self.compute_logprobs, top_logprobs_k=self.top_logprobs_k, + greedy_sampling=self.greedy_sampling, ) self._prompt_batch: Optional[PromptProcessingBatch] = None self._unprocessed_sequences = [] @@ -2844,6 +3137,10 @@ def _build_mixed_prompt_batch( right_pad_per_row=right_pad_per_row, suffix_lens=suffix_lens, apc_mode=apc_mode, + draft_model=self.draft_model, + draft_kind=self.draft_kind, + draft_block_size=self.draft_block_size, + greedy_sampling=self.greedy_sampling, ) def _build_apc_meta_for_cold( @@ -2993,6 +3290,12 @@ def _prompt_batch_progress(prompt_batch) -> List[PromptProgress]: return progress() return [] + def _extend_generation_batch(self, gen_batch) -> None: + if len(self._generation_batch) == 0: + self._generation_batch = gen_batch + else: + self._generation_batch.extend(gen_batch) + def _next(self, **kwargs): generation_responses = [] prompt_responses = [] @@ -3006,6 +3309,11 @@ def _next(self, **kwargs): mx.eval([c.state for c in self._generation_batch.prompt_cache]) mx.clear_cache() + if getattr(self._generation_batch, "is_speculative", False) and len( + self._generation_batch + ) > 0: + return prompt_responses, generation_responses + if len(self._generation_batch) >= self.completion_batch_size: return prompt_responses, generation_responses @@ -3030,7 +3338,7 @@ def _next(self, **kwargs): self._prompt_time_counter += elapsed self._record_prompt_batch_time(self._prompt_batch, elapsed) prompt_responses = self._prompt_batch_progress(self._prompt_batch) - self._generation_batch.extend(gen_batch) + self._extend_generation_batch(gen_batch) self._prompt_batch = None mx.clear_cache() return prompt_responses, generation_responses @@ -3073,7 +3381,7 @@ def _next(self, **kwargs): self._prompt_time_counter += elapsed self._record_prompt_batch_time(self._prompt_batch, elapsed) prompt_responses = self._prompt_batch_progress(self._prompt_batch) - self._generation_batch.extend(gen_batch) + self._extend_generation_batch(gen_batch) self._prompt_batch = None mx.clear_cache() return prompt_responses, generation_responses @@ -3108,6 +3416,10 @@ def _next(self, **kwargs): apc_meta=apc_meta, apc_manager=self.apc_manager, apc_mode=self.apc_mode, + draft_model=self.draft_model, + draft_kind=self.draft_kind, + draft_block_size=self.draft_block_size, + greedy_sampling=self.greedy_sampling, ) self._prompt_tokens_counter += self._prompt_batch.total_prompt_tokens @@ -3129,7 +3441,7 @@ def _next(self, **kwargs): self._prompt_time_counter += elapsed self._record_prompt_batch_time(self._prompt_batch, elapsed) prompt_responses = self._prompt_batch_progress(self._prompt_batch) - self._generation_batch.extend(gen_batch) + self._extend_generation_batch(gen_batch) self._prompt_batch = None mx.clear_cache() diff --git a/mlx_vlm/models/qwen3_5/language.py b/mlx_vlm/models/qwen3_5/language.py index 5775d412f..cb0907b87 100644 --- a/mlx_vlm/models/qwen3_5/language.py +++ b/mlx_vlm/models/qwen3_5/language.py @@ -511,7 +511,7 @@ def _can_target_verify_quantized(linear, x: mx.array) -> bool: if ( not isinstance(linear, nn.QuantizedLinear) or x.ndim != 3 - or x.shape[1] <= 1 + or x.shape[1] < 1 or linear.bits not in (4, 5) or linear.mode != "affine" or linear.biases is None diff --git a/mlx_vlm/server/app.py b/mlx_vlm/server/app.py index 6cddc9bb7..0b8f62ce7 100644 --- a/mlx_vlm/server/app.py +++ b/mlx_vlm/server/app.py @@ -106,6 +106,7 @@ def _build_gen_args( top_k=getattr(request, "top_k", 0), min_p=getattr(request, "min_p", 0.0), seed=getattr(request, "seed", None), + logprobs=bool(getattr(request, "logprobs", False)), repetition_penalty=getattr(request, "repetition_penalty", None), logit_bias=logit_bias, enable_thinking=enable_thinking, diff --git a/mlx_vlm/server/generation.py b/mlx_vlm/server/generation.py index dbde63343..94a5ae8af 100644 --- a/mlx_vlm/server/generation.py +++ b/mlx_vlm/server/generation.py @@ -506,6 +506,7 @@ class GenerationArguments: top_k: int = 0 min_p: float = 0.0 seed: Optional[int] = None + logprobs: bool = False repetition_penalty: Optional[float] = None logit_bias: Optional[dict] = None enable_thinking: bool = DEFAULT_ENABLE_THINKING @@ -865,7 +866,7 @@ def _run(self): self._ready.set() - if self.draft_model is not None: + if self.draft_model is not None and self.draft_kind != "mtp": self._run_speculative() return @@ -913,9 +914,14 @@ def _run(self): kv_group_size=self.kv_group_size, kv_quant_scheme=self.kv_quant_scheme, quantized_kv_start=self.quantized_kv_start, - top_logprobs_k=self.top_logprobs_k, + compute_logprobs=bool(args.logprobs), + top_logprobs_k=self.top_logprobs_k if args.logprobs else 0, stream=generation_stream, apc_manager=self.apc_manager, + draft_model=self.draft_model, + draft_kind=self.draft_kind, + draft_block_size=_get_draft_block_size_from_env(), + greedy_sampling=args.temperature == 0, ) # Vision encoder runs on the GPU thread; text tokenization @@ -978,6 +984,9 @@ def _run(self): mx.clear_cache() gc.collect() + if batch_gen is not None and callable(getattr(batch_gen, "close", None)): + batch_gen.close() + def _run_speculative(self): """GPU thread loop with DFlash, EAGLE-3, or MTP speculative decoding. @@ -1241,10 +1250,14 @@ def _step(self, batch_gen, active, gen_kwargs=None): rqueue = info["rqueue"] tok = r.token - if hasattr(tok, "item"): + if tok is None: + text = info["streamer"].finalize() + tok = 0 + elif hasattr(tok, "item"): tok = tok.item() - - text = self._stream_text(info, tok, r.finish_reason) + text = self._stream_text(info, tok, r.finish_reason) + else: + text = self._stream_text(info, tok, r.finish_reason) lp = r.token_logprob diff --git a/mlx_vlm/tests/test_generate.py b/mlx_vlm/tests/test_generate.py index 6ef4da3df..3b6aa2824 100644 --- a/mlx_vlm/tests/test_generate.py +++ b/mlx_vlm/tests/test_generate.py @@ -576,6 +576,40 @@ def force_token_2(tokens, logits): second = batch.next() assert [r.token for r in second] == [2, 2] + def test_generation_batch_uses_greedy_hidden_argmax_without_logprobs(self): + class FastArgmaxModel: + def __init__(self): + self.calls = [] + + def __call__(self, input_ids, cache=None, **kwargs): + del cache + self.calls.append(kwargs) + assert kwargs["return_hidden"] is True + assert kwargs["skip_logits"] is True + hidden = mx.ones((input_ids.shape[0], input_ids.shape[1], 3)) + return SimpleNamespace(hidden_states=[hidden]) + + def speculative_argmax_from_hidden(self, hidden): + return mx.full((hidden.shape[0], hidden.shape[1]), 7, dtype=mx.int32) + + model = FastArgmaxModel() + batch = GenerationBatch( + model=model, + uids=[0, 1], + inputs=mx.array([5, 6], dtype=mx.int32), + prompt_cache=[], + sampler=lambda logprobs: mx.argmax(logprobs, axis=-1), + stop_criteria=lambda token: False, + max_tokens=[2, 2], + greedy_sampling=True, + ) + batch.compute_logprobs = False + + first = batch.next() + assert [r.token for r in first] == [5, 6] + assert batch._next_tokens.tolist() == [7, 7] + assert model.calls == [{"return_hidden": True, "skip_logits": True}] + def test_remove_from_unprocessed(self, mock_model, mock_processor): gen = BatchGenerator( model=mock_model.language_model, diff --git a/mlx_vlm/tests/test_server.py b/mlx_vlm/tests/test_server.py index 3bd07a841..70c4d926d 100644 --- a/mlx_vlm/tests/test_server.py +++ b/mlx_vlm/tests/test_server.py @@ -2061,6 +2061,154 @@ def fake_initialize_model(): (str(uid * 10 + 1), "length"), ] + def test_run_routes_mtp_through_batch_generator(self, monkeypatch): + batch_state = {} + draft_model = object() + + class FakeDetokenizer: + def __init__(self): + self.last_segment = "" + + def reset(self): + self.last_segment = "" + + def add_token(self, token): + self.last_segment = str(token) + + def finalize(self): + pass + + class FakeBatchGenerator: + def __init__(self, *args, **kwargs): + del args + batch_state["kwargs"] = kwargs + self._next_uid = 1 + self._active = {} + self.next_active_sizes = [] + batch_state["instance"] = self + + def insert(self, *args, **kwargs): + del args, kwargs + uid = self._next_uid + self._next_uid += 1 + self._active[uid] = True + return (uid,) + + def remove(self, uid): + return self._active.pop(uid, None) is not None + + @property + def unprocessed_prompts(self): + return [] + + @property + def has_pending_prompts(self): + return False + + def next(self, **kwargs): + del kwargs + self.next_active_sizes.append(len(self._active)) + responses = [ + SimpleNamespace( + uid=uid, + token=uid + 100, + token_logprob=0.0, + finish_reason="length", + ) + for uid in sorted(self._active) + ] + self._active.clear() + return [], responses + + monkeypatch.setattr(server_generation, "BatchGenerator", FakeBatchGenerator) + monkeypatch.setattr( + server_generation, + "_get_draft_block_size_from_env", + lambda: 6, + ) + monkeypatch.setattr( + server_generation, + "make_streaming_detokenizer", + lambda _: FakeDetokenizer(), + ) + + gen = server.ResponseGenerator.__new__(server.ResponseGenerator) + gen.model_path = "demo" + gen.adapter_path = None + gen.model = None + gen.processor = None + gen.config = None + gen.stop_tokens = set() + gen.vision_cache = None + gen.draft_model = None + gen.draft_kind = None + gen.kv_bits = None + gen.kv_group_size = server.DEFAULT_KV_GROUP_SIZE + gen.kv_quant_scheme = server.DEFAULT_KV_QUANT_SCHEME + gen.quantized_kv_start = server.DEFAULT_QUANTIZED_KV_START + gen.top_logprobs_k = 0 + gen.apc_manager = None + gen.tokenizer = SimpleNamespace() + gen.requests = Queue() + gen._stop = False + gen._ready = Event() + gen._load_error = None + gen._cancelled = set() + gen._cancel_lock = Lock() + + def fake_initialize_model(): + gen.model = SimpleNamespace(language_model=object()) + gen.processor = SimpleNamespace() + gen.config = SimpleNamespace() + gen.stop_tokens = set() + gen.draft_model = draft_model + gen.draft_kind = "mtp" + gen.tokenizer = SimpleNamespace() + + gen._initialize_model = fake_initialize_model + gen._run_speculative = lambda: pytest.fail("MTP should use BatchGenerator") + gen._gpu_embed = lambda raw_inputs, images=None: ( + mx.array([[raw_inputs["request_id"]]], dtype=mx.int32), + {}, + ) + + request_queues = [] + for request_id in range(2): + rqueue = Queue() + request_queues.append(rqueue) + gen.requests.put( + ( + rqueue, + {"request_id": request_id}, + 1, + server.GenerationArguments(max_tokens=1, temperature=0), + None, + ) + ) + + worker = Thread(target=gen._run, daemon=True) + worker.start() + + try: + for rqueue in request_queues: + ctx = rqueue.get(timeout=1) + assert isinstance(ctx, server.GenerationContext) + item = rqueue.get(timeout=1) + assert item.finish_reason == "length" + assert rqueue.get(timeout=1) is None + finally: + gen._stop = True + gen.requests.put(None) + worker.join(timeout=2) + + kwargs = batch_state["kwargs"] + assert kwargs["draft_model"] is draft_model + assert kwargs["draft_kind"] == "mtp" + assert kwargs["draft_block_size"] == 6 + assert kwargs["greedy_sampling"] is True + assert kwargs["compute_logprobs"] is False + assert batch_state["instance"].next_active_sizes == [2] + def test_idle_batch_generator_is_recreated_for_new_sampler(self, monkeypatch): created = [] next_uid = [1] From e4471e0551a03e50b4f558ec6f6a5f9f0b251c56 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 24 May 2026 05:00:07 +0200 Subject: [PATCH 16/26] Speed up quantized Qwen batch decode --- mlx_vlm/generate.py | 145 ++++++++++++++++++++++++++++- mlx_vlm/models/qwen3_5/language.py | 7 ++ mlx_vlm/tests/test_generate.py | 47 ++++++++++ 3 files changed, 194 insertions(+), 5 deletions(-) diff --git a/mlx_vlm/generate.py b/mlx_vlm/generate.py index b0a266b57..4a3236444 100644 --- a/mlx_vlm/generate.py +++ b/mlx_vlm/generate.py @@ -63,6 +63,20 @@ DEFAULT_THINKING_END_TOKEN = "" DEFAULT_QUANTIZED_KV_START = 5000 DEFAULT_PREFILL_STEP_SIZE = 2048 +DEFAULT_BATCH_CACHE_EVAL_INTERVAL = 50 + + +def _get_batch_cache_eval_interval() -> int: + raw = os.environ.get("MLX_VLM_BATCH_CACHE_EVAL_INTERVAL") + if raw is None: + return DEFAULT_BATCH_CACHE_EVAL_INTERVAL + try: + return max(0, int(raw)) + except ValueError: + logger.warning( + "Ignoring invalid MLX_VLM_BATCH_CACHE_EVAL_INTERVAL=%r", raw + ) + return DEFAULT_BATCH_CACHE_EVAL_INTERVAL def _position_seed(seed: int, row_id: int, position: int) -> int: @@ -1559,6 +1573,22 @@ def kv_rows(c): return cache_a +def _split_cache_rows(prompt_cache: List[Any], batch_size: int): + if batch_size == 0: + return [] + if batch_size == 1 and prompt_cache: + return [prompt_cache] + + rows = [[] for _ in range(batch_size)] + for cache_entry in prompt_cache: + extract = getattr(cache_entry, "extract", None) + if not callable(extract): + return None + for row in range(batch_size): + rows[row].append(extract(row)) + return rows + + def _make_cache( model, left_padding, @@ -1757,10 +1787,53 @@ def __init__( # Per-sequence MRoPE delta self._rope_deltas = None + self._row_prompt_caches = None def __len__(self): return len(self.uids) + def _prefer_rowwise_batch_decode(self) -> bool: + prefer = getattr(self._language_model, "prefer_rowwise_batch_decode", False) + return bool(prefer() if callable(prefer) else prefer) + + def _can_use_rowwise_greedy_decode(self) -> bool: + return ( + (len(self.uids) > 1 or self._row_prompt_caches is not None) + and self.greedy_sampling + and not self.compute_logprobs + and self.top_logprobs_k == 0 + and not (self.logits_processors and any(self.logits_processors)) + and callable( + getattr(self._language_model, "speculative_argmax_from_hidden", None) + ) + and self._prefer_rowwise_batch_decode() + ) + + def _cache_rows(self): + if self._row_prompt_caches is not None: + return self._row_prompt_caches + return _split_cache_rows(self.prompt_cache, len(self.uids)) + + def _ensure_rowwise_caches(self) -> bool: + if self._row_prompt_caches is not None: + return True + rows = _split_cache_rows(self.prompt_cache, len(self.uids)) + if rows is None: + return False + self._row_prompt_caches = rows + self.prompt_cache = [] + return True + + def cache_states(self): + if self._row_prompt_caches is not None: + return [ + c.state + for row_cache in self._row_prompt_caches + for c in row_cache + if hasattr(c, "state") + ] + return [c.state for c in self.prompt_cache if hasattr(c, "state")] + def _greedy_argmax_step(self, inputs: mx.array, fwd_kwargs: dict): if ( not self.greedy_sampling @@ -1791,6 +1864,38 @@ def _greedy_argmax_step(self, inputs: mx.array, fwd_kwargs: dict): sampled = sampled[:, 0] return sampled + def _step_rowwise_greedy(self, inputs: mx.array, fwd_kwargs: dict): + if not self._ensure_rowwise_caches(): + return None + + argmax_from_hidden = self._language_model.speculative_argmax_from_hidden + next_tokens = [] + for row, row_cache in enumerate(self._row_prompt_caches): + row_kwargs = dict(fwd_kwargs) + if "rope_deltas" in row_kwargs: + row_kwargs["rope_deltas"] = row_kwargs["rope_deltas"][row : row + 1] + output = self._language_model( + inputs[row : row + 1, None], + cache=row_cache, + return_hidden=True, + skip_logits=True, + **row_kwargs, + ) + sampled = argmax_from_hidden(output.hidden_states[-1]) + if sampled is None: + return None + if sampled.ndim == 2 and sampled.shape[1] == 1: + sampled = sampled[:, 0] + next_tokens.append(sampled) + + self._next_tokens = mx.concatenate(next_tokens, axis=0) + self._next_lps = None + self._next_top_idx = None + self._next_top_lp = None + mx.async_eval(self._next_tokens) + mx.eval(inputs) + return inputs.tolist(), None, None, None + def _step(self): """Perform one generation step with double buffering.""" self._current_tokens = self._next_tokens @@ -1801,6 +1906,11 @@ def _step(self): if self._rope_deltas is not None: fwd_kwargs["rope_deltas"] = self._rope_deltas + if self._can_use_rowwise_greedy_decode(): + rowwise = self._step_rowwise_greedy(inputs, fwd_kwargs) + if rowwise is not None: + return rowwise + sampled = self._greedy_argmax_step(inputs, fwd_kwargs) if sampled is not None: self._next_tokens = sampled @@ -1895,7 +2005,17 @@ def extend(self, other: "GenerationBatch"): """Extend this batch with another generation batch.""" self_was_empty = len(self.uids) == 0 self.uids.extend(other.uids) - self.prompt_cache = _extend_cache(self.prompt_cache, other.prompt_cache) + if self._row_prompt_caches is not None or other._row_prompt_caches is not None: + self_rows = self._cache_rows() + other_rows = other._cache_rows() + if self_rows is not None and other_rows is not None: + self._row_prompt_caches = self_rows + other_rows + self.prompt_cache = [] + else: + self.prompt_cache = _extend_cache(self.prompt_cache, other.prompt_cache) + self._row_prompt_caches = None + else: + self.prompt_cache = _extend_cache(self.prompt_cache, other.prompt_cache) self.max_tokens.extend(other.max_tokens) self._num_tokens.extend(other._num_tokens) self.token_context.extend(other.token_context) @@ -1960,6 +2080,7 @@ def filter(self, keep: List[int]): if not keep: self.prompt_cache.clear() + self._row_prompt_caches = None self._current_tokens = None self._current_lps = None self._next_tokens = None @@ -1971,8 +2092,13 @@ def filter(self, keep: List[int]): self.logits_processors = [] else: keep_arr = mx.array(keep, mx.int32) - for c in self.prompt_cache: - c.filter(keep_arr) + if self._row_prompt_caches is not None: + self._row_prompt_caches = [ + self._row_prompt_caches[idx] for idx in keep + ] + else: + for c in self.prompt_cache: + c.filter(keep_arr) if self._next_tokens is not None: self._next_tokens = self._next_tokens[keep_arr] if self._next_lps is not None: @@ -2056,6 +2182,7 @@ def empty( batch._next_top_idx = None batch._next_top_lp = None batch._rope_deltas = None + batch._row_prompt_caches = None return batch @@ -2863,6 +2990,7 @@ def __init__( self._prompt_time_counter = 0 self._gen_tokens_counter = 0 self._steps_counter = 0 + self._cache_eval_interval = _get_batch_cache_eval_interval() self._wire_stack = contextlib.ExitStack() self._wire_stack.enter_context(wired_limit(model, [self._stream])) @@ -3305,8 +3433,15 @@ def _next(self, **kwargs): generation_responses = self._generation_batch.next() self._gen_tokens_counter += len(generation_responses) self._steps_counter += 1 - if self._steps_counter % 50 == 0: - mx.eval([c.state for c in self._generation_batch.prompt_cache]) + if ( + self._cache_eval_interval > 0 + and self._steps_counter % self._cache_eval_interval == 0 + ): + cache_states = getattr(self._generation_batch, "cache_states", None) + if callable(cache_states): + mx.eval(cache_states()) + else: + mx.eval([c.state for c in self._generation_batch.prompt_cache]) mx.clear_cache() if getattr(self._generation_batch, "is_speculative", False) and len( diff --git a/mlx_vlm/models/qwen3_5/language.py b/mlx_vlm/models/qwen3_5/language.py index cb0907b87..eb5150e37 100644 --- a/mlx_vlm/models/qwen3_5/language.py +++ b/mlx_vlm/models/qwen3_5/language.py @@ -1913,6 +1913,13 @@ def speculative_argmax_from_hidden(self, hidden: mx.array) -> Optional[mx.array] logits = self.speculative_logits_from_hidden(hidden) return mx.argmax(logits, axis=-1) + @property + def prefer_rowwise_batch_decode(self) -> bool: + return ( + not self.args.tie_word_embeddings + and isinstance(getattr(self, "lm_head", None), nn.QuantizedLinear) + ) + def speculative_verify_logits(self, inputs: mx.array, cache, sampler): out = self( inputs, diff --git a/mlx_vlm/tests/test_generate.py b/mlx_vlm/tests/test_generate.py index 3b6aa2824..ad39486ad 100644 --- a/mlx_vlm/tests/test_generate.py +++ b/mlx_vlm/tests/test_generate.py @@ -610,6 +610,53 @@ def speculative_argmax_from_hidden(self, hidden): assert batch._next_tokens.tolist() == [7, 7] assert model.calls == [{"return_hidden": True, "skip_logits": True}] + def test_generation_batch_rowwise_decode_survives_filter_to_singleton(self): + class RowCache: + def __init__(self, token): + self.token = token + + class BatchCache: + def extract(self, idx): + return RowCache(10 + idx) + + class RowwiseModel: + prefer_rowwise_batch_decode = True + + def __init__(self): + self.seen = [] + + def __call__(self, input_ids, cache=None, **kwargs): + self.seen.append((input_ids.tolist(), cache[0].token, kwargs)) + hidden = mx.array([[[cache[0].token]]], dtype=mx.float32) + return SimpleNamespace(hidden_states=[hidden]) + + def speculative_argmax_from_hidden(self, hidden): + return hidden.astype(mx.int32)[:, :, 0] + + model = RowwiseModel() + batch = GenerationBatch( + model=model, + uids=[0, 1], + inputs=mx.array([5, 6], dtype=mx.int32), + prompt_cache=[BatchCache()], + sampler=lambda logprobs: mx.argmax(logprobs, axis=-1), + stop_criteria=lambda token: False, + max_tokens=[3, 3], + greedy_sampling=True, + ) + batch.compute_logprobs = False + + first = batch.next() + assert [r.token for r in first] == [5, 6] + assert batch._next_tokens.tolist() == [10, 11] + assert [call[1] for call in model.seen] == [10, 11] + + batch.filter([1]) + second = batch.next() + assert [r.token for r in second] == [11] + assert batch._next_tokens.tolist() == [11] + assert model.seen[-1][1] == 11 + def test_remove_from_unprocessed(self, mock_model, mock_processor): gen = BatchGenerator( model=mock_model.language_model, From 7d62403eb1aa20a5205227e7ab680045b5cc298d Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 24 May 2026 11:29:07 +0200 Subject: [PATCH 17/26] Use true batched Qwen3.5 server decode --- mlx_vlm/generate.py | 117 +---------------------------- mlx_vlm/models/qwen3_5/language.py | 62 +++++++++++---- mlx_vlm/tests/test_generate.py | 47 ------------ 3 files changed, 49 insertions(+), 177 deletions(-) diff --git a/mlx_vlm/generate.py b/mlx_vlm/generate.py index 4a3236444..5aad78229 100644 --- a/mlx_vlm/generate.py +++ b/mlx_vlm/generate.py @@ -1573,22 +1573,6 @@ def kv_rows(c): return cache_a -def _split_cache_rows(prompt_cache: List[Any], batch_size: int): - if batch_size == 0: - return [] - if batch_size == 1 and prompt_cache: - return [prompt_cache] - - rows = [[] for _ in range(batch_size)] - for cache_entry in prompt_cache: - extract = getattr(cache_entry, "extract", None) - if not callable(extract): - return None - for row in range(batch_size): - rows[row].append(extract(row)) - return rows - - def _make_cache( model, left_padding, @@ -1785,53 +1769,12 @@ def __init__( self._next_top_idx = None self._next_top_lp = None - # Per-sequence MRoPE delta self._rope_deltas = None - self._row_prompt_caches = None def __len__(self): return len(self.uids) - def _prefer_rowwise_batch_decode(self) -> bool: - prefer = getattr(self._language_model, "prefer_rowwise_batch_decode", False) - return bool(prefer() if callable(prefer) else prefer) - - def _can_use_rowwise_greedy_decode(self) -> bool: - return ( - (len(self.uids) > 1 or self._row_prompt_caches is not None) - and self.greedy_sampling - and not self.compute_logprobs - and self.top_logprobs_k == 0 - and not (self.logits_processors and any(self.logits_processors)) - and callable( - getattr(self._language_model, "speculative_argmax_from_hidden", None) - ) - and self._prefer_rowwise_batch_decode() - ) - - def _cache_rows(self): - if self._row_prompt_caches is not None: - return self._row_prompt_caches - return _split_cache_rows(self.prompt_cache, len(self.uids)) - - def _ensure_rowwise_caches(self) -> bool: - if self._row_prompt_caches is not None: - return True - rows = _split_cache_rows(self.prompt_cache, len(self.uids)) - if rows is None: - return False - self._row_prompt_caches = rows - self.prompt_cache = [] - return True - def cache_states(self): - if self._row_prompt_caches is not None: - return [ - c.state - for row_cache in self._row_prompt_caches - for c in row_cache - if hasattr(c, "state") - ] return [c.state for c in self.prompt_cache if hasattr(c, "state")] def _greedy_argmax_step(self, inputs: mx.array, fwd_kwargs: dict): @@ -1864,38 +1807,6 @@ def _greedy_argmax_step(self, inputs: mx.array, fwd_kwargs: dict): sampled = sampled[:, 0] return sampled - def _step_rowwise_greedy(self, inputs: mx.array, fwd_kwargs: dict): - if not self._ensure_rowwise_caches(): - return None - - argmax_from_hidden = self._language_model.speculative_argmax_from_hidden - next_tokens = [] - for row, row_cache in enumerate(self._row_prompt_caches): - row_kwargs = dict(fwd_kwargs) - if "rope_deltas" in row_kwargs: - row_kwargs["rope_deltas"] = row_kwargs["rope_deltas"][row : row + 1] - output = self._language_model( - inputs[row : row + 1, None], - cache=row_cache, - return_hidden=True, - skip_logits=True, - **row_kwargs, - ) - sampled = argmax_from_hidden(output.hidden_states[-1]) - if sampled is None: - return None - if sampled.ndim == 2 and sampled.shape[1] == 1: - sampled = sampled[:, 0] - next_tokens.append(sampled) - - self._next_tokens = mx.concatenate(next_tokens, axis=0) - self._next_lps = None - self._next_top_idx = None - self._next_top_lp = None - mx.async_eval(self._next_tokens) - mx.eval(inputs) - return inputs.tolist(), None, None, None - def _step(self): """Perform one generation step with double buffering.""" self._current_tokens = self._next_tokens @@ -1906,11 +1817,6 @@ def _step(self): if self._rope_deltas is not None: fwd_kwargs["rope_deltas"] = self._rope_deltas - if self._can_use_rowwise_greedy_decode(): - rowwise = self._step_rowwise_greedy(inputs, fwd_kwargs) - if rowwise is not None: - return rowwise - sampled = self._greedy_argmax_step(inputs, fwd_kwargs) if sampled is not None: self._next_tokens = sampled @@ -2005,17 +1911,7 @@ def extend(self, other: "GenerationBatch"): """Extend this batch with another generation batch.""" self_was_empty = len(self.uids) == 0 self.uids.extend(other.uids) - if self._row_prompt_caches is not None or other._row_prompt_caches is not None: - self_rows = self._cache_rows() - other_rows = other._cache_rows() - if self_rows is not None and other_rows is not None: - self._row_prompt_caches = self_rows + other_rows - self.prompt_cache = [] - else: - self.prompt_cache = _extend_cache(self.prompt_cache, other.prompt_cache) - self._row_prompt_caches = None - else: - self.prompt_cache = _extend_cache(self.prompt_cache, other.prompt_cache) + self.prompt_cache = _extend_cache(self.prompt_cache, other.prompt_cache) self.max_tokens.extend(other.max_tokens) self._num_tokens.extend(other._num_tokens) self.token_context.extend(other.token_context) @@ -2080,7 +1976,6 @@ def filter(self, keep: List[int]): if not keep: self.prompt_cache.clear() - self._row_prompt_caches = None self._current_tokens = None self._current_lps = None self._next_tokens = None @@ -2092,13 +1987,8 @@ def filter(self, keep: List[int]): self.logits_processors = [] else: keep_arr = mx.array(keep, mx.int32) - if self._row_prompt_caches is not None: - self._row_prompt_caches = [ - self._row_prompt_caches[idx] for idx in keep - ] - else: - for c in self.prompt_cache: - c.filter(keep_arr) + for c in self.prompt_cache: + c.filter(keep_arr) if self._next_tokens is not None: self._next_tokens = self._next_tokens[keep_arr] if self._next_lps is not None: @@ -2182,7 +2072,6 @@ def empty( batch._next_top_idx = None batch._next_top_lp = None batch._rope_deltas = None - batch._row_prompt_caches = None return batch diff --git a/mlx_vlm/models/qwen3_5/language.py b/mlx_vlm/models/qwen3_5/language.py index eb5150e37..9d008df23 100644 --- a/mlx_vlm/models/qwen3_5/language.py +++ b/mlx_vlm/models/qwen3_5/language.py @@ -710,6 +710,32 @@ def _create_qwen3_5_ssm_mask(h: mx.array, cache): return cache.make_mask(h.shape[1]) +def _create_qwen3_5_attention_mask(h: mx.array, cache): + if cache is None: + return create_attention_mask(h, cache) + + if hasattr(cache, "_qwen3_5_decode_left_padding"): + delattr(cache, "_qwen3_5_decode_left_padding") + + left_padding = getattr(cache, "left_padding", None) + if ( + h.shape[1] == 1 + and isinstance(left_padding, mx.array) + and left_padding.ndim > 0 + ): + padding_cache = getattr(cache, "_qwen3_5_left_padding_cache", None) + if padding_cache is None or padding_cache[0] is not left_padding: + pads = [int(p) for p in left_padding.tolist()] + padding_cache = (left_padding, pads, max(pads)) + cache._qwen3_5_left_padding_cache = padding_cache + pads = padding_cache[1] + if padding_cache[2] <= 0: + return None + cache._qwen3_5_decode_left_padding = pads + return "left_padded_decode" + return create_attention_mask(h, cache) + + def _gated_delta_update_verify_decode( q: mx.array, k: mx.array, @@ -748,16 +774,20 @@ def _target_verify_left_padded_attention( if hasattr(cache, "bits") or queries.ndim != 4 or keys.ndim != 4: return None - left_padding = getattr(cache, "left_padding", None) - if not ( - isinstance(left_padding, mx.array) - and left_padding.ndim > 0 - and int(left_padding.max().item()) > 0 - ): + pads = getattr(cache, "_qwen3_5_decode_left_padding", None) + if pads is None: + left_padding = getattr(cache, "left_padding", None) + if not ( + isinstance(left_padding, mx.array) + and left_padding.ndim > 0 + and int(left_padding.max().item()) > 0 + ): + return None + pads = [int(p) for p in left_padding.tolist()] + if max(pads) <= 0: return None row_outputs = {} - pads = [int(p) for p in left_padding.tolist()] for pad in sorted(set(pads)): rows = [i for i, row_pad in enumerate(pads) if row_pad == pad] row_idx = mx.array(rows, dtype=mx.int32) @@ -889,7 +919,13 @@ def __call__( if cache is not None: keys, values = cache.update_and_fetch(keys, values) - if target_verify and L > 1: + left_padded_decode = ( + mask == "left_padded_decode" if isinstance(mask, str) else False + ) + if left_padded_decode: + mask = None + + if (target_verify and L > 1) or left_padded_decode: output = _target_verify_left_padded_attention( queries, keys, values, cache=cache, scale=self.scale, mask=mask ) @@ -1215,6 +1251,7 @@ def __call__( if ( h.shape[0] > 1 + and h.shape[1] > 1 and hidden_sink is None and gdn_sink is None and fa_cache is not None @@ -1268,7 +1305,7 @@ def __call__( cache[i] = cache[i].__class__.merge(entries) return mx.concatenate(row_outputs, axis=0) - fa_mask = create_attention_mask(h, cache[self.fa_idx]) + fa_mask = _create_qwen3_5_attention_mask(h, cache[self.fa_idx]) ssm_mask = _create_qwen3_5_ssm_mask(h, cache[self.ssm_idx]) capture_set = set(capture_layer_ids) if capture_layer_ids else set() @@ -1913,13 +1950,6 @@ def speculative_argmax_from_hidden(self, hidden: mx.array) -> Optional[mx.array] logits = self.speculative_logits_from_hidden(hidden) return mx.argmax(logits, axis=-1) - @property - def prefer_rowwise_batch_decode(self) -> bool: - return ( - not self.args.tie_word_embeddings - and isinstance(getattr(self, "lm_head", None), nn.QuantizedLinear) - ) - def speculative_verify_logits(self, inputs: mx.array, cache, sampler): out = self( inputs, diff --git a/mlx_vlm/tests/test_generate.py b/mlx_vlm/tests/test_generate.py index ad39486ad..3b6aa2824 100644 --- a/mlx_vlm/tests/test_generate.py +++ b/mlx_vlm/tests/test_generate.py @@ -610,53 +610,6 @@ def speculative_argmax_from_hidden(self, hidden): assert batch._next_tokens.tolist() == [7, 7] assert model.calls == [{"return_hidden": True, "skip_logits": True}] - def test_generation_batch_rowwise_decode_survives_filter_to_singleton(self): - class RowCache: - def __init__(self, token): - self.token = token - - class BatchCache: - def extract(self, idx): - return RowCache(10 + idx) - - class RowwiseModel: - prefer_rowwise_batch_decode = True - - def __init__(self): - self.seen = [] - - def __call__(self, input_ids, cache=None, **kwargs): - self.seen.append((input_ids.tolist(), cache[0].token, kwargs)) - hidden = mx.array([[[cache[0].token]]], dtype=mx.float32) - return SimpleNamespace(hidden_states=[hidden]) - - def speculative_argmax_from_hidden(self, hidden): - return hidden.astype(mx.int32)[:, :, 0] - - model = RowwiseModel() - batch = GenerationBatch( - model=model, - uids=[0, 1], - inputs=mx.array([5, 6], dtype=mx.int32), - prompt_cache=[BatchCache()], - sampler=lambda logprobs: mx.argmax(logprobs, axis=-1), - stop_criteria=lambda token: False, - max_tokens=[3, 3], - greedy_sampling=True, - ) - batch.compute_logprobs = False - - first = batch.next() - assert [r.token for r in first] == [5, 6] - assert batch._next_tokens.tolist() == [10, 11] - assert [call[1] for call in model.seen] == [10, 11] - - batch.filter([1]) - second = batch.next() - assert [r.token for r in second] == [11] - assert batch._next_tokens.tolist() == [11] - assert model.seen[-1][1] == 11 - def test_remove_from_unprocessed(self, mock_model, mock_processor): gen = BatchGenerator( model=mock_model.language_model, From 35fd30e43b96eb285804ebbe5239e35512a19c92 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 24 May 2026 21:43:30 +0200 Subject: [PATCH 18/26] Add exact ragged Qwen3.5 decode attention --- mlx_vlm/models/qwen3_5/language.py | 456 +++++++++++++++++++++++++++++ mlx_vlm/tests/test_speculative.py | 108 +++++++ 2 files changed, 564 insertions(+) diff --git a/mlx_vlm/models/qwen3_5/language.py b/mlx_vlm/models/qwen3_5/language.py index 9d008df23..402f9b0f8 100644 --- a/mlx_vlm/models/qwen3_5/language.py +++ b/mlx_vlm/models/qwen3_5/language.py @@ -562,6 +562,11 @@ def _target_verify_quantized_argmax(linear, x: mx.array) -> Optional[mx.array]: return None B, T, K = x.shape + if T == 1 and 1 < B <= 4: + out = _target_verify_quantized_argmax(linear, x.transpose(1, 0, 2)) + if out is not None: + return out.transpose(1, 0) + N = linear.weight.shape[0] num_tiles = N // 8 @@ -762,6 +767,453 @@ def _gated_delta_update_verify_decode( ) +_QWEN3_5_RAGGED_SDPA_ONE_PASS_SOURCE = r""" + uint q_batch_head_idx = threadgroup_position_in_grid.y; + uint simd_gid = simdgroup_index_in_threadgroup; + uint simd_lid = thread_index_in_simdgroup; + + constexpr int BN = 32; + constexpr int BD = 32; + constexpr int qk_per_thread = D_SIZE / BD; + constexpr int v_per_thread = V_SIZE / BD; + + typedef float U; + thread U q[qk_per_thread]; + thread U k[qk_per_thread]; + thread U o[v_per_thread]; + threadgroup U outputs[BN * BD]; + threadgroup U max_scores[BN]; + threadgroup U sum_exp_scores[BN]; + + int batch_idx = int(q_batch_head_idx) / NUM_Q_HEADS; + int q_head_idx = int(q_batch_head_idx) - batch_idx * NUM_Q_HEADS; + int kv_head_idx = q_head_idx / GQA_FACTOR; + int pad = int(pads[batch_idx]); + int N = K_SIZE - pad; + + const device T* qptr = + queries + int(q_batch_head_idx) * D_SIZE + int(simd_lid) * qk_per_thread; + const device T* kptr = + keys + (batch_idx * NUM_KV_HEADS + kv_head_idx) * K_SIZE * D_SIZE + + (pad + int(simd_gid)) * D_SIZE + int(simd_lid) * qk_per_thread; + const device T* vptr = + values + (batch_idx * NUM_KV_HEADS + kv_head_idx) * K_SIZE * V_SIZE + + (pad + int(simd_gid)) * V_SIZE + int(simd_lid) * v_per_thread; + device T* optr = + out + int(q_batch_head_idx) * V_SIZE + int(simd_gid) * v_per_thread; + + U s = U(scale[0]); + for (int i = 0; i < qk_per_thread; i++) { + q[i] = s * qptr[i]; + } + for (int i = 0; i < v_per_thread; i++) { + o[i] = 0; + } + + U max_score = -3.4028234663852886e38f; + U sum_exp_score = 0; + + for (int i = int(simd_gid); i < N; i += BN) { + for (int j = 0; j < qk_per_thread; j++) { + k[j] = kptr[j]; + } + + U score = 0; + for (int j = 0; j < qk_per_thread; j++) { + score += q[j] * k[j]; + } + score = simd_sum(score); + + U new_max = max(max_score, score); + U factor = fast::exp(max_score - new_max); + U exp_score = fast::exp(score - new_max); + + max_score = new_max; + sum_exp_score = sum_exp_score * factor + exp_score; + + for (int j = 0; j < v_per_thread; j++) { + o[j] = o[j] * factor + exp_score * vptr[j]; + } + + kptr += BN * D_SIZE; + vptr += BN * V_SIZE; + } + + if (simd_lid == 0) { + max_scores[simd_gid] = max_score; + sum_exp_scores[simd_gid] = sum_exp_score; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + max_score = max_scores[simd_lid]; + U new_max = simd_max(max_score); + U factor = fast::exp(max_score - new_max); + sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor); + + for (int i = 0; i < v_per_thread; i++) { + outputs[simd_lid * BD + simd_gid] = o[i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor); + o[i] = sum_exp_score == 0 ? o[i] : (o[i] / sum_exp_score); + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + if (simd_lid == 0) { + for (int i = 0; i < v_per_thread; i++) { + optr[i] = static_cast(o[i]); + } + } +""" + + +_QWEN3_5_RAGGED_SDPA_TWO_PASS_1_SOURCE = r""" + uint simd_lid = thread_index_in_simdgroup; + uint kv_head_idx = threadgroup_position_in_grid.x; + uint batch_idx = threadgroup_position_in_grid.y; + uint block_idx = threadgroup_position_in_grid.z; + uint gqa_idx = thread_position_in_threadgroup.y; + + constexpr int BD = 32; + constexpr int qk_per_thread = D_SIZE / BD; + constexpr int v_per_thread = V_SIZE / BD; + + typedef float U; + thread U q[qk_per_thread]; + thread U o[v_per_thread] = {0}; + + int q_head_idx = int(GQA_FACTOR * kv_head_idx + gqa_idx); + int q_batch_head_idx = int(batch_idx) * NUM_Q_HEADS + q_head_idx; + int pad = int(pads[batch_idx]); + int N = K_SIZE - pad; + + const device T* qptr = + queries + q_batch_head_idx * D_SIZE + int(simd_lid) * qk_per_thread; + const device T* kptr = + keys + (int(batch_idx) * NUM_KV_HEADS + int(kv_head_idx)) * + K_SIZE * D_SIZE + + (pad + int(block_idx)) * D_SIZE + int(simd_lid) * qk_per_thread; + const device T* vptr = + values + (int(batch_idx) * NUM_KV_HEADS + int(kv_head_idx)) * + K_SIZE * V_SIZE + + (pad + int(block_idx)) * V_SIZE + int(simd_lid) * v_per_thread; + device T* optr = + partials + q_batch_head_idx * BLOCKS * V_SIZE + + int(block_idx) * V_SIZE + int(simd_lid) * v_per_thread; + device float* sump = + sums + q_batch_head_idx * BLOCKS + int(block_idx); + device float* maxp = + maxs + q_batch_head_idx * BLOCKS + int(block_idx); + + U s = U(scale[0]); + for (int i = 0; i < qk_per_thread; i++) { + q[i] = s * qptr[i]; + } + + U max_score = -3.4028234663852886e38f; + U sum_exp_score = 0; + + for (int i = int(block_idx); i < N; i += BLOCKS) { + U score = 0; + for (int j = 0; j < qk_per_thread; j++) { + score += q[j] * kptr[j]; + } + score = simd_sum(score); + + U new_max = max(max_score, score); + U factor = fast::exp(max_score - new_max); + U exp_score = fast::exp(score - new_max); + + max_score = new_max; + sum_exp_score = sum_exp_score * factor + exp_score; + + for (int j = 0; j < v_per_thread; j++) { + o[j] = o[j] * factor + exp_score * vptr[j]; + } + + kptr += BLOCKS * D_SIZE; + vptr += BLOCKS * V_SIZE; + } + + if (simd_lid == 0) { + sump[0] = sum_exp_score; + maxp[0] = max_score; + } + for (int i = 0; i < v_per_thread; i++) { + optr[i] = static_cast(o[i]); + } +""" + + +_QWEN3_5_RAGGED_SDPA_TWO_PASS_2_SOURCE = r""" + uint q_batch_head_idx = threadgroup_position_in_grid.y; + uint simd_gid = simdgroup_index_in_threadgroup; + uint simd_lid = thread_index_in_simdgroup; + + constexpr int BN = 32; + constexpr int BD = 32; + constexpr int elem_per_thread = D_SIZE / BD; + + typedef float U; + thread U o[elem_per_thread] = {0}; + threadgroup U outputs[BN * BD]; + + const device T* part = + partials + int(q_batch_head_idx) * BLOCKS * D_SIZE + + int(simd_gid) * D_SIZE + int(simd_lid) * elem_per_thread; + const device float* sump = sums + int(q_batch_head_idx) * BLOCKS; + const device float* maxp = maxs + int(q_batch_head_idx) * BLOCKS; + device T* optr = + out + int(q_batch_head_idx) * D_SIZE + int(simd_gid) * elem_per_thread; + + U sum_exp_score = 0.0; + U max_score = -3.4028234663852886e38f; + + for (int b = 0; b < BLOCKS / BN; ++b) { + max_score = max(max_score, maxp[int(simd_lid) + BN * b]); + } + max_score = simd_max(max_score); + + for (int b = 0; b < BLOCKS / BN; ++b) { + U factor = fast::exp(maxp[int(simd_lid) + BN * b] - max_score); + sum_exp_score += factor * sump[int(simd_lid) + BN * b]; + } + sum_exp_score = simd_sum(sum_exp_score); + + for (int b = 0; b < BLOCKS / BN; ++b) { + U factor = fast::exp(maxp[int(simd_gid)] - max_score); + for (int i = 0; i < elem_per_thread; i++) { + o[i] += factor * static_cast(part[i]); + } + maxp += BN; + sump += BN; + part += BN * D_SIZE; + } + + for (int i = 0; i < elem_per_thread; i++) { + outputs[simd_lid * BD + simd_gid] = o[i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + o[i] = simd_sum(outputs[simd_gid * BD + simd_lid]); + o[i] = sum_exp_score == 0 ? o[i] : (o[i] / sum_exp_score); + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + if (simd_lid == 0) { + for (int i = 0; i < elem_per_thread; i++) { + optr[i] = static_cast(o[i]); + } + } +""" + + +@lru_cache(maxsize=1) +def _qwen3_5_device_arch_suffix() -> str: + info = mx.device_info() if hasattr(mx, "device_info") else mx.metal.device_info() + return str(info.get("architecture", ""))[-1:] + + +def _qwen3_5_sdpa_vector_blocks(seq_len: int, gqa_factor: int) -> int: + devc = _qwen3_5_device_arch_suffix() + n_simds = gqa_factor + if devc == "s": + blocks = 64 + if seq_len > 1024 and n_simds > 4: + if seq_len <= 8192: + blocks = 128 + elif seq_len <= 32768: + blocks = 256 + elif seq_len <= 65536: + blocks = 512 + else: + blocks = 1024 + return blocks + if devc == "d": + blocks = 128 + if n_simds <= 2 and seq_len > 8192: + blocks = 256 + elif n_simds >= 6: + if 16384 <= seq_len < 65536: + blocks = 512 + elif seq_len >= 65536: + blocks = 1024 + return blocks + if n_simds >= 4: + return 64 + return 32 + + +def _qwen3_5_sdpa_vector_plan(seq_len: int, q_heads: int, kv_heads: int): + devc = _qwen3_5_device_arch_suffix() + if (devc in {"d", "s"} and seq_len >= 1024) or ( + kv_heads < q_heads and seq_len >= 4096 + ): + return ("two_pass", _qwen3_5_sdpa_vector_blocks(seq_len, q_heads // kv_heads)) + return ("one_pass", 0) + + +@lru_cache(maxsize=None) +def _qwen3_5_ragged_sdpa_one_pass_kernel(dtype, d_size, v_size, k_size): + dtype_name = {mx.bfloat16: "bf16", mx.float16: "fp16"}.get(dtype, "unk") + return mx.fast.metal_kernel( + name=f"qwen3_5_ragged_sdpa_1p_{dtype_name}_d{d_size}_v{v_size}_k{k_size}", + input_names=["queries", "keys", "values", "pads", "scale"], + output_names=["out"], + header="#include \nusing namespace metal;\n", + source=_QWEN3_5_RAGGED_SDPA_ONE_PASS_SOURCE, + ) + + +@lru_cache(maxsize=None) +def _qwen3_5_ragged_sdpa_two_pass_1_kernel(dtype, d_size, v_size, k_size, blocks): + dtype_name = {mx.bfloat16: "bf16", mx.float16: "fp16"}.get(dtype, "unk") + return mx.fast.metal_kernel( + name=( + f"qwen3_5_ragged_sdpa_2p1_{dtype_name}_" + f"d{d_size}_v{v_size}_k{k_size}_b{blocks}" + ), + input_names=["queries", "keys", "values", "pads", "scale"], + output_names=["partials", "sums", "maxs"], + header="#include \nusing namespace metal;\n", + source=_QWEN3_5_RAGGED_SDPA_TWO_PASS_1_SOURCE, + ) + + +@lru_cache(maxsize=None) +def _qwen3_5_ragged_sdpa_two_pass_2_kernel(dtype, v_size, blocks): + dtype_name = {mx.bfloat16: "bf16", mx.float16: "fp16"}.get(dtype, "unk") + return mx.fast.metal_kernel( + name=f"qwen3_5_ragged_sdpa_2p2_{dtype_name}_v{v_size}_b{blocks}", + input_names=["partials", "sums", "maxs"], + output_names=["out"], + header="#include \nusing namespace metal;\n", + source=_QWEN3_5_RAGGED_SDPA_TWO_PASS_2_SOURCE, + ) + + +def _qwen3_5_ragged_decode_attention( + queries: mx.array, + keys: mx.array, + values: mx.array, + pads: List[int], + scale: float, +) -> Optional[mx.array]: + if ( + queries.ndim != 4 + or keys.ndim != 4 + or values.ndim != 4 + or queries.shape[2] != 1 + or queries.dtype not in (mx.bfloat16, mx.float16) + or keys.dtype != queries.dtype + or values.dtype != queries.dtype + ): + return None + + batch, q_heads, _, d_size = queries.shape + if len(pads) != batch or any(p < 0 for p in pads): + return None + kv_heads = keys.shape[1] + k_size = keys.shape[2] + v_size = values.shape[-1] + if ( + q_heads % kv_heads != 0 + or d_size != v_size + or d_size not in (64, 96, 128, 256) + or any(p >= k_size for p in pads) + ): + return None + + plans = [_qwen3_5_sdpa_vector_plan(k_size - pad, q_heads, kv_heads) for pad in pads] + if len(set(plans)) != 1: + return _qwen3_5_ragged_decode_attention_grouped( + queries, keys, values, pads, scale, plans + ) + mode, blocks = plans[0] + + queries = mx.contiguous(queries) + keys = mx.contiguous(keys) + values = mx.contiguous(values) + pads_array = mx.array(pads, dtype=mx.int32) + scale_array = mx.array([scale], dtype=mx.float32) + template = [ + ("T", queries.dtype), + ("D_SIZE", int(d_size)), + ("V_SIZE", int(v_size)), + ("K_SIZE", int(k_size)), + ("NUM_Q_HEADS", int(q_heads)), + ("NUM_KV_HEADS", int(kv_heads)), + ("GQA_FACTOR", int(q_heads // kv_heads)), + ] + + if mode == "one_pass": + kernel = _qwen3_5_ragged_sdpa_one_pass_kernel( + queries.dtype, d_size, v_size, k_size + ) + return kernel( + inputs=[queries, keys, values, pads_array, scale_array], + template=template, + grid=(1024, batch * q_heads, 1), + threadgroup=(1024, 1, 1), + output_shapes=[(batch, q_heads, 1, v_size)], + output_dtypes=[queries.dtype], + )[0] + + kernel_1 = _qwen3_5_ragged_sdpa_two_pass_1_kernel( + queries.dtype, d_size, v_size, k_size, blocks + ) + partials, sums, maxs = kernel_1( + inputs=[queries, keys, values, pads_array, scale_array], + template=[*template, ("BLOCKS", int(blocks))], + grid=(32 * kv_heads, (q_heads // kv_heads) * batch, blocks), + threadgroup=(32, q_heads // kv_heads, 1), + output_shapes=[ + (batch, q_heads, 1, blocks, v_size), + (batch, q_heads, 1, blocks), + (batch, q_heads, 1, blocks), + ], + output_dtypes=[queries.dtype, mx.float32, mx.float32], + ) + kernel_2 = _qwen3_5_ragged_sdpa_two_pass_2_kernel( + queries.dtype, v_size, blocks + ) + return kernel_2( + inputs=[partials, sums, maxs], + template=[("T", queries.dtype), ("D_SIZE", int(v_size)), ("BLOCKS", int(blocks))], + grid=(1024, batch * q_heads, 1), + threadgroup=(1024, 1, 1), + output_shapes=[(batch, q_heads, 1, v_size)], + output_dtypes=[queries.dtype], + )[0] + + +def _qwen3_5_ragged_decode_attention_grouped( + queries: mx.array, + keys: mx.array, + values: mx.array, + pads: List[int], + scale: float, + plans: List[tuple], +) -> Optional[mx.array]: + grouped_rows = {} + for row, plan in enumerate(plans): + grouped_rows.setdefault(plan, []).append(row) + + row_outputs = {} + for rows in grouped_rows.values(): + row_idx = mx.array(rows, dtype=mx.int32) + group_output = _qwen3_5_ragged_decode_attention( + mx.take(queries, row_idx, axis=0), + mx.take(keys, row_idx, axis=0), + mx.take(values, row_idx, axis=0), + [pads[row] for row in rows], + scale, + ) + if group_output is None: + return None + for local_row, row in enumerate(rows): + row_outputs[row] = group_output[local_row : local_row + 1] + + return mx.concatenate([row_outputs[row] for row in range(queries.shape[0])], axis=0) + + def _target_verify_left_padded_attention( queries: mx.array, keys: mx.array, @@ -787,6 +1239,10 @@ def _target_verify_left_padded_attention( if max(pads) <= 0: return None + output = _qwen3_5_ragged_decode_attention(queries, keys, values, pads, scale) + if output is not None: + return output + row_outputs = {} for pad in sorted(set(pads)): rows = [i for i, row_pad in enumerate(pads) if row_pad == pad] diff --git a/mlx_vlm/tests/test_speculative.py b/mlx_vlm/tests/test_speculative.py index 75244e982..ef1739e98 100644 --- a/mlx_vlm/tests/test_speculative.py +++ b/mlx_vlm/tests/test_speculative.py @@ -510,6 +510,114 @@ def test_qwen_target_verify_quantized_argmax_matches_singleton_path(): assert bool(mx.array_equal(ref, out).item()) +def test_qwen3_5_quantized_argmax_batch_as_time_matches_rowwise(): + mx.random.seed(18) + linear = nn.QuantizedLinear(512, 32, bias=False, group_size=64, bits=4) + linear.scales = linear.scales.astype(mx.bfloat16) + linear.biases = linear.biases.astype(mx.bfloat16) + x = mx.random.normal((4, 1, 512), dtype=mx.bfloat16) + + out = qwen_language._target_verify_quantized_argmax(linear, x) + ref = mx.concatenate( + [ + qwen_language._target_verify_quantized_argmax(linear, x[row : row + 1]) + for row in range(x.shape[0]) + ], + axis=0, + ) + mx.eval(out, ref) + + assert bool(mx.array_equal(out, ref).item()) + + +def _qwen3_5_ragged_attention_reference(queries, keys, values, pads, scale): + return mx.concatenate( + [ + qwen_language.scaled_dot_product_attention( + queries[row : row + 1], + keys[row : row + 1, :, pad:, :], + values[row : row + 1, :, pad:, :], + cache=None, + scale=scale, + mask=None, + ) + for row, pad in enumerate(pads) + ], + axis=0, + ) + + +def test_qwen3_5_ragged_decode_attention_matches_one_pass_singleton(): + mx.random.seed(19) + pads = [17, 0] + scale = 64**-0.5 + queries = mx.random.normal((2, 4, 1, 64), dtype=mx.bfloat16) + keys = mx.random.normal((2, 2, 256, 64), dtype=mx.bfloat16) + values = mx.random.normal((2, 2, 256, 64), dtype=mx.bfloat16) + + out = qwen_language._qwen3_5_ragged_decode_attention( + queries, keys, values, pads, scale + ) + ref = _qwen3_5_ragged_attention_reference(queries, keys, values, pads, scale) + mx.eval(out, ref) + + assert out is not None + assert bool(mx.array_equal(out, ref).item()) + + +def test_qwen3_5_ragged_decode_attention_matches_two_pass_singleton(): + mx.random.seed(20) + pads = [7, 0] + scale = 64**-0.5 + key_length = ( + 1100 + if qwen_language._qwen3_5_device_arch_suffix() in {"d", "s"} + else 4112 + ) + queries = mx.random.normal((2, 4, 1, 64), dtype=mx.bfloat16) + keys = mx.random.normal((2, 2, key_length, 64), dtype=mx.bfloat16) + values = mx.random.normal((2, 2, key_length, 64), dtype=mx.bfloat16) + + out = qwen_language._qwen3_5_ragged_decode_attention( + queries, keys, values, pads, scale + ) + ref = _qwen3_5_ragged_attention_reference(queries, keys, values, pads, scale) + mx.eval(out, ref) + + assert out is not None + assert bool(mx.array_equal(out, ref).item()) + + +def test_qwen3_5_ragged_decode_attention_matches_mixed_plan_singleton(): + mx.random.seed(21) + scale = 64**-0.5 + if qwen_language._qwen3_5_device_arch_suffix() in {"d", "s"}: + key_length = 1100 + pads = [101, 0] + else: + key_length = 4112 + pads = [33, 0] + queries = mx.random.normal((2, 4, 1, 64), dtype=mx.bfloat16) + keys = mx.random.normal((2, 2, key_length, 64), dtype=mx.bfloat16) + values = mx.random.normal((2, 2, key_length, 64), dtype=mx.bfloat16) + + plans = [ + qwen_language._qwen3_5_sdpa_vector_plan( + key_length - pad, queries.shape[1], keys.shape[1] + ) + for pad in pads + ] + out = qwen_language._qwen3_5_ragged_decode_attention( + queries, keys, values, pads, scale + ) + ref = _qwen3_5_ragged_attention_reference(queries, keys, values, pads, scale) + mx.eval(out, ref) + + assert len(set(plans)) == 2 + assert out is not None + assert bool(mx.array_equal(out, ref).item()) + + def test_qwen_target_verify_small_projection_matches_singleton_dense_gemv(): mx.random.seed(10) linear = nn.Linear(256, 8, bias=False) From 1a08841260be1470beca9d6b928c54605d9909f6 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 24 May 2026 23:01:08 +0200 Subject: [PATCH 19/26] Avoid slow mixed ragged attention dispatch --- mlx_vlm/models/qwen3_5/language.py | 35 +----------------------------- mlx_vlm/tests/test_speculative.py | 7 ++---- 2 files changed, 3 insertions(+), 39 deletions(-) diff --git a/mlx_vlm/models/qwen3_5/language.py b/mlx_vlm/models/qwen3_5/language.py index 402f9b0f8..836a75800 100644 --- a/mlx_vlm/models/qwen3_5/language.py +++ b/mlx_vlm/models/qwen3_5/language.py @@ -1123,9 +1123,7 @@ def _qwen3_5_ragged_decode_attention( plans = [_qwen3_5_sdpa_vector_plan(k_size - pad, q_heads, kv_heads) for pad in pads] if len(set(plans)) != 1: - return _qwen3_5_ragged_decode_attention_grouped( - queries, keys, values, pads, scale, plans - ) + return None mode, blocks = plans[0] queries = mx.contiguous(queries) @@ -1183,37 +1181,6 @@ def _qwen3_5_ragged_decode_attention( output_dtypes=[queries.dtype], )[0] - -def _qwen3_5_ragged_decode_attention_grouped( - queries: mx.array, - keys: mx.array, - values: mx.array, - pads: List[int], - scale: float, - plans: List[tuple], -) -> Optional[mx.array]: - grouped_rows = {} - for row, plan in enumerate(plans): - grouped_rows.setdefault(plan, []).append(row) - - row_outputs = {} - for rows in grouped_rows.values(): - row_idx = mx.array(rows, dtype=mx.int32) - group_output = _qwen3_5_ragged_decode_attention( - mx.take(queries, row_idx, axis=0), - mx.take(keys, row_idx, axis=0), - mx.take(values, row_idx, axis=0), - [pads[row] for row in rows], - scale, - ) - if group_output is None: - return None - for local_row, row in enumerate(rows): - row_outputs[row] = group_output[local_row : local_row + 1] - - return mx.concatenate([row_outputs[row] for row in range(queries.shape[0])], axis=0) - - def _target_verify_left_padded_attention( queries: mx.array, keys: mx.array, diff --git a/mlx_vlm/tests/test_speculative.py b/mlx_vlm/tests/test_speculative.py index ef1739e98..8ac586751 100644 --- a/mlx_vlm/tests/test_speculative.py +++ b/mlx_vlm/tests/test_speculative.py @@ -588,7 +588,7 @@ def test_qwen3_5_ragged_decode_attention_matches_two_pass_singleton(): assert bool(mx.array_equal(out, ref).item()) -def test_qwen3_5_ragged_decode_attention_matches_mixed_plan_singleton(): +def test_qwen3_5_ragged_decode_attention_rejects_mixed_plan(): mx.random.seed(21) scale = 64**-0.5 if qwen_language._qwen3_5_device_arch_suffix() in {"d", "s"}: @@ -610,12 +610,9 @@ def test_qwen3_5_ragged_decode_attention_matches_mixed_plan_singleton(): out = qwen_language._qwen3_5_ragged_decode_attention( queries, keys, values, pads, scale ) - ref = _qwen3_5_ragged_attention_reference(queries, keys, values, pads, scale) - mx.eval(out, ref) assert len(set(plans)) == 2 - assert out is not None - assert bool(mx.array_equal(out, ref).item()) + assert out is None def test_qwen_target_verify_small_projection_matches_singleton_dense_gemv(): From ceb904935b4bb8ca34ad050110c275ce80a1159c Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Mon, 25 May 2026 21:43:16 +0200 Subject: [PATCH 20/26] Improve Qwen3.5 batched decode scaling --- mlx_vlm/generate.py | 75 ++++---- mlx_vlm/models/qwen3_5/language.py | 265 +++++++++++++++++++++++++++-- mlx_vlm/server/generation.py | 13 +- mlx_vlm/speculative/mtp.py | 2 +- mlx_vlm/tests/test_generate.py | 36 ++++ mlx_vlm/tests/test_server.py | 34 ++++ mlx_vlm/tests/test_speculative.py | 60 +++++++ 7 files changed, 437 insertions(+), 48 deletions(-) diff --git a/mlx_vlm/generate.py b/mlx_vlm/generate.py index 5aad78229..d5b51e739 100644 --- a/mlx_vlm/generate.py +++ b/mlx_vlm/generate.py @@ -2005,7 +2005,6 @@ def next(self) -> List[Response]: return [] tokens, lp_list, top_idx_list, top_lp_list = self._step() - keep = [] responses = [] for i in range(len(self.uids)): @@ -2149,6 +2148,28 @@ def _finish_reason(self, row: int, token: int) -> Optional[str]: return "length" return None + def _append_token_responses( + self, + responses: List[Response], + tok_list: List[Optional[int]], + ) -> None: + for row, token in enumerate(tok_list): + if token is None or self._finished[row]: + continue + token = int(token) + self._num_tokens[row] += 1 + finish_reason = self._finish_reason(row, token) + if finish_reason is not None: + self._finished[row] = True + responses.append( + self.Response( + uid=self._all_uids[row], + token=token, + token_logprob=0.0, + finish_reason=finish_reason, + ) + ) + def _start_rounds(self): if self._rounds_iter is not None: return @@ -2208,7 +2229,7 @@ def next(self) -> List[Response]: self._start_rounds() try: - tok_list, _ = next(self._rounds_iter) + tok_list, round_meta = next(self._rounds_iter) except StopIteration: for row, done in enumerate(self._finished): if not done: @@ -2224,22 +2245,17 @@ def next(self) -> List[Response]: self._refresh_uids() return responses - for row, token in enumerate(tok_list): - if token is None or self._finished[row]: - continue - token = int(token) - self._num_tokens[row] += 1 - finish_reason = self._finish_reason(row, token) - if finish_reason is not None: - self._finished[row] = True - responses.append( - self.Response( - uid=self._all_uids[row], - token=token, - token_logprob=0.0, - finish_reason=finish_reason, - ) - ) + self._append_token_responses(responses, tok_list) + while ( + isinstance(round_meta, dict) + and int(round_meta.get("round_pos", 0)) + 1 + < int(round_meta.get("round_len", 1)) + ): + try: + tok_list, round_meta = next(self._rounds_iter) + except StopIteration: + break + self._append_token_responses(responses, tok_list) self._refresh_uids() return responses @@ -3154,10 +3170,10 @@ def _build_mixed_prompt_batch( right_pad_per_row=right_pad_per_row, suffix_lens=suffix_lens, apc_mode=apc_mode, - draft_model=self.draft_model, - draft_kind=self.draft_kind, - draft_block_size=self.draft_block_size, - greedy_sampling=self.greedy_sampling, + draft_model=getattr(self, "draft_model", None), + draft_kind=getattr(self, "draft_kind", None), + draft_block_size=getattr(self, "draft_block_size", None), + greedy_sampling=getattr(self, "greedy_sampling", False), ) def _build_apc_meta_for_cold( @@ -3440,10 +3456,10 @@ def _next(self, **kwargs): apc_meta=apc_meta, apc_manager=self.apc_manager, apc_mode=self.apc_mode, - draft_model=self.draft_model, - draft_kind=self.draft_kind, - draft_block_size=self.draft_block_size, - greedy_sampling=self.greedy_sampling, + draft_model=getattr(self, "draft_model", None), + draft_kind=getattr(self, "draft_kind", None), + draft_block_size=getattr(self, "draft_block_size", None), + greedy_sampling=getattr(self, "greedy_sampling", False), ) self._prompt_tokens_counter += self._prompt_batch.total_prompt_tokens @@ -3786,7 +3802,8 @@ def _generate_batch( def main(): args = parse_arguments() - mx.random.seed(args.seed) + seed = getattr(args, "seed", DEFAULT_SEED) + mx.random.seed(seed) if isinstance(args.image, str): args.image = [args.image] @@ -3889,7 +3906,7 @@ def main(): stream_kwargs = { "max_tokens": args.max_tokens, "temperature": args.temperature, - "seed": args.seed, + "seed": seed, "vision_cache": vision_cache, **kwargs, } @@ -3920,7 +3937,7 @@ def main(): "video": args.video, "fps": args.fps, "temperature": args.temperature, - "seed": args.seed, + "seed": seed, "max_tokens": args.max_tokens, "verbose": args.verbose, "max_kv_size": args.max_kv_size, diff --git a/mlx_vlm/models/qwen3_5/language.py b/mlx_vlm/models/qwen3_5/language.py index 836a75800..c2aecc4d8 100644 --- a/mlx_vlm/models/qwen3_5/language.py +++ b/mlx_vlm/models/qwen3_5/language.py @@ -377,6 +377,70 @@ def _target_verify_qlinear_header(bits: int, group_size: int) -> str: """ +_DECODE_BATCH_QMV_SOURCE = r""" + uint n_tile = threadgroup_position_in_grid.y; + uint simd_gid = simdgroup_index_in_threadgroup; + uint simd_lid = thread_index_in_simdgroup; + + int batch_idx = int(simd_gid) / NUM_SIMDGROUPS; + int out_simd = int(simd_gid) - batch_idx * NUM_SIMDGROUPS; + if (batch_idx >= BATCH_N) { + return; + } + + int out_row = int(n_tile) * BN + out_simd * RESULTS_PER_SIMDGROUP; + int in_vec_size_w = K_SIZE * BYTES_PER_PACK / PACK_FACTOR; + int in_vec_size_g = K_SIZE / GS; + + const device uint8_t* ws_base = + (const device uint8_t*)w + out_row * in_vec_size_w + + int(simd_lid) * PACKS_PER_THREAD * BYTES_PER_PACK; + const device T* scales_base = + scales + out_row * in_vec_size_g + int(simd_lid) / SCALE_STEP_PER_THREAD; + const device T* biases_base = + biases + out_row * in_vec_size_g + int(simd_lid) / SCALE_STEP_PER_THREAD; + const device T* x_base = + x + batch_idx * K_SIZE + int(simd_lid) * VALUES_PER_THREAD; + + float result[RESULTS_PER_SIMDGROUP]; + float x_thread[VALUES_PER_THREAD]; + for (int row = 0; row < RESULTS_PER_SIMDGROUP; ++row) { + result[row] = 0.0f; + } + + const device uint8_t* ws = ws_base; + const device T* sc = scales_base; + const device T* bs = biases_base; + const device T* xk = x_base; + + for (int k = 0; k < K_SIZE; k += BLOCK_SIZE) { + float sum = load_vector_exact(xk, x_thread); + + for (int row = 0; row < RESULTS_PER_SIMDGROUP; ++row) { + const device uint8_t* wl = ws + row * in_vec_size_w; + const device T* sl = sc + row * in_vec_size_g; + const device T* bl = bs + row * in_vec_size_g; + result[row] += qdot_exact(wl, x_thread, float(sl[0]), float(bl[0]), sum); + } + + ws += BLOCK_SIZE * BYTES_PER_PACK / PACK_FACTOR; + sc += BLOCK_SIZE / GS; + bs += BLOCK_SIZE / GS; + xk += BLOCK_SIZE; + } + + for (int row = 0; row < RESULTS_PER_SIMDGROUP; ++row) { + int n = out_row + row; + if (n < N_SIZE) { + float r = simd_sum(result[row]); + if (simd_lid == 0) { + y[(batch_idx * N_SIZE) + n] = T(r); + } + } + } +""" + + _TARGET_VERIFY_QARGMAX_SOURCE = r""" uint n_tile = threadgroup_position_in_grid.y; uint b_idx = threadgroup_position_in_grid.z; @@ -507,6 +571,21 @@ def _target_verify_qargmax_kernel(bits, group_size, dtype, verify_t, k_size, n_s ) +@lru_cache(maxsize=None) +def _decode_batch_qmv_kernel(bits, group_size, dtype, batch_n, k_size, n_size): + dtype_name = {mx.bfloat16: "bf16", mx.float16: "fp16"}.get(dtype, "unk") + return mx.fast.metal_kernel( + name=( + "qwen3_5_decode_batch_qmv_" + f"b{bits}_gs{group_size}_bn{batch_n}_k{k_size}_n{n_size}_{dtype_name}" + ), + input_names=["x", "w", "scales", "biases"], + output_names=["y"], + header=_target_verify_qlinear_header(bits, group_size), + source=_DECODE_BATCH_QMV_SOURCE, + ) + + def _can_target_verify_quantized(linear, x: mx.array) -> bool: if ( not isinstance(linear, nn.QuantizedLinear) @@ -530,6 +609,39 @@ def _can_target_verify_quantized(linear, x: mx.array) -> bool: ) +def _decode_quantized_linear_batch(linear, x: mx.array) -> Optional[mx.array]: + if ( + not _can_target_verify_quantized(linear, x) + or x.shape[1] != 1 + or x.shape[0] != 4 + ): + return None + + B, _T, K = x.shape + N = linear.weight.shape[0] + + x = mx.contiguous(x) + kernel = _decode_batch_qmv_kernel( + linear.bits, linear.group_size, x.dtype, B, K, N + ) + out = kernel( + inputs=[x, linear.weight, linear.scales, linear.biases], + template=[ + ("T", x.dtype), + ("BATCH_N", int(B)), + ("K_SIZE", int(K)), + ("N_SIZE", int(N)), + ], + grid=(32, 2 * B * (N // 8), 1), + threadgroup=(32, 2 * B, 1), + output_shapes=[(B, 1, N)], + output_dtypes=[x.dtype], + )[0] + if "bias" in linear: + out = out + linear["bias"] + return out + + def _target_verify_quantized_linear(linear, x: mx.array) -> Optional[mx.array]: if not _can_target_verify_quantized(linear, x): return None @@ -557,6 +669,60 @@ def _target_verify_quantized_linear(linear, x: mx.array) -> Optional[mx.array]: return out +def _decode_quantized_linears_fused(linears, x: mx.array): + if ( + x.ndim != 3 + or x.shape[1] != 1 + or len(linears) != 4 + or not all(isinstance(linear, nn.QuantizedLinear) for linear in linears) + ): + return None + + first = linears[0] + if not all( + linear.bits == first.bits + and linear.group_size == first.group_size + and linear.mode == first.mode + and linear.biases is not None + and linear.scales.dtype == x.dtype + and linear.biases.dtype == x.dtype + and "bias" not in linear + for linear in linears + ): + return None + + cache_key = tuple( + (id(linear.weight), id(linear.scales), id(linear.biases)) + for linear in linears + ) + cached = getattr(first, "_qwen3_5_fused_decode_linears", None) + if cached is None or cached[0] != cache_key: + weights = mx.concatenate([linear.weight for linear in linears], axis=0) + scales = mx.concatenate([linear.scales for linear in linears], axis=0) + biases = mx.concatenate([linear.biases for linear in linears], axis=0) + split_indices = [] + offset = 0 + for linear in linears[:-1]: + offset += linear.weight.shape[0] + split_indices.append(offset) + mx.eval(weights, scales, biases) + cached = (cache_key, weights, scales, biases, split_indices) + first._qwen3_5_fused_decode_linears = cached + + _, weights, scales, biases, split_indices = cached + output = mx.quantized_matmul( + x, + weights, + scales=scales, + biases=biases, + transpose=True, + group_size=first.group_size, + bits=first.bits, + mode=first.mode, + ) + return tuple(mx.split(output, split_indices, axis=-1)) + + def _target_verify_quantized_argmax(linear, x: mx.array) -> Optional[mx.array]: if not _can_target_verify_quantized(linear, x) or "bias" in linear: return None @@ -612,6 +778,10 @@ def _target_verify_singletons(fn, x: mx.array) -> mx.array: def _target_verify_linear(linear, x: mx.array, target_verify: bool) -> mx.array: if not _use_target_verify_dense(linear, x, target_verify): + if isinstance(linear, nn.QuantizedLinear): + out = _decode_quantized_linear_batch(linear, x) + if out is not None: + return out return linear(x) if isinstance(linear, nn.QuantizedLinear): @@ -639,7 +809,18 @@ def _target_verify_linears(linears, x: mx.array, target_verify: bool): isinstance(linear, (nn.Linear, nn.QuantizedLinear)) for linear in linears ) ): - return tuple(linear(x) for linear in linears) + out = _decode_quantized_linears_fused(linears, x) + if out is not None: + return out + outputs = [] + for linear in linears: + if isinstance(linear, nn.QuantizedLinear): + out = _decode_quantized_linear_batch(linear, x) + if out is not None: + outputs.append(out) + continue + outputs.append(linear(x)) + return tuple(outputs) return tuple(_target_verify_linear(linear, x, target_verify) for linear in linears) @@ -704,11 +885,23 @@ def _create_qwen3_5_ssm_mask(h: mx.array, cache): if not (cache and hasattr(cache, "make_mask")): return None + lengths = getattr(cache, "lengths", None) left_padding = getattr(cache, "left_padding", None) - if isinstance(left_padding, mx.array) and int(left_padding.max().item()) <= 0: - return None + if isinstance(left_padding, mx.array): + batch_size = int(left_padding.shape[0]) if left_padding.ndim > 0 else 1 + if ( + lengths is None + and getattr(cache, "_qwen3_5_ssm_no_mask_batch_size", None) + == batch_size + ): + return None + if int(left_padding.max().item()) <= 0: + if lengths is None: + cache._qwen3_5_ssm_no_mask_batch_size = batch_size + return None + if hasattr(cache, "_qwen3_5_ssm_no_mask_batch_size"): + delattr(cache, "_qwen3_5_ssm_no_mask_batch_size") - lengths = getattr(cache, "lengths", None) if isinstance(lengths, mx.array) and int(lengths.min().item()) >= h.shape[1]: return None @@ -741,6 +934,19 @@ def _create_qwen3_5_attention_mask(h: mx.array, cache): return create_attention_mask(h, cache) +def _set_qwen3_5_decode_left_padding(caches, layers, pads): + if caches is None: + return + for layer, cache_entry in zip(layers, caches): + if layer.is_linear or cache_entry is None: + continue + if pads is None: + if hasattr(cache_entry, "_qwen3_5_decode_left_padding"): + delattr(cache_entry, "_qwen3_5_decode_left_padding") + else: + cache_entry._qwen3_5_decode_left_padding = pads + + def _gated_delta_update_verify_decode( q: mx.array, k: mx.array, @@ -785,6 +991,7 @@ def _gated_delta_update_verify_decode( threadgroup U max_scores[BN]; threadgroup U sum_exp_scores[BN]; + int K_SIZE = int(k_size[0]); int batch_idx = int(q_batch_head_idx) / NUM_Q_HEADS; int q_head_idx = int(q_batch_head_idx) - batch_idx * NUM_Q_HEADS; int kv_head_idx = q_head_idx / GQA_FACTOR; @@ -881,6 +1088,7 @@ def _gated_delta_update_verify_decode( thread U q[qk_per_thread]; thread U o[v_per_thread] = {0}; + int K_SIZE = int(k_size[0]); int q_head_idx = int(GQA_FACTOR * kv_head_idx + gqa_idx); int q_batch_head_idx = int(batch_idx) * NUM_Q_HEADS + q_head_idx; int pad = int(pads[batch_idx]); @@ -1051,11 +1259,11 @@ def _qwen3_5_sdpa_vector_plan(seq_len: int, q_heads: int, kv_heads: int): @lru_cache(maxsize=None) -def _qwen3_5_ragged_sdpa_one_pass_kernel(dtype, d_size, v_size, k_size): +def _qwen3_5_ragged_sdpa_one_pass_kernel(dtype, d_size, v_size): dtype_name = {mx.bfloat16: "bf16", mx.float16: "fp16"}.get(dtype, "unk") return mx.fast.metal_kernel( - name=f"qwen3_5_ragged_sdpa_1p_{dtype_name}_d{d_size}_v{v_size}_k{k_size}", - input_names=["queries", "keys", "values", "pads", "scale"], + name=f"qwen3_5_ragged_sdpa_1p_{dtype_name}_d{d_size}_v{v_size}", + input_names=["queries", "keys", "values", "pads", "scale", "k_size"], output_names=["out"], header="#include \nusing namespace metal;\n", source=_QWEN3_5_RAGGED_SDPA_ONE_PASS_SOURCE, @@ -1063,14 +1271,14 @@ def _qwen3_5_ragged_sdpa_one_pass_kernel(dtype, d_size, v_size, k_size): @lru_cache(maxsize=None) -def _qwen3_5_ragged_sdpa_two_pass_1_kernel(dtype, d_size, v_size, k_size, blocks): +def _qwen3_5_ragged_sdpa_two_pass_1_kernel(dtype, d_size, v_size, blocks): dtype_name = {mx.bfloat16: "bf16", mx.float16: "fp16"}.get(dtype, "unk") return mx.fast.metal_kernel( name=( f"qwen3_5_ragged_sdpa_2p1_{dtype_name}_" - f"d{d_size}_v{v_size}_k{k_size}_b{blocks}" + f"d{d_size}_v{v_size}_b{blocks}" ), - input_names=["queries", "keys", "values", "pads", "scale"], + input_names=["queries", "keys", "values", "pads", "scale", "k_size"], output_names=["partials", "sums", "maxs"], header="#include \nusing namespace metal;\n", source=_QWEN3_5_RAGGED_SDPA_TWO_PASS_1_SOURCE, @@ -1131,22 +1339,20 @@ def _qwen3_5_ragged_decode_attention( values = mx.contiguous(values) pads_array = mx.array(pads, dtype=mx.int32) scale_array = mx.array([scale], dtype=mx.float32) + k_size_array = mx.array([k_size], dtype=mx.int32) template = [ ("T", queries.dtype), ("D_SIZE", int(d_size)), ("V_SIZE", int(v_size)), - ("K_SIZE", int(k_size)), ("NUM_Q_HEADS", int(q_heads)), ("NUM_KV_HEADS", int(kv_heads)), ("GQA_FACTOR", int(q_heads // kv_heads)), ] if mode == "one_pass": - kernel = _qwen3_5_ragged_sdpa_one_pass_kernel( - queries.dtype, d_size, v_size, k_size - ) + kernel = _qwen3_5_ragged_sdpa_one_pass_kernel(queries.dtype, d_size, v_size) return kernel( - inputs=[queries, keys, values, pads_array, scale_array], + inputs=[queries, keys, values, pads_array, scale_array, k_size_array], template=template, grid=(1024, batch * q_heads, 1), threadgroup=(1024, 1, 1), @@ -1155,10 +1361,10 @@ def _qwen3_5_ragged_decode_attention( )[0] kernel_1 = _qwen3_5_ragged_sdpa_two_pass_1_kernel( - queries.dtype, d_size, v_size, k_size, blocks + queries.dtype, d_size, v_size, blocks ) partials, sums, maxs = kernel_1( - inputs=[queries, keys, values, pads_array, scale_array], + inputs=[queries, keys, values, pads_array, scale_array, k_size_array], template=[*template, ("BLOCKS", int(blocks))], grid=(32 * kv_heads, (q_heads // kv_heads) * batch, blocks), threadgroup=(32, q_heads // kv_heads, 1), @@ -1447,6 +1653,19 @@ def __init__(self, config: TextConfig): def _causal_conv1d_verify(self, conv_input: mx.array, steps: int) -> mx.array: return self.conv1d(conv_input) + def _causal_conv1d_decode(self, conv_input: mx.array) -> mx.array: + cached = getattr(self, "_qwen3_5_decode_conv_weight", None) + cache_key = id(self.conv1d.weight) + if cached is None or cached[0] != cache_key: + weight = self.conv1d.weight[:, :, 0].T.astype(mx.float32) + mx.eval(weight) + cached = (cache_key, weight) + self._qwen3_5_decode_conv_weight = cached + + weight = cached[1] + out = mx.sum(conv_input.astype(mx.float32) * weight[None, :, :], axis=1) + return out.astype(self.conv1d.weight.dtype)[:, None, :] + def __call__( self, inputs: mx.array, @@ -1489,6 +1708,12 @@ def __call__( cache[0] = conv_input[:, -(self.conv_kernel_size - 1) :] if gdn_sink is not None: conv_out = nn.silu(self._causal_conv1d_verify(conv_input, S)) + elif ( + S == 1 + and conv_input.shape[1] == self.conv_kernel_size + and self.conv1d.weight.dtype in (mx.bfloat16, mx.float16) + ): + conv_out = nn.silu(self._causal_conv1d_decode(conv_input)) else: conv_out = nn.silu(self.conv1d(conv_input)) @@ -1730,6 +1955,12 @@ def __call__( fa_mask = _create_qwen3_5_attention_mask(h, cache[self.fa_idx]) ssm_mask = _create_qwen3_5_ssm_mask(h, cache[self.ssm_idx]) + decode_left_padding = ( + getattr(cache[self.fa_idx], "_qwen3_5_decode_left_padding", None) + if isinstance(fa_mask, str) and fa_mask == "left_padded_decode" + else None + ) + _set_qwen3_5_decode_left_padding(cache, self.layers, decode_left_padding) capture_set = set(capture_layer_ids) if capture_layer_ids else set() for i, (layer, c) in enumerate(zip(self.layers, cache)): diff --git a/mlx_vlm/server/generation.py b/mlx_vlm/server/generation.py index 94a5ae8af..c9d263d05 100644 --- a/mlx_vlm/server/generation.py +++ b/mlx_vlm/server/generation.py @@ -880,8 +880,19 @@ def _run(self): try: # Poll the request queue — non-blocking when generating, short # blocking wait when idle so we don't spin. + active_batch = bool(active) + coalesce_s = ( + get_speculative_batch_coalesce_s() + if ( + not active_batch + and self.draft_model is not None + and self.draft_kind == "mtp" + ) + else 0.0 + ) new_items, should_stop = self._collect_pending_requests( - active=bool(active) + active=active_batch, + coalesce_s=coalesce_s, ) if should_stop: break diff --git a/mlx_vlm/speculative/mtp.py b/mlx_vlm/speculative/mtp.py index 2c9fe7d56..adad88e4e 100644 --- a/mlx_vlm/speculative/mtp.py +++ b/mlx_vlm/speculative/mtp.py @@ -1018,7 +1018,7 @@ def _mtp_rounds_batch( finished[orig] = True if stop_check is not None and stop_check(orig, tok): finished[orig] = True - yield tokens_out, None + yield tokens_out, {"round_pos": pos, "round_len": max_new} # Update bonus tokens and per-row positions for j in range(n_active): diff --git a/mlx_vlm/tests/test_generate.py b/mlx_vlm/tests/test_generate.py index 3b6aa2824..d29642520 100644 --- a/mlx_vlm/tests/test_generate.py +++ b/mlx_vlm/tests/test_generate.py @@ -17,6 +17,7 @@ BatchStats, GenerationBatch, GenerationResult, + SpeculativeGenerationBatch, _left_pad_prompts, _prime_cached_prefix_rope_state, normalize_resize_shape, @@ -610,6 +611,41 @@ def speculative_argmax_from_hidden(self, hidden): assert batch._next_tokens.tolist() == [7, 7] assert model.calls == [{"return_hidden": True, "skip_logits": True}] + def test_speculative_generation_batch_drains_full_round(self, monkeypatch): + def fake_rounds(*args, **kwargs): + del args, kwargs + yield [1, 10], {"round_pos": 0, "round_len": 2} + yield [2, 11], {"round_pos": 1, "round_len": 2} + yield [3, 12], {"round_pos": 0, "round_len": 1} + + monkeypatch.setattr(generate_module, "run_speculative_server_rounds", fake_rounds) + + batch = SpeculativeGenerationBatch( + model=SimpleNamespace(), + draft_model=SimpleNamespace(), + draft_kind="mtp", + uids=[100, 200], + first_tokens=mx.array([0, 9], dtype=mx.int32), + prompt_cache=[], + sampler=lambda logprobs: mx.argmax(logprobs, axis=-1), + stop_criteria=lambda token: False, + max_tokens=[10, 10], + hidden=mx.zeros((2, 1, 1)), + shared_kv_states=None, + prompt_tokens=mx.array([[0], [9]], dtype=mx.int32), + ) + + first = batch.next() + assert [(r.uid, r.token) for r in first] == [(100, 0), (200, 9)] + + second = batch.next() + assert [(r.uid, r.token) for r in second] == [ + (100, 1), + (200, 10), + (100, 2), + (200, 11), + ] + def test_remove_from_unprocessed(self, mock_model, mock_processor): gen = BatchGenerator( model=mock_model.language_model, diff --git a/mlx_vlm/tests/test_server.py b/mlx_vlm/tests/test_server.py index 70c4d926d..359d02d45 100644 --- a/mlx_vlm/tests/test_server.py +++ b/mlx_vlm/tests/test_server.py @@ -2209,6 +2209,40 @@ def fake_initialize_model(): assert kwargs["compute_logprobs"] is False assert batch_state["instance"].next_active_sizes == [2] + def test_run_coalesces_idle_mtp_batch_generator(self, monkeypatch): + monkeypatch.setenv("MLX_VLM_SPEC_BATCH_COALESCE_MS", "37") + calls = [] + draft_model = object() + + gen = server.ResponseGenerator.__new__(server.ResponseGenerator) + gen.draft_model = None + gen.draft_kind = None + gen._stop = False + gen._ready = Event() + gen._load_error = None + + def fake_initialize_model(): + gen.model = SimpleNamespace(language_model=object()) + gen.processor = SimpleNamespace() + gen.config = SimpleNamespace() + gen.stop_tokens = set() + gen.draft_model = draft_model + gen.draft_kind = "mtp" + gen.tokenizer = SimpleNamespace() + + def fake_collect_pending_requests(*, active, idle_timeout=0.1, coalesce_s=0.0): + del idle_timeout + calls.append((active, coalesce_s)) + return [], True + + gen._initialize_model = fake_initialize_model + gen._run_speculative = lambda: pytest.fail("MTP should use BatchGenerator") + gen._collect_pending_requests = fake_collect_pending_requests + + gen._run() + + assert calls == [(False, 0.037)] + def test_idle_batch_generator_is_recreated_for_new_sampler(self, monkeypatch): created = [] next_uid = [1] diff --git a/mlx_vlm/tests/test_speculative.py b/mlx_vlm/tests/test_speculative.py index 8ac586751..0158b0100 100644 --- a/mlx_vlm/tests/test_speculative.py +++ b/mlx_vlm/tests/test_speculative.py @@ -496,6 +496,42 @@ def test_qwen_target_verify_quantized_linear_matches_singleton_batch_path(): assert bool(mx.array_equal(ref, out).item()) +def test_qwen3_5_decode_quantized_linears_fused_matches_separate(): + for bits in (4, 5): + mx.random.seed(170 + bits) + linears = [ + nn.QuantizedLinear(512, out_dim, bias=False, group_size=64, bits=bits) + for out_dim in (64, 64, 16, 16) + ] + for linear in linears: + linear.scales = linear.scales.astype(mx.bfloat16) + linear.biases = linear.biases.astype(mx.bfloat16) + x = mx.random.normal((4, 1, 512), dtype=mx.bfloat16) + + ref = tuple(linear(x) for linear in linears) + out = qwen_language._decode_quantized_linears_fused(tuple(linears), x) + mx.eval(*ref, *out) + + assert out is not None + assert all(bool(mx.array_equal(a, b).item()) for a, b in zip(ref, out)) + + +def test_qwen3_5_decode_quantized_linear_batch_matches_singleton_rows(): + for bits in (4, 5): + mx.random.seed(180 + bits) + linear = nn.QuantizedLinear(512, 32, bias=False, group_size=64, bits=bits) + linear.scales = linear.scales.astype(mx.bfloat16) + linear.biases = linear.biases.astype(mx.bfloat16) + x = mx.random.normal((4, 1, 512), dtype=mx.bfloat16) + + ref = mx.concatenate([linear(x[row : row + 1]) for row in range(4)], axis=0) + out = qwen_language._decode_quantized_linear_batch(linear, x) + mx.eval(ref, out) + + assert out is not None + assert bool(mx.array_equal(ref, out).item()) + + def test_qwen_target_verify_quantized_argmax_matches_singleton_path(): mx.random.seed(16) linear = nn.QuantizedLinear(512, 16, bias=False, group_size=32, bits=4) @@ -710,6 +746,30 @@ def test_qwen_gdn_verify_conv_matches_singleton_windows(): assert bool(mx.array_equal(ref, out).item()) +def test_qwen_gdn_decode_conv_matches_conv1d(): + mx.random.seed(141) + config = SimpleNamespace( + hidden_size=16, + linear_num_value_heads=2, + linear_num_key_heads=2, + linear_key_head_dim=4, + linear_value_head_dim=4, + linear_conv_kernel_dim=4, + rms_norm_eps=1e-6, + ) + layer = qwen_language.Qwen3_5GatedDeltaNet(config) + layer.conv1d.weight = layer.conv1d.weight.astype(mx.bfloat16) + conv_input = mx.random.normal( + (3, layer.conv_kernel_size, layer.conv_dim), dtype=mx.bfloat16 + ) + + ref = layer.conv1d(conv_input) + out = layer._causal_conv1d_decode(conv_input) + mx.eval(ref, out) + + assert bool(mx.array_equal(ref, out).item()) + + def test_qwen3_5_rope_index_ignores_left_padding_for_vision_rows(): model_config = qwen_language.ModelConfig( text_config=_tiny_qwen3_5_text_config(), From 159bf244c07969a5fb29ef4230523933c863b8d9 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Mon, 25 May 2026 21:57:09 +0200 Subject: [PATCH 21/26] Avoid MTP rollback syncs --- mlx_vlm/models/qwen3_5/language.py | 55 ++++++++++++++++++------------ mlx_vlm/speculative/mtp.py | 7 ++-- 2 files changed, 37 insertions(+), 25 deletions(-) diff --git a/mlx_vlm/models/qwen3_5/language.py b/mlx_vlm/models/qwen3_5/language.py index c2aecc4d8..b829fd86c 100644 --- a/mlx_vlm/models/qwen3_5/language.py +++ b/mlx_vlm/models/qwen3_5/language.py @@ -2000,13 +2000,31 @@ def rollback_speculative_cache( block_size: int, ) -> int: if isinstance(accepted, int): - accepted = mx.array([accepted]) + accepted_list = [int(accepted)] + elif isinstance(accepted, mx.array): + accepted_list = [int(x) for x in accepted.reshape(-1).tolist()] + else: + accepted_list = [int(x) for x in accepted] - max_a = int(accepted.max().item()) + max_a = max(accepted_list) n = max_a + 1 trim = block_size - n - is_batch = accepted.size > 1 - valid_ends = accepted + 1 + is_batch = len(accepted_list) > 1 + valid_ends_list = [a + 1 for a in accepted_list] + accepted_mx = None + valid_ends_mx = None + + def accepted_array(): + nonlocal accepted_mx + if accepted_mx is None: + accepted_mx = mx.array(accepted_list, dtype=mx.int32) + return accepted_mx + + def valid_ends_array(): + nonlocal valid_ends_mx + if valid_ends_mx is None: + valid_ends_mx = mx.array(valid_ends_list, dtype=mx.int32) + return valid_ends_mx # Separate trimmable (KV) caches from SSM caches. ssm_caches = [] @@ -2018,8 +2036,8 @@ def rollback_speculative_cache( c.trim(trim) right_trimmed = False if is_batch and max_a > 0: - extra_trim = max_a - accepted - if int(extra_trim.max().item()) > 0: + extra_trim_list = [max_a - a for a in accepted_list] + if any(extra_trim_list): prepare = getattr(c, "prepare", None) finalize = getattr(c, "finalize", None) if ( @@ -2027,11 +2045,7 @@ def rollback_speculative_cache( and callable(prepare) and callable(finalize) ): - prepare( - right_padding=[ - int(extra) for extra in extra_trim.tolist() - ] - ) + prepare(right_padding=extra_trim_list) finalize() right_trimmed = True if ( @@ -2042,10 +2056,9 @@ def rollback_speculative_cache( and max_a > 0 ): kv_len = c._idx - ve = valid_ends.tolist() verify_start = kv_len - n - for bi in range(accepted.shape[0]): - start = verify_start + int(ve[bi]) + for bi, ve in enumerate(valid_ends_list): + start = verify_start + ve if start < kv_len: c.keys[bi, :, start:kv_len, :] = 0 c.values[bi, :, start:kv_len, :] = 0 @@ -2056,7 +2069,7 @@ def rollback_speculative_cache( return max_a if all(len(s) > 11 and s[11] is not None for s in gdn_states): - a0 = int(accepted[0].item()) if not is_batch else None + a0 = accepted_list[0] if not is_batch else None if is_batch: intermediate_parts = [] conv_input_parts = [] @@ -2111,8 +2124,9 @@ def rollback_speculative_cache( if len(set(kernel_sizes)) != 1: raise ValueError("Qwen GDN layers must share conv kernel size.") + accepted_mx = accepted_array() accepted_bat = mx.concatenate( - [accepted.astype(mx.int32) for _ in ssm_caches], axis=0 + [accepted_mx for _ in ssm_caches], axis=0 ) state_bat, conv_bat = gated_delta_accept_states( mx.concatenate(intermediate_parts, axis=0), @@ -2187,7 +2201,7 @@ def rollback_speculative_cache( a_list.append(a) b_list.append(b) if is_batch: - steps_list.append(valid_ends.astype(mx.int32)) + steps_list.append(valid_ends_array()) else: steps_list.append(mx.full((batch_rows,), n, dtype=mx.int32)) A_log_list.append( @@ -2242,7 +2256,7 @@ def rollback_speculative_cache( ) # Scatter results back to individual caches. - a0 = int(accepted[0].item()) if not is_batch else None + a0 = accepted_list[0] if not is_batch else None state_offset = 0 for j, c in enumerate(ssm_caches): batch_rows = layer_batch_sizes[j] @@ -2250,13 +2264,12 @@ def rollback_speculative_cache( state_offset += batch_rows conv_input, K = conv_data[j] if is_batch: - acc_list = accepted.tolist() slices = [ conv_input[ bi : bi + 1, - int(acc_list[bi]) + 1 : int(acc_list[bi]) + K, + accepted_list[bi] + 1 : accepted_list[bi] + K, ] - for bi in range(accepted.shape[0]) + for bi in range(len(accepted_list)) ] c[0] = mx.concatenate(slices, axis=0) else: diff --git a/mlx_vlm/speculative/mtp.py b/mlx_vlm/speculative/mtp.py index adad88e4e..21ed858f8 100644 --- a/mlx_vlm/speculative/mtp.py +++ b/mlx_vlm/speculative/mtp.py @@ -662,7 +662,7 @@ def _mtp_rounds( next_shared_kv = _slice_shared_kv_after_reject( verify.shared_kv_states, bs - (accepted + 1) ) - kv_offset = _mtp_cache_offset_max(prompt_cache) + kv_offset += accepted + 1 draft_model.set_shared_kv( next_shared_kv, kv_offset, @@ -975,7 +975,6 @@ def _mtp_rounds_batch( ) max_a = max(accepted_list) - accepted_arr = mx.array(accepted_list) accept_verified = getattr(draft_model, "accept_verified_tokens_batch", None) if callable(accept_verified): @@ -1032,7 +1031,7 @@ def _mtp_rounds_batch( if any(a < bs - 1 for a in accepted_list): with mx.stream(generation_stream): lm.rollback_speculative_cache( - prompt_cache, verify.gdn_states, accepted_arr, bs + prompt_cache, verify.gdn_states, accepted_list, bs ) # Slice + tail-zero ``verify.shared_kv_states`` to match the @@ -1098,7 +1097,7 @@ def _mtp_rounds_batch( # Re-bind drafter with new shared_kv and per-row positions. positions_active = [positions[active_idx[j]] for j in range(len(active_idx))] - new_kv_offset = _mtp_cache_offset_max(prompt_cache) + new_kv_offset = max(positions_active) if positions_active else 0 draft_model.set_shared_kv( next_shared_kv, kv_offset=new_kv_offset, From 86623666519b3c5c97ed180a108eb9e34fbb1a27 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Tue, 26 May 2026 00:17:12 +0200 Subject: [PATCH 22/26] Remove slow Qwen decode qmv path --- mlx_vlm/models/qwen3_5/language.py | 126 +---------------------------- mlx_vlm/server/generation.py | 4 +- mlx_vlm/tests/test_speculative.py | 16 ---- 3 files changed, 3 insertions(+), 143 deletions(-) diff --git a/mlx_vlm/models/qwen3_5/language.py b/mlx_vlm/models/qwen3_5/language.py index b829fd86c..a475930e2 100644 --- a/mlx_vlm/models/qwen3_5/language.py +++ b/mlx_vlm/models/qwen3_5/language.py @@ -377,70 +377,6 @@ def _target_verify_qlinear_header(bits: int, group_size: int) -> str: """ -_DECODE_BATCH_QMV_SOURCE = r""" - uint n_tile = threadgroup_position_in_grid.y; - uint simd_gid = simdgroup_index_in_threadgroup; - uint simd_lid = thread_index_in_simdgroup; - - int batch_idx = int(simd_gid) / NUM_SIMDGROUPS; - int out_simd = int(simd_gid) - batch_idx * NUM_SIMDGROUPS; - if (batch_idx >= BATCH_N) { - return; - } - - int out_row = int(n_tile) * BN + out_simd * RESULTS_PER_SIMDGROUP; - int in_vec_size_w = K_SIZE * BYTES_PER_PACK / PACK_FACTOR; - int in_vec_size_g = K_SIZE / GS; - - const device uint8_t* ws_base = - (const device uint8_t*)w + out_row * in_vec_size_w + - int(simd_lid) * PACKS_PER_THREAD * BYTES_PER_PACK; - const device T* scales_base = - scales + out_row * in_vec_size_g + int(simd_lid) / SCALE_STEP_PER_THREAD; - const device T* biases_base = - biases + out_row * in_vec_size_g + int(simd_lid) / SCALE_STEP_PER_THREAD; - const device T* x_base = - x + batch_idx * K_SIZE + int(simd_lid) * VALUES_PER_THREAD; - - float result[RESULTS_PER_SIMDGROUP]; - float x_thread[VALUES_PER_THREAD]; - for (int row = 0; row < RESULTS_PER_SIMDGROUP; ++row) { - result[row] = 0.0f; - } - - const device uint8_t* ws = ws_base; - const device T* sc = scales_base; - const device T* bs = biases_base; - const device T* xk = x_base; - - for (int k = 0; k < K_SIZE; k += BLOCK_SIZE) { - float sum = load_vector_exact(xk, x_thread); - - for (int row = 0; row < RESULTS_PER_SIMDGROUP; ++row) { - const device uint8_t* wl = ws + row * in_vec_size_w; - const device T* sl = sc + row * in_vec_size_g; - const device T* bl = bs + row * in_vec_size_g; - result[row] += qdot_exact(wl, x_thread, float(sl[0]), float(bl[0]), sum); - } - - ws += BLOCK_SIZE * BYTES_PER_PACK / PACK_FACTOR; - sc += BLOCK_SIZE / GS; - bs += BLOCK_SIZE / GS; - xk += BLOCK_SIZE; - } - - for (int row = 0; row < RESULTS_PER_SIMDGROUP; ++row) { - int n = out_row + row; - if (n < N_SIZE) { - float r = simd_sum(result[row]); - if (simd_lid == 0) { - y[(batch_idx * N_SIZE) + n] = T(r); - } - } - } -""" - - _TARGET_VERIFY_QARGMAX_SOURCE = r""" uint n_tile = threadgroup_position_in_grid.y; uint b_idx = threadgroup_position_in_grid.z; @@ -571,21 +507,6 @@ def _target_verify_qargmax_kernel(bits, group_size, dtype, verify_t, k_size, n_s ) -@lru_cache(maxsize=None) -def _decode_batch_qmv_kernel(bits, group_size, dtype, batch_n, k_size, n_size): - dtype_name = {mx.bfloat16: "bf16", mx.float16: "fp16"}.get(dtype, "unk") - return mx.fast.metal_kernel( - name=( - "qwen3_5_decode_batch_qmv_" - f"b{bits}_gs{group_size}_bn{batch_n}_k{k_size}_n{n_size}_{dtype_name}" - ), - input_names=["x", "w", "scales", "biases"], - output_names=["y"], - header=_target_verify_qlinear_header(bits, group_size), - source=_DECODE_BATCH_QMV_SOURCE, - ) - - def _can_target_verify_quantized(linear, x: mx.array) -> bool: if ( not isinstance(linear, nn.QuantizedLinear) @@ -609,39 +530,6 @@ def _can_target_verify_quantized(linear, x: mx.array) -> bool: ) -def _decode_quantized_linear_batch(linear, x: mx.array) -> Optional[mx.array]: - if ( - not _can_target_verify_quantized(linear, x) - or x.shape[1] != 1 - or x.shape[0] != 4 - ): - return None - - B, _T, K = x.shape - N = linear.weight.shape[0] - - x = mx.contiguous(x) - kernel = _decode_batch_qmv_kernel( - linear.bits, linear.group_size, x.dtype, B, K, N - ) - out = kernel( - inputs=[x, linear.weight, linear.scales, linear.biases], - template=[ - ("T", x.dtype), - ("BATCH_N", int(B)), - ("K_SIZE", int(K)), - ("N_SIZE", int(N)), - ], - grid=(32, 2 * B * (N // 8), 1), - threadgroup=(32, 2 * B, 1), - output_shapes=[(B, 1, N)], - output_dtypes=[x.dtype], - )[0] - if "bias" in linear: - out = out + linear["bias"] - return out - - def _target_verify_quantized_linear(linear, x: mx.array) -> Optional[mx.array]: if not _can_target_verify_quantized(linear, x): return None @@ -778,10 +666,6 @@ def _target_verify_singletons(fn, x: mx.array) -> mx.array: def _target_verify_linear(linear, x: mx.array, target_verify: bool) -> mx.array: if not _use_target_verify_dense(linear, x, target_verify): - if isinstance(linear, nn.QuantizedLinear): - out = _decode_quantized_linear_batch(linear, x) - if out is not None: - return out return linear(x) if isinstance(linear, nn.QuantizedLinear): @@ -812,15 +696,7 @@ def _target_verify_linears(linears, x: mx.array, target_verify: bool): out = _decode_quantized_linears_fused(linears, x) if out is not None: return out - outputs = [] - for linear in linears: - if isinstance(linear, nn.QuantizedLinear): - out = _decode_quantized_linear_batch(linear, x) - if out is not None: - outputs.append(out) - continue - outputs.append(linear(x)) - return tuple(outputs) + return tuple(linear(x) for linear in linears) return tuple(_target_verify_linear(linear, x, target_verify) for linear in linears) diff --git a/mlx_vlm/server/generation.py b/mlx_vlm/server/generation.py index c9d263d05..5012193c5 100644 --- a/mlx_vlm/server/generation.py +++ b/mlx_vlm/server/generation.py @@ -1130,7 +1130,7 @@ def _run_speculative(self): token=tok, logprobs=0.0, finish_reason=finish, - peak_memory=mx.get_peak_memory() / 1e9, + peak_memory=mx.get_peak_memory() / 1e9 if finish else 0, prompt_tps=prompt_tps_map.get(uid), ) ) @@ -1197,7 +1197,7 @@ def stop_check(seq_idx, token_id): token=tok, logprobs=0.0, finish_reason=finish, - peak_memory=mx.get_peak_memory() / 1e9, + peak_memory=mx.get_peak_memory() / 1e9 if finish else 0, prompt_tps=prompt_tps_map.get(uid), ) ) diff --git a/mlx_vlm/tests/test_speculative.py b/mlx_vlm/tests/test_speculative.py index 0158b0100..fba22e67c 100644 --- a/mlx_vlm/tests/test_speculative.py +++ b/mlx_vlm/tests/test_speculative.py @@ -516,22 +516,6 @@ def test_qwen3_5_decode_quantized_linears_fused_matches_separate(): assert all(bool(mx.array_equal(a, b).item()) for a, b in zip(ref, out)) -def test_qwen3_5_decode_quantized_linear_batch_matches_singleton_rows(): - for bits in (4, 5): - mx.random.seed(180 + bits) - linear = nn.QuantizedLinear(512, 32, bias=False, group_size=64, bits=bits) - linear.scales = linear.scales.astype(mx.bfloat16) - linear.biases = linear.biases.astype(mx.bfloat16) - x = mx.random.normal((4, 1, 512), dtype=mx.bfloat16) - - ref = mx.concatenate([linear(x[row : row + 1]) for row in range(4)], axis=0) - out = qwen_language._decode_quantized_linear_batch(linear, x) - mx.eval(ref, out) - - assert out is not None - assert bool(mx.array_equal(ref, out).item()) - - def test_qwen_target_verify_quantized_argmax_matches_singleton_path(): mx.random.seed(16) linear = nn.QuantizedLinear(512, 16, bias=False, group_size=32, bits=4) From a8e73d2f5905dd24afcaeb9db46ce074fbd75610 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Tue, 26 May 2026 00:38:38 +0200 Subject: [PATCH 23/26] Reduce Qwen3.5 batched decode sync overhead --- mlx_vlm/models/qwen3_5/language.py | 114 +++++++++++++++++++++++++---- 1 file changed, 98 insertions(+), 16 deletions(-) diff --git a/mlx_vlm/models/qwen3_5/language.py b/mlx_vlm/models/qwen3_5/language.py index a475930e2..722f33969 100644 --- a/mlx_vlm/models/qwen3_5/language.py +++ b/mlx_vlm/models/qwen3_5/language.py @@ -116,6 +116,12 @@ def _precise_swiglu(h, gate, x): return (gate * x).astype(h.dtype) +@partial(mx.compile, shapeless=True) +def _qwen3_5_decode_depthwise_conv(conv_input: mx.array, weight: mx.array): + out = mx.sum(conv_input.astype(mx.float32) * weight[None, :, :], axis=1) + return out.astype(conv_input.dtype)[:, None, :] + + _TARGET_VERIFY_GEMV = mx.fast.metal_kernel( name="qwen3_5_target_verify_gemv", input_names=["x", "weight"], @@ -757,6 +763,68 @@ def _pad_row_time(x: mx.array, pad: int, target_length: int) -> mx.array: ) +def _qwen3_5_left_padding_info(cache): + left_padding = getattr(cache, "left_padding", None) + if not ( + isinstance(left_padding, mx.array) + and left_padding.ndim > 0 + and left_padding.size > 0 + ): + return None + + cached = getattr(cache, "_qwen3_5_left_padding_info", None) + if cached is None or cached[0] is not left_padding: + pads = tuple(int(p) for p in left_padding.tolist()) + cached = (left_padding, pads, max(pads) if pads else 0) + cache._qwen3_5_left_padding_info = cached + return cached[1], cached[2] + + +def _qwen3_5_set_left_padding_info(cache, pads): + left_padding = getattr(cache, "left_padding", None) + if not isinstance(left_padding, mx.array): + return + pads = tuple(int(p) for p in pads) + cache._qwen3_5_left_padding_info = ( + left_padding, + pads, + max(pads) if pads else 0, + ) + + +def _qwen3_5_advance_left_padding_info(cache, steps: int): + cached = getattr(cache, "_qwen3_5_left_padding_info", None) + if cached is None: + return + _left_padding, pads, _max_pad = cached + _qwen3_5_set_left_padding_info(cache, (p - steps for p in pads)) + + +def _qwen3_5_lengths_info(cache): + lengths = getattr(cache, "lengths", None) + if not ( + isinstance(lengths, mx.array) + and lengths.ndim > 0 + and lengths.size > 0 + ): + return None + cached = getattr(cache, "_qwen3_5_lengths_info", None) + if cached is None or cached[0] is not lengths: + values = tuple(int(v) for v in lengths.tolist()) + cached = (lengths, min(values) if values else 0) + cache._qwen3_5_lengths_info = cached + return cached[1] + + +def _qwen3_5_advance_lengths_info(cache, steps: int): + lengths = getattr(cache, "lengths", None) + cached = getattr(cache, "_qwen3_5_lengths_info", None) + if cached is None or not isinstance(lengths, mx.array): + return + _lengths, min_value = cached + cache._qwen3_5_lengths_info = (lengths, min_value - steps) + + def _create_qwen3_5_ssm_mask(h: mx.array, cache): if not (cache and hasattr(cache, "make_mask")): return None @@ -771,14 +839,17 @@ def _create_qwen3_5_ssm_mask(h: mx.array, cache): == batch_size ): return None - if int(left_padding.max().item()) <= 0: + left_padding_info = _qwen3_5_left_padding_info(cache) + max_left_padding = left_padding_info[1] if left_padding_info else 0 + if max_left_padding <= 0: if lengths is None: cache._qwen3_5_ssm_no_mask_batch_size = batch_size return None if hasattr(cache, "_qwen3_5_ssm_no_mask_batch_size"): delattr(cache, "_qwen3_5_ssm_no_mask_batch_size") - if isinstance(lengths, mx.array) and int(lengths.min().item()) >= h.shape[1]: + lengths_min = _qwen3_5_lengths_info(cache) + if lengths_min is not None and lengths_min >= h.shape[1]: return None return cache.make_mask(h.shape[1]) @@ -799,8 +870,9 @@ def _create_qwen3_5_attention_mask(h: mx.array, cache): ): padding_cache = getattr(cache, "_qwen3_5_left_padding_cache", None) if padding_cache is None or padding_cache[0] is not left_padding: - pads = [int(p) for p in left_padding.tolist()] - padding_cache = (left_padding, pads, max(pads)) + left_padding_info = _qwen3_5_left_padding_info(cache) + pads = list(left_padding_info[0]) if left_padding_info else [] + padding_cache = (left_padding, pads, max(pads) if pads else 0) cache._qwen3_5_left_padding_cache = padding_cache pads = padding_cache[1] if padding_cache[2] <= 0: @@ -1173,6 +1245,19 @@ def _qwen3_5_ragged_sdpa_two_pass_2_kernel(dtype, v_size, blocks): ) +@lru_cache(maxsize=128) +def _qwen3_5_cached_i32_array(values): + return mx.array(values, dtype=mx.int32) + + +@lru_cache(maxsize=128) +def _qwen3_5_cached_sdpa_scalars(scale: float, k_size: int): + return ( + mx.array([scale], dtype=mx.float32), + mx.array([k_size], dtype=mx.int32), + ) + + def _qwen3_5_ragged_decode_attention( queries: mx.array, keys: mx.array, @@ -1192,6 +1277,7 @@ def _qwen3_5_ragged_decode_attention( return None batch, q_heads, _, d_size = queries.shape + pads = tuple(int(p) for p in pads) if len(pads) != batch or any(p < 0 for p in pads): return None kv_heads = keys.shape[1] @@ -1213,9 +1299,8 @@ def _qwen3_5_ragged_decode_attention( queries = mx.contiguous(queries) keys = mx.contiguous(keys) values = mx.contiguous(values) - pads_array = mx.array(pads, dtype=mx.int32) - scale_array = mx.array([scale], dtype=mx.float32) - k_size_array = mx.array([k_size], dtype=mx.int32) + pads_array = _qwen3_5_cached_i32_array(pads) + scale_array, k_size_array = _qwen3_5_cached_sdpa_scalars(float(scale), int(k_size)) template = [ ("T", queries.dtype), ("D_SIZE", int(d_size)), @@ -1277,14 +1362,10 @@ def _target_verify_left_padded_attention( pads = getattr(cache, "_qwen3_5_decode_left_padding", None) if pads is None: - left_padding = getattr(cache, "left_padding", None) - if not ( - isinstance(left_padding, mx.array) - and left_padding.ndim > 0 - and int(left_padding.max().item()) > 0 - ): + left_padding_info = _qwen3_5_left_padding_info(cache) + if left_padding_info is None or left_padding_info[1] <= 0: return None - pads = [int(p) for p in left_padding.tolist()] + pads = list(left_padding_info[0]) if max(pads) <= 0: return None @@ -1539,8 +1620,7 @@ def _causal_conv1d_decode(self, conv_input: mx.array) -> mx.array: self._qwen3_5_decode_conv_weight = cached weight = cached[1] - out = mx.sum(conv_input.astype(mx.float32) * weight[None, :, :], axis=1) - return out.astype(self.conv1d.weight.dtype)[:, None, :] + return _qwen3_5_decode_depthwise_conv(conv_input, weight) def __call__( self, @@ -1660,6 +1740,8 @@ def __call__( cache[1] = state if hasattr(cache, "advance"): cache.advance(S) + _qwen3_5_advance_left_padding_info(cache, S) + _qwen3_5_advance_lengths_info(cache, S) out = self.norm(out, z) return _target_verify_linear( From c807a22e6bb6b0c9aad3154e4b8aef2dfde1209e Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Wed, 27 May 2026 10:09:24 +0200 Subject: [PATCH 24/26] Support PoolingCache in batched cache creation --- mlx_vlm/generate.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mlx_vlm/generate.py b/mlx_vlm/generate.py index 7f539b16c..34b19bd2f 100644 --- a/mlx_vlm/generate.py +++ b/mlx_vlm/generate.py @@ -1703,6 +1703,8 @@ def to_batch_cache(c, quantize=True): elif isinstance(c, cache.ArraysCache): c.left_padding = mx.array(left_padding) return c + elif isinstance(c, cache.PoolingCache): + return cache.BatchPoolingCache(c.ratio, left_padding) elif isinstance(c, cache.RotatingKVCache): if c.keep > 0: raise ValueError("RotatingKVCache with keep tokens is not supported.") From 0f726e0a1128261995a8567333a784b969438ef0 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 30 May 2026 23:11:00 +0200 Subject: [PATCH 25/26] Apply pre-commit formatting --- mlx_vlm/generate.py | 25 +++++-------- mlx_vlm/models/qwen3_5/language.py | 58 +++++++++++------------------- mlx_vlm/speculative/mtp.py | 30 ++++++++++------ mlx_vlm/tests/test_generate.py | 4 ++- mlx_vlm/tests/test_speculative.py | 12 +++---- 5 files changed, 57 insertions(+), 72 deletions(-) diff --git a/mlx_vlm/generate.py b/mlx_vlm/generate.py index 34b19bd2f..701f756b4 100644 --- a/mlx_vlm/generate.py +++ b/mlx_vlm/generate.py @@ -73,9 +73,7 @@ def _get_batch_cache_eval_interval() -> int: try: return max(0, int(raw)) except ValueError: - logger.warning( - "Ignoring invalid MLX_VLM_BATCH_CACHE_EVAL_INTERVAL=%r", raw - ) + logger.warning("Ignoring invalid MLX_VLM_BATCH_CACHE_EVAL_INTERVAL=%r", raw) return DEFAULT_BATCH_CACHE_EVAL_INTERVAL @@ -1634,11 +1632,7 @@ def kv_rows(c): if isinstance(c, cache.KVCache): return [c] offset = getattr(c, "offset", None) - if ( - hasattr(c, "extract") - and isinstance(offset, mx.array) - and offset.ndim > 0 - ): + if hasattr(c, "extract") and isinstance(offset, mx.array) and offset.ndim > 0: return [c.extract(i) for i in range(int(offset.shape[0]))] return None @@ -2391,11 +2385,9 @@ def next(self) -> List[Response]: return responses self._append_token_responses(responses, tok_list) - while ( - isinstance(round_meta, dict) - and int(round_meta.get("round_pos", 0)) + 1 - < int(round_meta.get("round_len", 1)) - ): + while isinstance(round_meta, dict) and int( + round_meta.get("round_pos", 0) + ) + 1 < int(round_meta.get("round_len", 1)): try: tok_list, round_meta = next(self._rounds_iter) except StopIteration: @@ -3501,9 +3493,10 @@ def _next(self, **kwargs): mx.eval([c.state for c in self._generation_batch.prompt_cache]) mx.clear_cache() - if getattr(self._generation_batch, "is_speculative", False) and len( - self._generation_batch - ) > 0: + if ( + getattr(self._generation_batch, "is_speculative", False) + and len(self._generation_batch) > 0 + ): return prompt_responses, generation_responses if len(self._generation_batch) >= self.completion_batch_size: diff --git a/mlx_vlm/models/qwen3_5/language.py b/mlx_vlm/models/qwen3_5/language.py index 37c9c0f36..6fcdf11a8 100644 --- a/mlx_vlm/models/qwen3_5/language.py +++ b/mlx_vlm/models/qwen3_5/language.py @@ -177,8 +177,7 @@ def _target_verify_weight(weight: mx.array, x: mx.array) -> Optional[mx.array]: def _target_verify_qlinear_header(bits: int, group_size: int) -> str: - return ( - r""" + return r""" using namespace metal; constant constexpr int SIMD_SIZE = 32; @@ -259,9 +258,10 @@ def _target_verify_qlinear_header(bits: int, group_size: int) -> str: } return scale * accum + sum * bias; } -""" - .replace("__BITS__", str(bits)) - .replace("__GS__", str(group_size)) +""".replace( + "__BITS__", str(bits) + ).replace( + "__GS__", str(group_size) ) @@ -480,9 +480,7 @@ def _can_target_verify_quantized(linear, x: mx.array) -> bool: _, _, K = x.shape N = linear.weight.shape[0] return ( - K == linear.weight.shape[1] * 32 // linear.bits - and K % 512 == 0 - and N % 8 == 0 + K == linear.weight.shape[1] * 32 // linear.bits and K % 512 == 0 and N % 8 == 0 ) @@ -536,8 +534,7 @@ def _decode_quantized_linears_fused(linears, x: mx.array): return None cache_key = tuple( - (id(linear.weight), id(linear.scales), id(linear.biases)) - for linear in linears + (id(linear.weight), id(linear.scales), id(linear.biases)) for linear in linears ) cached = getattr(first, "_qwen3_5_fused_decode_linears", None) if cached is None or cached[0] != cache_key: @@ -599,9 +596,7 @@ def _target_verify_quantized_argmax(linear, x: mx.array) -> Optional[mx.array]: output_dtypes=[x.dtype, mx.int32], ) best_tile = mx.argmax(tile_values, axis=-1) - return mx.take_along_axis(tile_indices, best_tile[..., None], axis=-1).squeeze( - -1 - ) + return mx.take_along_axis(tile_indices, best_tile[..., None], axis=-1).squeeze(-1) def _target_verify_timewise(fn, x: mx.array) -> mx.array: @@ -752,11 +747,7 @@ def _qwen3_5_advance_left_padding_info(cache, steps: int): def _qwen3_5_lengths_info(cache): lengths = getattr(cache, "lengths", None) - if not ( - isinstance(lengths, mx.array) - and lengths.ndim > 0 - and lengths.size > 0 - ): + if not (isinstance(lengths, mx.array) and lengths.ndim > 0 and lengths.size > 0): return None cached = getattr(cache, "_qwen3_5_lengths_info", None) if cached is None or cached[0] is not lengths: @@ -785,8 +776,7 @@ def _create_qwen3_5_ssm_mask(h: mx.array, cache): batch_size = int(left_padding.shape[0]) if left_padding.ndim > 0 else 1 if ( lengths is None - and getattr(cache, "_qwen3_5_ssm_no_mask_batch_size", None) - == batch_size + and getattr(cache, "_qwen3_5_ssm_no_mask_batch_size", None) == batch_size ): return None left_padding_info = _qwen3_5_left_padding_info(cache) @@ -813,11 +803,7 @@ def _create_qwen3_5_attention_mask(h: mx.array, cache): delattr(cache, "_qwen3_5_decode_left_padding") left_padding = getattr(cache, "left_padding", None) - if ( - h.shape[1] == 1 - and isinstance(left_padding, mx.array) - and left_padding.ndim > 0 - ): + if h.shape[1] == 1 and isinstance(left_padding, mx.array) and left_padding.ndim > 0: padding_cache = getattr(cache, "_qwen3_5_left_padding_cache", None) if padding_cache is None or padding_cache[0] is not left_padding: left_padding_info = _qwen3_5_left_padding_info(cache) @@ -1173,8 +1159,7 @@ def _qwen3_5_ragged_sdpa_two_pass_1_kernel(dtype, d_size, v_size, blocks): dtype_name = {mx.bfloat16: "bf16", mx.float16: "fp16"}.get(dtype, "unk") return mx.fast.metal_kernel( name=( - f"qwen3_5_ragged_sdpa_2p1_{dtype_name}_" - f"d{d_size}_v{v_size}_b{blocks}" + f"qwen3_5_ragged_sdpa_2p1_{dtype_name}_" f"d{d_size}_v{v_size}_b{blocks}" ), input_names=["queries", "keys", "values", "pads", "scale", "k_size"], output_names=["partials", "sums", "maxs"], @@ -1286,18 +1271,21 @@ def _qwen3_5_ragged_decode_attention( ], output_dtypes=[queries.dtype, mx.float32, mx.float32], ) - kernel_2 = _qwen3_5_ragged_sdpa_two_pass_2_kernel( - queries.dtype, v_size, blocks - ) + kernel_2 = _qwen3_5_ragged_sdpa_two_pass_2_kernel(queries.dtype, v_size, blocks) return kernel_2( inputs=[partials, sums, maxs], - template=[("T", queries.dtype), ("D_SIZE", int(v_size)), ("BLOCKS", int(blocks))], + template=[ + ("T", queries.dtype), + ("D_SIZE", int(v_size)), + ("BLOCKS", int(blocks)), + ], grid=(1024, batch * q_heads, 1), threadgroup=(1024, 1, 1), output_shapes=[(batch, q_heads, 1, v_size)], output_dtypes=[queries.dtype], )[0] + def _target_verify_left_padded_attention( queries: mx.array, keys: mx.array, @@ -2060,9 +2048,7 @@ def valid_ends_array(): raise ValueError("Qwen GDN layers must share conv kernel size.") accepted_mx = accepted_array() - accepted_bat = mx.concatenate( - [accepted_mx for _ in ssm_caches], axis=0 - ) + accepted_bat = mx.concatenate([accepted_mx for _ in ssm_caches], axis=0) state_bat, conv_bat = gated_delta_accept_states( mx.concatenate(intermediate_parts, axis=0), mx.concatenate(conv_input_parts, axis=0), @@ -2480,9 +2466,7 @@ def __call__( inputs, image_grid_thw, video_grid_thw, rope_mask ) if image_grid_thw is None and video_grid_thw is None: - rope_deltas = mx.zeros( - (batch_size, 1), dtype=rope_deltas.dtype - ) + rope_deltas = mx.zeros((batch_size, 1), dtype=rope_deltas.dtype) self._rope_deltas = rope_deltas self._position_ids = position_ids else: diff --git a/mlx_vlm/speculative/mtp.py b/mlx_vlm/speculative/mtp.py index 4f05164d9..1445a8898 100644 --- a/mlx_vlm/speculative/mtp.py +++ b/mlx_vlm/speculative/mtp.py @@ -362,7 +362,9 @@ def _speculative_walk_batch_deferred_uniform( return [accepted] * B, [draft_lists[row][: budgets[row]] for row in range(B)] -def _sampler_supports_positioned_target(sampler: Callable[[mx.array], mx.array]) -> bool: +def _sampler_supports_positioned_target( + sampler: Callable[[mx.array], mx.array] +) -> bool: return callable(getattr(sampler, "sample_target", None)) @@ -554,7 +556,8 @@ def _mtp_rounds( draft_model.reset(model) sampler_rng = _SpeculativeSamplerRNG( draft_model, - enabled=not greedy_sampling and not _sampler_supports_positioned_target(sampler), + enabled=not greedy_sampling + and not _sampler_supports_positioned_target(sampler), ) # Hidden from prefill is full prompt-length; reduce to a single slot. @@ -849,7 +852,8 @@ def _mtp_rounds_batch( draft_model.reset(model) sampler_rng = _SpeculativeSamplerRNG( draft_model, - enabled=not greedy_sampling and not _sampler_supports_positioned_target(sampler), + enabled=not greedy_sampling + and not _sampler_supports_positioned_target(sampler), ) # First-round hidden: prefill output may have shape [B, L, H]; reduce @@ -960,14 +964,18 @@ def _mtp_rounds_batch( ) ) else: - accepted_list, new_tokens_list = _speculative_walk_batch_deferred_greedy( - lm, - hidden_full, - draft_tokens, - sampler, - budgets, - row_ids=[row_ids[active_idx[j]] for j in range(n_active)], - base_positions=[emitted[active_idx[j]] for j in range(n_active)], + accepted_list, new_tokens_list = ( + _speculative_walk_batch_deferred_greedy( + lm, + hidden_full, + draft_tokens, + sampler, + budgets, + row_ids=[row_ids[active_idx[j]] for j in range(n_active)], + base_positions=[ + emitted[active_idx[j]] for j in range(n_active) + ], + ) ) sampler_rng.target_sampled( sync_draft=not _sampler_supports_positioned_target(sampler) diff --git a/mlx_vlm/tests/test_generate.py b/mlx_vlm/tests/test_generate.py index 53313e21e..52d825bc1 100644 --- a/mlx_vlm/tests/test_generate.py +++ b/mlx_vlm/tests/test_generate.py @@ -691,7 +691,9 @@ def fake_rounds(*args, **kwargs): yield [2, 11], {"round_pos": 1, "round_len": 2} yield [3, 12], {"round_pos": 0, "round_len": 1} - monkeypatch.setattr(generate_module, "run_speculative_server_rounds", fake_rounds) + monkeypatch.setattr( + generate_module, "run_speculative_server_rounds", fake_rounds + ) batch = SpeculativeGenerationBatch( model=SimpleNamespace(), diff --git a/mlx_vlm/tests/test_speculative.py b/mlx_vlm/tests/test_speculative.py index b8bb20a29..d4e194e34 100644 --- a/mlx_vlm/tests/test_speculative.py +++ b/mlx_vlm/tests/test_speculative.py @@ -618,9 +618,7 @@ def test_qwen3_5_ragged_decode_attention_matches_two_pass_singleton(): pads = [7, 0] scale = 64**-0.5 key_length = ( - 1100 - if qwen_language._qwen3_5_device_arch_suffix() in {"d", "s"} - else 4112 + 1100 if qwen_language._qwen3_5_device_arch_suffix() in {"d", "s"} else 4112 ) queries = mx.random.normal((2, 4, 1, 64), dtype=mx.bfloat16) keys = mx.random.normal((2, 2, key_length, 64), dtype=mx.bfloat16) @@ -803,9 +801,7 @@ def test_qwen3_5_rope_index_ignores_left_padding_for_vision_rows(): [[0, 10, 100, 101, 11, 12], [20, 21, 22, 23, 24, 25]], dtype=mx.int32, ) - attention_mask = mx.array( - [[0, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]], dtype=mx.int32 - ) + attention_mask = mx.array([[0, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]], dtype=mx.int32) image_grid_thw = mx.array([[1, 2, 2]], dtype=mx.int32) singleton_pos, singleton_delta = lm.get_rope_index(singleton_ids, image_grid_thw) @@ -2607,7 +2603,9 @@ def test_qwen3_5_mtp_batch_accept_updates_ragged_cache(): def test_qwen3_5_rollback_speculative_cache_trims_batch_rows_ragged(): text_config = _tiny_qwen3_5_text_config() language = qwen_language.LanguageModel(args=text_config) - cache = qwen_language.KVCache.merge([qwen_language.KVCache(), qwen_language.KVCache()]) + cache = qwen_language.KVCache.merge( + [qwen_language.KVCache(), qwen_language.KVCache()] + ) keys = mx.arange(2 * 1 * 5 * 4, dtype=mx.float32).reshape(2, 1, 5, 4) values = keys + 100 cache.update_and_fetch(keys, values) From 859d72499e64d4681b4449b93ae514debd630967 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 31 May 2026 02:31:37 +0200 Subject: [PATCH 26/26] Improve Qwen batch decode stability --- mlx_vlm/generate/ar.py | 28 ++++++++++++++++- mlx_vlm/models/qwen3_5/gated_delta.py | 43 +++++++++++++++++++++++---- mlx_vlm/models/qwen3_5/language.py | 11 +++---- mlx_vlm/tests/test_generate.py | 37 +++++++++++++++++++++++ 4 files changed, 108 insertions(+), 11 deletions(-) diff --git a/mlx_vlm/generate/ar.py b/mlx_vlm/generate/ar.py index 14fc7297e..42256dd52 100644 --- a/mlx_vlm/generate/ar.py +++ b/mlx_vlm/generate/ar.py @@ -647,9 +647,15 @@ def _extend_cache(cache_a, cache_b): return cache_b if not cache_b: return cache_a + extended = [] for ca, cb in zip(cache_a, cache_b): + if not hasattr(ca, "left_padding") and hasattr(ca.__class__, "merge"): + ca = ca.__class__.merge([ca]) + if not hasattr(cb, "left_padding") and hasattr(cb.__class__, "merge"): + cb = cb.__class__.merge([cb]) ca.extend(cb) - return cache_a + extended.append(ca) + return extended def _make_cache( @@ -1902,6 +1908,26 @@ def generate( rope_deltas = rope_deltas[:target_b] gen_batch._rope_deltas = rope_deltas + # Final prefill produces the first generated token and mutates the + # prompt cache. Materialize that boundary before the decode loop so + # the first decode step does not inherit the full lazy prefill graph. + cache_states = [] + for c in self.prompt_cache: + try: + cache_states.append(c.state) + except (AttributeError, TypeError): + pass + eval_targets = [first_tokens] + if cache_states: + eval_targets.append(cache_states) + if compute_logprobs and isinstance(gen_batch, GenerationBatch): + eval_targets.append(gen_batch._next_lps) + if top_logprobs_k > 0 and isinstance(gen_batch, GenerationBatch): + eval_targets.extend([gen_batch._next_top_idx, gen_batch._next_top_lp]) + if rope_deltas is not None: + eval_targets.append(rope_deltas) + mx.eval(*eval_targets) + # APC: harvest the post-prefill K/V into hashed blocks. Done after the # final prefill forward but before the cache references are released # so the block tensors snapshot the prompt prefix. diff --git a/mlx_vlm/models/qwen3_5/gated_delta.py b/mlx_vlm/models/qwen3_5/gated_delta.py index 6d8ea7ae0..e2a18cac5 100644 --- a/mlx_vlm/models/qwen3_5/gated_delta.py +++ b/mlx_vlm/models/qwen3_5/gated_delta.py @@ -1,7 +1,42 @@ +from functools import partial from typing import Optional import mlx.core as mx -from mlx_lm.models.gated_delta import compute_g, gated_delta_update # noqa: F401 +import mlx.nn as nn +from mlx_lm.models.gated_delta import gated_delta_kernel, gated_delta_ops + + +@partial(mx.compile, shapeless=True) +def compute_g(A_log, a, dt_bias): + return mx.exp(-mx.exp(A_log.astype(mx.float32)) * nn.softplus(a + dt_bias)) + + +@partial(mx.compile, shapeless=True) +def _compute_g_beta(A_log, a, b, dt_bias): + return compute_g(A_log, a, dt_bias), mx.sigmoid(b) + + +def gated_delta_update( + q: mx.array, + k: mx.array, + v: mx.array, + a: mx.array, + b: mx.array, + A_log: mx.array, + dt_bias: mx.array, + state: Optional[mx.array] = None, + mask: Optional[mx.array] = None, + use_kernel: bool = True, +): + g, beta = _compute_g_beta(A_log, a, b, dt_bias) + if state is None: + B, _, _Hk, Dk = q.shape + Hv, Dv = v.shape[-2:] + state = mx.zeros((B, Hv, Dv, Dk), dtype=mx.float32) + + if not use_kernel or mx.default_device() != mx.gpu or not mx.metal.is_available(): + return gated_delta_ops(q, k, v, g, beta, state, mask) + return gated_delta_kernel(q, k, v, g, beta, state, mask) def _make_gated_delta_with_states_kernel(has_mask: bool = False): @@ -153,8 +188,7 @@ def gated_delta_update_with_states( mask: Optional[mx.array] = None, use_kernel: bool = True, ): - beta = mx.sigmoid(b) - g = compute_g(A_log, a, dt_bias) + g, beta = _compute_g_beta(A_log, a, b, dt_bias) if state is None: B, _, _Hk, Dk = q.shape Hv, Dv = v.shape[-2:] @@ -459,8 +493,7 @@ def gated_delta_state_update( mask: Optional[mx.array] = None, use_kernel: bool = True, ) -> mx.array: - beta = mx.sigmoid(b) - g = compute_g(A_log, a, dt_bias) + g, beta = _compute_g_beta(A_log, a, b, dt_bias) if state is None: B, _, _Hk, Dk = k.shape Hv, Dv = v.shape[-2:] diff --git a/mlx_vlm/models/qwen3_5/language.py b/mlx_vlm/models/qwen3_5/language.py index 6fcdf11a8..4b20c6b6f 100644 --- a/mlx_vlm/models/qwen3_5/language.py +++ b/mlx_vlm/models/qwen3_5/language.py @@ -1834,11 +1834,12 @@ def __call__( pad = min(max(int(pad), 0), h.shape[1]) row_inputs = inputs[row : row + 1, pad:] row_embeds = h[row : row + 1, pad:] - row_position_ids = ( - None - if position_ids is None - else position_ids[:, row : row + 1, pad:] - ) + row_position_ids = None + if position_ids is not None: + if position_ids.ndim == 2: + row_position_ids = position_ids[row : row + 1, pad:] + else: + row_position_ids = position_ids[:, row : row + 1, pad:] current_cache = [] for cache_entry in cache: if cache_entry is None: diff --git a/mlx_vlm/tests/test_generate.py b/mlx_vlm/tests/test_generate.py index cbbc6aeea..ca1c37173 100644 --- a/mlx_vlm/tests/test_generate.py +++ b/mlx_vlm/tests/test_generate.py @@ -8,6 +8,7 @@ import mlx.core as mx import pytest +from mlx_lm.models.cache import BatchKVCache, KVCache from mlx_vlm import apc as apc_module from mlx_vlm.generate import ( @@ -808,6 +809,42 @@ def force_token_2(tokens, logits): assert [r.token for r in first] == [5, 6, 7] assert seen_contexts == [[30, 7]] + def test_generation_batch_extend_promotes_singleton_kv_cache(self): + def make_kv_cache(value): + c = KVCache() + keys = mx.full((1, 2, 3, 4), value, dtype=mx.float32) + values = mx.full((1, 2, 3, 4), value + 1, dtype=mx.float32) + c.update_and_fetch(keys, values) + return c + + sampler = lambda logprobs: mx.argmax(logprobs, axis=-1) + stop_criteria = lambda token: False + first = GenerationBatch( + model=MagicMock(), + uids=[0], + inputs=mx.array([5], dtype=mx.int32), + prompt_cache=[make_kv_cache(1.0)], + sampler=sampler, + stop_criteria=stop_criteria, + max_tokens=[2], + ) + second = GenerationBatch( + model=MagicMock(), + uids=[1], + inputs=mx.array([6], dtype=mx.int32), + prompt_cache=[make_kv_cache(3.0)], + sampler=sampler, + stop_criteria=stop_criteria, + max_tokens=[2], + ) + + first.extend(second) + + assert isinstance(first.prompt_cache[0], BatchKVCache) + assert first.prompt_cache[0].left_padding.tolist() == [0, 0] + assert first.prompt_cache[0].keys.shape[0] == 2 + assert first._next_tokens.tolist() == [5, 6] + def test_remove_from_unprocessed(self, mock_model, mock_processor): gen = BatchGenerator( model=mock_model.language_model,