Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
b1ec70d
Fix Qwen3.5 batched left-padding drift
Blaizzy May 21, 2026
ac4f2d5
Fix Qwen target verify batch drift
Blaizzy May 21, 2026
256757b
Fix Qwen batch parity for padded vision rows
Blaizzy May 21, 2026
f8f570c
Fix exact batched Qwen MTP verification
Blaizzy May 21, 2026
7722c96
Fix ragged Qwen3.5 MTP batch parity
Blaizzy May 21, 2026
a40c23b
Keep Qwen MTP ragged greedy and uniform sampled parity
Blaizzy May 21, 2026
11eb394
Enable ragged Qwen MTP for sampled batches
Blaizzy May 21, 2026
3aaba7b
Fix server sampler reuse across idle batches
Blaizzy May 22, 2026
744372a
Route server MTP singleton through batch path
Blaizzy May 22, 2026
c28fa1a
Fix seeded Qwen MTP CLI parity
Blaizzy May 23, 2026
8169230
Speed up exact positioned MTP sampling
Blaizzy May 23, 2026
c3581f2
Use exact qmatvec for Qwen MTP verifier logits
Blaizzy May 23, 2026
9b25b25
Fuse Qwen GDN accepted state scatter
Blaizzy May 23, 2026
a94e83c
Fix singleton batch generator cache performance
Blaizzy May 23, 2026
774c69e
Route server MTP through batch generator
Blaizzy May 24, 2026
e4471e0
Speed up quantized Qwen batch decode
Blaizzy May 24, 2026
7d62403
Use true batched Qwen3.5 server decode
Blaizzy May 24, 2026
35fd30e
Add exact ragged Qwen3.5 decode attention
Blaizzy May 24, 2026
1a08841
Avoid slow mixed ragged attention dispatch
Blaizzy May 24, 2026
ceb9049
Improve Qwen3.5 batched decode scaling
Blaizzy May 25, 2026
159bf24
Avoid MTP rollback syncs
Blaizzy May 25, 2026
8662366
Remove slow Qwen decode qmv path
Blaizzy May 25, 2026
a8e73d2
Reduce Qwen3.5 batched decode sync overhead
Blaizzy May 25, 2026
1bfecb6
Merge remote-tracking branch 'origin/main' into pc/qwen-mtp-batch-drift
Blaizzy May 27, 2026
c807a22
Support PoolingCache in batched cache creation
Blaizzy May 27, 2026
0f726e0
Apply pre-commit formatting
Blaizzy May 30, 2026
367a90e
Merge main into qwen MTP batch drift
Blaizzy May 30, 2026
859d724
Improve Qwen batch decode stability
Blaizzy May 31, 2026
b938a64
Merge remote-tracking branch 'origin/main' into pc/qwen-mtp-batch-drift
Blaizzy Jun 1, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
568 changes: 534 additions & 34 deletions mlx_vlm/generate/ar.py

Large diffs are not rendered by default.

185 changes: 180 additions & 5 deletions mlx_vlm/models/qwen3_5/gated_delta.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,42 @@
from functools import partial
from typing import Optional

import mlx.core as mx
from mlx_lm.models.gated_delta import compute_g, gated_delta_update # noqa: F401
import mlx.nn as nn
from mlx_lm.models.gated_delta import gated_delta_kernel, gated_delta_ops


@partial(mx.compile, shapeless=True)
def compute_g(A_log, a, dt_bias):
return mx.exp(-mx.exp(A_log.astype(mx.float32)) * nn.softplus(a + dt_bias))


@partial(mx.compile, shapeless=True)
def _compute_g_beta(A_log, a, b, dt_bias):
return compute_g(A_log, a, dt_bias), mx.sigmoid(b)


def gated_delta_update(
q: mx.array,
k: mx.array,
v: mx.array,
a: mx.array,
b: mx.array,
A_log: mx.array,
dt_bias: mx.array,
state: Optional[mx.array] = None,
mask: Optional[mx.array] = None,
use_kernel: bool = True,
):
g, beta = _compute_g_beta(A_log, a, b, dt_bias)
if state is None:
B, _, _Hk, Dk = q.shape
Hv, Dv = v.shape[-2:]
state = mx.zeros((B, Hv, Dv, Dk), dtype=mx.float32)

if not use_kernel or mx.default_device() != mx.gpu or not mx.metal.is_available():
return gated_delta_ops(q, k, v, g, beta, state, mask)
return gated_delta_kernel(q, k, v, g, beta, state, mask)


def _make_gated_delta_with_states_kernel(has_mask: bool = False):
Expand Down Expand Up @@ -153,8 +188,7 @@ def gated_delta_update_with_states(
mask: Optional[mx.array] = None,
use_kernel: bool = True,
):
beta = mx.sigmoid(b)
g = compute_g(A_log, a, dt_bias)
g, beta = _compute_g_beta(A_log, a, b, dt_bias)
if state is None:
B, _, _Hk, Dk = q.shape
Hv, Dv = v.shape[-2:]
Expand Down Expand Up @@ -276,6 +310,69 @@ def _make_gated_delta_state_kernel(has_mask: bool = False):
_gated_delta_state_kernel_masked = _make_gated_delta_state_kernel(True)


def _make_gated_delta_accept_states_kernel():
if not mx.metal.is_available():
return None

return mx.fast.metal_kernel(
name="qwen3_5_gated_delta_accept_states",
input_names=[
"intermediate_states",
"conv_input",
"live_state",
"live_conv",
"accepted",
],
output_names=["state_out", "conv_out"],
source=r"""
uint idx = thread_position_in_grid.x;

if (idx < StateTotal) {
uint dk = idx % Dk;
uint t0 = idx / Dk;
uint dv = t0 % Dv;
t0 /= Dv;
uint hv = t0 % Hv;
uint row = t0 / Hv;

int step = int(accepted[row]);
bool use_intermediate = step >= 0 && step < T;
StT value;
if (use_intermediate) {
value = intermediate_states[
((((row * T + uint(step)) * Hv + hv) * Dv + dv) * Dk + dk)
];
} else {
value = live_state[((row * Hv + hv) * Dv + dv) * Dk + dk];
}
state_out[idx] = static_cast<StT>(value);
}

if (idx < ConvTotal) {
uint c = idx % C;
uint t0 = idx / C;
uint win = t0 % ConvW;
uint row = t0 / ConvW;

int step = int(accepted[row]);
bool use_intermediate = step >= 0 && step < T;
ConvT value;
if (use_intermediate) {
value = conv_input[
(row * ConvInputT + uint(step) + 1 + win) * C + c
];
} else {
value = live_conv[(row * ConvW + win) * C + c];
}
conv_out[idx] = static_cast<ConvT>(value);
}
""",
)


_gated_delta_accept_states_kernel = _make_gated_delta_accept_states_kernel()


def _gated_delta_state_ops(
k: mx.array,
v: mx.array,
Expand Down Expand Up @@ -305,6 +402,85 @@ def _gated_delta_state_ops(
return state


def _gated_delta_accept_states_ops(
intermediate_states: mx.array,
conv_input: mx.array,
live_state: mx.array,
live_conv: mx.array,
accepted: mx.array,
kernel_size: int,
):
steps = [int(step) for step in accepted.tolist()]
state_rows = []
conv_rows = []
state_steps = intermediate_states.shape[1]
for row, step in enumerate(steps):
if 0 <= step < state_steps:
state_rows.append(intermediate_states[row, step])
conv_rows.append(conv_input[row : row + 1, step + 1 : step + kernel_size])
else:
state_rows.append(live_state[row])
conv_rows.append(live_conv[row : row + 1])
return mx.stack(state_rows, axis=0), mx.concatenate(conv_rows, axis=0)


def gated_delta_accept_states(
intermediate_states: mx.array,
conv_input: mx.array,
live_state: mx.array,
live_conv: mx.array,
accepted: mx.array,
kernel_size: int,
use_kernel: bool = True,
):
if accepted.dtype != mx.int32:
accepted = accepted.astype(mx.int32)

if (
not use_kernel
or mx.default_device() != mx.gpu
or not mx.metal.is_available()
or _gated_delta_accept_states_kernel is None
):
return _gated_delta_accept_states_ops(
intermediate_states,
conv_input,
live_state,
live_conv,
accepted,
kernel_size,
)

rows, state_steps, Hv, Dv, Dk = intermediate_states.shape
conv_input_t = conv_input.shape[1]
conv_dim = conv_input.shape[-1]
conv_window = int(kernel_size) - 1
state_total = rows * Hv * Dv * Dk
conv_total = rows * conv_window * conv_dim
total = max(state_total, conv_total)

return _gated_delta_accept_states_kernel(
inputs=[intermediate_states, conv_input, live_state, live_conv, accepted],
template=[
("StT", intermediate_states.dtype),
("ConvT", conv_input.dtype),
("T", state_steps),
("Hv", Hv),
("Dv", Dv),
("Dk", Dk),
("C", conv_dim),
("ConvW", conv_window),
("ConvInputT", conv_input_t),
("StateTotal", state_total),
("ConvTotal", conv_total),
],
grid=(total, 1, 1),
threadgroup=(256, 1, 1),
output_shapes=[live_state.shape, live_conv.shape],
output_dtypes=[intermediate_states.dtype, conv_input.dtype],
)


def gated_delta_state_update(
k: mx.array,
v: mx.array,
Expand All @@ -317,8 +493,7 @@ def gated_delta_state_update(
mask: Optional[mx.array] = None,
use_kernel: bool = True,
) -> mx.array:
beta = mx.sigmoid(b)
g = compute_g(A_log, a, dt_bias)
g, beta = _compute_g_beta(A_log, a, b, dt_bias)
if state is None:
B, _, _Hk, Dk = k.shape
Hv, Dv = v.shape[-2:]
Expand Down
Loading
Loading