Skip to content

Add NVIDIA LocateAnything-3B (MoonViT + Qwen2.5, AR + Parallel Box Decoding)#1242

Open
beshkenadze wants to merge 16 commits into
Blaizzy:mainfrom
beshkenadze:feat/locateanything-3b
Open

Add NVIDIA LocateAnything-3B (MoonViT + Qwen2.5, AR + Parallel Box Decoding)#1242
beshkenadze wants to merge 16 commits into
Blaizzy:mainfrom
beshkenadze:feat/locateanything-3b

Conversation

@beshkenadze
Copy link
Copy Markdown
Contributor

Summary

Ports nvidia/LocateAnything-3B — a visual-grounding VLM (object detection / referring-expression grounding / pointing / GUI & text localization) — into mlx-vlm so it runs on Apple Silicon via mlx_vlm.generate.

Architecture: MoonViT-SO-400M vision tower (shared with Kimi-VL) + Qwen2.5-3B text backbone + a 2-layer MLP connector (mlp1). Output is structured coordinate tokens, e.g. <ref>remote</ref><box><64><152><273><244></box>, with coordinates quantized to <0><1000> (normalized).

What's implemented

  • Model package mlx_vlm/models/locateanything/:
    • config.pyVisionConfig (moonvit) / TextConfig (qwen2) / ModelConfig with grounding token ids.
    • vision.py — MoonViT tower (2D RoPE, per-image block attention, 2×2 patch merge), ported from kimi_vl and reconciled to LocateAnything's weight names, PytorchGELUTanh activation, and LayerNorm eps.
    • language.py — standard Qwen2.5-3B causal LM (1D RoPE, GQA 16/2, tied embeddings) + the non-causal "magi" block mask for PBD.
    • locateanything.pymlp1 connector + Model (vision → projector → scatter at <IMG_CONTEXT>) + sanitize().
    • pbd.py — Parallel Box Decoding (MTP block decoder + bbox decode utils).
    • image_processing_locateanything.py / processing_locateanything.py + chat_template.json (preprocessing matched bit-for-bit to the HF reference — bicubic ceil-resize, not center-crop).
  • prompt_utils.py — register locateanythingLIST_WITH_IMAGE_FIRST.
  • generate/dispatch.py — a small additive, opt-in hook routing fast/hybrid to PBD (gated on model_type == "locateanything"); slow and every other model are unchanged.
  • tests/test_locateanything.py — config, vision/language/full-forward shapes, sanitize coverage, magi-mask, PBD decode utils, max-tokens, image-input handling (22 tests, all green).

Decoding modes

mode description throughput*
slow (default) pure autoregressive
fast Parallel Box Decoding (MTP, parallel blocks) ~2×
hybrid PBD with AR fallback on format irregularity ~2×

* 16-token COCO run. All three modes produce byte-identical grounding output.

Parity vs the PyTorch (CUDA) reference

Verified the MLX port numerically against the original HF/PyTorch model on an RTX 4090 (WSL, transformers==4.51, fp32). The HF vision_model + mlp1 were dumped on identical pixel_values, then fed to the MLX modules in fp32 and compared:

stage grid 64×64 (no pos-emb interp) grid 36×46 (with interp)
vision_model cos 0.999937, mean|Δ| ≈ 1e-3 cos 0.9894
mlp1 (connector) cos 0.999898 cos 0.9971
  • The port's math is numerically faithful (cos ≈ 0.99994 on identical inputs).
  • The entire residual on non-square grids comes from one op: the learnable 2D pos-emb bicubic interpolation — MLX's shared bicubic_interpolate kernel uses a = -0.5 vs PyTorch's a = -0.75. It is localized to the additive pos-emb (hence connector cos > vision cos: the connector's LayerNorm partly cancels it), affects every MLX MoonViT port (incl. kimi_vl), and does not change the grounding output. Tracked in bicubic_interpolate uses a=-0.5 (Keys') instead of PyTorch's a=-0.75 — degrades MoonViT pos-emb parity #1241.
  • PBD parity: by the model's design invariant (hybrid falls back to AR for consistency), fast/hybrid are verified byte-identical to the slow (AR) path on the same input — so the verified AR path is itself the oracle for the parallel path.

Verification

  • Unit tests: python -m unittest mlx_vlm.tests.test_locateanything → 22 passed.
  • End-to-end (real weights, Apple Silicon):
    python -m mlx_vlm.generate --model nvidia/LocateAnything-3B \
      --image http://images.cocodataset.org/val2017/000000039769.jpg \
      --prompt "Detect all objects in the image." --max-tokens 128 --temperature 0.0
    
    <ref>remote</ref><box><64><152><273><244></box><box><522><160><578><390></box> (boxes match the two remotes); prompt token count matches the HF processor exactly.
  • Large images: fixed a vision-attention OOM — a single image's all-True block mask was passed explicitly, forcing mx.fast.scaled_dot_product_attention off the flash kernel and materializing a dense [1,heads,S,S] score tensor (OOM on big frames). Now mask=None for one image (flash, O(N) memory); multi-image keeps the block-diagonal mask.
  • Codex CLI review (codex exec review): no P0/P1; two P2 edge-cases (PBD max_tokens, mx.array image input) fixed.

Blast radius

Everything PBD/model-specific lives in the locateanything package; the only shared edits are one additive line in prompt_utils.py and one additive, gated hook in dispatch.py. No other model, the vision tower, the processor, or shared SDPA is affected; slow AR remains the default.

Quantized weights

