Skip to content

fix(quant): resolve dtype mismatch casting F8E4M3 to BF16 in UnquantLinear (#2072)#2096

Open
glaziermag wants to merge 1 commit intoEricLBuehler:masterfrom
glaziermag:fix-fp8-unquantlinear-mismatch
Open

fix(quant): resolve dtype mismatch casting F8E4M3 to BF16 in UnquantLinear (#2072)#2096
glaziermag wants to merge 1 commit intoEricLBuehler:masterfrom
glaziermag:fix-fp8-unquantlinear-mismatch

Conversation

@glaziermag
Copy link
Copy Markdown
Contributor

@glaziermag glaziermag commented Apr 10, 2026

Fixes Issue #2072.

This PR aims to resolve a dtype mismatch panic that occurs when UnquantLinear attempts to dynamically cast weights from FP8-quantized tensors (DType::F8E4M3) natively. Since candle_core::DType::to_dtype does not export PTX CUDA instructions for F8E4M3 -> BF16 standard resolutions yet, this branch explicitly maps the casting operation via the scalar_fp8::ops::fp8_to_dtype bridge instead.

I've run some local verification tests on an L4 GPU (g2-standard-32) to make sure this doesn't break existing UnquantLinear fallback logic for native models (like Llama-3 in BF16).

1. Blast Radius Assessment

  • Component: mistralrs-quant/src/unquantized/mod.rs (UnquantLinear::forward).
  • Potential Impact: General unquantized LLM inference operations loading in standalone precision.
  • Verification: Tests confirm UnquantLinear fallback algebra does not panic on baseline tests for unquantized models.

2. Execution Logs (Before vs. After)

Before logs (reproducing the issue):
$ cargo run --release --features cuda -- serve --port 1234 -m neuralmagic/Meta-Llama-3-8B-Instruct-FP8

2026-04-10T02:03:31.595045Z  INFO mistralrs_core::models::llama: Using fp8 quantization: 8 bits.
2026-04-10T02:04:13.296123Z ERROR mistralrs_core::engine: step - Model failed with error: unexpected dtype, expected: BF16, got: F8E4M3

After logs (with this patch applied):
$ cargo run --release --features cuda -- serve --port 1234 -m neuralmagic/Meta-Llama-3-8B-Instruct-FP8

2026-04-10T02:15:16.111630Z  INFO mistralrs_server_core::mistralrs_for_server_builder: Model loaded.
2026-04-10T02:15:16.112250Z  INFO mistralrs_core: Beginning dummy run.
2026-04-10T02:15:24.210607Z  INFO mistralrs_core: Dummy run completed in 8.098182927s.
2026-04-10T02:15:24.214426Z  INFO mistralrs::commands::serve: Server listening on http://0.0.0.0:1234

3. Regression Tests Run

Ran the standard mistralrs-quant test suite natively with CUDA to help verify that the native operations remain structurally intact over BF16.
$ cargo test --release --features cuda --package mistralrs-quant

...
test fp8::quantize::tests::test_cublaslt_matmul ... ok
test hqq::quantize::test::test_quantize_hqq ... ok
test cublaslt::tests::test_fused_batch_matmul_f8e4m3_nobias ... ok
test cublaslt::tests::test_fused_batch_matmul_f8e4m3_out_bf16 ... ok
test blockwise_fp8::ops::tests::test_blockwise_fp8_gemm ... ok

test result: ok. 41 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 8.96s

I'd appreciate any feedback on this approach! Thank you for reviewing!

…inear

This fixes issue EricLBuehler#2072 where FP8-quantized weights could not be directly loaded into UnquantLinear fallback configurations due to unsupported native CUDA PTX casts in candle. UnquantLinear now safely routes F8E4M3 via scalar_fp8::ops::fp8_to_dtype.

Signed-off-by: Gabe <gabe@example.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant