fix(quant): prevent device mismatch in GEMV guard and UnquantLinear forward#2089
Open
Jamesrobertsonldn wants to merge 1 commit intoEricLBuehler:masterfrom
Open
fix(quant): prevent device mismatch in GEMV guard and UnquantLinear forward#2089Jamesrobertsonldn wants to merge 1 commit intoEricLBuehler:masterfrom
Jamesrobertsonldn wants to merge 1 commit intoEricLBuehler:masterfrom
Conversation
…orward Two bugs can cause crashes when weights end up on a different device than activations (e.g. CPU weights with CUDA/Metal activations): 1. should_use_gemv() only checks x.device().is_cuda() but never verifies w.device().is_cuda(). When w is on CPU, the CUDA GEMV kernel crashes with a device mismatch. 2. UnquantLinear::forward() assumes self.w is on the same device as the activation. When weights are loaded from UQFF cache or created by fusing dequantized layers, they can end up on CPU. The standard matmul path also fails with "device mismatch in matmul". Fix: - Add w.device() check to should_use_gemv() guard - Lazily migrate weight to activation's device in forward() using Cow<Tensor> (zero-cost when devices already match)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Fixes a crash when weights end up on a different device than activations during inference. This can happen when weights are loaded from UQFF cache or created by fusing dequantized layers.
Bug 1: GEMV guard only checks activation device
should_use_gemv()checksx.device().is_cuda()but never verifiesw.device().is_cuda(). Whenwis on CPU, the CUDA GEMV kernel is dispatched and crashes.Fix: Add
w.device().is_cuda()to the guard.Bug 2: UnquantLinear assumes weight device matches activation
UnquantLinear::forward()passesself.wdirectly to matmul operations. When the weight is on CPU but the activation is on GPU, candle's matmul fails with "device mismatch".Fix: Lazily migrate weight to activation's device using
Cow<Tensor>— zero-cost when devices already match, one-time migration on first mismatch.Reproduction
Observed on Qwen3.5-4B with Q4K quantization on Apple Metal (M4):
Changes
mistralrs-quant/src/gemv/mod.rs: Addw.device()check toshould_use_gemv()mistralrs-quant/src/unquantized/mod.rs: Device migration inUnquantLinear::forward()🤖 Generated with Claude Code