Skip to content

fix(quant): prevent device mismatch in GEMV guard and UnquantLinear forward#2089

Open
Jamesrobertsonldn wants to merge 1 commit intoEricLBuehler:masterfrom
Jamesrobertsonldn:upstream-fix
Open

fix(quant): prevent device mismatch in GEMV guard and UnquantLinear forward#2089
Jamesrobertsonldn wants to merge 1 commit intoEricLBuehler:masterfrom
Jamesrobertsonldn:upstream-fix

Conversation

@Jamesrobertsonldn
Copy link
Copy Markdown

Summary

Fixes a crash when weights end up on a different device than activations during inference. This can happen when weights are loaded from UQFF cache or created by fusing dequantized layers.

Bug 1: GEMV guard only checks activation device

should_use_gemv() checks x.device().is_cuda() but never verifies w.device().is_cuda(). When w is on CPU, the CUDA GEMV kernel is dispatched and crashes.

Fix: Add w.device().is_cuda() to the guard.

Bug 2: UnquantLinear assumes weight device matches activation

UnquantLinear::forward() passes self.w directly to matmul operations. When the weight is on CPU but the activation is on GPU, candle's matmul fails with "device mismatch".

Fix: Lazily migrate weight to activation's device using Cow<Tensor> — zero-cost when devices already match, one-time migration on first mismatch.

Reproduction

Observed on Qwen3.5-4B with Q4K quantization on Apple Metal (M4):

  • Model loads from UQFF cache successfully
  • First inference attempt crashes with device mismatch in matmul
  • Affects all quantized model tiers that use ISQ/UQFF

Changes

  • mistralrs-quant/src/gemv/mod.rs: Add w.device() check to should_use_gemv()
  • mistralrs-quant/src/unquantized/mod.rs: Device migration in UnquantLinear::forward()

🤖 Generated with Claude Code

…orward

Two bugs can cause crashes when weights end up on a different device
than activations (e.g. CPU weights with CUDA/Metal activations):

1. should_use_gemv() only checks x.device().is_cuda() but never
   verifies w.device().is_cuda(). When w is on CPU, the CUDA GEMV
   kernel crashes with a device mismatch.

2. UnquantLinear::forward() assumes self.w is on the same device as
   the activation. When weights are loaded from UQFF cache or created
   by fusing dequantized layers, they can end up on CPU. The standard
   matmul path also fails with "device mismatch in matmul".

Fix:
- Add w.device() check to should_use_gemv() guard
- Lazily migrate weight to activation's device in forward() using
  Cow<Tensor> (zero-cost when devices already match)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant