diff --git a/mlx_vlm/generate/ar.py b/mlx_vlm/generate/ar.py index 409ee15a9..054583f24 100644 --- a/mlx_vlm/generate/ar.py +++ b/mlx_vlm/generate/ar.py @@ -19,7 +19,14 @@ from .. import apc as _apc from ..models import cache from ..prompt_utils import apply_chat_template -from ..speculative.utils import run_speculative_rounds +from ..sample_utils import top_p_sampling +from ..speculative.utils import ( + make_speculative_prompt_cache, + run_speculative_rounds, + run_speculative_server_rounds, + speculative_hidden_state, + speculative_prefill_kwargs, +) from ..turboquant import BatchTurboQuantKVCache, turboquant_enabled from ..utils import group_images_by_shape, prepare_inputs from .common import ( @@ -42,6 +49,83 @@ DEFAULT_PREFILL_STEP_SIZE = 2048 DEFAULT_COMPLETION_BATCH_SIZE = 32 DEFAULT_PREFILL_BATCH_SIZE = 8 +DEFAULT_BATCH_CACHE_EVAL_INTERVAL = 50 + + +def _get_batch_cache_eval_interval() -> int: + raw = os.environ.get("MLX_VLM_BATCH_CACHE_EVAL_INTERVAL") + if raw is None: + return DEFAULT_BATCH_CACHE_EVAL_INTERVAL + try: + return max(0, int(raw)) + except ValueError: + logger.warning("Ignoring invalid MLX_VLM_BATCH_CACHE_EVAL_INTERVAL=%r", raw) + return DEFAULT_BATCH_CACHE_EVAL_INTERVAL + + +def _position_seed(seed: int, row_id: int, position: int) -> int: + x = (int(seed) ^ 0x9E3779B9) & 0xFFFFFFFF + x = (x + (int(row_id) + 1) * 0x85EBCA6B) & 0xFFFFFFFF + x = (x ^ ((int(position) + 1) * 0xC2B2AE35)) & 0xFFFFFFFF + x ^= x >> 16 + x = (x * 0x7FEB352D) & 0xFFFFFFFF + x ^= x >> 15 + return int(x & 0xFFFFFFFF) + + +def _position_keys(seed: int, row_ids: List[int], positions: List[int]) -> mx.array: + return mx.stack( + [ + mx.random.key(_position_seed(seed, row, pos)) + for row, pos in zip(row_ids, positions) + ] + ) + + +class _PositionedTargetSampler: + """Sampler with stateless target draws keyed by generated-token position.""" + + def __init__(self, *, temperature: float, top_p: float, seed: int): + self.temperature = float(temperature) + self.top_p = float(top_p) + self.seed = int(seed) + + def __call__(self, logprobs: mx.array) -> mx.array: + if self.top_p > 0 and self.top_p < 1.0: + return top_p_sampling(logprobs, self.top_p, self.temperature) + return mx.random.categorical(logprobs * (1 / self.temperature)) + + def sample_target( + self, + logprobs: mx.array, + *, + row_ids: List[int], + positions: List[int], + ) -> mx.array: + if logprobs.shape[0] != len(row_ids) or len(row_ids) != len(positions): + raise ValueError("row_ids and positions must match logprobs batch size.") + keys = _position_keys(self.seed, row_ids, positions) + if self.top_p > 0 and self.top_p < 1.0: + return mx.vmap(self._sample_top_p_one, in_axes=(0, 0))(logprobs, keys) + return mx.vmap(self._sample_one, in_axes=(0, 0))(logprobs, keys) + + def _sample_one(self, logprobs: mx.array, key: mx.array) -> mx.array: + return mx.random.categorical(logprobs * (1 / self.temperature), key=key) + + def _sample_top_p_one(self, logprobs: mx.array, key: mx.array) -> mx.array: + if logprobs.dtype == mx.bfloat16: + logprobs = logprobs.astype(mx.float32) + probs = mx.softmax(logprobs / self.temperature, axis=-1) + sorted_indices = mx.argsort(probs, axis=-1) + sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=-1) + cumulative_probs = mx.cumsum(sorted_probs, axis=-1) + top_probs = mx.where( + cumulative_probs > 1 - self.top_p, + sorted_probs, + mx.zeros_like(sorted_probs), + ) + sampled_pos = mx.random.categorical(mx.log(top_probs), key=key) + return mx.take_along_axis(sorted_indices, sampled_pos[..., None], axis=-1)[0] def _generate_module_override(name: str, fallback): @@ -93,6 +177,7 @@ def generate_step( draft_block_size: Optional[int] = None, prompt_cache_checkpoint: Optional[Callable[[int, List[Any]], None]] = None, prompt_cache_checkpoint_len: Optional[int] = None, + seed: Optional[int] = None, verbose: bool = False, **kwargs, ) -> Generator[Tuple[mx.array, mx.array], None, None]: @@ -163,12 +248,24 @@ def generate_step( sampler_is_greedy = sampler is None and temperature == 0 if sampler is None: - sampler = _generate_module_override("make_sampler", make_sampler)( - temp=temperature, - top_p=top_p, - min_p=min_p, - top_k=top_k, - ) + if ( + seed is not None + and temperature > 0 + and min_p == DEFAULT_MIN_P + and top_k == DEFAULT_TOP_K + ): + sampler = _PositionedTargetSampler( + temperature=temperature, + top_p=top_p, + seed=seed, + ) + else: + sampler = _generate_module_override("make_sampler", make_sampler)( + temp=temperature, + top_p=top_p, + min_p=min_p, + top_k=top_k, + ) processors = _generate_module_override( "make_logits_processors", make_logits_processors @@ -186,6 +283,7 @@ def generate_step( y = input_ids tokens = mx.array([], dtype=input_ids.dtype) + target_sample_position = 0 thinking_budget_criteria = kwargs.pop("thinking_budget_criteria", None) @@ -226,7 +324,7 @@ def generate_step( lm._rope_deltas = None def _step(y, inputs_embeds=None): - nonlocal tokens, kwargs, last_outputs + nonlocal tokens, kwargs, last_outputs, target_sample_position with mx.stream(generation_stream): if "decoder_input_ids" in kwargs: @@ -254,7 +352,18 @@ def _step(y, inputs_embeds=None): quantize_cache_fn(prompt_cache) logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) - y = sampler(logprobs) + y = _sample_with_positions( + sampler, + logprobs, + row_ids=[0] * logprobs.shape[0], + positions=list( + range( + target_sample_position, + target_sample_position + logprobs.shape[0], + ) + ), + ) + target_sample_position += logprobs.shape[0] if outputs.cross_attention_states is not None: kwargs = {"cross_attention_states": outputs.cross_attention_states} @@ -541,9 +650,15 @@ def _extend_cache(cache_a, cache_b): return cache_b if not cache_b: return cache_a + extended = [] for ca, cb in zip(cache_a, cache_b): + if not hasattr(ca, "left_padding") and hasattr(ca.__class__, "merge"): + ca = ca.__class__.merge([ca]) + if not hasattr(cb, "left_padding") and hasattr(cb.__class__, "merge"): + cb = cb.__class__.merge([cb]) ca.extend(cb) - return cache_a + extended.append(ca) + return extended def _make_cache( @@ -676,6 +791,19 @@ class PromptProgress: cached_tokens: int = 0 +def _sample_with_positions( + sampler: Callable[[mx.array], mx.array], + logprobs: mx.array, + *, + row_ids: Optional[List[int]] = None, + positions: Optional[List[int]] = None, +) -> mx.array: + sample_target = getattr(sampler, "sample_target", None) + if callable(sample_target) and row_ids is not None and positions is not None: + return sample_target(logprobs, row_ids=row_ids, positions=positions) + return sampler(logprobs) + + class GenerationBatch: """ Batched token generator with double-buffered pipelining. @@ -703,6 +831,7 @@ def __init__( stop_criteria, max_tokens: List[int], top_logprobs_k: int = 0, + greedy_sampling: bool = False, token_context: Optional[List[List[int]]] = None, logits_processors: Optional[ List[Optional[List[Callable[[mx.array, mx.array], mx.array]]]] @@ -719,6 +848,7 @@ def __init__( self._num_tokens = [0] * len(uids) self.compute_logprobs = True self.top_logprobs_k = top_logprobs_k + self.greedy_sampling = greedy_sampling self.logits_processors = logits_processors or [] self.thinking_budget_criteria = thinking_budget_criteria or [] self.token_context = [list(ctx) for ctx in (token_context or [])] @@ -737,6 +867,9 @@ def __init__( def __len__(self): return len(self.uids) + def cache_states(self): + return [c.state for c in self.prompt_cache if hasattr(c, "state")] + def _ensure_token_context(self, *, force: bool = False): if not (force or (self.logits_processors and any(self.logits_processors))): if not self.logits_processors: @@ -748,6 +881,36 @@ def _ensure_token_context(self, *, force: bool = False): elif len(self.token_context) > len(self.uids): self.token_context = self.token_context[: len(self.uids)] + def _greedy_argmax_step(self, inputs: mx.array, fwd_kwargs: dict): + if ( + not self.greedy_sampling + or self.compute_logprobs + or self.top_logprobs_k > 0 + or (self.logits_processors and any(self.logits_processors)) + ): + return None + + argmax_from_hidden = getattr( + self._language_model, "speculative_argmax_from_hidden", None + ) + if not callable(argmax_from_hidden): + return None + + output = self._language_model( + inputs[:, None], + cache=self.prompt_cache, + return_hidden=True, + skip_logits=True, + **fwd_kwargs, + ) + hidden = output.hidden_states[-1] + sampled = argmax_from_hidden(hidden) + if sampled is None: + return None + if sampled.ndim == 2 and sampled.shape[1] == 1: + sampled = sampled[:, 0] + return sampled + def _step(self): """Perform one generation step with double buffering.""" self._current_tokens = self._next_tokens @@ -758,6 +921,16 @@ def _step(self): if self._rope_deltas is not None: fwd_kwargs["rope_deltas"] = self._rope_deltas + sampled = self._greedy_argmax_step(inputs, fwd_kwargs) + if sampled is not None: + self._next_tokens = sampled + self._next_lps = None + self._next_top_idx = None + self._next_top_lp = None + mx.async_eval(self._next_tokens) + mx.eval(inputs) + return inputs.tolist(), None, None, None + output = self._language_model( inputs[:, None], cache=self.prompt_cache, **fwd_kwargs ) @@ -787,7 +960,12 @@ def _step(self): logits = mx.concatenate(processed_logits, axis=0) logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) - sampled = self.sampler(logprobs) + sampled = _sample_with_positions( + self.sampler, + logprobs, + row_ids=[0] * len(self.uids), + positions=[n + 1 for n in self._num_tokens], + ) self._next_tokens = sampled prev_top_idx = self._next_top_idx @@ -1038,7 +1216,13 @@ def next(self) -> List[Response]: @classmethod def empty( - cls, model, sampler, stop_criteria, compute_logprobs=True, top_logprobs_k=0 + cls, + model, + sampler, + stop_criteria, + compute_logprobs=True, + top_logprobs_k=0, + greedy_sampling: bool = False, ): """Create an empty generation batch.""" batch = cls.__new__(cls) @@ -1052,6 +1236,7 @@ def empty( batch._num_tokens = [] batch.compute_logprobs = compute_logprobs batch.top_logprobs_k = top_logprobs_k + batch.greedy_sampling = greedy_sampling batch.token_context = [] batch.logits_processors = [] batch.thinking_budget_criteria = [] @@ -1065,6 +1250,194 @@ def empty( return batch +class SpeculativeGenerationBatch: + """GenerationBatch-compatible wrapper for server-side MTP decode.""" + + is_speculative = True + Response = GenerationBatch.Response + + def __init__( + self, + model: nn.Module, + draft_model: nn.Module, + draft_kind: str, + uids: List[int], + first_tokens: mx.array, + prompt_cache: List[Any], + sampler: Callable[[mx.array], mx.array], + stop_criteria, + max_tokens: List[int], + hidden: mx.array, + shared_kv_states: Optional[dict], + prompt_tokens: mx.array, + *, + draft_block_size: Optional[int] = None, + token_dtype: mx.Dtype = mx.int32, + greedy_sampling: bool = False, + ): + self.model = model + self.draft_model = draft_model + self.draft_kind = draft_kind + self.uids = list(uids) + self._all_uids = list(uids) + self.first_tokens = first_tokens + self.prompt_cache = prompt_cache + self.sampler = sampler + self.stop_criteria = stop_criteria + self.max_tokens = list(max_tokens) + self.hidden = hidden + self.shared_kv_states = shared_kv_states + self.prompt_tokens = prompt_tokens + self.draft_block_size = draft_block_size + self.token_dtype = token_dtype + self.greedy_sampling = greedy_sampling + self._num_tokens = [0] * len(uids) + self._finished = [False] * len(uids) + self._sent_first = False + self._rounds_iter = None + + def __len__(self): + return sum(not done for done in self._finished) + + def _refresh_uids(self): + self.uids = [ + uid for uid, done in zip(self._all_uids, self._finished) if not done + ] + + def extend(self, other: "SpeculativeGenerationBatch"): + if len(self) == 0: + self.__dict__.update(other.__dict__) + return + raise RuntimeError("Cannot extend an active speculative generation batch.") + + def filter(self, keep: List[int]): + keep_uids = {self.uids[idx] for idx in keep} + for i, uid in enumerate(self._all_uids): + if uid not in keep_uids: + self._finished[i] = True + self._refresh_uids() + + def cache_states(self): + return [c.state for c in self.prompt_cache if hasattr(c, "state")] + + def _finish_reason(self, row: int, token: int) -> Optional[str]: + if self.stop_criteria(token): + return "stop" + if self._num_tokens[row] >= self.max_tokens[row]: + return "length" + return None + + def _append_token_responses( + self, + responses: List[GenerationBatch.Response], + tok_list: List[Optional[int]], + ) -> None: + for row, token in enumerate(tok_list): + if token is None or self._finished[row]: + continue + token = int(token) + self._num_tokens[row] += 1 + finish_reason = self._finish_reason(row, token) + if finish_reason is not None: + self._finished[row] = True + responses.append( + self.Response( + uid=self._all_uids[row], + token=token, + token_logprob=0.0, + finish_reason=finish_reason, + ) + ) + + def _start_rounds(self): + if self._rounds_iter is not None: + return + + def stop_check(seq_idx, token_id): + return ( + self._finished[seq_idx] + or self.stop_criteria(token_id) + or self._num_tokens[seq_idx] >= self.max_tokens[seq_idx] + ) + + self._rounds_iter = run_speculative_server_rounds( + self.model, + self.draft_model, + self.prompt_cache, + self.hidden, + draft_kind=self.draft_kind, + first_bonus=self.first_tokens, + max_tokens=max(self.max_tokens) if self.max_tokens else 0, + sampler=self.sampler, + draft_block_size=self.draft_block_size, + token_dtype=self.token_dtype, + stop_check=stop_check, + greedy_sampling=self.greedy_sampling, + shared_kv_states=self.shared_kv_states, + eos_token_ids=None, + prompt_tokens=self.prompt_tokens, + row_ids=[0] * len(self._all_uids), + ) + + def next(self) -> List[GenerationBatch.Response]: + if len(self) == 0: + return [] + + responses: List[GenerationBatch.Response] = [] + if not self._sent_first: + self._sent_first = True + mx.eval(self.first_tokens) + for row, token in enumerate(self.first_tokens.tolist()): + if self._finished[row]: + continue + token = int(token) + self._num_tokens[row] += 1 + finish_reason = self._finish_reason(row, token) + if finish_reason is not None: + self._finished[row] = True + responses.append( + self.Response( + uid=self._all_uids[row], + token=token, + token_logprob=0.0, + finish_reason=finish_reason, + ) + ) + self._refresh_uids() + return responses + + self._start_rounds() + try: + tok_list, round_meta = next(self._rounds_iter) + except StopIteration: + for row, done in enumerate(self._finished): + if not done: + self._finished[row] = True + responses.append( + self.Response( + uid=self._all_uids[row], + token=None, + token_logprob=0.0, + finish_reason="length", + ) + ) + self._refresh_uids() + return responses + + self._append_token_responses(responses, tok_list) + while isinstance(round_meta, dict) and int( + round_meta.get("round_pos", 0) + ) + 1 < int(round_meta.get("round_len", 1)): + try: + tok_list, round_meta = next(self._rounds_iter) + except StopIteration: + break + self._append_token_responses(responses, tok_list) + + self._refresh_uids() + return responses + + class PromptProcessingBatch: """ Handles VLM prompt processing with inputs_embeds and chunked prefill. @@ -1096,12 +1469,20 @@ def __init__( right_pad_per_row: Optional[List[int]] = None, suffix_lens: Optional[List[int]] = None, apc_mode: Optional[str] = None, + draft_model: Optional[nn.Module] = None, + draft_kind: Optional[str] = None, + draft_block_size: Optional[int] = None, + greedy_sampling: bool = False, ): self.model = model self.uids = uids self._prompt_uids = list(uids) self.max_tokens = max_tokens self.prefill_step_size = prefill_step_size + self.draft_model = draft_model + self.draft_kind = draft_kind + self.draft_block_size = draft_block_size + self.greedy_sampling = greedy_sampling lengths = [len(ids) for ids in input_ids] max_length = max(lengths) @@ -1169,6 +1550,27 @@ def __init__( if warm_cache is not None: self.prompt_cache = warm_cache + elif draft_model is not None and draft_kind is not None: + self.prompt_cache = make_speculative_prompt_cache( + model, + draft_kind=draft_kind, + batch_size=len(input_ids), + left_padding=left_padding, + make_cache=lambda lm, lp: _make_cache( + lm, + lp, + kv_bits=kv_bits, + kv_group_size=kv_group_size, + kv_quant_scheme=kv_quant_scheme, + ), + ) + elif ( + len(input_ids) == 1 + and right_pad_per_row is None + and kv_bits is None + and hasattr(model, "make_cache") + ): + self.prompt_cache = cache.make_prompt_cache(model) else: self.prompt_cache = _make_cache( model, @@ -1351,11 +1753,17 @@ def generate( self, sampler, stop_criteria, compute_logprobs=True, top_logprobs_k=0 ) -> GenerationBatch: """Process final tokens and transition to GenerationBatch.""" + call_kwargs = dict(self._prompt_kwargs) + if self.draft_model is not None and self.draft_kind is not None: + call_kwargs.update( + speculative_prefill_kwargs(self.draft_kind, self.draft_model) + ) + output = self.model( self._input_ids, cache=self.prompt_cache, inputs_embeds=self._inputs_embeds, - **self._prompt_kwargs, + **call_kwargs, ) logits = output.logits if hasattr(output, "logits") else output if self._right_pad_per_row is not None and any(self._right_pad_per_row): @@ -1381,7 +1789,12 @@ def generate( logits = mx.concatenate(processed_logits, axis=0) logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) - first_tokens = sampler(logprobs) + first_tokens = _sample_with_positions( + sampler, + logprobs, + row_ids=[0] * len(self.uids), + positions=[0] * len(self.uids), + ) mx.async_eval(first_tokens) @@ -1411,28 +1824,51 @@ def generate( self._suffix_lens, ) - gen_batch = GenerationBatch( - model=self.model, - uids=list(self.uids), - inputs=first_tokens, - prompt_cache=self.prompt_cache, - sampler=sampler, - stop_criteria=stop_criteria, - max_tokens=list(self.max_tokens), - top_logprobs_k=top_logprobs_k, - token_context=[list(ctx) for ctx in self._token_context], - logits_processors=list(self.logits_processors), - thinking_budget_criteria=list(self.thinking_budget_criteria), - ) + if self.draft_model is not None and self.draft_kind is not None: + gen_batch = SpeculativeGenerationBatch( + model=self.model, + draft_model=self.draft_model, + draft_kind=self.draft_kind, + uids=list(self.uids), + first_tokens=first_tokens, + prompt_cache=self.prompt_cache, + sampler=sampler, + stop_criteria=stop_criteria, + max_tokens=list(self.max_tokens), + hidden=speculative_hidden_state(self.draft_kind, output), + shared_kv_states=( + output.shared_kv_states if self.draft_kind == "mtp" else None + ), + prompt_tokens=self._input_ids, + draft_block_size=self.draft_block_size, + token_dtype=self._input_ids.dtype, + greedy_sampling=self.greedy_sampling, + ) + compute_logprobs = False + else: + gen_batch = GenerationBatch( + model=self.model, + uids=list(self.uids), + inputs=first_tokens, + prompt_cache=self.prompt_cache, + sampler=sampler, + stop_criteria=stop_criteria, + max_tokens=list(self.max_tokens), + top_logprobs_k=top_logprobs_k, + greedy_sampling=self.greedy_sampling, + token_context=[list(ctx) for ctx in self._token_context], + logits_processors=list(self.logits_processors), + thinking_budget_criteria=list(self.thinking_budget_criteria), + ) gen_batch.compute_logprobs = compute_logprobs - if compute_logprobs: + if compute_logprobs and isinstance(gen_batch, GenerationBatch): gen_batch._next_lps = logprobs[ mx.arange(first_tokens.shape[0]), first_tokens ] # Prime top-K buffers so the first token can emit top_logprobs too. - if top_logprobs_k > 0: + if top_logprobs_k > 0 and isinstance(gen_batch, GenerationBatch): k = top_logprobs_k sort_idx = mx.argsort(logprobs, axis=-1) top_idx = sort_idx[..., -k:][..., ::-1].astype(mx.int32) @@ -1475,6 +1911,26 @@ def generate( rope_deltas = rope_deltas[:target_b] gen_batch._rope_deltas = rope_deltas + # Final prefill produces the first generated token and mutates the + # prompt cache. Materialize that boundary before the decode loop so + # the first decode step does not inherit the full lazy prefill graph. + cache_states = [] + for c in self.prompt_cache: + try: + cache_states.append(c.state) + except (AttributeError, TypeError): + pass + eval_targets = [first_tokens] + if cache_states: + eval_targets.append(cache_states) + if compute_logprobs and isinstance(gen_batch, GenerationBatch): + eval_targets.append(gen_batch._next_lps) + if top_logprobs_k > 0 and isinstance(gen_batch, GenerationBatch): + eval_targets.extend([gen_batch._next_top_idx, gen_batch._next_top_lp]) + if rope_deltas is not None: + eval_targets.append(rope_deltas) + mx.eval(*eval_targets) + # APC: harvest the post-prefill K/V into hashed blocks. Done after the # final prefill forward but before the cache references are released # so the block tensors snapshot the prompt prefix. @@ -1582,6 +2038,10 @@ def __init__( ] = None, stream=None, apc_manager: Optional["_apc.APCManager"] = None, + draft_model: Optional[nn.Module] = None, + draft_kind: Optional[str] = None, + draft_block_size: Optional[int] = None, + greedy_sampling: bool = False, ): self.model = model self.max_tokens = max_tokens @@ -1593,6 +2053,17 @@ def __init__( self.compute_logprobs = compute_logprobs self.top_logprobs_k = top_logprobs_k self.logits_processors = logits_processors or [] + self.draft_model = draft_model + self.draft_kind = draft_kind + self.draft_block_size = draft_block_size + self.greedy_sampling = greedy_sampling or sampler is None + if self.draft_model is not None: + prefill_step_size = None + apc_manager = None + compute_logprobs = False + top_logprobs_k = 0 + self.compute_logprobs = False + self.top_logprobs_k = 0 # APC: opt-out for KV-quantized caches. Plain KV models use block APC; # mixed/custom cache models use exact prompt-cache snapshots. self.apc_mode = None @@ -1622,6 +2093,7 @@ def __init__( self.tokenizer.stopping_criteria, compute_logprobs=self.compute_logprobs, top_logprobs_k=self.top_logprobs_k, + greedy_sampling=self.greedy_sampling, ) self._prompt_batch: Optional[PromptProcessingBatch] = None self._unprocessed_sequences = [] @@ -1630,6 +2102,7 @@ def __init__( self._prompt_time_counter = 0 self._gen_tokens_counter = 0 self._steps_counter = 0 + self._cache_eval_interval = _get_batch_cache_eval_interval() self._wire_stack = contextlib.ExitStack() self._wire_stack.enter_context(wired_limit(model, [self._stream])) @@ -1909,6 +2382,10 @@ def _build_mixed_prompt_batch( right_pad_per_row=right_pad_per_row, suffix_lens=suffix_lens, apc_mode=apc_mode, + draft_model=getattr(self, "draft_model", None), + draft_kind=getattr(self, "draft_kind", None), + draft_block_size=getattr(self, "draft_block_size", None), + greedy_sampling=getattr(self, "greedy_sampling", False), ) def _build_apc_meta_for_cold( @@ -2069,6 +2546,12 @@ def _prompt_batch_progress(prompt_batch) -> List[PromptProgress]: return progress() return [] + def _extend_generation_batch(self, gen_batch) -> None: + if len(self._generation_batch) == 0: + self._generation_batch = gen_batch + else: + self._generation_batch.extend(gen_batch) + def _next(self, **kwargs): generation_responses = [] prompt_responses = [] @@ -2078,10 +2561,23 @@ def _next(self, **kwargs): generation_responses = self._generation_batch.next() self._gen_tokens_counter += len(generation_responses) self._steps_counter += 1 - if self._steps_counter % 50 == 0: - mx.eval([c.state for c in self._generation_batch.prompt_cache]) + if ( + self._cache_eval_interval > 0 + and self._steps_counter % self._cache_eval_interval == 0 + ): + cache_states = getattr(self._generation_batch, "cache_states", None) + if callable(cache_states): + mx.eval(cache_states()) + else: + mx.eval([c.state for c in self._generation_batch.prompt_cache]) mx.clear_cache() + if ( + getattr(self._generation_batch, "is_speculative", False) + and len(self._generation_batch) > 0 + ): + return prompt_responses, generation_responses + if len(self._generation_batch) >= self.completion_batch_size: return prompt_responses, generation_responses @@ -2106,7 +2602,7 @@ def _next(self, **kwargs): self._prompt_time_counter += elapsed self._record_prompt_batch_time(self._prompt_batch, elapsed) prompt_responses = self._prompt_batch_progress(self._prompt_batch) - self._generation_batch.extend(gen_batch) + self._extend_generation_batch(gen_batch) self._prompt_batch = None mx.clear_cache() return prompt_responses, generation_responses @@ -2149,7 +2645,7 @@ def _next(self, **kwargs): self._prompt_time_counter += elapsed self._record_prompt_batch_time(self._prompt_batch, elapsed) prompt_responses = self._prompt_batch_progress(self._prompt_batch) - self._generation_batch.extend(gen_batch) + self._extend_generation_batch(gen_batch) self._prompt_batch = None mx.clear_cache() return prompt_responses, generation_responses @@ -2189,6 +2685,10 @@ def _next(self, **kwargs): apc_meta=apc_meta, apc_manager=self.apc_manager, apc_mode=self.apc_mode, + draft_model=getattr(self, "draft_model", None), + draft_kind=getattr(self, "draft_kind", None), + draft_block_size=getattr(self, "draft_block_size", None), + greedy_sampling=getattr(self, "greedy_sampling", False), ) self._prompt_tokens_counter += self._prompt_batch.total_prompt_tokens @@ -2210,7 +2710,7 @@ def _next(self, **kwargs): self._prompt_time_counter += elapsed self._record_prompt_batch_time(self._prompt_batch, elapsed) prompt_responses = self._prompt_batch_progress(self._prompt_batch) - self._generation_batch.extend(gen_batch) + self._extend_generation_batch(gen_batch) self._prompt_batch = None mx.clear_cache() diff --git a/mlx_vlm/models/qwen3_5/gated_delta.py b/mlx_vlm/models/qwen3_5/gated_delta.py index 7111730ff..e2a18cac5 100644 --- a/mlx_vlm/models/qwen3_5/gated_delta.py +++ b/mlx_vlm/models/qwen3_5/gated_delta.py @@ -1,7 +1,42 @@ +from functools import partial from typing import Optional import mlx.core as mx -from mlx_lm.models.gated_delta import compute_g, gated_delta_update # noqa: F401 +import mlx.nn as nn +from mlx_lm.models.gated_delta import gated_delta_kernel, gated_delta_ops + + +@partial(mx.compile, shapeless=True) +def compute_g(A_log, a, dt_bias): + return mx.exp(-mx.exp(A_log.astype(mx.float32)) * nn.softplus(a + dt_bias)) + + +@partial(mx.compile, shapeless=True) +def _compute_g_beta(A_log, a, b, dt_bias): + return compute_g(A_log, a, dt_bias), mx.sigmoid(b) + + +def gated_delta_update( + q: mx.array, + k: mx.array, + v: mx.array, + a: mx.array, + b: mx.array, + A_log: mx.array, + dt_bias: mx.array, + state: Optional[mx.array] = None, + mask: Optional[mx.array] = None, + use_kernel: bool = True, +): + g, beta = _compute_g_beta(A_log, a, b, dt_bias) + if state is None: + B, _, _Hk, Dk = q.shape + Hv, Dv = v.shape[-2:] + state = mx.zeros((B, Hv, Dv, Dk), dtype=mx.float32) + + if not use_kernel or mx.default_device() != mx.gpu or not mx.metal.is_available(): + return gated_delta_ops(q, k, v, g, beta, state, mask) + return gated_delta_kernel(q, k, v, g, beta, state, mask) def _make_gated_delta_with_states_kernel(has_mask: bool = False): @@ -153,8 +188,7 @@ def gated_delta_update_with_states( mask: Optional[mx.array] = None, use_kernel: bool = True, ): - beta = mx.sigmoid(b) - g = compute_g(A_log, a, dt_bias) + g, beta = _compute_g_beta(A_log, a, b, dt_bias) if state is None: B, _, _Hk, Dk = q.shape Hv, Dv = v.shape[-2:] @@ -276,6 +310,69 @@ def _make_gated_delta_state_kernel(has_mask: bool = False): _gated_delta_state_kernel_masked = _make_gated_delta_state_kernel(True) +def _make_gated_delta_accept_states_kernel(): + if not mx.metal.is_available(): + return None + + return mx.fast.metal_kernel( + name="qwen3_5_gated_delta_accept_states", + input_names=[ + "intermediate_states", + "conv_input", + "live_state", + "live_conv", + "accepted", + ], + output_names=["state_out", "conv_out"], + source=r""" + uint idx = thread_position_in_grid.x; + + if (idx < StateTotal) { + uint dk = idx % Dk; + uint t0 = idx / Dk; + uint dv = t0 % Dv; + t0 /= Dv; + uint hv = t0 % Hv; + uint row = t0 / Hv; + + int step = int(accepted[row]); + bool use_intermediate = step >= 0 && step < T; + StT value; + if (use_intermediate) { + value = intermediate_states[ + ((((row * T + uint(step)) * Hv + hv) * Dv + dv) * Dk + dk) + ]; + } else { + value = live_state[((row * Hv + hv) * Dv + dv) * Dk + dk]; + } + state_out[idx] = static_cast(value); + } + + if (idx < ConvTotal) { + uint c = idx % C; + uint t0 = idx / C; + uint win = t0 % ConvW; + uint row = t0 / ConvW; + + int step = int(accepted[row]); + bool use_intermediate = step >= 0 && step < T; + ConvT value; + if (use_intermediate) { + value = conv_input[ + (row * ConvInputT + uint(step) + 1 + win) * C + c + ]; + } else { + value = live_conv[(row * ConvW + win) * C + c]; + } + conv_out[idx] = static_cast(value); + } + """, + ) + + +_gated_delta_accept_states_kernel = _make_gated_delta_accept_states_kernel() + + def _gated_delta_state_ops( k: mx.array, v: mx.array, @@ -305,6 +402,85 @@ def _gated_delta_state_ops( return state +def _gated_delta_accept_states_ops( + intermediate_states: mx.array, + conv_input: mx.array, + live_state: mx.array, + live_conv: mx.array, + accepted: mx.array, + kernel_size: int, +): + steps = [int(step) for step in accepted.tolist()] + state_rows = [] + conv_rows = [] + state_steps = intermediate_states.shape[1] + for row, step in enumerate(steps): + if 0 <= step < state_steps: + state_rows.append(intermediate_states[row, step]) + conv_rows.append(conv_input[row : row + 1, step + 1 : step + kernel_size]) + else: + state_rows.append(live_state[row]) + conv_rows.append(live_conv[row : row + 1]) + return mx.stack(state_rows, axis=0), mx.concatenate(conv_rows, axis=0) + + +def gated_delta_accept_states( + intermediate_states: mx.array, + conv_input: mx.array, + live_state: mx.array, + live_conv: mx.array, + accepted: mx.array, + kernel_size: int, + use_kernel: bool = True, +): + if accepted.dtype != mx.int32: + accepted = accepted.astype(mx.int32) + + if ( + not use_kernel + or mx.default_device() != mx.gpu + or not mx.metal.is_available() + or _gated_delta_accept_states_kernel is None + ): + return _gated_delta_accept_states_ops( + intermediate_states, + conv_input, + live_state, + live_conv, + accepted, + kernel_size, + ) + + rows, state_steps, Hv, Dv, Dk = intermediate_states.shape + conv_input_t = conv_input.shape[1] + conv_dim = conv_input.shape[-1] + conv_window = int(kernel_size) - 1 + state_total = rows * Hv * Dv * Dk + conv_total = rows * conv_window * conv_dim + total = max(state_total, conv_total) + + return _gated_delta_accept_states_kernel( + inputs=[intermediate_states, conv_input, live_state, live_conv, accepted], + template=[ + ("StT", intermediate_states.dtype), + ("ConvT", conv_input.dtype), + ("T", state_steps), + ("Hv", Hv), + ("Dv", Dv), + ("Dk", Dk), + ("C", conv_dim), + ("ConvW", conv_window), + ("ConvInputT", conv_input_t), + ("StateTotal", state_total), + ("ConvTotal", conv_total), + ], + grid=(total, 1, 1), + threadgroup=(256, 1, 1), + output_shapes=[live_state.shape, live_conv.shape], + output_dtypes=[intermediate_states.dtype, conv_input.dtype], + ) + + def gated_delta_state_update( k: mx.array, v: mx.array, @@ -317,8 +493,7 @@ def gated_delta_state_update( mask: Optional[mx.array] = None, use_kernel: bool = True, ) -> mx.array: - beta = mx.sigmoid(b) - g = compute_g(A_log, a, dt_bias) + g, beta = _compute_g_beta(A_log, a, b, dt_bias) if state is None: B, _, _Hk, Dk = k.shape Hv, Dv = v.shape[-2:] diff --git a/mlx_vlm/models/qwen3_5/language.py b/mlx_vlm/models/qwen3_5/language.py index 1aeba2db2..4b20c6b6f 100644 --- a/mlx_vlm/models/qwen3_5/language.py +++ b/mlx_vlm/models/qwen3_5/language.py @@ -1,4 +1,4 @@ -from functools import partial +from functools import lru_cache, partial from typing import Any, List, Optional import mlx.core as mx @@ -8,7 +8,6 @@ from ..base import ( LanguageModelOutput, create_attention_mask, - create_ssm_mask, scaled_dot_product_attention, ) from ..cache import ArraysCache, KVCache @@ -16,6 +15,7 @@ from ..rope_utils import apply_multimodal_rotary_pos_emb as _apply_mrope from .config import ModelConfig, TextConfig from .gated_delta import ( + gated_delta_accept_states, gated_delta_state_update, gated_delta_update, gated_delta_update_with_states, @@ -66,6 +66,12 @@ def _precise_swiglu(h, gate, x): return (gate * x).astype(h.dtype) +@partial(mx.compile, shapeless=True) +def _qwen3_5_decode_depthwise_conv(conv_input: mx.array, weight: mx.array): + out = mx.sum(conv_input.astype(mx.float32) * weight[None, :, :], axis=1) + return out.astype(conv_input.dtype)[:, None, :] + + _TARGET_VERIFY_GEMV = mx.fast.metal_kernel( name="qwen3_5_target_verify_gemv", input_names=["x", "weight"], @@ -147,7 +153,7 @@ def _use_target_verify_dense(linear, x: mx.array, target_verify: bool) -> bool: target_verify and x.ndim == 3 and x.shape[1] > 1 - and isinstance(linear, nn.Linear) + and isinstance(linear, (nn.Linear, nn.QuantizedLinear)) ) @@ -170,16 +176,463 @@ def _target_verify_weight(weight: mx.array, x: mx.array) -> Optional[mx.array]: return out.reshape(B, L, O) +def _target_verify_qlinear_header(bits: int, group_size: int) -> str: + return r""" + using namespace metal; + + constant constexpr int SIMD_SIZE = 32; + constant constexpr int BITS = __BITS__; + constant constexpr int GS = __GS__; + constant constexpr int PACK_FACTOR = (BITS == 5 ? 8 : 32 / BITS); + constant constexpr int BYTES_PER_PACK = (BITS == 5 ? 5 : 32 / 8); + constant constexpr int PACKS_PER_THREAD = 2; + constant constexpr int VALUES_PER_THREAD = PACK_FACTOR * PACKS_PER_THREAD; + constant constexpr int BLOCK_SIZE = VALUES_PER_THREAD * SIMD_SIZE; + constant constexpr int SCALE_STEP_PER_THREAD = GS / VALUES_PER_THREAD; + constant constexpr int RESULTS_PER_SIMDGROUP = 4; + constant constexpr int NUM_SIMDGROUPS = 2; + constant constexpr int BN = RESULTS_PER_SIMDGROUP * NUM_SIMDGROUPS; + + template + inline float load_vector_exact(const device T* x, thread float* x_thread) { + float sum = 0.0f; + if (BITS == 4) { + for (int i = 0; i < VALUES_PER_THREAD; i += 4) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 16.0f; + x_thread[i + 2] = x[i + 2] / 256.0f; + x_thread[i + 3] = x[i + 3] / 4096.0f; + } + } else if (BITS == 5) { + for (int i = 0; i < VALUES_PER_THREAD; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 32.0f; + x_thread[i + 2] = x[i + 2] / 4.0f; + x_thread[i + 3] = x[i + 3] / 128.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 2.0f; + x_thread[i + 6] = x[i + 6] / 64.0f; + x_thread[i + 7] = x[i + 7] / 8.0f; + } + } + return sum; + } + + inline float qdot_exact( + const device uint8_t* w, + const thread float* x_thread, + float scale, + float bias, + float sum) { + float accum = 0.0f; + if (BITS == 4) { + const device uint16_t* ws = (const device uint16_t*)w; + for (int i = 0; i < (VALUES_PER_THREAD / 4); i++) { + accum += + (x_thread[4 * i] * (ws[i] & 0x000f) + + x_thread[4 * i + 1] * (ws[i] & 0x00f0) + + x_thread[4 * i + 2] * (ws[i] & 0x0f00) + + x_thread[4 * i + 3] * (ws[i] & 0xf000)); + } + } else if (BITS == 5) { + for (int i = 0; i < (VALUES_PER_THREAD / 8); i++) { + const thread float* xt = x_thread + 8 * i; + const device uint8_t* wb = w + 5 * i; + + accum += (wb[0] & 0x1f) * xt[0]; + accum += (wb[0] & 0xe0) * xt[1]; + accum += (wb[1] & 0x3) * (xt[1] * 256.0f); + accum += (wb[1] & 0x7c) * xt[2]; + accum += (wb[1] & 0x80) * xt[3]; + accum += (wb[2] & 0xf) * (xt[3] * 256.0f); + accum += (wb[2] & 0xf0) * xt[4]; + accum += (wb[3] & 0x1) * (xt[4] * 256.0f); + accum += (wb[3] & 0x3e) * xt[5]; + accum += (wb[3] & 0xc0) * xt[6]; + accum += (wb[4] & 0x7) * (xt[6] * 256.0f); + accum += (wb[4] & 0xf8) * xt[7]; + } + } + return scale * accum + sum * bias; + } +""".replace( + "__BITS__", str(bits) + ).replace( + "__GS__", str(group_size) + ) + + +_TARGET_VERIFY_QMV_SOURCE = r""" + uint n_tile = threadgroup_position_in_grid.y; + uint b_idx = threadgroup_position_in_grid.z; + uint simd_gid = simdgroup_index_in_threadgroup; + uint simd_lid = thread_index_in_simdgroup; + + int out_row = int(n_tile) * BN + int(simd_gid) * RESULTS_PER_SIMDGROUP; + int in_vec_size_w = K_SIZE * BYTES_PER_PACK / PACK_FACTOR; + int in_vec_size_g = K_SIZE / GS; + + const device uint8_t* ws_base = + (const device uint8_t*)w + out_row * in_vec_size_w + + int(simd_lid) * PACKS_PER_THREAD * BYTES_PER_PACK; + const device T* scales_base = + scales + out_row * in_vec_size_g + int(simd_lid) / SCALE_STEP_PER_THREAD; + const device T* biases_base = + biases + out_row * in_vec_size_g + int(simd_lid) / SCALE_STEP_PER_THREAD; + const device T* x_base = + x + int(b_idx) * VERIFY_T * K_SIZE + int(simd_lid) * VALUES_PER_THREAD; + + float result[VERIFY_T][RESULTS_PER_SIMDGROUP]; + float x_thread[VERIFY_T][VALUES_PER_THREAD]; + for (int t = 0; t < VERIFY_T; ++t) { + for (int row = 0; row < RESULTS_PER_SIMDGROUP; ++row) { + result[t][row] = 0.0f; + } + } + + const device uint8_t* ws = ws_base; + const device T* sc = scales_base; + const device T* bs = biases_base; + const device T* xk = x_base; + + for (int k = 0; k < K_SIZE; k += BLOCK_SIZE) { + float sums[VERIFY_T]; + for (int t = 0; t < VERIFY_T; ++t) { + sums[t] = load_vector_exact(xk + t * K_SIZE, x_thread[t]); + } + + for (int row = 0; row < RESULTS_PER_SIMDGROUP; ++row) { + const device uint8_t* wl = ws + row * in_vec_size_w; + const device T* sl = sc + row * in_vec_size_g; + const device T* bl = bs + row * in_vec_size_g; + float s = float(sl[0]); + float b = float(bl[0]); + for (int t = 0; t < VERIFY_T; ++t) { + result[t][row] += qdot_exact(wl, x_thread[t], s, b, sums[t]); + } + } + + ws += BLOCK_SIZE * BYTES_PER_PACK / PACK_FACTOR; + sc += BLOCK_SIZE / GS; + bs += BLOCK_SIZE / GS; + xk += BLOCK_SIZE; + } + + for (int row = 0; row < RESULTS_PER_SIMDGROUP; ++row) { + int n = out_row + row; + for (int t = 0; t < VERIFY_T; ++t) { + float r = simd_sum(result[t][row]); + if (simd_lid == 0) { + y[(int(b_idx) * VERIFY_T + t) * N_SIZE + n] = T(r); + } + } + } +""" + + +_TARGET_VERIFY_QARGMAX_SOURCE = r""" + uint n_tile = threadgroup_position_in_grid.y; + uint b_idx = threadgroup_position_in_grid.z; + uint simd_gid = simdgroup_index_in_threadgroup; + uint simd_lid = thread_index_in_simdgroup; + + int out_row = int(n_tile) * BN + int(simd_gid) * RESULTS_PER_SIMDGROUP; + int in_vec_size_w = K_SIZE * BYTES_PER_PACK / PACK_FACTOR; + int in_vec_size_g = K_SIZE / GS; + + threadgroup float tile_best_values[VERIFY_T][NUM_SIMDGROUPS]; + threadgroup int tile_best_indices[VERIFY_T][NUM_SIMDGROUPS]; + + const device uint8_t* ws_base = + (const device uint8_t*)w + out_row * in_vec_size_w + + int(simd_lid) * PACKS_PER_THREAD * BYTES_PER_PACK; + const device T* scales_base = + scales + out_row * in_vec_size_g + int(simd_lid) / SCALE_STEP_PER_THREAD; + const device T* biases_base = + biases + out_row * in_vec_size_g + int(simd_lid) / SCALE_STEP_PER_THREAD; + const device T* x_base = + x + int(b_idx) * VERIFY_T * K_SIZE + int(simd_lid) * VALUES_PER_THREAD; + + float result[VERIFY_T][RESULTS_PER_SIMDGROUP]; + float x_thread[VERIFY_T][VALUES_PER_THREAD]; + for (int t = 0; t < VERIFY_T; ++t) { + for (int row = 0; row < RESULTS_PER_SIMDGROUP; ++row) { + result[t][row] = 0.0f; + } + } + + const device uint8_t* ws = ws_base; + const device T* sc = scales_base; + const device T* bs = biases_base; + const device T* xk = x_base; + + for (int k = 0; k < K_SIZE; k += BLOCK_SIZE) { + float sums[VERIFY_T]; + for (int t = 0; t < VERIFY_T; ++t) { + sums[t] = load_vector_exact(xk + t * K_SIZE, x_thread[t]); + } + + for (int row = 0; row < RESULTS_PER_SIMDGROUP; ++row) { + const device uint8_t* wl = ws + row * in_vec_size_w; + const device T* sl = sc + row * in_vec_size_g; + const device T* bl = bs + row * in_vec_size_g; + float s = float(sl[0]); + float b = float(bl[0]); + for (int t = 0; t < VERIFY_T; ++t) { + result[t][row] += qdot_exact(wl, x_thread[t], s, b, sums[t]); + } + } + + ws += BLOCK_SIZE * BYTES_PER_PACK / PACK_FACTOR; + sc += BLOCK_SIZE / GS; + bs += BLOCK_SIZE / GS; + xk += BLOCK_SIZE; + } + + for (int t = 0; t < VERIFY_T; ++t) { + float best_value = -3.4028234663852886e38f; + int best_index = 0; + for (int row = 0; row < RESULTS_PER_SIMDGROUP; ++row) { + int n = out_row + row; + if (n < N_SIZE) { + float rounded = float(T(simd_sum(result[t][row]))); + if (rounded > best_value) { + best_value = rounded; + best_index = n; + } + } + } + + if (simd_lid == 0) { + tile_best_values[t][simd_gid] = best_value; + tile_best_indices[t][simd_gid] = best_index; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (simd_gid == 0 && simd_lid == 0) { + for (int t = 0; t < VERIFY_T; ++t) { + float best = tile_best_values[t][0]; + int best_idx = tile_best_indices[t][0]; + for (int i = 1; i < NUM_SIMDGROUPS; ++i) { + float candidate = tile_best_values[t][i]; + int candidate_idx = tile_best_indices[t][i]; + if (candidate > best) { + best = candidate; + best_idx = candidate_idx; + } + } + int offset = (int(b_idx) * VERIFY_T + t) * NUM_TILES + int(n_tile); + tile_values[offset] = T(best); + tile_indices[offset] = best_idx; + } + } +""" + + +@lru_cache(maxsize=None) +def _target_verify_qmv_kernel(bits, group_size, dtype, verify_t, k_size, n_size): + dtype_name = {mx.bfloat16: "bf16", mx.float16: "fp16"}.get(dtype, "unk") + return mx.fast.metal_kernel( + name=( + "qwen3_5_target_verify_qmv_" + f"b{bits}_gs{group_size}_t{verify_t}_k{k_size}_n{n_size}_{dtype_name}" + ), + input_names=["x", "w", "scales", "biases"], + output_names=["y"], + header=_target_verify_qlinear_header(bits, group_size), + source=_TARGET_VERIFY_QMV_SOURCE, + ) + + +@lru_cache(maxsize=None) +def _target_verify_qargmax_kernel(bits, group_size, dtype, verify_t, k_size, n_size): + dtype_name = {mx.bfloat16: "bf16", mx.float16: "fp16"}.get(dtype, "unk") + return mx.fast.metal_kernel( + name=( + "qwen3_5_target_verify_qargmax_" + f"b{bits}_gs{group_size}_t{verify_t}_k{k_size}_n{n_size}_{dtype_name}" + ), + input_names=["x", "w", "scales", "biases"], + output_names=["tile_values", "tile_indices"], + header=_target_verify_qlinear_header(bits, group_size), + source=_TARGET_VERIFY_QARGMAX_SOURCE, + ) + + +def _can_target_verify_quantized(linear, x: mx.array) -> bool: + if ( + not isinstance(linear, nn.QuantizedLinear) + or x.ndim != 3 + or x.shape[1] < 1 + or linear.bits not in (4, 5) + or linear.mode != "affine" + or linear.biases is None + or x.dtype not in (mx.bfloat16, mx.float16) + or linear.scales.dtype != x.dtype + or linear.biases.dtype != x.dtype + ): + return False + + _, _, K = x.shape + N = linear.weight.shape[0] + return ( + K == linear.weight.shape[1] * 32 // linear.bits and K % 512 == 0 and N % 8 == 0 + ) + + +def _target_verify_quantized_linear(linear, x: mx.array) -> Optional[mx.array]: + if not _can_target_verify_quantized(linear, x): + return None + + B, T, K = x.shape + N = linear.weight.shape[0] + + x = mx.contiguous(x) + kernel = _target_verify_qmv_kernel(linear.bits, linear.group_size, x.dtype, T, K, N) + out = kernel( + inputs=[x, linear.weight, linear.scales, linear.biases], + template=[ + ("T", x.dtype), + ("VERIFY_T", int(T)), + ("K_SIZE", int(K)), + ("N_SIZE", int(N)), + ], + grid=(32, 2 * (N // 8), B), + threadgroup=(32, 2, 1), + output_shapes=[(B, T, N)], + output_dtypes=[x.dtype], + )[0] + if "bias" in linear: + out = out + linear["bias"] + return out + + +def _decode_quantized_linears_fused(linears, x: mx.array): + if ( + x.ndim != 3 + or x.shape[1] != 1 + or len(linears) != 4 + or not all(isinstance(linear, nn.QuantizedLinear) for linear in linears) + ): + return None + + first = linears[0] + if not all( + linear.bits == first.bits + and linear.group_size == first.group_size + and linear.mode == first.mode + and linear.biases is not None + and linear.scales.dtype == x.dtype + and linear.biases.dtype == x.dtype + and "bias" not in linear + for linear in linears + ): + return None + + cache_key = tuple( + (id(linear.weight), id(linear.scales), id(linear.biases)) for linear in linears + ) + cached = getattr(first, "_qwen3_5_fused_decode_linears", None) + if cached is None or cached[0] != cache_key: + weights = mx.concatenate([linear.weight for linear in linears], axis=0) + scales = mx.concatenate([linear.scales for linear in linears], axis=0) + biases = mx.concatenate([linear.biases for linear in linears], axis=0) + split_indices = [] + offset = 0 + for linear in linears[:-1]: + offset += linear.weight.shape[0] + split_indices.append(offset) + mx.eval(weights, scales, biases) + cached = (cache_key, weights, scales, biases, split_indices) + first._qwen3_5_fused_decode_linears = cached + + _, weights, scales, biases, split_indices = cached + output = mx.quantized_matmul( + x, + weights, + scales=scales, + biases=biases, + transpose=True, + group_size=first.group_size, + bits=first.bits, + mode=first.mode, + ) + return tuple(mx.split(output, split_indices, axis=-1)) + + +def _target_verify_quantized_argmax(linear, x: mx.array) -> Optional[mx.array]: + if not _can_target_verify_quantized(linear, x) or "bias" in linear: + return None + + B, T, K = x.shape + if T == 1 and 1 < B <= 4: + out = _target_verify_quantized_argmax(linear, x.transpose(1, 0, 2)) + if out is not None: + return out.transpose(1, 0) + + N = linear.weight.shape[0] + num_tiles = N // 8 + + x = mx.contiguous(x) + kernel = _target_verify_qargmax_kernel( + linear.bits, linear.group_size, x.dtype, T, K, N + ) + tile_values, tile_indices = kernel( + inputs=[x, linear.weight, linear.scales, linear.biases], + template=[ + ("T", x.dtype), + ("VERIFY_T", int(T)), + ("K_SIZE", int(K)), + ("N_SIZE", int(N)), + ("NUM_TILES", int(num_tiles)), + ], + grid=(32, 2 * num_tiles, B), + threadgroup=(32, 2, 1), + output_shapes=[(B, T, num_tiles), (B, T, num_tiles)], + output_dtypes=[x.dtype, mx.int32], + ) + best_tile = mx.argmax(tile_values, axis=-1) + return mx.take_along_axis(tile_indices, best_tile[..., None], axis=-1).squeeze(-1) + + +def _target_verify_timewise(fn, x: mx.array) -> mx.array: + return mx.concatenate([fn(x[:, i : i + 1]) for i in range(x.shape[1])], axis=1) + + +def _target_verify_singletons(fn, x: mx.array) -> mx.array: + rows = [] + for row in range(x.shape[0]): + rows.append( + mx.concatenate( + [fn(x[row : row + 1, i : i + 1]) for i in range(x.shape[1])], + axis=1, + ) + ) + return mx.concatenate(rows, axis=0) + + def _target_verify_linear(linear, x: mx.array, target_verify: bool) -> mx.array: if not _use_target_verify_dense(linear, x, target_verify): return linear(x) - if "bias" not in linear: + if isinstance(linear, nn.QuantizedLinear): + if x.shape[0] == 1: + return linear(x) + out = _target_verify_quantized_linear(linear, x) + if out is not None: + return out + return _target_verify_timewise(linear, x) + + if isinstance(linear, nn.Linear) and "bias" not in linear: out = _target_verify_weight(linear.weight, x) if out is not None: return out - return mx.concatenate([linear(x[:, i : i + 1]) for i in range(x.shape[1])], axis=1) + return _target_verify_singletons(linear, x) def _target_verify_linears(linears, x: mx.array, target_verify: bool): @@ -187,8 +640,13 @@ def _target_verify_linears(linears, x: mx.array, target_verify: bool): target_verify and x.ndim == 3 and x.shape[1] > 1 - and all(isinstance(linear, nn.Linear) for linear in linears) + and all( + isinstance(linear, (nn.Linear, nn.QuantizedLinear)) for linear in linears + ) ): + out = _decode_quantized_linears_fused(linears, x) + if out is not None: + return out return tuple(linear(x) for linear in linears) return tuple(_target_verify_linear(linear, x, target_verify) for linear in linears) @@ -202,12 +660,177 @@ def _target_verify_embedding_as_linear(embedding, x: mx.array, target_verify: bo if out is not None: return out + return _target_verify_timewise(embedding.as_linear, x) + + +def _extract_row_cache(cache_entry, row: int): + if isinstance(cache_entry, ArraysCache): + row_cache = ArraysCache(size=len(cache_entry.cache)) + row_cache.cache = [ + None if cached is None else cached[row : row + 1] + for cached in cache_entry.cache + ] + lengths = getattr(cache_entry, "lengths", None) + if lengths is not None: + row_cache.lengths = lengths[row : row + 1] + return row_cache + + if hasattr(cache_entry, "extract") and not cache_entry.empty(): + return cache_entry.extract(row) + + if hasattr(cache_entry, "left_padding"): + row_cache = KVCache() + return row_cache + + return cache_entry + + +def _is_single_row_batch_cache(cache_entry) -> bool: + left_padding = getattr(cache_entry, "left_padding", None) + return ( + isinstance(left_padding, mx.array) + and left_padding.ndim > 0 + and left_padding.size == 1 + ) + + +def _pad_row_time(x: mx.array, pad: int, target_length: int) -> mx.array: + if pad <= 0: + return x + if x.shape[1] >= target_length: + return x return mx.concatenate( - [embedding.as_linear(x[:, i : i + 1]) for i in range(x.shape[1])], + [ + mx.zeros((x.shape[0], pad, *x.shape[2:]), dtype=x.dtype), + x, + ], axis=1, ) +def _qwen3_5_left_padding_info(cache): + left_padding = getattr(cache, "left_padding", None) + if not ( + isinstance(left_padding, mx.array) + and left_padding.ndim > 0 + and left_padding.size > 0 + ): + return None + + cached = getattr(cache, "_qwen3_5_left_padding_info", None) + if cached is None or cached[0] is not left_padding: + pads = tuple(int(p) for p in left_padding.tolist()) + cached = (left_padding, pads, max(pads) if pads else 0) + cache._qwen3_5_left_padding_info = cached + return cached[1], cached[2] + + +def _qwen3_5_set_left_padding_info(cache, pads): + left_padding = getattr(cache, "left_padding", None) + if not isinstance(left_padding, mx.array): + return + pads = tuple(int(p) for p in pads) + cache._qwen3_5_left_padding_info = ( + left_padding, + pads, + max(pads) if pads else 0, + ) + + +def _qwen3_5_advance_left_padding_info(cache, steps: int): + cached = getattr(cache, "_qwen3_5_left_padding_info", None) + if cached is None: + return + _left_padding, pads, _max_pad = cached + _qwen3_5_set_left_padding_info(cache, (p - steps for p in pads)) + + +def _qwen3_5_lengths_info(cache): + lengths = getattr(cache, "lengths", None) + if not (isinstance(lengths, mx.array) and lengths.ndim > 0 and lengths.size > 0): + return None + cached = getattr(cache, "_qwen3_5_lengths_info", None) + if cached is None or cached[0] is not lengths: + values = tuple(int(v) for v in lengths.tolist()) + cached = (lengths, min(values) if values else 0) + cache._qwen3_5_lengths_info = cached + return cached[1] + + +def _qwen3_5_advance_lengths_info(cache, steps: int): + lengths = getattr(cache, "lengths", None) + cached = getattr(cache, "_qwen3_5_lengths_info", None) + if cached is None or not isinstance(lengths, mx.array): + return + _lengths, min_value = cached + cache._qwen3_5_lengths_info = (lengths, min_value - steps) + + +def _create_qwen3_5_ssm_mask(h: mx.array, cache): + if not (cache and hasattr(cache, "make_mask")): + return None + + lengths = getattr(cache, "lengths", None) + left_padding = getattr(cache, "left_padding", None) + if isinstance(left_padding, mx.array): + batch_size = int(left_padding.shape[0]) if left_padding.ndim > 0 else 1 + if ( + lengths is None + and getattr(cache, "_qwen3_5_ssm_no_mask_batch_size", None) == batch_size + ): + return None + left_padding_info = _qwen3_5_left_padding_info(cache) + max_left_padding = left_padding_info[1] if left_padding_info else 0 + if max_left_padding <= 0: + if lengths is None: + cache._qwen3_5_ssm_no_mask_batch_size = batch_size + return None + if hasattr(cache, "_qwen3_5_ssm_no_mask_batch_size"): + delattr(cache, "_qwen3_5_ssm_no_mask_batch_size") + + lengths_min = _qwen3_5_lengths_info(cache) + if lengths_min is not None and lengths_min >= h.shape[1]: + return None + + return cache.make_mask(h.shape[1]) + + +def _create_qwen3_5_attention_mask(h: mx.array, cache): + if cache is None: + return create_attention_mask(h, cache) + + if hasattr(cache, "_qwen3_5_decode_left_padding"): + delattr(cache, "_qwen3_5_decode_left_padding") + + left_padding = getattr(cache, "left_padding", None) + if h.shape[1] == 1 and isinstance(left_padding, mx.array) and left_padding.ndim > 0: + padding_cache = getattr(cache, "_qwen3_5_left_padding_cache", None) + if padding_cache is None or padding_cache[0] is not left_padding: + left_padding_info = _qwen3_5_left_padding_info(cache) + pads = list(left_padding_info[0]) if left_padding_info else [] + padding_cache = (left_padding, pads, max(pads) if pads else 0) + cache._qwen3_5_left_padding_cache = padding_cache + pads = padding_cache[1] + if padding_cache[2] <= 0: + return None + cache._qwen3_5_decode_left_padding = pads + return "left_padded_decode" + return create_attention_mask(h, cache) + + +def _set_qwen3_5_decode_left_padding(caches, layers, pads): + if caches is None: + return + for layer, cache_entry in zip(layers, caches): + if layer.is_linear or cache_entry is None: + continue + if pads is None: + if hasattr(cache_entry, "_qwen3_5_decode_left_padding"): + delattr(cache_entry, "_qwen3_5_decode_left_padding") + else: + cache_entry._qwen3_5_decode_left_padding = pads + + def _gated_delta_update_verify_decode( q: mx.array, k: mx.array, @@ -234,6 +857,500 @@ def _gated_delta_update_verify_decode( ) +_QWEN3_5_RAGGED_SDPA_ONE_PASS_SOURCE = r""" + uint q_batch_head_idx = threadgroup_position_in_grid.y; + uint simd_gid = simdgroup_index_in_threadgroup; + uint simd_lid = thread_index_in_simdgroup; + + constexpr int BN = 32; + constexpr int BD = 32; + constexpr int qk_per_thread = D_SIZE / BD; + constexpr int v_per_thread = V_SIZE / BD; + + typedef float U; + thread U q[qk_per_thread]; + thread U k[qk_per_thread]; + thread U o[v_per_thread]; + threadgroup U outputs[BN * BD]; + threadgroup U max_scores[BN]; + threadgroup U sum_exp_scores[BN]; + + int K_SIZE = int(k_size[0]); + int batch_idx = int(q_batch_head_idx) / NUM_Q_HEADS; + int q_head_idx = int(q_batch_head_idx) - batch_idx * NUM_Q_HEADS; + int kv_head_idx = q_head_idx / GQA_FACTOR; + int pad = int(pads[batch_idx]); + int N = K_SIZE - pad; + + const device T* qptr = + queries + int(q_batch_head_idx) * D_SIZE + int(simd_lid) * qk_per_thread; + const device T* kptr = + keys + (batch_idx * NUM_KV_HEADS + kv_head_idx) * K_SIZE * D_SIZE + + (pad + int(simd_gid)) * D_SIZE + int(simd_lid) * qk_per_thread; + const device T* vptr = + values + (batch_idx * NUM_KV_HEADS + kv_head_idx) * K_SIZE * V_SIZE + + (pad + int(simd_gid)) * V_SIZE + int(simd_lid) * v_per_thread; + device T* optr = + out + int(q_batch_head_idx) * V_SIZE + int(simd_gid) * v_per_thread; + + U s = U(scale[0]); + for (int i = 0; i < qk_per_thread; i++) { + q[i] = s * qptr[i]; + } + for (int i = 0; i < v_per_thread; i++) { + o[i] = 0; + } + + U max_score = -3.4028234663852886e38f; + U sum_exp_score = 0; + + for (int i = int(simd_gid); i < N; i += BN) { + for (int j = 0; j < qk_per_thread; j++) { + k[j] = kptr[j]; + } + + U score = 0; + for (int j = 0; j < qk_per_thread; j++) { + score += q[j] * k[j]; + } + score = simd_sum(score); + + U new_max = max(max_score, score); + U factor = fast::exp(max_score - new_max); + U exp_score = fast::exp(score - new_max); + + max_score = new_max; + sum_exp_score = sum_exp_score * factor + exp_score; + + for (int j = 0; j < v_per_thread; j++) { + o[j] = o[j] * factor + exp_score * vptr[j]; + } + + kptr += BN * D_SIZE; + vptr += BN * V_SIZE; + } + + if (simd_lid == 0) { + max_scores[simd_gid] = max_score; + sum_exp_scores[simd_gid] = sum_exp_score; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + max_score = max_scores[simd_lid]; + U new_max = simd_max(max_score); + U factor = fast::exp(max_score - new_max); + sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor); + + for (int i = 0; i < v_per_thread; i++) { + outputs[simd_lid * BD + simd_gid] = o[i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor); + o[i] = sum_exp_score == 0 ? o[i] : (o[i] / sum_exp_score); + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + if (simd_lid == 0) { + for (int i = 0; i < v_per_thread; i++) { + optr[i] = static_cast(o[i]); + } + } +""" + + +_QWEN3_5_RAGGED_SDPA_TWO_PASS_1_SOURCE = r""" + uint simd_lid = thread_index_in_simdgroup; + uint kv_head_idx = threadgroup_position_in_grid.x; + uint batch_idx = threadgroup_position_in_grid.y; + uint block_idx = threadgroup_position_in_grid.z; + uint gqa_idx = thread_position_in_threadgroup.y; + + constexpr int BD = 32; + constexpr int qk_per_thread = D_SIZE / BD; + constexpr int v_per_thread = V_SIZE / BD; + + typedef float U; + thread U q[qk_per_thread]; + thread U o[v_per_thread] = {0}; + + int K_SIZE = int(k_size[0]); + int q_head_idx = int(GQA_FACTOR * kv_head_idx + gqa_idx); + int q_batch_head_idx = int(batch_idx) * NUM_Q_HEADS + q_head_idx; + int pad = int(pads[batch_idx]); + int N = K_SIZE - pad; + + const device T* qptr = + queries + q_batch_head_idx * D_SIZE + int(simd_lid) * qk_per_thread; + const device T* kptr = + keys + (int(batch_idx) * NUM_KV_HEADS + int(kv_head_idx)) * + K_SIZE * D_SIZE + + (pad + int(block_idx)) * D_SIZE + int(simd_lid) * qk_per_thread; + const device T* vptr = + values + (int(batch_idx) * NUM_KV_HEADS + int(kv_head_idx)) * + K_SIZE * V_SIZE + + (pad + int(block_idx)) * V_SIZE + int(simd_lid) * v_per_thread; + device T* optr = + partials + q_batch_head_idx * BLOCKS * V_SIZE + + int(block_idx) * V_SIZE + int(simd_lid) * v_per_thread; + device float* sump = + sums + q_batch_head_idx * BLOCKS + int(block_idx); + device float* maxp = + maxs + q_batch_head_idx * BLOCKS + int(block_idx); + + U s = U(scale[0]); + for (int i = 0; i < qk_per_thread; i++) { + q[i] = s * qptr[i]; + } + + U max_score = -3.4028234663852886e38f; + U sum_exp_score = 0; + + for (int i = int(block_idx); i < N; i += BLOCKS) { + U score = 0; + for (int j = 0; j < qk_per_thread; j++) { + score += q[j] * kptr[j]; + } + score = simd_sum(score); + + U new_max = max(max_score, score); + U factor = fast::exp(max_score - new_max); + U exp_score = fast::exp(score - new_max); + + max_score = new_max; + sum_exp_score = sum_exp_score * factor + exp_score; + + for (int j = 0; j < v_per_thread; j++) { + o[j] = o[j] * factor + exp_score * vptr[j]; + } + + kptr += BLOCKS * D_SIZE; + vptr += BLOCKS * V_SIZE; + } + + if (simd_lid == 0) { + sump[0] = sum_exp_score; + maxp[0] = max_score; + } + for (int i = 0; i < v_per_thread; i++) { + optr[i] = static_cast(o[i]); + } +""" + + +_QWEN3_5_RAGGED_SDPA_TWO_PASS_2_SOURCE = r""" + uint q_batch_head_idx = threadgroup_position_in_grid.y; + uint simd_gid = simdgroup_index_in_threadgroup; + uint simd_lid = thread_index_in_simdgroup; + + constexpr int BN = 32; + constexpr int BD = 32; + constexpr int elem_per_thread = D_SIZE / BD; + + typedef float U; + thread U o[elem_per_thread] = {0}; + threadgroup U outputs[BN * BD]; + + const device T* part = + partials + int(q_batch_head_idx) * BLOCKS * D_SIZE + + int(simd_gid) * D_SIZE + int(simd_lid) * elem_per_thread; + const device float* sump = sums + int(q_batch_head_idx) * BLOCKS; + const device float* maxp = maxs + int(q_batch_head_idx) * BLOCKS; + device T* optr = + out + int(q_batch_head_idx) * D_SIZE + int(simd_gid) * elem_per_thread; + + U sum_exp_score = 0.0; + U max_score = -3.4028234663852886e38f; + + for (int b = 0; b < BLOCKS / BN; ++b) { + max_score = max(max_score, maxp[int(simd_lid) + BN * b]); + } + max_score = simd_max(max_score); + + for (int b = 0; b < BLOCKS / BN; ++b) { + U factor = fast::exp(maxp[int(simd_lid) + BN * b] - max_score); + sum_exp_score += factor * sump[int(simd_lid) + BN * b]; + } + sum_exp_score = simd_sum(sum_exp_score); + + for (int b = 0; b < BLOCKS / BN; ++b) { + U factor = fast::exp(maxp[int(simd_gid)] - max_score); + for (int i = 0; i < elem_per_thread; i++) { + o[i] += factor * static_cast(part[i]); + } + maxp += BN; + sump += BN; + part += BN * D_SIZE; + } + + for (int i = 0; i < elem_per_thread; i++) { + outputs[simd_lid * BD + simd_gid] = o[i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + o[i] = simd_sum(outputs[simd_gid * BD + simd_lid]); + o[i] = sum_exp_score == 0 ? o[i] : (o[i] / sum_exp_score); + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + if (simd_lid == 0) { + for (int i = 0; i < elem_per_thread; i++) { + optr[i] = static_cast(o[i]); + } + } +""" + + +@lru_cache(maxsize=1) +def _qwen3_5_device_arch_suffix() -> str: + info = mx.device_info() if hasattr(mx, "device_info") else mx.metal.device_info() + return str(info.get("architecture", ""))[-1:] + + +def _qwen3_5_sdpa_vector_blocks(seq_len: int, gqa_factor: int) -> int: + devc = _qwen3_5_device_arch_suffix() + n_simds = gqa_factor + if devc == "s": + blocks = 64 + if seq_len > 1024 and n_simds > 4: + if seq_len <= 8192: + blocks = 128 + elif seq_len <= 32768: + blocks = 256 + elif seq_len <= 65536: + blocks = 512 + else: + blocks = 1024 + return blocks + if devc == "d": + blocks = 128 + if n_simds <= 2 and seq_len > 8192: + blocks = 256 + elif n_simds >= 6: + if 16384 <= seq_len < 65536: + blocks = 512 + elif seq_len >= 65536: + blocks = 1024 + return blocks + if n_simds >= 4: + return 64 + return 32 + + +def _qwen3_5_sdpa_vector_plan(seq_len: int, q_heads: int, kv_heads: int): + devc = _qwen3_5_device_arch_suffix() + if (devc in {"d", "s"} and seq_len >= 1024) or ( + kv_heads < q_heads and seq_len >= 4096 + ): + return ("two_pass", _qwen3_5_sdpa_vector_blocks(seq_len, q_heads // kv_heads)) + return ("one_pass", 0) + + +@lru_cache(maxsize=None) +def _qwen3_5_ragged_sdpa_one_pass_kernel(dtype, d_size, v_size): + dtype_name = {mx.bfloat16: "bf16", mx.float16: "fp16"}.get(dtype, "unk") + return mx.fast.metal_kernel( + name=f"qwen3_5_ragged_sdpa_1p_{dtype_name}_d{d_size}_v{v_size}", + input_names=["queries", "keys", "values", "pads", "scale", "k_size"], + output_names=["out"], + header="#include \nusing namespace metal;\n", + source=_QWEN3_5_RAGGED_SDPA_ONE_PASS_SOURCE, + ) + + +@lru_cache(maxsize=None) +def _qwen3_5_ragged_sdpa_two_pass_1_kernel(dtype, d_size, v_size, blocks): + dtype_name = {mx.bfloat16: "bf16", mx.float16: "fp16"}.get(dtype, "unk") + return mx.fast.metal_kernel( + name=( + f"qwen3_5_ragged_sdpa_2p1_{dtype_name}_" f"d{d_size}_v{v_size}_b{blocks}" + ), + input_names=["queries", "keys", "values", "pads", "scale", "k_size"], + output_names=["partials", "sums", "maxs"], + header="#include \nusing namespace metal;\n", + source=_QWEN3_5_RAGGED_SDPA_TWO_PASS_1_SOURCE, + ) + + +@lru_cache(maxsize=None) +def _qwen3_5_ragged_sdpa_two_pass_2_kernel(dtype, v_size, blocks): + dtype_name = {mx.bfloat16: "bf16", mx.float16: "fp16"}.get(dtype, "unk") + return mx.fast.metal_kernel( + name=f"qwen3_5_ragged_sdpa_2p2_{dtype_name}_v{v_size}_b{blocks}", + input_names=["partials", "sums", "maxs"], + output_names=["out"], + header="#include \nusing namespace metal;\n", + source=_QWEN3_5_RAGGED_SDPA_TWO_PASS_2_SOURCE, + ) + + +@lru_cache(maxsize=128) +def _qwen3_5_cached_i32_array(values): + return mx.array(values, dtype=mx.int32) + + +@lru_cache(maxsize=128) +def _qwen3_5_cached_sdpa_scalars(scale: float, k_size: int): + return ( + mx.array([scale], dtype=mx.float32), + mx.array([k_size], dtype=mx.int32), + ) + + +def _qwen3_5_ragged_decode_attention( + queries: mx.array, + keys: mx.array, + values: mx.array, + pads: List[int], + scale: float, +) -> Optional[mx.array]: + if ( + queries.ndim != 4 + or keys.ndim != 4 + or values.ndim != 4 + or queries.shape[2] != 1 + or queries.dtype not in (mx.bfloat16, mx.float16) + or keys.dtype != queries.dtype + or values.dtype != queries.dtype + ): + return None + + batch, q_heads, _, d_size = queries.shape + pads = tuple(int(p) for p in pads) + if len(pads) != batch or any(p < 0 for p in pads): + return None + kv_heads = keys.shape[1] + k_size = keys.shape[2] + v_size = values.shape[-1] + if ( + q_heads % kv_heads != 0 + or d_size != v_size + or d_size not in (64, 96, 128, 256) + or any(p >= k_size for p in pads) + ): + return None + + plans = [_qwen3_5_sdpa_vector_plan(k_size - pad, q_heads, kv_heads) for pad in pads] + if len(set(plans)) != 1: + return None + mode, blocks = plans[0] + + queries = mx.contiguous(queries) + keys = mx.contiguous(keys) + values = mx.contiguous(values) + pads_array = _qwen3_5_cached_i32_array(pads) + scale_array, k_size_array = _qwen3_5_cached_sdpa_scalars(float(scale), int(k_size)) + template = [ + ("T", queries.dtype), + ("D_SIZE", int(d_size)), + ("V_SIZE", int(v_size)), + ("NUM_Q_HEADS", int(q_heads)), + ("NUM_KV_HEADS", int(kv_heads)), + ("GQA_FACTOR", int(q_heads // kv_heads)), + ] + + if mode == "one_pass": + kernel = _qwen3_5_ragged_sdpa_one_pass_kernel(queries.dtype, d_size, v_size) + return kernel( + inputs=[queries, keys, values, pads_array, scale_array, k_size_array], + template=template, + grid=(1024, batch * q_heads, 1), + threadgroup=(1024, 1, 1), + output_shapes=[(batch, q_heads, 1, v_size)], + output_dtypes=[queries.dtype], + )[0] + + kernel_1 = _qwen3_5_ragged_sdpa_two_pass_1_kernel( + queries.dtype, d_size, v_size, blocks + ) + partials, sums, maxs = kernel_1( + inputs=[queries, keys, values, pads_array, scale_array, k_size_array], + template=[*template, ("BLOCKS", int(blocks))], + grid=(32 * kv_heads, (q_heads // kv_heads) * batch, blocks), + threadgroup=(32, q_heads // kv_heads, 1), + output_shapes=[ + (batch, q_heads, 1, blocks, v_size), + (batch, q_heads, 1, blocks), + (batch, q_heads, 1, blocks), + ], + output_dtypes=[queries.dtype, mx.float32, mx.float32], + ) + kernel_2 = _qwen3_5_ragged_sdpa_two_pass_2_kernel(queries.dtype, v_size, blocks) + return kernel_2( + inputs=[partials, sums, maxs], + template=[ + ("T", queries.dtype), + ("D_SIZE", int(v_size)), + ("BLOCKS", int(blocks)), + ], + grid=(1024, batch * q_heads, 1), + threadgroup=(1024, 1, 1), + output_shapes=[(batch, q_heads, 1, v_size)], + output_dtypes=[queries.dtype], + )[0] + + +def _target_verify_left_padded_attention( + queries: mx.array, + keys: mx.array, + values: mx.array, + *, + cache, + scale: float, + mask: Optional[mx.array], +) -> Optional[mx.array]: + if hasattr(cache, "bits") or queries.ndim != 4 or keys.ndim != 4: + return None + + pads = getattr(cache, "_qwen3_5_decode_left_padding", None) + if pads is None: + left_padding_info = _qwen3_5_left_padding_info(cache) + if left_padding_info is None or left_padding_info[1] <= 0: + return None + pads = list(left_padding_info[0]) + if max(pads) <= 0: + return None + + output = _qwen3_5_ragged_decode_attention(queries, keys, values, pads, scale) + if output is not None: + return output + + row_outputs = {} + for pad in sorted(set(pads)): + rows = [i for i, row_pad in enumerate(pads) if row_pad == pad] + row_idx = mx.array(rows, dtype=mx.int32) + group_queries = mx.take(queries, row_idx, axis=0) + group_keys = mx.take(keys, row_idx, axis=0)[:, :, pad:, :] + group_values = mx.take(values, row_idx, axis=0)[:, :, pad:, :] + + if group_queries.shape[2] > 1: + prefix_len = group_keys.shape[-2] - group_queries.shape[2] + group_output = mx.concatenate( + [ + scaled_dot_product_attention( + group_queries[:, :, i : i + 1, :], + group_keys[:, :, : prefix_len + i + 1, :], + group_values[:, :, : prefix_len + i + 1, :], + cache=None, + scale=scale, + mask=None, + ) + for i in range(group_queries.shape[2]) + ], + axis=2, + ) + else: + group_output = scaled_dot_product_attention( + group_queries, + group_keys, + group_values, + cache=None, + scale=scale, + mask=None, + ) + + for j, row in enumerate(rows): + row_outputs[row] = group_output[j : j + 1] + + return mx.concatenate([row_outputs[i] for i in range(queries.shape[0])], axis=0) + + class Qwen3_5Attention(nn.Module): def __init__(self, args: TextConfig): super().__init__() @@ -283,7 +1400,6 @@ def __call__( target_verify: bool = False, ) -> mx.array: B, L, D = x.shape - q_proj_output, keys, values = _target_verify_linears( (self.q_proj, self.k_proj, self.v_proj), x, target_verify ) @@ -322,14 +1438,33 @@ def __call__( queries, keys = apply_multimodal_rotary_pos_emb(queries, keys, cos, sin) if mask is not None and isinstance(mask, mx.array): - if isinstance(kv_seq_len, mx.array): + if ( + cache is not None + and hasattr(cache, "_idx") + and hasattr(cache, "left_padding") + ): + kv_seq_len = int(cache._idx) + L + elif isinstance(kv_seq_len, mx.array): kv_seq_len = kv_seq_len.max().item() mask = mask[..., : int(kv_seq_len)] if cache is not None: keys, values = cache.update_and_fetch(keys, values) - if target_verify and L > 1: + left_padded_decode = ( + mask == "left_padded_decode" if isinstance(mask, str) else False + ) + if left_padded_decode: + mask = None + + if (target_verify and L > 1) or left_padded_decode: + output = _target_verify_left_padded_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask + ) + else: + output = None + + if output is None and target_verify and L > 1: prefix_len = keys.shape[-2] - L output = mx.concatenate( [ @@ -339,13 +1474,17 @@ def __call__( values[:, :, : prefix_len + i + 1, :], cache=cache, scale=self.scale, - mask=None, + mask=( + mask[..., i : i + 1, : prefix_len + i + 1] + if isinstance(mask, mx.array) and mask.ndim >= 4 + else None + ), ) for i in range(L) ], axis=2, ) - else: + elif output is None: output = scaled_dot_product_attention( queries, keys, values, cache=cache, scale=self.scale, mask=mask ) @@ -417,6 +1556,18 @@ def __init__(self, config: TextConfig): def _causal_conv1d_verify(self, conv_input: mx.array, steps: int) -> mx.array: return self.conv1d(conv_input) + def _causal_conv1d_decode(self, conv_input: mx.array) -> mx.array: + cached = getattr(self, "_qwen3_5_decode_conv_weight", None) + cache_key = id(self.conv1d.weight) + if cached is None or cached[0] != cache_key: + weight = self.conv1d.weight[:, :, 0].T.astype(mx.float32) + mx.eval(weight) + cached = (cache_key, weight) + self._qwen3_5_decode_conv_weight = cached + + weight = cached[1] + return _qwen3_5_decode_depthwise_conv(conv_input, weight) + def __call__( self, inputs: mx.array, @@ -465,6 +1616,12 @@ def __call__( cache[0] = mx.contiguous(conv_input[:, -n_keep:, :]) if gdn_sink is not None: conv_out = nn.silu(self._causal_conv1d_verify(conv_input, S)) + elif ( + S == 1 + and conv_input.shape[1] == self.conv_kernel_size + and self.conv1d.weight.dtype in (mx.bfloat16, mx.float16) + ): + conv_out = nn.silu(self._causal_conv1d_decode(conv_input)) else: conv_out = nn.silu(self.conv1d(conv_input)) @@ -535,6 +1692,8 @@ def __call__( cache[1] = state if hasattr(cache, "advance"): cache.advance(S) + _qwen3_5_advance_left_padding_info(cache, S) + _qwen3_5_advance_lengths_info(cache, S) out = self.norm(out, z) return _target_verify_linear( @@ -620,8 +1779,102 @@ def __call__( if cache is None: cache = [None] * len(self.layers) - fa_mask = create_attention_mask(h, cache[self.fa_idx]) - ssm_mask = create_ssm_mask(h, cache[self.ssm_idx]) + fa_cache = cache[self.fa_idx] + if ( + h.shape[0] == 1 + and hidden_sink is None + and gdn_sink is None + and fa_cache is not None + and _is_single_row_batch_cache(fa_cache) + ): + row_cache = [] + for cache_entry in cache: + if cache_entry is None: + row_cache.append(None) + elif _is_single_row_batch_cache(cache_entry): + row_cache.append(_extract_row_cache(cache_entry, 0)) + else: + row_cache.append(cache_entry) + + row_out = self( + inputs, + inputs_embeds=h, + cache=row_cache, + position_ids=position_ids, + ) + for i, cache_entry in enumerate(row_cache): + if cache[i] is None or cache_entry is None: + continue + if hasattr(cache[i].__class__, "merge"): + cache[i] = cache[i].__class__.merge([cache_entry]) + return row_out + + if ( + h.shape[0] > 1 + and h.shape[1] > 1 + and hidden_sink is None + and gdn_sink is None + and fa_cache is not None + and hasattr(fa_cache, "extract") + and hasattr(fa_cache.__class__, "merge") + and isinstance(getattr(fa_cache, "offset", None), mx.array) + and fa_cache.offset.ndim > 0 + ): + query_left_padding = mx.minimum(mx.maximum(-fa_cache.offset, 0), h.shape[1]) + cache_left_padding = getattr(fa_cache, "left_padding", None) + has_left_padding = ( + isinstance(cache_left_padding, mx.array) + and cache_left_padding.ndim > 0 + and int(cache_left_padding.max().item()) > 0 + ) + if has_left_padding or int(query_left_padding.max().item()) > 0: + row_outputs = [] + row_caches = [[] for _ in cache] + for row, pad in enumerate(query_left_padding.tolist()): + pad = min(max(int(pad), 0), h.shape[1]) + row_inputs = inputs[row : row + 1, pad:] + row_embeds = h[row : row + 1, pad:] + row_position_ids = None + if position_ids is not None: + if position_ids.ndim == 2: + row_position_ids = position_ids[row : row + 1, pad:] + else: + row_position_ids = position_ids[:, row : row + 1, pad:] + current_cache = [] + for cache_entry in cache: + if cache_entry is None: + current_cache.append(None) + else: + current_cache.append(_extract_row_cache(cache_entry, row)) + + row_out = self( + row_inputs, + inputs_embeds=row_embeds, + cache=current_cache, + position_ids=row_position_ids, + ) + if pad > 0: + row_out = _pad_row_time(row_out, pad, h.shape[1]) + row_outputs.append(row_out) + for i, cache_entry in enumerate(current_cache): + row_caches[i].append(cache_entry) + + for i, entries in enumerate(row_caches): + if cache[i] is None: + continue + if hasattr(cache[i].__class__, "merge"): + cache[i] = cache[i].__class__.merge(entries) + return mx.concatenate(row_outputs, axis=0) + + fa_mask = _create_qwen3_5_attention_mask(h, cache[self.fa_idx]) + ssm_mask = _create_qwen3_5_ssm_mask(h, cache[self.ssm_idx]) + decode_left_padding = ( + getattr(cache[self.fa_idx], "_qwen3_5_decode_left_padding", None) + if isinstance(fa_mask, str) and fa_mask == "left_padded_decode" + else None + ) + _set_qwen3_5_decode_left_padding(cache, self.layers, decode_left_padding) + position_embeddings = None if position_ids is not None: for layer in self.layers: @@ -671,13 +1924,31 @@ def rollback_speculative_cache( block_size: int, ) -> int: if isinstance(accepted, int): - accepted = mx.array([accepted]) + accepted_list = [int(accepted)] + elif isinstance(accepted, mx.array): + accepted_list = [int(x) for x in accepted.reshape(-1).tolist()] + else: + accepted_list = [int(x) for x in accepted] - max_a = int(accepted.max().item()) + max_a = max(accepted_list) n = max_a + 1 trim = block_size - n - is_batch = accepted.size > 1 - valid_ends = accepted + 1 + is_batch = len(accepted_list) > 1 + valid_ends_list = [a + 1 for a in accepted_list] + accepted_mx = None + valid_ends_mx = None + + def accepted_array(): + nonlocal accepted_mx + if accepted_mx is None: + accepted_mx = mx.array(accepted_list, dtype=mx.int32) + return accepted_mx + + def valid_ends_array(): + nonlocal valid_ends_mx + if valid_ends_mx is None: + valid_ends_mx = mx.array(valid_ends_list, dtype=mx.int32) + return valid_ends_mx # Separate trimmable (KV) caches from SSM caches. ssm_caches = [] @@ -687,12 +1958,31 @@ def rollback_speculative_cache( if c.is_trimmable(): if trim > 0: c.trim(trim) - if is_batch and hasattr(c, "_idx") and c.keys is not None and max_a > 0: + right_trimmed = False + if is_batch and max_a > 0: + extra_trim_list = [max_a - a for a in accepted_list] + if any(extra_trim_list): + prepare = getattr(c, "prepare", None) + finalize = getattr(c, "finalize", None) + if ( + c.keys is not None + and callable(prepare) + and callable(finalize) + ): + prepare(right_padding=extra_trim_list) + finalize() + right_trimmed = True + if ( + is_batch + and not right_trimmed + and hasattr(c, "_idx") + and c.keys is not None + and max_a > 0 + ): kv_len = c._idx - ve = valid_ends.tolist() verify_start = kv_len - n - for bi in range(accepted.shape[0]): - start = verify_start + int(ve[bi]) + for bi, ve in enumerate(valid_ends_list): + start = verify_start + ve if start < kv_len: c.keys[bi, :, start:kv_len, :] = 0 c.values[bi, :, start:kv_len, :] = 0 @@ -703,48 +1993,95 @@ def rollback_speculative_cache( return max_a if all(len(s) > 11 and s[11] is not None for s in gdn_states): - a0 = int(accepted[0].item()) if not is_batch else None - for j, c in enumerate(ssm_caches): - ( - _q, - _k, - _v, - _a, - _b, - _A_log, - _dt_bias, - _init_state, - _mask, - conv_input, - K, - intermediate_states, - *_, - ) = gdn_states[j] - if is_batch: - acc_list = accepted.tolist() - state_steps = intermediate_states.shape[1] - states = [ - ( - intermediate_states[bi, int(acc_list[bi])] - if int(acc_list[bi]) < state_steps - else c[1][bi] + a0 = accepted_list[0] if not is_batch else None + if is_batch: + intermediate_parts = [] + conv_input_parts = [] + live_state_parts = [] + live_conv_parts = [] + layer_batch_sizes = [] + kernel_sizes = [] + + for j, c in enumerate(ssm_caches): + ( + _q, + _k, + _v, + _a, + _b, + _A_log, + _dt_bias, + _init_state, + _mask, + conv_input, + K, + intermediate_states, + *_, + ) = gdn_states[j] + rows = intermediate_states.shape[0] + layer_batch_sizes.append(rows) + kernel_sizes.append(int(K)) + intermediate_parts.append(intermediate_states) + conv_input_parts.append(conv_input) + + live_state = c[1] + if live_state is None: + live_state = mx.zeros( + ( + rows, + intermediate_states.shape[2], + intermediate_states.shape[3], + intermediate_states.shape[4], + ), + dtype=intermediate_states.dtype, ) - for bi in range(len(acc_list)) - ] - c[1] = mx.stack(states, axis=0) - slices = [ - ( - conv_input[ - bi : bi + 1, - int(acc_list[bi]) + 1 : int(acc_list[bi]) + K, - ] - if int(acc_list[bi]) < state_steps - else c[0][bi : bi + 1] + live_state_parts.append(live_state) + + live_conv = c[0] + if live_conv is None: + live_conv = mx.zeros( + (rows, int(K) - 1, conv_input.shape[-1]), + dtype=conv_input.dtype, ) - for bi in range(len(acc_list)) - ] - c[0] = mx.concatenate(slices, axis=0) - else: + live_conv_parts.append(live_conv) + + if len(set(kernel_sizes)) != 1: + raise ValueError("Qwen GDN layers must share conv kernel size.") + + accepted_mx = accepted_array() + accepted_bat = mx.concatenate([accepted_mx for _ in ssm_caches], axis=0) + state_bat, conv_bat = gated_delta_accept_states( + mx.concatenate(intermediate_parts, axis=0), + mx.concatenate(conv_input_parts, axis=0), + mx.concatenate(live_state_parts, axis=0), + mx.concatenate(live_conv_parts, axis=0), + accepted_bat, + kernel_sizes[0], + use_kernel=True, + ) + + offset = 0 + for c, rows in zip(ssm_caches, layer_batch_sizes): + c[1] = state_bat[offset : offset + rows] + c[0] = conv_bat[offset : offset + rows] + offset += rows + else: + for j, c in enumerate(ssm_caches): + ( + _q, + _k, + _v, + _a, + _b, + _A_log, + _dt_bias, + _init_state, + _mask, + conv_input, + K, + intermediate_states, + *_, + ) = gdn_states[j] if a0 < intermediate_states.shape[1]: c[1] = intermediate_states[:, a0] c[0] = conv_input[:, a0 + 1 : a0 + K] @@ -786,7 +2123,7 @@ def rollback_speculative_cache( a_list.append(a) b_list.append(b) if is_batch: - steps_list.append(valid_ends.astype(mx.int32)) + steps_list.append(valid_ends_array()) else: steps_list.append(mx.full((batch_rows,), n, dtype=mx.int32)) A_log_list.append( @@ -841,7 +2178,7 @@ def rollback_speculative_cache( ) # Scatter results back to individual caches. - a0 = int(accepted[0].item()) if not is_batch else None + a0 = accepted_list[0] if not is_batch else None state_offset = 0 for j, c in enumerate(ssm_caches): batch_rows = layer_batch_sizes[j] @@ -849,13 +2186,12 @@ def rollback_speculative_cache( state_offset += batch_rows conv_input, K = conv_data[j] if is_batch: - acc_list = accepted.tolist() slices = [ conv_input[ bi : bi + 1, - int(acc_list[bi]) + 1 : int(acc_list[bi]) + K, + accepted_list[bi] + 1 : accepted_list[bi] + K, ] - for bi in range(accepted.shape[0]) + for bi in range(len(accepted_list)) ] c[0] = mx.concatenate(slices, axis=0) else: @@ -888,21 +2224,20 @@ def get_rope_index( ) image_index, video_index = 0, 0 for i, input_ids in enumerate(total_input_ids): - input_ids = mx.where( - attention_mask[i] == 1, input_ids, mx.zeros_like(input_ids) - ) + row_mask = attention_mask[i].tolist() + input_tokens = [ + token + for token, keep in zip(input_ids.tolist(), row_mask) + if keep == 1 + ] image_nums, video_nums = 0, 0 - vision_start_indices = mx.sum( - mx.where( - input_ids == vision_start_token_id, - mx.arange(input_ids.shape[0]), - mx.zeros_like(input_ids), - ) - ) - vision_tokens = input_ids[vision_start_indices + 1] - image_nums = (vision_tokens == image_token_id).sum().item() - video_nums = (vision_tokens == video_token_id).sum().item() - input_tokens = input_ids.tolist() + vision_tokens = [ + input_tokens[idx + 1] + for idx, token in enumerate(input_tokens[:-1]) + if token == vision_start_token_id + ] + image_nums = sum(token == image_token_id for token in vision_tokens) + video_nums = sum(token == video_token_id for token in vision_tokens) llm_pos_ids_list: list = [] st = 0 remain_images, remain_videos = image_nums, video_nums @@ -984,7 +2319,19 @@ def get_rope_index( llm_pos_ids_list.append(t_index + st_idx) llm_positions = mx.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1) - mask = mx.array(attention_mask[i] == 1) + compact_max_position = llm_positions.max() + padded_positions = [[1] * total_input_ids.shape[1] for _ in range(3)] + compact_positions = llm_positions.tolist() + compact_idx = 0 + for col, keep in enumerate(row_mask): + if keep == 1: + for dim in range(3): + padded_positions[dim][col] = compact_positions[dim][ + compact_idx + ] + compact_idx += 1 + llm_positions = mx.array(padded_positions, dtype=position_ids.dtype) + mask = mx.array(row_mask, dtype=mx.bool_) expanded_mask = mx.expand_dims(mask, axis=0) expanded_mask = mx.broadcast_to(expanded_mask, (3, 1, mask.shape[0])) expanded_positions = mx.expand_dims(llm_positions, axis=1) @@ -1001,7 +2348,7 @@ def get_rope_index( ) position_ids = updated_position_ids mrope_position_deltas.append( - llm_positions.max() + 1 - len(total_input_ids[i]) + compact_max_position + 1 - len(input_tokens) ) mrope_position_deltas = mx.array(mrope_position_deltas).reshape(-1, 1) return position_ids, mrope_position_deltas @@ -1036,17 +2383,25 @@ def __call__( pixel_values = kwargs.pop("pixel_values", None) image_grid_thw = kwargs.pop("image_grid_thw", None) video_grid_thw = kwargs.pop("video_grid_thw", None) + attention_mask = kwargs.pop("attention_mask", None) capture_layer_ids = kwargs.pop("capture_layer_ids", None) return_hidden = kwargs.pop("return_hidden", False) return_shared_kv = kwargs.pop("return_shared_kv", False) skip_logits = kwargs.pop("skip_logits", False) rope_deltas_kw = kwargs.pop("rope_deltas", None) + if ( + mask is None + and attention_mask is not None + and attention_mask.shape[-1] == inputs.shape[-1] + ): + mask = attention_mask if pixel_values is not None: self._rope_deltas = None self._position_ids = None cache_offset = 0 cache_offsets = None # per-element offsets for batched caches + c0 = None if cache and cache[self.model.fa_idx] is not None: c0 = cache[self.model.fa_idx] cache_offset = c0._idx if hasattr(c0, "_idx") else c0.offset @@ -1057,6 +2412,16 @@ def __call__( ): cache_offsets = mx.maximum(c0.offset, 0) + if mask is None and c0 is not None and cache_offset == 0: + left_padding = getattr(c0, "left_padding", None) + if ( + isinstance(left_padding, mx.array) + and left_padding.ndim > 0 + and left_padding.size >= inputs.shape[0] + ): + positions = mx.arange(inputs.shape[-1])[None, :] + mask = positions >= left_padding[: inputs.shape[0], None] + # Check if mask shape matches input shape (for chunked prefill compatibility) rope_mask = mask if mask is not None and mask.shape[-1] != inputs.shape[-1]: @@ -1101,6 +2466,8 @@ def __call__( position_ids, rope_deltas = self.get_rope_index( inputs, image_grid_thw, video_grid_thw, rope_mask ) + if image_grid_thw is None and video_grid_thw is None: + rope_deltas = mx.zeros((batch_size, 1), dtype=rope_deltas.dtype) self._rope_deltas = rope_deltas self._position_ids = position_ids else: @@ -1166,8 +2533,19 @@ def __call__( def speculative_logits_from_hidden(self, hidden: mx.array) -> mx.array: if self.args.tie_word_embeddings: return self.model.embed_tokens.as_linear(hidden) + out = _target_verify_quantized_linear(self.lm_head, hidden) + if out is not None: + return out return self.lm_head(hidden) + def speculative_argmax_from_hidden(self, hidden: mx.array) -> Optional[mx.array]: + if not self.args.tie_word_embeddings: + out = _target_verify_quantized_argmax(self.lm_head, hidden) + if out is not None: + return out + logits = self.speculative_logits_from_hidden(hidden) + return mx.argmax(logits, axis=-1) + def speculative_verify_logits(self, inputs: mx.array, cache, sampler): out = self( inputs, diff --git a/mlx_vlm/models/qwen3_vl/language.py b/mlx_vlm/models/qwen3_vl/language.py index 7d8e16622..e0f513653 100644 --- a/mlx_vlm/models/qwen3_vl/language.py +++ b/mlx_vlm/models/qwen3_vl/language.py @@ -316,21 +316,20 @@ def get_rope_index( ) image_index, video_index = 0, 0 for i, input_ids in enumerate(total_input_ids): - input_ids = mx.where( - attention_mask[i] == 1, input_ids, mx.zeros_like(input_ids) - ) + row_mask = attention_mask[i].tolist() + input_tokens = [ + token + for token, keep in zip(input_ids.tolist(), row_mask) + if keep == 1 + ] image_nums, video_nums = 0, 0 - vision_start_indices = mx.sum( - mx.where( - input_ids == vision_start_token_id, - mx.arange(input_ids.shape[0]), - mx.zeros_like(input_ids), - ) - ) - vision_tokens = input_ids[vision_start_indices + 1] - image_nums = (vision_tokens == image_token_id).sum().item() - video_nums = (vision_tokens == video_token_id).sum().item() - input_tokens = input_ids.tolist() + vision_tokens = [ + input_tokens[idx + 1] + for idx, token in enumerate(input_tokens[:-1]) + if token == vision_start_token_id + ] + image_nums = sum(token == image_token_id for token in vision_tokens) + video_nums = sum(token == video_token_id for token in vision_tokens) llm_pos_ids_list: list = [] st = 0 remain_images, remain_videos = image_nums, video_nums @@ -422,7 +421,19 @@ def get_rope_index( llm_pos_ids_list.append(t_index + st_idx) llm_positions = mx.concatenate(llm_pos_ids_list, axis=1).reshape(3, -1) - mask = mx.array(attention_mask[i] == 1) + compact_max_position = llm_positions.max() + padded_positions = [[1] * total_input_ids.shape[1] for _ in range(3)] + compact_positions = llm_positions.tolist() + compact_idx = 0 + for col, keep in enumerate(row_mask): + if keep == 1: + for dim in range(3): + padded_positions[dim][col] = compact_positions[dim][ + compact_idx + ] + compact_idx += 1 + llm_positions = mx.array(padded_positions, dtype=position_ids.dtype) + mask = mx.array(row_mask, dtype=mx.bool_) expanded_mask = mx.expand_dims(mask, axis=0) expanded_mask = mx.broadcast_to(expanded_mask, (3, 1, mask.shape[0])) expanded_positions = mx.expand_dims(llm_positions, axis=1) @@ -439,7 +450,7 @@ def get_rope_index( ) position_ids = updated_position_ids mrope_position_deltas.append( - llm_positions.max() + 1 - len(total_input_ids[i]) + compact_max_position + 1 - len(input_tokens) ) mrope_position_deltas = mx.array(mrope_position_deltas).reshape(-1, 1) return position_ids, mrope_position_deltas diff --git a/mlx_vlm/server/app.py b/mlx_vlm/server/app.py index a7bd338a5..c71265cbf 100644 --- a/mlx_vlm/server/app.py +++ b/mlx_vlm/server/app.py @@ -115,6 +115,8 @@ def _build_gen_args( top_p=getattr(request, "top_p", DEFAULT_TOP_P), top_k=getattr(request, "top_k", 0), min_p=getattr(request, "min_p", 0.0), + seed=getattr(request, "seed", None), + logprobs=bool(getattr(request, "logprobs", False)), repetition_penalty=getattr(request, "repetition_penalty", None), repetition_context_size=_request_field_or_default( request, diff --git a/mlx_vlm/server/generation.py b/mlx_vlm/server/generation.py index ea8376143..009334629 100644 --- a/mlx_vlm/server/generation.py +++ b/mlx_vlm/server/generation.py @@ -22,6 +22,7 @@ DEFAULT_PREFILL_STEP_SIZE, DEFAULT_QUANTIZED_KV_START, DEFAULT_REPETITION_CONTEXT_SIZE, + DEFAULT_SEED, DEFAULT_TEMPERATURE, DEFAULT_THINKING_END_TOKEN, DEFAULT_THINKING_START_TOKEN, @@ -104,6 +105,86 @@ def get_speculative_batch_coalesce_s(): return DEFAULT_SPECULATIVE_BATCH_COALESCE_MS / 1000.0 +def _position_seed(seed: int, row_id: int, position: int) -> int: + x = (int(seed) ^ 0x9E3779B9) & 0xFFFFFFFF + x = (x + (int(row_id) + 1) * 0x85EBCA6B) & 0xFFFFFFFF + x = (x ^ ((int(position) + 1) * 0xC2B2AE35)) & 0xFFFFFFFF + x ^= x >> 16 + x = (x * 0x7FEB352D) & 0xFFFFFFFF + x ^= x >> 15 + return int(x & 0xFFFFFFFF) + + +def _position_keys(seed: int, row_ids: List[int], positions: List[int]) -> mx.array: + return mx.stack( + [ + mx.random.key(_position_seed(seed, row, pos)) + for row, pos in zip(row_ids, positions) + ] + ) + + +class _PositionedTargetSampler: + """Server sampler with stateless target draws for ragged verification.""" + + def __init__(self, *, temperature: float, top_p: float, seed: Optional[int]): + self.temperature = float(temperature) + self.top_p = float(top_p) + self.seed = DEFAULT_SEED if seed is None else int(seed) + + def __call__(self, logprobs: mx.array) -> mx.array: + if self.top_p > 0 and self.top_p < 1.0: + return top_p_sampling(logprobs, self.top_p, self.temperature) + return mx.random.categorical(logprobs * (1 / self.temperature)) + + def sample_target( + self, + logprobs: mx.array, + *, + row_ids: List[int], + positions: List[int], + ) -> mx.array: + if logprobs.shape[0] != len(row_ids) or len(row_ids) != len(positions): + raise ValueError("row_ids and positions must match logprobs batch size.") + keys = _position_keys(self.seed, row_ids, positions) + if self.top_p > 0 and self.top_p < 1.0: + return mx.vmap(self._sample_top_p_one, in_axes=(0, 0))(logprobs, keys) + return mx.vmap(self._sample_one, in_axes=(0, 0))(logprobs, keys) + + def _sample_one(self, logprobs: mx.array, key: mx.array) -> mx.array: + return mx.random.categorical(logprobs * (1 / self.temperature), key=key) + + def _sample_top_p_one(self, logprobs: mx.array, key: mx.array) -> mx.array: + if logprobs.dtype == mx.bfloat16: + logprobs = logprobs.astype(mx.float32) + probs = mx.softmax(logprobs / self.temperature, axis=-1) + sorted_indices = mx.argsort(probs, axis=-1) + sorted_probs = mx.take_along_axis(probs, sorted_indices, axis=-1) + cumulative_probs = mx.cumsum(sorted_probs, axis=-1) + top_probs = mx.where( + cumulative_probs > 1 - self.top_p, + sorted_probs, + mx.zeros_like(sorted_probs), + ) + sampled_pos = mx.random.categorical(mx.log(top_probs), key=key) + return mx.take_along_axis(sorted_indices, sampled_pos[..., None], axis=-1)[0] + + +def _sample_last_token( + logits: mx.array, + sampler: Callable[[mx.array], mx.array], + *, + row_ids: Optional[List[int]] = None, + positions: Optional[List[int]] = None, +): + logits = logits[:, -1, :] + logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) + sample_target = getattr(sampler, "sample_target", None) + if callable(sample_target) and row_ids is not None and positions is not None: + return sample_target(logprobs, row_ids=row_ids, positions=positions) + return sampler(logprobs) + + def get_server_enable_thinking(): raw = os.environ.get("MLX_VLM_ENABLE_THINKING") if raw is None: @@ -438,6 +519,7 @@ class GenerationArguments: top_k: int = 0 min_p: float = 0.0 seed: Optional[int] = None + logprobs: bool = False repetition_penalty: Optional[float] = None repetition_context_size: Optional[int] = DEFAULT_REPETITION_CONTEXT_SIZE presence_penalty: Optional[float] = None @@ -465,6 +547,8 @@ def to_generate_kwargs(self) -> dict: "min_p": self.min_p, "enable_thinking": self.enable_thinking, } + if self.seed is not None: + kw["seed"] = self.seed if self.repetition_penalty is not None: kw["repetition_penalty"] = self.repetition_penalty if self.repetition_context_size is not None: @@ -799,14 +883,11 @@ def _cpu_preprocess(self, prompt, images=None, audio=None) -> dict: def _make_sampler(self, args: GenerationArguments) -> Optional[Callable]: if args.temperature == 0: return None - - def sampler(logprobs: mx.array) -> mx.array: - if args.top_p > 0 and args.top_p < 1.0: - return top_p_sampling(logprobs, args.top_p, args.temperature) - else: - return mx.random.categorical(logprobs * (1 / args.temperature)) - - return sampler + return _PositionedTargetSampler( + temperature=args.temperature, + top_p=args.top_p, + seed=args.seed, + ) def _make_logits_processors( self, args: GenerationArguments @@ -932,7 +1013,7 @@ def _run(self): self._ready.set() - if self.draft_model is not None: + if self.draft_model is not None and self.draft_kind != "mtp": self._run_speculative() return @@ -946,8 +1027,19 @@ def _run(self): try: # Poll the request queue — non-blocking when generating, short # blocking wait when idle so we don't spin. + active_batch = bool(active) + coalesce_s = ( + get_speculative_batch_coalesce_s() + if ( + not active_batch + and self.draft_model is not None + and self.draft_kind == "mtp" + ) + else 0.0 + ) new_items, should_stop = self._collect_pending_requests( - active=bool(active) + active=active_batch, + coalesce_s=coalesce_s, ) if should_stop: break @@ -964,6 +1056,11 @@ def _run(self): except Exception: pass + if new_items and batch_gen is not None and not active: + if not batch_gen.has_work: + batch_gen.close() + batch_gen = None + for rqueue, raw_inputs, prompt_tokens, args, images in new_items: if batch_gen is None: batch_gen = BatchGenerator( @@ -975,9 +1072,14 @@ def _run(self): kv_group_size=self.kv_group_size, kv_quant_scheme=self.kv_quant_scheme, quantized_kv_start=self.quantized_kv_start, - top_logprobs_k=self.top_logprobs_k, + compute_logprobs=bool(args.logprobs), + top_logprobs_k=self.top_logprobs_k if args.logprobs else 0, stream=generation_stream, apc_manager=self.apc_manager, + draft_model=self.draft_model, + draft_kind=self.draft_kind, + draft_block_size=_get_draft_block_size_from_env(), + greedy_sampling=args.temperature == 0, ) # Vision encoder runs on the GPU thread; text tokenization @@ -1045,6 +1147,9 @@ def _run(self): mx.clear_cache() gc.collect() + if batch_gen is not None and callable(getattr(batch_gen, "close", None)): + batch_gen.close() + def _run_speculative(self): """GPU thread loop with DFlash, EAGLE-3, or MTP speculative decoding. @@ -1143,7 +1248,13 @@ def _run_speculative(self): out = lm(input_mx, cache=prompt_cache, **lm_call_kwargs) hidden = speculative_hidden_state(draft_kind, out) shared_kv_states = out.shared_kv_states if is_mtp else None - first_bonus = sampler(out.logits[:, -1:]).squeeze(-1) + sample_row_ids = [0] * B + first_bonus = _sample_last_token( + out.logits, + sampler, + row_ids=sample_row_ids, + positions=[0] * B, + ) mx.eval(first_bonus, hidden, out.logits) prompt_elapsed = time.perf_counter() - prompt_started for uid in uids: @@ -1171,7 +1282,7 @@ def _run_speculative(self): token=tok, logprobs=0.0, finish_reason=finish, - peak_memory=mx.get_peak_memory() / 1e9, + peak_memory=mx.get_peak_memory() / 1e9 if finish else 0, prompt_tps=prompt_tps_map.get(uid), ) ) @@ -1214,6 +1325,7 @@ def stop_check(seq_idx, token_id): shared_kv_states=shared_kv_states, eos_token_ids=eos_set, prompt_tokens=input_mx, + row_ids=sample_row_ids, ) for tok_list, _ in rounds_iter: for j, tok in enumerate(tok_list): @@ -1237,7 +1349,7 @@ def stop_check(seq_idx, token_id): token=tok, logprobs=0.0, finish_reason=finish, - peak_memory=mx.get_peak_memory() / 1e9, + peak_memory=mx.get_peak_memory() / 1e9 if finish else 0, prompt_tps=prompt_tps_map.get(uid), ) ) @@ -1304,10 +1416,14 @@ def _step(self, batch_gen, active, gen_kwargs=None): rqueue = info["rqueue"] tok = r.token - if hasattr(tok, "item"): + if tok is None: + text = info["streamer"].finalize() + tok = 0 + elif hasattr(tok, "item"): tok = tok.item() - - text = self._stream_text(info, tok, r.finish_reason) + text = self._stream_text(info, tok, r.finish_reason) + else: + text = self._stream_text(info, tok, r.finish_reason) lp = r.token_logprob diff --git a/mlx_vlm/speculative/drafters/qwen3_5_mtp/qwen3_5_mtp.py b/mlx_vlm/speculative/drafters/qwen3_5_mtp/qwen3_5_mtp.py index dcf27a2ec..223a55cbe 100644 --- a/mlx_vlm/speculative/drafters/qwen3_5_mtp/qwen3_5_mtp.py +++ b/mlx_vlm/speculative/drafters/qwen3_5_mtp/qwen3_5_mtp.py @@ -5,7 +5,7 @@ import mlx.nn as nn from ....models.base import create_attention_mask -from ....models.cache import KVCache +from ....models.cache import BatchKVCache, KVCache from ....models.qwen3_5.language import Qwen3_5DecoderLayer from ....models.qwen3_5_moe.language import Qwen3_5MoeDecoderLayer from .config import Qwen3_5MTPConfig @@ -14,7 +14,8 @@ class Qwen3_5MTPDraftModel(nn.Module): supports_greedy_draft_argmax = True prefer_requested_block_size = True - requires_uniform_batch_acceptance = True + requires_uniform_batch_acceptance = False + supports_ragged_batch_acceptance = True def __init__(self, config: Qwen3_5MTPConfig): super().__init__() @@ -90,15 +91,19 @@ def bind(self, target_model) -> "Qwen3_5MTPDraftModel": ) return self - def make_cache(self) -> List[KVCache]: + def make_cache(self, left_padding: Optional[List[int]] = None) -> List[KVCache]: + if left_padding is not None: + return [BatchKVCache(left_padding) for _ in self.layers] return [KVCache() for _ in self.layers] - def reset(self, target_model) -> List[KVCache]: + def reset( + self, target_model, left_padding: Optional[List[int]] = None + ) -> List[KVCache]: self.bind(target_model) self.accept_lens = [] self.draft_lens = [] self._draft_round = 0 - self._cache = self.make_cache() + self._cache = self.make_cache(left_padding) self._seed_token = None self._seed_hidden = None self._next_position = 0 @@ -126,7 +131,8 @@ def set_shared_kv( position = kv_valid_len self._kv_valid_len = kv_valid_len self._position = position - if not self._cache or self._cache[0].offset == 0: + cache_empty = not self._cache or all(cache.empty() for cache in self._cache) + if cache_empty: self._next_position = kv_valid_len def _draft_start_position(self): @@ -272,12 +278,7 @@ def accept_verified_tokens_batch( token_dtype: mx.Dtype = mx.int32, greedy: bool = False, ) -> None: - """Extend the Qwen MTP drafter cache after a batched verify step. - - Qwen's native MTP drafter uses its own recurrent cache. For now the - batched path is restricted to lockstep acceptance so this cache keeps a - single scalar offset, matching normal decode's aligned batch invariant. - """ + """Extend the Qwen MTP drafter cache after a batched verify step.""" if len(accepted) <= 1: self.accept_verified_tokens( verify_hidden, @@ -290,43 +291,113 @@ def accept_verified_tokens_batch( ) return - accepted_set = {int(a) for a in accepted} - if len(accepted_set) != 1: - raise ValueError( - "Qwen MTP batched cache update requires uniform acceptance." - ) - accepted_i = accepted_set.pop() - - keep_appended = min(accepted_i, self._round_appended) - trim = self._round_appended - keep_appended - if trim > 0: - for cache in self._cache: - cache.trim(trim) - self._next_position = ( - self._next_position - trim - if isinstance(self._next_position, int) - else self._next_position - trim + accepted = [int(a) for a in accepted] + keep_appended = [min(a, self._round_appended) for a in accepted] + trims = [self._round_appended - keep for keep in keep_appended] + if any(trims): + if len(set(trims)) == 1: + trim = trims[0] + for cache in self._cache: + cache.trim(trim) + self._next_position = ( + self._next_position - trim + if isinstance(self._next_position, int) + else self._next_position - trim + ) + elif all(hasattr(cache, "prepare") for cache in self._cache) and all( + hasattr(cache, "finalize") for cache in self._cache + ): + for cache in self._cache: + cache.prepare(right_padding=trims) + for cache in self._cache: + cache.finalize() + if isinstance(self._next_position, mx.array): + self._next_position = self._next_position - mx.array( + trims, dtype=mx.int32 + ) + else: + self._next_position = mx.full( + (len(accepted),), self._next_position, dtype=mx.int32 + ) - mx.array(trims, dtype=mx.int32) + else: + raise ValueError( + "Qwen MTP ragged batch acceptance requires a batch-aware cache." + ) + + draft_rows = draft_tokens.tolist() + row_tokens = [] + row_hiddens = [] + for row, accepted_i in enumerate(accepted): + tokens_i = [] + hiddens_i = [] + for draft_idx in range(keep_appended[row], accepted_i): + tokens_i.append(int(draft_rows[row][draft_idx])) + hiddens_i.append( + verify_hidden[row : row + 1, draft_idx : draft_idx + 1, :] + ) + if new_tokens[row]: + tokens_i.append(int(new_tokens[row][-1])) + hiddens_i.append( + verify_hidden[row : row + 1, accepted_i : accepted_i + 1, :] + ) + row_tokens.append(tokens_i) + row_hiddens.append(hiddens_i) + + lengths = [len(tokens_i) for tokens_i in row_tokens] + max_len = max(lengths) if lengths else 0 + if max_len > 0: + token_data = [] + hidden_rows = [] + for tokens_i, hiddens_i in zip(row_tokens, row_hiddens): + token_data.extend(tokens_i) + pad = max_len - len(tokens_i) + if pad: + token_data.extend([0] * pad) + if hiddens_i: + hidden_row = mx.concatenate(hiddens_i, axis=1) + else: + hidden_row = mx.zeros( + (1, 0, verify_hidden.shape[-1]), dtype=verify_hidden.dtype + ) + if pad: + hidden_row = mx.concatenate( + [ + hidden_row, + mx.zeros( + (1, pad, verify_hidden.shape[-1]), + dtype=verify_hidden.dtype, + ), + ], + axis=1, + ) + hidden_rows.append(hidden_row) + + tokens = mx.array(token_data, dtype=token_dtype).reshape( + len(row_tokens), max_len ) + hiddens = mx.concatenate(hidden_rows, axis=0) + right_padding = [max_len - length for length in lengths] + if any(right_padding): + for cache in self._cache: + prepare = getattr(cache, "prepare", None) + if callable(prepare): + prepare(right_padding=right_padding, lengths=lengths) - token_chunks = [] - hidden_chunks = [] - for draft_idx in range(keep_appended, accepted_i): - token_chunks.append(draft_tokens[:, draft_idx : draft_idx + 1]) - hidden_chunks.append(verify_hidden[:, draft_idx : draft_idx + 1, :]) - - if all(new_tokens): - bonus = mx.array( - [[int(row_tokens[-1])] for row_tokens in new_tokens], - dtype=token_dtype, - ) - token_chunks.append(bonus) - hidden_chunks.append(verify_hidden[:, accepted_i : accepted_i + 1, :]) - - if token_chunks: - tokens = mx.concatenate(token_chunks, axis=1).astype(token_dtype) - hiddens = mx.concatenate(hidden_chunks, axis=1) h = self._forward_tokens(tokens, hiddens, token_dtype) - self._set_seed_from_hidden(h[:, -1:, :], sampler, greedy) + + if any(right_padding): + for cache in self._cache: + finalize = getattr(cache, "finalize", None) + if callable(finalize): + finalize() + if isinstance(self._next_position, mx.array): + self._next_position = self._next_position - mx.array( + right_padding, dtype=mx.int32 + ) + + last_idx = mx.array([length - 1 for length in lengths], dtype=mx.int32) + last_hidden = mx.take_along_axis(h, last_idx[:, None, None], axis=1) + self._set_seed_from_hidden(last_hidden, sampler, greedy) self._round_appended = 0 def filter_batch(self, keep) -> None: @@ -335,7 +406,10 @@ def filter_batch(self, keep) -> None: keep = mx.array(keep, dtype=mx.int32) for cache in self._cache: - if cache.keys is not None: + cache_filter = getattr(cache, "filter", None) + if callable(cache_filter): + cache_filter(keep) + elif cache.keys is not None: cache.keys = cache.keys[keep] cache.values = cache.values[keep] diff --git a/mlx_vlm/speculative/mtp.py b/mlx_vlm/speculative/mtp.py index 4253ce24f..1445a8898 100644 --- a/mlx_vlm/speculative/mtp.py +++ b/mlx_vlm/speculative/mtp.py @@ -142,6 +142,15 @@ def _mtp_verify_target( sample_target_tokens: bool = True, ) -> _MTPVerifyResult: if sample_target_tokens: + argmax_from_hidden = getattr(lm, "speculative_argmax_from_hidden", None) + if callable(argmax_from_hidden): + result = _mtp_verify_without_logits(lm, verify_input, prompt_cache) + if result is not None: + target_tokens = argmax_from_hidden(result.hidden) + if target_tokens is not None: + result.target_tokens = target_tokens + return result + result = _mtp_verify_with_model_method(lm, verify_input, prompt_cache, sampler) if result is not None: return result @@ -176,6 +185,8 @@ def _speculative_walk_deferred_greedy( draft_tokens: mx.array, sampler: Callable[[mx.array], mx.array], budget: int, + row_id: int = 0, + base_position: Optional[int] = None, ) -> Tuple[int, List[int]]: """Greedy MTP walk that projects target logits only until rejection.""" n_draft = draft_tokens.shape[1] @@ -191,7 +202,15 @@ def _speculative_walk_deferred_greedy( if logits.ndim == 3 and logits.shape[1] == 1: logits = logits[:, 0, :] logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) - target_token = sampler(logprobs) + sample_target = getattr(sampler, "sample_target", None) + if callable(sample_target) and base_position is not None: + target_token = sample_target( + logprobs, + row_ids=[row_id], + positions=[int(base_position) + pos], + ) + else: + target_token = sampler(logprobs) mx.eval(target_token) token = int(target_token.reshape(-1).item()) @@ -208,12 +227,42 @@ def _speculative_walk_deferred_greedy( return accepted, new_tokens +def _positioned_target_tokens( + lm: nn.Module, + target_hidden: mx.array, + sampler: Callable[[mx.array], mx.array], + *, + row_id: int, + base_position: int, +) -> Optional[mx.array]: + sample_target = getattr(sampler, "sample_target", None) + if not callable(sample_target): + return None + + with mx.stream(generation_stream): + logits = lm.speculative_logits_from_hidden(target_hidden) + if logits.ndim == 3: + if logits.shape[0] != 1: + return None + logits = logits[0] + logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) + positions = [int(base_position) + pos for pos in range(logprobs.shape[0])] + target_tokens = sample_target( + logprobs, + row_ids=[int(row_id)] * len(positions), + positions=positions, + ) + return target_tokens[None, :] + + def _speculative_walk_batch_deferred_greedy( lm: nn.Module, target_hidden: mx.array, draft_tokens: mx.array, sampler: Callable[[mx.array], mx.array], budgets: List[int], + row_ids: Optional[List[int]] = None, + base_positions: Optional[List[int]] = None, ) -> Tuple[List[int], List[List[int]]]: """Batched greedy walk that projects target logits only until all rows stop.""" B = draft_tokens.shape[0] @@ -234,7 +283,19 @@ def _speculative_walk_batch_deferred_greedy( if logits.ndim == 3 and logits.shape[1] == 1: logits = logits[:, 0, :] logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) - target_tokens = sampler(logprobs) + if _sampler_supports_positioned_target(sampler): + if row_ids is None or base_positions is None: + raise ValueError( + "positioned target sampling requires row_ids and " + "base_positions." + ) + target_tokens = sampler.sample_target( + logprobs, + row_ids=row_ids, + positions=[int(position) + pos for position in base_positions], + ) + else: + target_tokens = sampler(logprobs) mx.eval(target_tokens) target_list = [int(token) for token in target_tokens.reshape(-1).tolist()] @@ -254,17 +315,106 @@ def _speculative_walk_batch_deferred_greedy( return accepted, new_tokens +def _speculative_walk_batch_deferred_uniform( + lm: nn.Module, + target_hidden: mx.array, + draft_tokens: mx.array, + sampler: Callable[[mx.array], mx.array], + budgets: List[int], +) -> Tuple[List[int], List[List[int]]]: + """Deferred walk for models whose batched drafter cache needs lockstep rows. + + Stop at the first rejection in any row. This avoids consuming target sampler + RNG for verifier positions that will be thrown away by the uniform clamp. + """ + B = draft_tokens.shape[0] + n_draft = draft_tokens.shape[1] + draft_lists = [[int(token) for token in row] for row in draft_tokens.tolist()] + budgets = [int(budget) for budget in budgets] + + accepted = 0 + for pos in range(n_draft + 1): + with mx.stream(generation_stream): + logits = lm.speculative_logits_from_hidden( + target_hidden[:, pos : pos + 1, :] + ) + if logits.ndim == 3 and logits.shape[1] == 1: + logits = logits[:, 0, :] + logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) + target_tokens = sampler(logprobs) + mx.eval(target_tokens) + target_list = [int(token) for token in target_tokens.reshape(-1).tolist()] + + if pos < n_draft and all( + target_list[row] == draft_lists[row][pos] for row in range(B) + ): + accepted += 1 + continue + + new_tokens: List[List[int]] = [] + for row, budget in enumerate(budgets): + row_tokens = draft_lists[row][:accepted] + if len(row_tokens) < budget: + row_tokens.append(target_list[row]) + new_tokens.append(row_tokens[:budget]) + return [accepted] * B, new_tokens + + return [accepted] * B, [draft_lists[row][: budgets[row]] for row in range(B)] + + +def _sampler_supports_positioned_target( + sampler: Callable[[mx.array], mx.array] +) -> bool: + return callable(getattr(sampler, "sample_target", None)) + + +def _mtp_use_uniform_deferred_walk( + draft_model: nn.Module, + *, + n_active: int, + greedy_sampling: bool, + sampler: Callable[[mx.array], mx.array], +) -> bool: + if n_active <= 1: + return False + + if getattr(draft_model, "requires_uniform_batch_acceptance", False): + return True + + if _sampler_supports_positioned_target(sampler): + return False + + # Non-greedy sampling uses a single target RNG stream. If rows accept a + # ragged number of tokens, the verifier would draw future samples for some + # rows before no-drafter has drawn the next lockstep batch sample. + return not greedy_sampling + + def _mtp_acceptance_walk( lm: nn.Module, verify: _MTPVerifyResult, draft_tokens: mx.array, sampler: Callable[[mx.array], mx.array], budget: int, + row_id: int = 0, + base_position: Optional[int] = None, ) -> Tuple[int, List[int]]: if verify.target_tokens is not None: mx.async_eval(verify.target_tokens, verify.hidden) return _speculative_walk(draft_tokens, verify.target_tokens, budget) + if base_position is not None: + target_tokens = _positioned_target_tokens( + lm, + verify.hidden, + sampler, + row_id=row_id, + base_position=base_position, + ) + if target_tokens is not None: + mx.async_eval(target_tokens, verify.hidden) + return _speculative_walk(draft_tokens, target_tokens, budget) + mx.async_eval(verify.hidden) return _speculative_walk_deferred_greedy( lm, @@ -272,6 +422,8 @@ def _mtp_acceptance_walk( draft_tokens, sampler, budget, + row_id=row_id, + base_position=base_position, ) @@ -402,7 +554,11 @@ def _mtp_rounds( block_total = _dflash_block_total(draft_model, draft_block_size) configured_block_total = int(getattr(draft_model.config, "block_size", block_total)) draft_model.reset(model) - sampler_rng = _SpeculativeSamplerRNG(draft_model, enabled=not greedy_sampling) + sampler_rng = _SpeculativeSamplerRNG( + draft_model, + enabled=not greedy_sampling + and not _sampler_supports_positioned_target(sampler), + ) # Hidden from prefill is full prompt-length; reduce to a single slot. # The semantically-correct choice is the *last* prompt token's hidden: @@ -421,7 +577,7 @@ def _mtp_rounds( first_bonus, sampler, token_dtype, - **_mtp_draft_kwargs(draft_model, greedy_sampling), + **_mtp_draft_kwargs(draft_model, greedy_sampling, sampler), ) if hidden.shape[1] > 1: @@ -457,7 +613,7 @@ def _mtp_rounds( bs, sampler, token_dtype, - **_mtp_draft_kwargs(draft_model, greedy_sampling), + **_mtp_draft_kwargs(draft_model, greedy_sampling, sampler), ) with mx.stream(generation_stream): @@ -477,8 +633,12 @@ def _mtp_rounds( draft_tokens, sampler, max_tokens - emitted, + row_id=0, + base_position=emitted, + ) + sampler_rng.target_sampled( + sync_draft=not _sampler_supports_positioned_target(sampler) ) - sampler_rng.target_sampled(sync_draft=True) _record_speculative_round(draft_model, accepted, bs - 1) for tok in new_tokens: @@ -497,7 +657,7 @@ def _mtp_rounds( new_tokens, sampler, token_dtype, - **_mtp_draft_kwargs(draft_model, greedy_sampling), + **_mtp_draft_kwargs(draft_model, greedy_sampling, sampler), ) # Hidden for next round: pick the slot of the newly accepted bonus. @@ -512,7 +672,7 @@ def _mtp_rounds( next_shared_kv = _slice_shared_kv_after_reject( verify.shared_kv_states, bs - (accepted + 1) ) - kv_offset = _mtp_cache_offset_max(prompt_cache) + kv_offset += accepted + 1 draft_model.set_shared_kv( next_shared_kv, kv_offset, @@ -559,8 +719,13 @@ def _mtp_draft_position(kv_valid_len: Any) -> Any: return mx.maximum(mx.array(kv_valid_len, dtype=mx.int32) - 1, 0) -def _mtp_draft_kwargs(draft_model: nn.Module, greedy_sampling: bool) -> Dict[str, bool]: - if greedy_sampling and getattr(draft_model, "supports_greedy_draft_argmax", False): +def _mtp_draft_kwargs( + draft_model: nn.Module, + greedy_sampling: bool, + sampler: Optional[Callable[[mx.array], mx.array]] = None, +) -> Dict[str, bool]: + greedy_draft = greedy_sampling or _sampler_supports_positioned_target(sampler) + if greedy_draft and getattr(draft_model, "supports_greedy_draft_argmax", False): return {"greedy": True} return {} @@ -589,7 +754,7 @@ def _mtp_draft_block_active( block_size, sampler, token_dtype, - **_mtp_draft_kwargs(draft_model, greedy_sampling), + **_mtp_draft_kwargs(draft_model, greedy_sampling, sampler), ) positions_list = [int(position) for position in positions] @@ -602,7 +767,7 @@ def _mtp_draft_block_active( block_size, sampler, token_dtype, - **_mtp_draft_kwargs(draft_model, greedy_sampling), + **_mtp_draft_kwargs(draft_model, greedy_sampling, sampler), ) rowwise_tokens = [] @@ -629,7 +794,7 @@ def _mtp_draft_block_active( block_size, sampler, token_dtype, - **_mtp_draft_kwargs(draft_model, greedy_sampling), + **_mtp_draft_kwargs(draft_model, greedy_sampling, sampler), ) ) @@ -660,8 +825,9 @@ def _mtp_rounds_batch( stop_check: Optional[Callable[[int, int], bool]] = None, eos_token_ids: Optional[set] = None, greedy_sampling: bool = False, + row_ids: Optional[List[int]] = None, ) -> Generator[Tuple[List[Optional[int]], None], None, None]: - """Batched Gemma 4 MTP round loop (B > 1). + """Batched Gemma 4 MTP round loop (B >= 1). Mirrors ``_dflash_rounds_batch``: per-row state tracked by original index, continuous-batching filter on row finish. Differences vs DFlash @@ -677,10 +843,18 @@ def _mtp_rounds_batch( ) B = first_bonus.shape[0] + row_ids = list(range(B)) if row_ids is None else list(row_ids) block_total = _dflash_block_total(draft_model, draft_block_size) configured_block_total = int(getattr(draft_model.config, "block_size", block_total)) - draft_model.reset(model) - sampler_rng = _SpeculativeSamplerRNG(draft_model, enabled=not greedy_sampling) + if getattr(draft_model, "supports_ragged_batch_acceptance", False): + draft_model.reset(model, left_padding=[0] * B) + else: + draft_model.reset(model) + sampler_rng = _SpeculativeSamplerRNG( + draft_model, + enabled=not greedy_sampling + and not _sampler_supports_positioned_target(sampler), + ) # First-round hidden: prefill output may have shape [B, L, H]; reduce # to a single slot per row (last prompt token's hidden — see comment in @@ -747,6 +921,7 @@ def _mtp_rounds_batch( verify_input, prompt_cache, sampler, + sample_target_tokens=greedy_sampling, ) hidden_full = verify.hidden # [B_active, bs, H] @@ -773,14 +948,38 @@ def _mtp_rounds_batch( ) else: sampler_rng.target_eval(hidden_full) - accepted_list, new_tokens_list = _speculative_walk_batch_deferred_greedy( - lm, - hidden_full, - draft_tokens, - sampler, - budgets, + if _mtp_use_uniform_deferred_walk( + draft_model, + n_active=n_active, + greedy_sampling=greedy_sampling, + sampler=sampler, + ): + accepted_list, new_tokens_list = ( + _speculative_walk_batch_deferred_uniform( + lm, + hidden_full, + draft_tokens, + sampler, + budgets, + ) + ) + else: + accepted_list, new_tokens_list = ( + _speculative_walk_batch_deferred_greedy( + lm, + hidden_full, + draft_tokens, + sampler, + budgets, + row_ids=[row_ids[active_idx[j]] for j in range(n_active)], + base_positions=[ + emitted[active_idx[j]] for j in range(n_active) + ], + ) + ) + sampler_rng.target_sampled( + sync_draft=not _sampler_supports_positioned_target(sampler) ) - sampler_rng.target_sampled(sync_draft=True) # Keep the adaptive block-size history on a per-round basis so # batched MTP reacts like the singleton loop instead of letting # batch size change the controller signal. @@ -791,7 +990,6 @@ def _mtp_rounds_batch( ) max_a = max(accepted_list) - accepted_arr = mx.array(accepted_list) accept_verified = getattr(draft_model, "accept_verified_tokens_batch", None) if callable(accept_verified): @@ -803,7 +1001,7 @@ def _mtp_rounds_batch( new_tokens_list, sampler, token_dtype, - **_mtp_draft_kwargs(draft_model, greedy_sampling), + **_mtp_draft_kwargs(draft_model, greedy_sampling, sampler), ) # Per-row hidden: each row picks its own accepted slot from @@ -834,7 +1032,7 @@ def _mtp_rounds_batch( finished[orig] = True if stop_check is not None and stop_check(orig, tok): finished[orig] = True - yield tokens_out, None + yield tokens_out, {"round_pos": pos, "round_len": max_new} # Update bonus tokens and per-row positions for j in range(n_active): @@ -845,10 +1043,10 @@ def _mtp_rounds_batch( # Rollback target cache (uniform trim by ``bs - max_a - 1`` plus # per-row tail-zero on rows that accepted less). - if max_a < bs - 1: + if any(a < bs - 1 for a in accepted_list): with mx.stream(generation_stream): lm.rollback_speculative_cache( - prompt_cache, verify.gdn_states, accepted_arr, bs + prompt_cache, verify.gdn_states, accepted_list, bs ) # Slice + tail-zero ``verify.shared_kv_states`` to match the @@ -914,7 +1112,7 @@ def _mtp_rounds_batch( # Re-bind drafter with new shared_kv and per-row positions. positions_active = [positions[active_idx[j]] for j in range(len(active_idx))] - new_kv_offset = _mtp_cache_offset_max(prompt_cache) + new_kv_offset = max(positions_active) if positions_active else 0 draft_model.set_shared_kv( next_shared_kv, kv_offset=new_kv_offset, diff --git a/mlx_vlm/speculative/utils.py b/mlx_vlm/speculative/utils.py index 0606e8065..62648bf28 100644 --- a/mlx_vlm/speculative/utils.py +++ b/mlx_vlm/speculative/utils.py @@ -132,6 +132,7 @@ def run_speculative_server_rounds( shared_kv_states: Optional[dict] = None, eos_token_ids: Optional[set] = None, prompt_tokens: Optional[mx.array] = None, + row_ids: Optional[List[int]] = None, ) -> Generator[Tuple[List[Optional[int]], None], None, None]: batch_size = int(first_bonus.shape[0]) if first_bonus.ndim > 0 else 1 @@ -173,25 +174,6 @@ def run_speculative_server_rounds( return if draft_kind == "mtp": - if batch_size == 1: - yield from ( - ([tok], state) - for tok, state in _mtp_rounds( - model, - draft_model, - prompt_cache, - hidden, - shared_kv_states, - first_bonus=int(first_bonus.reshape(-1).item()), - max_tokens=max_tokens, - sampler=sampler, - draft_block_size=draft_block_size, - token_dtype=token_dtype, - greedy_sampling=greedy_sampling, - ) - ) - return - yield from _mtp_rounds_batch( model, draft_model, @@ -206,6 +188,7 @@ def run_speculative_server_rounds( stop_check=stop_check, eos_token_ids=eos_token_ids, greedy_sampling=greedy_sampling, + row_ids=row_ids, ) return diff --git a/mlx_vlm/tests/test_generate.py b/mlx_vlm/tests/test_generate.py index 37fda65ce..8770cabfb 100644 --- a/mlx_vlm/tests/test_generate.py +++ b/mlx_vlm/tests/test_generate.py @@ -8,6 +8,7 @@ import mlx.core as mx import pytest +from mlx_lm.models.cache import BatchKVCache, KVCache from mlx_vlm import apc as apc_module from mlx_vlm.generate import ( @@ -18,6 +19,7 @@ GenerationBatch, GenerationResult, PromptProcessingBatch, + SpeculativeGenerationBatch, _left_pad_prompts, _prime_cached_prefix_rope_state, ) @@ -691,6 +693,75 @@ def apply_forced_token(self, next_y): second = batch.next() assert [r.token for r in second] == [3] + def test_generation_batch_uses_greedy_hidden_argmax_without_logprobs(self): + class FastArgmaxModel: + def __init__(self): + self.calls = [] + + def __call__(self, input_ids, cache=None, **kwargs): + del cache + self.calls.append(kwargs) + assert kwargs["return_hidden"] is True + assert kwargs["skip_logits"] is True + hidden = mx.ones((input_ids.shape[0], input_ids.shape[1], 3)) + return SimpleNamespace(hidden_states=[hidden]) + + def speculative_argmax_from_hidden(self, hidden): + return mx.full((hidden.shape[0], hidden.shape[1]), 7, dtype=mx.int32) + + model = FastArgmaxModel() + batch = GenerationBatch( + model=model, + uids=[0, 1], + inputs=mx.array([5, 6], dtype=mx.int32), + prompt_cache=[], + sampler=lambda logprobs: mx.argmax(logprobs, axis=-1), + stop_criteria=lambda token: False, + max_tokens=[2, 2], + greedy_sampling=True, + ) + batch.compute_logprobs = False + + first = batch.next() + assert [r.token for r in first] == [5, 6] + assert batch._next_tokens.tolist() == [7, 7] + assert model.calls == [{"return_hidden": True, "skip_logits": True}] + + def test_speculative_generation_batch_drains_full_round(self, monkeypatch): + def fake_rounds(*args, **kwargs): + del args, kwargs + yield [1, 10], {"round_pos": 0, "round_len": 2} + yield [2, 11], {"round_pos": 1, "round_len": 2} + yield [3, 12], {"round_pos": 0, "round_len": 1} + + monkeypatch.setattr(ar_module, "run_speculative_server_rounds", fake_rounds) + + batch = SpeculativeGenerationBatch( + model=SimpleNamespace(), + draft_model=SimpleNamespace(), + draft_kind="mtp", + uids=[100, 200], + first_tokens=mx.array([0, 9], dtype=mx.int32), + prompt_cache=[], + sampler=lambda logprobs: mx.argmax(logprobs, axis=-1), + stop_criteria=lambda token: False, + max_tokens=[10, 10], + hidden=mx.zeros((2, 1, 1)), + shared_kv_states=None, + prompt_tokens=mx.array([[0], [9]], dtype=mx.int32), + ) + + first = batch.next() + assert [(r.uid, r.token) for r in first] == [(100, 0), (200, 9)] + + second = batch.next() + assert [(r.uid, r.token) for r in second] == [ + (100, 1), + (200, 10), + (100, 2), + (200, 11), + ] + def test_generation_batch_extend_keeps_processor_context_aligned(self): class FixedLogitModel: def __call__(self, input_ids, cache=None, **kwargs): @@ -738,6 +809,42 @@ def force_token_2(tokens, logits): assert [r.token for r in first] == [5, 6, 7] assert seen_contexts == [[30, 7]] + def test_generation_batch_extend_promotes_singleton_kv_cache(self): + def make_kv_cache(value): + c = KVCache() + keys = mx.full((1, 2, 3, 4), value, dtype=mx.float32) + values = mx.full((1, 2, 3, 4), value + 1, dtype=mx.float32) + c.update_and_fetch(keys, values) + return c + + sampler = lambda logprobs: mx.argmax(logprobs, axis=-1) + stop_criteria = lambda token: False + first = GenerationBatch( + model=MagicMock(), + uids=[0], + inputs=mx.array([5], dtype=mx.int32), + prompt_cache=[make_kv_cache(1.0)], + sampler=sampler, + stop_criteria=stop_criteria, + max_tokens=[2], + ) + second = GenerationBatch( + model=MagicMock(), + uids=[1], + inputs=mx.array([6], dtype=mx.int32), + prompt_cache=[make_kv_cache(3.0)], + sampler=sampler, + stop_criteria=stop_criteria, + max_tokens=[2], + ) + + first.extend(second) + + assert isinstance(first.prompt_cache[0], BatchKVCache) + assert first.prompt_cache[0].left_padding.tolist() == [0, 0] + assert first.prompt_cache[0].keys.shape[0] == 2 + assert first._next_tokens.tolist() == [5, 6] + def test_remove_from_unprocessed(self, mock_model, mock_processor): gen = BatchGenerator( model=mock_model.language_model, diff --git a/mlx_vlm/tests/test_server.py b/mlx_vlm/tests/test_server.py index fffe3db1b..a0bb0fe89 100644 --- a/mlx_vlm/tests/test_server.py +++ b/mlx_vlm/tests/test_server.py @@ -61,6 +61,101 @@ def test_speculative_server_dispatches_mtp_batch_loop(): ) +def test_speculative_server_samples_first_bonus_like_decode_step(): + seen = {} + logits = mx.array( + [ + [[1.0, 2.0, 3.0]], + [[4.0, 1.0, 0.0]], + ], + dtype=mx.float32, + ) + + def sampler(logprobs): + seen["shape"] = logprobs.shape + seen["values"] = logprobs + return mx.argmax(logprobs, axis=-1) + + tokens = server_generation._sample_last_token(logits, sampler) + expected_logprobs = logits[:, -1, :] - mx.logsumexp( + logits[:, -1, :], axis=-1, keepdims=True + ) + mx.eval(tokens, seen["values"], expected_logprobs) + + assert seen["shape"] == (2, 3) + assert tokens.tolist() == [2, 0] + assert bool(mx.allclose(seen["values"], expected_logprobs).item()) + + +def test_speculative_server_samples_first_bonus_with_positioned_sampler(): + seen = {} + logits = mx.array( + [ + [[1.0, 2.0, 3.0]], + [[4.0, 1.0, 0.0]], + ], + dtype=mx.float32, + ) + + class Sampler: + def __call__(self, logprobs): + raise AssertionError("positioned sampler was not used") + + def sample_target(self, logprobs, *, row_ids, positions): + seen["shape"] = logprobs.shape + seen["row_ids"] = list(row_ids) + seen["positions"] = list(positions) + return mx.argmax(logprobs, axis=-1) + + tokens = server_generation._sample_last_token( + logits, + Sampler(), + row_ids=[0, 0], + positions=[0, 0], + ) + mx.eval(tokens) + + assert seen == { + "shape": (2, 3), + "row_ids": [0, 0], + "positions": [0, 0], + } + assert tokens.tolist() == [2, 0] + + +def test_positioned_target_sampler_is_batch_grouping_invariant(): + sampler = server_generation._PositionedTargetSampler( + temperature=0.7, top_p=1.0, seed=42 + ) + logits = mx.array( + [ + [0.0, 1.0, 2.0, 3.0], + [3.0, 2.0, 1.0, 0.0], + ], + dtype=mx.float32, + ) + logprobs = logits - mx.logsumexp(logits, axis=-1, keepdims=True) + + batched = sampler.sample_target( + logprobs, + row_ids=[0, 0], + positions=[5, 5], + ) + single_0 = sampler.sample_target( + logprobs[0:1], + row_ids=[0], + positions=[5], + ) + single_1 = sampler.sample_target( + logprobs[1:2], + row_ids=[0], + positions=[5], + ) + mx.eval(batched, single_0, single_1) + + assert batched.tolist() == [single_0.item(), single_1.item()] + + def test_speculative_server_dispatches_eagle3_batch_loop(): assert ( speculative_utils.get_speculative_rounds_batch("eagle3") @@ -3029,6 +3124,330 @@ def fake_initialize_model(): (str(uid * 10 + 1), "length"), ] + def test_run_routes_mtp_through_batch_generator(self, monkeypatch): + batch_state = {} + draft_model = object() + + class FakeDetokenizer: + def __init__(self): + self.last_segment = "" + + def reset(self): + self.last_segment = "" + + def add_token(self, token): + self.last_segment = str(token) + + def finalize(self): + pass + + class FakeBatchGenerator: + def __init__(self, *args, **kwargs): + del args + batch_state["kwargs"] = kwargs + self._next_uid = 1 + self._active = {} + self.next_active_sizes = [] + batch_state["instance"] = self + + def insert(self, *args, **kwargs): + del args, kwargs + uid = self._next_uid + self._next_uid += 1 + self._active[uid] = True + return (uid,) + + def remove(self, uid): + return self._active.pop(uid, None) is not None + + @property + def unprocessed_prompts(self): + return [] + + @property + def has_pending_prompts(self): + return False + + def next(self, **kwargs): + del kwargs + self.next_active_sizes.append(len(self._active)) + responses = [ + SimpleNamespace( + uid=uid, + token=uid + 100, + token_logprob=0.0, + finish_reason="length", + ) + for uid in sorted(self._active) + ] + self._active.clear() + return [], responses + + monkeypatch.setattr(server_generation, "BatchGenerator", FakeBatchGenerator) + monkeypatch.setattr( + server_generation, + "_get_draft_block_size_from_env", + lambda: 6, + ) + monkeypatch.setattr( + server_generation, + "make_streaming_detokenizer", + lambda _: FakeDetokenizer(), + ) + + gen = server.ResponseGenerator.__new__(server.ResponseGenerator) + gen.model_path = "demo" + gen.adapter_path = None + gen.model = None + gen.processor = None + gen.config = None + gen.stop_tokens = set() + gen.vision_cache = None + gen.draft_model = None + gen.draft_kind = None + gen.kv_bits = None + gen.kv_group_size = server.DEFAULT_KV_GROUP_SIZE + gen.kv_quant_scheme = server.DEFAULT_KV_QUANT_SCHEME + gen.quantized_kv_start = server.DEFAULT_QUANTIZED_KV_START + gen.top_logprobs_k = 0 + gen.apc_manager = None + gen.tokenizer = SimpleNamespace() + gen.requests = Queue() + gen._stop = False + gen._ready = Event() + gen._load_error = None + gen._cancelled = set() + gen._cancel_lock = Lock() + + def fake_initialize_model(): + gen.model = SimpleNamespace(language_model=object()) + gen.processor = SimpleNamespace() + gen.config = SimpleNamespace() + gen.stop_tokens = set() + gen.draft_model = draft_model + gen.draft_kind = "mtp" + gen.tokenizer = SimpleNamespace() + + gen._initialize_model = fake_initialize_model + gen._run_speculative = lambda: pytest.fail("MTP should use BatchGenerator") + gen._gpu_embed = lambda raw_inputs, images=None: ( + mx.array([[raw_inputs["request_id"]]], dtype=mx.int32), + {}, + ) + + request_queues = [] + for request_id in range(2): + rqueue = Queue() + request_queues.append(rqueue) + gen.requests.put( + ( + rqueue, + {"request_id": request_id}, + 1, + server.GenerationArguments(max_tokens=1, temperature=0), + None, + ) + ) + + worker = Thread(target=gen._run, daemon=True) + worker.start() + + try: + for rqueue in request_queues: + ctx = rqueue.get(timeout=1) + assert isinstance(ctx, server.GenerationContext) + item = rqueue.get(timeout=1) + assert item.finish_reason == "length" + assert rqueue.get(timeout=1) is None + finally: + gen._stop = True + gen.requests.put(None) + worker.join(timeout=2) + + kwargs = batch_state["kwargs"] + assert kwargs["draft_model"] is draft_model + assert kwargs["draft_kind"] == "mtp" + assert kwargs["draft_block_size"] == 6 + assert kwargs["greedy_sampling"] is True + assert kwargs["compute_logprobs"] is False + assert batch_state["instance"].next_active_sizes == [2] + + def test_run_coalesces_idle_mtp_batch_generator(self, monkeypatch): + monkeypatch.setenv("MLX_VLM_SPEC_BATCH_COALESCE_MS", "37") + calls = [] + draft_model = object() + + gen = server.ResponseGenerator.__new__(server.ResponseGenerator) + gen.draft_model = None + gen.draft_kind = None + gen._stop = False + gen._ready = Event() + gen._load_error = None + + def fake_initialize_model(): + gen.model = SimpleNamespace(language_model=object()) + gen.processor = SimpleNamespace() + gen.config = SimpleNamespace() + gen.stop_tokens = set() + gen.draft_model = draft_model + gen.draft_kind = "mtp" + gen.tokenizer = SimpleNamespace() + + def fake_collect_pending_requests(*, active, idle_timeout=0.1, coalesce_s=0.0): + del idle_timeout + calls.append((active, coalesce_s)) + return [], True + + gen._initialize_model = fake_initialize_model + gen._run_speculative = lambda: pytest.fail("MTP should use BatchGenerator") + gen._collect_pending_requests = fake_collect_pending_requests + + gen._run() + + assert calls == [(False, 0.037)] + + def test_idle_batch_generator_is_recreated_for_new_sampler(self, monkeypatch): + created = [] + next_uid = [1] + + class FakeDetokenizer: + def __init__(self): + self.last_segment = "" + + def reset(self): + self.last_segment = "" + + def add_token(self, token): + self.last_segment = str(token) + + def finalize(self): + pass + + class FakeBatchGenerator: + def __init__(self, *args, **kwargs): + del args + self.sampler = kwargs.get("sampler") + self.closed = False + self._active = {} + created.append(self) + + def insert(self, *args, **kwargs): + del args, kwargs + uid = next_uid[0] + next_uid[0] += 1 + self._active[uid] = True + return (uid,) + + @property + def has_work(self): + return bool(self._active) + + @property + def unprocessed_prompts(self): + return [] + + @property + def has_pending_prompts(self): + return False + + def next(self, **kwargs): + del kwargs + responses = [ + SimpleNamespace( + uid=uid, + token=uid, + token_logprob=0.0, + finish_reason="length", + ) + for uid in list(self._active) + ] + self._active.clear() + return [], responses + + def remove(self, uid): + return self._active.pop(uid, None) is not None + + def close(self): + self.closed = True + + monkeypatch.setattr(server_generation, "BatchGenerator", FakeBatchGenerator) + monkeypatch.setattr( + server_generation, + "make_streaming_detokenizer", + lambda _: FakeDetokenizer(), + ) + + gen = server.ResponseGenerator.__new__(server.ResponseGenerator) + gen.model_path = "demo" + gen.adapter_path = None + gen.model = None + gen.processor = None + gen.config = None + gen.stop_tokens = set() + gen.vision_cache = None + gen.draft_model = None + gen.draft_kind = None + gen.kv_bits = None + gen.kv_group_size = server.DEFAULT_KV_GROUP_SIZE + gen.kv_quant_scheme = server.DEFAULT_KV_QUANT_SCHEME + gen.quantized_kv_start = server.DEFAULT_QUANTIZED_KV_START + gen.top_logprobs_k = 0 + gen.apc_manager = None + gen.tokenizer = SimpleNamespace() + gen.requests = Queue() + gen._stop = False + gen._ready = Event() + gen._load_error = None + gen._cancelled = set() + gen._cancel_lock = Lock() + gen._make_sampler = lambda args: f"sampler-{args.temperature}" + + def fake_initialize_model(): + gen.model = SimpleNamespace(language_model=object()) + gen.processor = SimpleNamespace() + gen.config = SimpleNamespace() + gen.stop_tokens = set() + gen.draft_model = None + gen.draft_kind = None + gen.tokenizer = SimpleNamespace() + + gen._initialize_model = fake_initialize_model + gen._gpu_embed = lambda raw_inputs, images=None: ( + mx.array([[raw_inputs["request_id"]]], dtype=mx.int32), + {}, + ) + + worker = Thread(target=gen._run, daemon=True) + worker.start() + + def run_request(request_id, temperature): + rqueue = Queue() + gen.requests.put( + ( + rqueue, + {"request_id": request_id}, + 1, + server.GenerationArguments(max_tokens=1, temperature=temperature), + None, + ) + ) + ctx = rqueue.get(timeout=1) + assert isinstance(ctx, server.GenerationContext) + item = rqueue.get(timeout=1) + assert item.finish_reason == "length" + assert rqueue.get(timeout=1) is None + + try: + run_request(1, 0.0) + run_request(2, 0.6) + finally: + gen._stop = True + gen.requests.put(None) + worker.join(timeout=2) + + assert [bg.sampler for bg in created] == ["sampler-0.0", "sampler-0.6"] + assert created[0].closed is True + def test_step_attaches_prompt_metrics_from_prompt_progress(self): class SimpleTokenizer: vocab = {"hi": 0} diff --git a/mlx_vlm/tests/test_speculative.py b/mlx_vlm/tests/test_speculative.py index 6f578f6ff..f4fa7986f 100644 --- a/mlx_vlm/tests/test_speculative.py +++ b/mlx_vlm/tests/test_speculative.py @@ -21,8 +21,10 @@ import mlx_vlm.speculative.mtp as mtp_utils from mlx_vlm.models.cache import ( ArraysCache, + BatchKVCache, BufferedRotatingKVCache, CacheList, + KVCache, PoolingCache, RotatingKVCache, ) @@ -337,6 +339,39 @@ def test_qwen_rollback_speculative_cache_uses_intermediate_states(): ] +def test_qwen_gated_delta_accept_states_matches_python_gather(): + accepted = mx.array([0, 2, 1, 3], dtype=mx.int32) + intermediate_states = mx.arange(4 * 4 * 2 * 3 * 5, dtype=mx.float32).reshape( + 4, 4, 2, 3, 5 + ) + conv_input = mx.arange(4 * 7 * 6, dtype=mx.float32).reshape(4, 7, 6) + live_state = mx.full((4, 2, 3, 5), -1.0, dtype=mx.float32) + live_conv = mx.full((4, 3, 6), -2.0, dtype=mx.float32) + + ref_state, ref_conv = qwen_language.gated_delta_accept_states( + intermediate_states, + conv_input, + live_state, + live_conv, + accepted, + kernel_size=4, + use_kernel=False, + ) + out_state, out_conv = qwen_language.gated_delta_accept_states( + intermediate_states, + conv_input, + live_state, + live_conv, + accepted, + kernel_size=4, + use_kernel=True, + ) + mx.eval(ref_state, ref_conv, out_state, out_conv) + + assert bool(mx.array_equal(ref_state, out_state).item()) + assert bool(mx.array_equal(ref_conv, out_conv).item()) + + def test_qwen_rollback_speculative_cache_zero_inits_missing_state(): accepted = mx.array([1, 0], dtype=mx.int32) caches = [ArraysCache(size=2)] @@ -480,7 +515,16 @@ def test_qwen_target_verify_linear_matches_singleton_dense_gemv(): linear.bias = mx.random.normal((32,)).astype(mx.bfloat16) x = mx.random.normal((3, 4, 16)).astype(mx.bfloat16) - ref = mx.concatenate([linear(x[:, i : i + 1]) for i in range(x.shape[1])], axis=1) + ref = mx.concatenate( + [ + mx.concatenate( + [linear(x[row : row + 1, i : i + 1]) for i in range(x.shape[1])], + axis=1, + ) + for row in range(x.shape[0]) + ], + axis=0, + ) out = qwen_language._target_verify_linear(linear, x, target_verify=True) mx.eval(ref, out) @@ -500,13 +544,187 @@ def test_qwen_target_verify_gemv_kernel_matches_singleton_dense_gemv(): assert bool(mx.array_equal(ref, out).item()) +def test_qwen_target_verify_quantized_linear_matches_singleton_path(): + mx.random.seed(15) + linear = nn.QuantizedLinear(512, 16, bias=False, group_size=32, bits=4) + linear.scales = linear.scales.astype(mx.bfloat16) + linear.biases = linear.biases.astype(mx.bfloat16) + x = mx.random.normal((2, 3, 512)).astype(mx.bfloat16) + + ref = qwen_language._target_verify_timewise(linear, x) + out = qwen_language._target_verify_linear(linear, x, target_verify=True) + mx.eval(ref, out) + + assert bool(mx.array_equal(ref, out).item()) + + +def test_qwen_target_verify_quantized_linear_matches_singleton_batch_path(): + mx.random.seed(17) + linear = nn.QuantizedLinear(512, 16, bias=False, group_size=32, bits=4) + linear.scales = linear.scales.astype(mx.bfloat16) + linear.biases = linear.biases.astype(mx.bfloat16) + x = mx.random.normal((1, 3, 512)).astype(mx.bfloat16) + + ref = linear(x) + out = qwen_language._target_verify_quantized_linear(linear, x) + mx.eval(ref, out) + + assert bool(mx.array_equal(ref, out).item()) + + +def test_qwen3_5_decode_quantized_linears_fused_matches_separate(): + for bits in (4, 5): + mx.random.seed(170 + bits) + linears = [ + nn.QuantizedLinear(512, out_dim, bias=False, group_size=64, bits=bits) + for out_dim in (64, 64, 16, 16) + ] + for linear in linears: + linear.scales = linear.scales.astype(mx.bfloat16) + linear.biases = linear.biases.astype(mx.bfloat16) + x = mx.random.normal((4, 1, 512), dtype=mx.bfloat16) + + ref = tuple(linear(x) for linear in linears) + out = qwen_language._decode_quantized_linears_fused(tuple(linears), x) + mx.eval(*ref, *out) + + assert out is not None + assert all(bool(mx.array_equal(a, b).item()) for a, b in zip(ref, out)) + + +def test_qwen_target_verify_quantized_argmax_matches_singleton_path(): + mx.random.seed(16) + linear = nn.QuantizedLinear(512, 16, bias=False, group_size=32, bits=4) + linear.scales = linear.scales.astype(mx.bfloat16) + linear.biases = linear.biases.astype(mx.bfloat16) + + x = mx.random.normal((2, 3, 512)).astype(mx.bfloat16) + ref = mx.argmax(qwen_language._target_verify_timewise(linear, x), axis=-1) + out = qwen_language._target_verify_quantized_argmax(linear, x) + mx.eval(ref, out) + + assert bool(mx.array_equal(ref, out).item()) + + +def test_qwen3_5_quantized_argmax_batch_as_time_matches_rowwise(): + mx.random.seed(18) + linear = nn.QuantizedLinear(512, 32, bias=False, group_size=64, bits=4) + linear.scales = linear.scales.astype(mx.bfloat16) + linear.biases = linear.biases.astype(mx.bfloat16) + x = mx.random.normal((4, 1, 512), dtype=mx.bfloat16) + + out = qwen_language._target_verify_quantized_argmax(linear, x) + ref = mx.concatenate( + [ + qwen_language._target_verify_quantized_argmax(linear, x[row : row + 1]) + for row in range(x.shape[0]) + ], + axis=0, + ) + mx.eval(out, ref) + + assert bool(mx.array_equal(out, ref).item()) + + +def _qwen3_5_ragged_attention_reference(queries, keys, values, pads, scale): + return mx.concatenate( + [ + qwen_language.scaled_dot_product_attention( + queries[row : row + 1], + keys[row : row + 1, :, pad:, :], + values[row : row + 1, :, pad:, :], + cache=None, + scale=scale, + mask=None, + ) + for row, pad in enumerate(pads) + ], + axis=0, + ) + + +def test_qwen3_5_ragged_decode_attention_matches_one_pass_singleton(): + mx.random.seed(19) + pads = [17, 0] + scale = 64**-0.5 + queries = mx.random.normal((2, 4, 1, 64), dtype=mx.bfloat16) + keys = mx.random.normal((2, 2, 256, 64), dtype=mx.bfloat16) + values = mx.random.normal((2, 2, 256, 64), dtype=mx.bfloat16) + + out = qwen_language._qwen3_5_ragged_decode_attention( + queries, keys, values, pads, scale + ) + ref = _qwen3_5_ragged_attention_reference(queries, keys, values, pads, scale) + mx.eval(out, ref) + + assert out is not None + assert bool(mx.array_equal(out, ref).item()) + + +def test_qwen3_5_ragged_decode_attention_matches_two_pass_singleton(): + mx.random.seed(20) + pads = [7, 0] + scale = 64**-0.5 + key_length = ( + 1100 if qwen_language._qwen3_5_device_arch_suffix() in {"d", "s"} else 4112 + ) + queries = mx.random.normal((2, 4, 1, 64), dtype=mx.bfloat16) + keys = mx.random.normal((2, 2, key_length, 64), dtype=mx.bfloat16) + values = mx.random.normal((2, 2, key_length, 64), dtype=mx.bfloat16) + + out = qwen_language._qwen3_5_ragged_decode_attention( + queries, keys, values, pads, scale + ) + ref = _qwen3_5_ragged_attention_reference(queries, keys, values, pads, scale) + mx.eval(out, ref) + + assert out is not None + assert bool(mx.array_equal(out, ref).item()) + + +def test_qwen3_5_ragged_decode_attention_rejects_mixed_plan(): + mx.random.seed(21) + scale = 64**-0.5 + if qwen_language._qwen3_5_device_arch_suffix() in {"d", "s"}: + key_length = 1100 + pads = [101, 0] + else: + key_length = 4112 + pads = [33, 0] + queries = mx.random.normal((2, 4, 1, 64), dtype=mx.bfloat16) + keys = mx.random.normal((2, 2, key_length, 64), dtype=mx.bfloat16) + values = mx.random.normal((2, 2, key_length, 64), dtype=mx.bfloat16) + + plans = [ + qwen_language._qwen3_5_sdpa_vector_plan( + key_length - pad, queries.shape[1], keys.shape[1] + ) + for pad in pads + ] + out = qwen_language._qwen3_5_ragged_decode_attention( + queries, keys, values, pads, scale + ) + + assert len(set(plans)) == 2 + assert out is None + + def test_qwen_target_verify_small_projection_matches_singleton_dense_gemv(): mx.random.seed(10) linear = nn.Linear(256, 8, bias=False) linear.weight = mx.random.normal((8, 256)).astype(mx.bfloat16) - x = mx.random.normal((1, 3, 256)).astype(mx.bfloat16) + x = mx.random.normal((3, 3, 256)).astype(mx.bfloat16) - ref = mx.concatenate([linear(x[:, i : i + 1]) for i in range(x.shape[1])], axis=1) + ref = mx.concatenate( + [ + mx.concatenate( + [linear(x[row : row + 1, i : i + 1]) for i in range(x.shape[1])], + axis=1, + ) + for row in range(x.shape[0]) + ], + axis=0, + ) out = qwen_language._target_verify_linear(linear, x, target_verify=True) mx.eval(ref, out) @@ -586,6 +804,93 @@ def test_qwen_gdn_verify_conv_matches_singleton_windows(): assert bool(mx.array_equal(ref, out).item()) +def test_qwen_gdn_decode_conv_matches_conv1d(): + mx.random.seed(141) + config = SimpleNamespace( + hidden_size=16, + linear_num_value_heads=2, + linear_num_key_heads=2, + linear_key_head_dim=4, + linear_value_head_dim=4, + linear_conv_kernel_dim=4, + rms_norm_eps=1e-6, + ) + layer = qwen_language.Qwen3_5GatedDeltaNet(config) + layer.conv1d.weight = layer.conv1d.weight.astype(mx.bfloat16) + conv_input = mx.random.normal( + (3, layer.conv_kernel_size, layer.conv_dim), dtype=mx.bfloat16 + ) + + ref = layer.conv1d(conv_input) + out = layer._causal_conv1d_decode(conv_input) + mx.eval(ref, out) + + assert bool(mx.array_equal(ref, out).item()) + + +def test_qwen3_5_rope_index_ignores_left_padding_for_vision_rows(): + model_config = qwen_language.ModelConfig( + text_config=_tiny_qwen3_5_text_config(), + vision_config=SimpleNamespace(spatial_merge_size=2), + model_type="qwen3_5", + image_token_id=101, + video_token_id=102, + image_token_index=101, + video_token_index=102, + vision_start_token_id=100, + vision_end_token_id=103, + vocab_size=128, + ) + lm = qwen_language.LanguageModel.__new__(qwen_language.LanguageModel) + lm.config = model_config + + singleton_ids = mx.array([[10, 100, 101, 11, 12]], dtype=mx.int32) + padded_ids = mx.array( + [[0, 10, 100, 101, 11, 12], [20, 21, 22, 23, 24, 25]], + dtype=mx.int32, + ) + attention_mask = mx.array([[0, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]], dtype=mx.int32) + image_grid_thw = mx.array([[1, 2, 2]], dtype=mx.int32) + + singleton_pos, singleton_delta = lm.get_rope_index(singleton_ids, image_grid_thw) + padded_pos, padded_delta = lm.get_rope_index( + padded_ids, image_grid_thw, attention_mask=attention_mask + ) + mx.eval(singleton_pos, padded_pos, singleton_delta, padded_delta) + + assert padded_pos[:, 0, 1:].tolist() == singleton_pos[:, 0, :].tolist() + assert padded_delta[0, 0].item() == singleton_delta[0, 0].item() + assert padded_delta[1, 0].item() == 0 + + +def test_qwen3_5_single_row_batch_cache_matches_singleton_cache(): + text_config = _tiny_qwen3_5_text_config() + text_config.num_hidden_layers = 2 + text_config.full_attention_interval = 2 + model = qwen_language.Qwen3_5Model(text_config) + + singleton_cache = [ArraysCache(size=2), KVCache()] + batch_arrays = ArraysCache(size=2) + batch_arrays.left_padding = mx.array([0], dtype=mx.int32) + batch_cache = [batch_arrays, BatchKVCache([0])] + + prompt = mx.array([[1, 2, 3]], dtype=mx.int32) + singleton_prompt = model(prompt, cache=singleton_cache) + batch_prompt = model(prompt, cache=batch_cache) + mx.eval(singleton_prompt, batch_prompt) + + assert bool(mx.array_equal(singleton_prompt, batch_prompt).item()) + assert isinstance(batch_cache[1], BatchKVCache) + + decode = mx.array([[4]], dtype=mx.int32) + singleton_decode = model(decode, cache=singleton_cache) + batch_decode = model(decode, cache=batch_cache) + mx.eval(singleton_decode, batch_decode) + + assert bool(mx.array_equal(singleton_decode, batch_decode).item()) + assert isinstance(batch_cache[1], BatchKVCache) + + def test_speculative_walk_accepts_until_first_mismatch(): accepted, new_tokens = _speculative_walk( mx.array([[11, 12, 13]], dtype=mx.int32), @@ -829,6 +1134,72 @@ def as_linear(self, hidden): assert fake_head.calls == 2 +def test_mtp_acceptance_walk_samples_positioned_block_once(): + class FakeEmbed: + def __init__(self): + self.calls = 0 + + def as_linear(self, hidden): + self.calls += 1 + return hidden + + class PositionedSampler: + def __init__(self): + self.calls = [] + + def __call__(self, logprobs): + raise AssertionError("positioned target sampler was not used") + + def sample_target(self, logprobs, *, row_ids, positions): + self.calls.append((list(row_ids), list(positions))) + return mx.argmax(logprobs, axis=-1) + + fake_head = FakeEmbed() + sampler = PositionedSampler() + lm = SimpleNamespace(speculative_logits_from_hidden=fake_head.as_linear) + target_hidden = mx.array( + [ + [ + [0, 0, 9, 0], + [0, 0, 0, 9], + [0, 9, 0, 0], + [9, 0, 0, 0], + ] + ], + dtype=mx.float32, + ) + draft_tokens = mx.array([[2, 1, 3]], dtype=mx.int32) + verify = mtp_utils._MTPVerifyResult(hidden=target_hidden, shared_kv_states={}) + + accepted, new_tokens = mtp_utils._mtp_acceptance_walk( + lm, + verify, + draft_tokens, + sampler, + budget=4, + row_id=5, + base_position=7, + ) + + assert accepted == 1 + assert new_tokens == [2, 3] + assert fake_head.calls == 1 + assert sampler.calls == [([5, 5, 5, 5], [7, 8, 9, 10])] + + +def test_mtp_draft_kwargs_uses_greedy_for_positioned_sampler(): + class PositionedSampler: + def sample_target(self, logprobs, *, row_ids, positions): + return mx.argmax(logprobs, axis=-1) + + draft_model = SimpleNamespace(supports_greedy_draft_argmax=True) + + assert mtp_utils._mtp_draft_kwargs(draft_model, False) == {} + assert mtp_utils._mtp_draft_kwargs(draft_model, False, PositionedSampler()) == { + "greedy": True + } + + def test_speculative_walk_batch_deferred_greedy_matches_batch_walk(): class FakeEmbed: def __init__(self): @@ -877,6 +1248,181 @@ def as_linear(self, hidden): assert fake_head.calls == 3 +def test_speculative_walk_batch_deferred_greedy_uses_positioned_sampler(): + class FakeEmbed: + def __init__(self): + self.calls = 0 + + def as_linear(self, hidden): + self.calls += 1 + return hidden + + class PositionedSampler: + def __init__(self): + self.calls = [] + + def __call__(self, logprobs): + raise AssertionError("positioned target sampler was not used") + + def sample_target(self, logprobs, *, row_ids, positions): + self.calls.append((list(row_ids), list(positions))) + return mx.argmax(logprobs, axis=-1) + + fake_head = FakeEmbed() + sampler = PositionedSampler() + lm = SimpleNamespace(speculative_logits_from_hidden=fake_head.as_linear) + target_hidden = mx.array( + [ + [ + [0, 0, 9, 0], + [0, 9, 0, 0], + [0, 0, 0, 9], + ], + [ + [9, 0, 0, 0], + [0, 0, 9, 0], + [0, 9, 0, 0], + ], + ], + dtype=mx.float32, + ) + draft_tokens = mx.array([[2, 3], [0, 2]], dtype=mx.int32) + + accepted, new_tokens = _speculative_walk_batch_deferred_greedy( + lm, + target_hidden, + draft_tokens, + sampler, + budgets=[3, 2], + row_ids=[10, 11], + base_positions=[7, 12], + ) + + assert accepted == [1, 2] + assert new_tokens == [[2, 1], [0, 2]] + assert fake_head.calls == 3 + assert sampler.calls == [ + ([10, 11], [7, 12]), + ([10, 11], [8, 13]), + ([10, 11], [9, 14]), + ] + + +def test_speculative_walk_batch_deferred_uniform_stops_at_batch_rejection(): + class FakeEmbed: + def __init__(self): + self.calls = 0 + + def as_linear(self, hidden): + self.calls += 1 + return hidden + + fake_head = FakeEmbed() + lm = SimpleNamespace(speculative_logits_from_hidden=fake_head.as_linear) + target_hidden = mx.array( + [ + [ + [0, 0, 9, 0], + [0, 0, 0, 9], + [0, 9, 0, 0], + ], + [ + [0, 9, 0, 0], + [9, 0, 0, 0], + [0, 0, 9, 0], + ], + ], + dtype=mx.float32, + ) + draft_tokens = mx.array([[2, 3], [0, 2]], dtype=mx.int32) + + accepted, new_tokens = mtp_utils._speculative_walk_batch_deferred_uniform( + lm, + target_hidden, + draft_tokens, + lambda logits: mx.argmax(logits, axis=-1), + budgets=[3, 3], + ) + + assert accepted == [0, 0] + assert new_tokens == [[2], [1]] + assert fake_head.calls == 1 + + +def test_mtp_server_singleton_dispatches_batch_rounds(monkeypatch): + calls = [] + + def fake_batch(*args, **kwargs): + calls.append(("batch", args, kwargs)) + yield [3], None + + def fake_single(*args, **kwargs): + raise AssertionError("server MTP singleton should use batch round path") + + monkeypatch.setattr(speculative_utils, "_mtp_rounds_batch", fake_batch) + monkeypatch.setattr(speculative_utils, "_mtp_rounds", fake_single) + + result = list( + speculative_utils.run_speculative_server_rounds( + SimpleNamespace(language_model=SimpleNamespace()), + SimpleNamespace(), + prompt_cache=[], + hidden=mx.zeros((1, 1, 1), dtype=mx.float32), + shared_kv_states={}, + draft_kind="mtp", + first_bonus=mx.array([2], dtype=mx.int32), + max_tokens=4, + sampler=lambda logprobs: mx.argmax(logprobs, axis=-1), + token_dtype=mx.int32, + greedy_sampling=False, + row_ids=[0], + ) + ) + + assert result == [([3], None)] + assert calls + assert calls[0][2]["first_bonus"].tolist() == [2] + assert calls[0][2]["row_ids"] == [0] + + +def test_mtp_uses_uniform_deferred_walk_for_batched_sampling(): + ragged_drafter = SimpleNamespace(requires_uniform_batch_acceptance=False) + uniform_drafter = SimpleNamespace(requires_uniform_batch_acceptance=True) + normal_sampler = lambda logits: mx.argmax(logits, axis=-1) + positioned_sampler = SimpleNamespace(sample_target=lambda *args, **kwargs: None) + + assert not mtp_utils._mtp_use_uniform_deferred_walk( + ragged_drafter, + n_active=1, + greedy_sampling=False, + sampler=normal_sampler, + ) + assert not mtp_utils._mtp_use_uniform_deferred_walk( + ragged_drafter, + n_active=2, + greedy_sampling=True, + sampler=normal_sampler, + ) + assert mtp_utils._mtp_use_uniform_deferred_walk( + ragged_drafter, + n_active=2, + greedy_sampling=False, + sampler=normal_sampler, + ) + assert not mtp_utils._mtp_use_uniform_deferred_walk( + ragged_drafter, + n_active=2, + greedy_sampling=False, + sampler=positioned_sampler, + ) + assert mtp_utils._mtp_use_uniform_deferred_walk( + uniform_drafter, + n_active=2, + greedy_sampling=True, + sampler=positioned_sampler, + ) + + def test_mtp_draft_hidden_uses_model_hook(): hidden = mx.array([[[1.0, 2.0]]], dtype=mx.float32) lm = SimpleNamespace(speculative_draft_hidden=lambda h: h * 2) @@ -922,6 +1468,39 @@ def verify_logits(inputs, cache, sampler): assert result.target_tokens is target_tokens +def test_mtp_verify_target_prefers_argmax_hidden_hook_for_greedy_tokens(): + verify_input = mx.array([[7, 8]], dtype=mx.int32) + hidden = mx.array([[[1.0, 0.0], [0.0, 1.0]]], dtype=mx.float32) + target_tokens = mx.array([[3, 4]], dtype=mx.int32) + calls = [] + + def verify_hidden(inputs, cache): + calls.append((inputs, cache)) + return hidden, {"full": ("k", "v")}, ["gdn"] + + lm = SimpleNamespace( + speculative_verify_hidden=verify_hidden, + speculative_argmax_from_hidden=lambda h: target_tokens, + speculative_verify_logits=lambda *args: (_ for _ in ()).throw( + AssertionError("full logits should not be used for greedy argmax") + ), + ) + + result = _mtp_verify_target( + lm, + verify_input, + prompt_cache=["cache"], + sampler=lambda logits: mx.argmax(logits, axis=-1), + ) + + assert calls[0][0] is verify_input + assert calls[0][1] == ["cache"] + assert result.hidden is hidden + assert result.shared_kv_states == {"full": ("k", "v")} + assert result.gdn_states == ["gdn"] + assert result.target_tokens is target_tokens + + def test_mtp_rounds_rolls_back_gemma_without_gdn_states(): class Draft: def __init__(self): @@ -2115,6 +2694,72 @@ def test_qwen3_5_mtp_batch_accept_updates_uniform_cache(): assert drafter._next_position == 5 +def test_qwen3_5_mtp_batch_accept_updates_ragged_cache(): + text_config = _tiny_qwen3_5_text_config() + text_config.mtp_num_hidden_layers = 1 + drafter = Qwen3_5MTPDraftModel( + Qwen3_5MTPConfig(text_config=text_config, block_size=3) + ) + target = SimpleNamespace( + language_model=SimpleNamespace( + model=SimpleNamespace(embed_tokens=nn.Embedding(32, 16)) + ) + ) + drafter.reset(target, left_padding=[0, 0]) + drafter.set_shared_kv( + {}, + kv_offset=4, + position=mx.array([4, 4], dtype=mx.int32), + kv_valid_len=mx.array([4, 4], dtype=mx.int32), + ) + hidden = mx.zeros((2, 1, 16), dtype=mx.float32) + draft_tokens = drafter.draft_block( + mx.array([7, 8], dtype=mx.int32), + hidden, + None, + 3, + lambda logits: mx.argmax(logits, axis=-1), + mx.int32, + greedy=True, + ) + verify_hidden = mx.zeros((2, 3, 16), dtype=mx.float32) + drafter.accept_verified_tokens_batch( + verify_hidden, + draft_tokens, + accepted=[1, 0], + new_tokens=[[int(draft_tokens[0, 0].item()), 5], [6]], + sampler=lambda logits: mx.argmax(logits, axis=-1), + token_dtype=mx.int32, + greedy=True, + ) + + mx.eval(drafter._seed_token, drafter._cache[0].offset) + assert drafter._seed_token.shape == (2, 1) + assert drafter._round_appended == 0 + assert drafter._cache[0]._idx == 3 + assert drafter._cache[0].offset.tolist() == [2, 1] + assert drafter._cache[0].left_padding.tolist() == [1, 2] + assert drafter._next_position.tolist() == [6, 5] + + +def test_qwen3_5_rollback_speculative_cache_trims_batch_rows_ragged(): + text_config = _tiny_qwen3_5_text_config() + language = qwen_language.LanguageModel(args=text_config) + cache = qwen_language.KVCache.merge( + [qwen_language.KVCache(), qwen_language.KVCache()] + ) + keys = mx.arange(2 * 1 * 5 * 4, dtype=mx.float32).reshape(2, 1, 5, 4) + values = keys + 100 + cache.update_and_fetch(keys, values) + + language.rollback_speculative_cache([cache], [], mx.array([2, 0]), block_size=3) + + mx.eval(cache.keys, cache.values, cache.offset, cache.left_padding) + assert cache._idx == 5 + assert cache.offset.tolist() == [5, 3] + assert cache.left_padding.tolist() == [0, 2] + + def test_qwen3_5_mtp_filter_batch_keeps_drafter_state_aligned(): text_config = _tiny_qwen3_5_text_config() text_config.mtp_num_hidden_layers = 1 @@ -2161,6 +2806,33 @@ def test_qwen3_5_mtp_filter_batch_keeps_drafter_state_aligned(): assert drafter._next_position.tolist() == [6] +def test_qwen3_5_mtp_filter_batch_keeps_batch_cache_padding_aligned(): + text_config = _tiny_qwen3_5_text_config() + text_config.mtp_num_hidden_layers = 1 + drafter = Qwen3_5MTPDraftModel( + Qwen3_5MTPConfig(text_config=text_config, block_size=3) + ) + target = SimpleNamespace( + language_model=SimpleNamespace( + model=SimpleNamespace(embed_tokens=nn.Embedding(32, 16)) + ) + ) + drafter.reset(target, left_padding=[0, 1, 2]) + drafter._cache[0].update_and_fetch( + mx.zeros((3, 1, 2, 8), dtype=mx.float32), + mx.zeros((3, 1, 2, 8), dtype=mx.float32), + ) + drafter._next_position = mx.array([4, 5, 6], dtype=mx.int32) + + drafter.filter_batch(mx.array([0, 2], dtype=mx.int32)) + + mx.eval(drafter._cache[0].left_padding, drafter._cache[0].offset) + assert drafter._cache[0].keys.shape[0] == 2 + assert drafter._cache[0].left_padding.tolist() == [0, 2] + assert drafter._cache[0].offset.tolist() == [2, 0] + assert drafter._next_position.tolist() == [4, 6] + + def test_qwen3_5_mtp_sanitize_strips_prefix_and_offsets_norms(): weights = { "mtp.fc.weight": mx.ones((2, 4)),