fix(quant): resolve dtype mismatch casting F8E4M3 to BF16 in UnquantLinear (#2072)#2096
Open
glaziermag wants to merge 1 commit intoEricLBuehler:masterfrom
Open
fix(quant): resolve dtype mismatch casting F8E4M3 to BF16 in UnquantLinear (#2072)#2096glaziermag wants to merge 1 commit intoEricLBuehler:masterfrom
glaziermag wants to merge 1 commit intoEricLBuehler:masterfrom
Conversation
…inear This fixes issue EricLBuehler#2072 where FP8-quantized weights could not be directly loaded into UnquantLinear fallback configurations due to unsupported native CUDA PTX casts in candle. UnquantLinear now safely routes F8E4M3 via scalar_fp8::ops::fp8_to_dtype. Signed-off-by: Gabe <gabe@example.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fixes Issue #2072.
This PR aims to resolve a dtype mismatch panic that occurs when
UnquantLinearattempts to dynamically cast weights from FP8-quantized tensors (DType::F8E4M3) natively. Sincecandle_core::DType::to_dtypedoes not export PTX CUDA instructions forF8E4M3 -> BF16standard resolutions yet, this branch explicitly maps the casting operation via thescalar_fp8::ops::fp8_to_dtypebridge instead.I've run some local verification tests on an L4 GPU (
g2-standard-32) to make sure this doesn't break existingUnquantLinearfallback logic for native models (like Llama-3 in BF16).1. Blast Radius Assessment
mistralrs-quant/src/unquantized/mod.rs(UnquantLinear::forward).UnquantLinearfallback algebra does not panic on baseline tests for unquantized models.2. Execution Logs (Before vs. After)
Before logs (reproducing the issue):
$ cargo run --release --features cuda -- serve --port 1234 -m neuralmagic/Meta-Llama-3-8B-Instruct-FP8After logs (with this patch applied):
$ cargo run --release --features cuda -- serve --port 1234 -m neuralmagic/Meta-Llama-3-8B-Instruct-FP83. Regression Tests Run
Ran the standard
mistralrs-quanttest suite natively with CUDA to help verify that the native operations remain structurally intact overBF16.$ cargo test --release --features cuda --package mistralrs-quantI'd appreciate any feedback on this approach! Thank you for reviewing!