Add NVIDIA LocateAnything-3B (MoonViT + Qwen2.5, AR + Parallel Box Decoding)#1242
Open
beshkenadze wants to merge 16 commits into
Open
Add NVIDIA LocateAnything-3B (MoonViT + Qwen2.5, AR + Parallel Box Decoding)#1242beshkenadze wants to merge 16 commits into
beshkenadze wants to merge 16 commits into
Conversation
lucasnewman
reviewed
May 30, 2026
lucasnewman
reviewed
May 30, 2026
beshkenadze
added a commit
to beshkenadze/mlx-vlm
that referenced
this pull request
May 30, 2026
…atch PBD hook - chat_template.json is loaded from the model repo at runtime (base.load_chat_template / processor.from_pretrained); the bundled copy was unused. nvidia/LocateAnything-3B and the mlx-community quant repos all ship it. (@lucasnewman) - Remove the model-specific PBD branch from generate/dispatch.py; fast/hybrid are now reached only via model.pbd_generate(...), leaving the dispatcher for AR/diffusion/ image-gen path selection. AR (slow) stays the default. (@lucasnewman)
beshkenadze
added a commit
to beshkenadze/mlx-vlm
that referenced
this pull request
May 30, 2026
…e flag (Blaizzy#1242) Re-add fast/hybrid access (removed as a model-specific hook earlier) as a generic, capability-based route: a --generation-mode {slow,fast,hybrid} flag plus a hasattr(model, 'pbd_generate') check in stream_generate. Consistent with how the dispatcher already routes image-generation by model capability. Default slow/AR is unchanged for every other model; generation_mode is consumed so it never propagates to generate_step. (@lucasnewman review)
…nector/model, processor, tests
…e, no crop) + add parity harness Parity vs HF reference (RTX 4090, transformers 4.51, fp32, identical inputs): vision_model cos=0.999937, mlp1 cos=0.999898 (grid 64x64, no pos-emb interp). Image processor previously center-cropped down (grid 34x44, dropped border pixels); HF bicubic-resizes up (grid 36x46). Now matched -> 442 prompt tokens identical to HF. Residual ~1% on non-square grids is the shared bicubic pos-emb interpolation kernel (MLX vs torch), which does not affect output correctness.
Implements PBD — the headline LocateAnything-3B feature — as an opt-in multi-token-prediction (MTP) block decoder on top of the existing AR port. - language.py: magi non-causal block-attention mask builder (build_magi_block_mask, dense equivalent of HF build_magi_ranges) plus an explicit-position RoPE path for the duplicated bridge token. The causal AR path is untouched (position_ids=None preserves original behaviour). - pbd.py: PBD decode loop (MTP forward -> sample block -> accept / AR fallback) with ported decode utils (decode_bbox_avg, decode_ref, handle_pattern, is_valid_box_frame). KV cache rewind via KVCache.trim after each block. - locateanything.py: Model.pbd_generate + make_cache entry points. - config.py: block_size, causal_attn, text_mask/null/switch token ids, n_future_tokens. - dispatch.py: additive, triple-gated opt-in hook routing locateanything fast/hybrid to pbd_generate; slow and every other model stay on default AR. Verified on COCO cats image (greedy): fast == hybrid == slow == AR oracle (byte-identical). PBD ~2x faster than slow. 16 unit tests pass.
…(review finding 2)
…flash path (#3) A single image's block mask is all-True (no-op), but passing it explicitly forced mx.fast.scaled_dot_product_attention off the flash kernel and materialized a dense [1,heads,S,S] fp32 score tensor -> 15.58GB / OOM on large frames (e.g. 2304x1296 -> 15604 patches). Now pass mask=None for a single image (flash, O(N) memory); multi-image batches keep the block-diagonal mask. Single- image output is unchanged (verified: identical COCO boxes); +3 mask-logic tests.
- pbd: truncate generated tokens to max_tokens (fast/hybrid could overrun the budget by appending a full block past the limit, e.g. max_tokens<block_size). - image processor: convert mx.array -> PIL before HF validation (make_list_of_images rejected mx.array, making the advertised array path dead code); reject unknown types consistently. +3 regression tests (22 total).
Parity/oracle/upload helpers were local dev artifacts; they don't belong in the model port. Removed so the PR contains only the locateanything package + the additive prompt_utils/dispatch hooks + tests.
…atch PBD hook - chat_template.json is loaded from the model repo at runtime (base.load_chat_template / processor.from_pretrained); the bundled copy was unused. nvidia/LocateAnything-3B and the mlx-community quant repos all ship it. (@lucasnewman) - Remove the model-specific PBD branch from generate/dispatch.py; fast/hybrid are now reached only via model.pbd_generate(...), leaving the dispatcher for AR/diffusion/ image-gen path selection. AR (slow) stays the default. (@lucasnewman)
…e flag (Blaizzy#1242) Re-add fast/hybrid access (removed as a model-specific hook earlier) as a generic, capability-based route: a --generation-mode {slow,fast,hybrid} flag plus a hasattr(model, 'pbd_generate') check in stream_generate. Consistent with how the dispatcher already routes image-generation by model capability. Default slow/AR is unchanged for every other model; generation_mode is consumed so it never propagates to generate_step. (@lucasnewman review)
d09a6b7 to
df7254d
Compare
Blaizzy
reviewed
May 30, 2026
Owner
There was a problem hiding this comment.
Missing save_pretrained() method
Contributor
Author
There was a problem hiding this comment.
I'll check it out, since it ran without any errors on a separate Mac mini.
Blaizzy
reviewed
May 30, 2026
Owner
There was a problem hiding this comment.
I would revert this this, because its model specific, only one model has this model so it's not meant to be here.
In #1239 I'm adding gen-kwargs which you will be able to use here.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Ports
nvidia/LocateAnything-3B— a visual-grounding VLM (object detection / referring-expression grounding / pointing / GUI & text localization) — into mlx-vlm so it runs on Apple Silicon viamlx_vlm.generate.Architecture: MoonViT-SO-400M vision tower (shared with Kimi-VL) + Qwen2.5-3B text backbone + a 2-layer MLP connector (
mlp1). Output is structured coordinate tokens, e.g.<ref>remote</ref><box><64><152><273><244></box>, with coordinates quantized to<0>…<1000>(normalized).What's implemented
mlx_vlm/models/locateanything/:config.py—VisionConfig(moonvit) /TextConfig(qwen2) /ModelConfigwith grounding token ids.vision.py— MoonViT tower (2D RoPE, per-image block attention, 2×2 patch merge), ported fromkimi_vland reconciled to LocateAnything's weight names,PytorchGELUTanhactivation, and LayerNorm eps.language.py— standard Qwen2.5-3B causal LM (1D RoPE, GQA 16/2, tied embeddings) + the non-causal "magi" block mask for PBD.locateanything.py—mlp1connector +Model(vision → projector → scatter at<IMG_CONTEXT>) +sanitize().pbd.py— Parallel Box Decoding (MTP block decoder + bbox decode utils).image_processing_locateanything.py/processing_locateanything.py+chat_template.json(preprocessing matched bit-for-bit to the HF reference — bicubic ceil-resize, not center-crop).prompt_utils.py— registerlocateanything→LIST_WITH_IMAGE_FIRST.generate/dispatch.py— a small additive, opt-in hook routingfast/hybridto PBD (gated onmodel_type == "locateanything");slowand every other model are unchanged.tests/test_locateanything.py— config, vision/language/full-forward shapes, sanitize coverage, magi-mask, PBD decode utils, max-tokens, image-input handling (22 tests, all green).Decoding modes
slow(default)fasthybrid* 16-token COCO run. All three modes produce byte-identical grounding output.
Parity vs the PyTorch (CUDA) reference
Verified the MLX port numerically against the original HF/PyTorch model on an RTX 4090 (WSL,
transformers==4.51, fp32). The HFvision_model+mlp1were dumped on identicalpixel_values, then fed to the MLX modules in fp32 and compared:vision_modelmlp1(connector)bicubic_interpolatekernel usesa = -0.5vs PyTorch'sa = -0.75. It is localized to the additive pos-emb (hence connector cos > vision cos: the connector'sLayerNormpartly cancels it), affects every MLX MoonViT port (incl.kimi_vl), and does not change the grounding output. Tracked in bicubic_interpolate uses a=-0.5 (Keys') instead of PyTorch's a=-0.75 — degrades MoonViT pos-emb parity #1241.fast/hybridare verified byte-identical to theslow(AR) path on the same input — so the verified AR path is itself the oracle for the parallel path.Verification
python -m unittest mlx_vlm.tests.test_locateanything→ 22 passed.<ref>remote</ref><box><64><152><273><244></box><box><522><160><578><390></box>(boxes match the two remotes); prompt token count matches the HF processor exactly.mx.fast.scaled_dot_product_attentionoff the flash kernel and materializing a dense[1,heads,S,S]score tensor (OOM on big frames). Nowmask=Nonefor one image (flash, O(N) memory); multi-image keeps the block-diagonal mask.codex exec review): no P0/P1; two P2 edge-cases (PBDmax_tokens,mx.arrayimage input) fixed.Blast radius
Everything PBD/model-specific lives in the
locateanythingpackage; the only shared edits are one additive line inprompt_utils.pyand one additive, gated hook indispatch.py. No other model, the vision tower, the processor, or shared SDPA is affected;slowAR remains the default.Quantized weights
MLX builds published to mlx-community: bf16, 8bit, 4bit (mixed 4/8-bit — pure 4-bit degrades the tied coordinate-token embedding).