MLX builds published to mlx-community: bf16, 8bit, 4bit (mixed 4/8-bit — pure 4-bit degrades the tied coordinate-token embedding).

Comment thread mlx_vlm/models/locateanything/chat_template.json Outdated
Comment thread mlx_vlm/generate/dispatch.py Outdated
beshkenadze added a commit to beshkenadze/mlx-vlm that referenced this pull request May 30, 2026
…atch PBD hook

- chat_template.json is loaded from the model repo at runtime (base.load_chat_template
  / processor.from_pretrained); the bundled copy was unused. nvidia/LocateAnything-3B
  and the mlx-community quant repos all ship it. (@lucasnewman)
- Remove the model-specific PBD branch from generate/dispatch.py; fast/hybrid are now
  reached only via model.pbd_generate(...), leaving the dispatcher for AR/diffusion/
  image-gen path selection. AR (slow) stays the default. (@lucasnewman)
beshkenadze added a commit to beshkenadze/mlx-vlm that referenced this pull request May 30, 2026
…e flag (Blaizzy#1242)

Re-add fast/hybrid access (removed as a model-specific hook earlier) as a generic,
capability-based route: a --generation-mode {slow,fast,hybrid} flag plus a
hasattr(model, 'pbd_generate') check in stream_generate. Consistent with how the
dispatcher already routes image-generation by model capability. Default slow/AR is
unchanged for every other model; generation_mode is consumed so it never propagates
to generate_step. (@lucasnewman review)
…e, no crop) + add parity harness

Parity vs HF reference (RTX 4090, transformers 4.51, fp32, identical inputs):
vision_model cos=0.999937, mlp1 cos=0.999898 (grid 64x64, no pos-emb interp).
Image processor previously center-cropped down (grid 34x44, dropped border
pixels); HF bicubic-resizes up (grid 36x46). Now matched -> 442 prompt tokens
identical to HF. Residual ~1% on non-square grids is the shared bicubic pos-emb
interpolation kernel (MLX vs torch), which does not affect output correctness.
Implements PBD — the headline LocateAnything-3B feature — as an opt-in
multi-token-prediction (MTP) block decoder on top of the existing AR port.

- language.py: magi non-causal block-attention mask builder
  (build_magi_block_mask, dense equivalent of HF build_magi_ranges) plus an
  explicit-position RoPE path for the duplicated bridge token. The causal AR
  path is untouched (position_ids=None preserves original behaviour).
- pbd.py: PBD decode loop (MTP forward -> sample block -> accept / AR fallback)
  with ported decode utils (decode_bbox_avg, decode_ref, handle_pattern,
  is_valid_box_frame). KV cache rewind via KVCache.trim after each block.
- locateanything.py: Model.pbd_generate + make_cache entry points.
- config.py: block_size, causal_attn, text_mask/null/switch token ids,
  n_future_tokens.
- dispatch.py: additive, triple-gated opt-in hook routing locateanything
  fast/hybrid to pbd_generate; slow and every other model stay on default AR.

Verified on COCO cats image (greedy): fast == hybrid == slow == AR oracle
(byte-identical). PBD ~2x faster than slow. 16 unit tests pass.
…flash path (#3)

A single image's block mask is all-True (no-op), but passing it explicitly
forced mx.fast.scaled_dot_product_attention off the flash kernel and
materialized a dense [1,heads,S,S] fp32 score tensor -> 15.58GB / OOM on large
frames (e.g. 2304x1296 -> 15604 patches). Now pass mask=None for a single image
(flash, O(N) memory); multi-image batches keep the block-diagonal mask. Single-
image output is unchanged (verified: identical COCO boxes); +3 mask-logic tests.
- pbd: truncate generated tokens to max_tokens (fast/hybrid could overrun the
  budget by appending a full block past the limit, e.g. max_tokens<block_size).
- image processor: convert mx.array -> PIL before HF validation (make_list_of_images
  rejected mx.array, making the advertised array path dead code); reject unknown
  types consistently. +3 regression tests (22 total).
Parity/oracle/upload helpers were local dev artifacts; they don't belong in
the model port. Removed so the PR contains only the locateanything package +
the additive prompt_utils/dispatch hooks + tests.
…atch PBD hook

- chat_template.json is loaded from the model repo at runtime (base.load_chat_template
  / processor.from_pretrained); the bundled copy was unused. nvidia/LocateAnything-3B
  and the mlx-community quant repos all ship it. (@lucasnewman)
- Remove the model-specific PBD branch from generate/dispatch.py; fast/hybrid are now
  reached only via model.pbd_generate(...), leaving the dispatcher for AR/diffusion/
  image-gen path selection. AR (slow) stays the default. (@lucasnewman)
…e flag (Blaizzy#1242)

Re-add fast/hybrid access (removed as a model-specific hook earlier) as a generic,
capability-based route: a --generation-mode {slow,fast,hybrid} flag plus a
hasattr(model, 'pbd_generate') check in stream_generate. Consistent with how the
dispatcher already routes image-generation by model capability. Default slow/AR is
unchanged for every other model; generation_mode is consumed so it never propagates
to generate_step. (@lucasnewman review)
@beshkenadze beshkenadze force-pushed the feat/locateanything-3b branch from d09a6b7 to df7254d Compare May 30, 2026 17:05
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing save_pretrained() method

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll check it out, since it ran without any errors on a separate Mac mini.

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would revert this this, because its model specific, only one model has this model so it's not meant to be here.

In #1239 I'm adding gen-kwargs which you will be able to use here.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, np.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants