Add NVIDIA LocateAnything-3B (MoonViT + Qwen2.5, autoregressive mode)#1
Closed
beshkenadze wants to merge 11 commits into
Closed
Add NVIDIA LocateAnything-3B (MoonViT + Qwen2.5, autoregressive mode)#1beshkenadze wants to merge 11 commits into
beshkenadze wants to merge 11 commits into
Conversation
…nector/model, processor, tests
…e, no crop) + add parity harness Parity vs HF reference (RTX 4090, transformers 4.51, fp32, identical inputs): vision_model cos=0.999937, mlp1 cos=0.999898 (grid 64x64, no pos-emb interp). Image processor previously center-cropped down (grid 34x44, dropped border pixels); HF bicubic-resizes up (grid 36x46). Now matched -> 442 prompt tokens identical to HF. Residual ~1% on non-square grids is the shared bicubic pos-emb interpolation kernel (MLX vs torch), which does not affect output correctness.
4 tasks
Implements PBD — the headline LocateAnything-3B feature — as an opt-in multi-token-prediction (MTP) block decoder on top of the existing AR port. - language.py: magi non-causal block-attention mask builder (build_magi_block_mask, dense equivalent of HF build_magi_ranges) plus an explicit-position RoPE path for the duplicated bridge token. The causal AR path is untouched (position_ids=None preserves original behaviour). - pbd.py: PBD decode loop (MTP forward -> sample block -> accept / AR fallback) with ported decode utils (decode_bbox_avg, decode_ref, handle_pattern, is_valid_box_frame). KV cache rewind via KVCache.trim after each block. - locateanything.py: Model.pbd_generate + make_cache entry points. - config.py: block_size, causal_attn, text_mask/null/switch token ids, n_future_tokens. - dispatch.py: additive, triple-gated opt-in hook routing locateanything fast/hybrid to pbd_generate; slow and every other model stay on default AR. Verified on COCO cats image (greedy): fast == hybrid == slow == AR oracle (byte-identical). PBD ~2x faster than slow. 16 unit tests pass.
…(review finding 2)
…flash path (#3) A single image's block mask is all-True (no-op), but passing it explicitly forced mx.fast.scaled_dot_product_attention off the flash kernel and materialized a dense [1,heads,S,S] fp32 score tensor -> 15.58GB / OOM on large frames (e.g. 2304x1296 -> 15604 patches). Now pass mask=None for a single image (flash, O(N) memory); multi-image batches keep the block-diagonal mask. Single- image output is unchanged (verified: identical COCO boxes); +3 mask-logic tests.
5e471bb to
994806d
Compare
- pbd: truncate generated tokens to max_tokens (fast/hybrid could overrun the budget by appending a full block past the limit, e.g. max_tokens<block_size). - image processor: convert mx.array -> PIL before HF validation (make_list_of_images rejected mx.array, making the advertised array path dead code); reject unknown types consistently. +3 regression tests (22 total).
Parity/oracle/upload helpers were local dev artifacts; they don't belong in the model port. Removed so the PR contains only the locateanything package + the additive prompt_utils/dispatch hooks + tests.
Owner
Author
|
Superseded by the upstream PR → Blaizzy#1242 (same branch). Continuing review there. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Ports
nvidia/LocateAnything-3B— a visual-grounding VLM (object detection / referring-expression grounding / pointing / GUI & text localization) — into mlx-vlm so it runs on Apple Silicon viamlx_vlm.generate.Architecture: MoonViT-SO-400M vision tower (shared with Kimi-VL) + Qwen2.5-3B text backbone + a 2-layer MLP connector (
mlp1). Output is structured coordinate tokens, e.g.<ref>remote</ref><box><64><152><273><244></box>, with coordinates quantized to<0>…<1000>(normalized).What's implemented
mlx_vlm/models/locateanything/:config.py—VisionConfig(moonvit) /TextConfig(qwen2) /ModelConfigwith grounding token ids.vision.py— MoonViT tower (2D RoPE, per-image block attention, 2×2 patch merge), ported fromkimi_vland reconciled to LocateAnything's weight names,PytorchGELUTanhactivation, and LayerNorm eps.language.py— standard Qwen2.5-3B causal LM (1D RoPE, GQA 16/2, tied embeddings) + the non-causal "magi" block mask for PBD.locateanything.py—mlp1connector +Model(vision → projector → scatter at<IMG_CONTEXT>) +sanitize().pbd.py— Parallel Box Decoding (MTP block decoder + bbox decode utils).image_processing_locateanything.py/processing_locateanything.py+chat_template.json(preprocessing matched bit-for-bit to the HF reference — bicubic ceil-resize, not center-crop).prompt_utils.py— registerlocateanything→LIST_WITH_IMAGE_FIRST.generate/dispatch.py— a small additive, opt-in hook routingfast/hybridto PBD (gated onmodel_type == "locateanything");slowand every other model are unchanged.tests/test_locateanything.py— config, vision/language/full-forward shapes, sanitize coverage, magi-mask, PBD decode utils, max-tokens, image-input handling (22 tests, all green).Decoding modes
slow(default)fasthybrid* 16-token COCO run. All three modes produce byte-identical grounding output.
Parity vs the PyTorch (CUDA) reference
Verified the MLX port numerically against the original HF/PyTorch model on an RTX 4090 (WSL,
transformers==4.51, fp32). The HFvision_model+mlp1were dumped on identicalpixel_values, then fed to the MLX modules in fp32 and compared (scripts/la_parity_{ref,mlx}.py):vision_modelmlp1(connector)bicubic_interpolatekernel usesa = -0.5vs PyTorch'sa = -0.75. It is localized to the additive pos-emb (hence connector cos > vision cos: the connector'sLayerNormpartly cancels it), affects every MLX MoonViT port (incl.kimi_vl), and does not change the grounding output. Tracked in Align bicubic_interpolate with PyTorch (a=-0.75, not -0.5) for MoonViT pos-emb parity #2.fast/hybridare verified byte-identical to theslow(AR) path on the same input — so the verified AR path is itself the oracle for the parallel path.Verification
python -m unittest mlx_vlm.tests.test_locateanything→ 22 passed.<ref>remote</ref><box><64><152><273><244></box><box><522><160><578><390></box>(boxes match the two remotes); prompt token count matches the HF processor exactly.[S,S]mask forced SDPA off the flash path; now flash, O(N) memory) — LocateAnything-3B: dense O(N²) vision attention OOMs on large images (Metal single-buffer cap) #3.codex exec review): no P0/P1; two P2 edge-cases (PBDmax_tokens,mx.arrayimage input) fixed.Blast radius
Everything PBD/model-specific lives in the
locateanythingpackage; the only shared edits are one additive line inprompt_utils.pyand one additive, gated hook indispatch.py. No other model, the vision tower, the processor, or shared SDPA is affected;slowAR remains the default.Quantized weights
MLX builds published to mlx-community: bf16, 8bit, 4bit (mixed 4/8-bit — pure 4-bit degrades the tied coordinate-token embedding).