From 7ef792f2c9a49989a41a7e0f0c88abd7ce81d7a6 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 30 May 2026 01:08:43 +0200 Subject: [PATCH 01/19] Add Nemotron Labs Diffusion model --- mlx_vlm/generate/dispatch.py | 31 +- .../models/nemotron_labs_diffusion/README.md | 170 +++ .../nemotron_labs_diffusion/__init__.py | 4 + .../models/nemotron_labs_diffusion/config.py | 59 + .../nemotron_labs_diffusion/language.py | 1055 +++++++++++++++++ .../nemotron_labs_diffusion.py | 77 ++ mlx_vlm/prompt_utils.py | 1 + mlx_vlm/tests/test_diffusion_models.py | 149 +++ mlx_vlm/utils.py | 1 + 9 files changed, 1546 insertions(+), 1 deletion(-) create mode 100644 mlx_vlm/models/nemotron_labs_diffusion/README.md create mode 100644 mlx_vlm/models/nemotron_labs_diffusion/__init__.py create mode 100644 mlx_vlm/models/nemotron_labs_diffusion/config.py create mode 100644 mlx_vlm/models/nemotron_labs_diffusion/language.py create mode 100644 mlx_vlm/models/nemotron_labs_diffusion/nemotron_labs_diffusion.py diff --git a/mlx_vlm/generate/dispatch.py b/mlx_vlm/generate/dispatch.py index 102865083..e5e0c363f 100644 --- a/mlx_vlm/generate/dispatch.py +++ b/mlx_vlm/generate/dispatch.py @@ -379,6 +379,13 @@ def parse_arguments(): help="Extra processor kwargs as JSON. " 'Example: --processor-kwargs \'{"cropping": false, "max_patches": 3}\'', ) + parser.add_argument( + "--gen-kwargs", + type=json.loads, + default={}, + help="Extra generation kwargs as JSON. " + "Example: --gen-kwargs '{\"linear_speculative\": true}'", + ) parser.add_argument( "--prefill-step-size", type=int, @@ -606,6 +613,21 @@ def is_masked_diffusion_text_model(model: nn.Module) -> bool: return getattr(config, "mask_token_id", None) is not None +def _use_masked_diffusion_text_path(model: nn.Module, kwargs: Dict[str, Any]) -> bool: + if not is_masked_diffusion_text_model(model): + return False + + config = getattr(model, "config", None) + if getattr(config, "default_generation_mode", None) != "ar": + return True + + generation_mode = kwargs.get("generation_mode") + if generation_mode is not None: + return generation_mode in ("diffusion", "linear_speculative") + + return bool(kwargs.get("linear_speculative", False)) + + def _prime_cached_prefix_rope_state( model: nn.Module, full_input_ids: mx.array, @@ -738,7 +760,7 @@ def stream_generate( } kwargs.update(data_kwargs) - if is_masked_diffusion_text_model(model): + if _use_masked_diffusion_text_path(model, kwargs): if image is not None or audio is not None or video is not None: raise ValueError("Diffusion text generation models are text-only.") @@ -789,6 +811,8 @@ def stream_generate( tokenizer=tokenizer, skip_special_tokens=skip_special_tokens, stats=generation_stats, + linear_speculative=kwargs.get("linear_speculative", False) + or kwargs.get("generation_mode") == "linear_speculative", ) mx.eval(generated) total_time = time.perf_counter() - tic @@ -1325,6 +1349,7 @@ def main(): "editing_threshold": None, "max_post_steps": None, "stability_steps": None, + "gen_kwargs": {}, } for name, default in diffusion_arg_defaults.items(): if not hasattr(args, name): @@ -1410,6 +1435,10 @@ def main(): if args.processor_kwargs: kwargs.update(args.processor_kwargs) + # Add generation kwargs from JSON + if args.gen_kwargs: + kwargs.update(args.gen_kwargs) + # Add thinking kwargs kwargs["enable_thinking"] = args.enable_thinking if args.thinking_budget is not None: diff --git a/mlx_vlm/models/nemotron_labs_diffusion/README.md b/mlx_vlm/models/nemotron_labs_diffusion/README.md new file mode 100644 index 000000000..564b83dc8 --- /dev/null +++ b/mlx_vlm/models/nemotron_labs_diffusion/README.md @@ -0,0 +1,170 @@ +# Nemotron Labs Diffusion + +Nemotron Labs Diffusion is a text-only diffusion language model from NVIDIA. The same checkpoint supports autoregressive decoding, block diffusion decoding, and linear self-speculative decoding. + +Capabilities: +- **Text generation** - normal autoregressive generation through the standard `mlx_vlm.generate` path +- **Diffusion generation** - masked block denoising with live visualization when `--verbose` is enabled +- **Linear self-speculation** - diffusion drafting with autoregressive verification using `--gen-kwargs` +- **Thinking mode** - chat-template support through `--enable-thinking` + +## Model + +| Model | Type | Params | Context | Modalities | +|---|---|---:|---:|---| +| `nvidia/Nemotron-Labs-Diffusion-8B` | Dense diffusion LM | 8B | 262k | Text | + +## Install + +```sh +pip install -U mlx-vlm +``` + +## CLI + +### Autoregressive generation + +By default, Nemotron uses the normal autoregressive generation path. + +```sh +mlx_vlm.generate \ + --model nvidia/Nemotron-Labs-Diffusion-8B \ + --prompt "Write a short story about a clockmaker." \ + --max-tokens 256 \ + --temperature 0.0 +``` + +### Diffusion generation + +Pass `generation_mode="diffusion"` to use the masked diffusion path. + +```sh +mlx_vlm.generate \ + --model nvidia/Nemotron-Labs-Diffusion-8B \ + --prompt "Write a short story about a clockmaker." \ + --max-tokens 256 \ + --max-denoising-steps 16 \ + --temperature 0.0 \ + --gen-kwargs '{"generation_mode": "diffusion"}' \ + --verbose +``` + +### Linear self-speculative generation + +Use `--gen-kwargs` for model-specific generation options. The bundled `linear_spec_lora` adapter is loaded automatically when available. + +```sh +mlx_vlm.generate \ + --model nvidia/Nemotron-Labs-Diffusion-8B \ + --prompt "Write a short story about a clockmaker." \ + --max-tokens 256 \ + --temperature 0.0 \ + --gen-kwargs '{"generation_mode": "linear_speculative"}' +``` + +### Thinking mode + +```sh +mlx_vlm.generate \ + --model nvidia/Nemotron-Labs-Diffusion-8B \ + --prompt "Solve this step by step: if a train travels 180 km in 2.5 hours, what is its average speed?" \ + --enable-thinking \ + --max-tokens 512 \ + --temperature 0.0 +``` + +## Python + +### Basic text generation + +```python +from mlx_vlm import generate, load +from mlx_vlm.prompt_utils import apply_chat_template + +model, processor = load("nvidia/Nemotron-Labs-Diffusion-8B") + +prompt = apply_chat_template( + processor, + model.config, + "Write a short story about a clockmaker.", +) + +result = generate( + model=model, + processor=processor, + prompt=prompt, + max_tokens=256, + temperature=0.0, +) +print(result.text) +``` + +### Diffusion generation + +```python +from mlx_vlm import generate, load +from mlx_vlm.prompt_utils import apply_chat_template + +model, processor = load("nvidia/Nemotron-Labs-Diffusion-8B") + +prompt = apply_chat_template( + processor, + model.config, + "Write a short story about a clockmaker.", +) + +result = generate( + model=model, + processor=processor, + prompt=prompt, + max_tokens=256, + max_denoising_steps=16, + temperature=0.0, + generation_mode="diffusion", +) +print(result.text) +``` + +### Linear self-speculative generation + +```python +from mlx_vlm import generate, load +from mlx_vlm.prompt_utils import apply_chat_template + +model, processor = load("nvidia/Nemotron-Labs-Diffusion-8B") + +prompt = apply_chat_template( + processor, + model.config, + "Write a short story about a clockmaker.", +) + +result = generate( + model=model, + processor=processor, + prompt=prompt, + max_tokens=256, + temperature=0.0, + generation_mode="linear_speculative", +) +print(result.text) +``` + +## Architecture + +- **Backbone** - dense decoder-only Ministral-style transformer +- **Layers** - 34 transformer layers +- **Hidden size** - 4096 +- **Attention** - 32 query heads, 8 KV heads, 128 head dimension +- **MLP** - SwiGLU with 14336 intermediate size +- **RoPE** - long-context YaRN/Llama 4-style scaling parameters from the checkpoint +- **Diffusion head** - untied output projection over the 131072-token vocabulary +- **Mask token** - `mask_token_id=100` + +## Notes + +- The model is text-only. Image, audio, and video inputs are not supported. +- AR generation should use the normal CLI without diffusion-specific arguments. +- Diffusion generation uses masked block denoising. `--verbose` shows the block visualization as masks are filled. +- Diffusion and linear self-speculative generation are exposed through `generation_mode`, for example `--gen-kwargs '{"generation_mode": "diffusion"}'`. +- The optional `linear_spec_lora` adapter included in the Hugging Face repo is used only during the diffusion draft phase of linear self-speculation. diff --git a/mlx_vlm/models/nemotron_labs_diffusion/__init__.py b/mlx_vlm/models/nemotron_labs_diffusion/__init__.py new file mode 100644 index 000000000..96e24ef23 --- /dev/null +++ b/mlx_vlm/models/nemotron_labs_diffusion/__init__.py @@ -0,0 +1,4 @@ +from .config import ModelConfig +from .nemotron_labs_diffusion import Model + +__all__ = ["Model", "ModelConfig"] diff --git a/mlx_vlm/models/nemotron_labs_diffusion/config.py b/mlx_vlm/models/nemotron_labs_diffusion/config.py new file mode 100644 index 000000000..b0694a52c --- /dev/null +++ b/mlx_vlm/models/nemotron_labs_diffusion/config.py @@ -0,0 +1,59 @@ +from dataclasses import dataclass +from typing import Any, Dict, Optional, Union + +from ..base import BaseModelConfig + + +@dataclass +class ModelConfig(BaseModelConfig): + model_type: str = "nemotron_labs_diffusion" + vocab_size: int = 131072 + hidden_size: int = 4096 + intermediate_size: int = 14336 + num_hidden_layers: int = 34 + num_attention_heads: int = 32 + num_key_value_heads: int = 8 + head_dim: Optional[int] = 128 + hidden_act: str = "silu" + max_position_embeddings: int = 262144 + initializer_range: float = 0.02 + rms_norm_eps: float = 1e-5 + use_cache: bool = False + pad_token_id: Optional[int] = None + bos_token_id: Optional[int] = 1 + eos_token_id: Optional[Union[int, list[int]]] = 11 + tie_word_embeddings: bool = False + rope_theta: float = 1000000.0 + rope_parameters: Optional[Dict[str, Any]] = None + rope_scaling: Optional[Dict[str, Any]] = None + attention_bias: bool = False + attention_dropout: float = 0.0 + mlp_bias: bool = False + sliding_window: Optional[int] = None + attn_implementation: str = "sdpa" + mask_token_id: int = 100 + default_generation_mode: str = "ar" + dlm_paradigm: str = "bidirectional" + block_size: int = 32 + dlm_loss_weight: Optional[float] = None + ar_loss_weight: float = 1.0 + dp_varying_mask_ratio: bool = False + + def __post_init__(self): + if self.head_dim is None: + self.head_dim = self.hidden_size // self.num_attention_heads + + rope_parameters = ( + dict(self.rope_parameters) + if self.rope_parameters is not None + else ( + dict(self.rope_scaling) + if self.rope_scaling is not None + else {"rope_type": "default", "rope_theta": self.rope_theta} + ) + ) + rope_parameters.setdefault("rope_type", "default") + rope_parameters.setdefault("rope_theta", self.rope_theta) + self.rope_parameters = rope_parameters + self.rope_scaling = rope_parameters + self.rope_theta = float(rope_parameters.get("rope_theta", self.rope_theta)) diff --git a/mlx_vlm/models/nemotron_labs_diffusion/language.py b/mlx_vlm/models/nemotron_labs_diffusion/language.py new file mode 100644 index 000000000..38758e271 --- /dev/null +++ b/mlx_vlm/models/nemotron_labs_diffusion/language.py @@ -0,0 +1,1055 @@ +import shutil +import sys +import time +from pathlib import Path +from typing import Any, Dict, Optional, Tuple + +import mlx.core as mx +import mlx.nn as nn +from mlx_lm.models.activations import swiglu +from mlx_lm.models.rope_utils import initialize_rope + +from ..base import ( + LanguageModelOutput, + create_attention_mask, + scaled_dot_product_attention, +) +from ..cache import KVCache +from .config import ModelConfig + + +def _topk(x: mx.array, k: int, axis: int = -1) -> Tuple[mx.array, mx.array]: + indices = mx.argpartition(-x, kth=k - 1, axis=axis)[..., :k] + values = mx.take_along_axis(x, indices, axis=axis) + order = mx.argsort(-values, axis=axis) + return mx.take_along_axis(values, order, axis=axis), mx.take_along_axis( + indices, order, axis=axis + ) + + +def _first_token_index(tokens: mx.array, token_ids: set[int]) -> Optional[int]: + values = tokens.tolist() + return next( + (index for index, token_id in enumerate(values) if token_id in token_ids), + None, + ) + + +def _wrap_text(text: str, width: int) -> str: + lines = [] + while len(text) > width: + split_at = text.rfind(" ", 0, width + 1) + if split_at <= 0: + split_at = width + lines.append(text[:split_at].rstrip()) + text = text[split_at:].lstrip() + if text: + lines.append(text) + return "\n".join(lines) + + +def _make_bidirectional_mask( + attention_mask: Optional[mx.array], x: mx.array +) -> Optional[mx.array]: + if attention_mask is None: + return None + if attention_mask.ndim == 4: + if attention_mask.dtype == mx.bool_: + return attention_mask + return mx.where(attention_mask.astype(mx.bool_), 0.0, mx.finfo(x.dtype).min) + if attention_mask.ndim != 2: + return attention_mask + + if attention_mask.shape[-1] == 0 or bool(mx.all(attention_mask).item()): + return None + mask = attention_mask[:, None, None, :].astype(mx.bool_) + return mx.where(mask, 0.0, mx.finfo(x.dtype).min).astype(x.dtype) + + +def _llama4_attention_scale( + config: ModelConfig, length: int, offset: Any, dtype: mx.Dtype +) -> mx.array: + beta = config.rope_parameters.get("llama_4_scaling_beta") + original_max = config.rope_parameters.get("original_max_position_embeddings") + if beta is None or original_max is None: + return mx.array(1.0, dtype=dtype) + positions = mx.arange(length, dtype=mx.float32) + offset + scale = 1.0 + float(beta) * mx.log1p(mx.floor(positions / float(original_max))) + return scale.astype(dtype)[None, None, :, None] + + +class MLP(nn.Module): + def __init__(self, config: ModelConfig): + super().__init__() + self.gate_proj = nn.Linear( + config.hidden_size, config.intermediate_size, bias=config.mlp_bias + ) + self.up_proj = nn.Linear( + config.hidden_size, config.intermediate_size, bias=config.mlp_bias + ) + self.down_proj = nn.Linear( + config.intermediate_size, config.hidden_size, bias=config.mlp_bias + ) + + def __call__(self, x: mx.array) -> mx.array: + return self.down_proj(swiglu(self.gate_proj(x), self.up_proj(x))) + + +class DraftLoRALinear(nn.Module): + def __init__(self, linear: nn.Linear, rank: int, scale: float): + super().__init__() + self.linear = linear + self.scale = scale + out_dim, in_dim = linear.weight.shape + self.lora_a = mx.zeros((in_dim, rank), dtype=linear.weight.dtype) + self.lora_b = mx.zeros((rank, out_dim), dtype=linear.weight.dtype) + self.enabled = False + + def __call__(self, x: mx.array) -> mx.array: + y = self.linear(x) + if not self.enabled: + return y + z = (x @ self.lora_a.astype(x.dtype)) @ self.lora_b.astype(x.dtype) + return y + (self.scale * z).astype(y.dtype) + + +class Attention(nn.Module): + def __init__(self, config: ModelConfig): + super().__init__() + self.config = config + self.num_heads = config.num_attention_heads + self.num_key_value_heads = config.num_key_value_heads + self.head_dim = config.head_dim + self.scale = self.head_dim**-0.5 + self.q_proj = nn.Linear( + config.hidden_size, config.num_attention_heads * self.head_dim, bias=False + ) + self.k_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False + ) + self.v_proj = nn.Linear( + config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False + ) + self.o_proj = nn.Linear( + config.num_attention_heads * self.head_dim, + config.hidden_size, + bias=config.attention_bias, + ) + self.rope = initialize_rope( + self.head_dim, + base=config.rope_theta, + traditional=False, + scaling_config=config.rope_parameters, + max_position_embeddings=config.max_position_embeddings, + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + use_cache: bool = True, + ) -> mx.array: + B, L, _ = x.shape + queries = self.q_proj(x).reshape(B, L, self.num_heads, self.head_dim) + keys = self.k_proj(x).reshape(B, L, self.num_key_value_heads, self.head_dim) + values = self.v_proj(x).reshape(B, L, self.num_key_value_heads, self.head_dim) + + queries = queries.transpose(0, 2, 1, 3) + keys = keys.transpose(0, 2, 1, 3) + values = values.transpose(0, 2, 1, 3) + + offset = cache.offset if cache is not None else 0 + queries = self.rope(queries, offset=offset) + keys = self.rope(keys, offset=offset) + queries = queries * _llama4_attention_scale( + self.config, L, offset, queries.dtype + ) + + if cache is not None: + if use_cache: + keys, values = cache.update_and_fetch(keys, values) + elif cache.keys is not None: + keys = mx.concatenate( + [cache.keys[..., : cache.offset, :], keys], axis=2 + ) + values = mx.concatenate( + [cache.values[..., : cache.offset, :], values], axis=2 + ) + + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output) + + +class DecoderLayer(nn.Module): + def __init__(self, config: ModelConfig): + super().__init__() + self.self_attn = Attention(config) + self.mlp = MLP(config) + self.input_layernorm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + use_cache: bool = True, + ) -> mx.array: + r = self.self_attn( + self.input_layernorm(x), mask=mask, cache=cache, use_cache=use_cache + ) + h = x + r + return h + self.mlp(self.post_attention_layernorm(h)) + + +class NemotronLabsDiffusionEncoder(nn.Module): + def __init__(self, config: ModelConfig): + super().__init__() + self.config = config + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + self.layers = [DecoderLayer(config) for _ in range(config.num_hidden_layers)] + self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def __call__( + self, + inputs: mx.array, + inputs_embeds: Optional[mx.array] = None, + attention_mask: Optional[mx.array] = None, + mask: Optional[mx.array] = None, + cache=None, + use_cache: bool = True, + use_causal_mask: bool = False, + ) -> mx.array: + h = self.embed_tokens(inputs) if inputs_embeds is None else inputs_embeds + if cache is None: + cache = [None] * len(self.layers) + if use_causal_mask: + layer_mask = create_attention_mask(h, cache[0]) + else: + layer_mask = _make_bidirectional_mask( + mask if mask is not None else attention_mask, h + ) + for layer, layer_cache in zip(self.layers, cache): + h = layer(h, mask=layer_mask, cache=layer_cache, use_cache=use_cache) + return self.norm(h) + + +class LanguageModel(nn.Module): + def __init__(self, config: ModelConfig): + super().__init__() + if config.dlm_paradigm not in ("bidirectional", "autoregressive"): + raise ValueError( + f"Unsupported Nemotron Labs Diffusion paradigm: {config.dlm_paradigm}" + ) + self.config = config + self.model_type = config.model_type + self.model = NemotronLabsDiffusionEncoder(config) + if not config.tie_word_embeddings: + self.diffusion_head = nn.Linear( + config.hidden_size, config.vocab_size, bias=False + ) + self._linear_spec_lora_loaded = False + + def __call__( + self, + inputs: mx.array, + inputs_embeds: Optional[mx.array] = None, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + **kwargs, + ): + out = self.model( + inputs, + inputs_embeds=inputs_embeds, + mask=mask, + attention_mask=kwargs.get("attention_mask"), + cache=cache, + use_cache=kwargs.get("use_cache", True), + use_causal_mask=kwargs.get("use_causal_mask", True), + ) + if self.config.tie_word_embeddings: + out = self.model.embed_tokens.as_linear(out) + else: + out = self.diffusion_head(out) + return LanguageModelOutput(logits=out) + + @staticmethod + def _top_k_logits(logits: mx.array, k: Optional[int]) -> mx.array: + if k is None or k <= 0: + return logits + values = _topk(logits, k=k, axis=-1)[0] + return mx.where(logits < values[..., -1:], mx.finfo(logits.dtype).min, logits) + + @staticmethod + def _top_p_logits(logits: mx.array, p: Optional[float]) -> mx.array: + if p is None or p >= 1.0: + return logits + sorted_indices = mx.argsort(-logits, axis=-1) + sorted_logits = mx.take_along_axis(logits, sorted_indices, axis=-1) + cumulative_probs = mx.cumsum( + mx.softmax(sorted_logits, axis=-1, precise=True), axis=-1 + ) + sorted_mask = cumulative_probs > p + sorted_mask = mx.concatenate( + [mx.zeros_like(sorted_mask[..., :1]), sorted_mask[..., :-1]], axis=-1 + ) + inverse_indices = mx.argsort(sorted_indices, axis=-1) + mask = mx.take_along_axis(sorted_mask, inverse_indices, axis=-1) + return mx.where(mask, mx.finfo(logits.dtype).min, logits) + + def _sample_with_temperature_topk_topp( + self, + logits: mx.array, + temperature: float = 0.0, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + ): + if temperature == 0.0: + token = mx.argmax(logits, axis=-1) + probs = mx.softmax(logits.astype(mx.float32), axis=-1, precise=True) + token_prob = mx.take_along_axis(probs, token[..., None], axis=-1)[..., 0] + return token, token_prob + + if temperature != 1.0: + logits = logits / temperature + logits = self._top_k_logits(logits, top_k) + logits = self._top_p_logits(logits, top_p) + token = mx.random.categorical(logits.astype(mx.float32), axis=-1) + probs = mx.softmax(logits.astype(mx.float32), axis=-1, precise=True) + token_prob = mx.take_along_axis(probs, token[..., None], axis=-1)[..., 0] + return token, token_prob + + def _project_hidden(self, hidden_states: mx.array) -> mx.array: + if self.config.tie_word_embeddings: + return self.model.embed_tokens.as_linear(hidden_states) + return self.diffusion_head(hidden_states) + + def _sample_tokens( + self, + logits: mx.array, + temperature: float = 0.0, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + ) -> mx.array: + if temperature == 0.0: + return mx.argmax(logits, axis=-1) + + if temperature != 1.0: + logits = logits / temperature + logits = self._top_k_logits(logits, top_k) + logits = self._top_p_logits(logits, top_p) + return mx.random.categorical(logits.astype(mx.float32), axis=-1) + + @staticmethod + def _trim_cache(cache, max_length: int) -> None: + for layer_cache in cache: + excess = max(0, int(layer_cache.offset) - int(max_length)) + if excess: + layer_cache.trim(excess) + + def load_linear_spec_lora(self, adapter_path: str | Path) -> bool: + adapter_path = Path(adapter_path) + adapter_file = adapter_path / "adapter_model.safetensors" + if not adapter_file.exists(): + return False + weights = mx.load(str(adapter_file)) + rank = 128 + scale = 4.0 + + for layer_idx, layer in enumerate(self.model.layers): + o_proj = layer.self_attn.o_proj + if not isinstance(o_proj, DraftLoRALinear): + o_proj = DraftLoRALinear(o_proj, rank=rank, scale=scale) + layer.self_attn.o_proj = o_proj + + prefix = "base_model.model.encoder.layers." f"{layer_idx}.self_attn.o_proj" + key_a = f"{prefix}.lora_A.weight" + key_b = f"{prefix}.lora_B.weight" + if key_a not in weights or key_b not in weights: + return False + o_proj.lora_a = weights[key_a].T.astype(o_proj.linear.weight.dtype) + o_proj.lora_b = weights[key_b].T.astype(o_proj.linear.weight.dtype) + + self._linear_spec_lora_loaded = True + return True + + def set_linear_spec_lora_enabled(self, enabled: bool) -> None: + for layer in self.model.layers: + o_proj = layer.self_attn.o_proj + if isinstance(o_proj, DraftLoRALinear): + o_proj.enabled = enabled + + def generate( + self, + inputs: mx.array, + temperature: float = 0.0, + block_length: int = 32, + steps: int = 32, + gen_length: int = 2048, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + eos_early_stop: bool = False, + minimal_topk: int = 1, + threshold: float = 0.95, + min_threshold: Optional[float] = None, + editing_threshold: float = 0.9, + max_post_steps: int = 16, + eos_id: Optional[int] = None, + mask_id: Optional[int] = None, + num_to_transfer: int = 1, + max_transfer_per_step: Optional[int] = None, + stability_steps: int = 2, + visualize: bool = False, + tokenizer: Optional[Any] = None, + skip_special_tokens: bool = False, + stats: Optional[Dict[str, float]] = None, + linear_speculative: bool = False, + ) -> mx.array: + if inputs.shape[0] != 1: + raise ValueError( + "Nemotron Labs Diffusion generation currently supports batch size 1." + ) + + eos_id = self.config.eos_token_id if eos_id is None else eos_id + mask_id = self.config.mask_token_id if mask_id is None else mask_id + if linear_speculative: + if not self._linear_spec_lora_loaded: + model_path = getattr(self, "model_path", None) + if model_path is not None: + self.load_linear_spec_lora(Path(model_path) / "linear_spec_lora") + output, _ = self.linear_spec_generate( + inputs, + max_new_tokens=gen_length, + block_length=block_length, + temperature=temperature, + top_p=top_p, + top_k=top_k, + mask_token_id=mask_id, + eos_token_id=eos_id, + threshold=0.0, + stats=stats, + ) + return output[:, inputs.shape[1] :] + + eos_token_ids = ( + set(eos_id) if isinstance(eos_id, (list, tuple, set)) else {eos_id} + ) + if block_length <= 0: + raise ValueError("block_length must be a positive integer.") + steps = max(1, int(steps)) + if max_transfer_per_step is not None: + max_transfer_per_step = min( + block_length, max(1, int(max_transfer_per_step)) + ) + + visualizer_state = { + "active": visualize and sys.stdout.isatty(), + "alternate_screen": False, + "rows": 0, + "last_draw": 0.0, + "min_interval": 0.1, + "token_ids": None, + "pieces": None, + "canvas": "", + } + + def clear_visualizer() -> None: + if not visualizer_state["active"]: + return + if visualizer_state["alternate_screen"]: + print("\033[H\033[2J", end="", flush=True) + visualizer_state["rows"] = 0 + return + if visualizer_state["rows"] == 0: + return + controls = ["\r\033[2K"] + for _ in range(visualizer_state["rows"] - 1): + controls.append("\033[1A\r\033[2K") + print("".join(controls), end="", flush=True) + visualizer_state["rows"] = 0 + + def finish_visualizer() -> None: + if not visualizer_state["active"]: + return + if visualizer_state["alternate_screen"]: + print("\033[H\033[2J\033[?25h\033[?1049l", end="", flush=True) + visualizer_state["alternate_screen"] = False + visualizer_state["rows"] = 0 + else: + clear_visualizer() + + def decode_token(token_id: int) -> str: + if tokenizer is None: + return str(token_id) + piece = tokenizer.decode( + [token_id], skip_special_tokens=skip_special_tokens + ) + return piece.replace("\n", "\\n") or " " + + def visualize_tokens(tokens: mx.array, force: bool = False) -> None: + if not visualizer_state["active"]: + return + now = time.perf_counter() + if ( + not force + and now - visualizer_state["last_draw"] + < visualizer_state["min_interval"] + ): + return + token_ids = tokens[0].tolist() + pieces = visualizer_state["pieces"] + previous_token_ids = visualizer_state["token_ids"] + if ( + pieces is None + or previous_token_ids is None + or len(previous_token_ids) != len(token_ids) + ): + pieces = ["[MASK]"] * len(token_ids) + previous_token_ids = [mask_id] * len(token_ids) + + found_eos = False + for i, token_id in enumerate(token_ids): + previous_token_id = previous_token_ids[i] + if found_eos: + if previous_token_id != mask_id: + pieces[i] = "[MASK]" + continue + if token_id == mask_id: + if previous_token_id != mask_id: + pieces[i] = "[MASK]" + elif token_id in eos_token_ids: + if previous_token_id != token_id: + pieces[i] = decode_token(token_id) or "" + found_eos = True + elif previous_token_id != token_id: + pieces[i] = decode_token(token_id) + + visualizer_state["pieces"] = pieces + visualizer_state["token_ids"] = token_ids + terminal_size = shutil.get_terminal_size((120, 20)) + terminal_width = max(20, terminal_size.columns - 1) + canvas = _wrap_text("".join(pieces), terminal_width) + if not force and canvas == visualizer_state["canvas"]: + return + rows = max(1, canvas.count("\n") + 1) + if ( + rows >= max(1, terminal_size.lines - 2) + and not visualizer_state["alternate_screen"] + ): + print("\033[?1049h\033[?25l\033[H\033[2J", end="", flush=True) + visualizer_state["alternate_screen"] = True + clear_visualizer() + print(canvas, end="", flush=True) + visualizer_state["rows"] = rows + visualizer_state["last_draw"] = now + visualizer_state["canvas"] = canvas + + generated_blocks = [] + prompt_tic = time.perf_counter() + recorded_prompt_time = False + cache = self.make_cache() + prefill_logits = self( + inputs, + cache=cache, + use_cache=True, + use_causal_mask=True, + ).logits + next_token = self._sample_tokens( + prefill_logits[:, -1, :], + temperature=temperature, + top_k=top_k, + top_p=top_p, + )[:, None] + mx.eval(next_token) + if stats is not None: + stats["prompt_time"] = time.perf_counter() - prompt_tic + stats["prompt_tokens"] = float(inputs.size) + recorded_prompt_time = True + + total_generated = 0 + num_blocks = (gen_length + block_length - 1) // block_length + for _ in range(num_blocks): + remaining = gen_length - total_generated + if remaining <= 0: + break + current_block_length = min(block_length, remaining) + block_positions = mx.arange(block_length) + block = mx.full((1, block_length), mask_id, dtype=inputs.dtype) + block[:, 0] = next_token[:, 0] + if visualizer_state["active"]: + preview = ( + mx.concatenate(generated_blocks + [block], axis=1) + if generated_blocks + else block + ) + visualize_tokens(preview, force=True) + + for step_idx in range(steps): + mask_index = block == mask_id + if not bool(mask_index.any().item()): + break + logits = self( + block, + cache=cache, + use_cache=False, + use_causal_mask=False, + ).logits + x0, token_probs = self._sample_with_temperature_topk_topp( + logits, + temperature=temperature, + top_k=top_k, + top_p=top_p, + ) + if stats is not None and not recorded_prompt_time: + mx.eval(x0, token_probs) + stats["prompt_time"] = time.perf_counter() - prompt_tic + stats["prompt_tokens"] = float(inputs.size) + recorded_prompt_time = True + x0 = mx.where(mask_index, x0, block) + confidence = mx.where(mask_index, token_probs, -mx.inf) + remaining_steps = max(1, steps - step_idx) + masked_count = int(mask_index.sum().item()) + transfer_count = max( + 1, (masked_count + remaining_steps - 1) // remaining_steps + ) + if max_transfer_per_step is not None: + transfer_count = min(transfer_count, max_transfer_per_step) + _, indices = _topk(confidence, min(transfer_count, masked_count)) + transfer_mask = ( + block_positions[None, None, :] == indices[..., None] + ).any(axis=1) + if threshold is not None: + high_confidence = (confidence >= threshold) & mask_index + if bool(high_confidence.any().item()): + transfer_mask = transfer_mask & high_confidence + else: + _, best_index = _topk(confidence, 1) + transfer_mask = ( + block_positions[None, None, :] == best_index[..., None] + ).any(axis=1) + block = mx.where(transfer_mask, x0, block) + if visualizer_state["active"] and bool(transfer_mask.any().item()): + preview = ( + mx.concatenate(generated_blocks + [block], axis=1) + if generated_blocks + else block + ) + visualize_tokens(preview) + + output = self( + block, + cache=cache, + use_cache=True, + use_causal_mask=True, + ) + next_token = self._sample_tokens( + output.logits[:, -1, :], + temperature=temperature, + top_k=top_k, + top_p=top_p, + )[:, None] + generated_block = block[:, :current_block_length] + generated_blocks.append(generated_block) + total_generated += current_block_length + if ( + eos_early_stop + and _first_token_index(generated_block[0], eos_token_ids) is not None + ): + break + + generated = ( + mx.concatenate(generated_blocks, axis=1) + if generated_blocks + else mx.zeros((1, 0), dtype=inputs.dtype) + ) + generated_ids = generated[0].tolist() + end = next( + ( + i + 1 + for i, token_id in enumerate(generated_ids) + if token_id in eos_token_ids + ), + generated.shape[1], + ) + if visualizer_state["active"]: + finish_visualizer() + if tokenizer is not None: + final_text = tokenizer.decode( + generated_ids[:end], skip_special_tokens=skip_special_tokens + ) + else: + final_text = " ".join(str(token_id) for token_id in generated_ids[:end]) + print(final_text, end="", flush=True) + if stats is not None: + stats["text_already_printed"] = True + return generated[:, :end] + + def ar_generate( + self, + prompt_ids: mx.array, + max_new_tokens: int = 128, + temperature: float = 0.0, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + eos_token_id: Optional[int] = None, + stats: Optional[Dict[str, float]] = None, + **kwargs, + ) -> tuple[mx.array, int]: + if eos_token_id is None: + eos_token_id = self.config.eos_token_id + eos_token_ids = ( + set(eos_token_id) + if isinstance(eos_token_id, (list, tuple, set)) + else {eos_token_id} + ) + + prompt_tic = time.perf_counter() + cache = self.make_cache() + prefill = self( + prompt_ids, + cache=cache, + use_cache=True, + use_causal_mask=True, + ).logits + mx.eval(prefill) + if stats is not None: + stats["prompt_time"] = time.perf_counter() - prompt_tic + stats["prompt_tokens"] = float(prompt_ids.size) + + generated = [] + next_logits = prefill[:, -1, :] + nfe = 0 + for _ in range(max_new_tokens): + nfe += 1 + next_token = self._sample_tokens( + next_logits, + temperature=temperature, + top_k=top_k, + top_p=top_p, + )[:, None] + generated.append(next_token) + if bool( + mx.array( + [token in eos_token_ids for token in next_token[:, 0].tolist()] + ) + .all() + .item() + ): + break + next_logits = self( + next_token, + cache=cache, + use_cache=True, + use_causal_mask=True, + ).logits[:, -1, :] + + if not generated: + return prompt_ids, nfe + return ( + mx.concatenate([prompt_ids, mx.concatenate(generated, axis=1)], axis=1), + nfe, + ) + + def linear_spec_generate( + self, + prompt_ids: mx.array, + max_new_tokens: int = 128, + block_length: int = 32, + temperature: float = 0.0, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + mask_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + threshold: float = 0.0, + stats: Optional[Dict[str, float]] = None, + **kwargs, + ) -> tuple[mx.array, int]: + if prompt_ids.shape[0] != 1: + raise ValueError("Linear speculative decoding requires batch size 1.") + if block_length <= 0: + raise ValueError("block_length must be a positive integer.") + + mask_token_id = ( + self.config.mask_token_id if mask_token_id is None else mask_token_id + ) + if eos_token_id is None: + eos_token_id = self.config.eos_token_id + eos_token_ids = ( + set(eos_token_id) + if isinstance(eos_token_id, (list, tuple, set)) + else {eos_token_id} + ) + + prompt_tic = time.perf_counter() + cache = self.make_cache() + prefill = self( + prompt_ids, + cache=cache, + use_cache=True, + use_causal_mask=True, + ).logits + mx.eval(prefill) + if stats is not None: + stats["prompt_time"] = time.perf_counter() - prompt_tic + stats["prompt_tokens"] = float(prompt_ids.size) + + next_token = self._sample_tokens( + prefill[:, -1, :], + temperature=temperature, + top_k=top_k, + top_p=top_p, + )[:, None] + generated = [next_token] + total_generated = 1 + nfe = 1 + + if next_token.item() in eos_token_ids: + return mx.concatenate([prompt_ids, next_token], axis=1), nfe + + while total_generated < max_new_tokens: + cache_len = cache[0].offset + block = mx.full((1, block_length), mask_token_id, dtype=prompt_ids.dtype) + block[:, 0] = next_token[:, 0] + + while bool((block == mask_token_id).any().item()): + self.set_linear_spec_lora_enabled(True) + draft_logits = self( + block, + cache=cache, + use_cache=False, + use_causal_mask=False, + ).logits + nfe += 1 + is_mask = block == mask_token_id + if threshold > 0: + draft_tokens, draft_probs = self._sample_with_temperature_topk_topp( + draft_logits, + temperature=temperature, + top_k=top_k, + top_p=top_p, + ) + draft_conf = mx.where(is_mask, draft_probs, -mx.inf) + unmask = draft_conf >= threshold + if not bool(unmask.any().item()): + _, best_idx = _topk(draft_conf, 1) + positions = mx.arange(block_length) + unmask = (positions[None, None, :] == best_idx[..., None]).any( + axis=1 + ) + block = mx.where(unmask, draft_tokens, block) + else: + draft_tokens = self._sample_tokens( + draft_logits, + temperature=temperature, + top_k=top_k, + top_p=top_p, + ) + block = mx.where(is_mask, draft_tokens, block) + break + + self.set_linear_spec_lora_enabled(False) + verify_logits = self( + block, + cache=cache, + use_cache=True, + use_causal_mask=True, + ).logits + nfe += 1 + ar_tokens = self._sample_tokens( + verify_logits, + temperature=temperature, + top_k=top_k, + top_p=top_p, + ) + + ar_token_ids = ar_tokens[0].tolist() + block_ids = block[0].tolist() + accepted = 1 + for i in range(block_length - 1): + if ar_token_ids[i] == block_ids[i + 1]: + accepted += 1 + else: + break + accepted = min(accepted, max_new_tokens - total_generated) + accepted_tokens = ar_tokens[:, :accepted] + generated.append(accepted_tokens) + total_generated += accepted + + self._trim_cache(cache, cache_len + accepted) + next_token = ar_tokens[:, accepted - 1 : accepted] + + eos_index = _first_token_index(accepted_tokens[0], eos_token_ids) + if eos_index is not None: + generated[-1] = accepted_tokens[:, : eos_index + 1] + break + + return ( + mx.concatenate([prompt_ids, mx.concatenate(generated, axis=1)], axis=1), + nfe, + ) + + def stream_linear_spec_generate( + self, + prompt_ids: mx.array, + max_new_tokens: int = 128, + block_length: int = 32, + temperature: float = 0.0, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + mask_token_id: Optional[int] = None, + eos_token_id: Optional[int] = None, + threshold: float = 0.0, + stats: Optional[Dict[str, float]] = None, + **kwargs, + ): + if prompt_ids.shape[0] != 1: + raise ValueError("Linear speculative decoding requires batch size 1.") + if block_length <= 0: + raise ValueError("block_length must be a positive integer.") + + mask_token_id = ( + self.config.mask_token_id if mask_token_id is None else mask_token_id + ) + if eos_token_id is None: + eos_token_id = self.config.eos_token_id + eos_token_ids = ( + set(eos_token_id) + if isinstance(eos_token_id, (list, tuple, set)) + else {eos_token_id} + ) + + prompt_tic = time.perf_counter() + cache = self.make_cache() + prefill = self( + prompt_ids, + cache=cache, + use_cache=True, + use_causal_mask=True, + ).logits + mx.eval(prefill) + if stats is not None: + stats["prompt_time"] = time.perf_counter() - prompt_tic + stats["prompt_tokens"] = float(prompt_ids.size) + + next_token = self._sample_tokens( + prefill[:, -1, :], + temperature=temperature, + top_k=top_k, + top_p=top_p, + )[:, None] + mx.eval(next_token) + yield next_token + total_generated = 1 + + if next_token.item() in eos_token_ids: + return + + while total_generated < max_new_tokens: + cache_len = cache[0].offset + block = mx.full((1, block_length), mask_token_id, dtype=prompt_ids.dtype) + block[:, 0] = next_token[:, 0] + + while bool((block == mask_token_id).any().item()): + self.set_linear_spec_lora_enabled(True) + draft_logits = self( + block, + cache=cache, + use_cache=False, + use_causal_mask=False, + ).logits + is_mask = block == mask_token_id + if threshold > 0: + draft_tokens, draft_probs = self._sample_with_temperature_topk_topp( + draft_logits, + temperature=temperature, + top_k=top_k, + top_p=top_p, + ) + draft_conf = mx.where(is_mask, draft_probs, -mx.inf) + unmask = draft_conf >= threshold + if not bool(unmask.any().item()): + _, best_idx = _topk(draft_conf, 1) + positions = mx.arange(block_length) + unmask = (positions[None, None, :] == best_idx[..., None]).any( + axis=1 + ) + block = mx.where(unmask, draft_tokens, block) + else: + draft_tokens = self._sample_tokens( + draft_logits, + temperature=temperature, + top_k=top_k, + top_p=top_p, + ) + block = mx.where(is_mask, draft_tokens, block) + break + + self.set_linear_spec_lora_enabled(False) + verify_logits = self( + block, + cache=cache, + use_cache=True, + use_causal_mask=True, + ).logits + ar_tokens = self._sample_tokens( + verify_logits, + temperature=temperature, + top_k=top_k, + top_p=top_p, + ) + + ar_token_ids = ar_tokens[0].tolist() + block_ids = block[0].tolist() + accepted = 1 + for i in range(block_length - 1): + if ar_token_ids[i] == block_ids[i + 1]: + accepted += 1 + else: + break + accepted = min(accepted, max_new_tokens - total_generated) + accepted_tokens = ar_tokens[:, :accepted] + + self._trim_cache(cache, cache_len + accepted) + next_token = ar_tokens[:, accepted - 1 : accepted] + + eos_index = _first_token_index(accepted_tokens[0], eos_token_ids) + if eos_index is not None: + accepted_tokens = accepted_tokens[:, : eos_index + 1] + mx.eval(accepted_tokens) + yield accepted_tokens + total_generated += accepted_tokens.shape[1] + if eos_index is not None: + break + + def sanitize(self, weights): + if self.config.tie_word_embeddings: + weights.pop("diffusion_head.weight", None) + + return { + k: v + for k, v in weights.items() + if "rotary_emb.inv_freq" not in k + and not k.endswith(".self_attn.k_scale") + and not k.endswith(".self_attn.v_scale") + } + + @property + def layers(self): + return self.model.layers + + @property + def head_dim(self): + return self.config.head_dim + + @property + def n_kv_heads(self): + return self.config.num_key_value_heads + + def make_cache(self): + return [KVCache() for _ in self.layers] diff --git a/mlx_vlm/models/nemotron_labs_diffusion/nemotron_labs_diffusion.py b/mlx_vlm/models/nemotron_labs_diffusion/nemotron_labs_diffusion.py new file mode 100644 index 000000000..93e2a55c6 --- /dev/null +++ b/mlx_vlm/models/nemotron_labs_diffusion/nemotron_labs_diffusion.py @@ -0,0 +1,77 @@ +from typing import Optional + +import mlx.core as mx +import mlx.nn as nn + +from ..base import InputEmbeddingsFeatures, LanguageModelOutput +from .config import ModelConfig +from .language import LanguageModel + + +class Model(nn.Module): + def __init__(self, config: ModelConfig): + super().__init__() + self.config = config + self.model_type = config.model_type + self.language_model = LanguageModel(config) + self._model_path = None + + @property + def model_path(self): + return self._model_path + + @model_path.setter + def model_path(self, value): + self._model_path = value + self.language_model.model_path = value + + def get_input_embeddings( + self, + input_ids: Optional[mx.array] = None, + pixel_values: Optional[mx.array] = None, + **kwargs, + ) -> InputEmbeddingsFeatures: + if pixel_values is not None: + raise ValueError("Nemotron Labs Diffusion is a text-only model.") + if input_ids is None: + raise ValueError("input_ids are required for Nemotron Labs Diffusion.") + return InputEmbeddingsFeatures( + inputs_embeds=self.language_model.model.embed_tokens(input_ids) + ) + + def __call__( + self, + input_ids: mx.array, + pixel_values: mx.array = None, + mask: mx.array = None, + cache=None, + **kwargs, + ) -> LanguageModelOutput: + input_embeddings_features = self.get_input_embeddings(input_ids, pixel_values) + return self.language_model( + input_ids, + mask=mask, + cache=cache, + inputs_embeds=input_embeddings_features.inputs_embeds, + **kwargs, + ) + + def sanitize(self, weights): + def transform_key(key): + if key.startswith("language_model."): + return key + if key.startswith("encoder."): + return f"language_model.model.{key[len('encoder.'):]}" + if key.startswith("diffusion_head."): + return f"language_model.{key}" + return key + + weights = {transform_key(k): v for k, v in weights.items()} + return self.language_model.sanitize(weights) + + @property + def layers(self): + return self.language_model.layers + + def make_cache(self): + return self.language_model.make_cache() diff --git a/mlx_vlm/prompt_utils.py b/mlx_vlm/prompt_utils.py index 11e533dba..f17fe6c7c 100644 --- a/mlx_vlm/prompt_utils.py +++ b/mlx_vlm/prompt_utils.py @@ -92,6 +92,7 @@ class MessageFormat(Enum): "falcon_ocr": MessageFormat.PROMPT_ONLY, "paligemma": MessageFormat.PROMPT_WITH_IMAGE_TOKEN, "laguna": MessageFormat.TEXT_ONLY, + "nemotron_labs_diffusion": MessageFormat.TEXT_ONLY, "deepseek_v4": MessageFormat.TEXT_ONLY, "hrm_text": MessageFormat.TEXT_ONLY, } diff --git a/mlx_vlm/tests/test_diffusion_models.py b/mlx_vlm/tests/test_diffusion_models.py index c1b5a03d4..69a6eecf3 100644 --- a/mlx_vlm/tests/test_diffusion_models.py +++ b/mlx_vlm/tests/test_diffusion_models.py @@ -19,8 +19,31 @@ def decode(self, tokens, skip_special_tokens=False): return "decoded" +class _Detokenizer: + def __init__(self): + self.text = "" + self.offset = 0 + + def reset(self): + self.text = "" + self.offset = 0 + + def add_token(self, token, skip_special_token_ids=None): + self.text += "decoded" + + def finalize(self): + pass + + @property + def last_segment(self): + segment = self.text[self.offset :] + self.offset = len(self.text) + return segment + + class _Processor: tokenizer = _Tokenizer() + detokenizer = _Detokenizer() class TestDiffusionModels(unittest.TestCase): @@ -201,3 +224,129 @@ def counted_call(self, *args, **kwargs): llada_language.LLaDA2MoeModel.__call__ = original_call self.assertLessEqual(calls["count"], 6) + + def test_nemotron_labs_diffusion(self): + from mlx_vlm.models import nemotron_labs_diffusion + + config = nemotron_labs_diffusion.ModelConfig( + model_type="nemotron_labs_diffusion", + vocab_size=128, + hidden_size=64, + intermediate_size=128, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=16, + max_position_embeddings=128, + rope_parameters={ + "rope_type": "default", + "rope_theta": 1000000.0, + "llama_4_scaling_beta": 0.1, + "original_max_position_embeddings": 64, + }, + eos_token_id=3, + mask_token_id=127, + ) + model = nemotron_labs_diffusion.Model(config) + self.dtype_consistency_test_runner( + model.language_model, + config.model_type, + config.num_hidden_layers, + ) + + generated = model.language_model.generate( + mx.array([[4]], dtype=mx.int32), + block_length=4, + steps=1, + gen_length=8, + max_post_steps=4, + mask_id=127, + eos_id=999, + ) + self.assertEqual(generated.shape, (1, 8)) + + ar_generated, ar_nfe = model.language_model.ar_generate( + mx.array([[4]], dtype=mx.int32), + max_new_tokens=2, + eos_token_id=3, + ) + mx.eval(ar_generated) + self.assertEqual(ar_generated.shape[0], 1) + self.assertLessEqual(ar_generated.shape[1], 3) + self.assertGreaterEqual(ar_nfe, 1) + + spec_generated, spec_nfe = model.language_model.linear_spec_generate( + mx.array([[4]], dtype=mx.int32), + max_new_tokens=2, + block_length=2, + eos_token_id=3, + mask_token_id=127, + ) + mx.eval(spec_generated) + self.assertEqual(spec_generated.shape[0], 1) + self.assertLessEqual(spec_generated.shape[1], 3) + self.assertGreaterEqual(spec_nfe, 1) + + def unexpected_diffusion_generate(*args, **kwargs): + raise AssertionError("Default Nemotron generation should use AR") + + model.language_model.generate = unexpected_diffusion_generate + default_results = list( + stream_generate( + model, + _Processor(), + prompt="ignored", + input_ids=mx.array([[4]], dtype=mx.int32), + max_tokens=1, + temperature=0.0, + ) + ) + self.assertEqual(default_results[-1].generation_tokens, 1) + + diffusion_calls = {} + + def diffusion_generate(input_ids, **kwargs): + diffusion_calls["kwargs"] = kwargs + kwargs["stats"]["prompt_time"] = 1.0 + return mx.array([[5, 3]], dtype=mx.int32) + + model.language_model.generate = diffusion_generate + diffusion_results = list( + stream_generate( + model, + _Processor(), + prompt="ignored", + input_ids=mx.array([[4]], dtype=mx.int32), + max_tokens=2, + generation_mode="diffusion", + temperature=0.0, + ) + ) + self.assertTrue(diffusion_calls["kwargs"]) + self.assertFalse(diffusion_calls["kwargs"]["linear_speculative"]) + self.assertEqual(diffusion_results[-1].generation_tokens, 2) + + linear_calls = {} + + def generate(input_ids, **kwargs): + linear_calls["kwargs"] = kwargs + kwargs["stats"]["prompt_time"] = 1.0 + return mx.array([[5, 3]], dtype=mx.int32) + + model.language_model.generate = generate + results = list( + stream_generate( + model, + _Processor(), + prompt="ignored", + input_ids=mx.array([[4]], dtype=mx.int32), + max_tokens=2, + generation_mode="linear_speculative", + temperature=0.0, + ) + ) + self.assertTrue(linear_calls["kwargs"]) + self.assertTrue(linear_calls["kwargs"]["linear_speculative"]) + self.assertGreaterEqual(len(results), 1) + self.assertEqual(results[-1].generation_tokens, 2) + self.assertEqual(results[-1].finish_reason, "stop") diff --git a/mlx_vlm/utils.py b/mlx_vlm/utils.py index 01c8e3696..81c542bcc 100644 --- a/mlx_vlm/utils.py +++ b/mlx_vlm/utils.py @@ -380,6 +380,7 @@ def get_class_predicate(p, m): if not lazy: mx.eval(model.parameters()) + model.model_path = model_path model.eval() return model From ae0574488bdfd0218dfb6044b76cc6770401ad1f Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 30 May 2026 01:41:01 +0200 Subject: [PATCH 02/19] Tune Nemotron diffusion mode defaults --- mlx_vlm/generate/dispatch.py | 18 +++++++++++++----- .../models/nemotron_labs_diffusion/README.md | 2 ++ .../models/nemotron_labs_diffusion/config.py | 1 + mlx_vlm/tests/test_diffusion_models.py | 2 ++ 4 files changed, 18 insertions(+), 5 deletions(-) diff --git a/mlx_vlm/generate/dispatch.py b/mlx_vlm/generate/dispatch.py index e5e0c363f..1fb148d4b 100644 --- a/mlx_vlm/generate/dispatch.py +++ b/mlx_vlm/generate/dispatch.py @@ -770,14 +770,22 @@ def stream_generate( top_k = kwargs.get("top_k", DEFAULT_TOP_K) max_denoising_steps = kwargs.get("max_denoising_steps") if max_denoising_steps is None: - max_denoising_steps = kwargs.get("steps", 32) + config = getattr(model, "config", None) + max_denoising_steps = kwargs.get( + "steps", getattr(config, "default_diffusion_steps", 32) + ) num_to_transfer = kwargs.get( "num_to_transfer", DEFAULT_MASKED_DIFFUSION_NUM_TO_TRANSFER ) - threshold = kwargs.get("threshold", DEFAULT_MASKED_DIFFUSION_THRESHOLD) - min_threshold = kwargs.get( - "min_threshold", DEFAULT_MASKED_DIFFUSION_MIN_THRESHOLD - ) + config = getattr(model, "config", None) + if getattr(config, "default_generation_mode", None) == "ar": + threshold = kwargs.get("threshold") + min_threshold = kwargs.get("min_threshold") + else: + threshold = kwargs.get("threshold", DEFAULT_MASKED_DIFFUSION_THRESHOLD) + min_threshold = kwargs.get( + "min_threshold", DEFAULT_MASKED_DIFFUSION_MIN_THRESHOLD + ) editing_threshold = kwargs.get( "editing_threshold", DEFAULT_MASKED_DIFFUSION_EDITING_THRESHOLD ) diff --git a/mlx_vlm/models/nemotron_labs_diffusion/README.md b/mlx_vlm/models/nemotron_labs_diffusion/README.md index 564b83dc8..844e59112 100644 --- a/mlx_vlm/models/nemotron_labs_diffusion/README.md +++ b/mlx_vlm/models/nemotron_labs_diffusion/README.md @@ -37,6 +37,7 @@ mlx_vlm.generate \ ### Diffusion generation Pass `generation_mode="diffusion"` to use the masked diffusion path. +Nemotron defaults to 8 denoising steps in this mode so the diffusion path transfers multiple tokens per forward pass. ```sh mlx_vlm.generate \ @@ -166,5 +167,6 @@ print(result.text) - The model is text-only. Image, audio, and video inputs are not supported. - AR generation should use the normal CLI without diffusion-specific arguments. - Diffusion generation uses masked block denoising. `--verbose` shows the block visualization as masks are filled. +- The default diffusion schedule uses 8 denoising steps and no confidence threshold for speed. Increase `--max-denoising-steps` or pass `--threshold` if you want more conservative denoising. - Diffusion and linear self-speculative generation are exposed through `generation_mode`, for example `--gen-kwargs '{"generation_mode": "diffusion"}'`. - The optional `linear_spec_lora` adapter included in the Hugging Face repo is used only during the diffusion draft phase of linear self-speculation. diff --git a/mlx_vlm/models/nemotron_labs_diffusion/config.py b/mlx_vlm/models/nemotron_labs_diffusion/config.py index b0694a52c..b11346a46 100644 --- a/mlx_vlm/models/nemotron_labs_diffusion/config.py +++ b/mlx_vlm/models/nemotron_labs_diffusion/config.py @@ -33,6 +33,7 @@ class ModelConfig(BaseModelConfig): attn_implementation: str = "sdpa" mask_token_id: int = 100 default_generation_mode: str = "ar" + default_diffusion_steps: int = 8 dlm_paradigm: str = "bidirectional" block_size: int = 32 dlm_loss_weight: Optional[float] = None diff --git a/mlx_vlm/tests/test_diffusion_models.py b/mlx_vlm/tests/test_diffusion_models.py index 69a6eecf3..8838526ba 100644 --- a/mlx_vlm/tests/test_diffusion_models.py +++ b/mlx_vlm/tests/test_diffusion_models.py @@ -324,6 +324,8 @@ def diffusion_generate(input_ids, **kwargs): ) self.assertTrue(diffusion_calls["kwargs"]) self.assertFalse(diffusion_calls["kwargs"]["linear_speculative"]) + self.assertEqual(diffusion_calls["kwargs"]["steps"], 8) + self.assertIsNone(diffusion_calls["kwargs"]["threshold"]) self.assertEqual(diffusion_results[-1].generation_tokens, 2) linear_calls = {} From d6742029d6e0d08654d857fc742496485c3e7739 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 30 May 2026 01:43:55 +0200 Subject: [PATCH 03/19] Restore quality-first Nemotron diffusion defaults --- mlx_vlm/models/nemotron_labs_diffusion/README.md | 4 ++-- mlx_vlm/models/nemotron_labs_diffusion/config.py | 2 +- mlx_vlm/tests/test_diffusion_models.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/mlx_vlm/models/nemotron_labs_diffusion/README.md b/mlx_vlm/models/nemotron_labs_diffusion/README.md index 844e59112..748ad6406 100644 --- a/mlx_vlm/models/nemotron_labs_diffusion/README.md +++ b/mlx_vlm/models/nemotron_labs_diffusion/README.md @@ -37,7 +37,7 @@ mlx_vlm.generate \ ### Diffusion generation Pass `generation_mode="diffusion"` to use the masked diffusion path. -Nemotron defaults to 8 denoising steps in this mode so the diffusion path transfers multiple tokens per forward pass. +Nemotron defaults to 32 denoising steps in this mode for a quality-first diffusion schedule. ```sh mlx_vlm.generate \ @@ -167,6 +167,6 @@ print(result.text) - The model is text-only. Image, audio, and video inputs are not supported. - AR generation should use the normal CLI without diffusion-specific arguments. - Diffusion generation uses masked block denoising. `--verbose` shows the block visualization as masks are filled. -- The default diffusion schedule uses 8 denoising steps and no confidence threshold for speed. Increase `--max-denoising-steps` or pass `--threshold` if you want more conservative denoising. +- The default diffusion schedule uses 32 denoising steps and no confidence threshold. Lower `--max-denoising-steps` for speed experiments, but quality can degrade quickly. - Diffusion and linear self-speculative generation are exposed through `generation_mode`, for example `--gen-kwargs '{"generation_mode": "diffusion"}'`. - The optional `linear_spec_lora` adapter included in the Hugging Face repo is used only during the diffusion draft phase of linear self-speculation. diff --git a/mlx_vlm/models/nemotron_labs_diffusion/config.py b/mlx_vlm/models/nemotron_labs_diffusion/config.py index b11346a46..6f4d3cd91 100644 --- a/mlx_vlm/models/nemotron_labs_diffusion/config.py +++ b/mlx_vlm/models/nemotron_labs_diffusion/config.py @@ -33,7 +33,7 @@ class ModelConfig(BaseModelConfig): attn_implementation: str = "sdpa" mask_token_id: int = 100 default_generation_mode: str = "ar" - default_diffusion_steps: int = 8 + default_diffusion_steps: int = 32 dlm_paradigm: str = "bidirectional" block_size: int = 32 dlm_loss_weight: Optional[float] = None diff --git a/mlx_vlm/tests/test_diffusion_models.py b/mlx_vlm/tests/test_diffusion_models.py index 8838526ba..854890ce9 100644 --- a/mlx_vlm/tests/test_diffusion_models.py +++ b/mlx_vlm/tests/test_diffusion_models.py @@ -324,7 +324,7 @@ def diffusion_generate(input_ids, **kwargs): ) self.assertTrue(diffusion_calls["kwargs"]) self.assertFalse(diffusion_calls["kwargs"]["linear_speculative"]) - self.assertEqual(diffusion_calls["kwargs"]["steps"], 8) + self.assertEqual(diffusion_calls["kwargs"]["steps"], 32) self.assertIsNone(diffusion_calls["kwargs"]["threshold"]) self.assertEqual(diffusion_results[-1].generation_tokens, 2) From 50212cd3f9ca8070880099e30b81b8d0b47399f2 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 30 May 2026 01:48:06 +0200 Subject: [PATCH 04/19] Align Nemotron diffusion threshold transfer --- mlx_vlm/generate/dispatch.py | 4 +++- mlx_vlm/models/nemotron_labs_diffusion/README.md | 4 ++-- mlx_vlm/models/nemotron_labs_diffusion/config.py | 1 + mlx_vlm/models/nemotron_labs_diffusion/language.py | 9 ++++++--- mlx_vlm/tests/test_diffusion_models.py | 2 +- 5 files changed, 13 insertions(+), 7 deletions(-) diff --git a/mlx_vlm/generate/dispatch.py b/mlx_vlm/generate/dispatch.py index 1fb148d4b..f19e117e3 100644 --- a/mlx_vlm/generate/dispatch.py +++ b/mlx_vlm/generate/dispatch.py @@ -779,7 +779,9 @@ def stream_generate( ) config = getattr(model, "config", None) if getattr(config, "default_generation_mode", None) == "ar": - threshold = kwargs.get("threshold") + threshold = kwargs.get( + "threshold", getattr(config, "default_diffusion_threshold", None) + ) min_threshold = kwargs.get("min_threshold") else: threshold = kwargs.get("threshold", DEFAULT_MASKED_DIFFUSION_THRESHOLD) diff --git a/mlx_vlm/models/nemotron_labs_diffusion/README.md b/mlx_vlm/models/nemotron_labs_diffusion/README.md index 748ad6406..f6dbb0eb8 100644 --- a/mlx_vlm/models/nemotron_labs_diffusion/README.md +++ b/mlx_vlm/models/nemotron_labs_diffusion/README.md @@ -37,7 +37,7 @@ mlx_vlm.generate \ ### Diffusion generation Pass `generation_mode="diffusion"` to use the masked diffusion path. -Nemotron defaults to 32 denoising steps in this mode for a quality-first diffusion schedule. +Nemotron defaults to 32 denoising steps and a 0.9 transfer threshold in this mode, matching the upstream dLM example. ```sh mlx_vlm.generate \ @@ -167,6 +167,6 @@ print(result.text) - The model is text-only. Image, audio, and video inputs are not supported. - AR generation should use the normal CLI without diffusion-specific arguments. - Diffusion generation uses masked block denoising. `--verbose` shows the block visualization as masks are filled. -- The default diffusion schedule uses 32 denoising steps and no confidence threshold. Lower `--max-denoising-steps` for speed experiments, but quality can degrade quickly. +- The default diffusion schedule uses 32 denoising steps and a 0.9 confidence threshold. Lower `--max-denoising-steps` for speed experiments, but quality can degrade quickly. - Diffusion and linear self-speculative generation are exposed through `generation_mode`, for example `--gen-kwargs '{"generation_mode": "diffusion"}'`. - The optional `linear_spec_lora` adapter included in the Hugging Face repo is used only during the diffusion draft phase of linear self-speculation. diff --git a/mlx_vlm/models/nemotron_labs_diffusion/config.py b/mlx_vlm/models/nemotron_labs_diffusion/config.py index 6f4d3cd91..4e400e2b4 100644 --- a/mlx_vlm/models/nemotron_labs_diffusion/config.py +++ b/mlx_vlm/models/nemotron_labs_diffusion/config.py @@ -34,6 +34,7 @@ class ModelConfig(BaseModelConfig): mask_token_id: int = 100 default_generation_mode: str = "ar" default_diffusion_steps: int = 32 + default_diffusion_threshold: Optional[float] = 0.9 dlm_paradigm: str = "bidirectional" block_size: int = 32 dlm_loss_weight: Optional[float] = None diff --git a/mlx_vlm/models/nemotron_labs_diffusion/language.py b/mlx_vlm/models/nemotron_labs_diffusion/language.py index 38758e271..9793afe9a 100644 --- a/mlx_vlm/models/nemotron_labs_diffusion/language.py +++ b/mlx_vlm/models/nemotron_labs_diffusion/language.py @@ -615,9 +615,12 @@ def visualize_tokens(tokens: mx.array, force: bool = False) -> None: confidence = mx.where(mask_index, token_probs, -mx.inf) remaining_steps = max(1, steps - step_idx) masked_count = int(mask_index.sum().item()) - transfer_count = max( - 1, (masked_count + remaining_steps - 1) // remaining_steps - ) + if threshold is not None: + transfer_count = masked_count + else: + transfer_count = max( + 1, (masked_count + remaining_steps - 1) // remaining_steps + ) if max_transfer_per_step is not None: transfer_count = min(transfer_count, max_transfer_per_step) _, indices = _topk(confidence, min(transfer_count, masked_count)) diff --git a/mlx_vlm/tests/test_diffusion_models.py b/mlx_vlm/tests/test_diffusion_models.py index 854890ce9..c8d9b40b7 100644 --- a/mlx_vlm/tests/test_diffusion_models.py +++ b/mlx_vlm/tests/test_diffusion_models.py @@ -325,7 +325,7 @@ def diffusion_generate(input_ids, **kwargs): self.assertTrue(diffusion_calls["kwargs"]) self.assertFalse(diffusion_calls["kwargs"]["linear_speculative"]) self.assertEqual(diffusion_calls["kwargs"]["steps"], 32) - self.assertIsNone(diffusion_calls["kwargs"]["threshold"]) + self.assertEqual(diffusion_calls["kwargs"]["threshold"], 0.9) self.assertEqual(diffusion_results[-1].generation_tokens, 2) linear_calls = {} From 0f408cce9e8a2eeaa7e75416a8eebb66aca40a1e Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 30 May 2026 01:59:45 +0200 Subject: [PATCH 05/19] Reduce Nemotron diffusion materialization --- .../nemotron_labs_diffusion/language.py | 70 +++++++++---------- 1 file changed, 34 insertions(+), 36 deletions(-) diff --git a/mlx_vlm/models/nemotron_labs_diffusion/language.py b/mlx_vlm/models/nemotron_labs_diffusion/language.py index 9793afe9a..aa9b7045c 100644 --- a/mlx_vlm/models/nemotron_labs_diffusion/language.py +++ b/mlx_vlm/models/nemotron_labs_diffusion/language.py @@ -312,8 +312,11 @@ def _sample_with_temperature_topk_topp( ): if temperature == 0.0: token = mx.argmax(logits, axis=-1) - probs = mx.softmax(logits.astype(mx.float32), axis=-1, precise=True) - token_prob = mx.take_along_axis(probs, token[..., None], axis=-1)[..., 0] + logits_f32 = logits.astype(mx.float32) + token_logit = mx.take_along_axis(logits_f32, token[..., None], axis=-1)[ + ..., 0 + ] + token_prob = mx.exp(token_logit - mx.logsumexp(logits_f32, axis=-1)) return token, token_prob if temperature != 1.0: @@ -321,8 +324,9 @@ def _sample_with_temperature_topk_topp( logits = self._top_k_logits(logits, top_k) logits = self._top_p_logits(logits, top_p) token = mx.random.categorical(logits.astype(mx.float32), axis=-1) - probs = mx.softmax(logits.astype(mx.float32), axis=-1, precise=True) - token_prob = mx.take_along_axis(probs, token[..., None], axis=-1)[..., 0] + logits_f32 = logits.astype(mx.float32) + token_logit = mx.take_along_axis(logits_f32, token[..., None], axis=-1)[..., 0] + token_prob = mx.exp(token_logit - mx.logsumexp(logits_f32, axis=-1)) return token, token_prob def _project_hidden(self, hidden_states: mx.array) -> mx.array: @@ -573,6 +577,7 @@ def visualize_tokens(tokens: mx.array, force: bool = False) -> None: recorded_prompt_time = True total_generated = 0 + end_length = None num_blocks = (gen_length + block_length - 1) // block_length for _ in range(num_blocks): remaining = gen_length - total_generated @@ -613,29 +618,29 @@ def visualize_tokens(tokens: mx.array, force: bool = False) -> None: recorded_prompt_time = True x0 = mx.where(mask_index, x0, block) confidence = mx.where(mask_index, token_probs, -mx.inf) - remaining_steps = max(1, steps - step_idx) - masked_count = int(mask_index.sum().item()) if threshold is not None: - transfer_count = masked_count + high_confidence = (confidence >= threshold) & mask_index + _, best_index = _topk(confidence, 1) + best_mask = ( + block_positions[None, None, :] == best_index[..., None] + ).any(axis=1) + transfer_mask = mx.where( + high_confidence.any(axis=1)[:, None], + high_confidence, + best_mask, + ) else: + remaining_steps = max(1, steps - step_idx) + masked_count = int(mask_index.sum().item()) transfer_count = max( 1, (masked_count + remaining_steps - 1) // remaining_steps ) - if max_transfer_per_step is not None: - transfer_count = min(transfer_count, max_transfer_per_step) - _, indices = _topk(confidence, min(transfer_count, masked_count)) - transfer_mask = ( - block_positions[None, None, :] == indices[..., None] - ).any(axis=1) - if threshold is not None: - high_confidence = (confidence >= threshold) & mask_index - if bool(high_confidence.any().item()): - transfer_mask = transfer_mask & high_confidence - else: - _, best_index = _topk(confidence, 1) - transfer_mask = ( - block_positions[None, None, :] == best_index[..., None] - ).any(axis=1) + if max_transfer_per_step is not None: + transfer_count = min(transfer_count, max_transfer_per_step) + _, indices = _topk(confidence, min(transfer_count, masked_count)) + transfer_mask = ( + block_positions[None, None, :] == indices[..., None] + ).any(axis=1) block = mx.where(transfer_mask, x0, block) if visualizer_state["active"] and bool(transfer_mask.any().item()): preview = ( @@ -660,27 +665,20 @@ def visualize_tokens(tokens: mx.array, force: bool = False) -> None: generated_block = block[:, :current_block_length] generated_blocks.append(generated_block) total_generated += current_block_length - if ( - eos_early_stop - and _first_token_index(generated_block[0], eos_token_ids) is not None - ): - break + if eos_early_stop: + eos_index = _first_token_index(generated_block[0], eos_token_ids) + if eos_index is not None: + end_length = total_generated - current_block_length + eos_index + 1 + break generated = ( mx.concatenate(generated_blocks, axis=1) if generated_blocks else mx.zeros((1, 0), dtype=inputs.dtype) ) - generated_ids = generated[0].tolist() - end = next( - ( - i + 1 - for i, token_id in enumerate(generated_ids) - if token_id in eos_token_ids - ), - generated.shape[1], - ) + end = end_length if end_length is not None else generated.shape[1] if visualizer_state["active"]: + generated_ids = generated[0].tolist() finish_visualizer() if tokenizer is not None: final_text = tokenizer.decode( From 9267d1a2f579d738f400749e09a3b77148063bdf Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 30 May 2026 02:04:26 +0200 Subject: [PATCH 06/19] Preserve bf16 in Nemotron diffusion --- .../nemotron_labs_diffusion/language.py | 39 ++++++++++++------- mlx_vlm/tests/test_diffusion_models.py | 16 ++++++++ 2 files changed, 40 insertions(+), 15 deletions(-) diff --git a/mlx_vlm/models/nemotron_labs_diffusion/language.py b/mlx_vlm/models/nemotron_labs_diffusion/language.py index aa9b7045c..6d8fc643b 100644 --- a/mlx_vlm/models/nemotron_labs_diffusion/language.py +++ b/mlx_vlm/models/nemotron_labs_diffusion/language.py @@ -51,19 +51,21 @@ def _wrap_text(text: str, width: int) -> str: def _make_bidirectional_mask( attention_mask: Optional[mx.array], x: mx.array ) -> Optional[mx.array]: + zero = mx.array(0, dtype=x.dtype) + neg_large = mx.array(mx.finfo(x.dtype).min, dtype=x.dtype) if attention_mask is None: return None if attention_mask.ndim == 4: if attention_mask.dtype == mx.bool_: return attention_mask - return mx.where(attention_mask.astype(mx.bool_), 0.0, mx.finfo(x.dtype).min) + return mx.where(attention_mask.astype(mx.bool_), zero, neg_large) if attention_mask.ndim != 2: return attention_mask if attention_mask.shape[-1] == 0 or bool(mx.all(attention_mask).item()): return None mask = attention_mask[:, None, None, :].astype(mx.bool_) - return mx.where(mask, 0.0, mx.finfo(x.dtype).min).astype(x.dtype) + return mx.where(mask, zero, neg_large) def _llama4_attention_scale( @@ -284,7 +286,8 @@ def _top_k_logits(logits: mx.array, k: Optional[int]) -> mx.array: if k is None or k <= 0: return logits values = _topk(logits, k=k, axis=-1)[0] - return mx.where(logits < values[..., -1:], mx.finfo(logits.dtype).min, logits) + neg_large = mx.array(mx.finfo(logits.dtype).min, dtype=logits.dtype) + return mx.where(logits < values[..., -1:], neg_large, logits) @staticmethod def _top_p_logits(logits: mx.array, p: Optional[float]) -> mx.array: @@ -301,7 +304,8 @@ def _top_p_logits(logits: mx.array, p: Optional[float]) -> mx.array: ) inverse_indices = mx.argsort(sorted_indices, axis=-1) mask = mx.take_along_axis(sorted_mask, inverse_indices, axis=-1) - return mx.where(mask, mx.finfo(logits.dtype).min, logits) + neg_large = mx.array(mx.finfo(logits.dtype).min, dtype=logits.dtype) + return mx.where(mask, neg_large, logits) def _sample_with_temperature_topk_topp( self, @@ -312,11 +316,8 @@ def _sample_with_temperature_topk_topp( ): if temperature == 0.0: token = mx.argmax(logits, axis=-1) - logits_f32 = logits.astype(mx.float32) - token_logit = mx.take_along_axis(logits_f32, token[..., None], axis=-1)[ - ..., 0 - ] - token_prob = mx.exp(token_logit - mx.logsumexp(logits_f32, axis=-1)) + token_logit = mx.take_along_axis(logits, token[..., None], axis=-1)[..., 0] + token_prob = mx.exp(token_logit - mx.logsumexp(logits, axis=-1)) return token, token_prob if temperature != 1.0: @@ -324,9 +325,8 @@ def _sample_with_temperature_topk_topp( logits = self._top_k_logits(logits, top_k) logits = self._top_p_logits(logits, top_p) token = mx.random.categorical(logits.astype(mx.float32), axis=-1) - logits_f32 = logits.astype(mx.float32) - token_logit = mx.take_along_axis(logits_f32, token[..., None], axis=-1)[..., 0] - token_prob = mx.exp(token_logit - mx.logsumexp(logits_f32, axis=-1)) + token_logit = mx.take_along_axis(logits, token[..., None], axis=-1)[..., 0] + token_prob = mx.exp(token_logit - mx.logsumexp(logits, axis=-1)) return token, token_prob def _project_hidden(self, hidden_states: mx.array) -> mx.array: @@ -617,7 +617,10 @@ def visualize_tokens(tokens: mx.array, force: bool = False) -> None: stats["prompt_tokens"] = float(inputs.size) recorded_prompt_time = True x0 = mx.where(mask_index, x0, block) - confidence = mx.where(mask_index, token_probs, -mx.inf) + neg_large = mx.array( + mx.finfo(token_probs.dtype).min, dtype=token_probs.dtype + ) + confidence = mx.where(mask_index, token_probs, neg_large) if threshold is not None: high_confidence = (confidence >= threshold) & mask_index _, best_index = _topk(confidence, 1) @@ -835,7 +838,10 @@ def linear_spec_generate( top_k=top_k, top_p=top_p, ) - draft_conf = mx.where(is_mask, draft_probs, -mx.inf) + neg_large = mx.array( + mx.finfo(draft_probs.dtype).min, dtype=draft_probs.dtype + ) + draft_conf = mx.where(is_mask, draft_probs, neg_large) unmask = draft_conf >= threshold if not bool(unmask.any().item()): _, best_idx = _topk(draft_conf, 1) @@ -972,7 +978,10 @@ def stream_linear_spec_generate( top_k=top_k, top_p=top_p, ) - draft_conf = mx.where(is_mask, draft_probs, -mx.inf) + neg_large = mx.array( + mx.finfo(draft_probs.dtype).min, dtype=draft_probs.dtype + ) + draft_conf = mx.where(is_mask, draft_probs, neg_large) unmask = draft_conf >= threshold if not bool(unmask.any().item()): _, best_idx = _topk(draft_conf, 1) diff --git a/mlx_vlm/tests/test_diffusion_models.py b/mlx_vlm/tests/test_diffusion_models.py index c8d9b40b7..d0030e6f7 100644 --- a/mlx_vlm/tests/test_diffusion_models.py +++ b/mlx_vlm/tests/test_diffusion_models.py @@ -254,6 +254,21 @@ def test_nemotron_labs_diffusion(self): config.num_hidden_layers, ) + model.language_model.update( + tree_map(lambda p: p.astype(mx.bfloat16), model.language_model.parameters()) + ) + bf16_outputs = model.language_model( + mx.array([[1, 2, 3]], dtype=mx.int32), + attention_mask=mx.array([[1, 1, 0]], dtype=mx.int32), + ) + self.assertEqual(bf16_outputs.logits.dtype, mx.bfloat16) + bf16_filtered = model.language_model._top_k_logits(bf16_outputs.logits, 2) + self.assertEqual(bf16_filtered.dtype, mx.bfloat16) + _, bf16_probs = model.language_model._sample_with_temperature_topk_topp( + bf16_outputs.logits + ) + self.assertEqual(bf16_probs.dtype, mx.bfloat16) + generated = model.language_model.generate( mx.array([[4]], dtype=mx.int32), block_length=4, @@ -281,6 +296,7 @@ def test_nemotron_labs_diffusion(self): block_length=2, eos_token_id=3, mask_token_id=127, + threshold=0.5, ) mx.eval(spec_generated) self.assertEqual(spec_generated.shape[0], 1) From 249185435fe03291e29a32888af761264862306f Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 30 May 2026 02:06:22 +0200 Subject: [PATCH 07/19] Accept upstream Nemotron mode aliases --- mlx_vlm/generate/dispatch.py | 15 +++++-- .../models/nemotron_labs_diffusion/README.md | 3 ++ mlx_vlm/tests/test_diffusion_models.py | 43 +++++++++++++++++++ 3 files changed, 58 insertions(+), 3 deletions(-) diff --git a/mlx_vlm/generate/dispatch.py b/mlx_vlm/generate/dispatch.py index f19e117e3..d6c70e96e 100644 --- a/mlx_vlm/generate/dispatch.py +++ b/mlx_vlm/generate/dispatch.py @@ -623,9 +623,17 @@ def _use_masked_diffusion_text_path(model: nn.Module, kwargs: Dict[str, Any]) -> generation_mode = kwargs.get("generation_mode") if generation_mode is not None: - return generation_mode in ("diffusion", "linear_speculative") + return generation_mode in ( + "diffusion", + "dlm", + "linear_speculative", + "linear_spec", + ) - return bool(kwargs.get("linear_speculative", False)) + return bool( + kwargs.get("linear_speculative", False) + or kwargs.get("linear_speculation", False) + ) def _prime_cached_prefix_rope_state( @@ -822,7 +830,8 @@ def stream_generate( skip_special_tokens=skip_special_tokens, stats=generation_stats, linear_speculative=kwargs.get("linear_speculative", False) - or kwargs.get("generation_mode") == "linear_speculative", + or kwargs.get("linear_speculation", False) + or kwargs.get("generation_mode") in ("linear_speculative", "linear_spec"), ) mx.eval(generated) total_time = time.perf_counter() - tic diff --git a/mlx_vlm/models/nemotron_labs_diffusion/README.md b/mlx_vlm/models/nemotron_labs_diffusion/README.md index f6dbb0eb8..a77a6d826 100644 --- a/mlx_vlm/models/nemotron_labs_diffusion/README.md +++ b/mlx_vlm/models/nemotron_labs_diffusion/README.md @@ -38,6 +38,7 @@ mlx_vlm.generate \ Pass `generation_mode="diffusion"` to use the masked diffusion path. Nemotron defaults to 32 denoising steps and a 0.9 transfer threshold in this mode, matching the upstream dLM example. +The upstream mode alias `generation_mode="dlm"` is also accepted. ```sh mlx_vlm.generate \ @@ -53,6 +54,7 @@ mlx_vlm.generate \ ### Linear self-speculative generation Use `--gen-kwargs` for model-specific generation options. The bundled `linear_spec_lora` adapter is loaded automatically when available. +The upstream mode alias `generation_mode="linear_spec"` is also accepted. ```sh mlx_vlm.generate \ @@ -169,4 +171,5 @@ print(result.text) - Diffusion generation uses masked block denoising. `--verbose` shows the block visualization as masks are filled. - The default diffusion schedule uses 32 denoising steps and a 0.9 confidence threshold. Lower `--max-denoising-steps` for speed experiments, but quality can degrade quickly. - Diffusion and linear self-speculative generation are exposed through `generation_mode`, for example `--gen-kwargs '{"generation_mode": "diffusion"}'`. +- Upstream mode names are accepted as aliases: `dlm` for diffusion and `linear_spec` for linear self-speculation. - The optional `linear_spec_lora` adapter included in the Hugging Face repo is used only during the diffusion draft phase of linear self-speculation. diff --git a/mlx_vlm/tests/test_diffusion_models.py b/mlx_vlm/tests/test_diffusion_models.py index d0030e6f7..a4c176934 100644 --- a/mlx_vlm/tests/test_diffusion_models.py +++ b/mlx_vlm/tests/test_diffusion_models.py @@ -344,6 +344,21 @@ def diffusion_generate(input_ids, **kwargs): self.assertEqual(diffusion_calls["kwargs"]["threshold"], 0.9) self.assertEqual(diffusion_results[-1].generation_tokens, 2) + diffusion_calls.clear() + list( + stream_generate( + model, + _Processor(), + prompt="ignored", + input_ids=mx.array([[4]], dtype=mx.int32), + max_tokens=2, + generation_mode="dlm", + temperature=0.0, + ) + ) + self.assertTrue(diffusion_calls["kwargs"]) + self.assertFalse(diffusion_calls["kwargs"]["linear_speculative"]) + linear_calls = {} def generate(input_ids, **kwargs): @@ -368,3 +383,31 @@ def generate(input_ids, **kwargs): self.assertGreaterEqual(len(results), 1) self.assertEqual(results[-1].generation_tokens, 2) self.assertEqual(results[-1].finish_reason, "stop") + + linear_calls.clear() + list( + stream_generate( + model, + _Processor(), + prompt="ignored", + input_ids=mx.array([[4]], dtype=mx.int32), + max_tokens=2, + generation_mode="linear_spec", + temperature=0.0, + ) + ) + self.assertTrue(linear_calls["kwargs"]["linear_speculative"]) + + linear_calls.clear() + list( + stream_generate( + model, + _Processor(), + prompt="ignored", + input_ids=mx.array([[4]], dtype=mx.int32), + max_tokens=2, + linear_speculation=True, + temperature=0.0, + ) + ) + self.assertTrue(linear_calls["kwargs"]["linear_speculative"]) From 380e25c9b0f0eacac1be0fbc3f67b1bfa3bee3ef Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 30 May 2026 15:56:49 +0200 Subject: [PATCH 08/19] Optimize Nemotron diffusion small-block inference --- .../nemotron_labs_diffusion/language.py | 99 +++++++++++++++---- 1 file changed, 80 insertions(+), 19 deletions(-) diff --git a/mlx_vlm/models/nemotron_labs_diffusion/language.py b/mlx_vlm/models/nemotron_labs_diffusion/language.py index 6d8fc643b..0adc8364d 100644 --- a/mlx_vlm/models/nemotron_labs_diffusion/language.py +++ b/mlx_vlm/models/nemotron_labs_diffusion/language.py @@ -83,6 +83,8 @@ def _llama4_attention_scale( class MLP(nn.Module): def __init__(self, config: ModelConfig): super().__init__() + self.small_sequence_chunks = 7 if config.intermediate_size % 7 == 0 else 1 + self.tiny_sequence_chunks = 28 if config.intermediate_size % 28 == 0 else 1 self.gate_proj = nn.Linear( config.hidden_size, config.intermediate_size, bias=config.mlp_bias ) @@ -93,8 +95,31 @@ def __init__(self, config: ModelConfig): config.intermediate_size, config.hidden_size, bias=config.mlp_bias ) + @staticmethod + def _chunked_linear(linear: nn.Linear, x: mx.array, chunks: int) -> mx.array: + weight_chunks = mx.split(linear.weight, chunks, axis=0) + outputs = [mx.matmul(x, weight.T) for weight in weight_chunks] + bias = getattr(linear, "bias", None) + if bias is not None: + bias_chunks = mx.split(bias, chunks, axis=0) + outputs = [output + bias for output, bias in zip(outputs, bias_chunks)] + return mx.concatenate(outputs, axis=-1) + def __call__(self, x: mx.array) -> mx.array: - return self.down_proj(swiglu(self.gate_proj(x), self.up_proj(x))) + sequence_length = x.shape[-2] + chunks = 1 + if 2 <= sequence_length <= 8: + chunks = self.tiny_sequence_chunks + elif sequence_length <= 16: + chunks = self.small_sequence_chunks + + if chunks > 1: + gate = self._chunked_linear(self.gate_proj, x, chunks) + up = self._chunked_linear(self.up_proj, x, chunks) + else: + gate = self.gate_proj(x) + up = self.up_proj(x) + return self.down_proj(swiglu(gate, up)) class DraftLoRALinear(nn.Module): @@ -151,6 +176,7 @@ def __call__( mask: Optional[mx.array] = None, cache: Optional[Any] = None, use_cache: bool = True, + attention_scale: Optional[mx.array] = None, ) -> mx.array: B, L, _ = x.shape queries = self.q_proj(x).reshape(B, L, self.num_heads, self.head_dim) @@ -164,9 +190,11 @@ def __call__( offset = cache.offset if cache is not None else 0 queries = self.rope(queries, offset=offset) keys = self.rope(keys, offset=offset) - queries = queries * _llama4_attention_scale( - self.config, L, offset, queries.dtype - ) + if attention_scale is None: + attention_scale = _llama4_attention_scale( + self.config, L, offset, queries.dtype + ) + queries = queries * attention_scale if cache is not None: if use_cache: @@ -202,9 +230,14 @@ def __call__( mask: Optional[mx.array] = None, cache: Optional[Any] = None, use_cache: bool = True, + attention_scale: Optional[mx.array] = None, ) -> mx.array: r = self.self_attn( - self.input_layernorm(x), mask=mask, cache=cache, use_cache=use_cache + self.input_layernorm(x), + mask=mask, + cache=cache, + use_cache=use_cache, + attention_scale=attention_scale, ) h = x + r return h + self.mlp(self.post_attention_layernorm(h)) @@ -237,8 +270,19 @@ def __call__( layer_mask = _make_bidirectional_mask( mask if mask is not None else attention_mask, h ) + first_cache = cache[0] if cache else None + offset = first_cache.offset if first_cache is not None else 0 + attention_scale = _llama4_attention_scale( + self.config, h.shape[1], offset, h.dtype + ) for layer, layer_cache in zip(self.layers, cache): - h = layer(h, mask=layer_mask, cache=layer_cache, use_cache=use_cache) + h = layer( + h, + mask=layer_mask, + cache=layer_cache, + use_cache=use_cache, + attention_scale=attention_scale, + ) return self.norm(h) @@ -256,6 +300,9 @@ def __init__(self, config: ModelConfig): self.diffusion_head = nn.Linear( config.hidden_size, config.vocab_size, bias=False ) + self.small_sequence_head_chunks = ( + 32 if config.vocab_size >= 4096 and config.vocab_size % 32 == 0 else 1 + ) self._linear_spec_lora_loaded = False def __call__( @@ -275,11 +322,7 @@ def __call__( use_cache=kwargs.get("use_cache", True), use_causal_mask=kwargs.get("use_causal_mask", True), ) - if self.config.tie_word_embeddings: - out = self.model.embed_tokens.as_linear(out) - else: - out = self.diffusion_head(out) - return LanguageModelOutput(logits=out) + return LanguageModelOutput(logits=self._project_hidden(out)) @staticmethod def _top_k_logits(logits: mx.array, k: Optional[int]) -> mx.array: @@ -332,6 +375,14 @@ def _sample_with_temperature_topk_topp( def _project_hidden(self, hidden_states: mx.array) -> mx.array: if self.config.tie_word_embeddings: return self.model.embed_tokens.as_linear(hidden_states) + if self.small_sequence_head_chunks > 1 and 2 <= hidden_states.shape[-2] <= 16: + weight_chunks = mx.split( + self.diffusion_head.weight, self.small_sequence_head_chunks, axis=0 + ) + return mx.concatenate( + [mx.matmul(hidden_states, weight.T) for weight in weight_chunks], + axis=-1, + ) return self.diffusion_head(hidden_states) def _sample_tokens( @@ -558,14 +609,15 @@ def visualize_tokens(tokens: mx.array, force: bool = False) -> None: prompt_tic = time.perf_counter() recorded_prompt_time = False cache = self.make_cache() - prefill_logits = self( + prefill_hidden = self.model( inputs, cache=cache, use_cache=True, use_causal_mask=True, - ).logits + ) + prefill_logits = self._project_hidden(prefill_hidden[:, -1:, :])[:, -1, :] next_token = self._sample_tokens( - prefill_logits[:, -1, :], + prefill_logits, temperature=temperature, top_k=top_k, top_p=top_p, @@ -595,7 +647,8 @@ def visualize_tokens(tokens: mx.array, force: bool = False) -> None: ) visualize_tokens(preview, force=True) - for step_idx in range(steps): + denoise_steps = max(1, min(steps, block_length)) + for step_idx in range(denoise_steps): mask_index = block == mask_id if not bool(mask_index.any().item()): break @@ -621,7 +674,10 @@ def visualize_tokens(tokens: mx.array, force: bool = False) -> None: mx.finfo(token_probs.dtype).min, dtype=token_probs.dtype ) confidence = mx.where(mask_index, token_probs, neg_large) - if threshold is not None: + force_completion = step_idx == denoise_steps - 1 + if force_completion: + transfer_mask = mask_index + elif threshold is not None: high_confidence = (confidence >= threshold) & mask_index _, best_index = _topk(confidence, 1) best_mask = ( @@ -633,8 +689,10 @@ def visualize_tokens(tokens: mx.array, force: bool = False) -> None: best_mask, ) else: - remaining_steps = max(1, steps - step_idx) masked_count = int(mask_index.sum().item()) + if masked_count == 0: + break + remaining_steps = max(1, denoise_steps - step_idx) transfer_count = max( 1, (masked_count + remaining_steps - 1) // remaining_steps ) @@ -652,15 +710,18 @@ def visualize_tokens(tokens: mx.array, force: bool = False) -> None: else block ) visualize_tokens(preview) + if not bool((block == mask_id).any().item()): + break - output = self( + output_hidden = self.model( block, cache=cache, use_cache=True, use_causal_mask=True, ) + next_logits = self._project_hidden(output_hidden[:, -1:, :])[:, -1, :] next_token = self._sample_tokens( - output.logits[:, -1, :], + next_logits, temperature=temperature, top_k=top_k, top_p=top_p, From 1b021e95102bb592e695fabd95ca66e0950922e4 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 30 May 2026 17:44:27 +0200 Subject: [PATCH 09/19] Optimize Nemotron diffusion small-row kernels --- .../nemotron_labs_diffusion/language.py | 319 +++++++++++++++++- 1 file changed, 316 insertions(+), 3 deletions(-) diff --git a/mlx_vlm/models/nemotron_labs_diffusion/language.py b/mlx_vlm/models/nemotron_labs_diffusion/language.py index 0adc8364d..11112458f 100644 --- a/mlx_vlm/models/nemotron_labs_diffusion/language.py +++ b/mlx_vlm/models/nemotron_labs_diffusion/language.py @@ -17,6 +17,8 @@ from ..cache import KVCache from .config import ModelConfig +_HAS_METAL = mx.metal.is_available() + def _topk(x: mx.array, k: int, axis: int = -1) -> Tuple[mx.array, mx.array]: indices = mx.argpartition(-x, kth=k - 1, axis=axis)[..., :k] @@ -80,6 +82,280 @@ def _llama4_attention_scale( return scale.astype(dtype)[None, None, :, None] +_SMALL_ROW_GEMV_KERNEL = ( + mx.fast.metal_kernel( + name="nemotron_small_row_gemv", + input_names=["x", "weight"], + output_names=["out"], + header="#include \nusing namespace metal;\n", + source=r""" + uint lane = thread_position_in_grid.x; + uint out_block = thread_position_in_grid.y; + uint row = thread_position_in_grid.z; + + constexpr int TM = 4; + constexpr int TN = 4; + constexpr int SN = 32; + constexpr int blockN = SN * TN; + + if (row >= R) { + return; + } + + int out_row = int(out_block * TM); + if (out_row >= O) { + return; + } + + const device T* in_vec = x + row * K; + const device T* mat = weight + out_row * K; + + float result[TM] = {0.0f, 0.0f, 0.0f, 0.0f}; + int col = int(lane * TN); + int n_iter = K / blockN; + int leftover = K - blockN * n_iter; + + for (int iter = 0; iter < n_iter; ++iter) { + float v[TN]; + for (int tn = 0; tn < TN; ++tn) { + v[tn] = static_cast(in_vec[col + tn]); + } + + for (int tm = 0; tm < TM; ++tm) { + for (int tn = 0; tn < TN; ++tn) { + result[tm] += static_cast(mat[tm * K + col + tn]) * v[tn]; + } + } + + col += blockN; + } + + if (leftover > 0) { + float v[TN]; + for (int tn = 0; tn < TN; ++tn) { + v[tn] = (col + tn < K) ? static_cast(in_vec[col + tn]) : 0.0f; + } + + for (int tm = 0; tm < TM; ++tm) { + for (int tn = 0; tn < TN; ++tn) { + T m = (col + tn < K) ? mat[tm * K + col + tn] : T(0); + result[tm] += static_cast(m) * v[tn]; + } + } + } + + for (int tm = 0; tm < TM; ++tm) { + for (ushort sn = (SN / 2); sn >= 1; sn >>= 1) { + result[tm] += simd_shuffle_down(result[tm], sn); + } + } + + if (lane == 0) { + for (int tm = 0; tm < TM; ++tm) { + out[row * O + out_row + tm] = static_cast(result[tm]); + } + } + """, + ) + if _HAS_METAL + else None +) + + +def _small_row_gemv_weight( + weight: mx.array, x: mx.array, max_sequence_length: int +) -> Optional[mx.array]: + if ( + _SMALL_ROW_GEMV_KERNEL is None + or x.ndim != 3 + or x.dtype != weight.dtype + or x.dtype not in (mx.bfloat16, mx.float16, mx.float32) + or not (2 <= x.shape[1] <= max_sequence_length) + ): + return None + + batch, length, in_dim = x.shape + out_dim, weight_in_dim = weight.shape + if in_dim != weight_in_dim or out_dim < 4 or out_dim % 4 != 0: + return None + + rows = batch * length + rows8 = ((rows + 7) // 8) * 8 + out = _SMALL_ROW_GEMV_KERNEL( + inputs=[x.reshape(rows, in_dim), weight], + template=[ + ("T", x.dtype), + ("K", in_dim), + ("O", out_dim), + ("R", rows), + ], + grid=(32, out_dim // 4, rows8), + threadgroup=(32, 1, 8), + output_shapes=[(rows, out_dim)], + output_dtypes=[x.dtype], + )[0] + return out.reshape(batch, length, out_dim) + + +def _small_row_linear( + linear: nn.Linear, x: mx.array, max_sequence_length: int +) -> Optional[mx.array]: + if not isinstance(linear, nn.Linear): + return None + out = _small_row_gemv_weight(linear.weight, x, max_sequence_length) + if out is None: + return None + bias = getattr(linear, "bias", None) + if bias is not None: + out = out + bias.astype(out.dtype) + return out + + +_SMALL_ROW_SWIGLU_KERNEL = ( + mx.fast.metal_kernel( + name="nemotron_small_row_swiglu", + input_names=["x", "gate_weight", "up_weight"], + output_names=["out"], + header="#include \nusing namespace metal;\n", + source=r""" + uint lane = thread_position_in_grid.x; + uint out_block = thread_position_in_grid.y; + uint row = thread_position_in_grid.z; + + constexpr int TM = 4; + constexpr int TN = 4; + constexpr int SN = 32; + constexpr int blockN = SN * TN; + + if (row >= R) { + return; + } + + int out_row = int(out_block * TM); + if (out_row >= O) { + return; + } + + const device T* in_vec = x + row * K; + const device T* gate_mat = gate_weight + out_row * K; + const device T* up_mat = up_weight + out_row * K; + + float gate_result[TM] = {0.0f, 0.0f, 0.0f, 0.0f}; + float up_result[TM] = {0.0f, 0.0f, 0.0f, 0.0f}; + int col = int(lane * TN); + int n_iter = K / blockN; + int leftover = K - blockN * n_iter; + + for (int iter = 0; iter < n_iter; ++iter) { + float v[TN]; + for (int tn = 0; tn < TN; ++tn) { + v[tn] = static_cast(in_vec[col + tn]); + } + + for (int tm = 0; tm < TM; ++tm) { + for (int tn = 0; tn < TN; ++tn) { + float value = v[tn]; + gate_result[tm] += + static_cast(gate_mat[tm * K + col + tn]) * value; + up_result[tm] += + static_cast(up_mat[tm * K + col + tn]) * value; + } + } + + col += blockN; + } + + if (leftover > 0) { + float v[TN]; + for (int tn = 0; tn < TN; ++tn) { + v[tn] = (col + tn < K) ? static_cast(in_vec[col + tn]) : 0.0f; + } + + for (int tm = 0; tm < TM; ++tm) { + for (int tn = 0; tn < TN; ++tn) { + float value = v[tn]; + T gate_weight_value = + (col + tn < K) ? gate_mat[tm * K + col + tn] : T(0); + T up_weight_value = + (col + tn < K) ? up_mat[tm * K + col + tn] : T(0); + gate_result[tm] += static_cast(gate_weight_value) * value; + up_result[tm] += static_cast(up_weight_value) * value; + } + } + } + + for (int tm = 0; tm < TM; ++tm) { + for (ushort sn = (SN / 2); sn >= 1; sn >>= 1) { + gate_result[tm] += simd_shuffle_down(gate_result[tm], sn); + up_result[tm] += simd_shuffle_down(up_result[tm], sn); + } + } + + if (lane == 0) { + for (int tm = 0; tm < TM; ++tm) { + float gate = static_cast(static_cast(gate_result[tm])); + float up = static_cast(static_cast(up_result[tm])); + float activated = gate / (1.0f + exp(-gate)); + out[row * O + out_row + tm] = static_cast(activated * up); + } + } + """, + ) + if _HAS_METAL + else None +) + + +def _small_row_swiglu( + gate_proj: nn.Linear, + up_proj: nn.Linear, + x: mx.array, + max_sequence_length: int, +) -> Optional[mx.array]: + if ( + _SMALL_ROW_SWIGLU_KERNEL is None + or not isinstance(gate_proj, nn.Linear) + or not isinstance(up_proj, nn.Linear) + or getattr(gate_proj, "bias", None) is not None + or getattr(up_proj, "bias", None) is not None + or x.ndim != 3 + or x.dtype != gate_proj.weight.dtype + or x.dtype != up_proj.weight.dtype + or x.dtype not in (mx.bfloat16, mx.float16, mx.float32) + or not (2 <= x.shape[1] <= max_sequence_length) + ): + return None + + batch, length, in_dim = x.shape + out_dim, weight_in_dim = gate_proj.weight.shape + up_out_dim, up_weight_in_dim = up_proj.weight.shape + if ( + in_dim != weight_in_dim + or in_dim != up_weight_in_dim + or out_dim != up_out_dim + or out_dim < 4 + or out_dim % 4 != 0 + ): + return None + + rows = batch * length + rows8 = ((rows + 7) // 8) * 8 + out = _SMALL_ROW_SWIGLU_KERNEL( + inputs=[x.reshape(rows, in_dim), gate_proj.weight, up_proj.weight], + template=[ + ("T", x.dtype), + ("K", in_dim), + ("O", out_dim), + ("R", rows), + ], + grid=(32, out_dim // 4, rows8), + threadgroup=(32, 1, 8), + output_shapes=[(rows, out_dim)], + output_dtypes=[x.dtype], + )[0] + return out.reshape(batch, length, out_dim) + + class MLP(nn.Module): def __init__(self, config: ModelConfig): super().__init__() @@ -107,6 +383,23 @@ def _chunked_linear(linear: nn.Linear, x: mx.array, chunks: int) -> mx.array: def __call__(self, x: mx.array) -> mx.array: sequence_length = x.shape[-2] + if 2 <= sequence_length <= 8: + hidden = _small_row_swiglu( + self.gate_proj, + self.up_proj, + x, + max_sequence_length=8, + ) + if hidden is not None: + down = _small_row_linear(self.down_proj, hidden, max_sequence_length=8) + if down is not None: + return down + return self.down_proj(hidden) + gate = _small_row_linear(self.gate_proj, x, max_sequence_length=8) + up = _small_row_linear(self.up_proj, x, max_sequence_length=8) + if gate is not None and up is not None: + return self.down_proj(swiglu(gate, up)) + chunks = 1 if 2 <= sequence_length <= 8: chunks = self.tiny_sequence_chunks @@ -179,7 +472,10 @@ def __call__( attention_scale: Optional[mx.array] = None, ) -> mx.array: B, L, _ = x.shape - queries = self.q_proj(x).reshape(B, L, self.num_heads, self.head_dim) + queries = _small_row_linear(self.q_proj, x, max_sequence_length=8) + if queries is None: + queries = self.q_proj(x) + queries = queries.reshape(B, L, self.num_heads, self.head_dim) keys = self.k_proj(x).reshape(B, L, self.num_key_value_heads, self.head_dim) values = self.v_proj(x).reshape(B, L, self.num_key_value_heads, self.head_dim) @@ -211,6 +507,9 @@ def __call__( queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + projected = _small_row_linear(self.o_proj, output, max_sequence_length=8) + if projected is not None: + return projected return self.o_proj(output) @@ -374,7 +673,21 @@ def _sample_with_temperature_topk_topp( def _project_hidden(self, hidden_states: mx.array) -> mx.array: if self.config.tie_word_embeddings: + out = _small_row_gemv_weight( + self.model.embed_tokens.weight, + hidden_states, + max_sequence_length=8, + ) + if out is not None: + return out return self.model.embed_tokens.as_linear(hidden_states) + out = _small_row_linear( + self.diffusion_head, + hidden_states, + max_sequence_length=8, + ) + if out is not None: + return out if self.small_sequence_head_chunks > 1 and 2 <= hidden_states.shape[-2] <= 16: weight_chunks = mx.split( self.diffusion_head.weight, self.small_sequence_head_chunks, axis=0 @@ -650,8 +963,6 @@ def visualize_tokens(tokens: mx.array, force: bool = False) -> None: denoise_steps = max(1, min(steps, block_length)) for step_idx in range(denoise_steps): mask_index = block == mask_id - if not bool(mask_index.any().item()): - break logits = self( block, cache=cache, @@ -710,6 +1021,8 @@ def visualize_tokens(tokens: mx.array, force: bool = False) -> None: else block ) visualize_tokens(preview) + if force_completion: + break if not bool((block == mask_id).any().item()): break From 4cc91a1bfb2ad1d0315ca616f208e3ea2a7de214 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 30 May 2026 20:18:46 +0200 Subject: [PATCH 10/19] Optimize Nemotron linear speculative decoding --- .../nemotron_labs_diffusion/language.py | 252 +++++++++++------- 1 file changed, 163 insertions(+), 89 deletions(-) diff --git a/mlx_vlm/models/nemotron_labs_diffusion/language.py b/mlx_vlm/models/nemotron_labs_diffusion/language.py index 11112458f..3d95bcc9c 100644 --- a/mlx_vlm/models/nemotron_labs_diffusion/language.py +++ b/mlx_vlm/models/nemotron_labs_diffusion/language.py @@ -359,8 +359,9 @@ def _small_row_swiglu( class MLP(nn.Module): def __init__(self, config: ModelConfig): super().__init__() - self.small_sequence_chunks = 7 if config.intermediate_size % 7 == 0 else 1 + self.small_sequence_chunks = 4 if config.intermediate_size % 4 == 0 else 1 self.tiny_sequence_chunks = 28 if config.intermediate_size % 28 == 0 else 1 + self.medium_sequence_chunks = 4 if config.intermediate_size % 4 == 0 else 1 self.gate_proj = nn.Linear( config.hidden_size, config.intermediate_size, bias=config.mlp_bias ) @@ -405,6 +406,8 @@ def __call__(self, x: mx.array) -> mx.array: chunks = self.tiny_sequence_chunks elif sequence_length <= 16: chunks = self.small_sequence_chunks + elif sequence_length <= 32: + chunks = self.medium_sequence_chunks if chunks > 1: gate = self._chunked_linear(self.gate_proj, x, chunks) @@ -698,6 +701,40 @@ def _project_hidden(self, hidden_states: mx.array) -> mx.array: ) return self.diffusion_head(hidden_states) + def _greedy_sample_hidden( + self, hidden_states: mx.array, return_prob: bool = False + ) -> mx.array | Tuple[mx.array, mx.array]: + logits = self._project_hidden(hidden_states) + if return_prob: + return self._sample_with_temperature_topk_topp(logits, temperature=0.0) + return self._sample_tokens(logits, temperature=0.0) + + def _sample_from_hidden( + self, + hidden_states: mx.array, + temperature: float = 0.0, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + return_prob: bool = False, + ) -> mx.array | Tuple[mx.array, mx.array]: + if temperature == 0.0: + return self._greedy_sample_hidden(hidden_states, return_prob=return_prob) + + logits = self._project_hidden(hidden_states) + if return_prob: + return self._sample_with_temperature_topk_topp( + logits, + temperature=temperature, + top_k=top_k, + top_p=top_p, + ) + return self._sample_tokens( + logits, + temperature=temperature, + top_k=top_k, + top_p=top_p, + ) + def _sample_tokens( self, logits: mx.array, @@ -928,13 +965,12 @@ def visualize_tokens(tokens: mx.array, force: bool = False) -> None: use_cache=True, use_causal_mask=True, ) - prefill_logits = self._project_hidden(prefill_hidden[:, -1:, :])[:, -1, :] - next_token = self._sample_tokens( - prefill_logits, + next_token = self._sample_from_hidden( + prefill_hidden[:, -1:, :], temperature=temperature, top_k=top_k, top_p=top_p, - )[:, None] + ) mx.eval(next_token) if stats is not None: stats["prompt_time"] = time.perf_counter() - prompt_tic @@ -949,8 +985,8 @@ def visualize_tokens(tokens: mx.array, force: bool = False) -> None: if remaining <= 0: break current_block_length = min(block_length, remaining) - block_positions = mx.arange(block_length) - block = mx.full((1, block_length), mask_id, dtype=inputs.dtype) + block_positions = mx.arange(current_block_length) + block = mx.full((1, current_block_length), mask_id, dtype=inputs.dtype) block[:, 0] = next_token[:, 0] if visualizer_state["active"]: preview = ( @@ -960,35 +996,49 @@ def visualize_tokens(tokens: mx.array, force: bool = False) -> None: ) visualize_tokens(preview, force=True) - denoise_steps = max(1, min(steps, block_length)) - for step_idx in range(denoise_steps): + denoise_steps = max(1, min(steps, current_block_length)) + denoise_range = range(denoise_steps) if current_block_length > 1 else () + for step_idx in denoise_range: mask_index = block == mask_id - logits = self( + force_completion = step_idx == denoise_steps - 1 + hidden_states = self.model( block, cache=cache, use_cache=False, use_causal_mask=False, - ).logits - x0, token_probs = self._sample_with_temperature_topk_topp( - logits, - temperature=temperature, - top_k=top_k, - top_p=top_p, ) + if force_completion: + x0 = self._sample_from_hidden( + hidden_states, + temperature=temperature, + top_k=top_k, + top_p=top_p, + ) + token_probs = None + else: + x0, token_probs = self._sample_from_hidden( + hidden_states, + temperature=temperature, + top_k=top_k, + top_p=top_p, + return_prob=True, + ) if stats is not None and not recorded_prompt_time: - mx.eval(x0, token_probs) + if token_probs is None: + mx.eval(x0) + else: + mx.eval(x0, token_probs) stats["prompt_time"] = time.perf_counter() - prompt_tic stats["prompt_tokens"] = float(inputs.size) recorded_prompt_time = True x0 = mx.where(mask_index, x0, block) - neg_large = mx.array( - mx.finfo(token_probs.dtype).min, dtype=token_probs.dtype - ) - confidence = mx.where(mask_index, token_probs, neg_large) - force_completion = step_idx == denoise_steps - 1 if force_completion: transfer_mask = mask_index elif threshold is not None: + neg_large = mx.array( + mx.finfo(token_probs.dtype).min, dtype=token_probs.dtype + ) + confidence = mx.where(mask_index, token_probs, neg_large) high_confidence = (confidence >= threshold) & mask_index _, best_index = _topk(confidence, 1) best_mask = ( @@ -1000,6 +1050,10 @@ def visualize_tokens(tokens: mx.array, force: bool = False) -> None: best_mask, ) else: + neg_large = mx.array( + mx.finfo(token_probs.dtype).min, dtype=token_probs.dtype + ) + confidence = mx.where(mask_index, token_probs, neg_large) masked_count = int(mask_index.sum().item()) if masked_count == 0: break @@ -1026,27 +1080,29 @@ def visualize_tokens(tokens: mx.array, force: bool = False) -> None: if not bool((block == mask_id).any().item()): break + generated_block = block[:, :current_block_length] + generated_blocks.append(generated_block) + total_generated += current_block_length + if eos_early_stop: + eos_index = _first_token_index(generated_block[0], eos_token_ids) + if eos_index is not None: + end_length = total_generated - current_block_length + eos_index + 1 + break + if total_generated >= gen_length: + break + output_hidden = self.model( block, cache=cache, use_cache=True, use_causal_mask=True, ) - next_logits = self._project_hidden(output_hidden[:, -1:, :])[:, -1, :] - next_token = self._sample_tokens( - next_logits, + next_token = self._sample_from_hidden( + output_hidden[:, -1:, :], temperature=temperature, top_k=top_k, top_p=top_p, - )[:, None] - generated_block = block[:, :current_block_length] - generated_blocks.append(generated_block) - total_generated += current_block_length - if eos_early_stop: - eos_index = _first_token_index(generated_block[0], eos_token_ids) - if eos_index is not None: - end_length = total_generated - current_block_length + eos_index + 1 - break + ) generated = ( mx.concatenate(generated_blocks, axis=1) @@ -1089,28 +1145,27 @@ def ar_generate( prompt_tic = time.perf_counter() cache = self.make_cache() - prefill = self( + prefill_hidden = self.model( prompt_ids, cache=cache, use_cache=True, use_causal_mask=True, - ).logits - mx.eval(prefill) + ) + next_token = self._sample_from_hidden( + prefill_hidden[:, -1:, :], + temperature=temperature, + top_k=top_k, + top_p=top_p, + ) + mx.eval(next_token) if stats is not None: stats["prompt_time"] = time.perf_counter() - prompt_tic stats["prompt_tokens"] = float(prompt_ids.size) generated = [] - next_logits = prefill[:, -1, :] nfe = 0 for _ in range(max_new_tokens): nfe += 1 - next_token = self._sample_tokens( - next_logits, - temperature=temperature, - top_k=top_k, - top_p=top_p, - )[:, None] generated.append(next_token) if bool( mx.array( @@ -1120,12 +1175,18 @@ def ar_generate( .item() ): break - next_logits = self( + next_hidden = self.model( next_token, cache=cache, use_cache=True, use_causal_mask=True, - ).logits[:, -1, :] + ) + next_token = self._sample_from_hidden( + next_hidden[:, -1:, :], + temperature=temperature, + top_k=top_k, + top_p=top_p, + ) if not generated: return prompt_ids, nfe @@ -1152,6 +1213,9 @@ def linear_spec_generate( raise ValueError("Linear speculative decoding requires batch size 1.") if block_length <= 0: raise ValueError("block_length must be a positive integer.") + # Treat block_length as a cap; acceptance is low beyond 8 and larger + # windows make each draft/verify pair slower without changing output. + draft_window = min(block_length, 8) mask_token_id = ( self.config.mask_token_id if mask_token_id is None else mask_token_id @@ -1166,23 +1230,23 @@ def linear_spec_generate( prompt_tic = time.perf_counter() cache = self.make_cache() - prefill = self( + prefill_hidden = self.model( prompt_ids, cache=cache, use_cache=True, use_causal_mask=True, - ).logits - mx.eval(prefill) + ) + next_token = self._sample_from_hidden( + prefill_hidden[:, -1:, :], + temperature=temperature, + top_k=top_k, + top_p=top_p, + ) + mx.eval(next_token) if stats is not None: stats["prompt_time"] = time.perf_counter() - prompt_tic stats["prompt_tokens"] = float(prompt_ids.size) - next_token = self._sample_tokens( - prefill[:, -1, :], - temperature=temperature, - top_k=top_k, - top_p=top_p, - )[:, None] generated = [next_token] total_generated = 1 nfe = 1 @@ -1192,25 +1256,29 @@ def linear_spec_generate( while total_generated < max_new_tokens: cache_len = cache[0].offset - block = mx.full((1, block_length), mask_token_id, dtype=prompt_ids.dtype) + current_block_length = min(draft_window, max_new_tokens - total_generated) + block = mx.full( + (1, current_block_length), mask_token_id, dtype=prompt_ids.dtype + ) block[:, 0] = next_token[:, 0] while bool((block == mask_token_id).any().item()): self.set_linear_spec_lora_enabled(True) - draft_logits = self( + draft_hidden = self.model( block, cache=cache, use_cache=False, use_causal_mask=False, - ).logits + ) nfe += 1 is_mask = block == mask_token_id if threshold > 0: - draft_tokens, draft_probs = self._sample_with_temperature_topk_topp( - draft_logits, + draft_tokens, draft_probs = self._sample_from_hidden( + draft_hidden, temperature=temperature, top_k=top_k, top_p=top_p, + return_prob=True, ) neg_large = mx.array( mx.finfo(draft_probs.dtype).min, dtype=draft_probs.dtype @@ -1225,8 +1293,8 @@ def linear_spec_generate( ) block = mx.where(unmask, draft_tokens, block) else: - draft_tokens = self._sample_tokens( - draft_logits, + draft_tokens = self._sample_from_hidden( + draft_hidden, temperature=temperature, top_k=top_k, top_p=top_p, @@ -1235,15 +1303,15 @@ def linear_spec_generate( break self.set_linear_spec_lora_enabled(False) - verify_logits = self( + verify_hidden = self.model( block, cache=cache, use_cache=True, use_causal_mask=True, - ).logits + ) nfe += 1 - ar_tokens = self._sample_tokens( - verify_logits, + ar_tokens = self._sample_from_hidden( + verify_hidden, temperature=temperature, top_k=top_k, top_p=top_p, @@ -1252,7 +1320,7 @@ def linear_spec_generate( ar_token_ids = ar_tokens[0].tolist() block_ids = block[0].tolist() accepted = 1 - for i in range(block_length - 1): + for i in range(current_block_length - 1): if ar_token_ids[i] == block_ids[i + 1]: accepted += 1 else: @@ -1293,6 +1361,9 @@ def stream_linear_spec_generate( raise ValueError("Linear speculative decoding requires batch size 1.") if block_length <= 0: raise ValueError("block_length must be a positive integer.") + # Treat block_length as a cap; acceptance is low beyond 8 and larger + # windows make each draft/verify pair slower without changing output. + draft_window = min(block_length, 8) mask_token_id = ( self.config.mask_token_id if mask_token_id is None else mask_token_id @@ -1307,24 +1378,23 @@ def stream_linear_spec_generate( prompt_tic = time.perf_counter() cache = self.make_cache() - prefill = self( + prefill_hidden = self.model( prompt_ids, cache=cache, use_cache=True, use_causal_mask=True, - ).logits - mx.eval(prefill) - if stats is not None: - stats["prompt_time"] = time.perf_counter() - prompt_tic - stats["prompt_tokens"] = float(prompt_ids.size) - - next_token = self._sample_tokens( - prefill[:, -1, :], + ) + next_token = self._sample_from_hidden( + prefill_hidden[:, -1:, :], temperature=temperature, top_k=top_k, top_p=top_p, - )[:, None] + ) mx.eval(next_token) + if stats is not None: + stats["prompt_time"] = time.perf_counter() - prompt_tic + stats["prompt_tokens"] = float(prompt_ids.size) + yield next_token total_generated = 1 @@ -1333,24 +1403,28 @@ def stream_linear_spec_generate( while total_generated < max_new_tokens: cache_len = cache[0].offset - block = mx.full((1, block_length), mask_token_id, dtype=prompt_ids.dtype) + current_block_length = min(draft_window, max_new_tokens - total_generated) + block = mx.full( + (1, current_block_length), mask_token_id, dtype=prompt_ids.dtype + ) block[:, 0] = next_token[:, 0] while bool((block == mask_token_id).any().item()): self.set_linear_spec_lora_enabled(True) - draft_logits = self( + draft_hidden = self.model( block, cache=cache, use_cache=False, use_causal_mask=False, - ).logits + ) is_mask = block == mask_token_id if threshold > 0: - draft_tokens, draft_probs = self._sample_with_temperature_topk_topp( - draft_logits, + draft_tokens, draft_probs = self._sample_from_hidden( + draft_hidden, temperature=temperature, top_k=top_k, top_p=top_p, + return_prob=True, ) neg_large = mx.array( mx.finfo(draft_probs.dtype).min, dtype=draft_probs.dtype @@ -1365,8 +1439,8 @@ def stream_linear_spec_generate( ) block = mx.where(unmask, draft_tokens, block) else: - draft_tokens = self._sample_tokens( - draft_logits, + draft_tokens = self._sample_from_hidden( + draft_hidden, temperature=temperature, top_k=top_k, top_p=top_p, @@ -1375,14 +1449,14 @@ def stream_linear_spec_generate( break self.set_linear_spec_lora_enabled(False) - verify_logits = self( + verify_hidden = self.model( block, cache=cache, use_cache=True, use_causal_mask=True, - ).logits - ar_tokens = self._sample_tokens( - verify_logits, + ) + ar_tokens = self._sample_from_hidden( + verify_hidden, temperature=temperature, top_k=top_k, top_p=top_p, @@ -1391,7 +1465,7 @@ def stream_linear_spec_generate( ar_token_ids = ar_tokens[0].tolist() block_ids = block[0].tolist() accepted = 1 - for i in range(block_length - 1): + for i in range(current_block_length - 1): if ar_token_ids[i] == block_ids[i + 1]: accepted += 1 else: From b65feb520342849953068ef38438a20ed7e5f86a Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 30 May 2026 22:14:38 +0200 Subject: [PATCH 11/19] Optimize Nemotron diffusion scoring --- .../nemotron_labs_diffusion/language.py | 137 ++++++++++++------ 1 file changed, 95 insertions(+), 42 deletions(-) diff --git a/mlx_vlm/models/nemotron_labs_diffusion/language.py b/mlx_vlm/models/nemotron_labs_diffusion/language.py index 3d95bcc9c..393361e29 100644 --- a/mlx_vlm/models/nemotron_labs_diffusion/language.py +++ b/mlx_vlm/models/nemotron_labs_diffusion/language.py @@ -1000,74 +1000,113 @@ def visualize_tokens(tokens: mx.array, force: bool = False) -> None: denoise_range = range(denoise_steps) if current_block_length > 1 else () for step_idx in denoise_range: mask_index = block == mask_id + masked_count = int(mask_index.sum().item()) + if masked_count == 0: + break force_completion = step_idx == denoise_steps - 1 + masked_positions = mx.sort( + mx.where(mask_index[0], block_positions, current_block_length) + )[:masked_count] hidden_states = self.model( block, cache=cache, use_cache=False, use_causal_mask=False, ) - if force_completion: - x0 = self._sample_from_hidden( - hidden_states, + masked_hidden_states = mx.take(hidden_states, masked_positions, axis=1) + need_confidence = not force_completion and masked_count > 1 + if need_confidence: + sampled_tokens, token_probs = self._sample_from_hidden( + masked_hidden_states, temperature=temperature, top_k=top_k, top_p=top_p, + return_prob=True, ) - token_probs = None else: - x0, token_probs = self._sample_from_hidden( - hidden_states, + sampled_tokens = self._sample_from_hidden( + masked_hidden_states, temperature=temperature, top_k=top_k, top_p=top_p, - return_prob=True, ) + token_probs = None if stats is not None and not recorded_prompt_time: if token_probs is None: - mx.eval(x0) + mx.eval(sampled_tokens) else: - mx.eval(x0, token_probs) + mx.eval(sampled_tokens, token_probs) stats["prompt_time"] = time.perf_counter() - prompt_tic stats["prompt_tokens"] = float(inputs.size) recorded_prompt_time = True - x0 = mx.where(mask_index, x0, block) - if force_completion: + + position_matches = block_positions[None, :] == masked_positions[:, None] + sampled_block = mx.sum( + mx.where( + position_matches, + sampled_tokens[0, :, None], + mx.zeros( + (masked_count, current_block_length), dtype=block.dtype + ), + ), + axis=0, + keepdims=True, + ).astype(block.dtype) + sampled_block = mx.where(mask_index, sampled_block, block) + + if force_completion or masked_count == 1: transfer_mask = mask_index elif threshold is not None: - neg_large = mx.array( - mx.finfo(token_probs.dtype).min, dtype=token_probs.dtype + sorted_indices = mx.argsort(-token_probs, axis=-1) + sorted_confidence = mx.take_along_axis( + token_probs, sorted_indices, axis=-1 ) - confidence = mx.where(mask_index, token_probs, neg_large) - high_confidence = (confidence >= threshold) & mask_index - _, best_index = _topk(confidence, 1) - best_mask = ( - block_positions[None, None, :] == best_index[..., None] - ).any(axis=1) - transfer_mask = mx.where( - high_confidence.any(axis=1)[:, None], - high_confidence, - best_mask, + sorted_block_positions = mx.take_along_axis( + masked_positions[None, :], sorted_indices, axis=-1 ) - else: - neg_large = mx.array( - mx.finfo(token_probs.dtype).min, dtype=token_probs.dtype + sorted_positions = mx.arange(masked_count)[None, :] + positional_threshold = 1.0 - 1.0 / ( + sorted_positions.astype(sorted_confidence.dtype) + 2.0 ) - confidence = mx.where(mask_index, token_probs, neg_large) - masked_count = int(mask_index.sum().item()) - if masked_count == 0: - break + positional_threshold = mx.where( + sorted_positions == 0, + mx.array( + mx.finfo(sorted_confidence.dtype).min, + dtype=sorted_confidence.dtype, + ), + positional_threshold, + ) + lower_bound = 0.5 if min_threshold is None else min_threshold + keep_sorted = (sorted_confidence >= threshold) | ( + (sorted_confidence >= lower_bound) + & (sorted_confidence >= positional_threshold) + ) + if max_transfer_per_step is not None: + keep_sorted = keep_sorted & ( + sorted_positions < max_transfer_per_step + ) + keep_sorted = keep_sorted | (sorted_positions == 0) + kept_positions = mx.where( + keep_sorted, sorted_block_positions, current_block_length + ) + transfer_mask = ( + block_positions[None, None, :] == kept_positions[..., None] + ).any(axis=1) & mask_index + else: remaining_steps = max(1, denoise_steps - step_idx) transfer_count = max( 1, (masked_count + remaining_steps - 1) // remaining_steps ) if max_transfer_per_step is not None: transfer_count = min(transfer_count, max_transfer_per_step) - _, indices = _topk(confidence, min(transfer_count, masked_count)) + _, indices = _topk(token_probs, min(transfer_count, masked_count)) + transfer_positions = mx.take_along_axis( + masked_positions[None, :], indices, axis=-1 + ) transfer_mask = ( - block_positions[None, None, :] == indices[..., None] - ).any(axis=1) - block = mx.where(transfer_mask, x0, block) + block_positions[None, None, :] == transfer_positions[..., None] + ).any(axis=1) & mask_index + block = mx.where(transfer_mask, sampled_block, block) if visualizer_state["active"] and bool(transfer_mask.any().item()): preview = ( mx.concatenate(generated_blocks + [block], axis=1) @@ -1213,9 +1252,9 @@ def linear_spec_generate( raise ValueError("Linear speculative decoding requires batch size 1.") if block_length <= 0: raise ValueError("block_length must be a positive integer.") - # Treat block_length as a cap; acceptance is low beyond 8 and larger - # windows make each draft/verify pair slower without changing output. - draft_window = min(block_length, 8) + max_draft_window = min(block_length, 32) + base_draft_window = min(max_draft_window, 8) + draft_window = base_draft_window mask_token_id = ( self.config.mask_token_id if mask_token_id is None else mask_token_id @@ -1287,7 +1326,7 @@ def linear_spec_generate( unmask = draft_conf >= threshold if not bool(unmask.any().item()): _, best_idx = _topk(draft_conf, 1) - positions = mx.arange(block_length) + positions = mx.arange(current_block_length) unmask = (positions[None, None, :] == best_idx[..., None]).any( axis=1 ) @@ -1337,6 +1376,13 @@ def linear_spec_generate( if eos_index is not None: generated[-1] = accepted_tokens[:, : eos_index + 1] break + if accepted == current_block_length and draft_window < max_draft_window: + draft_window = min(max_draft_window, draft_window * 2) + elif ( + accepted <= max(1, current_block_length // 2) + and draft_window > base_draft_window + ): + draft_window = max(base_draft_window, draft_window // 2) return ( mx.concatenate([prompt_ids, mx.concatenate(generated, axis=1)], axis=1), @@ -1361,9 +1407,9 @@ def stream_linear_spec_generate( raise ValueError("Linear speculative decoding requires batch size 1.") if block_length <= 0: raise ValueError("block_length must be a positive integer.") - # Treat block_length as a cap; acceptance is low beyond 8 and larger - # windows make each draft/verify pair slower without changing output. - draft_window = min(block_length, 8) + max_draft_window = min(block_length, 32) + base_draft_window = min(max_draft_window, 8) + draft_window = base_draft_window mask_token_id = ( self.config.mask_token_id if mask_token_id is None else mask_token_id @@ -1433,7 +1479,7 @@ def stream_linear_spec_generate( unmask = draft_conf >= threshold if not bool(unmask.any().item()): _, best_idx = _topk(draft_conf, 1) - positions = mx.arange(block_length) + positions = mx.arange(current_block_length) unmask = (positions[None, None, :] == best_idx[..., None]).any( axis=1 ) @@ -1484,6 +1530,13 @@ def stream_linear_spec_generate( total_generated += accepted_tokens.shape[1] if eos_index is not None: break + if accepted == current_block_length and draft_window < max_draft_window: + draft_window = min(max_draft_window, draft_window * 2) + elif ( + accepted <= max(1, current_block_length // 2) + and draft_window > base_draft_window + ): + draft_window = max(base_draft_window, draft_window // 2) def sanitize(self, weights): if self.config.tie_word_embeddings: From bfcbea035277600fce02350c0bffaff905be32f8 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 30 May 2026 22:55:44 +0200 Subject: [PATCH 12/19] Add Nemotron diffusion sampler controls --- mlx_vlm/generate/dispatch.py | 25 ++ .../models/nemotron_labs_diffusion/README.md | 7 +- .../nemotron_labs_diffusion/language.py | 296 ++++++++++++++++-- mlx_vlm/tests/test_diffusion_models.py | 81 +++++ 4 files changed, 387 insertions(+), 22 deletions(-) diff --git a/mlx_vlm/generate/dispatch.py b/mlx_vlm/generate/dispatch.py index d6c70e96e..15a121e5c 100644 --- a/mlx_vlm/generate/dispatch.py +++ b/mlx_vlm/generate/dispatch.py @@ -808,6 +808,30 @@ def stream_generate( ) generation_stats = {} + handled_generation_kwargs = { + "max_tokens", + "temperature", + "top_p", + "top_k", + "max_denoising_steps", + "steps", + "block_length", + "threshold", + "min_threshold", + "editing_threshold", + "max_post_steps", + "num_to_transfer", + "max_transfer_per_step", + "stability_steps", + "linear_speculative", + "linear_speculation", + "generation_mode", + } + model_generate_kwargs = { + key: value + for key, value in kwargs.items() + if key not in handled_generation_kwargs + } tic = time.perf_counter() generated = model.language_model.generate( input_ids, @@ -832,6 +856,7 @@ def stream_generate( linear_speculative=kwargs.get("linear_speculative", False) or kwargs.get("linear_speculation", False) or kwargs.get("generation_mode") in ("linear_speculative", "linear_spec"), + **model_generate_kwargs, ) mx.eval(generated) total_time = time.perf_counter() - tic diff --git a/mlx_vlm/models/nemotron_labs_diffusion/README.md b/mlx_vlm/models/nemotron_labs_diffusion/README.md index a77a6d826..b2317939d 100644 --- a/mlx_vlm/models/nemotron_labs_diffusion/README.md +++ b/mlx_vlm/models/nemotron_labs_diffusion/README.md @@ -39,6 +39,10 @@ mlx_vlm.generate \ Pass `generation_mode="diffusion"` to use the masked diffusion path. Nemotron defaults to 32 denoising steps and a 0.9 transfer threshold in this mode, matching the upstream dLM example. The upstream mode alias `generation_mode="dlm"` is also accepted. +Sampler variants from the NVIDIA evaluation harness can be selected with `sampler`. +Supported values are `confidence_threshold_bound` (default), `native`, `fixed`, `confidence_threshold_ref`, and `cumulative_error`. +For profiling, `head_scoring="chunked"` scores masked rows without concatenating full vocabulary logits; the default remains `head_scoring="full"` because it is usually faster on MLX's optimized matmul path. +For mixed AR+dLM experiments, pass `ar_weight` between `0.0` and `1.0`; this adds an AR causal block forward during denoising and is disabled by default. ```sh mlx_vlm.generate \ @@ -47,7 +51,7 @@ mlx_vlm.generate \ --max-tokens 256 \ --max-denoising-steps 16 \ --temperature 0.0 \ - --gen-kwargs '{"generation_mode": "diffusion"}' \ + --gen-kwargs '{"generation_mode": "diffusion", "sampler": "confidence_threshold_bound"}' \ --verbose ``` @@ -170,6 +174,7 @@ print(result.text) - AR generation should use the normal CLI without diffusion-specific arguments. - Diffusion generation uses masked block denoising. `--verbose` shows the block visualization as masks are filled. - The default diffusion schedule uses 32 denoising steps and a 0.9 confidence threshold. Lower `--max-denoising-steps` for speed experiments, but quality can degrade quickly. +- Diffusion generation records model-level stats such as `diffusion_denoise_nfe`, `diffusion_post_block_nfe`, and `diffusion_tokens_per_denoise_forward`. Use `head_scoring="chunked"` to profile the non-materializing confidence scorer. - Diffusion and linear self-speculative generation are exposed through `generation_mode`, for example `--gen-kwargs '{"generation_mode": "diffusion"}'`. - Upstream mode names are accepted as aliases: `dlm` for diffusion and `linear_spec` for linear self-speculation. - The optional `linear_spec_lora` adapter included in the Hugging Face repo is used only during the diffusion draft phase of linear self-speculation. diff --git a/mlx_vlm/models/nemotron_labs_diffusion/language.py b/mlx_vlm/models/nemotron_labs_diffusion/language.py index 393361e29..a16cd3515 100644 --- a/mlx_vlm/models/nemotron_labs_diffusion/language.py +++ b/mlx_vlm/models/nemotron_labs_diffusion/language.py @@ -211,6 +211,69 @@ def _small_row_linear( return out +def _chunked_greedy_score_weight( + weight: mx.array, + x: mx.array, + chunks: int, + return_prob: bool, +) -> Optional[mx.array | Tuple[mx.array, mx.array]]: + if ( + chunks <= 1 + or x.ndim != 3 + or x.dtype != weight.dtype + or x.dtype not in (mx.bfloat16, mx.float16, mx.float32) + ): + return None + + _, _, in_dim = x.shape + out_dim, weight_in_dim = weight.shape + if in_dim != weight_in_dim or out_dim % chunks != 0: + return None + + best_token = None + best_logit = None + normalizer_max = None + normalizer_sum = None + offset = 0 + for weight_chunk in mx.split(weight, chunks, axis=0): + logits = mx.matmul(x, weight_chunk.T) + chunk_token = mx.argmax(logits, axis=-1).astype(mx.int32) + chunk_logit = mx.take_along_axis(logits, chunk_token[..., None], axis=-1)[ + ..., 0 + ] + chunk_token = chunk_token + offset + + if best_logit is None: + best_logit = chunk_logit + best_token = chunk_token + else: + take_chunk = chunk_logit > best_logit + best_logit = mx.where(take_chunk, chunk_logit, best_logit) + best_token = mx.where(take_chunk, chunk_token, best_token) + + if return_prob: + logits = logits.astype(mx.float32) + chunk_max = mx.max(logits, axis=-1) + chunk_sum = mx.sum(mx.exp(logits - chunk_max[..., None]), axis=-1) + if normalizer_max is None: + normalizer_max = chunk_max + normalizer_sum = chunk_sum + else: + new_max = mx.maximum(normalizer_max, chunk_max) + normalizer_sum = normalizer_sum * mx.exp( + normalizer_max - new_max + ) + chunk_sum * mx.exp(chunk_max - new_max) + normalizer_max = new_max + + offset += weight_chunk.shape[0] + + if not return_prob: + return best_token + + log_prob = best_logit.astype(mx.float32) - (normalizer_max + mx.log(normalizer_sum)) + return best_token, mx.exp(log_prob).astype(x.dtype) + + _SMALL_ROW_SWIGLU_KERNEL = ( mx.fast.metal_kernel( name="nemotron_small_row_swiglu", @@ -605,6 +668,9 @@ def __init__(self, config: ModelConfig): self.small_sequence_head_chunks = ( 32 if config.vocab_size >= 4096 and config.vocab_size % 32 == 0 else 1 ) + self.greedy_score_chunks = ( + 4 if config.vocab_size >= 4096 and config.vocab_size % 4 == 0 else 1 + ) self._linear_spec_lora_loaded = False def __call__( @@ -702,13 +768,37 @@ def _project_hidden(self, hidden_states: mx.array) -> mx.array: return self.diffusion_head(hidden_states) def _greedy_sample_hidden( - self, hidden_states: mx.array, return_prob: bool = False + self, + hidden_states: mx.array, + return_prob: bool = False, + chunked_score: bool = False, ) -> mx.array | Tuple[mx.array, mx.array]: + if return_prob and chunked_score: + scored = self._chunked_greedy_score_hidden(hidden_states, return_prob=True) + if scored is not None: + return scored logits = self._project_hidden(hidden_states) if return_prob: return self._sample_with_temperature_topk_topp(logits, temperature=0.0) return self._sample_tokens(logits, temperature=0.0) + def _chunked_greedy_score_hidden( + self, hidden_states: mx.array, return_prob: bool + ) -> Optional[mx.array | Tuple[mx.array, mx.array]]: + if hidden_states.shape[-2] > 32: + return None + weight = ( + self.model.embed_tokens.weight + if self.config.tie_word_embeddings + else self.diffusion_head.weight + ) + return _chunked_greedy_score_weight( + weight, + hidden_states, + chunks=self.greedy_score_chunks, + return_prob=return_prob, + ) + def _sample_from_hidden( self, hidden_states: mx.array, @@ -716,9 +806,14 @@ def _sample_from_hidden( top_k: Optional[int] = None, top_p: Optional[float] = None, return_prob: bool = False, + chunked_score: bool = False, ) -> mx.array | Tuple[mx.array, mx.array]: if temperature == 0.0: - return self._greedy_sample_hidden(hidden_states, return_prob=return_prob) + return self._greedy_sample_hidden( + hidden_states, + return_prob=return_prob, + chunked_score=chunked_score, + ) logits = self._project_hidden(hidden_states) if return_prob: @@ -815,6 +910,7 @@ def generate( skip_special_tokens: bool = False, stats: Optional[Dict[str, float]] = None, linear_speculative: bool = False, + **kwargs, ) -> mx.array: if inputs.shape[0] != 1: raise ValueError( @@ -848,10 +944,84 @@ def generate( if block_length <= 0: raise ValueError("block_length must be a positive integer.") steps = max(1, int(steps)) + sampler = kwargs.get("sampler") + sampling_scaling_factor = float( + kwargs.get("sampling_scaling_factor", kwargs.get("factor", 1.0)) + ) + ar_weight = kwargs.get("ar_weight", 0.0) + head_scoring = kwargs.get("head_scoring") + ar_weight = float(ar_weight) + if ar_weight < 0.0 or ar_weight > 1.0: + raise ValueError("ar_weight must be between 0.0 and 1.0.") + sampler_name = (sampler or "confidence_threshold_bound").lower() + sampler_aliases = { + "default": "confidence_threshold_bound", + "optimized": "confidence_threshold_bound", + "threshold_bound": "confidence_threshold_bound", + "bound": "confidence_threshold_bound", + "hf": "native", + "upstream": "native", + "confidence_threshold": "native", + "threshold": "native", + "threshold_ref": "confidence_threshold_ref", + "ref": "confidence_threshold_ref", + "cumulative": "cumulative_error", + } + sampler_name = sampler_aliases.get(sampler_name, sampler_name) + valid_samplers = { + "native", + "fixed", + "confidence_threshold_ref", + "confidence_threshold_bound", + "cumulative_error", + } + if sampler_name not in valid_samplers: + raise ValueError( + "Unsupported Nemotron diffusion sampler " + f"{sampler!r}. Expected one of {sorted(valid_samplers)}." + ) + head_scoring_name = (head_scoring or "full").lower() + head_scoring_aliases = { + "default": "full", + "full_logits": "full", + "project_full_logits": "full", + "chunked_masked": "chunked", + } + head_scoring_name = head_scoring_aliases.get( + head_scoring_name, head_scoring_name + ) + if head_scoring_name not in {"full", "chunked"}: + raise ValueError( + "Unsupported Nemotron head_scoring " + f"{head_scoring!r}. Expected 'full' or 'chunked'." + ) + use_chunked_scoring = head_scoring_name == "chunked" if max_transfer_per_step is not None: max_transfer_per_step = min( block_length, max(1, int(max_transfer_per_step)) ) + if stats is not None: + stats["diffusion_sampler"] = sampler_name + stats["diffusion_head_scoring"] = ( + "chunked_masked" + if use_chunked_scoring and self.greedy_score_chunks > 1 + else "project_full_logits" + ) + for key in ( + "diffusion_blocks", + "diffusion_denoise_nfe", + "diffusion_post_block_nfe", + "diffusion_confidence_steps", + "diffusion_argmax_only_steps", + "diffusion_masked_rows_scored", + "diffusion_accepted_tokens", + "diffusion_mixed_ar_forwards", + ): + stats.setdefault(key, 0.0) + + def add_stat(key: str, value: float = 1.0) -> None: + if stats is not None: + stats[key] = stats.get(key, 0.0) + float(value) visualizer_state = { "active": visualize and sys.stdout.isatty(), @@ -984,6 +1154,7 @@ def visualize_tokens(tokens: mx.array, force: bool = False) -> None: remaining = gen_length - total_generated if remaining <= 0: break + add_stat("diffusion_blocks") current_block_length = min(block_length, remaining) block_positions = mx.arange(current_block_length) block = mx.full((1, current_block_length), mask_id, dtype=inputs.dtype) @@ -1004,6 +1175,8 @@ def visualize_tokens(tokens: mx.array, force: bool = False) -> None: if masked_count == 0: break force_completion = step_idx == denoise_steps - 1 + add_stat("diffusion_denoise_nfe") + add_stat("diffusion_masked_rows_scored", masked_count) masked_positions = mx.sort( mx.where(mask_index[0], block_positions, current_block_length) )[:masked_count] @@ -1013,8 +1186,29 @@ def visualize_tokens(tokens: mx.array, force: bool = False) -> None: use_cache=False, use_causal_mask=False, ) + if ar_weight > 0.0: + causal_hidden_states = self.model( + block, + cache=cache, + use_cache=False, + use_causal_mask=True, + ) + shifted_causal_hidden_states = mx.concatenate( + [hidden_states[:, :1, :], causal_hidden_states[:, :-1, :]], + axis=1, + ) + hidden_states = ( + (1.0 - ar_weight) * hidden_states + + ar_weight * shifted_causal_hidden_states + ).astype(hidden_states.dtype) + add_stat("diffusion_mixed_ar_forwards") masked_hidden_states = mx.take(hidden_states, masked_positions, axis=1) need_confidence = not force_completion and masked_count > 1 + add_stat( + "diffusion_confidence_steps" + if need_confidence + else "diffusion_argmax_only_steps" + ) if need_confidence: sampled_tokens, token_probs = self._sample_from_hidden( masked_hidden_states, @@ -1022,6 +1216,7 @@ def visualize_tokens(tokens: mx.array, force: bool = False) -> None: top_k=top_k, top_p=top_p, return_prob=True, + chunked_score=use_chunked_scoring, ) else: sampled_tokens = self._sample_from_hidden( @@ -1065,26 +1260,69 @@ def visualize_tokens(tokens: mx.array, force: bool = False) -> None: masked_positions[None, :], sorted_indices, axis=-1 ) sorted_positions = mx.arange(masked_count)[None, :] - positional_threshold = 1.0 - 1.0 / ( - sorted_positions.astype(sorted_confidence.dtype) + 2.0 - ) - positional_threshold = mx.where( - sorted_positions == 0, - mx.array( - mx.finfo(sorted_confidence.dtype).min, - dtype=sorted_confidence.dtype, - ), - positional_threshold, - ) - lower_bound = 0.5 if min_threshold is None else min_threshold - keep_sorted = (sorted_confidence >= threshold) | ( - (sorted_confidence >= lower_bound) - & (sorted_confidence >= positional_threshold) - ) + transfer_limit = masked_count + if sampler_name == "fixed": + transfer_limit = min(masked_count, max(1, int(num_to_transfer))) if max_transfer_per_step is not None: - keep_sorted = keep_sorted & ( - sorted_positions < max_transfer_per_step + transfer_limit = min(transfer_limit, max_transfer_per_step) + + if sampler_name == "native": + keep_sorted = sorted_confidence >= threshold + elif sampler_name == "fixed": + keep_sorted = sorted_positions < transfer_limit + keep_sorted = keep_sorted & (sorted_confidence >= threshold) + elif sampler_name == "confidence_threshold_ref": + positional_threshold = 1.0 - sampling_scaling_factor / ( + sorted_positions.astype(sorted_confidence.dtype) + 2.0 + ) + positional_threshold = mx.where( + sorted_positions == 0, + mx.array( + mx.finfo(sorted_confidence.dtype).min, + dtype=sorted_confidence.dtype, + ), + positional_threshold, + ) + criteria = (sorted_confidence >= threshold) & ( + sorted_confidence >= positional_threshold + ) + keep_sorted = mx.cumprod( + criteria.astype(mx.int32), axis=1 + ).astype(mx.bool_) + keep_sorted = keep_sorted & (sorted_positions < transfer_limit) + elif sampler_name == "cumulative_error": + confidence_floor = mx.array( + 1e-12, dtype=sorted_confidence.dtype + ) + cumulative_log_confidence = mx.cumsum( + mx.log(mx.maximum(sorted_confidence, confidence_floor)), + axis=1, ) + keep_sorted = cumulative_log_confidence >= mx.log( + mx.array(max(float(threshold), 1e-12)) + ) + keep_sorted = keep_sorted & (sorted_positions < transfer_limit) + else: + positional_threshold = 1.0 - sampling_scaling_factor / ( + sorted_positions.astype(sorted_confidence.dtype) + 2.0 + ) + positional_threshold = mx.where( + sorted_positions == 0, + mx.array( + mx.finfo(sorted_confidence.dtype).min, + dtype=sorted_confidence.dtype, + ), + positional_threshold, + ) + lower_bound = 0.5 if min_threshold is None else min_threshold + keep_sorted = (sorted_confidence >= threshold) | ( + (sorted_confidence >= lower_bound) + & (sorted_confidence >= positional_threshold) + ) + if max_transfer_per_step is not None: + keep_sorted = keep_sorted & ( + sorted_positions < transfer_limit + ) keep_sorted = keep_sorted | (sorted_positions == 0) kept_positions = mx.where( keep_sorted, sorted_block_positions, current_block_length @@ -1107,6 +1345,11 @@ def visualize_tokens(tokens: mx.array, force: bool = False) -> None: block_positions[None, None, :] == transfer_positions[..., None] ).any(axis=1) & mask_index block = mx.where(transfer_mask, sampled_block, block) + remaining_masked_count = int((block == mask_id).sum().item()) + add_stat( + "diffusion_accepted_tokens", + masked_count - remaining_masked_count, + ) if visualizer_state["active"] and bool(transfer_mask.any().item()): preview = ( mx.concatenate(generated_blocks + [block], axis=1) @@ -1116,7 +1359,7 @@ def visualize_tokens(tokens: mx.array, force: bool = False) -> None: visualize_tokens(preview) if force_completion: break - if not bool((block == mask_id).any().item()): + if remaining_masked_count == 0: break generated_block = block[:, :current_block_length] @@ -1136,6 +1379,7 @@ def visualize_tokens(tokens: mx.array, force: bool = False) -> None: use_cache=True, use_causal_mask=True, ) + add_stat("diffusion_post_block_nfe") next_token = self._sample_from_hidden( output_hidden[:, -1:, :], temperature=temperature, @@ -1149,6 +1393,16 @@ def visualize_tokens(tokens: mx.array, force: bool = False) -> None: else mx.zeros((1, 0), dtype=inputs.dtype) ) end = end_length if end_length is not None else generated.shape[1] + if stats is not None: + stats["diffusion_generated_tokens"] = float(end) + stats["diffusion_total_nfe"] = stats.get( + "diffusion_denoise_nfe", 0.0 + ) + stats.get("diffusion_post_block_nfe", 0.0) + denoise_nfe = stats.get("diffusion_denoise_nfe", 0.0) + if denoise_nfe: + stats["diffusion_tokens_per_denoise_forward"] = ( + stats.get("diffusion_accepted_tokens", 0.0) / denoise_nfe + ) if visualizer_state["active"]: generated_ids = generated[0].tolist() finish_visualizer() diff --git a/mlx_vlm/tests/test_diffusion_models.py b/mlx_vlm/tests/test_diffusion_models.py index a4c176934..585737a52 100644 --- a/mlx_vlm/tests/test_diffusion_models.py +++ b/mlx_vlm/tests/test_diffusion_models.py @@ -227,6 +227,9 @@ def counted_call(self, *args, **kwargs): def test_nemotron_labs_diffusion(self): from mlx_vlm.models import nemotron_labs_diffusion + from mlx_vlm.models.nemotron_labs_diffusion.language import ( + _chunked_greedy_score_weight, + ) config = nemotron_labs_diffusion.ModelConfig( model_type="nemotron_labs_diffusion", @@ -269,6 +272,21 @@ def test_nemotron_labs_diffusion(self): ) self.assertEqual(bf16_probs.dtype, mx.bfloat16) + score_hidden = mx.random.normal((1, 3, 16)).astype(mx.float32) + score_weight = mx.random.normal((4096, 16)).astype(mx.float32) + score_tokens, score_probs = _chunked_greedy_score_weight( + score_weight, score_hidden, chunks=16, return_prob=True + ) + score_logits = score_hidden @ score_weight.T + ref_tokens = mx.argmax(score_logits, axis=-1).astype(mx.int32) + ref_logits = mx.take_along_axis(score_logits, ref_tokens[..., None], axis=-1)[ + ..., 0 + ] + ref_probs = mx.exp(ref_logits - mx.logsumexp(score_logits, axis=-1)) + self.assertEqual(score_tokens.tolist(), ref_tokens.tolist()) + self.assertTrue(bool(mx.allclose(score_probs, ref_probs).item())) + + diffusion_stats = {} generated = model.language_model.generate( mx.array([[4]], dtype=mx.int32), block_length=4, @@ -277,8 +295,65 @@ def test_nemotron_labs_diffusion(self): max_post_steps=4, mask_id=127, eos_id=999, + stats=diffusion_stats, ) self.assertEqual(generated.shape, (1, 8)) + self.assertEqual( + diffusion_stats["diffusion_sampler"], "confidence_threshold_bound" + ) + self.assertGreaterEqual(diffusion_stats["diffusion_denoise_nfe"], 1) + self.assertGreaterEqual(diffusion_stats["diffusion_accepted_tokens"], 1) + self.assertIn("diffusion_tokens_per_denoise_forward", diffusion_stats) + + for sampler in ( + "native", + "fixed", + "confidence_threshold_ref", + "confidence_threshold_bound", + "cumulative_error", + ): + with self.subTest(sampler=sampler): + sampled = model.language_model.generate( + mx.array([[4]], dtype=mx.int32), + block_length=2, + steps=2, + gen_length=2, + mask_id=127, + eos_id=999, + sampler=sampler, + threshold=0.5, + ) + self.assertEqual(sampled.shape, (1, 2)) + + with self.assertRaises(ValueError): + model.language_model.generate( + mx.array([[4]], dtype=mx.int32), + block_length=2, + gen_length=2, + mask_id=127, + eos_id=999, + sampler="bogus", + ) + + mixed = model.language_model.generate( + mx.array([[4]], dtype=mx.int32), + block_length=2, + gen_length=2, + mask_id=127, + eos_id=999, + ar_weight=0.5, + ) + self.assertEqual(mixed.shape, (1, 2)) + + with self.assertRaises(ValueError): + model.language_model.generate( + mx.array([[4]], dtype=mx.int32), + block_length=2, + gen_length=2, + mask_id=127, + eos_id=999, + ar_weight=1.5, + ) ar_generated, ar_nfe = model.language_model.ar_generate( mx.array([[4]], dtype=mx.int32), @@ -335,6 +410,9 @@ def diffusion_generate(input_ids, **kwargs): input_ids=mx.array([[4]], dtype=mx.int32), max_tokens=2, generation_mode="diffusion", + sampler="native", + sampling_scaling_factor=2.0, + head_scoring="chunked", temperature=0.0, ) ) @@ -342,6 +420,9 @@ def diffusion_generate(input_ids, **kwargs): self.assertFalse(diffusion_calls["kwargs"]["linear_speculative"]) self.assertEqual(diffusion_calls["kwargs"]["steps"], 32) self.assertEqual(diffusion_calls["kwargs"]["threshold"], 0.9) + self.assertEqual(diffusion_calls["kwargs"]["sampler"], "native") + self.assertEqual(diffusion_calls["kwargs"]["sampling_scaling_factor"], 2.0) + self.assertEqual(diffusion_calls["kwargs"]["head_scoring"], "chunked") self.assertEqual(diffusion_results[-1].generation_tokens, 2) diffusion_calls.clear() From 8e84d0c61e1d0d5f99afd871b1c6b863b125af81 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sat, 30 May 2026 23:03:33 +0200 Subject: [PATCH 13/19] Fix Nemotron quantized generation paths --- .../nemotron_labs_diffusion/language.py | 28 +++++++++++++++---- 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/mlx_vlm/models/nemotron_labs_diffusion/language.py b/mlx_vlm/models/nemotron_labs_diffusion/language.py index a16cd3515..63a6a6fd4 100644 --- a/mlx_vlm/models/nemotron_labs_diffusion/language.py +++ b/mlx_vlm/models/nemotron_labs_diffusion/language.py @@ -437,6 +437,8 @@ def __init__(self, config: ModelConfig): @staticmethod def _chunked_linear(linear: nn.Linear, x: mx.array, chunks: int) -> mx.array: + if not isinstance(linear, nn.Linear): + return linear(x) weight_chunks = mx.split(linear.weight, chunks, axis=0) outputs = [mx.matmul(x, weight.T) for weight in weight_chunks] bias = getattr(linear, "bias", None) @@ -486,9 +488,15 @@ def __init__(self, linear: nn.Linear, rank: int, scale: float): super().__init__() self.linear = linear self.scale = scale - out_dim, in_dim = linear.weight.shape - self.lora_a = mx.zeros((in_dim, rank), dtype=linear.weight.dtype) - self.lora_b = mx.zeros((rank, out_dim), dtype=linear.weight.dtype) + out_dim, packed_or_in_dim = linear.weight.shape + if isinstance(linear, nn.QuantizedLinear): + in_dim = (packed_or_in_dim * 32) // linear.bits + self.lora_dtype = linear.scales.dtype + else: + in_dim = packed_or_in_dim + self.lora_dtype = linear.weight.dtype + self.lora_a = mx.zeros((in_dim, rank), dtype=self.lora_dtype) + self.lora_b = mx.zeros((rank, out_dim), dtype=self.lora_dtype) self.enabled = False def __call__(self, x: mx.array) -> mx.array: @@ -757,7 +765,11 @@ def _project_hidden(self, hidden_states: mx.array) -> mx.array: ) if out is not None: return out - if self.small_sequence_head_chunks > 1 and 2 <= hidden_states.shape[-2] <= 16: + if ( + isinstance(self.diffusion_head, nn.Linear) + and self.small_sequence_head_chunks > 1 + and 2 <= hidden_states.shape[-2] <= 16 + ): weight_chunks = mx.split( self.diffusion_head.weight, self.small_sequence_head_chunks, axis=0 ) @@ -792,6 +804,10 @@ def _chunked_greedy_score_hidden( if self.config.tie_word_embeddings else self.diffusion_head.weight ) + if not self.config.tie_word_embeddings and not isinstance( + self.diffusion_head, nn.Linear + ): + return None return _chunked_greedy_score_weight( weight, hidden_states, @@ -873,8 +889,8 @@ def load_linear_spec_lora(self, adapter_path: str | Path) -> bool: key_b = f"{prefix}.lora_B.weight" if key_a not in weights or key_b not in weights: return False - o_proj.lora_a = weights[key_a].T.astype(o_proj.linear.weight.dtype) - o_proj.lora_b = weights[key_b].T.astype(o_proj.linear.weight.dtype) + o_proj.lora_a = weights[key_a].T.astype(o_proj.lora_dtype) + o_proj.lora_b = weights[key_b].T.astype(o_proj.lora_dtype) self._linear_spec_lora_loaded = True return True From 2e6d0600bbaa1139c18fab1bf119520749df43af Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 31 May 2026 01:29:39 +0200 Subject: [PATCH 14/19] Reduce Nemotron diffusion denoise syncs --- .../models/nemotron_labs_diffusion/language.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mlx_vlm/models/nemotron_labs_diffusion/language.py b/mlx_vlm/models/nemotron_labs_diffusion/language.py index 63a6a6fd4..cea149c56 100644 --- a/mlx_vlm/models/nemotron_labs_diffusion/language.py +++ b/mlx_vlm/models/nemotron_labs_diffusion/language.py @@ -1185,11 +1185,11 @@ def visualize_tokens(tokens: mx.array, force: bool = False) -> None: denoise_steps = max(1, min(steps, current_block_length)) denoise_range = range(denoise_steps) if current_block_length > 1 else () + masked_count = max(0, current_block_length - 1) for step_idx in denoise_range: - mask_index = block == mask_id - masked_count = int(mask_index.sum().item()) if masked_count == 0: break + mask_index = block == mask_id force_completion = step_idx == denoise_steps - 1 add_stat("diffusion_denoise_nfe") add_stat("diffusion_masked_rows_scored", masked_count) @@ -1267,6 +1267,7 @@ def visualize_tokens(tokens: mx.array, force: bool = False) -> None: if force_completion or masked_count == 1: transfer_mask = mask_index + accepted_count = masked_count elif threshold is not None: sorted_indices = mx.argsort(-token_probs, axis=-1) sorted_confidence = mx.take_along_axis( @@ -1346,6 +1347,7 @@ def visualize_tokens(tokens: mx.array, force: bool = False) -> None: transfer_mask = ( block_positions[None, None, :] == kept_positions[..., None] ).any(axis=1) & mask_index + accepted_count = int(transfer_mask.sum().item()) else: remaining_steps = max(1, denoise_steps - step_idx) transfer_count = max( @@ -1360,12 +1362,9 @@ def visualize_tokens(tokens: mx.array, force: bool = False) -> None: transfer_mask = ( block_positions[None, None, :] == transfer_positions[..., None] ).any(axis=1) & mask_index + accepted_count = min(transfer_count, masked_count) block = mx.where(transfer_mask, sampled_block, block) - remaining_masked_count = int((block == mask_id).sum().item()) - add_stat( - "diffusion_accepted_tokens", - masked_count - remaining_masked_count, - ) + add_stat("diffusion_accepted_tokens", accepted_count) if visualizer_state["active"] and bool(transfer_mask.any().item()): preview = ( mx.concatenate(generated_blocks + [block], axis=1) @@ -1375,7 +1374,8 @@ def visualize_tokens(tokens: mx.array, force: bool = False) -> None: visualize_tokens(preview) if force_completion: break - if remaining_masked_count == 0: + masked_count -= accepted_count + if masked_count == 0: break generated_block = block[:, :current_block_length] From a8a54226f9c40a9e24112e45c96657c02800c33c Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 31 May 2026 02:44:09 +0200 Subject: [PATCH 15/19] Optimize Nemotron diffusion MLP matmuls --- .../nemotron_labs_diffusion/language.py | 139 ++++++++++++++++++ 1 file changed, 139 insertions(+) diff --git a/mlx_vlm/models/nemotron_labs_diffusion/language.py b/mlx_vlm/models/nemotron_labs_diffusion/language.py index cea149c56..c17a6ab5f 100644 --- a/mlx_vlm/models/nemotron_labs_diffusion/language.py +++ b/mlx_vlm/models/nemotron_labs_diffusion/language.py @@ -20,6 +20,99 @@ _HAS_METAL = mx.metal.is_available() +def _load_mlx_steel_gemm_header() -> Optional[str]: + include_root = Path(mx.__file__).parent / "include" + if not include_root.exists(): + return None + + seen: set[Path] = set() + + def expand(include_path: str) -> str: + path = include_root / include_path + if path in seen: + return "" + seen.add(path) + lines = [] + for line in path.read_text().splitlines(): + stripped = line.strip() + if stripped.startswith('#include "mlx/') and stripped.endswith('"'): + lines.append(expand(stripped[len('#include "') : -1])) + elif stripped != "#pragma once": + lines.append(line) + return "\n".join(lines) + + try: + return expand("mlx/backend/metal/kernels/steel/gemm/gemm.h") + except OSError: + return None + + +def _make_bm32_linear_kernel(): + header = _load_mlx_steel_gemm_header() + if header is None: + return None + + return mx.fast.metal_kernel( + name="nemotron_bm32_steel_linear_nt", + input_names=["x", "weight"], + output_names=["out"], + header=header + "\nusing namespace metal;\nusing namespace mlx::steel;\n", + source=r""" + constexpr short BM = 32; + constexpr short BN = 64; + constexpr short BK = 16; + constexpr short WM = 2; + constexpr short WN = 2; + + using gemm_kernel = GEMMKernel< + T, T, BM, BN, BK, WM, WN, + false, true, true, true, float>; + using loader_a_t = typename gemm_kernel::loader_a_t; + using loader_b_t = typename gemm_kernel::loader_b_t; + using mma_t = typename gemm_kernel::mma_t; + + const uint tid_x = threadgroup_position_in_grid.x; + const uint tid_y = threadgroup_position_in_grid.y; + const int c_row = int(tid_y) * BM; + const int c_col = int(tid_x) * BN; + + const device T* A = x + c_row * K; + const device T* B = weight + c_col * K; + device T* D = out + c_row * O + c_col; + + threadgroup T As[gemm_kernel::tgp_mem_size_a]; + threadgroup T Bs[gemm_kernel::tgp_mem_size_b]; + threadgroup_barrier(mem_flags::mem_none); + + thread mma_t mma_op( + simdgroup_index_in_threadgroup, + thread_index_in_simdgroup); + thread loader_a_t loader_a( + A, K, As, simdgroup_index_in_threadgroup, thread_index_in_simdgroup); + thread loader_b_t loader_b( + B, K, Bs, simdgroup_index_in_threadgroup, thread_index_in_simdgroup); + + for (int kk = 0; kk < K / BK; ++kk) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + + loader_a.next(); + loader_b.next(); + } + + threadgroup_barrier(mem_flags::mem_none); + mma_op.store_result(D, O); + """, + ) + + +_BM32_LINEAR_KERNEL = _make_bm32_linear_kernel() if _HAS_METAL else None + + def _topk(x: mx.array, k: int, axis: int = -1) -> Tuple[mx.array, mx.array]: indices = mx.argpartition(-x, kth=k - 1, axis=axis)[..., :k] values = mx.take_along_axis(x, indices, axis=axis) @@ -211,6 +304,46 @@ def _small_row_linear( return out +def _bm32_linear(linear: nn.Linear, x: mx.array) -> Optional[mx.array]: + if ( + _BM32_LINEAR_KERNEL is None + or not isinstance(linear, nn.Linear) + or x.ndim != 3 + or x.dtype != linear.weight.dtype + or x.dtype != mx.bfloat16 + or x.shape[-2] != 32 + ): + return None + + batch, length, in_dim = x.shape + rows = batch * length + out_dim, weight_in_dim = linear.weight.shape + if ( + in_dim != weight_in_dim + or rows % 32 != 0 + or out_dim % 64 != 0 + or in_dim % 16 != 0 + ): + return None + + out = _BM32_LINEAR_KERNEL( + inputs=[x.reshape(rows, in_dim), linear.weight], + template=[ + ("T", x.dtype), + ("K", in_dim), + ("O", out_dim), + ], + grid=(128 * (out_dim // 64), rows // 32, 1), + threadgroup=(128, 1, 1), + output_shapes=[(rows, out_dim)], + output_dtypes=[x.dtype], + )[0].reshape(batch, length, out_dim) + bias = getattr(linear, "bias", None) + if bias is not None: + out = out + bias.astype(out.dtype) + return out + + def _chunked_greedy_score_weight( weight: mx.array, x: mx.array, @@ -449,6 +582,12 @@ def _chunked_linear(linear: nn.Linear, x: mx.array, chunks: int) -> mx.array: def __call__(self, x: mx.array) -> mx.array: sequence_length = x.shape[-2] + if sequence_length == 32: + gate = _bm32_linear(self.gate_proj, x) + up = _bm32_linear(self.up_proj, x) + if gate is not None and up is not None: + return self.down_proj(swiglu(gate, up)) + if 2 <= sequence_length <= 8: hidden = _small_row_swiglu( self.gate_proj, From 5fdacab37b37a34200dce5171b27fd170ff58d65 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 31 May 2026 11:01:50 +0200 Subject: [PATCH 16/19] Improve Nemotron diffusion token acceptance --- mlx_vlm/models/nemotron_labs_diffusion/config.py | 1 + mlx_vlm/models/nemotron_labs_diffusion/language.py | 14 +++++++++++--- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/mlx_vlm/models/nemotron_labs_diffusion/config.py b/mlx_vlm/models/nemotron_labs_diffusion/config.py index 4e400e2b4..0bd5adc56 100644 --- a/mlx_vlm/models/nemotron_labs_diffusion/config.py +++ b/mlx_vlm/models/nemotron_labs_diffusion/config.py @@ -35,6 +35,7 @@ class ModelConfig(BaseModelConfig): default_generation_mode: str = "ar" default_diffusion_steps: int = 32 default_diffusion_threshold: Optional[float] = 0.9 + default_diffusion_sampling_scaling_factor: float = 2.0 dlm_paradigm: str = "bidirectional" block_size: int = 32 dlm_loss_weight: Optional[float] = None diff --git a/mlx_vlm/models/nemotron_labs_diffusion/language.py b/mlx_vlm/models/nemotron_labs_diffusion/language.py index c17a6ab5f..07bc327ca 100644 --- a/mlx_vlm/models/nemotron_labs_diffusion/language.py +++ b/mlx_vlm/models/nemotron_labs_diffusion/language.py @@ -1100,9 +1100,6 @@ def generate( raise ValueError("block_length must be a positive integer.") steps = max(1, int(steps)) sampler = kwargs.get("sampler") - sampling_scaling_factor = float( - kwargs.get("sampling_scaling_factor", kwargs.get("factor", 1.0)) - ) ar_weight = kwargs.get("ar_weight", 0.0) head_scoring = kwargs.get("head_scoring") ar_weight = float(ar_weight) @@ -1135,6 +1132,17 @@ def generate( "Unsupported Nemotron diffusion sampler " f"{sampler!r}. Expected one of {sorted(valid_samplers)}." ) + sampling_scaling_factor_arg = kwargs.get( + "sampling_scaling_factor", kwargs.get("factor") + ) + if sampling_scaling_factor_arg is None: + sampling_scaling_factor = ( + getattr(self.config, "default_diffusion_sampling_scaling_factor", 2.0) + if sampler_name == "confidence_threshold_bound" + else 1.0 + ) + else: + sampling_scaling_factor = float(sampling_scaling_factor_arg) head_scoring_name = (head_scoring or "full").lower() head_scoring_aliases = { "default": "full", From bf7eb9889ad8c47f20e273ad2f6a596951e1d50e Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Sun, 31 May 2026 16:31:07 +0200 Subject: [PATCH 17/19] Match Nemotron native diffusion parity --- .../models/nemotron_labs_diffusion/README.md | 8 +- .../models/nemotron_labs_diffusion/config.py | 2 + .../nemotron_labs_diffusion/language.py | 82 +++++++++++++++++-- mlx_vlm/tests/test_diffusion_models.py | 7 +- 4 files changed, 87 insertions(+), 12 deletions(-) diff --git a/mlx_vlm/models/nemotron_labs_diffusion/README.md b/mlx_vlm/models/nemotron_labs_diffusion/README.md index b2317939d..b8ff08e91 100644 --- a/mlx_vlm/models/nemotron_labs_diffusion/README.md +++ b/mlx_vlm/models/nemotron_labs_diffusion/README.md @@ -37,10 +37,12 @@ mlx_vlm.generate \ ### Diffusion generation Pass `generation_mode="diffusion"` to use the masked diffusion path. -Nemotron defaults to 32 denoising steps and a 0.9 transfer threshold in this mode, matching the upstream dLM example. +Nemotron defaults to the upstream/Transformers transfer policy with a 32-step denoising cap and a 0.9 transfer threshold. +This native mode also uses a Transformers-parity runtime for the denoise encoder. The upstream mode alias `generation_mode="dlm"` is also accepted. Sampler variants from the NVIDIA evaluation harness can be selected with `sampler`. -Supported values are `confidence_threshold_bound` (default), `native`, `fixed`, `confidence_threshold_ref`, and `cumulative_error`. +Supported values are `native` (default), `confidence_threshold_bound`, `fixed`, `confidence_threshold_ref`, and `cumulative_error`. +For faster MLX experiments, opt into the bounded sampler with `sampler="confidence_threshold_bound"`; it uses `min_threshold=0.45` by default and keeps the optimized MLX kernels. For profiling, `head_scoring="chunked"` scores masked rows without concatenating full vocabulary logits; the default remains `head_scoring="full"` because it is usually faster on MLX's optimized matmul path. For mixed AR+dLM experiments, pass `ar_weight` between `0.0` and `1.0`; this adds an AR causal block forward during denoising and is disabled by default. @@ -51,7 +53,7 @@ mlx_vlm.generate \ --max-tokens 256 \ --max-denoising-steps 16 \ --temperature 0.0 \ - --gen-kwargs '{"generation_mode": "diffusion", "sampler": "confidence_threshold_bound"}' \ + --gen-kwargs '{"generation_mode": "diffusion"}' \ --verbose ``` diff --git a/mlx_vlm/models/nemotron_labs_diffusion/config.py b/mlx_vlm/models/nemotron_labs_diffusion/config.py index 0bd5adc56..03af64399 100644 --- a/mlx_vlm/models/nemotron_labs_diffusion/config.py +++ b/mlx_vlm/models/nemotron_labs_diffusion/config.py @@ -33,8 +33,10 @@ class ModelConfig(BaseModelConfig): attn_implementation: str = "sdpa" mask_token_id: int = 100 default_generation_mode: str = "ar" + default_diffusion_sampler: str = "native" default_diffusion_steps: int = 32 default_diffusion_threshold: Optional[float] = 0.9 + default_diffusion_min_threshold: Optional[float] = 0.45 default_diffusion_sampling_scaling_factor: float = 2.0 dlm_paradigm: str = "bidirectional" block_size: int = 32 diff --git a/mlx_vlm/models/nemotron_labs_diffusion/language.py b/mlx_vlm/models/nemotron_labs_diffusion/language.py index 07bc327ca..a3ddc9131 100644 --- a/mlx_vlm/models/nemotron_labs_diffusion/language.py +++ b/mlx_vlm/models/nemotron_labs_diffusion/language.py @@ -175,6 +175,44 @@ def _llama4_attention_scale( return scale.astype(dtype)[None, None, :, None] +def _transformers_eager_attention( + queries: mx.array, + keys: mx.array, + values: mx.array, + scale: float, + mask: Optional[mx.array | str], +) -> mx.array: + if keys.shape[1] != queries.shape[1]: + repeats = queries.shape[1] // keys.shape[1] + keys = mx.repeat(keys, repeats, axis=1) + values = mx.repeat(values, repeats, axis=1) + + scores = ( + mx.matmul( + queries.astype(mx.float32), + keys.astype(mx.float32).transpose(0, 1, 3, 2), + ) + * scale + ) + if isinstance(mask, str): + if mask != "causal": + raise ValueError(f"Unsupported attention mask {mask!r}.") + query_length = queries.shape[-2] + key_length = keys.shape[-2] + query_positions = mx.arange(query_length)[:, None] + ( + key_length - query_length + ) + key_positions = mx.arange(key_length)[None, :] + causal = key_positions <= query_positions + neg_large = mx.array(mx.finfo(scores.dtype).min, dtype=scores.dtype) + scores = mx.where(causal[None, None, :, :], scores, neg_large) + elif mask is not None: + scores = scores + mask.astype(scores.dtype) + + weights = mx.softmax(scores, axis=-1).astype(queries.dtype) + return mx.matmul(weights, values) + + _SMALL_ROW_GEMV_KERNEL = ( mx.fast.metal_kernel( name="nemotron_small_row_gemv", @@ -555,6 +593,7 @@ def _small_row_swiglu( class MLP(nn.Module): def __init__(self, config: ModelConfig): super().__init__() + self.use_bm32 = True self.small_sequence_chunks = 4 if config.intermediate_size % 4 == 0 else 1 self.tiny_sequence_chunks = 28 if config.intermediate_size % 28 == 0 else 1 self.medium_sequence_chunks = 4 if config.intermediate_size % 4 == 0 else 1 @@ -582,7 +621,7 @@ def _chunked_linear(linear: nn.Linear, x: mx.array, chunks: int) -> mx.array: def __call__(self, x: mx.array) -> mx.array: sequence_length = x.shape[-2] - if sequence_length == 32: + if self.use_bm32 and sequence_length == 32: gate = _bm32_linear(self.gate_proj, x) up = _bm32_linear(self.up_proj, x) if gate is not None and up is not None: @@ -650,6 +689,7 @@ class Attention(nn.Module): def __init__(self, config: ModelConfig): super().__init__() self.config = config + self.use_transformers_eager_attention = False self.num_heads = config.num_attention_heads self.num_key_value_heads = config.num_key_value_heads self.head_dim = config.head_dim @@ -716,9 +756,14 @@ def __call__( [cache.values[..., : cache.offset, :], values], axis=2 ) - output = scaled_dot_product_attention( - queries, keys, values, cache=cache, scale=self.scale, mask=mask - ) + if self.use_transformers_eager_attention: + output = _transformers_eager_attention( + queries, keys, values, scale=self.scale, mask=mask + ) + else: + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask + ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) projected = _small_row_linear(self.o_proj, output, max_sequence_length=8) if projected is not None: @@ -820,6 +865,11 @@ def __init__(self, config: ModelConfig): ) self._linear_spec_lora_loaded = False + def _set_transformers_parity_runtime(self, enabled: bool) -> None: + for layer in self.model.layers: + layer.self_attn.use_transformers_eager_attention = enabled + layer.mlp.use_bm32 = not enabled + def __call__( self, inputs: mx.array, @@ -1105,9 +1155,10 @@ def generate( ar_weight = float(ar_weight) if ar_weight < 0.0 or ar_weight > 1.0: raise ValueError("ar_weight must be between 0.0 and 1.0.") - sampler_name = (sampler or "confidence_threshold_bound").lower() + default_sampler_name = getattr(self.config, "default_diffusion_sampler", "native") + sampler_name = (sampler or default_sampler_name).lower() sampler_aliases = { - "default": "confidence_threshold_bound", + "default": default_sampler_name.lower(), "optimized": "confidence_threshold_bound", "threshold_bound": "confidence_threshold_bound", "bound": "confidence_threshold_bound", @@ -1143,6 +1194,17 @@ def generate( ) else: sampling_scaling_factor = float(sampling_scaling_factor_arg) + if min_threshold is None and sampler_name == "confidence_threshold_bound": + min_threshold = getattr(self.config, "default_diffusion_min_threshold", 0.4) + if min_threshold is not None: + min_threshold = float(min_threshold) + transformers_parity_arg = kwargs.get("transformers_parity") + transformers_parity = ( + sampler_name == "native" + if transformers_parity_arg is None + else bool(transformers_parity_arg) + ) + self._set_transformers_parity_runtime(transformers_parity) head_scoring_name = (head_scoring or "full").lower() head_scoring_aliases = { "default": "full", @@ -1170,6 +1232,11 @@ def generate( if use_chunked_scoring and self.greedy_score_chunks > 1 else "project_full_logits" ) + stats["diffusion_min_threshold"] = ( + float(min_threshold) if min_threshold is not None else float("nan") + ) + stats["diffusion_sampling_scaling_factor"] = sampling_scaling_factor + stats["diffusion_transformers_parity"] = float(transformers_parity) for key in ( "diffusion_blocks", "diffusion_denoise_nfe", @@ -1591,6 +1658,7 @@ def ar_generate( stats: Optional[Dict[str, float]] = None, **kwargs, ) -> tuple[mx.array, int]: + self._set_transformers_parity_runtime(False) if eos_token_id is None: eos_token_id = self.config.eos_token_id eos_token_ids = ( @@ -1665,6 +1733,7 @@ def linear_spec_generate( stats: Optional[Dict[str, float]] = None, **kwargs, ) -> tuple[mx.array, int]: + self._set_transformers_parity_runtime(False) if prompt_ids.shape[0] != 1: raise ValueError("Linear speculative decoding requires batch size 1.") if block_length <= 0: @@ -1820,6 +1889,7 @@ def stream_linear_spec_generate( stats: Optional[Dict[str, float]] = None, **kwargs, ): + self._set_transformers_parity_runtime(False) if prompt_ids.shape[0] != 1: raise ValueError("Linear speculative decoding requires batch size 1.") if block_length <= 0: diff --git a/mlx_vlm/tests/test_diffusion_models.py b/mlx_vlm/tests/test_diffusion_models.py index 585737a52..3d724cd7c 100644 --- a/mlx_vlm/tests/test_diffusion_models.py +++ b/mlx_vlm/tests/test_diffusion_models.py @@ -1,3 +1,4 @@ +import math import unittest import mlx.core as mx @@ -298,9 +299,9 @@ def test_nemotron_labs_diffusion(self): stats=diffusion_stats, ) self.assertEqual(generated.shape, (1, 8)) - self.assertEqual( - diffusion_stats["diffusion_sampler"], "confidence_threshold_bound" - ) + self.assertEqual(diffusion_stats["diffusion_sampler"], "native") + self.assertTrue(math.isnan(diffusion_stats["diffusion_min_threshold"])) + self.assertEqual(diffusion_stats["diffusion_transformers_parity"], 1.0) self.assertGreaterEqual(diffusion_stats["diffusion_denoise_nfe"], 1) self.assertGreaterEqual(diffusion_stats["diffusion_accepted_tokens"], 1) self.assertIn("diffusion_tokens_per_denoise_forward", diffusion_stats) From a1c528bf9e9d11b324967a32d75230eb5d1496fa Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Mon, 1 Jun 2026 07:55:25 +0200 Subject: [PATCH 18/19] Add Streaming-dLLM sampler for Nemotron diffusion --- .../models/nemotron_labs_diffusion/README.md | 4 +- .../models/nemotron_labs_diffusion/config.py | 1 + .../nemotron_labs_diffusion/language.py | 49 ++++++++++++++++++- mlx_vlm/tests/test_diffusion_models.py | 39 +++++++++++++++ 4 files changed, 91 insertions(+), 2 deletions(-) diff --git a/mlx_vlm/models/nemotron_labs_diffusion/README.md b/mlx_vlm/models/nemotron_labs_diffusion/README.md index b8ff08e91..2ab3a25b3 100644 --- a/mlx_vlm/models/nemotron_labs_diffusion/README.md +++ b/mlx_vlm/models/nemotron_labs_diffusion/README.md @@ -41,8 +41,10 @@ Nemotron defaults to the upstream/Transformers transfer policy with a 32-step de This native mode also uses a Transformers-parity runtime for the denoise encoder. The upstream mode alias `generation_mode="dlm"` is also accepted. Sampler variants from the NVIDIA evaluation harness can be selected with `sampler`. -Supported values are `native` (default), `confidence_threshold_bound`, `fixed`, `confidence_threshold_ref`, and `cumulative_error`. +Supported values are `native` (default), `confidence_threshold_bound`, `streaming_dllm`, `fixed`, `confidence_threshold_ref`, and `cumulative_error`. For faster MLX experiments, opt into the bounded sampler with `sampler="confidence_threshold_bound"`; it uses `min_threshold=0.45` by default and keeps the optimized MLX kernels. +The `streaming_dllm` sampler implements the Streaming-dLLM dynamic confidence rule: the transfer threshold is reduced as the current block fills, controlled by `confidence_alpha` (default `0.9`). +Aliases `streaming`, `dynamic_confidence`, and `context_aware` are also accepted. For profiling, `head_scoring="chunked"` scores masked rows without concatenating full vocabulary logits; the default remains `head_scoring="full"` because it is usually faster on MLX's optimized matmul path. For mixed AR+dLM experiments, pass `ar_weight` between `0.0` and `1.0`; this adds an AR causal block forward during denoising and is disabled by default. diff --git a/mlx_vlm/models/nemotron_labs_diffusion/config.py b/mlx_vlm/models/nemotron_labs_diffusion/config.py index 03af64399..dbd67e3bd 100644 --- a/mlx_vlm/models/nemotron_labs_diffusion/config.py +++ b/mlx_vlm/models/nemotron_labs_diffusion/config.py @@ -38,6 +38,7 @@ class ModelConfig(BaseModelConfig): default_diffusion_threshold: Optional[float] = 0.9 default_diffusion_min_threshold: Optional[float] = 0.45 default_diffusion_sampling_scaling_factor: float = 2.0 + default_diffusion_confidence_alpha: float = 0.9 dlm_paradigm: str = "bidirectional" block_size: int = 32 dlm_loss_weight: Optional[float] = None diff --git a/mlx_vlm/models/nemotron_labs_diffusion/language.py b/mlx_vlm/models/nemotron_labs_diffusion/language.py index a3ddc9131..9a944ddbe 100644 --- a/mlx_vlm/models/nemotron_labs_diffusion/language.py +++ b/mlx_vlm/models/nemotron_labs_diffusion/language.py @@ -1162,6 +1162,11 @@ def generate( "optimized": "confidence_threshold_bound", "threshold_bound": "confidence_threshold_bound", "bound": "confidence_threshold_bound", + "streaming": "streaming_dllm", + "streaming-dllm": "streaming_dllm", + "dynamic": "streaming_dllm", + "dynamic_confidence": "streaming_dllm", + "context_aware": "streaming_dllm", "hf": "native", "upstream": "native", "confidence_threshold": "native", @@ -1176,6 +1181,7 @@ def generate( "fixed", "confidence_threshold_ref", "confidence_threshold_bound", + "streaming_dllm", "cumulative_error", } if sampler_name not in valid_samplers: @@ -1198,6 +1204,17 @@ def generate( min_threshold = getattr(self.config, "default_diffusion_min_threshold", 0.4) if min_threshold is not None: min_threshold = float(min_threshold) + confidence_alpha = float( + kwargs.get( + "confidence_alpha", + kwargs.get( + "dynamic_confidence_alpha", + getattr(self.config, "default_diffusion_confidence_alpha", 0.3), + ), + ) + ) + if confidence_alpha < 0.0 or confidence_alpha > 1.0: + raise ValueError("confidence_alpha must be between 0.0 and 1.0.") transformers_parity_arg = kwargs.get("transformers_parity") transformers_parity = ( sampler_name == "native" @@ -1236,6 +1253,7 @@ def generate( float(min_threshold) if min_threshold is not None else float("nan") ) stats["diffusion_sampling_scaling_factor"] = sampling_scaling_factor + stats["diffusion_confidence_alpha"] = confidence_alpha stats["diffusion_transformers_parity"] = float(transformers_parity) for key in ( "diffusion_blocks", @@ -1521,6 +1539,18 @@ def visualize_tokens(tokens: mx.array, force: bool = False) -> None: criteria.astype(mx.int32), axis=1 ).astype(mx.bool_) keep_sorted = keep_sorted & (sorted_positions < transfer_limit) + elif sampler_name == "streaming_dllm": + mask_ratio = masked_count / current_block_length + adjusted_threshold = float(threshold) * ( + 1.0 - confidence_alpha * (1.0 - mask_ratio) + ) + keep_sorted = sorted_confidence >= mx.array( + adjusted_threshold, dtype=sorted_confidence.dtype + ) + if max_transfer_per_step is not None: + keep_sorted = keep_sorted & ( + sorted_positions < transfer_limit + ) elif sampler_name == "cumulative_error": confidence_floor = mx.array( 1e-12, dtype=sorted_confidence.dtype @@ -1579,6 +1609,19 @@ def visualize_tokens(tokens: mx.array, force: bool = False) -> None: accepted_count = min(transfer_count, masked_count) block = mx.where(transfer_mask, sampled_block, block) add_stat("diffusion_accepted_tokens", accepted_count) + if ( + eos_early_stop + and sampler_name == "streaming_dllm" + and end_length is None + ): + block_ids = block[0].tolist() + for eos_position, token_id in enumerate(block_ids): + if token_id in eos_token_ids and mask_id not in block_ids[ + : eos_position + 1 + ]: + end_length = total_generated + eos_position + 1 + add_stat("diffusion_eos_early_exit") + break if visualizer_state["active"] and bool(transfer_mask.any().item()): preview = ( mx.concatenate(generated_blocks + [block], axis=1) @@ -1588,6 +1631,8 @@ def visualize_tokens(tokens: mx.array, force: bool = False) -> None: visualize_tokens(preview) if force_completion: break + if end_length is not None: + break masked_count -= accepted_count if masked_count == 0: break @@ -1595,11 +1640,13 @@ def visualize_tokens(tokens: mx.array, force: bool = False) -> None: generated_block = block[:, :current_block_length] generated_blocks.append(generated_block) total_generated += current_block_length - if eos_early_stop: + if eos_early_stop and end_length is None: eos_index = _first_token_index(generated_block[0], eos_token_ids) if eos_index is not None: end_length = total_generated - current_block_length + eos_index + 1 break + if end_length is not None: + break if total_generated >= gen_length: break diff --git a/mlx_vlm/tests/test_diffusion_models.py b/mlx_vlm/tests/test_diffusion_models.py index 3d724cd7c..c45d69b7d 100644 --- a/mlx_vlm/tests/test_diffusion_models.py +++ b/mlx_vlm/tests/test_diffusion_models.py @@ -226,6 +226,44 @@ def counted_call(self, *args, **kwargs): self.assertLessEqual(calls["count"], 6) + def test_nemotron_streaming_dllm_sampler(self): + from mlx_vlm.models import nemotron_labs_diffusion + + config = nemotron_labs_diffusion.ModelConfig( + vocab_size=64, + hidden_size=32, + intermediate_size=64, + num_hidden_layers=1, + num_attention_heads=4, + num_key_value_heads=2, + head_dim=8, + max_position_embeddings=64, + eos_token_id=11, + mask_token_id=10, + ) + model = nemotron_labs_diffusion.Model(config) + stats = {} + + generated = model.language_model.generate( + mx.array([[1, 2, 3]], dtype=mx.int32), + gen_length=8, + block_length=4, + steps=4, + sampler="dynamic_confidence", + threshold=0.9, + confidence_alpha=0.4, + temperature=0.0, + mask_id=10, + eos_id=11, + stats=stats, + ) + mx.eval(generated) + + self.assertEqual(generated.shape, (1, 8)) + self.assertEqual(stats["diffusion_sampler"], "streaming_dllm") + self.assertEqual(stats["diffusion_confidence_alpha"], 0.4) + self.assertEqual(stats["diffusion_transformers_parity"], 0.0) + def test_nemotron_labs_diffusion(self): from mlx_vlm.models import nemotron_labs_diffusion from mlx_vlm.models.nemotron_labs_diffusion.language import ( @@ -311,6 +349,7 @@ def test_nemotron_labs_diffusion(self): "fixed", "confidence_threshold_ref", "confidence_threshold_bound", + "streaming_dllm", "cumulative_error", ): with self.subTest(sampler=sampler): From fc10bd8a298e111ec5987382b4e4efdc9f2a5e18 Mon Sep 17 00:00:00 2001 From: Prince Canuma Date: Mon, 1 Jun 2026 21:46:23 +0200 Subject: [PATCH 19/19] Remove Nemotron Streaming-dLLM sampler --- .../models/nemotron_labs_diffusion/README.md | 4 +- .../models/nemotron_labs_diffusion/config.py | 1 - .../nemotron_labs_diffusion/language.py | 51 ++----------------- mlx_vlm/tests/test_diffusion_models.py | 39 -------------- 4 files changed, 5 insertions(+), 90 deletions(-) diff --git a/mlx_vlm/models/nemotron_labs_diffusion/README.md b/mlx_vlm/models/nemotron_labs_diffusion/README.md index 2ab3a25b3..b8ff08e91 100644 --- a/mlx_vlm/models/nemotron_labs_diffusion/README.md +++ b/mlx_vlm/models/nemotron_labs_diffusion/README.md @@ -41,10 +41,8 @@ Nemotron defaults to the upstream/Transformers transfer policy with a 32-step de This native mode also uses a Transformers-parity runtime for the denoise encoder. The upstream mode alias `generation_mode="dlm"` is also accepted. Sampler variants from the NVIDIA evaluation harness can be selected with `sampler`. -Supported values are `native` (default), `confidence_threshold_bound`, `streaming_dllm`, `fixed`, `confidence_threshold_ref`, and `cumulative_error`. +Supported values are `native` (default), `confidence_threshold_bound`, `fixed`, `confidence_threshold_ref`, and `cumulative_error`. For faster MLX experiments, opt into the bounded sampler with `sampler="confidence_threshold_bound"`; it uses `min_threshold=0.45` by default and keeps the optimized MLX kernels. -The `streaming_dllm` sampler implements the Streaming-dLLM dynamic confidence rule: the transfer threshold is reduced as the current block fills, controlled by `confidence_alpha` (default `0.9`). -Aliases `streaming`, `dynamic_confidence`, and `context_aware` are also accepted. For profiling, `head_scoring="chunked"` scores masked rows without concatenating full vocabulary logits; the default remains `head_scoring="full"` because it is usually faster on MLX's optimized matmul path. For mixed AR+dLM experiments, pass `ar_weight` between `0.0` and `1.0`; this adds an AR causal block forward during denoising and is disabled by default. diff --git a/mlx_vlm/models/nemotron_labs_diffusion/config.py b/mlx_vlm/models/nemotron_labs_diffusion/config.py index dbd67e3bd..03af64399 100644 --- a/mlx_vlm/models/nemotron_labs_diffusion/config.py +++ b/mlx_vlm/models/nemotron_labs_diffusion/config.py @@ -38,7 +38,6 @@ class ModelConfig(BaseModelConfig): default_diffusion_threshold: Optional[float] = 0.9 default_diffusion_min_threshold: Optional[float] = 0.45 default_diffusion_sampling_scaling_factor: float = 2.0 - default_diffusion_confidence_alpha: float = 0.9 dlm_paradigm: str = "bidirectional" block_size: int = 32 dlm_loss_weight: Optional[float] = None diff --git a/mlx_vlm/models/nemotron_labs_diffusion/language.py b/mlx_vlm/models/nemotron_labs_diffusion/language.py index 9a944ddbe..3b42b657c 100644 --- a/mlx_vlm/models/nemotron_labs_diffusion/language.py +++ b/mlx_vlm/models/nemotron_labs_diffusion/language.py @@ -199,9 +199,7 @@ def _transformers_eager_attention( raise ValueError(f"Unsupported attention mask {mask!r}.") query_length = queries.shape[-2] key_length = keys.shape[-2] - query_positions = mx.arange(query_length)[:, None] + ( - key_length - query_length - ) + query_positions = mx.arange(query_length)[:, None] + (key_length - query_length) key_positions = mx.arange(key_length)[None, :] causal = key_positions <= query_positions neg_large = mx.array(mx.finfo(scores.dtype).min, dtype=scores.dtype) @@ -1155,18 +1153,15 @@ def generate( ar_weight = float(ar_weight) if ar_weight < 0.0 or ar_weight > 1.0: raise ValueError("ar_weight must be between 0.0 and 1.0.") - default_sampler_name = getattr(self.config, "default_diffusion_sampler", "native") + default_sampler_name = getattr( + self.config, "default_diffusion_sampler", "native" + ) sampler_name = (sampler or default_sampler_name).lower() sampler_aliases = { "default": default_sampler_name.lower(), "optimized": "confidence_threshold_bound", "threshold_bound": "confidence_threshold_bound", "bound": "confidence_threshold_bound", - "streaming": "streaming_dllm", - "streaming-dllm": "streaming_dllm", - "dynamic": "streaming_dllm", - "dynamic_confidence": "streaming_dllm", - "context_aware": "streaming_dllm", "hf": "native", "upstream": "native", "confidence_threshold": "native", @@ -1181,7 +1176,6 @@ def generate( "fixed", "confidence_threshold_ref", "confidence_threshold_bound", - "streaming_dllm", "cumulative_error", } if sampler_name not in valid_samplers: @@ -1204,17 +1198,6 @@ def generate( min_threshold = getattr(self.config, "default_diffusion_min_threshold", 0.4) if min_threshold is not None: min_threshold = float(min_threshold) - confidence_alpha = float( - kwargs.get( - "confidence_alpha", - kwargs.get( - "dynamic_confidence_alpha", - getattr(self.config, "default_diffusion_confidence_alpha", 0.3), - ), - ) - ) - if confidence_alpha < 0.0 or confidence_alpha > 1.0: - raise ValueError("confidence_alpha must be between 0.0 and 1.0.") transformers_parity_arg = kwargs.get("transformers_parity") transformers_parity = ( sampler_name == "native" @@ -1253,7 +1236,6 @@ def generate( float(min_threshold) if min_threshold is not None else float("nan") ) stats["diffusion_sampling_scaling_factor"] = sampling_scaling_factor - stats["diffusion_confidence_alpha"] = confidence_alpha stats["diffusion_transformers_parity"] = float(transformers_parity) for key in ( "diffusion_blocks", @@ -1539,18 +1521,6 @@ def visualize_tokens(tokens: mx.array, force: bool = False) -> None: criteria.astype(mx.int32), axis=1 ).astype(mx.bool_) keep_sorted = keep_sorted & (sorted_positions < transfer_limit) - elif sampler_name == "streaming_dllm": - mask_ratio = masked_count / current_block_length - adjusted_threshold = float(threshold) * ( - 1.0 - confidence_alpha * (1.0 - mask_ratio) - ) - keep_sorted = sorted_confidence >= mx.array( - adjusted_threshold, dtype=sorted_confidence.dtype - ) - if max_transfer_per_step is not None: - keep_sorted = keep_sorted & ( - sorted_positions < transfer_limit - ) elif sampler_name == "cumulative_error": confidence_floor = mx.array( 1e-12, dtype=sorted_confidence.dtype @@ -1609,19 +1579,6 @@ def visualize_tokens(tokens: mx.array, force: bool = False) -> None: accepted_count = min(transfer_count, masked_count) block = mx.where(transfer_mask, sampled_block, block) add_stat("diffusion_accepted_tokens", accepted_count) - if ( - eos_early_stop - and sampler_name == "streaming_dllm" - and end_length is None - ): - block_ids = block[0].tolist() - for eos_position, token_id in enumerate(block_ids): - if token_id in eos_token_ids and mask_id not in block_ids[ - : eos_position + 1 - ]: - end_length = total_generated + eos_position + 1 - add_stat("diffusion_eos_early_exit") - break if visualizer_state["active"] and bool(transfer_mask.any().item()): preview = ( mx.concatenate(generated_blocks + [block], axis=1) diff --git a/mlx_vlm/tests/test_diffusion_models.py b/mlx_vlm/tests/test_diffusion_models.py index c45d69b7d..3d724cd7c 100644 --- a/mlx_vlm/tests/test_diffusion_models.py +++ b/mlx_vlm/tests/test_diffusion_models.py @@ -226,44 +226,6 @@ def counted_call(self, *args, **kwargs): self.assertLessEqual(calls["count"], 6) - def test_nemotron_streaming_dllm_sampler(self): - from mlx_vlm.models import nemotron_labs_diffusion - - config = nemotron_labs_diffusion.ModelConfig( - vocab_size=64, - hidden_size=32, - intermediate_size=64, - num_hidden_layers=1, - num_attention_heads=4, - num_key_value_heads=2, - head_dim=8, - max_position_embeddings=64, - eos_token_id=11, - mask_token_id=10, - ) - model = nemotron_labs_diffusion.Model(config) - stats = {} - - generated = model.language_model.generate( - mx.array([[1, 2, 3]], dtype=mx.int32), - gen_length=8, - block_length=4, - steps=4, - sampler="dynamic_confidence", - threshold=0.9, - confidence_alpha=0.4, - temperature=0.0, - mask_id=10, - eos_id=11, - stats=stats, - ) - mx.eval(generated) - - self.assertEqual(generated.shape, (1, 8)) - self.assertEqual(stats["diffusion_sampler"], "streaming_dllm") - self.assertEqual(stats["diffusion_confidence_alpha"], 0.4) - self.assertEqual(stats["diffusion_transformers_parity"], 0.0) - def test_nemotron_labs_diffusion(self): from mlx_vlm.models import nemotron_labs_diffusion from mlx_vlm.models.nemotron_labs_diffusion.language import ( @@ -349,7 +311,6 @@ def test_nemotron_labs_diffusion(self): "fixed", "confidence_threshold_ref", "confidence_threshold_bound", - "streaming_dllm", "cumulative_error", ): with self.subTest(sampler=sampler):