-
-
Notifications
You must be signed in to change notification settings - Fork 552
Add Nemotron Labs Diffusion model #1239
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+2,757
−6
Merged
Changes from all commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
7ef792f
Add Nemotron Labs Diffusion model
Blaizzy ae05744
Tune Nemotron diffusion mode defaults
Blaizzy d674202
Restore quality-first Nemotron diffusion defaults
Blaizzy 50212cd
Align Nemotron diffusion threshold transfer
Blaizzy 0f408cc
Reduce Nemotron diffusion materialization
Blaizzy 9267d1a
Preserve bf16 in Nemotron diffusion
Blaizzy 2491854
Accept upstream Nemotron mode aliases
Blaizzy 380e25c
Optimize Nemotron diffusion small-block inference
Blaizzy 1b021e9
Optimize Nemotron diffusion small-row kernels
Blaizzy 4cc91a1
Optimize Nemotron linear speculative decoding
Blaizzy 384a12a
Merge branch 'main' into pc/nemotron-labs-diffusion
Blaizzy b65feb5
Optimize Nemotron diffusion scoring
Blaizzy bfcbea0
Add Nemotron diffusion sampler controls
Blaizzy 8e84d0c
Fix Nemotron quantized generation paths
Blaizzy 2e6d060
Reduce Nemotron diffusion denoise syncs
Blaizzy a8a5422
Optimize Nemotron diffusion MLP matmuls
Blaizzy 5fdacab
Improve Nemotron diffusion token acceptance
Blaizzy bf7eb98
Match Nemotron native diffusion parity
Blaizzy a1c528b
Add Streaming-dLLM sampler for Nemotron diffusion
Blaizzy fc10bd8
Remove Nemotron Streaming-dLLM sampler
Blaizzy 835ee17
Merge branch 'main' into pc/nemotron-labs-diffusion
Blaizzy File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -379,6 +379,13 @@ def parse_arguments(): | |
| help="Extra processor kwargs as JSON. " | ||
| 'Example: --processor-kwargs \'{"cropping": false, "max_patches": 3}\'', | ||
| ) | ||
| parser.add_argument( | ||
| "--gen-kwargs", | ||
| type=json.loads, | ||
| default={}, | ||
| help="Extra generation kwargs as JSON. " | ||
| "Example: --gen-kwargs '{\"linear_speculative\": true}'", | ||
| ) | ||
| parser.add_argument( | ||
| "--prefill-step-size", | ||
| type=int, | ||
|
|
@@ -606,6 +613,29 @@ def is_masked_diffusion_text_model(model: nn.Module) -> bool: | |
| return getattr(config, "mask_token_id", None) is not None | ||
|
|
||
|
|
||
| def _use_masked_diffusion_text_path(model: nn.Module, kwargs: Dict[str, Any]) -> bool: | ||
| if not is_masked_diffusion_text_model(model): | ||
| return False | ||
|
|
||
| config = getattr(model, "config", None) | ||
| if getattr(config, "default_generation_mode", None) != "ar": | ||
| return True | ||
|
|
||
| generation_mode = kwargs.get("generation_mode") | ||
| if generation_mode is not None: | ||
| return generation_mode in ( | ||
| "diffusion", | ||
| "dlm", | ||
| "linear_speculative", | ||
| "linear_spec", | ||
| ) | ||
|
|
||
| return bool( | ||
| kwargs.get("linear_speculative", False) | ||
| or kwargs.get("linear_speculation", False) | ||
| ) | ||
|
|
||
|
|
||
| def _prime_cached_prefix_rope_state( | ||
| model: nn.Module, | ||
| full_input_ids: mx.array, | ||
|
|
@@ -738,7 +768,7 @@ def stream_generate( | |
| } | ||
| kwargs.update(data_kwargs) | ||
|
|
||
| if is_masked_diffusion_text_model(model): | ||
| if _use_masked_diffusion_text_path(model, kwargs): | ||
| if image is not None or audio is not None or video is not None: | ||
| raise ValueError("Diffusion text generation models are text-only.") | ||
|
|
||
|
|
@@ -748,14 +778,24 @@ def stream_generate( | |
| top_k = kwargs.get("top_k", DEFAULT_TOP_K) | ||
| max_denoising_steps = kwargs.get("max_denoising_steps") | ||
| if max_denoising_steps is None: | ||
| max_denoising_steps = kwargs.get("steps", 32) | ||
| config = getattr(model, "config", None) | ||
| max_denoising_steps = kwargs.get( | ||
| "steps", getattr(config, "default_diffusion_steps", 32) | ||
| ) | ||
| num_to_transfer = kwargs.get( | ||
| "num_to_transfer", DEFAULT_MASKED_DIFFUSION_NUM_TO_TRANSFER | ||
| ) | ||
| threshold = kwargs.get("threshold", DEFAULT_MASKED_DIFFUSION_THRESHOLD) | ||
| min_threshold = kwargs.get( | ||
| "min_threshold", DEFAULT_MASKED_DIFFUSION_MIN_THRESHOLD | ||
| ) | ||
| config = getattr(model, "config", None) | ||
| if getattr(config, "default_generation_mode", None) == "ar": | ||
| threshold = kwargs.get( | ||
| "threshold", getattr(config, "default_diffusion_threshold", None) | ||
| ) | ||
| min_threshold = kwargs.get("min_threshold") | ||
| else: | ||
| threshold = kwargs.get("threshold", DEFAULT_MASKED_DIFFUSION_THRESHOLD) | ||
| min_threshold = kwargs.get( | ||
| "min_threshold", DEFAULT_MASKED_DIFFUSION_MIN_THRESHOLD | ||
| ) | ||
| editing_threshold = kwargs.get( | ||
| "editing_threshold", DEFAULT_MASKED_DIFFUSION_EDITING_THRESHOLD | ||
| ) | ||
|
|
@@ -768,6 +808,30 @@ def stream_generate( | |
| ) | ||
|
|
||
| generation_stats = {} | ||
| handled_generation_kwargs = { | ||
| "max_tokens", | ||
| "temperature", | ||
| "top_p", | ||
| "top_k", | ||
| "max_denoising_steps", | ||
| "steps", | ||
| "block_length", | ||
| "threshold", | ||
| "min_threshold", | ||
| "editing_threshold", | ||
| "max_post_steps", | ||
| "num_to_transfer", | ||
| "max_transfer_per_step", | ||
| "stability_steps", | ||
| "linear_speculative", | ||
| "linear_speculation", | ||
| "generation_mode", | ||
| } | ||
| model_generate_kwargs = { | ||
| key: value | ||
| for key, value in kwargs.items() | ||
| if key not in handled_generation_kwargs | ||
| } | ||
| tic = time.perf_counter() | ||
| generated = model.language_model.generate( | ||
| input_ids, | ||
|
|
@@ -789,6 +853,10 @@ def stream_generate( | |
| tokenizer=tokenizer, | ||
| skip_special_tokens=skip_special_tokens, | ||
| stats=generation_stats, | ||
| linear_speculative=kwargs.get("linear_speculative", False) | ||
| or kwargs.get("linear_speculation", False) | ||
| or kwargs.get("generation_mode") in ("linear_speculative", "linear_spec"), | ||
| **model_generate_kwargs, | ||
|
Comment on lines
+856
to
+859
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To fix |
||
| ) | ||
| mx.eval(generated) | ||
| total_time = time.perf_counter() - tic | ||
|
|
@@ -1326,6 +1394,7 @@ def main(): | |
| "editing_threshold": None, | ||
| "max_post_steps": None, | ||
| "stability_steps": None, | ||
| "gen_kwargs": {}, | ||
| } | ||
| for name, default in diffusion_arg_defaults.items(): | ||
| if not hasattr(args, name): | ||
|
|
@@ -1411,6 +1480,10 @@ def main(): | |
| if args.processor_kwargs: | ||
| kwargs.update(args.processor_kwargs) | ||
|
|
||
| # Add generation kwargs from JSON | ||
| if args.gen_kwargs: | ||
| kwargs.update(args.gen_kwargs) | ||
|
|
||
| # Add thinking kwargs | ||
| kwargs["enable_thinking"] = args.enable_thinking | ||
| if args.thinking_budget is not None: | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,182 @@ | ||
| # Nemotron Labs Diffusion | ||
|
|
||
| Nemotron Labs Diffusion is a text-only diffusion language model from NVIDIA. The same checkpoint supports autoregressive decoding, block diffusion decoding, and linear self-speculative decoding. | ||
|
|
||
| Capabilities: | ||
| - **Text generation** - normal autoregressive generation through the standard `mlx_vlm.generate` path | ||
| - **Diffusion generation** - masked block denoising with live visualization when `--verbose` is enabled | ||
| - **Linear self-speculation** - diffusion drafting with autoregressive verification using `--gen-kwargs` | ||
| - **Thinking mode** - chat-template support through `--enable-thinking` | ||
|
|
||
| ## Model | ||
|
|
||
| | Model | Type | Params | Context | Modalities | | ||
| |---|---|---:|---:|---| | ||
| | `nvidia/Nemotron-Labs-Diffusion-8B` | Dense diffusion LM | 8B | 262k | Text | | ||
|
|
||
| ## Install | ||
|
|
||
| ```sh | ||
| pip install -U mlx-vlm | ||
| ``` | ||
|
|
||
| ## CLI | ||
|
|
||
| ### Autoregressive generation | ||
|
|
||
| By default, Nemotron uses the normal autoregressive generation path. | ||
|
|
||
| ```sh | ||
| mlx_vlm.generate \ | ||
| --model nvidia/Nemotron-Labs-Diffusion-8B \ | ||
| --prompt "Write a short story about a clockmaker." \ | ||
| --max-tokens 256 \ | ||
| --temperature 0.0 | ||
| ``` | ||
|
|
||
| ### Diffusion generation | ||
|
|
||
| Pass `generation_mode="diffusion"` to use the masked diffusion path. | ||
| Nemotron defaults to the upstream/Transformers transfer policy with a 32-step denoising cap and a 0.9 transfer threshold. | ||
| This native mode also uses a Transformers-parity runtime for the denoise encoder. | ||
| The upstream mode alias `generation_mode="dlm"` is also accepted. | ||
| Sampler variants from the NVIDIA evaluation harness can be selected with `sampler`. | ||
| Supported values are `native` (default), `confidence_threshold_bound`, `fixed`, `confidence_threshold_ref`, and `cumulative_error`. | ||
| For faster MLX experiments, opt into the bounded sampler with `sampler="confidence_threshold_bound"`; it uses `min_threshold=0.45` by default and keeps the optimized MLX kernels. | ||
| For profiling, `head_scoring="chunked"` scores masked rows without concatenating full vocabulary logits; the default remains `head_scoring="full"` because it is usually faster on MLX's optimized matmul path. | ||
| For mixed AR+dLM experiments, pass `ar_weight` between `0.0` and `1.0`; this adds an AR causal block forward during denoising and is disabled by default. | ||
|
|
||
| ```sh | ||
| mlx_vlm.generate \ | ||
| --model nvidia/Nemotron-Labs-Diffusion-8B \ | ||
| --prompt "Write a short story about a clockmaker." \ | ||
| --max-tokens 256 \ | ||
| --max-denoising-steps 16 \ | ||
| --temperature 0.0 \ | ||
| --gen-kwargs '{"generation_mode": "diffusion"}' \ | ||
| --verbose | ||
| ``` | ||
|
|
||
| ### Linear self-speculative generation | ||
|
|
||
| Use `--gen-kwargs` for model-specific generation options. The bundled `linear_spec_lora` adapter is loaded automatically when available. | ||
| The upstream mode alias `generation_mode="linear_spec"` is also accepted. | ||
|
|
||
| ```sh | ||
| mlx_vlm.generate \ | ||
| --model nvidia/Nemotron-Labs-Diffusion-8B \ | ||
| --prompt "Write a short story about a clockmaker." \ | ||
| --max-tokens 256 \ | ||
| --temperature 0.0 \ | ||
| --gen-kwargs '{"generation_mode": "linear_speculative"}' | ||
| ``` | ||
|
|
||
| ### Thinking mode | ||
|
|
||
| ```sh | ||
| mlx_vlm.generate \ | ||
| --model nvidia/Nemotron-Labs-Diffusion-8B \ | ||
| --prompt "Solve this step by step: if a train travels 180 km in 2.5 hours, what is its average speed?" \ | ||
| --enable-thinking \ | ||
| --max-tokens 512 \ | ||
| --temperature 0.0 | ||
| ``` | ||
|
|
||
| ## Python | ||
|
|
||
| ### Basic text generation | ||
|
|
||
| ```python | ||
| from mlx_vlm import generate, load | ||
| from mlx_vlm.prompt_utils import apply_chat_template | ||
|
|
||
| model, processor = load("nvidia/Nemotron-Labs-Diffusion-8B") | ||
|
|
||
| prompt = apply_chat_template( | ||
| processor, | ||
| model.config, | ||
| "Write a short story about a clockmaker.", | ||
| ) | ||
|
|
||
| result = generate( | ||
| model=model, | ||
| processor=processor, | ||
| prompt=prompt, | ||
| max_tokens=256, | ||
| temperature=0.0, | ||
| ) | ||
| print(result.text) | ||
| ``` | ||
|
|
||
| ### Diffusion generation | ||
|
|
||
| ```python | ||
| from mlx_vlm import generate, load | ||
| from mlx_vlm.prompt_utils import apply_chat_template | ||
|
|
||
| model, processor = load("nvidia/Nemotron-Labs-Diffusion-8B") | ||
|
|
||
| prompt = apply_chat_template( | ||
| processor, | ||
| model.config, | ||
| "Write a short story about a clockmaker.", | ||
| ) | ||
|
|
||
| result = generate( | ||
| model=model, | ||
| processor=processor, | ||
| prompt=prompt, | ||
| max_tokens=256, | ||
| max_denoising_steps=16, | ||
| temperature=0.0, | ||
| generation_mode="diffusion", | ||
| ) | ||
| print(result.text) | ||
| ``` | ||
|
|
||
| ### Linear self-speculative generation | ||
|
|
||
| ```python | ||
| from mlx_vlm import generate, load | ||
| from mlx_vlm.prompt_utils import apply_chat_template | ||
|
|
||
| model, processor = load("nvidia/Nemotron-Labs-Diffusion-8B") | ||
|
|
||
| prompt = apply_chat_template( | ||
| processor, | ||
| model.config, | ||
| "Write a short story about a clockmaker.", | ||
| ) | ||
|
|
||
| result = generate( | ||
| model=model, | ||
| processor=processor, | ||
| prompt=prompt, | ||
| max_tokens=256, | ||
| temperature=0.0, | ||
| generation_mode="linear_speculative", | ||
| ) | ||
| print(result.text) | ||
| ``` | ||
|
|
||
| ## Architecture | ||
|
|
||
| - **Backbone** - dense decoder-only Ministral-style transformer | ||
| - **Layers** - 34 transformer layers | ||
| - **Hidden size** - 4096 | ||
| - **Attention** - 32 query heads, 8 KV heads, 128 head dimension | ||
| - **MLP** - SwiGLU with 14336 intermediate size | ||
| - **RoPE** - long-context YaRN/Llama 4-style scaling parameters from the checkpoint | ||
| - **Diffusion head** - untied output projection over the 131072-token vocabulary | ||
| - **Mask token** - `mask_token_id=100` | ||
|
|
||
| ## Notes | ||
|
|
||
| - The model is text-only. Image, audio, and video inputs are not supported. | ||
| - AR generation should use the normal CLI without diffusion-specific arguments. | ||
| - Diffusion generation uses masked block denoising. `--verbose` shows the block visualization as masks are filled. | ||
| - The default diffusion schedule uses 32 denoising steps and a 0.9 confidence threshold. Lower `--max-denoising-steps` for speed experiments, but quality can degrade quickly. | ||
| - Diffusion generation records model-level stats such as `diffusion_denoise_nfe`, `diffusion_post_block_nfe`, and `diffusion_tokens_per_denoise_forward`. Use `head_scoring="chunked"` to profile the non-materializing confidence scorer. | ||
| - Diffusion and linear self-speculative generation are exposed through `generation_mode`, for example `--gen-kwargs '{"generation_mode": "diffusion"}'`. | ||
| - Upstream mode names are accepted as aliases: `dlm` for diffusion and `linear_spec` for linear self-speculation. | ||
| - The optional `linear_spec_lora` adapter included in the Hugging Face repo is used only during the diffusion draft phase of linear self-speculation. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| from .config import ModelConfig | ||
| from .nemotron_labs_diffusion import Model | ||
|
|
||
| __all__ = ["Model", "ModelConfig"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,64 @@ | ||
| from dataclasses import dataclass | ||
| from typing import Any, Dict, Optional, Union | ||
|
|
||
| from ..base import BaseModelConfig | ||
|
|
||
|
|
||
| @dataclass | ||
| class ModelConfig(BaseModelConfig): | ||
| model_type: str = "nemotron_labs_diffusion" | ||
| vocab_size: int = 131072 | ||
| hidden_size: int = 4096 | ||
| intermediate_size: int = 14336 | ||
| num_hidden_layers: int = 34 | ||
| num_attention_heads: int = 32 | ||
| num_key_value_heads: int = 8 | ||
| head_dim: Optional[int] = 128 | ||
| hidden_act: str = "silu" | ||
| max_position_embeddings: int = 262144 | ||
| initializer_range: float = 0.02 | ||
| rms_norm_eps: float = 1e-5 | ||
| use_cache: bool = False | ||
| pad_token_id: Optional[int] = None | ||
| bos_token_id: Optional[int] = 1 | ||
| eos_token_id: Optional[Union[int, list[int]]] = 11 | ||
| tie_word_embeddings: bool = False | ||
| rope_theta: float = 1000000.0 | ||
| rope_parameters: Optional[Dict[str, Any]] = None | ||
| rope_scaling: Optional[Dict[str, Any]] = None | ||
| attention_bias: bool = False | ||
| attention_dropout: float = 0.0 | ||
| mlp_bias: bool = False | ||
| sliding_window: Optional[int] = None | ||
| attn_implementation: str = "sdpa" | ||
| mask_token_id: int = 100 | ||
| default_generation_mode: str = "ar" | ||
| default_diffusion_sampler: str = "native" | ||
| default_diffusion_steps: int = 32 | ||
| default_diffusion_threshold: Optional[float] = 0.9 | ||
| default_diffusion_min_threshold: Optional[float] = 0.45 | ||
| default_diffusion_sampling_scaling_factor: float = 2.0 | ||
| dlm_paradigm: str = "bidirectional" | ||
| block_size: int = 32 | ||
| dlm_loss_weight: Optional[float] = None | ||
| ar_loss_weight: float = 1.0 | ||
| dp_varying_mask_ratio: bool = False | ||
|
|
||
| def __post_init__(self): | ||
| if self.head_dim is None: | ||
| self.head_dim = self.hidden_size // self.num_attention_heads | ||
|
|
||
| rope_parameters = ( | ||
| dict(self.rope_parameters) | ||
| if self.rope_parameters is not None | ||
| else ( | ||
| dict(self.rope_scaling) | ||
| if self.rope_scaling is not None | ||
| else {"rope_type": "default", "rope_theta": self.rope_theta} | ||
| ) | ||
| ) | ||
| rope_parameters.setdefault("rope_type", "default") | ||
| rope_parameters.setdefault("rope_theta", self.rope_theta) | ||
| self.rope_parameters = rope_parameters | ||
| self.rope_scaling = rope_parameters | ||
| self.rope_theta = float(rope_parameters.get("rope_theta", self.rope_theta)) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To fix