From 3c044aaccd34ca0484035ddcc93b766be7618d72 Mon Sep 17 00:00:00 2001 From: Jack Han Date: Tue, 21 Apr 2026 16:14:52 +0900 Subject: [PATCH 1/8] Add MXFP4 packed export, precision-aware sensitivity scorer, and AWQ/GPTQ support - Add MXFP4 packing utilities (mxfp4_pack.py): FP4 E2M1 + E8M0 scale format achieving 3.76x compression ratio for OCP MX specification compliance - Fix sensitivity scorer to use actual target precision (OCP_MXFP4Spec) instead of hardcoded INT4 proxy, reducing false exclusions from 68% to 5% of layers - Wire AWQ/GPTQ algorithm support via LLMTemplate.get_config(algorithm=...) - Add pack_mxfp4 config flag to control packed vs BF16 export - Add solar_open model type mapping to qwen3_moe template - Fix Int4PerGroupSpec fallback for Quark 0.11.1 compatibility (ch_axis param) Tested on MI355 (gfx950) with Solar-Open-100B: Packed checkpoint: 53GB (vs 192GB original, 3.62x compression) MMLU: 76.14% (-1.44% from baseline 77.58%) KMMLU: 57.03% (-0.35% from baseline 57.38%) --- CLAUDE.md | 73 +++++ {references => contribs}/llm-compressor | 0 {references => contribs}/transformers | 0 {references => contribs}/vllm | 0 src/quanto/constants.py | 3 +- src/quanto/core/config.py | 15 + .../core/sensitivity/sequential_analyzer.py | 69 +++- src/quanto/core/unified_quantizer.py | 126 +++++++- src/quanto/utils/__init__.py | 4 + src/quanto/utils/mxfp4_pack.py | 304 ++++++++++++++++++ tests/test_mxfp4_pack.py | 186 +++++++++++ 11 files changed, 769 insertions(+), 11 deletions(-) create mode 100644 CLAUDE.md rename {references => contribs}/llm-compressor (100%) rename {references => contribs}/transformers (100%) rename {references => contribs}/vllm (100%) create mode 100644 src/quanto/utils/mxfp4_pack.py create mode 100644 tests/test_mxfp4_pack.py diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..77a3a5a --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,73 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## What is Quanto + +Quanto is an LLM quantization toolkit built on AMD Quark. It quantizes HuggingFace models to INT4/INT8/FP8/MXFP4/MXFP6 precisions with multiple memory strategies for different GPU constraints. Source code lives in `src/quanto/`. + +## Commands + +```bash +# Install +pip install -e ".[dev]" # dev (pytest, ruff) +pip install -e ".[nvidia]" # with NVIDIA extras +pip install -e ".[rocm]" # with ROCm extras + +# Tests +pytest tests/ -v # all tests +pytest tests/test_unified_quantizer.py -v # single file +pytest tests/test_unified_quantizer.py::TestUnifiedConfig::test_default_config -v # single test + +# Lint & format +ruff check src/ # lint +ruff check src/ --fix # lint with autofix +ruff format src/ # format + +# Quantize a model (CLI) +python -m quanto --model_path /path/to/model --output_dir ./output --precision int4 +python -m quanto --model_path /path --sensitivity_analysis --sensitivity_threshold 0.12 + +# Dequantize +python -m quanto --dequantize --model_path ./quantized --output_dir ./dequantized + +# Docker-based integration tests +./scripts/run_tests.sh --gpu nvidia --test all +``` + +## Architecture + +### Pipeline flow +CLI (`__main__.py`) -> `UnifiedConfig` (dataclass validation) -> `UnifiedQuantizer.run()` -> `QuantizationResult` + +### Core modules (`src/quanto/core/`) +- **`config.py`** — `UnifiedConfig` dataclass with ~23 fields and `__post_init__` validation. `QuantizationConfig` is a backward-compat alias. +- **`unified_quantizer.py`** — Main quantizer implementing 4 memory strategies: `full` (entire model on GPU), `layerwise_cpu` (model on CPU, layers quantized one-by-one on GPU), `lazy` (weights loaded on-demand from safetensors), `auto` (selects based on model size vs GPU memory). +- **`base_quantizer.py`** — Abstract base class, `QuantizationResult` dataclass. +- **`dequantize.py`** — INT4 -> BF16/FP16 conversion. +- **`sensitivity/`** — Sequential sensitivity analysis: `SequentialSensitivityAnalyzer` scores per-layer quantization impact, `ActivationCache` manages GPU/CPU caching, `SensitivityScorer` computes perplexity-based metrics. + +### Supporting modules +- **`constants.py`** — `PRECISION_TO_SCHEME` mapping (e.g., `"int4"` -> `"int4_wo_128"`), `MODEL_TYPE_MAPPINGS`, `DEFAULT_EXCLUDE_PATTERNS`. +- **`analysis/layer_analyzer.py`** — Automatic detection of layers to exclude (lm_head, MoE gates, embeddings/norms with aggressive mode). +- **`utils/calibration.py`** — `CalibrationDataManager` loads from HuggingFace datasets or local files. +- **`utils/int4_pack.py`** — INT4 <-> INT32 packing/unpacking. +- **`utils/memory.py`** — GPU memory tracking and cleanup. +- **`utils/model_utils.py`** — Model type detection and Quark template lookup. + +### External dependency +AMD Quark is vendored as a git submodule in `contribs/quark/`. It provides the quantization scheme templates for each model architecture. + +## Code style + +- Ruff configured: 100-char line length, Python 3.10 target +- Lint rules: E, W, F, I (isort), B (bugbear), C4, UP, ARG, SIM +- Double quotes, space indentation +- `contribs/` directory is excluded from linting + +## Key patterns + +- **Backward compatibility aliases**: `QuantizationConfig = UnifiedConfig`, `AutoQuantizer` wraps `UnifiedQuantizer` +- **Valid precisions**: `int4`, `int4_64`, `int4_32`, `int8`, `fp8`, `mxfp4`, `mxfp6`, `uint4` +- **Memory strategies**: `full`, `layerwise_cpu`, `lazy`, `auto` +- **Export formats**: `quark` (native, default), `awq`, `gptq` (vLLM compat, INT4 only) diff --git a/references/llm-compressor b/contribs/llm-compressor similarity index 100% rename from references/llm-compressor rename to contribs/llm-compressor diff --git a/references/transformers b/contribs/transformers similarity index 100% rename from references/transformers rename to contribs/transformers diff --git a/references/vllm b/contribs/vllm similarity index 100% rename from references/vllm rename to contribs/vllm diff --git a/src/quanto/constants.py b/src/quanto/constants.py index 40aa38b..7421a13 100644 --- a/src/quanto/constants.py +++ b/src/quanto/constants.py @@ -46,6 +46,7 @@ "phi": "phi", "phi3": "phi3", "phi4": "phi3", + "solar_open": "qwen3_moe", } # Default layers to exclude from quantization @@ -78,7 +79,7 @@ # Supported quantization algorithms SUPPORTED_ALGORITHMS: list[str] = [ + "rtn", "awq", "gptq", - "smoothquant", ] diff --git a/src/quanto/core/config.py b/src/quanto/core/config.py index 5f44c3a..db9c506 100644 --- a/src/quanto/core/config.py +++ b/src/quanto/core/config.py @@ -94,6 +94,13 @@ class UnifiedConfig: # Layer batch size for lazy mode (number of layers to process in parallel) layer_batch_size: int = 4 + # Pack MXFP4 weights to compressed format (FP4 + E8M0 scales) + # Set False to keep BF16-stored weights for evaluation with lm-eval + pack_mxfp4: bool = True + + # Quantization algorithm: "rtn" (round-to-nearest, default), "awq", "gptq" + algorithm: str = "rtn" + def __post_init__(self) -> None: """Validate configuration after initialization.""" self.validate() @@ -147,6 +154,13 @@ def validate(self) -> None: if self.max_iterations < 1: raise ValueError(f"max_iterations must be >= 1, got {self.max_iterations}") + # Validate algorithm + valid_algorithms = ["rtn", "awq", "gptq"] + if self.algorithm not in valid_algorithms: + raise ValueError( + f"Invalid algorithm '{self.algorithm}'. Must be one of: {valid_algorithms}" + ) + def to_dict(self) -> dict[str, Any]: """Convert configuration to dictionary.""" return { @@ -171,6 +185,7 @@ def to_dict(self) -> dict[str, Any]: "trust_remote_code": self.trust_remote_code, "debug_dir": self.debug_dir, "layer_batch_size": self.layer_batch_size, + "algorithm": self.algorithm, } @classmethod diff --git a/src/quanto/core/sensitivity/sequential_analyzer.py b/src/quanto/core/sensitivity/sequential_analyzer.py index 7c1fac9..a8e4cfa 100644 --- a/src/quanto/core/sensitivity/sequential_analyzer.py +++ b/src/quanto/core/sensitivity/sequential_analyzer.py @@ -52,6 +52,7 @@ def __init__( metric: SensitivityMetric = SensitivityMetric.RELATIVE_NORM, cache_on_gpu: bool = True, initial_exclude_layers: list[str] | None = None, + template: object | None = None, ): """ Initialize the analyzer. @@ -61,11 +62,13 @@ def __init__( metric: Sensitivity metric to use cache_on_gpu: Store activations on GPU by default initial_exclude_layers: Layers to skip during analysis (already excluded) + template: LLMTemplate instance for precision-aware quantization config """ self.config = config self.metric = metric self.cache_on_gpu = cache_on_gpu self.initial_exclude_layers = initial_exclude_layers or [] + self.template = template # Components self.cache = ActivationCache( @@ -324,6 +327,10 @@ def _quantize_layer(self, layer: nn.Module, layer_name: str) -> nn.Module: """ Quantize a single layer for sensitivity testing. + Uses the actual target precision (MXFP4, INT4, FP8, etc.) rather than + a hardcoded INT4 proxy, so sensitivity scores accurately reflect the + quantization scheme being applied. + Args: layer: The layer module to quantize layer_name: Name of the layer @@ -332,15 +339,10 @@ def _quantize_layer(self, layer: nn.Module, layer_name: str) -> nn.Module: Quantized layer """ from quark.torch import ModelQuantizer - from quark.torch.quantization.config.config import Int4PerGroupSpec, QConfig, QLayerConfig - - # Create quantization config - # ch_axis=0 for per-row quantization (output channel dimension) - quant_config = QConfig( - global_quant_config=QLayerConfig( - weight=Int4PerGroupSpec(ch_axis=0, group_size=128).to_quantization_spec() - ), - ) + + from ...constants import PRECISION_TO_SCHEME + + quant_config = self._build_quant_config_for_scoring() # Quantize quantizer = ModelQuantizer(quant_config) @@ -355,6 +357,55 @@ def _quantize_layer(self, layer: nn.Module, layer_name: str) -> nn.Module: return quantized_layer + def _build_quant_config_for_scoring(self): + """ + Build quantization config matching the target precision. + + Uses the LLMTemplate if available (produces architecture-specific configs), + otherwise falls back to building a config from the precision's Quark Spec class. + """ + from quark.torch.quantization.config.config import QConfig, QLayerConfig + + from ...constants import PRECISION_TO_SCHEME + + precision = self.config.precision + scheme = PRECISION_TO_SCHEME.get(precision, precision) + + # Prefer template-based config (architecture-specific) + if self.template: + return self.template.get_config( + scheme=scheme, + exclude_layers=[], + ) + + # Fallback: build config from precision spec + if precision.startswith("mxfp4"): + from quark.torch.quantization.config.config import OCP_MXFP4Spec + + spec = OCP_MXFP4Spec(ch_axis=0).to_quantization_spec() + elif precision.startswith("mxfp6"): + from quark.torch.quantization.config.config import OCP_MXFP6E3M2Spec + + spec = OCP_MXFP6E3M2Spec(ch_axis=0).to_quantization_spec() + elif precision.startswith("int4") or precision.startswith("uint4"): + from quark.torch.quantization.config.config import Int4PerGroupSpec + + group_size = 128 + if "64" in precision: + group_size = 64 + elif "32" in precision: + group_size = 32 + spec = Int4PerGroupSpec(ch_axis=0, group_size=group_size).to_quantization_spec() + else: + # Default fallback to INT4 for unknown precisions + from quark.torch.quantization.config.config import Int4PerGroupSpec + + spec = Int4PerGroupSpec(ch_axis=0, group_size=128).to_quantization_spec() + + return QConfig( + global_quant_config=QLayerConfig(weight=spec), + ) + def analyze(self) -> AnalysisResult: """ Run sequential sensitivity analysis. diff --git a/src/quanto/core/unified_quantizer.py b/src/quanto/core/unified_quantizer.py index 8d78cf2..5df907f 100644 --- a/src/quanto/core/unified_quantizer.py +++ b/src/quanto/core/unified_quantizer.py @@ -262,6 +262,7 @@ def _run_sequential_sensitivity_analysis(self) -> list[str]: analyzer = SequentialSensitivityAnalyzer( config=self.config, cache_on_gpu=self.config.sensitivity_cache_on_gpu, + template=self.template, ) result = analyzer.analyze() @@ -332,6 +333,7 @@ def _run_iterative_sensitivity_analysis(self) -> list[str]: config=self.config, cache_on_gpu=cache_on_gpu, initial_exclude_layers=all_excluded, + template=self.template, ) # Run analysis @@ -447,16 +449,22 @@ def _create_quant_config(self, exclude_layers: list[str]) -> QConfig: quant_scheme = self._get_quant_scheme() self._log(f"Using quantization scheme: {quant_scheme}") + # Determine algorithm (None for RTN, "awq"/"gptq" for advanced) + algorithm = self.config.algorithm if self.config.algorithm != "rtn" else None + if algorithm: + self._log(f"Using quantization algorithm: {algorithm}") + # Create base quant config if self.template: quant_config = self.template.get_config( scheme=quant_scheme, + algorithm=algorithm, exclude_layers=exclude_layers, ) else: quant_config = QConfig( global_quant_config=QLayerConfig( - weight=Int4PerGroupSpec(group_size=128).to_quantization_spec() + weight=Int4PerGroupSpec(group_size=128, ch_axis=0).to_quantization_spec() ), exclude=exclude_layers, ) @@ -853,6 +861,11 @@ def _run_lazy_quantization(self) -> QuantizationResult: self._log("\n=== Assembling HuggingFace format ===") self._assemble_hf_format() + # Pack MXFP4 weights for actual compression + if self.config.precision.startswith("mxfp") and self.config.pack_mxfp4: + self._log("Packing MXFP4 weights...") + self._pack_mxfp4_weights(exclude_layers) + self.timing["total"] = time.time() - total_start result.success = True @@ -1160,6 +1173,11 @@ def _run_full_gpu_quantization(self) -> QuantizationResult: self.tokenizer.save_pretrained(self.config.output_dir) + # Pack MXFP4 weights for actual compression + if self.config.precision.startswith("mxfp") and self.config.pack_mxfp4: + self._log("Packing MXFP4 weights...") + self._pack_mxfp4_weights(exclude_layers) + self.timing["total"] = time.time() - total_start result.success = True @@ -1180,6 +1198,112 @@ def _run_full_gpu_quantization(self) -> QuantizationResult: return result + def _pack_mxfp4_weights(self, exclude_layers: list[str]) -> None: + """ + Post-process exported safetensors to pack MXFP4 weights. + + Replaces BF16 dequantized weights with packed FP4 + E8M0 scale format + for actual compression. Only processes quantized Linear weights + (those NOT in the exclude list). + """ + from ..utils.mxfp4_pack import pack_mxfp4 + + output_dir = Path(self.config.output_dir) + index_file = output_dir / "model.safetensors.index.json" + + if not index_file.exists(): + self._log("Warning: No safetensors index found, skipping MXFP4 packing") + return + + with open(index_file) as f: + index = json.load(f) + + weight_map = index.get("weight_map", {}) + group_size = 32 # MXFP4 default group size + + # Identify quantizable weight keys (2D weights not in exclude list) + import fnmatch + + def is_excluded(name: str) -> bool: + for pattern in exclude_layers: + if fnmatch.fnmatch(name, pattern) or pattern in name: + return True + return False + + # Group weights by shard file + shard_files = set(weight_map.values()) + total_packed = 0 + + for shard_name in sorted(shard_files): + shard_path = output_dir / shard_name + if not shard_path.exists(): + continue + + # Load all tensors from this shard + tensors = {} + with safe_open(str(shard_path), framework="pt", device="cpu") as f: + for key in f.keys(): + tensors[key] = f.get_tensor(key) + + # Pack eligible weights + new_tensors = {} + new_weight_map_updates = {} + + for key, tensor in tensors.items(): + # Check if this is a quantizable Linear weight + if ( + key.endswith(".weight") + and len(tensor.shape) == 2 + and not is_excluded(key) + and tensor.shape[-1] % 2 == 0 # Must be even for packing + ): + # Strip .weight suffix to get layer name + layer_name = key[: -len(".weight")] + + # Check if this layer was actually quantized (not in exclude) + layer_excluded = is_excluded(layer_name) + if not layer_excluded: + packed = pack_mxfp4(tensor, group_size=group_size) + packed_key = f"{layer_name}.weight.packed" + scale_key = f"{layer_name}.weight.scale_e8m0" + + new_tensors[packed_key] = packed["weight.packed"] + new_tensors[scale_key] = packed["weight.scale_e8m0"] + new_weight_map_updates[packed_key] = shard_name + new_weight_map_updates[scale_key] = shard_name + total_packed += 1 + continue + + # Keep non-quantized tensors as-is + new_tensors[key] = tensor + + # Re-save shard with packed weights + save_file(new_tensors, str(shard_path)) + + # Update weight map + for old_key in list(weight_map.keys()): + if weight_map[old_key] == shard_name and old_key not in new_tensors: + del weight_map[old_key] + weight_map.update(new_weight_map_updates) + + # Save updated index + index["weight_map"] = weight_map + with open(index_file, "w") as f: + json.dump(index, f, indent=2) + + # Update config.json with packed format info + config_file = output_dir / "config.json" + if config_file.exists(): + with open(config_file) as f: + config = json.load(f) + if "quantization_config" in config: + config["quantization_config"]["weight_format"] = "packed_mxfp4" + config["quantization_config"]["mxfp4_group_size"] = group_size + with open(config_file, "w") as f: + json.dump(config, f, indent=2) + + self._log(f"Packed {total_packed} weight tensors to MXFP4 format") + def _print_summary(self, result: QuantizationResult) -> None: """Print quantization summary.""" print("\n" + "=" * 60) diff --git a/src/quanto/utils/__init__.py b/src/quanto/utils/__init__.py index 9f6a31f..74e59d9 100644 --- a/src/quanto/utils/__init__.py +++ b/src/quanto/utils/__init__.py @@ -13,6 +13,7 @@ quantize_to_int4, unpack_int32_to_int4, ) +from .mxfp4_pack import pack_mxfp4, unpack_mxfp4 from .logging import Timer, get_logger, log_with_timestamp from .memory import ( clear_gpu_memory, @@ -41,6 +42,9 @@ "pack_int4_to_int32", "unpack_int32_to_int4", "pack_layer_weights", + # MXFP4 Packing + "pack_mxfp4", + "unpack_mxfp4", # Logging "get_logger", "log_with_timestamp", diff --git a/src/quanto/utils/mxfp4_pack.py b/src/quanto/utils/mxfp4_pack.py new file mode 100644 index 0000000..c332864 --- /dev/null +++ b/src/quanto/utils/mxfp4_pack.py @@ -0,0 +1,304 @@ +""" +MXFP4 Weight Packing Utilities + +Packs BF16 weights into OCP MX FP4 format with E8M0 shared scales. +This reduces model size by ~3.76x compared to BF16. + +MXFP4 format (OCP MX Specification): +- FP4 (E2M1): 1 sign + 1 exponent + 2 mantissa = 4 bits per value +- E8M0 scale: 8-bit shared exponent per group of 32 elements +- 2 FP4 values packed per uint8 byte + +FP4 E2M1 representable values (magnitude): + 0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0 +""" + +from __future__ import annotations + +import math + +import torch + + +# FP4 E2M1 encoding: 4-bit code -> float value +# Bit layout: [sign(1)] [exponent(1)] [mantissa(2)] +# Code 0b_s_e_mm +FP4_E2M1_TABLE = torch.tensor([ + 0.0, # 0b0_0_00 = 0 + 0.5, # 0b0_0_01 = 0.5 + 1.0, # 0b0_0_10 = 1.0 + 1.5, # 0b0_0_11 = 1.5 + 2.0, # 0b0_1_00 = 2.0 + 3.0, # 0b0_1_01 = 3.0 + 4.0, # 0b0_1_10 = 4.0 + 6.0, # 0b0_1_11 = 6.0 + -0.0, # 0b1_0_00 = -0 (treated as 0) + -0.5, # 0b1_0_01 = -0.5 + -1.0, # 0b1_0_10 = -1.0 + -1.5, # 0b1_0_11 = -1.5 + -2.0, # 0b1_1_00 = -2.0 + -3.0, # 0b1_1_01 = -3.0 + -4.0, # 0b1_1_10 = -4.0 + -6.0, # 0b1_1_11 = -6.0 +], dtype=torch.float32) + +# Max representable magnitude in FP4 E2M1 +FP4_MAX = 6.0 + + +def compute_e8m0_scales(weight: torch.Tensor, group_size: int = 32) -> torch.Tensor: + """ + Compute E8M0 shared scales for MXFP4 quantization. + + E8M0 format: 8-bit exponent only (no mantissa), representing power-of-2 scales. + scale = 2^(e8m0_code - 127) (IEEE 754 bias) + + The scale is chosen so that max(abs(group)) / scale <= FP4_MAX (6.0). + + Args: + weight: [..., in_features] tensor to compute scales for + group_size: Number of elements per group (default 32) + + Returns: + E8M0 scale codes as uint8 tensor [..., num_groups] + """ + orig_shape = weight.shape + in_features = orig_shape[-1] + + # Pad if not divisible by group_size + if in_features % group_size != 0: + pad_size = group_size - (in_features % group_size) + weight = torch.nn.functional.pad(weight, (0, pad_size)) + in_features = weight.shape[-1] + + num_groups = in_features // group_size + + # Reshape to [..., num_groups, group_size] + grouped = weight.reshape(*orig_shape[:-1], num_groups, group_size) + + # Max absolute value per group + group_max = grouped.abs().amax(dim=-1).float() + + # Avoid log2(0) — use a small floor + group_max = group_max.clamp(min=1e-12) + + # E8M0 exponent: floor(log2(group_max / FP4_MAX)) + 127 (IEEE bias) + # scale = 2^(code - 127), so code = floor(log2(group_max / FP4_MAX)) + 127 + exponent = torch.floor(torch.log2(group_max / FP4_MAX)).to(torch.int32) + 127 + + # Clamp to valid E8M0 range [0, 254] (255 = NaN/Inf in E8M0) + exponent = exponent.clamp(0, 254) + + return exponent.to(torch.uint8) + + +def e8m0_to_float(e8m0_codes: torch.Tensor) -> torch.Tensor: + """Convert E8M0 scale codes to float scale values. + + Args: + e8m0_codes: uint8 tensor of E8M0 exponent codes + + Returns: + Float tensor of scale values: 2^(code - 127) + """ + return torch.pow(2.0, e8m0_codes.to(torch.float32) - 127.0) + + +def quantize_to_fp4( + weight: torch.Tensor, + e8m0_scales: torch.Tensor, + group_size: int = 32, +) -> torch.Tensor: + """ + Quantize BF16 weight to FP4 E2M1 codes using E8M0 scales. + + Args: + weight: [..., in_features] BF16/FP32 tensor + e8m0_scales: [..., num_groups] uint8 E8M0 scale codes + group_size: Elements per group + + Returns: + [..., in_features] uint8 tensor with FP4 codes (values 0-15) + """ + orig_shape = weight.shape + in_features = orig_shape[-1] + + # Pad if needed + if in_features % group_size != 0: + pad_size = group_size - (in_features % group_size) + weight = torch.nn.functional.pad(weight, (0, pad_size)) + in_features = weight.shape[-1] + + num_groups = in_features // group_size + + # Convert scales to float + scales = e8m0_to_float(e8m0_scales) # [..., num_groups] + + # Reshape weight to [..., num_groups, group_size] + grouped = weight.reshape(*orig_shape[:-1], num_groups, group_size).float() + + # Divide by scale: normalized = weight / scale + # scales shape [..., num_groups] -> [..., num_groups, 1] + normalized = grouped / scales.unsqueeze(-1) + + # Round to nearest FP4 value using the lookup table + # FP4 positive magnitudes: [0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0] + sign = normalized.sign() + magnitude = normalized.abs() + + # Magnitude-only FP4 table for nearest-value lookup + fp4_magnitudes = torch.tensor( + [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], + dtype=torch.float32, + device=weight.device, + ) + + # Find nearest FP4 magnitude for each value + # Compute distance to each FP4 magnitude + diffs = (magnitude.unsqueeze(-1) - fp4_magnitudes).abs() + mag_codes = diffs.argmin(dim=-1) # index 0-7 + + # Combine sign and magnitude into 4-bit code + # sign bit = 0 for positive, 1 for negative (bit 3) + sign_bit = (sign < 0).to(torch.uint8) << 3 + fp4_codes = sign_bit | mag_codes.to(torch.uint8) + + # Reshape back to [..., in_features] + fp4_codes = fp4_codes.reshape(*orig_shape[:-1], in_features) + + # Trim padding + if in_features != orig_shape[-1]: + fp4_codes = fp4_codes[..., : orig_shape[-1]] + + return fp4_codes + + +def pack_fp4_to_uint8(fp4_codes: torch.Tensor) -> torch.Tensor: + """ + Pack two FP4 (4-bit) values into one uint8 byte. + + Packing order: low nibble first, high nibble second. + byte = fp4_codes[2i] | (fp4_codes[2i+1] << 4) + + Args: + fp4_codes: [..., N] uint8 tensor with FP4 codes (0-15) + + Returns: + [..., N//2] uint8 tensor with packed FP4 pairs + """ + *batch, n = fp4_codes.shape + + # Pad if odd number of elements + if n % 2 != 0: + fp4_codes = torch.nn.functional.pad(fp4_codes, (0, 1)) + n = fp4_codes.shape[-1] + + # Reshape to [..., N//2, 2] + pairs = fp4_codes.reshape(*batch, n // 2, 2) + + # Pack: low nibble = first value, high nibble = second value + packed = (pairs[..., 0] & 0x0F) | ((pairs[..., 1] & 0x0F) << 4) + + return packed.to(torch.uint8) + + +def unpack_uint8_to_fp4(packed: torch.Tensor) -> torch.Tensor: + """ + Unpack uint8 bytes to two FP4 codes each. + + Args: + packed: [..., N] uint8 tensor + + Returns: + [..., N*2] uint8 tensor with FP4 codes (0-15) + """ + low = packed & 0x0F + high = (packed >> 4) & 0x0F + + # Interleave: [low0, high0, low1, high1, ...] + unpacked = torch.stack([low, high], dim=-1) + return unpacked.reshape(*packed.shape[:-1], packed.shape[-1] * 2) + + +def pack_mxfp4( + weight: torch.Tensor, + group_size: int = 32, +) -> dict[str, torch.Tensor]: + """ + Full pipeline: BF16 weight -> packed MXFP4. + + Args: + weight: [out_features, in_features] BF16 tensor + group_size: MXFP4 group size (default 32) + + Returns: + Dict with: + - "weight.packed": [out_features, in_features // 2] uint8 (packed FP4 pairs) + - "weight.scale_e8m0": [out_features, in_features // group_size] uint8 + """ + # Step 1: Compute E8M0 scales + e8m0_scales = compute_e8m0_scales(weight, group_size) + + # Step 2: Quantize to FP4 codes + fp4_codes = quantize_to_fp4(weight, e8m0_scales, group_size) + + # Step 3: Pack FP4 codes into uint8 + packed = pack_fp4_to_uint8(fp4_codes) + + return { + "weight.packed": packed, + "weight.scale_e8m0": e8m0_scales, + } + + +def unpack_mxfp4( + packed: torch.Tensor, + scale_e8m0: torch.Tensor, + group_size: int = 32, + target_dtype: torch.dtype = torch.bfloat16, +) -> torch.Tensor: + """ + Unpack MXFP4 to BF16/FP32 for inference. + + Args: + packed: [out_features, in_features // 2] uint8 packed FP4 + scale_e8m0: [out_features, num_groups] uint8 E8M0 scales + group_size: MXFP4 group size + target_dtype: Output dtype + + Returns: + [out_features, in_features] tensor in target_dtype + """ + # Step 1: Unpack uint8 -> FP4 codes + fp4_codes = unpack_uint8_to_fp4(packed) + + # Step 2: Convert FP4 codes to float using lookup table + table = FP4_E2M1_TABLE.to(fp4_codes.device) + fp4_codes_long = fp4_codes.to(torch.long) + values = table[fp4_codes_long] + + # Step 3: Apply E8M0 scales + in_features = values.shape[-1] + num_groups = scale_e8m0.shape[-1] + + scales = e8m0_to_float(scale_e8m0).to(values.device) # [..., num_groups] + + # Reshape values to [..., num_groups, group_size] + # Trim or pad to match num_groups * group_size + expected_len = num_groups * group_size + if in_features > expected_len: + values = values[..., :expected_len] + elif in_features < expected_len: + values = torch.nn.functional.pad(values, (0, expected_len - in_features)) + + grouped = values.reshape(*values.shape[:-1], num_groups, group_size) + grouped = grouped * scales.unsqueeze(-1) + + # Reshape back + result = grouped.reshape(*values.shape[:-1], num_groups * group_size) + + # Trim to original in_features + if result.shape[-1] > in_features: + result = result[..., :in_features] + + return result.to(target_dtype) diff --git a/tests/test_mxfp4_pack.py b/tests/test_mxfp4_pack.py new file mode 100644 index 0000000..476bd50 --- /dev/null +++ b/tests/test_mxfp4_pack.py @@ -0,0 +1,186 @@ +"""Tests for MXFP4 packing/unpacking utilities.""" + +from __future__ import annotations + +import torch +import pytest + +from quanto.utils.mxfp4_pack import ( + FP4_E2M1_TABLE, + FP4_MAX, + compute_e8m0_scales, + e8m0_to_float, + pack_fp4_to_uint8, + pack_mxfp4, + quantize_to_fp4, + unpack_mxfp4, + unpack_uint8_to_fp4, +) + + +class TestFP4Table: + """Test FP4 E2M1 encoding table.""" + + def test_table_size(self): + assert len(FP4_E2M1_TABLE) == 16 + + def test_positive_values(self): + expected = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0] + for i, v in enumerate(expected): + assert FP4_E2M1_TABLE[i].item() == v + + def test_negative_values(self): + expected = [0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0] + for i, v in enumerate(expected): + assert FP4_E2M1_TABLE[i + 8].item() == pytest.approx(v) + + def test_max_value(self): + assert FP4_MAX == 6.0 + + +class TestE8M0Scales: + """Test E8M0 scale computation.""" + + def test_unit_range(self): + """Values in [-6, 6] should have scale ~ 1.0.""" + weight = torch.tensor([[1.0, 2.0, 3.0, -1.0] * 8]) # max=3.0 + scales = compute_e8m0_scales(weight, group_size=32) + scale_float = e8m0_to_float(scales) + # scale = 2^floor(log2(3.0 / 6.0)) = 2^floor(-1) = 2^(-1) = 0.5 + assert scale_float[0, 0].item() == pytest.approx(0.5) + + def test_large_values(self): + """Large values should produce larger scales.""" + weight = torch.tensor([[100.0] * 32]) + scales = compute_e8m0_scales(weight, group_size=32) + scale_float = e8m0_to_float(scales) + # scale = 2^floor(log2(100/6)) = 2^floor(4.06) = 2^4 = 16 + assert scale_float[0, 0].item() == pytest.approx(16.0) + + def test_small_values(self): + """Small values should produce smaller scales.""" + weight = torch.tensor([[0.01] * 32]) + scales = compute_e8m0_scales(weight, group_size=32) + scale_float = e8m0_to_float(scales) + # Very small scale + assert scale_float[0, 0].item() < 0.01 + + def test_multiple_groups(self): + """Multiple groups should have independent scales.""" + weight = torch.tensor([[1.0] * 32 + [100.0] * 32]) + scales = compute_e8m0_scales(weight, group_size=32) + assert scales.shape == (1, 2) + scale_float = e8m0_to_float(scales) + assert scale_float[0, 1] > scale_float[0, 0] + + +class TestFP4Packing: + """Test FP4 uint8 packing/unpacking.""" + + def test_pack_unpack_roundtrip(self): + """Pack then unpack should recover original codes.""" + codes = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]], dtype=torch.uint8) + packed = pack_fp4_to_uint8(codes) + assert packed.shape == (1, 8) + unpacked = unpack_uint8_to_fp4(packed) + assert unpacked.shape == (1, 16) + assert torch.equal(unpacked, codes) + + def test_pack_shape(self): + """Packed tensor should be half the size.""" + codes = torch.randint(0, 16, (4, 64), dtype=torch.uint8) + packed = pack_fp4_to_uint8(codes) + assert packed.shape == (4, 32) + + def test_pack_values(self): + """Check specific packing: byte = low_nibble | (high_nibble << 4).""" + codes = torch.tensor([[3, 12]], dtype=torch.uint8) # 3 and 12 + packed = pack_fp4_to_uint8(codes) + expected = (3) | (12 << 4) # 0x03 | 0xC0 = 0xC3 + assert packed[0, 0].item() == expected + + +class TestMXFP4FullPipeline: + """Test full pack/unpack pipeline.""" + + def test_roundtrip_small(self): + """Pack then unpack a small tensor.""" + weight = torch.randn(4, 64, dtype=torch.bfloat16) + result = pack_mxfp4(weight, group_size=32) + + assert "weight.packed" in result + assert "weight.scale_e8m0" in result + assert result["weight.packed"].dtype == torch.uint8 + assert result["weight.scale_e8m0"].dtype == torch.uint8 + assert result["weight.packed"].shape == (4, 32) # 64/2 + assert result["weight.scale_e8m0"].shape == (4, 2) # 64/32 + + def test_roundtrip_accuracy(self): + """Unpacked values should be close to original (within FP4 precision).""" + weight = torch.randn(8, 128, dtype=torch.bfloat16) * 2.0 + result = pack_mxfp4(weight, group_size=32) + + recovered = unpack_mxfp4( + result["weight.packed"], + result["weight.scale_e8m0"], + group_size=32, + ) + + # FP4 has only 16 values, so error can be significant + # But relative error should be bounded + rel_error = (weight.float() - recovered.float()).abs() / (weight.float().abs() + 1e-8) + mean_rel_error = rel_error.mean() + assert mean_rel_error < 0.5, f"Mean relative error too high: {mean_rel_error}" + + def test_compression_ratio(self): + """Packed format should achieve ~3.76x compression.""" + weight = torch.randn(1024, 4096, dtype=torch.bfloat16) + result = pack_mxfp4(weight, group_size=32) + + original_bytes = weight.numel() * 2 # BF16 = 2 bytes + packed_bytes = result["weight.packed"].numel() * 1 # uint8 = 1 byte + scale_bytes = result["weight.scale_e8m0"].numel() * 1 # uint8 = 1 byte + total_packed = packed_bytes + scale_bytes + + ratio = original_bytes / total_packed + # Expected: ~3.76x (64 bytes BF16 vs 17 bytes MXFP4 per 32 elements) + assert ratio > 3.5, f"Compression ratio too low: {ratio:.2f}x" + assert ratio < 4.0, f"Compression ratio too high: {ratio:.2f}x" + + def test_zero_tensor(self): + """Zero tensor should pack/unpack correctly.""" + weight = torch.zeros(4, 64, dtype=torch.bfloat16) + result = pack_mxfp4(weight, group_size=32) + recovered = unpack_mxfp4(result["weight.packed"], result["weight.scale_e8m0"]) + assert torch.allclose(recovered, weight.float(), atol=1e-6) + + def test_large_matrix(self): + """Test with sizes typical of LLM weight matrices.""" + weight = torch.randn(4096, 4096, dtype=torch.bfloat16) + result = pack_mxfp4(weight, group_size=32) + assert result["weight.packed"].shape == (4096, 2048) + assert result["weight.scale_e8m0"].shape == (4096, 128) + + +class TestConfigIntegration: + """Test that config accepts new algorithm field.""" + + def test_rtn_default(self): + from quanto.core.config import UnifiedConfig + config = UnifiedConfig(model_path="/tmp/test", output_dir="/tmp/out") + assert config.algorithm == "rtn" + + def test_awq_algorithm(self): + from quanto.core.config import UnifiedConfig + config = UnifiedConfig(model_path="/tmp/test", output_dir="/tmp/out", algorithm="awq") + assert config.algorithm == "awq" + + def test_gptq_algorithm(self): + from quanto.core.config import UnifiedConfig + config = UnifiedConfig(model_path="/tmp/test", output_dir="/tmp/out", algorithm="gptq") + assert config.algorithm == "gptq" + + def test_invalid_algorithm(self): + from quanto.core.config import UnifiedConfig + with pytest.raises(ValueError, match="Invalid algorithm"): + UnifiedConfig(model_path="/tmp/test", output_dir="/tmp/out", algorithm="invalid") From 0ec8bfa89685777b4c50ce05758556040ce8a202 Mon Sep 17 00:00:00 2001 From: Jack Han Date: Tue, 21 Apr 2026 16:16:57 +0900 Subject: [PATCH 2/8] Update .gitmodules for reorganized submodule paths --- .gitmodules | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/.gitmodules b/.gitmodules index 025fa65..45661ca 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,12 @@ [submodule "contribs/quark"] path = contribs/quark url = https://github.com/amd/Quark.git +[submodule "contribs/llm-compressor"] + path = contribs/llm-compressor + url = https://github.com/vllm-project/llm-compressor.git +[submodule "contribs/transformers"] + path = contribs/transformers + url = https://github.com/huggingface/transformers.git +[submodule "contribs/vllm"] + path = contribs/vllm + url = https://github.com/vllm-project/vllm.git From bdc53a1913d972951656c14c94712c0ff1e23c5d Mon Sep 17 00:00:00 2001 From: Jack Han Date: Tue, 21 Apr 2026 19:27:41 +0900 Subject: [PATCH 3/8] Refactor MXFP4 to use Quark file2file quantization MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace custom mxfp4_pack.py with Quark's quantize_model_per_safetensor for MXFP4 quantization. This produces properly packed uint8 weights that vLLM loads natively as a Quark-quantized checkpoint. - Add _run_file2file_quantization() using Quark's file2file path - Route MXFP precision to file2file in run() dispatch - Remove custom mxfp4_pack.py and pack_mxfp4 config option - Resolve HF hub IDs to local paths for file2file compatibility Solar-Open-100B: 192GB → 53GB (3.62x), 73s quantization time. Matches AMD's official MXFP4 model format (Kimi-K2.5-MXFP4). --- src/quanto/core/config.py | 4 - src/quanto/core/unified_quantizer.py | 170 ++++++--------- src/quanto/utils/__init__.py | 4 - src/quanto/utils/mxfp4_pack.py | 304 --------------------------- tests/test_mxfp4_pack.py | 186 ---------------- 5 files changed, 70 insertions(+), 598 deletions(-) delete mode 100644 src/quanto/utils/mxfp4_pack.py delete mode 100644 tests/test_mxfp4_pack.py diff --git a/src/quanto/core/config.py b/src/quanto/core/config.py index db9c506..b59dd70 100644 --- a/src/quanto/core/config.py +++ b/src/quanto/core/config.py @@ -94,10 +94,6 @@ class UnifiedConfig: # Layer batch size for lazy mode (number of layers to process in parallel) layer_batch_size: int = 4 - # Pack MXFP4 weights to compressed format (FP4 + E8M0 scales) - # Set False to keep BF16-stored weights for evaluation with lm-eval - pack_mxfp4: bool = True - # Quantization algorithm: "rtn" (round-to-nearest, default), "awq", "gptq" algorithm: str = "rtn" diff --git a/src/quanto/core/unified_quantizer.py b/src/quanto/core/unified_quantizer.py index 5df907f..d671c7a 100644 --- a/src/quanto/core/unified_quantizer.py +++ b/src/quanto/core/unified_quantizer.py @@ -861,10 +861,6 @@ def _run_lazy_quantization(self) -> QuantizationResult: self._log("\n=== Assembling HuggingFace format ===") self._assemble_hf_format() - # Pack MXFP4 weights for actual compression - if self.config.precision.startswith("mxfp") and self.config.pack_mxfp4: - self._log("Packing MXFP4 weights...") - self._pack_mxfp4_weights(exclude_layers) self.timing["total"] = time.time() - total_start @@ -1173,10 +1169,6 @@ def _run_full_gpu_quantization(self) -> QuantizationResult: self.tokenizer.save_pretrained(self.config.output_dir) - # Pack MXFP4 weights for actual compression - if self.config.precision.startswith("mxfp") and self.config.pack_mxfp4: - self._log("Packing MXFP4 weights...") - self._pack_mxfp4_weights(exclude_layers) self.timing["total"] = time.time() - total_start @@ -1198,111 +1190,86 @@ def _run_full_gpu_quantization(self) -> QuantizationResult: return result - def _pack_mxfp4_weights(self, exclude_layers: list[str]) -> None: + def _run_file2file_quantization(self) -> QuantizationResult: """ - Post-process exported safetensors to pack MXFP4 weights. + Run file-to-file quantization using Quark's quantize_model_per_safetensor. - Replaces BF16 dequantized weights with packed FP4 + E8M0 scale format - for actual compression. Only processes quantized Linear weights - (those NOT in the exclude list). + Processes each safetensors shard independently without loading the full model + into memory. Produces properly packed uint8 weights with E8M0 scales that + vLLM can load natively as a Quark-quantized checkpoint. + + This is the recommended path for MXFP4/MXFP6 quantization, matching how AMD + publishes official MXFP4 models (e.g., Kimi-K2.5-MXFP4). """ - from ..utils.mxfp4_pack import pack_mxfp4 + from quark.torch.quantization.file2file_quantization import quantize_model_per_safetensor - output_dir = Path(self.config.output_dir) - index_file = output_dir / "model.safetensors.index.json" + total_start = time.time() + result = QuantizationResult(success=False) - if not index_file.exists(): - self._log("Warning: No safetensors index found, skipping MXFP4 packing") - return + try: + # Setup (load config, detect model type, get template) + self._setup() - with open(index_file) as f: - index = json.load(f) + # Determine exclusions (including sensitivity analysis if enabled) + exclude_layers = self._determine_exclude_layers() + result.exclude_layers_used = exclude_layers + self._log(f"Exclude layers: {exclude_layers}") - weight_map = index.get("weight_map", {}) - group_size = 32 # MXFP4 default group size + # Create quantization config + quant_config = self._create_quant_config(exclude_layers) - # Identify quantizable weight keys (2D weights not in exclude list) - import fnmatch + self._log(f"\n{'=' * 60}") + self._log("FILE-TO-FILE QUANTIZATION") + self._log(f"{'=' * 60}") + self._log(f"Model: {self.config.model_path}") + self._log(f"Output: {self.config.output_dir}") + self._log(f"Precision: {self.config.precision}") + self._log(f"Device: {self.config.device}") + self._log(f"{'=' * 60}") - def is_excluded(name: str) -> bool: - for pattern in exclude_layers: - if fnmatch.fnmatch(name, pattern) or pattern in name: - return True - return False + # Resolve model path to local directory + # file2file requires a local path with safetensors files, not a HF hub ID + model_path = self.config.model_path + if not os.path.isdir(model_path): + from huggingface_hub import snapshot_download - # Group weights by shard file - shard_files = set(weight_map.values()) - total_packed = 0 + self._log(f"Downloading model from HuggingFace: {model_path}") + model_path = snapshot_download(model_path) + self._log(f"Model downloaded to: {model_path}") - for shard_name in sorted(shard_files): - shard_path = output_dir / shard_name - if not shard_path.exists(): - continue + # Run file-to-file quantization + self._log("Running file-to-file quantization...") + quant_start = time.time() - # Load all tensors from this shard - tensors = {} - with safe_open(str(shard_path), framework="pt", device="cpu") as f: - for key in f.keys(): - tensors[key] = f.get_tensor(key) - - # Pack eligible weights - new_tensors = {} - new_weight_map_updates = {} - - for key, tensor in tensors.items(): - # Check if this is a quantizable Linear weight - if ( - key.endswith(".weight") - and len(tensor.shape) == 2 - and not is_excluded(key) - and tensor.shape[-1] % 2 == 0 # Must be even for packing - ): - # Strip .weight suffix to get layer name - layer_name = key[: -len(".weight")] - - # Check if this layer was actually quantized (not in exclude) - layer_excluded = is_excluded(layer_name) - if not layer_excluded: - packed = pack_mxfp4(tensor, group_size=group_size) - packed_key = f"{layer_name}.weight.packed" - scale_key = f"{layer_name}.weight.scale_e8m0" - - new_tensors[packed_key] = packed["weight.packed"] - new_tensors[scale_key] = packed["weight.scale_e8m0"] - new_weight_map_updates[packed_key] = shard_name - new_weight_map_updates[scale_key] = shard_name - total_packed += 1 - continue - - # Keep non-quantized tensors as-is - new_tensors[key] = tensor - - # Re-save shard with packed weights - save_file(new_tensors, str(shard_path)) - - # Update weight map - for old_key in list(weight_map.keys()): - if weight_map[old_key] == shard_name and old_key not in new_tensors: - del weight_map[old_key] - weight_map.update(new_weight_map_updates) - - # Save updated index - index["weight_map"] = weight_map - with open(index_file, "w") as f: - json.dump(index, f, indent=2) + quantize_model_per_safetensor( + pretrained_model_path=model_path, + quant_config=quant_config, + save_path=self.config.output_dir, + device=self.config.device, + ) + + self.timing["quantization"] = time.time() - quant_start + self._log(f"File-to-file quantization completed in {self.timing['quantization']:.2f}s") + + self.timing["total"] = time.time() - total_start + + result.success = True + result.output_dir = self.config.output_dir + result.model_type = self.model_type + result.quant_scheme = self._get_quant_scheme() + result.precision = self.config.precision + result.timing = self.timing - # Update config.json with packed format info - config_file = output_dir / "config.json" - if config_file.exists(): - with open(config_file) as f: - config = json.load(f) - if "quantization_config" in config: - config["quantization_config"]["weight_format"] = "packed_mxfp4" - config["quantization_config"]["mxfp4_group_size"] = group_size - with open(config_file, "w") as f: - json.dump(config, f, indent=2) + self._print_summary(result) + + except Exception as e: + result.success = False + result.error_message = str(e) + self._log(f"Error during quantization: {e}") + import traceback + traceback.print_exc() - self._log(f"Packed {total_packed} weight tensors to MXFP4 format") + return result def _print_summary(self, result: QuantizationResult) -> None: """Print quantization summary.""" @@ -1345,7 +1312,10 @@ def run(self) -> QuantizationResult: strategy = self.config.memory_strategy # Dispatch to appropriate strategy - if strategy == "lazy": + # Use file-to-file for MXFP precisions (produces vLLM-compatible packed uint8) + if self.config.precision.startswith("mxfp"): + return self._run_file2file_quantization() + elif strategy == "lazy": return self._run_lazy_quantization() elif strategy == "layerwise_cpu": return self._run_layerwise_cpu_quantization() diff --git a/src/quanto/utils/__init__.py b/src/quanto/utils/__init__.py index 74e59d9..9f6a31f 100644 --- a/src/quanto/utils/__init__.py +++ b/src/quanto/utils/__init__.py @@ -13,7 +13,6 @@ quantize_to_int4, unpack_int32_to_int4, ) -from .mxfp4_pack import pack_mxfp4, unpack_mxfp4 from .logging import Timer, get_logger, log_with_timestamp from .memory import ( clear_gpu_memory, @@ -42,9 +41,6 @@ "pack_int4_to_int32", "unpack_int32_to_int4", "pack_layer_weights", - # MXFP4 Packing - "pack_mxfp4", - "unpack_mxfp4", # Logging "get_logger", "log_with_timestamp", diff --git a/src/quanto/utils/mxfp4_pack.py b/src/quanto/utils/mxfp4_pack.py deleted file mode 100644 index c332864..0000000 --- a/src/quanto/utils/mxfp4_pack.py +++ /dev/null @@ -1,304 +0,0 @@ -""" -MXFP4 Weight Packing Utilities - -Packs BF16 weights into OCP MX FP4 format with E8M0 shared scales. -This reduces model size by ~3.76x compared to BF16. - -MXFP4 format (OCP MX Specification): -- FP4 (E2M1): 1 sign + 1 exponent + 2 mantissa = 4 bits per value -- E8M0 scale: 8-bit shared exponent per group of 32 elements -- 2 FP4 values packed per uint8 byte - -FP4 E2M1 representable values (magnitude): - 0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0 -""" - -from __future__ import annotations - -import math - -import torch - - -# FP4 E2M1 encoding: 4-bit code -> float value -# Bit layout: [sign(1)] [exponent(1)] [mantissa(2)] -# Code 0b_s_e_mm -FP4_E2M1_TABLE = torch.tensor([ - 0.0, # 0b0_0_00 = 0 - 0.5, # 0b0_0_01 = 0.5 - 1.0, # 0b0_0_10 = 1.0 - 1.5, # 0b0_0_11 = 1.5 - 2.0, # 0b0_1_00 = 2.0 - 3.0, # 0b0_1_01 = 3.0 - 4.0, # 0b0_1_10 = 4.0 - 6.0, # 0b0_1_11 = 6.0 - -0.0, # 0b1_0_00 = -0 (treated as 0) - -0.5, # 0b1_0_01 = -0.5 - -1.0, # 0b1_0_10 = -1.0 - -1.5, # 0b1_0_11 = -1.5 - -2.0, # 0b1_1_00 = -2.0 - -3.0, # 0b1_1_01 = -3.0 - -4.0, # 0b1_1_10 = -4.0 - -6.0, # 0b1_1_11 = -6.0 -], dtype=torch.float32) - -# Max representable magnitude in FP4 E2M1 -FP4_MAX = 6.0 - - -def compute_e8m0_scales(weight: torch.Tensor, group_size: int = 32) -> torch.Tensor: - """ - Compute E8M0 shared scales for MXFP4 quantization. - - E8M0 format: 8-bit exponent only (no mantissa), representing power-of-2 scales. - scale = 2^(e8m0_code - 127) (IEEE 754 bias) - - The scale is chosen so that max(abs(group)) / scale <= FP4_MAX (6.0). - - Args: - weight: [..., in_features] tensor to compute scales for - group_size: Number of elements per group (default 32) - - Returns: - E8M0 scale codes as uint8 tensor [..., num_groups] - """ - orig_shape = weight.shape - in_features = orig_shape[-1] - - # Pad if not divisible by group_size - if in_features % group_size != 0: - pad_size = group_size - (in_features % group_size) - weight = torch.nn.functional.pad(weight, (0, pad_size)) - in_features = weight.shape[-1] - - num_groups = in_features // group_size - - # Reshape to [..., num_groups, group_size] - grouped = weight.reshape(*orig_shape[:-1], num_groups, group_size) - - # Max absolute value per group - group_max = grouped.abs().amax(dim=-1).float() - - # Avoid log2(0) — use a small floor - group_max = group_max.clamp(min=1e-12) - - # E8M0 exponent: floor(log2(group_max / FP4_MAX)) + 127 (IEEE bias) - # scale = 2^(code - 127), so code = floor(log2(group_max / FP4_MAX)) + 127 - exponent = torch.floor(torch.log2(group_max / FP4_MAX)).to(torch.int32) + 127 - - # Clamp to valid E8M0 range [0, 254] (255 = NaN/Inf in E8M0) - exponent = exponent.clamp(0, 254) - - return exponent.to(torch.uint8) - - -def e8m0_to_float(e8m0_codes: torch.Tensor) -> torch.Tensor: - """Convert E8M0 scale codes to float scale values. - - Args: - e8m0_codes: uint8 tensor of E8M0 exponent codes - - Returns: - Float tensor of scale values: 2^(code - 127) - """ - return torch.pow(2.0, e8m0_codes.to(torch.float32) - 127.0) - - -def quantize_to_fp4( - weight: torch.Tensor, - e8m0_scales: torch.Tensor, - group_size: int = 32, -) -> torch.Tensor: - """ - Quantize BF16 weight to FP4 E2M1 codes using E8M0 scales. - - Args: - weight: [..., in_features] BF16/FP32 tensor - e8m0_scales: [..., num_groups] uint8 E8M0 scale codes - group_size: Elements per group - - Returns: - [..., in_features] uint8 tensor with FP4 codes (values 0-15) - """ - orig_shape = weight.shape - in_features = orig_shape[-1] - - # Pad if needed - if in_features % group_size != 0: - pad_size = group_size - (in_features % group_size) - weight = torch.nn.functional.pad(weight, (0, pad_size)) - in_features = weight.shape[-1] - - num_groups = in_features // group_size - - # Convert scales to float - scales = e8m0_to_float(e8m0_scales) # [..., num_groups] - - # Reshape weight to [..., num_groups, group_size] - grouped = weight.reshape(*orig_shape[:-1], num_groups, group_size).float() - - # Divide by scale: normalized = weight / scale - # scales shape [..., num_groups] -> [..., num_groups, 1] - normalized = grouped / scales.unsqueeze(-1) - - # Round to nearest FP4 value using the lookup table - # FP4 positive magnitudes: [0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0] - sign = normalized.sign() - magnitude = normalized.abs() - - # Magnitude-only FP4 table for nearest-value lookup - fp4_magnitudes = torch.tensor( - [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], - dtype=torch.float32, - device=weight.device, - ) - - # Find nearest FP4 magnitude for each value - # Compute distance to each FP4 magnitude - diffs = (magnitude.unsqueeze(-1) - fp4_magnitudes).abs() - mag_codes = diffs.argmin(dim=-1) # index 0-7 - - # Combine sign and magnitude into 4-bit code - # sign bit = 0 for positive, 1 for negative (bit 3) - sign_bit = (sign < 0).to(torch.uint8) << 3 - fp4_codes = sign_bit | mag_codes.to(torch.uint8) - - # Reshape back to [..., in_features] - fp4_codes = fp4_codes.reshape(*orig_shape[:-1], in_features) - - # Trim padding - if in_features != orig_shape[-1]: - fp4_codes = fp4_codes[..., : orig_shape[-1]] - - return fp4_codes - - -def pack_fp4_to_uint8(fp4_codes: torch.Tensor) -> torch.Tensor: - """ - Pack two FP4 (4-bit) values into one uint8 byte. - - Packing order: low nibble first, high nibble second. - byte = fp4_codes[2i] | (fp4_codes[2i+1] << 4) - - Args: - fp4_codes: [..., N] uint8 tensor with FP4 codes (0-15) - - Returns: - [..., N//2] uint8 tensor with packed FP4 pairs - """ - *batch, n = fp4_codes.shape - - # Pad if odd number of elements - if n % 2 != 0: - fp4_codes = torch.nn.functional.pad(fp4_codes, (0, 1)) - n = fp4_codes.shape[-1] - - # Reshape to [..., N//2, 2] - pairs = fp4_codes.reshape(*batch, n // 2, 2) - - # Pack: low nibble = first value, high nibble = second value - packed = (pairs[..., 0] & 0x0F) | ((pairs[..., 1] & 0x0F) << 4) - - return packed.to(torch.uint8) - - -def unpack_uint8_to_fp4(packed: torch.Tensor) -> torch.Tensor: - """ - Unpack uint8 bytes to two FP4 codes each. - - Args: - packed: [..., N] uint8 tensor - - Returns: - [..., N*2] uint8 tensor with FP4 codes (0-15) - """ - low = packed & 0x0F - high = (packed >> 4) & 0x0F - - # Interleave: [low0, high0, low1, high1, ...] - unpacked = torch.stack([low, high], dim=-1) - return unpacked.reshape(*packed.shape[:-1], packed.shape[-1] * 2) - - -def pack_mxfp4( - weight: torch.Tensor, - group_size: int = 32, -) -> dict[str, torch.Tensor]: - """ - Full pipeline: BF16 weight -> packed MXFP4. - - Args: - weight: [out_features, in_features] BF16 tensor - group_size: MXFP4 group size (default 32) - - Returns: - Dict with: - - "weight.packed": [out_features, in_features // 2] uint8 (packed FP4 pairs) - - "weight.scale_e8m0": [out_features, in_features // group_size] uint8 - """ - # Step 1: Compute E8M0 scales - e8m0_scales = compute_e8m0_scales(weight, group_size) - - # Step 2: Quantize to FP4 codes - fp4_codes = quantize_to_fp4(weight, e8m0_scales, group_size) - - # Step 3: Pack FP4 codes into uint8 - packed = pack_fp4_to_uint8(fp4_codes) - - return { - "weight.packed": packed, - "weight.scale_e8m0": e8m0_scales, - } - - -def unpack_mxfp4( - packed: torch.Tensor, - scale_e8m0: torch.Tensor, - group_size: int = 32, - target_dtype: torch.dtype = torch.bfloat16, -) -> torch.Tensor: - """ - Unpack MXFP4 to BF16/FP32 for inference. - - Args: - packed: [out_features, in_features // 2] uint8 packed FP4 - scale_e8m0: [out_features, num_groups] uint8 E8M0 scales - group_size: MXFP4 group size - target_dtype: Output dtype - - Returns: - [out_features, in_features] tensor in target_dtype - """ - # Step 1: Unpack uint8 -> FP4 codes - fp4_codes = unpack_uint8_to_fp4(packed) - - # Step 2: Convert FP4 codes to float using lookup table - table = FP4_E2M1_TABLE.to(fp4_codes.device) - fp4_codes_long = fp4_codes.to(torch.long) - values = table[fp4_codes_long] - - # Step 3: Apply E8M0 scales - in_features = values.shape[-1] - num_groups = scale_e8m0.shape[-1] - - scales = e8m0_to_float(scale_e8m0).to(values.device) # [..., num_groups] - - # Reshape values to [..., num_groups, group_size] - # Trim or pad to match num_groups * group_size - expected_len = num_groups * group_size - if in_features > expected_len: - values = values[..., :expected_len] - elif in_features < expected_len: - values = torch.nn.functional.pad(values, (0, expected_len - in_features)) - - grouped = values.reshape(*values.shape[:-1], num_groups, group_size) - grouped = grouped * scales.unsqueeze(-1) - - # Reshape back - result = grouped.reshape(*values.shape[:-1], num_groups * group_size) - - # Trim to original in_features - if result.shape[-1] > in_features: - result = result[..., :in_features] - - return result.to(target_dtype) diff --git a/tests/test_mxfp4_pack.py b/tests/test_mxfp4_pack.py deleted file mode 100644 index 476bd50..0000000 --- a/tests/test_mxfp4_pack.py +++ /dev/null @@ -1,186 +0,0 @@ -"""Tests for MXFP4 packing/unpacking utilities.""" - -from __future__ import annotations - -import torch -import pytest - -from quanto.utils.mxfp4_pack import ( - FP4_E2M1_TABLE, - FP4_MAX, - compute_e8m0_scales, - e8m0_to_float, - pack_fp4_to_uint8, - pack_mxfp4, - quantize_to_fp4, - unpack_mxfp4, - unpack_uint8_to_fp4, -) - - -class TestFP4Table: - """Test FP4 E2M1 encoding table.""" - - def test_table_size(self): - assert len(FP4_E2M1_TABLE) == 16 - - def test_positive_values(self): - expected = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0] - for i, v in enumerate(expected): - assert FP4_E2M1_TABLE[i].item() == v - - def test_negative_values(self): - expected = [0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0] - for i, v in enumerate(expected): - assert FP4_E2M1_TABLE[i + 8].item() == pytest.approx(v) - - def test_max_value(self): - assert FP4_MAX == 6.0 - - -class TestE8M0Scales: - """Test E8M0 scale computation.""" - - def test_unit_range(self): - """Values in [-6, 6] should have scale ~ 1.0.""" - weight = torch.tensor([[1.0, 2.0, 3.0, -1.0] * 8]) # max=3.0 - scales = compute_e8m0_scales(weight, group_size=32) - scale_float = e8m0_to_float(scales) - # scale = 2^floor(log2(3.0 / 6.0)) = 2^floor(-1) = 2^(-1) = 0.5 - assert scale_float[0, 0].item() == pytest.approx(0.5) - - def test_large_values(self): - """Large values should produce larger scales.""" - weight = torch.tensor([[100.0] * 32]) - scales = compute_e8m0_scales(weight, group_size=32) - scale_float = e8m0_to_float(scales) - # scale = 2^floor(log2(100/6)) = 2^floor(4.06) = 2^4 = 16 - assert scale_float[0, 0].item() == pytest.approx(16.0) - - def test_small_values(self): - """Small values should produce smaller scales.""" - weight = torch.tensor([[0.01] * 32]) - scales = compute_e8m0_scales(weight, group_size=32) - scale_float = e8m0_to_float(scales) - # Very small scale - assert scale_float[0, 0].item() < 0.01 - - def test_multiple_groups(self): - """Multiple groups should have independent scales.""" - weight = torch.tensor([[1.0] * 32 + [100.0] * 32]) - scales = compute_e8m0_scales(weight, group_size=32) - assert scales.shape == (1, 2) - scale_float = e8m0_to_float(scales) - assert scale_float[0, 1] > scale_float[0, 0] - - -class TestFP4Packing: - """Test FP4 uint8 packing/unpacking.""" - - def test_pack_unpack_roundtrip(self): - """Pack then unpack should recover original codes.""" - codes = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]], dtype=torch.uint8) - packed = pack_fp4_to_uint8(codes) - assert packed.shape == (1, 8) - unpacked = unpack_uint8_to_fp4(packed) - assert unpacked.shape == (1, 16) - assert torch.equal(unpacked, codes) - - def test_pack_shape(self): - """Packed tensor should be half the size.""" - codes = torch.randint(0, 16, (4, 64), dtype=torch.uint8) - packed = pack_fp4_to_uint8(codes) - assert packed.shape == (4, 32) - - def test_pack_values(self): - """Check specific packing: byte = low_nibble | (high_nibble << 4).""" - codes = torch.tensor([[3, 12]], dtype=torch.uint8) # 3 and 12 - packed = pack_fp4_to_uint8(codes) - expected = (3) | (12 << 4) # 0x03 | 0xC0 = 0xC3 - assert packed[0, 0].item() == expected - - -class TestMXFP4FullPipeline: - """Test full pack/unpack pipeline.""" - - def test_roundtrip_small(self): - """Pack then unpack a small tensor.""" - weight = torch.randn(4, 64, dtype=torch.bfloat16) - result = pack_mxfp4(weight, group_size=32) - - assert "weight.packed" in result - assert "weight.scale_e8m0" in result - assert result["weight.packed"].dtype == torch.uint8 - assert result["weight.scale_e8m0"].dtype == torch.uint8 - assert result["weight.packed"].shape == (4, 32) # 64/2 - assert result["weight.scale_e8m0"].shape == (4, 2) # 64/32 - - def test_roundtrip_accuracy(self): - """Unpacked values should be close to original (within FP4 precision).""" - weight = torch.randn(8, 128, dtype=torch.bfloat16) * 2.0 - result = pack_mxfp4(weight, group_size=32) - - recovered = unpack_mxfp4( - result["weight.packed"], - result["weight.scale_e8m0"], - group_size=32, - ) - - # FP4 has only 16 values, so error can be significant - # But relative error should be bounded - rel_error = (weight.float() - recovered.float()).abs() / (weight.float().abs() + 1e-8) - mean_rel_error = rel_error.mean() - assert mean_rel_error < 0.5, f"Mean relative error too high: {mean_rel_error}" - - def test_compression_ratio(self): - """Packed format should achieve ~3.76x compression.""" - weight = torch.randn(1024, 4096, dtype=torch.bfloat16) - result = pack_mxfp4(weight, group_size=32) - - original_bytes = weight.numel() * 2 # BF16 = 2 bytes - packed_bytes = result["weight.packed"].numel() * 1 # uint8 = 1 byte - scale_bytes = result["weight.scale_e8m0"].numel() * 1 # uint8 = 1 byte - total_packed = packed_bytes + scale_bytes - - ratio = original_bytes / total_packed - # Expected: ~3.76x (64 bytes BF16 vs 17 bytes MXFP4 per 32 elements) - assert ratio > 3.5, f"Compression ratio too low: {ratio:.2f}x" - assert ratio < 4.0, f"Compression ratio too high: {ratio:.2f}x" - - def test_zero_tensor(self): - """Zero tensor should pack/unpack correctly.""" - weight = torch.zeros(4, 64, dtype=torch.bfloat16) - result = pack_mxfp4(weight, group_size=32) - recovered = unpack_mxfp4(result["weight.packed"], result["weight.scale_e8m0"]) - assert torch.allclose(recovered, weight.float(), atol=1e-6) - - def test_large_matrix(self): - """Test with sizes typical of LLM weight matrices.""" - weight = torch.randn(4096, 4096, dtype=torch.bfloat16) - result = pack_mxfp4(weight, group_size=32) - assert result["weight.packed"].shape == (4096, 2048) - assert result["weight.scale_e8m0"].shape == (4096, 128) - - -class TestConfigIntegration: - """Test that config accepts new algorithm field.""" - - def test_rtn_default(self): - from quanto.core.config import UnifiedConfig - config = UnifiedConfig(model_path="/tmp/test", output_dir="/tmp/out") - assert config.algorithm == "rtn" - - def test_awq_algorithm(self): - from quanto.core.config import UnifiedConfig - config = UnifiedConfig(model_path="/tmp/test", output_dir="/tmp/out", algorithm="awq") - assert config.algorithm == "awq" - - def test_gptq_algorithm(self): - from quanto.core.config import UnifiedConfig - config = UnifiedConfig(model_path="/tmp/test", output_dir="/tmp/out", algorithm="gptq") - assert config.algorithm == "gptq" - - def test_invalid_algorithm(self): - from quanto.core.config import UnifiedConfig - with pytest.raises(ValueError, match="Invalid algorithm"): - UnifiedConfig(model_path="/tmp/test", output_dir="/tmp/out", algorithm="invalid") From af77605461157359ed0becfccc2de01c5d463a0c Mon Sep 17 00:00:00 2001 From: Jack Han Date: Tue, 21 Apr 2026 23:06:01 +0900 Subject: [PATCH 4/8] Add vLLM fused layer alignment for exclude list compatibility MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit vLLM fuses certain projections into single linear layers (qkv_proj, gate_up_proj), requiring all members to share the same quantization scheme. Add _align_exclude_groups() to ensure that if any projection in a fused group is excluded, the entire group is excluded together. Fused groups handled: - self_attn: q_proj + k_proj + v_proj - mlp: gate_proj + up_proj - mlp.shared_experts: gate_proj + up_proj Solar-Open-100B: 16 → 32 excluded layers after alignment. --- src/quanto/core/unified_quantizer.py | 49 ++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/src/quanto/core/unified_quantizer.py b/src/quanto/core/unified_quantizer.py index d671c7a..0b94d86 100644 --- a/src/quanto/core/unified_quantizer.py +++ b/src/quanto/core/unified_quantizer.py @@ -216,6 +216,55 @@ def _determine_exclude_layers(self) -> list[str]: exclude.extend(sensitive_layers) # Remove duplicates + exclude = list(set(exclude)) + + # Align exclusions for vLLM fused layer compatibility + exclude = self._align_exclude_groups(exclude) + + return exclude + + def _align_exclude_groups(self, exclude: list[str]) -> list[str]: + """ + Ensure fused projection groups are excluded together for vLLM compatibility. + + vLLM fuses certain projections into single linear layers: + - qkv_proj: q_proj + k_proj + v_proj (must all share same scheme) + - gate_up_proj: gate_proj + up_proj (must all share same scheme) + + If any projection in a group is excluded, exclude the entire group. + """ + # Define fused groups: suffixes that must be excluded together + fused_groups = [ + ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"], + ["mlp.gate_proj", "mlp.up_proj"], + ["mlp.shared_experts.gate_proj", "mlp.shared_experts.up_proj"], + ] + + # Find layer prefixes from exclude list (e.g., "model.layers.0") + added = set() + for layer_name in list(exclude): + # Skip glob patterns + if "*" in layer_name: + continue + + for group in fused_groups: + # Check if this excluded layer belongs to a fused group + for suffix in group: + if layer_name.endswith(suffix): + # Extract the prefix (e.g., "model.layers.0") + prefix = layer_name[: -len(suffix)] + # Add all members of this group + for member_suffix in group: + member = prefix + member_suffix + if member not in exclude and member not in added: + added.add(member) + self._log(f" + {member} (aligned with {layer_name})") + break + + if added: + self._log(f"Aligned {len(added)} additional layers for vLLM fused layer compatibility") + + exclude.extend(added) return list(set(exclude)) def _analyze_sensitive_layers(self) -> list[str]: From 998daf9fa778f04fe2473772d8f6b94930920b29 Mon Sep 17 00:00:00 2001 From: Jack Han Date: Wed, 22 Apr 2026 07:52:42 +0900 Subject: [PATCH 5/8] Add MoE router gate to default exclude list for vLLM compatibility MoE router gates (*.gate, not gate_proj) must be excluded from MXFP4 quantization because vLLM's SolarOpenTopkRouter uses regular nn.Linear which cannot load packed uint8 weights. Also update CLAUDE.md with file2file quantization path, AWQ/GPTQ support, and fused layer alignment documentation. Verified: Solar-Open-100B MXFP4 checkpoint loads and runs inference on vLLM (MI355, TP=1, 53GB checkpoint). --- CLAUDE.md | 61 +++++++++++++++++++--------- src/quanto/core/unified_quantizer.py | 3 ++ 2 files changed, 44 insertions(+), 20 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 77a3a5a..ea30deb 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -14,7 +14,7 @@ pip install -e ".[dev]" # dev (pytest, ruff) pip install -e ".[nvidia]" # with NVIDIA extras pip install -e ".[rocm]" # with ROCm extras -# Tests +# Tests (requires Quark — run on remote server with amd-quark installed) pytest tests/ -v # all tests pytest tests/test_unified_quantizer.py -v # single file pytest tests/test_unified_quantizer.py::TestUnifiedConfig::test_default_config -v # single test @@ -24,39 +24,56 @@ ruff check src/ # lint ruff check src/ --fix # lint with autofix ruff format src/ # format -# Quantize a model (CLI) -python -m quanto --model_path /path/to/model --output_dir ./output --precision int4 -python -m quanto --model_path /path --sensitivity_analysis --sensitivity_threshold 0.12 +# Quantize a model (Python API — preferred, CLI is incomplete) +python -c " +from quanto import UnifiedQuantizer, UnifiedConfig +config = UnifiedConfig( + model_path='model/path', output_dir='./output', + precision='mxfp4', sensitivity_analysis=True, + sensitivity_threshold=0.12, +) +UnifiedQuantizer(config).run() +" # Dequantize python -m quanto --dequantize --model_path ./quantized --output_dir ./dequantized # Docker-based integration tests -./scripts/run_tests.sh --gpu nvidia --test all +./scripts/run_e2e_tests.sh rocm # all ROCm tests +./scripts/run_e2e_tests.sh cuda 1,2 # specific CUDA tests ``` ## Architecture ### Pipeline flow -CLI (`__main__.py`) -> `UnifiedConfig` (dataclass validation) -> `UnifiedQuantizer.run()` -> `QuantizationResult` +`UnifiedConfig` (dataclass validation) -> `UnifiedQuantizer.run()` -> strategy dispatch -> `QuantizationResult` + +### Quantization paths + +**MXFP4/MXFP6** — Uses Quark's `quantize_model_per_safetensor` (file2file). Processes each safetensors shard independently without loading the full model. Produces packed uint8 weights + E8M0 scales compatible with vLLM's Quark loader. + +**INT4/INT8/FP8** — Uses in-memory quantization via `ModelQuantizer` + `export_safetensors`. Three memory strategies: +- `full` — entire model on GPU +- `layerwise_cpu` — model on CPU, layers quantized one-by-one on GPU +- `lazy` — weights loaded on-demand from safetensors ### Core modules (`src/quanto/core/`) -- **`config.py`** — `UnifiedConfig` dataclass with ~23 fields and `__post_init__` validation. `QuantizationConfig` is a backward-compat alias. -- **`unified_quantizer.py`** — Main quantizer implementing 4 memory strategies: `full` (entire model on GPU), `layerwise_cpu` (model on CPU, layers quantized one-by-one on GPU), `lazy` (weights loaded on-demand from safetensors), `auto` (selects based on model size vs GPU memory). -- **`base_quantizer.py`** — Abstract base class, `QuantizationResult` dataclass. -- **`dequantize.py`** — INT4 -> BF16/FP16 conversion. -- **`sensitivity/`** — Sequential sensitivity analysis: `SequentialSensitivityAnalyzer` scores per-layer quantization impact, `ActivationCache` manages GPU/CPU caching, `SensitivityScorer` computes perplexity-based metrics. +- **`config.py`** — `UnifiedConfig` dataclass. Key fields: `precision`, `memory_strategy`, `algorithm` (rtn/awq/gptq), `sensitivity_analysis`, `sensitivity_threshold`, `exclude_layers`. +- **`unified_quantizer.py`** — Main quantizer. `run()` dispatches to `_run_file2file_quantization()` for MXFP or `_run_full_gpu_quantization()` / `_run_lazy_quantization()` for INT4/INT8. Contains `_determine_exclude_layers()` with sensitivity analysis and `_align_exclude_groups()` for vLLM fused layer compatibility. +- **`sensitivity/sequential_analyzer.py`** — Iterative sensitivity analysis. Scores each layer using the actual target precision (MXFP4 uses `OCP_MXFP4Spec`, not INT4 proxy). `_build_quant_config_for_scoring()` maps precision to the correct Quark spec class. ### Supporting modules -- **`constants.py`** — `PRECISION_TO_SCHEME` mapping (e.g., `"int4"` -> `"int4_wo_128"`), `MODEL_TYPE_MAPPINGS`, `DEFAULT_EXCLUDE_PATTERNS`. -- **`analysis/layer_analyzer.py`** — Automatic detection of layers to exclude (lm_head, MoE gates, embeddings/norms with aggressive mode). +- **`constants.py`** — `PRECISION_TO_SCHEME` mapping, `MODEL_TYPE_MAPPINGS` (includes `solar_open` -> `qwen3_moe`), `SUPPORTED_ALGORITHMS`. +- **`utils/model_utils.py`** — `detect_model_type()` and `get_template()` for Quark `LLMTemplate` lookup. - **`utils/calibration.py`** — `CalibrationDataManager` loads from HuggingFace datasets or local files. - **`utils/int4_pack.py`** — INT4 <-> INT32 packing/unpacking. -- **`utils/memory.py`** — GPU memory tracking and cleanup. -- **`utils/model_utils.py`** — Model type detection and Quark template lookup. ### External dependency -AMD Quark is vendored as a git submodule in `contribs/quark/`. It provides the quantization scheme templates for each model architecture. +AMD Quark is vendored as a git submodule in `contribs/quark/`. Key Quark APIs used: +- `LLMTemplate.get_config(scheme, algorithm, exclude_layers)` — generates per-architecture quantization configs +- `quantize_model_per_safetensor()` — file-to-file quantization (MXFP4 path) +- `ModelQuantizer` / `export_safetensors()` — in-memory quantization (INT4/INT8 path) +- `OCP_MXFP4Spec`, `Int4PerGroupSpec` — precision-specific quantization specs ## Code style @@ -67,7 +84,11 @@ AMD Quark is vendored as a git submodule in `contribs/quark/`. It provides the q ## Key patterns -- **Backward compatibility aliases**: `QuantizationConfig = UnifiedConfig`, `AutoQuantizer` wraps `UnifiedQuantizer` -- **Valid precisions**: `int4`, `int4_64`, `int4_32`, `int8`, `fp8`, `mxfp4`, `mxfp6`, `uint4` -- **Memory strategies**: `full`, `layerwise_cpu`, `lazy`, `auto` -- **Export formats**: `quark` (native, default), `awq`, `gptq` (vLLM compat, INT4 only) +- **vLLM fused layer alignment**: `_align_exclude_groups()` ensures q/k/v projections and gate/up projections are excluded together (vLLM fuses these into `qkv_proj` and `gate_up_proj`) +- **AWQ/GPTQ**: Set `algorithm="awq"` or `"gptq"` in config — passed to `LLMTemplate.get_config(algorithm=...)`. Quark handles execution internally via `AwqProcessor`/`GptqProcessor`. +- **Backward compat aliases**: `QuantizationConfig = UnifiedConfig`, `AutoQuantizer = UnifiedQuantizer` +- **HF hub resolution**: File2file path auto-resolves HF hub IDs to local cache via `snapshot_download` + +## Testing environment + +Remote server mi355-gpu-16 (aac14 cluster) with MI355 GPUs. Use podman containers with `rocm/vllm-dev:nightly` image which includes PyTorch, Quark, and all dependencies. See `memory/reference_mi355_server.md` for access details. diff --git a/src/quanto/core/unified_quantizer.py b/src/quanto/core/unified_quantizer.py index 0b94d86..7527788 100644 --- a/src/quanto/core/unified_quantizer.py +++ b/src/quanto/core/unified_quantizer.py @@ -201,6 +201,9 @@ def _determine_exclude_layers(self) -> list[str]: # Add standard patterns exclude.extend(["*embed*", "*norm*"]) + # Exclude MoE router gates (not gate_proj FFN layers) + exclude.append("*.gate") + if self.config.aggressive_exclusion: exclude.extend(["*gate*"]) From 2b04bf730b50d6d5b28a8fa04165d881c51c546c Mon Sep 17 00:00:00 2001 From: Jack Han Date: Wed, 22 Apr 2026 11:48:56 +0900 Subject: [PATCH 6/8] Add AutoConfig fallback for unsupported model types and EXAONE mapping MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add JSON config fallback in _setup() and detect_model_type() when AutoConfig fails for models not yet in transformers (e.g., exaone4_5) - Add graceful tokenizer fallback when AutoTokenizer fails - Add EXAONE model type mappings (exaone, exaone4_5, exaone4_5_text → llama) - Keep auto-strategy detection intact for non-MXFP paths - For multimodal models, merge text_config into top-level config Tested: EXAONE-4.5-33B MXFP4 quantization (64GB → 20GB, 3.2x, 20s) --- src/quanto/constants.py | 3 ++ src/quanto/core/unified_quantizer.py | 72 +++++++++++++++++++++------- src/quanto/utils/model_utils.py | 14 ++++-- 3 files changed, 68 insertions(+), 21 deletions(-) diff --git a/src/quanto/constants.py b/src/quanto/constants.py index 7421a13..b62a31d 100644 --- a/src/quanto/constants.py +++ b/src/quanto/constants.py @@ -47,6 +47,9 @@ "phi3": "phi3", "phi4": "phi3", "solar_open": "qwen3_moe", + "exaone": "llama", + "exaone4_5": "llama", + "exaone4_5_text": "llama", } # Default layers to exclude from quantization diff --git a/src/quanto/core/unified_quantizer.py b/src/quanto/core/unified_quantizer.py index 7527788..51ea759 100644 --- a/src/quanto/core/unified_quantizer.py +++ b/src/quanto/core/unified_quantizer.py @@ -117,18 +117,28 @@ def _setup(self) -> None: self._log("Setting up quantization...") # Load HuggingFace config (no weights) - self.hf_config = AutoConfig.from_pretrained( - self.config.model_path, trust_remote_code=self.config.trust_remote_code - ) + try: + self.hf_config = AutoConfig.from_pretrained( + self.config.model_path, trust_remote_code=self.config.trust_remote_code + ) + except (ValueError, KeyError) as e: + # Fallback for models not yet supported by transformers (e.g., exaone4_5) + self._log(f"AutoConfig failed ({e.__class__.__name__}), using JSON fallback") + self.hf_config = self._load_config_from_json() + self._detect_model_type() self._get_template() # Load tokenizer - self.tokenizer = AutoTokenizer.from_pretrained( - self.config.model_path, trust_remote_code=self.config.trust_remote_code - ) - if self.tokenizer.pad_token is None: - self.tokenizer.pad_token = self.tokenizer.eos_token + try: + self.tokenizer = AutoTokenizer.from_pretrained( + self.config.model_path, trust_remote_code=self.config.trust_remote_code + ) + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + except (ValueError, KeyError, OSError) as e: + self._log(f"AutoTokenizer failed ({e.__class__.__name__}), skipping tokenizer") + self.tokenizer = None # Create output directory os.makedirs(self.config.output_dir, exist_ok=True) @@ -137,6 +147,29 @@ def _setup(self) -> None: self.timing["setup"] = time.time() - start_time self._log(f"Setup completed in {self.timing['setup']:.2f}s") + def _load_config_from_json(self): + """Fallback config loading when AutoConfig fails (unsupported model types).""" + from pathlib import Path + from types import SimpleNamespace + + model_path = Path(self.config.model_path) + config_file = model_path / "config.json" + + # If model_path is a HF hub ID, resolve to local cache + if not config_file.exists(): + from huggingface_hub import hf_hub_download + + config_file = Path(hf_hub_download(self.config.model_path, "config.json")) + + with open(config_file) as f: + config_dict = json.load(f) + + # For multimodal models, text_config holds the LLM settings + text_config = config_dict.get("text_config", {}) + merged = {**config_dict, **text_config} + + return SimpleNamespace(**merged) + def _get_layer_info(self) -> dict[str, Any]: """Get layer information from config.""" info = { @@ -1352,22 +1385,25 @@ def run(self) -> QuantizationResult: Returns: QuantizationResult with details of the quantization """ - # Determine strategy + # Use file-to-file for MXFP precisions (produces vLLM-compatible packed uint8) + # This path skips auto-strategy detection since it doesn't need the model in memory + if self.config.precision.startswith("mxfp"): + return self._run_file2file_quantization() + + # Determine memory strategy for non-MXFP precisions if self.config.memory_strategy == "auto": - # Need to load config first for auto-detection - self.hf_config = AutoConfig.from_pretrained( - self.config.model_path, trust_remote_code=self.config.trust_remote_code - ) + try: + self.hf_config = AutoConfig.from_pretrained( + self.config.model_path, trust_remote_code=self.config.trust_remote_code + ) + except (ValueError, KeyError): + self.hf_config = self._load_config_from_json() strategy = self._auto_detect_strategy() self._log(f"Auto-detected memory strategy: {strategy}") else: strategy = self.config.memory_strategy - # Dispatch to appropriate strategy - # Use file-to-file for MXFP precisions (produces vLLM-compatible packed uint8) - if self.config.precision.startswith("mxfp"): - return self._run_file2file_quantization() - elif strategy == "lazy": + if strategy == "lazy": return self._run_lazy_quantization() elif strategy == "layerwise_cpu": return self._run_layerwise_cpu_quantization() diff --git a/src/quanto/utils/model_utils.py b/src/quanto/utils/model_utils.py index d979624..a224af6 100644 --- a/src/quanto/utils/model_utils.py +++ b/src/quanto/utils/model_utils.py @@ -35,9 +35,17 @@ def detect_model_type(model_path: str, trust_remote_code: bool = True) -> str: config = json.load(f) model_type = config.get("model_type", config.get("architectures", ["unknown"])[0]) else: - # Load config from model using transformers - config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code) - model_type = getattr(config, "model_type", getattr(config, "architectures", ["unknown"])[0]) + # Try AutoConfig first, fall back to JSON download for unsupported model types + try: + config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code) + model_type = getattr(config, "model_type", getattr(config, "architectures", ["unknown"])[0]) + except (ValueError, KeyError): + from huggingface_hub import hf_hub_download + + config_file = hf_hub_download(model_path, "config.json") + with open(config_file) as f: + config = json.load(f) + model_type = config.get("model_type", config.get("architectures", ["unknown"])[0]) return model_type From 337faa5babba3e14db5f25e48c232951aae476be Mon Sep 17 00:00:00 2001 From: Jack Han Date: Thu, 23 Apr 2026 15:49:34 +0900 Subject: [PATCH 7/8] Add quantization CLI and Kimi-K2 model support - Add main() to auto_quantize.py with full argparse CLI (--model_path, --precision, --exclude_layers_file, etc.) - Fix __main__.py dispatcher to pass args through to quantization mode - Add kimi_k2/kimi_k25 model type mapping in constants.py - Update CLAUDE.md with CLI usage examples - Remove project structure section from README.md --- CLAUDE.md | 22 +++++++-- README.md | 47 ------------------ src/quanto/__main__.py | 42 +++++++--------- src/quanto/constants.py | 3 ++ src/quanto/core/auto_quantize.py | 85 ++++++++++++++++++++++++++++++++ 5 files changed, 123 insertions(+), 76 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index ea30deb..7e579ff 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -24,8 +24,22 @@ ruff check src/ # lint ruff check src/ --fix # lint with autofix ruff format src/ # format -# Quantize a model (Python API — preferred, CLI is incomplete) -python -c " +# Quantize a model (CLI) +python -m quanto \ + --model_path model/path \ + --output_dir ./output \ + --precision mxfp4 \ + --sensitivity_analysis \ + --sensitivity_threshold 0.12 + +# Quantize with explicit exclude list (e.g., attn-excl strategy) +python -m quanto \ + --model_path model/path \ + --output_dir ./output \ + --precision mxfp4 \ + --exclude_layers_file exclude.json + +# Quantize (Python API) from quanto import UnifiedQuantizer, UnifiedConfig config = UnifiedConfig( model_path='model/path', output_dir='./output', @@ -33,7 +47,6 @@ config = UnifiedConfig( sensitivity_threshold=0.12, ) UnifiedQuantizer(config).run() -" # Dequantize python -m quanto --dequantize --model_path ./quantized --output_dir ./dequantized @@ -63,7 +76,8 @@ python -m quanto --dequantize --model_path ./quantized --output_dir ./dequantize - **`sensitivity/sequential_analyzer.py`** — Iterative sensitivity analysis. Scores each layer using the actual target precision (MXFP4 uses `OCP_MXFP4Spec`, not INT4 proxy). `_build_quant_config_for_scoring()` maps precision to the correct Quark spec class. ### Supporting modules -- **`constants.py`** — `PRECISION_TO_SCHEME` mapping, `MODEL_TYPE_MAPPINGS` (includes `solar_open` -> `qwen3_moe`), `SUPPORTED_ALGORITHMS`. +- **`constants.py`** — `PRECISION_TO_SCHEME` mapping, `MODEL_TYPE_MAPPINGS` (includes `solar_open` -> `qwen3_moe`, `kimi_k2` -> `kimi_k25`), `SUPPORTED_ALGORITHMS`. +- **`auto_quantize.py`** — CLI `main()` entry point. Parses args and creates `UnifiedConfig`. Supports `--exclude_layers_file` for JSON exclude lists. - **`utils/model_utils.py`** — `detect_model_type()` and `get_template()` for Quark `LLMTemplate` lookup. - **`utils/calibration.py`** — `CalibrationDataManager` loads from HuggingFace datasets or local files. - **`utils/int4_pack.py`** — INT4 <-> INT32 packing/unpacking. diff --git a/README.md b/README.md index a464c66..a0f3808 100644 --- a/README.md +++ b/README.md @@ -82,53 +82,6 @@ docker build -f docker/Dockerfile.rocm.dev -t quanto:rocm-dev . docker run --device=/dev/kfd --device=/dev/dri --group-add video -v $(pwd):/workspace -w /workspace quanto:rocm-dev bash ``` -## Project Structure - -``` -quanto/ -├── pyproject.toml # Package configuration -├── README.md # This file -├── requirements.txt # Base requirements -├── requirements-nvidia.txt # NVIDIA-specific deps -├── requirements-rocm.txt # ROCm-specific deps -├── contribs/ -│ └── quark/ # AMD Quark (submodule) -├── docker/ -│ ├── Dockerfile.cuda # Pre-built for CUDA -│ ├── Dockerfile.cuda.dev # Development for CUDA -│ ├── Dockerfile.rocm # Pre-built for ROCm -│ └── Dockerfile.rocm.dev # Development for ROCm -├── docs/ -│ └── examples.md # Experiment results -├── examples/ # Example scripts -├── scripts/ -│ └── repack.py # Weight packing utilities -├── src/quanto/ # Main package -│ ├── __init__.py -│ ├── __main__.py # CLI entry point -│ ├── constants.py # Shared constants -│ ├── core/ # Quantization engines -│ │ ├── base_quantizer.py -│ │ ├── auto_quantize.py -│ │ ├── layerwise_quant.py -│ │ ├── lazy_layerwise_quant.py -│ │ ├── iterative_quantizer.py -│ │ └── dequantize.py -│ ├── analysis/ # Layer analysis -│ │ ├── layer_analyzer.py -│ │ └── sensitivity_analyzer.py -│ ├── export/ # Export utilities -│ │ ├── hf_export.py -│ │ └── model_assembler.py -│ └── utils/ # Shared utilities -│ ├── calibration.py -│ ├── int4_pack.py -│ ├── logging.py -│ ├── memory.py -│ └── model_utils.py -└── tests/ # Test suite -``` - ## Usage ### Basic Usage diff --git a/src/quanto/__main__.py b/src/quanto/__main__.py index b81d8dc..fa797c0 100644 --- a/src/quanto/__main__.py +++ b/src/quanto/__main__.py @@ -14,39 +14,31 @@ def main() -> int: """Main entry point that dispatches to quantize or dequantize.""" - parser = argparse.ArgumentParser( - description="Quanto: LLM Quantization Tool", - add_help=False, - ) - - # Add --dequantize flag to detect mode - parser.add_argument("--dequantize", action="store_true", help="Run dequantization mode") - parser.add_argument("--help", "-h", action="store_true", help="Show help") + # Check if --dequantize is in args + if "--dequantize" in sys.argv: + from quanto.core.dequantize import main as dequant_main - # Parse known args to detect mode - args, remaining = parser.parse_known_args() + return dequant_main() - if args.help: - parser.print_help() - print("\nModes:") + # Show top-level help only when no args or just --help with no other flags + if len(sys.argv) <= 1 or (len(sys.argv) == 2 and sys.argv[1] in ("--help", "-h")): + print("usage: python -m quanto [--dequantize] [options]") + print() + print("Quanto: LLM Quantization Tool") + print() + print("Modes:") print( - " Quantization: python -m quanto --model_path ... --output_dir ... --precision int4" + " Quantization: python -m quanto --model_path ... --output_dir ... --precision mxfp4" ) print(" Dequantization: python -m quanto --dequantize --model_path ... --output_dir ...") + print() + print("Run 'python -m quanto --model_path x --output_dir y --help' for full quantization options.") return 0 - if args.dequantize: - # Run dequantization - from quanto.core.dequantize import main as dequant_main - - # Add back --dequantize flag since dequantize module expects it - sys.argv = [sys.argv[0], "--dequantize"] + remaining - return dequant_main() - else: - # Run quantization - from quanto.core.auto_quantize import main as quant_main + # Default: quantization mode + from quanto.core.auto_quantize import main as quant_main - return quant_main() + return quant_main() if __name__ == "__main__": diff --git a/src/quanto/constants.py b/src/quanto/constants.py index b62a31d..b58490d 100644 --- a/src/quanto/constants.py +++ b/src/quanto/constants.py @@ -50,6 +50,9 @@ "exaone": "llama", "exaone4_5": "llama", "exaone4_5_text": "llama", + "exaone_moe": "qwen3_moe", + "kimi_k2": "kimi_k25", + "kimi_k25": "kimi_k25", } # Default layers to exclude from quantization diff --git a/src/quanto/core/auto_quantize.py b/src/quanto/core/auto_quantize.py index 79dc6ea..63d8aa9 100644 --- a/src/quanto/core/auto_quantize.py +++ b/src/quanto/core/auto_quantize.py @@ -44,4 +44,89 @@ "QuantizationConfig", "UnifiedQuantizer", "UnifiedConfig", + "main", ] + + +def main() -> int: + """CLI entry point for quantization.""" + import argparse + import json + import sys + + parser = argparse.ArgumentParser( + description="Quanto: Quantize a model", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + # Required + parser.add_argument("--model_path", required=True, help="HuggingFace model ID or local path") + parser.add_argument("--output_dir", required=True, help="Output directory for quantized model") + + # Quantization settings + parser.add_argument( + "--precision", + default="mxfp4", + choices=["int4", "int4_64", "int4_32", "int8", "fp8", "mxfp4", "mxfp6", "uint4"], + help="Target precision", + ) + parser.add_argument("--algorithm", default="rtn", choices=["rtn", "awq", "gptq"], help="Quantization algorithm") + parser.add_argument("--memory_strategy", default="auto", choices=["full", "layerwise_cpu", "lazy", "auto"]) + + # Sensitivity analysis + parser.add_argument("--sensitivity_analysis", action="store_true", help="Enable iterative sensitivity analysis") + parser.add_argument("--sensitivity_threshold", type=float, default=0.12, help="Sensitivity threshold for layer exclusion") + parser.add_argument("--max_iterations", type=int, default=10, help="Max iterations for sensitivity analysis") + + # Layer exclusion + parser.add_argument("--exclude_layers", nargs="*", help="Layer name patterns to exclude from quantization") + parser.add_argument("--exclude_layers_file", help="JSON file containing exclude layer list") + + # Calibration + parser.add_argument("--calibration_data", default="pileval", help="Calibration dataset name or path") + parser.add_argument("--num_calib_samples", type=int, default=128, help="Number of calibration samples") + parser.add_argument("--seq_len", type=int, default=512, help="Sequence length for calibration") + + # Other + parser.add_argument("--device", default="cuda", help="Device (cuda, cuda:0, cpu)") + parser.add_argument("--trust_remote_code", action="store_true", default=True) + parser.add_argument("--no_trust_remote_code", action="store_true", help="Disable trust_remote_code") + parser.add_argument("--skip_evaluation", action="store_true", help="Skip perplexity evaluation") + parser.add_argument("--sensitivity_cache_on_gpu", action="store_true", default=True) + + args = parser.parse_args() + + # Handle exclude_layers from file + exclude_layers = args.exclude_layers + if args.exclude_layers_file: + with open(args.exclude_layers_file) as f: + exclude_layers = json.load(f) + + config = UnifiedConfig( + model_path=args.model_path, + output_dir=args.output_dir, + precision=args.precision, + algorithm=args.algorithm, + memory_strategy=args.memory_strategy, + sensitivity_analysis=args.sensitivity_analysis, + sensitivity_threshold=args.sensitivity_threshold, + max_iterations=args.max_iterations, + exclude_layers=exclude_layers, + calibration_data=args.calibration_data, + num_calib_samples=args.num_calib_samples, + seq_len=args.seq_len, + device=args.device, + trust_remote_code=not args.no_trust_remote_code, + skip_evaluation=args.skip_evaluation, + sensitivity_cache_on_gpu=args.sensitivity_cache_on_gpu, + ) + + quantizer = UnifiedQuantizer(config) + result = quantizer.run() + + if result.success: + print(json.dumps(result.to_dict(), indent=2)) + return 0 + else: + print(f"FAILED: {result.error_message}", file=sys.stderr) + return 1 From b7c49c664207a53bcde35bcd4f285d074f3e39dc Mon Sep 17 00:00:00 2001 From: Jack Han Date: Mon, 27 Apr 2026 03:09:06 +0900 Subject: [PATCH 8/8] Wire sensitivity metric through quantization flow --- CLAUDE.md | 7 +- src/quanto/constants.py | 8 ++ src/quanto/core/auto_quantize.py | 12 ++- src/quanto/core/config.py | 27 ++++++- .../core/sensitivity/sequential_analyzer.py | 6 +- src/quanto/core/unified_quantizer.py | 78 ++++++++++++++++--- tests/test_unified_quantizer.py | 75 ++++++++++++++++++ 7 files changed, 196 insertions(+), 17 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 7e579ff..0a0f52e 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -99,7 +99,12 @@ AMD Quark is vendored as a git submodule in `contribs/quark/`. Key Quark APIs us ## Key patterns - **vLLM fused layer alignment**: `_align_exclude_groups()` ensures q/k/v projections and gate/up projections are excluded together (vLLM fuses these into `qkv_proj` and `gate_up_proj`) -- **AWQ/GPTQ**: Set `algorithm="awq"` or `"gptq"` in config — passed to `LLMTemplate.get_config(algorithm=...)`. Quark handles execution internally via `AwqProcessor`/`GptqProcessor`. +- **AWQ/GPTQ algorithm support**: Enabled via config validation matrix in `constants.ALGORITHM_PRECISION_SUPPORT`. Valid combinations: + - RTN: all precisions (int4, int4_64, int4_32, int8, fp8, mxfp4, mxfp6, uint4) + - AWQ: INT4 only (int4, int4_64, int4_32) — activation-aware, Quark `AwqProcessor` + - GPTQ: INT4 only (int4) — Hessian-based, Quark `GptqProcessor` + - Invalid combos (e.g., AWQ+MXFP4, GPTQ+INT8) raise `ValueError` in `UnifiedConfig.validate()` +- **Sensitivity analysis algorithm-awareness**: `SequentialSensitivityAnalyzer._build_quant_config_for_scoring()` passes actual algorithm (not RTN proxy) to `LLMTemplate.get_config()` for correct Quark spec (critical for AWQ/GPTQ accuracy) - **Backward compat aliases**: `QuantizationConfig = UnifiedConfig`, `AutoQuantizer = UnifiedQuantizer` - **HF hub resolution**: File2file path auto-resolves HF hub IDs to local cache via `snapshot_download` diff --git a/src/quanto/constants.py b/src/quanto/constants.py index b58490d..f5c1715 100644 --- a/src/quanto/constants.py +++ b/src/quanto/constants.py @@ -89,3 +89,11 @@ "awq", "gptq", ] + +# Algorithm-Precision support matrix +# Defines which precisions are supported for each quantization algorithm +ALGORITHM_PRECISION_SUPPORT: dict[str, list[str]] = { + "rtn": ["int4", "int4_64", "int4_32", "int8", "fp8", "mxfp4", "mxfp6", "uint4"], + "awq": ["int4", "int4_64", "int4_32"], # AWQ is INT4-only (activation-aware) + "gptq": ["int4"], # GPTQ is INT4-only (Hessian-based) +} diff --git a/src/quanto/core/auto_quantize.py b/src/quanto/core/auto_quantize.py index 63d8aa9..092c67a 100644 --- a/src/quanto/core/auto_quantize.py +++ b/src/quanto/core/auto_quantize.py @@ -75,14 +75,21 @@ def main() -> int: # Sensitivity analysis parser.add_argument("--sensitivity_analysis", action="store_true", help="Enable iterative sensitivity analysis") - parser.add_argument("--sensitivity_threshold", type=float, default=0.12, help="Sensitivity threshold for layer exclusion") + parser.add_argument("--sensitivity_threshold", type=float, default=0.0, help="Sensitivity threshold for layer exclusion") + parser.add_argument( + "--sensitivity_metric", + type=str, + default="relative", + choices=["relative", "mse", "mae", "cosine", "kl"], + help="Metric used to rank sensitive layers", + ) parser.add_argument("--max_iterations", type=int, default=10, help="Max iterations for sensitivity analysis") # Layer exclusion parser.add_argument("--exclude_layers", nargs="*", help="Layer name patterns to exclude from quantization") parser.add_argument("--exclude_layers_file", help="JSON file containing exclude layer list") - # Calibration + # Calibration data parser.add_argument("--calibration_data", default="pileval", help="Calibration dataset name or path") parser.add_argument("--num_calib_samples", type=int, default=128, help="Number of calibration samples") parser.add_argument("--seq_len", type=int, default=512, help="Sequence length for calibration") @@ -110,6 +117,7 @@ def main() -> int: memory_strategy=args.memory_strategy, sensitivity_analysis=args.sensitivity_analysis, sensitivity_threshold=args.sensitivity_threshold, + sensitivity_metric=args.sensitivity_metric, max_iterations=args.max_iterations, exclude_layers=exclude_layers, calibration_data=args.calibration_data, diff --git a/src/quanto/core/config.py b/src/quanto/core/config.py index b59dd70..2cb5919 100644 --- a/src/quanto/core/config.py +++ b/src/quanto/core/config.py @@ -7,7 +7,7 @@ from __future__ import annotations -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import Any, Literal @@ -42,6 +42,7 @@ class UnifiedConfig: aggressive_exclusion: Use aggressive layer exclusion rules sensitivity_analysis: Enable sequential sensitivity analysis for layer exclusion sensitivity_threshold: Threshold for excluding sensitive layers + sensitivity_metric: Metric used to rank sensitive layers sensitivity_cache_on_gpu: Cache activations on GPU (faster, more memory) skip_evaluation: Skip perplexity evaluation trust_remote_code: Trust remote code when loading models @@ -81,6 +82,7 @@ class UnifiedConfig: # Sensitivity-based exclusion sensitivity_analysis: bool = False # Enable sequential sensitivity analysis sensitivity_threshold: float = 0.0 # Threshold for excluding sensitive layers (0 = disabled, typical values: 0.12-0.15 for INT4) + sensitivity_metric: str = "relative" # One of: relative, mse, mae, cosine, kl sensitivity_cache_on_gpu: bool = True # Cache activations on GPU (faster, uses more memory) max_iterations: int = 10 # Maximum iterations for iterative sensitivity analysis (1 = single-pass) @@ -150,13 +152,31 @@ def validate(self) -> None: if self.max_iterations < 1: raise ValueError(f"max_iterations must be >= 1, got {self.max_iterations}") - # Validate algorithm - valid_algorithms = ["rtn", "awq", "gptq"] + # Validate sensitivity metric + valid_sensitivity_metrics = ["relative", "mse", "mae", "cosine", "kl"] + if self.sensitivity_metric not in valid_sensitivity_metrics: + raise ValueError( + f"Invalid sensitivity_metric '{self.sensitivity_metric}'. " + f"Must be one of: {valid_sensitivity_metrics}" + ) + + # Validate algorithm and precision combination + from ..constants import ALGORITHM_PRECISION_SUPPORT + + valid_algorithms = list(ALGORITHM_PRECISION_SUPPORT.keys()) if self.algorithm not in valid_algorithms: raise ValueError( f"Invalid algorithm '{self.algorithm}'. Must be one of: {valid_algorithms}" ) + # Check if precision is supported for this algorithm + supported_precisions = ALGORITHM_PRECISION_SUPPORT[self.algorithm] + if self.precision not in supported_precisions: + raise ValueError( + f"Precision '{self.precision}' not supported for algorithm '{self.algorithm}'. " + f"Supported precisions: {supported_precisions}" + ) + def to_dict(self) -> dict[str, Any]: """Convert configuration to dictionary.""" return { @@ -175,6 +195,7 @@ def to_dict(self) -> dict[str, Any]: "aggressive_exclusion": self.aggressive_exclusion, "sensitivity_analysis": self.sensitivity_analysis, "sensitivity_threshold": self.sensitivity_threshold, + "sensitivity_metric": self.sensitivity_metric, "sensitivity_cache_on_gpu": self.sensitivity_cache_on_gpu, "max_iterations": self.max_iterations, "skip_evaluation": self.skip_evaluation, diff --git a/src/quanto/core/sensitivity/sequential_analyzer.py b/src/quanto/core/sensitivity/sequential_analyzer.py index a8e4cfa..9a9da09 100644 --- a/src/quanto/core/sensitivity/sequential_analyzer.py +++ b/src/quanto/core/sensitivity/sequential_analyzer.py @@ -359,7 +359,7 @@ def _quantize_layer(self, layer: nn.Module, layer_name: str) -> nn.Module: def _build_quant_config_for_scoring(self): """ - Build quantization config matching the target precision. + Build quantization config matching the target precision and algorithm. Uses the LLMTemplate if available (produces architecture-specific configs), otherwise falls back to building a config from the precision's Quark Spec class. @@ -369,12 +369,14 @@ def _build_quant_config_for_scoring(self): from ...constants import PRECISION_TO_SCHEME precision = self.config.precision + algorithm = (self.config.algorithm or "rtn").lower() scheme = PRECISION_TO_SCHEME.get(precision, precision) # Prefer template-based config (architecture-specific) if self.template: return self.template.get_config( scheme=scheme, + algorithm=algorithm if algorithm != "rtn" else None, exclude_layers=[], ) @@ -402,6 +404,8 @@ def _build_quant_config_for_scoring(self): spec = Int4PerGroupSpec(ch_axis=0, group_size=128).to_quantization_spec() + # Note: algorithm parameter (if needed by Quark) would be passed here + # Currently Quark's fallback configs don't take algorithm parameter return QConfig( global_quant_config=QLayerConfig(weight=spec), ) diff --git a/src/quanto/core/unified_quantizer.py b/src/quanto/core/unified_quantizer.py index 51ea759..b607518 100644 --- a/src/quanto/core/unified_quantizer.py +++ b/src/quanto/core/unified_quantizer.py @@ -43,6 +43,7 @@ from .base_quantizer import QuantizationResult from .config import UnifiedConfig from .sensitivity import SequentialSensitivityAnalyzer +from .sensitivity.scorer import SensitivityMetric class UnifiedQuantizer: @@ -78,6 +79,8 @@ def __init__(self, config: UnifiedConfig): self.safetensors_files = [] self.weight_index = {} # Maps weight name to file path self.timing = {} + self._resolved_algorithm: str | None = None + self._calibration_loader_cache = None def _log(self, message: str) -> None: """Print log message with timestamp.""" @@ -111,6 +114,61 @@ def _get_template(self) -> LLMTemplate | None: self._log(f"Warning: No template found for model type '{self.model_type}'") return self.template + def _resolve_algorithm(self) -> str | None: + """Resolve the requested quantization algorithm to pass to Quark. + + Returns: + None for RTN (Quark default), algorithm name for AWQ/GPTQ + """ + if self._resolved_algorithm is not None: + return self._resolved_algorithm + + algorithm = (self.config.algorithm or "rtn").lower() + if algorithm == "rtn": + self._resolved_algorithm = None + return None + + # AWQ and GPTQ are now supported via Quark's LLMTemplate + if algorithm in {"awq", "gptq"}: + self._resolved_algorithm = algorithm + return algorithm + + raise ValueError(f"Unsupported quantization algorithm: '{self.config.algorithm}'") + + def _resolve_sensitivity_metric(self) -> SensitivityMetric: + """Resolve configured sensitivity metric string to enum value.""" + metric_name = (self.config.sensitivity_metric or "relative").lower() + mapping = { + "relative": SensitivityMetric.RELATIVE_NORM, + "mse": SensitivityMetric.MSE, + "mae": SensitivityMetric.MAE, + "cosine": SensitivityMetric.COSINE, + "kl": SensitivityMetric.KL_DIVERGENCE, + } + try: + return mapping[metric_name] + except KeyError as exc: + valid = ", ".join(mapping.keys()) + raise ValueError( + f"Invalid sensitivity_metric '{self.config.sensitivity_metric}'. " + f"Must be one of: {valid}" + ) from exc + + def _get_calibration_dataloader(self): + """Load and cache the calibration dataloader.""" + if self._calibration_loader_cache is None: + if self.tokenizer is None: + raise RuntimeError("Tokenizer must be initialized before loading calibration data") + self._calibration_loader_cache = get_calib_dataloader( + dataset_name_or_path=self.config.calibration_data, + tokenizer=self.tokenizer, + batch_size=self.config.batch_size, + num_calib_data=self.config.num_calib_samples, + seqlen=self.config.seq_len, + device=self.config.device, + ) + return self._calibration_loader_cache + def _setup(self) -> None: """Load config, tokenizer, and build weight index.""" start_time = time.time() @@ -346,6 +404,7 @@ def _run_sequential_sensitivity_analysis(self) -> list[str]: analyzer = SequentialSensitivityAnalyzer( config=self.config, + metric=self._resolve_sensitivity_metric(), cache_on_gpu=self.config.sensitivity_cache_on_gpu, template=self.template, ) @@ -416,6 +475,7 @@ def _run_iterative_sensitivity_analysis(self) -> list[str]: # Create analyzer with current exclusion list analyzer = SequentialSensitivityAnalyzer( config=self.config, + metric=self._resolve_sensitivity_metric(), cache_on_gpu=cache_on_gpu, initial_exclude_layers=all_excluded, template=self.template, @@ -534,10 +594,12 @@ def _create_quant_config(self, exclude_layers: list[str]) -> QConfig: quant_scheme = self._get_quant_scheme() self._log(f"Using quantization scheme: {quant_scheme}") - # Determine algorithm (None for RTN, "awq"/"gptq" for advanced) - algorithm = self.config.algorithm if self.config.algorithm != "rtn" else None + # Determine algorithm (None for RTN, raise for unsupported) + algorithm = self._resolve_algorithm() if algorithm: self._log(f"Using quantization algorithm: {algorithm}") + else: + self._log("Using quantization algorithm: rtn") # Create base quant config if self.template: @@ -1221,14 +1283,7 @@ def _run_full_gpu_quantization(self) -> QuantizationResult: # Get calibration data self._log("Loading calibration data...") - calib_loader = get_calib_dataloader( - dataset_name_or_path=self.config.calibration_data, - tokenizer=self.tokenizer, - batch_size=self.config.batch_size, - num_calib_data=self.config.num_calib_samples, - seqlen=self.config.seq_len, - device=self.config.device, - ) + calib_loader = self._get_calibration_dataloader() # Quantize self._log("Quantizing model...") @@ -1385,6 +1440,9 @@ def run(self) -> QuantizationResult: Returns: QuantizationResult with details of the quantization """ + # Resolve algorithm early to provide immediate feedback + self._resolve_algorithm() + # Use file-to-file for MXFP precisions (produces vLLM-compatible packed uint8) # This path skips auto-strategy detection since it doesn't need the model in memory if self.config.precision.startswith("mxfp"): diff --git a/tests/test_unified_quantizer.py b/tests/test_unified_quantizer.py index e2b93ac..33f450b 100644 --- a/tests/test_unified_quantizer.py +++ b/tests/test_unified_quantizer.py @@ -21,6 +21,7 @@ sys.path.insert(0, str(Path(__file__).parent.parent / "src")) from quanto import UnifiedConfig, UnifiedQuantizer +from quanto.core.sensitivity.scorer import SensitivityMetric class TestUnifiedConfig: @@ -39,6 +40,7 @@ def test_default_config(self): assert config.calibration_data == "pileval" assert config.num_calib_samples == 128 assert config.device == "cuda" + assert config.sensitivity_metric == "relative" def test_custom_config(self): """Test custom configuration values.""" @@ -99,6 +101,27 @@ def test_to_dict(self): assert d["output_dir"] == "/tmp/output" assert d["precision"] == "int4" assert d["memory_strategy"] == "auto" + assert d["sensitivity_metric"] == "relative" + + def test_invalid_sensitivity_metric(self): + """Test that invalid sensitivity metric raises error.""" + with pytest.raises(ValueError, match="Invalid sensitivity_metric"): + UnifiedConfig( + model_path="/tmp/model", + output_dir="/tmp/output", + sensitivity_metric="invalid", + ) + + def test_sensitivity_metric_mapping(self): + """Test sensitivity metric mapping to enum values.""" + config = UnifiedConfig( + model_path="/tmp/model", + output_dir="/tmp/output", + sensitivity_metric="mse", + ) + quantizer = UnifiedQuantizer(config) + + assert quantizer._resolve_sensitivity_metric() == SensitivityMetric.MSE def test_from_dict(self): """Test config deserialization.""" @@ -112,6 +135,58 @@ def test_from_dict(self): assert config.model_path == "/tmp/model" assert config.precision == "fp8" + def test_algorithm_precision_validation(self): + """Test that algorithm-precision combinations are validated.""" + # AWQ only supports INT4 precisions + with pytest.raises(ValueError, match="Precision.*not supported.*algorithm"): + UnifiedConfig( + model_path="/tmp/model", + output_dir="/tmp/output", + algorithm="awq", + precision="mxfp4", # MXFP4 not supported for AWQ + ) + + # GPTQ only supports INT4 + with pytest.raises(ValueError, match="Precision.*not supported.*algorithm"): + UnifiedConfig( + model_path="/tmp/model", + output_dir="/tmp/output", + algorithm="gptq", + precision="int8", # INT8 not supported for GPTQ + ) + + def test_algorithm_precision_supported(self): + """Test that valid algorithm-precision combinations are accepted.""" + # AWQ + INT4 should be valid + config = UnifiedConfig( + model_path="/tmp/model", + output_dir="/tmp/output", + algorithm="awq", + precision="int4", + ) + assert config.algorithm == "awq" + assert config.precision == "int4" + + # GPTQ + INT4 should be valid + config = UnifiedConfig( + model_path="/tmp/model", + output_dir="/tmp/output", + algorithm="gptq", + precision="int4", + ) + assert config.algorithm == "gptq" + assert config.precision == "int4" + + # RTN + MXFP4 should be valid + config = UnifiedConfig( + model_path="/tmp/model", + output_dir="/tmp/output", + algorithm="rtn", + precision="mxfp4", + ) + assert config.algorithm == "rtn" + assert config.precision == "mxfp4" + class TestAutoDetectStrategy: """Tests for auto-detect strategy logic."""