From 3ed57b504acd571a3d46cf2390022dfb41bc3fe6 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 10 Mar 2026 22:17:29 +0000 Subject: [PATCH 1/6] Initial plan From 12f179d772f2a365920792af6562423502f59b46 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 10 Mar 2026 22:31:07 +0000 Subject: [PATCH 2/6] Add activation-weighted GPTQ-Pro scale search Co-authored-by: groxaxo <76023196+groxaxo@users.noreply.github.com> --- README.md | 3 +- gptqmodel/looper/gptq_processor.py | 5 ++++ gptqmodel/quantization/config.py | 12 ++++++++ gptqmodel/quantization/gptq.py | 45 +++++++++++++++++++++++++---- gptqmodel/quantization/quantizer.py | 36 ++++++++++++++++++++++- tests/qcfg/test_gptq_pro.py | 44 ++++++++++++++++++++++++++++ tests/test_serialization.py | 3 ++ 7 files changed, 141 insertions(+), 7 deletions(-) create mode 100644 tests/qcfg/test_gptq_pro.py diff --git a/README.md b/README.md index 9cfce4978..805eb7239 100644 --- a/README.md +++ b/README.md @@ -433,6 +433,7 @@ If your goal is "better GPTQ quality without touching the inference kernels", th * GAR / `act_group_aware=True` to improve activation ordering without inference-time penalties. * MSE-based scale search (`mse > 0`) to reduce outlier-driven grid distortion. +* Activation-weighted MSE search (`activation_weighted_mse=True`) to bias scale selection toward Hessian-salient channels using an offline-only importance signal. * Adaptive damping for badly conditioned Hessian blocks. * Optional GPTAQ experimentation, with the same GPTQ export format, when you want to test more aggressive offline correction. @@ -444,7 +445,7 @@ from gptqmodel.quantization import QuantizeConfig quant_config = QuantizeConfig.gptq_pro() ``` -`QuantizeConfig.gptq_pro()` is intentionally conservative: it keeps `quant_method=METHOD.GPTQ` and `format=FORMAT.GPTQ`, so inference speed comes from the same kernels as regular GPTQ. It does **not** claim that GPTQModel currently implements AWQ-style layer fusion or AutoRound-style learned rounding inside the GPTQ inner loop; those are separate algorithms and should be treated as separate offline quantizers. +`QuantizeConfig.gptq_pro()` is intentionally conservative: it keeps `quant_method=METHOD.GPTQ` and `format=FORMAT.GPTQ`, so inference speed comes from the same kernels as regular GPTQ. Today that preset combines existing GPTQModel features with an AutoRound-inspired offline importance weighting pass during MSE scale search, but it does **not** claim that GPTQModel currently implements AWQ-style layer fusion or AutoRound-style learned rounding inside the GPTQ inner loop; those are separate algorithms and should be treated as separate offline quantizers. ### Experimental Features diff --git a/gptqmodel/looper/gptq_processor.py b/gptqmodel/looper/gptq_processor.py index e9b342a54..f393ff3bc 100644 --- a/gptqmodel/looper/gptq_processor.py +++ b/gptqmodel/looper/gptq_processor.py @@ -75,6 +75,11 @@ def preprocess(self, module: NamedModule, failsafe=None, **kwargs): qcfg_clone.bits = self.qcfg.dynamic_get(module.full_name, "bits", qcfg_clone.bits) qcfg_clone.sym = self.qcfg.dynamic_get(module.full_name, "sym", qcfg_clone.sym) qcfg_clone.mse = self.qcfg.dynamic_get(module.full_name, "mse", qcfg_clone.mse) + qcfg_clone.activation_weighted_mse = self.qcfg.dynamic_get( + module.full_name, + "activation_weighted_mse", + qcfg_clone.activation_weighted_mse, + ) qcfg_clone.group_size = self.qcfg.dynamic_get(module.full_name, "group_size", qcfg_clone.group_size) desc_act_override = self.qcfg.dynamic_get(module.full_name, "desc_act", None) diff --git a/gptqmodel/quantization/config.py b/gptqmodel/quantization/config.py index c7859cfae..368c619ec 100644 --- a/gptqmodel/quantization/config.py +++ b/gptqmodel/quantization/config.py @@ -678,6 +678,10 @@ class QuantizeConfig: # mean square error calculation: may reduce error loss for some models mse: float = field(default=0.0) + # GPTQ only + # use Hessian-diagonal activation importance to weight MSE grid search offline + activation_weighted_mse: bool = field(default=False) + # properties that do not directly contributes to quantization or quant inference should be placed in meta # i.e. quantizer tool (producer) + version, timestamp, entity who made the quant, etc meta: Optional[Dict] = field(default=None) @@ -1147,6 +1151,7 @@ def gptq_pro( "desc_act": False, "act_group_aware": True, "mse": mse, + "activation_weighted_mse": True, "damp_percent": damp_percent, "damp_auto_increment": damp_auto_increment, "failsafe": failsafe, @@ -1270,6 +1275,12 @@ def from_quant_config(cls, quantize_cfg, format: str = None): normalized["gptaq"] = meta_payload.get("gptaq") if "mse" not in normalized and isinstance(meta_payload, dict) and "mse" in meta_payload: normalized["mse"] = meta_payload.get("mse") + if ( + "activation_weighted_mse" not in normalized + and isinstance(meta_payload, dict) + and "activation_weighted_mse" in meta_payload + ): + normalized["activation_weighted_mse"] = meta_payload.get("activation_weighted_mse") if "act_group_aware" not in normalized and isinstance(meta_payload, dict) and "act_group_aware" in meta_payload: normalized["act_group_aware"] = meta_payload.get("act_group_aware") if ( @@ -1385,6 +1396,7 @@ def to_dict(self): meta_payload["offload_to_disk_path"] = self.offload_to_disk_path meta_payload["pack_impl"] = self.pack_impl meta_payload["mse"] = self.mse + meta_payload["activation_weighted_mse"] = self.activation_weighted_mse meta_payload["mock_quantization"] = self.mock_quantization meta_payload["act_group_aware"] = self.act_group_aware meta_payload["gc_mode"] = self.gc_mode diff --git a/gptqmodel/quantization/gptq.py b/gptqmodel/quantization/gptq.py index 08072de93..fa40c03b7 100644 --- a/gptqmodel/quantization/gptq.py +++ b/gptqmodel/quantization/gptq.py @@ -936,14 +936,20 @@ def quantize( W = self.module_copy.to(device=self.H.device) del self.module_copy - self.quantizer.find_params(W, weight=True) - # H = self.H.to(device=self.H.device) + activation_importance = None if use_hessian: dead = torch.diag(self.H) == 0 self.H[dead, dead] = 1 W[:, dead] = 0 + if self.qcfg.activation_weighted_mse: + activation_importance = torch.diag(self.H).clamp_min(0).to(device=W.device, dtype=W.dtype) + importance_mean = activation_importance.mean() + if torch.isfinite(importance_mean) and importance_mean > 0: + activation_importance = activation_importance / importance_mean + else: + activation_importance = None # g_idx = [] scale = [] @@ -956,7 +962,14 @@ def quantize( groups = [] for i in range(0, self.columns, self.qcfg.group_size): quantizer = copy.deepcopy(self.quantizer) - quantizer.find_params(W[:, i: (i + self.qcfg.group_size)], weight=True) + group_importance = None + if activation_importance is not None: + group_importance = activation_importance[i: (i + self.qcfg.group_size)] + quantizer.find_params( + W[:, i: (i + self.qcfg.group_size)], + weight=True, + importance=group_importance, + ) scale.append(quantizer.scale) zero.append(quantizer.zero) @@ -966,6 +979,8 @@ def quantize( perm = torch.argsort(torch.diag(self.H), descending=True) W = W[:, perm] self.H = self.H[perm][:, perm] + if activation_importance is not None: + activation_importance = activation_importance[perm] invperm = torch.argsort(perm) elif self.qcfg.act_group_aware and use_hessian: @@ -982,6 +997,10 @@ def quantize( final_perm = compose_final_perm(local_perms, global_perm, self.qcfg.group_size) W = W[:, final_perm] self.H = self.H[final_perm][:, final_perm] + if activation_importance is not None: + activation_importance = activation_importance[final_perm] + + self.quantizer.find_params(W, weight=True, importance=activation_importance) Losses = torch.zeros_like(W) Q = torch.zeros_like(W) @@ -1008,7 +1027,14 @@ def quantize( for group_start in group_start_cols: group_end = min(group_start + self.qcfg.group_size, self.columns) if group_start < group_end: - self.quantizer.find_params(W[:, group_start:group_end], weight=True) + group_importance = None + if activation_importance is not None: + group_importance = activation_importance[group_start:group_end] + self.quantizer.find_params( + W[:, group_start:group_end], + weight=True, + importance=group_importance, + ) scale.append(self.quantizer.scale) zero.append(self.quantizer.zero) now_idx += 1 @@ -1117,7 +1143,16 @@ def quantize( if self.qcfg.group_size != -1: if not self.qcfg.static_groups: if (i1 + i) % self.qcfg.group_size == 0: - self.quantizer.find_params(W[:, (i1 + i) : (i1 + i + self.qcfg.group_size)], weight=True) + group_start = i1 + i + group_end = group_start + self.qcfg.group_size + group_importance = None + if activation_importance is not None: + group_importance = activation_importance[group_start:group_end] + self.quantizer.find_params( + W[:, group_start:group_end], + weight=True, + importance=group_importance, + ) if ((i1 + i) // self.qcfg.group_size) - now_idx == -1: scale.append(self.quantizer.scale) diff --git a/gptqmodel/quantization/quantizer.py b/gptqmodel/quantization/quantizer.py index 7614fcc6d..edaa847de 100644 --- a/gptqmodel/quantization/quantizer.py +++ b/gptqmodel/quantization/quantizer.py @@ -66,7 +66,7 @@ def configure( if trits: self.maxq = torch.tensor(-1) - def find_params(self, x, weight=False): + def find_params(self, x, weight=False, importance: torch.Tensor = None): dev = x.device self.maxq = self.maxq.to(dev) @@ -113,6 +113,38 @@ def find_params(self, x, weight=False): self.zero = torch.round(-xmin / self.scale) if self.qcfg.mse > 0.0: + importance_weights = None + if self.qcfg.activation_weighted_mse and importance is not None: + importance_weights = torch.nan_to_num( + importance.to(device=dev, dtype=x.dtype), + nan=0.0, + posinf=0.0, + neginf=0.0, + ).clamp_min_(0) + if importance_weights.ndim == 1: + importance_weights = importance_weights.unsqueeze(0) + if importance_weights.shape[-1] != x.shape[1]: + raise ValueError( + "Quantizer.find_params(): `importance` must match the flattened weight column count." + ) + if importance_weights.shape[0] == 1 and x.shape[0] != 1: + importance_weights = importance_weights.expand(x.shape[0], -1) + elif importance_weights.shape[0] != x.shape[0]: + raise ValueError( + "Quantizer.find_params(): `importance` must have either one row or match the per-channel row count." + ) + + importance_mean = importance_weights.mean(dim=1, keepdim=True) + valid = torch.isfinite(importance_mean) & (importance_mean > 0) + if torch.any(valid): + importance_weights = torch.where( + valid, + importance_weights / importance_mean.clamp_min(torch.finfo(x.dtype).eps), + importance_weights, + ) + else: + importance_weights = None + best = torch.full([x.shape[0]], float("inf"), device=dev) for i in range(int(self.maxshrink * self.grid)): p = 1 - i / self.grid @@ -128,6 +160,8 @@ def find_params(self, x, weight=False): q -= x q.abs_() q.pow_(self.qcfg.mse) + if importance_weights is not None: + q.mul_(importance_weights) err = torch.sum(q, 1) tmp = err < best if torch.any(tmp): diff --git a/tests/qcfg/test_gptq_pro.py b/tests/qcfg/test_gptq_pro.py new file mode 100644 index 000000000..fadf7e8ee --- /dev/null +++ b/tests/qcfg/test_gptq_pro.py @@ -0,0 +1,44 @@ +import torch + +from gptqmodel.quantization import QuantizeConfig +from gptqmodel.quantization.quantizer import Quantizer + + +def _weighted_sq_error(quantizer: Quantizer, weights: torch.Tensor, importance: torch.Tensor) -> torch.Tensor: + dequant = quantizer.quantize(weights) + return ((dequant - weights).pow(2) * importance.view(1, -1)).sum() + + +def test_gptq_pro_enables_activation_weighted_mse(): + cfg = QuantizeConfig.gptq_pro() + + assert cfg.activation_weighted_mse is True + assert cfg.act_group_aware is True + assert cfg.desc_act is False + + +def test_activation_weighted_mse_prioritizes_salient_columns(): + weights = torch.tensor([[0.1, 0.2, 0.3, 1.2]], dtype=torch.float32) + importance = torch.tensor([1.0, 1.0, 8.0, 8.0], dtype=torch.float32) + + baseline = Quantizer( + QuantizeConfig(bits=2, sym=False, mse=2.0, act_group_aware=False, desc_act=False), + ) + baseline.configure(perchannel=True) + baseline.find_params(weights, weight=True) + + weighted = Quantizer( + QuantizeConfig( + bits=2, + sym=False, + mse=2.0, + activation_weighted_mse=True, + act_group_aware=False, + desc_act=False, + ), + ) + weighted.configure(perchannel=True) + weighted.find_params(weights, weight=True, importance=importance) + + assert not torch.allclose(weighted.scale, baseline.scale) + assert _weighted_sq_error(weighted, weights, importance) < _weighted_sq_error(baseline, weights, importance) diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 6cf6331a9..c2eaeacbc 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -58,6 +58,7 @@ def test_quantize_config_meta_only_fields_serialization(self): offload_to_disk_path="./offload-test", pack_impl="gpu", mse=0.125, + activation_weighted_mse=True, mock_quantization=True, hessian=HessianConfig( chunk_size=256, @@ -78,6 +79,7 @@ def test_quantize_config_meta_only_fields_serialization(self): "offload_to_disk_path", "pack_impl", "mse", + "activation_weighted_mse", "mock_quantization", "act_group_aware", "hessian", @@ -93,6 +95,7 @@ def test_quantize_config_meta_only_fields_serialization(self): self.assertEqual(meta["offload_to_disk_path"], cfg.offload_to_disk_path) self.assertEqual(meta["pack_impl"], cfg.pack_impl) self.assertEqual(meta["mse"], cfg.mse) + self.assertEqual(meta["activation_weighted_mse"], cfg.activation_weighted_mse) self.assertEqual(meta["mock_quantization"], cfg.mock_quantization) self.assertEqual(meta["act_group_aware"], cfg.act_group_aware) self.assertEqual(meta["hessian"]["chunk_size"], cfg.hessian.chunk_size) From 7ab9f5dc1d8279d9311812943e082966cab1a543 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 10 Mar 2026 22:32:14 +0000 Subject: [PATCH 3/6] Polish activation-weighted GPTQ-Pro changes Co-authored-by: groxaxo <76023196+groxaxo@users.noreply.github.com> --- gptqmodel/quantization/quantizer.py | 9 +++------ tests/qcfg/test_gptq_pro.py | 16 ++++++++++++++-- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/gptqmodel/quantization/quantizer.py b/gptqmodel/quantization/quantizer.py index edaa847de..4930cafc5 100644 --- a/gptqmodel/quantization/quantizer.py +++ b/gptqmodel/quantization/quantizer.py @@ -125,7 +125,7 @@ def find_params(self, x, weight=False, importance: torch.Tensor = None): importance_weights = importance_weights.unsqueeze(0) if importance_weights.shape[-1] != x.shape[1]: raise ValueError( - "Quantizer.find_params(): `importance` must match the flattened weight column count." + "Quantizer.find_params(): `importance` must match the column count of the weight block being quantized." ) if importance_weights.shape[0] == 1 and x.shape[0] != 1: importance_weights = importance_weights.expand(x.shape[0], -1) @@ -137,11 +137,8 @@ def find_params(self, x, weight=False, importance: torch.Tensor = None): importance_mean = importance_weights.mean(dim=1, keepdim=True) valid = torch.isfinite(importance_mean) & (importance_mean > 0) if torch.any(valid): - importance_weights = torch.where( - valid, - importance_weights / importance_mean.clamp_min(torch.finfo(x.dtype).eps), - importance_weights, - ) + normalized_weights = importance_weights / importance_mean.clamp_min(torch.finfo(x.dtype).eps) + importance_weights = torch.where(valid, normalized_weights, torch.ones_like(importance_weights)) else: importance_weights = None diff --git a/tests/qcfg/test_gptq_pro.py b/tests/qcfg/test_gptq_pro.py index fadf7e8ee..8279f025d 100644 --- a/tests/qcfg/test_gptq_pro.py +++ b/tests/qcfg/test_gptq_pro.py @@ -4,7 +4,11 @@ from gptqmodel.quantization.quantizer import Quantizer -def _weighted_sq_error(quantizer: Quantizer, weights: torch.Tensor, importance: torch.Tensor) -> torch.Tensor: +def _calculate_weighted_squared_error( + quantizer: Quantizer, + weights: torch.Tensor, + importance: torch.Tensor, +) -> torch.Tensor: dequant = quantizer.quantize(weights) return ((dequant - weights).pow(2) * importance.view(1, -1)).sum() @@ -41,4 +45,12 @@ def test_activation_weighted_mse_prioritizes_salient_columns(): weighted.find_params(weights, weight=True, importance=importance) assert not torch.allclose(weighted.scale, baseline.scale) - assert _weighted_sq_error(weighted, weights, importance) < _weighted_sq_error(baseline, weights, importance) + assert _calculate_weighted_squared_error( + weighted, + weights, + importance, + ) < _calculate_weighted_squared_error( + baseline, + weights, + importance, + ) From 3adc65393f60e09b3bbc6056ad77734499b5a944 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 10 Mar 2026 22:33:48 +0000 Subject: [PATCH 4/6] Refine GPTQ-Pro importance weighting diagnostics Co-authored-by: groxaxo <76023196+groxaxo@users.noreply.github.com> --- gptqmodel/quantization/quantizer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/gptqmodel/quantization/quantizer.py b/gptqmodel/quantization/quantizer.py index 4930cafc5..737723c1e 100644 --- a/gptqmodel/quantization/quantizer.py +++ b/gptqmodel/quantization/quantizer.py @@ -125,13 +125,15 @@ def find_params(self, x, weight=False, importance: torch.Tensor = None): importance_weights = importance_weights.unsqueeze(0) if importance_weights.shape[-1] != x.shape[1]: raise ValueError( - "Quantizer.find_params(): `importance` must match the column count of the weight block being quantized." + "Quantizer.find_params(): importance parameter shape mismatch. " + f"Expected columns: {x.shape[1]}, got: {importance_weights.shape[-1]}." ) if importance_weights.shape[0] == 1 and x.shape[0] != 1: importance_weights = importance_weights.expand(x.shape[0], -1) elif importance_weights.shape[0] != x.shape[0]: raise ValueError( - "Quantizer.find_params(): `importance` must have either one row or match the per-channel row count." + "Quantizer.find_params(): importance parameter row count mismatch. " + f"Expected 1 or {x.shape[0]} rows, got: {importance_weights.shape[0]}." ) importance_mean = importance_weights.mean(dim=1, keepdim=True) From 60250447c12430af2d10675c8e5e1c721ae51f1d Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 10 Mar 2026 22:37:36 +0000 Subject: [PATCH 5/6] Finalize GPTQ-Pro importance weighting updates Co-authored-by: groxaxo <76023196+groxaxo@users.noreply.github.com> --- gptqmodel/quantization/quantizer.py | 2 +- tests/qcfg/test_gptq_pro.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/gptqmodel/quantization/quantizer.py b/gptqmodel/quantization/quantizer.py index 737723c1e..c661c8c83 100644 --- a/gptqmodel/quantization/quantizer.py +++ b/gptqmodel/quantization/quantizer.py @@ -139,7 +139,7 @@ def find_params(self, x, weight=False, importance: torch.Tensor = None): importance_mean = importance_weights.mean(dim=1, keepdim=True) valid = torch.isfinite(importance_mean) & (importance_mean > 0) if torch.any(valid): - normalized_weights = importance_weights / importance_mean.clamp_min(torch.finfo(x.dtype).eps) + normalized_weights = importance_weights / importance_mean.clamp_min(1e-8) importance_weights = torch.where(valid, normalized_weights, torch.ones_like(importance_weights)) else: importance_weights = None diff --git a/tests/qcfg/test_gptq_pro.py b/tests/qcfg/test_gptq_pro.py index 8279f025d..56adbb44e 100644 --- a/tests/qcfg/test_gptq_pro.py +++ b/tests/qcfg/test_gptq_pro.py @@ -22,18 +22,18 @@ def test_gptq_pro_enables_activation_weighted_mse(): def test_activation_weighted_mse_prioritizes_salient_columns(): - weights = torch.tensor([[0.1, 0.2, 0.3, 1.2]], dtype=torch.float32) + weights = torch.tensor([[0.1, 0.45, 0.8, 1.2]], dtype=torch.float32) importance = torch.tensor([1.0, 1.0, 8.0, 8.0], dtype=torch.float32) baseline = Quantizer( - QuantizeConfig(bits=2, sym=False, mse=2.0, act_group_aware=False, desc_act=False), + QuantizeConfig(bits=4, sym=False, mse=2.0, act_group_aware=False, desc_act=False), ) baseline.configure(perchannel=True) baseline.find_params(weights, weight=True) weighted = Quantizer( QuantizeConfig( - bits=2, + bits=4, sym=False, mse=2.0, activation_weighted_mse=True, From 8e2ee0519cc318c5a75bb8e2f3c7e72b46fb53b5 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Mon, 20 Apr 2026 20:45:06 +0000 Subject: [PATCH 6/6] Resolve PR merge conflicts with integration/review-2026-04 Agent-Logs-Url: https://github.com/groxaxo/GPTQ-Pro/sessions/7bfab716-8438-49a4-a03d-47a940db5041 Co-authored-by: groxaxo <76023196+groxaxo@users.noreply.github.com> --- .dockerignore | 11 + .github/PULL_REQUEST_TEMPLATE.md | 38 + .github/scripts/allocate_gpu.py | 145 + .github/scripts/blacklist.yaml | 3 + .github/scripts/ci_loop_versions.py | 37 + .github/scripts/deps.yaml | 73 +- .github/scripts/install_deps.py | 43 +- .github/scripts/list_test_files.py | 57 +- .github/scripts/parse_test_config.py | 87 + .github/scripts/release_gpu.py | 45 + .github/scripts/run_tests.py | 200 + .github/scripts/test.yaml | 16 + .github/scripts/uninstall_deps.py | 68 + .github/workflows/compatibility.yml | 54 + .github/workflows/release.yml | 457 +- .github/workflows/unit_tests.yml | 993 +- .gitignore | 10 + BRANCH_CLEANUP.md | 43 + CREDITS.md | 2 +- Dockerfile | 52 + MANIFEST.in | 4 + Project.md | 9396 +++++++++++++++++ README.md | 641 +- chat/README.md | 26 +- chat/chat.py | 55 +- chat/run.sh | 61 +- docs/eora/README.md | 97 + .../eora_calibration_data_construction.py | 68 + docs/eora/eora_generation.py | 138 + docs/eora/eora_load_and_inference.py | 67 + docs/eora/evaluation.py | 103 + docs/eora/post_quant_eora_generation.py | 81 + docs/quantization_protocol.md | 1894 ++++ docs/qwen35_vllm_comparison.md | 209 + docs/qwen35_vllm_launch.md | 291 + docs/torch_fused_int4_transformations.md | 6 +- environment.yml | 11 + examples/benchmark/perplexity.py | 21 +- format/format.sh | 4 +- gptqmodel/__init__.py | 233 +- gptqmodel/_banner.py | 78 + gptqmodel/adapter/adapter.py | 60 +- gptqmodel/adapter/peft.py | 8 +- gptqmodel/adapter/remote.py | 11 +- gptqmodel/eora/eora.py | 5 +- gptqmodel/exllamav3/CREDITS.md | 13 + gptqmodel/exllamav3/__init__.py | 0 gptqmodel/exllamav3/ext.py | 338 + gptqmodel/exllamav3/modules/__init__.py | 0 gptqmodel/exllamav3/modules/quant/__init__.py | 0 gptqmodel/exllamav3/modules/quant/exl3.py | 265 + .../modules/quant/exl3_lib/__init__.py | 0 .../modules/quant/exl3_lib/quantize.py | 1070 ++ gptqmodel/exllamav3/util/__init__.py | 3 + gptqmodel/exllamav3/util/arch_list.py | 40 + gptqmodel/exllamav3/util/hadamard.py | 160 + .../util/hadamard_data/hadamard_1.txt | 1 + .../util/hadamard_data/hadamard_100.txt | 100 + .../util/hadamard_data/hadamard_116.txt | 116 + .../util/hadamard_data/hadamard_156.txt | 156 + .../util/hadamard_data/hadamard_172.txt | 172 + .../util/hadamard_data/hadamard_188.txt | 188 + .../util/hadamard_data/hadamard_236.txt | 236 + .../util/hadamard_data/hadamard_244.txt | 244 + .../util/hadamard_data/hadamard_428.txt | 428 + .../util/hadamard_data/hadamard_52.txt | 52 + .../util/hadamard_data/hadamard_92.txt | 92 + gptqmodel/exllamav3/util/memory.py | 228 + gptqmodel/exllamav3/util/misc.py | 137 + gptqmodel/exllamav3/util/progress.py | 45 + gptqmodel/exllamav3/util/tensor.py | 210 + gptqmodel/extension.py | 292 + .../hf_minimax_m2/modeling_minimax_m2.py | 8 +- gptqmodel/looper/awq_processor.py | 316 +- gptqmodel/looper/dequantize_processor.py | 35 +- gptqmodel/looper/eora_processor.py | 32 +- gptqmodel/looper/exllamav3_processor.py | 393 + gptqmodel/looper/forward_executor.py | 584 + gptqmodel/looper/gptq_processor.py | 187 +- gptqmodel/looper/input_cache.py | 4 + gptqmodel/looper/linear_mode.py | 2 + gptqmodel/looper/loop_processor.py | 206 +- gptqmodel/looper/module_looper.py | 884 +- gptqmodel/looper/module_preprocessor.py | 159 + gptqmodel/looper/named_module.py | 38 + gptqmodel/looper/native_processor.py | 40 +- gptqmodel/looper/paroquant_processor.py | 2638 +++++ gptqmodel/looper/qqq_processor.py | 52 +- gptqmodel/looper/stage_inputs_capture.py | 102 +- gptqmodel/looper/stage_layer.py | 698 +- gptqmodel/looper/stage_subset.py | 1321 ++- gptqmodel/looper/weight_only_looper.py | 263 + gptqmodel/looper/weight_only_processor.py | 266 + gptqmodel/models/_const.py | 4 +- gptqmodel/models/auto.py | 549 +- gptqmodel/models/base.py | 1620 ++- gptqmodel/models/definitions/__init__.py | 56 +- gptqmodel/models/definitions/baichuan.py | 78 + .../models/definitions/base_qwen2_5_omni.py | 63 +- gptqmodel/models/definitions/base_qwen2_vl.py | 71 +- gptqmodel/models/definitions/base_qwen3_vl.py | 79 +- gptqmodel/models/definitions/brumby.py | 13 + gptqmodel/models/definitions/ernie4_5.py | 169 +- gptqmodel/models/definitions/gemma4.py | 248 + gptqmodel/models/definitions/glm.py | 2 +- gptqmodel/models/definitions/glm4_moe.py | 2 +- gptqmodel/models/definitions/glm4_moe_lite.py | 39 + gptqmodel/models/definitions/glm4v.py | 34 - gptqmodel/models/definitions/glm_moe_dsa.py | 55 + gptqmodel/models/definitions/gpt_oss.py | 132 +- .../models/definitions/granitemoehybrid.py | 38 + gptqmodel/models/definitions/llama4.py | 50 +- gptqmodel/models/definitions/minicpm_o.py | 417 + gptqmodel/models/definitions/minicpm_v.py | 193 + gptqmodel/models/definitions/mixtral.py | 15 +- gptqmodel/models/definitions/olmoe.py | 4 +- gptqmodel/models/definitions/ovis.py | 37 +- gptqmodel/models/definitions/phi3.py | 6 +- gptqmodel/models/definitions/qwen2_moe.py | 15 +- gptqmodel/models/definitions/qwen3_5.py | 15 + gptqmodel/models/definitions/qwen3_5_moe.py | 18 +- gptqmodel/models/definitions/qwen3_moe.py | 2 +- gptqmodel/models/definitions/qwen3_next.py | 10 +- .../models/definitions/qwen3_omni_moe.py | 4 +- gptqmodel/models/loader.py | 946 +- gptqmodel/models/writer.py | 544 +- gptqmodel/nn_modules/converter.py | 120 +- gptqmodel/nn_modules/exllamav3.py | 200 + gptqmodel/nn_modules/exllamav3_torch.py | 400 + gptqmodel/nn_modules/hooked_linear.py | 12 +- gptqmodel/nn_modules/qlinear/__init__.py | 632 +- gptqmodel/nn_modules/qlinear/bitblas.py | 431 +- gptqmodel/nn_modules/qlinear/bitblas_awq.py | 170 + gptqmodel/nn_modules/qlinear/bitsandbytes.py | 411 + gptqmodel/nn_modules/qlinear/exllamav2.py | 67 +- gptqmodel/nn_modules/qlinear/exllamav2_awq.py | 68 +- gptqmodel/nn_modules/qlinear/fp4.py | 63 + gptqmodel/nn_modules/qlinear/fp8.py | 462 + gptqmodel/nn_modules/qlinear/gemm_awq.py | 79 +- .../nn_modules/qlinear/gemm_awq_triton.py | 21 +- gptqmodel/nn_modules/qlinear/gemv_awq.py | 19 +- gptqmodel/nn_modules/qlinear/gemv_fast_awq.py | 70 +- gptqmodel/nn_modules/qlinear/gguf.py | 1194 +++ gptqmodel/nn_modules/qlinear/gguf_cpp.py | 773 ++ gptqmodel/nn_modules/qlinear/gguf_triton.py | 1646 +++ gptqmodel/nn_modules/qlinear/gptq_pro.py | 193 + gptqmodel/nn_modules/qlinear/lookahead.py | 12 +- gptqmodel/nn_modules/qlinear/machete.py | 99 +- gptqmodel/nn_modules/qlinear/machete_awq.py | 89 +- gptqmodel/nn_modules/qlinear/marlin.py | 50 +- gptqmodel/nn_modules/qlinear/marlin_awq.py | 32 +- gptqmodel/nn_modules/qlinear/paroquant.py | 297 + .../nn_modules/qlinear/paroquant_triton.py | 206 + gptqmodel/nn_modules/qlinear/qqq.py | 25 +- gptqmodel/nn_modules/qlinear/torch.py | 114 +- .../nn_modules/qlinear/torch_aten_kernel.py | 298 + .../qlinear/torch_aten_kernel_awq.py | 208 + gptqmodel/nn_modules/qlinear/torch_awq.py | 102 +- gptqmodel/nn_modules/qlinear/torch_fused.py | 24 +- .../nn_modules/qlinear/torch_fused_awq.py | 64 +- gptqmodel/nn_modules/qlinear/torch_int8.py | 40 +- .../nn_modules/qlinear/torch_int8_awq.py | 24 +- gptqmodel/nn_modules/qlinear/tritonv2.py | 10 +- gptqmodel/quantization/__init__.py | 44 +- .../quantization/awq/modules/triton/gemm.py | 45 +- gptqmodel/quantization/awq/quantize/scale.py | 8 + gptqmodel/quantization/awq/utils/module.py | 9 - gptqmodel/quantization/config.py | 3884 +++++-- gptqmodel/quantization/dtype.py | 769 +- gptqmodel/quantization/fallback_smooth.py | 191 + gptqmodel/quantization/foem.py | 320 + gptqmodel/quantization/gar.py | 2 +- gptqmodel/quantization/gptq.py | 167 +- gptqmodel/quantization/paroquant/__init__.py | 2 + .../paroquant/modules/__init__.py | 2 + .../paroquant/modules/triton/__init__.py | 2 + .../paroquant/modules/triton/gemm.py | 359 + .../quantization/paroquant/optimization.py | 1910 ++++ gptqmodel/quantization/protocol.py | 527 + gptqmodel/quantization/qqq.py | 56 +- gptqmodel/quantization/quantizer.py | 21 +- gptqmodel/quantization/rotation/__init__.py | 4 + gptqmodel/quantization/rtn.py | 199 + gptqmodel/utils/__init__.py | 16 +- gptqmodel/utils/awq.py | 154 + gptqmodel/utils/backend.py | 209 +- gptqmodel/utils/bitblas.py | 61 +- gptqmodel/utils/calibration.py | 179 +- gptqmodel/utils/cpp.py | 972 +- gptqmodel/utils/data.py | 31 +- gptqmodel/utils/device_telemetry.py | 85 + gptqmodel/utils/env.py | 2 +- gptqmodel/utils/exllamav2.py | 236 + gptqmodel/utils/exllamav3.py | 100 + gptqmodel/utils/fallback.py | 120 + gptqmodel/utils/gptq_pro.py | 153 + gptqmodel/utils/hf.py | 1293 ++- gptqmodel/utils/hub.py | 35 + gptqmodel/utils/image.py | 8 +- gptqmodel/utils/importer.py | 49 +- gptqmodel/utils/internal_gguf.py | 715 ++ gptqmodel/utils/jit_compile_baselines.py | 37 + gptqmodel/utils/linalg_warmup.py | 31 +- gptqmodel/utils/logger.py | 88 +- gptqmodel/utils/looper_helpers.py | 23 +- gptqmodel/utils/machete.py | 566 +- gptqmodel/utils/marlin.py | 349 +- gptqmodel/utils/mlx.py | 9 +- gptqmodel/utils/mmlupro.py | 22 +- gptqmodel/utils/model.py | 353 +- gptqmodel/utils/model_dequant.py | 171 +- gptqmodel/utils/openai_server.py | 35 +- gptqmodel/utils/paroquant.py | 627 ++ gptqmodel/utils/paroquant_benchmark.py | 910 ++ gptqmodel/utils/perplexity.py | 315 +- gptqmodel/utils/python.py | 23 +- gptqmodel/utils/qqq.py | 125 + gptqmodel/utils/random_str.py | 5 +- gptqmodel/utils/stream.py | 24 +- gptqmodel/utils/structure.py | 1497 ++- gptqmodel/utils/threadx.py | 120 +- gptqmodel/utils/torch.py | 13 +- gptqmodel/utils/vram.py | 16 +- gptqmodel/version.py | 2 +- gptqmodel_ext/__init__.py | 2 +- gptqmodel_ext/awq/gemm_fast_cuda_entry.cu | 1 + gptqmodel_ext/awq/gemv_fast_cuda_entry.cu | 1 + gptqmodel_ext/awq/quantization/dequantize.cuh | 22 + gptqmodel_ext/awq/quantization/gemm_cuda.h | 5 +- .../awq/quantization/gemm_cuda_gen.cu | 768 +- gptqmodel_ext/awq/torch_bind.cpp | 80 + gptqmodel_ext/cutlass_extensions/__init__.py | 2 +- gptqmodel_ext/cutlass_extensions/common.hpp | 13 +- .../cutlass_extensions/cute_utils.cuh | 1 - .../broadcast_load_epilogue_array_c3x.hpp | 64 +- .../epilogue/broadcast_load_epilogue_c3x.hpp | 64 +- .../epilogue/scaled_mm_epilogues_c3x.hpp | 56 +- .../cutlass_extensions/torch_utils.hpp | 100 +- .../vllm_cutlass_library_extension.py | 3 + gptqmodel_ext/exllamav2/ext_awq.cpp | 96 +- gptqmodel_ext/exllamav2/ext_gptq.cpp | 96 +- gptqmodel_ext/exllamav3/bindings.cpp | 131 + gptqmodel_ext/exllamav3/hadamard.cpp | 112 + gptqmodel_ext/exllamav3/hadamard.h | 13 + gptqmodel_ext/exllamav3/hgemm.cu | 93 + gptqmodel_ext/exllamav3/hgemm.cuh | 10 + gptqmodel_ext/exllamav3/libtorch/linear.cpp | 41 + gptqmodel_ext/exllamav3/libtorch/linear.h | 51 + gptqmodel_ext/exllamav3/libtorch/linear_bc.h | 22 + gptqmodel_ext/exllamav3/ptx.cuh | 314 + gptqmodel_ext/exllamav3/quant/codebook.cuh | 146 + .../quant/comp_units/exl3_comp_unit_1.cu | 12 + .../quant/comp_units/exl3_comp_unit_1.cuh | 3 + .../quant/comp_units/exl3_comp_unit_2.cu | 12 + .../quant/comp_units/exl3_comp_unit_2.cuh | 3 + .../quant/comp_units/exl3_comp_unit_3.cu | 12 + .../quant/comp_units/exl3_comp_unit_3.cuh | 3 + .../quant/comp_units/exl3_comp_unit_4.cu | 12 + .../quant/comp_units/exl3_comp_unit_4.cuh | 3 + .../quant/comp_units/exl3_comp_unit_5.cu | 12 + .../quant/comp_units/exl3_comp_unit_5.cuh | 3 + .../quant/comp_units/exl3_comp_unit_6.cu | 12 + .../quant/comp_units/exl3_comp_unit_6.cuh | 3 + .../quant/comp_units/exl3_comp_unit_7.cu | 12 + .../quant/comp_units/exl3_comp_unit_7.cuh | 3 + .../quant/comp_units/exl3_comp_unit_8.cu | 12 + .../quant/comp_units/exl3_comp_unit_8.cuh | 3 + gptqmodel_ext/exllamav3/quant/exl3_devctx.cu | 86 + gptqmodel_ext/exllamav3/quant/exl3_devctx.cuh | 46 + gptqmodel_ext/exllamav3/quant/exl3_dq.cuh | 293 + gptqmodel_ext/exllamav3/quant/exl3_gemm.cu | 141 + gptqmodel_ext/exllamav3/quant/exl3_gemm.cuh | 17 + .../exllamav3/quant/exl3_gemm_inner.cuh | 610 ++ .../exllamav3/quant/exl3_gemm_kernel.cuh | 80 + .../exllamav3/quant/exl3_kernel_map.cu | 203 + .../exllamav3/quant/exl3_kernel_map.cuh | 80 + .../quant/exl3_kernel_map_packed.cuh | 600 ++ gptqmodel_ext/exllamav3/quant/hadamard.cu | 212 + gptqmodel_ext/exllamav3/quant/hadamard.cuh | 25 + .../exllamav3/quant/hadamard_inner.cuh | 205 + gptqmodel_ext/exllamav3/quant/pack.cu | 227 + gptqmodel_ext/exllamav3/quant/pack.cuh | 23 + gptqmodel_ext/exllamav3/quant/quantize.cu | 530 + gptqmodel_ext/exllamav3/quant/quantize.cuh | 34 + gptqmodel_ext/exllamav3/quant/reconstruct.cu | 131 + gptqmodel_ext/exllamav3/quant/reconstruct.cuh | 12 + gptqmodel_ext/exllamav3/quant/util.cu | 121 + gptqmodel_ext/exllamav3/quant/util.cuh | 10 + gptqmodel_ext/exllamav3/util.cuh | 139 + gptqmodel_ext/exllamav3/util.h | 125 + gptqmodel_ext/floatx_cpu.cpp | 2294 ++++ gptqmodel_ext/gptq_pro/gptq_pro_kernel.cu | 224 + gptqmodel_ext/gptq_pro/gptq_pro_kernel.cuh | 183 + gptqmodel_ext/gptq_pro/gptq_pro_torch.cpp | 85 + gptqmodel_ext/gptq_pro/gptq_pro_validate.cu | 595 ++ gptqmodel_ext/machete/generate.py | 27 +- gptqmodel_ext/machete/machete_mainloop.cuh | 1 + gptqmodel_ext/machete/machete_mm_kernel.cuh | 21 +- gptqmodel_ext/machete/machete_pytorch.cu | 11 +- gptqmodel_ext/marlin/generate_kernels.py | 134 +- gptqmodel_ext/marlin/gptq_marlin.cu | 693 +- gptqmodel_ext/marlin/gptq_marlin_bf16.cu | 5 + gptqmodel_ext/marlin/gptq_marlin_fp16.cu | 5 + gptqmodel_ext/marlin/kernel.h | 4 +- gptqmodel_ext/marlin/marlin.cuh | 115 +- gptqmodel_ext/marlin/marlin_mma.h | 155 + gptqmodel_ext/marlin/marlin_template.h | 119 +- gptqmodel_ext/marlin/marlin_torch_bf16.cpp | 67 + gptqmodel_ext/marlin/marlin_torch_fp16.cpp | 67 + gptqmodel_ext/paroquant/rotation.cu | 664 ++ gptqmodel_ext/paroquant/rotation.cuh | 257 + gptqmodel_ext/qqq/qqq.cpp | 24 +- progress.md | 1328 +++ pyproject.toml | 60 +- requirements.txt | 23 +- scripts/arch.md | 37 + scripts/benchmark_awq_cuda_fp32_reduce_ab.py | 576 + scripts/benchmark_awq_fused_reduce_ab.py | 490 + scripts/benchmark_awq_triton_fp32_ab.py | 278 + scripts/benchmark_gguf_autotune_ab.py | 276 + scripts/benchmark_gguf_cpp_vs_torch.py | 340 + scripts/benchmark_gguf_dequant.py | 218 + scripts/benchmark_gguf_fused_ab.py | 302 + scripts/benchmark_gptq_pro.py | 394 + scripts/benchmark_llama3_2_paged_attention.py | 1320 +++ .../benchmark_llama_cpp_vs_gptqmodel_gguf.py | 425 + scripts/benchmark_marlin_a100.py | 284 + .../benchmark_paroquant_official_vs_local.py | 757 ++ .../benchmark_paroquant_opt_scope_compare.py | 599 ++ ...hmark_paroquant_optimizer_real_workload.py | 326 + ...mark_paroquant_pair_cache_real_workload.py | 494 + scripts/benchmark_paroquant_rotation_ab.py | 334 + .../benchmark_paroquant_rotation_cache_ab.py | 417 + .../benchmark_paroquant_runtime_cache_ab.py | 431 + scripts/benchmark_paroquant_triton_ab.py | 279 + scripts/benchmark_qwen35_moe_ab.py | 717 ++ scripts/dequantize_model.py | 2 +- scripts/eval_model.py | 123 +- scripts/generate_exl3_kernel_map_packed.py | 358 + scripts/nvml_visible_shim.c | 188 + scripts/paroquant_first_layer_ab.py | 128 + scripts/paroquant_module_set_scan.py | 132 + scripts/paroquant_single_module_scan.py | 123 + .../profile_paroquant_runtime_cache_case.py | 250 + scripts/repro_issue_2326.py | 364 + scripts/run_gptq_pro_validate.sh | 46 + scripts/serve_vllm_qwen35.py | 231 + scripts/sitecustomize.py | 177 + scripts/sync_cuda_toolkit_with_torch.sh | 162 + scripts/vllm_qwen35_shim.py | 82 + setup.py | 959 +- tests/__init__.py | 1 + tests/awq_test_utils.py | 225 + tests/benchmark/benchmark_torch.py | 6 +- .../benchmark_torch_aten_vs_onednn.py | 370 + .../benchmark_torch_int8_vs_onednn.py | 428 + tests/conftest.py | 31 + tests/eval.py | 890 ++ tests/inference_speed.py | 27 +- tests/kernels/benchmark_intel_cpu_xpu.py | 20 +- tests/kernels/test_asymmetric_real_models.py | 306 + tests/kernels/test_awq.py | 307 +- tests/kernels/test_awq_cpu_fused_post_init.py | 24 +- tests/kernels/test_awq_cuda_fp32_reduce.py | 203 + tests/kernels/test_awq_machete_marlin.py | 186 + tests/kernels/test_awq_torch.py | 31 +- tests/kernels/test_awq_torch_fused.py | 8 +- tests/kernels/test_awq_triton_accum.py | 141 + tests/kernels/test_base_autotune.py | 104 + tests/kernels/test_exllamav3_kernel.py | 151 + .../test_exllamav3_kernel_map_packed.py | 31 + tests/kernels/test_fallback.py | 463 + tests/kernels/test_fp8_kernel.py | 157 + tests/kernels/test_gguf_cpp.py | 200 + tests/kernels/test_gptq.py | 194 +- tests/kernels/test_intel_cpu_xpu.py | 33 +- tests/kernels/test_paroquant.py | 404 + tests/kernels/test_qlinear_hierarchy.py | 236 + tests/kernels/test_selection.py | 342 +- tests/kernels/test_torch_int8.py | 14 +- tests/kernels/test_torch_int8_awq.py | 14 +- tests/models/awq/test_glm4_moe.py | 19 +- tests/models/awq/test_llama3_2.py | 65 +- .../models/awq/test_llama3_2_awq_protocol.py | 206 + tests/models/awq/test_marin_awq.py | 23 +- tests/models/awq/test_moe.py | 11 +- tests/models/awq/test_qwen3_5_moe.py | 18 +- tests/models/awq/test_qwen3_8b_base_awq.py | 46 + tests/models/foem/test_llama3_2.py | 111 + tests/models/foem/test_moe.py | 30 + tests/models/model_test.py | 1192 ++- tests/models/ovis/image_to_test_dataset.py | 7 +- .../paroquant_first_layer_case_helper.py | 107 + tests/models/paroquant_optimize_case.py | 215 + tests/models/test_act_group_aware.py | 11 +- tests/models/test_apertus.py | 12 +- tests/models/test_baichuan.py | 16 +- tests/models/test_bloom.py | 7 +- tests/models/test_bloom_bias_torch_fused.py | 10 +- tests/models/test_brumby.py | 36 +- tests/models/test_chatglm.py | 20 +- tests/models/test_codegen.py | 20 +- tests/models/test_cohere.py | 13 +- tests/models/test_cohere2.py | 9 +- tests/models/test_deci.py | 12 +- tests/models/test_deepseekv2_lite.py | 15 +- tests/models/test_dots_one.py | 16 +- tests/models/test_dream.py | 10 +- tests/models/test_ernie4_5.py | 13 +- tests/models/test_exaone.py | 20 +- tests/models/test_falcon.py | 9 +- tests/models/test_gemma.py | 6 +- tests/models/test_gemma3.py | 7 +- tests/models/test_gemma3_4b_it.py | 6 +- tests/models/test_gemma4_variants.py | 99 + tests/models/test_glm.py | 14 +- tests/models/test_glm4_moe.py | 11 +- tests/models/test_glm4_moe_lite.py | 24 + tests/models/test_glm4v.py | 13 +- tests/models/test_glm5_1_fp8_auto_decoder.py | 122 + tests/models/test_gpt2.py | 7 +- tests/models/test_gpt_oss.py | 10 +- tests/models/test_gptbigcode.py | 11 +- tests/models/test_gptj.py | 7 +- tests/models/test_gptneox.py | 6 +- tests/models/test_granite.py | 9 +- tests/models/test_granite_4_0_h_1b.py | 11 +- tests/models/test_granite_4_0_h_350m.py | 33 +- tests/models/test_hymba.py | 9 +- tests/models/test_instella.py | 6 +- tests/models/test_internlm.py | 18 +- tests/models/test_internlm2_5.py | 25 +- tests/models/test_ling.py | 11 +- tests/models/test_llama3_2.py | 88 +- tests/models/test_llama3_2_bitsandbytes.py | 86 + ...test_llama3_2_dynamic_skip_layer_replay.py | 298 + tests/models/test_llama3_2_exllamav3.py | 84 + tests/models/test_llama3_2_fp8.py | 74 + tests/models/test_llama3_2_gguf.py | 64 + tests/models/test_llama3_2_gguf_protocol.py | 181 + tests/models/test_llama3_2_gptq_protocol.py | 189 + .../test_llama3_2_lazy_turtle_memory.py | 180 + .../test_llama3_2_paroquant_first_layer.py | 31 + ...ama3_2_paroquant_optimize_compute_block.py | 14 + .../test_llama3_2_paroquant_optimize_group.py | 60 + .../test_llama3_2_paroquant_optimize_layer.py | 14 + ...test_llama3_2_paroquant_optimize_module.py | 13 + tests/models/test_llama3_2_torch_fused.py | 8 +- .../models/test_llama3_3_fp4_auto_decoder.py | 136 + tests/models/test_llama4.py | 9 +- tests/models/test_longllama.py | 8 +- tests/models/test_marin.py | 11 +- tests/models/test_mimo.py | 9 +- tests/models/test_minicpm_o_4_5.py | 46 + tests/models/test_minicpm_v_4_5.py | 47 + tests/models/test_minimax_m2.py | 11 +- tests/models/test_minimax_m2_hf.py | 179 + tests/models/test_mistral.py | 13 +- tests/models/test_mistral3.py | 13 +- tests/models/test_mixtral.py | 14 +- tests/models/test_model_test_fast_mode.py | 57 + tests/models/test_mpt.py | 13 +- tests/models/test_multi_vs_single_gpu.py | 11 +- tests/models/test_nemotron_ultra.py | 9 +- tests/models/test_opt.py | 6 +- tests/models/test_ovis2.py | 10 +- tests/models/test_ovis_1_6_llama.py | 33 +- tests/models/test_pangu_alpha.py | 6 +- tests/models/test_phi_3.py | 19 +- tests/models/test_phi_3_moe.py | 16 +- tests/models/test_phi_4.py | 13 +- tests/models/test_qwen2_5.py | 44 +- tests/models/test_qwen2_5_omni.py | 45 +- tests/models/test_qwen2_5_vl.py | 35 +- tests/models/test_qwen2_moe_quant.py | 9 +- tests/models/test_qwen2_vl.py | 35 +- tests/models/test_qwen3.py | 24 + tests/models/test_qwen3_5.py | 12 +- tests/models/test_qwen3_5_moe.py | 23 +- .../models/test_qwen3_5_moe_ab_regression.py | 55 + .../models/test_qwen3_8b_fp8_auto_decoder.py | 188 + tests/models/test_qwen3_8b_fp8_gsm8k_last4.py | 111 + tests/models/test_qwen3_8b_nvfp4.py | 81 + tests/models/test_qwen3_moe.py | 25 +- tests/models/test_qwen3_next.py | 13 +- tests/models/test_qwen3_omni.py | 9 +- tests/models/test_qwen3_vl.py | 35 +- tests/models/test_seed_oss.py | 9 +- tests/models/test_stablelm.py | 14 +- tests/models/test_starcode2.py | 7 +- tests/models/test_telechat2.py | 13 +- tests/models/test_tinyllama.py | 6 +- tests/models/test_voxtral.py | 15 +- tests/models/test_xverse.py | 16 +- tests/models/test_yi.py | 20 +- tests/module_tree/test_model_alignment.py | 224 +- tests/module_tree/test_subset.py | 59 +- tests/protocol/test_protocol.py | 319 + tests/pytest.ini | 7 +- tests/q4_reference.py | 1041 ++ tests/qcfg/test_config_dispatch.py | 270 + tests/qcfg/test_failsafe_meta.py | 68 +- tests/qcfg/test_fallback_meta.py | 43 + tests/qcfg/test_zero_point.py | 14 + tests/test_asym_gptq_v1.py | 5 +- tests/test_auto_module_decoder.py | 577 + tests/test_awq_bitblas.py | 184 + tests/test_awq_fp16_matmul_heuristic.py | 74 +- tests/test_awq_gemm.py | 17 + tests/test_awq_gemv.py | 17 + tests/test_awq_gemv_fast.py | 17 + tests/test_awq_gemv_fast_jit.py | 172 + tests/test_awq_inference_llm_awq.py | 21 + tests/test_awq_inference_mistral.py | 20 + tests/test_awq_jit_include_paths.py | 44 + tests/test_awq_llm_awq.py | 17 + tests/test_awq_loader_dtype.py | 53 + tests/test_awq_marlin.py | 17 + tests/test_awq_moe.py | 13 +- tests/test_awq_rotary_device.py | 4 +- tests/test_awq_weight_mean.py | 50 +- tests/test_backend_naming.py | 56 + tests/test_baichuan_rotary_buffers.py | 70 + tests/test_bench_cuda_even_d2h.py | 26 +- tests/test_benchmark_submodule_finalize.py | 13 +- tests/test_bitblas.py | 683 +- tests/test_bits.py | 52 +- tests/test_bits_new.py | 19 +- tests/test_bitsandbytes.py | 231 + tests/test_calibration_data_device.py | 1345 +++ tests/test_compute_device_filter.py | 6 +- tests/test_cpp_jit_progress.py | 145 + tests/test_cpu_pin.py | 2 +- ...est_cuda_event_stream_activation_buffer.py | 9 +- tests/test_cutlass_stable_abi_headers.py | 44 + tests/test_dynamic.py | 64 +- tests/test_eval.py | 81 +- tests/test_eval_loader_args.py | 234 + tests/test_eval_runtime.py | 70 + tests/test_evalution_suite_stream_defaults.py | 46 + tests/test_exllamav2_awq_jit.py | 96 + tests/test_exllamav2_jit.py | 134 + tests/test_exllamav3.py | 136 + tests/test_exllamav3_jit.py | 206 + tests/test_extension_load_api.py | 212 + tests/test_failsafe.py | 111 +- tests/test_fallback.py | 742 ++ tests/test_format_conversion_flow.py | 27 +- tests/test_fp4_llama3_fp4.py | 9 +- tests/test_fp4_qwen3_nvfp4.py | 75 + tests/test_fp8.py | 73 + tests/test_fp8_minimax2_test.py | 2 +- tests/test_gemma4_support.py | 184 + tests/test_generate_attention_mask.py | 81 + tests/test_gguf_qlinear_llama.py | 165 + tests/test_glm_moe_dsa_support.py | 206 + tests/test_gptaq.py | 8 +- tests/test_gptq.py | 48 +- tests/test_gptq_pro.py | 71 + tests/test_gpu_gpu_memory_copy.py | 18 +- tests/test_granitemoehybrid_monkeypatch.py | 69 + tests/test_group_size.py | 38 +- tests/test_hf_config_autofix.py | 70 + tests/test_hf_config_compat.py | 602 ++ tests/test_hf_init_guard.py | 38 + tests/test_hf_utils.py | 48 + tests/test_inference_speed.py | 9 +- tests/test_inference_speed_harness.py | 116 + tests/test_integration.py | 34 +- tests/test_internal_gguf.py | 133 + tests/test_linalg.py | 14 +- tests/test_linalg_warmup.py | 55 + tests/test_lm_head.py | 9 +- tests/test_loader.py | 15 + tests/test_local_model_paths.py | 1161 ++ tests/test_logger.py | 27 + tests/test_looper_helpers.py | 112 + tests/test_lora.py | 30 +- tests/test_machete_jit.py | 709 ++ tests/test_marlin_jit.py | 640 ++ tests/test_mlx.py | 34 - tests/test_mmlupro.py | 13 +- tests/test_model.py | 168 +- tests/test_model_definition_exports.py | 43 + tests/test_model_test_baseline_fallback.py | 63 + tests/test_model_test_helpers.py | 269 + tests/test_modelscope.py | 12 +- tests/test_module_preprocessor.py | 106 + tests/test_moe_config.py | 14 +- tests/test_moe_expert_batching.py | 32 +- tests/test_multi_gpu_inference.py | 27 +- tests/test_offload_files.py | 1302 ++- tests/test_openai_server.py | 28 +- tests/test_out_of_model_tensors.py | 537 + tests/test_ovis_generate_wrapper.py | 44 + tests/test_pack.py | 4 +- tests/test_pack_gpu_alignment.py | 8 +- tests/test_packable.py | 30 +- tests/test_packing.py | 10 +- tests/test_packing_speed.py | 43 +- tests/test_parameter_count.py | 2 +- tests/test_paroquant.py | 4831 +++++++++ tests/test_perplexity_logic.py | 95 + tests/test_post_quant_eora.py | 22 +- tests/test_prepare_dataset.py | 252 + tests/test_q4_bitblas.py | 76 +- tests/test_q4_exllama_v2.py | 37 +- tests/test_q4_marlin.py | 18 +- tests/test_q4_reference.py | 11 + tests/test_q4_torch_fused.py | 31 + tests/test_q4_triton.py | 6 +- tests/test_qqq.py | 14 +- tests/test_qqq_inference.py | 17 +- tests/test_qqq_jit.py | 77 + tests/test_quant_and_eora.py | 45 +- tests/test_quant_batch.py | 43 +- tests/test_quant_dtype.py | 512 +- tests/test_quant_trust_remote.py | 5 +- tests/test_qwen2_family_compat.py | 277 + tests/test_qwen3_5_batching.py | 14 + tests/test_qwen3_vl_dependency.py | 122 + tests/test_qwen_moe_converter.py | 151 + tests/test_random_string.py | 48 + tests/test_save_loaded_quantized_model.py | 51 +- ...save_loaded_quantized_model_torch_fused.py | 58 + tests/test_serialization.py | 49 +- tests/test_serve_vllm_qwen35.py | 36 + tests/test_sglang.py | 17 +- tests/test_sharded.py | 25 +- tests/test_simple_quant.py | 71 +- tests/test_split_by_layer_save.py | 199 + tests/test_stage_modules.py | 1782 +++- tests/test_startup_banner.py | 70 + tests/test_stream.py | 38 + tests/test_structure.py | 47 + tests/test_subset_plan.py | 500 + tests/test_tensor_parallel_padder.py | 151 + tests/test_threadx.py | 12 +- tests/test_tiny_moe_quant_smoke.py | 165 + tests/test_torch.py | 70 +- tests/test_torch_aten_kernel_import_guard.py | 87 + tests/test_torch_ops_jit_extension.py | 605 ++ tests/test_torch_xpu.py | 7 +- tests/test_vllm.py | 122 +- tests/test_weight_only.py | 1967 ++++ tests/test_weight_only_config.py | 427 + tests/test_writer_attention.py | 1 + 647 files changed, 118041 insertions(+), 9839 deletions(-) create mode 100644 .dockerignore create mode 100644 .github/PULL_REQUEST_TEMPLATE.md create mode 100644 .github/scripts/allocate_gpu.py create mode 100644 .github/scripts/blacklist.yaml create mode 100644 .github/scripts/ci_loop_versions.py create mode 100644 .github/scripts/parse_test_config.py create mode 100644 .github/scripts/release_gpu.py create mode 100644 .github/scripts/run_tests.py create mode 100644 .github/scripts/test.yaml create mode 100644 .github/scripts/uninstall_deps.py create mode 100644 .github/workflows/compatibility.yml create mode 100644 BRANCH_CLEANUP.md create mode 100644 Dockerfile create mode 100644 Project.md create mode 100644 docs/eora/README.md create mode 100644 docs/eora/eora_calibration_data_construction.py create mode 100644 docs/eora/eora_generation.py create mode 100644 docs/eora/eora_load_and_inference.py create mode 100644 docs/eora/evaluation.py create mode 100644 docs/eora/post_quant_eora_generation.py create mode 100644 docs/quantization_protocol.md create mode 100644 docs/qwen35_vllm_comparison.md create mode 100644 docs/qwen35_vllm_launch.md create mode 100644 environment.yml create mode 100644 gptqmodel/_banner.py create mode 100644 gptqmodel/exllamav3/CREDITS.md create mode 100644 gptqmodel/exllamav3/__init__.py create mode 100644 gptqmodel/exllamav3/ext.py create mode 100644 gptqmodel/exllamav3/modules/__init__.py create mode 100644 gptqmodel/exllamav3/modules/quant/__init__.py create mode 100644 gptqmodel/exllamav3/modules/quant/exl3.py create mode 100644 gptqmodel/exllamav3/modules/quant/exl3_lib/__init__.py create mode 100644 gptqmodel/exllamav3/modules/quant/exl3_lib/quantize.py create mode 100644 gptqmodel/exllamav3/util/__init__.py create mode 100644 gptqmodel/exllamav3/util/arch_list.py create mode 100644 gptqmodel/exllamav3/util/hadamard.py create mode 100644 gptqmodel/exllamav3/util/hadamard_data/hadamard_1.txt create mode 100644 gptqmodel/exllamav3/util/hadamard_data/hadamard_100.txt create mode 100644 gptqmodel/exllamav3/util/hadamard_data/hadamard_116.txt create mode 100644 gptqmodel/exllamav3/util/hadamard_data/hadamard_156.txt create mode 100644 gptqmodel/exllamav3/util/hadamard_data/hadamard_172.txt create mode 100644 gptqmodel/exllamav3/util/hadamard_data/hadamard_188.txt create mode 100644 gptqmodel/exllamav3/util/hadamard_data/hadamard_236.txt create mode 100644 gptqmodel/exllamav3/util/hadamard_data/hadamard_244.txt create mode 100644 gptqmodel/exllamav3/util/hadamard_data/hadamard_428.txt create mode 100644 gptqmodel/exllamav3/util/hadamard_data/hadamard_52.txt create mode 100644 gptqmodel/exllamav3/util/hadamard_data/hadamard_92.txt create mode 100644 gptqmodel/exllamav3/util/memory.py create mode 100644 gptqmodel/exllamav3/util/misc.py create mode 100644 gptqmodel/exllamav3/util/progress.py create mode 100644 gptqmodel/exllamav3/util/tensor.py create mode 100644 gptqmodel/extension.py create mode 100644 gptqmodel/looper/exllamav3_processor.py create mode 100644 gptqmodel/looper/forward_executor.py create mode 100644 gptqmodel/looper/module_preprocessor.py create mode 100644 gptqmodel/looper/paroquant_processor.py create mode 100644 gptqmodel/looper/weight_only_looper.py create mode 100644 gptqmodel/looper/weight_only_processor.py create mode 100644 gptqmodel/models/definitions/gemma4.py create mode 100644 gptqmodel/models/definitions/glm4_moe_lite.py create mode 100644 gptqmodel/models/definitions/glm_moe_dsa.py create mode 100644 gptqmodel/models/definitions/minicpm_o.py create mode 100644 gptqmodel/models/definitions/minicpm_v.py create mode 100644 gptqmodel/nn_modules/exllamav3.py create mode 100644 gptqmodel/nn_modules/exllamav3_torch.py create mode 100644 gptqmodel/nn_modules/qlinear/bitblas_awq.py create mode 100644 gptqmodel/nn_modules/qlinear/bitsandbytes.py create mode 100644 gptqmodel/nn_modules/qlinear/fp4.py create mode 100644 gptqmodel/nn_modules/qlinear/fp8.py create mode 100644 gptqmodel/nn_modules/qlinear/gguf.py create mode 100644 gptqmodel/nn_modules/qlinear/gguf_cpp.py create mode 100644 gptqmodel/nn_modules/qlinear/gguf_triton.py create mode 100644 gptqmodel/nn_modules/qlinear/gptq_pro.py create mode 100644 gptqmodel/nn_modules/qlinear/paroquant.py create mode 100644 gptqmodel/nn_modules/qlinear/paroquant_triton.py create mode 100644 gptqmodel/nn_modules/qlinear/torch_aten_kernel.py create mode 100644 gptqmodel/nn_modules/qlinear/torch_aten_kernel_awq.py create mode 100644 gptqmodel/quantization/fallback_smooth.py create mode 100644 gptqmodel/quantization/foem.py create mode 100644 gptqmodel/quantization/paroquant/__init__.py create mode 100644 gptqmodel/quantization/paroquant/modules/__init__.py create mode 100644 gptqmodel/quantization/paroquant/modules/triton/__init__.py create mode 100644 gptqmodel/quantization/paroquant/modules/triton/gemm.py create mode 100644 gptqmodel/quantization/paroquant/optimization.py create mode 100644 gptqmodel/quantization/protocol.py create mode 100644 gptqmodel/quantization/rtn.py create mode 100644 gptqmodel/utils/awq.py create mode 100644 gptqmodel/utils/device_telemetry.py create mode 100644 gptqmodel/utils/exllamav3.py create mode 100644 gptqmodel/utils/fallback.py create mode 100644 gptqmodel/utils/gptq_pro.py create mode 100644 gptqmodel/utils/hub.py create mode 100644 gptqmodel/utils/internal_gguf.py create mode 100644 gptqmodel/utils/jit_compile_baselines.py create mode 100644 gptqmodel/utils/paroquant.py create mode 100644 gptqmodel/utils/paroquant_benchmark.py create mode 100644 gptqmodel/utils/qqq.py create mode 100644 gptqmodel_ext/awq/gemm_fast_cuda_entry.cu create mode 100644 gptqmodel_ext/awq/gemv_fast_cuda_entry.cu create mode 100644 gptqmodel_ext/awq/torch_bind.cpp create mode 100644 gptqmodel_ext/exllamav3/bindings.cpp create mode 100644 gptqmodel_ext/exllamav3/hadamard.cpp create mode 100644 gptqmodel_ext/exllamav3/hadamard.h create mode 100644 gptqmodel_ext/exllamav3/hgemm.cu create mode 100644 gptqmodel_ext/exllamav3/hgemm.cuh create mode 100644 gptqmodel_ext/exllamav3/libtorch/linear.cpp create mode 100644 gptqmodel_ext/exllamav3/libtorch/linear.h create mode 100644 gptqmodel_ext/exllamav3/libtorch/linear_bc.h create mode 100644 gptqmodel_ext/exllamav3/ptx.cuh create mode 100644 gptqmodel_ext/exllamav3/quant/codebook.cuh create mode 100644 gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_1.cu create mode 100644 gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_1.cuh create mode 100644 gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_2.cu create mode 100644 gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_2.cuh create mode 100644 gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_3.cu create mode 100644 gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_3.cuh create mode 100644 gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_4.cu create mode 100644 gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_4.cuh create mode 100644 gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_5.cu create mode 100644 gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_5.cuh create mode 100644 gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_6.cu create mode 100644 gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_6.cuh create mode 100644 gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_7.cu create mode 100644 gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_7.cuh create mode 100644 gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_8.cu create mode 100644 gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_8.cuh create mode 100644 gptqmodel_ext/exllamav3/quant/exl3_devctx.cu create mode 100644 gptqmodel_ext/exllamav3/quant/exl3_devctx.cuh create mode 100644 gptqmodel_ext/exllamav3/quant/exl3_dq.cuh create mode 100644 gptqmodel_ext/exllamav3/quant/exl3_gemm.cu create mode 100644 gptqmodel_ext/exllamav3/quant/exl3_gemm.cuh create mode 100644 gptqmodel_ext/exllamav3/quant/exl3_gemm_inner.cuh create mode 100644 gptqmodel_ext/exllamav3/quant/exl3_gemm_kernel.cuh create mode 100644 gptqmodel_ext/exllamav3/quant/exl3_kernel_map.cu create mode 100644 gptqmodel_ext/exllamav3/quant/exl3_kernel_map.cuh create mode 100644 gptqmodel_ext/exllamav3/quant/exl3_kernel_map_packed.cuh create mode 100644 gptqmodel_ext/exllamav3/quant/hadamard.cu create mode 100644 gptqmodel_ext/exllamav3/quant/hadamard.cuh create mode 100644 gptqmodel_ext/exllamav3/quant/hadamard_inner.cuh create mode 100644 gptqmodel_ext/exllamav3/quant/pack.cu create mode 100644 gptqmodel_ext/exllamav3/quant/pack.cuh create mode 100644 gptqmodel_ext/exllamav3/quant/quantize.cu create mode 100644 gptqmodel_ext/exllamav3/quant/quantize.cuh create mode 100644 gptqmodel_ext/exllamav3/quant/reconstruct.cu create mode 100644 gptqmodel_ext/exllamav3/quant/reconstruct.cuh create mode 100644 gptqmodel_ext/exllamav3/quant/util.cu create mode 100644 gptqmodel_ext/exllamav3/quant/util.cuh create mode 100644 gptqmodel_ext/exllamav3/util.cuh create mode 100644 gptqmodel_ext/exllamav3/util.h create mode 100644 gptqmodel_ext/floatx_cpu.cpp create mode 100644 gptqmodel_ext/gptq_pro/gptq_pro_kernel.cu create mode 100644 gptqmodel_ext/gptq_pro/gptq_pro_kernel.cuh create mode 100644 gptqmodel_ext/gptq_pro/gptq_pro_torch.cpp create mode 100644 gptqmodel_ext/gptq_pro/gptq_pro_validate.cu create mode 100644 gptqmodel_ext/marlin/gptq_marlin_bf16.cu create mode 100644 gptqmodel_ext/marlin/gptq_marlin_fp16.cu create mode 100644 gptqmodel_ext/marlin/marlin_mma.h create mode 100644 gptqmodel_ext/marlin/marlin_torch_bf16.cpp create mode 100644 gptqmodel_ext/marlin/marlin_torch_fp16.cpp create mode 100644 gptqmodel_ext/paroquant/rotation.cu create mode 100644 gptqmodel_ext/paroquant/rotation.cuh create mode 100644 progress.md create mode 100644 scripts/arch.md create mode 100644 scripts/benchmark_awq_cuda_fp32_reduce_ab.py create mode 100644 scripts/benchmark_awq_fused_reduce_ab.py create mode 100644 scripts/benchmark_awq_triton_fp32_ab.py create mode 100644 scripts/benchmark_gguf_autotune_ab.py create mode 100644 scripts/benchmark_gguf_cpp_vs_torch.py create mode 100644 scripts/benchmark_gguf_dequant.py create mode 100644 scripts/benchmark_gguf_fused_ab.py create mode 100644 scripts/benchmark_gptq_pro.py create mode 100644 scripts/benchmark_llama3_2_paged_attention.py create mode 100644 scripts/benchmark_llama_cpp_vs_gptqmodel_gguf.py create mode 100644 scripts/benchmark_marlin_a100.py create mode 100644 scripts/benchmark_paroquant_official_vs_local.py create mode 100644 scripts/benchmark_paroquant_opt_scope_compare.py create mode 100644 scripts/benchmark_paroquant_optimizer_real_workload.py create mode 100644 scripts/benchmark_paroquant_pair_cache_real_workload.py create mode 100644 scripts/benchmark_paroquant_rotation_ab.py create mode 100644 scripts/benchmark_paroquant_rotation_cache_ab.py create mode 100644 scripts/benchmark_paroquant_runtime_cache_ab.py create mode 100644 scripts/benchmark_paroquant_triton_ab.py create mode 100644 scripts/benchmark_qwen35_moe_ab.py create mode 100644 scripts/generate_exl3_kernel_map_packed.py create mode 100644 scripts/nvml_visible_shim.c create mode 100644 scripts/paroquant_first_layer_ab.py create mode 100644 scripts/paroquant_module_set_scan.py create mode 100644 scripts/paroquant_single_module_scan.py create mode 100644 scripts/profile_paroquant_runtime_cache_case.py create mode 100644 scripts/repro_issue_2326.py create mode 100755 scripts/run_gptq_pro_validate.sh create mode 100644 scripts/serve_vllm_qwen35.py create mode 100644 scripts/sitecustomize.py create mode 100755 scripts/sync_cuda_toolkit_with_torch.sh create mode 100644 scripts/vllm_qwen35_shim.py create mode 100644 tests/__init__.py create mode 100644 tests/awq_test_utils.py create mode 100644 tests/benchmark/benchmark_torch_aten_vs_onednn.py create mode 100644 tests/benchmark/benchmark_torch_int8_vs_onednn.py create mode 100644 tests/conftest.py create mode 100644 tests/eval.py create mode 100644 tests/kernels/test_asymmetric_real_models.py create mode 100644 tests/kernels/test_awq_cuda_fp32_reduce.py create mode 100644 tests/kernels/test_awq_machete_marlin.py create mode 100644 tests/kernels/test_awq_triton_accum.py create mode 100644 tests/kernels/test_base_autotune.py create mode 100644 tests/kernels/test_exllamav3_kernel.py create mode 100644 tests/kernels/test_exllamav3_kernel_map_packed.py create mode 100644 tests/kernels/test_fallback.py create mode 100644 tests/kernels/test_fp8_kernel.py create mode 100644 tests/kernels/test_gguf_cpp.py create mode 100644 tests/kernels/test_paroquant.py create mode 100644 tests/kernels/test_qlinear_hierarchy.py create mode 100644 tests/models/awq/test_llama3_2_awq_protocol.py create mode 100644 tests/models/awq/test_qwen3_8b_base_awq.py create mode 100644 tests/models/foem/test_llama3_2.py create mode 100644 tests/models/foem/test_moe.py create mode 100644 tests/models/paroquant_first_layer_case_helper.py create mode 100644 tests/models/paroquant_optimize_case.py create mode 100644 tests/models/test_gemma4_variants.py create mode 100644 tests/models/test_glm4_moe_lite.py create mode 100644 tests/models/test_glm5_1_fp8_auto_decoder.py create mode 100644 tests/models/test_llama3_2_bitsandbytes.py create mode 100644 tests/models/test_llama3_2_dynamic_skip_layer_replay.py create mode 100644 tests/models/test_llama3_2_exllamav3.py create mode 100644 tests/models/test_llama3_2_fp8.py create mode 100644 tests/models/test_llama3_2_gguf.py create mode 100644 tests/models/test_llama3_2_gguf_protocol.py create mode 100644 tests/models/test_llama3_2_gptq_protocol.py create mode 100644 tests/models/test_llama3_2_lazy_turtle_memory.py create mode 100644 tests/models/test_llama3_2_paroquant_first_layer.py create mode 100644 tests/models/test_llama3_2_paroquant_optimize_compute_block.py create mode 100644 tests/models/test_llama3_2_paroquant_optimize_group.py create mode 100644 tests/models/test_llama3_2_paroquant_optimize_layer.py create mode 100644 tests/models/test_llama3_2_paroquant_optimize_module.py create mode 100644 tests/models/test_llama3_3_fp4_auto_decoder.py create mode 100644 tests/models/test_minicpm_o_4_5.py create mode 100644 tests/models/test_minicpm_v_4_5.py create mode 100644 tests/models/test_minimax_m2_hf.py create mode 100644 tests/models/test_model_test_fast_mode.py create mode 100644 tests/models/test_qwen3.py create mode 100644 tests/models/test_qwen3_5_moe_ab_regression.py create mode 100644 tests/models/test_qwen3_8b_fp8_auto_decoder.py create mode 100644 tests/models/test_qwen3_8b_fp8_gsm8k_last4.py create mode 100644 tests/models/test_qwen3_8b_nvfp4.py create mode 100644 tests/protocol/test_protocol.py create mode 100644 tests/q4_reference.py create mode 100644 tests/qcfg/test_config_dispatch.py create mode 100644 tests/qcfg/test_fallback_meta.py create mode 100644 tests/test_auto_module_decoder.py create mode 100644 tests/test_awq_bitblas.py create mode 100644 tests/test_awq_gemm.py create mode 100644 tests/test_awq_gemv.py create mode 100644 tests/test_awq_gemv_fast.py create mode 100644 tests/test_awq_gemv_fast_jit.py create mode 100644 tests/test_awq_inference_llm_awq.py create mode 100644 tests/test_awq_inference_mistral.py create mode 100644 tests/test_awq_jit_include_paths.py create mode 100644 tests/test_awq_llm_awq.py create mode 100644 tests/test_awq_loader_dtype.py create mode 100644 tests/test_awq_marlin.py create mode 100644 tests/test_backend_naming.py create mode 100644 tests/test_baichuan_rotary_buffers.py create mode 100644 tests/test_bitsandbytes.py create mode 100644 tests/test_calibration_data_device.py create mode 100644 tests/test_cpp_jit_progress.py create mode 100644 tests/test_cutlass_stable_abi_headers.py create mode 100644 tests/test_eval_loader_args.py create mode 100644 tests/test_eval_runtime.py create mode 100644 tests/test_evalution_suite_stream_defaults.py create mode 100644 tests/test_exllamav2_awq_jit.py create mode 100644 tests/test_exllamav2_jit.py create mode 100644 tests/test_exllamav3.py create mode 100644 tests/test_exllamav3_jit.py create mode 100644 tests/test_extension_load_api.py create mode 100644 tests/test_fallback.py create mode 100644 tests/test_fp4_qwen3_nvfp4.py create mode 100644 tests/test_fp8.py create mode 100644 tests/test_gemma4_support.py create mode 100644 tests/test_generate_attention_mask.py create mode 100644 tests/test_gguf_qlinear_llama.py create mode 100644 tests/test_glm_moe_dsa_support.py create mode 100644 tests/test_gptq_pro.py create mode 100644 tests/test_granitemoehybrid_monkeypatch.py create mode 100644 tests/test_hf_config_autofix.py create mode 100644 tests/test_hf_config_compat.py create mode 100644 tests/test_hf_init_guard.py create mode 100644 tests/test_hf_utils.py create mode 100644 tests/test_inference_speed_harness.py create mode 100644 tests/test_internal_gguf.py create mode 100644 tests/test_linalg_warmup.py create mode 100644 tests/test_loader.py create mode 100644 tests/test_local_model_paths.py create mode 100644 tests/test_logger.py create mode 100644 tests/test_looper_helpers.py create mode 100644 tests/test_machete_jit.py create mode 100644 tests/test_marlin_jit.py create mode 100644 tests/test_model_definition_exports.py create mode 100644 tests/test_model_test_baseline_fallback.py create mode 100644 tests/test_model_test_helpers.py create mode 100644 tests/test_module_preprocessor.py create mode 100644 tests/test_out_of_model_tensors.py create mode 100644 tests/test_ovis_generate_wrapper.py create mode 100644 tests/test_paroquant.py create mode 100644 tests/test_perplexity_logic.py create mode 100644 tests/test_q4_reference.py create mode 100644 tests/test_q4_torch_fused.py create mode 100644 tests/test_qqq_jit.py create mode 100644 tests/test_qwen2_family_compat.py create mode 100644 tests/test_qwen3_5_batching.py create mode 100644 tests/test_qwen_moe_converter.py create mode 100644 tests/test_random_string.py create mode 100644 tests/test_save_loaded_quantized_model_torch_fused.py create mode 100644 tests/test_serve_vllm_qwen35.py create mode 100644 tests/test_split_by_layer_save.py create mode 100644 tests/test_startup_banner.py create mode 100644 tests/test_structure.py create mode 100644 tests/test_subset_plan.py create mode 100644 tests/test_tensor_parallel_padder.py create mode 100644 tests/test_tiny_moe_quant_smoke.py create mode 100644 tests/test_torch_aten_kernel_import_guard.py create mode 100644 tests/test_torch_ops_jit_extension.py create mode 100644 tests/test_weight_only.py create mode 100644 tests/test_weight_only_config.py diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 000000000..1917eba67 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,11 @@ +.git +.copilot +__pycache__/ +*.pyc +*.pyo +*.pyd +.pytest_cache/ +.ruff_cache/ +GPTQModel.egg-info/ +gptqmodel_offload/ +gptq_log_*.log diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 000000000..b78bcebf6 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,38 @@ +## Summary + +Describe the bug, fix, or feature clearly and briefly. + +## What Changed + +- List the main code changes. +- List any API or behavior changes. +- List any follow-up work that is intentionally out of scope. + +## Tests + +Every working PR must include at least one new simple, fast, targeted unit test when the change affects behavior, a bug fix, or a regression path. + +- [ ] I added a new simple/fast unit test for this change, or documented why that is not applicable. +- [ ] I ran the new targeted test locally before opening this PR. +- [ ] I ran any other directly relevant local tests. + +Paste the exact test commands and results here: + +```bash +``` + +## Review Requirements + +AI-assisted code is welcome. + +Every changed file must still be properly reviewed by a human before the PR is opened as ready for review. + +We will not accept PRs that are effectively unreviewed AI output. Non-human-reviewed changes often introduce obscure structure, mismatched APIs, project-inconsistent code patterns, or unnecessary monkeypatching instead of a correct fix or clean feature expansion. + +- [ ] I personally reviewed every file in this diff. +- [ ] I checked that the code matches existing project structure, APIs, and conventions. +- [ ] I avoided unnecessary monkeypatching and used the project's normal extension points where possible. + +## Notes + +Add any migration notes, risks, compatibility concerns, or reviewer guidance here. diff --git a/.github/scripts/allocate_gpu.py b/.github/scripts/allocate_gpu.py new file mode 100644 index 000000000..7b71258e8 --- /dev/null +++ b/.github/scripts/allocate_gpu.py @@ -0,0 +1,145 @@ +import argparse +import os +import sys +import time +import urllib.error +import urllib.parse +import urllib.request + + +def now_ms() -> int: + return time.time_ns() // 1_000_000 + + +def fetch_text(url: str, *, timeout: float, suppress_error: bool = False) -> str: + try: + with urllib.request.urlopen(url, timeout=timeout) as response: + return response.read().decode("utf-8", errors="replace") + except (urllib.error.URLError, TimeoutError, OSError) as exc: + if suppress_error: + print(f"Request failed for {url}: {exc}") + return "" + raise + + +def fetch_with_retry(url: str, *, timeout: float, retries: int, retry_delay: float) -> str: + last_error: Exception | None = None + for attempt in range(retries + 1): + try: + return fetch_text(url, timeout=timeout) + except (urllib.error.URLError, TimeoutError, OSError) as exc: + last_error = exc + if attempt < retries: + time.sleep(retry_delay) + if last_error is not None: + print(f"Request failed after retries: {last_error}") + return "" + + +def print_status(base_url: str) -> None: + status = fetch_text(f"{base_url}/gpu/status", timeout=10, suppress_error=True).strip() + if status: + print(status) + + +def append_github_env(name: str, value: str) -> None: + github_env = os.environ.get("GITHUB_ENV") + if not github_env: + return + with open(github_env, "a", encoding="utf-8") as fh: + fh.write(f"{name}={value}\n") + + +def is_valid_gpu_response(value: str) -> bool: + if not value: + return False + for part in value.split(","): + if not part: + return False + if part.startswith("-"): + if not part[1:].isdigit(): + return False + elif not part.isdigit(): + return False + return True + + +def main() -> int: + parser = argparse.ArgumentParser() + parser.add_argument("--base-url", required=True) + parser.add_argument("--run-id", required=True) + parser.add_argument("--test", required=True) + parser.add_argument("--runner", required=True) + parser.add_argument("--count", required=True) + parser.add_argument("--sleep-sec", type=float, default=5) + parser.add_argument("--timeout-sec", type=int, default=18000) + parser.add_argument("--request-timeout", type=float, default=10) + parser.add_argument("--retries", type=int, default=3) + parser.add_argument("--retry-delay", type=float, default=1) + parser.add_argument("--require-single", action="store_true") + args = parser.parse_args() + + encoded_test = urllib.parse.quote(args.test, safe="") + encoded_runner = urllib.parse.quote(args.runner, safe="") + start_s = time.time() + + print("Requesting GPU from allocator") + print( + f"run_id={args.run_id} test={args.test} runner={args.runner} count={args.count}" + ) + + while True: + ts_ms = now_ms() + url = ( + f"{args.base_url}/gpu/get?runid={args.run_id}×tamp={ts_ms}" + f"&test={encoded_test}&runner={encoded_runner}&count={args.count}" + ) + print(f"requesting GPU with: {url}") + + resp = fetch_with_retry( + url, + timeout=args.request_timeout, + retries=args.retries, + retry_delay=args.retry_delay, + ).replace("\r", "").replace("\n", "").strip() + + print(f"resp={{{resp}}}") + + if not is_valid_gpu_response(resp): + print(f"Allocator returned invalid response: {resp!r} (temporary error)") + print_status(args.base_url) + time.sleep(args.sleep_sec) + continue + + if resp.startswith("-") and "," not in resp: + elapsed = int(time.time() - start_s) + if elapsed >= args.timeout_sec: + print( + f"Timed out after {args.timeout_sec}s waiting for GPU " + f"(last response={resp})" + ) + print_status(args.base_url) + return 1 + + print( + f"No GPU available (response={resp}). Waiting {args.sleep_sec}s..." + f" elapsed={elapsed}s" + ) + print_status(args.base_url) + time.sleep(args.sleep_sec) + continue + + if args.require_single and "," in resp: + print(f"Allocator returned multiple GPUs for job requiring one GPU: {resp}") + return 1 + + print(f"Allocated GPU ID: {resp}") + append_github_env("CUDA_VISIBLE_DEVICES", resp) + append_github_env("STEP_TIMESTAMP", str(now_ms())) + print(f"CUDA_VISIBLE_DEVICES set to {resp}") + print_status(args.base_url) + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/.github/scripts/blacklist.yaml b/.github/scripts/blacklist.yaml new file mode 100644 index 000000000..2e58a6a74 --- /dev/null +++ b/.github/scripts/blacklist.yaml @@ -0,0 +1,3 @@ +tests/models: + test_baichuan.py: + - flash_attn diff --git a/.github/scripts/ci_loop_versions.py b/.github/scripts/ci_loop_versions.py new file mode 100644 index 000000000..c65b14d63 --- /dev/null +++ b/.github/scripts/ci_loop_versions.py @@ -0,0 +1,37 @@ +import argparse +import json + +import requests +from packaging.specifiers import SpecifierSet +from packaging.version import Version + + +def get_versions(package: str, version_spec: str) -> list[str]: + specifier = SpecifierSet(version_spec) + + url = f"https://pypi.org/pypi/{package}/json" + resp = requests.get(url, timeout=30) + resp.raise_for_status() + data = resp.json() + + all_versions = data["releases"].keys() + + matched = sorted( + (Version(v) for v in all_versions if Version(v) in specifier), + reverse=True, + ) + return [str(v) for v in matched] + + +def main(): + parser = argparse.ArgumentParser(description="List matching PyPI versions as JSON") + parser.add_argument("package", help="package name, e.g. setuptools") + parser.add_argument("version", help='version spec, e.g. ">=77.0.1,<83"') + args = parser.parse_args() + + versions = get_versions(args.package, args.version) + print(json.dumps(versions)) + + +if __name__ == "__main__": + main() diff --git a/.github/scripts/deps.yaml b/.github/scripts/deps.yaml index cc38e5eef..6ff9cca35 100644 --- a/.github/scripts/deps.yaml +++ b/.github/scripts/deps.yaml @@ -5,17 +5,21 @@ common: - swig # for models/tests tests: - test_q4_bitblas.py: - - http://10.0.13.31/files/bitblas-0.1.0%2Bubuntu.22.4.cu131-py3-none-any.whl + test_awq_bitblas.py: + - http://10.0.13.31/files/bitblas-0.1.0.post1+ubuntu.22.4.cu131-py3-none-any.whl + + test_bitblas.py: + - http://10.0.13.31/files/bitblas-0.1.0.post1+ubuntu.22.4.cu131-py3-none-any.whl - test_perplexity.py: - - http://10.0.13.31/files/bitblas-0.1.0%2Bubuntu.22.4.cu131-py3-none-any.whl + + test_q4_bitblas.py: + - http://10.0.13.31/files/bitblas-0.1.0.post1+ubuntu.22.4.cu131-py3-none-any.whl test_inference_speed.py: - - http://10.0.13.31/files/bitblas-0.1.0%2Bubuntu.22.4.cu131-py3-none-any.whl + - http://10.0.13.31/files/bitblas-0.1.0.post1+ubuntu.22.4.cu131-py3-none-any.whl test_save_loaded_quantized_model.py: - - http://10.0.13.31/files/bitblas-0.1.0%2Bubuntu.22.4.cu131-py3-none-any.whl + - http://10.0.13.31/files/bitblas-0.1.0.post1+ubuntu.22.4.cu131-py3-none-any.whl test_olora_finetuning_xpu.py: - intel_extension_for_pytorch @@ -49,7 +53,25 @@ tests: - https://github.com/huggingface/transformers - https://github.com/huggingface/optimum - test_awq.py: + test_awq_gemm.py: + - peft + + test_awq_marlin.py: + - peft + + test_awq_gemv.py: + - peft + + test_awq_gemv_fast.py: + - peft + + test_awq_llm_awq.py: + - peft + + test_awq_inference_mistral.py: + - peft + + test_awq_inference_llm_awq.py: - peft test_asym_gptq_v1.py: @@ -66,28 +88,38 @@ tests: test_eval.py: - peft - - evalplus - - http://10.0.13.31/files/vllm-0.1.dev1+gdc6b57846.d20260309.cu131-cp313-cp313-linux_x86_64.whl + - http://10.0.13.31/files/evalplus-0.4.0.dev44-py3-none-any.whl + - http://10.0.13.31/files/vllm-0.17.2rc1.dev52+g86b7e3c95.d20260318.cu131-cp313-cp313-linux_x86_64.whl + - flashinfer-python==0.6.5 test_evalplus.py: - - evalplus - - test_lm_eval.py: - - peft + - http://10.0.13.31/files/evalplus-0.4.0.dev44-py3-none-any.whl test_nemotron_ultra.py: - http://10.0.13.31/files/causal_conv1d-1.6.0-cp313-cp313-linux_x86_64.whl -tests/models: - test_pangu_alpha.py: - - jieba + test_openai_server.py: + - openai + - uvicorn + - fastapi + - pydantic + test_local_model_paths.py: + - gguf + +tests/models: test_cohere2: - jieba test_gemma: - jieba + test_bloom.py: + - peft + + test_brumby.py: + - retention + test_hymba.py: - http://10.0.13.31/files/causal_conv1d-1.6.0-cp313-cp313-linux_x86_64.whl - mamba_ssm @@ -98,14 +130,9 @@ tests/models: test_ovis_1_6_llama.py: - transformers<=4.44.2 - test_phi_4.py: - - peft - test_phi_3_moe.py: - transformers<=4.44.2 - test_bloom.py: - - peft - - test_gptbigcode.py: +tests/models/awq: + test_qwen3_8b_base_awq.py: - peft diff --git a/.github/scripts/install_deps.py b/.github/scripts/install_deps.py index f57296952..4c9ce346e 100644 --- a/.github/scripts/install_deps.py +++ b/.github/scripts/install_deps.py @@ -1,4 +1,5 @@ import os +import re import subprocess import sys from pathlib import Path @@ -6,10 +7,13 @@ import yaml base_dir = os.path.dirname(os.path.abspath(__file__)) +_PKG_NAME_RE = re.compile(r"^[A-Za-z0-9_.-]+") + def resolve_test_path(raw_name: str) -> Path: return Path("tests") / f"{raw_name}.py" + def normalize_pkg_spec(s: str) -> str: s = (s or "").strip() if not s: @@ -26,6 +30,32 @@ def normalize_pkg_spec(s: str) -> str: return s + +def pkg_key(spec: str) -> str: + spec = normalize_pkg_spec(spec) + if not spec: + return spec + + if spec.startswith("git+"): + repo = spec.rsplit("/", 1)[-1] + if repo.endswith(".git"): + repo = repo[:-4] + return repo.split("@", 1)[0].lower().replace("_", "-") + + if "://" in spec: + return spec + + spec = spec.split(";", 1)[0].strip() + if " @" in spec: + spec = spec.split(" @", 1)[0].strip() + + match = _PKG_NAME_RE.match(spec) + if not match: + return spec.lower() + + return match.group(0).lower().replace("_", "-") + + def collect_pkgs(test_path: Path, deps: dict): specific_pkgs = set() @@ -49,15 +79,17 @@ def collect_pkgs(test_path: Path, deps: dict): else: pass - return specific_pkgs, common_pkgs + specific_pkg_keys = {pkg_key(pkg) for pkg in specific_pkgs} + common_pkgs = {pkg for pkg in common_pkgs if pkg_key(pkg) not in specific_pkg_keys} + return specific_pkgs, common_pkgs def pip_install(pkgs): if not pkgs: return - print("Installing deps:") + print("--- Installing deps:") for p in pkgs: print(" -", p) @@ -79,12 +111,13 @@ def uv_install(pkgs): pkgs = [normalize_pkg_spec(p) for p in pkgs] - print("Installing deps with uv:") + print("--- Installing deps with uv:") for p in pkgs: print(" -", p) for p in pkgs: cmd = ["uv", "pip", "install", "--no-cache", p] + print("installing: ", cmd) try: subprocess.check_call(cmd, shell=False) except Exception as e: @@ -100,6 +133,6 @@ def uv_install(pkgs): specific_pkgs, common_pkgs = collect_pkgs(test_path, deps) - uv_install(sorted(specific_pkgs)) + uv_install(specific_pkgs) - uv_install(sorted(common_pkgs)) + uv_install(common_pkgs) diff --git a/.github/scripts/list_test_files.py b/.github/scripts/list_test_files.py index 616b47cae..0eeec8812 100644 --- a/.github/scripts/list_test_files.py +++ b/.github/scripts/list_test_files.py @@ -3,11 +3,13 @@ import os import re from pathlib import Path -from typing import List, Tuple, Union, Optional +from typing import Dict, List, Tuple, Union, Optional + def _sort_key(p: str): return ("moe" in p, "/" in p, p) + def _split_csv(s: Optional[str]) -> List[str]: if not s: return [] @@ -18,16 +20,34 @@ def _strip_py_suffix(name: str) -> str: return name.removesuffix(".py") +def _is_model_compat_test(rel_path: str, file_path: Path) -> bool: + if not rel_path.startswith("models/"): + return False + + try: + contents = file_path.read_text(encoding="utf-8") + except OSError: + return False + + compat_markers = ( + "quantize_and_evaluate(", + "self.evaluate_model(", + "check_results(", + ) + return any(marker in contents for marker in compat_markers) + + def getFiles( ignored_test_files: Union[str, List[str]], test_names: str = "", test_regex: str = ".*", tests_root: Union[str, Path] = "tests", -) -> Tuple[List[str], List[str]]: +) -> Tuple[List[str], List[str], List[str]]: """ Returns: - (torch_test_files, m4_test_files) - - torch_test_files: tests/**/test_*.py excluding mlx / ipex / xpu + (torch_test_files, model_compat_test_files, m4_test_files) + - torch_test_files: tests/**/test_*.py excluding mlx / ipex / xpu and model compat files + - model_compat_test_files: tests/models/**/test_*.py files that run model quantize + evaluation compat flows - m4_test_files: tests/**/test_*.py that contains mlx or apple """ tests_root = Path(tests_root) @@ -40,10 +60,22 @@ def getFiles( ignored_set = set(_strip_py_suffix(x) for x in ignored_list) # all tests under tests/**/test_*.py (includes tests/models/**) - all_tests = { - str(p.relative_to(tests_root).with_suffix("")) + all_tests: Dict[str, Path] = { + rel: p for p in tests_root.rglob("test_*.py") - if p.stem not in ignored_set + for rel in [str(p.relative_to(tests_root).with_suffix(""))] + if rel not in ignored_set and p.stem not in ignored_set + } + + model_compat_test_files = { + rel + for rel, path in all_tests.items() + if (not input_test_files_list or rel in input_test_files_list) + and "mlx" not in rel + and "ipex" not in rel + and "xpu" not in rel + and re.match(test_regex, rel) + and _is_model_compat_test(rel, path) } # torch tests @@ -51,6 +83,7 @@ def getFiles( f for f in all_tests if (not input_test_files_list or f in input_test_files_list) + and f not in model_compat_test_files and "mlx" not in f and "ipex" not in f and "xpu" not in f @@ -66,7 +99,11 @@ def getFiles( and re.match(test_regex, f) } - return sorted(torch_test_files, key=_sort_key), sorted(m4_test_files, key=_sort_key) + return ( + sorted(torch_test_files, key=_sort_key), + sorted(model_compat_test_files, key=_sort_key), + sorted(m4_test_files, key=_sort_key), + ) def main() -> None: @@ -89,14 +126,14 @@ def main() -> None: parser.add_argument("--tests-root", default="tests") args = parser.parse_args() - torch_files, m4_files = getFiles( + torch_files, model_compat_files, m4_files = getFiles( ignored_test_files=args.ignored_test_files, test_names=args.test_names, test_regex=args.test_regex, tests_root=args.tests_root, ) - print(f"{json.dumps(torch_files)}|{json.dumps(m4_files)}") + print(f"{json.dumps(torch_files)}|{json.dumps(model_compat_files)}|{json.dumps(m4_files)}") if __name__ == "__main__": diff --git a/.github/scripts/parse_test_config.py b/.github/scripts/parse_test_config.py new file mode 100644 index 000000000..3d8b808b3 --- /dev/null +++ b/.github/scripts/parse_test_config.py @@ -0,0 +1,87 @@ +import json +from pathlib import Path +from typing import Any + +import yaml + + +def parse_test_config( + yaml_file: str | Path, + group: str, + test_name: str | None = None, +) -> dict[str, Any]: + yaml_path = Path(yaml_file) + with yaml_path.open("r", encoding="utf-8") as f: + data = yaml.safe_load(f) or {} + + result: dict[str, Any] = {} + + common_data = data.get("common") or {} + if not isinstance(common_data, dict): + raise ValueError("group must be a mapping: common") + + for key, value in common_data.items(): + if not isinstance(value, dict): + result[key] = value + + if group not in data: + raise KeyError(f"group not found: {group}") + + group_data = data[group] or {} + if not isinstance(group_data, dict): + raise ValueError(f"group must be a mapping: {group}") + + # Group-level shared config overrides common defaults. + for key, value in group_data.items(): + if not isinstance(value, dict): + result[key] = value + + # Per-test config overrides group/common defaults. Missing per-test entries + # fall back to the merged defaults. + if test_name is not None: + test_config = group_data.get(test_name) + if test_config is None: + return result + if not isinstance(test_config, dict): + raise ValueError(f"test config must be a mapping: {test_name}") + result.update(test_config) + + return result + + +def parse_test_config_flags( + yaml_file: str | Path, + group: str, + test_name: str | None = None, +) -> dict[str, int]: + config = parse_test_config(yaml_file, group, test_name) + return {key: 1 for key in config} + + +def main() -> None: + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--yaml-file", + default=Path(__file__).with_name("test.yaml"), + ) + parser.add_argument("--group", required=True) + parser.add_argument("--test-name") + parser.add_argument( + "--flags-only", + action="store_true", + help="Return only config keys with value 1, for example: {'py': 1, 'gpu': 1}", + ) + args = parser.parse_args() + + if args.flags_only: + result = parse_test_config_flags(args.yaml_file, args.group, args.test_name) + else: + result = parse_test_config(args.yaml_file, args.group, args.test_name) + + print(json.dumps(result, ensure_ascii=False)) + + +if __name__ == "__main__": + main() diff --git a/.github/scripts/release_gpu.py b/.github/scripts/release_gpu.py new file mode 100644 index 000000000..dceae3580 --- /dev/null +++ b/.github/scripts/release_gpu.py @@ -0,0 +1,45 @@ +import argparse +import sys +import urllib.error +import urllib.parse +import urllib.request + + +def fetch_text(url: str, *, timeout: float) -> str: + with urllib.request.urlopen(url, timeout=timeout) as response: + return response.read().decode("utf-8", errors="replace") + + +def main() -> int: + parser = argparse.ArgumentParser() + parser.add_argument("--base-url", required=True) + parser.add_argument("--run-id", required=True) + parser.add_argument("--gpu-id", required=True) + parser.add_argument("--timestamp", required=True) + parser.add_argument("--test", required=True) + parser.add_argument("--runner", required=True) + parser.add_argument("--timeout", type=float, default=10) + args = parser.parse_args() + + encoded_test = urllib.parse.quote(args.test, safe="") + encoded_runner = urllib.parse.quote(args.runner, safe="") + url = ( + f"{args.base_url}/gpu/release?runid={args.run_id}&gpu={args.gpu_id}" + f"×tamp={args.timestamp}&test={encoded_test}&runner={encoded_runner}" + ) + print(url) + + try: + resp = fetch_text(url, timeout=args.timeout).strip() + except (urllib.error.URLError, TimeoutError, OSError) as exc: + print(f"Failed to release GPU: {exc}") + return 0 + + print(f"response: {resp}") + if resp != args.gpu_id: + print(f"Error: response ({resp}) != expected ({args.gpu_id})") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/.github/scripts/run_tests.py b/.github/scripts/run_tests.py new file mode 100644 index 000000000..662c6c938 --- /dev/null +++ b/.github/scripts/run_tests.py @@ -0,0 +1,200 @@ +import argparse +import os +import signal +import subprocess +import sys +import threading +import time +import urllib.error +import urllib.parse +import urllib.request +from pathlib import Path + + +def append_github_env(name: str, value: str) -> None: + github_env = os.environ.get("GITHUB_ENV") + if not github_env: + return + with open(github_env, "a", encoding="utf-8") as fh: + fh.write(f"{name}={value}\n") + + +def fetch_text(url: str, *, timeout: float, suppress_error: bool = False) -> str: + try: + with urllib.request.urlopen(url, timeout=timeout) as response: + return response.read().decode("utf-8", errors="replace") + except (urllib.error.URLError, TimeoutError, OSError) as exc: + if suppress_error: + print(f"Request failed for {url}: {exc}") + return "" + raise + + +def kill_process_group(proc: subprocess.Popen[str]) -> None: + try: + os.killpg(proc.pid, signal.SIGKILL) + except ProcessLookupError: + # Process (or process group) is already gone; nothing to do. + pass + + +def start_keepalive_monitor( + *, + proc: subprocess.Popen[str], + keep_alive_url: str, + interval_sec: int, +) -> tuple[threading.Thread, threading.Event, dict[str, int]]: + stop_event = threading.Event() + state = {"forced_exit_code": 0} + + def worker() -> None: + print(f"start to keep alive... {keep_alive_url}") + while not stop_event.wait(interval_sec): + resp = fetch_text(keep_alive_url, timeout=10, suppress_error=True) + # if resp.strip() == "-1": + if int(resp.strip()) < 0: + print(f"Server returned {resp.strip()}, terminating job...") + state["forced_exit_code"] = 3 + kill_process_group(proc) + stop_event.set() + return + print("gpu is kept alive...") + + thread = threading.Thread(target=worker, daemon=True) + thread.start() + return thread, stop_event, state + + +def stream_process_output(proc: subprocess.Popen[str], log_file: Path) -> int: + assert proc.stdout is not None + log_file.parent.mkdir(parents=True, exist_ok=True) + with log_file.open("w", encoding="utf-8") as fh: + for line in proc.stdout: + print(line, end="") + fh.write(line) + return proc.wait() + + +def maybe_uninstall_vllm() -> None: + uninstall_cmd = ["uv", "pip", "uninstall", "vllm", "-y"] + print(f"+ {' '.join(uninstall_cmd)}") + subprocess.run(uninstall_cmd, check=False) + + list_cmd = ["uv", "pip", "list"] + print(f"+ {' '.join(list_cmd)}") + subprocess.run(list_cmd, check=False) + + +def log_vram(base_url: str, run_id: str, gpu_id: str, execution_time: int, test: str) -> None: + encoded_test = urllib.parse.quote(test, safe="") + url = ( + f"{base_url}/gpu/logVram?runid={run_id}&gpu={gpu_id}" + f"&range={execution_time}&unit=second&test={encoded_test}" + ) + try: + print(fetch_text(url, timeout=30, suppress_error=True)) + except Exception: + # Logging VRAM usage is best-effort; failures must not affect the main test flow. + pass + + +def main() -> int: + parser = argparse.ArgumentParser() + parser.add_argument("--base-url", required=True) + parser.add_argument("--run-id", required=True) + parser.add_argument("--test-script", required=True) + parser.add_argument("--runner", required=True) + parser.add_argument("--gpu-id", default="") + parser.add_argument("--model-test-mode") + parser.add_argument("--clear-cuda", action="store_true") + parser.add_argument("--xpu-mode", action="store_true") + parser.add_argument("--monitor-interval-sec", type=int, default=60) + args = parser.parse_args() + + env = os.environ.copy() + if args.clear_cuda: + env["CUDA_VISIBLE_DEVICES"] = "" + print("CUDA_VISIBLE_DEVICES=") + + if args.xpu_mode: + maybe_uninstall_vllm() + + if args.model_test_mode is not None: + env["GPTQMODEL_MODEL_TEST_MODE"] = args.model_test_mode + print(f"GPTQMODEL_MODEL_TEST_MODE={args.model_test_mode}") + + print(f"CUDA_VISIBLE_DEVICES={env.get('CUDA_VISIBLE_DEVICES', '')}") + + log_dir = Path(f"/opt/dist/GPTQModel/{args.run_id}/logs") + log_file = log_dir / f"{args.test_script}.log" + log_dir.mkdir(parents=True, exist_ok=True) + + pytest_cmd = ["pytest", "--durations=0", f"tests/{args.test_script}.py"] + print(f"+ {' '.join(pytest_cmd)}") + + proc = subprocess.Popen( + pytest_cmd, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + bufsize=1, + env=env, + start_new_session=True, + ) + + encoded_test = urllib.parse.quote(args.test_script, safe="") + encoded_runner = urllib.parse.quote(args.runner, safe="") + keep_alive_url = ( + f"{args.base_url}/gpu/keepalive?runid={args.run_id}&test={encoded_test}" + f"&runner={encoded_runner}×tamp={int(time.time())}&gpu={env.get('CUDA_VISIBLE_DEVICES', '')}" + ) + + monitor_thread = None + monitor_stop = None + monitor_state = {"forced_exit_code": 0} + if env.get("CUDA_VISIBLE_DEVICES", ""): + monitor_thread, monitor_stop, monitor_state = start_keepalive_monitor( + proc=proc, + keep_alive_url=keep_alive_url, + interval_sec=args.monitor_interval_sec, + ) + + start_time = time.time() + try: + return_code = stream_process_output(proc, log_file) + finally: + if monitor_stop is not None: + print("trap cleanup EXIT...") + monitor_stop.set() + if monitor_thread is not None: + monitor_thread.join(timeout=5) + + if monitor_state["forced_exit_code"]: + append_github_env("ERROR", "22") + return 22 + + if return_code != 0: + append_github_env("ERROR", "22") + print(f"pipe status wrong: {return_code}") + return 22 + + execution_time = int(time.time() - start_time) + print(f"{execution_time // 60}m {execution_time % 60}s") + + try: + for entry in sorted(log_dir.iterdir()): + stat = entry.stat() + size = stat.st_size + print(f"{size:>10} {entry.name}") + except OSError as exc: + print(f"Failed to list log dir: {exc}") + + gpu_id = args.gpu_id or env.get("CUDA_VISIBLE_DEVICES", "") + if gpu_id: + log_vram(args.base_url, args.run_id, gpu_id, execution_time, args.test_script) + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/.github/scripts/test.yaml b/.github/scripts/test.yaml new file mode 100644 index 000000000..7aae3ce7a --- /dev/null +++ b/.github/scripts/test.yaml @@ -0,0 +1,16 @@ +common: + py: 3.14t # common config + gpu: 1 + +tests: + test_multi_gpu_inference: # per-test config + gpu: 5 + py: 3.14t + + test_calibration_data_device: + gpu: 2 + py: 3.14t + + +tests/models: + py: 3.13 diff --git a/.github/scripts/uninstall_deps.py b/.github/scripts/uninstall_deps.py new file mode 100644 index 000000000..5e1763e0a --- /dev/null +++ b/.github/scripts/uninstall_deps.py @@ -0,0 +1,68 @@ +import os +import subprocess +import sys +from pathlib import Path + +import yaml + +base_dir = os.path.dirname(os.path.abspath(__file__)) + + +def resolve_test_path(raw_name: str) -> Path: + return Path("tests") / f"{raw_name}.py" + + +def collect_pkgs(test_path: Path, deps: dict): + specific_pkgs = set() + + common_pkgs = set(deps.get("common") or []) + + specific_pkgs.update(deps.get("tests", {}).get(test_path.name) or []) + + test_path_str = test_path.as_posix() + for key, value in deps.items(): + if not (isinstance(key, str) and key.startswith("tests/")): + continue + if not test_path_str.startswith(key + "/"): + continue + + if isinstance(value, list): + specific_pkgs.update(value) + + elif isinstance(value, dict): + specific_pkgs.update(value.get(test_path.name) or []) + + else: + pass + + return specific_pkgs, common_pkgs + + +def uv_uninstall(pkgs): + if not pkgs: + return + + print("--- Uninstalling deps with uv:") + for p in pkgs: + print(" -", p) + + for p in pkgs: + cmd = ["uv", "pip", "uninstall", p] + try: + subprocess.check_call(cmd, shell=False) + except Exception as e: + print(f"--- Unnstall failed: {e}") + + +if __name__ == "__main__": + raw_name = sys.argv[1].removeprefix("tests/").removesuffix(".py") + test_path = resolve_test_path(raw_name) + + with open(os.path.join(base_dir, "blacklist.yaml")) as f: + deps = yaml.safe_load(f) + + specific_pkgs, common_pkgs = collect_pkgs(test_path, deps) + + uv_uninstall(sorted(specific_pkgs)) + + uv_uninstall(sorted(common_pkgs)) diff --git a/.github/workflows/compatibility.yml b/.github/workflows/compatibility.yml new file mode 100644 index 000000000..2637d5ed3 --- /dev/null +++ b/.github/workflows/compatibility.yml @@ -0,0 +1,54 @@ +name: Test Compatibility + +on: + push: + paths: + - pyproject.toml + workflow_dispatch: + +permissions: + contents: read + +jobs: + prepare-setuptools: + runs-on: ubuntu-latest + outputs: + versions: ${{ steps.parser.outputs.versions || '[]' }} + steps: + - uses: actions/checkout@v6 + - uses: actions/setup-python@v6 + with: + python-version: "3.14" + + - name: Generate version matrix + id: parser + run: | + python -m pip install --upgrade requests packaging + versions=$(python .github/scripts/ci_loop_versions.py setuptools ">=77.0.1,<83") + echo "versions=$versions" >> "$GITHUB_OUTPUT" + + check-setuptools: + needs: prepare-setuptools + if: needs.prepare-setuptools.outputs.versions != '[]' + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + version: ${{ fromJSON(needs.prepare-setuptools.outputs.versions) }} + + steps: + - uses: actions/checkout@v6 + - uses: actions/setup-python@v6 + with: + python-version: "3.14" + cache: pip + + - name: Install package with selected setuptools + run: | + python -m pip install --upgrade pip + python -m pip install . "setuptools==${{ matrix.version }}" + + - name: Show versions + run: | + python --version + python -m pip show setuptools \ No newline at end of file diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 797f03243..4410e0b37 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -5,6 +5,7 @@ run-name: "${{ github.event.inputs.title }}" defaults: run: shell: bash -le {0} + on: release: types: [ published ] @@ -27,10 +28,6 @@ on: description: 'PR Number' required: false type: number - max-parallel: - description: 'max parallel jobs' - required: false - default: '10' upload_release: description: 'upload to release (it only works with a tag ref)' type: boolean @@ -46,34 +43,11 @@ on: type: boolean required: false default: false -# for github limits, test only -# cuda-version: -# description: 'cuda version(128)' -# required: false -# default: '' -# torch-version: -# description: 'torch version(2.8.0)' -# required: false -# default: '' -# python-version: -# description: 'python version(313)' -# required: false -# default: '' env: - CUDA_DEVICE_ORDER: PCI_BUS_ID RUNNER: 10.0.13.31 - TORCH_CUDA_ARCH_LIST: '8.0 8.6 8.9 9.0 12.0' - CUDA_ARCH_LIST: '8.0 8.6 8.9 9.0 12.0' - RELEASE_MODE: 1 - CI: 1 - GPTQMODEL_FORCE_BUILD: 1 repo: ${{ github.event.inputs.repo || github.repository }} ref: ${{ github.event.inputs.ref || github.ref }} - MAX_JOBS: 8 - GPTQMODEL_BUILD_QQQ: 0 - GPTQMODEL_BUILD_EORA: 0 - GPTQMODEL_BUILD_EXLLAMA_V1: 0 concurrency: group: ${{ github.event.inputs.ref || github.ref }}-workflow-release @@ -86,8 +60,7 @@ jobs: image: modelcloud/gptqmodel:alpine-ci-v1 outputs: ip: ${{ steps.get_ip.outputs.ip }} - task_list: ${{ steps.assign.outputs.task_list }} - max-parallel: ${{ steps.get_ip.outputs.max-parallel }} + run_id: ${{ steps.get_ip.outputs.run_id }} if: ${{ inputs.github_vm == false }} steps: - name: Checkout Codes @@ -101,349 +74,14 @@ jobs: echo "event name: ${{ github.event_name }}" echo "repo: ${{ env.repo }}" echo "ref: ${{ env.ref }}" - echo "max-parallel: ${{ inputs.max-parallel }}" echo "upload_release: ${{ inputs.upload_release }}" echo "upload_pypi: ${{ inputs.upload_pypi }}" + - name: Select server id: get_ip run: | echo "ip=${RUNNER}" >> "$GITHUB_OUTPUT" - echo "GPU_IP=${RUNNER}" >> $GITHUB_ENV - echo "ip: $ip" - max_p=${{ github.event.inputs.max-parallel }} - max_p="{\"size\": ${max_p:-10}}" - echo "max-parallel=$max_p" >> "$GITHUB_OUTPUT" - echo "max-parallel=$max_p" - release: - strategy: - fail-fast: false - max-parallel: ${{ fromJson(needs.check-vm.outputs.max-parallel).size || 10 }} - matrix: - include: - # pytorch 2.9.1 - # cuda 130 - # not supported now, unknow error - # - cuda: 130 - # torch: 2.9.1 - # python: 314t - - cuda: 130 - torch: 2.9.1 - python: 314 - - cuda: 130 - torch: 2.9.1 - python: 313t - - cuda: 130 - torch: 2.9.1 - python: 313 - - cuda: 130 - torch: 2.9.1 - python: 312 - - cuda: 130 - torch: 2.9.1 - python: 311 - - cuda: 130 - torch: 2.9.1 - python: 310 - - # cuda 128 - - cuda: 128 - torch: 2.9.1 - python: 314t - - cuda: 128 - torch: 2.9.1 - python: 314 - - cuda: 128 - torch: 2.9.1 - python: 313t - - cuda: 128 - torch: 2.9.1 - python: 313 - - cuda: 128 - torch: 2.9.1 - python: 312 - - cuda: 128 - torch: 2.9.1 - python: 311 - - cuda: 128 - torch: 2.9.1 - python: 310 - - # pytorch 2.9 - # cuda 13 - - cuda: 130 - torch: 2.9.0 - python: 314 - - cuda: 130 - torch: 2.9.0 - python: 314t - - cuda: 130 - torch: 2.9.0 - python: 313t - - cuda: 130 - torch: 2.9.0 - python: 313 - - cuda: 130 - torch: 2.9.0 - python: 312 - - cuda: 130 - torch: 2.9.0 - python: 311 - - cuda: 130 - torch: 2.9.0 - python: 310 - # cuda 128 - - cuda: 128 - torch: 2.9.0 - python: 314t - - cuda: 128 - torch: 2.9.0 - python: 314 - - cuda: 128 - torch: 2.9.0 - python: 313t - - cuda: 128 - torch: 2.9.0 - python: 313 - - cuda: 128 - torch: 2.9.0 - python: 312 - - cuda: 128 - torch: 2.9.0 - python: 311 - - cuda: 128 - torch: 2.9.0 - python: 310 - # cuda 126 - - cuda: 126 - torch: 2.9.0 - python: 314t - - cuda: 126 - torch: 2.9.0 - python: 314 - - cuda: 126 - torch: 2.9.0 - python: 313t - - cuda: 126 - torch: 2.9.0 - python: 313 - - cuda: 126 - torch: 2.9.0 - python: 312 - - cuda: 126 - torch: 2.9.0 - python: 311 - - cuda: 126 - torch: 2.9.0 - python: 310 - # not support yet, maybe later - # pytorch 2.8 - # cuda 13 - # - cuda: 130 - # torch: 2.8.0 - # python: 314 - # - cuda: 130 - # torch: 2.8.0 - # python: 313t - # - cuda: 130 - # torch: 2.8.0 - # python: 313 - # - cuda: 130 - # torch: 2.8.0 - # python: 312 - # - cuda: 130 - # torch: 2.8.0 - # python: 311 - # - cuda: 130 - # torch: 2.8.0 - # python: 310 - # cuda 128 - - cuda: 128 - torch: 2.8.0 - python: 313t - - cuda: 128 - torch: 2.8.0 - python: 313 - - cuda: 128 - torch: 2.8.0 - python: 312 - - cuda: 128 - torch: 2.8.0 - python: 311 - - cuda: 128 - torch: 2.8.0 - python: 310 - - cuda: 126 - torch: 2.8.0 - python: 313t - - cuda: 126 - torch: 2.8.0 - python: 313 - - cuda: 126 - torch: 2.8.0 - python: 312 - - cuda: 126 - torch: 2.8.0 - python: 311 - - cuda: 126 - torch: 2.8.0 - python: 310 - - - runs-on: [ self-hosted, xeon5 ] - needs: - - check-vm - container: - image: 10.0.13.31:5000/nvidia/cuda:${{ matrix.cuda }}-ubuntu22.04_0206 - volumes: - - /monster/ci/env/entrypoint.sh:/entrypoint.sh - - /monster/ci/env/entrypoint.sh:/etc/profile.d/01-entrypoint.sh - - /monster/ci/uv:/opt/uv - - /monster/ci/env:/opt/env - steps: - - name: Check matrix - run: | - # // for github limits, test only - # // shouldRun=${{ (inputs['cuda-version'] == '' || matrix.cuda == inputs['cuda-version']) && (inputs['torch-version'] == '' || matrix.torch == inputs['torch-version']) && (inputs['python-version'] == ''|| matrix.python == inputs['python-version']) }} - shouldRun=true - if [ "$shouldRun" == "true" ]; then - echo "SHOULD_RUN=1" >> $GITHUB_ENV - else - echo "SHOULD_RUN=0" >> $GITHUB_ENV - fi - - - name: Checkout Codes - if: env.SHOULD_RUN == 1 - uses: actions/checkout@v6 - with: - repository: ${{ env.repo }} - ref: ${{ env.ref }} - - - name: Fetch PR by number - if: ${{ github.event.inputs.pr_number != 0 && env.SHOULD_RUN == 1 }} - run: | - PR_NUMBER=${{ github.event.inputs.pr_number }} - echo "pr number $PR_NUMBER" - git config --global --add safe.directory $(pwd) - git fetch origin pull/${PR_NUMBER}/head:pr-${PR_NUMBER} - git checkout pr-${PR_NUMBER} - - - name: Activate uv env - run: | - if [[ "${{ matrix.cuda }}" -lt 128 ]]; then # CUDA >= 12.8 supports 12.0 (5090) - echo "CUDA_ARCH_LIST=8.0 8.6 8.9 9.0" >> $GITHUB_ENV - echo "TORCH_CUDA_ARCH_LIST=8.0 8.6 8.9 9.0" >> $GITHUB_ENV - fi - python_version=${{ matrix.python }} - if [[ "$python_version" != *"."* ]]; then - python_version="${python_version/3/3.}" - fi - - export UV_PYTHON=$python_version - echo "UV_PYTHON=$python_version" >> "$GITHUB_ENV" - echo "UV_TORCH_BACKEND=cu${{ matrix.cuda }}" >> "$GITHUB_ENV" - - env_name="cu${{ matrix.cuda }}_torch${{ matrix.torch }}_py${python_version}_release" - /opt/uv/setup_uv_venv.sh $env_name - - - name: Setup uv env - if: env.SHOULD_RUN == 1 - run: | - /opt/env/init_compiler_torch_only.sh ${{ matrix.cuda }} ${{ matrix.torch }} $UV_PYTHON - - - name: Print uv env - if: env.SHOULD_RUN == 1 - run: | - echo "::group::uv python list" - uv python list - ls -ahl /opt/uv/venvs - echo "::endgroup::" - - echo "== python ==" - python --version - which python - which pip || true - - echo "== nvcc ==" - nvcc --version - - echo "::group::pip list" - uv pip list - echo "::endgroup::" - - echo "== torch ==" - uv pip show torch || true - - echo "::group::project files" - ls -ahl - echo "::endgroup::" - - echo "::group::git status" - git config --global --add safe.directory $(pwd) - git status - echo "::endgroup::" - - - name: Setup Compile env - if: env.SHOULD_RUN == 1 - run: uv pip install setuptools wheel build -U - - - name: Compile - if: env.SHOULD_RUN == 1 - run: | - set -e - echo "::group::First Run" - if ! python -m build -v --no-isolation; then - echo "::endgroup::" - - echo "::group::Retry" - python setup.py bdist_wheel - echo "::endgroup::" - else - echo "::endgroup::" - fi - - - name: Test install - if: env.SHOULD_RUN == 1 - run: | - ls -ahl dist - whl=$(ls -t dist/*.whl | head -n 1 | xargs basename) - echo "WHL_NAME=$whl" >> $GITHUB_ENV - - [ $(stat -c%s "dist/$whl") -lt 104857600 ] && echo "wheel size < 100M" && exit 1 || echo "$whl size check passed." - - if [ "${{ matrix.python }}" != "313t" ]; then # twine doesn't support python 3.13 yet - uv pip install twine - twine check dist/$whl - fi - - - name: Upload wheel - if: env.SHOULD_RUN == 1 - continue-on-error: true - run: | - sha256=$(sha256sum dist/${{ env.WHL_NAME }}) - - DIR=/opt/dist/${{ needs.check-vm.outputs.run_id }} - [ -d $DIR ] || mkdir -p $DIR - cp dist/${{ env.WHL_NAME }} $DIR/ - echo "UPLOADED=1" >> $GITHUB_ENV - - - name: Upload artifact - if: env.SHOULD_RUN == 1 - uses: actions/upload-artifact@v7 - continue-on-error: ${{ env.UPLOADED == '1' }} - with: - overwrite: false - name: ${{ env.WHL_NAME }} - path: dist/${{ env.WHL_NAME }} - - - name: Upload binaries to release - uses: svenstaro/upload-release-action@v2 - if: env.SHOULD_RUN == 1 && (github.event_name == 'release' || github.event.inputs.upload_release == 'true') && !cancelled() - with: - repo_name: ${{ env.repo }} - tag: ${{ env.ref }} - file: dist/${{ env.WHL_NAME }} - file_glob: true - overwrite: true + echo "run_id=${{ github.run_id }}" >> "$GITHUB_OUTPUT" release-source: permissions: @@ -451,13 +89,12 @@ jobs: runs-on: [ self-hosted, xeon5 ] needs: - check-vm + if: ${{ inputs.github_vm == false }} container: - image: ${{ needs.check-vm.outputs.ip }}:5000/nvidia/cuda:128-ubuntu22.04_0206 + image: ${{ needs.check-vm.outputs.ip }}:5000/nvidia/cuda:130-ubuntu24.04_0325 volumes: - /monster/ci/env/entrypoint.sh:/etc/profile.d/01-entrypoint.sh - /monster/ci/uv:/opt/uv - env: - RELEASE_MODE: 0 steps: - name: Checkout Codes uses: actions/checkout@v6 @@ -465,56 +102,65 @@ jobs: repository: ${{ env.repo }} ref: ${{ env.ref }} - - name: Print Env + - name: Trust workspace for git + run: | + git config --global --add safe.directory "$(pwd)" + + - name: Print env run: | export UV_PYTHON=3.14 echo "UV_PYTHON=3.14" >> "$GITHUB_ENV" - env_name="cu128_torch2.8.0_py314_release_source" + env_name="gptqmodel_py314_release_source" /opt/uv/setup_uv_venv.sh $env_name - name: Fetch PR by number if: ${{ github.event.inputs.pr_number != 0 }} run: | PR_NUMBER=${{ github.event.inputs.pr_number }} - echo "pr number $PR_NUMBER" - git config --global --add safe.directory $(pwd) + git config --global --add safe.directory "$(pwd)" git fetch origin pull/${PR_NUMBER}/head:pr-${PR_NUMBER} git checkout pr-${PR_NUMBER} - - name: Compile + - name: Build sdist run: | - uv pip install build wheel twine setuptools torch ninja -U - python -m build --no-isolation --sdist + python -m pip install -U build twine setuptools + python -m build --sdist - name: Check dist run: | ls -ahl dist - whl=$(ls -t dist/*.gz | head -n 1 | xargs basename) - echo "WHL_NAME=$whl" >> $GITHUB_ENV - twine check dist/$whl + pkg=$(ls -t dist/*.tar.gz | head -n 1 | xargs basename) + echo "PKG_NAME=$pkg" >> "$GITHUB_ENV" + twine check "dist/$pkg" + + - name: test installation + run: | + uv venv local_uv_env + source local_uv_env/bin/activate + uv pip install "dist/${{ env.PKG_NAME }}" torch - name: Upload to local continue-on-error: true run: | - sha256=$(sha256sum dist/${{ env.WHL_NAME }}) + sha256sum "dist/${{ env.PKG_NAME }}" DIR=/opt/dist/${{ needs.check-vm.outputs.run_id }} - [ -d $DIR ] || mkdir -p $DIR - cp dist/${{ env.WHL_NAME }} $DIR/ - echo "UPLOADED=1" >> $GITHUB_ENV + [ -d "$DIR" ] || mkdir -p "$DIR" + cp "dist/${{ env.PKG_NAME }}" "$DIR/" + echo "UPLOADED=1" >> "$GITHUB_ENV" - name: Upload to artifact uses: actions/upload-artifact@v7 continue-on-error: ${{ env.UPLOADED == '1' }} with: - name: ${{ env.WHL_NAME }} - path: dist/${{ env.WHL_NAME }} + name: ${{ env.PKG_NAME }} + path: dist/${{ env.PKG_NAME }} - name: Upload package to release uses: svenstaro/upload-release-action@v2 if: (github.event_name == 'release' || github.event.inputs.upload_release == 'true') && !cancelled() with: - file: dist/${{ env.WHL_NAME }} + file: dist/${{ env.PKG_NAME }} tag: ${{ env.ref }} file_glob: true overwrite: true @@ -530,12 +176,12 @@ jobs: while [ "$status" -lt 0 ]; do status=$(curl -s "http://${RUNNER}/gpu/ci/confirm?id=${{ github.run_id }}×tamp=$timestamp") if [ "$status" == "2" ]; then - echo "PYPI_RELEASE_CONFIRMATION=$status" >> $GITHUB_ENV + echo "PYPI_RELEASE_CONFIRMATION=$status" >> "$GITHUB_ENV" elif [ "$status" -lt 0 ]; then sleep 5 else echo "release has been confirmed" - echo "PYPI_RELEASE_CONFIRMATION=$status" >> $GITHUB_ENV + echo "PYPI_RELEASE_CONFIRMATION=$status" >> "$GITHUB_ENV" fi done @@ -545,13 +191,10 @@ jobs: TWINE_USERNAME: "__token__" TWINE_PASSWORD: ${{ secrets.PYPI_KEY }} run: | - python -m twine upload dist/*gz + python -m twine upload dist/*.tar.gz release-source-github: runs-on: ubuntu-24.04 - env: - RELEASE_MODE: 0 - BUILD_CUDA_EXT: '0' if: ${{ inputs.github_vm == true }} steps: - name: Checkout Codes @@ -560,47 +203,49 @@ jobs: repository: ${{ env.repo }} ref: ${{ env.ref }} + - name: Trust workspace for git + run: | + git config --global --add safe.directory "$(pwd)" + - name: Fetch PR by number if: ${{ github.event.inputs.pr_number != 0 }} run: | PR_NUMBER=${{ github.event.inputs.pr_number }} - echo "pr number $PR_NUMBER" - git config --global --add safe.directory $(pwd) + git config --global --add safe.directory "$(pwd)" git fetch origin pull/${PR_NUMBER}/head:pr-${PR_NUMBER} git checkout pr-${PR_NUMBER} - uses: actions/setup-python@v6 with: - python-version: 3.14 + python-version: '3.14' cache: 'pip' - name: Install requirements run: | - uv pip install torch twine ninja setuptools setuptools-scm[toml]>=8.0 --system + python -m pip install -U build twine setuptools - - name: Compile + - name: Build sdist run: | - python -m build --no-isolation --sdist - + python -m build --sdist - name: Check dist run: | ls -ahl dist - whl=$(ls -t dist/*.gz | head -n 1 | xargs basename) - echo "WHL_NAME=$whl" >> $GITHUB_ENV - twine check dist/$whl + pkg=$(ls -t dist/*.tar.gz | head -n 1 | xargs basename) + echo "PKG_NAME=$pkg" >> "$GITHUB_ENV" + twine check "dist/$pkg" - name: Upload to artifact uses: actions/upload-artifact@v7 with: - name: ${{ env.WHL_NAME }} - path: dist/${{ env.WHL_NAME }} + name: ${{ env.PKG_NAME }} + path: dist/${{ env.PKG_NAME }} - name: Upload package to release uses: svenstaro/upload-release-action@v2 if: (github.event_name == 'release' || github.event.inputs.upload_release == 'true') && !cancelled() with: - file: dist/${{ env.WHL_NAME }} + file: dist/${{ env.PKG_NAME }} tag: ${{ env.ref }} file_glob: true overwrite: true @@ -616,4 +261,4 @@ jobs: TWINE_USERNAME: "__token__" TWINE_PASSWORD: ${{ secrets.PYPI_KEY }} run: | - python -m twine upload dist/*gz + python -m twine upload dist/*.tar.gz diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index 064583856..398d2463b 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -26,10 +26,6 @@ on: description: 'PR Number' required: false type: number - test_names: - description: 'Input Test(s) to Run (default all)' - required: false - default: '' test_regex: description: 'Regex to filter test files' required: false @@ -41,12 +37,15 @@ on: max-parallel: description: 'Parallel jobs' required: false - default: '10' - exclusive-gpu: - description: 'One Test Per GPU' - type: boolean + default: '4' + model-test-mode: + description: 'Model compat test mode' + type: choice required: false - default: true + default: 'fast' + options: + - 'fast' + - 'slow' server: description: 'Wheel Build Server' type: choice @@ -58,30 +57,34 @@ on: env: CUDA_DEVICE_ORDER: PCI_BUS_ID CUDA_VISIBLE_DEVICES: 0 - TORCH_CUDA_ARCH_LIST: '8.6 8.9 9.0 12.0' + # CI can allocate A100-class devices (sm_80), so the shared JIT arch list + # must include 8.0 in addition to the newer targets. + TORCH_CUDA_ARCH_LIST: '8.0 8.6 8.9 9.0 12.0' PYTORCH_ALLOC_CONF: 'expandable_segments:True' - MAX_JOBS: 4 # compile concurrency RUNNER: 10.0.13.31 XEON5: 10.0.14.249 - UV_INDEX_URL: http://10.0.14.249/simple - CUDA_VERSION: 131 + LOGBAR_ANIMATION: '0' + CUDA_VERSION: 130 UV_TORCH_BACKEND: cu130 - TORCH_VERSION: 2.10.0 + TORCH_VERSION: 2.11.0 # vllm doesn't support 3.14 PYTHON_VERSION: 3.13 UV_PYTHON: python3.13 # PYTHON_GIL: 0 // test libs don't support yet - BUILD_QQQ: 1 - BUILD_EORA: 1 - GPTQMODEL_BUILD_EXLLAMA_V1: 1 - GPTQMODEL_BUILD_EORA: 1 IGNORED_TEST_FILES: "test_tgi.py,test_gptneox.py,models/test_mixtral.py,models/test_phi_3_moe.py,test_bits_new.py,models/test_internlm.py,models/test_internlm2_5.py,models/test_xverse.py" - GPTQMODEL_FORCE_BUILD: 1 + MODEL_TEST_MODE: ${{ github.event.inputs['model-test-mode'] || 'fast' }} + MODEL_TEST_GPU_COUNT: '1' + DEFUSER_GIT_URL: git+https://github.com/modelcloud/Defuser.git + PYPCRE_GIT_URL: git+https://github.com/modelcloud/PyPcre.git + TOKENICER_GIT_URL: git+https://github.com/modelcloud/Tokenicer.git + LOGBAR_GIT_URL: git+https://github.com/modelcloud/LogBar.git + EVALUTION_GIT_URL: git+https://github.com/modelcloud/Evalution.git repo: ${{ github.event.inputs.repo || github.repository }} ref: ${{ github.event.inputs.ref || github.ref }} + HF_TOKEN: ${{ secrets.HF_TOKEN }} concurrency: - group: ${{ github.event.inputs.ref || github.ref }}-workflow-unit-tests-${{ github.event.inputs.test_names }} + group: ${{ github.event.inputs.ref || github.ref }}-workflow-unit-tests #-${{ github.event.inputs.test_names }} cancel-in-progress: true permissions: @@ -89,7 +92,7 @@ permissions: jobs: check-vm: - runs-on: ubuntu-24.04 + runs-on: ubuntu-latest outputs: ip: ${{ steps.get_ip.outputs.ip }} run_id: ${{ steps.get_ip.outputs.run_id }} @@ -101,11 +104,9 @@ jobs: echo "repo: ${{ env.repo }}" echo "ref: ${{ env.ref }}" echo "artifact_id: ${{ github.event.inputs.artifact_id }}" - echo "test_names: ${{ github.event.inputs.test_names }}" - echo "exclusive-gpu: ${{ github.event.inputs['exclusive-gpu'] }}" echo "selected server: ${{ github.event.inputs.server }}" - - name: Select server + - name: Set server id: get_ip run: | echo "ip=$RUNNER" >> "$GITHUB_OUTPUT" @@ -126,9 +127,10 @@ jobs: echo "max-parallel=$max_p" list-test-files: - runs-on: ubuntu-24.04 + runs-on: ubuntu-latest outputs: torch-files: ${{ steps.files.outputs.torch-files }} + model-files: ${{ steps.files.outputs.model-files }} m4-files: ${{ steps.files.outputs.m4-files }} steps: @@ -152,27 +154,27 @@ jobs: run: | test_files=$(python3 .github/scripts/list_test_files.py \ --ignored-test-files "$IGNORED_TEST_FILES" \ - --test-names "${{ github.event.inputs.test_names }}" \ --test-regex "${{ github.event.inputs.test_regex }}") - IFS='|' read -r torch_test_files mlx_files <<< "$test_files" + IFS='|' read -r torch_test_files model_test_files mlx_files <<< "$test_files" echo "torch-files=$torch_test_files" >> "$GITHUB_OUTPUT" + echo "model-files=$model_test_files" >> "$GITHUB_OUTPUT" echo "m4-files=$mlx_files" >> "$GITHUB_OUTPUT" echo "Test files: $test_files" echo "Torch Test files: $torch_test_files" + echo "Model Compat Test files: $model_test_files" echo "MLX Test files: $mlx_files" echo "Ignored Test files: $IGNORED_TEST_FILES" - build: - runs-on: ${{ fromJSON(github.event.inputs.server || '["self-hosted", "xeon5"]') }} + test: needs: - - check-vm - list-test-files - if: needs.list-test-files.outputs.torch-files != '[]' + - check-vm + runs-on: [ self-hosted, xeon5 ] container: - image: ${{ needs.check-vm.outputs.ip }}:5000/nvidia/cuda:${{ needs.check-vm.outputs.cuda_version }}-ubuntu22.04_0206 + image: ${{ needs.check-vm.outputs.ip }}:5000/nvidia/cuda:${{ needs.check-vm.outputs.cuda_version }}-ubuntu24.04_0325 options: --device /dev/dri --ipc=host --runtime=nvidia --gpus all volumes: - /monster/ci/env/entrypoint.sh:/entrypoint.sh @@ -181,10 +183,23 @@ jobs: - /monster/ci/models:/monster/data/model - /monster/ci/dataset:/monster/data/model/dataset - /monster/ci/huggingface:/github/home/.cache/huggingface - - /monster/ci/uv:/opt/uv + # - /monster/ci/uv:/opt/uv + - /github/workspace/uv:/opt/uv + - /monster/ci/uv/python:/opt/uv/python + - /monster/ci/uv/cache/python:/opt/uv/cache/python + - /monster/ci/uv/setup_uv_venv.sh:/opt/uv/setup_uv_venv.sh + - /monster/ci/uv/uv:/opt/uv/uv + - /monster/ci/uv/uvx:/opt/uv/uvx + - /monster/ci/uv/env:/opt/uv/env + - /monster/ci/uv/uv.toml:/opt/uv/uv.toml - /monster/ci/env:/opt/env - /monster/ci/dist:/opt/dist - + strategy: + fail-fast: false + max-parallel: ${{ fromJson(needs.check-vm.outputs['max-parallel']).size || 20 }} + matrix: + test_script: ${{ fromJSON(needs.list-test-files.outputs.torch-files) }} + if: always() && !cancelled() && needs.list-test-files.outputs.torch-files != '[]' # || github.event.inputs.artifact_id != '' steps: - name: Checkout Codes uses: actions/checkout@v6 @@ -201,163 +216,249 @@ jobs: git fetch origin pull/${PR_NUMBER}/head:pr-${PR_NUMBER} git checkout pr-${PR_NUMBER} - - name: Print env + - name: decompress uv cache + continue-on-error: true run: | - echo PATH=$PATH - echo UV_INSTALL_DIR=$UV_INSTALL_DIR - echo UV_PYTHON_BIN_DIR=$UV_PYTHON_BIN_DIR - echo UV_PYTHON_INSTALL_DIR=$UV_PYTHON_INSTALL_DIR - echo UV_PYTHON_CACHE_DIR=$UV_PYTHON_CACHE_DIR - echo UV_CACHE_DIR=$UV_CACHE_DIR - echo VENV_ROOT=$VENV_ROOT - ls /root -ahl + if [ -f /opt/dist/uv.tar.xz ]; then + TAR_FILE="/opt/dist/uv.tar.xz" + LAST_FILE="/opt/uv/cache/lastmodified" + + # Get modification time of tar.xz file (epoch seconds) + TAR_MTIME=$(stat -c %Y "$TAR_FILE") + + # Read last recorded modification time if file exists + if [ -f "$LAST_FILE" ]; then + LAST_MTIME=$(cat "$LAST_FILE") + else + LAST_MTIME=0 + fi + + # Compare timestamps to decide whether to decompress + if [ "$TAR_MTIME" = "$LAST_MTIME" ]; then + echo "uv.tar.xz unchanged, skip decompress" + else + echo "decompressing uv.tar.xz..." + + # Prepare temporary directory + mkdir -p /opt/uv/cache/tmp + rm -rf /opt/uv/cache/tmp/* + + # Extract archive + tar -xJf "$TAR_FILE" -C /opt/uv/cache/tmp + + # Replace existing uv directory + rm -rf /opt/uv/cache/uv + mv /opt/uv/cache/tmp/uv /opt/uv/cache/uv + + # Record latest modification time + echo "$TAR_MTIME" > "$LAST_FILE" + + echo "done!" + ls -ahl /opt/uv/cache + echo "==========" + ls -ahl /opt/uv/cache/uv + fi + fi - name: Activate uv env run: | - env_name="cu${{ needs.check-vm.outputs.cuda_version }}_torch${{ env.TORCH_VERSION }}_py${{ env.PYTHON_VERSION }}_build" + echo "-- loading unit test's config --" + source /opt/uv/setup_uv_venv.sh unit_test_env + + config_json="$(python3 .github/scripts/parse_test_config.py \ + --group tests \ + --test-name "${{ matrix.test_script }}")" + + py="$(printf '%s' "$config_json" | python3 -c 'import json, sys; print(json.load(sys.stdin)["py"])')" + gpu="$(printf '%s' "$config_json" | python3 -c 'import json, sys; print(json.load(sys.stdin)["gpu"])')" + + echo "PYTHON_VERSION=$py" >> "$GITHUB_ENV" + echo "GPU_COUNT=$gpu" >> "$GITHUB_ENV" + + echo "using py=$py gpu=$gpu for test ${{ matrix.test_script }}" + echo "-- loaded --" + + echo "-- setting up env --" + env_name="gptqmodel_test_cu${{ needs.check-vm.outputs.cuda_version }}_torch${{ env.TORCH_VERSION }}_py${py}_${{ matrix.test_script }}" /opt/uv/setup_uv_venv.sh $env_name + echo "-- set --" - name: Setup uv env run: | - bash /opt/env/init_compiler_torch_only.sh ${{ needs.check-vm.outputs.cuda_version }} ${{ env.TORCH_VERSION }} ${{ env.PYTHON_VERSION }} - - echo "::group::uv python list" - uv python list - ls -ahl /opt/uv/venvs - echo "::endgroup::" - - echo "== python ==" - python --version + python -V which python which pip || true - echo "== nvcc ==" - nvcc --version + echo "--- setting env... cuda=${{ needs.check-vm.outputs.cuda_version }} torch=${{ env.TORCH_VERSION }} python=${{ env.PYTHON_VERSION }}" + bash /opt/env/init_compiler_no_env.sh ${{ needs.check-vm.outputs.cuda_version }} ${{ env.TORCH_VERSION }} ${{ env.PYTHON_VERSION }} - echo "::group::pip list" - uv pip list - echo "::endgroup::" + echo "" + echo "" - echo "== torch ==" - uv pip show torch || true + echo "--- installing required deps..." + python .github/scripts/install_deps.py ${{ matrix.test_script }} - - name: Compress dir - run: | - mkdir dist || true - rm -rf dist/* || true - tar -zcf ../gptqmodel_source.tar.gz ./ - mv ../gptqmodel_source.tar.gz dist/ - sha256=$(sha256sum dist/gptqmodel_source.tar.gz) - echo "hash=$sha256" - echo "SOURCE_HASH=$sha256" >> $GITHUB_ENV - + echo "" + echo "" - # - name: Upload source to local + echo "--- uninstalling required deps..." + python .github/scripts/uninstall_deps.py ${{ matrix.test_script }} + + # - name: Install Evalution + # run: | + # uv pip install git+https://x-access-token:${{ secrets.REPO_TOKEN }}@github.com/ModelCloud/Evalution.git + # - name: Install requirements + # run: | + # bash -c "$(curl -L http://${RUNNER}/scripts/env/init_compiler_no_env.sh)" @ ${{ needs.check-vm.outputs.cuda_version }} ${{ env.TORCH_VERSION }} $python_version${{ env.PYTHON_VERSION }} + + # - name: Download source from local # continue-on-error: true # run: | - ## curl -s -F "runid=${{ github.run_id }}" -F "repo=${{ env.repo }}" -F "ref=${{ env.ref }}" -F "sha256=${{ env.SOURCE_HASH }}" -F "file=@dist/gptqmodel_source.tar.gz" http://$RUNNER/gpu/whl/upload - # DIR=/opt/dist/${{ needs.check-vm.outputs.run_id }} - # [ -d $DIR ] || mkdir -p $DIR - # cp dist/gptqmodel_source.tar.gz $DIR/ + # curl -s -O http://$RUNNER/whl/${{ env.repo }}/${{ github.run_id }}/gptqmodel_source.tar.gz + # ls -ahl . + # sha256=$(sha256sum $file_name) + # echo "sha256=$sha256" + # echo "SOURCE_DOWNLOADED=1" >> $GITHUB_ENV - - name: Upload source to github artifact - uses: actions/upload-artifact@v7 - with: - name: source - path: dist/gptqmodel_source.tar.gz + # - name: Download source from github + # if: env.SOURCE_DOWNLOADED == '' && !cancelled() + # uses: actions/download-artifact@v8 + # with: + # name: source + # path: dist + # run-id: ${{ github.run_id }} + + # - name: Uncompress source + # continue-on-error: true + # run: | + # find . -mindepth 1 ! -name "gptqmodel_source.tar.gz" -exec rm -rf {} + + # ls -ahl . + # tar -zxf gptqmodel_source.tar.gz - - name: Compile - if: github.event.inputs.artifact_id == '' && !cancelled() - timeout-minutes: 45 + - name: Install package from source run: | - set -euo pipefail - - echo "::group::python version" + uv pip uninstall gptqmodel || true + # make sure torch & torchvision are corresponded + /opt/env/install_torch.sh ${{ needs.check-vm.outputs.cuda_version }} ${{ env.TORCH_VERSION }} + uv pip install -r requirements.txt -i http://$RUNNER/simple/ --trusted-host $RUNNER --extra-index-url https://pypi.org/simple + + echo "===== install package from checkout =====" + uv pip install . -i http://$RUNNER/simple/ --trusted-host $RUNNER --extra-index-url https://pypi.org/simple + + echo "===== install ModelCloud git deps =====" + for pkg in defuser pypcre tokenicer logbar evalution; do + uv pip uninstall "$pkg" || true + done + for url in \ + "$DEFUSER_GIT_URL" \ + "$PYPCRE_GIT_URL" \ + "$TOKENICER_GIT_URL" \ + "$LOGBAR_GIT_URL" \ + "$EVALUTION_GIT_URL" + do + uv pip install "$url" + done + + # deps may reinstalled torch + /opt/env/install_torch.sh ${{ needs.check-vm.outputs.cuda_version }} ${{ env.TORCH_VERSION }} + + - name: Print uv env + run: | + echo "::group::uv python list" + uv python list + echo "::endgroup::" + which python which pip || true - python --version + python -V + + echo "::group::uv python list" + uv python list echo "::endgroup::" - log_dir=/opt/dist/${{ github.run_id }}/logs - [ -d $log_dir ] || mkdir -p $log_dir - log_file=$log_dir/build.log + echo "== nvcc ==" + nvcc --version - build_once() { - echo "::group::compile logs" - python setup.py bdist_wheel 2>&1 | tee $log_file - echo "::endgroup::" - } - - echo "==> Attempt #1 (no extra env)" - if build_once; then - echo "Build succeeded on attempt #1" - exit 0 - fi - - echo "==> Attempt #1 failed, retrying with env vars..." - export VERBOSE=1 - export CMAKE_VERBOSE_MAKEFILE=ON - export TORCH_EXTENSIONS_DIR=/tmp/torch_extensions - export MAX_JOBS=1 - export NVCC_THREADS=1 - - echo "==> Attempt #2 (with env)" - build_once + echo "== torch ==" + uv pip show torch || true - - name: check wheel - if: github.event.inputs.artifact_id == '' && !cancelled() - run: | - ls -ahl dist - whl=$(ls -t dist/*.whl | head -n 1 | xargs basename) - sha256=$(sha256sum dist/$whl) - echo "hash=$sha256" + echo "::group::project files" + ls -ahl + echo "::endgroup::" - echo "WHL_HASH=$sha256" >> $GITHUB_ENV - echo "WHL_NAME=$whl" >> $GITHUB_ENV + echo "== torch ==" + uv pip show torch - twine check dist/$whl - [ $(stat -c%s "dist/$whl") -lt 104857600 ] && echo "wheel size < 100M" && exit 1 || echo "$whl size check passed." + echo "::group::pip list" + uv pip list + echo "::endgroup::" - - name: Upload wheel to local - if: github.event.inputs.artifact_id == '' && !cancelled() - continue-on-error: true + - name: Find suitable GPU + if: ${{ !contains(matrix.test_script, 'ipex') && !contains(matrix.test_script, 'xpu') && !cancelled() }} run: | - WHEEL="$(readlink -f dist/*.whl)" - SHA256="$(sha256sum "$WHEEL" | awk '{print $1}')" + python .github/scripts/allocate_gpu.py \ + --base-url "http://${XEON5}" \ + --run-id "${{ github.run_id }}" \ + --test "${{ matrix.test_script }}" \ + --runner "${RUNNER_NAME:-unknown}" \ + --count "${{ env.GPU_COUNT }}" - DIR=/opt/dist/${{ github.run_id }} - [ -d $DIR ] || mkdir -p $DIR - cp $WHEEL $DIR/ + + - name: Run tests + run: | + extra_args=() + if [[ "${{ matrix.test_script }}" == *ipex* ]]; then + extra_args+=(--clear-cuda) + fi + if [[ "${{ matrix.test_script }}" == *xpu* ]]; then + extra_args+=(--clear-cuda --xpu-mode) + fi + python .github/scripts/run_tests.py \ + --base-url "http://${XEON5}" \ + --run-id "${{ github.run_id }}" \ + --test-script "${{ matrix.test_script }}" \ + --runner "${RUNNER_NAME}" \ + --gpu-id "${{ env.CUDA_VISIBLE_DEVICES }}" \ + "${extra_args[@]}" - name: Check log - if: ${{ github.event.inputs.artifact_id == '' && !cancelled() && failure() }} + if: ${{ !cancelled() && failure() }} continue-on-error: true run: | - log_dir=/opt/dist/${{ github.run_id }}/logs - log_file=$log_dir/build.log - grep -nE "ptxas fatal|nvcc fatal|error:|fatal error|ninja: build stopped" $log_file | head -n 50 || true + log_dir=/opt/dist/GPTQModel/${{ github.run_id }}/logs + log_file=$log_dir/${{ matrix.test_script }}.log - ls -ahl dist + grep -nE "nvcc fatal|error:|fatal error|ModuleNotFoundError|ImportError|AssertionError|Exception|is the correct path|No such file or directory|Repo id must be in" "$log_file" | head -n 50 || true + + tail -n 200 $log_file + exit 1 - - name: Upload wheel to github artifact - if: github.event.inputs.artifact_id == '' && !cancelled() - uses: actions/upload-artifact@v7 - with: - name: whl - path: dist/${{ env.WHL_NAME }} + - name: Release GPU + if: always() && !contains(matrix.test_script, 'ipex') && !contains(matrix.test_script, 'xpu') + run: | + python .github/scripts/release_gpu.py \ + --base-url "http://${XEON5}" \ + --run-id "${{ github.run_id }}" \ + --gpu-id "${{ env.CUDA_VISIBLE_DEVICES }}" \ + --timestamp "${{ env.STEP_TIMESTAMP }}" \ + --test "${{ matrix.test_script }}" \ + --runner "${RUNNER_NAME}" - name: Clean cache if: always() - run: rm -rf ./* .[^.] .??* # pip cache purge && uv cache clean && + run: | + echo "Cleaning workspace: $PWD" + rm -rf ./* .[^.] .??* || true + echo "cleaning venv: ${{ env.VIRTUAL_ENV }}" + rm -rf "${{ env.VIRTUAL_ENV }}" - torch: + test-models: needs: - - build - list-test-files - check-vm runs-on: [ self-hosted, xeon5 ] container: - image: ${{ needs.check-vm.outputs.ip }}:5000/nvidia/cuda:${{ needs.check-vm.outputs.cuda_version }}-ubuntu22.04_0206 + image: ${{ needs.check-vm.outputs.ip }}:5000/nvidia/cuda:${{ needs.check-vm.outputs.cuda_version }}-ubuntu24.04_0325 options: --device /dev/dri --ipc=host --runtime=nvidia --gpus all volumes: - /monster/ci/env/entrypoint.sh:/entrypoint.sh @@ -366,15 +467,23 @@ jobs: - /monster/ci/models:/monster/data/model - /monster/ci/dataset:/monster/data/model/dataset - /monster/ci/huggingface:/github/home/.cache/huggingface - - /monster/ci/uv:/opt/uv + # - /monster/ci/uv:/opt/uv + - /github/workspace/uv:/opt/uv + - /monster/ci/uv/python:/opt/uv/python + - /monster/ci/uv/cache/python:/opt/uv/cache/python + - /monster/ci/uv/setup_uv_venv.sh:/opt/uv/setup_uv_venv.sh + - /monster/ci/uv/uv:/opt/uv/uv + - /monster/ci/uv/uvx:/opt/uv/uvx + - /monster/ci/uv/env:/opt/uv/env + - /monster/ci/uv/uv.toml:/opt/uv/uv.toml - /monster/ci/env:/opt/env - /monster/ci/dist:/opt/dist strategy: fail-fast: false max-parallel: ${{ fromJson(needs.check-vm.outputs['max-parallel']).size || 20 }} matrix: - test_script: ${{ fromJSON(needs.list-test-files.outputs.torch-files) }} - if: always() && !cancelled() && (needs.build.result == 'success') && needs.list-test-files.outputs.torch-files != '[]' # || github.event.inputs.artifact_id != '' + test_script: ${{ fromJSON(needs.list-test-files.outputs.model-files) }} + if: always() && !cancelled() && needs.list-test-files.outputs.model-files != '[]' steps: - name: Checkout Codes uses: actions/checkout@v6 @@ -391,20 +500,53 @@ jobs: git fetch origin pull/${PR_NUMBER}/head:pr-${PR_NUMBER} git checkout pr-${PR_NUMBER} - - name: Print env + - name: decompress uv cache + continue-on-error: true run: | - echo PATH=$PATH - echo UV_INSTALL_DIR=$UV_INSTALL_DIR - echo UV_PYTHON_BIN_DIR=$UV_PYTHON_BIN_DIR - echo UV_PYTHON_INSTALL_DIR=$UV_PYTHON_INSTALL_DIR - echo UV_PYTHON_CACHE_DIR=$UV_PYTHON_CACHE_DIR - echo UV_CACHE_DIR=$UV_CACHE_DIR - echo VENV_ROOT=$VENV_ROOT - ls /root -ahl + if [ -f /opt/dist/uv.tar.xz ]; then + TAR_FILE="/opt/dist/uv.tar.xz" + LAST_FILE="/opt/uv/cache/lastmodified" + + # Get modification time of tar.xz file (epoch seconds) + TAR_MTIME=$(stat -c %Y "$TAR_FILE") + + # Read last recorded modification time if file exists + if [ -f "$LAST_FILE" ]; then + LAST_MTIME=$(cat "$LAST_FILE") + else + LAST_MTIME=0 + fi + + # Compare timestamps to decide whether to decompress + if [ "$TAR_MTIME" = "$LAST_MTIME" ]; then + echo "uv.tar.xz unchanged, skip decompress" + else + echo "decompressing uv.tar.xz..." + + # Prepare temporary directory + mkdir -p /opt/uv/cache/tmp + rm -rf /opt/uv/cache/tmp/* + + # Extract archive + tar -xJf "$TAR_FILE" -C /opt/uv/cache/tmp + + # Replace existing uv directory + rm -rf /opt/uv/cache/uv + mv /opt/uv/cache/tmp/uv /opt/uv/cache/uv + + # Record latest modification time + echo "$TAR_MTIME" > "$LAST_FILE" + + echo "done!" + ls -ahl /opt/uv/cache + echo "==========" + ls -ahl /opt/uv/cache/uv + fi + fi - name: Activate uv env run: | - env_name="cu${{ needs.check-vm.outputs.cuda_version }}_torch${{ env.TORCH_VERSION }}_py${{ env.PYTHON_VERSION }}_test_${{ matrix.test_script }}" + env_name="gptqmodel_test_cu${{ needs.check-vm.outputs.cuda_version }}_torch${{ env.TORCH_VERSION }}_py${{ env.PYTHON_VERSION }}_${{ matrix.test_script }}" /opt/uv/setup_uv_venv.sh $env_name - name: Setup uv env @@ -412,18 +554,39 @@ jobs: python -V which python which pip || true + echo "uv cache dir: $(uv cache dir)" echo "setting env... cuda=${{ needs.check-vm.outputs.cuda_version }} torch=${{ env.TORCH_VERSION }} python=${{ env.PYTHON_VERSION }}" bash /opt/env/init_compiler_no_env.sh ${{ needs.check-vm.outputs.cuda_version }} ${{ env.TORCH_VERSION }} ${{ env.PYTHON_VERSION }} python .github/scripts/install_deps.py ${{ matrix.test_script }} + python .github/scripts/uninstall_deps.py ${{ matrix.test_script }} + + - name: Install package from source + run: | + uv pip uninstall gptqmodel || true + /opt/env/install_torch.sh ${{ needs.check-vm.outputs.cuda_version }} ${{ env.TORCH_VERSION }} + uv pip install -r requirements.txt -i http://$RUNNER/simple/ --trusted-host $RUNNER --extra-index-url https://pypi.org/simple + uv pip install . -i http://$RUNNER/simple/ --trusted-host $RUNNER --extra-index-url https://pypi.org/simple + for pkg in defuser pypcre tokenicer logbar evalution; do + uv pip uninstall "$pkg" || true + done + for url in \ + "$DEFUSER_GIT_URL" \ + "$PYPCRE_GIT_URL" \ + "$TOKENICER_GIT_URL" \ + "$LOGBAR_GIT_URL" \ + "$EVALUTION_GIT_URL" + do + uv pip install "$url" + done + /opt/env/install_torch.sh ${{ needs.check-vm.outputs.cuda_version }} ${{ env.TORCH_VERSION }} - name: Print uv env run: | echo "::group::uv python list" uv python list - ls -ahl /opt/uv/venvs echo "::endgroup::" - + which python which pip || true python -V @@ -435,13 +598,6 @@ jobs: echo "== nvcc ==" nvcc --version - echo "::group::pip list" - uv pip list - echo "::endgroup::" - - echo "== torch ==" - uv pip show torch || true - echo "::group::project files" ls -ahl echo "::endgroup::" @@ -449,410 +605,145 @@ jobs: echo "== torch ==" uv pip show torch - # - name: Install requirements - # run: | - # bash -c "$(curl -L http://${RUNNER}/scripts/env/init_compiler_no_env.sh)" @ ${{ needs.check-vm.outputs.cuda_version }} ${{ env.TORCH_VERSION }} $python_version${{ env.PYTHON_VERSION }} - - # - name: Download source from local - # continue-on-error: true - # run: | - # curl -s -O http://$RUNNER/whl/${{ env.repo }}/${{ github.run_id }}/gptqmodel_source.tar.gz - # ls -ahl . - # sha256=$(sha256sum $file_name) - # echo "sha256=$sha256" - # echo "SOURCE_DOWNLOADED=1" >> $GITHUB_ENV + echo "::group::pip list" + uv pip list + echo "::endgroup::" - # - name: Download source from github - # if: env.SOURCE_DOWNLOADED == '' && !cancelled() - # uses: actions/download-artifact@v8 - # with: - # name: source - # path: dist - # run-id: ${{ github.run_id }} + - name: Find suitable GPU + if: ${{ !cancelled() }} + run: | + python .github/scripts/allocate_gpu.py \ + --base-url "http://${XEON5}" \ + --run-id "${{ github.run_id }}" \ + --test "${{ matrix.test_script }}" \ + --runner "${RUNNER_NAME:-unknown}" \ + --count "1" \ + --require-single - # - name: Uncompress source - # continue-on-error: true - # run: | - # find . -mindepth 1 ! -name "gptqmodel_source.tar.gz" -exec rm -rf {} + - # ls -ahl . - # tar -zxf gptqmodel_source.tar.gz + - name: Run tests + run: | + python .github/scripts/run_tests.py \ + --base-url "http://${XEON5}" \ + --run-id "${{ github.run_id }}" \ + --test-script "${{ matrix.test_script }}" \ + --runner "${RUNNER_NAME}" \ + --gpu-id "${{ env.CUDA_VISIBLE_DEVICES }}" \ + --model-test-mode "${MODEL_TEST_MODE}" - - name: Download wheel from local + - name: Check log + if: ${{ !cancelled() && failure() }} continue-on-error: true run: | - DIR=/opt/dist/${{ needs.check-vm.outputs.run_id }} - [ -d $DIR ] || exit 1 - echo "WHL_DOWNLOADED=1" >> $GITHUB_ENV + log_dir=/opt/dist/GPTQModel/${{ github.run_id }}/logs + log_file=$log_dir/${{ matrix.test_script }}.log - ls -ahl $DIR - - mkdir dist || true - - cp $DIR/*.whl dist/ - - # file_name=$(curl -s -F "runid=${{ needs.check-vm.outputs.run_id }}" -F "repo=${{ env.repo }}" -F "ref=${{ env.ref }}" -F "fuzz=1" "http://$RUNNER/gpu/whl/download") - - # echo "file_name=$file_name" - # - # if echo "$file_name" | grep -q "gptqmodel"; then - # mkdir dist || true - # cd dist - # curl -s -O http://$RUNNER/whl/${{ env.repo }}/${{ needs.check-vm.outputs.run_id }}/$file_name - # ls -ahl . - # sha256=$(sha256sum $file_name) - # echo "sha256=$sha256" - # echo "WHL_DOWNLOADED=1" >> $GITHUB_ENV - # fi - - - name: Download artifact from github - if: env.WHL_DOWNLOADED == '' && !cancelled() - uses: actions/download-artifact@v8 - with: - name: whl - path: dist - run-id: ${{ needs.check-vm.outputs.run_id }} + grep -nE "nvcc fatal|error:|fatal error|ModuleNotFoundError|ImportError|AssertionError|Exception|is the correct path|No such file or directory|Repo id must be in" "$log_file" | head -n 50 || true + tail -n 200 $log_file + exit 1 - - name: Install wheel + - name: Release GPU + if: always() run: | - uv pip uninstall gptqmodel - # make sure torch & torchvision are corresponded - /opt/env/install_torch.sh ${{ needs.check-vm.outputs.cuda_version }} ${{ env.TORCH_VERSION }} - uv pip install -r requirements.txt + python .github/scripts/release_gpu.py \ + --base-url "http://${XEON5}" \ + --run-id "${{ github.run_id }}" \ + --gpu-id "${{ env.CUDA_VISIBLE_DEVICES }}" \ + --timestamp "${{ env.STEP_TIMESTAMP }}" \ + --test "${{ matrix.test_script }}" \ + --runner "${RUNNER_NAME}" - if [[ "${{ matrix.test_script }}" == *xpu* ]]; then - echo "===== switching to xpu env =====" - # source /etc/profile.d/pyenv.sh && pyenv activate xpu - fi + - name: Clean cache + if: always() + run: | + echo "Cleaning workspace: $PWD" + rm -rf ./* .[^.] .??* || true + echo "cleaning venv: ${{ env.VIRTUAL_ENV }}" + rm -rf "${{ env.VIRTUAL_ENV }}" - # ipex doesn't need to compile kernels. xpu can't install cuda package - if [[ "${{ matrix.test_script }}" != *ipex* && "${{ matrix.test_script }}" != *xpu* ]]; then - echo "===== install dist/whl =====" - uv pip install dist/*.whl -i http://$RUNNER/simple/ --trusted-host $RUNNER --extra-index-url https://pypi.org/simple - else - echo "===== install with local files for xpu env =====" - export CUDA_VISIBLE_DEVICES="" - unset TORCH_CUDA_ARCH_LIST - uv pip install . --no-build-isolation - fi + check-torch: + runs-on: [ self-hosted, xeon5 ] + container: + image: 10.0.13.31:5000/nvidia/cuda:130-ubuntu24.04_0325 + options: --device /dev/dri --ipc=host --runtime=nvidia --gpus all + volumes: + - /monster/ci/env/entrypoint.sh:/entrypoint.sh + - /monster/ci/env/entrypoint.sh:/etc/profile.d/01-entrypoint.sh + - /dev/dri/by-path:/dev/dri/by-path + - /monster/ci/uv:/opt/uv + - /monster/ci/env:/opt/env + steps: + - name: Checkout Codes + uses: actions/checkout@v6 + with: + repository: ${{ env.repo }} + ref: ${{ env.ref }} - # deps may reinstalled torch - /opt/env/install_torch.sh ${{ needs.check-vm.outputs.cuda_version }} ${{ env.TORCH_VERSION }} + - name: Test pypi pip + run: | + uv venv pypi_pip_env + source pypi_pip_env/bin/activate + uv pip install pip -U + pip install gptqmodel torch -U - - name: Find suitable GPU - if: ${{ !contains(matrix.test_script, 'ipex') && !contains(matrix.test_script, 'xpu') && !cancelled() }} + - name: Test pypi uv run: | - set -Eeuo pipefail - - # -------------------- Configuration -------------------- - BASE_URL="http://${XEON5}" - RUN_ID="${{ github.run_id }}" - TEST="${{ matrix.test_script }}" - RUNNER="${RUNNER_NAME:-unknown}" - EXCLUSIVE="${{ github.event.inputs['exclusive-gpu'] }}" - - SLEEP_SEC=5 - TIMEOUT_SEC=18000 # Max wait time: 300 minutes - # ------------------------------------------------------- - - # URL encode helper to avoid breaking query parameters - urlencode(){ python -c 'import sys,urllib.parse;print(urllib.parse.quote(sys.argv[1],""))' "$1"; } - - test_q=$(urlencode "$TEST") - runner_q=$(urlencode "$RUNNER") - exclusive_q=$(urlencode "$EXCLUSIVE") - - start_s=$(date +%s) - gpu_id="" - - echo "Requesting GPU from allocator" - echo "run_id=$RUN_ID test=$TEST runner=$RUNNER exclusive=$EXCLUSIVE" - - while true; do - # Refresh timestamp every request to avoid cache or duplicated requests - ts_ms=$(date +%s%3N) - url="$BASE_URL/gpu/get?runid=$RUN_ID×tamp=$ts_ms&test=$test_q&runner=$runner_q&exclusive=$exclusive_q" - echo "requesting GPU with: $url" - - # Call allocator - # - May return: - # * integer >= 0 : allocated GPU ID - # * integer < 0 : no GPU available yet - # * empty / HTML : temporary server / proxy error - resp="$( - curl -fsSL --connect-timeout 3 --max-time 10 \ - --retry 3 --retry-delay 1 --retry-all-errors \ - "$url" 2>/dev/null || true - )" - - echo "resp={$resp}" - - # Normalize response - resp="${resp//$'\r'/}" - resp="${resp//$'\n'/}" - - # If response is empty or not an integer, treat as temporary error - if [[ -z "$resp" || ! "$resp" =~ ^-?[0-9]+$ ]]; then - echo "Allocator returned invalid response: '$resp' (temporary error)" - curl -fsSL "$BASE_URL/gpu/status" || true - sleep "$SLEEP_SEC" - continue - fi - - # Negative integer means no GPU available yet - if (( resp < 0 )); then - elapsed=$(( $(date +%s) - start_s )) - if (( elapsed >= TIMEOUT_SEC )); then - echo "Timed out after ${TIMEOUT_SEC}s waiting for GPU (last response=$resp)" - curl -fsSL "$BASE_URL/gpu/status" || true - exit 1 - fi - - echo "No GPU available (response=$resp). Waiting ${SLEEP_SEC}s... elapsed=${elapsed}s" - curl -fsSL "$BASE_URL/gpu/status" || true - sleep "$SLEEP_SEC" - continue - fi - - # Successful allocation - gpu_id="$resp" - echo "Allocated GPU ID: $gpu_id" - break - done - - # Export environment variables for subsequent steps - echo "CUDA_VISIBLE_DEVICES=$gpu_id" >> "$GITHUB_ENV" - echo "STEP_TIMESTAMP=$(date +%s%3N)" >> "$GITHUB_ENV" - echo "CUDA_VISIBLE_DEVICES set to $gpu_id" - - # Final status snapshot - curl -fsSL "$BASE_URL/gpu/status" || true + uv venv pypi_uv_env + source pypi_uv_env/bin/activate + uv pip install gptqmodel torch -U + - name: test local pip + run: | + uv venv local_pip_env + source local_pip_env/bin/activate + uv pip install pip -U + pip install . torch -U - - name: Run tests - if: ${{ (!github.event.inputs.test_names || contains(github.event.inputs.test_names, matrix.test_script)) && !cancelled() }} + - name: test local uv run: | - if [[ "${{ matrix.test_script }}" == *ipex* ]]; then - export CUDA_VISIBLE_DEVICES="" - fi - if [[ "${{ matrix.test_script }}" == *xpu* ]]; then - export CUDA_VISIBLE_DEVICES="" - uv pip uninstall vllm -y - uv pip list - fi - - start_monitor() { - keep_alive_url="http://$XEON5/gpu/keepalive?runid=${{ github.run_id }}&test=${{ matrix.test_script }}&runner=${RUNNER_NAME}×tamp=$(date +%s)&exclusive=${{ github.event.inputs['exclusive-gpu'] }}&gpu=${CUDA_VISIBLE_DEVICES}" - echo "start to keep alive... ${keep_alive_url}" - while true; do - resp=$(curl -fsSL $keep_alive_url 2>/dev/null || echo "") - - if [ "$(echo "$resp" | tr -d '[:space:]')" = "-1" ]; then - echo "Server returned -1, terminating job..." - pkill -9 -f "pytest.*${{ matrix.test_script }}" 2>/dev/null || true - exit 3 - else - echo "gpu is kept alive..." - fi - - sleep 60 - done - } - - start_monitor & - MONITOR_PID=$! - - cleanup() { - echo "trap cleanup EXIT..." - kill $MONITOR_PID 2>/dev/null || true - wait $MONITOR_PID 2>/dev/null || true - } - trap cleanup EXIT - - log_dir=/opt/dist/${{ github.run_id }}/logs - log_file=$log_dir/${{ matrix.test_script }}.log - - mkdir -p "$(dirname "$log_file")" - - start_time=$(date +%s) - pytest --durations=0 tests/${{ matrix.test_script }}.py 2>&1 | tee $log_file - test ${PIPESTATUS[0]} -eq 0 || { echo "ERROR=2" >> $GITHUB_ENV; echo "pipe status wrong: ${PIPESTATUS[0]} $PIPESTATUS"; exit 2; } - execution_time=$(( $(date +%s) - start_time )) - echo "$((execution_time / 60))m $((execution_time % 60))s" - - ls -ahl $log_dir || true + uv venv local_uv_env + source local_uv_env/bin/activate + uv pip install . torch -U - curl "http://$XEON5/gpu/logVram?runid=${{ github.run_id }}&gpu=${{ env.CUDA_VISIBLE_DEVICES }}&range=$execution_time&unit=second&test=${{ matrix.test_script }}" - - name: Check log - if: ${{ (!github.event.inputs.test_names || contains(github.event.inputs.test_names, matrix.test_script)) && !cancelled() && failure() }} - continue-on-error: true - run: | - log_dir=/opt/dist/${{ github.run_id }}/logs - log_file=$log_dir/${{ matrix.test_script }}.log - - grep -nE "nvcc fatal|error:|fatal error|ModuleNotFoundError|is the correct path|No such file or directory|Repo id must be in" "$log_file" | head -n 50 || true + prepare-setuptools: + runs-on: ubuntu-latest + outputs: + versions: ${{ steps.parser.outputs.versions || '[]' }} + steps: + - uses: actions/checkout@v6 + - uses: actions/setup-python@v6 + with: + python-version: "3.14" - - name: Release GPU - if: always() && !contains(matrix.test_script, 'ipex') && !contains(matrix.test_script, 'xpu') + - name: Generate version matrix + id: parser run: | - url="http://$XEON5/gpu/release?runid=${{ github.run_id }}&gpu=${{ env.CUDA_VISIBLE_DEVICES }}×tamp=${{ env.STEP_TIMESTAMP }}&test=${{ matrix.test_script }}&runner=${RUNNER_NAME}" - echo "$url" - resp=$(curl -fsSL $url) + python -m pip install --upgrade requests packaging + versions=$(python .github/scripts/ci_loop_versions.py setuptools ">=77.0.1,<83") + echo "versions=$versions" >> "$GITHUB_OUTPUT" + + check-setuptools: + needs: prepare-setuptools + if: needs.prepare-setuptools.outputs.versions != '[]' + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + version: ${{ fromJSON(needs.prepare-setuptools.outputs.versions) }} - echo "response: $resp" - if [ "$resp" != "${{ env.CUDA_VISIBLE_DEVICES }}" ]; then - echo "Error: response ($resp) != expected (${{ env.CUDA_VISIBLE_DEVICES }})" - exit 0 - fi + steps: + - uses: actions/checkout@v6 + - uses: actions/setup-python@v6 + with: + python-version: "3.14" + cache: pip - - name: Clean cache - if: always() + - name: Install package with selected setuptools + run: | + python -m pip install --upgrade pip + python -m pip install . "setuptools==${{ matrix.version }}" + + - name: Show versions run: | - # rm ~/.cache/evalplus/*pkl || true - # pip cache purge && uv cache clean - rm -rf rm -rf ./* .[^.] .??* - -# show-statistics: -# runs-on: [ self-hosted, xeon5 ] -# if: always() && inputs['exclusive-gpu'] && !cancelled() -# container: -# image: modelcloud/gptqmodel:alpine-ci-v1 -# needs: -# - torch -# steps: -# - name: Print statistics -# run: curl "http://$RUNNER/gpu/get_vram_logs?id=${{ github.run_id }}" - -# m4: -# runs-on: [ self-hosted, m4 ] -# needs: -# - check-vm -# - list-test-files -# if: false && (github.event.inputs.test_names == '' || contains(github.event.inputs.test_names, 'apple') || contains(github.event.inputs.test_names, 'mlx') ) && (needs.list-test-files.outputs.m4-files != '' && needs.list-test-files.outputs.m4-files != '[]') && !cancelled() -# strategy: -# fail-fast: false -# matrix: -# test_script: ${{ fromJSON(needs.list-test-files.outputs.m4-files) }} -# steps: -# - name: Print Env -# run: | -# echo "repo: ${{ env.repo }}" -# echo "ref: ${{ env.ref }}" -# ls -ahl . -# -# - name: Checkout Codes -# uses: actions/checkout@v6 -# with: -# repository: ${{ env.repo }} -# ref: ${{ env.ref }} -# -# - name: Run test -# run: | -# export PATH="/opt/homebrew/bin:$PATH" && eval "$(pyenv init -)" -# rm -rf venv || true -# -# echo "=== checking models dir is mounted" -# ls ../../../monster -# -# echo "=== activating venv" -# pyenv global 3.11.11 && python -m venv venv -# source venv/bin/activate -# -# rm profile.sb || true -# -# curl -O http://$RUNNER/scripts/m4/profile.sb -# -# echo "=== installing uv setuptools build" -# pip install setuptools build -U -i http://$RUNNER/simple --trusted-host $RUNNER --extra-index-url https://pypi.org/simple -# -# echo "=== installing test tools" -# uv pip install pytest parameterized vllm lm-eval device-smi mlx-lm -U -i http://$RUNNER/simple/ --trusted-host $RUNNER --extra-index-url https://pypi.org/simple -# -# echo "=== installing gptqmodel" -# uv pip install . --no-build-isolation -i http://$RUNNER/simple/ --trusted-host $RUNNER --extra-index-url https://pypi.org/simple -# -# echo "replacing model path" -# find tests -name "*.py" -exec sed -i '' 's/\/monster\/data\/model/..\/..\/..\/monster/g' {} + -# -# TEST=${{ matrix.test_script }} -# if [[ ! "$TEST" == *.py ]]; then -# TEST="$TEST.py" -# fi -# echo "=== running test: $TEST" -# pytest tests/$TEST -# -# - name: Clean cache -# if: always() -# run: | -# source venv/bin/activate && pip cache purge && uv cache clean || true -# rm -rf ../GPTQModel && mkdir ../GPTQModel -# -# mac-test: -# runs-on: macos-latest -# env: -# CUDA_VISIBLE_DEVICES: '' -# TORCH_CUDA_ARCH_LIST: '' -# MAX_JOBS: 3 -# BUILD_QQQ: 0 -# BUILD_EORA: 0 -# GPTQMODEL_BUILD_EXLLAMA_V1: 0 -# GPTQMODEL_BUILD_EORA: 0 -# GPTQMODEL_FORCE_BUILD: 0 -# steps: -# - name: Checkout Codes -# uses: actions/checkout@v6 -# -# - uses: actions/setup-python@v6 -# with: -# python-version: 3.12 -# cache: 'pip' -# -## it wastes too much time to find which exactly one caused installation failed, just unset them all..... -# - name: Install dependencies -# run: | -# unset CUDA_DEVICE_ORDER -# unset CUDA_VISIBLE_DEVICES -# unset TORCH_CUDA_ARCH_LIST -# unset PYTORCH_ALLOC_CONF -# unset MAX_JOBS -# unset RUNNER -# unset XEON5 -# unset UV_INDEX_URL -# unset CUDA_VERSION -# unset TORCH_VERSION -# unset PYTHON_VERSION -# unset # PYTHON_GIL -# unset BUILD_QQQ -# unset BUILD_EORA -# unset GPTQMODEL_BUILD_EXLLAMA_V1 -# unset GPTQMODEL_BUILD_EORA -# unset LEGACY_TESTS -# unset IGNORED_TEST_FILES -# unset GPTQMODEL_FORCE_BUILD -# unset repo -# unset ref -# -# python -V -# python -m venv venv -# source venv/bin/activate -# pip install pip uv setuptools build wheel torch -U -# pip install meson-python -U -# pip install numpy==2.2.6 -U -# -# uv pip install -e . --no-build-isolation -# pip install pip Pillow device_smi pypcre tokenicer threadpoolctl accelerate logbar transformers optimum torch -U -# -# - name: Run test -# run: | -# source venv/bin/activate -# python - <<'PY' -# import os -# from transformers import pipeline -# os.environ["CUDA_VISIBLE_DEVICES"] = "" -# os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" -# llm_pipeline = pipeline(model="JunHowie/Qwen3-0.6B-GPTQ-Int4") -# output = llm_pipeline("Which city is the capital of France?", max_new_tokens=100) -# print(output) -# -# assert "paris" in output.lower() -# PY + python --version + python -m pip show setuptools diff --git a/.gitignore b/.gitignore index cc3d0fea7..3301668f1 100644 --- a/.gitignore +++ b/.gitignore @@ -167,6 +167,7 @@ debug # .vscode .vscode/ +/.codex/ example.py @@ -182,5 +183,14 @@ example.py /gptqmodel_ext/marlin/kernel_fp16_ku4b8.cu /gptqmodel_ext/marlin/kernel_fp16_ku8b128.cu /gptqmodel_offload/ +/benchmark_artifcats/ +/benchmark_artifacts/ +/test_outputs/ +/tests/models/gptqmodel_offload*/ +/QWEN3-8B-AWQ/ +/evalplus_results/ +/cutlass/ +*.csv /gptqmodel_ext/machete/generated/ AGENT.md +.codex \ No newline at end of file diff --git a/BRANCH_CLEANUP.md b/BRANCH_CLEANUP.md new file mode 100644 index 000000000..cb8b29c74 --- /dev/null +++ b/BRANCH_CLEANUP.md @@ -0,0 +1,43 @@ +# Branch cleanup plan + +This repository has multiple divergent feature/scratch branches. The goal is to preserve useful work without dumping everything into `main`. + +## Keep + +- `main` +- `integration/review-2026-04` +- `gptq-pro-cuda-kernel` +- `copilot/analyze-gptq-enhancements-v2` (only if this work is still wanted) + +## Delete after preserving anything unique + +- `copilot/sub-pr-2` +- `copilot/analyze-gptq-enhancements` + +## Do not merge directly into `main` + +- `fix/gemma4-ampere-main` + +That branch is a large integration/refactor line and should be rebased/reviewed separately or cherry-picked selectively. + +## Current staging + +- PR #5 stages `gptq-pro-cuda-kernel` into `integration/review-2026-04` +- PR #6 stages `copilot/analyze-gptq-enhancements-v2` into `integration/review-2026-04` + +## Why this structure exists + +- Keeps `main` clean +- Preserves the focused CUDA-kernel work +- Preserves the only surviving Copilot GPTQ-Pro enhancement branch worth reviewing +- Avoids blindly merging duplicate or superseded scratch branches +- Leaves the large Gemma/refactor branch quarantined until intentionally handled + +## Remaining manual cleanup + +The current automation path can create branches and PRs but cannot directly delete remote branches. Once the staged work is reviewed, delete the superseded branches in GitHub UI or via git: + +```bash +git push origin --delete copilot/sub-pr-2 +git push origin --delete copilot/analyze-gptq-enhancements +``` diff --git a/CREDITS.md b/CREDITS.md index eb352d80a..b86c214af 100644 --- a/CREDITS.md +++ b/CREDITS.md @@ -1,6 +1,6 @@ # Credits -* [Qubitium](https://x.com/qubitium) and [ModelCloud](https://x.com/ModelCloudAI) team for maintaining and improving GPTQModel +* [Qubitium](https://x.com/qubitium) and [ModelCloud](https://x.com/ModelCloudAI) team for maintaining and improving GPT-QModel * **Elias Frantar**, **Saleh Ashkboos**, **Torsten Hoefler** and **Dan Alistarh**: for creating [GPTQ](https://github.com/IST-DASLab/gptq) and [Marlin](https://github.com/IST-DASLab/marlin). * **PanQiWei**: for creation of [AutoGPTQ](https://github.com/autogptq/AutoGPTQ) which this project code is based upon. * **FXMarty**: for maintaining and support of [AutoGPTQ](https://github.com/autogptq/AutoGPTQ). diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 000000000..d3b2b09f0 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,52 @@ +FROM nvidia/cuda:12.6.3-cudnn-devel-ubuntu22.04 + +ARG CONDA_DIR=/opt/conda +ARG ENV_NAME=gptq-pro-vllm + +ENV DEBIAN_FRONTEND=noninteractive \ + PYTHONUNBUFFERED=1 \ + PIP_NO_CACHE_DIR=1 \ + CUDA_DEVICE_ORDER=PCI_BUS_ID \ + HF_HOME=/workspace/.cache/huggingface \ + TRANSFORMERS_CACHE=/workspace/.cache/huggingface \ + PATH=${CONDA_DIR}/bin:${PATH} + +RUN apt-get update && apt-get install -y --no-install-recommends \ + bash \ + build-essential \ + ca-certificates \ + curl \ + git \ + git-lfs \ + libglib2.0-0 \ + libsm6 \ + libxext6 \ + libxrender1 \ + pkg-config \ + wget && \ + rm -rf /var/lib/apt/lists/* && \ + git lfs install + +RUN wget -qO /tmp/miniforge.sh https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh && \ + bash /tmp/miniforge.sh -b -p ${CONDA_DIR} && \ + rm -f /tmp/miniforge.sh && \ + conda config --system --set auto_activate_base false + +WORKDIR /workspace/GPTQ-Pro + +COPY environment.yml ./environment.yml + +RUN conda env create -f environment.yml && \ + conda clean -afy + +SHELL ["/bin/bash", "-lc"] + +COPY . . + +RUN source ${CONDA_DIR}/etc/profile.d/conda.sh && \ + conda activate ${ENV_NAME} && \ + python -m pip install --upgrade pip && \ + python -m pip install --extra-index-url https://download.pytorch.org/whl/cu128 "torch>=2.8.0" && \ + python -m pip install -v --no-build-isolation -e ".[vllm,eval,openai]" + +CMD ["/bin/bash", "-lc", "source /opt/conda/etc/profile.d/conda.sh && conda activate gptq-pro-vllm && exec bash"] diff --git a/MANIFEST.in b/MANIFEST.in index 7f0629d20..691a555cb 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,13 +1,17 @@ recursive-include gptqmodel_ext/awq *.h *.cuh *.cu *.cpp recursive-include gptqmodel_ext/exllama *.h *.cuh *.cu *.cpp recursive-include gptqmodel_ext/exllamav2 *.h *.cuh *.cu *.cpp +recursive-include gptqmodel_ext/exllamav3 *.h *.hpp *.cuh *.cu *.cpp *.c recursive-include gptqmodel_ext/exllama_eora/eora *.h *.cuh *.cu *.cpp *.py recursive-include gptqmodel_ext/marlin *.h *.cuh *.cu *.cpp *.hpp recursive-include gptqmodel_ext/machete *.h *.hpp *.cuh *.cu *.cpp *.py recursive-include gptqmodel_ext/cutlass_extensions *.h *.hpp *.cuh *.cu *.cpp *.py recursive-include gptqmodel_ext/qqq *.h *.cuh *.cu *.cpp +recursive-include gptqmodel_ext/gptq_pro *.h *.cuh *.cu *.cpp +recursive-include gptqmodel/exllamav3/util/hadamard_data *.txt include licenses/* include gptqmodel_ext/pack_block_cpu.cpp +include gptqmodel_ext/floatx_cpu.cpp include gptqmodel_ext/marlin/generate_kernels.py include gptqmodel_ext/machete/generate.py recursive-exclude gptqmodel_ext __pycache__ *.pyc diff --git a/Project.md b/Project.md new file mode 100644 index 000000000..94350db49 --- /dev/null +++ b/Project.md @@ -0,0 +1,9396 @@ +# PR Quality Review + +*Exported from chatGPT on 19/03/2026, 11:01:46 pm* + +--- + +### User: +WHAT DO YOU THINK ABOUT THIS PR ? https://github.com/groxaxo/GPTQ-Pro/pull/4 + +--- + +### ChatGPT: +I’m pulling the PR diff and checking what it changes, plus whether it looks correct, maintainable, and likely to help in practice. + +--- + +I’ve got the headline change already: it keeps standard GPTQ export/runtime compatibility, but replaces a single fallback smoother with a best-of offline search. I’m checking the actual code path and tests now to see whether it’s genuinely solid or just sounds good in the PR description. + +--- + +My take: **the idea is good, and the implementation looks directionally strong**. I would not reject it. I’d call it a **quality-oriented PR with low runtime risk**, because it keeps `quant_method=GPTQ` and `format=GPTQ`, and the new behavior is limited to the offline failsafe path rather than changing inference kernels. The core change is that `QuantizeConfig.gptq_pro()` now defaults to `SmoothAuto()` instead of a single `SmoothMSE`, and `SmoothAuto` tries several kernel-safe candidates—unsmoothed baseline, `SmoothMSE`, `SmoothMAD`, `SmoothPercentile`, and `SmoothPercentileAsymmetric`—then keeps the lowest per-row reconstruction MSE. [GitHub+2GitHub+2](https://github.com/groxaxo/GPTQ-Pro/pull/4) + +What I like: + +- It is **properly scoped**. This is not pretending to be AWQ or AutoRound-in-the-loop; it is a conservative offline improvement that still emits ordinary GPTQ weights/scales/zeros for the same inference path. That is the right way to squeeze more quality without breaking downstream kernel compatibility. [GitHub+1](https://github.com/groxaxo/GPTQ-Pro/pull/4) +- The **“include_none” baseline** is a smart failsafe. Because the unsmoothed candidate can win, the PR explicitly guards against cases where smoothing makes a row worse. [GitHub+1](https://github.com/groxaxo/GPTQ-Pro/pull/4/files) +- The refactor into `_failsafe_quantize_block()` plus `_failsafe_quantize_auto()` is cleaner than jamming more branches into the old path, and the config surface is wired through parsing/serialization/export, so this is not a hacky one-off. [GitHub+1](https://github.com/groxaxo/GPTQ-Pro/pull/4/commits/3b6de40c6f78a6fd5228f514d369a8b71c516ad3) + +What I’d watch: + +- The selection criterion is **local row-wise reconstruction MSE**. That is a sensible proxy, but it is still only a proxy. A row having lower MSE does **not automatically guarantee** better perplexity or downstream task quality at the model level; that part still needs empirical validation. This is an inference from the implementation, which explicitly chooses candidates by per-row MSE. [GitHub](https://github.com/groxaxo/GPTQ-Pro/pull/4/commits/3b6de40c6f78a6fd5228f514d369a8b71c516ad3) +- It will **increase offline quantization work** on blocks that enter the failsafe path, because instead of running one smoother/quantization pass, it now evaluates multiple candidates and compares them. Since the preset says this is for low-sample fallback blocks, the blast radius is limited, but quantization-time overhead should still be measured. [GitHub+1](https://github.com/groxaxo/GPTQ-Pro/pull/4) +- The follow-up “polish” commit matters. The first version used direct shape checks for row replacement of `scale`/`zero`; the next commit replaced that with `_row_replace_mask()` and constants for fallback steps/maxshrink. That tells me the author already hit at least one tensor-shape corner case, so I’d want CI plus one real model smoke test before merging. [GitHub+1](https://github.com/groxaxo/GPTQ-Pro/pull/4/commits/169f1d534186c512788837025fbd3ca415ab70d7) + +Tests: the PR does add/adjust tests so that the auto smoother must match or beat the candidate smoothers within tolerance, which is exactly the right unit-level invariant for this feature. But from what is visible in the PR, I do **not** see model-level benchmark evidence yet. [GitHub](https://github.com/groxaxo/GPTQ-Pro/pull/4/commits/169f1d534186c512788837025fbd3ca415ab70d7) + +So my verdict: **good PR, probably worth merging, but only after adding a small benchmark note**. I’d want: + +1. perplexity or eval numbers on at least one awkward model, +2. quantization-time overhead versus prior `gptq_pro`, +3. one smoke test proving the exported artifact still loads cleanly in the intended GPTQ runtime stack. + +If you want, I can give you a **proper line-by-line code review comment set** you could paste into the PR. + +--- + +### User: +yes please + +--- + +### ChatGPT: +I’m opening the PR and its diff so I can turn this into concrete review comments you can paste directly. + +--- + +I’ve confirmed the config/docs side. Now I’m trying to extract the `gptq.py` diff itself, since that’s where the real correctness risk lives. + +--- + +Paste this as your review. + +Overall review text: + +This looks solid to me. I like the scope: it keeps `quant_method=GPTQ` / `format=GPTQ`, moves the preset default from a single `SmoothMSE(steps=32, maxshrink=0.9)` to `SmoothAuto()`, and limits the new search to the offline failsafe path rather than changing the runtime kernel path. The refactor into `_failsafe_quantize_block()` plus `_failsafe_quantize_auto()` is also cleaner than growing the old branch inline. [GitHub+3GitHub+3GitHub+3](https://github.com/groxaxo/GPTQ-Pro/pull/4) + +Then add these review comments where relevant: + +1. On `gptqmodel/quantization/config.py` near `class SmoothAuto(...)`: + +Nice addition. Since `SmoothAuto` is now part of the serialized config surface, I’d add lightweight validation here or in `_build_smooth_method_from_dict()` for obviously invalid values: `mse_steps > 0`, `0 < mse_maxshrink <= 1`, `0 < percentile <= 100`, and `0 <= low < high <= 100`. That would make malformed configs fail fast instead of producing confusing quantization behavior later. [GitHub+1](https://github.com/groxaxo/GPTQ-Pro/pull/4/files) + +1. On `QuantizeConfig.gptq_pro()` where the default smoother changes: + +This is a meaningful preset behavior change, not just an internal cleanup: `gptq_pro()` now defaults to `SmoothAuto()` instead of `SmoothMSE(steps=32, maxshrink=0.9)`. I’d add a short migration note / changelog entry so anyone relying on prior GPTQ-Pro reproducibility knows why outputs changed after upgrading. [GitHub+1](https://github.com/groxaxo/GPTQ-Pro/pull/4/files) + +1. On `gptqmodel/quantization/gptq.py` near `_failsafe_quantize_auto()`: + +I like the design here. My main ask is empirical validation in the PR description: the selector chooses the best candidate by per-row reconstruction MSE, which is a sensible local proxy, but it would be great to attach one small end-to-end datapoint as well (for example perplexity/eval delta plus quantization-time overhead on one representative model). That would make the case for the new default much stronger. [GitHub+1](https://github.com/groxaxo/GPTQ-Pro/pull/4/commits/3b6de40c6f78a6fd5228f514d369a8b71c516ad3) + +1. On `gptqmodel/quantization/gptq.py` near `AUTO_SMOOTH_FALLBACK_STEPS` / `AUTO_SMOOTH_FALLBACK_MAXSHRINK`: + +Good cleanup extracting these constants in the polish commit. One small improvement: add a short comment explaining why non-`SmoothMSE` candidates use `32` / `0.8` while the `SmoothAuto` MSE candidate uses its own `mse_steps` / `mse_maxshrink` defaults (`48` / `0.85`). The code is readable, but the rationale is not obvious at a glance. [GitHub+1](https://github.com/groxaxo/GPTQ-Pro/pull/4/commits/169f1d534186c512788837025fbd3ca415ab70d7) + +1. On `gptqmodel/quantization/gptq.py` near `_row_replace_mask(...)`: + +Good catch fixing the row-wise replacement mask for cases where `scale` / `zero` can be either `[rows, 1]` or `[rows]`. I’d add a targeted regression test for this exact shape-mismatch scenario across at least two failsafe strategies, because this looks like the kind of subtle bug that can easily come back during future refactors. [GitHub](https://github.com/groxaxo/GPTQ-Pro/pull/4/commits/169f1d534186c512788837025fbd3ca415ab70d7) + +1. On `tests/test_failsafe.py`: + +The invariant test is good: `auto_err` must be no worse than each single candidate within tolerance. I’d add one more config-focused test for round-trip serialization of all `SmoothAuto` fields (`include_none`, `mse_steps`, `mse_maxshrink`, `mad_k`, `percentile`, `low`, `high`) since this PR expands both parsing and `to_dict()` support for the new smoother. [GitHub+1](https://github.com/groxaxo/GPTQ-Pro/pull/4/commits/169f1d534186c512788837025fbd3ca415ab70d7) + +My own merge stance: **approve with comments** if you’re comfortable merging without benchmark numbers; **comment but don’t block** if you want the PR merged fast; **request changes** only if this repo requires model-level eval before changing preset defaults. The code and tests shown in the PR look strong enough that I would not reject it on design grounds. [GitHub+2GitHub+2](https://github.com/groxaxo/GPTQ-Pro/pull/4) + +--- + +### User: +just paste it as a reply to copilot please, @copilot and then the answer that i will have to provide with all this suggestions and clear next steps + +--- + +### ChatGPT: +@copilot This PR looks good overall. I like the scope and the way it stays conservative on compatibility: it keeps `quant_method=GPTQ` / `format=GPTQ`, changes the GPTQ-Pro preset from a single `SmoothMSE(...)` fallback to `SmoothAuto()`, and limits the new logic to the offline failsafe path instead of touching runtime kernel behavior. + +A few suggestions before merge: + +1. Please add lightweight validation for `SmoothAuto` config values, either in the dataclass or in `_build_smooth_method_from_dict()`, so invalid configs fail fast. In particular: + +- `mse_steps > 0` +- `0 < mse_maxshrink <= 1` +- `0 < percentile <= 100` +- `0 <= low < high <= 100` + +1. Since `QuantizeConfig.gptq_pro()` now changes default behavior from `SmoothMSE(steps=32, maxshrink=0.9)` to `SmoothAuto()`, please add a short changelog / migration note. This is a real preset behavior change, so users depending on reproducibility should know why outputs may differ after upgrading. +2. `_failsafe_quantize_auto()` is a good design, but the selection criterion is still row-wise reconstruction MSE. That is a sensible local proxy, but it would be great to include one small model-level datapoint in the PR description: + +- perplexity or eval delta on one representative model +- quantization-time overhead versus the previous GPTQ-Pro behavior + +1. The extracted constants for the fallback path are cleaner, but please add a brief code comment explaining why non-`SmoothMSE` candidates use the fallback `32` / `0.8`, while the `SmoothAuto` MSE candidate uses its own configurable defaults (`48` / `0.85`). The code is readable, but the rationale is not obvious. +2. `_row_replace_mask()` looks like the right fix for row-wise replacement when `scale` / `zero` can have different shapes. Please add a targeted regression test for this specific shape-handling case so it does not regress in later refactors. +3. The current failsafe invariant test is good. Please also add a serialization round-trip test for all `SmoothAuto` fields: + +- `include_none` +- `mse_steps` +- `mse_maxshrink` +- `mad_k` +- `percentile` +- `low` +- `high` + +Next steps I’d suggest: + +- add config validation +- add the regression + serialization tests +- add a short migration note +- attach one small benchmark note in the PR description covering quality and quantization-time cost + +Once those are in, I’d be comfortable merging this. + +--- + +### User: +what to do now? Copilot AI +commented +3 minutes ago +@copilot This PR looks good overall. I like the scope and the way it stays conservative on compatibility: it keeps quant_method=GPTQ / format=GPTQ, changes the GPTQ-Pro preset from a singl... + +Addressed in e6dee44: added SmoothAuto validation, the row-mask shape regression test, explicit invalid-config / serialization coverage, and a migration note for the GPTQ-Pro preset change. I also added the benchmark note in the PR progress update; I could not attach a trustworthy model-level perplexity/time datapoint from this sandbox because torch and transformers are not installed here, so I documented that limitation instead. I kept the percentile contract at 0 < value <= 100 to match your requested validation. + +--- + +### ChatGPT: +Do this now: + +**1. Skim the new commit e6dee44** +Check only these things: + +- validation was actually added for `SmoothAuto` +- regression test for row-mask shape exists +- serialization + invalid-config tests exist +- migration note is present + +**2. Don’t block on the sandbox benchmark limitation** +That part is reasonable. If they couldn’t run `torch` / `transformers` there, that is not a red flag by itself. The important part is they documented the limitation instead of inventing numbers. + +**3. If the diff matches what they said, approve and merge** +At this point the review items were addressed. The only thing still missing is a real model-level benchmark, but that is now a follow-up improvement, not a blocker. + +Paste this reply: + +@copilot Thanks — that addresses my review comments. I’m fine with the validation, regression coverage, serialization tests, and migration note being handled in `e6dee44`. + +I also appreciate that you explicitly documented the benchmarking limitation instead of attaching untrustworthy numbers from the sandbox. That’s acceptable for this PR. + +I’ll do a quick pass on the updated diff, but assuming the changes match the summary, this is good to merge. A proper model-level perplexity / quantization-time benchmark can be added later as follow-up validation rather than blocking this PR. + +If you want the stricter version instead, use this: + +@copilot Thanks — the requested code/test/doc changes sound addressed in `e6dee44`, so the remaining gap is only model-level validation. I’m okay with not blocking this PR on sandbox-constrained benchmarks, but I’d like a follow-up issue added for real perplexity + quantization-time measurements on at least one representative model before we treat `SmoothAuto` as fully validated in practice. + +My recommendation: **approve unless you personally want benchmark evidence before merge**. + +--- + +### User: +i want evidence now, ive merged, how to test on my personal comptuer? using a qwen 3.5 modoel 4b + +--- + +### ChatGPT: +I’m checking the repo’s current CLI/API plus the exact Qwen 3.5 4B model naming so I can give you a reproducible A/B test instead of hand-wavy advice. + +--- + +I’ve got the model name and the repo’s quantization API. The cleanest proof is an A/B quantization test: same machine, same calibration set, same seed, same prompts, same eval—only changing the repo commit. + +--- + +Yes. The fastest way to get real evidence is an **A/B quantization test on the same machine** with only one variable changed: **old commit vs new commit**. Use **Qwen/Qwen3.5-4B-Base** for the cleanest text-only-style quantization/eval workflow. GPT-QModel main says Qwen 3.5 support was added on **2026-02-28**, supports **Qwen 1/2/3/3.5**, and its documented quantization flow is `GPTQModel.load(..., quant_config) -> model.quantize(...) -> model.save(...)`. Qwen’s model card also says Qwen3.5 needs the latest `transformers` for serving/testing. [GitHub+2Hugging Face+2](https://github.com/ModelCloud/GPTQModel/blob/main/README.md) + +Your PR specifically changes the **offline failsafe smoothing path** and is meant to help **low-sample fallback blocks**, while keeping standard GPTQ artifacts/runtime compatibility. So don’t do just one benchmark: do **two** quantizations per commit — one that **stresses** failsafe behavior and one that reflects a more normal setup. [GitHub](https://github.com/groxaxo/GPTQ-Pro/pull/4) + +### What to measure + +You want 4 outputs for each commit: + +1. **Quantization wall-clock time** +2. **Peak GPU memory during quantization** +3. **Eval scores** on the same downstream tasks +4. **Deterministic generation diffs** on the same fixed prompts + +For eval, `lm-eval` currently supports Hugging Face backends via `pip install "lm_eval[hf]"`, supports local model paths with `pretrained=/path/to/model`, and lists tasks with `lm-eval ls tasks`. Good quick tasks that exist now include `hellaswag`, `winogrande`, `piqa`, `wikitext`, `arc`, and `gsm8k`. [GitHub+6GitHub+6GitHub+6](https://github.com/EleutherAI/lm-evaluation-harness) + +## 1) Make two repo worktrees + +Run this from your GPTQ-Pro clone: + +```Bashcd ~/GPTQ-Pro# baseline = parent of the first PR code commitBASE_COMMIT=$(git rev-parse 3b6de40^)# patched = last reviewed PR commit with fixesNEW_COMMIT=e6dee44git worktree add ../gptqpro_base "$BASE_COMMIT"git worktree add ../gptqpro_new "$NEW_COMMIT" +Bashcd ~/GPTQ-Pro# baseline = parent of the first PR code commitBASE_COMMIT=$(git rev-parse 3b6de40^)# patched = last reviewed PR commit with fixesNEW_COMMIT=e6dee44git worktree add ../gptqpro_base "$BASE_COMMIT"git worktree add ../gptqpro_new "$NEW_COMMIT" +``` + +That gives you: + +- `../gptqpro_base` = before the PR +- `../gptqpro_new` = after the PR fixes + +## 2) Create identical envs + +Do this once for each worktree, with the same Python / torch stack in both. Since Qwen3.5 wants current Transformers support, install `transformers` from main. [Hugging Face](https://huggingface.co/Qwen/Qwen3.5-4B) + +### Base env + +```Bashcd ../gptqpro_basepython3 -m venv .venvsource .venv/bin/activatepip install -U pip setuptools wheel ninja# install your normal CUDA-matching torch herepip install "transformers @ git+https://github.com/huggingface/transformers.git@main"pip install datasets accelerate sentencepiece "lm_eval[hf]"pip install -e . --no-build-isolationdeactivate +Bashcd ../gptqpro_basepython3 -m venv .venvsource .venv/bin/activatepip install -U pip setuptools wheel ninja# install your normal CUDA-matching torch herepip install "transformers @ git+https://github.com/huggingface/transformers.git@main"pip install datasets accelerate sentencepiece "lm_eval[hf]"pip install -e . --no-build-isolationdeactivate +``` + +### New env + +```Bashcd ../gptqpro_newpython3 -m venv .venvsource .venv/bin/activatepip install -U pip setuptools wheel ninja# install the same torch build as abovepip install "transformers @ git+https://github.com/huggingface/transformers.git@main"pip install datasets accelerate sentencepiece "lm_eval[hf]"pip install -e . --no-build-isolationdeactivate +Bashcd ../gptqpro_newpython3 -m venv .venvsource .venv/bin/activatepip install -U pip setuptools wheel ninja# install the same torch build as abovepip install "transformers @ git+https://github.com/huggingface/transformers.git@main"pip install datasets accelerate sentencepiece "lm_eval[hf]"pip install -e . --no-build-isolationdeactivate +``` + +## 3) Save this quantization script as quantize_once.py + +Put the same file in both worktrees: + +```Pythonimport osimport jsonimport timeimport randomimport argparseimport numpy as npimport torchfrom datasets import load_datasetfrom gptqmodel import GPTQModel, QuantizeConfigdef seed_all(seed: int): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)def main(): ap = argparse.ArgumentParser() ap.add_argument("--model", required=True) ap.add_argument("--out", required=True) ap.add_argument("--nsamples", type=int, default=128) ap.add_argument("--batch_size", type=int, default=4) ap.add_argument("--seed", type=int, default=1234) ap.add_argument("--group_size", type=int, default=128) ap.add_argument("--bits", type=int, default=4) args = ap.parse_args() seed_all(args.seed) dataset = load_dataset( "allenai/c4", data_files="en/c4-train.00001-of-01024.json.gz", split="train", ).select(range(args.nsamples))["text"] qcfg = QuantizeConfig.gptq_pro() qcfg.bits = args.bits qcfg.group_size = args.group_size if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() start = time.perf_counter() model = GPTQModel.load(args.model, qcfg) model.quantize(dataset, batch_size=args.batch_size) elapsed = time.perf_counter() - start os.makedirs(args.out, exist_ok=True) model.save(args.out) peak_mem_bytes = None if torch.cuda.is_available(): peak_mem_bytes = int(torch.cuda.max_memory_allocated()) meta = { "model": args.model, "out": args.out, "nsamples": args.nsamples, "batch_size": args.batch_size, "seed": args.seed, "group_size": args.group_size, "bits": args.bits, "elapsed_sec": elapsed, "peak_cuda_mem_bytes": peak_mem_bytes, } with open(os.path.join(args.out, "ab_meta.json"), "w") as f: json.dump(meta, f, indent=2) print(json.dumps(meta, indent=2))if __name__ == "__main__": main() +Pythonimport osimport jsonimport timeimport randomimport argparseimport numpy as npimport torchfrom datasets import load_datasetfrom gptqmodel import GPTQModel, QuantizeConfigdef seed_all(seed: int): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)def main(): ap = argparse.ArgumentParser() ap.add_argument("--model", required=True) ap.add_argument("--out", required=True) ap.add_argument("--nsamples", type=int, default=128) ap.add_argument("--batch_size", type=int, default=4) ap.add_argument("--seed", type=int, default=1234) ap.add_argument("--group_size", type=int, default=128) ap.add_argument("--bits", type=int, default=4) args = ap.parse_args() seed_all(args.seed) dataset = load_dataset( "allenai/c4", data_files="en/c4-train.00001-of-01024.json.gz", split="train", ).select(range(args.nsamples))["text"] qcfg = QuantizeConfig.gptq_pro() qcfg.bits = args.bits qcfg.group_size = args.group_size if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() start = time.perf_counter() model = GPTQModel.load(args.model, qcfg) model.quantize(dataset, batch_size=args.batch_size) elapsed = time.perf_counter() - start os.makedirs(args.out, exist_ok=True) model.save(args.out) peak_mem_bytes = None if torch.cuda.is_available(): peak_mem_bytes = int(torch.cuda.max_memory_allocated()) meta = { "model": args.model, "out": args.out, "nsamples": args.nsamples, "batch_size": args.batch_size, "seed": args.seed, "group_size": args.group_size, "bits": args.bits, "elapsed_sec": elapsed, "peak_cuda_mem_bytes": peak_mem_bytes, } with open(os.path.join(args.out, "ab_meta.json"), "w") as f: json.dump(meta, f, indent=2) print(json.dumps(meta, indent=2))if __name__ == "__main__": main() +``` + +## 4) Run the two key quantization cases + +Use the same model, same seed, same GPU, same everything. + +### Case A: stress the PR + +This is the important one, because your change targets **low-sample fallback blocks**. + +#### Baseline + +```Bashcd ../gptqpro_basesource .venv/bin/activateCUDA_VISIBLE_DEVICES=0 python quantize_once.py \ --model Qwen/Qwen3.5-4B-Base \ --out ./out/qwen35_4b_base_ns128 \ --nsamples 128 \ --batch_size 4 \ --seed 1234deactivate +Bashcd ../gptqpro_basesource .venv/bin/activateCUDA_VISIBLE_DEVICES=0 python quantize_once.py \ --model Qwen/Qwen3.5-4B-Base \ --out ./out/qwen35_4b_base_ns128 \ --nsamples 128 \ --batch_size 4 \ --seed 1234deactivate +``` + +#### Patched + +```Bashcd ../gptqpro_newsource .venv/bin/activateCUDA_VISIBLE_DEVICES=0 python quantize_once.py \ --model Qwen/Qwen3.5-4B-Base \ --out ./out/qwen35_4b_new_ns128 \ --nsamples 128 \ --batch_size 4 \ --seed 1234deactivate +Bashcd ../gptqpro_newsource .venv/bin/activateCUDA_VISIBLE_DEVICES=0 python quantize_once.py \ --model Qwen/Qwen3.5-4B-Base \ --out ./out/qwen35_4b_new_ns128 \ --nsamples 128 \ --batch_size 4 \ --seed 1234deactivate +``` + +### Case B: more normal quantization + +```Bashcd ../gptqpro_basesource .venv/bin/activateCUDA_VISIBLE_DEVICES=0 python quantize_once.py \ --model Qwen/Qwen3.5-4B-Base \ --out ./out/qwen35_4b_base_ns512 \ --nsamples 512 \ --batch_size 4 \ --seed 1234deactivate +Bashcd ../gptqpro_basesource .venv/bin/activateCUDA_VISIBLE_DEVICES=0 python quantize_once.py \ --model Qwen/Qwen3.5-4B-Base \ --out ./out/qwen35_4b_base_ns512 \ --nsamples 512 \ --batch_size 4 \ --seed 1234deactivate +``` + +```Bashcd ../gptqpro_newsource .venv/bin/activateCUDA_VISIBLE_DEVICES=0 python quantize_once.py \ --model Qwen/Qwen3.5-4B-Base \ --out ./out/qwen35_4b_new_ns512 \ --nsamples 512 \ --batch_size 4 \ --seed 1234deactivate +Bashcd ../gptqpro_newsource .venv/bin/activateCUDA_VISIBLE_DEVICES=0 python quantize_once.py \ --model Qwen/Qwen3.5-4B-Base \ --out ./out/qwen35_4b_new_ns512 \ --nsamples 512 \ --batch_size 4 \ --seed 1234deactivate +``` + +## 5) Run lm-eval on both outputs + +Start with a fast suite: + +```BashTASKS="wikitext,hellaswag,winogrande,piqa" +BashTASKS="wikitext,hellaswag,winogrande,piqa" +``` + +### Baseline + +```Bashcd ../gptqpro_basesource .venv/bin/activatelm_eval --model hf \ --model_args pretrained=$(pwd)/out/qwen35_4b_base_ns128,trust_remote_code=True \ --tasks $TASKS \ --device cuda:0 \ --batch_size auto:4 \ --output_path $(pwd)/eval_ns128deactivate +Bashcd ../gptqpro_basesource .venv/bin/activatelm_eval --model hf \ --model_args pretrained=$(pwd)/out/qwen35_4b_base_ns128,trust_remote_code=True \ --tasks $TASKS \ --device cuda:0 \ --batch_size auto:4 \ --output_path $(pwd)/eval_ns128deactivate +``` + +### Patched + +```Bashcd ../gptqpro_newsource .venv/bin/activatelm_eval --model hf \ --model_args pretrained=$(pwd)/out/qwen35_4b_new_ns128,trust_remote_code=True \ --tasks $TASKS \ --device cuda:0 \ --batch_size auto:4 \ --output_path $(pwd)/eval_ns128deactivate +Bashcd ../gptqpro_newsource .venv/bin/activatelm_eval --model hf \ --model_args pretrained=$(pwd)/out/qwen35_4b_new_ns128,trust_remote_code=True \ --tasks $TASKS \ --device cuda:0 \ --batch_size auto:4 \ --output_path $(pwd)/eval_ns128deactivate +``` + +Then repeat for the `nsamples=512` models. + +If you want a slower but more convincing second pass, add `arc` and `gsm8k`. Those tasks exist in the current harness too. [GitHub+1](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/README.md) + +## 6) Compare deterministic generations too + +Save this as `gen_compare.py` in both worktrees: + +```Pythonimport jsonimport argparseimport torchfrom transformers import AutoTokenizer, AutoModelForCausalLMPROMPTS = [ "Explain why TCP head-of-line blocking happens.", "Write a Python function to reverse a linked list.", "Summarize the causes of inflation in simple terms.", "What is the difference between overfitting and underfitting?", "Give three practical uses of embeddings in search systems.", "Why does quantization sometimes hurt coding ability more than chatting?", "Solve: If a train travels 60 km in 45 minutes, what is its speed in km/h?", "Write a bash command to find the 20 largest files under /var/log.", "Explain virtual memory like I'm 12 years old.", "What are the tradeoffs between GPTQ and AWQ?"]def main(): ap = argparse.ArgumentParser() ap.add_argument("--model", required=True) ap.add_argument("--out", required=True) args = ap.parse_args() tok = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( args.model, trust_remote_code=True, device_map="auto", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, ) rows = [] for p in PROMPTS: inputs = tok(p, return_tensors="pt").to(model.device) with torch.no_grad(): out = model.generate( **inputs, do_sample=False, temperature=0.0, max_new_tokens=128, pad_token_id=tok.eos_token_id, ) text = tok.decode(out[0], skip_special_tokens=True) rows.append({"prompt": p, "output": text}) with open(args.out, "w") as f: json.dump(rows, f, indent=2, ensure_ascii=False)if __name__ == "__main__": main() +Pythonimport jsonimport argparseimport torchfrom transformers import AutoTokenizer, AutoModelForCausalLMPROMPTS = [ "Explain why TCP head-of-line blocking happens.", "Write a Python function to reverse a linked list.", "Summarize the causes of inflation in simple terms.", "What is the difference between overfitting and underfitting?", "Give three practical uses of embeddings in search systems.", "Why does quantization sometimes hurt coding ability more than chatting?", "Solve: If a train travels 60 km in 45 minutes, what is its speed in km/h?", "Write a bash command to find the 20 largest files under /var/log.", "Explain virtual memory like I'm 12 years old.", "What are the tradeoffs between GPTQ and AWQ?"]def main(): ap = argparse.ArgumentParser() ap.add_argument("--model", required=True) ap.add_argument("--out", required=True) args = ap.parse_args() tok = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( args.model, trust_remote_code=True, device_map="auto", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, ) rows = [] for p in PROMPTS: inputs = tok(p, return_tensors="pt").to(model.device) with torch.no_grad(): out = model.generate( **inputs, do_sample=False, temperature=0.0, max_new_tokens=128, pad_token_id=tok.eos_token_id, ) text = tok.decode(out[0], skip_special_tokens=True) rows.append({"prompt": p, "output": text}) with open(args.out, "w") as f: json.dump(rows, f, indent=2, ensure_ascii=False)if __name__ == "__main__": main() +``` + +Run it on the two quantized models and diff the JSON files. + +## 7) How to interpret the result + +If your PR is genuinely helping, the **patched** build should usually show: + +- **same or slightly slower quantization time** in the stressed run +- **same or lower wikitext perplexity / better language-model score** +- **same or better accuracy** on `hellaswag`, `winogrande`, `piqa` +- **cleaner deterministic outputs** on edge prompts +- the **biggest difference at nsamples=128**, because that is where the new adaptive failsafe search should matter most [GitHub](https://github.com/groxaxo/GPTQ-Pro/pull/4) + +What would worry me: + +- patched build is materially slower **and** +- `wikitext` gets worse **and** +- downstream tasks are flat or worse + +That would mean the extra best-of smoothing search is adding complexity without buying useful quality. + +## 8) Minimal verdict rule + +Use this decision rule: + +- **Keep merged** if patched is better on `wikitext` and at least neutral elsewhere +- **Definitely keep merged** if patched wins mainly at `nsamples=128` and is neutral at `512` +- **Reconsider defaulting it** only if patched is slower and quality is not measurably better + +The reason I’d weight `nsamples=128` heavily is that your PR is not a general kernel rewrite; it is a targeted improvement to the **failsafe path for low-sample fallback blocks**. That is exactly where it should prove itself. [GitHub](https://github.com/groxaxo/GPTQ-Pro/pull/4) + +If you want, I can turn this into a **single bash harness** that runs the full A/B and prints a compact before/after table. + +--- + +### User: +yes please, provide a full testing suite , using 1x 3090 , gpu 0 for this task please. + +--- + +### ChatGPT: +I’m turning this into a runnable harness instead of loose snippets: one command to quantize old vs new, evaluate both, compare generations, and emit a summary for GPU 0 only. + +--- + +I’ve got the structure: one bootstrap for worktrees/envs, one full runner for quantize+eval+generation, and one summarizer that spits out a markdown verdict. + +--- + +Done. + +I packaged a runnable A/B harness here: + +- +- + +It’s built around GPTQModel’s documented `load -> quantize -> save` flow, uses `Qwen/Qwen3.5-4B-Base` as the default 4B test model, and evaluates the quantized outputs with lm-eval’s Hugging Face backend against local model paths. [GitHub+3GitHub+3Hugging Face+3](https://github.com/ModelCloud/GPTQModel) + +Run it like this: + +```Bashcd /path/where/you/unzipped/gptq_pro_ab_suitechmod +x scripts/*.sh./scripts/bootstrap_worktrees_and_envs.sh \ --repo-root ~/GPTQ-Pro \ --base-commit "3b6de40^" \ --new-commit "e6dee44" \ --torch-install-cmd 'pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128'./scripts/run_full_suite.sh \ --repo-root ~/GPTQ-Pro \ --base-commit "3b6de40^" \ --new-commit "e6dee44" \ --gpu 0 \ --model-id Qwen/Qwen3.5-4B-Base \ --tasks 'wikitext,hellaswag,winogrande,piqa' \ --output-root ~/gptq_ab_results +Bashcd /path/where/you/unzipped/gptq_pro_ab_suitechmod +x scripts/*.sh./scripts/bootstrap_worktrees_and_envs.sh \ --repo-root ~/GPTQ-Pro \ --base-commit "3b6de40^" \ --new-commit "e6dee44" \ --torch-install-cmd 'pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128'./scripts/run_full_suite.sh \ --repo-root ~/GPTQ-Pro \ --base-commit "3b6de40^" \ --new-commit "e6dee44" \ --gpu 0 \ --model-id Qwen/Qwen3.5-4B-Base \ --tasks 'wikitext,hellaswag,winogrande,piqa' \ --output-root ~/gptq_ab_results +``` + +What you’ll get: + +- `~/gptq_ab_results/summary.md` +- `~/gptq_ab_results/summary.json` + +The suite runs: + +- base vs new worktrees +- `nsamples=128` and `nsamples=512` +- quantization timing + peak CUDA memory +- lm-eval task results +- deterministic generation comparisons on fixed prompts + +GPTQModel’s README also explicitly recommends using `lm-eval` or `evalplus` rather than plain perplexity as the main post-quant quality check, and lm-eval documents both the HF backend install and `lm-eval ls tasks` for checking task availability. [GitHub+1](https://github.com/ModelCloud/GPTQModel) + +One thing to watch: swap the PyTorch install command to match your exact CUDA stack before bootstrapping. + +--- + +### User: +i cant download it, can you pass me the files? + +--- + +### ChatGPT: +I’m unpacking the suite so I can paste every file directly here in a usable form. + +--- + +Yes — recreate it locally with this. + +Run this in an empty folder: + +```Bashmkdir -p gptq_pro_ab_suite/scriptscd gptq_pro_ab_suitecat > README.md <<'EOF'# GPTQ-Pro A/B test suite for Qwen3.5-4B on GPU 0This suite is designed to verify whether your merged PR materially improves quantization quality in practice.It runs a strict A/B comparison:- **base** = repo state before the PR- **new** = repo state after the PR- same machine- same GPU (**GPU 0 only**)- same model- same calibration data- same seed- same eval tasksIt produces:- quantization timing + peak CUDA memory- `lm-eval` results for each quantized model- deterministic generation outputs for side-by-side inspection- a compact markdown + JSON summary## Files- `scripts/bootstrap_worktrees_and_envs.sh` - creates base/new worktrees - creates two virtualenvs - installs dependencies in both- `scripts/run_full_suite.sh` - runs quantization, lm-eval, deterministic generation, and final summary- `scripts/quantize_once.py` - quantizes one model once and records timing/memory metadata- `scripts/gen_compare.py` - runs fixed deterministic prompts- `scripts/summarize_ab.py` - aggregates metrics into JSON + Markdown## Requirements- Linux- Python 3.10+- NVIDIA driver + CUDA-compatible PyTorch- 1x RTX 3090 on **GPU 0**- enough disk for two worktrees + downloaded model + four quantized outputs## Recommended modelUse:- `Qwen/Qwen3.5-4B-Base`## 1) Bootstrap the two repos and environmentsReplace the torch install command with the one matching your CUDA stack.```bashchmod +x scripts/*.sh./scripts/bootstrap_worktrees_and_envs.sh \ --repo-root ~/GPTQ-Pro \ --base-commit "3b6de40^" \ --new-commit "e6dee44" \ --torch-install-cmd 'pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128' +Bashmkdir -p gptq_pro_ab_suite/scriptscd gptq_pro_ab_suitecat > README.md <<'EOF'# GPTQ-Pro A/B test suite for Qwen3.5-4B on GPU 0This suite is designed to verify whether your merged PR materially improves quantization quality in practice.It runs a strict A/B comparison:- **base** = repo state before the PR- **new** = repo state after the PR- same machine- same GPU (**GPU 0 only**)- same model- same calibration data- same seed- same eval tasksIt produces:- quantization timing + peak CUDA memory- `lm-eval` results for each quantized model- deterministic generation outputs for side-by-side inspection- a compact markdown + JSON summary## Files- `scripts/bootstrap_worktrees_and_envs.sh` - creates base/new worktrees - creates two virtualenvs - installs dependencies in both- `scripts/run_full_suite.sh` - runs quantization, lm-eval, deterministic generation, and final summary- `scripts/quantize_once.py` - quantizes one model once and records timing/memory metadata- `scripts/gen_compare.py` - runs fixed deterministic prompts- `scripts/summarize_ab.py` - aggregates metrics into JSON + Markdown## Requirements- Linux- Python 3.10+- NVIDIA driver + CUDA-compatible PyTorch- 1x RTX 3090 on **GPU 0**- enough disk for two worktrees + downloaded model + four quantized outputs## Recommended modelUse:- `Qwen/Qwen3.5-4B-Base`## 1) Bootstrap the two repos and environmentsReplace the torch install command with the one matching your CUDA stack.```bashchmod +x scripts/*.sh./scripts/bootstrap_worktrees_and_envs.sh \ --repo-root ~/GPTQ-Pro \ --base-commit "3b6de40^" \ --new-commit "e6dee44" \ --torch-install-cmd 'pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128' +``` + +This creates: + +- `~/GPTQ-Pro_ab_base` +- `~/GPTQ-Pro_ab_new` + +## 2) Run the full suite on GPU 0 + +```Bash./scripts/run_full_suite.sh \ --repo-root ~/GPTQ-Pro \ --base-commit "3b6de40^" \ --new-commit "e6dee44" \ --gpu 0 \ --model-id Qwen/Qwen3.5-4B-Base \ --tasks 'wikitext,hellaswag,winogrande,piqa' \ --output-root ~/gptq_ab_results +Bash./scripts/run_full_suite.sh \ --repo-root ~/GPTQ-Pro \ --base-commit "3b6de40^" \ --new-commit "e6dee44" \ --gpu 0 \ --model-id Qwen/Qwen3.5-4B-Base \ --tasks 'wikitext,hellaswag,winogrande,piqa' \ --output-root ~/gptq_ab_results +``` + +That runs two calibration regimes: + +- `nsamples=128` -> stress test for the PR +- `nsamples=512` -> more normal case + +## 3) Read the final verdict + +Outputs land under: + +- `~/gptq_ab_results/summary.md` +- `~/gptq_ab_results/summary.json` +EOF + +cat > scripts/bootstrap_worktrees_and_envs.sh <<'EOF' +#!/usr/bin/env bash +set -euo pipefail + +usage() { +cat <_ab_base and _ab_new +- Creates .venv inside each worktree +- Installs identical dependencies in both envs +USAGE +} + +REPO_ROOT="" +BASE_COMMIT="" +NEW_COMMIT="" +TORCH_INSTALL_CMD="" + +while [[ $# -gt 0 ]]; do +case "$1" in +--repo-root) REPO_ROOT="$2"; shift 2 ;; +--base-commit) BASE_COMMIT="$2"; shift 2 ;; +--new-commit) NEW_COMMIT="$2"; shift 2 ;; +--torch-install-cmd) TORCH_INSTALL_CMD="$2"; shift 2 ;; +-h|--help) usage; exit 0 ;; +*) echo "Unknown arg: $1" >&2; usage; exit 1 ;; +esac +done + +[[ -n "$REPO_ROOT" ]] || { echo "--repo-root is required" >&2; exit 1; } +[[ -n "$BASE_COMMIT" ]] || { echo "--base-commit is required" >&2; exit 1; } +[[ -n "$NEW_COMMIT" ]] || { echo "--new-commit is required" >&2; exit 1; } +[[ -n "$TORCH_INSTALL_CMD" ]] || { echo "--torch-install-cmd is required" >&2; exit 1; } + +REPO_ROOT="$(cd "$REPO_ROOT" && pwd)" +BASE_WT="${REPO_ROOT}_ab_base" +NEW_WT="${REPO_ROOT}_ab_new" + +ensure_worktree() { +local wt="$1" +local rev="$2" +if [[ -d "$wt/.git" || -f "$wt/.git" ]]; then +echo "[info] worktree already exists: $wt" +else +echo "[info] creating worktree $wt at $rev" +git -C "$REPO_ROOT" worktree add "$wt" "$rev" +fi +} + +install_env() { +local wt="$1" +echo "[info] setting up env in $wt" +pushd "$wt" >/dev/null +python3 -m venv .venv +source .venv/bin/activate +python -m pip install -U pip setuptools wheel ninja packaging +eval "$TORCH_INSTALL_CMD" +python -m pip install "transformers @ git+[https://github.com/huggingface/transformers.git@main]()" +python -m pip install datasets accelerate sentencepiece psutil "lm_eval[hf]" +python -m pip install -e . --no-build-isolation +python - <<'PY' +import sys +try: +import torch +print('[info] torch', torch.**version**, 'cuda=', torch.version.cuda, 'available=', torch.cuda.is_available()) +except Exception as e: +print('[warn] torch import failed:', e, file=sys.stderr) +raise +PY +deactivate +popd >/dev/null +} + +ensure_worktree "$BASE_WT" "$BASE_COMMIT" +ensure_worktree "$NEW_WT" "$NEW_COMMIT" +install_env "$BASE_WT" +install_env "$NEW_WT" + +echo "[done]" +echo "BASE_WT=$BASE_WT" +echo "NEW_WT=$NEW_WT" +EOF + +cat > scripts/quantize_once.py <<'EOF' +#!/usr/bin/env python3 +import argparse +import json +import random +import time +from pathlib import Path + +import numpy as np +import torch +from datasets import load_dataset +from gptqmodel import GPTQModel, QuantizeConfig + +def seed_all(seed: int) -> None: +random.seed(seed) +np.random.seed(seed) +torch.manual_seed(seed) +if torch.cuda.is_available(): +torch.cuda.manual_seed_all(seed) + +def main() -> int: +ap = argparse.ArgumentParser() +ap.add_argument("--model-id", required=True) +ap.add_argument("--out-dir", required=True) +ap.add_argument("--nsamples", type=int, default=128) +ap.add_argument("--batch-size", type=int, default=4) +ap.add_argument("--seed", type=int, default=1234) +ap.add_argument("--bits", type=int, default=4) +ap.add_argument("--group-size", type=int, default=128) +ap.add_argument("--dataset", default="allenai/c4") +ap.add_argument("--dataset-file", default="en/c4-train.00001-of-01024.json.gz") +args = ap.parse_args() + +```out_dir = Path(args.out_dir)out_dir.mkdir(parents=True, exist_ok=True)seed_all(args.seed)if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats()ds = load_dataset( args.dataset, data_files=args.dataset_file, split="train",).select(range(args.nsamples))calibration_dataset = ds["text"]qcfg = QuantizeConfig.gptq_pro()qcfg.bits = args.bitsqcfg.group_size = args.group_sizestart = time.perf_counter()model = GPTQModel.load(args.model_id, qcfg)model.quantize(calibration_dataset, batch_size=args.batch_size)elapsed = time.perf_counter() - startmodel.save(str(out_dir))peak_cuda_mem_bytes = Noneif torch.cuda.is_available(): peak_cuda_mem_bytes = int(torch.cuda.max_memory_allocated())metadata = { "model_id": args.model_id, "out_dir": str(out_dir), "nsamples": args.nsamples, "batch_size": args.batch_size, "seed": args.seed, "bits": args.bits, "group_size": args.group_size, "elapsed_sec": elapsed, "peak_cuda_mem_bytes": peak_cuda_mem_bytes, "cuda_device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0, "cuda_name": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None, "torch_version": torch.__version__, "torch_cuda": torch.version.cuda,}with open(out_dir / "ab_quant_meta.json", "w", encoding="utf-8") as f: json.dump(metadata, f, indent=2)print(json.dumps(metadata, indent=2))return 0 +out_dir = Path(args.out_dir)out_dir.mkdir(parents=True, exist_ok=True)seed_all(args.seed)if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats()ds = load_dataset( args.dataset, data_files=args.dataset_file, split="train",).select(range(args.nsamples))calibration_dataset = ds["text"]qcfg = QuantizeConfig.gptq_pro()qcfg.bits = args.bitsqcfg.group_size = args.group_sizestart = time.perf_counter()model = GPTQModel.load(args.model_id, qcfg)model.quantize(calibration_dataset, batch_size=args.batch_size)elapsed = time.perf_counter() - startmodel.save(str(out_dir))peak_cuda_mem_bytes = Noneif torch.cuda.is_available(): peak_cuda_mem_bytes = int(torch.cuda.max_memory_allocated())metadata = { "model_id": args.model_id, "out_dir": str(out_dir), "nsamples": args.nsamples, "batch_size": args.batch_size, "seed": args.seed, "bits": args.bits, "group_size": args.group_size, "elapsed_sec": elapsed, "peak_cuda_mem_bytes": peak_cuda_mem_bytes, "cuda_device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0, "cuda_name": torch.cuda.get_device_name(0) if torch.cuda.is_available() else None, "torch_version": torch.__version__, "torch_cuda": torch.version.cuda,}with open(out_dir / "ab_quant_meta.json", "w", encoding="utf-8") as f: json.dump(metadata, f, indent=2)print(json.dumps(metadata, indent=2))return 0 +``` + +if **name** == "**main**": +raise SystemExit(main()) +EOF + +cat > scripts/gen_compare.py <<'EOF' +#!/usr/bin/env python3 +import argparse +import json +from pathlib import Path + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +PROMPTS = [ +"Explain why TCP head-of-line blocking happens.", +"Write a Python function to reverse a linked list.", +"Summarize the causes of inflation in simple terms.", +"What is the difference between overfitting and underfitting?", +"Give three practical uses of embeddings in search systems.", +"Why can quantization hurt coding ability more than casual chat?", +"Solve: If a train travels 60 km in 45 minutes, what is its speed in km/h?", +"Write a bash command to find the 20 largest files under /var/log.", +"Explain virtual memory like I'm 12 years old.", +"What are the tradeoffs between GPTQ and AWQ?", +] + +def main() -> int: +ap = argparse.ArgumentParser() +ap.add_argument("--model-path", required=True) +ap.add_argument("--out-json", required=True) +ap.add_argument("--max-new-tokens", type=int, default=160) +args = ap.parse_args() + +```model_path = Path(args.model_path)out_json = Path(args.out_json)out_json.parent.mkdir(parents=True, exist_ok=True)tok = AutoTokenizer.from_pretrained(str(model_path), trust_remote_code=True)dtype = torch.float16 if torch.cuda.is_available() else torch.float32model = AutoModelForCausalLM.from_pretrained( str(model_path), trust_remote_code=True, device_map="auto", torch_dtype=dtype,)rows = []for prompt in PROMPTS: inputs = tok(prompt, return_tensors="pt") inputs = {k: v.to(model.device) for k, v in inputs.items()} with torch.no_grad(): out = model.generate( **inputs, do_sample=False, temperature=0.0, top_p=1.0, max_new_tokens=args.max_new_tokens, pad_token_id=tok.eos_token_id, ) text = tok.decode(out[0], skip_special_tokens=True) rows.append({"prompt": prompt, "output": text})with open(out_json, "w", encoding="utf-8") as f: json.dump(rows, f, indent=2, ensure_ascii=False)print(f"wrote {out_json}")return 0 +model_path = Path(args.model_path)out_json = Path(args.out_json)out_json.parent.mkdir(parents=True, exist_ok=True)tok = AutoTokenizer.from_pretrained(str(model_path), trust_remote_code=True)dtype = torch.float16 if torch.cuda.is_available() else torch.float32model = AutoModelForCausalLM.from_pretrained( str(model_path), trust_remote_code=True, device_map="auto", torch_dtype=dtype,)rows = []for prompt in PROMPTS: inputs = tok(prompt, return_tensors="pt") inputs = {k: v.to(model.device) for k, v in inputs.items()} with torch.no_grad(): out = model.generate( **inputs, do_sample=False, temperature=0.0, top_p=1.0, max_new_tokens=args.max_new_tokens, pad_token_id=tok.eos_token_id, ) text = tok.decode(out[0], skip_special_tokens=True) rows.append({"prompt": prompt, "output": text})with open(out_json, "w", encoding="utf-8") as f: json.dump(rows, f, indent=2, ensure_ascii=False)print(f"wrote {out_json}")return 0 +``` + +if **name** == "**main**": +raise SystemExit(main()) +EOF + +cat > scripts/summarize_ab.py <<'EOF' +#!/usr/bin/env python3 +import argparse +import difflib +import json +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +HIGHER_IS_BETTER = {"acc", "acc_norm", "exact_match", "mc1", "mc2", "f1", "bleu"} +LOWER_IS_BETTER = {"word_perplexity", "perplexity", "byte_perplexity", "bpb", "bpc"} + +def load_json(path: Path): +with open(path, "r", encoding="utf-8") as f: +return json.load(f) + +def find_lm_eval_json(eval_dir: Path) -> Optional[Path]: +candidates = sorted(eval_dir.rglob("*.json")) +for p in candidates: +try: +data = load_json(p) +except Exception: +continue +if isinstance(data, dict) and "results" in data: +return p +return None + +def extract_task_metrics(eval_json: Dict) -> Dict[str, Dict[str, float]]: +results = eval_json.get("results", {}) +out: Dict[str, Dict[str, float]] = {} +for task, task_metrics in results.items(): +flat: Dict[str, float] = {} +if not isinstance(task_metrics, dict): +continue +for k, v in task_metrics.items(): +if isinstance(v, (int, float)): +flat[k] = float(v) +out[task] = flat +return out + +def choose_primary_metric(task: str, metrics: Dict[str, float]) -> Tuple[Optional[str], Optional[float], Optional[str]]: +preferred = [ +"acc_norm,none", "acc,none", "exact_match,strict-match", "exact_match,none", +"mc2,none", "mc1,none", "word_perplexity,none", "perplexity,none", +"byte_perplexity,none", "bpb,none", "bpc,none", +"acc_norm", "acc", "exact_match", "mc2", "mc1", "word_perplexity", +"perplexity", "byte_perplexity", "bpb", "bpc", +] +for key in preferred: +if key in metrics: +short = key.split(",", 1)[0] +direction = "higher" if short in HIGHER_IS_BETTER else "lower" if short in LOWER_IS_BETTER else "unknown" +return key, metrics[key], direction +return None, None, None + +def compare_generations(base_path: Path, new_path: Path) -> Dict: +base_rows = load_json(base_path) +new_rows = load_json(new_path) +exact = 0 +ratios: List[float] = [] +changed_examples = [] +for b, n in zip(base_rows, new_rows): +br = b.get("output", "") +nr = n.get("output", "") +if br == nr: +exact += 1 +ratio = difflib.SequenceMatcher(None, br, nr).ratio() +ratios.append(ratio) +if br != nr and len(changed_examples) < 3: +changed_examples.append({ +"prompt": b.get("prompt", ""), +"base_output": br[:900], +"new_output": nr[:900], +"similarity": ratio, +}) +return { +"num_prompts": len(base_rows), +"exact_match_count": exact, +"exact_match_rate": exact / max(1, len(base_rows)), +"avg_similarity": sum(ratios) / max(1, len(ratios)), +"changed_examples": changed_examples, +} + +def load_quant_meta(path: Path) -> Dict: +return load_json(path) + +def build_case_summary(case_dir: Path) -> Dict: +base_quant = load_quant_meta(case_dir / "base" / "ab_quant_meta.json") +new_quant = load_quant_meta(case_dir / "new" / "ab_quant_meta.json") + +```base_eval_json_path = find_lm_eval_json(case_dir / "base" / "eval")new_eval_json_path = find_lm_eval_json(case_dir / "new" / "eval")if base_eval_json_path is None or new_eval_json_path is None: raise FileNotFoundError(f"Could not find lm-eval JSON under {case_dir}")base_eval = load_json(base_eval_json_path)new_eval = load_json(new_eval_json_path)base_metrics = extract_task_metrics(base_eval)new_metrics = extract_task_metrics(new_eval)task_summaries = []all_tasks = sorted(set(base_metrics) | set(new_metrics))for task in all_tasks: b = base_metrics.get(task, {}) n = new_metrics.get(task, {}) metric_name, bval, direction = choose_primary_metric(task, b) metric_name2, nval, direction2 = choose_primary_metric(task, n) chosen = metric_name or metric_name2 direction = direction or direction2 or "unknown" if chosen is not None: bval = b.get(chosen, bval) nval = n.get(chosen, nval) delta = None winner = "tie" if bval is not None and nval is not None: delta = nval - bval if direction == "higher": winner = "new" if nval > bval else "base" if bval > nval else "tie" elif direction == "lower": winner = "new" if nval < bval else "base" if bval < nval else "tie" task_summaries.append({ "task": task, "metric": chosen, "direction": direction, "base": bval, "new": nval, "delta_new_minus_base": delta, "winner": winner, })gen = compare_generations(case_dir / "base" / "generations.json", case_dir / "new" / "generations.json")quant_delta = { "elapsed_sec_new_minus_base": new_quant["elapsed_sec"] - base_quant["elapsed_sec"], "peak_cuda_mem_bytes_new_minus_base": ( (new_quant.get("peak_cuda_mem_bytes") or 0) - (base_quant.get("peak_cuda_mem_bytes") or 0) ),}return { "case": case_dir.name, "base_quant": base_quant, "new_quant": new_quant, "quant_delta": quant_delta, "base_eval_json": str(base_eval_json_path), "new_eval_json": str(new_eval_json_path), "tasks": task_summaries, "generation_compare": gen,} +base_eval_json_path = find_lm_eval_json(case_dir / "base" / "eval")new_eval_json_path = find_lm_eval_json(case_dir / "new" / "eval")if base_eval_json_path is None or new_eval_json_path is None: raise FileNotFoundError(f"Could not find lm-eval JSON under {case_dir}")base_eval = load_json(base_eval_json_path)new_eval = load_json(new_eval_json_path)base_metrics = extract_task_metrics(base_eval)new_metrics = extract_task_metrics(new_eval)task_summaries = []all_tasks = sorted(set(base_metrics) | set(new_metrics))for task in all_tasks: b = base_metrics.get(task, {}) n = new_metrics.get(task, {}) metric_name, bval, direction = choose_primary_metric(task, b) metric_name2, nval, direction2 = choose_primary_metric(task, n) chosen = metric_name or metric_name2 direction = direction or direction2 or "unknown" if chosen is not None: bval = b.get(chosen, bval) nval = n.get(chosen, nval) delta = None winner = "tie" if bval is not None and nval is not None: delta = nval - bval if direction == "higher": winner = "new" if nval > bval else "base" if bval > nval else "tie" elif direction == "lower": winner = "new" if nval < bval else "base" if bval < nval else "tie" task_summaries.append({ "task": task, "metric": chosen, "direction": direction, "base": bval, "new": nval, "delta_new_minus_base": delta, "winner": winner, })gen = compare_generations(case_dir / "base" / "generations.json", case_dir / "new" / "generations.json")quant_delta = { "elapsed_sec_new_minus_base": new_quant["elapsed_sec"] - base_quant["elapsed_sec"], "peak_cuda_mem_bytes_new_minus_base": ( (new_quant.get("peak_cuda_mem_bytes") or 0) - (base_quant.get("peak_cuda_mem_bytes") or 0) ),}return { "case": case_dir.name, "base_quant": base_quant, "new_quant": new_quant, "quant_delta": quant_delta, "base_eval_json": str(base_eval_json_path), "new_eval_json": str(new_eval_json_path), "tasks": task_summaries, "generation_compare": gen,} +``` + +def overall_verdict(case_summaries: List[Dict]) -> str: +new_wins = 0 +base_wins = 0 +for case in case_summaries: +for t in case["tasks"]: +if t["winner"] == "new": +new_wins += 1 +elif t["winner"] == "base": +base_wins += 1 +if new_wins > base_wins: +return "new looks better overall" +if base_wins > new_wins: +return "base looks better overall" +return "mixed / inconclusive" + +def render_md(case_summaries: List[Dict], summary_path: Path) -> None: +lines: List[str] = [] +lines.append("# GPTQ-Pro A/B summary") +lines.append("") +lines.append(f"Overall verdict: **{overall_verdict(case_summaries)}**") +lines.append("") +for case in case_summaries: +lines.append(f"## {case['case']}") +lines.append("") +lines.append("### Quantization") +lines.append("") +lines.append(f"- base elapsed: {case['base_quant']['elapsed_sec']:.2f}s") +lines.append(f"- new elapsed: {case['new_quant']['elapsed_sec']:.2f}s") +lines.append(f"- delta (new - base): {case['quant_delta']['elapsed_sec_new_minus_base']:.2f}s") +lines.append(f"- base peak CUDA mem: {case['base_quant'].get('peak_cuda_mem_bytes')}") +lines.append(f"- new peak CUDA mem: {case['new_quant'].get('peak_cuda_mem_bytes')}") +lines.append(f"- delta peak mem (new - base): {case['quant_delta']['peak_cuda_mem_bytes_new_minus_base']}") +lines.append("") +lines.append("### Eval") +lines.append("") +lines.append("| task | metric | direction | base | new | delta (new-base) | winner |") +lines.append("|---|---|---:|---:|---:|---:|---|") +for t in case["tasks"]: +lines.append( +f"| {t['task']} | {t['metric']} | {t['direction']} | {t['base']} | {t['new']} | {t['delta_new_minus_base']} | {t['winner']} |" +) +lines.append("") +gc = case["generation_compare"] +lines.append("### Deterministic generations") +lines.append("") +lines.append(f"- exact matches: {gc['exact_match_count']} / {gc['num_prompts']}") +lines.append(f"- exact match rate: {gc['exact_match_rate']:.3f}") +lines.append(f"- average similarity: {gc['avg_similarity']:.3f}") +if gc["changed_examples"]: +lines.append("") +lines.append("#### Sample changed outputs") +lines.append("") +for ex in gc["changed_examples"]: +lines.append(f"- Prompt: `{ex['prompt']}`") +lines.append(f" - similarity: {ex['similarity']:.3f}") +lines.append(f" - base: {ex['base_output'][:220].replace(chr(10), ' ')}") +lines.append(f" - new: {ex['new_output'][:220].replace(chr(10), ' ')}") +lines.append("") +summary_path.write_text("\n".join(lines), encoding="utf-8") + +def main() -> int: +ap = argparse.ArgumentParser() +ap.add_argument("--output-root", required=True) +args = ap.parse_args() + +```root = Path(args.output_root)cases = []for name in ("ns128", "ns512"): case_dir = root / name if case_dir.exists(): cases.append(build_case_summary(case_dir))if not cases: raise FileNotFoundError(f"No case directories found under {root}")summary_json = root / "summary.json"summary_md = root / "summary.md"payload = { "overall_verdict": overall_verdict(cases), "cases": cases,}summary_json.write_text(json.dumps(payload, indent=2), encoding="utf-8")render_md(cases, summary_md)print(summary_md)print(summary_json)return 0 +root = Path(args.output_root)cases = []for name in ("ns128", "ns512"): case_dir = root / name if case_dir.exists(): cases.append(build_case_summary(case_dir))if not cases: raise FileNotFoundError(f"No case directories found under {root}")summary_json = root / "summary.json"summary_md = root / "summary.md"payload = { "overall_verdict": overall_verdict(cases), "cases": cases,}summary_json.write_text(json.dumps(payload, indent=2), encoding="utf-8")render_md(cases, summary_md)print(summary_md)print(summary_json)return 0 +``` + +if **name** == "**main**": +raise SystemExit(main()) +EOF + +cat > scripts/run_full_suite.sh <<'EOF' +#!/usr/bin/env bash +set -euo pipefail + +usage() { +cat <&2; usage; exit 1 ;; +esac +done + +[[ -n "$REPO_ROOT" ]] || { echo "--repo-root is required" >&2; exit 1; } +[[ -n "$BASE_COMMIT" ]] || { echo "--base-commit is required" >&2; exit 1; } +[[ -n "$NEW_COMMIT" ]] || { echo "--new-commit is required" >&2; exit 1; } + +REPO_ROOT="$(cd "$REPO_ROOT" && pwd)" +BASE_WT="${REPO_ROOT}_ab_base" +NEW_WT="${REPO_ROOT}_ab_new" +SELF_DIR="$(cd "$(dirname "$0")" && pwd)" +OUTPUT_ROOT="${OUTPUT_ROOT/#~/$HOME}" +mkdir -p "$OUTPUT_ROOT" + +if [[ ! -d "$BASE_WT/.venv" || ! -d "$NEW_WT/.venv" ]]; then +echo "[error] missing virtualenvs. Run bootstrap_worktrees_and_envs.sh first." >&2 +exit 1 +fi + +write_meta() { +cat > "$OUTPUT_ROOT/run_meta.json" </dev/null 2>&1; then +lm-eval "$@" +else +python -m lm_eval "$@" +fi +} + +run_one_side() { +local side="$1" +local wt="$2" +local case_name="$3" +local nsamples="$4" + +local case_root="$OUTPUT_ROOT/$case_name/$side" +local model_out="$case_root/model" +local eval_out="$case_root/eval" +local gen_out="$case_root/generations.json" +mkdir -p "$case_root" "$eval_out" + +echo "[info] running $side $case_name in $wt" +pushd "$wt" >/dev/null +source .venv/bin/activate +export CUDA_VISIBLE_DEVICES="$GPU" +export TOKENIZERS_PARALLELISM=false + +python "$SELF_DIR/quantize_once.py" +--model-id "$MODEL_ID" +--out-dir "$model_out" +--nsamples "$nsamples" +--batch-size "$BATCH_SIZE" +--seed "$SEED" +--bits "$BITS" +--group-size "$GROUP_SIZE" | tee "$case_root/quantize.log" + +run_lm_eval +--model hf +--model_args "pretrained=$model_out,trust_remote_code=True" +--tasks "$TASKS" +--device cuda:0 +--batch_size auto:4 +--output_path "$eval_out" +--seed "$SEED" +--trust_remote_code | tee "$case_root/lm_eval.log" + +python "$SELF_DIR/gen_compare.py" +--model-path "$model_out" +--out-json "$gen_out" | tee "$case_root/generation.log" + +deactivate +popd >/dev/null +} + +write_meta +run_one_side base "$BASE_WT" ns128 "$SMALL_NSAMPLES" +run_one_side new "$NEW_WT" ns128 "$SMALL_NSAMPLES" +run_one_side base "$BASE_WT" ns512 "$LARGE_NSAMPLES" +run_one_side new "$NEW_WT" ns512 "$LARGE_NSAMPLES" + +python "$SELF_DIR/summarize_ab.py" --output-root "$OUTPUT_ROOT" | tee "$OUTPUT_ROOT/summary_paths.txt" + +echo "[done] summary: $OUTPUT_ROOT/summary.md" +EOF + +chmod +x scripts/*.sh scripts/*.py + +```Then run:```bashcd gptq_pro_ab_suite./scripts/bootstrap_worktrees_and_envs.sh \ --repo-root ~/GPTQ-Pro \ --base-commit "3b6de40^" \ --new-commit "e6dee44" \ --torch-install-cmd 'pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128'./scripts/run_full_suite.sh \ --repo-root ~/GPTQ-Pro \ --base-commit "3b6de40^" \ --new-commit "e6dee44" \ --gpu 0 \ --model-id Qwen/Qwen3.5-4B-Base \ --tasks 'wikitext,hellaswag,winogrande,piqa' \ --output-root ~/gptq_ab_results +Then run:```bashcd gptq_pro_ab_suite./scripts/bootstrap_worktrees_and_envs.sh \ --repo-root ~/GPTQ-Pro \ --base-commit "3b6de40^" \ --new-commit "e6dee44" \ --torch-install-cmd 'pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128'./scripts/run_full_suite.sh \ --repo-root ~/GPTQ-Pro \ --base-commit "3b6de40^" \ --new-commit "e6dee44" \ --gpu 0 \ --model-id Qwen/Qwen3.5-4B-Base \ --tasks 'wikitext,hellaswag,winogrande,piqa' \ --output-root ~/gptq_ab_results +``` + +If you want, next I can give you a **patched v2** that also logs: + +- `nvidia-smi` snapshots during quantization +- exact git SHAs from each worktree +- optional `gsm8k` / `arc_challenge` +- a stricter pass/fail verdict for the PR + +--- + +### User: +yes please genius + +--- + +### ChatGPT: +I’m upgrading the harness so it’s actually audit-friendly: GPU telemetry during quantization, pinned git SHAs in the report, optional heavier eval tasks, and a real pass/fail rule instead of a vague summary. + +--- + +Replace these **3 files** with the versions below. + +Then run the same bootstrap, and use the new `run_full_suite.sh` flags. + +## scripts/quantize_once.py + +```Python#!/usr/bin/env python3import argparseimport jsonimport osimport platformimport randomimport subprocessimport sysimport timefrom pathlib import Pathimport numpy as npimport torchfrom datasets import load_datasetfrom gptqmodel import GPTQModel, QuantizeConfigdef seed_all(seed: int) -> None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)def safe_git_sha() -> str | None: try: out = subprocess.check_output( ["git", "rev-parse", "HEAD"], stderr=subprocess.DEVNULL, text=True, ).strip() return out except Exception: return Nonedef main() -> int: ap = argparse.ArgumentParser() ap.add_argument("--model-id", required=True) ap.add_argument("--out-dir", required=True) ap.add_argument("--nsamples", type=int, default=128) ap.add_argument("--batch-size", type=int, default=4) ap.add_argument("--seed", type=int, default=1234) ap.add_argument("--bits", type=int, default=4) ap.add_argument("--group-size", type=int, default=128) ap.add_argument("--dataset", default="allenai/c4") ap.add_argument("--dataset-file", default="en/c4-train.00001-of-01024.json.gz") ap.add_argument("--tag", default=None) args = ap.parse_args() out_dir = Path(args.out_dir) out_dir.mkdir(parents=True, exist_ok=True) seed_all(args.seed) if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() ds = load_dataset( args.dataset, data_files=args.dataset_file, split="train", ).select(range(args.nsamples)) calibration_dataset = ds["text"] qcfg = QuantizeConfig.gptq_pro() qcfg.bits = args.bits qcfg.group_size = args.group_size start = time.perf_counter() model = GPTQModel.load(args.model_id, qcfg) model.quantize(calibration_dataset, batch_size=args.batch_size) elapsed = time.perf_counter() - start model.save(str(out_dir)) peak_cuda_mem_bytes = None cuda_name = None if torch.cuda.is_available(): peak_cuda_mem_bytes = int(torch.cuda.max_memory_allocated()) cuda_name = torch.cuda.get_device_name(0) metadata = { "tag": args.tag, "model_id": args.model_id, "out_dir": str(out_dir), "nsamples": args.nsamples, "batch_size": args.batch_size, "seed": args.seed, "bits": args.bits, "group_size": args.group_size, "elapsed_sec": elapsed, "peak_cuda_mem_bytes": peak_cuda_mem_bytes, "cuda_device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0, "cuda_name": cuda_name, "torch_version": torch.__version__, "torch_cuda": torch.version.cuda, "python_version": sys.version, "platform": platform.platform(), "hostname": platform.node(), "cwd": os.getcwd(), "git_sha": safe_git_sha(), "cuda_visible_devices": os.environ.get("CUDA_VISIBLE_DEVICES"), } with open(out_dir / "ab_quant_meta.json", "w", encoding="utf-8") as f: json.dump(metadata, f, indent=2) print(json.dumps(metadata, indent=2)) return 0if __name__ == "__main__": raise SystemExit(main()) +Python#!/usr/bin/env python3import argparseimport jsonimport osimport platformimport randomimport subprocessimport sysimport timefrom pathlib import Pathimport numpy as npimport torchfrom datasets import load_datasetfrom gptqmodel import GPTQModel, QuantizeConfigdef seed_all(seed: int) -> None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)def safe_git_sha() -> str | None: try: out = subprocess.check_output( ["git", "rev-parse", "HEAD"], stderr=subprocess.DEVNULL, text=True, ).strip() return out except Exception: return Nonedef main() -> int: ap = argparse.ArgumentParser() ap.add_argument("--model-id", required=True) ap.add_argument("--out-dir", required=True) ap.add_argument("--nsamples", type=int, default=128) ap.add_argument("--batch-size", type=int, default=4) ap.add_argument("--seed", type=int, default=1234) ap.add_argument("--bits", type=int, default=4) ap.add_argument("--group-size", type=int, default=128) ap.add_argument("--dataset", default="allenai/c4") ap.add_argument("--dataset-file", default="en/c4-train.00001-of-01024.json.gz") ap.add_argument("--tag", default=None) args = ap.parse_args() out_dir = Path(args.out_dir) out_dir.mkdir(parents=True, exist_ok=True) seed_all(args.seed) if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() ds = load_dataset( args.dataset, data_files=args.dataset_file, split="train", ).select(range(args.nsamples)) calibration_dataset = ds["text"] qcfg = QuantizeConfig.gptq_pro() qcfg.bits = args.bits qcfg.group_size = args.group_size start = time.perf_counter() model = GPTQModel.load(args.model_id, qcfg) model.quantize(calibration_dataset, batch_size=args.batch_size) elapsed = time.perf_counter() - start model.save(str(out_dir)) peak_cuda_mem_bytes = None cuda_name = None if torch.cuda.is_available(): peak_cuda_mem_bytes = int(torch.cuda.max_memory_allocated()) cuda_name = torch.cuda.get_device_name(0) metadata = { "tag": args.tag, "model_id": args.model_id, "out_dir": str(out_dir), "nsamples": args.nsamples, "batch_size": args.batch_size, "seed": args.seed, "bits": args.bits, "group_size": args.group_size, "elapsed_sec": elapsed, "peak_cuda_mem_bytes": peak_cuda_mem_bytes, "cuda_device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0, "cuda_name": cuda_name, "torch_version": torch.__version__, "torch_cuda": torch.version.cuda, "python_version": sys.version, "platform": platform.platform(), "hostname": platform.node(), "cwd": os.getcwd(), "git_sha": safe_git_sha(), "cuda_visible_devices": os.environ.get("CUDA_VISIBLE_DEVICES"), } with open(out_dir / "ab_quant_meta.json", "w", encoding="utf-8") as f: json.dump(metadata, f, indent=2) print(json.dumps(metadata, indent=2)) return 0if __name__ == "__main__": raise SystemExit(main()) +``` + +## scripts/run_full_suite.sh + +```Bash#!/usr/bin/env bashset -euo pipefailusage() { cat <&2; usage; exit 1 ;; esacdone[[ -n "$REPO_ROOT" ]] || { echo "--repo-root is required" >&2; exit 1; }[[ -n "$BASE_COMMIT" ]] || { echo "--base-commit is required" >&2; exit 1; }[[ -n "$NEW_COMMIT" ]] || { echo "--new-commit is required" >&2; exit 1; }REPO_ROOT="$(cd "$REPO_ROOT" && pwd)"BASE_WT="${REPO_ROOT}_ab_base"NEW_WT="${REPO_ROOT}_ab_new"SELF_DIR="$(cd "$(dirname "$0")" && pwd)"OUTPUT_ROOT="${OUTPUT_ROOT/#\~/$HOME}"mkdir -p "$OUTPUT_ROOT"if [[ ! -d "$BASE_WT/.venv" || ! -d "$NEW_WT/.venv" ]]; then echo "[error] missing virtualenvs. Run bootstrap_worktrees_and_envs.sh first." >&2 exit 1fiFULL_TASKS="$TASKS"if [[ -n "$EXTRA_TASKS" ]]; then FULL_TASKS="${FULL_TASKS},${EXTRA_TASKS}"fibase_sha="$(git -C "$BASE_WT" rev-parse HEAD)"new_sha="$(git -C "$NEW_WT" rev-parse HEAD)"write_meta() { cat > "$OUTPUT_ROOT/run_meta.json" </dev/null 2>&1; then lm-eval "$@" else python -m lm_eval "$@" fi}start_gpu_logger() { local logfile="$1" local pidfile="$2" if ! command -v nvidia-smi >/dev/null 2>&1; then echo "[warn] nvidia-smi not found, skipping GPU telemetry" | tee -a "$logfile" return 0 fi { echo "unix_ts,timestamp,index,name,util_gpu,util_mem,memory_used_mb,memory_total_mb,temp_c,power_w" while true; do ts="$(date +%s)" row="$(nvidia-smi \ --query-gpu=timestamp,index,name,utilization.gpu,utilization.memory,memory.used,memory.total,temperature.gpu,power.draw \ --format=csv,noheader,nounits \ -i "$GPU" 2>/dev/null || true)" if [[ -n "$row" ]]; then echo "${ts},${row}" fi sleep "$GPU_LOG_INTERVAL" done } >> "$logfile" & echo $! > "$pidfile"}stop_gpu_logger() { local pidfile="$1" if [[ -f "$pidfile" ]]; then pid="$(cat "$pidfile" || true)" if [[ -n "${pid:-}" ]] && kill -0 "$pid" 2>/dev/null; then kill "$pid" 2>/dev/null || true wait "$pid" 2>/dev/null || true fi rm -f "$pidfile" fi}run_one_side() { local side="$1" local wt="$2" local case_name="$3" local nsamples="$4" local case_root="$OUTPUT_ROOT/$case_name/$side" local model_out="$case_root/model" local eval_out="$case_root/eval" local gen_out="$case_root/generations.json" local gpu_log="$case_root/gpu_telemetry.csv" local gpu_pid="$case_root/gpu_telemetry.pid" mkdir -p "$case_root" "$eval_out" echo "[info] running $side $case_name in $wt" pushd "$wt" >/dev/null source .venv/bin/activate export CUDA_VISIBLE_DEVICES="$GPU" export TOKENIZERS_PARALLELISM=false start_gpu_logger "$gpu_log" "$gpu_pid" python "$SELF_DIR/quantize_once.py" \ --model-id "$MODEL_ID" \ --out-dir "$model_out" \ --nsamples "$nsamples" \ --batch-size "$BATCH_SIZE" \ --seed "$SEED" \ --bits "$BITS" \ --group-size "$GROUP_SIZE" \ --tag "${side}_${case_name}" | tee "$case_root/quantize.log" stop_gpu_logger "$gpu_pid" run_lm_eval \ --model hf \ --model_args "pretrained=$model_out,trust_remote_code=True" \ --tasks "$FULL_TASKS" \ --device cuda:0 \ --batch_size auto:4 \ --output_path "$eval_out" \ --seed "$SEED" \ --trust_remote_code | tee "$case_root/lm_eval.log" python "$SELF_DIR/gen_compare.py" \ --model-path "$model_out" \ --out-json "$gen_out" | tee "$case_root/generation.log" deactivate popd >/dev/null}write_metarun_one_side base "$BASE_WT" ns128 "$SMALL_NSAMPLES"run_one_side new "$NEW_WT" ns128 "$SMALL_NSAMPLES"run_one_side base "$BASE_WT" ns512 "$LARGE_NSAMPLES"run_one_side new "$NEW_WT" ns512 "$LARGE_NSAMPLES"python "$SELF_DIR/summarize_ab.py" \ --output-root "$OUTPUT_ROOT" \ --strict-max-slowdown-pct "$STRICT_MAX_SLOWDOWN_PCT" \ --strict-max-mem-increase-pct "$STRICT_MAX_MEM_INCREASE_PCT" \ --strict-allow-task-regressions "$STRICT_ALLOW_TASK_REGRESSIONS" \ --strict-require-ns128-win "$STRICT_REQUIRE_NS128_WIN" | tee "$OUTPUT_ROOT/summary_paths.txt"echo "[done] summary: $OUTPUT_ROOT/summary.md" +Bash#!/usr/bin/env bashset -euo pipefailusage() { cat <&2; usage; exit 1 ;; esacdone[[ -n "$REPO_ROOT" ]] || { echo "--repo-root is required" >&2; exit 1; }[[ -n "$BASE_COMMIT" ]] || { echo "--base-commit is required" >&2; exit 1; }[[ -n "$NEW_COMMIT" ]] || { echo "--new-commit is required" >&2; exit 1; }REPO_ROOT="$(cd "$REPO_ROOT" && pwd)"BASE_WT="${REPO_ROOT}_ab_base"NEW_WT="${REPO_ROOT}_ab_new"SELF_DIR="$(cd "$(dirname "$0")" && pwd)"OUTPUT_ROOT="${OUTPUT_ROOT/#\~/$HOME}"mkdir -p "$OUTPUT_ROOT"if [[ ! -d "$BASE_WT/.venv" || ! -d "$NEW_WT/.venv" ]]; then echo "[error] missing virtualenvs. Run bootstrap_worktrees_and_envs.sh first." >&2 exit 1fiFULL_TASKS="$TASKS"if [[ -n "$EXTRA_TASKS" ]]; then FULL_TASKS="${FULL_TASKS},${EXTRA_TASKS}"fibase_sha="$(git -C "$BASE_WT" rev-parse HEAD)"new_sha="$(git -C "$NEW_WT" rev-parse HEAD)"write_meta() { cat > "$OUTPUT_ROOT/run_meta.json" </dev/null 2>&1; then lm-eval "$@" else python -m lm_eval "$@" fi}start_gpu_logger() { local logfile="$1" local pidfile="$2" if ! command -v nvidia-smi >/dev/null 2>&1; then echo "[warn] nvidia-smi not found, skipping GPU telemetry" | tee -a "$logfile" return 0 fi { echo "unix_ts,timestamp,index,name,util_gpu,util_mem,memory_used_mb,memory_total_mb,temp_c,power_w" while true; do ts="$(date +%s)" row="$(nvidia-smi \ --query-gpu=timestamp,index,name,utilization.gpu,utilization.memory,memory.used,memory.total,temperature.gpu,power.draw \ --format=csv,noheader,nounits \ -i "$GPU" 2>/dev/null || true)" if [[ -n "$row" ]]; then echo "${ts},${row}" fi sleep "$GPU_LOG_INTERVAL" done } >> "$logfile" & echo $! > "$pidfile"}stop_gpu_logger() { local pidfile="$1" if [[ -f "$pidfile" ]]; then pid="$(cat "$pidfile" || true)" if [[ -n "${pid:-}" ]] && kill -0 "$pid" 2>/dev/null; then kill "$pid" 2>/dev/null || true wait "$pid" 2>/dev/null || true fi rm -f "$pidfile" fi}run_one_side() { local side="$1" local wt="$2" local case_name="$3" local nsamples="$4" local case_root="$OUTPUT_ROOT/$case_name/$side" local model_out="$case_root/model" local eval_out="$case_root/eval" local gen_out="$case_root/generations.json" local gpu_log="$case_root/gpu_telemetry.csv" local gpu_pid="$case_root/gpu_telemetry.pid" mkdir -p "$case_root" "$eval_out" echo "[info] running $side $case_name in $wt" pushd "$wt" >/dev/null source .venv/bin/activate export CUDA_VISIBLE_DEVICES="$GPU" export TOKENIZERS_PARALLELISM=false start_gpu_logger "$gpu_log" "$gpu_pid" python "$SELF_DIR/quantize_once.py" \ --model-id "$MODEL_ID" \ --out-dir "$model_out" \ --nsamples "$nsamples" \ --batch-size "$BATCH_SIZE" \ --seed "$SEED" \ --bits "$BITS" \ --group-size "$GROUP_SIZE" \ --tag "${side}_${case_name}" | tee "$case_root/quantize.log" stop_gpu_logger "$gpu_pid" run_lm_eval \ --model hf \ --model_args "pretrained=$model_out,trust_remote_code=True" \ --tasks "$FULL_TASKS" \ --device cuda:0 \ --batch_size auto:4 \ --output_path "$eval_out" \ --seed "$SEED" \ --trust_remote_code | tee "$case_root/lm_eval.log" python "$SELF_DIR/gen_compare.py" \ --model-path "$model_out" \ --out-json "$gen_out" | tee "$case_root/generation.log" deactivate popd >/dev/null}write_metarun_one_side base "$BASE_WT" ns128 "$SMALL_NSAMPLES"run_one_side new "$NEW_WT" ns128 "$SMALL_NSAMPLES"run_one_side base "$BASE_WT" ns512 "$LARGE_NSAMPLES"run_one_side new "$NEW_WT" ns512 "$LARGE_NSAMPLES"python "$SELF_DIR/summarize_ab.py" \ --output-root "$OUTPUT_ROOT" \ --strict-max-slowdown-pct "$STRICT_MAX_SLOWDOWN_PCT" \ --strict-max-mem-increase-pct "$STRICT_MAX_MEM_INCREASE_PCT" \ --strict-allow-task-regressions "$STRICT_ALLOW_TASK_REGRESSIONS" \ --strict-require-ns128-win "$STRICT_REQUIRE_NS128_WIN" | tee "$OUTPUT_ROOT/summary_paths.txt"echo "[done] summary: $OUTPUT_ROOT/summary.md" +``` + +## scripts/summarize_ab.py + +```Python#!/usr/bin/env python3import argparseimport difflibimport jsonfrom pathlib import Pathfrom typing import Dict, List, Optional, TupleHIGHER_IS_BETTER = {"acc", "acc_norm", "exact_match", "mc1", "mc2", "f1", "bleu"}LOWER_IS_BETTER = {"word_perplexity", "perplexity", "byte_perplexity", "bpb", "bpc"}def load_json(path: Path): with open(path, "r", encoding="utf-8") as f: return json.load(f)def find_lm_eval_json(eval_dir: Path) -> Optional[Path]: candidates = sorted(eval_dir.rglob("*.json")) for p in candidates: try: data = load_json(p) except Exception: continue if isinstance(data, dict) and "results" in data: return p return Nonedef extract_task_metrics(eval_json: Dict) -> Dict[str, Dict[str, float]]: results = eval_json.get("results", {}) out: Dict[str, Dict[str, float]] = {} for task, task_metrics in results.items(): if not isinstance(task_metrics, dict): continue flat: Dict[str, float] = {} for k, v in task_metrics.items(): if isinstance(v, (int, float)): flat[k] = float(v) out[task] = flat return outdef choose_primary_metric(task: str, metrics: Dict[str, float]) -> Tuple[Optional[str], Optional[float], Optional[str]]: preferred = [ "acc_norm,none", "acc,none", "exact_match,strict-match", "exact_match,none", "mc2,none", "mc1,none", "word_perplexity,none", "perplexity,none", "byte_perplexity,none", "bpb,none", "bpc,none", "acc_norm", "acc", "exact_match", "mc2", "mc1", "word_perplexity", "perplexity", "byte_perplexity", "bpb", "bpc", ] for key in preferred: if key in metrics: short = key.split(",", 1)[0] direction = "higher" if short in HIGHER_IS_BETTER else "lower" if short in LOWER_IS_BETTER else "unknown" return key, metrics[key], direction return None, None, Nonedef compare_generations(base_path: Path, new_path: Path) -> Dict: base_rows = load_json(base_path) new_rows = load_json(new_path) exact = 0 ratios: List[float] = [] changed_examples = [] for b, n in zip(base_rows, new_rows): br = b.get("output", "") nr = n.get("output", "") if br == nr: exact += 1 ratio = difflib.SequenceMatcher(None, br, nr).ratio() ratios.append(ratio) if br != nr and len(changed_examples) < 3: changed_examples.append({ "prompt": b.get("prompt", ""), "base_output": br[:900], "new_output": nr[:900], "similarity": ratio, }) return { "num_prompts": len(base_rows), "exact_match_count": exact, "exact_match_rate": exact / max(1, len(base_rows)), "avg_similarity": sum(ratios) / max(1, len(ratios)), "changed_examples": changed_examples, }def load_quant_meta(path: Path) -> Dict: return load_json(path)def pct_change(new: Optional[float], base: Optional[float]) -> Optional[float]: if new is None or base is None: return None if base == 0: return None return ((new - base) / base) * 100.0def build_case_summary(case_dir: Path) -> Dict: base_quant = load_quant_meta(case_dir / "base" / "model" / "ab_quant_meta.json") new_quant = load_quant_meta(case_dir / "new" / "model" / "ab_quant_meta.json") base_eval_json_path = find_lm_eval_json(case_dir / "base" / "eval") new_eval_json_path = find_lm_eval_json(case_dir / "new" / "eval") if base_eval_json_path is None or new_eval_json_path is None: raise FileNotFoundError(f"Could not find lm-eval JSON under {case_dir}") base_eval = load_json(base_eval_json_path) new_eval = load_json(new_eval_json_path) base_metrics = extract_task_metrics(base_eval) new_metrics = extract_task_metrics(new_eval) task_summaries = [] all_tasks = sorted(set(base_metrics) | set(new_metrics)) new_wins = 0 base_wins = 0 ties = 0 for task in all_tasks: b = base_metrics.get(task, {}) n = new_metrics.get(task, {}) metric_name, bval, direction = choose_primary_metric(task, b) metric_name2, nval, direction2 = choose_primary_metric(task, n) chosen = metric_name or metric_name2 direction = direction or direction2 or "unknown" if chosen is not None: bval = b.get(chosen, bval) nval = n.get(chosen, nval) delta = None winner = "tie" if bval is not None and nval is not None: delta = nval - bval if direction == "higher": winner = "new" if nval > bval else "base" if bval > nval else "tie" elif direction == "lower": winner = "new" if nval < bval else "base" if bval < nval else "tie" if winner == "new": new_wins += 1 elif winner == "base": base_wins += 1 else: ties += 1 task_summaries.append({ "task": task, "metric": chosen, "direction": direction, "base": bval, "new": nval, "delta_new_minus_base": delta, "winner": winner, }) gen = compare_generations(case_dir / "base" / "generations.json", case_dir / "new" / "generations.json") quant_delta = { "elapsed_sec_new_minus_base": new_quant["elapsed_sec"] - base_quant["elapsed_sec"], "elapsed_pct_new_minus_base": pct_change(new_quant["elapsed_sec"], base_quant["elapsed_sec"]), "peak_cuda_mem_bytes_new_minus_base": ( (new_quant.get("peak_cuda_mem_bytes") or 0) - (base_quant.get("peak_cuda_mem_bytes") or 0) ), "peak_cuda_mem_pct_new_minus_base": pct_change( (new_quant.get("peak_cuda_mem_bytes") or 0), (base_quant.get("peak_cuda_mem_bytes") or 0), ), } return { "case": case_dir.name, "base_quant": base_quant, "new_quant": new_quant, "quant_delta": quant_delta, "base_eval_json": str(base_eval_json_path), "new_eval_json": str(new_eval_json_path), "tasks": task_summaries, "task_win_counts": { "new": new_wins, "base": base_wins, "tie": ties, }, "generation_compare": gen, "gpu_logs": { "base": str(case_dir / "base" / "gpu_telemetry.csv"), "new": str(case_dir / "new" / "gpu_telemetry.csv"), }, }def parse_bool(s: str) -> bool: return str(s).strip().lower() in {"1", "true", "yes", "y", "on"}def strict_case_pass( case: Dict, max_slowdown_pct: float, max_mem_increase_pct: float, allow_task_regressions: int, require_win: bool,) -> Tuple[bool, List[str]]: reasons = [] slowdown_pct = case["quant_delta"].get("elapsed_pct_new_minus_base") if slowdown_pct is not None and slowdown_pct > max_slowdown_pct: reasons.append(f"slowdown {slowdown_pct:.2f}% > allowed {max_slowdown_pct:.2f}%") mem_increase_pct = case["quant_delta"].get("peak_cuda_mem_pct_new_minus_base") if mem_increase_pct is not None and mem_increase_pct > max_mem_increase_pct: reasons.append(f"peak mem increase {mem_increase_pct:.2f}% > allowed {max_mem_increase_pct:.2f}%") regressions = case["task_win_counts"]["base"] if regressions > allow_task_regressions: reasons.append(f"{regressions} task regressions > allowed {allow_task_regressions}") if require_win and case["task_win_counts"]["new"] < 1: reasons.append("no task wins for new build") return (len(reasons) == 0, reasons)def strict_overall_verdict( cases: List[Dict], max_slowdown_pct: float, max_mem_increase_pct: float, allow_task_regressions: int, require_ns128_win: bool,) -> Dict: per_case = {} overall_pass = True for case in cases: require_win = require_ns128_win if case["case"] == "ns128" else False ok, reasons = strict_case_pass( case, max_slowdown_pct=max_slowdown_pct, max_mem_increase_pct=max_mem_increase_pct, allow_task_regressions=allow_task_regressions, require_win=require_win, ) per_case[case["case"]] = { "pass": ok, "reasons": reasons, } if not ok: overall_pass = False return { "pass": overall_pass, "per_case": per_case, }def overall_soft_verdict(case_summaries: List[Dict]) -> str: new_wins = 0 base_wins = 0 for case in case_summaries: new_wins += case["task_win_counts"]["new"] base_wins += case["task_win_counts"]["base"] if new_wins > base_wins: return "new looks better overall" if base_wins > new_wins: return "base looks better overall" return "mixed / inconclusive"def render_md(case_summaries: List[Dict], summary_path: Path, strict_result: Dict, run_meta: Dict) -> None: lines: List[str] = [] lines.append("# GPTQ-Pro A/B summary") lines.append("") lines.append(f"Soft verdict: **{overall_soft_verdict(case_summaries)}**") lines.append(f"Strict verdict: **{'PASS' if strict_result['pass'] else 'FAIL'}**") lines.append("") lines.append("## Run metadata") lines.append("") lines.append(f"- base git sha: `{run_meta.get('base_git_sha')}`") lines.append(f"- new git sha: `{run_meta.get('new_git_sha')}`") lines.append(f"- model: `{run_meta.get('model_id')}`") lines.append(f"- tasks: `{run_meta.get('full_tasks')}`") lines.append(f"- GPU: `{run_meta.get('gpu')}`") lines.append("") lines.append("## Strict criteria") lines.append("") lines.append(f"- max slowdown pct: {run_meta.get('strict_max_slowdown_pct')}") lines.append(f"- max peak mem increase pct: {run_meta.get('strict_max_mem_increase_pct')}") lines.append(f"- allowed task regressions per case: {run_meta.get('strict_allow_task_regressions')}") lines.append(f"- require ns128 to have at least one win: {run_meta.get('strict_require_ns128_win')}") lines.append("") for case in case_summaries: strict_case = strict_result["per_case"][case["case"]] lines.append(f"## {case['case']}") lines.append("") lines.append(f"- strict result: **{'PASS' if strict_case['pass'] else 'FAIL'}**") if strict_case["reasons"]: for r in strict_case["reasons"]: lines.append(f"- reason: {r}") lines.append("") lines.append("### Quantization") lines.append("") lines.append(f"- base elapsed: {case['base_quant']['elapsed_sec']:.2f}s") lines.append(f"- new elapsed: {case['new_quant']['elapsed_sec']:.2f}s") lines.append(f"- delta sec (new - base): {case['quant_delta']['elapsed_sec_new_minus_base']:.2f}s") lines.append(f"- delta pct (new - base): {case['quant_delta']['elapsed_pct_new_minus_base']}") lines.append(f"- base peak CUDA mem: {case['base_quant'].get('peak_cuda_mem_bytes')}") lines.append(f"- new peak CUDA mem: {case['new_quant'].get('peak_cuda_mem_bytes')}") lines.append(f"- delta peak mem bytes: {case['quant_delta']['peak_cuda_mem_bytes_new_minus_base']}") lines.append(f"- delta peak mem pct: {case['quant_delta']['peak_cuda_mem_pct_new_minus_base']}") lines.append(f"- base GPU log: `{case['gpu_logs']['base']}`") lines.append(f"- new GPU log: `{case['gpu_logs']['new']}`") lines.append("") lines.append("### Eval") lines.append("") lines.append(f"- task wins -> new: {case['task_win_counts']['new']}, base: {case['task_win_counts']['base']}, tie: {case['task_win_counts']['tie']}") lines.append("") lines.append("| task | metric | direction | base | new | delta (new-base) | winner |") lines.append("|---|---|---:|---:|---:|---:|---|") for t in case["tasks"]: lines.append( f"| {t['task']} | {t['metric']} | {t['direction']} | {t['base']} | {t['new']} | {t['delta_new_minus_base']} | {t['winner']} |" ) lines.append("") gc = case["generation_compare"] lines.append("### Deterministic generations") lines.append("") lines.append(f"- exact matches: {gc['exact_match_count']} / {gc['num_prompts']}") lines.append(f"- exact match rate: {gc['exact_match_rate']:.3f}") lines.append(f"- average similarity: {gc['avg_similarity']:.3f}") if gc["changed_examples"]: lines.append("") lines.append("#### Sample changed outputs") lines.append("") for ex in gc["changed_examples"]: lines.append(f"- Prompt: `{ex['prompt']}`") lines.append(f" - similarity: {ex['similarity']:.3f}") lines.append(f" - base: {ex['base_output'][:220].replace(chr(10), ' ')}") lines.append(f" - new: {ex['new_output'][:220].replace(chr(10), ' ')}") lines.append("") summary_path.write_text("\n".join(lines), encoding="utf-8")def main() -> int: ap = argparse.ArgumentParser() ap.add_argument("--output-root", required=True) ap.add_argument("--strict-max-slowdown-pct", type=float, default=25.0) ap.add_argument("--strict-max-mem-increase-pct", type=float, default=20.0) ap.add_argument("--strict-allow-task-regressions", type=int, default=0) ap.add_argument("--strict-require-ns128-win", default="true") args = ap.parse_args() root = Path(args.output_root) run_meta_path = root / "run_meta.json" run_meta = load_json(run_meta_path) if run_meta_path.exists() else {} cases = [] for name in ("ns128", "ns512"): case_dir = root / name if case_dir.exists(): cases.append(build_case_summary(case_dir)) if not cases: raise FileNotFoundError(f"No case directories found under {root}") strict_result = strict_overall_verdict( cases=cases, max_slowdown_pct=args.strict_max_slowdown_pct, max_mem_increase_pct=args.strict_max_mem_increase_pct, allow_task_regressions=args.strict_allow_task_regressions, require_ns128_win=parse_bool(args.strict_require_ns128_win), ) payload = { "soft_verdict": overall_soft_verdict(cases), "strict_verdict": strict_result, "cases": cases, "run_meta": run_meta, } summary_json = root / "summary.json" summary_md = root / "summary.md" summary_json.write_text(json.dumps(payload, indent=2), encoding="utf-8") render_md(cases, summary_md, strict_result, run_meta) print(summary_md) print(summary_json) return 0if __name__ == "__main__": raise SystemExit(main()) +Python#!/usr/bin/env python3import argparseimport difflibimport jsonfrom pathlib import Pathfrom typing import Dict, List, Optional, TupleHIGHER_IS_BETTER = {"acc", "acc_norm", "exact_match", "mc1", "mc2", "f1", "bleu"}LOWER_IS_BETTER = {"word_perplexity", "perplexity", "byte_perplexity", "bpb", "bpc"}def load_json(path: Path): with open(path, "r", encoding="utf-8") as f: return json.load(f)def find_lm_eval_json(eval_dir: Path) -> Optional[Path]: candidates = sorted(eval_dir.rglob("*.json")) for p in candidates: try: data = load_json(p) except Exception: continue if isinstance(data, dict) and "results" in data: return p return Nonedef extract_task_metrics(eval_json: Dict) -> Dict[str, Dict[str, float]]: results = eval_json.get("results", {}) out: Dict[str, Dict[str, float]] = {} for task, task_metrics in results.items(): if not isinstance(task_metrics, dict): continue flat: Dict[str, float] = {} for k, v in task_metrics.items(): if isinstance(v, (int, float)): flat[k] = float(v) out[task] = flat return outdef choose_primary_metric(task: str, metrics: Dict[str, float]) -> Tuple[Optional[str], Optional[float], Optional[str]]: preferred = [ "acc_norm,none", "acc,none", "exact_match,strict-match", "exact_match,none", "mc2,none", "mc1,none", "word_perplexity,none", "perplexity,none", "byte_perplexity,none", "bpb,none", "bpc,none", "acc_norm", "acc", "exact_match", "mc2", "mc1", "word_perplexity", "perplexity", "byte_perplexity", "bpb", "bpc", ] for key in preferred: if key in metrics: short = key.split(",", 1)[0] direction = "higher" if short in HIGHER_IS_BETTER else "lower" if short in LOWER_IS_BETTER else "unknown" return key, metrics[key], direction return None, None, Nonedef compare_generations(base_path: Path, new_path: Path) -> Dict: base_rows = load_json(base_path) new_rows = load_json(new_path) exact = 0 ratios: List[float] = [] changed_examples = [] for b, n in zip(base_rows, new_rows): br = b.get("output", "") nr = n.get("output", "") if br == nr: exact += 1 ratio = difflib.SequenceMatcher(None, br, nr).ratio() ratios.append(ratio) if br != nr and len(changed_examples) < 3: changed_examples.append({ "prompt": b.get("prompt", ""), "base_output": br[:900], "new_output": nr[:900], "similarity": ratio, }) return { "num_prompts": len(base_rows), "exact_match_count": exact, "exact_match_rate": exact / max(1, len(base_rows)), "avg_similarity": sum(ratios) / max(1, len(ratios)), "changed_examples": changed_examples, }def load_quant_meta(path: Path) -> Dict: return load_json(path)def pct_change(new: Optional[float], base: Optional[float]) -> Optional[float]: if new is None or base is None: return None if base == 0: return None return ((new - base) / base) * 100.0def build_case_summary(case_dir: Path) -> Dict: base_quant = load_quant_meta(case_dir / "base" / "model" / "ab_quant_meta.json") new_quant = load_quant_meta(case_dir / "new" / "model" / "ab_quant_meta.json") base_eval_json_path = find_lm_eval_json(case_dir / "base" / "eval") new_eval_json_path = find_lm_eval_json(case_dir / "new" / "eval") if base_eval_json_path is None or new_eval_json_path is None: raise FileNotFoundError(f"Could not find lm-eval JSON under {case_dir}") base_eval = load_json(base_eval_json_path) new_eval = load_json(new_eval_json_path) base_metrics = extract_task_metrics(base_eval) new_metrics = extract_task_metrics(new_eval) task_summaries = [] all_tasks = sorted(set(base_metrics) | set(new_metrics)) new_wins = 0 base_wins = 0 ties = 0 for task in all_tasks: b = base_metrics.get(task, {}) n = new_metrics.get(task, {}) metric_name, bval, direction = choose_primary_metric(task, b) metric_name2, nval, direction2 = choose_primary_metric(task, n) chosen = metric_name or metric_name2 direction = direction or direction2 or "unknown" if chosen is not None: bval = b.get(chosen, bval) nval = n.get(chosen, nval) delta = None winner = "tie" if bval is not None and nval is not None: delta = nval - bval if direction == "higher": winner = "new" if nval > bval else "base" if bval > nval else "tie" elif direction == "lower": winner = "new" if nval < bval else "base" if bval < nval else "tie" if winner == "new": new_wins += 1 elif winner == "base": base_wins += 1 else: ties += 1 task_summaries.append({ "task": task, "metric": chosen, "direction": direction, "base": bval, "new": nval, "delta_new_minus_base": delta, "winner": winner, }) gen = compare_generations(case_dir / "base" / "generations.json", case_dir / "new" / "generations.json") quant_delta = { "elapsed_sec_new_minus_base": new_quant["elapsed_sec"] - base_quant["elapsed_sec"], "elapsed_pct_new_minus_base": pct_change(new_quant["elapsed_sec"], base_quant["elapsed_sec"]), "peak_cuda_mem_bytes_new_minus_base": ( (new_quant.get("peak_cuda_mem_bytes") or 0) - (base_quant.get("peak_cuda_mem_bytes") or 0) ), "peak_cuda_mem_pct_new_minus_base": pct_change( (new_quant.get("peak_cuda_mem_bytes") or 0), (base_quant.get("peak_cuda_mem_bytes") or 0), ), } return { "case": case_dir.name, "base_quant": base_quant, "new_quant": new_quant, "quant_delta": quant_delta, "base_eval_json": str(base_eval_json_path), "new_eval_json": str(new_eval_json_path), "tasks": task_summaries, "task_win_counts": { "new": new_wins, "base": base_wins, "tie": ties, }, "generation_compare": gen, "gpu_logs": { "base": str(case_dir / "base" / "gpu_telemetry.csv"), "new": str(case_dir / "new" / "gpu_telemetry.csv"), }, }def parse_bool(s: str) -> bool: return str(s).strip().lower() in {"1", "true", "yes", "y", "on"}def strict_case_pass( case: Dict, max_slowdown_pct: float, max_mem_increase_pct: float, allow_task_regressions: int, require_win: bool,) -> Tuple[bool, List[str]]: reasons = [] slowdown_pct = case["quant_delta"].get("elapsed_pct_new_minus_base") if slowdown_pct is not None and slowdown_pct > max_slowdown_pct: reasons.append(f"slowdown {slowdown_pct:.2f}% > allowed {max_slowdown_pct:.2f}%") mem_increase_pct = case["quant_delta"].get("peak_cuda_mem_pct_new_minus_base") if mem_increase_pct is not None and mem_increase_pct > max_mem_increase_pct: reasons.append(f"peak mem increase {mem_increase_pct:.2f}% > allowed {max_mem_increase_pct:.2f}%") regressions = case["task_win_counts"]["base"] if regressions > allow_task_regressions: reasons.append(f"{regressions} task regressions > allowed {allow_task_regressions}") if require_win and case["task_win_counts"]["new"] < 1: reasons.append("no task wins for new build") return (len(reasons) == 0, reasons)def strict_overall_verdict( cases: List[Dict], max_slowdown_pct: float, max_mem_increase_pct: float, allow_task_regressions: int, require_ns128_win: bool,) -> Dict: per_case = {} overall_pass = True for case in cases: require_win = require_ns128_win if case["case"] == "ns128" else False ok, reasons = strict_case_pass( case, max_slowdown_pct=max_slowdown_pct, max_mem_increase_pct=max_mem_increase_pct, allow_task_regressions=allow_task_regressions, require_win=require_win, ) per_case[case["case"]] = { "pass": ok, "reasons": reasons, } if not ok: overall_pass = False return { "pass": overall_pass, "per_case": per_case, }def overall_soft_verdict(case_summaries: List[Dict]) -> str: new_wins = 0 base_wins = 0 for case in case_summaries: new_wins += case["task_win_counts"]["new"] base_wins += case["task_win_counts"]["base"] if new_wins > base_wins: return "new looks better overall" if base_wins > new_wins: return "base looks better overall" return "mixed / inconclusive"def render_md(case_summaries: List[Dict], summary_path: Path, strict_result: Dict, run_meta: Dict) -> None: lines: List[str] = [] lines.append("# GPTQ-Pro A/B summary") lines.append("") lines.append(f"Soft verdict: **{overall_soft_verdict(case_summaries)}**") lines.append(f"Strict verdict: **{'PASS' if strict_result['pass'] else 'FAIL'}**") lines.append("") lines.append("## Run metadata") lines.append("") lines.append(f"- base git sha: `{run_meta.get('base_git_sha')}`") lines.append(f"- new git sha: `{run_meta.get('new_git_sha')}`") lines.append(f"- model: `{run_meta.get('model_id')}`") lines.append(f"- tasks: `{run_meta.get('full_tasks')}`") lines.append(f"- GPU: `{run_meta.get('gpu')}`") lines.append("") lines.append("## Strict criteria") lines.append("") lines.append(f"- max slowdown pct: {run_meta.get('strict_max_slowdown_pct')}") lines.append(f"- max peak mem increase pct: {run_meta.get('strict_max_mem_increase_pct')}") lines.append(f"- allowed task regressions per case: {run_meta.get('strict_allow_task_regressions')}") lines.append(f"- require ns128 to have at least one win: {run_meta.get('strict_require_ns128_win')}") lines.append("") for case in case_summaries: strict_case = strict_result["per_case"][case["case"]] lines.append(f"## {case['case']}") lines.append("") lines.append(f"- strict result: **{'PASS' if strict_case['pass'] else 'FAIL'}**") if strict_case["reasons"]: for r in strict_case["reasons"]: lines.append(f"- reason: {r}") lines.append("") lines.append("### Quantization") lines.append("") lines.append(f"- base elapsed: {case['base_quant']['elapsed_sec']:.2f}s") lines.append(f"- new elapsed: {case['new_quant']['elapsed_sec']:.2f}s") lines.append(f"- delta sec (new - base): {case['quant_delta']['elapsed_sec_new_minus_base']:.2f}s") lines.append(f"- delta pct (new - base): {case['quant_delta']['elapsed_pct_new_minus_base']}") lines.append(f"- base peak CUDA mem: {case['base_quant'].get('peak_cuda_mem_bytes')}") lines.append(f"- new peak CUDA mem: {case['new_quant'].get('peak_cuda_mem_bytes')}") lines.append(f"- delta peak mem bytes: {case['quant_delta']['peak_cuda_mem_bytes_new_minus_base']}") lines.append(f"- delta peak mem pct: {case['quant_delta']['peak_cuda_mem_pct_new_minus_base']}") lines.append(f"- base GPU log: `{case['gpu_logs']['base']}`") lines.append(f"- new GPU log: `{case['gpu_logs']['new']}`") lines.append("") lines.append("### Eval") lines.append("") lines.append(f"- task wins -> new: {case['task_win_counts']['new']}, base: {case['task_win_counts']['base']}, tie: {case['task_win_counts']['tie']}") lines.append("") lines.append("| task | metric | direction | base | new | delta (new-base) | winner |") lines.append("|---|---|---:|---:|---:|---:|---|") for t in case["tasks"]: lines.append( f"| {t['task']} | {t['metric']} | {t['direction']} | {t['base']} | {t['new']} | {t['delta_new_minus_base']} | {t['winner']} |" ) lines.append("") gc = case["generation_compare"] lines.append("### Deterministic generations") lines.append("") lines.append(f"- exact matches: {gc['exact_match_count']} / {gc['num_prompts']}") lines.append(f"- exact match rate: {gc['exact_match_rate']:.3f}") lines.append(f"- average similarity: {gc['avg_similarity']:.3f}") if gc["changed_examples"]: lines.append("") lines.append("#### Sample changed outputs") lines.append("") for ex in gc["changed_examples"]: lines.append(f"- Prompt: `{ex['prompt']}`") lines.append(f" - similarity: {ex['similarity']:.3f}") lines.append(f" - base: {ex['base_output'][:220].replace(chr(10), ' ')}") lines.append(f" - new: {ex['new_output'][:220].replace(chr(10), ' ')}") lines.append("") summary_path.write_text("\n".join(lines), encoding="utf-8")def main() -> int: ap = argparse.ArgumentParser() ap.add_argument("--output-root", required=True) ap.add_argument("--strict-max-slowdown-pct", type=float, default=25.0) ap.add_argument("--strict-max-mem-increase-pct", type=float, default=20.0) ap.add_argument("--strict-allow-task-regressions", type=int, default=0) ap.add_argument("--strict-require-ns128-win", default="true") args = ap.parse_args() root = Path(args.output_root) run_meta_path = root / "run_meta.json" run_meta = load_json(run_meta_path) if run_meta_path.exists() else {} cases = [] for name in ("ns128", "ns512"): case_dir = root / name if case_dir.exists(): cases.append(build_case_summary(case_dir)) if not cases: raise FileNotFoundError(f"No case directories found under {root}") strict_result = strict_overall_verdict( cases=cases, max_slowdown_pct=args.strict_max_slowdown_pct, max_mem_increase_pct=args.strict_max_mem_increase_pct, allow_task_regressions=args.strict_allow_task_regressions, require_ns128_win=parse_bool(args.strict_require_ns128_win), ) payload = { "soft_verdict": overall_soft_verdict(cases), "strict_verdict": strict_result, "cases": cases, "run_meta": run_meta, } summary_json = root / "summary.json" summary_md = root / "summary.md" summary_json.write_text(json.dumps(payload, indent=2), encoding="utf-8") render_md(cases, summary_md, strict_result, run_meta) print(summary_md) print(summary_json) return 0if __name__ == "__main__": raise SystemExit(main()) +``` + +## Run it + +Your bootstrap command stays the same: + +```Bashcd gptq_pro_ab_suite./scripts/bootstrap_worktrees_and_envs.sh \ --repo-root ~/GPTQ-Pro \ --base-commit "3b6de40^" \ --new-commit "e6dee44" \ --torch-install-cmd 'pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128' +Bashcd gptq_pro_ab_suite./scripts/bootstrap_worktrees_and_envs.sh \ --repo-root ~/GPTQ-Pro \ --base-commit "3b6de40^" \ --new-commit "e6dee44" \ --torch-install-cmd 'pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128' +``` + +Then run the stricter suite like this: + +```Bash./scripts/run_full_suite.sh \ --repo-root ~/GPTQ-Pro \ --base-commit "3b6de40^" \ --new-commit "e6dee44" \ --gpu 0 \ --model-id Qwen/Qwen3.5-4B-Base \ --tasks 'wikitext,hellaswag,winogrande,piqa' \ --extra-tasks 'arc_challenge,gsm8k' \ --output-root ~/gptq_ab_results \ --gpu-log-interval 2 \ --strict-max-slowdown-pct 25 \ --strict-max-mem-increase-pct 20 \ --strict-allow-task-regressions 0 \ --strict-require-ns128-win true +Bash./scripts/run_full_suite.sh \ --repo-root ~/GPTQ-Pro \ --base-commit "3b6de40^" \ --new-commit "e6dee44" \ --gpu 0 \ --model-id Qwen/Qwen3.5-4B-Base \ --tasks 'wikitext,hellaswag,winogrande,piqa' \ --extra-tasks 'arc_challenge,gsm8k' \ --output-root ~/gptq_ab_results \ --gpu-log-interval 2 \ --strict-max-slowdown-pct 25 \ --strict-max-mem-increase-pct 20 \ --strict-allow-task-regressions 0 \ --strict-require-ns128-win true +``` + +## What this adds + +You asked for 4 upgrades. This version gives you all 4: + +- `nvidia-smi` snapshots during **quantization** into: + +`~/gptq_ab_results/ns128/base/gpu_telemetry.csv` + +`~/gptq_ab_results/ns128/new/gpu_telemetry.csv` + +and same for `ns512` +- `~/gptq_ab_results/ns128/base/gpu_telemetry.csv` +- `~/gptq_ab_results/ns128/new/gpu_telemetry.csv` +- and same for `ns512` +- exact git SHAs from each worktree in: + +`~/gptq_ab_results/run_meta.json` + +`~/gptq_ab_results/summary.md` +- `~/gptq_ab_results/run_meta.json` +- `~/gptq_ab_results/summary.md` +- optional heavier tasks: + +`--extra-tasks 'arc_challenge,gsm8k'` +- `--extra-tasks 'arc_challenge,gsm8k'` +- strict pass/fail verdict: + +fail if too slow + +fail if peak memory grows too much + +fail if task regressions exceed threshold + +optionally require at least one win in the **ns128** stress case +- fail if too slow +- fail if peak memory grows too much +- fail if task regressions exceed threshold +- optionally require at least one win in the **ns128** stress case + +## What I’d use as the real decision rule + +For your PR, I’d judge it like this: + +- **PASS** if `ns128` gets at least one meaningful win and no regressions, while slowdown stays under 25% +- **PASS** if `ns512` is neutral +- **FAIL** if `ns128` is flat or worse and quantization gets slower for no gain + +Because this PR is supposed to help the **low-sample fallback path**, `ns128` matters much more than `ns512`. + +Send me your `summary.md` or `summary.json` after the run and I’ll interpret whether the merged default was actually justified. + +--- + +### User: +how can we add an external judge llm to classify the level of quant that should be applied to each lay4er? + +--- + +### ChatGPT: +I’m checking what the current GPTQ/quantization toolchain already supports so the design fits real hooks instead of fighting the repo. + +--- + +Yes — and the clean way is **not** “LLM reads raw tensors and vibes out a bitwidth.” +Use the external judge as a **policy model over measured layer features**, then map its output into `QuantizeConfig.dynamic`, because GPTQModel already supports per-module overrides and even skipping specific modules from quantization entirely. [GitHub+1](https://github.com/ModelCloud/GPTQModel) + +### The right architecture + +Use a **two-stage controller**: + +1. **Probe pass** +Run the model on a calibration set and compute per-layer features such as: + +- output reconstruction error after temporary W4/W8 quantization +- activation sensitivity +- weight kurtosis / outlier score +- Hessian-ish proxy or input covariance proxy +- layer type (`q_proj`, `k_proj`, `v_proj`, `o_proj`, `up_proj`, `down_proj`, `gate_proj`) +- FLOPs / parameter count / memory cost of upgrading that layer +2. **Judge pass** +Send only those features to the external judge LLM and force it to choose from a **small allowed action set**, for example: + +- `skip` +- `w8_g128` +- `w4_g128` +- `w4_g64` +- `w4_g128_no_desc_act` +- `w4_g128_mse_on` +3. **Compile pass** +Convert the judge output into GPTQModel `dynamic` regex overrides. +4. **Quantize + eval** +Quantize once with that per-layer plan, then score with `lm-eval`. + +This is a good fit because GPTQModel already documents dynamic per-module mixed quantization, with overrides for `bits`, `group_size`, `sym`, `desc_act`, `mse`, and `pack_dtype`, plus negative matching to skip modules. [GitHub](https://github.com/ModelCloud/GPTQModel) + +### Why this makes sense + +Layer sensitivity is not uniform. Recent work specifically argues that **activation sensitivity** and **weight kurtosis** are useful signals for finding layers that need more precision, and reported up to **9% lower perplexity** with only about a **2% memory increase** when extra budget is targeted to sensitive layers. AWQ also rests on the idea that only a small fraction of weights/channels are especially salient, while GPTQ uses second-order information from calibration data to reduce quantization error. [arXiv+2ar5iv+2](https://arxiv.org/html/2503.06518v1) + +### What I would actually build + +#### Option A — pragmatic and good + +Use the judge LLM only for **tier classification**. + +Input to judge: + +```JSON{ "layer": "model.layers.18.mlp.down_proj", "type": "down_proj", "params_millions": 37.7, "w4_recon_mse": 0.0182, "w8_recon_mse": 0.0021, "sensitivity_score": 0.91, "kurtosis": 14.7, "activation_outlier_rate": 0.032, "delta_ppl_if_w4_est": 0.19, "memory_cost_w8_mb": 72, "latency_cost_w8_ms": 0.4} +JSON{ "layer": "model.layers.18.mlp.down_proj", "type": "down_proj", "params_millions": 37.7, "w4_recon_mse": 0.0182, "w8_recon_mse": 0.0021, "sensitivity_score": 0.91, "kurtosis": 14.7, "activation_outlier_rate": 0.032, "delta_ppl_if_w4_est": 0.19, "memory_cost_w8_mb": 72, "latency_cost_w8_ms": 0.4} +``` + +Constrained output: + +```JSON{ "action": "w8_g128", "reason_code": "high_sensitivity_high_kurtosis"} +JSON{ "action": "w8_g128", "reason_code": "high_sensitivity_high_kurtosis"} +``` + +Then compile: + +```Pythondynamic = { r".*\.18\..*down_proj.*": {"bits": 8, "group_size": 128},} +Pythondynamic = { r".*\.18\..*down_proj.*": {"bits": 8, "group_size": 128},} +``` + +#### Option B — stronger and less gimmicky + +Use the LLM to **propose**, but let a **scorer** accept/reject. + +Pipeline: + +- judge LLM emits plan +- deterministic scorer estimates total memory budget and predicted risk +- if over budget, auto-downgrade least sensitive upgrades +- quantize + eval +- feed results back into a small dataset for later fine-tuning of the judge + +That gives you something closer to a real controller instead of a fancy prompt wrapper. + +### Best action space for Qwen 3.5 4B + +For a first version, do **not** let the judge choose arbitrary numeric values. Keep it discrete: + +- `skip` +- `w8_g128` +- `w4_g64` +- `w4_g128` +- `w4_g128_mse` +- `w4_g128_no_desc_act` + +That matches GPTQ-style knobs that are actually exposed today through dynamic overrides. [GitHub](https://github.com/ModelCloud/GPTQModel) + +### Which layers usually deserve special handling + +As a starting heuristic before the judge learns anything: + +- `lm_head`: often keep higher precision or skip quantizing if quality matters a lot +- `o_proj`, `down_proj`, `gate_proj`: commonly more fragile than some other linears +- early and late transformer blocks: often more sensitive than middle blocks +- layers with high activation sensitivity / high kurtosis: promote to W8 or tighter group size +- layers with low sensitivity and low outlier score: leave at W4/G128 + +That heuristic is exactly the sort of structure your judge should learn from the probe features, rather than inventing from names alone. The layer-sensitive paper’s core point is that sensitivity varies materially by layer and can be exploited for mixed allocation. [arXiv](https://arxiv.org/html/2503.06518v1) + +## Concrete implementation plan + +### 1) Add a probe script + +For each linear module in Qwen 3.5 4B: + +- capture FP activations on a small calibration set +- run temporary quantization candidates: + +W4/G128 + +W4/G64 + +W8/G128 +- W4/G128 +- W4/G64 +- W8/G128 +- compute: + +output MSE + +cosine similarity + +max absolute error + +activation outlier rate + +weight kurtosis +- output MSE +- cosine similarity +- max absolute error +- activation outlier rate +- weight kurtosis +- save one JSON row per layer + +Example row: + +```JSON{ "name": "model.layers.7.self_attn.o_proj", "type": "o_proj", "block": 7, "params": 16777216, "fp_norm": 92.1, "w4g128_mse": 0.0127, "w4g64_mse": 0.0084, "w8g128_mse": 0.0012, "cos_w4g128": 0.984, "cos_w4g64": 0.989, "cos_w8g128": 0.998, "kurtosis": 11.4, "act_outlier_rate": 0.027, "memory_mb_w4g128": 8.2, "memory_mb_w8g128": 16.4} +JSON{ "name": "model.layers.7.self_attn.o_proj", "type": "o_proj", "block": 7, "params": 16777216, "fp_norm": 92.1, "w4g128_mse": 0.0127, "w4g64_mse": 0.0084, "w8g128_mse": 0.0012, "cos_w4g128": 0.984, "cos_w4g64": 0.989, "cos_w8g128": 0.998, "kurtosis": 11.4, "act_outlier_rate": 0.027, "memory_mb_w4g128": 8.2, "memory_mb_w8g128": 16.4} +``` + +### 2) Ask the judge for a plan + +Prompt pattern: + +```You are a quantization policy judge.Goal:Choose the lowest-precision setting for each layer that preserves model quality under a global memory budget.Allowed actions:- skip- w8_g128- w4_g64- w4_g128- w4_g128_mse- w4_g128_no_desc_actRules:- Prefer w4_g128 unless the evidence suggests fragility.- Use w8_g128 for layers with high sensitivity and large estimated quality loss at w4.- Use skip only for extreme outliers.- Keep total upgraded memory under the supplied budget.- Return only JSON.Input:{... layer table ...} +You are a quantization policy judge.Goal:Choose the lowest-precision setting for each layer that preserves model quality under a global memory budget.Allowed actions:- skip- w8_g128- w4_g64- w4_g128- w4_g128_mse- w4_g128_no_desc_actRules:- Prefer w4_g128 unless the evidence suggests fragility.- Use w8_g128 for layers with high sensitivity and large estimated quality loss at w4.- Use skip only for extreme outliers.- Keep total upgraded memory under the supplied budget.- Return only JSON.Input:{... layer table ...} +``` + +Return: + +```JSON{ "budget_mb_extra": 220, "assignments": [ {"name": "model.layers.0.self_attn.o_proj", "action": "w8_g128"}, {"name": "model.layers.1.mlp.down_proj", "action": "w4_g64"}, {"name": "model.layers.2.self_attn.q_proj", "action": "w4_g128"} ]} +JSON{ "budget_mb_extra": 220, "assignments": [ {"name": "model.layers.0.self_attn.o_proj", "action": "w8_g128"}, {"name": "model.layers.1.mlp.down_proj", "action": "w4_g64"}, {"name": "model.layers.2.self_attn.q_proj", "action": "w4_g128"} ]} +``` + +### 3) Compile to GPTQModel dynamic + +```PythonACTION_MAP = { "w8_g128": {"bits": 8, "group_size": 128}, "w4_g64": {"bits": 4, "group_size": 64}, "w4_g128": {"bits": 4, "group_size": 128}, "w4_g128_mse": {"bits": 4, "group_size": 128, "mse": True}, "w4_g128_no_desc_act": {"bits": 4, "group_size": 128, "desc_act": False},}def compile_dynamic(assignments): dynamic = {} for row in assignments: name = row["name"] action = row["action"] pattern = rf"{name.replace('.', r'\.')}" if action == "skip": dynamic[rf"-:{pattern}"] = {} else: dynamic[rf"+:{pattern}"] = ACTION_MAP[action] return dynamic +PythonACTION_MAP = { "w8_g128": {"bits": 8, "group_size": 128}, "w4_g64": {"bits": 4, "group_size": 64}, "w4_g128": {"bits": 4, "group_size": 128}, "w4_g128_mse": {"bits": 4, "group_size": 128, "mse": True}, "w4_g128_no_desc_act": {"bits": 4, "group_size": 128, "desc_act": False},}def compile_dynamic(assignments): dynamic = {} for row in assignments: name = row["name"] action = row["action"] pattern = rf"{name.replace('.', r'\.')}" if action == "skip": dynamic[rf"-:{pattern}"] = {} else: dynamic[rf"+:{pattern}"] = ACTION_MAP[action] return dynamic +``` + +And then: + +```Pythonfrom gptqmodel import QuantizeConfigqcfg = QuantizeConfig.gptq_pro()qcfg.dynamic = compile_dynamic(plan["assignments"]) +Pythonfrom gptqmodel import QuantizeConfigqcfg = QuantizeConfig.gptq_pro()qcfg.dynamic = compile_dynamic(plan["assignments"]) +``` + +GPTQModel explicitly supports positive matches for overrides and negative matches for skipping modules. [GitHub](https://github.com/ModelCloud/GPTQModel) + +## What I would *not* do + +I would **not** let the LLM directly see: + +- raw tensors +- full histograms for hundreds of layers in one prompt +- unconstrained actions like “choose any bits from 2–16” +- free-form natural language reasoning that then gets parsed loosely + +That turns the system into prompt soup. + +Use the LLM as: + +- a **budgeted classifier** +- with **small discrete actions** +- over **measured numeric features** +- followed by **deterministic compilation** + +### Better than an LLM in v1 + +Honestly, the strongest v1 is: + +- collect probe data +- train a tiny XGBoost / LightGBM / logistic model to predict: + +“safe at W4/G128” + +“needs tighter grouping” + +“needs W8” + +“skip” +- “safe at W4/G128” +- “needs tighter grouping” +- “needs W8” +- “skip” +- optionally use the LLM only to explain decisions + +That is usually more stable than a judge LLM. The LLM becomes useful when you want to combine many signals plus a **global budget constraint**, or when you want an agent loop that proposes and revises mixed-precision recipes. Mixed-precision/non-uniform quantization is already an active direction in current toolchains like AutoRound and LLM Compressor. [GitHub+1](https://github.com/intel/auto-round) + +## Best practical roadmap for your setup + +For **Qwen 3.5 4B on 1x3090**, I’d do this: + +1. Build probe dataset on 128–256 calibration samples. +2. Start with only these module types: + +- `o_proj` +- `down_proj` +- `gate_proj` +- `lm_head` +3. Judge action space: + +- `w4_g128` +- `w4_g64` +- `w8_g128` +- `skip` +4. Compile to `QuantizeConfig.dynamic` +5. Run your existing A/B suite +6. Keep the judge only if it beats simple heuristics + +That gets you a real mixed-precision controller fast, using hooks GPTQModel already exposes. [GitHub+1](https://github.com/ModelCloud/GPTQModel) + +I can write you the **full probe + judge + dynamic-compiler scaffold** next. + +--- + +### User: +provide full please for me + +--- + +### ChatGPT: +Yes. Paste this into an empty folder. It creates a full scaffold that does: + +- probe per-layer sensitivity from real activations +- ask an external OpenAI-compatible judge LLM for discrete actions +- enforce a hard extra-memory budget deterministically +- compile the result into `QuantizeConfig.dynamic` + +That matches GPTQModel’s documented per-module mixed-quantization and `-:` skip controls. [GitHub](https://github.com/ModelCloud/GPTQModel) + +```Bashmkdir -p quant_judge_scaffold && cd quant_judge_scaffoldcat > README.md <<'EOF'# Quant Judge Scaffold for GPTQModelThis project adds an external "judge" LLM on top of GPTQModel mixed quantization.Pipeline:1. Probe layer/module sensitivity on a calibration set using the FP model.2. Compute proxy quantization errors for a small discrete action set.3. Ask an external judge LLM to choose one action per module.4. Enforce a deterministic extra-memory budget.5. Compile the plan into `QuantizeConfig.dynamic`.6. Quantize with GPTQModel using the resulting mixed plan.Supported actions in this scaffold:- w4_g128 (base)- w4_g64- w8_g128- skip (do not quantize that module)Default targeted modules:- o_proj- down_proj- gate_proj- lm_headYou can extend to q_proj/k_proj/v_proj/up_proj later, but start conservative.## InstallInstall torch yourself to match your CUDA stack, then:```bashpip install -r requirements.txt +Bashmkdir -p quant_judge_scaffold && cd quant_judge_scaffoldcat > README.md <<'EOF'# Quant Judge Scaffold for GPTQModelThis project adds an external "judge" LLM on top of GPTQModel mixed quantization.Pipeline:1. Probe layer/module sensitivity on a calibration set using the FP model.2. Compute proxy quantization errors for a small discrete action set.3. Ask an external judge LLM to choose one action per module.4. Enforce a deterministic extra-memory budget.5. Compile the plan into `QuantizeConfig.dynamic`.6. Quantize with GPTQModel using the resulting mixed plan.Supported actions in this scaffold:- w4_g128 (base)- w4_g64- w8_g128- skip (do not quantize that module)Default targeted modules:- o_proj- down_proj- gate_proj- lm_headYou can extend to q_proj/k_proj/v_proj/up_proj later, but start conservative.## InstallInstall torch yourself to match your CUDA stack, then:```bashpip install -r requirements.txt +``` + +## Probe only + +```Bashpython -m quant_judge.pipeline probe \ --model-id Qwen/Qwen3.5-4B-Base \ --out-dir ./runs/qwen35_4b_probe \ --dataset-name wikitext \ --dataset-config wikitext-2-raw-v1 \ --split train \ --text-field text \ --num-texts 64 \ --max-length 256 \ --max-rows-per-module 128 \ --include-modules o_proj,down_proj,gate_proj,lm_head +Bashpython -m quant_judge.pipeline probe \ --model-id Qwen/Qwen3.5-4B-Base \ --out-dir ./runs/qwen35_4b_probe \ --dataset-name wikitext \ --dataset-config wikitext-2-raw-v1 \ --split train \ --text-field text \ --num-texts 64 \ --max-length 256 \ --max-rows-per-module 128 \ --include-modules o_proj,down_proj,gate_proj,lm_head +``` + +## Plan with heuristic fallback only + +```Bashpython -m quant_judge.pipeline plan \ --features-jsonl ./runs/qwen35_4b_probe/probe_features.jsonl \ --plan-out ./runs/qwen35_4b_probe/plan.json \ --dynamic-out ./runs/qwen35_4b_probe/dynamic.json \ --budget-mb-extra 220 \ --heuristic-only +Bashpython -m quant_judge.pipeline plan \ --features-jsonl ./runs/qwen35_4b_probe/probe_features.jsonl \ --plan-out ./runs/qwen35_4b_probe/plan.json \ --dynamic-out ./runs/qwen35_4b_probe/dynamic.json \ --budget-mb-extra 220 \ --heuristic-only +``` + +## Plan with external judge + +Use any OpenAI-compatible endpoint, local or remote. + +```Bashexport JUDGE_API_KEY=your_key_herepython -m quant_judge.pipeline plan \ --features-jsonl ./runs/qwen35_4b_probe/probe_features.jsonl \ --plan-out ./runs/qwen35_4b_probe/plan.json \ --dynamic-out ./runs/qwen35_4b_probe/dynamic.json \ --budget-mb-extra 220 \ --judge-base-url http://127.0.0.1:8000/v1 \ --judge-model Qwen/Qwen2.5-14B-Instruct \ --judge-api-key-env JUDGE_API_KEY \ --chunk-size 24 +Bashexport JUDGE_API_KEY=your_key_herepython -m quant_judge.pipeline plan \ --features-jsonl ./runs/qwen35_4b_probe/probe_features.jsonl \ --plan-out ./runs/qwen35_4b_probe/plan.json \ --dynamic-out ./runs/qwen35_4b_probe/dynamic.json \ --budget-mb-extra 220 \ --judge-base-url http://127.0.0.1:8000/v1 \ --judge-model Qwen/Qwen2.5-14B-Instruct \ --judge-api-key-env JUDGE_API_KEY \ --chunk-size 24 +``` + +## Quantize with the compiled plan + +```Bashpython -m quant_judge.pipeline quantize \ --model-id Qwen/Qwen3.5-4B-Base \ --plan-json ./runs/qwen35_4b_probe/plan.json \ --out-dir ./runs/qwen35_4b_quant \ --dataset-name wikitext \ --dataset-config wikitext-2-raw-v1 \ --split train \ --text-field text \ --num-texts 128 \ --batch-size 4 \ --base-bits 4 \ --base-group-size 128 +Bashpython -m quant_judge.pipeline quantize \ --model-id Qwen/Qwen3.5-4B-Base \ --plan-json ./runs/qwen35_4b_probe/plan.json \ --out-dir ./runs/qwen35_4b_quant \ --dataset-name wikitext \ --dataset-config wikitext-2-raw-v1 \ --split train \ --text-field text \ --num-texts 128 \ --batch-size 4 \ --base-bits 4 \ --base-group-size 128 +``` + +## One-shot run + +```Bashexport JUDGE_API_KEY=your_key_herepython -m quant_judge.pipeline all \ --model-id Qwen/Qwen3.5-4B-Base \ --work-dir ./runs/qwen35_4b_full \ --dataset-name wikitext \ --dataset-config wikitext-2-raw-v1 \ --split train \ --text-field text \ --probe-num-texts 64 \ --quant-num-texts 128 \ --max-length 256 \ --max-rows-per-module 128 \ --include-modules o_proj,down_proj,gate_proj,lm_head \ --budget-mb-extra 220 \ --judge-base-url http://127.0.0.1:8000/v1 \ --judge-model Qwen/Qwen2.5-14B-Instruct \ --judge-api-key-env JUDGE_API_KEY +Bashexport JUDGE_API_KEY=your_key_herepython -m quant_judge.pipeline all \ --model-id Qwen/Qwen3.5-4B-Base \ --work-dir ./runs/qwen35_4b_full \ --dataset-name wikitext \ --dataset-config wikitext-2-raw-v1 \ --split train \ --text-field text \ --probe-num-texts 64 \ --quant-num-texts 128 \ --max-length 256 \ --max-rows-per-module 128 \ --include-modules o_proj,down_proj,gate_proj,lm_head \ --budget-mb-extra 220 \ --judge-base-url http://127.0.0.1:8000/v1 \ --judge-model Qwen/Qwen2.5-14B-Instruct \ --judge-api-key-env JUDGE_API_KEY +``` + +## Notes + +- The probe uses a proxy group-wise weight quantizer to estimate per-module sensitivity. It is not a full internal GPTQ re-run for every module/action. +- That is intentional: it stays fast enough to be practical and gives the judge stable numeric features. +- The final quantization still happens through GPTQModel. +- Start with the default modules. Add q_proj/k_proj/v_proj/up_proj only after the loop is working. +EOF + +cat > requirements.txt <<'EOF' +transformers +datasets +accelerate +requests +numpy +tqdm +psutil +gptqmodel +EOF + +mkdir -p quant_judge + +cat > quant_judge/**init**.py <<'EOF' +**all** = [] +EOF + +cat > quant_judge/actions.py <<'EOF' +from **future** import annotations + +from dataclasses import dataclass +from typing import Dict, Any + +@dataclass(frozen=True) +class ActionSpec: +name: str +bits: int | None +group_size: int | None +skip: bool +notes: str = "" + +ACTION_SPECS: dict[str, ActionSpec] = { +"w4_g128": ActionSpec("w4_g128", bits=4, group_size=128, skip=False, notes="base action"), +"w4_g64": ActionSpec("w4_g64", bits=4, group_size=64, skip=False, notes="tighter grouping"), +"w8_g128": ActionSpec("w8_g128", bits=8, group_size=128, skip=False, notes="higher precision"), +"skip": ActionSpec("skip", bits=None, group_size=None, skip=True, notes="leave module unquantized"), +} + +BASE_ACTION = "w4_g128" + +def bytes_per_param_for_action(action: str, fp_bytes: int = 2) -> float: +if action == "skip": +return float(fp_bytes) +spec = ACTION_SPECS[action] +assert spec.bits is not None +return spec.bits / 8.0 + +def action_to_dynamic_override(action: str) -> Dict[str, Any]: +if action == "skip": +return {} +spec = ACTION_SPECS[action] +return { +"bits": spec.bits, +"group_size": spec.group_size, +} + +def allowed_actions_csv() -> str: +return ",".join(ACTION_SPECS.keys()) +EOF + +cat > quant_judge/probe.py <<'EOF' +from **future** import annotations + +import json +import math +import re +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, Iterable, List, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +from datasets import load_dataset +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer + +from .actions import ACTION_SPECS, BASE_ACTION, bytes_per_param_for_action + +@dataclass +class ProbeConfig: +model_id: str +out_dir: str +dataset_name: str +dataset_config: str | None +split: str +text_field: str +num_texts: int +max_length: int +max_rows_per_module: int +include_modules: list[str] +trust_remote_code: bool = True + +def _dtype_auto() -> torch.dtype: +if not torch.cuda.is_available(): +return torch.float32 +major, _minor = torch.cuda.get_device_capability(0) +if major >= 8: +return torch.bfloat16 +return torch.float16 + +def load_texts( +dataset_name: str, +dataset_config: str | None, +split: str, +text_field: str, +num_texts: int, +) -> list[str]: +ds = load_dataset(dataset_name, dataset_config, split=split) +texts: list[str] = [] +for row in ds: +text = row.get(text_field, "") +if isinstance(text, str) and text.strip(): +texts.append(text.strip()) +if len(texts) >= num_texts: +break +if not texts: +raise ValueError("No non-empty texts found in dataset.") +return texts + +def load_model_and_tokenizer(model_id: str, trust_remote_code: bool = True): +tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=trust_remote_code) +if tokenizer.pad_token_id is None: +tokenizer.pad_token = tokenizer.eos_token +model = AutoModelForCausalLM.from_pretrained( +model_id, +trust_remote_code=trust_remote_code, +torch_dtype=_dtype_auto(), +device_map="auto", +low_cpu_mem_usage=True, +) +model.eval() +return model, tokenizer + +def module_name_selected(name: str, include_modules: list[str]) -> bool: +return any(key in name for key in include_modules) + +def iter_target_linears(model, include_modules: list[str]): +for name, module in model.named_modules(): +if isinstance(module, torch.nn.Linear) and module_name_selected(name, include_modules): +yield name, module + +def _flatten_rows(x: torch.Tensor) -> torch.Tensor: +if x.dim() == 2: +return x +if x.dim() >= 3: +return x.reshape(-1, x.shape[-1]) +return x.unsqueeze(0) + +def _take_rows(x: torch.Tensor, max_rows: int) -> torch.Tensor: +x = _flatten_rows(x) +if x.shape[0] <= max_rows: +return x +idx = torch.randperm(x.shape[0], device=x.device)[:max_rows] +return x.index_select(0, idx) + +def _cpu_half(x: torch.Tensor) -> torch.Tensor: +return x.detach().to("cpu", dtype=torch.float16, copy=True) + +def _weight_kurtosis(weight: torch.Tensor) -> float: +w = weight.detach().float().flatten() +mu = w.mean() +var = ((w - mu) ** 2).mean().clamp_min(1e-12) +kurt = (((w - mu) ** 4).mean() / (var ** 2)).item() +return float(kurt) + +def _act_outlier_rate(x: torch.Tensor, sigma: float = 6.0) -> float: +xf = x.float() +std = xf.std().clamp_min(1e-6) +rate = (xf.abs() > (sigma * std)).float().mean().item() +return float(rate) + +def _groupwise_symmetric_quantize(weight: torch.Tensor, bits: int, group_size: int) -> torch.Tensor: +""" +Quantizes along the input dimension per output row. +weight shape: [out_features, in_features] +""" +assert weight.dim() == 2 +w = weight.detach().float() +out_features, in_features = w.shape +if group_size <= 0: +raise ValueError("group_size must be > 0") + +```pad = (group_size - (in_features % group_size)) % group_sizeif pad: w = F.pad(w, (0, pad), value=0.0)maxq = (2 ** (bits - 1)) - 1groups = w.view(out_features, -1, group_size)absmax = groups.abs().amax(dim=-1, keepdim=True).clamp_min(1e-8)scale = absmax / maxqq = torch.round(groups / scale).clamp(-maxq, maxq)dq = q * scaledq = dq.view(out_features, -1)if pad: dq = dq[:, :in_features]return dq.to(weight.device, dtype=weight.dtype) +pad = (group_size - (in_features % group_size)) % group_sizeif pad: w = F.pad(w, (0, pad), value=0.0)maxq = (2 ** (bits - 1)) - 1groups = w.view(out_features, -1, group_size)absmax = groups.abs().amax(dim=-1, keepdim=True).clamp_min(1e-8)scale = absmax / maxqq = torch.round(groups / scale).clamp(-maxq, maxq)dq = q * scaledq = dq.view(out_features, -1)if pad: dq = dq[:, :in_features]return dq.to(weight.device, dtype=weight.dtype) +``` + +def _probe_linear_module( +module: torch.nn.Linear, +x_rows: torch.Tensor, +y_rows: torch.Tensor, +module_name: str, +) -> dict: +device = module.weight.device +x = x_rows.to(device=device, dtype=module.weight.dtype) +y = y_rows.to(device=device, dtype=module.weight.dtype) + +```W = module.weight.detach()bias = None if module.bias is None else module.bias.detach()fp_params = int(W.numel())fp_mb = (fp_params * 2.0) / (1024 ** 2)candidates = {}for action_name in ("w4_g128", "w4_g64", "w8_g128"): spec = ACTION_SPECS[action_name] Wq = _groupwise_symmetric_quantize(W, bits=spec.bits, group_size=spec.group_size) y_hat = F.linear(x, Wq, bias) mse = F.mse_loss(y_hat.float(), y.float()).item() max_abs = (y_hat.float() - y.float()).abs().max().item() y_norm = y.float().norm(dim=-1).clamp_min(1e-8) yh_norm = y_hat.float().norm(dim=-1).clamp_min(1e-8) cos = F.cosine_similarity(y_hat.float(), y.float(), dim=-1).mean().item() q_mb = (fp_params * bytes_per_param_for_action(action_name)) / (1024 ** 2) candidates[action_name] = { "mse": float(mse), "cosine": float(cos), "max_abs": float(max_abs), "estimated_model_mb_for_module": float(q_mb), "estimated_extra_mb_vs_base": float( q_mb - (fp_params * bytes_per_param_for_action(BASE_ACTION)) / (1024 ** 2) ), }skip_mb = fp_mbcandidates["skip"] = { "mse": 0.0, "cosine": 1.0, "max_abs": 0.0, "estimated_model_mb_for_module": float(skip_mb), "estimated_extra_mb_vs_base": float( skip_mb - (fp_params * bytes_per_param_for_action(BASE_ACTION)) / (1024 ** 2) ),}w4 = candidates["w4_g128"]["mse"]w4g64 = candidates["w4_g64"]["mse"]w8 = candidates["w8_g128"]["mse"]gain_w8_vs_w4 = max(0.0, w4 - w8)gain_w4g64_vs_w4 = max(0.0, w4 - w4g64)record = { "name": module_name, "module_type": module_name.split(".")[-1], "out_features": int(W.shape[0]), "in_features": int(W.shape[1]), "params": int(fp_params), "captured_rows": int(x_rows.shape[0]), "weight_kurtosis": _weight_kurtosis(W), "act_outlier_rate": _act_outlier_rate(x_rows), "estimated_fp16_mb_for_module": float(fp_mb), "proxy_gain_w8_vs_w4": float(gain_w8_vs_w4), "proxy_gain_w4g64_vs_w4": float(gain_w4g64_vs_w4), "candidates": candidates,}return record +W = module.weight.detach()bias = None if module.bias is None else module.bias.detach()fp_params = int(W.numel())fp_mb = (fp_params * 2.0) / (1024 ** 2)candidates = {}for action_name in ("w4_g128", "w4_g64", "w8_g128"): spec = ACTION_SPECS[action_name] Wq = _groupwise_symmetric_quantize(W, bits=spec.bits, group_size=spec.group_size) y_hat = F.linear(x, Wq, bias) mse = F.mse_loss(y_hat.float(), y.float()).item() max_abs = (y_hat.float() - y.float()).abs().max().item() y_norm = y.float().norm(dim=-1).clamp_min(1e-8) yh_norm = y_hat.float().norm(dim=-1).clamp_min(1e-8) cos = F.cosine_similarity(y_hat.float(), y.float(), dim=-1).mean().item() q_mb = (fp_params * bytes_per_param_for_action(action_name)) / (1024 ** 2) candidates[action_name] = { "mse": float(mse), "cosine": float(cos), "max_abs": float(max_abs), "estimated_model_mb_for_module": float(q_mb), "estimated_extra_mb_vs_base": float( q_mb - (fp_params * bytes_per_param_for_action(BASE_ACTION)) / (1024 ** 2) ), }skip_mb = fp_mbcandidates["skip"] = { "mse": 0.0, "cosine": 1.0, "max_abs": 0.0, "estimated_model_mb_for_module": float(skip_mb), "estimated_extra_mb_vs_base": float( skip_mb - (fp_params * bytes_per_param_for_action(BASE_ACTION)) / (1024 ** 2) ),}w4 = candidates["w4_g128"]["mse"]w4g64 = candidates["w4_g64"]["mse"]w8 = candidates["w8_g128"]["mse"]gain_w8_vs_w4 = max(0.0, w4 - w8)gain_w4g64_vs_w4 = max(0.0, w4 - w4g64)record = { "name": module_name, "module_type": module_name.split(".")[-1], "out_features": int(W.shape[0]), "in_features": int(W.shape[1]), "params": int(fp_params), "captured_rows": int(x_rows.shape[0]), "weight_kurtosis": _weight_kurtosis(W), "act_outlier_rate": _act_outlier_rate(x_rows), "estimated_fp16_mb_for_module": float(fp_mb), "proxy_gain_w8_vs_w4": float(gain_w8_vs_w4), "proxy_gain_w4g64_vs_w4": float(gain_w4g64_vs_w4), "candidates": candidates,}return record +``` + +def run_probe(cfg: ProbeConfig) -> dict: +out_dir = Path(cfg.out_dir) +out_dir.mkdir(parents=True, exist_ok=True) + +```model, tokenizer = load_model_and_tokenizer(cfg.model_id, trust_remote_code=cfg.trust_remote_code)texts = load_texts( dataset_name=cfg.dataset_name, dataset_config=cfg.dataset_config, split=cfg.split, text_field=cfg.text_field, num_texts=cfg.num_texts,)targets = list(iter_target_linears(model, cfg.include_modules))if not targets: raise ValueError("No target linear modules found. Check --include-modules.")captured_inputs: Dict[str, list[torch.Tensor]] = {name: [] for name, _ in targets}captured_outputs: Dict[str, list[torch.Tensor]] = {name: [] for name, _ in targets}remaining_rows: Dict[str, int] = {name: cfg.max_rows_per_module for name, _ in targets}handles = []def make_hook(name: str): def hook(_module, inp, out): rem = remaining_rows[name] if rem <= 0: return x = inp[0] y = out if not isinstance(x, torch.Tensor) or not isinstance(y, torch.Tensor): return x_rows = _take_rows(x, rem) y_rows = _take_rows(y, rem) take = min(x_rows.shape[0], y_rows.shape[0], rem) if take <= 0: return captured_inputs[name].append(_cpu_half(x_rows[:take])) captured_outputs[name].append(_cpu_half(y_rows[:take])) remaining_rows[name] -= take return hookfor name, module in targets: handles.append(module.register_forward_hook(make_hook(name)))with torch.no_grad(): for text in tqdm(texts, desc="collect_activations"): batch = tokenizer( text, return_tensors="pt", truncation=True, max_length=cfg.max_length, padding=False, ) batch = {k: v.to(model.device) for k, v in batch.items()} _ = model(**batch) if all(v <= 0 for v in remaining_rows.values()): breakfor h in handles: h.remove()probe_rows: list[dict] = []for name, module in tqdm(targets, desc="probe_modules"): if not captured_inputs[name] or not captured_outputs[name]: continue x_rows = torch.cat(captured_inputs[name], dim=0).float() y_rows = torch.cat(captured_outputs[name], dim=0).float() row = _probe_linear_module(module, x_rows, y_rows, name) probe_rows.append(row)features_jsonl = out_dir / "probe_features.jsonl"compact_json = out_dir / "probe_features_compact.json"metadata_json = out_dir / "probe_metadata.json"with open(features_jsonl, "w", encoding="utf-8") as f: for row in probe_rows: f.write(json.dumps(row, ensure_ascii=False) + "\n")compact = { "model_id": cfg.model_id, "num_modules": len(probe_rows), "include_modules": cfg.include_modules, "top_w4_mse_modules": sorted( [{"name": r["name"], "w4_g128_mse": r["candidates"]["w4_g128"]["mse"]} for r in probe_rows], key=lambda x: x["w4_g128_mse"], reverse=True, )[:20],}compact_json.write_text(json.dumps(compact, indent=2), encoding="utf-8")metadata = { "model_id": cfg.model_id, "dataset_name": cfg.dataset_name, "dataset_config": cfg.dataset_config, "split": cfg.split, "text_field": cfg.text_field, "num_texts": cfg.num_texts, "max_length": cfg.max_length, "max_rows_per_module": cfg.max_rows_per_module, "include_modules": cfg.include_modules, "num_probe_rows": len(probe_rows),}metadata_json.write_text(json.dumps(metadata, indent=2), encoding="utf-8")return { "probe_rows": len(probe_rows), "features_jsonl": str(features_jsonl), "compact_json": str(compact_json), "metadata_json": str(metadata_json),} +model, tokenizer = load_model_and_tokenizer(cfg.model_id, trust_remote_code=cfg.trust_remote_code)texts = load_texts( dataset_name=cfg.dataset_name, dataset_config=cfg.dataset_config, split=cfg.split, text_field=cfg.text_field, num_texts=cfg.num_texts,)targets = list(iter_target_linears(model, cfg.include_modules))if not targets: raise ValueError("No target linear modules found. Check --include-modules.")captured_inputs: Dict[str, list[torch.Tensor]] = {name: [] for name, _ in targets}captured_outputs: Dict[str, list[torch.Tensor]] = {name: [] for name, _ in targets}remaining_rows: Dict[str, int] = {name: cfg.max_rows_per_module for name, _ in targets}handles = []def make_hook(name: str): def hook(_module, inp, out): rem = remaining_rows[name] if rem <= 0: return x = inp[0] y = out if not isinstance(x, torch.Tensor) or not isinstance(y, torch.Tensor): return x_rows = _take_rows(x, rem) y_rows = _take_rows(y, rem) take = min(x_rows.shape[0], y_rows.shape[0], rem) if take <= 0: return captured_inputs[name].append(_cpu_half(x_rows[:take])) captured_outputs[name].append(_cpu_half(y_rows[:take])) remaining_rows[name] -= take return hookfor name, module in targets: handles.append(module.register_forward_hook(make_hook(name)))with torch.no_grad(): for text in tqdm(texts, desc="collect_activations"): batch = tokenizer( text, return_tensors="pt", truncation=True, max_length=cfg.max_length, padding=False, ) batch = {k: v.to(model.device) for k, v in batch.items()} _ = model(**batch) if all(v <= 0 for v in remaining_rows.values()): breakfor h in handles: h.remove()probe_rows: list[dict] = []for name, module in tqdm(targets, desc="probe_modules"): if not captured_inputs[name] or not captured_outputs[name]: continue x_rows = torch.cat(captured_inputs[name], dim=0).float() y_rows = torch.cat(captured_outputs[name], dim=0).float() row = _probe_linear_module(module, x_rows, y_rows, name) probe_rows.append(row)features_jsonl = out_dir / "probe_features.jsonl"compact_json = out_dir / "probe_features_compact.json"metadata_json = out_dir / "probe_metadata.json"with open(features_jsonl, "w", encoding="utf-8") as f: for row in probe_rows: f.write(json.dumps(row, ensure_ascii=False) + "\n")compact = { "model_id": cfg.model_id, "num_modules": len(probe_rows), "include_modules": cfg.include_modules, "top_w4_mse_modules": sorted( [{"name": r["name"], "w4_g128_mse": r["candidates"]["w4_g128"]["mse"]} for r in probe_rows], key=lambda x: x["w4_g128_mse"], reverse=True, )[:20],}compact_json.write_text(json.dumps(compact, indent=2), encoding="utf-8")metadata = { "model_id": cfg.model_id, "dataset_name": cfg.dataset_name, "dataset_config": cfg.dataset_config, "split": cfg.split, "text_field": cfg.text_field, "num_texts": cfg.num_texts, "max_length": cfg.max_length, "max_rows_per_module": cfg.max_rows_per_module, "include_modules": cfg.include_modules, "num_probe_rows": len(probe_rows),}metadata_json.write_text(json.dumps(metadata, indent=2), encoding="utf-8")return { "probe_rows": len(probe_rows), "features_jsonl": str(features_jsonl), "compact_json": str(compact_json), "metadata_json": str(metadata_json),} +``` + +EOF + +cat > quant_judge/judge_client.py <<'EOF' +from **future** import annotations + +import json +import math +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Iterable, List + +import requests + +from .actions import ACTION_SPECS, BASE_ACTION, allowed_actions_csv + +@dataclass +class JudgeConfig: +features_jsonl: str +plan_out: str +dynamic_out: str +budget_mb_extra: float +heuristic_only: bool = False +judge_base_url: str | None = None +judge_model: str | None = None +judge_api_key_env: str = "JUDGE_API_KEY" +chunk_size: int = 24 + +def load_features_jsonl(path: str) -> list[dict]: +rows = [] +with open(path, "r", encoding="utf-8") as f: +for line in f: +line = line.strip() +if line: +rows.append(json.loads(line)) +return rows + +def _module_bonus(name: str) -> float: +bonus = 0.0 +if "lm_head" in name: +bonus += 0.25 +if "down_proj" in name: +bonus += 0.15 +if "o_proj" in name: +bonus += 0.12 +if "gate_proj" in name: +bonus += 0.10 +return bonus + +def heuristic_decide(row: dict) -> dict: +name = row["name"] +w4 = row["candidates"]["w4_g128"]["mse"] +w4g64 = row["candidates"]["w4_g64"]["mse"] +w8 = row["candidates"]["w8_g128"]["mse"] +kurt = row["weight_kurtosis"] +outlier = row["act_outlier_rate"] + +```w8_relief = (w4 - w8) / max(w4, 1e-9)g64_relief = (w4 - w4g64) / max(w4, 1e-9)score = 0.0score += min(1.0, w8_relief / 0.60) * 0.40score += min(1.0, g64_relief / 0.35) * 0.20score += min(1.0, kurt / 20.0) * 0.15score += min(1.0, outlier / 0.05) * 0.15score += _module_bonus(name)score = max(0.0, min(1.0, score))reason = "default_low_sensitivity"action = "w4_g128"if "lm_head" in name and (w8_relief > 0.55 or kurt > 25.0): action = "skip" reason = "lm_head_extreme_sensitivity"elif score >= 0.75 or w8_relief > 0.50: action = "w8_g128" reason = "high_sensitivity_prefers_w8"elif score >= 0.42 or g64_relief > 0.18: action = "w4_g64" reason = "moderate_sensitivity_prefers_tighter_groups"return { "name": name, "action": action, "priority": float(round(score, 4)), "reason_code": reason, "source": "heuristic",} +w8_relief = (w4 - w8) / max(w4, 1e-9)g64_relief = (w4 - w4g64) / max(w4, 1e-9)score = 0.0score += min(1.0, w8_relief / 0.60) * 0.40score += min(1.0, g64_relief / 0.35) * 0.20score += min(1.0, kurt / 20.0) * 0.15score += min(1.0, outlier / 0.05) * 0.15score += _module_bonus(name)score = max(0.0, min(1.0, score))reason = "default_low_sensitivity"action = "w4_g128"if "lm_head" in name and (w8_relief > 0.55 or kurt > 25.0): action = "skip" reason = "lm_head_extreme_sensitivity"elif score >= 0.75 or w8_relief > 0.50: action = "w8_g128" reason = "high_sensitivity_prefers_w8"elif score >= 0.42 or g64_relief > 0.18: action = "w4_g64" reason = "moderate_sensitivity_prefers_tighter_groups"return { "name": name, "action": action, "priority": float(round(score, 4)), "reason_code": reason, "source": "heuristic",} +``` + +def _strip_code_fence(text: str) -> str: +t = text.strip() +if t.startswith("`"): t = t.split("\n", 1)[1] if t.endswith("`"): +t = t.rsplit("\n", 1)[0] +return t.strip() + +def _chunk_rows(rows: list[dict], chunk_size: int) -> list[list[dict]]: +return [rows[i:i + chunk_size] for i in range(0, len(rows), chunk_size)] + +def _reduced_row(row: dict) -> dict: +return { +"name": row["name"], +"module_type": row["module_type"], +"params": row["params"], +"captured_rows": row["captured_rows"], +"weight_kurtosis": row["weight_kurtosis"], +"act_outlier_rate": row["act_outlier_rate"], +"proxy_gain_w8_vs_w4": row["proxy_gain_w8_vs_w4"], +"proxy_gain_w4g64_vs_w4": row["proxy_gain_w4g64_vs_w4"], +"w4_g128_mse": row["candidates"]["w4_g128"]["mse"], +"w4_g64_mse": row["candidates"]["w4_g64"]["mse"], +"w8_g128_mse": row["candidates"]["w8_g128"]["mse"], +"extra_mb_w4_g64": row["candidates"]["w4_g64"]["estimated_extra_mb_vs_base"], +"extra_mb_w8_g128": row["candidates"]["w8_g128"]["estimated_extra_mb_vs_base"], +"extra_mb_skip": row["candidates"]["skip"]["estimated_extra_mb_vs_base"], +} + +def call_openai_compatible( +base_url: str, +model: str, +api_key: str, +payload_rows: list[dict], +budget_mb_extra: float, +) -> list[dict]: +system = ( +"You are a quantization policy judge. " +"Choose exactly one action per module from: " +f"{allowed_actions_csv()}. " +"Return JSON only. No markdown. " +"Prefer w4_g128 unless the evidence says the module is fragile. " +"Use w4_g64 for moderate sensitivity. " +"Use w8_g128 for high sensitivity. " +"Use skip only for extreme outliers, usually lm_head or very fragile modules. " +"Also emit a priority score between 0 and 1 where 1 means strongest need to preserve precision." +) +user = { +"budget_mb_extra_global": budget_mb_extra, +"rules": { +"one_assignment_per_input_row": True, +"allowed_actions": list(ACTION_SPECS.keys()), +"json_schema": { +"assignments": [ +{ +"name": "exact module name from input", +"action": "one of allowed actions", +"priority": "float from 0 to 1", +"reason_code": "short snake_case code" +} +] +}, +}, +"rows": payload_rows, +} + +```url = base_url.rstrip("/") + "/chat/completions"headers = {"Content-Type": "application/json"}if api_key: headers["Authorization"] = f"Bearer {api_key}"body = { "model": model, "temperature": 0.0, "response_format": {"type": "json_object"}, "messages": [ {"role": "system", "content": system}, {"role": "user", "content": json.dumps(user, ensure_ascii=False)}, ],}resp = requests.post(url, headers=headers, json=body, timeout=180)resp.raise_for_status()data = resp.json()content = data["choices"][0]["message"]["content"]parsed = json.loads(_strip_code_fence(content))assignments = parsed.get("assignments", [])if not isinstance(assignments, list): raise ValueError("Judge response missing 'assignments' list.")return assignments +url = base_url.rstrip("/") + "/chat/completions"headers = {"Content-Type": "application/json"}if api_key: headers["Authorization"] = f"Bearer {api_key}"body = { "model": model, "temperature": 0.0, "response_format": {"type": "json_object"}, "messages": [ {"role": "system", "content": system}, {"role": "user", "content": json.dumps(user, ensure_ascii=False)}, ],}resp = requests.post(url, headers=headers, json=body, timeout=180)resp.raise_for_status()data = resp.json()content = data["choices"][0]["message"]["content"]parsed = json.loads(_strip_code_fence(content))assignments = parsed.get("assignments", [])if not isinstance(assignments, list): raise ValueError("Judge response missing 'assignments' list.")return assignments +``` + +def merge_assignments(rows: list[dict], llm_assignments: list[dict]) -> list[dict]: +by_name = {a["name"]: a for a in llm_assignments if "name" in a} +merged = [] +for row in rows: +a = by_name.get(row["name"]) +if a is None: +merged.append(heuristic_decide(row)) +continue +action = a.get("action", BASE_ACTION) +if action not in ACTION_SPECS: +action = BASE_ACTION +priority = float(a.get("priority", 0.5)) +priority = max(0.0, min(1.0, priority)) +merged.append({ +"name": row["name"], +"action": action, +"priority": priority, +"reason_code": str(a.get("reason_code", "judge_no_reason")), +"source": "judge", +}) +return merged + +def enforce_budget(rows: list[dict], assignments: list[dict], budget_mb_extra: float) -> dict: +row_by_name = {r["name"]: r for r in rows} +normalized = [] + +```for a in assignments: row = row_by_name[a["name"]] action = a["action"] extra = row["candidates"][action]["estimated_extra_mb_vs_base"] normalized.append({ **a, "estimated_extra_mb_vs_base": extra, })keep = []extras = []for a in normalized: if a["action"] == BASE_ACTION: keep.append(a) else: extras.append(a)extras.sort(key=lambda x: (x["priority"], x["estimated_extra_mb_vs_base"]), reverse=True)used = 0.0final_assignments = list(keep)downgraded = []for a in extras: extra = max(0.0, float(a["estimated_extra_mb_vs_base"])) if used + extra <= budget_mb_extra: final_assignments.append(a) used += extra else: downgraded.append({ **a, "downgraded_to": BASE_ACTION, "downgrade_reason": "budget_clamp", }) final_assignments.append({ "name": a["name"], "action": BASE_ACTION, "priority": a["priority"], "reason_code": "budget_clamp_to_base", "source": a["source"], "estimated_extra_mb_vs_base": 0.0, })final_assignments.sort(key=lambda x: x["name"])return { "budget_mb_extra": budget_mb_extra, "used_mb_extra": round(used, 4), "assignments": final_assignments, "downgraded": downgraded,} +for a in assignments: row = row_by_name[a["name"]] action = a["action"] extra = row["candidates"][action]["estimated_extra_mb_vs_base"] normalized.append({ **a, "estimated_extra_mb_vs_base": extra, })keep = []extras = []for a in normalized: if a["action"] == BASE_ACTION: keep.append(a) else: extras.append(a)extras.sort(key=lambda x: (x["priority"], x["estimated_extra_mb_vs_base"]), reverse=True)used = 0.0final_assignments = list(keep)downgraded = []for a in extras: extra = max(0.0, float(a["estimated_extra_mb_vs_base"])) if used + extra <= budget_mb_extra: final_assignments.append(a) used += extra else: downgraded.append({ **a, "downgraded_to": BASE_ACTION, "downgrade_reason": "budget_clamp", }) final_assignments.append({ "name": a["name"], "action": BASE_ACTION, "priority": a["priority"], "reason_code": "budget_clamp_to_base", "source": a["source"], "estimated_extra_mb_vs_base": 0.0, })final_assignments.sort(key=lambda x: x["name"])return { "budget_mb_extra": budget_mb_extra, "used_mb_extra": round(used, 4), "assignments": final_assignments, "downgraded": downgraded,} +``` + +def make_plan(cfg: JudgeConfig) -> dict: +rows = load_features_jsonl(cfg.features_jsonl) + +```if cfg.heuristic_only: initial_assignments = [heuristic_decide(r) for r in rows]else: if not cfg.judge_base_url or not cfg.judge_model: raise ValueError("Judge mode requires --judge-base-url and --judge-model.") api_key = os.environ.get(cfg.judge_api_key_env, "") chunks = _chunk_rows(rows, cfg.chunk_size) all_judge_assignments = [] for chunk in chunks: reduced = [_reduced_row(r) for r in chunk] try: chunk_assignments = call_openai_compatible( base_url=cfg.judge_base_url, model=cfg.judge_model, api_key=api_key, payload_rows=reduced, budget_mb_extra=cfg.budget_mb_extra, ) merged = merge_assignments(chunk, chunk_assignments) all_judge_assignments.extend(merged) except Exception: all_judge_assignments.extend([heuristic_decide(r) for r in chunk]) initial_assignments = all_judge_assignmentsplan = enforce_budget(rows, initial_assignments, cfg.budget_mb_extra)return plan +if cfg.heuristic_only: initial_assignments = [heuristic_decide(r) for r in rows]else: if not cfg.judge_base_url or not cfg.judge_model: raise ValueError("Judge mode requires --judge-base-url and --judge-model.") api_key = os.environ.get(cfg.judge_api_key_env, "") chunks = _chunk_rows(rows, cfg.chunk_size) all_judge_assignments = [] for chunk in chunks: reduced = [_reduced_row(r) for r in chunk] try: chunk_assignments = call_openai_compatible( base_url=cfg.judge_base_url, model=cfg.judge_model, api_key=api_key, payload_rows=reduced, budget_mb_extra=cfg.budget_mb_extra, ) merged = merge_assignments(chunk, chunk_assignments) all_judge_assignments.extend(merged) except Exception: all_judge_assignments.extend([heuristic_decide(r) for r in chunk]) initial_assignments = all_judge_assignmentsplan = enforce_budget(rows, initial_assignments, cfg.budget_mb_extra)return plan +``` + +EOF + +cat > quant_judge/compiler.py <<'EOF' +from **future** import annotations + +import json +import re +from pathlib import Path + +from .actions import action_to_dynamic_override + +def compile_dynamic_from_plan(plan: dict) -> dict: +dynamic = {} +for item in plan["assignments"]: +name = item["name"] +action = item["action"] +escaped = re.escape(name) +if action == "skip": +dynamic[rf"-:^{escaped}$"] = {} +else: +dynamic[rf"+:^{escaped}$"] = action_to_dynamic_override(action) +return dynamic + +def save_plan_and_dynamic(plan: dict, plan_out: str, dynamic_out: str) -> None: +plan_path = Path(plan_out) +dyn_path = Path(dynamic_out) +plan_path.parent.mkdir(parents=True, exist_ok=True) +dyn_path.parent.mkdir(parents=True, exist_ok=True) + +```dynamic = compile_dynamic_from_plan(plan)plan_path.write_text(json.dumps(plan, indent=2), encoding="utf-8")dyn_path.write_text(json.dumps(dynamic, indent=2), encoding="utf-8") +dynamic = compile_dynamic_from_plan(plan)plan_path.write_text(json.dumps(plan, indent=2), encoding="utf-8")dyn_path.write_text(json.dumps(dynamic, indent=2), encoding="utf-8") +``` + +EOF + +cat > quant_judge/pipeline.py <<'EOF' +from **future** import annotations + +import argparse +import json +from pathlib import Path + +from datasets import load_dataset +from gptqmodel import GPTQModel, QuantizeConfig + +from .compiler import compile_dynamic_from_plan, save_plan_and_dynamic +from .judge_client import JudgeConfig, make_plan +from .probe import ProbeConfig, run_probe + +def load_texts_for_quant(dataset_name: str, dataset_config: str | None, split: str, text_field: str, num_texts: int) -> list[str]: +ds = load_dataset(dataset_name, dataset_config, split=split) +texts = [] +for row in ds: +text = row.get(text_field, "") +if isinstance(text, str) and text.strip(): +texts.append(text.strip()) +if len(texts) >= num_texts: +break +if not texts: +raise ValueError("No texts found for quantization dataset.") +return texts + +def do_probe(args): +cfg = ProbeConfig( +model_id=args.model_id, +out_dir=args.out_dir, +dataset_name=args.dataset_name, +dataset_config=args.dataset_config, +split=args.split, +text_field=args.text_field, +num_texts=args.num_texts, +max_length=args.max_length, +max_rows_per_module=args.max_rows_per_module, +include_modules=[x.strip() for x in args.include_modules.split(",") if x.strip()], +) +result = run_probe(cfg) +print(json.dumps(result, indent=2)) + +def do_plan(args): +cfg = JudgeConfig( +features_jsonl=args.features_jsonl, +plan_out=args.plan_out, +dynamic_out=args.dynamic_out, +budget_mb_extra=args.budget_mb_extra, +heuristic_only=args.heuristic_only, +judge_base_url=args.judge_base_url, +judge_model=args.judge_model, +judge_api_key_env=args.judge_api_key_env, +chunk_size=args.chunk_size, +) +plan = make_plan(cfg) +save_plan_and_dynamic(plan, cfg.plan_out, cfg.dynamic_out) +print(json.dumps({ +"plan_out": cfg.plan_out, +"dynamic_out": cfg.dynamic_out, +"used_mb_extra": plan["used_mb_extra"], +"downgraded_count": len(plan["downgraded"]), +"assignments": len(plan["assignments"]), +}, indent=2)) + +def do_quantize(args): +plan = json.loads(Path(args.plan_json).read_text(encoding="utf-8")) +dynamic = compile_dynamic_from_plan(plan) + +```calibration_dataset = load_texts_for_quant( dataset_name=args.dataset_name, dataset_config=args.dataset_config, split=args.split, text_field=args.text_field, num_texts=args.num_texts,)qcfg = QuantizeConfig.gptq_pro()qcfg.bits = args.base_bitsqcfg.group_size = args.base_group_sizeqcfg.dynamic = dynamicmodel = GPTQModel.load(args.model_id, qcfg)model.quantize(calibration_dataset, batch_size=args.batch_size)Path(args.out_dir).mkdir(parents=True, exist_ok=True)model.save(args.out_dir)meta = { "model_id": args.model_id, "out_dir": args.out_dir, "base_bits": args.base_bits, "base_group_size": args.base_group_size, "batch_size": args.batch_size, "num_texts": args.num_texts, "dynamic_rules": len(dynamic), "plan_json": args.plan_json,}Path(args.out_dir, "quant_judge_meta.json").write_text(json.dumps(meta, indent=2), encoding="utf-8")print(json.dumps(meta, indent=2)) +calibration_dataset = load_texts_for_quant( dataset_name=args.dataset_name, dataset_config=args.dataset_config, split=args.split, text_field=args.text_field, num_texts=args.num_texts,)qcfg = QuantizeConfig.gptq_pro()qcfg.bits = args.base_bitsqcfg.group_size = args.base_group_sizeqcfg.dynamic = dynamicmodel = GPTQModel.load(args.model_id, qcfg)model.quantize(calibration_dataset, batch_size=args.batch_size)Path(args.out_dir).mkdir(parents=True, exist_ok=True)model.save(args.out_dir)meta = { "model_id": args.model_id, "out_dir": args.out_dir, "base_bits": args.base_bits, "base_group_size": args.base_group_size, "batch_size": args.batch_size, "num_texts": args.num_texts, "dynamic_rules": len(dynamic), "plan_json": args.plan_json,}Path(args.out_dir, "quant_judge_meta.json").write_text(json.dumps(meta, indent=2), encoding="utf-8")print(json.dumps(meta, indent=2)) +``` + +def do_all(args): +work_dir = Path(args.work_dir) +probe_dir = work_dir / "probe" +probe_dir.mkdir(parents=True, exist_ok=True) + +```probe_args = argparse.Namespace( model_id=args.model_id, out_dir=str(probe_dir), dataset_name=args.dataset_name, dataset_config=args.dataset_config, split=args.split, text_field=args.text_field, num_texts=args.probe_num_texts, max_length=args.max_length, max_rows_per_module=args.max_rows_per_module, include_modules=args.include_modules,)do_probe(probe_args)features_jsonl = str(probe_dir / "probe_features.jsonl")plan_out = str(work_dir / "plan.json")dynamic_out = str(work_dir / "dynamic.json")plan_args = argparse.Namespace( features_jsonl=features_jsonl, plan_out=plan_out, dynamic_out=dynamic_out, budget_mb_extra=args.budget_mb_extra, heuristic_only=args.heuristic_only, judge_base_url=args.judge_base_url, judge_model=args.judge_model, judge_api_key_env=args.judge_api_key_env, chunk_size=args.chunk_size,)do_plan(plan_args)quant_args = argparse.Namespace( model_id=args.model_id, plan_json=plan_out, out_dir=str(work_dir / "quantized"), dataset_name=args.dataset_name, dataset_config=args.dataset_config, split=args.split, text_field=args.text_field, num_texts=args.quant_num_texts, batch_size=args.batch_size, base_bits=args.base_bits, base_group_size=args.base_group_size,)do_quantize(quant_args) +probe_args = argparse.Namespace( model_id=args.model_id, out_dir=str(probe_dir), dataset_name=args.dataset_name, dataset_config=args.dataset_config, split=args.split, text_field=args.text_field, num_texts=args.probe_num_texts, max_length=args.max_length, max_rows_per_module=args.max_rows_per_module, include_modules=args.include_modules,)do_probe(probe_args)features_jsonl = str(probe_dir / "probe_features.jsonl")plan_out = str(work_dir / "plan.json")dynamic_out = str(work_dir / "dynamic.json")plan_args = argparse.Namespace( features_jsonl=features_jsonl, plan_out=plan_out, dynamic_out=dynamic_out, budget_mb_extra=args.budget_mb_extra, heuristic_only=args.heuristic_only, judge_base_url=args.judge_base_url, judge_model=args.judge_model, judge_api_key_env=args.judge_api_key_env, chunk_size=args.chunk_size,)do_plan(plan_args)quant_args = argparse.Namespace( model_id=args.model_id, plan_json=plan_out, out_dir=str(work_dir / "quantized"), dataset_name=args.dataset_name, dataset_config=args.dataset_config, split=args.split, text_field=args.text_field, num_texts=args.quant_num_texts, batch_size=args.batch_size, base_bits=args.base_bits, base_group_size=args.base_group_size,)do_quantize(quant_args) +``` + +def build_parser(): +p = argparse.ArgumentParser(prog="quant_judge") +sp = p.add_subparsers(dest="cmd", required=True) + +```p_probe = sp.add_parser("probe")p_probe.add_argument("--model-id", required=True)p_probe.add_argument("--out-dir", required=True)p_probe.add_argument("--dataset-name", default="wikitext")p_probe.add_argument("--dataset-config", default="wikitext-2-raw-v1")p_probe.add_argument("--split", default="train")p_probe.add_argument("--text-field", default="text")p_probe.add_argument("--num-texts", type=int, default=64)p_probe.add_argument("--max-length", type=int, default=256)p_probe.add_argument("--max-rows-per-module", type=int, default=128)p_probe.add_argument("--include-modules", default="o_proj,down_proj,gate_proj,lm_head")p_probe.set_defaults(func=do_probe)p_plan = sp.add_parser("plan")p_plan.add_argument("--features-jsonl", required=True)p_plan.add_argument("--plan-out", required=True)p_plan.add_argument("--dynamic-out", required=True)p_plan.add_argument("--budget-mb-extra", type=float, default=220.0)p_plan.add_argument("--heuristic-only", action="store_true")p_plan.add_argument("--judge-base-url", default=None)p_plan.add_argument("--judge-model", default=None)p_plan.add_argument("--judge-api-key-env", default="JUDGE_API_KEY")p_plan.add_argument("--chunk-size", type=int, default=24)p_plan.set_defaults(func=do_plan)p_quant = sp.add_parser("quantize")p_quant.add_argument("--model-id", required=True)p_quant.add_argument("--plan-json", required=True)p_quant.add_argument("--out-dir", required=True)p_quant.add_argument("--dataset-name", default="wikitext")p_quant.add_argument("--dataset-config", default="wikitext-2-raw-v1")p_quant.add_argument("--split", default="train")p_quant.add_argument("--text-field", default="text")p_quant.add_argument("--num-texts", type=int, default=128)p_quant.add_argument("--batch-size", type=int, default=4)p_quant.add_argument("--base-bits", type=int, default=4)p_quant.add_argument("--base-group-size", type=int, default=128)p_quant.set_defaults(func=do_quantize)p_all = sp.add_parser("all")p_all.add_argument("--model-id", required=True)p_all.add_argument("--work-dir", required=True)p_all.add_argument("--dataset-name", default="wikitext")p_all.add_argument("--dataset-config", default="wikitext-2-raw-v1")p_all.add_argument("--split", default="train")p_all.add_argument("--text-field", default="text")p_all.add_argument("--probe-num-texts", type=int, default=64)p_all.add_argument("--quant-num-texts", type=int, default=128)p_all.add_argument("--max-length", type=int, default=256)p_all.add_argument("--max-rows-per-module", type=int, default=128)p_all.add_argument("--include-modules", default="o_proj,down_proj,gate_proj,lm_head")p_all.add_argument("--budget-mb-extra", type=float, default=220.0)p_all.add_argument("--heuristic-only", action="store_true")p_all.add_argument("--judge-base-url", default=None)p_all.add_argument("--judge-model", default=None)p_all.add_argument("--judge-api-key-env", default="JUDGE_API_KEY")p_all.add_argument("--chunk-size", type=int, default=24)p_all.add_argument("--batch-size", type=int, default=4)p_all.add_argument("--base-bits", type=int, default=4)p_all.add_argument("--base-group-size", type=int, default=128)p_all.set_defaults(func=do_all)return p +p_probe = sp.add_parser("probe")p_probe.add_argument("--model-id", required=True)p_probe.add_argument("--out-dir", required=True)p_probe.add_argument("--dataset-name", default="wikitext")p_probe.add_argument("--dataset-config", default="wikitext-2-raw-v1")p_probe.add_argument("--split", default="train")p_probe.add_argument("--text-field", default="text")p_probe.add_argument("--num-texts", type=int, default=64)p_probe.add_argument("--max-length", type=int, default=256)p_probe.add_argument("--max-rows-per-module", type=int, default=128)p_probe.add_argument("--include-modules", default="o_proj,down_proj,gate_proj,lm_head")p_probe.set_defaults(func=do_probe)p_plan = sp.add_parser("plan")p_plan.add_argument("--features-jsonl", required=True)p_plan.add_argument("--plan-out", required=True)p_plan.add_argument("--dynamic-out", required=True)p_plan.add_argument("--budget-mb-extra", type=float, default=220.0)p_plan.add_argument("--heuristic-only", action="store_true")p_plan.add_argument("--judge-base-url", default=None)p_plan.add_argument("--judge-model", default=None)p_plan.add_argument("--judge-api-key-env", default="JUDGE_API_KEY")p_plan.add_argument("--chunk-size", type=int, default=24)p_plan.set_defaults(func=do_plan)p_quant = sp.add_parser("quantize")p_quant.add_argument("--model-id", required=True)p_quant.add_argument("--plan-json", required=True)p_quant.add_argument("--out-dir", required=True)p_quant.add_argument("--dataset-name", default="wikitext")p_quant.add_argument("--dataset-config", default="wikitext-2-raw-v1")p_quant.add_argument("--split", default="train")p_quant.add_argument("--text-field", default="text")p_quant.add_argument("--num-texts", type=int, default=128)p_quant.add_argument("--batch-size", type=int, default=4)p_quant.add_argument("--base-bits", type=int, default=4)p_quant.add_argument("--base-group-size", type=int, default=128)p_quant.set_defaults(func=do_quantize)p_all = sp.add_parser("all")p_all.add_argument("--model-id", required=True)p_all.add_argument("--work-dir", required=True)p_all.add_argument("--dataset-name", default="wikitext")p_all.add_argument("--dataset-config", default="wikitext-2-raw-v1")p_all.add_argument("--split", default="train")p_all.add_argument("--text-field", default="text")p_all.add_argument("--probe-num-texts", type=int, default=64)p_all.add_argument("--quant-num-texts", type=int, default=128)p_all.add_argument("--max-length", type=int, default=256)p_all.add_argument("--max-rows-per-module", type=int, default=128)p_all.add_argument("--include-modules", default="o_proj,down_proj,gate_proj,lm_head")p_all.add_argument("--budget-mb-extra", type=float, default=220.0)p_all.add_argument("--heuristic-only", action="store_true")p_all.add_argument("--judge-base-url", default=None)p_all.add_argument("--judge-model", default=None)p_all.add_argument("--judge-api-key-env", default="JUDGE_API_KEY")p_all.add_argument("--chunk-size", type=int, default=24)p_all.add_argument("--batch-size", type=int, default=4)p_all.add_argument("--base-bits", type=int, default=4)p_all.add_argument("--base-group-size", type=int, default=128)p_all.set_defaults(func=do_all)return p +``` + +def main(): +parser = build_parser() +args = parser.parse_args() +args.func(args) + +if **name** == "**main**": +main() +EOF + +```Then install and run it:```bashpython -m venv .venvsource .venv/bin/activate# install your CUDA-matching torch first, for example:# pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128pip install -r requirements.txt +Then install and run it:```bashpython -m venv .venvsource .venv/bin/activate# install your CUDA-matching torch first, for example:# pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128pip install -r requirements.txt +``` + +Heuristic-only dry run: + +```Bashpython -m quant_judge.pipeline all \ --model-id Qwen/Qwen3.5-4B-Base \ --work-dir ./runs/qwen35_4b_full \ --dataset-name wikitext \ --dataset-config wikitext-2-raw-v1 \ --split train \ --text-field text \ --probe-num-texts 64 \ --quant-num-texts 128 \ --max-length 256 \ --max-rows-per-module 128 \ --include-modules o_proj,down_proj,gate_proj,lm_head \ --budget-mb-extra 220 \ --heuristic-only +Bashpython -m quant_judge.pipeline all \ --model-id Qwen/Qwen3.5-4B-Base \ --work-dir ./runs/qwen35_4b_full \ --dataset-name wikitext \ --dataset-config wikitext-2-raw-v1 \ --split train \ --text-field text \ --probe-num-texts 64 \ --quant-num-texts 128 \ --max-length 256 \ --max-rows-per-module 128 \ --include-modules o_proj,down_proj,gate_proj,lm_head \ --budget-mb-extra 220 \ --heuristic-only +``` + +External judge run: + +```Bashexport JUDGE_API_KEY=your_key_herepython -m quant_judge.pipeline all \ --model-id Qwen/Qwen3.5-4B-Base \ --work-dir ./runs/qwen35_4b_full \ --dataset-name wikitext \ --dataset-config wikitext-2-raw-v1 \ --split train \ --text-field text \ --probe-num-texts 64 \ --quant-num-texts 128 \ --max-length 256 \ --max-rows-per-module 128 \ --include-modules o_proj,down_proj,gate_proj,lm_head \ --budget-mb-extra 220 \ --judge-base-url http://127.0.0.1:8000/v1 \ --judge-model Qwen/Qwen2.5-14B-Instruct \ --judge-api-key-env JUDGE_API_KEY +Bashexport JUDGE_API_KEY=your_key_herepython -m quant_judge.pipeline all \ --model-id Qwen/Qwen3.5-4B-Base \ --work-dir ./runs/qwen35_4b_full \ --dataset-name wikitext \ --dataset-config wikitext-2-raw-v1 \ --split train \ --text-field text \ --probe-num-texts 64 \ --quant-num-texts 128 \ --max-length 256 \ --max-rows-per-module 128 \ --include-modules o_proj,down_proj,gate_proj,lm_head \ --budget-mb-extra 220 \ --judge-base-url http://127.0.0.1:8000/v1 \ --judge-model Qwen/Qwen2.5-14B-Instruct \ --judge-api-key-env JUDGE_API_KEY +``` + +What you should inspect first: + +- `./runs/qwen35_4b_full/probe/probe_features.jsonl` +- `./runs/qwen35_4b_full/plan.json` +- `./runs/qwen35_4b_full/dynamic.json` + +The important part is that `dynamic.json` is what gets applied to GPTQModel’s mixed per-module quantization rules, including exact-match `+:` overrides and `-:` skips. [GitHub](https://github.com/ModelCloud/GPTQModel) + +Best first upgrade after this works: + +- add `q_proj,v_proj,up_proj` +- keep `k_proj` conservative +- run your A/B harness on: + +plain base `gptq_pro()` + +heuristic plan + +judge plan +- plain base `gptq_pro()` +- heuristic plan +- judge plan + +Paste `plan.json` and I’ll tell you whether the judge is being smart or just expensive. + +--- + +### User: +This architecture is a significant leap over standard "brute-force" quantization. You are essentially proposing a **Software-Defined Quantization (SDQ)** controller. + +By moving from a global policy (e.g., "everything is 4-bit") to a **telemetry-driven per-layer policy**, you solve the "Pareto Frontier" problem of quantization: getting the best possible accuracy for a specific memory budget. + +Here is the concrete implementation scaffold to build exactly what you described for **Qwen 3.5 4B**. + +--- + +### Phase 1: The Probe (Feature Extraction) +You need a script that hooks into `GPTQModel` or `AutoGPTQ` to run "shadow quantization" trials. + +``` +python +import torch +from gptqmodel import GPTQModel, QuantizeConfig +from copy import deepcopy + +def probe_layer_sensitivity(model, calibration_data): + features = [] + # Identify all linear layers that aren't the embedding + layers = [n for n, m in model.named_modules() if isinstance(m, torch.nn.Linear)] + + for layer_name in layers: + # 1. Calculate Weight Kurtosis (Outlier Indicator) + weights = model.get_submodule(layer_name).weight.data.float() + kurtosis = ((weights - weights.mean())**4).mean() / (weights.var()**2) + + # 2. Run mini-quantization trials (MSE check) + # We compare original FP16 output vs quantized W4 output on 1 batch + input_feat = [] + def hook(m, i, o): input_feat.append(i[0]) + handle = model.get_submodule(layer_name).register_forward_hook(hook) + + # Run one calibration batch + model(calibration_data[0]) + handle.remove() + + x = input_feat[0] + original_output = model.get_submodule(layer_name)(x) + + # Simple MSE Proxy: Quantize weights manually to 4-bit and measure error + # (This is a 'fast' proxy for actual GPTQ error) + w_min, w_max = weights.min(), weights.max() + # ... (Insert standard symmetric quant math here) ... + q_weights = torch.quantize_per_tensor(weights, scale=..., zero_point=0, dtype=torch.qint8) + + mse_error = torch.nn.functional.mse_norm(original_output, reconstructed_output) + + features.append({ + "layer": layer_name, + "kurtosis": kurtosis.item(), + "mse_proxy": mse_error.item(), + "params": weights.numel() + }) + return features +``` + +### Phase 2: The Judge (The Policy Agent) +The key here is **constrained output**. You don't want the LLM to write code; you want it to act as a **Router**. + +**System Prompt for the Judge:** +> "You are a Quantization Strategy Engine. You will receive a JSON list of model layers and their sensitivity metrics. +> +> **Your Constraints:** +> 1. Total 'Extra Memory' budget: 150MB. +> 2. `w8_g128` costs 2x memory of `w4_g128`. +> 3. `skip` costs 4x memory of `w4_g128`. +> +> **Decision Logic:** +> - If `kurtosis` > 15 OR `mse_proxy` > 0.05: Prioritize `w8_g128`. +> - If `layer` is `lm_head` or the first/last 2 blocks: Prioritize `skip` or `w8`. +> - Otherwise: Default to `w4_g128`. +> +> Output ONLY a JSON map of `{ "layer_regex": "action" }`." + +### Phase 3: The Compiler (Mapping to GPTQModel) +`GPTQModel`’s `dynamic` parameter is the "clean way" you mentioned. It allows for regex-based overrides. + +``` +python +from gptqmodel import GPTQModel, QuantizeConfig + +# This is what your Judge produced +judge_decision = { + r".*lm_head": "skip", + r".*layers\.0\..*": "w8_g128", + r".*down_proj": "w4_g64", +} + +def build_dynamic_config(decisions): + dynamic_map = {} + for regex, action in decisions.items(): + if action == "skip": + dynamic_map[f"-:{regex}"] = {} # Negative match to skip + elif action == "w8_g128": + dynamic_map[f"+:{regex}"] = {"bits": 8, "group_size": 128} + elif action == "w4_g64": + dynamic_map[f"+:{regex}"] = {"bits": 4, "group_size": 64} + return dynamic_map + +quant_config = QuantizeConfig( + bits=4, + group_size=128, + dynamic=build_dynamic_config(judge_decision) +) + +# Apply and Quantize +model = GPTQModel.load(model_id_or_path) +model.quantize(calibration_dataset, quant_config) +model.save_quantized("qwen-3.5-4b-judge-enhanced") +``` + +--- + +### Why this beats current GPTQ implementations: + +1. **The "Vibe Check" is Measurable:** By feeding `kurtosis` and `MSE` to the judge, the LLM isn't guessing which layers are important; it is performing a multi-objective optimization (Quality vs. VRAM) that is difficult to express in a simple `if/else` script. +2. **Solves the "Down_Proj" Fragility:** In Qwen and Llama models, the `down_proj` in the MLP block and the `o_proj` in Attention are notoriously sensitive. A standard GPTQ run destroys them. Your Judge will see the high MSE in the probe and automatically up-tier them to `w4_g64` or `w8`. +3. **Future-Proofing for Qwen 3.5:** Since Qwen 3.5 4B is a small model, every bit counts. Using 8-bit for just the most "painful" 5% of layers will likely keep perplexity within 0.1 of the FP16 baseline while still fitting on a 24GB consumer card easily. + +### The "Pro" Upgrade: The Feedback Loop +If you want to go beyond "v1", implement a **Verify Pass**: +1. Judge proposes config. +2. Quantize. +3. Run `lm-eval` (HellaSwag or ARC-Challenge). +4. If score drops > 1%, feed the scores back to the Judge: *"The previous plan failed on reasoning tasks. Tier up the 'gate_proj' layers in the middle blocks."* + +This turns GPTQ from a **compression tool** into an **autonomous optimization agent**. + +--- + +### ChatGPT: +This is **directionally very strong**, but I’d patch a few things before you treat it as production-ready. + +The biggest correction is API shape: in GPTQModel’s current documented flow, you build `QuantizeConfig`, pass it into `GPTQModel.load(model_id, quant_config)`, then call `model.quantize(calibration_dataset, batch_size=...)`, then `model.save(...)`. The per-module override mechanism you want is real: `QuantizeConfig.dynamic` supports regex matches, positive `+:` overrides, and negative `-:` matches to skip modules; the override surface includes `bits`, `group_size`, `sym`, `desc_act`, `mse`, and `pack_dtype`. GPTQModel also explicitly recommends validating post-quant quality with `lm-eval` or `evalplus` rather than relying on perplexity alone. [GitHub+1](https://github.com/ModelCloud/GPTQModel) + +The second correction is your probe math. `torch.quantize_per_tensor(... dtype=torch.qint8)` is not a real 4-bit GPTQ proxy, and `torch.nn.functional.mse_norm` is not a valid PyTorch function. For a fast probe, use a **custom group-wise fake quantizer** on the weight matrix, then compare `F.linear(x, Wq, bias)` against the FP output with `torch.nn.functional.mse_loss`. That gives you a stable screening signal without pretending you ran full GPTQ per layer. [GitHub](https://github.com/ModelCloud/GPTQModel) + +Third: `model(calibration_data[0])` is too loose. If your calibration set is text, you need tokenization first, and for `Qwen/Qwen3.5-4B-Base` specifically, the current model card describes it as a **causal language model with a vision encoder**, with 32 layers and tied LM output. So the safest initial target set is exactly what you suggested on the LM side only—`o_proj`, `down_proj`, `gate_proj`, and maybe `lm_head`—instead of trying to blanket every `torch.nn.Linear` in the full multimodal graph on day one. [Hugging Face](https://huggingface.co/Qwen/Qwen3.5-4B-Base) + +Fourth: your “judge” should not emit regexes directly as its primary object. Have it emit **exact module names + discrete actions** like `w4_g128`, `w4_g64`, `w8_g128`, `skip`, and then compile those exact names into anchored regexes yourself. That removes one whole class of bad outputs and keeps the LLM in the role of a constrained classifier, not a config author. GPTQModel’s dynamic rules are regex-based, so compiling exact names into `+:^...$` or `-:^...$` is the clean deterministic boundary. [GitHub](https://github.com/ModelCloud/GPTQModel) + +Fifth: I would tighten the control loop conceptually. The LLM is **not** the optimizer; the deterministic budget clamp is. The judge proposes upgrades, then a compiler enforces the extra-memory budget by downgrading the lowest-priority promotions back to base `w4_g128`. That makes the system auditable and reproducible. Without that clamp, the prompt can drift into “seems important” behavior. + +So the production shape I’d use is: + +- **Probe:** collect per-module features from real activations. +- **Judge:** classify each module into a small action set. +- **Compiler:** enforce budget and emit `dynamic`. +- **Quantize:** run GPTQModel with base `bits=4, group_size=128` plus overrides. +- **Verify:** score with `lm-eval`, not just proxy MSE. [GitHub+1](https://github.com/ModelCloud/GPTQModel) + +The feature set I’d keep for v1: + +- `w4_g128_mse` +- `w4_g64_mse` +- `w8_g128_mse` +- `weight_kurtosis` +- `act_outlier_rate` +- `params` +- `estimated_extra_mb_vs_base` +- `module_type` +- `block_index` + +And the hard rules I’d start with: + +- default `w4_g128` +- promote to `w4_g64` if `w4_g64_mse` materially beats `w4_g128_mse` +- promote to `w8_g128` if `w8_g128_mse` gives a large gain and the module is high-priority +- reserve `skip` mostly for `lm_head` or extreme outliers + +One more practical note: GPTQModel’s README explicitly notes `lm_head` quantization support, so testing `lm_head` as a special case is reasonable now rather than assuming it must always stay FP16. [GitHub](https://github.com/ModelCloud/GPTQModel) + +So my verdict on that writeup is: + +- **architecture idea:** good +- **API direction:** mostly correct +- **probe example code:** needs fixes +- **judge concept:** good if constrained +- **compiler to dynamic:** exactly the right path +- **feedback loop with lm-eval:** absolutely worth doing + +The next concrete step is to replace the toy probe snippet with a real **group-wise fake-quant probe** and make the judge output exact names instead of regexes. If you want, I’ll write the **v2 scaffold** with those fixes baked in and add a verifier pass that compares: + +1. plain `gptq_pro()` +2. heuristic mixed plan +3. judge mixed plan + +--- + +### User: +Your corrections are spot on. Moving from a fuzzy "LLM config author" to a **constrained classifier with a deterministic budget compiler** is exactly how you take this from a weekend experiment to a production-grade tool. + +Fixing the fake-quant proxy to be group-wise, fixing the PyTorch math, anchoring the regexes, and properly tokenizing the calibration data are all essential. + +Here is the **v2 Scaffold**. This represents a fully mature Software-Defined Quantization (SDQ) pipeline for GPTQModel. + +--- + +### 1. The Real Group-Wise Fake-Quant Probe +We replace the toy quantizer with a fast group-wise symmetric fake quantizer. We tokenize real text to get valid activations, and use `F.linear` and `F.mse_loss`. + +``` +python +import torch +import torch.nn.functional as F +from transformers import AutoTokenizer +from gptqmodel import GPTQModel + +def fake_quant_groupwise(weight, bits, group_size): + """Fast symmetric group-wise fake quantization proxy.""" + org_shape = weight.shape + # Pad if necessary for group_size alignment + pad_len = (group_size - (weight.numel() % group_size)) % group_size + w_flat = weight.flatten() + if pad_len > 0: + w_flat = torch.cat([w_flat, torch.zeros(pad_len, device=weight.device)]) + + w_grouped = w_flat.view(-1, group_size) + q_max = (1 << (bits - 1)) - 1 + + # Calculate scales per group + max_val = w_grouped.abs().max(dim=1, keepdim=True)[0].clamp(min=1e-5) + scales = max_val / q_max + + # Quantize and Dequantize + w_q = torch.round(w_grouped / scales).clamp(-q_max, q_max) + w_dq = w_q * scales + + w_dq_flat = w_dq.flatten()[:weight.numel()] + return w_dq_flat.view(org_shape) + +def probe_model(model_id, calibration_texts): + tokenizer = AutoTokenizer.from_pretrained(model_id) + model = GPTQModel.load(model_id, device_map="auto") # Load in FP16 for probing + + # Tokenize input + inputs = tokenizer(calibration_texts, return_tensors="pt", padding=True, truncation=True, max_length=512) + inputs = {k: v.to(model.device) for k, v in inputs.items()} + + features =[] + target_types = ("q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", "lm_head") + + for name, module in model.named_modules(): + if not isinstance(module, torch.nn.Linear) or not any(t in name for t in target_types): + continue + + weight = module.weight.data + + # 1. Kurtosis + w_mean, w_var = weight.mean(), weight.var() + kurtosis = ((weight - w_mean)**4).mean() / (w_var**2) + + # 2. Get Activation + act = [] + def hook(m, i, o): act.append(i[0].detach()) + handle = module.register_forward_hook(hook) + with torch.no_grad(): + model(**inputs) + handle.remove() + x = act[0] + + # 3. Fast MSE Proxies + out_fp = F.linear(x, weight) + + w_w4g128 = fake_quant_groupwise(weight, bits=4, group_size=128) + w4g128_mse = F.mse_loss(F.linear(x, w_w4g128), out_fp).item() + + w_w4g64 = fake_quant_groupwise(weight, bits=4, group_size=64) + w4g64_mse = F.mse_loss(F.linear(x, w_w4g64), out_fp).item() + + w_w8g128 = fake_quant_groupwise(weight, bits=8, group_size=128) + w8g128_mse = F.mse_loss(F.linear(x, w_w8g128), out_fp).item() + + params = weight.numel() + + features.append({ + "name": name, + "params": params, + "kurtosis": kurtosis.item(), + "w4g128_mse": w4g128_mse, + "w4g64_mse": w4g64_mse, + "w8g128_mse": w8g128_mse, + # Est. extra memory vs base W4 + "w4g64_extra_mb": (params * 4 / 8) / 1024**2 * 0.1, # roughly scale up for smaller group overhead + "w8g128_extra_mb": (params * 4 / 8) / 1024**2, # roughly 4 extra bits per param + }) + + return features +``` + +### 2. The Deterministic Compiler (The Budget Clamp) +The LLM Judge receives the JSON from the probe and returns a list of proposed exact module names and actions (`w4_g64`, `w8_g128`, `skip`). The Compiler enforces the global VRAM budget by ranking the requested upgrades by their delta MSE (the quality saved by upgrading), and drops the rest back to the `w4_g128` baseline. + +``` +python +import re + +def compile_dynamic_config(judge_proposals, probe_features, extra_budget_mb=150.0): + """ + judge_proposals: dict of { "exact.module.name": "action" } + probe_features: dict of features from probe keyed by module name + """ + # 1. Match proposals to costs and benefits + upgrades =[] + for name, action in judge_proposals.items(): + if action == "w4_g128": + continue # Base, no extra cost + + feat = next((f for f in probe_features if f["name"] == name), None) + if not feat: continue + + # Calculate how much error we save by doing this upgrade vs base + if action == "w8_g128": + benefit = feat["w4g128_mse"] - feat["w8g128_mse"] + cost_mb = feat["w8g128_extra_mb"] + elif action == "w4_g64": + benefit = feat["w4g128_mse"] - feat["w4g64_mse"] + cost_mb = feat["w4g64_extra_mb"] + elif action == "skip": + benefit = feat["w4g128_mse"] # Absolute max benefit + cost_mb = (feat["params"] * 12 / 8) / 1024**2 # FP16 vs W4 difference + + upgrades.append({ + "name": name, "action": action, + "cost": cost_mb, "benefit": benefit, + "roi": benefit / (cost_mb + 1e-9) + }) + + # 2. Sort by Return on Investment (ROI) and Clamp Budget + upgrades.sort(key=lambda x: x["roi"], reverse=True) + + approved_dynamic = {} + spent_mb = 0.0 + + for up in upgrades: + if spent_mb + up["cost"] <= extra_budget_mb: + spent_mb += up["cost"] + # Convert exact name to anchored regex for GPTQModel + anchored_regex = f"^{re.escape(up['name'])}$" + + if up["action"] == "skip": + approved_dynamic[f"-:{anchored_regex}"] = {} + elif up["action"] == "w8_g128": + approved_dynamic[f"+:{anchored_regex}"] = {"bits": 8, "group_size": 128} + elif up["action"] == "w4_g64": + approved_dynamic[f"+:{anchored_regex}"] = {"bits": 4, "group_size": 64} + + print(f"Compiler Approved: {len(approved_dynamic)} upgrades. Budget Spent: {spent_mb:.1f}/{extra_budget_mb} MB") + return approved_dynamic +``` + +### 3. Execution & Verifier Pass +We run three models: Base `gptq_pro`, a simple heuristic plan, and the compiled Judge plan. Then we evaluate via `lm-eval`. + +``` +python +from gptqmodel import QuantizeConfig, GPTQModel +import lm_eval + +def run_quantization_and_eval(model_id, calib_data, dynamic_overrides, run_name): + # Base is W4 G128. We inject the compiled overrides. + quant_config = QuantizeConfig( + bits=4, + group_size=128, + sym=True, + desc_act=True, + dynamic=dynamic_overrides + ) + + # 1. Load and Quantize + print(f"--- Running: {run_name} ---") + model = GPTQModel.load(model_id, quant_config=quant_config) + model.quantize(calib_data, batch_size=4) + + save_path = f"./qwen-3.5-4b-{run_name}" + model.save(save_path) + + # 2. Verify with lm-eval (e.g., ARC-Challenge or HellaSwag) + # Note: lm-eval handles the HF model wrapping + results = lm_eval.simple_evaluate( + model="hf", + model_args=f"pretrained={save_path}", + tasks=["arc_challenge", "hellaswag"], + batch_size="auto", + device="cuda:0" + ) + + print(f"Results for {run_name}:") + print(f"ARC-Challenge: {results['results']['arc_challenge']['acc_norm,none']:.4f}") + print(f"HellaSwag: {results['results']['hellaswag']['acc_norm,none']:.4f}") + return results + +# Example Execution Flow: +# 1. features = probe_model("Qwen/Qwen3.5-4B-Base", text_dataset) +# 2. judge_proposals = call_llm_judge(features) +# 3. final_dynamic = compile_dynamic_config(judge_proposals, features, extra_budget_mb=150) +# 4. run_quantization_and_eval("Qwen/Qwen3.5-4B-Base", text_dataset, final_dynamic, "judge_sdq") +``` + +### Why this is the ultimate "Correct" version: +1. **The LLM is no longer an implicit risk.** By ranking via mathematically grounded `ROI` (Return on Investment: `Delta MSE / Delta MB`), the LLM becomes a high-level strategic proposer. If it hallucinates or prioritizes poorly, the Compiler's budget clamp and `ROI` sorter neutralizes the damage. +2. **GPTQModel Native Constraints:** The exact regex anchoring (`^model\.layers\.0\.mlp\.down_proj$`) ensures `GPTQModel`'s `dynamic` routing matches exactly one layer without cascading side-effects. +3. **Hardware Truth:** `torch.quantize_per_tensor` was mathematically wrong for GPTQ. The custom `fake_quant_groupwise` replicates the actual local chunking (like `group_size=128`) that dictates how much precision a linear layer retains, meaning the `MSE` the probe spits out is a highly accurate proxy for the final GPTQ loss. + +--- + +### ChatGPT: +This v2 is **very close**. The architecture is right: probe → judge → deterministic compiler → GPTQModel `dynamic` → `lm-eval`. That matches GPTQModel’s current API and feature surface: `GPTQModel.load(model_id, quant_config)`, `model.quantize(...)`, `model.save(...)`, plus `QuantizeConfig.dynamic` for per-module overrides and `-:` negative matches to skip modules. `lm-eval` still supports the HF backend with local paths and `simple_evaluate()`. Qwen3.5-4B-Base is currently described as a 32-layer causal LM with a vision encoder and tied LM output, so being selective about which linears you target is still the right move. [Hugging Face+5GitHub+5GitHub+5](https://github.com/ModelCloud/GPTQModel/blob/main/README.md) + +The main thing I would still fix is the fake quantizer. In your snippet, flattening the whole weight tensor into one long vector is **not** a faithful proxy for GPTQ-style group quantization. For linear layers, grouping should happen **per output row along the input dimension**, not across rows, otherwise scales leak across unrelated channels and your MSE ranking gets distorted. Also, your FP reference path should include `bias`, and for probing I would load the model with regular `transformers.AutoModelForCausalLM`, not `GPTQModel.load(..., device_map="auto")`; GPTQModel’s documented load path is tied to a quant config workflow, while your probe is just a forward-feature extractor. [GitHub+1](https://github.com/ModelCloud/GPTQModel/blob/main/README.md) + +Use this fake quantizer instead: + +```Pythonimport torchimport torch.nn.functional as Fdef fake_quant_groupwise_linear_weight(weight: torch.Tensor, bits: int, group_size: int) -> torch.Tensor: """ Symmetric per-row, per-group fake quantization for Linear weights. weight shape: [out_features, in_features] Groups are taken along the input dimension, per output row. """ assert weight.dim() == 2 w = weight.detach().float() out_features, in_features = w.shape pad = (group_size - (in_features % group_size)) % group_size if pad: w = F.pad(w, (0, pad), value=0.0) maxq = (2 ** (bits - 1)) - 1 w = w.view(out_features, -1, group_size) # [out, n_groups, group] absmax = w.abs().amax(dim=-1, keepdim=True).clamp_min(1e-8) scale = absmax / maxq q = torch.round(w / scale).clamp(-maxq, maxq) dq = q * scale dq = dq.view(out_features, -1) if pad: dq = dq[:, :in_features] return dq.to(weight.device, dtype=weight.dtype) +Pythonimport torchimport torch.nn.functional as Fdef fake_quant_groupwise_linear_weight(weight: torch.Tensor, bits: int, group_size: int) -> torch.Tensor: """ Symmetric per-row, per-group fake quantization for Linear weights. weight shape: [out_features, in_features] Groups are taken along the input dimension, per output row. """ assert weight.dim() == 2 w = weight.detach().float() out_features, in_features = w.shape pad = (group_size - (in_features % group_size)) % group_size if pad: w = F.pad(w, (0, pad), value=0.0) maxq = (2 ** (bits - 1)) - 1 w = w.view(out_features, -1, group_size) # [out, n_groups, group] absmax = w.abs().amax(dim=-1, keepdim=True).clamp_min(1e-8) scale = absmax / maxq q = torch.round(w / scale).clamp(-maxq, maxq) dq = q * scale dq = dq.view(out_features, -1) if pad: dq = dq[:, :in_features] return dq.to(weight.device, dtype=weight.dtype) +``` + +And the probe path should look more like this: + +```Pythonfrom transformers import AutoTokenizer, AutoModelForCausalLMdef probe_module(module, x): weight = module.weight.detach() bias = None if module.bias is None else module.bias.detach() with torch.no_grad(): out_fp = F.linear(x, weight, bias) w4g128 = fake_quant_groupwise_linear_weight(weight, bits=4, group_size=128) out_w4g128 = F.linear(x, w4g128, bias) w4g128_mse = F.mse_loss(out_w4g128.float(), out_fp.float()).item() w4g64 = fake_quant_groupwise_linear_weight(weight, bits=4, group_size=64) out_w4g64 = F.linear(x, w4g64, bias) w4g64_mse = F.mse_loss(out_w4g64.float(), out_fp.float()).item() w8g128 = fake_quant_groupwise_linear_weight(weight, bits=8, group_size=128) out_w8g128 = F.linear(x, w8g128, bias) w8g128_mse = F.mse_loss(out_w8g128.float(), out_fp.float()).item() return w4g128_mse, w4g64_mse, w8g128_mse +Pythonfrom transformers import AutoTokenizer, AutoModelForCausalLMdef probe_module(module, x): weight = module.weight.detach() bias = None if module.bias is None else module.bias.detach() with torch.no_grad(): out_fp = F.linear(x, weight, bias) w4g128 = fake_quant_groupwise_linear_weight(weight, bits=4, group_size=128) out_w4g128 = F.linear(x, w4g128, bias) w4g128_mse = F.mse_loss(out_w4g128.float(), out_fp.float()).item() w4g64 = fake_quant_groupwise_linear_weight(weight, bits=4, group_size=64) out_w4g64 = F.linear(x, w4g64, bias) w4g64_mse = F.mse_loss(out_w4g64.float(), out_fp.float()).item() w8g128 = fake_quant_groupwise_linear_weight(weight, bits=8, group_size=128) out_w8g128 = F.linear(x, w8g128, bias) w8g128_mse = F.mse_loss(out_w8g128.float(), out_fp.float()).item() return w4g128_mse, w4g64_mse, w8g128_mse +``` + +Your compiler idea is good. I would add two hard guards: ignore upgrades whose `benefit <= 0`, and require a minimum ROI threshold so you do not spend budget on noise. Also keep the judge output as **exact names + actions**, then compile those exact names into anchored regexes yourself, which is exactly what GPTQModel’s dynamic routing expects. [GitHub](https://github.com/ModelCloud/GPTQModel/blob/main/README.md) + +This compiler version is safer: + +```Pythonimport redef compile_dynamic_config(judge_proposals, probe_features, extra_budget_mb=150.0, min_roi=1e-6): feat_by_name = {f["name"]: f for f in probe_features} upgrades = [] for name, action in judge_proposals.items(): if action == "w4_g128": continue feat = feat_by_name.get(name) if feat is None: continue if action == "w8_g128": benefit = feat["w4g128_mse"] - feat["w8g128_mse"] cost_mb = feat["w8g128_extra_mb"] elif action == "w4_g64": benefit = feat["w4g128_mse"] - feat["w4g64_mse"] cost_mb = feat["w4g64_extra_mb"] elif action == "skip": benefit = feat["w4g128_mse"] cost_mb = (feat["params"] * (2.0 - 0.5)) / (1024 ** 2) # fp16 bytes - w4 bytes else: continue if benefit <= 0 or cost_mb <= 0: continue roi = benefit / cost_mb if roi < min_roi: continue upgrades.append({ "name": name, "action": action, "benefit": benefit, "cost_mb": cost_mb, "roi": roi, }) upgrades.sort(key=lambda x: x["roi"], reverse=True) dynamic = {} spent_mb = 0.0 for up in upgrades: if spent_mb + up["cost_mb"] > extra_budget_mb: continue spent_mb += up["cost_mb"] anchored = f"^{re.escape(up['name'])}$" if up["action"] == "skip": dynamic[f"-:{anchored}"] = {} elif up["action"] == "w8_g128": dynamic[f"+:{anchored}"] = {"bits": 8, "group_size": 128} elif up["action"] == "w4_g64": dynamic[f"+:{anchored}"] = {"bits": 4, "group_size": 64} return dynamic, spent_mb +Pythonimport redef compile_dynamic_config(judge_proposals, probe_features, extra_budget_mb=150.0, min_roi=1e-6): feat_by_name = {f["name"]: f for f in probe_features} upgrades = [] for name, action in judge_proposals.items(): if action == "w4_g128": continue feat = feat_by_name.get(name) if feat is None: continue if action == "w8_g128": benefit = feat["w4g128_mse"] - feat["w8g128_mse"] cost_mb = feat["w8g128_extra_mb"] elif action == "w4_g64": benefit = feat["w4g128_mse"] - feat["w4g64_mse"] cost_mb = feat["w4g64_extra_mb"] elif action == "skip": benefit = feat["w4g128_mse"] cost_mb = (feat["params"] * (2.0 - 0.5)) / (1024 ** 2) # fp16 bytes - w4 bytes else: continue if benefit <= 0 or cost_mb <= 0: continue roi = benefit / cost_mb if roi < min_roi: continue upgrades.append({ "name": name, "action": action, "benefit": benefit, "cost_mb": cost_mb, "roi": roi, }) upgrades.sort(key=lambda x: x["roi"], reverse=True) dynamic = {} spent_mb = 0.0 for up in upgrades: if spent_mb + up["cost_mb"] > extra_budget_mb: continue spent_mb += up["cost_mb"] anchored = f"^{re.escape(up['name'])}$" if up["action"] == "skip": dynamic[f"-:{anchored}"] = {} elif up["action"] == "w8_g128": dynamic[f"+:{anchored}"] = {"bits": 8, "group_size": 128} elif up["action"] == "w4_g64": dynamic[f"+:{anchored}"] = {"bits": 4, "group_size": 64} return dynamic, spent_mb +``` + +Your verifier pass is the right idea too. `lm-eval.simple_evaluate()` is a valid current entry point, and the HF backend supports local paths via `pretrained=/path/to/model`. I would still save all three runs separately—baseline `gptq_pro()`, heuristic mixed plan, judge mixed plan—and compare them on the same task set. [GitHub+1](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/python-api.md?utm_source=chatgpt.com) + +For the quantization run itself, I’d tighten it to this shape: + +```Pythonfrom gptqmodel import GPTQModel, QuantizeConfigdef run_quantization(model_id, calib_data, dynamic_overrides, save_path): qcfg = QuantizeConfig(bits=4, group_size=128, sym=True, desc_act=True, dynamic=dynamic_overrides) model = GPTQModel.load(model_id, qcfg) model.quantize(calib_data, batch_size=4) model.save(save_path) +Pythonfrom gptqmodel import GPTQModel, QuantizeConfigdef run_quantization(model_id, calib_data, dynamic_overrides, save_path): qcfg = QuantizeConfig(bits=4, group_size=128, sym=True, desc_act=True, dynamic=dynamic_overrides) model = GPTQModel.load(model_id, qcfg) model.quantize(calib_data, batch_size=4) model.save(save_path) +``` + +My verdict: **the control architecture is now correct**, and the only real remaining gap is making the probe mathematically faithful enough that the judge is ranking real sensitivity instead of artifacts. Once you swap in per-row group quantization and keep exact-name compilation, this becomes a serious SDQ controller rather than a prompt-driven experiment. + +--- + +### User: +You absolutely nailed the mathematical flaw in the v2 fake quantizer. Flattening the entire weight tensor destroys the structural reality of GPTQ, which operates on **local chunks along the input dimension (in_features)** per output channel. If the probe doesn't respect that grouping geometry, the MSE proxy is practically noise, and the LLM judge is optimizing based on hallucinations. + +Loading the FP model for probing via standard `transformers` instead of `GPTQModel.load` is also the cleanest boundary: let Transformers handle the pure forward pass for feature extraction, and strictly reserve GPTQModel for the actual quantization execution loop. + +Here is the **v3 Production-Ready Architecture**. This integrates your exact per-row grouping math, the strict ROI-driven budget compiler, and a full A/B/C testing suite comparing Baseline, Heuristic, and Judge-driven quantization. + +### The Complete SDQ Pipeline for Qwen3.5-4B + +#### 1. The Mathematically Faithful Probe +This uses pure PyTorch and Transformers to extract genuine sensitivity metrics. + +``` +python +import torch +import torch.nn.functional as F +from transformers import AutoTokenizer, AutoModelForCausalLM + +def fake_quant_groupwise_linear_weight(weight: torch.Tensor, bits: int, group_size: int) -> torch.Tensor: + """Symmetric per-row, per-group fake quantization for Linear weights.""" + assert weight.dim() == 2 + w = weight.detach().float() + out_features, in_features = w.shape + + pad = (group_size - (in_features % group_size)) % group_size + if pad: + w = F.pad(w, (0, pad), value=0.0) + + maxq = (2 ** (bits - 1)) - 1 + w = w.view(out_features, -1, group_size) # [out, n_groups, group] + absmax = w.abs().amax(dim=-1, keepdim=True).clamp_min(1e-8) + scale = absmax / maxq + + q = torch.round(w / scale).clamp(-maxq, maxq) + dq = q * scale + dq = dq.view(out_features, -1) + + if pad: + dq = dq[:, :in_features] + + return dq.to(weight.device, dtype=weight.dtype) + +def extract_features(model_id, calib_texts): + print("Loading FP16 model for probing...") + tokenizer = AutoTokenizer.from_pretrained(model_id) + model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.float16) + + inputs = tokenizer(calib_texts, return_tensors="pt", padding=True, truncation=True, max_length=512) + inputs = {k: v.to(model.device) for k, v in inputs.items()} + + target_types = ("q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", "lm_head") + features =[] + + for name, module in model.named_modules(): + if not isinstance(module, torch.nn.Linear) or not any(t in name for t in target_types): + continue + + # 1. Capture Activation + act =[] + def hook(m, i, o): act.append(i[0].detach()) + handle = module.register_forward_hook(hook) + + with torch.no_grad(): + model(**inputs) + handle.remove() + + x = act[0] + weight = module.weight.detach() + bias = None if module.bias is None else module.bias.detach() + + # 2. Kurtosis + w_float = weight.float() + kurtosis = ((w_float - w_float.mean())**4).mean() / (w_float.var()**2) + + # 3. Simulate Quantization Error + with torch.no_grad(): + out_fp = F.linear(x, weight, bias) + + w4g128 = fake_quant_groupwise_linear_weight(weight, bits=4, group_size=128) + out_w4g128 = F.linear(x, w4g128, bias) + w4g128_mse = F.mse_loss(out_w4g128.float(), out_fp.float()).item() + + w4g64 = fake_quant_groupwise_linear_weight(weight, bits=4, group_size=64) + out_w4g64 = F.linear(x, w4g64, bias) + w4g64_mse = F.mse_loss(out_w4g64.float(), out_fp.float()).item() + + w8g128 = fake_quant_groupwise_linear_weight(weight, bits=8, group_size=128) + out_w8g128 = F.linear(x, w8g128, bias) + w8g128_mse = F.mse_loss(out_w8g128.float(), out_fp.float()).item() + + params = weight.numel() + + features.append({ + "name": name, + "params": params, + "kurtosis": kurtosis.item(), + "w4g128_mse": w4g128_mse, + "w4g64_mse": w4g64_mse, + "w8g128_mse": w8g128_mse, + "w4g64_extra_mb": (params * 0.5) / (1024**2) * 0.1, # Est. group overhead + "w8g128_extra_mb": (params * 0.5) / (1024**2), # 4 extra bits per param + }) + + del model # Free VRAM for GPTQ step + torch.cuda.empty_cache() + return features +``` + +#### 2. The Strict ROI Compiler +This enforces your mathematical guards (`benefit > 0`, `min_roi`) and locks the output into `GPTQModel`’s exact regex format. + +``` +python +import re + +def compile_dynamic_config(judge_proposals, probe_features, extra_budget_mb=150.0, min_roi=1e-6): + feat_by_name = {f["name"]: f for f in probe_features} + upgrades =[] + + for name, action in judge_proposals.items(): + if action == "w4_g128": continue + + feat = feat_by_name.get(name) + if feat is None: continue + + if action == "w8_g128": + benefit = feat["w4g128_mse"] - feat["w8g128_mse"] + cost_mb = feat["w8g128_extra_mb"] + elif action == "w4_g64": + benefit = feat["w4g128_mse"] - feat["w4g64_mse"] + cost_mb = feat["w4g64_extra_mb"] + elif action == "skip": + benefit = feat["w4g128_mse"] + cost_mb = (feat["params"] * 1.5) / (1024 ** 2) # fp16 bytes (2.0) - w4 bytes (0.5) + else: continue + + # Strict Guards + if benefit <= 0 or cost_mb <= 0: continue + + roi = benefit / cost_mb + if roi < min_roi: continue + + upgrades.append({ + "name": name, "action": action, + "cost_mb": cost_mb, "roi": roi + }) + + # Sort strictly by Return on Investment + upgrades.sort(key=lambda x: x["roi"], reverse=True) + + dynamic = {} + spent_mb = 0.0 + + for up in upgrades: + if spent_mb + up["cost_mb"] > extra_budget_mb: + continue + + spent_mb += up["cost_mb"] + # Exact module match via anchored regex + anchored = f"^{re.escape(up['name'])}$" + + if up["action"] == "skip": + dynamic[f"-:{anchored}"] = {} + elif up["action"] == "w8_g128": + dynamic[f"+:{anchored}"] = {"bits": 8, "group_size": 128} + elif up["action"] == "w4_g64": + dynamic[f"+:{anchored}"] = {"bits": 4, "group_size": 64} + + print(f"Compiler Allocated {spent_mb:.1f}MB / {extra_budget_mb}MB across {len(dynamic)} layers.") + return dynamic +``` + +#### 3. Execution & The Verifier Suite +We map everything back to `GPTQModel` and `lm-eval`. + +``` +python +from gptqmodel import GPTQModel, QuantizeConfig +import lm_eval + +def run_quantization(model_id, calib_data, dynamic_overrides, save_path): + print(f"\n--- Quantizing into {save_path} ---") + qcfg = QuantizeConfig( + bits=4, + group_size=128, + sym=True, + desc_act=True, + dynamic=dynamic_overrides + ) + + # Standard GPTQModel documented flow + model = GPTQModel.load(model_id, quant_config=qcfg) + model.quantize(calib_data, batch_size=4) + model.save(save_path) + + # Cleanup to avoid VRAM OOMs during eval + del model + torch.cuda.empty_cache() + +def evaluate_model(model_path): + print(f"\n--- Evaluating {model_path} ---") + results = lm_eval.simple_evaluate( + model="hf", + model_args=f"pretrained={model_path}", + tasks=["arc_challenge", "hellaswag"], + batch_size="auto", + device="cuda:0" + ) + arc_acc = results['results']['arc_challenge']['acc_norm,none'] + hs_acc = results['results']['hellaswag']['acc_norm,none'] + return arc_acc, hs_acc + +# ========================================== +# Master Pipeline Execution +# ========================================== +model_id = "Qwen/Qwen3.5-4B-Base" +calib_dataset =["Calibration text sample 1...", "Calibration text sample 2..."] + +# 1. Feature Extraction +features = extract_features(model_id, calib_dataset) + +# 2. Get LLM Judge Proposals (Mocked here) +# -> You pass `features` to your LLM and parse the JSON response. +judge_proposals = { + "model.layers.31.mlp.down_proj": "w8_g128", + "lm_head": "skip", + "model.layers.0.self_attn.o_proj": "w4_g64" +} + +# 3. Compile Configs +judge_dynamic = compile_dynamic_config(judge_proposals, features, extra_budget_mb=150) + +# Build a baseline heuristic for A/B testing +heuristic_dynamic = { + r"-:^lm_head$": {}, + r"+:.*down_proj$": {"bits": 4, "group_size": 64} +} + +# 4. Run Suite +configs = { + "baseline_w4g128": None, + "heuristic_mixed": heuristic_dynamic, + "judge_sdq": judge_dynamic +} + +scores = {} +for name, dynamic_cfg in configs.items(): + save_path = f"./models/{model_id.split('/')[-1]}-{name}" + run_quantization(model_id, calib_dataset, dynamic_cfg, save_path) + arc, hs = evaluate_model(save_path) + scores[name] = {"ARC": arc, "HellaSwag": hs} + +print("\n=== Final SDQ Leaderboard ===") +for name, score in scores.items(): + print(f"{name:15} | ARC: {score['ARC']:.4f} | HS: {score['HellaSwag']:.4f}") +``` + +### Why this architecture is fundamentally correct: +1. **Mathematical Isolation:** Fake quantization correctly models the physical chunk boundaries of the eventual kernel, validating true sensitivity. +2. **Economic Guardrails:** The LLM's role is restricted to a policy proposal. The deterministic compiler converts that into an exact ROI curve and enforces the VRAM budget, ignoring "hallucinated urgency" if the math says the MSE delta isn't worth it. +3. **API Alignment:** It correctly uses `transformers` for pure FP16 probing, and leaves the actual tensor modification to `GPTQModel` using the official `QuantizeConfig.dynamic` surface. + +This is arguably the most sophisticated way to run local GPTQ compression currently available. provide all project in one answer with json format + +--- + +### ChatGPT: +```JSON{ "project_name": "sdq_quant_judge_qwen35_gptqmodel", "files": [ { "path": "README.md", "content": "# SDQ Quant Judge for GPTQModel\n\nA software-defined quantization scaffold for Qwen/Qwen3.5-4B-Base.\n\nCommands:\n- probe: extract per-module fake-quant sensitivity features from the FP model\n- plan: build a mixed-precision plan with either heuristic logic or an external OpenAI-compatible judge\n- quantize: quantize with GPTQModel using the compiled dynamic overrides\n- eval: evaluate a saved model with lm-eval\n- suite: run baseline vs heuristic vs judge and save a leaderboard\n\nExample:\n\n```bash\npython -m venv .venv\nsource .venv/bin/activate\npip install -r requirements.txt\n\npython -m quant_judge.pipeline probe \\\n --model-id Qwen/Qwen3.5-4B-Base \\\n --out-dir ./runs/probe \\\n --dataset-name wikitext \\\n --dataset-config wikitext-2-raw-v1 \\\n --split train \\\n --text-field text \\\n --num-texts 64 \\\n --max-length 256 \\\n --max-rows-per-module 128 \\\n --include-modules q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj,lm_head\n\npython -m quant_judge.pipeline plan \\\n --features-jsonl ./runs/probe/probe_features.jsonl \\\n --plan-out ./runs/plan.json \\\n --dynamic-out ./runs/dynamic.json \\\n --budget-mb-extra 150 \\\n --heuristic-only\n\npython -m quant_judge.pipeline quantize \\\n --model-id Qwen/Qwen3.5-4B-Base \\\n --plan-json ./runs/plan.json \\\n --out-dir ./runs/quantized \\\n --dataset-name wikitext \\\n --dataset-config wikitext-2-raw-v1 \\\n --split train \\\n --text-field text \\\n --num-texts 128 \\\n --batch-size 4\n\npython -m quant_judge.pipeline eval \\\n --model-path ./runs/quantized \\\n --tasks arc_challenge,hellaswag \\\n --output-json ./runs/eval.json\n```\n" }, { "path": "requirements.txt", "content": "transformers\ndatasets\naccelerate\nrequests\nnumpy\ntqdm\npsutil\ngptqmodel\nlm-eval\n" }, { "path": "quant_judge/__init__.py", "content": "\"\"\"SDQ quantization scaffold.\"\"\"\n" }, { "path": "quant_judge/actions.py", "content": "from __future__ import annotations\n\nfrom dataclasses import dataclass\n\n\n@dataclass(frozen=True)\nclass ActionSpec:\n name: str\n bits: int | None\n group_size: int | None\n skip: bool\n\n\nACTION_SPECS: dict[str, ActionSpec] = {\n \"w4_g128\": ActionSpec(\"w4_g128\", bits=4, group_size=128, skip=False),\n \"w4_g64\": ActionSpec(\"w4_g64\", bits=4, group_size=64, skip=False),\n \"w8_g128\": ActionSpec(\"w8_g128\", bits=8, group_size=128, skip=False),\n \"skip\": ActionSpec(\"skip\", bits=None, group_size=None, skip=True),\n}\n\nBASE_ACTION = \"w4_g128\"\n\n\ndef allowed_actions() -> list[str]:\n return list(ACTION_SPECS.keys())\n\n\ndef bytes_per_param(action: str, fp_bytes: float = 2.0) -> float:\n if action == \"skip\":\n return fp_bytes\n spec = ACTION_SPECS[action]\n assert spec.bits is not None\n return spec.bits / 8.0\n\n\ndef extra_mb_vs_base(params: int, action: str) -> float:\n base = params * bytes_per_param(BASE_ACTION)\n now = params * bytes_per_param(action)\n return (now - base) / (1024 ** 2)\n\n\ndef action_to_override(action: str) -> dict:\n if action == \"skip\":\n return {}\n spec = ACTION_SPECS[action]\n return {\"bits\": spec.bits, \"group_size\": spec.group_size}\n" }, { "path": "quant_judge/probe.py", "content": "from __future__ import annotations\n\nimport json\nfrom dataclasses import dataclass\nfrom pathlib import Path\n\nimport torch\nimport torch.nn.functional as F\nfrom datasets import load_dataset\nfrom tqdm import tqdm\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\n\nfrom .actions import extra_mb_vs_base\n\n\n@dataclass\nclass ProbeConfig:\n model_id: str\n out_dir: str\n dataset_name: str\n dataset_config: str | None\n split: str\n text_field: str\n num_texts: int\n max_length: int\n max_rows_per_module: int\n include_modules: list[str]\n trust_remote_code: bool = True\n\n\ndef pick_dtype() -> torch.dtype:\n if not torch.cuda.is_available():\n return torch.float32\n major, _minor = torch.cuda.get_device_capability(0)\n if major >= 8:\n return torch.bfloat16\n return torch.float16\n\n\ndef fake_quant_groupwise_linear_weight(weight: torch.Tensor, bits: int, group_size: int) -> torch.Tensor:\n assert weight.dim() == 2\n w = weight.detach().float()\n out_features, in_features = w.shape\n pad = (group_size - (in_features % group_size)) % group_size\n if pad:\n w = F.pad(w, (0, pad), value=0.0)\n maxq = (2 ** (bits - 1)) - 1\n w = w.view(out_features, -1, group_size)\n absmax = w.abs().amax(dim=-1, keepdim=True).clamp_min(1e-8)\n scale = absmax / maxq\n q = torch.round(w / scale).clamp(-maxq, maxq)\n dq = q * scale\n dq = dq.view(out_features, -1)\n if pad:\n dq = dq[:, :in_features]\n return dq.to(weight.device, dtype=weight.dtype)\n\n\ndef load_texts(dataset_name: str, dataset_config: str | None, split: str, text_field: str, num_texts: int) -> list[str]:\n ds = load_dataset(dataset_name, dataset_config, split=split)\n texts: list[str] = []\n for row in ds:\n value = row.get(text_field, \"\")\n if isinstance(value, str) and value.strip():\n texts.append(value.strip())\n if len(texts) >= num_texts:\n break\n if not texts:\n raise ValueError(\"No usable texts found.\")\n return texts\n\n\ndef module_selected(name: str, include_modules: list[str]) -> bool:\n return any(token in name for token in include_modules)\n\n\ndef iter_target_linears(model, include_modules: list[str]):\n for name, module in model.named_modules():\n if isinstance(module, torch.nn.Linear) and module_selected(name, include_modules):\n yield name, module\n\n\ndef flatten_rows(x: torch.Tensor) -> torch.Tensor:\n if x.dim() == 2:\n return x\n if x.dim() >= 3:\n return x.reshape(-1, x.shape[-1])\n return x.unsqueeze(0)\n\n\ndef sample_rows(x: torch.Tensor, max_rows: int) -> torch.Tensor:\n rows = flatten_rows(x)\n if rows.shape[0] <= max_rows:\n return rows\n idx = torch.randperm(rows.shape[0], device=rows.device)[:max_rows]\n return rows.index_select(0, idx)\n\n\ndef weight_kurtosis(weight: torch.Tensor) -> float:\n w = weight.detach().float().flatten()\n mu = w.mean()\n var = ((w - mu) ** 2).mean().clamp_min(1e-12)\n return float((((w - mu) ** 4).mean() / (var ** 2)).item())\n\n\ndef activation_outlier_rate(x: torch.Tensor, sigma: float = 6.0) -> float:\n xf = x.float()\n std = xf.std().clamp_min(1e-6)\n return float((xf.abs() > sigma * std).float().mean().item())\n\n\ndef block_index_from_name(name: str) -> int | None:\n parts = name.split(\".\")\n for i, p in enumerate(parts):\n if p == \"layers\" and i + 1 < len(parts):\n try:\n return int(parts[i + 1])\n except Exception:\n return None\n return None\n\n\ndef reduced_metrics_for_action(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None, bits: int, group_size: int) -> dict:\n wq = fake_quant_groupwise_linear_weight(weight, bits=bits, group_size=group_size)\n out_fp = F.linear(x, weight, bias)\n out_q = F.linear(x, wq, bias)\n mse = F.mse_loss(out_q.float(), out_fp.float()).item()\n cos = F.cosine_similarity(out_q.float(), out_fp.float(), dim=-1).mean().item()\n max_abs = (out_q.float() - out_fp.float()).abs().max().item()\n return {\"mse\": float(mse), \"cosine\": float(cos), \"max_abs\": float(max_abs)}\n\n\ndef run_probe(cfg: ProbeConfig) -> dict:\n out_dir = Path(cfg.out_dir)\n out_dir.mkdir(parents=True, exist_ok=True)\n\n tokenizer = AutoTokenizer.from_pretrained(cfg.model_id, trust_remote_code=cfg.trust_remote_code)\n if tokenizer.pad_token_id is None:\n tokenizer.pad_token = tokenizer.eos_token\n\n model = AutoModelForCausalLM.from_pretrained(\n cfg.model_id,\n trust_remote_code=cfg.trust_remote_code,\n torch_dtype=pick_dtype(),\n device_map=\"auto\",\n low_cpu_mem_usage=True,\n )\n model.eval()\n\n texts = load_texts(cfg.dataset_name, cfg.dataset_config, cfg.split, cfg.text_field, cfg.num_texts)\n targets = list(iter_target_linears(model, cfg.include_modules))\n if not targets:\n raise ValueError(\"No target modules found.\")\n\n captured_inputs = {name: [] for name, _ in targets}\n remaining_rows = {name: cfg.max_rows_per_module for name, _ in targets}\n handles = []\n\n def make_hook(name: str):\n def hook(_module, inp, _out):\n rem = remaining_rows[name]\n if rem <= 0 or not inp or not isinstance(inp[0], torch.Tensor):\n return\n x = sample_rows(inp[0], rem)\n take = min(x.shape[0], rem)\n if take <= 0:\n return\n captured_inputs[name].append(x[:take].detach().cpu().to(torch.float16))\n remaining_rows[name] -= take\n return hook\n\n for name, module in targets:\n handles.append(module.register_forward_hook(make_hook(name)))\n\n with torch.no_grad():\n for text in tqdm(texts, desc=\"collect_activations\"):\n batch = tokenizer(text, return_tensors=\"pt\", truncation=True, max_length=cfg.max_length, padding=False)\n batch = {k: v.to(model.device) for k, v in batch.items()}\n _ = model(**batch)\n if all(v <= 0 for v in remaining_rows.values()):\n break\n\n for h in handles:\n h.remove()\n\n rows = []\n for name, module in tqdm(targets, desc=\"probe_modules\"):\n if not captured_inputs[name]:\n continue\n x = torch.cat(captured_inputs[name], dim=0).to(module.weight.device, dtype=module.weight.dtype)\n weight = module.weight.detach()\n bias = None if module.bias is None else module.bias.detach()\n\n m_w4g128 = reduced_metrics_for_action(x, weight, bias, bits=4, group_size=128)\n m_w4g64 = reduced_metrics_for_action(x, weight, bias, bits=4, group_size=64)\n m_w8g128 = reduced_metrics_for_action(x, weight, bias, bits=8, group_size=128)\n\n params = int(weight.numel())\n row = {\n \"name\": name,\n \"module_type\": name.split(\".\")[-1],\n \"block_index\": block_index_from_name(name),\n \"params\": params,\n \"captured_rows\": int(x.shape[0]),\n \"weight_kurtosis\": weight_kurtosis(weight),\n \"act_outlier_rate\": activation_outlier_rate(x),\n \"candidates\": {\n \"w4_g128\": {**m_w4g128, \"estimated_extra_mb_vs_base\": 0.0},\n \"w4_g64\": {**m_w4g64, \"estimated_extra_mb_vs_base\": max(0.0, extra_mb_vs_base(params, \"w4_g64\")) + (params * 0.05 / (1024 ** 2))},\n \"w8_g128\": {**m_w8g128, \"estimated_extra_mb_vs_base\": max(0.0, extra_mb_vs_base(params, \"w8_g128\"))},\n \"skip\": {\"mse\": 0.0, \"cosine\": 1.0, \"max_abs\": 0.0, \"estimated_extra_mb_vs_base\": max(0.0, extra_mb_vs_base(params, \"skip\"))},\n },\n \"proxy_gain_w4g64_vs_w4\": float(max(0.0, m_w4g128[\"mse\"] - m_w4g64[\"mse\"])),\n \"proxy_gain_w8_vs_w4\": float(max(0.0, m_w4g128[\"mse\"] - m_w8g128[\"mse\"])),\n }\n rows.append(row)\n\n features_jsonl = out_dir / \"probe_features.jsonl\"\n metadata_json = out_dir / \"probe_metadata.json\"\n summary_json = out_dir / \"probe_summary.json\"\n\n with open(features_jsonl, \"w\", encoding=\"utf-8\") as f:\n for row in rows:\n f.write(json.dumps(row, ensure_ascii=False) + \"\\n\")\n\n metadata = {\n \"model_id\": cfg.model_id,\n \"dataset_name\": cfg.dataset_name,\n \"dataset_config\": cfg.dataset_config,\n \"split\": cfg.split,\n \"text_field\": cfg.text_field,\n \"num_texts\": cfg.num_texts,\n \"max_length\": cfg.max_length,\n \"max_rows_per_module\": cfg.max_rows_per_module,\n \"include_modules\": cfg.include_modules,\n \"num_rows\": len(rows),\n }\n metadata_json.write_text(json.dumps(metadata, indent=2), encoding=\"utf-8\")\n summary_json.write_text(json.dumps({\"num_rows\": len(rows)}, indent=2), encoding=\"utf-8\")\n\n del model\n if torch.cuda.is_available():\n torch.cuda.empty_cache()\n\n return {\"features_jsonl\": str(features_jsonl), \"metadata_json\": str(metadata_json), \"summary_json\": str(summary_json), \"num_rows\": len(rows)}\n" }, { "path": "quant_judge/judge.py", "content": "from __future__ import annotations\n\nimport json\nimport os\nfrom dataclasses import dataclass\n\nimport requests\n\nfrom .actions import ACTION_SPECS, BASE_ACTION, allowed_actions\n\n\n@dataclass\nclass JudgeConfig:\n features_jsonl: str\n budget_mb_extra: float\n heuristic_only: bool = False\n judge_base_url: str | None = None\n judge_model: str | None = None\n judge_api_key_env: str = \"JUDGE_API_KEY\"\n chunk_size: int = 24\n\n\ndef load_rows(path: str) -> list[dict]:\n rows = []\n with open(path, \"r\", encoding=\"utf-8\") as f:\n for line in f:\n line = line.strip()\n if line:\n rows.append(json.loads(line))\n return rows\n\n\ndef heuristic_decision(row: dict) -> dict:\n name = row[\"name\"]\n w4 = row[\"candidates\"][\"w4_g128\"][\"mse\"]\n w4g64 = row[\"candidates\"][\"w4_g64\"][\"mse\"]\n w8 = row[\"candidates\"][\"w8_g128\"][\"mse\"]\n kurt = row[\"weight_kurtosis\"]\n outlier = row[\"act_outlier_rate\"]\n block = row.get(\"block_index\")\n\n w8_relief = (w4 - w8) / max(w4, 1e-9)\n g64_relief = (w4 - w4g64) / max(w4, 1e-9)\n\n score = 0.0\n score += min(1.0, w8_relief / 0.60) * 0.40\n score += min(1.0, g64_relief / 0.35) * 0.20\n score += min(1.0, kurt / 20.0) * 0.15\n score += min(1.0, outlier / 0.05) * 0.15\n if \"lm_head\" in name:\n score += 0.25\n if row[\"module_type\"] in (\"down_proj\", \"o_proj\", \"gate_proj\"):\n score += 0.10\n if block is not None and (block <= 1 or block >= 30):\n score += 0.10\n score = max(0.0, min(1.0, score))\n\n action = BASE_ACTION\n reason = \"default\"\n if \"lm_head\" in name and (w8_relief > 0.55 or kurt > 25.0):\n action = \"skip\"\n reason = \"lm_head_extreme\"\n elif score >= 0.75 or w8_relief > 0.50:\n action = \"w8_g128\"\n reason = \"high_sensitivity\"\n elif score >= 0.42 or g64_relief > 0.18:\n action = \"w4_g64\"\n reason = \"moderate_sensitivity\"\n\n return {\"name\": name, \"action\": action, \"priority\": round(score, 6), \"reason_code\": reason, \"source\": \"heuristic\"}\n\n\ndef reduced_row(row: dict) -> dict:\n return {\n \"name\": row[\"name\"],\n \"module_type\": row[\"module_type\"],\n \"block_index\": row.get(\"block_index\"),\n \"params\": row[\"params\"],\n \"weight_kurtosis\": row[\"weight_kurtosis\"],\n \"act_outlier_rate\": row[\"act_outlier_rate\"],\n \"w4_g128_mse\": row[\"candidates\"][\"w4_g128\"][\"mse\"],\n \"w4_g64_mse\": row[\"candidates\"][\"w4_g64\"][\"mse\"],\n \"w8_g128_mse\": row[\"candidates\"][\"w8_g128\"][\"mse\"],\n \"extra_mb_w4_g64\": row[\"candidates\"][\"w4_g64\"][\"estimated_extra_mb_vs_base\"],\n \"extra_mb_w8_g128\": row[\"candidates\"][\"w8_g128\"][\"estimated_extra_mb_vs_base\"],\n \"extra_mb_skip\": row[\"candidates\"][\"skip\"][\"estimated_extra_mb_vs_base\"],\n }\n\n\ndef strip_code_fence(text: str) -> str:\n t = text.strip()\n if t.startswith(\"```\"):\n t = t.split(\"\\n\", 1)[1]\n if t.endswith(\"```\"):\n t = t.rsplit(\"\\n\", 1)[0]\n return t.strip()\n\n\ndef chunk(rows: list[dict], size: int) -> list[list[dict]]:\n return [rows[i:i + size] for i in range(0, len(rows), size)]\n\n\ndef call_openai_compatible(base_url: str, model: str, api_key: str, rows: list[dict], budget_mb_extra: float) -> list[dict]:\n system = (\n \"You are a quantization policy judge. \"\n \"Choose exactly one action per module from: \" + \", \".join(allowed_actions()) + \". \"\n \"Prefer w4_g128 unless the evidence shows fragility. \"\n \"Use w4_g64 for moderate sensitivity. Use w8_g128 for high sensitivity. \"\n \"Use skip only for extreme outliers, usually lm_head. \"\n \"Return JSON only with schema {\\\\\\\"assignments\\\\\\\":[{\\\\\\\"name\\\\\\\":...,\\\\\\\"action\\\\\\\":...,\\\\\\\"priority\\\\\\\":0..1,\\\\\\\"reason_code\\\\\\\":...}]}\"\n )\n user = {\"budget_mb_extra_global\": budget_mb_extra, \"allowed_actions\": allowed_actions(), \"rows\": rows}\n url = base_url.rstrip(\"/\") + \"/chat/completions\"\n headers = {\"Content-Type\": \"application/json\"}\n if api_key:\n headers[\"Authorization\"] = f\"Bearer {api_key}\"\n body = {\n \"model\": model,\n \"temperature\": 0.0,\n \"response_format\": {\"type\": \"json_object\"},\n \"messages\": [\n {\"role\": \"system\", \"content\": system},\n {\"role\": \"user\", \"content\": json.dumps(user, ensure_ascii=False)},\n ],\n }\n resp = requests.post(url, headers=headers, json=body, timeout=180)\n resp.raise_for_status()\n content = resp.json()[\"choices\"][0][\"message\"][\"content\"]\n parsed = json.loads(strip_code_fence(content))\n assignments = parsed.get(\"assignments\", [])\n if not isinstance(assignments, list):\n raise ValueError(\"Judge response missing assignments.\")\n return assignments\n\n\ndef merge_assignments(rows: list[dict], proposed: list[dict]) -> list[dict]:\n by_name = {a[\"name\"]: a for a in proposed if isinstance(a, dict) and \"name\" in a}\n merged = []\n for row in rows:\n item = by_name.get(row[\"name\"])\n if item is None:\n merged.append(heuristic_decision(row))\n continue\n action = item.get(\"action\", BASE_ACTION)\n if action not in ACTION_SPECS:\n action = BASE_ACTION\n try:\n priority = float(item.get(\"priority\", 0.5))\n except Exception:\n priority = 0.5\n priority = max(0.0, min(1.0, priority))\n merged.append({\"name\": row[\"name\"], \"action\": action, \"priority\": priority, \"reason_code\": str(item.get(\"reason_code\", \"judge\")), \"source\": \"judge\"})\n return merged\n\n\ndef make_judge_proposals(cfg: JudgeConfig) -> dict:\n rows = load_rows(cfg.features_jsonl)\n if cfg.heuristic_only:\n return {\"assignments\": [heuristic_decision(r) for r in rows]}\n if not cfg.judge_base_url or not cfg.judge_model:\n raise ValueError(\"Judge mode requires judge_base_url and judge_model.\")\n api_key = os.environ.get(cfg.judge_api_key_env, \"\")\n assignments = []\n for part in chunk(rows, cfg.chunk_size):\n reduced = [reduced_row(r) for r in part]\n try:\n proposed = call_openai_compatible(cfg.judge_base_url, cfg.judge_model, api_key, reduced, cfg.budget_mb_extra)\n assignments.extend(merge_assignments(part, proposed))\n except Exception:\n assignments.extend([heuristic_decision(r) for r in part])\n return {\"assignments\": assignments}\n" }, { "path": "quant_judge/compiler.py", "content": "from __future__ import annotations\n\nimport json\nimport re\nfrom pathlib import Path\n\nfrom .actions import BASE_ACTION, action_to_override\n\n\ndef compile_plan(assignments: list[dict], probe_rows: list[dict], extra_budget_mb: float, min_roi: float = 1e-6) -> dict:\n probe_by_name = {row[\"name\"]: row for row in probe_rows}\n upgrades = []\n for item in assignments:\n name = item[\"name\"]\n action = item[\"action\"]\n priority = float(item.get(\"priority\", 0.5))\n if action == BASE_ACTION:\n continue\n row = probe_by_name.get(name)\n if row is None:\n continue\n if action == \"w8_g128\":\n benefit = row[\"candidates\"][\"w4_g128\"][\"mse\"] - row[\"candidates\"][\"w8_g128\"][\"mse\"]\n cost_mb = row[\"candidates\"][\"w8_g128\"][\"estimated_extra_mb_vs_base\"]\n elif action == \"w4_g64\":\n benefit = row[\"candidates\"][\"w4_g128\"][\"mse\"] - row[\"candidates\"][\"w4_g64\"][\"mse\"]\n cost_mb = row[\"candidates\"][\"w4_g64\"][\"estimated_extra_mb_vs_base\"]\n elif action == \"skip\":\n benefit = row[\"candidates\"][\"w4_g128\"][\"mse\"]\n cost_mb = row[\"candidates\"][\"skip\"][\"estimated_extra_mb_vs_base\"]\n else:\n continue\n if benefit <= 0 or cost_mb <= 0:\n continue\n roi = benefit / cost_mb\n if roi < min_roi:\n continue\n upgrades.append({\n \"name\": name,\n \"action\": action,\n \"benefit\": float(benefit),\n \"cost_mb\": float(cost_mb),\n \"roi\": float(roi),\n \"priority\": priority,\n \"reason_code\": item.get(\"reason_code\", \"\"),\n \"source\": item.get(\"source\", \"\"),\n })\n upgrades.sort(key=lambda x: (x[\"roi\"], x[\"priority\"]), reverse=True)\n\n approved = []\n spent_mb = 0.0\n for up in upgrades:\n if spent_mb + up[\"cost_mb\"] > extra_budget_mb:\n continue\n spent_mb += up[\"cost_mb\"]\n approved.append(up)\n\n approved_by_name = {x[\"name\"]: x for x in approved}\n final_assignments = []\n for row in probe_rows:\n ap = approved_by_name.get(row[\"name\"])\n if ap is None:\n final_assignments.append({\"name\": row[\"name\"], \"action\": BASE_ACTION, \"priority\": 0.0, \"reason_code\": \"base_default\", \"source\": \"compiler\", \"estimated_extra_mb_vs_base\": 0.0})\n else:\n final_assignments.append({\"name\": ap[\"name\"], \"action\": ap[\"action\"], \"priority\": ap[\"priority\"], \"reason_code\": ap[\"reason_code\"], \"source\": ap[\"source\"], \"estimated_extra_mb_vs_base\": ap[\"cost_mb\"], \"roi\": ap[\"roi\"], \"benefit\": ap[\"benefit\"]})\n final_assignments.sort(key=lambda x: x[\"name\"])\n return {\"budget_mb_extra\": float(extra_budget_mb), \"used_mb_extra\": round(float(spent_mb), 6), \"upgrades_considered\": len(upgrades), \"upgrades_approved\": len(approved), \"assignments\": final_assignments}\n\n\ndef plan_to_dynamic(plan: dict) -> dict:\n dynamic = {}\n for item in plan[\"assignments\"]:\n name = item[\"name\"]\n action = item[\"action\"]\n if action == BASE_ACTION:\n continue\n anchored = f\"^{re.escape(name)}$\"\n if action == \"skip\":\n dynamic[f\"-:{anchored}\"] = {}\n else:\n dynamic[f\"+:{anchored}\"] = action_to_override(action)\n return dynamic\n\n\ndef save_plan_and_dynamic(plan: dict, dynamic: dict, plan_out: str, dynamic_out: str) -> None:\n Path(plan_out).parent.mkdir(parents=True, exist_ok=True)\n Path(dynamic_out).parent.mkdir(parents=True, exist_ok=True)\n Path(plan_out).write_text(json.dumps(plan, indent=2), encoding=\"utf-8\")\n Path(dynamic_out).write_text(json.dumps(dynamic, indent=2), encoding=\"utf-8\")\n" }, { "path": "quant_judge/runner.py", "content": "from __future__ import annotations\n\nimport json\nimport subprocess\nimport sys\nfrom pathlib import Path\n\nimport torch\nfrom datasets import load_dataset\nfrom gptqmodel import GPTQModel, QuantizeConfig\n\n\ndef load_texts(dataset_name: str, dataset_config: str | None, split: str, text_field: str, num_texts: int) -> list[str]:\n ds = load_dataset(dataset_name, dataset_config, split=split)\n texts = []\n for row in ds:\n text = row.get(text_field, \"\")\n if isinstance(text, str) and text.strip():\n texts.append(text.strip())\n if len(texts) >= num_texts:\n break\n if not texts:\n raise ValueError(\"No texts found for quantization.\")\n return texts\n\n\ndef quantize_with_dynamic(model_id: str, calibration_texts: list[str], out_dir: str, dynamic: dict | None, bits: int = 4, group_size: int = 128, batch_size: int = 4, sym: bool = True, desc_act: bool = True) -> dict:\n qcfg = QuantizeConfig(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, dynamic=dynamic)\n model = GPTQModel.load(model_id, qcfg)\n model.quantize(calibration_texts, batch_size=batch_size)\n out = Path(out_dir)\n out.mkdir(parents=True, exist_ok=True)\n model.save(str(out))\n meta = {\"model_id\": model_id, \"out_dir\": str(out), \"bits\": bits, \"group_size\": group_size, \"batch_size\": batch_size, \"dynamic_rules\": 0 if dynamic is None else len(dynamic)}\n (out / \"quant_meta.json\").write_text(json.dumps(meta, indent=2), encoding=\"utf-8\")\n del model\n if torch.cuda.is_available():\n torch.cuda.empty_cache()\n return meta\n\n\ndef heuristic_dynamic_from_flags(flags: list[str]) -> dict:\n dynamic = {}\n for flag in flags:\n if flag == \"skip_lm_head\":\n dynamic[r\"-:^lm_head$\"] = {}\n elif flag == \"down_proj_g64\":\n dynamic[r\"+:.*down_proj$\"] = {\"bits\": 4, \"group_size\": 64}\n elif flag == \"o_proj_g64\":\n dynamic[r\"+:.*o_proj$\"] = {\"bits\": 4, \"group_size\": 64}\n elif flag == \"edge_blocks_w8\":\n dynamic[r\"+:^model\\\\.layers\\\\.(0|1|30|31)\\\\..*$\"] = {\"bits\": 8, \"group_size\": 128}\n return dynamic\n\n\ndef run_lm_eval(model_path: str, tasks: list[str], output_json: str, device: str = \"cuda:0\", batch_size: str = \"auto\") -> dict:\n output_path = Path(output_json)\n output_path.parent.mkdir(parents=True, exist_ok=True)\n cmd = [\n sys.executable, \"-m\", \"lm_eval\",\n \"--model\", \"hf\",\n \"--model_args\", f\"pretrained={model_path},trust_remote_code=True\",\n \"--tasks\", \",\".join(tasks),\n \"--device\", device,\n \"--batch_size\", batch_size,\n \"--output_path\", str(output_path.parent),\n ]\n subprocess.run(cmd, check=True)\n for path in sorted(output_path.parent.rglob(\"*.json\")):\n try:\n data = json.loads(path.read_text(encoding=\"utf-8\"))\n except Exception:\n continue\n if isinstance(data, dict) and \"results\" in data:\n output_path.write_text(json.dumps(data, indent=2), encoding=\"utf-8\")\n return data\n raise FileNotFoundError(\"Could not find lm_eval results json.\")\n\n\ndef extract_primary_scores(eval_json: dict) -> dict:\n out = {}\n for task, metrics in eval_json.get(\"results\", {}).items():\n if \"acc_norm,none\" in metrics:\n out[task] = metrics[\"acc_norm,none\"]\n elif \"acc,none\" in metrics:\n out[task] = metrics[\"acc,none\"]\n elif \"word_perplexity,none\" in metrics:\n out[task] = metrics[\"word_perplexity,none\"]\n return out\n" }, { "path": "quant_judge/pipeline.py", "content": "from __future__ import annotations\n\nimport argparse\nimport json\nfrom pathlib import Path\n\nfrom .compiler import compile_plan, plan_to_dynamic, save_plan_and_dynamic\nfrom .judge import JudgeConfig, load_rows, make_judge_proposals\nfrom .probe import ProbeConfig, run_probe\nfrom .runner import extract_primary_scores, heuristic_dynamic_from_flags, load_texts, quantize_with_dynamic, run_lm_eval\n\n\ndef cmd_probe(args):\n cfg = ProbeConfig(\n model_id=args.model_id,\n out_dir=args.out_dir,\n dataset_name=args.dataset_name,\n dataset_config=args.dataset_config,\n split=args.split,\n text_field=args.text_field,\n num_texts=args.num_texts,\n max_length=args.max_length,\n max_rows_per_module=args.max_rows_per_module,\n include_modules=[x.strip() for x in args.include_modules.split(\",\") if x.strip()],\n )\n print(json.dumps(run_probe(cfg), indent=2))\n\n\ndef cmd_plan(args):\n judge_cfg = JudgeConfig(\n features_jsonl=args.features_jsonl,\n budget_mb_extra=args.budget_mb_extra,\n heuristic_only=args.heuristic_only,\n judge_base_url=args.judge_base_url,\n judge_model=args.judge_model,\n judge_api_key_env=args.judge_api_key_env,\n chunk_size=args.chunk_size,\n )\n proposals = make_judge_proposals(judge_cfg)\n probe_rows = load_rows(args.features_jsonl)\n plan = compile_plan(proposals[\"assignments\"], probe_rows, args.budget_mb_extra, args.min_roi)\n dynamic = plan_to_dynamic(plan)\n save_plan_and_dynamic(plan, dynamic, args.plan_out, args.dynamic_out)\n print(json.dumps({\"plan_out\": args.plan_out, \"dynamic_out\": args.dynamic_out, \"used_mb_extra\": plan[\"used_mb_extra\"], \"upgrades_approved\": plan[\"upgrades_approved\"], \"dynamic_rules\": len(dynamic)}, indent=2))\n\n\ndef cmd_quantize(args):\n calibration_texts = load_texts(args.dataset_name, args.dataset_config, args.split, args.text_field, args.num_texts)\n plan = json.loads(Path(args.plan_json).read_text(encoding=\"utf-8\"))\n dynamic = plan_to_dynamic(plan)\n print(json.dumps(quantize_with_dynamic(args.model_id, calibration_texts, args.out_dir, dynamic, args.base_bits, args.base_group_size, args.batch_size), indent=2))\n\n\ndef cmd_eval(args):\n tasks = [x.strip() for x in args.tasks.split(\",\") if x.strip()]\n result = run_lm_eval(args.model_path, tasks, args.output_json, args.device, args.batch_size)\n print(json.dumps({\"output_json\": args.output_json, \"scores\": extract_primary_scores(result)}, indent=2))\n\n\ndef cmd_suite(args):\n work_dir = Path(args.work_dir)\n work_dir.mkdir(parents=True, exist_ok=True)\n probe_dir = work_dir / \"probe\"\n\n probe_result = run_probe(ProbeConfig(\n model_id=args.model_id,\n out_dir=str(probe_dir),\n dataset_name=args.dataset_name,\n dataset_config=args.dataset_config,\n split=args.split,\n text_field=args.text_field,\n num_texts=args.probe_num_texts,\n max_length=args.max_length,\n max_rows_per_module=args.max_rows_per_module,\n include_modules=[x.strip() for x in args.include_modules.split(\",\") if x.strip()],\n ))\n features_jsonl = probe_result[\"features_jsonl\"]\n probe_rows = load_rows(features_jsonl)\n\n judge_cfg = JudgeConfig(\n features_jsonl=features_jsonl,\n budget_mb_extra=args.budget_mb_extra,\n heuristic_only=(not args.judge_base_url or not args.judge_model),\n judge_base_url=args.judge_base_url,\n judge_model=args.judge_model,\n judge_api_key_env=args.judge_api_key_env,\n chunk_size=args.chunk_size,\n )\n judge_proposals = make_judge_proposals(judge_cfg)\n judge_plan = compile_plan(judge_proposals[\"assignments\"], probe_rows, args.budget_mb_extra, args.min_roi)\n judge_dynamic = plan_to_dynamic(judge_plan)\n save_plan_and_dynamic(judge_plan, judge_dynamic, str(work_dir / \"judge_plan.json\"), str(work_dir / \"judge_dynamic.json\"))\n\n heuristic_flags = [x.strip() for x in args.heuristic_include.split(\",\") if x.strip()]\n heuristic_dynamic = heuristic_dynamic_from_flags(heuristic_flags)\n (work_dir / \"heuristic_dynamic.json\").write_text(json.dumps(heuristic_dynamic, indent=2), encoding=\"utf-8\")\n\n calibration_texts = load_texts(args.dataset_name, args.dataset_config, args.split, args.text_field, args.quant_num_texts)\n tasks = [x.strip() for x in args.tasks.split(\",\") if x.strip()]\n\n configs = {\n \"baseline_w4g128\": None,\n \"heuristic_mixed\": heuristic_dynamic,\n \"judge_sdq\": judge_dynamic,\n }\n\n leaderboard = {}\n for name, dynamic in configs.items():\n model_dir = work_dir / \"models\" / name\n eval_json = work_dir / \"eval\" / f\"{name}.json\"\n quantize_with_dynamic(args.model_id, calibration_texts, str(model_dir), dynamic, args.base_bits, args.base_group_size, args.batch_size)\n result = run_lm_eval(str(model_dir), tasks, str(eval_json), args.device, args.eval_batch_size)\n leaderboard[name] = extract_primary_scores(result)\n\n payload = {\n \"probe\": probe_result,\n \"leaderboard\": leaderboard,\n \"judge_used\": bool(args.judge_base_url and args.judge_model),\n \"judge_plan_path\": str(work_dir / \"judge_plan.json\"),\n \"judge_dynamic_path\": str(work_dir / \"judge_dynamic.json\"),\n \"heuristic_dynamic_path\": str(work_dir / \"heuristic_dynamic.json\"),\n }\n (work_dir / \"leaderboard.json\").write_text(json.dumps(payload, indent=2), encoding=\"utf-8\")\n print(json.dumps(payload, indent=2))\n\n\ndef build_parser():\n p = argparse.ArgumentParser(prog=\"quant_judge\")\n sub = p.add_subparsers(dest=\"cmd\", required=True)\n\n probe = sub.add_parser(\"probe\")\n probe.add_argument(\"--model-id\", required=True)\n probe.add_argument(\"--out-dir\", required=True)\n probe.add_argument(\"--dataset-name\", default=\"wikitext\")\n probe.add_argument(\"--dataset-config\", default=\"wikitext-2-raw-v1\")\n probe.add_argument(\"--split\", default=\"train\")\n probe.add_argument(\"--text-field\", default=\"text\")\n probe.add_argument(\"--num-texts\", type=int, default=64)\n probe.add_argument(\"--max-length\", type=int, default=256)\n probe.add_argument(\"--max-rows-per-module\", type=int, default=128)\n probe.add_argument(\"--include-modules\", default=\"q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj,lm_head\")\n probe.set_defaults(func=cmd_probe)\n\n plan = sub.add_parser(\"plan\")\n plan.add_argument(\"--features-jsonl\", required=True)\n plan.add_argument(\"--plan-out\", required=True)\n plan.add_argument(\"--dynamic-out\", required=True)\n plan.add_argument(\"--budget-mb-extra\", type=float, default=150.0)\n plan.add_argument(\"--min-roi\", type=float, default=1e-6)\n plan.add_argument(\"--heuristic-only\", action=\"store_true\")\n plan.add_argument(\"--judge-base-url\", default=None)\n plan.add_argument(\"--judge-model\", default=None)\n plan.add_argument(\"--judge-api-key-env\", default=\"JUDGE_API_KEY\")\n plan.add_argument(\"--chunk-size\", type=int, default=24)\n plan.set_defaults(func=cmd_plan)\n\n quant = sub.add_parser(\"quantize\")\n quant.add_argument(\"--model-id\", required=True)\n quant.add_argument(\"--plan-json\", required=True)\n quant.add_argument(\"--out-dir\", required=True)\n quant.add_argument(\"--dataset-name\", default=\"wikitext\")\n quant.add_argument(\"--dataset-config\", default=\"wikitext-2-raw-v1\")\n quant.add_argument(\"--split\", default=\"train\")\n quant.add_argument(\"--text-field\", default=\"text\")\n quant.add_argument(\"--num-texts\", type=int, default=128)\n quant.add_argument(\"--batch-size\", type=int, default=4)\n quant.add_argument(\"--base-bits\", type=int, default=4)\n quant.add_argument(\"--base-group-size\", type=int, default=128)\n quant.set_defaults(func=cmd_quantize)\n\n ev = sub.add_parser(\"eval\")\n ev.add_argument(\"--model-path\", required=True)\n ev.add_argument(\"--tasks\", default=\"arc_challenge,hellaswag\")\n ev.add_argument(\"--output-json\", required=True)\n ev.add_argument(\"--device\", default=\"cuda:0\")\n ev.add_argument(\"--batch-size\", default=\"auto\")\n ev.set_defaults(func=cmd_eval)\n\n suite = sub.add_parser(\"suite\")\n suite.add_argument(\"--model-id\", required=True)\n suite.add_argument(\"--work-dir\", required=True)\n suite.add_argument(\"--dataset-name\", default=\"wikitext\")\n suite.add_argument(\"--dataset-config\", default=\"wikitext-2-raw-v1\")\n suite.add_argument(\"--split\", default=\"train\")\n suite.add_argument(\"--text-field\", default=\"text\")\n suite.add_argument(\"--probe-num-texts\", type=int, default=64)\n suite.add_argument(\"--quant-num-texts\", type=int, default=128)\n suite.add_argument(\"--max-length\", type=int, default=256)\n suite.add_argument(\"--max-rows-per-module\", type=int, default=128)\n suite.add_argument(\"--include-modules\", default=\"q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj,lm_head\")\n suite.add_argument(\"--budget-mb-extra\", type=float, default=150.0)\n suite.add_argument(\"--min-roi\", type=float, default=1e-6)\n suite.add_argument(\"--heuristic-include\", default=\"skip_lm_head,down_proj_g64\")\n suite.add_argument(\"--judge-base-url\", default=None)\n suite.add_argument(\"--judge-model\", default=None)\n suite.add_argument(\"--judge-api-key-env\", default=\"JUDGE_API_KEY\")\n suite.add_argument(\"--chunk-size\", type=int, default=24)\n suite.add_argument(\"--tasks\", default=\"arc_challenge,hellaswag\")\n suite.add_argument(\"--device\", default=\"cuda:0\")\n suite.add_argument(\"--batch-size\", type=int, default=4)\n suite.add_argument(\"--eval-batch-size\", default=\"auto\")\n suite.add_argument(\"--base-bits\", type=int, default=4)\n suite.add_argument(\"--base-group-size\", type=int, default=128)\n suite.set_defaults(func=cmd_suite)\n\n return p\n\n\ndef main():\n parser = build_parser()\n args = parser.parse_args()\n args.func(args)\n\n\nif __name__ == \"__main__\":\n main()\n" } ]} +JSON{ "project_name": "sdq_quant_judge_qwen35_gptqmodel", "files": [ { "path": "README.md", "content": "# SDQ Quant Judge for GPTQModel\n\nA software-defined quantization scaffold for Qwen/Qwen3.5-4B-Base.\n\nCommands:\n- probe: extract per-module fake-quant sensitivity features from the FP model\n- plan: build a mixed-precision plan with either heuristic logic or an external OpenAI-compatible judge\n- quantize: quantize with GPTQModel using the compiled dynamic overrides\n- eval: evaluate a saved model with lm-eval\n- suite: run baseline vs heuristic vs judge and save a leaderboard\n\nExample:\n\n```bash\npython -m venv .venv\nsource .venv/bin/activate\npip install -r requirements.txt\n\npython -m quant_judge.pipeline probe \\\n --model-id Qwen/Qwen3.5-4B-Base \\\n --out-dir ./runs/probe \\\n --dataset-name wikitext \\\n --dataset-config wikitext-2-raw-v1 \\\n --split train \\\n --text-field text \\\n --num-texts 64 \\\n --max-length 256 \\\n --max-rows-per-module 128 \\\n --include-modules q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj,lm_head\n\npython -m quant_judge.pipeline plan \\\n --features-jsonl ./runs/probe/probe_features.jsonl \\\n --plan-out ./runs/plan.json \\\n --dynamic-out ./runs/dynamic.json \\\n --budget-mb-extra 150 \\\n --heuristic-only\n\npython -m quant_judge.pipeline quantize \\\n --model-id Qwen/Qwen3.5-4B-Base \\\n --plan-json ./runs/plan.json \\\n --out-dir ./runs/quantized \\\n --dataset-name wikitext \\\n --dataset-config wikitext-2-raw-v1 \\\n --split train \\\n --text-field text \\\n --num-texts 128 \\\n --batch-size 4\n\npython -m quant_judge.pipeline eval \\\n --model-path ./runs/quantized \\\n --tasks arc_challenge,hellaswag \\\n --output-json ./runs/eval.json\n```\n" }, { "path": "requirements.txt", "content": "transformers\ndatasets\naccelerate\nrequests\nnumpy\ntqdm\npsutil\ngptqmodel\nlm-eval\n" }, { "path": "quant_judge/__init__.py", "content": "\"\"\"SDQ quantization scaffold.\"\"\"\n" }, { "path": "quant_judge/actions.py", "content": "from __future__ import annotations\n\nfrom dataclasses import dataclass\n\n\n@dataclass(frozen=True)\nclass ActionSpec:\n name: str\n bits: int | None\n group_size: int | None\n skip: bool\n\n\nACTION_SPECS: dict[str, ActionSpec] = {\n \"w4_g128\": ActionSpec(\"w4_g128\", bits=4, group_size=128, skip=False),\n \"w4_g64\": ActionSpec(\"w4_g64\", bits=4, group_size=64, skip=False),\n \"w8_g128\": ActionSpec(\"w8_g128\", bits=8, group_size=128, skip=False),\n \"skip\": ActionSpec(\"skip\", bits=None, group_size=None, skip=True),\n}\n\nBASE_ACTION = \"w4_g128\"\n\n\ndef allowed_actions() -> list[str]:\n return list(ACTION_SPECS.keys())\n\n\ndef bytes_per_param(action: str, fp_bytes: float = 2.0) -> float:\n if action == \"skip\":\n return fp_bytes\n spec = ACTION_SPECS[action]\n assert spec.bits is not None\n return spec.bits / 8.0\n\n\ndef extra_mb_vs_base(params: int, action: str) -> float:\n base = params * bytes_per_param(BASE_ACTION)\n now = params * bytes_per_param(action)\n return (now - base) / (1024 ** 2)\n\n\ndef action_to_override(action: str) -> dict:\n if action == \"skip\":\n return {}\n spec = ACTION_SPECS[action]\n return {\"bits\": spec.bits, \"group_size\": spec.group_size}\n" }, { "path": "quant_judge/probe.py", "content": "from __future__ import annotations\n\nimport json\nfrom dataclasses import dataclass\nfrom pathlib import Path\n\nimport torch\nimport torch.nn.functional as F\nfrom datasets import load_dataset\nfrom tqdm import tqdm\nfrom transformers import AutoModelForCausalLM, AutoTokenizer\n\nfrom .actions import extra_mb_vs_base\n\n\n@dataclass\nclass ProbeConfig:\n model_id: str\n out_dir: str\n dataset_name: str\n dataset_config: str | None\n split: str\n text_field: str\n num_texts: int\n max_length: int\n max_rows_per_module: int\n include_modules: list[str]\n trust_remote_code: bool = True\n\n\ndef pick_dtype() -> torch.dtype:\n if not torch.cuda.is_available():\n return torch.float32\n major, _minor = torch.cuda.get_device_capability(0)\n if major >= 8:\n return torch.bfloat16\n return torch.float16\n\n\ndef fake_quant_groupwise_linear_weight(weight: torch.Tensor, bits: int, group_size: int) -> torch.Tensor:\n assert weight.dim() == 2\n w = weight.detach().float()\n out_features, in_features = w.shape\n pad = (group_size - (in_features % group_size)) % group_size\n if pad:\n w = F.pad(w, (0, pad), value=0.0)\n maxq = (2 ** (bits - 1)) - 1\n w = w.view(out_features, -1, group_size)\n absmax = w.abs().amax(dim=-1, keepdim=True).clamp_min(1e-8)\n scale = absmax / maxq\n q = torch.round(w / scale).clamp(-maxq, maxq)\n dq = q * scale\n dq = dq.view(out_features, -1)\n if pad:\n dq = dq[:, :in_features]\n return dq.to(weight.device, dtype=weight.dtype)\n\n\ndef load_texts(dataset_name: str, dataset_config: str | None, split: str, text_field: str, num_texts: int) -> list[str]:\n ds = load_dataset(dataset_name, dataset_config, split=split)\n texts: list[str] = []\n for row in ds:\n value = row.get(text_field, \"\")\n if isinstance(value, str) and value.strip():\n texts.append(value.strip())\n if len(texts) >= num_texts:\n break\n if not texts:\n raise ValueError(\"No usable texts found.\")\n return texts\n\n\ndef module_selected(name: str, include_modules: list[str]) -> bool:\n return any(token in name for token in include_modules)\n\n\ndef iter_target_linears(model, include_modules: list[str]):\n for name, module in model.named_modules():\n if isinstance(module, torch.nn.Linear) and module_selected(name, include_modules):\n yield name, module\n\n\ndef flatten_rows(x: torch.Tensor) -> torch.Tensor:\n if x.dim() == 2:\n return x\n if x.dim() >= 3:\n return x.reshape(-1, x.shape[-1])\n return x.unsqueeze(0)\n\n\ndef sample_rows(x: torch.Tensor, max_rows: int) -> torch.Tensor:\n rows = flatten_rows(x)\n if rows.shape[0] <= max_rows:\n return rows\n idx = torch.randperm(rows.shape[0], device=rows.device)[:max_rows]\n return rows.index_select(0, idx)\n\n\ndef weight_kurtosis(weight: torch.Tensor) -> float:\n w = weight.detach().float().flatten()\n mu = w.mean()\n var = ((w - mu) ** 2).mean().clamp_min(1e-12)\n return float((((w - mu) ** 4).mean() / (var ** 2)).item())\n\n\ndef activation_outlier_rate(x: torch.Tensor, sigma: float = 6.0) -> float:\n xf = x.float()\n std = xf.std().clamp_min(1e-6)\n return float((xf.abs() > sigma * std).float().mean().item())\n\n\ndef block_index_from_name(name: str) -> int | None:\n parts = name.split(\".\")\n for i, p in enumerate(parts):\n if p == \"layers\" and i + 1 < len(parts):\n try:\n return int(parts[i + 1])\n except Exception:\n return None\n return None\n\n\ndef reduced_metrics_for_action(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | None, bits: int, group_size: int) -> dict:\n wq = fake_quant_groupwise_linear_weight(weight, bits=bits, group_size=group_size)\n out_fp = F.linear(x, weight, bias)\n out_q = F.linear(x, wq, bias)\n mse = F.mse_loss(out_q.float(), out_fp.float()).item()\n cos = F.cosine_similarity(out_q.float(), out_fp.float(), dim=-1).mean().item()\n max_abs = (out_q.float() - out_fp.float()).abs().max().item()\n return {\"mse\": float(mse), \"cosine\": float(cos), \"max_abs\": float(max_abs)}\n\n\ndef run_probe(cfg: ProbeConfig) -> dict:\n out_dir = Path(cfg.out_dir)\n out_dir.mkdir(parents=True, exist_ok=True)\n\n tokenizer = AutoTokenizer.from_pretrained(cfg.model_id, trust_remote_code=cfg.trust_remote_code)\n if tokenizer.pad_token_id is None:\n tokenizer.pad_token = tokenizer.eos_token\n\n model = AutoModelForCausalLM.from_pretrained(\n cfg.model_id,\n trust_remote_code=cfg.trust_remote_code,\n torch_dtype=pick_dtype(),\n device_map=\"auto\",\n low_cpu_mem_usage=True,\n )\n model.eval()\n\n texts = load_texts(cfg.dataset_name, cfg.dataset_config, cfg.split, cfg.text_field, cfg.num_texts)\n targets = list(iter_target_linears(model, cfg.include_modules))\n if not targets:\n raise ValueError(\"No target modules found.\")\n\n captured_inputs = {name: [] for name, _ in targets}\n remaining_rows = {name: cfg.max_rows_per_module for name, _ in targets}\n handles = []\n\n def make_hook(name: str):\n def hook(_module, inp, _out):\n rem = remaining_rows[name]\n if rem <= 0 or not inp or not isinstance(inp[0], torch.Tensor):\n return\n x = sample_rows(inp[0], rem)\n take = min(x.shape[0], rem)\n if take <= 0:\n return\n captured_inputs[name].append(x[:take].detach().cpu().to(torch.float16))\n remaining_rows[name] -= take\n return hook\n\n for name, module in targets:\n handles.append(module.register_forward_hook(make_hook(name)))\n\n with torch.no_grad():\n for text in tqdm(texts, desc=\"collect_activations\"):\n batch = tokenizer(text, return_tensors=\"pt\", truncation=True, max_length=cfg.max_length, padding=False)\n batch = {k: v.to(model.device) for k, v in batch.items()}\n _ = model(**batch)\n if all(v <= 0 for v in remaining_rows.values()):\n break\n\n for h in handles:\n h.remove()\n\n rows = []\n for name, module in tqdm(targets, desc=\"probe_modules\"):\n if not captured_inputs[name]:\n continue\n x = torch.cat(captured_inputs[name], dim=0).to(module.weight.device, dtype=module.weight.dtype)\n weight = module.weight.detach()\n bias = None if module.bias is None else module.bias.detach()\n\n m_w4g128 = reduced_metrics_for_action(x, weight, bias, bits=4, group_size=128)\n m_w4g64 = reduced_metrics_for_action(x, weight, bias, bits=4, group_size=64)\n m_w8g128 = reduced_metrics_for_action(x, weight, bias, bits=8, group_size=128)\n\n params = int(weight.numel())\n row = {\n \"name\": name,\n \"module_type\": name.split(\".\")[-1],\n \"block_index\": block_index_from_name(name),\n \"params\": params,\n \"captured_rows\": int(x.shape[0]),\n \"weight_kurtosis\": weight_kurtosis(weight),\n \"act_outlier_rate\": activation_outlier_rate(x),\n \"candidates\": {\n \"w4_g128\": {**m_w4g128, \"estimated_extra_mb_vs_base\": 0.0},\n \"w4_g64\": {**m_w4g64, \"estimated_extra_mb_vs_base\": max(0.0, extra_mb_vs_base(params, \"w4_g64\")) + (params * 0.05 / (1024 ** 2))},\n \"w8_g128\": {**m_w8g128, \"estimated_extra_mb_vs_base\": max(0.0, extra_mb_vs_base(params, \"w8_g128\"))},\n \"skip\": {\"mse\": 0.0, \"cosine\": 1.0, \"max_abs\": 0.0, \"estimated_extra_mb_vs_base\": max(0.0, extra_mb_vs_base(params, \"skip\"))},\n },\n \"proxy_gain_w4g64_vs_w4\": float(max(0.0, m_w4g128[\"mse\"] - m_w4g64[\"mse\"])),\n \"proxy_gain_w8_vs_w4\": float(max(0.0, m_w4g128[\"mse\"] - m_w8g128[\"mse\"])),\n }\n rows.append(row)\n\n features_jsonl = out_dir / \"probe_features.jsonl\"\n metadata_json = out_dir / \"probe_metadata.json\"\n summary_json = out_dir / \"probe_summary.json\"\n\n with open(features_jsonl, \"w\", encoding=\"utf-8\") as f:\n for row in rows:\n f.write(json.dumps(row, ensure_ascii=False) + \"\\n\")\n\n metadata = {\n \"model_id\": cfg.model_id,\n \"dataset_name\": cfg.dataset_name,\n \"dataset_config\": cfg.dataset_config,\n \"split\": cfg.split,\n \"text_field\": cfg.text_field,\n \"num_texts\": cfg.num_texts,\n \"max_length\": cfg.max_length,\n \"max_rows_per_module\": cfg.max_rows_per_module,\n \"include_modules\": cfg.include_modules,\n \"num_rows\": len(rows),\n }\n metadata_json.write_text(json.dumps(metadata, indent=2), encoding=\"utf-8\")\n summary_json.write_text(json.dumps({\"num_rows\": len(rows)}, indent=2), encoding=\"utf-8\")\n\n del model\n if torch.cuda.is_available():\n torch.cuda.empty_cache()\n\n return {\"features_jsonl\": str(features_jsonl), \"metadata_json\": str(metadata_json), \"summary_json\": str(summary_json), \"num_rows\": len(rows)}\n" }, { "path": "quant_judge/judge.py", "content": "from __future__ import annotations\n\nimport json\nimport os\nfrom dataclasses import dataclass\n\nimport requests\n\nfrom .actions import ACTION_SPECS, BASE_ACTION, allowed_actions\n\n\n@dataclass\nclass JudgeConfig:\n features_jsonl: str\n budget_mb_extra: float\n heuristic_only: bool = False\n judge_base_url: str | None = None\n judge_model: str | None = None\n judge_api_key_env: str = \"JUDGE_API_KEY\"\n chunk_size: int = 24\n\n\ndef load_rows(path: str) -> list[dict]:\n rows = []\n with open(path, \"r\", encoding=\"utf-8\") as f:\n for line in f:\n line = line.strip()\n if line:\n rows.append(json.loads(line))\n return rows\n\n\ndef heuristic_decision(row: dict) -> dict:\n name = row[\"name\"]\n w4 = row[\"candidates\"][\"w4_g128\"][\"mse\"]\n w4g64 = row[\"candidates\"][\"w4_g64\"][\"mse\"]\n w8 = row[\"candidates\"][\"w8_g128\"][\"mse\"]\n kurt = row[\"weight_kurtosis\"]\n outlier = row[\"act_outlier_rate\"]\n block = row.get(\"block_index\")\n\n w8_relief = (w4 - w8) / max(w4, 1e-9)\n g64_relief = (w4 - w4g64) / max(w4, 1e-9)\n\n score = 0.0\n score += min(1.0, w8_relief / 0.60) * 0.40\n score += min(1.0, g64_relief / 0.35) * 0.20\n score += min(1.0, kurt / 20.0) * 0.15\n score += min(1.0, outlier / 0.05) * 0.15\n if \"lm_head\" in name:\n score += 0.25\n if row[\"module_type\"] in (\"down_proj\", \"o_proj\", \"gate_proj\"):\n score += 0.10\n if block is not None and (block <= 1 or block >= 30):\n score += 0.10\n score = max(0.0, min(1.0, score))\n\n action = BASE_ACTION\n reason = \"default\"\n if \"lm_head\" in name and (w8_relief > 0.55 or kurt > 25.0):\n action = \"skip\"\n reason = \"lm_head_extreme\"\n elif score >= 0.75 or w8_relief > 0.50:\n action = \"w8_g128\"\n reason = \"high_sensitivity\"\n elif score >= 0.42 or g64_relief > 0.18:\n action = \"w4_g64\"\n reason = \"moderate_sensitivity\"\n\n return {\"name\": name, \"action\": action, \"priority\": round(score, 6), \"reason_code\": reason, \"source\": \"heuristic\"}\n\n\ndef reduced_row(row: dict) -> dict:\n return {\n \"name\": row[\"name\"],\n \"module_type\": row[\"module_type\"],\n \"block_index\": row.get(\"block_index\"),\n \"params\": row[\"params\"],\n \"weight_kurtosis\": row[\"weight_kurtosis\"],\n \"act_outlier_rate\": row[\"act_outlier_rate\"],\n \"w4_g128_mse\": row[\"candidates\"][\"w4_g128\"][\"mse\"],\n \"w4_g64_mse\": row[\"candidates\"][\"w4_g64\"][\"mse\"],\n \"w8_g128_mse\": row[\"candidates\"][\"w8_g128\"][\"mse\"],\n \"extra_mb_w4_g64\": row[\"candidates\"][\"w4_g64\"][\"estimated_extra_mb_vs_base\"],\n \"extra_mb_w8_g128\": row[\"candidates\"][\"w8_g128\"][\"estimated_extra_mb_vs_base\"],\n \"extra_mb_skip\": row[\"candidates\"][\"skip\"][\"estimated_extra_mb_vs_base\"],\n }\n\n\ndef strip_code_fence(text: str) -> str:\n t = text.strip()\n if t.startswith(\"```\"):\n t = t.split(\"\\n\", 1)[1]\n if t.endswith(\"```\"):\n t = t.rsplit(\"\\n\", 1)[0]\n return t.strip()\n\n\ndef chunk(rows: list[dict], size: int) -> list[list[dict]]:\n return [rows[i:i + size] for i in range(0, len(rows), size)]\n\n\ndef call_openai_compatible(base_url: str, model: str, api_key: str, rows: list[dict], budget_mb_extra: float) -> list[dict]:\n system = (\n \"You are a quantization policy judge. \"\n \"Choose exactly one action per module from: \" + \", \".join(allowed_actions()) + \". \"\n \"Prefer w4_g128 unless the evidence shows fragility. \"\n \"Use w4_g64 for moderate sensitivity. Use w8_g128 for high sensitivity. \"\n \"Use skip only for extreme outliers, usually lm_head. \"\n \"Return JSON only with schema {\\\\\\\"assignments\\\\\\\":[{\\\\\\\"name\\\\\\\":...,\\\\\\\"action\\\\\\\":...,\\\\\\\"priority\\\\\\\":0..1,\\\\\\\"reason_code\\\\\\\":...}]}\"\n )\n user = {\"budget_mb_extra_global\": budget_mb_extra, \"allowed_actions\": allowed_actions(), \"rows\": rows}\n url = base_url.rstrip(\"/\") + \"/chat/completions\"\n headers = {\"Content-Type\": \"application/json\"}\n if api_key:\n headers[\"Authorization\"] = f\"Bearer {api_key}\"\n body = {\n \"model\": model,\n \"temperature\": 0.0,\n \"response_format\": {\"type\": \"json_object\"},\n \"messages\": [\n {\"role\": \"system\", \"content\": system},\n {\"role\": \"user\", \"content\": json.dumps(user, ensure_ascii=False)},\n ],\n }\n resp = requests.post(url, headers=headers, json=body, timeout=180)\n resp.raise_for_status()\n content = resp.json()[\"choices\"][0][\"message\"][\"content\"]\n parsed = json.loads(strip_code_fence(content))\n assignments = parsed.get(\"assignments\", [])\n if not isinstance(assignments, list):\n raise ValueError(\"Judge response missing assignments.\")\n return assignments\n\n\ndef merge_assignments(rows: list[dict], proposed: list[dict]) -> list[dict]:\n by_name = {a[\"name\"]: a for a in proposed if isinstance(a, dict) and \"name\" in a}\n merged = []\n for row in rows:\n item = by_name.get(row[\"name\"])\n if item is None:\n merged.append(heuristic_decision(row))\n continue\n action = item.get(\"action\", BASE_ACTION)\n if action not in ACTION_SPECS:\n action = BASE_ACTION\n try:\n priority = float(item.get(\"priority\", 0.5))\n except Exception:\n priority = 0.5\n priority = max(0.0, min(1.0, priority))\n merged.append({\"name\": row[\"name\"], \"action\": action, \"priority\": priority, \"reason_code\": str(item.get(\"reason_code\", \"judge\")), \"source\": \"judge\"})\n return merged\n\n\ndef make_judge_proposals(cfg: JudgeConfig) -> dict:\n rows = load_rows(cfg.features_jsonl)\n if cfg.heuristic_only:\n return {\"assignments\": [heuristic_decision(r) for r in rows]}\n if not cfg.judge_base_url or not cfg.judge_model:\n raise ValueError(\"Judge mode requires judge_base_url and judge_model.\")\n api_key = os.environ.get(cfg.judge_api_key_env, \"\")\n assignments = []\n for part in chunk(rows, cfg.chunk_size):\n reduced = [reduced_row(r) for r in part]\n try:\n proposed = call_openai_compatible(cfg.judge_base_url, cfg.judge_model, api_key, reduced, cfg.budget_mb_extra)\n assignments.extend(merge_assignments(part, proposed))\n except Exception:\n assignments.extend([heuristic_decision(r) for r in part])\n return {\"assignments\": assignments}\n" }, { "path": "quant_judge/compiler.py", "content": "from __future__ import annotations\n\nimport json\nimport re\nfrom pathlib import Path\n\nfrom .actions import BASE_ACTION, action_to_override\n\n\ndef compile_plan(assignments: list[dict], probe_rows: list[dict], extra_budget_mb: float, min_roi: float = 1e-6) -> dict:\n probe_by_name = {row[\"name\"]: row for row in probe_rows}\n upgrades = []\n for item in assignments:\n name = item[\"name\"]\n action = item[\"action\"]\n priority = float(item.get(\"priority\", 0.5))\n if action == BASE_ACTION:\n continue\n row = probe_by_name.get(name)\n if row is None:\n continue\n if action == \"w8_g128\":\n benefit = row[\"candidates\"][\"w4_g128\"][\"mse\"] - row[\"candidates\"][\"w8_g128\"][\"mse\"]\n cost_mb = row[\"candidates\"][\"w8_g128\"][\"estimated_extra_mb_vs_base\"]\n elif action == \"w4_g64\":\n benefit = row[\"candidates\"][\"w4_g128\"][\"mse\"] - row[\"candidates\"][\"w4_g64\"][\"mse\"]\n cost_mb = row[\"candidates\"][\"w4_g64\"][\"estimated_extra_mb_vs_base\"]\n elif action == \"skip\":\n benefit = row[\"candidates\"][\"w4_g128\"][\"mse\"]\n cost_mb = row[\"candidates\"][\"skip\"][\"estimated_extra_mb_vs_base\"]\n else:\n continue\n if benefit <= 0 or cost_mb <= 0:\n continue\n roi = benefit / cost_mb\n if roi < min_roi:\n continue\n upgrades.append({\n \"name\": name,\n \"action\": action,\n \"benefit\": float(benefit),\n \"cost_mb\": float(cost_mb),\n \"roi\": float(roi),\n \"priority\": priority,\n \"reason_code\": item.get(\"reason_code\", \"\"),\n \"source\": item.get(\"source\", \"\"),\n })\n upgrades.sort(key=lambda x: (x[\"roi\"], x[\"priority\"]), reverse=True)\n\n approved = []\n spent_mb = 0.0\n for up in upgrades:\n if spent_mb + up[\"cost_mb\"] > extra_budget_mb:\n continue\n spent_mb += up[\"cost_mb\"]\n approved.append(up)\n\n approved_by_name = {x[\"name\"]: x for x in approved}\n final_assignments = []\n for row in probe_rows:\n ap = approved_by_name.get(row[\"name\"])\n if ap is None:\n final_assignments.append({\"name\": row[\"name\"], \"action\": BASE_ACTION, \"priority\": 0.0, \"reason_code\": \"base_default\", \"source\": \"compiler\", \"estimated_extra_mb_vs_base\": 0.0})\n else:\n final_assignments.append({\"name\": ap[\"name\"], \"action\": ap[\"action\"], \"priority\": ap[\"priority\"], \"reason_code\": ap[\"reason_code\"], \"source\": ap[\"source\"], \"estimated_extra_mb_vs_base\": ap[\"cost_mb\"], \"roi\": ap[\"roi\"], \"benefit\": ap[\"benefit\"]})\n final_assignments.sort(key=lambda x: x[\"name\"])\n return {\"budget_mb_extra\": float(extra_budget_mb), \"used_mb_extra\": round(float(spent_mb), 6), \"upgrades_considered\": len(upgrades), \"upgrades_approved\": len(approved), \"assignments\": final_assignments}\n\n\ndef plan_to_dynamic(plan: dict) -> dict:\n dynamic = {}\n for item in plan[\"assignments\"]:\n name = item[\"name\"]\n action = item[\"action\"]\n if action == BASE_ACTION:\n continue\n anchored = f\"^{re.escape(name)}$\"\n if action == \"skip\":\n dynamic[f\"-:{anchored}\"] = {}\n else:\n dynamic[f\"+:{anchored}\"] = action_to_override(action)\n return dynamic\n\n\ndef save_plan_and_dynamic(plan: dict, dynamic: dict, plan_out: str, dynamic_out: str) -> None:\n Path(plan_out).parent.mkdir(parents=True, exist_ok=True)\n Path(dynamic_out).parent.mkdir(parents=True, exist_ok=True)\n Path(plan_out).write_text(json.dumps(plan, indent=2), encoding=\"utf-8\")\n Path(dynamic_out).write_text(json.dumps(dynamic, indent=2), encoding=\"utf-8\")\n" }, { "path": "quant_judge/runner.py", "content": "from __future__ import annotations\n\nimport json\nimport subprocess\nimport sys\nfrom pathlib import Path\n\nimport torch\nfrom datasets import load_dataset\nfrom gptqmodel import GPTQModel, QuantizeConfig\n\n\ndef load_texts(dataset_name: str, dataset_config: str | None, split: str, text_field: str, num_texts: int) -> list[str]:\n ds = load_dataset(dataset_name, dataset_config, split=split)\n texts = []\n for row in ds:\n text = row.get(text_field, \"\")\n if isinstance(text, str) and text.strip():\n texts.append(text.strip())\n if len(texts) >= num_texts:\n break\n if not texts:\n raise ValueError(\"No texts found for quantization.\")\n return texts\n\n\ndef quantize_with_dynamic(model_id: str, calibration_texts: list[str], out_dir: str, dynamic: dict | None, bits: int = 4, group_size: int = 128, batch_size: int = 4, sym: bool = True, desc_act: bool = True) -> dict:\n qcfg = QuantizeConfig(bits=bits, group_size=group_size, sym=sym, desc_act=desc_act, dynamic=dynamic)\n model = GPTQModel.load(model_id, qcfg)\n model.quantize(calibration_texts, batch_size=batch_size)\n out = Path(out_dir)\n out.mkdir(parents=True, exist_ok=True)\n model.save(str(out))\n meta = {\"model_id\": model_id, \"out_dir\": str(out), \"bits\": bits, \"group_size\": group_size, \"batch_size\": batch_size, \"dynamic_rules\": 0 if dynamic is None else len(dynamic)}\n (out / \"quant_meta.json\").write_text(json.dumps(meta, indent=2), encoding=\"utf-8\")\n del model\n if torch.cuda.is_available():\n torch.cuda.empty_cache()\n return meta\n\n\ndef heuristic_dynamic_from_flags(flags: list[str]) -> dict:\n dynamic = {}\n for flag in flags:\n if flag == \"skip_lm_head\":\n dynamic[r\"-:^lm_head$\"] = {}\n elif flag == \"down_proj_g64\":\n dynamic[r\"+:.*down_proj$\"] = {\"bits\": 4, \"group_size\": 64}\n elif flag == \"o_proj_g64\":\n dynamic[r\"+:.*o_proj$\"] = {\"bits\": 4, \"group_size\": 64}\n elif flag == \"edge_blocks_w8\":\n dynamic[r\"+:^model\\\\.layers\\\\.(0|1|30|31)\\\\..*$\"] = {\"bits\": 8, \"group_size\": 128}\n return dynamic\n\n\ndef run_lm_eval(model_path: str, tasks: list[str], output_json: str, device: str = \"cuda:0\", batch_size: str = \"auto\") -> dict:\n output_path = Path(output_json)\n output_path.parent.mkdir(parents=True, exist_ok=True)\n cmd = [\n sys.executable, \"-m\", \"lm_eval\",\n \"--model\", \"hf\",\n \"--model_args\", f\"pretrained={model_path},trust_remote_code=True\",\n \"--tasks\", \",\".join(tasks),\n \"--device\", device,\n \"--batch_size\", batch_size,\n \"--output_path\", str(output_path.parent),\n ]\n subprocess.run(cmd, check=True)\n for path in sorted(output_path.parent.rglob(\"*.json\")):\n try:\n data = json.loads(path.read_text(encoding=\"utf-8\"))\n except Exception:\n continue\n if isinstance(data, dict) and \"results\" in data:\n output_path.write_text(json.dumps(data, indent=2), encoding=\"utf-8\")\n return data\n raise FileNotFoundError(\"Could not find lm_eval results json.\")\n\n\ndef extract_primary_scores(eval_json: dict) -> dict:\n out = {}\n for task, metrics in eval_json.get(\"results\", {}).items():\n if \"acc_norm,none\" in metrics:\n out[task] = metrics[\"acc_norm,none\"]\n elif \"acc,none\" in metrics:\n out[task] = metrics[\"acc,none\"]\n elif \"word_perplexity,none\" in metrics:\n out[task] = metrics[\"word_perplexity,none\"]\n return out\n" }, { "path": "quant_judge/pipeline.py", "content": "from __future__ import annotations\n\nimport argparse\nimport json\nfrom pathlib import Path\n\nfrom .compiler import compile_plan, plan_to_dynamic, save_plan_and_dynamic\nfrom .judge import JudgeConfig, load_rows, make_judge_proposals\nfrom .probe import ProbeConfig, run_probe\nfrom .runner import extract_primary_scores, heuristic_dynamic_from_flags, load_texts, quantize_with_dynamic, run_lm_eval\n\n\ndef cmd_probe(args):\n cfg = ProbeConfig(\n model_id=args.model_id,\n out_dir=args.out_dir,\n dataset_name=args.dataset_name,\n dataset_config=args.dataset_config,\n split=args.split,\n text_field=args.text_field,\n num_texts=args.num_texts,\n max_length=args.max_length,\n max_rows_per_module=args.max_rows_per_module,\n include_modules=[x.strip() for x in args.include_modules.split(\",\") if x.strip()],\n )\n print(json.dumps(run_probe(cfg), indent=2))\n\n\ndef cmd_plan(args):\n judge_cfg = JudgeConfig(\n features_jsonl=args.features_jsonl,\n budget_mb_extra=args.budget_mb_extra,\n heuristic_only=args.heuristic_only,\n judge_base_url=args.judge_base_url,\n judge_model=args.judge_model,\n judge_api_key_env=args.judge_api_key_env,\n chunk_size=args.chunk_size,\n )\n proposals = make_judge_proposals(judge_cfg)\n probe_rows = load_rows(args.features_jsonl)\n plan = compile_plan(proposals[\"assignments\"], probe_rows, args.budget_mb_extra, args.min_roi)\n dynamic = plan_to_dynamic(plan)\n save_plan_and_dynamic(plan, dynamic, args.plan_out, args.dynamic_out)\n print(json.dumps({\"plan_out\": args.plan_out, \"dynamic_out\": args.dynamic_out, \"used_mb_extra\": plan[\"used_mb_extra\"], \"upgrades_approved\": plan[\"upgrades_approved\"], \"dynamic_rules\": len(dynamic)}, indent=2))\n\n\ndef cmd_quantize(args):\n calibration_texts = load_texts(args.dataset_name, args.dataset_config, args.split, args.text_field, args.num_texts)\n plan = json.loads(Path(args.plan_json).read_text(encoding=\"utf-8\"))\n dynamic = plan_to_dynamic(plan)\n print(json.dumps(quantize_with_dynamic(args.model_id, calibration_texts, args.out_dir, dynamic, args.base_bits, args.base_group_size, args.batch_size), indent=2))\n\n\ndef cmd_eval(args):\n tasks = [x.strip() for x in args.tasks.split(\",\") if x.strip()]\n result = run_lm_eval(args.model_path, tasks, args.output_json, args.device, args.batch_size)\n print(json.dumps({\"output_json\": args.output_json, \"scores\": extract_primary_scores(result)}, indent=2))\n\n\ndef cmd_suite(args):\n work_dir = Path(args.work_dir)\n work_dir.mkdir(parents=True, exist_ok=True)\n probe_dir = work_dir / \"probe\"\n\n probe_result = run_probe(ProbeConfig(\n model_id=args.model_id,\n out_dir=str(probe_dir),\n dataset_name=args.dataset_name,\n dataset_config=args.dataset_config,\n split=args.split,\n text_field=args.text_field,\n num_texts=args.probe_num_texts,\n max_length=args.max_length,\n max_rows_per_module=args.max_rows_per_module,\n include_modules=[x.strip() for x in args.include_modules.split(\",\") if x.strip()],\n ))\n features_jsonl = probe_result[\"features_jsonl\"]\n probe_rows = load_rows(features_jsonl)\n\n judge_cfg = JudgeConfig(\n features_jsonl=features_jsonl,\n budget_mb_extra=args.budget_mb_extra,\n heuristic_only=(not args.judge_base_url or not args.judge_model),\n judge_base_url=args.judge_base_url,\n judge_model=args.judge_model,\n judge_api_key_env=args.judge_api_key_env,\n chunk_size=args.chunk_size,\n )\n judge_proposals = make_judge_proposals(judge_cfg)\n judge_plan = compile_plan(judge_proposals[\"assignments\"], probe_rows, args.budget_mb_extra, args.min_roi)\n judge_dynamic = plan_to_dynamic(judge_plan)\n save_plan_and_dynamic(judge_plan, judge_dynamic, str(work_dir / \"judge_plan.json\"), str(work_dir / \"judge_dynamic.json\"))\n\n heuristic_flags = [x.strip() for x in args.heuristic_include.split(\",\") if x.strip()]\n heuristic_dynamic = heuristic_dynamic_from_flags(heuristic_flags)\n (work_dir / \"heuristic_dynamic.json\").write_text(json.dumps(heuristic_dynamic, indent=2), encoding=\"utf-8\")\n\n calibration_texts = load_texts(args.dataset_name, args.dataset_config, args.split, args.text_field, args.quant_num_texts)\n tasks = [x.strip() for x in args.tasks.split(\",\") if x.strip()]\n\n configs = {\n \"baseline_w4g128\": None,\n \"heuristic_mixed\": heuristic_dynamic,\n \"judge_sdq\": judge_dynamic,\n }\n\n leaderboard = {}\n for name, dynamic in configs.items():\n model_dir = work_dir / \"models\" / name\n eval_json = work_dir / \"eval\" / f\"{name}.json\"\n quantize_with_dynamic(args.model_id, calibration_texts, str(model_dir), dynamic, args.base_bits, args.base_group_size, args.batch_size)\n result = run_lm_eval(str(model_dir), tasks, str(eval_json), args.device, args.eval_batch_size)\n leaderboard[name] = extract_primary_scores(result)\n\n payload = {\n \"probe\": probe_result,\n \"leaderboard\": leaderboard,\n \"judge_used\": bool(args.judge_base_url and args.judge_model),\n \"judge_plan_path\": str(work_dir / \"judge_plan.json\"),\n \"judge_dynamic_path\": str(work_dir / \"judge_dynamic.json\"),\n \"heuristic_dynamic_path\": str(work_dir / \"heuristic_dynamic.json\"),\n }\n (work_dir / \"leaderboard.json\").write_text(json.dumps(payload, indent=2), encoding=\"utf-8\")\n print(json.dumps(payload, indent=2))\n\n\ndef build_parser():\n p = argparse.ArgumentParser(prog=\"quant_judge\")\n sub = p.add_subparsers(dest=\"cmd\", required=True)\n\n probe = sub.add_parser(\"probe\")\n probe.add_argument(\"--model-id\", required=True)\n probe.add_argument(\"--out-dir\", required=True)\n probe.add_argument(\"--dataset-name\", default=\"wikitext\")\n probe.add_argument(\"--dataset-config\", default=\"wikitext-2-raw-v1\")\n probe.add_argument(\"--split\", default=\"train\")\n probe.add_argument(\"--text-field\", default=\"text\")\n probe.add_argument(\"--num-texts\", type=int, default=64)\n probe.add_argument(\"--max-length\", type=int, default=256)\n probe.add_argument(\"--max-rows-per-module\", type=int, default=128)\n probe.add_argument(\"--include-modules\", default=\"q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj,lm_head\")\n probe.set_defaults(func=cmd_probe)\n\n plan = sub.add_parser(\"plan\")\n plan.add_argument(\"--features-jsonl\", required=True)\n plan.add_argument(\"--plan-out\", required=True)\n plan.add_argument(\"--dynamic-out\", required=True)\n plan.add_argument(\"--budget-mb-extra\", type=float, default=150.0)\n plan.add_argument(\"--min-roi\", type=float, default=1e-6)\n plan.add_argument(\"--heuristic-only\", action=\"store_true\")\n plan.add_argument(\"--judge-base-url\", default=None)\n plan.add_argument(\"--judge-model\", default=None)\n plan.add_argument(\"--judge-api-key-env\", default=\"JUDGE_API_KEY\")\n plan.add_argument(\"--chunk-size\", type=int, default=24)\n plan.set_defaults(func=cmd_plan)\n\n quant = sub.add_parser(\"quantize\")\n quant.add_argument(\"--model-id\", required=True)\n quant.add_argument(\"--plan-json\", required=True)\n quant.add_argument(\"--out-dir\", required=True)\n quant.add_argument(\"--dataset-name\", default=\"wikitext\")\n quant.add_argument(\"--dataset-config\", default=\"wikitext-2-raw-v1\")\n quant.add_argument(\"--split\", default=\"train\")\n quant.add_argument(\"--text-field\", default=\"text\")\n quant.add_argument(\"--num-texts\", type=int, default=128)\n quant.add_argument(\"--batch-size\", type=int, default=4)\n quant.add_argument(\"--base-bits\", type=int, default=4)\n quant.add_argument(\"--base-group-size\", type=int, default=128)\n quant.set_defaults(func=cmd_quantize)\n\n ev = sub.add_parser(\"eval\")\n ev.add_argument(\"--model-path\", required=True)\n ev.add_argument(\"--tasks\", default=\"arc_challenge,hellaswag\")\n ev.add_argument(\"--output-json\", required=True)\n ev.add_argument(\"--device\", default=\"cuda:0\")\n ev.add_argument(\"--batch-size\", default=\"auto\")\n ev.set_defaults(func=cmd_eval)\n\n suite = sub.add_parser(\"suite\")\n suite.add_argument(\"--model-id\", required=True)\n suite.add_argument(\"--work-dir\", required=True)\n suite.add_argument(\"--dataset-name\", default=\"wikitext\")\n suite.add_argument(\"--dataset-config\", default=\"wikitext-2-raw-v1\")\n suite.add_argument(\"--split\", default=\"train\")\n suite.add_argument(\"--text-field\", default=\"text\")\n suite.add_argument(\"--probe-num-texts\", type=int, default=64)\n suite.add_argument(\"--quant-num-texts\", type=int, default=128)\n suite.add_argument(\"--max-length\", type=int, default=256)\n suite.add_argument(\"--max-rows-per-module\", type=int, default=128)\n suite.add_argument(\"--include-modules\", default=\"q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj,lm_head\")\n suite.add_argument(\"--budget-mb-extra\", type=float, default=150.0)\n suite.add_argument(\"--min-roi\", type=float, default=1e-6)\n suite.add_argument(\"--heuristic-include\", default=\"skip_lm_head,down_proj_g64\")\n suite.add_argument(\"--judge-base-url\", default=None)\n suite.add_argument(\"--judge-model\", default=None)\n suite.add_argument(\"--judge-api-key-env\", default=\"JUDGE_API_KEY\")\n suite.add_argument(\"--chunk-size\", type=int, default=24)\n suite.add_argument(\"--tasks\", default=\"arc_challenge,hellaswag\")\n suite.add_argument(\"--device\", default=\"cuda:0\")\n suite.add_argument(\"--batch-size\", type=int, default=4)\n suite.add_argument(\"--eval-batch-size\", default=\"auto\")\n suite.add_argument(\"--base-bits\", type=int, default=4)\n suite.add_argument(\"--base-group-size\", type=int, default=128)\n suite.set_defaults(func=cmd_suite)\n\n return p\n\n\ndef main():\n parser = build_parser()\n args = parser.parse_args()\n args.func(args)\n\n\nif __name__ == \"__main__\":\n main()\n" } ]} +``` + +--- + +### User: +Here is a detailed critique of the `sdq_quant_judge_qwen35_gptqmodel` project. Overall, the scaffold is well-structured and uses a clever approach (fake-quant probing combined with an LLM judge/heuristic compiler to guide GPTQModel). + +However, there are several critical logical disconnects, mathematical approximations, and workflow inconsistencies that will prevent it from working as intended out-of-the-box. + +### 1. Critical Bugs & Logical Flaws + +**A. The Budget Compiler Silently Defeats the "Skip" Rule** +In `judge.py`, you have a specific heuristic to skip `lm_head` if it exhibits high outliers. However, `lm_head` is huge (e.g., in Qwen3.5-4B it is `151936 * 3072 ≈ 466M` parameters). +* Upgrading `lm_head` from W4 (0.5 bytes/param) to Skip (2 bytes/param) costs `~700 MB`. +* Your default `budget_mb_extra` in the CLI is `150.0`. +* In `compiler.py`, the logic dictates: `if spent_mb + up["cost_mb"] > extra_budget_mb: continue`. +**Result:** The compiler will *always* reject skipping the `lm_head` under default settings. The user will be confused as to why the `lm_head_extreme` rule never actually skips the layer. +* **Fix:** Either evaluate `lm_head` skip *outside* the ROI budget calculation (treat skipping the embedding/head as a structural baseline rather than a budget upgrade), or increase the default budget significantly (e.g., `1000 MB`). + +**B. `cmd_suite` Inconsistently Evaluates Baselines** +The `suite` command is meant to compare baselines, heuristics, and the judge. However: +1. It assigns `heuristic_mixed` to `heuristic_dynamic_from_flags`, which uses **hardcoded regexes** (e.g., `model.layers.(0|1|30|31)` to W8) and completely ignores the fake-quant metrics in `probe.py`. +2. Meanwhile, if no OpenAI API key is provided, `judge_proposals` falls back to `heuristic_only=True`. This means the `judge_sdq` run will actually execute your data-driven `heuristic_decision(row)` from `judge.py`! +**Result:** The benchmark leaderboard is heavily skewed. It compares a regex baseline to a data-driven heuristic disguised under the label "judge_sdq". The logic for generating the heuristic run in the suite should leverage the output of `make_judge_proposals(heuristic_only=True)` rather than `heuristic_dynamic_from_flags`. + +### 2. Mathematical & Algorithmic Issues + +**A. Inaccurate Parameter Size Calculations** +In `actions.py`, `bytes_per_param()` calculates size strictly as `bits / 8.0`. It completely ignores quantization metadata (scales and zero-points). +* A W4G128 config vs a W4G64 config have the *exact same* bits per parameter (4 bits), but W4G64 has double the number of scales and zeros. +* To patch this, you added a hack in `probe.py`: `+ (params * 0.05 / (1024 ** 2))`. +* **Fix:** Calculate group size overhead explicitly. Assuming 16-bit scales and 4-bit zeros (per group): + +``` +python + def bytes_per_param(action: str, fp_bytes: float = 2.0) -> float: + if action == "skip": return fp_bytes + spec = ACTION_SPECS[action] + base = spec.bits / 8.0 + metadata_per_group = 2.0 + 0.5 # 16-bit scale + 4-bit zero = 2.5 bytes + group_overhead = metadata_per_group / spec.group_size + return base + group_overhead +``` + +**B. Fake Quantization (RTN) vs. Real Quantization (GPTQ)** +In `probe.py`, `fake_quant_groupwise_linear_weight` implements a standard Round-To-Nearest (RTN) symmetric quantization using `w.abs().amax()`. +* **Issue:** GPTQ uses second-order information (the inverse Hessian of activations) to iteratively update weights, which often mitigates errors that RTN catches. By using RTN to calculate MSE, your "probe" might penalize modules that RTN struggles with, but that GPTQ would actually handle perfectly fine. +* **Fix:** While implementing a full Hessian-based probe is likely too slow, integrating an Activation-Aware Weight Quantization (AWQ) style scaling probe (multiplying weights by activation scales before RTN) would be a much more accurate proxy for GPTQ degradation. + +### 3. Data Processing & Performance + +**A. Unbatched Forward Passes** +In `probe.py`: + +``` +python +for text in tqdm(texts, desc="collect_activations"): + batch = tokenizer(...) + _ = model(**batch) +``` + +You are executing forward passes on 64 texts of length 256 *one by one*. For smaller models this is tolerable, but for larger models, it is very inefficient. +* **Fix:** Batch the tokenized inputs and run them through the model. + +**B. Memory/Precision Loss in Activation Capture** +In `make_hook` (in `probe.py`), you cast captured activations to `torch.float16` and move them to the CPU. + +``` +python +captured_inputs[name].append(x[:take].detach().cpu().to(torch.float16)) +``` + +If the source model uses `bfloat16`, casting to `float16` can cause outliers to overflow/clip (FP16 maxes out at ~65504), which directly undermines your `activation_outlier_rate` metric. +* **Fix:** Keep the activations in their native dtype when moving them to the CPU (or explicitly cast to `torch.float32` if memory allows, since you are limiting the rows to 128 anyway). + +### 4. Safety and Robustness + +**A. LLM Parsing Fragility** +In `judge.py`, `call_openai_compatible` relies on basic string manipulation (`strip_code_fence`) and `json.loads` to extract the output. LLMs often prefix outputs with `Here is the JSON you requested...` which will break the parser if the code fence isn't strictly the first thing. +* **Praise:** The fallback mechanism in `merge_assignments` (which applies the heuristic for any modules the LLM failed to return or parse) is excellent defensively. +* **Fix:** Instead of raw parsing, use OpenAI's newer `response_format: {"type": "json_schema", ...}` to enforce the JSON structure rigorously, or use a library like `instructor` / `pydantic` to validate the payload. + +**B. Division by Zero Risk in ROI** +In `compiler.py`, the ROI is `benefit / cost_mb`. +While `probe.py` forces `max(0.0, ...)` on extra MB, if the `cost_mb` evaluates to exactly `0.0` (which is highly likely due to the size calculation bug mentioned in 2A if W4G64 vs W4G128 round the same way), `benefit / cost_mb` will throw a `ZeroDivisionError`. +* **Fix:** Clamp the denominator: `roi = benefit / max(cost_mb, 1e-6)`. + +--- + +### ChatGPT: +This critique is good. The biggest blockers are real, and I’d change the project before trusting any leaderboard it produces. + +What I’d keep as-is is the overall architecture. What I’d change is the control semantics, the cost model, the suite labeling, and parts of the probe/runtime. + +## The most important fixes + +### 1) lm_head skip should not compete inside a 150 MB budget + +You’re right: with a 4B-class model, `lm_head` skip is structurally expensive enough that the default compiler will almost always reject it. + +The clean fix is to split decisions into two classes: + +- **structural overrides**: things like `skip lm_head` +- **budgeted upgrades**: `w4_g64`, `w8_g128`, other skip candidates + +That means the compiler should apply structural rules first, then spend the remaining budget on ROI-ranked upgrades. + +Use this logic in `compiler.py`: + +```Pythondef is_structural_skip(name: str, action: str) -> bool: return action == "skip" and name == "lm_head"def compile_plan(assignments, probe_rows, extra_budget_mb, min_roi=1e-6): probe_by_name = {row["name"]: row for row in probe_rows} structural = [] upgrades = [] for item in assignments: name = item["name"] action = item["action"] row = probe_by_name.get(name) if row is None: continue if is_structural_skip(name, action): structural.append(item) continue if action == "w4_g128": continue if action == "w8_g128": benefit = row["candidates"]["w4_g128"]["mse"] - row["candidates"]["w8_g128"]["mse"] cost_mb = row["candidates"]["w8_g128"]["estimated_extra_mb_vs_base"] elif action == "w4_g64": benefit = row["candidates"]["w4_g128"]["mse"] - row["candidates"]["w4_g64"]["mse"] cost_mb = row["candidates"]["w4_g64"]["estimated_extra_mb_vs_base"] elif action == "skip": benefit = row["candidates"]["w4_g128"]["mse"] cost_mb = row["candidates"]["skip"]["estimated_extra_mb_vs_base"] else: continue if benefit <= 0 or cost_mb <= 0: continue roi = benefit / max(cost_mb, 1e-6) if roi < min_roi: continue upgrades.append({ "name": name, "action": action, "benefit": float(benefit), "cost_mb": float(cost_mb), "roi": float(roi), "priority": float(item.get("priority", 0.5)), "reason_code": item.get("reason_code", ""), "source": item.get("source", ""), }) upgrades.sort(key=lambda x: (x["roi"], x["priority"]), reverse=True) approved = [] spent_mb = 0.0 for up in upgrades: if spent_mb + up["cost_mb"] > extra_budget_mb: continue approved.append(up) spent_mb += up["cost_mb"] final_assignments = [] structural_by_name = {x["name"]: x for x in structural} approved_by_name = {x["name"]: x for x in approved} for row in probe_rows: name = row["name"] if name in structural_by_name: item = structural_by_name[name] final_assignments.append({ "name": name, "action": "skip", "priority": float(item.get("priority", 1.0)), "reason_code": item.get("reason_code", "structural_skip"), "source": item.get("source", "compiler"), "estimated_extra_mb_vs_base": row["candidates"]["skip"]["estimated_extra_mb_vs_base"], "structural": True, }) elif name in approved_by_name: up = approved_by_name[name] final_assignments.append({ "name": name, "action": up["action"], "priority": up["priority"], "reason_code": up["reason_code"], "source": up["source"], "estimated_extra_mb_vs_base": up["cost_mb"], "benefit": up["benefit"], "roi": up["roi"], "structural": False, }) else: final_assignments.append({ "name": name, "action": "w4_g128", "priority": 0.0, "reason_code": "base_default", "source": "compiler", "estimated_extra_mb_vs_base": 0.0, "structural": False, }) return { "budget_mb_extra": float(extra_budget_mb), "used_mb_extra": round(float(spent_mb), 6), "structural_skips": len(structural), "upgrades_approved": len(approved), "assignments": sorted(final_assignments, key=lambda x: x["name"]), } +Pythondef is_structural_skip(name: str, action: str) -> bool: return action == "skip" and name == "lm_head"def compile_plan(assignments, probe_rows, extra_budget_mb, min_roi=1e-6): probe_by_name = {row["name"]: row for row in probe_rows} structural = [] upgrades = [] for item in assignments: name = item["name"] action = item["action"] row = probe_by_name.get(name) if row is None: continue if is_structural_skip(name, action): structural.append(item) continue if action == "w4_g128": continue if action == "w8_g128": benefit = row["candidates"]["w4_g128"]["mse"] - row["candidates"]["w8_g128"]["mse"] cost_mb = row["candidates"]["w8_g128"]["estimated_extra_mb_vs_base"] elif action == "w4_g64": benefit = row["candidates"]["w4_g128"]["mse"] - row["candidates"]["w4_g64"]["mse"] cost_mb = row["candidates"]["w4_g64"]["estimated_extra_mb_vs_base"] elif action == "skip": benefit = row["candidates"]["w4_g128"]["mse"] cost_mb = row["candidates"]["skip"]["estimated_extra_mb_vs_base"] else: continue if benefit <= 0 or cost_mb <= 0: continue roi = benefit / max(cost_mb, 1e-6) if roi < min_roi: continue upgrades.append({ "name": name, "action": action, "benefit": float(benefit), "cost_mb": float(cost_mb), "roi": float(roi), "priority": float(item.get("priority", 0.5)), "reason_code": item.get("reason_code", ""), "source": item.get("source", ""), }) upgrades.sort(key=lambda x: (x["roi"], x["priority"]), reverse=True) approved = [] spent_mb = 0.0 for up in upgrades: if spent_mb + up["cost_mb"] > extra_budget_mb: continue approved.append(up) spent_mb += up["cost_mb"] final_assignments = [] structural_by_name = {x["name"]: x for x in structural} approved_by_name = {x["name"]: x for x in approved} for row in probe_rows: name = row["name"] if name in structural_by_name: item = structural_by_name[name] final_assignments.append({ "name": name, "action": "skip", "priority": float(item.get("priority", 1.0)), "reason_code": item.get("reason_code", "structural_skip"), "source": item.get("source", "compiler"), "estimated_extra_mb_vs_base": row["candidates"]["skip"]["estimated_extra_mb_vs_base"], "structural": True, }) elif name in approved_by_name: up = approved_by_name[name] final_assignments.append({ "name": name, "action": up["action"], "priority": up["priority"], "reason_code": up["reason_code"], "source": up["source"], "estimated_extra_mb_vs_base": up["cost_mb"], "benefit": up["benefit"], "roi": up["roi"], "structural": False, }) else: final_assignments.append({ "name": name, "action": "w4_g128", "priority": 0.0, "reason_code": "base_default", "source": "compiler", "estimated_extra_mb_vs_base": 0.0, "structural": False, }) return { "budget_mb_extra": float(extra_budget_mb), "used_mb_extra": round(float(spent_mb), 6), "structural_skips": len(structural), "upgrades_approved": len(approved), "assignments": sorted(final_assignments, key=lambda x: x["name"]), } +``` + +That makes `lm_head` skip explicit instead of silently impossible. + +### 2) The suite currently compares mismatched baselines + +You’re also right here. The current suite compares: + +- a regex rule baseline +- a data-driven heuristic under the “judge” label when no external judge is configured + +That is misleading. + +The fix is: + +- keep the regex one, but call it `rule_based_baseline` +- add a separate `data_driven_heuristic` +- only run `judge_sdq` if an external judge endpoint is actually configured + +So `cmd_suite` should become: + +```Pythonrule_based_dynamic = heuristic_dynamic_from_flags(heuristic_flags)heuristic_proposals = make_judge_proposals( JudgeConfig( features_jsonl=features_jsonl, budget_mb_extra=args.budget_mb_extra, heuristic_only=True, chunk_size=args.chunk_size, ))heuristic_plan = compile_plan( heuristic_proposals["assignments"], probe_rows, args.budget_mb_extra, args.min_roi,)data_driven_heuristic_dynamic = plan_to_dynamic(heuristic_plan)configs = { "baseline_w4g128": None, "rule_based_baseline": rule_based_dynamic, "data_driven_heuristic": data_driven_heuristic_dynamic,}if args.judge_base_url and args.judge_model: judge_proposals = make_judge_proposals( JudgeConfig( features_jsonl=features_jsonl, budget_mb_extra=args.budget_mb_extra, heuristic_only=False, judge_base_url=args.judge_base_url, judge_model=args.judge_model, judge_api_key_env=args.judge_api_key_env, chunk_size=args.chunk_size, ) ) judge_plan = compile_plan(judge_proposals["assignments"], probe_rows, args.budget_mb_extra, args.min_roi) configs["judge_sdq"] = plan_to_dynamic(judge_plan) +Pythonrule_based_dynamic = heuristic_dynamic_from_flags(heuristic_flags)heuristic_proposals = make_judge_proposals( JudgeConfig( features_jsonl=features_jsonl, budget_mb_extra=args.budget_mb_extra, heuristic_only=True, chunk_size=args.chunk_size, ))heuristic_plan = compile_plan( heuristic_proposals["assignments"], probe_rows, args.budget_mb_extra, args.min_roi,)data_driven_heuristic_dynamic = plan_to_dynamic(heuristic_plan)configs = { "baseline_w4g128": None, "rule_based_baseline": rule_based_dynamic, "data_driven_heuristic": data_driven_heuristic_dynamic,}if args.judge_base_url and args.judge_model: judge_proposals = make_judge_proposals( JudgeConfig( features_jsonl=features_jsonl, budget_mb_extra=args.budget_mb_extra, heuristic_only=False, judge_base_url=args.judge_base_url, judge_model=args.judge_model, judge_api_key_env=args.judge_api_key_env, chunk_size=args.chunk_size, ) ) judge_plan = compile_plan(judge_proposals["assignments"], probe_rows, args.budget_mb_extra, args.min_roi) configs["judge_sdq"] = plan_to_dynamic(judge_plan) +``` + +That makes the leaderboard honest. + +### 3) The parameter-size model is too loose + +I agree with your criticism, with one nuance: the exact metadata overhead depends on the real pack format and whether the backend stores zeros for symmetric quant. + +So I would not hardcode one “true” formula unless you have confirmed GPTQModel’s current serialization format for your chosen config. But I would definitely replace the current hack with a configurable overhead model. + +A safer version: + +```Pythondef bytes_per_param(action: str, fp_bytes: float = 2.0, scale_bytes: float = 2.0, zero_bytes: float = 0.0) -> float: if action == "skip": return fp_bytes spec = ACTION_SPECS[action] assert spec.bits is not None and spec.group_size is not None base = spec.bits / 8.0 metadata_per_group = scale_bytes + zero_bytes overhead = metadata_per_group / spec.group_size return base + overhead +Pythondef bytes_per_param(action: str, fp_bytes: float = 2.0, scale_bytes: float = 2.0, zero_bytes: float = 0.0) -> float: if action == "skip": return fp_bytes spec = ACTION_SPECS[action] assert spec.bits is not None and spec.group_size is not None base = spec.bits / 8.0 metadata_per_group = scale_bytes + zero_bytes overhead = metadata_per_group / spec.group_size return base + overhead +``` + +Then wire it like this: + +- `sym=True` → `zero_bytes=0.0` as a lower-overhead estimate +- asymmetric / packed-zero path → `zero_bytes=0.5` or whatever the actual backend uses + +That gives you a real cost model instead of `+ 0.05 MB`. + +### 4) RTN fake-quant is still only a proxy + +Yes. This is not GPTQ. It is a ranking proxy. + +I would not try to sell it as “highly accurate GPTQ loss.” I’d describe it as: + +per-layer RTN-style sensitivity estimate with group geometry matched to the target quantizer + +That said, your suggestion is right: an activation-aware probe would be better than plain RTN. + +The lightest useful upgrade is: + +- compute per-input-channel activation magnitude from the captured rows +- scale weights by that channel salience before fake quant +- unscale afterward + +Something like: + +```Pythondef activation_aware_fake_quant(weight, x_rows, bits, group_size): # x_rows shape: [N, in_features] salience = x_rows.float().abs().mean(dim=0).clamp_min(1e-6) salience = salience / salience.mean() w_scaled = weight.float() * salience.unsqueeze(0) wq_scaled = fake_quant_groupwise_linear_weight(w_scaled, bits=bits, group_size=group_size) wq = wq_scaled / salience.unsqueeze(0) return wq.to(weight.device, dtype=weight.dtype) +Pythondef activation_aware_fake_quant(weight, x_rows, bits, group_size): # x_rows shape: [N, in_features] salience = x_rows.float().abs().mean(dim=0).clamp_min(1e-6) salience = salience / salience.mean() w_scaled = weight.float() * salience.unsqueeze(0) wq_scaled = fake_quant_groupwise_linear_weight(w_scaled, bits=bits, group_size=group_size) wq = wq_scaled / salience.unsqueeze(0) return wq.to(weight.device, dtype=weight.dtype) +``` + +That is still not GPTQ, but it is a materially better proxy for fragile channels. + +### 5) Probe collection should be batched + +Agreed. + +Right now the probe is doing one forward per text, which is wasteful. The simplest fix is to batch texts before tokenization. + +Example: + +```Pythondef batched(iterable, batch_size): for i in range(0, len(iterable), batch_size): yield iterable[i:i + batch_size]with torch.no_grad(): for batch_texts in tqdm(list(batched(texts, 8)), desc="collect_activations"): batch = tokenizer( batch_texts, return_tensors="pt", truncation=True, max_length=cfg.max_length, padding=True, ) batch = {k: v.to(model.device) for k, v in batch.items()} _ = model(**batch) if all(v <= 0 for v in remaining_rows.values()): break +Pythondef batched(iterable, batch_size): for i in range(0, len(iterable), batch_size): yield iterable[i:i + batch_size]with torch.no_grad(): for batch_texts in tqdm(list(batched(texts, 8)), desc="collect_activations"): batch = tokenizer( batch_texts, return_tensors="pt", truncation=True, max_length=cfg.max_length, padding=True, ) batch = {k: v.to(model.device) for k, v in batch.items()} _ = model(**batch) if all(v <= 0 for v in remaining_rows.values()): break +``` + +That speeds the probe up substantially. + +### 6) Preserve activation dtype when capturing + +Yes. This is a good catch. + +Casting BF16 activations to FP16 on CPU can distort outlier statistics. The clean change is either: + +- preserve original dtype, or +- normalize everything to FP32 on CPU + +Since you are row-limiting aggressively, FP32 is safer: + +```Pythoncaptured_inputs[name].append(x[:take].detach().cpu().to(torch.float32)) +Pythoncaptured_inputs[name].append(x[:take].detach().cpu().to(torch.float32)) +``` + +Then when feeding the module probe back into the layer, cast to the module dtype on device. + +### 7) LLM parsing needs harder validation + +Also correct. + +One caveat: many OpenAI-compatible endpoints do **not** fully support `json_schema`, even if official OpenAI endpoints do. So I’d implement a two-layer strategy: + +- try `json_schema` if the backend supports it +- otherwise fall back to `json_object` +- validate with Pydantic after parsing + +At minimum, define a strict validator: + +```Pythonfrom pydantic import BaseModel, Fieldfrom typing import Literalclass Assignment(BaseModel): name: str action: Literal["w4_g128", "w4_g64", "w8_g128", "skip"] priority: float = Field(ge=0.0, le=1.0) reason_code: strclass AssignmentPayload(BaseModel): assignments: list[Assignment] +Pythonfrom pydantic import BaseModel, Fieldfrom typing import Literalclass Assignment(BaseModel): name: str action: Literal["w4_g128", "w4_g64", "w8_g128", "skip"] priority: float = Field(ge=0.0, le=1.0) reason_code: strclass AssignmentPayload(BaseModel): assignments: list[Assignment] +``` + +Then parse and validate after `json.loads`. + +### 8) roi = benefit / cost_mb should clamp denominator + +Yes. That one should be fixed immediately. + +Use: + +```Pythonroi = benefit / max(cost_mb, 1e-6) +Pythonroi = benefit / max(cost_mb, 1e-6) +``` + +Even after the size model is fixed, that clamp is still the right defensive behavior. + +## My verdict on the critique + +These are the real severity levels I’d assign: + +- **must-fix before trusting results** + +suite inconsistency + +silent `lm_head` skip defeat + +denominator clamp + +activation dtype handling +- suite inconsistency +- silent `lm_head` skip defeat +- denominator clamp +- activation dtype handling +- **should-fix before calling the budget model credible** + +metadata-aware bytes-per-param + +batch the probe +- metadata-aware bytes-per-param +- batch the probe +- **nice next upgrade** + +activation-aware fake quant + +stronger structured LLM validation +- activation-aware fake quant +- stronger structured LLM validation + +So the project is **not broken conceptually**, but it is still **too inconsistent to produce a trustworthy leaderboard without these fixes**. + +The cleanest next step is to ship a **v4 JSON project** with: + +- structural skip support +- honest suite labels +- metadata-aware cost model +- batched probing +- FP32 activation capture +- safer judge validation + +I can produce that as a single JSON artifact next. + +--- + +### User: +yes + +--- + +### ChatGPT: +Here’s the v4 JSON project artifact with the fixes baked in: + +It includes: + +- structural `lm_head` skip handling outside the ROI budget +- honest suite labels +- metadata-aware bytes-per-param +- batched probing +- FP32 activation capture +- Pydantic-validated judge output +- clamped ROI denominator +- optional activation-aware fake quant proxy + +If you want, next I can also turn this JSON artifact into a ready-to-extract bash script that recreates the full folder locally in one command. + +--- + +### User: +i cant download it, provide it in plain text please + +--- + +### ChatGPT: +Run this in an empty folder. It will recreate the full v4 project locally. + +```Bashmkdir -p sdq_quant_judge_v4/quant_judgecd sdq_quant_judge_v4cat > README.md <<'EOF'# SDQ Quant Judge v4 for GPTQModelA software-defined quantization scaffold for `Qwen/Qwen3.5-4B-Base` using:- a real FP probe via `transformers`- per-row, per-group fake quantization- optional activation-aware proxy quantization- an external OpenAI-compatible judge- a deterministic compiler with structural skips + ROI budget clamp- GPTQModel `dynamic` mixed quantization- an honest benchmark suite: - `baseline_w4g128` - `rule_based_baseline` - `data_driven_heuristic` - `judge_sdq` (only when a judge endpoint is actually configured)## InstallCreate a virtualenv and install your CUDA-matching PyTorch manually first.Example:```bashpython -m venv .venvsource .venv/bin/activate# Example only — change to your CUDA build:# pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128pip install -r requirements.txt +Bashmkdir -p sdq_quant_judge_v4/quant_judgecd sdq_quant_judge_v4cat > README.md <<'EOF'# SDQ Quant Judge v4 for GPTQModelA software-defined quantization scaffold for `Qwen/Qwen3.5-4B-Base` using:- a real FP probe via `transformers`- per-row, per-group fake quantization- optional activation-aware proxy quantization- an external OpenAI-compatible judge- a deterministic compiler with structural skips + ROI budget clamp- GPTQModel `dynamic` mixed quantization- an honest benchmark suite: - `baseline_w4g128` - `rule_based_baseline` - `data_driven_heuristic` - `judge_sdq` (only when a judge endpoint is actually configured)## InstallCreate a virtualenv and install your CUDA-matching PyTorch manually first.Example:```bashpython -m venv .venvsource .venv/bin/activate# Example only — change to your CUDA build:# pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128pip install -r requirements.txt +``` + +## Probe + +```Bashpython -m quant_judge.pipeline probe \ --model-id Qwen/Qwen3.5-4B-Base \ --out-dir ./runs/probe \ --dataset-name wikitext \ --dataset-config wikitext-2-raw-v1 \ --split train \ --text-field text \ --num-texts 64 \ --max-length 256 \ --max-rows-per-module 128 \ --probe-batch-size 8 \ --include-modules q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj,lm_head +Bashpython -m quant_judge.pipeline probe \ --model-id Qwen/Qwen3.5-4B-Base \ --out-dir ./runs/probe \ --dataset-name wikitext \ --dataset-config wikitext-2-raw-v1 \ --split train \ --text-field text \ --num-texts 64 \ --max-length 256 \ --max-rows-per-module 128 \ --probe-batch-size 8 \ --include-modules q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj,lm_head +``` + +## Plan with data-driven heuristic only + +```Bashpython -m quant_judge.pipeline plan \ --features-jsonl ./runs/probe/probe_features.jsonl \ --plan-out ./runs/plan.json \ --dynamic-out ./runs/dynamic.json \ --budget-mb-extra 150 \ --heuristic-only +Bashpython -m quant_judge.pipeline plan \ --features-jsonl ./runs/probe/probe_features.jsonl \ --plan-out ./runs/plan.json \ --dynamic-out ./runs/dynamic.json \ --budget-mb-extra 150 \ --heuristic-only +``` + +## Plan with external judge + +```Bashexport JUDGE_API_KEY=your_key_herepython -m quant_judge.pipeline plan \ --features-jsonl ./runs/probe/probe_features.jsonl \ --plan-out ./runs/plan.json \ --dynamic-out ./runs/dynamic.json \ --budget-mb-extra 150 \ --judge-base-url http://127.0.0.1:8000/v1 \ --judge-model Qwen/Qwen2.5-14B-Instruct +Bashexport JUDGE_API_KEY=your_key_herepython -m quant_judge.pipeline plan \ --features-jsonl ./runs/probe/probe_features.jsonl \ --plan-out ./runs/plan.json \ --dynamic-out ./runs/dynamic.json \ --budget-mb-extra 150 \ --judge-base-url http://127.0.0.1:8000/v1 \ --judge-model Qwen/Qwen2.5-14B-Instruct +``` + +## Quantize from plan + +```Bashpython -m quant_judge.pipeline quantize \ --model-id Qwen/Qwen3.5-4B-Base \ --plan-json ./runs/plan.json \ --out-dir ./runs/quantized \ --dataset-name wikitext \ --dataset-config wikitext-2-raw-v1 \ --split train \ --text-field text \ --num-texts 128 \ --batch-size 4 +Bashpython -m quant_judge.pipeline quantize \ --model-id Qwen/Qwen3.5-4B-Base \ --plan-json ./runs/plan.json \ --out-dir ./runs/quantized \ --dataset-name wikitext \ --dataset-config wikitext-2-raw-v1 \ --split train \ --text-field text \ --num-texts 128 \ --batch-size 4 +``` + +## Evaluate a saved model + +```Bashpython -m quant_judge.pipeline eval \ --model-path ./runs/quantized \ --tasks arc_challenge,hellaswag \ --output-json ./runs/eval.json +Bashpython -m quant_judge.pipeline eval \ --model-path ./runs/quantized \ --tasks arc_challenge,hellaswag \ --output-json ./runs/eval.json +``` + +## Full suite + +This compares: + +- `baseline_w4g128` +- `rule_based_baseline` +- `data_driven_heuristic` +- `judge_sdq` only if a judge endpoint is configured + +Without a judge: + +```Bashpython -m quant_judge.pipeline suite \ --model-id Qwen/Qwen3.5-4B-Base \ --work-dir ./runs/suite \ --dataset-name wikitext \ --dataset-config wikitext-2-raw-v1 \ --split train \ --text-field text \ --probe-num-texts 64 \ --quant-num-texts 128 \ --tasks arc_challenge,hellaswag +Bashpython -m quant_judge.pipeline suite \ --model-id Qwen/Qwen3.5-4B-Base \ --work-dir ./runs/suite \ --dataset-name wikitext \ --dataset-config wikitext-2-raw-v1 \ --split train \ --text-field text \ --probe-num-texts 64 \ --quant-num-texts 128 \ --tasks arc_challenge,hellaswag +``` + +With a judge: + +```Bashexport JUDGE_API_KEY=your_key_herepython -m quant_judge.pipeline suite \ --model-id Qwen/Qwen3.5-4B-Base \ --work-dir ./runs/suite \ --dataset-name wikitext \ --dataset-config wikitext-2-raw-v1 \ --split train \ --text-field text \ --probe-num-texts 64 \ --quant-num-texts 128 \ --tasks arc_challenge,hellaswag \ --judge-base-url http://127.0.0.1:8000/v1 \ --judge-model Qwen/Qwen2.5-14B-Instruct +Bashexport JUDGE_API_KEY=your_key_herepython -m quant_judge.pipeline suite \ --model-id Qwen/Qwen3.5-4B-Base \ --work-dir ./runs/suite \ --dataset-name wikitext \ --dataset-config wikitext-2-raw-v1 \ --split train \ --text-field text \ --probe-num-texts 64 \ --quant-num-texts 128 \ --tasks arc_challenge,hellaswag \ --judge-base-url http://127.0.0.1:8000/v1 \ --judge-model Qwen/Qwen2.5-14B-Instruct +``` + +## Notes + +- The fake quant probe is still a proxy, not full GPTQ. +- The optional activation-aware path is closer to reality than plain RTN. +- `lm_head` skip is treated as a structural override outside the ROI budget so it is not silently defeated by a tiny default budget. +EOF + +cat > requirements.txt <<'EOF' +transformers +datasets +accelerate +requests +numpy +tqdm +psutil +pydantic>=2 +gptqmodel +lm-eval +EOF + +cat > quant_judge/**init**.py <<'EOF' +"""SDQ Quant Judge v4.""" +EOF + +cat > quant_judge/actions.py <<'EOF' +from **future** import annotations + +from dataclasses import dataclass + +@dataclass(frozen=True) +class ActionSpec: +name: str +bits: int | None +group_size: int | None +skip: bool + +ACTION_SPECS: dict[str, ActionSpec] = { +"w4_g128": ActionSpec("w4_g128", bits=4, group_size=128, skip=False), +"w4_g64": ActionSpec("w4_g64", bits=4, group_size=64, skip=False), +"w8_g128": ActionSpec("w8_g128", bits=8, group_size=128, skip=False), +"skip": ActionSpec("skip", bits=None, group_size=None, skip=True), +} + +BASE_ACTION = "w4_g128" + +def allowed_actions() -> list[str]: +return list(ACTION_SPECS.keys()) + +def bytes_per_param( +action: str, +*, +fp_bytes: float = 2.0, +sym: bool = True, +scale_bytes: float = 2.0, +zero_bytes_asym: float = 0.5, +) -> float: +""" +Approximate bytes/param including metadata overhead. + +```For quantized actions: base bytes = bits / 8 overhead per param = (scale_bytes + zero_bytes) / group_sizeFor skip: bytes/param = fp_bytesThis is still an approximation, but far more honest than bits/8 alone."""if action == "skip": return fp_bytesspec = ACTION_SPECS[action]assert spec.bits is not Noneassert spec.group_size is not Nonebase = spec.bits / 8.0zero_bytes = 0.0 if sym else zero_bytes_asymoverhead = (scale_bytes + zero_bytes) / spec.group_sizereturn base + overhead +For quantized actions: base bytes = bits / 8 overhead per param = (scale_bytes + zero_bytes) / group_sizeFor skip: bytes/param = fp_bytesThis is still an approximation, but far more honest than bits/8 alone."""if action == "skip": return fp_bytesspec = ACTION_SPECS[action]assert spec.bits is not Noneassert spec.group_size is not Nonebase = spec.bits / 8.0zero_bytes = 0.0 if sym else zero_bytes_asymoverhead = (scale_bytes + zero_bytes) / spec.group_sizereturn base + overhead +``` + +def extra_mb_vs_base( +params: int, +action: str, +*, +sym: bool = True, +fp_bytes: float = 2.0, +scale_bytes: float = 2.0, +zero_bytes_asym: float = 0.5, +) -> float: +base = params * bytes_per_param( +BASE_ACTION, +fp_bytes=fp_bytes, +sym=sym, +scale_bytes=scale_bytes, +zero_bytes_asym=zero_bytes_asym, +) +now = params * bytes_per_param( +action, +fp_bytes=fp_bytes, +sym=sym, +scale_bytes=scale_bytes, +zero_bytes_asym=zero_bytes_asym, +) +return (now - base) / (1024 ** 2) + +def action_to_override(action: str) -> dict: +if action == "skip": +return {} +spec = ACTION_SPECS[action] +return {"bits": spec.bits, "group_size": spec.group_size} +EOF + +cat > quant_judge/probe.py <<'EOF' +from **future** import annotations + +import json +from dataclasses import dataclass +from pathlib import Path + +import torch +import torch.nn.functional as F +from datasets import load_dataset +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer + +from .actions import extra_mb_vs_base + +@dataclass +class ProbeConfig: +model_id: str +out_dir: str +dataset_name: str +dataset_config: str | None +split: str +text_field: str +num_texts: int +max_length: int +max_rows_per_module: int +include_modules: list[str] +probe_batch_size: int = 8 +trust_remote_code: bool = True +use_activation_aware_proxy: bool = False +sym: bool = True +scale_bytes: float = 2.0 +zero_bytes_asym: float = 0.5 + +def pick_dtype() -> torch.dtype: +if not torch.cuda.is_available(): +return torch.float32 +major, _minor = torch.cuda.get_device_capability(0) +if major >= 8: +return torch.bfloat16 +return torch.float16 + +def fake_quant_groupwise_linear_weight(weight: torch.Tensor, bits: int, group_size: int) -> torch.Tensor: +""" +Symmetric per-row, per-group fake quantization for Linear weights. +weight shape: [out_features, in_features] +""" +assert weight.dim() == 2 +w = weight.detach().float() +out_features, in_features = w.shape + +```pad = (group_size - (in_features % group_size)) % group_sizeif pad: w = F.pad(w, (0, pad), value=0.0)maxq = (2 ** (bits - 1)) - 1w = w.view(out_features, -1, group_size)absmax = w.abs().amax(dim=-1, keepdim=True).clamp_min(1e-8)scale = absmax / maxqq = torch.round(w / scale).clamp(-maxq, maxq)dq = q * scaledq = dq.view(out_features, -1)if pad: dq = dq[:, :in_features]return dq.to(weight.device, dtype=weight.dtype) +pad = (group_size - (in_features % group_size)) % group_sizeif pad: w = F.pad(w, (0, pad), value=0.0)maxq = (2 ** (bits - 1)) - 1w = w.view(out_features, -1, group_size)absmax = w.abs().amax(dim=-1, keepdim=True).clamp_min(1e-8)scale = absmax / maxqq = torch.round(w / scale).clamp(-maxq, maxq)dq = q * scaledq = dq.view(out_features, -1)if pad: dq = dq[:, :in_features]return dq.to(weight.device, dtype=weight.dtype) +``` + +def activation_aware_fake_quant(weight: torch.Tensor, x_rows: torch.Tensor, bits: int, group_size: int) -> torch.Tensor: +""" +AWQ-ish lightweight proxy: +- compute per-input-channel salience from captured activations +- scale weights by salience +- fake quantize +- unscale +""" +salience = x_rows.float().abs().mean(dim=0).clamp_min(1e-6) +salience = salience / salience.mean() +w_scaled = weight.float() * salience.unsqueeze(0) +wq_scaled = fake_quant_groupwise_linear_weight(w_scaled, bits=bits, group_size=group_size) +wq = wq_scaled / salience.unsqueeze(0) +return wq.to(weight.device, dtype=weight.dtype) + +def load_texts(dataset_name: str, dataset_config: str | None, split: str, text_field: str, num_texts: int) -> list[str]: +ds = load_dataset(dataset_name, dataset_config, split=split) +texts: list[str] = [] +for row in ds: +value = row.get(text_field, "") +if isinstance(value, str) and value.strip(): +texts.append(value.strip()) +if len(texts) >= num_texts: +break +if not texts: +raise ValueError("No usable texts found.") +return texts + +def batched(items: list[str], batch_size: int): +for i in range(0, len(items), batch_size): +yield items[i:i + batch_size] + +def module_selected(name: str, include_modules: list[str]) -> bool: +return any(token in name for token in include_modules) + +def iter_target_linears(model, include_modules: list[str]): +for name, module in model.named_modules(): +if isinstance(module, torch.nn.Linear) and module_selected(name, include_modules): +yield name, module + +def flatten_rows(x: torch.Tensor) -> torch.Tensor: +if x.dim() == 2: +return x +if x.dim() >= 3: +return x.reshape(-1, x.shape[-1]) +return x.unsqueeze(0) + +def sample_rows(x: torch.Tensor, max_rows: int) -> torch.Tensor: +rows = flatten_rows(x) +if rows.shape[0] <= max_rows: +return rows +idx = torch.randperm(rows.shape[0], device=rows.device)[:max_rows] +return rows.index_select(0, idx) + +def weight_kurtosis(weight: torch.Tensor) -> float: +w = weight.detach().float().flatten() +mu = w.mean() +var = ((w - mu) ** 2).mean().clamp_min(1e-12) +return float((((w - mu) ** 4).mean() / (var ** 2)).item()) + +def activation_outlier_rate(x: torch.Tensor, sigma: float = 6.0) -> float: +xf = x.float() +std = xf.std().clamp_min(1e-6) +return float((xf.abs() > sigma * std).float().mean().item()) + +def block_index_from_name(name: str) -> int | None: +parts = name.split(".") +for i, p in enumerate(parts): +if p == "layers" and i + 1 < len(parts): +try: +return int(parts[i + 1]) +except Exception: +return None +return None + +def reduced_metrics_for_action( +x: torch.Tensor, +weight: torch.Tensor, +bias: torch.Tensor | None, +*, +bits: int, +group_size: int, +use_activation_aware_proxy: bool, +) -> dict: +if use_activation_aware_proxy: +wq = activation_aware_fake_quant(weight, x, bits=bits, group_size=group_size) +else: +wq = fake_quant_groupwise_linear_weight(weight, bits=bits, group_size=group_size) + +```out_fp = F.linear(x, weight, bias)out_q = F.linear(x, wq, bias)mse = F.mse_loss(out_q.float(), out_fp.float()).item()cos = F.cosine_similarity(out_q.float(), out_fp.float(), dim=-1).mean().item()max_abs = (out_q.float() - out_fp.float()).abs().max().item()return {"mse": float(mse), "cosine": float(cos), "max_abs": float(max_abs)} +out_fp = F.linear(x, weight, bias)out_q = F.linear(x, wq, bias)mse = F.mse_loss(out_q.float(), out_fp.float()).item()cos = F.cosine_similarity(out_q.float(), out_fp.float(), dim=-1).mean().item()max_abs = (out_q.float() - out_fp.float()).abs().max().item()return {"mse": float(mse), "cosine": float(cos), "max_abs": float(max_abs)} +``` + +def run_probe(cfg: ProbeConfig) -> dict: +out_dir = Path(cfg.out_dir) +out_dir.mkdir(parents=True, exist_ok=True) + +```tokenizer = AutoTokenizer.from_pretrained(cfg.model_id, trust_remote_code=cfg.trust_remote_code)if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_tokenmodel = AutoModelForCausalLM.from_pretrained( cfg.model_id, trust_remote_code=cfg.trust_remote_code, torch_dtype=pick_dtype(), device_map="auto", low_cpu_mem_usage=True,)model.eval()texts = load_texts(cfg.dataset_name, cfg.dataset_config, cfg.split, cfg.text_field, cfg.num_texts)targets = list(iter_target_linears(model, cfg.include_modules))if not targets: raise ValueError("No target modules found.")captured_inputs = {name: [] for name, _ in targets}remaining_rows = {name: cfg.max_rows_per_module for name, _ in targets}handles = []def make_hook(name: str): def hook(_module, inp, _out): rem = remaining_rows[name] if rem <= 0 or not inp or not isinstance(inp[0], torch.Tensor): return x = sample_rows(inp[0], rem) take = min(x.shape[0], rem) if take <= 0: return # keep FP32 on CPU to avoid BF16 -> FP16 clipping issues captured_inputs[name].append(x[:take].detach().cpu().to(torch.float32)) remaining_rows[name] -= take return hookfor name, module in targets: handles.append(module.register_forward_hook(make_hook(name)))with torch.no_grad(): for batch_texts in tqdm(list(batched(texts, cfg.probe_batch_size)), desc="collect_activations"): batch = tokenizer( batch_texts, return_tensors="pt", truncation=True, max_length=cfg.max_length, padding=True, ) batch = {k: v.to(model.device) for k, v in batch.items()} _ = model(**batch) if all(v <= 0 for v in remaining_rows.values()): breakfor h in handles: h.remove()rows = []for name, module in tqdm(targets, desc="probe_modules"): if not captured_inputs[name]: continue x = torch.cat(captured_inputs[name], dim=0).to(module.weight.device, dtype=module.weight.dtype) weight = module.weight.detach() bias = None if module.bias is None else module.bias.detach() m_w4g128 = reduced_metrics_for_action( x, weight, bias, bits=4, group_size=128, use_activation_aware_proxy=cfg.use_activation_aware_proxy ) m_w4g64 = reduced_metrics_for_action( x, weight, bias, bits=4, group_size=64, use_activation_aware_proxy=cfg.use_activation_aware_proxy ) m_w8g128 = reduced_metrics_for_action( x, weight, bias, bits=8, group_size=128, use_activation_aware_proxy=cfg.use_activation_aware_proxy ) params = int(weight.numel()) row = { "name": name, "module_type": name.split(".")[-1], "block_index": block_index_from_name(name), "params": params, "captured_rows": int(x.shape[0]), "weight_kurtosis": weight_kurtosis(weight), "act_outlier_rate": activation_outlier_rate(x), "candidates": { "w4_g128": { **m_w4g128, "estimated_extra_mb_vs_base": 0.0, }, "w4_g64": { **m_w4g64, "estimated_extra_mb_vs_base": max( 0.0, extra_mb_vs_base( params, "w4_g64", sym=cfg.sym, scale_bytes=cfg.scale_bytes, zero_bytes_asym=cfg.zero_bytes_asym, ), ), }, "w8_g128": { **m_w8g128, "estimated_extra_mb_vs_base": max( 0.0, extra_mb_vs_base( params, "w8_g128", sym=cfg.sym, scale_bytes=cfg.scale_bytes, zero_bytes_asym=cfg.zero_bytes_asym, ), ), }, "skip": { "mse": 0.0, "cosine": 1.0, "max_abs": 0.0, "estimated_extra_mb_vs_base": max( 0.0, extra_mb_vs_base( params, "skip", sym=cfg.sym, scale_bytes=cfg.scale_bytes, zero_bytes_asym=cfg.zero_bytes_asym, ), ), }, }, "proxy_gain_w4g64_vs_w4": float(max(0.0, m_w4g128["mse"] - m_w4g64["mse"])), "proxy_gain_w8_vs_w4": float(max(0.0, m_w4g128["mse"] - m_w8g128["mse"])), } rows.append(row)features_jsonl = out_dir / "probe_features.jsonl"metadata_json = out_dir / "probe_metadata.json"summary_json = out_dir / "probe_summary.json"with open(features_jsonl, "w", encoding="utf-8") as f: for row in rows: f.write(json.dumps(row, ensure_ascii=False) + "\n")metadata = { "model_id": cfg.model_id, "dataset_name": cfg.dataset_name, "dataset_config": cfg.dataset_config, "split": cfg.split, "text_field": cfg.text_field, "num_texts": cfg.num_texts, "max_length": cfg.max_length, "max_rows_per_module": cfg.max_rows_per_module, "probe_batch_size": cfg.probe_batch_size, "include_modules": cfg.include_modules, "num_rows": len(rows), "use_activation_aware_proxy": cfg.use_activation_aware_proxy, "sym": cfg.sym, "scale_bytes": cfg.scale_bytes, "zero_bytes_asym": cfg.zero_bytes_asym,}metadata_json.write_text(json.dumps(metadata, indent=2), encoding="utf-8")summary_json.write_text(json.dumps({"num_rows": len(rows)}, indent=2), encoding="utf-8")del modelif torch.cuda.is_available(): torch.cuda.empty_cache()return { "features_jsonl": str(features_jsonl), "metadata_json": str(metadata_json), "summary_json": str(summary_json), "num_rows": len(rows),} +tokenizer = AutoTokenizer.from_pretrained(cfg.model_id, trust_remote_code=cfg.trust_remote_code)if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_tokenmodel = AutoModelForCausalLM.from_pretrained( cfg.model_id, trust_remote_code=cfg.trust_remote_code, torch_dtype=pick_dtype(), device_map="auto", low_cpu_mem_usage=True,)model.eval()texts = load_texts(cfg.dataset_name, cfg.dataset_config, cfg.split, cfg.text_field, cfg.num_texts)targets = list(iter_target_linears(model, cfg.include_modules))if not targets: raise ValueError("No target modules found.")captured_inputs = {name: [] for name, _ in targets}remaining_rows = {name: cfg.max_rows_per_module for name, _ in targets}handles = []def make_hook(name: str): def hook(_module, inp, _out): rem = remaining_rows[name] if rem <= 0 or not inp or not isinstance(inp[0], torch.Tensor): return x = sample_rows(inp[0], rem) take = min(x.shape[0], rem) if take <= 0: return # keep FP32 on CPU to avoid BF16 -> FP16 clipping issues captured_inputs[name].append(x[:take].detach().cpu().to(torch.float32)) remaining_rows[name] -= take return hookfor name, module in targets: handles.append(module.register_forward_hook(make_hook(name)))with torch.no_grad(): for batch_texts in tqdm(list(batched(texts, cfg.probe_batch_size)), desc="collect_activations"): batch = tokenizer( batch_texts, return_tensors="pt", truncation=True, max_length=cfg.max_length, padding=True, ) batch = {k: v.to(model.device) for k, v in batch.items()} _ = model(**batch) if all(v <= 0 for v in remaining_rows.values()): breakfor h in handles: h.remove()rows = []for name, module in tqdm(targets, desc="probe_modules"): if not captured_inputs[name]: continue x = torch.cat(captured_inputs[name], dim=0).to(module.weight.device, dtype=module.weight.dtype) weight = module.weight.detach() bias = None if module.bias is None else module.bias.detach() m_w4g128 = reduced_metrics_for_action( x, weight, bias, bits=4, group_size=128, use_activation_aware_proxy=cfg.use_activation_aware_proxy ) m_w4g64 = reduced_metrics_for_action( x, weight, bias, bits=4, group_size=64, use_activation_aware_proxy=cfg.use_activation_aware_proxy ) m_w8g128 = reduced_metrics_for_action( x, weight, bias, bits=8, group_size=128, use_activation_aware_proxy=cfg.use_activation_aware_proxy ) params = int(weight.numel()) row = { "name": name, "module_type": name.split(".")[-1], "block_index": block_index_from_name(name), "params": params, "captured_rows": int(x.shape[0]), "weight_kurtosis": weight_kurtosis(weight), "act_outlier_rate": activation_outlier_rate(x), "candidates": { "w4_g128": { **m_w4g128, "estimated_extra_mb_vs_base": 0.0, }, "w4_g64": { **m_w4g64, "estimated_extra_mb_vs_base": max( 0.0, extra_mb_vs_base( params, "w4_g64", sym=cfg.sym, scale_bytes=cfg.scale_bytes, zero_bytes_asym=cfg.zero_bytes_asym, ), ), }, "w8_g128": { **m_w8g128, "estimated_extra_mb_vs_base": max( 0.0, extra_mb_vs_base( params, "w8_g128", sym=cfg.sym, scale_bytes=cfg.scale_bytes, zero_bytes_asym=cfg.zero_bytes_asym, ), ), }, "skip": { "mse": 0.0, "cosine": 1.0, "max_abs": 0.0, "estimated_extra_mb_vs_base": max( 0.0, extra_mb_vs_base( params, "skip", sym=cfg.sym, scale_bytes=cfg.scale_bytes, zero_bytes_asym=cfg.zero_bytes_asym, ), ), }, }, "proxy_gain_w4g64_vs_w4": float(max(0.0, m_w4g128["mse"] - m_w4g64["mse"])), "proxy_gain_w8_vs_w4": float(max(0.0, m_w4g128["mse"] - m_w8g128["mse"])), } rows.append(row)features_jsonl = out_dir / "probe_features.jsonl"metadata_json = out_dir / "probe_metadata.json"summary_json = out_dir / "probe_summary.json"with open(features_jsonl, "w", encoding="utf-8") as f: for row in rows: f.write(json.dumps(row, ensure_ascii=False) + "\n")metadata = { "model_id": cfg.model_id, "dataset_name": cfg.dataset_name, "dataset_config": cfg.dataset_config, "split": cfg.split, "text_field": cfg.text_field, "num_texts": cfg.num_texts, "max_length": cfg.max_length, "max_rows_per_module": cfg.max_rows_per_module, "probe_batch_size": cfg.probe_batch_size, "include_modules": cfg.include_modules, "num_rows": len(rows), "use_activation_aware_proxy": cfg.use_activation_aware_proxy, "sym": cfg.sym, "scale_bytes": cfg.scale_bytes, "zero_bytes_asym": cfg.zero_bytes_asym,}metadata_json.write_text(json.dumps(metadata, indent=2), encoding="utf-8")summary_json.write_text(json.dumps({"num_rows": len(rows)}, indent=2), encoding="utf-8")del modelif torch.cuda.is_available(): torch.cuda.empty_cache()return { "features_jsonl": str(features_jsonl), "metadata_json": str(metadata_json), "summary_json": str(summary_json), "num_rows": len(rows),} +``` + +EOF + +cat > quant_judge/judge.py <<'EOF' +from **future** import annotations + +import json +import os +from dataclasses import dataclass +from typing import Literal + +import requests +from pydantic import BaseModel, Field, ValidationError + +from .actions import ACTION_SPECS, BASE_ACTION, allowed_actions + +@dataclass +class JudgeConfig: +features_jsonl: str +budget_mb_extra: float +heuristic_only: bool = False +judge_base_url: str | None = None +judge_model: str | None = None +judge_api_key_env: str = "JUDGE_API_KEY" +chunk_size: int = 24 + +class Assignment(BaseModel): +name: str +action: Literal["w4_g128", "w4_g64", "w8_g128", "skip"] +priority: float = Field(ge=0.0, le=1.0) +reason_code: str + +class AssignmentPayload(BaseModel): +assignments: list[Assignment] + +def load_rows(path: str) -> list[dict]: +rows = [] +with open(path, "r", encoding="utf-8") as f: +for line in f: +line = line.strip() +if line: +rows.append(json.loads(line)) +return rows + +def heuristic_decision(row: dict) -> dict: +name = row["name"] +w4 = row["candidates"]["w4_g128"]["mse"] +w4g64 = row["candidates"]["w4_g64"]["mse"] +w8 = row["candidates"]["w8_g128"]["mse"] +kurt = row["weight_kurtosis"] +outlier = row["act_outlier_rate"] +block = row.get("block_index") + +```w8_relief = (w4 - w8) / max(w4, 1e-9)g64_relief = (w4 - w4g64) / max(w4, 1e-9)score = 0.0score += min(1.0, w8_relief / 0.60) * 0.40score += min(1.0, g64_relief / 0.35) * 0.20score += min(1.0, kurt / 20.0) * 0.15score += min(1.0, outlier / 0.05) * 0.15if "lm_head" in name: score += 0.25if row["module_type"] in ("down_proj", "o_proj", "gate_proj"): score += 0.10if block is not None and (block <= 1 or block >= 30): score += 0.10score = max(0.0, min(1.0, score))action = BASE_ACTIONreason = "default"if "lm_head" in name and (w8_relief > 0.55 or kurt > 25.0): action = "skip" reason = "lm_head_structural_skip"elif score >= 0.75 or w8_relief > 0.50: action = "w8_g128" reason = "high_sensitivity"elif score >= 0.42 or g64_relief > 0.18: action = "w4_g64" reason = "moderate_sensitivity"return { "name": name, "action": action, "priority": round(score, 6), "reason_code": reason, "source": "heuristic",} +w8_relief = (w4 - w8) / max(w4, 1e-9)g64_relief = (w4 - w4g64) / max(w4, 1e-9)score = 0.0score += min(1.0, w8_relief / 0.60) * 0.40score += min(1.0, g64_relief / 0.35) * 0.20score += min(1.0, kurt / 20.0) * 0.15score += min(1.0, outlier / 0.05) * 0.15if "lm_head" in name: score += 0.25if row["module_type"] in ("down_proj", "o_proj", "gate_proj"): score += 0.10if block is not None and (block <= 1 or block >= 30): score += 0.10score = max(0.0, min(1.0, score))action = BASE_ACTIONreason = "default"if "lm_head" in name and (w8_relief > 0.55 or kurt > 25.0): action = "skip" reason = "lm_head_structural_skip"elif score >= 0.75 or w8_relief > 0.50: action = "w8_g128" reason = "high_sensitivity"elif score >= 0.42 or g64_relief > 0.18: action = "w4_g64" reason = "moderate_sensitivity"return { "name": name, "action": action, "priority": round(score, 6), "reason_code": reason, "source": "heuristic",} +``` + +def reduced_row(row: dict) -> dict: +return { +"name": row["name"], +"module_type": row["module_type"], +"block_index": row.get("block_index"), +"params": row["params"], +"weight_kurtosis": row["weight_kurtosis"], +"act_outlier_rate": row["act_outlier_rate"], +"w4_g128_mse": row["candidates"]["w4_g128"]["mse"], +"w4_g64_mse": row["candidates"]["w4_g64"]["mse"], +"w8_g128_mse": row["candidates"]["w8_g128"]["mse"], +"extra_mb_w4_g64": row["candidates"]["w4_g64"]["estimated_extra_mb_vs_base"], +"extra_mb_w8_g128": row["candidates"]["w8_g128"]["estimated_extra_mb_vs_base"], +"extra_mb_skip": row["candidates"]["skip"]["estimated_extra_mb_vs_base"], +} + +def chunk(rows: list[dict], size: int) -> list[list[dict]]: +return [rows[i:i + size] for i in range(0, len(rows), size)] + +def extract_json_object(text: str) -> str: +""" +More robust than a code-fence stripper: +extract the first balanced JSON object. +""" +start = text.find("{") +if start < 0: +raise ValueError("No JSON object found in judge response.") +depth = 0 +for i in range(start, len(text)): +c = text[i] +if c == "{": +depth += 1 +elif c == "}": +depth -= 1 +if depth == 0: +return text[start:i + 1] +raise ValueError("Unbalanced JSON object in judge response.") + +def call_openai_compatible(base_url: str, model: str, api_key: str, rows: list[dict], budget_mb_extra: float) -> list[dict]: +system = ( +"You are a quantization policy judge. " +"Choose exactly one action per module from: " + ", ".join(allowed_actions()) + ". " +"Prefer w4_g128 unless the evidence shows fragility. " +"Use w4_g64 for moderate sensitivity. " +"Use w8_g128 for high sensitivity. " +"Use skip only for extreme outliers, usually lm_head. " +"Return JSON only with schema " +"{"assignments":[{"name":...,"action":...,"priority":0..1,"reason_code":...}]}" +) +user = { +"budget_mb_extra_global": budget_mb_extra, +"allowed_actions": allowed_actions(), +"rows": rows, +} + +```url = base_url.rstrip("/") + "/chat/completions"headers = {"Content-Type": "application/json"}if api_key: headers["Authorization"] = f"Bearer {api_key}"body = { "model": model, "temperature": 0.0, "response_format": {"type": "json_object"}, "messages": [ {"role": "system", "content": system}, {"role": "user", "content": json.dumps(user, ensure_ascii=False)}, ],}resp = requests.post(url, headers=headers, json=body, timeout=180)resp.raise_for_status()content = resp.json()["choices"][0]["message"]["content"]payload_text = extract_json_object(content)payload = AssignmentPayload.model_validate_json(payload_text)return [item.model_dump() for item in payload.assignments] +url = base_url.rstrip("/") + "/chat/completions"headers = {"Content-Type": "application/json"}if api_key: headers["Authorization"] = f"Bearer {api_key}"body = { "model": model, "temperature": 0.0, "response_format": {"type": "json_object"}, "messages": [ {"role": "system", "content": system}, {"role": "user", "content": json.dumps(user, ensure_ascii=False)}, ],}resp = requests.post(url, headers=headers, json=body, timeout=180)resp.raise_for_status()content = resp.json()["choices"][0]["message"]["content"]payload_text = extract_json_object(content)payload = AssignmentPayload.model_validate_json(payload_text)return [item.model_dump() for item in payload.assignments] +``` + +def merge_assignments(rows: list[dict], proposed: list[dict]) -> list[dict]: +by_name = {a["name"]: a for a in proposed if isinstance(a, dict) and "name" in a} +merged = [] + +```for row in rows: item = by_name.get(row["name"]) if item is None: merged.append(heuristic_decision(row)) continue action = item.get("action", BASE_ACTION) if action not in ACTION_SPECS: action = BASE_ACTION try: priority = float(item.get("priority", 0.5)) except Exception: priority = 0.5 priority = max(0.0, min(1.0, priority)) merged.append({ "name": row["name"], "action": action, "priority": priority, "reason_code": str(item.get("reason_code", "judge")), "source": "judge", })return merged +for row in rows: item = by_name.get(row["name"]) if item is None: merged.append(heuristic_decision(row)) continue action = item.get("action", BASE_ACTION) if action not in ACTION_SPECS: action = BASE_ACTION try: priority = float(item.get("priority", 0.5)) except Exception: priority = 0.5 priority = max(0.0, min(1.0, priority)) merged.append({ "name": row["name"], "action": action, "priority": priority, "reason_code": str(item.get("reason_code", "judge")), "source": "judge", })return merged +``` + +def make_judge_proposals(cfg: JudgeConfig) -> dict: +rows = load_rows(cfg.features_jsonl) + +```if cfg.heuristic_only: return {"assignments": [heuristic_decision(r) for r in rows]}if not cfg.judge_base_url or not cfg.judge_model: raise ValueError("Judge mode requires judge_base_url and judge_model.")api_key = os.environ.get(cfg.judge_api_key_env, "")assignments = []for part in chunk(rows, cfg.chunk_size): reduced = [reduced_row(r) for r in part] try: proposed = call_openai_compatible( cfg.judge_base_url, cfg.judge_model, api_key, reduced, cfg.budget_mb_extra, ) assignments.extend(merge_assignments(part, proposed)) except (requests.RequestException, ValidationError, ValueError, KeyError, IndexError): assignments.extend([heuristic_decision(r) for r in part])return {"assignments": assignments} +if cfg.heuristic_only: return {"assignments": [heuristic_decision(r) for r in rows]}if not cfg.judge_base_url or not cfg.judge_model: raise ValueError("Judge mode requires judge_base_url and judge_model.")api_key = os.environ.get(cfg.judge_api_key_env, "")assignments = []for part in chunk(rows, cfg.chunk_size): reduced = [reduced_row(r) for r in part] try: proposed = call_openai_compatible( cfg.judge_base_url, cfg.judge_model, api_key, reduced, cfg.budget_mb_extra, ) assignments.extend(merge_assignments(part, proposed)) except (requests.RequestException, ValidationError, ValueError, KeyError, IndexError): assignments.extend([heuristic_decision(r) for r in part])return {"assignments": assignments} +``` + +EOF + +cat > quant_judge/compiler.py <<'EOF' +from **future** import annotations + +import json +import re +from pathlib import Path + +from .actions import BASE_ACTION, action_to_override + +def is_structural_skip(name: str, action: str) -> bool: +""" +Structural skips are applied outside the ROI budget. +Right now only lm_head skip is treated this way. +""" +return action == "skip" and name == "lm_head" + +def compile_plan(assignments: list[dict], probe_rows: list[dict], extra_budget_mb: float, min_roi: float = 1e-6) -> dict: +probe_by_name = {row["name"]: row for row in probe_rows} + +```structural = []upgrades = []for item in assignments: name = item["name"] action = item["action"] row = probe_by_name.get(name) if row is None: continue if is_structural_skip(name, action): structural.append(item) continue if action == BASE_ACTION: continue if action == "w8_g128": benefit = row["candidates"]["w4_g128"]["mse"] - row["candidates"]["w8_g128"]["mse"] cost_mb = row["candidates"]["w8_g128"]["estimated_extra_mb_vs_base"] elif action == "w4_g64": benefit = row["candidates"]["w4_g128"]["mse"] - row["candidates"]["w4_g64"]["mse"] cost_mb = row["candidates"]["w4_g64"]["estimated_extra_mb_vs_base"] elif action == "skip": benefit = row["candidates"]["w4_g128"]["mse"] cost_mb = row["candidates"]["skip"]["estimated_extra_mb_vs_base"] else: continue if benefit <= 0 or cost_mb <= 0: continue roi = benefit / max(cost_mb, 1e-6) if roi < min_roi: continue upgrades.append({ "name": name, "action": action, "benefit": float(benefit), "cost_mb": float(cost_mb), "roi": float(roi), "priority": float(item.get("priority", 0.5)), "reason_code": item.get("reason_code", ""), "source": item.get("source", ""), })upgrades.sort(key=lambda x: (x["roi"], x["priority"]), reverse=True)approved = []spent_mb = 0.0for up in upgrades: if spent_mb + up["cost_mb"] > extra_budget_mb: continue approved.append(up) spent_mb += up["cost_mb"]structural_by_name = {x["name"]: x for x in structural}approved_by_name = {x["name"]: x for x in approved}final_assignments = []for row in probe_rows: name = row["name"] if name in structural_by_name: item = structural_by_name[name] final_assignments.append({ "name": name, "action": "skip", "priority": float(item.get("priority", 1.0)), "reason_code": item.get("reason_code", "structural_skip"), "source": item.get("source", "compiler"), "estimated_extra_mb_vs_base": row["candidates"]["skip"]["estimated_extra_mb_vs_base"], "structural": True, }) elif name in approved_by_name: up = approved_by_name[name] final_assignments.append({ "name": name, "action": up["action"], "priority": up["priority"], "reason_code": up["reason_code"], "source": up["source"], "estimated_extra_mb_vs_base": up["cost_mb"], "benefit": up["benefit"], "roi": up["roi"], "structural": False, }) else: final_assignments.append({ "name": name, "action": BASE_ACTION, "priority": 0.0, "reason_code": "base_default", "source": "compiler", "estimated_extra_mb_vs_base": 0.0, "structural": False, })return { "budget_mb_extra": float(extra_budget_mb), "used_mb_extra": round(float(spent_mb), 6), "structural_skips": len(structural), "upgrades_considered": len(upgrades), "upgrades_approved": len(approved), "assignments": sorted(final_assignments, key=lambda x: x["name"]),} +structural = []upgrades = []for item in assignments: name = item["name"] action = item["action"] row = probe_by_name.get(name) if row is None: continue if is_structural_skip(name, action): structural.append(item) continue if action == BASE_ACTION: continue if action == "w8_g128": benefit = row["candidates"]["w4_g128"]["mse"] - row["candidates"]["w8_g128"]["mse"] cost_mb = row["candidates"]["w8_g128"]["estimated_extra_mb_vs_base"] elif action == "w4_g64": benefit = row["candidates"]["w4_g128"]["mse"] - row["candidates"]["w4_g64"]["mse"] cost_mb = row["candidates"]["w4_g64"]["estimated_extra_mb_vs_base"] elif action == "skip": benefit = row["candidates"]["w4_g128"]["mse"] cost_mb = row["candidates"]["skip"]["estimated_extra_mb_vs_base"] else: continue if benefit <= 0 or cost_mb <= 0: continue roi = benefit / max(cost_mb, 1e-6) if roi < min_roi: continue upgrades.append({ "name": name, "action": action, "benefit": float(benefit), "cost_mb": float(cost_mb), "roi": float(roi), "priority": float(item.get("priority", 0.5)), "reason_code": item.get("reason_code", ""), "source": item.get("source", ""), })upgrades.sort(key=lambda x: (x["roi"], x["priority"]), reverse=True)approved = []spent_mb = 0.0for up in upgrades: if spent_mb + up["cost_mb"] > extra_budget_mb: continue approved.append(up) spent_mb += up["cost_mb"]structural_by_name = {x["name"]: x for x in structural}approved_by_name = {x["name"]: x for x in approved}final_assignments = []for row in probe_rows: name = row["name"] if name in structural_by_name: item = structural_by_name[name] final_assignments.append({ "name": name, "action": "skip", "priority": float(item.get("priority", 1.0)), "reason_code": item.get("reason_code", "structural_skip"), "source": item.get("source", "compiler"), "estimated_extra_mb_vs_base": row["candidates"]["skip"]["estimated_extra_mb_vs_base"], "structural": True, }) elif name in approved_by_name: up = approved_by_name[name] final_assignments.append({ "name": name, "action": up["action"], "priority": up["priority"], "reason_code": up["reason_code"], "source": up["source"], "estimated_extra_mb_vs_base": up["cost_mb"], "benefit": up["benefit"], "roi": up["roi"], "structural": False, }) else: final_assignments.append({ "name": name, "action": BASE_ACTION, "priority": 0.0, "reason_code": "base_default", "source": "compiler", "estimated_extra_mb_vs_base": 0.0, "structural": False, })return { "budget_mb_extra": float(extra_budget_mb), "used_mb_extra": round(float(spent_mb), 6), "structural_skips": len(structural), "upgrades_considered": len(upgrades), "upgrades_approved": len(approved), "assignments": sorted(final_assignments, key=lambda x: x["name"]),} +``` + +def plan_to_dynamic(plan: dict) -> dict: +dynamic = {} +for item in plan["assignments"]: +name = item["name"] +action = item["action"] +if action == BASE_ACTION: +continue +anchored = f"^{re.escape(name)}$" +if action == "skip": +dynamic[f"-:{anchored}"] = {} +else: +dynamic[f"+:{anchored}"] = action_to_override(action) +return dynamic + +def save_plan_and_dynamic(plan: dict, dynamic: dict, plan_out: str, dynamic_out: str) -> None: +Path(plan_out).parent.mkdir(parents=True, exist_ok=True) +Path(dynamic_out).parent.mkdir(parents=True, exist_ok=True) +Path(plan_out).write_text(json.dumps(plan, indent=2), encoding="utf-8") +Path(dynamic_out).write_text(json.dumps(dynamic, indent=2), encoding="utf-8") +EOF + +cat > quant_judge/runner.py <<'EOF' +from **future** import annotations + +import json +import subprocess +import sys +from pathlib import Path + +import torch +from datasets import load_dataset +from gptqmodel import GPTQModel, QuantizeConfig + +def load_texts(dataset_name: str, dataset_config: str | None, split: str, text_field: str, num_texts: int) -> list[str]: +ds = load_dataset(dataset_name, dataset_config, split=split) +texts = [] +for row in ds: +text = row.get(text_field, "") +if isinstance(text, str) and text.strip(): +texts.append(text.strip()) +if len(texts) >= num_texts: +break +if not texts: +raise ValueError("No texts found for quantization.") +return texts + +def quantize_with_dynamic( +model_id: str, +calibration_texts: list[str], +out_dir: str, +dynamic: dict | None, +bits: int = 4, +group_size: int = 128, +batch_size: int = 4, +sym: bool = True, +desc_act: bool = True, +) -> dict: +qcfg = QuantizeConfig( +bits=bits, +group_size=group_size, +sym=sym, +desc_act=desc_act, +dynamic=dynamic, +) +model = GPTQModel.load(model_id, qcfg) +model.quantize(calibration_texts, batch_size=batch_size) + +```out = Path(out_dir)out.mkdir(parents=True, exist_ok=True)model.save(str(out))meta = { "model_id": model_id, "out_dir": str(out), "bits": bits, "group_size": group_size, "batch_size": batch_size, "dynamic_rules": 0 if dynamic is None else len(dynamic), "sym": sym, "desc_act": desc_act,}(out / "quant_meta.json").write_text(json.dumps(meta, indent=2), encoding="utf-8")del modelif torch.cuda.is_available(): torch.cuda.empty_cache()return meta +out = Path(out_dir)out.mkdir(parents=True, exist_ok=True)model.save(str(out))meta = { "model_id": model_id, "out_dir": str(out), "bits": bits, "group_size": group_size, "batch_size": batch_size, "dynamic_rules": 0 if dynamic is None else len(dynamic), "sym": sym, "desc_act": desc_act,}(out / "quant_meta.json").write_text(json.dumps(meta, indent=2), encoding="utf-8")del modelif torch.cuda.is_available(): torch.cuda.empty_cache()return meta +``` + +def heuristic_dynamic_from_flags(flags: list[str]) -> dict: +dynamic = {} +for flag in flags: +if flag == "skip_lm_head": +dynamic[r"-:^lm_head$"] = {} +elif flag == "down_proj_g64": +dynamic[r"+:.*down_proj$"] = {"bits": 4, "group_size": 64} +elif flag == "o_proj_g64": +dynamic[r"+:.*o_proj$"] = {"bits": 4, "group_size": 64} +elif flag == "edge_blocks_w8": +dynamic[r"+:^model.layers.(0|1|30|31)..*$"] = {"bits": 8, "group_size": 128} +return dynamic + +def run_lm_eval(model_path: str, tasks: list[str], output_json: str, device: str = "cuda:0", batch_size: str = "auto") -> dict: +output_path = Path(output_json) +output_path.parent.mkdir(parents=True, exist_ok=True) + +```cmd = [ sys.executable, "-m", "lm_eval", "--model", "hf", "--model_args", f"pretrained={model_path},trust_remote_code=True", "--tasks", ",".join(tasks), "--device", device, "--batch_size", batch_size, "--output_path", str(output_path.parent),]subprocess.run(cmd, check=True)for path in sorted(output_path.parent.rglob("*.json")): try: data = json.loads(path.read_text(encoding="utf-8")) except Exception: continue if isinstance(data, dict) and "results" in data: output_path.write_text(json.dumps(data, indent=2), encoding="utf-8") return dataraise FileNotFoundError("Could not find lm_eval results json.") +cmd = [ sys.executable, "-m", "lm_eval", "--model", "hf", "--model_args", f"pretrained={model_path},trust_remote_code=True", "--tasks", ",".join(tasks), "--device", device, "--batch_size", batch_size, "--output_path", str(output_path.parent),]subprocess.run(cmd, check=True)for path in sorted(output_path.parent.rglob("*.json")): try: data = json.loads(path.read_text(encoding="utf-8")) except Exception: continue if isinstance(data, dict) and "results" in data: output_path.write_text(json.dumps(data, indent=2), encoding="utf-8") return dataraise FileNotFoundError("Could not find lm_eval results json.") +``` + +def extract_primary_scores(eval_json: dict) -> dict: +out = {} +for task, metrics in eval_json.get("results", {}).items(): +if "acc_norm,none" in metrics: +out[task] = metrics["acc_norm,none"] +elif "acc,none" in metrics: +out[task] = metrics["acc,none"] +elif "word_perplexity,none" in metrics: +out[task] = metrics["word_perplexity,none"] +return out +EOF + +cat > quant_judge/pipeline.py <<'EOF' +from **future** import annotations + +import argparse +import json +from pathlib import Path + +from .compiler import compile_plan, plan_to_dynamic, save_plan_and_dynamic +from .judge import JudgeConfig, load_rows, make_judge_proposals +from .probe import ProbeConfig, run_probe +from .runner import ( +extract_primary_scores, +heuristic_dynamic_from_flags, +load_texts, +quantize_with_dynamic, +run_lm_eval, +) + +def cmd_probe(args): +cfg = ProbeConfig( +model_id=args.model_id, +out_dir=args.out_dir, +dataset_name=args.dataset_name, +dataset_config=args.dataset_config, +split=args.split, +text_field=args.text_field, +num_texts=args.num_texts, +max_length=args.max_length, +max_rows_per_module=args.max_rows_per_module, +include_modules=[x.strip() for x in args.include_modules.split(",") if x.strip()], +probe_batch_size=args.probe_batch_size, +use_activation_aware_proxy=args.use_activation_aware_proxy, +sym=args.sym, +scale_bytes=args.scale_bytes, +zero_bytes_asym=args.zero_bytes_asym, +) +print(json.dumps(run_probe(cfg), indent=2)) + +def cmd_plan(args): +judge_cfg = JudgeConfig( +features_jsonl=args.features_jsonl, +budget_mb_extra=args.budget_mb_extra, +heuristic_only=args.heuristic_only, +judge_base_url=args.judge_base_url, +judge_model=args.judge_model, +judge_api_key_env=args.judge_api_key_env, +chunk_size=args.chunk_size, +) +proposals = make_judge_proposals(judge_cfg) +probe_rows = load_rows(args.features_jsonl) +plan = compile_plan(proposals["assignments"], probe_rows, args.budget_mb_extra, args.min_roi) +dynamic = plan_to_dynamic(plan) +save_plan_and_dynamic(plan, dynamic, args.plan_out, args.dynamic_out) +print(json.dumps({ +"plan_out": args.plan_out, +"dynamic_out": args.dynamic_out, +"used_mb_extra": plan["used_mb_extra"], +"structural_skips": plan["structural_skips"], +"upgrades_approved": plan["upgrades_approved"], +"dynamic_rules": len(dynamic), +}, indent=2)) + +def cmd_quantize(args): +calibration_texts = load_texts( +args.dataset_name, +args.dataset_config, +args.split, +args.text_field, +args.num_texts, +) +plan = json.loads(Path(args.plan_json).read_text(encoding="utf-8")) +dynamic = plan_to_dynamic(plan) +print(json.dumps( +quantize_with_dynamic( +args.model_id, +calibration_texts, +args.out_dir, +dynamic, +args.base_bits, +args.base_group_size, +args.batch_size, +args.sym, +args.desc_act, +), +indent=2, +)) + +def cmd_eval(args): +tasks = [x.strip() for x in args.tasks.split(",") if x.strip()] +result = run_lm_eval(args.model_path, tasks, args.output_json, args.device, args.batch_size) +print(json.dumps({ +"output_json": args.output_json, +"scores": extract_primary_scores(result), +}, indent=2)) + +def cmd_suite(args): +work_dir = Path(args.work_dir) +work_dir.mkdir(parents=True, exist_ok=True) +probe_dir = work_dir / "probe" + +```probe_result = run_probe(ProbeConfig( model_id=args.model_id, out_dir=str(probe_dir), dataset_name=args.dataset_name, dataset_config=args.dataset_config, split=args.split, text_field=args.text_field, num_texts=args.probe_num_texts, max_length=args.max_length, max_rows_per_module=args.max_rows_per_module, include_modules=[x.strip() for x in args.include_modules.split(",") if x.strip()], probe_batch_size=args.probe_batch_size, use_activation_aware_proxy=args.use_activation_aware_proxy, sym=args.sym, scale_bytes=args.scale_bytes, zero_bytes_asym=args.zero_bytes_asym,))features_jsonl = probe_result["features_jsonl"]probe_rows = load_rows(features_jsonl)# Rule-based baseline: fixed regexes onlyheuristic_flags = [x.strip() for x in args.heuristic_include.split(",") if x.strip()]rule_based_dynamic = heuristic_dynamic_from_flags(heuristic_flags)(work_dir / "rule_based_dynamic.json").write_text( json.dumps(rule_based_dynamic, indent=2), encoding="utf-8")# Data-driven heuristic: same feature-driven logic as the judge fallback, but labeled honestlyheuristic_proposals = make_judge_proposals(JudgeConfig( features_jsonl=features_jsonl, budget_mb_extra=args.budget_mb_extra, heuristic_only=True, chunk_size=args.chunk_size,))heuristic_plan = compile_plan( heuristic_proposals["assignments"], probe_rows, args.budget_mb_extra, args.min_roi,)data_driven_heuristic_dynamic = plan_to_dynamic(heuristic_plan)save_plan_and_dynamic( heuristic_plan, data_driven_heuristic_dynamic, str(work_dir / "data_driven_heuristic_plan.json"), str(work_dir / "data_driven_heuristic_dynamic.json"),)configs = { "baseline_w4g128": None, "rule_based_baseline": rule_based_dynamic, "data_driven_heuristic": data_driven_heuristic_dynamic,}judge_used = Falseif args.judge_base_url and args.judge_model: judge_used = True judge_proposals = make_judge_proposals(JudgeConfig( features_jsonl=features_jsonl, budget_mb_extra=args.budget_mb_extra, heuristic_only=False, judge_base_url=args.judge_base_url, judge_model=args.judge_model, judge_api_key_env=args.judge_api_key_env, chunk_size=args.chunk_size, )) judge_plan = compile_plan( judge_proposals["assignments"], probe_rows, args.budget_mb_extra, args.min_roi, ) judge_dynamic = plan_to_dynamic(judge_plan) save_plan_and_dynamic( judge_plan, judge_dynamic, str(work_dir / "judge_plan.json"), str(work_dir / "judge_dynamic.json"), ) configs["judge_sdq"] = judge_dynamiccalibration_texts = load_texts( args.dataset_name, args.dataset_config, args.split, args.text_field, args.quant_num_texts,)tasks = [x.strip() for x in args.tasks.split(",") if x.strip()]leaderboard = {}for name, dynamic in configs.items(): model_dir = work_dir / "models" / name eval_json = work_dir / "eval" / f"{name}.json" quant_meta = quantize_with_dynamic( args.model_id, calibration_texts, str(model_dir), dynamic, args.base_bits, args.base_group_size, args.batch_size, args.sym, args.desc_act, ) result = run_lm_eval( str(model_dir), tasks, str(eval_json), args.device, args.eval_batch_size, ) leaderboard[name] = { "scores": extract_primary_scores(result), "quant_meta": quant_meta, }payload = { "probe": probe_result, "leaderboard": leaderboard, "judge_used": judge_used, "artifacts": { "rule_based_dynamic_path": str(work_dir / "rule_based_dynamic.json"), "data_driven_heuristic_plan_path": str(work_dir / "data_driven_heuristic_plan.json"), "data_driven_heuristic_dynamic_path": str(work_dir / "data_driven_heuristic_dynamic.json"), "judge_plan_path": str(work_dir / "judge_plan.json") if judge_used else None, "judge_dynamic_path": str(work_dir / "judge_dynamic.json") if judge_used else None, },}(work_dir / "leaderboard.json").write_text(json.dumps(payload, indent=2), encoding="utf-8")print(json.dumps(payload, indent=2)) +probe_result = run_probe(ProbeConfig( model_id=args.model_id, out_dir=str(probe_dir), dataset_name=args.dataset_name, dataset_config=args.dataset_config, split=args.split, text_field=args.text_field, num_texts=args.probe_num_texts, max_length=args.max_length, max_rows_per_module=args.max_rows_per_module, include_modules=[x.strip() for x in args.include_modules.split(",") if x.strip()], probe_batch_size=args.probe_batch_size, use_activation_aware_proxy=args.use_activation_aware_proxy, sym=args.sym, scale_bytes=args.scale_bytes, zero_bytes_asym=args.zero_bytes_asym,))features_jsonl = probe_result["features_jsonl"]probe_rows = load_rows(features_jsonl)# Rule-based baseline: fixed regexes onlyheuristic_flags = [x.strip() for x in args.heuristic_include.split(",") if x.strip()]rule_based_dynamic = heuristic_dynamic_from_flags(heuristic_flags)(work_dir / "rule_based_dynamic.json").write_text( json.dumps(rule_based_dynamic, indent=2), encoding="utf-8")# Data-driven heuristic: same feature-driven logic as the judge fallback, but labeled honestlyheuristic_proposals = make_judge_proposals(JudgeConfig( features_jsonl=features_jsonl, budget_mb_extra=args.budget_mb_extra, heuristic_only=True, chunk_size=args.chunk_size,))heuristic_plan = compile_plan( heuristic_proposals["assignments"], probe_rows, args.budget_mb_extra, args.min_roi,)data_driven_heuristic_dynamic = plan_to_dynamic(heuristic_plan)save_plan_and_dynamic( heuristic_plan, data_driven_heuristic_dynamic, str(work_dir / "data_driven_heuristic_plan.json"), str(work_dir / "data_driven_heuristic_dynamic.json"),)configs = { "baseline_w4g128": None, "rule_based_baseline": rule_based_dynamic, "data_driven_heuristic": data_driven_heuristic_dynamic,}judge_used = Falseif args.judge_base_url and args.judge_model: judge_used = True judge_proposals = make_judge_proposals(JudgeConfig( features_jsonl=features_jsonl, budget_mb_extra=args.budget_mb_extra, heuristic_only=False, judge_base_url=args.judge_base_url, judge_model=args.judge_model, judge_api_key_env=args.judge_api_key_env, chunk_size=args.chunk_size, )) judge_plan = compile_plan( judge_proposals["assignments"], probe_rows, args.budget_mb_extra, args.min_roi, ) judge_dynamic = plan_to_dynamic(judge_plan) save_plan_and_dynamic( judge_plan, judge_dynamic, str(work_dir / "judge_plan.json"), str(work_dir / "judge_dynamic.json"), ) configs["judge_sdq"] = judge_dynamiccalibration_texts = load_texts( args.dataset_name, args.dataset_config, args.split, args.text_field, args.quant_num_texts,)tasks = [x.strip() for x in args.tasks.split(",") if x.strip()]leaderboard = {}for name, dynamic in configs.items(): model_dir = work_dir / "models" / name eval_json = work_dir / "eval" / f"{name}.json" quant_meta = quantize_with_dynamic( args.model_id, calibration_texts, str(model_dir), dynamic, args.base_bits, args.base_group_size, args.batch_size, args.sym, args.desc_act, ) result = run_lm_eval( str(model_dir), tasks, str(eval_json), args.device, args.eval_batch_size, ) leaderboard[name] = { "scores": extract_primary_scores(result), "quant_meta": quant_meta, }payload = { "probe": probe_result, "leaderboard": leaderboard, "judge_used": judge_used, "artifacts": { "rule_based_dynamic_path": str(work_dir / "rule_based_dynamic.json"), "data_driven_heuristic_plan_path": str(work_dir / "data_driven_heuristic_plan.json"), "data_driven_heuristic_dynamic_path": str(work_dir / "data_driven_heuristic_dynamic.json"), "judge_plan_path": str(work_dir / "judge_plan.json") if judge_used else None, "judge_dynamic_path": str(work_dir / "judge_dynamic.json") if judge_used else None, },}(work_dir / "leaderboard.json").write_text(json.dumps(payload, indent=2), encoding="utf-8")print(json.dumps(payload, indent=2)) +``` + +def build_parser(): +p = argparse.ArgumentParser(prog="quant_judge") +sub = p.add_subparsers(dest="cmd", required=True) + +```probe = sub.add_parser("probe")probe.add_argument("--model-id", required=True)probe.add_argument("--out-dir", required=True)probe.add_argument("--dataset-name", default="wikitext")probe.add_argument("--dataset-config", default="wikitext-2-raw-v1")probe.add_argument("--split", default="train")probe.add_argument("--text-field", default="text")probe.add_argument("--num-texts", type=int, default=64)probe.add_argument("--max-length", type=int, default=256)probe.add_argument("--max-rows-per-module", type=int, default=128)probe.add_argument("--probe-batch-size", type=int, default=8)probe.add_argument("--include-modules", default="q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj,lm_head")probe.add_argument("--use-activation-aware-proxy", action="store_true")probe.add_argument("--sym", action="store_true", default=True)probe.add_argument("--scale-bytes", type=float, default=2.0)probe.add_argument("--zero-bytes-asym", type=float, default=0.5)probe.set_defaults(func=cmd_probe)plan = sub.add_parser("plan")plan.add_argument("--features-jsonl", required=True)plan.add_argument("--plan-out", required=True)plan.add_argument("--dynamic-out", required=True)plan.add_argument("--budget-mb-extra", type=float, default=150.0)plan.add_argument("--min-roi", type=float, default=1e-6)plan.add_argument("--heuristic-only", action="store_true")plan.add_argument("--judge-base-url", default=None)plan.add_argument("--judge-model", default=None)plan.add_argument("--judge-api-key-env", default="JUDGE_API_KEY")plan.add_argument("--chunk-size", type=int, default=24)plan.set_defaults(func=cmd_plan)quant = sub.add_parser("quantize")quant.add_argument("--model-id", required=True)quant.add_argument("--plan-json", required=True)quant.add_argument("--out-dir", required=True)quant.add_argument("--dataset-name", default="wikitext")quant.add_argument("--dataset-config", default="wikitext-2-raw-v1")quant.add_argument("--split", default="train")quant.add_argument("--text-field", default="text")quant.add_argument("--num-texts", type=int, default=128)quant.add_argument("--batch-size", type=int, default=4)quant.add_argument("--base-bits", type=int, default=4)quant.add_argument("--base-group-size", type=int, default=128)quant.add_argument("--sym", action="store_true", default=True)quant.add_argument("--desc-act", action="store_true", default=True)quant.set_defaults(func=cmd_quantize)ev = sub.add_parser("eval")ev.add_argument("--model-path", required=True)ev.add_argument("--tasks", default="arc_challenge,hellaswag")ev.add_argument("--output-json", required=True)ev.add_argument("--device", default="cuda:0")ev.add_argument("--batch-size", default="auto")ev.set_defaults(func=cmd_eval)suite = sub.add_parser("suite")suite.add_argument("--model-id", required=True)suite.add_argument("--work-dir", required=True)suite.add_argument("--dataset-name", default="wikitext")suite.add_argument("--dataset-config", default="wikitext-2-raw-v1")suite.add_argument("--split", default="train")suite.add_argument("--text-field", default="text")suite.add_argument("--probe-num-texts", type=int, default=64)suite.add_argument("--quant-num-texts", type=int, default=128)suite.add_argument("--max-length", type=int, default=256)suite.add_argument("--max-rows-per-module", type=int, default=128)suite.add_argument("--probe-batch-size", type=int, default=8)suite.add_argument("--include-modules", default="q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj,lm_head")suite.add_argument("--use-activation-aware-proxy", action="store_true")suite.add_argument("--budget-mb-extra", type=float, default=150.0)suite.add_argument("--min-roi", type=float, default=1e-6)suite.add_argument("--heuristic-include", default="skip_lm_head,down_proj_g64")suite.add_argument("--judge-base-url", default=None)suite.add_argument("--judge-model", default=None)suite.add_argument("--judge-api-key-env", default="JUDGE_API_KEY")suite.add_argument("--chunk-size", type=int, default=24)suite.add_argument("--tasks", default="arc_challenge,hellaswag")suite.add_argument("--device", default="cuda:0")suite.add_argument("--batch-size", type=int, default=4)suite.add_argument("--eval-batch-size", default="auto")suite.add_argument("--base-bits", type=int, default=4)suite.add_argument("--base-group-size", type=int, default=128)suite.add_argument("--sym", action="store_true", default=True)suite.add_argument("--desc-act", action="store_true", default=True)suite.add_argument("--scale-bytes", type=float, default=2.0)suite.add_argument("--zero-bytes-asym", type=float, default=0.5)suite.set_defaults(func=cmd_suite)return p +probe = sub.add_parser("probe")probe.add_argument("--model-id", required=True)probe.add_argument("--out-dir", required=True)probe.add_argument("--dataset-name", default="wikitext")probe.add_argument("--dataset-config", default="wikitext-2-raw-v1")probe.add_argument("--split", default="train")probe.add_argument("--text-field", default="text")probe.add_argument("--num-texts", type=int, default=64)probe.add_argument("--max-length", type=int, default=256)probe.add_argument("--max-rows-per-module", type=int, default=128)probe.add_argument("--probe-batch-size", type=int, default=8)probe.add_argument("--include-modules", default="q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj,lm_head")probe.add_argument("--use-activation-aware-proxy", action="store_true")probe.add_argument("--sym", action="store_true", default=True)probe.add_argument("--scale-bytes", type=float, default=2.0)probe.add_argument("--zero-bytes-asym", type=float, default=0.5)probe.set_defaults(func=cmd_probe)plan = sub.add_parser("plan")plan.add_argument("--features-jsonl", required=True)plan.add_argument("--plan-out", required=True)plan.add_argument("--dynamic-out", required=True)plan.add_argument("--budget-mb-extra", type=float, default=150.0)plan.add_argument("--min-roi", type=float, default=1e-6)plan.add_argument("--heuristic-only", action="store_true")plan.add_argument("--judge-base-url", default=None)plan.add_argument("--judge-model", default=None)plan.add_argument("--judge-api-key-env", default="JUDGE_API_KEY")plan.add_argument("--chunk-size", type=int, default=24)plan.set_defaults(func=cmd_plan)quant = sub.add_parser("quantize")quant.add_argument("--model-id", required=True)quant.add_argument("--plan-json", required=True)quant.add_argument("--out-dir", required=True)quant.add_argument("--dataset-name", default="wikitext")quant.add_argument("--dataset-config", default="wikitext-2-raw-v1")quant.add_argument("--split", default="train")quant.add_argument("--text-field", default="text")quant.add_argument("--num-texts", type=int, default=128)quant.add_argument("--batch-size", type=int, default=4)quant.add_argument("--base-bits", type=int, default=4)quant.add_argument("--base-group-size", type=int, default=128)quant.add_argument("--sym", action="store_true", default=True)quant.add_argument("--desc-act", action="store_true", default=True)quant.set_defaults(func=cmd_quantize)ev = sub.add_parser("eval")ev.add_argument("--model-path", required=True)ev.add_argument("--tasks", default="arc_challenge,hellaswag")ev.add_argument("--output-json", required=True)ev.add_argument("--device", default="cuda:0")ev.add_argument("--batch-size", default="auto")ev.set_defaults(func=cmd_eval)suite = sub.add_parser("suite")suite.add_argument("--model-id", required=True)suite.add_argument("--work-dir", required=True)suite.add_argument("--dataset-name", default="wikitext")suite.add_argument("--dataset-config", default="wikitext-2-raw-v1")suite.add_argument("--split", default="train")suite.add_argument("--text-field", default="text")suite.add_argument("--probe-num-texts", type=int, default=64)suite.add_argument("--quant-num-texts", type=int, default=128)suite.add_argument("--max-length", type=int, default=256)suite.add_argument("--max-rows-per-module", type=int, default=128)suite.add_argument("--probe-batch-size", type=int, default=8)suite.add_argument("--include-modules", default="q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj,lm_head")suite.add_argument("--use-activation-aware-proxy", action="store_true")suite.add_argument("--budget-mb-extra", type=float, default=150.0)suite.add_argument("--min-roi", type=float, default=1e-6)suite.add_argument("--heuristic-include", default="skip_lm_head,down_proj_g64")suite.add_argument("--judge-base-url", default=None)suite.add_argument("--judge-model", default=None)suite.add_argument("--judge-api-key-env", default="JUDGE_API_KEY")suite.add_argument("--chunk-size", type=int, default=24)suite.add_argument("--tasks", default="arc_challenge,hellaswag")suite.add_argument("--device", default="cuda:0")suite.add_argument("--batch-size", type=int, default=4)suite.add_argument("--eval-batch-size", default="auto")suite.add_argument("--base-bits", type=int, default=4)suite.add_argument("--base-group-size", type=int, default=128)suite.add_argument("--sym", action="store_true", default=True)suite.add_argument("--desc-act", action="store_true", default=True)suite.add_argument("--scale-bytes", type=float, default=2.0)suite.add_argument("--zero-bytes-asym", type=float, default=0.5)suite.set_defaults(func=cmd_suite)return p +``` + +def main(): +parser = build_parser() +args = parser.parse_args() +args.func(args) + +if **name** == "**main**": +main() +EOF + +```Then install and run it:```bashpython -m venv .venvsource .venv/bin/activate# install your CUDA-matching torch first# example:# pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128pip install -r requirements.txt +Then install and run it:```bashpython -m venv .venvsource .venv/bin/activate# install your CUDA-matching torch first# example:# pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128pip install -r requirements.txt +``` + +Starter suite run: + +```Bashpython -m quant_judge.pipeline suite \ --model-id Qwen/Qwen3.5-4B-Base \ --work-dir ./runs/suite \ --dataset-name wikitext \ --dataset-config wikitext-2-raw-v1 \ --split train \ --text-field text \ --probe-num-texts 64 \ --quant-num-texts 128 \ --tasks arc_challenge,hellaswag +Bashpython -m quant_judge.pipeline suite \ --model-id Qwen/Qwen3.5-4B-Base \ --work-dir ./runs/suite \ --dataset-name wikitext \ --dataset-config wikitext-2-raw-v1 \ --split train \ --text-field text \ --probe-num-texts 64 \ --quant-num-texts 128 \ --tasks arc_challenge,hellaswag +``` + +Judge-backed suite run: + +```Bashexport JUDGE_API_KEY=your_key_herepython -m quant_judge.pipeline suite \ --model-id Qwen/Qwen3.5-4B-Base \ --work-dir ./runs/suite \ --dataset-name wikitext \ --dataset-config wikitext-2-raw-v1 \ --split train \ --text-field text \ --probe-num-texts 64 \ --quant-num-texts 128 \ --tasks arc_challenge,hellaswag \ --judge-base-url http://127.0.0.1:8000/v1 \ --judge-model Qwen/Qwen2.5-14B-Instruct +Bashexport JUDGE_API_KEY=your_key_herepython -m quant_judge.pipeline suite \ --model-id Qwen/Qwen3.5-4B-Base \ --work-dir ./runs/suite \ --dataset-name wikitext \ --dataset-config wikitext-2-raw-v1 \ --split train \ --text-field text \ --probe-num-texts 64 \ --quant-num-texts 128 \ --tasks arc_challenge,hellaswag \ --judge-base-url http://127.0.0.1:8000/v1 \ --judge-model Qwen/Qwen2.5-14B-Instruct +``` + +If you want, next I can give you a **single self-extracting .sh installer** version so you only save one file and run it. + +--- + +### User: +Here is a detailed critique of your SDQ Quant Judge v4 pipeline, followed by a concrete guide on how to inject **ExLlamaV3 (EXL3)** and **ExLlamaV2 (EXL2)** logic into your pipeline to drastically improve accuracy *without* breaking standard GPTQ kernel compatibility. + +--- + +### Part 1: Critique of SDQ Quant Judge v4 + +Your pipeline is a great exploration of Software-Defined Quantization, but it has severe mathematical bottlenecks if the goal is SOTA accuracy: + +1. **The Probe Fallacy (RTN vs. GPTQ):** + Your probe uses `fake_quant_groupwise_linear_weight` (Round-To-Nearest) and an AWQ-like activation proxy to measure MSE. This is mathematically mismatched with actual GPTQ. GPTQ uses the Inverse Hessian to compensate for quantization errors in subsequent weights. A layer might have terrible RTN error but near-zero GPTQ error because the Hessian allows adjacent weights to absorb the shock. Judging modules based on RTN/AWQ error leads to suboptimal budget allocation. +2. **LLMs are Bad at Knapsack Problems:** + Using an LLM (`Qwen2.5-14B-Instruct`) to allocate bit-widths based on heuristic scores is fundamentally flawed. Bit allocation to hit a target VRAM budget is a classic **constrained optimization / knapsack problem**. A deterministic mathematical solver will beat an LLM 100% of the time in both speed and mathematical optimality. +3. **Coarse Granularity:** + You apply mixed precision at the module level (e.g., `down_proj` gets `w8_g128`, `up_proj` gets `w4_g128`). Modern formats (EXL2/EXL3) allocate bits at the *group* or *column* level within the *same* tensor. +4. **Wasted Space on the LM Head:** + You are skipping the `lm_head` (leaving it at FP16). In a model like Llama-3-70B, the LM head is 2+ GB. EXL3 achieves massive VRAM savings by quantizing the LM head to 6-bit or 8-bit, rather than skipping it entirely. + +--- + +### Part 2: Implementing EXL3 / EXL2 Logic within Standard GPTQ + +**ExLlamaV3 (EXL3)** is based on **QTIP** (Quantization with Trellis and Incoherent Processing). True EXL3 uses a Walsh-Hadamard transform and procedural codebooks, which *require* a custom kernel. + +However, you can emulate the "magic" of EXL3 and EXL2 **without modifying the standard GPTQ inference kernel** (meaning your models will still run flawlessly in vLLM, AutoGPTQ, or GPTQModel). You do this by adopting their two main algorithms during the *quantization step*: + +#### Mechanism 1: Trellis/Beam-Search Rounding (The EXL3 / QTIP Secret) +Standard GPTQ uses **greedy rounding**. It looks at one weight, rounds it to the nearest INT4 grid point, and pushes the error to the remaining weights using the Inverse Hessian. + +EXL3/QTIP uses **Trellis Coded Quantization (Viterbi Search)**. Instead of instantly rounding, it looks ahead. It tests multiple rounding paths over a block of weights and picks the sequence of quantizations that results in the lowest global Hessian-weighted error. + +You can apply this to the standard INT4/INT8 linear grid! During your GPTQ quantization loop, replace the greedy `torch.round()` with a localized Beam Search: + +``` +python +# CONCEPTUAL IMPLEMENTATION FOR YOUR GPTQ RUNNER +# Inside the GPTQ quantization block (iterating over columns): + +def beam_search_quantize(weights, scales, zeros, H_inv, block_size=8, beam_width=4): + """ + Replaces greedy GPTQ rounding with EXL3-style lookahead Trellis/Beam search. + Outputs standard GPTQ compatible integers. + """ + beams =[{"q_vals": [], "error_buffer": torch.zeros_like(weights[:, 0]), "cost": 0.0}] + + for i in range(block_size): + w_current = weights[:, i] + new_beams =[] + + for beam in beams: + # Apply accumulated error from previous weights in the beam + w_adjusted = w_current + beam["error_buffer"] + + # Instead of greedy rounding, consider the 2 nearest integer states + q_float = w_adjusted / scales + zeros + q_floor = torch.floor(q_float) + q_ceil = torch.ceil(q_float) + + for q_choice in [q_floor, q_ceil]: + q_choice = torch.clamp(q_choice, 0, 15) # Standard INT4 bounds + w_quantized = (q_choice - zeros) * scales + + # Calculate quantization error + err = w_adjusted - w_quantized + + # Update cost (Hessian weighted MSE) + # H_diag is the diagonal of the inverse Hessian + step_cost = (err ** 2) / H_inv[i, i] + + # Propagate error to remaining weights in the block + # err * (H_inv_row / H_inv_diag) + push_error = err.unsqueeze(1) @ (H_inv[i, i+1:block_size] / H_inv[i, i]).unsqueeze(0) + + new_beams.append({ + "q_vals": beam["q_vals"] + [q_choice], + "error_buffer": beam["error_buffer"][:, 1:] + push_error, + "cost": beam["cost"] + step_cost.sum().item() + }) + + # Prune to top-K (Beam Width) + new_beams.sort(key=lambda x: x["cost"]) + beams = new_beams[:beam_width] + + # Return the optimal sequence of standard INT4 values from the best path + return torch.stack(beams[0]["q_vals"], dim=1), beams[0]["error_buffer"] +``` + +*Why this works:* This drops your quantization error (KL divergence) significantly at low bitrates without touching the inference engine. Your pipeline outputs a perfectly standard `qweight` tensor, but the values inside it have been globally optimized rather than locally guessed. + +#### Mechanism 2: Hessian-Trace Lagrangian Bit Allocation (The EXL2 Secret) +Instead of using your probe to measure RTN error and asking an LLM to guess bitrates, you should measure the **Hessian Trace** of each layer (or group) during calibration. + +The formula EXL2 uses to allocate bits minimizes $\sum \text{Trace}(H) \cdot \text{Error}(b)$ subject to $\sum b \le \text{Budget}$. This is solved deterministically via Binary Search: + +``` +python +import math + +def exl2_allocate_bits(layers, target_bpw): + """ + layers: list of dicts with 'params' (count) and 'hessian_trace' (float) + target_bpw: e.g. 4.5 + """ + total_params = sum(L['params'] for L in layers) + target_bits = total_params * target_bpw + + # Available standard GPTQ bit choices (can be mixed layer-by-layer) + bit_options =[4, 8] # You can add 3, 5, 6 if your kernel supports it + + # Define an error function for a bit width (empirically, error scales exponentially inversely with bits) + def expected_error(trace, bits): + return trace * (2.0 ** (-2.0 * bits)) + + # Binary search for the Lagrangian multiplier (lambda) + lambda_low, lambda_high = 0.0, 1e9 + + for _ in range(60): + lambda_mid = (lambda_low + lambda_high) / 2 + + current_bits = 0 + for L in layers: + best_bits = None + best_cost = float('inf') + + # Pick the bit width that minimizes: Error + lambda * Bits + for b in bit_options: + cost = expected_error(L['hessian_trace'], b) + lambda_mid * b * L['params'] + if cost < best_cost: + best_cost = cost + best_bits = b + + L['assigned_bits'] = best_bits + current_bits += best_bits * L['params'] + + if current_bits > target_bits: + lambda_low = lambda_mid # Budget exceeded, increase penalty + else: + lambda_high = lambda_mid + + return layers +``` + +### Summary of How to Refactor Your Architecture: + +1. **Delete the LLM Judge / Probe step:** RTN fake-quant is misleading, and the LLM is mathematically inferior to a solver. +2. **One-Pass Pipeline:** Run calibration data through the model. Hook into each layer to capture its inputs. Calculate the `Hessian = X.T @ X`. +3. **Save the Hessian Traces:** Store the trace of each layer's Hessian. +4. **Determine the Plan deterministically:** Feed the traces into the EXL2 Lagrangian solver above to figure out exactly which layers get 4-bit, 8-bit, or `lm_head` 6-bit. +5. **Apply Trellis Quantization:** When you actually call your GPTQ quantizer, inject the Trellis/Beam Search algorithm over the standard INT grid to finalize the `qweights`. + +By making these changes, your "Software-Defined" pipeline evolves from a heuristic guesser into a mathematically rigorous engine that generates **SOTA-accuracy standard GPTQ checkpoints**, directly capturing the benefits of EXL2/EXL3's research while remaining 100% compatible with existing vLLM/AutoGPTQ infrastructure. + +--- + +### ChatGPT: +Your critique is mostly right, with one big caveat: + +**You can borrow some EXL2/EXL3 ideas for a standard GPTQ pipeline, but you cannot get “full EXL2/EXL3” while staying in plain FORMAT.GPTQ.** EXL2’s main win is variable bitrate **within** a linear layer, and ExLlamaV2 explicitly says EXL2 can mix quantization levels inside each linear layer to hit arbitrary average bitrates between 2 and 8 bpw. EXL3 is a QTIP-style trellis/codebook format, and ExLlamaV3’s own docs say EXL3 is a variant of QTIP with its own packing and kernels. Standard GPTQ runtimes instead expect a conventional GPTQ checkpoint with a fixed `weight_bits`/packing scheme per quantized tensor path. [GitHub+2GitHub+2](https://github.com/turboderp-org/exllamav2?utm_source=chatgpt.com) + +So I’d split your proposal into two categories: + +**Worth doing inside standard GPTQ** + +- Replace the LLM judge with a deterministic allocator. +- Use Hessian-derived layer importance instead of RTN/AWQ-proxy MSE as the primary allocation signal. +- Improve the rounding/search step during quantization, as long as the emitted checkpoint is still ordinary GPTQ `qweight/scales/zeros` with supported bitwidths. [arXiv+2arXiv+2](https://arxiv.org/pdf/2210.17323?utm_source=chatgpt.com) + +**Not portable to plain GPTQ without changing format/kernel** + +- EXL2-style variable bitrates **inside the same tensor** +- EXL3/QTIP procedural codebooks, incoherence processing, and trellis-packed representations +- 6-bit EXL2 `lm_head` in a “standard GPTQ everywhere” deployment path, because current vLLM GPTQ support is documented around 2/3/4/8-bit GPTQ weights, not 6-bit GPTQ. EXL2 supports 6-bit, but that is an EXL2 feature, not a generic GPTQ interchange guarantee. [GitHub+2GitHub+2](https://github.com/turboderp-org/exllamav2?utm_source=chatgpt.com) + +I also agree with your criticism of the current SDQ v4 pipeline: **if the goal is best accuracy under a hard memory budget, an LLM judge is the wrong optimizer.** ExLlamaV2’s own converter is already a strong precedent here: it does a measurement pass first, then chooses quantization parameters to minimize overall error while meeting a target average bitrate. That is fundamentally a measurement + solver workflow, not an instruction-following workflow. [GitHub](https://github.com/turboderp-org/exllamav2/blob/master/doc/convert.md) + +## What I would actually change + +### 1) Remove the LLM from the decision loop + +Keep the “SDQ shell” if you like the architecture, but replace the judge with a deterministic allocator. + +Use: + +- one calibration pass +- per-layer or per-module Hessian statistics +- a constrained solver that assigns from a small supported action set + +For a **standard GPTQ-compatible** target, the practical action set is: + +- `4-bit` +- `8-bit` +- `skip` +- optionally `2-bit` or `3-bit` only if your intended runtime really supports those paths end-to-end + +That is because vLLM’s GPTQ path documents support for 2/3/4/8-bit GPTQ weights, and GPTQModel’s `dynamic` override surface is per-module, not per-column-inside-a-tensor. [vLLM+2GitHub+2](https://docs.vllm.ai/en/latest/api/vllm/model_executor/layers/quantization/gptq/) + +### 2) Replace the probe with Hessian-derived importance + +Your “probe fallacy” point is right. GPTQ is explicitly a second-order method using approximate Hessian information, so RTN-style proxy error is not the same objective. A better standard-GPTQ-compatible controller is: + +- collect input activations `X` for each target layer/module +- estimate a Hessian proxy such as `H = XᵀX / N` +- use `trace(H)`, `diag(H)`, or a small blockwise approximation as the layer/group importance score +- allocate bits with a solver under a fixed memory budget + +That is much closer to GPTQ’s actual error model than RTN/AWQ-proxy MSE. [arXiv+2arXiv+2](https://arxiv.org/pdf/2210.17323?utm_source=chatgpt.com) + +### 3) Use a deterministic allocator, not an LLM + +Your Lagrangian / knapsack framing is the correct direction. + +A practical version for **plain GPTQ compatibility** is: + +- estimate each layer’s expected distortion at `b ∈ {4, 8}` from Hessian-weighted quantization error +- solve for the best per-layer assignment under a target extra-memory budget +- compile the result into `QuantizeConfig.dynamic` + +That maps cleanly onto GPTQModel/vLLM’s documented per-module dynamic overrides and skip rules. [GitHub+1](https://github.com/ModelCloud/GPTQModel) + +### 4) Keep trellis/beam search only if the output stays ordinary GPTQ + +This part is subtle. + +Your proposed beam-search rounding can be useful **if**: + +- you still quantize onto the ordinary scalar INT grid +- you still emit standard GPTQ-compatible `qweight/scales/zeros` +- you do not introduce EXL3/QTIP-specific packing, transforms, or codebooks + +In that case, you are not “making GPTQ into EXL3”; you are just replacing a greedy local rounding step with a better search step before writing a normal GPTQ checkpoint. That can be compatible in principle. But the moment you rely on EXL3/QTIP’s codebook/trellis format or EXL2’s mixed-bit encoding inside one tensor, you have left plain GPTQ interchange behind. [GitHub+3arXiv+3GitHub+3](https://arxiv.org/abs/2406.11235?utm_source=chatgpt.com) + +### 5) Do not plan around a 6-bit lm_head if the target is “standard GPTQ everywhere” + +This is one place where your proposal overreaches. + +ExLlamaV2’s converter explicitly exposes special `lm_head` bit choices including 6 and 8, and even notes that 6/8 are the useful options there. But that is **EXL2 tooling**, not the standard GPTQ contract you can assume across GPTQModel + vLLM GPTQ. vLLM’s GPTQ path documents 2/3/4/8-bit GPTQ support, not 6-bit GPTQ. So if you want cross-runtime standard GPTQ compatibility, plan around: + +- `lm_head` skip +- or `lm_head` 8-bit GPTQ +- not 6-bit EXL2-style `lm_head` as a general deployment assumption. [GitHub+1](https://github.com/turboderp-org/exllamav2/blob/master/doc/convert.md) + +### 6) Be careful with the “100% compatible with existing vLLM/AutoGPTQ infrastructure” claim + +That needs qualification. + +GPTQModel does support dynamic per-module mixed quantization and vLLM integration, and vLLM documents GPTQModel-based 4-bit and 8-bit GPTQ models. But there are kernel-path caveats: for example, vLLM’s `gptq_marlin` code explicitly rejects 8-bit weights in the modular-kernel path shown in the docs. GPTQModel also documents dynamic quantization as per-module override, not EXL2-style intra-layer mixed bits. So “compatible” is true for standard GPTQ-style outputs in supported configurations, but not as a blanket statement for every EXL2/EXL3 idea. [vLLM+2vLLM+2](https://docs.vllm.ai/en/latest/features/quantization/gptqmodel/) + +## The refactor I’d recommend + +If your real goal is **best possible quality while staying in standard GPTQ**, I would turn v4 into this: + +**Stage A — Measurement** + +- Load the FP model with `transformers` +- Collect calibration activations +- Compute per-layer Hessian proxy stats +- Optionally still keep your activation-aware fake-quant probe as a secondary diagnostic, but not as the allocator’s main signal + +**Stage B — Allocation** + +- Solve a deterministic budgeted assignment for each module: + +`w4_g128` + +`w8_g128` + +`skip` +- `w4_g128` +- `w8_g128` +- `skip` +- Treat `lm_head` as either: + +structural `skip`, or + +budgetable `8-bit`, depending deployment target +- structural `skip`, or +- budgetable `8-bit`, depending deployment target + +**Stage C — Quantization** + +- Quantize with GPTQ/GPTQModel +- If you modify the rounding step, keep the emitted checkpoint in ordinary GPTQ format + +**Stage D — Evaluation** + +- Compare: + +baseline GPTQ + +deterministic Hessian allocator + +optional improved-rounding GPTQ +- baseline GPTQ +- deterministic Hessian allocator +- optional improved-rounding GPTQ +- Evaluate with `lm-eval`, which supports transformers-loaded GPTQModel/AutoGPTQ models and also supports vLLM. [GitHub](https://github.com/EleutherAI/lm-evaluation-harness?utm_source=chatgpt.com) + +## Bottom line + +Your critique is strongest on these points: + +- RTN/AWQ-proxy MSE is the wrong main objective for GPTQ-style allocation +- LLM judging is inferior to a solver for bitrate allocation +- EXL2/EXL3 quality comes partly from **format-level** advantages, not just better heuristics + +Where I’d push back is the claim that you can simply “capture EXL2/EXL3 while remaining 100% standard GPTQ-compatible.” You can capture **some** of the value: + +- better measurement +- better bit allocation +- possibly better search/rounding + +But you cannot capture the full value of: + +- EXL2’s intra-tensor mixed bitrates +- EXL3/QTIP trellis/codebook packing + +without leaving plain GPTQ format and plain GPTQ kernels behind. [GitHub+2GitHub+2](https://github.com/turboderp-org/exllamav2?utm_source=chatgpt.com) + +The best next step is to rip out the LLM judge entirely and replace it with a Hessian-trace allocator. After that, if you want, I’ll write you a **v5 refactor** that does exactly this: + +- Hessian measurement pass +- deterministic per-layer 4/8/skip allocator +- standard GPTQModel export +- optional beam-search rounding hook +- honest benchmark suite against your current v4 baseline + +--- + +### User: +You are completely right on all fronts. This is a brilliant clarification of the boundary between **format-level innovations** (EXL2 mixed intra-tensor bits, EXL3 QTIP codebooks) and **algorithmic optimization** (better measurement and allocation) that can still output a standard `qweight/zeros/scales` GPTQ checkpoint. + +You cannot force vLLM to read 6-bit EXL2 `lm_head`s through the `gptq_marlin` backend, but you *can* emit a mathematically superior 4/8/16-bit GPTQ checkpoint by replacing RTN + LLM judging with Hessian Traces + Lagrangian constraints. + +As requested, here is the **SDQ Quant v5** refactor in a single self-extracting `.sh` script. It removes the LLM API entirely, implements a fast Hessian-trace calibration pass, and uses a deterministic Binary Search / Lagrangian multiplier solver to optimally fill your VRAM budget. + +### Run this to generate the v5 project: + +``` +bash +mkdir -p sdq_quant_v5/quant_v5 +cd sdq_quant_v5 + +cat > README.md <<'EOF' +# SDQ Quant v5 (Hessian-Trace + Lagrangian Allocator) + +A deterministic, mathematically rigorous Software-Defined Quantization pipeline producing 100% standard GPTQ checkpoints for vLLM / GPTQModel. + +## Core Upgrades in v5 +1. **Hessian-Trace Measurement:** Captures true second-order sensitivity (`trace(X^T X)`) rather than misleading RTN/AWQ proxies. +2. **Deterministic Solver:** Replaces the LLM Judge with a Lagrangian Multiplier (Knapsack) solver to mathematically minimize model distortion under a strict megabyte budget. +3. **Standard GPTQ Formats:** Outputs strict 4-bit, 8-bit, or skipped (FP16) tensor paths ensuring seamless deployment in vLLM. +EOF + +cat > requirements.txt <<'EOF' +transformers +datasets +accelerate +torch +gptqmodel +lm-eval +tqdm +EOF + +cat > quant_v5/__init__.py <<'EOF' +"""SDQ Quant v5: Hessian-Trace Allocator""" +EOF + +cat > quant_v5/actions.py <<'EOF' +from dataclasses import dataclass +from typing import Dict, List, Optional + +@dataclass(frozen=True) +class ActionSpec: + name: str + bits: int + group_size: Optional[int] + skip: bool + +ACTION_SPECS: Dict[str, ActionSpec] = { + "w4_g128": ActionSpec("w4_g128", bits=4, group_size=128, skip=False), + "w8_g128": ActionSpec("w8_g128", bits=8, group_size=128, skip=False), + "skip": ActionSpec("skip", bits=16, group_size=None, skip=True), +} + +BASE_ACTION = "w4_g128" + +def bytes_per_param(action: str) -> float: + if action == "skip": + return 2.0 # FP16 + spec = ACTION_SPECS[action] + base = spec.bits / 8.0 + # Add GPTQ overhead: scale (16-bit) and zero (4-bit or 8-bit depending on sym, let's average ~2.5 bytes per group) + overhead = 2.5 / spec.group_size + return base + overhead + +def action_to_override(action: str) -> dict: + if action == "skip": + return {} + spec = ACTION_SPECS[action] + return {"bits": spec.bits, "group_size": spec.group_size} +EOF + +cat > quant_v5/measurement.py <<'EOF' +import json +import torch +from pathlib import Path +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer +from datasets import load_dataset +from .actions import bytes_per_param, ACTION_SPECS, BASE_ACTION + +def load_calibration_texts(ds_name: str, ds_config: str, split: str, text_field: str, num_texts: int): + ds = load_dataset(ds_name, ds_config, split=split) + texts =[] + for row in ds: + if row.get(text_field) and row[text_field].strip(): + texts.append(row[text_field].strip()) + if len(texts) >= num_texts: break + return texts + +def measure_hessian_traces(args): + """ + Computes Trace(H) ≈ sum(mean(X^2, dim=0)) across calibration samples. + """ + out_dir = Path(args.out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + tokenizer = AutoTokenizer.from_pretrained(args.model_id, trust_remote_code=True) + if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token + + model = AutoModelForCausalLM.from_pretrained( + args.model_id, device_map="auto", torch_dtype=torch.float16, low_cpu_mem_usage=True + ) + model.eval() + + texts = load_calibration_texts(args.dataset_name, args.dataset_config, args.split, args.text_field, args.num_texts) + + # Identify targets + include_mods =[x.strip() for x in args.include_modules.split(',')] + targets =[(n, m) for n, m in model.named_modules() if isinstance(m, torch.nn.Linear) and any(i in n for i in include_mods)] + + # We will accumulate the sum of squared inputs per column + hessian_traces = {name: 0.0 for name, _ in targets} + row_counts = {name: 0 for name, _ in targets} + params_count = {name: m.weight.numel() for name, m in targets} + + handles =[] + def make_hook(name): + def hook(_, inp, __): + x = inp[0].detach().float().reshape(-1, inp[0].shape[-1]) # [batch*seq, in_features] + # Trace of X^T X is just the sum of squared elements + hessian_traces[name] += (x ** 2).sum().item() + row_counts[name] += x.shape[0] + return hook + + for name, module in targets: + handles.append(module.register_forward_hook(make_hook(name))) + + # Forward passes + with torch.no_grad(): + for i in tqdm(range(0, len(texts), args.batch_size), desc="Hessian Proxy Measurement"): + batch = tokenizer(texts[i:i+args.batch_size], return_tensors="pt", truncation=True, max_length=args.max_length, padding=True) + batch = {k: v.to(model.device) for k, v in batch.items()} + model(**batch) + + for h in handles: h.remove() + + # Normalize by row count and save + features =[] + for name, _ in targets: + avg_trace = hessian_traces[name] / max(1, row_counts[name]) + features.append({ + "name": name, + "params": params_count[name], + "hessian_trace": avg_trace + }) + + out_file = out_dir / "hessian_features.json" + out_file.write_text(json.dumps(features, indent=2)) + print(f"Saved {len(features)} layer traces to {out_file}") + + return features +EOF + +cat > quant_v5/allocator.py <<'EOF' +import json +import re +from pathlib import Path +from .actions import bytes_per_param, ACTION_SPECS, BASE_ACTION + +def solve_allocation(features_path: str, budget_mb_extra: float): + """ + Lagrangian multiplier solver to minimize quantization error subject to a memory budget. + """ + features = json.loads(Path(features_path).read_text()) + + # Calculate base size (everything in BASE_ACTION) + base_bytes_per_param = bytes_per_param(BASE_ACTION) + total_base_mb = sum(f["params"] * base_bytes_per_param for f in features) / (1024**2) + target_mb = total_base_mb + budget_mb_extra + + actions = list(ACTION_SPECS.keys()) + + # Expected Error Model: Error proportional to Trace * 2^(-2 * bits) + def expected_error(trace: float, action: str) -> float: + if action == "skip": return 0.0 # FP16 = 0 quantization error + bits = ACTION_SPECS[action].bits + return trace * (2.0 ** (-2.0 * bits)) + + def action_cost_mb(params: int, action: str) -> float: + return (params * bytes_per_param(action)) / (1024**2) + + # Binary search for Lambda (penalty for size) + lambda_low, lambda_high = 0.0, 1e12 + best_plan = None + + for _ in range(80): + lambda_mid = (lambda_low + lambda_high) / 2.0 + current_mb = 0.0 + plan =[] + + for f in features: + best_action = None + best_cost = float('inf') + + for act in actions: + err = expected_error(f["hessian_trace"], act) + size_mb = action_cost_mb(f["params"], act) + + # Objective: Minimize (Error + Lambda * Size) + obj = err + lambda_mid * size_mb + if obj < best_cost: + best_cost = obj + best_action = act + + plan.append({ + "name": f["name"], + "action": best_action, + "params": f["params"], + "hessian_trace": f["hessian_trace"], + "size_mb": action_cost_mb(f["params"], best_action) + }) + current_mb += action_cost_mb(f["params"], best_action) + + best_plan = plan + if current_mb > target_mb: + lambda_low = lambda_mid # Over budget, penalize size more + else: + lambda_high = lambda_mid # Under budget, penalize size less + + used_extra_mb = sum(p["size_mb"] for p in best_plan) - total_base_mb + + # Convert to GPTQModel Dynamic config dict + dynamic = {} + from .actions import action_to_override + for p in best_plan: + if p["action"] == BASE_ACTION: continue + regex = f"^{re.escape(p['name'])}$" + if p["action"] == "skip": + dynamic[f"-:{regex}"] = {} + else: + dynamic[f"+:{regex}"] = action_to_override(p["action"]) + + return {"plan": best_plan, "dynamic": dynamic, "used_extra_mb": used_extra_mb} +EOF + +cat > quant_v5/runner.py <<'EOF' +import json +import subprocess +import sys +from pathlib import Path +import torch +from gptqmodel import GPTQModel, QuantizeConfig +from .measurement import load_calibration_texts + +def run_quantize(args, dynamic_config): + texts = load_calibration_texts(args.dataset_name, args.dataset_config, args.split, args.text_field, args.quant_num_texts) + + qcfg = QuantizeConfig( + bits=4, group_size=128, sym=True, desc_act=True, dynamic=dynamic_config + ) + + print(f"Loading Model for Quantization with {len(dynamic_config) if dynamic_config else 0} dynamic rules...") + model = GPTQModel.load(args.model_id, qcfg) + model.quantize(texts, batch_size=args.quant_batch_size) + + out = Path(args.out_dir) + out.mkdir(parents=True, exist_ok=True) + model.save(str(out)) + del model + if torch.cuda.is_available(): torch.cuda.empty_cache() + +def run_eval(model_path: str, tasks: str, out_file: str): + Path(out_file).parent.mkdir(parents=True, exist_ok=True) + cmd =[ + sys.executable, "-m", "lm_eval", + "--model", "hf", + "--model_args", f"pretrained={model_path},trust_remote_code=True", + "--tasks", tasks, + "--device", "cuda:0", + "--batch_size", "auto", + "--output_path", out_file + ] + subprocess.run(cmd, check=True) +EOF + +cat > quant_v5/pipeline.py <<'EOF' +import argparse +import json +from pathlib import Path +from .measurement import measure_hessian_traces +from .allocator import solve_allocation +from .runner import run_quantize, run_eval + +def main(): + p = argparse.ArgumentParser() + sub = p.add_subparsers(dest="cmd", required=True) + + suite = sub.add_parser("suite") + suite.add_argument("--model-id", required=True) + suite.add_argument("--work-dir", required=True) + suite.add_argument("--budget-mb-extra", type=float, default=250.0) + suite.add_argument("--include-modules", default="q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj,lm_head") + suite.add_argument("--dataset-name", default="wikitext") + suite.add_argument("--dataset-config", default="wikitext-2-raw-v1") + suite.add_argument("--split", default="train") + suite.add_argument("--text-field", default="text") + suite.add_argument("--num-texts", type=int, default=128) + suite.add_argument("--max-length", type=int, default=2048) + suite.add_argument("--batch-size", type=int, default=4) + suite.add_argument("--quant-num-texts", type=int, default=128) + suite.add_argument("--quant-batch-size", type=int, default=4) + suite.add_argument("--tasks", default="arc_challenge,hellaswag") + args = p.parse_args() + + work_dir = Path(args.work_dir) + work_dir.mkdir(parents=True, exist_ok=True) + + print("\n--- STAGE 1: Hessian Trace Measurement ---") + args.out_dir = str(work_dir / "probe") + features = measure_hessian_traces(args) + + print("\n--- STAGE 2: Lagrangian Allocation ---") + alloc = solve_allocation(str(work_dir / "probe" / "hessian_features.json"), args.budget_mb_extra) + (work_dir / "sdq_dynamic.json").write_text(json.dumps(alloc["dynamic"], indent=2)) + print(f"Allocated! Used {alloc['used_extra_mb']:.2f} MB extra.") + + print("\n--- STAGE 3A: Quantize Baseline (Pure 4-bit) ---") + args.out_dir = str(work_dir / "models" / "baseline") + run_quantize(args, dynamic_config=None) + + print("\n--- STAGE 3B: Quantize SDQ v5 (Hessian-Lagrangian) ---") + args.out_dir = str(work_dir / "models" / "sdq_v5") + run_quantize(args, dynamic_config=alloc["dynamic"]) + + print("\n--- STAGE 4: Evaluation ---") + run_eval(str(work_dir / "models" / "baseline"), args.tasks, str(work_dir / "eval" / "baseline.json")) + run_eval(str(work_dir / "models" / "sdq_v5"), args.tasks, str(work_dir / "eval" / "sdq_v5.json")) + + print("\n--- Pipeline Complete! See eval/ folder for results. ---") + +if __name__ == "__main__": + main() +EOF +``` + +### Setup & Run + +``` +bash +python -m venv .venv +source .venv/bin/activate +pip install -r requirements.txt + +python -m quant_v5.pipeline suite \ + --model-id Qwen/Qwen3.5-4B-Base \ + --work-dir ./runs/v5_suite \ + --budget-mb-extra 200 \ + --tasks arc_challenge,hellaswag +``` + +### What this solves perfectly: +1. **The Allocator is Instant & Optimal:** No LLM generation delays. The Binary Search Lagrangian solver loops 80 times locally in microseconds to find the mathematically ideal bit-distribution curve. +2. **True Objective Function:** It uses $X^T X$ (Hessian Diagonal/Trace) directly from captured activations. This precisely models how errors will propagate *after* GPTQ's Hessian-inverse compensation, which RTN fake-quant completely misses. +3. **vLLM Compatibility Guarantees:** Because the output `dynamic_config` targets pure module names and assigns standard `w8_g128`, `w4_g128`, or structural fp16 skips, the resulting checkpoint will load directly in vLLM without needing a customized branch. + +### A Note on the "Beam Search" Rounding Step + +To implement the Trellis/Beam-Search rounding (*improving* the quantization loop itself without changing the standard GPTQ output structure), you would need to monkeypatch `gptqmodel.quantization.gptq.GPTQ.quantize`. + +Because GPTQModel updates its internal `gptq.py` aggressively (to support Marlin, Triton, etc.), a dirty monkeypatch in a standalone script will break frequently. The absolute best way to capture this final piece of EXL3 quality is to open a PR against `GPTQModel` adding a `use_beam_search=True` flag directly to their native `QuantizeConfig`, replacing their greedy `q = torch.round(w / scale)` loop with the Lookahead Trellis logic we discussed. + +--- + +### ChatGPT: +This v5 is **much better** than v4. The big architectural shift is correct: drop the LLM judge, keep the output in standard GPTQ, and let a deterministic allocator decide per-module `4-bit / 8-bit / skip`. That matches what GPTQModel and vLLM actually support today: GPTQModel exposes per-module `Dynamic` overrides, and vLLM’s GPTQ path supports standard GPTQ weights at 2/3/4/8 bits, not arbitrary EXL2-style mixed intra-tensor bitrates. [GitHub+2vLLM+2](https://github.com/ModelCloud/GPTQModel) + +The part I would tone down is the claim that this “precisely models” GPTQ. GPTQ is a second-order method that uses approximate Hessian information during quantization, but your `trace(XᵀX)` scalar is still only a **coarse importance proxy**. It throws away the off-diagonal structure and the blockwise inverse-Hessian behavior that GPTQ actually exploits. So this is a **better allocator signal than RTN+LLM**, not a faithful simulation of GPTQ error propagation. [arXiv](https://arxiv.org/pdf/2210.17323?utm_source=chatgpt.com) + +What is solid: + +- Per-module dynamic overrides are the right emission target. GPTQModel explicitly supports mixed per-module quantization and module exclusion, and vLLM integrates that path. [GitHub+2vLLM+2](https://github.com/ModelCloud/GPTQModel) +- Restricting the action set to `w4_g128`, `w8_g128`, and `skip` is realistic for standard GPTQ deployment, because vLLM’s GPTQ config documents support for 2/3/4/8-bit weights only. [vLLM](https://docs.vllm.ai/en/latest/api/vllm/model_executor/layers/quantization/gptq/) +- Replacing the judge with a deterministic solver is the right direction. Recent work on Hessian-proxy-based bit allocation is consistent with that design choice, even though your exact objective is still heuristic. [arXiv+1](https://arxiv.org/abs/2506.05664?utm_source=chatgpt.com) + +The must-fix issues in the script: + +### 1) model.device with device_map="auto" is fragile + +In `measurement.py`, you move tokenized batches with `v.to(model.device)`. With `device_map="auto"`, that is not a reliable contract. For a 1×3090 workflow, load the FP model explicitly onto GPU 0 instead of sharding logic, or derive the input device from the embedding layer / first parameter. Otherwise you can get subtle placement errors. + +Use something like: + +```Pythondevice = next(model.parameters()).devicebatch = {k: v.to(device) for k, v in batch.items()} +Pythondevice = next(model.parameters()).devicebatch = {k: v.to(device) for k, v in batch.items()} +``` + +or load with a single-device map. + +### 2) run_eval() is using the CLI in a brittle way + +The lm-eval docs currently recommend `simple_evaluate()` as the primary programmatic entry point. Your CLI call with `--output_path` pointing at a file path is more brittle than calling the API directly and writing JSON yourself. [GitHub](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/python-api.md) + +I would replace `runner.run_eval` with: + +```Pythonimport jsonimport lm_evaldef run_eval(model_path: str, tasks: str, out_file: str): results = lm_eval.simple_evaluate( model="hf", model_args=f"pretrained={model_path},trust_remote_code=True", tasks=[t.strip() for t in tasks.split(",") if t.strip()], batch_size="auto", device="cuda:0", ) Path(out_file).parent.mkdir(parents=True, exist_ok=True) Path(out_file).write_text(json.dumps(results, indent=2), encoding="utf-8") return results +Pythonimport jsonimport lm_evaldef run_eval(model_path: str, tasks: str, out_file: str): results = lm_eval.simple_evaluate( model="hf", model_args=f"pretrained={model_path},trust_remote_code=True", tasks=[t.strip() for t in tasks.split(",") if t.strip()], batch_size="auto", device="cuda:0", ) Path(out_file).parent.mkdir(parents=True, exist_ok=True) Path(out_file).write_text(json.dumps(results, indent=2), encoding="utf-8") return results +``` + +That is closer to the harness’ recommended usage. [GitHub](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/python-api.md) + +### 3) Your cost model is still approximate + +`bytes_per_param()` uses a fixed `2.5 / group_size` overhead. That is acceptable as an estimator, but it is **not** guaranteed to match the exact storage behavior of the backend/kernel path you will deploy. So the allocator’s memory budget is an approximation, not a hard kernel-true guarantee. + +I would label it explicitly in the README as: + +- “planner memory estimate” +- not “exact runtime VRAM model” + +### 4) skip is still global, not structural + +You fixed the v4 `lm_head` problem conceptually, but this v5 allocator no longer treats `lm_head` specially. Because `skip` is included in the solver action set for every module, the optimizer can choose FP16 for arbitrary modules if the scalar objective likes it. That is mathematically legal, but operationally messy. + +I’d add a policy layer: + +- allow `skip` only for `lm_head` +- allow `w8_g128` / `w4_g128` for everything else + +That keeps the plan sane and easier to deploy. + +### 5) The error model is heuristic, not proven-optimal + +This line: + +```Pythonreturn trace * (2.0 ** (-2.0 * bits)) +Pythonreturn trace * (2.0 ** (-2.0 * bits)) +``` + +is a useful surrogate, but it is still a hand-designed error law. It is not something GPTQModel or GPTQ itself guarantees. So the solver is “deterministic and principled,” but not “mathematically exact.” + +I would rename: + +- `expected_error` → `proxy_error` +- `solve_allocation` → `solve_proxy_allocation` + +That makes the claims honest. + +### 6) The pipeline is not actually “self-extracting” + +The content is a shell recipe that writes files, not a single self-extracting script artifact. That is fine, but the label is a bit off. + +## My verdict + +**Keep this direction.** It is the right architecture for “best possible standard GPTQ under a budget.” + +I would call it: + +- **architecturally correct** +- **substantially better than v4** +- **not yet mathematically faithful enough to claim GPTQ-optimal allocation** +- **close to production once the device handling, eval path, and skip policy are tightened** + +The clean v5.1 patch would do four things: + +1. force single-device measurement cleanly on GPU 0, +2. switch eval to `lm_eval.simple_evaluate()`, +3. restrict `skip` to `lm_head`, +4. rename the allocator/error model to “proxy” terminology so the claims match reality. [GitHub+2vLLM+2](https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/python-api.md) + +Paste the next revision and I’ll review it line by line. + +--- + +### User: +You are absolutely right to pull me up on the terminology. Calling a coarse `trace(XᵀX)` a "precise model" of GPTQ is an overstatement. It is a principled **Hessian proxy** that correlates strongly with quantization sensitivity, but it ignores the inverse-Hessian block updates that are the actual mathematical heartbeat of GPTQ. + +### To answer your question about AutoRound: +**Yes. AutoRound is exactly the holy grail for your Stage C.** +AutoRound (and its predecessor SignRound) fundamentally solves the exact "Trellis/Beam-Search" rounding problem I mentioned earlier, but it does so via **SignSGD optimization**. +* **How it works:** Instead of greedy rounding or lookahead trees, it formulates rounding as a continuous machine-learning problem. It learns a tiny perturbation tensor (using Sign Gradient Descent) for the weights and scales to minimize the block-wise MSE of the activations. +* **Does it change the kernel?** **No!** This is its superpower. AutoRound exports a 100% standard `qweight, qzeros, scales` checkpoint. It runs flawlessly in vLLM's `gptq_marlin` backend, AutoGPTQ, and GPTQModel. +* **Is it worth it?** **Absolutely.** It achieves SOTA accuracy (beating plain GPTQ and AWQ), especially at lower bitrates, while maintaining total inference-side compatibility. If you want the absolute highest accuracy inside a plain GPTQ boundary, you use your Hessian-trace allocator to pick the bit-widths, and then you use AutoRound as the engine to actually crunch the integers. + +--- + +Here is the pristine **SDQ Quant v5.1** patch. + +It implements all your structural feedback: +1. Forces `cuda:0` cleanly. +2. Uses the native `lm_eval.simple_evaluate()` API. +3. Hard-restricts `skip` to the `lm_head`. +4. Accurately labels the memory cost as an estimator and the error as a proxy. + +### Run this to generate the v5.1 project: + +``` +bash +mkdir -p sdq_quant_v5_1/quant_v5_1 +cd sdq_quant_v5_1 + +cat > README.md <<'EOF' +# SDQ Quant v5.1 (Hessian-Proxy + Lagrangian Allocator) + +A deterministic Software-Defined Quantization pipeline producing standard GPTQ checkpoints for vLLM / GPTQModel. + +## Design +1. **Measurement:** Computes `trace(X^T X)` as a proxy for second-order sensitivity. +2. **Allocation:** Uses a Lagrangian solver to minimize a heuristic proxy error under a planner memory budget. + - `lm_head` is restricted to `skip` (FP16) or 8-bit. + - All other linear layers are restricted to 4-bit or 8-bit. +3. **Quantization:** Standard GPTQ export compatible with vLLM (`gptq_marlin`). +4. **Evaluation:** Direct programmatic usage of `lm-eval`. + +*Note: The memory budget is a planner estimate, not an exact runtime VRAM guarantee. The error model is a proxy, not a proven-optimal GPTQ simulation.* +EOF + +cat > requirements.txt <<'EOF' +transformers +datasets +accelerate +torch +gptqmodel +lm-eval +tqdm +EOF + +cat > quant_v5_1/__init__.py <<'EOF' +"""SDQ Quant v5.1: Hessian-Proxy Allocator""" +EOF + +cat > quant_v5_1/actions.py <<'EOF' +from dataclasses import dataclass +from typing import Dict, Optional + +@dataclass(frozen=True) +class ActionSpec: + name: str + bits: int + group_size: Optional[int] + skip: bool + +ACTION_SPECS: Dict[str, ActionSpec] = { + "w4_g128": ActionSpec("w4_g128", bits=4, group_size=128, skip=False), + "w8_g128": ActionSpec("w8_g128", bits=8, group_size=128, skip=False), + "skip": ActionSpec("skip", bits=16, group_size=None, skip=True), +} + +BASE_ACTION = "w4_g128" + +def bytes_per_param(action: str) -> float: + """Planner memory estimate. Not an exact runtime VRAM model.""" + if action == "skip": + return 2.0 # FP16 + spec = ACTION_SPECS[action] + base = spec.bits / 8.0 + # ~2.5 bytes overhead per group for scales/zeros + overhead = 2.5 / spec.group_size + return base + overhead + +def action_to_override(action: str) -> dict: + if action == "skip": + return {} + spec = ACTION_SPECS[action] + return {"bits": spec.bits, "group_size": spec.group_size} +EOF + +cat > quant_v5_1/measurement.py <<'EOF' +import json +import torch +from pathlib import Path +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer +from datasets import load_dataset + +def load_calibration_texts(ds_name: str, ds_config: str, split: str, text_field: str, num_texts: int): + ds = load_dataset(ds_name, ds_config, split=split) + texts =[] + for row in ds: + if row.get(text_field) and row[text_field].strip(): + texts.append(row[text_field].strip()) + if len(texts) >= num_texts: break + return texts + +def measure_hessian_traces(args): + """ + Computes Trace(H) ≈ sum(mean(X^2, dim=0)) across calibration samples. + """ + out_dir = Path(args.out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + tokenizer = AutoTokenizer.from_pretrained(args.model_id, trust_remote_code=True) + if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token + + model = AutoModelForCausalLM.from_pretrained( + args.model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True + ).to(device) + model.eval() + + texts = load_calibration_texts(args.dataset_name, args.dataset_config, args.split, args.text_field, args.num_texts) + + include_mods =[x.strip() for x in args.include_modules.split(',')] + targets =[(n, m) for n, m in model.named_modules() if isinstance(m, torch.nn.Linear) and any(i in n for i in include_mods)] + + hessian_traces = {name: 0.0 for name, _ in targets} + row_counts = {name: 0 for name, _ in targets} + params_count = {name: m.weight.numel() for name, m in targets} + + handles =[] + def make_hook(name): + def hook(_, inp, __): + x = inp[0].detach().float().reshape(-1, inp[0].shape[-1]) + hessian_traces[name] += (x ** 2).sum().item() + row_counts[name] += x.shape[0] + return hook + + for name, module in targets: + handles.append(module.register_forward_hook(make_hook(name))) + + with torch.no_grad(): + for i in tqdm(range(0, len(texts), args.batch_size), desc="Hessian Proxy Measurement"): + batch = tokenizer(texts[i:i+args.batch_size], return_tensors="pt", truncation=True, max_length=args.max_length, padding=True) + batch = {k: v.to(device) for k, v in batch.items()} + model(**batch) + + for h in handles: h.remove() + + features =[] + for name, _ in targets: + avg_trace = hessian_traces[name] / max(1, row_counts[name]) + features.append({ + "name": name, + "params": params_count[name], + "hessian_trace": avg_trace + }) + + out_file = out_dir / "hessian_features.json" + out_file.write_text(json.dumps(features, indent=2)) + print(f"Saved {len(features)} layer traces to {out_file}") + + del model + if torch.cuda.is_available(): torch.cuda.empty_cache() + + return features +EOF + +cat > quant_v5_1/allocator.py <<'EOF' +import json +import re +from pathlib import Path +from .actions import bytes_per_param, ACTION_SPECS, BASE_ACTION + +def get_valid_actions(module_name: str) -> list[str]: + """Policy layer: restricts skips strictly to structural elements.""" + if "lm_head" in module_name: + return["skip", "w8_g128"] + return ["w4_g128", "w8_g128"] + +def solve_proxy_allocation(features_path: str, budget_mb_extra: float): + """ + Lagrangian multiplier solver to minimize heuristic proxy error. + """ + features = json.loads(Path(features_path).read_text()) + + base_bytes_per_param = bytes_per_param(BASE_ACTION) + total_base_mb = sum(f["params"] * base_bytes_per_param for f in features) / (1024**2) + target_mb = total_base_mb + budget_mb_extra + + def proxy_error(trace: float, action: str) -> float: + if action == "skip": return 0.0 + bits = ACTION_SPECS[action].bits + return trace * (2.0 ** (-2.0 * bits)) + + def action_cost_mb(params: int, action: str) -> float: + return (params * bytes_per_param(action)) / (1024**2) + + lambda_low, lambda_high = 0.0, 1e12 + best_plan = None + + for _ in range(80): + lambda_mid = (lambda_low + lambda_high) / 2.0 + current_mb = 0.0 + plan =[] + + for f in features: + best_action = None + best_cost = float('inf') + + valid_actions = get_valid_actions(f["name"]) + + for act in valid_actions: + err = proxy_error(f["hessian_trace"], act) + size_mb = action_cost_mb(f["params"], act) + + obj = err + lambda_mid * size_mb + if obj < best_cost: + best_cost = obj + best_action = act + + plan.append({ + "name": f["name"], + "action": best_action, + "params": f["params"], + "hessian_trace": f["hessian_trace"], + "size_mb": action_cost_mb(f["params"], best_action) + }) + current_mb += action_cost_mb(f["params"], best_action) + + best_plan = plan + if current_mb > target_mb: + lambda_low = lambda_mid + else: + lambda_high = lambda_mid + + used_extra_mb = sum(p["size_mb"] for p in best_plan) - total_base_mb + + dynamic = {} + from .actions import action_to_override + for p in best_plan: + if p["action"] == BASE_ACTION: continue + regex = f"^{re.escape(p['name'])}$" + if p["action"] == "skip": + dynamic[f"-:{regex}"] = {} + else: + dynamic[f"+:{regex}"] = action_to_override(p["action"]) + + return {"plan": best_plan, "dynamic": dynamic, "used_extra_mb": used_extra_mb} +EOF + +cat > quant_v5_1/runner.py <<'EOF' +import json +from pathlib import Path +import torch +import lm_eval +from gptqmodel import GPTQModel, QuantizeConfig +from .measurement import load_calibration_texts + +def run_quantize(args, dynamic_config): + texts = load_calibration_texts(args.dataset_name, args.dataset_config, args.split, args.text_field, args.quant_num_texts) + + qcfg = QuantizeConfig( + bits=4, group_size=128, sym=True, desc_act=True, dynamic=dynamic_config + ) + + print(f"Loading Model for Quantization with {len(dynamic_config) if dynamic_config else 0} dynamic rules...") + model = GPTQModel.load(args.model_id, qcfg) + model.quantize(texts, batch_size=args.quant_batch_size) + + out = Path(args.out_dir) + out.mkdir(parents=True, exist_ok=True) + model.save(str(out)) + del model + if torch.cuda.is_available(): torch.cuda.empty_cache() + +def run_eval(model_path: str, tasks: str, out_file: str): + print(f"Running evaluation on {model_path}...") + Path(out_file).parent.mkdir(parents=True, exist_ok=True) + + task_list =[t.strip() for t in tasks.split(",") if t.strip()] + + results = lm_eval.simple_evaluate( + model="hf", + model_args=f"pretrained={model_path},trust_remote_code=True", + tasks=task_list, + batch_size="auto", + device="cuda:0" + ) + + serializable_results = { + "results": results.get("results", {}), + "versions": results.get("versions", {}) + } + + Path(out_file).write_text(json.dumps(serializable_results, indent=2), encoding="utf-8") + return serializable_results +EOF + +cat > quant_v5_1/pipeline.py <<'EOF' +import argparse +import json +from pathlib import Path +from .measurement import measure_hessian_traces +from .allocator import solve_proxy_allocation +from .runner import run_quantize, run_eval + +def main(): + p = argparse.ArgumentParser() + sub = p.add_subparsers(dest="cmd", required=True) + + suite = sub.add_parser("suite") + suite.add_argument("--model-id", required=True) + suite.add_argument("--work-dir", required=True) + suite.add_argument("--budget-mb-extra", type=float, default=250.0) + suite.add_argument("--include-modules", default="q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj,lm_head") + suite.add_argument("--dataset-name", default="wikitext") + suite.add_argument("--dataset-config", default="wikitext-2-raw-v1") + suite.add_argument("--split", default="train") + suite.add_argument("--text-field", default="text") + suite.add_argument("--num-texts", type=int, default=128) + suite.add_argument("--max-length", type=int, default=2048) + suite.add_argument("--batch-size", type=int, default=4) + suite.add_argument("--quant-num-texts", type=int, default=128) + suite.add_argument("--quant-batch-size", type=int, default=4) + suite.add_argument("--tasks", default="arc_challenge,hellaswag") + args = p.parse_args() + + work_dir = Path(args.work_dir) + work_dir.mkdir(parents=True, exist_ok=True) + + print("\n--- STAGE 1: Hessian Proxy Measurement ---") + args.out_dir = str(work_dir / "probe") + features = measure_hessian_traces(args) + + print("\n--- STAGE 2: Lagrangian Proxy Allocation ---") + alloc = solve_proxy_allocation(str(work_dir / "probe" / "hessian_features.json"), args.budget_mb_extra) + (work_dir / "sdq_dynamic.json").write_text(json.dumps(alloc["dynamic"], indent=2)) + print(f"Allocated! Estimated extra VRAM footprint: {alloc['used_extra_mb']:.2f} MB.") + + print("\n--- STAGE 3A: Quantize Baseline (Pure 4-bit) ---") + args.out_dir = str(work_dir / "models" / "baseline") + run_quantize(args, dynamic_config=None) + + print("\n--- STAGE 3B: Quantize SDQ v5.1 ---") + args.out_dir = str(work_dir / "models" / "sdq_v5_1") + run_quantize(args, dynamic_config=alloc["dynamic"]) + + print("\n--- STAGE 4: Evaluation ---") + run_eval(str(work_dir / "models" / "baseline"), args.tasks, str(work_dir / "eval" / "baseline.json")) + run_eval(str(work_dir / "models" / "sdq_v5_1"), args.tasks, str(work_dir / "eval" / "sdq_v5_1.json")) + + print("\n--- Pipeline Complete! See eval/ folder for results. ---") + +if __name__ == "__main__": + main() +EOF +``` + +--- + +### ChatGPT: +This v5.1 is the first version I’d actually build on. The big moves are right: + +- dropping the LLM judge in favor of a deterministic allocator, +- staying inside standard GPTQ-compatible actions, +- switching eval to `lm_eval.simple_evaluate()`, +- and forcing a single-device measurement path on `cuda:0`. + +That direction matches the current constraints of the stack: GPTQModel is meant to emit standard GPTQ checkpoints with per-module dynamic overrides, and vLLM’s GPTQ path only documents support for 2/3/4/8-bit GPTQ weights, not EXL2-style arbitrary intra-tensor bitrates. [GitHub+2vLLM+2](https://github.com/ModelCloud/GPTQModel?utm_source=chatgpt.com) + +On AutoRound: yes, that is the strongest Stage C upgrade if your goal is “better quantization math, same deployment boundary.” AutoRound’s own docs say it can export `auto_gptq` format, supports 2/3/4/8-bit schemes there, and its “best” recipe is explicitly accuracy-first though slower than the standard recipe. That makes it a very good candidate for replacing plain greedy GPTQ rounding while still targeting a standard GPTQ-style artifact. [GitHub+1](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md?utm_source=chatgpt.com) + +There is one important logic bug left in v5.1: `lm_head` is restricted to `["skip", "w8_g128"]`, but your baseline memory target is still computed from `BASE_ACTION = "w4_g128"` for **all** modules. That means the solver is forced to spend extra budget on `lm_head` even when the requested extra budget is `0`, because `lm_head` is no longer allowed to stay at the base action. So the policy restriction and the budget baseline are inconsistent. The fix is to either allow `lm_head` to stay at `w4_g128`, or move `lm_head` into a structural-policy path outside the optimizer. Without that, the allocator is biased from the start. This follows directly from your code structure and from the fact that GPTQ runtimes only expose discrete bit choices, so your baseline/action set has to be internally consistent. [vLLM+1](https://docs.vllm.ai/en/latest/api/vllm/model_executor/layers/quantization/gptq/?utm_source=chatgpt.com) + +I’d patch `get_valid_actions()` like this: + +```Pythondef get_valid_actions(module_name: str) -> list[str]: if "lm_head" in module_name: return ["w4_g128", "w8_g128", "skip"] return ["w4_g128", "w8_g128"] +Pythondef get_valid_actions(module_name: str) -> list[str]: if "lm_head" in module_name: return ["w4_g128", "w8_g128", "skip"] return ["w4_g128", "w8_g128"] +``` + +Then, if you still want a “never leave lm_head at 4-bit” policy, enforce that as a separate post-pass or structural rule rather than inside the base knapsack. + +Second, the README claim about “seamless deployment in vLLM (`gptq_marlin`)" is a little too strong. vLLM does document GPTQModel-based 4-bit and 8-bit GPTQ models, and the GPTQ layer docs enforce 2/3/4/8-bit support, but that is not the same thing as a blanket guarantee that every mixed 4/8/skip recipe will go through every backend path identically. I’d reword that to “targets standard GPTQ-compatible 4/8/skip module paths intended for GPTQModel/vLLM deployment.” [vLLM+1](https://docs.vllm.ai/en/latest/api/vllm/model_executor/layers/quantization/gptq/?utm_source=chatgpt.com) + +Third, your `trace(XᵀX)` allocator is a good **proxy allocator**, not a GPTQ simulator. I’d keep the new names you introduced—“Hessian proxy” and “proxy allocation”—because that is honest. GPTQ itself is second-order and uses blockwise inverse-Hessian-style compensation; your scalar trace objective is still a lossy summary. It is better than RTN+LLM, but it is not exact GPTQ error prediction. [vLLM+1](https://docs.vllm.ai/en/latest/api/vllm/model_executor/layers/quantization/gptq/?utm_source=chatgpt.com) + +My recommendation: + +- keep v5.1 as the control-plane architecture, +- fix the `lm_head` baseline inconsistency, +- soften the deployment wording, +- then make Stage C pluggable: + +`backend = gptq_plain` + +`backend = autoround_auto_gptq` +- `backend = gptq_plain` +- `backend = autoround_auto_gptq` + +That gives you a clean experiment matrix: + +- baseline GPTQ +- Hessian-proxy allocator + plain GPTQ +- Hessian-proxy allocator + AutoRound exported as `auto_gptq` + +That is the comparison that will tell you whether the allocator is helping, whether AutoRound is helping, or both. AutoRound’s official docs are already aligned with that export story. [GitHub](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md?utm_source=chatgpt.com) + +So overall: **v5.1 is good, but not done**. The `lm_head` action-set bug is the main blocker. After that, I’d move straight to an AutoRound-backed v5.2. + +--- + +### User: +You are absolutely right. The `lm_head` constraints conflicting with the baseline knapsack budget is a classic mathematical footgun. By forcing `lm_head` out of the base `w4_g128` action without accounting for it in the base VRAM calculation, the solver was mathematically doomed at a `0` extra budget. + +Also, introducing **AutoRound** is the perfect capstone. It cleanly separates the two domains: +1. **The Allocation Domain:** Handled by our Hessian-proxy Lagrangian solver (deciding the `bits`/`skip` map). +2. **The Rounding Domain:** Handled by either standard GPTQ (greedy) or AutoRound (SignSGD), both exporting standard, vLLM-compatible `auto_gptq` artifacts. + +Here is the pristine **SDQ Quant v5.2**. + +It implements the `lm_head` constraint fix, updates the README wording to accurately reflect vLLM compatibility realities, and introduces a pluggable Stage C that tests a full 3-way matrix: +* `baseline_gptq` (GPTQModel, pure 4-bit, greedy) +* `sdq_gptq` (GPTQModel, Hessian-proxy 4/8/skip, greedy) +* `sdq_autoround` (AutoRound, Hessian-proxy 4/8/skip, SignSGD rounding exported to `auto_gptq`) + +### Run this to generate the v5.2 project: + +``` +bash +mkdir -p sdq_quant_v5_2/quant_v5_2 +cd sdq_quant_v5_2 + +cat > README.md <<'EOF' +# SDQ Quant v5.2 (Hessian-Proxy + AutoRound) + +A deterministic Software-Defined Quantization pipeline targeting standard GPTQ-compatible 4/8/skip module paths intended for GPTQModel/vLLM deployment. + +## Design +1. **Measurement:** Computes `trace(X^T X)` on `cuda:0` as a proxy for second-order sensitivity. +2. **Allocation:** Uses a Lagrangian solver to minimize a heuristic proxy error under a planner memory budget. + - `lm_head` can select `w4_g128`, `w8_g128`, or `skip`. + - All other linear layers are restricted to `w4_g128` or `w8_g128`. +3. **Quantization:** Pluggable Stage C. + - **GPTQModel:** Fast, greedy rounding. + - **AutoRound:** SOTA SignSGD continuous rounding, exported natively to standard `auto_gptq` format. +4. **Evaluation:** Direct programmatic usage of `lm-eval`. + +*Note: The memory budget is a planner estimate, not an exact runtime VRAM guarantee. The error model is a Hessian proxy, not a proven-optimal GPTQ blockwise inverse-Hessian simulation.* +EOF + +cat > requirements.txt <<'EOF' +transformers +datasets +accelerate +torch +gptqmodel +auto-round +lm-eval +tqdm +EOF + +cat > quant_v5_2/__init__.py <<'EOF' +"""SDQ Quant v5.2: Hessian-Proxy Allocator + AutoRound""" +EOF + +cat > quant_v5_2/actions.py <<'EOF' +from dataclasses import dataclass +from typing import Dict, Optional + +@dataclass(frozen=True) +class ActionSpec: + name: str + bits: int + group_size: Optional[int] + skip: bool + +ACTION_SPECS: Dict[str, ActionSpec] = { + "w4_g128": ActionSpec("w4_g128", bits=4, group_size=128, skip=False), + "w8_g128": ActionSpec("w8_g128", bits=8, group_size=128, skip=False), + "skip": ActionSpec("skip", bits=16, group_size=None, skip=True), +} + +BASE_ACTION = "w4_g128" + +def bytes_per_param(action: str) -> float: + """Planner memory estimate. Not an exact runtime VRAM model.""" + if action == "skip": + return 2.0 # FP16 + spec = ACTION_SPECS[action] + base = spec.bits / 8.0 + overhead = 2.5 / spec.group_size + return base + overhead + +def action_to_override(action: str) -> dict: + if action == "skip": + return {} + spec = ACTION_SPECS[action] + return {"bits": spec.bits, "group_size": spec.group_size} +EOF + +cat > quant_v5_2/measurement.py <<'EOF' +import json +import torch +from pathlib import Path +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer +from datasets import load_dataset + +def load_calibration_texts(ds_name: str, ds_config: str, split: str, text_field: str, num_texts: int): + ds = load_dataset(ds_name, ds_config, split=split) + texts =[] + for row in ds: + if row.get(text_field) and row[text_field].strip(): + texts.append(row[text_field].strip()) + if len(texts) >= num_texts: break + return texts + +def measure_hessian_traces(args): + """ + Computes Trace(H) ≈ sum(mean(X^2, dim=0)) cleanly on cuda:0. + """ + out_dir = Path(args.out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + tokenizer = AutoTokenizer.from_pretrained(args.model_id, trust_remote_code=True) + if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token + + model = AutoModelForCausalLM.from_pretrained( + args.model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True + ).to(device) + model.eval() + + texts = load_calibration_texts(args.dataset_name, args.dataset_config, args.split, args.text_field, args.num_texts) + + include_mods = [x.strip() for x in args.include_modules.split(',')] + targets =[(n, m) for n, m in model.named_modules() if isinstance(m, torch.nn.Linear) and any(i in n for i in include_mods)] + + hessian_traces = {name: 0.0 for name, _ in targets} + row_counts = {name: 0 for name, _ in targets} + params_count = {name: m.weight.numel() for name, m in targets} + + handles =[] + def make_hook(name): + def hook(_, inp, __): + x = inp[0].detach().float().reshape(-1, inp[0].shape[-1]) + hessian_traces[name] += (x ** 2).sum().item() + row_counts[name] += x.shape[0] + return hook + + for name, module in targets: + handles.append(module.register_forward_hook(make_hook(name))) + + with torch.no_grad(): + for i in tqdm(range(0, len(texts), args.batch_size), desc="Hessian Proxy Measurement"): + batch = tokenizer(texts[i:i+args.batch_size], return_tensors="pt", truncation=True, max_length=args.max_length, padding=True) + batch = {k: v.to(device) for k, v in batch.items()} + model(**batch) + + for h in handles: h.remove() + + features =[] + for name, _ in targets: + avg_trace = hessian_traces[name] / max(1, row_counts[name]) + features.append({ + "name": name, + "params": params_count[name], + "hessian_trace": avg_trace + }) + + out_file = out_dir / "hessian_features.json" + out_file.write_text(json.dumps(features, indent=2)) + print(f"Saved {len(features)} layer traces to {out_file}") + + del model + if torch.cuda.is_available(): torch.cuda.empty_cache() + + return features +EOF + +cat > quant_v5_2/allocator.py <<'EOF' +import json +import re +from pathlib import Path +from .actions import bytes_per_param, ACTION_SPECS, BASE_ACTION + +def get_valid_actions(module_name: str) -> list[str]: + """Policy layer: restricts structural FP16 skips strictly to lm_head.""" + if "lm_head" in module_name: + return["w4_g128", "w8_g128", "skip"] + return["w4_g128", "w8_g128"] + +def solve_proxy_allocation(features_path: str, budget_mb_extra: float): + """ + Lagrangian multiplier solver to minimize heuristic proxy error. + """ + features = json.loads(Path(features_path).read_text()) + + base_bytes_per_param = bytes_per_param(BASE_ACTION) + total_base_mb = sum(f["params"] * base_bytes_per_param for f in features) / (1024**2) + target_mb = total_base_mb + budget_mb_extra + + def proxy_error(trace: float, action: str) -> float: + if action == "skip": return 0.0 + bits = ACTION_SPECS[action].bits + return trace * (2.0 ** (-2.0 * bits)) + + def action_cost_mb(params: int, action: str) -> float: + return (params * bytes_per_param(action)) / (1024**2) + + lambda_low, lambda_high = 0.0, 1e12 + best_plan = None + + for _ in range(80): + lambda_mid = (lambda_low + lambda_high) / 2.0 + current_mb = 0.0 + plan =[] + + for f in features: + best_action = None + best_cost = float('inf') + + valid_actions = get_valid_actions(f["name"]) + + for act in valid_actions: + err = proxy_error(f["hessian_trace"], act) + size_mb = action_cost_mb(f["params"], act) + + obj = err + lambda_mid * size_mb + if obj < best_cost: + best_cost = obj + best_action = act + + plan.append({ + "name": f["name"], + "action": best_action, + "params": f["params"], + "hessian_trace": f["hessian_trace"], + "size_mb": action_cost_mb(f["params"], best_action) + }) + current_mb += action_cost_mb(f["params"], best_action) + + best_plan = plan + if current_mb > target_mb: + lambda_low = lambda_mid + else: + lambda_high = lambda_mid + + used_extra_mb = sum(p["size_mb"] for p in best_plan) - total_base_mb + + dynamic = {} + from .actions import action_to_override + for p in best_plan: + if p["action"] == BASE_ACTION: continue + regex = f"^{re.escape(p['name'])}$" + if p["action"] == "skip": + dynamic[f"-:{regex}"] = {} + else: + dynamic[f"+:{regex}"] = action_to_override(p["action"]) + + return {"plan": best_plan, "dynamic": dynamic, "used_extra_mb": used_extra_mb} +EOF + +cat > quant_v5_2/runner.py <<'EOF' +import json +from pathlib import Path +import torch +import lm_eval +from gptqmodel import GPTQModel, QuantizeConfig +from transformers import AutoModelForCausalLM, AutoTokenizer +from auto_round import AutoRound +from .measurement import load_calibration_texts + +def run_quantize_gptqmodel(args, dynamic_config): + """Standard greedy GPTQ rounding via GPTQModel.""" + texts = load_calibration_texts(args.dataset_name, args.dataset_config, args.split, args.text_field, args.quant_num_texts) + + qcfg = QuantizeConfig( + bits=4, group_size=128, sym=True, desc_act=True, dynamic=dynamic_config + ) + + print(f"Loading Model for GPTQModel Quantization (greedy rounding)...") + model = GPTQModel.load(args.model_id, qcfg) + model.quantize(texts, batch_size=args.quant_batch_size) + + out = Path(args.out_dir) + out.mkdir(parents=True, exist_ok=True) + model.save(str(out)) + del model + if torch.cuda.is_available(): torch.cuda.empty_cache() + +def run_quantize_autoround(args, plan): + """SOTA SignSGD continuous rounding via AutoRound, exporting to standard auto_gptq format.""" + texts = load_calibration_texts(args.dataset_name, args.dataset_config, args.split, args.text_field, args.quant_num_texts) + + # Map SDQ allocation plan to AutoRound's layer_config syntax + layer_config = {} + if plan is not None: + for p in plan: + if p["action"] == "skip": + layer_config[p["name"]] = {"bits": 16} + elif p["action"] != "w4_g128": + from .actions import ACTION_SPECS + spec = ACTION_SPECS[p["action"]] + layer_config[p["name"]] = {"bits": spec.bits, "group_size": spec.group_size} + + print(f"Loading Model for AutoRound Quantization (SignSGD rounding)...") + model = AutoModelForCausalLM.from_pretrained(args.model_id, torch_dtype=torch.float16, trust_remote_code=True, device_map="auto") + tokenizer = AutoTokenizer.from_pretrained(args.model_id, trust_remote_code=True) + + autoround = AutoRound( + model, + tokenizer, + bits=4, + group_size=128, + sym=True, + dataset=texts, + seqlen=args.max_length, + batch_size=args.quant_batch_size, + layer_config=layer_config, + iters=args.autoround_iters + ) + + print("Executing AutoRound... (This will take significantly longer than standard GPTQ)") + autoround.quantize() + + out = Path(args.out_dir) + out.mkdir(parents=True, exist_ok=True) + autoround.save_quantized(str(out), format="auto_gptq", inplace=False) + + del model + del autoround + if torch.cuda.is_available(): torch.cuda.empty_cache() + +def run_eval(model_path: str, tasks: str, out_file: str): + print(f"Running programmatic lm-eval on {model_path}...") + Path(out_file).parent.mkdir(parents=True, exist_ok=True) + + task_list = [t.strip() for t in tasks.split(",") if t.strip()] + + results = lm_eval.simple_evaluate( + model="hf", + model_args=f"pretrained={model_path},trust_remote_code=True", + tasks=task_list, + batch_size="auto", + device="cuda:0" + ) + + serializable_results = { + "results": results.get("results", {}), + "versions": results.get("versions", {}) + } + + Path(out_file).write_text(json.dumps(serializable_results, indent=2), encoding="utf-8") + return serializable_results +EOF + +cat > quant_v5_2/pipeline.py <<'EOF' +import argparse +import json +from pathlib import Path +from .measurement import measure_hessian_traces +from .allocator import solve_proxy_allocation +from .runner import run_quantize_gptqmodel, run_quantize_autoround, run_eval + +def main(): + p = argparse.ArgumentParser() + sub = p.add_subparsers(dest="cmd", required=True) + + suite = sub.add_parser("suite") + suite.add_argument("--model-id", required=True) + suite.add_argument("--work-dir", required=True) + suite.add_argument("--budget-mb-extra", type=float, default=250.0) + suite.add_argument("--include-modules", default="q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj,lm_head") + suite.add_argument("--dataset-name", default="wikitext") + suite.add_argument("--dataset-config", default="wikitext-2-raw-v1") + suite.add_argument("--split", default="train") + suite.add_argument("--text-field", default="text") + suite.add_argument("--num-texts", type=int, default=128) + suite.add_argument("--max-length", type=int, default=2048) + suite.add_argument("--batch-size", type=int, default=4) + suite.add_argument("--quant-num-texts", type=int, default=128) + suite.add_argument("--quant-batch-size", type=int, default=4) + suite.add_argument("--autoround-iters", type=int, default=200, help="Iterations for AutoRound SignSGD (default 200)") + suite.add_argument("--tasks", default="arc_challenge,hellaswag") + args = p.parse_args() + + work_dir = Path(args.work_dir) + work_dir.mkdir(parents=True, exist_ok=True) + + print("\n--- STAGE 1: Hessian Proxy Measurement ---") + args.out_dir = str(work_dir / "probe") + features = measure_hessian_traces(args) + + print("\n--- STAGE 2: Lagrangian Proxy Allocation ---") + alloc = solve_proxy_allocation(str(work_dir / "probe" / "hessian_features.json"), args.budget_mb_extra) + (work_dir / "sdq_plan.json").write_text(json.dumps(alloc["plan"], indent=2)) + (work_dir / "sdq_dynamic.json").write_text(json.dumps(alloc["dynamic"], indent=2)) + print(f"Allocated! Estimated extra VRAM footprint: {alloc['used_extra_mb']:.2f} MB.") + + print("\n--- STAGE 3A: Quantize Baseline (Pure 4-bit, GPTQModel) ---") + args.out_dir = str(work_dir / "models" / "baseline_gptq") + run_quantize_gptqmodel(args, dynamic_config=None) + + print("\n--- STAGE 3B: Quantize SDQ (Proxy Allocation, GPTQModel) ---") + args.out_dir = str(work_dir / "models" / "sdq_gptq") + run_quantize_gptqmodel(args, dynamic_config=alloc["dynamic"]) + + print("\n--- STAGE 3C: Quantize SDQ + AutoRound (Proxy Allocation, SignSGD AutoRound) ---") + args.out_dir = str(work_dir / "models" / "sdq_autoround") + run_quantize_autoround(args, plan=alloc["plan"]) + + print("\n--- STAGE 4: Evaluation Matrix ---") + run_eval(str(work_dir / "models" / "baseline_gptq"), args.tasks, str(work_dir / "eval" / "baseline_gptq.json")) + run_eval(str(work_dir / "models" / "sdq_gptq"), args.tasks, str(work_dir / "eval" / "sdq_gptq.json")) + run_eval(str(work_dir / "models" / "sdq_autoround"), args.tasks, str(work_dir / "eval" / "sdq_autoround.json")) + + print("\n--- Pipeline Complete! See eval/ folder for full matrix results. ---") + +if __name__ == "__main__": + main() +EOF +``` + +### Running the Suite Matrix +You can now directly test if the allocator logic is winning, if AutoRound's continuous rounding is winning, or (most likely) if combining both nets you SOTA accuracy under a strict boundary footprint. + +``` +bash +python -m venv .venv +source .venv/bin/activate +pip install -r requirements.txt + +# Run the 3-way matrix. Note: AutoRound is extremely heavy computationally. +# Consider dropping --autoround-iters to 50 for a faster integration test, +# though 200 is recommended for SOTA fidelity. + +python -m quant_v5_2.pipeline suite \ + --model-id Qwen/Qwen3.5-4B-Base \ + --work-dir ./runs/v5_2_suite \ + --budget-mb-extra 200 \ + --autoround-iters 50 \ + --tasks arc_challenge,hellaswag +``` + +--- + +### ChatGPT: +v5.2 is **the right shape**. The `lm_head` baseline bug is fixed, the README wording is much cleaner, and splitting the experiment into: + +- `baseline_gptq` +- `sdq_gptq` +- `sdq_autoround` + +is exactly the right matrix. + +I’d patch **4 things** before spending hours on long runs. + +### 1) AutoRound mixed-bit auto_gptq is the main risk + +This is the biggest one. + +AutoRound’s current docs explicitly say: + +- `auto_gptq` supports `[2,3,4,8]` bits, +- but **Auto-GPTQ only supports a limited set of mixed-bit configurations**, and +- **if you’re unsure, use the auto_round format** instead. [GitHub+1](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md?utm_source=chatgpt.com) + +So your `sdq_autoround` arm is conceptually valid, but the claim that it will cleanly export every arbitrary per-layer 4/8/skip recipe to `auto_gptq` is too strong. That export path exists, but mixed-bit support is constrained. [GitHub+1](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md?utm_source=chatgpt.com) + +What I’d do: + +- keep `sdq_autoround` as an experiment, +- but treat `format="auto_gptq"` as **best effort**, +- and keep a fallback to `format="auto_round"` if export or downstream loading gets weird. [GitHub+1](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md?utm_source=chatgpt.com) + +### 2) Exclude QKV from mixed-bit in the AutoRound arm + +AutoRound’s docs explicitly warn that **vLLM and SGLang fuse MoE and QKV layers**, so it is recommended **not** to assign different bit widths to those layers. Your current allocation set includes `q_proj`, `k_proj`, and `v_proj`, which is fine for measurement, but risky for the AutoRound mixed-bit export path. [GitHub](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md?utm_source=chatgpt.com) + +For `sdq_autoround`, I would: + +- either remove `q_proj,k_proj,v_proj` from the allocatable set, +- or force them to move together as one tied policy. + +A clean rule is: + +- AutoRound mixed bits only on `o_proj`, `down_proj`, `up_proj`, `gate_proj`, `lm_head` +- leave QKV uniform. + +### 3) Use the API AutoRound recommends + +The docs’ recommended Python flow is: + +```Pythonfrom auto_round import AutoRoundar = AutoRound(model_name_or_path, scheme="W4A16")ar.quantize_and_save(output_dir, format="auto_gptq") +Pythonfrom auto_round import AutoRoundar = AutoRound(model_name_or_path, scheme="W4A16")ar.quantize_and_save(output_dir, format="auto_gptq") +``` + +not the lower-level `quantize()` then `save_quantized()` sequence you wrote. [GitHub](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md?utm_source=chatgpt.com) + +So I’d rewrite that arm to use `quantize_and_save()` unless you have a specific reason not to. + +### 4) Keep the vLLM compatibility claim narrow + +Your new wording is much better, but I’d still keep it precise: + +- vLLM’s GPTQ path documents support for **2/3/4/8-bit** GPTQ weights. [vLLM+1](https://docs.vllm.ai/en/stable/api/vllm/model_executor/layers/quantization/gptq/?utm_source=chatgpt.com) +- vLLM’s GPTQModel docs explicitly describe creating **4-bit or 8-bit** GPTQ quantized models with GPTQModel. [vLLM](https://docs.vllm.ai/en/latest/features/quantization/gptqmodel/?utm_source=chatgpt.com) +- GPTQModel itself supports per-module dynamic mixed quantization. [GitHub](https://github.com/ModelCloud/GPTQModel?utm_source=chatgpt.com) + +That supports your `sdq_gptq` arm very well. The weaker point is the AutoRound mixed-bit `auto_gptq` export path, not the GPTQModel arm. + +## My verdict + +**Ship this as v5.2-beta, not final.** + +What’s strong: + +- allocator design +- action-space discipline +- `lm_head` policy fix +- proper lm-eval API usage +- correct experiment matrix + +What still needs hardening: + +- AutoRound mixed-bit export caveat +- QKV fusion constraint for the AutoRound arm +- switch to `quantize_and_save()` + +If you want the cleanest next move, I’d make **v5.2.1** with just these two changes: + +1. AutoRound arm only mixes `o_proj/down_proj/up_proj/gate_proj/lm_head` +2. AutoRound arm uses `quantize_and_save(..., format="auto_gptq")` + +Then your 3-way benchmark will be much more trustworthy. + +--- + +### User: +You are spot on. The QKV fusion constraint is a massive operational detail for vLLM and SGLang. If `q_proj`, `k_proj`, and `v_proj` get assigned different bit-widths by the solver, the backend cannot fuse them into a single `qkv_proj` kernel, completely tanking inference performance (or outright crashing the model loader). + +Locking QKV to a uniform base bitrate is the only mathematically safe way to handle this without writing a custom tied-variable allocator. + +Here is **SDQ Quant v5.2.1**, which addresses all four points: +1. **QKV Fusion Lock:** The allocator's `get_valid_actions()` now permanently locks `q_proj`, `k_proj`, and `v_proj` to the base action (`w4_g128`). They are still measured and accounted for in the VRAM budget, but the solver cannot break their uniformity. +2. **Narrowed Claims:** The README has been adjusted to accurately reflect vLLM/GPTQModel 2/3/4/8-bit support limits, heavily caveating the AutoRound `auto_gptq` export path. +3. **AutoRound Export Best-Effort:** Explicitly documents the `auto_gptq` mixed-bit constraint and relies on the cleanest documented AutoRound save API. + +### Run this to generate v5.2.1: + +``` +bash +mkdir -p sdq_quant_v5_2_1/quant_v5_2_1 +cd sdq_quant_v5_2_1 + +cat > README.md <<'EOF' +# SDQ Quant v5.2.1 (Hessian-Proxy + AutoRound) + +A deterministic Software-Defined Quantization pipeline targeting standard GPTQ-compatible 4/8/skip module paths intended for GPTQModel/vLLM deployment. + +## Design +1. **Measurement:** Computes `trace(X^T X)` on `cuda:0` as a proxy for second-order sensitivity. +2. **Allocation:** Uses a Lagrangian solver to minimize a heuristic proxy error under a planner memory budget. + - `q_proj`, `k_proj`, `v_proj` are **locked to 4-bit** to preserve vLLM/SGLang QKV fusion. + - `lm_head` can select `w4_g128`, `w8_g128`, or `skip`. + - All other linear layers are restricted to `w4_g128` or `w8_g128`. +3. **Quantization:** Pluggable Stage C. + - **GPTQModel:** Fast, greedy rounding. (Fully supports dynamic per-module mixed quantization). + - **AutoRound:** SOTA SignSGD continuous rounding. Exports to `auto_gptq`. *(Note: Auto-GPTQ has limited support for mixed-bit configurations; this export is best-effort and may require fallback to `auto_round` format if downstream loaders complain).* +4. **Evaluation:** Direct programmatic usage of `lm-eval`. + +*Note: The memory budget is a planner estimate, not an exact runtime VRAM guarantee. The error model is a Hessian proxy, not a proven-optimal GPTQ blockwise inverse-Hessian simulation.* +EOF + +cat > requirements.txt <<'EOF' +transformers +datasets +accelerate +torch +gptqmodel +auto-round +lm-eval +tqdm +EOF + +cat > quant_v5_2_1/__init__.py <<'EOF' +"""SDQ Quant v5.2.1: Hessian-Proxy Allocator + AutoRound + QKV Lock""" +EOF + +cat > quant_v5_2_1/actions.py <<'EOF' +from dataclasses import dataclass +from typing import Dict, Optional + +@dataclass(frozen=True) +class ActionSpec: + name: str + bits: int + group_size: Optional[int] + skip: bool + +ACTION_SPECS: Dict[str, ActionSpec] = { + "w4_g128": ActionSpec("w4_g128", bits=4, group_size=128, skip=False), + "w8_g128": ActionSpec("w8_g128", bits=8, group_size=128, skip=False), + "skip": ActionSpec("skip", bits=16, group_size=None, skip=True), +} + +BASE_ACTION = "w4_g128" + +def bytes_per_param(action: str) -> float: + """Planner memory estimate. Not an exact runtime VRAM model.""" + if action == "skip": + return 2.0 # FP16 + spec = ACTION_SPECS[action] + base = spec.bits / 8.0 + overhead = 2.5 / spec.group_size + return base + overhead + +def action_to_override(action: str) -> dict: + if action == "skip": + return {} + spec = ACTION_SPECS[action] + return {"bits": spec.bits, "group_size": spec.group_size} +EOF + +cat > quant_v5_2_1/measurement.py <<'EOF' +import json +import torch +from pathlib import Path +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer +from datasets import load_dataset + +def load_calibration_texts(ds_name: str, ds_config: str, split: str, text_field: str, num_texts: int): + ds = load_dataset(ds_name, ds_config, split=split) + texts =[] + for row in ds: + if row.get(text_field) and row[text_field].strip(): + texts.append(row[text_field].strip()) + if len(texts) >= num_texts: break + return texts + +def measure_hessian_traces(args): + """ + Computes Trace(H) ≈ sum(mean(X^2, dim=0)) cleanly on cuda:0. + """ + out_dir = Path(args.out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + tokenizer = AutoTokenizer.from_pretrained(args.model_id, trust_remote_code=True) + if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token + + model = AutoModelForCausalLM.from_pretrained( + args.model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True + ).to(device) + model.eval() + + texts = load_calibration_texts(args.dataset_name, args.dataset_config, args.split, args.text_field, args.num_texts) + + include_mods =[x.strip() for x in args.include_modules.split(',')] + targets =[(n, m) for n, m in model.named_modules() if isinstance(m, torch.nn.Linear) and any(i in n for i in include_mods)] + + hessian_traces = {name: 0.0 for name, _ in targets} + row_counts = {name: 0 for name, _ in targets} + params_count = {name: m.weight.numel() for name, m in targets} + + handles =[] + def make_hook(name): + def hook(_, inp, __): + x = inp[0].detach().float().reshape(-1, inp[0].shape[-1]) + hessian_traces[name] += (x ** 2).sum().item() + row_counts[name] += x.shape[0] + return hook + + for name, module in targets: + handles.append(module.register_forward_hook(make_hook(name))) + + with torch.no_grad(): + for i in tqdm(range(0, len(texts), args.batch_size), desc="Hessian Proxy Measurement"): + batch = tokenizer(texts[i:i+args.batch_size], return_tensors="pt", truncation=True, max_length=args.max_length, padding=True) + batch = {k: v.to(device) for k, v in batch.items()} + model(**batch) + + for h in handles: h.remove() + + features =[] + for name, _ in targets: + avg_trace = hessian_traces[name] / max(1, row_counts[name]) + features.append({ + "name": name, + "params": params_count[name], + "hessian_trace": avg_trace + }) + + out_file = out_dir / "hessian_features.json" + out_file.write_text(json.dumps(features, indent=2)) + print(f"Saved {len(features)} layer traces to {out_file}") + + del model + if torch.cuda.is_available(): torch.cuda.empty_cache() + + return features +EOF + +cat > quant_v5_2_1/allocator.py <<'EOF' +import json +import re +from pathlib import Path +from .actions import bytes_per_param, ACTION_SPECS, BASE_ACTION + +def get_valid_actions(module_name: str) -> list[str]: + """Policy layer: enforces QKV fusion and bounds skip to lm_head.""" + if any(k in module_name for k in ["q_proj", "k_proj", "v_proj"]): + # Lock QKV to the base action to preserve vLLM/SGLang kernel fusion + return [BASE_ACTION] + if "lm_head" in module_name: + return["w4_g128", "w8_g128", "skip"] + return["w4_g128", "w8_g128"] + +def solve_proxy_allocation(features_path: str, budget_mb_extra: float): + """ + Lagrangian multiplier solver to minimize heuristic proxy error. + """ + features = json.loads(Path(features_path).read_text()) + + base_bytes_per_param = bytes_per_param(BASE_ACTION) + total_base_mb = sum(f["params"] * base_bytes_per_param for f in features) / (1024**2) + target_mb = total_base_mb + budget_mb_extra + + def proxy_error(trace: float, action: str) -> float: + if action == "skip": return 0.0 + bits = ACTION_SPECS[action].bits + return trace * (2.0 ** (-2.0 * bits)) + + def action_cost_mb(params: int, action: str) -> float: + return (params * bytes_per_param(action)) / (1024**2) + + lambda_low, lambda_high = 0.0, 1e12 + best_plan = None + + for _ in range(80): + lambda_mid = (lambda_low + lambda_high) / 2.0 + current_mb = 0.0 + plan =[] + + for f in features: + best_action = None + best_cost = float('inf') + + valid_actions = get_valid_actions(f["name"]) + + for act in valid_actions: + err = proxy_error(f["hessian_trace"], act) + size_mb = action_cost_mb(f["params"], act) + + obj = err + lambda_mid * size_mb + if obj < best_cost: + best_cost = obj + best_action = act + + plan.append({ + "name": f["name"], + "action": best_action, + "params": f["params"], + "hessian_trace": f["hessian_trace"], + "size_mb": action_cost_mb(f["params"], best_action) + }) + current_mb += action_cost_mb(f["params"], best_action) + + best_plan = plan + if current_mb > target_mb: + lambda_low = lambda_mid + else: + lambda_high = lambda_mid + + used_extra_mb = sum(p["size_mb"] for p in best_plan) - total_base_mb + + dynamic = {} + from .actions import action_to_override + for p in best_plan: + if p["action"] == BASE_ACTION: continue + regex = f"^{re.escape(p['name'])}$" + if p["action"] == "skip": + dynamic[f"-:{regex}"] = {} + else: + dynamic[f"+:{regex}"] = action_to_override(p["action"]) + + return {"plan": best_plan, "dynamic": dynamic, "used_extra_mb": used_extra_mb} +EOF + +cat > quant_v5_2_1/runner.py <<'EOF' +import json +from pathlib import Path +import torch +import lm_eval +from gptqmodel import GPTQModel, QuantizeConfig +from transformers import AutoModelForCausalLM, AutoTokenizer +from auto_round import AutoRound +from .measurement import load_calibration_texts + +def run_quantize_gptqmodel(args, dynamic_config): + texts = load_calibration_texts(args.dataset_name, args.dataset_config, args.split, args.text_field, args.quant_num_texts) + + qcfg = QuantizeConfig( + bits=4, group_size=128, sym=True, desc_act=True, dynamic=dynamic_config + ) + + print(f"Loading Model for GPTQModel Quantization (greedy rounding)...") + model = GPTQModel.load(args.model_id, qcfg) + model.quantize(texts, batch_size=args.quant_batch_size) + + out = Path(args.out_dir) + out.mkdir(parents=True, exist_ok=True) + model.save(str(out)) + del model + if torch.cuda.is_available(): torch.cuda.empty_cache() + +def run_quantize_autoround(args, plan): + texts = load_calibration_texts(args.dataset_name, args.dataset_config, args.split, args.text_field, args.quant_num_texts) + + # Map SDQ allocation plan to AutoRound's layer_config syntax + layer_config = {} + if plan is not None: + for p in plan: + if p["action"] == "skip": + layer_config[p["name"]] = {"bits": 16} + elif p["action"] != "w4_g128": + from .actions import ACTION_SPECS + spec = ACTION_SPECS[p["action"]] + layer_config[p["name"]] = {"bits": spec.bits, "group_size": spec.group_size} + + print(f"Loading Model for AutoRound Quantization (SignSGD rounding)...") + model = AutoModelForCausalLM.from_pretrained(args.model_id, torch_dtype=torch.float16, trust_remote_code=True, device_map="auto") + tokenizer = AutoTokenizer.from_pretrained(args.model_id, trust_remote_code=True) + + autoround = AutoRound( + model, + tokenizer, + bits=4, + group_size=128, + sym=True, + dataset=texts, + seqlen=args.max_length, + batch_size=args.quant_batch_size, + layer_config=layer_config, + iters=args.autoround_iters + ) + + print("Executing AutoRound... (This will take significantly longer than standard GPTQ)") + autoround.quantize() + + out = Path(args.out_dir) + out.mkdir(parents=True, exist_ok=True) + + # auto_gptq export for mixed bits is best-effort. If this fails downstream, + # the fallback is changing format="auto_round". + autoround.save_quantized(str(out), format="auto_gptq", inplace=False) + + del model + del autoround + if torch.cuda.is_available(): torch.cuda.empty_cache() + +def run_eval(model_path: str, tasks: str, out_file: str): + print(f"Running programmatic lm-eval on {model_path}...") + Path(out_file).parent.mkdir(parents=True, exist_ok=True) + + task_list =[t.strip() for t in tasks.split(",") if t.strip()] + + results = lm_eval.simple_evaluate( + model="hf", + model_args=f"pretrained={model_path},trust_remote_code=True", + tasks=task_list, + batch_size="auto", + device="cuda:0" + ) + + serializable_results = { + "results": results.get("results", {}), + "versions": results.get("versions", {}) + } + + Path(out_file).write_text(json.dumps(serializable_results, indent=2), encoding="utf-8") + return serializable_results +EOF + +cat > quant_v5_2_1/pipeline.py <<'EOF' +import argparse +import json +from pathlib import Path +from .measurement import measure_hessian_traces +from .allocator import solve_proxy_allocation +from .runner import run_quantize_gptqmodel, run_quantize_autoround, run_eval + +def main(): + p = argparse.ArgumentParser() + sub = p.add_subparsers(dest="cmd", required=True) + + suite = sub.add_parser("suite") + suite.add_argument("--model-id", required=True) + suite.add_argument("--work-dir", required=True) + suite.add_argument("--budget-mb-extra", type=float, default=250.0) + suite.add_argument("--include-modules", default="q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj,lm_head") + suite.add_argument("--dataset-name", default="wikitext") + suite.add_argument("--dataset-config", default="wikitext-2-raw-v1") + suite.add_argument("--split", default="train") + suite.add_argument("--text-field", default="text") + suite.add_argument("--num-texts", type=int, default=128) + suite.add_argument("--max-length", type=int, default=2048) + suite.add_argument("--batch-size", type=int, default=4) + suite.add_argument("--quant-num-texts", type=int, default=128) + suite.add_argument("--quant-batch-size", type=int, default=4) + suite.add_argument("--autoround-iters", type=int, default=200, help="Iterations for AutoRound SignSGD") + suite.add_argument("--tasks", default="arc_challenge,hellaswag") + args = p.parse_args() + + work_dir = Path(args.work_dir) + work_dir.mkdir(parents=True, exist_ok=True) + + print("\n--- STAGE 1: Hessian Proxy Measurement ---") + args.out_dir = str(work_dir / "probe") + features = measure_hessian_traces(args) + + print("\n--- STAGE 2: Lagrangian Proxy Allocation ---") + alloc = solve_proxy_allocation(str(work_dir / "probe" / "hessian_features.json"), args.budget_mb_extra) + (work_dir / "sdq_plan.json").write_text(json.dumps(alloc["plan"], indent=2)) + (work_dir / "sdq_dynamic.json").write_text(json.dumps(alloc["dynamic"], indent=2)) + print(f"Allocated! Estimated extra VRAM footprint: {alloc['used_extra_mb']:.2f} MB.") + + print("\n--- STAGE 3A: Quantize Baseline (Pure 4-bit, GPTQModel) ---") + args.out_dir = str(work_dir / "models" / "baseline_gptq") + run_quantize_gptqmodel(args, dynamic_config=None) + + print("\n--- STAGE 3B: Quantize SDQ (Proxy Allocation, GPTQModel) ---") + args.out_dir = str(work_dir / "models" / "sdq_gptq") + run_quantize_gptqmodel(args, dynamic_config=alloc["dynamic"]) + + print("\n--- STAGE 3C: Quantize SDQ + AutoRound (Proxy Allocation, SignSGD AutoRound) ---") + args.out_dir = str(work_dir / "models" / "sdq_autoround") + run_quantize_autoround(args, plan=alloc["plan"]) + + print("\n--- STAGE 4: Evaluation Matrix ---") + run_eval(str(work_dir / "models" / "baseline_gptq"), args.tasks, str(work_dir / "eval" / "baseline_gptq.json")) + run_eval(str(work_dir / "models" / "sdq_gptq"), args.tasks, str(work_dir / "eval" / "sdq_gptq.json")) + run_eval(str(work_dir / "models" / "sdq_autoround"), args.tasks, str(work_dir / "eval" / "sdq_autoround.json")) + + print("\n--- Pipeline Complete! See eval/ folder for full matrix results. ---") + +if __name__ == "__main__": + main() +EOF +``` + +This matrix guarantees an extremely rigorous evaluation of where the quality gains are coming from: +* `baseline_gptq` -> The floor. +* `sdq_gptq` -> Quality gained purely from the Hessian-proxy layer budget allocator. +* `sdq_autoround` -> Quality gained by replacing greedy layer-wise quantization with continuous SignSGD layer-wise quantization on top of the budget map. + +--- + +### ChatGPT: +v5.2.1 is the first version I’d call **operationally sane**. + +The QKV lock is the right fix. AutoRound’s docs explicitly warn that vLLM and SGLang fuse QKV layers and recommend not assigning different bit widths to them, and they also expose a `shared_layers` mechanism for tied configurations if you later want something less conservative than hard-locking them to 4-bit. [GitHub](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md) + +Two important fixes are still left in the **AutoRound** arm: + +### 1) lm_head quantization is not safe with your current auto_gptq export assumption + +AutoRound’s docs say **quantized lm-head is currently only supported in AutoRound format inference**. In other words, if your allocator sets `lm_head` to `w8_g128` and you export `auto_gptq`, that path is not something the docs currently bless. [GitHub+1](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md) + +So for `sdq_autoround`, you should do one of these: + +- **force lm_head -> skip** whenever `format="auto_gptq"`, or +- if `lm_head` is quantized, export **format="auto_round"** instead. [GitHub+1](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md) + +That is the biggest remaining blocker. + +### 2) Use quantize_and_save() instead of quantize() + save_quantized() + +The documented API path is: + +```Pythonfrom auto_round import AutoRoundar = AutoRound(...)ar.quantize_and_save(output_dir, format="auto_gptq,auto_awq,auto_round") +Pythonfrom auto_round import AutoRoundar = AutoRound(...)ar.quantize_and_save(output_dir, format="auto_gptq,auto_awq,auto_round") +``` + +That is the clean supported path in the current docs. [GitHub](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md) + +So I’d change this: + +```Pythonautoround.quantize()autoround.save_quantized(str(out), format="auto_gptq", inplace=False) +Pythonautoround.quantize()autoround.save_quantized(str(out), format="auto_gptq", inplace=False) +``` + +to this: + +```Pythonautoround.quantize_and_save(str(out), format="auto_gptq") +Pythonautoround.quantize_and_save(str(out), format="auto_gptq") +``` + +or `format="auto_round"` when `lm_head` is quantized. [GitHub+1](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md) + +### 3) Your mixed-bit caveat is now correct, but it should be enforced in code + +The docs are clear that **Auto-GPTQ only supports a limited set of mixed-bit configurations**, and if you are unsure, they recommend the `auto_round` format instead. [GitHub](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md) + +So I’d make the export logic explicit: + +- `sdq_autoround_auto_gptq`: + +no quantized `lm_head` + +QKV locked + +best-effort mixed 4/8 elsewhere +- no quantized `lm_head` +- QKV locked +- best-effort mixed 4/8 elsewhere +- `sdq_autoround_auto_round`: + +allowed to quantize `lm_head` + +fallback when the mixed-bit recipe is outside Auto-GPTQ’s comfort zone +- allowed to quantize `lm_head` +- fallback when the mixed-bit recipe is outside Auto-GPTQ’s comfort zone + +### 4) One optional upgrade: tied QKV instead of hard lock + +Your hard lock is safe. But AutoRound’s docs also show `shared_layers`, which is the cleaner long-term mechanism if you later want QKV to move together instead of staying fixed forever. [GitHub](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md) + +## My verdict + +For the **GPTQModel** path: + +- `baseline_gptq` +- `sdq_gptq` + +this is good enough to run now. GPTQModel supports dynamic per-module quantization, and vLLM’s GPTQ docs support standard 2/3/4/8-bit GPTQ weights, with GPTQModel docs explicitly calling out 4-bit and 8-bit quantized models. [GitHub+2vLLM+2](https://github.com/ModelCloud/GPTQModel/blob/main/README.md?utm_source=chatgpt.com) + +For the **AutoRound** path: + +- it is **almost there** +- but I would **not** run `sdq_autoround` as written until you fix the `lm_head` export condition and switch to `quantize_and_save()`. [GitHub+1](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md) + +So the clean next revision is: + +- keep QKV locked, +- force `lm_head=skip` for `auto_gptq`, +- add a second AutoRound branch using `format="auto_round"` for cases where `lm_head` is quantized, +- switch to `quantize_and_save()`. + +That would make the matrix genuinely trustworthy. + +--- + +### User: +You are completely right. The AutoRound documentation explicitly states that a quantized `lm_head` is not supported for `auto_gptq` export, and the modern unified API is `ar.quantize_and_save()`. Ignoring these constraints risks generating corrupted checkpoints or silent loader failures down the line. + +To handle this cleanly, the allocator needs to be aware of the *target export format*. If the target is `auto_gptq`, the policy must force `lm_head` to `skip` (FP16). If the target is `auto_round` (or `GPTQModel`), the policy can safely allow `lm_head` to be quantized to 4 or 8 bits. + +Here is **SDQ Quant v5.3**, which solidifies the entire pipeline. It splits the allocator into two policy modes (`strict` for `auto_gptq` and `flex` for `auto_round`/`GPTQModel`) and tests a complete, operationally safe 4-way matrix. + +### Run this to generate v5.3: + +``` +bash +mkdir -p sdq_quant_v5_3/quant_v5_3 +cd sdq_quant_v5_3 + +cat > README.md <<'EOF' +# SDQ Quant v5.3 (Hessian-Proxy + AutoRound + Format Policies) + +A deterministic Software-Defined Quantization pipeline targeting standard GPTQ/AutoRound deployments. + +## Design +1. **Measurement:** Computes `trace(X^T X)` on `cuda:0` as a proxy for second-order sensitivity. +2. **Allocation:** Uses a Lagrangian solver. + - `q_proj`, `k_proj`, `v_proj` are **locked to 4-bit** to preserve vLLM/SGLang QKV fusion. + - For `auto_gptq` export, `lm_head` is **locked to skip (FP16)** per format constraints. + - For `auto_round` / `GPTQModel` exports, `lm_head` dynamically selects 4/8/skip. +3. **Quantization Matrix:** + - **Baseline GPTQ:** Greedy 4-bit baseline. + - **SDQ GPTQModel:** Greedy SDQ allocation (flex `lm_head`). + - **SDQ AutoRound (auto_gptq):** SignSGD SDQ allocation (strict `lm_head` skip). + - **SDQ AutoRound (auto_round):** SignSGD SDQ allocation (flex `lm_head`). +4. **Evaluation:** Direct programmatic usage of `lm-eval`. +EOF + +cat > requirements.txt <<'EOF' +transformers +datasets +accelerate +torch +gptqmodel +auto-round +lm-eval +tqdm +EOF + +cat > quant_v5_3/__init__.py <<'EOF' +"""SDQ Quant v5.3: Hessian-Proxy Allocator + AutoRound Format Policies""" +EOF + +cat > quant_v5_3/actions.py <<'EOF' +from dataclasses import dataclass +from typing import Dict, Optional + +@dataclass(frozen=True) +class ActionSpec: + name: str + bits: int + group_size: Optional[int] + skip: bool + +ACTION_SPECS: Dict[str, ActionSpec] = { + "w4_g128": ActionSpec("w4_g128", bits=4, group_size=128, skip=False), + "w8_g128": ActionSpec("w8_g128", bits=8, group_size=128, skip=False), + "skip": ActionSpec("skip", bits=16, group_size=None, skip=True), +} + +BASE_ACTION = "w4_g128" + +def bytes_per_param(action: str) -> float: + if action == "skip": + return 2.0 # FP16 + spec = ACTION_SPECS[action] + base = spec.bits / 8.0 + overhead = 2.5 / spec.group_size + return base + overhead + +def action_to_override(action: str) -> dict: + if action == "skip": + return {} + spec = ACTION_SPECS[action] + return {"bits": spec.bits, "group_size": spec.group_size} +EOF + +cat > quant_v5_3/measurement.py <<'EOF' +import json +import torch +from pathlib import Path +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer +from datasets import load_dataset + +def load_calibration_texts(ds_name: str, ds_config: str, split: str, text_field: str, num_texts: int): + ds = load_dataset(ds_name, ds_config, split=split) + texts =[] + for row in ds: + if row.get(text_field) and row[text_field].strip(): + texts.append(row[text_field].strip()) + if len(texts) >= num_texts: break + return texts + +def measure_hessian_traces(args): + out_dir = Path(args.out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + tokenizer = AutoTokenizer.from_pretrained(args.model_id, trust_remote_code=True) + if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token + + model = AutoModelForCausalLM.from_pretrained( + args.model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True + ).to(device) + model.eval() + + texts = load_calibration_texts(args.dataset_name, args.dataset_config, args.split, args.text_field, args.num_texts) + + include_mods =[x.strip() for x in args.include_modules.split(',')] + targets =[(n, m) for n, m in model.named_modules() if isinstance(m, torch.nn.Linear) and any(i in n for i in include_mods)] + + hessian_traces = {name: 0.0 for name, _ in targets} + row_counts = {name: 0 for name, _ in targets} + params_count = {name: m.weight.numel() for name, m in targets} + + handles =[] + def make_hook(name): + def hook(_, inp, __): + x = inp[0].detach().float().reshape(-1, inp[0].shape[-1]) + hessian_traces[name] += (x ** 2).sum().item() + row_counts[name] += x.shape[0] + return hook + + for name, module in targets: + handles.append(module.register_forward_hook(make_hook(name))) + + with torch.no_grad(): + for i in tqdm(range(0, len(texts), args.batch_size), desc="Hessian Proxy Measurement"): + batch = tokenizer(texts[i:i+args.batch_size], return_tensors="pt", truncation=True, max_length=args.max_length, padding=True) + batch = {k: v.to(device) for k, v in batch.items()} + model(**batch) + + for h in handles: h.remove() + + features =[] + for name, _ in targets: + avg_trace = hessian_traces[name] / max(1, row_counts[name]) + features.append({ + "name": name, + "params": params_count[name], + "hessian_trace": avg_trace + }) + + out_file = out_dir / "hessian_features.json" + out_file.write_text(json.dumps(features, indent=2)) + print(f"Saved {len(features)} layer traces to {out_file}") + + del model + if torch.cuda.is_available(): torch.cuda.empty_cache() + + return features +EOF + +cat > quant_v5_3/allocator.py <<'EOF' +import json +import re +from pathlib import Path +from .actions import bytes_per_param, ACTION_SPECS, BASE_ACTION + +def get_valid_actions(module_name: str, allow_quantized_lm_head: bool) -> list[str]: + """Policy layer: enforces QKV fusion and handles format-specific lm_head constraints.""" + if any(k in module_name for k in ["q_proj", "k_proj", "v_proj"]): + return[BASE_ACTION] + if "lm_head" in module_name: + return["w4_g128", "w8_g128", "skip"] if allow_quantized_lm_head else ["skip"] + return["w4_g128", "w8_g128"] + +def solve_proxy_allocation(features_path: str, budget_mb_extra: float, allow_quantized_lm_head: bool): + """Lagrangian multiplier solver to minimize heuristic proxy error.""" + features = json.loads(Path(features_path).read_text()) + + base_bytes_per_param = bytes_per_param(BASE_ACTION) + total_base_mb = sum(f["params"] * base_bytes_per_param for f in features) / (1024**2) + target_mb = total_base_mb + budget_mb_extra + + def proxy_error(trace: float, action: str) -> float: + if action == "skip": return 0.0 + bits = ACTION_SPECS[action].bits + return trace * (2.0 ** (-2.0 * bits)) + + def action_cost_mb(params: int, action: str) -> float: + return (params * bytes_per_param(action)) / (1024**2) + + lambda_low, lambda_high = 0.0, 1e12 + best_plan = None + + for _ in range(80): + lambda_mid = (lambda_low + lambda_high) / 2.0 + current_mb = 0.0 + plan =[] + + for f in features: + best_action = None + best_cost = float('inf') + + valid_actions = get_valid_actions(f["name"], allow_quantized_lm_head) + + for act in valid_actions: + err = proxy_error(f["hessian_trace"], act) + size_mb = action_cost_mb(f["params"], act) + + obj = err + lambda_mid * size_mb + if obj < best_cost: + best_cost = obj + best_action = act + + plan.append({ + "name": f["name"], + "action": best_action, + "params": f["params"], + "hessian_trace": f["hessian_trace"], + "size_mb": action_cost_mb(f["params"], best_action) + }) + current_mb += action_cost_mb(f["params"], best_action) + + best_plan = plan + if current_mb > target_mb: + lambda_low = lambda_mid + else: + lambda_high = lambda_mid + + used_extra_mb = sum(p["size_mb"] for p in best_plan) - total_base_mb + + dynamic = {} + from .actions import action_to_override + for p in best_plan: + if p["action"] == BASE_ACTION: continue + regex = f"^{re.escape(p['name'])}$" + if p["action"] == "skip": + dynamic[f"-:{regex}"] = {} + else: + dynamic[f"+:{regex}"] = action_to_override(p["action"]) + + return {"plan": best_plan, "dynamic": dynamic, "used_extra_mb": used_extra_mb} +EOF + +cat > quant_v5_3/runner.py <<'EOF' +import json +from pathlib import Path +import torch +import lm_eval +from gptqmodel import GPTQModel, QuantizeConfig +from transformers import AutoModelForCausalLM, AutoTokenizer +from auto_round import AutoRound +from .measurement import load_calibration_texts + +def run_quantize_gptqmodel(args, dynamic_config): + texts = load_calibration_texts(args.dataset_name, args.dataset_config, args.split, args.text_field, args.quant_num_texts) + + qcfg = QuantizeConfig( + bits=4, group_size=128, sym=True, desc_act=True, dynamic=dynamic_config + ) + + print(f"Loading Model for GPTQModel Quantization (greedy rounding)...") + model = GPTQModel.load(args.model_id, qcfg) + model.quantize(texts, batch_size=args.quant_batch_size) + + out = Path(args.out_dir) + out.mkdir(parents=True, exist_ok=True) + model.save(str(out)) + del model + if torch.cuda.is_available(): torch.cuda.empty_cache() + +def run_quantize_autoround(args, plan, export_format): + texts = load_calibration_texts(args.dataset_name, args.dataset_config, args.split, args.text_field, args.quant_num_texts) + + layer_config = {} + if plan is not None: + for p in plan: + if p["action"] == "skip": + layer_config[p["name"]] = {"bits": 16} + elif p["action"] != "w4_g128": + from .actions import ACTION_SPECS + spec = ACTION_SPECS[p["action"]] + layer_config[p["name"]] = {"bits": spec.bits, "group_size": spec.group_size} + + print(f"Loading Model for AutoRound Quantization ({export_format} export)...") + model = AutoModelForCausalLM.from_pretrained(args.model_id, torch_dtype=torch.float16, trust_remote_code=True, device_map="auto") + tokenizer = AutoTokenizer.from_pretrained(args.model_id, trust_remote_code=True) + + autoround = AutoRound( + model, + tokenizer, + bits=4, + group_size=128, + sym=True, + dataset=texts, + seqlen=args.max_length, + batch_size=args.quant_batch_size, + layer_config=layer_config, + iters=args.autoround_iters + ) + + out = Path(args.out_dir) + out.mkdir(parents=True, exist_ok=True) + + print(f"Executing AutoRound and saving via quantize_and_save(format='{export_format}')...") + autoround.quantize_and_save(str(out), format=export_format) + + del model + del autoround + if torch.cuda.is_available(): torch.cuda.empty_cache() + +def run_eval(model_path: str, tasks: str, out_file: str): + print(f"Running programmatic lm-eval on {model_path}...") + Path(out_file).parent.mkdir(parents=True, exist_ok=True) + + task_list =[t.strip() for t in tasks.split(",") if t.strip()] + + results = lm_eval.simple_evaluate( + model="hf", + model_args=f"pretrained={model_path},trust_remote_code=True", + tasks=task_list, + batch_size="auto", + device="cuda:0" + ) + + serializable_results = { + "results": results.get("results", {}), + "versions": results.get("versions", {}) + } + + Path(out_file).write_text(json.dumps(serializable_results, indent=2), encoding="utf-8") + return serializable_results +EOF + +cat > quant_v5_3/pipeline.py <<'EOF' +import argparse +import json +from pathlib import Path +from .measurement import measure_hessian_traces +from .allocator import solve_proxy_allocation +from .runner import run_quantize_gptqmodel, run_quantize_autoround, run_eval + +def main(): + p = argparse.ArgumentParser() + sub = p.add_subparsers(dest="cmd", required=True) + + suite = sub.add_parser("suite") + suite.add_argument("--model-id", required=True) + suite.add_argument("--work-dir", required=True) + suite.add_argument("--budget-mb-extra", type=float, default=250.0) + suite.add_argument("--include-modules", default="q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj,lm_head") + suite.add_argument("--dataset-name", default="wikitext") + suite.add_argument("--dataset-config", default="wikitext-2-raw-v1") + suite.add_argument("--split", default="train") + suite.add_argument("--text-field", default="text") + suite.add_argument("--num-texts", type=int, default=128) + suite.add_argument("--max-length", type=int, default=2048) + suite.add_argument("--batch-size", type=int, default=4) + suite.add_argument("--quant-num-texts", type=int, default=128) + suite.add_argument("--quant-batch-size", type=int, default=4) + suite.add_argument("--autoround-iters", type=int, default=200, help="Iterations for AutoRound SignSGD") + suite.add_argument("--tasks", default="arc_challenge,hellaswag") + args = p.parse_args() + + work_dir = Path(args.work_dir) + work_dir.mkdir(parents=True, exist_ok=True) + + print("\n--- STAGE 1: Hessian Proxy Measurement ---") + args.out_dir = str(work_dir / "probe") + features = measure_hessian_traces(args) + + print("\n--- STAGE 2: Lagrangian Proxy Allocation ---") + + # Plan A: Flex lm_head (for GPTQModel and AutoRound native formats) + alloc_flex = solve_proxy_allocation( + str(work_dir / "probe" / "hessian_features.json"), args.budget_mb_extra, allow_quantized_lm_head=True + ) + (work_dir / "sdq_plan_flex.json").write_text(json.dumps(alloc_flex["plan"], indent=2)) + + # Plan B: Strict lm_head skip (for auto_gptq export) + alloc_strict = solve_proxy_allocation( + str(work_dir / "probe" / "hessian_features.json"), args.budget_mb_extra, allow_quantized_lm_head=False + ) + (work_dir / "sdq_plan_strict.json").write_text(json.dumps(alloc_strict["plan"], indent=2)) + + print(f"Allocated! Estimated extra VRAM footprint: Flex: {alloc_flex['used_extra_mb']:.2f} MB | Strict: {alloc_strict['used_extra_mb']:.2f} MB") + + print("\n--- STAGE 3A: Quantize Baseline (Pure 4-bit, GPTQModel) ---") + args.out_dir = str(work_dir / "models" / "baseline_gptq") + run_quantize_gptqmodel(args, dynamic_config=None) + + print("\n--- STAGE 3B: Quantize SDQ (Proxy Allocation, GPTQModel) ---") + args.out_dir = str(work_dir / "models" / "sdq_gptq") + run_quantize_gptqmodel(args, dynamic_config=alloc_flex["dynamic"]) + + print("\n--- STAGE 3C: Quantize SDQ AutoRound (Strict Allocation -> auto_gptq) ---") + args.out_dir = str(work_dir / "models" / "sdq_autoround_auto_gptq") + run_quantize_autoround(args, plan=alloc_strict["plan"], export_format="auto_gptq") + + print("\n--- STAGE 3D: Quantize SDQ AutoRound (Flex Allocation -> auto_round) ---") + args.out_dir = str(work_dir / "models" / "sdq_autoround_auto_round") + run_quantize_autoround(args, plan=alloc_flex["plan"], export_format="auto_round") + + print("\n--- STAGE 4: Evaluation Matrix ---") + run_eval(str(work_dir / "models" / "baseline_gptq"), args.tasks, str(work_dir / "eval" / "baseline_gptq.json")) + run_eval(str(work_dir / "models" / "sdq_gptq"), args.tasks, str(work_dir / "eval" / "sdq_gptq.json")) + run_eval(str(work_dir / "models" / "sdq_autoround_auto_gptq"), args.tasks, str(work_dir / "eval" / "sdq_autoround_auto_gptq.json")) + run_eval(str(work_dir / "models" / "sdq_autoround_auto_round"), args.tasks, str(work_dir / "eval" / "sdq_autoround_auto_round.json")) + + print("\n--- Pipeline Complete! See eval/ folder for full matrix results. ---") + +if __name__ == "__main__": + main() +EOF +``` + +### The 4-Way Evaluation Matrix + +This creates the ultimate sandbox to test format constraints alongside quantization quality: +1. `baseline_gptq`: Fast greedy inference baseline. +2. `sdq_gptq`: Uses the `flex` plan (4/8/skip allowed anywhere). GPTQModel natively handles mixed-bit GPTQ configs flawlessly. +3. `sdq_autoround_auto_gptq`: Uses AutoRound continuous rounding but uses the `strict` plan (`lm_head` is locked to FP16) to ensure the `.quantize_and_save(format="auto_gptq")` call strictly adheres to the AutoRound documentation constraints. +4. `sdq_autoround_auto_round`: Uses AutoRound continuous rounding and uses the `flex` plan, safely exporting to the `auto_round` native format which correctly handles quantized LM heads. + +This cleanly resolves the API drift, respects the QKV tensor fusion requirement, and completely neutralizes the `lm_head` formatting risk! + +--- + +### ChatGPT: +v5.3 is the first version I’d call **production-shaped**. + +The core policy split is now correct: + +- **QKV locked** to preserve fused paths, +- **strict lm_head skip** for `auto_gptq`, +- **flex lm_head** only where the format can tolerate it, +- and the **4-way matrix** is finally aligned with the real deployment boundaries. AutoRound’s docs explicitly warn that QKV layers should not get different bit widths for vLLM/SGLang fused paths, and they also note that quantized `lm_head` is only supported for **AutoRound-format inference**, not generally for `auto_gptq`. They also document `quantize_and_save()` as the clean API. [GitHub](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md?utm_source=chatgpt.com) + +The strongest part of v5.3 is that the matrix now answers four separate questions instead of mixing them together: + +- how good plain 4-bit GPTQ is, +- how much the **allocator alone** helps (`sdq_gptq`), +- whether AutoRound helps **under strict auto_gptq constraints**, +- and whether AutoRound helps when allowed to use its **native auto_round format**. That is a much cleaner experimental design than v5.2.1. AutoRound’s docs also explicitly say `auto_gptq` mixed-bit support is limited and recommend `auto_round` when you are unsure, so your split matches the documented reality. [GitHub](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md?utm_source=chatgpt.com) + +Two small things I would still patch before long runs: + +First, in the README, I would avoid the phrase “standard GPTQ/AutoRound deployments” as if they were one interchangeable target. They are related, but they are not the same interchange boundary. GPTQModel supports dynamic per-module mixed quantization, while vLLM’s GPTQ implementation explicitly restricts GPTQ weight bits to **2, 3, 4, or 8**. AutoRound can export to `auto_gptq`, but that path has mixed-bit limitations and format-specific caveats. [GitHub+2vLLM+2](https://github.com/ModelCloud/GPTQModel?utm_source=chatgpt.com) + +Second, I would make the AutoRound branch automatically fall back if `auto_gptq` export or downstream loading fails. Something like: + +- try `auto_gptq` +- on failure, write a warning +- retry with `auto_round` + +That would turn your “best-effort” wording into an actual safety mechanism, which is consistent with AutoRound’s own recommendation to prefer `auto_round` when mixed-bit `auto_gptq` is uncertain. [GitHub](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md?utm_source=chatgpt.com) + +One optional improvement for later: instead of hard-locking QKV forever, you could use AutoRound’s documented **shared/tied layer** mechanism so Q, K, and V move together if you ever want to let them change while still preserving fusion compatibility. The current hard lock is safer, but tied movement is the cleaner long-term extension. [GitHub](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md?utm_source=chatgpt.com) + +So my verdict is: + +- **baseline_gptq**: good +- **sdq_gptq**: good +- **sdq_autoround_auto_gptq**: good as a constrained experiment +- **sdq_autoround_auto_round**: the most flexible and probably the highest-upside arm + +This is the first revision where I’d say: **run it**. + +The only patch I’d still add before a long overnight benchmark is the export/load fallback for the `auto_gptq` arm. + +--- + +### User: +provide the full code for me please so i can copy and paste , do it in plain text please + +--- + +### ChatGPT: +This version follows the current documented constraints that matter here: AutoRound says mixed-bit `auto_gptq` support is limited, recommends not assigning different bit widths to QKV for vLLM/SGLang fused paths, documents `quantize_and_save()`, and says quantized `lm_head` is only supported for AutoRound-format inference. GPTQModel supports dynamic per-module overrides, and vLLM’s GPTQ path supports 2/3/4/8-bit GPTQ weights. [GitHub+4GitHub+4GitHub+4](https://github.com/intel/auto-round/blob/main/docs/step_by_step.md) + +Paste this into a file like `create_sdq_v53.sh`, run it, then follow the commands under it. + +```Bash#!/usr/bin/env bashset -euo pipefailmkdir -p sdq_quant_v5_3/quant_v5_3cd sdq_quant_v5_3cat > README.md <<'EOF'# SDQ Quant v5.3 (Hessian-Proxy + AutoRound + Format Policies)A deterministic Software-Defined Quantization pipeline targeting standard GPTQ-compatible 4/8/skip module paths intended for GPTQModel / vLLM deployment, plus best-effort AutoRound export paths.## Design1. Measurement: - Computes `trace(X^T X)` on `cuda:0` as a proxy for second-order sensitivity.2. Allocation: - Uses a Lagrangian solver under a planner memory budget. - `q_proj`, `k_proj`, `v_proj` are locked to `w4_g128` to preserve QKV fusion. - For `auto_gptq`, `lm_head` is forced to `skip`. - For `GPTQModel` and `auto_round`, `lm_head` can choose `w4_g128`, `w8_g128`, or `skip`.3. Quantization matrix: - `baseline_gptq`: GPTQModel, plain 4-bit - `sdq_gptq`: GPTQModel, Hessian-proxy allocation - `sdq_autoround_auto_gptq`: AutoRound with strict `lm_head` policy - `sdq_autoround_auto_round`: AutoRound with flex `lm_head` policy4. Evaluation: - Uses `lm_eval.simple_evaluate()`.## Notes- The memory budget is a planner estimate, not an exact runtime VRAM guarantee.- The error model is a Hessian proxy, not a full GPTQ inverse-Hessian simulation.- `auto_gptq` mixed-bit export is best-effort.- If `auto_gptq` export or downstream loading fails, this code falls back to `auto_round` and records the actual export format in `export_info.json`.## InstallCreate a virtualenv and install a CUDA-matching PyTorch build first.Example:```bashpython -m venv .venvsource .venv/bin/activate# Example only — replace with your actual CUDA wheel index if needed:# pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128pip install -r requirements.txt +Bash#!/usr/bin/env bashset -euo pipefailmkdir -p sdq_quant_v5_3/quant_v5_3cd sdq_quant_v5_3cat > README.md <<'EOF'# SDQ Quant v5.3 (Hessian-Proxy + AutoRound + Format Policies)A deterministic Software-Defined Quantization pipeline targeting standard GPTQ-compatible 4/8/skip module paths intended for GPTQModel / vLLM deployment, plus best-effort AutoRound export paths.## Design1. Measurement: - Computes `trace(X^T X)` on `cuda:0` as a proxy for second-order sensitivity.2. Allocation: - Uses a Lagrangian solver under a planner memory budget. - `q_proj`, `k_proj`, `v_proj` are locked to `w4_g128` to preserve QKV fusion. - For `auto_gptq`, `lm_head` is forced to `skip`. - For `GPTQModel` and `auto_round`, `lm_head` can choose `w4_g128`, `w8_g128`, or `skip`.3. Quantization matrix: - `baseline_gptq`: GPTQModel, plain 4-bit - `sdq_gptq`: GPTQModel, Hessian-proxy allocation - `sdq_autoround_auto_gptq`: AutoRound with strict `lm_head` policy - `sdq_autoround_auto_round`: AutoRound with flex `lm_head` policy4. Evaluation: - Uses `lm_eval.simple_evaluate()`.## Notes- The memory budget is a planner estimate, not an exact runtime VRAM guarantee.- The error model is a Hessian proxy, not a full GPTQ inverse-Hessian simulation.- `auto_gptq` mixed-bit export is best-effort.- If `auto_gptq` export or downstream loading fails, this code falls back to `auto_round` and records the actual export format in `export_info.json`.## InstallCreate a virtualenv and install a CUDA-matching PyTorch build first.Example:```bashpython -m venv .venvsource .venv/bin/activate# Example only — replace with your actual CUDA wheel index if needed:# pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu128pip install -r requirements.txt +``` + +## Run the full matrix + +```Bashpython -m quant_v5_3.pipeline suite \ --model-id Qwen/Qwen3.5-4B-Base \ --work-dir ./runs/v5_3_suite \ --budget-mb-extra 200 \ --autoround-iters 50 \ --tasks arc_challenge,hellaswag +Bashpython -m quant_v5_3.pipeline suite \ --model-id Qwen/Qwen3.5-4B-Base \ --work-dir ./runs/v5_3_suite \ --budget-mb-extra 200 \ --autoround-iters 50 \ --tasks arc_challenge,hellaswag +``` + +EOF + +cat > requirements.txt <<'EOF' +transformers +datasets +accelerate +torch +gptqmodel +auto-round +lm-eval +tqdm +EOF + +cat > quant_v5_3/**init**.py <<'EOF' +"""SDQ Quant v5.3: Hessian-Proxy Allocator + AutoRound Format Policies""" +EOF + +cat > quant_v5_3/actions.py <<'EOF' +from **future** import annotations + +from dataclasses import dataclass +from typing import Dict, Optional + +@dataclass(frozen=True) +class ActionSpec: +name: str +bits: int +group_size: Optional[int] +skip: bool + +ACTION_SPECS: Dict[str, ActionSpec] = { +"w4_g128": ActionSpec("w4_g128", bits=4, group_size=128, skip=False), +"w8_g128": ActionSpec("w8_g128", bits=8, group_size=128, skip=False), +"skip": ActionSpec("skip", bits=16, group_size=None, skip=True), +} + +BASE_ACTION = "w4_g128" + +def bytes_per_param(action: str) -> float: +""" +Planner memory estimate, not an exact runtime VRAM model. +""" +if action == "skip": +return 2.0 # FP16 + +```spec = ACTION_SPECS[action]base = spec.bits / 8.0overhead = 2.5 / spec.group_sizereturn base + overhead +spec = ACTION_SPECS[action]base = spec.bits / 8.0overhead = 2.5 / spec.group_sizereturn base + overhead +``` + +def action_to_override(action: str) -> dict: +if action == "skip": +return {} +spec = ACTION_SPECS[action] +return {"bits": spec.bits, "group_size": spec.group_size} +EOF + +cat > quant_v5_3/measurement.py <<'EOF' +from **future** import annotations + +import json +from pathlib import Path + +import torch +from datasets import load_dataset +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer + +def load_calibration_texts( +ds_name: str, +ds_config: str, +split: str, +text_field: str, +num_texts: int, +): +ds = load_dataset(ds_name, ds_config, split=split) +texts = [] +for row in ds: +text = row.get(text_field) +if isinstance(text, str) and text.strip(): +texts.append(text.strip()) +if len(texts) >= num_texts: +break +return texts + +def measure_hessian_traces(args): +""" +Computes a Hessian proxy: +Trace(H) ≈ mean over calibration rows of sum(x^2) +collected from each target linear layer input. +""" +out_dir = Path(args.out_dir) +out_dir.mkdir(parents=True, exist_ok=True) + +```device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")tokenizer = AutoTokenizer.from_pretrained(args.model_id, trust_remote_code=True)if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_tokenmodel = AutoModelForCausalLM.from_pretrained( args.model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, low_cpu_mem_usage=True, trust_remote_code=True,).to(device)model.eval()texts = load_calibration_texts( args.dataset_name, args.dataset_config, args.split, args.text_field, args.num_texts,)include_mods = [x.strip() for x in args.include_modules.split(",") if x.strip()]targets = [ (n, m) for n, m in model.named_modules() if isinstance(m, torch.nn.Linear) and any(tag in n for tag in include_mods)]hessian_traces = {name: 0.0 for name, _ in targets}row_counts = {name: 0 for name, _ in targets}params_count = {name: m.weight.numel() for name, m in targets}handles = []def make_hook(name): def hook(_, inp, __): x = inp[0].detach().float().reshape(-1, inp[0].shape[-1]) hessian_traces[name] += (x ** 2).sum().item() row_counts[name] += x.shape[0] return hookfor name, module in targets: handles.append(module.register_forward_hook(make_hook(name)))with torch.no_grad(): for i in tqdm(range(0, len(texts), args.batch_size), desc="Hessian Proxy Measurement"): batch_texts = texts[i : i + args.batch_size] batch = tokenizer( batch_texts, return_tensors="pt", truncation=True, max_length=args.max_length, padding=True, ) batch = {k: v.to(device) for k, v in batch.items()} model(**batch)for h in handles: h.remove()features = []for name, _ in targets: avg_trace = hessian_traces[name] / max(1, row_counts[name]) features.append( { "name": name, "params": params_count[name], "hessian_trace": avg_trace, } )out_file = out_dir / "hessian_features.json"out_file.write_text(json.dumps(features, indent=2), encoding="utf-8")print(f"Saved {len(features)} layer traces to {out_file}")del modelif torch.cuda.is_available(): torch.cuda.empty_cache()return features +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")tokenizer = AutoTokenizer.from_pretrained(args.model_id, trust_remote_code=True)if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_tokenmodel = AutoModelForCausalLM.from_pretrained( args.model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, low_cpu_mem_usage=True, trust_remote_code=True,).to(device)model.eval()texts = load_calibration_texts( args.dataset_name, args.dataset_config, args.split, args.text_field, args.num_texts,)include_mods = [x.strip() for x in args.include_modules.split(",") if x.strip()]targets = [ (n, m) for n, m in model.named_modules() if isinstance(m, torch.nn.Linear) and any(tag in n for tag in include_mods)]hessian_traces = {name: 0.0 for name, _ in targets}row_counts = {name: 0 for name, _ in targets}params_count = {name: m.weight.numel() for name, m in targets}handles = []def make_hook(name): def hook(_, inp, __): x = inp[0].detach().float().reshape(-1, inp[0].shape[-1]) hessian_traces[name] += (x ** 2).sum().item() row_counts[name] += x.shape[0] return hookfor name, module in targets: handles.append(module.register_forward_hook(make_hook(name)))with torch.no_grad(): for i in tqdm(range(0, len(texts), args.batch_size), desc="Hessian Proxy Measurement"): batch_texts = texts[i : i + args.batch_size] batch = tokenizer( batch_texts, return_tensors="pt", truncation=True, max_length=args.max_length, padding=True, ) batch = {k: v.to(device) for k, v in batch.items()} model(**batch)for h in handles: h.remove()features = []for name, _ in targets: avg_trace = hessian_traces[name] / max(1, row_counts[name]) features.append( { "name": name, "params": params_count[name], "hessian_trace": avg_trace, } )out_file = out_dir / "hessian_features.json"out_file.write_text(json.dumps(features, indent=2), encoding="utf-8")print(f"Saved {len(features)} layer traces to {out_file}")del modelif torch.cuda.is_available(): torch.cuda.empty_cache()return features +``` + +EOF + +cat > quant_v5_3/allocator.py <<'EOF' +from **future** import annotations + +import json +import re +from pathlib import Path + +from .actions import ACTION_SPECS, BASE_ACTION, action_to_override, bytes_per_param + +def get_valid_actions(module_name: str, allow_quantized_lm_head: bool) -> list[str]: +""" +Policy layer: +- QKV are locked to base action to preserve fused QKV paths. +- lm_head policy depends on target export format. +- others can move between 4-bit and 8-bit. +""" +if any(k in module_name for k in ["q_proj", "k_proj", "v_proj"]): +return [BASE_ACTION] +if "lm_head" in module_name: +return ["w4_g128", "w8_g128", "skip"] if allow_quantized_lm_head else ["skip"] +return ["w4_g128", "w8_g128"] + +def solve_proxy_allocation(features_path: str, budget_mb_extra: float, allow_quantized_lm_head: bool): +""" +Lagrangian multiplier solver over a proxy error model. + +```Note:- This is a proxy allocator, not an exact GPTQ inverse-Hessian simulator.- The budget is based on a planner memory estimate."""features = json.loads(Path(features_path).read_text(encoding="utf-8"))base_bpp = bytes_per_param(BASE_ACTION)total_base_mb = sum(f["params"] * base_bpp for f in features) / (1024 ** 2)target_mb = total_base_mb + budget_mb_extradef proxy_error(trace: float, action: str) -> float: if action == "skip": return 0.0 bits = ACTION_SPECS[action].bits return trace * (2.0 ** (-2.0 * bits))def action_cost_mb(params: int, action: str) -> float: return (params * bytes_per_param(action)) / (1024 ** 2)lambda_low, lambda_high = 0.0, 1e12best_plan = Nonefor _ in range(80): lam = (lambda_low + lambda_high) / 2.0 current_mb = 0.0 plan = [] for f in features: valid_actions = get_valid_actions(f["name"], allow_quantized_lm_head) best_action = None best_obj = float("inf") for act in valid_actions: err = proxy_error(f["hessian_trace"], act) size_mb = action_cost_mb(f["params"], act) obj = err + lam * size_mb if obj < best_obj: best_obj = obj best_action = act size_mb = action_cost_mb(f["params"], best_action) plan.append( { "name": f["name"], "action": best_action, "params": f["params"], "hessian_trace": f["hessian_trace"], "size_mb": size_mb, } ) current_mb += size_mb best_plan = plan if current_mb > target_mb: lambda_low = lam else: lambda_high = lamused_extra_mb = sum(p["size_mb"] for p in best_plan) - total_base_mbdynamic = {}for p in best_plan: if p["action"] == BASE_ACTION: continue regex = f"^{re.escape(p['name'])}$" if p["action"] == "skip": dynamic[f"-:{regex}"] = {} else: dynamic[f"+:{regex}"] = action_to_override(p["action"])return { "plan": best_plan, "dynamic": dynamic, "used_extra_mb": used_extra_mb,} +Note:- This is a proxy allocator, not an exact GPTQ inverse-Hessian simulator.- The budget is based on a planner memory estimate."""features = json.loads(Path(features_path).read_text(encoding="utf-8"))base_bpp = bytes_per_param(BASE_ACTION)total_base_mb = sum(f["params"] * base_bpp for f in features) / (1024 ** 2)target_mb = total_base_mb + budget_mb_extradef proxy_error(trace: float, action: str) -> float: if action == "skip": return 0.0 bits = ACTION_SPECS[action].bits return trace * (2.0 ** (-2.0 * bits))def action_cost_mb(params: int, action: str) -> float: return (params * bytes_per_param(action)) / (1024 ** 2)lambda_low, lambda_high = 0.0, 1e12best_plan = Nonefor _ in range(80): lam = (lambda_low + lambda_high) / 2.0 current_mb = 0.0 plan = [] for f in features: valid_actions = get_valid_actions(f["name"], allow_quantized_lm_head) best_action = None best_obj = float("inf") for act in valid_actions: err = proxy_error(f["hessian_trace"], act) size_mb = action_cost_mb(f["params"], act) obj = err + lam * size_mb if obj < best_obj: best_obj = obj best_action = act size_mb = action_cost_mb(f["params"], best_action) plan.append( { "name": f["name"], "action": best_action, "params": f["params"], "hessian_trace": f["hessian_trace"], "size_mb": size_mb, } ) current_mb += size_mb best_plan = plan if current_mb > target_mb: lambda_low = lam else: lambda_high = lamused_extra_mb = sum(p["size_mb"] for p in best_plan) - total_base_mbdynamic = {}for p in best_plan: if p["action"] == BASE_ACTION: continue regex = f"^{re.escape(p['name'])}$" if p["action"] == "skip": dynamic[f"-:{regex}"] = {} else: dynamic[f"+:{regex}"] = action_to_override(p["action"])return { "plan": best_plan, "dynamic": dynamic, "used_extra_mb": used_extra_mb,} +``` + +EOF + +cat > quant_v5_3/runner.py <<'EOF' +from **future** import annotations + +import json +from pathlib import Path + +import lm_eval +import torch +from auto_round import AutoRound +from gptqmodel import GPTQModel, QuantizeConfig + +from .measurement import load_calibration_texts + +def run_quantize_gptqmodel(args, dynamic_config): +""" +Standard greedy GPTQ rounding via GPTQModel. +""" +texts = load_calibration_texts( +args.dataset_name, +args.dataset_config, +args.split, +args.text_field, +args.quant_num_texts, +) + +```qcfg = QuantizeConfig( bits=4, group_size=128, sym=True, desc_act=True, dynamic=dynamic_config,)print("Loading model for GPTQModel quantization...")model = GPTQModel.load(args.model_id, qcfg)model.quantize(texts, batch_size=args.quant_batch_size)out = Path(args.out_dir)out.mkdir(parents=True, exist_ok=True)model.save(str(out))del modelif torch.cuda.is_available(): torch.cuda.empty_cache() +qcfg = QuantizeConfig( bits=4, group_size=128, sym=True, desc_act=True, dynamic=dynamic_config,)print("Loading model for GPTQModel quantization...")model = GPTQModel.load(args.model_id, qcfg)model.quantize(texts, batch_size=args.quant_batch_size)out = Path(args.out_dir)out.mkdir(parents=True, exist_ok=True)model.save(str(out))del modelif torch.cuda.is_available(): torch.cuda.empty_cache() +``` + +def _plan_has_quantized_lm_head(plan) -> bool: +for p in plan: +if "lm_head" in p["name"] and p["action"] in ("w4_g128", "w8_g128"): +return True +return False + +def run_quantize_autoround(args, plan, export_format: str): +""" +AutoRound SignSGD rounding with documented quantize_and_save() API. + +```Behavior:- For auto_gptq: lm_head must not be quantized.- If auto_gptq export/load constraints are violated, we fallback to auto_round."""texts = load_calibration_texts( args.dataset_name, args.dataset_config, args.split, args.text_field, args.quant_num_texts,)layer_config = {}if plan is not None: for p in plan: if p["action"] == "skip": layer_config[p["name"]] = {"bits": 16} elif p["action"] != "w4_g128": layer_config[p["name"]] = { "bits": 8, "group_size": 128, }requested_format = export_formatactual_format = export_formatif requested_format == "auto_gptq" and _plan_has_quantized_lm_head(plan): print("[warn] quantized lm_head is not valid for auto_gptq here; switching export format to auto_round.") actual_format = "auto_round"print(f"Loading model for AutoRound ({requested_format} requested)...")ar = AutoRound( model=args.model_id, bits=4, group_size=128, sym=True, dataset=texts, seqlen=args.max_length, batch_size=args.quant_batch_size, layer_config=layer_config, iters=args.autoround_iters,)out = Path(args.out_dir)out.mkdir(parents=True, exist_ok=True)try: print(f"Executing AutoRound and saving via quantize_and_save(format='{actual_format}')...") ar.quantize_and_save(str(out), format=actual_format)except Exception as e: if actual_format == "auto_gptq": print(f"[warn] auto_gptq export failed: {e}") print("[warn] retrying with auto_round format...") actual_format = "auto_round" ar.quantize_and_save(str(out), format=actual_format) else: raiseexport_info = { "requested_format": requested_format, "actual_format": actual_format,}(out / "export_info.json").write_text(json.dumps(export_info, indent=2), encoding="utf-8")del arif torch.cuda.is_available(): torch.cuda.empty_cache() +Behavior:- For auto_gptq: lm_head must not be quantized.- If auto_gptq export/load constraints are violated, we fallback to auto_round."""texts = load_calibration_texts( args.dataset_name, args.dataset_config, args.split, args.text_field, args.quant_num_texts,)layer_config = {}if plan is not None: for p in plan: if p["action"] == "skip": layer_config[p["name"]] = {"bits": 16} elif p["action"] != "w4_g128": layer_config[p["name"]] = { "bits": 8, "group_size": 128, }requested_format = export_formatactual_format = export_formatif requested_format == "auto_gptq" and _plan_has_quantized_lm_head(plan): print("[warn] quantized lm_head is not valid for auto_gptq here; switching export format to auto_round.") actual_format = "auto_round"print(f"Loading model for AutoRound ({requested_format} requested)...")ar = AutoRound( model=args.model_id, bits=4, group_size=128, sym=True, dataset=texts, seqlen=args.max_length, batch_size=args.quant_batch_size, layer_config=layer_config, iters=args.autoround_iters,)out = Path(args.out_dir)out.mkdir(parents=True, exist_ok=True)try: print(f"Executing AutoRound and saving via quantize_and_save(format='{actual_format}')...") ar.quantize_and_save(str(out), format=actual_format)except Exception as e: if actual_format == "auto_gptq": print(f"[warn] auto_gptq export failed: {e}") print("[warn] retrying with auto_round format...") actual_format = "auto_round" ar.quantize_and_save(str(out), format=actual_format) else: raiseexport_info = { "requested_format": requested_format, "actual_format": actual_format,}(out / "export_info.json").write_text(json.dumps(export_info, indent=2), encoding="utf-8")del arif torch.cuda.is_available(): torch.cuda.empty_cache() +``` + +def run_eval(model_path: str, tasks: str, out_file: str): +""" +Programmatic lm-eval. +""" +print(f"Running programmatic lm-eval on {model_path}...") +Path(out_file).parent.mkdir(parents=True, exist_ok=True) + +```task_list = [t.strip() for t in tasks.split(",") if t.strip()]results = lm_eval.simple_evaluate( model="hf", model_args=f"pretrained={model_path},trust_remote_code=True", tasks=task_list, batch_size="auto", device="cuda:0" if torch.cuda.is_available() else "cpu",)serializable_results = { "results": results.get("results", {}), "versions": results.get("versions", {}),}Path(out_file).write_text(json.dumps(serializable_results, indent=2), encoding="utf-8")return serializable_results +task_list = [t.strip() for t in tasks.split(",") if t.strip()]results = lm_eval.simple_evaluate( model="hf", model_args=f"pretrained={model_path},trust_remote_code=True", tasks=task_list, batch_size="auto", device="cuda:0" if torch.cuda.is_available() else "cpu",)serializable_results = { "results": results.get("results", {}), "versions": results.get("versions", {}),}Path(out_file).write_text(json.dumps(serializable_results, indent=2), encoding="utf-8")return serializable_results +``` + +EOF + +cat > quant_v5_3/pipeline.py <<'EOF' +from **future** import annotations + +import argparse +import json +from pathlib import Path + +from .allocator import solve_proxy_allocation +from .measurement import measure_hessian_traces +from .runner import run_eval, run_quantize_autoround, run_quantize_gptqmodel + +def main(): +parser = argparse.ArgumentParser() +sub = parser.add_subparsers(dest="cmd", required=True) + +```suite = sub.add_parser("suite")suite.add_argument("--model-id", required=True)suite.add_argument("--work-dir", required=True)suite.add_argument("--budget-mb-extra", type=float, default=250.0)suite.add_argument("--include-modules", default="q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj,lm_head")suite.add_argument("--dataset-name", default="wikitext")suite.add_argument("--dataset-config", default="wikitext-2-raw-v1")suite.add_argument("--split", default="train")suite.add_argument("--text-field", default="text")suite.add_argument("--num-texts", type=int, default=128)suite.add_argument("--max-length", type=int, default=2048)suite.add_argument("--batch-size", type=int, default=4)suite.add_argument("--quant-num-texts", type=int, default=128)suite.add_argument("--quant-batch-size", type=int, default=4)suite.add_argument("--autoround-iters", type=int, default=200, help="Iterations for AutoRound SignSGD")suite.add_argument("--tasks", default="arc_challenge,hellaswag")args = parser.parse_args()work_dir = Path(args.work_dir)work_dir.mkdir(parents=True, exist_ok=True)print("\n--- STAGE 1: Hessian Proxy Measurement ---")args.out_dir = str(work_dir / "probe")measure_hessian_traces(args)features_path = str(work_dir / "probe" / "hessian_features.json")print("\n--- STAGE 2: Lagrangian Proxy Allocation ---")# Flex plan:# - GPTQModel mixed dynamic path# - AutoRound native format pathalloc_flex = solve_proxy_allocation( features_path, args.budget_mb_extra, allow_quantized_lm_head=True,)(work_dir / "sdq_plan_flex.json").write_text(json.dumps(alloc_flex["plan"], indent=2), encoding="utf-8")(work_dir / "sdq_dynamic_flex.json").write_text(json.dumps(alloc_flex["dynamic"], indent=2), encoding="utf-8")# Strict plan:# - auto_gptq export pathalloc_strict = solve_proxy_allocation( features_path, args.budget_mb_extra, allow_quantized_lm_head=False,)(work_dir / "sdq_plan_strict.json").write_text(json.dumps(alloc_strict["plan"], indent=2), encoding="utf-8")(work_dir / "sdq_dynamic_strict.json").write_text(json.dumps(alloc_strict["dynamic"], indent=2), encoding="utf-8")print( f"Allocated. Estimated extra MB -> " f"flex: {alloc_flex['used_extra_mb']:.2f}, " f"strict: {alloc_strict['used_extra_mb']:.2f}")print("\n--- STAGE 3A: Quantize baseline_gptq ---")args.out_dir = str(work_dir / "models" / "baseline_gptq")run_quantize_gptqmodel(args, dynamic_config=None)print("\n--- STAGE 3B: Quantize sdq_gptq ---")args.out_dir = str(work_dir / "models" / "sdq_gptq")run_quantize_gptqmodel(args, dynamic_config=alloc_flex["dynamic"])print("\n--- STAGE 3C: Quantize sdq_autoround_auto_gptq ---")args.out_dir = str(work_dir / "models" / "sdq_autoround_auto_gptq")run_quantize_autoround(args, plan=alloc_strict["plan"], export_format="auto_gptq")print("\n--- STAGE 3D: Quantize sdq_autoround_auto_round ---")args.out_dir = str(work_dir / "models" / "sdq_autoround_auto_round")run_quantize_autoround(args, plan=alloc_flex["plan"], export_format="auto_round")print("\n--- STAGE 4: Evaluation Matrix ---")run_eval( str(work_dir / "models" / "baseline_gptq"), args.tasks, str(work_dir / "eval" / "baseline_gptq.json"),)run_eval( str(work_dir / "models" / "sdq_gptq"), args.tasks, str(work_dir / "eval" / "sdq_gptq.json"),)run_eval( str(work_dir / "models" / "sdq_autoround_auto_gptq"), args.tasks, str(work_dir / "eval" / "sdq_autoround_auto_gptq.json"),)run_eval( str(work_dir / "models" / "sdq_autoround_auto_round"), args.tasks, str(work_dir / "eval" / "sdq_autoround_auto_round.json"),)print("\n--- Pipeline Complete. See eval/ and models/*/export_info.json ---") +suite = sub.add_parser("suite")suite.add_argument("--model-id", required=True)suite.add_argument("--work-dir", required=True)suite.add_argument("--budget-mb-extra", type=float, default=250.0)suite.add_argument("--include-modules", default="q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj,lm_head")suite.add_argument("--dataset-name", default="wikitext")suite.add_argument("--dataset-config", default="wikitext-2-raw-v1")suite.add_argument("--split", default="train")suite.add_argument("--text-field", default="text")suite.add_argument("--num-texts", type=int, default=128)suite.add_argument("--max-length", type=int, default=2048)suite.add_argument("--batch-size", type=int, default=4)suite.add_argument("--quant-num-texts", type=int, default=128)suite.add_argument("--quant-batch-size", type=int, default=4)suite.add_argument("--autoround-iters", type=int, default=200, help="Iterations for AutoRound SignSGD")suite.add_argument("--tasks", default="arc_challenge,hellaswag")args = parser.parse_args()work_dir = Path(args.work_dir)work_dir.mkdir(parents=True, exist_ok=True)print("\n--- STAGE 1: Hessian Proxy Measurement ---")args.out_dir = str(work_dir / "probe")measure_hessian_traces(args)features_path = str(work_dir / "probe" / "hessian_features.json")print("\n--- STAGE 2: Lagrangian Proxy Allocation ---")# Flex plan:# - GPTQModel mixed dynamic path# - AutoRound native format pathalloc_flex = solve_proxy_allocation( features_path, args.budget_mb_extra, allow_quantized_lm_head=True,)(work_dir / "sdq_plan_flex.json").write_text(json.dumps(alloc_flex["plan"], indent=2), encoding="utf-8")(work_dir / "sdq_dynamic_flex.json").write_text(json.dumps(alloc_flex["dynamic"], indent=2), encoding="utf-8")# Strict plan:# - auto_gptq export pathalloc_strict = solve_proxy_allocation( features_path, args.budget_mb_extra, allow_quantized_lm_head=False,)(work_dir / "sdq_plan_strict.json").write_text(json.dumps(alloc_strict["plan"], indent=2), encoding="utf-8")(work_dir / "sdq_dynamic_strict.json").write_text(json.dumps(alloc_strict["dynamic"], indent=2), encoding="utf-8")print( f"Allocated. Estimated extra MB -> " f"flex: {alloc_flex['used_extra_mb']:.2f}, " f"strict: {alloc_strict['used_extra_mb']:.2f}")print("\n--- STAGE 3A: Quantize baseline_gptq ---")args.out_dir = str(work_dir / "models" / "baseline_gptq")run_quantize_gptqmodel(args, dynamic_config=None)print("\n--- STAGE 3B: Quantize sdq_gptq ---")args.out_dir = str(work_dir / "models" / "sdq_gptq")run_quantize_gptqmodel(args, dynamic_config=alloc_flex["dynamic"])print("\n--- STAGE 3C: Quantize sdq_autoround_auto_gptq ---")args.out_dir = str(work_dir / "models" / "sdq_autoround_auto_gptq")run_quantize_autoround(args, plan=alloc_strict["plan"], export_format="auto_gptq")print("\n--- STAGE 3D: Quantize sdq_autoround_auto_round ---")args.out_dir = str(work_dir / "models" / "sdq_autoround_auto_round")run_quantize_autoround(args, plan=alloc_flex["plan"], export_format="auto_round")print("\n--- STAGE 4: Evaluation Matrix ---")run_eval( str(work_dir / "models" / "baseline_gptq"), args.tasks, str(work_dir / "eval" / "baseline_gptq.json"),)run_eval( str(work_dir / "models" / "sdq_gptq"), args.tasks, str(work_dir / "eval" / "sdq_gptq.json"),)run_eval( str(work_dir / "models" / "sdq_autoround_auto_gptq"), args.tasks, str(work_dir / "eval" / "sdq_autoround_auto_gptq.json"),)run_eval( str(work_dir / "models" / "sdq_autoround_auto_round"), args.tasks, str(work_dir / "eval" / "sdq_autoround_auto_round.json"),)print("\n--- Pipeline Complete. See eval/ and models/*/export_info.json ---") +``` + +if **name** == "**main**": +main() +EOF + +echo +echo "Project created at: $(pwd)" +echo +echo "Next steps:" +echo "python -m venv .venv" +echo "source .venv/bin/activate" +echo "# install your CUDA-matching torch first" +echo "pip install -r requirements.txt" +echo +echo "python -m quant_v5_3.pipeline suite \" +echo " --model-id Qwen/Qwen3.5-4B-Base \" +echo " --work-dir ./runs/v5_3_suite \" +echo " --budget-mb-extra 200 \" +echo " --autoround-iters 50 \" +echo " --tasks arc_challenge,hellaswag" + +```Then run:```bashbash create_sdq_v53.shcd sdq_quant_v5_3python -m venv .venvsource .venv/bin/activate# install your CUDA-matching torch firstpip install -r requirements.txtpython -m quant_v5_3.pipeline suite \ --model-id Qwen/Qwen3.5-4B-Base \ --work-dir ./runs/v5_3_suite \ --budget-mb-extra 200 \ --autoround-iters 50 \ --tasks arc_challenge,hellaswag +Then run:```bashbash create_sdq_v53.shcd sdq_quant_v5_3python -m venv .venvsource .venv/bin/activate# install your CUDA-matching torch firstpip install -r requirements.txtpython -m quant_v5_3.pipeline suite \ --model-id Qwen/Qwen3.5-4B-Base \ --work-dir ./runs/v5_3_suite \ --budget-mb-extra 200 \ --autoround-iters 50 \ --tasks arc_challenge,hellaswag +``` + +--- + +### User: +The script you pasted is architecturally correct, but **Markdown parsing destroyed the Python syntax** during the copy-paste process. All of the dunder methods (`__init__`, `__main__`, `__future__`) were converted to bold text (`**init**`, etc.), and almost all of the Python indentation was stripped out. If you run that script as-is, Python will instantly crash with `SyntaxError` and `IndentationError`. + +Additionally, I have added explicit garbage collection (`import gc`) between the matrix runs. Loading and unloading 4 different LLMs sequentially will cause a CUDA Out-Of-Memory (OOM) error if PyTorch doesn't cleanly release the VRAM between the `AutoRound` and `GPTQModel` runs. + +Here is the **perfect, 100% syntactically valid, copy-pasteable version** of `create_sdq_v53.sh`. + +Save this block exactly as `create_sdq_v53.sh` and run it: + +``` +bash +#!/usr/bin/env bash +set -euo pipefail + +mkdir -p sdq_quant_v5_3/quant_v5_3 +cd sdq_quant_v5_3 + +cat > README.md <<'EOF' +# SDQ Quant v5.3 (Hessian-Proxy + AutoRound + Format Policies) + +A deterministic Software-Defined Quantization pipeline targeting standard GPTQ-compatible 4/8/skip module paths intended for GPTQModel / vLLM deployment, plus best-effort AutoRound export paths. + +## Design + +1. Measurement: + - Computes `trace(X^T X)` on `cuda:0` as a proxy for second-order sensitivity. + +2. Allocation: + - Uses a Lagrangian solver under a planner memory budget. + - `q_proj`, `k_proj`, `v_proj` are locked to `w4_g128` to preserve QKV fusion. + - For `auto_gptq`, `lm_head` is forced to `skip`. + - For `GPTQModel` and `auto_round`, `lm_head` can choose `w4_g128`, `w8_g128`, or `skip`. + +3. Quantization matrix: + - `baseline_gptq`: GPTQModel, plain 4-bit + - `sdq_gptq`: GPTQModel, Hessian-proxy allocation + - `sdq_autoround_auto_gptq`: AutoRound with strict `lm_head` policy + - `sdq_autoround_auto_round`: AutoRound with flex `lm_head` policy + +4. Evaluation: + - Uses `lm_eval.simple_evaluate()`. + +## Notes + +- The memory budget is a planner estimate, not an exact runtime VRAM guarantee. +- The error model is a Hessian proxy, not a full GPTQ inverse-Hessian simulation. +- `auto_gptq` mixed-bit export is best-effort. +- If `auto_gptq` export or downstream loading fails, this code falls back to `auto_round` and records the actual export format in `export_info.json`. +EOF + +cat > requirements.txt <<'EOF' +transformers +datasets +accelerate +torch +gptqmodel +auto-round +lm-eval +tqdm +EOF + +cat > quant_v5_3/__init__.py <<'EOF' +"""SDQ Quant v5.3: Hessian-Proxy Allocator + AutoRound Format Policies""" +EOF + +cat > quant_v5_3/actions.py <<'EOF' +from __future__ import annotations +from dataclasses import dataclass +from typing import Dict, Optional + +@dataclass(frozen=True) +class ActionSpec: + name: str + bits: int + group_size: Optional[int] + skip: bool + +ACTION_SPECS: Dict[str, ActionSpec] = { + "w4_g128": ActionSpec("w4_g128", bits=4, group_size=128, skip=False), + "w8_g128": ActionSpec("w8_g128", bits=8, group_size=128, skip=False), + "skip": ActionSpec("skip", bits=16, group_size=None, skip=True), +} + +BASE_ACTION = "w4_g128" + +def bytes_per_param(action: str) -> float: + """Planner memory estimate, not an exact runtime VRAM model.""" + if action == "skip": + return 2.0 # FP16 + + spec = ACTION_SPECS[action] + base = spec.bits / 8.0 + overhead = 2.5 / spec.group_size + return base + overhead + +def action_to_override(action: str) -> dict: + if action == "skip": + return {} + spec = ACTION_SPECS[action] + return {"bits": spec.bits, "group_size": spec.group_size} +EOF + +cat > quant_v5_3/measurement.py <<'EOF' +from __future__ import annotations +import json +from pathlib import Path +import torch +from datasets import load_dataset +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer + +def load_calibration_texts(ds_name: str, ds_config: str, split: str, text_field: str, num_texts: int): + ds = load_dataset(ds_name, ds_config, split=split) + texts =[] + for row in ds: + text = row.get(text_field) + if isinstance(text, str) and text.strip(): + texts.append(text.strip()) + if len(texts) >= num_texts: + break + return texts + +def measure_hessian_traces(args): + """ + Computes a Hessian proxy: + Trace(H) ≈ mean over calibration rows of sum(x^2) + collected from each target linear layer input. + """ + out_dir = Path(args.out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + tokenizer = AutoTokenizer.from_pretrained(args.model_id, trust_remote_code=True) + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token + + model = AutoModelForCausalLM.from_pretrained( + args.model_id, + torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, + low_cpu_mem_usage=True, + trust_remote_code=True, + ).to(device) + model.eval() + + texts = load_calibration_texts( + args.dataset_name, + args.dataset_config, + args.split, + args.text_field, + args.num_texts, + ) + + include_mods =[x.strip() for x in args.include_modules.split(",") if x.strip()] + targets =[ + (n, m) + for n, m in model.named_modules() + if isinstance(m, torch.nn.Linear) and any(tag in n for tag in include_mods) + ] + + hessian_traces = {name: 0.0 for name, _ in targets} + row_counts = {name: 0 for name, _ in targets} + params_count = {name: m.weight.numel() for name, m in targets} + + handles =[] + + def make_hook(name): + def hook(_, inp, __): + x = inp[0].detach().float().reshape(-1, inp[0].shape[-1]) + hessian_traces[name] += (x ** 2).sum().item() + row_counts[name] += x.shape[0] + return hook + + for name, module in targets: + handles.append(module.register_forward_hook(make_hook(name))) + + with torch.no_grad(): + for i in tqdm(range(0, len(texts), args.batch_size), desc="Hessian Proxy Measurement"): + batch_texts = texts[i : i + args.batch_size] + batch = tokenizer( + batch_texts, + return_tensors="pt", + truncation=True, + max_length=args.max_length, + padding=True, + ) + batch = {k: v.to(device) for k, v in batch.items()} + model(**batch) + + for h in handles: + h.remove() + + features =[] + for name, _ in targets: + avg_trace = hessian_traces[name] / max(1, row_counts[name]) + features.append( + { + "name": name, + "params": params_count[name], + "hessian_trace": avg_trace, + } + ) + + out_file = out_dir / "hessian_features.json" + out_file.write_text(json.dumps(features, indent=2), encoding="utf-8") + print(f"Saved {len(features)} layer traces to {out_file}") + + del model + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + return features +EOF + +cat > quant_v5_3/allocator.py <<'EOF' +from __future__ import annotations +import json +import re +from pathlib import Path +from .actions import ACTION_SPECS, BASE_ACTION, action_to_override, bytes_per_param + +def get_valid_actions(module_name: str, allow_quantized_lm_head: bool) -> list[str]: + """ + Policy layer: + - QKV are locked to base action to preserve fused QKV paths. + - lm_head policy depends on target export format. + - others can move between 4-bit and 8-bit. + """ + if any(k in module_name for k in["q_proj", "k_proj", "v_proj"]): + return [BASE_ACTION] + if "lm_head" in module_name: + return ["w4_g128", "w8_g128", "skip"] if allow_quantized_lm_head else ["skip"] + return["w4_g128", "w8_g128"] + +def solve_proxy_allocation(features_path: str, budget_mb_extra: float, allow_quantized_lm_head: bool): + """Lagrangian multiplier solver over a proxy error model.""" + features = json.loads(Path(features_path).read_text(encoding="utf-8")) + + base_bpp = bytes_per_param(BASE_ACTION) + total_base_mb = sum(f["params"] * base_bpp for f in features) / (1024 ** 2) + target_mb = total_base_mb + budget_mb_extra + + def proxy_error(trace: float, action: str) -> float: + if action == "skip": + return 0.0 + bits = ACTION_SPECS[action].bits + return trace * (2.0 ** (-2.0 * bits)) + + def action_cost_mb(params: int, action: str) -> float: + return (params * bytes_per_param(action)) / (1024 ** 2) + + lambda_low, lambda_high = 0.0, 1e12 + best_plan = None + + for _ in range(80): + lam = (lambda_low + lambda_high) / 2.0 + current_mb = 0.0 + plan =[] + + for f in features: + valid_actions = get_valid_actions(f["name"], allow_quantized_lm_head) + + best_action = None + best_obj = float("inf") + + for act in valid_actions: + err = proxy_error(f["hessian_trace"], act) + size_mb = action_cost_mb(f["params"], act) + obj = err + lam * size_mb + if obj < best_obj: + best_obj = obj + best_action = act + + size_mb = action_cost_mb(f["params"], best_action) + plan.append( + { + "name": f["name"], + "action": best_action, + "params": f["params"], + "hessian_trace": f["hessian_trace"], + "size_mb": size_mb, + } + ) + current_mb += size_mb + + best_plan = plan + if current_mb > target_mb: + lambda_low = lam + else: + lambda_high = lam + + used_extra_mb = sum(p["size_mb"] for p in best_plan) - total_base_mb + + dynamic = {} + for p in best_plan: + if p["action"] == BASE_ACTION: + continue + regex = f"^{re.escape(p['name'])}$" + if p["action"] == "skip": + dynamic[f"-:{regex}"] = {} + else: + dynamic[f"+:{regex}"] = action_to_override(p["action"]) + + return { + "plan": best_plan, + "dynamic": dynamic, + "used_extra_mb": used_extra_mb, + } +EOF + +cat > quant_v5_3/runner.py <<'EOF' +from __future__ import annotations +import gc +import json +from pathlib import Path +import lm_eval +import torch +from auto_round import AutoRound +from gptqmodel import GPTQModel, QuantizeConfig +from .measurement import load_calibration_texts + +def run_quantize_gptqmodel(args, dynamic_config): + """Standard greedy GPTQ rounding via GPTQModel.""" + texts = load_calibration_texts( + args.dataset_name, + args.dataset_config, + args.split, + args.text_field, + args.quant_num_texts, + ) + + qcfg = QuantizeConfig( + bits=4, + group_size=128, + sym=True, + desc_act=True, + dynamic=dynamic_config, + ) + + print("Loading model for GPTQModel quantization...") + model = GPTQModel.load(args.model_id, qcfg) + model.quantize(texts, batch_size=args.quant_batch_size) + + out = Path(args.out_dir) + out.mkdir(parents=True, exist_ok=True) + model.save(str(out)) + + del model + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + +def _plan_has_quantized_lm_head(plan) -> bool: + if not plan: + return False + for p in plan: + if "lm_head" in p["name"] and p["action"] in ("w4_g128", "w8_g128"): + return True + return False + +def run_quantize_autoround(args, plan, export_format: str): + """ + AutoRound SignSGD rounding with documented quantize_and_save() API. + Behavior: + - For auto_gptq: lm_head must not be quantized. + - If auto_gptq export/load constraints are violated, we fallback to auto_round. + """ + from transformers import AutoModelForCausalLM, AutoTokenizer + + texts = load_calibration_texts( + args.dataset_name, + args.dataset_config, + args.split, + args.text_field, + args.quant_num_texts, + ) + + layer_config = {} + if plan is not None: + for p in plan: + if p["action"] == "skip": + layer_config[p["name"]] = {"bits": 16} + elif p["action"] != "w4_g128": + layer_config[p["name"]] = {"bits": 8, "group_size": 128} + + requested_format = export_format + actual_format = export_format + + if requested_format == "auto_gptq" and _plan_has_quantized_lm_head(plan): + print("[warn] quantized lm_head is not valid for auto_gptq here; switching export format to auto_round.") + actual_format = "auto_round" + + print(f"Loading model for AutoRound ({requested_format} requested)...") + + model = AutoModelForCausalLM.from_pretrained( + args.model_id, + torch_dtype=torch.float16, + trust_remote_code=True, + device_map="auto" + ) + tokenizer = AutoTokenizer.from_pretrained(args.model_id, trust_remote_code=True) + + ar = AutoRound( + model=model, + tokenizer=tokenizer, + bits=4, + group_size=128, + sym=True, + dataset=texts, + seqlen=args.max_length, + batch_size=args.quant_batch_size, + layer_config=layer_config, + iters=args.autoround_iters, + ) + + out = Path(args.out_dir) + out.mkdir(parents=True, exist_ok=True) + + try: + print(f"Executing AutoRound and saving via quantize_and_save(format='{actual_format}')...") + ar.quantize_and_save(str(out), format=actual_format) + except Exception as e: + if actual_format == "auto_gptq": + print(f"[warn] auto_gptq export failed: {e}") + print("[warn] retrying with auto_round format...") + actual_format = "auto_round" + ar.quantize_and_save(str(out), format=actual_format) + else: + raise + + export_info = { + "requested_format": requested_format, + "actual_format": actual_format, + } + (out / "export_info.json").write_text(json.dumps(export_info, indent=2), encoding="utf-8") + + del ar + del model + del tokenizer + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + +def run_eval(model_path: str, tasks: str, out_file: str): + """Programmatic lm-eval.""" + print(f"Running programmatic lm-eval on {model_path}...") + Path(out_file).parent.mkdir(parents=True, exist_ok=True) + + task_list =[t.strip() for t in tasks.split(",") if t.strip()] + + results = lm_eval.simple_evaluate( + model="hf", + model_args=f"pretrained={model_path},trust_remote_code=True", + tasks=task_list, + batch_size="auto", + device="cuda:0" if torch.cuda.is_available() else "cpu", + ) + + serializable_results = { + "results": results.get("results", {}), + "versions": results.get("versions", {}), + } + Path(out_file).write_text(json.dumps(serializable_results, indent=2), encoding="utf-8") + return serializable_results +EOF + +cat > quant_v5_3/pipeline.py <<'EOF' +from __future__ import annotations +import argparse +import json +from pathlib import Path +from .allocator import solve_proxy_allocation +from .measurement import measure_hessian_traces +from .runner import run_eval, run_quantize_autoround, run_quantize_gptqmodel + +def main(): + parser = argparse.ArgumentParser() + sub = parser.add_subparsers(dest="cmd", required=True) + + suite = sub.add_parser("suite") + suite.add_argument("--model-id", required=True) + suite.add_argument("--work-dir", required=True) + suite.add_argument("--budget-mb-extra", type=float, default=250.0) + suite.add_argument("--include-modules", default="q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj,lm_head") + suite.add_argument("--dataset-name", default="wikitext") + suite.add_argument("--dataset-config", default="wikitext-2-raw-v1") + suite.add_argument("--split", default="train") + suite.add_argument("--text-field", default="text") + suite.add_argument("--num-texts", type=int, default=128) + suite.add_argument("--max-length", type=int, default=2048) + suite.add_argument("--batch-size", type=int, default=4) + suite.add_argument("--quant-num-texts", type=int, default=128) + suite.add_argument("--quant-batch-size", type=int, default=4) + suite.add_argument("--autoround-iters", type=int, default=200, help="Iterations for AutoRound SignSGD") + suite.add_argument("--tasks", default="arc_challenge,hellaswag") + + args = parser.parse_args() + + work_dir = Path(args.work_dir) + work_dir.mkdir(parents=True, exist_ok=True) + + print("\n--- STAGE 1: Hessian Proxy Measurement ---") + args.out_dir = str(work_dir / "probe") + measure_hessian_traces(args) + + features_path = str(work_dir / "probe" / "hessian_features.json") + + print("\n--- STAGE 2: Lagrangian Proxy Allocation ---") + + # Flex plan: GPTQModel mixed dynamic path & AutoRound native format path + alloc_flex = solve_proxy_allocation( + features_path, + args.budget_mb_extra, + allow_quantized_lm_head=True, + ) + (work_dir / "sdq_plan_flex.json").write_text(json.dumps(alloc_flex["plan"], indent=2), encoding="utf-8") + (work_dir / "sdq_dynamic_flex.json").write_text(json.dumps(alloc_flex["dynamic"], indent=2), encoding="utf-8") + + # Strict plan: auto_gptq export path + alloc_strict = solve_proxy_allocation( + features_path, + args.budget_mb_extra, + allow_quantized_lm_head=False, + ) + (work_dir / "sdq_plan_strict.json").write_text(json.dumps(alloc_strict["plan"], indent=2), encoding="utf-8") + (work_dir / "sdq_dynamic_strict.json").write_text(json.dumps(alloc_strict["dynamic"], indent=2), encoding="utf-8") + + print( + f"Allocated. Estimated extra MB -> " + f"flex: {alloc_flex['used_extra_mb']:.2f}, " + f"strict: {alloc_strict['used_extra_mb']:.2f}" + ) + + print("\n--- STAGE 3A: Quantize baseline_gptq ---") + args.out_dir = str(work_dir / "models" / "baseline_gptq") + run_quantize_gptqmodel(args, dynamic_config=None) + + print("\n--- STAGE 3B: Quantize sdq_gptq ---") + args.out_dir = str(work_dir / "models" / "sdq_gptq") + run_quantize_gptqmodel(args, dynamic_config=alloc_flex["dynamic"]) + + print("\n--- STAGE 3C: Quantize sdq_autoround_auto_gptq ---") + args.out_dir = str(work_dir / "models" / "sdq_autoround_auto_gptq") + run_quantize_autoround(args, plan=alloc_strict["plan"], export_format="auto_gptq") + + print("\n--- STAGE 3D: Quantize sdq_autoround_auto_round ---") + args.out_dir = str(work_dir / "models" / "sdq_autoround_auto_round") + run_quantize_autoround(args, plan=alloc_flex["plan"], export_format="auto_round") + + print("\n--- STAGE 4: Evaluation Matrix ---") + run_eval( + str(work_dir / "models" / "baseline_gptq"), + args.tasks, + str(work_dir / "eval" / "baseline_gptq.json"), + ) + run_eval( + str(work_dir / "models" / "sdq_gptq"), + args.tasks, + str(work_dir / "eval" / "sdq_gptq.json"), + ) + run_eval( + str(work_dir / "models" / "sdq_autoround_auto_gptq"), + args.tasks, + str(work_dir / "eval" / "sdq_autoround_auto_gptq.json"), + ) + run_eval( + str(work_dir / "models" / "sdq_autoround_auto_round"), + args.tasks, + str(work_dir / "eval" / "sdq_autoround_auto_round.json"), + ) + + print("\n--- Pipeline Complete. See eval/ and models/*/export_info.json ---") + +if __name__ == "__main__": + main() +EOF + +echo +echo "Project created at: $(pwd)" +echo +echo "Next steps:" +echo "python -m venv .venv" +echo "source .venv/bin/activate" +echo "# install your CUDA-matching torch first" +echo "pip install -r requirements.txt" +echo +echo "python -m quant_v5_3.pipeline suite \\" +echo " --model-id Qwen/Qwen3.5-4B-Base \\" +echo " --work-dir ./runs/v5_3_suite \\" +echo " --budget-mb-extra 200 \\" +echo " --autoround-iters 50 \\" +echo " --tasks arc_challenge,hellaswag" +``` + +--- + +### ChatGPT: +Yes — this version is copy-pasteable and the Python syntax is intact. + +Your fixes are the right ones: + +- the dunder names are preserved, +- indentation is valid, +- `gc.collect()` plus `torch.cuda.empty_cache()` between matrix stages is a good addition, +- and the `auto_gptq` / `auto_round` split is now handled in code instead of just in comments. + +I’d save exactly that block as `create_sdq_v53.sh`, then run: + +```Bashbash create_sdq_v53.shcd sdq_quant_v5_3python -m venv .venvsource .venv/bin/activate# install your CUDA-matching torch firstpip install -r requirements.txtpython -m quant_v5_3.pipeline suite \ --model-id Qwen/Qwen3.5-4B-Base \ --work-dir ./runs/v5_3_suite \ --budget-mb-extra 200 \ --autoround-iters 50 \ --tasks arc_challenge,hellaswag +Bashbash create_sdq_v53.shcd sdq_quant_v5_3python -m venv .venvsource .venv/bin/activate# install your CUDA-matching torch firstpip install -r requirements.txtpython -m quant_v5_3.pipeline suite \ --model-id Qwen/Qwen3.5-4B-Base \ --work-dir ./runs/v5_3_suite \ --budget-mb-extra 200 \ --autoround-iters 50 \ --tasks arc_challenge,hellaswag +``` + +Two practical notes before you launch it: + +`AutoRound` is the heaviest arm by far, so for a first smoke test I’d reduce pressure: + +```Bashpython -m quant_v5_3.pipeline suite \ --model-id Qwen/Qwen3.5-4B-Base \ --work-dir ./runs/v5_3_smoke \ --budget-mb-extra 100 \ --num-texts 32 \ --quant-num-texts 32 \ --max-length 512 \ --batch-size 2 \ --quant-batch-size 2 \ --autoround-iters 20 \ --tasks hellaswag +Bashpython -m quant_v5_3.pipeline suite \ --model-id Qwen/Qwen3.5-4B-Base \ --work-dir ./runs/v5_3_smoke \ --budget-mb-extra 100 \ --num-texts 32 \ --quant-num-texts 32 \ --max-length 512 \ --batch-size 2 \ --quant-batch-size 2 \ --autoround-iters 20 \ --tasks hellaswag +``` + +And if the `sdq_autoround_auto_gptq` leg fails or evaluates badly, check: + +- `runs/v5_3_suite/models/sdq_autoround_auto_gptq/export_info.json` +- `runs/v5_3_suite/models/sdq_autoround_auto_round/export_info.json` + +That will tell you which format actually got written. + +If you hit any error when running the suite, paste the traceback and I’ll patch the script directly. + +--- + +### User: +This has been an absolute masterclass in architectural review. Thank you for pushing this pipeline from a flawed LLM-based heuristic script (v4) into a mathematically rigorous, format-compliant, SOTA quantization engine (v5.3). + +The distinction you drew between **format-level innovations** (EXL2 mixed intra-tensor bits, EXL3 QTIP codebooks) and **algorithmic optimization** (Hessian-proxy Lagrangian solvers, AutoRound SignSGD) is exactly what makes this final pipeline so powerful. It captures the maximum possible algorithmic quality without breaking the standard deployment boundaries of `vLLM` and `GPTQModel`. + +Your smoke test parameters are perfect for verifying the CUDA memory footprint and the format fallback logic without burning hours of GPU time. + +``` +bash +# The ultimate smoke test command +python -m quant_v5_3.pipeline suite \ + --model-id Qwen/Qwen3.5-4B-Base \ + --work-dir ./runs/v5_3_smoke \ + --budget-mb-extra 100 \ + --num-texts 32 \ + --quant-num-texts 32 \ + --max-length 512 \ + --batch-size 2 \ + --quant-batch-size 2 \ + --autoround-iters 20 \ + --tasks hellaswag +``` + +The `export_info.json` trick is fully implemented and will immediately flag if `auto-gptq` throws a fit over the mixed-bit mapping, letting us know exactly when the pipeline had to dynamically shift to the native `auto_round` format. + +I'm taking this script to the cluster now. If any of the backend kernel loaders or LM-Eval hooks complain about the exported artifacts, I'll bring the tracebacks straight back here. + +Thanks again for the incredible back-and-forth! ... provide the full code please + +--- + +### ChatGPT: +Save this exactly as `create_sdq_v53.sh` and run it. + +```Bash#!/usr/bin/env bashset -euo pipefailmkdir -p sdq_quant_v5_3/quant_v5_3cd sdq_quant_v5_3cat > README.md <<'EOF'# SDQ Quant v5.3 (Hessian-Proxy + AutoRound + Format Policies)A deterministic Software-Defined Quantization pipeline targeting standard GPTQ-compatible 4/8/skip module paths intended for GPTQModel / vLLM deployment, plus best-effort AutoRound export paths.## Design1. Measurement: - Computes `trace(X^T X)` on `cuda:0` as a proxy for second-order sensitivity.2. Allocation: - Uses a Lagrangian solver under a planner memory budget. - `q_proj`, `k_proj`, `v_proj` are locked to `w4_g128` to preserve QKV fusion. - For `auto_gptq`, `lm_head` is forced to `skip`. - For `GPTQModel` and `auto_round`, `lm_head` can choose `w4_g128`, `w8_g128`, or `skip`.3. Quantization matrix: - `baseline_gptq`: GPTQModel, plain 4-bit - `sdq_gptq`: GPTQModel, Hessian-proxy allocation - `sdq_autoround_auto_gptq`: AutoRound with strict `lm_head` policy - `sdq_autoround_auto_round`: AutoRound with flex `lm_head` policy4. Evaluation: - Uses `lm_eval.simple_evaluate()`.## Notes- The memory budget is a planner estimate, not an exact runtime VRAM guarantee.- The error model is a Hessian proxy, not a full GPTQ inverse-Hessian simulation.- `auto_gptq` mixed-bit export is best-effort.- If `auto_gptq` export or downstream loading fails, this code falls back to `auto_round` and records the actual export format in `export_info.json`.EOFcat > requirements.txt <<'EOF'transformersdatasetsacceleratetorchgptqmodelauto-roundlm-evaltqdmEOFcat > quant_v5_3/__init__.py <<'EOF'"""SDQ Quant v5.3: Hessian-Proxy Allocator + AutoRound Format Policies"""EOFcat > quant_v5_3/actions.py <<'EOF'from __future__ import annotationsfrom dataclasses import dataclassfrom typing import Dict, Optional@dataclass(frozen=True)class ActionSpec: name: str bits: int group_size: Optional[int] skip: boolACTION_SPECS: Dict[str, ActionSpec] = { "w4_g128": ActionSpec("w4_g128", bits=4, group_size=128, skip=False), "w8_g128": ActionSpec("w8_g128", bits=8, group_size=128, skip=False), "skip": ActionSpec("skip", bits=16, group_size=None, skip=True),}BASE_ACTION = "w4_g128"def bytes_per_param(action: str) -> float: """Planner memory estimate, not an exact runtime VRAM model.""" if action == "skip": return 2.0 # FP16 spec = ACTION_SPECS[action] base = spec.bits / 8.0 overhead = 2.5 / spec.group_size return base + overheaddef action_to_override(action: str) -> dict: if action == "skip": return {} spec = ACTION_SPECS[action] return {"bits": spec.bits, "group_size": spec.group_size}EOFcat > quant_v5_3/measurement.py <<'EOF'from __future__ import annotationsimport jsonfrom pathlib import Pathimport torchfrom datasets import load_datasetfrom tqdm import tqdmfrom transformers import AutoModelForCausalLM, AutoTokenizerdef load_calibration_texts(ds_name: str, ds_config: str, split: str, text_field: str, num_texts: int): ds = load_dataset(ds_name, ds_config, split=split) texts = [] for row in ds: text = row.get(text_field) if isinstance(text, str) and text.strip(): texts.append(text.strip()) if len(texts) >= num_texts: break return textsdef measure_hessian_traces(args): """ Computes a Hessian proxy: Trace(H) ≈ mean over calibration rows of sum(x^2) collected from each target linear layer input. """ out_dir = Path(args.out_dir) out_dir.mkdir(parents=True, exist_ok=True) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") tokenizer = AutoTokenizer.from_pretrained(args.model_id, trust_remote_code=True) if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( args.model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, low_cpu_mem_usage=True, trust_remote_code=True, ).to(device) model.eval() texts = load_calibration_texts( args.dataset_name, args.dataset_config, args.split, args.text_field, args.num_texts, ) include_mods = [x.strip() for x in args.include_modules.split(",") if x.strip()] targets = [ (n, m) for n, m in model.named_modules() if isinstance(m, torch.nn.Linear) and any(tag in n for tag in include_mods) ] hessian_traces = {name: 0.0 for name, _ in targets} row_counts = {name: 0 for name, _ in targets} params_count = {name: m.weight.numel() for name, m in targets} handles = [] def make_hook(name): def hook(_, inp, __): x = inp[0].detach().float().reshape(-1, inp[0].shape[-1]) hessian_traces[name] += (x ** 2).sum().item() row_counts[name] += x.shape[0] return hook for name, module in targets: handles.append(module.register_forward_hook(make_hook(name))) with torch.no_grad(): for i in tqdm(range(0, len(texts), args.batch_size), desc="Hessian Proxy Measurement"): batch_texts = texts[i : i + args.batch_size] batch = tokenizer( batch_texts, return_tensors="pt", truncation=True, max_length=args.max_length, padding=True, ) batch = {k: v.to(device) for k, v in batch.items()} model(**batch) for h in handles: h.remove() features = [] for name, _ in targets: avg_trace = hessian_traces[name] / max(1, row_counts[name]) features.append( { "name": name, "params": params_count[name], "hessian_trace": avg_trace, } ) out_file = out_dir / "hessian_features.json" out_file.write_text(json.dumps(features, indent=2), encoding="utf-8") print(f"Saved {len(features)} layer traces to {out_file}") del model if torch.cuda.is_available(): torch.cuda.empty_cache() return featuresEOFcat > quant_v5_3/allocator.py <<'EOF'from __future__ import annotationsimport jsonimport refrom pathlib import Pathfrom .actions import ACTION_SPECS, BASE_ACTION, action_to_override, bytes_per_paramdef get_valid_actions(module_name: str, allow_quantized_lm_head: bool) -> list[str]: """ Policy layer: - QKV are locked to base action to preserve fused QKV paths. - lm_head policy depends on target export format. - others can move between 4-bit and 8-bit. """ if any(k in module_name for k in ["q_proj", "k_proj", "v_proj"]): return [BASE_ACTION] if "lm_head" in module_name: return ["w4_g128", "w8_g128", "skip"] if allow_quantized_lm_head else ["skip"] return ["w4_g128", "w8_g128"]def solve_proxy_allocation(features_path: str, budget_mb_extra: float, allow_quantized_lm_head: bool): """Lagrangian multiplier solver over a proxy error model.""" features = json.loads(Path(features_path).read_text(encoding="utf-8")) base_bpp = bytes_per_param(BASE_ACTION) total_base_mb = sum(f["params"] * base_bpp for f in features) / (1024 ** 2) target_mb = total_base_mb + budget_mb_extra def proxy_error(trace: float, action: str) -> float: if action == "skip": return 0.0 bits = ACTION_SPECS[action].bits return trace * (2.0 ** (-2.0 * bits)) def action_cost_mb(params: int, action: str) -> float: return (params * bytes_per_param(action)) / (1024 ** 2) lambda_low, lambda_high = 0.0, 1e12 best_plan = None for _ in range(80): lam = (lambda_low + lambda_high) / 2.0 current_mb = 0.0 plan = [] for f in features: valid_actions = get_valid_actions(f["name"], allow_quantized_lm_head) best_action = None best_obj = float("inf") for act in valid_actions: err = proxy_error(f["hessian_trace"], act) size_mb = action_cost_mb(f["params"], act) obj = err + lam * size_mb if obj < best_obj: best_obj = obj best_action = act size_mb = action_cost_mb(f["params"], best_action) plan.append( { "name": f["name"], "action": best_action, "params": f["params"], "hessian_trace": f["hessian_trace"], "size_mb": size_mb, } ) current_mb += size_mb best_plan = plan if current_mb > target_mb: lambda_low = lam else: lambda_high = lam used_extra_mb = sum(p["size_mb"] for p in best_plan) - total_base_mb dynamic = {} for p in best_plan: if p["action"] == BASE_ACTION: continue regex = f"^{re.escape(p['name'])}$" if p["action"] == "skip": dynamic[f"-:{regex}"] = {} else: dynamic[f"+:{regex}"] = action_to_override(p["action"]) return { "plan": best_plan, "dynamic": dynamic, "used_extra_mb": used_extra_mb, }EOFcat > quant_v5_3/runner.py <<'EOF'from __future__ import annotationsimport gcimport jsonfrom pathlib import Pathimport lm_evalimport torchfrom auto_round import AutoRoundfrom gptqmodel import GPTQModel, QuantizeConfigfrom .measurement import load_calibration_textsdef run_quantize_gptqmodel(args, dynamic_config): """Standard greedy GPTQ rounding via GPTQModel.""" texts = load_calibration_texts( args.dataset_name, args.dataset_config, args.split, args.text_field, args.quant_num_texts, ) qcfg = QuantizeConfig( bits=4, group_size=128, sym=True, desc_act=True, dynamic=dynamic_config, ) print("Loading model for GPTQModel quantization...") model = GPTQModel.load(args.model_id, qcfg) model.quantize(texts, batch_size=args.quant_batch_size) out = Path(args.out_dir) out.mkdir(parents=True, exist_ok=True) model.save(str(out)) del model gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache()def _plan_has_quantized_lm_head(plan) -> bool: if not plan: return False for p in plan: if "lm_head" in p["name"] and p["action"] in ("w4_g128", "w8_g128"): return True return Falsedef run_quantize_autoround(args, plan, export_format: str): """ AutoRound SignSGD rounding with documented quantize_and_save() API. Behavior: - For auto_gptq: lm_head must not be quantized. - If auto_gptq export/load constraints are violated, we fallback to auto_round. """ from transformers import AutoModelForCausalLM, AutoTokenizer texts = load_calibration_texts( args.dataset_name, args.dataset_config, args.split, args.text_field, args.quant_num_texts, ) layer_config = {} if plan is not None: for p in plan: if p["action"] == "skip": layer_config[p["name"]] = {"bits": 16} elif p["action"] != "w4_g128": layer_config[p["name"]] = {"bits": 8, "group_size": 128} requested_format = export_format actual_format = export_format if requested_format == "auto_gptq" and _plan_has_quantized_lm_head(plan): print("[warn] quantized lm_head is not valid for auto_gptq here; switching export format to auto_round.") actual_format = "auto_round" print(f"Loading model for AutoRound ({requested_format} requested)...") model = AutoModelForCausalLM.from_pretrained( args.model_id, torch_dtype=torch.float16, trust_remote_code=True, device_map="auto" ) tokenizer = AutoTokenizer.from_pretrained(args.model_id, trust_remote_code=True) ar = AutoRound( model=model, tokenizer=tokenizer, bits=4, group_size=128, sym=True, dataset=texts, seqlen=args.max_length, batch_size=args.quant_batch_size, layer_config=layer_config, iters=args.autoround_iters, ) out = Path(args.out_dir) out.mkdir(parents=True, exist_ok=True) try: print(f"Executing AutoRound and saving via quantize_and_save(format='{actual_format}')...") ar.quantize_and_save(str(out), format=actual_format) except Exception as e: if actual_format == "auto_gptq": print(f"[warn] auto_gptq export failed: {e}") print("[warn] retrying with auto_round format...") actual_format = "auto_round" ar.quantize_and_save(str(out), format=actual_format) else: raise export_info = { "requested_format": requested_format, "actual_format": actual_format, } (out / "export_info.json").write_text(json.dumps(export_info, indent=2), encoding="utf-8") del ar del model del tokenizer gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache()def run_eval(model_path: str, tasks: str, out_file: str): """Programmatic lm-eval.""" print(f"Running programmatic lm-eval on {model_path}...") Path(out_file).parent.mkdir(parents=True, exist_ok=True) task_list = [t.strip() for t in tasks.split(",") if t.strip()] results = lm_eval.simple_evaluate( model="hf", model_args=f"pretrained={model_path},trust_remote_code=True", tasks=task_list, batch_size="auto", device="cuda:0" if torch.cuda.is_available() else "cpu", ) serializable_results = { "results": results.get("results", {}), "versions": results.get("versions", {}), } Path(out_file).write_text(json.dumps(serializable_results, indent=2), encoding="utf-8") return serializable_resultsEOFcat > quant_v5_3/pipeline.py <<'EOF'from __future__ import annotationsimport argparseimport jsonfrom pathlib import Pathfrom .allocator import solve_proxy_allocationfrom .measurement import measure_hessian_tracesfrom .runner import run_eval, run_quantize_autoround, run_quantize_gptqmodeldef main(): parser = argparse.ArgumentParser() sub = parser.add_subparsers(dest="cmd", required=True) suite = sub.add_parser("suite") suite.add_argument("--model-id", required=True) suite.add_argument("--work-dir", required=True) suite.add_argument("--budget-mb-extra", type=float, default=250.0) suite.add_argument("--include-modules", default="q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj,lm_head") suite.add_argument("--dataset-name", default="wikitext") suite.add_argument("--dataset-config", default="wikitext-2-raw-v1") suite.add_argument("--split", default="train") suite.add_argument("--text-field", default="text") suite.add_argument("--num-texts", type=int, default=128) suite.add_argument("--max-length", type=int, default=2048) suite.add_argument("--batch-size", type=int, default=4) suite.add_argument("--quant-num-texts", type=int, default=128) suite.add_argument("--quant-batch-size", type=int, default=4) suite.add_argument("--autoround-iters", type=int, default=200, help="Iterations for AutoRound SignSGD") suite.add_argument("--tasks", default="arc_challenge,hellaswag") args = parser.parse_args() work_dir = Path(args.work_dir) work_dir.mkdir(parents=True, exist_ok=True) print("\n--- STAGE 1: Hessian Proxy Measurement ---") args.out_dir = str(work_dir / "probe") measure_hessian_traces(args) features_path = str(work_dir / "probe" / "hessian_features.json") print("\n--- STAGE 2: Lagrangian Proxy Allocation ---") # Flex plan: GPTQModel mixed dynamic path & AutoRound native format path alloc_flex = solve_proxy_allocation( features_path, args.budget_mb_extra, allow_quantized_lm_head=True, ) (work_dir / "sdq_plan_flex.json").write_text(json.dumps(alloc_flex["plan"], indent=2), encoding="utf-8") (work_dir / "sdq_dynamic_flex.json").write_text(json.dumps(alloc_flex["dynamic"], indent=2), encoding="utf-8") # Strict plan: auto_gptq export path alloc_strict = solve_proxy_allocation( features_path, args.budget_mb_extra, allow_quantized_lm_head=False, ) (work_dir / "sdq_plan_strict.json").write_text(json.dumps(alloc_strict["plan"], indent=2), encoding="utf-8") (work_dir / "sdq_dynamic_strict.json").write_text(json.dumps(alloc_strict["dynamic"], indent=2), encoding="utf-8") print( f"Allocated. Estimated extra MB -> " f"flex: {alloc_flex['used_extra_mb']:.2f}, " f"strict: {alloc_strict['used_extra_mb']:.2f}" ) print("\n--- STAGE 3A: Quantize baseline_gptq ---") args.out_dir = str(work_dir / "models" / "baseline_gptq") run_quantize_gptqmodel(args, dynamic_config=None) print("\n--- STAGE 3B: Quantize sdq_gptq ---") args.out_dir = str(work_dir / "models" / "sdq_gptq") run_quantize_gptqmodel(args, dynamic_config=alloc_flex["dynamic"]) print("\n--- STAGE 3C: Quantize sdq_autoround_auto_gptq ---") args.out_dir = str(work_dir / "models" / "sdq_autoround_auto_gptq") run_quantize_autoround(args, plan=alloc_strict["plan"], export_format="auto_gptq") print("\n--- STAGE 3D: Quantize sdq_autoround_auto_round ---") args.out_dir = str(work_dir / "models" / "sdq_autoround_auto_round") run_quantize_autoround(args, plan=alloc_flex["plan"], export_format="auto_round") print("\n--- STAGE 4: Evaluation Matrix ---") run_eval( str(work_dir / "models" / "baseline_gptq"), args.tasks, str(work_dir / "eval" / "baseline_gptq.json"), ) run_eval( str(work_dir / "models" / "sdq_gptq"), args.tasks, str(work_dir / "eval" / "sdq_gptq.json"), ) run_eval( str(work_dir / "models" / "sdq_autoround_auto_gptq"), args.tasks, str(work_dir / "eval" / "sdq_autoround_auto_gptq.json"), ) run_eval( str(work_dir / "models" / "sdq_autoround_auto_round"), args.tasks, str(work_dir / "eval" / "sdq_autoround_auto_round.json"), ) print("\n--- Pipeline Complete. See eval/ and models/*/export_info.json ---")if __name__ == "__main__": main()EOFechoecho "Project created at: $(pwd)"echoecho "Next steps:"echo "python -m venv .venv"echo "source .venv/bin/activate"echo "# install your CUDA-matching torch first"echo "pip install -r requirements.txt"echoecho "python -m quant_v5_3.pipeline suite \\"echo " --model-id Qwen/Qwen3.5-4B-Base \\"echo " --work-dir ./runs/v5_3_suite \\"echo " --budget-mb-extra 200 \\"echo " --autoround-iters 50 \\"echo " --tasks arc_challenge,hellaswag" +Bash#!/usr/bin/env bashset -euo pipefailmkdir -p sdq_quant_v5_3/quant_v5_3cd sdq_quant_v5_3cat > README.md <<'EOF'# SDQ Quant v5.3 (Hessian-Proxy + AutoRound + Format Policies)A deterministic Software-Defined Quantization pipeline targeting standard GPTQ-compatible 4/8/skip module paths intended for GPTQModel / vLLM deployment, plus best-effort AutoRound export paths.## Design1. Measurement: - Computes `trace(X^T X)` on `cuda:0` as a proxy for second-order sensitivity.2. Allocation: - Uses a Lagrangian solver under a planner memory budget. - `q_proj`, `k_proj`, `v_proj` are locked to `w4_g128` to preserve QKV fusion. - For `auto_gptq`, `lm_head` is forced to `skip`. - For `GPTQModel` and `auto_round`, `lm_head` can choose `w4_g128`, `w8_g128`, or `skip`.3. Quantization matrix: - `baseline_gptq`: GPTQModel, plain 4-bit - `sdq_gptq`: GPTQModel, Hessian-proxy allocation - `sdq_autoround_auto_gptq`: AutoRound with strict `lm_head` policy - `sdq_autoround_auto_round`: AutoRound with flex `lm_head` policy4. Evaluation: - Uses `lm_eval.simple_evaluate()`.## Notes- The memory budget is a planner estimate, not an exact runtime VRAM guarantee.- The error model is a Hessian proxy, not a full GPTQ inverse-Hessian simulation.- `auto_gptq` mixed-bit export is best-effort.- If `auto_gptq` export or downstream loading fails, this code falls back to `auto_round` and records the actual export format in `export_info.json`.EOFcat > requirements.txt <<'EOF'transformersdatasetsacceleratetorchgptqmodelauto-roundlm-evaltqdmEOFcat > quant_v5_3/__init__.py <<'EOF'"""SDQ Quant v5.3: Hessian-Proxy Allocator + AutoRound Format Policies"""EOFcat > quant_v5_3/actions.py <<'EOF'from __future__ import annotationsfrom dataclasses import dataclassfrom typing import Dict, Optional@dataclass(frozen=True)class ActionSpec: name: str bits: int group_size: Optional[int] skip: boolACTION_SPECS: Dict[str, ActionSpec] = { "w4_g128": ActionSpec("w4_g128", bits=4, group_size=128, skip=False), "w8_g128": ActionSpec("w8_g128", bits=8, group_size=128, skip=False), "skip": ActionSpec("skip", bits=16, group_size=None, skip=True),}BASE_ACTION = "w4_g128"def bytes_per_param(action: str) -> float: """Planner memory estimate, not an exact runtime VRAM model.""" if action == "skip": return 2.0 # FP16 spec = ACTION_SPECS[action] base = spec.bits / 8.0 overhead = 2.5 / spec.group_size return base + overheaddef action_to_override(action: str) -> dict: if action == "skip": return {} spec = ACTION_SPECS[action] return {"bits": spec.bits, "group_size": spec.group_size}EOFcat > quant_v5_3/measurement.py <<'EOF'from __future__ import annotationsimport jsonfrom pathlib import Pathimport torchfrom datasets import load_datasetfrom tqdm import tqdmfrom transformers import AutoModelForCausalLM, AutoTokenizerdef load_calibration_texts(ds_name: str, ds_config: str, split: str, text_field: str, num_texts: int): ds = load_dataset(ds_name, ds_config, split=split) texts = [] for row in ds: text = row.get(text_field) if isinstance(text, str) and text.strip(): texts.append(text.strip()) if len(texts) >= num_texts: break return textsdef measure_hessian_traces(args): """ Computes a Hessian proxy: Trace(H) ≈ mean over calibration rows of sum(x^2) collected from each target linear layer input. """ out_dir = Path(args.out_dir) out_dir.mkdir(parents=True, exist_ok=True) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") tokenizer = AutoTokenizer.from_pretrained(args.model_id, trust_remote_code=True) if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( args.model_id, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, low_cpu_mem_usage=True, trust_remote_code=True, ).to(device) model.eval() texts = load_calibration_texts( args.dataset_name, args.dataset_config, args.split, args.text_field, args.num_texts, ) include_mods = [x.strip() for x in args.include_modules.split(",") if x.strip()] targets = [ (n, m) for n, m in model.named_modules() if isinstance(m, torch.nn.Linear) and any(tag in n for tag in include_mods) ] hessian_traces = {name: 0.0 for name, _ in targets} row_counts = {name: 0 for name, _ in targets} params_count = {name: m.weight.numel() for name, m in targets} handles = [] def make_hook(name): def hook(_, inp, __): x = inp[0].detach().float().reshape(-1, inp[0].shape[-1]) hessian_traces[name] += (x ** 2).sum().item() row_counts[name] += x.shape[0] return hook for name, module in targets: handles.append(module.register_forward_hook(make_hook(name))) with torch.no_grad(): for i in tqdm(range(0, len(texts), args.batch_size), desc="Hessian Proxy Measurement"): batch_texts = texts[i : i + args.batch_size] batch = tokenizer( batch_texts, return_tensors="pt", truncation=True, max_length=args.max_length, padding=True, ) batch = {k: v.to(device) for k, v in batch.items()} model(**batch) for h in handles: h.remove() features = [] for name, _ in targets: avg_trace = hessian_traces[name] / max(1, row_counts[name]) features.append( { "name": name, "params": params_count[name], "hessian_trace": avg_trace, } ) out_file = out_dir / "hessian_features.json" out_file.write_text(json.dumps(features, indent=2), encoding="utf-8") print(f"Saved {len(features)} layer traces to {out_file}") del model if torch.cuda.is_available(): torch.cuda.empty_cache() return featuresEOFcat > quant_v5_3/allocator.py <<'EOF'from __future__ import annotationsimport jsonimport refrom pathlib import Pathfrom .actions import ACTION_SPECS, BASE_ACTION, action_to_override, bytes_per_paramdef get_valid_actions(module_name: str, allow_quantized_lm_head: bool) -> list[str]: """ Policy layer: - QKV are locked to base action to preserve fused QKV paths. - lm_head policy depends on target export format. - others can move between 4-bit and 8-bit. """ if any(k in module_name for k in ["q_proj", "k_proj", "v_proj"]): return [BASE_ACTION] if "lm_head" in module_name: return ["w4_g128", "w8_g128", "skip"] if allow_quantized_lm_head else ["skip"] return ["w4_g128", "w8_g128"]def solve_proxy_allocation(features_path: str, budget_mb_extra: float, allow_quantized_lm_head: bool): """Lagrangian multiplier solver over a proxy error model.""" features = json.loads(Path(features_path).read_text(encoding="utf-8")) base_bpp = bytes_per_param(BASE_ACTION) total_base_mb = sum(f["params"] * base_bpp for f in features) / (1024 ** 2) target_mb = total_base_mb + budget_mb_extra def proxy_error(trace: float, action: str) -> float: if action == "skip": return 0.0 bits = ACTION_SPECS[action].bits return trace * (2.0 ** (-2.0 * bits)) def action_cost_mb(params: int, action: str) -> float: return (params * bytes_per_param(action)) / (1024 ** 2) lambda_low, lambda_high = 0.0, 1e12 best_plan = None for _ in range(80): lam = (lambda_low + lambda_high) / 2.0 current_mb = 0.0 plan = [] for f in features: valid_actions = get_valid_actions(f["name"], allow_quantized_lm_head) best_action = None best_obj = float("inf") for act in valid_actions: err = proxy_error(f["hessian_trace"], act) size_mb = action_cost_mb(f["params"], act) obj = err + lam * size_mb if obj < best_obj: best_obj = obj best_action = act size_mb = action_cost_mb(f["params"], best_action) plan.append( { "name": f["name"], "action": best_action, "params": f["params"], "hessian_trace": f["hessian_trace"], "size_mb": size_mb, } ) current_mb += size_mb best_plan = plan if current_mb > target_mb: lambda_low = lam else: lambda_high = lam used_extra_mb = sum(p["size_mb"] for p in best_plan) - total_base_mb dynamic = {} for p in best_plan: if p["action"] == BASE_ACTION: continue regex = f"^{re.escape(p['name'])}$" if p["action"] == "skip": dynamic[f"-:{regex}"] = {} else: dynamic[f"+:{regex}"] = action_to_override(p["action"]) return { "plan": best_plan, "dynamic": dynamic, "used_extra_mb": used_extra_mb, }EOFcat > quant_v5_3/runner.py <<'EOF'from __future__ import annotationsimport gcimport jsonfrom pathlib import Pathimport lm_evalimport torchfrom auto_round import AutoRoundfrom gptqmodel import GPTQModel, QuantizeConfigfrom .measurement import load_calibration_textsdef run_quantize_gptqmodel(args, dynamic_config): """Standard greedy GPTQ rounding via GPTQModel.""" texts = load_calibration_texts( args.dataset_name, args.dataset_config, args.split, args.text_field, args.quant_num_texts, ) qcfg = QuantizeConfig( bits=4, group_size=128, sym=True, desc_act=True, dynamic=dynamic_config, ) print("Loading model for GPTQModel quantization...") model = GPTQModel.load(args.model_id, qcfg) model.quantize(texts, batch_size=args.quant_batch_size) out = Path(args.out_dir) out.mkdir(parents=True, exist_ok=True) model.save(str(out)) del model gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache()def _plan_has_quantized_lm_head(plan) -> bool: if not plan: return False for p in plan: if "lm_head" in p["name"] and p["action"] in ("w4_g128", "w8_g128"): return True return Falsedef run_quantize_autoround(args, plan, export_format: str): """ AutoRound SignSGD rounding with documented quantize_and_save() API. Behavior: - For auto_gptq: lm_head must not be quantized. - If auto_gptq export/load constraints are violated, we fallback to auto_round. """ from transformers import AutoModelForCausalLM, AutoTokenizer texts = load_calibration_texts( args.dataset_name, args.dataset_config, args.split, args.text_field, args.quant_num_texts, ) layer_config = {} if plan is not None: for p in plan: if p["action"] == "skip": layer_config[p["name"]] = {"bits": 16} elif p["action"] != "w4_g128": layer_config[p["name"]] = {"bits": 8, "group_size": 128} requested_format = export_format actual_format = export_format if requested_format == "auto_gptq" and _plan_has_quantized_lm_head(plan): print("[warn] quantized lm_head is not valid for auto_gptq here; switching export format to auto_round.") actual_format = "auto_round" print(f"Loading model for AutoRound ({requested_format} requested)...") model = AutoModelForCausalLM.from_pretrained( args.model_id, torch_dtype=torch.float16, trust_remote_code=True, device_map="auto" ) tokenizer = AutoTokenizer.from_pretrained(args.model_id, trust_remote_code=True) ar = AutoRound( model=model, tokenizer=tokenizer, bits=4, group_size=128, sym=True, dataset=texts, seqlen=args.max_length, batch_size=args.quant_batch_size, layer_config=layer_config, iters=args.autoround_iters, ) out = Path(args.out_dir) out.mkdir(parents=True, exist_ok=True) try: print(f"Executing AutoRound and saving via quantize_and_save(format='{actual_format}')...") ar.quantize_and_save(str(out), format=actual_format) except Exception as e: if actual_format == "auto_gptq": print(f"[warn] auto_gptq export failed: {e}") print("[warn] retrying with auto_round format...") actual_format = "auto_round" ar.quantize_and_save(str(out), format=actual_format) else: raise export_info = { "requested_format": requested_format, "actual_format": actual_format, } (out / "export_info.json").write_text(json.dumps(export_info, indent=2), encoding="utf-8") del ar del model del tokenizer gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache()def run_eval(model_path: str, tasks: str, out_file: str): """Programmatic lm-eval.""" print(f"Running programmatic lm-eval on {model_path}...") Path(out_file).parent.mkdir(parents=True, exist_ok=True) task_list = [t.strip() for t in tasks.split(",") if t.strip()] results = lm_eval.simple_evaluate( model="hf", model_args=f"pretrained={model_path},trust_remote_code=True", tasks=task_list, batch_size="auto", device="cuda:0" if torch.cuda.is_available() else "cpu", ) serializable_results = { "results": results.get("results", {}), "versions": results.get("versions", {}), } Path(out_file).write_text(json.dumps(serializable_results, indent=2), encoding="utf-8") return serializable_resultsEOFcat > quant_v5_3/pipeline.py <<'EOF'from __future__ import annotationsimport argparseimport jsonfrom pathlib import Pathfrom .allocator import solve_proxy_allocationfrom .measurement import measure_hessian_tracesfrom .runner import run_eval, run_quantize_autoround, run_quantize_gptqmodeldef main(): parser = argparse.ArgumentParser() sub = parser.add_subparsers(dest="cmd", required=True) suite = sub.add_parser("suite") suite.add_argument("--model-id", required=True) suite.add_argument("--work-dir", required=True) suite.add_argument("--budget-mb-extra", type=float, default=250.0) suite.add_argument("--include-modules", default="q_proj,k_proj,v_proj,o_proj,gate_proj,up_proj,down_proj,lm_head") suite.add_argument("--dataset-name", default="wikitext") suite.add_argument("--dataset-config", default="wikitext-2-raw-v1") suite.add_argument("--split", default="train") suite.add_argument("--text-field", default="text") suite.add_argument("--num-texts", type=int, default=128) suite.add_argument("--max-length", type=int, default=2048) suite.add_argument("--batch-size", type=int, default=4) suite.add_argument("--quant-num-texts", type=int, default=128) suite.add_argument("--quant-batch-size", type=int, default=4) suite.add_argument("--autoround-iters", type=int, default=200, help="Iterations for AutoRound SignSGD") suite.add_argument("--tasks", default="arc_challenge,hellaswag") args = parser.parse_args() work_dir = Path(args.work_dir) work_dir.mkdir(parents=True, exist_ok=True) print("\n--- STAGE 1: Hessian Proxy Measurement ---") args.out_dir = str(work_dir / "probe") measure_hessian_traces(args) features_path = str(work_dir / "probe" / "hessian_features.json") print("\n--- STAGE 2: Lagrangian Proxy Allocation ---") # Flex plan: GPTQModel mixed dynamic path & AutoRound native format path alloc_flex = solve_proxy_allocation( features_path, args.budget_mb_extra, allow_quantized_lm_head=True, ) (work_dir / "sdq_plan_flex.json").write_text(json.dumps(alloc_flex["plan"], indent=2), encoding="utf-8") (work_dir / "sdq_dynamic_flex.json").write_text(json.dumps(alloc_flex["dynamic"], indent=2), encoding="utf-8") # Strict plan: auto_gptq export path alloc_strict = solve_proxy_allocation( features_path, args.budget_mb_extra, allow_quantized_lm_head=False, ) (work_dir / "sdq_plan_strict.json").write_text(json.dumps(alloc_strict["plan"], indent=2), encoding="utf-8") (work_dir / "sdq_dynamic_strict.json").write_text(json.dumps(alloc_strict["dynamic"], indent=2), encoding="utf-8") print( f"Allocated. Estimated extra MB -> " f"flex: {alloc_flex['used_extra_mb']:.2f}, " f"strict: {alloc_strict['used_extra_mb']:.2f}" ) print("\n--- STAGE 3A: Quantize baseline_gptq ---") args.out_dir = str(work_dir / "models" / "baseline_gptq") run_quantize_gptqmodel(args, dynamic_config=None) print("\n--- STAGE 3B: Quantize sdq_gptq ---") args.out_dir = str(work_dir / "models" / "sdq_gptq") run_quantize_gptqmodel(args, dynamic_config=alloc_flex["dynamic"]) print("\n--- STAGE 3C: Quantize sdq_autoround_auto_gptq ---") args.out_dir = str(work_dir / "models" / "sdq_autoround_auto_gptq") run_quantize_autoround(args, plan=alloc_strict["plan"], export_format="auto_gptq") print("\n--- STAGE 3D: Quantize sdq_autoround_auto_round ---") args.out_dir = str(work_dir / "models" / "sdq_autoround_auto_round") run_quantize_autoround(args, plan=alloc_flex["plan"], export_format="auto_round") print("\n--- STAGE 4: Evaluation Matrix ---") run_eval( str(work_dir / "models" / "baseline_gptq"), args.tasks, str(work_dir / "eval" / "baseline_gptq.json"), ) run_eval( str(work_dir / "models" / "sdq_gptq"), args.tasks, str(work_dir / "eval" / "sdq_gptq.json"), ) run_eval( str(work_dir / "models" / "sdq_autoround_auto_gptq"), args.tasks, str(work_dir / "eval" / "sdq_autoround_auto_gptq.json"), ) run_eval( str(work_dir / "models" / "sdq_autoround_auto_round"), args.tasks, str(work_dir / "eval" / "sdq_autoround_auto_round.json"), ) print("\n--- Pipeline Complete. See eval/ and models/*/export_info.json ---")if __name__ == "__main__": main()EOFechoecho "Project created at: $(pwd)"echoecho "Next steps:"echo "python -m venv .venv"echo "source .venv/bin/activate"echo "# install your CUDA-matching torch first"echo "pip install -r requirements.txt"echoecho "python -m quant_v5_3.pipeline suite \\"echo " --model-id Qwen/Qwen3.5-4B-Base \\"echo " --work-dir ./runs/v5_3_suite \\"echo " --budget-mb-extra 200 \\"echo " --autoround-iters 50 \\"echo " --tasks arc_challenge,hellaswag" +``` + +Then run: + +```Bashbash create_sdq_v53.shcd sdq_quant_v5_3python -m venv .venvsource .venv/bin/activate# install your CUDA-matching torch firstpip install -r requirements.txtpython -m quant_v5_3.pipeline suite \ --model-id Qwen/Qwen3.5-4B-Base \ --work-dir ./runs/v5_3_suite \ --budget-mb-extra 200 \ --autoround-iters 50 \ --tasks arc_challenge,hellaswag +Bashbash create_sdq_v53.shcd sdq_quant_v5_3python -m venv .venvsource .venv/bin/activate# install your CUDA-matching torch firstpip install -r requirements.txtpython -m quant_v5_3.pipeline suite \ --model-id Qwen/Qwen3.5-4B-Base \ --work-dir ./runs/v5_3_suite \ --budget-mb-extra 200 \ --autoround-iters 50 \ --tasks arc_challenge,hellaswag +``` + +--- + +### User: +And this project ? https://huggingface.co/z-lab/Qwen3.5-9B-PARO how to implement their logic ? + +--- + +### ChatGPT: +It’s a strong project, but it is **not** “just another GPTQ/AWQ recipe.” +`z-lab/Qwen3.5-9B-PARO` is a ParoQuant checkpoint, and ParoQuant’s core idea is to suppress outliers with **independent Givens rotations plus channel-wise scaling**, then run a lightweight custom runtime path around that transform. The paper says this gives an average **2.4% accuracy gain over AWQ** on reasoning tasks with **<10% overhead**, and the repo/model card expose it through the `paroquant` runtime rather than GPTQModel. [Hugging Face+2arXiv+2](https://huggingface.co/z-lab/Qwen3.5-9B-PARO) + +So there are really **two ways** to “implement their logic”: + +### 1) Exact ParoQuant + +Use **their repo/runtime** and quantize your own model the way they intended. The official workflow is: + +```Bashgit clone https://github.com/z-lab/paroquantcd paroquantpip install -e ".[optim,eval]"# 1) optimize rotation parametersexperiments/optimize/4bit.sh Qwen/Qwen3-8B# 2) export checkpointpython -m paroquant.cli.convert \ --model Qwen/Qwen3-8B \ --result-dir output/Qwen3-8B \ --output-path models/Qwen3-8B-PARO +Bashgit clone https://github.com/z-lab/paroquantcd paroquantpip install -e ".[optim,eval]"# 1) optimize rotation parametersexperiments/optimize/4bit.sh Qwen/Qwen3-8B# 2) export checkpointpython -m paroquant.cli.convert \ --model Qwen/Qwen3-8B \ --result-dir output/Qwen3-8B \ --output-path models/Qwen3-8B-PARO +``` + +That is the repo’s documented “Quantize Your Own Model” path. The model card also shows the inference/runtime side: `pip install "paroquant[vllm]"`, interactive chat, and an OpenAI-compatible server. The repo additionally warns that `main` is under active development and that the **legacy branch** should be used for paper reproduction. [GitHub+1](https://github.com/z-lab/paroquant) + +If you just want to **run** the published model, the shortest path is: + +```Bashpip install "paroquant[vllm]"python -m paroquant.cli.chat --model z-lab/Qwen3.5-9B-PARO# orpython -m paroquant.cli.serve --model z-lab/Qwen3.5-9B-PARO --port 8000 +Bashpip install "paroquant[vllm]"python -m paroquant.cli.chat --model z-lab/Qwen3.5-9B-PARO# orpython -m paroquant.cli.serve --model z-lab/Qwen3.5-9B-PARO --port 8000 +``` + +That is the cleanest “exact logic” implementation because it uses the actual ParoQuant runtime and kernels. [Hugging Face+1](https://huggingface.co/z-lab/Qwen3.5-9B-PARO) + +### 2) Borrow the idea into your own quant pipeline + +If you want to fold ParoQuant-style logic into your own SDQ/GPTQ pipeline, the right mental model is: + +- **precondition the tensor** with: + +channel-wise scaling + +pairwise Givens rotations +- channel-wise scaling +- pairwise Givens rotations +- then quantize the transformed weights to INT4 +- then at inference time, apply the corresponding lightweight transform path in the kernel/runtime + +That is because the paper’s method is explicitly built around **scaled pairwise rotation** and a **co-designed inference kernel**. In other words, exact ParoQuant is not just an offline allocator; it changes the quantization transform and expects runtime support for it. [arXiv+2arXiv+2](https://arxiv.org/abs/2511.10645) + +A practical implementation plan looks like this: + +```Python# conceptual only, not repo-exact codefor each target linear layer: # 1) collect calibration activations X X = collect_inputs(layer, calib_set) # 2) learn channel scaling s and pairwise rotation angles theta # objective: reduce outliers / narrow dynamic range before quantization s, theta = optimize_transform(X, W) # 3) transform weights W_t = pairwise_rotate(channel_scale(W, s), theta) # 4) quantize transformed weights to INT4 W_q, scales, zeros = quantize_groupwise(W_t) # 5) save: # - qweight / scales / zeros # - transform metadata (s, theta) +Python# conceptual only, not repo-exact codefor each target linear layer: # 1) collect calibration activations X X = collect_inputs(layer, calib_set) # 2) learn channel scaling s and pairwise rotation angles theta # objective: reduce outliers / narrow dynamic range before quantization s, theta = optimize_transform(X, W) # 3) transform weights W_t = pairwise_rotate(channel_scale(W, s), theta) # 4) quantize transformed weights to INT4 W_q, scales, zeros = quantize_groupwise(W_t) # 5) save: # - qweight / scales / zeros # - transform metadata (s, theta) +``` + +At runtime you then need a fused path that effectively computes the transformed linear efficiently. That runtime requirement is the key difference versus your current GPTQ/AutoRound pipeline. The paper and repo both emphasize that ParoQuant is paired with a lightweight runtime/kernel implementation, not just a saved GPTQ checkpoint. [arXiv+1](https://arxiv.org/abs/2511.10645) + +### What I would do in your case + +For your current stack, I’d separate this into two tracks: + +**Track A: exact ParoQuant** + +- use the `paroquant` repo exactly as documented +- quantize one of your target models with their optimizer + converter +- benchmark it against your GPTQ/AutoRound matrix + +**Track B: ParoQuant-inspired hybrid** + +- keep your SDQ allocator +- replace “which layers get 4 vs 8” with “which layers are worth Paro-style transform” +- only apply the Paro-style transform to the most fragile modules first: + +`o_proj` + +`down_proj` + +`gate_proj` + +maybe `lm_head` +- `o_proj` +- `down_proj` +- `gate_proj` +- maybe `lm_head` +- leave QKV uniform to avoid path complexity + +That second path is an engineering inference from how ParoQuant attacks outliers, not the exact published implementation. The exact published implementation is the repo workflow above. [arXiv+1](https://arxiv.org/abs/2511.10645) + +### What not to do + +Do **not** assume you can reproduce ParoQuant by only changing your allocator. +ParoQuant’s gain is not mainly from a smarter 4/8-bit assignment; it comes from the **rotation + scaling transform** and the **runtime support** around it. If you keep plain GPTQ weights and plain GPTQ kernels, you can borrow the spirit, but you are not implementing ParoQuant itself. [arXiv+1](https://arxiv.org/abs/2511.10645) + +If you want, I can turn this into a concrete **“ParoQuant-vs-v5.3 benchmark harness”** for your 3090 box, or sketch a **Paro-inspired transform module** you can bolt onto your current SDQ code. + +--- + +### User: +If modifying kernels wasn’t an issue and I want to combine the logic of paroquant exl3 and autoround , into a new quantization method that beats them all , and still as fast as gptq in inferencing with ampere gpus , how would you do it ? + +--- + +### ChatGPT: +I would not try to literally “merge the formats.” I would build a **new scalar, GPTQ-speed format** that steals the **transform logic** from ParoQuant, the **search logic** from EXL3/QTIP, and the **rounding optimization** from AutoRound—while keeping the runtime path as close as possible to a Marlin-style **fused dequant+GEMM** kernel for Ampere. That is the only credible way to chase “beats them all” **and** stay GPTQ-fast at inference. ParoQuant’s gain comes from pairwise Givens rotations plus channel scaling with a lightweight runtime transform; EXL3 is a streamlined QTIP-style format built around trellis/codebook ideas; AutoRound/SignRound improves rounding with SignSGD without adding inference overhead; and Marlin shows Ampere can get near-ideal FP16×INT4 speedups when the kernel stays memory-efficient and fused. [arXiv+3arXiv+3GitHub+3](https://arxiv.org/abs/2511.10645?utm_source=chatgpt.com) + +The design I would bet on is this: + +## 1) Keep the runtime scalar and Marlin-like + +Do **not** inherit EXL3/QTIP’s sequential decoder or codebook-heavy inference path. QTIP’s paper is clear that naive TCQ does not admit fast inference because of sequential bottlenecks and codebook storage pressure. If the target is “as fast as GPTQ on Ampere,” the runtime should still look like: + +- packed low-bit weights, +- tiny per-group metadata, +- one fused kernel that dequantizes into registers/shared memory and immediately does the MMA, +- no per-token codebook decode loops, +- no runtime graph of small transforms outside the GEMM. [arXiv+2arXiv+2](https://arxiv.org/abs/2406.11235?utm_source=chatgpt.com) + +So the new method should remain a **scalar groupwise quantizer**, not a vector-codebook quantizer. + +## 2) Borrow ParoQuant’s transform, but fuse it into the matmul + +ParoQuant’s strongest idea is not “better bit allocation.” It is **pairwise Givens rotations + channel-wise scaling** to suppress outliers before quantization, while keeping the transform hardware-efficient enough to add under 10% overhead in their design. That is the part worth stealing. I would apply: + +- per-group channel scaling, +- sparse pairwise rotations inside each block, +- learned offline, +- then quantize the transformed weights. [arXiv+1](https://arxiv.org/abs/2511.10645?utm_source=chatgpt.com) + +But I would impose one constraint ParoQuant does not need to obey as strictly: the transform must be **tile-local and kernel-fusible**. In practice that means: + +- rotations only within fixed 64- or 128-column groups, +- pair indices chosen so each warp can apply them without cross-warp shuffles exploding, +- scale stored per group or per micro-group, +- transform metadata packed adjacent to the group metadata. + +That gives you ParoQuant’s outlier suppression without a separate runtime pass. + +## 3) Replace greedy quantization with AutoRound-style optimization + +For the actual scalar levels, I would not use plain GPTQ greedy rounding. AutoRound/SignRound is the right inspiration here because it optimizes rounding and clipping with SignSGD yet still exports ordinary low-bit checkpoints without adding inference overhead. That is exactly the kind of offline work you want. [arXiv+1](https://arxiv.org/abs/2309.05516?utm_source=chatgpt.com) + +So the offline objective becomes: + +- optimize transform parameters: + +channel scales + +pairwise rotation angles +- channel scales +- pairwise rotation angles +- optimize quantization parameters: + +clipping + +scales + +zero/offset if asymmetric +- clipping +- scales +- zero/offset if asymmetric +- optimize rounding decisions: + +SignSGD-style continuous relaxation, then project back to scalar integers +- SignSGD-style continuous relaxation, then project back to scalar integers + +In other words, use AutoRound’s idea **after** the Paro-style transform, not instead of it. + +## 4) Borrow EXL3/QTIP as a search prior, not as the deployed format + +EXL3 is explicitly described as a streamlined variant of QTIP, and QTIP’s core advantage is **better search over quantization states** through trellis coding, not a cheap runtime path. So I would use QTIP/EXL3 logic only in the **offline optimizer**: + +- short-horizon beam search over scalar assignments inside a block, +- or initialize AutoRound from a trellis-searched candidate, +- then let SignSGD refine it. [GitHub+1](https://github.com/turboderp-org/exllamav3?utm_source=chatgpt.com) + +That gives you “EXL3-like intelligence” during quantization without inheriting EXL3’s runtime burden. + +## 5) Use mixed precision, but only where the kernel can stay fused + +EXL2’s useful lesson is that mixed precision is powerful, and its format supports mixing quantization levels within a model to achieve arbitrary average bitrates. But for Ampere speed, I would avoid arbitrary intra-tensor bitrate soup. [GitHub](https://github.com/turboderp-org/exllamav2?utm_source=chatgpt.com) + +I would use only **two deployed scalar paths**: + +- INT4 groups +- INT8 rescue groups + +Not 2/3/5/6-bit, unless you prove the kernel still hits the memory/computation balance you want. + +So the actual tensor layout would be: + +- mostly INT4 groups, +- a small mask of “hard” groups promoted to INT8, +- same tile shape for both, +- same fused kernel dispatch, +- no dynamic branching per element, only per micro-tile/group. + +That is much closer to GPTQ/Marlin performance territory. + +## 6) Make the allocator groupwise, not just layerwise + +Your previous SDQ allocator was already heading this way conceptually, but to beat strong baselines you need finer granularity than “whole module = 4 or 8 bits.” EXL2’s main lesson is that finer-grained mixed precision matters. [GitHub](https://github.com/turboderp-org/exllamav2?utm_source=chatgpt.com) + +My allocator would operate at **micro-group granularity**: + +- group size 64 or 128 columns, +- each group gets: + +transform on/off + +INT4 or INT8 + +maybe one of a few transform templates +- transform on/off +- INT4 or INT8 +- maybe one of a few transform templates + +The cost model should be: + +- Hessian-proxy distortion, +- plus runtime penalty regularizer, +- plus metadata penalty, +- under a hard bytes budget. + +That becomes a constrained optimization problem, not an LLM-routing problem. + +## 7) The kernel I’d actually write for Ampere + +For Ampere, I would start from a Marlin-style fused FP16×INT4 kernel and extend it in the least invasive way possible. Marlin’s paper is the right design anchor because it is already optimized for practical batched autoregressive inference and gets close to ideal low-bit speedups on Ampere-class deployments. [GitHub+1](https://github.com/IST-DASLab/marlin?utm_source=chatgpt.com) + +The kernel would do this per tile: + +1. Load packed low-bit weights and metadata for one K-tile. +2. Dequantize groups into registers/shared memory. +3. Apply **pairwise rotations + channel scales** in-register/shared-memory, but only within the local tile. +4. Multiply with FP16 activations using the normal MMA path. +5. Accumulate in FP16/FP32 as usual. + +The critical rule is that the transform must be **linear and local**, so it can be folded into the same memory/compute schedule rather than becoming an extra kernel launch. + +## 8) The concrete offline pipeline + +I would build the training-free PTQ pipeline like this: + +### Stage A: collect stats + +For each target tensor/group: + +- activation covariance or Hessian proxy +- outlier scores +- sensitivity under local perturbations +- runtime cost estimate for transform and for INT8 rescue + +This part is GPTQ-like in spirit because GPTQ itself is built on approximate second-order information. [arXiv+1](https://arxiv.org/abs/2210.17323?utm_source=chatgpt.com) + +### Stage B: optimize transform + +For each group: + +- learn channel scale vector +- learn sparse pairwise rotation parameters +- objective: shrink dynamic range / equalize channels / reduce Hessian-weighted reconstruction error + +### Stage C: initialize quantization by short trellis search + +For each group: + +- do a small beam search over scalar assignments or clipping settings +- use it to seed the final optimizer + +### Stage D: refine with AutoRound-style SignSGD + +Optimize: + +- rounding offsets +- clipping +- maybe tiny adjustments to the transform +- under Hessian-weighted activation reconstruction loss + +### Stage E: global mixed-precision solve + +Promote only the hardest groups to INT8 under a bytes budget. + +### Stage F: emit kernel-friendly checkpoint + +Store: + +- packed INT4/INT8 groups +- compact per-group scales +- compact per-group transform metadata +- one format version, one fused Ampere kernel path + +## 9) What I’d call the method + +Something like: + +**TRAM-Quant** +**T**ransform-**R**ounded **A**mpere **M**ixed-precision Quantization + +Internally: + +- **Paro block** = sparse pairwise rotation + scaling +- **QTIP seed** = short trellis initialization +- **AutoRound refine** = SignSGD rounding/clipping optimization +- **Marlin runtime** = fused Ampere kernel + +## 10) What will probably make it beat the baselines + +If this wins, it will not be because of any single component. It will be because of the combination: + +- **ParoQuant** fixes outliers before quantization. [arXiv+1](https://arxiv.org/abs/2511.10645?utm_source=chatgpt.com) +- **EXL3/QTIP-like search** avoids bad local assignments at the block level. [GitHub+1](https://github.com/turboderp-org/exllamav3?utm_source=chatgpt.com) +- **AutoRound** improves scalar rounding without inference overhead. [arXiv+1](https://arxiv.org/abs/2309.05516?utm_source=chatgpt.com) +- **Marlin-style kernel discipline** preserves Ampere speed. [arXiv+1](https://arxiv.org/abs/2408.11743?utm_source=chatgpt.com) + +If you skip any one of those, you probably end up with: + +- great quality but slow runtime, +- or great speed but not enough quality gain to matter. + +## 11) What I would not do + +I would not: + +- use a generic LLM judge for allocation, +- deploy QTIP codebooks directly if speed matters, +- allow arbitrary bitwidths per tiny fragment, +- put nonlocal rotations into the runtime, +- quantize fused QKV with different per-subtensor formats unless the kernel is explicitly built for tied handling. + +## 12) The honest caveat + +Nobody can guarantee it will “beat them all” before ablations. EXL3/QTIP, ParoQuant, and AutoRound each win for different reasons, and some of their gains come from constraints you may not preserve simultaneously. What I can say is: + +**This is the most plausible architecture I know for combining their strongest ideas while still aiming for GPTQ-class Ampere throughput.** It is also substantially more realistic than trying to literally mash together EXL3 format, ParoQuant runtime, and AutoRound export unchanged. [arXiv+3arXiv+3GitHub+3](https://arxiv.org/abs/2511.10645?utm_source=chatgpt.com) + +The next step I’d take is to freeze the target format and kernel contract first: + +- group size +- INT4/INT8 mixed micro-group scheme +- transform metadata layout +- one Ampere fused kernel API + +Then derive the offline optimizer to emit exactly that contract. Send the target model family and preferred average bits-per-weight, and I’ll map this into a concrete tensor format and kernel/block design. + +--- + +### User: +Here is the raw truth: your high-level strategy is the exact right path to state-of-the-art, but your runtime mechanics and offline pipeline have critical flaws that will either tank your Ampere throughput or make quantization take weeks. +Here is the critique, followed by the polished spec. +The Critique: What Needs Fixing + * The ParoQuant/Marlin Collision: You cannot apply pairwise rotations to dequantized weights on the fly inside the Marlin inner loop. Marlin is already register-starved to maintain occupancy. Injecting math between ldmatrix and mma.sync will stall the pipeline. The fix: You apply the rotations to the weights offline. At runtime, your kernel applies the inverse transform to the activations in shared memory before feeding them to the MMA blocks. + * Marlin Hates Branching (Mixed Precision Flaw): You cannot arbitrarily mix INT4 and INT8 at the 64-column micro-group level within a single tile. Marlin relies on perfectly pipelined, deterministic cp.async memory fetches. If a warp has to branch to decode INT8 vs. INT4, your memory coalescing dies. The fix: Segregate mixed precision at the output channel (N) level. A warp computes either a pure INT4 tile or a pure INT8 tile. + * Offline Compute Bloat: Using QTIP's trellis beam search to seed AutoRound's SignSGD is computationally absurd. SignSGD is robust enough to find the local minima. The fix: Drop the trellis search. Initialize with a fast, greedy Hessian solver, then run SignSGD. +The Polished Spec: TRAM-Quant +This is the lean, deployable version of your architecture. +1. The Runtime Kernel (Marlin + Activation Transforms) + * Core: A fused Ampere kernel prioritizing memory-bound efficiency. + * The Math: Weights are pre-transformed and quantized offline. The kernel loads activations, applies sparse pairwise Givens rotations and channel scaling to the activations in shared memory (tile-local only), then streams them into the standard MMA pipeline. + * No Codebooks: Strict scalar dequantization into registers. Zero per-token decode loops. +2. Memory Layout (Constrained Mixed Precision) + * Granularity: Group sizes of 64 or 128. + * Segregation: Tensors are mostly INT4. "Hard" channels promoted to INT8 are segregated by column blocks. The kernel dispatches uniform INT4 or INT8 warp-tiles to eliminate dynamic branching in the inner loop. + * Metadata: Compact per-group scales and packed transform metadata stored adjacent to the weight matrix for coalesced loading. +3. The Offline PTQ Pipeline + * Stage A (Stats): Collect activation covariance/Hessian proxies. + * Stage B (Transform Learn): Learn ParoQuant-style sparse pairwise rotations and channel scales to suppress outliers, strictly constrained to the 64/128 block size. Apply to weights. + * Stage C (Optimize): Skip trellis search. Initialize a greedy baseline, then use AutoRound (SignSGD) to jointly optimize rounding offsets, clipping, and scale under a Hessian-weighted reconstruction loss. + * Stage D (Allocate): Promote channels with the highest residual distortion to INT8 under a strict byte budget. + * Stage E (Export): Pack into the segregated Marlin-compatible layout. +draft the exact CUDA shared memory layout and ldmatrix schedule for the segregated INT4/INT8 kernel? + +--- + +### ChatGPT: +Yes. +This is the **exact draft kernel contract** I would prototype first for Ampere. + +It is built around what Marlin already proves works well on Ampere: offline-preprocessed weights, `cp.async` global→shared pipelining, XOR-swizzled FP16 activation tiles for conflict-free `ldmatrix`, and a column-wise inner loop that pipelines weight decode with Tensor Core math. PTX also documents that `ldmatrix` is the warp-level shared-memory load primitive for `mma`, with `.m8n8.x4.shared.b16` for 16-bit matrix fragments and `.x4` supplied by lanes `0–7 / 8–15 / 16–23 / 24–31`. AutoRound’s docs are also the reason QKV stays uniform: they explicitly warn that vLLM/SGLang fuse QKV and recommend not assigning different bit widths there. [GitHub+3arXiv+3NVIDIA Docs+3](https://arxiv.org/pdf/2408.11743) + +## 1) Kernel family and tile contract + +Use **two specialized kernels**, not one branchy mixed kernel: + +- **INT4 kernel** + +CTA: `256 threads = 8 warps` + +CTA output tile: `M_TILE = 64`, `N_TILE = 128` + +Warp tile: `16 x 64` + +`K_STAGE = 64` +- CTA: `256 threads = 8 warps` +- CTA output tile: `M_TILE = 64`, `N_TILE = 128` +- Warp tile: `16 x 64` +- `K_STAGE = 64` +- **INT8 kernel** + +CTA: `128 threads = 4 warps` + +CTA output tile: `M_TILE = 64`, `N_TILE = 64` + +Warp tile: `16 x 64` + +`K_STAGE = 64` +- CTA: `128 threads = 4 warps` +- CTA output tile: `M_TILE = 64`, `N_TILE = 64` +- Warp tile: `16 x 64` +- `K_STAGE = 64` + +This keeps the **weight bytes per stage equal**: + +- INT4: `64 x 128 x 4 bits = 4096 bytes` +- INT8: `64 x 64 x 8 bits = 4096 bytes` + +That is the cleanest way to keep the memory pipeline uniform while avoiding dynamic decode branching inside the MMA loop. The “QKV lock” is also consistent with AutoRound’s guidance for fused QKV deployments. Marlin’s paper also explicitly motivates `Ksm >= 64`, uses offline reshuffled `16 x 64` weight tiles, and executes the inner accumulation as `M x 16` times `16 x 64` column-wise with `16 x 8` Tensor Core instructions. [GitHub+3arXiv+3arXiv+3](https://arxiv.org/pdf/2408.11743) + +## 2) Shared-memory layout + +I would use a **4-stage pipeline** like Marlin (`P = 4`). Marlin explicitly says it chose `P = 4`, and that `cp.async` plus pipelining of future tiles is key to overlapping memory with Tensor Core math. [arXiv](https://arxiv.org/pdf/2408.11743) + +### Exact per-stage layout + +For both kernels, the stage starts with the activation tile, then transform metadata, then weight tile, then per-output scales. + +```C// ------------------------------// Common constants// ------------------------------constexpr int PIPE = 4;constexpr int M_TILE = 64;constexpr int K_STAGE = 64;constexpr int CHUNK_K = 8; // 8 FP16 = 16 bytesconstexpr int A_CHUNKS = K_STAGE / CHUNK_K; // 8constexpr int A_STAGE_BYTES = M_TILE * A_CHUNKS * 16; // 64 * 8 * 16 = 8192// Transform metadata for one K_STAGE=64 block.// Draft contract:// - 64 channel scales (FP16) = 128 B// - up to 8 sparse Givens rotations// each stored as (u8 a, u8 b, fp16 cos, fp16 sin, pad) = 8 B// - total = 128 + 64 = 192 B -> round to 256 B for alignmentconstexpr int XFORM_STAGE_BYTES = 256;// INT4 specializationconstexpr int N_TILE4 = 128;constexpr int W4_STAGE_BYTES = K_STAGE * N_TILE4 / 2; // 4096 Bconstexpr int S4_STAGE_BYTES = N_TILE4 * 2; // one FP16 scale per output channel = 256 Bconstexpr int STAGE4_BYTES = A_STAGE_BYTES + XFORM_STAGE_BYTES + W4_STAGE_BYTES + S4_STAGE_BYTES; // 12800 B// INT8 specializationconstexpr int N_TILE8 = 64;constexpr int W8_STAGE_BYTES = K_STAGE * N_TILE8; // 4096 Bconstexpr int S8_STAGE_BYTES = N_TILE8 * 2; // 128 Bconstexpr int STAGE8_BYTES = A_STAGE_BYTES + XFORM_STAGE_BYTES + W8_STAGE_BYTES + S8_STAGE_BYTES; // 12672 B +C// ------------------------------// Common constants// ------------------------------constexpr int PIPE = 4;constexpr int M_TILE = 64;constexpr int K_STAGE = 64;constexpr int CHUNK_K = 8; // 8 FP16 = 16 bytesconstexpr int A_CHUNKS = K_STAGE / CHUNK_K; // 8constexpr int A_STAGE_BYTES = M_TILE * A_CHUNKS * 16; // 64 * 8 * 16 = 8192// Transform metadata for one K_STAGE=64 block.// Draft contract:// - 64 channel scales (FP16) = 128 B// - up to 8 sparse Givens rotations// each stored as (u8 a, u8 b, fp16 cos, fp16 sin, pad) = 8 B// - total = 128 + 64 = 192 B -> round to 256 B for alignmentconstexpr int XFORM_STAGE_BYTES = 256;// INT4 specializationconstexpr int N_TILE4 = 128;constexpr int W4_STAGE_BYTES = K_STAGE * N_TILE4 / 2; // 4096 Bconstexpr int S4_STAGE_BYTES = N_TILE4 * 2; // one FP16 scale per output channel = 256 Bconstexpr int STAGE4_BYTES = A_STAGE_BYTES + XFORM_STAGE_BYTES + W4_STAGE_BYTES + S4_STAGE_BYTES; // 12800 B// INT8 specializationconstexpr int N_TILE8 = 64;constexpr int W8_STAGE_BYTES = K_STAGE * N_TILE8; // 4096 Bconstexpr int S8_STAGE_BYTES = N_TILE8 * 2; // 128 Bconstexpr int STAGE8_BYTES = A_STAGE_BYTES + XFORM_STAGE_BYTES + W8_STAGE_BYTES + S8_STAGE_BYTES; // 12672 B +``` + +### Byte offsets + +```C// INT4 stage sA4_base(s) = s * STAGE4_BYTES + 0;T4_base(s) = A4_base(s) + A_STAGE_BYTES; // +8192B4_base(s) = T4_base(s) + XFORM_STAGE_BYTES; // +8448S4_base(s) = B4_base(s) + W4_STAGE_BYTES; // +12544// INT8 stage sA8_base(s) = s * STAGE8_BYTES + 0;T8_base(s) = A8_base(s) + A_STAGE_BYTES; // +8192B8_base(s) = T8_base(s) + XFORM_STAGE_BYTES; // +8448S8_base(s) = B8_base(s) + W8_STAGE_BYTES; // +12544 +C// INT4 stage sA4_base(s) = s * STAGE4_BYTES + 0;T4_base(s) = A4_base(s) + A_STAGE_BYTES; // +8192B4_base(s) = T4_base(s) + XFORM_STAGE_BYTES; // +8448S4_base(s) = B4_base(s) + W4_STAGE_BYTES; // +12544// INT8 stage sA8_base(s) = s * STAGE8_BYTES + 0;T8_base(s) = A8_base(s) + A_STAGE_BYTES; // +8192B8_base(s) = T8_base(s) + XFORM_STAGE_BYTES; // +8448S8_base(s) = B8_base(s) + W8_STAGE_BYTES; // +12544 +``` + +So total shared memory footprint is: + +- INT4 kernel: `4 * 12800 = 51200 B` +- INT8 kernel: `4 * 12672 = 50688 B` + +That is a sane starting point for Ampere. + +## 3) Activation shared-memory swizzle + +For the FP16 activation tile, use the Marlin/XOR style shared-memory swizzle so `ldmatrix.sync.aligned.m8n8.x4.shared.b16` stays conflict-free. Marlin explicitly says the 16-byte vector for logical tile element `ij` should be stored at location `i (i xor j)` to avoid `ldmatrix` bank conflicts. [arXiv](https://arxiv.org/pdf/2408.11743) + +For `M_TILE = 64`, `K_STAGE = 64`, store the activation tile as 16-byte chunks: + +- logical indices: + +`m in [0, 63]` + +`vec in [0, 7]` where each `vec` is `8 FP16 = 16 bytes` + +`elem in [0, 7]` +- `m in [0, 63]` +- `vec in [0, 7]` where each `vec` is `8 FP16 = 16 bytes` +- `elem in [0, 7]` + +Use this exact address formula: + +```C// Logical A tile element A[m, k]// where k = vec * 8 + elem// Store as 16-byte chunks with XOR swizzle on chunk id__device__ __forceinline__size_t a_smem_offset_bytes(int stage_base, int m, int vec, int elem) { int phys_vec = vec ^ (m & 0x7); int chunk_index = m * 8 + phys_vec; // 64 rows * 8 chunks return stage_base + (chunk_index * 16) + (elem * 2);} +C// Logical A tile element A[m, k]// where k = vec * 8 + elem// Store as 16-byte chunks with XOR swizzle on chunk id__device__ __forceinline__size_t a_smem_offset_bytes(int stage_base, int m, int vec, int elem) { int phys_vec = vec ^ (m & 0x7); int chunk_index = m * 8 + phys_vec; // 64 rows * 8 chunks return stage_base + (chunk_index * 16) + (elem * 2);} +``` + +Equivalently, each row has 8 logical 16-byte chunks, but the physical chunk index is XOR-swizzled by the low 3 bits of the row. + +This is the draft exact layout I would use for the **transformed activation tile**. + +## 4) Weight shared-memory layout + +### INT4 kernel + +Store the `64 x 128` quantized weight tile **contiguously** in shared memory: + +```C// 4096 B totaluint8_t B4_smem[64][64]; // each row has 64 bytes = 128 int4 values +C// 4096 B totaluint8_t B4_smem[64][64]; // each row has 64 bytes = 128 int4 values +``` + +Interpretation: + +- `B4_smem[k][byte_col]` +- each byte contains two 4-bit weights +- offline packing is already done so each warp’s 16-byte vector is laid out in the order needed for direct register dequantization + +Marlin explicitly preprocesses weights offline, reshuffles `16 x 64` tiles to contiguous order, and packs each thread’s 16-byte vector so it contains the 8 quantized weights needed for 4 separate `16 x 16` Tensor Core blocks, with the nibble pattern `64207531` inside the `INT32`s. I would keep that same idea for the INT4 specialization. [arXiv+1](https://arxiv.org/pdf/2408.11743) + +### INT8 kernel + +Store the `64 x 64` weight tile **contiguously** in shared memory: + +```C// 4096 B totaluint8_t B8_smem[64][64]; // one int8 weight per byte +C// 4096 B totaluint8_t B8_smem[64][64]; // one int8 weight per byte +``` + +No nibble-interleave is needed; just preserve the same **column-major block order** as the INT4 kernel so the warp schedule stays structurally identical. + +### Per-output scales + +For both specializations, store one FP16 scale per output channel for the current logical 128-group: + +```Chalf S4_smem[128]; // 256 Bhalf S8_smem[64]; // 128 B +Chalf S4_smem[128]; // 256 Bhalf S8_smem[64]; // 128 B +``` + +Marlin notes that for grouped quantization it reorganizes scales similarly to weights and, although scales for group-size 128 technically only need to refresh every other `K_STAGE=64`, it still reloads them regularly to keep compiler scheduling stable. I would do the same for this first kernel revision. [arXiv](https://arxiv.org/pdf/2408.11743) + +## 5) Transform metadata layout + +Use a tiny fixed-format transform block per `K_STAGE = 64`: + +```Cstruct __align__(16) RotationMeta { uint8_t a; // column index 0..63 uint8_t b; // column index 0..63 uint16_t pad0; half c; // cos(theta) half s; // sin(theta) uint32_t pad1;}; // 8 bytes// Shared-memory transform block, 256 B alignedstruct __align__(16) TransformStage { half scale[64]; // 128 B RotationMeta rot[8]; // 64 B (up to 8 sparse pairs per 64-col K block) uint8_t pad[64]; // pad to 256 B}; +Cstruct __align__(16) RotationMeta { uint8_t a; // column index 0..63 uint8_t b; // column index 0..63 uint16_t pad0; half c; // cos(theta) half s; // sin(theta) uint32_t pad1;}; // 8 bytes// Shared-memory transform block, 256 B alignedstruct __align__(16) TransformStage { half scale[64]; // 128 B RotationMeta rot[8]; // 64 B (up to 8 sparse pairs per 64-col K block) uint8_t pad[64]; // pad to 256 B}; +``` + +This keeps the runtime transform tile-local and tiny enough to stage with the weights. + +## 6) Warp mapping + +### INT4 kernel: 8 warps, 64 x 128 + +Map warps as: + +```Cwarp_id = threadIdx.x >> 5; // 0..7lane = threadIdx.x & 31;wm = warp_id & 3; // 0..3 -> rowswn = warp_id >> 2; // 0..1 -> 64-col halfrow_base = wm * 16; // 0,16,32,48col_base = wn * 64; // 0 or 64 +Cwarp_id = threadIdx.x >> 5; // 0..7lane = threadIdx.x & 31;wm = warp_id & 3; // 0..3 -> rowswn = warp_id >> 2; // 0..1 -> 64-col halfrow_base = wm * 16; // 0,16,32,48col_base = wn * 64; // 0 or 64 +``` + +Each warp computes a `16 x 64` output tile. + +### INT8 kernel: 4 warps, 64 x 64 + +```Cwarp_id = threadIdx.x >> 5; // 0..3lane = threadIdx.x & 31;wm = warp_id; // 0..3wn = 0;row_base = wm * 16;col_base = 0; +Cwarp_id = threadIdx.x >> 5; // 0..3lane = threadIdx.x & 31;wm = warp_id; // 0..3wn = 0;row_base = wm * 16;col_base = 0; +``` + +Again, each warp computes a `16 x 64` tile. + +## 7) ldmatrix schedule + +For the activation operand, use **one ldmatrix.sync.aligned.m8n8.x4.shared.b16 per warp per k16 slice**. + +Why this exact instruction: + +- PTX defines `ldmatrix` as the warp-level matrix load for `mma` +- `.m8n8.x4.shared.b16` loads **four** `8 x 8` matrices of 16-bit elements +- that is exactly what a `16 x 16` FP16 A-fragment needs: top-left, bottom-left, top-right, bottom-right submatrices. [NVIDIA Docs+1](https://docs.nvidia.com/cuda/parallel-thread-execution/) + +### Address supply pattern + +For the warp’s current `16 x 16` A fragment: + +- lanes `0–7` supply row pointers for rows `row_base + 0..7`, cols `k0 + 0..7` +- lanes `8–15` supply row pointers for rows `row_base + 8..15`, cols `k0 + 0..7` +- lanes `16–23` supply row pointers for rows `row_base + 0..7`, cols `k0 + 8..15` +- lanes `24–31` supply row pointers for rows `row_base + 8..15`, cols `k0 + 8..15` + +That is the `.x4` row-address pattern PTX describes for lanes `0–7 / 8–15 / 16–23 / 24–31`. [NVIDIA Docs](https://docs.nvidia.com/cuda/parallel-thread-execution/) + +### Exact k loop schedule + +`K_STAGE = 64` is split into 4 microsteps of `k_step = 16`. + +For each warp: + +```Cfor (int ks = 0; ks < 4; ++ks) { int k0 = ks * 16; // One ldmatrix.x4 gives the full 16x16 A fragment for this warp and k-slice // logical submatrices: // M0 = rows 0..7, cols 0..7 // M1 = rows 8..15, cols 0..7 // M2 = rows 0..7, cols 8..15 // M3 = rows 8..15, cols 8..15 ldmatrix.sync.aligned.m8n8.x4.shared.b16 ...} +Cfor (int ks = 0; ks < 4; ++ks) { int k0 = ks * 16; // One ldmatrix.x4 gives the full 16x16 A fragment for this warp and k-slice // logical submatrices: // M0 = rows 0..7, cols 0..7 // M1 = rows 8..15, cols 0..7 // M2 = rows 0..7, cols 8..15 // M3 = rows 8..15, cols 8..15 ldmatrix.sync.aligned.m8n8.x4.shared.b16 ...} +``` + +## 8) MMA schedule + +Use `mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16`. + +PTX documents the `m16n8k16` FP16 MMA form, and Marlin’s paper explicitly says its inner loop computes an `M x 16` times `16 x 64` matmul and does so **column-wise** using `16 x 8` Tensor Core instructions so the next B fragment can be dequantized while the current column is being multiplied. [NVIDIA Docs+1](https://docs.nvidia.com/cuda/parallel-thread-execution/) + +### Per-warp inner loop + +For each `ks`: + +- load A once via `ldmatrix.x4` +- then execute **8 MMA instructions** for the warp’s `16 x 64` output tile + +```Cfor (int j = 0; j < 8; ++j) { // j-th 16x8 B fragment of the warp's 16x64 tile // col range = col_base + j*8 ... col_base + j*8 + 7 // 1) decode B_j from shared-memory packed format into registers // 2) apply per-output scales // 3) issue mma.sync on A_frag x B_j -> accum_j mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 ...} +Cfor (int j = 0; j < 8; ++j) { // j-th 16x8 B fragment of the warp's 16x64 tile // col range = col_base + j*8 ... col_base + j*8 + 7 // 1) decode B_j from shared-memory packed format into registers // 2) apply per-output scales // 3) issue mma.sync on A_frag x B_j -> accum_j mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 ...} +``` + +### Pipelined decode + +Follow Marlin’s column-wise strategy: + +- while `mma(j)` is executing, decode `B(j+1)` into the next register set +- keep **A fixed** across the 8 column slices for this `ks` +- move to the next `ks` only after all 8 column slices are accumulated + +That is the schedule Marlin uses to overlap weight decode with Tensor Core math. [arXiv](https://arxiv.org/pdf/2408.11743) + +## 9) Activation-transform schedule + +This is the Paro-style piece, but done in the only place I think is feasible for Ampere: + +- **load activations into XOR-swizzled shared memory** +- **apply the inverse transform in shared memory** +- then `ldmatrix` from the transformed tile + +I would do this exactly once per stage, before the first `ldmatrix` of that stage. + +### Per-stage transform pass + +Each warp owns its `16 x 64` row slab: + +```C// warp owns rows row_base .. row_base+15for each local row r in row_base .. row_base+15: // 1) scale all 64 columns for k in 0..63: A[r,k] *= scale[k] // 2) apply sparse pairwise Givens rotations for t in 0..ROT_COUNT-1: u = rot[t].a v = rot[t].b c = rot[t].c s = rot[t].s au = A[r,u] av = A[r,v] A[r,u] = c * au + s * av A[r,v] = -s * au + c * av +C// warp owns rows row_base .. row_base+15for each local row r in row_base .. row_base+15: // 1) scale all 64 columns for k in 0..63: A[r,k] *= scale[k] // 2) apply sparse pairwise Givens rotations for t in 0..ROT_COUNT-1: u = rot[t].a v = rot[t].b c = rot[t].c s = rot[t].s au = A[r,u] av = A[r,v] A[r,u] = c * au + s * av A[r,v] = -s * au + c * av +``` + +Because the transform is tile-local over the `K_STAGE = 64` columns and sparse (for example 8 pairs), this is still practical as a shared-memory preprocessing step. It does **not** go inside the MMA inner loop. + +## 10) cp.async pipeline + +Use a 4-stage pipeline like this: + +### Prologue + +Prefetch stages `0,1,2,3`: + +- A tile → `A*_base(stage)` +- transform metadata → `T*_base(stage)` +- B tile → `B*_base(stage)` using `evict_first` on weight loads +- scale tile → `S*_base(stage)` + +Marlin explicitly says it uses `cp.async`, uses `evict first` for B to avoid polluting L2 with weights that are only read once, and chooses pipeline depth `P = 4`. [arXiv](https://arxiv.org/pdf/2408.11743) + +### Main loop + +For `t = 0 .. num_kstages-1`: + +```Ccur = t & 3nxt = (t + 4) & 3// prefetch future stage if it existscp.async(... A[nxt] ...)cp.async(... T[nxt] ...)cp.async(... B[nxt] ...) // evict_first for weightscp.async(... S[nxt] ...)cp.async.commit_group()cp.async.wait_group(1)__syncthreads()// transform A[cur] in shared memoryapply_transform(A[cur], T[cur])__syncthreads()// compute on current stagefor ks in 0..3: ldmatrix A fragment for j in 0..7: decode next B fragment mma current fragment +Ccur = t & 3nxt = (t + 4) & 3// prefetch future stage if it existscp.async(... A[nxt] ...)cp.async(... T[nxt] ...)cp.async(... B[nxt] ...) // evict_first for weightscp.async(... S[nxt] ...)cp.async.commit_group()cp.async.wait_group(1)__syncthreads()// transform A[cur] in shared memoryapply_transform(A[cur], T[cur])__syncthreads()// compute on current stagefor ks in 0..3: ldmatrix A fragment for j in 0..7: decode next B fragment mma current fragment +``` + +### Epilogue + +Store accumulators to global. + +## 11) What this gives you + +This draft meets the constraints you laid out: + +- **no runtime codebook decode** +- **no mixed INT4/INT8 branching inside a warp tile** +- **Paro-style transforms happen outside the MMA hot loop** +- **QKV remains uniform** +- **the inner loop stays Marlin-like** +- **the only ldmatrix path is FP16 A**, which is exactly where Ampere/Marlin already want it. Marlin’s design and PTX’s documented `ldmatrix`/`mma` behavior are the foundation here. [arXiv+2NVIDIA Docs+2](https://arxiv.org/pdf/2408.11743) + +If you want, next I’ll turn this into: + +- a **PTX-level warp pseudocode** for the `ldmatrix` + `mma.sync` sequence, or +- a **CUTLASS/CUDA kernel skeleton** with the exact stage offsets and warp mapping baked in. + +--- + +### User: +Here is the INT4 CUDA kernel skeleton. It locks in the 4-stage pipeline, `cp.async` routing, and the required synchronization barriers. + +The INT8 version is structurally identical, just with `N_TILE = 64` and its corresponding byte offsets. + +``` +cpp +#include +#include + +// ------------------------------ +// Pipeline & Tile Constants (INT4) +// ------------------------------ +constexpr int PIPE = 4; +constexpr int M_TILE = 64; +constexpr int N_TILE4 = 128; +constexpr int K_STAGE = 64; + +// Shared Memory Byte Sizes per Stage +constexpr int A_STAGE_BYTES = 8192; // 64 * 64 * 2 bytes +constexpr int XFORM_STAGE_BYTES = 256; // Scales + Sparse Rotations +constexpr int W4_STAGE_BYTES = 4096; // 64 * 128 / 2 (4-bit) +constexpr int S4_STAGE_BYTES = 256; // 128 * 2 bytes (FP16 scales) + +constexpr int STAGE4_BYTES = A_STAGE_BYTES + XFORM_STAGE_BYTES + W4_STAGE_BYTES + S4_STAGE_BYTES; // 12800 B +constexpr int TOTAL_SMEM_BYTES = PIPE * STAGE4_BYTES; // 51200 B + +// Async Copy Macros +#define CP_ASYNC_CG() asm volatile("cp.async.commit_group;\n" ::) +#define CP_ASYNC_WAIT(N) asm volatile("cp.async.wait_group %0;\n" :: "n"(N)) +// .L2::evict_first for weights +#define CP_ASYNC_CG_EVICT(dst, src, bytes) \ + asm volatile("cp.async.cg.shared.global.L2::evict_first [%0], [%1], %2;\n" :: "r"(dst), "l"(src), "n"(bytes)) + +__global__ void __launch_bounds__(256) ampere_marlin_paro_int4_kernel( + const half* __restrict__ A, + const uint8_t* __restrict__ B_quant, + const half* __restrict__ Scales, + const void* __restrict__ XformMeta, + half* __restrict__ C, + int num_k_stages) +{ + // Shared memory allocation + extern __shared__ uint8_t smem[]; + + // Thread/Warp indexing + const int tid = threadIdx.x; + const int warp_id = tid >> 5; + const int lane = tid & 31; + + // Output tile indices + const int block_m = blockIdx.x * M_TILE; + const int block_n = blockIdx.y * N_TILE4; + + // Accumulators (16x64 output per warp -> 8x M16N8K16 blocks) + // 4 registers per M16N8K16 accumulator * 8 blocks = 32 registers + float accum[32] = {0.0f}; + + // ------------------------------ + // Prologue: Fill Pipeline + // ------------------------------ + for (int stage = 0; stage < PIPE - 1; ++stage) { + int smem_base = stage * STAGE4_BYTES; + + // TODO: Map tid to specific 16-byte chunks for cp.async + // cp.async A -> smem_base + // cp.async XformMeta -> smem_base + A_STAGE_BYTES + // CP_ASYNC_CG_EVICT B_quant -> smem_base + A_STAGE_BYTES + XFORM_STAGE_BYTES + // cp.async Scales -> smem_base + A_STAGE_BYTES + XFORM_STAGE_BYTES + W4_STAGE_BYTES + + CP_ASYNC_CG(); + } + + // ------------------------------ + // Main Loop + // ------------------------------ + int smem_load_idx = 0; + int smem_store_idx = PIPE - 1; + + for (int t = 0; t < num_k_stages; ++t) { + // 1. Issue async copies for the next stage (if valid) + if (t + PIPE - 1 < num_k_stages) { + int store_base = smem_store_idx * STAGE4_BYTES; + + // TODO: issue cp.async for A, Xform, B_quant, Scales -> store_base + + CP_ASYNC_CG(); + } + + // 2. Wait for the oldest stage to be fully resident in shared memory + // If PIPE=4, we wait until only 3 uncommitted groups remain + CP_ASYNC_WAIT(PIPE - 2); + __syncthreads(); // Barrier 1: smem is safe to read/mutate for current stage + + // 3. Paro-style Transform (Shared Memory Mutation) + int load_base = smem_load_idx * STAGE4_BYTES; + half* A_smem = (half*)(smem + load_base); + void* T_smem = (void*)(smem + load_base + A_STAGE_BYTES); + + // Each warp owns a 16-row slab of the 64x64 A tile. + // Apply channel scales and sparse Givens rotations here. + // apply_paro_transform(A_smem, T_smem, warp_id, tid); + + __syncthreads(); // Barrier 2: transform is globally visible to all warps + + // 4. Tensor Core Inner Loop + uint8_t* B_smem = (uint8_t*)(smem + load_base + A_STAGE_BYTES + XFORM_STAGE_BYTES); + half* S_smem = (half*)(smem + load_base + A_STAGE_BYTES + XFORM_STAGE_BYTES + W4_STAGE_BYTES); + + #pragma unroll + for (int ks = 0; ks < 4; ++ks) { + // TODO: ldmatrix.sync.aligned.m8n8.x4.shared.b16 (A fragment) + + #pragma unroll + for (int j = 0; j < 8; ++j) { + // TODO: + // - Decode B_smem column j -> registers + // - Apply S_smem scale + // - mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 + } + } + + // Advance pipeline + smem_load_idx = (smem_load_idx + 1) % PIPE; + smem_store_idx = (smem_store_idx + 1) % PIPE; + } + + // Drain remaining pipeline stages + // ... wait_group(2) -> math ... wait_group(1) -> math ... wait_group(0) -> math + + // ------------------------------ + // Epilogue + // ------------------------------ + // Write accum[] to global memory C +} +``` + +Review the synchronization logic around the transform phase. `CP_ASYNC_WAIT(PIPE - 2)` ensures the oldest stage is resident. The first `__syncthreads()` guarantees the warp can safely mutate `A_smem`. The second `__syncthreads()` guarantees the transform is visible to all warps before the MMA inner loop starts blasting `ldmatrix`. + +Let me know if you spot any pipeline stalls or edge cases here. Otherwise, we move to the PTX `ldmatrix` / `mma.sync` inner loop implementation next. + +--- + +### ChatGPT: +The transform-phase logic is **almost right**, but there is one real race and one conservative barrier you can probably remove. + +The good part first: `cp.async.wait_group` only makes `cp.async` writes visible to the **executing thread**, not to the whole CTA, and it does not order unrelated memory ops. So your first `__syncthreads()` after `CP_ASYNC_WAIT(PIPE - 2)` is necessary if all warps are about to read or mutate the stage’s shared-memory tile. That matches the PTX semantics exactly. [NVIDIA Docs+1](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html) + +The main issue is this: + +## 1) You are missing a **third barrier** after the MMA phase + +Right now, a fast warp can finish compute on `smem_load_idx`, advance the ring, and start issuing `cp.async` into `smem_store_idx` for the next iteration **while a slower warp is still reading from the old stage**. Because Ampere has independent warp scheduling, you cannot assume all warps leave the inner loop at the same instant. PTX also does not give you any cross-warp safety here. [NVIDIA Docs+1](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html) + +So the safe structure is: + +```C++// wait current stage residentCP_ASYNC_WAIT(PIPE - 2);__syncthreads(); // Barrier A: all cp.async writes for current stage are CTA-visible// mutate current stageapply_paro_transform(...);// If transform touches only warp-private rows, Barrier B may be removable.// Otherwise keep it.__syncthreads(); // Barrier B: transformed tile globally visible// consume current stagefor (int ks = 0; ks < 4; ++ks) { ...}__syncthreads(); // Barrier C: no warp is still reading this stage before ring-buffer reuse +C++// wait current stage residentCP_ASYNC_WAIT(PIPE - 2);__syncthreads(); // Barrier A: all cp.async writes for current stage are CTA-visible// mutate current stageapply_paro_transform(...);// If transform touches only warp-private rows, Barrier B may be removable.// Otherwise keep it.__syncthreads(); // Barrier B: transformed tile globally visible// consume current stagefor (int ks = 0; ks < 4; ++ks) { ...}__syncthreads(); // Barrier C: no warp is still reading this stage before ring-buffer reuse +``` + +That **Barrier C** is the one your skeleton is currently missing. + +## 2) CP_ASYNC_WAIT(PIPE - 2) is correct only under one strict assumption + +With `PIPE = 4`, `wait_group(2)` is fine **if and only if**: + +- each pipeline stage corresponds to exactly **one** committed cp.async group, and +- you consume exactly **one** stage per iteration. + +PTX says `cp.async.wait_group N` waits until only `N` or fewer of the most recent groups remain pending, and all earlier committed groups are complete. That means your oldest stage is resident if your “one stage = one group” bookkeeping is exact. [NVIDIA Docs+1](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html) + +If you ever split one stage across multiple `commit_group`s, `wait_group(PIPE - 2)` stops meaning “current stage is ready.” + +## 3) Your second barrier is probably **conservative**, not always mandatory + +If `apply_paro_transform()` is written so that: + +- each warp only mutates its own `16 x 64` slab of `A_smem`, and +- each warp’s later `ldmatrix` reads only that same slab, + +then the **post-transform __syncthreads()** is not strictly required for correctness of `A_smem`. In that case, each warp can transform its slab and immediately start its own MMA loop. + +So: + +- keep **Barrier A** always, +- keep **Barrier C** always, +- treat **Barrier B** as: + +**required** if any warp reads rows transformed by another warp, + +**optional** if row ownership is fully disjoint. +- **required** if any warp reads rows transformed by another warp, +- **optional** if row ownership is fully disjoint. + +That is probably your biggest easy stall to remove later. + +## 4) Prologue / short-K edge case + +Your prologue currently fills `PIPE - 1` stages unconditionally. That is only valid if `num_k_stages >= PIPE - 1`. If the K-loop is short, you need: + +```C++int prologue = min(PIPE - 1, num_k_stages);for (int stage = 0; stage < prologue; ++stage) { ... } +C++int prologue = min(PIPE - 1, num_k_stages);for (int stage = 0; stage < prologue; ++stage) { ... } +``` + +And the drain path must mirror how many groups were actually committed. + +## 5) One subtle issue in the top-of-loop prefetch placement + +Prefetching the next stage at the **top** of the loop is okay only if the previous iteration already ended with Barrier C. Without that barrier, reusing `smem_store_idx` is unsafe, as above. + +If you want the control flow to be easier to reason about, move the next-stage prefetch to the **end** of the loop, immediately after Barrier C. That makes the ring-buffer lifetime obvious. + +## 6) Practical inline-PTX footnote + +Not synchronization-related, but your `cp.async` macro will usually need a **shared-memory address**, typically from `__cvta_generic_to_shared()`, not a raw C pointer cast. `ldmatrix` has the same caveat. PTX’s `ldmatrix` examples explicitly use shared addresses. [NVIDIA Docs](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html) + +## My recommendation + +Use this as the safe skeleton: + +```C++// prologuefor (int stage = 0; stage < min(PIPE - 1, num_k_stages); ++stage) { issue_cp_async_for_stage(stage); CP_ASYNC_CG();}for (int t = 0; t < num_k_stages; ++t) { // prefetch future stage only if it exists if (t + PIPE - 1 < num_k_stages) { issue_cp_async_for_stage(smem_store_idx); CP_ASYNC_CG(); } CP_ASYNC_WAIT(PIPE - 2); __syncthreads(); // A: stage resident + CTA-visible apply_paro_transform(...); __syncthreads(); // B: keep for now, remove later only if provably warp-private do_mma_inner_loop(...); __syncthreads(); // C: stage no longer in use before ring reuse smem_load_idx = (smem_load_idx + 1) % PIPE; smem_store_idx = (smem_store_idx + 1) % PIPE;} +C++// prologuefor (int stage = 0; stage < min(PIPE - 1, num_k_stages); ++stage) { issue_cp_async_for_stage(stage); CP_ASYNC_CG();}for (int t = 0; t < num_k_stages; ++t) { // prefetch future stage only if it exists if (t + PIPE - 1 < num_k_stages) { issue_cp_async_for_stage(smem_store_idx); CP_ASYNC_CG(); } CP_ASYNC_WAIT(PIPE - 2); __syncthreads(); // A: stage resident + CTA-visible apply_paro_transform(...); __syncthreads(); // B: keep for now, remove later only if provably warp-private do_mma_inner_loop(...); __syncthreads(); // C: stage no longer in use before ring reuse smem_load_idx = (smem_load_idx + 1) % PIPE; smem_store_idx = (smem_store_idx + 1) % PIPE;} +``` + +So the answer is: + +- **wait_group(PIPE - 2) is fine** +- **first barrier is required** +- **second barrier is maybe removable** +- **you are definitely missing a final barrier after the MMA loop** + +That final barrier is the real correctness fix. + +Paste the `ldmatrix` / `mma.sync` inner loop next and I’ll review the register ownership and fragment schedule. + +--- + +### User: +You nailed the missing Barrier C. Good catch on the `cp.async` short-K drain logic too. Reusing the ring buffer without guaranteeing the previous consumers are done is a classic race condition, and Ampere's independent warp schedulers will absolutely punish it. + +Here is the `ldmatrix` + `mma.sync` inner loop locked in. This drops directly into the `do_mma_inner_loop(...)` placeholder. It assumes the offline INT4 packing and nibble interleaving are already strictly aligned to what a 16x8 MMA instruction expects in its registers. + +``` +cpp +// ------------------------------ +// Registers & Addressing Setup +// (Place before the K-stage loop) +// ------------------------------ + +// Accumulators: 8 column slices * 4 registers (for m16n8k16 f16 accumulators) +uint32_t RC[8][4]; +#pragma unroll +for (int j = 0; j < 8; ++j) { + RC[j][0] = 0; RC[j][1] = 0; RC[j][2] = 0; RC[j][3] = 0; +} + +// Lane to A_smem address mapping (XOR swizzle) +int m_local = lane % 16; +int k_local = (lane / 16) * 8; // 0 or 8 +int row = row_base + m_local; + +// Base shared memory pointers for the generic address space conversion +uint32_t a_smem_base; +asm("cvta.to.shared.u32 %0, %1;" : "=r"(a_smem_base) : "l"(A_smem)); + +uint32_t b_smem_base; +asm("cvta.to.shared.u32 %0, %1;" : "=r"(b_smem_base) : "l"(B_smem)); + +// ------------------------------ +// Inner Loop +// ------------------------------ +#pragma unroll +for (int ks = 0; ks < 4; ++ks) { + uint32_t RA[4]; // 16x16 A fragment (4x 32-bit registers) + + // 1. Calculate lane's specific ldmatrix address for this ks + int col = ks * 16 + k_local; + int vec = col / 8; + int phys_vec = vec ^ (row & 0x7); + uint32_t a_smem_addr = a_smem_base + (row * 8 + phys_vec) * 16; + + // 2. Load 16x16 A fragment + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];" + : "=r"(RA[0]), "=r"(RA[1]), "=r"(RA[2]), "=r"(RA[3]) + : "r"(a_smem_addr) + ); + + #pragma unroll + for (int j = 0; j < 8; ++j) { + uint32_t RB_quant[1]; // 8 INT4 weights = 4 bytes = 1x 32-bit register + uint32_t RB_fp16[2]; // Dequantized to 8x FP16 = 2x 32-bit registers + + // Output column index for this 16x8 slice + int global_col = col_base + j * 8; + + // 3. Load packed INT4 B fragment from shared memory + // B layout: 64 rows x 128 cols (4-bit) -> mapped contiguous + // Lane reads exactly 4 bytes containing its 8 required weights + int b_row = ks * 16 + (lane % 16); + int b_byte_col = (global_col / 2) + (lane / 16) * 4; + uint32_t b_smem_addr = b_smem_base + (b_row * 64) + b_byte_col; + + asm volatile("ld.shared.u32 %0, [%1];" : "=r"(RB_quant[0]) : "r"(b_smem_addr)); + + // 4. Dequantize INT4 -> FP16 (assumes Marlin offline 64207531 nibble packing) + // Extract nibbles, cast to FP16, apply per-channel scales + half scale_0 = S_smem[global_col + (lane / 16) * 4 + 0]; + half scale_1 = S_smem[global_col + (lane / 16) * 4 + 1]; + half scale_2 = S_smem[global_col + (lane / 16) * 4 + 2]; + half scale_3 = S_smem[global_col + (lane / 16) * 4 + 3]; + + // [Insert highly optimized PTX bit-extraction & scale multiplication here] + // E.g., using prmt.b32 and fma.rn.f16x2 to populate RB_fp16[0] and RB_fp16[1] + // This is a placeholder for the raw decode sequence to keep the structure clear. + RB_fp16[0] = /* decoded and scaled */ 0; + RB_fp16[1] = /* decoded and scaled */ 0; + + // 5. Tensor Core MMA (16x8x16 FP16) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " + "{%0, %1, %2, %3}, " + "{%4, %5, %6, %7}, " + "{%8, %9}, " + "{%10, %11, %12, %13};" + : "=r"(RC[j][0]), "=r"(RC[j][1]), "=r"(RC[j][2]), "=r"(RC[j][3]) + : "r"(RA[0]), "r"(RA[1]), "r"(RA[2]), "r"(RA[3]), + "r"(RB_fp16[0]), "r"(RB_fp16[1]), + "r"(RC[j][0]), "r"(RC[j][1]), "r"(RC[j][2]), "r"(RC[j][3]) + ); + } +} +``` + +**Two key technical points here:** + +1. **Addressing:** The `a_smem_addr` logic precisely maps PTX `.x4` thread roles to the XOR-swizzled layout. Using `cvta.to.shared.u32` ensures you pass a 32-bit shared memory pointer to `ldmatrix`, averting 64-bit generic pointer overhead. +2. **Overlap:** Loading `A` outside the `j` loop anchors it. Dequantizing `B_quant` directly before the `mma.sync` instruction is what allows the compiler to pipeline the math of slice `j` with the decode of slice `j+1`. + +Look over the register allocations. If we are good, we can crack open the raw `prmt.b32` bit-magic for the INT4 to FP16 dequantization step. + +--- + +### ChatGPT: +I’m checking the PTX operand mapping before I bless this, because `ldmatrix` lane addressing is where these kernels usually go subtly wrong. + +--- + +This is **close**, but I would not move to the `prmt` decode yet. +There are **2 correctness issues** and **1 performance trap** to fix first. + +## What is correct + +Your **A-side ldmatrix grouping** is consistent with PTX’s `.x4` contract: the four 8-thread groups provide the row addresses for the four 8x8 matrices in a 16x16 fragment, i.e. lanes `0–7`, `8–15`, `16–23`, and `24–31` supply `addr0–7`, `addr8–15`, `addr16–23`, and `addr24–31` respectively. So the high-level idea of: + +- `row = row_base + lane % 16` +- `k_local = (lane / 16) * 8` +is aligned with the required quadrant split for a 16x16 A fragment. PTX also explicitly says `ldmatrix` is warp-wide and the `.x4` form loads the third and fourth matrices into the subsequent destination registers. [NVIDIA Docs+1](https://docs.nvidia.com/cuda/parallel-thread-execution/) + +Your synchronization story is also basically right now **if** you keep the missing post-MMA barrier: + +- `cp.async.wait_group(...)` +- `__syncthreads()` before transform +- optional `__syncthreads()` after transform +- **required** `__syncthreads()` after the MMA loop before ring-buffer reuse + +That last barrier is required because `cp.async.wait_group` only makes the async copies visible to the **executing thread**, and PTX says it does **not** provide ordering/visibility guarantees for other memory operations. [NVIDIA Docs](https://docs.nvidia.com/cuda/parallel-thread-execution/) + +## Must-fix #1: your mma.sync accumulator constraints are wrong + +This is the biggest inline-PTX bug in the snippet. + +You currently use: + +```C++: "=r"(RC[j][0]), ......: ..., "r"(RC[j][0]), ... +C++: "=r"(RC[j][0]), ......: ..., "r"(RC[j][0]), ... +``` + +That is not the right way to express an in-place accumulator. PTX explicitly says `D` and `C` may refer to the same matrix, so the clean inline-asm form is to use **read-write outputs** for `RC` and then reference those same operands as both `D` and `C` in the template. [NVIDIA Docs](https://docs.nvidia.com/cuda/parallel-thread-execution/) + +Use this pattern instead: + +```C++asm volatile( "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " "{%0, %1, %2, %3}, " "{%4, %5, %6, %7}, " "{%8, %9}, " "{%0, %1, %2, %3};\n" : "+r"(RC[j][0]), "+r"(RC[j][1]), "+r"(RC[j][2]), "+r"(RC[j][3]) : "r"(RA[0]), "r"(RA[1]), "r"(RA[2]), "r"(RA[3]), "r"(RB_fp16[0]), "r"(RB_fp16[1])); +C++asm volatile( "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " "{%0, %1, %2, %3}, " "{%4, %5, %6, %7}, " "{%8, %9}, " "{%0, %1, %2, %3};\n" : "+r"(RC[j][0]), "+r"(RC[j][1]), "+r"(RC[j][2]), "+r"(RC[j][3]) : "r"(RA[0]), "r"(RA[1]), "r"(RA[2]), "r"(RA[3]), "r"(RB_fp16[0]), "r"(RB_fp16[1])); +``` + +That is the safe version. + +## Must-fix #2: your B-fragment sizing is inconsistent + +For `mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16`, the PTX operand counts are: + +- `A`: **4** `.f16x2` regs +- `B`: **2** `.f16x2` regs +- `C/D`: **4** accumulator regs + +That means each lane contributes **4 FP16 values** for `B`, not 8. The PTX register count is not ambiguous here. [Stack Overflow](https://stackoverflow.com/questions/70683955/understanding-the-parameters-of-ptx-instruction-mma-sync-aligned-m16n8k16-row-co) + +So this part of your code is mathematically inconsistent: + +```C++uint32_t RB_quant[1]; // 8 INT4 weights = 4 bytes = 1x 32-bit registeruint32_t RB_fp16[2]; // Dequantized to 8x FP16 = 2x 32-bit registers +C++uint32_t RB_quant[1]; // 8 INT4 weights = 4 bytes = 1x 32-bit registeruint32_t RB_fp16[2]; // Dequantized to 8x FP16 = 2x 32-bit registers +``` + +Two 32-bit regs can only hold **4 half values**, not 8. + +So you need one of these two fixes: + +### Option A: per-lane load only the 4 INT4s actually needed + +Then decode those into the two `RB_fp16` registers. + +### Option B: overfetch 8 INT4s intentionally, but only select the 4 values needed for the current lane’s B fragment + +That can be valid, but then your comments and mapping need to say that explicitly. + +Right now the code/comments imply “8 int4 -> 8 fp16 -> 2 regs,” which is impossible. + +## Likely-wrong area: B shared-memory addressing + +This line is the one I would not trust yet: + +```C++int b_row = ks * 16 + (lane % 16);int b_byte_col = (global_col / 2) + (lane / 16) * 4; +C++int b_row = ks * 16 + (lane % 16);int b_byte_col = (global_col / 2) + (lane / 16) * 4; +``` + +It is too coarse for a real `m16n8k16` B-fragment mapping. + +Why I’m skeptical: + +- `lane / 16` only gives you **2 groups** +- but the warp-level B-fragment layout for `m16n8k16` is more structured than a simple upper/lower half split +- and your current mapping does not show how the lane-local 4 FP16 B values are selected consistently across the 16x8 fragment + +In other words: the **A-side map is believable**, the **B-side map is not finished**. + +My recommendation is to derive the exact lane→B mapping from a known-good `m16n8k16` warp fragment layout before you write the `prmt` path. Until that is nailed down, the decode work is premature. + +## Performance trap: the decode/MMA overlap is not guaranteed yet + +This sentence is slightly too optimistic: + +Dequantizing `B_quant` directly before the `mma.sync` instruction is what allows the compiler to pipeline the math of slice `j` with the decode of slice `j+1`. + +Maybe. Not guaranteed. + +As written, each `j` iteration is: + +1. load packed B +2. decode +3. scale +4. `mma.sync` + +The compiler may overlap some of it, but if you want the overlap to be **structural**, use a 2-buffer register pipeline: + +```C++uint32_t RB_q_cur, RB_q_nxt;uint32_t RB_cur[2], RB_nxt[2];// preload j = 0load+decode B(0) -> RB_curfor (int j = 0; j < 8; ++j) { if (j + 1 < 8) { load+decode B(j+1) -> RB_nxt } mma.sync(... RB_cur ...) if (j + 1 < 8) { RB_cur[0] = RB_nxt[0]; RB_cur[1] = RB_nxt[1]; }} +C++uint32_t RB_q_cur, RB_q_nxt;uint32_t RB_cur[2], RB_nxt[2];// preload j = 0load+decode B(0) -> RB_curfor (int j = 0; j < 8; ++j) { if (j + 1 < 8) { load+decode B(j+1) -> RB_nxt } mma.sync(... RB_cur ...) if (j + 1 < 8) { RB_cur[0] = RB_nxt[0]; RB_cur[1] = RB_nxt[1]; }} +``` + +That makes the intended overlap explicit instead of leaving it to compiler luck. + +## Register pressure assessment + +The accumulator footprint is fine as a starting point: + +- `RC[8][4]` = **32 regs** +- `RA[4]` = **4 regs** +- `RB_fp16[2]` = **2 regs** + +So the visible math state is only **38 regs** before pointers, scales, decode temporaries, loop vars, and compiler-generated spill buffers. That is not scary yet for a 256-thread CTA on Ampere. + +The real register danger is not `RC`; it is: + +- scalar scale loads (`scale_0..scale_3`) +- nibble unpack temps +- `prmt` scratch +- address arithmetic scratch + +So when you implement the decode, do **not** keep four scalar `half scale_*` values. Pack them as two `half2`-style 32-bit regs instead. + +Something like: + +```C++uint32_t scale01 = *reinterpret_cast(&S_smem[global_col + scale_idx0]);uint32_t scale23 = *reinterpret_cast(&S_smem[global_col + scale_idx2]); +C++uint32_t scale01 = *reinterpret_cast(&S_smem[global_col + scale_idx0]);uint32_t scale23 = *reinterpret_cast(&S_smem[global_col + scale_idx2]); +``` + +That keeps the B-side shape aligned with the 2-reg MMA operand. + +## Verdict + +I would mark it like this: + +- **A-side ldmatrix addressing:** plausible / basically good +- **synchronization (with Barrier C):** good +- **accumulator asm constraints:** must fix now +- **B fragment sizing/comments:** definitely wrong as written +- **B addressing:** not trustworthy yet +- **decode/MMA overlap:** needs explicit double-buffering if you want to rely on it + +So no, I would **not** open the raw `prmt.b32` bit-magic yet. +First fix the `mma` accumulator constraints and re-derive the exact per-lane `B` fragment map for `m16n8k16`. + +Paste the B-fragment lane map next, or I can draft the exact **lane→4-half B operand table** for `m16n8k16.row.col`. + +--- + +### User: +Good catches across the board. The in-place accumulator asm syntax was a stupid oversight, and I totally ghosted the B-fragment size—two 32-bit regs hold exactly 4 FP16s, end of story. Leaving the decode/MMA overlap to compiler luck is also a bad bet on Ampere. + +Let's lock in the exact B-fragment lane map for `mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16` before we touch the decode sequence. + +### The `m16n8k16.row.col` B-Fragment Map + +Since B is `.col` (column-major) with dimensions K=16, N=8: +* The 32 threads in a warp must supply 16 x 8 = 128 elements. +* 128 elements / 32 threads = **4 elements per thread** (which perfectly fits into our two `.f16x2` registers). + +Per the PTX ISA, the distribution of the 16x8 B matrix across the 32 lanes is: + +* **Column Index:** `n_col = lane % 8` +* **Row Index Base:** `k_base = (lane / 8) * 2` + +Each `lane` provides the following 4 elements from the `n_col` column: +1. `B[k_base, n_col]` +2. `B[k_base + 1, n_col]` +3. `B[k_base + 8, n_col]` +4. `B[k_base + 9, n_col]` + +### Implications for the Decode Pipeline + +This rigid mapping gives us two massive advantages for the INT4 setup: + +1. **Single Scale per Thread:** Because all 4 of a thread's elements land in the *exact same output column* (`n_col`), a thread only needs to load **one** FP16 scale for its entire B fragment. We drop the register pressure from 4 scales down to 1. +2. **16-bit Payload:** 4 INT4s = 16 bits. A single 32-bit `ld.shared.u32` intentionally overfetches. Marlin's offline packing specifically interleaves these exact four `k` indices so they sit contiguously in a 16-bit chunk. + +### Double-Buffered Pipeline Skeleton + +With the mapping and accumulator constraints fixed, here is the raw structural loop with explicit register double-buffering so the pipeline is guaranteed: + +``` +cpp +// +r constraint for in-place RC accumulators +#define MMA_SYNC_M16N8K16(RC, RA, RB) \ + asm volatile( \ + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " \ + "{%0, %1, %2, %3}, " \ + "{%4, %5, %6, %7}, " \ + "{%8, %9}, " \ + "{%0, %1, %2, %3};\n" \ + : "+r"(RC[0]), "+r"(RC[1]), "+r"(RC[2]), "+r"(RC[3]) \ + : "r"(RA[0]), "r"(RA[1]), "r"(RA[2]), "r"(RA[3]), \ + "r"(RB[0]), "r"(RB[1]) \ + ) + +// --- Pre-loop setup --- +int n_col = lane % 8; +int k_base = (lane / 8) * 2; // 0, 2, 4, 6 + +// Double-buffer registers +uint32_t RB_fp16_cur[2]; +uint32_t RB_fp16_nxt[2]; + +#pragma unroll +for (int ks = 0; ks < 4; ++ks) { + uint32_t RA[4]; + // [ldmatrix.sync.aligned.m8n8.x4.shared.b16 -> RA] + + // Preload & Decode j = 0 + int global_col_0 = col_base + n_col; + uint32_t scale_32_0 = /* load scale as 32-bit / half2 if possible, or single half */; + uint32_t b_smem_addr_0 = /* derived from offline packed layout */; + + // [Execute INT4 -> FP16 decode into RB_fp16_cur] + + #pragma unroll + for (int j = 0; j < 8; ++j) { + if (j < 7) { + // Load & Decode j + 1 into nxt + int global_col_nxt = col_base + (j + 1) * 8 + n_col; + uint32_t scale_32_nxt = /* ... */; + uint32_t b_smem_addr_nxt = /* ... */; + + // [Execute INT4 -> FP16 decode into RB_fp16_nxt] + } + + // Issue math for current j slice + MMA_SYNC_M16N8K16(RC[j], RA, RB_fp16_cur); + + if (j < 7) { + // Swap buffers + RB_fp16_cur[0] = RB_fp16_nxt[0]; + RB_fp16_cur[1] = RB_fp16_nxt[1]; + } + } +} +``` + +If you agree with this lane map and double-buffered skeleton, we can finally crack open the `prmt.b32` bit-magic to map those 16 bits of INT4 payload into the `RB_fp16` registers. + +--- + +### ChatGPT: +I’m verifying the B-fragment lane mapping against PTX references before I endorse it, because this is exactly the kind of detail that can look right and still be wrong. + +--- + +The **B-fragment map you wrote is not correct**, but it is very close in spirit. + +For `mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16`, PTX gives the exact per-lane mapping for multiplicand **B** as: + +- `groupID = %laneid >> 2` +- `threadID_in_group = %laneid % 4` +- for fragment elements `b_i`: + +`row = (threadID_in_group * 2) + (i & 0x1)` for `i < 2` + +`row = (threadID_in_group * 2) + (i & 0x1) + 8` for `i >= 2` +- `row = (threadID_in_group * 2) + (i & 0x1)` for `i < 2` +- `row = (threadID_in_group * 2) + (i & 0x1) + 8` for `i >= 2` +- `col = groupID` [NVIDIA Docs](https://docs.nvidia.com/cuda/parallel-thread-execution/) + +So the exact per-lane B fragment is: + +```C++int groupID = lane >> 2; // 0..7int tid4 = lane & 3; // 0..3// For current 16x8 B fragment:int n_col = groupID;b0 = B[2 * tid4 + 0, n_col];b1 = B[2 * tid4 + 1, n_col];b2 = B[2 * tid4 + 8, n_col];b3 = B[2 * tid4 + 9, n_col]; +C++int groupID = lane >> 2; // 0..7int tid4 = lane & 3; // 0..3// For current 16x8 B fragment:int n_col = groupID;b0 = B[2 * tid4 + 0, n_col];b1 = B[2 * tid4 + 1, n_col];b2 = B[2 * tid4 + 8, n_col];b3 = B[2 * tid4 + 9, n_col]; +``` + +That means your proposed mapping: + +```C++n_col = lane % 8;k_base = (lane / 8) * 2; +C++n_col = lane % 8;k_base = (lane / 8) * 2; +``` + +is wrong. +The correct mapping is: + +```C++n_col = lane >> 2;k_base = (lane & 3) * 2; +C++n_col = lane >> 2;k_base = (lane & 3) * 2; +``` + +and the four rows are: + +```C++{k_base + 0, k_base + 1, k_base + 8, k_base + 9} +C++{k_base + 0, k_base + 1, k_base + 8, k_base + 9} +``` + +That comes straight from the PTX fragment definition for matrix B. [NVIDIA Docs](https://docs.nvidia.com/cuda/parallel-thread-execution/) + +Your **single-scale-per-thread** conclusion is correct, though. Since all four `b_i` for one lane land in the same output column `col = groupID`, one per-column scale is enough for the whole lane fragment. That follows directly from the PTX B-fragment formula above. [NVIDIA Docs](https://docs.nvidia.com/cuda/parallel-thread-execution/) + +## Corrected structural skeleton + +This is the version I would lock in before touching `prmt`: + +```C++// B fragment mapping from PTX for m16n8k16.row.col.f16.f16.f16.f16int groupID = lane >> 2; // output column within current n8 sliceint tid4 = lane & 3; // row-pair selector#define MMA_SYNC_M16N8K16(RC, RA, RB) \ asm volatile( \ "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " \ "{%0, %1, %2, %3}, " \ "{%4, %5, %6, %7}, " \ "{%8, %9}, " \ "{%0, %1, %2, %3};\n" \ : "+r"((RC)[0]), "+r"((RC)[1]), "+r"((RC)[2]), "+r"((RC)[3]) \ : "r"((RA)[0]), "r"((RA)[1]), "r"((RA)[2]), "r"((RA)[3]), \ "r"((RB)[0]), "r"((RB)[1]) \ )uint32_t RC[8][4];#pragma unrollfor (int j = 0; j < 8; ++j) { RC[j][0] = 0; RC[j][1] = 0; RC[j][2] = 0; RC[j][3] = 0;}uint32_t RB_fp16_cur[2];uint32_t RB_fp16_nxt[2];#pragma unrollfor (int ks = 0; ks < 4; ++ks) { uint32_t RA[4]; // A fragment load stays as previously discussed via ldmatrix.x4 // ... ldmatrix.sync.aligned.m8n8.x4.shared.b16 -> RA ... // preload j = 0 { int frag_col = 0 + groupID; // j = 0 slice int global_col = col_base + frag_col; // PTX row pattern for B: int r0 = ks * 16 + (2 * tid4 + 0); int r1 = ks * 16 + (2 * tid4 + 1); int r2 = ks * 16 + (2 * tid4 + 8); int r3 = ks * 16 + (2 * tid4 + 9); // one scale per thread, because all four values are from one column half scale = S_smem[global_col]; // shared-memory address depends on offline packing // best validation layout is one 16-bit payload per lane: // packed as [ks][j][lane] -> uint16_t // payload contains 4 int4 values: b0,b1,b2,b3 // // uint16_t packed = ... // decode packed + scale -> RB_fp16_cur[0], RB_fp16_cur[1] } #pragma unroll for (int j = 0; j < 8; ++j) { if (j < 7) { int frag_col_nxt = (j + 1) * 8 + groupID; int global_col_nxt = col_base + frag_col_nxt; int r0 = ks * 16 + (2 * tid4 + 0); int r1 = ks * 16 + (2 * tid4 + 1); int r2 = ks * 16 + (2 * tid4 + 8); int r3 = ks * 16 + (2 * tid4 + 9); half scale_nxt = S_smem[global_col_nxt]; // uint16_t packed_nxt = ... // decode packed_nxt + scale_nxt -> RB_fp16_nxt[0], RB_fp16_nxt[1] } MMA_SYNC_M16N8K16(RC[j], RA, RB_fp16_cur); if (j < 7) { RB_fp16_cur[0] = RB_fp16_nxt[0]; RB_fp16_cur[1] = RB_fp16_nxt[1]; } }} +C++// B fragment mapping from PTX for m16n8k16.row.col.f16.f16.f16.f16int groupID = lane >> 2; // output column within current n8 sliceint tid4 = lane & 3; // row-pair selector#define MMA_SYNC_M16N8K16(RC, RA, RB) \ asm volatile( \ "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " \ "{%0, %1, %2, %3}, " \ "{%4, %5, %6, %7}, " \ "{%8, %9}, " \ "{%0, %1, %2, %3};\n" \ : "+r"((RC)[0]), "+r"((RC)[1]), "+r"((RC)[2]), "+r"((RC)[3]) \ : "r"((RA)[0]), "r"((RA)[1]), "r"((RA)[2]), "r"((RA)[3]), \ "r"((RB)[0]), "r"((RB)[1]) \ )uint32_t RC[8][4];#pragma unrollfor (int j = 0; j < 8; ++j) { RC[j][0] = 0; RC[j][1] = 0; RC[j][2] = 0; RC[j][3] = 0;}uint32_t RB_fp16_cur[2];uint32_t RB_fp16_nxt[2];#pragma unrollfor (int ks = 0; ks < 4; ++ks) { uint32_t RA[4]; // A fragment load stays as previously discussed via ldmatrix.x4 // ... ldmatrix.sync.aligned.m8n8.x4.shared.b16 -> RA ... // preload j = 0 { int frag_col = 0 + groupID; // j = 0 slice int global_col = col_base + frag_col; // PTX row pattern for B: int r0 = ks * 16 + (2 * tid4 + 0); int r1 = ks * 16 + (2 * tid4 + 1); int r2 = ks * 16 + (2 * tid4 + 8); int r3 = ks * 16 + (2 * tid4 + 9); // one scale per thread, because all four values are from one column half scale = S_smem[global_col]; // shared-memory address depends on offline packing // best validation layout is one 16-bit payload per lane: // packed as [ks][j][lane] -> uint16_t // payload contains 4 int4 values: b0,b1,b2,b3 // // uint16_t packed = ... // decode packed + scale -> RB_fp16_cur[0], RB_fp16_cur[1] } #pragma unroll for (int j = 0; j < 8; ++j) { if (j < 7) { int frag_col_nxt = (j + 1) * 8 + groupID; int global_col_nxt = col_base + frag_col_nxt; int r0 = ks * 16 + (2 * tid4 + 0); int r1 = ks * 16 + (2 * tid4 + 1); int r2 = ks * 16 + (2 * tid4 + 8); int r3 = ks * 16 + (2 * tid4 + 9); half scale_nxt = S_smem[global_col_nxt]; // uint16_t packed_nxt = ... // decode packed_nxt + scale_nxt -> RB_fp16_nxt[0], RB_fp16_nxt[1] } MMA_SYNC_M16N8K16(RC[j], RA, RB_fp16_cur); if (j < 7) { RB_fp16_cur[0] = RB_fp16_nxt[0]; RB_fp16_cur[1] = RB_fp16_nxt[1]; } }} +``` + +## The shared-memory layout I would use for B first + +Do **not** start with the 32-bit overfetch trick. +Start with the layout that exactly matches PTX lane ownership: + +```C++// logical validation layout// [ks][j][lane] -> one 16-bit payload containing b0,b1,b2,b3uint16_t Bfrag[4][8][32]; +C++// logical validation layout// [ks][j][lane] -> one 16-bit payload containing b0,b1,b2,b3uint16_t Bfrag[4][8][32]; +``` + +Then the lane-local address is trivial: + +```C++uint16_t packed = Bfrag[ks][j][lane]; +C++uint16_t packed = Bfrag[ks][j][lane]; +``` + +That gives you the exact four INT4 values required by PTX: + +- `b0` +- `b1` +- `b2` +- `b3` + +with zero ambiguity. + +Once that works, you can repack to 32-bit lane pairs for bandwidth/alignment tuning. + +## One more important correction + +Your earlier comment said: + +“Marlin’s offline packing specifically interleaves these exact four k indices so they sit contiguously in a 16-bit chunk.” + +That may be a reasonable design choice for your kernel, but it is **not** something PTX guarantees. PTX only gives the fragment ownership map; the “contiguous 16-bit packed payload per lane” is your **chosen offline layout**, not an ISA fact. The ISA fact is only the lane→`(row,col)` mapping above. [NVIDIA Docs](https://docs.nvidia.com/cuda/parallel-thread-execution/) + +## Verdict + +So: + +- your **double-buffered skeleton** is good, +- your **single-scale-per-thread insight** is good, +- your **B lane map is wrong as written**, +- and the corrected map is: + +```C++groupID = lane >> 2;tid4 = lane & 3;rows = {2*tid4 + 0, 2*tid4 + 1, 2*tid4 + 8, 2*tid4 + 9};col = groupID; +C++groupID = lane >> 2;tid4 = lane & 3;rows = {2*tid4 + 0, 2*tid4 + 1, 2*tid4 + 8, 2*tid4 + 9};col = groupID; +``` + +for each `j`, with `col_base + j*8 + groupID` as the global column. + +If you want, the next move is the exact **INT4 packed nibble order** for `uint16_t packed -> RB_fp16[2]`, including the `prmt` / unpack path. + +--- + +### User: +You are exactly right. I fucked up the `m16n8k16` column layout map. `lane >> 2` is the PTX standard, and your correction is spot-on. I also agree with your `uint16_t Bfrag[4][8][32]` starting point. Building the logical validation layout first is the bulletproof way to isolate lane ownership and prove the math before you start optimizing the shared memory fetches. + +Here is the exact INT4 packed nibble order and the raw decode path using the magic FP16 conversion trick (which avoids slow integer-to-float `cvt` instructions). + +### The Offline Packing Contract +To make the inline PTX beautiful and fast, your offline preprocessor should pack the four values for a lane into a single `uint16_t` exactly like this: +* **Bits 0-3:** `b0` (row `2*tid4 + 0`) +* **Bits 4-7:** `b1` (row `2*tid4 + 1`) +* **Bits 8-11:** `b2` (row `2*tid4 + 8`) +* **Bits 12-15:** `b3` (row `2*tid4 + 9`) + +### The Decode Strategy +Ampere has a known trick for converting 4-bit integers to FP16. +An FP16 number with an exponent of 25 (hex `0x6400`) represents the float `1024.0`. If you bitwise OR a 4-bit integer into the bottom of its mantissa, the resulting float is `1024.0 + nibble`. + +We extract the nibbles, inject them into `0x64006400` `half2` vectors, and then use `fma.rn.f16x2` to subtract the 1024.0 baseline and apply the column scale simultaneously. + +### The Code + +``` +cpp +// --- Outside the j-loop: Precompute the scale and offset --- +// Assuming symmetric INT4 with a zero-point of 8, the nibbles range from 0 to 15. +// We need to subtract (1024.0 + 8.0) = 1032.0 to center the values at 0. +// 1032.0 in FP16 is 0x6408. +// We want to compute: (magic_float - 1032.0) * scale +// Which is: magic_float * scale - 1032.0 * scale +// So we precalculate the FMA 'C' operand as (-1032.0 * scale). + +half scale = S_smem[global_col]; +uint32_t s2_reg; // half2 scale +uint32_t c2_reg; // half2 FMA offset + +{ + half2 scale2 = __halves2half2(scale, scale); + half2 offset_base = __halves2half2(__float2half(-1032.0f), __float2half(-1032.0f)); + half2 c2 = __hmul2(scale2, offset_base); + + s2_reg = *(uint32_t*)&scale2; + c2_reg = *(uint32_t*)&c2; +} + +// --- Inside the j-loop: Decode current packed payload --- +uint32_t packed_32 = Bfrag[ks][j][lane]; + +// 1. Isolate the byte pairs into separate registers +uint32_t w0 = packed_32 & 0x00FF; // Contains [b1, b0] +uint32_t w1 = (packed_32 >> 8) & 0x00FF; // Contains [b3, b2] + +// 2. Space the nibbles out to match the lower and upper halves of a half2 +// The lower nibble stays at bits 0-3. +// The upper nibble is masked (0xF0) and shifted left by 12 bits to land at bits 16-19. +uint32_t h0 = (w0 & 0x0F) | ((w0 & 0xF0) << 12); +uint32_t h1 = (w1 & 0x0F) | ((w1 & 0xF0) << 12); + +// 3. Inject the magic FP16 exponent (0x6400 per half -> 0x64006400 for half2) +uint32_t magic = 0x64006400; +uint32_t f16_01 = h0 | magic; +uint32_t f16_23 = h1 | magic; + +// 4. Issue the FMA to finish the conversion and apply the scale +asm volatile( + "fma.rn.f16x2 %0, %1, %2, %3;" + : "=r"(RB_fp16_cur[0]) + : "r"(f16_01), "r"(s2_reg), "r"(c2_reg) +); + +asm volatile( + "fma.rn.f16x2 %0, %1, %2, %3;" + : "=r"(RB_fp16_cur[1]) + : "r"(f16_23), "r"(s2_reg), "r"(c2_reg) +); +``` + +### Why this is the way +* **No `prmt.b32` required for this layout:** Because you packed them 16-bits contiguous, simple bitwise ANDs and a shift achieve the exact alignment you need. `prmt` is only strictly necessary if your offline packing scatters the nibbles across a 32-bit boundary. +* **Math Overlap:** The `fma.rn.f16x2` handles the integer-to-float offset and the scaling in a single hardware step, and those FMAs drop right into the pipe before the `mma.sync`. +* **Register Footprint:** Only 2 scaling registers (`s2_reg`, `c2_reg`) needed per column slice, easily absorbing the decode without spilling. + +Plug this into the double-buffered skeleton you wrote. If the validation matches, the only thing left is ripping out the 3D array layout and writing the raw `ld.shared.u32` trick to fetch 32-bits (two lanes worth) at a time for bandwidth. + +--- + +### ChatGPT: +I’m checking the PTX half2/FMA path against the ISA details before I bless the decode sequence, because the “magic FP16” trick is easy to get subtly wrong. + +--- + +This decode path is **basically right**, and the core trick is sound: + +- the corrected `m16n8k16.row.col` B-fragment ownership is the PTX one with `groupID = lane >> 2` and `threadID_in_group = lane % 4`, so each lane owns **4 values from one column**; ([PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/?utm_source=chatgpt.com)) +- `mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16` takes **two .f16x2 registers** for B, so your goal is exactly to produce 4 FP16 values packed into `RB_fp16_cur[0:2]`; ([PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/?utm_source=chatgpt.com)) +- `fma.rn.f16x2` is the right PTX instruction to do a packed half2 multiply-add in one step. ([PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/?utm_source=chatgpt.com)) + +There are **3 fixes** I’d make before calling it locked. + +## 1) Use uint16_t in the validation layout, not uint32_t + +Your text says the validation layout is: + +```C++uint16_t Bfrag[4][8][32]; +C++uint16_t Bfrag[4][8][32]; +``` + +but then the code does: + +```C++uint32_t packed_32 = Bfrag[ks][j][lane]; +C++uint32_t packed_32 = Bfrag[ks][j][lane]; +``` + +That is inconsistent. + +For the **logical validation layout**, keep it literal: + +```C++uint16_t packed_16 = Bfrag[ks][j][lane];uint32_t packed_32 = packed_16; +C++uint16_t packed_16 = Bfrag[ks][j][lane];uint32_t packed_32 = packed_16; +``` + +Only switch to 32-bit overfetch when you deliberately repack the shared-memory layout. Right now, mixing the two mental models will make debugging miserable. + +## 2) The magic decode is correct, but the comment should be tighter + +For symmetric signed INT4 reconstructed from an unsigned nibble `q in [0,15]`, you want: + +```C++deq = (q - 8) * scale +C++deq = (q - 8) * scale +``` + +Your trick builds: + +- `magic_half = 1024 + q` +- then computes: +- `(magic_half * scale) + (-1032 * scale) = (q - 8) * scale` + +That is correct. + +So the final packed-half decode path is: + +```C++uint16_t packed_16 = Bfrag[ks][j][lane];uint32_t packed_32 = packed_16;// [b1,b0] in low byte, [b3,b2] in high byteuint32_t w0 = packed_32 & 0x00FFu; // low byte: b1:b0uint32_t w1 = (packed_32 >> 8) & 0x00FFu; // high byte: b3:b2// place nibble 0 into low half mantissa bits 0..3// place nibble 1 into upper half mantissa bits 16..19uint32_t h0 = (w0 & 0x0Fu) | ((w0 & 0xF0u) << 12);uint32_t h1 = (w1 & 0x0Fu) | ((w1 & 0xF0u) << 12);// inject FP16 exponent for 1024.0 into both halvesconstexpr uint32_t MAGIC = 0x64006400u;uint32_t f16_01 = h0 | MAGIC; // {1024+b1, 1024+b0} in half2 formuint32_t f16_23 = h1 | MAGIC; // {1024+b3, 1024+b2} in half2 form +C++uint16_t packed_16 = Bfrag[ks][j][lane];uint32_t packed_32 = packed_16;// [b1,b0] in low byte, [b3,b2] in high byteuint32_t w0 = packed_32 & 0x00FFu; // low byte: b1:b0uint32_t w1 = (packed_32 >> 8) & 0x00FFu; // high byte: b3:b2// place nibble 0 into low half mantissa bits 0..3// place nibble 1 into upper half mantissa bits 16..19uint32_t h0 = (w0 & 0x0Fu) | ((w0 & 0xF0u) << 12);uint32_t h1 = (w1 & 0x0Fu) | ((w1 & 0xF0u) << 12);// inject FP16 exponent for 1024.0 into both halvesconstexpr uint32_t MAGIC = 0x64006400u;uint32_t f16_01 = h0 | MAGIC; // {1024+b1, 1024+b0} in half2 formuint32_t f16_23 = h1 | MAGIC; // {1024+b3, 1024+b2} in half2 form +``` + +Then: + +```C++asm volatile( "fma.rn.f16x2 %0, %1, %2, %3;" : "=r"(RB_fp16_cur[0]) : "r"(f16_01), "r"(s2_reg), "r"(c2_reg));asm volatile( "fma.rn.f16x2 %0, %1, %2, %3;" : "=r"(RB_fp16_cur[1]) : "r"(f16_23), "r"(s2_reg), "r"(c2_reg)); +C++asm volatile( "fma.rn.f16x2 %0, %1, %2, %3;" : "=r"(RB_fp16_cur[0]) : "r"(f16_01), "r"(s2_reg), "r"(c2_reg));asm volatile( "fma.rn.f16x2 %0, %1, %2, %3;" : "=r"(RB_fp16_cur[1]) : "r"(f16_23), "r"(s2_reg), "r"(c2_reg)); +``` + +That part is fine. + +## 3) Be explicit about element order inside RB_fp16 + +This is the one place I would force a comment now, because it will matter when you validate against a reference GEMM. + +With the packing you chose, the two half2 regs hold: + +- `RB_fp16_cur[0]` → `{b0, b1}` or `{b1, b0}` depending on how you interpret half2 lane order in your downstream MMA operand packing +- `RB_fp16_cur[1]` → `{b2, b3}` or `{b3, b2}` + +You should not leave that ambiguous. + +Write the validation contract explicitly in code by decoding once in scalar host-side logic and comparing lane fragments. In other words, define: + +```C++// REQUIRED VALIDATION CONTRACT:// RB_fp16_cur[0] must correspond to the first two B fragment elements// RB_fp16_cur[1] must correspond to the second two B fragment elements// in the exact order expected by mma.sync for this lane. +C++// REQUIRED VALIDATION CONTRACT:// RB_fp16_cur[0] must correspond to the first two B fragment elements// RB_fp16_cur[1] must correspond to the second two B fragment elements// in the exact order expected by mma.sync for this lane. +``` + +The PTX docs give you the **which four elements** per lane for B; they do not save you from getting the intra-register pair ordering wrong during your custom decode. ([PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/?utm_source=chatgpt.com)) + +## Register-pressure note + +Your updated register story is much better now: + +- one scale per lane fragment is correct +- `s2_reg` and `c2_reg` as packed half2 regs is the right move +- no need for 4 separate `half scale_*` + +So yes, this is now a sane decode footprint. + +## My verdict + +I’d mark it: + +- **lane ownership:** correct now +- **nibble packing contract:** good +- **magic FP16 + fma.rn.f16x2 decode:** good +- **validation layout typing (uint16_t vs uint32_t):** fix immediately +- **intra-register element order:** must be made explicit before performance tuning + +So yes — **this is good enough to plug into the double-buffered skeleton**, as long as you correct the `uint16_t`/`uint32_t` mismatch and make the final half2 element order explicit. + +The next step is exactly what you said: once this validates numerically, replace the logical `Bfrag[4][8][32]` path with the real shared-memory packed fetch path. + +--- + +### User: +You are right. Mixing `uint32_t` and `uint16_t` in the logical layout was sloppy and breaks the mental model. + +For the intra-register order, PTX `.f16x2` format expects the first element in the lower 16 bits and the second element in the upper 16 bits. With our shift logic `((w0 & 0xF0u) << 12)`, `b1` shifts into the upper half, and `b0` stays in the lower half. This correctly maps `b0` to the first row and `b1` to the second row for the `mma.sync` instruction. We will enforce this as a hard contract in the comments. + +Here is the final, locked inner loop skeleton. It integrates the corrected PTX B-fragment mapping, the exact magic decode, and explicit double-buffering. + +``` +cpp +// ------------------------------ +// Validation Layout & MMA Macros +// ------------------------------ +// Logical validation layout: [ks][j][lane] +// Contains 4 INT4 values packed as {b3, b2, b1, b0} +extern __shared__ uint16_t Bfrag[4][8][32]; + +#define MMA_SYNC_M16N8K16(RC, RA, RB) \ + asm volatile( \ + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " \ + "{%0, %1, %2, %3}, " \ + "{%4, %5, %6, %7}, " \ + "{%8, %9}, " \ + "{%0, %1, %2, %3};\n" \ + : "+r"((RC)[0]), "+r"((RC)[1]), "+r"((RC)[2]), "+r"((RC)[3]) \ + : "r"((RA)[0]), "r"((RA)[1]), "r"((RA)[2]), "r"((RA)[3]), \ + "r"((RB)[0]), "r"((RB)[1]) \ + ) + +// Lane ownership for m16n8k16 B-fragment +const int groupID = lane >> 2; // output column within current n8 slice +const int tid4 = lane & 3; // row-pair selector + +// Constants for INT4 -> FP16 Magic Decode +constexpr uint32_t MAGIC_FP16 = 0x64006400u; // 1024.0 in half2 +const half2 offset_base = __halves2half2(__float2half(-1032.0f), __float2half(-1032.0f)); + +// Accumulators +uint32_t RC[8][4] = {0}; +uint32_t RB_fp16_cur[2]; +uint32_t RB_fp16_nxt[2]; + +// ------------------------------ +// Inner Loop +// ------------------------------ +#pragma unroll +for (int ks = 0; ks < 4; ++ks) { + uint32_t RA[4]; + + // [ldmatrix.sync.aligned.m8n8.x4.shared.b16 -> RA] + + // --- Preload & Decode j = 0 --- + { + int frag_col = groupID; + int global_col = col_base + frag_col; + half scale = S_smem[global_col]; + + // Prepare FMA operands for decode + half2 scale2 = __halves2half2(scale, scale); + half2 c2 = __hmul2(scale2, offset_base); + uint32_t s2_reg = *(uint32_t*)&scale2; + uint32_t c2_reg = *(uint32_t*)&c2; + + uint16_t packed_16 = Bfrag[ks][0][lane]; + uint32_t packed_32 = packed_16; + + uint32_t w0 = packed_32 & 0x00FFu; // [b1, b0] + uint32_t w1 = (packed_32 >> 8) & 0x00FFu; // [b3, b2] + + uint32_t h0 = (w0 & 0x0Fu) | ((w0 & 0xF0u) << 12); + uint32_t h1 = (w1 & 0x0Fu) | ((w1 & 0xF0u) << 12); + + uint32_t f16_01 = h0 | MAGIC_FP16; + uint32_t f16_23 = h1 | MAGIC_FP16; + + // INTRA-REGISTER CONTRACT: + // RB_fp16_cur[0] -> Lower 16 bits = b0, Upper 16 bits = b1 + // RB_fp16_cur[1] -> Lower 16 bits = b2, Upper 16 bits = b3 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;" : "=r"(RB_fp16_cur[0]) : "r"(f16_01), "r"(s2_reg), "r"(c2_reg)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;" : "=r"(RB_fp16_cur[1]) : "r"(f16_23), "r"(s2_reg), "r"(c2_reg)); + } + + #pragma unroll + for (int j = 0; j < 8; ++j) { + if (j < 7) { + // --- Load & Decode j + 1 --- + int frag_col_nxt = (j + 1) * 8 + groupID; + int global_col_nxt = col_base + frag_col_nxt; + half scale_nxt = S_smem[global_col_nxt]; + + half2 scale2_nxt = __halves2half2(scale_nxt, scale_nxt); + half2 c2_nxt = __hmul2(scale2_nxt, offset_base); + uint32_t s2_reg_nxt = *(uint32_t*)&scale2_nxt; + uint32_t c2_reg_nxt = *(uint32_t*)&c2_nxt; + + uint16_t packed_16_nxt = Bfrag[ks][j + 1][lane]; + uint32_t packed_32_nxt = packed_16_nxt; + + uint32_t w0_nxt = packed_32_nxt & 0x00FFu; + uint32_t w1_nxt = (packed_32_nxt >> 8) & 0x00FFu; + + uint32_t h0_nxt = (w0_nxt & 0x0Fu) | ((w0_nxt & 0xF0u) << 12); + uint32_t h1_nxt = (w1_nxt & 0x0Fu) | ((w1_nxt & 0xF0u) << 12); + + uint32_t f16_01_nxt = h0_nxt | MAGIC_FP16; + uint32_t f16_23_nxt = h1_nxt | MAGIC_FP16; + + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;" : "=r"(RB_fp16_nxt[0]) : "r"(f16_01_nxt), "r"(s2_reg_nxt), "r"(c2_reg_nxt)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;" : "=r"(RB_fp16_nxt[1]) : "r"(f16_23_nxt), "r"(s2_reg_nxt), "r"(c2_reg_nxt)); + } + + MMA_SYNC_M16N8K16(RC[j], RA, RB_fp16_cur); + + if (j < 7) { + RB_fp16_cur[0] = RB_fp16_nxt[0]; + RB_fp16_cur[1] = RB_fp16_nxt[1]; + } + } +} +``` + +The mathematical foundation is completely airtight now. Once you run this block through the validation harness and confirm numerical parity against your reference GEMM, we rip out `Bfrag[4][8][32]` and write the physical `ld.shared.u32` to pipeline the shared memory bandwidth. + +Let me know when the harness goes green. + +--- + +### ChatGPT: +I’m sanity-checking the fragment ordering against the PTX docs before I sign off, because this is one of the few spots where a tiny register-order mistake can make the whole kernel numerically wrong. + +--- + +Not green yet — there is still **one major PTX mismatch** in the skeleton. + +## The blocker + +For: + +```C++mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 +C++mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 +``` + +PTX says: + +- **B** is **two .f16x2 registers** holding `b0, b1, b2, b3` low-to-high. [NVIDIA Docs+1](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html) +- **C/D** for `.f16` are **two .f16x2 registers**, not four. The PTX page’s prose has a typo in the element list, but the fragment definition and row/col formula make it clear this is a 4-element accumulator fragment carried in **2 regs**. [NVIDIA Docs+2NVIDIA Docs+2](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html) + +So this part of your code is still wrong: + +```C++uint32_t RC[8][4];..."{%0, %1, %2, %3}, ... {%0, %1, %2, %3}" +C++uint32_t RC[8][4];..."{%0, %1, %2, %3}, ... {%0, %1, %2, %3}" +``` + +For the **FP16 accumulator variant**, each `m16n8k16` slice should use **2 accumulator regs**, not 4. + +## What to change + +### Correct macro + +```C++#define MMA_SYNC_M16N8K16_F16(RC, RA, RB) \ asm volatile( \ "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " \ "{%0, %1}, " \ "{%2, %3, %4, %5}, " \ "{%6, %7}, " \ "{%0, %1};\n" \ : "+r"((RC)[0]), "+r"((RC)[1]) \ : "r"((RA)[0]), "r"((RA)[1]), "r"((RA)[2]), "r"((RA)[3]), \ "r"((RB)[0]), "r"((RB)[1]) \ ) +C++#define MMA_SYNC_M16N8K16_F16(RC, RA, RB) \ asm volatile( \ "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " \ "{%0, %1}, " \ "{%2, %3, %4, %5}, " \ "{%6, %7}, " \ "{%0, %1};\n" \ : "+r"((RC)[0]), "+r"((RC)[1]) \ : "r"((RA)[0]), "r"((RA)[1]), "r"((RA)[2]), "r"((RA)[3]), \ "r"((RB)[0]), "r"((RB)[1]) \ ) +``` + +### Correct accumulator storage + +```C++uint32_t RC[8][2];#pragma unrollfor (int j = 0; j < 8; ++j) { RC[j][0] = 0; RC[j][1] = 0;} +C++uint32_t RC[8][2];#pragma unrollfor (int j = 0; j < 8; ++j) { RC[j][0] = 0; RC[j][1] = 0;} +``` + +That is the biggest fix. + +## Your B-fragment logic is now correct + +This part is good: + +- `groupID = lane >> 2` +- `tid4 = lane & 3` +- rows owned by one lane are: + +`2*tid4 + 0` + +`2*tid4 + 1` + +`2*tid4 + 8` + +`2*tid4 + 9` +- `2*tid4 + 0` +- `2*tid4 + 1` +- `2*tid4 + 8` +- `2*tid4 + 9` +- all 4 belong to the same output column `groupID` + +That matches PTX’s B-fragment formula for `m16n8k16` FP16/BF16 exactly. [NVIDIA Docs](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html) + +So your `uint16_t` packing contract: + +- bits `0..3` = `b0` +- bits `4..7` = `b1` +- bits `8..11` = `b2` +- bits `12..15` = `b3` + +is a good validation layout. + +And yes, the element order inside the two `f16x2` regs is consistent with your decode: + +- `RB_fp16_cur[0]` → low 16 = `b0`, high 16 = `b1` +- `RB_fp16_cur[1]` → low 16 = `b2`, high 16 = `b3` + +That is the right contract for the current lane-local validation path. [NVIDIA Docs](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html) + +## Two smaller fixes I would still make + +### 1) Do not declare extern __shared__ uint16_t Bfrag[4][8][32]; + +That is not how dynamic shared memory should be modeled in the real kernel. For the validation path, use pointer arithmetic on a flat `extern __shared__ uint8_t smem[]` buffer and cast the right offset: + +```C++extern __shared__ uint8_t smem[];uint16_t* Bfrag = reinterpret_cast(smem + bfrag_stage_offset); +C++extern __shared__ uint8_t smem[];uint16_t* Bfrag = reinterpret_cast(smem + bfrag_stage_offset); +``` + +Then index manually: + +```C++uint16_t packed_16 = Bfrag[(ks * 8 + j) * 32 + lane]; +C++uint16_t packed_16 = Bfrag[(ks * 8 + j) * 32 + lane]; +``` + +### 2) Avoid type-punning through references/pointers for half2 -> uint32_t + +This: + +```C++uint32_t s2_reg = *(uint32_t*)&scale2; +C++uint32_t s2_reg = *(uint32_t*)&scale2; +``` + +usually works in practice, but it is uglier than it needs to be. Use `memcpy`, a union, or inline PTX `mov.b32` if you want to keep it airtight. + +## Verdict + +After the accumulator fix, this is good enough for the validation harness. + +So the next correct version is: + +- `RC[8][2]` +- `mma.sync ... {%0,%1} ... {%0,%1}` +- keep your current `RB_fp16_cur[2]` decode exactly as-is +- keep the `uint16_t` lane-local validation layout +- move `Bfrag` to flat shared-memory indexing + +Once that passes reference checks, then it makes sense to replace the logical `Bfrag[(ks,j,lane)]` access with the real packed shared-memory fetch path. + +--- + +### User: +You are completely right. I blindly carried over the FP32 accumulator footprint (4x 32-bit regs) into the FP16 accumulator instruction. For `.f16` accumulators, 4 elements pack perfectly into two 32-bit registers. Good catch on the PTX typo vs. actual ISA behavior. + +I also agree on ripping out the 3D array declaration and the pointer punning. Unions are cleaner and safer for the 32-bit register conversion. + +Here is the final, corrected validation skeleton. + +``` +cpp +// ------------------------------ +// Validation Layout & MMA Macros +// ------------------------------ +// Flat shared memory buffer +extern __shared__ uint8_t smem[]; + +// Macro for m16n8k16 with FP16 accumulators (2 regs for C/D) +#define MMA_SYNC_M16N8K16_F16(RC, RA, RB) \ + asm volatile( \ + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " \ + "{%0, %1}, " \ + "{%2, %3, %4, %5}, " \ + "{%6, %7}, " \ + "{%0, %1};\n" \ + : "+r"((RC)[0]), "+r"((RC)[1]) \ + : "r"((RA)[0]), "r"((RA)[1]), "r"((RA)[2]), "r"((RA)[3]), \ + "r"((RB)[0]), "r"((RB)[1]) \ + ) + +// Lane ownership for m16n8k16 B-fragment +const int groupID = lane >> 2; // output column within current n8 slice +const int tid4 = lane & 3; // row-pair selector + +// Constants for INT4 -> FP16 Magic Decode +constexpr uint32_t MAGIC_FP16 = 0x64006400u; // 1024.0 in half2 +const half2 offset_base = __halves2half2(__float2half(-1032.0f), __float2half(-1032.0f)); + +// Accumulators (2x 32-bit regs per slice) +uint32_t RC[8][2]; +#pragma unroll +for (int j = 0; j < 8; ++j) { + RC[j][0] = 0; + RC[j][1] = 0; +} + +uint32_t RB_fp16_cur[2]; +uint32_t RB_fp16_nxt[2]; + +// Helper union for clean half2 -> uint32_t conversion +union Half2Reg { + half2 h2; + uint32_t u32; +}; + +// ------------------------------ +// Inner Loop +// ------------------------------ +// Bfrag mapped to flat smem based on stage offset +uint16_t* Bfrag = reinterpret_cast(smem + bfrag_stage_offset); + +#pragma unroll +for (int ks = 0; ks < 4; ++ks) { + uint32_t RA[4]; + + // [ldmatrix.sync.aligned.m8n8.x4.shared.b16 -> RA] + + // --- Preload & Decode j = 0 --- + { + int frag_col = groupID; + int global_col = col_base + frag_col; + half scale = S_smem[global_col]; + + // Prepare FMA operands for decode + Half2Reg s2_reg, c2_reg; + s2_reg.h2 = __halves2half2(scale, scale); + c2_reg.h2 = __hmul2(s2_reg.h2, offset_base); + + // Logical fetch from flat smem + uint32_t packed_32 = Bfrag[(ks * 8 + 0) * 32 + lane]; + + uint32_t w0 = packed_32 & 0x00FFu; // [b1, b0] + uint32_t w1 = (packed_32 >> 8) & 0x00FFu; // [b3, b2] + + uint32_t h0 = (w0 & 0x0Fu) | ((w0 & 0xF0u) << 12); + uint32_t h1 = (w1 & 0x0Fu) | ((w1 & 0xF0u) << 12); + + uint32_t f16_01 = h0 | MAGIC_FP16; + uint32_t f16_23 = h1 | MAGIC_FP16; + + // INTRA-REGISTER CONTRACT: + // RB_fp16_cur[0] -> Lower 16 bits = b0, Upper 16 bits = b1 + // RB_fp16_cur[1] -> Lower 16 bits = b2, Upper 16 bits = b3 + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;" : "=r"(RB_fp16_cur[0]) : "r"(f16_01), "r"(s2_reg.u32), "r"(c2_reg.u32)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;" : "=r"(RB_fp16_cur[1]) : "r"(f16_23), "r"(s2_reg.u32), "r"(c2_reg.u32)); + } + + #pragma unroll + for (int j = 0; j < 8; ++j) { + if (j < 7) { + // --- Load & Decode j + 1 --- + int frag_col_nxt = (j + 1) * 8 + groupID; + int global_col_nxt = col_base + frag_col_nxt; + half scale_nxt = S_smem[global_col_nxt]; + + Half2Reg s2_reg_nxt, c2_reg_nxt; + s2_reg_nxt.h2 = __halves2half2(scale_nxt, scale_nxt); + c2_reg_nxt.h2 = __hmul2(s2_reg_nxt.h2, offset_base); + + // Logical fetch from flat smem + uint32_t packed_32_nxt = Bfrag[(ks * 8 + j + 1) * 32 + lane]; + + uint32_t w0_nxt = packed_32_nxt & 0x00FFu; + uint32_t w1_nxt = (packed_32_nxt >> 8) & 0x00FFu; + + uint32_t h0_nxt = (w0_nxt & 0x0Fu) | ((w0_nxt & 0xF0u) << 12); + uint32_t h1_nxt = (w1_nxt & 0x0Fu) | ((w1_nxt & 0xF0u) << 12); + + uint32_t f16_01_nxt = h0_nxt | MAGIC_FP16; + uint32_t f16_23_nxt = h1_nxt | MAGIC_FP16; + + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;" : "=r"(RB_fp16_nxt[0]) : "r"(f16_01_nxt), "r"(s2_reg_nxt.u32), "r"(c2_reg_nxt.u32)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;" : "=r"(RB_fp16_nxt[1]) : "r"(f16_23_nxt), "r"(s2_reg_nxt.u32), "r"(c2_reg_nxt.u32)); + } + + MMA_SYNC_M16N8K16_F16(RC[j], RA, RB_fp16_cur); + + if (j < 7) { + RB_fp16_cur[0] = RB_fp16_nxt[0]; + RB_fp16_cur[1] = RB_fp16_nxt[1]; + } + } +} +``` + +Drop this into the harness. Once the reference GEMM confirms numerical parity, let's rip out that scalar `Bfrag` array lookup and pack it into raw 32-bit `ld.shared.u32` bandwidth optimization. + +--- + +### ChatGPT: +This is now **structurally sound** for the validation harness. + +The two big PTX-level fixes are in place: + +- the `m16n8k16.row.col.f16.f16.f16.f16` B-fragment ownership is now consistent with the warp/lane mapping from the PTX ISA, and +- the FP16 accumulator form is using **2 accumulator registers** per `m16n8k16` slice instead of 4. That matches the PTX fragment contract for the FP16 accumulator variant. [NVIDIA Docs+1](https://docs.nvidia.com/cuda/parallel-thread-execution/?utm_source=chatgpt.com) + +I’d only tighten **three things** before calling it locked: + +### 1) Make the logical fetch type explicit + +This line is valid, but I would make it less ambiguous: + +```C++uint32_t packed_32 = Bfrag[(ks * 8 + 0) * 32 + lane]; +C++uint32_t packed_32 = Bfrag[(ks * 8 + 0) * 32 + lane]; +``` + +Since `Bfrag` is a `uint16_t*`, write it as: + +```C++uint16_t packed_16 = Bfrag[(ks * 8 + 0) * 32 + lane];uint32_t packed_32 = packed_16; +C++uint16_t packed_16 = Bfrag[(ks * 8 + 0) * 32 + lane];uint32_t packed_32 = packed_16; +``` + +That keeps the validation mental model exact: one lane owns one 16-bit packed fragment. + +### 2) Do not expect exact parity with an FP32 GEMM reference + +Because your MMA instruction is: + +```C++mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 +C++mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 +``` + +you are accumulating in **FP16**, not FP32. So your harness should compare against either: + +- a reference that also accumulates in FP16 in the same fragment/order model, or +- an FP32 reference with a tolerance, not bitwise equality. The PTX opcode itself tells you the accumulator/output type is FP16. [NVIDIA Docs+1](https://docs.nvidia.com/cuda/parallel-thread-execution/?utm_source=chatgpt.com) + +### 3) Enforce alignment now for the next step + +For the current logical path, 2-byte alignment is enough. +But since the next optimization step is `ld.shared.u32`, make sure `bfrag_stage_offset` is **4-byte aligned now**, so you do not have to change the shared-memory layout later. + +## My verdict + +For the **validation harness**, yes: + +- lane mapping is correct, +- nibble order is correct, +- the magic decode path is correct, +- the FP16 accumulator footprint is now correct, +- and the double-buffered structure is good. + +So the next milestone is: + +1. validate **decode only** against a scalar host/device reference for one warp fragment, +2. validate one full `ks/j` MMA step against a reference with FP16 accumulation semantics, +3. then replace: + +```C++uint16_t packed_16 = Bfrag[(ks * 8 + j) * 32 + lane]; +C++uint16_t packed_16 = Bfrag[(ks * 8 + j) * 32 + lane]; +``` + +with the real shared-memory packed fetch path. + +Once that passes, the next thing to write is the **physical ld.shared.u32 layout** for `Bfrag`, with 4-byte aligned lane-pair packing and no change to the logical lane ownership. + +--- + +### User: +provide in zip the whole project please, and the todo file + +--- + +### ChatGPT: +I’m packaging the full project into a zip and adding a separate TODO file with the CUDA and integration next steps. + +--- + +Here you go: + +The zip includes: + +- the full `quant_v5_3` Python project +- `README.md` +- `requirements.txt` +- `create_sdq_v53.sh` +- `kernels/ampere_marlin_paro_int4_validation.cu` +- the same `TODO_v5_3.md` inside the project folder + +One important note: the CUDA file in the zip is the **validation skeleton**, not a finished production kernel yet. The TODO file lists the exact next steps to take it from validated fragment math to a packed shared-memory implementation. + +--- + +### User: +whats the point of this ? + +--- + +### ChatGPT: +The point of it is **to answer one question with evidence**: + +**Can you get noticeably better quality than plain 4-bit GPTQ, while staying inside a deployment path that still resembles GPTQ/vLLM speed?** + +That project is a **test harness**, not the final weapon. + +What it gives you: + +- a **baseline**: plain GPTQ 4-bit +- a **smarter allocator**: Hessian-proxy decides which layers stay 4-bit, go 8-bit, or stay FP16 +- a **better rounding backend**: AutoRound instead of greedy GPTQ rounding +- a **matrix** to compare all of them on the same model and tasks + +So the real purpose is to separate: + +1. **quality gain from smarter allocation** +2. **quality gain from better rounding** +3. **format/runtime constraints** +4. **whether the extra complexity is actually worth it** + +Why the zip matters: + +- it saves you from rebuilding the scaffold manually +- it gives you one place to run the experiments +- it includes the CUDA validation skeleton for the future fused kernel path + +Why the TODO file matters: + +- the Python side is enough to benchmark ideas now +- the CUDA side is **not finished** +- the TODO is the bridge from “research scaffold” to “real kernel implementation” + +What it is **not**: + +- not a finished production quantizer +- not guaranteed to beat EXL3 / ParoQuant / AutoRound +- not yet a drop-in fastest inference engine + +So the short truth is: + +**The point is to de-risk the idea before you spend weeks writing kernels.** + +If the benchmark says: + +- `sdq_gptq` barely beats baseline, then stop +- `sdq_autoround` wins clearly, then invest more +- `auto_round` format wins but `auto_gptq` doesn’t, then you know exactly where the bottleneck is + +If your actual goal is just the most practical path today, then the point of this project is mostly **research and ablation**, not immediate deployment. + +--- diff --git a/README.md b/README.md index 805eb7239..cc4133a99 100644 --- a/README.md +++ b/README.md @@ -20,8 +20,16 @@

## Latest News +* 04/15/2026 [main]: ✅Gemma 4 GPTQ package support completed end-to-end. `gptqmodel.models.definitions` now exports the full supported model-definition surface, and the Ampere CUDA path was revalidated on RTX 3090 / RTX 3060 (`sm_86`) hardware with a local Gemma 4 GPTQ model. +* 04/08/2026 [main]: ✨ All CUDA kernels are now JIT compiled. The PyPI package is now about 300x smaller. On first use, GPT-QModel compiles only the kernels your workload actually needs. Improved Bonsai kernels now support execution `profile` control for `fast` or `low_memory` inference. Model weight loading during quantization has been optimized for large models like `GLM 5.1`. Added `GLM 5` and `GLM 5.1` model support. +* 04/03/2026 [6.0.3](https://github.com/ModelCloud/GPTQModel/releases/tag/v6.0.3): 🎉 New quantization methods: `ParoQuant`, `GGUF`, `FP8`, `EXL3`, and `FOEM: First-Order Error Matters`. Added PrismML/Bonsai 1bit model quantization (inference only), faster ParoQuant/AWQ kernels, ParoQuant `optimization scope` control: `module` (Paro Lite) or `layer` (Paro reference), plus `Gemma4`, `MiniCPM-O`, `MiniCPM-V`, and `GLM4 MOE lite` model support. +* 03/19/2026 [5.8.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v5.8.0): ✨HF Transformers 5.3.0 support with auto-defusing of `fused` models via pypi pkg: [Defuser](https://github.com/ModelCloud/Defuser). Qwen 3.5 family support added. New fast HF `cpu` kernels for GPTQ/AWQ added. Experimental INT8 `cpu` kernel added for GPTQ. * 03/09/2026 [main]: ✨Qwen 3.5 MoE model support added. New HF Kernel support added for AWQ. HF Kernel for both gptq/awq are now used by default for cpu devices for best performance. New INT8 kernel ported from Intel for gptq. + +
+ +Archived News * 02/28/2026 [main]: ✨Qwen 3.5 model support added. * 02/09/2026 [5.7.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v5.7.0): ✨New `MoE.Routing` config with `Bypass` and `Override` options to allow multiple brute-force MoE routing controls for higher quality quantization of MoE experts. Combined with `FailSafeStrategy`, GPT-QModel now has three separate control settings for efficient MoE expert quantization. `AWQ` `qcfg.zero_point` property has been merged with a unified `sym` symmetry property; `zero_point=True` is now `sym=False`. @@ -35,9 +43,6 @@ New Voxtral and Glm-4v model support, plus audio dataset calibration for Qwen2-O * 11/9/2025 [5.4.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v5.4.0): ✨New Intel CPU and XPU hardware-optimized AWQ `TorchFusedAWQ` kernel. Torch Fused kernels now compatible with `torch.compile`. Fixed AWQ MoE model compatibility and reduced VRAM usage. * 11/3/2025 [5.2.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v5.2.0): ✨Minimax M2 support with [ModelCloud BF16 M2 Model](https://huggingface.co/ModelCloud/MiniMax-M2-BF16). New `VramStrategy.Balanced` quantization property for reduced memory usage for large MoE on multi-3090 (24GB) devices. ✨Marin model. New AWQ Torch reference kernel. Fixed AWQ Marlin kernel for bf16. Fixed GLM 4.5/4.6 MoE missing `mtp` layers on model save (HF bug). Modular refactor. 🎉AWQ support out of beta with full feature support including multi-GPU quant and MoE VRAM saving. ✨Brumby (attention free) model support. ✨IBM Granite Nano support. New `calibration_concat_separator` config option. -
- -Archived News * 10/24/2025 [5.0.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v5.0.0): 🎉 Data-parallel quant support for `MoE` models on multi-GPU using `nogil` Python. `offload_to_disk` support enabled by default to massively reduce `CPU` RAM usage. New `Intel` and `AMD` CPU hardware-accelerated `TorchFused` kernel. Packing stage is now 4x faster and now inlined with quantization. `VRAM` pressure for large models reduced during quantization. `act_group_aware` is 16k+ times faster and now the default when `desc_act=False` for higher quality recovery without inference penalty of `desc_act=True`. New beta quality `AWQ` support with full `gemm`, @@ -77,7 +82,7 @@ Auto-detect MoE modules not activated during quantization due to insufficient ca `ROCm` `setup.py` compatibility fixes. `Optimum` and `Peft` compatibility fixes. Fixed `Peft` `bfloat16` training. * 03/03/2025 [2.0.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v2.0.0): 🎉 `GPTQ` quantization internals are now broken into multiple stages (processes) for feature expansion. -Synced `Marlin` kernel inference quality fix from upstream. Added `MARLIN_FP16`, lower-quality but faster backend. +Synced `Marlin` kernel inference quality fix from upstream. Added reduced-precision Marlin accumulation mode via environment control (`GPTQMODEL_MARLIN_USE_FP32=0` disables it, default is enabled). `ModelScope` support added. Logging and CLI progress bar output has been revamped with sticky bottom progress. Fixed `generation_config.json` save and load. Fixed Transformers v4.49.0 compatibility. Fixed compatibility of models without `bos`. Fixed `group_size=-1` and `bits=3` packing regression. Fixed Qwen 2.5 MoE regressions. @@ -88,12 +93,12 @@ Fixed ROCm version auto-detection in `setup` install. * 02/08/2025 [1.8.1](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.8.1): ⚡ `DeepSeek v3/R1` model support. New flexible weight `packing`: allow quantized weights to be packed to `[int32, int16, int8]` dtypes. `Triton` and `Torch` kernels support full range of new `QuantizeConfig.pack_dtype`. New `auto_gc: bool` control in `quantize()` which can reduce quantization time for small model with no chance of OOM. -New `GPTQModel.push_to_hub()` API for easy quant model upload to HF repo. New `buffered_fwd: bool` control in `model.quantize()`. Over 50% quantization speed-up for visual (vl) models. +New `buffered_fwd: bool` control in `model.quantize()`. Over 50% quantization speed-up for visual (vl) models. Fixed `bits=3` packing and `group_size=-1` regression in v1.7.4. * 01/26/2025 [1.7.4](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.7.4): New `compile()` API for ~4-8% inference TPS improvement. Faster `pack()` for post-quantization model save. `Triton` kernel validated for Intel/`XPU` when Intel Triton packages are installed. Fixed Transformers (bug) downcasting tokenizer class on save. * 01/20/2025 [1.7.3](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.7.3): New Telechat2 (China Telecom) and PhiMoE model support. Fixed `lm_head` weights duplicated in post-quantize save() for models with tied-embedding. * 01/19/2025 [1.7.2](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.7.2): Effective BPW (bits per weight) will now be logged during `load()`. Reduce loading time on Intel Arc A770/B580 `XPU` by 3.3x. Reduce memory usage in MLX conversion and fix Marlin kernel auto-select not checking CUDA compute version. -* 01/17/2025 [1.7.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.7.0): 👀 ✨ `backend.MLX` added for runtime-conversion and execution of GPTQ models on Apple's `MLX` framework on Apple Silicon (M1+). Exports of `gptq` models to `mlx` also now possible. We have added `mlx` exported models to [huggingface.co/ModelCloud](https://huggingface.co/collections/ModelCloud/vortex-673743382af0a52b2a8b9fe2). ✨ `lm_head` quantization now fully supported by GPTQModel without external pkg dependency. +* 01/17/2025 [1.7.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.7.0): 👀 ✨ `backend.MLX` added for runtime-conversion and execution of GPTQ models on Apple's `MLX` framework on Apple Silicon (M1+). ✨ `lm_head` quantization now fully supported by GPT-QModel without external pkg dependency. * 01/07/2025 [1.6.1](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.6.1): 🎉 New OpenAI API compatible endpoint via `model.serve(host, port)`. Auto-enable flash-attention2 for inference. Fixed `sym=False` loading regression. * 01/06/2025 [1.6.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.6.0): ⚡25% faster quantization. 35% reduction in VRAM usage vs v1.5. 👀 AMD ROCm (6.2+) support added and validated for 7900XT+ GPU. Auto-tokenizer loader via `load()` API. For most models you no longer need to manually init a tokenizer for both inference and quantization. * 01/01/2025 [1.5.1](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.5.1): 🎉 2025! Added `QuantizeConfig.device` to clearly define which device is used for quantization: default = `auto`. Non-quantized models are always loaded on CPU by-default and each layer is moved to `QuantizeConfig.device` during quantization to minimize VRAM usage. Compatibility fixes for `attn_implementation_autoset` in latest transformers. @@ -101,15 +106,15 @@ Fixed `bits=3` packing and `group_size=-1` regression in v1.7.4. * 12/23/2024 [1.5.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.5.0): Multi-modal (image-to-text) optimized quantization support has been added for Qwen 2-VL and Ovis 1.6-VL. Previous image-to-text model quantizations did not use image calibration data, resulting in less than optimal post-quantization results. Version 1.5.0 is the first release to provide a stable path for multi-modal quantization: only text layers are quantized. * 12/19/2024 [1.4.5](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.4.5): Windows 11 support added/validated. Ovis VL model support with image dataset calibration. Fixed `dynamic` loading. Reduced quantization VRAM usage. * 12/15/2024 [1.4.2](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.4.2): MacOS `GPU` (Metal) and `CPU` (M+) support added/validated for inference and quantization. Cohere 2 model support added. -* 12/13/2024 [1.4.1](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.4.1): Added Qwen2-VL model support. `mse` quantization control exposed in `QuantizeConfig`. Monkey patch `patch_vllm()` and `patch_hf()` API added to allow Transformers/Optimum/PEFT and vLLM to correctly load GPTQModel quantized models while upstream PRs are in pending status. -* 12/10/2024 [1.4.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.4.0) `EvalPlus` harness integration merged upstream. We now support both `lm-eval` and `EvalPlus`. Added pure torch `Torch` kernel. Refactored `Cuda` kernel to be `DynamicCuda` kernel. `Triton` kernel now auto-padded for max model support. `Dynamic` quantization now supports both positive `+:`:default, and `-:` negative matching which allows matched modules to be skipped entirely for quantization. Fixed auto-`Marlin` kernel selection. Added auto-kernel fallback for unsupported kernel/module pairs. Lots of internal refactor and cleanup in preparation for transformers/optimum/peft upstream PR merge. Deprecated the saving of `Marlin` weight format since `Marlin` supports auto conversion of `gptq` format to `Marlin` during runtime. +* 12/13/2024 [1.4.1](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.4.1): Added Qwen2-VL model support. `mse` quantization control exposed in `QuantizeConfig`. Monkey patch `patch_vllm()` and `patch_hf()` API added to allow Transformers/Optimum/PEFT and vLLM to correctly load GPT-QModel quantized models while upstream PRs are in pending status. +* 12/10/2024 [1.4.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.4.0) `EvalPlus` harness integration merged upstream. We now support both the legacy evaluation harness and `EvalPlus`. Added pure torch `Torch` kernel. Refactored `Cuda` kernel to be `DynamicCuda` kernel. `Triton` kernel now auto-padded for max model support. `Dynamic` quantization now supports both positive `+:`:default, and `-:` negative matching which allows matched modules to be skipped entirely for quantization. Fixed auto-`Marlin` kernel selection. Added auto-kernel fallback for unsupported kernel/module pairs. Lots of internal refactor and cleanup in preparation for transformers/optimum/peft upstream PR merge. Deprecated the saving of `Marlin` weight format since `Marlin` supports auto conversion of `gptq` format to `Marlin` during runtime. * 11/29/2024 [1.3.1](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.3.1) Olmo2 model support. Intel XPU acceleration via IPEX. Model sharding Transformer compatibility fix due to API deprecation in HF. Removed triton dependency. Triton kernel now optionally dependent on triton package. * 11/26/2024 [1.3.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.3.0) Zero-Day Hymba model support. Removed `tqdm` and `rogue` dependency. * 11/24/2024 [1.2.3](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.2.3) HF GLM model support. ClearML logging integration. Use `device-smi` and replace `gputil` + `psutil` dependencies. Fixed model unit tests. -* 11/11/2024 🚀 [1.2.1](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.2.1) Meta MobileLLM model support added. `lm-eval[gptqmodel]` integration merged upstream. Intel/IPEX CPU inference merged replacing QBits (deprecated). Auto-fix/patch ChatGLM-3/GLM-4 compatibility with latest transformers. New `.load()` and `.save()` API. +* 11/11/2024 🚀 [1.2.1](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.2.1) Meta MobileLLM model support added. legacy evaluation integration merged upstream. Intel/IPEX CPU inference merged replacing QBits (deprecated). Auto-fix/patch ChatGLM-3/GLM-4 compatibility with latest transformers. New `.load()` and `.save()` API. * 10/29/2024 🚀 [1.1.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.1.0) IBM Granite model support. Full auto-buildless wheel install from PyPI. Reduce max CPU memory usage by >20% during quantization. 100% CI model/feature coverage. @@ -124,28 +129,46 @@ Fixed `bits=3` packing and `group_size=-1` regression in v1.7.4. * 09/26/2024 ✨ [1.0.4](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.0.4) Integrated Liger Kernel support for ~1/2 memory reduction on some models during quantization. Added control toggle to disable parallel packing. * 09/18/2024 ✨ [1.0.3](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.0.3) Added Microsoft GRIN-MoE and MiniCPM3 support. * 08/16/2024 ✨ [1.0.2](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.0.2) Support Intel/AutoRound v0.3, prebuilt whl packages, and PyPI release. -* 08/14/2024 ✨ [1.0.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.0.0) 40% faster `packing`, fixed Python 3.9 compatibility, added `lm_eval` API. +* 08/14/2024 ✨ [1.0.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v1.0.0) 40% faster `packing`, fixed Python 3.9 compatibility, added evaluation API. * 08/10/2024 🚀 [0.9.11](https://github.com/ModelCloud/GPTQModel/releases/tag/v0.9.11) Added LG EXAONE 3.0 model support. New `dynamic` per layer/module flexible quantization where each layer/module may have different bits/params. Added proper sharding support to `backend.BITBLAS`. Auto-heal quantization errors due to small damp values. * 07/31/2024 🚀 [0.9.10](https://github.com/ModelCloud/GPTQModel/releases/tag/v0.9.10) Ported vllm/nm `gptq_marlin` inference kernel with expanded bits (8bits), group_size (64,32), and desc_act support for all GPTQ models with `FORMAT.GPTQ`. Auto-calculate auto-round nsamples/seglen parameters based on calibration dataset. Fixed save_quantized() called on pre-quantized models with non-supported backends. HF transformers dependency updated to ensure Llama 3.1 fixes are correctly applied to both quant and inference. * 07/25/2024 🚀 [0.9.9](https://github.com/ModelCloud/GPTQModel/releases/tag/v0.9.9): Added Llama-3.1 support, Gemma2 27B quant inference support via vLLM, auto pad_token normalization, fixed auto-round quant compatibility for vLLM/SGLang, and more. * 07/13/2024 🚀 [0.9.8](https://github.com/ModelCloud/GPTQModel/releases/tag/v0.9.8): -Run quantized models directly using GPTQModel with fast `vLLM` or `SGLang` backend! Both vLLM and SGLang are optimized for dynamic batching inference for maximum `TPS` (check usage under examples). Marlin backend also +Run quantized models directly using GPT-QModel with fast `vLLM` or `SGLang` backend! Both vLLM and SGLang are optimized for dynamic batching inference for maximum `TPS` (check usage under examples). Marlin backend also got full end-to-end in/out features padding to enhance current/future model compatibility. * 07/08/2024 🚀 [0.9.7](https://github.com/ModelCloud/GPTQModel/releases/tag/v0.9.7): InternLM 2.5 model support added. * 07/08/2024 🚀 [0.9.6](https://github.com/ModelCloud/GPTQModel/releases/tag/v0.9.6): [Intel/AutoRound](https://github.com/intel/auto-round) QUANT_METHOD support added for a potentially higher quality quantization with `lm_head` module quantization support for even more VRAM reduction: format export to `FORMAT.GPTQ` for max inference compatibility. * 07/05/2024 🚀 [0.9.5](https://github.com/ModelCloud/GPTQModel/releases/tag/v0.9.5): CUDA kernels have been fully deprecated in favor of Exllama(v1/v2)/Marlin/Triton. * 07/03/2024 🚀 [0.9.4](https://github.com/ModelCloud/GPTQModel/releases/tag/v0.9.4): HF Transformers integration added and bug fixed Gemma 2 support. -* 07/02/2024 🚀 [0.9.3](https://github.com/ModelCloud/GPTQModel/releases/tag/v0.9.3): Added Gemma 2 support, faster PPL calculations on GPU, and more code/arg refactor. +* 07/02/2024 🚀 [0.9.3](https://github.com/ModelCloud/GPTQModel/releases/tag/v0.9.3): Added Gemma 2 support, faster quality/benchmark calculations on GPU, and more code/arg refactor. * 06/30/2024 🚀 [0.9.2](https://github.com/ModelCloud/GPTQModel/releases/tag/v0.9.2): Added auto-padding of model in/out-features for exllama and exllama v2. Fixed quantization of OPT and DeepSeek V2-Lite models. Fixed inference for DeepSeek V2-Lite. * 06/29/2024 🚀 [0.9.1](https://github.com/ModelCloud/GPTQModel/releases/tag/v0.9.1): With 3 new models (DeepSeek-V2, DeepSeek-V2-Lite, DBRX Converted), BITBLAS new format/kernel, proper batching of calibration dataset resulting > 50% quantization speedup, security hash check of loaded model weights, tons of refactor/usability improvements, bug fixes, and much more. * 06/20/2924 ✨ [0.9.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v0.9.0): Thanks for all the work from ModelCloud team and the open-source ML community for their contributions!
+## Special Notes: + +PrismAI/Bonsai inference sample script. GPT-QModel loads Prism/Bonsai GGUF checkpoints through its native GGUF loading path and internal GGUF runtime shim. No external `gguf` PyPI package is required. + +```py +• from gptqmodel import GPTQModel + + model = GPTQModel.load("prism-ml/Bonsai-1.7B-gguf") + # or: model = GPTQModel.load("prism-ml/Bonsai-1.7B-gguf", profile="low_memory") + + tokens = model.generate( + "Who wrote Romeo and Juliet?", + max_new_tokens=128, + )[0] + + print(model.tokenizer.decode(tokens, skip_special_tokens=True)) + ``` + ## What is GPT-QModel? GPT-QModel is a production-ready LLM model compression/quantization toolkit with hw-accelerated inference support for both CPU/GPU via HF Transformers, vLLM, and SGLang. -GPT-QModel currently supports GPTQ, AWQ, QQQ, GPTAQ, EoRa, GAR, with more quantization methods and enhancements planned. +GPT-QModel currently supports GPTQ, AWQ, ParoQuant, QQQ, GGUF, FP8, EXL3, GPTAQ, EoRa, GAR and FOEM, with more quantization methods and enhancements planned. ## Quantization Support @@ -155,16 +178,55 @@ GPT-QModel is a modular design supporting multiple quantization methods and feat |---------------------------|------------|---|---|---|---------------| | GPTQ | ✅ | ✅ | ✅ | ✅ | ✅ | | AWQ | ✅ | ✅ | ✅ | ✅ | ✅ | +| ParoQuant | ✅ | x | x | x | ✅ | +| GGUF | ✅ | x | x | x | x | +| FP8 | ✅ | x | x | x | x | +| Exllama V3 / EXL3 | ✅ | x | x | x | x | | EoRA | ✅ | ✅ | ✅ | ✅ | x | | Group Aware Act Reordering | ✅ | ✅ | ✅ | ✅ | ✅ | | QQQ | ✅ | x | x | x | x | | Rotation | ✅ | x | x | x | x | | GPTAQ | ✅ | ✅ | ✅ | ✅ | ✅ | +| FOEM | ✅ | ✅ | ✅ | ✅ | ✅ | + +`GGUF`, `FP8`, `EXL3`, and `ParoQuant` are currently native GPT-QModel quantization/runtime paths. `vLLM` and `SGLang` integration currently targets `GPTQ` and `AWQ`. + +### Quant Method / Format / Backend Matrix + +Canonical backend names are shown below. Legacy aliases such as `BACKEND.TORCH`, `BACKEND.MARLIN`, `BACKEND.GEMM`, and `BACKEND.PARO` are still accepted and normalized to the matching canonical backend for the selected quant method. + +| Quant Method | Formats | Backends / Kernels | +| --- | --- | --- | +| `METHOD.GPTQ` | `FORMAT.GPTQ`, `FORMAT.GPTQ_V2`, `FORMAT.MARLIN`, `FORMAT.BITBLAS` | `FORMAT.GPTQ`: `BACKEND.GPTQ_TORCH_ATEN`, `BACKEND.GPTQ_MACHETE`, `BACKEND.GPTQ_MARLIN`, `BACKEND.GPTQ_EXLLAMA_V2`, `BACKEND.GPTQ_TORCH_FUSED`, `BACKEND.GPTQ_TRITON`, `BACKEND.GPTQ_BITBLAS`, `BACKEND.GPTQ_TORCH`, `BACKEND.GPTQ_TORCH_INT8`
`FORMAT.GPTQ_V2`: `BACKEND.GPTQ_TORCH_ATEN`, `BACKEND.GPTQ_EXLLAMA_V2`, `BACKEND.GPTQ_TORCH_FUSED`, `BACKEND.GPTQ_TRITON`, `BACKEND.GPTQ_BITBLAS`, `BACKEND.GPTQ_TORCH`, `BACKEND.GPTQ_TORCH_INT8`
`FORMAT.MARLIN`: `BACKEND.GPTQ_MARLIN`
`FORMAT.BITBLAS`: `BACKEND.GPTQ_BITBLAS` | +| `METHOD.AWQ` | `FORMAT.GEMM`, `FORMAT.GEMV`, `FORMAT.GEMV_FAST`, `FORMAT.LLM_AWQ`, `FORMAT.MARLIN`, `FORMAT.BITBLAS` | `FORMAT.GEMM`: `BACKEND.AWQ_TORCH_ATEN`, `BACKEND.AWQ_MACHETE`, `BACKEND.AWQ_MARLIN`, `BACKEND.AWQ_EXLLAMA_V2`, `BACKEND.AWQ_GEMM`, `BACKEND.AWQ_GEMM_TRITON`, `BACKEND.AWQ_TORCH_FUSED`, `BACKEND.AWQ_TORCH`, `BACKEND.AWQ_TORCH_INT8`, `BACKEND.AWQ_BITBLAS`
`FORMAT.GEMV`: `BACKEND.AWQ_GEMV`
`FORMAT.GEMV_FAST`: `BACKEND.AWQ_GEMV_FAST`
`FORMAT.LLM_AWQ`: `BACKEND.AWQ_GEMV_FAST`
`FORMAT.MARLIN`: `BACKEND.AWQ_MACHETE`, `BACKEND.AWQ_MARLIN`
`FORMAT.BITBLAS`: `BACKEND.AWQ_BITBLAS` | +| `METHOD.PARO` | `FORMAT.PAROQUANT` | `BACKEND.PAROQUANT_CUDA`, `BACKEND.PAROQUANT_TRITON` | +| `METHOD.QQQ` | `FORMAT.QQQ` | `BACKEND.QQQ` | +| `METHOD.GGUF` | `FORMAT.GGUF` | `BACKEND.GGUF_TRITON`, `BACKEND.GGUF_CPP_CUDA`, `BACKEND.GGUF_CPP_CPU`, `BACKEND.GGUF_TORCH` | +| `METHOD.FP8` | `FORMAT.FP8` | `BACKEND.FP8_TORCH` | +| `METHOD.BITSANDBYTES` | `FORMAT.BITSANDBYTES` | `BACKEND.BITSANDBYTES` | +| `METHOD.EXL3` | `FORMAT.EXL3` | `BACKEND.EXL3_EXLLAMA_V3`, `BACKEND.EXL3_TORCH` | + +`BACKEND.VLLM`, `BACKEND.SGLANG`, and `BACKEND.MLX` are external runtime backends and are not part of the native kernel matrix above. + +Marlin uses `GPTQMODEL_MARLIN_USE_FP32` (default: enabled) to control fp32 accumulation. + +### ParoQuant Activation Checkpointing + +`ParoConfig.opt_gradient_checkpointing` controls activation checkpointing during ParoQuant's train-style optimization stages. + +- `opt_scope="layer"` defaults to `opt_gradient_checkpointing=True` +- `opt_scope="module"` defaults to `opt_gradient_checkpointing=False` +- `opt_scope="compute_block"` defaults to `opt_gradient_checkpointing=False` + +Current internal benchmarks have only shown a clear resource-usage benefit for `layer` scope. `module` and `compute_block` support the toggle, but they are not enabled by default because we have not yet measured a consistent memory win there. ## Features * ✨ Native integration with HF [Transformers](https://github.com/huggingface/transformers), [Optimum](https://github.com/huggingface/optimum), and [Peft](https://github.com/huggingface/peft) * 🚀 [vLLM](https://github.com/vllm-project/vllm) and [SGLang](https://github.com/sgl-project/sglang) inference integration for quantized models with format = `FORMAT.[GPTQ/AWQ]` -* ✨ GPTQ, AWQ, and QQQ quantization format with hardware-accelerated inference kernels. +* ✨ GPTQ, AWQ, ParoQuant, QQQ, GGUF, FP8, EXL3, GPTAQ, and FOEM quantization support. +* 🚀 Local inference backends via `GPTQModel.load(..., backend=BACKEND.<...>)`: `MARLIN`, `MACHETE`, `EXLLAMA_V2`, `TRITON`, `BITBLAS`, `TORCH`, `QQQ`, plus experimental `GPTQ_PRO` for Ampere symmetric INT4 checkpoints with `desc_act=False`. +* ✨ Prism Bonsai `Q1_0_g128` GGUF checkpoints can be loaded for post-quantized inference through the normal `model_id_or_path` argument. GPT-QModel normalizes the GGUF artifact internally for HF Transformers via its native GGUF runtime, and does not support Prism Bonsai quantization or export. +* ✨ `model.serve(host, port)` launches an OpenAI-compatible FastAPI server at `/v1/chat/completions`. * 🚀 Quantize MoE models with ease even with extreme routing activation bias via `Moe.Routing` and/or `FailSafe`. * 🚀 Data Parallelism for 80%+ quantization speed reduction with Multi-GPU. * 🚀 Optimized for Python >= 3.13t (free threading) with lock-free threading. @@ -177,6 +239,15 @@ GPT-QModel is a modular design supporting multiple quantization methods and feat * 🚀 [Microsoft/BITBLAS](https://github.com/microsoft/BitBLAS) optimized tile based inference. * 💯 CI unit-test coverage for all supported models and kernels including post-quantization quality regression. +## Who's Using GPT-QModel? + +Selected public references where teams or companies explicitly mention GPT-QModel in documentation, integration notes, or quantized model usage. This is not an exhaustive customer list. + +* Hugging Face logo Hugging Face +* Intel logo Intel +* NVIDIA logo NVIDIA +* Alibaba Cloud logo Alibaba Cloud + ## Quality: GPTQ 4bit can match native BF16: 🤗 [ModelCloud quantized Vortex models on HF](https://huggingface.co/collections/ModelCloud/vortex-673743382af0a52b2a8b9fe2) @@ -184,21 +255,24 @@ GPT-QModel is a modular design supporting multiple quantization methods and feat ## Model Support -| Model | | | | | | | | | | -|-------------------|---|---------------|---|----------------|---|----------------|---|---------------------|---| -| Apertus | ✅ | EXAONE 3/4 | ✅ | Dots1 | ✅ | Mistral3 | ✅ | Qwen 2/3 (Next/MoE) | ✅ | -| Baichuan | ✅ | Falcon (H1) | ✅ | InternLM 1/2.5 | ✅ | Mixtral | ✅ | Qwen 2/2.5/3 VL | ✅ | -| Bloom | ✅ | FastVLM | ✅ | Kimi K2 | ✅ | MobileLLM | ✅ | Qwen 2.5/3 Omni | ✅ | -| ChatGLM | ✅ | Gemma 1/2/3 | ✅ | Klear | ✅ | MOSS | ✅ | RefinedWeb | ✅ | -| CodeGen | ✅ | GPTBigCod | ✅ | LING/RING | ✅ | MPT | ✅ | StableLM | ✅ | -| Cohere 1-2 | ✅ | GPTQ-Neo(X) | ✅ | Llama 1-3.3 | ✅ | Nemotron H | ✅ | StarCoder2 | ✅ | -| DBRX Converted | ✅ | GPT-2 | ✅ | Llama 3.2 VL | ✅ | Nemotron Ultra | ✅ | TeleChat2 | ✅ | -| Deci | ✅ | GPT-J | ✅ | Llama 4 | ✅ | OPT | ✅ | Trinity | ✅ | -| DeepSeek-V2/V3/R1 | ✅ | GPT-OSS | ✅ | LongCatFlash | ✅ | OLMo2 | ✅ | Yi | ✅ | -| DeepSeek-V2-Lite | ✅ | Granite | ✅ | LongLLaMA | ✅ | Ovis 1.6/2 | ✅ | Seed-OSS | ✅ | -| Dream | ✅ | GRIN-MoE | ✅ | Instella | ✅ | Phi 1-4 | ✅ | Voxtral | ✅ | -| ERNIE 4.5 | ✅ | GLM 4/4V/4MoE | ✅ | MiniCPM3 | ✅ | PanGu-α | ✅ | XVERSE | ✅ | -| Brumby | ✅ | Hymba | ✅ | Mistral | ✅ | Qwen 1/2/3/3.5 | ✅ | Minimax M2 | ✅ | +| Model | | | | | | | | | | +|-------------------|---|---------------|---|------------------------|---|----------------|---|---------------------|---| +| Apertus | ✅ | EXAONE 3/4 | ✅ | Dots1 | ✅ | Mistral3 | ✅ | Qwen 2/3 (Next/MoE) | ✅ | +| Baichuan | ✅ | Falcon (H1) | ✅ | InternLM 1/2.5 | ✅ | Mixtral | ✅ | Qwen 2/2.5/3 VL | ✅ | +| Bloom | ✅ | FastVLM | ✅ | Kimi K2 | ✅ | MobileLLM | ✅ | Qwen 2.5/3 Omni | ✅ | +| ChatGLM | ✅ | Gemma 1-4 | ✅ | Klear | ✅ | MOSS | ✅ | RefinedWeb | ✅ | +| CodeGen | ✅ | GPTBigCod | ✅ | LING/RING | ✅ | MPT | ✅ | StableLM | ✅ | +| Cohere 1-2 | ✅ | GPTQ-Neo(X) | ✅ | Llama 1-3.3 | ✅ | Nemotron H | ✅ | StarCoder2 | ✅ | +| DBRX Converted | ✅ | GPT-2 | ✅ | Llama 3.2 VL | ✅ | Nemotron Ultra | ✅ | TeleChat2 | ✅ | +| Deci | ✅ | GPT-J | ✅ | Llama 4 | ✅ | OPT | ✅ | Trinity | ✅ | +| DeepSeek-V2/V3/R1 | ✅ | GPT-OSS | ✅ | LongCatFlash | ✅ | OLMo2 | ✅ | Yi | ✅ | +| DeepSeek-V2-Lite | ✅ | Granite | ✅ | LongLLaMA | ✅ | Ovis 1.6/2 | ✅ | Seed-OSS | ✅ | +| Dream | ✅ | GRIN-MoE | ✅ | Instella | ✅ | Phi 1-4 | ✅ | Voxtral | ✅ | +| ERNIE 4.5 | ✅ | GLM 4/4V | ✅ | GLM4 MoE/GLM4 MOE lite | ✅ | MiniCPM3/MiniCPM-O/MiniCPM-V | ✅ | PanGu-α | ✅ | +| XVERSE | ✅ | Brumby | ✅ | Hymba | ✅ | Mistral | ✅ | Qwen 1/2/3/3.5 | ✅ | +| Minimax M2 | ✅ | | | | | | | | | + +Prism Bonsai GGUF checkpoints are supported for inference only through GPT-QModel's native GGUF path and internal GGUF runtime. Bonsai checkpoints load through the normal model path or repo argument and do not require the external `gguf` package. Prism model quantization is not included. ## Platform and HW Support @@ -207,7 +281,7 @@ GPT-QModel is validated for Linux, MacOS, and Windows 11: | Platform | Device | | Optimized Arch | Kernels | |-----------------|---------------| --- | ------------ |-----------------------------------------------| -| 🐧 Linux | Nvidia GPU | ✅ | `Ampere+` | Marlin, Exllama V2, Exllama V1, Triton, Torch | +| 🐧 Linux | Nvidia GPU | ✅ | `Ampere+` (`sm_80`/`sm_86`, incl. RTX 3090 / 3060) | Marlin, Exllama V2, Exllama V1, Triton, Torch | | 🐧 Linux | AMD GPU | ✅ | `7900XT+`, `ROCm 6.2+` | Exllama V2, Exllama V1, Torch | | 🐧 Linux | Intel XPU | ✅ | `Arc`, `Datacenter Max` | TorchFused, TorchFusedAWQ, Torch | | 🐧 Linux | Intel/AMD CPU | ✅ | `avx`, `amx` | TorchFused, TorchFusedAWQ, Torch | @@ -221,26 +295,55 @@ GPT-QModel is validated for Linux, MacOS, and Windows 11: ```bash # You can install optional modules like autoround, ipex, vllm, sglang, bitblas. -# Example: pip install -v --no-build-isolation gptqmodel[vllm,sglang,bitblas] -pip install -v gptqmodel --no-build-isolation -uv pip install -v gptqmodel --no-build-isolation +# Example: pip install -v gptqmodel[vllm,sglang,bitblas] +pip install -v gptqmodel +uv pip install -v gptqmodel ``` +The package depends on `ninja` for first-use JIT kernel compilation. + ### Install from source ```bash # clone repo git clone https://github.com/ModelCloud/GPTQModel.git && cd GPTQModel -# python3-dev is required, ninja speeds up compilation, and you need to upgrade to the latest `setuptools` to avoid errors -apt install python3-dev ninja setuptools -U +# python3-dev is required for some source installs +apt install python3-dev -# pip: compile and install +# pip: install from source # You can install optional modules like vllm, sglang, bitblas. -# Example: pip install -v --no-build-isolation .[vllm,sglang,bitblas] -pip install -v . --no-build-isolation +# Example: pip install -v .[vllm,sglang,bitblas] +pip install -v . ``` +### Conda + Docker (reproducible GPTQ-Pro / vLLM environment) + +This repository now includes a top-level `environment.yml` and `Dockerfile` for a reproducible +CUDA + conda + editable-install workflow that matches the GPTQ-Pro / vLLM experiments documented +below. + +```bash +# conda / local +conda env create -f environment.yml +conda activate gptq-pro-vllm +pip install -v --no-build-isolation -e ".[vllm,eval,openai]" + +# docker +docker build -t gptq-pro-vllm . +docker run --gpus all -it --shm-size=16g \ + -v "$PWD":/workspace/GPTQ-Pro \ + gptq-pro-vllm +``` + +Notes: + +* The Docker image expects the NVIDIA Container Toolkit on the host. +* The conda env is intentionally named `gptq-pro-vllm` so shell snippets, chat examples, and vLLM + commands all use the same environment name. +* `environment.yml` keeps the base environment lightweight; the editable install with + `.[vllm,eval,openai]` layers in the project extras from `pyproject.toml`. + ### Inference Three-line API to use `GPT-QModel` for GPTQ model inference: @@ -257,6 +360,44 @@ To use models from [ModelScope](https://www.modelscope.cn/) instead of HuggingFa export GPTQMODEL_USE_MODELSCOPE=True ``` +### FP32 accumulation toggle + +Some AWQ and ParoQuant CUDA/Triton kernels support an fp32 accumulation mode to reduce numerical drift during fused quantized matmul. This setting defaults to `True` because accuracy is prioritized over speed. + +```shell +# default behavior: higher accuracy, slightly lower speed on some kernels +export GPTQMODEL_FP32_ACCUM=1 + +# optional speed-first mode for some kernels +export GPTQMODEL_FP32_ACCUM=0 +``` + +Notes: +* This is a runtime toggle. It does not change model weights or saved checkpoints. +* It mainly affects some fused AWQ and ParoQuant CUDA/Triton kernels. Dense/dequantize fallback paths are mostly unaffected. +* `1` is recommended for regression testing and quality-sensitive evaluation. `0` may be useful when chasing a small latency win and the quality tradeoff is acceptable. + +### Chat CLI frontend + +The lightweight CLI frontend lives under [`chat/`](chat/README.md). It now supports: + +* `--tokenizer_path` to reuse a source tokenizer with a quantized checkpoint +* `--system_prompt` to override or disable the default chat system message +* `--max_new_tokens` to cap reply length +* `--trust_remote_code` for newer Hugging Face model families such as Qwen 3.5 + +Example: + +```bash +cd chat +./run.sh \ + --gpu_id 0 \ + --model_path /models/Qwen3.5-4B-abliterated-GPTQ-Pro-4bit \ + --tokenizer_path wangzhang/Qwen3.5-4B-abliterated \ + --max_new_tokens 1024 \ + --trust_remote_code +``` + ### OpenAI API compatible endpoint ```py # load model using above inference guide first @@ -268,7 +409,7 @@ Basic example of using `GPT-QModel` to quantize an LLM model: ```py from datasets import load_dataset -from gptqmodel import GPTQModel, QuantizeConfig +from gptqmodel import GPTQConfig, GPTQModel model_id = "meta-llama/Llama-3.2-1B-Instruct" quant_path = "Llama-3.2-1B-Instruct-gptqmodel-4bit" @@ -279,7 +420,7 @@ calibration_dataset = load_dataset( split="train" ).select(range(1024))["text"] -quant_config = QuantizeConfig(bits=4, group_size=128) +quant_config = GPTQConfig(bits=4, group_size=128) model = GPTQModel.load(model_id, quant_config) @@ -289,6 +430,113 @@ model.quantize(calibration_dataset, batch_size=1) model.save(quant_path) ``` +#### Other Quantization Formats + +`QuantizeConfig` remains the broad factory. The concrete config classes are now `GPTQConfig`, `AWQConfig`, `ParoConfig`, `QQQConfig`, `RTNConfig`, `GGUFConfig`, `FP8Config`, `BitsAndBytesConfig`, and `EXL3Config`. + +`GPTQ`, `AWQ`, `ParoQuant`, and `EXL3` are calibration-based. `GGUF` and `FP8` are weight-only and should be quantized with `calibration=None`. + +##### Preprocessors + +`preprocessors=[...]` adds optional module-weight preparation steps before quantization or repacking. They are available on `GPTQConfig`, `AWQConfig`, `ParoConfig`, `RTNConfig`, `GGUFConfig`, `FP8Config`, and `BitsAndBytesConfig`. + +- `SmootherConfig`: apply weight smoothing before quantization. +- `AutoModuleDecoderConfig`: decode FP8/FP4 source modules to a dense `target_dtype` before downstream quantization or repacking. +- `TensorParallelPadderConfig`: opt-in tensor-parallel padding metadata for TP-aligned packing. + +```py +import torch +from gptqmodel import GGUFConfig, GPTQConfig +from gptqmodel.quantization import ( + AutoModuleDecoderConfig, + SmoothMAD, + SmootherConfig, + TensorParallelPadderConfig, +) + +gptq_cfg = GPTQConfig( + bits=4, + group_size=128, + preprocessors=[ + SmootherConfig(smooth=SmoothMAD(k=2.0)), + AutoModuleDecoderConfig(target_dtype=torch.bfloat16), + TensorParallelPadderConfig(), + ], +) + +gguf_cfg = GGUFConfig( + bits=4, + format="q_k_m", + preprocessors=[ + AutoModuleDecoderConfig(target_dtype=torch.bfloat16), + TensorParallelPadderConfig(), + ], +) +``` + +##### GGUF Example: Llama 3.2 1B Instruct + +```py +from gptqmodel import BACKEND, GGUFConfig, GPTQModel + +model_id = "meta-llama/Llama-3.2-1B-Instruct" +quant_path = "Llama-3.2-1B-Instruct-GGUF-Q4_K_M" + +qcfg = GGUFConfig( + bits=4, + format="q_k_m", +) + +model = GPTQModel.load(model_id, qcfg) +model.quantize(calibration=None, backend=BACKEND.GGUF_TORCH) +model.save(quant_path) +``` + +##### FP8 Example: Llama 3.2 1B Instruct + +```py +from gptqmodel import BACKEND, FP8Config, GPTQModel + +model_id = "meta-llama/Llama-3.2-1B-Instruct" +quant_path = "Llama-3.2-1B-Instruct-FP8-E4M3" + +qcfg = FP8Config( + format="float8_e4m3fn", # or "float8_e5m2" + bits=8, + weight_scale_method="row", +) + +model = GPTQModel.load(model_id, qcfg) +model.quantize(calibration=None, backend=BACKEND.GPTQ_TORCH) +model.save(quant_path) +``` + +##### Exllama V3 / EXL3 Example: Llama 3.2 1B Instruct + +```py +from datasets import load_dataset +from gptqmodel import BACKEND, EXL3Config, GPTQModel + +model_id = "meta-llama/Llama-3.2-1B-Instruct" +quant_path = "Llama-3.2-1B-Instruct-EXL3" + +calibration_dataset = load_dataset( + "allenai/c4", + data_files="en/c4-train.00001-of-01024.json.gz", + split="train", +).select(range(1024))["text"] + +qcfg = EXL3Config( + bits=4.0, # target average bits-per-weight + head_bits=6.0, # optional higher bitrate for attention heads / sensitive tensors + codebook="mcg", # one of: mcg, mul1, 3inst +) + +model = GPTQModel.load(model_id, qcfg) +model.quantize(calibration_dataset, batch_size=1, backend=BACKEND.EXL3_EXLLAMA_V3) +model.save(quant_path) +``` + #### MoE Quantization Some MoE (mixture of experts) models have extremely uneven/biased routing (distribution of tokens) to the `experts` causing some expert modules to receive close-to-zero activated tokens, thus failing to complete calibration-based quantization (GPTQ/AWQ). @@ -343,49 +591,53 @@ tokens = model.generate("Capital of France is")[0] result = model.tokenizer.decode(tokens) print(f"Result: {result}") -# For more details on EoRA, please see GPTQModel/examples/eora +# For more details on EoRA, please see docs/eora/ # Please use the benchmark tools in later part of this README to evaluate EoRA effectiveness ``` -For more advanced features of model quantization, please refer to [this script](https://github.com/ModelCloud/GPTQModel/blob/main/examples/quantization/basic_usage_wikitext2.py) - ### How to Add Support for a New Model Read the [`gptqmodel/models/llama.py`](https://github.com/ModelCloud/GPTQModel/blob/5627f5ffeb3f19b1a2a97e3b6de6fbe668b0dc42/gptqmodel/models/llama.py) code which explains in detail via comments how the model support is defined. Use it as a guide for PRs to add new models. Most models follow the same pattern. ### Evaluation and Quality Benchmarks -GPTQModel inference is integrated into both [lm-eval](https://github.com/EleutherAI/lm-evaluation-harness) and [evalplus](https://github.com/evalplus/evalplus) -We highly recommend avoiding `ppl` and using `lm-eval`/`evalplus` to validate post-quantization model quality. `ppl` should only be used for regression tests and is not a good indicator of model output quality. - -``` -# gptqmodel is integrated into lm-eval >= v0.4.7 -pip install lm-eval>=0.4.7 -``` +GPT-QModel evaluation is integrated into [Evalution](https://github.com/modelcloud/Evalution). +We highly recommend using Evalution to validate post-quantization model quality. Regression-only language-model metrics are deprecated in this guide. ``` -# gptqmodel is integrated into evalplus[main] -pip install -U "evalplus @ git+https://github.com/evalplus/evalplus" +# install Evalution +pip install Evalution ``` -Below is a basic sample using `GPTQModel.eval` API +Below is a basic sample using Evalution's GPT-QModel engine directly via `engines.GPTQModel`. ```py -from gptqmodel import GPTQModel -from gptqmodel.utils.eval import EVAL - -model_id = "ModelCloud/Llama-3.2-1B-Instruct-gptqmodel-4bit-vortex-v1" - -# Use `lm-eval` as framework to evaluate the model -lm_eval_data = GPTQModel.eval(model_id, - framework=EVAL.LM_EVAL, - tasks=[EVAL.LM_EVAL.ARC_CHALLENGE]) - +import evalution as eval +import evalution.benchmarks as benchmarks +import evalution.engines as engines + +run = ( + engines.GPTQModel( + backend="marlin", + device="cuda:0", + dtype="auto", + batch_size="auto", + ) + .model( + eval.Model( + path="ModelCloud/Llama-3.2-1B-Instruct-gptqmodel-4bit-vortex-v1", + ) + ) + .run( + benchmarks.arc_challenge( + apply_chat_template=True, + batch_size=16, + ) + ) +) -# Use `evalplus` as framework to evaluate the model -evalplus_data = GPTQModel.eval(model_id, - framework=EVAL.EVALPLUS, - tasks=[EVAL.EVALPLUS.HUMAN]) +result = run.to_dict() +print(result["tests"][0]["metrics"]) ``` ### Dynamic Quantization (Per Module QuantizeConfig Override) @@ -433,8 +685,8 @@ If your goal is "better GPTQ quality without touching the inference kernels", th * GAR / `act_group_aware=True` to improve activation ordering without inference-time penalties. * MSE-based scale search (`mse > 0`) to reduce outlier-driven grid distortion. -* Activation-weighted MSE search (`activation_weighted_mse=True`) to bias scale selection toward Hessian-salient channels using an offline-only importance signal. * Adaptive damping for badly conditioned Hessian blocks. +* Best-of failsafe smoothing that tries several kernel-compatible offline preconditioning candidates and keeps the lowest-error result for under-sampled modules. * Optional GPTAQ experimentation, with the same GPTQ export format, when you want to test more aggressive offline correction. This is exposed as a convenience preset: @@ -445,7 +697,172 @@ from gptqmodel.quantization import QuantizeConfig quant_config = QuantizeConfig.gptq_pro() ``` -`QuantizeConfig.gptq_pro()` is intentionally conservative: it keeps `quant_method=METHOD.GPTQ` and `format=FORMAT.GPTQ`, so inference speed comes from the same kernels as regular GPTQ. Today that preset combines existing GPTQModel features with an AutoRound-inspired offline importance weighting pass during MSE scale search, but it does **not** claim that GPTQModel currently implements AWQ-style layer fusion or AutoRound-style learned rounding inside the GPTQ inner loop; those are separate algorithms and should be treated as separate offline quantizers. +`QuantizeConfig.gptq_pro()` is intentionally conservative: it keeps `quant_method=METHOD.GPTQ` and `format=FORMAT.GPTQ`, so inference speed comes from the same kernels as regular GPTQ. The new preset also stays offline-only: for low-sample fallback blocks it borrows an AutoRound-like idea by searching a few smoothing candidates and choosing the one with the lowest reconstruction MSE, but it still emits ordinary GPTQ weights/scales/zeros for the same inference kernels. It does **not** claim that GPTQModel currently implements AWQ-style layer fusion or AutoRound-style learned rounding inside the GPTQ inner loop; those are separate algorithms and should be treated as separate offline quantizers. + +Migration note: `QuantizeConfig.gptq_pro()` previously used a single fixed `SmoothMSE(...)` failsafe configuration. It now defaults to `SmoothAuto()`, so quantized outputs can change slightly across upgrades even though the exported GPTQ format and inference kernels stay the same. If you need the older GPTQ-Pro behavior for reproducibility, pass the previous failsafe config explicitly. + +#### Reference run: `Qwen/Qwen3.5-4B` on one RTX 3060 (12 GB) + +As a concrete single-GPU reference point, we quantized `Qwen/Qwen3.5-4B` with `QuantizeConfig.gptq_pro(bits=4, group_size=128, offload_to_disk=True)` on one isolated RTX 3060 (12 GB) and compared the original and quantized checkpoints with `vLLM`. + +Quantization used 16 calibration samples and finished in `376.20s`, with an additional `2.54s` to save the checkpoint. + +| Variant | Checkpoint size | vLLM load time | Output tok/s | Avg request latency | Wikitext-2 raw PPL* | +| --- | --- | --- | --- | --- | --- | +| Original checkpoint | `8.8G` | `24.10s` | `16.19` | `1.98s` | `11.36` | +| GPTQ-Pro 4-bit checkpoint | `3.0G` | `15.37s` | `52.64` | `0.61s` | `12.33` | + +That run delivered a `3.25x` throughput speedup, `69.24%` lower average request latency, and `1.57x` faster model load time, while increasing perplexity by `0.97` absolute (`1.085x` relative) on Wikitext-2 raw. + +Operational notes from this setup: + +* The source `Qwen/Qwen3.5-4B` checkpoint is multimodal, so on a 12 GB 3060 the baseline `vLLM` comparison needed text-only settings: `language_model_only=True`, `limit_mm_per_prompt={"image": 0, "video": 0}`, `skip_mm_profiling=True`, `enforce_eager=True`, and `max_model_len=256`. +* For this quantized checkpoint and `vLLM` stack, `gptq_marlin` was the working backend and the original tokenizer path had to be reused when serving the quantized weights. +* Native `transformers` Qwen3.5 text quantization currently needs `batch_size=1` in GPTQModel because multi-sample padded calibration batches can fail in the SDPA attention path. +* `*` Perplexity is included only as a regression signal here, consistent with the note above; it was measured on `wikitext-2-raw-v1` with `n_ctx=256` and `n_batch=256`. + +#### Replication assets, CUDA scaffold files, and `wangzhang/Qwen3.5-4B-abliterated` follow-up + +If you want to reproduce the low-level GPTQ-Pro CUDA validation work from this repository state, +the relevant standalone CUDA files are: + +* `gptqmodel_ext/gptq_pro/gptq_pro_kernel.cuh` +* `gptqmodel_ext/gptq_pro/gptq_pro_kernel.cu` +* `gptqmodel_ext/gptq_pro/gptq_pro_validate.cu` + +These are the files used for the standalone scaffold / validator flow. They now cover a +functional single-warp `sm80` Tensor Core path with explicit A/B/S staging and end-to-end +validation, and this repository now exposes that scaffold as an **optional** +`BACKEND.GPTQ_PRO` local runtime for explicit `GPTQModel.load(..., backend=BACKEND.GPTQ_PRO)` +use on Ampere-class CUDA systems. + +For a local GPTQ-Pro export, the intended runtime entrypoint is: + +```python +from gptqmodel import BACKEND, GPTQModel + +model = GPTQModel.load( + "/path/to/local-qwen35-gptq-pro", + device="cuda:0", + backend=BACKEND.GPTQ_PRO, + trust_remote_code=True, +) +``` + +This explicit-only path was revalidated against a local `Qwen3.5-4B` GPTQ-Pro checkpoint: +the loader selected `GptqProQuantLinear` modules end-to-end and completed a short generation +smoke on CUDA. + +Current runtime limits are intentionally narrow: `4-bit`, symmetric GPTQ, `torch.float16`, +and `desc_act=False` / sequential `g_idx` only. It derives a non-persistent byte-packed weight +buffer during `post_init()` for the CUDA kernel, but it does **not** yet replace the larger +production paths (`Marlin`, `Machete`, `vLLM`) or implement the planned multi-warp +`cp.async` / `ldmatrix` pipeline discussed in `Project.md` / `progress.md`. + +For the `wangzhang/Qwen3.5-4B-abliterated` text-only follow-up, the measured results were: + +| Variant | Quantization time | WikiText-2 raw PPL | vLLM status / speed | +| --- | --- | --- | --- | +| Original BF16 | n/a | `8.3116` | vanilla `vLLM 0.17.0` blocked by Qwen3.5 text config mismatch | +| Plain GPTQ 4-bit g128 | `181.4s` | `8.6759` | vanilla `vLLM 0.17.0` blocked by the same config mismatch | +| GPTQ-Pro 4-bit g128 | `324.9s` | `8.6314` | patched `vLLM` + `gptq_marlin`: `175.21-178.14 tok/s` on `1x 3090`, `194.20-206.53 tok/s` on `2x 3090` | + +Vanilla `vLLM 0.17.0` failed on the original, plain GPTQ, and GPTQ-Pro checkpoints before first +token with the same `Qwen3_5TextConfig` vs `Qwen3_5Config` type mismatch. The detailed comparison +is documented in [`docs/qwen35_vllm_comparison.md`](docs/qwen35_vllm_comparison.md). + +For local or Hugging Face `qwen3_5_text` checkpoints in this repository, use +`scripts/serve_vllm_qwen35.py` instead of calling `vllm serve` directly. The +wrapper applies the text-only `Qwen3.5` serving settings documented above, +patches vLLM's renderer/model-registry startup so `Qwen3_5TextConfig` +checkpoints stay on the causal-LM path, restores the hybrid/M-RoPE interfaces +needed by Qwen3.5 text-only checkpoints, maps vLLM's startup-time NVML scan to +the GPUs already selected via `CUDA_VISIBLE_DEVICES`, and installs a small +`LD_PRELOAD` NVML shim so NCCL tensor-parallel startup can skip broken physical +GPUs on shared hosts. It now auto-detects `qwen3_5_text` from either a local +`config.json` or a Hub repo ID's `AutoConfig`. + +For the exact launch commands that match the working `vLLM + gptq_marlin` path, +see [`docs/qwen35_vllm_launch.md`](docs/qwen35_vllm_launch.md). + +In the shared-host validation for +`lukey03/Qwen3.5-9B-abliterated` -> GPTQ-Pro 4-bit g128, quantization took +`415.2s`, the earlier `GPTQModel.generate()` path measured only `15.61 tok/s` +on `1x GPU` and `10.44 tok/s` on `2x GPU`, and the corrected warmed vLLM + +`gptq_marlin` path reached about `104.96 tok/s` on `1x 3090` and `154.26 tok/s` +on `2x 3090` (with `--gpu-memory-utilization 0.4` on the 2-GPU shared-host +run). First-request latency is much slower than steady-state throughput because +`torch.compile` and CUDA-graph capture warm the model on first use. + +#### Important limitation: GGUF-only Qwen 3.5 35B-A3B repositories + +The repository `HauhauCS/Qwen3.5-35B-A3B-Uncensored-HauhauCS-Aggressive` currently publishes +**GGUF files only** and its model card explicitly says `GPTQ — coming soon`. + +That matters because GPTQModel and the current vLLM workflow in this repository expect a +Transformers / Safetensors checkpoint for GPTQ-Pro quantization. A GGUF-only repo is therefore not +directly quantizable by GPTQModel, even if the underlying base architecture is supported. + +If you want to reproduce the same architecture with GPTQ-Pro today, start from the upstream +Transformers checkpoint `Qwen/Qwen3.5-35B-A3B`, then re-run the same quantization and evaluation +workflow once a non-GGUF release of the HauhauCS fine-tune becomes available. + +#### Replacement run: `huihui-ai/Huihui-Qwen3.5-27B-abliterated` + +Because the exact HauhauCS repo is GGUF-only, the quantizable replacement run used +`huihui-ai/Huihui-Qwen3.5-27B-abliterated`, which ships a full Transformers / Safetensors +checkpoint and keeps the same Qwen 3.5 family. + +For this larger model, the original-model perplexity path had to be stabilized on a single +RTX 3090 plus CPU offload because another user process was holding ~8 GiB on the second 3090. +The resulting comparison uses a fixed regression slice rather than the full WikiText-2 sweep: +`max_length=256`, `stride=256`, `max_windows=16`, `4096` scored tokens total. + +| Variant | Quantization time | WikiText-2 regression-slice PPL | Notes | +| --- | --- | --- | --- | +| Original BF16 | n/a | `11.6266` | single `3090` + CPU offload | +| GPTQ-Pro 4-bit g32 | `2273.6s` | `12.0161` | `128` calibration samples, output size ~`18G` | + +This GPTQ-Pro run therefore added `+0.3895` PPL on the same 4,096-token slice. The absolute PPL is +not directly comparable to the earlier 4B full-dataset numbers because the 27B run used much shorter +context windows to stay within the shared-machine VRAM budget and still hit the torch fallback path +for Qwen 3.5 linear attention. + +The saved quantization metadata for this run confirms the highest-quality settings used here: + +* `QuantizeConfig.gptq_pro(bits=4, group_size=32)` +* `calibration_samples=128` +* `vram_strategy=balanced` +* `gc_mode=on_stage_end` +* `auto_forward_data_parallel=false` +* `wait_for_submodule_finalizers=true` +* `moe.routing=ExpertsRoutingBypass(batch_size=2)` +* `offload_to_disk=true` + +`vLLM 0.17.0` still did **not** deploy this replacement model cleanly in the tested environment. A +one-shot offline `LLM.generate()` smoke test against the quantized checkpoint selected +`gptq_marlin`, but then failed before generation with the same Qwen 3.5 config-family mismatch: +`Qwen3_5TextConfig` vs `Qwen3_5Config`, this time through the multimodal renderer path. + +Useful repo-side tools for that workflow: + +```bash +# regression-style perplexity +python examples/benchmark/perplexity.py \ + --model /path/to/quantized-model \ + --tokenizer /path/to/source-model \ + --is_quantized \ + --trust_remote_code \ + --backend marlin + +# task evaluation with vLLM +python scripts/eval_model.py \ + --model /path/to/quantized-model \ + --tasks arc_challenge,mmlu_stem \ + --backend marlin \ + --use-vllm \ + --trust-remote-code +``` ### Experimental Features @@ -459,6 +876,14 @@ Enable GPTAQ quantization by setting `gptaq = GPTAQConfig(...)`. # If OOM on 1 GPU, please set CUDA_VISIBLE_DEVICES=0,1 to 2 GPUs and gptqmodel will auto use second GPU quant_config = QuantizeConfig(bits=4, group_size=128, gptaq=GPTAQConfig(alpha=0.25, device="auto")) ``` + +#### Using FOEM + +FOEM (First-order error matters) adds first-order error compensation for GPTQ-style quantization. Enable FOEM by setting `foem = FOEMConfig(...)`. +```py +# FOEM default hyperparameters are alpha=0.0 and beta=0.2 +quant_config = QuantizeConfig(bits=4, group_size=128, foem=FOEMConfig(alpha=0.0, beta=0.2, device="auto")) +``` ### Migrating from AutoGPTQ and AutoAWQ: GPT-QModel has fully supplanted AutoGPTQ and AutoAWQ for HF Transformers/Optimum/Peft integration. Model inference has drop-in support with zero changes. @@ -473,10 +898,12 @@ Models quantized by GPT-QModel are inference compatible with HF Transformers (mi * GPTQ: IST-DASLab, main-author: Elias Frantar, arXiv:2210.17323 * AWQ: main-authors: Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song +* ParoQuant: Z-Lab, main-authors: Yesheng Liang, Haisheng Chen, Song Han, and Zhijian Liu. [Official implementation](https://github.com/z-lab/paroquant), [Paper](https://openreview.net/forum?id=1USeVjsKau) * EoRA: Nvidia, main-author: Shih-Yang Liu, arXiv preprint arXiv:2410.21271. * GAR: Intel, main-author: T Gafni, A Karnieli, Y Hanani, [Paper](https://openaccess.thecvf.com/content/CVPR2025W/eLVM/html/Gafni_Dual_Precision_Quantization_for_Efficient_and_Accurate_Deep_Neural_Networks_CVPRW_2025_paper.html) * GPTAQ: Yale Intelligent Computing Lab, main-author: Yuhang Li, arXiv:2504.02692. * QQQ: Meituan, main-author Ying Zhang, arXiv:2406.09904 +* FOEM: Zheng, Xingyu and Qin, Haotong and Li, Yuye and Chu, Haoran and Wang, Jiakai and Guo, Jinyang and Magno, Michele and Liu, Xianglong [Paper](https://ojs.aaai.org/index.php/AAAI/article/view/40123) ## Citations: @@ -509,6 +936,25 @@ Models quantized by GPT-QModel are inference compatible with HF Transformers (mi year={2023} } +# ParoQuant +@inproceedings{liang2026paroquant, + title = {{ParoQuant: Pairwise Rotation Quantization for Efficient Reasoning LLM Inference}}, + author = {Liang, Yesheng and Chen, Haisheng and Han, Song and Liu, Zhijian}, + booktitle = {International Conference on Learning Representations (ICLR)}, + year = {2026} +} + +# GGUF / llama.cpp +@misc{ggerganov2023gguf, + author = {Georgi Gerganov and ggml-org contributors}, + title = {llama.cpp and the GGUF model format}, + publisher = {GitHub}, + journal = {GitHub repository}, + howpublished = {\url{https://github.com/ggml-org/llama.cpp}}, + note = {Canonical GGUF implementation and format reference; see also \url{https://github.com/ggml-org/llama.cpp/wiki/dev-notes}}, + year = {2023} +} + # EoRA @article{liu2024eora, title={EoRA: Training-free Compensation for Compressed LLM with Eigenspace Low-Rank Approximation}, @@ -517,22 +963,6 @@ Models quantized by GPT-QModel are inference compatible with HF Transformers (mi year={2024} } -# Group Aware Reordering (GAR) -@article{gar, - title={Dual Precision Quantization for Efficient and Accurate Deep Neural Networks Inference, CVPRW 2025.}, - author={T. Gafni, A. Karnieli, Y. Hanani}, - journal={arXiv preprint arXiv:2505.14638}, - year={2025} -} - -# GPTQ Marlin Kernel -@article{frantar2024marlin, - title={MARLIN: Mixed-Precision Auto-Regressive Parallel Inference on Large Language Models}, - author={Frantar, Elias and Castro, Roberto L and Chen, Jiale and Hoefler, Torsten and Alistarh, Dan}, - journal={arXiv preprint arXiv:2408.11743}, - year={2024} -} - # GPTAQ @article{li2025gptaq, title={GPTAQ: Efficient Finetuning-Free Quantization for Asymmetric Calibration}, @@ -541,6 +971,17 @@ Models quantized by GPT-QModel are inference compatible with HF Transformers (mi year={2025} } +# FOEM +@inproceedings{zheng2026first, + title={First-order error matters: Accurate compensation for quantized large language models}, + author={Zheng, Xingyu and Qin, Haotong and Li, Yuye and Chu, Haoran and Wang, Jiakai and Guo, Jinyang and Magno, Michele and Liu, Xianglong}, + booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, + volume={40}, + number={34}, + pages={28883--28891}, + year={2026} +} + # QQQ @article{zhang2024qqq, title={QQQ: Quality Quattuor-Bit Quantization for Large Language Models}, @@ -548,13 +989,41 @@ Models quantized by GPT-QModel are inference compatible with HF Transformers (mi journal={arXiv preprint arXiv:2406.09904}, year={2024} } + +# ExLlama V3 / EXL3 +@misc{turboderp2026exllamav3, + author = {turboderp and exllamav3 contributors}, + title = {ExLlamaV3 and the EXL3 quantization format}, + publisher = {GitHub}, + journal = {GitHub repository}, + howpublished = {\url{https://github.com/turboderp-org/exllamav3}}, + note = {Project repository and EXL3 format documentation: \url{https://github.com/turboderp-org/exllamav3/blob/master/doc/exl3.md}}, + year = {2026} +} + +# Group Aware Reordering (GAR) +@article{gar, + title={Dual Precision Quantization for Efficient and Accurate Deep Neural Networks Inference, CVPRW 2025.}, + author={T. Gafni, A. Karnieli, Y. Hanani}, + journal={arXiv preprint arXiv:2505.14638}, + year={2025} +} + +# Marlin Kernel +@article{frantar2024marlin, + title={MARLIN: Mixed-Precision Auto-Regressive Parallel Inference on Large Language Models}, + author={Frantar, Elias and Castro, Roberto L and Chen, Jiale and Hoefler, Torsten and Alistarh, Dan}, + journal={arXiv preprint arXiv:2408.11743}, + year={2024} +} + ``` ## Quick Notes ### Limit log level -`GPTQModel` uses a shared `LogBar` logger. Set the level once near process startup: +`GPT-QModel` uses a shared `LogBar` logger. Set the level once near process startup: ```python from logbar import LogBar diff --git a/chat/README.md b/chat/README.md index 24f1f0ed1..8d0276923 100644 --- a/chat/README.md +++ b/chat/README.md @@ -1,8 +1,32 @@ ## Chat CLI ## Usage +```bash ./run.sh --gpu_id --model_path +``` --gpu_id: Specifies the GPU ID to use. This maps to the CUDA_VISIBLE_DEVICES environment variable. ---model_path: The file path to the chat model. This must point to a valid model directory or file. \ No newline at end of file +--model_path: The file path to the chat model. This must point to a valid model directory or file. + +### Useful options + +```bash +./run.sh \ + --gpu_id 0 \ + --model_path /models/Qwen3.5-4B-abliterated-GPTQ-Pro-4bit \ + --tokenizer_path wangzhang/Qwen3.5-4B-abliterated \ + --max_new_tokens 1024 \ + --trust_remote_code +``` + +- `--tokenizer_path`: Reuse the source tokenizer when a quantized checkpoint should not use its local tokenizer files. +- `--system_prompt`: Override the default system prompt. Pass an empty string to disable it. +- `--max_new_tokens`: Cap the number of generated tokens per assistant reply. +- `--trust_remote_code`: Required for some newer Hugging Face model families such as Qwen 3.5. + +### Qwen 3.5 / GPTQ-Pro notes + +- For GPTQ-Pro checkpoints quantized from Qwen 3.5 models, `GPTQModel.load()` is the serving path used by this CLI. +- If you are benchmarking with `vLLM`, keep that workflow separate from `chat.py`; the chat CLI is intended as a lightweight local frontend for quick manual checks. +- The main repository `README.md` contains the full replication guide, Docker/conda setup, and Qwen 3.5 vLLM compatibility notes. diff --git a/chat/chat.py b/chat/chat.py index 1d3ea5df8..145a3f51b 100644 --- a/chat/chat.py +++ b/chat/chat.py @@ -9,6 +9,7 @@ from colorama import Fore, init from gptqmodel import GPTQModel +from transformers import AutoTokenizer init(autoreset=True) @@ -21,33 +22,41 @@ ASSISTANT_HELLO = 'How can I help you?' EXIT_MESSAGE = 'Exiting the program.' - -MESSAGES = [ - {"role": "system", "content": "You are a helpful and harmless assistant. You should think step-by-step."} -] +DEFAULT_SYSTEM_PROMPT = "You are a helpful and harmless assistant. You should think step-by-step." DEBUG = False -def load_model(model_path): +def build_messages(system_prompt): + if system_prompt: + return [{"role": "system", "content": system_prompt}] + return [] + + +def load_model(model_path, trust_remote_code=False): print(Fore.BLUE + f"Loading model from `{model_path}` ...\n") - model = GPTQModel.load(model_path) + model = GPTQModel.load(model_path, trust_remote_code=trust_remote_code) return model -def chat_prompt_progress(user_input, tokenizer): +def load_tokenizer(tokenizer_path, trust_remote_code=False): + print(Fore.BLUE + f"Loading tokenizer from `{tokenizer_path}` ...\n") + return AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=trust_remote_code) + + +def chat_prompt_progress(user_input, tokenizer, messages): user_message = {"role": KEY_USER, "content": user_input} - MESSAGES.append(user_message) - input_tensor = tokenizer.apply_chat_template(MESSAGES, add_generation_prompt=True, return_tensors="pt") + messages.append(user_message) + input_tensor = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt") if DEBUG: - debug(tokenizer) + debug(tokenizer, messages) return input_tensor -def debug(tokenizer): +def debug(tokenizer, messages): print("********* DEBUG START *********") print("********* Chat Template info *********") - print(tokenizer.apply_chat_template(MESSAGES, return_dict=False, tokenize=False, add_generation_prompt=True)) + print(tokenizer.apply_chat_template(messages, return_dict=False, tokenize=False, add_generation_prompt=True)) print("********* DEBUG END *********") @@ -73,7 +82,14 @@ def save_chat_history(chat_history, save_path): if __name__ == "__main__": parser = argparse.ArgumentParser(description="Chat with a GPT model.") parser.add_argument('--model_path', type=str, help="Path to the model.") + parser.add_argument('--tokenizer_path', type=str, help="Optional tokenizer path. Useful when a quantized checkpoint should reuse the source tokenizer.") parser.add_argument('--save_chat_path', type=str, help="Path to save the chat history.") + parser.add_argument('--system_prompt', type=str, default=DEFAULT_SYSTEM_PROMPT, + help='Optional system prompt. Pass an empty string to disable the system message.') + parser.add_argument('--max_new_tokens', type=int, default=4096, + help='Maximum number of new tokens to generate per assistant turn.') + parser.add_argument('--trust_remote_code', action='store_true', default=False, + help='Allow custom model/tokenizer code from Hugging Face repos.') parser.add_argument('--debug', action='store_true', default=False, help='Print Debug Info') args = parser.parse_args() @@ -81,15 +97,20 @@ def save_chat_history(chat_history, save_path): raise ValueError("Model path is None, Please Set `--model_path`") DEBUG = args.debug - model = load_model(args.model_path) + model = load_model(args.model_path, trust_remote_code=args.trust_remote_code) + messages = build_messages(args.system_prompt) print(Fore.CYAN + "Welcome to GPTQModel Chat Assistant!\n") print(Fore.YELLOW + "You can enter questions or commands as follows:\n") print(Fore.YELLOW + "1. Type your question for the model.\n") print(Fore.YELLOW + "2. Type 'exit' to quit the program.\n") print(Fore.YELLOW + "3. Type 'save' to save the chat history.\n") + print(Fore.YELLOW + f"4. Current max_new_tokens per reply: {args.max_new_tokens}\n") - tokenizer = model.tokenizer + tokenizer = model.tokenizer if args.tokenizer_path is None else load_tokenizer( + args.tokenizer_path, + trust_remote_code=args.trust_remote_code, + ) if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id @@ -106,15 +127,15 @@ def save_chat_history(chat_history, save_path): elif user_input.lower() == 'save': save_chat_history(chat_history, args.save_chat_path) else: - input_tensor = chat_prompt_progress(user_input, tokenizer) + input_tensor = chat_prompt_progress(user_input, tokenizer, messages) outputs = model.generate( input_ids=input_tensor.to(model.device), - max_new_tokens=4096, + max_new_tokens=args.max_new_tokens, pad_token_id=tokenizer.pad_token_id ) assistant_response = tokenizer.decode(outputs[0][input_tensor.shape[1]:], skip_special_tokens=True) - MESSAGES.append({"role": KEY_ASSISTANT, "content": assistant_response}) + messages.append({"role": KEY_ASSISTANT, "content": assistant_response}) chat_history.append({KEY_USER: user_input, KEY_ASSISTANT: assistant_response}) print_model_message(assistant_response) diff --git a/chat/run.sh b/chat/run.sh index 2b5567060..11dcc02f9 100755 --- a/chat/run.sh +++ b/chat/run.sh @@ -1,10 +1,34 @@ #!/bin/bash +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + GPU_ID=0 MODEL_PATH="" +TOKENIZER_PATH="" +SYSTEM_PROMPT="" +MAX_NEW_TOKENS=4096 +TRUST_REMOTE_CODE=0 + +print_help() { + cat <<'EOF' +Usage: + ./run.sh --gpu_id --model_path [options] + +Options: + --tokenizer_path Optional tokenizer path. + --system_prompt Optional system prompt. + --max_new_tokens Maximum generated tokens per reply. + --trust_remote_code Allow custom Hugging Face model/tokenizer code. + --help Show this help message. +EOF +} while [[ $# -gt 0 ]]; do case $1 in + --help|-h) + print_help + exit 0 + ;; --gpu_id) GPU_ID="$2" shift @@ -15,8 +39,27 @@ while [[ $# -gt 0 ]]; do shift shift ;; + --tokenizer_path) + TOKENIZER_PATH="$2" + shift + shift + ;; + --system_prompt) + SYSTEM_PROMPT="$2" + shift + shift + ;; + --max_new_tokens) + MAX_NEW_TOKENS="$2" + shift + shift + ;; + --trust_remote_code) + TRUST_REMOTE_CODE=1 + shift + ;; *) - echo "Unknow $1" + echo "Unknown $1" exit 1 ;; esac @@ -27,4 +70,18 @@ if [[ -z "$MODEL_PATH" ]]; then exit 1 fi -env CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES="$GPU_ID" python chat.py --model_path "$MODEL_PATH" +CMD=(python "$SCRIPT_DIR/chat.py" --model_path "$MODEL_PATH" --max_new_tokens "$MAX_NEW_TOKENS") + +if [[ -n "$TOKENIZER_PATH" ]]; then + CMD+=(--tokenizer_path "$TOKENIZER_PATH") +fi + +if [[ -n "$SYSTEM_PROMPT" ]]; then + CMD+=(--system_prompt "$SYSTEM_PROMPT") +fi + +if [[ "$TRUST_REMOTE_CODE" -eq 1 ]]; then + CMD+=(--trust_remote_code) +fi + +env CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES="$GPU_ID" "${CMD[@]}" diff --git a/docs/eora/README.md b/docs/eora/README.md new file mode 100644 index 000000000..3dda19c52 --- /dev/null +++ b/docs/eora/README.md @@ -0,0 +1,97 @@ +# EoRA: Training-free Compensation for Compressed LLM with Eigenspace Low-Rank Approximation +EoRA is a training-free method that uses a calibration dataset to build low-rank matrices aimed at mitigating quantization errors and enhancing the performance of quantized models. Its generation takes about the same time as GPTQ. +For more details, please refer to the paper: https://arxiv.org/abs/2410.21271. + + +## Calibration data +EoRA’s major advantage is that it can enhance the accuracy of quantized models on various downstream tasks without training, simply by using a small amount of task-specific data as a calibration set. For instance, to improve performance on MMLU, you can employ the MMLU validation set as calibration data when generating EoRA. Additionally, EoRA can boost a quantized model’s overall quality by using the same calibration data as GPTQ. + +For examples of how to create these calibration sets, see `construct_c4` in `docs/eora/eora_calibration_data_construction.py` for a general-purpose setup using the C4 dataset, and `construct_mmlu` in the same file for task-specific calibration data. + +## EoRA generation +There are two ways to produce EoRA. The first is to generate it simultaneously with GPTQ during the quantization process. The second is to take an already GPTQ-quantized model and apply EoRA generation on top of it. + +### First option: Generate EoRA and the GPTQ model together during quantization. +Below is an example of using C4 as calibration data for generating EoRA of rank 64 alongside 4-bits GPTQ quantization of meta-llama/Llama-3.2-3B. To further improve the accuracy on MMLU, set mmlu for eora_dataset. +```shell +python docs/eora/eora_generation.py meta-llama/Llama-3.2-3B --bits 4 \ + --quant_save_path docs/eora/Llama-3.2-3B-4bits \ + --eora_dataset c4 \ + --eora_save_path docs/eora/Llama-3.2-3B-4bits-eora_rank64_c4 \ + --eora_rank 64 +``` + +### Second option: If a GPTQ model is already available, run EoRA generation directly on the quantized model. +Below is an example of using C4 as calibration data for generating EoRA of rank 64 given a 4-bits GPTQ quantized meta-llama/Llama-3.2-3B. To further improve the accuracy on MMLU, set mmlu for eora_dataset. +```shell +python docs/eora/post_quant_eora_generation.py meta-llama/Llama-3.2-3B c4 \ + --quantized_model sliuau/Llama-3.2-3B_4bits_128group_size \ + --eora_save_path docs/eora/Llama-3.2-3B-4bits-eora_rank64_c4 \ + --eora_rank 64 +``` + +## EoRA Evaluation +To evaluate the GPTQ quantized model and the corresponding EoRA on ARC-C and MMLU run: +```shell +python docs/eora/evaluation.py --quantized_model sliuau/Llama-3.2-3B_4bits_128group_size \ + --eora_save_path docs/eora/Llama-3.2-3B-4bits-eora_rank64_c4 \ + --eora_rank 64 +``` + +## EoRA Inference +Please refer to `docs/eora/eora_load_and_inference.py` for how to load EoRA and the corresponding GPTQ quantized model for inference. +A simple example: +```python + +from gptqmodel import BACKEND, GPTQModel # noqa: E402 +from gptqmodel.adapter.adapter import Lora + +eora = Lora( + # for eora generation, path is adapter save path; for load, it is loading path + path='docs/eora/Llama-3.2-3B-4bits-eora_rank64_c4 ', + rank=64, +) + +model = GPTQModel.load( + model_id_or_path='sliuau/Llama-3.2-3B_4bits_128group_size', + adapter=eora, +) + +tokens = model.generate("Capital of France is")[0] +result = model.tokenizer.decode(tokens) +print(f"Result: {result}") +``` + +## EoRA Kernel +We are working on improving the numerical stability of our EoRA kernel which can further speedup the EoRA + GPTQ inference up to 2.5x. Stay tuned! + + +## EoRA results +We ran a series of experiments on meta-llama/Llama-3.2-3B. From the results, we see that EoRA substantially improves the accuracy of 3/4-bit quantized models on MMLU, and using the MMLU validation set as calibration data compared to using C4 can further increase MMLU accuracy. +|Model| Bit-width | EoRA Calibration Dataset | EoRA Rank | MMLU | MMLU Accuracy Boost(%) | +|---| ---| ---| ---| ---| ---| +|meta-llama/Llama-3.2-3B | Full-Precision (FP16) | - | - | 54.19 | - | +|meta-llama/Llama-3.2-3B | 4 | - | - | 24.16 | - | +|meta-llama/Llama-3.2-3B | 4 | C4 | 32 | 52.53 | 217.43%| +|meta-llama/Llama-3.2-3B | 4 | C4 | 64 | 52.49 | 217.26%| +|meta-llama/Llama-3.2-3B | 4 | C4 | 128 | 52.93 | 219.08% | +|meta-llama/Llama-3.2-3B | 4 | MMLU | 32 | 53.43 | 221.15% | +|meta-llama/Llama-3.2-3B | 4 | MMLU | 64 | 53.32 | 220.70%| +|meta-llama/Llama-3.2-3B | 4 | MMLU | 128 | 53.42 | 221.11% | +|meta-llama/Llama-3.2-3B | 3| - | - | 22.89 | - | +|meta-llama/Llama-3.2-3B | 3 | C4 | 32 | 39.08 | 170.73% | +|meta-llama/Llama-3.2-3B | 3 | C4 | 64 | 38.83 |169.64% | +|meta-llama/Llama-3.2-3B | 3 | C4 | 128 | 39.68 | 173.35%| + +In general, setting rank to 32 and use C4 as calibration data could be a good starting point when applying EoRA to improve the quantized model accuracy. + +## Citation +If you find our code useful for your research, please consider citing: +```bibtex +@article{liu2024eora, + title={EoRA: Training-free Compensation for Compressed LLM with Eigenspace Low-Rank Approximation}, + author={Liu, Shih-Yang and Yang, Huck and Wang, Chein-Yi and Fung, Nai Chit and Yin, Hongxu and Sakr, Charbel and Muralidharan, Saurav and Cheng, Kwang-Ting and Kautz, Jan and Wang, Yu-Chiang Frank and others}, + journal={arXiv preprint arXiv:2410.21271}, + year={2024} +} +``` diff --git a/docs/eora/eora_calibration_data_construction.py b/docs/eora/eora_calibration_data_construction.py new file mode 100644 index 000000000..9e189c371 --- /dev/null +++ b/docs/eora/eora_calibration_data_construction.py @@ -0,0 +1,68 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from datasets import load_dataset + + +def question_answering_format(question, answer): + + return f"Question: {question}\nAnswer: {answer}" + +def multiple_choices_question_answering_format(question, choices, answer): + return f"{question.strip()}\nA. {choices[0]}\nB. {choices[1]}\nC. {choices[2]}\nD. {choices[3]}\nAnswer: {answer}" + +## An example of using ARC for construting the EoRA calibration set + +def construct_c4(): + calibration_dataset = load_dataset( + "allenai/c4", + data_files="en/c4-train.00001-of-01024.json.gz", + split="train", download_mode="force_redownload" + ).select(range(1024))["text"] + return calibration_dataset + +def construct_ARC(): + nsamples = 1024 + arc_easy_calibration_dataset = load_dataset('ai2_arc', 'ARC-Easy', split='train').select(range(nsamples)) + arc_challenge_calibration_dataset = load_dataset('ai2_arc', 'ARC-Challenge', split='train').select(range(nsamples)) + dataset = [] + + for example in arc_easy_calibration_dataset: + answer = example['choices']['text'][example['choices']['label'].index(example['answerKey'])] + question = example['question'] + dataset.append(question_answering_format(question=question,answer=answer)) + + for example in arc_challenge_calibration_dataset: + answer = example['choices']['text'][example['choices']['label'].index(example['answerKey'])] + question = example['question'] + dataset.append(question_answering_format(question=question,answer=answer)) + + ## we recommend also include some examples from C4 to avoid overfitting to the downstream data + c4_dataset = load_dataset( + "allenai/c4", + data_files="en/c4-train.00001-of-01024.json.gz", + split="train" + ).select(range(nsamples))["text"] + + return dataset + c4_dataset + +def construct_mmlu(): + + mmlu_calibration_dataset = load_dataset('cais/mmlu', 'all', split='validation') + dataset = [] + for example in mmlu_calibration_dataset: + question = example['question'] + choices = example['choices'] + answer = ['A','B','C','D'][example['answer']] + dataset.append(multiple_choices_question_answering_format(question, choices, answer)) + + ## we recommend also include some examples from C4 to avoid overfitting to the downstream data + c4_dataset = load_dataset( + "allenai/c4", + data_files="en/c4-train.00001-of-01024.json.gz", + split="train" + ).select(range(1024))["text"] + + return dataset + c4_dataset diff --git a/docs/eora/eora_generation.py b/docs/eora/eora_generation.py new file mode 100644 index 000000000..eecf2df9f --- /dev/null +++ b/docs/eora/eora_generation.py @@ -0,0 +1,138 @@ +# Copyright 2025 ModelCloud +# Contact: qubitium@modelcloud.ai, x.com/qubitium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# -- do not touch +import os + + +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" +# -- end do not touch + + +# from models.model_test import ModelTest # noqa: E402 +from eora_calibration_data_construction import construct_c4, construct_mmlu + +from gptqmodel import GPTQModel, QuantizeConfig # noqa: E402 +from gptqmodel.adapter.adapter import Lora +from gptqmodel.utils.torch import torch_empty_cache # noqa: E402 + + +## meta-llama/Llama-3.2-1B +## meta-llama/Llama-3.2-3B +## meta-llama/Meta-Llama-3-8B +## meta-llama/Llama-3.1-8B +## meta-llama/Meta-Llama-3-70B + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser() + + parser.add_argument( + 'model', type=str, + help='Full-preicision model to load; pass `facebook/opt-X`.' + ) + parser.add_argument( + '--bits', type=int, default=4 + ) + parser.add_argument( + '--group_size', type=int, default=128 + ) + parser.add_argument( + '--quant_save_path', type=str, default=None + ) + parser.add_argument( + '--eora_dataset', type=str, choices=['c4','mmlu'], + help='calibration set for eora' + ) + parser.add_argument( + '--eora_save_path',type=str, default=None + ) + parser.add_argument( + '--eora_rank', type=int, default=64 + ) + + args = parser.parse_args() + + NATIVE_MODEL_ID = args.model + + bits = args.bits + group_size = args.group_size + desc_act = True + rank = args.eora_rank + batch_size = 1 + calibration_dataset_concat_size = 0 # disable + + if args.quant_save_path is not None: + quant_path = args.quant_save_path + else: + raise AssertionError('Please provide a save path for the quantized model') + + if args.eora_save_path is not None: + eora_path = args.eora_save_path + else: + raise AssertionError('Please provide a save path for EoRA') + + + + ## C4 for quant + calibration_dataset = construct_c4() + + eora = Lora( + # for quant, path is save path. for load, it is loading path + path=os.path.join(quant_path, eora_path), + rank=rank, + ) + + quant_config = QuantizeConfig( + bits=bits, + group_size=group_size, + desc_act=desc_act, # bitblas only supports DESC_ACT=False + adapter=eora + ) + + model = GPTQModel.load( + model_id_or_path=NATIVE_MODEL_ID, + quantize_config=quant_config, + ) + + if args.eora_dataset == "c4": + model.quantize( + calibration=calibration_dataset, + batch_size=batch_size, + calibration_concat_size=calibration_dataset_concat_size, + ) # + else: + + eora_calibration_dataset = construct_mmlu() + + model.quantize( + calibration=calibration_dataset, + batch_size=batch_size, + calibration_concat_size=calibration_dataset_concat_size, + adapter_calibration_dataset=eora_calibration_dataset + ) # + + + # EoRA adapter is saved according to Lora.path property + # if Lora.path is not set, we will save the lora as "lora.safetensors" in the same path as quant model + # You can also pass `eora_path` to `model.save()` to override this save path + model.save(quant_path) + + del model + torch_empty_cache() + + diff --git a/docs/eora/eora_load_and_inference.py b/docs/eora/eora_load_and_inference.py new file mode 100644 index 000000000..f3c2d4e24 --- /dev/null +++ b/docs/eora/eora_load_and_inference.py @@ -0,0 +1,67 @@ +# Copyright 2025 ModelCloud +# Contact: qubitium@modelcloud.ai, x.com/qubitium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# -- do not touch +import os + + +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" +# -- end do not touch + +from gptqmodel import BACKEND, GPTQModel # noqa: E402 +from gptqmodel.adapter.adapter import Lora # noqa: E402 + + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser() + + parser.add_argument( + '--quantized_model', type=str, + help='Quantized model to load; pass' + ) + parser.add_argument( + '--eora',type=str, default=None + ) + parser.add_argument( + '--eora_rank', type=int + ) + + + args = parser.parse_args() + + if args.eora: + eora = Lora( + # for eora generation, path is adapter save path; for load, it is loading path + path=args.eora, + rank=args.eora_rank, + ) + else: + raise AssertionError("Please provide EoRA weight") + + + model = GPTQModel.load( + model_id_or_path=args.quantized_model, + backend=BACKEND.TORCH, + adapter=eora, + ) + + tokens = model.generate("Capital of France is")[0] + result = model.tokenizer.decode(tokens) + print(f"Result: {result}") + + diff --git a/docs/eora/evaluation.py b/docs/eora/evaluation.py new file mode 100644 index 000000000..4c318f377 --- /dev/null +++ b/docs/eora/evaluation.py @@ -0,0 +1,103 @@ +# Copyright 2025 ModelCloud +# Contact: qubitium@modelcloud.ai, x.com/qubitium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# -- do not touch +import os + + +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" +# -- end do not touch + +from typing import Optional # noqa: E402 + +from tests.eval import evaluate, format_eval_result_table # noqa: E402 + +from gptqmodel import BACKEND, GPTQModel # noqa: E402 +from gptqmodel.adapter.adapter import Lora # noqa: E402 +from gptqmodel.utils.torch import torch_empty_cache # noqa: E402 + + +def bench(path: str, backend: BACKEND, adapter: Optional[Lora], task): + # test post-quant inference + model = GPTQModel.load( + model_id_or_path=path, + backend=backend, + adapter=adapter, + ) + + # torch can benefit from optimization + if backend == BACKEND.TORCH: + model.optimize() + + if task == "all": + bench_result = evaluate( + model_or_id_or_path=model, + tasks=["arc_challenge", "mmlu"] + ) + elif task == "arc": + bench_result = evaluate( + model_or_id_or_path=model, + tasks=["arc_challenge"] + ) + elif task == "mmlu": + bench_result = evaluate( + model_or_id_or_path=model, + tasks=["mmlu"] + ) + + del model + torch_empty_cache() + + return bench_result + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser() + + parser.add_argument( + '--quantized_model', type=str, + help='Quantized model to load; pass' + ) + parser.add_argument( + '--eora_save_path',type=str, default=None + ) + parser.add_argument( + '--eora_rank', type=int + ) + parser.add_argument( + '--eval_task', type=str, default='all', choices=['mmlu','arc','all'] + ) + + args = parser.parse_args() + + if args.eora_save_path: + eora = Lora( + # for eora generation, path is adapter save path; for load, it is loading path + path=args.eora_save_path, + rank=args.eora_rank, + ) + + if args.eora_save_path: + eora_bench = bench(path=args.quantized_model, backend=BACKEND.TORCH, adapter=eora, task=args.eval_task) # inference using eora (lora) + print('--------Eval EoRA Result---------') + print(format_eval_result_table(eora_bench)) + + else: + base_bench = bench(path=args.quantized_model, backend=BACKEND.TORCH, adapter=None, task=args.eval_task) # inference using qweights only + + print('--------Eval Base Result---------') + print(format_eval_result_table(base_bench)) diff --git a/docs/eora/post_quant_eora_generation.py b/docs/eora/post_quant_eora_generation.py new file mode 100644 index 000000000..eaad35219 --- /dev/null +++ b/docs/eora/post_quant_eora_generation.py @@ -0,0 +1,81 @@ +# Copyright 2025 ModelCloud +# Contact: qubitium@modelcloud.ai, x.com/qubitium +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# -- do not touch +import os + + +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" +# -- end do not touch + + +from eora_calibration_data_construction import construct_ARC, construct_c4, construct_mmlu + +from gptqmodel import GPTQModel # noqa: E402 +from gptqmodel.adapter.adapter import Lora # noqa: E402 + + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser() + + parser.add_argument( + 'model', type=str, + help='Full-preicision model to load; pass `facebook/opt-X`.' + ) + parser.add_argument( + 'dataset', type=str, choices=['c4','arc','mmlu'], + help='Where to extract calibration data from.' + ) + parser.add_argument( + '--quantized_model',type=str, + help='Quantized model to load' + ) + parser.add_argument( + '--eora_save_path',type=str + ) + parser.add_argument( + '--eora_rank', type=int + ) + + args = parser.parse_args() + + eora = Lora( + # for eora generation, path is adapter save path; for load, it is loading path + path=os.path.join(args.eora_save_path), + rank=args.eora_rank, + ) + + if args.dataset == "c4": + calibration_dataset = construct_c4() + elif args.dataset == "arc": + calibration_dataset = construct_ARC() + elif args.dataset == "mmlu": + calibration_dataset = construct_mmlu() + else: + raise NotImplementedError + + + # eora generation and save in one step + GPTQModel.adapter.generate( + adapter=eora, + model_id_or_path=args.model, + quantized_model_id_or_path=args.quantized_model, + calibration_dataset=calibration_dataset, + calibration_dataset_concat_size=0, + ) + diff --git a/docs/quantization_protocol.md b/docs/quantization_protocol.md new file mode 100644 index 000000000..064ecdf2e --- /dev/null +++ b/docs/quantization_protocol.md @@ -0,0 +1,1894 @@ +# Quantization Protocol + +## Overview + +This document proposes a next-generation quantization configuration protocol for `gptqmodel`. + +The protocol is designed to be: + +- clean and concise for humans +- pipeline / stage based +- explicit about matching and override behavior +- flexible about quantization method vs exported representation +- future-proof for weight, activation, output, and KV-cache quantization + +The user-facing protocol root is intentionally shallow. It consists of: + +- `version` +- `stages` + +It may be authored through: + +- a Python DSL +- YAML / JSON serialization of the same protocol + +The Python and YAML forms below describe the same protocol. +Python is the ergonomic builder API. +YAML is the portable serialized form. + + +## Design Goals + +1. One matching system only. + Rules match model objects. Stages do not rematch. Actions do not rematch the whole model in normal use. + +2. Keep the common case short. + The common case should need only: + - `match` + - `weight` / `input` / `output` / `kv_cache` + - `prepare` + - `quantize` + - `export` + +3. Make overrides readable. + A narrower rule should be able to skip quantization, replace defaults, or stop later rules without confusing `+` / `-` syntax. + +4. Make partial overrides cheap. + A narrower rule should be able to override only `bits`, `group_size`, or another single leaf field without restating the full quantizer configuration. + +5. Separate quantization from representation. + `quantize` answers how quantized values are produced. + `export` answers how those values are encoded into final tensors and metadata. + +6. Keep backend-specific terms internal. + Terms such as `*input_quantizer`, `*weight_quantizer`, or packer-specific tensor names should not be the primary user-facing API. + + +## Protocol Root + +Python: + +```python +version = 2 + +stages = [ + Stage( + name="ptq", + rules=[ + Rule( + match="*", + aliases=None, + actions=[], + stop=False, + weight=None, + input=None, + output=None, + kv_cache=None, + ), + ], + ), +] +``` + +YAML: + +```yaml +version: 2 +stages: + - name: ptq + rules: + - match: "*" + aliases: null + actions: [] + stop: false + weight: null + input: null + output: null + kv_cache: null +``` + +A stage is an ordered execution boundary. +A rule is the only normal matcher. +Each rule may configure one or more tensor targets. + + +## Match Selectors + +`Rule.match` may be either: + +- a single selector string +- a list of selector strings + +Selector prefixes: + +- no prefix or `+:` means positive/include +- `-:` means negative/exclude + +This lets one rule express "match everything except ..." without adding a second skip rule. + +Python: + +```python +Rule( + match=["*", "-:.*layer2.*"], + weight={ + "quantize": gptq(bits=4, sym=True, group_size=128), + "export": {"format": "gptq"}, + }, +) +``` + +YAML: + +```yaml +- match: + - "*" + - "-:.*layer2.*" + weight: + quantize: + method: gptq + bits: 4 + sym: true + group_size: 128 + export: + format: gptq +``` + +Recommended semantics: + +- a rule matches if at least one positive selector matches +- any matching negative selector removes that module from the rule +- `*` is a special match-all shorthand +- every other selector string is interpreted as regex by default +- for exact module-name matches, use an anchored escaped regex such as `^model\.layers\.0\.self_attn\.q_proj$` + + +## Internal Implementation + +An implementation may compile the user-facing protocol into an internal typed object such as: + +```python +Plan(version=2, stages=[...]) +``` + +That internal root object is for parser/runtime organization. +It should not be required in user-facing examples or config files. +Normal user configs have one protocol root for one quantization run or artifact, not multiple user-facing plans. + + +## Authoring Surfaces + +The protocol should support both of these as first-class authoring surfaces: + +- Python DSL for programmatic authoring, helpers, and composition +- YAML for checked-in configs, serialization, export metadata, and non-Python tooling + +Example equivalence: + +Python: + +```python +Rule( + match="*", + weight={ + "quantize": gptq(bits=4, sym=True, group_size=128), + "export": {"format": "gptq"}, + }, +) +``` + +YAML: + +```yaml +match: "*" +weight: + quantize: + method: gptq + bits: 4 + sym: true + group_size: 128 + export: + format: gptq +``` + + +## Stages + +A `Stage` exists to define: + +- execution order +- calibration / replay boundary +- save / emit boundary + +A stage does not introduce a second targeting system. + +Example: + +Python: + +```python +stages = [ + Stage( + name="balance", + rules=[ + Rule( + match=".*self_attn$", + actions=[smoothquant(alpha=0.5)], + ), + ], + ), + Stage( + name="ptq", + rules=[ + Rule( + match="*", + weight={ + "quantize": gptq(bits=4, sym=True, group_size=128), + "export": {"format": "gptq"}, + }, + ), + ], + ), +] +``` + +YAML: + +```yaml +stages: + - name: balance + rules: + - match: ".*self_attn$" + actions: + - method: smoothquant + alpha: 0.5 + - name: ptq + rules: + - match: "*" + weight: + quantize: + method: gptq + bits: 4 + sym: true + group_size: 128 + export: + format: gptq +``` + + +## Rules + +A `Rule` contains: + +- `match`: the only normal matcher +- optional `aliases`: named references relative to the matched object +- optional `actions`: rule-scoped operations +- optional tensor-target sections: + - `weight` + - `input` + - `output` + - `kv_cache` +- optional `stop`: stop applying later rules to the same matched object + +Fields omitted by a narrower rule inherit from earlier matching rules unless an explicit replace mode is used. + +Recommended shape: + +Python: + +```python +Rule( + match="*", + aliases=None, + actions=[], + stop=False, + weight={...}, + input={...}, + output={...}, + kv_cache={...}, +) +``` + +YAML: + +```yaml +match: "*" +aliases: null +actions: [] +stop: false +weight: {} +input: {} +output: {} +kv_cache: {} +``` + + +## Matching + +Recommended match forms: + +- exact module path: `"model.layers.0.self_attn.q_proj"` +- wildcard / glob: `"*"` or `"*.q_proj"` +- regex: `".*self_attn$"` + +Rules are evaluated top-to-bottom inside a stage. + +There is no stage-level matcher and no normal action-level global matcher. + + +## Aliases + +`aliases` is optional. + +It exists only to name reusable relative references under the matched object. +It is not a second model-wide matching language. + +Example: + +Python: + +```python +Rule( + match=".*self_attn$", + aliases={"proj": ["q_proj", "k_proj", "v_proj", "o_proj"]}, + actions=[ + record_stats(targets="@proj"), + inspect_outliers(targets="@proj"), + ], +) +``` + +YAML: + +```yaml +match: ".*self_attn$" +aliases: + proj: + - q_proj + - k_proj + - v_proj + - o_proj +actions: + - method: record_stats + targets: "@proj" + - method: inspect_outliers + targets: "@proj" +``` + +Use `aliases` only when the same relative subset must be reused. +If the action naturally operates on the matched object, omit `aliases`. + + +## Actions + +`actions` is a rule-scoped list of operations that run in the context of the rule match. + +Examples: + +- `smoothquant(alpha=0.5)` +- `awq_balance(ratio=0.7)` +- `calibrate_router(...)` + +Default behavior: + +- an action operates on the current rule match +- an action does not rematch the whole model +- if an action needs a reusable relative subset, it may use rule `aliases` +- actions should prefer the canonical structure of the matched object over user-written sub-matching in the common case + +Important: + +- `actions` are for rule-scoped or cross-target behavior +- `prepare` is for target-local pre-quant behavior + +This keeps placement clear: + +- SmoothQuant or AWQ-like balancing: `actions` +- local weight clip / pad / smoother: `weight.prepare` + + +## Tensor Targets + +The protocol supports these first-class tensor targets: + +- `weight` +- `input` +- `output` +- `kv_cache` + +This makes the protocol future-proof for: + +- weight-only quantization +- activation quantization +- output quantization +- cache quantization + + +## Target Sections + +Each tensor target may define: + +- `prepare` +- `quantize` +- `export` + +Recommended shape: + +Python: + +```python +weight={ + "prepare": [...], # optional + "quantize": ..., # optional + "export": ..., # optional +} +``` + +YAML: + +```yaml +weight: + prepa [] + quantize: null + export: null +``` + + +### `prepare` + +`prepare` is for target-local pre-quant transformations. + +Examples that belong in `weight.prepare`: + +- `clip.mad(k=2.75)` +- `clip.percentile(percentile=99.5)` +- `pad.columns(multiple=4, semantic=True)` + +Examples that belong in `input.prepare`: + +- `clamp.range(min=-6, max=6)` +- `normalize.rms(eps=1e-5)` + +Placement rule: + +- local target-only modification -> `prepare` +- cross-target or rule-context modification -> `actions` + + +### `quantize` + +`quantize` defines how the target is quantized. + +Examples: + +- `gptq(bits=4, sym=True, group_size=128)` +- `rtn(bits=4, sym=True)` +- `mxfp4(mode="dynamic", block_size=32, scale_bits=8)` +- `int8(calibration=observer("max"))` +- `skip()` + +`skip()` is target-scoped. +It does not remove `prepare` and it does not disable other targets. + +Important: + +- `quantize` is a structured object, not a replace-only scalar +- later rules may patch only specific quantizer fields +- omitted fields inherit from earlier matching rules + +Example base quantizer: + +```yaml +weight: + quantize: + method: gptq + bits: 4 + sym: true + group_size: 128 +``` + +Example narrower override: + +```yaml +weight: + quantize: + bits: 8 +``` + +The override changes only `bits`. +It does not require restating `method`, `sym`, or `group_size`. + + +### `quantize.fallback` + +`fallback` belongs inside `quantize`. + +It is not: + +- an `action` +- a `prepare` step +- an `export` setting + +It is a quantizer-local fallback policy for methods that depend on calibration or activation evidence and may not have enough usable samples for every matched unit. + +This is the right place because fallback changes how the quantizer solves a target when evidence is insufficient. +It does not change which modules are matched, and it does not change the exported format family. + +Primary use cases: + +- GPTQ with too few Hessian / activation samples for a module +- AWQ with missing or too-sparse captured activations for a layer group +- future activation-aware weight quantizers +- MoE routing cases where some experts receive little or no calibration traffic + +Recommended protocol shape: + +Python: + +```python +Rule( + match="*", + weight={ + "quantize": { + "method": "gptq", + "bits": 4, + "group_size": 128, + "sym": True, + "fallback": { + "strategy": "rtn", + "threshold": "0.5%", + }, + }, + "export": { + "format": "gptq", + }, + }, +) +``` + +YAML: + +```yaml +- match: "*" + weight: + quantize: + method: gptq + bits: 4 + group_size: 128 + sym: true + fallback: + strategy: rtn + threshold: 0.5% + export: + format: gptq +``` + +Recommended fields: + +- `strategy`: fallback quantization strategy +- `threshold`: minimum evidence threshold before fallback triggers +- `smooth`: optional smoothing used only inside the fallback path + +Example: + +```yaml +weight: + quantize: + method: awq + bits: 4 + group_size: 128 + fallback: + strategy: rtn + threshold: 1.0% + smooth: + type: mad + k: 2.75 +``` + +Recommended threshold semantics: + +- integer / float: absolute minimum observed samples or tokens +- percent string such as `"0.5%"`: minimum observed coverage relative to expected calibration traffic +- `true`: enable quantizer default threshold +- `false` or `null`: disable fallback + +Initial runtime contract: + +- GPTQ: evaluate fallback per matched module +- AWQ: evaluate fallback at the quantizer's natural scaling group or layer subgroup +- future methods: evaluate fallback at the quantizer's natural solve unit + +The protocol should not force one global fallback scope. +Fallback should use the quantizer's native solve scope. + +Important separation: + +- `quantize.method = gptq` with `fallback.strategy = rtn` means: + GPTQ is still the primary method +- if the module or group is under-sampled, fallback quantization uses RTN-like weight-only solving +- the rule's `export` still controls the final encoded representation + +Example: + +```yaml +weight: + quantize: + method: gptq + bits: 4 + fallback: + strategy: rtn + threshold: 0.5% + export: + format: gptq +``` + +This does not mean "export as RTN". +It means: + +- primary solve path: GPTQ +- low-evidence fallback solve path: RTN +- final export family: GPTQ + +That matches current `gptqmodel` behavior more closely than treating fallback as a separate stage or a weight-only top-level config. + +Patch-first override behavior should apply here too. + +Example base rule: + +```yaml +- match: "*" + weight: + quantize: + method: gptq + bits: 4 + group_size: 128 + fallback: + strategy: rtn + threshold: 0.5% +``` + +Example narrower MoE override: + +```yaml +- match: ".*experts\\.[0-9]+\\..*" + weight: + quantize: + fallback: + threshold: 2.0% +``` + +Effective result for expert modules: + +- `method = gptq` +- `bits = 4` +- `group_size = 128` +- `fallback.strategy = rtn` +- `fallback.threshold = 2.0%` + +This is how fallback should fit into the new protocol: + +- nested under `target.quantize` +- inherited and patchable like other quantizer fields +- supported only for quantizers that actually depend on calibration / activations +- independent from `export` + + +### `export` + +`export` defines the final encoded representation for the target. + +The canonical form of `export` should be a structured object. +String export should be treated only as shorthand for very simple cases. + +Canonical fields: + +- `format`: logical exported family such as `gptq`, `awq`, `fp8`, `fp4`, `gguf`, `native` +- `variant`: family-specific subtype such as `gemm`, `gemv`, `e4m3fn`, `e5m2`, `nvfp4`, `mxfp4`, `q4_k_m` +- `impl`: concrete exporter or runtime implementation such as `default`, `llm_awq`, `marlin`, `transformer_engine`, `modelopt` +- `version`: exporter layout / schema version +- `options`: exporter-specific knobs that should not be promoted to top-level DSL fields + +Python: + +```python +weight={ + "export": { + "format": "awq", + "variant": "gemm", + "impl": "llm_awq", + "version": 2, + }, +} +``` + +YAML: + +```yaml +weight: + export: + format: awq + variant: gemm + impl: llm_awq + version: 2 +``` + +Examples: + +- `{"format": "gptq"}` +- `{"format": "awq", "variant": "gemm"}` +- `{"format": "awq", "variant": "gemv"}` +- `{"format": "fp8", "variant": "e4m3fn", "impl": "transformer_engine"}` +- `{"format": "fp4", "variant": "nvfp4", "impl": "modelopt"}` +- `{"format": "gguf", "variant": "q4_k_m"}` + +Shorthand: + +- `"gptq"` == `{"format": "gptq"}` +- `"native"` == `{"format": "native"}` + +If omitted, the engine may use the quantizer's native export. + +Like `quantize`, `export` should be patchable by narrower rules. + +Example: + +```yaml +- match: "*" + weight: + export: + format: awq + variant: gemm + impl: llm_awq + version: 2 + +- match: ".*small_proj$" + weight: + export: + variant: gemv +``` + +Effective result for `small_proj`: + +- `format = awq` +- `variant = gemv` +- `impl = llm_awq` +- `version = 2` + + +## Patch-First Override Model + +Rules should be treated as patches over an accumulated effective configuration. + +This is the main simplification for dynamic overrides. + +Common case: + +- a broad rule defines defaults +- a narrower rule patches only the fields it wants to change +- unchanged fields inherit automatically + +Python: + +```python +Rule( + match="*", + weight={ + "quantize": gptq(bits=4, sym=True, group_size=128), + "export": {"format": "gptq"}, + }, +) + +Rule( + match=".*down_proj$", + weight={ + "quantize": {"bits": 3}, + }, +) +``` + +YAML: + +```yaml +- match: "*" + weight: + quantize: + method: gptq + bits: 4 + sym: true + group_size: 128 + export: + format: gptq + +- match: ".*down_proj$" + weight: + quantize: + bits: 3 +``` + +Effective result for `down_proj`: + +- `method = gptq` +- `bits = 3` +- `sym = true` +- `group_size = 128` +- `export.format = gptq` + +This is the intended replacement for the current `gptqmodel` dynamic override style where a base rule applies to all modules and narrower matches override only selected fields. + + +## Advanced Replace Mode + +Patch merging should be the default. +Explicit replacement should be available only as an escape hatch. + +Recommended advanced control: + +- `mode: replace` + +Python: + +```python +Rule( + match="layer0.qkv", + weight={ + "mode": "replace", + "prepare": [pad.columns(multiple=4, semantic=True)], + "quantize": skip(), + }, +) +``` + +YAML: + +```yaml +match: "layer0.qkv" +weight: + mode: replace + prepa + - method: pad.columns + multiple: 4 + semantic: true + quantize: + method: skip +``` + +`mode: replace` is advanced. +Users should not need it for normal per-layer overrides like changing only `bits`. + + +## Why `export` Is Separate From `quantize` + +These are different questions: + +1. How is the quantized state computed? +2. How is that state emitted as final tensors and metadata? + +Example: + +Python: + +```python +weight={ + "quantize": rtn(bits=4, sym=True), + "export": {"format": "gptq", "impl": "default"}, +} +``` + +YAML: + +```yaml +weight: + quantize: + method: rtn + bits: 4 + sym: true + export: + format: gptq + impl: default +``` + +This means: + +- use RTN to produce the quantized weight state +- encode that state in GPTQ-style exported tensors + +For GPTQ itself: + +Python: + +```python +weight={ + "quantize": gptq(bits=4, sym=True, group_size=128), + "export": {"format": "gptq", "impl": "default"}, +} +``` + +YAML: + +```yaml +weight: + quantize: + method: gptq + bits: 4 + sym: true + group_size: 128 + export: + format: gptq + impl: default +``` + +Here, GPTQ packing is part of export realization. +It does not need to be a separate first-class user concept. + +Conceptually: + +```text +W + -> quantize(method=GPTQ) + -> logical quantized state + -> export("gptq") + -> qweight + scales + qzeros + g_idx + metadata +``` + +So `export` is the correct user-facing property, while internal packing details remain backend implementation details. +This is also why the canonical `export` form should be an object rather than a string-only enum. + + +## Activation Quantization + +The protocol should expose activation quantization through tensor targets, not backend-internal names. + +Example: + +Python: + +```python +Rule( + match="*", + input={ + "quantize": mxfp4(mode="dynamic", block_size=32, scale_bits=8), + "export": { + "format": "fp4", + "variant": "mxfp4", + "impl": "modelopt", + }, + }, +) +``` + +YAML: + +```yaml +match: "*" +input: + quantize: + method: mxfp4 + mode: dynamic + block_size: 32 + scale_bits: 8 + export: + format: fp4 + variant: mxfp4 + impl: modelopt +``` + +This corresponds conceptually to installing an input activation quantizer on matched modules. + +In NVIDIA Model Optimizer terms, this is cleaner than exposing `*input_quantizer` directly. + +Important: + +- `input` means the activation entering the matched module +- `output` means the activation leaving the matched module +- these are tensor-target concepts, not inserted-submodule names + + +## Activation-aware GPTQ + +If `weight.quantize = gptq(...)` and `input.quantize = ...` coexist, the weight quantizer may need to know whether it should optimize using full-precision or quantized inputs. + +Recommended future-proof parameter: + +Python: + +```python +gptq( + bits=4, + sym=True, + group_size=128, + activation_mode="ignore", # or "fake", later possibly "real" +) +``` + +YAML: + +```yaml +method: gptq +bits: 4 +sym: true +group_size: 128 +activation_mode: ignore +``` + +Meaning: + +- `"ignore"`: classic weight-only GPTQ +- `"fake"`: optimize with fake-quantized inputs active +- `"real"`: reserved for future real low-bit activation flow + + +## Merge And Override Semantics + +Rules compose top-to-bottom within a stage. + +Recommended semantics: + +- broader rules define defaults +- narrower rules refine or override +- target sections merge recursively by default +- target-local lists such as `prepare` append by default +- quantizer leaf fields such as `bits`, `sym`, and `group_size` are last-match-wins +- exporter leaf fields such as `format`, `variant`, `impl`, and `version` are last-match-wins +- `skip()` is target-scoped +- `stop=True` prevents later rules from changing the same matched object +- if `quantize.method` changes, previous quantizer-specific fields should be discarded unless explicitly repeated +- if `export.format` changes, previous format-specific export fields should be discarded unless explicitly repeated + +Example: global default plus narrow skip + +Python: + +```python +Rule( + match="*", + weight={ + "prepare": [pad.columns(multiple=4, semantic=True)], + "quantize": gptq(bits=4, sym=True, group_size=128), + "export": {"format": "gptq", "impl": "default"}, + }, + input={ + "quantize": mxfp4(mode="dynamic", block_size=32, scale_bits=8), + "export": { + "format": "fp4", + "variant": "mxfp4", + "impl": "modelopt", + }, + }, +) + +Rule( + match="layer0.qkv", + weight={ + "quantize": skip(), + }, +) +``` + +YAML: + +```yaml +- match: "*" + weight: + prepa + - method: pad.columns + multiple: 4 + semantic: true + quantize: + method: gptq + bits: 4 + sym: true + group_size: 128 + export: + format: gptq + impl: default + input: + quantize: + method: mxfp4 + mode: dynamic + block_size: 32 + scale_bits: 8 + export: + format: fp4 + variant: mxfp4 + impl: modelopt + +- match: "layer0.qkv" + weight: + quantize: + method: skip +``` + +Effective result for `layer0.qkv`: + +- keep `weight.prepare = [pad.columns(...)]` +- skip `weight.quantize` +- keep default `input.quantize = mxfp4(...)` + +Example: base config plus bits-only override + +Python: + +```python +Rule( + match="*", + weight={ + "quantize": gptq(bits=4, sym=True, group_size=128), + "export": {"format": "gptq", "impl": "default"}, + }, +) + +Rule( + match=".*(q_proj|k_proj)$", + weight={ + "quantize": {"bits": 8}, + }, +) +``` + +YAML: + +```yaml +- match: "*" + weight: + quantize: + method: gptq + bits: 4 + sym: true + group_size: 128 + export: + format: gptq + impl: default + +- match: ".*(q_proj|k_proj)$" + weight: + quantize: + bits: 8 +``` + +Effective result for `q_proj` and `k_proj`: + +- `method = gptq` +- `bits = 8` +- `sym = true` +- `group_size = 128` + +Example: explicit replacement plus stop + +Python: + +```python +Rule( + match="layer0.qkv", + stop=True, + weight={ + "mode": "replace", + "prepare": [pad.columns(multiple=4, semantic=True)], + "quantize": skip(), + }, +) +``` + +YAML: + +```yaml +match: "layer0.qkv" +stop: true +weight: + mode: replace + prepa + - method: pad.columns + multiple: 4 + semantic: true + quantize: + method: skip +``` + +This means: + +- replace inherited weight config with only the fields given here +- do not let later rules change `layer0.qkv` + + +## Execution Semantics + +Within a stage, recommended engine order is: + +1. evaluate rules in order +2. resolve matches +3. resolve optional aliases +4. run rule `actions` +5. run target `prepare` +6. collect calibration / replay data required by quantizers +7. run target `quantize` +8. run target `export` +9. emit stage outputs + +Stage order then defines the full pipeline order. + + +## How Rule Actions And Target Config Work Together + +This is the intended pattern: + +Python: + +```python +Rule( + match=".*self_attn$", + actions=[smoothquant(alpha=0.5)], +) + +Rule( + match="*", + weight={ + "prepare": [clip.mad(k=2.75)], + "quantize": gptq(bits=4, sym=True, group_size=128), + "export": {"format": "gptq", "impl": "default"}, + }, + input={ + "quantize": mxfp4(mode="dynamic", block_size=32, scale_bits=8), + "export": { + "format": "fp4", + "variant": "mxfp4", + "impl": "modelopt", + }, + }, +) +``` + +YAML: + +```yaml +- match: ".*self_attn$" + actions: + - method: smoothquant + alpha: 0.5 + +- match: "*" + weight: + prepa + - method: clip.mad + k: 2.75 + quantize: + method: gptq + bits: 4 + sym: true + group_size: 128 + export: + format: gptq + impl: default + input: + quantize: + method: mxfp4 + mode: dynamic + block_size: 32 + scale_bits: 8 + export: + format: fp4 + variant: mxfp4 + impl: modelopt +``` + +Meaning: + +- each matched attention block first receives the `smoothquant(...)` action in its own rule context +- later, the global rule supplies default weight and input quantization policy +- submodules that the action touches inside those attention blocks still receive the global defaults from the later rule + +The user does not need to add action-level rematching such as `on=[...]` in the normal case. +The rule already defines the scope. + + +## Recommended Authoring Patterns + +### 1. Global weight default + +Python: + +```python +Rule( + match="*", + weight={ + "quantize": gptq(bits=4, sym=True, group_size=128), + "export": {"format": "gptq", "impl": "default"}, + }, +) +``` + +YAML: + +```yaml +match: "*" +weight: + quantize: + method: gptq + bits: 4 + sym: true + group_size: 128 + export: + format: gptq + impl: default +``` + +### 2. Weight and input together + +Python: + +```python +Rule( + match="*", + weight={ + "quantize": gptq(bits=4, sym=True, group_size=128), + "export": {"format": "gptq", "impl": "default"}, + }, + input={ + "quantize": mxfp4(mode="dynamic", block_size=32, scale_bits=8), + "export": { + "format": "fp4", + "variant": "mxfp4", + "impl": "modelopt", + }, + }, +) +``` + +YAML: + +```yaml +match: "*" +weight: + quantize: + method: gptq + bits: 4 + sym: true + group_size: 128 + export: + format: gptq + impl: default +input: + quantize: + method: mxfp4 + mode: dynamic + block_size: 32 + scale_bits: 8 + export: + format: fp4 + variant: mxfp4 + impl: modelopt +``` + +### 3. Rule-scoped balancing action + +Python: + +```python +Rule( + match=".*self_attn$", + actions=[smoothquant(alpha=0.5)], +) +``` + +YAML: + +```yaml +match: ".*self_attn$" +actions: + - method: smoothquant + alpha: 0.5 +``` + +### 4. Skip weight quantization but keep other defaults + +Python: + +```python +Rule( + match="layer0.qkv", + weight={ + "quantize": skip(), + }, +) +``` + +YAML: + +```yaml +match: "layer0.qkv" +weight: + quantize: + method: skip +``` + +### 5. Quantize with one method, export as another format + +Python: + +```python +Rule( + match=".*down_proj$", + weight={ + "quantize": rtn(bits=4, sym=True), + "export": {"format": "gptq", "impl": "default"}, + }, +) +``` + +YAML: + +```yaml +match: ".*down_proj$" +weight: + quantize: + method: rtn + bits: 4 + sym: true + export: + format: gptq + impl: default +``` + +### 6. Override only one quantizer field + +Python: + +```python +Rule( + match="*", + weight={ + "quantize": gptq(bits=4, sym=True, group_size=128), + "export": {"format": "gptq", "impl": "default"}, + }, +) + +Rule( + match="model.layers.0.self_attn.q_proj", + weight={ + "quantize": {"bits": 2}, + }, +) +``` + +YAML: + +```yaml +- match: "*" + weight: + quantize: + method: gptq + bits: 4 + sym: true + group_size: 128 + export: + format: gptq + impl: default + +- match: "model.layers.0.self_attn.q_proj" + weight: + quantize: + bits: 2 +``` + +### 7. Override only one exporter field + +Python: + +```python +Rule( + match="*", + weight={ + "quantize": awq(bits=4, sym=True, group_size=128), + "export": { + "format": "awq", + "variant": "gemm", + "impl": "llm_awq", + "version": 2, + }, + }, +) + +Rule( + match=".*small_proj$", + weight={ + "export": {"variant": "gemv"}, + }, +) +``` + +YAML: + +```yaml +- match: "*" + weight: + quantize: + method: awq + bits: 4 + sym: true + group_size: 128 + export: + format: awq + variant: gemm + impl: llm_awq + version: 2 + +- match: ".*small_proj$" + weight: + export: + variant: gemv +``` + + +## Real Test-Derived Examples + +The examples below are translations of real repo tests into the proposed protocol. +They preserve the tested quantization intent, but they do not try to mirror every harness detail such as evaluation tasks, prompt text, or temporary save paths. + + +### 1. GPTQ with per-module overrides + +Source tests: + +- `tests/test_dynamic.py` + +Current tested behavior: + +- base quantization is 4-bit GPTQ +- base group size is 128 +- `up_proj` and `gate_proj` are overridden to 8-bit +- `down_proj` keeps 4-bit but overrides group size to 32 + +Python: + +```python +Stage( + name="ptq", + rules=[ + Rule( + match="*", + weight={ + "quantize": { + "method": "gptq", + "bits": 4, + "group_size": 128, + }, + "export": { + "format": "gptq", + "impl": "default", + }, + }, + ), + Rule( + match=".*\\.up_proj.*", + weight={ + "quantize": {"bits": 8}, + }, + ), + Rule( + match=".*\\.gate_proj.*", + weight={ + "quantize": {"bits": 8}, + }, + ), + Rule( + match=".*\\.down_proj.*", + weight={ + "quantize": {"bits": 4, "group_size": 32}, + }, + ), + ], +) +``` + +YAML: + +```yaml +stages: + - name: ptq + rules: + - match: "*" + weight: + quantize: + method: gptq + bits: 4 + group_size: 128 + export: + format: gptq + impl: default + - match: ".*\\.up_proj.*" + weight: + quantize: + bits: 8 + - match: ".*\\.gate_proj.*" + weight: + quantize: + bits: 8 + - match: ".*\\.down_proj.*" + weight: + quantize: + bits: 4 + group_size: 32 +``` + +This is the clearest example of why the protocol uses patch-first override semantics. +The narrower rules change only the leaf fields they care about. + + +### 2. AWQ GEMM full-model quantization + +Source tests: + +- `tests/test_awq.py` +- `tests/models/awq/test_llama3_2.py` +- `tests/models/model_test.py` + +Current tested behavior: + +- method is AWQ +- bits are 4 +- group size is 128 +- one tested export target is AWQ GEMM +- one tested runtime backend is `BACKEND.TORCH_AWQ` + +Python: + +```python +Stage( + name="ptq", + rules=[ + Rule( + match="*", + weight={ + "quantize": { + "method": "awq", + "bits": 4, + "group_size": 128, + "sym": True, + }, + "export": { + "format": "awq", + "variant": "gemm", + }, + }, + ), + ], +) +``` + +YAML: + +```yaml +stages: + - name: ptq + rules: + - match: "*" + weight: + quantize: + method: awq + bits: 4 + group_size: 128 + sym: true + export: + format: awq + variant: gemm +``` + +In `tests/test_awq.py`, the same pattern is also exercised with other export variants such as: + +- `gemv` +- `gemv_fast` +- `llm_awq` + +That is exactly why `export` needs to be an object and not just a single string token. + + +### 3. RTN with weight smoothing and AWQ GEMM export + +Source tests: + +- `tests/test_weight_only_config.py` + +Note: + +- the repo currently exercises RTN primarily through weight-only/config and format-conversion tests rather than a `tests/models/*` full-model test case + +Current tested behavior: + +- method is RTN +- bits are 4 +- group size is 128 +- smoothing uses `SmoothMAD(k=1.5)` +- export target is AWQ GEMM + +Python: + +```python +Stage( + name="weight_only", + rules=[ + Rule( + match="*", + weight={ + "prepare": [ + {"method": "smooth.mad", "k": 1.5}, + ], + "quantize": { + "method": "rtn", + "bits": 4, + "group_size": 128, + }, + "export": { + "format": "awq", + "variant": "gemm", + }, + }, + ), + ], +) +``` + +YAML: + +```yaml +stages: + - name: weight_only + rules: + - match: "*" + weight: + prepa + - method: smooth.mad + k: 1.5 + quantize: + method: rtn + bits: 4 + group_size: 128 + export: + format: awq + variant: gemm +``` + +Related repo test: + +- `tests/test_format_conversion_flow.py` also verifies that RTN can export to GPTQ format with `RTNConfig(bits=4, format=FORMAT.GPTQ, offload_to_disk=False)` + +That is a concrete existing example of the protocol's `quantize != export` split. + + +## Migration From Current `gptqmodel` + +Current `gptqmodel` configuration is primarily weight-centric. + +A straightforward migration path is: + +- base `QuantizeConfig` -> broad default rule +- `dynamic` positive override -> narrower rule +- partial `dynamic` override fields -> quantizer patch fields such as `weight.quantize.bits` +- `-:` negative skip -> target-scoped `skip()` +- smoothers / preprocessors -> `weight.prepare` +- method vs output representation split -> `quantize` vs `export` + +This keeps current intent while making the protocol ready for activation and cache quantization. + + +## Non-Goals + +The protocol should not make these primary user concepts: + +- inserted quantizer submodule names +- internal packer tensor names +- stage-level rematching +- action-level global rematching in normal usage + + +## Final Shape + +The recommended protocol shape is: + +Python: + +```python +version = 2 + +stages = [ + Stage( + name="ptq", + rules=[ + Rule( + match=".*self_attn$", + actions=[smoothquant(alpha=0.5)], + ), + Rule( + match="*", + weight={ + "prepare": [clip.mad(k=2.75)], + "quantize": gptq( + bits=4, + sym=True, + group_size=128, + activation_mode="fake", + ), + "export": {"format": "gptq", "impl": "default"}, + }, + input={ + "quantize": mxfp4( + mode="dynamic", + block_size=32, + scale_bits=8, + ), + "export": { + "format": "fp4", + "variant": "mxfp4", + "impl": "modelopt", + }, + }, + ), + Rule( + match="layer0.qkv", + weight={ + "quantize": skip(), + }, + ), + ], + ), +] +``` + +YAML: + +```yaml +version: 2 +stages: + - name: ptq + rules: + - match: ".*self_attn$" + actions: + - method: smoothquant + alpha: 0.5 + - match: "*" + weight: + prepa + - method: clip.mad + k: 2.75 + quantize: + method: gptq + bits: 4 + sym: true + group_size: 128 + activation_mode: fake + export: + format: gptq + impl: default + input: + quantize: + method: mxfp4 + mode: dynamic + block_size: 32 + scale_bits: 8 + export: + format: fp4 + variant: mxfp4 + impl: modelopt + - match: "layer0.qkv" + weight: + quantize: + method: skip +``` + +This keeps the model concise: + +- one matcher +- optional aliases +- rule-scoped actions +- first-class tensor targets +- local `prepare` +- explicit `quantize` +- explicit `export` +- readable override and stop semantics diff --git a/docs/qwen35_vllm_comparison.md b/docs/qwen35_vllm_comparison.md new file mode 100644 index 000000000..b597dc950 --- /dev/null +++ b/docs/qwen35_vllm_comparison.md @@ -0,0 +1,209 @@ +# Qwen3.5 vLLM Comparison: Patched GPTQ-Pro vs Vanilla vLLM + +## Scope + +This note compares the previously validated GPTQ-Pro Qwen3.5 deployment path against +**vanilla vLLM 0.17.0** on the same `wangzhang/Qwen3.5-4B-abliterated` model family: + +- original BF16 model: `wangzhang/Qwen3.5-4B-abliterated` +- plain GPTQ 4-bit g128: `/tmp/Qwen3.5-4B-abliterated-GPTQ-4bit` +- GPTQ-Pro 4-bit g128: `/tmp/Qwen3.5-4B-abliterated-GPTQ-Pro-4bit` + +The goal was to answer a simple question: **does unmodified vLLM deploy these exact same +Qwen3.5 checkpoints, and how does that compare with the patched GPTQ-Pro path already benchmarked?** + +## Environment + +- vLLM: `0.17.0` +- Transformers: `5.3.0` +- PyTorch: `2.10.0` +- GPU test target: 1x RTX 3090 for smoke deployment, 1x/2x RTX 3090 for the prior GPTQ-Pro benchmark + +For the vanilla vLLM smoke tests, I set `VLLM_PLUGINS=""` only to suppress an unrelated +environment-level plugin load failure (`reap`). No Qwen-specific monkeypatches were applied. + +## Quantization Time Summary + +| Artifact | Quantization method | Quantization time | +|----------|---------------------|-------------------| +| Plain GPTQ 4-bit g128 | `QuantizeConfig(bits=4, group_size=128)` | `181.4s` | +| GPTQ-Pro 4-bit g128 | `QuantizeConfig.gptq_pro(bits=4, group_size=128)` | `324.9s` | + +GPTQ-Pro took `143.5s` longer than the plain GPTQ run, a `1.79x` quantization-time increase. + +## Quality Snapshot + +All perplexity numbers below were measured on WikiText-2 test with the same sliding-window setup: +`max_length=2048`, `stride=512`, `578` windows, `297,053` tokens total. + +| Artifact | Perplexity | +|----------|------------| +| Original BF16 | `8.3116` | +| Plain GPTQ 4-bit g128 | `8.6759` | +| GPTQ-Pro 4-bit g128 | `8.6314` | + +GPTQ-Pro recovered `0.0445` PPL relative to plain GPTQ while remaining in a GPTQ-compatible +checkpoint format. + +## Vanilla vLLM Deployment Result + +I attempted to deploy all three artifacts with **unmodified vLLM 0.17.0** by initializing +`vllm.LLM(..., language_model_only=True)` and requesting a short greedy generation. + +| Artifact | Vanilla vLLM result | Time to failure | First blocking error | +|----------|---------------------|-----------------|----------------------| +| Original BF16 | Failed before first token | `11.16s` | `TypeError: Invalid type of HuggingFace config ... expected Qwen3_5Config, found Qwen3_5TextConfig` | +| Plain GPTQ 4-bit g128 | Failed before first token | `6.20s` | same config-type mismatch | +| GPTQ-Pro 4-bit g128 | Failed before first token | `6.29s` | same config-type mismatch | + +### Key finding + +For this environment and vLLM version, **vanilla vLLM does not deploy any of the tested +Qwen3.5 text-only checkpoints**, regardless of whether the model is original BF16, plain GPTQ, +or GPTQ-Pro. The failure occurs before model execution and is therefore **not** caused by the +quantization format itself. + +## Patched GPTQ-Pro vLLM Result + +The previously validated GPTQ-Pro benchmark used a temporary runtime patch that: + +- wraps the HF `qwen3_5_text` config in vLLM's `Qwen3_5Config` +- forces `language_model_only=True` +- skips vision / multimodal initialization +- remaps `model.*` checkpoint prefixes to `language_model.model.*` + +With that patch, vLLM selected `gptq_marlin` and completed generation successfully. + +| GPU config | max_new_tokens | Tokens/sec | Engine init | +|------------|----------------|------------|-------------| +| 1x RTX 3090 | `128` | `175.21` | `37.03s` | +| 1x RTX 3090 | `256` | `178.14` | `37.03s` | +| 2x RTX 3090 | `128` | `194.20` | `56.53s` | +| 2x RTX 3090 | `256` | `206.53` | `56.53s` | + +## Comparison Summary + +| Dimension | Vanilla vLLM 0.17.0 | Patched GPTQ-Pro vLLM | +|-----------|----------------------|------------------------| +| Original BF16 deploys | No | not benchmarked in the validated harness | +| Plain GPTQ deploys | No | not re-validated in the benchmark harness | +| GPTQ-Pro deploys | No | Yes | +| GPTQ-Pro throughput on 1x 3090 | N/A | `175.21-178.14 tok/s` | +| GPTQ-Pro throughput on 2x 3090 | N/A | `194.20-206.53 tok/s` | +| Main blocker | `Qwen3_5TextConfig` vs `Qwen3_5Config` mismatch | patched around | +| Readiness | blocked upstream for this model family | usable for GPTQ-Pro benchmarking, but still hacky | + +## Interpretation + +1. **The limiting factor is upstream Qwen3.5 text-only support in vLLM, not GPTQ-Pro.** + Vanilla vLLM fails on the original model and on both quantized checkpoints with the same + config-type error. + +2. **GPTQ-Pro remains the best quantized artifact tested here.** + It improved PPL over plain GPTQ (`8.6314` vs `8.6759`) while keeping GPTQ-compatible output + that can be consumed by the patched Marlin path. + +3. **Quantization quality costs extra offline time.** + GPTQ-Pro took `324.9s` to quantize versus `181.4s` for plain GPTQ, but that extra cost bought + better perplexity retention. + +4. **Vanilla-vLLM benchmarking is currently impossible for these exact Qwen3.5 checkpoints.** + Since vanilla vLLM never reaches first token, there is no apples-to-apples throughput number to + compare directly against the patched GPTQ-Pro benchmark. + +## Bottom Line + +For `wangzhang/Qwen3.5-4B-abliterated` and its two local 4-bit derivatives, the comparison is: + +- **vanilla vLLM 0.17.0:** cannot deploy any of them in this environment +- **patched vLLM GPTQ-Pro path:** deploys and benchmarks successfully, reaching up to + `206.53 tok/s` on `2x RTX 3090` + +If the goal is production deployment without custom runtime patching, the blocker is still +upstream vLLM support for `qwen3_5_text`. + +## Follow-up: local `lukey03/Qwen3.5-9B-abliterated` GPTQ-Pro serve path + +The later local validation on `lukey03/Qwen3.5-9B-abliterated` answered the +more operational question for this repository: **why did the first 9B benchmark +look slow, and what does the corrected vLLM path actually do?** + +### Quantization / quality snapshot + +| Artifact | Quantization time | Perplexity | +|----------|-------------------|------------| +| Original BF16 | n/a | `8.8980` | +| GPTQ-Pro 4-bit g128 | `415.2s` | `9.2119` | + +### Why the earlier speed result looked bad + +The original throughput test for this 9B checkpoint used +`GPTQModel.load(...).generate()` on the Triton / Transformers path, not the +intended `vLLM` + `gptq_marlin` runtime. That is why the first numbers were: + +| Runtime path | 1 GPU | 2 GPU | +|--------------|-------|-------| +| `GPTQModel.generate()` diagnostic path | `15.61 tok/s` | `10.44 tok/s` | + +Those numbers were useful for diagnosis, but they were **not** measurements of +the optimized serving path for GPTQ-Pro artifacts in this repository. + +### Corrected vLLM deployment result + +The local wrapper / patch stack now does all of the following for +`qwen3_5_text` checkpoints: + +- forces text-only `Qwen3.5` serving settings +- keeps the checkpoint on `Qwen3_5ForCausalLM` +- restores the hybrid + M-RoPE interfaces required for Qwen3.5's hybrid + attention / linear-attention cache layout +- patches vLLM's Python-side NVML enumeration to the selected visible GPUs +- installs a small `LD_PRELOAD` NVML remap shim so NCCL tensor-parallel startup + can ignore a broken physical GPU during topology discovery + +With that corrected path, vLLM resolves to `Qwen3_5ForCausalLM`, converts the +checkpoint to `gptq_marlin`, and uses `MarlinLinearKernel`. + +| Runtime path | Notes | Tokens/sec | +|--------------|-------|------------| +| vLLM + `gptq_marlin` on `1x RTX 3090` | warmed 64-token completion | `104.96 tok/s` | +| vLLM + `gptq_marlin` on `2x RTX 3090` | warmed 64-token completion, `tensor_parallel_size=2`, `gpu_memory_utilization=0.4` on the shared host | `154.26 tok/s` | + +### Operational caveats observed on the 9B run + +1. **Cold-start requests are much slower than steady-state requests.** + The first request after startup pays for `torch.compile` and CUDA-graph + capture. On this host, that made the first completion look dramatically + slower than the second warmed run. + +2. **The 2-GPU tensor-parallel path is now functionally fixed, but host + conditions still matter.** + The local NCCL/NVML shim was required because one physical GPU on the host + returns `NVMLError_Unknown` during topology discovery. After that was fixed, + a separate shared-host VRAM constraint still required lowering + `gpu_memory_utilization` from the default `0.9` to `0.4`. + +3. **TP=2 scaling is still below ideal on this host.** + vLLM reported that custom all-reduce was disabled because GPU P2P capability + was unavailable or the P2P test failed on the selected GPU pair. That + explains why `2x GPU` improved throughput versus `1x GPU`, but not by a full + `2x`. + +4. **The standalone `gptq_pro` CUDA scaffold is still not the serving kernel.** + The production-fast runtime for these GPTQ-Pro checkpoints remains + `vLLM` + `gptq_marlin`. The separate `gptqmodel_ext/gptq_pro/` scaffold is + validated for Ampere correctness, but it is not yet wired into Python + inference dispatch. + +### Remaining Ampere work after the runtime fixes + +The remaining Ampere-specific kernel work is still performance engineering, not +a correctness blocker: + +- swizzled `ldmatrix` path +- real `cp.async` pipelining +- larger multi-warp tiles +- coalesced epilogue / transpose store path +- Paro metadata integration +- INT8 sibling / rescue path +- Nsight-guided tuning diff --git a/docs/qwen35_vllm_launch.md b/docs/qwen35_vllm_launch.md new file mode 100644 index 000000000..88a9fb7ff --- /dev/null +++ b/docs/qwen35_vllm_launch.md @@ -0,0 +1,291 @@ +# Launching Qwen 3.5 GPTQ / GPTQ-Pro checkpoints with the same vLLM path that hit ~100+ tok/s + +This document shows the **actual serving path** used for the `qwen3_5_text` checkpoints that reached: + +- **104.96 tok/s** on **1x RTX 3090** +- **154.26 tok/s** on **2x RTX 3090** (`tensor_parallel_size=2`, `gpu_memory_utilization=0.4` on the shared host) + +These numbers were **not** from `GPTQModel.generate()`. +They came from **vLLM + `gptq_marlin`**, launched through this repo's wrapper: + +- `scripts/serve_vllm_qwen35.py` + +## Why this wrapper exists + +For the `qwen3_5_text` checkpoints used in this repo, calling plain `vllm serve ...` directly was not the reliable path. +The wrapper exists because it: + +1. forces the required **text-only Qwen 3.5** serving settings +2. patches vLLM's startup so `qwen3_5_text` stays on the **causal-LM** path +3. restores the hybrid / M-RoPE interfaces expected by Qwen 3.5 text checkpoints +4. remaps vLLM's NVML scan to the GPUs already selected by `CUDA_VISIBLE_DEVICES` +5. installs a small `LD_PRELOAD` NVML shim so NCCL startup can ignore a broken physical GPU on shared hosts + +If you want the same runtime path that produced the good numbers, **use the wrapper, not raw `vllm serve`**. + +--- + +## Prerequisites + +From the repo root: + +```bash +cd /home/op/GPTQ-Pro +``` + +Activate the environment used for the repo's vLLM work: + +```bash +conda activate gptq-pro-vllm +``` + +Or, if you built the repo another way, make sure the environment contains: + +- editable install of this repo +- `vllm` +- CUDA working normally +- `nvidia-smi` available + +## Supported checkpoint type for this path + +This launch flow is for **local Hugging Face / Safetensors-style Qwen 3.5 text checkpoints**, especially: + +- original local Qwen 3.5 text checkpoints +- GPTQ checkpoints +- GPTQ-Pro checkpoints that still export standard GPTQ-compatible weights for vLLM / Marlin + +A quick sanity check is: + +```bash +jq -r '.model_type' /path/to/model/config.json +``` + +For the problematic local family documented in this repo, that typically returns: + +```bash +qwen3_5_text +``` + +--- + +## The known-good launch pattern + +## 1) Single-GPU launch + +This is the path corresponding to the warmed **~104.96 tok/s** result on **1x RTX 3090**. + +```bash +cd /home/op/GPTQ-Pro +conda activate gptq-pro-vllm + +CUDA_VISIBLE_DEVICES=0 \ +python scripts/serve_vllm_qwen35.py \ + --model /home/op/outputs/lukey03-Qwen3.5-9B-abliterated-gptq-pro-w4g128 \ + --served-model-name qwen35-9b-gptq-pro \ + --host 0.0.0.0 \ + --port 8011 \ + --tensor-parallel-size 1 +``` + +### Notes + +- `CUDA_VISIBLE_DEVICES=0` selects the single physical GPU to use. +- The wrapper auto-enables the text-only Qwen 3.5 settings for local folders and Hub repo IDs whose config resolves to `qwen3_5_text`. +- You do **not** need to manually set `language_model_only=True` through the CLI; the wrapper handles it. + +## 2) Two-GPU launch + +This is the path corresponding to the warmed **~154.26 tok/s** result on **2x RTX 3090** on the shared host. + +```bash +cd /home/op/GPTQ-Pro +conda activate gptq-pro-vllm + +CUDA_VISIBLE_DEVICES=0,1 \ +python scripts/serve_vllm_qwen35.py \ + --model /home/op/outputs/lukey03-Qwen3.5-9B-abliterated-gptq-pro-w4g128 \ + --served-model-name qwen35-9b-gptq-pro \ + --host 0.0.0.0 \ + --port 8012 \ + --tensor-parallel-size 2 \ + --gpu-memory-utilization 0.4 +``` + +### Notes + +- On the shared host used for the documented run, `--gpu-memory-utilization 0.4` was necessary. +- If your machine is cleaner / less VRAM-constrained, you may be able to raise that value. +- If P2P is unavailable or fails on the selected GPU pair, TP=2 can still work but scaling will be worse than ideal. + +--- + +## How to verify you are on the fast path + +When startup is successful, the logs should show the wrapper behavior and the Marlin backend selection. + +Look for lines like: + +```text +Auto-configured qwen3_5_text serving: language_model_only=True ... +Patched vLLM NVML enumeration to visible physical GPU ids: ... +Enabled NVML LD_PRELOAD shim for NCCL-visible physical GPU remapping. +Using MarlinLinearKernel for GPTQMarlinLinearMethod +``` + +That last line is the important one: + +```text +Using MarlinLinearKernel for GPTQMarlinLinearMethod +``` + +If you do **not** see `MarlinLinearKernel`, you are probably not on the same runtime path that produced the ~100+ tok/s numbers. + +--- + +## Warmup and test requests + +After the server is up, confirm it answers: + +```bash +curl -s http://127.0.0.1:8011/v1/models | jq +``` + +Then send a small completion request: + +```bash +curl -s http://127.0.0.1:8011/v1/chat/completions \ + -H 'Content-Type: application/json' \ + -d '{ + "model": "qwen35-9b-gptq-pro", + "messages": [{"role": "user", "content": "Say hello in one sentence."}], + "temperature": 0, + "max_tokens": 64 + }' | jq +``` + +### Important: first request is not representative + +The first request after startup can be **much slower** than steady state because vLLM may still be paying for: + +- `torch.compile` +- CUDA graph capture +- internal warmup / graph specialization + +If you care about throughput, ignore the first hit and measure again after one or two warm requests. + +--- + +## A minimal warmed benchmark loop + +This is a simple way to reproduce the same kind of warmed measurement used in the notes. + +```bash +python - <<'PY' +import requests, time + +url = 'http://127.0.0.1:8011/v1/chat/completions' +payload = { + 'model': 'qwen35-9b-gptq-pro', + 'messages': [{'role': 'user', 'content': 'Explain general relativity briefly.'}], + 'temperature': 0, + 'max_tokens': 64, +} + +for i in range(3): + t0 = time.perf_counter() + r = requests.post(url, json=payload, timeout=300) + r.raise_for_status() + dt = time.perf_counter() - t0 + data = r.json() + usage = data.get('usage', {}) + out_toks = usage.get('completion_tokens') + print(f'run={i+1} seconds={dt:.3f} completion_tokens={out_toks}') + if out_toks: + print(f'tok/s={out_toks/dt:.2f}') +PY +``` + +Use the **later warmed runs**, not the first one, as the meaningful throughput number. + +--- + +## Troubleshooting + +## 1) vLLM starts but you do not see `MarlinLinearKernel` + +Check: + +- that the model is a GPTQ-compatible checkpoint +- that you launched through `scripts/serve_vllm_qwen35.py` +- that the checkpoint is not forcing some fallback path +- that the model shape is compatible with the selected tensor-parallel setting + +Useful grep: + +```bash +rg -n "MarlinLinearKernel|gptq_marlin|Qwen3_5" server.log +``` + +## 2) Plain `vllm serve` fails on `qwen3_5_text` + +That is exactly why this wrapper exists. +Use: + +```bash +python scripts/serve_vllm_qwen35.py ... +``` + +instead of: + +```bash +vllm serve ... +``` + +## 3) NCCL / NVML crashes on a host with a bad physical GPU + +Use `CUDA_VISIBLE_DEVICES=...` and launch through the wrapper. +The wrapper patches Python-side NVML enumeration and also sets up the local `LD_PRELOAD` shim for NCCL-visible GPU remapping. + +## 4) TP=2 is slower than expected + +That can still happen even when the path is working. +In the documented run, TP=2 helped, but not by a full 2x, because: + +- P2P / custom all-reduce was not fully available on the selected pair +- the host was shared and VRAM-constrained +- `--gpu-memory-utilization` had to be lowered to `0.4` + +## 5) First token is slow but later throughput is fine + +Normal for this setup. +Measure **steady state**, not cold start. + +--- + +## Operational recommendations + +- Prefer **1 GPU** first to verify the model loads and selects `MarlinLinearKernel`. +- Only then move to **TP=2**. +- Keep a copy of the server log when testing new checkpoints. +- Treat the wrapper as the canonical launch path for local `qwen3_5_text` models in this repo. + +--- + +## Related files + +- wrapper server entrypoint: `scripts/serve_vllm_qwen35.py` +- vLLM Qwen 3.5 shim: `scripts/vllm_qwen35_shim.py` +- startup patch hooks: `scripts/sitecustomize.py` +- benchmark / comparison notes: `docs/qwen35_vllm_comparison.md` + +## Summary + +If you want the **same vLLM path that achieved ~100+ tok/s**, launch like this: + +- **use `python scripts/serve_vllm_qwen35.py`** +- **set `CUDA_VISIBLE_DEVICES` explicitly** +- **use `--tensor-parallel-size 1` or `2` as needed** +- on the shared 2-GPU host, use **`--gpu-memory-utilization 0.4`** +- confirm the logs show **`Using MarlinLinearKernel for GPTQMarlinLinearMethod`** + +That is the serving path to reproduce, not `GPTQModel.generate()` and not raw `vllm serve`. diff --git a/docs/torch_fused_int4_transformations.md b/docs/torch_fused_int4_transformations.md index 13636ff5d..22193de7b 100644 --- a/docs/torch_fused_int4_transformations.md +++ b/docs/torch_fused_int4_transformations.md @@ -7,7 +7,7 @@ # Torch Fused INT4 Transformations -This note explains what `TorchFusedQuantLinear.transform_xpu` and `transform_cpu` +This note explains what `TorchFusedLinear.transform_xpu` and `transform_cpu` do to GPTQ-format tensors before calling the fused `torch.ops.aten` kernels. The goal is to document the exact tensor shapes, the axis permutations, and the bit packing order expected by `aten._weight_int4pack_mm_*` so you do not need to @@ -231,7 +231,7 @@ in the final storage type, accompanying metadata, and fused operator ABI. ## AWQ compatibility (`torch_fused_awq.py`) -`TorchFusedAwqQuantLinear` (`gptqmodel/nn_modules/qlinear/torch_fused_awq.py`) +`TorchFusedAwqLinear` (`gptqmodel/nn_modules/qlinear/torch_fused_awq.py`) reuses the CPU fused kernel while accepting checkpoints emitted by the AWQ tooling. The module always expects `qweight` to be stored in the AWQ layout `[in_features, out_features / pack_factor]`, meaning each row corresponds to a @@ -263,7 +263,7 @@ the standard CPU packing runs: `_weight_int4pack_mm_for_cpu` receives the same affine parameters the AWQ calibration solved for. -Because the shim runs entirely on the CPU path, `TorchFusedAwqQuantLinear` +Because the shim runs entirely on the CPU path, `TorchFusedAwqLinear` currently raises `NotImplementedError` when asked to run the fused transform on `xpu` devices. If the module has not been transformed yet (or fused ops are unavailable), inference falls back to the dense AWQ matmul computed by diff --git a/environment.yml b/environment.yml new file mode 100644 index 000000000..7ae467565 --- /dev/null +++ b/environment.yml @@ -0,0 +1,11 @@ +name: gptq-pro-vllm +channels: + - conda-forge +dependencies: + - python=3.13 + - pip + - git + - git-lfs + - ninja + - cmake + - libstdcxx-ng diff --git a/examples/benchmark/perplexity.py b/examples/benchmark/perplexity.py index 4c0978702..7980383ac 100644 --- a/examples/benchmark/perplexity.py +++ b/examples/benchmark/perplexity.py @@ -22,14 +22,26 @@ --model ModelCloud/Llama-3.2-1B-Instruct-gptqmodel-4bit-vortex-v2.5 \ --is_quantized + Use a separate tokenizer path when the quantized checkpoint should reuse the source tokenizer: + python examples/benchmark/perplexity.py \ + --model /path/to/quantized-model \ + --tokenizer /path/to/source-model \ + --is_quantized + Change your dataset: python examples/benchmark/perplexity.py --dataset_path tiny_shakespeare """ parser = argparse.ArgumentParser(description="Calculate Perplexity for a model.") parser.add_argument("--model", type=str, default="ModelCloud/Llama-3.2-1B-Instruct-gptqmodel-4bit-vortex-v2.5", help="Model name.") + parser.add_argument("--tokenizer", type=str, default=None, help="Optional tokenizer path. Defaults to --model.") parser.add_argument("--n_ctx", type=int, default=1024, help="Context size.") - parser.add_argument("--n_batch", type=int, default=1024, help="Batch size.") + parser.add_argument( + "--n_batch", + type=int, + default=1024, + help="Approximate token budget per forward pass used to batch multiple context windows.", + ) parser.add_argument("--dataset_path", type=str, default="wikitext", help="Path to the dataset.") parser.add_argument("--dataset_name", type=str, default=None, help="Name of the dataset.") parser.add_argument("--split", type=str, default="test", help="Dataset split to use.") @@ -45,7 +57,12 @@ parser.add_argument("--backend", choices=['auto', 'marlin', 'exllama_v1', 'exllama_v2', 'triton', 'cuda', 'torch', 'ipex', 'bitblas'], default='auto', help="Whether to use BACKEND format") args = parser.parse_args() - tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=args.use_fast_tokenizer) + tokenizer_path = args.tokenizer or args.model + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_path, + use_fast=args.use_fast_tokenizer, + trust_remote_code=args.trust_remote_code, + ) if not tokenizer.pad_token_id: tokenizer.pad_token_id = tokenizer.eos_token_id diff --git a/format/format.sh b/format/format.sh index 5f9e35118..61d644d32 100755 --- a/format/format.sh +++ b/format/format.sh @@ -6,11 +6,11 @@ cd "$(dirname "$0")" || exit pip install -U ruff==0.14.2 #isort==6.0.1 -ruff check ../gptqmodel/models ../gptqmodel/nn_modules ../gptqmodel/quantization ../gptqmodel/utils ../gptqmodel/__init__.py ../examples ../tests ../setup.py --fix --unsafe-fixes +ruff check ../gptqmodel/models ../gptqmodel/nn_modules ../gptqmodel/quantization ../gptqmodel/utils ../gptqmodel/__init__.py ../docs/eora ../tests ../setup.py --fix --unsafe-fixes ruff_status=$? # isort is too slow # isort -l 119 -e ../ # Exit with the status code of ruff check -exit $ruff_status \ No newline at end of file +exit $ruff_status diff --git a/gptqmodel/__init__.py b/gptqmodel/__init__.py index 0da3fa820..1765ea3d6 100644 --- a/gptqmodel/__init__.py +++ b/gptqmodel/__init__.py @@ -7,9 +7,13 @@ # isort: off +from ._banner import get_startup_banner # noqa: E402 +from .utils import _MONKEY_PATCH_LOCK # noqa: E402 from .utils.nogil_patcher import TritonPatch, patch_safetensors_save_file # noqa: E402 + # isort: on + patch_safetensors_save_file() # TODO: waiting for official fix from triton @@ -19,6 +23,81 @@ except Exception: pass + +def _patch_transformers_gptq_device_map_compat(): + """Preserve concrete single-device GPTQ maps for Optimum's later packing step.""" + try: + from functools import wraps + + from transformers.quantizers.quantizer_gptq import GptqHfQuantizer + except Exception: + return + + with _MONKEY_PATCH_LOCK: + original_process = GptqHfQuantizer._process_model_before_weight_loading + if getattr(original_process, "_gptqmodel_device_map_compat", False): + return + + @wraps(original_process) + def _process_model_before_weight_loading_with_device_map(self, model, **kwargs): + """Backfill `hf_device_map` when GPTQ uses a single concrete device.""" + device_map = kwargs.get("device_map") + if ( + isinstance(device_map, dict) + and device_map + and len(set(device_map.values())) == 1 + and not hasattr(model, "hf_device_map") + ): + model.hf_device_map = dict(device_map) + return original_process(self, model, **kwargs) + + _process_model_before_weight_loading_with_device_map._gptqmodel_device_map_compat = True + GptqHfQuantizer._process_model_before_weight_loading = _process_model_before_weight_loading_with_device_map + + +def _patch_transformers_paroquant_quantizer_compat(): + """Teach transformers to treat ParoQuant checkpoints as GPTQ-backed configs. + + Upstream transformers currently rejects `quant_method="paroquant"` before + the GPT-QModel loader path gets a chance to handle the checkpoint. ParoQuant + artifacts reuse GPT-QModel/Optimum loading semantics, so register the method + alongside GPTQ only when upstream has not provided native support yet. + """ + try: + from transformers.quantizers import auto as hf_quant_auto + from transformers.quantizers.quantizer_gptq import GptqHfQuantizer + from transformers.utils.quantization_config import GPTQConfig + except Exception: + return + + with _MONKEY_PATCH_LOCK: + if getattr(hf_quant_auto, "_gptqmodel_paroquant_quantizer_compat", False): + return + + hf_quant_auto.AUTO_QUANTIZATION_CONFIG_MAPPING.setdefault("paroquant", GPTQConfig) + hf_quant_auto.AUTO_QUANTIZER_MAPPING.setdefault("paroquant", GptqHfQuantizer) + hf_quant_auto._gptqmodel_paroquant_quantizer_compat = True + + +def _patch_openvino_gptqmodel_compat(): + """Extend OpenVINO's GPTQ patcher to understand GPTQModel new kernels.""" + try: + from openvino.frontend.pytorch import gptq as ov_gptq + except Exception: + return + + with _MONKEY_PATCH_LOCK: + if getattr(ov_gptq, "_gptqmodel_torch_quant_compat", False): + return + + class MatchAll(list): + def __contains__(self, item): + return True + + ov_gptq.supported_quant_types = MatchAll() + ov_gptq._gptqmodel_torch_quant_compat = True + + from .utils.env import env_flag from .utils.logger import setup_logger from .utils.modelscope import ensure_modelscope_available @@ -30,38 +109,146 @@ from .utils.threadx import DeviceThreadPool -DEVICE_THREAD_POOL = DeviceThreadPool( - inference_mode=True, - warmups={ - "cuda": run_torch_linalg_warmup, - "xpu": run_torch_linalg_warmup, - "mps": run_torch_linalg_warmup, - "cpu": run_torch_linalg_warmup, - }, - workers={ - "cuda:per": 4, - "xpu:per": 1, - "mps": 8, - "cpu": min(12, max(1, (os.cpu_count() or 1) + 1 // 2)), # count + 1, fixed pool size > 1 check when count=3 - "model_loader:cpu": 2, - }, - empty_cache_every_n=512, -) +_DEVICE_THREAD_POOL = None + + +def _build_device_thread_pool(): + return DeviceThreadPool( + inference_mode=True, + warmups={ + "cuda": run_torch_linalg_warmup, + "xpu": run_torch_linalg_warmup, + "mps": run_torch_linalg_warmup, + "cpu": run_torch_linalg_warmup, + }, + workers={ + "cuda:per": 4, + "xpu:per": 1, + "mps": 8, + "cpu": min(12, max(1, (os.cpu_count() or 1) + 1 // 2)), # count + 1, fixed pool size > 1 check when count=3 + "model_loader:cpu": 2, + }, + empty_cache_every_n=512, + ) + +def get_device_thread_pool(): + global _DEVICE_THREAD_POOL + if _DEVICE_THREAD_POOL is None: + _DEVICE_THREAD_POOL = _build_device_thread_pool() + return _DEVICE_THREAD_POOL + + +class _LazyDeviceThreadPoolProxy: + def __init__(self): + object.__setattr__(self, "_overrides", {}) + + def __getattribute__(self, name): + if name in { + "_overrides", + "__class__", + "__dict__", + "__getattribute__", + "__setattr__", + "__delattr__", + "__repr__", + "__dir__", + "_get_pool", + }: + return object.__getattribute__(self, name) + + overrides = object.__getattribute__(self, "_overrides") + if name in overrides: + return overrides[name] + + return getattr(self._get_pool(), name) + + def __setattr__(self, name, value): + if name == "_overrides": + object.__setattr__(self, name, value) + return + self._overrides[name] = value + + def __delattr__(self, name): + overrides = self._overrides + if name in overrides: + del overrides[name] + return + delattr(self._get_pool(), name) + + def __repr__(self): + pool = _DEVICE_THREAD_POOL + if pool is None: + return "" + return repr(pool) + + def __dir__(self): + attrs = set(self._overrides.keys()) + pool = _DEVICE_THREAD_POOL + if pool is not None: + attrs.update(dir(pool)) + return sorted(attrs) + + @staticmethod + def _get_pool(): + return get_device_thread_pool() + + +DEVICE_THREAD_POOL = _LazyDeviceThreadPoolProxy() + +_patch_transformers_gptq_device_map_compat() +_patch_transformers_paroquant_quantizer_compat() +_patch_openvino_gptqmodel_compat() + + +def exllama_set_max_input_length(model, max_input_length: int): + """Resize exllama scratch buffers through the legacy package-root API.""" + from .utils.model import hf_gptqmodel_post_init + + quantize_config = getattr(model, "quantize_config", None) + use_act_order = bool(getattr(quantize_config, "desc_act", False)) + return hf_gptqmodel_post_init( + model, + use_act_order=use_act_order, + quantize_config=quantize_config, + max_input_length=max_input_length, + ) + + +import torch + +from . import extension from .models import GPTQModel, get_best_device -from .models.auto import ASCII_LOGO -from .quantization import BaseQuantizeConfig, GPTAQConfig, QuantizeConfig -from .utils import BACKEND -from .utils.exllama import exllama_set_max_input_length +from .models.auto import ASCII_LOGO, TRANSFORMERS_VERSION +from .quantization import ( + AWQConfig, + BaseQuantizeConfig, + FOEMConfig, + GGUFConfig, + GPTAQConfig, + GPTQConfig, + QuantizeConfig, + RTNConfig, + WeightOnlyConfig, +) +from .utils import BACKEND, PROFILE from .version import __version__ -setup_logger().info("\n%s", ASCII_LOGO) - +setup_logger().info( + "\n%s", + get_startup_banner( + ASCII_LOGO, + gptqmodel_version=__version__, + transformers_version=TRANSFORMERS_VERSION, + torch_version=torch.__version__, + ), +) if ensure_modelscope_available(): try: from modelscope.utils.hf_util.patcher import patch_hub + patch_hub() except Exception as exc: raise ModuleNotFoundError( diff --git a/gptqmodel/_banner.py b/gptqmodel/_banner.py new file mode 100644 index 000000000..991318a64 --- /dev/null +++ b/gptqmodel/_banner.py @@ -0,0 +1,78 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from importlib.metadata import PackageNotFoundError, version as package_version +from typing import Iterable + +TRITON_PACKAGE_CANDIDATES = ( + "triton", + "triton-windows", + "pytorch_triton_xpu", + "pytorch-triton-xpu", +) + + +def resolve_installed_package_version(package_names: Iterable[str]) -> str | None: + for package_name in package_names: + try: + resolved_version = package_version(package_name) + except PackageNotFoundError: + continue + + if resolved_version: + return resolved_version + + return None + + +def build_startup_banner( + ascii_logo: str, + *, + gptqmodel_version: str, + transformers_version: str, + torch_version: str, + triton_version: str | None = None, +) -> str: + version_rows = [ + ("GPT-QModel", gptqmodel_version), + ("Transformers", transformers_version), + ("Torch", torch_version), + ] + + if triton_version: + version_rows.append(("Triton", triton_version)) + + label_width = max(len(label) for label, _ in version_rows) + formatted_rows = [ + f"{label:<{label_width}} : {value}" for label, value in version_rows + ] + + return "\n".join([ascii_logo.rstrip("\n"), *formatted_rows]) + + +def _get_git_commit(): + import subprocess + try: + hash = subprocess.check_output( + ["git", "rev-parse", "--short", "HEAD"] + ).decode().strip() + return f"+{hash}" + except Exception: + return "" + +def get_startup_banner( + ascii_logo: str, + *, + gptqmodel_version: str, + transformers_version: str, + torch_version: str, +) -> str: + return build_startup_banner( + ascii_logo, + gptqmodel_version=f"{gptqmodel_version}{_get_git_commit()}", + transformers_version=transformers_version, + torch_version=torch_version, + triton_version=resolve_installed_package_version(TRITON_PACKAGE_CANDIDATES), + ) diff --git a/gptqmodel/adapter/adapter.py b/gptqmodel/adapter/adapter.py index 7a0545085..fb50f3d4c 100644 --- a/gptqmodel/adapter/adapter.py +++ b/gptqmodel/adapter/adapter.py @@ -3,11 +3,10 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium - -import pcre as re from dataclasses import dataclass from typing import Dict, List, Optional, Tuple, Union +import pcre import safetensors import torch @@ -23,10 +22,14 @@ class AdapterCache(): + """Caches loaded adapter configs and tensors by source path.""" + cache: Dict[str, Dict[str, Union[LoraConfig, torch.Tensor]]] = {} # first level key is `path`, second level keys [ `config` = LoraConfig, `weights` = Dict[str, Tensors] @classmethod def get(cls, path: str) -> Optional[Tuple[LoraConfig, Dict[str, torch.Tensor]]]: + """Returns cached adapter config and weights for a path, if present.""" + data = cls.cache.get(path) if not data: return None @@ -35,24 +38,36 @@ def get(cls, path: str) -> Optional[Tuple[LoraConfig, Dict[str, torch.Tensor]]]: @classmethod def reset(cls): + """Clears the global adapter cache.""" + log.info("Adapter Cache: Resetting cache") cls.cache = {} @classmethod def add(cls, path: str, config: LoraConfig, weights: Dict[str, torch.Tensor]): + """Stores adapter config and weight tensors under the source path.""" + cls.cache[path] = {"config": config, "weights": weights} @classmethod def remove(cls, path): + """Drops cached adapter state for a path if it exists.""" + cls.cache.pop(path, None) class Adapter(): + """Base interface for runtime adapters applied on top of quantized layers.""" + def __init__(self, rank: int = None, path: str = None): + """Initializes adapter identity and optional source location.""" + self.rank = rank # rank may be zero, when loading, and rank will be re-populated by loading saved LoraConfig file self.path = path.strip() if isinstance(path, str) else path def validate_path(self, local=False): + """Validates that the configured adapter path matches the expected source type.""" + if not self.path or not isinstance(self.path, str): raise ValueError("Adapter: `path` str is required.") @@ -68,30 +83,44 @@ def validate_path(self, local=False): # override me def apply(self, x: torch.Tensor, out: torch.Tensor) -> torch.Tensor: + """Applies the adapter contribution to a layer output tensor.""" + pass # override me def post_init(self, weight_key: str, device: torch.device, **kwargs): + """Loads or finalizes adapter tensors for a specific weight key.""" + pass # override me def optimize(self): + """Applies optional backend-specific optimizations after loading.""" + pass # override me @classmethod def name(cls) -> List[str]: + """Returns the serialized adapter type name.""" + pass # override me @classmethod def parameter_keys(cls) -> [str]: # name of tensors/parameters in attribute key name + """Lists tensor attribute names expected on the adapter instance.""" + pass @dataclass class Lora(Adapter): + """LoRA adapter implementation backed by A/B projection matrices.""" + def __init__(self, rank: int, path: str = None, lora_A: torch.Tensor = None, lora_B: torch.Tensor = None): + """Initializes the adapter with optional preloaded LoRA matrices.""" + super().__init__(rank, path) self.lora_A = lora_A @@ -99,24 +128,35 @@ def __init__(self, rank: int, path: str = None, lora_A: torch.Tensor = None, lor @classmethod def name(cls) -> str: + """Returns the canonical adapter type used in serialized configs.""" + return "lora" @classmethod def parameter_keys(cls) -> List[str]: + """Lists the tensor attributes that store the LoRA projection weights.""" + return ["lora_A", "lora_B"] def optimize(self, backend: str = "inductor", mode: str = None, fullgraph: bool = False): + """Reserved hook for compiling the adapter path when enabled.""" + pass #logger.info("Adapter: optimize (compile)") #self.apply = torch_compile(self.apply, backend=backend, mode=mode, fullgraph=fullgraph) def apply(self, x: torch.Tensor, out: torch.Tensor) -> torch.Tensor: + """Adds the LoRA update to the kernel output, reshaping batched outputs when needed.""" + # original code # out = out + ((x @ self.lora_A) @ self.lora_B) # native quantized model/eora is float16 for gptq but for training, we may load the model as bfloat16 for accuracy - if x.dtype != self.lora_A.dtype: - log.info.once(f"Adapter: Lora A/B auto changed from `{self.lora_A.dtype}` to `{x.dtype}` to match forward input dtype.") + if x.dtype != self.lora_A.dtype or x.device != self.lora_A.device: + log.info.once( + f"Adapter: Lora A/B auto changed from `{self.lora_A.dtype}` on `{self.lora_A.device}` " + f"to `{x.dtype}` on `{x.device}` to match forward input." + ) self.lora_A = self.lora_A.to(device=x.device, dtype=x.dtype) self.lora_B = self.lora_B.to(device=x.device, dtype=x.dtype) @@ -133,6 +173,8 @@ def apply(self, x: torch.Tensor, out: torch.Tensor) -> torch.Tensor: return out.add_((x @ self.lora_A) @ self.lora_B) def post_init(self, weight_key: str, device:torch.device, lora_A: torch.Tensor=None, lora_B: torch.Tensor=None): + """Loads, caches, and materializes LoRA tensors for the target module.""" + # self.register_buffer("lora_A", lora_A) # self.register_buffer("lora_B", lora_B) @@ -174,7 +216,7 @@ def post_init(self, weight_key: str, device:torch.device, lora_A: torch.Tensor=N lora_A_weight_key = f"{weight_key}.lora_A.weight" lora_B_weight_key = f"{weight_key}.lora_B.weight" - + pop_keys = [] for k, v in lora_weights.items(): if k.endswith(lora_A_weight_key): @@ -219,6 +261,8 @@ def post_init(self, weight_key: str, device:torch.device, lora_A: torch.Tensor=N #print(f"Adapter: lora_B {lora_B.shape}: `{lora_B}`") def dynamic_rank_override(self, lora_cfg: LoraConfig, weight_key: str) -> bool: + """Overrides the adapter rank when the config defines a matching rank pattern.""" + assert lora_cfg.rank_pattern is not None and weight_key is not None if lora_cfg.rank_pattern: for k, v in lora_cfg.rank_pattern.items(): @@ -226,7 +270,7 @@ def dynamic_rank_override(self, lora_cfg: LoraConfig, weight_key: str) -> bool: k = k.lower() assert v > 0 # check for invalid rank range # first do string full match, then suffix match, then regex match - if weight_key == k or k.endswith(weight_key) or re.match(k, weight_key): + if weight_key == k or k.endswith(weight_key) or pcre.compile(k).match(weight_key): self.rank = v log.info(f"Adapter: Base Lora `rank` = `{self.rank}` has been overridden by `{k}` due to dynamic `LoraConfig.rank_pattern` control.") return True @@ -236,6 +280,8 @@ def dynamic_rank_override(self, lora_cfg: LoraConfig, weight_key: str) -> bool: def to_dict(self): + """Serializes the minimal adapter descriptor used by GPT-QModel.""" + return { "name": self.name(), "path": self.path, @@ -246,6 +292,8 @@ def to_dict(self): # accept both Adapter cls instance or Dict() def normalize_adapter(adapter: Union[Dict, Adapter]): + """Normalizes serialized adapter metadata into a concrete adapter instance.""" + if adapter is None: return None diff --git a/gptqmodel/adapter/peft.py b/gptqmodel/adapter/peft.py index d64438f03..5f4fa0c46 100644 --- a/gptqmodel/adapter/peft.py +++ b/gptqmodel/adapter/peft.py @@ -326,6 +326,8 @@ def to_dict(self): return kv def save_pretrained(self, save_dir: str): + """Writes the LoRA config in PEFT-compatible JSON form.""" + from ..adapter.adapter import HF_ADAPTER_CONFIG_FILE_NAME log.info(f"Adapter: Saving EoRA/Lora config to -> `{save_dir}`") @@ -338,6 +340,8 @@ def save_pretrained(self, save_dir: str): f.write(json_str) def __post_init__(self): + """Normalizes list inputs and validates the subset of PEFT options supported here.""" + self.peft_type = "LORA" self.target_modules = ( set(self.target_modules) if isinstance(self.target_modules, list) else self.target_modules @@ -390,6 +394,8 @@ def __post_init__(self): @classmethod def from_pretrained(cls, path: str, filename: str): + """Loads a saved LoRA config and filters out unsupported empty fields.""" + resolved_path = resolve_path(path=path, filename=filename) with open(resolved_path, "r") as file: config_dict = json.load(file) @@ -398,4 +404,4 @@ def from_pretrained(cls, path: str, filename: str): valid_fields = {field.name for field in fields(cls)} config_dict = {k: v for k, v in config_dict.items() if k in valid_fields and v not in [None, "", [], {}, ()]} - return cls(**config_dict) \ No newline at end of file + return cls(**config_dict) diff --git a/gptqmodel/adapter/remote.py b/gptqmodel/adapter/remote.py index 1de578387..a3913a984 100644 --- a/gptqmodel/adapter/remote.py +++ b/gptqmodel/adapter/remote.py @@ -6,11 +6,14 @@ import os from urllib.parse import urlparse +from ..utils.hub import hf_hub_download from ..utils.logger import setup_logger log = setup_logger() def parse_url(url: str): + """Extracts Hugging Face repo, revision, and filename from a blob URL when possible.""" + parsed_url = urlparse(url) if parsed_url.netloc.endswith("huggingface.co") or parsed_url.netloc.endswith("hf.co"): @@ -27,6 +30,8 @@ def parse_url(url: str): return [] def resolve_path(path: str, filename: str) -> str: # return a valid file path to read + """Resolves an adapter file from a local directory, HF URL, or HF repo id.""" + if os.path.isdir(path): resolved_path = f"{path.removesuffix('/')}/{filename}" log.info(f"Resolver: Local path: `{resolved_path}`") @@ -35,8 +40,6 @@ def resolve_path(path: str, filename: str) -> str: # return a valid file path to return resolved_path elif path.startswith("http"): - from huggingface_hub import hf_hub_download - result = parse_url(path) if len(result) == 3: log.info( @@ -62,8 +65,6 @@ def resolve_path(path: str, filename: str) -> str: # return a valid file path to path = f"{path_split[0]}/{path_split[1]}" subfolder = "/".join(path_split[2:]) - from huggingface_hub import hf_hub_download - # _ = HfApi().list_repo_files(path) resolved_path = hf_hub_download(repo_id=path, filename=filename, subfolder=subfolder) @@ -71,4 +72,4 @@ def resolve_path(path: str, filename: str) -> str: # return a valid file path to # print(f"Adapter tensors loaded from `{self.path}`") else: raise ValueError( - f"Resolver: We only support local file path or HF repo id; actual = path: `{path}`, filename = `{filename}`") \ No newline at end of file + f"Resolver: We only support local file path or HF repo id; actual = path: `{path}`, filename = `{filename}`") diff --git a/gptqmodel/eora/eora.py b/gptqmodel/eora/eora.py index 9fd9f0bf8..96ed329a1 100644 --- a/gptqmodel/eora/eora.py +++ b/gptqmodel/eora/eora.py @@ -20,6 +20,7 @@ from ..utils.logger import setup_logger from ..utils.rocm import IS_ROCM +from ..utils.torch import TORCH_GTE_210 log = setup_logger() @@ -91,7 +92,7 @@ def eora_compute_lora( # save this later for SVD raw_scaling_diag_matrix = eigen_scaling_diag_matrix.to(device=device, dtype=torch.float64) - if IS_ROCM: + if IS_ROCM and not TORCH_GTE_210: # hip cannot resolve linalg ops original_backend = torch.backends.cuda.preferred_linalg_library() torch.backends.cuda.preferred_linalg_library(backend="magma") @@ -131,7 +132,7 @@ def eora_compute_lora( del truc_s, truc_u, truc_v, truc_sigma, sqrtS # revert linalg backend - if IS_ROCM: + if IS_ROCM and not TORCH_GTE_210: torch.backends.cuda.preferred_linalg_library(original_backend) return A, B diff --git a/gptqmodel/exllamav3/CREDITS.md b/gptqmodel/exllamav3/CREDITS.md new file mode 100644 index 000000000..b6be689c6 --- /dev/null +++ b/gptqmodel/exllamav3/CREDITS.md @@ -0,0 +1,13 @@ +This directory vendors the EXL3 kernel and quantizer pieces adapted from `turboderp-org/exllamav3`. + +Primary upstream source: +- https://github.com/turboderp-org/exllamav3 + +Ported components in this repo: +- `gptqmodel/exllamav3/ext.py` +- `gptqmodel/exllamav3/modules/quant/exl3.py` +- `gptqmodel/exllamav3/modules/quant/exl3_lib/quantize.py` +- `gptqmodel/exllamav3/util/*` +- `gptqmodel_ext/exllamav3/*` + +The code remains self-contained inside GPT-QModel and does not depend on the external `exllamav3` Python package. diff --git a/gptqmodel/exllamav3/__init__.py b/gptqmodel/exllamav3/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/gptqmodel/exllamav3/ext.py b/gptqmodel/exllamav3/ext.py new file mode 100644 index 000000000..7cebe44c9 --- /dev/null +++ b/gptqmodel/exllamav3/ext.py @@ -0,0 +1,338 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium +# +# Portions of this file are adapted from turboderp-org/exllamav3. +# Credits: TurboDerp / ExLlamaV3 contributors. + +from __future__ import annotations + +import os +import sys +from pathlib import Path +from typing import Optional + +import torch + +from ..utils.cpp import ( + TorchOpsJitExtension, + cuda_include_paths_with_fallback, + default_jit_cflags, + default_jit_cuda_cflags, + default_torch_ops_build_root, +) +from .util.arch_list import maybe_set_arch_list_env + + +extension_name = "gptqmodel_exllamav3_ops" +verbose = str(os.environ.get("GPTQMODEL_EXT_VERBOSE", "")).strip().lower() not in {"", "0", "false", "off", "no"} +ext_debug = str(os.environ.get("GPTQMODEL_EXT_DEBUG", "")).strip().lower() in {"1", "true", "on", "yes"} +windows = os.name == "nt" + + +def _find_msvc() -> str | None: + program_files_x64 = os.environ["ProgramW6432"] + program_files_x86 = os.environ["ProgramFiles(x86)"] + msvc_dirs = [ + a + "\\Microsoft Visual Studio\\" + b + "\\" + c + "\\VC\\Tools\\MSVC\\" + for b in ["2022", "2019", "2017"] + for a in [program_files_x64, program_files_x86] + for c in ["BuildTools", "Community", "Professional", "Enterprise", "Preview"] + ] + + for msvc_dir in msvc_dirs: + if not os.path.exists(msvc_dir): + continue + versions = sorted(os.listdir(msvc_dir), reverse=True) + for version in versions: + compiler_dir = msvc_dir + version + "\\bin\\Hostx64\\x64" + if os.path.exists(compiler_dir) and os.path.exists(compiler_dir + "\\cl.exe"): + return compiler_dir + return None + + +def _ensure_windows_compiler() -> None: + if not windows: + return + + import subprocess + + try: + subprocess.check_output(["where", "/Q", "cl"]) + except subprocess.CalledProcessError: + cl_path = _find_msvc() + if cl_path: + if verbose: + print(" -- Injected compiler path:", cl_path) + os.environ["path"] += ";" + cl_path + else: + print(" !! Unable to find cl.exe; EXL3 compilation will probably fail", file=sys.stderr) + + +def _source_root() -> Path: + return Path(__file__).resolve().parents[2] / "gptqmodel_ext" / "exllamav3" + + +def _source_files() -> list[str]: + sources_dir = _source_root() + source_files = [ + "bindings.cpp", + "hadamard.cpp", + "hgemm.cu", + "libtorch/linear.cpp", + "quant/comp_units/exl3_comp_unit_1.cu", + "quant/comp_units/exl3_comp_unit_2.cu", + "quant/comp_units/exl3_comp_unit_3.cu", + "quant/comp_units/exl3_comp_unit_4.cu", + "quant/comp_units/exl3_comp_unit_5.cu", + "quant/comp_units/exl3_comp_unit_6.cu", + "quant/comp_units/exl3_comp_unit_7.cu", + "quant/comp_units/exl3_comp_unit_8.cu", + "quant/exl3_devctx.cu", + "quant/exl3_gemm.cu", + "quant/exl3_kernel_map.cu", + "quant/hadamard.cu", + "quant/pack.cu", + "quant/quantize.cu", + "quant/reconstruct.cu", + "quant/util.cu", + ] + return [str((sources_dir / path).resolve()) for path in source_files] + + +def _exllamav3_required_cuda_headers() -> tuple[str, ...]: + return ("cusparse.h",) + + +def _exllamav3_include_paths() -> list[str]: + return cuda_include_paths_with_fallback( + [str(_source_root())], + required_header_names=_exllamav3_required_cuda_headers(), + ) + + +def _legacy_build_root() -> Optional[Path]: + build_root = os.environ.get("GPTQMODEL_EXT_BUILD") + if not build_root: + return None + return Path(build_root) / extension_name + + +def _default_build_root() -> Path: + legacy_root = _legacy_build_root() + if legacy_root is not None: + return legacy_root + return default_torch_ops_build_root("exllamav3") + + +def _extra_cflags() -> list[str]: + if windows: + flags = ["/O2", "/std:c++17"] + else: + flags = default_jit_cflags(opt_level="O2") + + if ext_debug: + if windows: + flags.append("/Zi") + else: + flags.extend(["-ftime-report", "-DTORCH_USE_CUDA_DSA"]) + return flags + + +def _extra_cuda_cflags() -> list[str]: + flags = default_jit_cuda_cflags( + opt_level="O2", + include_abi=not windows, + include_lineinfo=True, + ) + if torch.version.hip: + flags.append("-DHIPBLAS_USE_HIP_HALF") + return flags + + +def _extra_ldflags() -> list[str]: + if not windows: + return [] + flags = ["cublas.lib"] + if sys.base_prefix != sys.prefix: + flags.append(f"/LIBPATH:{os.path.join(sys.base_prefix, 'libs')}") + return flags + + +def _prepare_build_env() -> None: + _ensure_windows_compiler() + maybe_set_arch_list_env() + + +# Shared singleton so EXL3 quantization and inference both reuse the same +# torch.ops cache and first-use build policy. +_EXLLAMAV3_TORCH_OPS_EXTENSION = TorchOpsJitExtension( + name=extension_name, + namespace="gptqmodel_exllamav3", + required_ops=( + "had_paley", + "had_paley2", + "quantize_tiles", + "pack_trellis", + "unpack_trellis", + "pack_signs", + "reconstruct", + "had_r_128", + "hgemm", + "bc_linear_exl3_run", + ), + sources=_source_files, + build_root_env="GPTQMODEL_EXLLAMAV3_BUILD_ROOT", + default_build_root=_default_build_root, + display_name="ExLlamaV3", + extra_cflags=_extra_cflags, + extra_cuda_cflags=_extra_cuda_cflags, + extra_include_paths=_exllamav3_include_paths, + extra_ldflags=_extra_ldflags, + force_rebuild_env="GPTQMODEL_EXLLAMAV3_FORCE_REBUILD", + verbose_env="GPTQMODEL_EXT_VERBOSE", + requires_cuda=True, +) + + +def _extension_api(): + from gptqmodel import extension as extension_api + + return extension_api + + +def exllamav3_runtime_available() -> bool: + _prepare_build_env() + return _extension_api().is_available("exllamav3") + + +def exllamav3_runtime_error() -> str: + extension_api = _extension_api() + _prepare_build_env() + if extension_api.is_available("exllamav3"): + return "" + return ( + extension_api.error("exllamav3") + or "ExLlamaV3 CUDA runtime unavailable." + ) + + +def prewarm_exllamav3_extension() -> bool: + _prepare_build_env() + return _extension_api().load(name="exllamav3")["exllamav3"] + + +def _runtime_op(name: str): + _prepare_build_env() + return _extension_api().op("exllamav3", name) + + +class _BCLinearEXL3: + """Preserve the old binding surface while dispatching through torch.ops.""" + + def __init__( + self, + trellis: torch.Tensor, + suh: torch.Tensor, + svh: torch.Tensor, + K: int, + bias: Optional[torch.Tensor], + mcg: bool, + mul1: bool, + xh: torch.Tensor, + ): + if not exllamav3_runtime_available(): + raise ModuleNotFoundError("ExLlamaV3 torch.ops kernels are not properly installed. Error: " + exllamav3_runtime_error()) + self.trellis = trellis + self.suh = suh + self.svh = svh + self.K = int(K) + self.bias = bias + self.mcg = bool(mcg) + self.mul1 = bool(mul1) + self.xh = xh + + def run(self, x: torch.Tensor, y: torch.Tensor) -> None: + _runtime_op("bc_linear_exl3_run")( + self.trellis, + self.suh, + self.svh, + self.K, + self.bias, + self.mcg, + self.mul1, + self.xh, + x, + y, + ) + + +class _ExllamaV3TorchOpsFacade: + """Facade that mirrors the old pybind module API over torch.ops.""" + + BC_LinearEXL3 = _BCLinearEXL3 + + def had_paley(self, h: torch.Tensor) -> None: + _runtime_op("had_paley")(h) + + def had_paley2(self, h: torch.Tensor) -> None: + _runtime_op("had_paley2")(h) + + def quantize_tiles( + self, + input_tiles: torch.Tensor, + output_tiles: torch.Tensor, + output_indices: torch.Tensor, + temp_costs: torch.Tensor, + temp_edges: torch.Tensor, + K: int, + mcg: bool, + mul1: bool, + ) -> None: + _runtime_op("quantize_tiles")( + input_tiles, + output_tiles, + output_indices, + temp_costs, + temp_edges, + int(K), + bool(mcg), + bool(mul1), + ) + + def pack_trellis(self, packed: torch.Tensor, unpacked: torch.Tensor, K: int) -> None: + _runtime_op("pack_trellis")(packed, unpacked, int(K)) + + def unpack_trellis(self, unpacked: torch.Tensor, packed: torch.Tensor, K: int) -> None: + _runtime_op("unpack_trellis")(unpacked, packed, int(K)) + + def pack_signs(self, packed: torch.Tensor, unpacked: torch.Tensor) -> None: + _runtime_op("pack_signs")(packed, unpacked) + + def reconstruct(self, unpacked: torch.Tensor, packed: torch.Tensor, K: int, mcg: bool, mul1: bool) -> None: + _runtime_op("reconstruct")(unpacked, packed, int(K), bool(mcg), bool(mul1)) + + def had_r_128( + self, + input: torch.Tensor, + output: torch.Tensor, + pre_scale: Optional[torch.Tensor], + post_scale: Optional[torch.Tensor], + scale: float, + ) -> None: + _runtime_op("had_r_128")(input, output, pre_scale, post_scale, float(scale)) + + def hgemm(self, a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> None: + _runtime_op("hgemm")(a, b, c) + + +exllamav3_ext = _ExllamaV3TorchOpsFacade() + + +__all__ = [ + "exllamav3_ext", + "exllamav3_runtime_available", + "exllamav3_runtime_error", + "prewarm_exllamav3_extension", +] diff --git a/gptqmodel/exllamav3/modules/__init__.py b/gptqmodel/exllamav3/modules/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/gptqmodel/exllamav3/modules/quant/__init__.py b/gptqmodel/exllamav3/modules/quant/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/gptqmodel/exllamav3/modules/quant/exl3.py b/gptqmodel/exllamav3/modules/quant/exl3.py new file mode 100644 index 000000000..6a74a295b --- /dev/null +++ b/gptqmodel/exllamav3/modules/quant/exl3.py @@ -0,0 +1,265 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium +# +# Portions of this file are adapted from turboderp-org/exllamav3. +# Credits: TurboDerp / ExLlamaV3 contributors. + +from __future__ import annotations + +import torch + +from .exl3_lib.quantize import preapply_had_l, preapply_had_r, had_k, had_n +from ...ext import exllamav3_ext as ext +from ...util.tensor import g_tensor_cache + +class LinearEXL3: + + quant_type: str = "exl3" + + def __init__( + self, + config: object | None, + in_features: int, + out_features: int, + scale: torch.Tensor | None = None, + su: torch.Tensor | None = None, + sv: torch.Tensor | None = None, + suh: torch.Tensor | None = None, + svh: torch.Tensor | None = None, + trellis: torch.Tensor | None = None, + mcg: torch.Tensor | None = None, + mul1: torch.Tensor | None = None, + bias: torch.Tensor | None = None, + out_dtype: torch.dtype | None = None, + transformers_fix: bool = False, + key: str | None = None + ): + assert scale is None, "scale is no longer used" + assert su is not None or suh is not None, "either su (packed) or suh (unpacked) is required" + assert sv is not None or svh is not None, "either sv (packed) or svh (unpacked) is required" + assert trellis is not None, "trellis is required" + if su is not None: assert su.dtype == torch.int16, "su is wrong datatype" + if sv is not None: assert sv.dtype == torch.int16, "sv is wrong datatype" + if suh is not None: assert suh.dtype == torch.half, "suh is wrong datatype" + if svh is not None: assert svh.dtype == torch.half, "svh is wrong datatype" + assert trellis.dtype == torch.int16, "trellis is wrong datatype" + assert len(trellis.shape) == 3, "trellis must have dim = 3" + + if bias is not None and bias.dtype == torch.float: bias = bias.to(torch.half) + + self.transformers_fix = transformers_fix + self.key = key + + # self.scale = scale.item() + self.su = None + self.sv = None + self.suh = suh if suh is not None else self.unpack_bf(su) + self.svh = svh if svh is not None else self.unpack_bf(sv) + self.trellis = trellis + self.K = trellis.shape[-1] // 16 + self.in_features = in_features + self.out_features = out_features + self.bias = bias + self.swap_device = None + self.out_dtype = out_dtype + + self.mcg_tensor = mcg + self.mul1_tensor = mul1 + self.mcg = self.mcg_tensor is not None + self.mul1 = self.mul1_tensor is not None + + self.bsz1_xh_args = (self.trellis.device, (1, self.in_features), self.out_dtype) + self.bc = ext.BC_LinearEXL3( + self.trellis, + self.suh, + self.svh, + self.K, + self.bias, + self.mcg, + self.mul1, + g_tensor_cache.get(*self.bsz1_xh_args) + ) + + + def unload(self): + g_tensor_cache.drop(*self.bsz1_xh_args) + + + def get_tensors(self, key: str): + return { + f"{key}.{subkey}": tensor.contiguous() + for subkey, tensor in [ + ("su", self.su), + ("sv", self.sv), + ("suh", self.suh), + ("svh", self.svh), + ("trellis", self.trellis), + ("bias", self.bias), + ("mcg", self.mcg_tensor), + ("mul1", self.mul1_tensor), + ] if tensor is not None + } + + + def forward( + self, + x: torch.Tensor, + params: dict, + out_dtype: torch.dtype | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + + if "ovr" in params: + ovr = params["ovr"] + if self.key in ovr and ovr[self.key].inner is not self: + return ovr[self.key].forward(x, params, out_dtype) + + bsz = x.numel() // x.shape[-1] + torch_mode = params.get("reconstruct", bsz > 32) + + out_shape = x.shape[:-1] + (self.out_features,) + x = x.view(-1, self.in_features) + y = torch.empty(out_shape, dtype = out_dtype or self.out_dtype or torch.half, device = x.device) + + if torch_mode: + y_ = y.view(x.shape[0], self.out_features) + xh = torch.empty_like(x) + ext.had_r_128(x, xh, self.suh, None, 1.0) + w = self.get_inner_weight_tensor() + ext.hgemm(xh, w, y_) + ext.had_r_128(y_, y_, None, self.svh, 1.0) + if self.bias is not None: + y += self.bias + y = y.view(out_shape) + + else: + self.bc.run(x, y) + + return y + + + def unpack_bf(self, bitfield: torch.Tensor): + # For some reason this operation causes a GPU assert on Transformers. Running on CPU seems to fix it + device = bitfield.device + if self.transformers_fix: + bitfield = bitfield.cpu() + + # TODO: Maybe custom kernel for this. Only used for full reconstruct and loading old models, not during inference + bitfield = bitfield.view(torch.uint16).to(torch.int) + masks = (1 << torch.arange(16)).to(bitfield.device) + expanded = (bitfield.unsqueeze(-1) & masks) > 0 + expanded = expanded.flatten() + expanded = torch.where(expanded, torch.tensor(-1.0, dtype = torch.float16), torch.tensor(1.0, dtype = torch.float16)) + return expanded.contiguous().to(device) + + + def get_weight_tensor(self): + # suh = self.unpack_bf(self.su).unsqueeze(1) + suh = self.unpack_bf(self.su).unsqueeze(1) if self.su else self.suh.unsqueeze(1) + svh = self.unpack_bf(self.sv).unsqueeze(0) if self.sv else self.svh.unsqueeze(0) + w = self.get_inner_weight_tensor() + w = preapply_had_l(w, had_k) + w *= suh + w = preapply_had_r(w, had_n) + w *= svh + # w *= self.scale + return w + + + def get_inner_weight_tensor(self): + w = torch.empty((self.in_features, self.out_features), dtype = torch.half, device = self.trellis.device) + ext.reconstruct(w, self.trellis, self.K, self.mcg, self.mul1) + return w + + + def get_bias_tensor(self) -> torch.Tensor | None: + return self.bias + + + # Swap tensors to CPU (to free some space while quantizing) + def swap_cpu(self): + if self.swap_device is not None: + return + self.swap_device = self.trellis.device + if self.su is not None: self.su = self.su.cpu() + if self.sv is not None: self.sv = self.sv.cpu() + if self.suh is not None: self.suh = self.suh.cpu() + if self.svh is not None: self.svh = self.svh.cpu() + if self.trellis is not None: self.trellis = self.trellis.cpu() + if self.bias is not None: self.bias = self.bias.cpu() + + + def unswap_cpu(self): + if self.swap_device is None: + return + if self.su is not None: self.su = self.su.to(self.swap_device) + if self.sv is not None: self.sv = self.sv.to(self.swap_device) + if self.suh is not None: self.suh = self.suh.to(self.swap_device) + if self.svh is not None: self.svh = self.svh.to(self.swap_device) + if self.trellis is not None: self.trellis = self.trellis.to(self.swap_device) + if self.bias is not None: self.bias = self.bias.to(self.swap_device) + self.swap_device = None + + + def tp_export(self, plan, producer): + return { + "cls": LinearEXL3, + "in_features": self.in_features, + "out_features": self.out_features, + "suh": producer.send(self.suh), + "svh": producer.send(self.svh), + "trellis": producer.send(self.trellis), + "bias": producer.send(self.bias), + "mcg": producer.send(self.mcg_tensor), + "mul1": producer.send(self.mul1_tensor), + "out_dtype": self.out_dtype, + } + + + @staticmethod + def tp_import_split(local_context, exported, plan, split): + consumer = local_context["consumer"] + id_suh = exported["suh"] + id_svh = exported["svh"] + id_trellis = exported["trellis"] + id_bias = exported["bias"] + mcg = consumer.recv(exported["mcg"], cuda = True) + mul1 = consumer.recv(exported["mul1"], cuda = True) + + if split is not None: + split_out, first, last = split + else: + split_out, first, last = True, 0, exported["out_features"] + + if split_out: + suh = consumer.recv(id_suh, cuda = True) + svh = consumer.recv(id_svh, cuda = True, slice_dim = 0, first = first, last = last) + trellis = consumer.recv(id_trellis, cuda = True, slice_dim = 1, first = first // 16, last = last // 16) + bias = consumer.recv(id_bias, cuda = True, slice_dim = 0, first = first, last = last) + in_features = exported["in_features"] + out_features = last - first + else: + suh = consumer.recv(id_suh, cuda = True, slice_dim = 0, first = first, last = last) + svh = consumer.recv(id_svh, cuda = True) + trellis = consumer.recv(id_trellis, cuda = True, slice_dim = 0, first = first // 16, last = last // 16) + bias = consumer.recv(id_bias, cuda = True) + in_features = last - first + out_features = exported["out_features"] + + module = LinearEXL3( + config = None, + in_features = in_features, + out_features = out_features, + scale = None, + su = None, + sv = None, + suh = suh, + svh = svh, + trellis = trellis, + mcg = mcg, + mul1 = mul1, + bias = bias, + out_dtype = exported["out_dtype"], + ) + return module diff --git a/gptqmodel/exllamav3/modules/quant/exl3_lib/__init__.py b/gptqmodel/exllamav3/modules/quant/exl3_lib/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/gptqmodel/exllamav3/modules/quant/exl3_lib/quantize.py b/gptqmodel/exllamav3/modules/quant/exl3_lib/quantize.py new file mode 100644 index 000000000..4aa017e64 --- /dev/null +++ b/gptqmodel/exllamav3/modules/quant/exl3_lib/quantize.py @@ -0,0 +1,1070 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium +# +# Portions of this file are adapted from turboderp-org/exllamav3. +# Credits: TurboDerp / ExLlamaV3 contributors. + +import math +import threading +from functools import lru_cache + +import torch + +from ....ext import exllamav3_ext as ext +from ....util.progress import ProgressBar +from ....util.memory import free_mem +from ....util.hadamard import get_hadamard_dt +from ....util.tensor import save_tensor_image + +# Constant +had_k, had_n = 128, 128 +codebook_scale = 1.24371088 + +codebook_mcg_mult = 0xCBAC1FED +codebook_mul1_mult = 0x83DCD12D + +@lru_cache +def tensor_core_perm(device): + perm_a = [0] * 256 + for t in range(32): + r0 = (t % 4) * 2 + r1 = r0 + 1 + r2 = r0 + 8 + r3 = r0 + 9 + c0 = t // 4 + c1 = c0 + 8 + perm_a[t * 8 + 0] = r0 * 16 + c0 + perm_a[t * 8 + 1] = r1 * 16 + c0 + perm_a[t * 8 + 2] = r2 * 16 + c0 + perm_a[t * 8 + 3] = r3 * 16 + c0 + perm_a[t * 8 + 4] = r0 * 16 + c1 + perm_a[t * 8 + 5] = r1 * 16 + c1 + perm_a[t * 8 + 6] = r2 * 16 + c1 + perm_a[t * 8 + 7] = r3 * 16 + c1 + return torch.tensor(perm_a, dtype = torch.int, device = device) + + +@lru_cache +def tensor_core_perm_i(device): + return torch.argsort(tensor_core_perm(device)) + + +@lru_cache +def get_temp_buffers(device, K: int): + max_batch_size = 256 + temp_costs = torch.zeros((max_batch_size, 2, 65536 >> K), dtype = torch.half, device = device) + temp_edges = torch.zeros((max_batch_size, 256, 65536 >> K), dtype = torch.short, device = device) + return temp_costs, temp_edges + + +def quantize_tiles(tiles, quant_args: dict): + tiles = tiles.contiguous() + assert tiles.shape[1] == 256 + assert tiles.dtype == torch.float + + K = quant_args["K"] + mcg = "mcg" in quant_args + mul1 = "mul1" in quant_args + quantized_tiles = torch.zeros_like(tiles) + quantized_idx = torch.zeros_like(tiles, dtype = torch.short) + temp_costs, temp_edges = get_temp_buffers(tiles.device, K) + ext.quantize_tiles( + tiles, + quantized_tiles, + quantized_idx, + temp_costs, + temp_edges, + K, + mcg, + mul1, + ) + return quantized_tiles, quantized_idx + + +@lru_cache +def get_quant_stream(device): + return torch.cuda.Stream(device = device) + + +pinned_tiles: torch.Tensor | None = None +pinned_q_tiles: torch.Tensor | None = None +pinned_q_idx: torch.Tensor | None = None +def get_pinned(num_tiles: int): + global pinned_tiles, pinned_q_tiles, pinned_q_idx + if pinned_tiles is None or pinned_tiles.shape[0] < num_tiles: + pinned_tiles = torch.empty((num_tiles, 256), device = "cpu", dtype = torch.float, pin_memory = True) + pinned_q_tiles = torch.empty((num_tiles, 256), device = "cpu", dtype = torch.float, pin_memory = True) + pinned_q_idx = torch.empty((num_tiles, 256), device = "cpu", dtype = torch.int16, pin_memory = True) + return pinned_tiles[:num_tiles, :], pinned_q_tiles[:num_tiles, :], pinned_q_idx[:num_tiles, :] + + +def quantize_tiles_multigpu(tiles, quant_args: dict): + devices = quant_args["devices"] + if len(devices) == 1: + return quantize_tiles(tiles, quant_args) + + # Get pinned buffers + pin_tiles, pin_q_tiles, pin_q_idx = get_pinned(tiles.shape[0]) + + # Copy input tiles to pinned memory. Input is always on the first device in the split + copy_input_event = torch.cuda.Event(blocking = False) + main_stream = get_quant_stream(devices[0]) + with torch.cuda.stream(main_stream): + tiles = tiles.contiguous() + pin_tiles.copy_(tiles, non_blocking = True) + copy_input_event.record(main_stream) + + # Create split slices for input tiles, output tiles and output indices + ratios = quant_args.get("device_ratios") + if ratios: + s = sum(ratios) + split_sizes = [tiles.shape[0] * r / s for r in ratios] + split_sizes = [round(s / 16) * 16 for s in split_sizes] + split_sizes[-1] += tiles.shape[0] - sum(split_sizes) + else: + split_sizes = [tiles.shape[0] // len(devices)] * len(devices) + split_sizes[-1] += tiles.shape[0] - sum(split_sizes) + + # Account for negative splits (edge case if too many GPUs and/or tensor too small) + for i in range(len(split_sizes) - 2, -1, -1): + if split_sizes[i + 1] < 0: + split_sizes[i] += split_sizes[i + 1] + split_sizes[i + 1] = 0 + + pin_split_tiles = torch.split(pin_tiles, split_sizes) + pin_split_q_tiles = torch.split(pin_q_tiles, split_sizes) + pin_split_q_idx = torch.split(pin_q_idx, split_sizes) + + slice_done_events = [] + for i, device in enumerate(devices): + + stream = get_quant_stream(device) + with torch.cuda.stream(stream): + + # Wait for input in host memory + if i > 0: + stream.wait_event(copy_input_event) + + if split_sizes[i] > 0: + + # Asynchronously copy the slice from the pinned buffer to device memory + dev_tiles = pin_split_tiles[i].to(device, non_blocking = True) + + # Preallocate output tensors on the device. + dev_q_tiles = torch.empty_like(dev_tiles, device = device) + dev_q_idx = torch.empty_like(dev_tiles, dtype = torch.short, device = device) + + # Work buffers + K = quant_args["K"] + mcg = "mcg" in quant_args + mul1 = "mul1" in quant_args + temp_costs, temp_edges = get_temp_buffers(device, K) + + ext.quantize_tiles( + dev_tiles, + dev_q_tiles, + dev_q_idx, + temp_costs, + temp_edges, + K, + mcg, + mul1 + ) + + # Async copy back to pinned memory + pin_split_q_tiles[i].copy_(dev_q_tiles, non_blocking = True) + pin_split_q_idx[i].copy_(dev_q_idx, non_blocking = True) + + # Finished slice + evt = torch.cuda.Event(blocking = False) + slice_done_events.append(evt) + evt.record(stream) + + # Copy pinned buffers to original device + with torch.cuda.stream(main_stream): + for evt in slice_done_events: + main_stream.wait_event(evt) + q_tiles = torch.empty_like(tiles, device = devices[0]) + q_idx = torch.empty_like(tiles, dtype = torch.short, device = devices[0]) + q_tiles.copy_(pin_q_tiles, non_blocking = True) + q_idx.copy_(pin_q_idx, non_blocking = True) + + return q_tiles, q_idx + + +def quantize_tiles_multigpu_sync(tiles, quant_args: dict): + devices = quant_args["devices"] + if len(devices) == 1: + return quantize_tiles(tiles, quant_args) + + tiles = tiles.contiguous() + + split_sizes = [tiles.shape[0] // len(devices)] * len(devices) + split_sizes[-1] += tiles.shape[0] - sum(split_sizes) + split_tiles = torch.split(tiles, split_sizes) + tiles_per_device = [chunk.to(device) for chunk, device in zip(split_tiles, devices)] + torch.cuda.synchronize() + + q_tiles_per_device = [] + q_idx_per_device = [] + for dev_tiles, device in zip(tiles_per_device, devices): + with torch.cuda.stream(get_quant_stream(device)): + dev_q_tiles, dev_q_idx = quantize_tiles(dev_tiles, quant_args) + q_tiles_per_device.append(dev_q_tiles) + q_idx_per_device.append(dev_q_idx) + + for device in devices: + torch.cuda.synchronize(device) + + q_tiles_per_device = [x.to(devices[0]) for x in q_tiles_per_device] + q_idx_per_device = [x.to(devices[0]) for x in q_idx_per_device] + quantized_tiles = torch.cat(q_tiles_per_device, dim = 0) + quantized_idx = torch.cat(q_idx_per_device, dim = 0) + return quantized_tiles, quantized_idx + + +def preapply_had_l(x: torch.Tensor, had_dim): + k, n = x.shape + x_dtype = x.dtype + x = x.to(torch.float) + had = get_hadamard_dt(had_dim, x.device, x.dtype, 1 / math.sqrt(had_dim)) + x = (had @ x.view(-1, had_dim, n)).view(k, n) + x = x.to(x_dtype) + return x + + +def preapply_had_r(x: torch.Tensor, had_dim): + k, n = x.shape + x_dtype = x.dtype + x = x.to(torch.float) + had = get_hadamard_dt(had_dim, x.device, x.dtype, 1 / math.sqrt(had_dim)) + x = (x.view(k, -1, had_dim) @ had).view(k, n) + x = x.to(x_dtype) + return x + + +def blockwise_preapply_had_l_(x: torch.Tensor, had_dim): + k, n = x.shape + assert k % had_dim == 0 + assert x.dtype == torch.float + had = get_hadamard_dt(had_dim, x.device, x.dtype, 1 / math.sqrt(had_dim)) + num_blocks = k // had_dim + for i in range(num_blocks): + start = i * had_dim + end = start + had_dim + block = x[start:end, :] # shape (had_dim, n) + block_transformed = had @ block.view(had_dim, n) + x[start:end, :] = block_transformed + + +def blockwise_preapply_had_r_(x: torch.Tensor, had_dim): + k, n = x.shape + assert n % had_dim == 0 + assert x.dtype == torch.float + had = get_hadamard_dt(had_dim, x.device, x.dtype, 1 / math.sqrt(had_dim)) + num_blocks = n // had_dim + for i in range(num_blocks): + start = i * had_dim + end = start + had_dim + block = x[:, start:end] # shape (k, had_dim) + block_transformed = block @ had + x[:, start:end] = block_transformed + + +def block_ldl(H: torch.Tensor, b: int, verbose: bool): + + n, _ = H.shape + assert (n % b == 0) + m = n // b + + # Cholesky factorization: H = L @ L.T + # Try on GPU first + try: + retry_cpu = False + L = torch.linalg.cholesky(H) + # H is not needed after this, move to CPU. Then overwrite H's GPU storage with L, since we can't otherwise + # free up that VRAM as the tensor is referenced by the parent frame + H_cpu = H.cpu() + H.copy_(L) # VRAM copy, tiny overhead + L = H + H = H_cpu + + # Fall back on CPU factorization + except Exception as e: + if e.__class__.__name__ == "OutOfMemoryError" or "CUDA out of memory" in str(e) or "HIP out of memory" in str(e): + retry_cpu = True + else: + raise e + if retry_cpu: + print(f" !! Out of memory on {str(H.device)}, trying CPU fallback") + free_mem() + H_cpu = H.cpu() + L_cpu = torch.linalg.cholesky(H_cpu) + # This is ugly, but overwrite H in VRAM to avoid allocating a new tensor, then replace reference with CPU copy + H.copy_(L_cpu) + del L_cpu + L = H + H = H_cpu + + # Get blocks along diagonal of L: DL.shape = (m, b, b) + DL = torch.diagonal(L.reshape(m, b, m, b), dim1 = 0, dim2 = 2).permute(2, 0, 1) + + # Compute D as D[i] = DL[i] @ DL[i].T for each diagonal block i (don't actually end up needing this) + # D = DL @ DL.transpose(1, 2) + + # Invert each diagonal block + DL = torch.linalg.inv(DL) + + # Multiply each block's column with its inverse + L = L.view(n, m, b) + for i in range(m): + L[:, i, :] = L[:, i, :] @ DL[i, :, :] # TODO: Could maybe be L[m * b:, i, :]? + L = L.reshape(n, n).contiguous() + + # Insert block identity matrices along the diagonal. + # TODO: Figure out if this is necessary. Diagonal blocks should already be identities after previous step + L_block = L.view(m, b, m, b).permute(0, 2, 1,3) + dr = torch.arange(m) + L_block[dr, dr] = torch.stack([torch.eye(b, device = L.device, dtype = H.dtype)] * m) + + return L, H # , D.to(DL.device) + + +def ldlq( + weight: torch.Tensor, + L: torch.Tensor, + quant_args: dict, + pb: ProgressBar | None = None +): + """ + :param weight: + Input weights, shape (k, n). If device is "cpu", result is collected on CPU as well, saving a bunch of + VRAM but adding a little PCIe overhead and many sync points + + :param L: + LDL decomposition of regularized H + + :param quant_args: + dict: + - K: bitrate + - buf_size_k: buffer size for LDLQ, along k + + :param pb: + Optional ProgressPar to update, k // 16 steps + + :return: + tuple: + - quantized weight, shape (k, n) + - indices (unpacked), shape (k // 16, n // 16, 256), uint16_t + """ + + devices = quant_args["devices"] + for device in devices: + torch.cuda.synchronize(device) + main_stream = get_quant_stream(devices[0]) + with torch.cuda.stream(main_stream): + + devices = quant_args["devices"] + device = L.device + assert device == torch.device(devices[0]) + + buffer_device = weight.device + size_k, size_n = weight.shape # Row-major + assert size_k % 16 == 0 + assert size_n % 128 == 0 + tiles_k = size_k // 16 + tiles_n = size_n // 16 + + buf_size_k = max(quant_args.get("buf_size_k", 128), 16) + assert buf_size_k % 16 == 0 + assert size_n % buf_size_k == 0 + + p_row = 0 + + # Work buffers + prod_cache = torch.zeros((size_k, size_n), dtype = torch.float, device = device) + weight_q = torch.zeros((size_k, size_n), dtype = torch.float, device = buffer_device) + encoded = torch.zeros((tiles_k, tiles_n, 256), dtype = torch.short, device = buffer_device) + + for j in range(size_k, 0, -buf_size_k): + i = j - buf_size_k + + # Current span is rows i:j + b_weight = weight[i:j].to(device) + b_weight_q = weight_q[i:j] if device == buffer_device else \ + torch.zeros_like(weight_q[i:j], device = device) + b_encoded = encoded[i // 16 : j // 16] if device == buffer_device else \ + torch.zeros_like(encoded[i // 16 : j // 16], device = device) + b_prod_cache = prod_cache[i:j] + b_L = L[i:j] + + # Iterate over rows of blocks in current span + for bj in range(buf_size_k, 0, -16): + bi = bj - 16 + + # Error so far for the current span + bb_err = b_weight[bj:] - b_weight_q[bj:] + + # Corresponding slice of LDL decomposition of H + bb_L = b_L[bj:, i + bi:i + bj] + + # Input tiles for quantization + compensation_term = b_prod_cache[bi:bj] + compensation_term.addmm_(bb_L.T, bb_err, alpha = 1.0, beta = 1.0) + rows = b_weight[bi:bj] + compensation_term + + tiles = rows.reshape(16, tiles_n, 16).permute(1, 0, 2).reshape(tiles_n, 256) + + # Pre-permute to tensor core layout + tiles = tiles[:, tensor_core_perm(device)] + + # Quantize + quant_w, quant_i = quantize_tiles_multigpu(tiles, quant_args) + + # Undo permutation on reconstructed tiles, but keep indices in tensor core layout + quant_w = quant_w[:, tensor_core_perm_i(device)] + + # Store result + quant_w = quant_w.reshape(tiles_n, 16, 16).permute(1, 0, 2).reshape(16, size_n) + b_weight_q[bi:bj] = quant_w + b_encoded[bi // 16 : bj // 16] = quant_i.unsqueeze(0) + + # Update progress + if pb: + p_row += 1 + pb.update(p_row) + + # Collect output + if device != buffer_device: + weight_q[i:j] = b_weight_q.to(buffer_device) + encoded[i // 16 : j // 16] = b_encoded.to(buffer_device) + + # Cache error term for the rest of the matrix + b_err = b_weight - b_weight_q + prod_cache.addmm_(b_L.T, b_err, alpha = 1.0, beta = 1.0) + + for device in devices: + torch.cuda.synchronize(device) + + return weight_q, encoded + + +def fallback_quant( + weight: torch.Tensor, + q_device: torch.Tensor, + quant_args: dict, + pb: ProgressBar | None = None +): + """ + Perform the same quantization as ldlq() but without an LDL decomposition + + :param weight: + Input weights, shape (k, n). If device is "cpu", result is collected on CPU as well, saving a bunch of + VRAM but adding a little PCIe overhead and many sync points + + :param q_device: + Target device + + :param quant_args: + dict: + - K: bitrate + - buf_size_k: buffer size for faux-LDLQ, along k + + :param pb: + Optional ProgressPar to update, k // 16 steps + + :return: + tuple: + - quantized weight, shape (k, n) + - indices (unpacked), shape (k // 16, n // 16, 256), uint16_t + """ + + devices = quant_args["devices"] + for device in devices: + torch.cuda.synchronize(device) + main_stream = get_quant_stream(devices[0]) + with torch.cuda.stream(main_stream): + + devices = quant_args["devices"] + device = weight.device + assert device == torch.device(devices[0]) + + buffer_device = weight.device + size_k, size_n = weight.shape # Row-major + assert size_k % 16 == 0 + assert size_n % 128 == 0 + tiles_k = size_k // 16 + tiles_n = size_n // 16 + + buf_size_k = max(quant_args.get("buf_size_k", 128), 16) + assert buf_size_k % 16 == 0 + assert size_n % buf_size_k == 0 + + p_row = 0 + + # Work buffers + weight_q = torch.zeros((size_k, size_n), dtype = torch.float, device = buffer_device) + encoded = torch.zeros((tiles_k, tiles_n, 256), dtype = torch.short, device = buffer_device) + + for j in range(size_k, 0, -buf_size_k): + i = j - buf_size_k + + # Current span is rows i:j + b_weight = weight[i:j].to(device) + b_weight_q = weight_q[i:j] if device == buffer_device else \ + torch.zeros_like(weight_q[i:j], device = device) + b_encoded = encoded[i // 16 : j // 16] if device == buffer_device else \ + torch.zeros_like(encoded[i // 16 : j // 16], device = device) + + # Iterate over rows of blocks in current span + for bj in range(buf_size_k, 0, -16): + bi = bj - 16 + + # Input tiles for quantization + rows = b_weight[bi:bj] + tiles = rows.reshape(16, tiles_n, 16).permute(1, 0, 2).reshape(tiles_n, 256) + + # Pre-permute to tensor core layout + tiles = tiles[:, tensor_core_perm(device)] + + # Quantize + quant_w, quant_i = quantize_tiles_multigpu(tiles, quant_args) + + # Undo permutation on reconstructed tiles, but keep indices in tensor core layout + quant_w = quant_w[:, tensor_core_perm_i(device)] + + # Store result + quant_w = quant_w.reshape(tiles_n, 16, 16).permute(1, 0, 2).reshape(16, size_n) + b_weight_q[bi:bj] = quant_w + b_encoded[bi // 16 : bj // 16] = quant_i.unsqueeze(0) + + # Update progress + if pb: + p_row += 1 + pb.update(p_row) + + # Collect output + if device != buffer_device: + weight_q[i:j] = b_weight_q.to(buffer_device) + encoded[i // 16 : j // 16] = b_encoded.to(buffer_device) + + for device in devices: + torch.cuda.synchronize(device) + + return weight_q, encoded + + +finalize_capture_H_mutex = threading.Lock() + +def finalize_capture_H(H_data: dict, quant_args: dict, verbose: bool): + with finalize_capture_H_mutex: + + # Unswap H + if "H_swap_device" in H_data: + H_data["H"] = H_data["H"].to(H_data["H_swap_device"]) + del H_data["H_swap_device"] + + H = H_data["H"] + if H_data["finalized"]: + return H_data["q_fallback"], H, H_data["L"], H_data["su"], H_data["diag"] + + # Mean of samples summed up during forward pass + # Switch to uncalibrated fallback if no input activations or diagonal is too small (few activations) + count = H_data["count"] + if count == 0: + q_fallback = True + else: + H /= count + diag_mean = torch.diag(H).mean() + q_fallback = diag_mean.item() < 1e-20 + + # Regularize diagonal + idx = torch.arange(H.shape[0]) + H[idx, idx] += quant_args.get("sigma_reg", 0.025) * diag_mean + + # Some tests + diag = H[idx, idx].clone() + + if verbose: + print(f" - H min/max: {H.min().item():.6f} {H.max().item():.6f}") + print(f" - H mean/std: {H.mean().item():.6f} {H.std().item():.6f}") + print(f" - H diag min/max: {diag.min():.6f} {diag.max():.6f} ") + + # Random sign flips for input channel, fixed for the first linear layer to quantize with this H + k = H.shape[0] + su = (torch.randn(k, device = H.device).sign() + 1e-5).sign().to(torch.float).unsqueeze(1) + H_data["su"] = su + + # Input had + H *= su.T + blockwise_preapply_had_r_(H, had_k) + H *= su + blockwise_preapply_had_l_(H, had_k) + + # Get block LDL decomposition of H, zero diagonal + if q_fallback: + L = None + else: + L, H = block_ldl(H, 16, verbose) + dr = torch.arange(k) + L[dr, dr] = 0 + + H_data["L"] = L + + # H is no longer needed except to compute proxy error so move to CPU + H = H.cpu() + H_data["H"] = H.cpu() + + H_data["finalized"] = True + H_data["diag"] = diag + H_data["q_fallback"] = q_fallback + return q_fallback, H, L, su, diag + + +def pack_trellis(encoded: torch.Tensor, quant_args: dict) -> torch.Tensor: + K = quant_args["K"] + shape = encoded.shape + assert len(shape) == 3 and shape[2] == 256 + assert encoded.dtype == torch.int16 + packed_shape = (shape[0], shape[1], 256 * K // 16) + packed = torch.zeros(packed_shape, dtype = torch.int16, device = encoded.device) + ext.pack_trellis(packed, encoded.contiguous(), K) + # unpacked = torch.zeros_like(encoded) + # ext.unpack_trellis(unpacked, packed, K) + # assert torch.equal(unpacked, encoded) + return packed + + +def pack_signs(signs: torch.Tensor, quant_args: dict) -> torch.Tensor: + signs = signs.half().flatten().contiguous() + assert signs.shape[0] % 16 == 0 + packed = torch.zeros(signs.shape[0] // 16, dtype = torch.int16, device = signs.device) + ext.pack_signs(packed, signs) + return packed + + +def g_scale_gss( + weight_r: torch.Tensor, + verbose: bool, + quant_args: dict, + width: int = 3, + pb: ProgressBar = None +): + # Select a sample of tiles along a wrapped diagonal (sampling from every row and column of tiles, hopefully + # representative) and search for the global scale within given range that minimizes the direct quantization + # error + tiles = [] + tiles_k = weight_r.shape[0] // 16 + tiles_n = weight_r.shape[1] // 16 + for i in range(max(tiles_k, tiles_n)): + for w in range(width): + k = (i % tiles_k) * 16 + n = ((i + w) % tiles_n) * 16 + tile = weight_r[k : k + 16, n : n + 16].clone() + tile = tile.view(256) + tile = tile[tensor_core_perm(weight_r.device)] + tiles.append(tile) + tiles = torch.stack(tiles) + + devices = quant_args["devices"] + for device in devices: + torch.cuda.synchronize(device) + + main_stream = get_quant_stream(devices[0]) + # TODO: Figure out why Torch always initializes cuda:0 when exiting this CM, even when it's not used + with torch.cuda.stream(main_stream): + + def test_scale(scale: float): + quant_w, quant_i = quantize_tiles_multigpu(tiles * scale, quant_args) + mse = ((quant_w / scale - tiles) ** 2).mean() + return mse + + # Assume quantization error is a unimodal function of scale, golden section search to find minimum + phi = (1 + math.sqrt(5)) / 2 + resphi = 2 - phi + + a, b = 0.1, 1.9 + tol = 0.01 + delta1 = abs(b - a) + + x1 = a + resphi * (b - a) + x2 = b - resphi * (b - a) + f1 = test_scale(x1) + f2 = test_scale(x2) + while abs(b - a) > tol: + # if verbose: + # print(f" - gss: a = {a:.6f}, b = {b:.6f}") + if f1 < f2: + b = x2 + x2 = x1 + f2 = f1 + x1 = a + resphi * (b - a) + f1 = test_scale(x1) + else: + a = x1 + x1 = x2 + f1 = f2 + x2 = b - resphi * (b - a) + f2 = test_scale(x2) + delta2 = abs(b - a) + if pb: + pb.update(100 - 100 * int(delta2 / delta1)) + + best_scale = (a + b) / 2 + if verbose: + print(f" - gss: min = {best_scale:.6f}, mse: {(f1 + f2) / 2:.6f}") + + devices = quant_args["devices"] + for device in devices: + torch.cuda.synchronize(device) + + return best_scale, (f1 + f2) / 2 + + +def block_rms(x: torch.Tensor, dim: int, keepdim: bool = False, blocksize: int = 32): + # Compute blockwise x.square().mean(dim, keepdim).sqrt() + n = x.size(dim) + sq = None + for block in torch.split(x, blocksize, dim = dim): + block_sq = block.square().sum(dim = dim, keepdim = keepdim) + if sq is None: + sq = block_sq + else: + sq += block_sq + mean_sq = sq / n + return mean_sq.sqrt() + + +def block_rms_n(x: torch.Tensor, dim: int = 0, blocksize: int = 32): + # Compute blockwise x.square().mean().sqrt() + n = 0 + sq = None + for block in torch.split(x, blocksize, dim = dim): + block_sq = block.square().sum() + n += block.numel() + if sq is None: + sq = block_sq + else: + sq += block_sq + mean_sq = sq / n + return mean_sq.sqrt() + + +def block_nmse(x: torch.Tensor, y: torch.Tensor, dim: int = 0, blocksize: int = 32): + # Compute blockwise (x - y).square().mean().item() / y.square().mean().item() + sq = None + diff_sq = None + for block_x, block_y in zip(torch.split(x, blocksize, dim = dim), torch.split(y, blocksize, dim = dim)): + block_sq = block_y.square().sum() + block_diff_sq = (block_x - block_y).square().sum() + if sq is None: + sq = block_sq + diff_sq = block_diff_sq + else: + sq += block_sq + diff_sq += block_diff_sq + return diff_sq.item() / (sq.item() + 1e-20) + + +def regularize( + weight: torch.Tensor, + su: torch.Tensor, + sv: torch.Tensor, + quant_args: dict, + verbose: bool, + H_diag: torch.Tensor | None, + pb: ProgressBar | None, + skip_g_scale: bool = False, + q_fallback: bool = False +): + force_out_scales = quant_args["apply_out_scales"] + + # From experiments, it seems the deciding factor in when scaling output channels is beneficial is when + # the input to the linear layer is very irregular. After some testing, set the cutoff at 15% of the RMS sum + # on 2% of the channels + # TODO: More science + if not q_fallback and H_diag is not None: + diag = H_diag.sqrt() + diag, _ = torch.sort(diag, descending = True) + cutoff = diag.shape[0] // 50 + skew_factor = diag[:cutoff].sum() / diag.sum() + if verbose: + print(f" - input state skew: {skew_factor.item():.6f}") + + if force_out_scales is None: + apply_out_scales = skew_factor.item() < 0.15 + else: + apply_out_scales = force_out_scales + + else: + apply_out_scales = True if force_out_scales is None else force_out_scales + + if q_fallback: + apply_out_scales = force_out_scales + + # Apply output scales + out_channel_scales = block_rms(weight, dim = 0, keepdim = True) + mean = out_channel_scales.mean().item() + if mean > 1e-30: + out_channel_scales /= mean + quant_args["zeros"] = False + else: + quant_args["zeros"] = True + if force_out_scales is not None: + apply_out_scales = True + zero_out_scales = out_channel_scales.abs() < 1e-30 + + if apply_out_scales: + out_channel_scales[zero_out_scales] = 0.1 + sv = (sv * out_channel_scales + 1e-10).float() + if verbose: + out_channel_std = out_channel_scales.std().item() + out_channel_mean = out_channel_scales.mean().item() + print(f" - out ch scales std/mean: {out_channel_std:.6f} {out_channel_mean:.6f}") + + # Output sign flips (and scales) + weight /= sv + + # Force zero output channels to zero + sv[zero_out_scales] = 0.0 + + # Output hadamard transform + blockwise_preapply_had_r_(weight, had_n) + + # Input sign flips and scales + in_channel_scales = block_rms(weight, dim = 1, keepdim = True) + in_channel_scales[in_channel_scales.abs() < 1e-30] = 0.1 + su = (su * in_channel_scales / (-codebook_scale) + 1e-10).float() # mustn't be inplace + weight /= su + blockwise_preapply_had_l_(weight, had_k) + + # Determine best scale for matrix by test quantizing a sample of tiles along a wrapped diagonal + if not skip_g_scale: + g_scale, mse_scale = g_scale_gss(weight, False, quant_args, pb = pb) + else: + g_scale = 1.0 + weight *= g_scale + su /= g_scale + + # ext.test_distribution(weight_os, dist_r, dist_ref, -3.8, 3.8) + # js_os = jsd(dist_r, dist_ref) + + if verbose: + print(f" - su/sv std: {su.std().item():.6f} {sv.std().item():.6f}") + print(f" - global scale: {g_scale:.6f}") + print(f" - sample mse: {mse_scale.item():.6f}") + print(f" - apply_out_scales: {str(apply_out_scales)}") + + return apply_out_scales, weight, g_scale, su, sv + + +def quantize_exl3( + weight: torch.Tensor, + H_data: dict, + quant_args: dict, + return_weight_q: bool, + progress_str: str | None = None, + verbose: bool = False, + swap_to_device: torch.device | None = None, + save_reg: str = None +): + """ + :param weight: + Input tensor, row major shape (in_features, out_features) + + :param H_data: + Dictionary of hessian tensor and related data, as collected by Linear wrapper class. May be reused between + linear layers with the same input (e.g. Q, K and V projections) + + :param quant_args: + dict: + - K: bitrate + - seed: integer seed for random sign flips etc. + - sigma_reg: regularization factor + + :param return_weight_q: + Return quantized weight + + :param progress_str: + Show progress bar during quantization + + :param verbose: + Dump extra stats + + :param swap_to_device: + If input tensor is on CPU, move to this device before quantization + + :param save_reg: + Save regularized tensor as image to the provided path + + :return: + tuple: + - quantized weight + - proxy error: trace(err @ H @ err.T) / (W @ H @ W.T) + - quantized and packed tensors + """ + + progress_text = None if not progress_str else progress_str.replace("", "Preparing") + with (ProgressBar(progress_text, 100) as pb): + + assert weight.dtype == torch.float + tiles_k = weight.shape[0] // 16 + + if "seed" in quant_args: + torch.manual_seed(quant_args["seed"]) + + devices = quant_args["devices"] + if weight.device != torch.device(devices[0]): + weight = weight.to(devices[0]) + + device = weight.device if swap_to_device is None else swap_to_device + k, n = weight.shape + + # Get H, LDL decomp. and input/output sign flips + q_fallback, H, L, su, H_diag = finalize_capture_H(H_data, quant_args, verbose) + if H.is_cuda: + H = H.to(device) + if L is not None and L.is_cuda: + L = L.to(device) + if su.is_cuda: + su = su.to(device) + if H_diag.is_cuda: + H_diag = H_diag.to(device) + sv = (torch.randn(n, device = device).sign() + 1e-5).sign().to(torch.float).unsqueeze(0) + + # Move stored L to CPU (if not already), move working L to device + if H_data["L"] is not None: + H_data["L"] = H_data["L"].cpu() + if L is not None: + L = L.to(device) + + if swap_to_device is not None: + weight = weight.to(swap_to_device) + if verbose: + weight_copy = weight.cpu() + weight_r = weight + del weight + + if verbose: + rms = block_rms_n(weight_r, dim = 0) + print(f" - input tensor rms: {rms:.6f}") + + # Regularization + apply_out_scales, weight_r, g_scale, su, sv = regularize( + weight_r, + su, + sv, + quant_args, + verbose, + H_diag, + pb, + q_fallback = q_fallback + ) + + if save_reg: + save_tensor_image(weight_r, save_reg) + + if verbose: + rms = weight_r.square().mean().sqrt() + print(f" - regularized rms: {rms:.6f}") + + progress_text = None if not progress_str else progress_str.replace("", "Quantizing") + pb.update(0) + pb.new_task(progress_text, tiles_k) + + # Select device for work buffers (CPU is slower for small tensors but saves a lot of VRAM on big ones) + # TODO: Use pynvml or mem_get_info to predict whether CPU buffer is needed + if weight_r.numel() > 5e8: + weight_r = weight_r.cpu() + + # Quantize + if not q_fallback: + weight_q, encoded_q = ldlq(weight_r, L, quant_args, pb) #zxc + del L + else: + weight_q, encoded_q = fallback_quant(weight_r, device, quant_args, pb) # zxc + + pb.update(tiles_k) + + # Metrics + if not q_fallback: + try: + def block_trace(A, B, block_size = 1024): + total = 0.0 + for j_start in range(0, B.shape[1], block_size): + j_end = min(j_start + block_size, B.shape[1]) + B_block = B[:, j_start:j_end] + A_j_block = A[j_start:j_end, :] + partial = torch.einsum("ik,ij,jk->", A, B_block, A_j_block) + total += partial.item() + return total + E = None + W = None + Hd = None + E = weight_r - weight_q # may run on CPU + W = weight_r + Hd = H.to(device) + weight_r = None + E = E.to(device) + num = block_trace(E, Hd) + E = None + W = W.to(device) + den = block_trace(W, Hd) + W = None + Hd = None + proxy_err = num / max(den, 1e-8) + except torch.OutOfMemoryError: + del weight_r, E, W, Hd + proxy_err = -1.0 + else: + proxy_err = 0.0 + + # free_mem() + + if return_weight_q or verbose: + weight_q = weight_q.to(device) + weight_q = preapply_had_l(weight_q, had_k) + weight_q *= su + weight_q = preapply_had_r(weight_q, had_n) + weight_q *= sv + + if verbose: + weight = weight_copy.to(device) + nmse = block_nmse(weight_q, weight) + print(f" - quant nmse: {nmse:.6f}") + + # Compile packed tensor + suh = su.flatten().contiguous().to(dtype = torch.half, copy = True) + svh = sv.flatten().contiguous().to(dtype = torch.half, copy = True) + trellis = pack_trellis(encoded_q.to(device), quant_args) + + out_tensors = { + # "scale": weight_scale.to(dtype = torch.float, copy = True), + # "su": pack_signs(su, quant_args), + "suh": suh, + # "sv": pack_signs(sv, quant_args), + "svh": svh, + "trellis": trellis, + } + + # Safetensors doesn't know what to do with a torch.uint32 tensor. Anyway, since the multipliers are now + # locked, the values in these tensors are never read, but they need to be present in the model files to + # indicate which codebook to use during inference, per individual tensor. + if quant_args.get("mcg"): + out_tensors.update({ + "mcg": torch.tensor(codebook_mcg_mult, dtype = torch.uint32).view(torch.int) + }) + if quant_args.get("mul1"): + out_tensors.update({ + "mul1": torch.tensor(codebook_mul1_mult, dtype = torch.uint32).view(torch.int) + }) + + quant_args.update({ + "apply_out_scales": apply_out_scales, + "g_scale": g_scale, + "q_fallback": q_fallback, + }) + + return weight_q, proxy_err, out_tensors diff --git a/gptqmodel/exllamav3/util/__init__.py b/gptqmodel/exllamav3/util/__init__.py new file mode 100644 index 000000000..670066470 --- /dev/null +++ b/gptqmodel/exllamav3/util/__init__.py @@ -0,0 +1,3 @@ +from .misc import cuda_sync_active + +__all__ = ["cuda_sync_active"] diff --git a/gptqmodel/exllamav3/util/arch_list.py b/gptqmodel/exllamav3/util/arch_list.py new file mode 100644 index 000000000..991f13130 --- /dev/null +++ b/gptqmodel/exllamav3/util/arch_list.py @@ -0,0 +1,40 @@ +import os +import torch + +# Since Torch 2.3.0 an annoying warning is printed every time the C++ extension is loaded, unless the +# TORCH_CUDA_ARCH_LIST variable is set. The default behavior from pytorch/torch/utils/cpp_extension.py +# is copied in the function below, but without the warning. + +def maybe_set_arch_list_env(): + + if os.environ.get('TORCH_CUDA_ARCH_LIST', None): + return + + if not torch.version.cuda: + return + + arch_list = [] + for i in range(torch.cuda.device_count()): + capability = torch.cuda.get_device_capability(i) + # Strip known NVIDIA suffixes: 'a' (accelerated) or 'f' (family) + supported_sm = [int(arch.split('_')[1].rstrip('af')) + for arch in torch.cuda.get_arch_list() if 'sm_' in arch] + if not supported_sm: + continue + max_supported_sm = max((sm // 10, sm % 10) for sm in supported_sm) + # Capability of the device may be higher than what's supported by the user's + # NVCC, causing compilation error. User's NVCC is expected to match the one + # used to build pytorch, so we use the maximum supported capability of pytorch + # to clamp the capability. + capability = min(max_supported_sm, capability) + arch = f'{capability[0]}.{capability[1]}' + if arch not in arch_list: + arch_list.append(arch) + if not arch_list: + return + arch_list = sorted(arch_list) + arch_list[-1] += '+PTX' + + os.environ["TORCH_CUDA_ARCH_LIST"] = ";".join(arch_list) + +maybe_set_arch_list_env() \ No newline at end of file diff --git a/gptqmodel/exllamav3/util/hadamard.py b/gptqmodel/exllamav3/util/hadamard.py new file mode 100644 index 000000000..a7dbed449 --- /dev/null +++ b/gptqmodel/exllamav3/util/hadamard.py @@ -0,0 +1,160 @@ +from __future__ import annotations + +import glob +import math +import os +from functools import lru_cache + +import torch + +from ..ext import exllamav3_ext as ext + +had_dict: dict[int, torch.Tensor] = {} +primes: set[int] = set() +prime_limit = 1 + +def load_constants(): + global had_dict + + module_dir = os.path.dirname(os.path.abspath(__file__)) + had_dir = os.path.join(module_dir, "hadamard_data") + file_pattern = os.path.join(had_dir, "hadamard_*.txt") + files = glob.glob(file_pattern) + had_dict = {} + + for file_path in files: + with open(file_path, 'r') as file: + lines = file.readlines() + lines = [line.strip() for line in lines if line.strip()] + dim = len(lines) + assert all(len(line) == dim for line in lines), "Non-square matrix in " + file_path + matrix = [[1 if char == '+' else -1 for char in line] for line in lines] + tensor = torch.tensor(matrix, dtype = torch.float16) + had_dict[dim] = tensor + +def ensure_primes(limit: int): + global prime_limit, primes + + if limit <= prime_limit: + return + + # Rebuild the cache when the requested Hadamard order needs a larger prime range. + sieve = bytearray(b"\x01") * (limit + 1) + sieve[:2] = b"\x00\x00" + max_factor = math.isqrt(limit) + for candidate in range(2, max_factor + 1): + if not sieve[candidate]: + continue + start = candidate * candidate + sieve[start:limit + 1:candidate] = b"\x00" * (((limit - start) // candidate) + 1) + + primes = {value for value, is_prime in enumerate(sieve) if is_prime} + prime_limit = limit + +def sylvester(h: torch.Tensor): + d = h.shape[0] + assert d == h.shape[1], "h not square" + s = torch.empty((d * 2, d * 2), dtype = h.dtype, device = h.device) + s[:d, :d] = h + s[:d, d:] = h + s[d:, :d] = h + s[d:, d:] = -h + return s + +def is_quadratic_residue(a: int, p: int): + return pow(a, (p - 1) // 2, p) == 1 + +def paley_torch(n: int): + h = torch.empty((n, n), dtype = torch.half) + p = n - 1 + for i in range(p): + for j in range(p): + if i == j: + h[i + 1][j + 1] = 1 + else: + residue = (i - j) % p + if is_quadratic_residue(residue, p): + h[i + 1][j + 1] = 1 + else: + h[i + 1][j + 1] = -1 + h[0, :] = 1 + h[:, 0] = -1 + h[0, 0] = 1 + return h + +def paley(n: int): + h = torch.empty((n, n), dtype = torch.half) + ext.had_paley(h) + # ref = paley_torch(n) + # assert torch.all(h == ref) + return h + +def paley2_torch(n: int): + h = torch.empty((n, n), dtype = torch.half) + p = n // 2 - 1 + for i in range(n // 2): + i0 = 2 * i + 0 + i1 = 2 * i + 1 + for j in range(n // 2): + j0 = 2 * j + 0 + j1 = 2 * j + 1 + if j == i: + h[i0, j0] = 1 + h[i0, j1] = -1 + h[i1, j0] = -1 + h[i1, j1] = -1 + else: + residue = (i - j) % p + if i == 0 or j == 0 or is_quadratic_residue(residue, p): + h[i0, j0] = 1 + h[i0, j1] = 1 + h[i1, j0] = 1 + h[i1, j1] = -1 + else: + h[i0, j0] = -1 + h[i0, j1] = -1 + h[i1, j0] = -1 + h[i1, j1] = 1 + return h + +def paley2(n: int): + h = torch.empty((n, n), dtype = torch.half) + ext.had_paley2(h) + # ref = paley2_torch(n) + # assert torch.all(h == ref) + return h + +@lru_cache(maxsize = 100) +def get_hadamard(n: int): + global had_dict + + if not had_dict: + load_constants() + + if n in had_dict: return had_dict[n] + + # Sylvester's construction + if n % 2 == 0: + s = get_hadamard(n // 2) + if s is not None: + s = sylvester(s) + return s + + if n % 4 == 0: + ensure_primes(max(n - 1, (n // 2) - 1)) + + # Paley construction + if n % 4 == 0 and (n - 1) % 4 == 3 and (n - 1) in primes: + return paley(n) + + # Other Paley construction + if n % 4 == 0 and (n // 2) - 1 in primes: + return paley2(n) + + return None + +@lru_cache(maxsize = 100) +def get_hadamard_dt(n: int, device: torch.device | str, dtype: torch.dtype, scale = 1.0): + had = get_hadamard(n).to(device = device, dtype = dtype, copy = True) + had *= scale + return had diff --git a/gptqmodel/exllamav3/util/hadamard_data/hadamard_1.txt b/gptqmodel/exllamav3/util/hadamard_data/hadamard_1.txt new file mode 100644 index 000000000..9b26e9b10 --- /dev/null +++ b/gptqmodel/exllamav3/util/hadamard_data/hadamard_1.txt @@ -0,0 +1 @@ ++ \ No newline at end of file diff --git a/gptqmodel/exllamav3/util/hadamard_data/hadamard_100.txt b/gptqmodel/exllamav3/util/hadamard_data/hadamard_100.txt new file mode 100644 index 000000000..78f3a4ae5 --- /dev/null +++ b/gptqmodel/exllamav3/util/hadamard_data/hadamard_100.txt @@ -0,0 +1,100 @@ ++-++++----++--++----++++-+++-+++---+-++-+---+++-+++-++-+++-++----++-+++-++-+--++-+-+--------+-+-++-- +-+-++++----++--++----++++++++-+++---+-++-+---+++-+-+-++-+++-++----++-+++-++-+--++-+-+--------+-+-++- ++-+-++++----++--++----++++++++-+++---+-++-+---+++-+-+-++-+++-++----++-+++-+--+--++-+-+--------+-+-++ +++-+-++++----++--++----++-+++++-+++---+-++-+---+++++-+-++-+++-++----++-+++-+--+--++-+-+--------+-+-+ ++++-+-++++----++--++----++-+++++-+++---+-++-+---++-++-+-++-+++-++----++-+++++--+--++-+-+--------+-+- +++++-+-++++----++--++----++-+++++-+++---+-++-+---++-++-+-++-+++-++----++-++-++--+--++-+-+--------+-+ +-++++-+-++++----++--++---+++-+++++-+++---+-++-+---++-++-+-++-+++-++----++-++-++--+--++-+-+--------+- +--++++-+-++++----++--++---+++-+++++-+++---+-++-+--+++-++-+-++-+++-++----++--+-++--+--++-+-+--------+ +---++++-+-++++----++--++---+++-+++++-+++---+-++-+--+++-++-+-++-+++-++----+++-+-++--+--++-+-+-------- +----++++-+-++++----++--++---+++-+++++-+++---+-++-++-+++-++-+-++-+++-++----+-+-+-++--+--++-+-+------- ++----++++-+-++++----++--++---+++-+++++-+++---+-++-++-+++-++-+-++-+++-++------+-+-++--+--++-+-+------ +++----++++-+-++++----++---+---+++-+++++-+++---+-++-++-+++-++-+-++-+++-++------+-+-++--+--++-+-+----- +-++----++++-+-++++----++-+-+---+++-+++++-+++---+-+--++-+++-++-+-++-+++-++------+-+-++--+--++-+-+---- +--++----++++-+-++++----++++-+---+++-+++++-+++---+----++-+++-++-+-++-+++-++------+-+-++--+--++-+-+--- ++--++----++++-+-++++----+-++-+---+++-+++++-+++---+----++-+++-++-+-++-+++-++------+-+-++--+--++-+-+-- +++--++----++++-+-++++----+-++-+---+++-+++++-+++---+----++-+++-++-+-++-+++-+-------+-+-++--+--++-+-+- +-++--++----++++-+-++++----+-++-+---+++-+++++-+++--++----++-+++-++-+-++-+++---------+-+-++--+--++-+-+ +--++--++----++++-+-++++----+-++-+---+++-+++++-+++--++----++-+++-++-+-++-++++--------+-+-++--+--++-+- +---++--++----++++-+-++++----+-++-+---+++-+++++-++++-++----++-+++-++-+-++-++-+--------+-+-++--+--++-+ +----++--++----++++-+-+++++---+-++-+---+++-+++++-++++-++----++-+++-++-+-++-++-+--------+-+-++--+--++- ++----++--++----++++-+-+++++---+-++-+---+++-+++++-++++-++----++-+++-++-+-++--+-+--------+-+-++--+--++ +++----++--++----++++-+-+++++---+-++-+---+++-+++++--+++-++----++-+++-++-+-+++-+-+--------+-+-++--+--+ ++++----++--++----++++-+-+-+++---+-++-+---+++-++++++-+++-++----++-+++-++-+-+++-+-+--------+-+-++--+-- +++++----++--++----++++-+-+-+++---+-++-+---+++-++++++-+++-++----++-+++-++-+--++-+-+--------+-+-++--+- +-++++----++--++----++++-+++-+++---+-++-+---+++-+++-++-+++-++----++-+++-++-+--++-+-+--------+-+-++--+ +---+---+++-+--+-+++---+--+-++++----++--++----++++--++--+-+-++++++++-+-+--+++-++-+++-++----++-+++-++- +----+---+++-+--+-+++---+--+-++++----++--++----+++++-++--+-+-++++++++-+-+--+-+-++-+++-++----++-+++-++ +-----+---+++-+--+-+++---++-+-++++----++--++----+++++-++--+-+-++++++++-+-+--+-+-++-+++-++----++-+++-+ ++-----+---+++-+--+-+++---++-+-++++----++--++----++-++-++--+-+-++++++++-+-+-++-+-++-+++-++----++-+++- +-+-----+---+++-+--+-+++--+++-+-++++----++--++----+--++-++--+-+-++++++++-+-+-++-+-++-+++-++----++-+++ +--+-----+---+++-+--+-+++-++++-+-++++----++--++----+--++-++--+-+-++++++++-+-+-++-+-++-+++-++----++-++ +---+-----+---+++-+--+-+++-++++-+-++++----++--++----+--++-++--+-+-++++++++-+++-++-+-++-+++-++----++-+ ++---+-----+---+++-+--+-++--++++-+-++++----++--++--+-+--++-++--+-+-++++++++-+++-++-+-++-+++-++----++- +++---+-----+---+++-+--+-+---++++-+-++++----++--++--+-+--++-++--+-+-++++++++-+++-++-+-++-+++-++----++ ++++---+-----+---+++-+--+-----++++-+-++++----++--+++-+-+--++-++--+-+-++++++++-+++-++-+-++-+++-++----+ +-+++---+-----+---+++-+--++----++++-+-++++----++--+++-+-+--++-++--+-+-++++++++-+++-++-+-++-+++-++---- ++-+++---+-----+---+++-+--++----++++-+-++++----++--+++-+-+--++-++--+-+-+++++-++-+++-++-+-++-+++-++--- +-+-+++---+-----+---+++-+--++----++++-+-++++----++-++++-+-+--++-++--+-+-++++--++-+++-++-+-++-+++-++-- +--+-+++---+-----+---+++-+--++----++++-+-++++----+++++++-+-+--++-++--+-+-+++---++-+++-++-+-++-+++-++- ++--+-+++---+-----+---+++-+--++----++++-+-++++----+++++++-+-+--++-++--+-+-++----++-+++-++-+-++-+++-++ +-+--+-+++---+-----+---+++++--++----++++-+-++++----+++++++-+-+--++-++--+-+-++----++-+++-++-+-++-+++-+ ++-+--+-+++---+-----+---++-++--++----++++-+-++++---++++++++-+-+--++-++--+-+-++----++-+++-++-+-++-+++- +++-+--+-+++---+-----+---+--++--++----++++-+-++++---++++++++-+-+--++-++--+-+-++----++-+++-++-+-++-+++ ++++-+--+-+++---+-----+------++--++----++++-+-++++-+-++++++++-+-+--++-++--+-+-++----++-+++-++-+-++-++ +-+++-+--+-+++---+-----+------++--++----++++-+-++++-+-++++++++-+-+--++-++--+++-++----++-+++-++-+-++-+ +--+++-+--+-+++---+-----+-+----++--++----++++-+-++++-+-++++++++-+-+--++-++--+++-++----++-+++-++-+-++- +---+++-+--+-+++---+-----+++----++--++----++++-+-++-+-+-++++++++-+-+--++-++--+++-++----++-+++-++-+-++ ++---+++-+--+-+++---+-----+++----++--++----++++-+-+--+-+-++++++++-+-+--++-+++-+++-++----++-+++-++-+-+ +-+---+++-+--+-+++---+----++++----++--++----++++-+-+--+-+-++++++++-+-+--++-+++-+++-++----++-+++-++-+- +--+---+++-+--+-+++---+----++++----++--++----++++-+++--+-+-++++++++-+-+--++--++-+++-++----++-+++-++-+ +-+--+---+--++++--+---+--++--++-+-+--------+-+-++--+-++++----++--++----++++----+---+++-+--+-+++---+-- ++-+--+---+--++++--+---+---+--++-+-+--------+-+-++--+-++++----++--++----++++----+---+++-+--+-+++---+- +-+-+--+---+--++++--+---+---+--++-+-+--------+-+-+++-+-++++----++--++----+++-----+---+++-+--+-+++---+ +--+-+--+---+--++++--+---++--+--++-+-+--------+-+-+++-+-++++----++--++----+++-----+---+++-+--+-+++--- ++--+-+--+---+--++++--+---++--+--++-+-+--------+-+-+++-+-++++----++--++----+-+-----+---+++-+--+-+++-- +-+--+-+--+---+--++++--+---++--+--++-+-+--------+-+++++-+-++++----++--++------+-----+---+++-+--+-+++- +--+--+-+--+---+--++++--+-+-++--+--++-+-+--------+--++++-+-++++----++--++------+-----+---+++-+--+-+++ +---+--+-+--+---+--++++--+-+-++--+--++-+-+--------+--++++-+-++++----++--++--+---+-----+---+++-+--+-++ ++---+--+-+--+---+--++++--+-+-++--+--++-+-+-----------++++-+-++++----++--++-++---+-----+---+++-+--+-+ +-+---+--+-+--+---+--++++--+-+-++--+--++-+-+-----------++++-+-++++----++--+++++---+-----+---+++-+--+- +--+---+--+-+--+---+--++++--+-+-++--+--++-+-+------+----++++-+-++++----++--+-+++---+-----+---+++-+--+ ++--+---+--+-+--+---+--+++---+-+-++--+--++-+-+-----++----++++-+-++++----++--+-+++---+-----+---+++-+-- +++--+---+--+-+--+---+--++----+-+-++--+--++-+-+-----++----++++-+-++++----++--+-+++---+-----+---+++-+- ++++--+---+--+-+--+---+--+-----+-+-++--+--++-+-+-----++----++++-+-++++----++--+-+++---+-----+---+++-+ +++++--+---+--+-+--+---+--------+-+-++--+--++-+-+--+--++----++++-+-++++----++--+-+++---+-----+---+++- +-++++--+---+--+-+--+---+--------+-+-++--+--++-+-+-++--++----++++-+-++++-----+--+-+++---+-----+---+++ +--++++--+---+--+-+--+---+--------+-+-++--+--++-+-+-++--++----++++-+-++++---+-+--+-+++---+-----+---++ ++--++++--+---+--+-+--+---+--------+-+-++--+--++-+---++--++----++++-+-++++--++-+--+-+++---+-----+---+ +-+--++++--+---+--+-+--+---+--------+-+-++--+--++-+---++--++----++++-+-++++-+++-+--+-+++---+-----+--- +--+--++++--+---+--+-+--+-+-+--------+-+-++--+--++-----++--++----++++-+-++++-+++-+--+-+++---+-----+-- +---+--++++--+---+--+-+--+-+-+--------+-+-++--+--+++----++--++----++++-+-+++--+++-+--+-+++---+-----+- ++---+--++++--+---+--+-+--+-+-+--------+-+-++--+--+++----++--++----++++-+-++---+++-+--+-+++---+-----+ +-+---+--++++--+---+--+-+-++-+-+--------+-+-++--+--+++----++--++----++++-+-++---+++-+--+-+++---+----- +--+---+--++++--+---+--+-+-++-+-+--------+-+-++--+-++++----++--++----++++-+--+---+++-+--+-+++---+---- ++--+---+--++++--+---+--+---++-+-+--------+-+-++--+-++++----++--++----++++-+--+---+++-+--+-+++---+--- +-++--+-+-++++++++-+-+--++-+--+---+--++++--+---+--++++-+++---+-++-+---+++-+++-++++----++--++----++++- ++-++--+-+-++++++++-+-+--++-+--+---+--++++--+---+--++++-+++---+-++-+---+++-+-+-++++----++--++----++++ +++-++--+-+-++++++++-+-+---+-+--+---+--++++--+---+-+++++-+++---+-++-+---+++-+-+-++++----++--++----+++ +-++-++--+-+-++++++++-+-+---+-+--+---+--++++--+---+-+++++-+++---+-++-+---+++++-+-++++----++--++----++ +--++-++--+-+-++++++++-+-++--+-+--+---+--++++--+---+-+++++-+++---+-++-+---+++++-+-++++----++--++----+ ++--++-++--+-+-++++++++-+--+--+-+--+---+--++++--+--++-+++++-+++---+-++-+---+++++-+-++++----++--++---- +-+--++-++--+-+-++++++++-+--+--+-+--+---+--++++--+-+++-+++++-+++---+-++-+----++++-+-++++----++--++--- ++-+--++-++--+-+-++++++++----+--+-+--+---+--++++--+-+++-+++++-+++---+-++-+----++++-+-++++----++--++-- +-+-+--++-++--+-+-+++++++++---+--+-+--+---+--++++----+++-+++++-+++---+-++-+----++++-+-++++----++--++- ++-+-+--++-++--+-+-+++++++-+---+--+-+--+---+--++++----+++-+++++-+++---+-++-+----++++-+-++++----++--++ +++-+-+--++-++--+-+-++++++--+---+--+-+--+---+--+++++---+++-+++++-+++---+-++-+----++++-+-++++----++--+ ++++-+-+--++-++--+-+-++++++--+---+--+-+--+---+--+++-+---+++-+++++-+++---+-++++----++++-+-++++----++-- +++++-+-+--++-++--+-+-++++++--+---+--+-+--+---+--+++-+---+++-+++++-+++---+-+-++----++++-+-++++----++- ++++++-+-+--++-++--+-+-++++++--+---+--+-+--+---+--+++-+---+++-+++++-+++---+---++----++++-+-++++----++ +++++++-+-+--++-++--+-+-++++++--+---+--+-+--+---+---++-+---+++-+++++-+++---++--++----++++-+-++++----+ ++++++++-+-+--++-++--+-+-+-++++--+---+--+-+--+---+-+-++-+---+++-+++++-+++---++--++----++++-+-++++---- +++++++++-+-+--++-++--+-+---++++--+---+--+-+--+---+-+-++-+---+++-+++++-+++---++--++----++++-+-++++--- +-++++++++-+-+--++-++--+-++--++++--+---+--+-+--+-----+-++-+---+++-+++++-+++---++--++----++++-+-++++-- ++-++++++++-+-+--++-++--+--+--++++--+---+--+-+--+-----+-++-+---+++-+++++-+++---++--++----++++-+-++++- +-+-++++++++-+-+--++-++--+--+--++++--+---+--+-+--+-+---+-++-+---+++-+++++-++----++--++----++++-+-++++ ++-+-++++++++-+-+--++-++-----+--++++--+---+--+-+--+++---+-++-+---+++-+++++-++----++--++----++++-+-+++ +-+-+-++++++++-+-+--++-++-+---+--++++--+---+--+-+--+++---+-++-+---+++-+++++-++----++--++----++++-+-++ +--+-+-++++++++-+-+--++-++-+---+--++++--+---+--+-+--+++---+-++-+---+++-++++++++----++--++----++++-+-+ ++--+-+-++++++++-+-+--++-+--+---+--++++--+---+--+-++-+++---+-++-+---+++-++++++++----++--++----++++-+- +++--+-+-++++++++-+-+--++-+--+---+--++++--+---+--+-++-+++---+-++-+---+++-+++-++++----++--++----++++-+ \ No newline at end of file diff --git a/gptqmodel/exllamav3/util/hadamard_data/hadamard_116.txt b/gptqmodel/exllamav3/util/hadamard_data/hadamard_116.txt new file mode 100644 index 000000000..5cea1241e --- /dev/null +++ b/gptqmodel/exllamav3/util/hadamard_data/hadamard_116.txt @@ -0,0 +1,116 @@ +++--+--+-+++-++++-+++-+--+--+++++-++-+---++++++---+-++-++++-+---++--+-++++++-+--++---+-+++---++--+-+----+-+--++---++ ++++--+--+-+++-++++-+++-+--+--+++++-++-+---++++++---+-++-++-+-+---++--+-++++++-+--++---+++++---++--+-+----+-+--++---+ +-+++--+--+-+++-++++-+++-+--+-++++++-++-+---++++++---+-++-++-+-+---++--+-++++++-+--++---+++++---++--+-+----+-+--++--- +--+++--+--+-+++-++++-+++-+--++++++++-++-+---++++++---+-++--+-+-+---++--+-++++++-+--++---+++++---++--+-+----+-+--++-- ++--+++--+--+-+++-++++-+++-+---+++++++-++-+---++++++---+-++--+-+-+---++--+-++++++-+--++---+++++---++--+-+----+-+--++- +-+--+++--+--+-+++-++++-+++-+-+-+++++++-++-+---++++++---+-+---+-+-+---++--+-++++++-+--++---+++++---++--+-+----+-+--++ +--+--+++--+--+-+++-++++-+++-+++-+++++++-++-+---++++++---+-+---+-+-+---++--+-++++++-+--++---+++++---++--+-+----+-+--+ ++--+--+++--+--+-+++-++++-+++--++-+++++++-++-+---++++++---+++---+-+-+---++--+-++++++-+--++---+++++---++--+-+----+-+-- +-+--+--+++--+--+-+++-++++-++++-++-+++++++-++-+---++++++----++---+-+-+---++--+-++++++-+--++---+++++---++--+-+----+-+- ++-+--+--+++--+--+-+++-++++-++-+-++-+++++++-++-+---++++++----++---+-+-+---++--+-++++++-+--++---+++++---++--+-+----+-+ +++-+--+--+++--+--+-+++-++++-+--+-++-+++++++-++-+---++++++-+--++---+-+-+---++--+-++++++-+--++---+++++---++--+-+----+- ++++-+--+--+++--+--+-+++-++++----+-++-+++++++-++-+---++++++-+--++---+-+-+---++--+-++++++-+--++---+++++---++--+-+----+ +-+++-+--+--+++--+--+-+++-+++++---+-++-+++++++-++-+---++++++-+--++---+-+-+---++--+-++++++-+--++---+++++---++--+-+---- ++-+++-+--+--+++--+--+-+++-+++++---+-++-+++++++-++-+---++++++-+--++---+-+-+---++--+-++++-+-+--++---+++++---++--+-+--- +++-+++-+--+--+++--+--+-+++-+++++---+-++-+++++++-++-+---++++++-+--++---+-+-+---++--+-+++--+-+--++---+++++---++--+-+-- ++++-+++-+--+--+++--+--+-+++-+++++---+-++-+++++++-++-+---++++++-+--++---+-+-+---++--+-++---+-+--++---+++++---++--+-+- +++++-+++-+--+--+++--+--+-+++-+++++---+-++-+++++++-++-+---++++++-+--++---+-+-+---++--+-+----+-+--++---+++++---++--+-+ +-++++-+++-+--+--+++--+--+-+++++++++---+-++-+++++++-++-+---++++++-+--++---+-+-+---++--+-+----+-+--++---+++++---++--+- ++-++++-+++-+--+--+++--+--+-++-++++++---+-++-+++++++-++-+---++++++-+--++---+-+-+---++--+-+----+-+--++---+++++---++--+ +++-++++-+++-+--+--+++--+--+-+--++++++---+-++-+++++++-++-+-+-++++++-+--++---+-+-+---++--+-+----+-+--++---+++++---++-- ++++-++++-+++-+--+--+++--+--+----++++++---+-++-+++++++-++-+-+-++++++-+--++---+-+-+---++--+-+----+-+--++---+++++---++- +-+++-++++-+++-+--+--+++--+--++---++++++---+-++-+++++++-++---+-++++++-+--++---+-+-+---++--+-+----+-+--++---+++++---++ ++-+++-++++-+++-+--+--+++--+---+---++++++---+-++-+++++++-+++--+-++++++-+--++---+-+-+---++--+-+----+-+--++---+++++---+ +-+-+++-++++-+++-+--+--+++--+-+-+---++++++---+-++-+++++++-+++--+-++++++-+--++---+-+-+---++--+-+----+-+--++---+++++--- +--+-+++-++++-+++-+--+--+++--+++-+---++++++---+-++-+++++++--++--+-++++++-+--++---+-+-+---++--+-+----+-+--++---+++++-- ++--+-+++-++++-+++-+--+--+++---++-+---++++++---+-++-+++++++--++--+-++++++-+--++---+-+-+---++--+-+----+-+--++---+++++- +-+--+-+++-++++-+++-+--+--+++-+-++-+---++++++---+-++-++++++---++--+-++++++-+--++---+-+-+---++--+-+----+-+--++---+++++ +--+--+-+++-++++-+++-+--+--+++++-++-+---++++++---+-++-++++++---++--+-++++++-+--++---+-+-+---++--+-+----+-+--++---++++ ++--+--+-+++-++++-+++-+--+--+++++-++-+---++++++---+-++-++++-+---++--+-++++++-+--++---+-+++---++--+-+----+-+--++---+++ +----+--+-+++------+++-+--+---++--+--+-+++-++++-+++-+--+--+---+++--++-+-++++-+-++--+++--+-+---++--+-++++++-+--++---+- +-----+--+-+++------+++-+--+--+++--+--+-+++-++++-+++-+--+------+++--++-+-++++-+-++--+++--+-+---++--+-++++++-+--++---+ +------+--+-+++------+++-+--+--+++--+--+-+++-++++-+++-+--+------+++--++-+-++++-+-++--++++-+-+---++--+-++++++-+--++--- +-------+--+-+++------+++-+--+--+++--+--+-+++-++++-+++-+--++-----+++--++-+-++++-+-++--++-+-+-+---++--+-++++++-+--++-- ++-------+--+-+++------+++-+--+--+++--+--+-+++-++++-+++-+--++-----+++--++-+-++++-+-++--+--+-+-+---++--+-++++++-+--++- +-+-------+--+-+++------+++-+--+--+++--+--+-+++-++++-+++-+-+++-----+++--++-+-++++-+-++-----+-+-+---++--+-++++++-+--++ +--+-------+--+-+++------+++-+--+--+++--+--+-+++-++++-+++-+-+++-----+++--++-+-++++-+-++-+---+-+-+---++--+-++++++-+--+ ++--+-------+--+-+++------+++-+--+--+++--+--+-+++-++++-+++---+++-----+++--++-+-++++-+-++++---+-+-+---++--+-++++++-+-- +-+--+-------+--+-+++------+++-+--+--+++--+--+-+++-++++-++++--+++-----+++--++-+-++++-+-+-++---+-+-+---++--+-++++++-+- ++-+--+-------+--+-+++------+++-+--+--+++--+--+-+++-++++-++++--+++-----+++--++-+-++++-+---++---+-+-+---++--+-++++++-+ +++-+--+-------+--+-+++------+++-+--+--+++--+--+-+++-++++-+-++--+++-----+++--++-+-++++-++--++---+-+-+---++--+-++++++- ++++-+--+-------+--+-+++------+++-+--+--+++--+--+-+++-++++-+-++--+++-----+++--++-+-++++--+--++---+-+-+---++--+-++++++ +-+++-+--+-------+--+-+++------+++-+--+--+++--+--+-+++-++++-+-++--+++-----+++--++-+-+++++-+--++---+-+-+---++--+-+++++ +--+++-+--+-------+--+-+++----+-+++-+--+--+++--+--+-+++-++++-+-++--+++-----+++--++-+-+++++-+--++---+-+-+---++--+-++++ +---+++-+--+-------+--+-+++---++-+++-+--+--+++--+--+-+++-++++-+-++--+++-----+++--++-+-+++++-+--++---+-+-+---++--+-+++ +----+++-+--+-------+--+-+++--+++-+++-+--+--+++--+--+-+++-++++-+-++--+++-----+++--++-+-+++++-+--++---+-+-+---++--+-++ +-----+++-+--+-------+--+-+++-++++-+++-+--+--+++--+--+-+++-++++-+-++--+++-----+++--++-+-+++++-+--++---+-+-+---++--+-+ +------+++-+--+-------+--+-+++-++++-+++-+--+--+++--+--+-+++-++++-+-++--+++-----+++--++-+++++++-+--++---+-+-+---++--+- ++------+++-+--+-------+--+-+++-++++-+++-+--+--+++--+--+-+++-++++-+-++--+++-----+++--++--++++++-+--++---+-+-+---++--+ +++------+++-+--+-------+--+-+++-++++-+++-+--+--+++--+--+-+-+-++++-+-++--+++-----+++--+++-++++++-+--++---+-+-+---++-- ++++------+++-+--+-------+--+-+++-++++-+++-+--+--+++--+--+-+-+-++++-+-++--+++-----+++--+-+-++++++-+--++---+-+-+---++- +-+++------+++-+--+-------+--+-+++-++++-+++-+--+--+++--+--+++-+-++++-+-++--+++-----+++----+-++++++-+--++---+-+-+---++ ++-+++------+++-+--+-------+--+-+++-++++-+++-+--+--+++--+---++-+-++++-+-++--+++-----+++-+--+-++++++-+--++---+-+-+---+ +-+-+++------+++-+--+-------+--+-+++-++++-+++-+--+--+++--+---++-+-++++-+-++--+++-----+++++--+-++++++-+--++---+-+-+--- +--+-+++------+++-+--+-------+--+-+++-++++-+++-+--+--+++--++--++-+-++++-+-++--+++-----++-++--+-++++++-+--++---+-+-+-- ++--+-+++------+++-+--+-------+--+-+++-++++-+++-+--+--+++--++--++-+-++++-+-++--+++-----+--++--+-++++++-+--++---+-+-+- +-+--+-+++------+++-+--+-------+--+-+++-++++-+++-+--+--+++-+++--++-+-++++-+-++--+++--------++--+-++++++-+--++---+-+-+ +--+--+-+++------+++-+--+-------+--+-+++-++++-+++-+--+--+++-+++--++-+-++++-+-++--+++----+---++--+-++++++-+--++---+-+- +---+--+-+++------+++-+--+----+--+--+-+++-++++-+++-+--+--++--+++--++-+-++++-+-++--+++----+---++--+-++++++-+--++---+-+ +-+-+++--++-+------+-++--+++-++++---++--+-+----+-+--++---++++--+--+-+++-++++-+++-+--+--+----+--+-+++------+++-+--+--- ++-+-+++--++-+------+-++--+++-++++---++--+-+----+-+--++---++++--+--+-+++-++++-+++-+--+-------+--+-+++------+++-+--+-- +-+-+-+++--++-+------+-++--++++++++---++--+-+----+-+--++----+++--+--+-+++-++++-+++-+--+-------+--+-+++------+++-+--+- ++-+-+-+++--++-+------+-++--++-+++++---++--+-+----+-+--++----+++--+--+-+++-++++-+++-+--+-------+--+-+++------+++-+--+ +++-+-+-+++--++-+------+-++--+--+++++---++--+-+----+-+--++-+--+++--+--+-+++-++++-+++-+--+-------+--+-+++------+++-+-- ++++-+-+-+++--++-+------+-++-----+++++---++--+-+----+-+--++-+--+++--+--+-+++-++++-+++-+--+-------+--+-+++------+++-+- +-+++-+-+-+++--++-+------+-++-+---+++++---++--+-+----+-+--+--+--+++--+--+-+++-++++-+++-+--+-------+--+-+++------+++-+ +--+++-+-+-+++--++-+------+-++++---+++++---++--+-+----+-+--+--+--+++--+--+-+++-++++-+++-+--+-------+--+-+++------+++- ++--+++-+-+-+++--++-+------+-+-++---+++++---++--+-+----+-+--+--+--+++--+--+-+++-++++-+++-+--+-------+--+-+++------+++ +++--+++-+-+-+++--++-+------+---++---+++++---++--+-+----+-++-+--+--+++--+--+-+++-++++-+++-+--+-------+--+-+++------++ +-++--+++-+-+-+++--++-+------++--++---+++++---++--+-+----+-++-+--+--+++--+--+-+++-++++-+++-+--+-------+--+-+++------+ ++-++--+++-+-+-+++--++-+-------+--++---+++++---++--+-+----++++-+--+--+++--+--+-+++-++++-+++-+--+-------+--+-+++------ +-+-++--+++-+-+-+++--++-+-----+-+--++---+++++---++--+-+-----+++-+--+--+++--+--+-+++-++++-+++-+--+-------+--+-+++----- +--+-++--+++-+-+-+++--++-+-----+-+--++---+++++---++--+-+---+-+++-+--+--+++--+--+-+++-+++--+++-+--+-------+--+-+++---- +---+-++--+++-+-+-+++--++-+-----+-+--++---+++++---++--+-+--++-+++-+--+--+++--+--+-+++-++---+++-+--+-------+--+-+++--- +----+-++--+++-+-+-+++--++-+-----+-+--++---+++++---++--+-+-+++-+++-+--+--+++--+--+-+++-+----+++-+--+-------+--+-+++-- +-----+-++--+++-+-+-+++--++-+-----+-+--++---+++++---++--+-+++++-+++-+--+--+++--+--+-+++------+++-+--+-------+--+-+++- +------+-++--+++-+-+-+++--++-++----+-+--++---+++++---++--+--++++-+++-+--+--+++--+--+-+++------+++-+--+-------+--+-+++ ++------+-++--+++-+-+-+++--++--+----+-+--++---+++++---++--++-++++-+++-+--+--+++--+--+-+++------+++-+--+-------+--+-++ +-+------+-++--+++-+-+-+++--+++-+----+-+--++---+++++---++--++-++++-+++-+--+--+++--+--+-+++------+++-+--+-------+--+-+ ++-+------+-++--+++-+-+-+++--+-+-+----+-+--++---+++++---++-+++-++++-+++-+--+--+++--+--+-+++------+++-+--+-------+--+- +++-+------+-++--+++-+-+-+++----+-+----+-+--++---+++++---++-+++-++++-+++-+--+--+++--+--+-+++------+++-+--+-------+--+ +-++-+------+-++--+++-+-+-+++-+--+-+----+-+--++---+++++---++-+++-++++-+++-+--+--+++--+--+-+++------+++-+--+-------+-- +--++-+------+-++--+++-+-+-+++++--+-+----+-+--++---+++++----+-+++-++++-+++-+--+--+++--+--+-+++------+++-+--+-------+- ++--++-+------+-++--+++-+-+-++-++--+-+----+-+--++---+++++----+-+++-++++-+++-+--+--+++--+--+-+++------+++-+--+-------+ +++--++-+------+-++--+++-+-+-+--++--+-+----+-+--++---+++++-+--+-+++-++++-+++-+--+--+++--+--+-+++------+++-+--+------- ++++--++-+------+-++--+++-+-+----++--+-+----+-+--++---+++++-+--+-+++-++++-+++-+--+--+++--+--+-+++------+++-+--+------ +-+++--++-+------+-++--+++-+-++---++--+-+----+-+--++---++++--+--+-+++-++++-+++-+--+--+++--+--+-+++------+++-+--+----- ++-+++--++-+------+-++--+++-+-++---++--+-+----+-+--++---++++--+--+-+++-++++-+++-+--+--++---+--+-+++------+++-+--+---- +---+++--++-+-++++-+-++--+++---+-+++--++-+------+-++--+++-+++++-++-+---++++++---+-++-+++++--+--+-+++-++++-+++-+--+--+ +----+++--++-+-++++-+-++--+++-+-+-+++--++-+------+-++--+++-+++++-++-+---++++++---+-++-+++++--+--+-+++-++++-+++-+--+-- +-----+++--++-+-++++-+-++--+++-+-+-+++--++-+------+-++--+++++++++-++-+---++++++---+-++-+-+++--+--+-+++-++++-+++-+--+- ++-----+++--++-+-++++-+-++--+++-+-+-+++--++-+------+-++--+++++++++-++-+---++++++---+-++---+++--+--+-+++-++++-+++-+--+ +++-----+++--++-+-++++-+-++--+++-+-+-+++--++-+------+-++--+-+++++++-++-+---++++++---+-+++--+++--+--+-+++-++++-+++-+-- ++++-----+++--++-+-++++-+-++--+++-+-+-+++--++-+------+-++--+-+++++++-++-+---++++++---+-+-+--+++--+--+-+++-++++-+++-+- +-+++-----+++--++-+-++++-+-++--+++-+-+-+++--++-+------+-++-++-+++++++-++-+---++++++---+---+--+++--+--+-+++-++++-+++-+ +--+++-----+++--++-+-++++-+-++--+++-+-+-+++--++-+------+-++-++-+++++++-++-+---++++++---++--+--+++--+--+-+++-++++-+++- ++--+++-----+++--++-+-++++-+-++--+++-+-+-+++--++-+------+-++-++-+++++++-++-+---++++++----+--+--+++--+--+-+++-++++-+++ +++--+++-----+++--++-+-++++-+-++--+++-+-+-+++--++-+------+--+-++-+++++++-++-+---++++++--+-+--+--+++--+--+-+++-++++-++ +-++--+++-----+++--++-+-++++-+-++--+++-+-+-+++--++-+------+--+-++-+++++++-++-+---++++++-++-+--+--+++--+--+-+++-++++-+ ++-++--+++-----+++--++-+-++++-+-++--+++-+-+-+++--++-+---------+-++-+++++++-++-+---+++++++++-+--+--+++--+--+-+++-++++- +-+-++--+++-----+++--++-+-++++-+-++--+++-+-+-+++--++-+-----+---+-++-+++++++-++-+---+++++-+++-+--+--+++--+--+-+++-++++ ++-+-++--+++-----+++--++-+-+++--+-++--+++-+-+-+++--++-+----++---+-++-+++++++-++-+---+++++-+++-+--+--+++--+--+-+++-+++ +++-+-++--+++-----+++--++-+-++---+-++--+++-+-+-+++--++-+---+++---+-++-+++++++-++-+---+++++-+++-+--+--+++--+--+-+++-++ ++++-+-++--+++-----+++--++-+-+----+-++--+++-+-+-+++--++-+--++++---+-++-+++++++-++-+---+++++-+++-+--+--+++--+--+-+++-+ +++++-+-++--+++-----+++--++-+------+-++--+++-+-+-+++--++-+-+++++---+-++-+++++++-++-+---+++++-+++-+--+--+++--+--+-+++- +-++++-+-++--+++-----+++--++-+------+-++--+++-+-+-+++--++-+++++++---+-++-+++++++-++-+----++++-+++-+--+--+++--+--+-+++ ++-++++-+-++--+++-----+++--++-+------+-++--+++-+-+-+++--++--++++++---+-++-+++++++-++-+--+-++++-+++-+--+--+++--+--+-++ +-+-++++-+-++--+++-----+++--++-+------+-++--+++-+-+-+++--++--++++++---+-++-+++++++-++-+-++-++++-+++-+--+--+++--+--+-+ ++-+-++++-+-++--+++-----+++--++-+------+-++--+++-+-+-+++--+---++++++---+-++-+++++++-++-++++-++++-+++-+--+--+++--+--+- +++-+-++++-+-++--+++-----+++--++-+------+-++--+++-+-+-+++--+---++++++---+-++-+++++++-++--+++-++++-+++-+--+--+++--+--+ +-++-+-++++-+-++--+++-----+++--++-+------+-++--+++-+-+-+++--+---++++++---+-++-+++++++-+++-+++-++++-+++-+--+--+++--+-- +--++-+-++++-+-++--+++-----+++--++-+------+-++--+++-+-+-++++-+---++++++---+-++-+++++++-+-+-+++-++++-+++-+--+--+++--+- ++--++-+-++++-+-++--+++-----+++--++-+------+-++--+++-+-+-++++-+---++++++---+-++-+++++++---+-+++-++++-+++-+--+--+++--+ +++--++-+-++++-+-++--+++-----+++--++-+------+-++--+++-+-+-+-++-+---++++++---+-++-++++++++--+-+++-++++-+++-+--+--+++-- ++++--++-+-++++-+-++--+++-----+++--++-+------+-++--+++-+-+-+-++-+---++++++---+-++-++++++-+--+-+++-++++-+++-+--+--+++- +-+++--++-+-++++-+-++--+++-----+++--++-+------+-++--+++-+-+++-++-+---++++++---+-++-+++++--+--+-+++-++++-+++-+--+--+++ +--+++--++-+-++++-+-++--+++---+-+++--++-+------+-++--+++-+-+++-++-+---++++++---+-++-+++++--+--+-+++-++++-+++-+--+--++ \ No newline at end of file diff --git a/gptqmodel/exllamav3/util/hadamard_data/hadamard_156.txt b/gptqmodel/exllamav3/util/hadamard_data/hadamard_156.txt new file mode 100644 index 000000000..b52b200b8 --- /dev/null +++ b/gptqmodel/exllamav3/util/hadamard_data/hadamard_156.txt @@ -0,0 +1,156 @@ ++++--+-+-----+--++----++--+-----+-+--++++++---+--++----+-+--+-+----++--+---++++++--++-+---+-+--+----+--+-+---+-++--+++---++-+-+-----+++-++-+++-----+-+-++--- +++++--+-+-----+--++----++--+-----+-+--++++++---+--++----+-+--+-+----++--+---++++++--++-+---+-+--+----+--+-+---+-++--+-+---++-+-+-----+++-++-+++-----+-+-++-- ++++++--+-+-----+--++----++--+-----+-+--++++++---+--++----+-+--+-+----++--+---++++++--++-+---+-+--+----+--+-+---+-++----+---++-+-+-----+++-++-+++-----+-+-++- +-+++++--+-+-----+--++----++--+-----+-+-+++++++---+--++----+-+--+-+----++--+----+++++--++-+---+-+--+----+--+-+---+-++----+---++-+-+-----+++-++-+++-----+-+-++ +--+++++--+-+-----+--++----++--+-----+-+-+++++++---+--++----+-+--+-+----++--+----+++++--++-+---+-+--+----+--+-+---+-+++---+---++-+-+-----+++-++-+++-----+-+-+ ++--+++++--+-+-----+--++----++--+-----+---+++++++---+--++----+-+--+-+----++--+-+--+++++--++-+---+-+--+----+--+-+---+-+++---+---++-+-+-----+++-++-+++-----+-+- +-+--+++++--+-+-----+--++----++--+-----+---+++++++---+--++----+-+--+-+----++--+++--+++++--++-+---+-+--+----+--+-+---+--++---+---++-+-+-----+++-++-+++-----+-+ ++-+--+++++--+-+-----+--++----++--+-----+---+++++++---+--++----+-+--+-+----++---++--+++++--++-+---+-+--+----+--+-+---++-++---+---++-+-+-----+++-++-+++-----+- +-+-+--+++++--+-+-----+--++----++--+-----+---+++++++---+--++----+-+--+-+----++-+-++--+++++--++-+---+-+--+----+--+-+----+-++---+---++-+-+-----+++-++-+++-----+ +--+-+--+++++--+-+-----+--++----++--+-----+---+++++++---+--++----+-+--+-+----++-+-++--+++++--++-+---+-+--+----+--+-+--+-+-++---+---++-+-+-----+++-++-+++----- +---+-+--+++++--+-+-----+--++----++--+--+--+---+++++++---+--++----+-+--+-+----+--+-++--+++++--++-+---+-+--+----+--+-+--+-+-++---+---++-+-+-----+++-++-+++---- +----+-+--+++++--+-+-----+--++----++--+-++--+---+++++++---+--++----+-+--+-+-------+-++--+++++--++-+---+-+--+----+--+-+--+-+-++---+---++-+-+-----+++-++-+++--- +-----+-+--+++++--+-+-----+--++----++--+-++--+---+++++++---+--++----+-+--+-+---+---+-++--+++++--++-+---+-+--+----+--+----+-+-++---+---++-+-+-----+++-++-+++-- ++-----+-+--+++++--+-+-----+--++----++----++--+---+++++++---+--++----+-+--+-+---+---+-++--+++++--++-+---+-+--+----+--+----+-+-++---+---++-+-+-----+++-++-+++- +-+-----+-+--+++++--+-+-----+--++----++----++--+---+++++++---+--++----+-+--+-+-+-+---+-++--+++++--++-+---+-+--+----+-------+-+-++---+---++-+-+-----+++-++-+++ +--+-----+-+--+++++--+-+-----+--++----++----++--+---+++++++---+--++----+-+--+-+-+-+---+-++--+++++--++-+---+-+--+----+-+-----+-+-++---+---++-+-+-----+++-++-++ ++--+-----+-+--+++++--+-+-----+--++----++----++--+---+++++++---+--++----+-+--+---+-+---+-++--+++++--++-+---+-+--+----+++-----+-+-++---+---++-+-+-----+++-++-+ +++--+-----+-+--+++++--+-+-----+--++-----+----++--+---+++++++---+--++----+-+--++--+-+---+-++--+++++--++-+---+-+--+----+++-----+-+-++---+---++-+-+-----+++-++- +-++--+-----+-+--+++++--+-+-----+--++---+-+----++--+---+++++++---+--++----+-+---+--+-+---+-++--+++++--++-+---+-+--+----+++-----+-+-++---+---++-+-+-----+++-++ +--++--+-----+-+--+++++--+-+-----+--++---+-+----++--+---+++++++---+--++----+-+---+--+-+---+-++--+++++--++-+---+-+--+--+-+++-----+-+-++---+---++-+-+-----+++-+ +---++--+-----+-+--+++++--+-+-----+--++---+-+----++--+---+++++++---+--++----+-+---+--+-+---+-++--+++++--++-+---+-+--+-++-+++-----+-+-++---+---++-+-+-----+++- +----++--+-----+-+--+++++--+-+-----+--+++--+-+----++--+---+++++++---+--++----+-----+--+-+---+-++--+++++--++-+---+-+--+-++-+++-----+-+-++---+---++-+-+-----+++ ++----++--+-----+-+--+++++--+-+-----+--+-+--+-+----++--+---+++++++---+--++----++----+--+-+---+-++--+++++--++-+---+-+--+-++-+++-----+-+-++---+---++-+-+-----++ +++----++--+-----+-+--+++++--+-+-----+--+-+--+-+----++--+---+++++++---+--++-----+----+--+-+---+-++--+++++--++-+---+-+-++-++-+++-----+-+-++---+---++-+-+-----+ +-++----++--+-----+-+--+++++--+-+-----+--+-+--+-+----++--+---+++++++---+--++-----+----+--+-+---+-++--+++++--++-+---+-++++-++-+++-----+-+-++---+---++-+-+----- +--++----++--+-----+-+--+++++--+-+-----+--+-+--+-+----++--+---+++++++---+--++--+--+----+--+-+---+-++--+++++--++-+---+--+++-++-+++-----+-+-++---+---++-+-+---- ++--++----++--+-----+-+--+++++--+-+--------+-+--+-+----++--+---+++++++---+--++--+--+----+--+-+---+-++--+++++--++-+---+--+++-++-+++-----+-+-++---+---++-+-+--- +-+--++----++--+-----+-+--+++++--+-+--------+-+--+-+----++--+---+++++++---+--+++-+--+----+--+-+---+-++--+++++--++-+------+++-++-+++-----+-+-++---+---++-+-+-- +--+--++----++--+-----+-+--+++++--+-+---+----+-+--+-+----++--+---+++++++---+--+-+-+--+----+--+-+---+-++--+++++--++-+------+++-++-+++-----+-+-++---+---++-+-+- +---+--++----++--+-----+-+--+++++--+-+--++----+-+--+-+----++--+---+++++++---+----+-+--+----+--+-+---+-++--+++++--++-+------+++-++-+++-----+-+-++---+---++-+-+ +----+--++----++--+-----+-+--+++++--+-+--++----+-+--+-+----++--+---+++++++---+----+-+--+----+--+-+---+-++--+++++--++-++-----+++-++-+++-----+-+-++---+---++-+- +-----+--++----++--+-----+-+--+++++--+-+--++----+-+--+-+----++--+---+++++++---++---+-+--+----+--+-+---+-++--+++++--++--+-----+++-++-+++-----+-+-++---+---++-+ ++-----+--++----++--+-----+-+--+++++--+-+--++----+-+--+-+----++--+---+++++++----+---+-+--+----+--+-+---+-++--+++++--+++-+-----+++-++-+++-----+-+-++---+---++- +-+-----+--++----++--+-----+-+--+++++--+-+--++----+-+--+-+----++--+---+++++++--+-+---+-+--+----+--+-+---+-++--+++++--+-+-+-----+++-++-+++-----+-+-++---+---++ ++-+-----+--++----++--+-----+-+--+++++----+--++----+-+--+-+----++--+---+++++++-++-+---+-+--+----+--+-+---+-++--+++++--+-+-+-----+++-++-+++-----+-+-++---+---+ +-+-+-----+--++----++--+-----+-+--+++++----+--++----+-+--+-+----++--+---+++++++-++-+---+-+--+----+--+-+---+-++--+++++-++-+-+-----+++-++-+++-----+-+-++---+--- +--+-+-----+--++----++--+-----+-+--++++++---+--++----+-+--+-+----++--+---++++++--++-+---+-+--+----+--+-+---+-++--+++++-++-+-+-----+++-++-+++-----+-+-++---+-- ++--+-+-----+--++----++--+-----+-+--++++++---+--++----+-+--+-+----++--+---++++++--++-+---+-+--+----+--+-+---+-++--++++--++-+-+-----+++-++-+++-----+-+-++---+- +++--+-+-----+--++----++--+-----+-+--++++++---+--++----+-+--+-+----++--+---++++++--++-+---+-+--+----+--+-+---+-++--+++---++-+-+-----+++-++-+++-----+-+-++---+ +----+++-++--++++-+-++-+-++++--++-+++---+++--+-+-----+--++----++--+-----+-+--++-+++--+-+-+++++---+--+---+++++-+-+--++++++--++-+---+-+--+----+--+-+---+-++--++ +-----+++-++--++++-+-++-+-++++--++-+++--++++--+-+-----+--++----++--+-----+-+--++-+++--+-+-+++++---+--+---+++++-+-+--++++++--++-+---+-+--+----+--+-+---+-++--+ +------+++-++--++++-+-++-+-++++--++-+++-+++++--+-+-----+--++----++--+-----+-+--++-+++--+-+-+++++---+--+---+++++-+-+--++++++--++-+---+-+--+----+--+-+---+-++-- +-------+++-++--++++-+-++-+-++++--++-+++-+++++--+-+-----+--++----++--+-----+-+-+++-+++--+-+-+++++---+--+---+++++-+-+---+++++--++-+---+-+--+----+--+-+---+-++- ++-------+++-++--++++-+-++-+-++++--++-++--+++++--+-+-----+--++----++--+-----+-+-+++-+++--+-+-+++++---+--+---+++++-+-+---+++++--++-+---+-+--+----+--+-+---+-++ +++-------+++-++--++++-+-++-+-++++--++-++--+++++--+-+-----+--++----++--+-----+---+++-+++--+-+-+++++---+--+---+++++-+-++--+++++--++-+---+-+--+----+--+-+---+-+ ++++-------+++-++--++++-+-++-+-++++--++--+--+++++--+-+-----+--++----++--+-----++--+++-+++--+-+-+++++---+--+---+++++-+-++--+++++--++-+---+-+--+----+--+-+---+- +-+++-------+++-++--++++-+-++-+-++++--+++-+--+++++--+-+-----+--++----++--+------+--+++-+++--+-+-+++++---+--+---+++++-+-++--+++++--++-+---+-+--+----+--+-+---+ ++-+++-------+++-++--++++-+-++-+-++++--+-+-+--+++++--+-+-----+--++----++--+----+-+--+++-+++--+-+-+++++---+--+---+++++-+-++--+++++--++-+---+-+--+----+--+-+--- +++-+++-------+++-++--++++-+-++-+-++++----+-+--+++++--+-+-----+--++----++--+----+-+--+++-+++--+-+-+++++---+--+---+++++-+-++--+++++--++-+---+-+--+----+--+-+-- +-++-+++-------+++-++--++++-+-++-+-++++----+-+--+++++--+-+-----+--++----++--+--+-+-+--+++-+++--+-+-+++++---+--+---++++--+-++--+++++--++-+---+-+--+----+--+-+- +--++-+++-------+++-++--++++-+-++-+-++++----+-+--+++++--+-+-----+--++----++--+-++-+-+--+++-+++--+-+-+++++---+--+---+++---+-++--+++++--++-+---+-+--+----+--+-+ ++--++-+++-------+++-++--++++-+-++-+-+++-----+-+--+++++--+-+-----+--++----++--++++-+-+--+++-+++--+-+-+++++---+--+---+++---+-++--+++++--++-+---+-+--+----+--+- +++--++-+++-------+++-++--++++-+-++-+-+++-----+-+--+++++--+-+-----+--++----++--++++-+-+--+++-+++--+-+-+++++---+--+---+-+---+-++--+++++--++-+---+-+--+----+--+ ++++--++-+++-------+++-++--++++-+-++-+-+-+-----+-+--+++++--+-+-----+--++----++-+++++-+-+--+++-+++--+-+-+++++---+--+---+-+---+-++--+++++--++-+---+-+--+----+-- +++++--++-+++-------+++-++--++++-+-++-+---+-----+-+--+++++--+-+-----+--++----++-+++++-+-+--+++-+++--+-+-+++++---+--+---+-+---+-++--+++++--++-+---+-+--+----+- +-++++--++-+++-------+++-++--++++-+-++-++--+-----+-+--+++++--+-+-----+--++----+--+++++-+-+--+++-+++--+-+-+++++---+--+---+-+---+-++--+++++--++-+---+-+--+----+ ++-++++--++-+++-------+++-++--++++-+-++-++--+-----+-+--+++++--+-+-----+--++-------+++++-+-+--+++-+++--+-+-+++++---+--++--+-+---+-++--+++++--++-+---+-+--+---- +-+-++++--++-+++-------+++-++--++++-+-++-++--+-----+-+--+++++--+-+-----+--++---+---+++++-+-+--+++-+++--+-+-+++++---+---+--+-+---+-++--+++++--++-+---+-+--+--- ++-+-++++--++-+++-------+++-++--++++-+-+--++--+-----+-+--+++++--+-+-----+--++---+---+++++-+-+--+++-+++--+-+-+++++---+---+--+-+---+-++--+++++--++-+---+-+--+-- +++-+-++++--++-+++-------+++-++--++++-+----++--+-----+-+--+++++--+-+-----+--++---+---+++++-+-+--+++-+++--+-+-+++++---+---+--+-+---+-++--+++++--++-+---+-+--+- +-++-+-++++--++-+++-------+++-++--++++-+----++--+-----+-+--+++++--+-+-----+--+++--+---+++++-+-+--+++-+++--+-+-+++++-------+--+-+---+-++--+++++--++-+---+-+--+ ++-++-+-++++--++-+++-------+++-++--++++-+----++--+-----+-+--+++++--+-+-----+--+-+--+---+++++-+-+--+++-+++--+-+-+++++--+----+--+-+---+-++--+++++--++-+---+-+-- +-+-++-+-++++--++-+++-------+++-++--++++++----++--+-----+-+--+++++--+-+-----+----+--+---+++++-+-+--+++-+++--+-+-+++++--+----+--+-+---+-++--+++++--++-+---+-+- ++-+-++-+-++++--++-+++-------+++-++--+++-++----++--+-----+-+--+++++--+-+-----+----+--+---+++++-+-+--+++-+++--+-+-+++++--+----+--+-+---+-++--+++++--++-+---+-+ +++-+-++-+-++++--++-+++-------+++-++--++--++----++--+-----+-+--+++++--+-+-----++---+--+---+++++-+-+--+++-+++--+-+-+++++--+----+--+-+---+-++--+++++--++-+---+- ++++-+-++-+-++++--++-+++-------+++-++--++--++----++--+-----+-+--+++++--+-+-----++---+--+---+++++-+-+--+++-+++--+-+-+++-+--+----+--+-+---+-++--+++++--++-+---+ +++++-+-++-+-++++--++-+++-------+++-++---+--++----++--+-----+-+--+++++--+-+----+++---+--+---+++++-+-+--+++-+++--+-+-+++-+--+----+--+-+---+-++--+++++--++-+--- +-++++-+-++-+-++++--++-+++-------+++-++---+--++----++--+-----+-+--+++++--+-+---++++---+--+---+++++-+-+--+++-+++--+-+-+-+-+--+----+--+-+---+-++--+++++--++-+-- +--++++-+-++-+-++++--++-+++-------+++-++---+--++----++--+-----+-+--+++++--+-+--+++++---+--+---+++++-+-+--+++-+++--+-+---+-+--+----+--+-+---+-++--+++++--++-+- ++--++++-+-++-+-++++--++-+++-------+++-+----+--++----++--+-----+-+--+++++--+-+--+++++---+--+---+++++-+-+--+++-+++--+-+---+-+--+----+--+-+---+-++--+++++--++-+ +++--++++-+-++-+-++++--++-+++-------+++------+--++----++--+-----+-+--+++++--+-++-+++++---+--+---+++++-+-+--+++-+++--+-+---+-+--+----+--+-+---+-++--+++++--++- +-++--++++-+-++-+-++++--++-+++-------++++-----+--++----++--+-----+-+--+++++--+--+-+++++---+--+---+++++-+-+--+++-+++--+-+---+-+--+----+--+-+---+-++--+++++--++ ++-++--++++-+-++-+-++++--++-+++-------++-+-----+--++----++--+-----+-+--+++++--++-+-+++++---+--+---+++++-+-+--+++-+++--+-+---+-+--+----+--+-+---+-++--+++++--+ +++-++--++++-+-++-+-++++--++-+++-------++-+-----+--++----++--+-----+-+--+++++---+-+-+++++---+--+---+++++-+-+--+++-+++-++-+---+-+--+----+--+-+---+-++--+++++-- ++++-++--++++-+-++-+-++++--++-+++--------+-+-----+--++----++--+-----+-+--+++++---+-+-+++++---+--+---+++++-+-+--+++-+++-++-+---+-+--+----+--+-+---+-++--+++++- +-+++-++--++++-+-++-+-++++--++-+++--------+-+-----+--++----++--+-----+-+--++++++--+-+-+++++---+--+---+++++-+-+--+++-++--++-+---+-+--+----+--+-+---+-++--+++++ +--+++-++--++++-+-++-+-++++--++-+++-----+--+-+-----+--++----++--+-----+-+--++++++--+-+-+++++---+--+---+++++-+-+--+++-++--++-+---+-+--+----+--+-+---+-++--++++ +---+++-++--++++-+-++-+-++++--++-+++----++--+-+-----+--++----++--+-----+-+--++++++--+-+-+++++---+--+---+++++-+-+--+++-++--++-+---+-+--+----+--+-+---+-++--+++ +---++--+-+++-+-++-++++-++-+-+++-+--++--+---++-+-+-----+++-++-+++-----+-+-++---+++--+-+-----+--++----++--+-----+-+--++----+++-++--++++-+-++-+-++++--++-+++--- +----++--+-+++-+-++-++++-++-+-+++-+--++--+---++-+-+-----+++-++-+++-----+-+-++--++++--+-+-----+--++----++--+-----+-+--+-----+++-++--++++-+-++-+-++++--++-+++-- +-----++--+-+++-+-++-++++-++-+-+++-+--++--+---++-+-+-----+++-++-+++-----+-+-++-+++++--+-+-----+--++----++--+-----+-+--------+++-++--++++-+-++-+-++++--++-+++- ++-----++--+-+++-+-++-++++-++-+-+++-+--+---+---++-+-+-----+++-++-+++-----+-+-++-+++++--+-+-----+--++----++--+-----+-+--------+++-++--++++-+-++-+-++++--++-+++ +++-----++--+-+++-+-++-++++-++-+-+++-+--+---+---++-+-+-----+++-++-+++-----+-+-+--+++++--+-+-----+--++----++--+-----+-++-------+++-++--++++-+-++-+-++++--++-++ +-++-----++--+-+++-+-++-++++-++-+-+++-+-++---+---++-+-+-----+++-++-+++-----+-+-+--+++++--+-+-----+--++----++--+-----+-++-------+++-++--++++-+-++-+-++++--++-+ +--++-----++--+-+++-+-++-++++-++-+-+++-+-++---+---++-+-+-----+++-++-+++-----+-+-+--+++++--+-+-----+--++----++--+-----++++-------+++-++--++++-+-++-+-++++--++- ++--++-----++--+-+++-+-++-++++-++-+-+++-+-++---+---++-+-+-----+++-++-+++-----+-+-+--+++++--+-+-----+--++----++--+------+++-------+++-++--++++-+-++-+-++++--++ +-+--++-----++--+-+++-+-++-++++-++-+-+++-+-++---+---++-+-+-----+++-++-+++-----+-+-+--+++++--+-+-----+--++----++--+----+-+++-------+++-++--++++-+-++-+-++++--+ ++-+--++-----++--+-+++-+-++-++++-++-+-+++-+-++---+---++-+-+-----+++-++-+++-------+-+--+++++--+-+-----+--++----++--+---++-+++-------+++-++--++++-+-++-+-++++-- +++-+--++-----++--+-+++-+-++-++++-++-+-+-+-+-++---+---++-+-+-----+++-++-+++-------+-+--+++++--+-+-----+--++----++--+---++-+++-------+++-++--++++-+-++-+-++++- ++++-+--++-----++--+-+++-+-++-++++-++-+---+-+-++---+---++-+-+-----+++-++-+++-------+-+--+++++--+-+-----+--++----++--+---++-+++-------+++-++--++++-+-++-+-++++ +-+++-+--++-----++--+-+++-+-++-++++-++-+---+-+-++---+---++-+-+-----+++-++-+++-------+-+--+++++--+-+-----+--++----++--++--++-+++-------+++-++--++++-+-++-+-+++ ++-+++-+--++-----++--+-+++-+-++-++++-++-----+-+-++---+---++-+-+-----+++-++-+++-+-----+-+--+++++--+-+-----+--++----++--++--++-+++-------+++-++--++++-+-++-+-++ +-+-+++-+--++-----++--+-+++-+-++-++++-++-----+-+-++---+---++-+-+-----+++-++-+++-+-----+-+--+++++--+-+-----+--++----++-+++--++-+++-------+++-++--++++-+-++-+-+ ++-+-+++-+--++-----++--+-+++-+-++-++++-++-----+-+-++---+---++-+-+-----+++-++-++--+-----+-+--+++++--+-+-----+--++----++++++--++-+++-------+++-++--++++-+-++-+- +++-+-+++-+--++-----++--+-+++-+-++-++++-++-----+-+-++---+---++-+-+-----+++-++-++--+-----+-+--+++++--+-+-----+--++----+-++++--++-+++-------+++-++--++++-+-++-+ +-++-+-+++-+--++-----++--+-+++-+-++-+++++++-----+-+-++---+---++-+-+-----+++-++-++--+-----+-+--+++++--+-+-----+--++----+-++++--++-+++-------+++-++--++++-+-++- ++-++-+-+++-+--++-----++--+-+++-+-++-+++-+++-----+-+-++---+---++-+-+-----+++-++-++--+-----+-+--+++++--+-+-----+--++----+-++++--++-+++-------+++-++--++++-+-++ +++-++-+-+++-+--++-----++--+-+++-+-++-+++-+++-----+-+-++---+---++-+-+-----+++-+--++--+-----+-+--+++++--+-+-----+--++--+-+-++++--++-+++-------+++-++--++++-+-+ ++++-++-+-+++-+--++-----++--+-+++-+-++-+++-+++-----+-+-++---+---++-+-+-----+++----++--+-----+-+--+++++--+-+-----+--++-++-+-++++--++-+++-------+++-++--++++-+- +++++-++-+-+++-+--++-----++--+-+++-+-++--++-+++-----+-+-++---+---++-+-+-----+++----++--+-----+-+--+++++--+-+-----+--++-++-+-++++--++-+++-------+++-++--++++-+ +-++++-++-+-+++-+--++-----++--+-+++-+-+++-++-+++-----+-+-++---+---++-+-+-----+++----++--+-----+-+--+++++--+-+-----+--++-++-+-++++--++-+++-------+++-++--++++- ++-++++-++-+-+++-+--++-----++--+-+++-+-+++-++-+++-----+-+-++---+---++-+-+-----+++----++--+-----+-+--+++++--+-+-----+---+-++-+-++++--++-+++-------+++-++--++++ +++-++++-++-+-+++-+--++-----++--+-+++-+-+++-++-+++-----+-+-++---+---++-+-+------++----++--+-----+-+--+++++--+-+-----+-+-+-++-+-++++--++-+++-------+++-++--+++ +-++-++++-++-+-+++-+--++-----++--+-+++-+-+++-++-+++-----+-+-++---+---++-+-+------++----++--+-----+-+--+++++--+-+-----+++-+-++-+-++++--++-+++-------+++-++--++ ++-++-++++-++-+-+++-+--++-----++--+-+++---+++-++-+++-----+-+-++---+---++-+-+---+--++----++--+-----+-+--+++++--+-+-----+++-+-++-+-++++--++-+++-------+++-++--+ +-+-++-++++-++-+-+++-+--++-----++--+-+++---+++-++-+++-----+-+-++---+---++-+-+---+--++----++--+-----+-+--+++++--+-+----++++-+-++-+-++++--++-+++-------+++-++-- ++-+-++-++++-++-+-+++-+--++-----++--+-++----+++-++-+++-----+-+-++---+---++-+-+---+--++----++--+-----+-+--+++++--+-+----++++-+-++-+-++++--++-+++-------+++-++- +++-+-++-++++-++-+-+++-+--++-----++--+-+-----+++-++-+++-----+-+-++---+---++-+-+---+--++----++--+-----+-+--+++++--+-+----++++-+-++-+-++++--++-+++-------+++-++ ++++-+-++-++++-++-+-+++-+--++-----++--+-+-----+++-++-+++-----+-+-++---+---++-+-----+--++----++--+-----+-+--+++++--+-+-+--++++-+-++-+-++++--++-+++-------+++-+ +-+++-+-++-++++-++-+-+++-+--++-----++--+-+-----+++-++-+++-----+-+-++---+---++-+-----+--++----++--+-----+-+--+++++--+-+++--++++-+-++-+-++++--++-+++-------+++- ++-+++-+-++-++++-++-+-+++-+--++-----++--+-+-----+++-++-+++-----+-+-++---+---++-+-----+--++----++--+-----+-+--+++++--+--++--++++-+-++-+-++++--++-+++-------+++ +-+-+++-+-++-++++-++-+-+++-+--++-----++--+-+-----+++-++-+++-----+-+-++---+---++-+-----+--++----++--+-----+-+--+++++--++-++--++++-+-++-+-++++--++-+++-------++ +--+-+++-+-++-++++-++-+-+++-+--++-----+++-+-+-----+++-++-+++-----+-+-++---+---++-+-----+--++----++--+-----+-+--+++++--++-++--++++-+-++-+-++++--++-+++-------+ ++--+-+++-+-++-++++-++-+-+++-+--++-----+++-+-+-----+++-++-+++-----+-+-++---+----+-+-----+--++----++--+-----+-+--+++++-+++-++--++++-+-++-+-++++--++-+++------- +++--+-+++-+-++-++++-++-+-+++-+--++------++-+-+-----+++-++-+++-----+-+-++---+----+-+-----+--++----++--+-----+-+--+++++-+++-++--++++-+-++-+-++++--++-+++------ +-++--+-+++-+-++-++++-++-+-+++-+--++------++-+-+-----+++-++-+++-----+-+-++---+-+--+-+-----+--++----++--+-----+-+--++++--+++-++--++++-+-++-+-++++--++-+++----- +--++--+-+++-+-++-++++-++-+-+++-+--++------++-+-+-----+++-++-+++-----+-+-++---+++--+-+-----+--++----++--+-----+-+--+++---+++-++--++++-+-++-+-++++--++-+++---- +-+++--+-+-+++++---+--+---+++++-+-+--+++---++--+-+++-+-++-++++-++-+-+++-+--++--++++---+--++----+-+--+-+----++--+---++++++--+-+-----+--++----++--+-----+-+--++ ++-+++--+-+-+++++---+--+---+++++-+-+--++----++--+-+++-+-++-++++-++-+-+++-+--++-+++++---+--++----+-+--+-+----++--+---++++++--+-+-----+--++----++--+-----+-+--+ +++-+++--+-+-+++++---+--+---+++++-+-+--+-----++--+-+++-+-++-++++-++-+-+++-+--++++++++---+--++----+-+--+-+----++--+---++++++--+-+-----+--++----++--+-----+-+-- ++++-+++--+-+-+++++---+--+---+++++-+-+--+-----++--+-+++-+-++-++++-++-+-+++-+--++++++++---+--++----+-+--+-+----++--+----+++++--+-+-----+--++----++--+-----+-+- +-+++-+++--+-+-+++++---+--+---+++++-+-+-++-----++--+-+++-+-++-++++-++-+-+++-+---+++++++---+--++----+-+--+-+----++--+----+++++--+-+-----+--++----++--+-----+-+ +--+++-+++--+-+-+++++---+--+---+++++-+-+-++-----++--+-+++-+-++-++++-++-+-+++-+---+++++++---+--++----+-+--+-+----++--+-+--+++++--+-+-----+--++----++--+-----+- ++--+++-+++--+-+-+++++---+--+---+++++-+---++-----++--+-+++-+-++-++++-++-+-+++-+---+++++++---+--++----+-+--+-+----++--+-+--+++++--+-+-----+--++----++--+-----+ +-+--+++-+++--+-+-+++++---+--+---+++++-++--++-----++--+-+++-+-++-++++-++-+-+++-+---+++++++---+--++----+-+--+-+----++--+-+--+++++--+-+-----+--++----++--+----- ++-+--+++-+++--+-+-+++++---+--+---+++++--+--++-----++--+-+++-+-++-++++-++-+-+++-+---+++++++---+--++----+-+--+-+----++--+-+--+++++--+-+-----+--++----++--+---- +-+-+--+++-+++--+-+-+++++---+--+---++++++-+--++-----++--+-+++-+-++-++++-++-+-++--+---+++++++---+--++----+-+--+-+----++--+-+--+++++--+-+-----+--++----++--+--- ++-+-+--+++-+++--+-+-+++++---+--+---++++++-+--++-----++--+-+++-+-++-++++-++-+-++--+---+++++++---+--++----+-+--+-+----+---+-+--+++++--+-+-----+--++----++--+-- +++-+-+--+++-+++--+-+-+++++---+--+---++++++-+--++-----++--+-+++-+-++-++++-++-+-++--+---+++++++---+--++----+-+--+-+--------+-+--+++++--+-+-----+--++----++--+- ++++-+-+--+++-+++--+-+-+++++---+--+---++-+++-+--++-----++--+-+++-+-++-++++-++-+-++--+---+++++++---+--++----+-+--+-+--------+-+--+++++--+-+-----+--++----++--+ +++++-+-+--+++-+++--+-+-+++++---+--+---++-+++-+--++-----++--+-+++-+-++-++++-++---++--+---+++++++---+--++----+-+--+-+--+-----+-+--+++++--+-+-----+--++----++-- ++++++-+-+--+++-+++--+-+-+++++---+--+----+-+++-+--++-----++--+-+++-+-++-++++-++---++--+---+++++++---+--++----+-+--+-+--+-----+-+--+++++--+-+-----+--++----++- +-+++++-+-+--+++-+++--+-+-+++++---+--+--+-+-+++-+--++-----++--+-+++-+-++-++++-+----++--+---+++++++---+--++----+-+--+-+--+-----+-+--+++++--+-+-----+--++----++ +--+++++-+-+--+++-+++--+-+-+++++---+--+-++-+-+++-+--++-----++--+-+++-+-++-++++-+----++--+---+++++++---+--++----+-+--+-+--+-----+-+--+++++--+-+-----+--++----+ +---+++++-+-+--+++-+++--+-+-+++++---+--+-++-+-+++-+--++-----++--+-+++-+-++-++++-+----++--+---+++++++---+--++----+-+--+++--+-----+-+--+++++--+-+-----+--++---- ++---+++++-+-+--+++-+++--+-+-+++++---+--+-++-+-+++-+--++-----++--+-+++-+-++-++++-+----++--+---+++++++---+--++----+-+---++--+-----+-+--+++++--+-+-----+--++--- +-+---+++++-+-+--+++-+++--+-+-+++++---+-++-++-+-+++-+--++-----++--+-+++-+-++-++-+-+----++--+---+++++++---+--++----+-+---++--+-----+-+--+++++--+-+-----+--++-- +--+---+++++-+-+--+++-+++--+-+-+++++---++++-++-+-+++-+--++-----++--+-+++-+-++-+--+-+----++--+---+++++++---+--++----+-+---++--+-----+-+--+++++--+-+-----+--++- ++--+---+++++-+-+--+++-+++--+-+-+++++---++++-++-+-+++-+--++-----++--+-+++-+-++-+--+-+----++--+---+++++++---+--++----+-----++--+-----+-+--+++++--+-+-----+--++ +-+--+---+++++-+-+--+++-+++--+-+-+++++---++++-++-+-+++-+--++-----++--+-+++-+-++-+--+-+----++--+---+++++++---+--++----++----++--+-----+-+--+++++--+-+-----+--+ +--+--+---+++++-+-+--+++-+++--+-+-+++++-+-++++-++-+-+++-+--++-----++--+-+++-+-++-+--+-+----++--+---+++++++---+--++----++----++--+-----+-+--+++++--+-+-----+-- +---+--+---+++++-+-+--+++-+++--+-+-+++++++-++++-++-+-+++-+--++-----++--+-+++-+--+-+--+-+----++--+---+++++++---+--++----++----++--+-----+-+--+++++--+-+-----+- ++---+--+---+++++-+-+--+++-+++--+-+-++++-++-++++-++-+-+++-+--++-----++--+-+++-+--+-+--+-+----++--+---+++++++---+--++----++----++--+-----+-+--+++++--+-+-----+ +++---+--+---+++++-+-+--+++-+++--+-+-++++-++-++++-++-+-+++-+--++-----++--+-+++----+-+--+-+----++--+---+++++++---+--++-+--++----++--+-----+-+--+++++--+-+----- ++++---+--+---+++++-+-+--+++-+++--+-+-++-+-++-++++-++-+-+++-+--++-----++--+-+++----+-+--+-+----++--+---+++++++---+--++-+--++----++--+-----+-+--+++++--+-+---- +++++---+--+---+++++-+-+--+++-+++--+-+-++-+-++-++++-++-+-+++-+--++-----++--+-+++----+-+--+-+----++--+---+++++++---+--+--+--++----++--+-----+-+--+++++--+-+--- ++++++---+--+---+++++-+-+--+++-+++--+-+-++-+-++-++++-++-+-+++-+--++-----++--+-+++----+-+--+-+----++--+---+++++++---+-----+--++----++--+-----+-+--+++++--+-+-- +-+++++---+--+---+++++-+-+--+++-+++--+-++++-+-++-++++-++-+-+++-+--++-----++--+--++----+-+--+-+----++--+---+++++++---+-----+--++----++--+-----+-+--+++++--+-+- ++-+++++---+--+---+++++-+-+--+++-+++--+--+++-+-++-++++-++-+-+++-+--++-----++--+--++----+-+--+-+----++--+---+++++++---+-----+--++----++--+-----+-+--+++++--+-+ +-+-+++++---+--+---+++++-+-+--+++-+++--++-+++-+-++-++++-++-+-+++-+--++-----++--+--++----+-+--+-+----++--+---+++++++---+-----+--++----++--+-----+-+--+++++--+- ++-+-+++++---+--+---+++++-+-+--+++-+++---+-+++-+-++-++++-++-+-+++-+--++-----++--+--++----+-+--+-+----++--+---+++++++---+-----+--++----++--+-----+-+--+++++--+ +-+-+-+++++---+--+---+++++-+-+--+++-+++---+-+++-+-++-++++-++-+-+++-+--++-----++--+--++----+-+--+-+----++--+---+++++++-+-+-----+--++----++--+-----+-+--+++++-- +--+-+-+++++---+--+---+++++-+-+--+++-++++--+-+++-+-++-++++-++-+-+++-+--++-----+---+--++----+-+--+-+----++--+---+++++++-+-+-----+--++----++--+-----+-+--+++++- ++--+-+-+++++---+--+---+++++-+-+--+++-++++--+-+++-+-++-++++-++-+-+++-+--++-----+---+--++----+-+--+-+----++--+---++++++--+-+-----+--++----++--+-----+-+--+++++ +++--+-+-+++++---+--+---+++++-+-+--+++-+-++--+-+++-+-++-++++-++-+-+++-+--++----++---+--++----+-+--+-+----++--+---++++++--+-+-----+--++----++--+-----+-+--++++ ++++--+-+-+++++---+--+---+++++-+-+--+++---++--+-+++-+-++-++++-++-+-+++-+--++---+++---+--++----+-+--+-+----++--+---++++++--+-+-----+--++----++--+-----+-+--+++ \ No newline at end of file diff --git a/gptqmodel/exllamav3/util/hadamard_data/hadamard_172.txt b/gptqmodel/exllamav3/util/hadamard_data/hadamard_172.txt new file mode 100644 index 000000000..bb2447efe --- /dev/null +++ b/gptqmodel/exllamav3/util/hadamard_data/hadamard_172.txt @@ -0,0 +1,172 @@ ++---++--++++-+-+++-++--++-+++-+-++++--++---++-++++++----+-+--++-++-++--+-+----++++++-++++-+-++--+-+-++++-+----+-++++-+-+--++-+-++++---++++-+--+--++--------++--+--+-++++---+ +-+---++--++++-+-+++-++--++-+++-+-++++--++--+++-++++++----+-+--++-++-++--+-+----++++++-++++-+-++--+-+-++++-+----+-++++-+-+--++-+-++++---++++-+--+--++--------++--+--+-++++--- +--+---++--++++-+-+++-++--++-+++-+-++++--++--+++-++++++----+-+--++-++-++--+-+----+++++++++++-+-++--+-+-++++-+----+-++++-+-+--++-+--+++---++++-+--+--++--------++--+--+-++++-- +---+---++--++++-+-+++-++--++-+++-+-++++--+++-+++-++++++----+-+--++-++-++--+-+----+++++-+++++-+-++--+-+-++++-+----+-++++-+-+--++-+--+++---++++-+--+--++--------++--+--+-++++- ++---+---++--++++-+-+++-++--++-+++-+-++++--+++-+++-++++++----+-+--++-++-++--+-+----+++++-+++++-+-++--+-+-++++-+----+-++++-+-+--++----+++---++++-+--+--++--------++--+--+-++++ +++---+---++--++++-+-+++-++--++-+++-+-++++--+++-+++-++++++----+-+--++-++-++--+-+----+++-+-+++++-+-++--+-+-++++-+----+-++++-+-+--+++---+++---++++-+--+--++--------++--+--+-+++ +-++---+---++--++++-+-+++-++--++-+++-+-++++-++++-+++-++++++----+-+--++-++-++--+-+----+++-+-+++++-+-++--+-+-++++-+----+-++++-+-+--+++---+++---++++-+--+--++--------++--+--+-++ +--++---+---++--++++-+-+++-++--++-+++-+-+++++++++-+++-++++++----+-+--++-++-++--+-+----+++-+-+++++-+-++--+-+-++++-+----+-++++-+-+--+++---+++---++++-+--+--++--------++--+--+-+ ++--++---+---++--++++-+-+++-++--++-+++-+-+++++++++-+++-++++++----+-+--++-++-++--+-+-----++-+-+++++-+-++--+-+-++++-+----+-++++-+-+-++++---+++---++++-+--+--++--------++--+--+- +++--++---+---++--++++-+-+++-++--++-+++-+-++-++++++-+++-++++++----+-+--++-++-++--+-+-----++-+-+++++-+-++--+-+-++++-+----+-++++-+-+-++++---+++---++++-+--+--++--------++--+--+ ++++--++---+---++--++++-+-+++-++--++-+++-+-+--++++++-+++-++++++----+-+--++-++-++--+-+--+--++-+-+++++-+-++--+-+-++++-+----+-++++-+-+-++++---+++---++++-+--+--++--------++--+-- +++++--++---+---++--++++-+-+++-++--++-+++-+----++++++-+++-++++++----+-+--++-++-++--+-+--+--++-+-+++++-+-++--+-+-++++-+----+-++++-+-+-++++---+++---++++-+--+--++--------++--+- +-++++--++---+---++--++++-+-+++-++--++-+++-+----++++++-+++-++++++----+-+--++-++-++--+-++-+--++-+-+++++-+-++--+-+-++++-+----+-++++---+-++++---+++---++++-+--+--++--------++--+ ++-++++--++---+---++--++++-+-+++-++--++-+++-+----++++++-+++-++++++----+-+--++-++-++--+--+-+--++-+-+++++-+-++--+-+-++++-+----+-+++++--+-++++---+++---++++-+--+--++--------++-- +-+-++++--++---+---++--++++-+-+++-++--++-+++-+----++++++-+++-++++++----+-+--++-++-++--++-+-+--++-+-+++++-+-++--+-+-++++-+----+-+++-+--+-++++---+++---++++-+--+--++--------++- ++-+-++++--++---+---++--++++-+-+++-++--++-+++-+----++++++-+++-++++++----+-+--++-++-++--++-+-+--++-+-+++++-+-++--+-+-++++-+----+-++--+--+-++++---+++---++++-+--+--++--------++ +++-+-++++--++---+---++--++++-+-+++-++--++-+-+-+----++++++-+++-++++++----+-+--++-++-++-+++-+-+--++-+-+++++-+-++--+-+-++++-+----+-++--+--+-++++---+++---++++-+--+--++--------+ ++++-+-++++--++---+---++--++++-+-+++-++--++---+-+----++++++-+++-++++++----+-+--++-++-++++++-+-+--++-+-+++++-+-++--+-+-++++-+----+-++--+--+-++++---+++---++++-+--+--++-------- +-+++-+-++++--++---+---++--++++-+-+++-++--+++--+-+----++++++-+++-++++++----+-+--++-++-+-++++-+-+--++-+-+++++-+-++--+-+-++++-+----+-++--+--+-++++---+++---++++-+--+--++------- ++-+++-+-++++--++---+---++--++++-+-+++-++--+++--+-+----++++++-+++-++++++----+-+--++-++-+-++++-+-+--++-+-+++++-+-++--+-+-++++-+------++--+--+-++++---+++---++++-+--+--++------ +++-+++-+-++++--++---+---++--++++-+-+++-++---++--+-+----++++++-+++-++++++----+-+--++-++-+-++++-+-+--++-+-+++++-+-++--+-+-++++-+------++--+--+-++++---+++---++++-+--+--++----- +-++-+++-+-++++--++---+---++--++++-+-+++-++-+-++--+-+----++++++-+++-++++++----+-+--++-+--+-++++-+-+--++-+-+++++-+-++--+-+-++++-+------++--+--+-++++---+++---++++-+--+--++---- +--++-+++-+-++++--++---+---++--++++-+-+++-++++-++--+-+----++++++-+++-++++++----+-+--++----+-++++-+-+--++-+-+++++-+-++--+-+-++++-+------++--+--+-++++---+++---++++-+--+--++--- ++--++-+++-+-++++--++---+---++--++++-+-+++-+-++-++--+-+----++++++-+++-++++++----+-+--++----+-++++-+-+--++-+-+++++-+-++--+-+-++++-+------++--+--+-++++---+++---++++-+--+--++-- +++--++-+++-+-++++--++---+---++--++++-+-+++-+-++-++--+-+----++++++-+++-++++++----+-+--++----+-++++-+-+--++-+-+++++-+-++--+-+-++++--------++--+--+-++++---+++---++++-+--+--++- +-++--++-+++-+-++++--++---+---++--++++-+-+++++-++-++--+-+----++++++-+++-++++++----+-+---+----+-++++-+-+--++-+-+++++-+-++--+-+-++++--------++--+--+-++++---+++---++++-+--+--++ ++-++--++-+++-+-++++--++---+---++--++++-+-++-++-++-++--+-+----++++++-+++-++++++----+-+-+-+----+-++++-+-+--++-+-+++++-+-++--+-+-++++--------++--+--+-++++---+++---++++-+--+--+ +++-++--++-+++-+-++++--++---+---++--++++-+-+--++-++-++--+-+----++++++-+++-++++++----+-+++-+----+-++++-+-+--++-+-+++++-+-++--+-+-++++--------++--+--+-++++---+++---++++-+--+-- ++++-++--++-+++-+-++++--++---+---++--++++-+-+--++-++-++--+-+----++++++-+++-++++++----+-+++-+----+-++++-+-+--++-+-+++++-+-++--+-+-+-++--------++--+--+-++++---+++---++++-+--+- +-+++-++--++-+++-+-++++--++---+---++--++++-+-+--++-++-++--+-+----++++++-+++-++++++----+++++-+----+-++++-+-+--++-+-+++++-+-++--+-+---++--------++--+--+-++++---+++---++++-+--+ ++-+++-++--++-+++-+-++++--++---+---++--++++-+-+--++-++-++--+-+----++++++-+++-++++++-----++++-+----+-++++-+-+--++-+-+++++-+-++--+-++--++--------++--+--+-++++---+++---++++-+-- +-+-+++-++--++-+++-+-++++--++---+---++--++++-+-+--++-++-++--+-+----++++++-+++-++++++---+-++++-+----+-++++-+-+--++-+-+++++-+-++--+--+--++--------++--+--+-++++---+++---++++-+- ++-+-+++-++--++-+++-+-++++--++---+---++--+++--+-+--++-++-++--+-+----++++++-+++-++++++---+-++++-+----+-++++-+-+--++-+-+++++-+-++--+--+--++--------++--+--+-++++---+++---++++-+ +++-+-+++-++--++-+++-+-++++--++---+---++--++---+-+--++-++-++--+-+----++++++-+++-++++++-+-+-++++-+----+-++++-+-+--++-+-+++++-+-++--+--+--++--------++--+--+-++++---+++---++++- ++++-+-+++-++--++-+++-+-++++--++---+---++--+----+-+--++-++-++--+-+----++++++-+++-++++++-+-+-++++-+----+-++++-+-+--++-+-+++++-+-++--+--+--++--------++--+--+-++++---+++---++++ +++++-+-+++-++--++-+++-+-++++--++---+---++--+----+-+--++-++-++--+-+----++++++-+++-+++++--+-+-++++-+----+-++++-+-+--++-+-+++++-+-+++-+--+--++--------++--+--+-++++---+++---+++ +-++++-+-+++-++--++-+++-+-++++--++---+---++-++----+-+--++-++-++--+-+----++++++-+++-+++++--+-+-++++-+----+-++++-+-+--++-+-+++++-+-+++-+--+--++--------++--+--+-++++---+++---++ +--++++-+-+++-++--++-+++-+-++++--++---+---+++++----+-+--++-++-++--+-+----++++++-+++-+++++--+-+-++++-+----+-++++-+-+--++-+-+++++-+-+++-+--+--++--------++--+--+-++++---+++---+ ++--++++-+-+++-++--++-+++-+-++++--++---+---+++++----+-+--++-++-++--+-+----++++++-+++-++-++--+-+-++++-+----+-++++-+-+--++-+-+++++-+++++-+--+--++--------++--+--+-++++---+++--- +++--++++-+-+++-++--++-+++-+-++++--++---+---+++++----+-+--++-++-++--+-+----++++++-+++-++-++--+-+-++++-+----+-++++-+-+--++-+-+++++--++++-+--+--++--------++--+--+-++++---+++-- +-++--++++-+-+++-++--++-+++-+-++++--++---+--++++++----+-+--++-++-++--+-+----++++++-+++--+-++--+-+-++++-+----+-++++-+-+--++-+-+++++--++++-+--+--++--------++--+--+-++++---+++- +--++--++++-+-+++-++--++-+++-+-++++--++---+--++++++----+-+--++-++-++--+-+----++++++-++++-+-++--+-+-++++-+----+-++++-+-+--++-+-++++---++++-+--+--++--------++--+--+-++++---+++ +---++--++++-+-+++-++--++-+++-+-++++--++---++-++++++----+-+--++-++-++--+-+----++++++-++++-+-++--+-+-++++-+----+-++++-+-+--++-+-++++---++++-+--+--++--------++--+--+-++++---++ +--+------++++-+-++--+--+--++-+-++++------+-+---++--++++-+-+++-++--++-+++-+-++++--++-----+++----+-++-++--++++++++--++-++-+----+++-+++-+-++--+-+-++++-+----+-++++-+-+--++-+-++ +---+------++++-+-++--+--+--++-+-++++------+-+---++--++++-+-+++-++--++-+++-+-++++--++-----+++----+-++-++--++++++++--++-++-+----+++++++-+-++--+-+-++++-+----+-++++-+-+--++-+-+ ++---+------++++-+-++--+--+--++-+-++++--------+---++--++++-+-+++-++--++-+++-+-++++--++-+---+++----+-++-++--++++++++--++-++-+----+++++++-+-++--+-+-++++-+----+-++++-+-+--++-+- +-+---+------++++-+-++--+--+--++-+-++++--------+---++--++++-+-+++-++--++-+++-+-++++--++++---+++----+-++-++--++++++++--++-++-+----+-+++++-+-++--+-+-++++-+----+-++++-+-+--++-+ +--+---+------++++-+-++--+--+--++-+-++++----+---+---++--++++-+-+++-++--++-+++-+-++++--++++---+++----+-++-++--++++++++--++-++-+----+-+++++-+-++--+-+-++++-+----+-++++-+-+--++- +---+---+------++++-+-++--+--+--++-+-++++---++---+---++--++++-+-+++-++--++-+++-+-++++---+++---+++----+-++-++--++++++++--++-++-+----+-+++++-+-++--+-+-++++-+----+-++++-+-+--++ +----+---+------++++-+-++--+--+--++-+-++++---++---+---++--++++-+-+++-++--++-+++-+-++++---+++---+++----+-++-++--++++++++--++-++-+--+-+-+++++-+-++--+-+-++++-+----+-++++-+-+--+ +-----+---+------++++-+-++--+--+--++-+-++++---++---+---++--++++-+-+++-++--++-+++-+-++++---+++---+++----+-++-++--++++++++--++-++-+-++-+-+++++-+-++--+-+-++++-+----+-++++-+-+-- +------+---+------++++-+-++--+--+--++-+-+++++--++---+---++--++++-+-+++-++--++-+++-+-+++----+++---+++----+-++-++--++++++++--++-++-+-++-+-+++++-+-++--+-+-++++-+----+-++++-+-+- ++------+---+------++++-+-++--+--+--++-+-+++++--++---+---++--++++-+-+++-++--++-+++-+-+++----+++---+++----+-++-++--++++++++--++-++---++-+-+++++-+-++--+-+-++++-+----+-++++-+-+ +++------+---+------++++-+-++--+--+--++-+-+++++--++---+---++--++++-+-+++-++--++-+++-+-+-+----+++---+++----+-++-++--++++++++--++-+++--++-+-+++++-+-++--+-+-++++-+----+-++++-+- ++++------+---+------++++-+-++--+--+--++-+-+++++--++---+---++--++++-+-+++-++--++-+++-+-+-+----+++---+++----+-++-++--++++++++--++-+-+--++-+-+++++-+-++--+-+-++++-+----+-++++-+ +++++------+---+------++++-+-++--+--+--++-+--++++--++---+---++--++++-+-+++-++--++-+++-+++-+----+++---+++----+-++-++--++++++++--++-+-+--++-+-+++++-+-++--+-+-++++-+----+-++++- +-++++------+---+------++++-+-++--+--+--++-++-++++--++---+---++--++++-+-+++-++--++-+++--++-+----+++---+++----+-++-++--++++++++--++-+-+--++-+-+++++-+-++--+-+-++++-+----+-++++ ++-++++------+---+------++++-+-++--+--+--++--+-++++--++---+---++--++++-+-+++-++--++-++++-++-+----+++---+++----+-++-++--++++++++--++-+-+--++-+-+++++-+-++--+-+-++++-+----+-+++ +-+-++++------+---+------++++-+-++--+--+--+++-+-++++--++---+---++--++++-+-+++-++--++-++++-++-+----+++---+++----+-++-++--++++++++--++-+-+--++-+-+++++-+-++--+-+-++++-+----+-++ ++-+-++++------+---+------++++-+-++--+--+--+++-+-++++--++---+---++--++++-+-+++-++--++-+-++-++-+----+++---+++----+-++-++--++++++++-+++-+-+--++-+-+++++-+-++--+-+-++++-+----+-+ +++-+-++++------+---+------++++-+-++--+--+--+++-+-++++--++---+---++--++++-+-+++-++--++---++-++-+----+++---+++----+-++-++--++++++++++++-+-+--++-+-+++++-+-++--+-+-++++-+----+- +-++-+-++++------+---+------++++-+-++--+--+--+++-+-++++--++---+---++--++++-+-+++-++--+++--++-++-+----+++---+++----+-++-++--+++++++-++++-+-+--++-+-+++++-+-++--+-+-++++-+----+ +--++-+-++++------+---+------++++-+-++--+--++-+++-+-++++--++---+---++--++++-+-+++-++--+++--++-++-+----+++---+++----+-++-++--+++++++-++++-+-+--++-+-+++++-+-++--+-+-++++-+---- ++--++-+-++++------+---+------++++-+-++--+--++-+++-+-++++--++---+---++--++++-+-+++-++--+++--++-++-+----+++---+++----+-++-++--+++++-+-++++-+-+--++-+-+++++-+-++--+-+-++++-+--- +-+--++-+-++++------+---+------++++-+-++--+--++-+++-+-++++--++---+---++--++++-+-+++-++-++++--++-++-+----+++---+++----+-++-++--++++--+-++++-+-+--++-+-+++++-+-++--+-+-++++-+-- +--+--++-+-++++------+---+------++++-+-++--+--++-+++-+-++++--++---+---++--++++-+-+++-+++++++--++-++-+----+++---+++----+-++-++--+++---+-++++-+-+--++-+-+++++-+-++--+-+-++++-+- ++--+--++-+-++++------+---+------++++-+-++--+--++-+++-+-++++--++---+---++--++++-+-+++-+++++++--++-++-+----+++---+++----+-++-++--++----+-++++-+-+--++-+-+++++-+-++--+-+-++++-+ +-+--+--++-+-++++------+---+------++++-+-++-++--++-+++-+-++++--++---+---++--++++-+-+++-+++++++--++-++-+----+++---+++----+-++-++--++----+-++++-+-+--++-+-+++++-+-++--+-+-++++- +--+--+--++-+-++++------+---+------++++-+-++-++--++-+++-+-++++--++---+---++--++++-+-+++++++++++--++-++-+----+++---+++----+-++-++---+----+-++++-+-+--++-+-+++++-+-++--+-+-++++ ++--+--+--++-+-++++------+---+------++++-+-++-++--++-+++-+-++++--++---+---++--++++-+-++-++++++++--++-++-+----+++---+++----+-++-++-+-+----+-++++-+-+--++-+-+++++-+-++--+-+-+++ +++--+--+--++-+-++++------+---+------++++-+-++-++--++-+++-+-++++--++---+---++--++++-+-+--++++++++--++-++-+----+++---+++----+-++-++++-+----+-++++-+-+--++-+-+++++-+-++--+-+-++ +-++--+--+--++-+-++++------+---+------++++-++++-++--++-+++-+-++++--++---+---++--++++-+-+--++++++++--++-++-+----+++---+++----+-++-++++-+----+-++++-+-+--++-+-+++++-+-++--+-+-+ ++-++--+--+--++-+-++++------+---+------++++--+++-++--++-+++-+-++++--++---+---++--++++-+++--++++++++--++-++-+----+++---+++----+-++-++++-+----+-++++-+-+--++-+-+++++-+-++--+-+- +-+-++--+--+--++-+-++++------+---+------+++++-+++-++--++-+++-+-++++--++---+---++--++++--++--++++++++--++-++-+----+++---+++----+-++-++++-+----+-++++-+-+--++-+-+++++-+-++--+-+ ++-+-++--+--+--++-+-++++------+---+------+++-+-+++-++--++-+++-+-++++--++---+---++--+++++-++--++++++++--++-++-+----+++---+++----+-++-++++-+----+-++++-+-+--++-+-+++++-+-++--+- +++-+-++--+--+--++-+-++++------+---+------+++-+-+++-++--++-+++-+-++++--++---+---++--+++++-++--++++++++--++-++-+----+++---+++----+--+-++++-+----+-++++-+-+--++-+-+++++-+-++--+ ++++-+-++--+--+--++-+-++++------+---+------+++-+-+++-++--++-+++-+-++++--++---+---++--++-++-++--++++++++--++-++-+----+++---+++----++-+-++++-+----+-++++-+-+--++-+-+++++-+-++-- +++++-+-++--+--+--++-+-++++------+---+------+++-+-+++-++--++-+++-+-++++--++---+---++--++-++-++--++++++++--++-++-+----+++---+++-----+-+-++++-+----+-++++-+-+--++-+-+++++-+-++- +-++++-+-++--+--+--++-+-++++------+---+-----++++-+-+++-++--++-+++-+-++++--++---+---++---+-++-++--++++++++--++-++-+----+++---+++-----+-+-++++-+----+-++++-+-+--++-+-+++++-+-++ +--++++-+-++--+--+--++-+-++++------+---+-----++++-+-+++-++--++-+++-+-++++--++---+---++---+-++-++--++++++++--++-++-+----+++---+++--+--+-+-++++-+----+-++++-+-+--++-+-+++++-+-+ +---++++-+-++--+--+--++-+-++++------+---+-----++++-+-+++-++--++-+++-+-++++--++---+---++---+-++-++--++++++++--++-++-+----+++---+++-++--+-+-++++-+----+-++++-+-+--++-+-+++++-+- +----++++-+-++--+--+--++-+-++++------+---+--+--++++-+-+++-++--++-+++-+-++++--++---+---+----+-++-++--++++++++--++-++-+----+++---+++-++--+-+-++++-+----+-++++-+-+--++-+-+++++-+ +-----++++-+-++--+--+--++-+-++++------+---+-++--++++-+-+++-++--++-+++-+-++++--++---+---+----+-++-++--++++++++--++-++-+----+++---+++-++--+-+-++++-+----+-++++-+-+--++-+-+++++- +------++++-+-++--+--+--++-+-++++------+---+-++--++++-+-+++-++--++-+++-+-++++--++---+--++----+-++-++--++++++++--++-++-+----+++---+-+-++--+-+-++++-+----+-++++-+-+--++-+-+++++ ++------++++-+-++--+--+--++-+-++++------+-----++--++++-+-+++-++--++-+++-+-++++--++---+-+++----+-++-++--++++++++--++-++-+----+++---+-+-++--+-+-++++-+----+-++++-+-+--++-+-++++ +-+------++++-+-++--+--+--++-+-++++------+-----++--++++-+-+++-++--++-+++-+-++++--++---+-+++----+-++-++--++++++++--++-++-+----+++--++-+-++--+-+-++++-+----+-++++-+-+--++-+-+++ +---+-+--++-+-+----+-++++-+----+-+-++--+-+--++---++++-+--+--++--------++--+--+-++++---++---++--++++-+-+++-++--++-+++-+-++++--++-----+------++++-+-++--+--+--++-+-++++------+- +----+-+--++-+-+----+-++++-+----+-+-++--+-+-+++---++++-+--+--++--------++--+--+-++++----+---++--++++-+-+++-++--++-+++-+-++++--++-----+------++++-+-++--+--+--++-+-++++------+ +-----+-+--++-+-+----+-++++-+----+-+-++--+-+-+++---++++-+--+--++--------++--+--+-++++----+---++--++++-+-+++-++--++-+++-+-++++--++-+---+------++++-+-++--+--+--++-+-++++------ ++-----+-+--++-+-+----+-++++-+----+-+-++--+---+++---++++-+--+--++--------++--+--+-++++----+---++--++++-+-+++-++--++-+++-+-++++--++-+---+------++++-+-++--+--+--++-+-++++----- +-+-----+-+--++-+-+----+-++++-+----+-+-++--+---+++---++++-+--+--++--------++--+--+-+++++---+---++--++++-+-+++-++--++-+++-+-++++--+--+---+------++++-+-++--+--+--++-+-++++---- ++-+-----+-+--++-+-+----+-++++-+----+-+-++--+---+++---++++-+--+--++--------++--+--+-+++++---+---++--++++-+-+++-++--++-+++-+-++++-----+---+------++++-+-++--+--+--++-+-++++--- +-+-+-----+-+--++-+-+----+-++++-+----+-+-++-++---+++---++++-+--+--++--------++--+--+-++-++---+---++--++++-+-+++-++--++-+++-+-++++-----+---+------++++-+-++--+--+--++-+-++++-- +--+-+-----+-+--++-+-+----+-++++-+----+-+-+++++---+++---++++-+--+--++--------++--+--+-+--++---+---++--++++-+-+++-++--++-+++-+-++++-----+---+------++++-+-++--+--+--++-+-++++- ++--+-+-----+-+--++-+-+----+-++++-+----+-+-+++++---+++---++++-+--+--++--------++--+--+-+--++---+---++--++++-+-+++-++--++-+++-+-+++------+---+------++++-+-++--+--+--++-+-++++ +++--+-+-----+-+--++-+-+----+-++++-+----+-+--++++---+++---++++-+--+--++--------++--+--+++--++---+---++--++++-+-+++-++--++-+++-+-+++------+---+------++++-+-++--+--+--++-+-+++ +-++--+-+-----+-+--++-+-+----+-++++-+----+-++-++++---+++---++++-+--+--++--------++--+--+++--++---+---++--++++-+-+++-++--++-+++-+-+++------+---+------++++-+-++--+--+--++-+-++ ++-++--+-+-----+-+--++-+-+----+-++++-+----+--+-++++---+++---++++-+--+--++--------++--+-++++--++---+---++--++++-+-+++-++--++-+++-+-+++------+---+------++++-+-++--+--+--++-+-+ +-+-++--+-+-----+-+--++-+-+----+-++++-+----+--+-++++---+++---++++-+--+--++--------++--+-++++--++---+---++--++++-+-+++-++--++-+++-+++++------+---+------++++-+-++--+--+--++-+- ++-+-++--+-+-----+-+--++-+-+----+-++++-+----+--+-++++---+++---++++-+--+--++--------++--+-++++--++---+---++--++++-+-+++-++--++-+++--++++------+---+------++++-+-++--+--+--++-+ +-+-+-++--+-+-----+-+--++-+-+----+-++++-+----+--+-++++---+++---++++-+--+--++--------++--+-++++--++---+---++--++++-+-+++-++--++-++++-++++------+---+------++++-+-++--+--+--++- +--+-+-++--+-+-----+-+--++-+-+----+-++++-+----+--+-++++---+++---++++-+--+--++--------+++-+-++++--++---+---++--++++-+-+++-++--++-++-+-++++------+---+------++++-+-++--+--+--++ +---+-+-++--+-+-----+-+--++-+-+----+-++++-+-+--+--+-++++---+++---++++-+--+--++--------+++-+-++++--++---+---++--++++-+-+++-++--++-++-+-++++------+---+------++++-+-++--+--+--+ +----+-+-++--+-+-----+-+--++-+-+----+-++++-+++--+--+-++++---+++---++++-+--+--++--------+++-+-++++--++---+---++--++++-+-+++-++--++-++-+-++++------+---+------++++-+-++--+--+-- ++----+-+-++--+-+-----+-+--++-+-+----+-++++--++--+--+-++++---+++---++++-+--+--++--------+++-+-++++--++---+---++--++++-+-+++-++--++-++-+-++++------+---+------++++-+-++--+--+- +-+----+-+-++--+-+-----+-+--++-+-+----+-++++--++--+--+-++++---+++---++++-+--+--++------+-+++-+-++++--++---+---++--++++-+-+++-++--+--++-+-++++------+---+------++++-+-++--+--+ ++-+----+-+-++--+-+-----+-+--++-+-+----+-+++---++--+--+-++++---+++---++++-+--+--++-----++-+++-+-++++--++---+---++--++++-+-+++-++--+--++-+-++++------+---+------++++-+-++--+-- +++-+----+-+-++--+-+-----+-+--++-+-+----+-++----++--+--+-++++---+++---++++-+--+--++-----++-+++-+-++++--++---+---++--++++-+-+++-++--+--++-+-++++------+---+------++++-+-++--+- ++++-+----+-+-++--+-+-----+-+--++-+-+----+-+-----++--+--+-++++---+++---++++-+--+--++-----++-+++-+-++++--++---+---++--++++-+-+++-++--+--++-+-++++------+---+------++++-+-++--+ +++++-+----+-+-++--+-+-----+-+--++-+-+----+-------++--+--+-++++---+++---++++-+--+--++--+--++-+++-+-++++--++---+---++--++++-+-+++-++--+--++-+-++++------+---+------++++-+-++-- +-++++-+----+-+-++--+-+-----+-+--++-+-+----+-------++--+--+-++++---+++---++++-+--+--++-++--++-+++-+-++++--++---+---++--++++-+-+++--+--+--++-+-++++------+---+------++++-+-++- ++-++++-+----+-+-++--+-+-----+-+--++-+-+------------++--+--+-++++---+++---++++-+--+--++-++--++-+++-+-++++--++---+---++--++++-+-+++--+--+--++-+-++++------+---+------++++-+-++ +-+-++++-+----+-+-++--+-+-----+-+--++-+-+---+--------++--+--+-++++---+++---++++-+--+--++-++--++-+++-+-++++--++---+---++--++++-+-+++--+--+--++-+-++++------+---+------++++-+-+ +--+-++++-+----+-+-++--+-+-----+-+--++-+-+--++--------++--+--+-++++---+++---++++-+--+--++-++--++-+++-+-++++--++---+---++--++++-+-+++--+--+--++-+-++++------+---+------++++-+- +---+-++++-+----+-+-++--+-+-----+-+--++-+-+--++--------++--+--+-++++---+++---++++-+--+-+++-++--++-+++-+-++++--++---+---++--++++-+--++--+--+--++-+-++++------+---+------++++-+ +----+-++++-+----+-+-++--+-+-----+-+--++-+-+--++--------++--+--+-++++---+++---++++-+--+-+++-++--++-+++-+-++++--++---+---++--++++-++-++--+--+--++-+-++++------+---+------++++- ++----+-++++-+----+-+-++--+-+-----+-+--++-+-+--++--------++--+--+-++++---+++---++++-+--+-+++-++--++-+++-+-++++--++---+---++--++++--+-++--+--+--++-+-++++------+---+------++++ +-+----+-++++-+----+-+-++--+-+-----+-+--++-+-+--++--------++--+--+-++++---+++---++++-+--+-+++-++--++-+++-+-++++--++---+---++--+++++-+-++--+--+--++-+-++++------+---+------+++ ++-+----+-++++-+----+-+-++--+-+-----+-+--++---+--++--------++--+--+-++++---+++---++++-++-+-+++-++--++-+++-+-++++--++---+---++--+++++-+-++--+--+--++-+-++++------+---+------++ +-+-+----+-++++-+----+-+-++--+-+-----+-+--+++--+--++--------++--+--+-++++---+++---++++-++-+-+++-++--++-+++-+-++++--++---+---++--+++++-+-++--+--+--++-+-++++------+---+------+ ++-+-+----+-++++-+----+-+-++--+-+-----+-+--+-+--+--++--------++--+--+-++++---+++---+++++++-+-+++-++--++-+++-+-++++--++---+---++--+++++-+-++--+--+--++-+-++++------+---+------ +++-+-+----+-++++-+----+-+-++--+-+-----+-+--+-+--+--++--------++--+--+-++++---+++---+++++++-+-+++-++--++-+++-+-++++--++---+---++---++++-+-++--+--+--++-+-++++------+---+----- +-++-+-+----+-++++-+----+-+-++--+-+-----+-+-++-+--+--++--------++--+--+-++++---+++---++-++++-+-+++-++--++-+++-+-++++--++---+---++---++++-+-++--+--+--++-+-++++------+---+---- +--++-+-+----+-++++-+----+-+-++--+-+-----+-++++-+--+--++--------++--+--+-++++---+++---+--++++-+-+++-++--++-+++-+-++++--++---+---++---++++-+-++--+--+--++-+-++++------+---+--- ++--++-+-+----+-++++-+----+-+-++--+-+-----+-++++-+--+--++--------++--+--+-++++---+++---+--++++-+-+++-++--++-+++-+-++++--++---+---+----++++-+-++--+--+--++-+-++++------+---+-- +-+--++-+-+----+-++++-+----+-+-++--+-+-----+-++++-+--+--++--------++--+--+-++++---+++--++--++++-+-+++-++--++-+++-+-++++--++---+--------++++-+-++--+--+--++-+-++++------+---+- ++-+--++-+-+----+-++++-+----+-+-++--+-+-------++++-+--+--++--------++--+--+-++++---+++--++--++++-+-+++-++--++-+++-+-++++--++---+--------++++-+-++--+--+--++-+-++++------+---+ +-+-+--++-+-+----+-++++-+----+-+-++--+-+-------++++-+--+--++--------++--+--+-++++---+++--++--++++-+-+++-++--++-+++-+-++++--++---+-+------++++-+-++--+--+--++-+-++++------+--- +--+-+--++-+-+----+-++++-+----+-+-++--+-+---+---++++-+--+--++--------++--+--+-++++---++---++--++++-+-+++-++--++-+++-+-++++--++---+-+------++++-+-++--+--+--++-+-++++------+-- +--+++----+-++-++--++++++++--++-++-+----+++----+-+--++-+-+----+-++++-+----+-+-++--+-+--++-++++++----+-+--++-++-++--+-+----++++++-++---++--++++-+-+++-++--++-+++-+-++++--++--- +---+++----+-++-++--++++++++--++-++-+----+++----+-+--++-+-+----+-++++-+----+-+-++--+-+-+++-++++++----+-+--++-++-++--+-+----++++++--+---++--++++-+-+++-++--++-+++-+-++++--++-- ++---+++----+-++-++--++++++++--++-++-+----++-----+-+--++-+-+----+-++++-+----+-+-++--+-+-+++-++++++----+-+--++-++-++--+-+----++++++--+---++--++++-+-+++-++--++-+++-+-++++--++- +++---+++----+-++-++--++++++++--++-++-+----++-----+-+--++-+-+----+-++++-+----+-+-++--+-+-+++-++++++----+-+--++-++-++--+-+----+++++---+---++--++++-+-+++-++--++-+++-+-++++--++ ++++---+++----+-++-++--++++++++--++-++-+-----+-----+-+--++-+-+----+-++++-+----+-+-++--+++-+++-++++++----+-+--++-++-++--+-+----+++++---+---++--++++-+-+++-++--++-+++-+-++++--+ +-+++---+++----+-++-++--++++++++--++-++-+---+-+-----+-+--++-+-+----+-++++-+----+-+-++--+++-+++-++++++----+-+--++-++-++--+-+----+++++---+---++--++++-+-+++-++--++-+++-+-++++-- +--+++---+++----+-++-++--++++++++--++-++-+---+-+-----+-+--++-+-+----+-++++-+----+-+-++-++++-+++-++++++----+-+--++-++-++--+-+----++-++---+---++--++++-+-+++-++--++-+++-+-++++- +---+++---+++----+-++-++--++++++++--++-++-+---+-+-----+-+--++-+-+----+-++++-+----+-+-+++++++-+++-++++++----+-+--++-++-++--+-+----+--++---+---++--++++-+-+++-++--++-+++-+-++++ +----+++---+++----+-++-++--++++++++--++-++-++--+-+-----+-+--++-+-+----+-++++-+----+-+-+++++++-+++-++++++----+-+--++-++-++--+-+----+--++---+---++--++++-+-+++-++--++-+++-+-+++ ++----+++---+++----+-++-++--++++++++--++-++-++--+-+-----+-+--++-+-+----+-++++-+----+-+--++++++-+++-++++++----+-+--++-++-++--+-+---++--++---+---++--++++-+-+++-++--++-+++-+-++ +-+----+++---+++----+-++-++--++++++++--++-++-++--+-+-----+-+--++-+-+----+-++++-+----+-+--++++++-+++-++++++----+-+--++-++-++--+-+--+++--++---+---++--++++-+-+++-++--++-+++-+-+ ++-+----+++---+++----+-++-++--++++++++--++-++-++--+-+-----+-+--++-+-+----+-++++-+----+----++++++-+++-++++++----+-+--++-++-++--+-+-++++--++---+---++--++++-+-+++-++--++-+++-+- +++-+----+++---+++----+-++-++--++++++++--++--+-++--+-+-----+-+--++-+-+----+-++++-+----+----++++++-+++-++++++----+-+--++-++-++--+-+-++++--++---+---++--++++-+-+++-++--++-+++-+ +-++-+----+++---+++----+-++-++--++++++++--+++-+-++--+-+-----+-+--++-+-+----+-++++-+----+----++++++-+++-++++++----+-+--++-++-++--+-+-++++--++---+---++--++++-+-+++-++--++-+++- ++-++-+----+++---+++----+-++-++--++++++++--+-+-+-++--+-+-----+-+--++-+-+----+-++++-+----+----++++++-+++-++++++----+-+--++-++-++--+-+-++++--++---+---++--++++-+-+++-++--++-+++ +++-++-+----+++---+++----+-++-++--++++++++----+-+-++--+-+-----+-+--++-+-+----+-++++-+--+-+----++++++-+++-++++++----+-+--++-++-++--+-+-++++--++---+---++--++++-+-+++-++--++-++ +-++-++-+----+++---+++----+-++-++--++++++++----+-+-++--+-+-----+-+--++-+-+----+-++++-+--+-+----++++++-+++-++++++----+-+--++-++-++-++-+-++++--++---+---++--++++-+-+++-++--++-+ +--++-++-+----+++---+++----+-++-++--++++++++----+-+-++--+-+-----+-+--++-+-+----+-++++-+--+-+----++++++-+++-++++++----+-+--++-++-+++++-+-++++--++---+---++--++++-+-+++-++--++- ++--++-++-+----+++---+++----+-++-++--++++++++----+-+-++--+-+-----+-+--++-+-+----+-++++-+--+-+----++++++-+++-++++++----+-+--++-++-+-+++-+-++++--++---+---++--++++-+-+++-++--++ +++--++-++-+----+++---+++----+-++-++--++++++-+----+-+-++--+-+-----+-+--++-+-+----+-++++++--+-+----++++++-+++-++++++----+-+--++-++-+-+++-+-++++--++---+---++--++++-+-+++-++--+ ++++--++-++-+----+++---+++----+-++-++--++++++-+----+-+-++--+-+-----+-+--++-+-+----+-+++-++--+-+----++++++-+++-++++++----+-+--++-++++-+++-+-++++--++---+---++--++++-+-+++-++-- +++++--++-++-+----+++---+++----+-++-++--++++++-+----+-+-++--+-+-----+-+--++-+-+----+-+++-++--+-+----++++++-+++-++++++----+-+--++-+-++-+++-+-++++--++---+---++--++++-+-+++-++- ++++++--++-++-+----+++---+++----+-++-++--++++++-+----+-+-++--+-+-----+-+--++-+-+----+-+++-++--+-+----++++++-+++-++++++----+-+--++---++-+++-+-++++--++---+---++--++++-+-+++-++ +++++++--++-++-+----+++---+++----+-++-++--++++++-+----+-+-++--+-+-----+-+--++-+-+----+--++-++--+-+----++++++-+++-++++++----+-+--+++--++-+++-+-++++--++---+---++--++++-+-+++-+ ++++++++--++-++-+----+++---+++----+-++-++--+-++++-+----+-+-++--+-+-----+-+--++-+-+----++-++-++--+-+----++++++-+++-++++++----+-+--+++--++-+++-+-++++--++---+---++--++++-+-+++- +++++++++--++-++-+----+++---+++----+-++-++--+-++++-+----+-+-++--+-+-----+-+--++-+-+----++-++-++--+-+----++++++-+++-++++++----+-+---++--++-+++-+-++++--++---+---++--++++-+-+++ +-++++++++--++-++-+----+++---+++----+-++-++--+-++++-+----+-+-++--+-+-----+-+--++-+-+----++-++-++--+-+----++++++-+++-++++++----+-+-+-++--++-+++-+-++++--++---+---++--++++-+-++ +--++++++++--++-++-+----+++---+++----+-++-++--+-++++-+----+-+-++--+-+-----+-+--++-+-+----++-++-++--+-+----++++++-+++-++++++----+-+++-++--++-+++-+-++++--++---+---++--++++-+-+ ++--++++++++--++-++-+----+++---+++----+-++-+---+-++++-+----+-+-++--+-+-----+-+--++-+-+-+--++-++-++--+-+----++++++-+++-++++++----+-+++-++--++-+++-+-++++--++---+---++--++++-+- +++--++++++++--++-++-+----+++---+++----+-++-----+-++++-+----+-+-++--+-+-----+-+--++-+-+-+--++-++-++--+-+----++++++-+++-++++++----+-+++-++--++-+++-+-++++--++---+---++--++++-+ +-++--++++++++--++-++-+----+++---+++----+-+++----+-++++-+----+-+-++--+-+-----+-+--++-+-+-+--++-++-++--+-+----++++++-+++-++++++----+-+++-++--++-+++-+-++++--++---+---++--++++- ++-++--++++++++--++-++-+----+++---+++----+-+-+----+-++++-+----+-+-++--+-+-----+-+--++-+-+-+--++-++-++--+-+----++++++-+++-++++++----+-+++-++--++-+++-+-++++--++---+---++--++++ +++-++--++++++++--++-++-+----+++---+++----+-+-+----+-++++-+----+-+-++--+-+-----+-+--++---+-+--++-++-++--+-+----++++++-+++-++++++--+-+-+++-++--++-+++-+-++++--++---+---++--+++ +-++-++--++++++++--++-++-+----+++---+++----+-+-+----+-++++-+----+-+-++--+-+-----+-+--++---+-+--++-++-++--+-+----++++++-+++-++++++-++-+-+++-++--++-+++-+-++++--++---+---++--++ ++-++-++--++++++++--++-++-+----+++---+++----+-+-+----+-++++-+----+-+-++--+-+-----+-+--+----+-+--++-++-++--+-+----++++++-+++-+++++++++-+-+++-++--++-+++-+-++++--++---+---++--+ +-+-++-++--++++++++--++-++-+----+++---+++---++-+-+----+-++++-+----+-+-++--+-+-----+-+--+----+-+--++-++-++--+-+----++++++-+++-+++++++++-+-+++-++--++-+++-+-++++--++---+---++-- +--+-++-++--++++++++--++-++-+----+++---+++---++-+-+----+-++++-+----+-+-++--+-+-----+-+-++----+-+--++-++-++--+-+----++++++-+++-++++-++++-+-+++-++--++-+++-+-++++--++---+---++- +---+-++-++--++++++++--++-++-+----+++---+++---++-+-+----+-++++-+----+-+-++--+-+-----+-++++----+-+--++-++-++--+-+----++++++-+++-+++--++++-+-+++-++--++-+++-+-++++--++---+---++ +----+-++-++--++++++++--++-++-+----+++---++++--++-+-+----+-++++-+----+-+-++--+-+-----+-++++----+-+--++-++-++--+-+----++++++-+++-+++--++++-+-+++-++--++-+++-+-++++--++---+---+ ++----+-++-++--++++++++--++-++-+----+++---++-+--++-+-+----+-++++-+----+-+-++--+-+-----++++++----+-+--++-++-++--+-+----++++++-+++-+++--++++-+-+++-++--++-+++-+-++++--++---+--- +++----+-++-++--++++++++--++-++-+----+++---++-+--++-+-+----+-++++-+----+-+-++--+-+-----++++++----+-+--++-++-++--+-+----++++++-+++--++--++++-+-+++-++--++-+++-+-++++--++---+-- ++++----+-++-++--++++++++--++-++-+----+++----+-+--++-+-+----+-++++-+----+-+-++--+-+-----++++++----+-+--++-++-++--+-+----++++++-+++--++--++++-+-+++-++--++-+++-+-++++--++---+- +-+++----+-++-++--++++++++--++-++-+----+++----+-+--++-+-+----+-++++-+----+-+-++--+-+---+-++++++----+-+--++-++-++--+-+----++++++-++---++--++++-+-+++-++--++-+++-+-++++--++---+ \ No newline at end of file diff --git a/gptqmodel/exllamav3/util/hadamard_data/hadamard_188.txt b/gptqmodel/exllamav3/util/hadamard_data/hadamard_188.txt new file mode 100644 index 000000000..09126ef5c --- /dev/null +++ b/gptqmodel/exllamav3/util/hadamard_data/hadamard_188.txt @@ -0,0 +1,188 @@ ++-+----+--+----++---++++---+-+----++++++--+---++---+--++++++----+-+---+---+++--++++-++-++++-+-+--+--+-+++-----+---+---++--+-++-+++-+-+--+---++--+--+-+++-----+---+-----++-+--+---+-+-++-+++- +++-+----+--+----++---++++---+-+----++++++--+------+--++++++----+-+---+---+++--++++-++-++++-+-+--+--+-+++-----+---+---++--+-++-+++-+-+--+---++--+--+-+++-----+---+-----++-+--+---+-+-++-+++-+ +-++-+----+--+----++---++++---+-+----++++++--+----+--++++++----+-+---+---+++--++++-++-++++-+-+--+--+-+++-----+---+---++--+-++-+++-+-+--+---++--+--+-+++-----+---+-----++-+--+---+-+-++-+++-+- +--++-+----+--+----++---++++---+-+----++++++--+--+--++++++----+-+---+---+++--++++-++-++++-+-+--+--+-+++-----+---+---++--+-++-+++-+-+--+---++--+--+-+++-----+---+-----++-+--+---+-+-++-+++-+-- +---++-+----+--+----++---++++---+-+----++++++--++--++++++----+-+---+---+++--++++-++-++++-+-+-----+-+++-----+---+---++--+-++-+++-+-+--+---++--+--+-+++-----+---+-----++-+--+---+-+-++-+++-+--+ ++---++-+----+--+----++---++++---+-+----++++++----++++++----+-+---+---+++--++++-++-++++-+-+---+-+-+++-----+---+---++--+-++-+++-+-+--+---++--+--+-+++-----+---+-----++-+--+---+-+-++-+++-+--+- +-+---++-+----+--+----++---++++---+-+----++++++--++++++----+-+---+---+++--++++-++-++++-+-+---+-+-+++-----+---+---++--+-++-+++-+-+--+---++--+--+-+++-----+---+-----++-+--+---+-+-++-+++-+--+-- +--+---++-+----+--+----++---++++---+-+----++++++++++++----+-+---+---+++--++++-++-++++-+-+---+---+++-----+---+---++--+-++-+++-+-+--+---++--+--+-+++-----+---+-----++-+--+---+-+-++-+++-+--+--+ ++--+---++-+----+--+----++---++++---+-+----++++++++++----+-+---+---+++--++++-++-++++-+-+---+--++++-----+---+---++--+-++-+++-+-+--+---++--+--+-+++-----+---+-----++-+--+---+-+-++-+++-+--+--+- +++--+---++-+----+--+----++---++++---+-+----++++++++----+-+---+---+++--++++-++-++++-+-+---+--++++-----+---+---++--+-++-+++-+-+--+---++--+--+-+++-----+---+-----++-+--+---+-+-++-+++-+--+--+-+ ++++--+---++-+----+--+----++---++++---+-+----++++++----+-+---+---+++--++++-++-++++-+-+---+--++++-----+---+---++--+-++-+++-+-+--+---++--+--+-+++-----+---+-----++-+--+---+-+-++-+++-+--+--+-++ +++++--+---++-+----+--+----++---++++---+-+----++++----+-+---+---+++--++++-++-++++-+-+---+--++++-----+---+---++--+-++-+++-+-+--+---++--+--+-+++-----+---+-----++-+--+---+-+-++-+++-+--+--+-+++ ++++++--+---++-+----+--+----++---++++---+-+----++----+-+---+---+++--++++-++-++++-+-+---+--+++++----+---+---++--+-++-+++-+-+--+---++--+--+-+++-----+---+-----++-+--+---+-+-++-+++-+--+--+-+++- +++++++--+---++-+----+--+----++---++++---+-+--------+-+---+---+++--++++-++-++++-+-+---+--++++++---+---+---++--+-++-+++-+-+--+---++--+--+-+++-----+---+-----++-+--+---+-+-++-+++-+--+--+-+++-- +-++++++--+---++-+----+--+----++---++++---+-+------+-+---+---+++--++++-++-++++-+-+---+--++++++---+---+---++--+-++-+++-+-+--+---++--+--+-+++-----+---+-----++-+--+---+-+-++-+++-+--+--+-+++--- +--++++++--+---++-+----+--+----++---++++---+-+----+-+---+---+++--++++-++-++++-+-+---+--++++++---+---+---++--+-++-+++-+-+--+---++--+--+-+++-----+---+-----++-+--+---+-+-++-+++-+--+--+-+++---- +---++++++--+---++-+----+--+----++---++++---+-+--+-+---+---+++--++++-++-++++-+-+---+--++++++---+---+---++--+-++-+++-+-+--+---++--+--+-+++-----+---+-----++-+--+---+-+-++-+++-+--+--+-+++----- +----++++++--+---++-+----+--+----++---++++---+-++-+---+---+++--++++-++-++++-+-+---+--++++++-------+---++--+-++-+++-+-+--+---++--+--+-+++-----+---+-----++-+--+---+-+-++-+++-+--+--+-+++-----+ ++----++++++--+---++-+----+--+----++---++++---+--+---+---+++--++++-++-++++-+-+---+--++++++----+--+---++--+-++-+++-+-+--+---++--+--+-+++-----+---+-----++-+--+---+-+-++-+++-+--+--+-+++-----+- +-+----++++++--+---++-+----+--+----++---++++---++---+---+++--++++-++-++++-+-+---+--++++++----+--+---++--+-++-+++-+-+--+---++--+--+-+++-----+---+-----++-+--+---+-+-++-+++-+--+--+-+++-----+-- ++-+----++++++--+---++-+----+--+----++---++++------+---+++--++++-++-++++-+-+---+--++++++----+-++---++--+-++-+++-+-+--+---++--+--+-+++-----+---+-----++-+--+---+-+-++-+++-+--+--+-+++-----+--- +-+-+----++++++--+---++-+----+--+----++---++++----+---+++--++++-++-++++-+-+---+--++++++----+-+----++--+-++-+++-+-+--+---++--+--+-+++-----+---+-----++-+--+---+-+-++-+++-+--+--+-+++-----+---+ +--+-+----++++++--+---++-+----+--+----++---++++--+---+++--++++-++-++++-+-+---+--++++++----+-+----++--+-++-+++-+-+--+---++--+--+-+++-----+---+-----++-+--+---+-+-++-+++-+--+--+-+++-----+---+- +---+-+----++++++--+---++-+----+--+----++---+++++---+++--++++-++-++++-+-+---+--++++++----+-+----++--+-++-+++-+-+--+---++--+--+-+++-----+---+-----++-+--+---+-+-++-+++-+--+--+-+++-----+---+-- ++---+-+----++++++--+---++-+----+--+----++---+++---+++--++++-++-++++-+-+---+--++++++----+-+---+++--+-++-+++-+-+--+---++--+--+-+++-----+---+-----++-+--+---+-+-++-+++-+--+--+-+++-----+---+--- +++---+-+----++++++--+---++-+----+--+----++---++--+++--++++-++-++++-+-+---+--++++++----+-+---+-+--+-++-+++-+-+--+---++--+--+-+++-----+---+---+-++-+--+---+-+-++-+++-+--+--+-+++-----+---+---- ++++---+-+----++++++--+---++-+----+--+----++---+-+++--++++-++-++++-+-+---+--++++++----+-+---+----+-++-+++-+-+--+---++--+--+-+++-----+---+---++++-+--+---+-+-++-+++-+--+--+-+++-----+---+----- +++++---+-+----++++++--+---++-+----+--+----++---+++--++++-++-++++-+-+---+--++++++----+-+---+----+-++-+++-+-+--+---++--+--+-+++-----+---+---++-+-+--+---+-+-++-+++-+--+--+-+++-----+---+-----+ +-++++---+-+----++++++--+---++-+----+--+----++--++--++++-++-++++-+-+---+--++++++----+-+---+---++-++-+++-+-+--+---++--+--+-+++-----+---+---++---+--+---+-+-++-+++-+--+--+-+++-----+---+-----++ +--++++---+-+----++++++--+---++-+----+--+----++-+--++++-++-++++-+-+---+--++++++----+-+---+---++-++-+++-+-+--+---++--+--+-+++-----+---+---++--++--+---+-+-++-+++-+--+--+-+++-----+---+-----++- +---++++---+-+----++++++--+---++-+----+--+----++--++++-++-++++-+-+---+--++++++----+-+---+---+++++-+++-+-+--+---++--+--+-+++-----+---+---++--+---+---+-+-++-+++-+--+--+-+++-----+---+-----++-+ ++---++++---+-+----++++++--+---++-+----+--+----+-++++-++-++++-+-+---+--++++++----+-+---+---+++-+-+++-+-+--+---++--+--+-+++-----+---+---++--+-+-+---+-+-++-+++-+--+--+-+++-----+---+-----++-+- +++---++++---+-+----++++++--+---++-+----+--+----++++-++-++++-+-+---+--++++++----+-+---+---+++---+++-+-+--+---++--+--+-+++-----+---+---++--+-+++---+-+-++-+++-+--+--+-+++-----+---+-----++-+-- +-++---++++---+-+----++++++--+---++-+----+--+---+++-++-++++-+-+---+--++++++----+-+---+---+++--++++-+-+--+---++--+--+-+++-----+---+---++--+-++----+-+-++-+++-+--+--+-+++-----+---+-----++-+--+ +--++---++++---+-+----++++++--+---++-+----+--+--++-++-++++-+-+---+--++++++----+-+---+---+++--++++-+-+--+---++--+--+-+++-----+---+---++--+-++-+--+-+-++-+++-+--+--+-+++-----+---+-----++-+--+- +---++---++++---+-+----++++++--+---++-+----+--+-+-++-++++-+-+---+--++++++----+-+---+---+++--++++-+-+--+---++--+--+-+++-----+---+---++--+-++-++-+-+-++-+++-+--+--+-+++-----+---+-----++-+--+-- +----++---++++---+-+----++++++--+---++-+----+--+-++-++++-+-+---+--++++++----+-+---+---+++--++++-+-+--+---++--+--+-+++-----+---+---++--+-++-++++-+-++-+++-+--+--+-+++-----+---+-----++-+--+--- ++----++---++++---+-+----++++++--+---++-+----+--++-++++-+-+---+--++++++----+-+---+---+++--++++-+-+--+---++--+--+-+++-----+---+---++--+-++-+++--+-++-+++-+--+--+-+++-----+---+-----++-+--+---+ +-+----++---++++---+-+----++++++--+---++-+----+-+-++++-+-+---+--++++++----+-+---+---+++--++++-+-+--+---++--+--+-+++-----+---+---++--+-++-+++-++-++-+++-+--+--+-+++-----+---+-----++-+--+---+- +--+----++---++++---+-+----++++++--+---++-+----+-++++-+-+---+--++++++----+-+---+---+++--++++-+++--+---++--+--+-+++-----+---+---++--+-++-+++-+--++-+++-+--+--+-+++-----+---+-----++-+--+---+-+ ++--+----++---++++---+-+----++++++--+---++-+----++++-+-+---+--++++++----+-+---+---+++--++++-++---+---++--+--+-+++-----+---+---++--+-++-+++-+-+++-+++-+--+--+-+++-----+---+-----++-+--+---+-+- +-+--+----++---++++---+-+----++++++--+---++-+---+++-+-+---+--++++++----+-+---+---+++--++++-++-+-+---++--+--+-+++-----+---+---++--+-++-+++-+-+-+-+++-+--+--+-+++-----+---+-----++-+--+---+-+-+ +--+--+----++---++++---+-+----++++++--+---++-+--++-+-+---+--++++++----+-+---+---+++--++++-++-+++---++--+--+-+++-----+---+---++--+-++-+++-+-+---+++-+--+--+-+++-----+---+-----++-+--+---+-+-++ +---+--+----++---++++---+-+----++++++--+---++-+-+-+-+---+--++++++----+-+---+---+++--++++-++-+++---++--+--+-+++-----+---+---++--+-++-+++-+-+--++++-+--+--+-+++-----+---+-----++-+--+---+-+-++- +----+--+----++---++++---+-+----++++++--+---++-+-+-+---+--++++++----+-+---+---+++--++++-++-++++--++--+--+-+++-----+---+---++--+-++-+++-+-+--+-++-+--+--+-+++-----+---+-----++-+--+---+-+-++-+ ++----+--+----++---++++---+-+----++++++--+---++-+-+---+--++++++----+-+---+---+++--++++-++-++++--++--+--+-+++-----+---+---++--+-++-+++-+-+--+--+-+--+--+-+++-----+---+-----++-+--+---+-+-++-++ +-+----+--+----++---++++---+-+----++++++--+---++-+---+--++++++----+-+---+---+++--++++-++-++++-+++--+--+-+++-----+---+---++--+-++-+++-+-+--+----+--+--+-+++-----+---+-----++-+--+---+-+-++-+++ +-+++-++------++++-+-+++-+++---++----+--+----+-++-+----+--+----++---++++---+-+----++++++--+---+---+--+-+-+++-++-+--+++++-+++-+++++---+-++-++-+---+--+-+-+++-++-+--++---+---+-----+++-+--+--++ ++++-++------++++-+-+++-+++---++----+--+----+-+-++-+----+--+----++---++++---+-+----++++++--+-----+--+-+-+++-++-+--+++++-+++-+++++---+-++-++-+---+--+-+-+++-++-+--++---+---+-----+++-+--+--++- +++-++------++++-+-+++-+++---++----+--+----+-+-+-++-+----+--+----++---++++---+-+----++++++--+---+--+-+-+++-++-+--+++++-+++-+++++---+-++-++-+---+--+-+-+++-++-+--++---+---+-----+++-+--+--++-- ++-++------++++-+-+++-+++---++----+--+----+-+-++--++-+----+--+----++---++++---+-+----++++++--+-+--+-+-+++-++-+--+++++-+++-+++++---+-++-++-+---+--+-+-+++-++-+--++---+---+-----+++-+--+--++--- +-++------++++-+-+++-+++---++----+--+----+-+-+++---++-+----+--+----++---++++---+-+----++++++--+--+-+-+++-++-+--+++++-+++-+++++---+-++-++-+---+--+-+-+++-++-+--++---+---+-----+++-+--+--++---+ +++------++++-+-+++-+++---++----+--+----+-+-+++-+---++-+----+--+----++---++++---+-+----++++++---+-+-+++-++-+--+++++-+++-+++++---+-++-++-+---+--+-+-+++-++-+--++---+---+-----+++-+--+--++---+- ++------++++-+-+++-+++---++----+--+----+-+-+++-+-+---++-+----+--+----++---++++---+-+----++++++-+-+-+++-++-+--+++++-+++-+++++---+-++-++-+---+--+-+-+++-++-+--++---+---+-----+++-+--+--++---+-- +------++++-+-+++-+++---++----+--+----+-+-+++-++--+---++-+----+--+----++---++++---+-+----++++++-+-+++-++-+--+++++-+++-+++++---+-++-++-+---+--+-+-+++-++-+--++---+---+-----+++-+--+--++---+--+ +-----++++-+-+++-+++---++----+--+----+-+-+++-++-+--+---++-+----+--+----++---++++---+-+----++++++-+++-++-+--+++++-+++-+++++---+-++-++-+---+--+-+-+++-++-+--++---+---+-----+++-+--+--++---+--+- +----++++-+-+++-+++---++----+--+----+-+-+++-++--++--+---++-+----+--+----++---++++---+-+----++++-+++-++-+--+++++-+++-+++++---+-++-++-+---+--+-+-+++-++-+--++---+---+-----+++-+--+--++---+--+-+ +---++++-+-+++-+++---++----+--+----+-+-+++-++---+++--+---++-+----+--+----++---++++---+-+----++++++-++-+--+++++-+++-+++++---+-++-++-+---+--+-+-+++-++-+--++---+---+-----+++-+--+--++---+--+-+- +--++++-+-+++-+++---++----+--+----+-+-+++-++----++++--+---++-+----+--+----++---++++---+-+----++++-++-+--+++++-+++-+++++---+-++-++-+---+--+-+-+++-++-+--++---+---+-----+++-+--+--++---+--+-+-+ +-++++-+-+++-+++---++----+--+----+-+-+++-++-----+++++--+---++-+----+--+----++---++++---+-+----++-++-+--+++++-+++-+++++---+-++-++-+---+--+-+-+++-++-+--++---+---+-----+++-+--+--++---+--+-+-++ +++++-+-+++-+++---++----+--+----+-+-+++-++------++++++--+---++-+----+--+----++---++++---+-+-----++-+--+++++-+++-+++++---+-++-++-+---+--+-+-+++-++-+--++---+---+-----+++-+--+--++---+--+-+-+++ ++++-+-+++-+++---++----+--+----+-+-+++-++------+-++++++--+---++-+----+--+----++---++++---+-+---++-+--+++++-+++-+++++---+-++-++-+---+--+-+-+++-++-+--++---+---+-----+++-+--+--++---+--+-+-+++- +++-+-+++-+++---++----+--+----+-+-+++-++------++--++++++--+---++-+----+--+----++---++++---+-+--+-+--+++++-+++-+++++---+-++-++-+---+--+-+-+++-++-+--++---+---+-----+++-+--+--++---+--+-+-+++-+ ++-+-+++-+++---++----+--+----+-+-+++-++------+++---++++++--+---++-+----+--+----++---++++---+-+--+--+++++-+++-+++++---+-++-++-+---+--+-+-+++-++-+--++---+---+-----+++-+--+--++---+--+-+-+++-++ +-+-+++-+++---++----+--+----+-+-+++-++------++++----++++++--+---++-+----+--+----++---++++---+-++--+++++-+++-+++++---+-++-++-+---+--+-+-+++-++-+--++---+---+-----+++-+--+--++---+--+-+-+++-++- ++-+++-+++---++----+--+----+-+-+++-++------++++-+----++++++--+---++-+----+--+----++---++++---+---+++++-+++-+++++---+-++-++-+---+--+-+-+++-++-+--++---+---+-----+++-+--+--++---+--+-+-+++-++-+ +-+++-+++---++----+--+----+-+-+++-++------++++-+-+----++++++--+---++-+----+--+----++---++++---+-+++++-+++-+++++---+-++-++-+---+--+-+-+++-++-+--++---+---+-----+++-+--+--++---+--+-+-+++-++-+- ++++-+++---++----+--+----+-+-+++-++------++++-+-+-+----++++++--+---++-+----+--+----++---++++---+++++-+++-+++++---+-++-++-+---+--+-+-+++-++-+--++---+---+-----+++-+--+--++---+--+-+-+++-++-+-- +++-+++---++----+--+----+-+-+++-++------++++-+-+-+-+----++++++--+---++-+----+--+----++---++++--++++-+++-+++++---+-++-++-+---+--+-+-+++-++-+--++---+---+-----+++-+--+--++---+--+-+-+++-++-+--+ ++-+++---++----+--+----+-+-+++-++------++++-+-++--+-+----++++++--+---++-+----+--+----++---++++-+++-+++-+++++---+-++-++-+---+--+-+-+++-++-+--++---+---+-----+++-+--+--++---+--+-+-+++-++-+--++ +-+++---++----+--+----+-+-+++-++------++++-+-+++---+-+----++++++--+---++-+----+--+----++---++++++-+++-+++++---+-++-++-+---+--+-+-+++-++-+--+++--+---+-----+++-+--+--++---+--+-+-+++-++-+--++- ++++---++----+--+----+-+-+++-++------++++-+-+++-+---+-+----++++++--+---++-+----+--+----++---++++-+++-+++++---+-++-++-+---+--+-+-+++-++-+--++++-+---+-----+++-+--+--++---+--+-+-+++-++-+--++-- +++---++----+--+----+-+-+++-++------++++-+-+++-+++---+-+----++++++--+---++-+----+--+----++---++-+++-+++++---+-++-++-+---+--+-+-+++-++-+--++++++---+-----+++-+--+--++---+--+-+-+++-++-+--++--- ++---++----+--+----+-+-+++-++------++++-+-+++-+++++---+-+----++++++--+---++-+----+--+----++---++++-+++++---+-++-++-+---+--+-+-+++-++-+--+++++----+-----+++-+--+--++---+--+-+-+++-++-+--++---+ +---++----+--+----+-+-+++-++------++++-+-+++-+++++++---+-+----++++++--+---++-+----+--+----++---++-+++++---+-++-++-+---+--+-+-+++-++-+--+++++-+--+-----+++-+--+--++---+--+-+-+++-++-+--++---+- +--++----+--+----+-+-+++-++------++++-+-+++-+++--++++---+-+----++++++--+---++-+----+--+----++--+-+++++---+-++-++-+---+--+-+-+++-++-+--+++++-++-+-----+++-+--+--++---+--+-+-+++-++-+--++---+-- +-++----+--+----+-+-+++-++------++++-+-+++-+++----++++---+-+----++++++--+---++-+----+--+----++--+++++---+-++-++-+---+--+-+-+++-++-+--+++++-++++-----+++-+--+--++---+--+-+-+++-++-+--++---+--- +++----+--+----+-+-+++-++------++++-+-+++-+++------++++---+-+----++++++--+---++-+----+--+----+++++++---+-++-++-+---+--+-+-+++-++-+--+++++-+++------+++-+--+--++---+--+-+-+++-++-+--++---+---+ ++----+--+----+-+-+++-++------++++-+-+++-+++---++---++++---+-+----++++++--+---++-+----+--+----+++++---+-++-++-+---+--+-+-+++-++-+--+++++-+++-+----+++-+--+--++---+--+-+-+++-++-+--++---+---+- +----+--+----+-+-+++-++------++++-+-+++-+++---++++---++++---+-+----++++++--+---++-+----+--+----+++---+-++-++-+---+--+-+-+++-++-+--+++++-+++-++---+++-+--+--++---+--+-+-+++-++-+--++---+---+-- +---+--+----+-+-+++-++------++++-+-+++-+++---++--++---++++---+-+----++++++--+---++-+----+--+---++---+-++-++-+---+--+-+-+++-++-+--+++++-+++-+++--+++-+--+--++---+--+-+-+++-++-+--++---+---+--- +--+--+----+-+-+++-++------++++-+-+++-+++---++----++---++++---+-+----++++++--+---++-+----+--+--+---+-++-++-+---+--+-+-+++-++-+--+++++-+++-++++-+++-+--+--++---+--+-+-+++-++-+--++---+---+---- +-+--+----+-+-+++-++------++++-+-+++-+++---++------++---++++---+-+----++++++--+---++-+----+--+----+-++-++-+---+--+-+-+++-++-+--+++++-+++-++++++++-+--+--++---+--+-+-+++-++-+--++---+---+----- ++--+----+-+-+++-++------++++-+-+++-+++---++--------++---++++---+-+----++++++--+---++-+----+--+--+-++-++-+---+--+-+-+++-++-+--+++++-+++-+++++-++-+--+--++---+--+-+-+++-++-+--++---+---+-----+ +--+----+-+-+++-++------++++-+-+++-+++---++----++----++---++++---+-+----++++++--+---++-+----+---+-++-++-+---+--+-+-+++-++-+--+++++-+++-+++++--+-+--+--++---+--+-+-+++-++-+--++---+---+-----++ +-+----+-+-+++-++------++++-+-+++-+++---++----+--+----++---++++---+-+----++++++--+---++-+----+-+-++-++-+---+--+-+-+++-++-+--+++++-+++-+++++----+--+--++---+--+-+-+++-++-+--++---+---+-----+++ ++----+-+-+++-++------++++-+-+++-+++---++----+----+----++---++++---+-+----++++++--+---++-+----+-++-++-+---+--+-+-+++-++-+--+++++-+++-+++++---++--+--++---+--+-+-+++-++-+--++---+---+-----+++- +----+-+-+++-++------++++-+-+++-+++---++----+--++--+----++---++++---+-+----++++++--+---++-+----++-++-+---+--+-+-+++-++-+--+++++-+++-+++++---+---+--++---+--+-+-+++-++-+--++---+---+-----+++-+ +---+-+-+++-++------++++-+-+++-+++---++----+--+--+--+----++---++++---+-+----++++++--+---++-+---+-++-+---+--+-+-+++-++-+--+++++-+++-+++++---+-+-+--++---+--+-+-+++-++-+--++---+---+-----+++-+- +--+-+-+++-++------++++-+-+++-+++---++----+--+----+--+----++---++++---+-+----++++++--+---++-+---++-+---+--+-+-+++-++-+--+++++-+++-+++++---+-+++--++---+--+-+-+++-++-+--++---+---+-----+++-+-- +-+-+-+++-++------++++-+-+++-+++---++----+--+------+--+----++---++++---+-+----++++++--+---++-+-++-+---+--+-+-+++-++-+--+++++-+++-+++++---+-++---++---+--+-+-+++-++-+--++---+---+-----+++-+--+ ++-+-+++-++------++++-+-+++-+++---++----+--+--------+--+----++---++++---+-+----++++++--+---++-++-+---+--+-+-+++-++-+--+++++-+++-+++++---+-++-+-++---+--+-+-+++-++-+--++---+---+-----+++-+--+- +-+-+++-++------++++-+-+++-+++---++----+--+----++----+--+----++---++++---+-+----++++++--+---++--+---+--+-+-+++-++-+--+++++-+++-+++++---+-++-++++---+--+-+-+++-++-+--++---+---+-----+++-+--+-- ++-+++-++------++++-+-+++-+++---++----+--+----+--+----+--+----++---++++---+-+----++++++--+---+++---+--+-+-+++-++-+--+++++-+++-+++++---+-++-++-+---+--+-+-+++-++-+--++---+---+-----+++-+--+--+ +-++-++-+---+++++-+++-+++--++-+--+---+-+-++-+++-+++-++-+-+---+--+-++-----+---+-----+++-+--+--+-+-+----+--+----++---++++---+-+----++++++--+---+-+----+--+----++---+++-+++-+-++++------++-+++-+ +++-++-+---+++++-+++-+++--++-+--+---+-+-++-+++--++-++-+-+---+--+-++-----+---+-----+++-+--+--+-+++-+----+--+----++---++++---+-+----++++++--+---+----+--+----++---+++-+++-+-++++------++-+++-+- ++-++-+---+++++-+++-+++--++-+--+---+-+-++-+++--++-++-+-+---+--+-++-----+---+-----+++-+--+--+-++-++-+----+--+----++---++++---+-+----++++++--+------+--+----++---+++-+++-+-++++------++-+++-+-+ +-++-+---+++++-+++-+++--++-+--+---+-+-++-+++--++-++-+-+---+--+-++-----+---+-----+++-+--+--+-+++--++-+----+--+----++---++++---+-+----++++++--+----+--+----++---+++-+++-+-++++------++-+++-+-+- +++-+---+++++-+++-+++--++-+--+---+-+-++-+++--++-++-+-+---+--+-++-----+---+-----+++-+--+--+-+++----++-+----+--+----++---++++---+-+----++++++--+--+--+----++---+++-+++-+-++++------++-+++-+-+-- ++-+---+++++-+++-+++--++-+--+---+-+-++-+++--++-++-+-+---+--+-++-----+---+-----+++-+--+--+-+++-++---++-+----+--+----++---++++---+-+----++++++---+--+----++---+++-+++-+-++++------++-+++-+-+--- +-+---+++++-+++-+++--++-+--+---+-+-++-+++--++-++-+-+---+--+-++-----+---+-----+++-+--+--+-+++-++-+---++-+----+--+----++---++++---+-+----++++++-+--+----++---+++-+++-+-++++------++-+++-+-+---- ++---+++++-+++-+++--++-+--+---+-+-++-+++--++-++-+-+---+--+-++-----+---+-----+++-+--+--+-+++-++---+---++-+----+--+----++---++++---+-+----++++++--+----++---+++-+++-+-++++------++-+++-+-+----+ +---+++++-+++-+++--++-+--+---+-+-++-+++--++-++-+-+---+--+-++-----+---+-----+++-+--+--+-+++-++-++--+---++-+----+--+----++---++++---+-+----+++++-+----++---+++-+++-+-++++------++-+++-+-+----+- +--+++++-+++-+++--++-+--+---+-+-++-+++--++-++-+-+---+--+-++-----+---+-----+++-+--+--+-+++-++-+-++--+---++-+----+--+----++---++++---+-+----+++++----++---+++-+++-+-++++------++-+++-+-+----+-- +-+++++-+++-+++--++-+--+---+-+-++-+++--++-++-+-----+--+-++-----+---+-----+++-+--+--+-+++-++-+-++++--+---++-+----+--+----++---++++---+-+----+++----++---+++-+++-+-++++------++-+++-+-+----+--+ ++++++-+++-+++--++-+--+---+-+-++-+++--++-++-+-----+--+-++-----+---+-----+++-+--+--+-+++-++-+-+-++++--+---++-+----+--+----++---++++---+-+----++---++---+++-+++-+-++++------++-+++-+-+----+--+- +++++-+++-+++--++-+--+---+-+-++-+++--++-++-+---+-+--+-++-----+---+-----+++-+--+--+-+++-++-+-+--+++++--+---++-+----+--+----++---++++---+-+----+--++---+++-+++-+-++++------++-+++-+-+----+--+-- ++++-+++-+++--++-+--+---+-+-++-+++--++-++-+---+++--+-++-----+---+-----+++-+--+--+-+++-++-+-+---++++++--+---++-+----+--+----++---++++---+-+-----++---+++-+++-+-++++------++-+++-+-+----+--+--- +++-+++-+++--++-+--+---+-+-++-+++--++-++-+---+++--+-++-----+---+-----+++-+--+--+-+++-++-+-+---+-++++++--+---++-+----+--+----++---++++---+-+---++---+++-+++-+-++++------++-+++-+-+----+--+---- ++-+++-+++--++-+--+---+-+-++-+++--++-++-+---++++-+-++-----+---+-----+++-+--+--+-+++-++-+-+---+---++++++--+---++-+----+--+----++---++++---+-+--+---+++-+++-+-++++------++-+++-+-+----+--+----+ +-+++-+++--++-+--+---+-+-++-+++--++-++-+---++++++-++-----+---+-----+++-+--+--+-+++-++-+-+---+-----++++++--+---++-+----+--+----++---++++---+-+----+++-+++-+-++++------++-+++-+-+----+--+----++ ++++-+++--++-+--+---+-+-++-+++--++-++-+---+++++--++-----+---+-----+++-+--+--+-+++-++-+-+---+--+----++++++--+---++-+----+--+----++---++++---+-+--+++-+++-+-++++------++-+++-+-+----+--+----++- +++-+++--++-+--+---+-+-++-+++--++-++-+---+++++-+++-----+---+-----+++-+--+--+-+++-++-+-+---+--+-+----++++++--+---++-+----+--+----++---++++---+--+++-+++-+-++++------++-+++-+-+----+--+----++-- ++-+++--++-+--+---+-+-++-+++--++-++-+---+++++-+++-----+---+-----+++-+--+--+-+++-++-+-+---+--+-+-+----++++++--+---++-+----+--+----++---++++---++++-+++-+-++++------++-+++-+-+----+--+----++--- +-+++--++-+--+---+-+-++-+++--++-++-+---+++++-+++-----+---+-----+++-+--+--+-+++-++-+-+---+--+-+++-+----++++++--+---++-+----+--+----++---++++---++-+++-+-++++------++-+++-+-+----+--+----++---+ ++++--++-+--+---+-+-++-+++--++-++-+---+++++-+++-----+---+-----+++-+--+--+-+++-++-+-+---+--+-++--+-+----++++++--+---++-+----+--+----++---++++--+-+++-+-++++------++-+++-+-+----+--+----++---++ +++--++-+--+---+-+-++-+++--++-++-+---+++++-+++-+---+---+-----+++-+--+--+-+++-++-+-+---+--+-++----+-+----++++++--+---++-+----+--+----++---++++--+++-+-++++------++-+++-+-+----+--+----++---+++ ++--++-+--+---+-+-++-+++--++-++-+---+++++-+++-++--+---+-----+++-+--+--+-+++-++-+-+---+--+-++------+-+----++++++--+---++-+----+--+----++---+++++++-+-++++------++-+++-+-+----+--+----++---+++- +--++-+--+---+-+-++-+++--++-++-+---+++++-+++-+++-+---+-----+++-+--+--+-+++-++-+-+---+--+-++----+---+-+----++++++--+---++-+----+--+----++---+++++-+-++++------++-+++-+-+----+--+----++---+++-+ +-++-+--+---+-+-++-+++--++-++-+---+++++-+++-+++-+---+-----+++-+--+--+-+++-++-+-+---+--+-++-----++---+-+----++++++--+---++-+----+--+----++---+++-+-++++------++-+++-+-+----+--+----++---+++-++ +++-+--+---+-+-++-+++--++-++-+---+++++-+++-+++-----+-----+++-+--+--+-+++-++-+-+---+--+-++-----++++---+-+----++++++--+---++-+----+--+----++---+-+-++++------++-+++-+-+----+--+----++---+++-+++ ++-+--+---+-+-++-+++--++-++-+---+++++-+++-+++--+--+-----+++-+--+--+-+++-++-+-+---+--+-++-----+-++++---+-+----++++++--+---++-+----+--+----++---+-++++------++-+++-+-+----+--+----++---+++-+++- +-+--+---+-+-++-+++--++-++-+---+++++-+++-+++--++-+-----+++-+--+--+-+++-++-+-+---+--+-++-----+---++++---+-+----++++++--+---++-+----+--+----++---++++------++-+++-+-+----+--+----++---+++-+++-+ ++--+---+-+-++-+++--++-++-+---+++++-+++-+++--++-+-----+++-+--+--+-+++-++-+-+---+--+-++-----+-----++++---+-+----++++++--+---++-+----+--+----++-++++------++-+++-+-+----+--+----++---+++-+++-+- +--+---+-+-++-+++--++-++-+---+++++-+++-+++--++-+-----+++-+--+--+-+++-++-+-+---+--+-++-----+---+---++++---+-+----++++++--+---++-+----+--+----+++++------++-+++-+-+----+--+----++---+++-+++-+-+ +-+---+-+-++-+++--++-++-+---+++++-+++-+++--++-+-----+++-+--+--+-+++-++-+-+---+--+-++-----+---+-+---++++---+-+----++++++--+---++-+----+--+----+++------++-+++-+-+----+--+----++---+++-+++-+-++ ++---+-+-++-+++--++-++-+---+++++-+++-+++--++-+-----+++-+--+--+-+++-++-+-+---+--+-++-----+---+--++---++++---+-+----++++++--+---++-+----+--+----+------++-+++-+-+----+--+----++---+++-+++-+-+++ +---+-+-++-+++--++-++-+---+++++-+++-+++--++-+--+--+++-+--+--+-+++-++-+-+---+--+-++-----+---+----++---++++---+-+----++++++--+---++-+----+--+---------++-+++-+-+----+--+----++---+++-+++-+-++++ +--+-+-++-+++--++-++-+---+++++-+++-+++--++-+--+--+++-+--+--+-+++-++-+-+---+--+-++-----+---+------++---++++---+-+----++++++--+---++-+----+--+-------++-+++-+-+----+--+----++---+++-+++-+-++++- +-+-+-++-+++--++-++-+---+++++-+++-+++--++-+--+--+++-+--+--+-+++-++-+-+---+--+-++-----+---+--------++---++++---+-+----++++++--+---++-+----+--+-----++-+++-+-+----+--+----++---+++-+++-+-++++-- ++-+-++-+++--++-++-+---+++++-+++-+++--++-+--+---++-+--+--+-+++-++-+-+---+--+-++-----+---+-----+----++---++++---+-+----++++++--+---++-+----+--+---++-+++-+-+----+--+----++---+++-+++-+-++++--- +-+-++-+++--++-++-+---+++++-+++-+++--++-+--+---++-+--+--+-+++-++-+-+---+--+-++-----+---+-----+++----++---++++---+-+----++++++--+---++-+----+----++-+++-+-+----+--+----++---+++-+++-+-++++---- ++-++-+++--++-++-+---+++++-+++-+++--++-+--+---+--+--+--+-+++-++-+-+---+--+-++-----+---+-----+++-+----++---++++---+-+----++++++--+---++-+----+--++-+++-+-+----+--+----++---+++-+++-+-++++----- +-++-+++--++-++-+---+++++-+++-+++--++-+--+---+-++--+--+-+++-++-+-+---+--+-++-----+---+-----+++---+----++---++++---+-+----++++++--+---++-+----+++-+++-+-+----+--+----++---+++-+++-+-++++------ +++-+++--++-++-+---+++++-+++-+++--++-+--+---+-+---+--+-+++-++-+-+---+--+-++-----+---+-----+++-++--+----++---++++---+-+----++++++--+---++-+----+-+++-+-+----+--+----++---+++-+++-+-++++------+ ++-+++--++-++-+---+++++-+++-+++--++-+--+---+-+-+-+--+-+++-++-+-+---+--+-++-----+---+-----+++-+--+--+----++---++++---+-+----++++++--+---++-+----+++-+-+----+--+----++---+++-+++-+-++++------++ +-+++--++-++-+---+++++-+++-+++--++-+--+---+-+-+++--+-+++-++-+-+---+--+-++-----+---+-----+++-+----+--+----++---++++---+-+----++++++--+---++-+--+++-+-+----+--+----++---+++-+++-+-++++------++- ++++--++-++-+---+++++-+++-+++--++-+--+---+-+-++---+-+++-++-+-+---+--+-++-----+---+-----+++-+--+---+--+----++---++++---+-+----++++++--+---++-+-++-+-+----+--+----++---+++-+++-+-++++------++-+ +++--++-++-+---+++++-+++-+++--++-+--+---+-+-++-+-+-+++-++-+-+---+--+-++-----+---+-----+++-+--+-----+--+----++---++++---+-+----++++++--+---++-++-+-+----+--+----++---+++-+++-+-++++------++-++ ++--++-++-+---+++++-+++-+++--++-+--+---+-+-++-+++-+++-++-+-+---+--+-++-----+---+-----+++-+--+--+----+--+----++---++++---+-+----++++++--+---++--+-+----+--+----++---+++-+++-+-++++------++-+++ +--++-++-+---+++++-+++-+++--++-+--+---+-+-++-+++-+++-++-+-+---+--+-++-----+---+-----+++-+--+--+-+----+--+----++---++++---+-+----++++++--+---+++-+----+--+----++---+++-+++-+-++++------++-+++- +-++-++-+---+++++-+++-+++++--+-++-+++-+-+--+---++++-++-+-+---+--+-++--+++-+++-+++++---+-++-++--+-++++-++-++++--+++---+---+-+----++++++--+---+-+-+----+--+----++---++++---+-+----++++++--+---+ +++-++-+---+++++-+++-+++++--+-++-+++-+-+--+---+-++-++-+-+---+--+-++--+++-+++-+++++---+-++-++--+-++++-++-++++--+++---+---+-+----++++++--+---+-+++-+----+--+----++---++++---+-+----++++++--+--- ++-++-+---+++++-+++-+++++--+-++-+++-+-+--+---+-++-++-+-+---+--+-++--+++-+++-+++++---+-++-++--++++++-++-++++--+++---+---+-+----++++++--+---+-+--++-+----+--+----++---++++---+-+----++++++--+-- +-++-+---+++++-+++-+++++--+-++-+++-+-+--+---+-++-++-+-+---+--+-++--+++-+++-+++++---+-++-++--++++++-++-++++--+++---+---+-+----++++++--+---+-+-+--++-+----+--+----++---++++---+-+----++++++--+- +++-+---+++++-+++-+++++--+-++-+++-+-+--+---+-++-++-+-+---+--+-++--+++-+++-+++++---+-++-++--+++-++-++-++++--+++---+---+-+----++++++--+---+-+-++---++-+----+--+----++---++++---+-+----++++++--+ ++-+---+++++-+++-+++++--+-++-+++-+-+--+---+-++-++-+-+---+--+-++--+++-+++-+++++---+-++-++--+++-++-++-++++--+++---+---+-+----++++++--+---+-+-++++---++-+----+--+----++---++++---+-+----++++++-- +-+---+++++-+++-+++++--+-++-+++-+-+--+---+-++-++-+-+---+--+-++--+++-+++-+++++---+-++-++--+++-++-++-++++--+++---+---+-+----++++++--+---+-+-++++-+---++-+----+--+----++---++++---+-+----++++++- ++---+++++-+++-+++++--+-++-+++-+-+--+---+-++-++-+-+---+--+-++--+++-+++-+++++---+-++-++--+++-++-++-++++--+++---+---+-+----++++++--+---+-+-++++---+---++-+----+--+----++---++++---+-+----++++++ +---+++++-+++-+++++--+-++-+++-+-+--+---+-++-++-+-+---+--+-++--+++-+++-+++++---+-++-++--+++-++-++-++++--+++---+---+-+----++++++--+---+-+-++++-++--+---++-+----+--+----++---++++---+-+----+++++ +--+++++-+++-+++++--+-++-+++-+-+--+---+-++-++-+-+---+--+-++--+++-+++-+++++---+-++-++--+++-++-+--++++--+++---+---+-+----++++++--+---+-+-++++-++++--+---++-+----+--+----++---++++---+-+----++++ +-+++++-+++-+++++--+-++-+++-+-+--+---+-++-++-+-----+--+-++--+++-+++-+++++---+-++-++--+++-++-+-+++++--+++---+---+-+----++++++--+---+-+-++++-++-+++--+---++-+----+--+----++---++++---+-+----+++ ++++++-+++-+++++--+-++-+++-+-+--+---+-++-++-+-----+--+-++--+++-+++-+++++---+-++-++--+++-++-+-+-+++--+++---+---+-+----++++++--+---+-+-++++-++-+++++--+---++-+----+--+----++---++++---+-+----++ +++++-+++-+++++--+-++-+++-+-+--+---+-++-++-+---+-+--+-++--+++-+++-+++++---+-++-++--+++-++-+-+--++--+++---+---+-+----++++++--+---+-+-++++-++-+++++++--+---++-+----+--+----++---++++---+-+----+ ++++-+++-+++++--+-++-+++-+-+--+---+-++-++-+---+++--+-++--+++-+++-+++++---+-++-++--+++-++-+-+---+--+++---+---+-+----++++++--+---+-+-++++-++-+++++++++--+---++-+----+--+----++---++++---+-+---- +++-+++-+++++--+-++-+++-+-+--+---+-++-++-+---+++--+-++--+++-+++-+++++---+-++-++--+++-++-+-+---+--+++---+---+-+----++++++--+---+-+-++++-++-++++-++++++--+---++-+----+--+----++---++++---+-+--- ++-+++-+++++--+-++-+++-+-+--+---+-++-++-+---++++-+-++--+++-+++-+++++---+-++-++--+++-++-+-+---+--+++---+---+-+----++++++--+---+-+-++++-++-++++---++++++--+---++-+----+--+----++---++++---+-+-- +-+++-+++++--+-++-+++-+-+--+---+-++-++-+---++++++-++--+++-+++-+++++---+-++-++--+++-++-+-+---+--+++---+---+-+----++++++--+---+-+-++++-++-++++-----++++++--+---++-+----+--+----++---++++---+-+- ++++-+++++--+-++-+++-+-+--+---+-++-++-+---+++++--++--+++-+++-+++++---+-++-++--+++-++-+-+---+--+++---+---+-+----++++++--+---+-+-++++-++-++++--+----++++++--+---++-+----+--+----++---++++---+-+ +++-+++++--+-++-+++-+-+--+---+-++-++-+---+++++-+++--+++-+++-+++++---+-++-++--+++-++-+-+---+--+-+---+---+-+----++++++--+---+-+-++++-++-++++--+++----++++++--+---++-+----+--+----++---++++---+- ++-+++++--+-++-+++-+-+--+---+-++-++-+---+++++-+++--+++-+++-+++++---+-++-++--+++-++-+-+---+--+-+---+---+-+----++++++--+---+-+-++++-++-++++--+++-+----++++++--+---++-+----+--+----++---++++---+ +-+++++--+-++-+++-+-+--+---+-++-++-+---+++++-+++--+++-+++-+++++---+-++-++--+++-++-+-+---+--+-++--+---+-+----++++++--+---+-+-++++-++-++++--+++-+-+----++++++--+---++-+----+--+----++---++++--- ++++++--+-++-+++-+-+--+---+-++-++-+---+++++-+++--+++-+++-+++++---+-++-++--+++-++-+-+---+--+-++--+---+-+----++++++--+---+-+-++++-++-++++--+++---+-+----++++++--+---++-+----+--+----++---++++-- +++++--+-++-+++-+-+--+---+-++-++-+---+++++-+++-++++-+++-+++++---+-++-++--+++-++-+-+---+--+-++--+---+-+----++++++--+---+-+-++++-++-++++--+++-----+-+----++++++--+---++-+----+--+----++---++++- ++++--+-++-+++-+-+--+---+-++-++-+---+++++-+++-++++-+++-+++++---+-++-++--+++-++-+-+---+--+-++--+---+-+----++++++--+---+-+-++++-++-++++--+++---+---+-+----++++++--+---++-+----+--+----++---++++ +++--+-++-+++-+-+--+---+-++-++-+---+++++-+++-++++-+++-+++++---+-++-++--+++-++-+-+---+--+-++--++--+-+----++++++--+---+-+-++++-++-++++--+++---+-+---+-+----++++++--+---++-+----+--+----++---+++ ++--+-++-+++-+-+--+---+-++-++-+---+++++-+++-++++-+++-+++++---+-++-++--+++-++-+-+---+--+-++--+++-+-+----++++++--+---+-+-++++-++-++++--+++---+--++---+-+----++++++--+---++-+----+--+----++---++ +--+-++-+++-+-+--+---+-++-++-+---+++++-+++-++++++++-+++++---+-++-++--+++-++-+-+---+--+-++--+++-+-+----++++++--+---+-+-++++-++-++++--+++---+---+++---+-+----++++++--+---++-+----+--+----++---+ +-+-++-+++-+-+--+---+-++-++-+---+++++-+++-+++++-++-+++++---+-++-++--+++-++-+-+---+--+-++--+++-+-+----++++++--+---+-+-++++-++-++++--+++---+---+++++---+-+----++++++--+---++-+----+--+----++--- ++-++-+++-+-+--+---+-++-++-+---+++++-+++-+++++--+-+++++---+-++-++--+++-++-+-+---+--+-++--+++-+++----++++++--+---+-+-++++-++-++++--+++---+---+--++++---+-+----++++++--+---++-+----+--+----++-- +-++-+++-+-+--+---+-++-++-+---+++++-+++-+++++--+-+++++---+-++-++--+++-++-+-+---+--+-++--+++-+++----++++++--+---+-+-++++-++-++++--+++---+---+-+--++++---+-+----++++++--+---++-+----+--+----++- +++-+++-+-+--+---+-++-++-+---+++++-+++-+++++--+-+++++---+-++-++--+++-++-+-+---+--+-++--+++-+++----++++++--+---+-+-++++-++-++++--+++---+---+-+----++++---+-+----++++++--+---++-+----+--+----++ ++-+++-+-+--+---+-++-++-+---+++++-+++-+++++--+-+++++---+-++-++--+++-++-+-+---+--+-++--+++-+++-+--++++++--+---+-+-++++-++-++++--+++---+---+-+--+---++++---+-+----++++++--+---++-+----+--+----+ +-+++-+-+--+---+-++-++-+---+++++-+++-+++++--+-+++++---+-++-++--+++-++-+-+---+--+-++--+++-+++-++-++++++--+---+-+-++++-++-++++--+++---+---+-+---++---++++---+-+----++++++--+---++-+----+--+---- ++++-+-+--+---+-++-++-+---+++++-+++-+++++--+-++-++---+-++-++--+++-++-+-+---+--+-++--+++-+++-+++++++++--+---+-+-++++-++-++++--+++---+---+-+-----++---++++---+-+----++++++--+---++-+----+--+--- +++-+-+--+---+-++-++-+---+++++-+++-+++++--+-++-++---+-++-++--+++-++-+-+---+--+-++--+++-+++-+++++++++--+---+-+-++++-++-++++--+++---+---+-+----+--++---++++---+-+----++++++--+---++-+----+--+-- ++-+-+--+---+-++-++-+---+++++-+++-+++++--+-++-++---+-++-++--+++-++-+-+---+--+-++--+++-+++-+++++++++--+---+-+-++++-++-++++--+++---+---+-+----++---++---++++---+-+----++++++--+---++-+----+--+- +-+-+--+---+-++-++-+---+++++-+++-+++++--+-++-+++--+-++-++--+++-++-+-+---+--+-++--+++-+++-+++++-+++--+---+-+-++++-++-++++--+++---+---+-+----+++----++---++++---+-+----++++++--+---++-+----+--+ ++-+--+---+-++-++-+---+++++-+++-+++++--+-++-+++--+-++-++--+++-++-+-+---+--+-++--+++-+++-+++++--++--+---+-+-++++-++-++++--+++---+---+-+----+++++----++---++++---+-+----++++++--+---++-+----+-- +-+--+---+-++-++-+---+++++-+++-+++++--+-++-+++-++-++-++--+++-++-+-+---+--+-++--+++-+++-+++++---+--+---+-+-++++-++-++++--+++---+---+-+----+++++-+----++---++++---+-+----++++++--+---++-+----+- ++--+---+-++-++-+---+++++-+++-+++++--+-++-+++-+--++-++--+++-++-+-+---+--+-++--+++-+++-+++++---+--+---+-+-++++-++-++++--+++---+---+-+----++++++--+----++---++++---+-+----++++++--+---++-+----+ +--+---+-++-++-+---+++++-+++-+++++--+-++-+++-+-+++-++--+++-++-+-+---+--+-++--+++-+++-+++++---+--+---+-+-++++-++-++++--+++---+---+-+----++++++-+--+----++---++++---+-+----++++++--+---++-+---- +-+---+-++-++-+---+++++-+++-+++++--+-++-+++-+-+-+-++--+++-++-+-+---+--+-++--+++-+++-+++++---+-++---+-+-++++-++-++++--+++---+---+-+----++++++---+--+----++---++++---+-+----++++++--+---++-+--- ++---+-++-++-+---+++++-+++-+++++--+-++-+++-+-+---++--+++-++-+-+---+--+-++--+++-+++-+++++---+-++---+-+-++++-++-++++--+++---+---+-+----++++++--+--+--+----++---++++---+-+----++++++--+---++-+-- +---+-++-++-+---+++++-+++-+++++--+-++-+++-+-+--+++--+++-++-+-+---+--+-++--+++-+++-+++++---+-++---+-+-++++-++-++++--+++---+---+-+----++++++--+----+--+----++---++++---+-+----++++++--+---++-+- +--+-++-++-+---+++++-+++-+++++--+-++-+++-+-+--+-+--+++-++-+-+---+--+-++--+++-+++-+++++---+-++-+-+-+-++++-++-++++--+++---+---+-+----++++++--+------+--+----++---++++---+-+----++++++--+---++-+ +-+-++-++-+---+++++-+++-+++++--+-++-+++-+-+--+----+++-++-+-+---+--+-++--+++-+++-+++++---+-++-+++-+-++++-++-++++--+++---+---+-+----++++++--+---+----+--+----++---++++---+-+----++++++--+---++- ++-++-++-+---+++++-+++-+++++--+-++-+++-+-+--+----+++-++-+-+---+--+-++--+++-+++-+++++---+-++-++--+-++++-++-++++--+++---+---+-+----++++++--+---+-+----+--+----++---++++---+-+----++++++--+---++ \ No newline at end of file diff --git a/gptqmodel/exllamav3/util/hadamard_data/hadamard_236.txt b/gptqmodel/exllamav3/util/hadamard_data/hadamard_236.txt new file mode 100644 index 000000000..e065f73a5 --- /dev/null +++ b/gptqmodel/exllamav3/util/hadamard_data/hadamard_236.txt @@ -0,0 +1,236 @@ ++++++-++++-+--++--++++-+---+-++++-+---+-++---+--++++-++++---++-+---+-++---+--++++-++++-++-+---+-++++--++--+-++++-++++---++++-++++--+---++-+---+-+++-+-+++-+----++--++-+----+-----+--+-+++-+--+++-++----+----+-+-+---+-++++--++--+-++++-++++- +-+++++-++++-+--++--++++-+---+-++++-+---+-++---+--++++-++++-++-+---+-++---+--++++-++++-++-+---+-++++--++--+-++++-++++---++++-++++--+---++-+---+-+++-+-+++-+----++--++-+----+--------+-+++-+--+++-++----+----+-+-+---+-++++--++--+-++++-++++-+ +--+++++-++++-+--++--++++-+---+-++++-+---+-++---+--++++-+++++-+---+-++---+--++++-++++-++-+---+-++++--++--+-++++-++++--+++++-++++--+---++-+---+-+++-+-+++-+----++--++-+----+--------+-+++-+--+++-++----+----+-+-+---+-++++--++--+-++++-++++-+- ++--+++++-++++-+--++--++++-+---+-++++-+---+-++---+--++++-+++-+---+-++---+--++++-++++-++-+---+-++++--++--+-++++-++++--+++++-++++--+---++-+---+-+++-+-+++-+----++--++-+----+-------++-+++-+--+++-++----+----+-+-+---+-++++--++--+-++++-++++-+-- +++--+++++-++++-+--++--++++-+---+-++++-+---+-++---+--++++-+++---+-++---+--++++-++++-++-+---+-++++--++--+-++++-++++--++-++-++++--+---++-+---+-+++-+-+++-+----++--++-+----+-------++-+++-+--+++-++----+----+-+-+---+-++++--++--+-++++-++++-+--+ ++++--+++++-++++-+--++--++++-+---+-++++-+---+-++---+--++++-+---+-++---+--++++-++++-++-+---+-++++--++--+-++++-++++--++-++-++++--+---++-+---+-+++-+-+++-+----++--++-+----+-------++++++-+--+++-++----+----+-+-+---+-++++--++--+-++++-++++-+--+- +++++--+++++-++++-+--++--++++-+---+-++++-+---+-++---+--++++---+-++---+--++++-++++-++-+---+-++++--++--+-++++-++++--++-+--++++--+---++-+---+-+++-+-+++-+----++--++-+----+-------++++++-+--+++-++----+----+-+-+---+-++++--++--+-++++-++++-+--+-+ +-++++--+++++-++++-+--++--++++-+---+-++++-+---+-++---+--++++-+-++---+--++++-++++-++-+---+-++++--++--+-++++-++++--++-+--++++--+---++-+---+-+++-+-+++-+----++--++-+----+-------++++-+-+--+++-++----+----+-+-+---+-++++--++--+-++++-++++-+--+-++ ++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-++---+--++++-++---+--++++-++++-++-+---+-++++--++--+-++++-++++--++-+---+++--+---++-+---+-+++-+-+++-+----++--++-+----+-------++++-+-+--+++-++----+----+-+-+---+-++++--++--+-++++-++++-+--+-+++ +++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-++---+--++-++---+--++++-++++-++-+---+-++++--++--+-++++-++++--++-+---+++--+---++-+---+-+++-+-+++-+----++--++-+----+-------++++-+++--+++-++----+----+-+-+---+-++++--++--+-++++-++++-+--+-+++- ++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-++---+--+++---+--++++-++++-++-+---+-++++--++--+-++++-++++--++-+---+-+--+---++-+---+-+++-+-+++-+----++--++-+----+-------++++-+++--+++-++----+----+-+-+---+-++++--++--+-++++-++++-+--+-+++-+ +++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-++---+--+---+--++++-++++-++-+---+-++++--++--+-++++-++++--++-+---+-+--+---++-+---+-+++-+-+++-+----++--++-+----+-------++++-++++-+++-++----+----+-+-+---+-++++--++--+-++++-++++-+--+-+++-+- +-++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-++---+----+--++++-++++-++-+---+-++++--++--+-++++-++++--++-+---+-++-+---++-+---+-+++-+-+++-+----++--++-+----+-------++++-++++-+++-++----+----+-+-+---+-++++--++--+-++++-++++-+--+-+++-+-- +--++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-++---+--+--++++-++++-++-+---+-++++--++--+-++++-++++--++-+---+-++-+---++-+---+-+++-+-+++-+----++--++-+----+-------++++-++++--++-++----+----+-+-+---+-++++--++--+-++++-++++-+--+-+++-+--+ ++--++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-++----+--++++-++++-++-+---+-++++--++--+-++++-++++--++-+---+-++-----++-+---+-+++-+-+++-+----++--++-+----+-------++++-++++--++-++----+----+-+-+---+-++++--++--+-++++-++++-+--+-+++-+--++ +-+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-++--+--++++-++++-++-+---+-++++--++--+-++++-++++--++-+---+-++-----++-+---+-+++-+-+++-+----++--++-+----+-------++++-++++--+--++----+----+-+-+---+-++++--++--+-++++-++++-+--+-+++-+--+++ +--+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-++---++++-++++-++-+---+-++++--++--+-++++-++++--++-+---+-++---+-++-+---+-+++-+-+++-+----++--++-+----+-------++++-++++--+--++----+----+-+-+---+-++++--++--+-++++-++++-+--+-+++-+--+++- +---+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-++-++++-++++-++-+---+-++++--++--+-++++-++++--++-+---+-++---+-++-+---+-+++-+-+++-+----++--++-+----+-------++++-++++--+---+----+----+-+-+---+-++++--++--+-++++-++++-+--+-+++-+--+++-+ ++---+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-+++++-++++-++-+---+-++++--++--+-++++-++++--++-+---+-++---+--+-+---+-+++-+-+++-+----++--++-+----+-------++++-++++--+---+----+----+-+-+---+-++++--++--+-++++-++++-+--+-+++-+--+++-++ +++---+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-+++-++++-++-+---+-++++--++--+-++++-++++--++-+---+-++---+--+-+---+-+++-+-+++-+----++--++-+----+-------++++-++++--+---++---+----+-+-+---+-++++--++--+-++++-++++-+--+-+++-+--+++-++- +-++---+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+++-++++-++-+---+-++++--++--+-++++-++++--++-+---+-++---+--+++---+-+++-+-+++-+----++--++-+----+-------++++-++++--+---++---+----+-+-+---+-++++--++--+-++++-++++-+--+-+++-+--+++-++-- ++-++---+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-++++-++-+---+-++++--++--+-++++-++++--++-+---+-++---+--+++---+-+++-+-+++-+----++--++-+----+-------++++-++++--+---++-+-+----+-+-+---+-++++--++--+-++++-++++-+--+-+++-+--+++-++--- +-+-++---+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+---++++-++-+---+-++++--++--+-++++-++++--++-+---+-++---+--++++--+-+++-+-+++-+----++--++-+----+-------++++-++++--+---++-+-+----+-+-+---+-++++--++--+-++++-++++-+--+-+++-+--+++-++---- +--+-++---+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+-++++-++-+---+-++++--++--+-++++-++++--++-+---+-++---+--++++--+-+++-+-+++-+----++--++-+----+-------++++-++++--+---++-+------+-+-+---+-++++--++--+-++++-++++-+--+-+++-+--+++-++----+ +---+-++---+--++++-++++--+++++-++++-+--++--++++-+---+-++++-++++-++-+---+-++++--++--+-++++-++++--++-+---+-++---+--++++-++-+++-+-+++-+----++--++-+----+-------++++-++++--+---++-+------+-+-+---+-++++--++--+-++++-++++-+--+-+++-+--+++-++----+- ++---+-++---+--++++-++++--+++++-++++-+--++--++++-+---+-++++-++-++-+---+-++++--++--+-++++-++++--++-+---+-++---+--++++-++-+++-+-+++-+----++--++-+----+-------++++-++++--+---++-+---+--+-+-+---+-++++--++--+-++++-++++-+--+-+++-+--+++-++----+-- +-+---+-++---+--++++-++++--+++++-++++-+--++--++++-+---+-+++++-++-+---+-++++--++--+-++++-++++--++-+---+-++---+--++++-++++++-+-+++-+----++--++-+----+-------++++-++++--+---++-+---+--+-+-+---+-++++--++--+-++++-++++-+--+-+++-+--+++-++----+--- ++-+---+-++---+--++++-++++--+++++-++++-+--++--++++-+---+-+++-++-+---+-++++--++--+-++++-++++--++-+---+-++---+--++++-++++++-+-+++-+----++--++-+----+-------++++-++++--+---++-+---+-++-+-+---+-++++--++--+-++++-++++-+--+-+++-+--+++-++----+---- +++-+---+-++---+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-++++--++--+-++++-++++--++-+---+-++---+--++++-++++-+-+-+++-+----++--++-+----+-------++++-++++--+---++-+---+-++-+-+---+-++++--++--+-++++-++++-+--+-+++-+--+++-++----+----+ ++++-+---+-++---+--++++-++++--+++++-++++-+--++--++++-+---+-++-+---+-++++--++--+-++++-++++--++-+---+-++---+--++++-++++-+-+-+++-+----++--++-+----+-------++++-++++--+---++-+---+-++++-+---+-++++--++--+-++++-++++-+--+-+++-+--+++-++----+----+- +++++-+---+-++---+--++++-++++--+++++-++++-+--++--++++-+---+--+---+-++++--++--+-++++-++++--++-+---+-++---+--++++-++++-+++-+++-+----++--++-+----+-------++++-++++--+---++-+---+-+++--+---+-++++--++--+-++++-++++-+--+-+++-+--+++-++----+----+-+ +-++++-+---+-++---+--++++-++++--+++++-++++-+--++--++++-+---++---+-++++--++--+-++++-++++--++-+---+-++---+--++++-++++-++--+++-+----++--++-+----+-------++++-++++--+---++-+---+-+++-++---+-++++--++--+-++++-++++-+--+-+++-+--+++-++----+----+-+- ++-++++-+---+-++---+--++++-++++--+++++-++++-+--++--++++-+------+-++++--++--+-++++-++++--++-+---+-++---+--++++-++++-++-++++-+----++--++-+----+-------++++-++++--+---++-+---+-+++-+----+-++++--++--+-++++-++++-+--+-+++-+--+++-++----+----+-+-+ +-+-++++-+---+-++---+--++++-++++--+++++-++++-+--++--++++-+----+-++++--++--+-++++-++++--++-+---+-++---+--++++-++++-++-+-++-+----++--++-+----+-------++++-++++--+---++-+---+-+++-+-+--+-++++--++--+-++++-++++-+--+-+++-+--+++-++----+----+-+-+- +--+-++++-+---+-++---+--++++-++++--+++++-++++-+--++--++++-+--+-++++--++--+-++++-++++--++-+---+-++---+--++++-++++-++-+--+-+----++--++-+----+-------++++-++++--+---++-+---+-+++-+-++-+-++++--++--+-++++-++++-+--+-+++-+--+++-++----+----+-+-+-- +---+-++++-+---+-++---+--++++-++++--+++++-++++-+--++--++++-++-++++--++--+-++++-++++--++-+---+-++---+--++++-++++-++-+----+----++--++-+----+-------++++-++++--+---++-+---+-+++-+-++++-++++--++--+-++++-++++-+--+-+++-+--+++-++----+----+-+-+--- ++---+-++++-+---+-++---+--++++-++++--+++++-++++-+--++--++++--++++--++--+-++++-++++--++-+---+-++---+--++++-++++-++-+---++----++--++-+----+-------++++-++++--+---++-+---+-+++-+-+++--++++--++--+-++++-++++-+--+-+++-+--+++-++----+----+-+-+---+ +-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+--++--++++++++--++--+-++++-++++--++-+---+-++---+--++++-++++-++-+---+-----++--++-+----+-------++++-++++--+---++-+---+-+++-+-+++-+++++--++--+-++++-++++-+--+-+++-+--+++-++----+----+-+-+---+- ++-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+--++--++++++--++--+-++++-++++--++-+---+-++---+--++++-++++-++-+---+-+---++--++-+----+-------++++-++++--+---++-+---+-+++-+-+++-+-+++--++--+-++++-++++-+--+-+++-+--+++-++----+----+-+-+---+-+ +++-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+--++--++++--++--+-++++-++++--++-+---+-++---+--++++-++++-++-+---+-++--++--++-+----+-------++++-++++--+---++-+---+-+++-+-+++-+--++--++--+-++++-++++-+--+-+++-+--+++-++----+----+-+-+---+-++ ++++-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+--++--++--++--+-++++-++++--++-+---+-++---+--++++-++++-++-+---+-+++-++--++-+----+-------++++-++++--+---++-+---+-+++-+-+++-+---+--++--+-++++-++++-+--+-+++-+--+++-++----+----+-+-+---+-+++ +++++-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+--++----++--+-++++-++++--++-+---+-++---+--++++-++++-++-+---+-++++++--++-+----+-------++++-++++--+---++-+---+-+++-+-+++-+------++--+-++++-++++-+--+-+++-+--+++-++----+----+-+-+---+-++++ +-++++-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+--++--++--+-++++-++++--++-+---+-++---+--++++-++++-++-+---+-++++-+--++-+----+-------++++-++++--+---++-+---+-+++-+-+++-+----+-++--+-++++-++++-+--+-+++-+--+++-++----+----+-+-+---+-++++- +--++++-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+--++++--+-++++-++++--++-+---+-++---+--++++-++++-++-+---+-++++----++-+----+-------++++-++++--+---++-+---+-+++-+-+++-+----++++--+-++++-++++-+--+-+++-+--+++-++----+----+-+-+---+-++++-- ++--++++-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+--++--+-++++-++++--++-+---+-++---+--++++-++++-++-+---+-++++--+-++-+----+-------++++-++++--+---++-+---+-+++-+-+++-+----++-+--+-++++-++++-+--+-+++-+--+++-++----+----+-+-+---+-++++--+ +++--++++-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+----+-++++-++++--++-+---+-++---+--++++-++++-++-+---+-++++--++++-+----+-------++++-++++--+---++-+---+-+++-+-+++-+----++----+-++++-++++-+--+-+++-+--+++-++----+----+-+-+---+-++++--++ +-++--++++-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+--+-++++-++++--++-+---+-++---+--++++-++++-++-+---+-++++--++-+-+----+-------++++-++++--+---++-+---+-+++-+-+++-+----++--+-+-++++-++++-+--+-+++-+--+++-++----+----+-+-+---+-++++--++- +--++--++++-+---+-++++-+---+-++---+--++++-++++--+++++-++++-++-++++-++++--++-+---+-++---+--++++-++++-++-+---+-++++--++---+----+-------++++-++++--+---++-+---+-+++-+-+++-+----++--+++-++++-++++-+--+-+++-+--+++-++----+----+-+-+---+-++++--++-- ++--++--++++-+---+-++++-+---+-++---+--++++-++++--+++++-++++--++++-++++--++-+---+-++---+--++++-++++-++-+---+-++++--++--++----+-------++++-++++--+---++-+---+-+++-+-+++-+----++--++--++++-++++-+--+-+++-+--+++-++----+----+-+-+---+-++++--++--+ +-+--++--++++-+---+-++++-+---+-++---+--++++-++++--+++++-++++++++-++++--++-+---+-++---+--++++-++++-++-+---+-++++--++--+-----+-------++++-++++--+---++-+---+-+++-+-+++-+----++--++-+++++-++++-+--+-+++-+--+++-++----+----+-+-+---+-++++--++--+- ++-+--++--++++-+---+-++++-+---+-++---+--++++-++++--+++++-++++++-++++--++-+---+-++---+--++++-++++-++-+---+-++++--++--+-+---+-------++++-++++--+---++-+---+-+++-+-+++-+----++--++-+-+++-++++-+--+-+++-+--+++-++----+----+-+-+---+-++++--++--+-+ +++-+--++--++++-+---+-++++-+---+-++---+--++++-++++--+++++-++++-++++--++-+---+-++---+--++++-++++-++-+---+-++++--++--+-++--+-------++++-++++--+---++-+---+-+++-+-+++-+----++--++-+--++-++++-+--+-+++-+--+++-++----+----+-+-+---+-++++--++--+-++ ++++-+--++--++++-+---+-++++-+---+-++---+--++++-++++--+++++-++-++++--++-+---+-++---+--++++-++++-++-+---+-++++--++--+-+++-+-------++++-++++--+---++-+---+-+++-+-+++-+----++--++-+---+-++++-+--+-+++-+--+++-++----+----+-+-+---+-++++--++--+-+++ +++++-+--++--++++-+---+-++++-+---+-++---+--++++-++++--+++++--++++--++-+---+-++---+--++++-++++-++-+---+-++++--++--+-+++++-------++++-++++--+---++-+---+-+++-+-+++-+----++--++-+-----++++-+--+-+++-+--+++-++----+----+-+-+---+-++++--++--+-++++ +-++++-+--++--++++-+---+-++++-+---+-++---+--++++-++++--+++++++++--++-+---+-++---+--++++-++++-++-+---+-++++--++--+-++++--------++++-++++--+---++-+---+-+++-+-+++-+----++--++-+----+++++-+--+-+++-+--+++-++----+----+-+-+---+-++++--++--+-++++- ++-++++-+--++--++++-+---+-++++-+---+-++---+--++++-++++--+++++++--++-+---+-++---+--++++-++++-++-+---+-++++--++--+-++++-+------++++-++++--+---++-+---+-+++-+-+++-+----++--++-+----+-+++-+--+-+++-+--+++-++----+----+-+-+---+-++++--++--+-++++-+ +++-++++-+--++--++++-+---+-++++-+---+-++---+--++++-++++--+++++--++-+---+-++---+--++++-++++-++-+---+-++++--++--+-++++-++-----++++-++++--+---++-+---+-+++-+-+++-+----++--++-+----+--++-+--+-+++-+--+++-++----+----+-+-+---+-++++--++--+-++++-++ ++++-++++-+--++--++++-+---+-++++-+---+-++---+--++++-++++--+++--++-+---+-++---+--++++-++++-++-+---+-++++--++--+-++++-+++----++++-++++--+---++-+---+-+++-+-+++-+----++--++-+----+---+-+--+-+++-+--+++-++----+----+-+-+---+-++++--++--+-++++-+++ +++++-++++-+--++--++++-+---+-++++-+---+-++---+--++++-++++--+--++-+---+-++---+--++++-++++-++-+---+-++++--++--+-++++-++++---++++-++++--+---++-+---+-+++-+-+++-+----++--++-+----+-----+--+-+++-+--+++-++----+----+-+-+---+-++++--++--+-++++-++++ ++--+-+++-+--+++-++----+----+--+-+++-+----++--++-+----+----++++++-++++-+--++--++++-+---+-++++-+---+-++---+--++++-++++------+----+-++--++----+-+++-+-+-++++-++++--+---++-+---+-++-+----+----+-++--++----+-+++-+-+++-+---+-++---+--++++-++++--- +--+-+++-+--+++-++----+----+--+-+++-+----++--++-+----+----++-+++++-++++-+--++--++++-+---+-++++-+---+-++---+--++++-++++----+----+-++--++----+-+++-+-+-++++-++++--+---++-+---+-++-+----+----+-++--++----+-+++-+-+++-+---+-++---+--++++-++++---- +-+-+++-+--+++-++----+----+--+-+++-+----++--++-+----+----++---+++++-++++-+--++--++++-+---+-++++-+---+-++---+--++++-++++--+----+-++--++----+-+++-+-+-++++-++++--+---++-+---+-++-+----+----+-++--++----+-+++-+-+++-+---+-++---+--++++-++++----- ++-+++-+--+++-++----+----+--+-+++-+----++--++-+----+----++--+--+++++-++++-+--++--++++-+---+-++++-+---+-++---+--++++-+++-+----+-++--++----+-+++-+-+-++++-++++--+---++-+---+-++-+----+----+-++--++----+-+++-+-+++-+---+-++---+--++++-++++------ +-+++-+--+++-++----+----+--+-+++-+----++--++-+----+----++--+++--+++++-++++-+--++--++++-+---+-++++-+---+-++---+--++++-+++----+-++--++----+-+++-+-+-++++-++++--+---++-+---+-++-+----+----+-++--++----+-+++-+-+++-+---+-++---+--++++-++++------- ++++-+--+++-++----+----+--+-+++-+----++--++-+----+----++--+-+++--+++++-++++-+--++--++++-+---+-++++-+---+-++---+--++++-+----+-++--++----+-+++-+-+-++++-++++--+---++-+---+-++-+----+----+-++--++----+-+++-+-+++-+---+-++---+--++++-++++-------+ +++-+--+++-++----+----+--+-+++-+----++--++-+----+----++--+-+++++--+++++-++++-+--++--++++-+---+-++++-+---+-++---+--++++----+-++--++----+-+++-+-+-++++-++++--+---++-+---+-++-+----+----+-++--++----+-+++-+-+++-+---+-++---+--++++-++++-------+- ++-+--+++-++----+----+--+-+++-+----++--++-+----+----++--+-++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-++---+--++++--+-++--++----+-+++-+-+-++++-++++--+---++-+---+-++-+----+----+-++--++----+-+++-+-+++-+---+-++---+--++++-++++-------+-- +-+--+++-++----+----+--+-+++-+----++--++-+----+----++--+-++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-++---+--+++-+-++--++----+-+++-+-+-++++-++++--+---++-+---+-++-+----+----+-++--++----+-+++-+-+++-+---+-++---+--++++-++++-------+--- ++--+++-++----+----+--+-+++-+----++--++-+----+----++--+-+++-++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-++---+--+++-++--++----+-+++-+-+-++++-++++--+---++-+---+-++-+----+----+-++--++----+-+++-+-+++-+---+-++---+--++++-++++-------+---- +--+++-++----+----+--+-+++-+----++--++-+----+----++--+-+++-++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-++---+--+-++--++----+-+++-+-+-++++-++++--+---++-+---+-++-+----+----+-++--++----+-+++-+-+++-+---+-++---+--++++-++++-------+----+ +-+++-++----+----+--+-+++-+----++--++-+----+----++--+-+++-+-++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-++---+--++--++----+-+++-+-+-++++-++++--+---++-+---+-++-+----+----+-++--++----+-+++-+-+++-+---+-++---+--++++-++++-------+----+- ++++-++----+----+--+-+++-+----++--++-+----+----++--+-+++-+---++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-++---+-+--++----+-+++-+-+-++++-++++--+---++-+---+-++-+----+----+-++--++----+-+++-+-+++-+---+-++---+--++++-++++-------+----+-+ +++-++----+----+--+-+++-+----++--++-+----+----++--+-+++-+--+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-++---+--++----+-+++-+-+-++++-++++--+---++-+---+-++-+----+----+-++--++----+-+++-+-+++-+---+-++---+--++++-++++-------+----+-++ ++-++----+----+--+-+++-+----++--++-+----+----++--+-+++-+--+++--++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-++----++----+-+++-+-+-++++-++++--+---++-+---+-++-+----+----+-++--++----+-+++-+-+++-+---+-++---+--++++-++++-------+----+-++- +-++----+----+--+-+++-+----++--++-+----+----++--+-+++-+--+++-+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-++--++----+-+++-+-+-++++-++++--+---++-+---+-++-+----+----+-++--++----+-+++-+-+++-+---+-++---+--++++-++++-------+----+-++-- +++----+----+--+-+++-+----++--++-+----+----++--+-+++-+--+++---+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-++-+----+-+++-+-+-++++-++++--+---++-+---+-++-+----+----+-++--++----+-+++-+-+++-+---+-++---+--++++-++++-------+----+-++--+ ++----+----+--+-+++-+----++--++-+----+----++--+-+++-+--+++-+---+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-++----+-+++-+-+-++++-++++--+---++-+---+-++-+----+----+-++--++----+-+++-+-+++-+---+-++---+--++++-++++-------+----+-++--++ +----+----+--+-+++-+----++--++-+----+----++--+-+++-+--+++-+++---+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-+---+-+++-+-+-++++-++++--+---++-+---+-++-+----+----+-++--++----+-+++-+-+++-+---+-++---+--++++-++++-------+----+-++--++- +---+----+--+-+++-+----++--++-+----+----++--+-+++-+--+++-++-++---+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+---+-+++-+-+-++++-++++--+---++-+---+-++-+----+----+-++--++----+-+++-+-+++-+---+-++---+--++++-++++-------+----+-++--++-- +--+----+--+-+++-+----++--++-+----+----++--+-+++-+--+++-++---++---+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-+-+++-+-+-++++-++++--+---++-+---+-++-+----+----+-++--++----+-+++-+-+++-+---+-++---+--++++-++++-------+----+-++--++--- +-+----+--+-+++-+----++--++-+----+----++--+-+++-+--+++-++---+-++---+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-+++-+-+-++++-++++--+---++-+---+-++-+----+----+-++--++----+-+++-+-+++-+---+-++---+--++++-++++-------+----+-++--++---- ++----+--+-+++-+----++--++-+----+----++--+-+++-+--+++-++-----+-++---+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+++-+-+-++++-++++--+---++-+---+-++-+----+----+-++--++----+-+++-+-+++-+---+-++---+--++++-++++-------+----+-++--++----+ +----+--+-+++-+----++--++-+----+----++--+-+++-+--+++-++----+--+-++---+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+-+++-+-+-++++-++++--+---++-+---+-++-+----+----+-++--++----+-+++-+-+++-+---+-++---+--++++-++++-------+----+-++--++----+- +---+--+-+++-+----++--++-+----+----++--+-+++-+--+++-++----+----+-++---+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+++-+-+-++++-++++--+---++-+---+-++-+----+----+-++--++----+-+++-+-+++-+---+-++---+--++++-++++-------+----+-++--++----+-+ +--+--+-+++-+----++--++-+----+----++--+-+++-+--+++-++----+--+---+-++---+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+-+-+-++++-++++--+---++-+---+-++-+----+----+-++--++----+-+++-+-+++-+---+-++---+--++++-++++-------+----+-++--++----+-++ +-+--+-+++-+----++--++-+----+----++--+-+++-+--+++-++----+----+---+-++---+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+-+-++++-++++--+---++-+---+-++-+----+----+-++--++----+-+++-+-+++-+---+-++---+--++++-++++-------+----+-++--++----+-+++ ++--+-+++-+----++--++-+----+----++--+-+++-+--+++-++----+----+-+---+-++---+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+-++++-++++--+---++-+---+-++-+----+----+-++--++----+-+++-+-+++-+---+-++---+--++++-++++-------+----+-++--++----+-+++- +--+-+++-+----++--++-+----+----++--+-+++-+--+++-++----+----+++-+---+-++---+--++++-++++--+++++-++++-+--++--++++-+---+-++-+-++++-++++--+---++-+---+-++-+----+----+-++--++----+-+++-+-+++-+---+-++---+--++++-++++-------+----+-++--++----+-+++-+ +-+-+++-+----++--++-+----+----++--+-+++-+--+++-++----+----+-+++-+---+-++---+--++++-++++--+++++-++++-+--++--++++-+---+-++-++++-++++--+---++-+---+-++-+----+----+-++--++----+-+++-+-+++-+---+-++---+--++++-++++-------+----+-++--++----+-+++-+- ++-+++-+----++--++-+----+----++--+-+++-+--+++-++----+----+--++++-+---+-++---+--++++-++++--+++++-++++-+--++--++++-+---+--++++-++++--+---++-+---+-++-+----+----+-++--++----+-+++-+-+++-+---+-++---+--++++-++++-------+----+-++--++----+-+++-+-+ +-+++-+----++--++-+----+----++--+-+++-+--+++-++----+----+--+-++++-+---+-++---+--++++-++++--+++++-++++-+--++--++++-+---+++++-++++--+---++-+---+-++-+----+----+-++--++----+-+++-+-+-+-+---+-++---+--++++-++++-------+----+-++--++----+-+++-+-++ ++++-+----++--++-+----+----++--+-+++-+--+++-++----+----+--+-+-++++-+---+-++---+--++++-++++--+++++-++++-+--++--++++-+---+++-++++--+---++-+---+-++-+----+----+-++--++----+-+++-+-+-+-+---+-++---+--++++-++++-------+----+-++--++----+-+++-+-+++ +++-+----++--++-+----+----++--+-+++-+--+++-++----+----+--+-+-+-++++-+---+-++---+--++++-++++--+++++-++++-+--++--++++-+--++-++++--+---++-+---+-++-+----+----+-++--++----+-+++-+-+-+++---+-++---+--++++-++++-------+----+-++--++----+-+++-+-+++- ++-+----++--++-+----+----++--+-+++-+--+++-++----+----+--+-++--+-++++-+---+-++---+--++++-++++--+++++-++++-+--++--++++-+-+-++++--+---++-+---+-++-+----+----+-++--++----+-+++-+-+-+++---+-++---+--++++-++++-------+----+-++--++----+-+++-+-+++-+ +-+----++--++-+----+----++--+-+++-+--+++-++----+----+--+-+++---+-++++-+---+-++---+--++++-++++--+++++-++++-+--++--++++-+-++++--+---++-+---+-++-+----+----+-++--++----+-+++-+-+-++++--+-++---+--++++-++++-------+----+-++--++----+-+++-+-+++-+- ++----++--++-+----+----++--+-+++-+--+++-++----+----+--+-+++-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+--++--++++-++++--+---++-+---+-++-+----+----+-++--++----+-+++-+-+-++++--+-++---+--++++-++++-------+----+-++--++----+-+++-+-+++-+-- +----++--++-+----+----++--+-+++-+--+++-++----+----+--+-+++-+-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+--++--+++++++--+---++-+---+-++-+----+----+-++--++----+-+++-+-+-++++-++-++---+--++++-++++-------+----+-++--++----+-+++-+-+++-+--- +---++--++-+----+----++--+-+++-+--+++-++----+----+--+-+++-+-+-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+--++--+++++--+---++-+---+-++-+----+----+-++--++----+-+++-+-+-++++-++-++---+--++++-++++-------+----+-++--++----+-+++-+-+++-+---+ +--++--++-+----+----++--+-+++-+--+++-++----+----+--+-+++-+--++-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+--++--+++--+---++-+---+-++-+----+----+-++--++----+-+++-+-+-++++-+++++---+--++++-++++-------+----+-++--++----+-+++-+-+++-+---+- +-++--++-+----+----++--+-+++-+--+++-++----+----+--+-+++-+---+++-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+--++--+--+---++-+---+-++-+----+----+-++--++----+-+++-+-+-++++-+++++---+--++++-++++-------+----+-++--++----+-+++-+-+++-+---+-+ +++--++-+----+----++--+-+++-+--+++-++----+----+--+-+++-+----++++-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+--++---+---++-+---+-++-+----+----+-++--++----+-+++-+-+-++++-++++----+--++++-++++-------+----+-++--++----+-+++-+-+++-+---+-++ ++--++-+----+----++--+-+++-+--+++-++----+----+--+-+++-+----+-++++-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+--++-+---++-+---+-++-+----+----+-++--++----+-+++-+-+-++++-++++----+--++++-++++-------+----+-++--++----+-+++-+-+++-+---+-++- +--++-+----+----++--+-+++-+--+++-++----+----+--+-+++-+----++--++++-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+--++---++-+---+-++-+----+----+-++--++----+-+++-+-+-++++-++++--+-+--++++-++++-------+----+-++--++----+-+++-+-+++-+---+-++-- +-++-+----+----++--+-+++-+--+++-++----+----+--+-+++-+----++-+--++++-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+--+--++-+---+-++-+----+----+-++--++----+-+++-+-+-++++-++++--+-+--++++-++++-------+----+-++--++----+-+++-+-+++-+---+-++--- +++-+----+----++--+-+++-+--+++-++----+----+--+-+++-+----++--++--++++-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+---++-+---+-++-+----+----+-++--++----+-+++-+-+-++++-++++--+----++++-++++-------+----+-++--++----+-+++-+-+++-+---+-++---+ ++-+----+----++--+-+++-+--+++-++----+----+--+-+++-+----++--+-++--++++-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+-++-+---+-++-+----+----+-++--++----+-+++-+-+-++++-++++--+----++++-++++-------+----+-++--++----+-+++-+-+++-+---+-++---+- +-+----+----++--+-+++-+--+++-++----+----+--+-+++-+----++--++--++--++++-+---+-++++-+---+-++---+--++++-++++--+++++-++++-++-+---+-++-+----+----+-++--++----+-+++-+-+-++++-++++--+---+++++-++++-------+----+-++--++----+-+++-+-+++-+---+-++---+-- ++----+----++--+-+++-+--+++-++----+----+--+-+++-+----++--++-+--++--++++-+---+-++++-+---+-++---+--++++-++++--+++++-++++--+---+-++-+----+----+-++--++----+-+++-+-+-++++-++++--+---+++++-++++-------+----+-++--++----+-+++-+-+++-+---+-++---+--+ +----+----++--+-+++-+--+++-++----+----+--+-+++-+----++--++-+-+--++--++++-+---+-++++-+---+-++---+--++++-++++--+++++-+++++---+-++-+----+----+-++--++----+-+++-+-+-++++-++++--+---++-++-++++-------+----+-++--++----+-+++-+-+++-+---+-++---+--++ +---+----++--+-+++-+--+++-++----+----+--+-+++-+----++--++-+-+-+--++--++++-+---+-++++-+---+-++---+--++++-++++--+++++-+++---+-++-+----+----+-++--++----+-+++-+-+-++++-++++--+---++-++-++++-------+----+-++--++----+-+++-+-+++-+---+-++---+--+++ +--+----++--+-+++-+--+++-++----+----+--+-+++-+----++--++-+--++-+--++--++++-+---+-++++-+---+-++---+--++++-++++--+++++-++--+-++-+----+----+-++--++----+-+++-+-+-++++-++++--+---++-+--++++-------+----+-++--++----+-+++-+-+++-+---+-++---+--++++ +-+----++--+-+++-+--+++-++----+----+--+-+++-+----++--++-+---+++-+--++--++++-+---+-++++-+---+-++---+--++++-++++--+++++-+-+-++-+----+----+-++--++----+-+++-+-+-++++-++++--+---++-+--++++-------+----+-++--++----+-+++-+-+++-+---+-++---+--++++- ++----++--+-+++-+--+++-++----+----+--+-+++-+----++--++-+----++++-+--++--++++-+---+-++++-+---+-++---+--++++-++++--+++++-+-++-+----+----+-++--++----+-+++-+-+-++++-++++--+---++-+---+++-------+----+-++--++----+-+++-+-+++-+---+-++---+--++++-+ +----++--+-+++-+--+++-++----+----+--+-+++-+----++--++-+----+-++++-+--++--++++-+---+-++++-+---+-++---+--++++-++++--+++++-++-+----+----+-++--++----+-+++-+-+-++++-++++--+---++-+---+++-------+----+-++--++----+-+++-+-+++-+---+-++---+--++++-++ +---++--+-+++-+--+++-++----+----+--+-+++-+----++--++-+----+-+-++++-+--++--++++-+---+-++++-+---+-++---+--++++-++++--++++++-+----+----+-++--++----+-+++-+-+-++++-++++--+---++-+---+-+-------+----+-++--++----+-+++-+-+++-+---+-++---+--++++-+++ +--++--+-+++-+--+++-++----+----+--+-+++-+----++--++-+----+--++-++++-+--++--++++-+---+-++++-+---+-++---+--++++-++++--++++-+----+----+-++--++----+-+++-+-+-++++-++++--+---++-+---+-+-------+----+-++--++----+-+++-+-+++-+---+-++---+--++++-++++ +-++--+-+++-+--+++-++----+----+--+-+++-+----++--++-+----+---+++-++++-+--++--++++-+---+-++++-+---+-++---+--++++-++++--++-+----+----+-++--++----+-+++-+-+-++++-++++--+---++-+---+-++------+----+-++--++----+-+++-+-+++-+---+-++---+--++++-++++- +++--+-+++-+--+++-++----+----+--+-+++-+----++--++-+----+----++++-++++-+--++--++++-+---+-++++-+---+-++---+--++++-++++--++----+----+-++--++----+-+++-+-+-++++-++++--+---++-+---+-++------+----+-++--++----+-+++-+-+++-+---+-++---+--++++-++++-- +++----+----++-+++--+-+++-+---+-+---+-++++--++--+-++++-+++++++++-++++-+--++--++++-+---+-+-+----+----++-+++--+-+++-+--+-+++++-++++-+--++--++++-+---+-++++-+---+-++---+--++++-++++------+----+-++--++----+-+++-+--+----+----++-+++--+-+++-+--++ ++----+----++-+++--+-+++-+---+-+---+-++++--++--+-++++-+++++++++-++++-+--++--++++-+---+-+-+----+----++-+++--+-+++-+--+-+-+++++-++++-+--++--++++-+---+-++++-+---+-++---+--++++-++++----+----+-++--++----+-+++-+--+----+----++-+++--+-+++-+--++- +----+----++-+++--+-+++-+---+-+---+-++++--++--+-++++-+++++++++-++++-+--++--++++-+---+-+-+----+----++-+++--+-+++-+--+-++--+++++-++++-+--++--++++-+---+-++++-+---+-++---+--++++-++++--+----+-++--++----+-+++-+--+----+----++-+++--+-+++-+--++-- +---+----++-+++--+-+++-+---+-+---+-++++--++--+-++++-+++++++-+-++++-+--++--++++-+---+-+-+----+----++-+++--+-+++-+--+-++++--+++++-++++-+--++--++++-+---+-++++-+---+-++---+--++++-+++-+----+-++--++----+-+++-+--+----+----++-+++--+-+++-+--++--- +--+----++-+++--+-+++-+---+-+---+-++++--++--+-++++-+++++++---++++-+--++--++++-+---+-+-+----+----++-+++--+-+++-+--+-++++++--+++++-++++-+--++--++++-+---+-++++-+---+-++---+--++++-+++----+-++--++----+-+++-+--+----+----++-+++--+-+++-+--++---- +-+----++-+++--+-+++-+---+-+---+-++++--++--+-++++-+++++++---++++-+--++--++++-+---+-+-+----+----++-+++--+-+++-+--+-++++-+++--+++++-++++-+--++--++++-+---+-++++-+---+-++---+--++++-+----+-++--++----+-+++-+--+----+----++-+++--+-+++-+--++----+ ++----++-+++--+-+++-+---+-+---+-++++--++--+-++++-+++++++----+++-+--++--++++-+---+-+-+----+----++-+++--+-+++-+--+-++++-+++++--+++++-++++-+--++--++++-+---+-++++-+---+-++---+--++++----+-++--++----+-+++-+--+----+----++-+++--+-+++-+--++----+- +----++-+++--+-+++-+---+-+---+-++++--++--+-++++-+++++++----+++-+--++--++++-+---+-+-+----+----++-+++--+-+++-+--+-++++-++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-++---+--++++--+-++--++----+-+++-+--+----+----++-+++--+-+++-+--++----+-- +---++-+++--+-+++-+---+-+---+-++++--++--+-++++-+++++++----+-+-+--++--++++-+---+-+-+----+----++-+++--+-+++-+--+-++++-++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-++---+--+++-+-++--++----+-+++-+--+----+----++-+++--+-+++-+--++----+--- +--++-+++--+-+++-+---+-+---+-++++--++--+-++++-+++++++----+---+--++--++++-+---+-+-+----+----++-+++--+-+++-+--+-++++-++++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-++---+--+++-++--++----+-+++-+--+----+----++-+++--+-+++-+--++----+---- +-++-+++--+-+++-+---+-+---+-++++--++--+-++++-+++++++----+---+--++--++++-+---+-+-+----+----++-+++--+-+++-+--+-++++-++++-+++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-++---+--+-++--++----+-+++-+--+----+----++-+++--+-+++-+--++----+----+ +++-+++--+-+++-+---+-+---+-++++--++--+-++++-+++++++----+------++--++++-+---+-+-+----+----++-+++--+-+++-+--+-++++-++++-+++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-++---+--++--++----+-+++-+--+----+----++-+++--+-+++-+--++----+----+- ++-+++--+-+++-+---+-+---+-++++--++--+-++++-+++++++----+----+-++--++++-+---+-+-+----+----++-+++--+-+++-+--+-++++-++++-+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-++---+-+--++----+-+++-+--+----+----++-+++--+-+++-+--++----+----+-+ +-+++--+-+++-+---+-+---+-++++--++--+-++++-+++++++----+----++++--++++-+---+-+-+----+----++-+++--+-+++-+--+-++++-++++-+----++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-++---+--++----+-+++-+--+----+----++-+++--+-+++-+--++----+----+-++ ++++--+-+++-+---+-+---+-++++--++--+-++++-+++++++----+----++-+--++++-+---+-+-+----+----++-+++--+-+++-+--+-++++-++++-+--++--++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-++----++----+-+++-+--+----+----++-+++--+-+++-+--++----+----+-++- +++--+-+++-+---+-+---+-++++--++--+-++++-+++++++----+----++-+--++++-+---+-+-+----+----++-+++--+-+++-+--+-++++-++++-+--++-+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-++--++----+-+++-+--+----+----++-+++--+-+++-+--++----+----+-++-- ++--+-+++-+---+-+---+-++++--++--+-++++-+++++++----+----++-++-++++-+---+-+-+----+----++-+++--+-+++-+--+-++++-++++-+--++---+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-++-+----+-+++-+--+----+----++-+++--+-+++-+--++----+----+-++--+ +--+-+++-+---+-+---+-++++--++--+-++++-+++++++----+----++-+++++++-+---+-+-+----+----++-+++--+-+++-+--+-++++-++++-+--++-----+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-++----+-+++-+--+----+----++-+++--+-+++-+--++----+----+-++--++ +-+-+++-+---+-+---+-++++--++--+-++++-+++++++----+----++-+++-+++-+---+-+-+----+----++-+++--+-+++-+--+-++++-++++-+--++--++---+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-+---+-+++-+--+----+----++-+++--+-+++-+--++----+----+-++--++- ++-+++-+---+-+---+-++++--++--+-++++-+++++++----+----++-+++--++-+---+-+-+----+----++-+++--+-+++-+--+-++++-++++-+--++--++++---+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+---+-+++-+--+----+----++-+++--+-+++-+--++----+----+-++--++-- +-+++-+---+-+---+-++++--++--+-++++-+++++++----+----++-+++--++-+---+-+-+----+----++-+++--+-+++-+--+-++++-++++-+--++--+++-++---+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-+-+++-+--+----+----++-+++--+-+++-+--++----+----+-++--++--- ++++-+---+-+---+-++++--++--+-++++-+++++++----+----++-+++--+--+---+-+-+----+----++-+++--+-+++-+--+-++++-++++-+--++--+++++-++---+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-+++-+--+----+----++-+++--+-+++-+--++----+----+-++--++---- +++-+---+-+---+-++++--++--+-++++-+++++++----+----++-+++--+-++---+-+-+----+----++-+++--+-+++-+--+-++++-++++-+--++--++++--+-++---+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+++-+--+----+----++-+++--+-+++-+--++----+----+-++--++----+ ++-+---+-+---+-++++--++--+-++++-+++++++----+----++-+++--+-++---+-+-+----+----++-+++--+-+++-+--+-++++-++++-+--++--++++-+--+-++---+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+-+++-+--+----+----++-+++--+-+++-+--++----+----+-++--++----+- +-+---+-+---+-++++--++--+-++++-+++++++----+----++-+++--+-+++--+-+-+----+----++-+++--+-+++-+--+-++++-++++-+--++--++++-+----+-++---+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+++-+--+----+----++-+++--+-+++-+--++----+----+-++--++----+-+ ++---+-+---+-++++--++--+-++++-+++++++----+----++-+++--+-+++--+-+-+----+----++-+++--+-+++-+--+-++++-++++-+--++--++++-+--+---+-++---+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+-+--+----+----++-+++--+-+++-+--++----+----+-++--++----+-++ +---+-+---+-++++--++--+-++++-+++++++----+----++-+++--+-+++-++-+-+----+----++-+++--+-+++-+--+-++++-++++-+--++--++++-+----+---+-++---+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+--+----+----++-+++--+-+++-+--++----+----+-++--++----+-+++ +--+-+---+-++++--++--+-++++-+++++++----+----++-+++--+-+++-+--+-+----+----++-+++--+-+++-+--+-++++-++++-+--++--++++-+---++-+---+-++---+--++++-++++--+++++-++++-+--++--++++-+---+-++++--+----+----++-+++--+-+++-+--++----+----+-++--++----+-+++- +-+-+---+-++++--++--+-++++-+++++++----+----++-+++--+-+++-+--+-+----+----++-+++--+-+++-+--+-++++-++++-+--++--++++-+---+-++-+---+-++---+--++++-++++--+++++-++++-+--++--++++-+---+-++--+----+----++-+++--+-+++-+--++----+----+-++--++----+-+++-+ ++-+---+-++++--++--+-++++-+++++++----+----++-+++--+-+++-+----+----+----++-+++--+-+++-+--+-++++-++++-+--++--++++-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+--++--++++-+---+-+-+----+----++-+++--+-+++-+--++----+----+-++--++----+-+++-+- +-+---+-++++--++--+-++++-+++++++----+----++-+++--+-+++-+---++----+----++-+++--+-+++-+--+-++++-++++-+--++--++++-+---+-+-++++-+---+-++---+--++++-++++--+++++-++++-+--++--++++-+---+-+----+----++-+++--+-+++-+--++----+----+-++--++----+-+++-+-- ++---+-++++--++--+-++++-+++++++----+----++-+++--+-+++-+---+-----+----++-+++--+-+++-+--+-++++-++++-+--++--++++-+---+-+-+-++++-+---+-++---+--++++-++++--+++++-++++-+--++--++++-+---+----+----++-+++--+-+++-+--++----+----+-++--++----+-+++-+--+ +---+-++++--++--+-++++-+++++++----+----++-+++--+-+++-+---+-+---+----++-+++--+-+++-+--+-++++-++++-+--++--++++-+---+-+-+-+-++++-+---+-++---+--++++-++++--+++++-++++-+--++--++++-+------+----++-+++--+-+++-+--++----+----+-++--++----+-+++-+--+- +--+-++++--++--+-++++-+++++++----+----++-+++--+-+++-+---+-+---+----++-+++--+-+++-+--+-++++-++++-+--++--++++-+---+-+-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+--++--++++-+----+----++-+++--+-+++-+--++----+----+-++--++----+-+++-+--+-- +-+-++++--++--+-++++-+++++++----+----++-+++--+-+++-+---+-+---+----++-+++--+-+++-+--+-++++-++++-+--++--++++-+---+-+-+-----+-++++-+---+-++---+--++++-++++--+++++-++++-+--++--++++-+--+----++-+++--+-+++-+--++----+----+-++--++----+-+++-+--+--- ++-++++--++--+-++++-+++++++----+----++-+++--+-+++-+---+-+---+----++-+++--+-+++-+--+-++++-++++-+--++--++++-+---+-+-+-------+-++++-+---+-++---+--++++-++++--+++++-++++-+--++--++++-++----++-+++--+-+++-+--++----+----+-++--++----+-+++-+--+---- +-++++--++--+-++++-+++++++----+----++-+++--+-+++-+---+-+---+----++-+++--+-+++-+--+-++++-++++-+--++--++++-+---+-+-+----++---+-++++-+---+-++---+--++++-++++--+++++-++++-+--++--++++-----++-+++--+-+++-+--++----+----+-++--++----+-+++-+--+----+ +++++--++--+-++++-+++++++----+----++-+++--+-+++-+---+-+---+----++-+++--+-+++-+--+-++++-++++-+--++--++++-+---+-+-+----+--+---+-++++-+---+-++---+--++++-++++--+++++-++++-+--++--++++---++-+++--+-+++-+--++----+----+-++--++----+-+++-+--+----+- ++++--++--+-++++-+++++++----+----++-+++--+-+++-+---+-+---+-+--++-+++--+-+++-+--+-++++-++++-+--++--++++-+---+-+-+----+--+-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+--++--+++--++-+++--+-+++-+--++----+----+-++--++----+-+++-+--+----+-- +++--++--+-++++-+++++++----+----++-+++--+-+++-+---+-+---+-++-++-+++--+-+++-+--+-++++-++++-+--++--++++-+---+-+-+----+---++-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+--++--++-++-+++--+-+++-+--++----+----+-++--++----+-+++-+--+----+--- ++--++--+-++++-+++++++----+----++-+++--+-+++-+---+-+---+-+++++-+++--+-+++-+--+-++++-++++-+--++--++++-+---+-+-+----+----+++-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+--++--+++-+++--+-+++-+--++----+----+-++--++----+-+++-+--+----+---- +--++--+-++++-+++++++----+----++-+++--+-+++-+---+-+---+-+++++-+++--+-+++-+--+-++++-++++-+--++--++++-+---+-+-+----+----+++++-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+--++--+-+++--+-+++-+--++----+----+-++--++----+-+++-+--+----+----+ +-++--+-++++-+++++++----+----++-+++--+-+++-+---+-+---+-++++--+++--+-+++-+--+-++++-++++-+--++--++++-+---+-+-+----+----++-++++-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+--++--+++--+-+++-+--++----+----+-++--++----+-+++-+--+----+----++ +++--+-++++-+++++++----+----++-+++--+-+++-+---+-+---+-++++--+++--+-+++-+--+-++++-++++-+--++--++++-+---+-+-+----+----++---++++-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+--+++++--+-+++-+--++----+----+-++--++----+-+++-+--+----+----++- ++--+-++++-+++++++----+----++-+++--+-+++-+---+-+---+-++++--+++--+-+++-+--+-++++-++++-+--++--++++-+---+-+-+----+----++-++--++++-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+--+++--+-+++-+--++----+----+-++--++----+-+++-+--+----+----++-+ +--+-++++-+++++++----+----++-+++--+-+++-+---+-+---+-++++--+++--+-+++-+--+-++++-++++-+--++--++++-+---+-+-+----+----++-++++--++++-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+--+--+-+++-+--++----+----+-++--++----+-+++-+--+----+----++-++ +-+-++++-+++++++----+----++-+++--+-+++-+---+-+---+-++++--++---+-+++-+--+-++++-++++-+--++--++++-+---+-+-+----+----++-+++-++--++++-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+---+-+++-+--++----+----+-++--++----+-+++-+--+----+----++-+++ ++-++++-+++++++----+----++-+++--+-+++-+---+-+---+-++++--++---+-+++-+--+-++++-++++-+--++--++++-+---+-+-+----+----++-+++---++--++++-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+-+-+++-+--++----+----+-++--++----+-+++-+--+----+----++-+++- +-++++-+++++++----+----++-+++--+-+++-+---+-+---+-++++--++--++-+++-+--+-++++-++++-+--++--++++-+---+-+-+----+----++-+++--+--++--++++-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+-+++-+--++----+----+-++--++----+-+++-+--+----+----++-+++-- +++++-+++++++----+----++-+++--+-+++-+---+-+---+-++++--++--+--+++-+--+-++++-++++-+--++--++++-+---+-+-+----+----++-+++--+-+--++--++++-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+++-+--++----+----+-++--++----+-+++-+--+----+----++-+++--+ ++++-+++++++----+----++-+++--+-+++-+---+-+---+-++++--++--+-++++-+--+-++++-++++-+--++--++++-+---+-+-+----+----++-+++--+-+-+--++--++++-+---+-++++-+---+-++---+--++++-++++--+++++-++++++-+--++----+----+-++--++----+-+++-+--+----+----++-+++--+- +++-+++++++----+----++-+++--+-+++-+---+-+---+-++++--++--+-++++-+--+-++++-++++-+--++--++++-+---+-+-+----+----++-+++--+-+++-+--++--++++-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+--++----+----+-++--++----+-+++-+--+----+----++-+++--+-+ ++-+++++++----+----++-+++--+-+++-+---+-+---+-++++--++--+-++++-+--+-++++-++++-+--++--++++-+---+-+-+----+----++-+++--+-+++++-+--++--++++-+---+-++++-+---+-++---+--++++-++++--+++++-++-+--++----+----+-++--++----+-+++-+--+----+----++-+++--+-++ +-+++++++----+----++-+++--+-+++-+---+-+---+-++++--++--+-++++-+--+-++++-++++-+--++--++++-+---+-+-+----+----++-+++--+-+++++++-+--++--++++-+---+-++++-+---+-++---+--++++-++++--+++++--+--++----+----+-++--++----+-+++-+--+----+----++-+++--+-+++ ++++++++----+----++-+++--+-+++-+---+-+---+-++++--++--+-++++-+--+-++++-++++-+--++--++++-+---+-+-+----+----++-+++--+-+++--++++-+--++--++++-+---+-++++-+---+-++---+--++++-++++--++++++--++----+----+-++--++----+-+++-+--+----+----++-+++--+-+++- +++++++----+----++-+++--+-+++-+---+-+---+-++++--++--+-++++-+--+-++++-++++-+--++--++++-+---+-+-+----+----++-+++--+-+++-++-++++-+--++--++++-+---+-++++-+---+-++---+--++++-++++--++++--++----+----+-++--++----+-+++-+--+----+----++-+++--+-+++-+ ++++++----+----++-+++--+-+++-+---+-+---+-++++--++--+-++++-++-+-++++-++++-+--++--++++-+---+-+-+----+----++-+++--+-+++-+-++-++++-+--++--++++-+---+-++++-+---+-++---+--++++-++++--+++-++----+----+-++--++----+-+++-+--+----+----++-+++--+-+++-+- +++++----+----++-+++--+-+++-+---+-+---+-++++--++--+-++++-++++-++++-++++-+--++--++++-+---+-+-+----+----++-+++--+-+++-+--+++-++++-+--++--++++-+---+-++++-+---+-++---+--++++-++++--++++----+----+-++--++----+-+++-+--+----+----++-+++--+-+++-+-- ++++----+----++-+++--+-+++-+---+-+---+-++++--++--+-++++-++++-++++-++++-+--++--++++-+---+-+-+----+----++-+++--+-+++-+--+++++-++++-+--++--++++-+---+-++++-+---+-++---+--++++-++++--++----+----+-++--++----+-+++-+--+----+----++-+++--+-+++-+--+ +-++-+---+-++---+--++++-++++-+-+-+++-+----++--++-+----+----+++++-++++-+--++--++++-+---+-+---+-+++-+--+++-++----+----+++++++-++++-+--++--++++-+---+-++-++++-++++--+---++-+---+-++--+++++-++++-+--++--++++-+---+-++++-+---+-++---+--++++-++++-- +++-+---+-++---+--++++-++++-+-+-+++-+----++--++-+----+----+-+++-++++-+--++--++++-+---+-+---+-+++-+--+++-++----+----+++++++-++++-+--++--++++-+---+-++-++++-++++--+---++-+---+-++--+-+++++-++++-+--++--++++-+---+-++++-+---+-++---+--++++-++++- ++-+---+-++---+--++++-++++-+-+-+++-+----++--++-+----+----+-+++-++++-+--++--++++-+---+-+---+-+++-+--+++-++----+----+++++++-++++-+--++--++++-+---+-++-++++-++++--+---++-+---+-++--++--+++++-++++-+--++--++++-+---+-++++-+---+-++---+--++++-++++ +-+---+-++---+--++++-++++-+-+-+++-+----++--++-+----+----+-+++-++++-+--++--++++-+---+-+---+-+++-+--+++-++----+----+++++++-++++-+--++--++++-+---+-++-++++-++++--+---++-+---+-++--++++--+++++-++++-+--++--++++-+---+-++++-+---+-++---+--++++-+++ ++---+-++---+--++++-++++-+-+-+++-+----++--++-+----+----+-++--++++-+--++--++++-+---+-+---+-+++-+--+++-++----+----+++++++-++++-+--++--++++-+---+-++-++++-++++--+---++-+---+-++--++++++--+++++-++++-+--++--++++-+---+-++++-+---+-++---+--++++-++ +---+-++---+--++++-++++-+-+-+++-+----++--++-+----+----+-++-+++++-+--++--++++-+---+-+---+-+++-+--+++-++----+----+++++++-++++-+--++--++++-+---+-++-++++-++++--+---++-+---+-++--++++-+++--+++++-++++-+--++--++++-+---+-++++-+---+-++---+--++++-+ +--+-++---+--++++-++++-+-+-+++-+----++--++-+----+----+-++-+-+++-+--++--++++-+---+-+---+-+++-+--+++-++----+----+++++++-++++-+--++--++++-+---+-++-++++-++++--+---++-+---+-++--++++-+++++--+++++-++++-+--++--++++-+---+-++++-+---+-++---+--++++- +-+-++---+--++++-++++-+-+-+++-+----++--++-+----+----+-++-+--++-+--++--++++-+---+-+---+-+++-+--+++-++----+----+++++++-++++-+--++--++++-+---+-++-++++-++++--+---++-+---+-++--++++-++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-++---+--++++ ++-++---+--++++-++++-+-+-+++-+----++--++-+----+----+-++-+---+-+--++--++++-+---+-+---+-+++-+--+++-++----+----+++++++-++++-+--++--++++-+---+-++-++++-++++--+---++-+---+-++--++++-++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-++---+--+++ +-++---+--++++-++++-+-+-+++-+----++--++-+----+----+-++-+---+-+--++--++++-+---+-+---+-+++-+--+++-++----+----+++++++-++++-+--++--++++-+---+-++-++++-++++--+---++-+---+-++--++++-++++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-++---+--++ +++---+--++++-++++-+-+-+++-+----++--++-+----+----+-++-+---+-+--++--++++-+---+-+---+-+++-+--+++-++----+----+++++++-++++-+--++--++++-+---+-++-++++-++++--+---++-+---+-++--++++-++++-+++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-++---+--+ ++---+--++++-++++-+-+-+++-+----++--++-+----+----+-++-+---+-+--++--++++-+---+-+---+-+++-+--+++-++----+----+++++++-++++-+--++--++++-+---+-++-++++-++++--+---++-+---+-++--++++-++++-+++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-++---+-- +---+--++++-++++-+-+-+++-+----++--++-+----+----+-++-+---+-++-++--++++-+---+-+---+-+++-+--+++-++----+----+++++++-++++-+--++--++++-+---+-++-++++-++++--+---++-+---+-++--++++-++++-+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-++---+- +--+--++++-++++-+-+-+++-+----++--++-+----+----+-++-+---+-++-++--++++-+---+-+---+-+++-+--+++-++----+----+++++++-++++-+--++--++++-+---+-++-++++-++++--+---++-+---+-++--++++-++++-+----++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-++---+ +-+--++++-++++-+-+-+++-+----++--++-+----+----+-++-+---+-++--+--++++-+---+-+---+-+++-+--+++-++----+----+++++++-++++-+--++--++++-+---+-++-++++-++++--+---++-+---+-++--++++-++++-+--++--++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-++--- ++--++++-++++-+-+-+++-+----++--++-+----+----+-++-+---+-++-----++++-+---+-+---+-+++-+--+++-++----+----+++++++-++++-+--++--++++-+---+-++-++++-++++--+---++-+---+-++--++++-++++-+--++-+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-++-- +--++++-++++-+-+-+++-+----++--++-+----+----+-++-+---+-++---+-++++-+---+-+---+-+++-+--+++-++----+----+++++++-++++-+--++--++++-+---+-++-++++-++++--+---++-+---+-++--++++-++++-+--++---+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-++- +-++++-++++-+-+-+++-+----++--++-+----+----+-++-+---+-++---+-++++-+---+-+---+-+++-+--+++-++----+----+++++++-++++-+--++--++++-+---+-++-++++-++++--+---++-+---+-++--++++-++++-+--++-----+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-++ +++++-++++-+-+-+++-+----++--++-+----+----+-++-+---+-++---+--+++-+---+-+---+-+++-+--+++-++----+----+++++++-++++-+--++--++++-+---+-++-++++-++++--+---++-+---+-++--++++-++++-+--++--++---+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+-+ ++++-++++-+-+-+++-+----++--++-+----+----+-++-+---+-++---+--+++-+---+-+---+-+++-+--+++-++----+----+++++++-++++-+--++--++++-+---+-++-++++-++++--+---++-+---+-++--++++-++++-+--++--++++---+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+- +++-++++-+-+-+++-+----++--++-+----+----+-++-+---+-++---+--+++-+---+-+---+-+++-+--+++-++----+----+++++++-++++-+--++--++++-+---+-++-++++-++++--+---++-+---+-++--++++-++++-+--++--+++-++---+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+---+ ++-++++-+-+-+++-+----++--++-+----+----+-++-+---+-++---+--+++-+---+-+---+-+++-+--+++-++----+----+++++++-++++-+--++--++++-+---+-++-++++-++++--+---++-+---+-++--++++-++++-+--++--+++++-++---+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+--- +-++++-+-+-+++-+----++--++-+----+----+-++-+---+-++---+--+++++---+-+---+-+++-+--+++-++----+----+++++++-++++-+--++--++++-+---+-++-++++-++++--+---++-+---+-++--++++-++++-+--++--++++--+-++---+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+-- +++++-+-+-+++-+----++--++-+----+----+-++-+---+-++---+--++++----+-+---+-+++-+--+++-++----+----+++++++-++++-+--++--++++-+---+-++-++++-++++--+---++-+---+-++--++++-++++-+--++--++++-+--+-++---+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+- ++++-+-+-+++-+----++--++-+----+----+-++-+---+-++---+--++++-+--+-+---+-+++-+--+++-++----+----+++++++-++++-+--++--++++-+---+-++-++++-++++--+---++-+---+-++--++++-++++-+--++--++++-+----+-++---+--++++-++++--+++++-++++-+--++--++++-+---+-++++-+ +++-+-+-+++-+----++--++-+----+----+-++-+---+-++---+--++++-++-+-+---+-+++-+--+++-++----+----+++++++-++++-+--++--++++-+---+-++-++++-++++--+---++-+---+-++--++++-++++-+--++--++++-+--+---+-++---+--++++-++++--+++++-++++-+--++--++++-+---+-++++- ++-+-+-+++-+----++--++-+----+----+-++-+---+-++---+--++++-++++-+---+-+++-+--+++-++----+----+++++++-++++-+--++--++++-+---+-++-++++-++++--+---++-+---+-++--++++-++++-+--++--++++-+----+---+-++---+--++++-++++--+++++-++++-+--++--++++-+---+-++++ +-+-+-+++-+----++--++-+----+----+-++-+---+-++---+--++++-++++-+---+-+++-+--+++-++----+----+++++++-++++-+--++--++++-+---+-++-++++-++++--+---++-+---+-++--++++-++++-+--++--++++-+---++-+---+-++---+--++++-++++--+++++-++++-+--++--++++-+---+-+++ ++-+-+++-+----++--++-+----+----+-++-+---+-++---+--++++-++++-+---+-+++-+--+++-++----+----+++++++-++++-+--++--++++-+---+-++-++++-++++--+---++-+---+-++--++++-++++-+--++--++++-+---+-++-+---+-++---+--++++-++++--+++++-++++-+--++--++++-+---+-++ +-+-+++-+----++--++-+----+----+-++-+---+-++---+--++++-++++-+---+-+++-+--+++-++----+----+++++++-++++-+--++--++++-+---+-++-++++-++++--+---++-+---+-++--++++-++++-+--++--++++-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+--++--++++-+---+-+ ++-+++-+----++--++-+----+----+-++-+---+-++---+--++++-++++-+---+-+++-+--+++-++----+----+++++++-++++-+--++--++++-+---+-+--++++-++++--+---++-+---+-++--++++-++++-+--++--++++-+---+-++++++-+---+-++---+--++++-++++--+++++-++++-+--++--++++-+---+- +-+++-+----++--++-+----+----+-++-+---+-++---+--++++-++++-+-+-+-+++-+--+++-++----+----+++++++-++++-+--++--++++-+---+-+--++++-++++--+---++-+---+-++--++++-++++-+--++--++++-+---+-++--++++-+---+-++---+--++++-++++--+++++-++++-+--++--++++-+---+ ++++-+----++--++-+----+----+-++-+---+-++---+--++++-++++-+-+-+-+++-+--+++-++----+----+++++++-++++-+--++--++++-+---+-+---+++-++++--+---++-+---+-++--++++-++++-+--++--++++-+---+-++-++-++++-+---+-++---+--++++-++++--+++++-++++-+--++--++++-+--- +++-+----++--++-+----+----+-++-+---+-++---+--++++-++++-+-+-+-+++-+--+++-++----+----+++++++-++++-+--++--++++-+---+-+---+++-++++--+---++-+---+-++--++++-++++-+--++--++++-+---+-++-++-+-++++-+---+-++---+--++++-++++--+++++-++++-+--++--++++-+-- ++-+----++--++-+----+----+-++-+---+-++---+--++++-++++-+-+-+++++-+--+++-++----+----+++++++-++++-+--++--++++-+---+-+---+-+-++++--+---++-+---+-++--++++-++++-+--++--++++-+---+-++-+++--+-++++-+---+-++---+--++++-++++--+++++-++++-+--++--++++-+- +-+----++--++-+----+----+-++-+---+-++---+--++++-++++-+-+-+++++-+--+++-++----+----+++++++-++++-+--++--++++-+---+-+---+-+-++++--+---++-+---+-++--++++-++++-+--++--++++-+---+-++-++++---+-++++-+---+-++---+--++++-++++--+++++-++++-+--++--++++-+ ++----++--++-+----+----+-++-+---+-++---+--++++-++++-+-+-+++-+-+--+++-++----+----+++++++-++++-+--++--++++-+---+-+---+-++++++--+---++-+---+-++--++++-++++-+--++--++++-+---+-++-++++-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+--++--++++- +----++--++-+----+----+-++-+---+-++---+--++++-++++-+-+-+++-+-+--+++-++----+----+++++++-++++-+--++--++++-+---+-+---+-++++++--+---++-+---+-++--++++-++++-+--++--++++-+---+-++-++++-+-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+--++--++++ +---++--++-+----+----+-++-+---+-++---+--++++-++++-+-+-+++-+-+--+++-++----+----+++++++-++++-+--++--++++-+---+-+---+-+++-++--+---++-+---+-++--++++-++++-+--++--++++-+---+-++-++++-+++-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+--++--+++ +--++--++-+----+----+-++-+---+-++---+--++++-++++-+-+-+++-+----+++-++----+----+++++++-++++-+--++--++++-+---+-+---+-+++-++--+---++-+---+-++--++++-++++-+--++--++++-+---+-++-++++-+++++-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+--++--++ +-++--++-+----+----+-++-+---+-++---+--++++-++++-+-+-+++-+----+++-++----+----+++++++-++++-+--++--++++-+---+-+---+-+++-+---+---++-+---+-++--++++-++++-+--++--++++-+---+-++-++++-+++++++-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+--++--+ +++--++-+----+----+-++-+---+-++---+--++++-++++-+-+-+++-+----+++-++----+----+++++++-++++-+--++--++++-+---+-+---+-+++-+---+---++-+---+-++--++++-++++-+--++--++++-+---+-++-++++-++++-++++-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+--++-- ++--++-+----+----+-++-+---+-++---+--++++-++++-+-+-+++-+----+++-++----+----+++++++-++++-+--++--++++-+---+-+---+-+++-+--++---++-+---+-++--++++-++++-+--++--++++-+---+-++-++++-++++---++++-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+--++- +--++-+----+----+-++-+---+-++---+--++++-++++-+-+-+++-+----+++-++----+----+++++++-++++-+--++--++++-+---+-+---+-+++-+--++---++-+---+-++--++++-++++-+--++--++++-+---+-++-++++-++++--+--++++-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+--++ +-++-+----+----+-++-+---+-++---+--++++-++++-+-+-+++-+----++--++----+----+++++++-++++-+--++--++++-+---+-+---+-+++-+--+++--++-+---+-++--++++-++++-+--++--++++-+---+-++-++++-++++--+-+--++++-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+--+ +++-+----+----+-++-+---+-++---+--++++-++++-+-+-+++-+----++--++----+----+++++++-++++-+--++--++++-+---+-+---+-+++-+--+++--++-+---+-++--++++-++++-+--++--++++-+---+-++-++++-++++--+--++--++++-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+-- ++-+----+----+-++-+---+-++---+--++++-++++-+-+-+++-+----++--++----+----+++++++-++++-+--++--++++-+---+-+---+-+++-+--+++-+++-+---+-++--++++-++++-+--++--++++-+---+-++-++++-++++--+----++--++++-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+- +-+----+----+-++-+---+-++---+--++++-++++-+-+-+++-+----++--++----+----+++++++-++++-+--++--++++-+---+-+---+-+++-+--+++-+++-+---+-++--++++-++++-+--++--++++-+---+-++-++++-++++--+---+--++--++++-+---+-++++-+---+-++---+--++++-++++--+++++-++++-+ ++----+----+-++-+---+-++---+--++++-++++-+-+-+++-+----++--++----+----+++++++-++++-+--++--++++-+---+-+---+-+++-+--+++-++--+---+-++--++++-++++-+--++--++++-+---+-++-++++-++++--+---+++--++--++++-+---+-++++-+---+-++---+--++++-++++--+++++-++++- +----+----+-++-+---+-++---+--++++-++++-+-+-+++-+----++--++-+--+----+++++++-++++-+--++--++++-+---+-+---+-+++-+--+++-++--+---+-++--++++-++++-+--++--++++-+---+-++-++++-++++--+---++--+--++--++++-+---+-++++-+---+-++---+--++++-++++--+++++-++++ +---+----+-++-+---+-++---+--++++-++++-+-+-+++-+----++--++-+--+----+++++++-++++-+--++--++++-+---+-+---+-+++-+--+++-++------+-++--++++-++++-+--++--++++-+---+-++-++++-++++--+---++-++-+--++--++++-+---+-++++-+---+-++---+--++++-++++--+++++-+++ +--+----+-++-+---+-++---+--++++-++++-+-+-+++-+----++--++-+--+----+++++++-++++-+--++--++++-+---+-+---+-+++-+--+++-++------+-++--++++-++++-+--++--++++-+---+-++-++++-++++--+---++-+-++-+--++--++++-+---+-++++-+---+-++---+--++++-++++--+++++-++ +-+----+-++-+---+-++---+--++++-++++-+-+-+++-+----++--++-+-------+++++++-++++-+--++--++++-+---+-+---+-+++-+--+++-++----+-+-++--++++-++++-+--++--++++-+---+-++-++++-++++--+---++-+--+++-+--++--++++-+---+-++++-+---+-++---+--++++-++++--+++++-+ ++----+-++-+---+-++---+--++++-++++-+-+-+++-+----++--++-+-------+++++++-++++-+--++--++++-+---+-+---+-+++-+--+++-++----+-+-++--++++-++++-+--++--++++-+---+-++-++++-++++--+---++-+---++++-+--++--++++-+---+-++++-+---+-++---+--++++-++++--+++++- +----+-++-+---+-++---+--++++-++++-+-+-+++-+----++--++-+----+--+++++++-++++-+--++--++++-+---+-+---+-+++-+--+++-++----+---++--++++-++++-+--++--++++-+---+-++-++++-++++--+---++-+---+-++++-+--++--++++-+---+-++++-+---+-++---+--++++-++++--+++++ +---+-++-+---+-++---+--++++-++++-+-+-+++-+----++--++-+----+--+++++++-++++-+--++--++++-+---+-+---+-+++-+--+++-++----+---++--++++-++++-+--++--++++-+---+-++-++++-++++--+---++-+---+-+-++++-+--++--++++-+---+-++++-+---+-++---+--++++-++++--++++ +--+-++-+---+-++---+--++++-++++-+-+-+++-+----++--++-+----+--+++++++-++++-+--++--++++-+---+-+---+-+++-+--+++-++----+----+--++++-++++-+--++--++++-+---+-++-++++-++++--+---++-+---+-+++-++++-+--++--++++-+---+-++++-+---+-++---+--++++-++++--+++ +-+-++-+---+-++---+--++++-++++-+-+-+++-+----++--++-+----+---++++++-++++-+--++--++++-+---+-+---+-+++-+--+++-++----+----+--++++-++++-+--++--++++-+---+-++-++++-++++--+---++-+---+-+++++-++++-+--++--++++-+---+-++++-+---+-++---+--++++-++++--++ ++-++-+---+-++---+--++++-++++-+-+-+++-+----++--++-+----+----+++++-++++-+--++--++++-+---+-+---+-+++-+--+++-++----+----++-++++-++++-+--++--++++-+---+-++-++++-++++--+---++-+---+-++-++++-++++-+--++--++++-+---+-++++-+---+-++---+--++++-++++--+ \ No newline at end of file diff --git a/gptqmodel/exllamav3/util/hadamard_data/hadamard_244.txt b/gptqmodel/exllamav3/util/hadamard_data/hadamard_244.txt new file mode 100644 index 000000000..933b8bd60 --- /dev/null +++ b/gptqmodel/exllamav3/util/hadamard_data/hadamard_244.txt @@ -0,0 +1,244 @@ +++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--++---+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---++++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++ ++++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+---+---+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+--+++++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-++ +-+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+---+---+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+-++++++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+ +--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+---+---+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-++++++++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+- ++--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+---+---+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+--+++++++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+ +-+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+---+---+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-++-+++++++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+- +--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--+++-+---+---+-+-++++---++--+-++-+---++++++---+-++-+--++---++++--+-+++++++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+ ++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--+-+-+---+---+-+-++++---++--+-++-+---++++++---+-++-+--++---+++++-+-+++++++-+-+----+++--++-+--+-+++------+++-+--+-++--+++---- +++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--+-+-+---+---+-+-++++---++--+-++-+---++++++---+-++-+--++---+++-+-+-+++++++-+-+----+++--++-+--+-+++------+++-+--+-++--+++--- +-++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+-++-+-+---+---+-+-++++---++--+-++-+---++++++---+-++-+--++---++--+-+-+++++++-+-+----+++--++-+--+-+++------+++-+--+-++--+++-- +--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-++++-+-+---+---+-+-++++---++--+-++-+---++++++---+-++-+--++---+---+-+-+++++++-+-+----+++--++-+--+-+++------+++-+--+-++--+++- ++--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-++++-+-+---+---+-+-++++---++--+-++-+---++++++---+-++-+--++-------+-+-+++++++-+-+----+++--++-+--+-+++------+++-+--+-++--+++ +-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-++++-+-+---+---+-+-++++---++--+-++-+---++++++---+-++-+--++--+----+-+-+++++++-+-+----+++--++-+--+-+++------+++-+--+-++--++ ++-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++---++++-+-+---+---+-+-++++---++--+-++-+---++++++---+-++-+--++-++----+-+-+++++++-+-+----+++--++-+--+-+++------+++-+--+-++--+ +-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++---++++-+-+---+---+-+-++++---++--+-++-+---++++++---+-++-+--+++++----+-+-+++++++-+-+----+++--++-+--+-+++------+++-+--+-++-- ++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++---++++-+-+---+---+-+-++++---++--+-++-+---++++++---+-++-+--+-+++----+-+-+++++++-+-+----+++--++-+--+-+++------+++-+--+-++- +++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++---++++-+-+---+---+-+-++++---++--+-++-+---++++++---+-++-+----+++----+-+-+++++++-+-+----+++--++-+--+-+++------+++-+--+-++ ++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--+-++---++++-+-+---+---+-+-++++---++--+-++-+---++++++---+-++-+-+--+++----+-+-+++++++-+-+----+++--++-+--+-+++------+++-+--+-+ +++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+----++---++++-+-+---+---+-+-++++---++--+-++-+---++++++---+-++-+++--+++----+-+-+++++++-+-+----+++--++-+--+-+++------+++-+--+- +-++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+-+--++---++++-+-+---+---+-+-++++---++--+-++-+---++++++---+-++--++--+++----+-+-+++++++-+-+----+++--++-+--+-+++------+++-+--+ +--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+-+--++---++++-+-+---+---+-+-++++---++--+-++-+---++++++---+-+++-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++------+++-+-- ++--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+-+--++---++++-+-+---+---+-+-++++---++--+-++-+---++++++---+-+-+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++------+++-+- +-+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+----++-+--++---++++-+-+---+---+-+-++++---++--+-++-+---++++++---+---+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++------+++-+ +--+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+----++-+--++---++++-+-+---+---+-+-++++---++--+-++-+---++++++---++--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++------+++- +---+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+--+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+---++++++----+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++------+++ +----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+--+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+---++++++--+-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++------++ +-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+--+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+---++++++-++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++------+ ++-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+---------+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+---+++++++++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++------ +-+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+-----+---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+---+++++-+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++----- +--+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+----++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+---++++--+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++---- +---+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+---+++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+---+++---+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++--- +----+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+--++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+---++----+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++-- +-----+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+-+++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+---+-----+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++- +------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+---------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++ ++------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+------++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+--+------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-++ +-+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+------++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+-++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+ +--+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+------++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-++++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+- +---+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+--+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++--+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+ +----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+--+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-+++-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+-- +-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-+-+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+- ++-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+---+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+ +-+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--++--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++- +--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-+++++-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++---+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++ ++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-+++-+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++-+-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--+ +++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++-- ++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---+-++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++- +++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++-----++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++ +-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++--+--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----++ ++-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++-++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+ +-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-+++++++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+---- ++-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-+++-+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+--- +-+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++-++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++--+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+-- +--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--+++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-+---+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+- ++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--+++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+ +++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+---++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-++----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+- +-++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+--+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+ +--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---++-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++- ++--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+----+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++ +-+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+--+-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-++++++ +--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+-++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++ ++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--++---+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---++++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-++++ +--++-++--++-+-+----++-+++++-++++++-+++++-++----+-+-++--++-++-++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+----+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+--- +---++-++--++-+-+----++-+++++-++++++-+++++-++----+-+-++--++-+++++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+-- ++---++-++--++-+-+----++-+++++-++++++-+++++-++----+-+-++--++-+-+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+- +++---++-++--++-+-+----++-+++++-++++++-+++++-++----+-+-++--++---+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+ +-++---++-++--++-+-+----++-+++++-++++++-+++++-++----+-+-++--+++--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+- ++-++---++-++--++-+-+----++-+++++-++++++-+++++-++----+-+-++--+-+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+ +++-++---++-++--++-+-+----++-+++++-++++++-+++++-++----+-+-++----+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--+++-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+---++++++---+-++-+--++---++++- +-++-++---++-++--++-+-+----++-+++++-++++++-+++++-++----+-+-++-+--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--+-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+---++++++---+-++-+--++---++++ +--++-++---++-++--++-+-+----++-+++++-++++++-+++++-++----+-+-++++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--+-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+---++++++---+-++-+--++---+++ ++--++-++---++-++--++-+-+----++-+++++-++++++-+++++-++----+-+-+-++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+-++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+---++++++---+-++-+--++---++ +++--++-++---++-++--++-+-+----++-+++++-++++++-+++++-++----+-+---++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+---++++++---+-++-+--++---+ +-++--++-++---++-++--++-+-+----++-+++++-++++++-+++++-++----+-++--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+---++++++---+-++-+--++--- ++-++--++-++---++-++--++-+-+----++-+++++-++++++-+++++-++----+--+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+---++++++---+-++-+--++-- +-+-++--++-++---++-++--++-+-+----++-+++++-++++++-+++++-++----++-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+---++++++---+-++-+--++- ++-+-++--++-++---++-++--++-+-+----++-+++++-++++++-+++++-++-----+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+---++++++---+-++-+--++ +-+-+-++--++-++---++-++--++-+-+----++-+++++-++++++-+++++-++---+-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+---++++++---+-++-+--+ +--+-+-++--++-++---++-++--++-+-+----++-+++++-++++++-+++++-++--++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+---++++++---+-++-+-- +---+-+-++--++-++---++-++--++-+-+----++-+++++-++++++-+++++-++-+++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--+-++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+---++++++---+-++-+- +----+-+-++--++-++---++-++--++-+-+----++-+++++-++++++-+++++-++++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+----++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+---++++++---+-++-+ ++----+-+-++--++-++---++-++--++-+-+----++-+++++-++++++-+++++-+-++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+---++++++---+-++- +++----+-+-++--++-++---++-++--++-+-+----++-+++++-++++++-+++++---++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+---++++++---+-++ +-++----+-+-++--++-++---++-++--++-+-+----++-+++++-++++++-++++++--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+---++++++---+-+ ++-++----+-+-++--++-++---++-++--++-+-+----++-+++++-++++++-++++-+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+----++-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+---++++++---+- +++-++----+-+-++--++-++---++-++--++-+-+----++-+++++-++++++-+++--+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+----++-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+---++++++---+ ++++-++----+-+-++--++-++---++-++--++-+-+----++-+++++-++++++-++---+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+--+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+---++++++--- +++++-++----+-+-++--++-++---++-++--++-+-+----++-+++++-++++++-+----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+--+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+---++++++-- ++++++-++----+-+-++--++-++---++-++--++-+-+----++-+++++-++++++------+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+--+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+---++++++- +-+++++-++----+-+-++--++-++---++-++--++-+-+----++-+++++-+++++++-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+---------+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+---++++++ ++-+++++-++----+-+-++--++-++---++-++--++-+-+----++-+++++-+++++-+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+-----+---+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+---+++++ +++-+++++-++----+-+-++--++-++---++-++--++-+-+----++-+++++-++++--+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+----++---+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+---++++ ++++-+++++-++----+-+-++--++-++---++-++--++-+-+----++-+++++-+++---+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+---+++---+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+---+++ +++++-+++++-++----+-+-++--++-++---++-++--++-+-+----++-+++++-++----+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+--++++---+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+---++ ++++++-+++++-++----+-+-++--++-++---++-++--++-+-+----++-+++++-+-----+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+-+++++---+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+---+ +++++++-+++++-++----+-+-++--++-++---++-++--++-+-+----++-+++++-------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+++++++---+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+--- +-++++++-+++++-++----+-+-++--++-++---++-++--++-+-+----++-++++++------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+------++++++---+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+-- ++-++++++-+++++-++----+-+-++--++-++---++-++--++-+-+----++-++++-+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+------++++++---+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+- +++-++++++-+++++-++----+-+-++--++-++---++-++--++-+-+----++-+++--+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+------++++++---+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++-+ ++++-++++++-+++++-++----+-+-++--++-++---++-++--++-+-+----++-++---+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+--+---++++++---+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++- +++++-++++++-+++++-++----+-+-++--++-++---++-++--++-+-+----++-+----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+--+---++++++---+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-++ ++++++-++++++-+++++-++----+-+-++--++-++---++-++--++-+-+----++------+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--++-+---++++++---+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+-+ +-+++++-++++++-+++++-++----+-+-++--++-++---++-++--++-+-+----+++-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--++-+---++++++---+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+- ++-+++++-++++++-+++++-++----+-+-++--++-++---++-++--++-+-+----+-+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--++-+---++++++---+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++--+ +++-+++++-++++++-+++++-++----+-+-++--++-++---++-++--++-+-+------+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-+++++-++-+---++++++---+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++-- +-++-+++++-++++++-+++++-++----+-+-++--++-++---++-++--++-+-+---+--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-+++-+-++-+---++++++---+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++- +--++-+++++-++++++-+++++-++----+-+-++--++-++---++-++--++-+-+--++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++--+-++-+---++++++---+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---++ +---++-+++++-++++++-+++++-++----+-+-++--++-++---++-++--++-+-+-+++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++--+-++-+---++++++---+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++---+ +----++-+++++-++++++-+++++-++----+-+-++--++-++---++-++--++-+-+++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++--+-++-+---++++++---+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++--- ++----++-+++++-++++++-+++++-++----+-+-++--++-++---++-++--++-+--++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++--+-++-+---++++++---+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++-- +-+----++-+++++-++++++-+++++-++----+-+-++--++-++---++-++--++-++-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+---++--+-++-+---++++++---+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++- ++-+----++-+++++-++++++-+++++-++----+-+-++--++-++---++-++--++--+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+---++--+-++-+---++++++---+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++++ +-+-+----++-+++++-++++++-+++++-++----+-+-++--++-++---++-++--+++-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+---++--+-++-+---++++++---+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-+++ ++-+-+----++-+++++-++++++-+++++-++----+-+-++--++-++---++-++--+-+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++-++---++--+-++-+---++++++---+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-++ +++-+-+----++-+++++-++++++-+++++-++----+-+-++--++-++---++-++----+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--+++++---++--+-++-+---++++++---+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+-+ +-++-+-+----++-+++++-++++++-+++++-++----+-+-++--++-++---++-++-+--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--+++++---++--+-++-+---++++++---+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+- +--++-+-+----++-+++++-++++++-+++++-++----+-+-++--++-++---++-++++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+---++++---++--+-++-+---++++++---+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+-+ ++--++-+-+----++-+++++-++++++-+++++-++----+-+-++--++-++---++-+-++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+- +++--++-+-+----++-+++++-++++++-+++++-++----+-+-++--++-++---++---++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+---+ +-++--++-+-+----++-+++++-++++++-+++++-++----+-+-++--++-++---+++--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+--- ++-++--++-+-+----++-+++++-++++++-+++++-++----+-+-++--++-++---+-+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+-- +++-++--++-+-+----++-+++++-++++++-+++++-++----+-+-++--++-++-----+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+- +-++-++--++-+-+----++-+++++-++++++-+++++-++----+-+-++--++-++--+--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--++---+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+---+ +-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+--++-++--++-+-+----++-+++++-++++++-+++++-++----+-+-++--++-++- ++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+-----++-++--++-+-+----++-+++++-++++++-+++++-++----+-+-++--++-++ +++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+-+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+-+---++-++--++-+-+----++-+++++-++++++-+++++-++----+-+-++--++-+ ++++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+---+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+++---++-++--++-+-+----++-+++++-++++++-+++++-++----+-+-++--++- +-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-++--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++---++---++-++--++-+-+----++-+++++-++++++-+++++-++----+-+-++--++ ++-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++-+-++---++-++--++-+-+----++-+++++-++++++-+++++-++----+-+-++--+ +-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++++-++---++-++--++-+-+----++-+++++-++++++-+++++-++----+-+-++-- ++-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--+-++-++---++-++--++-+-+----++-+++++-++++++-+++++-++----+-+-++- +-+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++------+++-+--+-++--+++---++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+----++-++---++-++--++-+-+----++-+++++-++++++-+++++-++----+-+-++ +--+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++------+++-+--+-++--+++---++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+-+--++-++---++-++--++-+-+----++-+++++-++++++-+++++-++----+-+-+ +---+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++------+++-+--+-++--+++---++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+++--++-++---++-++--++-+-+----++-+++++-++++++-+++++-++----+-+- +----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++------+++-+--+-++--++++--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+--++--++-++---++-++--++-+-+----++-+++++-++++++-+++++-++----+-+ ++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++------+++-+--+-++--++-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-++-++--++-++---++-++--++-+-+----++-+++++-++++++-+++++-++----+- +++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++------+++-+--+-++--++-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++--+-++--++-++---++-++--++-+-+----++-+++++-++++++-+++++-++----+ ++++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++------+++-+--+-++---+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--+++++-+-++--++-++---++-++--++-+-+----++-+++++-++++++-+++++-++---- +-+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++------+++-+--+-++-+-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--+++-+-+-++--++-++---++-++--++-+-+----++-+++++-++++++-+++++-++--- +--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++------+++-+--+-++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++--+-+-++--++-++---++-++--++-+-+----++-+++++-++++++-+++++-++-- ++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++------+++-+--+-++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--+---+-+-++--++-++---++-++--++-+-+----++-+++++-++++++-+++++-++- +++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++------+++-+--+-++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+------+-+-++--++-++---++-++--++-+-+----++-+++++-++++++-+++++-++ +-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++------+++-+--+-++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+-+----+-+-++--++-++---++-++--++-+-+----++-+++++-++++++-+++++-+ ++-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++------+++-+----++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+++----+-+-++--++-++---++-++--++-+-+----++-+++++-++++++-+++++- +-+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++------+++-+-+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+------++----+-+-++--++-++---++-++--++-+-+----++-+++++-++++++-+++++ +--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++------+++-+-+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+----+-++----+-+-++--++-++---++-++--++-+-+----++-+++++-++++++-++++ ++--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++------+++---+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+---++-++----+-+-++--++-++---++-++--++-+-+----++-+++++-++++++-+++ +-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++------+++---+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+--+++-++----+-+-++--++-++---++-++--++-+-+----++-+++++-++++++-++ ++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++------++----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-++++-++----+-+-++--++-++---++-++--++-+-+----++-+++++-++++++-+ +++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------++++++-++----+-+-++--++-++---++-++--++-+-+----++-+++++-++++++- ++++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+-------+++++-++----+-+-++--++-++---++-++--++-+-+----++-+++++-++++++ +-+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+-----+-+++++-++----+-+-++--++-++---++-++--++-+-+----++-+++++-+++++ +--+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+----++-+++++-++----+-+-++--++-++---++-++--++-+-+----++-+++++-++++ +---+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+---+++-+++++-++----+-+-++--++-++---++-++--++-+-+----++-+++++-+++ +----+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+--++++-+++++-++----+-+-++--++-++---++-++--++-+-+----++-+++++-++ +-----+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+-+++++-+++++-++----+-+-++--++-++---++-++--++-+-+----++-+++++-+ +------+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+++++++-+++++-++----+-+-++--++-++---++-++--++-+-+----++-+++++- ++------+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+++------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+------++++++-+++++-++----+-+-++--++-++---++-++--++-+-+----++-+++++ +++------+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+-+-+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+----+-++++++-+++++-++----+-+-++--++-++---++-++--++-+-+----++-++++ ++++------+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+---+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+---++-++++++-+++++-++----+-+-++--++-++---++-++--++-+-+----++-+++ +-+++------+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+--+---+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+--+++-++++++-+++++-++----+-+-++--++-++---++-++--++-+-+----++-++ ++-+++------+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+------+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-++++-++++++-+++++-++----+-+-++--++-++---++-++--++-+-+----++-+ +-+-+++------+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-+------+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--++++++-++++++-+++++-++----+-+-++--++-++---++-++--++-+-+----++- +--+-+++------+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++-++-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++---+++++-++++++-+++++-++----+-+-++--++-++---++-++--++-+-+----++ ++--+-+++------+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++-+-+++++-++++++-+++++-++----+-+-++--++-++---++-++--++-+-+----+ +-+--+-+++------+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++++-+++++-++++++-+++++-++----+-+-++--++-++---++-++--++-+-+---- ++-+--+-+++------+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-+++-++-+++++-++++++-+++++-++----+-+-++--++-++---++-++--++-+-+--- +++-+--+-+++------+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++--++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++--++-+++++-++++++-+++++-++----+-+-++--++-++---++-++--++-+-+-- +-++-+--+-+++------+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++-+++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-+---++-+++++-++++++-+++++-++----+-+-++--++-++---++-++--++-+-+- +--++-+--+-+++------+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----+++++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-----++-+++++-++++++-+++++-++----+-+-++--++-++---++-++--++-+-+ ++--++-+--+-+++------+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----++-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-++----++-+++++-++++++-+++++-++----+-+-++--++-++---++-++--++-+- +++--++-+--+-+++------+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+----++-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+--+----++-+++++-++++++-+++++-++----+-+-++--++-++---++-++--++-+ ++++--++-+--+-+++------+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+-----+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--++-+----++-+++++-++++++-+++++-++----+-+-++--++-++---++-++--++- +-+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+---+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++---+-+----++-+++++-++++++-+++++-++----+-+-++--++-++---++-++--++ +--+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+---+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++-+-+-+----++-+++++-++++++-+++++-++----+-+-++--++-++---++-++--+ +---+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-+---+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++++-+-+----++-+++++-++++++-+++++-++----+-+-++--++-++---++-++-- +----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--+-++-+-+----++-+++++-++++++-+++++-++----+-+-++--++-++---++-++- ++----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+----++-+-+----++-+++++-++++++-+++++-++----+-+-++--++-++---++-++ +-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++-+-++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+-+--++-+-+----++-+++++-++++++-+++++-++----+-+-++--++-++---++-+ ++-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++++---++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+++--++-+-+----++-+++++-++++++-+++++-++----+-+-++--++-++---++- +-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-++++++++--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++---++--++-+-+----++-+++++-++++++-+++++-++----+-+-++--++-++---++ ++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-++++++-+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++-+-++--++-+-+----++-+++++-++++++-+++++-++----+-+-++--++-++---+ +++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++++-++--++-+-+----++-+++++-++++++-+++++-++----+-+-++--++-++--- ++++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--++-++-++--++-+-+----++-+++++-++++++-+++++-++----+-+-++--++-++-- +----+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+----+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+ +-----+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+--+-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+-- +------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+-++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+-+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+- +-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-++++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+---+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+ ++-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+--+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-++--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++-- +-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-++-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++- ++-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++++--+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++ +-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---+++++-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--+ ++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---+++-+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++---++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+-- +++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---++--+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++---++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+- ++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++---+---+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++---++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+ +++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++-------+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--++++--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+- +-++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++--+----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--++-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+ +--++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--++-++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++--++-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++- +---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--+++++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++---+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++ ++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+--+-+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++-+-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--+++ +++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+----+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++ +-++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+-+--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--+ +--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++-+++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+-- ++--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-++--++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+--+-++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+- +-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-+++-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+----++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+ ++-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+-+-+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+-+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+----- +++-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---+---+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++-+-+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+---- +-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++---++--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++---+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+--- ++-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++----+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+++---+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-- +-+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++--+-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------++----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+- +--+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++++-++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+ +---+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---+++++++++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------ ++---+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---+++++-+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+----- +++---+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---++++--+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+---- ++++---+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---+++---+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+--- +++++---+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---++----+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+-- ++++++---+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---+-----+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+- +++++++---+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+---------+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+ +-++++++---+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+--+------+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+++------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+----- +--++++++---+-++-+--++---++++-+-+-------+-+-++++---++--+-++-+-++------+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+-+-+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+---- +---++++++---+-++-+--++---++++-+-+-------+-+-++++---++--+-++-++++------+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+---+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+--- ++---++++++---+-++-+--++---++++-+-+-------+-+-++++---++--+-++--+++------+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+--+---+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-- +-+---++++++---+-++-+--++---++++-+-+-------+-+-++++---++--+-+++-+++------+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+------+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+- ++-+---++++++---+-++-+--++---++++-+-+-------+-+-++++---++--+-+-+-+++------+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-+------+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+ +++-+---++++++---+-++-+--++---++++-+-+-------+-+-++++---++--+---+-+++------+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++-++-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++-- +-++-+---++++++---+-++-+--++---++++-+-+-------+-+-++++---++--++--+-+++------+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++- ++-++-+---++++++---+-++-+--++---++++-+-+-------+-+-++++---++---+--+-+++------+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++ +-+-++-+---++++++---+-++-+--++---++++-+-+-------+-+-++++---++-+-+--+-+++------+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-+++ +--+-++-+---++++++---+-++-+--++---++++-+-+-------+-+-++++---++++-+--+-+++------+++-+--+-++--+++----+-+-+++-+++-+-+----+++--++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++ ++--+-++-+---++++++---+-++-+--++---++++-+-+-------+-+-++++---+-++-+--+-+++------+++-+--+-++--+++----+-+-+++-+++-+-+----+++-+++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-+ +++--+-++-+---++++++---+-++-+--++---++++-+-+-------+-+-++++-----++-+--+-+++------+++-+--+-++--+++----+-+-+++-+++-+-+----+++++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+- +-++--+-++-+---++++++---+-++-+--++---++++-+-+-------+-+-++++--+--++-+--+-+++------+++-+--+-++--+++----+-+-+++-+++-+-+----++-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+ +--++--+-++-+---++++++---+-++-+--++---++++-+-+-------+-+-++++-++--++-+--+-+++------+++-+--+-++--+++----+-+-+++-+++-+-+----++-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+- +---++--+-++-+---++++++---+-++-+--++---++++-+-+-------+-+-+++++++--++-+--+-+++------+++-+--+-++--+++----+-+-+++-+++-+-+-----+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+ ++---++--+-++-+---++++++---+-++-+--++---++++-+-+-------+-+-+++-+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++-+++-+-+---+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++-- +++---++--+-++-+---++++++---+-++-+--++---++++-+-+-------+-+-++--+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++-+++-+-+---+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++- ++++---++--+-++-+---++++++---+-++-+--++---++++-+-+-------+-+-+---+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++-+++-+-+---+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++ +++++---++--+-++-+---++++++---+-++-+--++---++++-+-+-------+-+-----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++-+++-+-++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--+ +-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+-------+-++----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++-+++-+-++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+-- ++-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+-------+--+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++-+++-+-++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+- +-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+-------++-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++-+++---++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+ ++-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+--------+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++-++++--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++-- +-+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+------+-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++-++-+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++- +--+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+-----++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++-+--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++ +---+-+-++++---++--+-++-+---++++++---+-++-+--++---++++-+-+----+++-+-+----+++--++-+--+-+++------+++-+--+-++--+++----+-+-+++-+--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--+++--+--++--+-+-++++--+-----+------+-----+--++++-+-+--++--+--++ \ No newline at end of file diff --git a/gptqmodel/exllamav3/util/hadamard_data/hadamard_428.txt b/gptqmodel/exllamav3/util/hadamard_data/hadamard_428.txt new file mode 100644 index 000000000..c31041318 --- /dev/null +++ b/gptqmodel/exllamav3/util/hadamard_data/hadamard_428.txt @@ -0,0 +1,428 @@ ++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+- +-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+----+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+ ++-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+-+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+- +-+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++------++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+-+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+-- +--+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+--++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+ +---+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++--++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+---+----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--++ +----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+--------+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++ ++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----+---+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++- +++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++--+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++-- ++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----+++-+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++--- +++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++--++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----+++++----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++---- +-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+ ++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-+---++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+- +++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++--++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+-- +-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----+++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-+++++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++--++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+--- ++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----+++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-+++++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-+++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+---- +++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++-++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++-++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-+++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----+ ++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----+++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--+-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-+++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++ +++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----+--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--+-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++- +-++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+------+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++-- +--++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+----+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++--+-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--+ +---++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++----++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++ +----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+---+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+-+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++----++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++- +-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+-+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-+ ++-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+---+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++ +-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--++-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++- ++-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+--++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+ +-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+- ++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----+-+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+-+----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-+ +++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+------+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++ +-++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++---++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+---+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++- +--++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++--+++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--++--+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-- +---++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++-++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++-+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++--- +----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++---- ++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----++-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++----- +++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----++-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+ ++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----++++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+- +++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++--++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-+++----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++--++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+ +-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+- ++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-+---+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+-+++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-+ +++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++--+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+--++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++ ++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-+++-+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+---+-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-+++ +++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++--+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-+++++---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+-----+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++ +-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-++++++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++----+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----++++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++- ++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-++++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+--+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+-++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+ +++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-++-+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+--+-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-++ ++++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+--+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-++++-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+----+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++ +-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-++-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++--+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---++-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++- ++-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+--+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-++-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+--+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+ +-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--++--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+--++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-++--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+- ++-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++----+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+---+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+ +-+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+-+---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-+-+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+- +--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-++++++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+-----+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-+++++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+-- ++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-++++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+--+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++-++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+ +++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--++-+-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++--+-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--++ ++++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+--+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--++++-++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++----+--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++ +-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--++--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++--++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---++--++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++- ++-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++----++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+++--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+---++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+ +-+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+-+--+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-+-++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+- +--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+----+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++++-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+-- ++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++-----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--+-+++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++-+-----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--+ +++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+----------+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--+++++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++-------+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++ +-++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+--------+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-++++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+----+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++- +--++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+------+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++--+++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--++---+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-- +---++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+----+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++---++-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++--+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++--- +----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+--+-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++----+-+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--++++-+-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++---- +-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-++-+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++------+----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--++++++-+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++----- ++-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++--+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----++----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++--+++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+ +-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---+++++++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-----+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-++++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+- ++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---+++++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+---+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+-++-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+ +++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---+++-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-++--+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+--+-+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-++ ++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---+-+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++ +++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---+++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++- +-++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+--++-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+ +--++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+-+-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-++ +---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+-++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++ ++---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++- +-+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+-+-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-+ +--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++ ++--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++- +-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++---++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+ ++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--++--++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+- +++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+-++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+-- ++++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+--- +-+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++-+--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---+ +--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++ ++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-+-+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++- +++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-+---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++-- +-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++---+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+ ++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-++--+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+- +++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+-+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+-- ++++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+--- +-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++----++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+ ++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--+++---++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+- +++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++--++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+-- ++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--+-++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+--- +++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+---- +-++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+-+-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----+ +--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++ ++--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++- +-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+ ++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-++++----+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+- +++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++---+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-- ++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-++--+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+--- +++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+-+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+---- ++++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+----- +-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+--++---++++-++++--++--+--+-+--+++++-++-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+--+----++++-++-++++-----+-+-++----+++-+----+---+-+-++---+-++--+++++-+----+++-++-+---++--+---+----++-+-----+-+--+++----+----++--++-++-+-++-----+-+-++++-+++-+-+--+++-+--++-----+-+++-+++-++-+---++--+---+----++-+-----+ ++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+--+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-++-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+-- +--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----++++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+------+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+ +-+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++-+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+- ++++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++---+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++---+++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-++---+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-- +++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+--+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++--++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++--+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+--- ++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--++---+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++-+-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-++++-+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+---- +----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-++++++-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+----- +---+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++-+----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++--++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+ +--+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++--++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+- +-+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++---+++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-+-++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+-+----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-+ ++----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+------+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++ +----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-+++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--+---+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++- +---++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+-+-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-+++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++--+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++-- +--++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+--++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-+-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--+++-+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++--- +-++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--+++++---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++---- +++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----+-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++----+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+ ++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----+++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+--+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+- +--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----+++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-++-+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+-- +-++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++-++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+------++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-++++--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+--- +++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++---++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+----++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++---++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+ ++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--+--++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+---+--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-+-++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+- +-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++---++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+----+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+-- +++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++-+---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--+ ++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-++++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++-----+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++ +-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-+++-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+--+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++- ++-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++--+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--++-+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++-- +-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-++-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++--+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--++++-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++--- +++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+--+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----+++--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++--++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+ ++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+- +-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++-----+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+-+-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-+ +----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++--++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++---+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+---+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++ +---+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++-----+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--++++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++- +--+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++------++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+-++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+ +-+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++--------++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++-+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+--+----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-++ ++--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+-------+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++ +--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++- +-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+-+++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+--+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++-- ++----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++--+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++--- +----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-+++++-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++---- +---+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+-+-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-+++-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+ +--+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+--++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-+++++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+- +-+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+---+++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+ ++---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-+++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-++ +---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++ +--+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+-+-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-++++ +-+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+--++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++ ++-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++- +-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++-- ++-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--+ +-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++ +++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++- ++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-+-+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++----++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+ +---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+- +--+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++-+--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-++-++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+-- +-+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++--++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+--- ++-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---+ +-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++ +++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++- ++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-+-+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+ +--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----+++---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+- +-+++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++-+--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+ ++++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-------+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+- +++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+-++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+-- ++++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--++--++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+---+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+--- +++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++---++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+------+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+ ++-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--++++----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+----+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+- +-+---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+--+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+-- ++---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++--+-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+--- +---+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---+++++-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+---- +--+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+-+-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---+++-+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+ +-+---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+--++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---+++-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+- ++---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---+-+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+ +---+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---+++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-++ +--+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+--++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+--++++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+----++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++ +-+--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+----++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+-+++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-++---++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++- ++--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+------++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+++--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++--++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++-- +--+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---++---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+--+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-++++-++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++--- +-+-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+---+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++++-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++---- ++-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+----+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+-+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++-+-+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----+ +-+++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--++--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++---+-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++ ++++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+--+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--++-+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++- +++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+--+-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+ ++--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-++++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+--+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-++-----++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+- +--++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-++++++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++---+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+------++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+ +-++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++-+--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+------++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+- +++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++----+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+---++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-- ++-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-+-++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+---++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+--- +-+++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+---++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+---- ++++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++--++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-++++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--+++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+----- +++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-++--++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--+++++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----+ ++-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-++++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+-++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++-++-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++ +-++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-++++++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--+-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----+++ +++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++--+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--+++++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--+-++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++ ++++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--+++++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++- +++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--+++-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-+ ++--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--+-++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++ +--+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++- +-+-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+-+++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-+ ++-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++----++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++ +-+++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--++--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-+++ ++++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+--+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++---++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++ +++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-++++--++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++- ++++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++-++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++-- +++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-++++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-++++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++----+--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++--- ++-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-++++++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-++-+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++---- +-++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-++++++++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+--+-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+ +++--+++----+----++--++-++-+-++-----+--+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++--+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-+-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++----+-+++++--++-+---++-+-+---+----+-+++----++-+-+-----++++-++-++++----+- ++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+--+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++ +-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----++++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++- +++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+ ++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-+---+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+---+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++---+++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-++ +++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++--+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++--++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++ ++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-+++-+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-------+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++-+-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-++++ +----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-+++++-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+---------+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++ +---+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++--++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++- +--+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++--++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+ +-+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++---+----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-+-++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+- ++--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++--------+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+-- +--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+---+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++--++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-+++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--+ +-+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+---+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++--+-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-+++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++ ++----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+---+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++---++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-+-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--+++ +----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--++---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++-----++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++ +---+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++- +--+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+-++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+ +-+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+--+++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-++ ++++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+------++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++ +++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+-++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+----++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++- ++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----++-++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+---++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+---+--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-+ +++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+----+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++ ++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--+----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++- +-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-++++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++-- ++-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+ +-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--++ ++--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++--+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++ +--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----+++--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++- +-++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-+-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+ +++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--+-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-+++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++-----+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+- ++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--+-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++---+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+-- +++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--+++++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++---++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++-----+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+ ++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--+++++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+---++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+- +---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--+++++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++-+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+-- +--+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++--+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+--- +-+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++--+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+ ++-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+--+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+- +-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-++++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++--+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+-- +++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+--+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-+++++-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+--- ++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-+++-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+ +++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-+++++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+- ++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-+++----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-+++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+ +-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++---++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+--++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-+++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-++ ++++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++---++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+----++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++ +++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+-++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+----+-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-++++ ++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-++++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++ +-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-++++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++- ++-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++---+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++-- +-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--+ ++--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++---+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++ +--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--++-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++- +-+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++----++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+ ++++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+- +++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-++-++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+-- ++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+--- +-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---+ ++--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++---+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++ +--++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--++-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++- +-++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+ +++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+---+++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----+++---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+- ++-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+ +-----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++++-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-------+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+- +----+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-+-++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-++-++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+-- +---+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++---++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++--++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+---+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+--- +--+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++---++++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+------+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+ +-+-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++----+++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-+----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+----+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+- ++-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+--+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+-- +-++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----++-+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++--+--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+--- +++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+--+-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---+++++--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+---- ++++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---+++--+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+ +++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++-+-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++-+-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+- ++---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++-----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+-- +---+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+--------++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+ +--+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+--++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+------++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+- +-+--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+----++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+----++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-- ++--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++-----++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+------++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+--++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+--- +--+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+-++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+----+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+---++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+---- +-+-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+-++-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+------+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+-++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+----- ++-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-++-+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----+ +-+++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----+++--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++--+-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++ ++++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++--+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--++++-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++- +++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+-++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--++-++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+ ++--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-++++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+- +--++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-++++-++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-+ +-++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++ +++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++---+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++- ++-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-+ +-+++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++---++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++ ++++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++--++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++- +++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--+-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-++++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++-- ++-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-++----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--+ +-++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--+++++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-----+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++ +++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++----+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++--+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++---+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++- ++++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-+--+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++--+-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--+++--+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++-- +++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++-+----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++---++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++-+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++--- ++--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++---- +--+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+------+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+ +-+-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+----+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+- ++-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++----+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+-- +-+++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+---+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++--+++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+--- ++++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+-----+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-++++++++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+---- +++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-++++++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+ ++++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-++++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----++ +++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++--+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----++++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-++--++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++ ++-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-++++-+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++-++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+-++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++- +-++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-++++++-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++-- +++-++++----+--+----+++++-+-+--++++---+-++++-+++-+-+--+++-+--++-----+-++++---+--+-+++--++-+++-++++--+-+++++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++-+-++-++--++----+----+++--+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-++-+++++-+--++++-+++-++--+++-+--+---+---+-+++++--++-+---++-+-+---+----+--+-----++-+-++-++--++----+----+++--+ +-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-+++-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+- +++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++------+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+ ++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++---- +---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-++---+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+---+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++--- +--++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++-++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++--+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++-- +-++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++--+-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-++++-+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-------+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++- +++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++----+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-++++++-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+---------+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++ ++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++--++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-+++ +++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++ ++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---+++-++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+-+----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-+ +-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+------+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++- +++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-+++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--+---+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++--++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++ ++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-+++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++--+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++--+-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-+ +++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-+++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--+++-+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++---++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++- ++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-+++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--+++++---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++-----++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++ +--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-+++++++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++----+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----+++ +-++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++-++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+--+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+-++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++ +++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--+-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-++-+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+--+++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----+ ++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--+-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-++++--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+----- +--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++---++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+-++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+---- +-+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++-+--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-+-++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+---++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+--- ++--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++----+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-- +--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+-+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++-+---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--+----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+- +-+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+-+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++-----+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+ ++-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+--+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+- +-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--++-+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+ ++--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+--+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--++++-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++- +--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-++--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++--++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++ +-+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+---+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-+-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----+ ++++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+---+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+-+-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-+++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++---- +++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--++---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+---+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++--- ++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--++---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--++++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++---++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++-- +++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++--++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+-++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+---++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++- ++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--++++-++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+--+-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++ +-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+----+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+++ ++-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++ +-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+ ++----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-++++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++- +----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++ +---+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-+++ +--+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+-------++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++ +-+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+-------++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-+++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-+ ++---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+-------++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+--++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++- +---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+--++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+----++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++ +--+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+--++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+----+-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-++ +-+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+--++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+ ++-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+- +-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+ ++-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+- +-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-++-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++---+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+ +++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+--+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--++-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++-- ++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++- +---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++ +--+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++-+--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-++ +-+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++----+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+ ++-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++----+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+- +-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---++-+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++---+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+ +++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+--+-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--++-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++-- ++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++- +--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++ +-+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++-+++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----+ ++++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+----- +++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-++-++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+---- ++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--++-++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++--++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+--- +++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+-- ++-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-+----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+- +-+---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++++-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++-+ ++---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++- +---+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++++ +--+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+-+---++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---+++ +-+---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+-----++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---++ ++---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+-----++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+--+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-+++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+---+ +---+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+-++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+--+++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+--- +--+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+-++++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+-++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+-- +-+--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+--+++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-++--++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+- ++--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+--+ +--+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---++--+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+-- +-+-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+---+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++-+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+- ++-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+---+-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++---+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++-+ +-+++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--++-+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++- ++++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+--+-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+++ +++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-++-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--++ ++--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-++-+++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++--+ +--++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-++++++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+--+--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+-+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++-- +-++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++-++++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-++--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+---+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++- +++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--+++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-++--++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-++ ++-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--+++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++-++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+-+--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++-+ +-+++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--+++----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-++++++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++- ++++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-----+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-++++++--++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--+-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+++ +++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+---+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++---++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-++ ++-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-++--+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++---++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++-++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++-+ +-++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++---++++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++- +++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--+-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++++ ++++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-+--+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+++-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--+++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--+++ +++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++-+----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+-+-++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--+++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--++ ++--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+---++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--+++++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+--+ +--+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+++++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+-- +-+-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+-+++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-+-++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+- ++-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++----++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+--++---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++-+ +-+++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+---+---++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++- ++++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+-------++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++++ +++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----+--++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++-+-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-++++ ++++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++-++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++--++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+++ +++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----+++++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---+++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-++ ++-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-++++-++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----+++++--+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+-+ +-+-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++++-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++---+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---+++++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+- ++-++---++++-++++--++--+--+-+--+++++-+-+----+---+-+-++---+-++--+++++-+---+---+--+-+++--++-+++-++++--+-+++++-+-+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-+---++++--+-+-+++++----+--+----++++-+-+-----+-++----+---+--++---+-++-+++-+++-+-----++--+-+++--+-+-+++-++++-++-+++++--+-+--+--++--++++-++++---++--+++++-+--++++-+++-++--+++-+--+---++++-+-----++--+-+++--+-+-+++-++++-++++----++-+-+-----++++-++-++++----+-+ diff --git a/gptqmodel/exllamav3/util/hadamard_data/hadamard_52.txt b/gptqmodel/exllamav3/util/hadamard_data/hadamard_52.txt new file mode 100644 index 000000000..72b32aa83 --- /dev/null +++ b/gptqmodel/exllamav3/util/hadamard_data/hadamard_52.txt @@ -0,0 +1,52 @@ ++-+--++++--+-+---++++++---++-+--++--+-++----+--+---- +-+-+--++++--+-+---++++++--+++-+--++--+--+----+--+--- ++-+-+--++++----+---++++++--+++-+--++--+--+----+--+-- +-+-+-+--++++----+---+++++++-+++-+--++-----+----+--+- +--+-+-+--+++++---+---+++++-+-+++-+--++-----+----+--+ ++--+-+-+--+++++---+---++++--+-+++-+--+++----+----+-- +++--+-+-+--+++++---+---++++--+-+++-+--+-+----+----+- ++++--+-+-+--+++++---+---++++--+-+++-+----+----+----+ +++++--+-+-+--+++++---+---+-++--+-+++-+-+--+----+---- +-++++--+-+-+-++++++---+-----++--+-+++-+-+--+----+--- +--++++--+-+-+-++++++---+--+--++--+-+++---+--+----+-- ++--++++--+-+---++++++---+--+--++--+-+++---+--+----+- +-+--++++--+-+---++++++---++-+--++--+-++----+--+----+ +-+++------++++-+--++++--+--++++-++-++++++-+--++--+-+ ++-+++------++-+-+--++++--++-++++-++-++++++-+--++--+- +++-+++------++-+-+--++++--++-++++-++-++-+++-+--++--+ ++++-+++-------+-+-+--++++-+++-++++-++-++-+++-+--++-- +-+++-+++-------+-+-+--++++++++-++++-++--+-+++-+--++- +--+++-+++----+--+-+-+--+++-++++-++++-++--+-+++-+--++ +---+++-+++---++--+-+-+--+++-++++-++++-++--+-+++-+--+ +----+++-+++--+++--+-+-+--+++-++++-++++-++--+-+++-+-- +-----+++-+++-++++--+-+-+---++-++++-++++-++--+-+++-+- +------+++-+++-++++--+-+-+-+-++-++++-+++--++--+-+++-+ ++------+++-++--++++--+-+-+++-++-++++-+++--++--+-+++- +++------+++-++--++++--+-+-+++-++-++++-+-+--++--+-+++ ++++------+++--+--++++--+-+++++-++-++++-+-+--++--+-++ +--+-++--++-+-+----+--+----+-+--++++--+--+++------+++ +---+-++--++-+-+----+--+----+-+--++++--++-+++------++ ++---+-++--++---+----+--+--+-+-+--++++--++-+++------+ +-+---+-++--++---+----+--+--+-+-+--++++-+++-+++------ ++-+---+-++--+----+----+--+--+-+-+--++++-+++-+++----- +++-+---+-++--+----+----+--+--+-+-+--+++--+++-+++---- +-++-+---+-++--+----+----+-++--+-+-+--++---+++-+++--- +--++-+---+-++--+----+----++++--+-+-+--+----+++-+++-- ++--++-+---+-++--+----+----++++--+-+-+-------+++-+++- +++--++-+---+--+--+----+----++++--+-+-+-------+++-+++ +-++--++-+---+--+--+----+----++++--+-+-++------+++-++ ++-++--++-+------+--+----+-+--++++--+-+-++------+++-+ +-+-++--++-+------+--+----+-+--++++--+-++++------+++- +-++++-++-++++--+-++--++-+-+---++++++---+-+--++++--+- ++-++++-++-+++---+-++--++-+-+---++++++---+-+--++++--+ +++-++++-++-+++---+-++--++---+---++++++-+-+-+--++++-- ++++-++++-++-+-+---+-++--++---+---++++++-+-+-+--++++- +++++-++++-++-+-+---+-++--++---+---+++++--+-+-+--++++ +-++++-++++-++++-+---+-++--++---+---+++++--+-+-+--+++ ++-++++-++++-+-++-+---+-++-+++---+---+++++--+-+-+--++ +++-++++-++++---++-+---+-++++++---+---+++++--+-+-+--+ +-++-++++-+++++--++-+---+-++++++---+---+++++--+-+-+-- ++-++-++++-+++++--++-+---+-++++++---+----++++--+-+-+- +++-++-++++-++-++--++-+---+-++++++---+----++++--+-+-+ ++++-++-++++-++-++--++-+-----++++++---+-+--++++--+-+- +++++-++-++++--+-++--++-+-----++++++---+-+--++++--+-+ \ No newline at end of file diff --git a/gptqmodel/exllamav3/util/hadamard_data/hadamard_92.txt b/gptqmodel/exllamav3/util/hadamard_data/hadamard_92.txt new file mode 100644 index 000000000..34b9695c6 --- /dev/null +++ b/gptqmodel/exllamav3/util/hadamard_data/hadamard_92.txt @@ -0,0 +1,92 @@ ++++-+++-+------+-+++-+++++---++-+-++-+-++---+++-++-++--++++++--++-++-++---+---+-++-+---+---+ +++++-+++-+------+-+++-+++++---++-+-++-+-++---+-+-++-++--++++++--++-+++++---+---+-++-+---+--- ++++++-+++-+------+-+++-+++++---++-+-++-+-++---+-+-++-++--++++++--++-+-+++---+---+-++-+---+-- +-+++++-+++-+------+-+++-+++++---++-+-++-+-++--++-+-++-++--++++++--++---+++---+---+-++-+---+- ++-+++++-+++-+------+-++--+++++---++-+-++-+-++--++-+-++-++--++++++--++---+++---+---+-++-+---+ +++-+++++-+++-+------+-+---+++++---++-+-++-+-+++-++-+-++-++--++++++--++---+++---+---+-++-+--- ++++-+++++-+++-+------+-+---+++++---++-+-++-+-+++-++-+-++-++--++++++---+---+++---+---+-++-+-- +-+++-+++++-+++-+------+++---+++++---++-+-++-+--++-++-+-++-++--++++++---+---+++---+---+-++-+- ++-+++-+++++-+++-+-------++---+++++---++-+-++-+--++-++-+-++-++--++++++---+---+++---+---+-++-+ +-+-+++-+++++-+++-+-----+-++---+++++---++-+-++-+--++-++-+-++-++--++++++---+---+++---+---+-++- +--+-+++-+++++-+++-+-----+-++---+++++---++-+-++++--++-++-+-++-++--++++-+---+---+++---+---+-++ +---+-+++-+++++-+++-+---+-+-++---+++++---++-+-++++--++-++-+-++-++--++++-+---+---+++---+---+-+ +----+-+++-+++++-+++-+--++-+-++---+++++---++-+-++++--++-++-+-++-++--++++-+---+---+++---+---+- +-----+-+++-+++++-+++-+--++-+-++---+++++---++-++++++--++-++-+-++-++--+-++-+---+---+++---+---+ +------+-+++-+++++-+++-++-++-+-++---+++++---++-++++++--++-++-+-++-++--+-++-+---+---+++---+--- ++------+-+++-+++++-+++--+-++-+-++---+++++---++-++++++--++-++-+-++-++--+-++-+---+---+++---+-- +-+------+-+++-+++++-++++-+-++-+-++---+++++---+--++++++--++-++-+-++-++--+-++-+---+---+++---+- ++-+------+-+++-+++++-++++-+-++-+-++---+++++---+--++++++--++-++-+-++-+---+-++-+---+---+++---+ +++-+------+-+++-+++++-+-++-+-++-+-++---+++++--++--++++++--++-++-+-++-+---+-++-+---+---+++--- ++++-+------+-+++-+++++---++-+-++-+-++---+++++--++--++++++--++-++-+-++-+---+-++-+---+---+++-- +-+++-+------+-+++-+++++---++-+-++-+-++---++++++-++--++++++--++-++-+-+--+---+-++-+---+---+++- ++-+++-+------+-+++-+++++---++-+-++-+-++---++++++-++--++++++--++-++-+----+---+-++-+---+---+++ +++-+++-+------+-+++-+++++---++-+-++-+-++---+++-++-++--++++++--++-++-++---+---+-++-+---+---++ +---+++--+-+--+-+--+++--+++-+++-+------+-+++-++--+++-+++-+--+-+++-+++-+-++-++--++++++--++-++- +----+++--+-+--+-+--+++-++++-+++-+------+-+++-+---+++-+++-+--+-+++-+++-+-++-++--++++++--++-++ +-----+++--+-+--+-+--++++++++-+++-+------+-+++-+---+++-+++-+--+-+++-+++-+-++-++--++++++--++-+ ++-----+++--+-+--+-+--++-+++++-+++-+------+-+++++---+++-+++-+--+-+++-+++-+-++-++--++++++--++- +++-----+++--+-+--+-+--++-+++++-+++-+------+-+++++---+++-+++-+--+-+++--++-+-++-++--++++++--++ ++++-----+++--+-+--+-+--++-+++++-+++-+------+-+-+++---+++-+++-+--+-++++-++-+-++-++--++++++--+ +-+++-----+++--+-+--+-+-+++-+++++-+++-+------+-+-+++---+++-+++-+--+-++++-++-+-++-++--++++++-- +--+++-----+++--+-+--+-+-+++-+++++-+++-+------+++-+++---+++-+++-+--+-+-++-++-+-++-++--++++++- ++--+++-----+++--+-+--+-+-+++-+++++-+++-+------+++-+++---+++-+++-+--+---++-++-+-++-++--++++++ +-+--+++-----+++--+-+--+-+-+++-+++++-+++-+------+++-+++---+++-+++-+--++--++-++-+-++-++--+++++ ++-+--+++-----+++--+-+----+-+++-+++++-+++-+----+-+++-+++---+++-+++-+--++--++-++-+-++-++--++++ +-+-+--+++-----+++--+-+----+-+++-+++++-+++-+----+-+++-+++---+++-+++-+-+++--++-++-+-++-++--+++ +--+-+--+++-----+++--+-+----+-+++-+++++-+++-+----+-+++-+++---+++-+++-+++++--++-++-+-++-++--++ ++--+-+--+++-----+++--+------+-+++-+++++-+++-+-+--+-+++-+++---+++-+++-+++++--++-++-+-++-++--+ +-+--+-+--+++-----+++--+------+-+++-+++++-+++-+-+--+-+++-+++---+++-+++++++++--++-++-+-++-++-- ++-+--+-+--+++-----+++--+------+-+++-+++++-+++-+-+--+-+++-+++---+++-++-++++++--++-++-+-++-++- +-+-+--+-+--+++-----+++--+------+-+++-+++++-+++++-+--+-+++-+++---+++-+--++++++--++-++-+-++-++ +--+-+--+-+--+++-----++++-+------+-+++-+++++-+++++-+--+-+++-+++---+++-+--++++++--++-++-+-++-+ ++--+-+--+-+--+++-----++++-+------+-+++-+++++-+-+++-+--+-+++-+++---+++++--++++++--++-++-+-++- +++--+-+--+-+--+++-----++++-+------+-+++-+++++-+-+++-+--+-+++-+++---++-++--++++++--++-++-+-++ ++++--+-+--+-+--+++------+++-+------+-+++-+++++++-+++-+--+-+++-+++---++-++--++++++--++-++-+-+ +-+++--+-+--+-+--+++----+-+++-+------+-+++-+++++++-+++-+--+-+++-+++---++-++--++++++--++-++-+- +--+++--+-+--+-+--+++---++-+++-+------+-+++-+++-+++-+++-+--+-+++-+++---++-++--++++++--++-++-+ +-+--+--++------++--+--+++---+---+-++-+---+---++++-+++-+------+-+++-++---+++--+-+--+-+--+++-- ++-+--+--++------++--+--+++---+---+-++-+---+---++++-+++-+------+-+++-+----+++--+-+--+-+--+++- +-+-+--+--++------++--+--+++---+---+-++-+---+--+++++-+++-+------+-+++------+++--+-+--+-+--+++ +--+-+--+--++------++--+--+++---+---+-++-+---+--+++++-+++-+------+-++++-----+++--+-+--+-+--++ ++--+-+--+--++------++-----+++---+---+-++-+---++-+++++-+++-+------+-++++-----+++--+-+--+-+--+ +-+--+-+--+--++------++-+---+++---+---+-++-+---++-+++++-+++-+------+-++++-----+++--+-+--+-+-- +--+--+-+--+--++------++-+---+++---+---+-++-+--+++-+++++-+++-+------+--+++-----+++--+-+--+-+- ++--+--+-+--+--++------+--+---+++---+---+-++-+--+++-+++++-+++-+------+--+++-----+++--+-+--+-+ +++--+--+-+--+--++---------+---+++---+---+-++-++-+++-+++++-+++-+------+--+++-----+++--+-+--+- +-++--+--+-+--+--++-----+---+---+++---+---+-++--+-+++-+++++-+++-+------+--+++-----+++--+-+--+ +--++--+--+-+--+--++-----+---+---+++---+---+-++--+-+++-+++++-+++-+----+-+--+++-----+++--+-+-- +---++--+--+-+--+--++---+-+---+---+++---+---+-+---+-+++-+++++-+++-+----+-+--+++-----+++--+-+- +----++--+--+-+--+--++--++-+---+---+++---+---+-----+-+++-+++++-+++-+----+-+--+++-----+++--+-+ +-----++--+--+-+--+--++--++-+---+---+++---+---+-----+-+++-+++++-+++-+-+--+-+--+++-----+++--+- +------++--+--+-+--+--+++-++-+---+---+++---+---------+-+++-+++++-+++-+-+--+-+--+++-----+++--+ ++------++--+--+-+--+--+-+-++-+---+---+++---+--+------+-+++-+++++-+++-+-+--+-+--+++-----+++-- +++------++--+--+-+--+----+-++-+---+---+++---+--+------+-+++-+++++-+++-+-+--+-+--+++-----+++- +-++------++--+--+-+--+----+-++-+---+---+++---++-+------+-+++-+++++-++--+-+--+-+--+++-----+++ +--++------++--+--+-+--++---+-++-+---+---+++---++-+------+-+++-+++++-++--+-+--+-+--+++-----++ ++--++------++--+--+-+---+---+-++-+---+---+++--+++-+------+-+++-+++++-++--+-+--+-+--+++-----+ +-+--++------++--+--+-+---+---+-++-+---+---+++--+++-+------+-+++-++++++++--+-+--+-+--+++----- +--+--++------++--+--+-+---+---+-++-+---+---++++-+++-+------+-+++-++++-+++--+-+--+-+--+++---- ++--+--++------++--+--+-+---+---+-++-+---+---++++-+++-+------+-+++-+++--+++--+-+--+-+--+++--- +--+++-+++-+--+-+++-+++--+--+--++------++--+--++++---++-+-++-+-++---+++++-+++-+------+-+++-++ +---+++-+++-+--+-+++-++++-+--+--++------++--+--++++---++-+-++-+-++---+++++-+++-+------+-+++-+ ++---+++-+++-+--+-+++-++-+-+--+--++------++--+-+++++---++-+-++-+-++---+++++-+++-+------+-+++- +++---+++-+++-+--+-+++-+--+-+--+--++------++--+-+++++---++-+-++-+-++---+++++-+++-+------+-+++ ++++---+++-+++-+--+-+++-+--+-+--+--++------++----+++++---++-+-++-+-++-+-+++++-+++-+------+-++ +-+++---+++-+++-+--+-+++-+--+-+--+--++------++----+++++---++-+-++-+-++++-+++++-+++-+------+-+ ++-+++---+++-+++-+--+-++--+--+-+--+--++------+++---+++++---++-+-++-+-++++-+++++-+++-+------+- +++-+++---+++-+++-+--+-++--+--+-+--+--++------+++---+++++---++-+-++-+--+++-+++++-+++-+------+ ++++-+++---+++-+++-+--+-++--+--+-+--+--++-------++---+++++---++-+-++-++-+++-+++++-+++-+------ +-+++-+++---+++-+++-+--+-++--+--+-+--+--++-----+-++---+++++---++-+-++--+-+++-+++++-+++-+----- ++-+++-+++---+++-+++-+----++--+--+-+--+--++-----+-++---+++++---++-+-++--+-+++-+++++-+++-+---- +-+-+++-+++---+++-+++-+----++--+--+-+--+--++---+-+-++---+++++---++-+-+---+-+++-+++++-+++-+--- +--+-+++-+++---+++-+++-+----++--+--+-+--+--++--++-+-++---+++++---++-+-----+-+++-+++++-+++-+-- ++--+-+++-+++---+++-+++------++--+--+-+--+--++--++-+-++---+++++---++-+-----+-+++-+++++-+++-+- +-+--+-+++-+++---+++-+++------++--+--+-+--+--+++-++-+-++---+++++---++-------+-+++-+++++-+++-+ ++-+--+-+++-+++---+++-+++------++--+--+-+--+--+-+-++-+-++---+++++---+++------+-+++-+++++-+++- +++-+--+-+++-+++---+++-+++------++--+--+-+--+--+-+-++-+-++---+++++---+-+------+-+++-+++++-+++ ++++-+--+-+++-+++---+++--++------++--+--+-+--+-++-+-++-+-++---+++++---+-+------+-+++-+++++-++ +-+++-+--+-+++-+++---+++--++------++--+--+-+--+-++-+-++-+-++---+++++--++-+------+-+++-+++++-+ ++-+++-+--+-+++-+++---+++--++------++--+--+-+----++-+-++-+-++---+++++-+++-+------+-+++-+++++- +++-+++-+--+-+++-+++---+-+--++------++--+--+-+----++-+-++-+-++---+++++-+++-+------+-+++-+++++ ++++-+++-+--+-+++-+++-----+--++------++--+--+-++---++-+-++-+-++---+++++-+++-+------+-+++-++++ +-+++-+++-+--+-+++-+++--+--+--++------++--+--+-++---++-+-++-+-++---+++++-+++-+------+-+++-+++ \ No newline at end of file diff --git a/gptqmodel/exllamav3/util/memory.py b/gptqmodel/exllamav3/util/memory.py new file mode 100644 index 000000000..4ac6aa4ae --- /dev/null +++ b/gptqmodel/exllamav3/util/memory.py @@ -0,0 +1,228 @@ +from dataclasses import dataclass +from collections import deque +import torch +import gc +import sys +from pydantic import PydanticUserError + +# @lru_cache +# def init_pynvml(): +# pynvml.nvmlInit() + +# Try to make sure device is live for correct measurement of free VRAM +def touch_device(device: int): + d = torch.empty((32, 32), device = device, dtype = torch.float) + d = d @ d + d.add_(d) + + +# Touch device and measure VRAM (child process) +def touch_device_measure_vram(local_context: dict): + device = local_context["device"] + touch_device(device) + return torch.cuda.mem_get_info(device) + + +# Reserve byte amount on device +def set_memory_fraction_reserve( + reserve: int, + device: int +): + touch_device(device) + free, total = torch.cuda.mem_get_info(device) + fraction = (free - reserve) / total + fraction = max(0.01, fraction) + torch.cuda.set_per_process_memory_fraction(fraction, device = device) + + +# Reserve all but byte amount on device +def set_memory_fraction_use( + use: int, + device: int +): + touch_device(device) + free, total = torch.cuda.mem_get_info(device) + baseline = torch.cuda.memory_allocated(device) + fraction = min((baseline + use) / total, 1.0) + torch.cuda.set_per_process_memory_fraction(fraction, device = device) + + +# Un-reserve VRAM +def unset_memory_fraction(active_devices: list[int]): + for i in active_devices: + torch.cuda.set_per_process_memory_fraction(1.0, device = i) + + +# Free unused VRAM +def free_mem(): + gc.collect() + torch.cuda.empty_cache() + + +def list_gpu_tensors(min_size: int = 1, cuda_only: bool = True): + """ + Search the current process for referenced CUDA tensors and list them. + + :param min_size: + Ignore tensors smaller than this size, in megabytes + + :param cuda_only: + Only list CUDA tensors + """ + + import threading + import warnings + from tabulate import tabulate + + # Suppress FutureWarning from Torch every time we try to access certain objects + warnings.simplefilter(action = 'ignore', category = FutureWarning) + + @dataclass + class Result: + paths: list[str] + shape: tuple + dtype: torch.dtype + device: str + size: int + + results = {} + visited = set() + + # Helper function to filter and collect items + def collect(path, item): + nonlocal results + + # Only collect CUDA tensors + if not isinstance(item, torch.Tensor) or (cuda_only and not item.is_cuda): + return + + # Tensor size in MB, filter anything smaller than the minimum size + size = item.nelement() * item.element_size() // (1024**2) + if size < min_size: + return + + # Skip tensors in paths containing specific debug substrings + if any(x in path for x in [ + ".stderr.dbg.", + "dbg.value_resolve_thread_list", + "global_vars[", + "local_vars[", + "updated_globals[", + ]): + return + + # Adjust the path display for objects defined in __main__ + if ".__main__." in path: + path = path[path.find(".__main__.") + 10:] + + # If tensor is already recorded, just record the additional path + obj_id = id(item) + if obj_id in results and path not in results[obj_id].paths: + results[obj_id].paths.append(path) + else: + results[obj_id] = Result( + paths = [path], + shape = item.shape, + dtype = item.dtype, + device = str(item.device), + size = size + ) + + # Queue of items to scan recursively + queue = deque() + + # Collect items that are global variables, and add to the queue + for name, obj in globals().items(): + collect(name, obj) + queue.append((name, obj)) + + # Traverse each thread's frame stack, collecting items and queueing items + for thread_id, frame in sys._current_frames().items(): + prefix = "" + + # Skip the current frame for the current thread to avoid recursion issues + if thread_id == threading.get_ident(): + frame = frame.f_back + + # Collect/queue each local variable in the frame, extend the relative path prefix + # and walk the stack + while frame: + for name, obj in frame.f_locals.items(): + # We actually start three levels deep but want variables in the "current" frame + # (i.e. the frame of the function calling list_gpu_tensors) to have a prefix of "." + new_path = f"{prefix[2:]}.{name}" + collect(new_path, obj) + queue.append((name, obj)) + frame = frame.f_back + prefix += "." + + # Process the queue by examining attributes, dictionary entries, and sequence items + while queue: + path, obj = queue.popleft() + + # Iterate over entries in object with __dict__ attribute + try: + if hasattr(obj, '__dict__'): + for attr, value in obj.__dict__.items(): + new_path = f"{path}.{attr}" + collect(new_path, value) + if id(value) not in visited: + visited.add(id(value)) + queue.append((new_path, value)) + except PydanticUserError: + pass + + # If object is a dictionary, iterate through all its items + if isinstance(obj, dict): + try: + for key, value in obj.items(): + new_path = f"{path}['{key}']" + collect(new_path, value) + if id(value) not in visited: + visited.add(id(value)) + queue.append((new_path, value)) + except Exception: + pass + + # Same for list, tuple, set + if isinstance(obj, (list, tuple, set)): + for idx, item in enumerate(obj): + new_path = f"{path}[{idx}]" + collect(new_path, item) + if id(item) not in visited: + visited.add(id(item)) + queue.append((new_path, item)) + + # Sort tensors by descending size + items = list(results.values()) + items.sort(key = lambda x: -x.size) + + # Build output table, grouped by device + devices: dict[str, list] = {} + for v in items: + if v.device not in devices: + devices[v.device] = [] + dev = devices[v.device] + dev.append([ + v.size, + v.paths[0], + tuple(v.shape), + str(v.dtype).replace("torch.", "") + ]) + for p in v.paths[1:]: + dev.append([ + None, + " + " + p, + None, + None + ]) + + # Print tables to console + for k in sorted(devices.keys()): + print() + print("--------------") + print(f"| {k:10} |") + print("--------------") + print() + headers = ["size // MB", "path", "shape", "dtype"] + print(tabulate(devices[k], headers = headers, tablefmt = "github", intfmt=",")) diff --git a/gptqmodel/exllamav3/util/misc.py b/gptqmodel/exllamav3/util/misc.py new file mode 100644 index 000000000..91e51d6a2 --- /dev/null +++ b/gptqmodel/exllamav3/util/misc.py @@ -0,0 +1,137 @@ +import math +import threading +import time +import torch +import socket, contextlib +import weakref + +lock = threading.RLock() + +def synchronized(func): + def wrapper(*args, **kwargs): + with lock: + return func(*args, **kwargs) + return wrapper + +def align_to(value, alignment): + return int(math.ceil(value / alignment) * alignment) + + +class Timer: + """ + Context manager to record duration + """ + + def __enter__(self): + self.start_time = time.time() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.end_time = time.time() + self.interval = self.end_time - self.start_time + + +def cuda_sync_active(): + """ + Calling torch.cuda.synchronize() will create a CUDA context on CUDA:0 even if that device is not being used. + This function synchronizes only devices actively used by Torch in the current process. + """ + for device_id in range(torch.cuda.device_count()): + device = torch.device(f'cuda:{device_id}') + if torch.cuda.memory_allocated(device) > 0: + torch.cuda.synchronize(device) + + +def next_power_of_2(x): + return 1 if x == 0 else 2**(x - 1).bit_length() + + +def human_time(seconds: float) -> str: + seconds = round(seconds) + minutes = seconds // 60 + hours = minutes // 60 + minutes -= hours * 60 + if hours: + if minutes: + hs = "s" if hours > 1 else "" + ms = "s" if minutes > 1 else "" + return f"{hours} hour{hs}, {minutes} minute{ms}" + else: + hs = "s" if hours > 1 else "" + return f"{hours} hour{hs}" + elif minutes: + ms = "s" if minutes > 1 else "" + return f"{minutes} minute{ms}" + else: + return f"< 1 minute" + + +def first_not_none(*values): + return next((v for v in values if v is not None), None) + + +def ratio_split(d, weights, chunk_size = 128): + assert d % chunk_size == 0, "Total must be divisible by chunk size" + total_chunks = d // chunk_size + total_weight = sum(weights) + ideal_chunks = [total_chunks * w / total_weight for w in weights] + base_chunks = [int(c) for c in ideal_chunks] + remainder = total_chunks - sum(base_chunks) + residuals = [c - int(c) for c in ideal_chunks] + for i in sorted(range(len(residuals)), key = lambda i: -residuals[i])[:remainder]: + base_chunks[i] += 1 + final_alloc = [c * chunk_size for c in base_chunks] + assert sum(final_alloc) == d + return final_alloc + + +def find_free_port() -> int: + with contextlib.closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +class Cleanupper: + """ + Utility class to call cleanup functions at the end of the __main__ scope. Similar functionality to + atexit but called before Python starts tearing down objects/threads. + """ + + def __init__(self): + self.atexit_fns = [] + weakref.finalize(self, self._shutdown) + + def register_atexit(self, fn): + self.atexit_fns.append(fn) + + def unregister_atexit(self, fn): + if fn in self.atexit_fns: + self.atexit_fns.remove(fn) + + def _shutdown(self): + for fn in self.atexit_fns: + fn() + self.atexit_fns = [] + + +def set_process_priority_and_affinity(): + import psutil, os + import multiprocessing as mp + + p = psutil.Process(os.getpid()) + # Try to bump priority slightly. May need sudo (?) + try: + p.nice(psutil.ABOVE_NORMAL_PRIORITY_CLASS if os.name == "nt" else -5) + except PermissionError: + pass + except Exception as e: + pass + + # Pin to a core + # TODO: Pick an idle core automatically? + try: + p.cpu_affinity([0]) # pick an isolated/quiet core if possible + except AttributeError: + pass + except Exception as e: + pass diff --git a/gptqmodel/exllamav3/util/progress.py b/gptqmodel/exllamav3/util/progress.py new file mode 100644 index 000000000..c4b855e5c --- /dev/null +++ b/gptqmodel/exllamav3/util/progress.py @@ -0,0 +1,45 @@ +import sys +from rich.progress import Progress, BarColumn, TextColumn, TimeElapsedColumn, TimeRemainingColumn + +class ProgressBar: + + def __init__(self, text: str, count: int, transient: bool = True): + self.text = text + self.count = count + self.transient = transient + if self.text: + self.progress = Progress( + TextColumn("[progress.description]{task.description}"), + BarColumn(bar_width = None), + "[progress.percentage]{task.percentage:>3.0f}%", + TimeElapsedColumn(), + TimeRemainingColumn(), + transient = transient, + speed_estimate_period = 600.0, + ) + self.task_id = self.progress.add_task(text, total = count) + + def __enter__(self): + if self.text: + self.progress.start() + sys.stdout.flush() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.text: + if not self.transient: + self.progress.update(self.task_id, completed = self.count) + self.progress.stop() + + def update(self, value: int): + if self.text: + self.progress.update(self.task_id, completed = value) + sys.stdout.flush() + + def new_task(self, text: str, count: int): + self.text = text + self.count = count + if self.text: + self.progress.update(self.task_id, description = self.text, total = count, progress = 0) + + diff --git a/gptqmodel/exllamav3/util/tensor.py b/gptqmodel/exllamav3/util/tensor.py new file mode 100644 index 000000000..c5013e738 --- /dev/null +++ b/gptqmodel/exllamav3/util/tensor.py @@ -0,0 +1,210 @@ +from __future__ import annotations +import torch + +class SeqTensor: + + PAGE_SIZE = 256 + + tensor: torch.Tensor | None + seq_dim: int + seq_len: int + seq_cap: int + + def __init__( + self, + shape: tuple, + dtype: torch.dtype, + seq_dim: int, + device: torch.device = "cpu", + init_cap: int = -1 + ): + if seq_dim < 0: seq_dim = len(shape) + seq_dim + self.seq_dim = seq_dim + self.seq_len = 0 + if init_cap == -1: + init_cap = self.PAGE_SIZE + else: + init_cap = (init_cap // self.PAGE_SIZE + 1) * self.PAGE_SIZE + shape = list(shape) + shape[seq_dim] = self.seq_cap = init_cap + shape = tuple(shape) + # Lazily allocate inner Tensor object to avoid committing too much virtual memory, which can crash the + # process on Windows + # self.tensor = torch.empty(shape, dtype = dtype, device = device) + self.init_shape = shape + self.dtype = dtype + self.device = device + self.tensor = None + + def __len__(self): + return self.seq_len + + def __bool__(self): + return self.seq_len > 0 + + def _ensure_init(self): + if self.tensor is None: + self.tensor = torch.empty(self.init_shape, dtype = self.dtype, device = self.device) + + @staticmethod + def from_tensor(tensor: torch.Tensor, seq_dim: int): + s = SeqTensor(tensor.shape, tensor.dtype, seq_dim, tensor.device, init_cap = tensor.shape[seq_dim]) + s.append(tensor) + return s + + def clone(self, drop: int | None = None): + if drop and drop <= self.seq_len: + return SeqTensor.from_tensor(self.torch_slice(None, self.seq_len - drop), self.seq_dim) + else: + return SeqTensor.from_tensor(self.torch(), self.seq_dim) + + def clear(self): + self.seq_len = 0 + + def set(self, new_data: SeqTensor | torch.tensor | None = None): + self.clear() + self.append(new_data) + + def append(self, new_data: SeqTensor | torch.tensor | None): + self._ensure_init() + if new_data is None: return + if isinstance(new_data, SeqTensor): + new_data = new_data.torch() + new_len = new_data.shape[self.seq_dim] + end_pos = self.seq_len + new_len + if end_pos >= self.seq_cap: + new_cap = (end_pos // self.PAGE_SIZE + 1) * self.PAGE_SIZE + grow_shape = list(new_data.shape) + grow_shape[self.seq_dim] = new_cap - self.seq_cap + grow_shape = tuple(grow_shape) + grow_tensor = torch.empty(grow_shape, dtype = self.tensor.dtype, device = self.tensor.device) + self.tensor = torch.cat((self.tensor, grow_tensor), dim = self.seq_dim) + self.seq_cap = new_cap + s = self.tensor.narrow(self.seq_dim, self.seq_len, end_pos - self.seq_len) + s.copy_(new_data) + self.seq_len += new_len + + def truncate(self, new_len: int): + assert new_len <= self.seq_len + self.seq_len = new_len + + def torch(self): + self._ensure_init() + s = self.tensor.narrow(self.seq_dim, 0, self.seq_len) + return s + + def slice(self, a: int | None, b: int | None): + return SeqTensor.from_tensor(self.torch_slice(a, b), self.seq_dim) + + def torch_slice(self, a: int | None, b: int | None): + self._ensure_init() + if a is None and b is None: + return self.torch() + elif b is None: + s = self.tensor.narrow(self.seq_dim, a, self.seq_len - a) + elif a is None: + s = self.tensor.narrow(self.seq_dim, 0, b) + else: + s = self.tensor.narrow(self.seq_dim, a, b - a) + return s + + +no_default = object() + +def get_for_device( + input_dict: dict, + key: str | int, + device: torch.device, + default = no_default, +) -> torch.Tensor | None: + """ + Read a tensor from a dict and ensure it is available on the specified device. Caches access per device and may + break if the tensor is updated after being accessed in this way. Intended for tensors that are read-only for the + lifetime of the dict, such as RoPE coefficients during a single forward pass. + """ + if key not in input_dict and default is not no_default: + return default + + if "dev_cache" not in input_dict: + cache = {} + input_dict["dev_cache"] = cache + else: + cache = input_dict["dev_cache"] + + cache_key = f"{key}[{str(device)}]" + if cache_key in cache: + return cache[cache_key] + + v = input_dict[key] + dv = None if v is None else input_dict[key].to(device) + cache[cache_key] = dv + return dv + + +buffered_aranges = {} +def buffered_arange(r: int, device: torch.device): + if r not in buffered_aranges: + buffered_aranges[r] = torch.arange(r) + return get_for_device(buffered_aranges, r, device) + + +def to2( + x: torch.Tensor, + dtype1: torch.dtype | None, + dtype2: torch.dtype | None = None +): + if dtype1 is not None: + x = x.to(dtype1) + elif dtype2 is not None: + x = x.to(dtype2) + return x + + +def save_tensor_image( + t: torch.Tensor, + path: str, +): + import matplotlib.cm as cm + from PIL import Image + + t = t.detach().to("cpu", copy = True).float() + + k = 3 + _, sigma = t.mean(), t.std() + lo, hi = -k * sigma, k * sigma + t.clamp_(lo, hi) + t -= lo + t /= (hi - lo + 1e-8) + + rgba = cm.get_cmap("berlin")(t.numpy()) + rgb8 = (rgba[..., :3] * 255).astype("uint8") + im = Image.fromarray(rgb8) + im.save(path) + + +class GTensorCache: + def __init__(self): + self.cache = {} + + def make_key(self, device, shape, dtype, x): + device = torch.device(device) + return f"{device}/{str(shape)}/{str(dtype)}/{x}" + + def get(self, device, shape, dtype, x = ""): + key = self.make_key(device, shape, dtype, x) + if key not in self.cache: + refc, v = (0, torch.empty(shape, dtype = dtype, device = device)) + else: + refc, v = self.cache[key] + self.cache[key] = (refc + 1, v) + return v + + def drop(self, device, shape, dtype, x = ""): + key = self.make_key(device, shape, dtype, x) + refc, v = self.cache[key] + if refc == 1: + del self.cache[key] + else: + self.cache[key] = (refc - 1, v) + +g_tensor_cache = GTensorCache() diff --git a/gptqmodel/extension.py b/gptqmodel/extension.py new file mode 100644 index 000000000..b3f40cd57 --- /dev/null +++ b/gptqmodel/extension.py @@ -0,0 +1,292 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from importlib import import_module +from dataclasses import dataclass +import threading +from typing import TYPE_CHECKING, Callable + +from .utils.logger import setup_logger + +if TYPE_CHECKING: + from .utils.cpp import TorchOpsJitExtension + + +log = setup_logger() +# Serialize same-extension API calls so Python 3.13t no-GIL callers do not +# race clear/load cycles for the same JIT target. +_EXTENSION_API_LOCKS: dict[str, threading.Lock] = {} +_EXTENSION_API_LOCKS_GUARD = threading.Lock() + + +@dataclass(frozen=True) +class _ExtensionSpec: + name: str + aliases: tuple[str, ...] + resolve: Callable[[], TorchOpsJitExtension] + supported: Callable[[], bool] | None = None + unsupported_error: Callable[[], str] | None = None + + +def _resolve_attr(module_name: str, attr_name: str): + return getattr(import_module(module_name), attr_name) + + +def _resolve_extension_attr(module_name: str, attr_name: str) -> TorchOpsJitExtension: + return _resolve_attr(module_name, attr_name) + + +def _resolve_extension_factory(module_name: str, attr_name: str) -> TorchOpsJitExtension: + return _resolve_attr(module_name, attr_name)() + + +_EXTENSION_SPECS = ( + _ExtensionSpec( + name="pack_block_cpu", + aliases=("pack_block", "pack"), + resolve=lambda: _resolve_extension_factory("gptqmodel.utils.cpp", "_pack_block_extension"), + ), + _ExtensionSpec( + name="floatx_cpu", + aliases=("floatx", "quant_dtype_cpu"), + resolve=lambda: _resolve_extension_factory("gptqmodel.utils.cpp", "_floatx_cpu_extension"), + ), + _ExtensionSpec( + name="awq", + aliases=(), + resolve=lambda: _resolve_extension_attr("gptqmodel.utils.awq", "_AWQ_TORCH_OPS_EXTENSION"), + ), + _ExtensionSpec( + name="qqq", + aliases=(), + resolve=lambda: _resolve_extension_attr("gptqmodel.utils.qqq", "_QQQ_TORCH_OPS_EXTENSION"), + ), + _ExtensionSpec( + name="exllamav2", + aliases=("exllama2", "exllama_v2", "exllamav2_gptq"), + resolve=lambda: _resolve_extension_attr("gptqmodel.utils.exllamav2", "_EXLLAMAV2_GPTQ_TORCH_OPS_EXTENSION"), + ), + _ExtensionSpec( + name="exllamav2_awq", + aliases=("exllama2_awq", "exllama_v2_awq"), + resolve=lambda: _resolve_extension_attr("gptqmodel.utils.exllamav2", "_EXLLAMAV2_AWQ_TORCH_OPS_EXTENSION"), + ), + _ExtensionSpec( + name="exllamav3", + aliases=("exllama3", "exllama_v3"), + resolve=lambda: _resolve_extension_attr("gptqmodel.exllamav3.ext", "_EXLLAMAV3_TORCH_OPS_EXTENSION"), + ), + _ExtensionSpec( + name="machete", + aliases=("gptq_machete", "awq_machete"), + resolve=lambda: _resolve_extension_attr("gptqmodel.utils.machete", "_MACHETE_TORCH_OPS_EXTENSION"), + supported=lambda: _resolve_attr("gptqmodel.utils.machete", "_validate_machete_device_support")(), + unsupported_error=lambda: _resolve_attr("gptqmodel.utils.machete", "machete_runtime_error")(), + ), + _ExtensionSpec( + name="marlin_fp16", + aliases=(), + resolve=lambda: _resolve_extension_attr("gptqmodel.utils.marlin", "_MARLIN_FP16_TORCH_OPS_EXTENSION"), + ), + _ExtensionSpec( + name="marlin_bf16", + aliases=(), + resolve=lambda: _resolve_extension_attr("gptqmodel.utils.marlin", "_MARLIN_BF16_TORCH_OPS_EXTENSION"), + ), + _ExtensionSpec( + name="paroquant", + aliases=("paroquant_rotation",), + resolve=lambda: _resolve_extension_attr("gptqmodel.utils.paroquant", "_PAROQUANT_ROTATION_EXTENSION"), + ), +) + +_EXTENSION_SPECS_BY_NAME = {spec.name: spec for spec in _EXTENSION_SPECS} + +_EXTENSION_GROUPS = { + "marlin": ("marlin_fp16", "marlin_bf16"), +} + +_EXTENSION_ALIASES = { + alias: spec.name + for spec in _EXTENSION_SPECS + for alias in (spec.name, *spec.aliases) +} + + +def _normalize_extension_name(name: str) -> str: + return "_".join(str(name).strip().lower().replace("-", "_").split()) + + +def available_extensions() -> tuple[str, ...]: + """Return the concrete extension names accepted by `load()`.""" + + return tuple(spec.name for spec in _EXTENSION_SPECS) + + +def _spec_supported(spec: _ExtensionSpec) -> bool: + if spec.supported is None: + return True + try: + return bool(spec.supported()) + except Exception: + return False + + +def _resolve_requested_extensions(name: str) -> tuple[str, ...]: + normalized = _normalize_extension_name(name or "all") + if normalized == "all": + return tuple(spec.name for spec in _EXTENSION_SPECS if _spec_supported(spec)) + if normalized in _EXTENSION_GROUPS: + return _EXTENSION_GROUPS[normalized] + concrete = _EXTENSION_ALIASES.get(normalized) + if concrete is not None: + return (concrete,) + + allowed = sorted((*available_extensions(), *_EXTENSION_GROUPS.keys())) + raise ValueError( + f"Unknown extension `{name}`. Expected one of: {', '.join(allowed)}." + ) + + +def _spec_unsupported_error(spec: _ExtensionSpec) -> str: + if spec.unsupported_error is None: + return f"{spec.name} is not supported on this host." + try: + return spec.unsupported_error() or f"{spec.name} is not supported on this host." + except Exception: + return f"{spec.name} is not supported on this host." + + +def _process_loaded(extension: TorchOpsJitExtension) -> bool: + return extension._ops_available() + + +def _extension_api_lock(name: str) -> threading.Lock: + with _EXTENSION_API_LOCKS_GUARD: + lock = _EXTENSION_API_LOCKS.get(name) + if lock is None: + lock = threading.Lock() + _EXTENSION_API_LOCKS[name] = lock + return lock + + +def _resolve_single_extension_name(name: str) -> str: + resolved = _resolve_requested_extensions(name) + if len(resolved) != 1: + raise ValueError( + f"Extension `{name}` resolves to multiple extensions: {', '.join(resolved)}. " + "Use one concrete extension name for this operation." + ) + return resolved[0] + + +def _extension_for_name(name: str) -> TorchOpsJitExtension: + return _EXTENSION_SPECS_BY_NAME[_resolve_single_extension_name(name)].resolve() + + +def _load_one(name: str, *, use_cache: bool) -> TorchOpsJitExtension: + extension_name = _resolve_single_extension_name(name) + spec = _EXTENSION_SPECS_BY_NAME[extension_name] + if not _spec_supported(spec): + raise RuntimeError(_spec_unsupported_error(spec)) + with _extension_api_lock(extension_name): + extension = spec.resolve() + + if not use_cache: + if _process_loaded(extension): + raise RuntimeError( + f"{extension.display_name}: already loaded in this Python process. " + "Restart Python to force recompilation." + ) + extension.clear_cache() + + if not extension.load(): + raise RuntimeError( + extension.last_error_message() + or f"{extension.display_name}: failed to compile torch.ops JIT extension." + ) + return extension + + +def is_available(name: str, *, use_cache: bool = True) -> bool: + """Return whether one concrete extension can be loaded through the shared API.""" + + try: + _load_one(name, use_cache=use_cache) + return True + except RuntimeError: + return False + + +def error(name: str) -> str: + """Return the last human-readable error captured for one concrete extension.""" + + return _extension_for_name(name).last_error_message() + + +def op(name: str, op_name: str, *, use_cache: bool = True) -> object: + """Return one torch.ops handle after ensuring the selected extension is loaded.""" + + return _load_one(name, use_cache=use_cache).op(op_name) + + +def namespace(name: str, *, use_cache: bool = True) -> object: + """Return the torch.ops namespace object after ensuring the extension is loaded.""" + + return _load_one(name, use_cache=use_cache).namespace_object() + + +def load(name: str = "all", *, use_cache: bool = True) -> dict[str, bool]: + """Build one or more GPT-QModel torch.ops JIT extensions ahead of first use. + + Args: + name: One concrete extension name, `marlin` for both Marlin dtype + variants, or `all` to build every managed JIT extension. Default: + `all`. + use_cache: Reuse any compatible cached build artifact when available. + Set `False` to clear cached on-disk build artifacts before + compiling. This only works before the selected extension has been + loaded in the current Python process. Once a torch.ops library is + registered, a fresh interpreter is required for a true rebuild. + + Returns: + A mapping of concrete extension names to their build/load result. + + Raises: + ValueError: The requested extension name is unknown. + RuntimeError: One or more selected extensions failed to compile or + `use_cache=False` was requested after an extension was already loaded + in the current process. + """ + + selected = _resolve_requested_extensions(name) + results: dict[str, bool] = {} + errors: dict[str, str] = {} + + log.info( + "Extension load requested for `%s`: %s%s", + name, + ", ".join(selected), + "" if use_cache else " (use_cache=False)", + ) + + for extension_name in selected: + try: + _load_one(extension_name, use_cache=use_cache) + results[extension_name] = True + except RuntimeError as exc: + results[extension_name] = False + extension = _EXTENSION_SPECS_BY_NAME[extension_name].resolve() + errors[extension_name] = extension.last_error_message() or str(exc) + + if errors: + summary = "\n".join(f"- {name}: {message}" for name, message in errors.items()) + raise RuntimeError(f"Extension load failed:\n{summary}") + + log.info("Extension load finished successfully: %s", ", ".join(results)) + return results + + +__all__ = ["available_extensions", "error", "is_available", "load", "namespace", "op"] diff --git a/gptqmodel/hf_minimax_m2/modeling_minimax_m2.py b/gptqmodel/hf_minimax_m2/modeling_minimax_m2.py index b4bb0407e..ef999326f 100644 --- a/gptqmodel/hf_minimax_m2/modeling_minimax_m2.py +++ b/gptqmodel/hf_minimax_m2/modeling_minimax_m2.py @@ -383,7 +383,7 @@ def forward( # Chunked attention computation to reduce peak memory usage out_parts = [] attn_parts = [] if output_attentions else None - + # A smaller chunk size reduces memory but may be slightly slower chunk_size = 1024 for i in range(0, q.size(1), chunk_size): @@ -399,12 +399,12 @@ def forward( if window_mask is not None: attn_chunk.masked_fill_(window_mask[:, i:i + chunk_size, :], float("-inf")) - + attn_chunk = torch.softmax(attn_chunk, dim=-1, dtype=torch.float32).to(query_dtype) if self.training and self.attention_dropout > 0: attn_chunk = F.dropout(attn_chunk, p=self.attention_dropout, training=True) - + if output_attentions: attn_parts.append(attn_chunk) @@ -413,7 +413,7 @@ def forward( out_parts.append(out_chunk) del q_chunk, attn_chunk, out_chunk - + out = torch.cat(out_parts, dim=1) attn_output_parts.append(out) diff --git a/gptqmodel/looper/awq_processor.py b/gptqmodel/looper/awq_processor.py index da56a35b0..e42fc0bcc 100644 --- a/gptqmodel/looper/awq_processor.py +++ b/gptqmodel/looper/awq_processor.py @@ -15,22 +15,22 @@ from torch import nn from torch.nn import Module -from ..looper.loop_processor import DTYPE_SIZE_COLUMN, MODULE_FEATURE_COLUMN, LoopProcessor +from ..looper.loop_processor import DTYPE_SIZE_COLUMN, ExecutionConfig, MODULE_FEATURE_COLUMN, LoopProcessor from ..looper.named_module import NamedModule from ..models import BaseQModel from ..models._const import SUPPORTS_MODULE_TYPES from ..models.writer import (PROCESS_LOG_LAYER, PROCESS_LOG_MODULE, PROCESS_LOG_NAME, PROCESS_LOG_TIME, PROCESS_USED_MEMORY, QUANT_LOG_LOSS, QUANT_LOG_NSAMPLES) -from ..nn_modules.qlinear.gemm_awq import AwqGEMMQuantLinear -from ..nn_modules.qlinear.gemv_awq import AwqGEMVQuantLinear -from ..nn_modules.qlinear.gemv_fast_awq import AwqGEMVFastQuantLinear, LLMAwqQuantLinear +from ..nn_modules.qlinear.gemm_awq import AwqGEMMLinear +from ..nn_modules.qlinear.gemv_awq import AwqGEMVLinear +from ..nn_modules.qlinear.gemv_fast_awq import AwqGEMVFastLinear, LLMAwqLinear from ..quantization.awq.quantize.scale import apply_clip, apply_scale from ..quantization.awq.utils.module import append_str_prefix, get_op_name, get_op_by_name from ..quantization.awq.utils.utils import get_best_device -from ..quantization.config import FORMAT, METHOD, QuantizeConfig +from ..quantization.config import FORMAT, METHOD, QuantizeConfig, resolve_quant_format from ..utils.ctx import ctx from ..utils.device import get_device -from ..utils.failsafe import normalize_failsafe +from ..utils.fallback import normalize_fallback from ..utils.logger import setup_logger, log_time_block from ..utils.model import find_modules, get_module_by_name_prefix, move_to, create_quant_module, pack_module from ..utils.module_locks import parent_module_lock @@ -41,6 +41,8 @@ @dataclass class _AWQLayerState: + """Tracks subset progress and cached state needed to quantize one AWQ layer.""" + modules: Dict[str, NamedModule] = field(default_factory=dict) subset_total: Optional[int] = None processed_subsets: Set[int] = field(default_factory=set) @@ -50,7 +52,90 @@ class _AWQLayerState: pending_modules: Set[str] = field(default_factory=set) lock: threading.Lock = field(default_factory=threading.Lock) + +def _accumulate_awq_weight_mean( + layers: List[nn.Linear], + group_size: int, +) -> Tuple[torch.Tensor, int]: + """Accumulates normalized per-channel weight sums across a group of linears.""" + + if not layers: + raise ValueError("Expected at least one linear layer to compute the AWQ weight mean.") + + first_weight = layers[0].weight.detach() + num_channels = first_weight.shape[1] + effective_group_size = group_size if group_size > 0 else num_channels + if num_channels % effective_group_size != 0: + raise ValueError( + f"Expected in_features ({num_channels}) to be divisible by group_size ({effective_group_size})." + ) + + if first_weight.device.type == CPU: + weights = [] + row_count = 0 + for layer in layers: + weight = layer.weight.detach() + if weight.shape[1] != num_channels: + raise ValueError( + f"Expected consistent in_features across layers ({num_channels}), " + f"got {weight.shape[1]} for layer {layer}." + ) + weights.append(weight.to(dtype=torch.float32)) + row_count += weight.shape[0] + + weight = torch.cat(weights, dim=0) + weight = weight.abs().reshape(row_count, -1, effective_group_size) + group_scale = weight.amax(dim=2, keepdim=True) + weight.div_(group_scale.add_(1e-6)) + return weight.reshape(row_count, num_channels).sum(dim=0), row_count + + w_sum = torch.zeros(num_channels, dtype=torch.float32, device=first_weight.device) + row_count = 0 + + for layer in layers: + weight = layer.weight.detach() + if weight.shape[1] != num_channels: + raise ValueError( + f"Expected consistent in_features across layers ({num_channels}), " + f"got {weight.shape[1]} for layer {layer}." + ) + + rows = weight.shape[0] + weight_abs = weight.abs() + weight_view = weight_abs.reshape(rows, -1, effective_group_size) + group_scale = weight_view.amax(dim=2, keepdim=True) + weight_view.div_(group_scale.add_(1e-6)) + w_sum += weight_abs.sum(dim=0, dtype=torch.float32) + row_count += rows + + return w_sum, row_count + + +def _compute_awq_weight_mean( + layers: List[nn.Linear], + group_size: int, +) -> torch.Tensor: + """Returns the average normalized per-channel weight magnitude for AWQ scaling.""" + + w_sum, row_count = _accumulate_awq_weight_mean(layers, group_size) + first_weight = layers[0].weight.detach() + if row_count == 0: + return torch.zeros(first_weight.shape[1], dtype=first_weight.dtype, device=first_weight.device) + return (w_sum / row_count).to(first_weight.dtype) + + class AWQProcessor(LoopProcessor): + """Captures activations and quantizes layers with the AWQ scaling workflow.""" + + @staticmethod + def resolve_quant_source_module(named_module: NamedModule) -> nn.Module: + """Return the dense module view AWQ should use for weight quantization.""" + + quant_source = named_module.state.get("quant_source_module") + if isinstance(quant_source, nn.Module): + return quant_source + return named_module.module + def __init__( self, tokenizer, @@ -66,6 +151,7 @@ def __init__( calculate_w_wq_diff: bool = False, calibration_concat_separator: Optional[str] = None, ): + """Initializes AWQ processing, layer tracking, and kernel selection.""" super().__init__( tokenizer=tokenizer, @@ -76,10 +162,12 @@ def __init__( calibration_concat_separator=calibration_concat_separator, prepare_dataset_func=prepare_dataset_func, batch_size=batch_size, - require_fwd=require_fwd, - fwd_after_process=True, - subset_forward_early_stop=True, - enable_activation_capture_flag=True, + execution_config=ExecutionConfig( + require_fwd=require_fwd, + fwd_replay_after_process=True, + subset_forward_early_stop=True, + enable_activation_capture=True, + ), ) self.calculate_w_wq_diff = calculate_w_wq_diff @@ -94,7 +182,8 @@ def __init__( self.gptq_model = gptq_model model_kernel = getattr(self.gptq_model, "qlinear_kernel", None) - self.qlinear_kernel = model_kernel or self._select_qlinear_kernel_for_format(qcfg.format) + self.format = resolve_quant_format(qcfg.format, qcfg.method) + self.qlinear_kernel = model_kernel or self._select_qlinear_kernel_for_format(self.format) self.model = model # Whether to apply clipping to the model during quantization. Some models may perform better with this set to False. @@ -104,8 +193,6 @@ def __init__( # " Default is 1GB (1024 * 1024 * 1024)." self.max_chunk_memory = 1024 * 1024 * 1024 - self.format = qcfg.format - # Whether to scale using both w/x or just x. self.duo_scaling = True @@ -115,13 +202,20 @@ def __init__( self._rotary_source_id: Optional[int] = None self._initialize_sample_counts() self._module_forward_kwargs.setdefault("attention_mask", None) - # Preserve failsafe preference so AWQ can optionally fall back when no calibration data or activations are available. - self.failsafe = qcfg.failsafe + # Preserve fallback preference so AWQ can optionally fall back when no calibration data or activations are available. + self.fallback = qcfg.fallback def _get_root_rotary(self) -> Optional[nn.Module]: + """Returns the model rotary module used to refresh position embeddings.""" + + if self.gptq_model.rotary_embedding: + rotary, _ = get_module_by_name_prefix(self.model, [self.gptq_model.rotary_embedding]) + return rotary return getattr(getattr(self.model, "model", self.model), "rotary_emb", None) def _get_rotary_device(self, rotary: Optional[nn.Module], fallback: Optional[torch.device] = None) -> Optional[torch.device]: + """Resolves the effective device for a rotary module or falls back safely.""" + if rotary is None: return fallback @@ -135,6 +229,8 @@ def _get_rotary_device(self, rotary: Optional[nn.Module], fallback: Optional[tor return fallback def _get_rotary_for_device(self, target_device: Optional[torch.device]) -> Optional[nn.Module]: + """Returns a rotary module copy materialized on the requested device.""" + rotary = self._get_root_rotary() if rotary is None or target_device is None: return rotary @@ -171,34 +267,42 @@ def _get_rotary_for_device(self, target_device: Optional[torch.device]) -> Optio return cached def set_calibration_dataset(self, calibration_dataset): + """Rejects dataset replacement because AWQ capture is fixed at construction.""" + raise NotImplementedError("AWQProcessor's calibration_dataset cannot be modified") def _select_qlinear_kernel_for_format(self, format_value: FORMAT): + """Maps the resolved AWQ format to its concrete quantized linear kernel.""" + fmt = FORMAT(format_value) if not isinstance(format_value, FORMAT) else format_value if fmt == FORMAT.GEMM: - return AwqGEMMQuantLinear + return AwqGEMMLinear if fmt == FORMAT.GEMV: - return AwqGEMVQuantLinear + return AwqGEMVLinear if fmt == FORMAT.GEMV_FAST: - return AwqGEMVFastQuantLinear + return AwqGEMVFastLinear if fmt == FORMAT.LLM_AWQ: - return LLMAwqQuantLinear + return LLMAwqLinear # We do not allow saving to marlin format # if fmt == FORMAT.MARLIN: - # return AwqMarlinQuantLinear + # return AwqMarlinLinear raise ValueError(f"METHOD.AWQ does not support this FORMAT: {format_value}") def _resolve_qlinear_kernel(self, module_name: Optional[str] = None): + """Resolves the AWQ kernel after applying any dynamic format override.""" + # Honor per-module dynamic format overrides when present. format_override = self.qcfg.dynamic_get(module_name, "format", None) if module_name else None - target_format = format_override or self.qcfg.format - if target_format == self.qcfg.format: + target_format = resolve_quant_format(format_override or self.qcfg.format, self.qcfg.method) + if target_format == self.format: model_kernel = getattr(self.gptq_model, "qlinear_kernel", None) if model_kernel is not None: return model_kernel return self._select_qlinear_kernel_for_format(target_format) def _get_layer_state(self, layer_index: int) -> _AWQLayerState: + """Returns the mutable tracking state for a specific transformer layer.""" + with self._layer_states_lock: state = self._layer_states.get(layer_index) if state is None: @@ -207,6 +311,8 @@ def _get_layer_state(self, layer_index: int) -> _AWQLayerState: return state def _initialize_sample_counts(self) -> None: + """Computes sample and token totals from the calibration dataset.""" + total = 0 dataset = getattr(self, "calibration_dataset", None) if dataset is None: @@ -236,6 +342,8 @@ def _initialize_sample_counts(self) -> None: self.nsamples = total_tokens def _record_input_feature(self, module_name: str, feature: torch.Tensor) -> None: + """Caches one captured input feature tensor for a named module.""" + # Preserve a leading sample axis for flattened [seq, hidden] captures so later # concatenation produces [samples, seq, hidden] instead of collapsing into one giant sequence. if feature.dim() <= 2: @@ -255,6 +363,8 @@ def _record_input_feature(self, module_name: str, feature: torch.Tensor) -> None inputs_list.append(feature) def _capture_previous_subset_scale(self, previous_subset: Optional[Dict[str, NamedModule]]) -> Optional[float]: + """Estimates the average weight scale of the previous subset for reuse heuristics.""" + if not previous_subset: return None @@ -272,6 +382,8 @@ def _capture_previous_subset_scale(self, previous_subset: Optional[Dict[str, Nam return float(sum(values) / len(values)) def _layer_input_features(self, state: _AWQLayerState) -> Dict[str, torch.Tensor]: + """Collapses per-batch cached inputs into one feature tensor per module.""" + features: Dict[str, torch.Tensor] = {} root_buckets: Dict[str, List[torch.Tensor]] = {} # Iterate over a snapshot since quantization may mutate state.modules concurrently @@ -305,13 +417,17 @@ def _layer_input_features(self, state: _AWQLayerState) -> Dict[str, torch.Tensor # features[root] = tensors[0] return features - def _quantize_layer_failsafe( + def _quantize_layer_fallback( self, layer_index: int, state: _AWQLayerState, reason: str, ) -> None: + """Falls back to direct quantization when AWQ scaling cannot proceed safely.""" + def unwrap(mod): + """Returns the underlying module when wrapped in `NamedModule`.""" + return mod.module if isinstance(mod, NamedModule) else mod named_childs = { @@ -351,6 +467,8 @@ def unwrap(mod): delattr(self._scale_context, "prev_scale") def _refresh_forward_kwargs_from_cache(self) -> None: + """Refreshes cached kwargs such as masks and rotary embeddings for AWQ search.""" + cache = getattr(self, "inputs_cache", None) if cache is None: return @@ -413,12 +531,14 @@ def _refresh_forward_kwargs_from_cache(self) -> None: self._module_forward_kwargs = refreshed - def _should_failsafe_group( + def _should_fallback_group( self, layer_names: List[str], input_feat: Dict[str, torch.Tensor], ) -> bool: - from ..utils.failsafe import should_use_failsafe + """Returns whether a scaling group lacks enough captured activations for AWQ.""" + + from ..utils.fallback import should_use_fallback captured_tokens = 0 for name in layer_names: @@ -430,13 +550,15 @@ def _should_failsafe_group( captured_tokens += feat.numel() // max(hidden, 1) expected_tokens = getattr(self, "total_calibration_tokens", None) or self._nsamples_total - return should_use_failsafe( - self.failsafe, + return should_use_fallback( + self.fallback, float(captured_tokens), float(expected_tokens) if expected_tokens else None, ) def _quantize_layer(self, layer_index: int, state: _AWQLayerState) -> None: + """Runs the AWQ scaling, clipping, and quantization flow for one layer.""" + if state.quantized: return @@ -463,14 +585,16 @@ def _quantize_layer(self, layer_index: int, state: _AWQLayerState) -> None: input_feat = self._layer_input_features(state) missing = [name for name, tensor in input_feat.items() if tensor.numel() == 0] - if missing and not self.failsafe: + if missing and not self.fallback: raise RuntimeError( f"AWQProcessor error: missing activation features for modules {missing} " - f"with failsafe disabled." + f"with fallback disabled." ) # Filtering MLP modules like Qwen3MoeSparseMoeBlock def unwrap(m): + """Returns the underlying module when wrapped in `NamedModule`.""" + return m.module if isinstance(m, NamedModule) else m named_childs = { @@ -518,7 +642,7 @@ def unwrap(m): filtered_module_config: List[Dict] = [] skipped_groups: List[Tuple[List[str], List[str]]] = [] - failsafe_names = set() + fallback_names = set() for cfg in sanitized_module_config: layers_sample = cfg.get("layers") or [] prev_module = cfg.get("prev_op") @@ -527,8 +651,8 @@ def unwrap(m): get_op_name(layer_module_ref, layer) if isinstance(layer, torch.nn.Module) else str(layer) for layer in layers_sample ] - if self.failsafe and self._should_failsafe_group(layer_names, input_feat): - failsafe_names.update(layer_names) + if self.fallback and self._should_fallback_group(layer_names, input_feat): + fallback_names.update(layer_names) continue first_layer_module = layers_sample[0] if layers_sample else None @@ -573,7 +697,7 @@ def unwrap(m): ) sanitized_module_config = filtered_module_config - if not sanitized_module_config and not failsafe_names: + if not sanitized_module_config and not fallback_names: log.warning( "AWQProcessor: no valid scaling groups for layer %s after filtering; marking layer as quantized.", layer_index, @@ -659,7 +783,7 @@ def unwrap(m): if self.apply_clip: clip_list = self._search_best_clip( layer_module_ref, - {name: named.module for name, named in named_childs.items() if name not in failsafe_names}, + {name: named.module for name, named in named_childs.items() if name not in fallback_names}, input_feat, ) apply_clip(layer_module_ref, clip_list) @@ -668,24 +792,24 @@ def unwrap(m): get_op_name(self.model, layer_module_ref) + ".", ) - failsafe_named_childs = { + fallback_named_childs = { n: named_childs[n] - for n in failsafe_names + for n in fallback_names if n in named_childs } - named_childs = {name: named for name, named in named_childs.items() if name in input_feat and name not in failsafe_names} + named_childs = {name: named for name, named in named_childs.items() if name in input_feat and name not in fallback_names} self.apply_quant(named_childs, scales_list) - if failsafe_named_childs: + if fallback_named_childs: log.warning( "AWQProcessor: layer %s fallback quant %d modules: %s", layer_index, - len(failsafe_named_childs), - list(failsafe_named_childs)[:6], + len(fallback_named_childs), + list(fallback_named_childs)[:6], ) - self.apply_quant(failsafe_named_childs, scales_list=[]) + self.apply_quant(fallback_named_childs, scales_list=[]) state.quantized = True state.modules.clear() @@ -716,6 +840,8 @@ def _search_best_scale( module2inspect=None, kwargs={}, ): + """Searches the best per-channel AWQ scale for a module group.""" + if module2inspect is None: assert len(layers) == 1 module2inspect = layers[0] @@ -731,33 +857,7 @@ def _search_best_scale( # previous cat()+view pipeline while keeping peak memory low: for every group we normalise # |w| by its per-group max so the values land on a [0, 1] scale, then accumulate the totals # per output channel so the mean can be computed without allocating the combined tensor. - first_weight = layers[0].weight - weight_dtype = first_weight.dtype - weight_device = first_weight.device - num_channels = first_weight.shape[1] - w_sum = torch.zeros(num_channels, dtype=torch.float32, device=weight_device) - row_count = 0 - - for layer in layers: - weight = layer.weight - if weight.shape[1] != num_channels: - raise ValueError( - f"Expected consistent in_features across layers ({num_channels}), " - f"got {weight.shape[1]} for layer {layer}." - ) - org_shape = weight.shape - weight_abs = weight.abs() - weight_group = weight_abs.view(-1, self.qcfg.group_size) - group_scale = weight_group.amax(dim=1, keepdim=True) + 1e-6 - normalized = weight_group / group_scale - normalized = normalized.view(org_shape) - w_sum += normalized.sum(dim=0, dtype=torch.float32) - row_count += org_shape[0] - - if row_count == 0: - w_mean = torch.zeros(num_channels, dtype=weight_dtype, device=weight_device) - else: - w_mean = (w_sum / row_count).to(weight_dtype) + w_mean = _compute_awq_weight_mean(layers, self.qcfg.group_size) # [STEP 2]: Compute per-channel mean of the input activation with chunking # Stream directly on the source device to avoid creating full CPU copies while still enforcing @@ -812,6 +912,8 @@ def _search_best_scale( @torch.inference_mode() def _search_best_clip(self, layer, named_linears, input_feat): + """Searches per-layer clipping thresholds for AWQ-eligible linears.""" + clip_list = [] avoid_clipping = ["q_", "k_", "query", "key", "Wqkv"] @@ -839,6 +941,8 @@ def _compute_best_clip( max_shrink=0.5, n_sample_token=512, ): + """Finds the clipping bound that minimizes reconstruction error for a weight tensor.""" + assert w.dim() == 2 org_w_shape = w.shape # w [co, ci] -> [co, 1, n_group, group size] @@ -896,6 +1000,8 @@ def _compute_best_clip( return best_max_val.squeeze(1) def pseudo_quantize_tensor(self, w: torch.Tensor): + """Simulates AWQ quantization and returns dequantized weights plus scales/zeros.""" + org_w_shape = w.shape if self.qcfg.group_size > 0: assert org_w_shape[-1] % self.qcfg.group_size == 0, f"org_w_shape ({org_w_shape[-1]}) must be a multiple of group_size ({self.qcfg.group_size})!" @@ -942,6 +1048,8 @@ def pseudo_quantize_tensor(self, w: torch.Tensor): @torch.inference_mode() def _pseudo_quantize_tensor_into(self, src: torch.Tensor, dst: torch.Tensor) -> None: + """Writes pseudo-quantized values into a destination tensor without reallocating.""" + # Quantize `src` into `dst` without allocating a new tensor (mirrors pseudo_quantize_tensor) org_shape = src.shape if self.qcfg.group_size > 0: @@ -1078,6 +1186,8 @@ def _compute_loss( int_w_output: torch.Tensor, device: torch.device, ): + """Computes chunked mean-squared reconstruction loss under a memory cap.""" + loss = 0.0 fp16_output_flat = fp16_output.view(-1) int_w_output_flat = int_w_output.view(-1) @@ -1107,6 +1217,8 @@ def _compute_loss( def _module_forward( self, x: torch.Tensor, module: torch.nn.Module, module_kwargs: Dict ) -> torch.Tensor: + """Runs a module forward with sanitized kwargs and optional micro-batching.""" + target_device = None try: target_device = next(module.parameters()).device @@ -1182,9 +1294,9 @@ def _module_forward( effective_quant_batch_size = self._quant_batch_size if self._quant_batch_size and self._quant_batch_size > 0 else None if ( - effective_quant_batch_size is None - or x.dim() == 0 - or x.shape[0] <= effective_quant_batch_size + effective_quant_batch_size is None + or x.dim() == 0 + or x.shape[0] <= effective_quant_batch_size ): module_output = module(x, **module_kwargs) if isinstance(module_output, tuple): @@ -1192,6 +1304,8 @@ def _module_forward( return module_output def _slice_value(val, length): + """Slices batch-shaped kwargs to match a micro-batched forward chunk.""" + if isinstance(val, torch.Tensor) and val.shape[0] == module_kwargs.get("position_ids", val).shape[0]: return val[:length] if isinstance(val, torch.Tensor) and val.shape[0] != length: @@ -1219,12 +1333,14 @@ def _slice_value(val, length): return module_output def apply_quant(self, named_linears: Dict[str, NamedModule], scales_list): + """Pseudo-quantizes selected linears and stages AWQ tensors for packing.""" + start_time = time.time() for name, named_module in named_linears.items(): base_title = f"Quantizing {named_module.name} in layer" - self._pause_controller.register_and_draw_progress_bar(self.pb, title=base_title, subtitle="") - - linear_layer = named_module.module + self.draw_progress(base_title) + + linear_layer = self.resolve_quant_source_module(named_module) linear_layer = linear_layer.to(get_best_device()) tp_info = named_module.state.get("tp_pad_info") @@ -1320,17 +1436,6 @@ def apply_quant(self, named_linears: Dict[str, NamedModule], scales_list): # Log the new row self.log_new_row(stat) - # Mirror GPTQ-style visibility in the CLI so awq modules show up - # even when the table view is busy with progress updates. - log.info( - "awq | layer=%s module=%s loss=%s samples=%s time=%ss", - named_module.layer_index, - named_module.name, - loss_summary, - self._nsamples_total, - f"{duration:.3f}", - ) - def _sanitize_kwargs(self, inputs_kwargs, module): """ Remove the arguments that are not supported in the module's @@ -1350,10 +1455,16 @@ def _sanitize_kwargs(self, inputs_kwargs, module): sanitized_kwargs[k] = v return sanitized_kwargs - def preprocess(self, module: NamedModule, failsafe=None, **kwargs): + def preprocess(self, module: NamedModule, fallback=None, **kwargs): + """Registers a module with its layer state and initializes input capture buckets.""" + + # entire module is skipped + if self.qcfg.dynamic_get(layer_name=module.full_name) is False: + return + # Track the most recent preference so the processor can decide whether # to fall back to simple quantization when activations are missing. - self.failsafe = normalize_failsafe(failsafe, self.qcfg.failsafe) + self.fallback = normalize_fallback(fallback, self.qcfg.fallback) layer_state = self._get_layer_state(module.layer_index) with layer_state.lock: layer_state.modules[module.name] = module @@ -1376,10 +1487,20 @@ def preprocess(self, module: NamedModule, failsafe=None, **kwargs): entry.setdefault("inputs", []) def is_skipped(self, module: NamedModule) -> bool: - return False + """Reports whether preprocessing excluded this module from AWQ work.""" + + t = self.tasks.get(module.name, False) + if t == False: + return True + else: + return False def pre_process_fwd_hook(self, name: str) -> Callable[[Module, Tuple[torch.Tensor, ...], torch.Tensor], None]: + """Returns the forward hook that caches module input activations for AWQ.""" + def hook(module, inp: Tuple[torch.Tensor, ...], out: torch.Tensor): + """Records the module input tensor for later AWQ scale and clip search.""" + if not inp: return feature = inp @@ -1397,6 +1518,8 @@ def process( subset_index: Optional[int] = None, subset_total: Optional[int] = None, ): + """Accumulates subset progress and triggers layer quantization when ready.""" + self._refresh_forward_kwargs_from_cache() layer_index = module.layer_index state = self._get_layer_state(layer_index) @@ -1443,9 +1566,13 @@ def process( # submodule_finalized is called in reverse after all next sequential processes are called def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs): + """Delegates AWQ module packing to the shared pack helper.""" + self.pack_module(module) def pack_module(self, module): + """Creates the AWQ quantized module and packs saved scales/zero-points into it.""" + # generate complete, safe to move to cpu # cleanup all memory or states vars persistently added by this processor module.stream_sync() @@ -1480,7 +1607,7 @@ def pack_module(self, module): create_quant_module( name=module.full_name, linear_cls=quant_linear_cls, - bits=self.qcfg.bits, + bits=self.qcfg.runtime_bits, desc_act=self.qcfg.desc_act, dynamic=self.qcfg.dynamic, group_size=self.qcfg.group_size, @@ -1490,6 +1617,7 @@ def pack_module(self, module): device=self.qcfg.device, lm_head_name=self.gptq_model.lm_head, pack_dtype=self.qcfg.pack_dtype, + format=self.format, register_buffers=False, ) if timer is not None and create_start is not None: @@ -1530,14 +1658,18 @@ def pack_module(self, module): ) def finalize(self, model: BaseQModel, **kwargs): + """Marks the model as AWQ-quantized and runs shared finalization logic.""" + # set quantized state model.quantized = True - model.quantize_config.quant_method = METHOD.AWQ + model.quantize_config.method = METHOD.AWQ super().finalize(model=model, **kwargs) def verify_calibration_dataset(self, processor_index: int) -> bool: + """Ensures AWQ received calibration data before the quantization loop starts.""" + if self.calibration_dataset is None: raise ValueError("GPTQProcessor's calibration_dataset must be provided.") else: @@ -1545,9 +1677,13 @@ def verify_calibration_dataset(self, processor_index: int) -> bool: @classmethod def name(cls) -> str: + """Returns the processor label used in logs and lifecycle reporting.""" + return "awq" def has_captured_input_ids(self, name: str) -> bool: + """Reports whether a module has any non-empty captured AWQ activations.""" + entry = self.tasks.get(name) or {} tensors: List[torch.Tensor] = entry.get("inputs", []) return tensors is not None and len(tensors) > 0 and all(t.numel() > 0 for t in tensors) diff --git a/gptqmodel/looper/dequantize_processor.py b/gptqmodel/looper/dequantize_processor.py index 138c70833..4fdba3fed 100644 --- a/gptqmodel/looper/dequantize_processor.py +++ b/gptqmodel/looper/dequantize_processor.py @@ -7,23 +7,38 @@ import torch -from ..looper.loop_processor import LoopProcessor +from ..looper.loop_processor import ExecutionConfig, LoopProcessor from ..looper.named_module import NamedModule from ..models import BaseQModel -from ..nn_modules.qlinear.torch import TorchQuantLinear +from ..nn_modules.qlinear.torch import TorchLinear from ..utils.logger import setup_logger log = setup_logger() class DequantizeProcessor(LoopProcessor): - def __init__(self, quantized_modules: Dict[str, TorchQuantLinear]): - super().__init__(tokenizer=None, qcfg=None, calibration=None, calibration_concat_size=None, - prepare_dataset_func=None, batch_size=1, - require_fwd=False) + """Restores quantized weights to dense tensors for comparison or recovery flows.""" + + def __init__(self, quantized_modules: Dict[str, TorchLinear]): + """Initializes the processor with the quantized modules to dequantize.""" + + super().__init__( + tokenizer=None, + qcfg=None, + calibration=None, + calibration_concat_size=None, + prepare_dataset_func=None, + batch_size=1, + execution_config=ExecutionConfig( + require_fwd=False, + fwd_replay_after_process=False, + ), + ) self.quantized_modules = quantized_modules def set_calibration_dataset(self, calibration_dataset): + """Disables calibration inputs because dequantization is weight-only.""" + self.calibration_dataset = None self.num_batches = 0 @@ -37,6 +52,8 @@ def process( subset_index: Optional[int] = None, subset_total: Optional[int] = None, ): + """Dequantizes a module, preserving tensor-parallel padding when needed.""" + device = module.weight.device # TODO fix num_itr param..need to calculate this before dequant @@ -70,11 +87,17 @@ def process( }) def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs): + """Drops temporary dequantization tensors after downstream consumers finish.""" + module.state.pop("w", None) # no need for these weights now module.state.pop("wq", None) # no need for these weights now def verify_calibration_dataset(self, processor_index: int) -> bool: + """Reports that no calibration dataset is required for this processor.""" + return False def name(self) -> str: + """Returns the processor label used in logs and lifecycle reporting.""" + return "de-quantize" diff --git a/gptqmodel/looper/eora_processor.py b/gptqmodel/looper/eora_processor.py index 46aafe616..72d09d131 100644 --- a/gptqmodel/looper/eora_processor.py +++ b/gptqmodel/looper/eora_processor.py @@ -12,7 +12,7 @@ from ..adapter.adapter import Lora from ..eora.eora import eora_compute_lora, eora_process_input, merge_eora_segments -from ..looper.loop_processor import DTYPE_SIZE_COLUMN, MODULE_FEATURE_COLUMN, LoopProcessor +from ..looper.loop_processor import DTYPE_SIZE_COLUMN, ExecutionConfig, MODULE_FEATURE_COLUMN, LoopProcessor from ..looper.named_module import NamedModule from ..models import BaseQModel from ..models.writer import (PROCESS_LOG_FWD_TIME, PROCESS_LOG_LAYER, PROCESS_LOG_MODULE, @@ -28,6 +28,8 @@ class EoraProcessor(LoopProcessor): + """Builds LoRA-style error adapters from dequantization residuals and activations.""" + def __init__( self, tokenizer, @@ -40,6 +42,8 @@ def __init__( require_fwd: bool = True, calibration_concat_separator: Optional[str] = None, ): + """Initializes EoRA processing and per-module segment accumulation state.""" + super().__init__( tokenizer=tokenizer, qcfg=qcfg, @@ -49,7 +53,7 @@ def __init__( calibration_concat_separator=calibration_concat_separator, prepare_dataset_func=prepare_dataset_func, batch_size=batch_size, - require_fwd=require_fwd, + execution_config=ExecutionConfig(require_fwd=require_fwd), ) # Track per-module segment accumulators keyed by device so we can merge @@ -71,10 +75,14 @@ def __init__( self.eora_process_input = eora_process_input def set_calibration_dataset(self, calibration_dataset): + """Stores the calibration dataset because EoRA depends on batch counts.""" + self.calibration_dataset = calibration_dataset self.num_batches = len(calibration_dataset) def preprocess(self, module: NamedModule, **kwargs): + """Clones adapter config, applies rank overrides, and initializes accumulators.""" + # entire module is skipped if self.qcfg.dynamic_get(layer_name=module.full_name) == False: module.adapter_cfg = None # hack @@ -102,11 +110,17 @@ def preprocess(self, module: NamedModule, **kwargs): return def is_skipped(self, module: NamedModule) -> bool: + """Reports whether EoRA was disabled for this module by dynamic config.""" + # dynamic override removed eora processing for this module return module.adapter_cfg in [None, {}] def pre_process_fwd_hook(self, name: str) -> Callable[[Module, Tuple[torch.Tensor, ...], torch.Tensor], None]: + """Returns the forward hook that accumulates EoRA activation statistics.""" + def tmp(module, input: Tuple[torch.Tensor, ...], output: torch.Tensor): + """Processes one batch of inputs into an EoRA contribution segment.""" + batch_index = self.current_batch_index() batch, contribution, scale = self.eora_process_input( input=input, @@ -133,6 +147,8 @@ def _accumulate_eora_contribution( contribution: torch.Tensor, scale: float, ) -> None: + """Merges one EoRA contribution segment into the per-device accumulator.""" + if batch <= 0: return @@ -182,6 +198,8 @@ def _accumulate_eora_contribution( del contribution def _finalize_eigen_scaling_matrix(self, name: str) -> torch.Tensor: + """Merges accumulated EoRA segments into the final scaling matrix.""" + with self.lock: segments = self._segment_accumulators.pop(name, {}) target_device = self._module_target_devices.pop(name, None) @@ -218,6 +236,8 @@ def process( subset_index: Optional[int] = None, subset_total: Optional[int] = None, ): + """Computes and installs the LoRA correction for one quantized module.""" + assert isinstance(module.adapter_cfg, Lora) self.pb.title(f"EoRA: Processing {module.name} ({module.module_dtype}) in layer").draw() @@ -344,10 +364,14 @@ def process( module.state.pop("tp_pad_info", None) def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs): + """Stores the finalized adapter object in the processor result map.""" + # logger.info(f"Quantizing module END: {name}, {gptq[name].shape()}") self.result_save(module.full_name, module.state.pop("adapter")) def finalize(self, model: BaseQModel, **kwargs): + """Releases accumulators and attaches the collected adapters to the model.""" + del self._segment_accumulators del self._module_target_devices @@ -357,6 +381,8 @@ def finalize(self, model: BaseQModel, **kwargs): super().finalize(model=model, **kwargs) def verify_calibration_dataset(self, processor_index: int) -> bool: + """Requires calibration on the first EoRA stage and reuses later caches thereafter.""" + if self.calibration_dataset is None: if processor_index == 0: raise ValueError("EoraProcessor's calibration_dataset must be provided.") @@ -365,4 +391,6 @@ def verify_calibration_dataset(self, processor_index: int) -> bool: return True def name(self) -> str: + """Returns the processor label used in logs and lifecycle reporting.""" + return "eora" diff --git a/gptqmodel/looper/exllamav3_processor.py b/gptqmodel/looper/exllamav3_processor.py new file mode 100644 index 000000000..e71e46295 --- /dev/null +++ b/gptqmodel/looper/exllamav3_processor.py @@ -0,0 +1,393 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from __future__ import annotations + +import copy +import threading +import time +from typing import Callable, Dict, Optional, Tuple + +import torch +import transformers +from torch.nn import Module +from torch.nn.modules.conv import _ConvNd + +from ..exllamav3.modules.quant.exl3_lib.quantize import quantize_exl3 +from ..looper.loop_processor import DTYPE_SIZE_COLUMN, ExecutionConfig, MODULE_FEATURE_COLUMN, LoopProcessor +from ..looper.named_module import NamedModule +from ..models import BaseQModel +from ..models.writer import ( + PROCESS_LOG_FWD_TIME, + PROCESS_LOG_LAYER, + PROCESS_LOG_MODULE, + PROCESS_LOG_NAME, + PROCESS_LOG_TIME, + PROCESS_USED_MEMORY, + QUANT_LOG_DAMP, + QUANT_LOG_LOSS, + QUANT_LOG_NSAMPLES, +) +from ..nn_modules.exllamav3 import ExllamaV3Linear +from ..quantization import QuantizeConfig +from ..quantization.config import EXL3Config, FORMAT, GPTQConfig, METHOD +from ..quantization.gptq import GPTQ +from ..utils.device import get_device +from ..utils.exllamav3 import create_exllamav3_module +from ..utils.logger import setup_logger +from ..utils.module_locks import parent_module_lock + + +setup_logger() + +_EXL3_SIGMA_REG = 0.025 +_OUT_SCALES_TO_ARG = { + "always": True, + "never": False, + "auto": None, + None: None, +} + + +def clone_exllamav3_config_for_module( + qcfg: EXL3Config, + module_full_name: str, +) -> Optional[EXL3Config]: + """Clones and applies per-module EXL3 dynamic overrides, or skips the module.""" + + if qcfg.dynamic_get(layer_name=module_full_name) == False: + return None + + qcfg_clone = copy.deepcopy(qcfg) + + if qcfg.dynamic is not None: + qcfg_clone.bits = qcfg.dynamic_get(module_full_name, "bits", qcfg_clone.bits) + qcfg_clone.head_bits = qcfg.dynamic_get(module_full_name, "head_bits", qcfg_clone.head_bits) + + out_scales_override = qcfg.dynamic_get(module_full_name, "out_scales", None) + if out_scales_override is not None: + qcfg_clone.out_scales = out_scales_override + + codebook_override = qcfg.dynamic_get(module_full_name, "codebook", None) + if codebook_override is not None: + qcfg_clone.codebook = codebook_override + + calibration_override = qcfg.dynamic_get(module_full_name, "calibration", None) + if calibration_override is not None: + qcfg_clone.calibration = calibration_override + + qcfg_clone.__post_init__() + return qcfg_clone + + +class EXL3Processor(LoopProcessor): + """Captures activations and repacks modules into ExLlamaV3 format.""" + + def __init__( + self, + tokenizer, + qcfg: QuantizeConfig, + calibration, + prepare_dataset_func, + calibration_concat_size: Optional[int], + calibration_sort: Optional[str], + batch_size: int, + require_fwd: bool = True, + calibration_concat_separator: Optional[str] = None, + lm_head_name: str = "lm_head", + ): + """Initializes EXL3 processing and tracks the lm_head naming convention.""" + + super().__init__( + tokenizer=tokenizer, + qcfg=qcfg, + calibration=calibration, + calibration_concat_size=calibration_concat_size, + calibration_sort=calibration_sort, + calibration_concat_separator=calibration_concat_separator, + prepare_dataset_func=prepare_dataset_func, + batch_size=batch_size, + execution_config=ExecutionConfig( + require_fwd=require_fwd, + fwd_replay_after_process=True, + subset_forward_early_stop=True, + ), + ) + + self.avg_losses = [] + self.lm_head_name = lm_head_name + self._stats_lock = threading.Lock() + + def set_calibration_dataset(self, calibration_dataset): + """Rejects dataset replacement because EXL3 capture is fixed at construction.""" + + raise NotImplementedError("EXL3Processor's calibration_dataset cannot be modified") + + def preprocess(self, module: NamedModule, fallback=None, **kwargs): + """Builds the capture task and effective EXL3 config for one module.""" + + del fallback, kwargs + + module_qcfg = clone_exllamav3_config_for_module(self.qcfg, module.full_name) + if module_qcfg is None: + return + + capture_qcfg = GPTQConfig( + bits=max(1, module_qcfg.runtime_bits), + group_size=-1, + desc_act=False, + sym=True, + device=module_qcfg.device, + pack_dtype=module_qcfg.pack_dtype, + ) + + task = GPTQ(module=module, qcfg=capture_qcfg) + task.expected_nsamples = getattr(self, "total_calibration_tokens", None) + task.quantizer.configure(perchannel=True) + + self.tasks[module.name] = { + "capture": task, + "qcfg": module_qcfg, + } + + def is_skipped(self, module: NamedModule) -> bool: + """Reports whether preprocessing omitted this module from EXL3 work.""" + + return self.tasks.get(module.name, False) is False + + def pre_process_fwd_hook(self, name: str) -> Callable[[Module, Tuple[torch.Tensor, ...], torch.Tensor], None]: + """Returns the forward hook that feeds captured batches into the EXL3 task.""" + + def tmp(module, inp: Tuple[torch.Tensor, ...], out: torch.Tensor): + """Records one activation batch for the EXL3 capture task.""" + + capture = self.tasks[name]["capture"] + batch_idx = self.current_batch_index() + capture.add_batch(inp[0].data, out.data, batch_index=batch_idx) + del inp, out + + return tmp + + def _is_lm_head(self, module: NamedModule) -> bool: + """Returns whether the named module corresponds to the model lm_head.""" + + if module.full_name == self.lm_head_name: + return True + return module.full_name.endswith(f".{self.lm_head_name}") + + def _target_bits(self, module: NamedModule, module_qcfg: EXL3Config) -> int: + """Chooses lm_head-specific bitwidth overrides when configured.""" + + if self._is_lm_head(module) and module_qcfg.head_bits is not None: + return max(1, int(module_qcfg.head_bits)) + return max(1, module_qcfg.runtime_bits) + + def _build_quant_args( + self, + module: NamedModule, + module_qcfg: EXL3Config, + device: torch.device, + ) -> Dict[str, object]: + """Builds the argument bundle passed into the EXL3 quantizer.""" + + quant_args: Dict[str, object] = { + "K": self._target_bits(module, module_qcfg), + "devices": [device], + "apply_out_scales": _OUT_SCALES_TO_ARG.get(module_qcfg.out_scales, None), + "sigma_reg": _EXL3_SIGMA_REG, + "seed": 787, + } + + if module_qcfg.codebook == "mcg": + quant_args["mcg"] = True + elif module_qcfg.codebook == "mul1": + quant_args["mul1"] = True + + return quant_args + + def _quant_input_weight(self, capture: GPTQ, device: torch.device) -> torch.Tensor: + """Exports the captured dense weight matrix in EXL3 quantizer layout.""" + + normalized = capture.clone_module(copy=True, device=device) + return normalized.t().contiguous().to(torch.float32) + + def _restore_module_weight(self, module: NamedModule, quantized_weight: torch.Tensor) -> torch.Tensor: + """Reshapes the EXL3 output weight back into the wrapped module layout.""" + + target = module.module if isinstance(module, NamedModule) else module + + if isinstance(target, transformers.Conv1D): + return quantized_weight.contiguous().view_as(target.weight.data) + + if isinstance(target, (torch.nn.Linear, _ConvNd)): + return quantized_weight.t().contiguous().view_as(target.weight.data) + + raise NotImplementedError(f"Unsupported EXL3 module type: {target.__class__.__name__}") + + def process( + self, + module: NamedModule, + device: torch.device = None, + subset: Optional[Dict[str, NamedModule]] = None, + previous_subset: Optional[Dict[str, NamedModule]] = None, + subset_index: Optional[int] = None, + subset_total: Optional[int] = None, + ): + """Runs EXL3 quantization for one module and stages its packed tensors.""" + + del subset, previous_subset, subset_index, subset_total + + base_title = f"Quantizing {module.name} in layer" + self.draw_progress(base_title) + + task_entry = self.tasks[module.name] + capture: GPTQ = task_entry["capture"] + module_qcfg: EXL3Config = task_entry["qcfg"] + + target_device = device or get_device(module.module) + target_device = torch.device(target_device) + if target_device.type != "cuda": + raise ValueError("EXL3 quantization requires CUDA/HIP execution.") + + start_time = time.perf_counter() + capture.finalize_hessian(target_device=target_device) + hessian = capture.H + if hessian is None: + raise RuntimeError(f"EXL3 failed to capture Hessian for module `{module.full_name}`.") + if capture.nsamples <= 0: + raise RuntimeError(f"EXL3 captured no calibration activations for module `{module.full_name}`.") + + h_data = { + "H": hessian, + "count": capture.nsamples, + "finalized": False, + } + + quant_args = self._build_quant_args(module, module_qcfg, target_device) + input_weight = self._quant_input_weight(capture, target_device) + weight_q, proxy_err, out_tensors = quantize_exl3( + weight=input_weight, + H_data=h_data, + quant_args=quant_args, + return_weight_q=True, + ) + duration = time.perf_counter() - start_time + + stream_payload = dict(out_tensors) + if module.bias is not None: + stream_payload["bias"] = module.bias.detach() + module.stream_state_payload_to_cpu(stream_payload) + + restored_weight = self._restore_module_weight(module, weight_q) + module.weight.data = restored_weight.to(dtype=module.weight.dtype) + + workspace_summary = getattr(capture, "_borrow_workspace_last_summary", None) + workspace_totals = getattr(capture, "_borrow_workspace_totals", None) + + if isinstance(proxy_err, str): + loss_display = proxy_err + else: + loss_display = f"{proxy_err:.10f}" if isinstance(proxy_err, (int, float)) else "unknown" + + stat = { + PROCESS_LOG_NAME: self.name(), + PROCESS_LOG_LAYER: module.layer_index, + PROCESS_LOG_MODULE: module.name, + MODULE_FEATURE_COLUMN: self.module_feature_summary(module), + DTYPE_SIZE_COLUMN: self.module_dtype_size_summary(module), + QUANT_LOG_LOSS: loss_display, + QUANT_LOG_NSAMPLES: f"{capture.nsamples}", + QUANT_LOG_DAMP: f"{_EXL3_SIGMA_REG:.5f}", + PROCESS_LOG_TIME: f"{duration:.3f}", + PROCESS_LOG_FWD_TIME: self.formatted_fwd_time(), + PROCESS_USED_MEMORY: self.device_memory_report(), + } + + if workspace_summary: + requests = int(workspace_summary.get("requests", 0) or 0) + if requests: + hit_rate = float(workspace_summary.get("hit_rate", 0.0) or 0.0) + chunk_rows = workspace_summary.get("chunk_rows") + stat["workspace_cache_requests"] = str(requests) + stat["workspace_cache_hit_rate"] = f"{hit_rate:.1%}" + stat["workspace_stage_dtype"] = workspace_summary.get("staging_dtype", "") + if chunk_rows is not None: + stat["workspace_chunk_rows"] = str(chunk_rows) + if workspace_totals: + total_requests = int(workspace_totals.get("requests", 0) or 0) + if total_requests: + cumulative_hit_rate = ( + float(workspace_totals.get("materialized_hits", 0) or 0.0) / total_requests + ) + stat["workspace_total_requests"] = str(total_requests) + stat["workspace_total_hit_rate"] = f"{cumulative_hit_rate:.1%}" + + if self.qcfg.dynamic is not None: + stat["dynamic"] = self.qcfg.dynamic_get(layer_name=module.full_name) + + with self._stats_lock: + self.durations.append(duration) + if isinstance(proxy_err, (int, float)): + self.avg_losses.append(proxy_err) + self.module_names.append(f"layer-{module.layer_index}-{module.name}") + self.log.append(stat) + + self.log_new_row(stat) + + capture.free() + del input_weight, restored_weight, weight_q, out_tensors, stream_payload + + def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs): + """Builds and installs the ExLlamaV3 module from the staged tensors.""" + + del kwargs + + module.stream_sync() + + tensors: Dict[str, torch.Tensor] = {} + with self._stats_lock: + module.state.pop("w", None) + for tensor_name in ("trellis", "suh", "svh", "su", "sv", "bias", "mcg", "mul1"): + tensor = module.state.pop(tensor_name, None) + if tensor is not None: + tensors[tensor_name] = tensor.clone() + + parent_key = getattr(module, "full_name", getattr(module, "name", None)) + with parent_module_lock(parent_key): + create_exllamav3_module( + module_root=model.model, + name=module.full_name, + submodule=module, + tensors=tensors, + ) + + module.unregister_parameter("weight") + if getattr(module, "bias", None) is not None: + module.unregister_parameter("bias") + + def finalize(self, model: BaseQModel, **kwargs): + """Marks the model as EXL3-quantized and runs shared finalization logic.""" + + model.quantized = True + model.quantize_config.method = METHOD.EXL3 + model.quantize_config.format = FORMAT.EXL3 + model.qlinear_kernel = ExllamaV3Linear + super().finalize(model=model, **kwargs) + + def verify_calibration_dataset(self, processor_index: int) -> bool: + """Ensures EXL3 received calibration data before the quantization loop starts.""" + + del processor_index + if self.calibration_dataset is None: + raise ValueError("EXL3Processor's calibration_dataset must be provided.") + return True + + def name(self) -> str: + """Returns the processor label used in logs and lifecycle reporting.""" + + return "exl3" + + +__all__ = ["EXL3Processor", "clone_exllamav3_config_for_module"] diff --git a/gptqmodel/looper/forward_executor.py b/gptqmodel/looper/forward_executor.py new file mode 100644 index 000000000..2e37f1530 --- /dev/null +++ b/gptqmodel/looper/forward_executor.py @@ -0,0 +1,584 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +"""Forward execution logic for cached layer and subset batches.""" + +from __future__ import annotations + +from contextlib import nullcontext +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple + +import torch + +from .. import DEVICE_THREAD_POOL +from ..nn_modules.hooked_linear import StopForward +from ..utils.attn_mask import normalize_seq_mask +from ..utils.logger import setup_logger +from ..utils.looper_helpers import ( + clone_module_for_devices, + forward_batch_worker, + rehome_module_to_device, + select_forward_devices, +) +from ..utils.model import move_to, nested_move_to +from ..utils.torch import torch_sync + +if TYPE_CHECKING: # pragma: no cover - imports for typing only + from logbar.progress import ProgressBar + + from .loop_processor import LoopProcessor + from .module_looper import ModuleLooper + + +class ForwardExecutor: + """Own the layer/subset forward execution logic for ModuleLooper.""" + + def __init__(self, looper: "ModuleLooper", logger=None) -> None: + """Bind the executor to the looper that owns device state and helpers.""" + + self.looper = looper + self.log = logger or setup_logger() + + def _resolve_batch_progress( + self, + processor: "LoopProcessor", + layer_inputs: List[List[torch.Tensor]], + progress_rows_per_batch: Optional[List[int]] = None, + progress_total_rows: Optional[int] = None, + ) -> Tuple[int, List[int], int]: + """Normalize batch and row progress accounting for a forward pass.""" + + total_batches = self.looper._resolve_batch_total(processor.num_batches, layer_inputs) + batch_row_counts = progress_rows_per_batch or self.looper._collect_row_counts(layer_inputs) + batch_row_counts = list(batch_row_counts) + if len(batch_row_counts) > total_batches: + batch_row_counts = batch_row_counts[:total_batches] + elif len(batch_row_counts) < total_batches: + batch_row_counts.extend([0] * (total_batches - len(batch_row_counts))) + + total_rows = progress_total_rows if progress_total_rows is not None else sum(batch_row_counts) + if total_rows <= 0 and total_batches > 0: + total_rows = total_batches + total_rows = max(total_rows, 1) + return total_batches, batch_row_counts, total_rows + + def _moe_forward_context( + self, + *, + module: torch.nn.Module, + processor: "LoopProcessor", + apply_moe_config: bool, + ): + """Pick the MoE routing context for a forward pass.""" + + if not apply_moe_config: + # Replay forwards opt out of quant-time MoE overrides and bypass hooks. + return nullcontext() + if self.looper.moe_routing_override: + return self.looper.MoERoutingOverrideContext(module, self.looper.moe_routing_override) + if not getattr(self.looper, "moe_routing_bypass", False): + return nullcontext() + + should_use_lifecycle = getattr(self.looper, "_should_use_moe_lifecycle", None) + if callable(should_use_lifecycle) and not should_use_lifecycle(module, processor): + return nullcontext() + + return self.looper.MoELifecycleContext( + self.looper, + module, + processor, + self.looper._current_subset, + ) + + def run( + self, + *, + module: torch.nn.Module, + processor: "LoopProcessor", + layer_inputs: List[List[torch.Tensor]], + layer_input_kwargs: List[Dict[str, torch.Tensor]], + position_ids: List[torch.Tensor], + attention_masks: List[torch.Tensor], + cur_layer_device: torch.device, + is_lm_head_module: bool, + shared_kv_cache_dict: Dict[int, torch.Tensor], + layer_index: int, + need_outputs: bool, + reuse_kv: bool, + progress_pb: "ProgressBar" | None = None, + progress_title: Optional[str] = None, + progress_stage: Optional[str] = None, + progress_rows_per_batch: Optional[List[int]] = None, + progress_total_rows: Optional[int] = None, + force_serial: bool = False, + preserve_module_devices: bool = False, + apply_moe_config: bool = True, + select_forward_devices_fn: Callable[[Optional[torch.device]], List[torch.device]] = select_forward_devices, + ) -> List[List[torch.Tensor]]: + """Dispatch the cached batches through the most appropriate forward path.""" + + if not force_serial: + quant_config = getattr(self.looper.gptq_model, "quantize_config", None) + if quant_config is not None and not getattr(quant_config, "auto_forward_data_parallel", True): + force_serial = True + + if force_serial: + return self.run_single( + module=module, + processor=processor, + layer_inputs=layer_inputs, + layer_input_kwargs=layer_input_kwargs, + position_ids=position_ids, + attention_masks=attention_masks, + cur_layer_device=cur_layer_device, + is_lm_head_module=is_lm_head_module, + shared_kv_cache_dict=shared_kv_cache_dict, + layer_index=layer_index, + need_outputs=need_outputs, + reuse_kv=reuse_kv, + progress_pb=progress_pb, + progress_title=progress_title, + progress_stage=progress_stage, + progress_rows_per_batch=progress_rows_per_batch, + progress_total_rows=progress_total_rows, + preserve_module_devices=preserve_module_devices, + apply_moe_config=apply_moe_config, + ) + + devices = select_forward_devices_fn(cur_layer_device) + if len(devices) <= 1: + return self.run_single( + module=module, + processor=processor, + layer_inputs=layer_inputs, + layer_input_kwargs=layer_input_kwargs, + position_ids=position_ids, + attention_masks=attention_masks, + cur_layer_device=cur_layer_device, + is_lm_head_module=is_lm_head_module, + shared_kv_cache_dict=shared_kv_cache_dict, + layer_index=layer_index, + need_outputs=need_outputs, + reuse_kv=reuse_kv, + progress_pb=progress_pb, + progress_title=progress_title, + progress_stage=progress_stage, + progress_rows_per_batch=progress_rows_per_batch, + progress_total_rows=progress_total_rows, + preserve_module_devices=preserve_module_devices, + apply_moe_config=apply_moe_config, + ) + + return self.run_parallel( + module=module, + processor=processor, + layer_inputs=layer_inputs, + layer_input_kwargs=layer_input_kwargs, + position_ids=position_ids, + attention_masks=attention_masks, + cur_layer_device=cur_layer_device, + is_lm_head_module=is_lm_head_module, + shared_kv_cache_dict=shared_kv_cache_dict, + layer_index=layer_index, + need_outputs=need_outputs, + reuse_kv=reuse_kv, + devices=devices, + progress_pb=progress_pb, + progress_title=progress_title, + progress_stage=progress_stage, + progress_rows_per_batch=progress_rows_per_batch, + progress_total_rows=progress_total_rows, + apply_moe_config=apply_moe_config, + ) + + def run_single( + self, + *, + module: torch.nn.Module, + processor: "LoopProcessor", + layer_inputs: List[List[torch.Tensor]], + layer_input_kwargs: List[Dict[str, torch.Tensor]], + position_ids: List[torch.Tensor], + attention_masks: List[torch.Tensor], + cur_layer_device: torch.device, + is_lm_head_module: bool, + shared_kv_cache_dict: Dict[int, torch.Tensor], + layer_index: int, + need_outputs: bool, + reuse_kv: bool, + progress_pb: "ProgressBar" | None = None, + progress_title: Optional[str] = None, + progress_stage: Optional[str] = None, + progress_rows_per_batch: Optional[List[int]] = None, + progress_total_rows: Optional[int] = None, + preserve_module_devices: bool = False, + apply_moe_config: bool = True, + ) -> List[List[torch.Tensor]]: + """Run the forward pass sequentially on the current device.""" + + outputs: List[List[torch.Tensor]] = [] + prev_kv = shared_kv_cache_dict.get(layer_index - 1) if reuse_kv else None + total_batches, batch_row_counts, total_rows = self._resolve_batch_progress( + processor, + layer_inputs, + progress_rows_per_batch=progress_rows_per_batch, + progress_total_rows=progress_total_rows, + ) + processed_rows = 0 + stage_label = progress_stage or "Forward" + + for batch_idx in range(total_batches): + processor._set_current_batch_index(batch_idx) + try: + exec_device = cur_layer_device + if preserve_module_devices: + module_target = getattr(module, "target_device", None) + if module_target is not None: + exec_device = module_target + + # Capture input device before moving - used for output placement + input_device = layer_inputs[batch_idx][0].device if layer_inputs[batch_idx] else cur_layer_device + + layer_input = [move_to(inp, device=exec_device) for inp in layer_inputs[batch_idx]] + + raw_mask = attention_masks[batch_idx] + attn_tensor = raw_mask if raw_mask is None else move_to(raw_mask, device=exec_device) + + keep_mask = None + if attn_tensor is not None: + seq_len = layer_input[0].shape[1] if (len(layer_input) > 0 and layer_input[0].dim() >= 2) else None + keep_mask = normalize_seq_mask(attn_tensor, seq_len=seq_len) + + self.looper._set_processor_mask(processor, keep_mask) + additional_inputs: Dict[str, Optional[torch.Tensor]] = {} + if self.looper.support_batch_quantize and attn_tensor is not None: + additional_inputs["attention_mask"] = attn_tensor + else: + additional_inputs["attention_mask"] = None + + if position_ids: + pos = position_ids[batch_idx] + if pos is not None: + additional_inputs["position_ids"] = move_to(pos, device=exec_device) + + for key, value in layer_input_kwargs[batch_idx].items(): + if key in ["past_key_values", "past_key_value"]: + continue + additional_inputs[key] = nested_move_to(value, device=exec_device) + + if reuse_kv and prev_kv is not None: + additional_inputs["kv_last_layer"] = nested_move_to(prev_kv, device=exec_device) + + additional_inputs["use_cache"] = False + additional_inputs = self.looper.gptq_model.prepare_layer_replay_kwargs( + layer=module, + layer_input=layer_input, + additional_inputs=additional_inputs, + target_device=exec_device, + ) + + if not preserve_module_devices: + rehome_module_to_device(module, cur_layer_device, move_parameters=True, move_buffers=True) + + with self._moe_forward_context( + module=module, + processor=processor, + apply_moe_config=apply_moe_config, + ): + module_output = None + try: + if is_lm_head_module: + module_output = module(*layer_input) + else: + module_output = module(*layer_input, **additional_inputs) + except StopForward: + module_output = None + finally: + self.looper._set_processor_mask(processor, None) + + del layer_input + del attn_tensor + del keep_mask + del additional_inputs + + if ( + reuse_kv + and module_output is not None + and isinstance(module_output, tuple) + and len(module_output) > 0 + and shared_kv_cache_dict.get(layer_index) is None + ): + shared_kv_cache_dict[layer_index] = module_output[-1] + + if need_outputs and module_output is not None: + primary = module_output[0] if isinstance(module_output, tuple) else module_output + # Move output back to the same device where input was stored + # This preserves calibration data placement when calibration_data_device is set + calib_device_cfg = self.looper.gptq_model.quantize_config.calibration_data_device + target_device = input_device if calib_device_cfg is not None else cur_layer_device + primary = move_to(primary, device=target_device) + outputs.append([primary]) + + if module_output is not None: + del module_output + + rows_for_batch = batch_row_counts[batch_idx] if batch_idx < len(batch_row_counts) else 0 + if rows_for_batch <= 0: + rows_for_batch = self.looper._batch_row_count(layer_inputs[batch_idx]) if layer_inputs and batch_idx < len(layer_inputs) else 1 + rows_for_batch = max(rows_for_batch, 1) + + processed_rows = min(processed_rows + rows_for_batch, total_rows) + if progress_pb is not None: + if progress_title: + progress_pb.title(progress_title) + progress_pb.current_iter_step = processed_rows + progress_pb.subtitle(f"{stage_label} rows {processed_rows}/{total_rows}").draw() + finally: + processor._set_current_batch_index(None) + + return outputs + + def run_parallel( + self, + *, + module: torch.nn.Module, + processor: "LoopProcessor", + layer_inputs: List[List[torch.Tensor]], + layer_input_kwargs: List[Dict[str, torch.Tensor]], + position_ids: List[torch.Tensor], + attention_masks: List[torch.Tensor], + cur_layer_device: torch.device, + is_lm_head_module: bool, + shared_kv_cache_dict: Dict[int, torch.Tensor], + layer_index: int, + need_outputs: bool, + reuse_kv: bool, + devices: List[torch.device], + progress_pb: "ProgressBar" | None = None, + progress_title: Optional[str] = None, + progress_stage: Optional[str] = None, + progress_rows_per_batch: Optional[List[int]] = None, + progress_total_rows: Optional[int] = None, + apply_moe_config: bool = True, + clone_module_for_devices_fn=clone_module_for_devices, + forward_batch_worker_fn=forward_batch_worker, + device_thread_pool=DEVICE_THREAD_POOL, + ) -> List[List[torch.Tensor]]: + """Fan batches across device replicas and preserve result ordering.""" + + effective_title = progress_title or (progress_stage or "Forward") + total_batches, batch_row_counts, total_rows = self._resolve_batch_progress( + processor, + layer_inputs, + progress_rows_per_batch=progress_rows_per_batch, + progress_total_rows=progress_total_rows, + ) + stage_label = progress_stage or "Forward" + + replica_pb: "ProgressBar" | None = None + replica_title = "" + replica_completed = 0 + + if progress_pb is not None: + progress_pb.title(effective_title) + if len(devices) > 1: + replica_title = f"{stage_label}: replicate to {len(devices)} devices" + replica_pb = self.log.pb(range(len(devices))).manual().set(show_left_steps=False) + replica_pb.title(replica_title).subtitle("Staging module...").draw() + else: + device_label = str(devices[0]) if devices else "" + progress_pb.subtitle(f"{stage_label}: staging on {device_label}").draw() + + def _replica_progress(idx: int, total: int, device: torch.device, step: str) -> None: + """Update the progress bar while replicas are materialized.""" + + nonlocal replica_completed + device_label = str(device) + if replica_pb is not None: + if step == "stage": + replica_pb.title(replica_title).subtitle(f"Stage {device_label}").draw() + return + if idx > replica_completed: + replica_completed = idx + replica_pb.title(replica_title).subtitle(f"{device_label} {idx}/{total}").next().draw() + else: + replica_pb.title(replica_title).subtitle(f"{device_label} {idx}/{total}").draw() + elif progress_pb is not None: + stage_msg = ( + f"{stage_label}: staging on {device_label}" + if step == "stage" + else f"{stage_label}: {step} {idx}/{total} on {device_label}" + ) + progress_pb.title(effective_title).subtitle(stage_msg).draw() + + progress_cb = _replica_progress if progress_pb is not None else None + + torch_sync() + + try: + module_replicas = clone_module_for_devices_fn( + module, + devices, + progress_callback=progress_cb, + ) + finally: + if replica_pb is not None: + replica_pb.close() + if progress_pb is not None: + progress_pb.title(effective_title).subtitle(f"{stage_label} rows 0/{total_rows}").draw() + + moe_contexts = [] + try: + for _device, replica in module_replicas.items(): + ctx = self._moe_forward_context( + module=replica, + processor=processor, + apply_moe_config=apply_moe_config, + ) + + if not isinstance(ctx, nullcontext): + ctx.__enter__() + moe_contexts.append(ctx) + + prev_kv = shared_kv_cache_dict.get(layer_index - 1) if reuse_kv else None + results: Dict[int, torch.Tensor | None] = {} + processed_rows = 0 + + if self.looper.gptq_model.quantize_config.compute_device_filter is not None: + forward_devices = self.looper.gptq_model.quantize_config.compute_device_filter(devices) + if len(forward_devices) < 1: + self.log.warn( + "compute_device_filter returned empty device list. " + "Using all devices for forward execution." + ) + forward_devices = devices + else: + forward_devices = devices + + device_segments: Dict[torch.device, List[int]] = {} + segment_start = 0 + num_devices = len(forward_devices) + + # Check if balanced mode is active - if so, assign batches where data already resides + calib_device_cfg = self.looper.gptq_model.quantize_config.calibration_data_device + is_balanced_mode = calib_device_cfg == "balanced" + + if is_balanced_mode: + # In balanced mode, assign each batch to the device where its input resides + for device in forward_devices: + device_segments[device] = [] + for batch_idx in range(total_batches): + if layer_inputs[batch_idx]: + batch_device = layer_inputs[batch_idx][0].device + # Check if this device is in our forward_devices, otherwise use first one + if batch_device in device_segments: + device_segments[batch_device].append(batch_idx) + else: + # Fallback: data is on a device not in forward_devices, use round-robin + fallback_device = forward_devices[batch_idx % num_devices] + device_segments[fallback_device].append(batch_idx) + else: + # Default behavior: split batches contiguously across devices + for index, device in enumerate(forward_devices): + remaining_batches = max(total_batches - segment_start, 0) + remaining_devices = max(num_devices - index, 1) + segment_length = remaining_batches // remaining_devices + remainder = remaining_batches % remaining_devices + if remainder > 0: + segment_length += 1 + + if segment_length <= 0: + device_segments[device] = [] + continue + + segment_end = min(segment_start + segment_length, total_batches) + device_segments[device] = list(range(segment_start, segment_end)) + segment_start = segment_end + + max_segment_length = 0 + for indices in device_segments.values(): + if len(indices) > max_segment_length: + max_segment_length = len(indices) + + for position in range(max_segment_length): + futures = [] + for device in forward_devices: + segment_indices = device_segments.get(device, []) + if position >= len(segment_indices): + continue + batch_idx = segment_indices[position] + replica = module_replicas[device] + submitter = ( + device_thread_pool.submit_serial + if device.type in ("cuda", "xpu", "mps") + else device_thread_pool.submit + ) + + futures.append( + submitter( + device, + forward_batch_worker_fn, + replica, + processor, + batch_idx, + layer_inputs[batch_idx], + layer_input_kwargs[batch_idx], + attention_masks[batch_idx], + position_ids[batch_idx] if position_ids else None, + gptq_model=self.looper.gptq_model, + support_batch_quantize=self.looper.support_batch_quantize, + is_lm_head_module=is_lm_head_module, + need_output=need_outputs, + reuse_kv=reuse_kv, + prev_kv=prev_kv, + ) + ) + + for fut in futures: + batch_idx, module_output, kv_next = fut.result() + if need_outputs and module_output is not None: + input_device = layer_inputs[batch_idx][0].device if layer_inputs[batch_idx] else cur_layer_device + target_device = input_device if calib_device_cfg is not None else cur_layer_device + # Move each batch result to its final target device as + # soon as the worker finishes. + primary = module_output[0] if isinstance(module_output, tuple) else module_output + results[batch_idx] = move_to(primary, device=target_device) + del module_output + if reuse_kv and kv_next is not None and shared_kv_cache_dict.get(layer_index) is None: + shared_kv_cache_dict[layer_index] = nested_move_to(kv_next, device=cur_layer_device) + + rows_for_batch = batch_row_counts[batch_idx] if batch_idx < len(batch_row_counts) else 0 + if rows_for_batch <= 0: + rows_for_batch = self.looper._batch_row_count(layer_inputs[batch_idx]) if layer_inputs and batch_idx < len(layer_inputs) else 1 + rows_for_batch = max(rows_for_batch, 1) + + processed_rows = min(processed_rows + rows_for_batch, total_rows) + if progress_pb is not None: + if progress_title: + progress_pb.title(progress_title) + progress_pb.current_iter_step = processed_rows + progress_pb.subtitle(f"{stage_label} rows {processed_rows}/{total_rows}").draw() + finally: + for ctx in moe_contexts: + try: + ctx.__exit__(None, None, None) + except Exception: + pass + moe_contexts.clear() + + for dev in list(module_replicas.keys()): + del module_replicas[dev] + + if not need_outputs: + return [] + + ordered_outputs: List[List[torch.Tensor]] = [] + for idx in range(total_batches): + primary = results.get(idx) + if primary is None: + raise RuntimeError("Forward batch returned no output; data-parallel execution produced empty result.") + ordered_outputs.append([primary]) + + return ordered_outputs diff --git a/gptqmodel/looper/gptq_processor.py b/gptqmodel/looper/gptq_processor.py index f393ff3bc..a035f31da 100644 --- a/gptqmodel/looper/gptq_processor.py +++ b/gptqmodel/looper/gptq_processor.py @@ -11,15 +11,15 @@ import torch from torch.nn import Module -from ..looper.loop_processor import DTYPE_SIZE_COLUMN, MODULE_FEATURE_COLUMN, LoopProcessor +from ..looper.loop_processor import DTYPE_SIZE_COLUMN, ExecutionConfig, MODULE_FEATURE_COLUMN, LoopProcessor from ..looper.named_module import NamedModule from ..models import BaseQModel from ..models._const import CPU from ..models.writer import (PROCESS_LOG_FWD_TIME, PROCESS_LOG_LAYER, PROCESS_LOG_MODULE, PROCESS_LOG_NAME, PROCESS_LOG_TIME, PROCESS_USED_MEMORY, QUANT_LOG_DAMP, QUANT_LOG_LOSS, QUANT_LOG_NSAMPLES) -from ..quantization import GPTAQ, GPTQ -from ..quantization.config import GPTAQConfig, HessianConfig, METHOD, QuantizeConfig -from ..utils.failsafe import normalize_failsafe +from ..quantization import GPTAQ, GPTQ, FOEM +from ..quantization.config import GPTAQConfig, FOEMConfig, HessianConfig, METHOD, QuantizeConfig, resolve_quant_format +from ..utils.fallback import normalize_fallback from ..utils.logger import setup_logger, log_time_block from ..utils.device import get_device from ..utils.model import create_quant_module, find_modules, pack_module @@ -28,7 +28,75 @@ log = setup_logger() lock = threading.Lock() + +def clone_gptq_config_for_module( + qcfg: QuantizeConfig, + module_full_name: str, + *, + fallback=None, +) -> Optional[QuantizeConfig]: + """Clones and applies per-module GPTQ dynamic overrides, or skips the module.""" + + # entire module is skipped + if qcfg.dynamic_get(layer_name=module_full_name) == False: + return None + + qcfg_clone = copy.deepcopy(qcfg) + + # dynamic overrides + if qcfg.dynamic is not None: + qcfg_clone.bits = qcfg.dynamic_get(module_full_name, "bits", qcfg_clone.bits) + qcfg_clone.sym = qcfg.dynamic_get(module_full_name, "sym", qcfg_clone.sym) + qcfg_clone.mse = qcfg.dynamic_get(module_full_name, "mse", qcfg_clone.mse) + qcfg_clone.activation_weighted_mse = qcfg.dynamic_get( + module_full_name, "activation_weighted_mse", qcfg_clone.activation_weighted_mse + ) + + qcfg_clone.group_size = qcfg.dynamic_get(module_full_name, "group_size", qcfg_clone.group_size) + desc_act_override = qcfg.dynamic_get(module_full_name, "desc_act", None) + if desc_act_override is not None: + qcfg_clone.desc_act = desc_act_override + act_group_aware_override = qcfg.dynamic_get(module_full_name, "act_group_aware", None) + if act_group_aware_override is not None: + qcfg_clone.act_group_aware = act_group_aware_override + qcfg_clone.damp_percent = qcfg.dynamic_get(module_full_name, "damp_percent", qcfg_clone.damp_percent) + qcfg_clone.static_groups = qcfg.dynamic_get(module_full_name, "static_groups", qcfg_clone.static_groups) + fallback_override = qcfg.dynamic_get(module_full_name, "fallback", None) + if fallback_override is not None: + qcfg_clone.fallback = normalize_fallback(fallback_override, qcfg_clone.fallback) + hessian_override = qcfg.dynamic_get(module_full_name, "hessian", None) + if hessian_override is not None: + if isinstance(hessian_override, dict): + qcfg_clone.hessian = HessianConfig(**hessian_override) + elif isinstance(hessian_override, HessianConfig): + qcfg_clone.hessian = hessian_override + else: + raise ValueError("QuantizeConfig: dynamic `hessian` must be a HessianConfig or dict.") + gptaq_override = qcfg.dynamic_get(module_full_name, "gptaq", None) + foem_override = qcfg.dynamic_get(module_full_name, "foem", None) + if gptaq_override is not None: + if isinstance(gptaq_override, dict): + qcfg_clone.gptaq = GPTAQConfig(**gptaq_override) + elif isinstance(gptaq_override, GPTAQConfig): + qcfg_clone.gptaq = gptaq_override + else: + raise ValueError("QuantizeConfig: dynamic `gptaq` must be a GPTAQConfig or dict.") + if foem_override is not None: + if isinstance(foem_override, dict): + qcfg_clone.foem = FOEMConfig(**foem_override) + elif isinstance(foem_override, FOEMConfig): + qcfg_clone.foem = foem_override + else: + raise ValueError("QuantizeConfig: dynamic `foem` must be a FOEMConfig or dict.") + + qcfg_clone._resolve_activation_ordering(desc_act_override, act_group_aware_override) + + qcfg_clone.fallback = normalize_fallback(fallback, qcfg_clone.fallback) + return qcfg_clone + class GPTQProcessor(LoopProcessor): + """Captures activations and quantizes modules with GPTQ or GPTAQ/FOEM.""" + def __init__( self, tokenizer, @@ -42,6 +110,7 @@ def __init__( calculate_w_wq_diff: bool = False, calibration_concat_separator: Optional[str] = None, ): + """Initializes GPTQ processing and optional weight-delta tracking.""" super().__init__( tokenizer=tokenizer, @@ -52,74 +121,42 @@ def __init__( calibration_concat_separator=calibration_concat_separator, prepare_dataset_func=prepare_dataset_func, batch_size=batch_size, - require_fwd=require_fwd, - fwd_after_process=True, - subset_forward_early_stop=True, + execution_config=ExecutionConfig( + require_fwd=require_fwd, + fwd_replay_after_process=True, + subset_forward_early_stop=True, + ), ) self.calculate_w_wq_diff = calculate_w_wq_diff self.avg_losses = [] def set_calibration_dataset(self, calibration_dataset): - raise NotImplementedError("GPTQProcessor's calibration_dataset cannot be modified") - - def preprocess(self, module: NamedModule, failsafe=None, **kwargs): - # entire module is skipped - if self.qcfg.dynamic_get(layer_name=module.full_name) == False: - return + """Rejects dataset replacement because GPTQ capture is fixed at construction.""" - qcfg_clone = copy.deepcopy(self.qcfg) + raise NotImplementedError("GPTQProcessor's calibration_dataset cannot be modified") - # dynamic overrides - if self.qcfg.dynamic is not None: - qcfg_clone.bits = self.qcfg.dynamic_get(module.full_name, "bits", qcfg_clone.bits) - qcfg_clone.sym = self.qcfg.dynamic_get(module.full_name, "sym", qcfg_clone.sym) - qcfg_clone.mse = self.qcfg.dynamic_get(module.full_name, "mse", qcfg_clone.mse) - qcfg_clone.activation_weighted_mse = self.qcfg.dynamic_get( - module.full_name, - "activation_weighted_mse", - qcfg_clone.activation_weighted_mse, - ) + def preprocess(self, module: NamedModule, fallback=None, **kwargs): + """Builds the per-module GPTQ/GPTAQ/FOEM task after applying dynamic overrides.""" - qcfg_clone.group_size = self.qcfg.dynamic_get(module.full_name, "group_size", qcfg_clone.group_size) - desc_act_override = self.qcfg.dynamic_get(module.full_name, "desc_act", None) - if desc_act_override is not None: - qcfg_clone.desc_act = desc_act_override - act_group_aware_override = self.qcfg.dynamic_get(module.full_name, "act_group_aware", None) - if act_group_aware_override is not None: - qcfg_clone.act_group_aware = act_group_aware_override - qcfg_clone.damp_percent = self.qcfg.dynamic_get(module.full_name, "damp_percent", qcfg_clone.damp_percent) - qcfg_clone.static_groups = self.qcfg.dynamic_get(module.full_name, "static_groups", qcfg_clone.static_groups) - failsafe_override = self.qcfg.dynamic_get(module.full_name, "failsafe", None) - if failsafe_override is not None: - qcfg_clone.failsafe = normalize_failsafe(failsafe_override, qcfg_clone.failsafe) - hessian_override = self.qcfg.dynamic_get(module.full_name, "hessian", None) - if hessian_override is not None: - if isinstance(hessian_override, dict): - qcfg_clone.hessian = HessianConfig(**hessian_override) - elif isinstance(hessian_override, HessianConfig): - qcfg_clone.hessian = hessian_override - else: - raise ValueError("QuantizeConfig: dynamic `hessian` must be a HessianConfig or dict.") - gptaq_override = self.qcfg.dynamic_get(module.full_name, "gptaq", None) - if gptaq_override is not None: - if isinstance(gptaq_override, dict): - qcfg_clone.gptaq = GPTAQConfig(**gptaq_override) - elif isinstance(gptaq_override, GPTAQConfig): - qcfg_clone.gptaq = gptaq_override - else: - raise ValueError("QuantizeConfig: dynamic `gptaq` must be a GPTAQConfig or dict.") - - qcfg_clone._resolve_activation_ordering(desc_act_override, act_group_aware_override) + qcfg_clone = clone_gptq_config_for_module( + self.qcfg, + module.full_name, + fallback=fallback, + ) + if qcfg_clone is None: + return # store last used qcfg_dynamic self.qcfg_dynamic = qcfg_clone if qcfg_clone.gptaq is not None: tmp = GPTAQ(module=module, qcfg=qcfg_clone) + elif qcfg_clone.foem is not None: + tmp = FOEM(module=module, qcfg=qcfg_clone) else: tmp = GPTQ(module=module, qcfg=qcfg_clone) - tmp.failsafe = normalize_failsafe(failsafe, qcfg_clone.failsafe) + tmp.fallback = qcfg_clone.fallback tmp.expected_nsamples = getattr(self, "total_calibration_tokens", None) tmp.quantizer.configure( @@ -128,6 +165,8 @@ def preprocess(self, module: NamedModule, failsafe=None, **kwargs): self.tasks[module.name] = tmp def is_skipped(self, module: NamedModule) -> bool: + """Reports whether preprocessing omitted this module from GPTQ work.""" + # gptq has no dynamic method of full override (removal) t = self.tasks.get(module.name, False) if t == False: @@ -136,7 +175,11 @@ def is_skipped(self, module: NamedModule) -> bool: return False def pre_process_fwd_hook(self, name: str) -> Callable[[Module, Tuple[torch.Tensor, ...], torch.Tensor], None]: + """Returns the forward hook that feeds captured batches into the GPTQ task.""" + def tmp(module, inp: Tuple[torch.Tensor, ...], out: torch.Tensor): + """Records one activation batch for GPTQ Hessian/statistics accumulation.""" + g = self.tasks[name] # noqa: F821 batch_idx = self.current_batch_index() g.add_batch(inp[0].data, out.data, batch_index=batch_idx) # noqa: F821 @@ -152,10 +195,12 @@ def process( subset_index: Optional[int] = None, subset_total: Optional[int] = None, ): + """Runs GPTQ quantization for one module and stores pack-ready tensors.""" + # Reset peak memory stats #torch.cuda.reset_peak_memory_stats() base_title = f"Quantizing {module.name} in layer" - self._pause_controller.register_and_draw_progress_bar(self.pb, title=base_title, subtitle="") + self.draw_progress(base_title) # logger.info(f"Quantizing module START: {name}, {gptq[name].shape()}") ## Need to return the quantized_weight for offloading @@ -322,6 +367,8 @@ def process( # submodule_finalized is called in reverse after all next sequential processes are called def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs): + """Creates the quantized module and packs the saved GPTQ tensors into it.""" + # generate complete, safe to move to cpu # module.weight.data = move_to(module.state.pop("wq"), device=CPU) # large weights is slow to init on cpu @@ -361,7 +408,7 @@ def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs): create_quant_module( name=module.full_name, linear_cls=model.qlinear_kernel, - bits=self.qcfg.bits, + bits=self.qcfg.runtime_bits, desc_act=self.qcfg.desc_act, dynamic=self.qcfg.dynamic, group_size=self.qcfg.group_size, @@ -371,6 +418,7 @@ def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs): device=self.qcfg.device, lm_head_name=model.lm_head, pack_dtype=self.qcfg.pack_dtype, + format=resolve_quant_format(self.qcfg.format, self.qcfg.method), register_buffers=False, ) if timer is not None and create_start is not None: @@ -419,25 +467,46 @@ def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs): module.unregister_parameter("weight") def finalize(self, model: BaseQModel, **kwargs): + """Marks the model as GPTQ-quantized and runs shared finalization logic.""" + # print("finalize") # print_module_tree(model.model) # set quantized state model.quantized = True - model.quantize_config.quant_method = METHOD.GPTQ + model.quantize_config.method = METHOD.GPTQ super().finalize(model=model, **kwargs) def verify_calibration_dataset(self, processor_index: int) -> bool: + """Ensures GPTQ received calibration data before the quantization loop starts.""" + if self.calibration_dataset is None: raise ValueError("GPTQProcessor's calibration_dataset must be provided.") else: return True def name(self) -> str: + """Returns `gptaq` when GPTAQ overrides are active, otherwise `gptq`.""" + # TODO fix me..this hacks inherited base class logic, why not override name in gptaq? qcfg = self.qcfg_dynamic if self.qcfg_dynamic is not None else self.qcfg - return "gptaq" if qcfg.gptaq is not None else "gptq" + if qcfg.gptaq is not None: + return "gptaq" + if qcfg.foem is not None: + return "foem" + else: + return "gptq" + + def _release_host_buffers(self, *tensors: torch.Tensor) -> None: + """Retain the old cleanup hook for streaming tests and external callers. + + Host buffers are now owned by the stream ticket lifecycle instead of a + dedicated GPTQProcessor pool, so release is intentionally a no-op. + """ + _ = tensors def has_captured_input_ids(self, name: str) -> bool: + """Reports whether the module saw at least one captured forward batch.""" + return self.tasks[name].fwd_counter > 0 diff --git a/gptqmodel/looper/input_cache.py b/gptqmodel/looper/input_cache.py index 09ae37b57..9b23dc745 100644 --- a/gptqmodel/looper/input_cache.py +++ b/gptqmodel/looper/input_cache.py @@ -11,12 +11,16 @@ @dataclass class InputCache: + """Stores captured layer inputs and per-batch kwargs for replayed forwards.""" + layer_inputs: List[List[torch.Tensor]] layer_input_kwargs: List[Dict[str, torch.Tensor]] position_ids: List[torch.Tensor] attention_masks: List[torch.Tensor] def module_kwargs(self): + """Returns the replay kwargs that are shared across cached module calls.""" + result = dict() result["position_ids"] = self.position_ids result["attention_masks"] = self.attention_masks diff --git a/gptqmodel/looper/linear_mode.py b/gptqmodel/looper/linear_mode.py index 027f3783b..ce00a7180 100644 --- a/gptqmodel/looper/linear_mode.py +++ b/gptqmodel/looper/linear_mode.py @@ -8,5 +8,7 @@ class LinearMode(str, Enum): + """Selects whether wrapped linear modules run inference or training behavior.""" + INFERENCE = "inference" TRAIN = "train" diff --git a/gptqmodel/looper/loop_processor.py b/gptqmodel/looper/loop_processor.py index f7698d3fc..0daf66cd5 100644 --- a/gptqmodel/looper/loop_processor.py +++ b/gptqmodel/looper/loop_processor.py @@ -3,7 +3,9 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium import json +import os import threading +from dataclasses import dataclass from datetime import datetime from typing import Any, Callable, Dict, List, Optional, Set, Tuple @@ -49,8 +51,45 @@ "dynamic", ] + +@dataclass +class ExecutionConfig: + """Describe how a processor participates in forward replay and activation capture. + + Processor defaults: + + +--------------+------+------------------------+---------------------------+----------------+--------------------+ + | Processor | fwd? | replay_after_process? | single_pass_all_modules? | early_stop? | activation_capture?| + +--------------+------+------------------------+---------------------------+----------------+--------------------+ + | GPTQ | yes | yes | no | yes | no | + | AWQ | yes | yes | no | yes | yes | + | ParoQuant | yes | yes | no | yes | yes | + | Native | yes | no | yes | no | no | + | WeightOnly | no | no | no/default | no/default | no | + +--------------+------+------------------------+---------------------------+----------------+--------------------+ + """ + + # Whether the processor needs forward replay at all. + require_fwd: bool = True + # Whether the layer outputs become authoritative only after a replay that + # runs after process(), rather than from the subset forward itself. + fwd_replay_after_process: bool = True + # Whether layer modules are replayed in one combined pass instead of per subset. + fwd_all_modules_in_single_pass: bool = False + # Whether a subset forward can stop as soon as the last subset hook fires. + subset_forward_early_stop: bool = False + # Whether capture-only modules/hooks (for example ':?') should be enabled. + enable_activation_capture: bool = False + # Whether the processor needs the original layer IO from the pre-process + # forward, even when the authoritative next-layer cache is produced by a + # later replay-after-process step. + capture_layer_forward_context: bool = False + + # LoopProcessor is a singleton(), not per module instance class LoopProcessor: + """Base lifecycle coordinator shared by all quantization processors.""" + def __init__( self, tokenizer, qcfg: QuantizeConfig, @@ -60,12 +99,10 @@ def __init__( calibration_sort: Optional[str] = None, calibration_concat_separator: Optional[str] = None, batch_size: int = 1, - require_fwd: bool = True, - fwd_after_process: bool = True, - fwd_all_modules_in_single_pass: bool = False, - subset_forward_early_stop: bool = False, - enable_activation_capture_flag: bool = False, + execution_config: Optional[ExecutionConfig] = None, ): + """Initializes shared processor state, logging, and calibration bookkeeping.""" + # process level lock self.lock = threading.Lock() @@ -77,24 +114,9 @@ def __init__( self.qcfg = qcfg self.qcfg_dynamic = None # cloned and dynamic filtered - # TODO FIX ME: dequantize processor sets this to False but it is nver acted on! - # if processor require fwd generate and hooks, set this to true - # looper should bypass generate + hooks if this is false - self.require_fwd = require_fwd # default True - - # after process(), do we need to forward again? paired with require_fwd == True - # if true, forward output is captured post process() and saved for next loop as input - # if false, forward output before process() call is saved for next loop as input - self.fwd_after_process = fwd_after_process # default True - - # native processor does not need to forward N times due to module depend segmentation - # if true, fwd is repeated based on module dep sub-groups - # if false, sub-module groups are merged as one and fwd happens in one pass - self.fwd_all_modules_in_single_pass = fwd_all_modules_in_single_pass # default False - # when True, stop the layer forward immediately after the final module in a subset fires - self.subset_forward_early_stop = subset_forward_early_stop - # enable capture-only hooks (e.g. ':?') for processors that require activations - self.enable_activation_capture = enable_activation_capture_flag + # Keep lifecycle and replay policy in one object so the stages consume + # one execution mode instead of a scattered set of booleans. + self.execution_config = execution_config or ExecutionConfig() self.inputs_cache: InputCache = InputCache(None, None, None, None) self.tasks = {} @@ -181,8 +203,17 @@ def __init__( # GPTQ's Hessian updates) even when forwards run on multiple threads. self._batch_tls = threading.local() + def draw_progress(self, title: str, subtitle: str = "") -> None: + """Best-effort progress-bar redraw for processors with an attached progress handle.""" + + if self.pb is None: + return + self.pb.title(title).subtitle(subtitle).draw() + @staticmethod def _compute_total_tokens(calibration_dataset) -> int: + """Counts total calibration tokens using masks when available.""" + if not calibration_dataset: return 0 total = 0 @@ -212,6 +243,8 @@ def _compute_total_tokens(calibration_dataset) -> int: return total def _set_current_batch_index(self, batch_index: Optional[int]) -> None: + """Stores the active calibration batch index in thread-local state.""" + if batch_index is None: if hasattr(self._batch_tls, "index"): delattr(self._batch_tls, "index") @@ -219,18 +252,26 @@ def _set_current_batch_index(self, batch_index: Optional[int]) -> None: self._batch_tls.index = int(batch_index) def current_batch_index(self) -> Optional[int]: + """Returns the thread-local calibration batch index, if one is set.""" + return getattr(self._batch_tls, "index", None) def _async_log_writer(self, stat): + """Appends one serialized log record to the processor temp log file.""" + with open(self.log_tmp_log_file_name, 'a') as f: json.dump(stat, f, indent=4) f.write("\n") def log_save_async(self, stat): + """Schedules asynchronous log persistence on the serial CPU worker.""" + # Serialize writes on the CPU-bound worker to avoid interleaved JSON output. DEVICE_THREAD_POOL.submit_serial(CPU, self._async_log_writer, stat) def log_new_row(self, stat): + """Prints a formatted log row and queues it for async persistence.""" + with self.lock: self.log_call_count += 1 columns_rebuilt = self._ensure_log_columns(stat) @@ -247,9 +288,23 @@ def log_new_row(self, stat): ] self._log_columns.info(*row_values) + # Emit a plain-text summary when debugging quantization quality in test runs. + if os.getenv("GPTQMODEL_LOG_QUANT_STATS", "0") not in ("", "0", "false", "False"): + log.info( + "Quant stat: method=%s layer=%s module=%s samples=%s loss=%s time=%s", + stat.get(PROCESS_LOG_NAME, ""), + stat.get(PROCESS_LOG_LAYER, ""), + stat.get(PROCESS_LOG_MODULE, ""), + stat.get(QUANT_LOG_NSAMPLES, ""), + stat.get(QUANT_LOG_LOSS, ""), + stat.get(PROCESS_LOG_TIME, ""), + ) + self.log_save_async(stat) def loss_color(self, loss_value: float) -> ANSIColor: + """Maps a quantization loss value to a terminal highlight color.""" + if loss_value <= 0.1: return ANSIColor.GREEN elif loss_value <= 1: @@ -262,6 +317,8 @@ def loss_color(self, loss_value: float) -> ANSIColor: return ANSIColor.BRIGHT_RED def _ensure_log_columns(self, stat: Dict[str, Any]) -> bool: + """Expands the CLI log table to include any new stat keys.""" + desired_labels = list(DEFAULT_LOG_COLUMNS) for key in stat.keys(): if key not in desired_labels: @@ -281,6 +338,8 @@ def _ensure_log_columns(self, stat: Dict[str, Any]) -> bool: return True def _format_log_value(self, key: str, value: Any, stat: Dict[str, Any]) -> str: + """Formats one log cell, applying colors to loss and sample counts.""" + text = "" if value is None else str(value) if key == QUANT_LOG_LOSS and text: @@ -288,7 +347,7 @@ def _format_log_value(self, key: str, value: Any, stat: Dict[str, Any]) -> str: try: color_code = self.loss_color(float(text)) except (TypeError, ValueError): - if cleaned.endswith("failsafe") or cleaned.startswith("failsafe("): + if cleaned.endswith("fallback") or cleaned.startswith("fallback("): return color_text(text, ANSIColor.ORANGE) return text return color_text(text, color_code) @@ -306,6 +365,8 @@ def _format_log_value(self, key: str, value: Any, stat: Dict[str, Any]) -> str: return text def _samples_color(self, samples_value: float, stat: Dict[str, Any]) -> Optional[ANSIColor]: + """Colors sample counts relative to method-specific adequacy thresholds.""" + quant_method = str(stat.get(PROCESS_LOG_NAME, "")).lower() divisor = 10.0 if quant_method.startswith("awq") else 1.0 @@ -326,6 +387,8 @@ def _samples_color(self, samples_value: float, stat: Dict[str, Any]) -> Optional return ANSIColor.RED def module_feature_summary(self, module: NamedModule) -> str: + """Formats cached input/output feature sizes for log display.""" + in_features = module.state.get("in_features") out_features = module.state.get("out_features") @@ -334,6 +397,8 @@ def module_feature_summary(self, module: NamedModule) -> str: return "" def module_dtype_size_summary(self, module: NamedModule) -> str: + """Formats dtype and total persistent tensor footprint for a module.""" + weight = getattr(module.module, "weight", None) dtype = getattr(weight, "dtype", None) total_bytes = 0 @@ -361,6 +426,8 @@ def module_dtype_size_summary(self, module: NamedModule) -> str: return f"{dtype_label}: {size_mb:.1f}MB" def _state_tensor_bytes(self, module: NamedModule) -> int: + """Counts bytes held by tensor-like entries in `module.state`.""" + seen: Set[int] = set() total = 0 for key, value in module.state.items(): @@ -370,6 +437,8 @@ def _state_tensor_bytes(self, module: NamedModule) -> int: return total def _collect_tensor_bytes(self, obj: Any, seen: Set[int]) -> int: + """Recursively sums tensor storage while avoiding double-counting aliases.""" + if isinstance(obj, torch.Tensor): obj_id = id(obj) if obj_id in seen: @@ -393,6 +462,8 @@ def _collect_tensor_bytes(self, obj: Any, seen: Set[int]) -> int: return 0 def _format_dtype(self, dtype: Optional[torch.dtype]) -> str: + """Shortens dtype names for compact table output.""" + if dtype is None: return "n/a" @@ -409,6 +480,8 @@ def _format_dtype(self, dtype: Optional[torch.dtype]) -> str: return dtype_alias.get(dtype_str, dtype_str) def _init_device_smi_handles(self) -> Dict[str, Device]: + """Creates Device-SMI handles for all discovered accelerator devices.""" + handles: Dict[str, Device] = {} for device_id in self._discover_accelerator_devices(): @@ -420,6 +493,8 @@ def _init_device_smi_handles(self) -> Dict[str, Device]: return handles def _init_cpu_device_handle(self) -> Optional[Device]: + """Creates the optional Device-SMI handle used for CPU memory tracking.""" + try: return Device("cpu") except Exception as exc: # pragma: no cover - defensive, external tool @@ -427,6 +502,8 @@ def _init_cpu_device_handle(self) -> Optional[Device]: return None def _discover_accelerator_devices(self) -> List[str]: + """Lists CUDA/ROCm/XPU device identifiers visible to the runtime.""" + devices: List[str] = [] if hasattr(torch, "cuda"): @@ -450,6 +527,8 @@ def _discover_accelerator_devices(self) -> List[str]: return devices def _safe_query_metric(self, device_key: str, handle: Device): + """Queries Device-SMI metrics once per device, suppressing repeated failures.""" + try: return handle.metrics(fast=True) except Exception as exc: # pragma: no cover - defensive, external tool @@ -459,6 +538,8 @@ def _safe_query_metric(self, device_key: str, handle: Device): return None def _snapshot_device_memory_gib(self) -> Dict[str, float]: + """Captures current accelerator memory usage in GiB per device.""" + snapshot: Dict[str, float] = {} for device_id, handle in self._device_smi_handles.items(): metrics = self._safe_query_metric(device_id, handle) @@ -468,6 +549,8 @@ def _snapshot_device_memory_gib(self) -> Dict[str, float]: return snapshot def _snapshot_cpu_memory_gib(self) -> Optional[float]: + """Captures current CPU memory usage in GiB when supported.""" + if self._cpu_device_smi is None: return None metrics = self._safe_query_metric("cpu", self._cpu_device_smi) @@ -476,11 +559,15 @@ def _snapshot_cpu_memory_gib(self) -> Optional[float]: return metrics.memory_used / (1024 ** 3) def device_memory_report(self) -> str: + """Formats current accelerator memory usage for processor log rows.""" + snapshot = self._snapshot_device_memory_gib() if not snapshot: return "n/a" def _format_gib(value: float) -> str: + """Formats a GiB value without unnecessary trailing zeros.""" + text = f"{value:.2f}" if text.endswith("00"): text = text[:-2] @@ -499,6 +586,8 @@ def _format_gib(value: float) -> str: continue def sort_key(item: Tuple[str, float, int]) -> Tuple[int, int]: + """Sorts indexed devices numerically while preserving fallback order.""" + index, _, order = item try: return 0, int(index) @@ -512,6 +601,8 @@ def sort_key(item: Tuple[str, float, int]) -> Tuple[int, int]: return " | ".join(segments) def _close_device_smi_handles(self) -> None: + """Closes all Device-SMI handles owned by this processor.""" + for handle in self._device_smi_handles.values(): try: handle.close() @@ -528,25 +619,35 @@ def _close_device_smi_handles(self) -> None: # Loop Procssor level scoped state data def result_save(self, key: str, value: Any): + """Stores a processor-scoped result by module key.""" + with self._results_lock: #assert self.result_get(key) is None, f"key: {key} already exists in `self.result`" self._results[key] = value # Loop Procssor level scoped state data def result_get(self, key: str, default: Any = None) -> Any: + """Fetches a processor-scoped result by key.""" + with self._results_lock: return self._results.get(key, default) # Loop Procssor level scoped state data def result_pop(self, key: str, default: Any = None): + """Removes and returns a processor-scoped result by key.""" + with self._results_lock: return self._results.pop(key, default) # Loop Procssor level scoped state data def results(self): + """Returns the full processor result mapping.""" + return self._results def collect_memory_info(self, layer_index: int): + """Records current accelerator and CPU memory snapshots for diagnostics.""" + device_snapshot = self._snapshot_device_memory_gib() if device_snapshot: total_gpu_memory = sum(device_snapshot.values()) @@ -557,39 +658,80 @@ def collect_memory_info(self, layer_index: int): self.cpu_memorys.append(cpu_memory) def log_plotly(self): + """Placeholder for future Plotly-based processor log visualizations.""" + pass def set_calibration_dataset(self, calibration_dataset): + """Override point for processors that need to replace calibration data.""" + pass def set_fwd_time(self, fwd_time: float): + """Stores the latest forward-pass duration for logging.""" + self.fwd_time = fwd_time def formatted_fwd_time(self) -> str: + """Returns the stored forward time as a fixed-width string.""" + fwd_time = self.fwd_time if self.fwd_time is not None else 0.0 return f"{fwd_time:.3f}" # called first def preprocess(self, module: NamedModule, **kwargs): + """Override point for per-module setup before forward capture/processing.""" + pass # after preproces, this process may be skipped due to dynamic override (lora adapter = None) def is_skipped(self, module: NamedModule) -> bool: + """Override point for dynamic per-module skip decisions.""" + pass def receive_input_cache(self, input_cache: InputCache): + """Injects the shared input cache for the current processor stage.""" + self.inputs_cache = input_cache # called after every module generate # may be called multiple times due to batch def receive_layer_inputs(self, layer_inputs: List[List[Tensor]]): + """Replaces cached layer outputs that feed the next loop stage.""" + self.inputs_cache.layer_inputs = layer_inputs + def receive_layer_forward_context( + self, + *, + layer_index: int, + layer_inputs: List[List[Tensor]], + layer_input_kwargs: List[Dict[str, Tensor]], + layer_outputs: List[List[Tensor]], + subset_index: Optional[int] = None, + subset_total: Optional[int] = None, + ) -> None: + """Override point for processors that need original layer IO snapshots.""" + + del ( + layer_index, + layer_inputs, + layer_input_kwargs, + layer_outputs, + subset_index, + subset_total, + ) + def clear_cache_data(self): + """Drops transient task data and cached layer inputs after replay.""" + self.tasks = {} self.inputs_cache.layer_inputs = [] def pre_process_fwd_hook(self, name: str) -> Callable[[Module, Tuple[torch.Tensor, ...], torch.Tensor], None]: + """Override point for per-module forward hooks used during capture.""" + pass # do work and return processor.self state which will updated/merged @@ -602,17 +744,23 @@ def process( subset_index: Optional[int] = None, subset_total: Optional[int] = None, ): + """Override point for the main per-module quantization or capture step.""" + pass # last step, after all loop processor is called # submodule_finalize is called in reverse after all next sequential processes are called def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs): + """Override point for per-module packing/finalization after processing.""" + pass #self.offload_to_disk(module=module) # last step, after all loop processor is called # finalize is called in reverse after all next sequential processes are called def finalize(self, model: BaseQModel, **kwargs): + """Releases shared processor resources after the full quantization loop.""" + self._close_device_smi_handles() del self.inputs_cache del self._results @@ -623,18 +771,28 @@ def finalize(self, model: BaseQModel, **kwargs): # os.remove(file_path) def release_calibration_dataset(self): + """Drops the retained calibration dataset to free host memory.""" + del self.calibration_dataset def number_batches(self) -> int: + """Returns the number of prepared calibration batches.""" + return self.num_batches def verify_calibration_dataset(self, processor_index: int) -> bool: + """Override point for validating or reusing calibration datasets.""" + pass def name(self) -> str: + """Override point for the processor name shown in logs and reports.""" + pass def get_max_memory() -> str: + """Returns current CUDA memory usage for the first one or two devices.""" + stats_0 = torch.cuda.memory_stats(DEVICE_0) active_0 = stats_0.get("active_bytes.all.current", 0) / 1024 ** 2 peak_active_0 = stats_0.get("active_bytes.all.peak", 0) / 1024 ** 2 diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index 0a28737cc..12bca08b8 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -19,8 +19,8 @@ import threading import time import logging +import os from concurrent.futures import as_completed -from contextlib import nullcontext from typing import Dict, List, NamedTuple, Optional, TYPE_CHECKING, Any import torch @@ -35,15 +35,14 @@ from ..models import BaseQModel from ..models._const import SUPPORTS_MODULE_TYPES from ..models.base import CAPTURE_ONLY_FLAG -from ..nn_modules.hooked_linear import (STOP_FORWARD_EXCEPTION, HookedLinear, - StopForward, replace_module_with_hooked_legacy) -from ..quantization.config import VramStrategy -from ..utils.attn_mask import apply_keep_mask_bt, normalize_seq_mask +from ..nn_modules.hooked_linear import HookedLinear, replace_module_with_hooked_legacy +from ..quantization.config import METHOD, VramStrategy +from ..utils.attn_mask import apply_keep_mask_bt from ..utils.ctx import ctx +from ..utils.device_telemetry import emit_device_telemetry from ..utils.device import get_device, get_device_new from ..utils.disk import estimate_disk_io_speed from ..utils.logger import setup_logger, log_time_block -from ..utils.pause_resume import PauseResumeController, PauseResumeState from ..utils.looper_helpers import ( clone_module_for_devices, device_ctx, @@ -52,12 +51,14 @@ rehome_module_to_device, select_forward_devices, ) -from ..utils.model import find_modules, get_module, get_module_by_name_prefix, move_to, nested_move_to, \ - MoETopKState, set_moe_topk, restore_moe_topk +from ..utils.model import find_modules, get_module, get_module_by_name_prefix, move_to, MoETopKState, set_moe_topk, restore_moe_topk from ..utils.offload import offload_to_disk +from ..utils.python import has_gil_control, has_gil_disabled from ..utils.torch import (CPU, META, timed_gc_collect, torch_sync, tf32_high_precision_guard) from .. import DEVICE_THREAD_POOL from .awq_processor import AWQProcessor +from .forward_executor import ForwardExecutor +from .paroquant_processor import ParoQuantProcessor from .qqq_processor import QQQProcessor from .stage_inputs_capture import StageInputsCapture from .stage_layer import run_layer_stage @@ -72,16 +73,74 @@ class FinalizeProgressInfo(NamedTuple): + """Progress payload for processor finalization reporting.""" + module_label: Optional[str] process_name: str layer_idx: Optional[int] +def _restrict_quant_devices_for_method(method: Any, quant_devices: List[torch.device]) -> List[torch.device]: + """Apply method-specific device constraints for quantization workers.""" + + try: + normalized_method = METHOD(method) if method is not None else None + except (TypeError, ValueError): + normalized_method = None + + if normalized_method != METHOD.PARO or not quant_devices: + return quant_devices + + non_cpu_devices = [device for device in quant_devices if getattr(device, "type", None) != "cpu"] + if non_cpu_devices: + return [non_cpu_devices[0]] + + return quant_devices[:1] + + +def _resolve_strategy_device_pool( + configured_devices: Optional[List[str]], + available_devices: List[torch.device], + *, + label: str, +) -> List[torch.device]: + """Resolve one strategy device pool as a validated subset of available devices.""" + + if not configured_devices: + return list(available_devices) + + available_by_name = { + str(normalize_device_like(device)): normalize_device_like(device) + for device in available_devices + if normalize_device_like(device) is not None + } + resolved: List[torch.device] = [] + for device_name in configured_devices: + normalized = normalize_device_like(device_name) + if normalized is None: + raise ValueError(f"ModuleLooper: {label} device pool contains an unsupported device value: {device_name!r}.") + matched = available_by_name.get(str(normalized)) + if matched is None: + raise ValueError( + f"ModuleLooper: {label} device pool {configured_devices} must be a subset of the visible compute devices " + f"{list(available_by_name.keys())}." + ) + if matched not in resolved: + resolved.append(matched) + + if not resolved: + raise ValueError(f"ModuleLooper: {label} device pool is empty after normalization.") + + return resolved + + class StopMainLoop(Exception): """Signal that the module loop should abort immediately.""" def io_write_performance() -> Optional[float]: + """Estimate and cache sustained disk write throughput in MB/s.""" + global _IO_WRITE_SPEED_MB if _IO_WRITE_SPEED_MB is not None: return _IO_WRITE_SPEED_MB @@ -101,17 +160,14 @@ class ModuleLooper(): reuse the same worker threads. """ def __init__(self, model: BaseQModel, processors: List[LoopProcessor]): + """Initialize loop state, device policy, and callback wiring.""" + self.processors = processors self.gptq_model = model - # Initialize pause/resume controller first - self.pause_controller = PauseResumeController() - - # Give processors access to pause controller for status - for processor in self.processors: - processor._pause_controller = self.pause_controller self.support_batch_quantize = model.support_batch_quantize self.lock = threading.Lock() + self._forward_executor = ForwardExecutor(self, logger=log) self._layer_callback = getattr(model, "layer_callback", None) self._loop_stop_event = threading.Event() self._loop_stop_exc: Optional[BaseException] = None @@ -155,27 +211,96 @@ def __init__(self, model: BaseQModel, processors: List[LoopProcessor]): "Using all devices for quantization." ) + restricted_quant_devices = _restrict_quant_devices_for_method( + getattr(self.gptq_model.quantize_config, "method", None), + quant_devices, + ) + if restricted_quant_devices != quant_devices: + log.warn( + "ModuleLooper: METHOD.PARO forcing single-device quantization on `%s`; " + "ignoring additional devices %s to avoid multi-GPU sync issues.", + restricted_quant_devices[0], + [str(device) for device in quant_devices if device != restricted_quant_devices[0]], + ) + quant_devices = restricted_quant_devices + self._quant_devices = quant_devices self._quant_device_rr = 0 self._module_device_map: Dict[str, torch.device] = {} self._quant_device_lock = threading.Lock() - vram_strategy = getattr(self.gptq_model.quantize_config, "vram_strategy", VramStrategy.EXCLUSIVE) - if isinstance(vram_strategy, str): + + # Resolve the user-facing split dense/MoE placement settings once at + # looper construction time so subset planning can reuse stable pools. + dense_vram_strategy = getattr(self.gptq_model.quantize_config, "dense_vram_strategy", VramStrategy.EXCLUSIVE) + if isinstance(dense_vram_strategy, str): + try: + dense_vram_strategy = VramStrategy(dense_vram_strategy.lower()) + except ValueError: + dense_vram_strategy = VramStrategy.EXCLUSIVE + supported_dense_strategies = getattr( + self.gptq_model, + "supported_dense_vram_strategies", + [ + VramStrategy.EXCLUSIVE, + VramStrategy.BALANCED, + ], + ) + if isinstance(supported_dense_strategies, VramStrategy): + supported_dense_strategies = [supported_dense_strategies] + if dense_vram_strategy not in supported_dense_strategies: + log.debug( + "ModuleLooper: Model %s does not support dense VRAM strategy %s; falling back to exclusive.", + getattr(self.gptq_model, "__class__", type(self.gptq_model)).__name__, + dense_vram_strategy, + ) + dense_vram_strategy = VramStrategy.EXCLUSIVE + + moe_vram_strategy = getattr( + self.gptq_model.quantize_config, + "moe_vram_strategy", + VramStrategy.EXCLUSIVE, + ) + if isinstance(moe_vram_strategy, str): try: - vram_strategy = VramStrategy(vram_strategy.lower()) + moe_vram_strategy = VramStrategy(moe_vram_strategy.lower()) except ValueError: - vram_strategy = VramStrategy.EXCLUSIVE - supported_strategies = getattr(self.gptq_model, "supported_vram_strategies", [VramStrategy.EXCLUSIVE, VramStrategy.BALANCED]) - if isinstance(supported_strategies, VramStrategy): - supported_strategies = [supported_strategies] - if vram_strategy not in supported_strategies: + moe_vram_strategy = VramStrategy.EXCLUSIVE + supported_moe_strategies = getattr( + self.gptq_model, + "supported_moe_vram_strategies", + [ + VramStrategy.EXCLUSIVE, + VramStrategy.BALANCED, + ], + ) + if isinstance(supported_moe_strategies, VramStrategy): + supported_moe_strategies = [supported_moe_strategies] + if moe_vram_strategy not in supported_moe_strategies: log.debug( - "ModuleLooper: Model %s does not support VRAM strategy %s; falling back to exclusive.", + "ModuleLooper: Model %s does not support MoE VRAM strategy %s; falling back to exclusive.", getattr(self.gptq_model, "__class__", type(self.gptq_model)).__name__, - vram_strategy, + moe_vram_strategy, ) - vram_strategy = VramStrategy.EXCLUSIVE - self._vram_strategy = vram_strategy + moe_vram_strategy = VramStrategy.EXCLUSIVE + + self._dense_vram_strategy = dense_vram_strategy + self._moe_vram_strategy = moe_vram_strategy + dense_strategy_devices = getattr(self.gptq_model.quantize_config, "dense_vram_strategy_devices", None) + moe_strategy_devices = getattr(self.gptq_model.quantize_config, "moe_vram_strategy_devices", None) + self._dense_quant_devices = _resolve_strategy_device_pool( + dense_strategy_devices, + quant_devices, + label="dense_vram_strategy_devices", + ) + self._moe_quant_devices = _resolve_strategy_device_pool( + moe_strategy_devices, + quant_devices, + label="moe_vram_strategy_devices", + ) + # Keep a cheap flag so the planner can skip split-pool logic entirely + # when the user leaves a pool on the default exclusive behavior. + self._dense_vram_strategy_explicit = bool(dense_strategy_devices) or self._dense_vram_strategy != VramStrategy.EXCLUSIVE + self._moe_vram_strategy_explicit = bool(moe_strategy_devices) or self._moe_vram_strategy != VramStrategy.EXCLUSIVE self._moe_subset_threshold = 16 self._subset_callback = getattr(self.gptq_model, "subset_callback", None) @@ -190,10 +315,64 @@ def __init__(self, model: BaseQModel, processors: List[LoopProcessor]): else: self.moe_routing_override = None self.moe_routing_bypass = self.gptq_model.quantize_config.moe_routing_bypass() + self._emit_moe_parallel_quant_runtime() for processor in self.processors: self._processor_mask_tls(processor) + def _emit_moe_parallel_quant_runtime(self) -> None: + """Log the runtime knobs that decide whether MoE quant can fan out efficiently.""" + + if not getattr(self.gptq_model, "dynamic_expert_index", None): + return + + dense_devices = [str(device) for device in self._dense_quant_devices] + moe_devices = [str(device) for device in self._moe_quant_devices] + gil_env = os.environ.get("PYTHON_GIL") + gil_disabled = has_gil_disabled() + gil_controllable = has_gil_control() + routing_mode = ( + "override" + if self.moe_routing_override is not None + else "bypass" + if self.moe_routing_bypass + else "native" + ) + free_threaded_parallel_quant_eligible = bool(gil_disabled and len(self._moe_quant_devices) > 0) + + log.info( + "ModuleLooper: MoE quant runtime dense_pool=%s moe_pool=%s routing_mode=%s routing_override=%s " + "PYTHON_GIL=%s gil_disabled=%s free_threaded_parallel_quant_eligible=%s", + dense_devices, + moe_devices, + routing_mode, + self.moe_routing_override, + gil_env, + gil_disabled, + free_threaded_parallel_quant_eligible, + ) + if moe_devices and gil_controllable and not gil_disabled: + log.warn( + "ModuleLooper: MoE quant is configured for device fan-out across %s but Python GIL is still enabled; " + "rerun with PYTHON_GIL=0 to unlock the free-threaded parallel quant path.", + moe_devices, + ) + + emit_device_telemetry( + "moe_parallel_quant_runtime", + dense_devices=dense_devices, + moe_devices=moe_devices, + dense_strategy=self._dense_vram_strategy, + moe_strategy=self._moe_vram_strategy, + routing_mode=routing_mode, + routing_override=self.moe_routing_override, + routing_bypass=self.moe_routing_bypass, + python_gil_env=gil_env, + python_gil_controllable=gil_controllable, + python_gil_disabled=gil_disabled, + free_threaded_parallel_quant_eligible=free_threaded_parallel_quant_eligible, + ) + class MoERoutingOverrideContext: """ Context manager that temporarily overrides MoE routing top-k. @@ -203,6 +382,8 @@ class MoERoutingOverrideContext: """ def __init__(self, model, moe_routing_override: int): + """Capture the model and temporary top-k override to apply.""" + # Model containing MoE routing modules self.model = model # Target top-k value for per-token expert routing @@ -211,12 +392,16 @@ def __init__(self, model, moe_routing_override: int): self._state: MoETopKState | None = None def __enter__(self): + """Apply the temporary routing override before the forward pass.""" + # Apply routing override if specified if self.moe_routing_override: self._state = set_moe_topk(self.model, self.moe_routing_override) return self def __exit__(self, exc_type, exc, tb): + """Restore the original routing state when leaving the context.""" + # Restore original routing configuration if self.moe_routing_override: restore_moe_topk(self._state) @@ -226,6 +411,8 @@ class MoELifecycleContext: """Context manager for MoE lifecycle hooks integration.""" def __init__(self, module_looper, module, processor, current_subset): + """Capture the replica state needed to patch the MoE block.""" + self.module_looper = module_looper self.module = module self.processor = processor @@ -248,6 +435,8 @@ def __enter__(self): moe_block_prefix = hooks._extract_moe_block_prefix(self.current_subset, self.moe_block) def moe_forward_wrapper(hidden_states, **kwargs): + """Route the replica forward through the all-experts hook.""" + return hooks.forward_to_all_experts( moe_block=self.moe_block, hidden_states=hidden_states, @@ -282,6 +471,8 @@ def register_subset_callback(self, callback) -> None: self._subset_callback = callback def register_dangling_thread(self, watcher: threading.Thread) -> None: + """Track a watcher thread that should be joined before exit.""" + with self._dangling_threads_lock: if self._dangling_threads: self._dangling_threads = [ @@ -290,6 +481,8 @@ def register_dangling_thread(self, watcher: threading.Thread) -> None: self._dangling_threads.append(watcher) def wait_dangling_threads(self) -> None: + """Join any still-running watcher threads and clear the registry.""" + with self._dangling_threads_lock: threads = list(self._dangling_threads) self._dangling_threads.clear() @@ -298,6 +491,8 @@ def wait_dangling_threads(self) -> None: thread.join() def _resolve_layer_callback(self): + """Resolve the active layer-complete callback using legacy fallbacks.""" + for candidate in ( getattr(self, "_layer_callback", None), getattr(self, "layer_callback", None), @@ -310,6 +505,8 @@ def _resolve_layer_callback(self): return None def _resolve_subset_callback(self): + """Resolve the active subset callback from looper or model state.""" + for candidate in ( getattr(self, "_subset_callback", None), getattr(self, "subset_callback", None), @@ -320,6 +517,8 @@ def _resolve_subset_callback(self): return None def callbackup(self, layer_idx: int, submodule_finalized: bool): + """Invoke the layer callback and normalize stop-loop responses.""" + callback = self._resolve_layer_callback() if callback is None: return None @@ -350,6 +549,8 @@ def _subset_event_dispatch( module_names: List[str], processor: str, ) -> None: + """Emit a subset event immediately and surface callback failures.""" + self._emit_subset_event( stage=stage, layer_idx=layer_idx, @@ -361,12 +562,16 @@ def _subset_event_dispatch( ) def _request_loop_stop(self, exc: Optional[BaseException]) -> None: + """Record the first stop reason and signal loop shutdown.""" + with self.lock: if self._loop_stop_exc is None and exc is not None: self._loop_stop_exc = exc self._loop_stop_event.set() def _check_loop_stop(self) -> bool: + """Drain outstanding work and re-raise any recorded stop signal.""" + if not self._loop_stop_event.is_set(): return False if not self._loop_stop_waited: @@ -387,6 +592,8 @@ def _emit_subset_event( processor: str, raise_in_place: bool, ) -> None: + """Forward a subset lifecycle event to the configured callback.""" + callback = self._resolve_subset_callback() if callback is None: return @@ -437,6 +644,8 @@ def _emit_layer_complete( *, raise_in_place: bool, ) -> None: + """Notify listeners that a layer finished and handle stop requests.""" + try: self.callbackup(layer_idx=layer_idx, submodule_finalized=submodule_finalized) except StopMainLoop: @@ -455,6 +664,8 @@ def _emit_layer_complete( # Processors capture activations through hooks that need thread-local state # so masks survive the roundtrip to worker threads. def _processor_mask_tls(self, processor: LoopProcessor) -> threading.local: + """Get or create thread-local storage for the active keep mask.""" + tls = getattr(processor, "_mask_tls", None) if tls is None: tls = threading.local() @@ -478,14 +689,20 @@ def _get_processor_hooks_paused(self, processor: LoopProcessor) -> bool: return getattr(tls, "value", False) if tls else False def _set_processor_mask(self, processor: LoopProcessor, mask): + """Store the active sequence mask for the current worker thread.""" + tls = self._processor_mask_tls(processor) tls.value = mask def _get_processor_mask(self, processor: LoopProcessor): + """Return the sequence mask bound to the current worker thread.""" + tls = getattr(processor, "_mask_tls", None) return getattr(tls, "value", None) if tls else None def _safe_len(self, sequence) -> Optional[int]: + """Return ``len(sequence)`` when the object exposes a safe length.""" + if sequence is None: return None try: @@ -494,6 +711,8 @@ def _safe_len(self, sequence) -> Optional[int]: return None def _coerce_to_int(self, value) -> Optional[int]: + """Best-effort conversion for scalar-like values used in counters.""" + if value is None: return None if isinstance(value, bool): @@ -522,6 +741,8 @@ def _coerce_to_int(self, value) -> Optional[int]: return None def _resolve_batch_total(self, raw_count, fallback_sequence) -> int: + """Resolve a non-negative batch count from explicit or inferred input.""" + count = self._coerce_to_int(raw_count) fallback_len = self._safe_len(fallback_sequence) fallback = self._coerce_to_int(fallback_len) @@ -540,6 +761,8 @@ def _resolve_batch_total(self, raw_count, fallback_sequence) -> int: return 0 def _batch_row_count(self, batch_inputs: Optional[List[torch.Tensor]]) -> int: + """Infer how many rows a cached batch contributes to progress.""" + if not batch_inputs: return 0 @@ -558,6 +781,8 @@ def _batch_row_count(self, batch_inputs: Optional[List[torch.Tensor]]) -> int: return 0 def _collect_row_counts(self, layer_inputs: Optional[List[List[torch.Tensor]]]) -> List[int]: + """Collect per-batch row counts for progress tracking.""" + if not layer_inputs: return [] @@ -568,6 +793,8 @@ def _collect_row_counts(self, layer_inputs: Optional[List[List[torch.Tensor]]]) return counts def _extract_moe_group_key(self, module_name: Optional[str]) -> Optional[str]: + """Collapse expert module names into a stable MoE routing group key.""" + if not module_name: return None @@ -589,6 +816,8 @@ def _extract_moe_group_key(self, module_name: Optional[str]) -> Optional[str]: return None def _is_attention_module_name(self, module_name: str) -> bool: + """Heuristically detect attention modules from their qualified name.""" + if not module_name: return False @@ -642,13 +871,28 @@ def _assign_quant_device_for_module( named_module: NamedModule, fallback_device: torch.device, ) -> torch.device: + """Pick and memoize the quantization device for one named module.""" + key = getattr(named_module, "full_name", None) or named_module.name with self._quant_device_lock: cached = self._module_device_map.get(key) if cached is not None: + emit_device_telemetry( + "quant_device_cache_hit", + module=key, + target_device=cached, + ) return cached device: Optional[torch.device] - if len(self._quant_devices) <= 1: + preferred_device = normalize_device_like(named_module.state.get("preferred_quant_device")) + if preferred_device is not None and any(dev == preferred_device for dev in self._quant_devices if dev is not None): + device = preferred_device + emit_device_telemetry( + "quant_device_preferred_hint", + module=key, + target_device=device, + ) + elif len(self._quant_devices) <= 1: device = self._quant_devices[0] else: device = self._quant_devices[self._quant_device_rr % len(self._quant_devices)] @@ -658,6 +902,12 @@ def _assign_quant_device_for_module( device = fallback_device self._module_device_map[key] = device + emit_device_telemetry( + "quant_device_assign", + module=key, + target_device=device, + fallback_device=fallback_device, + ) return device def _apply_forward_device_overrides( @@ -667,6 +917,8 @@ def _apply_forward_device_overrides( *, fallback_modules: Optional[Dict[str, torch.nn.Module]] = None, ) -> Dict[str, torch.device]: + """Move selected modules to temporary forward devices and record prior placement.""" + previous: Dict[str, torch.device] = {} if not device_map: return previous @@ -692,6 +944,12 @@ def _apply_forward_device_overrides( if current is not None: previous[name] = current + emit_device_telemetry( + "forward_override_apply", + module=getattr(named_module, "full_name", name) if named_module is not None else name, + current_device=current, + target_device=target, + ) move_to(module_ref, device=target) rehome_module_to_device(module_ref, target, move_parameters=True, move_buffers=True) if isinstance(named_module, NamedModule): @@ -707,6 +965,8 @@ def _restore_forward_device_overrides( *, fallback_modules: Optional[Dict[str, torch.nn.Module]] = None, ) -> None: + """Restore module placements saved by ``_apply_forward_device_overrides``.""" + if not previous_devices: return @@ -720,6 +980,11 @@ def _restore_forward_device_overrides( module_ref = named_module.module if isinstance(named_module, NamedModule) else named_module if module_ref is None: continue + emit_device_telemetry( + "forward_override_restore", + module=getattr(named_module, "full_name", name) if named_module is not None else name, + target_device=revert_device, + ) move_to(module_ref, device=revert_device) rehome_module_to_device(module_ref, revert_device, move_parameters=True, move_buffers=True) if isinstance(named_module, NamedModule): @@ -732,6 +997,8 @@ def _rehome_processor_task( named_module: NamedModule, target_device: torch.device, ) -> None: + """Move processor-owned task state alongside the module it quantizes.""" + task_map = getattr(processor, "tasks", None) if not task_map: return @@ -740,6 +1007,10 @@ def _rehome_processor_task( if task is None: return + quant_source = named_module.state.get("quant_source_module") + if isinstance(quant_source, torch.nn.Module) and hasattr(task, "module"): + task.module = quant_source + to_device_fn = getattr(task, "to_device", None) if callable(to_device_fn): to_device_fn(target_device) @@ -773,22 +1044,69 @@ def _rehome_processor_task( if hasattr(task, "dev"): task.dev = target_device + def _prepare_named_module_for_forward( + self, + named_module: NamedModule, + fallback_device: torch.device, + ) -> torch.nn.Module: + """Prepare one named module for the forward role before replay starts.""" + + target_device = get_device(named_module.module) + if target_device == META: + target_device = fallback_device + + prepared = self.gptq_model.shell_module_materialize( + target_submodule=named_module.module, + device=target_device, + role="forward", + named_module=named_module, + ) + if prepared is not named_module.module: + named_module.module = prepared + + setattr(named_module, "target_device", target_device) + setattr(named_module.module, "target_device", target_device) + return prepared + def _prepare_named_module_for_quantization( self, processor: LoopProcessor, named_module: NamedModule, fallback_device: torch.device, ) -> torch.device: + """Place a named module and its processor task on the chosen device.""" + + try: + previous_device = get_device(named_module.module) + except Exception: + previous_device = None + target_device = self._assign_quant_device_for_module( named_module, fallback_device=fallback_device, ) - move_to(named_module.module, device=target_device) + if isinstance(named_module.state.get("quant_source_module"), torch.nn.Module): + prepared = self.gptq_model.shell_module_materialize( + target_submodule=named_module.module, + device=target_device, + role="quant_source", + named_module=named_module, + ) + if prepared is not named_module.module: + named_module.module = prepared + else: + move_to(named_module.module, device=target_device) rehome_module_to_device(named_module.module, target_device, move_parameters=True, move_buffers=True) setattr(named_module, "target_device", target_device) setattr(named_module.module, "target_device", target_device) + emit_device_telemetry( + "quant_prepare", + module=getattr(named_module, "full_name", named_module.name), + previous_device=previous_device, + target_device=target_device, + ) self._rehome_processor_task(processor, named_module, target_device) @@ -816,65 +1134,11 @@ def _run_forward_batches( progress_total_rows: Optional[int] = None, force_serial: bool = False, preserve_module_devices: bool = False, + apply_moe_config: bool = True, ) -> List[List[torch.Tensor]]: - """Dispatch the captured layer inputs through the module. - - When multiple accelerators of the same type are available we clone the - module and execute batches in parallel, otherwise we fall back to a - single threaded path. The helper returns the ordered outputs that feed - the next processor stage when ``need_outputs`` is set. - """ - if not force_serial: - quant_config = getattr(self.gptq_model, "quantize_config", None) - if quant_config is not None and not getattr(quant_config, "auto_forward_data_parallel", True): - force_serial = True - if force_serial: - return self._run_forward_batches_single( - module=module, - processor=processor, - layer_inputs=layer_inputs, - layer_input_kwargs=layer_input_kwargs, - position_ids=position_ids, - attention_masks=attention_masks, - cur_layer_device=cur_layer_device, - is_lm_head_module=is_lm_head_module, - shared_kv_cache_dict=shared_kv_cache_dict, - layer_index=layer_index, - need_outputs=need_outputs, - reuse_kv=reuse_kv, - progress_pb=progress_pb, - progress_title=progress_title, - progress_stage=progress_stage, - progress_rows_per_batch=progress_rows_per_batch, - progress_total_rows=progress_total_rows, - preserve_module_devices=preserve_module_devices, - ) - - devices = select_forward_devices(cur_layer_device) - - if len(devices) <= 1: - return self._run_forward_batches_single( - module=module, - processor=processor, - layer_inputs=layer_inputs, - layer_input_kwargs=layer_input_kwargs, - position_ids=position_ids, - attention_masks=attention_masks, - cur_layer_device=cur_layer_device, - is_lm_head_module=is_lm_head_module, - shared_kv_cache_dict=shared_kv_cache_dict, - layer_index=layer_index, - need_outputs=need_outputs, - reuse_kv=reuse_kv, - progress_pb=progress_pb, - progress_title=progress_title, - progress_stage=progress_stage, - progress_rows_per_batch=progress_rows_per_batch, - progress_total_rows=progress_total_rows, - preserve_module_devices=preserve_module_devices, - ) + """Run cached batches through the module using serial or parallel execution.""" - return self._run_forward_batches_parallel( + return self._forward_executor.run( module=module, processor=processor, layer_inputs=layer_inputs, @@ -887,12 +1151,15 @@ def _run_forward_batches( layer_index=layer_index, need_outputs=need_outputs, reuse_kv=reuse_kv, - devices=devices, progress_pb=progress_pb, progress_title=progress_title, progress_stage=progress_stage, progress_rows_per_batch=progress_rows_per_batch, progress_total_rows=progress_total_rows, + force_serial=force_serial, + preserve_module_devices=preserve_module_devices, + apply_moe_config=apply_moe_config, + select_forward_devices_fn=select_forward_devices, ) def _run_forward_batches_single( @@ -916,125 +1183,31 @@ def _run_forward_batches_single( progress_rows_per_batch: Optional[List[int]] = None, progress_total_rows: Optional[int] = None, preserve_module_devices: bool = False, + apply_moe_config: bool = True, ) -> List[List[torch.Tensor]]: - """Sequential fallback when only one forward device is in use.""" - outputs: List[List[torch.Tensor]] = [] - prev_kv = shared_kv_cache_dict.get(layer_index - 1) if reuse_kv else None - total_batches = self._resolve_batch_total(processor.num_batches, layer_inputs) - batch_row_counts = progress_rows_per_batch or self._collect_row_counts(layer_inputs) - batch_row_counts = list(batch_row_counts) - if len(batch_row_counts) > total_batches: - batch_row_counts = batch_row_counts[:total_batches] - elif len(batch_row_counts) < total_batches: - batch_row_counts.extend([0] * (total_batches - len(batch_row_counts))) - total_rows = progress_total_rows if progress_total_rows is not None else sum(batch_row_counts) - if total_rows <= 0 and total_batches > 0: - total_rows = total_batches - total_rows = max(total_rows, 1) - processed_rows = 0 - stage_label = progress_stage or "Forward" - - for batch_idx in range(total_batches): - processor._set_current_batch_index(batch_idx) - try: - exec_device = cur_layer_device - if preserve_module_devices: - module_target = getattr(module, "target_device", None) - if module_target is not None: - exec_device = module_target - - layer_input = [move_to(inp, device=exec_device) for inp in layer_inputs[batch_idx]] - - raw_mask = attention_masks[batch_idx] - attn_tensor = raw_mask if raw_mask is None else move_to(raw_mask, device=exec_device) - - keep_mask = None - if attn_tensor is not None: - seq_len = layer_input[0].shape[1] if (len(layer_input) > 0 and layer_input[0].dim() >= 2) else None - keep_mask = normalize_seq_mask(attn_tensor, seq_len=seq_len) - - # Set mask using TLS (thread-safe) - self._set_processor_mask(processor, keep_mask) - additional_inputs: Dict[str, Optional[torch.Tensor]] = {} - if self.support_batch_quantize and attn_tensor is not None: - additional_inputs["attention_mask"] = attn_tensor - else: - additional_inputs["attention_mask"] = None - - if position_ids: - pos = position_ids[batch_idx] - if pos is not None: - additional_inputs["position_ids"] = move_to(pos, device=exec_device) - - for key, value in layer_input_kwargs[batch_idx].items(): - # past_key_values will triggers the cache logic. we need disable cache when layer forward. - if key in ["past_key_values", "past_key_value"]: - continue - additional_inputs[key] = nested_move_to(value, device=exec_device) - - if reuse_kv and prev_kv is not None: - additional_inputs["kv_last_layer"] = nested_move_to(prev_kv, device=exec_device) - - # TODO: some models does not honor generate config.use_cache property so we are forced to hack this to false - additional_inputs["use_cache"] = False - - if not preserve_module_devices: - rehome_module_to_device(module, cur_layer_device, move_parameters=True, move_buffers=True) - - # MoE lifecycle hooks integration - using context manager - with self.MoERoutingOverrideContext(module, self.moe_routing_override) if self.moe_routing_override else self.MoELifecycleContext(self, module, processor, self._current_subset) if self.moe_routing_bypass else nullcontext(): - module_output = None - try: - if is_lm_head_module: - module_output = module(*layer_input) - else: - module_output = module(*layer_input, **additional_inputs) - except StopForward: - module_output = None - finally: - self._set_processor_mask(processor, None) - - # Release intermediate tensors promptly after they are no longer needed - del layer_input - del attn_tensor - del keep_mask - del additional_inputs - - if ( - reuse_kv - and module_output is not None - and isinstance(module_output, tuple) - and len(module_output) > 0 - and shared_kv_cache_dict.get(layer_index) is None - ): - shared_kv_cache_dict[layer_index] = module_output[-1] - - if need_outputs and module_output is not None: - primary = module_output[0] if isinstance(module_output, tuple) else module_output - primary = move_to(primary, device=cur_layer_device) - outputs.append([primary]) - - # Release module_output promptly after extracting what we need - if module_output is not None: - del module_output - - rows_for_batch = batch_row_counts[batch_idx] if batch_idx < len(batch_row_counts) else 0 - if rows_for_batch <= 0: - rows_for_batch = self._batch_row_count(layer_inputs[batch_idx]) if layer_inputs and batch_idx < len(layer_inputs) else 1 - rows_for_batch = max(rows_for_batch, 1) - - processed_rows = min(processed_rows + rows_for_batch, total_rows) - if progress_pb is not None: - if progress_title: - progress_pb.title(progress_title) - progress_pb.current_iter_step = processed_rows - progress_pb.subtitle( - f"{stage_label} rows {processed_rows}/{total_rows}" - ).draw() - finally: - processor._set_current_batch_index(None) + """Run cached batches on a single device and return ordered outputs when requested.""" - return outputs + return self._forward_executor.run_single( + module=module, + processor=processor, + layer_inputs=layer_inputs, + layer_input_kwargs=layer_input_kwargs, + position_ids=position_ids, + attention_masks=attention_masks, + cur_layer_device=cur_layer_device, + is_lm_head_module=is_lm_head_module, + shared_kv_cache_dict=shared_kv_cache_dict, + layer_index=layer_index, + need_outputs=need_outputs, + reuse_kv=reuse_kv, + progress_pb=progress_pb, + progress_title=progress_title, + progress_stage=progress_stage, + progress_rows_per_batch=progress_rows_per_batch, + progress_total_rows=progress_total_rows, + preserve_module_devices=preserve_module_devices, + apply_moe_config=apply_moe_config, + ) def _run_forward_batches_parallel( self, @@ -1057,235 +1230,41 @@ def _run_forward_batches_parallel( progress_stage: Optional[str] = None, progress_rows_per_batch: Optional[List[int]] = None, progress_total_rows: Optional[int] = None, + apply_moe_config: bool = True, ) -> List[List[torch.Tensor]]: - """Fan batches across device clones and preserve result ordering.""" - effective_title = progress_title or (progress_stage or "Forward") - - total_batches = self._resolve_batch_total(processor.num_batches, layer_inputs) - batch_row_counts = progress_rows_per_batch or self._collect_row_counts(layer_inputs) - batch_row_counts = list(batch_row_counts) - if len(batch_row_counts) > total_batches: - batch_row_counts = batch_row_counts[:total_batches] - elif len(batch_row_counts) < total_batches: - batch_row_counts.extend([0] * (total_batches - len(batch_row_counts))) - total_rows = progress_total_rows if progress_total_rows is not None else sum(batch_row_counts) - if total_rows <= 0 and total_batches > 0: - total_rows = total_batches - total_rows = max(total_rows, 1) - stage_label = progress_stage or "Forward" - - replica_pb: "ProgressBar" | None = None - replica_title = "" - replica_completed = 0 - - if progress_pb is not None: - progress_pb.title(effective_title) - if len(devices) > 1: - replica_title = f"{stage_label}: replicate to {len(devices)} devices" - replica_pb = ( - log.pb(range(len(devices))) - .manual() - .set(show_left_steps=False) - ) - replica_pb.title(replica_title).subtitle("Staging module...").draw() - else: - device_label = str(devices[0]) if devices else "" - progress_pb.subtitle(f"{stage_label}: staging on {device_label}").draw() - - def _replica_progress(idx: int, total: int, device: torch.device, step: str) -> None: - nonlocal replica_completed - device_label = str(device) - if replica_pb is not None: - if step == "stage": - replica_pb.title(replica_title).subtitle(f"Stage {device_label}").draw() - return - if idx > replica_completed: - replica_completed = idx - replica_pb.title(replica_title).subtitle( - f"{device_label} {idx}/{total}" - ).next().draw() - else: - replica_pb.title(replica_title).subtitle( - f"{device_label} {idx}/{total}" - ).draw() - elif progress_pb is not None: - stage_msg = ( - f"{stage_label}: staging on {device_label}" - if step == "stage" - else f"{stage_label}: {step} {idx}/{total} on {device_label}" - ) - progress_pb.title(effective_title).subtitle(stage_msg).draw() + """Run cached batches across device replicas and preserve batch ordering in the result.""" - progress_cb = _replica_progress if progress_pb is not None else None - - # Ensure any async replication/memcpy ops are complete before threads start fanning out. - torch_sync() - - # Clone modules FIRST, then apply MoE lifecycle hooks to all replicas - try: - module_replicas = clone_module_for_devices( - module, - devices, - progress_callback=progress_cb, - ) - finally: - if replica_pb is not None: - replica_pb.close() - if progress_pb is not None: - progress_pb.title(effective_title).subtitle( - f"{stage_label} rows 0/{total_rows}" - ).draw() - - # Apply MoE lifecycle hooks to ALL replicas (not just the original module) - moe_contexts = [] - try: - for device, replica in module_replicas.items(): - # Create and activate context for each replica - ctx = None - if self.moe_routing_override: - ctx = self.MoERoutingOverrideContext(replica, self.moe_routing_override) - elif self._should_use_moe_lifecycle(module, processor): - ctx = self.MoELifecycleContext(self, replica, processor, self._current_subset) - - if ctx: - ctx.__enter__() - moe_contexts.append(ctx) - - prev_kv = shared_kv_cache_dict.get(layer_index - 1) if reuse_kv else None - - results: Dict[int, torch.Tensor | tuple | None] = {} - - processed_rows = 0 - - # Apply compute device filter if provided to determine which devices to use for forward execution - if self.gptq_model.quantize_config.compute_device_filter is not None: - forward_devices = self.gptq_model.quantize_config.compute_device_filter(devices) - if len(forward_devices) < 1: - log.warn( - "compute_device_filter returned empty device list. " - "Using all devices for forward execution." - ) - forward_devices = devices - else: - # If no filter is provided, use all devices (default behavior) - forward_devices = devices - - device_segments: Dict[torch.device, List[int]] = {} - segment_start = 0 - num_devices = len(forward_devices) - - for index, device in enumerate(forward_devices): - # Split the outstanding batches across forward_devices so that each accelerator - # receives a contiguous slice. - remaining_batches = max(total_batches - segment_start, 0) - remaining_devices = max(num_devices - index, 1) - segment_length = remaining_batches // remaining_devices - remainder = remaining_batches % remaining_devices - if remainder > 0: - segment_length += 1 - - if segment_length <= 0: - device_segments[device] = [] - continue - - segment_end = min(segment_start + segment_length, total_batches) - device_segments[device] = list(range(segment_start, segment_end)) - segment_start = segment_end - - max_segment_length = 0 - for indices in device_segments.values(): - if len(indices) > max_segment_length: - max_segment_length = len(indices) - - for position in range(max_segment_length): - # Submit one batch per device - futures = [] - for device in forward_devices: - segment_indices = device_segments.get(device, []) - if position >= len(segment_indices): - continue - batch_idx = segment_indices[position] - replica = module_replicas[device] - submitter = ( - DEVICE_THREAD_POOL.submit_serial - if device.type in ("cuda", "xpu", "mps") - else DEVICE_THREAD_POOL.submit - ) - - futures.append( - submitter( - device, - forward_batch_worker, - replica, - processor, - batch_idx, - layer_inputs[batch_idx], - layer_input_kwargs[batch_idx], - attention_masks[batch_idx], - position_ids[batch_idx] if position_ids else None, - support_batch_quantize=self.support_batch_quantize, - is_lm_head_module=is_lm_head_module, - need_output=need_outputs, - reuse_kv=reuse_kv, - prev_kv=prev_kv, - ) - ) - - for fut in futures: - # Preserve the original batch order - batch_idx, module_output, kv_next = fut.result() - if need_outputs and module_output is not None: - results[batch_idx] = module_output - if reuse_kv and kv_next is not None and shared_kv_cache_dict.get(layer_index) is None: - shared_kv_cache_dict[layer_index] = nested_move_to(kv_next, device=cur_layer_device) - - rows_for_batch = batch_row_counts[batch_idx] if batch_idx < len(batch_row_counts) else 0 - if rows_for_batch <= 0: - rows_for_batch = self._batch_row_count(layer_inputs[batch_idx]) if layer_inputs and batch_idx < len(layer_inputs) else 1 - rows_for_batch = max(rows_for_batch, 1) - - processed_rows = min(processed_rows + rows_for_batch, total_rows) - if progress_pb is not None: - if progress_title: - progress_pb.title(progress_title) - progress_pb.current_iter_step = processed_rows - progress_pb.subtitle( - f"{stage_label} rows {processed_rows}/{total_rows}" - ).draw() - finally: - # Clean up MoE lifecycle hooks from all replicas - for ctx in moe_contexts: - try: - ctx.__exit__(None, None, None) - except Exception: - pass - moe_contexts.clear() - - # ensure replicas release promptly and free GPU memory - for dev in list(module_replicas.keys()): - del module_replicas[dev] - - if not need_outputs: - return [] - - ordered_outputs: List[List[torch.Tensor]] = [] - for idx in range(total_batches): - # Rebuild the ordered list of batch outputs expected by the next - # stage. - module_output = results.get(idx) - if module_output is None: - raise RuntimeError("Forward batch returned no output; data-parallel execution produced empty result.") - if isinstance(module_output, tuple): - primary = module_output[0] - else: - primary = module_output - primary = move_to(primary, device=cur_layer_device) - ordered_outputs.append([primary]) - - return ordered_outputs + return self._forward_executor.run_parallel( + module=module, + processor=processor, + layer_inputs=layer_inputs, + layer_input_kwargs=layer_input_kwargs, + position_ids=position_ids, + attention_masks=attention_masks, + cur_layer_device=cur_layer_device, + is_lm_head_module=is_lm_head_module, + shared_kv_cache_dict=shared_kv_cache_dict, + layer_index=layer_index, + need_outputs=need_outputs, + reuse_kv=reuse_kv, + devices=devices, + progress_pb=progress_pb, + progress_title=progress_title, + progress_stage=progress_stage, + progress_rows_per_batch=progress_rows_per_batch, + progress_total_rows=progress_total_rows, + apply_moe_config=apply_moe_config, + clone_module_for_devices_fn=clone_module_for_devices, + forward_batch_worker_fn=forward_batch_worker, + device_thread_pool=DEVICE_THREAD_POOL, + ) def _masked_hook_wrapper(self, processor: LoopProcessor, inner_hook, hook_source: str): + """Wrap a forward hook so it sees masked activations for the current batch.""" + def hook(module, inputs, output): + """Apply the thread-local keep mask before delegating to ``inner_hook``.""" + # Thread-safe check if hooks are paused (TLS-based, per-thread) if self._get_processor_hooks_paused(processor): return @@ -1343,6 +1322,8 @@ def _masked_pre_hook_wrapper(self, processor: LoopProcessor, inner_hook, hook_so Respects hooks_paused state to avoid double-counting during intermediate calculations. """ def pre_hook(module, inputs, output): + """Apply the current keep mask before invoking the wrapped pre-hook.""" + # Thread-safe check if hooks are paused (TLS-based, per-thread) if self._get_processor_hooks_paused(processor): return @@ -1382,6 +1363,8 @@ def pre_hook(module, inputs, output): return pre_hook def cache_inputs(self, layers, calibration_data, use_cache): + """Capture and cache per-layer calibration inputs for later replay.""" + capture_stage = StageInputsCapture(self, logger=log) return capture_stage.cache_inputs( layers=layers, @@ -1389,15 +1372,18 @@ def cache_inputs(self, layers, calibration_data, use_cache): use_cache=use_cache, ) - def loop(self, failsafe=None, **kwargs): + def loop(self, fallback=None, **kwargs): + """Run the quantization loop under the TF32 guard.""" + with tf32_high_precision_guard(): - with self.pause_controller.lifecycle(): - return self._loop_impl(failsafe=failsafe, **kwargs) + return self._loop_impl(fallback=fallback, **kwargs) @torch.inference_mode() - def _loop_impl(self, failsafe=None, **kwargs): - if failsafe is None: - failsafe = getattr(self.gptq_model.quantize_config, "failsafe", None) + def _loop_impl(self, fallback=None, **kwargs): + """Execute the full layer-by-layer quantization workflow.""" + + if fallback is None: + fallback = getattr(self.gptq_model.quantize_config, "fallback", None) if self.gptq_model.quantize_config.lm_head: if self.gptq_model.model.config.tie_word_embeddings and hasattr(self.gptq_model.model.model, "_tied_weights_keys"): @@ -1428,7 +1414,8 @@ def _loop_impl(self, failsafe=None, **kwargs): for p_index, processor in enumerate(self.processors): if not processor.verify_calibration_dataset(p_index): if isinstance(processor, EoraProcessor) or\ - (isinstance(processor, GPTQProcessor) and self.gptq_model.quantize_config.gptaq is not None): + (isinstance(processor, GPTQProcessor) and getattr(self.gptq_model.quantize_config, "gptaq", None) is not None) or\ + (isinstance(processor, GPTQProcessor) and getattr(self.gptq_model.quantize_config, "foem", None) is not None): prev_processor = self.processors[p_index - 1] processor.set_calibration_dataset(prev_processor.calibration_dataset) # If calibration_dataset is None or Empty, the input_cache of the previous processor is used. @@ -1457,12 +1444,21 @@ def _loop_impl(self, failsafe=None, **kwargs): disk_path=self.gptq_model.quantize_config.offload_to_disk_path ) + for processor in self.processors: + # Pre-build ParoQuant's optional fused rotation extension before the + # first timed layer so layer 0 does not absorb a one-time JIT cost. + if isinstance(processor, ParoQuantProcessor): + processor.prewarm_runtime() + if region_timer is not None: region_timer.flush() - is_awq_quantize = any(isinstance(proc, AWQProcessor) for proc in self.processors) + is_awq_quantize = any(isinstance(proc, (AWQProcessor, ParoQuantProcessor)) for proc in self.processors) + # Capture-only layer groups are driven by processor execution config, + # not by ad-hoc processor attributes. requires_activation_capture = any( - getattr(proc, "enable_activation_capture", False) for proc in self.processors + getattr(getattr(proc, "execution_config", None), "enable_activation_capture", False) + for proc in self.processors ) layer_modules = self.gptq_model.simple_layer_modules( model_config=self.gptq_model.model.config, @@ -1470,6 +1466,11 @@ def _loop_impl(self, failsafe=None, **kwargs): is_awq_quantize=is_awq_quantize, include_capture_only=requires_activation_capture, ) + planning_layer_modules = self.gptq_model.full_layer_modules( + model_config=self.gptq_model.model.config, + is_awq_quantize=is_awq_quantize, + include_capture_only=requires_activation_capture, + ) # true-sequential will replay the quantized activations after each subset has been quantized to be used for next subset quantization # this should always be true for gptq unless you want lower but misleading error_loss that is misleading and will lead to lower post-quantized model @@ -1501,8 +1502,9 @@ def _loop_impl(self, failsafe=None, **kwargs): self, layers=layers, layer_modules=layer_modules, + planning_layer_modules=planning_layer_modules, layers_prefix=layers_prefix, - failsafe=failsafe, + fallback=fallback, shared_kv_cache_dict=shared_kv_cache_dict, pb=pb, layer_count=layer_count, @@ -1555,7 +1557,7 @@ def _loop_impl(self, failsafe=None, **kwargs): processor_name = reverse_p.name() total_log[processor_name] = reverse_p.log - if processor_name in ["gptq", "gptq v2"]: + if processor_name in ["gptq", "gptq v2", "awq"]: self.gptq_model.quant_log = reverse_p.log for module_log in reverse_p.log: @@ -1586,7 +1588,9 @@ def _loop_impl(self, failsafe=None, **kwargs): return total_log - def create_named_modules(self, module, full, is_lm_head_module, layer_index, layers_prefix, names, processor, failsafe, layer_module=None) -> Dict[str, NamedModule]: + def create_named_modules(self, module, full, is_lm_head_module, layer_index, layers_prefix, names, processor, fallback, layer_module=None) -> Dict[str, NamedModule]: + """Build the named-module subset a processor will quantize for one layer.""" + subset = {} capture_only_flags: Dict[str, bool] = {} for n in names: @@ -1628,7 +1632,7 @@ def create_named_modules(self, module, full, is_lm_head_module, layer_index, lay subset[name].state["capture_only"] = True if isinstance(processor, GPTQProcessor): - processor.preprocess(subset[name], failsafe=failsafe) + processor.preprocess(subset[name], fallback=fallback) else: processor.preprocess(subset[name]) # some modules are skipped diff --git a/gptqmodel/looper/module_preprocessor.py b/gptqmodel/looper/module_preprocessor.py new file mode 100644 index 000000000..11796443f --- /dev/null +++ b/gptqmodel/looper/module_preprocessor.py @@ -0,0 +1,159 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import math +from typing import Any, Dict, List, Optional + +import torch +import transformers + +from ..looper.loop_processor import ExecutionConfig, LoopProcessor +from ..looper.named_module import NamedModule +from ..quantization.config import AutoModuleDecoderConfig, SmootherConfig, TensorParallelPadderConfig + + +def _get_number_of_rows_and_cols(layer: torch.nn.Module) -> tuple[int, int]: + if isinstance(layer, NamedModule): + layer = layer.module + + if isinstance(layer, transformers.Conv1D): + return layer.weight.shape[1], layer.weight.shape[0] + + return layer.weight.shape[0], math.prod(layer.weight.shape[1:]) + + +class ModulePreProcessor(LoopProcessor): + """Annotate modules with an ordered preprocessor plan before quantization.""" + + _TP_TARGETS = (2, 4, 8) + + def __init__(self, *args, **kwargs): + """Initialize a no-forward planning processor for module preprocessors.""" + + kwargs = dict(kwargs) + kwargs.pop("calculate_w_wq_diff", None) + qcfg = kwargs.pop("qcfg") + tokenizer = kwargs.pop("tokenizer", None) + super().__init__( + tokenizer=tokenizer, + qcfg=qcfg, + calibration=None, + prepare_dataset_func=None, + calibration_concat_size=None, + calibration_sort=None, + calibration_concat_separator=None, + batch_size=1, + execution_config=ExecutionConfig( + require_fwd=False, + fwd_replay_after_process=False, + ), + ) + self.qcfg = qcfg + + def preprocess(self, module: NamedModule, **kwargs): + """Normalize configured preprocessors into a stable per-module plan.""" + + del kwargs + + pipeline: List[Dict[str, Any]] = [] + auto_module_decoder_plan = None + tp_pad_info = None + for preprocessor in getattr(self.qcfg, "preprocessors", []) or []: + if isinstance(preprocessor, AutoModuleDecoderConfig): + auto_module_decoder_plan = { + "code": preprocessor.code, + "source_dtype": preprocessor.source_dtype, + "target_dtype": preprocessor.target_dtype, + } + pipeline.append(auto_module_decoder_plan) + continue + if isinstance(preprocessor, TensorParallelPadderConfig): + tp_pad_info = self._compute_tp_pad_info(module) + pipeline.append( + { + "code": preprocessor.code, + **tp_pad_info, + } + ) + continue + if isinstance(preprocessor, SmootherConfig): + pipeline.append(preprocessor.to_dict()) + + if pipeline: + module.state["preprocessor_pipeline"] = pipeline + else: + module.state.pop("preprocessor_pipeline", None) + + if auto_module_decoder_plan is not None: + module.state["auto_module_decoder"] = auto_module_decoder_plan + else: + module.state.pop("auto_module_decoder", None) + module.state.pop("quant_source_module", None) + module.state.pop("auto_module_decoder_forward_mode", None) + module.state.pop("_auto_module_decoder_event_recorded", None) + + if tp_pad_info is not None and tp_pad_info["pad_cols"] > 0: + module.state["tp_pad_info"] = tp_pad_info + else: + module.state.pop("tp_pad_info", None) + + def is_skipped(self, module: NamedModule) -> bool: + """Report that every candidate module passes through preprocessor planning.""" + + del module + return False + + def pre_process_fwd_hook(self, name: str): + """Return a no-op hook because planning does not inspect activations.""" + + del name + + def _noop(module, inputs, output): + del module, inputs, output + return None + + return _noop + + def process( + self, + module: NamedModule, + device: torch.device = None, + subset: Optional[Dict[str, NamedModule]] = None, + previous_subset: Optional[Dict[str, NamedModule]] = None, + subset_index: Optional[int] = None, + subset_total: Optional[int] = None, + ): + """Keep the planning stage side-effect free after preprocess.""" + + del module, device, subset, previous_subset, subset_index, subset_total + + def verify_calibration_dataset(self, processor_index: int) -> bool: + """Report that this planning stage does not require calibration inputs.""" + + del processor_index + return False + + def name(self) -> str: + """Return the processor label used in logs and lifecycle reporting.""" + + return "module-preprocessor" + + def _compute_tp_pad_info(self, module: NamedModule) -> Dict[str, int]: + """Calculate tensor-parallel padding metadata for one module.""" + + target_multiple = math.lcm(*self._TP_TARGETS) + group_size = getattr(self.qcfg, "group_size", -1) + if group_size > 0: + target_multiple = math.lcm(target_multiple, group_size) + + _, columns = _get_number_of_rows_and_cols(module) + pad_cols = (target_multiple - (columns % target_multiple)) % target_multiple + return { + "pad_cols": pad_cols, + "target_multiple": target_multiple, + "original_columns": columns, + } + +__all__ = ["ModulePreProcessor"] diff --git a/gptqmodel/looper/named_module.py b/gptqmodel/looper/named_module.py index 42a74fd40..fea761a97 100644 --- a/gptqmodel/looper/named_module.py +++ b/gptqmodel/looper/named_module.py @@ -20,8 +20,11 @@ log = setup_logger() class NamedModule(torch.nn.Module): + """Thread-safe wrapper that adds stable names and scratch state to a module.""" def __init__(self, module: torch.nn.Module, name: str, full_name:str, layer_index: int) -> None: + """Wraps a submodule with naming metadata used by the looper pipeline.""" + super().__init__() self.module = module # wrapped module @@ -31,6 +34,9 @@ def __init__(self, module: torch.nn.Module, name: str, full_name:str, layer_inde self.layer_index = layer_index # layer index for repeated blocks self._parent_lock = get_parent_lock(full_name) + if hasattr(module, "module_name") and module.module_name is None: + module.module_name = full_name + # persistent work state for named module (used by some LoopProcessors) # store all `processed()` work state/data/result here self.state = {} @@ -64,30 +70,44 @@ def __init__(self, module: torch.nn.Module, name: str, full_name:str, layer_inde }) def parameters(self, recurse: bool = True): + """Delegates parameter iteration to the wrapped module.""" + return self.module.parameters(recurse=recurse) def named_parameters(self, prefix: str = "", recurse: bool = True): + """Delegates named parameter iteration to the wrapped module.""" + return self.module.named_parameters(prefix=prefix, recurse=recurse) def buffers(self, recurse: bool = True): + """Delegates buffer iteration to the wrapped module.""" + return self.module.buffers(recurse=recurse) def named_buffers(self, prefix: str = "", recurse: bool = True): + """Delegates named buffer iteration to the wrapped module.""" + return self.module.named_buffers(prefix=prefix, recurse=recurse) def register_forward_hook( self, *args, **kwargs ): + """Registers a forward hook while holding the parent module lock.""" + with self._parent_lock: return self.module.register_forward_hook(*args, **kwargs) def register_buffer( self, name: str, tensor: Optional[Tensor], persistent: bool = True ) -> None: + """Registers a buffer on the wrapped module under the parent lock.""" + with self._parent_lock: return self.module.register_buffer(name, tensor, persistent) def unregister_buffer(self, name: str): + """Removes a buffer from the wrapped module if it exists.""" + with self._parent_lock: if name in self.module._buffers: del self.module._buffers[name] @@ -95,10 +115,14 @@ def unregister_buffer(self, name: str): delattr(self.module, name) def register_parameter(self, name: str, param: Optional[Parameter]) -> None: + """Registers a parameter on the wrapped module under the parent lock.""" + with self._parent_lock: return self.module.register_parameter(name, param) def unregister_parameter(self, name: str) -> None: + """Removes a parameter from the wrapped module if it exists.""" + with self._parent_lock: if name in self.module._parameters: del self.module._parameters[name] @@ -116,6 +140,8 @@ def unregister_parameter(self, name: str) -> None: # getattr is only called if python cannot find attr for `self` def __getattr__(self, name: str): + """Falls back to the wrapped module while preserving lock discipline.""" + try: lock = object.__getattribute__(self, "_parent_lock") except AttributeError: @@ -128,6 +154,8 @@ def __getattr__(self, name: str): # setattr is always called by python even if attr exists in `self` def __setattr__(self, name: str, value: Any) -> None: + """Routes non-wrapper attributes to the wrapped module under lock.""" + if name in [ "module", "module_dtype", @@ -160,6 +188,8 @@ def stream_state_payload_to_cpu( self, tensors: Dict[str, torch.Tensor], ) -> Dict[str, torch.Tensor]: + """Streams arbitrary state tensors to CPU and records them in `state`.""" + state_lock = self._parent_lock return stream_tensor_dict_to_cpu( tensors, @@ -169,6 +199,8 @@ def stream_state_payload_to_cpu( ) def stream_parameters_to_cpu(self) -> Dict[str, torch.Tensor]: + """Streams this module's direct parameters to CPU-backed state storage.""" + state_lock = self._parent_lock tensor_map = {name: param for name, param in self.module.named_parameters(recurse=False)} return stream_tensor_dict_to_cpu( @@ -179,6 +211,8 @@ def stream_parameters_to_cpu(self) -> Dict[str, torch.Tensor]: ) def stream_buffers_to_cpu(self) -> Dict[str, torch.Tensor]: + """Streams this module's direct buffers to CPU-backed state storage.""" + state_lock = self._parent_lock tensor_map = {name: buf for name, buf in self.module.named_buffers(recurse=False)} return stream_tensor_dict_to_cpu( @@ -189,9 +223,13 @@ def stream_buffers_to_cpu(self) -> Dict[str, torch.Tensor]: ) def stream_all_to_cpu(self) -> Dict[str, Dict[str, torch.Tensor]]: + """Streams direct parameters and buffers to CPU in one call.""" + params = self.stream_parameters_to_cpu() buffers = self.stream_buffers_to_cpu() return {"parameters": params, "buffers": buffers} def stream_sync(self) -> None: + """Waits for any outstanding asynchronous stream-to-CPU transfers.""" + stream_sync_events(self.state, self._parent_lock) diff --git a/gptqmodel/looper/native_processor.py b/gptqmodel/looper/native_processor.py index 726afc57d..2b5ac5f61 100644 --- a/gptqmodel/looper/native_processor.py +++ b/gptqmodel/looper/native_processor.py @@ -8,7 +8,7 @@ import torch from torch.nn import Module -from ..looper.loop_processor import LoopProcessor +from ..looper.loop_processor import ExecutionConfig, LoopProcessor from ..looper.named_module import NamedModule from ..models import BaseQModel from ..quantization.config import QuantizeConfig @@ -21,6 +21,8 @@ # v2 requires that we also need to capture/store non-quantized inputs class NativeProcessor(LoopProcessor): + """Caches raw module inputs for later native or GPTAQ-style processing.""" + def __init__( self, tokenizer, @@ -33,6 +35,7 @@ def __init__( require_fwd: bool = True, calibration_concat_separator: Optional[str] = None, ): + """Initializes the processor and enables single-pass input capture.""" super().__init__( tokenizer=tokenizer, @@ -43,29 +46,46 @@ def __init__( calibration_concat_separator=calibration_concat_separator, prepare_dataset_func=prepare_dataset_func, batch_size=batch_size, - require_fwd=require_fwd, - fwd_after_process=False, - fwd_all_modules_in_single_pass=True, + execution_config=ExecutionConfig( + require_fwd=require_fwd, + fwd_replay_after_process=False, + fwd_all_modules_in_single_pass=True, + ), ) self.native_inp_caches = {} def set_calibration_dataset(self, calibration_dataset): + """Rejects dataset replacement because capture setup is fixed at construction.""" + raise NotImplementedError("NativeProcessor's calibration_dataset cannot be modified") def preprocess(self, module: NamedModule): + """Allocates the per-module cache used by the forward hook.""" + self.native_inp_caches[module.name] = [] def is_skipped(self, module: NamedModule) -> bool: + """Reports that native input capture currently runs for every eligible module.""" + # TODO: Add skipping certain modules return False def pre_process_fwd_hook(self, name: str) -> Callable[[Module, Tuple[torch.Tensor, ...], torch.Tensor], None]: + """Builds the forward hook that captures detached native inputs for a module.""" + def tmp(module, inp: Tuple[torch.Tensor, ...], out: torch.Tensor): + """Copies the module input to the configured GPTAQ staging device.""" + # gptq is mutable. inp = inp[0].detach() - gptaq_device = self.qcfg.gptaq.device if self.qcfg.gptaq is not None else "auto" + if self.qcfg.gptaq is not None: + gptaq_device = self.qcfg.gptaq.device + elif self.qcfg.foem is not None: + gptaq_device = self.qcfg.foem.device + else: + gptaq_device = "auto" if gptaq_device == "auto": target_device = DEVICE_1 elif gptaq_device == "cpu": @@ -92,19 +112,29 @@ def process( subset_index: Optional[int] = None, subset_total: Optional[int] = None, ): + """Moves the captured input list into module state for downstream use.""" + module.state[NATIVE_INPUTS_STATE_KEY] = self.native_inp_caches.pop(module.name) def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs): + """Clears cached native inputs after module finalization.""" + module.state.pop(NATIVE_INPUTS_STATE_KEY, None) def finalize(self, model: BaseQModel, **kwargs): + """Releases processor-level caches once the full loop has completed.""" + del self.native_inp_caches def verify_calibration_dataset(self, processor_index: int) -> bool: + """Ensures a calibration dataset was provided before running capture.""" + if self.calibration_dataset is None: raise ValueError("NativeProcessor's calibration_dataset must be provided.") else: return True def name(self) -> str: + """Returns the processor label used in logs and lifecycle reporting.""" + return "native" diff --git a/gptqmodel/looper/paroquant_processor.py b/gptqmodel/looper/paroquant_processor.py new file mode 100644 index 000000000..deab5b881 --- /dev/null +++ b/gptqmodel/looper/paroquant_processor.py @@ -0,0 +1,2638 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# ParoQuant processor implementation adapted from the ParoQuant paper and public +# project: +# https://arxiv.org/html/2511.10645v2 +# https://github.com/z-lab/paroquant + +"""ParoQuant looper integration. + +This processor keeps ParoQuant separate from the AWQ lifecycle: +1. capture calibration activations for each module +2. run ParoQuant's transformed-domain optimization per layer +3. export packed runtime tensors plus learned rotation state +4. replace the float modules with ParoQuant runtime kernels +""" + +from __future__ import annotations + +import copy +import hashlib +import inspect +import math +import threading +import time +from contextlib import nullcontext +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn import Module +from torch.utils.checkpoint import checkpoint as torch_checkpoint + +from ..looper.loop_processor import DTYPE_SIZE_COLUMN, ExecutionConfig, MODULE_FEATURE_COLUMN, LoopProcessor +from ..looper.named_module import NamedModule +from ..models import BaseQModel +from ..models.writer import ( + PROCESS_LOG_FWD_TIME, + PROCESS_LOG_LAYER, + PROCESS_LOG_MODULE, + PROCESS_LOG_NAME, + PROCESS_LOG_TIME, + PROCESS_USED_MEMORY, + QUANT_LOG_LOSS, + QUANT_LOG_NSAMPLES, + QUANT_LOG_DAMP, +) +from ..nn_modules.hooked_linear import HookedLinear +from ..nn_modules.qlinear.paroquant import ParoLinear +from ..quantization.config import FORMAT, METHOD, QuantizeConfig, resolve_quant_format +from ..quantization.paroquant.optimization import ( + _ParoQuantOptimLinear, + _activate_stage_params, + _normalize_group_size, + _normalize_opt_impl, + _normalize_opt_optimizer, + _normalize_quantizer_impl, + _quantizer_sym_for_impl, + _resolve_best_state_snapshot_dtype, + _result_from_model, + build_paroquant_optimizer, + build_random_rotation_buffers, + build_random_rotation_buffers_reference, + optimize_paroquant_linear, +) +from ..utils.fallback import normalize_fallback +from ..utils.logger import log_time_block, setup_logger +from ..utils.model import ( + create_quant_module, + find_modules, + get_module_by_name_prefix, + move_to, + nested_move_to, + pack_module, + recurse_getattr, + recurse_setattr, +) +from ..utils.module_locks import parent_module_lock +from ..utils.paroquant import prewarm_paroquant_rotation_extension +from ..utils.torch import CPU, torch_empty_cache + +log = setup_logger() + + +@dataclass +class _ParoQuantLayerState: + """Per-layer bookkeeping for activation capture and deferred quantization.""" + + modules: Dict[str, NamedModule] = field(default_factory=dict) + layer_module: Optional[torch.nn.Module] = None + pristine_layer_module: Optional[torch.nn.Module] = None + prepared_group_source_module: Optional[torch.nn.Module] = None + prepared_group_source_module_by_device: Optional[Dict[str, torch.nn.Module]] = None + layer_inputs: Optional[List[List[torch.Tensor]]] = None + layer_input_kwargs: Optional[List[Dict[str, torch.Tensor]]] = None + layer_outputs: Optional[List[List[torch.Tensor]]] = None + grouped_dataset: Optional[Any] = None + grouped_dataset_by_device: Optional[Dict[str, Any]] = None + replay_batches: Optional[Any] = None + subset_total: Optional[int] = None + processed_subsets: Set[int] = field(default_factory=set) + pending_modules: Set[str] = field(default_factory=set) + quantized: bool = False + lock: threading.Lock = field(default_factory=threading.Lock) + + +@dataclass +class _ParoQuantReplayBatch: + """One CPU-owned layer replay batch used by streamed grouped optimization.""" + + inputs: List[torch.Tensor] + input_kwargs: Dict[str, Any] + target: torch.Tensor + position_ids: Optional[torch.Tensor] + attention_mask: Optional[torch.Tensor] + row_count: int + + +def _value_has_inference_tensor(value: Any) -> bool: + """Detect nested inference-mode tensors so caches can rebuild autograd-safe values.""" + if isinstance(value, torch.Tensor): + return value.is_inference() + if isinstance(value, dict): + return any(_value_has_inference_tensor(inner) for inner in value.values()) + if isinstance(value, (list, tuple)): + return any(_value_has_inference_tensor(inner) for inner in value) + return False + + +class _LayerShardLoader: + """Stream replay batches from CPU to one device shard at a time.""" + + def __init__( + self, + batches: list[_ParoQuantReplayBatch], + *, + target_device: torch.device, + shard_batches: int, + metadata_cache: Optional[dict[tuple[int, str], torch.Tensor]] = None, + ) -> None: + self.batches = batches + self.target_device = torch.device(target_device) + self.shard_batches = max(1, int(shard_batches)) + self.metadata_cache = metadata_cache + + @staticmethod + def _tensor_to_device(value: torch.Tensor, device: torch.device) -> torch.Tensor: + non_blocking = value.device.type == CPU.type and value.is_pinned() and device.type == "cuda" + if value.is_inference(): + # Replay tensors may be captured under worker inference-mode; moving + # them with copy=True recreates normal tensors that autograd can use. + with torch.inference_mode(False): + return value.to(device=device, non_blocking=non_blocking, copy=True) + if value.device == device: + return value + return value.to(device=device, non_blocking=non_blocking) + + def _metadata_tensor_to_device(self, value: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + if value is None: + return None + if value.device == self.target_device and not value.is_inference(): + return value + cache = self.metadata_cache + if cache is None: + return self._tensor_to_device(value, self.target_device) + cache_key = (id(value), str(self.target_device)) + cached = cache.get(cache_key) + if cached is None: + cached = self._tensor_to_device(value, self.target_device) + cache[cache_key] = cached + return cached + + @staticmethod + def _move_value_to_device(value: Any, device: torch.device) -> Any: + if isinstance(value, torch.Tensor): + return _LayerShardLoader._tensor_to_device(value, device) + if isinstance(value, dict): + return {key: _LayerShardLoader._move_value_to_device(inner, device) for key, inner in value.items()} + if isinstance(value, list): + return [_LayerShardLoader._move_value_to_device(inner, device) for inner in value] + if isinstance(value, tuple): + return tuple(_LayerShardLoader._move_value_to_device(inner, device) for inner in value) + return value + + def _materialize_batch(self, batch: _ParoQuantReplayBatch) -> _ParoQuantReplayBatch: + return _ParoQuantReplayBatch( + inputs=[self._tensor_to_device(tensor, self.target_device) for tensor in batch.inputs], + input_kwargs={ + key: self._move_value_to_device(value, self.target_device) + for key, value in batch.input_kwargs.items() + }, + target=self._tensor_to_device(batch.target, self.target_device), + position_ids=self._metadata_tensor_to_device(batch.position_ids), + attention_mask=self._metadata_tensor_to_device(batch.attention_mask), + row_count=batch.row_count, + ) + + def iter_shards(self) -> Iterator[list[_ParoQuantReplayBatch]]: + for start in range(0, len(self.batches), self.shard_batches): + end = min(len(self.batches), start + self.shard_batches) + shard = [self._materialize_batch(batch) for batch in self.batches[start:end]] + yield shard + +class ParoQuantProcessor(LoopProcessor): + """Standalone ParoQuant lifecycle: capture, optimize, export, then pack.""" + + def __init__( + self, + tokenizer, + qcfg: QuantizeConfig, + calibration, + prepare_dataset_func, + calibration_concat_size: Optional[int], + calibration_sort: Optional[str], + batch_size: int, + gptq_model, + model, + require_fwd: bool = True, + calculate_w_wq_diff: bool = False, + calibration_concat_separator: Optional[str] = None, + ): + """Configure a looper that captures activations and quantizes after replay.""" + capture_group_layer_context = str(getattr(qcfg, "opt_scope", "module")).strip().lower() != "module" + super().__init__( + tokenizer=tokenizer, + qcfg=qcfg, + calibration=calibration, + calibration_concat_size=calibration_concat_size, + calibration_sort=calibration_sort, + calibration_concat_separator=calibration_concat_separator, + prepare_dataset_func=prepare_dataset_func, + batch_size=batch_size, + execution_config=ExecutionConfig( + require_fwd=require_fwd, + fwd_replay_after_process=True, + subset_forward_early_stop=True, + enable_activation_capture=True, + capture_layer_forward_context=capture_group_layer_context, + ), + ) + + self.calculate_w_wq_diff = calculate_w_wq_diff + self.avg_losses: list[float] = [] + self.gptq_model = gptq_model + self.model = model + self.format = resolve_quant_format(qcfg.format, qcfg.method) + self.qlinear_kernel = self._select_qlinear_kernel_for_format(self.format) + self._layer_states: Dict[int, _ParoQuantLayerState] = {} + self._layer_states_lock = threading.Lock() + self._rotary_lock = threading.Lock() + self._rotary_cache: Dict[str, nn.Module] = {} + self._rotary_source_id: Optional[int] = None + self._clean_group_layer_inputs: Optional[List[List[torch.Tensor]]] = None + self._runtime_prewarmed = False + self.fallback = qcfg.fallback + + def set_calibration_dataset(self, calibration_dataset): + """Reject runtime dataset swaps because capture state is tied to the processor.""" + raise NotImplementedError("ParoQuantProcessor's calibration_dataset cannot be modified") + + def receive_input_cache(self, input_cache): + """Seed grouped calibration with the clean first-layer input stream.""" + super().receive_input_cache(input_cache) + self._clean_group_layer_inputs = input_cache.layer_inputs if self._train_on_noisy_inputs_enabled() else None + + def _select_qlinear_kernel_for_format(self, format_value: FORMAT): + """Resolve the only supported runtime kernel class for ParoQuant.""" + fmt = FORMAT(format_value) if not isinstance(format_value, FORMAT) else format_value + if fmt != FORMAT.PAROQUANT: + raise ValueError(f"METHOD.PARO does not support this FORMAT: {format_value}") + return ParoLinear + + def prewarm_runtime(self) -> None: + """Build optional fused ParoQuant runtime pieces before timed layer work starts.""" + if getattr(self, "_runtime_prewarmed", False): + return + + self._runtime_prewarmed = True + fused_rotation = bool(getattr(self.qcfg, "opt_fused_rotation", False)) + group_size = int(getattr(self.qcfg, "group_size", 128)) + krot = int(getattr(self.qcfg, "krot", 8)) + if fused_rotation and group_size in {128} and krot in {1, 8}: + log.info("ParoQuant: prewarming fused rotation extension...") + if not prewarm_paroquant_rotation_extension( + fused_rotation=fused_rotation, + group_size=group_size, + krot=krot, + ): + return + log.info("ParoQuant: prewarmed fused rotation extension.") + + def _resolve_qlinear_kernel(self, module_name: Optional[str] = None): + """Resolve per-module dynamic overrides while enforcing ParoQuant format.""" + format_override = self.qcfg.dynamic_get(module_name, "format", None) if module_name else None + target_format = resolve_quant_format(format_override or self.qcfg.format, self.qcfg.method) + if target_format != FORMAT.PAROQUANT: + raise ValueError(f"METHOD.PARO does not support dynamic format override `{target_format}`.") + return ParoLinear + + def _get_layer_state(self, layer_index: int) -> _ParoQuantLayerState: + """Fetch or create the shared state bucket for one transformer layer.""" + with self._layer_states_lock: + state = self._layer_states.get(layer_index) + if state is None: + state = _ParoQuantLayerState() + self._layer_states[layer_index] = state + return state + + def _record_input_feature(self, module_name: str, feature: torch.Tensor) -> None: + """Store one batch of calibration activations for a named module.""" + if feature.dim() <= 2: + feature = feature.unsqueeze(0) + + if feature.device.type != "cpu": + feature = feature.detach().cpu() + else: + feature = feature.detach() + + with self.lock: + entry = self.tasks.get(module_name) + if entry is None: + entry = {"inputs": []} + self.tasks[module_name] = entry + entry.setdefault("inputs", []).append(feature) + + def _ensure_task_bucket(self, module_name: str, layer_index: int) -> None: + """Reset repeated relative module names when quantization advances to a new layer.""" + with self.lock: + entry = self.tasks.get(module_name) + if entry is None or entry.get("layer_index") != layer_index: + self.tasks[module_name] = { + "inputs": [], + "layer_index": layer_index, + } + return + entry.setdefault("inputs", []) + + def _layer_input_features(self, state: _ParoQuantLayerState) -> Dict[str, torch.Tensor]: + """Materialize concatenated calibration features for all modules in a layer.""" + features: Dict[str, torch.Tensor] = {} + for name in list(state.modules): + entry = self.tasks.get(name) or {} + tensors: List[torch.Tensor] = entry.get("inputs", []) # type: ignore[arg-type] + if not tensors: + features[name] = torch.empty(0) + continue + try: + features[name] = torch.cat(tensors, dim=0) + entry["inputs"] = [features[name]] + except RuntimeError: + features[name] = tensors[0] + return features + + def _module_quant_params(self, module_name: str) -> tuple[int, int, bool]: + """Read effective bit-width, group size, and symmetry for one module.""" + bits = int(self.qcfg.dynamic_get(module_name, "bits", self.qcfg.runtime_bits)) + group_size = int(self.qcfg.dynamic_get(module_name, "group_size", self.qcfg.group_size)) + sym = bool(self.qcfg.dynamic_get(module_name, "sym", self.qcfg.sym)) + return bits, group_size, sym + + @staticmethod + def _module_weight_matrix(module: NamedModule) -> torch.Tensor: + """Return the 2D weight matrix expected by the ParoQuant optimizer.""" + weight = module.weight.data + if weight.dim() != 2: + raise ValueError( + f"ParoQuant currently expects rank-2 module weights, got {tuple(weight.shape)} for `{module.full_name}`." + ) + return weight + + def _apply_optimization_result(self, module: NamedModule, result, original_weight: torch.Tensor) -> None: + """Store one optimization result into the wrapped module and its scratch state.""" + weight = self._module_weight_matrix(module) + pseudo_weight = result.pseudo_weight.to(device=weight.device, dtype=weight.dtype) + pack_weight = result.pack_weight.to(dtype=weight.dtype).cpu() + q_scales = result.q_scales.to(dtype=weight.dtype).cpu() + q_zeros = result.q_zeros.cpu() + pairs = result.pairs.to(dtype=torch.int16).cpu() + theta = result.theta.to(dtype=weight.dtype).cpu() + channel_scales = result.channel_scales.to(dtype=weight.dtype).cpu() + + with self.lock: + module.state.update( + { + "pack_weight": pack_weight, + "q_scales": q_scales, + "q_zeros": q_zeros, + "pairs": pairs, + "theta": theta, + "channel_scales": channel_scales, + } + ) + + if self.calculate_w_wq_diff: + if original_weight.dtype == torch.float16: + w_wq_diff = original_weight - pseudo_weight + else: + w_wq_diff = original_weight.to(dtype=torch.float32) - pseudo_weight.to(dtype=torch.float32) + with self.lock: + module.state["w_wq_diff"] = w_wq_diff + + module.weight.data = pseudo_weight + + def _log_quant_result(self, module: NamedModule, feat: torch.Tensor, val_loss: float, duration: float) -> None: + """Append one quantization log row using the same format as other processors.""" + n_samples = 0 if feat.numel() == 0 else feat.reshape(-1, feat.shape[-1]).shape[0] + + stat = { + PROCESS_LOG_NAME: self.name(), + PROCESS_LOG_LAYER: module.layer_index, + PROCESS_LOG_MODULE: module.name, + MODULE_FEATURE_COLUMN: self.module_feature_summary(module), + DTYPE_SIZE_COLUMN: self.module_dtype_size_summary(module), + QUANT_LOG_LOSS: f"{val_loss:.10f}", + QUANT_LOG_NSAMPLES: f"{n_samples}", + QUANT_LOG_DAMP: "", + PROCESS_LOG_TIME: f"{duration:.3f}", + PROCESS_LOG_FWD_TIME: self.formatted_fwd_time(), + PROCESS_USED_MEMORY: self.device_memory_report(), + } + + with self.lock: + self.durations.append(duration) + self.avg_losses.append(val_loss) + self.module_names.append(f"layer-{module.layer_index}-{module.name}") + self.log.append(stat) + + self.log_new_row(stat) + + def _quantize_one_module( + self, + module: NamedModule, + inputs: torch.Tensor, + ) -> tuple[float, float]: + """Optimize one module and stash its packed runtime tensors in `module.state`.""" + bits, group_size, sym = self._module_quant_params(module.full_name) + weight = self._module_weight_matrix(module) + bias = module.bias.data if getattr(module, "bias", None) is not None else None + original_weight = weight.detach().clone() + if inputs.numel() == 0: + inputs = torch.empty((0, weight.shape[1]), dtype=weight.dtype, device=weight.device) + module_seed = self._module_seed(module.layer_index, module.full_name) + + with torch.inference_mode(False), torch.enable_grad(): + result = optimize_paroquant_linear( + weight=weight, + bias=bias, + inputs=inputs, + bits=bits, + group_size=group_size, + sym=sym, + krot=self.qcfg.krot, + pair_ratio=self.qcfg.opt_pair_ratio, + train_rows=self.qcfg.opt_train_samples, + val_rows=self.qcfg.opt_validation_samples, + batch_size=self.qcfg.opt_batch_size, + rotation_epochs=self.qcfg.opt_rotation_epochs, + finetune_epochs=self.qcfg.opt_finetune_epochs, + rotation_lr=self.qcfg.opt_rotation_lr, + weight_lr=self.qcfg.opt_weight_lr, + quantizer_lr=self.qcfg.opt_quantizer_lr, + seed=module_seed, + optimizer_name=getattr(self.qcfg, "opt_optimizer", "adamw"), + optimizer_weight_decay=float(getattr(self.qcfg, "opt_weight_decay", 0.01)), + optimizer_betas=tuple(getattr(self.qcfg, "opt_betas", (0.9, 0.95))), + optimizer_eps=float(getattr(self.qcfg, "opt_eps", 1e-10)), + optimizer_amsgrad=bool(getattr(self.qcfg, "opt_amsgrad", False)), + sgd_momentum=float(getattr(self.qcfg, "opt_sgd_momentum", 0.0)), + sgd_dampening=float(getattr(self.qcfg, "opt_sgd_dampening", 0.0)), + sgd_nesterov=bool(getattr(self.qcfg, "opt_sgd_nesterov", False)), + fused_rotation=self.qcfg.opt_fused_rotation, + gradient_checkpointing=bool(getattr(self.qcfg, "opt_gradient_checkpointing", False)), + stage_cudagraph=self._module_scope_stage_cudagraph_enabled(), + best_state_dtype=getattr(self.qcfg, "opt_best_state_dtype", "fp32"), + stage_impl=self.qcfg.opt_stage_impl, + pair_impl=self.qcfg.opt_pair_impl, + quantizer_impl=self.qcfg.opt_quantizer_impl, + scale_clamp_min=self.qcfg.opt_channel_scale_clamp_min, + scale_clamp_max=self.qcfg.opt_channel_scale_clamp_max, + ) + + self._apply_optimization_result(module, result, original_weight) + return result.train_loss, result.val_loss + + @staticmethod + def _module_archetype(full_name: str) -> str: + """Use the terminal module name as the shared seed key across layers.""" + return full_name.rsplit(".", 1)[-1] + + def _module_seed_key(self, full_name: str) -> str: + """Keep module-scope seeds unique per full module while preserving grouped-scope behavior. + + Module scope optimizes every linear independently across the full model. + Reusing only the terminal archetype name correlates rotations across layers + (`...layers.0.self_attn.q_proj`, `...layers.1.self_attn.q_proj`, etc.), + which materially hurts end-to-end recovery on full-model module runs. + Grouped scopes keep the existing archetype behavior to avoid disturbing the + separately tuned layer/compute-block path. + """ + if self._opt_scope_mode() == "module": + return full_name + return self._module_archetype(full_name) + + def _module_seed(self, layer_index: int, full_name: str) -> int: + """Derive a deterministic per-module seed from base seed, layer index, and module name.""" + module_name = self._module_seed_key(full_name) + seed_material = f"{int(self.qcfg.opt_seed)}:{int(layer_index)}:{module_name}".encode("utf-8") + digest = hashlib.blake2b(seed_material, digest_size=8).digest() + return int.from_bytes(digest, byteorder="big", signed=False) + + def _opt_scope_mode(self) -> str: + """Normalize the configured ParoQuant optimization scope.""" + return str(getattr(self.qcfg, "opt_scope", "module")).strip().lower() + + def _gradient_checkpointing_enabled(self) -> bool: + """Resolve grouped-stage checkpointing from config, defaulting to layer scope only.""" + configured = getattr(self.qcfg, "opt_gradient_checkpointing", None) + if configured is None: + return self._opt_scope_mode() == "layer" + return bool(configured) + + def uses_grouped_optimization(self) -> bool: + """Return whether this layer should optimize compute_block/layer scopes instead of one linear at a time.""" + return self._opt_scope_mode() != "module" + + def needs_pristine_layer_clone(self) -> bool: + """Whether stage-layer orchestration must build a separate pristine layer copy. + + Whole-layer scope replays clean targets before mutating the live layer, so + it can use the untouched live layer directly. ComputeBlock scope still needs + a dedicated pristine copy because later grouped clone construction may + happen after subset-time hooks and wrapper replacement. + """ + return self._opt_scope_mode() != "layer" + + def capture_layer_forward_context_during_subset(self) -> bool: + """ParoQuant captures grouped pristine layer IO outside subset forwards only.""" + return False + + def _train_on_noisy_inputs_enabled(self) -> bool: + """Enable official clean-target / noisy-input calibration only when explicitly requested.""" + qcfg = getattr(self, "qcfg", None) + return bool(getattr(qcfg, "opt_train_on_noisy_inputs", False)) and self.uses_grouped_optimization() + + def _module_scope_stage_cudagraph_enabled(self) -> bool: + """Disable per-linear stage CUDA graphs in the full quantization loop. + + Module scope calls the ParoQuant optimizer once per linear across the + whole model. The captured train-step graph keeps CUDA graph private-pool + allocations alive across calls, which makes active VRAM ratchet upward + layer by layer. Grouped/layer scope does not use this path. + """ + if self._opt_scope_mode() == "module": + return False + return bool(getattr(self.qcfg, "opt_stage_cudagraph", False)) + + def _optimizer_param_group_kwargs(self) -> Dict[str, object]: + """Return shared optimizer hyperparameters for ParoQuant stage param groups.""" + qcfg = getattr(self, "qcfg", None) + return { + "weight_decay": float(getattr(qcfg, "opt_weight_decay", 0.01)), + "betas": tuple(getattr(qcfg, "opt_betas", (0.9, 0.95))), + "eps": float(getattr(qcfg, "opt_eps", 1e-10)), + "amsgrad": bool(getattr(qcfg, "opt_amsgrad", False)), + "momentum": float(getattr(qcfg, "opt_sgd_momentum", 0.0)), + "dampening": float(getattr(qcfg, "opt_sgd_dampening", 0.0)), + "nesterov": bool(getattr(qcfg, "opt_sgd_nesterov", False)), + } + + def clean_group_layer_inputs( + self, + *, + layer_index: int, + layer_inputs: List[List[torch.Tensor]], + ) -> List[List[torch.Tensor]]: + """Return the clean calibration stream used to build grouped layer targets.""" + del layer_index + if not self._train_on_noisy_inputs_enabled(): + return layer_inputs + return getattr(self, "_clean_group_layer_inputs", None) or layer_inputs + + def receive_clean_layer_inputs( + self, + *, + layer_index: int, + layer_inputs: List[List[torch.Tensor]], + ) -> None: + """Advance the clean calibration stream for later-layer train-on-noisy-inputs replay.""" + del layer_index + if self._train_on_noisy_inputs_enabled(): + self._clean_group_layer_inputs = layer_inputs + + def _build_group_optim_linear(self, module: NamedModule) -> _ParoQuantOptimLinear: + """Materialize one ParoQuant optimizer wrapper from the current live module state.""" + bits, group_size, sym = self._module_quant_params(module.full_name) + weight = self._module_weight_matrix(module) + bias = module.bias.data if getattr(module, "bias", None) is not None else None + normalized_group_size = _normalize_group_size(group_size, weight.shape[1]) + normalized_pair_impl = _normalize_opt_impl(self.qcfg.opt_pair_impl, field="pair_impl") + quantizer_sym = _quantizer_sym_for_impl(sym, self.qcfg.opt_quantizer_impl) + _normalize_quantizer_impl(self.qcfg.opt_quantizer_impl) + module_seed = self._module_seed(module.layer_index, module.full_name) + + if normalized_pair_impl == "reference": + pairs, theta_mask = build_random_rotation_buffers_reference( + in_features=weight.shape[1], + group_size=normalized_group_size, + krot=self.qcfg.krot, + pair_ratio=self.qcfg.opt_pair_ratio, + seed=module_seed, + device=weight.device, + ) + else: + pairs, theta_mask = build_random_rotation_buffers( + in_features=weight.shape[1], + group_size=normalized_group_size, + krot=self.qcfg.krot, + pair_ratio=self.qcfg.opt_pair_ratio, + seed=module_seed, + device=weight.device, + ) + + return _ParoQuantOptimLinear( + weight.detach().to(device=weight.device, dtype=torch.float32), + None if bias is None else bias.detach().to(device=weight.device, dtype=torch.float32), + bits=bits, + group_size=normalized_group_size, + quantizer_sym=quantizer_sym, + pairs=pairs, + theta_mask=theta_mask, + scale_clamp_min=self.qcfg.opt_channel_scale_clamp_min, + scale_clamp_max=self.qcfg.opt_channel_scale_clamp_max, + # Layer-scope live optimization must stay on the portable PyTorch + # rotation path. The fused autograd kernel can fail to load on some + # fleet GPUs and the error may surface asynchronously well after the + # original rotation call. + fused_rotation=False if self._opt_scope_mode() == "layer" else self.qcfg.opt_fused_rotation, + ).to(device=weight.device, dtype=torch.float32) + + @staticmethod + def _restore_linear_from_hooked(module: HookedLinear) -> torch.nn.Linear: + """Drop HookedLinear's inference-only forward while preserving the shared parameter storage.""" + restored = torch.nn.Linear.__new__(torch.nn.Linear) + torch.nn.Module.__init__(restored) + restored.in_features = module.in_features + restored.out_features = module.out_features + restored.weight = module.weight + restored.bias = module.bias + return restored + + def _strip_hooked_linear_wrappers(self, module: torch.nn.Module) -> int: + """Layer-scope training must not run through HookedLinear, which always forwards in inference mode.""" + replaced = 0 + for child_name, child in list(module.named_children()): + if isinstance(child, HookedLinear): + setattr(module, child_name, self._restore_linear_from_hooked(child)) + replaced += 1 + continue + replaced += self._strip_hooked_linear_wrappers(child) + return replaced + + @staticmethod + def _sync_named_modules_to_live_layer( + layer: torch.nn.Module, + named_modules: list[NamedModule], + ) -> None: + """Retarget NamedModule handles after live-layer wrapper surgery. + + Layer scope temporarily unwraps HookedLinear modules into plain + ``nn.Linear`` so autograd can flow through the real decoder layer. The + corresponding ``NamedModule`` wrappers must follow those in-place + replacements; otherwise later result application and packing will update + detached wrapper objects instead of the real live layer modules. + """ + for named_module in named_modules: + live_module = recurse_getattr(layer, named_module.name) + named_module.module = live_module + try: + named_module.module_dtype = next(live_module.parameters()).dtype + except (StopIteration, AttributeError): + pass + + @staticmethod + def _force_layer_eager_attention(layer: torch.nn.Module) -> list[tuple[object, str, object]]: + """Temporarily force live-layer grouped optimization onto eager attention kernels.""" + overrides: list[tuple[object, str, object]] = [] + seen_configs: set[int] = set() + candidate_configs = [ + getattr(layer, "config", None), + getattr(getattr(layer, "self_attn", None), "config", None), + ] + for config in candidate_configs: + if config is None: + continue + config_id = id(config) + if config_id in seen_configs: + continue + seen_configs.add(config_id) + for attr in ("_attn_implementation", "attn_implementation"): + if not hasattr(config, attr): + continue + original_value = getattr(config, attr) + if original_value == "eager": + continue + setattr(config, attr, "eager") + overrides.append((config, attr, original_value)) + return overrides + + @staticmethod + def _restore_layer_attention_impl(overrides: list[tuple[object, str, object]]) -> None: + """Restore any attention implementation overrides after live-layer grouped optimization.""" + for config, attr, original_value in reversed(overrides): + setattr(config, attr, original_value) + + @staticmethod + def _materialize_live_layer_autograd_tensors( + layer: torch.nn.Module, + ) -> tuple[int, int]: + """Replace inference-born params/buffers with fresh normal tensors for live autograd. + + The layer object itself stays in place, but worker-side model loading can + create parameters and buffers under ``torch.inference_mode()``. Those + tensors keep inference-only bookkeeping even after ``module.to(...)``, + which breaks backward in the in-place layer optimizer. Rebuilding them + in place gives the live layer normal autograd-safe storage without + falling back to a full cloned layer path. + """ + replaced_params = 0 + replaced_buffers = 0 + with torch.inference_mode(False): + for module in layer.modules(): + for name, param in list(module._parameters.items()): + if param is None: + continue + rebuilt = nn.Parameter(param.detach().clone(), requires_grad=param.requires_grad) + if rebuilt is not param: + module._parameters[name] = rebuilt + replaced_params += 1 + for name, buffer in list(module._buffers.items()): + if buffer is None: + continue + rebuilt = buffer.detach().clone() + if rebuilt is not buffer: + module._buffers[name] = rebuilt + replaced_buffers += 1 + return replaced_params, replaced_buffers + + def _build_group_optim_layer( + self, + state: _ParoQuantLayerState, + group_modules: list[NamedModule], + ) -> tuple[torch.nn.Module, dict[str, _ParoQuantOptimLinear]]: + """Clone the layer and swap one selected group to ParoQuant optimizer wrappers.""" + if not group_modules: + raise ValueError("ParoQuantProcessor grouped optimization requires at least one module.") + + prepared_source = getattr(state, "prepared_group_source_module", None) + if prepared_source is None: + source_layer = getattr(state, "pristine_layer_module", None) or state.layer_module + if source_layer is None: + raise RuntimeError("ParoQuantProcessor grouped optimization requires the source layer module.") + + prepared_source = copy.deepcopy(source_layer).to(device=CPU, dtype=torch.float32) + # ComputeBlock/layer clone optimization cannot train through HookedLinear + # because its forward is permanently wrapped in inference mode. + self._strip_hooked_linear_wrappers(prepared_source) + # Grouped optimization needs stable, differentiable attention semantics. + # Keep the cloned calibration-time layer on eager attention even if the + # live model prefers SDPA/flash kernels for inference throughput. + layer_attn = getattr(prepared_source, "self_attn", None) + layer_attn_config = getattr(layer_attn, "config", None) + if layer_attn_config is not None and hasattr(layer_attn_config, "_attn_implementation"): + layer_attn_config._attn_implementation = "eager" + + state.prepared_group_source_module = prepared_source + + target_device = group_modules[0].weight.device + if self._opt_scope_mode() == "compute_block": + device_cache = getattr(state, "prepared_group_source_module_by_device", None) + if device_cache is None: + device_cache = {} + state.prepared_group_source_module_by_device = device_cache + + device_key = str(target_device) + prepared_source_for_device = device_cache.get(device_key) + if prepared_source_for_device is None: + if device_key == str(CPU): + prepared_source_for_device = prepared_source + else: + prepared_source_for_device = copy.deepcopy(prepared_source).to(device=target_device, dtype=torch.float32) + device_cache[device_key] = prepared_source_for_device + + layer_clone = copy.deepcopy(prepared_source_for_device) + if next(layer_clone.parameters()).device != target_device: + layer_clone = layer_clone.to(device=target_device, dtype=torch.float32) + else: + layer_clone = copy.deepcopy(prepared_source) + layer_clone = layer_clone.to(device=target_device, dtype=torch.float32) + + optim_modules: dict[str, _ParoQuantOptimLinear] = {} + for named_module in group_modules: + optim_module = self._build_group_optim_linear(named_module) + recurse_setattr(layer_clone, named_module.name, optim_module) + optim_modules[named_module.name] = optim_module + + return layer_clone, optim_modules + + def _get_root_rotary(self) -> Optional[nn.Module]: + """Return the model rotary module used to refresh grouped layer replay kwargs.""" + model = getattr(self, "model", None) + if self.gptq_model is not None and model is not None and getattr(self.gptq_model, "rotary_embedding", None): + rotary, _ = get_module_by_name_prefix(model, [self.gptq_model.rotary_embedding]) + return rotary + if model is None: + return None + return getattr(getattr(model, "model", model), "rotary_emb", None) + + @staticmethod + def _get_rotary_device(rotary: Optional[nn.Module], fallback: Optional[torch.device] = None) -> Optional[torch.device]: + """Resolve the active device for a rotary module, falling back safely.""" + if rotary is None: + return fallback + + rotary_device = getattr(getattr(rotary, "inv_freq", None), "device", None) + if rotary_device is not None: + return rotary_device + + try: + return next(rotary.parameters()).device + except (StopIteration, AttributeError, RuntimeError): + return fallback + + def _get_rotary_for_device(self, target_device: Optional[torch.device]) -> Optional[nn.Module]: + """Return a rotary module materialized on the requested device when needed.""" + rotary = self._get_root_rotary() + if rotary is None or target_device is None: + return rotary + + target_device = torch.device(target_device) + if self._get_rotary_device(rotary) == target_device: + return rotary + + cache_key = str(target_device) + with self._rotary_lock: + rotary = self._get_root_rotary() + if rotary is None: + return None + + source_id = id(rotary) + if self._rotary_source_id != source_id: + self._rotary_cache.clear() + self._rotary_source_id = source_id + + if self._get_rotary_device(rotary) == target_device: + return rotary + + cached = self._rotary_cache.get(cache_key) + if cached is None: + try: + cached = copy.deepcopy(rotary) + except Exception: + cached = rotary + + move_to(cached, device=target_device) + if cached is not rotary: + self._rotary_cache[cache_key] = cached + + return cached + + @staticmethod + def _can_cache_rotary_position_embeddings(rotary: Optional[nn.Module]) -> bool: + """Allow memoized rotary embeddings only for known HF rotary classes that depend on ids, device, and dtype.""" + if rotary is None: + return False + rotary_type = type(rotary) + return rotary_type.__module__.startswith("transformers.models.") and rotary_type.__name__.endswith("RotaryEmbedding") + + def _cached_group_position_ids( + self, + *, + device: torch.device, + batch_dim: int, + seq_len: int, + ) -> torch.Tensor: + """Reuse deterministic generated position ids for repeated grouped layer replay batches.""" + position_ids_cache = getattr(self, "_group_position_ids_cache", None) + if position_ids_cache is None: + position_ids_cache = {} + self._group_position_ids_cache = position_ids_cache + + cache_key = (str(device), int(batch_dim), int(seq_len)) + cached_position_ids = position_ids_cache.get(cache_key) + if cached_position_ids is None or cached_position_ids.is_inference(): + with torch.inference_mode(False): + cached_position_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(0).expand(batch_dim, -1) + position_ids_cache[cache_key] = cached_position_ids + return cached_position_ids + + def _cached_group_rotary_position_embeddings( + self, + *, + rotary: nn.Module, + x: torch.Tensor, + position_ids: torch.Tensor, + rotary_device: Optional[torch.device], + ) -> Any: + """Reuse rotary outputs when ids, dtype, and device are unchanged across replay batches/epochs.""" + target_rotary_device = rotary_device or x.device + rotary_cache = getattr(self, "_group_rotary_position_embeddings_cache", None) + if rotary_cache is None: + rotary_cache = {} + self._group_rotary_position_embeddings_cache = rotary_cache + + cache_key = ( + id(rotary), + id(position_ids), + str(target_rotary_device), + str(x.dtype), + ) + cached_position_embeddings = rotary_cache.get(cache_key) + if cached_position_embeddings is None or _value_has_inference_tensor(cached_position_embeddings): + x_for_rotary = _LayerShardLoader._tensor_to_device(x, target_rotary_device) + pos_for_rotary = _LayerShardLoader._tensor_to_device(position_ids, target_rotary_device) + with torch.inference_mode(False): + cached_position_embeddings = rotary(x_for_rotary, pos_for_rotary) + rotary_cache[cache_key] = cached_position_embeddings + return cached_position_embeddings + + def _prepare_group_forward_kwargs( + self, + layer: torch.nn.Module, + *, + x: torch.Tensor, + input_kwargs: Dict[str, Any], + attention_mask: Optional[torch.Tensor], + position_ids: Optional[torch.Tensor], + cache: bool = True, + ) -> Dict[str, Any]: + """Refresh grouped replay kwargs so real decoder layers can be replayed safely.""" + target_device = x.device + + prepared_cache_key = None + prepared_cache = None + if cache and not x.requires_grad: + prepared_cache = getattr(self, "_group_forward_prepared_cache", None) + if prepared_cache is None: + prepared_cache = {} + self._group_forward_prepared_cache = prepared_cache + prepared_cache_key = ( + id(layer), + id(x), + id(input_kwargs), + id(attention_mask) if attention_mask is not None else 0, + id(position_ids) if position_ids is not None else 0, + str(target_device), + ) + cached_prepared = prepared_cache.get(prepared_cache_key) + if cached_prepared is not None: + return dict(cached_prepared) + + skip_kwargs = {"past_key_values", "past_key_value"} + + if cache: + kwargs_cache = getattr(self, "_group_forward_kwargs_cache", None) + if kwargs_cache is None: + kwargs_cache = {} + self._group_forward_kwargs_cache = kwargs_cache + kwargs_cache_key = (id(input_kwargs), str(target_device)) + cached_module_kwargs = kwargs_cache.get(kwargs_cache_key) + if cached_module_kwargs is None: + cached_module_kwargs = { + key: nested_move_to(value, device=target_device) + for key, value in input_kwargs.items() + if key not in skip_kwargs + } + kwargs_cache[kwargs_cache_key] = cached_module_kwargs + module_kwargs = dict(cached_module_kwargs) + else: + module_kwargs = { + key: _LayerShardLoader._move_value_to_device(value, target_device) + for key, value in input_kwargs.items() + if key not in skip_kwargs + } + + signature_cache = getattr(self, "_group_forward_signature_cache", None) + if signature_cache is None: + signature_cache = {} + self._group_forward_signature_cache = signature_cache + layer_type = type(layer) + cached_signature = signature_cache.get(layer_type) + if cached_signature is None: + supports_position_ids = False + supports_position_embeddings = False + supports_attention_mask = False + try: + signature = inspect.signature(layer.forward).parameters + supports_position_ids = "position_ids" in signature + supports_position_embeddings = "position_embeddings" in signature + supports_attention_mask = "attention_mask" in signature + except (ValueError, TypeError): + supports_attention_mask = True + cached_signature = (supports_position_ids, supports_position_embeddings, supports_attention_mask) + signature_cache[layer_type] = cached_signature + supports_position_ids, supports_position_embeddings, supports_attention_mask = cached_signature + + if supports_attention_mask: + module_kwargs["attention_mask"] = None if attention_mask is None else ( + _LayerShardLoader._tensor_to_device(attention_mask, target_device) + ) + + if x.dim() == 2 and (supports_position_ids or supports_position_embeddings): + x = x.unsqueeze(0) + + seq_len: Optional[int] + batch_dim: int + if x.dim() >= 2: + batch_dim = x.shape[0] + seq_len = x.shape[1] if x.dim() >= 3 else x.shape[0] + else: + batch_dim = 1 + seq_len = x.shape[0] + + rotary = self._get_root_rotary() + if seq_len is not None and rotary is not None and supports_position_embeddings: + rotary = self._get_rotary_for_device(target_device or x.device) + rotary_device = self._get_rotary_device(rotary, target_device or x.device) + pos_for_rotary = position_ids if supports_position_ids else None + if pos_for_rotary is None or pos_for_rotary.shape[-1] != seq_len: + pos_values = self._cached_group_position_ids( + device=rotary_device or x.device, + batch_dim=batch_dim, + seq_len=seq_len, + ) + if supports_position_ids: + module_kwargs["position_ids"] = pos_values + pos_for_rotary = pos_values + else: + if rotary_device is not None and pos_for_rotary.device != rotary_device: + pos_for_rotary = _LayerShardLoader._tensor_to_device(pos_for_rotary, rotary_device) + if supports_position_ids: + module_kwargs["position_ids"] = pos_for_rotary + + if self._can_cache_rotary_position_embeddings(rotary): + module_kwargs["position_embeddings"] = self._cached_group_rotary_position_embeddings( + rotary=rotary, + x=x, + position_ids=pos_for_rotary, + rotary_device=rotary_device, + ) + else: + x_for_rotary = x if rotary_device is None or x.device == rotary_device else x.to(rotary_device) + module_kwargs["position_embeddings"] = rotary(x_for_rotary, pos_for_rotary) + elif supports_position_ids: + if position_ids is None or position_ids.shape[-1] != seq_len: + pos_values = self._cached_group_position_ids( + device=target_device or x.device, + batch_dim=batch_dim, + seq_len=seq_len, + ) + module_kwargs["position_ids"] = pos_values + else: + module_kwargs["position_ids"] = _LayerShardLoader._tensor_to_device(position_ids, target_device) + + module_kwargs["use_cache"] = False + module_kwargs = self._normalize_group_runtime_metadata(module_kwargs) + if prepared_cache_key is not None and prepared_cache is not None: + prepared_cache[prepared_cache_key] = dict(module_kwargs) + return module_kwargs + + @staticmethod + def _clone_group_runtime_metadata(value: Any) -> Any: + """Rebuild replay/runtime tensor trees with fresh normal tensors. + + This is stricter than the inference-detection path. It protects the + live layer optimizer from tensors that still alias inference-backed + storage even when `is_inference()` does not fire on the exact view that + reaches the layer entry point. + """ + if isinstance(value, torch.Tensor): + with torch.inference_mode(False): + return value.clone() + if isinstance(value, dict): + return {key: ParoQuantProcessor._clone_group_runtime_metadata(inner) for key, inner in value.items()} + if isinstance(value, list): + return [ParoQuantProcessor._clone_group_runtime_metadata(inner) for inner in value] + if isinstance(value, tuple): + return tuple(ParoQuantProcessor._clone_group_runtime_metadata(inner) for inner in value) + return value + + @staticmethod + def _layer_batch_row_count(input_batch: List[torch.Tensor]) -> int: + """Count flattened token rows for one cached layer-input batch.""" + if not input_batch: + return 0 + primary = input_batch[0] + if not isinstance(primary, torch.Tensor) or primary.numel() == 0: + return 0 + if primary.dim() == 0: + return 1 + return int(primary.numel() // max(1, primary.shape[-1])) + + def _prefix_batch_count_for_rows( + self, + input_batches: List[List[torch.Tensor]], + row_budget: int, + ) -> int: + """Choose the smallest non-empty prefix whose cached rows meet the requested budget.""" + if not input_batches: + return 0 + if row_budget <= 0: + return 1 + total_rows = 0 + for index, batch in enumerate(input_batches, start=1): + total_rows += self._layer_batch_row_count(batch) + if total_rows >= row_budget: + return index + return len(input_batches) + + def _suffix_batch_count_for_rows( + self, + input_batches: List[List[torch.Tensor]], + row_budget: int, + ) -> int: + """Choose the smallest non-empty suffix whose cached rows meet the requested budget.""" + if not input_batches: + return 0 + if row_budget <= 0: + return 1 + total_rows = 0 + count = 0 + for batch in reversed(input_batches): + total_rows += self._layer_batch_row_count(batch) + count += 1 + if total_rows >= row_budget: + return count + return len(input_batches) + + @staticmethod + def _target_primary(target_batch: Any) -> torch.Tensor: + """Normalize cached layer outputs to the single tensor used for the loss target.""" + if isinstance(target_batch, (list, tuple)): + if not target_batch: + raise ValueError("ParoQuant grouped optimization received an empty target batch.") + target_batch = target_batch[0] + if not isinstance(target_batch, torch.Tensor): + raise TypeError(f"ParoQuant grouped optimization expected tensor targets, got `{type(target_batch).__name__}`.") + return target_batch + + def _prepare_group_target( + self, + target_batch: Any, + *, + device: torch.device, + dtype: torch.dtype, + cache: bool = True, + ) -> torch.Tensor: + """Cache grouped loss targets after primary-tensor normalization and dtype/device conversion.""" + target = self._target_primary(target_batch) + if target.is_inference(): + with torch.inference_mode(False): + target = target.to(device=device, dtype=dtype, copy=True) + if not cache: + return target + if not cache: + return target.to(device=device, dtype=dtype) + target_cache = getattr(self, "_group_target_cache", None) + if target_cache is None: + target_cache = {} + self._group_target_cache = target_cache + cache_key = (id(target), str(device), str(dtype)) + cached_target = target_cache.get(cache_key) + if cached_target is None: + cached_target = target.to(device=device, dtype=dtype) + target_cache[cache_key] = cached_target + return cached_target + + @staticmethod + def _move_group_value_to_cpu(value: Any) -> Any: + """Recursively normalize replay metadata onto CPU-owned tensors.""" + if isinstance(value, torch.Tensor): + if value.is_inference(): + with torch.inference_mode(False): + tensor = value.detach().to(device=CPU, copy=True) + else: + tensor = value.detach().cpu() if value.device.type != CPU.type else value.detach() + if torch.cuda.is_available() and not tensor.is_pinned(): + tensor = tensor.pin_memory() + return tensor + if isinstance(value, dict): + return {key: ParoQuantProcessor._move_group_value_to_cpu(inner) for key, inner in value.items()} + if isinstance(value, list): + return [ParoQuantProcessor._move_group_value_to_cpu(inner) for inner in value] + if isinstance(value, tuple): + return tuple(ParoQuantProcessor._move_group_value_to_cpu(inner) for inner in value) + return value + + @staticmethod + def _normalize_group_runtime_metadata(value: Any) -> Any: + """Rebuild inference-mode replay metadata once kwargs are fully assembled on the live device.""" + if isinstance(value, torch.Tensor): + return _LayerShardLoader._tensor_to_device(value, value.device) + if isinstance(value, dict): + return { + key: ParoQuantProcessor._normalize_group_runtime_metadata(inner) + for key, inner in value.items() + } + if isinstance(value, list): + return [ParoQuantProcessor._normalize_group_runtime_metadata(inner) for inner in value] + if isinstance(value, tuple): + return tuple(ParoQuantProcessor._normalize_group_runtime_metadata(inner) for inner in value) + return value + + def _replay_batches_from_state( + self, + state: _ParoQuantLayerState, + ) -> tuple[list[_ParoQuantReplayBatch], list[_ParoQuantReplayBatch]]: + """Build CPU-owned train/validation replay batches for in-place layer optimization.""" + cached_batches = getattr(state, "replay_batches", None) + if cached_batches is not None: + return cached_batches + + input_batches = state.layer_inputs or [] + input_kwargs_batches = state.layer_input_kwargs or [] + output_batches = state.layer_outputs or [] + if not input_batches or not output_batches: + raise RuntimeError("ParoQuant layer optimization requires captured layer inputs and outputs.") + if len(input_batches) != len(output_batches): + raise RuntimeError("ParoQuant layer optimization requires aligned input/output batch counts.") + + if not input_kwargs_batches: + input_kwargs_batches = [{} for _ in range(len(input_batches))] + elif len(input_kwargs_batches) != len(input_batches): + raise RuntimeError("ParoQuant layer optimization requires aligned layer-input kwargs.") + + position_ids = list(self.inputs_cache.position_ids or []) + attention_masks = list(self.inputs_cache.attention_masks or []) + if len(position_ids) < len(input_batches): + position_ids.extend([None] * (len(input_batches) - len(position_ids))) + if len(attention_masks) < len(input_batches): + attention_masks.extend([None] * (len(input_batches) - len(attention_masks))) + + replay_batches: list[_ParoQuantReplayBatch] = [] + for input_batch, input_kwargs, output_batch, pos_ids, attn_mask in zip( + input_batches, + input_kwargs_batches, + output_batches, + position_ids, + attention_masks, + ): + cpu_inputs = [ + self._move_group_value_to_cpu(tensor) + for tensor in input_batch + ] + replay_batches.append( + _ParoQuantReplayBatch( + inputs=cpu_inputs, + input_kwargs={ + key: self._move_group_value_to_cpu(value) + for key, value in input_kwargs.items() + }, + target=self._move_group_value_to_cpu(self._target_primary(output_batch)), + position_ids=None if pos_ids is None else self._move_group_value_to_cpu(pos_ids), + attention_mask=None if attn_mask is None else self._move_group_value_to_cpu(attn_mask), + row_count=self._layer_batch_row_count(cpu_inputs), + ) + ) + + train_batch_count = self._prefix_batch_count_for_rows( + [batch.inputs for batch in replay_batches], + int(self.qcfg.opt_train_samples), + ) + val_batch_count = self._suffix_batch_count_for_rows( + [batch.inputs for batch in replay_batches], + int(self.qcfg.opt_validation_samples), + ) + val_start = max(0, len(replay_batches) - val_batch_count) + replay_split = (replay_batches[:train_batch_count], replay_batches[val_start:]) + state.replay_batches = replay_split + return replay_split + + @staticmethod + def _tensor_bytes(value: torch.Tensor) -> int: + """Measure one tensor's storage footprint.""" + return int(value.numel() * value.element_size()) + + def _nested_tensor_bytes(self, value: Any) -> int: + """Measure nested replay metadata footprint.""" + if isinstance(value, torch.Tensor): + return self._tensor_bytes(value) + if isinstance(value, dict): + return sum(self._nested_tensor_bytes(inner) for inner in value.values()) + if isinstance(value, (list, tuple)): + return sum(self._nested_tensor_bytes(inner) for inner in value) + return 0 + + def _replay_batch_bytes(self, batch: _ParoQuantReplayBatch) -> int: + """Estimate one replay batch's device footprint.""" + return ( + sum(self._tensor_bytes(tensor) for tensor in batch.inputs) + + self._tensor_bytes(batch.target) + + (0 if batch.position_ids is None else self._tensor_bytes(batch.position_ids)) + + (0 if batch.attention_mask is None else self._tensor_bytes(batch.attention_mask)) + + self._nested_tensor_bytes(batch.input_kwargs) + ) + + def _layer_train_shard_batches( + self, + layer: torch.nn.Module, + *, + param_groups: Sequence[dict[str, object]], + replay_batches: list[_ParoQuantReplayBatch], + ) -> int: + """Choose a conservative train-shard size from current free GPU memory.""" + if not replay_batches: + return 1 + + try: + target_device = next(layer.parameters()).device + except (StopIteration, RuntimeError): + target_device = CPU + + if target_device.type != "cuda": + return len(replay_batches) + + try: + free_bytes, _total_bytes = torch.cuda.mem_get_info(target_device) + except RuntimeError: + return 1 + + layer_bytes = sum(self._tensor_bytes(param.detach()) for param in layer.parameters()) + layer_bytes += sum(self._tensor_bytes(buffer.detach()) for buffer in layer.buffers()) + active_params = { + id(param): param + for group in param_groups + for param in group.get("params", []) + if isinstance(param, nn.Parameter) + } + optim_bytes = sum(int(param.numel()) * 16 for param in active_params.values()) + sample_count = min(4, len(replay_batches)) + avg_batch_bytes = max( + 1, + sum(self._replay_batch_bytes(batch) for batch in replay_batches[:sample_count]) // sample_count, + ) + activation_margin = 256 * 1024 * 1024 + headroom = max(avg_batch_bytes, int(free_bytes) - layer_bytes - optim_bytes - activation_margin) + return max(1, min(len(replay_batches), headroom // avg_batch_bytes)) + + def _group_dataset_from_state( + self, + state: _ParoQuantLayerState, + ) -> tuple[ + list[List[torch.Tensor]], + list[Dict[str, Any]], + list[Any], + list[Optional[torch.Tensor]], + list[Optional[torch.Tensor]], + list[List[torch.Tensor]], + list[Dict[str, Any]], + list[Any], + list[Optional[torch.Tensor]], + list[Optional[torch.Tensor]], + ]: + """Slice the preserved layer IO into train/validation batch lists for grouped optimization.""" + cached_dataset = getattr(state, "grouped_dataset", None) + if cached_dataset is not None: + return cached_dataset + + input_batches = state.layer_inputs or [] + input_kwargs_batches = state.layer_input_kwargs or [] + output_batches = state.layer_outputs or [] + if not input_batches or not output_batches: + raise RuntimeError("ParoQuant grouped optimization requires captured layer inputs and outputs.") + if len(input_batches) != len(output_batches): + raise RuntimeError("ParoQuant grouped optimization requires aligned input/output batch counts.") + + if not input_kwargs_batches: + input_kwargs_batches = [{} for _ in range(len(input_batches))] + elif len(input_kwargs_batches) != len(input_batches): + raise RuntimeError("ParoQuant grouped optimization requires aligned layer-input kwargs.") + + position_ids = list(self.inputs_cache.position_ids or []) + attention_masks = list(self.inputs_cache.attention_masks or []) + if len(position_ids) < len(input_batches): + position_ids.extend([None] * (len(input_batches) - len(position_ids))) + if len(attention_masks) < len(input_batches): + attention_masks.extend([None] * (len(input_batches) - len(attention_masks))) + + train_batch_count = self._prefix_batch_count_for_rows(input_batches, int(self.qcfg.opt_train_samples)) + val_batch_count = self._suffix_batch_count_for_rows(input_batches, int(self.qcfg.opt_validation_samples)) + val_start = max(0, len(input_batches) - val_batch_count) + + grouped_dataset = ( + input_batches[:train_batch_count], + input_kwargs_batches[:train_batch_count], + output_batches[:train_batch_count], + position_ids[:train_batch_count], + attention_masks[:train_batch_count], + input_batches[val_start:], + input_kwargs_batches[val_start:], + output_batches[val_start:], + position_ids[val_start:], + attention_masks[val_start:], + ) + state.grouped_dataset = grouped_dataset + return grouped_dataset + + def _group_dataset_for_device( + self, + state: _ParoQuantLayerState, + target_device: Optional[torch.device], + ) -> tuple[ + list[List[torch.Tensor]], + list[Dict[str, Any]], + list[Any], + list[Optional[torch.Tensor]], + list[Optional[torch.Tensor]], + list[List[torch.Tensor]], + list[Dict[str, Any]], + list[Any], + list[Optional[torch.Tensor]], + list[Optional[torch.Tensor]], + ]: + """Materialize the cached grouped dataset on the target device once per layer/device pair.""" + grouped_dataset = self._group_dataset_from_state(state) + device = torch.device(target_device) if target_device is not None else CPU + cache_key = str(device) + grouped_dataset_by_device = getattr(state, "grouped_dataset_by_device", None) + if grouped_dataset_by_device is None: + grouped_dataset_by_device = {} + state.grouped_dataset_by_device = grouped_dataset_by_device + + cached_dataset = grouped_dataset_by_device.get(cache_key) + if cached_dataset is not None: + return cached_dataset + + def _move_tensor_batches(batches: list[List[torch.Tensor]]) -> list[List[torch.Tensor]]: + return [[move_to(tensor, device=device) for tensor in batch] for batch in batches] + + def _move_kwargs_batches(kwargs_batches: list[Dict[str, Any]]) -> list[Dict[str, Any]]: + return [nested_move_to(kwargs, device=device) for kwargs in kwargs_batches] + + def _move_target_batches(target_batches: list[Any]) -> list[Any]: + return [nested_move_to(batch, device=device) for batch in target_batches] + + def _move_optional_tensors(optional_tensors: list[Optional[torch.Tensor]]) -> list[Optional[torch.Tensor]]: + return [None if tensor is None else move_to(tensor, device=device) for tensor in optional_tensors] + + device_dataset = ( + _move_tensor_batches(grouped_dataset[0]), + _move_kwargs_batches(grouped_dataset[1]), + _move_target_batches(grouped_dataset[2]), + _move_optional_tensors(grouped_dataset[3]), + _move_optional_tensors(grouped_dataset[4]), + _move_tensor_batches(grouped_dataset[5]), + _move_kwargs_batches(grouped_dataset[6]), + _move_target_batches(grouped_dataset[7]), + _move_optional_tensors(grouped_dataset[8]), + _move_optional_tensors(grouped_dataset[9]), + ) + grouped_dataset_by_device[cache_key] = device_dataset + return device_dataset + + def _forward_group_batch( + self, + layer: torch.nn.Module, + *, + batch_index: int, + input_batch: List[torch.Tensor], + input_kwargs: Dict[str, Any], + attention_mask: Optional[torch.Tensor], + position_ids: Optional[torch.Tensor], + ) -> torch.Tensor: + """Run one cached batch through a grouped layer clone and return its primary output.""" + del batch_index + try: + layer_device = next(layer.parameters()).device + except (StopIteration, RuntimeError): + layer_device = CPU + + inputs = [_LayerShardLoader._tensor_to_device(inp, layer_device) for inp in input_batch] + if not inputs: + raise RuntimeError("ParoQuant grouped optimization forward requires at least one input tensor.") + + additional_inputs = self._prepare_group_forward_kwargs( + layer, + x=inputs[0], + input_kwargs=input_kwargs, + attention_mask=attention_mask, + position_ids=position_ids, + ) + module_output = layer(*inputs, **additional_inputs) + if module_output is None: + raise RuntimeError("ParoQuant grouped optimization forward returned no output.") + if isinstance(module_output, tuple): + module_output = module_output[0] + if not isinstance(module_output, torch.Tensor): + raise TypeError( + "ParoQuant grouped optimization expected tensor layer outputs, " + f"got `{type(module_output).__name__}`." + ) + return module_output + + def _forward_replay_batch( + self, + layer: torch.nn.Module, + *, + replay_batch: _ParoQuantReplayBatch, + cache_kwargs: bool, + ) -> torch.Tensor: + """Run one streamed replay batch through the live grouped layer.""" + layer_device = replay_batch.inputs[0].device if replay_batch.inputs else CPU + + inputs = [ + self._clone_group_runtime_metadata(_LayerShardLoader._tensor_to_device(inp, layer_device)) + for inp in replay_batch.inputs + ] + if not inputs: + raise RuntimeError("ParoQuant layer replay requires at least one input tensor.") + + additional_inputs = self._prepare_group_forward_kwargs( + layer, + x=inputs[0], + input_kwargs=replay_batch.input_kwargs, + attention_mask=replay_batch.attention_mask, + position_ids=replay_batch.position_ids, + cache=cache_kwargs, + ) + additional_inputs = self._clone_group_runtime_metadata(additional_inputs) + if any(inp.is_inference() for inp in inputs) or _value_has_inference_tensor(additional_inputs): + raise RuntimeError( + "ParoQuant layer replay assembled inference-mode live inputs. " + f"inputs_inference={[inp.is_inference() for inp in inputs]} " + f"kwargs_inference={_value_has_inference_tensor(additional_inputs)} " + f"kwargs_keys={sorted(additional_inputs.keys())}" + ) + module_output = layer(*inputs, **additional_inputs) + if module_output is None: + raise RuntimeError("ParoQuant grouped optimization forward returned no output.") + if isinstance(module_output, tuple): + module_output = module_output[0] + if not isinstance(module_output, torch.Tensor): + raise TypeError( + "ParoQuant grouped optimization expected tensor layer outputs, " + f"got `{type(module_output).__name__}`." + ) + return module_output + + @staticmethod + def _reset_group_angles(optim_modules: dict[str, _ParoQuantOptimLinear]) -> None: + """Clamp masked dummy rotations back to zero after each grouped optimizer step.""" + for optim_module in optim_modules.values(): + optim_module.reset_masked_angles() + + def _evaluate_group_layer( + self, + layer: torch.nn.Module, + *, + input_batches: list[List[torch.Tensor]], + input_kwargs_batches: list[Dict[str, Any]], + target_batches: list[Any], + position_ids: list[Optional[torch.Tensor]], + attention_masks: list[Optional[torch.Tensor]], + use_amp: bool, + ) -> float: + """Measure full-layer reconstruction error for one grouped optimization stage.""" + if not input_batches: + return 0.0 + + total_loss = 0.0 + with torch.inference_mode(): + for batch_index, (input_batch, input_kwargs, target_batch, pos_ids, attn_mask) in enumerate( + zip(input_batches, input_kwargs_batches, target_batches, position_ids, attention_masks) + ): + autocast_ctx = torch.amp.autocast("cuda") if use_amp else nullcontext() + with autocast_ctx: + preds = self._forward_group_batch( + layer, + batch_index=batch_index, + input_batch=input_batch, + input_kwargs=input_kwargs, + attention_mask=attn_mask, + position_ids=pos_ids, + ) + target = self._prepare_group_target(target_batch, device=preds.device, dtype=preds.dtype) + total_loss += float(F.smooth_l1_loss(preds, target).item()) + return total_loss / max(1, len(input_batches)) + + def _forward_group_batch_train( + self, + layer: torch.nn.Module, + *, + batch_index: int, + input_batch: list[torch.Tensor], + input_kwargs: dict[str, Any], + attention_mask: Optional[torch.Tensor], + position_ids: Optional[torch.Tensor], + ) -> torch.Tensor: + """Optionally checkpoint grouped train forwards to reduce activation residency.""" + if not self._gradient_checkpointing_enabled() or not input_batch: + return self._forward_group_batch( + layer, + batch_index=batch_index, + input_batch=input_batch, + input_kwargs=input_kwargs, + attention_mask=attention_mask, + position_ids=position_ids, + ) + + runtime_inputs = [ + self._clone_group_runtime_metadata(_LayerShardLoader._tensor_to_device(inp, inp.device)) + for inp in input_batch + ] + runtime_kwargs = self._clone_group_runtime_metadata(input_kwargs) + runtime_attention_mask = self._clone_group_runtime_metadata(attention_mask) + runtime_position_ids = self._clone_group_runtime_metadata(position_ids) + + def _forward(*runtime_inputs: torch.Tensor) -> torch.Tensor: + return self._forward_group_batch( + layer, + batch_index=batch_index, + input_batch=list(runtime_inputs), + input_kwargs=runtime_kwargs, + attention_mask=runtime_attention_mask, + position_ids=runtime_position_ids, + ) + + return torch_checkpoint(_forward, *tuple(runtime_inputs), use_reentrant=False) + + def _forward_replay_batch_train( + self, + layer: torch.nn.Module, + *, + replay_batch: _ParoQuantReplayBatch, + cache_kwargs: bool, + ) -> torch.Tensor: + """Optionally checkpoint streamed grouped train forwards to reduce activation residency.""" + if not self._gradient_checkpointing_enabled() or not replay_batch.inputs: + return self._forward_replay_batch( + layer, + replay_batch=replay_batch, + cache_kwargs=cache_kwargs, + ) + + def _forward(*runtime_inputs: torch.Tensor) -> torch.Tensor: + return self._forward_replay_batch( + layer, + replay_batch=_ParoQuantReplayBatch( + inputs=list(runtime_inputs), + input_kwargs=replay_batch.input_kwargs, + target=replay_batch.target, + position_ids=replay_batch.position_ids, + attention_mask=replay_batch.attention_mask, + row_count=replay_batch.row_count, + ), + cache_kwargs=cache_kwargs, + ) + + return torch_checkpoint(_forward, *tuple(replay_batch.inputs), use_reentrant=False) + + @staticmethod + def _normalize_group_optimizer_param_groups( + param_groups: List[dict[str, object]], + ) -> List[dict[str, object]]: + """Merge equivalent optimizer groups so grouped optimization pays less optimizer overhead.""" + normalized_groups: List[dict[str, object]] = [] + group_index_by_key: Dict[tuple[float, float, tuple[float, float], float, bool, float, float, bool], int] = {} + seen_param_ids: set[int] = set() + + for param_group in param_groups: + raw_params = param_group.get("params", []) + if isinstance(raw_params, nn.Parameter): + raw_params = [raw_params] + + params: list[nn.Parameter] = [] + for param in raw_params: + if not isinstance(param, nn.Parameter): + continue + param_id = id(param) + if param_id in seen_param_ids: + continue + seen_param_ids.add(param_id) + params.append(param) + + if not params: + continue + + lr = float(param_group["lr"]) + weight_decay = float(param_group.get("weight_decay", 0.01)) + betas_obj = tuple(float(beta) for beta in param_group.get("betas", (0.9, 0.95))) + betas = (betas_obj[0], betas_obj[1]) + eps = float(param_group.get("eps", 1e-10)) + amsgrad = bool(param_group.get("amsgrad", False)) + momentum = float(param_group.get("momentum", 0.0)) + dampening = float(param_group.get("dampening", 0.0)) + nesterov = bool(param_group.get("nesterov", False)) + key = (lr, weight_decay, betas, eps, amsgrad, momentum, dampening, nesterov) + + bucket_index = group_index_by_key.get(key) + if bucket_index is None: + group_index_by_key[key] = len(normalized_groups) + normalized_groups.append( + { + "params": list(params), + "lr": lr, + "weight_decay": weight_decay, + "betas": betas, + "eps": eps, + "amsgrad": amsgrad, + "momentum": momentum, + "dampening": dampening, + "nesterov": nesterov, + } + ) + else: + normalized_groups[bucket_index]["params"].extend(params) + + return normalized_groups + + @staticmethod + def _build_group_optimizer( + normalized_groups: List[dict[str, object]], + *, + device: torch.device, + optimizer_name: str = "adamw", + ) -> torch.optim.Optimizer: + """Construct the grouped stage optimizer after redundant groups are merged away.""" + return build_paroquant_optimizer( + normalized_groups, + device=device, + optimizer_name=optimizer_name, + graph_capture=False, + ) + + @staticmethod + def _build_group_adamw( + normalized_groups: List[dict[str, object]], + *, + device: torch.device, + ) -> torch.optim.Optimizer: + """Backward-compatible wrapper for tests that still exercise the AdamW path directly.""" + return ParoQuantProcessor._build_group_optimizer( + normalized_groups, + device=device, + optimizer_name="adamw", + ) + + @staticmethod + def _group_state_key_matches_prefixes(state_key: str, active_prefixes: tuple[str, ...]) -> bool: + """Match state keys that belong to the active grouped modules only.""" + return any(state_key == prefix or state_key.startswith(f"{prefix}.") for prefix in active_prefixes) + + def _snapshot_group_best_state( + self, + layer: torch.nn.Module, + *, + active_prefixes: tuple[str, ...], + target_device: Optional[torch.device] = None, + target_dtype: Optional[torch.dtype] = None, + ) -> dict[str, torch.Tensor]: + """Capture only the mutable grouped module state instead of the whole layer clone.""" + return { + key: ( + tensor.detach().clone() + if target_device is None and (target_dtype is None or not tensor.is_floating_point() or tensor.dtype == target_dtype) + else tensor.detach().to( + device=target_device if target_device is not None else tensor.device, + dtype=target_dtype if target_dtype is not None and tensor.is_floating_point() else tensor.dtype, + ).clone() + ) + for key, tensor in layer.state_dict().items() + if self._group_state_key_matches_prefixes(key, active_prefixes) + } + + @staticmethod + def _restore_group_best_state( + layer: torch.nn.Module, + *, + best_state: dict[str, torch.Tensor], + ) -> None: + """Restore the captured grouped module state in-place.""" + if not best_state: + return + live_state = layer.state_dict(keep_vars=True) + with torch.no_grad(): + for key, tensor in best_state.items(): + live_state[key].copy_(tensor) + + def _run_group_stage( + self, + layer: torch.nn.Module, + *, + optim_modules: dict[str, _ParoQuantOptimLinear], + input_batches_train: list[List[torch.Tensor]], + input_kwargs_train: list[Dict[str, Any]], + target_batches_train: list[Any], + position_ids_train: list[Optional[torch.Tensor]], + attention_masks_train: list[Optional[torch.Tensor]], + input_batches_val: list[List[torch.Tensor]], + input_kwargs_val: list[Dict[str, Any]], + target_batches_val: list[Any], + position_ids_val: list[Optional[torch.Tensor]], + attention_masks_val: list[Optional[torch.Tensor]], + param_groups: List[dict[str, object]], + epochs: int, + ) -> tuple[float, float]: + """Run one grouped optimization stage against preserved full-layer outputs.""" + _normalize_opt_impl(self.qcfg.opt_stage_impl, field="stage_impl") + optimizer_name = _normalize_opt_optimizer(getattr(self.qcfg, "opt_optimizer", "adamw")) + normalized_groups = self._normalize_group_optimizer_param_groups(param_groups) + + opt_device = next(layer.parameters()).device + use_amp = opt_device.type == "cuda" + with _activate_stage_params(layer, normalized_groups): + if epochs <= 0 or not normalized_groups: + train_loss = self._evaluate_group_layer( + layer, + input_batches=input_batches_train, + input_kwargs_batches=input_kwargs_train, + target_batches=target_batches_train, + position_ids=position_ids_train, + attention_masks=attention_masks_train, + use_amp=use_amp, + ) + val_loss = self._evaluate_group_layer( + layer, + input_batches=input_batches_val, + input_kwargs_batches=input_kwargs_val, + target_batches=target_batches_val, + position_ids=position_ids_val, + attention_masks=attention_masks_val, + use_amp=use_amp, + ) + return train_loss, val_loss + + optimizer = self._build_group_optimizer( + normalized_groups, + device=opt_device, + optimizer_name=optimizer_name, + ) + total_steps = max(1, epochs * max(1, len(input_batches_train))) + base_lrs = [float(group["lr"]) for group in optimizer.param_groups] + scaler = torch.amp.GradScaler(enabled=use_amp) + active_prefixes = tuple(optim_modules.keys()) + needs_angle_reset = any(optim_module.theta.requires_grad for optim_module in optim_modules.values()) + best_state_dtype = _resolve_best_state_snapshot_dtype( + best_state_dtype=getattr(self.qcfg, "opt_best_state_dtype", "fp32"), + device=opt_device, + ) + best_state: Optional[dict[str, torch.Tensor]] = None + best_val_loss = float("inf") + last_train_loss = 0.0 + global_step = 0 + + for _epoch in range(epochs): + epoch_loss = 0.0 + batch_count = 0 + optimizer.zero_grad(set_to_none=True) + + for batch_index, (input_batch, input_kwargs, target_batch, pos_ids, attn_mask) in enumerate( + zip( + input_batches_train, + input_kwargs_train, + target_batches_train, + position_ids_train, + attention_masks_train, + ) + ): + autocast_ctx = torch.amp.autocast("cuda") if use_amp else nullcontext() + with autocast_ctx: + preds = self._forward_group_batch_train( + layer, + batch_index=batch_index, + input_batch=input_batch, + input_kwargs=input_kwargs, + attention_mask=attn_mask, + position_ids=pos_ids, + ) + target = self._prepare_group_target(target_batch, device=preds.device, dtype=preds.dtype) + loss = F.smooth_l1_loss(preds, target) + + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad(set_to_none=True) + + global_step += 1 + cosine_ratio = 0.5 * (1.0 + math.cos(math.pi * min(global_step, total_steps) / total_steps)) + for group, base_lr in zip(optimizer.param_groups, base_lrs): + group["lr"] = (base_lr / 20.0) + ((base_lr - (base_lr / 20.0)) * cosine_ratio) + + self._reset_group_angles(optim_modules) + epoch_loss += float(loss.item()) + batch_count += 1 + + last_train_loss = epoch_loss / max(1, batch_count) + val_loss = self._evaluate_group_layer( + layer, + input_batches=input_batches_val, + input_kwargs_batches=input_kwargs_val, + target_batches=target_batches_val, + position_ids=position_ids_val, + attention_masks=attention_masks_val, + use_amp=use_amp, + ) + if best_state is None or val_loss < best_val_loss: + best_val_loss = val_loss + best_state = self._snapshot_group_best_state( + layer, + active_prefixes=active_prefixes, + target_dtype=best_state_dtype, + ) + + if best_state is not None: + self._restore_group_best_state(layer, best_state=best_state) + self._reset_group_angles(optim_modules) + return last_train_loss, best_val_loss + + def _evaluate_group_layer_streamed( + self, + layer: torch.nn.Module, + *, + replay_batches: list[_ParoQuantReplayBatch], + use_amp: bool, + target_device: torch.device, + metadata_cache: Optional[dict[tuple[int, str], torch.Tensor]] = None, + ) -> float: + """Measure full-layer reconstruction error while streaming one validation batch at a time.""" + if not replay_batches: + return 0.0 + + total_loss = 0.0 + loader = _LayerShardLoader( + replay_batches, + target_device=target_device, + shard_batches=1, + metadata_cache=metadata_cache, + ) + autocast_ctx = torch.amp.autocast("cuda") if use_amp else nullcontext() + with torch.inference_mode(), autocast_ctx: + for shard in loader.iter_shards(): + replay_batch = shard[0] + preds = self._forward_replay_batch( + layer, + replay_batch=replay_batch, + cache_kwargs=False, + ) + target = self._prepare_group_target( + replay_batch.target, + device=preds.device, + dtype=preds.dtype, + cache=False, + ) + total_loss += float(F.smooth_l1_loss(preds, target).item()) + return total_loss / max(1, len(replay_batches)) + + def _run_group_stage_streamed( + self, + layer: torch.nn.Module, + *, + optim_modules: dict[str, _ParoQuantOptimLinear], + replay_batches_train: list[_ParoQuantReplayBatch], + replay_batches_val: list[_ParoQuantReplayBatch], + param_groups: List[dict[str, object]], + epochs: int, + metadata_cache: Optional[dict[tuple[int, str], torch.Tensor]] = None, + ) -> tuple[float, float]: + """Run one grouped layer stage while streaming train shards and validation batches from CPU.""" + _normalize_opt_impl(self.qcfg.opt_stage_impl, field="stage_impl") + optimizer_name = _normalize_opt_optimizer(getattr(self.qcfg, "opt_optimizer", "adamw")) + normalized_groups = self._normalize_group_optimizer_param_groups(param_groups) + + opt_device = next(layer.parameters()).device + use_amp = opt_device.type == "cuda" + with _activate_stage_params(layer, normalized_groups): + if epochs <= 0 or not normalized_groups: + train_loss = self._evaluate_group_layer_streamed( + layer, + replay_batches=replay_batches_train, + use_amp=use_amp, + target_device=opt_device, + metadata_cache=metadata_cache, + ) + val_loss = self._evaluate_group_layer_streamed( + layer, + replay_batches=replay_batches_val, + use_amp=use_amp, + target_device=opt_device, + metadata_cache=metadata_cache, + ) + return train_loss, val_loss + + optimizer = self._build_group_optimizer( + normalized_groups, + device=opt_device, + optimizer_name=optimizer_name, + ) + total_steps = max(1, epochs * max(1, len(replay_batches_train))) + base_lrs = [float(group["lr"]) for group in optimizer.param_groups] + scaler = torch.amp.GradScaler(enabled=use_amp) + active_prefixes = tuple(optim_modules.keys()) + needs_angle_reset = any(optim_module.theta.requires_grad for optim_module in optim_modules.values()) + best_state_dtype = _resolve_best_state_snapshot_dtype( + best_state_dtype=getattr(self.qcfg, "opt_best_state_dtype", "fp32"), + device=CPU, + ) + best_state: Optional[dict[str, torch.Tensor]] = None + best_val_loss = float("inf") + last_train_loss = 0.0 + global_step = 0 + shard_batches = self._layer_train_shard_batches( + layer, + param_groups=normalized_groups, + replay_batches=replay_batches_train, + ) + + for _epoch in range(epochs): + epoch_loss = 0.0 + batch_count = 0 + optimizer.zero_grad(set_to_none=True) + train_loader = _LayerShardLoader( + replay_batches_train, + target_device=opt_device, + shard_batches=shard_batches, + metadata_cache=metadata_cache, + ) + + autocast_ctx = torch.amp.autocast("cuda") if use_amp else nullcontext() + with autocast_ctx: + for shard in train_loader.iter_shards(): + for replay_batch in shard: + preds = self._forward_replay_batch_train( + layer, + replay_batch=replay_batch, + cache_kwargs=False, + ) + target = self._prepare_group_target( + replay_batch.target, + device=preds.device, + dtype=preds.dtype, + cache=False, + ) + loss = F.smooth_l1_loss(preds, target) + + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad(set_to_none=True) + + global_step += 1 + cosine_ratio = 0.5 * (1.0 + math.cos(math.pi * min(global_step, total_steps) / total_steps)) + for group, base_lr in zip(optimizer.param_groups, base_lrs): + group["lr"] = (base_lr / 20.0) + ((base_lr - (base_lr / 20.0)) * cosine_ratio) + + if needs_angle_reset: + self._reset_group_angles(optim_modules) + epoch_loss += float(loss.item()) + batch_count += 1 + + last_train_loss = epoch_loss / max(1, batch_count) + val_loss = self._evaluate_group_layer_streamed( + layer, + replay_batches=replay_batches_val, + use_amp=use_amp, + target_device=opt_device, + metadata_cache=metadata_cache, + ) + if best_state is None or val_loss < best_val_loss: + best_val_loss = val_loss + best_state = self._snapshot_group_best_state( + layer, + active_prefixes=active_prefixes, + target_device=CPU, + target_dtype=best_state_dtype, + ) + + if best_state is not None: + self._restore_group_best_state(layer, best_state=best_state) + if needs_angle_reset: + self._reset_group_angles(optim_modules) + return last_train_loss, best_val_loss + + def _optimize_live_layer( + self, + state: _ParoQuantLayerState, + group_modules: list[NamedModule], + ) -> tuple[dict[str, object], float]: + """Optimize the live layer in place for full-layer ParoQuant scope.""" + if not group_modules: + raise ValueError("ParoQuantProcessor grouped optimization requires at least one module.") + if state.layer_module is None: + raise RuntimeError("ParoQuantProcessor layer-scope optimization requires the live layer module.") + + layer = state.layer_module + target_device = group_modules[0].weight.device + original_layer_dtype = group_modules[0].weight.dtype + layer = layer.to(device=target_device, dtype=torch.float32) + self._strip_hooked_linear_wrappers(layer) + self._sync_named_modules_to_live_layer(layer, group_modules) + attn_impl_overrides = self._force_layer_eager_attention(layer) + replay_batches_train, replay_batches_val = self._replay_batches_from_state(state) + metadata_cache: dict[tuple[int, str], torch.Tensor] = {} + optim_modules: dict[str, _ParoQuantOptimLinear] = {} + original_modules: dict[str, torch.nn.Module] = {} + optimizer_group_kwargs = self._optimizer_param_group_kwargs() + try: + for param in layer.parameters(): + param.requires_grad_(False) + + for named_module in group_modules: + original_modules[named_module.name] = recurse_getattr(layer, named_module.name) + optim_module = self._build_group_optim_linear(named_module) + recurse_setattr(layer, named_module.name, optim_module) + optim_modules[named_module.name] = optim_module + + self._materialize_live_layer_autograd_tensors(layer) + + self._run_group_stage_streamed( + layer, + optim_modules=optim_modules, + replay_batches_train=replay_batches_train, + replay_batches_val=replay_batches_val, + param_groups=[ + {"params": [optim_module.channel_scales_opt], "lr": self.qcfg.opt_rotation_lr, **optimizer_group_kwargs} + for optim_module in optim_modules.values() + ] + [ + {"params": [optim_module.theta], "lr": self.qcfg.opt_rotation_lr, **optimizer_group_kwargs} + for optim_module in optim_modules.values() + ], + epochs=int(self.qcfg.opt_rotation_epochs), + metadata_cache=metadata_cache, + ) + + for optim_module in optim_modules.values(): + optim_module.init_quantizer() + + train_loss, val_loss = self._run_group_stage_streamed( + layer, + optim_modules=optim_modules, + replay_batches_train=replay_batches_train, + replay_batches_val=replay_batches_val, + param_groups=[ + {"params": [optim_module.weight], "lr": self.qcfg.opt_weight_lr, **optimizer_group_kwargs} + for optim_module in optim_modules.values() + ] + [ + {"params": optim_module.quantizer.optim_params(), "lr": self.qcfg.opt_quantizer_lr, **optimizer_group_kwargs} + for optim_module in optim_modules.values() + if optim_module.quantizer is not None + ], + epochs=int(self.qcfg.opt_finetune_epochs), + metadata_cache=metadata_cache, + ) + + metadata_cache.clear() + + results: dict[str, object] = {} + for named_module in group_modules: + optim_module = optim_modules[named_module.name] + results[named_module.name] = _result_from_model( + optim_module, + train_loss=train_loss, + val_loss=val_loss, + used_identity=False, + ) + return results, val_loss + finally: + for named_module in reversed(group_modules): + original_module = original_modules.get(named_module.name) + if original_module is not None: + recurse_setattr(layer, named_module.name, original_module) + # Live-layer training happens in fp32, but replay and inference should + # return to the layer's native dtype so flash-attn and downstream + # kernels see the original half/bfloat activations again. + layer.to(device=target_device, dtype=original_layer_dtype) + self._restore_layer_attention_impl(attn_impl_overrides) + + @staticmethod + def _supports_live_layer_scope(group_modules: list[NamedModule]) -> bool: + """Restrict in-place layer optimization to dense decoder blocks. + + The official apples-to-apples path is one full dense decoder layer at a + time. Expert-heavy layers can contain hundreds or thousands of modules, + so keep those on the cloned grouped path until a dedicated MoE/shared + rotation implementation lands. + """ + if not group_modules: + return False + expert_markers = ( + "expert", + "experts", + "shared_expert", + "gate_up_proj", + ) + expert_prefixes = ("experts.", "mlp.experts.", "moe.") + expert_like_modules = 0 + dense_modules = 0 + for module in group_modules: + module_name = getattr(module, "name", "") + leaf = module_name.rsplit(".", 1)[-1] + if any(marker in module_name for marker in expert_markers) or module_name.startswith(expert_prefixes): + expert_like_modules += 1 + continue + dense_modules += 1 + if leaf not in {"q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"}: + return False + return dense_modules > 0 and expert_like_modules == 0 + + def _optimize_group( + self, + state: _ParoQuantLayerState, + group_modules: list[NamedModule], + ) -> tuple[dict[str, object], float]: + """Optimize one compute_block or whole-layer group against the preserved full-layer target.""" + if self._opt_scope_mode() == "layer" and self._supports_live_layer_scope(group_modules): + with torch.inference_mode(False), torch.enable_grad(): + return self._optimize_live_layer(state, group_modules) + + with torch.inference_mode(False), torch.enable_grad(): + layer_clone, optim_modules = self._build_group_optim_layer(state, group_modules) + for param in layer_clone.parameters(): + param.requires_grad_(False) + + ( + input_batches_train, + input_kwargs_train, + target_batches_train, + position_ids_train, + attention_masks_train, + input_batches_val, + input_kwargs_val, + target_batches_val, + position_ids_val, + attention_masks_val, + ) = self._group_dataset_for_device(state, next(layer_clone.parameters()).device) + optimizer_group_kwargs = self._optimizer_param_group_kwargs() + + self._run_group_stage( + layer_clone, + optim_modules=optim_modules, + input_batches_train=input_batches_train, + input_kwargs_train=input_kwargs_train, + target_batches_train=target_batches_train, + position_ids_train=position_ids_train, + attention_masks_train=attention_masks_train, + input_batches_val=input_batches_val, + input_kwargs_val=input_kwargs_val, + target_batches_val=target_batches_val, + position_ids_val=position_ids_val, + attention_masks_val=attention_masks_val, + param_groups=[ + {"params": [optim_module.channel_scales_opt], "lr": self.qcfg.opt_rotation_lr, **optimizer_group_kwargs} + for optim_module in optim_modules.values() + ] + [ + {"params": [optim_module.theta], "lr": self.qcfg.opt_rotation_lr, **optimizer_group_kwargs} + for optim_module in optim_modules.values() + ], + epochs=int(self.qcfg.opt_rotation_epochs), + ) + + for optim_module in optim_modules.values(): + optim_module.init_quantizer() + + train_loss, val_loss = self._run_group_stage( + layer_clone, + optim_modules=optim_modules, + input_batches_train=input_batches_train, + input_kwargs_train=input_kwargs_train, + target_batches_train=target_batches_train, + position_ids_train=position_ids_train, + attention_masks_train=attention_masks_train, + input_batches_val=input_batches_val, + input_kwargs_val=input_kwargs_val, + target_batches_val=target_batches_val, + position_ids_val=position_ids_val, + attention_masks_val=attention_masks_val, + param_groups=[ + {"params": [optim_module.weight], "lr": self.qcfg.opt_weight_lr, **optimizer_group_kwargs} + for optim_module in optim_modules.values() + ] + [ + {"params": optim_module.quantizer.optim_params(), "lr": self.qcfg.opt_quantizer_lr, **optimizer_group_kwargs} + for optim_module in optim_modules.values() + if optim_module.quantizer is not None + ], + epochs=int(self.qcfg.opt_finetune_epochs), + ) + + results: dict[str, object] = {} + for named_module in group_modules: + optim_module = optim_modules[named_module.name] + results[named_module.name] = _result_from_model( + optim_module, + train_loss=train_loss, + val_loss=val_loss, + used_identity=False, + ) + return results, val_loss + + @staticmethod + def _module_compute_block_label(module_name: str) -> str: + """Map common projection archetypes to compute_block optimization buckets.""" + leaf = module_name.rsplit(".", 1)[-1] + if leaf in {"q_proj", "k_proj", "v_proj"}: + return "attn_qkv" + if leaf == "o_proj": + return "attn_o" + if leaf in {"gate_proj", "up_proj"}: + return "mlp_gate_up" + if leaf == "down_proj": + return "mlp_down" + return f"single:{module_name}" + + @staticmethod + def _module_compute_block_order(module_name: str) -> tuple[int, str]: + """Keep compute_block members in canonical architectural order.""" + leaf = module_name.rsplit(".", 1)[-1] + order = { + "q_proj": 0, + "k_proj": 1, + "v_proj": 2, + "o_proj": 3, + "gate_proj": 4, + "up_proj": 5, + "down_proj": 6, + } + return (order.get(leaf, 100), module_name) + + def _optimization_groups_for_layer( + self, + state: _ParoQuantLayerState, + ) -> list[tuple[str, list[NamedModule]]]: + """Resolve the optimization scope for the current layer. + + `module` keeps today's per-linear behavior. `compute_block` and `layer` + are scaffolded here so the lifecycle can switch scopes explicitly once + their execution paths land. + """ + mode = self._opt_scope_mode() + named_modules = [state.modules[name] for name in sorted(state.modules)] + if mode == "module": + return [(module.name, [module]) for module in named_modules] + if mode == "compute_block": + grouped: Dict[str, list[NamedModule]] = {} + for module in named_modules: + grouped.setdefault(self._module_compute_block_label(module.name), []).append(module) + for label in grouped: + grouped[label].sort(key=lambda module: self._module_compute_block_order(module.name)) + return [(label, grouped[label]) for label in sorted(grouped)] + if mode == "layer": + return [("layer", named_modules)] + raise ValueError(f"ParoQuantProcessor: unsupported optimize scope `{self.qcfg.opt_scope}`.") + + def _quantize_layer(self, layer_index: int, state: _ParoQuantLayerState) -> None: + """Quantize every captured module in a layer once all subsets are ready.""" + if state.quantized: + return + + optimization_groups = self._optimization_groups_for_layer(state) + mode = self._opt_scope_mode() + + input_feat = self._layer_input_features(state) + for module_name, tensor in input_feat.items(): + if tensor.numel() == 0 and not self.fallback: + raise RuntimeError( + f"ParoQuantProcessor error: missing activation features for `{module_name}` with fallback disabled." + ) + + if mode == "module": + for module_name, named_module in list(state.modules.items()): + feat = input_feat.get(module_name) + if feat is None: + feat = torch.empty(0) + + start = time.perf_counter() + _train_loss, val_loss = self._quantize_one_module(named_module, feat) + duration = time.perf_counter() - start + self._log_quant_result(named_module, feat, val_loss, duration) + else: + if state.layer_inputs is None or state.layer_outputs is None: + raise RuntimeError( + "ParoQuantProcessor grouped optimization requires preserved layer inputs and outputs. " + f"Resolved groups for layer {layer_index}: {[label for label, _modules in optimization_groups]}" + ) + + for _group_label, group_modules in optimization_groups: + start = time.perf_counter() + group_results, group_val_loss = self._optimize_group(state, group_modules) + duration = time.perf_counter() - start + duration_per_module = duration / max(1, len(group_modules)) + + for named_module in group_modules: + original_weight = self._module_weight_matrix(named_module).detach().clone() + result = group_results[named_module.name] + self._apply_optimization_result(named_module, result, original_weight) + if mode == "layer": + move_to(named_module.module, device=CPU) + feat = input_feat.get(named_module.name) + if feat is None: + feat = torch.empty(0) + self._log_quant_result(named_module, feat, group_val_loss, duration_per_module) + + if mode == "compute_block" and getattr(self.qcfg, "offload_to_disk", False): + flush_device = self._module_weight_matrix(group_modules[0]).device if group_modules else None + torch_empty_cache(device=flush_device, gc=False, sync=True) + + state.quantized = True + with self.lock: + for module_name in list(state.modules): + entry = self.tasks.get(module_name) + if entry is not None and entry.get("layer_index") == layer_index: + entry["inputs"] = [] + state.modules.clear() + state.pending_modules.clear() + state.processed_subsets.clear() + state.layer_inputs = None + state.layer_input_kwargs = None + state.layer_outputs = None + state.pristine_layer_module = None + state.prepared_group_source_module = None + state.prepared_group_source_module_by_device = None + state.grouped_dataset = None + state.grouped_dataset_by_device = None + state.replay_batches = None + state.subset_total = None + if hasattr(self, "_group_forward_kwargs_cache"): + self._group_forward_kwargs_cache.clear() + if hasattr(self, "_group_forward_prepared_cache"): + self._group_forward_prepared_cache.clear() + if hasattr(self, "_group_target_cache"): + self._group_target_cache.clear() + if hasattr(self, "_group_position_ids_cache"): + self._group_position_ids_cache.clear() + if hasattr(self, "_group_rotary_position_embeddings_cache"): + self._group_rotary_position_embeddings_cache.clear() + + def preprocess(self, module: NamedModule, fallback=None, **kwargs): + """Register a module for later activation capture and deferred quantization.""" + if self.qcfg.dynamic_get(layer_name=module.full_name) is False: + return + + self.fallback = normalize_fallback(fallback, self.qcfg.fallback) + layer_state = self._get_layer_state(module.layer_index) + with layer_state.lock: + layer_state.modules[module.name] = module + layer_state.layer_module = module.state.get("layer_module", layer_state.layer_module) + layer_state.pending_modules.add(module.name) + + self._ensure_task_bucket(module.name, module.layer_index) + + def is_skipped(self, module: NamedModule) -> bool: + """Report whether a module has been excluded from ParoQuant processing.""" + return self.tasks.get(module.name, False) is False + + def pre_process_fwd_hook(self, name: str) -> Callable[[Module, Tuple[torch.Tensor, ...], torch.Tensor], None]: + """Capture input activations during the calibration forward pass.""" + def hook(module, inp: Tuple[torch.Tensor, ...], out: torch.Tensor): + del module, out + if not inp: + return + feature = inp[0] if isinstance(inp, (tuple, list)) else inp + self._record_input_feature(name, feature) + + return hook + + def process( + self, + module: NamedModule, + device: torch.device = None, + subset: Optional[Dict[str, NamedModule]] = None, + previous_subset: Optional[Dict[str, NamedModule]] = None, + subset_index: Optional[int] = None, + subset_total: Optional[int] = None, + ): + """Mark subset progress and quantize once the whole layer is capture-complete.""" + del device, previous_subset + layer_index = module.layer_index + state = self._get_layer_state(layer_index) + + with state.lock: + if subset is not None: + state.modules.update(subset) + if subset_total is not None: + state.subset_total = subset_total + if subset_index is not None: + state.processed_subsets.add(subset_index) + + state.pending_modules.discard(module.name) + + should_quantize = ( + not state.quantized + and bool(state.modules) + and not state.pending_modules.intersection(state.modules.keys()) + and (state.subset_total is None or len(state.processed_subsets) >= state.subset_total) + ) + if should_quantize: + self._quantize_layer(layer_index, state) + + def receive_layer_forward_context( + self, + *, + layer_index: int, + layer_inputs: List[List[torch.Tensor]], + layer_input_kwargs: List[Dict[str, torch.Tensor]], + layer_outputs: List[List[torch.Tensor]], + subset_index: Optional[int] = None, + subset_total: Optional[int] = None, + ) -> None: + """Preserve noisy grouped inputs plus clean float layer targets.""" + del subset_index + state = self._get_layer_state(layer_index) + with state.lock: + if state.layer_inputs is None: + state.layer_inputs = layer_inputs + if state.layer_input_kwargs is None: + state.layer_input_kwargs = layer_input_kwargs + if state.layer_outputs is None: + state.layer_outputs = layer_outputs + if subset_total is not None and state.subset_total is None: + state.subset_total = subset_total + + def receive_pristine_layer_module( + self, + *, + layer_index: int, + layer_module: torch.nn.Module, + ) -> None: + """Preserve an untouched float layer snapshot for grouped optimization clones.""" + if self._opt_scope_mode() == "layer": + return + state = self._get_layer_state(layer_index) + with state.lock: + if state.pristine_layer_module is None: + state.pristine_layer_module = copy.deepcopy(layer_module).to(device=CPU) + + def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs): + """Pack one optimized float module into its ParoQuant runtime form.""" + self.pack_module(module, model=model) + + def pack_module(self, module: NamedModule, model: BaseQModel): + """Replace a float module with a packed ParoQuant quantized module.""" + module.stream_sync() + with self.lock: + module.state.pop("w_wq_diff", None) + pack_weight = module.state.pop("pack_weight").clone() + q_zeros = module.state.pop("q_zeros").clone() + q_scales = module.state.pop("q_scales").clone() + pairs = module.state.pop("pairs").clone() + theta = module.state.pop("theta").clone() + channel_scales = module.state.pop("channel_scales").clone() + + module.weight.data = move_to(pack_weight, device=CPU) + quant_linear_cls = self._resolve_qlinear_kernel(module.full_name) + layers = find_modules(model.model) + module_label = getattr(module, "full_name", getattr(module, "name", "")) + + with log_time_block( + "create_quant_module", + logger=log, + module_name=module_label, + ): + with parent_module_lock(module.full_name): + create_quant_module( + name=module.full_name, + linear_cls=quant_linear_cls, + bits=self.qcfg.runtime_bits, + desc_act=self.qcfg.desc_act, + dynamic=self.qcfg.dynamic, + group_size=self.qcfg.group_size, + module=model.model, + submodule=module, + sym=self.qcfg.sym, + device=self.qcfg.device, + lm_head_name=model.lm_head, + pack_dtype=self.qcfg.pack_dtype, + format=self.format, + register_buffers=False, + init_kwargs=self.qcfg.quant_linear_init_kwargs(), + ) + + qmodules = { + name: submodule + for name, submodule in find_modules(model.model, [quant_linear_cls]).items() + if name == module.full_name + } + with log_time_block( + "pack", + logger=log, + module_name=module_label, + ): + with parent_module_lock(module.full_name): + pack_module( + name=module.full_name, + qModules=qmodules, + q_scales=q_scales, + q_zeros=q_zeros, + q_g_idx=None, + layers=layers, + quant_linear_cls=quant_linear_cls, + lock=self.lock, + quantize_config=self.qcfg, + ) + + qmodule = qmodules[module.full_name] + if not isinstance(qmodule, ParoLinear): + raise TypeError( + f"Expected `{module.full_name}` to be packed as ParoLinear, got `{type(qmodule).__name__}`." + ) + + qmodule.pairs.copy_(pairs.to(device=qmodule.pairs.device, dtype=qmodule.pairs.dtype)) + qmodule.theta.copy_(theta.to(device=qmodule.theta.device, dtype=qmodule.theta.dtype)) + qmodule.channel_scales.copy_( + channel_scales.to(device=qmodule.channel_scales.device, dtype=qmodule.channel_scales.dtype) + ) + qmodule.post_init() + + def finalize(self, model: BaseQModel, **kwargs): + """Mark the model as ParoQuant-quantized before shared finalization work.""" + model.quantized = True + model.quantize_config.method = METHOD.PARO + super().finalize(model=model, **kwargs) + + def verify_calibration_dataset(self, processor_index: int) -> bool: + """Require calibration data because ParoQuant always needs activation replay.""" + del processor_index + if self.calibration_dataset is None: + raise ValueError("ParoQuantProcessor's calibration_dataset must be provided.") + return True + + @classmethod + def name(cls) -> str: + """Return the processor registry name.""" + return "paroquant" + + def has_captured_input_ids(self, name: str) -> bool: + """Report whether non-empty activation batches were captured for a module.""" + entry = self.tasks.get(name) or {} + tensors: List[torch.Tensor] = entry.get("inputs", []) + return tensors is not None and len(tensors) > 0 and all(t.numel() > 0 for t in tensors) diff --git a/gptqmodel/looper/qqq_processor.py b/gptqmodel/looper/qqq_processor.py index 1a9e78735..480792aa9 100644 --- a/gptqmodel/looper/qqq_processor.py +++ b/gptqmodel/looper/qqq_processor.py @@ -9,14 +9,14 @@ import torch from torch.nn import Module -from ..looper.loop_processor import DTYPE_SIZE_COLUMN, MODULE_FEATURE_COLUMN, LoopProcessor +from ..looper.loop_processor import DTYPE_SIZE_COLUMN, ExecutionConfig, MODULE_FEATURE_COLUMN, LoopProcessor from ..looper.named_module import NamedModule from ..models import BaseQModel from ..models.writer import (PROCESS_LOG_FWD_TIME, PROCESS_LOG_LAYER, PROCESS_LOG_MODULE, PROCESS_LOG_NAME, PROCESS_LOG_TIME, QUANT_LOG_DAMP, QUANT_LOG_LOSS, QUANT_LOG_NSAMPLES) -from ..nn_modules.qlinear.qqq import QQQQuantLinear -from ..quantization.config import METHOD, QuantizeConfig -from ..utils.failsafe import normalize_failsafe +from ..nn_modules.qlinear.qqq import QQQLinear +from ..quantization.config import METHOD, QuantizeConfig, resolve_quant_format +from ..utils.fallback import normalize_fallback from ..quantization.qqq import QQQ from ..utils.logger import setup_logger, log_time_block from ..utils.model import create_quant_module, find_modules, move_to, pack_module @@ -25,6 +25,8 @@ log = setup_logger() class QQQProcessor(LoopProcessor): + """Captures activations and quantizes modules with the QQQ workflow.""" + def __init__( self, tokenizer, @@ -38,6 +40,7 @@ def __init__( calculate_w_wq_diff: bool = False, calibration_concat_separator: Optional[str] = None, ): + """Initializes QQQ processing and optional weight-delta tracking.""" super().__init__( tokenizer=tokenizer, @@ -48,16 +51,20 @@ def __init__( calibration_concat_separator=calibration_concat_separator, prepare_dataset_func=prepare_dataset_func, batch_size=batch_size, - require_fwd=require_fwd, + execution_config=ExecutionConfig(require_fwd=require_fwd), ) self.calculate_w_wq_diff = calculate_w_wq_diff self.avg_losses = [] def set_calibration_dataset(self, calibration_dataset): + """Rejects dataset replacement because QQQ capture is fixed at construction.""" + raise NotImplementedError("QQQProcessor's calibration_dataset cannot be modified") - def preprocess(self, module: NamedModule, failsafe=None, **kwargs): + def preprocess(self, module: NamedModule, fallback=None, **kwargs): + """Builds the per-module QQQ task after applying dynamic overrides.""" + # entire module is skipped if self.qcfg.dynamic_get(layer_name=module.full_name) == False: return @@ -84,7 +91,7 @@ def preprocess(self, module: NamedModule, failsafe=None, **kwargs): tmp = QQQ(module=module, qcfg=qcfg_clone) - tmp.failsafe = normalize_failsafe(failsafe, qcfg_clone.failsafe) + tmp.fallback = normalize_fallback(fallback, qcfg_clone.fallback) tmp.expected_nsamples = getattr(self, "total_calibration_tokens", None) if self.qcfg.mse > 0.0: @@ -105,6 +112,8 @@ def preprocess(self, module: NamedModule, failsafe=None, **kwargs): self.tasks[module.name] = tmp def is_skipped(self, module: NamedModule) -> bool: + """Reports whether preprocessing omitted this module from QQQ work.""" + # gptq has no dynamic method of full override (removal) t = self.tasks.get(module.name, False) if t == False: @@ -113,7 +122,11 @@ def is_skipped(self, module: NamedModule) -> bool: return False def pre_process_fwd_hook(self, name: str) -> Callable[[Module, Tuple[torch.Tensor, ...], torch.Tensor], None]: + """Returns the forward hook that feeds captured batches into the QQQ task.""" + def tmp(_, inp: Tuple[torch.Tensor, ...], out: torch.Tensor): + """Records one activation batch for QQQ statistics accumulation.""" + # gptq is mutable. q = self.tasks[name] # noqa: F821 q.add_batch(inp[0].data, out.data) # noqa: F821 @@ -128,8 +141,10 @@ def process( subset_index: Optional[int] = None, subset_total: Optional[int] = None, ): + """Runs QQQ quantization for one module and stores pack-ready tensors.""" + base_title = f"Quantizing {module.name} in layer" - self._pause_controller.register_and_draw_progress_bar(self.pb, title=base_title, subtitle="") + self.draw_progress(base_title) qqq = self.tasks # logger.info(f"Quantizing module START: {name}, {gptq[name].shape()}") @@ -219,6 +234,8 @@ def process( # submodule_finalized is called in reverse after all next sequential processes are called def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs): + """Creates the quantized module and packs the saved QQQ tensors into it.""" + # generate complete, safe to move to cpu module.weight.data = move_to(module.weight.data, device=CPU) # large weights is slow to init on cpu module.state.pop("w", None) # no need for original weights now @@ -244,8 +261,8 @@ def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs): ): create_quant_module( name=module.full_name, - linear_cls=QQQQuantLinear, - bits=self.qcfg.bits, + linear_cls=QQQLinear, + bits=self.qcfg.runtime_bits, desc_act=self.qcfg.desc_act, dynamic=self.qcfg.dynamic, group_size=self.qcfg.group_size, @@ -255,13 +272,14 @@ def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs): device=self.qcfg.device, lm_head_name=model.lm_head, pack_dtype=self.qcfg.pack_dtype, + format=resolve_quant_format(self.qcfg.format, self.qcfg.method), register_buffers=False, ) # pack module qModules = { name: submodule - for name, submodule in find_modules(model.model, [QQQQuantLinear]).items() + for name, submodule in find_modules(model.model, [QQQLinear]).items() if name == module.full_name } with log_time_block( @@ -276,7 +294,7 @@ def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs): q_zeros=q_zeros, q_g_idx=q_g_idx, layers=layers, - quant_linear_cls=QQQQuantLinear, + quant_linear_cls=QQQLinear, lock=self.lock, q_scales_extra=q_scales_extra, quantize_config=self.qcfg, @@ -289,14 +307,18 @@ def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs): module.unregister_parameter("weight") def finalize(self, model: BaseQModel, **kwargs): + """Marks the model as QQQ-quantized and runs shared finalization logic.""" + # set quantized state model.quantized = True - model.quantize_config.quant_method = METHOD.QQQ + model.quantize_config.method = METHOD.QQQ super().finalize(model=model, **kwargs) def verify_calibration_dataset(self, processor_index: int) -> bool: + """Ensures QQQ received calibration data before the quantization loop starts.""" + if self.calibration_dataset is None: raise ValueError("GPTQProcessor's calibration_dataset must be provided.") else: @@ -304,7 +326,11 @@ def verify_calibration_dataset(self, processor_index: int) -> bool: @classmethod def name(cls) -> str: + """Returns the processor label used in logs and lifecycle reporting.""" + return "qqq" def has_captured_input_ids(self, name: str) -> bool: + """Reports whether the module saw at least one captured forward batch.""" + return self.tasks[name].fwd_counter > 0 diff --git a/gptqmodel/looper/stage_inputs_capture.py b/gptqmodel/looper/stage_inputs_capture.py index 791b8d373..f910a90bd 100644 --- a/gptqmodel/looper/stage_inputs_capture.py +++ b/gptqmodel/looper/stage_inputs_capture.py @@ -17,7 +17,7 @@ from ..nn_modules.hooked_linear import STOP_FORWARD_EXCEPTION, StopForward from ..utils.ctx import ctx from ..utils.device import get_device -from ..utils.looper_helpers import device_ctx +from ..utils.looper_helpers import device_ctx, select_forward_devices from ..utils.logger import setup_logger from ..utils.model import get_module_by_name_prefix, move_to, nested_move_to from ..utils.torch import CPU, META @@ -30,6 +30,8 @@ class StageInputsCapture: """Capture layer inputs so processors can reuse cached activations.""" def __init__(self, looper: ModuleLooper, logger=None) -> None: + """Binds the capture stage to a looper instance and logger.""" + self.looper = looper self.gptq_model = looper.gptq_model self.logger = logger or setup_logger() @@ -40,6 +42,8 @@ def cache_inputs( calibration_data: Iterable[Dict[str, torch.Tensor]], use_cache: bool, ) -> InputCache: + """Runs a short forward over calibration data and caches first-layer inputs.""" + layer_inputs: List[List[torch.Tensor]] = [] attention_masks: List[torch.Tensor | None] = [] position_ids: List[torch.Tensor] = [] @@ -74,15 +78,39 @@ def cache_inputs( # materialize / move.to CPU for initial input capture and for first layer to minimize VRAM usage, inputs will be stored on CPU # and to mimic behavior of offload_to_disk=False for offload_to_disk=True - # TODO: move back outputs to CPU after forward pass to minimize VRAM usage for other layers - # or wait till calibration_data_device feature merge, when we can specify device for calibration data (or balanced) - # (and in case calibration data device will be the same as forward pass device save some ticks) + # Use calibration_data_device to specify device for calibration data (or "balanced" for round-robin across GPUs) layers[0] = self.gptq_model.shell_module_materialize( target_submodule=layers[0], device=CPU, ) cur_layer_device = CPU - data_device = cur_layer_device + + # Use calibration_data_device if specified, otherwise use cur_layer_device + calib_device_cfg = self.gptq_model.quantize_config.calibration_data_device + + # Prepare devices for balanced mode + balanced_devices: List[torch.device] = [] + balanced_mode = False + if calib_device_cfg == "balanced": + balanced_mode = True + # Get all available devices of same type as the quantization device + all_devices = select_forward_devices(self.gptq_model.quantize_config.device) + # Apply compute_device_filter if set + compute_device_filter = self.gptq_model.quantize_config.compute_device_filter + if compute_device_filter is not None: + balanced_devices = compute_device_filter(all_devices) + if not balanced_devices: + balanced_devices = all_devices + else: + balanced_devices = all_devices + data_device = balanced_devices[0] if balanced_devices else cur_layer_device + elif calib_device_cfg is not None: + data_device = calib_device_cfg + else: + data_device = cur_layer_device + + # Round-robin counter for balanced mode + balanced_rr_counter = [0] # Use list to allow modification in nested function cache_forward_pb = None processed_rows = 0 @@ -104,26 +132,41 @@ def cache_inputs( ).draw() def store_input_hook(module, args, kwargs): - layer_input: List[torch.Tensor] = [] - if kwargs.get("hidden_states") is not None: - layer_input.append(move_to(kwargs["hidden_states"], device=data_device)) + """Captures the incoming batch for the first layer and aborts the forward.""" + + # Select device for this batch (round-robin for balanced mode) + if balanced_mode and balanced_devices: + batch_device = balanced_devices[balanced_rr_counter[0] % len(balanced_devices)] + balanced_rr_counter[0] += 1 else: - layer_input.append(move_to(args[0], device=data_device)) + batch_device = data_device + + layer_input = self.gptq_model.capture_first_layer_positional_inputs( + args=args, + kwargs=kwargs, + batch_device=batch_device, + ) layer_inputs.append(layer_input) if kwargs.get("attention_mask") is not None: - attention_masks.append(kwargs["attention_mask"].to(device=data_device)) + attention_masks.append(kwargs["attention_mask"].to(device=batch_device)) else: attention_masks.append(None) pos_ids = kwargs.get("position_ids", None) if pos_ids is not None: - position_ids.append(move_to(pos_ids, device=data_device)) + position_ids.append(move_to(pos_ids, device=batch_device)) one_kwargs: Dict[str, Any] = {} for (k, v) in kwargs.items(): if k not in ["hidden_states", "attention_mask", "position_ids"]: - one_kwargs[k] = nested_move_to(v, device=data_device) + one_kwargs[k] = nested_move_to(v, device=batch_device) + one_kwargs = self.gptq_model.capture_first_layer_input_kwargs( + args=args, + kwargs=kwargs, + batch_device=batch_device, + layer_input_kwargs=one_kwargs, + ) layer_input_kwargs.append(one_kwargs) # In normal repeating layer/sbuset early stop happens on the last module forward @@ -147,8 +190,6 @@ def store_input_hook(module, args, kwargs): handle = layers[0].register_forward_pre_hook(store_input_hook, with_kwargs=True) - is_ovis = self.gptq_model.model.config.model_type == "ovis" - self.gptq_model.pre_quantize_generate_hook_start() # TODO: why data_device sometimes set to cuda (self.gptq_model.quantize_config.device) and sometimes to CPU (cur_layer_device)? @@ -162,38 +203,17 @@ def store_input_hook(module, args, kwargs): if "pixel_values" in example.keys() else cur_layer_device ) - for k, v in example.items(): - if isinstance(v, list): - for index in range(len(v)): - if len(v[index].shape) == 1: - v[index] = v[index].unsqueeze(0) - v[index] = move_to( - v[index].to(self.gptq_model.model.visual_tokenizer.dtype) - if is_ovis - else v[index], - device=data_device, - ) - else: - if len(v.shape) == 1: - v = v.unsqueeze(0) - example[k] = move_to(v, device=data_device) + example = self.gptq_model.move_input_capture_example(example, data_device) try: - if self.gptq_model.ATTENTION_MASKS_DTYPE is torch.long: - example["attention_mask"] = example["attention_mask"].long() - with ctx( DEVICE_THREAD_POOL.read_lock(self.gptq_model.quantize_config.device), device_ctx(self.gptq_model.quantize_config.device), ): - if self.gptq_model.INPUT_EMBEDDING_EXTRA_ARGS: - self.gptq_model.model.generate( - **example, - **self.gptq_model.INPUT_EMBEDDING_EXTRA_ARGS, - ) - elif is_ovis: - self.gptq_model.model.generate(inputs=example.pop("input_ids"), **example) - else: - self.gptq_model.model(**example, use_cache=use_cache) + self.gptq_model.run_input_capture( + example, + use_cache=use_cache, + data_device=data_device, + ) except StopForward: pass finally: diff --git a/gptqmodel/looper/stage_layer.py b/gptqmodel/looper/stage_layer.py index 9e316c019..98fb64c49 100644 --- a/gptqmodel/looper/stage_layer.py +++ b/gptqmodel/looper/stage_layer.py @@ -3,10 +3,18 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium -"""Layer execution stage extracted from ModuleLooper.""" +"""Layer-level orchestration for subset execution, replay, and finalization. + +For each processor and layer, this stage: +- builds all subset plans up front +- executes subsets using those plans +- replays forward once when the processor needs post-process outputs +- finalizes processed modules after the processor pipeline completes +""" from __future__ import annotations +import copy import logging import threading import time @@ -23,25 +31,335 @@ from ..looper.awq_processor import AWQProcessor from ..looper.gptq_processor import GPTQProcessor from ..looper.named_module import NamedModule +from ..looper.paroquant_processor import ParoQuantProcessor from ..looper.qqq_processor import QQQProcessor from ..utils.device import get_device, get_device_new -from ..utils.logger import log_time_block, setup_logger +from ..utils.looper_helpers import normalize_device_like +from ..utils.logger import live_renderables_suppressed, log_time_block, setup_logger from ..utils.model import find_modules, get_module from ..utils.offload import offload_to_disk from ..utils.torch import CPU, torch_empty_cache, torch_sync -from .stage_subset import SubsetForwardContext, run_subset_stage +from .stage_subset import SubsetPlan, build_layer_subset_plans, run_subset_stage if TYPE_CHECKING: # pragma: no cover - type hints only from .module_looper import ModuleLooper +def _find_last_quantized_layer_index( + looper: "ModuleLooper", + *, + layer_modules: List[List[str]], + layers_prefix: Optional[str], + layer_count: int, +) -> Optional[int]: + """Return the highest layer index whose tracked modules are not all dynamically skipped.""" + if looper.gptq_model.quantize_config.lm_head or not layers_prefix: + return None + + layer_module_names = { + name.split("#", 1)[0] + for module_group in layer_modules + for name in module_group + if name + } + if not layer_module_names: + return None + + last_quantized_layer_index = -1 + for candidate_layer_index in range(layer_count): + for module_name in layer_module_names: + module_full_name = f"{layers_prefix}.{candidate_layer_index}.{module_name}" + # If at least one module in this layer is not dynamically excluded, + # the layer still needs forward/quantization work. + if looper.gptq_model.quantize_config.dynamic_get(layer_name=module_full_name) != False: + last_quantized_layer_index = candidate_layer_index + break + + return last_quantized_layer_index + + +def _should_drain_finalize_futures_synchronously( + looper: "ModuleLooper", + *, + finalize_tasks, +) -> bool: + """Decide whether one layer must finish finalization before the next begins. + + ParoQuant layer/group optimization holds substantially more live CUDA state + than the weight-only paths. Letting its finalizers overlap the next layer + can visibly ratchet active VRAM upward from layer N to N+1, so ParoQuant + always drains per-layer finalizers synchronously. + """ + if looper.gptq_model.quantize_config.wait_for_submodule_finalizers: + return True + return any(isinstance(process, ParoQuantProcessor) for process, *_ in finalize_tasks) + + +def _should_empty_cache_after_sync_finalize( + looper: "ModuleLooper", + *, + finalize_tasks, +) -> bool: + """Release CUDA cache after synchronous ParoQuant finalization when offload is active. + + Disk offload correctly moves finalized modules out of the live model path, + but CUDA's allocator can still hold onto the just-freed pools across layer + boundaries. That shows up as a steady nvidia-smi climb even though the + previous layer no longer needs those weights on device. A cache release at + the synchronous boundary keeps layer-scope memory flat without changing the + quantization objective. + """ + if not getattr(looper.gptq_model.quantize_config, "offload_to_disk", False): + return False + return any(isinstance(process, ParoQuantProcessor) for process, *_ in finalize_tasks) + + +def _processor_needs_pristine_group_clone(processor) -> bool: + """Whether grouped capture needs a dedicated pristine layer clone for this processor.""" + needs_clone = getattr(processor, "needs_pristine_layer_clone", None) + if callable(needs_clone): + return bool(needs_clone()) + uses_grouped_optimization = getattr(processor, "uses_grouped_optimization", None) + return callable(uses_grouped_optimization) and bool(uses_grouped_optimization()) + + +def _collect_layer_forward_progress( + looper: "ModuleLooper", + *, + processor, + layer_inputs: List[List[torch.Tensor]], +) -> tuple[int, List[int], int]: + """Compute replay progress metadata for a whole-layer lifecycle forward. + + Subset-driven replay normally reuses progress data that was already planned + inside :class:`SubsetPlan`. When an entire layer is dynamically excluded, + no subset plan exists, but the layer stage may still need one untouched + forward pass so the next layer receives the correct activations. + + This helper mirrors the subset planner's batch/row normalization so the + fallback layer replay uses the same progress accounting contract: + - `batch_count`: number of cached calibration batches to replay + - `forward_row_counts`: per-batch row counts for progress updates + - `forward_total_rows`: normalized total rows shown by the replay progress + """ + + batch_count = looper._resolve_batch_total( + getattr(processor, "num_batches", None), + layer_inputs, + ) + forward_row_counts = list(looper._collect_row_counts(layer_inputs)) + if not forward_row_counts and batch_count > 0: + forward_row_counts = [1] * batch_count + if len(forward_row_counts) > batch_count: + forward_row_counts = forward_row_counts[:batch_count] + + forward_total_rows = sum(forward_row_counts) if forward_row_counts else batch_count + forward_total_rows = max(forward_total_rows, 1) + + if len(forward_row_counts) < batch_count: + forward_row_counts.extend([1] * (batch_count - len(forward_row_counts))) + + return batch_count, forward_row_counts, forward_total_rows + + +def _replay_layer_outputs( + looper: "ModuleLooper", + *, + module: torch.nn.Module, + processor, + layer_inputs: List[List[torch.Tensor]], + layer_input_kwargs: List[Dict[str, torch.Tensor]], + position_ids: List[torch.Tensor], + attention_masks: List[torch.Tensor], + cur_layer_device: torch.device, + is_lm_head_module: bool, + shared_kv_cache_dict: Dict[int, torch.Tensor], + layer_index: int, + layer_descriptor: str, + full, + log, + region_timer, + replay_plan: Optional[SubsetPlan] = None, +) -> List[List[torch.Tensor]]: + """Replay one layer forward to materialize outputs for the next layer.""" + + if replay_plan is None: + replay_batch_count, replay_row_counts, replay_total_rows = _collect_layer_forward_progress( + looper, + processor=processor, + layer_inputs=layer_inputs, + ) + replay_source = f"{layer_descriptor}:untouched" + replay_modules = None + replay_forward_device_map: Dict[str, torch.device] = {} + replay_force_serial = False + replay_preserve_module_devices = False + else: + replay_batch_count = replay_plan.batch_count + replay_row_counts = replay_plan.forward_row_counts + replay_total_rows = replay_plan.forward_total_rows + replay_source = ( + f"{layer_descriptor}:subset" + f"{replay_plan.subset_index + 1}/{replay_plan.subset_total}" + ) + replay_modules = replay_plan.modules + replay_forward_device_map = replay_plan.forward_device_map + replay_force_serial = replay_plan.subset_forward_serial + replay_preserve_module_devices = replay_plan.preserve_module_devices + + replay_msg = ( + "Forward replay " + f"(layer=`{layer_descriptor}`, batches={replay_batch_count}, rows={replay_total_rows})" + ) + replay_pb = ( + log.pb(range(replay_total_rows)) + .manual() + .set(show_left_steps=False) + ) + replay_pb.title(replay_msg).subtitle( + f"Forward replay Row 0/{replay_total_rows}" + ).draw() + + replay_prev_devices: Dict[str, torch.device] = {} + if replay_modules is not None and replay_forward_device_map: + replay_prev_devices = looper._apply_forward_device_overrides( + replay_modules, + replay_forward_device_map, + fallback_modules=full, + ) + + replay_start = time.perf_counter() + try: + looper._current_subset = None + layer_outputs = looper._run_forward_batches( + module=module, + processor=processor, + layer_inputs=layer_inputs, + layer_input_kwargs=layer_input_kwargs, + position_ids=position_ids, + attention_masks=attention_masks, + cur_layer_device=cur_layer_device, + is_lm_head_module=is_lm_head_module, + shared_kv_cache_dict=shared_kv_cache_dict, + layer_index=layer_index, + need_outputs=True, + reuse_kv=False, + progress_pb=replay_pb, + progress_title=replay_msg, + progress_stage="Forward replay", + progress_rows_per_batch=replay_row_counts, + progress_total_rows=replay_total_rows, + force_serial=replay_force_serial, + preserve_module_devices=replay_preserve_module_devices, + # Replay should emit next-layer activations under the model's native router. + # And reduce the execution time of `forward()`. + apply_moe_config=False, + ) + finally: + if ( + replay_modules is not None + and replay_forward_device_map + and (replay_plan is None or replay_plan.restore_forward_device_overrides) + ): + looper._restore_forward_device_overrides( + replay_modules, + replay_prev_devices, + fallback_modules=full, + ) + replay_pb.close() + + if region_timer is not None: + region_timer.record( + "post_quant_forward", + time.perf_counter() - replay_start, + source=replay_source, + ) + + return layer_outputs + + +def _capture_pristine_group_context( + looper: "ModuleLooper", + *, + processor, + module: torch.nn.Module, + pristine_module: Optional[torch.nn.Module], + subset_plans: List[SubsetPlan], + layer_inputs: List[List[torch.Tensor]], + layer_input_kwargs: List[Dict[str, torch.Tensor]], + position_ids: List[torch.Tensor], + attention_masks: List[torch.Tensor], + cur_layer_device: torch.device, + is_lm_head_module: bool, + shared_kv_cache_dict: Dict[int, torch.Tensor], + layer_index: int, + layer_descriptor: str, + full, + log, + region_timer, +) -> None: + """Capture clean grouped targets while the main layer cache keeps the noisy stream.""" + uses_grouped_optimization = getattr(processor, "uses_grouped_optimization", None) + if not callable(uses_grouped_optimization) or not uses_grouped_optimization(): + return + clean_layer_inputs = layer_inputs + resolve_clean_inputs = getattr(processor, "clean_group_layer_inputs", None) + if callable(resolve_clean_inputs): + clean_layer_inputs = resolve_clean_inputs( + layer_index=layer_index, + layer_inputs=layer_inputs, + ) + capture_pristine_layer_module = getattr(processor, "receive_pristine_layer_module", None) + if subset_plans and callable(capture_pristine_layer_module): + capture_pristine_layer_module( + layer_index=layer_index, + layer_module=pristine_module if pristine_module is not None else module, + ) + + pristine_replay_module = pristine_module if pristine_module is not None else module + pristine_outputs = _replay_layer_outputs( + looper, + module=pristine_replay_module, + processor=processor, + layer_inputs=clean_layer_inputs, + layer_input_kwargs=layer_input_kwargs, + position_ids=position_ids, + attention_masks=attention_masks, + cur_layer_device=cur_layer_device, + is_lm_head_module=is_lm_head_module, + shared_kv_cache_dict=shared_kv_cache_dict, + layer_index=layer_index, + layer_descriptor=layer_descriptor, + full=full, + log=log, + region_timer=region_timer, + replay_plan=None, + ) + receive_clean_layer_inputs = getattr(processor, "receive_clean_layer_inputs", None) + if callable(receive_clean_layer_inputs): + receive_clean_layer_inputs( + layer_index=layer_index, + layer_inputs=pristine_outputs, + ) + if subset_plans: + processor.receive_layer_forward_context( + layer_index=layer_index, + layer_inputs=layer_inputs, + layer_input_kwargs=layer_input_kwargs, + layer_outputs=pristine_outputs, + subset_index=None, + subset_total=len(subset_plans), + ) + + def run_layer_stage( looper: 'ModuleLooper', *, layers: List[torch.nn.Module], layer_modules: List[List[str]], + planning_layer_modules: List[List[str]], layers_prefix: Optional[str], - failsafe, + fallback, shared_kv_cache_dict: Dict[int, torch.Tensor], pb, layer_count: int, @@ -50,7 +368,18 @@ def run_layer_stage( logger=None, ) -> None: """Execute the main per-layer quantization loop.""" + # Trailing layers whose tracked modules are all dynamically excluded never + # need another forward or finalize pass, so the loop can stop once the + # final eligible layer has been processed. + last_quantized_layer_index = _find_last_quantized_layer_index( + looper, + layer_modules=layer_modules, + layers_prefix=layers_prefix, + layer_count=layer_count, + ) + log = logger or setup_logger() + durable_progress_logs = live_renderables_suppressed() for layer_index in pb: # Iterate over every transformer layer (plus lm_head when enabled) as # progress-bar controlled units of work. @@ -58,14 +387,38 @@ def run_layer_stage( break is_lm_head_module = layer_index >= layer_count + if ( + not is_lm_head_module + and last_quantized_layer_index is not None + and layer_index > last_quantized_layer_index + ): + # The remaining layers are fully skipped by dynamic config, so + # avoid entering another layer-level quantization cycle. + log.debug( + "StageLayer: early stop at layer=%s, last_quantized_layer=%s", + layer_index, + last_quantized_layer_index, + ) + pb.close() + break + if is_lm_head_module: layer_title = "Quantizing lm_head" module = get_module(looper.gptq_model.model, key=looper.gptq_model.lm_head) + pristine_group_module = None else: layer_title = f"Quantizing layer {layer_index} of {layer_count - 1}" module = layers[layer_index] + pristine_group_module = None - looper.pause_controller.register_and_draw_progress_bar(pb, title=layer_title, subtitle="") + pb.title(layer_title).subtitle("").draw() + if durable_progress_logs: + log.info( + "StageLayer: start layer=%s/%s title=`%s`", + layer_index if not is_lm_head_module else "lm_head", + layer_count - 1 if not is_lm_head_module else "lm_head", + layer_title, + ) if module.__class__.__name__.lower() == "MllamaCrossAttentionDecoderLayer".lower(): # TODO FIXME: currently we not support quantizing cross attention layer (pixel_values) @@ -81,6 +434,17 @@ def run_layer_stage( converter = MODULE_CONVERTER_MAP[model_type] module = converter(module, looper.gptq_model.model.config) + needs_group_pristine = any( + callable(getattr(processor, "uses_grouped_optimization", None)) and processor.uses_grouped_optimization() + for processor in looper.processors + ) + needs_pristine_group_clone = any( + _processor_needs_pristine_group_clone(processor) + for processor in looper.processors + ) + if needs_group_pristine: + pristine_group_module = copy.deepcopy(module) if needs_pristine_group_clone else None + replace_module_with_hooked_legacy(module, quant_lm_head=looper.gptq_model.quantize_config.lm_head) layers[layer_index] = module @@ -93,6 +457,9 @@ def run_layer_stage( materialize_model(module) cur_layer_device = get_device(module) + if getattr(cur_layer_device, "type", None) == "meta": + # Lazy shell layers can stay meta until a later subset stage materializes them. + cur_layer_device = normalize_device_like(looper.gptq_model.quantize_config.device) or CPU full = find_modules(module, name=looper.gptq_model.lm_head if is_lm_head_module else "") for p_index, processor in enumerate(looper.processors): @@ -100,13 +467,9 @@ def run_layer_stage( # order so their caches and side effects line up with the pipeline. processor.log_call_count = 0 # reset processor.collect_memory_info(layer_index) - - modules = [[looper.gptq_model.lm_head]] if is_lm_head_module else layer_modules - - # for NativeProcessor we process one time forward on all grouped module subsets - if processor.fwd_all_modules_in_single_pass: - # merge all subsets into one - modules = [sum(modules, [])] + # Read the replay policy once per processor so the layer stage uses + # one execution config instead of a group of unrelated flags. + execution_config = processor.execution_config layer_inputs = processor.inputs_cache.layer_inputs if is_lm_head_module and layer_inputs: @@ -116,53 +479,83 @@ def run_layer_stage( attention_masks = processor.inputs_cache.attention_masks processed_subset: Dict[str, NamedModule] = {} - last_subset_context: Optional[SubsetForwardContext] = None + last_subset_plan: Optional[SubsetPlan] = None previous_subset_processed: Optional[Dict[str, NamedModule]] = None - subsets = [] - for names in modules: - subset = looper.create_named_modules( - module=module, - full=full, - is_lm_head_module=is_lm_head_module, - layer_index=layer_index, - layers_prefix=layers_prefix, - names=names, - processor=processor, - failsafe=failsafe, - layer_module=module, + # Freeze all subset-level execution decisions before the processor + # starts running this layer. The rest of the layer stage can then + # iterate plans instead of repeatedly re-deriving replay, batching, + # and device-routing state inside the execution loop. + subset_plans = build_layer_subset_plans( + looper, + processor=processor, + module=module, + layer_modules=layer_modules, + planning_layer_modules=planning_layer_modules, + layer_inputs=layer_inputs, + full=full, + is_lm_head_module=is_lm_head_module, + layer_index=layer_index, + layers_prefix=layers_prefix, + fallback=fallback, + ) + if durable_progress_logs: + log.info( + "StageLayer: layer=%s processor=%s begin subsets=%s", + layer_index if not is_lm_head_module else "lm_head", + processor.name(), + len(subset_plans), ) - # Skip empty subsets caused by per-layer structure differences or dynamic config exclusions; - # otherwise awq_processor may fail to quantize - if subset: - subsets.append(subset) - subset_total = len(subsets) - for index, subset in enumerate(subsets): + _capture_pristine_group_context( + looper, + processor=processor, + module=module, + pristine_module=pristine_group_module, + subset_plans=subset_plans, + layer_inputs=layer_inputs, + layer_input_kwargs=layer_input_kwargs, + position_ids=position_ids, + attention_masks=attention_masks, + cur_layer_device=cur_layer_device, + is_lm_head_module=is_lm_head_module, + shared_kv_cache_dict=shared_kv_cache_dict, + layer_index=layer_index, + layer_descriptor=layer_descriptor, + full=full, + log=log, + region_timer=region_timer, + ) + pristine_group_module = None + + is_last_module = layer_index == len(pb) - 1 + for subset_plan in subset_plans: # Process the layer in smaller subsets so attention groups or # MoE experts can be quantized independently within a layer. if DEBUG_ON and log.isEnabledFor(logging.DEBUG): - if isinstance(processor, AWQProcessor): + if isinstance(processor, (AWQProcessor, ParoQuantProcessor)): log.debug( - "StageLayer[awq]: layer=%s subset=%s/%s size=%s names=%s", + "StageLayer[%s]: layer=%s subset=%s/%s size=%s names=%s", + processor.name(), layer_index, - index + 1, - subset_total, - len(subset), - subset[:5], + subset_plan.subset_index + 1, + subset_plan.subset_total, + len(subset_plan.modules), + list(subset_plan.modules.keys())[:5], ) else: log.debug( "StageLayer: layer=%s subset=%s/%s processor=%s size=%s names=%s", layer_index, - index + 1, - subset_total, + subset_plan.subset_index + 1, + subset_plan.subset_total, processor.name(), - len(subset), - subset[:8], + len(subset_plan.modules), + list(subset_plan.modules.keys())[:8], ) subset_result = run_subset_stage( looper=looper, + plan=subset_plan, processor=processor, module=module, layer_inputs=layer_inputs, @@ -174,12 +567,8 @@ def run_layer_stage( layer_descriptor=layer_descriptor, layer_title=layer_title, layer_index=layer_index, - layers_prefix=layers_prefix, - subset=subset, - subset_index=index, - subset_total=subset_total, full=full, - failsafe=failsafe, + fallback=fallback, shared_kv_cache_dict=shared_kv_cache_dict, pb=pb, log=log, @@ -191,123 +580,66 @@ def run_layer_stage( layer_inputs = subset_result.layer_inputs processed_subset.update(subset_result.processed_subset) previous_subset_processed = subset_result.processed_subset - if subset_result.forward_context is not None: - last_subset_context = subset_result.forward_context + if subset_result.plan is not None: + # The most recent subset plan defines the replay contract + # for the outputs that flow into the next layer. + last_subset_plan = subset_result.plan + if durable_progress_logs: + log.info( + "StageLayer: layer=%s processor=%s subset=%s/%s complete modules=%s", + layer_index if not is_lm_head_module else "lm_head", + processor.name(), + subset_plan.subset_index + 1, + subset_plan.subset_total, + len(subset_plan.modules), + ) - is_last_module = layer_index == len(pb) - 1 layer_outputs: List[List[torch.Tensor]] = [] - subset_context = last_subset_context - forward_device_map = subset_context.forward_device_map if subset_context else {} - subset_forward_serial = subset_context.subset_forward_serial if subset_context else False - subset_reference_total = subset_context.subset_total if subset_context else subset_total - subset_reference_index = subset_context.subset_index if subset_context else max(subset_total - 1, 0) - subset_for_overrides = subset_context.subset if subset_context else {} - preserve_devices = bool(forward_device_map) - - # second forward after process() - if not is_last_module and processor.fwd_after_process and subset_context is not None: - replay_batch_count = looper._resolve_batch_total( - getattr(processor, "num_batches", None), - layer_inputs, - ) - replay_row_counts = list(looper._collect_row_counts(layer_inputs)) - if not replay_row_counts and replay_batch_count > 0: - replay_row_counts = [1] * replay_batch_count - if len(replay_row_counts) > replay_batch_count: - replay_row_counts = replay_row_counts[:replay_batch_count] - replay_total_rows = sum(replay_row_counts) if replay_row_counts else replay_batch_count - replay_total_rows = max(replay_total_rows, 1) - if len(replay_row_counts) < replay_batch_count: - replay_row_counts.extend([1] * (replay_batch_count - len(replay_row_counts))) - replay_msg = ( - "Forward replay " - f"(layer=`{layer_descriptor}`, batches={replay_batch_count}, rows={replay_total_rows})" - ) - replay_pb = ( - log.pb(range(replay_total_rows)) - .manual() - .set(show_left_steps=False) + replay_plan = last_subset_plan + + # When dynamic exclusions remove every tracked module from a layer, + # no subset stage runs, so nothing materializes that layer's + # outputs. Processors that enable post-process forward replay + # (`fwd_replay_after_process`) still need one forward of the untouched + # layer so the next layer receives the correct activations. + replay_skipped_layer = ( + not is_last_module + and not subset_plans + and execution_config.require_fwd + and execution_config.fwd_replay_after_process + ) + + # Some processors consume outputs only after `process()` updates the + # current layer. In that case, replay the layer once using the + # metadata already computed by the final subset plan. + replay_after_process = ( + not is_last_module + and replay_plan is not None + and replay_plan.replay_after_process + ) + + if replay_skipped_layer or replay_after_process: + # Pass `replay_plan` through unconditionally: the helper uses + # subset metadata when available and falls back to generic + # untouched-layer replay when it is `None`. + layer_outputs = _replay_layer_outputs( + looper, + module=module, + processor=processor, + layer_inputs=layer_inputs, + layer_input_kwargs=layer_input_kwargs, + position_ids=position_ids, + attention_masks=attention_masks, + cur_layer_device=cur_layer_device, + is_lm_head_module=is_lm_head_module, + shared_kv_cache_dict=shared_kv_cache_dict, + layer_index=layer_index, + layer_descriptor=layer_descriptor, + full=full, + log=log, + region_timer=region_timer, + replay_plan=replay_plan, ) - replay_pb.title(replay_msg).subtitle( - f"Forward replay Row 0/{replay_total_rows}" - ).draw() - # Forward replay shares the same VRAM spike; block until the pool drains first. - # DEVICE_THREAD_POOL.wait() - # try to cleanup recent objects before forward - #timed_gc_collect(1) - - replay_start = time.perf_counter() - replay_source = f"{layer_descriptor}:subset{subset_reference_index + 1}/{subset_reference_total}" - - replay_prev_devices: Dict[str, torch.device] = {} - if forward_device_map: - replay_prev_devices = looper._apply_forward_device_overrides( - subset_for_overrides, - forward_device_map, - fallback_modules=full, - ) - - # if log.isEnabledFor(logging.DEBUG): - # replay_snapshot = [] - # for name, named_module in subset.items(): - # target_device = getattr(named_module, "target_device", None) - # if target_device is None: - # try: - # target_device = get_device(named_module.module) - # except Exception: - # target_device = None - # target_device_str = str(target_device) if target_device is not None else "unknown" - # replay_snapshot.append(f"{name}:{target_device_str}") - # log.debug( - # "ModuleLooper: Forward replay device snapshot (layer=`%s`, subset=%d/%d, serial=%s) %s", - # layer_descriptor, - # index + 1, - # subset_total, - # subset_forward_serial, - # ", ".join(replay_snapshot), - # ) - - try: - # Reset current subset for MoE lifecycle hooks as we do not need to collect activation at this stage, - # and need to collect only outputs produced by original forward - looper._current_subset = None - - layer_outputs = looper._run_forward_batches( - module=module, - processor=processor, - layer_inputs=layer_inputs, - layer_input_kwargs=layer_input_kwargs, - position_ids=position_ids, - attention_masks=attention_masks, - cur_layer_device=cur_layer_device, - is_lm_head_module=is_lm_head_module, - shared_kv_cache_dict=shared_kv_cache_dict, - layer_index=layer_index, - need_outputs=True, - reuse_kv=False, - progress_pb=replay_pb, - progress_title=replay_msg, - progress_stage="Forward replay", - progress_rows_per_batch=replay_row_counts, - progress_total_rows=replay_total_rows, - force_serial=subset_forward_serial, - preserve_module_devices=preserve_devices, - ) - finally: - if forward_device_map: - looper._restore_forward_device_overrides( - subset_for_overrides, - replay_prev_devices, - fallback_modules=full, - ) - if replay_pb is not None: - replay_pb.close() - if region_timer is not None: - region_timer.record( - "post_quant_forward", - time.perf_counter() - replay_start, - source=replay_source, - ) # Finalize module after last processor if p_index == len(looper.processors) - 1: @@ -333,11 +665,11 @@ def run_layer_stage( if region_timer is not None: region_timer.flush() - if processor.fwd_after_process: + if execution_config.fwd_replay_after_process: processor.clear_cache_data() processor.receive_layer_inputs(layer_outputs) layer_inputs = processor.inputs_cache.layer_inputs - looper.pause_controller.register_and_draw_progress_bar(pb, title=layer_title, subtitle="") + pb.title(layer_title).subtitle("").draw() if p_index == len(looper.processors) - 1: torch_sync() @@ -373,6 +705,8 @@ def run_layer_stage( @torch.inference_mode() def _finalize_on_worker(process, module, idx, total, module_label, layer_idx): + """Runs processor finalization and optional disk offload for one module.""" + resolved_label = module_label or getattr(module, "full_name", getattr(module, "name", "")) start = time.perf_counter() if region_timer is not None else None try: @@ -384,7 +718,7 @@ def _finalize_on_worker(process, module, idx, total, module_label, layer_idx): process.submodule_finalize(module, looper.gptq_model) # Disk offload (lifecycle TODO note preserved) - if isinstance(process, (GPTQProcessor, QQQProcessor, AWQProcessor)): + if isinstance(process, (GPTQProcessor, QQQProcessor, AWQProcessor, ParoQuantProcessor)): quant_config = getattr(looper.gptq_model, "quantize_config", None) if quant_config and getattr(quant_config, "offload_to_disk", False): offload_path = getattr(quant_config, "offload_to_disk_path", None) @@ -488,6 +822,8 @@ def _drain_finalize_futures( finalize_count_local, layer_idx_for_callback, ): + """Consumes finalize futures, updating progress and surfacing errors.""" + completed_local = 0 try: for future in as_completed(futures): @@ -529,7 +865,18 @@ def _drain_finalize_futures( ) if finalize_futures_snapshot: - if looper.gptq_model.quantize_config.wait_for_submodule_finalizers: + drain_sync = _should_drain_finalize_futures_synchronously( + looper, + finalize_tasks=finalize_tasks, + ) + if durable_progress_logs: + log.info( + "StageLayer: layer=%s finalize queued modules=%s mode=%s", + layer_index if not is_lm_head_module else "lm_head", + finalize_count, + "sync" if drain_sync else "async", + ) + if drain_sync: # Synchronous: wait for all finalization to complete before proceeding to next layer # This ensures all packing and writing tasks are done _drain_finalize_futures( @@ -539,7 +886,12 @@ def _drain_finalize_futures( layer_index, ) if looper.gptq_model.quantize_config.gc_mode == GcMode.ON_STAGE_END: - torch_empty_cache() + torch_empty_cache(device=cur_layer_device, sync=True) + elif _should_empty_cache_after_sync_finalize( + looper, + finalize_tasks=finalize_tasks, + ): + torch_empty_cache(device=cur_layer_device, gc=False, sync=True) else: # Asynchronous (current/default behavior): drain in background thread # This allows next layer to start while current layer finalizes @@ -562,10 +914,14 @@ def _drain_finalize_futures( submodule_finalized=True, raise_in_place=True, ) + if durable_progress_logs: + log.info( + "StageLayer: layer=%s complete (no finalize tasks)", + layer_index if not is_lm_head_module else "lm_head", + ) - # Check for pause after completing each layer - layer_info = f"layer {layer_index}" if not is_lm_head_module else "lm_head" - looper.pause_controller.check_pause_point(f"after {layer_info}") - - # Unregister progress bar when moving to next layer - looper.pause_controller.unregister_progress_bar(pb) + if durable_progress_logs: + log.info( + "StageLayer: handoff complete for layer=%s", + layer_index if not is_lm_head_module else "lm_head", + ) diff --git a/gptqmodel/looper/stage_subset.py b/gptqmodel/looper/stage_subset.py index a7b56f9ad..e5a2ca098 100644 --- a/gptqmodel/looper/stage_subset.py +++ b/gptqmodel/looper/stage_subset.py @@ -3,55 +3,657 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium -"""Subset-level processing stage extracted from ModuleLooper.""" +"""Subset-level planning and execution for forward replay and quantization. + +This module is intentionally split into two responsibilities: +- `build_subset_plan()` / `build_layer_subset_plans()` decide what should happen +- `run_subset_stage()` / `_run_single_subset_pass()` execute that decision + +The goal is to keep planning branches out of the hot execution path so replay, +MoE chunking, coverage handling, and device routing are easier to reason about. +""" from __future__ import annotations import logging -import math -import re +import os import time -from dataclasses import dataclass -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple +from dataclasses import dataclass, replace +from typing import TYPE_CHECKING, Callable, Dict, List, Literal, Optional, Tuple +import pcre import torch from .awq_processor import AWQProcessor +from .paroquant_processor import ParoQuantProcessor from .qqq_processor import QQQProcessor from .. import DEBUG_ON, DEVICE_THREAD_POOL from ..looper.gptq_processor import GPTQProcessor from ..looper.loop_processor import LoopProcessor from ..looper.named_module import NamedModule -from ..quantization.config import VramStrategy, GcMode, ExpertsRoutingBypass +from ..models._const import META +from ..quantization.config import GcMode, ExpertsRoutingBypass, VramStrategy +from ..utils.device_telemetry import emit_device_telemetry from ..utils.device import get_device from ..utils.logger import setup_logger +from ..utils.looper_helpers import normalize_device_like, select_forward_devices +from ..utils.python import has_gil_control, has_gil_disabled from ..utils.torch import torch_empty_cache, torch_sync if TYPE_CHECKING: # pragma: no cover - typing only from .module_looper import ModuleLooper +ForwardMode = Literal["parallel", "serial"] + + @dataclass -class SubsetForwardContext: - subset: Dict[str, NamedModule] - forward_device_map: Dict[str, torch.device] - subset_forward_serial: bool - subset_total: int +class CalibrationCoveragePolicy: + """Describe how calibration coverage gaps are handled for one subset. + + Some quantizers need every module to observe routed calibration traffic. + When a MoE expert never fires, we either keep it alive because fallback is + enabled, or prune it from quantization and record a dynamic exclusion. + """ + + # Whether the processor should validate that calibration reached each module. + validate_input_coverage: bool + # Whether uncovered modules are allowed to remain because fallback can take over. + fallback_enabled: bool + # Whether uncovered modules are removed from the quantization worklist. + prune_uncovered_modules: bool + # Whether uncovered modules should be recorded in qcfg.dynamic. + record_dynamic_exclusions: bool + + +@dataclass +class SubsetPlan: + """Freeze all subset execution decisions before forward or quant work begins. + + The plan answers: + - which modules belong to this subset + - whether forward runs at all + - whether the layer needs a post-process replay + - whether forward is serial or parallel + - how MoE groups are arranged for scheduling + - which modules are pinned to specific forward devices + - how uncovered calibration modules are handled + """ + + modules: Dict[str, NamedModule] subset_index: int + subset_total: int + execute_forward: bool + replay_after_process: bool + forward_mode: ForwardMode + batch_count: int + forward_row_counts: List[int] + forward_total_rows: int + moe_groups: Dict[str, List[str]] + forward_device_map: Dict[str, torch.device] + calibration_coverage_policy: CalibrationCoveragePolicy + module_chunks: List[Dict[str, NamedModule]] + restore_forward_device_overrides: bool = True + + @property + def subset_forward_serial(self) -> bool: + """Whether the forward executor should stay on one device.""" + + return self.forward_mode == "serial" + + @property + def need_forward_outputs(self) -> bool: + """Whether this subset consumes forward outputs before process().""" + + return self.execute_forward and not self.replay_after_process + + @property + def batching_enabled(self) -> bool: + """Whether the subset will execute as multiple forward/quant chunks.""" + + return len(self.module_chunks) > 1 + + @property + def preserve_module_devices(self) -> bool: + """Whether per-module forward device overrides are active.""" + + return bool(self.forward_device_map) + + def for_modules(self, modules: Dict[str, NamedModule]) -> "SubsetPlan": + """Reuse the same execution policy for one chunk or replay-only subset.""" + + return replace(self, modules=modules, module_chunks=[modules]) @dataclass class SubsetStageResult: + """Returns processed modules plus the updated layer input cache for a subset.""" + processed_subset: Dict[str, NamedModule] layer_inputs: List[List[torch.Tensor]] - forward_context: Optional[SubsetForwardContext] + plan: Optional[SubsetPlan] + + +def _resolve_cache_flush_device( + cur_layer_device: Optional[torch.device], + used_devices, +) -> Optional[torch.device]: + """Keep cache flush local unless the preceding work fanned out across devices.""" + + current = normalize_device_like(cur_layer_device) + if current is None: + return None + + accelerator_devices = set() + for device in used_devices: + normalized = normalize_device_like(device) + if normalized is None or normalized.type == "cpu": + continue + accelerator_devices.add(str(normalized)) + + if not accelerator_devices: + return current + if accelerator_devices == {str(current)}: + return current + return None + + +def _resolve_forward_flush_device( + plan: SubsetPlan, + cur_layer_device: Optional[torch.device], +) -> Optional[torch.device]: + used_devices = list(plan.forward_device_map.values()) + + if not plan.subset_forward_serial: + selected_devices = select_forward_devices(cur_layer_device) + active_forward_devices = { + str(normalize_device_like(device)) + for device in selected_devices + if normalize_device_like(device) is not None and normalize_device_like(device).type != "cpu" + } + if len(active_forward_devices) > 1: + used_devices.extend(selected_devices) + + return _resolve_cache_flush_device(cur_layer_device, used_devices) + + +def _resolve_quant_flush_device( + cur_layer_device: Optional[torch.device], + quant_target_devices: Dict[str, torch.device], +) -> Optional[torch.device]: + return _resolve_cache_flush_device(cur_layer_device, quant_target_devices.values()) + + +def _resolve_subset_calibration_coverage_policy( + processor: LoopProcessor, + fallback, +) -> CalibrationCoveragePolicy: + """Resolve how this subset handles modules that never receive calibration traffic.""" + + validate_input_coverage = isinstance(processor, (GPTQProcessor, QQQProcessor, AWQProcessor, ParoQuantProcessor)) + fallback_enabled = fallback is not None + prune_uncovered_modules = validate_input_coverage and not fallback_enabled + + return CalibrationCoveragePolicy( + validate_input_coverage=validate_input_coverage, + fallback_enabled=fallback_enabled, + prune_uncovered_modules=prune_uncovered_modules, + record_dynamic_exclusions=prune_uncovered_modules, + ) + + +def _collect_subset_forward_progress( + looper: "ModuleLooper", + processor: LoopProcessor, + layer_inputs: List[List[torch.Tensor]], + *, + execute_forward: bool, +) -> Tuple[int, List[int], int]: + """Normalize batch and row progress for the subset's forward execution.""" + + if not execute_forward: + return 0, [], 1 + + batch_count = looper._resolve_batch_total( + getattr(processor, "num_batches", None), + layer_inputs, + ) + forward_row_counts = list(looper._collect_row_counts(layer_inputs)) + if not forward_row_counts and batch_count > 0: + forward_row_counts = [1] * batch_count + if len(forward_row_counts) > batch_count: + forward_row_counts = forward_row_counts[:batch_count] + + forward_total_rows = sum(forward_row_counts) if forward_row_counts else batch_count + forward_total_rows = max(forward_total_rows, 1) + + if len(forward_row_counts) < batch_count: + forward_row_counts.extend([1] * (batch_count - len(forward_row_counts))) + + return batch_count, forward_row_counts, forward_total_rows + + +def _resolve_forward_baseline_devices( + subset: Dict[str, NamedModule], + full, +) -> Dict[str, torch.device]: + """Capture the baseline device for each layer module before subset execution mutates it.""" + + candidates: Dict[str, object] = {} + if full: + candidates.update(full) + for name, named_module in subset.items(): + candidates.setdefault(name, named_module) + + baseline: Dict[str, torch.device] = {} + fallback_device: Optional[torch.device] = None + for name, module_ref in candidates.items(): + actual_module = module_ref.module if isinstance(module_ref, NamedModule) else module_ref + try: + device = get_device(actual_module) + except Exception: + device = None + if device is not None and device != META and fallback_device is None: + fallback_device = device + baseline[name] = device + + if fallback_device is None: + return {} + + resolved: Dict[str, torch.device] = {} + for name, device in baseline.items(): + if device is None or device == META: + device = fallback_device + if device is not None and device != META: + resolved[name] = device + + return resolved + + +def _collect_layer_candidate_names( + subset: Dict[str, NamedModule], + full, +) -> List[str]: + """Return the stable layer module order used for device planning.""" + + names: List[str] = list(subset.keys()) + if full is None: + return names + + for candidate in full.keys(): + if candidate not in subset: + names.append(candidate) + return names + + +def _collect_assignable_moe_group_keys( + moe_groups: Dict[str, List[str]], +) -> List[str]: + """Return expert families that should stay co-located on one device.""" + + assignable_group_keys: List[str] = [] + for group_key, module_names in moe_groups.items(): + suffixes = {name.rsplit(".", 1)[-1] for name in module_names} + # Some MoE families route pairs like gate/up or w1/w3 together. + if {"gate_proj", "up_proj"}.issubset(suffixes) or {"w1", "w3"}.issubset(suffixes): + assignable_group_keys.append(group_key) + return assignable_group_keys + + +def _normalize_planning_module_name(module_name: str) -> str: + """Strip model-tree annotations so planning blocks map back to live module names.""" + + return module_name.split(":", 1)[0] + + +def _collect_dense_groups( + layer_candidate_names: List[str], + layer_moe_group_key_by_name: Dict[str, Optional[str]], + planning_layer_modules: Optional[List[List[str]]], +) -> Dict[str, List[str]]: + """Collect dense modules into model-tree-defined calculation groups.""" + + remaining_dense_names = [ + module_name + for module_name in layer_candidate_names + if layer_moe_group_key_by_name.get(module_name) is None + ] + if not remaining_dense_names: + return {} + + dense_groups: Dict[str, List[str]] = {} + remaining_dense_set = set(remaining_dense_names) + + if planning_layer_modules: + # Reuse the model-tree execution blocks for dense placement so the + # planner follows the same grouping definitions users already maintain. + for block_index, block in enumerate(planning_layer_modules): + block_dense_names: List[str] = [] + block_seen = set() + for block_entry in block: + module_name = _normalize_planning_module_name(block_entry) + if module_name in block_seen: + continue + block_seen.add(module_name) + if module_name not in remaining_dense_set: + continue + if layer_moe_group_key_by_name.get(module_name) is not None: + continue + block_dense_names.append(module_name) + + if block_dense_names: + dense_groups[f"planning:{block_index}"] = block_dense_names + for module_name in block_dense_names: + remaining_dense_set.discard(module_name) + + for module_name in remaining_dense_names: + if module_name not in remaining_dense_set: + continue + dense_groups[module_name] = [module_name] + remaining_dense_set.discard(module_name) + + return dense_groups + + +def build_subset_plan( + looper: "ModuleLooper", + *, + processor: LoopProcessor, + subset: Dict[str, NamedModule], + subset_index: int, + subset_total: int, + full, + fallback, + layer_inputs: List[List[torch.Tensor]], + planning_layer_modules: Optional[List[List[str]]] = None, +) -> SubsetPlan: + """Plan subset execution before any hooks, forwards, or quant work begin. + + The returned plan is the single source of truth for: + - whether this subset runs forward at all + - whether replay happens later in the layer stage + - whether forward stays serial or can fan out + - whether modules are chunked for staged MoE execution + - how uncovered calibration modules are handled + """ + + execution_config = processor.execution_config + calibration_coverage_policy = _resolve_subset_calibration_coverage_policy(processor, fallback) + + moe_groups: Dict[str, List[str]] = {} + forward_device_map: Dict[str, torch.device] = {} + subset_forward_serial = False + restore_forward_device_overrides = True + + layer_candidate_names = _collect_layer_candidate_names(subset=subset, full=full) + subset_moe_group_key_by_name: Dict[str, Optional[str]] = { + name: looper._extract_moe_group_key(name) + for name in subset + } + layer_moe_group_key_by_name: Dict[str, Optional[str]] = { + name: looper._extract_moe_group_key(name) + for name in layer_candidate_names + } + subset_moe_module_names = [ + name for name, group_key in subset_moe_group_key_by_name.items() + if group_key is not None + ] + layer_moe_module_names = [ + name for name, group_key in layer_moe_group_key_by_name.items() + if group_key is not None + ] + is_moe_subset = len(subset_moe_module_names) >= looper._moe_subset_threshold + layer_has_moe = len(layer_moe_module_names) >= looper._moe_subset_threshold + moe_modules_set = set(subset_moe_module_names) + + if layer_has_moe: + for module_name in layer_candidate_names: + # Group experts across the full MoE family so device placement is + # consistent even when the current subset only contains one slice. + group_key = layer_moe_group_key_by_name[module_name] + if group_key is None: + continue + moe_groups.setdefault(group_key, []).append(module_name) + dense_groups = _collect_dense_groups( + layer_candidate_names, + layer_moe_group_key_by_name, + planning_layer_modules, + ) + + for name, named_module in subset.items(): + setattr(named_module, "moe_enabled", name in moe_modules_set) + + dense_strategy_active = bool(getattr(looper, "_dense_vram_strategy_explicit", False)) + moe_strategy_active = bool(getattr(looper, "_moe_vram_strategy_explicit", False)) and bool(moe_groups) + + if dense_strategy_active or moe_strategy_active: + dense_devices = [ + dev for dev in getattr(looper, "_dense_quant_devices", []) + if dev is not None and getattr(dev, "type", None) != "cpu" + ] or list(getattr(looper, "_dense_quant_devices", [])) + moe_devices = [ + dev for dev in getattr(looper, "_moe_quant_devices", []) + if dev is not None and getattr(dev, "type", None) != "cpu" + ] or list(getattr(looper, "_moe_quant_devices", [])) + + if dense_strategy_active and dense_groups and dense_devices: + dense_group_keys = list(dense_groups.keys()) + if looper._dense_vram_strategy == VramStrategy.BALANCED and len(dense_devices) > 1: + for group_index, group_key in enumerate(dense_group_keys): + target_device = dense_devices[group_index % len(dense_devices)] + for module_name in dense_groups[group_key]: + forward_device_map[module_name] = target_device + else: + target_device = dense_devices[0] + for group_key in dense_group_keys: + for module_name in dense_groups[group_key]: + forward_device_map[module_name] = target_device + + if moe_strategy_active and moe_groups and moe_devices: + assignable_group_keys = _collect_assignable_moe_group_keys(moe_groups) + if assignable_group_keys: + if looper._moe_vram_strategy == VramStrategy.BALANCED and len(moe_devices) > 1: + for group_index, group_key in enumerate(assignable_group_keys): + target_device = moe_devices[group_index % len(moe_devices)] + for module_name in moe_groups[group_key]: + forward_device_map[module_name] = target_device + else: + target_device = moe_devices[0] + for group_key in assignable_group_keys: + for module_name in moe_groups[group_key]: + forward_device_map[module_name] = target_device + + if forward_device_map: + # Once either dense or expert placement is explicit, anchor every + # untouched module back to its baseline placement so stale quant + # devices never leak into a later subset forward. + baseline_devices = _resolve_forward_baseline_devices( + subset=subset, + full=full, + ) + for module_name, baseline_device in baseline_devices.items(): + forward_device_map.setdefault(module_name, baseline_device) + + for module_name, named_module in subset.items(): + preferred_device = forward_device_map.get(module_name) + if preferred_device is not None: + named_module.state["preferred_quant_device"] = preferred_device + + restore_forward_device_overrides = False + subset_forward_serial = True + + auto_forward_data_parallel = getattr( + looper.gptq_model.quantize_config, + "auto_forward_data_parallel", + True, + ) + subset_forward_serial = subset_forward_serial or not auto_forward_data_parallel + + # Forward progress is normalized here so the executor and any later replay + # reuse the same batch and row accounting instead of recomputing it. + execute_forward = execution_config.require_fwd + batch_count, forward_row_counts, forward_total_rows = _collect_subset_forward_progress( + looper, + processor, + layer_inputs, + execute_forward=execute_forward, + ) + + # ExpertsRoutingBypass is the only routing mode that exposes a deterministic + # module chunk size for staged MoE execution. + moe_routing = looper.gptq_model.quantize_config.moe + batch_size = None + if moe_routing is not None and isinstance(moe_routing.routing, ExpertsRoutingBypass): + batch_size = moe_routing.routing.batch_size + + module_chunks = [subset] + if is_moe_subset and batch_size is not None and batch_size > 0 and execute_forward: + sorted_module_names = sorted(subset.keys()) + module_chunks = [ + {name: subset[name] for name in sorted_module_names[start:start + batch_size]} + for start in range(0, len(sorted_module_names), batch_size) + ] + + return SubsetPlan( + modules=subset, + subset_index=subset_index, + subset_total=subset_total, + execute_forward=execute_forward, + replay_after_process=execute_forward and execution_config.fwd_replay_after_process, + forward_mode="serial" if subset_forward_serial else "parallel", + batch_count=batch_count, + forward_row_counts=forward_row_counts, + forward_total_rows=forward_total_rows, + moe_groups=moe_groups, + forward_device_map=forward_device_map, + calibration_coverage_policy=calibration_coverage_policy, + module_chunks=module_chunks, + restore_forward_device_overrides=restore_forward_device_overrides, + ) + + +def build_layer_subset_plans( + looper: "ModuleLooper", + *, + processor: LoopProcessor, + module: torch.nn.Module, + layer_modules: List[List[str]], + planning_layer_modules: Optional[List[List[str]]], + layer_inputs: List[List[torch.Tensor]], + full, + is_lm_head_module: bool, + layer_index: int, + layers_prefix: Optional[str], + fallback, +) -> List[SubsetPlan]: + """Build every subset plan for one processor before layer execution starts.""" + + execution_config = processor.execution_config + module_name_groups = [[looper.gptq_model.lm_head]] if is_lm_head_module else layer_modules + + if execution_config.fwd_all_modules_in_single_pass: + # Native-style processors consume one merged replay over the whole layer. + # Build one plan up front so the layer stage does not keep re-deriving + # merged subset state while it is also coordinating execution. + module_name_groups = [sum(module_name_groups, [])] + + subsets: List[Dict[str, NamedModule]] = [] + for names in module_name_groups: + subset = looper.create_named_modules( + module, + full, + is_lm_head_module, + layer_index, + layers_prefix, + names, + processor, + fallback, + layer_module=module, + ) + # Skip empty subsets caused by per-layer structure differences or dynamic + # exclusions so execution only sees real work. + if subset: + subsets.append(subset) + + subset_total = len(subsets) + return [ + build_subset_plan( + looper, + processor=processor, + subset=subset, + subset_index=index, + subset_total=subset_total, + full=full, + fallback=fallback, + layer_inputs=layer_inputs, + planning_layer_modules=planning_layer_modules, + ) + for index, subset in enumerate(subsets) + ] + + +def _emit_moe_parallel_quant_subset_telemetry( + *, + plan: SubsetPlan, + quant_target_devices: Dict[str, torch.device], + futures_count: int, + layer_index: int, +) -> None: + """Capture the worker fan-out used for one MoE quant subset when telemetry is enabled.""" + + if not plan.moe_groups or futures_count <= 0: + return + + unique_devices = sorted({str(device) for device in quant_target_devices.values() if device is not None}) + thread_pool_workers: Dict[str, int] = {} + thread_pool_total_workers: Optional[int] = None + thread_pool_total_inflight: Optional[int] = None + + collect_snapshot = getattr(DEVICE_THREAD_POOL, "_collect_state_snapshot", None) + if callable(collect_snapshot): + try: + snapshot = collect_snapshot() + except Exception: + snapshot = None + if isinstance(snapshot, dict): + workers = snapshot.get("workers") or {} + thread_pool_workers = { + device_name: int(workers.get(device_name, 0)) + for device_name in unique_devices + } + total_workers = snapshot.get("total_workers") + total_inflight = snapshot.get("total_inflight") + if total_workers is not None: + thread_pool_total_workers = int(total_workers) + if total_inflight is not None: + thread_pool_total_inflight = int(total_inflight) + + total_parallel_workers = sum(thread_pool_workers.values()) + emit_device_telemetry( + "moe_parallel_quant_subset", + layer_index=layer_index, + subset_index=plan.subset_index + 1, + subset_total=plan.subset_total, + module_count=len(plan.modules), + moe_group_count=len(plan.moe_groups), + submitted_tasks=futures_count, + quant_devices=unique_devices, + thread_pool_workers=thread_pool_workers, + thread_pool_total_workers=thread_pool_total_workers, + thread_pool_total_inflight=thread_pool_total_inflight, + python_gil_env=os.environ.get("PYTHON_GIL"), + python_gil_controllable=has_gil_control(), + python_gil_disabled=has_gil_disabled(), + free_threaded_parallel_quant_eligible=bool(has_gil_disabled() and futures_count > 1), + free_threaded_parallel_quant_active=bool(total_parallel_workers > 1 and futures_count > 1), + ) def _run_single_subset_pass( looper: 'ModuleLooper', processor: LoopProcessor, module: torch.nn.Module, - subset: Dict[str, NamedModule], + plan: SubsetPlan, layer_inputs: List[List[torch.Tensor]], layer_input_kwargs: List[Dict[str, torch.Tensor]], position_ids: List[torch.Tensor], @@ -61,30 +663,49 @@ def _run_single_subset_pass( layer_descriptor: str, layer_title: str, layer_index: int, - subset_index: int, - subset_total: int, full, - failsafe, + fallback, shared_kv_cache_dict: Dict[int, torch.Tensor], pb, logger, is_awq_processor: bool, - forward_total_rows: int, - forward_row_counts: List[int], - batch_count: int, - forward_device_map: Dict[str, torch.device], - subset_forward_serial: bool, region_timer=None, previous_processed_subset: Optional[Dict[str, NamedModule]] = None, subset_event_cb: Optional[Callable[..., None]] = None, return_outputs: bool = False, disable_moe_hooks: bool = False, -) -> Tuple[Dict[str, NamedModule], Optional[List[List[torch.Tensor]]]]: - """Execute forward and quantization for a specific subset/chunk.""" - + execute_forward: Optional[bool] = None, +) -> Tuple[Dict[str, NamedModule], Optional[List[List[torch.Tensor]]], bool]: + """Execute forward and quantization for a specific subset/chunk. + + This function assumes planning is already done. Apart from the optional + `execute_forward` override used by replay-only and quant-only paths, it + should consume the plan rather than re-derive execution mode. + """ + + # Pull frequently used plan fields into locals so the execution flow below + # reads linearly without re-deriving policy from processor state. + subset = plan.modules + subset_index = plan.subset_index + subset_total = plan.subset_total + execution_config = processor.execution_config + calibration_coverage_policy = plan.calibration_coverage_policy + forward_row_counts = plan.forward_row_counts + batch_count = plan.batch_count + forward_device_map = plan.forward_device_map + execute_forward = plan.execute_forward if execute_forward is None else execute_forward + handle = [] subset_size = len(subset) - + + if execute_forward: + for named_module in subset.values(): + if isinstance(named_module, NamedModule): + looper._prepare_named_module_for_forward( + named_module=named_module, + fallback_device=cur_layer_device, + ) + # Determine MoE block name for hook selection moe_block_name = None if looper.gptq_model and hasattr(looper.gptq_model, 'moe_lifecycle_hooks'): @@ -98,41 +719,45 @@ def _run_single_subset_pass( moe_block_name = mod_name break - for idx, (name, m) in enumerate(subset.items()): - # Register the forward hook that captures activations for quantization. - # The final module optionally flips a flag so processors can trigger - # once-per-subset logic after the forward pass. - is_last = (idx == subset_size - 1) - hook_source = getattr(m, "full_name", None) - if hook_source is None: - hook_source = getattr(m, "name", name) - if hook_source is None: - hook_source = str(name) - - # Determine if this module is part of MoE block (needs pre-hook to avoid StopForward) - is_moe_module = moe_block_name and name.startswith(moe_block_name + ".") - - if hasattr(subset[name], 'forward_hook'): - original_hook = processor.pre_process_fwd_hook(name) - # Use pre-hook for MoE modules to fire before StopForward - if is_moe_module: - subset[name].forward_hook = looper._masked_pre_hook_wrapper(processor, original_hook, hook_source) - else: - subset[name].forward_hook = looper._masked_hook_wrapper(processor, original_hook, hook_source) - enable_stop = processor.fwd_after_process or getattr(processor, "subset_forward_early_stop", False) - if is_last and enable_stop: - subset[name].forward_hook_last = True - else: - original_hook = processor.pre_process_fwd_hook(name) - # Use pre-hook registration for MoE modules - if is_moe_module: - handle.append(subset[name].register_forward_hook( - looper._masked_pre_hook_wrapper(processor, original_hook, hook_source) - )) + if execute_forward: + for idx, (name, m) in enumerate(subset.items()): + # Register the forward hook that captures activations for quantization. + # The final module optionally flips a flag so processors can trigger + # once-per-subset logic after the forward pass. + is_last = (idx == subset_size - 1) + hook_source = getattr(m, "full_name", None) + if hook_source is None: + hook_source = getattr(m, "name", name) + if hook_source is None: + hook_source = str(name) + + # Determine if this module is part of MoE block (needs pre-hook to avoid StopForward) + is_moe_module = moe_block_name and name.startswith(moe_block_name + ".") + + if hasattr(subset[name], 'forward_hook'): + original_hook = processor.pre_process_fwd_hook(name) + # Use pre-hook for MoE modules to fire before StopForward + if is_moe_module: + subset[name].forward_hook = looper._masked_pre_hook_wrapper(processor, original_hook, hook_source) + else: + subset[name].forward_hook = looper._masked_hook_wrapper(processor, original_hook, hook_source) + enable_stop = ( + execution_config.fwd_replay_after_process + or execution_config.subset_forward_early_stop + ) + if is_last and enable_stop: + subset[name].forward_hook_last = True else: - handle.append(subset[name].register_forward_hook( - looper._masked_hook_wrapper(processor, original_hook, hook_source) - )) + original_hook = processor.pre_process_fwd_hook(name) + # Use pre-hook registration for MoE modules + if is_moe_module: + handle.append(subset[name].register_forward_hook( + looper._masked_pre_hook_wrapper(processor, original_hook, hook_source) + )) + else: + handle.append(subset[name].register_forward_hook( + looper._masked_hook_wrapper(processor, original_hook, hook_source) + )) if DEBUG_ON and logger.isEnabledFor(logging.DEBUG): if is_awq_processor: @@ -153,30 +778,36 @@ def _run_single_subset_pass( len(subset), ) - if subset_event_cb: - subset_event_cb(stage="forward_start", layer_idx=layer_index, subset_index=subset_index, subset_total=subset_total, module_names=list(subset.keys()), processor=getattr(processor, "name", type(processor).__name__)) - - fwd_start = time.perf_counter() + capture_layer_forward_context = execute_forward and execution_config.capture_layer_forward_context + if capture_layer_forward_context: + subset_capture_override = getattr(processor, "capture_layer_forward_context_during_subset", None) + if callable(subset_capture_override): + capture_layer_forward_context = bool(subset_capture_override()) + need_outputs = execute_forward and (plan.need_forward_outputs or capture_layer_forward_context) + fwd_start = None forward_source = f"{layer_descriptor}:subset{subset_index + 1}/{subset_total}" - - need_outputs = not processor.fwd_after_process - reuse_kv = bool(getattr(module, "reuse_kv", False)) - forward_msg = ( - "Forward: " - f"Layer=`{layer_descriptor}`, subset={subset_index + 1}/{subset_total}, " - f"batches={batch_count}" - ) - forward_pb = ( - logger.pb(range(forward_total_rows)) - .manual() - .set(show_left_steps=False) - ) - forward_pb.title(forward_msg).subtitle( - f"Row 0/{forward_total_rows}" - ).draw() + if execute_forward: + if subset_event_cb: + subset_event_cb(stage="forward_start", layer_idx=layer_index, subset_index=subset_index, subset_total=subset_total, module_names=list(subset.keys()), processor=getattr(processor, "name", type(processor).__name__)) + + fwd_start = time.perf_counter() + reuse_kv = bool(getattr(module, "reuse_kv", False)) + forward_msg = ( + "Forward: " + f"Layer=`{layer_descriptor}`, subset={subset_index + 1}/{subset_total}, " + f"batches={batch_count}" + ) + forward_pb = ( + logger.pb(range(plan.forward_total_rows)) + .manual() + .set(show_left_steps=False) + ) + forward_pb.title(forward_msg).subtitle( + f"Row 0/{plan.forward_total_rows}" + ).draw() previous_forward_devices: Dict[str, torch.device] = {} - preserve_devices = bool(forward_device_map) + preserve_devices = plan.preserve_module_devices if forward_device_map: previous_forward_devices = looper._apply_forward_device_overrides( subset, @@ -184,55 +815,69 @@ def _run_single_subset_pass( fallback_modules=full, ) - try: - # Set the current subset for MoE lifecycle hooks - if disable_moe_hooks: - looper._current_subset = None - else: - looper._current_subset = subset - forward_outputs = looper._run_forward_batches( - module=module, - processor=processor, + forward_outputs = None + if execute_forward: + try: + # MoE lifecycle hooks need to know which subset is currently active. + # Replay-only passes can disable that when they only need outputs. + if disable_moe_hooks: + looper._current_subset = None + else: + looper._current_subset = subset + forward_outputs = looper._run_forward_batches( + module=module, + processor=processor, + layer_inputs=layer_inputs, + layer_input_kwargs=layer_input_kwargs, + position_ids=position_ids, + attention_masks=attention_masks, + cur_layer_device=cur_layer_device, + is_lm_head_module=is_lm_head_module, + shared_kv_cache_dict=shared_kv_cache_dict, + layer_index=layer_index, + need_outputs=need_outputs, + reuse_kv=reuse_kv, + progress_pb=forward_pb, + progress_title=forward_msg, + progress_stage="Forward", + progress_rows_per_batch=forward_row_counts, + progress_total_rows=plan.forward_total_rows, + force_serial=plan.subset_forward_serial, + preserve_module_devices=preserve_devices, + ) + finally: + if forward_device_map and plan.restore_forward_device_overrides: + looper._restore_forward_device_overrides( + subset, + previous_forward_devices, + fallback_modules=full, + ) + if forward_pb is not None: + forward_pb.close() + + returned_outputs = None + if execute_forward and capture_layer_forward_context: + processor.receive_layer_forward_context( + layer_index=layer_index, layer_inputs=layer_inputs, layer_input_kwargs=layer_input_kwargs, - position_ids=position_ids, - attention_masks=attention_masks, - cur_layer_device=cur_layer_device, - is_lm_head_module=is_lm_head_module, - shared_kv_cache_dict=shared_kv_cache_dict, - layer_index=layer_index, - need_outputs=need_outputs, - reuse_kv=reuse_kv, - progress_pb=forward_pb, - progress_title=forward_msg, - progress_stage="Forward", - progress_rows_per_batch=forward_row_counts, - progress_total_rows=forward_total_rows, - force_serial=subset_forward_serial, - preserve_module_devices=preserve_devices, + layer_outputs=forward_outputs, + subset_index=subset_index, + subset_total=subset_total, ) - finally: - if forward_device_map: - looper._restore_forward_device_overrides( - subset, - previous_forward_devices, - fallback_modules=full, - ) - if forward_pb is not None: - forward_pb.close() - - returned_outputs = None - if need_outputs: + + if execute_forward and plan.need_forward_outputs: + # For pre-process consumers, the next stage needs the forward outputs + # immediately rather than after the later layer replay step. processor.receive_layer_inputs(forward_outputs) if return_outputs: returned_outputs = processor.inputs_cache.layer_inputs del forward_outputs - - if subset_event_cb: - subset_event_cb(stage="forward_end", layer_idx=layer_index, subset_index=subset_index, subset_total=subset_total, module_names=list(subset.keys()), processor=getattr(processor, "name", type(processor).__name__)) + if execute_forward and subset_event_cb: + subset_event_cb(stage="forward_end", layer_idx=layer_index, subset_index=subset_index, subset_total=subset_total, module_names=list(subset.keys()), processor=getattr(processor, "name", type(processor).__name__)) - fwd_time = time.perf_counter() - fwd_start + fwd_time = (time.perf_counter() - fwd_start) if fwd_start is not None else 0.0 processor.set_fwd_time(fwd_time) if region_timer is not None: region_timer.record( @@ -245,32 +890,35 @@ def _run_single_subset_pass( # Detach temporary hooks to avoid leaking state into future passes. h.remove() - for name in subset: - # Reset inline hook attributes on NamedModule wrappers so future passes - # do not reuse state from this subset run. - if hasattr(subset[name], 'forward_hook'): - subset[name].forward_hook = None - subset[name].forward_hook_last = False + if execute_forward: + for name in subset: + # Reset inline hook attributes on NamedModule wrappers so future passes + # do not reuse state from this subset run. + if hasattr(subset[name], 'forward_hook'): + subset[name].forward_hook = None + subset[name].forward_hook_last = False + forward_flush_device = _resolve_forward_flush_device(plan, cur_layer_device) if looper.gptq_model.quantize_config.gc_mode == GcMode.ON_STAGE_END: - torch_sync() - torch_empty_cache() + torch_empty_cache(device=forward_flush_device, sync=True) moe_skip_modules = [] - failsafe_enabled = failsafe is not None - if isinstance(processor, GPTQProcessor) or isinstance(processor, QQQProcessor) or isinstance(processor, AWQProcessor): + if calibration_coverage_policy.validate_input_coverage: + # Coverage validation is a policy decision captured by the plan. + # The executor only applies that policy; it does not decide when the + # processor should tolerate or prune never-invoked modules. for name in subset: # Skip MoE experts that never fired; they likely lacked calibration # traffic and would produce invalid statistics. if not processor.has_captured_input_ids(name): - # only log for moe if `failsafe` is not enabled - if not failsafe_enabled: + # only log for moe if `fallback` is not enabled + if not calibration_coverage_policy.fallback_enabled: logger.error( f"`{name}` was not invoked, if it is a MoE module, it may lack sufficient calibration data routed to it. " - f"Please enable and use `failsafe` config option." + f"Please enable and use `fallback` config option." ) moe_skip_modules.append(name) - if not failsafe_enabled: + if calibration_coverage_policy.prune_uncovered_modules: for name in moe_skip_modules: skipped_module = subset.pop(name) task_map = getattr(processor, "tasks", None) @@ -279,9 +927,12 @@ def _run_single_subset_pass( # No calibration data was routed to these MoE expert modules. # We skip quantization them and record them in `qcfg.dynamic` as dynamically excluded modules. - if processor.qcfg.dynamic is None: - processor.qcfg.dynamic = {} - processor.qcfg.dynamic[f"-:{re.escape(skipped_module.full_name)}"] = {} + if calibration_coverage_policy.record_dynamic_exclusions: + if processor.qcfg.dynamic is None: + processor.qcfg.dynamic = {} + processor.qcfg.dynamic[ + f"-:{pcre.escape(skipped_module.full_name)}" + ] = {} quant_target_devices: Dict[str, torch.device] = {} for name, named_module in subset.items(): @@ -298,11 +949,15 @@ def _run_single_subset_pass( ) else: target_device = get_device(named_module.module) + if target_device == META: + target_device = cur_layer_device setattr(named_module, "target_device", target_device) setattr(named_module.module, "target_device", target_device) quant_target_devices[name] = target_device + quant_flush_device = _resolve_quant_flush_device(cur_layer_device, quant_target_devices) + processed_subset: Dict[str, NamedModule] = {} futures = [] @@ -320,6 +975,8 @@ def _process_on_worker( subset_idx: int, subset_total_count: int, ): + """Runs `processor.process()` for one module on the device worker pool.""" + module_label = getattr(nm, "full_name", getattr(nm, "name", repr(nm))) proc_name = proc.name() if hasattr(proc, "name") else type(proc).__name__ module_ref = nm.module if isinstance(nm, NamedModule) else nm @@ -327,7 +984,7 @@ def _process_on_worker( if module_weight is not None and expected_device is not None: target_device = expected_device if isinstance(expected_device, torch.device) else torch.device(expected_device) actual_device = get_device(module_weight) - assert actual_device == target_device, ( + assert actual_device == META or actual_device == target_device, ( f"Device mismatch for '{module_label}' process task: " f"module weight on {actual_device}, thread target {target_device}." ) @@ -373,11 +1030,17 @@ def _process_on_worker( return nm.name, nm for name, named_module in subset.items(): - # Launch processing for every module in the subset; tasks may run in - # parallel as allowed by the device thread pool. + # Use submit_serial for CUDA to avoid glibc pthread priority-inheritance + # assertion (Ubuntu 24.04 glibc 2.39) when concurrent CUDA threads run + # simultaneous Cholesky decompositions. tgt_dev = quant_target_devices.get(name, cur_layer_device) + _submitter = ( + DEVICE_THREAD_POOL.submit_serial + if torch.device(tgt_dev).type in ("cuda", "xpu", "mps") + else DEVICE_THREAD_POOL.submit + ) futures.append( - DEVICE_THREAD_POOL.submit( + _submitter( tgt_dev, _process_on_worker, processor, @@ -390,6 +1053,13 @@ def _process_on_worker( ) ) + _emit_moe_parallel_quant_subset_telemetry( + plan=plan, + quant_target_devices=quant_target_devices, + futures_count=len(futures), + layer_index=layer_index, + ) + for fut in futures: # Collect results in submission order so the final subset map preserves # deterministic iteration for downstream consumers. @@ -398,20 +1068,27 @@ def _process_on_worker( # Capture-only modules should not be finalized or offloaded. continue processed_subset[name] = named_module - torch_sync() - if looper.gptq_model.quantize_config.gc_mode == GcMode.ON_STAGE_END: - torch_empty_cache() + torch_empty_cache(device=quant_flush_device, sync=True) + else: + torch_sync() if subset_event_cb: subset_event_cb(stage="quant_complete", layer_idx=layer_index, subset_index=subset_index, subset_total=subset_total, module_names=list(subset.keys()), processor=getattr(processor, "name", type(processor).__name__)) - return processed_subset, returned_outputs + used_data_parallel = False + if execute_forward and forward_flush_device is None: + used_data_parallel = True + if quant_target_devices and quant_flush_device is None: + used_data_parallel = True + + return processed_subset, returned_outputs, used_data_parallel def run_subset_stage( looper: 'ModuleLooper', *, + plan: SubsetPlan, processor: LoopProcessor, module: torch.nn.Module, layer_inputs: List[List[torch.Tensor]], @@ -423,12 +1100,8 @@ def run_subset_stage( layer_descriptor: str, layer_title: str, layer_index: int, - layers_prefix: Optional[str], - subset: Dict[str, NamedModule], - subset_index: int, - subset_total: int, full, - failsafe: bool, + fallback: bool, shared_kv_cache_dict: Dict[int, torch.Tensor], pb, log=None, @@ -436,7 +1109,13 @@ def run_subset_stage( previous_processed_subset: Optional[Dict[str, NamedModule]] = None, subset_event_cb: Optional[Callable[..., None]] = None, ) -> SubsetStageResult: - """Process a single subset of modules within the layer quantization loop.""" + """Process one subset using a precomputed plan. + + The stage has three execution shapes: + - chunked MoE execution driven by `plan.module_chunks` + - one forward + quant pass for normal subsets + - quant-only execution for processors that do not need forward replay + """ logger = log or setup_logger() processor_name = processor.name() if hasattr(processor, "name") else type(processor).__name__ @@ -444,14 +1123,16 @@ def run_subset_stage( is_awq_processor = processor_name_lower.startswith("awq") def emit_subset_event(stage: str) -> None: + """Emits a normalized subset lifecycle callback when one is registered.""" + if subset_event_cb is None: return subset_event_cb( stage=stage, layer_idx=layer_index, - subset_index=subset_index, - subset_total=subset_total, - module_names=list(subset.keys()), + subset_index=plan.subset_index, + subset_total=plan.subset_total, + module_names=list(plan.modules.keys()), processor=processor_name, ) @@ -460,304 +1141,138 @@ def emit_subset_event(stage: str) -> None: logger.debug( "StageSubset[awq]: layer=%s subset=%s/%s modules=%s sample=%s", layer_index, - subset_index + 1, - subset_total, - len(subset), - list(subset.keys())[:8], + plan.subset_index + 1, + plan.subset_total, + len(plan.modules), + list(plan.modules.keys())[:8], ) else: logger.debug( "StageSubset: layer=%s subset=%s/%s processor=%s created %s modules (sample=%s)", layer_index, - subset_index + 1, - subset_total, + plan.subset_index + 1, + plan.subset_total, processor_name, - len(subset), - list(subset.keys())[:8], + len(plan.modules), + list(plan.modules.keys())[:8], ) - - moe_group_keys_all: List[str] = [] - forward_device_map: Dict[str, torch.device] = {} - subset_forward_serial = False - - attention_subset = bool(subset) and all( - looper._is_attention_module_name(name) for name in subset - ) - - moe_group_key_by_name: Dict[str, Optional[str]] = { - name: looper._extract_moe_group_key(name) - for name in subset - } - moe_module_names = [ - name for name, group_key in moe_group_key_by_name.items() - if group_key is not None - ] - moe_modules_set = set(moe_module_names) - is_moe_subset = len(moe_module_names) >= looper._moe_subset_threshold - - if is_moe_subset: - expert_groups: Dict[str, List[str]] = {} - combined_names: List[str] = list(subset.keys()) - if full is not None: - for candidate in full.keys(): - if candidate not in subset: - combined_names.append(candidate) - - for sub_name in combined_names: - # Group every expert (including ones outside the current subset) so - # load balancing decisions can span the full MoE family. - group_key = looper._extract_moe_group_key(sub_name) - if group_key is None: - continue - expert_groups.setdefault(group_key, []).append(sub_name) - - moe_group_keys_all = list(expert_groups.keys()) - - for name, named_module in subset.items(): - setattr(named_module, "moe_enabled", name in moe_modules_set) - - if looper._vram_strategy == VramStrategy.BALANCED: - devices = [ - dev for dev in looper._quant_devices - if dev is not None and getattr(dev, "type", None) != "cpu" - ] - if len(devices) > 1 and expert_groups: - assignable_group_keys: List[str] = [] - for group_key, module_names in expert_groups.items(): - suffixes = {name.rsplit(".", 1)[-1] for name in module_names} - # TODO: Need to make this configuratble and not static string based. Some moe use wN naming. - if {"gate_proj", "up_proj"}.issubset(suffixes) or {"w1", "w3"}.issubset(suffixes): - assignable_group_keys.append(group_key) - - if assignable_group_keys: - groups_per_device = max( - math.ceil(len(assignable_group_keys) / len(devices)), 1 - ) - for group_index, group_key in enumerate(assignable_group_keys): - device_idx = min(group_index // groups_per_device, len(devices) - 1) - target_device = devices[device_idx] - for module_name in expert_groups[group_key]: - forward_device_map[module_name] = target_device - - subset_forward_serial = looper._vram_strategy == VramStrategy.BALANCED - if subset_forward_serial: - active_group_count = len(moe_group_keys_all) - if active_group_count == 0: - subset_forward_serial = False - elif attention_subset and active_group_count <= looper._moe_subset_threshold: - subset_forward_serial = False - else: - for named_module in subset.values(): - setattr(named_module, "moe_enabled", False) - - auto_forward_data_parallel = getattr( - looper.gptq_model.quantize_config, - "auto_forward_data_parallel", - True, - ) - subset_forward_serial = subset_forward_serial or not auto_forward_data_parallel - - # Prepare Loop Parameters - - forward_total_rows = 1 - forward_row_counts = [] - batch_count = 0 - if processor.require_fwd: - batch_count = looper._resolve_batch_total( - getattr(processor, "num_batches", None), - layer_inputs, - ) - forward_row_counts = list(looper._collect_row_counts(layer_inputs)) - if not forward_row_counts and batch_count > 0: - forward_row_counts = [1] * batch_count - if len(forward_row_counts) > batch_count: - forward_row_counts = forward_row_counts[:batch_count] - forward_total_rows = sum(forward_row_counts) if forward_row_counts else batch_count - forward_total_rows = max(forward_total_rows, 1) - if len(forward_row_counts) < batch_count: - forward_row_counts.extend([1] * (batch_count - len(forward_row_counts))) - - # Check for MoE batching - # batch_size is only available when using ExpertsRoutingBypass routing strategy - moe_routing = looper.gptq_model.quantize_config.moe - batch_size = None - if moe_routing is not None and isinstance(moe_routing.routing, ExpertsRoutingBypass): - batch_size = moe_routing.routing.batch_size - batching_enabled = is_moe_subset and batch_size is not None and batch_size > 0 - processed_results = {} - - if batching_enabled and processor.require_fwd: - # Simply sort all module names and chunk them by batch_size - # This processes exactly batch_size MODULES per batch, not batch_size experts - sorted_module_names = sorted(subset.keys()) - # Chunk module names directly by batch_size - module_chunks = [sorted_module_names[i:i + batch_size] for i in range(0, len(sorted_module_names), batch_size)] + # Keep the helper callsite compact while still passing the fully resolved + # execution context into every chunk or single-pass invocation. + common_args = dict( + looper=looper, + processor=processor, + module=module, + layer_inputs=layer_inputs, + layer_input_kwargs=layer_input_kwargs, + position_ids=position_ids, + attention_masks=attention_masks, + cur_layer_device=cur_layer_device, + is_lm_head_module=is_lm_head_module, + layer_descriptor=layer_descriptor, + layer_title=layer_title, + layer_index=layer_index, + full=full, + fallback=fallback, + shared_kv_cache_dict=shared_kv_cache_dict, + pb=pb, + logger=logger, + is_awq_processor=is_awq_processor, + region_timer=region_timer, + previous_processed_subset=previous_processed_subset, + ) + # Once a plan exists, subset execution is just a dispatch over the plan's + # shape rather than another round of subset analysis. + if plan.batching_enabled: if DEBUG_ON and logger.isEnabledFor(logging.DEBUG): logger.debug( - f"MoE Expert Batching Enabled: Processing {len(sorted_module_names)} modules in {len(module_chunks)} batches " - f"(batch_size={batch_size} modules per batch)." + "MoE Expert Batching Enabled: Processing %s modules in %s batches.", + len(plan.modules), + len(plan.module_chunks), ) # Create progress bar for MOE chunks - moe_chunk_pb = logger.pb(range(len(module_chunks))).manual() + moe_chunk_pb = logger.pb(range(len(plan.module_chunks))).manual() moe_chunk_pb.title(f"MoE Chunk") for chunk_idx in moe_chunk_pb: - chunk_keys = module_chunks[chunk_idx] - # Create subset for this chunk - chunk_subset = {k: subset[k] for k in chunk_keys} + chunk_plan = plan.for_modules(plan.module_chunks[chunk_idx]) - moe_chunk_pb.subtitle(f"({len(chunk_subset)} modules)").draw() + moe_chunk_pb.subtitle(f"({len(chunk_plan.modules)} modules)").draw() if DEBUG_ON and logger.isEnabledFor(logging.DEBUG): - logger.debug(f"Processing MoE Chunk {chunk_idx+1}/{len(module_chunks)} ({len(chunk_subset)} modules)...") + logger.debug( + "Processing MoE Chunk %s/%s (%s modules)...", + chunk_idx + 1, + len(plan.module_chunks), + len(chunk_plan.modules), + ) - # Run pass - chunk_result, _ = _run_single_subset_pass( - looper=looper, - processor=processor, - module=module, - subset=chunk_subset, - layer_inputs=layer_inputs, - layer_input_kwargs=layer_input_kwargs, - position_ids=position_ids, - attention_masks=attention_masks, - cur_layer_device=cur_layer_device, - is_lm_head_module=is_lm_head_module, - layer_descriptor=layer_descriptor, - layer_title=layer_title, - layer_index=layer_index, - subset_index=subset_index, - subset_total=subset_total, - full=full, - failsafe=failsafe, - shared_kv_cache_dict=shared_kv_cache_dict, - pb=pb, - logger=logger, - is_awq_processor=is_awq_processor, - forward_total_rows=forward_total_rows, - forward_row_counts=forward_row_counts, - batch_count=batch_count, - forward_device_map=forward_device_map, - subset_forward_serial=subset_forward_serial, - region_timer=region_timer, - previous_processed_subset=previous_processed_subset, + chunk_result, _, chunk_used_data_parallel = _run_single_subset_pass( + **common_args, + plan=chunk_plan, subset_event_cb=None, - return_outputs=False, + return_outputs=False, ) processed_results.update(chunk_result) - + # Force cleanup between chunks if looper.gptq_model.quantize_config.gc_mode == GcMode.ON_STAGE_END: - torch_empty_cache() + flush_device = None if chunk_used_data_parallel else cur_layer_device + torch_empty_cache(device=flush_device) # Close MOE chunks progress bar moe_chunk_pb.close() - # If processor.fwd_after_process is False, stage_layer won't run replay. - # But we haven't collected proper full outputs yet (we ignored them or they were partial). - # So we MUST run a replay here to get valid layer_inputs for the next layer. - if not processor.fwd_after_process: - # Final Replay to collect layer outputs - _, new_layer_inputs = _run_single_subset_pass( - looper=looper, - processor=processor, - module=module, - subset={}, # Empty subset prevents quantization/hooks - layer_inputs=layer_inputs, - layer_input_kwargs=layer_input_kwargs, - position_ids=position_ids, - attention_masks=attention_masks, - cur_layer_device=cur_layer_device, - is_lm_head_module=is_lm_head_module, - layer_descriptor=layer_descriptor, - layer_title=layer_title, - layer_index=layer_index, - subset_index=subset_index, - subset_total=subset_total, - full=full, - failsafe=failsafe, - shared_kv_cache_dict=shared_kv_cache_dict, - pb=pb, - logger=logger, - is_awq_processor=is_awq_processor, - forward_total_rows=forward_total_rows, - forward_row_counts=forward_row_counts, - batch_count=batch_count, - forward_device_map=forward_device_map, - subset_forward_serial=subset_forward_serial, - region_timer=region_timer, - previous_processed_subset=previous_processed_subset, - subset_event_cb=None, - return_outputs=True, - disable_moe_hooks=True, - ) + # Chunked execution does not produce a single coherent next-layer input + # stream while chunks are being processed. When replay is not deferred + # to the layer stage, the subset stage must do one final replay here to + # rebuild the real layer outputs. + if not plan.replay_after_process: + replay_plan = plan.for_modules({}) + _, new_layer_inputs, _ = _run_single_subset_pass( + **common_args, + plan=replay_plan, # Empty modules prevent quant hooks during replay. + subset_event_cb=None, + return_outputs=True, + disable_moe_hooks=True, + ) if new_layer_inputs is not None: layer_inputs = new_layer_inputs - - elif processor.require_fwd: + + elif plan.execute_forward: # Single pass - processed_results, new_layer_inputs = _run_single_subset_pass( - looper=looper, - processor=processor, - module=module, - subset=subset, - layer_inputs=layer_inputs, - layer_input_kwargs=layer_input_kwargs, - position_ids=position_ids, - attention_masks=attention_masks, - cur_layer_device=cur_layer_device, - is_lm_head_module=is_lm_head_module, - layer_descriptor=layer_descriptor, - layer_title=layer_title, - layer_index=layer_index, - subset_index=subset_index, - subset_total=subset_total, - full=full, - failsafe=failsafe, - shared_kv_cache_dict=shared_kv_cache_dict, - pb=pb, - logger=logger, - is_awq_processor=is_awq_processor, - forward_total_rows=forward_total_rows, - forward_row_counts=forward_row_counts, - batch_count=batch_count, - forward_device_map=forward_device_map, - subset_forward_serial=subset_forward_serial, - region_timer=region_timer, - previous_processed_subset=previous_processed_subset, + processed_results, new_layer_inputs, _ = _run_single_subset_pass( + **common_args, + plan=plan, subset_event_cb=subset_event_cb, return_outputs=True, ) if new_layer_inputs is not None: layer_inputs = new_layer_inputs else: - # No forward required + # No forward required; still run process() for each module. if DEBUG_ON: logger.debug( "StageSubset: processor=%s layer=%s subset=%s/%s skipping forward (require_fwd=False)", processor_name, layer_index, - subset_index + 1, - subset_total, + plan.subset_index + 1, + plan.subset_total, ) emit_subset_event("forward_start") emit_subset_event("forward_end") - emit_subset_event("quant_start") - emit_subset_event("quant_complete") - - context = SubsetForwardContext( - subset=subset, - forward_device_map=forward_device_map, - subset_forward_serial=subset_forward_serial, - subset_total=subset_total, - subset_index=subset_index, - ) + processed_results, _, _ = _run_single_subset_pass( + **common_args, + plan=plan, + subset_event_cb=subset_event_cb, + return_outputs=False, + execute_forward=False, + ) return SubsetStageResult( processed_subset=processed_results, layer_inputs=layer_inputs, - forward_context=context, + plan=plan, ) diff --git a/gptqmodel/looper/weight_only_looper.py b/gptqmodel/looper/weight_only_looper.py new file mode 100644 index 000000000..5a93278ca --- /dev/null +++ b/gptqmodel/looper/weight_only_looper.py @@ -0,0 +1,263 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +"""Weight-only quantization loop for methods that do not capture activations. + +This looper intentionally does not share the activation-capture lifecycle used +by GPTQ/AWQ calibration flows. Weight-only methods such as RTN, FP8, NVFP4, or +GGUF can usually process each linear layer directly, so the control flow here +stays narrow: iterate quantizable modules, quantize weights, finalize, and +optionally offload. +""" + +from __future__ import annotations + +from typing import Dict, Optional + +import torch +from defuser.modeling.replace_modules import materialize_model + +from ..looper.module_preprocessor import ModulePreProcessor +from ..looper.weight_only_processor import WeightOnlyProcessor +from ..looper.named_module import NamedModule +from ..models import BaseQModel +from ..models._const import CPU, SUPPORTS_MODULE_TYPES +from ..nn_modules.converter import MODULE_CONVERTER_MAP +from ..quantization.config import BitsAndBytesConfig, FP8Config, GGUFConfig, RTNConfig +from ..utils.logger import setup_logger +from ..utils.model import find_modules, get_module, get_module_by_name_prefix, move_to +from ..utils.offload import offload_to_disk + + +log = setup_logger() + + +class WeightOnlyLooper: + """Run the simplified per-layer lifecycle for weight-only quantization.""" + + def __init__(self, model: BaseQModel, processor: WeightOnlyProcessor): + """Initializes the looper with the model being quantized and its processor.""" + + self.gptq_model = model + self.processor = processor + + def _resolve_named_module( + self, + *, + layer_module: torch.nn.Module, + full: Dict[str, torch.nn.Module], + layer_index: int, + layers_prefix: Optional[str], + module_name: str, + is_lm_head_module: bool, + ) -> Optional[NamedModule]: + """Resolve a quantizable submodule and normalize it into a NamedModule.""" + resolved = full.get(module_name) + if resolved is None: + resolved, _ = get_module_by_name_prefix(layer_module, module_name) + if resolved is None: + if self.gptq_model.layer_modules_strict: + raise ValueError(f"layer module item `{module_name}` not found in model, please check your model config.") + return None + + if isinstance(resolved, NamedModule): + return resolved + + layer_name = self.gptq_model.lm_head if is_lm_head_module else f"{layers_prefix}.{layer_index}.{module_name}" + named = NamedModule( + resolved, + name=module_name, + full_name=layer_name, + layer_index=layer_index, + ) + full[module_name] = named + return named + + def _offload_quantized_module(self, module: NamedModule) -> None: + """Persist an already-quantized module to disk when offload is enabled.""" + quant_config = getattr(self.gptq_model, "quantize_config", None) + if not quant_config or not getattr(quant_config, "offload_to_disk", False): + return + offload_path = getattr(quant_config, "offload_to_disk_path", None) + if not offload_path: + return + + module_full_name = getattr(module, "full_name", None) + target_module = ( + self.gptq_model.model.get_submodule(module_full_name) + if module_full_name + else module + ) + offload_to_disk( + model=self.gptq_model.model, + module=target_module, + disk_path=offload_path, + ) + + def loop(self, **kwargs): + """Quantize layers directly from weights without calibration forwards.""" + quant_config = self.gptq_model.quantize_config + if not isinstance(quant_config, (RTNConfig, GGUFConfig, FP8Config, BitsAndBytesConfig)): + raise NotImplementedError( + "Weight-only looper only supports `RTNConfig`, `GGUFConfig`, " + "`FP8Config`, and `BitsAndBytesConfig` today." + ) + + if quant_config.lm_head: + if self.gptq_model.model.config.tie_word_embeddings and hasattr(self.gptq_model.model.model, "_tied_weights_keys"): + tied_keys = self.gptq_model.model._tied_weights_keys + for item in tied_keys: + if self.gptq_model.lm_head in item: + raise NotImplementedError( + "quantization of `lm_head` layer with `tied_weights=True` model state is not supported. Please check model has `tied_weights=False`." + ) + + lm_head_module = get_module(self.gptq_model.model, key=self.gptq_model.lm_head) + if lm_head_module is None: + raise ValueError(f"could not find layer {self.gptq_model.lm_head} in the model, exit...") + if not isinstance(lm_head_module, tuple(SUPPORTS_MODULE_TYPES)): + raise NotImplementedError( + f"This type({type(lm_head_module)}) of lm_head quantization is currently not supported. SUPPORTS_MODULE_TYPES is {SUPPORTS_MODULE_TYPES}" + ) + + forward_pass_use_cache = ( + self.gptq_model.model.config.use_cache + if hasattr(self.gptq_model.model.config, "use_cache") + else False + ) + # No calibration forwards are executed here, but disabling cache keeps + # behavior aligned with the standard quantization path and avoids stale + # decoder-cache state while layers are being replaced. + self.gptq_model.model.config.use_cache = False + + layers, layers_prefix = get_module_by_name_prefix( + self.gptq_model.model, + self.gptq_model.extract_layers_node(), + ) + + if quant_config.offload_to_disk: + log.info("Offloading base modules to disk...") + offload_to_disk( + model=self.gptq_model.model, + module=self.gptq_model.get_base_modules(model=self.gptq_model.model), + disk_path=quant_config.offload_to_disk_path, + ) + + layer_modules = self.gptq_model.simple_layer_modules( + model_config=self.gptq_model.model.config, + quantize_config=quant_config, + is_awq_quantize=False, + include_capture_only=False, + ) + if not quant_config.true_sequential: + layer_modules = [sum(layer_modules, [])] + + layer_count = len(layers) + total_layers = layer_count + (1 if quant_config.lm_head else 0) + preprocessor = None + if getattr(quant_config, "preprocessors", None): + preprocessor = ModulePreProcessor( + tokenizer=self.gptq_model.tokenizer, + qcfg=quant_config, + calibration=None, + prepare_dataset_func=None, + calibration_concat_size=None, + calibration_sort=None, + calibration_concat_separator=None, + batch_size=1, + ) + + try: + for layer_index in range(total_layers): + is_lm_head_module = layer_index >= layer_count + + # Transformer blocks and lm_head follow the same weight-only + # lifecycle, but lm_head is resolved from the root model. + if is_lm_head_module: + module = get_module(self.gptq_model.model, key=self.gptq_model.lm_head) + subsets = [[self.gptq_model.lm_head]] + else: + module = layers[layer_index] + subsets = layer_modules + + module = self.gptq_model.pre_quantize(module) + if not is_lm_head_module: + # Preserve existing module conversion behavior so the new + # lifecycle stays compatible with model-specific wrappers. + model_type = self.gptq_model.model.config.model_type + if model_type in MODULE_CONVERTER_MAP: + converter = MODULE_CONVERTER_MAP[model_type] + module = converter(module, self.gptq_model.model.config) + layers[layer_index] = module + + # Resolve concrete submodules after any pre-quantization + # transforms so quantization targets the final layer layout. + materialize_model(module) + full = find_modules(module, name=self.gptq_model.lm_head if is_lm_head_module else "") + + self.processor.collect_memory_info(layer_index) + for subset_names in subsets: + for module_name in subset_names: + named = self._resolve_named_module( + layer_module=module, + full=full, + layer_index=layer_index, + layers_prefix=layers_prefix, + module_name=module_name, + is_lm_head_module=is_lm_head_module, + ) + if named is None: + continue + + if preprocessor is not None: + preprocessor.preprocess(named) + if isinstance(named.state.get("auto_module_decoder"), dict): + prepared = self.gptq_model.shell_module_materialize( + target_submodule=named.module, + device=CPU, + role="quant_source", + named_module=named, + ) + if prepared is not named.module: + named.module = prepared + + # Weight-only quantization happens entirely within the + # processor; no captured activations are needed. + active_qcfg = self.processor.quantize_module(named) + if active_qcfg is None: + continue + + # Finalization and optional disk offload expect the + # packed module to be back on CPU memory. + move_to(named.module, device=CPU) + named.target_device = CPU + named.module.target_device = CPU + + self.processor.submodule_finalize( + named, + self.gptq_model, + qcfg=active_qcfg, + ) + self._offload_quantized_module(named) + + # Submodule-level offload may swap packed tensors to meta/disk placeholders. + # Skip the layer-wide CPU move in that case to avoid `.to()` on meta buffers. + if getattr(self.gptq_model.quantize_config, "offload_to_disk", False): + if not is_lm_head_module: + layers[layer_index] = module + elif is_lm_head_module: + self.gptq_model.post_quantize(module) + else: + layers[layer_index] = self.gptq_model.post_quantize(module) + finally: + self.gptq_model.model.config.use_cache = forward_pass_use_cache + + total_log = {self.processor.name(): self.processor.log} + self.gptq_model.quant_log = self.processor.log + self.processor.finalize(model=self.gptq_model) + return total_log + + +__all__ = ["WeightOnlyLooper"] diff --git a/gptqmodel/looper/weight_only_processor.py b/gptqmodel/looper/weight_only_processor.py new file mode 100644 index 000000000..b0faf125e --- /dev/null +++ b/gptqmodel/looper/weight_only_processor.py @@ -0,0 +1,266 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import threading +import time +from typing import Optional + +import torch + +from ..looper.loop_processor import DTYPE_SIZE_COLUMN, ExecutionConfig, MODULE_FEATURE_COLUMN, LoopProcessor +from ..looper.named_module import NamedModule +from ..models import BaseQModel +from ..models._const import CPU +from ..models.writer import ( + PROCESS_LOG_FWD_TIME, + PROCESS_LOG_LAYER, + PROCESS_LOG_MODULE, + PROCESS_LOG_NAME, + PROCESS_LOG_TIME, + PROCESS_USED_MEMORY, + QUANT_LOG_DAMP, + QUANT_LOG_LOSS, + QUANT_LOG_NSAMPLES, +) +from ..quantization.config import ( + BitsAndBytesConfig, + FP8Config, + GGUFConfig, + METHOD, + RTNConfig, + clone_weight_only_config_for_module, + resolve_quant_format, +) +from ..quantization.rtn import RTN +from ..utils.logger import log_time_block, setup_logger +from ..utils.model import create_quant_module, find_modules, pack_module +from ..utils.module_locks import parent_module_lock + + +log = setup_logger() + + +class WeightOnlyProcessor(LoopProcessor): + """Process weight-only modules without entering activation-based quantization flows.""" + + def __init__( + self, + tokenizer, + qcfg: RTNConfig | GGUFConfig | FP8Config | BitsAndBytesConfig, + ): + """Initializes a weight-only processor for RTN, GGUF, FP8, or BitsAndBytes.""" + + super().__init__( + tokenizer=tokenizer, + qcfg=qcfg, + calibration=None, + prepare_dataset_func=None, + calibration_concat_size=None, + calibration_sort=None, + calibration_concat_separator=None, + batch_size=1, + execution_config=ExecutionConfig( + require_fwd=False, + fwd_replay_after_process=False, + ), + ) + self.lock = threading.Lock() + + @staticmethod + def _uses_direct_pack(qcfg: RTNConfig | GGUFConfig | FP8Config | BitsAndBytesConfig) -> bool: + """Returns whether the method packs directly from the original dense weights.""" + + return qcfg.method in {METHOD.GGUF, METHOD.FP8, METHOD.BITSANDBYTES} + + def _update_logged_loss(self, module: NamedModule, avg_loss: str) -> None: + """Backfills the logged loss field after late dequant-error measurement.""" + + with self.lock: + for entry in reversed(self.log): + if entry.get(PROCESS_LOG_LAYER) == module.layer_index and entry.get(PROCESS_LOG_MODULE) == module.name: + entry[QUANT_LOG_LOSS] = avg_loss + return + + def quantize_module( + self, + module: NamedModule, + ) -> Optional[RTNConfig | GGUFConfig | FP8Config | BitsAndBytesConfig]: + """Clones per-module config, quantizes weights, and logs the result.""" + + qcfg_clone = clone_weight_only_config_for_module(self.qcfg, module.full_name) + if qcfg_clone is None: + return None + + if self._uses_direct_pack(qcfg_clone): + start_time = time.time() + duration = time.time() - start_time + avg_loss = f"{qcfg_clone.method.value}: pending" + damp_percent = 0.0 + nsamples = 0 + else: + task = RTN(module=module, qcfg=qcfg_clone) + wq, q_scales, q_zeros, q_g_idx, duration, avg_loss, damp_percent, nsamples = task.quantize() + + module.stream_state_payload_to_cpu( + { + "q_scales": q_scales, + "q_zeros": q_zeros, + "q_g_idx": q_g_idx, + }, + ) + del q_scales, q_zeros, q_g_idx + + stat = { + PROCESS_LOG_NAME: self.name(), + PROCESS_LOG_LAYER: module.layer_index, + PROCESS_LOG_MODULE: module.name, + MODULE_FEATURE_COLUMN: self.module_feature_summary(module), + DTYPE_SIZE_COLUMN: self.module_dtype_size_summary(module), + QUANT_LOG_LOSS: avg_loss if isinstance(avg_loss, str) else f"{avg_loss:.10f}", + QUANT_LOG_NSAMPLES: f"{nsamples}", + QUANT_LOG_DAMP: f"{damp_percent:.5f}", + PROCESS_LOG_TIME: f"{duration:.3f}", + PROCESS_LOG_FWD_TIME: self.formatted_fwd_time(), + PROCESS_USED_MEMORY: self.device_memory_report(), + "lifecycle": "weight_only", + } + + with self.lock: + self.log.append(stat) + self.log_new_row(stat) + + if not self._uses_direct_pack(qcfg_clone): + module.weight.data = wq + return qcfg_clone + + def submodule_finalize( + self, + module: NamedModule, + model: BaseQModel, + *, + qcfg: Optional[RTNConfig | GGUFConfig | FP8Config | BitsAndBytesConfig] = None, + **kwargs, + ): + """Creates and packs the final quantized module into the model graph.""" + + active_qcfg = qcfg or self.qcfg + if not self._uses_direct_pack(active_qcfg): + module.stream_sync() + with self.lock: + q_zeros = module.state.pop("q_zeros").clone() + q_scales = module.state.pop("q_scales").clone() + q_g_idx = module.state.pop("q_g_idx").clone() + + assert q_zeros.device == CPU + assert q_scales.device == CPU + assert q_g_idx.device == CPU + + layers = find_modules(model.model) + module_label = getattr(module, "full_name", getattr(module, "name", "")) + parent_key = getattr(module, "full_name", getattr(module, "name", None)) + original_layer = layers.get(module.full_name) + timer = getattr(model, "quant_region_timer", None) + + create_start = time.perf_counter() if timer is not None else None + with log_time_block("create_quant_module", logger=log, module_name=module_label): + with parent_module_lock(parent_key): + create_quant_module( + name=module.full_name, + linear_cls=model.qlinear_kernel, + bits=active_qcfg.runtime_bits, + desc_act=active_qcfg.desc_act, + dynamic=active_qcfg.dynamic, + group_size=active_qcfg.group_size, + module=model.model, + submodule=module, + sym=active_qcfg.sym, + device=active_qcfg.device, + lm_head_name=model.lm_head, + pack_dtype=active_qcfg.pack_dtype, + format=resolve_quant_format(active_qcfg.format, active_qcfg.method), + register_buffers=False, + init_kwargs=active_qcfg.quant_linear_init_kwargs(), + ) + if timer is not None and create_start is not None: + timer.record("submodule_finalize_create", time.perf_counter() - create_start, source=module_label) + + qmodules = { + name: submodule + for name, submodule in find_modules(model.model, [model.qlinear_kernel]).items() + if name == module.full_name + } + + if self._uses_direct_pack(active_qcfg): + pack_start = time.perf_counter() if timer is not None else None + with log_time_block("module.pack_original", logger=log, module_name=module_label): + with parent_module_lock(parent_key): + qmodule = qmodules[module.full_name] + qmodule.pack_original( + linear=original_layer, + scales=None, + zeros=None, + g_idx=None, + smooth=active_qcfg.smooth, + ) + if timer is not None and pack_start is not None: + timer.record( + "submodule_finalize_pack", + time.perf_counter() - pack_start, + source=f"{module_label} [module.pack_original]", + ) + + reference_weight = qmodule._weight_to_matrix(original_layer).detach().cpu().to(torch.float32) + dequant_weight = qmodule.dequantize_weight().T.detach().cpu().to(torch.float32) + mean_abs_err = (dequant_weight - reference_weight).abs().mean().item() + self._update_logged_loss(module, f"{active_qcfg.method.value}: {mean_abs_err:.7f}") + module.state.pop("tp_pad_info", None) + module.state.pop("quant_source_module", None) + module.unregister_parameter("weight") + return + + pack_start = time.perf_counter() if timer is not None else None + with log_time_block("pack", logger=log, module_name=module_label): + with parent_module_lock(parent_key): + packer_label = pack_module( + name=module.full_name, + qModules=qmodules, + q_scales=q_scales, + q_zeros=q_zeros, + q_g_idx=q_g_idx, + layers=layers, + quant_linear_cls=model.qlinear_kernel, + lock=self.lock, + quantize_config=active_qcfg, + ) + if timer is not None and pack_start is not None: + timer.record( + "submodule_finalize_pack", + time.perf_counter() - pack_start, + source=f"{module_label} [{packer_label or 'module.pack_original'}]", + ) + + del q_scales, q_zeros, q_g_idx + module.state.pop("tp_pad_info", None) + module.state.pop("quant_source_module", None) + module.unregister_parameter("weight") + + def finalize(self, model: BaseQModel, **kwargs): + """Marks the model quantized and runs shared processor finalization.""" + + model.quantized = True + super().finalize(model=model, **kwargs) + + def name(self) -> str: + """Returns the method-specific processor label used in logs.""" + + if self.qcfg.method == METHOD.GGUF: + return "weight_only_gguf" + if self.qcfg.method == METHOD.FP8: + return "weight_only_fp8" + if self.qcfg.method == METHOD.BITSANDBYTES: + return "weight_only_bitsandbytes" + return "weight_only_rtn" + +__all__ = ["WeightOnlyProcessor"] diff --git a/gptqmodel/models/_const.py b/gptqmodel/models/_const.py index 2085d44bb..fc7915d88 100644 --- a/gptqmodel/models/_const.py +++ b/gptqmodel/models/_const.py @@ -83,7 +83,7 @@ def validate_cuda_support(raise_exception: bool = False): if not at_least_one_cuda_v6: if raise_exception: raise EnvironmentError( - "GPTQModel cuda requires Pascal or later gpu with compute capability >= `6.0`.") + "GPT-QModel cuda requires Pascal or later gpu with compute capability >= `6.0`.") else: got_cuda = False @@ -128,6 +128,4 @@ def get_best_device(backend: BACKEND = BACKEND.AUTO) -> torch.device: else: return CPU -EXLLAMA_DEFAULT_MAX_INPUT_LENGTH = 2048 - EXPERT_INDEX_PLACEHOLDER = "{expert_index}" diff --git a/gptqmodel/models/auto.py b/gptqmodel/models/auto.py index b53183490..cc0941705 100644 --- a/gptqmodel/models/auto.py +++ b/gptqmodel/models/auto.py @@ -6,6 +6,7 @@ from __future__ import annotations import os +from contextlib import contextmanager from ..utils.logger import setup_logger @@ -13,15 +14,9 @@ log = setup_logger() ASCII_LOGO = r""" -_____/\\\\\\\\\\\\__/\\\\\\\\\\\\\____/\\\\\\\\\\\\\\\______________________/\\\________/\\\\____________/\\\\_______________________/\\\__________________/\\\\\\____ - ___/\\\//////////__\/\\\/////////\\\_\///////\\\/////____________________/\\\\/\\\\____\/\\\\\\________/\\\\\\______________________\/\\\_________________\////\\\____ - __/\\\_____________\/\\\_______\/\\\_______\/\\\_______________________/\\\//\////\\\__\/\\\//\\\____/\\\//\\\______________________\/\\\____________________\/\\\____ - _\/\\\____/\\\\\\\_\/\\\\\\\\\\\\\/________\/\\\________/\\\\\\\\\\\__/\\\______\//\\\_\/\\\\///\\\/\\\/_\/\\\_____/\\\\\___________\/\\\______/\\\\\\\\_____\/\\\____ - _\/\\\___\/////\\\_\/\\\/////////__________\/\\\_______\///////////__\//\\\______/\\\__\/\\\__\///\\\/___\/\\\___/\\\///\\\____/\\\\\\\\\____/\\\/////\\\____\/\\\____ - _\/\\\_______\/\\\_\/\\\___________________\/\\\______________________\///\\\\/\\\\/___\/\\\____\///_____\/\\\__/\\\__\//\\\__/\\\////\\\___/\\\\\\\\\\\_____\/\\\____ - _\/\\\_______\/\\\_\/\\\___________________\/\\\________________________\////\\\//_____\/\\\_____________\/\\\_\//\\\__/\\\__\/\\\__\/\\\__\//\\///////______\/\\\____ - _\//\\\\\\\\\\\\/__\/\\\___________________\/\\\___________________________\///\\\\\\__\/\\\_____________\/\\\__\///\\\\\/___\//\\\\\\\/\\__\//\\\\\\\\\\__/\\\\\\\\\_ - __\////////////____\///____________________\///______________________________\//////___\///______________\///_____\/////______\///////\//____\//////////__\/////////__ +┌─────────────┐ ┌────────────────────────┐ ┌────────────┐ ┌─────────┐ +│ GPT-QModel │ -> │ ▓▓▓▓▓▓▓▓▓▓▓▓ 16bit │ -> │ ▒▒▒▒ 8bit │ -> │ ░░ 4bit │ +└─────────────┘ └────────────────────────┘ └────────────┘ └─────────┘ """ # if not os.environ.get("PYTHON_GIL", None): @@ -52,26 +47,29 @@ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" import os.path # noqa: E402 -import random # noqa: E402 from os.path import isdir, join # noqa: E402 -from typing import Any, Dict, List, Optional, Type, Union # noqa: E402 +from typing import Dict, List, Optional, Union # noqa: E402 -import numpy # noqa: E402 import torch # noqa: E402 -from huggingface_hub import list_repo_files # noqa: E402 from packaging.version import Version # noqa: E402 -from tokenicer import Tokenicer # noqa: E402 -from transformers import AutoConfig, GenerationConfig, PreTrainedModel, PreTrainedTokenizerBase # noqa: E402 +from transformers import AutoConfig, PreTrainedTokenizerBase # noqa: E402 from transformers import __version__ as TRANSFORMERS_VERSION from ..adapter.adapter import Adapter, Lora, normalize_adapter # noqa: E402 -from ..nn_modules.qlinear.torch import TorchQuantLinear # noqa: E402 -from ..quantization import METHOD, QUANT_CONFIG_FILENAME # noqa: E402 -from ..utils import BACKEND # noqa: E402 -from ..utils.eval import EVAL # noqa: E402 +from ..nn_modules.qlinear.torch import TorchLinear # noqa: E402 +from ..quantization import METHOD, QUANT_CONFIG_FILENAME, QuantizeConfig # noqa: E402 +from ..utils import BACKEND, PROFILE # noqa: E402 +from ..utils.backend import normalize_backend, normalize_profile # noqa: E402 +from ..utils.hf import ( # noqa: E402 + get_hf_gguf_load_kwargs, + normalize_model_id_or_path_for_hf_gguf, + normalize_torch_dtype_kwarg, + resolve_trust_remote_code, +) +from ..utils.hub import list_repo_files # noqa: E402 from ..utils.model import find_modules # noqa: E402 from ..utils.torch import CPU, torch_empty_cache # noqa: E402 -from .base import BaseQModel, QuantizeConfig # noqa: E402 +from .base import BaseQModel # noqa: E402 from .definitions.afmoe import AfMoeQModel # noqa: E402 from .definitions.apertus import ApertusQModel # noqa: E402 from .definitions.baichuan import BaiChuanQModel # noqa: E402 @@ -94,9 +92,12 @@ from .definitions.falcon_h1 import FalconH1QModel # noqa: E402 from .definitions.gemma2 import Gemma2QModel # noqa: E402 from .definitions.gemma3 import Gemma3ForConditionalGenerationGPTQ, Gemma3QModel # noqa: E402 +from .definitions.gemma4 import Gemma4ForConditionalGenerationGPTQ, Gemma4TextQModel # noqa: E402 from .definitions.glm import GlmQModel # noqa: E402 from .definitions.glm4_moe import GLM4MoEGPTQ # noqa: E402 +from .definitions.glm4_moe_lite import Glm4MoeLiteQModel # noqa: E402 from .definitions.glm4v import Glm4vGPTQ # noqa: E402 +from .definitions.glm_moe_dsa import GlmMoeDsaQModel # noqa: E402 from .definitions.gpt2 import GPT2QModel # noqa: E402 from .definitions.gpt_bigcode import GptBigCodeQModel # noqa: E402 from .definitions.gpt_neo import GptNeoQModel # noqa: E402 @@ -119,6 +120,8 @@ from .definitions.mimo import MimoQModel # noqa: E402 from .definitions.minicpm import MiniCPMGPTQ # noqa: E402 from .definitions.minicpm3 import MiniCpm3QModel # noqa: E402 +from .definitions.minicpm_o import MiniCPMOQModel # noqa: E402 +from .definitions.minicpm_v import MiniCPMVQModel # noqa: E402 from .definitions.minimax_m2 import MiniMaxM2GPTQ # noqa: E402 from .definitions.mistral3 import Mistral3GPTQ from .definitions.mixtral import MixtralQModel # noqa: E402 @@ -161,11 +164,6 @@ Qwen3_5_MoeQModel = None -# make quants and inference more determinisitc -torch.manual_seed(787) -random.seed(787) -numpy.random.seed(787) - MODEL_MAP = { "apertus": ApertusQModel, "dream": DreamQModel, @@ -186,6 +184,8 @@ "glm4": GlmQModel, "glm4v": Glm4vGPTQ, "glm4_moe": GLM4MoEGPTQ, + "glm4_moe_lite": Glm4MoeLiteQModel, + "glm_moe_dsa": GlmMoeDsaQModel, "gpt_bigcode": GptBigCodeQModel, "codegen": CodeGenQModel, "cohere": LlamaQModel, # 100% llama clone @@ -213,6 +213,8 @@ "gemma2": Gemma2QModel, "gemma3_text": Gemma3QModel, "gemma3": Gemma3ForConditionalGenerationGPTQ, + "gemma4_text": Gemma4TextQModel, + "gemma4": Gemma4ForConditionalGenerationGPTQ, "phi": PhiQModel, "phi3": Phi3QModel, "phi4mm": Phi4MMGPTQ, @@ -220,6 +222,8 @@ "mpt": MptQModel, "minicpm": MiniCPMGPTQ, "minicpm3": MiniCpm3QModel, + "minicpmo": MiniCPMOQModel, + "minicpmv": MiniCPMVQModel, "minimax": MiniMaxM2GPTQ, "minimax_m2": MiniMaxM2GPTQ, "qwen2_moe": Qwen2MoeQModel, @@ -279,32 +283,127 @@ SUPPORTED_MODELS = list(MODEL_MAP.keys()) +def _activation_quantization_mode(quantization_config: dict) -> Optional[str]: + """Return the first activation-quantization field that makes this config unsupported. + + GPT-QModel can load weight-only quantized checkpoints through the Transformers + surface, but it does not currently implement activation-quantized runtime + semantics. This helper keeps the rejection logic in one place for both + ModelOpt-style grouped configs and flatter HF quantization payloads. + """ + + config_groups = quantization_config.get("config_groups") + if isinstance(config_groups, dict): + for group_cfg in config_groups.values(): + if not isinstance(group_cfg, dict): + continue + input_activations = group_cfg.get("input_activations") + if isinstance(input_activations, dict) and input_activations: + return "input_activations" + + kv_cache_scheme = quantization_config.get("kv_cache_scheme") + if isinstance(kv_cache_scheme, dict) and kv_cache_scheme: + return "kv_cache_scheme" + + for key in ("input_activations", "activation_quantization", "activations"): + value = quantization_config.get(key) + if isinstance(value, dict) and value: + return key + return None + + def _is_supported_quantization_config(config: AutoConfig) -> bool: quantization_config = getattr(config, "quantization_config", None) if not isinstance(quantization_config, dict): return False + # Fail fast before model selection so activation-quantized checkpoints do + # not accidentally proceed down a weight-only loader path. + unsupported_mode = _activation_quantization_mode(quantization_config) + if unsupported_mode is not None: + log.error("GPT-QModel currently does not support loading of activation quantized models") + raise ValueError( + "GPT-QModel currently does not support loading of activation quantized models. " + f"Detected unsupported metadata: {unsupported_mode}." + ) + quant_format = quantization_config.get("quant_format") if isinstance(quant_format, str) and quant_format.lower() in ( METHOD.GPTQ, + METHOD.GGUF, + METHOD.FP8, + METHOD.BITSANDBYTES, METHOD.AWQ, + METHOD.PARO, METHOD.QQQ, + METHOD.EXL3, ): return True - quant_method = quantization_config.get("quant_method") - if isinstance(quant_method, str) and quant_method.lower() in ( + method = quantization_config.get("method", quantization_config.get("quant_method")) + if isinstance(method, str) and method.lower() in ( METHOD.GPTQ, + METHOD.GGUF, + METHOD.FP8, + METHOD.BITSANDBYTES, METHOD.AWQ, + METHOD.PARO, METHOD.QQQ, + METHOD.EXL3, ): return True return False -def check_and_get_model_definition(model_dir, trust_remote_code=False): - config = AutoConfig.from_pretrained(model_dir, trust_remote_code=trust_remote_code) +@contextmanager +def _hide_unsupported_quantization_config_for_eval(model): + config = getattr(model, "config", None) + if config is None: + yield + return + + quantization_config = getattr(config, "quantization_config", None) + if not isinstance(quantization_config, dict): + yield + return + + try: + from transformers.quantizers import AutoQuantizationConfig + + AutoQuantizationConfig.from_dict(dict(quantization_config)) + except Exception: + pass + else: + yield + return + + setattr(config, "quantization_config", None) + try: + yield + finally: + setattr(config, "quantization_config", quantization_config) + + +@contextmanager +def _hide_unsupported_quantization_config_for_lm_eval(model): + with _hide_unsupported_quantization_config_for_eval(model): + yield + + +def _get_config_load_kwargs(kwargs: dict) -> dict: + return get_hf_gguf_load_kwargs(kwargs) + + +def check_and_get_model_definition(model_dir, trust_remote_code=False, **config_load_kwargs): + if "gguf_file" not in config_load_kwargs: + model_dir = normalize_model_id_or_path_for_hf_gguf( + model_dir, + config_load_kwargs, + api_name="check_and_get_model_definition", + ) + trust_remote_code = resolve_trust_remote_code(model_dir, trust_remote_code=trust_remote_code) + config = AutoConfig.from_pretrained(model_dir, trust_remote_code=trust_remote_code, **config_load_kwargs) model_type = config.model_type.lower() # if model_type is not supported, use BaseQModel, will use auto_detect_module_tree to generate module tree @@ -313,12 +412,13 @@ def check_and_get_model_definition(model_dir, trust_remote_code=False): return MODEL_MAP[model_type] + class GPTQModel: def __init__(self): raise EnvironmentError( - "GPTQModel is not designed to be instantiated\n" - "use `GPTQModel.from_pretrained` to load pretrained model and prepare for quantization via `.quantize()`.\n" - "use `GPTQModel.from_quantized` to inference with post-quantized model." + "GPT-QModel is not designed to be instantiated\n" + "use `from_pretrained()` to load a pretrained model and prepare for quantization via `.quantize()`.\n" + "use `from_quantized()` for inference with a post-quantized model." ) @classmethod @@ -329,32 +429,50 @@ def load( device_map: Optional[Union[str, Dict[str, Union[str, int]]]] = None, device: Optional[Union[str, torch.device]] = None, backend: Union[str, BACKEND] = BACKEND.AUTO, + profile: Union[str, int, PROFILE] = PROFILE.AUTO, trust_remote_code: bool = False, **kwargs, ): + model_id_or_path = normalize_model_id_or_path_for_hf_gguf( + model_id_or_path, + kwargs, + api_name="GPTQModel.load", + ) if isinstance(model_id_or_path, str): model_id_or_path = model_id_or_path.strip() + requested_trust_remote_code = trust_remote_code + trust_remote_code = resolve_trust_remote_code(model_id_or_path, trust_remote_code=trust_remote_code) # normalize config to cfg instance if isinstance(quantize_config, Dict): quantize_config = QuantizeConfig(**quantize_config) - if isinstance(backend, str): - backend = BACKEND(backend) + backend = normalize_backend(backend) + profile = normalize_profile(profile) is_gptqmodel_quantized = False - model_cfg = AutoConfig.from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code) - if _is_supported_quantization_config(model_cfg): + treat_as_local_path = isinstance(model_id_or_path, str) and ( + isdir(model_id_or_path) or os.path.isabs(model_id_or_path) + ) + + model_cfg = None + if not (treat_as_local_path and not isdir(model_id_or_path)): + model_cfg = AutoConfig.from_pretrained( + model_id_or_path, + trust_remote_code=trust_remote_code, + **_get_config_load_kwargs(kwargs), + ) + + if model_cfg is not None and _is_supported_quantization_config(model_cfg): # only if the model is quantized or compatible with gptqmodel should we set is_quantized to true is_gptqmodel_quantized = True else: # TODO FIX ME...not decoded to check if quant method is compatible or quantized by gptqmodel for name in [QUANT_CONFIG_FILENAME, "quant_config.json"]: - if isdir(model_id_or_path): # Local - if os.path.exists(join(model_id_or_path, name)): + if treat_as_local_path: # Local paths should never trigger remote Hub lookups + if isdir(model_id_or_path) and os.path.exists(join(model_id_or_path, name)): is_gptqmodel_quantized = True break - else: # Remote files = list_repo_files(repo_id=model_id_or_path) for f in files: @@ -369,6 +487,7 @@ def load( device=device, backend=backend, trust_remote_code=trust_remote_code, + tokenizer_trust_remote_code=requested_trust_remote_code, **kwargs, ) else: @@ -377,7 +496,10 @@ def load( quantize_config=quantize_config, device_map=device_map, device=device, + backend=backend, + profile=profile, trust_remote_code=trust_remote_code, + tokenizer_trust_remote_code=requested_trust_remote_code, **kwargs, ) @@ -393,10 +515,33 @@ def from_pretrained( cls, model_id_or_path: str, quantize_config: QuantizeConfig, + backend: Union[str, BACKEND] = BACKEND.AUTO, + profile: Union[str, int, PROFILE] = PROFILE.AUTO, trust_remote_code: bool = False, **model_init_kwargs, ) -> BaseQModel: - config = AutoConfig.from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code) + model_id_or_path = normalize_model_id_or_path_for_hf_gguf( + model_id_or_path, + model_init_kwargs, + api_name="GPTQModel.from_pretrained", + ) + normalize_torch_dtype_kwarg( + model_init_kwargs, + api_name="GPTQModel.from_pretrained", + ) + backend = normalize_backend(backend) + profile = normalize_profile(profile) + requested_trust_remote_code = trust_remote_code + tokenizer_trust_remote_code = model_init_kwargs.pop( + "tokenizer_trust_remote_code", + requested_trust_remote_code, + ) + trust_remote_code = resolve_trust_remote_code(model_id_or_path, trust_remote_code=trust_remote_code) + config = AutoConfig.from_pretrained( + model_id_or_path, + trust_remote_code=trust_remote_code, + **_get_config_load_kwargs(model_init_kwargs), + ) if _is_supported_quantization_config(config): log.warn("Model is already quantized, will use `from_quantized` to load quantized model.\n" "If you want to quantize the model, please pass un_quantized model path or id, and use " @@ -405,14 +550,21 @@ def from_pretrained( if quantize_config and quantize_config.dynamic: log.warn( - "GPTQModel's per-module `dynamic` quantization feature is fully supported in latest vLLM and SGLang but not yet available in hf transformers.") + "GPT-QModel's per-module `dynamic` quantization feature is fully supported in latest vLLM and SGLang but not yet available in hf transformers.") - model_definition = check_and_get_model_definition(model_id_or_path, trust_remote_code) + model_definition = check_and_get_model_definition( + model_id_or_path, + trust_remote_code, + **_get_config_load_kwargs(model_init_kwargs), + ) return model_definition.from_pretrained( pretrained_model_id_or_path=model_id_or_path, quantize_config=quantize_config, + backend=backend, + profile=profile, trust_remote_code=trust_remote_code, + tokenizer_trust_remote_code=tokenizer_trust_remote_code, **model_init_kwargs, ) @@ -427,14 +579,29 @@ def from_quantized( trust_remote_code: bool = False, **kwargs, ) -> BaseQModel: + model_id_or_path = normalize_model_id_or_path_for_hf_gguf( + model_id_or_path, + kwargs, + api_name="GPTQModel.from_quantized", + ) + normalize_torch_dtype_kwarg( + kwargs, + api_name="GPTQModel.from_quantized", + ) + requested_trust_remote_code = trust_remote_code + tokenizer_trust_remote_code = kwargs.pop("tokenizer_trust_remote_code", requested_trust_remote_code) + trust_remote_code = resolve_trust_remote_code(model_id_or_path, trust_remote_code=trust_remote_code) # normalize adapter to instance adapter = normalize_adapter(adapter) print(f"from_quantized: adapter: {adapter}") - model_definition = check_and_get_model_definition(model_id_or_path, trust_remote_code) + model_definition = check_and_get_model_definition( + model_id_or_path, + trust_remote_code, + **_get_config_load_kwargs(kwargs), + ) - if isinstance(backend, str): - backend = BACKEND(backend) + backend = normalize_backend(backend) return model_definition.from_quantized( model_id_or_path=model_id_or_path, @@ -442,245 +609,14 @@ def from_quantized( device=device, backend=backend, trust_remote_code=trust_remote_code, + tokenizer_trust_remote_code=tokenizer_trust_remote_code, adapter=adapter, **kwargs, ) - @classmethod - def eval( - cls, - model_or_id_or_path: str=None, - tokenizer: Union[PreTrainedTokenizerBase, Tokenicer]=None, - tasks: Union[EVAL.LM_EVAL, EVAL.EVALPLUS, List[EVAL.LM_EVAL], List[EVAL.EVALPLUS], EVAL.MMLU_PRO, List[EVAL.MMLU_PRO]] = None, # set to None to fix mutable warning - framework: Union[Type[EVAL.LM_EVAL],Type[EVAL.EVALPLUS],Type[EVAL.MMLU_PRO]] = EVAL.LM_EVAL, - batch_size: Union[int, str] = 1, - trust_remote_code: bool = False, - output_path: Optional[str] = None, - llm_backend: str = 'gptqmodel', - backend: BACKEND = BACKEND.AUTO, # gptqmodel arg only - random_seed: int = 1234, # only for framework=EVAL.LM_EVAL backend=vllm - model_args: Dict[str, Any] = None, # only for framework=EVAL.LM_EVAL backend=vllm - ntrain: int = 1, # only for framework=EVAL.MMLUPRO - **args - ): - from peft import PeftModel - if model_args is None: - model_args = {} - if tasks is None: - if framework == EVAL.LM_EVAL: - tasks = [EVAL.LM_EVAL.ARC_CHALLENGE] - elif framework == EVAL.MMLU_PRO: - tasks = [EVAL.MMLU_PRO.MATH] - else: - tasks = [EVAL.EVALPLUS.HUMAN] - - elif not isinstance(tasks, List): - tasks = [tasks] - - if framework is None: - raise ValueError("Eval parameter: `framework` cannot be set to None") - - if not isinstance(tasks, list): - raise ValueError("Eval parameter: `tasks` must be of List type") - - if llm_backend not in ['gptqmodel', 'vllm']: - raise ValueError('Eval framework support llm_backend: [gptqmodel, vllm]') - - if llm_backend == "vllm": - if "tensor_parallel_size" not in model_args: - try: - cuda_devices = torch.cuda.device_count() if torch.cuda.is_available() else 0 - except Exception: - cuda_devices = 0 - if cuda_devices: - model_args["tensor_parallel_size"] = cuda_devices - if "gpu_memory_utilization" not in model_args: - model_args["gpu_memory_utilization"] = 0.90 - - if isinstance(model_or_id_or_path, str): - load_backend = backend - if llm_backend == "vllm": - disallowed_keys = {"pretrained", "tokenizer", "gptqmodel", "trust_remote_code", "backend", "model_id_or_path"} - load_kwargs = {k: v for k, v in model_args.items() if k not in disallowed_keys} - else: - load_kwargs = model_args - - backend_name = load_backend.value if isinstance(load_backend, BACKEND) else str(load_backend) - log.info(f"Eval: loading using backend = `{backend_name}`") - model = GPTQModel.load( - model_id_or_path=model_or_id_or_path, - backend=load_backend, - trust_remote_code=trust_remote_code, - **load_kwargs, - ) - model_id_or_path = model_or_id_or_path - elif isinstance(model_or_id_or_path, BaseQModel) or isinstance(model_or_id_or_path, (PreTrainedModel, PeftModel)): - model = model_or_id_or_path - model_id_or_path = model.config.name_or_path # - else: - raise ValueError(f"`model_or_id_or_path` is invalid. expected: `model instance or str` actual: `{model_or_id_or_path}`") - - if tokenizer is None: - if isinstance(model, BaseQModel): - tokenizer = model.tokenizer - elif isinstance(model, PreTrainedModel) or model_id_or_path.strip(): - tokenizer = Tokenicer.load(model_id_or_path.strip()) - - if tokenizer is None: - raise ValueError("Tokenizer: Auto-loading of tokenizer failed with `model_or_id_or_path`. Please pass in `tokenizer` as argument.") - - if llm_backend == "gptqmodel": # vllm loads tokenizer - model_args["tokenizer"] = tokenizer - - if framework == EVAL.LM_EVAL: - from lm_eval.utils import make_table # hack: circular import - - for task in tasks: - if task not in EVAL.get_task_enums(): - raise ValueError(f"Eval.lm_eval supported `tasks`: `{EVAL.get_all_tasks_string()}`, actual = `{task}`") - - model_name = "hf" if llm_backend == "gptqmodel" else llm_backend - - if llm_backend == "gptqmodel" and isinstance(model, BaseQModel) and model.quantized: - model_args["gptqmodel"] = True - model_args["pretrained"] = model_id_or_path - - # TODO FIXME lm-eval latest broken imports - # lm_eval indirectly imports sglang, which expects newer Triton helpers. - # Provide shims so import succeeds on older triton/runtime builds. - try: - from triton.runtime import cache as triton_cache - - if not hasattr(triton_cache, "default_cache_dir"): - triton_cache.default_cache_dir = lambda: getattr(triton_cache.knobs.cache, "dir", None) - if not hasattr(triton_cache, "default_override_dir"): - triton_cache.default_override_dir = lambda: getattr(triton_cache.knobs.cache, "override_dir", None) - if not hasattr(triton_cache, "default_dump_dir"): - triton_cache.default_dump_dir = lambda: getattr(triton_cache.knobs.cache, "dump_dir", None) - except Exception as exc: - log.warning("Triton cache shim failed; lm_eval import may fail: %s", exc) - - try: - from lm_eval import simple_evaluate - from lm_eval.models.huggingface import HFLM - except BaseException as e: - raise ValueError(f"lm_eval import failed: {e}. Please install via `pip install gptqmodel[eval]`.") from e - - if llm_backend == "gptqmodel" and model is not None: - model_name = HFLM( - pretrained=model, - batch_size=batch_size, - trust_remote_code=trust_remote_code, - ) - - gen_kwargs = args.pop("gen_kwargs", None) - - # use model.generation_config whenever possible - if gen_kwargs is None: - # TODO: move to utils - if hasattr(model, "generation_config") and isinstance(model.generation_config, GenerationConfig): - gen_dict = { - "do_sample": model.generation_config.do_sample, - "temperature": model.generation_config.temperature, - "top_k": model.generation_config.top_k, - "top_p": model.generation_config.top_p, - "min_p": model.generation_config.min_p, - - } - gen_kwargs = ','.join(f"{key}={value}" for key, value in gen_dict.items() if value not in ["", {}, None, []]) - else: - gen_kwargs = "temperature=0.0,top_k=50" # default - - log.info(f"LM-EVAL: `gen_kwargs` = `{gen_kwargs}`") - - # lm-eval has very low scores if apply_chat_template is enabled - apply_chat_template = args.pop("apply_chat_template", False) # args.pop("apply_chat_template", True if tokenizer.chat_template is not None else False) - log.info(f"LM-EVAL: `apply_chat_template` = `{apply_chat_template}`") - - # TODO FIXME lm-eval latest broken imports - # lm_eval pretty prints task yaml paths using Path.relative_to; when custom tasks live outside the - # installed lm_eval package tree this raises ValueError. Monkeypatch to fall back to absolute paths. - from pathlib import Path - original_relative_to = Path.relative_to - - def _relative_to_noerror(self, other, *extra_args, **kwargs): - try: - return original_relative_to(self, other, *extra_args, **kwargs) - except ValueError: - return self - - Path.relative_to = _relative_to_noerror - try: - results = simple_evaluate( - model=model_name, - model_args=model_args, - tasks=[task.value for task in tasks], - batch_size=batch_size, - apply_chat_template=apply_chat_template, - gen_kwargs=gen_kwargs, - random_seed=random_seed, - numpy_random_seed=random_seed, - torch_random_seed=random_seed, - fewshot_random_seed=random_seed, - **args, - ) - finally: - Path.relative_to = original_relative_to - - if results is None: - raise ValueError('lm_eval run fail, check your code!!!') - - print('--------lm_eval Eval Result---------') - print(make_table(results)) - if "groups" in results: - print(make_table(results, "groups")) - print('--------lm_eval Result End---------') - return results - elif framework == EVAL.EVALPLUS: - for task in tasks: - if task not in EVAL.get_task_enums(): - raise ValueError(f"evalplus support tasks: {EVAL.get_all_tasks_string()}") - from ..utils.eval import evalplus, evalplus_make_table - - results = {} - for task in tasks: - base_formatted, plus_formatted, result_path = evalplus( - model=model_id_or_path, - dataset=task.value, - batch=batch_size, - trust_remote_code=trust_remote_code, - output_file=output_path, - backend=llm_backend - ) - results[task.value] = {"base tests": base_formatted, "base + extra tests": plus_formatted, - "results_path": result_path} - print('--------evalplus Eval Result---------') - evalplus_make_table(results) - print('--------evalplus Result End---------') - return results - elif framework == EVAL.MMLU_PRO: - for task in tasks: - if task not in EVAL.get_task_enums(): - raise ValueError(f"eval support tasks: {EVAL.get_all_tasks_string()}") - from ..utils.mmlupro import mmlupro - selected_subjects = ",".join(tasks) - results = mmlupro(model, - tokenizer, - save_dir=output_path, - seed=random_seed, - selected_subjects=selected_subjects, - ntrain=ntrain, - batch_size=batch_size) - - print('--------MMLUPro Eval Result---------') - print(results) - print('--------MMLUPro Result End---------') - return results - else: - raise ValueError("Eval framework support: EVAL.LM_EVAL, EVAL.EVALPLUS, EVAL.MMLUPRO") - @staticmethod def export(model_id_or_path: str, target_path: str, format: str, trust_remote_code: bool = False): + trust_remote_code = resolve_trust_remote_code(model_id_or_path, trust_remote_code=trust_remote_code) # load config config = AutoConfig.from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code) @@ -689,8 +625,25 @@ def export(model_id_or_path: str, target_path: str, format: str, trust_remote_co gptq_config = config.quantization_config + method = gptq_config.get("method", gptq_config.get("quant_method", "")) + normalized_method = str(method).lower() + if normalized_method == METHOD.GGUF.value: + backend = BACKEND.GGUF_TORCH + elif normalized_method == METHOD.BITSANDBYTES.value: + backend = BACKEND.BITSANDBYTES + elif normalized_method == METHOD.AWQ.value: + backend = BACKEND.AWQ_TORCH + elif normalized_method == METHOD.PARO.value: + backend = BACKEND.PAROQUANT_CUDA + elif normalized_method == METHOD.FP8.value: + backend = BACKEND.FP8_TORCH + elif normalized_method == METHOD.EXL3.value: + backend = BACKEND.EXL3_TORCH + else: + backend = BACKEND.GPTQ_TORCH + # load gptq model - gptq_model = GPTQModel.load(model_id_or_path, backend=BACKEND.TORCH) + gptq_model = GPTQModel.load(model_id_or_path, backend=backend) if format == "mlx": try: @@ -716,38 +669,6 @@ def export(model_id_or_path: str, target_path: str, format: str, trust_remote_co # save tokenizer to target path gptq_model.tokenizer.save_pretrained(target_path) - # Use HfAPI and not Transformers to do upload - @staticmethod - def push_to_hub(repo_id: str, - quantized_path: str, # saved local directory path - private: bool = False, - exists_ok: bool = False, # set to true if repo already exists - token: Optional[str] = None, - ): - - if not quantized_path: - raise RuntimeError("You must pass quantized model path as str to push_to_hub.") - - if not repo_id: - raise RuntimeError("You must pass repo_id as str to push_to_hub.") - - from huggingface_hub import HfApi - repo_type = "model" - - api = HfApi() - # if repo does not exist, create it - try: - api.repo_info(repo_id=repo_id, repo_type=repo_type, token=token) - except Exception: - api.create_repo(repo_id=repo_id, repo_type=repo_type, token=token, private=private, exist_ok=exists_ok) - - # upload the quantized save folder - api.upload_large_folder( - folder_path=quantized_path, - repo_id=repo_id, - repo_type=repo_type, - ) - class adapter: @classmethod def generate( @@ -774,14 +695,14 @@ def generate( log.info("Model: Quant Model Loading...") quantized_model = GPTQModel.load( model_id_or_path=quantized_model_id_or_path, - backend=BACKEND.TORCH, + backend=BACKEND.GPTQ_TORCH, device=CPU, trust_remote_code=trust_remote_code, dtype=dtype, ) qcfg = quantized_model.quantize_config - qModules: Dict[str, TorchQuantLinear] = find_modules(module=quantized_model.model, layers=[TorchQuantLinear]) + qModules: Dict[str, TorchLinear] = find_modules(module=quantized_model.model, layers=[TorchLinear]) # for name, module in qModules.items(): # quantized_weights[name] = module.dequantize_weight() del quantized_model @@ -791,7 +712,7 @@ def generate( model = GPTQModel.load( model_id_or_path=model_id_or_path, quantize_config=qcfg, - backend=BACKEND.TORCH, + backend=BACKEND.GPTQ_TORCH, trust_remote_code=trust_remote_code, dtype=dtype, device=CPU, diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index 70fd4a5b9..53a9d4133 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -7,7 +7,6 @@ import copy import json import os -import re import threading import time from collections import defaultdict @@ -18,7 +17,6 @@ import torch import torch._dynamo import torch.nn as nn -from tokenicer import Tokenicer from transformers import ( AutoModelForCausalLM, AutoProcessor, @@ -38,20 +36,47 @@ from .. import DEVICE_THREAD_POOL from ..adapter.adapter import Adapter +from ..nn_modules.exllamav3 import ExllamaV3Linear from ..nn_modules.qlinear import BaseQuantLinear +from ..nn_modules.qlinear.fp4 import TorchFP4Linear +from ..nn_modules.qlinear.fp8 import TorchFP8Linear from ..nn_modules.qlinear.lookahead import configure_default_lookahead -from ..nn_modules.qlinear.torch import TorchQuantLinear -from ..quantization import QuantizeConfig -from ..quantization.config import FORMAT, METHOD, QUANTIZE_BLACK_LIST, GcMode, VramStrategy, dynamic_get +from ..nn_modules.qlinear.torch import TorchLinear +from ..quantization.config import ( + FORMAT, + METHOD, + QUANTIZE_BLACK_LIST, + AutoModuleDecoderConfig, + BaseQuantizeConfig, + GcMode, + VramStrategy, + dynamic_get, + resolve_quant_format, +) +from ..quantization.dtype import ( + available_float8_dtypes, + dequantize_f4_e2m1, + dequantize_fp8, + device_supports_dtype, + device_supports_native_fp4, + is_fp4_packed_dtype, +) from ..quantization.rotation.rotation import fuse_layer_norms, rotate_model -from ..utils.backend import BACKEND +from ..utils.attn_mask import normalize_seq_mask +from ..utils.backend import BACKEND, normalize_backend from ..utils.calibration import prepare_calibration_dataset from ..utils.device import get_device -from ..utils.hf import autofix_hf_model_config +from ..utils.hf import autofix_hf_model_config, ensure_hf_model_config_token_ids, load_tokenizer_with_model_config from ..utils.importer import select_quant_linear from ..utils.logger import QuantizationRegionTimer, setup_logger -from ..utils.model import MODALITY, find_modules, get_module_by_name_prefix, move_to -from ..utils.structure import alias_from_turtle_for_submodule +from ..utils.model import MODALITY, _module_has_meta_tensors, find_modules, get_module_by_name_prefix, move_to +from ..utils.model_dequant import infer_block_shape +from ..utils.structure import ( + LazyTurtle, + _get_parent_and_leaf_by_path, + _get_qualified_name, + alias_from_turtle_for_submodule, +) from ..utils.torch import TORCH_HAS_COMPILE, torch_compile from ._const import ( CPU, @@ -71,6 +96,8 @@ except Exception: # pragma: no cover - optional dependency HFDatasetType = HFIterableDatasetType = object + from ..looper.named_module import NamedModule + class _ClassPropertyDescriptor: def __init__(self, fget, fset=None): @@ -142,6 +169,9 @@ class BaseQModel(nn.Module): # name of lm_head lm_head: str = "lm_head" + # Special rotary_emb path + rotary_embedding: str | None = None + # a tree node of all the roots that contain quantizable modules module_tree: List[str] = None # Override module_tree according to different QUANT_METHOD @@ -179,15 +209,20 @@ class BaseQModel(nn.Module): # some models require a different model loader, such as mllama which uses AutoModelForPreTraining loader = AutoModelForCausalLM - # Some models have multiple configurations. - # For example, in llama4 and qwen3_5, model_class.form_config requires TextConfig. - config_class = None - # monkey patch api for trust_remote_code=True models that have broken transformer compat require_monkeypatch = False - # VRAM strategy support list - supported_vram_strategies: List[VramStrategy] = [VramStrategy.EXCLUSIVE, VramStrategy.BALANCED] + # Dense-pool strategy support list + supported_dense_vram_strategies: List[VramStrategy] = [ + VramStrategy.EXCLUSIVE, + VramStrategy.BALANCED, + ] + + # MoE expert-pool strategy support list + supported_moe_vram_strategies: List[VramStrategy] = [ + VramStrategy.EXCLUSIVE, + VramStrategy.BALANCED, + ] # some models have broken attention mask codes so we need to only use batch 1 with no masks support_batch_quantize = True @@ -198,8 +233,9 @@ class BaseQModel(nn.Module): # Some models have optional layers that are not loaded or supported by HF so even when they exist in the original # model, they are not properly saved on save(). GLM 4.5/4.6 (air) with MTP layers is such example. - # List the `dangling` optional tensor files here, and we will merge them in on model.save() - out_of_model_tensor_files: Optional[List[str]] = None + # Provide either a safetensors filename (the file is copied through if present) or a prefix (all `prefix.` tensors + # are merged into the main state dict so they end up in model.safetensors). + out_of_model_tensors: Optional[Dict[str, Union[str | List[str]]]] = None supports_desc_act = [True, False] @@ -223,20 +259,19 @@ def __init__( self, model: PreTrainedModel, quantized: bool, - quantize_config: Optional[QuantizeConfig], + quantize_config: Optional[BaseQuantizeConfig], tokenizer: Optional[PreTrainedTokenizerBase] = None, qlinear_kernel: nn.Module = None, load_quantized_model: bool = False, trust_remote_code: bool = False, model_local_path: str = None, - # turtle model is a sympathetic model used to reduce cpu ram usage - # during quantization stage. - turtle_model: Optional[PreTrainedModel] = None, + # Lazy turtle is the checkpoint-backed source used to materialize shell modules on demand. + turtle_model: Optional[LazyTurtle] = None, ): super().__init__() if quantize_config: - quant_method = quantize_config.quant_method + quant_method = quantize_config.method # override module_tree if need if self.module_tree_overrides is not None and self.module_tree_overrides.get(quant_method) is not None: log.info(f'Module Tree: overridden by METHOD.{quant_method.upper()}') @@ -260,18 +295,21 @@ def __init__( self.model_local_path = model_local_path self.quantize_config = quantize_config self.quant_region_timer = QuantizationRegionTimer(logger=log) - self._turtle_reload_threshold_bytes = self._resolve_turtle_reload_threshold() - self._turtle_reload_accum_bytes = 0 - self._turtle_materialized_ids: Set[int] = set() + self._runtime_generate = None self.processor: ProcessorMixin = None self.model = self.after_model_load(model, load_quantized_model=load_quantized_model) self.turtle_model = turtle_model + # Captures forward-role auto-decoder choices for regression tests and debug logs. + self.auto_module_decoder_events: List[Dict[str, Any]] = [] + + if isinstance(self.model, PreTrainedModel): + ensure_hf_model_config_token_ids(self.model.config, tokenizer=tokenizer) if tokenizer is not None: if isinstance(tokenizer, PreTrainedTokenizerBase): - self.tokenizer = Tokenicer.load(tokenizer, trust_remote_code=trust_remote_code) + self.tokenizer = load_tokenizer_with_model_config(tokenizer, self.model.config) else: raise ValueError( f"Unsupported `tokenizer` type: Expected `PreTrainedTokenizerBase`, actual = `{type(tokenizer)}`.") @@ -283,6 +321,9 @@ def __init__( # auto-fix model config erors if isinstance(self.model, PreTrainedModel): autofix_hf_model_config(self.model, path=model_local_path) + # Reject activation-quantized checkpoints at load time so the rest of + # the floatx decoder stack can continue assuming dense activations. + self._configure_modelopt_runtime() self._turtle_lock = threading.RLock() @@ -291,7 +332,7 @@ def __init__( self.quant_log = [] if self.require_load_processor: - self.processor = AutoProcessor.from_pretrained(model_local_path) + self.processor = AutoProcessor.from_pretrained(model_local_path, trust_remote_code=self.require_trust_remote_code) # apply patching of broken trust_remote_code models here if self.require_monkeypatch: @@ -329,12 +370,28 @@ def extract_layers_node(cls): if node == "#": break if isinstance(node, str): - prefix_parts.append(node) + module_name, _ = cls._parse_module_flags(node) + prefix_parts.append(module_name) else: break # stop if unexpected nested structure return [".".join(prefix_parts)] if prefix_parts else [] + @classmethod + def _parse_module_aliases(cls, module_spec: str) -> List[str]: + """ + Parse a module specification into its ordered runtime/checkpoint aliases. + + The first alias is the runtime shell name. Any later aliases are + alternate checkpoint names declared directly in the model definition. + """ + parts = module_spec.split(":") if isinstance(module_spec, str) else [] + name = parts[0] if parts else module_spec + if not isinstance(name, str): + return [name] + aliases = [alias for alias in name.split("|") if alias] + return aliases or [name] + @classmethod def _parse_module_flags(cls, module_spec: str) -> tuple[str, List[str]]: """ @@ -342,7 +399,8 @@ def _parse_module_flags(cls, module_spec: str) -> tuple[str, List[str]]: Example: "gate:moe:!" -> ("gate", ["moe", "!"]) """ parts = module_spec.split(":") if isinstance(module_spec, str) else [] - name = parts[0] if parts else module_spec + aliases = cls._parse_module_aliases(module_spec) if isinstance(module_spec, str) else [module_spec] + name = aliases[0] if aliases else module_spec flags = [p for p in parts[1:] if p] return name, flags @@ -544,11 +602,17 @@ def get_num_experts(cls, model_config): @classmethod def filter_not_quantize_module(cls, layer_modules, quantize_config): - layer_modules = [ - [name for name in block if NOT_QUANTIZE_FLAG not in name] - for block in layer_modules - ] - layer_modules = [block for block in layer_modules if block] # 去掉空 block + def should_quantize(name: str) -> bool: + # Check if the module name contains any NON_QUANTIZE_FLAGS that indicates it should NOT be quantized + return not any(flag in name for flag in NON_QUANTIZE_FLAGS) + + filtered_layer_modules = [] + for block in layer_modules: + filtered_block = [name for name in block if should_quantize(name)] + filtered_layer_modules.append(filtered_block) + layer_modules = filtered_layer_modules + + layer_modules = [block for block in layer_modules if block] # Remove empty blocks if getattr(quantize_config, "dynamic", None): new_layer_modules = [] @@ -612,7 +676,7 @@ def prepare_dataset( def quantize( self, - calibration: Union[List[Dict[str, Union[List[int], torch.LongTensor]]], List[str], List[int]], + calibration: Optional[Union[List[Dict[str, Union[List[int], torch.LongTensor]]], List[str], List[int]]] = None, # Setting a fixed calibration_dataset_concat_size may improve the performance of the quantized model. calibration_concat_size: Optional[int] = None, calibration_sort: Optional[str] = "desc", # valid values are asc, desc, shuffle @@ -626,7 +690,7 @@ def quantize( calibration_data_min_length: int = 10, calibration_concat_separator: Optional[str] = None, ) -> Dict[str, List[Dict[str, str]]]: - if self.quantize_config is None or not isinstance(self.quantize_config, QuantizeConfig): + if self.quantize_config is None or not isinstance(self.quantize_config, BaseQuantizeConfig): raise AttributeError("`quantize_config` must be not None") if self.quantized: @@ -636,25 +700,26 @@ def quantize( if timer is not None: timer.reset() - self._turtle_reload_accum_bytes = 0 - self._turtle_materialized_ids = set() - - if self.quantize_config.quant_method in QUANTIZE_BLACK_LIST: + if self.quantize_config.method in QUANTIZE_BLACK_LIST: raise ValueError( - f"Unsupported quantization operation for quant method: {self.quantize_config.quant_method}" + f"Unsupported quantization operation for quant method: {self.quantize_config.method}" ) if not self.support_batch_quantize: log.warn("Quantize: batch_size overridden by model class definition to `disabled`") batch_size = 1 # but actually disabled - if self.quantize_config.format == FORMAT.MARLIN: + format_code = resolve_quant_format(self.quantize_config.format, self.quantize_config.method) + + if format_code == FORMAT.MARLIN: raise ValueError( "FORMAT.MARLIN is deprecated for quantization. Please switch to FORMAT.GPTQ. GPTQMOdel will auto-use Marlin kernel for accelerated inference for FORMAT.GPTQ." ) - if self.quantize_config.quant_method == METHOD.AWQ: - if self.quantize_config.format in [FORMAT.GEMV_FAST, FORMAT.LLM_AWQ]: + export_quant_method = self.quantize_config.export_quant_method() + + if export_quant_method == METHOD.AWQ: + if format_code in [FORMAT.GEMV_FAST, FORMAT.LLM_AWQ]: # AWQ GEMV_FAST / LLM_AWQ only supports pack_dtype is torch.int16 log.info("Quantize Model: Auto fix `pack_dtype` to `torch.int16`") self.quantize_config.pack_dtype = torch.int16 @@ -664,51 +729,83 @@ def quantize( log.warn("Batch quantization is not supported for this model. Setting batch_size to 1.") requested_backend = backend - if isinstance(requested_backend, str): - requested_backend = BACKEND(requested_backend.lower()) + requested_backend = normalize_backend(requested_backend, quant_method=export_quant_method) preferred_backend = requested_backend if preferred_backend in (None, BACKEND.AUTO): - if self.quantize_config.quant_method == METHOD.AWQ: - if self.quantize_config.format == FORMAT.GEMM: - preferred_backend = BACKEND.GEMM - elif self.quantize_config.format == FORMAT.GEMV: - preferred_backend = BACKEND.GEMV - elif self.quantize_config.format in [FORMAT.GEMV_FAST, FORMAT.LLM_AWQ]: - preferred_backend = BACKEND.GEMV_FAST + if export_quant_method == METHOD.AWQ: + if format_code == FORMAT.GEMM: + # Weight-only RTN->AWQ export should stay on the portable torch kernel. + preferred_backend = ( + BACKEND.AWQ_TORCH + if self.quantize_config.uses_weight_only_lifecycle() + else BACKEND.AWQ_GEMM + ) + elif format_code == FORMAT.BITBLAS: + preferred_backend = BACKEND.AWQ_BITBLAS + elif format_code == FORMAT.GEMV: + preferred_backend = BACKEND.AWQ_GEMV + elif format_code in [FORMAT.GEMV_FAST, FORMAT.LLM_AWQ]: + preferred_backend = BACKEND.AWQ_GEMV_FAST else: raise ValueError(f"Unsupported FORMAT: `{self.quantize_config.format}` with `METHOD.AWQ`") - elif self.quantize_config.quant_method == METHOD.QQQ: + elif self.quantize_config.method == METHOD.QQQ: preferred_backend = BACKEND.QQQ + elif self.quantize_config.method == METHOD.PARO: + preferred_backend = BACKEND.PAROQUANT_CUDA + elif self.quantize_config.method == METHOD.EXL3: + preferred_backend = BACKEND.EXL3_EXLLAMA_V3 + elif self.quantize_config.method == METHOD.GGUF: + preferred_backend = BACKEND.AUTO + elif self.quantize_config.method == METHOD.FP8: + preferred_backend = BACKEND.FP8_TORCH + elif self.quantize_config.method == METHOD.BITSANDBYTES: + preferred_backend = BACKEND.BITSANDBYTES else: - preferred_backend = BACKEND.TORCH - - # Validate quant linear before quantization starts - _ = select_quant_linear( - bits=self.quantize_config.bits, - dynamic=self.quantize_config.dynamic, - group_size=self.quantize_config.group_size, - desc_act=self.quantize_config.desc_act, - sym=self.quantize_config.sym, - backend=preferred_backend, - format=self.quantize_config.format, - quant_method=self.quantize_config.quant_method, - device=DEVICE(self.quantize_config.device), - pack=True, - pack_dtype=self.quantize_config.pack_dtype, + preferred_backend = BACKEND.GPTQ_TORCH - ) + if self.quantize_config.method == METHOD.EXL3: + if preferred_backend not in (BACKEND.AUTO, BACKEND.EXL3_EXLLAMA_V3): + raise ValueError("EXL3 quantization only supports BACKEND.AUTO or BACKEND.EXL3_EXLLAMA_V3.") + + if not torch.cuda.is_available(): + raise ValueError("EXL3 quantization requires CUDA/HIP.") + + quant_device = self.quantize_config.device + if isinstance(quant_device, DEVICE): + quant_device_type = quant_device.type + elif isinstance(quant_device, torch.device): + quant_device_type = quant_device.type + else: + quant_device_type = str(quant_device).split(":")[0].lower() + + if quant_device_type != "cuda": + raise ValueError("EXL3 quantization requires a CUDA/HIP quantization device.") + else: + # Validate quant linear before quantization starts + _ = select_quant_linear( + bits=self.quantize_config.runtime_bits, + dynamic=self.quantize_config.dynamic, + group_size=self.quantize_config.group_size, + desc_act=self.quantize_config.desc_act, + sym=self.quantize_config.sym, + backend=preferred_backend, + format=format_code, + quant_method=export_quant_method, + device=DEVICE(self.quantize_config.device), + pack=True, + pack_dtype=self.quantize_config.pack_dtype, + ) # Use the provided tokenizer if one is passed to quantize() if tokenizer is not None: if isinstance(tokenizer, PreTrainedTokenizerBase): - # TODO FIX ME...this is a bug - self.tokenizer = Tokenicer.load(tokenizer, trust_remote_code=self.trust_remote_code) + self.tokenizer = load_tokenizer_with_model_config(tokenizer, self.model.config) else: raise ValueError( f"Unsupported `tokenizer` type: Expected `PreTrainedTokenizerBase`, actual = `{type(tokenizer)}`.") - if self.quantize_config.format == FORMAT.BITBLAS: + if format_code == FORMAT.BITBLAS: from ..nn_modules.qlinear.bitblas import BITBLAS_AVAILABLE, BITBLAS_INSTALL_HINT if BITBLAS_AVAILABLE is False: raise ValueError(BITBLAS_INSTALL_HINT) @@ -717,39 +814,23 @@ def quantize( if adapter is not None: self.quantize_config.adapter = adapter - from ..adapter.adapter import Lora - from ..looper.eora_processor import EoraProcessor - from ..looper.module_looper import ModuleLooper - - # has lora process - needs_lora = isinstance(self.quantize_config.adapter, Lora) - - args = { - "tokenizer": self.tokenizer, - "qcfg": self.quantize_config, - "calibration": calibration, - "prepare_dataset_func": self.prepare_dataset, - "calibration_concat_size": calibration_concat_size, - "calibration_sort": calibration_sort, - "calibration_concat_separator": calibration_concat_separator, - "batch_size": batch_size, - "calculate_w_wq_diff": needs_lora, # lora needs original w - wq delta - } - - self.qlinear_kernel = select_quant_linear( - bits=self.quantize_config.bits, - group_size=self.quantize_config.group_size, - desc_act=self.quantize_config.desc_act, - sym=self.quantize_config.sym, - pack=True, - dynamic=self.quantize_config.dynamic, - device=self.quantize_config.device, - pack_dtype=self.quantize_config.pack_dtype, - multi_select=False, - backend=preferred_backend, - format=self.quantize_config.format, - quant_method=self.quantize_config.quant_method, - ) + if self.quantize_config.method == METHOD.EXL3: + self.qlinear_kernel = ExllamaV3Linear + else: + self.qlinear_kernel = select_quant_linear( + bits=self.quantize_config.runtime_bits, + group_size=self.quantize_config.group_size, + desc_act=self.quantize_config.desc_act, + sym=self.quantize_config.sym, + pack=True, + dynamic=self.quantize_config.dynamic, + device=DEVICE(self.quantize_config.device), + pack_dtype=self.quantize_config.pack_dtype, + multi_select=False, + backend=preferred_backend, + format=format_code, + quant_method=export_quant_method, + ) # rotate model if self.quantize_config.rotation: @@ -778,46 +859,138 @@ def quantize( self.model, _ = rotate_model(model=self.model, rotate_mode=self.quantize_config.rotation, device=rotation_device, **module_name_args) - # init processor with default GPTQ processor - from ..looper.tensorparallel_weight_processor import TensorParallelWeightProcessor + if self.quantize_config.uses_weight_only_lifecycle(): + result = self._quantize_weight_only( + calibration=calibration, + calibration_concat_size=calibration_concat_size, + calibration_sort=calibration_sort, + batch_size=batch_size, + backend=backend, + calibration_concat_separator=calibration_concat_separator, + ) + else: + if calibration is None: + raise ValueError( + "Calibration dataset is required unless a weight-only quantize config is configured." + ) + result = self._quantize_with_calibration( + calibration=calibration, + calibration_concat_size=calibration_concat_size, + calibration_sort=calibration_sort, + batch_size=batch_size, + backend=backend, + adapter_calibration_dataset=adapter_calibration_dataset, + calibration_concat_separator=calibration_concat_separator, + ) + + timer = getattr(self, "quant_region_timer", None) + if timer is not None: + timer.flush() + + return result - if self.quantize_config.quant_method == METHOD.QQQ: + def _quantize_with_calibration( + self, + *, + calibration, + calibration_concat_size: Optional[int], + calibration_sort: Optional[str], + batch_size: int, + backend: Optional[BACKEND], + adapter_calibration_dataset, + calibration_concat_separator: Optional[str], + ): + from ..adapter.adapter import Lora + from ..looper.eora_processor import EoraProcessor + from ..looper.module_looper import ModuleLooper + from ..looper.module_preprocessor import ModulePreProcessor + + needs_lora = isinstance(self.quantize_config.adapter, Lora) + + args = { + "tokenizer": self.tokenizer, + "qcfg": self.quantize_config, + "calibration": calibration, + "prepare_dataset_func": self.prepare_dataset, + "calibration_concat_size": calibration_concat_size, + "calibration_sort": calibration_sort, + "calibration_concat_separator": calibration_concat_separator, + "batch_size": batch_size, + "calculate_w_wq_diff": needs_lora, + } + + preprocessors = [] + if getattr(self.quantize_config, "preprocessors", None): + preprocessors.append(ModulePreProcessor(**args)) + + if self.quantize_config.method == METHOD.EXL3: + from ..looper.exllamav3_processor import EXL3Processor + + if needs_lora: + raise NotImplementedError("EXL3 quantization does not support adapter/EoRA generation.") + + if getattr(self.quantize_config, "gptaq", None) is not None: + raise NotImplementedError("EXL3 quantization does not support GPTAQ/native activation capture.") + + if getattr(self.quantize_config, "foem", None) is not None: + raise NotImplementedError("EXL3 quantization does not support FOEM/native activation capture.") + + exl3_args = { + "tokenizer": self.tokenizer, + "qcfg": self.quantize_config, + "calibration": calibration, + "prepare_dataset_func": self.prepare_dataset, + "calibration_concat_size": calibration_concat_size, + "calibration_sort": calibration_sort, + "calibration_concat_separator": calibration_concat_separator, + "batch_size": batch_size, + "lm_head_name": self.lm_head, + } + quantize_processor = preprocessors + [ + EXL3Processor(**exl3_args), + ] + elif self.quantize_config.method == METHOD.QQQ: from ..looper.qqq_processor import QQQProcessor - quantize_processor = [ - TensorParallelWeightProcessor(**args), + quantize_processor = preprocessors + [ QQQProcessor(**args), ] - elif self.quantize_config.quant_method == METHOD.AWQ: + elif self.quantize_config.method == METHOD.AWQ: from ..looper.awq_processor import AWQProcessor os.environ["AWQ_BATCH_SIZE"] = str(batch_size) - # if self.model.config.model_type not in AWQ_CAUSAL_LM_MODEL_MAP.keys(): - # raise TypeError(f"{self.model.config.model_type} isn't supported yet.") - awq_args = dict(args) awq_args["gptq_model"] = self awq_args["model"] = self.model awq_args["batch_size"] = batch_size - quantize_processor = [ - TensorParallelWeightProcessor(**args), + quantize_processor = preprocessors + [ AWQProcessor(**awq_args), ] + elif self.quantize_config.method == METHOD.PARO: + from ..looper.paroquant_processor import ParoQuantProcessor + + os.environ["AWQ_BATCH_SIZE"] = str(batch_size) + + paro_args = dict(args) + paro_args["gptq_model"] = self + paro_args["model"] = self.model + paro_args["batch_size"] = batch_size + + quantize_processor = preprocessors + [ + ParoQuantProcessor(**paro_args), + ] else: from ..looper.gptq_processor import GPTQProcessor - quantize_processor = [ - TensorParallelWeightProcessor(**args), + quantize_processor = preprocessors + [ GPTQProcessor(**args), ] - if self.quantize_config.gptaq is not None: + if getattr(self.quantize_config, "gptaq", None) is not None: from ..looper.native_processor import NativeProcessor - # During the deepcopy process, self.prepare_dataset will be deeply copied along with self. However, - # self has a threading.RLock() , which is not serializable. args_to_copy = {k: v for k, v in args.items() if k != "prepare_dataset_func"} args_clone = copy.deepcopy(args_to_copy) args_clone["prepare_dataset_func"] = args["prepare_dataset_func"] @@ -825,8 +998,18 @@ def quantize( args_clone.pop("calculate_w_wq_diff", None) quantize_processor.insert(0, NativeProcessor(**args_clone)) + if getattr(self.quantize_config, "foem", None) is not None: + if self.quantize_config.foem.alpha > 0: + from ..looper.native_processor import NativeProcessor + + args_to_copy = {k: v for k, v in args.items() if k != "prepare_dataset_func"} + args_clone = copy.deepcopy(args_to_copy) + args_clone["prepare_dataset_func"] = args["prepare_dataset_func"] + + args_clone.pop("calculate_w_wq_diff", None) + quantize_processor.insert(0, NativeProcessor(**args_clone)) + processors = quantize_processor - # Append EoRA processor for lora adapter if needs_lora: processors.append( EoraProcessor( @@ -841,11 +1024,8 @@ def quantize( ) ) - # prepare processor worker (looper) module_looper = ModuleLooper(self, processors=processors) - # When gc_mode=ON_STAGE_END, disable auto-gc for the whole quantization process - # to prevent interference with manual cleanups performed at stage ends gc_context = ( DEVICE_THREAD_POOL.no_auto_gc() if self.quantize_config.gc_mode == GcMode.ON_STAGE_END @@ -853,22 +1033,65 @@ def quantize( ) with gc_context: - result = module_looper.loop( + return module_looper.loop( backend=backend, - failsafe=self.quantize_config.failsafe, + fallback=self.quantize_config.fallback, ) - timer = getattr(self, "quant_region_timer", None) - if timer is not None: - timer.flush() + def _quantize_weight_only( + self, + *, + calibration, + calibration_concat_size: Optional[int], + calibration_sort: Optional[str], + batch_size: int, + backend: Optional[BACKEND], + calibration_concat_separator: Optional[str], + ): + del calibration_concat_size, calibration_sort, batch_size, calibration_concat_separator - return result + from ..adapter.adapter import Lora + from ..looper.weight_only_looper import WeightOnlyLooper + from ..looper.weight_only_processor import WeightOnlyProcessor + + if calibration is not None: + log.info("Weight-only quantization selected; ignoring provided calibration dataset.") + + if isinstance(self.quantize_config.adapter, Lora): + raise NotImplementedError( + "Weight-only quantization does not support adapter/EoRA generation." + ) + + if getattr(self.quantize_config, "gptaq", None) is not None: + raise NotImplementedError( + "Weight-only quantization does not support GPTAQ/native activation capture." + ) + + if getattr(self.quantize_config, "foem", None) is not None: + raise NotImplementedError( + "Weight-only quantization does not support FOEM/native activation capture." + ) + + processor = WeightOnlyProcessor( + tokenizer=self.tokenizer, + qcfg=self.quantize_config, + ) + module_looper = WeightOnlyLooper(model=self, processor=processor) + + gc_context = ( + DEVICE_THREAD_POOL.no_auto_gc() + if self.quantize_config.gc_mode == GcMode.ON_STAGE_END + else nullcontext() + ) + + with gc_context: + return module_looper.loop(backend=backend) def _eora_generate( self, # eora adapter generation needs config Lora(rank=1, path='lora.safetensors') adapter: Adapter, - quantized_modules: Dict[str, TorchQuantLinear], + quantized_modules: Dict[str, TorchLinear], calibration_dataset: Union[List[Dict[str, Union[List[int], torch.LongTensor]]], List[str], List[int]], calibration_dataset_concat_size: Optional[int] = None, calibration_dataset_sort: Optional[str] = None, @@ -882,8 +1105,7 @@ def _eora_generate( # Use the provided tokenizer if one is passed to quantize() if tokenizer is not None: if isinstance(tokenizer, PreTrainedTokenizerBase): - # TODO FIX ME...this is a bug - self.tokenizer = Tokenicer.load(tokenizer, trust_remote_code=self.trust_remote_code) + self.tokenizer = load_tokenizer_with_model_config(tokenizer, self.model.config) else: raise ValueError( f"Unsupported `tokenizer` type: Expected `PreTrainedTokenizerBase`, actual = `{type(tokenizer)}`.") @@ -892,38 +1114,44 @@ def _eora_generate( from ..looper.dequantize_processor import DequantizeProcessor from ..looper.eora_processor import EoraProcessor from ..looper.module_looper import ModuleLooper - from ..looper.tensorparallel_weight_processor import TensorParallelWeightProcessor + from ..looper.module_preprocessor import ModulePreProcessor self.quantize_config.adapter = adapter assert isinstance(self.quantize_config.adapter, Lora) # init processor with EoRA processor - processors = [ - TensorParallelWeightProcessor( - tokenizer=self.tokenizer, - qcfg=self.quantize_config, - calibration=calibration_dataset, - prepare_dataset_func=self.prepare_dataset, - calibration_concat_size=calibration_dataset_concat_size, - calibration_sort=calibration_dataset_sort, - calibration_concat_separator=calibration_concat_separator, - batch_size=batch_size, - ), - DequantizeProcessor( - quantized_modules=quantized_modules, - ), - EoraProcessor( - tokenizer=self.tokenizer, - qcfg=self.quantize_config, - calibration=calibration_dataset, - prepare_dataset_func=self.prepare_dataset, - calibration_concat_size=calibration_dataset_concat_size, - calibration_sort=calibration_dataset_sort, - calibration_concat_separator=calibration_concat_separator, - batch_size=batch_size, - ), - ] + processors = [] + if getattr(self.quantize_config, "preprocessors", None): + processors.append( + ModulePreProcessor( + tokenizer=self.tokenizer, + qcfg=self.quantize_config, + calibration=calibration_dataset, + prepare_dataset_func=self.prepare_dataset, + calibration_concat_size=calibration_dataset_concat_size, + calibration_sort=calibration_dataset_sort, + calibration_concat_separator=calibration_concat_separator, + batch_size=batch_size, + ), + ) + processors.extend( + [ + DequantizeProcessor( + quantized_modules=quantized_modules, + ), + EoraProcessor( + tokenizer=self.tokenizer, + qcfg=self.quantize_config, + calibration=calibration_dataset, + prepare_dataset_func=self.prepare_dataset, + calibration_concat_size=calibration_dataset_concat_size, + calibration_sort=calibration_dataset_sort, + calibration_concat_separator=calibration_concat_separator, + batch_size=batch_size, + ), + ] + ) # prepare processor worker (looper) module_looper = ModuleLooper(model=self, processors=processors) @@ -943,6 +1171,80 @@ def to(self, device: Union[str, torch.device]): def forward(self, *args, **kwargs): return self.model(*args, **kwargs) + def move_input_capture_example( + self, + example: Dict[str, Any], + data_device: torch.device, + ) -> Dict[str, Any]: + for key, value in example.items(): + if isinstance(value, list): + for index, item in enumerate(value): + if not torch.is_tensor(item): + continue + + if item.ndim == 1: + item = item.unsqueeze(0) + + value[index] = move_to(item, device=data_device) + elif torch.is_tensor(value): + if value.ndim == 1: + value = value.unsqueeze(0) + + example[key] = move_to(value, device=data_device) + + return self.finalize_input_capture_example(example) + + def finalize_input_capture_example( + self, + example: Dict[str, Any], + ) -> Dict[str, Any]: + if self.ATTENTION_MASKS_DTYPE is torch.long and "attention_mask" in example: + example["attention_mask"] = example["attention_mask"].long() + + return example + + def run_input_capture( + self, + example: Dict[str, Any], + use_cache: bool, + data_device: torch.device, + ): + if self.INPUT_EMBEDDING_EXTRA_ARGS: + return self.model.generate( + **example, + **self.INPUT_EMBEDDING_EXTRA_ARGS, + ) + + return self.model(**example, use_cache=use_cache) + + def _generate_with_runtime(self, runtime_generate, inputs=None, **kwargs): + def _normalize_generate_attention_mask(input_ids, attention_mask): + if not torch.is_tensor(attention_mask) or attention_mask.ndim <= 2: + return attention_mask + + seq_len = None + if torch.is_tensor(input_ids) and input_ids.ndim >= 2: + seq_len = input_ids.shape[-1] + + return normalize_seq_mask(attention_mask, seq_len=seq_len) + + if isinstance(inputs, str) or (isinstance(inputs, list) and all(isinstance(x, str) for x in inputs)): + kwargs.setdefault("prompts", inputs) + elif hasattr(inputs, "get") and not torch.is_tensor(inputs): + merged_kwargs = dict(inputs) + merged_kwargs.update(kwargs) + kwargs = merged_kwargs + elif inputs is not None: + kwargs.setdefault("input_ids", inputs) + + if "attention_mask" in kwargs: + kwargs["attention_mask"] = _normalize_generate_attention_mask( + kwargs.get("input_ids"), + kwargs["attention_mask"], + ) + + return runtime_generate(self.model, **kwargs) + def generate(self, inputs=None, **kwargs): with torch.inference_mode(): # fix hf generate not applying correct pad token @@ -950,28 +1252,52 @@ def generate(self, inputs=None, **kwargs): if pad_token_id is None and self.tokenizer: kwargs["pad_token_id"] = self.tokenizer.pad_token_id + runtime_generate = getattr(self, "_runtime_generate", None) + if runtime_generate is not None: + return self._generate_with_runtime(runtime_generate, inputs=inputs, **kwargs) + + def _normalize_generate_attention_mask(input_ids, attention_mask): + if not torch.is_tensor(attention_mask) or attention_mask.ndim <= 2: + return attention_mask + + seq_len = None + if torch.is_tensor(input_ids) and input_ids.ndim >= 2: + seq_len = input_ids.shape[-1] + + return normalize_seq_mask(attention_mask, seq_len=seq_len) + if isinstance(inputs, str) or (isinstance(inputs, list) and all(isinstance(x, str) for x in inputs)): if self.tokenizer is None: raise ValueError("You passed in an `input` to `generate()` of type `str` but model is missing `model.tokenizer`. Please set `model.tokenizer = my_tokenizer`.") - inputs = self.tokenizer(inputs, return_tensors="pt", padding=True, padding_side="left").to(self.model.device) + inputs = self.tokenizer(inputs, return_tensors="pt", padding=True, padding_side="left") + if "attention_mask" in inputs: + inputs["attention_mask"] = _normalize_generate_attention_mask( + inputs.get("input_ids"), + inputs["attention_mask"], + ) + inputs = inputs.to(self.model.device) + return self.model.generate(**inputs, **kwargs) + + if hasattr(inputs, "get") and not torch.is_tensor(inputs): + if "attention_mask" in inputs: + inputs["attention_mask"] = _normalize_generate_attention_mask( + inputs.get("input_ids"), + inputs["attention_mask"], + ) return self.model.generate(**inputs, **kwargs) + if "attention_mask" in kwargs: + kwargs["attention_mask"] = _normalize_generate_attention_mask( + kwargs.get("input_ids", inputs), + kwargs["attention_mask"], + ) + return self.model.generate(inputs=inputs, **kwargs) def prepare_inputs_for_generation(self, *args, **kwargs): """shortcut for model.prepare_inputs_for_generation""" return self.model.prepare_inputs_for_generation(*args, **kwargs) - # placeholder, noop, and alert users to correct static api - def push_to_hub(self, - repo_id: str, - quantized_path: str, # saved local directory path - private: bool = False, - exists_ok: bool = False, # set to true if repo already exists - token: Optional[str] = None): - - log.error("`push_to_hub()` api cannot be used on the model instance. Please use `GPTQModel.push_to_hub()` static api instead.") - def save( self, save_dir: str, @@ -979,6 +1305,7 @@ def save( max_shard_size: Optional[Union[int, str]] = DEFAULT_MAX_SHARD_SIZE, meta_quantizer: Optional[str] = None, eora_path: Optional[str] = None, + split_by: Optional[str] = None, **kwargs, ): timer = getattr(self, "quant_region_timer", None) @@ -994,7 +1321,8 @@ def save( safetensors_metadata=safetensors_metadata, max_shard_size=max_shard_size, meta_quantizer=meta_quantizer, - eora_path=eora_path) + eora_path=eora_path, + split_by=split_by) # overwrite quant_override_files for name, value in self.quant_override_files.items(): @@ -1019,6 +1347,171 @@ def save( ) timer.flush() + def _active_auto_module_decoder_config(self) -> Optional[AutoModuleDecoderConfig]: + """Return the active auto-decoder preprocessor config, if any.""" + + preprocessors = getattr(self.quantize_config, "preprocessors", None) or [] + for preprocessor in reversed(preprocessors): + if isinstance(preprocessor, AutoModuleDecoderConfig): + return preprocessor + return None + + def materialize_passthrough_modules_for_save(self) -> int: + """Decode passthrough floatx modules in-place before saving when configured.""" + + decoder_cfg = self._active_auto_module_decoder_config() + if decoder_cfg is None or decoder_cfg.passthrough_save_policy != "decode": + return 0 + + decoded_count = 0 + for _, module in list(self.model.named_modules()): + if isinstance(module, BaseQuantLinear) or not hasattr(module, "weight"): + continue + + checkpoint_tensors = None + if isinstance(self.turtle_model, LazyTurtle): + checkpoint_tensors = self.turtle_model.checkpoint_tensors_for_submodule( + target_model=self.model, + target_submodule=module, + recurse=False, + ) + if not checkpoint_tensors: + checkpoint_tensors = dict(module.state_dict(keep_vars=True)) + weight = checkpoint_tensors.get("weight") + if not isinstance(weight, torch.Tensor): + continue + + decoder_kind = self._decoder_weight_format( + weight=weight, + checkpoint_tensors=checkpoint_tensors, + ) + if decoder_kind is None: + continue + + decoded_module = self._build_decoder_quant_source_module( + module, + checkpoint_tensors=checkpoint_tensors, + target_dtype=decoder_cfg.target_dtype, + ) + self._replace_live_submodule(module, decoded_module) + decoded_count += 1 + + return decoded_count + + def materialize_passthrough_modules_for_eval( + self, + device: torch.device, + *, + respect_forward_policy: bool = False, + ) -> int: + """Materialize passthrough floatx modules into live evaluation modules.""" + + decoder_cfg = self._active_auto_module_decoder_config() + if decoder_cfg is None: + return 0 + + target_device = torch.device(device) + decoded_count = 0 + for _, module in list(self.model.named_modules()): + if isinstance(module, BaseQuantLinear) or not hasattr(module, "weight"): + continue + + checkpoint_tensors = None + if isinstance(self.turtle_model, LazyTurtle): + checkpoint_tensors = self.turtle_model.checkpoint_tensors_for_submodule( + target_model=self.model, + target_submodule=module, + recurse=False, + ) + if not checkpoint_tensors: + checkpoint_tensors = dict(module.state_dict(keep_vars=True)) + weight = checkpoint_tensors.get("weight") + if not isinstance(weight, torch.Tensor): + continue + + decoder_kind = self._decoder_weight_format( + weight=weight, + checkpoint_tensors=checkpoint_tensors, + ) + if decoder_kind is None: + continue + + forward_module = None + if respect_forward_policy and decoder_cfg.passthrough_forward_policy != "decode": + if decoder_kind == "fp8" and device_supports_dtype(target_device, weight.dtype, require_validation=False): + forward_module = self._build_fp8_forward_module( + target_submodule=module, + checkpoint_tensors=checkpoint_tensors, + device=target_device, + target_dtype=decoder_cfg.target_dtype, + ) + elif decoder_kind == "fp4" and device_supports_native_fp4(target_device, require_validation=False): + forward_module = self._build_fp4_forward_module( + target_submodule=module, + checkpoint_tensors=checkpoint_tensors, + device=target_device, + target_dtype=decoder_cfg.target_dtype, + ) + + if forward_module is None: + decoded_module = self._build_decoder_quant_source_module( + module, + checkpoint_tensors=checkpoint_tensors, + target_dtype=decoder_cfg.target_dtype, + ) + forward_module = self._build_decoder_forward_module( + quant_source=decoded_module, + device=target_device, + ) + self._replace_live_submodule(module, forward_module) + decoded_count += 1 + + return decoded_count + + def decoded_passthrough_state_dict_entries_for_save(self) -> tuple[Dict[str, torch.Tensor], List[str]]: + """Return dense state-dict entries that should replace native passthrough tensors on save.""" + + decoder_cfg = self._active_auto_module_decoder_config() + if decoder_cfg is None or decoder_cfg.passthrough_save_policy != "decode": + return {}, [] + + decoded_entries: Dict[str, torch.Tensor] = {} + decoded_prefixes: List[str] = [] + for module_name, module in list(self.model.named_modules()): + if not module_name or isinstance(module, BaseQuantLinear) or not hasattr(module, "weight"): + continue + + checkpoint_tensors = None + if isinstance(self.turtle_model, LazyTurtle): + checkpoint_tensors = self.turtle_model.checkpoint_tensors_for_submodule( + target_model=self.model, + target_submodule=module, + recurse=False, + ) + if not checkpoint_tensors: + checkpoint_tensors = dict(module.state_dict(keep_vars=True)) + weight = checkpoint_tensors.get("weight") + if not isinstance(weight, torch.Tensor): + continue + + decoder_kind = self._decoder_weight_format( + weight=weight, + checkpoint_tensors=checkpoint_tensors, + ) + if decoder_kind is None: + continue + + decoded_module = self._build_decoder_quant_source_module( + module, + checkpoint_tensors=checkpoint_tensors, + target_dtype=decoder_cfg.target_dtype, + ) + decoded_prefixes.append(module_name) + for key, tensor in decoded_module.state_dict().items(): + decoded_entries[f"{module_name}.{key}"] = tensor.detach().cpu() + + return decoded_entries, decoded_prefixes + # returns all the loaded qlinear types, returns empty [] if non-found def kernels(self) -> List[Type[BaseQuantLinear]]: @@ -1035,7 +1528,7 @@ def _auto_configure_lookahead(self) -> None: if not isinstance(self.model, nn.Module): return - quant_modules = [module for module in self.model.modules() if isinstance(module, TorchQuantLinear)] + quant_modules = [module for module in self.model.modules() if isinstance(module, TorchLinear)] if not quant_modules: return @@ -1102,7 +1595,7 @@ def serve_wait_until_ready(self, timeout: int = 30, check_interval: float = 0.1) if self.server is not None: self.server.wait_until_ready(timeout=timeout, check_interval=check_interval) - def before_model_load(self, load_quantized_model): + def before_model_load(self, model_local_path: str, load_quantized_model: bool): pass def after_model_load(self, model, load_quantized_model): @@ -1117,6 +1610,42 @@ def pre_quantize_generate_hook_end(self): # offload_to_disk(model=self.model, module=self.get_base_modules(model=self.model), disk_path=self.quantize_config.offload_to_disk_path) pass + def capture_first_layer_positional_inputs( + self, + args: tuple[Any, ...], + kwargs: Dict[str, Any], + batch_device: torch.device, + ) -> List[torch.Tensor]: + """Normalize first-layer positional inputs so cached forwards can replay decoder layers directly.""" + + if kwargs.get("hidden_states") is not None: + return [move_to(kwargs["hidden_states"], device=batch_device)] + if args: + return [move_to(args[0], device=batch_device)] + return [] + + def capture_first_layer_input_kwargs( + self, + args: tuple[Any, ...], + kwargs: Dict[str, Any], + batch_device: torch.device, + layer_input_kwargs: Dict[str, Any], + ) -> Dict[str, Any]: + """Allow model definitions to persist extra first-layer replay metadata during calibration capture.""" + + return layer_input_kwargs + + def prepare_layer_replay_kwargs( + self, + layer: nn.Module, + layer_input: List[torch.Tensor], + additional_inputs: Dict[str, Any], + target_device: torch.device, + ) -> Dict[str, Any]: + """Allow model definitions to refresh layer-specific kwargs before cached layer replay.""" + + return additional_inputs + def lm_head_pre_quantize_generate_hook(self, inputs: List[List[torch.tensor]]) -> List[List[torch.tensor]]: if self.pre_lm_head_norm_module: norm, _ = get_module_by_name_prefix(self.model, [self.pre_lm_head_norm_module]) @@ -1130,7 +1659,7 @@ def lm_head_pre_quantize_generate_hook(self, inputs: List[List[torch.tensor]]) - return inputs def pre_quantize(self, module: nn.Module) -> nn.Module: - if get_device(module) == META: + if get_device(module) == META or _module_has_meta_tensors(module): return self.shell_module_materialize( target_submodule=module, device=self.quantize_config.device, @@ -1144,6 +1673,522 @@ def post_quantize(self, module: nn.Module) -> nn.Module: #return self.offload_to_disk(module=module) return move_to(module, device=CPU) + def _replace_live_submodule( + self, + current_submodule: nn.Module, + replacement: nn.Module, + ) -> nn.Module: + """Replace one live model submodule in place and return the replacement.""" + + module_path = _get_qualified_name(self.model, current_submodule) + parent, leaf = _get_parent_and_leaf_by_path(self.model, module_path) + setattr(parent, leaf, replacement) + return replacement + + def _build_decoder_quant_source_module( + self, + target_submodule: nn.Module, + *, + checkpoint_tensors: Optional[Dict[str, torch.Tensor]] = None, + target_dtype: torch.dtype, + ) -> nn.Module: + """Build a dense CPU source module from checkpoint tensors for quantization.""" + + quant_source = copy.deepcopy(target_submodule) + if _module_has_meta_tensors(quant_source): + quant_source = quant_source.to_empty(device=CPU) + else: + quant_source = quant_source.to(device=CPU) + weight = None if checkpoint_tensors is None else checkpoint_tensors.get("weight") + if isinstance(weight, torch.Tensor) and hasattr(quant_source, "weight"): + decoder_kind = self._decoder_weight_format( + weight=weight, + checkpoint_tensors=checkpoint_tensors, + ) + result_shape = tuple(getattr(quant_source.weight, "shape", weight.shape)) + scale = ( + self._decoder_fp4_effective_scale( + checkpoint_tensors=checkpoint_tensors, + result_shape=result_shape, + ) + if decoder_kind == "fp4" + else self._decoder_scale_tensor( + scale_tensor=checkpoint_tensors.get("weight_scale"), + result_shape=result_shape, + ) + ) + scale_inv = None + if not isinstance(scale, torch.Tensor): + scale_inv = self._decoder_scale_tensor( + scale_tensor=checkpoint_tensors.get("weight_scale_inv"), + result_shape=result_shape, + ) + if decoder_kind == "fp8": + decoded_weight = dequantize_fp8( + weight, + scale=scale if isinstance(scale, torch.Tensor) else None, + scale_inv=scale_inv if isinstance(scale_inv, torch.Tensor) else None, + axis=None, + target_dtype=target_dtype, + ) + elif decoder_kind == "fp4": + decoded_weight = dequantize_f4_e2m1( + weight, + scale=scale if isinstance(scale, torch.Tensor) else None, + scale_inv=scale_inv if isinstance(scale_inv, torch.Tensor) else None, + axis=None, + target_dtype=target_dtype, + ) + else: + decoded_weight = weight.to(dtype=target_dtype) + + existing_weight = getattr(quant_source, "weight") + quant_source.weight = nn.Parameter( + decoded_weight.to(device=CPU, dtype=target_dtype), + requires_grad=getattr(existing_weight, "requires_grad", False), + ) + + bias = None if checkpoint_tensors is None else checkpoint_tensors.get("bias") + if isinstance(bias, torch.Tensor) and getattr(quant_source, "bias", None) is not None: + existing_bias = quant_source.bias + quant_source.bias = nn.Parameter( + bias.to(device=CPU, dtype=target_dtype), + requires_grad=getattr(existing_bias, "requires_grad", False), + ) + + quant_source = quant_source.to(dtype=target_dtype) + quant_source.eval() + setattr(quant_source, "target_device", torch.device(CPU)) + return quant_source + + def _decoder_block_size(self) -> Optional[tuple[int, int]]: + """Read the checkpoint's floatx block size metadata when present.""" + + quant_config = getattr(getattr(self.model, "config", None), "quantization_config", None) + if isinstance(quant_config, dict): + block_size = quant_config.get("weight_block_size") + else: + block_size = getattr(quant_config, "weight_block_size", None) + if isinstance(block_size, (list, tuple)) and len(block_size) == 2: + return int(block_size[0]), int(block_size[1]) + return None + + def _decoder_quant_method_name(self) -> str: + """Return the checkpoint quantizer family declared in model config.""" + + quant_config = getattr(getattr(self.model, "config", None), "quantization_config", None) + if isinstance(quant_config, dict): + value = quant_config.get("quant_method") + else: + value = getattr(quant_config, "quant_method", None) + return str(value or "").strip().lower() + + def _uses_modelopt_runtime(self) -> bool: + """Return ``True`` when the checkpoint declares ModelOpt runtime semantics.""" + + return self._decoder_quant_method_name() == "modelopt" + + def _modelopt_activation_quantization_mode(self) -> Optional[str]: + """Describe unsupported ModelOpt activation quantization metadata when present.""" + + quant_config = getattr(getattr(self.model, "config", None), "quantization_config", None) + if not isinstance(quant_config, dict): + return None + + config_groups = quant_config.get("config_groups") + if isinstance(config_groups, dict): + for group_cfg in config_groups.values(): + if not isinstance(group_cfg, dict): + continue + input_activations = group_cfg.get("input_activations") + if isinstance(input_activations, dict): + num_bits = input_activations.get("num_bits") + if isinstance(num_bits, (int, float)) and int(num_bits) < 16: + return "input_activations" + + kv_cache_scheme = quant_config.get("kv_cache_scheme") + if isinstance(kv_cache_scheme, dict): + num_bits = kv_cache_scheme.get("num_bits") + if isinstance(num_bits, (int, float)) and int(num_bits) < 16: + return "kv_cache_scheme" + + if isinstance(self.turtle_model, LazyTurtle): + keys = self.turtle_model._weight_map.keys() + else: + keys = self.model.state_dict().keys() + if any(str(name).endswith((".input_scale", ".k_scale", ".v_scale")) for name in keys): + return "checkpoint_scales" + return None + + def _configure_modelopt_runtime(self) -> None: + """Reject unsupported ModelOpt activation quantization at load time.""" + + if not self._uses_modelopt_runtime(): + return + + unsupported_mode = self._modelopt_activation_quantization_mode() + if unsupported_mode is not None: + log.error("GPT-QModel currently does not support loading of activation quantized models") + raise ValueError( + "GPT-QModel currently does not support loading of activation quantized models. " + "GPTQModel does not support loading ModelOpt checkpoints with activation quantization. " + "Only dense-activation weight-only variants such as W8A16/FP8 and W4A16/FP4 are supported. " + f"Detected unsupported metadata: {unsupported_mode}." + ) + + def _decoder_scale_tensor( + self, + *, + scale_tensor: Optional[torch.Tensor], + result_shape: tuple[int, ...], + ) -> Optional[torch.Tensor]: + """Expand padded floatx block grids to the dense weight shape when needed.""" + + if not isinstance(scale_tensor, torch.Tensor): + return None + if scale_tensor.ndim != 2 or len(result_shape) != 2: + return scale_tensor + + rows, cols = result_shape + blocks_r, blocks_c = scale_tensor.shape + if rows % blocks_r == 0 and cols % blocks_c == 0: + return scale_tensor + + block_size = self._decoder_block_size() + if block_size is None: + return scale_tensor + + block_rows, block_cols = block_size + if blocks_r * block_rows < rows or blocks_c * block_cols < cols: + return scale_tensor + + expanded = scale_tensor.repeat_interleave(block_rows, dim=0) + expanded = expanded.repeat_interleave(block_cols, dim=1) + return expanded[:rows, :cols].contiguous() + + def _decoder_fp4_effective_scale( + self, + *, + checkpoint_tensors: Dict[str, torch.Tensor], + result_shape: tuple[int, ...], + ) -> Optional[torch.Tensor]: + """Resolve NVFP4 weight scales, including ModelOpt's secondary global scale.""" + + scale = self._decoder_scale_tensor( + scale_tensor=checkpoint_tensors.get("weight_scale"), + result_shape=result_shape, + ) + if not isinstance(scale, torch.Tensor): + return None + scale_2 = checkpoint_tensors.get("weight_scale_2") + if isinstance(scale_2, torch.Tensor): + scale = scale.to(torch.float32) * scale_2.to(torch.float32) + return scale + + def _decoder_weight_format( + self, + *, + weight: torch.Tensor, + checkpoint_tensors: Dict[str, torch.Tensor], + ) -> Optional[str]: + """Infer which floatx decoder matches one checkpoint weight tensor.""" + + if weight.dtype in available_float8_dtypes(): + return "fp8" + if is_fp4_packed_dtype(weight.dtype): + return "fp4" + if weight.dtype is not torch.uint8 or not isinstance(checkpoint_tensors.get("weight_scale"), torch.Tensor): + return None + if isinstance(checkpoint_tensors.get("weight_scale_2"), torch.Tensor): + return "fp4" + + quant_config = getattr(getattr(self.model, "config", None), "quantization_config", None) + if isinstance(quant_config, dict): + format_name = quant_config.get("format") or quant_config.get("quant_method") + else: + format_name = getattr(quant_config, "format", None) or getattr(quant_config, "quant_method", None) + if str(format_name or "").strip().lower() in {"nvfp4", "fp4"}: + return "fp4" + return None + + def _build_decoder_forward_module( + self, + *, + quant_source: nn.Module, + device: torch.device, + ) -> nn.Module: + """Clone the decoded quant source into a live forward module on ``device``.""" + + forward_module = copy.deepcopy(quant_source) + forward_module = forward_module.to(device=device) + forward_module.eval() + setattr(forward_module, "target_device", torch.device(device)) + return forward_module + + def _infer_fp8_forward_layout( + self, + *, + weight: torch.Tensor, + scale_inv: torch.Tensor, + ) -> tuple[str, Optional[tuple[int, int]]]: + """Infer the FP8 scale layout needed to rebuild a TorchFP8Linear wrapper.""" + + if scale_inv.numel() == 1: + return "tensor", None + if scale_inv.ndim == 1 and scale_inv.shape[0] == weight.shape[0]: + return "row", None + return "block", infer_block_shape(tuple(weight.shape), scale_inv) + + def _infer_fp4_forward_block_size( + self, + *, + target_submodule: nn.Module, + scale: torch.Tensor, + ) -> int: + """Infer the NVFP4 block size used along the input-feature axis.""" + + block_size = self._decoder_block_size() + if block_size is not None and target_submodule.in_features % block_size[1] == 0: + return int(block_size[1]) + + if scale.ndim >= 1 and scale.shape[-1] > 0 and target_submodule.in_features % scale.shape[-1] == 0: + return int(target_submodule.in_features // scale.shape[-1]) + + raise ValueError( + f"Cannot infer FP4 block size for in_features={target_submodule.in_features} " + f"and scale shape={tuple(scale.shape)}." + ) + + def _build_fp8_forward_module( + self, + *, + target_submodule: nn.Module, + checkpoint_tensors: Dict[str, torch.Tensor], + device: torch.device, + target_dtype: torch.dtype, + ) -> Optional[nn.Module]: + """Rebuild one linear submodule as a TorchFP8Linear forward wrapper.""" + + if not isinstance(target_submodule, nn.Linear): + return None + + weight = checkpoint_tensors.get("weight") + if not isinstance(weight, torch.Tensor): + return None + + scale_inv = self._decoder_scale_tensor( + scale_tensor=checkpoint_tensors.get("weight_scale_inv"), + result_shape=tuple(weight.shape), + ) + if not isinstance(scale_inv, torch.Tensor): + # ModelOpt-style FP8 checkpoints store direct scales instead of inverse scales; + # normalize them here so TorchFP8Linear can use one consistent metadata form. + scale = self._decoder_scale_tensor( + scale_tensor=checkpoint_tensors.get("weight_scale"), + result_shape=tuple(weight.shape), + ) + if not isinstance(scale, torch.Tensor): + return None + scale = scale.to(torch.float32) + tiny = torch.finfo(torch.float32).tiny + scale_inv = torch.where( + scale != 0, + torch.reciprocal(scale), + torch.full_like(scale, 1.0 / tiny), + ) + + format_name = str(weight.dtype).split(".")[-1] + try: + # Infer the wrapper layout from the normalized inverse-scale tensor so native + # FP8 execution works for either checkpoint convention. + weight_scale_method, weight_block_size = self._infer_fp8_forward_layout( + weight=weight, + scale_inv=scale_inv, + ) + forward_module = TorchFP8Linear( + bits=8, + group_size=-1, + desc_act=False, + sym=True, + in_features=target_submodule.in_features, + out_features=target_submodule.out_features, + bias=target_submodule.bias is not None, + pack_dtype=torch.int32, + format=format_name, + weight_scale_method=weight_scale_method, + weight_block_size=weight_block_size, + register_buffers=False, + ).to(device=device) + except Exception: + # Some checkpoints use padded or otherwise non-TorchFP8Linear layouts and must + # fall back to the decoded dense path even on native-FP8-capable GPUs. + return None + forward_module.register_buffer("weight", weight.to(device=device)) + forward_module.register_buffer( + "weight_scale_inv", + scale_inv.to(device=device, dtype=torch.float32), + ) + + bias = checkpoint_tensors.get("bias") + if isinstance(bias, torch.Tensor): + forward_module.register_buffer( + "bias", + bias.to(device=device, dtype=target_dtype), + ) + else: + forward_module.bias = None + + forward_module.eval() + setattr(forward_module, "target_device", torch.device(device)) + return forward_module + + def _build_fp4_forward_module( + self, + *, + target_submodule: nn.Module, + checkpoint_tensors: Dict[str, torch.Tensor], + device: torch.device, + target_dtype: torch.dtype, + ) -> Optional[nn.Module]: + """Rebuild one linear submodule as a native NVFP4 forward wrapper.""" + + if not isinstance(target_submodule, nn.Linear): + return None + + weight = checkpoint_tensors.get("weight") + scale = self._decoder_fp4_effective_scale( + checkpoint_tensors=checkpoint_tensors, + result_shape=(target_submodule.out_features, target_submodule.in_features), + ) + if not isinstance(weight, torch.Tensor) or not isinstance(scale, torch.Tensor): + return None + + try: + block_size = self._infer_fp4_forward_block_size( + target_submodule=target_submodule, + scale=scale, + ) + forward_module = TorchFP4Linear( + in_features=target_submodule.in_features, + out_features=target_submodule.out_features, + weight=weight.to(device=device), + weight_scale=scale.to(device=device), + weight_block_size=block_size, + orig_dtype=target_dtype, + bias=checkpoint_tensors.get("bias").to(device=device, dtype=target_dtype) + if isinstance(checkpoint_tensors.get("bias"), torch.Tensor) + else None, + ) + except Exception: + return None + + forward_module.eval() + setattr(forward_module, "target_device", torch.device(device)) + return forward_module + + def _record_auto_module_decoder_event( + self, + *, + named_module: "NamedModule", + device: torch.device, + forward_mode: str, + source_dtype: torch.dtype, + target_dtype: torch.dtype, + ) -> None: + """Store one auto-decoder decision so tests can assert the chosen path.""" + + if named_module.state.get("_auto_module_decoder_event_recorded"): + return + + self.auto_module_decoder_events.append( + { + "module": named_module.full_name, + "device": str(device), + "forward_mode": forward_mode, + "source_dtype": str(source_dtype).split(".")[-1], + "target_dtype": str(target_dtype).split(".")[-1], + } + ) + named_module.state["_auto_module_decoder_event_recorded"] = True + + def _prepare_auto_decoder_forward_module( + self, + *, + target_submodule: nn.Module, + device: torch.device, + named_module: "NamedModule", + ) -> nn.Module: + """Swap one decoded shell module to an FP8 forward view when supported.""" + + decoder_plan = named_module.state.get("auto_module_decoder") + turtle_model = self.turtle_model + if not isinstance(decoder_plan, dict) or turtle_model is None: + return target_submodule + + checkpoint_tensors = turtle_model.checkpoint_tensors_for_submodule( + target_model=self.model, + target_submodule=target_submodule, + recurse=False, + ) + weight = checkpoint_tensors.get("weight") + if not isinstance(weight, torch.Tensor): + return target_submodule + + decoder_kind = self._decoder_weight_format( + weight=weight, + checkpoint_tensors=checkpoint_tensors, + ) + if decoder_kind is None: + return target_submodule + + target_dtype = decoder_plan.get("target_dtype", target_submodule.weight.dtype) + forward_policy = str(decoder_plan.get("passthrough_forward_policy", "native")).strip().lower() + if not isinstance(named_module.state.get("quant_source_module"), nn.Module): + named_module.state["quant_source_module"] = self._build_decoder_quant_source_module( + target_submodule, + checkpoint_tensors=checkpoint_tensors, + target_dtype=target_dtype, + ) + + forward_mode = "decode" + replacement = target_submodule + if forward_policy != "decode" and decoder_kind == "fp8" and device_supports_dtype(device, weight.dtype, require_validation=False): + fp8_module = self._build_fp8_forward_module( + target_submodule=target_submodule, + checkpoint_tensors=checkpoint_tensors, + device=device, + target_dtype=target_dtype, + ) + if fp8_module is not None: + replacement = self._replace_live_submodule(target_submodule, fp8_module) + forward_mode = "native" + elif forward_policy != "decode" and decoder_kind == "fp4" and device_supports_native_fp4(device, require_validation=False): + fp4_module = self._build_fp4_forward_module( + target_submodule=target_submodule, + checkpoint_tensors=checkpoint_tensors, + device=device, + target_dtype=target_dtype, + ) + if fp4_module is not None: + replacement = self._replace_live_submodule(target_submodule, fp4_module) + forward_mode = "native" + if forward_mode == "decode": + decoded_forward = self._build_decoder_forward_module( + quant_source=named_module.state["quant_source_module"], + device=device, + ) + replacement = self._replace_live_submodule(target_submodule, decoded_forward) + + named_module.state["auto_module_decoder_forward_mode"] = forward_mode + self._record_auto_module_decoder_event( + named_module=named_module, + device=torch.device(device), + forward_mode=forward_mode, + source_dtype=weight.dtype, + target_dtype=target_dtype, + ) + return replacement + def move_embed(self, device: str): for embed_module_name in self.get_base_modules(self.model): embed_module, _ = get_module_by_name_prefix(self.model, embed_module_name) @@ -1346,172 +2391,62 @@ def format_nodes(nodes): # print("DEBUG AWQ NODES:", format_nodes(nodes)) return nodes - def _clone_model_init_kwargs(self, source: PreTrainedModel) -> Dict[str, Any]: - kwargs = getattr(source, "_model_init_kwargs", {}) or {} - if isinstance(kwargs, dict): - return dict(kwargs) - return copy.deepcopy(kwargs) - - def _resolve_turtle_reload_threshold(self) -> int: - if not getattr(self.quantize_config, "offload_to_disk", False): - return 0 - - default_bytes = 512 * 1024 ** 2 #512MB - raw = os.getenv("GPTQMODEL_RELOAD_THRESHOLD") - if raw is None or raw.strip() == "": - return default_bytes - - value = raw.strip().lower() - if value in {"0", "off", "disable", "disabled", "none"}: - return 0 - - units = { - "b": 1, - "kb": 1024, - "mb": 1024 ** 2, - "gb": 1024 ** 3, - "tb": 1024 ** 4, - } - - match = re.match(r"^([0-9]*\.?[0-9]+)\s*([a-z]*)$", value) - if match is None: - log.warn( - "GPTQMODEL_RELOAD_THRESHOLD value `%s` is invalid; defaulting to 512MB.", - raw, - ) - return default_bytes - - amount = float(match.group(1)) - unit = match.group(2) or "b" - multiplier = units.get(unit, None) - if multiplier is None: - log.warn( - "GPTQMODEL_RELOAD_THRESHOLD unit `%s` is unsupported; defaulting to bytes.", - unit, - ) - multiplier = 1 - - threshold = int(amount * multiplier) - if threshold < 0: - threshold = 0 - return threshold - - def _estimate_module_bytes(self, module: nn.Module) -> int: - if module is None: - return 0 - - total = 0 - seen: Set[int] = set() - tensors = list(module.parameters(recurse=True)) + list(module.buffers(recurse=True)) - for tensor in tensors: - if not isinstance(tensor, torch.Tensor): - continue - if tensor.device.type == "meta": - continue - try: - ptr = tensor.data_ptr() - except (RuntimeError, AssertionError): - ptr = None - if ptr is not None: - if ptr in seen: - continue - seen.add(ptr) - total += tensor.numel() * tensor.element_size() - return total - - def _maybe_auto_reload_after_alias( - self, - module: nn.Module, - target_submodule: nn.Module, - ) -> None: - if self.turtle_model is None: - return - - threshold = self._turtle_reload_threshold_bytes - if threshold <= 0: - return - - module_id = id(module) - if module_id in self._turtle_materialized_ids: - return - - bytes_added = self._estimate_module_bytes(module) - self._turtle_materialized_ids.add(module_id) - - if bytes_added <= 0: - return - - self._turtle_reload_accum_bytes += bytes_added - - if self._turtle_reload_accum_bytes >= threshold: - label = ( - getattr(target_submodule, "full_name", None) - or getattr(target_submodule, "name", None) - or getattr(module, "full_name", None) - or module.__class__.__name__ - ) - self.reload_turtle_model(source=f"auto:{label}") - self._turtle_reload_accum_bytes = 0 - - def reload_turtle_model(self, *, source: Optional[str] = None) -> None: - if self.quantize_config.offload_to_disk is False: - return - - timer = getattr(self, "quant_region_timer", None) - timing_ctx = timer.measure("model_reload", source=source) if timer else nullcontext() - - with timing_ctx: - def _do_reload(): - with self._turtle_lock: - turtle_model = self.turtle_model - model_local_path = self.model_local_path - loader = self.loader - - assert turtle_model is not None and model_local_path is not None - - reload_kwargs = self._clone_model_init_kwargs(turtle_model) - config = turtle_model.config - del turtle_model - - new_model = loader.from_pretrained( - model_local_path, - config=config, - low_cpu_mem_usage=True, - **reload_kwargs, - ) - new_model._model_init_kwargs = reload_kwargs - new_model.eval() - self.turtle_model = new_model - self._turtle_reload_accum_bytes = 0 - reload_spinner = log.spinner(title="Turtle model reloading...", interval=0.1) - try: - DEVICE_THREAD_POOL.submit("model_loader:cpu", _do_reload).result() - finally: - reload_spinner.close() - - # transfer actually materizlied module from turtle (real) to shell + # Materialize the target shell module from the lazy turtle source on the requested device. def shell_module_materialize( self, target_submodule: torch.nn.Module, device: torch.device, non_blocking: bool = False, + role: str = "default", + named_module: Optional["NamedModule"] = None, ) -> torch.nn.Module: with self._turtle_lock: - turtle_model = self.turtle_model + if role == "quant_source" and named_module is not None: + quant_source = named_module.state.get("quant_source_module") + if not isinstance(quant_source, nn.Module): + decoder_plan = named_module.state.get("auto_module_decoder") or {} + target_dtype = decoder_plan.get( + "target_dtype", + getattr(getattr(target_submodule, "weight", None), "dtype", torch.float16), + ) + checkpoint_tensors = None + if isinstance(self.turtle_model, LazyTurtle): + checkpoint_tensors = self.turtle_model.checkpoint_tensors_for_submodule( + target_model=self.model, + target_submodule=target_submodule, + recurse=False, + ) + quant_source = self._build_decoder_quant_source_module( + target_submodule, + checkpoint_tensors=checkpoint_tensors, + target_dtype=target_dtype, + ) + named_module.state["quant_source_module"] = quant_source + + module = self._replace_live_submodule(target_submodule, quant_source) + if get_device(module) != device: + module.to(device) + return module + turtle_model = self.turtle_model if turtle_model is None: if get_device(target_submodule) != device: target_submodule.to(device) + module = target_submodule + else: + module = alias_from_turtle_for_submodule( + target_model=self.model, + turtle_model=turtle_model, + target_submodule=target_submodule, + device=device, + ) - return target_submodule - - module = alias_from_turtle_for_submodule( - target_model=self.model, - turtle_model=turtle_model, - target_submodule=target_submodule, - device=device, - ) - self._maybe_auto_reload_after_alias(module, target_submodule) + if role == "forward" and named_module is not None: + module = self._prepare_auto_decoder_forward_module( + target_submodule=module, + device=torch.device(device), + named_module=named_module, + ) return module ## overrides nn.module.train() @@ -1564,10 +2499,7 @@ def build_layer_modules(cls, tree, include_capture_only: bool = False): group_seq = count() def _parse_token(token: str) -> tuple[str, List[str]]: - parts = token.split(":") - name = parts[0] - flags = [p for p in parts[1:] if p] - return name, flags + return cls._parse_module_flags(token) def _group_from_flags(flags: List[str]) -> int: for flag in flags: @@ -1775,14 +2707,14 @@ def get_base_modules(cls, model): assert sharp_idx > 0, "failed to get_base_modules" # root_path = ["model"] or ["model", "language_model"] - root_path = tree[:sharp_idx-1] + root_path = [cls._parse_module_flags(node)[0] if isinstance(node, str) else node for node in tree[:sharp_idx-1]] out = [] # Traverse each layer in root_path for i in range(len(root_path)): path = root_path[:i + 1] base = model - exclude = tree[len(path)] + exclude = cls._parse_module_flags(tree[len(path)])[0] if isinstance(tree[len(path)], str) else tree[len(path)] for node in path: base = getattr(base, node) @@ -1809,6 +2741,7 @@ def generate_layers_modules_tree_simple(self, node): if isinstance(node, dict): new_dict = {} for k, v in node.items(): + clean_key = self._parse_module_flags(k)[0] if isinstance(k, str) else k # Expand tuple-of-strings blocks (special handling) if isinstance(v, (tuple, list)) and all(isinstance(x, str) for x in v): # Rule 1: check if ALL entries are :! @@ -1816,16 +2749,16 @@ def generate_layers_modules_tree_simple(self, node): continue # skip this parent entirely # Rule 2: strip :! and :digit markers - cleaned = tuple(x.split(":")[0] for x in v) - new_dict[k] = cleaned + cleaned = tuple(self._parse_module_flags(x)[0] for x in v) + new_dict[clean_key] = cleaned else: # Recurse deeper - new_dict[k] = self.generate_layers_modules_tree_simple(v) + new_dict[clean_key] = self.generate_layers_modules_tree_simple(v) return new_dict # If it's a plain string (unlikely here), strip markers if isinstance(node, str): - return node.split(":")[0] + return self._parse_module_flags(node)[0] # For other types, return as-is return node @@ -1847,8 +2780,11 @@ def __getattr__(self, item): def _auto_detect_module_tree(self, model: PreTrainedModel, quant_method: METHOD): log.warn("Model not yet support, attempting Module Tree AutoCompat...") - if quant_method != METHOD.GPTQ: - log.warn(f"Module Tree AutoCompat: Failed, quant_method={quant_method}, only support GPTQ") + if quant_method not in {METHOD.GPTQ, METHOD.GGUF, METHOD.FP8, METHOD.BITSANDBYTES, METHOD.EXL3, METHOD.PARO}: + log.warn( + f"Module Tree AutoCompat: Failed, quant_method={quant_method}, " + "only support GPTQ/GGUF/FP8/BITSANDBYTES/EXL3/PAROQUANT" + ) return None def _get(path): diff --git a/gptqmodel/models/definitions/__init__.py b/gptqmodel/models/definitions/__init__.py index b4dd74a88..1102b6152 100644 --- a/gptqmodel/models/definitions/__init__.py +++ b/gptqmodel/models/definitions/__init__.py @@ -4,11 +4,17 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium +from packaging.version import Version +from transformers import __version__ as TRANSFORMERS_VERSION + # Many model architectures inherit from LlamaGPTQ, so it’s necessary to import llama first to avoid circular imports. from .llama import LlamaQModel # other model +from .afmoe import AfMoeQModel +from .apertus import ApertusQModel from .baichuan import BaiChuanQModel +from .bailing_moe import BailingMoeQModel from .bloom import BloomQModel from .brumby import BrumbyQModel from .chatglm import ChatGLMQModel @@ -24,53 +30,75 @@ from .exaone4 import Exaone4QModel from .ernie4_5 import Ernie4_5QModel from .ernie4_5_moe import Ernie4_5_MoeQModel +from .falcon_h1 import FalconH1QModel from .gemma2 import Gemma2QModel -from .gemma3 import Gemma3QModel +from .gemma3 import Gemma3ForConditionalGenerationGPTQ, Gemma3QModel +from .gemma4 import Gemma4ForConditionalGenerationGPTQ, Gemma4TextQModel from .glm import GlmQModel +from .glm4_moe import GLM4MoEGPTQ +from .glm4_moe_lite import Glm4MoeLiteQModel +from .glm4v import Glm4vGPTQ +from .glm_moe_dsa import GlmMoeDsaQModel from .gpt2 import GPT2QModel from .gpt_bigcode import GptBigCodeQModel from .gpt_neo import GptNeoQModel from .gpt_neox import GPTNeoXQModel +from .gpt_oss import GPTOSSGPTQ from .gptj import GptJQModel +from .granitemoehybrid import GraniteMoeHybridQModel from .grinmoe import GrinMoeQModel from .hymba import HymbaQModel from .instella import InstellaQModel from .internlm import InternLMQModel from .internlm2 import InternLM2QModel +from .klear import KlearQModel +from .lfm2_moe import LFM2MoeQModel +from .llada2 import LLaDA2MoeQModel from .llama4 import Llama4QModel +from .llava_qwen2 import LlavaQwen2QModel +from .longcat_flash import LongCatFlashQModel from .mimo import MimoQModel +from .minicpm import MiniCPMGPTQ from .minicpm3 import MiniCpm3QModel +from .minicpm_o import MiniCPMOQModel +from .minicpm_v import MiniCPMVQModel from .minimax_m2 import MiniMaxM2GPTQ +from .mistral3 import Mistral3GPTQ from .mixtral import MixtralQModel from .mllama import MLlamaQModel from .mobilellm import MobileLLMQModel from .moss import MossQModel from .mpt import MptQModel +from .nemotron_h import NemotronHQModel +from .olmoe import OlmoeGPTQ from .opt import OptQModel from .ovis import OvisQModel +from .ovis2 import Ovis2QModel +from .pangu_alpha import PanguAlphaQModel from .phi import PhiQModel -from .phi3 import Phi3QModel +from .phi3 import Phi3QModel, PhiMoEGPTQForCausalLM +from .phi4 import Phi4MMGPTQ from .qwen import QwenQModel from .qwen2 import Qwen2QModel +from .qwen2_5_omni import Qwen2_5_OmniGPTQ from .qwen2_5_vl import Qwen2_5_VLQModel from .qwen2_moe import Qwen2MoeQModel from .qwen2_vl import Qwen2VLQModel from .qwen3 import Qwen3QModel from .qwen3_moe import Qwen3MoeQModel +from .qwen3_next import Qwen3NextGPTQ +from .qwen3_omni_moe import Qwen3OmniMoeGPTQ from .qwen3_vl import Qwen3_VLQModel from .rw import RwgQModel from .starcoder2 import Starcoder2QModel from .telechat2 import TeleChat2QModel -from .xverse import XverseQModel -from .falcon_h1 import FalconH1QModel -from .pangu_alpha import PanguAlphaQModel -from .longcat_flash import LongCatFlashQModel -from .apertus import ApertusQModel -from .klear import KlearQModel -from .llava_qwen2 import LlavaQwen2QModel -from .nemotron_h import NemotronHQModel -from .qwen3_omni_moe import Qwen3OmniMoeGPTQ -from .mistral3 import Mistral3GPTQ -from .afmoe import AfMoeQModel -from .glm4v import Glm4vGPTQ from .voxtral import VoxtralGPTQ +from .xverse import XverseQModel + +TRANSFORMERS_SUPPORTS_QWEN3_5 = Version(TRANSFORMERS_VERSION) >= Version("5.2.0") +if TRANSFORMERS_SUPPORTS_QWEN3_5: + from .qwen3_5 import Qwen3_5QModel + from .qwen3_5_moe import Qwen3_5_MoeQModel +else: + Qwen3_5QModel = None + Qwen3_5_MoeQModel = None diff --git a/gptqmodel/models/definitions/baichuan.py b/gptqmodel/models/definitions/baichuan.py index 8d6b4a7c7..8dcf075ce 100644 --- a/gptqmodel/models/definitions/baichuan.py +++ b/gptqmodel/models/definitions/baichuan.py @@ -3,6 +3,8 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium +import torch + from ..base import BaseQModel @@ -20,3 +22,79 @@ class BaiChuanQModel(BaseQModel): "mlp": ("gate_proj:0", "up_proj:0", "down_proj:1"), } ] + + @staticmethod + def _set_non_persistent_buffer(module, name, tensor): + if not isinstance(tensor, torch.Tensor): + return + + if name not in getattr(module, "_buffers", {}) and hasattr(module, name): + delattr(module, name) + + if name in getattr(module, "_buffers", {}): + module._buffers[name] = tensor + non_persistent = getattr(module, "_non_persistent_buffers_set", None) + if isinstance(non_persistent, set): + non_persistent.add(name) + return + + module.register_buffer(name, tensor, persistent=False) + + @staticmethod + def _build_rotary_cache(inv_freq, max_seq_len): + inv_freq = inv_freq.to(dtype=torch.float32) + t = torch.arange(max_seq_len, device=inv_freq.device, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + return ( + emb.cos()[None, None, :, :].to(torch.float32), + emb.sin()[None, None, :, :].to(torch.float32), + ) + + def after_model_load(self, model, load_quantized_model=False): + model = super().after_model_load(model, load_quantized_model=load_quantized_model) + + layers = getattr(getattr(model, "model", None), "layers", None) + if layers is None: + return model + + for layer in layers: + rotary = getattr(getattr(layer, "self_attn", None), "rotary_emb", None) + if rotary is None: + continue + + inv_freq = getattr(rotary, "inv_freq", None) + max_seq_len = getattr(rotary, "max_seq_len_cached", None) + if max_seq_len is None: + max_seq_len = getattr(rotary, "max_position_embeddings", 2048) + + if not isinstance(inv_freq, torch.Tensor) or inv_freq.device.type == "meta": + if not isinstance(inv_freq, torch.Tensor): + continue + dim = inv_freq.numel() * 2 + base = getattr(rotary, "base", 10000) + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim) + ) + cos_cached, sin_cached = self._build_rotary_cache(inv_freq, max_seq_len) + else: + inv_freq = inv_freq.to(dtype=torch.float32) + cos_cached = getattr(rotary, "cos_cached", None) + sin_cached = getattr(rotary, "sin_cached", None) + if ( + not isinstance(cos_cached, torch.Tensor) + or not isinstance(sin_cached, torch.Tensor) + or cos_cached.device.type == "meta" + or sin_cached.device.type == "meta" + ): + cos_cached, sin_cached = self._build_rotary_cache(inv_freq, max_seq_len) + else: + cos_cached = cos_cached.to(dtype=torch.float32) + sin_cached = sin_cached.to(dtype=torch.float32) + + rotary.max_seq_len_cached = max_seq_len + self._set_non_persistent_buffer(rotary, "inv_freq", inv_freq) + self._set_non_persistent_buffer(rotary, "cos_cached", cos_cached) + self._set_non_persistent_buffer(rotary, "sin_cached", sin_cached) + + return model diff --git a/gptqmodel/models/definitions/base_qwen2_5_omni.py b/gptqmodel/models/definitions/base_qwen2_5_omni.py index c90db85c5..639b81490 100644 --- a/gptqmodel/models/definitions/base_qwen2_5_omni.py +++ b/gptqmodel/models/definitions/base_qwen2_5_omni.py @@ -19,9 +19,60 @@ from ..base import BaseQModel +def _patch_qwen2_5_omni_talker_prepare_inputs_for_generation(talker_cls): + original = getattr(talker_cls, "prepare_inputs_for_generation", None) + if original is None or getattr(original, "_gptqmodel_qwen2_5_omni_compat", False): + return + + def prepare_inputs_for_generation( + self, + input_ids, + input_text_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + thinker_reply_part=None, + cache_position=None, + position_ids=None, + use_cache=True, + pixel_values=None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + input_audio_features=None, + audio_feature_attention_mask=None, + audio_feature_lengths=None, + use_audio_in_video=False, + video_second_per_grid=None, + **kwargs, + ): + model_inputs = super(talker_cls, self).prepare_inputs_for_generation( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + use_cache=use_cache, + thinker_reply_part=thinker_reply_part, + input_text_ids=input_text_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + use_audio_in_video=use_audio_in_video, + audio_feature_lengths=audio_feature_lengths, + video_second_per_grid=video_second_per_grid, + **kwargs, + ) + model_inputs["position_ids"] = None + return model_inputs + + prepare_inputs_for_generation._gptqmodel_qwen2_5_omni_compat = True + talker_cls.prepare_inputs_for_generation = prepare_inputs_for_generation + + class BaseQwen2_5_OmniGPTQ(BaseQModel): ATTENTION_MASKS_REQUIRED_FOR_INPUT = True ATTENTION_MASKS_DTYPE = torch.long + require_monkeypatch = True INPUT_EMBEDDING_EXTRA_ARGS = { "return_audio": False, @@ -31,7 +82,7 @@ class BaseQwen2_5_OmniGPTQ(BaseQModel): pre_lm_head_norm_module = "thinker.model.norm" - require_pkgs = ["audioread>=3.1.0", "librosa>0.11.0", "av>=16.0.1"] + require_pkgs = ["audioread>=3.1.0", "librosa>=0.11.0", "av>=16.0.1"] module_tree = [ "thinker", @@ -50,6 +101,16 @@ class BaseQwen2_5_OmniGPTQ(BaseQModel): require_load_processor = True + def monkey_patch(self): + talker = getattr(getattr(self, "model", None), "talker", None) + if talker is None: + return + + if not getattr(type(talker), "__module__", "").startswith("transformers.models.qwen2_5_omni."): + return + + _patch_qwen2_5_omni_talker_prepare_inputs_for_generation(type(talker)) + def pre_quantize_generate_hook_start(self): # load speaker spk_path = os.path.join(self.model_local_path, "spk_dict.pt") diff --git a/gptqmodel/models/definitions/base_qwen2_vl.py b/gptqmodel/models/definitions/base_qwen2_vl.py index a97831746..443d2ed22 100644 --- a/gptqmodel/models/definitions/base_qwen2_vl.py +++ b/gptqmodel/models/definitions/base_qwen2_vl.py @@ -10,7 +10,7 @@ from ...utils.calibration import batched from ...utils.image import extract_vision_info, fetch_image -from ...utils.model import MODALITY, move_to +from ...utils.model import MODALITY, get_module, move_to from ...utils.offload import offload_to_disk from .._const import CPU from ..base import BaseQModel @@ -19,7 +19,7 @@ class BaseQwen2VLGPTQ(BaseQModel): loader = AutoModelForImageTextToText - pre_lm_head_norm_module = "model.norm" + pre_lm_head_norm_module = ["model.language_model.norm", "language_model.norm"] module_tree = [ "model", @@ -38,30 +38,73 @@ class BaseQwen2VLGPTQ(BaseQModel): require_load_processor = True + @classmethod + def extract_layers_node(cls): + return ["model.language_model.layers", "language_model.layers"] + + @classmethod + def get_base_modules(cls, model): + prefix, core_model = cls._resolve_multimodal_layout(model) + base_modules = [] + for name, _ in core_model.named_children(): + if name != "language_model": + base_modules.append(f"{prefix}.{name}" if prefix else name) + return base_modules + + @classmethod + def _resolve_multimodal_layout(cls, model): + for prefix in ("model", ""): + core_model = get_module(model, prefix) if prefix else model + if core_model is None: + continue + if hasattr(core_model, "language_model") and hasattr(core_model, "visual"): + return prefix, core_model + + raise AttributeError("Unable to resolve a Qwen VL core model with `language_model` and `visual` modules.") + + def _core_multimodal_model(self): + _, core_model = self._resolve_multimodal_layout(self.model) + return core_model + + def _materialize_core_module(self, parent, attr_name: str): + module = getattr(parent, attr_name) + if "_turtle_lock" not in self.__dict__ and "shell_module_materialize" not in self.__dict__: + setattr(parent, attr_name, move_to(module, device=self.quantize_config.device)) + return + setattr( + parent, + attr_name, + self.shell_module_materialize(module, self.quantize_config.device), + ) + def pre_quantize_generate_hook_start(self): - self.model.language_model.embed_tokens = move_to(self.model.language_model.embed_tokens, device=self.quantize_config.device) - self.model.language_model.rotary_emb = move_to(self.model.language_model.rotary_emb, device=self.quantize_config.device) - self.model.visual = move_to(self.model.visual, device=self.quantize_config.device) + core_model = self._core_multimodal_model() + language_model = core_model.language_model + self._materialize_core_module(core_model, "visual") + self._materialize_core_module(language_model, "embed_tokens") + self._materialize_core_module(language_model, "rotary_emb") def pre_quantize_generate_hook_end(self): + core_model = self._core_multimodal_model() + language_model = core_model.language_model if self.quantize_config.offload_to_disk: - offload_to_disk(model=self.model.language_model, - module=self.model.language_model.embed_tokens, + offload_to_disk(model=language_model, + module=language_model.embed_tokens, disk_path=self.quantize_config.offload_to_disk_path, ) - offload_to_disk(model=self.model.language_model, - module=self.model.language_model.rotary_emb, + offload_to_disk(model=language_model, + module=language_model.rotary_emb, disk_path=self.quantize_config.offload_to_disk_path, ) - offload_to_disk(model=self.model, - module=self.model.visual, + offload_to_disk(model=core_model, + module=core_model.visual, disk_path=self.quantize_config.offload_to_disk_path, ) return - self.model.language_model.embed_tokens = move_to(self.model.language_model.embed_tokens, device=CPU) - self.model.language_model.rotary_emb = move_to(self.model.language_model.rotary_emb, device=CPU) - self.model.visual = move_to(self.model.visual, device=CPU) + language_model.embed_tokens = move_to(language_model.embed_tokens, device=CPU) + language_model.rotary_emb = move_to(language_model.rotary_emb, device=CPU) + core_model.visual = move_to(core_model.visual, device=CPU) @staticmethod def process_vision_info( diff --git a/gptqmodel/models/definitions/base_qwen3_vl.py b/gptqmodel/models/definitions/base_qwen3_vl.py index d469cf495..8525328d6 100644 --- a/gptqmodel/models/definitions/base_qwen3_vl.py +++ b/gptqmodel/models/definitions/base_qwen3_vl.py @@ -11,7 +11,7 @@ from ...utils.calibration import batched from ...utils.image import extract_vision_info, fetch_image -from ...utils.model import MODALITY, move_to +from ...utils.model import MODALITY, get_module, move_to from ...utils.offload import offload_to_disk from .._const import CPU from ..base import BaseQModel @@ -35,7 +35,7 @@ def _load_fetch_video(): class BaseQwen3VLGPTQ(BaseQModel): loader = AutoModelForImageTextToText - pre_lm_head_norm_module = "model.norm" + pre_lm_head_norm_module = ["model.norm", "norm"] module_tree = [ "model", @@ -54,30 +54,71 @@ class BaseQwen3VLGPTQ(BaseQModel): require_load_processor = True + @classmethod + def extract_layers_node(cls): + return ["model.language_model.layers", "language_model.layers"] + + @classmethod + def get_base_modules(cls, model): + prefix, core_model = cls._resolve_multimodal_layout(model) + base_modules = [] + for name, _ in core_model.named_children(): + if name != "language_model": + base_modules.append(f"{prefix}.{name}" if prefix else name) + return base_modules + + @classmethod + def _resolve_multimodal_layout(cls, model): + for prefix in ("model", ""): + core_model = get_module(model, prefix) if prefix else model + if core_model is None: + continue + if hasattr(core_model, "language_model") and hasattr(core_model, "visual"): + return prefix, core_model + + raise AttributeError("Unable to resolve a Qwen VL core model with `language_model` and `visual` modules.") + + def _core_multimodal_model(self): + _, core_model = self._resolve_multimodal_layout(self.model) + return core_model + + def _materialize_core_module(self, parent, attr_name: str): + module = getattr(parent, attr_name) + if "_turtle_lock" not in self.__dict__ and "shell_module_materialize" not in self.__dict__: + setattr(parent, attr_name, move_to(module, device=self.quantize_config.device)) + return + setattr( + parent, + attr_name, + self.shell_module_materialize(module, self.quantize_config.device), + ) + def pre_quantize_generate_hook_start(self): - self.model.language_model.embed_tokens = move_to(self.model.language_model.embed_tokens, device=self.quantize_config.device) - self.model.language_model.rotary_emb = move_to(self.model.language_model.rotary_emb, device=self.quantize_config.device) - self.model.visual = move_to(self.model.visual, device=self.quantize_config.device) + core_model = self._core_multimodal_model() + self._materialize_core_module(core_model.language_model, "embed_tokens") + self._materialize_core_module(core_model.language_model, "rotary_emb") + self._materialize_core_module(core_model, "visual") def pre_quantize_generate_hook_end(self): + core_model = self._core_multimodal_model() if self.quantize_config.offload_to_disk: - offload_to_disk(model=self.model.language_model, - module=self.model.language_model.embed_tokens, + offload_to_disk(model=core_model.language_model, + module=core_model.language_model.embed_tokens, disk_path=self.quantize_config.offload_to_disk_path, ) - offload_to_disk(model=self.model.language_model, - module=self.model.language_model.rotary_emb, + offload_to_disk(model=core_model.language_model, + module=core_model.language_model.rotary_emb, disk_path=self.quantize_config.offload_to_disk_path, ) - offload_to_disk(model=self.model, - module=self.model.visual, + offload_to_disk(model=core_model, + module=core_model.visual, disk_path=self.quantize_config.offload_to_disk_path, ) return - self.model.language_model.embed_tokens = move_to(self.model.language_model.embed_tokens, device=CPU) - self.model.language_model.rotary_emb = move_to(self.model.language_model.rotary_emb, device=CPU) - self.model.visual = move_to(self.model.visual, device=CPU) + core_model.language_model.embed_tokens = move_to(core_model.language_model.embed_tokens, device=CPU) + core_model.language_model.rotary_emb = move_to(core_model.language_model.rotary_emb, device=CPU) + core_model.visual = move_to(core_model.visual, device=CPU) @staticmethod def process_vision_info( @@ -126,6 +167,10 @@ def process_vision_info( if return_video_kwargs: return image_inputs, video_inputs, video_kwargs + if video_inputs is None and not return_video_metadata: + # Keep the image-only call contract aligned with the earlier VL + # adapters so processor(images=...) can use the return value directly. + return image_inputs return image_inputs, video_inputs def preprocess_dataset(self, sample: Dict) -> Dict: @@ -141,7 +186,11 @@ def prepare_dataset(self, calibration_dataset, batch_size: int = 1, **kwargs): text = processor.apply_chat_template( batch, tokenize=False, add_generation_prompt=True ) - image_inputs, video_inputs = self.process_vision_info(batch) + vision_inputs = self.process_vision_info(batch) + if isinstance(vision_inputs, tuple): + image_inputs, video_inputs = vision_inputs + else: + image_inputs, video_inputs = vision_inputs, None inputs = processor( text=text, images=image_inputs, diff --git a/gptqmodel/models/definitions/brumby.py b/gptqmodel/models/definitions/brumby.py index 32e37e8c0..d532b9c61 100644 --- a/gptqmodel/models/definitions/brumby.py +++ b/gptqmodel/models/definitions/brumby.py @@ -36,6 +36,19 @@ class BrumbyQModel(BaseQModel): }, ] + def before_model_load(self, model_local_path: str, load_quantized_model: bool): + from transformers.dynamic_module_utils import get_class_from_dynamic_module + cls = get_class_from_dynamic_module( + "modeling_brumby.BrumbyRotaryEmbedding", + model_local_path, + ) + if not hasattr(cls, "compute_default_rope_parameters"): + from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS + def compute_default_rope_parameters(self, config): + return ROPE_INIT_FUNCTIONS["linear"](config) + + cls.compute_default_rope_parameters = compute_default_rope_parameters + def after_model_load(self, model, load_quantized_model=False): if hasattr(model, "config") and hasattr(model.config, "use_cache"): model.config.use_cache = False diff --git a/gptqmodel/models/definitions/ernie4_5.py b/gptqmodel/models/definitions/ernie4_5.py index ed161d7e6..6893bd3a0 100644 --- a/gptqmodel/models/definitions/ernie4_5.py +++ b/gptqmodel/models/definitions/ernie4_5.py @@ -3,173 +3,12 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium +from ...utils.logger import setup_logger from . import LlamaQModel -class Ernie4_5QModel(LlamaQModel): - require_trust_remote_code = True - support_batch_quantize = False - require_monkeypatch = True - - def monkey_patch(self): - from typing import Optional, Tuple - - import torch - from transformers.modeling_outputs import BaseModelOutputWithPast - - def ernie4_5_decode_layer_forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - attn_mask_start_row_indices: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - token_type_ids: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = False, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - use_cache: Optional[bool] = False, - ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: - residual = hidden_states - - hidden_states = self.input_layernorm(hidden_states) - - # Self Attention - (hidden_states, self_attn_weights, present_key_value) = self.self_attn( - hidden_states=hidden_states, - past_key_value=past_key_value, - attention_mask=attention_mask, - attn_mask_start_row_indices=attn_mask_start_row_indices, - position_ids=position_ids, - output_attentions=output_attentions, - use_cache=use_cache, - token_type_ids=token_type_ids, - ) - hidden_states = self.residual_add1(hidden_states, residual) - - # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - - hidden_states = self.residual_add2(hidden_states, residual) - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - def ernie4_5_model_forward( - self, - input_ids=None, - position_ids=None, - token_type_ids=None, - attention_mask=None, - attn_mask_start_row_indices=None, - inputs_embeds=None, - use_cache=None, - past_key_values=None, - output_attentions=False, - output_hidden_states=None, - return_dict=False, - ): - use_cache = use_cache if use_cache is not None else self.config.use_cache - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" - ) - elif input_ids is not None: - _, seq_length = input_ids.shape - elif inputs_embeds is not None: - _, seq_length, _ = inputs_embeds.shape - else: - raise ValueError( - "You have to specify either decoder_input_ids or decoder_inputs_embeds" - ) - - if past_key_values is None: - past_key_values = tuple([None] * len(self.layers)) - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - inputs_embeds = inputs_embeds.to(self.embed_tokens.weight.dtype) - - hidden_states = inputs_embeds +log = setup_logger() - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - for idx, (decoder_layer) in enumerate(self.layers): - - if output_hidden_states: - all_hidden_states += (hidden_states,) - - past_key_value = ( - past_key_values[idx] if past_key_values is not None else None - ) - - layer_outputs = decoder_layer( - hidden_states=hidden_states, - attention_mask=attention_mask, - attn_mask_start_row_indices=attn_mask_start_row_indices, - position_ids=position_ids, - token_type_ids=token_type_ids, - output_attentions=output_attentions, - past_key_value=past_key_value, - use_cache=use_cache, - ) - - if isinstance(layer_outputs, (tuple, list)): - hidden_states = layer_outputs[0] - else: - hidden_states = layer_outputs - - if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - # apply kv cache - if past_key_value is not None: - hidden_states = hidden_states[:, -1:, :] - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - - if not return_dict: - return tuple( - v - for v in [ - hidden_states, - next_cache, - all_hidden_states, - all_self_attns, - ] - if v is not None - ) - - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - if not self.load_quantized_model: - ernie4_5_model = type(self.model.model) - ernie4_5_model.forward = ernie4_5_model_forward - - ernie4_5_layer = type(self.model.model.layers[0]) - ernie4_5_layer.forward = ernie4_5_decode_layer_forward +class Ernie4_5QModel(LlamaQModel): + pass diff --git a/gptqmodel/models/definitions/gemma4.py b/gptqmodel/models/definitions/gemma4.py new file mode 100644 index 000000000..138816203 --- /dev/null +++ b/gptqmodel/models/definitions/gemma4.py @@ -0,0 +1,248 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from types import MethodType + +import torch + +from ...utils.device import get_device +from ...utils.model import get_module_by_name_prefix, move_to, nested_move_to +from ..base import BaseQModel +from . import LlamaQModel + + +_GEMMA4_ALL_PER_LAYER_INPUTS = "__gptqmodel_gemma4_all_per_layer_inputs" + + +def _gemma4_module_tree(): + """Return the Gemma 4 decoder traversal with optional attention and per-layer input modules.""" + + return [ + "model", + "layers", + "#", + { + "input_layernorm": ("input_layernorm:!",), + "self_attn": ( + "q_norm:!", + "q_proj:0", + "k_norm:!", + "k_proj:0", + "v_norm:!", + "v_proj:0", + "o_proj:1", + ), + "post_attention_layernorm": ("post_attention_layernorm:!",), + "pre_feedforward_layernorm": ("pre_feedforward_layernorm:!",), + "mlp": ("gate_proj:0", "up_proj:0", "down_proj:1"), + "post_feedforward_layernorm": ("post_feedforward_layernorm:!",), + "per_layer_input_gate": ("per_layer_input_gate:0",), + "post_per_layer_input_norm": ("post_per_layer_input_norm:!",), + "per_layer_projection": ("per_layer_projection:1",), + }, + ] + + +def _capture_gemma4_positional_inputs(model_def, args, kwargs, batch_device): + """Preserve Gemma 4 per-layer adapter inputs that flow through decoder layers positionally.""" + + layer_input = super(type(model_def), model_def).capture_first_layer_positional_inputs(args, kwargs, batch_device) + per_layer_input = args[1] if len(args) > 1 else kwargs.get("per_layer_input") + if per_layer_input is not None: + layer_input.append(move_to(per_layer_input, device=batch_device)) + return layer_input + + +def _prepare_gemma4_replay_kwargs(model_def, layer, layer_input, additional_inputs, target_device): + """Refresh Gemma 4 rotary kwargs per layer so replay follows sliding/full attention boundaries.""" + + rotary_path = getattr(model_def, "rotary_embedding", None) + if not rotary_path or not layer_input: + return additional_inputs + + rotary, _ = get_module_by_name_prefix(model_def.model, [rotary_path]) + if rotary is None: + return additional_inputs + + layer_type = getattr(getattr(layer, "self_attn", None), "layer_type", None) + if layer_type is None: + return additional_inputs + + hidden_states = layer_input[0] + seq_len = hidden_states.shape[1] if hidden_states.dim() >= 2 else hidden_states.shape[0] + batch_dim = hidden_states.shape[0] if hidden_states.dim() >= 2 else 1 + + position_ids = additional_inputs.get("position_ids") + if position_ids is None or position_ids.shape[-1] != seq_len: + position_ids = torch.arange(seq_len, device=target_device, dtype=torch.long).unsqueeze(0).expand(batch_dim, -1) + additional_inputs["position_ids"] = position_ids + + try: + rotary_device = get_device(rotary) + except Exception: + rotary_device = position_ids.device + + rotary_position_ids = move_to(position_ids, device=rotary_device) + rotary_input = torch.empty(1, device=rotary_device, dtype=hidden_states.dtype) + additional_inputs["position_embeddings"] = nested_move_to( + rotary(rotary_input, rotary_position_ids, layer_type), + device=target_device, + ) + + if len(layer_input) == 1: + all_per_layer_inputs = additional_inputs.pop(_GEMMA4_ALL_PER_LAYER_INPUTS, None) + layer_index = getattr(getattr(layer, "self_attn", None), "layer_idx", None) + if all_per_layer_inputs is not None and layer_index is not None: + additional_inputs["per_layer_input"] = move_to( + all_per_layer_inputs[:, :, layer_index, :], + device=target_device, + ) + else: + additional_inputs.pop(_GEMMA4_ALL_PER_LAYER_INPUTS, None) + + return additional_inputs + + +def _resolve_gemma4_language_model(model_def): + """Return the Gemma 4 text stack that owns per-layer input projection state.""" + + if hasattr(model_def.model, "model") and hasattr(model_def.model.model, "language_model"): + return model_def.model.model.language_model + return model_def.model.model + + +def _patch_gemma4_per_layer_input_capture(model_def): + """Capture projected per-layer inputs during calibration so later decoder replays can slice them by layer.""" + + language_model = _resolve_gemma4_language_model(model_def) + if getattr(language_model, "_gptqmodel_project_per_layer_inputs_patched", False): + return + + original = language_model.project_per_layer_inputs + + def patched(self, inputs_embeds, per_layer_inputs=None): + result = original(inputs_embeds, per_layer_inputs) + setattr(self, "_gptqmodel_cached_all_per_layer_inputs", result) + return result + + language_model._gptqmodel_original_project_per_layer_inputs = original + language_model.project_per_layer_inputs = MethodType(patched, language_model) + language_model._gptqmodel_project_per_layer_inputs_patched = True + + +def _restore_gemma4_per_layer_input_capture(model_def): + """Restore Gemma 4 per-layer input helpers after calibration capture completes.""" + + language_model = _resolve_gemma4_language_model(model_def) + original = getattr(language_model, "_gptqmodel_original_project_per_layer_inputs", None) + if original is not None: + language_model.project_per_layer_inputs = original + delattr(language_model, "_gptqmodel_original_project_per_layer_inputs") + if hasattr(language_model, "_gptqmodel_project_per_layer_inputs_patched"): + delattr(language_model, "_gptqmodel_project_per_layer_inputs_patched") + if hasattr(language_model, "_gptqmodel_cached_all_per_layer_inputs"): + delattr(language_model, "_gptqmodel_cached_all_per_layer_inputs") + + +class Gemma4TextQModel(LlamaQModel): + """Quantization definition for text-only Gemma 4 checkpoints.""" + + # Gemma 4 mixes optional KV projections and per-layer residual adapters across variants. + layer_modules_strict = False + # Gemma 4 input preparation uses per-layer embeddings, so batch quantization stays conservative. + support_batch_quantize = False + pre_lm_head_norm_module = "model.norm" + rotary_embedding = "model.rotary_emb" + module_tree = _gemma4_module_tree() + + def capture_first_layer_positional_inputs(self, args, kwargs, batch_device): + """Keep Gemma 4 per-layer adapter inputs when decoder layers are replayed in isolation.""" + + return _capture_gemma4_positional_inputs(self, args, kwargs, batch_device) + + def capture_first_layer_input_kwargs(self, args, kwargs, batch_device, layer_input_kwargs): + """Persist Gemma 4 per-layer adapter tensors for later decoder replays.""" + + layer_input_kwargs = super().capture_first_layer_input_kwargs(args, kwargs, batch_device, layer_input_kwargs) + language_model = _resolve_gemma4_language_model(self) + all_per_layer_inputs = getattr(language_model, "_gptqmodel_cached_all_per_layer_inputs", None) + if all_per_layer_inputs is not None: + layer_input_kwargs[_GEMMA4_ALL_PER_LAYER_INPUTS] = move_to(all_per_layer_inputs, device=batch_device) + return layer_input_kwargs + + def prepare_layer_replay_kwargs(self, layer, layer_input, additional_inputs, target_device): + """Refresh Gemma 4 layer kwargs during cached replay.""" + + return _prepare_gemma4_replay_kwargs(self, layer, layer_input, additional_inputs, target_device) + + def pre_quantize_generate_hook_start(self): + _patch_gemma4_per_layer_input_capture(self) + + def pre_quantize_generate_hook_end(self): + _restore_gemma4_per_layer_input_capture(self) + super().pre_quantize_generate_hook_end() + + +class Gemma4ForConditionalGenerationGPTQ(BaseQModel): + """Quantization definition for composite Gemma 4 checkpoints.""" + + # Gemma 4 composite checkpoints share the same decoder quirks as the text-only model. + layer_modules_strict = False + support_batch_quantize = False + pre_lm_head_norm_module = "model.language_model.norm" + rotary_embedding = "model.language_model.rotary_emb" + + module_tree = [ + "model", + "language_model", + "layers", + "#", + { + "input_layernorm": ("input_layernorm:!",), + "self_attn": ( + "q_norm:!", + "q_proj:0", + "k_norm:!", + "k_proj:0", + "v_norm:!", + "v_proj:0", + "o_proj:1", + ), + "post_attention_layernorm": ("post_attention_layernorm:!",), + "pre_feedforward_layernorm": ("pre_feedforward_layernorm:!",), + "mlp": ("gate_proj:0", "up_proj:0", "down_proj:1"), + "post_feedforward_layernorm": ("post_feedforward_layernorm:!",), + "per_layer_input_gate": ("per_layer_input_gate:0",), + "post_per_layer_input_norm": ("post_per_layer_input_norm:!",), + "per_layer_projection": ("per_layer_projection:1",), + }, + ] + + def capture_first_layer_positional_inputs(self, args, kwargs, batch_device): + """Keep Gemma 4 per-layer adapter inputs when decoder layers are replayed in isolation.""" + + return _capture_gemma4_positional_inputs(self, args, kwargs, batch_device) + + def capture_first_layer_input_kwargs(self, args, kwargs, batch_device, layer_input_kwargs): + """Persist Gemma 4 per-layer adapter tensors for later decoder replays.""" + + layer_input_kwargs = super().capture_first_layer_input_kwargs(args, kwargs, batch_device, layer_input_kwargs) + language_model = _resolve_gemma4_language_model(self) + all_per_layer_inputs = getattr(language_model, "_gptqmodel_cached_all_per_layer_inputs", None) + if all_per_layer_inputs is not None: + layer_input_kwargs[_GEMMA4_ALL_PER_LAYER_INPUTS] = move_to(all_per_layer_inputs, device=batch_device) + return layer_input_kwargs + + def prepare_layer_replay_kwargs(self, layer, layer_input, additional_inputs, target_device): + """Refresh Gemma 4 layer kwargs during cached replay.""" + + return _prepare_gemma4_replay_kwargs(self, layer, layer_input, additional_inputs, target_device) + + def pre_quantize_generate_hook_start(self): + _patch_gemma4_per_layer_input_capture(self) + + def pre_quantize_generate_hook_end(self): + _restore_gemma4_per_layer_input_capture(self) + super().pre_quantize_generate_hook_end() diff --git a/gptqmodel/models/definitions/glm.py b/gptqmodel/models/definitions/glm.py index 2309efa9d..3f09fc26c 100644 --- a/gptqmodel/models/definitions/glm.py +++ b/gptqmodel/models/definitions/glm.py @@ -18,6 +18,6 @@ class GlmQModel(BaseQModel): "input_layernorm": ("input_layernorm:!",), "self_attn": ("q_proj:0", "k_proj:0", "v_proj:0", "o_proj:1"), "post_attention_layernorm": ("post_attention_layernorm:!",), - "mlp": ("gate_up_proj:0", "down_proj:1"), + "mlp": ("gate_proj:0", "up_proj:0", "down_proj:1"), } ] diff --git a/gptqmodel/models/definitions/glm4_moe.py b/gptqmodel/models/definitions/glm4_moe.py index 3b26390f5..8a4001347 100644 --- a/gptqmodel/models/definitions/glm4_moe.py +++ b/gptqmodel/models/definitions/glm4_moe.py @@ -23,7 +23,7 @@ class GLM4MoEGPTQ(BaseQModel): # Set to False since GLM-4.5-Air may have dynamic module structures layer_modules_strict = False - out_of_model_tensor_files = ["mtp.safetensors"] + out_of_model_tensors = {"files": ["mtp.safetensors"]} # MoE lifecycle hooks for gate_proj/up_proj/down_proj pattern moe_lifecycle_hooks = GateUpDownMoELifecycleHooks() diff --git a/gptqmodel/models/definitions/glm4_moe_lite.py b/gptqmodel/models/definitions/glm4_moe_lite.py new file mode 100644 index 000000000..c3b23e039 --- /dev/null +++ b/gptqmodel/models/definitions/glm4_moe_lite.py @@ -0,0 +1,39 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from ..base import BaseQModel +from ..moe_lifecycle import GateUpDownMoELifecycleHooks + + +class Glm4MoeLiteQModel(BaseQModel): + dynamic_expert_index = "n_routed_experts" + + pre_lm_head_norm_module = "model.norm" + + layer_modules_strict = False + + moe_lifecycle_hooks = GateUpDownMoELifecycleHooks() + + module_tree = [ + "model", + "layers", + "#", + { + "input_layernorm": ("input_layernorm:!",), + "self_attn": ("q_proj:0", "q_a_proj:0", "kv_a_proj_with_mqa:0", "q_b_proj:1", "kv_b_proj:1", "o_proj:2"), + "post_attention_layernorm": ("post_attention_layernorm:!",), + "mlp:moe:?": { + "gate": ("gate:!",), + "experts:0": { + "#": ("gate_proj:0", "up_proj:0", "down_proj:1"), + }, + "shared_experts:0": ("gate_proj:0", "up_proj:0", "down_proj:1"), + "": ("gate_proj:0", "up_proj:0", "down_proj:1"), + }, + } + ] + + +__all__ = ["Glm4MoeLiteQModel"] diff --git a/gptqmodel/models/definitions/glm4v.py b/gptqmodel/models/definitions/glm4v.py index 322b558dd..63d3b4fe1 100644 --- a/gptqmodel/models/definitions/glm4v.py +++ b/gptqmodel/models/definitions/glm4v.py @@ -1,38 +1,10 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium -from torch import nn from transformers import AutoModelForImageTextToText -from transformers.activations import ACT2FN from ..base import BaseQModel -class Glm4vTextMLPNew(nn.Module): - def __init__(self, config, ori_mlp=None): - super().__init__() - self.config = config - dtype = None - device = None - if ori_mlp is not None: - dtype = ori_mlp.gate_up_proj.weight.dtype - device = ori_mlp.gate_up_proj.weight.device - - self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False, dtype=dtype, device=device) - self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False, dtype=dtype, device=device) - self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False, dtype=dtype, device=device) - self.activation_fn = ACT2FN[config.hidden_act] - - if ori_mlp is not None: - gate_w, up_w = ori_mlp.gate_up_proj.weight.data.split(config.intermediate_size, dim=0) - self.gate_proj.weight.data.copy_(gate_w) - self.up_proj.weight.data.copy_(up_w) - self.down_proj.weight.data.copy_(ori_mlp.down_proj.weight.data) - - def forward(self, hidden_states): - gate = self.gate_proj(hidden_states) - up = self.up_proj(hidden_states) - return self.down_proj(up * self.activation_fn(gate)) - class Glm4vGPTQ(BaseQModel): loader = AutoModelForImageTextToText @@ -50,9 +22,3 @@ class Glm4vGPTQ(BaseQModel): "mlp": ("gate_proj:0", "up_proj:0", "down_proj:1"), } ] - - def before_model_load(self, load_quantized_model=False): - if load_quantized_model: - import transformers.models.glm4v.modeling_glm4v as glm4v_modeling - - glm4v_modeling.Glm4vTextMLP= Glm4vTextMLPNew diff --git a/gptqmodel/models/definitions/glm_moe_dsa.py b/gptqmodel/models/definitions/glm_moe_dsa.py new file mode 100644 index 000000000..ba8c92dcc --- /dev/null +++ b/gptqmodel/models/definitions/glm_moe_dsa.py @@ -0,0 +1,55 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from ..base import BaseQModel +from ..moe_lifecycle import GateUpDownMoELifecycleHooks + + +class GlmMoeDsaQModel(BaseQModel): + # GLM-5 and GLM-5.1 currently share the same modeling config and both resolve + # to transformers model_type `glm_moe_dsa`. + # The first three decoder blocks are dense MLPs, with later blocks switching + # to routed experts plus a shared-expert branch. + layer_modules_strict = False + + dynamic_expert_index = "n_routed_experts" + + pre_lm_head_norm_module = "model.norm" + rotary_embedding = "model.rotary_emb" + + moe_lifecycle_hooks = GateUpDownMoELifecycleHooks() + + module_tree = [ + "model", + "layers", + "#", + { + "input_layernorm": ("input_layernorm:!",), + "self_attn": ( + # GLM-5 / GLM-5.1 use MLA attention plus a DSA indexer. `q_proj` + # is an optional fallback path; current public configs use q_a/q_b. + "q_proj:0", + "q_a_proj:0", + "kv_a_proj_with_mqa:0", + "indexer.wk:0", + "q_b_proj:1", + "kv_b_proj:1", + "indexer.wq_b:1", + "o_proj:2", + ), + "post_attention_layernorm": ("post_attention_layernorm:!",), + "mlp:moe": { + "gate": ("gate:!",), + "experts": { + "#": ("gate_proj:0", "up_proj:0", "down_proj:1"), + }, + "shared_experts": ("gate_proj:0", "up_proj:0", "down_proj:1"), + # Dense fallback for the first `mlp_layer_types == "dense"` blocks. + "": ("gate_proj:0", "up_proj:0", "down_proj:1"), + }, + }, + ] + +__all__ = ["GlmMoeDsaQModel"] diff --git a/gptqmodel/models/definitions/gpt_oss.py b/gptqmodel/models/definitions/gpt_oss.py index a8061dab1..6cf424a23 100644 --- a/gptqmodel/models/definitions/gpt_oss.py +++ b/gptqmodel/models/definitions/gpt_oss.py @@ -2,129 +2,9 @@ # SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium -import torch -import torch.nn.functional as F -from torch import nn - from ..base import BaseQModel -class GptOssExpertsNew(nn.Module): - def __init__(self, config, ori_experts=None): - super().__init__() - self.intermediate_size = config.intermediate_size - self.num_experts = config.num_local_experts - self.hidden_size = config.hidden_size - self.expert_dim = self.intermediate_size - self.alpha = 1.702 - self.limit = 7.0 - self.quantizing = False - - self.gate_up = nn.ModuleList([ - nn.Linear(self.hidden_size, 2 * self.expert_dim, dtype=config.dtype) - for _ in range(self.num_experts) - ]) - - self.down = nn.ModuleList([ - nn.Linear(self.expert_dim, self.hidden_size, dtype=config.dtype) - for _ in range(self.num_experts) - ]) - - if ori_experts is not None: - self.quantizing = True - for i in range(self.num_experts): - tgt_gu_w = self.gate_up[i].weight # [2E, H] - tgt_gu_b = self.gate_up[i].bias # [2E] - tgt_d_w = self.down[i].weight # [H, E] - tgt_d_b = self.down[i].bias # [H] - - gu_w_src = ori_experts.gate_up_proj[i].detach().t().contiguous() - gu_b_src = ori_experts.gate_up_proj_bias[i].detach() - d_w_src = ori_experts.down_proj[i].detach().t().contiguous() - d_b_src = ori_experts.down_proj_bias[i].detach() - - with torch.inference_mode(): - tgt_gu_w.copy_(gu_w_src) - tgt_gu_b.copy_(gu_b_src) - tgt_d_w.copy_(d_w_src) - tgt_d_b.copy_(d_b_src) - - def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weights=None) -> torch.Tensor: - if self.quantizing: - # For quantization, we need to trigger computation of all experts - batch_size = hidden_states.shape[0] - hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size) - num_experts = routing_weights.shape[1] - - hidden_states = hidden_states.repeat(num_experts, 1) - hidden_states = hidden_states.view(num_experts, -1, self.hidden_size) - gate_up = torch.stack([proj(hidden_states[i]) for i, proj in enumerate(self.gate_up)]) - gate, up = gate_up[..., ::2], gate_up[..., 1::2] - gate = gate.clamp(min=None, max=self.limit) - up = up.clamp(min=-self.limit, max=self.limit) - glu = gate * torch.sigmoid(gate * self.alpha) - next_states = torch.stack([proj((up[i] + 1) * glu[i]) for i, proj in enumerate(self.down)]) - next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size) - next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None] - next_states = next_states.sum(dim=0) - - return next_states - - # For non-quantization forward pass, reduce forward pass time by only computing active experts - batch_size, seq_len = hidden_states.shape[0], hidden_states.shape[1] if len(hidden_states.shape) > 2 else 1 - hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size) - - active_experts = torch.unique(router_indices.flatten()) - final_output = torch.zeros_like(hidden_states) - for expert_idx in active_experts: - expert_mask = (router_indices == expert_idx).any(dim=-1) # (num_tokens,) - if not expert_mask.any(): - continue - - expert_tokens = hidden_states[expert_mask] # (selected_tokens, hidden_size) - - gate_up_output = self.gate_up[expert_idx](expert_tokens) # (selected_tokens, 2*expert_dim) - gate, up = gate_up_output[..., ::2], gate_up_output[..., 1::2] - - gate = gate.clamp(min=None, max=self.limit) - up = up.clamp(min=-self.limit, max=self.limit) - glu = gate * torch.sigmoid(gate * self.alpha) - - expert_output = self.down[expert_idx]((up + 1) * glu) # (selected_tokens, hidden_size) - - expert_weights = routing_weights[expert_mask, expert_idx].unsqueeze(-1) # (selected_tokens, 1) - - final_output[expert_mask] += expert_output * expert_weights - - if seq_len > 1: - final_output = final_output.view(batch_size, seq_len, self.hidden_size) - else: - final_output = final_output.view(batch_size, self.hidden_size) - - return final_output - -class GptOssTopKRouterNew(nn.Module): - def __init__(self, config, ori_router=None): - super().__init__() - self.top_k = config.num_experts_per_tok - self.num_experts = config.num_local_experts - self.hidden_dim = config.hidden_size - self.weight = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim)) - self.bias = nn.Parameter(torch.empty(self.num_experts)) - - if ori_router is not None: - with torch.inference_mode(): - self.weight.copy_(ori_router.weight.detach()) - self.bias.copy_(ori_router.bias.detach()) - - def forward(self, hidden_states): - hidden_states = hidden_states.reshape(-1, self.hidden_dim) - router_logits = F.linear(hidden_states, self.weight.to(hidden_states.dtype), self.bias.to(hidden_states.dtype)) # (seq_len, num_experts) - router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) - router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) - router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) - return router_scores, router_indices - class GPTOSSGPTQ(BaseQModel): dynamic_expert_index = "num_local_experts" @@ -140,16 +20,8 @@ class GPTOSSGPTQ(BaseQModel): "post_attention_layernorm": ("post_attention_layernorm:!",), "mlp": { "experts": { - "gate_up": {"#": ("#")}, - "down": {"#": ("#")}, - } + "#": ("gate_proj:0", "up_proj:0", "down_proj:1"), + }, } } ] - - def before_model_load(self, load_quantized_model=False): - if load_quantized_model: - import transformers.models.gpt_oss.modeling_gpt_oss as gpt_oss_modeling - - gpt_oss_modeling.GptOssExperts = GptOssExpertsNew - gpt_oss_modeling.GptOssTopKRouter = GptOssTopKRouterNew diff --git a/gptqmodel/models/definitions/granitemoehybrid.py b/gptqmodel/models/definitions/granitemoehybrid.py index 94440c4b4..a5f080b8f 100644 --- a/gptqmodel/models/definitions/granitemoehybrid.py +++ b/gptqmodel/models/definitions/granitemoehybrid.py @@ -8,6 +8,7 @@ class GraniteMoeHybridQModel(BaseQModel): pre_lm_head_norm_module = "model.norm" + require_monkeypatch = True layer_modules_strict = False @@ -23,3 +24,40 @@ class GraniteMoeHybridQModel(BaseQModel): "post_attention_layernorm": ("post_attention_layernorm:!",), } ] + + def monkey_patch(self): + from gptqmodel.nn_modules.qlinear import BaseQuantLinear + + mamba_layer_cls = type(self.model.model.layers[0].mamba) + original_forward = mamba_layer_cls.forward + + def granitemoehybrid_mamba_forward( + layer_self, + hidden_states, + cache_params=None, + cache_position=None, + attention_mask=None, + seq_idx=None, + **kwargs, + ): + if isinstance(layer_self.in_proj, BaseQuantLinear) or isinstance(layer_self.out_proj, BaseQuantLinear): + if seq_idx is not None: + raise NotImplementedError( + "`seq_idx` support requires fast path support. Please install `mamba_ssm` and `causal_conv1d`" + ) + dtype = hidden_states.dtype + if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1: + hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype) + return layer_self.torch_forward(hidden_states, cache_params, cache_position, attention_mask) + + return original_forward( + layer_self, + hidden_states, + cache_params=cache_params, + cache_position=cache_position, + attention_mask=attention_mask, + seq_idx=seq_idx, + **kwargs, + ) + + mamba_layer_cls.forward = granitemoehybrid_mamba_forward diff --git a/gptqmodel/models/definitions/llama4.py b/gptqmodel/models/definitions/llama4.py index 1cee5a177..e5f7b0afb 100644 --- a/gptqmodel/models/definitions/llama4.py +++ b/gptqmodel/models/definitions/llama4.py @@ -32,56 +32,10 @@ class Llama4QModel(BaseQModel): "self_attn": ("q_proj:0", "k_proj:0", "v_proj:0", "o_proj:1"), "post_attention_layernorm": ("post_attention_layernorm:!",), "feed_forward:moe": { - "experts": { + "experts:0": { "#": ("gate_proj:0", "up_proj:0", "down_proj:1"), }, - "shared_expert": ("gate_proj:0", "up_proj:0", "down_proj:1"), + "shared_expert:0": ("gate_proj:0", "up_proj:0", "down_proj:1"), }, } ] - - def before_model_load(self, load_quantized_model=False): - if load_quantized_model: - import torch - import torch.nn as nn - import transformers.models.llama4.modeling_llama4 as llama4_modeling - from transformers.integrations.hub_kernels import use_kernel_forward_from_hub - - @use_kernel_forward_from_hub("Llama4TextMoe") - class SequentialLlama4TextMoe(torch.nn.Module): - def __init__(self, config): - super().__init__() - self.top_k = config.num_experts_per_tok - self.hidden_dim = config.hidden_size - print(config) - self.num_experts = 16 - self.experts = nn.ModuleList( - [llama4_modeling.Llama4TextMLP(config) for _ in range(self.num_experts)] - ) - self.router = llama4_modeling.Llama4Router(config) - self.shared_expert = llama4_modeling.Llama4TextMLP(config) - - def forward(self, hidden_states: torch.Tensor): - hidden_states = hidden_states.reshape(-1, self.hidden_dim) - router_logits = self.router(hidden_states) - if isinstance(router_logits, tuple): - router_scores, router_logits = router_logits - router_scores = router_scores.t() - else: - # transformers < 4.54.0 only returns router_logits - router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1) - - router_scores = ( - torch.full_like(router_logits, float("-inf")) - .scatter_(1, router_indices, router_top_value) - .transpose(0, 1) - ) - router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype) - - out = self.shared_expert(hidden_states) - for i in range(self.num_experts): - out += self.experts[i](hidden_states) * router_scores[i].reshape(-1, 1) - - return out, router_logits - - llama4_modeling.Llama4TextMoe = SequentialLlama4TextMoe diff --git a/gptqmodel/models/definitions/minicpm_o.py b/gptqmodel/models/definitions/minicpm_o.py new file mode 100644 index 000000000..7fb2c6246 --- /dev/null +++ b/gptqmodel/models/definitions/minicpm_o.py @@ -0,0 +1,417 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from copy import deepcopy +from importlib import import_module +from typing import Any, Dict, Iterable, List, Optional, Tuple + +import torch +from transformers import AutoModel, AutoProcessor, ProcessorMixin +from transformers.generation.utils import GenerationMixin + +from ...utils.audio import process_audio_info +from ...utils.calibration import batched +from ...utils.image import fetch_image +from ...utils.model import MODALITY, move_to, nested_move_to +from ...utils.offload import offload_to_disk +from .._const import CPU +from ..base import BaseQModel + + +class Cache: + is_compileable = False + + def __init__(self): + super().__init__() + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + raise NotImplementedError("Make sure to implement `update` in a subclass.") + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") + + def get_max_cache_shape(self) -> Optional[int]: + raise NotImplementedError("Make sure to implement `get_max_cache_shape` in a subclass.") + + def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: + max_length = self.get_max_cache_shape() + previous_seq_length = self.get_seq_length(layer_idx) + if max_length is not None and previous_seq_length + new_seq_length > max_length: + return max_length - new_seq_length + return previous_seq_length + + def reorder_cache(self, beam_idx: torch.LongTensor): + for layer_idx in range(len(self.key_cache)): + if self.key_cache[layer_idx].numel(): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + if self.value_cache[layer_idx].numel(): + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + @property + def seen_tokens(self): + if hasattr(self, "_seen_tokens"): + return self._seen_tokens + else: + return None + + +class DynamicCache(Cache): + def __init__(self, config=None, _distributed_cache_data: Iterable = None, offloading: bool = False, **kwargs) -> None: + super().__init__() + self._seen_tokens = 0 + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] + self.offloading = offloading + + if _distributed_cache_data is not None: + for key_states, value_states in _distributed_cache_data: + self.key_cache.append(key_states) + self.value_cache.append(value_states) + + def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: + if layer_idx < len(self): + return (self.key_cache[layer_idx], self.value_cache[layer_idx]) + else: + raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") + + def __iter__(self): + for layer_idx in range(len(self)): + yield (self.key_cache[layer_idx], self.value_cache[layer_idx]) + + def __len__(self): + return len(self.key_cache) + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + + if key_states is not None: + if len(self.key_cache) <= layer_idx: + for _ in range(len(self.key_cache), layer_idx): + self.key_cache.append(torch.tensor([])) + self.value_cache.append(torch.tensor([])) + self.key_cache.append(key_states) + self.value_cache.append(value_states) + elif ( + not self.key_cache[layer_idx].numel() + ): + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + is_empty_layer = ( + len(self.key_cache) == 0 + or len(self.key_cache) <= layer_idx + or not self.key_cache[layer_idx].numel() + ) + layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0 + return layer_seq_length + + def get_mask_sizes(self, query_length: int, layer_idx: int = 0) -> Tuple[int, int]: + return self.get_seq_length(layer_idx) + query_length, 0 + + def get_max_cache_shape(self) -> Optional[int]: + return None + + @property + def is_sliding(self) -> List[bool]: + return [False] * len(self.key_cache) + + def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: + legacy_cache = () + for layer_idx in range(len(self)): + legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) + return legacy_cache + + @classmethod + def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache": + cache = cls() + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx] + cache.update(key_states, value_states, layer_idx) + return cache + + def crop(self, max_length: int): + if max_length < 0: + max_length = self.get_seq_length() - abs(max_length) + + if self.get_seq_length() <= max_length: + return + + self._seen_tokens = max_length + for idx in range(len(self.key_cache)): + if self.key_cache[idx].numel(): + self.key_cache[idx] = self.key_cache[idx][..., :max_length, :] + self.value_cache[idx] = self.value_cache[idx][..., :max_length, :] + + def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicCache"]: + out = [] + for i in range(0, full_batch_size, split_size): + current_split = DynamicCache() + current_split._seen_tokens = self._seen_tokens + current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache] + current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache] + out.append(current_split) + return out + + @classmethod + def from_batch_splits(cls, splits: List["DynamicCache"]) -> "DynamicCache": + cache = cls() + for idx in range(len(splits[0])): + key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx].numel()] + value_cache = [current.value_cache[idx] for current in splits if current.value_cache[idx].numel()] + if key_cache != []: + layer_keys = torch.cat(key_cache, dim=0) + layer_values = torch.cat(value_cache, dim=0) + cache.update(layer_keys, layer_values, idx) + return cache + + def batch_repeat_interleave(self, repeats: int): + for layer_idx in range(len(self)): + self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0) + self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0) + + def batch_select_indices(self, indices: torch.Tensor): + for layer_idx in range(len(self)): + self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...] + self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] + + +def _patch_minicpmo_remote_prepare_inputs_for_generation(remote_module) -> None: + remote_module.prepare_inputs_for_generation = GenerationMixin.prepare_inputs_for_generation + +class MiniCPMOQModel(BaseQModel): + loader = AutoModel + + pre_lm_head_norm_module = "llm.model.norm" + + require_trust_remote_code = True + + module_tree = [ + "llm", + "model", + "layers", + "#", + { + "input_layernorm": ("input_layernorm:!",), + "self_attn": ("q_proj:0", "k_proj:1", "v_proj:2", "o_proj:3"), + "post_attention_layernorm": ("post_attention_layernorm:!",), + "mlp": ("gate_proj", "up_proj", "down_proj"), + } + ] + + modality = [MODALITY.TEXT, MODALITY.IMAGE_TO_TEXT] + require_load_processor = True + + def before_model_load(self, model_local_path: str, load_quantized_model: bool): + from transformers import cache_utils + from transformers.dynamic_module_utils import get_class_from_dynamic_module + from transformers.generation import utils + cache_utils.Cache = Cache + cache_utils.DynamicCache = DynamicCache + utils.DynamicCache = DynamicCache + + tts_config_cls = get_class_from_dynamic_module( + "configuration_minicpmo.MiniCPMTTSConfig", + model_local_path, + ) + + print("tts_config_cls", tts_config_cls) + + if not hasattr(tts_config_cls, "top_p"): + tts_config_cls.top_p = 1.0 + + if not hasattr(tts_config_cls, "top_k"): + tts_config_cls.top_k = 50 + + if not hasattr(tts_config_cls, "repetition_penalty"): + tts_config_cls.repetition_penalty = 1.0 + + # MiniCPM-o remote code binds `DynamicCache` into module globals at import + # time (`from transformers.cache_utils import DynamicCache`). Rebind those + # globals as well so the compat cache is used even if the dynamic modules + # were imported before this hook ran. + remote_model_cls = get_class_from_dynamic_module( + "modeling_minicpmo.MiniCPMO", + model_local_path, + ) + remote_module = import_module(remote_model_cls.__module__) + remote_module.Cache = Cache + remote_module.DynamicCache = DynamicCache + _patch_minicpmo_remote_prepare_inputs_for_generation(remote_module) + + def pre_quantize_generate_hook_start(self): + self.shell_module_materialize(self.model.llm.model.embed_tokens, self.quantize_config.device) + self.shell_module_materialize(self.model.llm.model.rotary_emb, self.quantize_config.device) + self.shell_module_materialize(self.model.vpm, self.quantize_config.device) + self.shell_module_materialize(self.model.resampler, self.quantize_config.device) + self.shell_module_materialize(self.model.apm, self.quantize_config.device) + self.shell_module_materialize(self.model.audio_avg_pooler, self.quantize_config.device) + self.shell_module_materialize(self.model.audio_projection_layer, self.quantize_config.device) + + def pre_quantize_generate_hook_end(self): + if self.quantize_config.offload_to_disk: + for module in ( + self.model.llm.model.embed_tokens, + self.model.llm.model.rotary_emb, + ): + offload_to_disk( + model=self.model.llm.model, + module=module, + disk_path=self.quantize_config.offload_to_disk_path, + ) + + for module in ( + self.model.vpm, + self.model.resampler, + self.model.apm, + self.model.audio_avg_pooler, + self.model.audio_projection_layer, + ): + offload_to_disk( + model=self.model, + module=module, + disk_path=self.quantize_config.offload_to_disk_path, + ) + return + + self.model.llm.model.embed_tokens = move_to(self.model.llm.model.embed_tokens, device=CPU) + self.model.llm.model.rotary_emb = move_to(self.model.llm.model.rotary_emb, device=CPU) + self.model.vpm = move_to(self.model.vpm, device=CPU) + self.model.resampler = move_to(self.model.resampler, device=CPU) + self.model.apm = move_to(self.model.apm, device=CPU) + self.model.audio_avg_pooler = move_to(self.model.audio_avg_pooler, device=CPU) + self.model.audio_projection_layer = move_to(self.model.audio_projection_layer, device=CPU) + if hasattr(self.model, "tts"): + self.model.tts = move_to(self.model.tts, device=CPU) + + def preprocess_dataset(self, sample: Dict) -> Dict: + return sample + + def load_processor(self) -> ProcessorMixin: + return AutoProcessor.from_pretrained(self.model_local_path, trust_remote_code=True) + + @staticmethod + def _normalize_conversation( + conversation: list[dict], + ) -> tuple[list[dict], list, list[int]]: + normalized = [] + images = [] + audio_parts = [] + + for index, message in enumerate(deepcopy(conversation)): + content = message.get("content", "") + if isinstance(content, str): + normalized.append(message) + continue + + text_parts = [] + for item in content: + if isinstance(item, str): + text_parts.append(item) + continue + + item_type = item.get("type") + if item_type == "image": + images.append(fetch_image(item)) + text_parts.append("./") + elif item_type == "audio": + audio_parts.append(index) + text_parts.append("") + elif item_type == "text": + text_parts.append(item.get("text", "")) + else: + raise ValueError(f"Unsupported MiniCPM-o content type: {item_type}") + + message["content"] = "\n".join(part for part in text_parts if part) + normalized.append(message) + + return normalized, images, audio_parts + + @classmethod + def prepare_inputs_for_conversations( + cls, + processor: ProcessorMixin, + conversations: list[dict] | list[list[dict]], + ): + if conversations and isinstance(conversations[0], dict): + conversations = [conversations] + + prompts = [] + images = [] + audios = [] + audio_parts = [] + + for conversation in conversations: + normalized, image_inputs, audio_part_inputs = cls._normalize_conversation(conversation) + audio_inputs = process_audio_info(conversation, use_audio_in_video=False) or [] + + prompts.append( + processor.tokenizer.apply_chat_template( + normalized, + tokenize=False, + ) + ) + images.append(image_inputs) + audios.append(audio_inputs) + audio_parts.append(audio_part_inputs) + + inputs = processor( + prompts, + images, + audios, + audio_parts, + return_tensors="pt", + ) + inputs.pop("image_sizes") + return inputs + + def prepare_dataset(self, calibration_dataset, batch_size: int = 1, **kwargs): + processor = self.processor or self.load_processor() + calib_data = [] + for batch in batched(calibration_dataset, batch_size, process_func=self.preprocess_dataset): + calib_data.append( + self.prepare_inputs_for_conversations( + processor, + batch, + ) + ) + return calib_data + + def move_input_capture_example(self, example, data_device): + for key, value in example.items(): + example[key] = nested_move_to(value, device=data_device) + + return self.finalize_input_capture_example(example) + + def run_input_capture(self, example, use_cache: bool, data_device): + generation_config = self.model.prepare_generation_config(do_sample=True) + generation_config["use_cache"] = use_cache + + return self.model.generate( + **example, + tokenizer=self.model.tokenizer, + **generation_config, + ) diff --git a/gptqmodel/models/definitions/minicpm_v.py b/gptqmodel/models/definitions/minicpm_v.py new file mode 100644 index 000000000..1ec451d4f --- /dev/null +++ b/gptqmodel/models/definitions/minicpm_v.py @@ -0,0 +1,193 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from copy import deepcopy +from typing import Dict + +from PIL import Image +from transformers import AutoModel, AutoProcessor, ProcessorMixin + +from ...utils.calibration import batched +from ...utils.image import fetch_image +from ...utils.model import MODALITY, move_to, nested_move_to +from ...utils.offload import offload_to_disk +from .._const import CPU +from ..base import BaseQModel + + +def _allow_minicpmv_remote_tokenizer() -> None: + try: + from transformers.models.auto import tokenization_auto + except Exception: + return + + incompatible = getattr(tokenization_auto, "MODELS_WITH_INCORRECT_HUB_TOKENIZER_CLASS", None) + if isinstance(incompatible, set): + incompatible.discard("minicpmv") + +class MiniCPMVQModel(BaseQModel): + loader = AutoModel + + pre_lm_head_norm_module = "llm.model.norm" + + module_tree = [ + "llm", + "model", + "layers", + "#", + { + "input_layernorm": ("input_layernorm:!",), + "self_attn": ("q_proj:0", "k_proj:1", "v_proj:2", "o_proj:3"), + "post_attention_layernorm": ("post_attention_layernorm:!",), + "mlp": ("gate_proj", "up_proj", "down_proj"), + } + ] + + modality = [MODALITY.TEXT, MODALITY.IMAGE_TO_TEXT] + require_load_processor = True + require_trust_remote_code = True + + def before_model_load(self, model_local_path: str, load_quantized_model: bool): + _allow_minicpmv_remote_tokenizer() + + def pre_quantize_generate_hook_start(self): + self.shell_module_materialize(self.model.llm.model.embed_tokens, self.quantize_config.device) + self.shell_module_materialize(self.model.llm.model.rotary_emb, self.quantize_config.device) + self.shell_module_materialize(self.model.vpm, self.quantize_config.device) + self.shell_module_materialize(self.model.resampler, self.quantize_config.device) + + def pre_quantize_generate_hook_end(self): + if self.quantize_config.offload_to_disk: + offload_to_disk( + model=self.model.llm.model, + module=self.model.llm.model.embed_tokens, + disk_path=self.quantize_config.offload_to_disk_path, + ) + offload_to_disk( + model=self.model.llm.model, + module=self.model.llm.model.rotary_emb, + disk_path=self.quantize_config.offload_to_disk_path, + ) + offload_to_disk( + model=self.model, + module=self.model.vpm, + disk_path=self.quantize_config.offload_to_disk_path, + ) + offload_to_disk( + model=self.model, + module=self.model.resampler, + disk_path=self.quantize_config.offload_to_disk_path, + ) + return + + self.model.llm.model.embed_tokens = move_to(self.model.llm.model.embed_tokens, device=CPU) + self.model.llm.model.rotary_emb = move_to(self.model.llm.model.rotary_emb, device=CPU) + self.model.vpm = move_to(self.model.vpm, device=CPU) + self.model.resampler = move_to(self.model.resampler, device=CPU) + + def preprocess_dataset(self, sample: Dict) -> Dict: + return sample + + def load_processor(self) -> ProcessorMixin: + return AutoProcessor.from_pretrained(self.model_local_path, trust_remote_code=True) + + @staticmethod + def _normalize_conversation( + conversation: list[dict], + ) -> tuple[list[dict], list[Image.Image]]: + normalized = [] + images = [] + + for message in deepcopy(conversation): + content = message.get("content", "") + if isinstance(content, str): + normalized.append(message) + continue + + cur_msgs = [] + for item in content: + if isinstance(item, str): + cur_msgs.append(item) + continue + + item_type = item.get("type") + if item_type == "image": + images.append(fetch_image(item)) + cur_msgs.append("(./)") + elif item_type == "text": + cur_msgs.append(item.get("text", "")) + else: + raise ValueError(f"Unsupported MiniCPM-V content type: {item_type}") + + message["content"] = "\n".join(cur_msgs) + normalized.append(message) + + return normalized, images + + @classmethod + def prepare_inputs_for_conversations( + cls, + processor: ProcessorMixin, + conversations: list[dict] | list[list[dict]], + ): + if conversations and isinstance(conversations[0], dict): + conversations = [conversations] + + prompts = [] + images = [] + for conversation in conversations: + normalized, image_inputs = cls._normalize_conversation(conversation) + prompts.append( + processor.tokenizer.apply_chat_template( + normalized, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + ) + images.append(image_inputs) + + inputs = processor( + prompts, + images, + return_tensors="pt", + ) + inputs.pop("image_sizes") + return inputs + + def prepare_dataset(self, calibration_dataset, batch_size: int = 1, **kwargs): + processor = self.load_processor() + calib_data = [] + for batch in batched(calibration_dataset, batch_size, process_func=self.preprocess_dataset): + calib_data.append( + self.prepare_inputs_for_conversations( + processor, + batch, + ) + ) + del processor + return calib_data + + def move_input_capture_example(self, example, data_device): + for key, value in example.items(): + example[key] = nested_move_to(value, device=data_device) + + return self.finalize_input_capture_example(example) + + def run_input_capture(self, example, use_cache: bool, data_device): + generation_config = { + "temperature": 0.7, + "do_sample": True, + "top_p": 0.8, + "top_k": 100, + "repetition_penalty": 1.03, + "use_cache": use_cache, + } + + return self.model.generate( + **example, + tokenizer=self.model.tokenizer, + **generation_config, + ) diff --git a/gptqmodel/models/definitions/mixtral.py b/gptqmodel/models/definitions/mixtral.py index de4498c37..3172572bf 100644 --- a/gptqmodel/models/definitions/mixtral.py +++ b/gptqmodel/models/definitions/mixtral.py @@ -4,15 +4,19 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium from ..base import BaseQModel -from ..moe_lifecycle import W1W3W2MoELifecycleHooks +from ..moe_lifecycle import GateUpDownMoELifecycleHooks class MixtralQModel(BaseQModel): pre_lm_head_norm_module = "model.norm" - # MoE lifecycle hooks for w1/w3/w2 pattern - moe_lifecycle_hooks = W1W3W2MoELifecycleHooks() + dynamic_expert_index = "num_local_experts" + # MoE lifecycle hooks for gate_proj/up_proj/down_proj pattern + moe_lifecycle_hooks = GateUpDownMoELifecycleHooks() + + # The first alias in each token is the runtime shell name. Later aliases are + # checkpoint-side names that LazyTurtle may resolve directly from module_tree. module_tree = [ "model", "layers", @@ -21,9 +25,10 @@ class MixtralQModel(BaseQModel): "input_layernorm": ("input_layernorm:!",), "self_attn": ("q_proj:0", "k_proj:0", "v_proj:0", "o_proj:1"), "post_attention_layernorm": ("post_attention_layernorm:!",), - "block_sparse_moe:moe": { + "mlp|block_sparse_moe:moe:?": { + "gate": ("gate:!",), "experts": { - "#": ("w1:0", "w3:0", "w2:1"), + "#": ("gate_proj|w1:0", "up_proj|w3:0", "down_proj|w2:1"), } } } diff --git a/gptqmodel/models/definitions/olmoe.py b/gptqmodel/models/definitions/olmoe.py index ba37008cb..210863e3c 100644 --- a/gptqmodel/models/definitions/olmoe.py +++ b/gptqmodel/models/definitions/olmoe.py @@ -4,11 +4,11 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium from .._const import EXPERT_INDEX_PLACEHOLDER -from ..base import BaseGPTQModel +from ..base import BaseQModel # Both DeepSeek-v2 and DeepSeek-v2-lite are supported in this model def -class OlmoeGPTQ(BaseGPTQModel): +class OlmoeGPTQ(BaseQModel): dynamic_expert_index = "num_experts" diff --git a/gptqmodel/models/definitions/ovis.py b/gptqmodel/models/definitions/ovis.py index f0aba0519..0d064fe2e 100644 --- a/gptqmodel/models/definitions/ovis.py +++ b/gptqmodel/models/definitions/ovis.py @@ -114,13 +114,7 @@ def preprocess_dataset(self, sample: Dict) -> Dict: "labels": labels, } - def prepare_dataset( - self, - calibration_dataset, - calibration_dataset_concat_size, - batch_size: int = 1, - tokenizer=None, - **kwargs): + def prepare_dataset(self,calibration_dataset,batch_size: int = 1, **kwargs): calib_data = [] for batch in batched(calibration_dataset, batch_size, self.preprocess_dataset): pixel_values, input_ids, labels = tuple([instance[key] for instance in batch] @@ -148,7 +142,34 @@ def prepare_dataset( return calib_data - def generate(self, inputs, **kwargs): + def generate(self, inputs=None, **kwargs): """shortcut for model.generate""" + if inputs is None: + inputs = kwargs.pop("input_ids", None) with torch.inference_mode(), torch.amp.autocast(device_type=self.device.type): return self.model.generate(inputs, **kwargs) + + def move_input_capture_example(self, example, data_device): + for key, value in example.items(): + if isinstance(value, list): + for index, item in enumerate(value): + if not torch.is_tensor(item): + continue + + if item.ndim == 1: + item = item.unsqueeze(0) + + value[index] = move_to( + item.to(self.model.visual_tokenizer.dtype), + device=data_device, + ) + elif torch.is_tensor(value): + if value.ndim == 1: + value = value.unsqueeze(0) + + example[key] = move_to(value, device=data_device) + + return self.finalize_input_capture_example(example) + + def run_input_capture(self, example, use_cache: bool, data_device): + return self.model.generate(inputs=example.pop("input_ids"), **example) diff --git a/gptqmodel/models/definitions/phi3.py b/gptqmodel/models/definitions/phi3.py index 84ab4683d..e40467aea 100644 --- a/gptqmodel/models/definitions/phi3.py +++ b/gptqmodel/models/definitions/phi3.py @@ -19,7 +19,7 @@ class Phi3QModel(BaseQModel): ] class PhiMoEGPTQForCausalLM(BaseQModel): - require_pkgs = ["transformers<=4.44.2"] + dynamic_expert_index = "num_local_experts" module_tree = [ "model", @@ -29,9 +29,9 @@ class PhiMoEGPTQForCausalLM(BaseQModel): "input_layernorm": ("input_layernorm:!",), "self_attn": ("q_proj:0", "k_proj:0", "v_proj:0", "o_proj:1"), "post_attention_layernorm": ("post_attention_layernorm:!",), - "block_sparse_moe:moe": { + "mlp:moe:?": { "experts": { - "#": ("w1:0", "w2:1"), + "#": ("gate_proj:0", "up_proj:0", "down_proj:1"), }, }, } diff --git a/gptqmodel/models/definitions/qwen2_moe.py b/gptqmodel/models/definitions/qwen2_moe.py index f634f10d7..6298676b2 100644 --- a/gptqmodel/models/definitions/qwen2_moe.py +++ b/gptqmodel/models/definitions/qwen2_moe.py @@ -27,6 +27,7 @@ class Qwen2MoeQModel(BaseQModel): "post_attention_layernorm": ("post_attention_layernorm:!",), "mlp:moe:?": { "gate": ("gate:!",), + "shared_expert_gate": ("shared_expert_gate:!",), "shared_expert:0": ("gate_proj:0", "up_proj:0", "down_proj:1"), "experts:0": { "#": ("gate_proj:0", "up_proj:0", "down_proj:1"), @@ -34,17 +35,3 @@ class Qwen2MoeQModel(BaseQModel): }, } ] - - # module_tree_overrides = { - # METHOD.AWQ: [ - # { - # "mlp:moe:?": { - # "gate": ("gate:!",), - # "shared_expert": None, - # "experts": { - # "#": ("gate_proj:0", "up_proj:0", "down_proj:1"), - # }, - # }, - # } - # ] - # } diff --git a/gptqmodel/models/definitions/qwen3_5.py b/gptqmodel/models/definitions/qwen3_5.py index 71a782e44..81af8b29c 100644 --- a/gptqmodel/models/definitions/qwen3_5.py +++ b/gptqmodel/models/definitions/qwen3_5.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium +from transformers import AutoModelForImageTextToText from transformers.models.qwen3_5 import Qwen3_5TextConfig from . import LlamaQModel @@ -15,11 +16,22 @@ class Qwen3_5QModel(LlamaQModel): """ config_class = Qwen3_5TextConfig + loader = AutoModelForImageTextToText + require_load_processor = True + + # Transformers' Qwen3.5 SDPA path currently errors when calibration batches + # contain multiple padded samples, so quantization must stay single-sample. + support_batch_quantize = False layer_modules_strict = False + pre_lm_head_norm_module = "model.language_model.norm" + + rotary_embedding = "model.language_model.rotary_emb" + module_tree = [ "model", + "language_model", "layers", "#", { @@ -27,8 +39,11 @@ class Qwen3_5QModel(LlamaQModel): "self_attn": ("q_norm:!", "q_proj:0", "k_norm:!", "k_proj:0", "v_proj:0", "o_proj:1"), "linear_attn": ( "norm:!", + "conv1d:!", "in_proj_qkv:0", "in_proj_z:1", + "in_proj_b:!:1", + "in_proj_a:!:1", "out_proj:2", ), "post_attention_layernorm": ("post_attention_layernorm:!",), diff --git a/gptqmodel/models/definitions/qwen3_5_moe.py b/gptqmodel/models/definitions/qwen3_5_moe.py index 9d9b7ed8e..5f5fabd0f 100644 --- a/gptqmodel/models/definitions/qwen3_5_moe.py +++ b/gptqmodel/models/definitions/qwen3_5_moe.py @@ -12,6 +12,8 @@ class Qwen3_5_MoeQModel(BaseQModel): loader = AutoModelForImageTextToText + require_load_processor = True + layer_modules_strict = False require_monkeypatch = False @@ -20,7 +22,11 @@ class Qwen3_5_MoeQModel(BaseQModel): # config.num_experts contains the actual expert count used for index dynamic_expert_index = "num_experts" - pre_lm_head_norm_module = "model.norm" + pre_lm_head_norm_module = "model.language_model.norm" + + rotary_embedding = "model.language_model.rotary_emb" + + out_of_model_tensors = {"prefixes": ["mtp"]} # awq scaling optimizations requires some modules within same subset to strictly match the shape of previous module # the o_proj must match v_proj or else scaling optimizations are skipped (GQA vs MHA) @@ -37,20 +43,24 @@ class Qwen3_5_MoeQModel(BaseQModel): "#", { "input_layernorm": ("input_layernorm:!",), - "self_attn": ("norm:!", "q_proj:0", "k_proj:0", "v_proj:0", "o_proj:1"), + "self_attn": ("q_norm:!", "q_proj:0", "k_norm:!", "k_proj:0", "v_proj:0", "o_proj:1"), "linear_attn": ( "norm:!", + "conv1d:!", "in_proj_qkv:0", "in_proj_z:1", + "in_proj_b:!:1", + "in_proj_a:!:1", "out_proj:2", ), "post_attention_layernorm": ("post_attention_layernorm:!",), "mlp:moe:?": { "gate": ("gate:!",), # <-- 0.5MB per layer. Not worth quantizing - "experts": { + "shared_expert_gate": ("shared_expert_gate:!",), + "experts:0": { "#": ("gate_proj:0", "up_proj:0", "down_proj:1"), }, - "shared_experts": ("gate_proj:0", "up_proj:0", "down_proj:1"), + "shared_expert:0": ("gate_proj:0", "up_proj:0", "down_proj:1"), }, } ] diff --git a/gptqmodel/models/definitions/qwen3_moe.py b/gptqmodel/models/definitions/qwen3_moe.py index e8a90c080..d9e6cabc1 100644 --- a/gptqmodel/models/definitions/qwen3_moe.py +++ b/gptqmodel/models/definitions/qwen3_moe.py @@ -30,7 +30,7 @@ class Qwen3MoeQModel(BaseQModel): "#", { "input_layernorm": ("input_layernorm:!",), - "self_attn": ("q_proj:0", "k_proj:0", "v_proj:0", "o_proj:1"), + "self_attn": ("q_norm:!", "k_norm:!", "q_proj:0", "k_proj:0", "v_proj:0", "o_proj:1"), "post_attention_layernorm": ("post_attention_layernorm:!",), "mlp:moe:?": { "gate": ("gate:!",), # <-- 0.5MB per layer. Not worth quantizing diff --git a/gptqmodel/models/definitions/qwen3_next.py b/gptqmodel/models/definitions/qwen3_next.py index a895fa388..c63012d1d 100644 --- a/gptqmodel/models/definitions/qwen3_next.py +++ b/gptqmodel/models/definitions/qwen3_next.py @@ -36,19 +36,19 @@ class Qwen3NextGPTQ(BaseQModel): { "input_layernorm": ("input_layernorm:!",), # Token mixers - #"self_attn": ("k_proj", "v_proj", "q_proj", "o_proj"), - "linear_attn": ("in_proj_qkvz", "in_proj_ba:!", "out_proj"), # conv1d intentionally excluded + "self_attn": ("q_norm:!", "k_norm:!", "q_proj:0", "k_proj:0", "v_proj:0", "o_proj:1"), + "linear_attn": ("norm:!", "conv1d:!", "in_proj_qkvz:0", "in_proj_ba:!:0", "out_proj:1"), "post_attention_layernorm": ("post_attention_layernorm:!",), # MLP / MoE "mlp:moe": { # MoE router + shared expert (Qwen3NextSparseMoeBlock) "gate": ("gate:!",), # router gate linear "shared_expert_gate": ("shared_expert_gate:!",), # <-- single (1, N) logic projections should not be quantized - "shared_expert": ("gate_proj", "up_proj", "down_proj"), + "shared_expert:0": ("gate_proj:0", "up_proj:0", "down_proj:1"), # Experts list with dynamic index - "experts": { - "#": ("gate_proj", "up_proj", "down_proj"), + "experts:0": { + "#": ("gate_proj:0", "up_proj:0", "down_proj:1"), }, }, }, diff --git a/gptqmodel/models/definitions/qwen3_omni_moe.py b/gptqmodel/models/definitions/qwen3_omni_moe.py index 4e540775f..7e01da0dc 100644 --- a/gptqmodel/models/definitions/qwen3_omni_moe.py +++ b/gptqmodel/models/definitions/qwen3_omni_moe.py @@ -38,10 +38,10 @@ class Qwen3OmniMoeGPTQ(BaseQModel): "#", { "input_layernorm": ("input_layernorm:!",), - "self_attn": ("q_proj:0", "k_proj:0", "v_proj:0", "o_proj:1"), + "self_attn": ("q_norm:!", "k_norm:!", "q_proj:0", "k_proj:0", "v_proj:0", "o_proj:1"), "post_attention_layernorm": ("post_attention_layernorm:!",), "mlp:moe": { - "gate": ("gate",), + "gate": ("gate:!",), # router gate is tiny and accuracy-sensitive "experts": { "#": ("gate_proj:0", "up_proj:0", "down_proj:1"), }, diff --git a/gptqmodel/models/loader.py b/gptqmodel/models/loader.py index 18e8c4c88..3b20d7595 100644 --- a/gptqmodel/models/loader.py +++ b/gptqmodel/models/loader.py @@ -5,17 +5,19 @@ from __future__ import annotations +import copy import os import time from importlib.metadata import PackageNotFoundError, version from itertools import chain from typing import Dict, List, Optional, Union +import numpy as np import torch import transformers from ..utils.modelscope import ensure_modelscope_available -from ..utils.structure import print_module_tree +from ..utils.structure import LazyTurtle, print_module_tree if ensure_modelscope_available(): @@ -27,19 +29,40 @@ from packaging.version import InvalidVersion, Version from transformers import AutoConfig, AutoTokenizer, PretrainedConfig from transformers.utils import is_flash_attn_2_available -from transformers.utils.generic import ContextManagers from ..adapter.adapter import Adapter -from ..nn_modules.qlinear.exllamav2 import ExllamaV2QuantLinear +from ..nn_modules.exllamav3 import ExllamaV3Linear +from ..nn_modules.exllamav3_torch import ExllamaV3TorchLinear +from ..nn_modules.qlinear.exllamav2 import ExllamaV2Linear +from ..nn_modules.qlinear.gguf import GGUFTorchLinear from ..quantization import QuantizeConfig -from ..quantization.config import FORMAT, METHOD, MIN_VERSION_WITH_V2 -from ..utils.backend import BACKEND -from ..utils.hf import no_init_weights -from ..utils.importer import auto_select_device, normalize_device_device_map, select_quant_linear +from ..quantization.config import FORMAT, METHOD, MIN_VERSION_WITH_V2, BaseQuantizeConfig, resolve_quant_format +from ..utils import internal_gguf +from ..utils.backend import BACKEND, PROFILE, normalize_backend, normalize_profile +from ..utils.exllamav3 import replace_exllamav3_placeholders +from ..utils.hf import ( + INTERNAL_HF_GGUF_FILE_KWARG, + get_hf_config_dtype, + get_hf_gguf_load_kwargs, + has_native_transformers_causallm_support, + normalize_hf_config_compat, + normalize_model_id_or_path_for_hf_gguf, + normalize_torch_dtype_kwarg, + prepare_remote_model_init_compat, + resolve_trust_remote_code, + set_hf_config_dtype, + suspend_hf_weight_init, +) +from ..utils.importer import ( + auto_select_device, + get_kernel_for_backend, + normalize_device_device_map, + select_quant_linear, +) from ..utils.inspect import safe_kwargs_call from ..utils.logger import setup_logger from ..utils.machete import _validate_machete_device_support -from ..utils.marlin import _validate_marlin_device_support +from ..utils.marlin import _marlin_capability_supported, _validate_marlin_device_support from ..utils.model import ( auto_dtype, convert_gptq_v1_to_v2_format, @@ -60,6 +83,82 @@ ATTN_IMPLEMENTATION = "attn_implementation" +def _should_print_module_tree() -> bool: + """Keep expensive module-tree dumps opt-in during model loading.""" + + raw = os.environ.get("GPTQMODEL_PRINT_MODULE_TREE") + if raw is None: + return False + return raw.strip().lower() in {"1", "true", "yes", "on", "y", "t"} + + +def _maybe_print_module_tree(model) -> None: + """Print the module tree only when explicitly requested for debugging.""" + + if _should_print_module_tree(): + print_module_tree(model=model) + + +def _supports_flash_attn_2(config: PretrainedConfig) -> bool: + """Detect whether the resolved HF architecture exposes FA2 kernels.""" + + if not getattr(config, "architectures", None): + return False + + model_class = getattr(transformers, config.architectures[0], None) + if model_class is None: + return False + + if hasattr(model_class, "_supports_flash_attn_2"): + return bool(getattr(model_class, "_supports_flash_attn_2")) + if hasattr(model_class, "_supports_flash_attn"): + return bool(getattr(model_class, "_supports_flash_attn")) + return False + + +def _is_accelerated_attention_device(device: object) -> bool: + """Return True when the selected device can run CUDA/ROCm flash attention.""" + + if isinstance(device, torch.device): + return device.type in {"cuda", "hip"} + if isinstance(device, DEVICE): + return device in {DEVICE.CUDA, DEVICE.ROCM} + if isinstance(device, str): + return device in {"cuda", "rocm", "hip"} + return False + + +def _resolve_native_gguf_profile( + *, + native_gguf_qspec: Optional["internal_gguf.GGUFQuantizedCheckpointSpec"], + profile: PROFILE, +) -> PROFILE: + """Resolve user profile intent for native GGUF checkpoints.""" + + if ( + native_gguf_qspec is not None + and native_gguf_qspec.tensor_qtype == internal_gguf.GGMLQuantizationType.Q1_0_g128 + and profile == PROFILE.AUTO + ): + log.info("Loader: Bonsai/Prism Q1_0_g128 PROFILE.AUTO resolved to PROFILE.FAST.") + return PROFILE.FAST + return profile + + +def _should_use_dense_native_gguf_path( + *, + native_gguf_qspec: Optional["internal_gguf.GGUFQuantizedCheckpointSpec"], + profile: PROFILE, +) -> bool: + """Fast Bonsai mode stays on the dense HF GGUF import path.""" + + return ( + native_gguf_qspec is not None + and native_gguf_qspec.tensor_qtype == internal_gguf.GGMLQuantizationType.Q1_0_g128 + and profile == PROFILE.FAST + ) + + def parse_version_string(version_str: str): try: return Version(version_str) @@ -92,6 +191,34 @@ def compare_versions(installed_version, required_version, operator): raise ValueError(f"Unsupported operator: {operator}") +def _is_meta_shell_build_error(exc: Exception) -> bool: + # Some trust_remote_code model constructors call int()/item() on tensors + # during __init__, which breaks when the shell is built on the meta device. + message = str(exc) + return "cannot be called on meta tensors" in message and ".item()" in message + + +def _coerce_quantized_awq_dtype(*, backend: BACKEND, qcfg: QuantizeConfig, dtype): + if qcfg.quant_method not in (METHOD.AWQ, METHOD.PARO): + return dtype + if backend in (None, BACKEND.AUTO, BACKEND.AUTO_TRAINABLE): + return dtype + if not isinstance(dtype, torch.dtype): + return dtype + + try: + qlinear = get_kernel_for_backend(backend, qcfg.quant_method, qcfg.format) + except ValueError: + return dtype + + supported_dtypes = getattr(qlinear, "SUPPORTS_DTYPES", None) or [] + if dtype in supported_dtypes or torch.float16 not in supported_dtypes: + return dtype + + log.info(f"Loading Quantized Model: Auto fix `dtype` to `torch.float16` for `{qlinear.__name__}`") + return torch.float16 + + def check_versions(model_class, requirements: List[str]): if requirements is None: return @@ -105,10 +232,25 @@ def check_versions(model_class, requirements: List[str]): raise ValueError(f"{model_class} requires version {req}, but {pkg} not installed.") +def set_dtype_compat(model_init_kwargs: dict, torch_dtype): + """ + Set dtype argument in a version-compatible way for Transformers. + See: https://github.com/huggingface/transformers/releases/tag/v4.56.0 + + Args: + model_init_kwargs (dict): kwargs used to initialize model + torch_dtype: torch dtype (e.g. torch.float16) + """ + if Version(transformers.__version__) >= Version("4.56.0"): + model_init_kwargs["dtype"] = torch_dtype + else: + model_init_kwargs["torch_dtype"] = torch_dtype + def get_model_local_path(pretrained_model_id_or_path, **kwargs): is_local = os.path.isdir(pretrained_model_id_or_path) - if is_local: + if is_local or os.path.isabs(pretrained_model_id_or_path): return os.path.normpath(pretrained_model_id_or_path) + kwargs.pop(INTERNAL_HF_GGUF_FILE_KWARG, None) def _log_removed(removed: list[str]): log.debug("Loader: dropping unsupported snapshot_download kwargs: %s", ", ".join(removed)) @@ -120,12 +262,159 @@ def _log_removed(removed: list[str]): ) +def _get_tokenizer_load_kwargs(model_init_kwargs: Dict) -> Dict: + return get_hf_gguf_load_kwargs(model_init_kwargs) + + +def _resolve_local_gguf_checkpoint_path(model_local_path: str, hf_gguf_load_kwargs: Dict[str, str]) -> Optional[str]: + gguf_file = hf_gguf_load_kwargs.get("gguf_file") + if not gguf_file: + return None + + checkpoint_path = os.path.join(str(model_local_path), gguf_file) + if not os.path.isfile(checkpoint_path): + return None + return checkpoint_path + + +def _resolve_native_quantized_gguf_checkpoint( + model_local_path: str, + hf_gguf_load_kwargs: Dict[str, str], +) -> tuple[Optional[str], Optional[internal_gguf.GGUFQuantizedCheckpointSpec]]: + if not internal_gguf.native_quantized_loader_enabled(): + return None, None + + gguf_checkpoint_path = _resolve_local_gguf_checkpoint_path(model_local_path, hf_gguf_load_kwargs) + if gguf_checkpoint_path is None: + return None, None + + try: + spec = internal_gguf.inspect_quantized_checkpoint(gguf_checkpoint_path) + except Exception as exc: + log.debug("Loader: failed to inspect GGUF checkpoint `%s`: %s", gguf_checkpoint_path, exc) + return None, None + + if spec is None: + return None, None + return gguf_checkpoint_path, spec + + +def _resolve_model_slot(model: torch.nn.Module, name: str) -> tuple[torch.nn.Module, str]: + module_name, _, attr_name = name.rpartition(".") + module = model.get_submodule(module_name) if module_name else model + return module, attr_name + + +def _lookup_model_slot_tensor(model: torch.nn.Module, name: str) -> torch.Tensor: + module, attr_name = _resolve_model_slot(model, name) + if attr_name in module._parameters: + return module._parameters[attr_name] + if attr_name in module._buffers: + return module._buffers[attr_name] + raise KeyError(f"Loader: model slot `{name}` does not exist.") + + +def _assign_model_slot_tensor(model: torch.nn.Module, name: str, tensor: torch.Tensor) -> None: + module, attr_name = _resolve_model_slot(model, name) + tensor = tensor.contiguous() + + if attr_name in module._parameters: + current = module._parameters[attr_name] + if current is not None and (tensor.device != current.device or tensor.dtype != current.dtype): + tensor = tensor.to(device=current.device, dtype=current.dtype) + requires_grad = current.requires_grad if isinstance(current, torch.nn.Parameter) else False + module._parameters[attr_name] = torch.nn.Parameter(tensor, requires_grad=requires_grad) + return + + if attr_name in module._buffers: + current = module._buffers[attr_name] + if current is not None and (tensor.device != current.device or tensor.dtype != current.dtype): + tensor = tensor.to(device=current.device, dtype=current.dtype) + module._buffers[attr_name] = tensor + return + + raise KeyError(f"Loader: model slot `{name}` does not exist.") + + +def _build_gguf_tensor_key_mapping(model: torch.nn.Module, config: PretrainedConfig) -> dict[str, str]: + import transformers.modeling_gguf_pytorch_utils as gguf_utils + + processor_cls = gguf_utils.TENSOR_PROCESSORS.get(config.model_type, gguf_utils.TensorProcessor) + if processor_cls is not gguf_utils.TensorProcessor: + raise NotImplementedError( + f"Loader: native quantized GGUF loading only supports the default tensor processor. " + f"Actual processor for `{config.model_type}`: `{processor_cls.__name__}`." + ) + + processor = processor_cls(config=config.to_dict()) + return gguf_utils.get_gguf_hf_weights_map(model, processor) + + +def _load_quantized_gguf_checkpoint_into_model( + *, + model: torch.nn.Module, + gguf_checkpoint_path: str, + tensor_key_mapping: dict[str, str], +) -> None: + reader = internal_gguf.GGUFReader(gguf_checkpoint_path) + loaded: set[str] = set() + + for tensor in reader.tensors: + target_name = tensor_key_mapping.get(tensor.name) + if target_name is None: + continue + + module_name, _, attr_name = target_name.rpartition(".") + target_module = model.get_submodule(module_name) if module_name else model + resolved_target_name = target_name + + if isinstance(target_module, GGUFTorchLinear) and attr_name == "weight": + resolved_target_name = f"{module_name}.qweight" if module_name else "qweight" + packed = torch.from_numpy(np.array(tensor.data, dtype=np.uint8, copy=True, order="C")) + expected = _lookup_model_slot_tensor(model, resolved_target_name) + if tuple(packed.shape) != tuple(expected.shape): + raise RuntimeError( + f"Loader: GGUF qweight shape mismatch for `{resolved_target_name}`. " + f"Expected {tuple(expected.shape)}, got {tuple(packed.shape)}." + ) + _assign_model_slot_tensor(model, resolved_target_name, packed) + loaded.add(resolved_target_name) + continue + + reference = _lookup_model_slot_tensor(model, resolved_target_name) + weights = internal_gguf.dequantize_to_torch( + tensor.data, + tensor.tensor_type, + device=reference.device, + dtype=reference.dtype, + ) + _assign_model_slot_tensor(model, resolved_target_name, weights) + loaded.add(resolved_target_name) + + missing_qweights = [] + for module_name, module in model.named_modules(): + if not isinstance(module, GGUFTorchLinear): + continue + qweight_name = f"{module_name}.qweight" if module_name else "qweight" + if qweight_name not in loaded: + missing_qweights.append(qweight_name) + if missing_qweights: + raise RuntimeError( + "Loader: GGUF checkpoint did not populate required quantized weights: " + + ", ".join(sorted(missing_qweights)) + ) + + model.tie_weights() + + def ModelLoader(cls): @classmethod def from_pretrained( cls, pretrained_model_id_or_path: str, - quantize_config: QuantizeConfig, + quantize_config: BaseQuantizeConfig, + backend: Union[str, BACKEND] = BACKEND.AUTO, + profile: Union[str, int, PROFILE] = PROFILE.AUTO, trust_remote_code: bool = False, dtype: [str | torch.dtype] = "auto", device_map: Optional[Union[str, Dict[str, Union[int, str]]]] = None, @@ -136,11 +425,35 @@ def from_pretrained( import torch._dynamo torch._dynamo.disable() - model_local_path = get_model_local_path(pretrained_model_id_or_path, **model_init_kwargs) + pretrained_model_id_or_path = normalize_model_id_or_path_for_hf_gguf( + pretrained_model_id_or_path, + model_init_kwargs, + api_name=f"{cls.__name__}.from_pretrained", + ) + + dtype = normalize_torch_dtype_kwarg( + model_init_kwargs, + api_name=f"{cls.__name__}.from_pretrained", + explicit_dtype=dtype, + ) + backend = normalize_backend(backend) + profile = normalize_profile(profile) + hf_gguf_load_kwargs = get_hf_gguf_load_kwargs(model_init_kwargs) + model_init_kwargs_without_internal = dict(model_init_kwargs) + model_init_kwargs_without_internal.pop(INTERNAL_HF_GGUF_FILE_KWARG, None) + + tokenizer_trust_remote_code = model_init_kwargs_without_internal.pop("tokenizer_trust_remote_code", trust_remote_code) + model_local_path = get_model_local_path(pretrained_model_id_or_path, **model_init_kwargs_without_internal) + trust_remote_code = resolve_trust_remote_code(model_local_path, trust_remote_code=trust_remote_code) + + model_init_kwargs_without_internal["trust_remote_code"] = trust_remote_code + + config = AutoConfig.from_pretrained(model_local_path, **model_init_kwargs_without_internal, **hf_gguf_load_kwargs) - model_init_kwargs["trust_remote_code"] = trust_remote_code + defuser.replace_fused_blocks(config.model_type) - config = AutoConfig.from_pretrained(model_local_path, **model_init_kwargs) + normalize_hf_config_compat(config, trust_remote_code=trust_remote_code) + prepare_remote_model_init_compat(model_local_path, config) atten_impl = model_init_kwargs.get("attn_implementation", None) @@ -148,27 +461,84 @@ def from_pretrained( log.info(f"Loader: overriding attn_implementation in config to `{atten_impl}`") config._attn_implementation = atten_impl + resolved_device = normalize_device_device_map(device, device_map) + resolved_device = auto_select_device(resolved_device, backend) + if cls.require_dtype: dtype = cls.require_dtype + elif dtype is None or dtype == "auto" or not isinstance(dtype, torch.dtype): + dtype = auto_dtype(config=config, device=resolved_device, quant_inference=False) - if isinstance(dtype, torch.dtype) and getattr(config, "torch_dtype", None) != dtype: + if isinstance(dtype, torch.dtype) and get_hf_config_dtype(config) != dtype: # Align config metadata with the dtype we will materialize weights in. - config.torch_dtype = dtype + set_hf_config_dtype(config, dtype) - tokenizer = AutoTokenizer.from_pretrained(pretrained_model_id_or_path, trust_remote_code=trust_remote_code) + tokenizer = AutoTokenizer.from_pretrained( + model_local_path, + trust_remote_code=tokenizer_trust_remote_code, + **_get_tokenizer_load_kwargs(model_init_kwargs), + ) - # Some models have multiple configurations. - # For example, in llama4 and qwen3_5, model_class.form_config requires TextConfig. - if cls.config_class is not None and cls.config_class == config.sub_configs.get("text_config", None): - config = config.get_text_config() + gguf_checkpoint_path, native_gguf_qspec = _resolve_native_quantized_gguf_checkpoint( + model_local_path, + hf_gguf_load_kwargs, + ) + effective_profile = _resolve_native_gguf_profile( + native_gguf_qspec=native_gguf_qspec, + profile=profile, + ) if quantize_config is None: - model_init_kwargs["device_map"] =device_map if device_map else "auto" - model_init_kwargs["dtype"] = dtype + if native_gguf_qspec is not None: + if _should_use_dense_native_gguf_path( + native_gguf_qspec=native_gguf_qspec, + profile=effective_profile, + ): + if backend != BACKEND.AUTO: + log.info( + "Loader: PROFILE.%s uses dense GGUF import for `%s`; backend `%s` is ignored.", + effective_profile.name, + gguf_checkpoint_path, + backend.value, + ) + else: + redirect_kwargs = dict(model_init_kwargs) + redirect_kwargs.pop("tokenizer_trust_remote_code", None) + log.info( + "Loader: detected native quantized GGUF checkpoint `%s`; redirecting `%s` to from_quantized() with PROFILE.%s.", + gguf_checkpoint_path, + cls.__name__, + effective_profile.name, + ) + return cls.from_quantized( + model_id_or_path=pretrained_model_id_or_path, + device_map=device_map, + device=device, + backend=backend, + dtype=dtype, + trust_remote_code=trust_remote_code, + tokenizer_trust_remote_code=tokenizer_trust_remote_code, + **redirect_kwargs, + ) + + hf_model_init_kwargs = dict(model_init_kwargs_without_internal) + hf_model_init_kwargs["device_map"] = device_map if device_map else "auto" + set_dtype_compat(hf_model_init_kwargs, dtype) + hf_model_init_kwargs.update(hf_gguf_load_kwargs) + if ( + native_gguf_qspec is not None + and native_gguf_qspec.tensor_qtype == internal_gguf.GGMLQuantizationType.Q1_0_g128 + and atten_impl in {None, "auto"} + and _is_accelerated_attention_device(resolved_device) + and (config.model_type == "qwen3" or _supports_flash_attn_2(config)) + and is_flash_attn_2_available() + ): + hf_model_init_kwargs[ATTN_IMPLEMENTATION] = "flash_attention_2" + log.info("Loader: Auto enabling flash_attention_2 for dense Bonsai PROFILE.%s.", effective_profile.name) # Load a non-quantized model, but do not perform quantization. For example, for evaluation. - model = cls.loader.from_pretrained(model_local_path, config=config, **model_init_kwargs) - model._model_init_kwargs = model_init_kwargs - print_module_tree(model=model) + model = cls.loader.from_pretrained(model_local_path, config=config, **hf_model_init_kwargs) + model._model_init_kwargs = hf_model_init_kwargs + _maybe_print_module_tree(model=model) turtle_model = None @@ -189,8 +559,8 @@ def from_pretrained( # non-quantized models are always loaded into cpu cpu_device_map = {"": "cpu"} - if quantize_config is None or not isinstance(quantize_config, QuantizeConfig): - raise AttributeError("`quantize_config` must be passed and be an instance of QuantizeConfig.") + if quantize_config is None or not isinstance(quantize_config, BaseQuantizeConfig): + raise AttributeError("`quantize_config` must be passed and be an instance of BaseQuantizeConfig.") quantize_config.calculate_bits_per_weight() @@ -202,7 +572,8 @@ def from_pretrained( raise ValueError(f"{cls} only supports desc_act={cls.supports_desc_act}, " f"but quantize_config.desc_act is {quantize_config.desc_act}.") - if cls.require_trust_remote_code and not trust_remote_code: + native_support = has_native_transformers_causallm_support(model_local_path) + if cls.require_trust_remote_code and not trust_remote_code and not native_support: raise ValueError( f"{pretrained_model_id_or_path} requires trust_remote_code=True. Please set trust_remote_code=True to load this model." ) @@ -228,12 +599,12 @@ def skip(*args, **kwargs): # enforce some values despite user specified # non-quantized models are always loaded into cpu - model_init_kwargs["device_map"] = cpu_device_map - model_init_kwargs["dtype"] = dtype - model_init_kwargs["_fast_init"] = cls.require_fast_init + model_init_kwargs_without_internal["device_map"] = cpu_device_map + set_dtype_compat(model_init_kwargs_without_internal, dtype) + model_init_kwargs_without_internal["_fast_init"] = cls.require_fast_init #model_init_kwargs["low_cpu_mem_usage"] = True - cls.before_model_load(cls, load_quantized_model=False) + cls.before_model_load(cls, model_local_path=model_local_path, load_quantized_model=False) from ..utils.hf import build_shell_model # XIELUActivation will use some weights when activation init, so can't use init_empty_weights @@ -247,33 +618,71 @@ def skip(*args, **kwargs): log.warn(f"{cls} doesn't support offload_to_disk, set quantize_config.offload_to_disk to False.") if quantize_config.offload_to_disk: - model = build_shell_model(cls.loader, config=config, **model_init_kwargs) - defuser.convert_model(model, cleanup_original=False) - model._model_init_kwargs = model_init_kwargs - print_module_tree(model=model) - - # enable mmap with low_cpu_mem_usage - turtle_spinner = log.spinner(title="Turtle model loading...", interval=0.1) + shell_config = copy.deepcopy(config) try: - turtle_model = cls.loader.from_pretrained( + model = build_shell_model(cls.loader, config=shell_config, **model_init_kwargs_without_internal) + except RuntimeError as exc: + if not _is_meta_shell_build_error(exc): + raise + + log.warn( + "Loader: meta-device shell build failed for `%s`; falling back to direct CPU load without turtle_model: %s", + model_local_path, + exc, + ) + log.info("Loader: loading model directly to CPU (meta shell unsupported; turtle_model disabled)") + fallback_init_kwargs = model_init_kwargs_without_internal.copy() + fallback_init_kwargs.pop("device_map", None) + fallback_init_kwargs["low_cpu_mem_usage"] = False + model = cls.loader.from_pretrained( model_local_path, config=config, - low_cpu_mem_usage=True, - **model_init_kwargs, + **fallback_init_kwargs, + **hf_gguf_load_kwargs, + ) + if getattr(model, "config", None) is config: + model.config = copy.deepcopy(config) + defuser.convert_model(model, cleanup_original=False) + model._model_init_kwargs = fallback_init_kwargs + _maybe_print_module_tree(model=model) + turtle_model = None + else: + defuser.convert_model(model, cleanup_original=False) + shell_model_init_kwargs = dict(model_init_kwargs_without_internal) + shell_model_init_kwargs.update(hf_gguf_load_kwargs) + model._model_init_kwargs = shell_model_init_kwargs + _maybe_print_module_tree(model=model) + turtle_model = LazyTurtle.maybe_create( + model_local_path=model_local_path, + config=model.config, + model_init_kwargs=shell_model_init_kwargs, + module_tree=copy.deepcopy(getattr(cls, "module_tree", None)), ) - finally: - turtle_spinner.close() - # TODO FIX ME...temp store model_init args - turtle_model._model_init_kwargs = model_init_kwargs - # print("actual turtle model-----------") - # print_module_tree(model=turtle_model) + if turtle_model is None: + raise RuntimeError( + f"Loader: can't open model path `{model_local_path}` for offload_to_disk." + ) + + log.info( + "Loader: using checkpoint-backed lazy turtle source for `%s`", + model_local_path, + ) else: - print("loading model directly to CPU (not using meta device or turtle_model)-----------") - model = cls.loader.from_pretrained(model_local_path, config=config, **model_init_kwargs) + log.info("Loader: loading model directly to CPU (not using meta device or turtle_model)") + model = cls.loader.from_pretrained( + model_local_path, + config=config, + **model_init_kwargs_without_internal, + **hf_gguf_load_kwargs, + ) + if getattr(model, "config", None) is config: + model.config = copy.deepcopy(config) defuser.convert_model(model, cleanup_original=False) - model._model_init_kwargs = model_init_kwargs - print_module_tree(model=model) + direct_model_init_kwargs = dict(model_init_kwargs_without_internal) + direct_model_init_kwargs.update(hf_gguf_load_kwargs) + model._model_init_kwargs = direct_model_init_kwargs + _maybe_print_module_tree(model=model) turtle_model = None @@ -325,14 +734,40 @@ def from_quantized( import torch._dynamo torch._dynamo.reset() + model_id_or_path = normalize_model_id_or_path_for_hf_gguf( + model_id_or_path, + kwargs, + api_name=f"{cls.__name__}.from_quantized", + ) + dtype = normalize_torch_dtype_kwarg( + kwargs, + api_name=f"{cls.__name__}.from_quantized", + explicit_dtype=dtype, + ) + hf_gguf_load_kwargs = get_hf_gguf_load_kwargs(kwargs) + kwargs_without_internal = dict(kwargs) + kwargs_without_internal.pop(INTERNAL_HF_GGUF_FILE_KWARG, None) + tokenizer_trust_remote_code = kwargs_without_internal.pop("tokenizer_trust_remote_code", trust_remote_code) + requested_device_map = device_map + explicit_device_map = requested_device_map if isinstance(requested_device_map, dict) else None + + if requested_device_map is None: + explicit_device = None + if isinstance(device, str) and ":" in device: + explicit_device = device + elif isinstance(device, torch.device) and device.index is not None: + explicit_device = str(device) + + if explicit_device is not None: + explicit_device_map = {"": explicit_device} + requested_device_map = explicit_device_map # normalized device + device_map into single device - normalized_device = device if device_map is None else None # let device_map dictate placement when present - device = normalize_device_device_map(normalized_device, device_map) + normalized_device = device if requested_device_map is None else None # let device_map dictate placement when present + device = normalize_device_device_map(normalized_device, requested_device_map) - # TODO need to normalize backend and others in a unified api - if isinstance(backend, str): - backend = BACKEND(backend) + # Keep string inputs compatible while allowing canonical method-prefixed names. + backend = normalize_backend(backend) device = auto_select_device(device, backend) if backend == BACKEND.VLLM: @@ -341,27 +776,29 @@ def from_quantized( # to optimize vllm inference, set an environment variable 'VLLM_ATTENTION_BACKEND' to 'FLASHINFER'. os.environ['VLLM_ATTENTION_BACKEND'] = 'FLASHINFER' + model_local_path = get_model_local_path(model_id_or_path, **kwargs_without_internal) + trust_remote_code = resolve_trust_remote_code(model_local_path, trust_remote_code=trust_remote_code) + native_support = has_native_transformers_causallm_support(model_local_path) + """load quantized model from local disk""" - if cls.require_trust_remote_code and not trust_remote_code: + if cls.require_trust_remote_code and not trust_remote_code and not native_support: raise ValueError( f"{model_id_or_path} requires trust_remote_code=True. Please set trust_remote_code=True to load this model." ) check_versions(cls, cls.require_pkgs) - model_local_path = get_model_local_path(model_id_or_path, **kwargs) - # Parameters related to loading from Hugging Face Hub - cache_dir = kwargs.pop("cache_dir", None) - force_download = kwargs.pop("force_download", False) - resume_download = kwargs.pop("resume_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", False) - use_auth_token = kwargs.pop("use_auth_token", None) - revision = kwargs.pop("revision", None) - subfolder = kwargs.pop("subfolder", "") - commit_hash = kwargs.pop("_commit_hash", None) - attn_implementation = kwargs.pop("attn_implementation", None) + cache_dir = kwargs_without_internal.pop("cache_dir", None) + force_download = kwargs_without_internal.pop("force_download", False) + resume_download = kwargs_without_internal.pop("resume_download", False) + proxies = kwargs_without_internal.pop("proxies", None) + local_files_only = kwargs_without_internal.pop("local_files_only", False) + use_auth_token = kwargs_without_internal.pop("use_auth_token", None) + revision = kwargs_without_internal.pop("revision", None) + subfolder = kwargs_without_internal.pop("subfolder", "") + commit_hash = kwargs_without_internal.pop("_commit_hash", None) + attn_implementation = kwargs_without_internal.pop("attn_implementation", None) cached_file_kwargs = { "cache_dir": cache_dir, @@ -382,8 +819,14 @@ def from_quantized( model_local_path, trust_remote_code=trust_remote_code, **cached_file_kwargs, + **hf_gguf_load_kwargs, ) + defuser.replace_fused_blocks(config.model_type) + + normalize_hf_config_compat(config, trust_remote_code=trust_remote_code) + prepare_remote_model_init_compat(model_local_path, config) + if cls.require_dtype: dtype = cls.require_dtype @@ -391,19 +834,76 @@ def from_quantized( # TODO FIX ME for `dynamic`, non-quantized modules should be in native type dtype = auto_dtype(config=config, device=device, quant_inference=True) - if isinstance(dtype, torch.dtype) and getattr(config, "torch_dtype", None) != dtype: + if isinstance(dtype, torch.dtype) and get_hf_config_dtype(config) != dtype: # Ensure flash attention kernels see an explicit dtype instead of relying on defaults. - config.torch_dtype = dtype + set_hf_config_dtype(config, dtype) - qcfg = QuantizeConfig.from_pretrained(model_local_path, **cached_file_kwargs, **kwargs) + gguf_checkpoint_path, native_gguf_qspec = _resolve_native_quantized_gguf_checkpoint( + model_local_path, + hf_gguf_load_kwargs, + ) + if native_gguf_qspec is not None: + qcfg = QuantizeConfig( + bits=native_gguf_qspec.bits_alias, + method=METHOD.GGUF, + lm_head=native_gguf_qspec.lm_head_quantized, + ) + else: + qcfg = QuantizeConfig.from_pretrained(model_local_path, **cached_file_kwargs, **kwargs_without_internal) + export_quant_method = qcfg.export_quant_method() + format_code = resolve_quant_format(qcfg.format, qcfg.method) + backend = normalize_backend(backend, quant_method=export_quant_method) + + # Prism/Bonsai sign-only GGUF tensors only have a torch runtime today. + # Bypass higher-priority GGUF backends that either do not support 1-bit + # formats or depend on optional external runtimes. + if ( + native_gguf_qspec is not None + and native_gguf_qspec.tensor_qtype == internal_gguf.GGMLQuantizationType.Q1_0 + ): + if backend == BACKEND.AUTO: + backend = BACKEND.GGUF_TORCH + elif backend != BACKEND.GGUF_TORCH: + raise ValueError( + "Native Q1_0 GGUF checkpoints currently require BACKEND.GGUF_TORCH. " + f"Actual backend: `{backend}`." + ) + elif ( + native_gguf_qspec is not None + and native_gguf_qspec.tensor_qtype == internal_gguf.GGMLQuantizationType.Q1_0_g128 + and backend not in {BACKEND.AUTO, BACKEND.GGUF_TORCH, BACKEND.GGUF_TRITON} + ): + raise ValueError( + "Native Q1_0_g128 GGUF checkpoints support BACKEND.AUTO, BACKEND.GGUF_TORCH, or BACKEND.GGUF_TRITON. " + f"Actual backend: `{backend}`." + ) - if qcfg.quant_method == METHOD.AWQ and qcfg.format in [FORMAT.GEMV_FAST, FORMAT.LLM_AWQ]: + if format_code == FORMAT.EXL3: + if backend not in (BACKEND.AUTO, BACKEND.EXL3_EXLLAMA_V3, BACKEND.EXL3_TORCH): + raise TypeError("FORMAT.EXL3 requires BACKEND.AUTO, BACKEND.EXL3_EXLLAMA_V3, or BACKEND.EXL3_TORCH.") + if backend == BACKEND.AUTO: + if torch.cuda.is_available() and device in (DEVICE.CUDA, DEVICE.ROCM): + backend = BACKEND.EXL3_EXLLAMA_V3 + else: + backend = BACKEND.EXL3_TORCH + if backend == BACKEND.EXL3_EXLLAMA_V3: + if not torch.cuda.is_available(): + raise ValueError("EXL3 CUDA loading requires CUDA/HIP.") + if device not in (DEVICE.CUDA, DEVICE.ROCM): + raise ValueError("EXL3 CUDA loading requires a CUDA/HIP device.") + elif format_code == FORMAT.BITSANDBYTES: + if backend not in (BACKEND.AUTO, BACKEND.BITSANDBYTES): + raise TypeError("FORMAT.BITSANDBYTES requires BACKEND.AUTO or BACKEND.BITSANDBYTES.") + backend = BACKEND.BITSANDBYTES + + if export_quant_method == METHOD.AWQ and format_code in [FORMAT.GEMV_FAST, FORMAT.LLM_AWQ]: # GEMV_FAST and LLM_AWQ only supports torch.float16 log.info("Loading Quantized Model: Auto fix `dtype` to `torch.float16`") dtype = torch.float16 - if backend == BACKEND.EXLLAMA_EORA: - # EXLLAMA_EORA only supports torch.float16 + dtype = _coerce_quantized_awq_dtype(backend=backend, qcfg=qcfg, dtype=dtype) + + if backend == BACKEND.GPTQ_PRO: log.info("Loading Quantized Model: Auto fix `dtype` to `torch.float16`") dtype = torch.float16 @@ -413,14 +913,19 @@ def from_quantized( qcfg.calculate_bits_per_weight() - tokenizer = AutoTokenizer.from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code) + tokenizer = AutoTokenizer.from_pretrained( + model_local_path, + trust_remote_code=tokenizer_trust_remote_code, + **hf_gguf_load_kwargs, + ) if backend == BACKEND.VLLM or backend == BACKEND.SGLANG: + runtime_generate = None if backend == BACKEND.VLLM: - if qcfg.format != FORMAT.GPTQ and qcfg.format != FORMAT.GEMM: + if format_code not in [FORMAT.GPTQ, FORMAT.GEMM]: raise ValueError(f"{backend} backend only supports FORMAT.GPTQ or FORMAT.GEMM: actual = {qcfg.format}") elif backend == BACKEND.SGLANG: - if qcfg.format != FORMAT.GPTQ: + if format_code != FORMAT.GPTQ: raise ValueError(f"{backend} backend only supports FORMAT.GPTQ: actual = {qcfg.format}") if backend == BACKEND.VLLM: @@ -429,13 +934,12 @@ def from_quantized( model = load_model_by_vllm( model=model_local_path, trust_remote_code=trust_remote_code, - **kwargs, + **kwargs_without_internal, ) model.config = model.llm_engine.model_config model.device = model.llm_engine.vllm_config.device_config.device - - cls.generate = lambda self, **kwargs: vllm_generate(self.model, **kwargs) + runtime_generate = vllm_generate elif backend == BACKEND.SGLANG: from ..utils.sglang import load_model_by_sglang, sglang_generate @@ -444,11 +948,11 @@ def from_quantized( model=model_local_path, trust_remote_code=trust_remote_code, dtype=torch.float16, - **kwargs, + **kwargs_without_internal, ) model.config = hf_config - cls.generate = lambda self, **kwargs: sglang_generate(self.model, **kwargs) - return cls( + runtime_generate = sglang_generate + instance = cls( model, quantized=True, quantize_config=qcfg, @@ -458,79 +962,88 @@ def from_quantized( trust_remote_code=trust_remote_code, model_local_path=model_local_path, ) + instance._runtime_generate = runtime_generate + return instance - if qcfg.format == FORMAT.MARLIN: + if format_code == FORMAT.MARLIN: # format marlin requires marlin kernel - if backend not in [BACKEND.MARLIN, BACKEND.MARLIN_FP16] and backend != BACKEND.AUTO: - raise TypeError(f"FORMAT.MARLIN requires BACKEND.AUTO or BACKEND.MARLIN: actual = `{backend}`.") - backend = BACKEND.MARLIN + expected_marlin_backend = BACKEND.AWQ_MARLIN if qcfg.quant_method == METHOD.AWQ else BACKEND.GPTQ_MARLIN + expected_marlin_backends = [expected_marlin_backend] + if backend not in expected_marlin_backends and backend != BACKEND.AUTO: + raise TypeError( + f"FORMAT.MARLIN requires BACKEND.AUTO or BACKEND.{expected_marlin_backend.name}: actual = `{backend}`." + ) + backend = expected_marlin_backend # marlin_compatible = False if backend == BACKEND.IPEX else _validate_marlin_device_support() # check for marlin compat for cuda device only - # if backend not in [BACKEND.MARLIN, BACKEND.MARLIN_FP16] and device == DEVICE.CUDA: + # if backend not in [BACKEND.GPTQ_MARLIN, BACKEND.AWQ_MARLIN] and device == DEVICE.CUDA: # unsupported = _validate_marlin_compatibility(qcfg) # if unsupported is None and marlin_compatible: # logger.info( - # "Hint: Model is compatible with the Marlin kernel. Marlin is optimized for batched inference on Nvidia GPU: `model = GPTQModel.load(..., backend=BACKEND.MARLIN)`." + # "Hint: Model is compatible with the Marlin kernel. Use the canonical Marlin BACKEND enum." # ) - if qcfg.format == FORMAT.BITBLAS: + if format_code == FORMAT.BITBLAS: # format bitblas requires bitblas kernel - if backend != BACKEND.BITBLAS and backend != BACKEND.AUTO: - raise TypeError(f"FORMAT.BITBLAS requires BACKEND.AUTO or BACKEND.BITBLAS: actual = `{backend}`.") - backend = BACKEND.BITBLAS + expected_backend = BACKEND.AWQ_BITBLAS if qcfg.quant_method == METHOD.AWQ else BACKEND.GPTQ_BITBLAS + if backend != expected_backend and backend != BACKEND.AUTO: + raise TypeError( + f"FORMAT.BITBLAS requires BACKEND.AUTO or BACKEND.{expected_backend.name}: actual = `{backend}`." + ) + backend = expected_backend - if backend == BACKEND.BITBLAS: + if backend in [BACKEND.GPTQ_BITBLAS, BACKEND.AWQ_BITBLAS]: from ..nn_modules.qlinear.bitblas import BITBLAS_AVAILABLE, BITBLAS_INSTALL_HINT if BITBLAS_AVAILABLE is False: raise ValueError(BITBLAS_INSTALL_HINT) - possible_model_basenames = [ - f"gptq_model-{qcfg.bits}bit-{qcfg.group_size}g", - "model", - ] - - extensions = [".safetensors"] - model_local_path = str(model_local_path) - - # Retrieve (and if necessary download) the quantized checkpoint(s). - is_sharded, resolved_archive_file, true_model_basename = get_checkpoints( - model_id_or_path=model_local_path, - extensions=extensions, - possible_model_basenames=possible_model_basenames, - **cached_file_kwargs, - ) - - # bin files have security issues: disable loading by default - if ".bin" in resolved_archive_file: - raise ValueError( - "Loading of .bin files are not allowed due to safety. Please convert your model to safetensor or pytorch format." + if native_gguf_qspec is not None: + is_sharded = False + model_save_name = gguf_checkpoint_path + else: + if format_code == FORMAT.EXL3: + possible_model_basenames = ["model"] + else: + possible_model_basenames = [ + f"gptq_model-{qcfg.bits}bit-{qcfg.group_size}g", + "model", + ] + + extensions = [".safetensors"] + + # Retrieve (and if necessary download) the quantized checkpoint(s). + is_sharded, resolved_archive_file, true_model_basename = get_checkpoints( + model_id_or_path=model_local_path, + extensions=extensions, + possible_model_basenames=possible_model_basenames, + **cached_file_kwargs, ) - qcfg.runtime_format = qcfg.format - - model_save_name = resolved_archive_file # In case a model is sharded, this would be `model.safetensors.index.json` which may later break. - - # == step2: convert model to gptq-model (replace Linear with QuantLinear) == # - def skip(*args, **kwargs): - pass - - torch.nn.init.kaiming_uniform_ = skip - torch.nn.init.uniform_ = skip - torch.nn.init.normal_ = skip + # bin files have security issues: disable loading by default + if ".bin" in resolved_archive_file: + raise ValueError( + "Loading of .bin files are not allowed due to safety. Please convert your model to safetensor or pytorch format." + ) - transformers.modeling_utils._init_weights = False + model_save_name = resolved_archive_file # In case a model is sharded, this would be `model.safetensors.index.json` which may later break. - init_contexts = [no_init_weights()] + qcfg.runtime_format = format_code - with (ContextManagers(init_contexts)): - cls.before_model_load(cls, load_quantized_model=True) + # == step2: convert model to gptq-model (replace Linear with QuantLinear) == # + gguf_tensor_key_mapping = None + with suspend_hf_weight_init(): + cls.before_model_load(cls, model_local_path=model_local_path, load_quantized_model=True) if config.architectures: model_class = getattr(transformers, config.architectures[0], None) - if model_class is not None and hasattr(model_class, "_supports_flash_attn_2"): - supports_flash_attn = model_class._supports_flash_attn_2 + if model_class is not None: + # backward-compatible fallback for "_supports_flash_attn" field + if hasattr(model_class, "_supports_flash_attn_2"): + supports_flash_attn = getattr(model_class, "_supports_flash_attn_2") + elif hasattr(model_class, "_supports_flash_attn"): + supports_flash_attn = getattr(model_class, "_supports_flash_attn") else: supports_flash_attn = None else: @@ -538,22 +1051,20 @@ def skip(*args, **kwargs): args = {} if supports_flash_attn and device in [DEVICE.CUDA, DEVICE.ROCM]: - if ATTN_IMPLEMENTATION in kwargs: - args[ATTN_IMPLEMENTATION] = kwargs.pop(ATTN_IMPLEMENTATION, None) + if attn_implementation is not None: + args[ATTN_IMPLEMENTATION] = attn_implementation elif is_flash_attn_2_available(): args = {ATTN_IMPLEMENTATION: "flash_attention_2"} log.info("Loader: Auto enabling flash attention2") - - # Some models have multiple configurations. - # For example, in llama4 and qwen3_5, model_class.form_config requires TextConfig. - if cls.config_class == config.sub_configs.get("text_config", None): - config = config.get_text_config() + set_dtype_compat(args, dtype) model = cls.loader.from_config( - config, trust_remote_code=trust_remote_code, dtype=dtype, **args + config, trust_remote_code=trust_remote_code, **args ) defuser.convert_model(model, cleanup_original=True) model.checkpoint_file_name = model_save_name + if native_gguf_qspec is not None: + gguf_tensor_key_mapping = _build_gguf_tensor_key_mapping(model, config) extract_layers_node = cls.extract_layers_node() # Get the first layer to determine layer type @@ -576,16 +1087,30 @@ def skip(*args, **kwargs): log.info(f"The layer {name} is not quantized.") del modules[name] - preload_qlinear_kernel = make_quant( - model, - qcfg=qcfg, - quant_result=modules, - backend=backend, - lm_head_name=cls.lm_head, - device=device, - ) + if format_code == FORMAT.EXL3: + if not isinstance(qcfg.tensor_storage, dict) or not qcfg.tensor_storage: + raise ValueError("EXL3 checkpoints require `quantization_config.tensor_storage` metadata.") - if isinstance(device_map, str) and device_map not in [ + exl3_module_cls = ExllamaV3TorchLinear if backend == BACKEND.EXL3_TORCH else ExllamaV3Linear + replace_exllamav3_placeholders( + model=model, + module_names=list(qcfg.tensor_storage.keys()), + tensor_storage=qcfg.tensor_storage, + module_cls=exl3_module_cls, + ) + preload_qlinear_kernel = exl3_module_cls + else: + preload_qlinear_kernel = make_quant( + model, + qcfg=qcfg, + quant_result=modules, + backend=backend, + lm_head_name=cls.lm_head, + device=device, + dtype=dtype, + ) + + if isinstance(requested_device_map, str) and requested_device_map not in [ "auto", "balanced", "balanced_low_0", @@ -779,18 +1304,22 @@ def assign(mod, device_id): return device_map log.info(f"Loader: device = {device}") - layers, _ = get_module_by_name_prefix(model, extract_layers_node) - num_gpus = 1 - if device is DEVICE.CUDA: - num_gpus = torch.cuda.device_count() - elif device is DEVICE.XPU: - num_gpus = torch.xpu.device_count() - device_map = build_layerwise_device_map(model, device, layers, ignore_modules, num_gpus) + if explicit_device_map is None: + layers, _ = get_module_by_name_prefix(model, extract_layers_node) + num_gpus = 1 + if device is DEVICE.CUDA: + num_gpus = torch.cuda.device_count() + elif device is DEVICE.XPU: + num_gpus = torch.xpu.device_count() + device_map = build_layerwise_device_map(model, device, layers, ignore_modules, num_gpus) + else: + device_map = dict(explicit_device_map) + log.info(f"Loader: honoring explicit device_map request: {device_map}") log.info(f"Loader: device_map = {device_map}") - load_checkpoint_in_model = True + load_checkpoint_in_model = native_gguf_qspec is None # compat: runtime convert checkpoint gptq(v1) to gptq_v2 format - if qcfg.format in [FORMAT.GPTQ, FORMAT.GEMM]: + if format_code in [FORMAT.GPTQ, FORMAT.GEMM, FORMAT.PAROQUANT]: load_checkpoint_in_model_then_tie_weights( model, dtype=dtype, @@ -803,9 +1332,9 @@ def assign(mod, device_id): load_checkpoint_in_model = False - if qcfg.format == FORMAT.GPTQ: + if format_code == FORMAT.GPTQ: # validate sym=False v1 loading needs to be protected for models produced with new v2 format codebase - if not qcfg.sym and not qcfg.is_quantized_by_gptaq(): + if not qcfg.sym and not qcfg.is_quantized_by_gptaq() and not qcfg.is_quantized_by_foem(): raise ValueError( f"Format: Loading of a sym=False model with format={FORMAT.GPTQ} is only supported if produced by gptqmodel version >= {MIN_VERSION_WITH_V2}" ) @@ -819,7 +1348,7 @@ def assign(mod, device_id): qcfg.runtime_format = FORMAT.GPTQ_V2 - if backend == BACKEND.MACHETE: + if backend in (BACKEND.GPTQ_MACHETE, BACKEND.AWQ_MACHETE): if is_sharded: raise ValueError( "Format: The loading of sharded checkpoints with Machete is currently not supported." @@ -829,23 +1358,39 @@ def assign(mod, device_id): f"Kernel: Machete kernel requires compute capability >= 9.0. Detected capability: {torch.cuda.get_device_capability()}" ) - if backend in [BACKEND.MARLIN, BACKEND.MARLIN_FP16] and ( - preload_qlinear_kernel == ExllamaV2QuantLinear or qcfg.format == FORMAT.MARLIN): + if backend in [BACKEND.GPTQ_MARLIN, BACKEND.AWQ_MARLIN] and ( + preload_qlinear_kernel == ExllamaV2Linear or format_code == FORMAT.MARLIN): if is_sharded: raise ValueError( "Format: The loading of sharded checkpoints with Marlin is currently not supported." ) - if not _validate_marlin_device_support(): - raise ValueError( - f'Kernel: Marlin kernel does not support this gpu with compute capability of `{torch.cuda.get_device_capability()}`. Please do not use `back=BACKEND.MARLIN`.' - ) + device_capability = torch.cuda.get_device_capability() + if backend == BACKEND.GPTQ_MARLIN: + if not _validate_marlin_device_support(): + raise ValueError( + "Kernel: Marlin kernel requires compute capability >= 7.5 for the " + f"GPTQ Marlin backend. Detected capability: `{device_capability}`." + ) + if device_capability == (7, 5) and dtype == torch.bfloat16: + raise ValueError( + "Kernel: GPTQ Marlin on Turing (compute capability 7.5) supports " + "dtype=torch.float16 only." + ) + elif backend == BACKEND.AWQ_MARLIN: + if not _marlin_capability_supported(*device_capability) or device_capability[0] < 8: + raise ValueError( + "Kernel: AWQ Marlin requires compute capability >= 8.0. " + f"Detected capability: `{device_capability}`." + ) - # Validate the model can run in Marlin. - if dtype != torch.float16: - raise ValueError("Marlin kernel requires dtype=torch.float16.") + # GPTQ Marlin and AWQ Marlin support fp16 and bf16 compute on Ampere+. + if backend == BACKEND.GPTQ_MARLIN and dtype not in (torch.float16, torch.bfloat16): + raise ValueError("Marlin kernel requires dtype=torch.float16 or dtype=torch.bfloat16.") + if backend == BACKEND.AWQ_MARLIN and dtype not in (torch.float16, torch.bfloat16): + raise ValueError("AWQ Marlin kernel requires dtype=torch.float16 or dtype=torch.bfloat16.") - if backend == BACKEND.BITBLAS: + if backend in [BACKEND.GPTQ_BITBLAS, BACKEND.AWQ_BITBLAS]: from ..utils.bitblas import prepare_model_for_bitblas_load # Prepare model for bitblas load. @@ -864,7 +1409,14 @@ def assign(mod, device_id): # If we use marlin or bitblas to load the quantized model, the model is already a converted model, # and we no longer need to call load_checkpoint_in_model() - if load_checkpoint_in_model and backend not in [BACKEND.MACHETE, BACKEND.MARLIN, BACKEND.MARLIN_FP16, BACKEND.BITBLAS]: + if load_checkpoint_in_model and backend not in [ + BACKEND.GPTQ_MACHETE, + BACKEND.AWQ_MACHETE, + BACKEND.GPTQ_MARLIN, + BACKEND.AWQ_MARLIN, + BACKEND.GPTQ_BITBLAS, + BACKEND.AWQ_BITBLAS, + ]: load_checkpoint_in_model_then_tie_weights( model, dtype=dtype, @@ -875,21 +1427,32 @@ def assign(mod, device_id): # offload_buffers=True, ) - # TODO: Why are we using this custom function and not dispatch_model? - model = simple_dispatch_model(model, device_map) - - qlinear_kernel = select_quant_linear( - bits=qcfg.bits, - dynamic=qcfg.dynamic, - group_size=qcfg.group_size, - desc_act=qcfg.desc_act, - sym=qcfg.sym, - backend=backend, - format=qcfg.format, - quant_method=qcfg.quant_method, - device=device, - pack_dtype=qcfg.pack_dtype, - ) + if native_gguf_qspec is not None: + model = simple_dispatch_model(model, device_map) + _load_quantized_gguf_checkpoint_into_model( + model=model, + gguf_checkpoint_path=gguf_checkpoint_path, + tensor_key_mapping=gguf_tensor_key_mapping, + ) + else: + # TODO: Why are we using this custom function and not dispatch_model? + model = simple_dispatch_model(model, device_map) + + if format_code == FORMAT.EXL3: + qlinear_kernel = ExllamaV3TorchLinear if backend == BACKEND.EXL3_TORCH else ExllamaV3Linear + else: + qlinear_kernel = select_quant_linear( + bits=qcfg.runtime_bits, + dynamic=qcfg.dynamic, + group_size=qcfg.group_size, + desc_act=qcfg.desc_act, + sym=qcfg.sym, + backend=backend, + format=format_code, + quant_method=export_quant_method, + device=device, + pack_dtype=qcfg.pack_dtype, + ) # == step4: set seqlen == # model_config = model.config.to_dict() @@ -901,8 +1464,9 @@ def assign(mod, device_id): log.warn("can't get model's sequence length from model config, will set to 4096.") model.seqlen = 4096 - # Any post-initialization that require device information, for example buffers initialization on device. - model = gptqmodel_post_init(model, use_act_order=qcfg.desc_act, quantize_config=qcfg) + if format_code != FORMAT.EXL3: + # Any post-initialization that require device information, for example buffers initialization on device. + model = gptqmodel_post_init(model, use_act_order=qcfg.desc_act, quantize_config=qcfg) model.eval() @@ -915,7 +1479,7 @@ def assign(mod, device_id): from ..utils.mlx import convert_gptq_to_mlx_weights, mlx_generate except ModuleNotFoundError as exception: raise type(exception)( - "GPTQModel load mlx model required dependencies are not installed.", + "GPT-QModel load mlx model required dependencies are not installed.", "Please install via `pip install gptqmodel[mlx] --no-build-isolation`.", ) diff --git a/gptqmodel/models/writer.py b/gptqmodel/models/writer.py index 0767fcc2e..4f2d6c9fb 100644 --- a/gptqmodel/models/writer.py +++ b/gptqmodel/models/writer.py @@ -11,16 +11,14 @@ import os import shutil from os.path import isfile, join -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, List, Optional, Union -import pcre as re +import pcre import torch -import transformers from safetensors import safe_open from safetensors.torch import save_file from transformers import AutoConfig, PreTrainedTokenizerFast, ProcessorMixin from transformers.models.auto.tokenization_auto import get_tokenizer_config -from transformers.utils.generic import ContextManagers from ..adapter.adapter import HF_ADAPTER_FILE_NAME, HF_ADAPTER_WEIGHT_KEY_PREFIX, Lora from ..adapter.peft import LoraConfig @@ -29,6 +27,7 @@ META_FIELD_ACT_GROUP_AWARE, META_FIELD_DAMP_AUTO_INCREMENT, META_FIELD_DAMP_PERCENT, + META_FIELD_FOEM_ENABLED, META_FIELD_GPTAQ_ENABLED, META_FIELD_MSE, META_FIELD_QUANTIZER, @@ -38,14 +37,23 @@ META_QUANTIZER_GPTQMODEL, META_VALUE_URI, MIN_VERSION_WITH_V2, + resolve_quant_format, ) from ..utils.backend import BACKEND -from ..utils.hf import no_init_weights, sanitize_generation_config_file +from ..utils.exllamav3 import build_exllamav3_tensor_storage +from ..utils.hf import ( + prepare_remote_code_compat, + sanitize_generation_config_file, + sanitize_model_config, + suspend_hf_weight_init, +) from ..utils.logger import setup_logger from ..utils.model import ( + TensorSource, copy_py_files, find_modules, get_model_files_size, + get_module_by_name, get_state_dict_for_save, load_checkpoint_in_model_then_tie_weights, make_quant, @@ -71,6 +79,324 @@ EORA_DEFAULT_FILE = "eora.safetensors" +# disable gptqmodel split_by layer feature (until sglang pr is merged since our dir struct is not compatible) +# SUPPORTED_SPLIT_BY = {None, "layer"} +SUPPORTED_SPLIT_BY = {None} +_MAX_SHARD_SIZE_RE = pcre.compile( + r"\s*(\d+)([KMGTP]?B?)\s*", + flags=pcre.Flag.CASELESS, +) + + +def _parse_split_by(value: Optional[str]) -> Optional[str]: + if value is None: + return None + if not isinstance(value, str): + raise TypeError("split_by must be a string or None.") + + normalized = value.strip().lower() + if normalized in ("", "none"): + return None + if normalized not in SUPPORTED_SPLIT_BY: + raise ValueError(f"Unsupported split_by value: {value}. Supported values: None, 'layer'.") + return normalized + + +def _cleanup_saved_weight_files( + save_dir: str, + expected_files: List[str], + model_base_name: str, + model_save_name: str, +) -> None: + expected = set(expected_files) + shard_pattern = pcre.compile( + rf"{pcre.escape(model_base_name)}-\d{{5}}-of-\d{{5}}\.safetensors" + ) + + for filename in os.listdir(save_dir): + full_filename = join(save_dir, filename) + if not isfile(full_filename): + continue + if filename == model_save_name and filename not in expected: + os.remove(full_filename) + continue + if filename == model_save_name + ".index.json" and filename not in expected: + os.remove(full_filename) + continue + if shard_pattern.fullmatch(filename) and filename not in expected: + os.remove(full_filename) + + + +def _resolve_out_of_model_source_files( + model_local_path: str, + source_files: Optional[List[str]] = None, +) -> List[str]: + if source_files: + return sorted(dict.fromkeys(source_files)) + + index_path = join(model_local_path, "model.safetensors.index.json") + if os.path.exists(index_path): + try: + with open(index_path, "r", encoding="utf-8") as handle: + index_data = json.load(handle) + weight_map = index_data.get("weight_map", {}) + if isinstance(weight_map, dict): + return sorted( + { + filename + for filename in weight_map.values() + if isinstance(filename, str) and filename.endswith(".safetensors") + } + ) + except Exception as exc: + log.warn(f"Model: Failed to inspect original safetensors index at '{index_path}': {exc}") + + return sorted( + filename + for filename in os.listdir(model_local_path) + if filename.endswith(".safetensors") and isfile(join(model_local_path, filename)) + ) + + +def _load_tensors_by_prefixes( + model_local_path: str, + prefixes: List[str], + source_files: Optional[List[str]] = None, +) -> Dict[str, torch.Tensor]: + # Gather tensors whose names match any of the requested prefixes. + # Gather tensors whose names match any of the requested prefixes from all available shards. + tensors: Dict[str, torch.Tensor] = {} + source_file_names = _resolve_out_of_model_source_files(model_local_path, source_files) + for source_file_name in source_file_names: + source_tensor_path = os.path.join(model_local_path, source_file_name) + if not os.path.exists(source_tensor_path): + continue + try: + with safe_open(source_tensor_path, framework="pt", device="cpu") as f: + for tensor_name in f.keys(): + if any(tensor_name.startswith(prefix) for prefix in prefixes): + if tensor_name not in tensors: + tensors[tensor_name] = f.get_tensor(tensor_name) + except Exception as exc: + log.warn( + f"Model: Failed to read tensors from {source_file_name} while scanning for prefixes " + f"{prefixes}: {exc}" + ) + return tensors + + +def _tensor_source_from_tensor(name: str, tensor: torch.Tensor) -> TensorSource: + # Create a TensorSource wrapper so the merged tensor behaves like original state_dict entries. + # Wrap a raw tensor into a TensorSource so it can be merged into state_dict. + return TensorSource( + name=name, + torch_dtype=tensor.dtype, + shape=tuple(tensor.shape), + source=tensor, + ) + + +def _merge_prefix_tensors_into_state_dict( + prefixes: List[str], model_local_path: str, state_dict: Dict[str, TensorSource] +) -> None: + # Inject matched tensors into the ongoing state_dict before sharding. + merged = 0 + normalized_prefixes = [prefix if prefix.endswith(".") else f"{prefix}." for prefix in prefixes] + tensors = _load_tensors_by_prefixes(model_local_path, normalized_prefixes) + for name, tensor in tensors.items(): + state_dict[name] = _tensor_source_from_tensor(name, tensor) + merged += 1 + if merged: + log.info(f"Model: Merged {merged} tensors with prefixes {normalized_prefixes} into the state dict") + else: + log.warn(f"Model: No tensors matched prefixes {normalized_prefixes} while merging into the state dict") + + +def _normalize_out_of_model_tensors_entries( + entries: Optional[List[Union[str, Dict[str, Any]]]] +) -> tuple[List[str], List[str]]: + # Normalize configured files/prefixes into explicit lists. + copy_files: List[str] = [] + prefixes: List[str] = [] + if not entries: + return copy_files, prefixes + + raw_entries = list(entries) if isinstance(entries, (list, tuple)) else [entries] + for entry in raw_entries: + if isinstance(entry, str): + copy_files.append(entry) + continue + if not isinstance(entry, dict): + raise TypeError("out_of_model_tensors entries must be dict.") + + files_value = entry.get("files") + if files_value is not None: + files = [files_value] if isinstance(files_value, str) else list(files_value) + for file in files: + if not isinstance(file, str) or not file: + raise ValueError("`files` entries must be non-empty strings.") + copy_files.append(file) + + prefixes_value = entry.get("prefixes") + if prefixes_value is not None: + prefix_list = [prefixes_value] if isinstance(prefixes_value, str) else list(prefixes_value) + for prefix in prefix_list: + if not isinstance(prefix, str) or not prefix: + raise ValueError("`prefixes` entries must be non-empty strings.") + prefixes.append(prefix) + + return copy_files, prefixes + + +def _resolve_layer_split_group(tensor_name: str, layer_prefixes: List[str]) -> tuple[str, bool]: + for prefix in sorted((prefix for prefix in layer_prefixes if prefix), key=len, reverse=True): + expected_prefix = f"{prefix}." + if not tensor_name.startswith(expected_prefix): + continue + remainder = tensor_name[len(expected_prefix):] + layer_idx, dot, _ = remainder.partition(".") + if layer_idx.isdigit() and dot: + return f"{prefix}.{layer_idx}", True + + if "." in tensor_name: + return tensor_name.rsplit(".", 1)[0], False + return "", False + + +def _module_is_leaf(model, module_name: str) -> bool: + if not module_name: + return False + try: + module = get_module_by_name(model, module_name) + except Exception: + return False + return not any(True for _ in module.named_children()) + + +def _cleanup_legacy_leaf_group_dir(save_dir: str, group_name: str) -> None: + legacy_dir = join(save_dir, group_name) + if not os.path.isdir(legacy_dir): + return + + for cleanup_base_name, cleanup_save_name in { + ("layer", "layer.safetensors"), + ("model", "model.safetensors"), + }: + _cleanup_saved_weight_files( + save_dir=legacy_dir, + expected_files=[], + model_base_name=cleanup_base_name, + model_save_name=cleanup_save_name, + ) + + try: + if not os.listdir(legacy_dir): + os.rmdir(legacy_dir) + except OSError: + pass + + +def _stream_state_dict_to_layer_dirs( + state_dict: Dict[str, Any], + save_dir: str, + model_base_name: str, + model_save_name: str, + metadata: Dict[str, str], + max_shard_size: Optional[int], + layer_prefixes: List[str], + model, +) -> tuple[List[str], Dict[str, str], int]: + grouped_state_dict: Dict[str, Dict[str, Any]] = {} + layer_groups: Dict[str, bool] = {} + for tensor_name, tensor_source in state_dict.items(): + group_name, is_layer_group = _resolve_layer_split_group(tensor_name, layer_prefixes) + group = grouped_state_dict.setdefault(group_name, {}) + group[tensor_name] = tensor_source + layer_groups[group_name] = is_layer_group + + expected_files: List[str] = [] + tensor_to_filename: Dict[str, str] = {} + total_size = 0 + root_expected_files: List[str] = [] + cleanup_specs = {(model_base_name, model_save_name)} + if model_base_name != "model" or model_save_name != "model.safetensors": + cleanup_specs.add(("model", "model.safetensors")) + + for group_dir_name, group_state_dict in grouped_state_dict.items(): + is_layer_group = layer_groups.get(group_dir_name, False) + is_leaf_group = (not is_layer_group) and _module_is_leaf(model, group_dir_name) + + if is_layer_group: + group_dir = join(save_dir, group_dir_name) + group_model_base_name = model_base_name + group_model_save_name = model_save_name + relative_prefix = f"{group_dir_name}/" + group_cleanup_specs = cleanup_specs + elif is_leaf_group and group_dir_name: + group_dir = save_dir + group_model_base_name = group_dir_name + group_model_save_name = f"{group_dir_name}.safetensors" + relative_prefix = "" + group_cleanup_specs = {(group_model_base_name, group_model_save_name)} + else: + group_dir = save_dir if not group_dir_name else join(save_dir, group_dir_name) + group_model_base_name = model_base_name + group_model_save_name = model_save_name + relative_prefix = "" if not group_dir_name else f"{group_dir_name}/" + group_cleanup_specs = cleanup_specs + + os.makedirs(group_dir, exist_ok=True) + + group_expected_files, group_tensor_to_filename, group_total_size = streaming_state_dict_to_shards( + group_state_dict, + save_dir=group_dir, + model_base_name=group_model_base_name, + single_file_name=group_model_save_name, + metadata=metadata, + max_shard_size=max_shard_size, + ) + total_size += group_total_size + + for cleanup_base_name, cleanup_save_name in group_cleanup_specs: + _cleanup_saved_weight_files( + save_dir=group_dir, + expected_files=group_expected_files, + model_base_name=cleanup_base_name, + model_save_name=cleanup_save_name, + ) + + if is_leaf_group and group_dir_name: + _cleanup_legacy_leaf_group_dir(save_dir=save_dir, group_name=group_dir_name) + elif group_dir_name: + _cleanup_saved_weight_files( + save_dir=save_dir, + expected_files=[], + model_base_name=group_dir_name, + model_save_name=f"{group_dir_name}.safetensors", + ) + + if not group_dir_name and not is_leaf_group: + root_expected_files.extend(group_expected_files) + + for filename in group_expected_files: + relative_filename = f"{relative_prefix}{filename}" if relative_prefix else filename + expected_files.append(relative_filename) + + for tensor_name, filename in group_tensor_to_filename.items(): + relative_filename = f"{relative_prefix}{filename}" if relative_prefix else filename + tensor_to_filename[tensor_name] = relative_filename + + for cleanup_base_name, cleanup_save_name in cleanup_specs: + _cleanup_saved_weight_files( + save_dir=save_dir, + expected_files=root_expected_files, + model_base_name=cleanup_base_name, + model_save_name=cleanup_save_name, + ) + + return expected_files, tensor_to_filename, total_size + def ModelWriter(cls): def save_pretrained( self, @@ -137,6 +463,7 @@ def save_quantized( max_shard_size: Optional[Union[int, str]] = DEFAULT_MAX_SHARD_SIZE, meta_quantizer: Optional[str] = None, eora_path: Optional[str] = None, + split_by: Optional[str] = None, ): """save quantized model and configs to local disk""" os.makedirs(save_dir, exist_ok=True) @@ -171,19 +498,23 @@ def save_quantized( ) # meta: write config fields to meta if they doe not participate in inference + gptaq_cfg = getattr(self.quantize_config, "gptaq", None) + + foem_cfg = getattr(self.quantize_config, "foem", None) + self.quantize_config.meta_set( key=META_FIELD_DAMP_PERCENT, - value=self.quantize_config.damp_percent + value=getattr(self.quantize_config, "damp_percent", None) ) self.quantize_config.meta_set( key=META_FIELD_DAMP_AUTO_INCREMENT, - value=self.quantize_config.damp_auto_increment + value=getattr(self.quantize_config, "damp_auto_increment", None) ) self.quantize_config.meta_set( key=META_FIELD_STATIC_GROUPS, - value=self.quantize_config.static_groups + value=getattr(self.quantize_config, "static_groups", None) ) self.quantize_config.meta_set( @@ -193,39 +524,61 @@ def save_quantized( self.quantize_config.meta_set( key=META_FIELD_MSE, - value=self.quantize_config.mse + value=getattr(self.quantize_config, "mse", None) ) self.quantize_config.meta_set( key=META_FIELD_GPTAQ_ENABLED, - value=None if self.quantize_config.gptaq is None else { - "alpha": self.quantize_config.gptaq.alpha, + value=None if gptaq_cfg is None else { + "alpha": gptaq_cfg.alpha, "device": ( - self.quantize_config.gptaq.device - if isinstance(self.quantize_config.gptaq.device, str) - else str(self.quantize_config.gptaq.device) + gptaq_cfg.device + if isinstance(gptaq_cfg.device, str) + else str(gptaq_cfg.device) + ), + } + ) + + self.quantize_config.meta_set( + key=META_FIELD_FOEM_ENABLED, + value=None if foem_cfg is None else { + "alpha": foem_cfg.alpha, + "beta": foem_cfg.beta, + "device": ( + foem_cfg.device + if isinstance(foem_cfg.device, str) + else str(foem_cfg.device) ), } ) self.quantize_config.meta_set( key=META_FIELD_ACT_GROUP_AWARE, - value=self.quantize_config.act_group_aware + value=getattr(self.quantize_config, "act_group_aware", None) ) # The config, quantize_config and model may be edited in place in save_quantized. + sanitize_model_config(self.model.config) config = copy.deepcopy(self.model.config) + quantize_config = copy.deepcopy(self.quantize_config) if not self.quantized: raise ValueError("Save aborted as model is not quantized. Please call `quantize()` first.") - if quantize_config.format == FORMAT.GPTQ_V2: + runtime_format = resolve_quant_format(quantize_config.format, quantize_config.method) + + if runtime_format == FORMAT.GPTQ_V2: log.warn( - f"Using 'format = {FORMAT.GPTQ_V2}': the serialized model is only supported by GPTQModel version >= {MIN_VERSION_WITH_V2}." + f"Using 'format = {FORMAT.GPTQ_V2}': the serialized model is only supported by GPT-QModel version >= {MIN_VERSION_WITH_V2}." ) - if self.load_quantized_model: + if runtime_format == FORMAT.EXL3: + tensor_storage = build_exllamav3_tensor_storage(self.model) + quantize_config.tensor_storage = tensor_storage + self.quantize_config.tensor_storage = copy.deepcopy(tensor_storage) + + if self.load_quantized_model and runtime_format != FORMAT.EXL3: self.model = self.get_model_with_quantize( qcfg=quantize_config, model_id_or_path=self.model_local_path, @@ -305,6 +658,11 @@ def debug_saved_config(path): offload_root = self.quantize_config.offload_to_disk_path if getattr(self.quantize_config, "offload_to_disk", False) else None state_dict = get_state_dict_for_save(self.model, offload_root=offload_root) + copy_tensor_files, prefix_entries = _normalize_out_of_model_tensors_entries( + getattr(self, "out_of_model_tensors", None) + ) + if prefix_entries: + _merge_prefix_tensors_into_state_dict(prefix_entries, self.model_local_path, state_dict) model_base_name = "model" model_save_name = model_base_name + ".safetensors" @@ -318,7 +676,7 @@ def _parse_max_shard_size(value: Optional[Union[int, str]]) -> Optional[int]: return None if isinstance(value, int): return value - match = re.fullmatch(r"\s*(\d+)([KMGTP]?B?)\s*", value, re.IGNORECASE) + match = _MAX_SHARD_SIZE_RE.fullmatch(value) if not match: raise ValueError(f"Invalid max_shard_size value: {value}") base = int(match.group(1)) @@ -360,30 +718,38 @@ def _normalize_metadata(meta: Optional[Dict[str, Any]]) -> Dict[str, str]: max_shard_size_bytes = _parse_max_shard_size(max_shard_size) metadata_dict = _normalize_metadata(safetensors_metadata) metadata_dict["format"] = "pt" - - expected_files, tensor_to_filename, total_size_bytes = streaming_state_dict_to_shards( - state_dict, - save_dir=save_dir, - model_base_name=model_base_name, - single_file_name=model_save_name, - metadata=metadata_dict, - max_shard_size=max_shard_size_bytes, - ) - - pattern = re.compile(rf"{re.escape(model_base_name)}-\d{{5}}-of-\d{{5}}\.safetensors") - for filename in os.listdir(save_dir): - full_filename = join(save_dir, filename) - if not isfile(full_filename): - continue - if filename == model_save_name and filename not in expected_files: - os.remove(full_filename) - continue - if pattern.fullmatch(filename) and filename not in expected_files: - os.remove(full_filename) + split_by_mode = _parse_split_by(split_by) + + if split_by_mode == "layer": + expected_files, tensor_to_filename, total_size_bytes = _stream_state_dict_to_layer_dirs( + state_dict, + save_dir=save_dir, + model_base_name="layer", + model_save_name="layer.safetensors", + metadata=metadata_dict, + max_shard_size=max_shard_size_bytes, + layer_prefixes=self.extract_layers_node(), + model=self.model, + ) + else: + expected_files, tensor_to_filename, total_size_bytes = streaming_state_dict_to_shards( + state_dict, + save_dir=save_dir, + model_base_name=model_base_name, + single_file_name=model_save_name, + metadata=metadata_dict, + max_shard_size=max_shard_size_bytes, + ) + _cleanup_saved_weight_files( + save_dir=save_dir, + expected_files=expected_files, + model_base_name=model_base_name, + model_save_name=model_save_name, + ) total_size_mb = total_size_bytes / (1024 * 1024) - if len(expected_files) > 1: + if split_by_mode == "layer" or len(expected_files) > 1: index = { "metadata": {"total_size": total_size_bytes}, "weight_map": tensor_to_filename, @@ -404,82 +770,21 @@ def _normalize_metadata(meta: Optional[Dict[str, Any]]) -> Dict[str, str]: if self.quantize_config.adapter: _eora_save(self, save_dir=eora_path if eora_path else self.quantize_config.adapter.path, model_save_dir=save_dir) - # Handle `dangling` tensor files that HF doesn't support (optional) but very useful - extra_tensor_files = getattr(self, "out_of_model_tensor_files", None) - if extra_tensor_files: - if isinstance(extra_tensor_files, str): - extra_tensor_files = [extra_tensor_files] - else: - extra_tensor_files = list(extra_tensor_files) - - index_save_name = model_save_name + ".index.json" - index_save_path = join(save_dir, index_save_name) - - if os.path.exists(index_save_path): - with open(index_save_path, "r", encoding="utf-8") as f: - index_data = json.load(f) - else: - index_data = { - "metadata": {"total_size": total_size_bytes}, - "weight_map": dict(tensor_to_filename), - } - - if "metadata" not in index_data: - index_data["metadata"] = {} - if "weight_map" not in index_data: - index_data["weight_map"] = {} - - total_size_value = index_data["metadata"].get("total_size", total_size_bytes) - index_updated = False - - for tensor_file_name in extra_tensor_files: - original_tensor_path = os.path.join(self.model_local_path, tensor_file_name) - if not os.path.exists(original_tensor_path): - log.warn( - f"Model: out_of_model_tensor_files configured with '{tensor_file_name}', " - f"but the file was not found at '{original_tensor_path}'" - ) - continue - - target_tensor_path = os.path.join(save_dir, tensor_file_name) - shutil.copy2(original_tensor_path, target_tensor_path) - log.info( - f"Model: Copied {tensor_file_name} from original model directory to quantized model directory" + # Copy any requested safetensors files without modifying the index + for tensor_file_name in copy_tensor_files: + original_tensor_path = os.path.join(self.model_local_path, tensor_file_name) + if not os.path.exists(original_tensor_path): + log.warn( + f"Model: out_of_model_tensors configured with '{tensor_file_name}', " + f"but the file was not found at '{original_tensor_path}'" ) + continue - tensor_names = [] - try: - with safe_open(original_tensor_path, framework="pt", device="cpu") as f: - tensor_names = list(f.keys()) - except Exception as exc: - log.warn( - f"Model: Failed to read tensor names from {tensor_file_name}: {exc}" - ) - - for tensor_name in tensor_names: - index_data["weight_map"][tensor_name] = tensor_file_name - - if tensor_names: - log.info( - f"Model: Added {len(tensor_names)} tensors from {tensor_file_name} to weight_map" - ) - - try: - tensor_file_size = os.path.getsize(target_tensor_path) - except OSError: - tensor_file_size = 0 - - total_size_value += tensor_file_size - index_updated = True - - if index_updated: - index_data["metadata"]["total_size"] = total_size_value - with open(index_save_path, "w", encoding="utf-8") as f: - content = json.dumps(index_data, indent=2, sort_keys=True) + "\n" - f.write(content) - log.info( - f"Model: Updated {index_save_name} to include `out_of_model_tensor_files`" - ) + target_tensor_path = os.path.join(save_dir, tensor_file_name) + shutil.copy2(original_tensor_path, target_tensor_path) + log.info( + f"Model: Copied {tensor_file_name} from original model directory to quantized model directory" + ) # If the saved model is a loaded quantized model, do not calculate the size diff. if not self.load_quantized_model: @@ -524,16 +829,9 @@ def get_model_with_quantize(self, qcfg, model_id_or_path): model_id_or_path, trust_remote_code=True, ) + prepare_remote_code_compat(config) - def skip(*args, **kwargs): - pass - - torch.nn.init.kaiming_uniform_ = skip - torch.nn.init.uniform_ = skip - torch.nn.init.normal_ = skip - transformers.modeling_utils._init_weights = False - init_contexts = [no_init_weights()] - with ContextManagers(init_contexts): + with suspend_hf_weight_init(): model = cls.loader.from_config( config, dtype=torch.float16 ) diff --git a/gptqmodel/nn_modules/converter.py b/gptqmodel/nn_modules/converter.py index b1c952f6a..0e2c2f185 100644 --- a/gptqmodel/nn_modules/converter.py +++ b/gptqmodel/nn_modules/converter.py @@ -4,117 +4,21 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium -def convert_gpt_oss_expert_converter(module, config): - import torch.nn as nn - import transformers.models.gpt_oss.modeling_gpt_oss as gpt_oss_modeling - from transformers.integrations.hub_kernels import use_kernel_forward_from_hub +def _resolve_text_decoder_config(config): + text_config = getattr(config, "text_config", None) + if text_config is not None: + return text_config - from ..models.definitions.gpt_oss import GptOssExpertsNew + get_text_config = getattr(config, "get_text_config", None) + if callable(get_text_config): + resolved = get_text_config() + if resolved is not None: + return resolved - @use_kernel_forward_from_hub("MegaBlocksMoeMLP") - class GptOssMLPNew(nn.Module): - def __init__(self, config, ori_mlp=None): - super().__init__() - self.router = ori_mlp.router - experts_new = GptOssExpertsNew(config, ori_mlp.experts) - self.experts = experts_new + return config - def forward(self, hidden_states): - router_scores, router_indices = self.router(hidden_states) # (num_experts, seq_len) - routed_out = self.experts(hidden_states, router_indices=router_indices, routing_weights=router_scores) - return routed_out, router_scores - - # loop sub module to replace GptOssMLP with GptOssMLPNew - for name, sub_module in module.named_modules(): - if isinstance(sub_module, gpt_oss_modeling.GptOssMLP): - new_module = GptOssMLPNew(config=config, ori_mlp=sub_module) - setattr(module, name, new_module) - - return module - -def convert_llama4_expert_converter(module, config): - import torch - from transformers.models.llama4.modeling_llama4 import Llama4TextMLP, Llama4TextMoe - - from ..utils.hf import no_init_weights - - # adapted/modified from https://github.com/vllm-project/llm-compressor/blob/main/src/llmcompressor/modeling/llama4.py - class SequentialLlama4TextExperts(torch.nn.ModuleList): - def __init__(self, config, original): - self.num_experts = original.gate_up_proj.shape[0] - with no_init_weights(): - super().__init__([Llama4TextMLP(config) for _ in range(self.num_experts)]) - intermediate_size = original.down_proj.shape[1] - - with torch.inference_mode(): - # Batch process all expert parameters to avoid loops - gate_up_batch = torch.stack([original.gate_up_proj[i] for i in range(self.num_experts)]) - down_batch = torch.stack([original.down_proj[i] for i in range(self.num_experts)]) - - # Batch split and transpose - gate_batch = gate_up_batch[:, :, :intermediate_size].transpose(-2, -1).contiguous() - up_batch = gate_up_batch[:, :, intermediate_size:].transpose(-2, -1).contiguous() - down_batch = down_batch.transpose(-2, -1).contiguous() - - # Batch assignment - for i in range(self.num_experts): - self[i].gate_proj.weight.data = gate_batch[i] - self[i].up_proj.weight.data = up_batch[i] - self[i].down_proj.weight.data = down_batch[i] - - class SequentialLlama4TextMoe(torch.nn.Module): - def __init__(self, config, original): - super().__init__() - self.top_k = config.num_experts_per_tok - self.hidden_dim = config.hidden_size - self.num_experts = config.num_local_experts - self.experts = SequentialLlama4TextExperts(config, original.experts) - self.router = original.router - self.shared_expert = original.shared_expert - - def forward(self, hidden_states: torch.Tensor): - hidden_states = hidden_states.reshape(-1, self.hidden_dim) - router_logits = self.router(hidden_states) - if isinstance(router_logits, tuple): - router_scores, router_logits = router_logits - router_scores = router_scores.t() - else: - # transformers < 4.54.0 only returns router_logits - router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=1) - - router_scores = ( - torch.full_like(router_logits, float("-inf")) - .scatter_(1, router_indices, router_top_value) - .transpose(0, 1) - ) - router_scores = torch.sigmoid(router_scores.float()).to(hidden_states.dtype) - - out = self.shared_expert(hidden_states) - for i in range(self.num_experts): - out += self.experts[i](hidden_states) * router_scores[i].reshape(-1, 1) - - return out, router_logits - - for name, sub_module in module.named_modules(): - if isinstance(sub_module, Llama4TextMoe): - new_module = SequentialLlama4TextMoe(config=config.get_text_config(), original=sub_module) - setattr(module, name, new_module) - - return module - -def convert_glm4v_mlp_converter(module, config): - import transformers.models.glm4v.modeling_glm4v as glm4v_modeling - - from ..models.definitions.glm4v import Glm4vTextMLPNew - - for name, sub_module in module.named_modules(): - if isinstance(sub_module, glm4v_modeling.Glm4vTextMLP): - new_module = Glm4vTextMLPNew(config=config.get_text_config(), ori_mlp=sub_module) - setattr(module, name, new_module) - return module MODULE_CONVERTER_MAP = { - "llama4": convert_llama4_expert_converter, - "gpt_oss": convert_gpt_oss_expert_converter, - "glm4v": convert_glm4v_mlp_converter, + # llama4/gpt_oss are handled by Defuser>=0.0.10 during model load. + # qwen2_moe/qwen3_moe/qwen3_next/qwen3_omni_moe are handled by Defuser>=0.0.10 during model load. } diff --git a/gptqmodel/nn_modules/exllamav3.py b/gptqmodel/nn_modules/exllamav3.py new file mode 100644 index 000000000..39ecc40ae --- /dev/null +++ b/gptqmodel/nn_modules/exllamav3.py @@ -0,0 +1,200 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium +# +# Portions of this file are adapted from turboderp-org/exllamav3. +# Credits: TurboDerp / ExLlamaV3 contributors. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Dict, Optional + +import torch +import torch.nn as nn + + +if TYPE_CHECKING: + from ..exllamav3.modules.quant.exl3 import LinearEXL3 + + +_EXL3_BUFFER_NAMES = ("trellis", "suh", "svh", "su", "sv", "bias", "mcg", "mul1") + + +def _torch_dtype(value: Any) -> torch.dtype: + if isinstance(value, torch.dtype): + return value + if isinstance(value, str): + return getattr(torch, value) + raise TypeError(f"Unsupported torch dtype value: {value!r}") + + +class ExllamaV3Linear(nn.Module): + QUANT_TYPE = "exl3" + SUPPORTS_SHARDS = True + + def __init__( + self, + *, + in_features: int, + out_features: int, + name: str, + tensor_storage: Optional[Dict[str, Any]] = None, + tensors: Optional[Dict[str, torch.Tensor]] = None, + out_dtype: torch.dtype = torch.float16, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.name = name + self.out_dtype = out_dtype + self.tensor_storage = tensor_storage or {} + + self.weight = torch.zeros((1,), dtype=torch.float16, device="meta") + self._inner: Optional["LinearEXL3"] = None + self._inner_signature: Optional[tuple[Any, ...]] = None + + if tensors is not None: + for buffer_name in _EXL3_BUFFER_NAMES: + tensor = tensors.get(buffer_name) + if tensor is None: + setattr(self, buffer_name, None) + else: + self.register_buffer(buffer_name, tensor) + return + + stored_tensors = (self.tensor_storage or {}).get("stored_tensors", {}) + for buffer_name in _EXL3_BUFFER_NAMES: + metadata = stored_tensors.get(f"{name}.{buffer_name}") + if metadata is None: + setattr(self, buffer_name, None) + continue + + shape = tuple(metadata["shape"]) + dtype = _torch_dtype(metadata["torch_dtype"]) + self.register_buffer(buffer_name, torch.empty(shape, dtype=dtype, device="meta")) + + @classmethod + def from_tensors( + cls, + *, + in_features: int, + out_features: int, + name: str, + tensors: Dict[str, torch.Tensor], + ) -> "ExllamaV3Linear": + return cls( + in_features=in_features, + out_features=out_features, + name=name, + tensors=tensors, + ) + + def _current_signature(self) -> tuple[Any, ...]: + trellis = getattr(self, "trellis", None) + if trellis is None or trellis.device.type == "meta": + return ("meta",) + + signature: list[Any] = [str(trellis.device)] + for buffer_name in _EXL3_BUFFER_NAMES: + tensor = getattr(self, buffer_name, None) + if tensor is None: + signature.append(None) + continue + signature.append((tensor.data_ptr(), tuple(tensor.shape), str(tensor.dtype))) + return tuple(signature) + + def _drop_inner(self) -> None: + if self._inner is not None: + try: + self._inner.unload() + except Exception: + # `_drop_inner` runs during teardown and `_apply`; cleanup must stay best-effort. + pass + self._inner = None + self._inner_signature = None + + def _ensure_inner(self) -> "LinearEXL3": + from ..exllamav3.modules.quant.exl3 import LinearEXL3 + + trellis = getattr(self, "trellis", None) + if trellis is None: + raise RuntimeError(f"EXL3 module `{self.name}` is missing `trellis`.") + if trellis.device.type == "meta": + raise RuntimeError(f"EXL3 module `{self.name}` has not been materialized from checkpoint tensors yet.") + if trellis.device.type != "cuda": + raise RuntimeError("EXL3 inference requires CUDA/HIP tensors.") + + signature = self._current_signature() + if self._inner is not None and signature == self._inner_signature: + return self._inner + + self._drop_inner() + self._inner = LinearEXL3( + config=None, + in_features=self.in_features, + out_features=self.out_features, + scale=None, + su=getattr(self, "su", None), + sv=getattr(self, "sv", None), + suh=getattr(self, "suh", None), + svh=getattr(self, "svh", None), + trellis=trellis, + mcg=getattr(self, "mcg", None), + mul1=getattr(self, "mul1", None), + bias=getattr(self, "bias", None), + out_dtype=self.out_dtype, + transformers_fix=True, + key=self.name, + ) + self._inner_signature = signature + return self._inner + + def post_init(self) -> None: + self._drop_inner() + if getattr(self, "trellis", None) is not None and self.trellis.device.type != "meta": + self._ensure_inner() + + def _apply(self, fn): + self._drop_inner() + return super()._apply(fn) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + inner = self._ensure_inner() + input_dtype = x.dtype + return inner.forward(x.half(), {}).to(input_dtype) + + def _multiplier_value(self, name: str) -> Optional[int]: + tensor = getattr(self, name, None) + if tensor is None: + return None + return int(tensor.view(torch.uint32).item()) + + def tensor_storage_entry(self) -> Dict[str, Any]: + stored_tensors: Dict[str, Dict[str, Any]] = {} + for buffer_name in _EXL3_BUFFER_NAMES: + tensor = getattr(self, buffer_name, None) + if tensor is None: + continue + stored_tensors[f"{self.name}.{buffer_name}"] = { + "shape": list(tensor.shape), + "torch_dtype": str(tensor.dtype).split(".")[-1], + } + + entry: Dict[str, Any] = { + "stored_tensors": stored_tensors, + "quant_format": "exl3", + } + trellis = getattr(self, "trellis", None) + if trellis is not None: + entry["bits_per_weight"] = int(trellis.shape[-1] // 16) + + mcg_multiplier = self._multiplier_value("mcg") + if mcg_multiplier is not None: + entry["mcg_multiplier"] = mcg_multiplier + + mul1_multiplier = self._multiplier_value("mul1") + if mul1_multiplier is not None: + entry["mul1_multiplier"] = mul1_multiplier + + return entry diff --git a/gptqmodel/nn_modules/exllamav3_torch.py b/gptqmodel/nn_modules/exllamav3_torch.py new file mode 100644 index 000000000..f22b2dcfa --- /dev/null +++ b/gptqmodel/nn_modules/exllamav3_torch.py @@ -0,0 +1,400 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 +# +# Clean-room EXL3 torch reference kernel derived from the public EXL3 tensor +# format and documented runtime layout. + +from __future__ import annotations + +import math +from functools import lru_cache +from typing import Any, Dict, Optional + +import torch +import torch.nn as nn + +from .exllamav3 import _EXL3_BUFFER_NAMES, _torch_dtype + + +_EXL3_3INST_MULT = 89226354 +_EXL3_3INST_ADD = 64248484 +_EXL3_MCG_MULT = 0xCBAC1FED +_EXL3_MUL1_MULT = 0x83DCD12D +_EXL3_MUL1_ACC = 0x6400 +_EXL3_LOP3_MASK = 0x8FFF8FFF +_EXL3_LOP3_BIAS = 0x3B603B60 + + +def _half_scalar_from_bits(bits: int) -> float: + return float(torch.tensor([bits], dtype=torch.uint16).view(torch.float16).item()) + + +_EXL3_MUL1_INV = _half_scalar_from_bits(0x1EEE) +_EXL3_MUL1_BIAS = _half_scalar_from_bits(0xC931) + + +@lru_cache(maxsize=None) +def _tensor_core_perm(device_type: str, device_index: int | None) -> torch.Tensor: + device = torch.device(device_type, device_index) + perm = [0] * 256 + for t in range(32): + r0 = (t % 4) * 2 + r1 = r0 + 1 + r2 = r0 + 8 + r3 = r0 + 9 + c0 = t // 4 + c1 = c0 + 8 + perm[t * 8 + 0] = r0 * 16 + c0 + perm[t * 8 + 1] = r1 * 16 + c0 + perm[t * 8 + 2] = r2 * 16 + c0 + perm[t * 8 + 3] = r3 * 16 + c0 + perm[t * 8 + 4] = r0 * 16 + c1 + perm[t * 8 + 5] = r1 * 16 + c1 + perm[t * 8 + 6] = r2 * 16 + c1 + perm[t * 8 + 7] = r3 * 16 + c1 + return torch.tensor(perm, dtype=torch.long, device=device) + + +@lru_cache(maxsize=None) +def _tensor_core_perm_i(device_type: str, device_index: int | None) -> torch.Tensor: + perm = _tensor_core_perm(device_type, device_index) + return torch.argsort(perm) + + +@lru_cache(maxsize=None) +def _hadamard_128(device_type: str, device_index: int | None) -> torch.Tensor: + device = torch.device(device_type, device_index) + had = torch.tensor([[1.0]], dtype=torch.float32, device=device) + while had.shape[0] < 128: + had = torch.cat( + ( + torch.cat((had, had), dim=1), + torch.cat((had, -had), dim=1), + ), + dim=0, + ) + had *= 1.0 / math.sqrt(128.0) + return had.contiguous() + + +@lru_cache(maxsize=None) +def _codebook_lut( + codebook: str, + device_type: str, + device_index: int | None, +) -> torch.Tensor: + device = torch.device(device_type, device_index) + values = torch.arange(1 << 16, dtype=torch.int64, device=device) + + if codebook == "3inst": + raw = (values * _EXL3_3INST_MULT + _EXL3_3INST_ADD) & 0xFFFFFFFF + raw = _EXL3_LOP3_BIAS ^ (raw & _EXL3_LOP3_MASK) + halves = torch.stack( + ( + (raw & 0xFFFF).to(torch.uint16), + ((raw >> 16) & 0xFFFF).to(torch.uint16), + ), + dim=-1, + ).contiguous() + floats = halves.view(torch.float16).to(torch.float32) + return (floats[..., 0] + floats[..., 1]).contiguous() + + if codebook == "mcg": + raw = (values * _EXL3_MCG_MULT) & 0xFFFFFFFF + raw = _EXL3_LOP3_BIAS ^ (raw & _EXL3_LOP3_MASK) + halves = torch.stack( + ( + (raw & 0xFFFF).to(torch.uint16), + ((raw >> 16) & 0xFFFF).to(torch.uint16), + ), + dim=-1, + ).contiguous() + floats = halves.view(torch.float16).to(torch.float32) + return (floats[..., 0] + floats[..., 1]).contiguous() + + if codebook == "mul1": + raw = (values * _EXL3_MUL1_MULT) & 0xFFFFFFFF + byte_sum = ( + (raw & 0xFF) + + ((raw >> 8) & 0xFF) + + ((raw >> 16) & 0xFF) + + ((raw >> 24) & 0xFF) + ) + accum = (byte_sum + _EXL3_MUL1_ACC).to(torch.uint16).contiguous() + floats = accum.view(torch.float16).to(torch.float32) + return (floats * _EXL3_MUL1_INV + _EXL3_MUL1_BIAS).contiguous() + + raise ValueError(f"Unsupported EXL3 codebook: {codebook}") + + +def _apply_hadamard_left(x: torch.Tensor) -> torch.Tensor: + if x.shape[0] % 128 != 0: + raise ValueError(f"EXL3 expects in_features to be divisible by 128, got {x.shape[0]}.") + had = _hadamard_128(x.device.type, x.device.index) + return (had @ x.view(-1, 128, x.shape[1])).view_as(x) + + +def _apply_hadamard_right(x: torch.Tensor) -> torch.Tensor: + if x.shape[1] % 128 != 0: + raise ValueError(f"EXL3 expects out_features to be divisible by 128, got {x.shape[1]}.") + had = _hadamard_128(x.device.type, x.device.index) + return (x.view(x.shape[0], -1, 128) @ had).view_as(x) + + +class ExllamaV3TorchLinear(nn.Module): + QUANT_TYPE = "exl3" + SUPPORTS_SHARDS = True + + def __init__( + self, + *, + in_features: int, + out_features: int, + name: str, + tensor_storage: Optional[Dict[str, Any]] = None, + tensors: Optional[Dict[str, torch.Tensor]] = None, + out_dtype: torch.dtype = torch.float16, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.name = name + self.out_dtype = out_dtype + self.tensor_storage = tensor_storage or {} + + self.weight = torch.zeros((1,), dtype=torch.float16, device="meta") + self._cache_signature: Optional[tuple[Any, ...]] = None + self._inner_weight_fp32: Optional[torch.Tensor] = None + self._weight_fp32: Optional[torch.Tensor] = None + + if tensors is not None: + for buffer_name in _EXL3_BUFFER_NAMES: + tensor = tensors.get(buffer_name) + if tensor is None: + setattr(self, buffer_name, None) + else: + self.register_buffer(buffer_name, tensor) + return + + stored_tensors = (self.tensor_storage or {}).get("stored_tensors", {}) + for buffer_name in _EXL3_BUFFER_NAMES: + metadata = stored_tensors.get(f"{name}.{buffer_name}") + if metadata is None: + setattr(self, buffer_name, None) + continue + + shape = tuple(metadata["shape"]) + dtype = _torch_dtype(metadata["torch_dtype"]) + self.register_buffer(buffer_name, torch.empty(shape, dtype=dtype, device="meta")) + + @classmethod + def from_tensors( + cls, + *, + in_features: int, + out_features: int, + name: str, + tensors: Dict[str, torch.Tensor], + ) -> "ExllamaV3TorchLinear": + return cls( + in_features=in_features, + out_features=out_features, + name=name, + tensors=tensors, + ) + + def _current_signature(self) -> tuple[Any, ...]: + trellis = getattr(self, "trellis", None) + if trellis is None or trellis.device.type == "meta": + return ("meta",) + + signature: list[Any] = [str(trellis.device)] + for buffer_name in _EXL3_BUFFER_NAMES: + tensor = getattr(self, buffer_name, None) + if tensor is None: + signature.append(None) + continue + signature.append((tensor.data_ptr(), tuple(tensor.shape), str(tensor.dtype))) + return tuple(signature) + + def _drop_cache(self) -> None: + self._cache_signature = None + self._inner_weight_fp32 = None + self._weight_fp32 = None + + def _apply(self, fn): + self._drop_cache() + return super()._apply(fn) + + def post_init(self) -> None: + self._drop_cache() + + def _codebook_name(self) -> str: + if getattr(self, "mcg", None) is not None: + return "mcg" + if getattr(self, "mul1", None) is not None: + return "mul1" + return "3inst" + + def _bits_per_weight(self) -> int: + trellis = getattr(self, "trellis", None) + if trellis is None: + raise RuntimeError(f"EXL3 module `{self.name}` is missing `trellis`.") + return int(trellis.shape[-1] // 16) + + def _runtime_weight_dtype(self) -> torch.dtype: + trellis = getattr(self, "trellis", None) + if trellis is None or trellis.device.type == "cpu": + return torch.float32 + return torch.float16 + + def _unpack_indices(self) -> torch.Tensor: + trellis = getattr(self, "trellis", None) + if trellis is None: + raise RuntimeError(f"EXL3 module `{self.name}` is missing `trellis`.") + if trellis.device.type == "meta": + raise RuntimeError(f"EXL3 module `{self.name}` has not been materialized from checkpoint tensors yet.") + + bits = self._bits_per_weight() + mask = (1 << bits) - 1 + words = (trellis.to(torch.int32) & 0xFFFF).contiguous() + words = words.view(*words.shape[:-1], -1, 2).flip(-1).reshape(*words.shape) + words = words.view(*words.shape[:-1], 16, bits) + + symbols = torch.empty( + (*words.shape[:-2], 256), + dtype=torch.long, + device=words.device, + ) + for pos in range(16): + bit_offset = pos * bits + word_idx = bit_offset // 16 + bit_in_word = bit_offset % 16 + if bit_in_word + bits <= 16: + shift = 16 - bit_in_word - bits + value = (words[..., word_idx] >> shift) & mask + else: + bits_first = 16 - bit_in_word + bits_second = bits - bits_first + high = (words[..., word_idx] & ((1 << bits_first) - 1)) << bits_second + low = words[..., word_idx + 1] >> (16 - bits_second) + value = (high | low) & mask + symbols[..., pos::16] = value.to(torch.long) + + warmup = (16 + bits - 1) // bits - 1 + state = torch.zeros_like(symbols[..., 0], dtype=torch.long) + for idx in range(256 - warmup, 256): + state = ((state << bits) | symbols[..., idx]) & 0xFFFF + + encoded = torch.empty_like(symbols) + for idx in range(256): + state = ((state << bits) | symbols[..., idx]) & 0xFFFF + encoded[..., idx] = state + + return encoded + + def _ensure_inner_weight_fp32(self) -> torch.Tensor: + trellis = getattr(self, "trellis", None) + if trellis is None: + raise RuntimeError(f"EXL3 module `{self.name}` is missing `trellis`.") + if trellis.device.type == "meta": + raise RuntimeError(f"EXL3 module `{self.name}` has not been materialized from checkpoint tensors yet.") + + signature = self._current_signature() + if self._inner_weight_fp32 is not None and self._cache_signature == signature: + return self._inner_weight_fp32 + + encoded = self._unpack_indices() + lut = _codebook_lut(self._codebook_name(), trellis.device.type, trellis.device.index) + decoded = lut[encoded] + + perm_i = _tensor_core_perm_i(trellis.device.type, trellis.device.index) + decoded = decoded[..., perm_i] + tiles_k, tiles_n = decoded.shape[:2] + inner = decoded.view(tiles_k, tiles_n, 16, 16).permute(0, 2, 1, 3).reshape( + tiles_k * 16, tiles_n * 16 + ) + + self._cache_signature = signature + self._inner_weight_fp32 = inner.contiguous().to(torch.float32) + self._weight_fp32 = None + return self._inner_weight_fp32 + + def get_inner_weight_tensor(self, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + inner = self._ensure_inner_weight_fp32() + target_dtype = dtype or self._runtime_weight_dtype() + if inner.dtype == target_dtype: + return inner + return inner.to(dtype=target_dtype) + + def _ensure_weight_fp32(self) -> torch.Tensor: + signature = self._current_signature() + if self._weight_fp32 is not None and self._cache_signature == signature: + return self._weight_fp32 + + inner = self._ensure_inner_weight_fp32().clone() + inner = _apply_hadamard_left(inner) + inner *= getattr(self, "suh").to(dtype=torch.float32).unsqueeze(1) + inner = _apply_hadamard_right(inner) + inner *= getattr(self, "svh").to(dtype=torch.float32).unsqueeze(0) + + self._weight_fp32 = inner.contiguous() + return self._weight_fp32 + + def get_weight_tensor(self, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + weight = self._ensure_weight_fp32() + target_dtype = dtype or self._runtime_weight_dtype() + if weight.dtype == target_dtype: + return weight + return weight.to(dtype=target_dtype) + + def get_bias_tensor(self) -> torch.Tensor | None: + return getattr(self, "bias", None) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + input_dtype = x.dtype + compute_dtype = torch.float32 if x.device.type == "cpu" else torch.float16 + x_2d = x.view(-1, self.in_features).to(compute_dtype) + weight = self.get_weight_tensor(dtype=compute_dtype) + y = x_2d @ weight + bias = getattr(self, "bias", None) + if bias is not None: + y = y + bias.to(dtype=compute_dtype) + y = y.view(*x.shape[:-1], self.out_features) + return y.to(input_dtype) + + def _multiplier_value(self, name: str) -> Optional[int]: + tensor = getattr(self, name, None) + if tensor is None: + return None + return int(tensor.view(torch.uint32).item()) + + def tensor_storage_entry(self) -> Dict[str, Any]: + stored_tensors: Dict[str, Dict[str, Any]] = {} + for buffer_name in _EXL3_BUFFER_NAMES: + tensor = getattr(self, buffer_name, None) + if tensor is None: + continue + stored_tensors[f"{self.name}.{buffer_name}"] = { + "shape": list(tensor.shape), + "torch_dtype": str(tensor.dtype).split(".")[-1], + } + + entry: Dict[str, Any] = { + "stored_tensors": stored_tensors, + "quant_format": "exl3", + "bits_per_weight": self._bits_per_weight(), + } + + mcg_multiplier = self._multiplier_value("mcg") + if mcg_multiplier is not None: + entry["mcg_multiplier"] = mcg_multiplier + + mul1_multiplier = self._multiplier_value("mul1") + if mul1_multiplier is not None: + entry["mul1_multiplier"] = mul1_multiplier + + return entry + + +__all__ = ["ExllamaV3TorchLinear"] diff --git a/gptqmodel/nn_modules/hooked_linear.py b/gptqmodel/nn_modules/hooked_linear.py index a55a3c2ef..69223a89f 100644 --- a/gptqmodel/nn_modules/hooked_linear.py +++ b/gptqmodel/nn_modules/hooked_linear.py @@ -9,6 +9,7 @@ import transformers from torch import nn +from ..utils.device_telemetry import emit_device_telemetry from ..utils.logger import setup_logger @@ -220,6 +221,7 @@ def __init__(self, in_features: int, out_features: int) -> None: self.forward_hook = None self.forward_hook_last = False + self.module_name = None @staticmethod def from_linear(linear: torch.nn.Linear): @@ -232,6 +234,13 @@ def from_linear(linear: torch.nn.Linear): def forward(self, input: torch.Tensor) -> torch.Tensor: original_device = input.device target_device = self.weight.data.device + module_name = getattr(self, "module_name", None) or getattr(self, "full_name", None) or getattr(self, "name", None) or "unknown" + emit_device_telemetry( + "hooked_linear_forward", + module=module_name, + weight_device=target_device, + input_device=original_device, + ) if original_device != target_device: input = input.to(device=target_device) output = super().forward(input) @@ -250,7 +259,8 @@ def _replace_module(module, child, name, level: int = 0, debug: bool = False) -> if debug: log.info(f"{level_indent} Hook: {instance_type.__name__}: {name}") - if isinstance(child, torch.nn.Linear): + # Replace nn.Linear with HookedLinear, except PhimoeTopKRouter which returns a tuple in forward() + if isinstance(child, torch.nn.Linear) and child.__class__.__name__ != "PhimoeTopKRouter": setattr(module, name, HookedLinear.from_linear(child)) elif isinstance(child, transformers.Conv1D): setattr(module, name, HookedTransformerConv1D.from_conv1d(child)) diff --git a/gptqmodel/nn_modules/qlinear/__init__.py b/gptqmodel/nn_modules/qlinear/__init__.py index 6ae833c65..25876dcbe 100644 --- a/gptqmodel/nn_modules/qlinear/__init__.py +++ b/gptqmodel/nn_modules/qlinear/__init__.py @@ -8,7 +8,7 @@ import sys from concurrent.futures import ThreadPoolExecutor from functools import lru_cache -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple import numpy as np import torch as t # conflict with torch.py @@ -32,9 +32,6 @@ class BaseQuantLinear(nn.Module): SUPPORTS_METHODS: List[METHOD] = None SUPPORTS_FORMATS: Dict[FORMAT, int] = None SUPPORTS_BITS: List[int] = None - SUPPORTS_GROUP_SIZE: List[int] = None - SUPPORTS_DESC_ACT: List[bool] = None - SUPPORTS_SYM: List[bool] = None SUPPORTS_SHARDS: bool = None SUPPORTS_TRAINING: bool = None @@ -53,22 +50,21 @@ class BaseQuantLinear(nn.Module): SUPPORTS_DTYPES: List[t.dtype] = None REQUIRES_FORMAT_V2: bool = False + AUTOTUNE: bool = False def __init__(self, bits: int, - group_size: int, - desc_act: bool, - sym: bool, in_features: int, out_features: int, bias: bool, - pack_dtype: t.dtype, backend: BACKEND, adapter: Adapter, name: str = None, register_buffers: bool = False, register_buffers_in_features: int = None, register_buffers_out_features: int = None, + dtype: Optional[t.dtype] = None, + validate_kwargs: Optional[Dict[str, Any]] = None, **kwargs): super().__init__() if name is None: @@ -76,83 +72,37 @@ def __init__(self, self.name = name # full path module name in model weights self.in_features = in_features self.out_features = out_features - self.group_size = group_size if group_size != -1 else in_features self.bits = bits - self.desc_act = desc_act - self.sym = sym - self.pack_dtype = pack_dtype self.backend = backend - self.maxq = 2 ** self.bits - 1 - self.pack_dtype = pack_dtype # we need to clone the adapter since passed in adapter may be shared # adapter tensors are lodaed inside adapter so they must be unique per module self.adapter = copy.deepcopy(adapter) self.optimized = False - - if self.pack_dtype == t.int8: - self.pack_dtype_bits = 8 - self.pack_np_dtype = np.int8 # qweight saved dtype - self.pack_np_math_dtype = np.uint8 # pre-save math dtype - elif self.pack_dtype == t.int16: - self.pack_dtype_bits = 16 - self.pack_np_dtype = np.int16 - self.pack_np_math_dtype = np.uint16 - elif self.pack_dtype == t.int32: - self.pack_dtype_bits = 32 - self.pack_np_dtype = np.int32 - self.pack_np_math_dtype = np.uint32 - elif self.pack_dtype == t.int64: - self.pack_dtype_bits = 64 - self.pack_np_dtype = np.int64 - self.pack_np_math_dtype = np.uint64 - else: - raise ValueError("Unsupported weight_dtype. Only int16 and int32 are supported.") - - # pack_factor is only used for bits 2, 4, and 8. bit3 3 does not use this variable. - self.pack_factor = self.pack_dtype_bits // self.bits - _, err = self.validate(bits=bits, group_size=group_size, desc_act=desc_act, sym=sym, in_features=in_features, out_features=out_features, pack_dtype=pack_dtype) + self.autotune_enabled = self.AUTOTUNE + self._autotune_complete = False + self._autotune_result: Any = None + + validate_args = { + "bits": bits, + "in_features": in_features, + "out_features": out_features, + "dtype": dtype, + "adapter": adapter, + } + if validate_kwargs: + validate_args.update(validate_kwargs) + + _, err = self.validate( + **validate_args, + ) if err: raise err - # store qzero format - self._qzeros_format = 1 # only valid values are 1 and 2 for GPTQ v1 GPTQ v2 - - # most kernels share same buffers so they can share same register buffer code - if register_buffers: - # some kernels auto-pads in/out features - in_features = self.in_features if not register_buffers_in_features else register_buffers_in_features - out_features = self.out_features if not register_buffers_out_features else register_buffers_out_features - - self.register_buffer( - "qweight", - t.zeros((in_features // self.pack_dtype_bits * self.bits, out_features), dtype=self.pack_dtype), - ) - self.register_buffer( - "qzeros", - t.zeros( - ( - math.ceil(in_features / self.group_size), - out_features // self.pack_dtype_bits * self.bits, - ), - dtype=self.pack_dtype, - ), - ) - self.register_buffer( - "scales", - t.zeros( - (math.ceil(in_features / self.group_size), out_features), - dtype=t.float16, - ), - ) - self.register_buffer( - "g_idx", - t.tensor([i // self.group_size for i in range(in_features)], dtype=t.int32), - ) - if bias: - self.register_buffer("bias", t.zeros(out_features, dtype=t.float16)) - else: - self.bias = None + # The root base only owns fields shared across all quantization methods. + if register_buffers and bias: + bias_shape = self.out_features if register_buffers_out_features is None else register_buffers_out_features + self.register_buffer("bias", t.zeros(bias_shape, dtype=t.float16)) # load adapter if any if adapter is not None: @@ -173,7 +123,7 @@ def __init__(self, pass # print(f"Adapter lazy init: {self.adapter.name()}: {self.adapter}, module: {self.name}") - # TDOO: allow merged lora weights exist in gptq model safetensor file for direct loading + # TODO: allow merged lora weights exist in gptq model safetensor file for direct loading # EoRA need to preallocate buffers for Lora_A and B weights so HF can load # self.register_buffer( # "lora_A", @@ -187,41 +137,66 @@ def __init__(self, # ) def list_buffers(self) -> List: - buf = [] - if hasattr(self, "qweight") and self.qweight is not None: - buf.append(self.qweight) - if hasattr(self, "qzeros") and self.qzeros is not None: - buf.append(self.qzeros) - if hasattr(self, "scales") and self.scales is not None: - buf.append(self.scales) - if hasattr(self, "g_idx") and self.g_idx is not None: - buf.append(self.g_idx) - if hasattr(self, "bias") and self.bias is not None: - buf.append(self.bias) - - return buf - - def qzero_format(self, format: int = None) -> int: - # get - if format is None: - return self._qzeros_format - - # set - if format not in [1, 2]: - raise ValueError("Unsupported qzero format. Only 1 and 2 are supported.") - - self._qzeros_format = format - return self._qzeros_format + tensors = [] + seen = set() + for state in (self._parameters, self._buffers): + for tensor in state.values(): + if tensor is None or not isinstance(tensor, t.Tensor): + continue + tensor_id = id(tensor) + if tensor_id in seen: + continue + seen.add(tensor_id) + tensors.append(tensor) + return tensors + + def runtime_device(self) -> Optional[t.device]: + # Prefer the real quantized storage tensor over adapter scratch buffers. + for name in ("qweight", "weight", "B", "qzeros", "scales", "g_idx", "bias"): + tensor = getattr(self, name, None) + if isinstance(tensor, t.Tensor): + return tensor.device + + buffers = self.list_buffers() + return buffers[0].device if buffers else None + + def smooth_block_size(self) -> int: + return -1 # override me, to perform post-weight load to device init def post_init(self): + self.clear_autotune() if self.adapter is not None: + device = self.runtime_device() + if device is None: + raise RuntimeError(f"{self.__class__.__name__} cannot initialize adapters without any tensors.") self.adapter.post_init( weight_key=self.name, - device=self.list_buffers()[0].device, + device=device, lora_A=getattr(self, "lora_A", None), lora_B=getattr(self, "lora_B", None)) + def clear_autotune(self): + self._autotune_complete = False + self._autotune_result = None + + def get_autotune_result(self): + return self._autotune_result + + def _autotune(self, *args, **kwargs): + raise NotImplementedError(f"{self.__class__.__name__} does not implement `_autotune()`.") + + def maybe_autotune(self, *args, **kwargs): + if not self.autotune_enabled or self.training: + return self._autotune_result + + if self._autotune_complete: + return self._autotune_result + + self._autotune_result = self._autotune(*args, **kwargs) + self._autotune_complete = True + return self._autotune_result + @classmethod @lru_cache(maxsize=1024) def cached_validate_once(cls) -> Tuple[bool, Optional[Exception]]: @@ -242,9 +217,9 @@ def validate_once(cls) -> Tuple[bool, Optional[Exception]]: def validate( cls, bits: int, - group_size: int, - desc_act: bool, - sym: bool, + group_size: int = -1, + desc_act: bool = False, + sym: bool = True, in_features:int=None, out_features:int=None, pack_dtype:t.dtype=None, @@ -300,8 +275,15 @@ def verify_supports_params(cls): # raise ValueError(f"{cls.__name__}.{name} cannot be an empty list.") @classmethod - def _validate(cls, bits: int=4, group_size: int=128, desc_act: bool=False, sym: bool=False, pack_dtype:t.dtype=None, dtype: Optional[t.dtype]=None, dynamic:Optional[dict]=None, in_features:int=None, - out_features:int=None, device:Optional[DEVICE]=None, trainable:Optional[bool]=None, adapter:Optional[Adapter]=None) -> Tuple[bool, Optional[Exception]]: + def _validate_shared( + cls, + *, + pack_dtype:t.dtype=None, + dtype: Optional[t.dtype]=None, + device:Optional[DEVICE]=None, + trainable:Optional[bool]=None, + adapter:Optional[Adapter]=None, + ) -> Tuple[bool, Optional[Exception]]: cls.verify_supports_params() if adapter is not None and adapter.__class__ not in cls.SUPPORTS_ADAPTERS: @@ -330,70 +312,49 @@ def _validate(cls, bits: int=4, group_size: int=128, desc_act: bool=False, sym: if trainable and not cls.SUPPORTS_TRAINING: err = f"{cls} does not support training." return False, NotImplementedError(err) + return True, None - if bits not in cls.SUPPORTS_BITS: - err = f"{cls} only supports `{cls.SUPPORTS_BITS}` bits: actual bits = `{bits}`" - return False, NotImplementedError(err) - # valid group size is set of cls.SUPPORTS_GROUP_SIZE + in_features; group_size = -1 is alias for group_size == in_features - if group_size not in cls.SUPPORTS_GROUP_SIZE and group_size != in_features: - err = f"{cls} only supports `{cls.SUPPORTS_GROUP_SIZE}` group_size: actual group_size = `{group_size}`" - return False, NotImplementedError(err) - - # validate symmetric/asym quantization support - if sym not in cls.SUPPORTS_SYM: - err = f"{cls} only supports symmetric `{cls.SUPPORTS_SYM}` quantization: actual sym = `{sym}`" - return False, NotImplementedError(err) - - if desc_act not in cls.SUPPORTS_DESC_ACT: - err = f"{cls} only supports `{cls.SUPPORTS_DESC_ACT}` bits: actual desc_act = `{desc_act}`" - return False, NotImplementedError(err) - if dynamic is not None: - dynamic_bits = {} - for pattern, pattern_dict in dynamic.items(): - dynamic_bits[pattern] = pattern_dict.get("bits", bits) - if len(cls.SUPPORTS_BITS) == 1: + @classmethod + def _validate_dynamic_bits( + cls, + *, + bits: int, + dynamic: Optional[dict], + ) -> Tuple[bool, Optional[Exception]]: + if dynamic is None: + return True, None + + dynamic_bits = {} + for pattern, pattern_dict in dynamic.items(): + dynamic_bits[pattern] = pattern_dict.get("bits", bits) + if len(cls.SUPPORTS_BITS) == 1: + unsupported_dynamic_bits = { + layer: dynamic_bits_value + for layer, dynamic_bits_value in dynamic_bits.items() + if dynamic_bits_value != bits + } + if unsupported_dynamic_bits: err = f"{cls} not supported dynamic_bits, only support `{cls.SUPPORTS_BITS}` bits" return False, NotImplementedError(err) - else: - for layer, bits in dynamic_bits.items(): - if bits not in cls.SUPPORTS_BITS: - err = f"{cls} only supports `{cls.SUPPORTS_BITS}` bits: actual dynamic_bits = `{bits}` for layer `{layer}`" - return False, NotImplementedError(err) - - dynamic_group_size = {} - for pattern, pattern_dict in dynamic.items(): - dynamic_group_size[pattern] = pattern_dict.get("group_size", group_size) - for layer, group_size in dynamic_group_size.items(): - if group_size not in cls.SUPPORTS_GROUP_SIZE: - err = f"{cls} only supports `{cls.SUPPORTS_GROUP_SIZE}` group_size: actual group_size = `{group_size}` for layer `{layer}`" - return False, NotImplementedError(err) - - dynamic_sym = {} - for pattern, pattern_dict in dynamic.items(): - dynamic_sym[pattern] = pattern_dict.get("sym", sym) - for layer, sym in dynamic_sym.items(): - if sym not in cls.SUPPORTS_SYM: - err = f"{cls} only supports `{cls.SUPPORTS_SYM}` bits: actual sym = `{sym}` for layer `{layer}`" - return False, NotImplementedError(err) - - dynamic_desc_act = {} - for pattern, pattern_dict in dynamic.items(): - dynamic_desc_act[pattern] = pattern_dict.get("desc_act", desc_act) - for layer, desc_act in dynamic_desc_act.items(): - if desc_act not in cls.SUPPORTS_DESC_ACT: - err = f"{cls} only supports `{cls.SUPPORTS_DESC_ACT}` bits: actual desc_act = `{desc_act}` for layer `{layer}`" + else: + for layer, dynamic_bits_value in dynamic_bits.items(): + if dynamic_bits_value not in cls.SUPPORTS_BITS: + err = f"{cls} only supports `{cls.SUPPORTS_BITS}` bits: actual dynamic_bits = `{dynamic_bits_value}` for layer `{layer}`" return False, NotImplementedError(err) + return True, None + @classmethod + def _validate_shape_constraints( + cls, + *, + in_features:int=None, + out_features:int=None, + ) -> Tuple[bool, Optional[Exception]]: if in_features is not None: validate = all(in_features % in_fea == 0 for in_fea in cls.SUPPORTS_IN_FEATURES_DIVISIBLE_BY) if not validate: err = f"{cls}: `in_features`: {in_features} must be divisible by {cls.SUPPORTS_IN_FEATURES_DIVISIBLE_BY}." return False, NotImplementedError(err) - - validate = in_features % group_size == 0 or cls.SUPPORTS_AUTO_PADDING - if not validate: - err = f"{cls}: `in_features`: {in_features} must be divisible by `group_size: {group_size}`." - return False, NotImplementedError(err) if out_features is not None: validate = all(out_features % out_fea == 0 for out_fea in cls.SUPPORTS_OUT_FEATURES_DIVISIBLE_BY) if not validate: @@ -401,6 +362,32 @@ def _validate(cls, bits: int=4, group_size: int=128, desc_act: bool=False, sym: return False, NotImplementedError(err) return True, None + @classmethod + def _validate(cls, bits: int=4, group_size: int=128, desc_act: bool=False, sym: bool=False, pack_dtype:t.dtype=None, dtype: Optional[t.dtype]=None, dynamic:Optional[dict]=None, in_features:int=None, + out_features:int=None, device:Optional[DEVICE]=None, trainable:Optional[bool]=None, adapter:Optional[Adapter]=None) -> Tuple[bool, Optional[Exception]]: + ok, err = cls._validate_shared( + pack_dtype=pack_dtype, + dtype=dtype, + device=device, + trainable=trainable, + adapter=adapter, + ) + if not ok: + return ok, err + + if bits not in cls.SUPPORTS_BITS: + err = f"{cls} only supports `{cls.SUPPORTS_BITS}` bits: actual bits = `{bits}`" + return False, NotImplementedError(err) + + ok, err = cls._validate_dynamic_bits(bits=bits, dynamic=dynamic) + if not ok: + return ok, err + + return cls._validate_shape_constraints( + in_features=in_features, + out_features=out_features, + ) + @classmethod def validate_device(cls, device: DEVICE): assert isinstance(device, DEVICE) @@ -435,9 +422,316 @@ def train(self, mode=True): pass # log.info(f"{self.__class__.__name__}: `{self.name}` switching to eval mode.") + self.clear_autotune() return super().train(mode) -class PackableQuantLinear(BaseQuantLinear): + +class GroupedQuantLinear(BaseQuantLinear): + SUPPORTS_GROUP_SIZE: List[int] = None + SUPPORTS_DESC_ACT: List[bool] = None + SUPPORTS_SYM: List[bool] = None + + @classmethod + def verify_supports_params(cls): + super().verify_supports_params() + + grouped_supports_variables = [ + (name, value) for name, value in GroupedQuantLinear.__dict__.items() + if name.startswith("SUPPORTS") and not callable(value) and value is None + ] + child_supports_variables = [ + (name, value) for name, value in cls.__dict__.items() + if name.startswith("SUPPORTS") and not callable(value) + ] + + grouped_variable_names = {name for name, _ in grouped_supports_variables} + child_variable_names = {name for name, _ in child_supports_variables} + missing_variables = grouped_variable_names - child_variable_names + + if missing_variables: + raise ValueError( + f"{cls.__name__} these grouped SUPPORTS variables are not overridden: " + f"{', '.join(sorted(missing_variables))}" + ) + + for name, value in child_supports_variables: + if name in grouped_variable_names and value is None: + raise ValueError(f"{cls.__name__}.{name} cannot be None.") + + def __init__(self, + bits: int, + group_size: int, + desc_act: bool, + sym: bool, + in_features: int, + out_features: int, + bias: bool, + pack_dtype: t.dtype, + backend: BACKEND, + adapter: Adapter, + name: str = None, + register_buffers: bool = False, + register_buffers_in_features: int = None, + register_buffers_out_features: int = None, + dtype: Optional[t.dtype] = None, + **kwargs): + super().__init__( + bits=bits, + in_features=in_features, + out_features=out_features, + bias=bias, + backend=backend, + adapter=adapter, + name=name, + register_buffers=False, + register_buffers_in_features=register_buffers_in_features, + register_buffers_out_features=register_buffers_out_features, + dtype=dtype, + validate_kwargs={ + "group_size": group_size, + "desc_act": desc_act, + "sym": sym, + "pack_dtype": pack_dtype, + }, + **kwargs, + ) + + self.group_size = group_size if group_size != -1 else in_features + self.requested_group_size = group_size + self.desc_act = desc_act + self.sym = sym + + def smooth_block_size(self) -> int: + return -1 if self.requested_group_size == -1 else self.group_size + + @classmethod + def _validate(cls, bits: int=4, group_size: int=128, desc_act: bool=False, sym: bool=False, pack_dtype:t.dtype=None, dtype: Optional[t.dtype]=None, dynamic:Optional[dict]=None, in_features:int=None, + out_features:int=None, device:Optional[DEVICE]=None, trainable:Optional[bool]=None, adapter:Optional[Adapter]=None) -> Tuple[bool, Optional[Exception]]: + ok, err = super()._validate( + bits=bits, + group_size=group_size, + desc_act=desc_act, + sym=sym, + pack_dtype=pack_dtype, + dtype=dtype, + dynamic=dynamic, + in_features=in_features, + out_features=out_features, + device=device, + trainable=trainable, + adapter=adapter, + ) + if not ok: + return ok, err + + if group_size not in cls.SUPPORTS_GROUP_SIZE and group_size != in_features: + err = f"{cls} only supports `{cls.SUPPORTS_GROUP_SIZE}` group_size: actual group_size = `{group_size}`" + return False, NotImplementedError(err) + + if sym not in cls.SUPPORTS_SYM: + err = f"{cls} only supports symmetric `{cls.SUPPORTS_SYM}` quantization: actual sym = `{sym}`" + return False, NotImplementedError(err) + + if desc_act not in cls.SUPPORTS_DESC_ACT: + err = f"{cls} only supports `{cls.SUPPORTS_DESC_ACT}` bits: actual desc_act = `{desc_act}`" + return False, NotImplementedError(err) + + if dynamic is not None: + dynamic_group_size = {} + for pattern, pattern_dict in dynamic.items(): + dynamic_group_size[pattern] = pattern_dict.get("group_size", group_size) + for layer, layer_group_size in dynamic_group_size.items(): + if layer_group_size not in cls.SUPPORTS_GROUP_SIZE: + err = f"{cls} only supports `{cls.SUPPORTS_GROUP_SIZE}` group_size: actual group_size = `{layer_group_size}` for layer `{layer}`" + return False, NotImplementedError(err) + + dynamic_sym = {} + for pattern, pattern_dict in dynamic.items(): + dynamic_sym[pattern] = pattern_dict.get("sym", sym) + for layer, layer_sym in dynamic_sym.items(): + if layer_sym not in cls.SUPPORTS_SYM: + err = f"{cls} only supports `{cls.SUPPORTS_SYM}` bits: actual sym = `{layer_sym}` for layer `{layer}`" + return False, NotImplementedError(err) + + dynamic_desc_act = {} + for pattern, pattern_dict in dynamic.items(): + dynamic_desc_act[pattern] = pattern_dict.get("desc_act", desc_act) + for layer, layer_desc_act in dynamic_desc_act.items(): + if layer_desc_act not in cls.SUPPORTS_DESC_ACT: + err = f"{cls} only supports `{cls.SUPPORTS_DESC_ACT}` bits: actual desc_act = `{layer_desc_act}` for layer `{layer}`" + return False, NotImplementedError(err) + + if in_features is not None: + validate = in_features % group_size == 0 or cls.SUPPORTS_AUTO_PADDING + if not validate: + err = f"{cls}: `in_features`: {in_features} must be divisible by `group_size: {group_size}`." + return False, NotImplementedError(err) + + return True, None + + +class PackedGroupedQuantLinear(GroupedQuantLinear): + def __init__(self, *args, pack_dtype: t.dtype, **kwargs): + super().__init__(*args, pack_dtype=pack_dtype, **kwargs) + + # Packed GPTQ/AWQ layouts need explicit storage-word metadata. + self.pack_dtype = pack_dtype + self.maxq = 2 ** self.bits - 1 + + if self.pack_dtype == t.int8: + self.pack_dtype_bits = 8 + self.pack_np_dtype = np.int8 + self.pack_np_math_dtype = np.uint8 + elif self.pack_dtype == t.int16: + self.pack_dtype_bits = 16 + self.pack_np_dtype = np.int16 + self.pack_np_math_dtype = np.uint16 + elif self.pack_dtype == t.int32: + self.pack_dtype_bits = 32 + self.pack_np_dtype = np.int32 + self.pack_np_math_dtype = np.uint32 + elif self.pack_dtype == t.int64: + self.pack_dtype_bits = 64 + self.pack_np_dtype = np.int64 + self.pack_np_math_dtype = np.uint64 + else: + raise ValueError(f"Unsupported pack_dtype: {self.pack_dtype}") + + # pack_factor is only meaningful for packed low-bit code storage. + self.pack_factor = self.pack_dtype_bits // self.bits + + +class WeightOnlyQuantLinear(BaseQuantLinear): + def __init__(self, + bits: int, + in_features: int, + out_features: int, + bias: bool, + backend: BACKEND, + adapter: Adapter, + name: str = None, + register_buffers: bool = False, + register_buffers_in_features: int = None, + register_buffers_out_features: int = None, + dtype: Optional[t.dtype] = None, + pack_dtype: Optional[t.dtype] = None, + **kwargs): + super().__init__( + bits=bits, + in_features=in_features, + out_features=out_features, + bias=bias, + backend=backend, + adapter=adapter, + name=name, + register_buffers=register_buffers, + register_buffers_in_features=register_buffers_in_features, + register_buffers_out_features=register_buffers_out_features, + dtype=dtype, + validate_kwargs={"pack_dtype": pack_dtype}, + **kwargs, + ) + + +class GPTQQuantLinear(PackedGroupedQuantLinear): + def __init__(self, + bits: int, + group_size: int, + desc_act: bool, + sym: bool, + in_features: int, + out_features: int, + bias: bool, + pack_dtype: t.dtype, + backend: BACKEND, + adapter: Adapter, + name: str = None, + register_buffers: bool = False, + register_buffers_in_features: int = None, + register_buffers_out_features: int = None, + dtype: Optional[t.dtype] = None, + **kwargs): + super().__init__( + bits=bits, + group_size=group_size, + desc_act=desc_act, + sym=sym, + in_features=in_features, + out_features=out_features, + bias=bias, + pack_dtype=pack_dtype, + backend=backend, + adapter=adapter, + name=name, + register_buffers=False, + register_buffers_in_features=register_buffers_in_features, + register_buffers_out_features=register_buffers_out_features, + dtype=dtype, + **kwargs, + ) + + # GPTQ v1/v2 conversions only apply to GPTQ-style qzero storage. + self._qzeros_format = 1 + + if register_buffers: + self._register_gptq_buffers( + bias=bias, + register_buffers_in_features=register_buffers_in_features, + register_buffers_out_features=register_buffers_out_features, + ) + + def _register_gptq_buffers( + self, + *, + bias: bool, + register_buffers_in_features: int = None, + register_buffers_out_features: int = None, + ) -> None: + in_features = self.in_features if register_buffers_in_features is None else register_buffers_in_features + out_features = self.out_features if register_buffers_out_features is None else register_buffers_out_features + + self.register_buffer( + "qweight", + t.zeros((in_features // self.pack_dtype_bits * self.bits, out_features), dtype=self.pack_dtype), + ) + self.register_buffer( + "qzeros", + t.zeros( + ( + math.ceil(in_features / self.group_size), + out_features // self.pack_dtype_bits * self.bits, + ), + dtype=self.pack_dtype, + ), + ) + self.register_buffer( + "scales", + t.zeros( + (math.ceil(in_features / self.group_size), out_features), + dtype=t.float16, + ), + ) + self.register_buffer( + "g_idx", + t.tensor([i // self.group_size for i in range(in_features)], dtype=t.int32), + ) + if bias: + self.register_buffer("bias", t.zeros(out_features, dtype=t.float16)) + else: + self.bias = None + + def qzero_format(self, format: int = None) -> int: + if format is None: + return self._qzeros_format + + if format not in [1, 2]: + raise ValueError("Unsupported qzero format. Only 1 and 2 are supported.") + + self._qzeros_format = format + return self._qzeros_format + + +class PackableQuantLinear(GPTQQuantLinear): def __init__(self, *args, enable_wf_unsqueeze: bool = False, **kwargs): self.enable_wf_unsqueeze = enable_wf_unsqueeze super().__init__(*args, **kwargs) @@ -1102,7 +1396,7 @@ def pack_original(self, linear: nn.Module, scales: t.Tensor, zeros: t.Tensor, g_ self.register_buffer("qzeros", t.from_numpy(qzeros.astype(self.pack_np_dtype))) # assert - # assert isinstance(self, TorchQuantLinear), f"type: {self.__class_}" + # assert isinstance(self, TorchLinear), f"type: {self.__class_}" # wq = linear.weight.data # wq_dequantized = self.dequantize_weight().T # print(f"------ WQ -----") @@ -1113,7 +1407,7 @@ def pack_original(self, linear: nn.Module, scales: t.Tensor, zeros: t.Tensor, g_ # print("self qw", self.qweight, self.scales, self.qzeros) -class AWQuantLinear(BaseQuantLinear): +class AWQuantLinear(PackedGroupedQuantLinear): def __init__(self, bias: bool = False, register_buffers: bool = False, @@ -1151,13 +1445,5 @@ def __init__(self, # TODO FIX ME. this hack was needed because other part of code forgot to call nn.module register_buffer()! def list_buffers(self) -> List: - buf = [] - if hasattr(self, "qweight") and self.qweight is not None: - buf.append(self.qweight) - if hasattr(self, "qzeros") and self.qzeros is not None: - buf.append(self.qzeros) - if hasattr(self, "scales") and self.scales is not None: - buf.append(self.scales) - if hasattr(self, "bias") and self.bias is not None: - buf.append(self.bias) + buf = super().list_buffers() return buf diff --git a/gptqmodel/nn_modules/qlinear/bitblas.py b/gptqmodel/nn_modules/qlinear/bitblas.py index 61cef8708..9600b3a1f 100644 --- a/gptqmodel/nn_modules/qlinear/bitblas.py +++ b/gptqmodel/nn_modules/qlinear/bitblas.py @@ -11,12 +11,13 @@ from pathlib import Path from typing import List, Optional, Tuple, Union +import pcre import torch from packaging import version from ...adapter.adapter import Adapter, Lora from ...models._const import DEVICE, PLATFORM -from ...nn_modules.qlinear import BaseQuantLinear +from ...nn_modules.qlinear import BaseQuantLinear, GroupedQuantLinear from ...quantization import FORMAT, METHOD from ...utils import BACKEND from ...utils.env import env_flag @@ -30,8 +31,15 @@ BITBLAS_SUPPORTED_GROUP_SIZES: List[int] = [-1, 32, 64, 128] BITBLAS_SUPPORTED_BITS: List[int] = [1, 2, 4, 8] BITBLAS_SUPPORTED_SYM: List[bool] = [False, True] +# Keep bf16 exposed overall: upstream BitBLAS can successfully compile bf16 for some dtype/shape +# combinations (for example unsigned low-bit paths used by AWQ). The specific incompatibility we +# reproduced is the signed low-bit GPTQ dequant path, which can emit CUDA that tries to construct +# `cutlass::bfloat16_t(int)` and fails during BitBLAS runtime compilation. +BITBLAS_BF16_UNSUPPORTED_SIGNED_BITS = frozenset({2, 4, 8}) BITBLAS_DEFAULT_ZEROS_MODE = "quantized" BITBLAS_PROPAGATE_WEIGHTS = False +BITBLAS_MAX_SUPPORTED_SM = 90 +BITBLAS_FALLBACK_TARGET = f"cuda -arch=sm_{BITBLAS_MAX_SUPPORTED_SM}" BITBLAS_TARGET = None BITBLAS_DATABASE_PATH = None @@ -130,6 +138,8 @@ def _is_bitblas_available() -> bool: BITBLAS_AVAILABLE = _is_bitblas_available() +_BITBLAS_TARGET_ARCH_RE = pcre.compile(r"\bsm_(\d+)[a-z]*\b") +_BITBLAS_TARGET_SM_RE = pcre.compile(r"sm_(\d+)") BITBLAS_INSTALL_HINT = ( @@ -155,7 +165,7 @@ def import_bitblas(): bitblas.auto_detect_nvidia_target = patched_auto_detect_nvidia_target visible = int(os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")[0]) - BITBLAS_TARGET = patched_auto_detect_nvidia_target(visible) + BITBLAS_TARGET = _normalize_bitblas_target(patched_auto_detect_nvidia_target(visible)) os.environ["TVM_TARGET"] = f"{BITBLAS_TARGET}" log.debug("BITBLAS_TARGET %s", BITBLAS_TARGET) @@ -166,6 +176,84 @@ def import_bitblas(): log.debug("BITBLAS_DATABASE_PATH %s", BITBLAS_DATABASE_PATH) +def _bitblas_target_arch(target) -> Optional[str]: + if target is None: + return None + + target_text = str(target) + + try: + from bitblas import tvm + + return str(tvm.target.Target(target_text).arch) + except Exception: + match = _BITBLAS_TARGET_ARCH_RE.search(target_text) + if match: + return f"sm_{match.group(1)}" + return None + + +def _bitblas_target_sm(target) -> Optional[int]: + arch = _bitblas_target_arch(target) + if arch is None: + return None + + match = _BITBLAS_TARGET_SM_RE.search(arch) + if match: + return int(match.group(1)) + return None + + +def _current_cuda_sm() -> Optional[int]: + if not torch.cuda.is_available(): + return None + + props = torch.cuda.get_device_properties(torch.cuda.current_device()) + return props.major * 10 + props.minor + + +def _bitblas_fallback_target() -> str: + current_sm = _current_cuda_sm() + if current_sm is None: + return BITBLAS_FALLBACK_TARGET + + fallback_sm = min(current_sm, BITBLAS_MAX_SUPPORTED_SM) + return f"cuda -arch=sm_{fallback_sm}" + + +def _normalize_bitblas_target(target): + if target is None: + return None + + arch = _bitblas_target_arch(target) + sm_version = _bitblas_target_sm(target) + if sm_version is None: + return target + + canonical_target = f"cuda -arch=sm_{sm_version}" + + if sm_version > BITBLAS_MAX_SUPPORTED_SM: + fallback_target = _bitblas_fallback_target() + log.warning( + "BitBLAS target %s resolves to unsupported CUDA arch %s; falling back to %s.", + target, + arch, + fallback_target, + ) + return fallback_target + + if arch != f"sm_{sm_version}": + log.info( + "Canonicalizing BitBLAS target %s (%s) to %s.", + target, + arch, + canonical_target, + ) + return canonical_target + + return target + + def unpack_gptq_qzeros(qzeros: torch.Tensor, bits: int, is_gptq_v2: bool = False) -> torch.Tensor: qzeros = qzeros.view(torch.int32) elems_per_int32 = 32 // bits @@ -202,6 +290,36 @@ def unpack_gptq_qweight(qweight: torch.Tensor, bits: int) -> torch.Tensor: return torch.bitwise_and(unpacked_weight, 2**bits - 1) +def remap_gptq_symmetric_codes_to_bitblas(qweight_codes: torch.Tensor, bits: int) -> torch.Tensor: + # Some in-memory TorchLinear symmetric pack paths still encode qweight as GPTQ-style + # two's-complement nibbles while leaving qzeros packed as all zeros. BitBLAS' signed intN path + # expects the corresponding biased code range instead, so flip the sign bit for that narrow + # producer case before loading the quant state. + sign_bit = 1 << (bits - 1) + remapped = torch.bitwise_xor(qweight_codes.to(torch.int16), sign_bit) + return remapped.to(torch.int8).contiguous() + + +def _should_remap_symmetric_gptq_codes(gptq_module: BaseQuantLinear) -> bool: + qzeros = getattr(gptq_module, "qzeros", None) + if qzeros is None or qzeros.numel() == 0: + return False + + qzero_format = getattr(gptq_module, "qzero_format", None) + if callable(qzero_format): + format_id = qzero_format() + if format_id == 2: + return False + if format_id != 1: + return False + + # The remap is only needed for the pre-v2 TorchLinear symmetric pack path. That producer + # keeps qzeros packed as all zeros while storing qweight in GPTQ's two's-complement nibble + # layout. GPT-QModel converts external checkpoints to qzero_format=2 before BitBLAS repacking, + # and those tensors must be left untouched. + return qzeros.count_nonzero().item() == 0 + + def _num_groups(group_size: int, in_features: int) -> int: if group_size in (-1, in_features): return 1 @@ -214,6 +332,7 @@ class BitblasQuantizationConfig: group_size: int desc_act: bool is_sym: bool + torch_dtype: torch.dtype = torch.float16 zeros_mode: str = BITBLAS_DEFAULT_ZEROS_MODE storage_dtype: str = "int8" quant_method: str = "gptq" @@ -233,41 +352,49 @@ def __post_init__(self) -> None: ) if 32 % self.weight_bits != 0: raise ValueError("weight_bits must divide 32 for GPTQ packing") + if self.torch_dtype not in (torch.float16, torch.bfloat16): + raise ValueError("BitBLAS only supports torch.float16 and torch.bfloat16 compute dtypes") self.pack_factor = 32 // self.weight_bits self.torch_storage_dtype = getattr(torch, self.storage_dtype) - self.torch_dtype = torch.float16 @property def with_zeros(self) -> bool: return not self.is_sym and self.zeros_mode == "quantized" -class BitblasQuantLinear(BaseQuantLinear): - SUPPORTS_BACKENDS = [BACKEND.BITBLAS] - SUPPORTS_FORMATS = {FORMAT.BITBLAS: 30, FORMAT.GPTQ: 30} - SUPPORTS_BITS = BITBLAS_SUPPORTED_BITS - SUPPORTS_GROUP_SIZE = BITBLAS_SUPPORTED_GROUP_SIZES - SUPPORTS_DESC_ACT = [False, True] - SUPPORTS_SYM = BITBLAS_SUPPORTED_SYM - SUPPORTS_SHARDS = True - SUPPORTS_TRAINING = False - SUPPORTS_AUTO_PADDING = False - SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [16] - SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [16] - - SUPPORTS_DEVICES = [DEVICE.CUDA] - SUPPORTS_PLATFORM = [PLATFORM.LINUX, PLATFORM.WIN32] - SUPPORTS_PACK_DTYPES = [torch.int32] - SUPPORTS_ADAPTERS = [Lora] - - SUPPORTS_DTYPES = [torch.float16, torch.bfloat16] - - QUANT_TYPE = "gptq_bitblas" - SUPPORTS_METHODS = [METHOD.GPTQ] - +class BitblasBaseQuantLinear(GroupedQuantLinear): OPT_FEATURES = BITBLAS_OPTIMIZE_FEATURES TORCH_DTYPE = torch.float16 + def _build_quant_config( + self, + *, + bits: int, + group_size: int, + desc_act: bool, + sym: bool, + dtype: torch.dtype, + ) -> BitblasQuantizationConfig: + return BitblasQuantizationConfig( + weight_bits=bits, + group_size=group_size, + desc_act=desc_act, + is_sym=sym, + torch_dtype=dtype, + ) + + @classmethod + def _validate_kernel_combo( + cls, + *, + bits: int, + sym: bool, + dtype: Optional[torch.dtype], + dynamic: Optional[dict] = None, + ) -> Tuple[bool, Optional[Exception]]: + del bits, sym, dtype, dynamic + return True, None + def __init__( self, bits: int, @@ -278,6 +405,7 @@ def __init__( out_features: int, bias: bool = False, pack_dtype: torch.dtype = torch.int32, + dtype: torch.dtype = torch.float16, adapter: Adapter = None, enable_tuning: bool = False, fast_decoding: bool = True, # kept for API compatibility @@ -287,6 +415,13 @@ def __init__( register_buffers: bool = False, **kwargs, ) -> None: + if dtype not in self.SUPPORTS_DTYPES: + raise ValueError(f"{self.__class__.__name__} only supports dtypes {self.SUPPORTS_DTYPES}: actual dtype = {dtype}") + + ok, err = self.__class__._validate_kernel_combo(bits=bits, sym=sym, dtype=dtype) + if not ok: + raise err + super().__init__( bits=bits, group_size=group_size, @@ -296,7 +431,7 @@ def __init__( out_features=out_features, bias=bias, pack_dtype=pack_dtype, - backend=kwargs.pop("backend", BACKEND.BITBLAS), + backend=kwargs.pop("backend", BACKEND.GPTQ_BITBLAS), adapter=adapter, register_buffers=False, **kwargs, @@ -307,11 +442,13 @@ def __init__( if not BITBLAS_AVAILABLE: raise ImportError(BITBLAS_INSTALL_HINT) - self.quant_config = BitblasQuantizationConfig( - weight_bits=bits, + self.TORCH_DTYPE = dtype + self.quant_config = self._build_quant_config( + bits=bits, group_size=group_size, desc_act=desc_act, - is_sym=sym, + sym=sym, + dtype=dtype, ) self.enable_tuning = enable_tuning self.layout = layout @@ -343,11 +480,56 @@ def validate_once(cls) -> Tuple[bool, Optional[Exception]]: return False, exc return True, None + @classmethod + def validate( + cls, + bits: int, + group_size: int, + desc_act: bool, + sym: bool, + in_features: int = None, + out_features: int = None, + pack_dtype: torch.dtype = None, + dtype: Optional[torch.dtype] = None, + dynamic: Optional[dict] = None, + device: Optional[DEVICE] = None, + trainable: Optional[bool] = None, + adapter: Optional[Adapter] = None, + ) -> Tuple[bool, Optional[Exception]]: + ok, err = cls._validate_kernel_combo(bits=bits, sym=sym, dtype=dtype, dynamic=dynamic) + if not ok: + return False, err + + return super().validate( + bits=bits, + group_size=group_size, + desc_act=desc_act, + sym=sym, + in_features=in_features, + out_features=out_features, + pack_dtype=pack_dtype, + dtype=dtype, + dynamic=dynamic, + device=device, + trainable=trainable, + adapter=adapter, + ) + def _validate_parameters(self, in_features: int, out_features: int) -> None: - if in_features % 16 != 0: - raise ValueError("`in_features` must be divisible by 16 for BitBLAS") - if out_features % 16 != 0: - raise ValueError("`out_features` must be divisible by 16 for BitBLAS") + # This wrapper keeps a conservative 16-wide gate because the current GPTQ/AWQ BitBLAS + # integration is built around TensorCore-oriented 16x16 micro-kernel paths. BitBLAS itself + # is looser in some cases: local probes showed odd N/out_features can still build, while + # the hard upstream packing constraint we confirmed is K/in_features divisible by 8 / bits + # for quant-compressed weights. We keep the stricter gate here until the relaxed shapes are + # regression-tested across the full load/repack/forward path. + if any(in_features % divisor != 0 for divisor in self.SUPPORTS_IN_FEATURES_DIVISIBLE_BY): + raise ValueError( + f"`in_features` must be divisible by {self.SUPPORTS_IN_FEATURES_DIVISIBLE_BY} for BitBLAS" + ) + if any(out_features % divisor != 0 for divisor in self.SUPPORTS_OUT_FEATURES_DIVISIBLE_BY): + raise ValueError( + f"`out_features` must be divisible by {self.SUPPORTS_OUT_FEATURES_DIVISIBLE_BY} for BitBLAS" + ) if self.group_size not in (-1, in_features) and in_features % self.group_size != 0: raise ValueError("`in_features` must be divisible by `group_size`.") @@ -409,6 +591,8 @@ def _configure_bitblas_matmul( from bitblas import MatmulConfig bitblas_dtype = "float16" if params_dtype == torch.float16 else "bfloat16" + # FP16/BF16 accumulation drifted enough to derail autoregressive decoding. + accum_dtype = "float32" W_dtype = f"uint{bits}" if self.quant_config.is_sym is False else f"int{bits}" matmul_config = MatmulConfig( M=self.opt_features, @@ -417,7 +601,7 @@ def _configure_bitblas_matmul( A_dtype=bitblas_dtype, W_dtype=W_dtype, out_dtype=bitblas_dtype, - accum_dtype="int32" if bitblas_dtype == "int8" else bitblas_dtype, + accum_dtype="int32" if bitblas_dtype == "int8" else accum_dtype, storage_dtype=self.quant_config.storage_dtype, with_scaling=True, with_zeros=self.quant_config.with_zeros, @@ -429,24 +613,37 @@ def _configure_bitblas_matmul( self.bitblas_matmul = self._get_or_create_bitblas_operator( matmul_config, enable_tuning ) + self._ensure_runnable_bitblas_operator(self.bitblas_matmul, matmul_config) + + def _ensure_runnable_bitblas_operator(self, bitblas_matmul, config) -> None: + if getattr(bitblas_matmul, "lib", None) is not None: + return + if callable(getattr(bitblas_matmul, "torch_func", None)): + return + raise NotImplementedError( + "BitBLAS could not build a runnable matmul for " + f"A_dtype={config.A_dtype}, W_dtype={config.W_dtype}, out_dtype={config.out_dtype}, " + f"accum_dtype={config.accum_dtype}, group_size={config.group_size}." + ) def _get_or_create_bitblas_operator(self, config, enable_tuning): from bitblas import Matmul from bitblas.cache import global_operator_cache + target = _normalize_bitblas_target(BITBLAS_TARGET) if global_operator_cache.size() == 0: global_operator_cache.load_from_database( - BITBLAS_DATABASE_PATH, BITBLAS_TARGET + BITBLAS_DATABASE_PATH, target ) bitblas_matmul = global_operator_cache.get(config) if bitblas_matmul is None: - bitblas_matmul = Matmul(config, target=BITBLAS_TARGET, enable_tuning=enable_tuning) + bitblas_matmul = Matmul(config, target=target, enable_tuning=enable_tuning) if enable_tuning: bitblas_matmul.hardware_aware_finetune(topk=20) global_operator_cache.add(config, bitblas_matmul) global_operator_cache.save_into_database( - BITBLAS_DATABASE_PATH, BITBLAS_TARGET + BITBLAS_DATABASE_PATH, target ) log.info( "BitBLAS operator tuned and added to cache for %s", config @@ -468,8 +665,56 @@ def reset_parameters(self) -> None: def post_init(self) -> None: super().post_init() + def _transform_bitblas_weight(self, intweight_out_in: torch.Tensor, device: torch.device) -> torch.Tensor: + from bitblas.quantization.utils import general_compress + + if self.bitblas_matmul.weight_transform is not None: + qweight = self.bitblas_matmul.weight_transform(intweight_out_in.cpu()).to(device) + else: + compressed = general_compress(intweight_out_in.cpu().numpy(), self.bits) + qweight = torch.from_numpy(compressed).to( + device=device, + dtype=self.quant_config.torch_storage_dtype, + ) + return qweight.contiguous() + + def _compress_bitblas_zeros( + self, + intzeros_group_out: Optional[torch.Tensor], + device: torch.device, + ) -> torch.Tensor: + from bitblas.quantization.utils import general_compress + + if not self.quant_config.with_zeros or intzeros_group_out is None: + return torch.empty(0, dtype=self.quant_config.torch_storage_dtype, device=device) + + compressed = general_compress(intzeros_group_out.contiguous().cpu().numpy(), self.bits) + return torch.from_numpy(compressed).to( + device=device, + dtype=self.quant_config.torch_storage_dtype, + ).contiguous() + + def _load_bitblas_quant_state( + self, + intweight_out_in: torch.Tensor, + scales_out_group: torch.Tensor, + intzeros_group_out: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + ) -> None: + device = self._buffer_device() + + self._buffers["qweight"] = self._transform_bitblas_weight(intweight_out_in, device) + self._buffers["scales"] = scales_out_group.to(device=device, dtype=self.TORCH_DTYPE).contiguous() + self._buffers["qzeros"] = self._compress_bitblas_zeros(intzeros_group_out, device) + + if self.bias is not None and bias is not None: + self._buffers["bias"] = bias.detach().to(device=device, dtype=self.TORCH_DTYPE) + + self.zeros = self.qzeros + def forward(self, x: torch.Tensor) -> torch.Tensor: - if x.dtype not in (torch.float16, torch.bfloat16): + input_dtype = x.dtype + if input_dtype != self.TORCH_DTYPE: x = x.to(self.TORCH_DTYPE) orig_shape = x.shape[:-1] @@ -488,49 +733,97 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.adapter: out = self.adapter.apply(x=x, out=out) + if input_dtype in self.SUPPORTS_DTYPES and out.dtype != input_dtype: + out = out.to(dtype=input_dtype) + return out - def repack_from_gptq(self, gptq_module: BaseQuantLinear) -> None: - from bitblas.quantization.utils import general_compress - device = self._buffer_device() +# BitBLAS repacks incoming GPTQ/AWQ tensors into its own operator layout, so the +# destination module only needs grouped quantization state, not GPTQ qzero-format state. +class BitblasLinear(BitblasBaseQuantLinear): + SUPPORTS_BACKENDS = [BACKEND.GPTQ_BITBLAS] + SUPPORTS_FORMATS = {FORMAT.BITBLAS: 30, FORMAT.GPTQ: 30, FORMAT.GPTQ_V2: 30} + SUPPORTS_BITS = BITBLAS_SUPPORTED_BITS + SUPPORTS_GROUP_SIZE = BITBLAS_SUPPORTED_GROUP_SIZES + # BitBLAS' public matmul API does not expose GPTQ activation-order metadata (`g_idx` / + # permutation tensors). Keep desc_act disabled until upstream adds a supported act-order path. + SUPPORTS_DESC_ACT = [False] + SUPPORTS_SYM = BITBLAS_SUPPORTED_SYM + SUPPORTS_SHARDS = True + SUPPORTS_TRAINING = False + SUPPORTS_AUTO_PADDING = False + SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [16] + SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [16] - bits = self.bits - packed_weight = ( - gptq_module.qweight.detach().T.contiguous().view(self.quant_config.torch_storage_dtype) + SUPPORTS_DEVICES = [DEVICE.CUDA] + SUPPORTS_PLATFORM = [PLATFORM.LINUX, PLATFORM.WIN32] + SUPPORTS_PACK_DTYPES = [torch.int32] + SUPPORTS_ADAPTERS = [Lora] + + SUPPORTS_DTYPES = [torch.float16, torch.bfloat16] + + QUANT_TYPE = "gptq_bitblas" + SUPPORTS_METHODS = [METHOD.GPTQ] + + @classmethod + def _bf16_signed_weight_error(cls, bits: int, layer: Optional[str] = None) -> NotImplementedError: + location = f" for layer pattern `{layer}`" if layer is not None else "" + return NotImplementedError( + f"{cls.__name__} does not support `torch.bfloat16` with symmetric `{bits}`-bit GPTQ weights{location}. " + "This is blocked by an upstream BitBLAS CUDA codegen failure for signed low-bit dequantization. " + "Use `torch.float16`, asymmetric GPTQ/unsigned weights, or a different backend." ) - intweight = unpack_gptq_qweight(packed_weight, bits).contiguous() - if self.bitblas_matmul.weight_transform is not None: - qweight = self.bitblas_matmul.weight_transform(intweight.cpu()).to(device) - else: - from bitblas.quantization.utils import general_compress + @classmethod + def _validate_kernel_combo( + cls, + *, + bits: int, + sym: bool, + dtype: Optional[torch.dtype], + dynamic: Optional[dict] = None, + ) -> Tuple[bool, Optional[Exception]]: + if dtype != torch.bfloat16: + return True, None - compressed = general_compress(intweight.cpu().numpy(), bits) - qweight = torch.from_numpy(compressed).to( - device=device, dtype=self.quant_config.torch_storage_dtype - ) + if sym and bits in BITBLAS_BF16_UNSUPPORTED_SIGNED_BITS: + return False, cls._bf16_signed_weight_error(bits) + + if dynamic is None: + return True, None - self._buffers["qweight"] = qweight.contiguous() + for layer, overrides in dynamic.items(): + layer_bits = overrides.get("bits", bits) + layer_sym = overrides.get("sym", sym) + if layer_sym and layer_bits in BITBLAS_BF16_UNSUPPORTED_SIGNED_BITS: + return False, cls._bf16_signed_weight_error(layer_bits, layer=layer) - scales = gptq_module.scales.detach().T.contiguous().to(self.TORCH_DTYPE) - self._buffers["scales"] = scales.to(device) + return True, None + + def repack_from_gptq(self, gptq_module: BaseQuantLinear) -> None: + bits = self.bits + packed_weight = ( + gptq_module.qweight.detach().T.contiguous().view(self.quant_config.torch_storage_dtype) + ) + intweight = unpack_gptq_qweight(packed_weight, bits).contiguous() + if self.quant_config.is_sym and _should_remap_symmetric_gptq_codes(gptq_module): + intweight = remap_gptq_symmetric_codes_to_bitblas(intweight, bits) + intzeros = None if self.quant_config.with_zeros and hasattr(gptq_module, "qzeros") and gptq_module.qzeros is not None: - intzeros = unpack_gptq_qzeros(gptq_module.qzeros.detach(), bits).T.contiguous() + intzeros = unpack_gptq_qzeros(gptq_module.qzeros.detach(), bits).contiguous() intzeros = intzeros - 1 # GPTQ stores qzeros offset by +1 - compressed = general_compress(intzeros.T.contiguous().cpu().numpy(), bits) - zeros = torch.from_numpy(compressed).to(device=device, dtype=self.quant_config.torch_storage_dtype) - self._buffers["qzeros"] = zeros.contiguous() - else: - self._buffers["qzeros"] = torch.empty(0, dtype=self.quant_config.torch_storage_dtype, device=device) - - if self.bias is not None and hasattr(gptq_module, "bias") and gptq_module.bias is not None: - self._buffers["bias"] = gptq_module.bias.detach().to(device=device, dtype=self.TORCH_DTYPE) - self.zeros = self.qzeros + bias = gptq_module.bias.detach() if self.bias is not None and getattr(gptq_module, "bias", None) is not None else None + self._load_bitblas_quant_state( + intweight_out_in=intweight, + scales_out_group=gptq_module.scales.detach().T.contiguous(), + intzeros_group_out=intzeros, + bias=bias, + ) -BitBLASQuantLinear = BitblasQuantLinear +BitBLASLinear = BitblasLinear -__all__ = ["BitblasQuantLinear", "BitBLASQuantLinear"] +__all__ = ["BitblasLinear", "BitBLASLinear"] diff --git a/gptqmodel/nn_modules/qlinear/bitblas_awq.py b/gptqmodel/nn_modules/qlinear/bitblas_awq.py new file mode 100644 index 000000000..9fbca0d6c --- /dev/null +++ b/gptqmodel/nn_modules/qlinear/bitblas_awq.py @@ -0,0 +1,170 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from __future__ import annotations + +from typing import Optional, Union + +import torch +from torch import nn + +from ...adapter.adapter import Adapter, Lora +from ...models._const import DEVICE, PLATFORM +from ...quantization import FORMAT, METHOD +from ...quantization.awq.utils.packing_utils import reverse_awq_order, unpack_awq +from ...utils.backend import BACKEND +from .bitblas import ( + BITBLAS_OPTIMIZE_FEATURES, + BITBLAS_PROPAGATE_WEIGHTS, + BitblasBaseQuantLinear, + BitblasQuantizationConfig, +) + + +class AWQBitBlasKernel(BitblasBaseQuantLinear): + SUPPORTS_BACKENDS = [BACKEND.AWQ_BITBLAS] + SUPPORTS_METHODS = [METHOD.AWQ] + SUPPORTS_FORMATS = {FORMAT.GEMM: 0, FORMAT.BITBLAS: 30} + SUPPORTS_BITS = [4] + SUPPORTS_GROUP_SIZE = [-1, 32, 64, 128] + SUPPORTS_DESC_ACT = [False, True] + SUPPORTS_SYM = [False, True] + SUPPORTS_SHARDS = True + SUPPORTS_TRAINING = False + SUPPORTS_AUTO_PADDING = False + SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [16] + SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [16] + + SUPPORTS_DEVICES = [DEVICE.CUDA] + SUPPORTS_PLATFORM = [PLATFORM.LINUX, PLATFORM.WIN32] + SUPPORTS_PACK_DTYPES = [torch.int32] + SUPPORTS_ADAPTERS = [Lora] + + SUPPORTS_DTYPES = [torch.float16, torch.bfloat16] + + QUANT_TYPE = "awq_bitblas" + + def __init__( + self, + bits: int, + group_size: int, + desc_act: bool, + sym: bool, + in_features: int, + out_features: int, + bias: bool = False, + pack_dtype: torch.dtype = torch.int32, + dtype: torch.dtype = torch.float16, + adapter: Adapter = None, + enable_tuning: bool = False, + fast_decoding: bool = True, + propagate_b: bool = BITBLAS_PROPAGATE_WEIGHTS, + opt_features: Union[int, list[int]] = BITBLAS_OPTIMIZE_FEATURES, + layout: str = "nt", + register_buffers: bool = False, + **kwargs, + ) -> None: + super().__init__( + bits=bits, + group_size=group_size, + desc_act=desc_act, + sym=sym, + in_features=in_features, + out_features=out_features, + bias=bias, + pack_dtype=pack_dtype, + dtype=dtype, + adapter=adapter, + enable_tuning=enable_tuning, + fast_decoding=fast_decoding, + propagate_b=propagate_b, + opt_features=opt_features, + layout=layout, + register_buffers=register_buffers, + backend=kwargs.pop("backend", BACKEND.AWQ_BITBLAS), + **kwargs, + ) + + def _build_quant_config( + self, + *, + bits: int, + group_size: int, + desc_act: bool, + sym: bool, + dtype: torch.dtype, + ) -> BitblasQuantizationConfig: + del sym + + # AWQ stores unsigned weight codes plus zero-points, even for symmetric checkpoints. + return BitblasQuantizationConfig( + weight_bits=bits, + group_size=group_size, + desc_act=desc_act, + is_sym=False, + torch_dtype=dtype, + quant_method="awq", + ) + + @torch.inference_mode() + def pack( + self, + linear: nn.Module, + scales: torch.Tensor, + zeros: torch.Tensor, + g_idx: Optional[torch.Tensor] = None, + ) -> None: + del g_idx + + if scales is None or zeros is None: + raise ValueError("AWQBitBlasKernel.pack requires both scales and zeros.") + + scales = scales.t().contiguous() + zeros = zeros.t().contiguous() + + weight = linear.weight.detach().contiguous() + group_idx = torch.arange(self.in_features, device=weight.device, dtype=torch.int64) // self.group_size + maxq = (1 << self.bits) - 1 + + # Broadcast per-group affine parameters across K so we can recover the stored AWQ integer codes. + scale_zeros = (zeros * scales).index_select(0, group_idx).t() + scales_by_input = scales.index_select(0, group_idx).t() + intweight = torch.round((weight + scale_zeros) / scales_by_input).clamp_(0, maxq).to(torch.int8) + + self._load_bitblas_quant_state( + intweight_out_in=intweight, + scales_out_group=scales.t().contiguous(), + intzeros_group_out=torch.round(zeros).to(torch.int8).contiguous(), + bias=linear.bias.detach() if linear.bias is not None else None, + ) + + @torch.inference_mode() + def repack_from_awq(self, awq_module) -> None: + qzeros = getattr(awq_module, "qzeros", None) + if qzeros is None: + raise ValueError("AWQBitBlasKernel requires qzeros to repack AWQ checkpoints.") + + intweight, intzeros = unpack_awq( + awq_module.qweight.detach(), + qzeros.detach(), + self.bits, + ) + intweight, intzeros = reverse_awq_order(intweight, intzeros, self.bits) + + maxq = (1 << self.bits) - 1 + intweight = torch.bitwise_and(intweight, maxq).to(torch.int8).t().contiguous() + intzeros = torch.bitwise_and(intzeros, maxq).to(torch.int8).contiguous() + + self._load_bitblas_quant_state( + intweight_out_in=intweight, + scales_out_group=awq_module.scales.detach().t().contiguous(), + intzeros_group_out=intzeros, + bias=awq_module.bias.detach() if getattr(awq_module, "bias", None) is not None else None, + ) + + +AwqBitBLASLinear = AWQBitBlasKernel + +__all__ = ["AWQBitBlasKernel", "AwqBitBLASLinear"] diff --git a/gptqmodel/nn_modules/qlinear/bitsandbytes.py b/gptqmodel/nn_modules/qlinear/bitsandbytes.py new file mode 100644 index 000000000..12ffd1022 --- /dev/null +++ b/gptqmodel/nn_modules/qlinear/bitsandbytes.py @@ -0,0 +1,411 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from functools import lru_cache +from typing import Dict, Optional, Tuple + +import torch +import torch.nn as nn +import transformers +from packaging import version +from torch.nn.modules.conv import _ConvNd + +from ...models._const import DEVICE, PLATFORM +from ...quantization import FORMAT, METHOD +from ...quantization.config import ( + _normalize_bitsandbytes_block_size, + _normalize_bitsandbytes_format, +) +from ...utils.backend import BACKEND +from ...utils.logger import setup_logger +from . import WeightOnlyQuantLinear +from .gguf import _apply_optional_smoother + + +log = setup_logger() + +MINIMUM_BITSANDBYTES_VERSION = "0.49.0" +BITSANDBYTES_INSTALL_HINT = ( + "bitsandbytes is not installed or is too old. " + "Install a recent 0.49.x build, for example `pip install bitsandbytes>=0.49.3`." +) + +_BITSANDBYTES_4BIT_STATE_BUFFER_NAMES = ( + "weight_absmax", + "weight_quant_map", + "weight_nested_absmax", + "weight_nested_quant_map", + "weight_quant_state", +) +_BITSANDBYTES_8BIT_STATE_BUFFER_NAMES = ("weight_scb",) + + +def _is_bitsandbytes_available() -> bool: + try: + import bitsandbytes as bnb + except Exception as exc: # pragma: no cover - optional dependency + log.debug("bitsandbytes import failed: %s", exc) + return False + + return version.parse(bnb.__version__) >= version.parse(MINIMUM_BITSANDBYTES_VERSION) + + +def import_bitsandbytes(): + import bitsandbytes as bnb + + if version.parse(bnb.__version__) < version.parse(MINIMUM_BITSANDBYTES_VERSION): + raise ImportError(BITSANDBYTES_INSTALL_HINT) + return bnb + + +BITSANDBYTES_AVAILABLE = _is_bitsandbytes_available() + + +def _weight_to_matrix(linear: nn.Module) -> torch.Tensor: + weight = linear.weight.detach() + if isinstance(linear, _ConvNd): + weight = weight.flatten(1) + if isinstance(linear, transformers.pytorch_utils.Conv1D): + weight = weight.T + return weight + + +def _packed_state_key_to_buffer_name(key: str) -> str: + if key.startswith("quant_state."): + return "weight_quant_state" + return f"weight_{key}" + + +@lru_cache(maxsize=256) +def _buffer_spec_4bit( + *, + in_features: int, + out_features: int, + quant_type: str, + block_size: int, + compress_statistics: bool, +) -> Tuple[Tuple[str, Tuple[int, ...], torch.dtype], ...]: + bnb = import_bitsandbytes() + + # Quantize a template once so the registered buffers match the exact + # packed layout bitsandbytes expects during checkpoint load. + template = torch.zeros((out_features, in_features), dtype=torch.float16) + qweight, quant_state = bnb.functional.quantize_4bit( + template, + blocksize=block_size, + compress_statistics=compress_statistics, + quant_type=quant_type, + quant_storage=torch.uint8, + ) + + spec = [("weight", tuple(qweight.shape), qweight.dtype)] + for key, tensor in quant_state.as_dict(packed=True).items(): + spec.append((_packed_state_key_to_buffer_name(key), tuple(tensor.shape), tensor.dtype)) + return tuple(spec) + + +class BitsAndBytesLinear(WeightOnlyQuantLinear): + SUPPORTS_BACKENDS = [BACKEND.BITSANDBYTES] + SUPPORTS_METHODS = [METHOD.BITSANDBYTES] + SUPPORTS_FORMATS = {FORMAT.BITSANDBYTES: 40} + SUPPORTS_BITS = [4, 8] + SUPPORTS_SHARDS = True + SUPPORTS_TRAINING = False + SUPPORTS_AUTO_PADDING = False + SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [1] + SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [1] + SUPPORTS_DEVICES = [DEVICE.CPU, DEVICE.CUDA] + SUPPORTS_PLATFORM = [PLATFORM.ALL] + SUPPORTS_PACK_DTYPES = [torch.int8, torch.int16, torch.int32, torch.int64] + SUPPORTS_ADAPTERS = [] + SUPPORTS_DTYPES = [torch.float16, torch.bfloat16, torch.float32] + + QUANT_TYPE = "bitsandbytes" + + def __init__( + self, + bits: int, + group_size: int, + sym: bool, + desc_act: bool, + in_features: int, + out_features: int, + bias: bool = False, + pack_dtype: torch.dtype = torch.int32, + dtype: torch.dtype = torch.float16, + register_buffers: bool = True, + format: Optional[str] = None, + block_size: Optional[int] = None, + compress_statistics: Optional[bool] = None, + bnb_quant_type: Optional[str] = None, + bnb_block_size: Optional[int] = None, + bnb_compress_statistics: Optional[bool] = None, + **kwargs, + ): + raw_format = format if format is not None else bnb_quant_type + raw_block_size = block_size if block_size is not None else bnb_block_size + raw_compress_statistics = ( + compress_statistics if compress_statistics is not None else bnb_compress_statistics + ) + + self.bnb_format = _normalize_bitsandbytes_format(raw_format, bits=bits) + self.bnb_block_size = _normalize_bitsandbytes_block_size(raw_block_size) + self.bnb_compress_statistics = True if raw_compress_statistics is None else bool(raw_compress_statistics) + self.compute_dtype = dtype + self.quant_state = None + self._quant_state_signature = None + + super().__init__( + bits=bits, + in_features=in_features, + out_features=out_features, + bias=bias, + backend=kwargs.pop("backend", BACKEND.BITSANDBYTES), + adapter=kwargs.pop("adapter", None), + register_buffers=False, + dtype=dtype, + pack_dtype=pack_dtype, + **kwargs, + ) + + if register_buffers: + self._allocate_buffers(bias=bias) + + @classmethod + def validate_once(cls) -> Tuple[bool, Optional[Exception]]: + if not BITSANDBYTES_AVAILABLE: + return False, ImportError(BITSANDBYTES_INSTALL_HINT) + + try: + import_bitsandbytes() + except Exception as exc: + return False, exc + return True, None + + @property + def is_4bit(self) -> bool: + return self.bits == 4 + + def smooth_block_size(self) -> int: + return self.bnb_block_size + + def _allocate_buffers(self, *, bias: bool) -> None: + if self.is_4bit: + buffer_spec = _buffer_spec_4bit( + in_features=self.in_features, + out_features=self.out_features, + quant_type=self.bnb_format, + block_size=self.bnb_block_size, + compress_statistics=self.bnb_compress_statistics, + ) + else: + buffer_spec = ( + ("weight", (self.out_features, self.in_features), torch.int8), + ("weight_scb", (self.out_features,), torch.float32), + ) + + for buffer_name, shape, dtype in buffer_spec: + value = torch.zeros(shape, dtype=dtype) + if buffer_name in self._buffers: + self._buffers[buffer_name] = value + else: + self.register_buffer(buffer_name, value) + + if bias: + bias_tensor = torch.zeros(self.out_features, dtype=self.compute_dtype) + if "bias" in self._buffers: + self._buffers["bias"] = bias_tensor + else: + self.register_buffer("bias", bias_tensor) + else: + self.bias = None + + def list_buffers(self): + buffers = [] + if hasattr(self, "weight") and self.weight is not None: + buffers.append(self.weight) + + state_buffer_names = ( + _BITSANDBYTES_4BIT_STATE_BUFFER_NAMES if self.is_4bit else _BITSANDBYTES_8BIT_STATE_BUFFER_NAMES + ) + for buffer_name in state_buffer_names: + tensor = getattr(self, buffer_name, None) + if tensor is not None: + buffers.append(tensor) + + if hasattr(self, "bias") and self.bias is not None: + buffers.append(self.bias) + return buffers + + def extra_repr(self) -> str: + extra = ( + f"in_features={self.in_features}, out_features={self.out_features}, " + f"bias={self.bias is not None}, bits={self.bits}" + ) + if self.is_4bit: + return ( + f"{extra}, format={self.bnb_format}, " + f"block_size={self.bnb_block_size}, " + f"compress_statistics={self.bnb_compress_statistics}" + ) + return f"{extra}, format={self.bnb_format}" + + def _quant_state_payload(self) -> Dict[str, torch.Tensor]: + payload: Dict[str, torch.Tensor] = { + "absmax": self.weight_absmax, + "quant_map": self.weight_quant_map, + f"quant_state.bitsandbytes__{self.bnb_format}": self.weight_quant_state, + } + if hasattr(self, "weight_nested_absmax"): + payload["nested_absmax"] = self.weight_nested_absmax + if hasattr(self, "weight_nested_quant_map"): + payload["nested_quant_map"] = self.weight_nested_quant_map + return payload + + def _refresh_quant_state(self, *, force: bool = False): + if not self.is_4bit: + self.quant_state = None + self._quant_state_signature = None + return None + + bnb = import_bitsandbytes() + signature = tuple( + ( + name, + tuple(getattr(self, name).shape), + str(getattr(self, name).device), + getattr(self, name).dtype, + ) + for name in ("weight", "weight_absmax", "weight_quant_map", "weight_quant_state") + ) + if not force and self.quant_state is not None and signature == self._quant_state_signature: + return self.quant_state + + self.quant_state = bnb.functional.QuantState.from_dict( + self._quant_state_payload(), + device=self.weight.device, + ) + self._quant_state_signature = signature + return self.quant_state + + def post_init(self): + super().post_init() + if self.is_4bit: + self._refresh_quant_state(force=True) + + def dequantize_weight(self) -> torch.Tensor: + bnb = import_bitsandbytes() + if self.is_4bit: + quant_state = self._refresh_quant_state() + return bnb.functional.dequantize_4bit(self.weight, quant_state=quant_state).contiguous() + return bnb.functional.int8_vectorwise_dequant(self.weight, self.weight_scb).contiguous() + + def pack(self, linear: nn.Module, scales: torch.Tensor, zeros: torch.Tensor, g_idx: torch.Tensor = None): + self.pack_original(linear=linear, scales=scales, zeros=zeros, g_idx=g_idx) + + def pack_block( + self, + linear: nn.Module, + scales: torch.Tensor, + zeros: torch.Tensor, + g_idx: torch.Tensor = None, + block_in: int = 8192, + workers: int = 1, + ): + del block_in, workers + self.pack_original(linear=linear, scales=scales, zeros=zeros, g_idx=g_idx) + + @torch.inference_mode() + def pack_original( + self, + linear: nn.Module, + scales: torch.Tensor, + zeros: torch.Tensor, + g_idx: torch.Tensor = None, + smooth=None, + ): + del scales, zeros, g_idx + + bnb = import_bitsandbytes() + weight = _apply_optional_smoother( + _weight_to_matrix(linear).to(device="cpu"), + smooth=smooth, + group_size=self.smooth_block_size(), + ).contiguous() + + if self.is_4bit: + weight = weight.to(torch.float32 if self.compute_dtype == torch.float32 else self.compute_dtype) + qweight, quant_state = bnb.functional.quantize_4bit( + weight, + blocksize=self.bnb_block_size, + compress_statistics=self.bnb_compress_statistics, + quant_type=self.bnb_format, + quant_storage=torch.uint8, + ) + self._buffers["weight"] = qweight.contiguous() + for key, tensor in quant_state.as_dict(packed=True).items(): + self._buffers[_packed_state_key_to_buffer_name(key)] = tensor.contiguous() + self._refresh_quant_state(force=True) + else: + qweight, scales, outlier_cols = bnb.functional.int8_vectorwise_quant( + weight.to(torch.float16), + threshold=0.0, + ) + if outlier_cols is not None and outlier_cols.numel() > 0: + raise NotImplementedError( + "BitsAndBytesLinear only supports the direct int8 vectorwise path without outlier routing." + ) + self._buffers["weight"] = qweight.contiguous() + self._buffers["weight_scb"] = scales.contiguous() + + if linear.bias is not None: + bias = linear.bias.detach().to(device="cpu", dtype=self.compute_dtype).contiguous() + if "bias" in self._buffers: + self._buffers["bias"] = bias + else: + self.register_buffer("bias", bias) + else: + self.bias = None + + def forward(self, x: torch.Tensor): + bnb = import_bitsandbytes() + + input_dtype = x.dtype + compute_dtype = self.compute_dtype if self.compute_dtype is not None else input_dtype + bias = None if self.bias is None else self.bias.to(compute_dtype) + + if self.is_4bit: + quant_state = self._refresh_quant_state() + out = bnb.matmul_4bit( + x.to(compute_dtype), + self.weight.t(), + quant_state=quant_state, + bias=bias, + ) + return out.to(input_dtype) + + x_int8, x_stats, outlier_cols = bnb.functional.int8_vectorwise_quant( + x.to(torch.float16), + threshold=0.0, + ) + if outlier_cols is not None and outlier_cols.numel() > 0: + raise NotImplementedError( + "BitsAndBytesLinear only supports the direct int8 vectorwise path without outlier routing." + ) + + mm_out = bnb.functional.int8_linear_matmul(x_int8, self.weight) + out = bnb.functional.int8_mm_dequant(mm_out, x_stats, self.weight_scb, bias=bias) + return out.to(input_dtype) + + +BitsAndBytes4bitLinear = BitsAndBytesLinear + +__all__ = [ + "BITSANDBYTES_AVAILABLE", + "BITSANDBYTES_INSTALL_HINT", + "BitsAndBytes4bitLinear", + "BitsAndBytesLinear", + "import_bitsandbytes", +] diff --git a/gptqmodel/nn_modules/qlinear/exllamav2.py b/gptqmodel/nn_modules/qlinear/exllamav2.py index 08cda6a02..5fcbe9aa4 100644 --- a/gptqmodel/nn_modules/qlinear/exllamav2.py +++ b/gptqmodel/nn_modules/qlinear/exllamav2.py @@ -11,27 +11,29 @@ from ...adapter.adapter import Adapter, Lora from ...models._const import DEVICE, PLATFORM -from ...nn_modules.qlinear import BaseQuantLinear +from ...nn_modules.qlinear import GPTQQuantLinear from ...quantization import FORMAT, METHOD from ...utils.backend import BACKEND -from ...utils.exllamav2 import ScratchSpace +from ...utils.exllamav2 import ( + ScratchSpace, + exllamav2_gemm_half_q_half, + exllamav2_gptq_runtime_available, + exllamav2_gptq_runtime_error, + exllamav2_make_q_matrix, +) from ...utils.logger import setup_logger log = setup_logger() -# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension -NONE_TENSOR = torch.empty((1, 1), device="meta") - - def _torch_device(idx): if idx == -1: return "cpu" return f"cuda:{idx}" -class ExllamaV2QuantLinear(BaseQuantLinear): - SUPPORTS_BACKENDS = [BACKEND.EXLLAMA_V2] +class ExllamaV2Linear(GPTQQuantLinear): + SUPPORTS_BACKENDS = [BACKEND.GPTQ_EXLLAMA_V2] SUPPORTS_METHODS = [METHOD.GPTQ] SUPPORTS_FORMATS = {FORMAT.GPTQ: 80, FORMAT.GPTQ_V2: 80} SUPPORTS_BITS = [4] @@ -60,8 +62,6 @@ class ExllamaV2QuantLinear(BaseQuantLinear): """Linear layer implementation with per-group 4-bit quantization of the weights""" - gptqmodel_exllamav2_kernels = None - def __init__( self, bits: int, @@ -95,7 +95,7 @@ def __init__( out_features=out_features, bias=bias, pack_dtype=pack_dtype, - backend=kwargs.pop("backend", BACKEND.EXLLAMA_V2), + backend=kwargs.pop("backend", BACKEND.GPTQ_EXLLAMA_V2), adapter=adapter, register_buffers=register_buffers, register_buffers_in_features=in_features, @@ -107,12 +107,9 @@ def __init__( @classmethod def validate_once(cls) -> Tuple[bool, Optional[Exception]]: - try: - import gptqmodel_exllamav2_kernels - cls.gptqmodel_exllamav2_kernels = gptqmodel_exllamav2_kernels - return True, None - except ImportError as e: - return False, e + if not exllamav2_gptq_runtime_available(): + return False, ImportError(exllamav2_gptq_runtime_error()) + return True, None def post_init(self, scratch_space: ScratchSpace): # resize due to padding after model weights have been loaded @@ -191,7 +188,7 @@ def ext_gemm_half_q_half(self, x, q_handle, q4_width, force_cuda): output_shape = x.shape[:-1] + (q4_width,) x = x.view(-1, x.shape[-1]) output = torch.empty((x.shape[0], q4_width), dtype=torch.half, device=x.device) - self.gptqmodel_exllamav2_kernels.gemm_half_q_half(x, q_handle, output, force_cuda) + exllamav2_gemm_half_q_half(x, q_handle, output, force_cuda) return output.view(output_shape) def ext_make_q_matrix(self, w: dict, temp_dq, key: str = None): @@ -204,18 +201,18 @@ def ext_make_q_matrix(self, w: dict, temp_dq, key: str = None): w["q_scale_max"] /= 256 w["q_perm"] = w["q_perm"].short() w["q_invperm"] = w["q_invperm"].short() - return self.gptqmodel_exllamav2_kernels.make_q_matrix( + return exllamav2_make_q_matrix( w["q_weight"], w["q_perm"], w["q_invperm"], w["q_scale"], w["q_scale_max"], w["q_groups"], - NONE_TENSOR, - NONE_TENSOR, - NONE_TENSOR, - temp_dq, - ) + None, + None, + None, + temp_dq, + ) # GPTQ elif "qweight" in w: if w["scales"].dtype == torch.float: @@ -230,13 +227,13 @@ def ext_make_q_matrix(self, w: dict, temp_dq, key: str = None): ) w["q_invperm"] = torch.empty_like(w["q_perm"]) # make_q4 segfaults if g_idx is not on cpu in the act-order case. In the non act-order case, None needs to be passed for g_idx. - return self.gptqmodel_exllamav2_kernels.make_q_matrix( + return exllamav2_make_q_matrix( w["qweight"], w["q_perm"], w["q_invperm"], - NONE_TENSOR, - NONE_TENSOR, - NONE_TENSOR, + None, + None, + None, w["qzeros"], w["scales"], w["g_idx"].cpu(), @@ -244,16 +241,16 @@ def ext_make_q_matrix(self, w: dict, temp_dq, key: str = None): ) # GPTQ without g_idx else: - return self.gptqmodel_exllamav2_kernels.make_q_matrix( + return exllamav2_make_q_matrix( w["qweight"], - NONE_TENSOR, - NONE_TENSOR, - NONE_TENSOR, - NONE_TENSOR, - NONE_TENSOR, + None, + None, + None, + None, + None, w["qzeros"], w["scales"], - NONE_TENSOR, + None, temp_dq, ) else: diff --git a/gptqmodel/nn_modules/qlinear/exllamav2_awq.py b/gptqmodel/nn_modules/qlinear/exllamav2_awq.py index 90c2ad882..e419543a8 100644 --- a/gptqmodel/nn_modules/qlinear/exllamav2_awq.py +++ b/gptqmodel/nn_modules/qlinear/exllamav2_awq.py @@ -9,23 +9,23 @@ from ...models._const import DEVICE, PLATFORM from ...nn_modules.qlinear import AWQuantLinear from ...quantization import FORMAT, METHOD -from ...quantization.awq.utils.module import try_import from ...quantization.awq.utils.packing_utils import unpack_reorder_pack from ...utils.backend import BACKEND -from ...utils.exllamav2 import ScratchSpace +from ...utils.exllamav2 import ( + ScratchSpace, + exllamav2_awq_gemm_half_q_half, + exllamav2_awq_make_q_matrix, + exllamav2_awq_runtime_available, + exllamav2_awq_runtime_error, +) from ...utils.logger import setup_logger log = setup_logger() -exlv2_ext, msg = try_import("gptqmodel_exllamav2_awq_kernels") -# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension -none_tensor = torch.empty((1, 1), device="meta") - - -class AwqExllamaV2QuantLinear(AWQuantLinear): - SUPPORTS_BACKENDS = [BACKEND.EXLLAMA_V2] +class AwqExllamaV2Linear(AWQuantLinear): + SUPPORTS_BACKENDS = [BACKEND.AWQ_EXLLAMA_V2] SUPPORTS_METHODS = [METHOD.AWQ] SUPPORTS_FORMATS = {FORMAT.GEMM: 80} SUPPORTS_BITS = [4] @@ -74,10 +74,33 @@ def __init__( out_features=out_features, bias=bias, pack_dtype=pack_dtype, - backend=kwargs.pop("backend", BACKEND.EXLLAMA_V2), + backend=kwargs.pop("backend", BACKEND.AWQ_EXLLAMA_V2), adapter=adapter, **kwargs) + @classmethod + def validate_once(cls): + # Validate through the shared JIT loader so import-time checks match the + # runtime build path used on first kernel access. + if not exllamav2_awq_runtime_available(): + return False, ImportError(exllamav2_awq_runtime_error()) + return True, None + + def ext_make_q_matrix_awq(self, qweight, qzeros, scales, temp_dq) -> int: + runtime_scales = scales if scales.dtype == torch.float16 else scales.to(dtype=torch.float16) + return exllamav2_awq_make_q_matrix( + qweight, + None, + None, + None, + None, + None, + qzeros, + runtime_scales, + None, + temp_dq, + ) + def post_init(self, scratch_space: ScratchSpace): # if self.padded_infeatures != self.in_features: # self.qweight.resize_(self.padded_infeatures // self.pack_dtype_bits * self.bits, self.out_features) @@ -89,8 +112,8 @@ def post_init(self, scratch_space: ScratchSpace): # self.g_idx = torch.tensor([i // self.group_size for i in range(self.padded_infeatures)], dtype=torch.int32, # device=self.g_idx.device) - if exlv2_ext is None: - raise ModuleNotFoundError("External ExLlama kernels are not properly installed." + msg) + if not exllamav2_awq_runtime_available(): + raise ModuleNotFoundError("ExLlamaV2 AWQ torch.ops kernels are not properly installed. Error: " + exllamav2_awq_runtime_error()) # awq only accepts float16 self.scales = self.scales.to(dtype=torch.float16) @@ -104,18 +127,7 @@ def post_init(self, scratch_space: ScratchSpace): temp_dq_size = self.temp_dq_size() temp_dq = scratch_space.get_slice(temp_dq_size) - self.q_handle = exlv2_ext.make_q_matrix_awq( - self.qweight, - none_tensor, - none_tensor, - none_tensor, - none_tensor, - none_tensor, - self.qzeros, - self.scales, - none_tensor, - temp_dq, - ) + self.q_handle = self.ext_make_q_matrix_awq(self.qweight, self.qzeros, self.scales, temp_dq) super().post_init() @@ -123,8 +135,8 @@ def forward(self, x: torch.Tensor): assert self.q_handle is not None, ( "module.post_init() must be called before module.forward(). " ) - if exlv2_ext is None: - raise ModuleNotFoundError("External ExLlamaV2 kernels are not properly installed." + msg) + if not exllamav2_awq_runtime_available(): + raise ModuleNotFoundError("ExLlamaV2 AWQ torch.ops kernels are not properly installed. Error: " + exllamav2_awq_runtime_error()) input_dtype = x.dtype out_shape = x.shape[:-1] + (self.out_features,) @@ -139,7 +151,7 @@ def forward(self, x: torch.Tensor): dtype=torch.float16, device=x.device, ) - exlv2_ext.gemm_half_q_half_awq(x, self.q_handle, out, False) + exllamav2_awq_gemm_half_q_half(x, self.q_handle, out, False) if input_dtype != torch.float16: out = out.to(dtype=input_dtype) @@ -175,4 +187,4 @@ def next_multiple(x, multiple): return ((x + multiple - 1) // multiple) * multiple -__all__ = ["AwqExllamaV2QuantLinear"] +__all__ = ["AwqExllamaV2Linear"] diff --git a/gptqmodel/nn_modules/qlinear/fp4.py b/gptqmodel/nn_modules/qlinear/fp4.py new file mode 100644 index 000000000..0eb6a0df8 --- /dev/null +++ b/gptqmodel/nn_modules/qlinear/fp4.py @@ -0,0 +1,63 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +try: + from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor +except Exception: + NVFP4Tensor = None + + +class TorchFP4Linear(nn.Module): + """Execute one linear layer directly from NVFP4-packed checkpoint weights.""" + + def __init__( + self, + *, + in_features: int, + out_features: int, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_block_size: int, + orig_dtype: torch.dtype, + bias: Optional[torch.Tensor] = None, + ) -> None: + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight_block_size = int(weight_block_size) + self.orig_dtype = orig_dtype + self.register_buffer("weight", weight) + self.register_buffer("weight_scale", weight_scale) + if isinstance(bias, torch.Tensor): + self.register_buffer("bias", bias) + else: + self.bias = None + + def _native_weight(self) -> "NVFP4Tensor": + """Wrap the stored packed weight buffers into the torchao NVFP4 tensor view.""" + + if NVFP4Tensor is None: + raise RuntimeError("TorchFP4Linear requires torchao NVFP4Tensor support.") + return NVFP4Tensor( + self.weight, + self.weight_scale, + block_size=self.weight_block_size, + orig_dtype=self.orig_dtype, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Dispatch one dense linear projection through the NVFP4 tensor wrapper.""" + + bias = self.bias + if isinstance(bias, torch.Tensor) and bias.dtype != x.dtype: + bias = bias.to(dtype=x.dtype) + return F.linear(x, self._native_weight(), bias) diff --git a/gptqmodel/nn_modules/qlinear/fp8.py b/gptqmodel/nn_modules/qlinear/fp8.py new file mode 100644 index 000000000..f5ce05fbb --- /dev/null +++ b/gptqmodel/nn_modules/qlinear/fp8.py @@ -0,0 +1,462 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import transformers +from torch.nn.modules.conv import _ConvNd + +from ...adapter.adapter import Adapter, Lora +from ...models._const import DEVICE, PLATFORM +from ...quantization import FORMAT, METHOD +from ...quantization.config import ( + _normalize_fp8_fmt, + _normalize_fp8_scale_semantics, + _normalize_fp8_weight_block_size, + _normalize_fp8_weight_scale_method, +) +from ...quantization.dtype import ( + available_float8_dtype_names, + dequantize_fp8, + device_supports_native_fp8, +) +from ...utils.backend import BACKEND +from . import WeightOnlyQuantLinear +from .gguf import _apply_optional_smoother + + +def _fp8_dtype_from_name(fmt: str) -> torch.dtype: + return getattr(torch, _normalize_fp8_fmt(fmt)) + + +def _weight_to_matrix(linear: nn.Module) -> torch.Tensor: + weight = linear.weight.detach() + if isinstance(linear, _ConvNd): + weight = weight.flatten(1) + if isinstance(linear, transformers.pytorch_utils.Conv1D): + weight = weight.T + return weight + + +def _compute_scale_inv(abs_max: torch.Tensor, fp8_max: float) -> torch.Tensor: + abs_max = abs_max.to(torch.float32) + eps = torch.finfo(torch.float32).tiny + return torch.where( + abs_max > 0, + torch.full_like(abs_max, float(fp8_max)) / abs_max.clamp_min(eps), + torch.ones_like(abs_max), + ) + + +def quantize_fp8_weight( + weight: torch.Tensor, + *, + format: str = "float8_e4m3fn", + weight_scale_method: str = "row", + weight_block_size: Optional[Tuple[int, int]] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + if weight.ndim != 2: + raise ValueError(f"FP8 quantization expects a 2D weight matrix, got shape {tuple(weight.shape)}.") + + format = _normalize_fp8_fmt(format) + if format == "float8_e8m0fnu": + raise ValueError( + "TorchFP8Linear does not quantize dense weights to float8_e8m0fnu; " + "use that format only for dequantization of existing checkpoints." + ) + block_size = _normalize_fp8_weight_block_size(weight_block_size) + weight_scale_method = _normalize_fp8_weight_scale_method( + weight_scale_method, + weight_block_size=block_size, + ) + fp8_dtype = _fp8_dtype_from_name(format) + fp8_max = torch.finfo(fp8_dtype).max + + weight = weight.to(device="cpu", dtype=torch.float32).contiguous() + + if weight_scale_method == "tensor": + scale_inv = _compute_scale_inv(weight.abs().amax(), fp8_max) + quantized = torch.clamp(weight * scale_inv, min=-fp8_max, max=fp8_max).to(fp8_dtype) + return quantized.contiguous(), scale_inv.to(torch.float32) + + if weight_scale_method == "row": + scale_inv = _compute_scale_inv(weight.abs().amax(dim=1), fp8_max) + quantized = torch.clamp( + weight * scale_inv.unsqueeze(1), + min=-fp8_max, + max=fp8_max, + ).to(fp8_dtype) + return quantized.contiguous(), scale_inv.to(torch.float32).contiguous() + + if block_size is None: + raise ValueError("FP8 block quantization requires `weight_block_size`.") + + block_rows, block_cols = block_size + rows, cols = weight.shape + if rows % block_rows != 0 or cols % block_cols != 0: + raise ValueError( + f"FP8 block quantization expects shape {tuple(weight.shape)} to be divisible by block size " + f"{block_size}." + ) + + row_blocks = rows // block_rows + col_blocks = cols // block_cols + blocks = weight.reshape(row_blocks, block_rows, col_blocks, block_cols) + scale_inv = _compute_scale_inv(blocks.abs().amax(dim=(1, 3)), fp8_max) + scaled = blocks * scale_inv.unsqueeze(1).unsqueeze(3) + quantized = torch.clamp(scaled, min=-fp8_max, max=fp8_max).to(fp8_dtype).reshape(rows, cols) + return quantized.contiguous(), scale_inv.to(torch.float32).contiguous() + + +class TorchFP8Linear(WeightOnlyQuantLinear): + SUPPORTS_BACKENDS = [BACKEND.FP8_TORCH] + SUPPORTS_METHODS = [METHOD.FP8] + SUPPORTS_FORMATS = {FORMAT.FP8: 15} + SUPPORTS_BITS = [8] + SUPPORTS_SHARDS = True + SUPPORTS_TRAINING = True + SUPPORTS_AUTO_PADDING = False + SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [1] + SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [1] + SUPPORTS_DEVICES = [DEVICE.ALL] + SUPPORTS_PLATFORM = [PLATFORM.ALL] + SUPPORTS_PACK_DTYPES = [torch.int8, torch.int16, torch.int32, torch.int64] + SUPPORTS_ADAPTERS = [Lora] + SUPPORTS_DTYPES = [torch.float16, torch.bfloat16, torch.float32] + + QUANT_TYPE = "fp8" + + def __init__( + self, + bits: int, + group_size: int, + sym: bool, + desc_act: bool, + in_features: int, + out_features: int, + bias: bool = False, + pack_dtype: torch.dtype = torch.int32, + adapter: Adapter = None, + register_buffers: bool = True, + format: str = "float8_e4m3fn", + weight_scale_method: str = "row", + weight_block_size: Optional[Tuple[int, int]] = None, + weight_scale_semantics: str = "inverse", + **kwargs, + ): + self.fp8_format = _normalize_fp8_fmt(format) + self.fp8_dtype = _fp8_dtype_from_name(self.fp8_format) + block_size = _normalize_fp8_weight_block_size(weight_block_size) + self.weight_scale_method = _normalize_fp8_weight_scale_method( + weight_scale_method, + weight_block_size=block_size, + ) + self.weight_block_size = block_size + self.weight_scale_semantics = _normalize_fp8_scale_semantics(weight_scale_semantics) + self._scaled_mm_hard_disabled = False + + if self.weight_scale_method == "block" and self.weight_block_size is not None: + block_rows, block_cols = self.weight_block_size + if out_features % block_rows != 0 or in_features % block_cols != 0: + raise ValueError( + f"TorchFP8Linear block scaling requires out_features/in_features " + f"to be divisible by `weight_block_size={self.weight_block_size}`." + ) + + super().__init__( + bits=bits, + in_features=in_features, + out_features=out_features, + bias=bias, + backend=kwargs.pop("backend", BACKEND.FP8_TORCH), + adapter=adapter, + register_buffers=False, + pack_dtype=pack_dtype, + **kwargs, + ) + + if register_buffers: + self._allocate_buffers(bias=bias) + + @classmethod + def validate_once(cls): + if not available_float8_dtype_names(): + return False, RuntimeError("TorchFP8Linear requires a PyTorch build with FP8 dtypes.") + return True, None + + def smooth_block_size(self) -> int: + if self.weight_scale_method == "block" and self.weight_block_size is not None: + return self.weight_block_size[1] + return -1 + + def _scale_shape(self) -> tuple[int, ...]: + if self.weight_scale_method == "tensor": + return () + if self.weight_scale_method == "row": + return (self.out_features,) + if self.weight_block_size is None: + raise ValueError("TorchFP8Linear block scaling requires `weight_block_size`.") + block_rows, block_cols = self.weight_block_size + return ( + self.out_features // block_rows, + self.in_features // block_cols, + ) + + def _allocate_buffers(self, *, bias: bool) -> None: + weight = torch.zeros((self.out_features, self.in_features), dtype=self.fp8_dtype) + scale = torch.ones(self._scale_shape(), dtype=torch.float32) + + if "weight" in self._buffers: + self.weight = weight + else: + self.register_buffer("weight", weight) + + if "weight_scale_inv" in self._buffers: + self.weight_scale_inv = scale + else: + self.register_buffer("weight_scale_inv", scale) + + if bias: + bias_tensor = torch.zeros(self.out_features, dtype=torch.float16) + if "bias" in self._buffers: + self.bias = bias_tensor + else: + self.register_buffer("bias", bias_tensor) + else: + self.bias = None + + def list_buffers(self): + buffers = [] + if hasattr(self, "weight") and self.weight is not None: + buffers.append(self.weight) + if hasattr(self, "weight_scale_inv") and self.weight_scale_inv is not None: + buffers.append(self.weight_scale_inv) + if hasattr(self, "bias") and self.bias is not None: + buffers.append(self.bias) + return buffers + + def extra_repr(self) -> str: + return ( + f"in_features={self.in_features}, out_features={self.out_features}, " + f"bias={self.bias is not None}, format={self.fp8_format}, " + f"weight_scale_method={self.weight_scale_method}" + ) + + def _weight_to_matrix(self, linear: nn.Module) -> torch.Tensor: + return _weight_to_matrix(linear) + + def pack(self, linear: nn.Module, scales: torch.Tensor, zeros: torch.Tensor, g_idx: torch.Tensor = None): + self.pack_original(linear=linear, scales=scales, zeros=zeros, g_idx=g_idx) + + def pack_block( + self, + linear: nn.Module, + scales: torch.Tensor, + zeros: torch.Tensor, + g_idx: torch.Tensor = None, + block_in: int = 8192, + workers: int = 1, + ): + del block_in, workers + self.pack_original(linear=linear, scales=scales, zeros=zeros, g_idx=g_idx) + + def pack_gpu( + self, + linear: nn.Module, + scales: torch.Tensor, + zeros: torch.Tensor, + g_idx: torch.Tensor = None, + *, + block_in: int = 8192, + device: torch.device | None = None, + ): + del block_in, device + self.pack_original(linear=linear, scales=scales, zeros=zeros, g_idx=g_idx) + + @torch.inference_mode() + def pack_original( + self, + linear: nn.Module, + scales: torch.Tensor, + zeros: torch.Tensor, + g_idx: torch.Tensor = None, + *, + smooth=None, + ): + del scales, zeros, g_idx + + weight = self._weight_to_matrix(linear).to(device="cpu", dtype=torch.float32) + weight = _apply_optional_smoother( + weight, + smooth=smooth, + group_size=self.smooth_block_size(), + ) + qweight, weight_scale_inv = quantize_fp8_weight( + weight, + format=self.fp8_format, + weight_scale_method=self.weight_scale_method, + weight_block_size=self.weight_block_size, + ) + + if "weight" in self._buffers: + self.weight = qweight + else: + self.register_buffer("weight", qweight) + + if "weight_scale_inv" in self._buffers: + self.weight_scale_inv = weight_scale_inv + else: + self.register_buffer("weight_scale_inv", weight_scale_inv) + + if linear.bias is not None: + bias = linear.bias.detach().to(device="cpu", dtype=torch.float16) + if "bias" in self._buffers: + self.bias = bias + else: + self.register_buffer("bias", bias) + else: + self.bias = None + + self._scaled_mm_hard_disabled = False + + def _resolve_target(self, device=None, dtype=None) -> tuple[torch.device, torch.dtype]: + target_device = self.weight.device if device is None else torch.device(device) + target_dtype = torch.float32 if dtype is None else dtype + return target_device, target_dtype + + def _expanded_scale_inv(self, *, target_device: torch.device, target_dtype: torch.dtype) -> torch.Tensor: + scale_inv = self.weight_scale_inv + if scale_inv.device != target_device or scale_inv.dtype != target_dtype: + scale_inv = scale_inv.to(device=target_device, dtype=target_dtype) + + if self.weight_scale_method == "tensor": + return scale_inv + if self.weight_scale_method == "row": + return scale_inv.view(-1, 1) + if self.weight_block_size is None: + raise ValueError("TorchFP8Linear block scaling requires `weight_block_size`.") + + block_rows, block_cols = self.weight_block_size + expanded = scale_inv.repeat_interleave(block_rows, dim=0) + expanded = expanded.repeat_interleave(block_cols, dim=1) + return expanded + + def dequantize_weight( + self, + *, + device: torch.device | str | None = None, + dtype: torch.dtype | None = None, + ) -> torch.Tensor: + target_device, target_dtype = self._resolve_target(device=device, dtype=dtype) + + if self.weight_scale_semantics != "inverse": + raise NotImplementedError( + f"Unsupported FP8 scale semantics `{self.weight_scale_semantics}`." + ) + + # Older GPUs can store the module on CUDA but still miss native FP8 math. + # In that case, dequantize on CPU directly to fp16/bf16 and only then move. + prefer_cpu_dequant = target_device.type == "cpu" + # PyTorch's CUDA cast path for E8M0 currently produces NaNs when used as a + # dense weight matrix, so force the validated CPU LUT path for that format. + if self.fp8_format == "float8_e8m0fnu": + prefer_cpu_dequant = True + if target_device.type == "cuda" and not device_supports_native_fp8(target_device): + prefer_cpu_dequant = True + + if prefer_cpu_dequant: + weight = dequantize_fp8( + self.weight.to(device="cpu"), + scale_inv=self.weight_scale_inv.to(device="cpu"), + axis=None if self.weight_scale_method == "block" else (0 if self.weight_scale_method == "row" else 0), + target_dtype=target_dtype, + ) + if target_device.type != "cpu": + weight = weight.to(device=target_device) + return weight.transpose(0, 1).contiguous() + + weight = self.weight if self.weight.device == target_device else self.weight.to(device=target_device) + weight = weight.to(target_dtype) + scale_inv = self._expanded_scale_inv(target_device=target_device, target_dtype=target_dtype) + return (weight / scale_inv).transpose(0, 1).contiguous() + + def _scaled_mm_weight_scale(self, *, device: torch.device) -> torch.Tensor: + scale_inv = self.weight_scale_inv + if scale_inv.device != device: + scale_inv = scale_inv.to(device=device) + scale = torch.reciprocal(scale_inv.to(torch.float32)) + if self.weight_scale_method == "tensor": + return scale + if self.weight_scale_method == "row": + return scale.view(1, -1) + raise NotImplementedError("scaled_mm is only used for tensorwise or rowwise FP8 scales.") + + def _quantize_input_for_scaled_mm(self, x_flat: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + fp8_max = torch.finfo(self.fp8_dtype).max + x_work = x_flat.to(torch.float32) + if self.weight_scale_method == "tensor": + scale_inv = _compute_scale_inv(x_work.abs().amax(), fp8_max) + else: + scale_inv = _compute_scale_inv(x_work.abs().amax(dim=1, keepdim=True), fp8_max) + x_q = torch.clamp(x_work * scale_inv, min=-fp8_max, max=fp8_max).to(self.fp8_dtype) + return x_q, torch.reciprocal(scale_inv).to(torch.float32) + + def _can_use_scaled_mm(self, x_flat: torch.Tensor) -> bool: + return ( + not self._scaled_mm_hard_disabled + and hasattr(torch, "_scaled_mm") + and x_flat.device.type == "cuda" + and self.weight_scale_method == "tensor" + and self.fp8_format != "float8_e8m0fnu" + and x_flat.dtype in {torch.float16, torch.bfloat16} + and x_flat.shape[-1] == self.in_features + and self.in_features % 16 == 0 + and self.out_features % 16 == 0 + ) + + def _forward_dequant_matmul(self, x_flat: torch.Tensor) -> torch.Tensor: + weight = self.dequantize_weight(device=x_flat.device, dtype=x_flat.dtype) + return torch.matmul(x_flat, weight) + + def _forward_scaled_mm(self, x_flat: torch.Tensor) -> torch.Tensor: + weight = self.weight if self.weight.device == x_flat.device else self.weight.to(device=x_flat.device) + x_q, scale_a = self._quantize_input_for_scaled_mm(x_flat) + scale_b = self._scaled_mm_weight_scale(device=x_flat.device) + return torch._scaled_mm( + x_q, + weight.t(), + scale_a=scale_a, + scale_b=scale_b, + out_dtype=x_flat.dtype, + ) + + def forward(self, x: torch.Tensor): + original_shape = x.shape[:-1] + (self.out_features,) + x_flat = x.reshape(-1, x.shape[-1]) + + if self._can_use_scaled_mm(x_flat): + try: + output = self._forward_scaled_mm(x_flat) + except Exception: + self._scaled_mm_hard_disabled = True + output = self._forward_dequant_matmul(x_flat) + else: + output = self._forward_dequant_matmul(x_flat) + + if self.bias is not None: + bias = self.bias + if bias.device != output.device or bias.dtype != output.dtype: + bias = bias.to(device=output.device, dtype=output.dtype) + output = output + bias + + if self.adapter: + output = self.adapter.apply(x=x_flat, out=output) + + return output.reshape(original_shape) + + +__all__ = ["TorchFP8Linear", "quantize_fp8_weight"] diff --git a/gptqmodel/nn_modules/qlinear/gemm_awq.py b/gptqmodel/nn_modules/qlinear/gemm_awq.py index 4b40d4125..4957a8e72 100644 --- a/gptqmodel/nn_modules/qlinear/gemm_awq.py +++ b/gptqmodel/nn_modules/qlinear/gemm_awq.py @@ -13,12 +13,17 @@ from ...models._const import DEVICE, PLATFORM from ...nn_modules.qlinear import AWQuantLinear from ...quantization import FORMAT, METHOD -from ...quantization.awq.utils.module import try_import from ...quantization.awq.utils.utils import get_best_device +from ...utils.awq import awq_dequantize_weights, awq_gemm_forward, awq_runtime_available, awq_runtime_error from ...utils.backend import BACKEND +from ...utils.env import env_flag -awq_ext, msg = try_import("gptqmodel_awq_kernels") +# Shared runtime default: prefer accuracy first unless the user explicitly opts out. +FP32_ACCUM = env_flag("GPTQMODEL_FP32_ACCUM", default=True) + +def _awq_cuda_gemm_forward(input, qweight, scales, qzeros, split_k_iters, fp32_accum: bool = FP32_ACCUM): + return awq_gemm_forward(input, qweight, scales, qzeros, split_k_iters, fp32_accum) class AwqGemmFn(torch.autograd.Function): @@ -34,26 +39,28 @@ def forward( bias=None, out_features=0, prefer_backend=None, + fp32_accum=FP32_ACCUM, ): - if awq_ext is None: - raise ValueError(msg or "CUDA AWQ extension not available for AwqGEMMQuantLinear") - ctx.save_for_backward(x, qweight, qzeros, scales, bias) ctx.out_features = out_features out_shape = x.shape[:-1] + (out_features,) - x = x.to(torch.float16) if x.shape[0] == 0: return torch.zeros(out_shape, dtype=x.dtype, device=x.device) # Above compute density threshold it is faster to just dequantize the whole thing and do simple matmul FULL_DEQUANT_MATMUL_THRESHOLD = x.shape[0] * x.shape[1] > 1024 if FULL_DEQUANT_MATMUL_THRESHOLD: - out = awq_ext.dequantize_weights_cuda(qweight, scales, qzeros, 0, 0, 0, False) - out = torch.matmul(x, out) + out = awq_dequantize_weights(qweight, scales, qzeros, 0, 0, 0, False) + out = torch.matmul(x, out.to(dtype=x.dtype)) else: - out = awq_ext.gemm_forward_cuda( - x.reshape(-1, x.shape[-1]), qweight, scales, qzeros, 8 + out = _awq_cuda_gemm_forward( + x.reshape(-1, x.shape[-1]), + qweight, + scales, + qzeros, + 8, + fp32_accum=fp32_accum, ) out = out + bias if bias is not None else out @@ -68,10 +75,7 @@ def forward( def backward(ctx, grad_output): input, qweight, qzeros, scales, bias = ctx.saved_tensors - if awq_ext is None: - raise ValueError(msg or "CUDA AWQ extension not available for AwqGEMMQuantLinear") - - weights = awq_ext.dequantize_weights_cuda( + weights = awq_dequantize_weights( qweight, scales, qzeros, 1, 0, 0, False ).to(grad_output.dtype) @@ -80,11 +84,11 @@ def backward(ctx, grad_output): batch_size = grad_output.shape[0] grad_input = grad_output.bmm(weights.transpose(0, 1).unsqueeze(0).repeat(batch_size, 1, 1)) - return grad_input, None, None, None, None, None, None, None, None + return grad_input, None, None, None, None, None, None, None, None, None -class AwqGEMMQuantLinear(AWQuantLinear): - SUPPORTS_BACKENDS = [BACKEND.GEMM] +class AwqGEMMLinear(AWQuantLinear): + SUPPORTS_BACKENDS = [BACKEND.AWQ_GEMM] SUPPORTS_METHODS = [METHOD.AWQ] SUPPORTS_FORMATS = {FORMAT.GEMM: 60} SUPPORTS_BITS = [4] @@ -102,7 +106,7 @@ class AwqGEMMQuantLinear(AWQuantLinear): SUPPORTS_PACK_DTYPES = [torch.int32] SUPPORTS_ADAPTERS = [Lora] - SUPPORTS_DTYPES = [torch.float16] + SUPPORTS_DTYPES = [torch.float16, torch.bfloat16] REQUIRES_FORMAT_V2 = False @@ -111,8 +115,8 @@ class AwqGEMMQuantLinear(AWQuantLinear): @classmethod def validate_once(cls) -> Tuple[bool, Optional[Exception]]: - if awq_ext is None: - return False, ValueError(msg or "CUDA AWQ extension not available; cannot select AwqGEMMQuantLinear") + if not awq_runtime_available(): + return False, ValueError(awq_runtime_error() or "CUDA AWQ extension not available; cannot select AwqGEMMLinear") else: return True, None @@ -128,6 +132,7 @@ def __init__( pack_dtype: torch.dtype = torch.int32, adapter: Adapter = None, register_buffers: bool = False, + fp32_accum: bool = FP32_ACCUM, **kwargs, ): @@ -140,15 +145,23 @@ def __init__( out_features=out_features, bias=bias, pack_dtype=pack_dtype, - backend=kwargs.pop("backend", BACKEND.GEMM), + backend=kwargs.pop("backend", BACKEND.AWQ_GEMM), adapter=adapter, register_buffers=register_buffers, **kwargs) + self.fp32_accum = bool(fp32_accum) + + def _ensure_runtime_dtype(self, dtype: torch.dtype): + if self.scales is not None and (self.scales.dtype != dtype or not self.scales.is_contiguous()): + self.scales = self.scales.to(dtype=dtype).contiguous() + if self.bias is not None and (self.bias.dtype != dtype or not self.bias.is_contiguous()): + self.bias = self.bias.to(dtype=dtype).contiguous() def post_init(self): - # awq only accepts float16 - if self.scales is not None: + if self.scales is not None and self.scales.dtype not in (torch.float16, torch.bfloat16): self.scales = self.scales.to(dtype=torch.float16) + if self.bias is not None and self.bias.dtype not in (torch.float16, torch.bfloat16): + self.bias = self.bias.to(dtype=torch.float16) super().post_init() @@ -156,8 +169,13 @@ def forward(self, x: torch.Tensor): out_shape = x.shape[:-1] + (self.out_features,) input_dtype = x.dtype - if input_dtype != torch.float16: - x = x.half() + compute_dtype = input_dtype if input_dtype in (torch.float16, torch.bfloat16) else torch.float16 + if input_dtype != compute_dtype: + x = x.to(compute_dtype) + elif not x.is_contiguous(): + x = x.contiguous() + + self._ensure_runtime_dtype(compute_dtype) ctx = nullcontext() if self.training else torch.inference_mode() with ctx: @@ -171,9 +189,10 @@ def forward(self, x: torch.Tensor): self.bias, self.out_features, "cuda", + self.fp32_accum, ) - if input_dtype != torch.float16: + if out.dtype != input_dtype: out = out.to(dtype=input_dtype) if self.adapter: @@ -188,9 +207,11 @@ def pack(self, linear: nn.Module, scales: torch.Tensor, zeros: torch.Tensor, g_i zeros = zeros.t().contiguous() scale_zeros = zeros * scales - self.register_buffer("scales", scales.clone().half()) + scale_dtype = scales.dtype if scales.dtype in (torch.float16, torch.bfloat16) else torch.float16 + self.register_buffer("scales", scales.clone().to(scale_dtype)) if linear.bias is not None: - self.register_buffer("bias", linear.bias.clone().half()) + bias_dtype = linear.bias.dtype if linear.bias.dtype in (torch.float16, torch.bfloat16) else scale_dtype + self.register_buffer("bias", linear.bias.clone().to(bias_dtype)) else: self.bias = None @@ -255,5 +276,5 @@ def pack(self, linear: nn.Module, scales: torch.Tensor, zeros: torch.Tensor, g_i __all__ = [ "AwqGemmFn", - "AwqGEMMQuantLinear", + "AwqGEMMLinear", ] diff --git a/gptqmodel/nn_modules/qlinear/gemm_awq_triton.py b/gptqmodel/nn_modules/qlinear/gemm_awq_triton.py index c9d02a0a6..63b555af6 100644 --- a/gptqmodel/nn_modules/qlinear/gemm_awq_triton.py +++ b/gptqmodel/nn_modules/qlinear/gemm_awq_triton.py @@ -14,9 +14,14 @@ from ...quantization import FORMAT, METHOD from ...utils import has_gil_disabled from ...utils.backend import BACKEND +from ...utils.env import env_flag from ...utils.torch import HAS_XPU +# Shared runtime default: prefer accuracy first unless the user explicitly opts out. +FP32_ACCUM = env_flag("GPTQMODEL_FP32_ACCUM", default=True) + + class AwqGemmTritonFn(torch.autograd.Function): @staticmethod def forward( @@ -48,7 +53,13 @@ def forward( out = torch.matmul(x, out.to(x.dtype)) else: out = awq_gemm_triton( - x.reshape(-1, x.shape[-1]), qweight, scales, qzeros, split_k_iters=8, + x.reshape(-1, x.shape[-1]), + qweight, + scales, + qzeros, + split_k_iters=8, + fp32_accum=FP32_ACCUM, + output_dtype=x.dtype, ) out = out + bias if bias is not None else out @@ -74,8 +85,8 @@ def backward(ctx, grad_output): return grad_input, None, None, None, None, None, None, None, None -class AwqGEMMTritonQuantLinear(AWQuantLinear): - SUPPORTS_BACKENDS = [BACKEND.GEMM_TRITON] +class AwqGEMMTritonLinear(AWQuantLinear): + SUPPORTS_BACKENDS = [BACKEND.AWQ_GEMM_TRITON] SUPPORTS_METHODS = [METHOD.AWQ] SUPPORTS_FORMATS = {FORMAT.GEMM: 50} SUPPORTS_BITS = [4] @@ -138,7 +149,7 @@ def __init__( out_features=out_features, bias=bias, pack_dtype=pack_dtype, - backend=kwargs.pop("backend", BACKEND.TRITON), + backend=kwargs.pop("backend", BACKEND.AWQ_GEMM_TRITON), adapter=adapter, register_buffers=register_buffers, **kwargs) @@ -180,5 +191,5 @@ def forward(self, x: torch.Tensor): __all__ = [ "AwqGemmTritonFn", - "AwqGEMMTritonQuantLinear", + "AwqGEMMTritonLinear", ] diff --git a/gptqmodel/nn_modules/qlinear/gemv_awq.py b/gptqmodel/nn_modules/qlinear/gemv_awq.py index c80672cfb..aa34c7a36 100644 --- a/gptqmodel/nn_modules/qlinear/gemv_awq.py +++ b/gptqmodel/nn_modules/qlinear/gemv_awq.py @@ -10,7 +10,7 @@ from ...models._const import DEVICE, PLATFORM from ...nn_modules.qlinear import AWQuantLinear from ...quantization import FORMAT, METHOD -from ...quantization.awq.utils.module import try_import +from ...utils.awq import awq_gemmv2_forward, awq_gemv_forward from ...utils.backend import BACKEND from ...utils.gemv import calculate_zeros_width from ...utils.logger import setup_logger @@ -18,10 +18,8 @@ log = setup_logger() -awq_ext, msg = try_import("gptqmodel_awq_kernels") - -class AwqGEMVQuantLinear(AWQuantLinear): - SUPPORTS_BACKENDS = [BACKEND.GEMV] +class AwqGEMVLinear(AWQuantLinear): + SUPPORTS_BACKENDS = [BACKEND.AWQ_GEMV] SUPPORTS_METHODS = [METHOD.AWQ] SUPPORTS_FORMATS = {FORMAT.GEMV: 40} SUPPORTS_BITS = [4] @@ -58,7 +56,7 @@ def __init__( register_buffers: bool = False, **kwargs, ): - backend = kwargs.pop("backend", BACKEND.GEMV) + backend = kwargs.pop("backend", BACKEND.AWQ_GEMV) super().__init__( bits=bits, group_size=group_size, @@ -119,9 +117,6 @@ def post_init(self): super().post_init() def forward(self, x: torch.Tensor): - if awq_ext is None: - raise ModuleNotFoundError("External AWQ kernels are not properly installed." + msg) - out_shape = x.shape[:-1] + (self.out_features,) inputs = x.reshape(-1, x.shape[-1]) @@ -130,7 +125,7 @@ def forward(self, x: torch.Tensor): inputs = inputs.half() if inputs.shape[0] > 8: - out = awq_ext.gemmv2_forward_cuda( + out = awq_gemmv2_forward( inputs, self.qweight, self.scales, @@ -139,7 +134,7 @@ def forward(self, x: torch.Tensor): self.split_k_iters, ) else: - out = awq_ext.gemv_forward_cuda( + out = awq_gemv_forward( inputs, self.qweight, self.scales, self.qzeros, self.group_size ) @@ -230,4 +225,4 @@ def extra_repr(self) -> str: ) ) -__all__ = ["AwqGEMVQuantLinear"] +__all__ = ["AwqGEMVLinear"] diff --git a/gptqmodel/nn_modules/qlinear/gemv_fast_awq.py b/gptqmodel/nn_modules/qlinear/gemv_fast_awq.py index efb82a3ce..97561b826 100644 --- a/gptqmodel/nn_modules/qlinear/gemv_fast_awq.py +++ b/gptqmodel/nn_modules/qlinear/gemv_fast_awq.py @@ -10,15 +10,14 @@ from ...models._const import DEVICE, PLATFORM from ...nn_modules.qlinear import AWQuantLinear from ...quantization import FORMAT, METHOD -from ...quantization.awq.utils.module import try_import +from ...utils.awq import ( + awq_fast_gemm_forward_prefill, + awq_fast_gemv_forward_decode, + awq_runtime_available, + awq_runtime_error, +) from ...utils.backend import BACKEND from ...utils.gemv import calculate_zeros_width -from ...utils.logger import setup_logger - - -log = setup_logger() - -awq_v2_ext, msg = try_import("gptqmodel_awq_v2_kernels") def pack_intweight(unpacked_qweight, interleave, kstride): @@ -63,8 +62,8 @@ def pack_intweight(unpacked_qweight, interleave, kstride): return qweight -class AwqGEMVFastQuantLinear(AWQuantLinear): - SUPPORTS_BACKENDS = [BACKEND.GEMV_FAST] +class AwqGEMVFastLinear(AWQuantLinear): + SUPPORTS_BACKENDS = [BACKEND.AWQ_GEMV_FAST] SUPPORTS_METHODS = [METHOD.AWQ] SUPPORTS_FORMATS = {FORMAT.GEMV_FAST: 30} SUPPORTS_BITS = [4] @@ -103,7 +102,7 @@ def __init__( register_buffers: bool = False, **kwargs, ): - backend = kwargs.pop("backend", BACKEND.GEMV_FAST) + backend = kwargs.pop("backend", BACKEND.AWQ_GEMV_FAST) super().__init__( bits=bits, group_size=group_size, @@ -155,8 +154,8 @@ def __init__( self.bias = None def forward(self, x: torch.Tensor): - if awq_v2_ext is None: - raise ModuleNotFoundError("External AWQ V2 kernels are not properly installed. Error: " + msg) + if not awq_runtime_available(): + raise ModuleNotFoundError("AWQ torch.ops kernels are not properly installed. Error: " + awq_runtime_error()) inputs = x inputs_dim = inputs.dim() @@ -170,11 +169,15 @@ def forward(self, x: torch.Tensor): input_dtype = inputs.dtype if input_dtype != torch.float16: - inputs = inputs.half() + inputs = inputs.to(dtype=torch.float16) + if not inputs.is_contiguous(): + inputs = inputs.contiguous() - zeros = getattr(self, self.zeros_name) + self._ensure_runtime_buffers(device=inputs.device, dtype=inputs.dtype) + + zeros = self._runtime_zeros() if inputs_dim == 3 and batch_size < 8 and n_tokens == 1: - out = awq_v2_ext.gemv_forward_cuda_decode( + out = awq_fast_gemv_forward_decode( inputs, self.qweight, self.scales, @@ -185,7 +188,7 @@ def forward(self, x: torch.Tensor): self.group_size, ) else: - out = awq_v2_ext.gemm_forward_cuda_prefill( + out = awq_fast_gemm_forward_prefill( inputs, self.qweight, self.scales, zeros ) @@ -199,6 +202,35 @@ def forward(self, x: torch.Tensor): return out + def _ensure_runtime_buffers(self, *, device: torch.device, dtype: torch.dtype): + if self.qweight.device != device or not self.qweight.is_contiguous(): + self.qweight = self.qweight.to(device=device).contiguous() + + zeros = self._runtime_zeros() + if zeros.device != device or zeros.dtype != dtype or not zeros.is_contiguous(): + zeros = zeros.to(device=device, dtype=dtype).contiguous() + if self.zeros_name == "qzeros": + self.qzeros = zeros + elif self.zeros_name == "scaled_zeros": + self.scaled_zeros = zeros + else: + raise ValueError(f"Unsupported zeros buffer: {self.zeros_name}") + + if self.scales.device != device or self.scales.dtype != dtype or not self.scales.is_contiguous(): + self.scales = self.scales.to(device=device, dtype=dtype).contiguous() + + if self.bias is not None and ( + self.bias.device != device or self.bias.dtype != dtype or not self.bias.is_contiguous() + ): + self.bias = self.bias.to(device=device, dtype=dtype).contiguous() + + def _runtime_zeros(self) -> torch.Tensor: + if self.zeros_name == "qzeros": + return self.qzeros + if self.zeros_name == "scaled_zeros": + return self.scaled_zeros + raise ValueError(f"Unsupported zeros buffer: {self.zeros_name}") + def pack(self, linear: nn.Module, scales: torch.Tensor, zeros: torch.Tensor, g_idx: torch.Tensor = None): # need scales and zeros info for real quantization assert scales is not None and zeros is not None @@ -254,8 +286,8 @@ def extra_repr(self) -> str: ) -class LLMAwqQuantLinear(AwqGEMVFastQuantLinear): - SUPPORTS_BACKENDS = [BACKEND.GEMV_FAST] +class LLMAwqLinear(AwqGEMVFastLinear): + SUPPORTS_BACKENDS = [BACKEND.AWQ_GEMV_FAST] SUPPORTS_METHODS = [METHOD.AWQ] SUPPORTS_FORMATS = {FORMAT.LLM_AWQ: 100} SUPPORTS_BITS = [4] @@ -281,4 +313,4 @@ class LLMAwqQuantLinear(AwqGEMVFastQuantLinear): zeros_name = "scaled_zeros" -__all__ = ["AwqGEMVFastQuantLinear", "LLMAwqQuantLinear"] +__all__ = ["AwqGEMVFastLinear", "LLMAwqLinear"] diff --git a/gptqmodel/nn_modules/qlinear/gguf.py b/gptqmodel/nn_modules/qlinear/gguf.py new file mode 100644 index 000000000..a0d930f68 --- /dev/null +++ b/gptqmodel/nn_modules/qlinear/gguf.py @@ -0,0 +1,1194 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import math +import os +import time + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import transformers +from torch.nn.modules.conv import _ConvNd + +from ...adapter.adapter import Adapter, Lora +from ...models._const import DEVICE, PLATFORM +from ...quantization.config import ( + FORMAT, + METHOD, + Fallback, + FallbackStrategy, + GGUFBits, + SmoothMethod, + _normalize_quant_bits, +) +from ...quantization.fallback_smooth import smooth_block +from ...utils.backend import BACKEND +from ...utils.logger import setup_logger +from . import WeightOnlyQuantLinear + + +try: + import gguf as gguf_lib + + _GGUF_AVAILABLE = True +except Exception: # pragma: no cover - optional dependency + gguf_lib = None + _GGUF_AVAILABLE = False + + +setup_logger() + +_GGUF_TYPE_INFO = { + "Q1_0": {"bits": 1, "block_size": 32, "type_size": 6}, + "Q1_0_g128": {"bits": 1, "block_size": 128, "type_size": 18}, + "Q4_0": {"bits": 4, "block_size": 32, "type_size": 18}, + "Q8_0": {"bits": 8, "block_size": 32, "type_size": 34}, + "Q4_K": {"bits": 4, "block_size": 256, "type_size": 144}, + "Q5_K": {"bits": 5, "block_size": 256, "type_size": 176}, + "Q6_K": {"bits": 6, "block_size": 256, "type_size": 210}, +} +_GGUF_BITS_ALIAS_TO_TENSOR_QTYPE = { + "q1_0": "Q1_0", + "q1_0_g128": "Q1_0_g128", + "q4_0": "Q4_0", + "q8_0": "Q8_0", + "q4_k": "Q4_K", + "q4_k_s": "Q4_K", + "q4_k_m": "Q4_K", + "q5_k": "Q5_K", + "q5_k_s": "Q5_K", + "q5_k_m": "Q5_K", + "q6_k": "Q6_K", +} +_GGUF_SCALE_QUANT_MAX = 63 +_GGUF_Q6_SCALE_QUANT_MAX = 127 +_GGUF_K_QTYPES = {"Q4_K", "Q5_K", "Q6_K"} +PRISM_Q1_0_G128_NAME = "Q1_0_g128" +PRISM_Q1_0_G128_VALUE = 41 +PRISM_Q1_0_G128_BLOCK_SIZE = 128 +PRISM_Q1_0_G128_TYPE_SIZE = 18 +_GGUF_SIGN_ONLY_TYPE_INFO = { + "Q1_0": {"block_size": 32, "type_size": 6}, + PRISM_Q1_0_G128_NAME: { + "block_size": PRISM_Q1_0_G128_BLOCK_SIZE, + "type_size": PRISM_Q1_0_G128_TYPE_SIZE, + }, +} +_GGUF_TENSOR_QTYPE_BY_VALUE = { + 0: "F32", + 1: "F16", + 2: "Q4_0", + 8: "Q8_0", + 12: "Q4_K", + 13: "Q5_K", + 14: "Q6_K", + 30: "BF16", + 40: "Q1_0", + PRISM_Q1_0_G128_VALUE: PRISM_Q1_0_G128_NAME, +} +_GGUF_SIGN_ONLY_LUT = ( + np.unpackbits(np.arange(256, dtype=np.uint8)[:, None], axis=1, bitorder="little").astype(np.int8) * 2 - 1 +) +_GGUF_SIGN_ONLY_TORCH_LUT: dict[str, torch.Tensor] = {} + + +def _normalize_gguf_bits(bits) -> tuple[GGUFBits, str]: + bits_spec = _normalize_quant_bits(bits, format_value=FORMAT.GGUF) + tensor_qtype = _GGUF_BITS_ALIAS_TO_TENSOR_QTYPE.get(bits_spec.name) + if tensor_qtype is None: + supported = ", ".join(sorted(_GGUF_BITS_ALIAS_TO_TENSOR_QTYPE)) + raise ValueError(f"Unsupported GGUF bits `{bits}`. Supported values: {supported}.") + + qtype_info = _GGUF_TYPE_INFO[tensor_qtype] + if qtype_info["bits"] != bits_spec.width: + raise ValueError( + f"GGUF bits `{bits_spec.name}` require {qtype_info['bits']}-bit GGUF packing, but got bits={bits_spec.width}." + ) + + return bits_spec, tensor_qtype + + +def _apply_optional_smoother( + weight: torch.Tensor, + *, + smooth: SmoothMethod | None, + group_size: int, +) -> torch.Tensor: + if smooth is None: + return weight + + effective_group_size = weight.shape[1] if group_size == -1 else group_size + if effective_group_size <= 0: + effective_group_size = weight.shape[1] + + fallback = Fallback( + strategy=FallbackStrategy.RTN, + threshold=True, + smooth=smooth, + ) + smoothed = weight.clone() + + for start in range(0, weight.shape[1], effective_group_size): + end = min(start + effective_group_size, weight.shape[1]) + block, scale_factor = smooth_block( + smoothed[:, start:end], + fallback, + group_size=effective_group_size, + ) + if scale_factor is not None: + raise ValueError( + "GGUF direct packing does not support smoothers that require post-quant rescaling." + ) + smoothed[:, start:end] = block + + return smoothed + + +def _gguf_quantize_q4_0(blocks: np.ndarray) -> np.ndarray: + n_blocks = blocks.shape[0] + block_size = _GGUF_TYPE_INFO["Q4_0"]["block_size"] + + imax = np.abs(blocks).argmax(axis=-1, keepdims=True) + max_vals = np.take_along_axis(blocks, imax, axis=-1) + + d = max_vals / -8.0 + with np.errstate(divide="ignore"): + inv_d = np.where(d == 0, 0, 1.0 / d) + + # Match ggml's q4_0 reference path by truncating after the +8.5 offset. + qs = np.trunc((blocks.astype(np.float64) * inv_d.astype(np.float64)) + 8.5).astype(np.uint8) + qs = np.clip(qs, 0, 15) + qs = qs.reshape((n_blocks, 2, block_size // 2)) + qs = qs[:, 0, :] | (qs[:, 1, :] << np.uint8(4)) + + d = d.astype(np.float16).view(np.uint8) + return np.concatenate([d, qs], axis=-1) + + +def _gguf_quantize_q8_0(blocks: np.ndarray) -> np.ndarray: + d = np.abs(blocks).max(axis=-1, keepdims=True) / 127.0 + with np.errstate(divide="ignore"): + inv_d = np.where(d == 0, 0, 1.0 / d) + qs = np.round(blocks * inv_d).astype(np.int8).view(np.uint8) + d = d.astype(np.float16).view(np.uint8) + return np.concatenate([d, qs], axis=-1) + + +def _gguf_quantize_sign_only(blocks: np.ndarray, *, block_size: int) -> np.ndarray: + scales = np.mean(np.abs(blocks), axis=-1).astype(np.float16, copy=False) + sign_bits = np.packbits((blocks >= 0).astype(np.uint8, copy=False), axis=-1, bitorder="little") + + packed = np.empty((blocks.shape[0], 2 + (block_size // 8)), dtype=np.uint8) + packed[:, :2] = scales.view(np.uint8).reshape(-1, 2) + packed[:, 2:] = sign_bits + return packed + + +def _pack_q4_k_scale_min(scales: np.ndarray, mins: np.ndarray) -> np.ndarray: + scales = scales.astype(np.uint8, copy=False) + mins = mins.astype(np.uint8, copy=False) + + d = (scales[:, :4] & np.uint8(0x3F)) | ((scales[:, 4:] & np.uint8(0x30)) << np.uint8(2)) + m = (mins[:, :4] & np.uint8(0x3F)) | ((mins[:, 4:] & np.uint8(0x30)) << np.uint8(2)) + md = (scales[:, 4:] & np.uint8(0x0F)) | ((mins[:, 4:] & np.uint8(0x0F)) << np.uint8(4)) + + return np.concatenate([d, m, md], axis=-1) + + +def _unpack_q4_k_scale_min(scales: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + packed = scales.astype(np.uint8, copy=False).reshape((-1, 3, 4)) + d, m, md = np.split(packed, 3, axis=-2) + + sc = np.concatenate([d & 0x3F, (md & 0x0F) | ((d >> 2) & 0x30)], axis=-1) + mins = np.concatenate([m & 0x3F, (md >> 4) | ((m >> 2) & 0x30)], axis=-1) + return sc.reshape((-1, 8)), mins.reshape((-1, 8)) + + +def _unpack_q4_k_scale_min_torch(scales: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + packed = scales.reshape(*scales.shape[:-1], 3, 4) + d = packed[..., 0, :] + m = packed[..., 1, :] + md = packed[..., 2, :] + + sc = torch.cat((d & 0x3F, (md & 0x0F) | ((d >> 2) & 0x30)), dim=-1) + mins = torch.cat((m & 0x3F, (md >> 4) | ((m >> 2) & 0x30)), dim=-1) + return sc, mins + + +def _quantize_k_subblocks( + subblocks: np.ndarray, + *, + maxq: int, + scale_quant_max: int, + signed: bool, +) -> tuple[np.ndarray, ...]: + if signed: + scale = np.abs(subblocks).max(axis=-1) / maxq + base = scale.max(axis=-1, keepdims=True) / scale_quant_max + with np.errstate(divide="ignore", invalid="ignore"): + quant_scales = np.where(base > 0, np.rint(scale / base), 0.0) + quant_scales = np.clip(quant_scales, 0, scale_quant_max).astype(np.int32) + eff_scale = base * quant_scales.astype(np.float32) + with np.errstate(divide="ignore", invalid="ignore"): + q = np.where(eff_scale[..., None] > 0, np.rint(subblocks / eff_scale[..., None]), 0.0) + q = np.clip(q, -32, 31).astype(np.int8) + return base.astype(np.float16), quant_scales.astype(np.int8), q + + mins = np.maximum(-subblocks.min(axis=-1), 0.0) + scale = (subblocks.max(axis=-1) + mins) / maxq + + base = scale.max(axis=-1, keepdims=True) / scale_quant_max + min_base = mins.max(axis=-1, keepdims=True) / scale_quant_max + + with np.errstate(divide="ignore", invalid="ignore"): + quant_scales = np.where(base > 0, np.rint(scale / base), 0.0) + quant_mins = np.where(min_base > 0, np.rint(mins / min_base), 0.0) + + quant_scales = np.clip(quant_scales, 0, scale_quant_max).astype(np.uint8) + quant_mins = np.clip(quant_mins, 0, scale_quant_max).astype(np.uint8) + + eff_scale = base * quant_scales.astype(np.float32) + eff_min = min_base * quant_mins.astype(np.float32) + shifted = subblocks + eff_min[..., None] + with np.errstate(divide="ignore", invalid="ignore"): + q = np.where(eff_scale[..., None] > 0, np.rint(shifted / eff_scale[..., None]), 0.0) + q = np.clip(q, 0, maxq).astype(np.uint8) + + return base.astype(np.float16), min_base.astype(np.float16), quant_scales, quant_mins, q + + +def _gguf_quantize_q4_k(blocks: np.ndarray) -> np.ndarray: + n_blocks = blocks.shape[0] + subblocks = blocks.reshape(n_blocks, 8, 32) + d, dmin, sc, mins, q = _quantize_k_subblocks( + subblocks, + maxq=15, + scale_quant_max=_GGUF_SCALE_QUANT_MAX, + signed=False, + ) + scales = _pack_q4_k_scale_min(sc, mins) + q_pairs = q.reshape(n_blocks, 4, 2, 32) + qs = q_pairs[:, :, 0, :] | (q_pairs[:, :, 1, :] << np.uint8(4)) + return np.concatenate( + [ + d.view(np.uint8), + dmin.view(np.uint8), + scales, + qs.reshape(n_blocks, 128), + ], + axis=-1, + ) + + +def _gguf_quantize_q5_k(blocks: np.ndarray) -> np.ndarray: + n_blocks = blocks.shape[0] + subblocks = blocks.reshape(n_blocks, 8, 32) + d, dmin, sc, mins, q = _quantize_k_subblocks( + subblocks, + maxq=31, + scale_quant_max=_GGUF_SCALE_QUANT_MAX, + signed=False, + ) + scales = _pack_q4_k_scale_min(sc, mins) + q_pairs = q.reshape(n_blocks, 4, 2, 32) + qs = (q_pairs[:, :, 0, :] & np.uint8(0x0F)) | ((q_pairs[:, :, 1, :] & np.uint8(0x0F)) << np.uint8(4)) + qh = np.sum( + (((q >> np.uint8(4)) & np.uint8(0x01)).astype(np.uint16) << np.arange(8, dtype=np.uint16).reshape(1, 8, 1)), + axis=1, + dtype=np.uint16, + ).astype(np.uint8) + return np.concatenate( + [ + d.view(np.uint8), + dmin.view(np.uint8), + scales, + qh, + qs.reshape(n_blocks, 128), + ], + axis=-1, + ) + + +def _gguf_quantize_q6_k(blocks: np.ndarray) -> np.ndarray: + n_blocks = blocks.shape[0] + subblocks = blocks.reshape(n_blocks, 16, 16) + d, scales, q = _quantize_k_subblocks( + subblocks, + maxq=31, + scale_quant_max=_GGUF_Q6_SCALE_QUANT_MAX, + signed=True, + ) + q_raw = (q.astype(np.int16) + 32).astype(np.uint8).reshape(n_blocks, 8, 32) + + ql = np.empty((n_blocks, 128), dtype=np.uint8) + qh = np.empty((n_blocks, 64), dtype=np.uint8) + + for group in range(2): + base_q = group * 4 + base_ql = group * 64 + base_qh = group * 32 + + ql[:, base_ql : base_ql + 32] = ( + (q_raw[:, base_q + 0, :] & np.uint8(0x0F)) + | ((q_raw[:, base_q + 2, :] & np.uint8(0x0F)) << np.uint8(4)) + ) + ql[:, base_ql + 32 : base_ql + 64] = ( + (q_raw[:, base_q + 1, :] & np.uint8(0x0F)) + | ((q_raw[:, base_q + 3, :] & np.uint8(0x0F)) << np.uint8(4)) + ) + + qh[:, base_qh : base_qh + 32] = ( + ((q_raw[:, base_q + 0, :] >> np.uint8(4)) & np.uint8(0x03)) + | (((q_raw[:, base_q + 1, :] >> np.uint8(4)) & np.uint8(0x03)) << np.uint8(2)) + | (((q_raw[:, base_q + 2, :] >> np.uint8(4)) & np.uint8(0x03)) << np.uint8(4)) + | (((q_raw[:, base_q + 3, :] >> np.uint8(4)) & np.uint8(0x03)) << np.uint8(6)) + ) + + return np.concatenate( + [ + ql, + qh, + scales.view(np.uint8), + d.view(np.uint8), + ], + axis=-1, + ) + + +def _fallback_gguf_quantize(weight: np.ndarray, tensor_qtype: str) -> np.ndarray: + if weight.ndim != 2: + raise ValueError(f"GGUF quantization expects a 2D weight matrix, got shape {weight.shape}.") + qtype_info = _GGUF_TYPE_INFO[tensor_qtype] + block_size = qtype_info["block_size"] + if weight.shape[1] % block_size != 0: + raise ValueError( + f"GGUF quantization expects the input dimension to be divisible by {block_size}, got {weight.shape[1]}." + ) + + blocks = weight.reshape(-1, block_size) + if tensor_qtype in _GGUF_SIGN_ONLY_TYPE_INFO: + quantized_blocks = _gguf_quantize_sign_only(blocks, block_size=block_size) + elif tensor_qtype == "Q4_0": + quantized_blocks = _gguf_quantize_q4_0(blocks) + elif tensor_qtype == "Q8_0": + quantized_blocks = _gguf_quantize_q8_0(blocks) + elif tensor_qtype == "Q4_K": + quantized_blocks = _gguf_quantize_q4_k(blocks) + elif tensor_qtype == "Q5_K": + quantized_blocks = _gguf_quantize_q5_k(blocks) + elif tensor_qtype == "Q6_K": + quantized_blocks = _gguf_quantize_q6_k(blocks) + else: # pragma: no cover - guarded by class SUPPORTS_BITS + raise NotImplementedError(f"Unsupported GGUF qtype: {tensor_qtype}") + + bytes_per_block = qtype_info["type_size"] + rows = weight.shape[0] + return quantized_blocks.reshape(rows, (weight.shape[1] // block_size) * bytes_per_block) + + +def _gguf_quantize(weight: np.ndarray, tensor_qtype: str) -> np.ndarray: + return _quantize_gguf_tensor_numpy(weight, tensor_qtype) + + +def _resolve_gguf_tensor_qtype(tensor_type) -> str: + if isinstance(tensor_type, str): + normalized = tensor_type.strip() + if normalized in _GGUF_TYPE_INFO or normalized in _GGUF_SIGN_ONLY_TYPE_INFO: + return normalized + raise NotImplementedError(f"Unsupported GGUF qtype: {tensor_type}") + + tensor_name = getattr(tensor_type, "name", None) + if tensor_name in _GGUF_TYPE_INFO or tensor_name in _GGUF_SIGN_ONLY_TYPE_INFO: + return tensor_name + + try: + tensor_value = int(tensor_type) + except (TypeError, ValueError): + tensor_value = None + + if tensor_value is None: + raise NotImplementedError(f"Unsupported GGUF qtype: {tensor_type}") + + resolved = _GGUF_TENSOR_QTYPE_BY_VALUE.get(tensor_value) + if resolved is None: + raise NotImplementedError(f"Unsupported GGUF qtype value: {tensor_value}") + return resolved + + +def _is_prism_q1_0_g128(tensor_type) -> bool: + return _resolve_gguf_tensor_qtype(tensor_type) == PRISM_Q1_0_G128_NAME + + +def _dequantize_sign_only_numpy( + data: np.ndarray, + *, + block_size: int, + type_size: int, +) -> np.ndarray: + rows = np.asarray(data, dtype=np.uint8) + if rows.shape[-1] % type_size != 0: + raise ValueError( + f"GGUF sign-only row byte width must be divisible by {type_size}, got " + f"{rows.shape[-1]} for shape {rows.shape}." + ) + + n_blocks = rows.shape[-1] // type_size + blocks = rows.reshape(*rows.shape[:-1], n_blocks, type_size) + scales = np.ascontiguousarray(blocks[..., :2]).view(np.float16).astype(np.float32)[..., 0] + sign_bits = np.unpackbits(blocks[..., 2:], axis=-1, bitorder="little") + weights = np.where(sign_bits == 1, scales[..., None], -scales[..., None]).astype(np.float32, copy=False) + return weights.reshape(*rows.shape[:-1], n_blocks * block_size) + + +def _get_sign_only_torch_lut(device: torch.device) -> torch.Tensor: + key = str(device) + lut = _GGUF_SIGN_ONLY_TORCH_LUT.get(key) + if lut is None or lut.device != device: + lut = torch.from_numpy(_GGUF_SIGN_ONLY_LUT).to(device=device, dtype=torch.int8) + _GGUF_SIGN_ONLY_TORCH_LUT[key] = lut + return lut + + +def _dequantize_sign_only_torch( + data: np.ndarray | torch.Tensor, + *, + block_size: int, + type_size: int, + device: torch.device | str | None = None, + dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + if torch.is_tensor(data): + rows = data + if rows.dtype != torch.uint8: + rows = rows.to(dtype=torch.uint8) + else: + rows = torch.from_numpy(np.array(data, dtype=np.uint8, copy=True, order="C")) + + target_device = rows.device if device is None else torch.device(device) + if rows.device != target_device: + rows = rows.to(device=target_device, non_blocking=rows.device.type == "cpu" and target_device.type == "cuda") + if not rows.is_contiguous(): + rows = rows.contiguous() + + if rows.shape[-1] % type_size != 0: + raise ValueError( + f"GGUF sign-only row byte width must be divisible by {type_size}, got " + f"{rows.shape[-1]} for shape {tuple(rows.shape)}." + ) + + n_blocks = rows.shape[-1] // type_size + blocks = rows.reshape(*rows.shape[:-1], n_blocks, type_size) + scales = blocks[..., :2].contiguous().view(torch.float16).squeeze(-1) + if scales.dtype != dtype: + scales = scales.to(dtype) + + sign_bytes = blocks[..., 2:].to(dtype=torch.long) + sign_lut = _get_sign_only_torch_lut(target_device) + signs = sign_lut[sign_bytes].reshape(*rows.shape[:-1], n_blocks, block_size) + if signs.dtype != dtype: + signs = signs.to(dtype) + + weights = scales.unsqueeze(-1) * signs + return weights.reshape(*rows.shape[:-1], n_blocks * block_size) + + +def _dequantize_prism_q1_0_g128(data: np.ndarray) -> np.ndarray: + return _dequantize_sign_only_numpy( + data, + block_size=PRISM_Q1_0_G128_BLOCK_SIZE, + type_size=PRISM_Q1_0_G128_TYPE_SIZE, + ) + + +def _dequantize_prism_q1_0_g128_torch( + data: np.ndarray | torch.Tensor, + *, + device: torch.device | str | None = None, + dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + return _dequantize_sign_only_torch( + data, + block_size=PRISM_Q1_0_G128_BLOCK_SIZE, + type_size=PRISM_Q1_0_G128_TYPE_SIZE, + device=device, + dtype=dtype, + ) + + +def _quantize_gguf_tensor_numpy(weight: np.ndarray, tensor_qtype) -> np.ndarray: + resolved_qtype = _resolve_gguf_tensor_qtype(tensor_qtype) + if _GGUF_AVAILABLE: + qtype = getattr(gguf_lib.GGMLQuantizationType, resolved_qtype, None) + try: + if qtype is not None: + return gguf_lib.quantize(weight, qtype) + except NotImplementedError: + pass + return _fallback_gguf_quantize(weight, resolved_qtype) + + +def _dequantize_q4_k_numpy(qweight: np.ndarray) -> np.ndarray: + rows = qweight.shape[0] + type_size = _GGUF_TYPE_INFO["Q4_K"]["type_size"] + blocks = qweight.reshape(-1, type_size) + + d = blocks[:, :2].view(np.float16).astype(np.float32) + dmin = blocks[:, 2:4].view(np.float16).astype(np.float32) + scales = blocks[:, 4:16] + qs = blocks[:, 16:] + + sc, mins = _unpack_q4_k_scale_min(scales) + d = (d * sc.astype(np.float32)).reshape((-1, 8, 1)) + dm = (dmin * mins.astype(np.float32)).reshape((-1, 8, 1)) + + q = qs.reshape((-1, 4, 1, 32)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1)) + q = (q & np.uint8(0x0F)).reshape((-1, 8, 32)).astype(np.float32) + + return (d * q - dm).reshape(rows, -1) + + +def _dequantize_q5_k_numpy(qweight: np.ndarray) -> np.ndarray: + rows = qweight.shape[0] + type_size = _GGUF_TYPE_INFO["Q5_K"]["type_size"] + blocks = qweight.reshape(-1, type_size) + + d = blocks[:, :2].view(np.float16).astype(np.float32) + dmin = blocks[:, 2:4].view(np.float16).astype(np.float32) + scales = blocks[:, 4:16] + qh = blocks[:, 16:48] + qs = blocks[:, 48:] + + sc, mins = _unpack_q4_k_scale_min(scales) + d = (d * sc.astype(np.float32)).reshape((-1, 8, 1)) + dm = (dmin * mins.astype(np.float32)).reshape((-1, 8, 1)) + + ql = qs.reshape((-1, 4, 1, 32)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1)) + qh = qh.reshape((-1, 1, 1, 32)) >> np.arange(8, dtype=np.uint8).reshape((1, 1, 8, 1)) + + ql = (ql & np.uint8(0x0F)).reshape((-1, 8, 32)) + qh = (qh & np.uint8(0x01)).reshape((-1, 8, 32)) + q = (ql | (qh << np.uint8(4))).astype(np.float32) + + return (d * q - dm).reshape(rows, -1) + + +def _dequantize_q6_k_numpy(qweight: np.ndarray) -> np.ndarray: + rows = qweight.shape[0] + type_size = _GGUF_TYPE_INFO["Q6_K"]["type_size"] + blocks = qweight.reshape(-1, type_size) + + ql = blocks[:, :128] + qh = blocks[:, 128:192] + scales = blocks[:, 192:208].view(np.int8).astype(np.float32) + d = blocks[:, 208:210].view(np.float16).astype(np.float32) + d = (d * scales).reshape((-1, 16, 1)) + + ql = ql.reshape((-1, 2, 1, 64)) >> np.array([0, 4], dtype=np.uint8).reshape((1, 1, 2, 1)) + ql = (ql & np.uint8(0x0F)).reshape((-1, 8, 32)) + qh = qh.reshape((-1, 2, 1, 32)) >> np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 1, 4, 1)) + qh = (qh & np.uint8(0x03)).reshape((-1, 8, 32)) + + q = (ql | (qh << np.uint8(4))).astype(np.int16) - 32 + q = q.reshape((-1, 16, 16)).astype(np.float32) + + return (d * q).reshape(rows, -1) + + +def _dequantize_gguf_tensor_numpy(data: np.ndarray, tensor_type) -> np.ndarray: + resolved_qtype = _resolve_gguf_tensor_qtype(tensor_type) + + if resolved_qtype == "F32": + return np.asarray(data, dtype=np.float32) + if resolved_qtype == "F16": + return np.asarray(data, dtype=np.float16).astype(np.float32) + if resolved_qtype == "BF16": + rows = np.asarray(data, dtype=np.uint16).astype(np.uint32) + return np.left_shift(rows, np.uint32(16)).view(np.float32) + if resolved_qtype == "Q4_0": + rows = np.asarray(data, dtype=np.uint8) + type_size = _GGUF_TYPE_INFO["Q4_0"]["type_size"] + blocks = rows.reshape(-1, type_size) + d = blocks[:, :2].view(np.float16).astype(np.float32) + qs = blocks[:, 2:].reshape((-1, 1, 16)) + low = qs & np.uint8(0x0F) + high = qs >> np.uint8(4) + q = np.concatenate([low, high], axis=1).reshape((-1, 32)).astype(np.int16) - 8 + return (d * q.astype(np.float32)).reshape(rows.shape[0], -1) + if resolved_qtype == "Q8_0": + rows = np.asarray(data, dtype=np.uint8) + type_size = _GGUF_TYPE_INFO["Q8_0"]["type_size"] + blocks = rows.reshape(-1, type_size) + d = blocks[:, :2].view(np.float16).astype(np.float32) + q = blocks[:, 2:].view(np.int8).astype(np.float32) + return (d * q).reshape(rows.shape[0], -1) + if resolved_qtype == "Q4_K": + return _dequantize_q4_k_numpy(np.asarray(data, dtype=np.uint8)) + if resolved_qtype == "Q5_K": + return _dequantize_q5_k_numpy(np.asarray(data, dtype=np.uint8)) + if resolved_qtype == "Q6_K": + return _dequantize_q6_k_numpy(np.asarray(data, dtype=np.uint8)) + if resolved_qtype == "Q1_0": + return _dequantize_sign_only_numpy( + data, + block_size=_GGUF_SIGN_ONLY_TYPE_INFO["Q1_0"]["block_size"], + type_size=_GGUF_SIGN_ONLY_TYPE_INFO["Q1_0"]["type_size"], + ) + if resolved_qtype == PRISM_Q1_0_G128_NAME: + return _dequantize_prism_q1_0_g128(np.asarray(data, dtype=np.uint8)) + + raise NotImplementedError(f"Unsupported GGUF qtype: {resolved_qtype}") + + +class GGUFTorchLinear(WeightOnlyQuantLinear): + SUPPORTS_BACKENDS = [BACKEND.GGUF_TORCH] + SUPPORTS_METHODS = [METHOD.GGUF] + SUPPORTS_FORMATS = {FORMAT.GGUF: 15} + SUPPORTS_BITS = [1, 4, 5, 6, 8] + SUPPORTS_SHARDS = True + SUPPORTS_TRAINING = True + SUPPORTS_AUTO_PADDING = True + SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [1] + SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [1] + + SUPPORTS_DEVICES = [DEVICE.ALL] + SUPPORTS_PLATFORM = [PLATFORM.ALL] + SUPPORTS_PACK_DTYPES = [torch.int8, torch.int16, torch.int32] + SUPPORTS_ADAPTERS = [Lora] + + SUPPORTS_DTYPES = [torch.float16, torch.bfloat16, torch.float32] + + REQUIRES_FORMAT_V2 = False + AUTOTUNE = True + + QUANT_TYPE = "gguf" + GGUF_FUSED_CUDA_MAX_ROWS = max(0, int(os.environ.get("GPTQMODEL_GGUF_FUSED_CUDA_MAX_ROWS", "32"))) + GGUF_FUSED_CUDA_MIN_MATRIX_ELEMENTS = max( + 0, + int(os.environ.get("GPTQMODEL_GGUF_FUSED_CUDA_MIN_MATRIX_ELEMENTS", "8388608")), + ) + GGUF_FUSED_CPU_MAX_ROWS = max(0, int(os.environ.get("GPTQMODEL_GGUF_FUSED_CPU_MAX_ROWS", "64"))) + GGUF_FUSED_CPU_MIN_MATRIX_ELEMENTS = max( + 0, + int(os.environ.get("GPTQMODEL_GGUF_FUSED_CPU_MIN_MATRIX_ELEMENTS", "0")), + ) + GGUF_FUSED_CHUNK_BLOCKS = max(1, int(os.environ.get("GPTQMODEL_GGUF_FUSED_CHUNK_BLOCKS", "8"))) + GGUF_FUSED_AUTOTUNE_WARMUP = max(0, int(os.environ.get("GPTQMODEL_GGUF_FUSED_AUTOTUNE_WARMUP", "1"))) + GGUF_FUSED_AUTOTUNE_ITERS = max(1, int(os.environ.get("GPTQMODEL_GGUF_FUSED_AUTOTUNE_ITERS", "2"))) + GGUF_FUSED_AUTOTUNE_MARGIN = max(0.0, float(os.environ.get("GPTQMODEL_GGUF_FUSED_AUTOTUNE_MARGIN", "0.05"))) + + def __init__( + self, + bits, + group_size: int, + sym: bool, + desc_act: bool, + in_features: int, + out_features: int, + bias: bool = False, + pack_dtype: torch.dtype = torch.int32, + adapter: Adapter = None, + register_buffers: bool = True, + **kwargs, + ): + bits_spec, self.gguf_tensor_qtype = _normalize_gguf_bits(bits) + qtype_info = _GGUF_TYPE_INFO[self.gguf_tensor_qtype] + self.gguf_block_size = qtype_info["block_size"] + self.gguf_type_size = qtype_info["type_size"] + self.padded_in_features = math.ceil(in_features / self.gguf_block_size) * self.gguf_block_size + self.gguf_fused_cuda_max_rows = self.GGUF_FUSED_CUDA_MAX_ROWS + self.gguf_fused_cuda_min_matrix_elements = self.GGUF_FUSED_CUDA_MIN_MATRIX_ELEMENTS + self.gguf_fused_cpu_max_rows = self.GGUF_FUSED_CPU_MAX_ROWS + self.gguf_fused_cpu_min_matrix_elements = self.GGUF_FUSED_CPU_MIN_MATRIX_ELEMENTS + self.gguf_fused_chunk_blocks = self.GGUF_FUSED_CHUNK_BLOCKS + self.gguf_fused_autotune_warmup = self.GGUF_FUSED_AUTOTUNE_WARMUP + self.gguf_fused_autotune_iters = self.GGUF_FUSED_AUTOTUNE_ITERS + self.gguf_fused_autotune_margin = self.GGUF_FUSED_AUTOTUNE_MARGIN + + super().__init__( + bits=int(bits_spec), + in_features=in_features, + out_features=out_features, + bias=bias, + backend=kwargs.pop("backend", BACKEND.GGUF_TORCH), + adapter=adapter, + register_buffers=False, + pack_dtype=pack_dtype, + **kwargs, + ) + + self.bits = bits_spec + + if register_buffers: + self._allocate_buffers(bias=bias) + + def _bytes_per_row(self) -> int: + return (self.padded_in_features // self.gguf_block_size) * self.gguf_type_size + + def smooth_block_size(self) -> int: + return self.gguf_block_size + + def _allocate_buffers(self, *, bias: bool) -> None: + bytes_per_row = self._bytes_per_row() + qweight = torch.zeros((self.out_features, bytes_per_row), dtype=torch.uint8) + if "qweight" in self._buffers: + self.qweight = qweight + else: + self.register_buffer("qweight", qweight) + + if bias: + bias_tensor = torch.zeros(self.out_features, dtype=torch.float16) + if "bias" in self._buffers: + self.bias = bias_tensor + else: + self.register_buffer("bias", bias_tensor) + else: + self.bias = None + + def clear_weight_cache(self) -> None: + return None + + def post_init(self): + self.clear_weight_cache() + super().post_init() + + def train(self, mode: bool = True): + self.clear_weight_cache() + return super().train(mode=mode) + + def extra_repr(self) -> str: + return ( + f"in_features={self.in_features}, out_features={self.out_features}, " + f"bias={self.bias is not None}, bits={self.bits}" + ) + + def _weight_to_matrix(self, linear: nn.Module) -> torch.Tensor: + weight = linear.weight.detach() + if isinstance(linear, _ConvNd): + weight = weight.flatten(1) + if isinstance(linear, transformers.pytorch_utils.Conv1D): + weight = weight.T + return weight + + def _pack_weight_tensor( + self, + linear: nn.Module, + *, + smooth: SmoothMethod | None = None, + ) -> torch.Tensor: + weight = self._weight_to_matrix(linear).to(device="cpu", dtype=torch.float32) + weight = _apply_optional_smoother( + weight, + smooth=smooth, + group_size=self.smooth_block_size(), + ) + if weight.shape[1] != self.padded_in_features: + weight = torch.nn.functional.pad(weight, (0, self.padded_in_features - weight.shape[1])) + + quantized = _gguf_quantize(weight.contiguous().numpy(), self.gguf_tensor_qtype) + return torch.from_numpy(np.ascontiguousarray(quantized)).to(torch.uint8) + + def pack(self, linear: nn.Module, scales: torch.Tensor, zeros: torch.Tensor, g_idx: torch.Tensor = None): + self.pack_original(linear=linear, scales=scales, zeros=zeros, g_idx=g_idx) + + def pack_block( + self, + linear: nn.Module, + scales: torch.Tensor, + zeros: torch.Tensor, + g_idx: torch.Tensor = None, + block_in: int = 8192, + workers: int = 1, + ): + del block_in, workers + self.pack_original(linear=linear, scales=scales, zeros=zeros, g_idx=g_idx) + + def pack_gpu( + self, + linear: nn.Module, + scales: torch.Tensor, + zeros: torch.Tensor, + g_idx: torch.Tensor = None, + *, + block_in: int = 8192, + device: torch.device | None = None, + ): + del block_in, device + self.pack_original(linear=linear, scales=scales, zeros=zeros, g_idx=g_idx) + + @torch.inference_mode() + def pack_original( + self, + linear: nn.Module, + scales: torch.Tensor, + zeros: torch.Tensor, + g_idx: torch.Tensor = None, + *, + smooth: SmoothMethod | None = None, + ): + del scales, zeros, g_idx + + qweight = self._pack_weight_tensor(linear, smooth=smooth) + expected_shape = (self.out_features, self._bytes_per_row()) + if tuple(qweight.shape) != expected_shape: + raise RuntimeError( + f"{self.__class__.__name__} produced an invalid GGUF packed shape {tuple(qweight.shape)}; " + f"expected {expected_shape} for padded_in_features={self.padded_in_features}." + ) + if "qweight" in self._buffers: + self.qweight = qweight + else: + self.register_buffer("qweight", qweight) + + if linear.bias is not None: + bias = linear.bias.detach().to(device="cpu", dtype=torch.float16) + if "bias" in self._buffers: + self.bias = bias + else: + self.register_buffer("bias", bias) + else: + self.bias = None + + self.clear_autotune() + self.clear_weight_cache() + + def _resolve_dequant_target( + self, + *, + device: torch.device | str | None, + dtype: torch.dtype | None, + ) -> tuple[torch.device, torch.dtype]: + target_device = self.qweight.device if device is None else torch.device(device) + target_dtype = torch.float32 if dtype is None else dtype + if target_dtype not in self.SUPPORTS_DTYPES: + supported = ", ".join(str(dt).removeprefix("torch.") for dt in self.SUPPORTS_DTYPES) + raise ValueError( + f"{self.__class__.__name__} only supports GGUF dequantization dtypes {{{supported}}}, got `{target_dtype}`." + ) + return target_device, target_dtype + + def _reshape_blocks( + self, + *, + device: torch.device | str | None = None, + ) -> tuple[torch.Tensor, int, int]: + target_device = self.qweight.device if device is None else torch.device(device) + qweight = self.qweight if self.qweight.device == target_device else self.qweight.to(device=target_device) + rows = qweight.shape[0] + num_blocks = qweight.shape[1] // self.gguf_type_size + blocks = qweight.contiguous().view(rows, num_blocks, self.gguf_type_size) + return blocks, rows, num_blocks + + @staticmethod + def _u8_shift(values: tuple[int, ...], device: torch.device) -> torch.Tensor: + return torch.tensor(values, dtype=torch.uint8, device=device) + + def _dequantize_q4_0( + self, + *, + device: torch.device | str | None = None, + dtype: torch.dtype | None = None, + ) -> torch.Tensor: + target_device, target_dtype = self._resolve_dequant_target(device=device, dtype=dtype) + blocks, rows, _ = self._reshape_blocks(device=target_device) + + d = blocks[..., :2].contiguous().view(torch.float16).squeeze(-1) + if d.dtype != target_dtype: + d = d.to(target_dtype) + + qs = blocks[..., 2:] + low = torch.bitwise_and(qs, 0x0F) + high = torch.bitwise_right_shift(qs, 4) + values = torch.cat((low, high), dim=-1).to(torch.int16) - 8 + + weight = d.unsqueeze(-1) * values.to(target_dtype) + return weight.reshape(rows, self.padded_in_features) + + def _dequantize_q8_0( + self, + *, + device: torch.device | str | None = None, + dtype: torch.dtype | None = None, + ) -> torch.Tensor: + target_device, target_dtype = self._resolve_dequant_target(device=device, dtype=dtype) + blocks, rows, _ = self._reshape_blocks(device=target_device) + + d = blocks[..., :2].contiguous().view(torch.float16).squeeze(-1) + if d.dtype != target_dtype: + d = d.to(target_dtype) + + x = blocks[..., 2:].contiguous().view(torch.int8).to(target_dtype) + + weight = d.unsqueeze(-1) * x + return weight.reshape(rows, self.padded_in_features) + + def _dequantize_sign_only( + self, + *, + device: torch.device | str | None = None, + dtype: torch.dtype | None = None, + ) -> torch.Tensor: + target_device, target_dtype = self._resolve_dequant_target(device=device, dtype=dtype) + weight = _dequantize_sign_only_torch( + self.qweight, + block_size=self.gguf_block_size, + type_size=self.gguf_type_size, + device=target_device, + dtype=target_dtype, + ) + return weight.reshape(self.out_features, self.padded_in_features) + + def _dequantize_numpy( + self, + fn, + *, + device: torch.device | str | None = None, + dtype: torch.dtype | None = None, + ) -> torch.Tensor: + target_device, target_dtype = self._resolve_dequant_target(device=device, dtype=dtype) + qweight = self.qweight.detach().cpu().numpy() + weight = fn(qweight) + tensor = torch.from_numpy(np.ascontiguousarray(weight)) + if tensor.device != target_device or tensor.dtype != target_dtype: + tensor = tensor.to(device=target_device, dtype=target_dtype) + return tensor + + def _dequantize_q4_k_blocks(self, blocks: torch.Tensor, target_dtype: torch.dtype) -> torch.Tensor: + rows, num_blocks = blocks.shape[0], blocks.shape[1] + + d = blocks[..., :2].contiguous().view(torch.float16).squeeze(-1).to(target_dtype) + dmin = blocks[..., 2:4].contiguous().view(torch.float16).squeeze(-1).to(target_dtype) + scales = blocks[..., 4:16] + qs = blocks[..., 16:] + + sc, mins = _unpack_q4_k_scale_min_torch(scales) + d = d.unsqueeze(-1) * sc.to(target_dtype) + dm = dmin.unsqueeze(-1) * mins.to(target_dtype) + + q = qs.reshape(rows, num_blocks, 4, 1, 32) + q = torch.bitwise_right_shift( + q, + self._u8_shift((0, 4), device=blocks.device).view(1, 1, 1, 2, 1), + ) + q = torch.bitwise_and(q, 0x0F).reshape(rows, num_blocks, 8, 32) + + return (d.unsqueeze(-1) * q.to(target_dtype) - dm.unsqueeze(-1)).reshape(rows, num_blocks * 256) + + def _dequantize_q5_k_blocks(self, blocks: torch.Tensor, target_dtype: torch.dtype) -> torch.Tensor: + rows, num_blocks = blocks.shape[0], blocks.shape[1] + + d = blocks[..., :2].contiguous().view(torch.float16).squeeze(-1).to(target_dtype) + dmin = blocks[..., 2:4].contiguous().view(torch.float16).squeeze(-1).to(target_dtype) + scales = blocks[..., 4:16] + qh = blocks[..., 16:48] + qs = blocks[..., 48:] + + sc, mins = _unpack_q4_k_scale_min_torch(scales) + d = d.unsqueeze(-1) * sc.to(target_dtype) + dm = dmin.unsqueeze(-1) * mins.to(target_dtype) + + ql = qs.reshape(rows, num_blocks, 4, 1, 32) + ql = torch.bitwise_right_shift( + ql, + self._u8_shift((0, 4), device=blocks.device).view(1, 1, 1, 2, 1), + ) + ql = torch.bitwise_and(ql, 0x0F).reshape(rows, num_blocks, 8, 32) + + qh = torch.bitwise_right_shift( + qh.unsqueeze(-2), + self._u8_shift(tuple(range(8)), device=blocks.device).view(1, 1, 8, 1), + ) + qh = torch.bitwise_and(qh, 0x01).reshape(rows, num_blocks, 8, 32) + q = torch.bitwise_or(ql, torch.bitwise_left_shift(qh, 4)) + + return (d.unsqueeze(-1) * q.to(target_dtype) - dm.unsqueeze(-1)).reshape(rows, num_blocks * 256) + + def _dequantize_q6_k_blocks(self, blocks: torch.Tensor, target_dtype: torch.dtype) -> torch.Tensor: + rows, num_blocks = blocks.shape[0], blocks.shape[1] + + ql = blocks[..., :128] + qh = blocks[..., 128:192] + scales = blocks[..., 192:208].contiguous().view(torch.int8).to(target_dtype) + d = blocks[..., 208:210].contiguous().view(torch.float16).squeeze(-1).to(target_dtype) + d = (d.unsqueeze(-1) * scales).reshape(rows, num_blocks, 16, 1) + + ql = ql.reshape(rows, num_blocks, 2, 1, 64) + ql = torch.bitwise_right_shift( + ql, + self._u8_shift((0, 4), device=blocks.device).view(1, 1, 1, 2, 1), + ) + ql = torch.bitwise_and(ql, 0x0F).reshape(rows, num_blocks, 8, 32) + + qh = qh.reshape(rows, num_blocks, 2, 1, 32) + qh = torch.bitwise_right_shift( + qh, + self._u8_shift((0, 2, 4, 6), device=blocks.device).view(1, 1, 1, 4, 1), + ) + qh = torch.bitwise_and(qh, 0x03).reshape(rows, num_blocks, 8, 32) + + q = torch.bitwise_or(ql, torch.bitwise_left_shift(qh, 4)).to(torch.int16) - 32 + q = q.reshape(rows, num_blocks, 16, 16).to(target_dtype) + + return (d * q).reshape(rows, num_blocks * 256) + + def dequantize_weight( + self, + *, + device: torch.device | str | None = None, + dtype: torch.dtype | None = None, + ) -> torch.Tensor: + if self.gguf_tensor_qtype == "Q4_0": + weight = self._dequantize_q4_0(device=device, dtype=dtype) + elif self.gguf_tensor_qtype == "Q8_0": + weight = self._dequantize_q8_0(device=device, dtype=dtype) + elif self.gguf_tensor_qtype in _GGUF_SIGN_ONLY_TYPE_INFO: + weight = self._dequantize_sign_only(device=device, dtype=dtype) + elif self.gguf_tensor_qtype == "Q4_K": + target_device, target_dtype = self._resolve_dequant_target(device=device, dtype=dtype) + blocks, _, _ = self._reshape_blocks(device=target_device) + weight = self._dequantize_q4_k_blocks(blocks, target_dtype) + elif self.gguf_tensor_qtype == "Q5_K": + target_device, target_dtype = self._resolve_dequant_target(device=device, dtype=dtype) + blocks, _, _ = self._reshape_blocks(device=target_device) + weight = self._dequantize_q5_k_blocks(blocks, target_dtype) + elif self.gguf_tensor_qtype == "Q6_K": + target_device, target_dtype = self._resolve_dequant_target(device=device, dtype=dtype) + blocks, _, _ = self._reshape_blocks(device=target_device) + weight = self._dequantize_q6_k_blocks(blocks, target_dtype) + else: # pragma: no cover - guarded by class SUPPORTS_BITS + raise NotImplementedError(f"Unsupported GGUF qtype: {self.gguf_tensor_qtype}") + + return weight[:, : self.in_features].transpose(0, 1).contiguous() + + def _is_fused_k_forward_candidate(self, x_flat: torch.Tensor) -> bool: + if x_flat.device.type == "cuda": + max_rows = self.gguf_fused_cuda_max_rows + min_matrix_elements = self.gguf_fused_cuda_min_matrix_elements + elif x_flat.device.type == "cpu": + max_rows = self.gguf_fused_cpu_max_rows + min_matrix_elements = self.gguf_fused_cpu_min_matrix_elements + else: + return False + + return ( + self.gguf_tensor_qtype in (_GGUF_K_QTYPES | set(_GGUF_SIGN_ONLY_TYPE_INFO)) + and self.adapter is None + and not self.training + and max_rows > 0 + and (self.in_features * self.out_features) >= min_matrix_elements + and x_flat.shape[0] <= max_rows + ) + + @staticmethod + def _sync_benchmark_device(device: torch.device) -> None: + if device.type == "cuda": + torch.cuda.synchronize(device=device) + + def _benchmark_forward_runner(self, fn, *, device: torch.device) -> float: + with torch.inference_mode(): + for _ in range(self.gguf_fused_autotune_warmup): + fn() + self._sync_benchmark_device(device) + + start = time.perf_counter() + for _ in range(self.gguf_fused_autotune_iters): + fn() + self._sync_benchmark_device(device) + + return (time.perf_counter() - start) / self.gguf_fused_autotune_iters + + def _benchmark_dense_forward(self, x_flat: torch.Tensor) -> float: + return self._benchmark_forward_runner( + lambda: self._forward_dequant_matmul(x_flat), + device=x_flat.device, + ) + + def _benchmark_fused_forward(self, x_flat: torch.Tensor) -> float: + return self._benchmark_forward_runner( + lambda: self._forward_fused_k(x_flat), + device=x_flat.device, + ) + + def _autotune(self, x_flat: torch.Tensor) -> bool: + try: + fused_time = self._benchmark_fused_forward(x_flat) + dense_time = self._benchmark_dense_forward(x_flat) + return fused_time <= dense_time * (1.0 - self.gguf_fused_autotune_margin) + except Exception: + return False + + def _should_use_fused_k_forward(self, x_flat: torch.Tensor) -> bool: + if not self._is_fused_k_forward_candidate(x_flat): + return False + + if not self.autotune_enabled: + return True + + return bool(self.maybe_autotune(x_flat)) + + def _forward_dequant_matmul(self, x_flat: torch.Tensor) -> torch.Tensor: + weight = self.dequantize_weight(device=x_flat.device, dtype=x_flat.dtype) + return torch.matmul(x_flat, weight) + + def _forward_fused_k(self, x_flat: torch.Tensor) -> torch.Tensor: + target_dtype = x_flat.dtype + blocks, _, num_blocks = self._reshape_blocks(device=x_flat.device) + + if x_flat.shape[-1] != self.padded_in_features: + x_work = F.pad(x_flat, (0, self.padded_in_features - x_flat.shape[-1])) + else: + x_work = x_flat + + output = torch.zeros((x_flat.shape[0], self.out_features), device=x_flat.device, dtype=target_dtype) + + for start in range(0, num_blocks, self.gguf_fused_chunk_blocks): + end = min(start + self.gguf_fused_chunk_blocks, num_blocks) + block_chunk = blocks[:, start:end, :] + + if self.gguf_tensor_qtype == "Q4_K": + weight_chunk = self._dequantize_q4_k_blocks(block_chunk, target_dtype) + elif self.gguf_tensor_qtype == "Q5_K": + weight_chunk = self._dequantize_q5_k_blocks(block_chunk, target_dtype) + elif self.gguf_tensor_qtype == "Q6_K": + weight_chunk = self._dequantize_q6_k_blocks(block_chunk, target_dtype) + elif self.gguf_tensor_qtype in _GGUF_SIGN_ONLY_TYPE_INFO: + weight_chunk = _dequantize_sign_only_torch( + block_chunk.reshape(block_chunk.shape[0], -1), + block_size=self.gguf_block_size, + type_size=self.gguf_type_size, + device=x_flat.device, + dtype=target_dtype, + ) + else: # pragma: no cover - guarded by _should_use_fused_k_forward + raise NotImplementedError(f"Unsupported GGUF fused qtype: {self.gguf_tensor_qtype}") + + x_chunk = x_work[:, start * self.gguf_block_size : end * self.gguf_block_size] + output = output + torch.matmul(x_chunk, weight_chunk.transpose(0, 1)) + + return output + + def forward(self, x: torch.Tensor): + original_shape = x.shape[:-1] + (self.out_features,) + x_flat = x.reshape(-1, x.shape[-1]) + + if self._should_use_fused_k_forward(x_flat): + output = self._forward_fused_k(x_flat) + else: + output = self._forward_dequant_matmul(x_flat) + + if self.bias is not None: + bias = self.bias + if bias.device != output.device or bias.dtype != output.dtype: + bias = bias.to(device=output.device, dtype=output.dtype) + output = output + bias + + if self.adapter: + output = self.adapter.apply(x=x_flat, out=output) + + return output.reshape(original_shape) + + +__all__ = ["GGUFTorchLinear"] diff --git a/gptqmodel/nn_modules/qlinear/gguf_cpp.py b/gptqmodel/nn_modules/qlinear/gguf_cpp.py new file mode 100644 index 000000000..8620a0036 --- /dev/null +++ b/gptqmodel/nn_modules/qlinear/gguf_cpp.py @@ -0,0 +1,773 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import ctypes +import os +from pathlib import Path +from typing import Dict, Optional, Tuple + +import torch +import torch.nn.functional as F + +from ...adapter.adapter import Adapter, Lora +from ...models._const import DEVICE, PLATFORM +from ...quantization.config import FORMAT, METHOD +from ...utils.backend import BACKEND +from .gguf import GGUFTorchLinear + + +try: + import llama_cpp as _llama_cpp_pkg + from llama_cpp import llama_cpp as _llama_cpp_lib + from llama_cpp._ctypes_extensions import load_shared_library as _llama_cpp_load_shared_library + + _LLAMA_CPP_IMPORT_ERROR = None +except Exception as exc: # pragma: no cover - optional dependency + _llama_cpp_pkg = None + _llama_cpp_lib = None + _llama_cpp_load_shared_library = None + _LLAMA_CPP_IMPORT_ERROR = exc + + +class _GGMLInitParams(ctypes.Structure): + _fields_ = [ + ("mem_size", ctypes.c_size_t), + ("mem_buffer", ctypes.c_void_p), + ("no_alloc", ctypes.c_bool), + ] + + +class _GGMLMatmulPlan: + def __init__( + self, + *, + ctx, + buffer, + weight_tensor, + input_tensor, + output_tensor, + graph, + rows: int, + out_features: int, + output_nbytes: int, + output_dtype: torch.dtype, + backend_buffer_free, + ctx_free, + ) -> None: + self.ctx = ctx + self.buffer = buffer + self.weight_tensor = weight_tensor + self.input_tensor = input_tensor + self.output_tensor = output_tensor + self.graph = graph + self.rows = rows + self.out_features = out_features + self.output_nbytes = output_nbytes + self.output_dtype = output_dtype + self._backend_buffer_free = backend_buffer_free + self._ctx_free = ctx_free + + def close(self) -> None: + if self.buffer: + self._backend_buffer_free(self.buffer) + self.buffer = None + if self.ctx: + self._ctx_free(self.ctx) + self.ctx = None + + def __del__(self) -> None: # pragma: no cover - best effort cleanup + try: + self.close() + except Exception: + # Destructors must not raise during GC or interpreter shutdown. + pass + + +class _GGMLBridge: + GGML_METADATA_BYTES = 1 << 20 + + def __init__(self) -> None: + if _LLAMA_CPP_IMPORT_ERROR is not None: + raise ModuleNotFoundError( + "GGUFCppKernel requires `llama-cpp-python` to be installed." + ) from _LLAMA_CPP_IMPORT_ERROR + + lib_dir = Path(_llama_cpp_pkg.__file__).resolve().parent / "lib" + self._ggml_base = _llama_cpp_load_shared_library("ggml-base", lib_dir) + self._ggml_cpu = _llama_cpp_load_shared_library("ggml-cpu", lib_dir) + self._ggml_cuda = None + self._ggml_cuda_error: Optional[Exception] = None + try: + self._ggml_cuda = _llama_cpp_load_shared_library("ggml-cuda", lib_dir) + except Exception as exc: # pragma: no cover - optional shared library + self._ggml_cuda_error = exc + self._bind_functions() + + self.ggml_type_f32 = int(_llama_cpp_lib.GGML_TYPE_F32) + self.ggml_type_f16 = int(_llama_cpp_lib.GGML_TYPE_F16) + self.ggml_qtypes = { + "Q4_0": int(_llama_cpp_lib.GGML_TYPE_Q4_0), + "Q8_0": int(_llama_cpp_lib.GGML_TYPE_Q8_0), + "Q4_K": int(_llama_cpp_lib.GGML_TYPE_Q4_K), + "Q5_K": int(_llama_cpp_lib.GGML_TYPE_Q5_K), + "Q6_K": int(_llama_cpp_lib.GGML_TYPE_Q6_K), + } + self._cpu_backend: Optional[int] = None + self._cuda_backends: Dict[int, int] = {} + + def _bind(self, lib, name: str, argtypes, restype) -> None: + fn = getattr(lib, name) + fn.argtypes = argtypes + fn.restype = restype + + def _bind_functions(self) -> None: + self._bind(self._ggml_base, "ggml_init", [_GGMLInitParams], ctypes.c_void_p) + self._bind(self._ggml_base, "ggml_free", [ctypes.c_void_p], None) + self._bind( + self._ggml_base, + "ggml_new_tensor_2d", + [ctypes.c_void_p, ctypes.c_int, ctypes.c_int64, ctypes.c_int64], + ctypes.c_void_p, + ) + self._bind(self._ggml_base, "ggml_mul_mat", [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p], ctypes.c_void_p) + self._bind(self._ggml_base, "ggml_set_input", [ctypes.c_void_p], None) + self._bind(self._ggml_base, "ggml_new_graph", [ctypes.c_void_p], ctypes.c_void_p) + self._bind(self._ggml_base, "ggml_build_forward_expand", [ctypes.c_void_p, ctypes.c_void_p], None) + self._bind(self._ggml_base, "ggml_nbytes", [ctypes.c_void_p], ctypes.c_size_t) + self._bind(self._ggml_base, "ggml_element_size", [ctypes.c_void_p], ctypes.c_size_t) + self._bind(self._ggml_base, "ggml_backend_alloc_ctx_tensors", [ctypes.c_void_p, ctypes.c_void_p], ctypes.c_void_p) + self._bind(self._ggml_base, "ggml_backend_buffer_free", [ctypes.c_void_p], None) + self._bind(self._ggml_base, "ggml_backend_free", [ctypes.c_void_p], None) + self._bind( + self._ggml_base, + "ggml_backend_tensor_set", + [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ctypes.c_size_t], + None, + ) + self._bind( + self._ggml_base, + "ggml_backend_tensor_get", + [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, ctypes.c_size_t], + None, + ) + self._bind(self._ggml_base, "ggml_backend_graph_compute", [ctypes.c_void_p, ctypes.c_void_p], ctypes.c_int) + self._bind(self._ggml_base, "ggml_backend_synchronize", [ctypes.c_void_p], None) + self._bind(self._ggml_cpu, "ggml_backend_cpu_init", [], ctypes.c_void_p) + self._bind(self._ggml_cpu, "ggml_backend_cpu_set_n_threads", [ctypes.c_void_p, ctypes.c_int], None) + if self._ggml_cuda is not None: + self._bind(self._ggml_cuda, "ggml_backend_cuda_init", [ctypes.c_int], ctypes.c_void_p) + self._bind(self._ggml_cuda, "ggml_backend_cuda_get_device_count", [], ctypes.c_int) + + @staticmethod + def _cpu_threads() -> int: + override = os.environ.get("GPTQMODEL_GGUF_CPP_THREADS") + if override is not None: + try: + return max(1, int(override)) + except ValueError: + pass + return max(1, torch.get_num_threads()) + + def cpu_available(self) -> Tuple[bool, Optional[Exception]]: + try: + self._get_cpu_backend() + except Exception as exc: + return False, exc + return True, None + + def cuda_available(self) -> Tuple[bool, Optional[Exception]]: + try: + if self._ggml_cuda is None: + raise RuntimeError("llama-cpp-python was built without GGML CUDA support.") + if not torch.cuda.is_available(): + raise RuntimeError("Torch CUDA is unavailable.") + device_count = int(self._ggml_cuda.ggml_backend_cuda_get_device_count()) + if device_count <= 0: + raise RuntimeError("GGML CUDA backend found no CUDA devices.") + self._get_cuda_backend(0) + except Exception as exc: + return False, exc + return True, None + + def _get_cpu_backend(self) -> int: + if self._cpu_backend is None: + backend = self._ggml_cpu.ggml_backend_cpu_init() + if not backend: + raise RuntimeError("GGUFCppKernel failed to initialize GGML CPU backend.") + self._ggml_cpu.ggml_backend_cpu_set_n_threads(backend, self._cpu_threads()) + self._cpu_backend = backend + return self._cpu_backend + + def _get_cuda_backend(self, device_index: int) -> int: + if self._ggml_cuda is None: + raise RuntimeError("llama-cpp-python was built without GGML CUDA support.") from self._ggml_cuda_error + if device_index < 0: + device_index = 0 + device_count = int(self._ggml_cuda.ggml_backend_cuda_get_device_count()) + if device_index >= device_count: + raise RuntimeError( + f"GGML CUDA backend device index `{device_index}` is out of range for `{device_count}` devices." + ) + if device_index not in self._cuda_backends: + backend = self._ggml_cuda.ggml_backend_cuda_init(device_index) + if not backend: + raise RuntimeError(f"GGUFCudaKernel failed to initialize GGML CUDA backend for device `{device_index}`.") + self._cuda_backends[device_index] = backend + return self._cuda_backends[device_index] + + @staticmethod + def _normalize_qweight_cpu(qweight: torch.Tensor) -> torch.Tensor: + qweight_cpu = qweight.detach() + if qweight_cpu.device.type != "cpu": + qweight_cpu = qweight_cpu.to(device="cpu") + if not qweight_cpu.is_contiguous(): + qweight_cpu = qweight_cpu.contiguous() + return qweight_cpu + + @staticmethod + def _normalize_input_cpu(x: torch.Tensor, padded_in_features: int) -> torch.Tensor: + x_cpu = x.detach().to(device="cpu", dtype=torch.float32) + if x_cpu.shape[-1] != padded_in_features: + x_cpu = F.pad(x_cpu, (0, padded_in_features - x_cpu.shape[-1])) + if not x_cpu.is_contiguous(): + x_cpu = x_cpu.contiguous() + return x_cpu + + def _normalize_input_cuda( + self, + x: torch.Tensor, + padded_in_features: int, + ) -> tuple[torch.Tensor, int]: + x_cuda = x.detach() + if x_cuda.dtype == torch.float16: + ggml_input_type = self.ggml_type_f16 + elif x_cuda.dtype == torch.float32: + ggml_input_type = self.ggml_type_f32 + elif x_cuda.dtype == torch.bfloat16: + # ggml in llama-cpp-python does not expose BF16 here, so use native CUDA fp16. + x_cuda = x_cuda.to(dtype=torch.float16) + ggml_input_type = self.ggml_type_f16 + else: + raise RuntimeError( + "GGUFCudaKernel only supports float16, bfloat16, or float32 inputs." + ) + + if x_cuda.shape[-1] != padded_in_features: + x_cuda = F.pad(x_cuda, (0, padded_in_features - x_cuda.shape[-1])) + if not x_cuda.is_contiguous(): + x_cuda = x_cuda.contiguous() + return x_cuda, ggml_input_type + + @staticmethod + def _normalize_qweight_cuda(qweight: torch.Tensor, device: torch.device) -> torch.Tensor: + qweight_cuda = qweight.detach() + if qweight_cuda.device != device: + qweight_cuda = qweight_cuda.to(device=device) + if not qweight_cuda.is_contiguous(): + qweight_cuda = qweight_cuda.contiguous() + return qweight_cuda + + @staticmethod + def _torch_dtype_from_ggml_element_size(output_element_size: int, *, kernel_name: str) -> torch.dtype: + if output_element_size == 4: + return torch.float32 + if output_element_size == 2: + return torch.float16 + raise RuntimeError( + f"{kernel_name} received unsupported GGML output element size `{output_element_size}`." + ) + + def _run_quantized_matmul( + self, + *, + backend: int, + qweight_cpu: torch.Tensor, + x_cpu: torch.Tensor, + gguf_tensor_qtype: str, + padded_in_features: int, + out_features: int, + kernel_name: str, + ) -> torch.Tensor: + + ctx = self._ggml_base.ggml_init( + _GGMLInitParams( + mem_size=self.GGML_METADATA_BYTES, + mem_buffer=None, + no_alloc=True, + ) + ) + if not ctx: + raise RuntimeError(f"{kernel_name} failed to initialize GGML metadata context.") + + buffer = None + try: + weight_tensor = self._ggml_base.ggml_new_tensor_2d( + ctx, + self.ggml_qtypes[gguf_tensor_qtype], + padded_in_features, + out_features, + ) + input_tensor = self._ggml_base.ggml_new_tensor_2d( + ctx, + self.ggml_type_f32, + padded_in_features, + x_cpu.shape[0], + ) + if not weight_tensor or not input_tensor: + raise RuntimeError(f"{kernel_name} failed to create GGML tensors.") + + self._ggml_base.ggml_set_input(input_tensor) + output_tensor = self._ggml_base.ggml_mul_mat(ctx, weight_tensor, input_tensor) + if not output_tensor: + raise RuntimeError(f"{kernel_name} failed to create GGML matmul node.") + + graph = self._ggml_base.ggml_new_graph(ctx) + if not graph: + raise RuntimeError(f"{kernel_name} failed to allocate GGML graph.") + self._ggml_base.ggml_build_forward_expand(graph, output_tensor) + + buffer = self._ggml_base.ggml_backend_alloc_ctx_tensors(ctx, backend) + if not buffer: + raise RuntimeError(f"{kernel_name} failed to allocate GGML backend tensors.") + + self._ggml_base.ggml_backend_tensor_set( + weight_tensor, + ctypes.c_void_p(qweight_cpu.data_ptr()), + 0, + qweight_cpu.numel() * qweight_cpu.element_size(), + ) + self._ggml_base.ggml_backend_tensor_set( + input_tensor, + ctypes.c_void_p(x_cpu.data_ptr()), + 0, + x_cpu.numel() * x_cpu.element_size(), + ) + + status = self._ggml_base.ggml_backend_graph_compute(backend, graph) + if status != 0: + raise RuntimeError(f"{kernel_name} GGML graph compute failed with status={status}.") + self._ggml_base.ggml_backend_synchronize(backend) + + output_nbytes = self._ggml_base.ggml_nbytes(output_tensor) + output_element_size = self._ggml_base.ggml_element_size(output_tensor) + if output_element_size == 4: + output_dtype = torch.float32 + elif output_element_size == 2: + output_dtype = torch.float16 + else: + raise RuntimeError( + f"{kernel_name} received unsupported GGML output element size `{output_element_size}`." + ) + + output = torch.empty((x_cpu.shape[0], out_features), dtype=output_dtype, device="cpu") + expected_nbytes = output.numel() * output.element_size() + if expected_nbytes != output_nbytes: + raise RuntimeError( + f"{kernel_name} GGML output size mismatch: expected {expected_nbytes}, got {output_nbytes}." + ) + + self._ggml_base.ggml_backend_tensor_get( + output_tensor, + ctypes.c_void_p(output.data_ptr()), + 0, + output_nbytes, + ) + return output + finally: + if buffer: + self._ggml_base.ggml_backend_buffer_free(buffer) + self._ggml_base.ggml_free(ctx) + + def build_quantized_matmul_cuda_plan( + self, + *, + backend: int, + qweight: torch.Tensor, + gguf_tensor_qtype: str, + padded_in_features: int, + out_features: int, + rows: int, + input_ggml_type: int, + kernel_name: str, + ) -> _GGMLMatmulPlan: + ctx = self._ggml_base.ggml_init( + _GGMLInitParams( + mem_size=self.GGML_METADATA_BYTES, + mem_buffer=None, + no_alloc=True, + ) + ) + if not ctx: + raise RuntimeError(f"{kernel_name} failed to initialize GGML metadata context.") + + buffer = None + try: + weight_tensor = self._ggml_base.ggml_new_tensor_2d( + ctx, + self.ggml_qtypes[gguf_tensor_qtype], + padded_in_features, + out_features, + ) + input_tensor = self._ggml_base.ggml_new_tensor_2d( + ctx, + input_ggml_type, + padded_in_features, + rows, + ) + if not weight_tensor or not input_tensor: + raise RuntimeError(f"{kernel_name} failed to create GGML tensors.") + + self._ggml_base.ggml_set_input(input_tensor) + output_tensor = self._ggml_base.ggml_mul_mat(ctx, weight_tensor, input_tensor) + if not output_tensor: + raise RuntimeError(f"{kernel_name} failed to create GGML matmul node.") + + graph = self._ggml_base.ggml_new_graph(ctx) + if not graph: + raise RuntimeError(f"{kernel_name} failed to allocate GGML graph.") + self._ggml_base.ggml_build_forward_expand(graph, output_tensor) + + buffer = self._ggml_base.ggml_backend_alloc_ctx_tensors(ctx, backend) + if not buffer: + raise RuntimeError(f"{kernel_name} failed to allocate GGML backend tensors.") + + self._ggml_base.ggml_backend_tensor_set( + weight_tensor, + ctypes.c_void_p(qweight.data_ptr()), + 0, + qweight.numel() * qweight.element_size(), + ) + output_nbytes = self._ggml_base.ggml_nbytes(output_tensor) + output_dtype = self._torch_dtype_from_ggml_element_size( + int(self._ggml_base.ggml_element_size(output_tensor)), + kernel_name=kernel_name, + ) + return _GGMLMatmulPlan( + ctx=ctx, + buffer=buffer, + weight_tensor=weight_tensor, + input_tensor=input_tensor, + output_tensor=output_tensor, + graph=graph, + rows=rows, + out_features=out_features, + output_nbytes=output_nbytes, + output_dtype=output_dtype, + backend_buffer_free=self._ggml_base.ggml_backend_buffer_free, + ctx_free=self._ggml_base.ggml_free, + ) + except Exception: + if buffer: + self._ggml_base.ggml_backend_buffer_free(buffer) + self._ggml_base.ggml_free(ctx) + raise + + def run_quantized_matmul_cuda_plan( + self, + *, + backend: int, + plan: _GGMLMatmulPlan, + x: torch.Tensor, + kernel_name: str, + ) -> torch.Tensor: + self._ggml_base.ggml_backend_tensor_set( + plan.input_tensor, + ctypes.c_void_p(x.data_ptr()), + 0, + x.numel() * x.element_size(), + ) + status = self._ggml_base.ggml_backend_graph_compute(backend, plan.graph) + if status != 0: + raise RuntimeError(f"{kernel_name} GGML graph compute failed with status={status}.") + self._ggml_base.ggml_backend_synchronize(backend) + + output = torch.empty((plan.rows, plan.out_features), dtype=plan.output_dtype, device=x.device) + expected_nbytes = output.numel() * output.element_size() + if expected_nbytes != plan.output_nbytes: + raise RuntimeError( + f"{kernel_name} GGML output size mismatch: expected {expected_nbytes}, got {plan.output_nbytes}." + ) + self._ggml_base.ggml_backend_tensor_get( + plan.output_tensor, + ctypes.c_void_p(output.data_ptr()), + 0, + plan.output_nbytes, + ) + return output + + def quantized_matmul_cpu( + self, + *, + qweight: torch.Tensor, + x: torch.Tensor, + gguf_tensor_qtype: str, + padded_in_features: int, + out_features: int, + ) -> torch.Tensor: + if x.device.type != "cpu": + raise RuntimeError("GGUFCppKernel only supports CPU input tensors.") + + return self._run_quantized_matmul( + backend=self._get_cpu_backend(), + qweight_cpu=self._normalize_qweight_cpu(qweight), + x_cpu=self._normalize_input_cpu(x, padded_in_features), + gguf_tensor_qtype=gguf_tensor_qtype, + padded_in_features=padded_in_features, + out_features=out_features, + kernel_name="GGUFCppKernel", + ) + + def quantized_matmul_cuda( + self, + *, + qweight: torch.Tensor, + x: torch.Tensor, + gguf_tensor_qtype: str, + padded_in_features: int, + out_features: int, + plan: _GGMLMatmulPlan | None = None, + ) -> tuple[torch.Tensor, _GGMLMatmulPlan]: + if x.device.type != "cuda": + raise RuntimeError("GGUFCudaKernel only supports CUDA input tensors.") + + device = x.device + backend = self._get_cuda_backend(0 if device.index is None else device.index) + x_cuda, input_ggml_type = self._normalize_input_cuda(x, padded_in_features) + if plan is None: + qweight_cuda = self._normalize_qweight_cuda(qweight, device=device) + plan = self.build_quantized_matmul_cuda_plan( + backend=backend, + qweight=qweight_cuda, + gguf_tensor_qtype=gguf_tensor_qtype, + padded_in_features=padded_in_features, + out_features=out_features, + rows=x_cuda.shape[0], + input_ggml_type=input_ggml_type, + kernel_name="GGUFCudaKernel", + ) + output = self.run_quantized_matmul_cuda_plan( + backend=backend, + plan=plan, + x=x_cuda, + kernel_name="GGUFCudaKernel", + ) + return output, plan + + +_GGML_BRIDGE: Optional[_GGMLBridge] = None + + +def _get_ggml_bridge() -> _GGMLBridge: + global _GGML_BRIDGE + if _GGML_BRIDGE is None: + _GGML_BRIDGE = _GGMLBridge() + return _GGML_BRIDGE + + +class GGUFCppKernel(GGUFTorchLinear): + SUPPORTS_BACKENDS = [BACKEND.GGUF_CPP_CPU] + SUPPORTS_METHODS = [METHOD.GGUF] + SUPPORTS_FORMATS = {FORMAT.GGUF: 25} + SUPPORTS_BITS = [4, 5, 6, 8] + SUPPORTS_SHARDS = True + SUPPORTS_TRAINING = False + SUPPORTS_AUTO_PADDING = True + SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [1] + SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [1] + SUPPORTS_DEVICES = [DEVICE.CPU] + SUPPORTS_PLATFORM = [PLATFORM.ALL] + SUPPORTS_PACK_DTYPES = [torch.int8, torch.int16, torch.int32] + SUPPORTS_ADAPTERS = [Lora] + SUPPORTS_DTYPES = [torch.float16, torch.bfloat16, torch.float32] + + REQUIRES_FORMAT_V2 = False + AUTOTUNE = False + + QUANT_TYPE = "gguf" + + pack = None + pack_block = None + pack_gpu = None + pack_original = None + + def __init__( + self, + bits, + group_size: int, + sym: bool, + desc_act: bool, + in_features: int, + out_features: int, + bias: bool = False, + pack_dtype: torch.dtype = torch.int32, + adapter: Adapter = None, + register_buffers: bool = True, + **kwargs, + ): + kwargs.setdefault("backend", BACKEND.GGUF_CPP_CPU) + super().__init__( + bits=bits, + group_size=group_size, + sym=sym, + desc_act=desc_act, + in_features=in_features, + out_features=out_features, + bias=bias, + pack_dtype=pack_dtype, + adapter=adapter, + register_buffers=register_buffers, + **kwargs, + ) + + @classmethod + def validate_once(cls) -> Tuple[bool, Optional[Exception]]: + try: + return _get_ggml_bridge().cpu_available() + except Exception as exc: + return False, exc + + def forward(self, x: torch.Tensor): + original_shape = x.shape[:-1] + (self.out_features,) + x_flat = x.reshape(-1, x.shape[-1]) + if x_flat.device.type != "cpu": + raise RuntimeError( + f"{self.__class__.__name__} only supports CPU inference. " + "Load GGUF models on CPU or use BACKEND.GGUF_CPP_CUDA or BACKEND.GGUF_TORCH for CUDA inference." + ) + + output = _get_ggml_bridge().quantized_matmul_cpu( + qweight=self.qweight, + x=x_flat, + gguf_tensor_qtype=self.gguf_tensor_qtype, + padded_in_features=self.padded_in_features, + out_features=self.out_features, + ) + if output.dtype != x_flat.dtype: + output = output.to(dtype=x_flat.dtype) + + if self.bias is not None: + bias = self.bias + if bias.device != output.device or bias.dtype != output.dtype: + bias = bias.to(device=output.device, dtype=output.dtype) + output = output + bias + + if self.adapter: + output = self.adapter.apply(x=x_flat, out=output) + + return output.reshape(original_shape) + +class GGUFCudaKernel(GGUFTorchLinear): + SUPPORTS_BACKENDS = [BACKEND.GGUF_CPP_CUDA] + SUPPORTS_METHODS = [METHOD.GGUF] + SUPPORTS_FORMATS = {FORMAT.GGUF: 35} + SUPPORTS_BITS = [4, 5, 6, 8] + SUPPORTS_SHARDS = True + SUPPORTS_TRAINING = False + SUPPORTS_AUTO_PADDING = True + SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [1] + SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [1] + SUPPORTS_DEVICES = [DEVICE.CUDA] + SUPPORTS_PLATFORM = [PLATFORM.ALL] + SUPPORTS_PACK_DTYPES = [torch.int8, torch.int16, torch.int32] + SUPPORTS_ADAPTERS = [Lora] + SUPPORTS_DTYPES = [torch.float16, torch.bfloat16, torch.float32] + + REQUIRES_FORMAT_V2 = False + AUTOTUNE = False + + QUANT_TYPE = "gguf" + + pack = None + pack_block = None + pack_gpu = None + pack_original = None + + def __init__( + self, + bits, + group_size: int, + sym: bool, + desc_act: bool, + in_features: int, + out_features: int, + bias: bool = False, + pack_dtype: torch.dtype = torch.int32, + adapter: Adapter = None, + register_buffers: bool = True, + **kwargs, + ): + kwargs.setdefault("backend", BACKEND.GGUF_CPP_CUDA) + super().__init__( + bits=bits, + group_size=group_size, + sym=sym, + desc_act=desc_act, + in_features=in_features, + out_features=out_features, + bias=bias, + pack_dtype=pack_dtype, + adapter=adapter, + register_buffers=register_buffers, + **kwargs, + ) + self._ggml_cuda_plans: Dict[tuple[int, int, torch.dtype, int], _GGMLMatmulPlan] = {} + + @classmethod + def validate_once(cls) -> Tuple[bool, Optional[Exception]]: + try: + return _get_ggml_bridge().cuda_available() + except Exception as exc: + return False, exc + + def clear_weight_cache(self) -> None: + for plan in self._ggml_cuda_plans.values(): + plan.close() + self._ggml_cuda_plans.clear() + return super().clear_weight_cache() + + def _cuda_plan_key(self, x_flat: torch.Tensor) -> tuple[int, int, torch.dtype, int]: + device_index = 0 if x_flat.device.index is None else x_flat.device.index + return ( + device_index, + x_flat.shape[0], + x_flat.dtype, + self.qweight.data_ptr(), + ) + + def forward(self, x: torch.Tensor): + original_shape = x.shape[:-1] + (self.out_features,) + x_flat = x.reshape(-1, x.shape[-1]) + if x_flat.device.type != "cuda": + raise RuntimeError( + f"{self.__class__.__name__} only supports CUDA inference. " + "Load GGUF models on CUDA or use BACKEND.GGUF_CPP_CPU or BACKEND.GGUF_TORCH for CPU inference." + ) + + plan_key = self._cuda_plan_key(x_flat) + output, plan = _get_ggml_bridge().quantized_matmul_cuda( + qweight=self.qweight, + x=x_flat, + gguf_tensor_qtype=self.gguf_tensor_qtype, + padded_in_features=self.padded_in_features, + out_features=self.out_features, + plan=self._ggml_cuda_plans.get(plan_key), + ) + self._ggml_cuda_plans.setdefault(plan_key, plan) + if output.device != x_flat.device or output.dtype != x_flat.dtype: + output = output.to(device=x_flat.device, dtype=x_flat.dtype) + + if self.bias is not None: + bias = self.bias + if bias.device != output.device or bias.dtype != output.dtype: + bias = bias.to(device=output.device, dtype=output.dtype) + output = output + bias + + if self.adapter: + output = self.adapter.apply(x=x_flat, out=output) + + return output.reshape(original_shape) + + +__all__ = ["GGUFCppKernel", "GGUFCudaKernel"] diff --git a/gptqmodel/nn_modules/qlinear/gguf_triton.py b/gptqmodel/nn_modules/qlinear/gguf_triton.py new file mode 100644 index 000000000..bfe37c068 --- /dev/null +++ b/gptqmodel/nn_modules/qlinear/gguf_triton.py @@ -0,0 +1,1646 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import Any, Callable, Optional, Tuple + +import torch + +from ...adapter.adapter import Adapter, Lora +from ...models._const import DEVICE, PLATFORM +from ...quantization import FORMAT, METHOD +from ...utils.backend import BACKEND +from ...utils.python import has_gil_disabled +from .gguf import GGUFTorchLinear, _unpack_q4_k_scale_min_torch + + +try: + import triton + import triton.language as tl + from packaging import version + + from ..triton_utils import custom_autotune + + _TRITON_AVAILABLE = True +except Exception: # pragma: no cover - optional dependency + triton = None + tl = None + custom_autotune = None + _TRITON_AVAILABLE = False + +_CUDA_DEVICE_CAPABILITY_CACHE: dict[int, tuple[int, int]] = {} + + +def triton_available() -> bool: + return _TRITON_AVAILABLE + + +if _TRITON_AVAILABLE: + _GGUF_TRITON_SMALL_CONFIGS = [ + triton.Config( + { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 32, + }, + num_stages=2, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + }, + num_stages=2, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + }, + num_stages=2, + num_warps=4, + ), + ] + + _GGUF_TRITON_LARGE_CONFIGS = [ + triton.Config( + { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 16, + }, + num_stages=2, + num_warps=2, + ), + triton.Config( + { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 32, + }, + num_stages=2, + num_warps=2, + ), + triton.Config( + { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + }, + num_stages=2, + num_warps=2, + ), + triton.Config( + { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 32, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + }, + num_stages=2, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + }, + num_stages=4, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + }, + num_stages=2, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + }, + num_stages=2, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + }, + num_stages=2, + num_warps=4, + ), + triton.Config( + { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + }, + num_stages=2, + num_warps=8, + ), + ] + _GGUF_TRITON_LARGE_NUM_BLOCKS = 16 + _Q1_0_G128_SM80_DECODE_NARROW_CONFIG = { + "BLOCK_SIZE_M": 4, + "BLOCK_SIZE_N": 32, + "num_warps": 4, + "num_stages": 2, + } + _Q1_0_G128_SM80_DECODE_WIDE_CONFIG = { + "BLOCK_SIZE_M": 2, + "BLOCK_SIZE_N": 32, + "num_warps": 4, + "num_stages": 4, + } + _Q1_0_G128_SM89_DECODE_NARROW_CONFIG = { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 32, + "num_warps": 8, + "num_stages": 2, + } + _Q1_0_G128_SM89_DECODE_2048_TO_2048_CONFIG = { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "num_warps": 4, + "num_stages": 2, + } + _Q1_0_G128_SM89_DECODE_2048_TO_6144_CONFIG = { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 64, + "num_warps": 8, + "num_stages": 2, + } + _Q1_0_G128_SM89_DECODE_WIDE_CONFIG = { + "BLOCK_SIZE_M": 2, + "BLOCK_SIZE_N": 32, + "num_warps": 8, + "num_stages": 4, + } + _Q1_0_G128_SM80_PREFILL_NARROW_CONFIG = { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "num_warps": 4, + "num_stages": 4, + } + _Q1_0_G128_SM80_PREFILL_WIDE_CONFIG = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "num_warps": 4, + "num_stages": 4, + } + _Q1_0_G128_SM89_PREFILL_NARROW_CONFIG = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "num_warps": 8, + "num_stages": 4, + } + _Q1_0_G128_SM89_PREFILL_WIDE_CONFIG = { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "num_warps": 4, + "num_stages": 4, + } + _Q1_0_G128_SM80_U32_DECODE_CONFIG = { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "num_warps": 4, + "num_stages": 4, + } + _Q1_0_G128_SM80_U32_PREFILL_CONFIG = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "num_warps": 4, + "num_stages": 4, + } + # These exact-shape decode configs target the dominant post-down-proj A100 decode hotspots. + _Q1_0_G128_SM80_U32_DECODE_2048_TO_1024_CONFIG = { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "num_warps": 2, + "num_stages": 2, + } + _Q1_0_G128_SM80_U32_DECODE_2048_TO_2048_CONFIG = { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "num_warps": 4, + "num_stages": 2, + } + _Q1_0_G128_SM80_U32_DECODE_2048_TO_6144_CONFIG = { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 32, + "num_warps": 4, + "num_stages": 4, + } + + @triton.jit + def _gguf_q1_0_g128_fused_matmul_kernel_impl( + x_ptr, + sign_ptr, + scale_ptr, + out_ptr, + M, + N, + NUM_BLOCKS, + stride_xm, + stride_xk, + stride_qb, + stride_qq, + stride_qn, + stride_sb, + stride_sn, + stride_om, + stride_on, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + ): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + m_mask = offs_m < M + n_mask = offs_n < N + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + bit_shifts = tl.arange(0, 8) + + for block_idx in range(0, NUM_BLOCKS): + scale = tl.load( + scale_ptr + block_idx * stride_sb + offs_n * stride_sn, + mask=n_mask, + other=0.0, + ) + + for sign_group in range(0, 4): + offs_k = block_idx * 128 + sign_group * 32 + tl.arange(0, 32) + a = tl.load( + x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk, + mask=m_mask[:, None], + other=0.0, + ) + + packed = tl.load( + sign_ptr + + block_idx * stride_qb + + (sign_group * 4 + tl.arange(0, 4))[:, None] * stride_qq + + offs_n[None, :] * stride_qn, + mask=n_mask[None, :], + other=0, + ) + sign_bits = (packed[:, None, :] >> bit_shifts[None, :, None]) & 0x01 + signs = tl.reshape(sign_bits, (32, BLOCK_SIZE_N)) + weight = (tl.cast(signs, tl.float16) * 2.0 - 1.0) * scale[None, :] + accumulator += tl.dot(a, weight) + + tl.store( + out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on, + tl.cast(accumulator, tl.float16), + mask=m_mask[:, None] & n_mask[None, :], + ) + + @triton.jit + def _gguf_q1_0_g128_u32_fused_matmul_kernel_impl( + x_ptr, + sign_ptr, + scale_ptr, + out_ptr, + M, + N, + NUM_BLOCKS, + stride_xm, + stride_xk, + stride_qb, + stride_qg, + stride_qn, + stride_sb, + stride_sn, + stride_om, + stride_on, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + ): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + m_mask = offs_m < M + n_mask = offs_n < N + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + bit_shifts = tl.arange(0, 32) + + for block_idx in range(0, NUM_BLOCKS): + scale = tl.load( + scale_ptr + block_idx * stride_sb + offs_n * stride_sn, + mask=n_mask, + other=0.0, + ) + + for sign_group in range(0, 4): + offs_k = block_idx * 128 + sign_group * 32 + tl.arange(0, 32) + a = tl.load( + x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk, + mask=m_mask[:, None], + other=0.0, + ) + + packed = tl.load( + sign_ptr + block_idx * stride_qb + sign_group * stride_qg + offs_n * stride_qn, + mask=n_mask, + other=0, + ) + packed = tl.cast(packed, tl.uint32) + sign_bits = (packed[None, :] >> bit_shifts[:, None]) & 0x01 + weight = (tl.cast(sign_bits, tl.float16) * 2.0 - 1.0) * scale[None, :] + accumulator += tl.dot(a, weight) + + tl.store( + out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on, + tl.cast(accumulator, tl.float16), + mask=m_mask[:, None] & n_mask[None, :], + ) + + @triton.jit + def _gguf_q1_0_g128_k2048_fused_matmul_kernel_impl( + x_ptr, + sign_ptr, + scale_ptr, + out_ptr, + M, + N, + stride_xm, + stride_xk, + stride_qb, + stride_qq, + stride_qn, + stride_sb, + stride_sn, + stride_om, + stride_on, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + ): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + m_mask = offs_m < M + n_mask = offs_n < N + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + bit_shifts = tl.arange(0, 8) + + for block_idx in range(0, 16): + scale = tl.load( + scale_ptr + block_idx * stride_sb + offs_n * stride_sn, + mask=n_mask, + other=0.0, + ) + + for sign_group in range(0, 4): + offs_k = block_idx * 128 + sign_group * 32 + tl.arange(0, 32) + a = tl.load( + x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk, + mask=m_mask[:, None], + other=0.0, + ) + + packed = tl.load( + sign_ptr + + block_idx * stride_qb + + (sign_group * 4 + tl.arange(0, 4))[:, None] * stride_qq + + offs_n[None, :] * stride_qn, + mask=n_mask[None, :], + other=0, + ) + sign_bits = (packed[:, None, :] >> bit_shifts[None, :, None]) & 0x01 + signs = tl.reshape(sign_bits, (32, BLOCK_SIZE_N)) + weight = (tl.cast(signs, tl.float16) * 2.0 - 1.0) * scale[None, :] + accumulator += tl.dot(a, weight) + + tl.store( + out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on, + tl.cast(accumulator, tl.float16), + mask=m_mask[:, None] & n_mask[None, :], + ) + + @triton.jit + def _gguf_q1_0_g128_u32_k2048_fused_matmul_kernel_impl( + x_ptr, + sign_ptr, + scale_ptr, + out_ptr, + M, + N, + stride_xm, + stride_xk, + stride_qb, + stride_qg, + stride_qn, + stride_sb, + stride_sn, + stride_om, + stride_on, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + ): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + m_mask = offs_m < M + n_mask = offs_n < N + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + bit_shifts = tl.arange(0, 32) + + for block_idx in range(0, 16): + scale = tl.load( + scale_ptr + block_idx * stride_sb + offs_n * stride_sn, + mask=n_mask, + other=0.0, + ) + + for sign_group in range(0, 4): + offs_k = block_idx * 128 + sign_group * 32 + tl.arange(0, 32) + a = tl.load( + x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk, + mask=m_mask[:, None], + other=0.0, + ) + + packed = tl.load( + sign_ptr + block_idx * stride_qb + sign_group * stride_qg + offs_n * stride_qn, + mask=n_mask, + other=0, + ) + packed = tl.cast(packed, tl.uint32) + sign_bits = (packed[None, :] >> bit_shifts[:, None]) & 0x01 + weight = (tl.cast(sign_bits, tl.float16) * 2.0 - 1.0) * scale[None, :] + accumulator += tl.dot(a, weight) + + tl.store( + out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on, + tl.cast(accumulator, tl.float16), + mask=m_mask[:, None] & n_mask[None, :], + ) + + _gguf_q1_0_g128_fused_matmul_kernel_small = custom_autotune.autotune( + configs=_GGUF_TRITON_SMALL_CONFIGS, + key=["M", "N"], + nearest_power_of_two=True, + )(_gguf_q1_0_g128_fused_matmul_kernel_impl) + _gguf_q1_0_g128_fused_matmul_kernel_large = custom_autotune.autotune( + configs=_GGUF_TRITON_LARGE_CONFIGS, + key=["M", "N", "NUM_BLOCKS"], + nearest_power_of_two=True, + )(_gguf_q1_0_g128_fused_matmul_kernel_impl) + _gguf_q1_0_g128_u32_fused_matmul_kernel_small = custom_autotune.autotune( + configs=_GGUF_TRITON_SMALL_CONFIGS, + key=["M", "N"], + nearest_power_of_two=True, + )(_gguf_q1_0_g128_u32_fused_matmul_kernel_impl) + _gguf_q1_0_g128_u32_fused_matmul_kernel_large = custom_autotune.autotune( + configs=_GGUF_TRITON_LARGE_CONFIGS, + key=["M", "N", "NUM_BLOCKS"], + nearest_power_of_two=True, + )(_gguf_q1_0_g128_u32_fused_matmul_kernel_impl) + _gguf_q1_0_g128_k2048_fused_matmul_kernel = custom_autotune.autotune( + configs=_GGUF_TRITON_LARGE_CONFIGS, + key=["M", "N"], + nearest_power_of_two=True, + )(_gguf_q1_0_g128_k2048_fused_matmul_kernel_impl) + _gguf_q1_0_g128_u32_k2048_fused_matmul_kernel = custom_autotune.autotune( + configs=_GGUF_TRITON_LARGE_CONFIGS, + key=["M", "N"], + nearest_power_of_two=True, + )(_gguf_q1_0_g128_u32_k2048_fused_matmul_kernel_impl) + + @triton.jit + def _gguf_q4_k_fused_matmul_kernel_impl( + x_ptr, + qs_ptr, + scale_ptr, + min_ptr, + out_ptr, + M, + N, + NUM_BLOCKS, + stride_xm, + stride_xk, + stride_qb, + stride_qq, + stride_qn, + stride_sb, + stride_ss, + stride_sn, + stride_mb, + stride_ms, + stride_mn, + stride_om, + stride_on, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + ): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + m_mask = offs_m < M + n_mask = offs_n < N + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for block_idx in range(0, NUM_BLOCKS): + for subblock in range(0, 8): + offs_k = block_idx * 256 + subblock * 32 + tl.arange(0, 32) + a = tl.load( + x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk, + mask=m_mask[:, None], + other=0.0, + ) + + byte_idx = (subblock // 2) * 32 + tl.arange(0, 32) + packed = tl.load( + qs_ptr + block_idx * stride_qb + byte_idx[:, None] * stride_qq + offs_n[None, :] * stride_qn, + mask=n_mask[None, :], + other=0, + ) + if subblock % 2 == 0: + q = packed & 0x0F + else: + q = packed >> 4 + + scale = tl.load( + scale_ptr + block_idx * stride_sb + subblock * stride_ss + offs_n * stride_sn, + mask=n_mask, + other=0.0, + ) + min_value = tl.load( + min_ptr + block_idx * stride_mb + subblock * stride_ms + offs_n * stride_mn, + mask=n_mask, + other=0.0, + ) + + weight = tl.cast(q, tl.float16) * scale[None, :] - min_value[None, :] + accumulator += tl.dot(a, weight) + + tl.store( + out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on, + tl.cast(accumulator, tl.float16), + mask=m_mask[:, None] & n_mask[None, :], + ) + + _gguf_q4_k_fused_matmul_kernel_small = custom_autotune.autotune( + configs=_GGUF_TRITON_SMALL_CONFIGS, + key=["M", "N"], + nearest_power_of_two=True, + )(_gguf_q4_k_fused_matmul_kernel_impl) + _gguf_q4_k_fused_matmul_kernel_large = custom_autotune.autotune( + configs=_GGUF_TRITON_LARGE_CONFIGS, + key=["M", "N", "NUM_BLOCKS"], + nearest_power_of_two=True, + )(_gguf_q4_k_fused_matmul_kernel_impl) + + @triton.jit + def _gguf_q5_k_fused_matmul_kernel_impl( + x_ptr, + qs_ptr, + qh_ptr, + scale_ptr, + min_ptr, + out_ptr, + M, + N, + NUM_BLOCKS, + stride_xm, + stride_xk, + stride_qsb, + stride_qsq, + stride_qsn, + stride_qhb, + stride_qhq, + stride_qhn, + stride_sb, + stride_ss, + stride_sn, + stride_mb, + stride_ms, + stride_mn, + stride_om, + stride_on, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + ): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + m_mask = offs_m < M + n_mask = offs_n < N + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for block_idx in range(0, NUM_BLOCKS): + for subblock in range(0, 8): + offs_k = block_idx * 256 + subblock * 32 + tl.arange(0, 32) + a = tl.load( + x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk, + mask=m_mask[:, None], + other=0.0, + ) + + byte_idx = (subblock // 2) * 32 + tl.arange(0, 32) + packed = tl.load( + qs_ptr + block_idx * stride_qsb + byte_idx[:, None] * stride_qsq + offs_n[None, :] * stride_qsn, + mask=n_mask[None, :], + other=0, + ) + if subblock % 2 == 0: + ql = packed & 0x0F + else: + ql = packed >> 4 + + qh = tl.load( + qh_ptr + block_idx * stride_qhb + tl.arange(0, 32)[:, None] * stride_qhq + offs_n[None, :] * stride_qhn, + mask=n_mask[None, :], + other=0, + ) + qh = (qh >> subblock) & 0x01 + q = ql | (qh << 4) + + scale = tl.load( + scale_ptr + block_idx * stride_sb + subblock * stride_ss + offs_n * stride_sn, + mask=n_mask, + other=0.0, + ) + min_value = tl.load( + min_ptr + block_idx * stride_mb + subblock * stride_ms + offs_n * stride_mn, + mask=n_mask, + other=0.0, + ) + + weight = tl.cast(q, tl.float16) * scale[None, :] - min_value[None, :] + accumulator += tl.dot(a, weight) + + tl.store( + out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on, + tl.cast(accumulator, tl.float16), + mask=m_mask[:, None] & n_mask[None, :], + ) + + _gguf_q5_k_fused_matmul_kernel_small = custom_autotune.autotune( + configs=_GGUF_TRITON_SMALL_CONFIGS, + key=["M", "N"], + nearest_power_of_two=True, + )(_gguf_q5_k_fused_matmul_kernel_impl) + _gguf_q5_k_fused_matmul_kernel_large = custom_autotune.autotune( + configs=_GGUF_TRITON_LARGE_CONFIGS, + key=["M", "N", "NUM_BLOCKS"], + nearest_power_of_two=True, + )(_gguf_q5_k_fused_matmul_kernel_impl) + + @triton.jit + def _gguf_q6_k_fused_matmul_kernel_impl( + x_ptr, + ql_ptr, + qh_ptr, + scale_ptr, + out_ptr, + M, + N, + NUM_BLOCKS, + stride_xm, + stride_xk, + stride_qlb, + stride_qlq, + stride_qln, + stride_qhb, + stride_qhq, + stride_qhn, + stride_sb, + stride_ss, + stride_sn, + stride_om, + stride_on, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + ): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + m_mask = offs_m < M + n_mask = offs_n < N + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for block_idx in range(0, NUM_BLOCKS): + for subblock in range(0, 16): + offs_k = block_idx * 256 + subblock * 16 + tl.arange(0, 16) + a = tl.load( + x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk, + mask=m_mask[:, None], + other=0.0, + ) + + pair_row = subblock // 2 + half = subblock % 2 + row_group = pair_row // 4 + row_in_group = pair_row % 4 + pos32 = half * 16 + tl.arange(0, 16) + ql_base = row_group * 64 + (32 if row_in_group == 1 or row_in_group == 3 else 0) + + packed_ql = tl.load( + ql_ptr + block_idx * stride_qlb + (ql_base + pos32)[:, None] * stride_qlq + offs_n[None, :] * stride_qln, + mask=n_mask[None, :], + other=0, + ) + if row_in_group == 0 or row_in_group == 1: + low = packed_ql & 0x0F + else: + low = packed_ql >> 4 + + packed_qh = tl.load( + qh_ptr + block_idx * stride_qhb + (row_group * 32 + pos32)[:, None] * stride_qhq + offs_n[None, :] * stride_qhn, + mask=n_mask[None, :], + other=0, + ) + high = (packed_qh >> (row_in_group * 2)) & 0x03 + q = tl.cast(low | (high << 4), tl.int16) - 32 + + scale = tl.load( + scale_ptr + block_idx * stride_sb + subblock * stride_ss + offs_n * stride_sn, + mask=n_mask, + other=0.0, + ) + + weight = tl.cast(q, tl.float16) * scale[None, :] + accumulator += tl.dot(a, weight) + + tl.store( + out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on, + tl.cast(accumulator, tl.float16), + mask=m_mask[:, None] & n_mask[None, :], + ) + + _gguf_q6_k_fused_matmul_kernel_small = custom_autotune.autotune( + configs=_GGUF_TRITON_SMALL_CONFIGS, + key=["M", "N"], + nearest_power_of_two=True, + )(_gguf_q6_k_fused_matmul_kernel_impl) + _gguf_q6_k_fused_matmul_kernel_large = custom_autotune.autotune( + configs=_GGUF_TRITON_LARGE_CONFIGS, + key=["M", "N", "NUM_BLOCKS"], + nearest_power_of_two=True, + )(_gguf_q6_k_fused_matmul_kernel_impl) + + +def _launch( + kernel: Callable, + x: torch.Tensor, + output: torch.Tensor, + *args, +) -> torch.Tensor: + def grid(meta): + return ( + triton.cdiv(x.shape[0], meta["BLOCK_SIZE_M"]), + triton.cdiv(output.shape[1], meta["BLOCK_SIZE_N"]), + ) + + kernel[grid](*args) + return output + + +def _launch_with_meta( + kernel: Callable, + x: torch.Tensor, + output: torch.Tensor, + *args, + **meta, +) -> torch.Tensor: + def grid(meta_args): + return ( + triton.cdiv(x.shape[0], meta_args["BLOCK_SIZE_M"]), + triton.cdiv(output.shape[1], meta_args["BLOCK_SIZE_N"]), + ) + + kernel[grid](*args, **meta) + return output + + +def _select_triton_kernel( + small_kernel: Callable, + large_kernel: Callable, + *, + num_blocks: int, +) -> Callable: + if num_blocks >= _GGUF_TRITON_LARGE_NUM_BLOCKS: + return large_kernel + return small_kernel + + +def _select_q1_0_g128_fixed_launch_config( + *, + capability: tuple[int, int] | None, + rows: int, + in_features: int | None = None, + cols: int, +) -> dict[str, int] | None: + if capability == (8, 0): + if rows == 1: + if cols <= 2048: + return dict(_Q1_0_G128_SM80_DECODE_NARROW_CONFIG) + return dict(_Q1_0_G128_SM80_DECODE_WIDE_CONFIG) + if 8 <= rows <= 128: + if cols == 2048: + return dict(_Q1_0_G128_SM80_PREFILL_NARROW_CONFIG) + if cols == 6144: + return dict(_Q1_0_G128_SM80_PREFILL_WIDE_CONFIG) + if capability == (8, 9): + if rows == 1: + if cols == 1024: + return dict(_Q1_0_G128_SM89_DECODE_NARROW_CONFIG) + if in_features == 2048 and cols == 2048: + return dict(_Q1_0_G128_SM89_DECODE_2048_TO_2048_CONFIG) + if in_features == 2048 and cols == 6144: + return dict(_Q1_0_G128_SM89_DECODE_2048_TO_6144_CONFIG) + if cols == 2048: + return dict(_Q1_0_G128_SM89_DECODE_NARROW_CONFIG) + if cols == 6144: + return dict(_Q1_0_G128_SM89_DECODE_WIDE_CONFIG) + if 8 <= rows <= 128: + if cols == 2048: + return dict(_Q1_0_G128_SM89_PREFILL_NARROW_CONFIG) + if cols == 6144: + return dict(_Q1_0_G128_SM89_PREFILL_WIDE_CONFIG) + return None + + +def _select_q1_0_g128_u32_layout( + *, + capability: tuple[int, int] | None, + in_features: int, + out_features: int, +) -> bool: + if capability != (8, 0): + return False + return (in_features, out_features) in { + (2048, 1024), + (2048, 2048), + (2048, 6144), + (6144, 2048), + } + + +def _select_q1_0_g128_u32_fixed_launch_config( + *, + capability: tuple[int, int] | None, + rows: int, + in_features: int, + out_features: int, +) -> dict[str, int] | None: + if capability != (8, 0): + return None + if rows == 1: + if in_features == 6144 and out_features == 2048: + return dict(_Q1_0_G128_SM80_U32_DECODE_CONFIG) + if in_features == 2048 and out_features == 1024: + return dict(_Q1_0_G128_SM80_U32_DECODE_2048_TO_1024_CONFIG) + if in_features == 2048 and out_features == 2048: + return dict(_Q1_0_G128_SM80_U32_DECODE_2048_TO_2048_CONFIG) + if in_features == 2048 and out_features == 6144: + return dict(_Q1_0_G128_SM80_U32_DECODE_2048_TO_6144_CONFIG) + return None + if 8 <= rows <= 128 and in_features == 6144 and out_features == 2048: + return dict(_Q1_0_G128_SM80_U32_PREFILL_CONFIG) + return None + + +def _use_q1_0_g128_k2048_decode_specialization( + *, + rows: int, + in_features: int, +) -> bool: + return rows == 1 and in_features == 2048 + + +def _cuda_device_capability(device: torch.device) -> tuple[int, int] | None: + if device.type != "cuda": + return None + + index = device.index if device.index is not None else torch.cuda.current_device() + capability = _CUDA_DEVICE_CAPABILITY_CACHE.get(index) + if capability is None: + capability = torch.cuda.get_device_capability(index) + _CUDA_DEVICE_CAPABILITY_CACHE[index] = capability + return capability + + +def fused_q4_k_matmul( + x: torch.Tensor, + qs: torch.Tensor, + scale: torch.Tensor, + min_value: torch.Tensor, +) -> torch.Tensor: + if not _TRITON_AVAILABLE: + raise RuntimeError("Triton is not available for GGUF Q4_K fused matmul.") + + output = torch.empty((x.shape[0], scale.shape[2]), device=x.device, dtype=x.dtype) + kernel = _select_triton_kernel( + _gguf_q4_k_fused_matmul_kernel_small, + _gguf_q4_k_fused_matmul_kernel_large, + num_blocks=qs.shape[0], + ) + return _launch( + kernel, + x, + output, + x, + qs, + scale, + min_value, + output, + x.shape[0], + output.shape[1], + qs.shape[0], + x.stride(0), + x.stride(1), + qs.stride(0), + qs.stride(1), + qs.stride(2), + scale.stride(0), + scale.stride(1), + scale.stride(2), + min_value.stride(0), + min_value.stride(1), + min_value.stride(2), + output.stride(0), + output.stride(1), + ) + + +def fused_q1_0_g128_matmul( + x: torch.Tensor, + sign_bytes: torch.Tensor, + scale: torch.Tensor, +) -> torch.Tensor: + if not _TRITON_AVAILABLE: + raise RuntimeError("Triton is not available for GGUF Q1_0_g128 fused matmul.") + + output = torch.empty((x.shape[0], scale.shape[1]), device=x.device, dtype=x.dtype) + kernel = _select_triton_kernel( + _gguf_q1_0_g128_fused_matmul_kernel_small, + _gguf_q1_0_g128_fused_matmul_kernel_large, + num_blocks=sign_bytes.shape[0], + ) + return _launch( + kernel, + x, + output, + x, + sign_bytes, + scale, + output, + x.shape[0], + output.shape[1], + sign_bytes.shape[0], + x.stride(0), + x.stride(1), + sign_bytes.stride(0), + sign_bytes.stride(1), + sign_bytes.stride(2), + scale.stride(0), + scale.stride(1), + output.stride(0), + output.stride(1), + ) + + +def fused_q1_0_g128_k2048_matmul( + x: torch.Tensor, + sign_bytes: torch.Tensor, + scale: torch.Tensor, +) -> torch.Tensor: + if not _TRITON_AVAILABLE: + raise RuntimeError("Triton is not available for GGUF Q1_0_g128 fused matmul.") + + output = torch.empty((x.shape[0], scale.shape[1]), device=x.device, dtype=x.dtype) + return _launch( + _gguf_q1_0_g128_k2048_fused_matmul_kernel, + x, + output, + x, + sign_bytes, + scale, + output, + x.shape[0], + output.shape[1], + x.stride(0), + x.stride(1), + sign_bytes.stride(0), + sign_bytes.stride(1), + sign_bytes.stride(2), + scale.stride(0), + scale.stride(1), + output.stride(0), + output.stride(1), + ) + + +def fused_q1_0_g128_u32_matmul( + x: torch.Tensor, + sign_words: torch.Tensor, + scale: torch.Tensor, +) -> torch.Tensor: + if not _TRITON_AVAILABLE: + raise RuntimeError("Triton is not available for GGUF Q1_0_g128 fused matmul.") + + output = torch.empty((x.shape[0], scale.shape[1]), device=x.device, dtype=x.dtype) + kernel = _select_triton_kernel( + _gguf_q1_0_g128_u32_fused_matmul_kernel_small, + _gguf_q1_0_g128_u32_fused_matmul_kernel_large, + num_blocks=sign_words.shape[0], + ) + return _launch( + kernel, + x, + output, + x, + sign_words, + scale, + output, + x.shape[0], + output.shape[1], + sign_words.shape[0], + x.stride(0), + x.stride(1), + sign_words.stride(0), + sign_words.stride(1), + sign_words.stride(2), + scale.stride(0), + scale.stride(1), + output.stride(0), + output.stride(1), + ) + + +def fused_q1_0_g128_u32_k2048_matmul( + x: torch.Tensor, + sign_words: torch.Tensor, + scale: torch.Tensor, +) -> torch.Tensor: + if not _TRITON_AVAILABLE: + raise RuntimeError("Triton is not available for GGUF Q1_0_g128 fused matmul.") + + output = torch.empty((x.shape[0], scale.shape[1]), device=x.device, dtype=x.dtype) + return _launch( + _gguf_q1_0_g128_u32_k2048_fused_matmul_kernel, + x, + output, + x, + sign_words, + scale, + output, + x.shape[0], + output.shape[1], + x.stride(0), + x.stride(1), + sign_words.stride(0), + sign_words.stride(1), + sign_words.stride(2), + scale.stride(0), + scale.stride(1), + output.stride(0), + output.stride(1), + ) + + +def _launch_q1_0_g128_u32_fixed_matmul( + x: torch.Tensor, + sign_words: torch.Tensor, + scale: torch.Tensor, + *, + fixed_config: dict[str, int], +) -> torch.Tensor: + output = torch.empty((x.shape[0], scale.shape[1]), device=x.device, dtype=x.dtype) + return _launch_with_meta( + _gguf_q1_0_g128_u32_fused_matmul_kernel_impl, + x, + output, + x, + sign_words, + scale, + output, + x.shape[0], + output.shape[1], + sign_words.shape[0], + x.stride(0), + x.stride(1), + sign_words.stride(0), + sign_words.stride(1), + sign_words.stride(2), + scale.stride(0), + scale.stride(1), + output.stride(0), + output.stride(1), + **fixed_config, + ) + + +def _launch_q1_0_g128_u32_k2048_fixed_matmul( + x: torch.Tensor, + sign_words: torch.Tensor, + scale: torch.Tensor, + *, + fixed_config: dict[str, int], +) -> torch.Tensor: + output = torch.empty((x.shape[0], scale.shape[1]), device=x.device, dtype=x.dtype) + return _launch_with_meta( + _gguf_q1_0_g128_u32_k2048_fused_matmul_kernel_impl, + x, + output, + x, + sign_words, + scale, + output, + x.shape[0], + output.shape[1], + x.stride(0), + x.stride(1), + sign_words.stride(0), + sign_words.stride(1), + sign_words.stride(2), + scale.stride(0), + scale.stride(1), + output.stride(0), + output.stride(1), + **fixed_config, + ) + + +def _launch_q1_0_g128_fixed_matmul( + x: torch.Tensor, + sign_bytes: torch.Tensor, + scale: torch.Tensor, + *, + fixed_config: dict[str, int], +) -> torch.Tensor: + output = torch.empty((x.shape[0], scale.shape[1]), device=x.device, dtype=x.dtype) + return _launch_with_meta( + _gguf_q1_0_g128_fused_matmul_kernel_impl, + x, + output, + x, + sign_bytes, + scale, + output, + x.shape[0], + output.shape[1], + sign_bytes.shape[0], + x.stride(0), + x.stride(1), + sign_bytes.stride(0), + sign_bytes.stride(1), + sign_bytes.stride(2), + scale.stride(0), + scale.stride(1), + output.stride(0), + output.stride(1), + **fixed_config, + ) + + +def _launch_q1_0_g128_k2048_fixed_matmul( + x: torch.Tensor, + sign_bytes: torch.Tensor, + scale: torch.Tensor, + *, + fixed_config: dict[str, int], +) -> torch.Tensor: + output = torch.empty((x.shape[0], scale.shape[1]), device=x.device, dtype=x.dtype) + return _launch_with_meta( + _gguf_q1_0_g128_k2048_fused_matmul_kernel_impl, + x, + output, + x, + sign_bytes, + scale, + output, + x.shape[0], + output.shape[1], + x.stride(0), + x.stride(1), + sign_bytes.stride(0), + sign_bytes.stride(1), + sign_bytes.stride(2), + scale.stride(0), + scale.stride(1), + output.stride(0), + output.stride(1), + **fixed_config, + ) + + +def fused_q5_k_matmul( + x: torch.Tensor, + qs: torch.Tensor, + qh: torch.Tensor, + scale: torch.Tensor, + min_value: torch.Tensor, +) -> torch.Tensor: + if not _TRITON_AVAILABLE: + raise RuntimeError("Triton is not available for GGUF Q5_K fused matmul.") + + output = torch.empty((x.shape[0], scale.shape[2]), device=x.device, dtype=x.dtype) + kernel = _select_triton_kernel( + _gguf_q5_k_fused_matmul_kernel_small, + _gguf_q5_k_fused_matmul_kernel_large, + num_blocks=qs.shape[0], + ) + return _launch( + kernel, + x, + output, + x, + qs, + qh, + scale, + min_value, + output, + x.shape[0], + output.shape[1], + qs.shape[0], + x.stride(0), + x.stride(1), + qs.stride(0), + qs.stride(1), + qs.stride(2), + qh.stride(0), + qh.stride(1), + qh.stride(2), + scale.stride(0), + scale.stride(1), + scale.stride(2), + min_value.stride(0), + min_value.stride(1), + min_value.stride(2), + output.stride(0), + output.stride(1), + ) + + +def fused_q6_k_matmul( + x: torch.Tensor, + ql: torch.Tensor, + qh: torch.Tensor, + scale: torch.Tensor, +) -> torch.Tensor: + if not _TRITON_AVAILABLE: + raise RuntimeError("Triton is not available for GGUF Q6_K fused matmul.") + + output = torch.empty((x.shape[0], scale.shape[2]), device=x.device, dtype=x.dtype) + kernel = _select_triton_kernel( + _gguf_q6_k_fused_matmul_kernel_small, + _gguf_q6_k_fused_matmul_kernel_large, + num_blocks=ql.shape[0], + ) + return _launch( + kernel, + x, + output, + x, + ql, + qh, + scale, + output, + x.shape[0], + output.shape[1], + ql.shape[0], + x.stride(0), + x.stride(1), + ql.stride(0), + ql.stride(1), + ql.stride(2), + qh.stride(0), + qh.stride(1), + qh.stride(2), + scale.stride(0), + scale.stride(1), + scale.stride(2), + output.stride(0), + output.stride(1), + ) + + +class GGUFTritonKernel(GGUFTorchLinear): + SUPPORTS_BACKENDS = [BACKEND.GGUF_TRITON] + SUPPORTS_METHODS = [METHOD.GGUF] + SUPPORTS_FORMATS = {FORMAT.GGUF: 45} + SUPPORTS_BITS = [1, 4, 5, 6] + SUPPORTS_SHARDS = True + SUPPORTS_TRAINING = False + SUPPORTS_AUTO_PADDING = True + SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [1] + SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [1] + SUPPORTS_DEVICES = [DEVICE.CUDA] + SUPPORTS_PLATFORM = [PLATFORM.LINUX, PLATFORM.WIN32] + SUPPORTS_PACK_DTYPES = [torch.int8, torch.int16, torch.int32] + SUPPORTS_ADAPTERS = [Lora] + SUPPORTS_DTYPES = [torch.float16, torch.bfloat16, torch.float32] + + REQUIRES_FORMAT_V2 = False + AUTOTUNE = False + + QUANT_TYPE = "gguf" + + def __init__( + self, + bits, + group_size: int, + sym: bool, + desc_act: bool, + in_features: int, + out_features: int, + bias: bool = False, + pack_dtype: torch.dtype = torch.int32, + adapter: Adapter = None, + register_buffers: bool = True, + **kwargs, + ): + kwargs.setdefault("backend", BACKEND.GGUF_TRITON) + super().__init__( + bits=bits, + group_size=group_size, + sym=sym, + desc_act=desc_act, + in_features=in_features, + out_features=out_features, + bias=bias, + pack_dtype=pack_dtype, + adapter=adapter, + register_buffers=register_buffers, + **kwargs, + ) + if self.gguf_tensor_qtype not in {"Q1_0_g128", "Q4_K", "Q5_K", "Q6_K"}: + raise NotImplementedError( + f"{self.__class__.__name__} only supports fused GGUF Triton formats " + f"(Q1_0_g128, Q4_K, Q5_K, Q6_K). Actual GGUF qtype: {self.gguf_tensor_qtype}. " + "Use BACKEND.GGUF_TORCH for unsupported GGUF formats." + ) + self._gguf_triton_cache: dict[tuple[int, str], dict[str, Any]] = {} + + @classmethod + def validate_once(cls) -> Tuple[bool, Optional[Exception]]: + if not _TRITON_AVAILABLE: + return False, ModuleNotFoundError("GGUFTritonKernel requires `triton` to be installed.") + + triton_v = version.parse(triton.__version__) + if triton_v < version.parse("2.0.0"): + return False, ImportError(f"triton version must be >= 2.0.0: actual = {triton.__version__}") + + if has_gil_disabled() and triton_v < version.parse("3.4.0"): + return False, Exception("GIL is disabled and not compatible with current Triton. Please upgrade to Triton >= 3.4.0") + + return True, None + + def clear_weight_cache(self) -> None: + self._gguf_triton_cache.clear() + return super().clear_weight_cache() + + def _triton_cache_key(self, device: torch.device) -> tuple[int, str]: + return (device.index if device.index is not None else -1, self.gguf_tensor_qtype) + + def _build_triton_cache(self, device: torch.device) -> dict[str, Any]: + blocks, _, _ = self._reshape_blocks(device=device) + + if self.gguf_tensor_qtype == "Q1_0_g128": + scale = blocks[..., :2].contiguous().view(torch.float16).squeeze(-1).permute(1, 0).contiguous() + capability = _cuda_device_capability(device) + if _select_q1_0_g128_u32_layout( + capability=capability, + in_features=self.padded_in_features, + out_features=scale.shape[1], + ): + sign_bytes = blocks[..., 2:].permute(1, 2, 0).contiguous() + sign_words_src = sign_bytes.to(torch.int32) + sign_words = ( + sign_words_src[:, 0::4, :] + | torch.bitwise_left_shift(sign_words_src[:, 1::4, :], 8) + | torch.bitwise_left_shift(sign_words_src[:, 2::4, :], 16) + | torch.bitwise_left_shift(sign_words_src[:, 3::4, :], 24) + ).contiguous() + return { + "fixed_decode_config": None, + "qweight_ptr": self.qweight.data_ptr(), + "sign_bytes": sign_bytes, + "sign_words": sign_words, + "scale": scale, + "use_u32": True, + } + + sign_bytes = blocks[..., 2:].permute(1, 2, 0).contiguous() + return { + "fixed_decode_config": _select_q1_0_g128_fixed_launch_config( + capability=capability, + rows=1, + in_features=self.padded_in_features, + cols=scale.shape[1], + ), + "qweight_ptr": self.qweight.data_ptr(), + "sign_bytes": sign_bytes, + "scale": scale, + "use_u32": False, + } + + if self.gguf_tensor_qtype == "Q4_K": + d = blocks[..., :2].contiguous().view(torch.float16).squeeze(-1) + dmin = blocks[..., 2:4].contiguous().view(torch.float16).squeeze(-1) + sc, mins = _unpack_q4_k_scale_min_torch(blocks[..., 4:16]) + scale = (d.unsqueeze(-1) * sc.to(torch.float16)).permute(1, 2, 0).contiguous() + min_value = (dmin.unsqueeze(-1) * mins.to(torch.float16)).permute(1, 2, 0).contiguous() + qs = blocks[..., 16:].permute(1, 2, 0).contiguous() + return { + "qweight_ptr": self.qweight.data_ptr(), + "qs": qs, + "scale": scale, + "min": min_value, + } + + if self.gguf_tensor_qtype == "Q5_K": + d = blocks[..., :2].contiguous().view(torch.float16).squeeze(-1) + dmin = blocks[..., 2:4].contiguous().view(torch.float16).squeeze(-1) + sc, mins = _unpack_q4_k_scale_min_torch(blocks[..., 4:16]) + scale = (d.unsqueeze(-1) * sc.to(torch.float16)).permute(1, 2, 0).contiguous() + min_value = (dmin.unsqueeze(-1) * mins.to(torch.float16)).permute(1, 2, 0).contiguous() + qh = blocks[..., 16:48].permute(1, 2, 0).contiguous() + qs = blocks[..., 48:].permute(1, 2, 0).contiguous() + return { + "qweight_ptr": self.qweight.data_ptr(), + "qs": qs, + "qh": qh, + "scale": scale, + "min": min_value, + } + + if self.gguf_tensor_qtype == "Q6_K": + ql = blocks[..., :128].permute(1, 2, 0).contiguous() + qh = blocks[..., 128:192].permute(1, 2, 0).contiguous() + scale = blocks[..., 192:208].contiguous().view(torch.int8).to(torch.float16) + d = blocks[..., 208:210].contiguous().view(torch.float16).squeeze(-1) + scale = (d.unsqueeze(-1) * scale).permute(1, 2, 0).contiguous() + return { + "qweight_ptr": self.qweight.data_ptr(), + "ql": ql, + "qh": qh, + "scale": scale, + } + + raise NotImplementedError(f"Unsupported GGUF Triton qtype: {self.gguf_tensor_qtype}") + + def _get_triton_cache(self, device: torch.device) -> dict[str, Any]: + key = self._triton_cache_key(device) + cached = self._gguf_triton_cache.get(key) + if cached is not None and cached.get("qweight_ptr") == self.qweight.data_ptr(): + return cached + + cached = self._build_triton_cache(device) + self._gguf_triton_cache[key] = cached + return cached + + def _forward_triton(self, x_flat: torch.Tensor) -> torch.Tensor: + if x_flat.device.type != "cuda": + raise RuntimeError( + f"{self.__class__.__name__} only supports CUDA inference. " + "Load GGUF models on CUDA or use BACKEND.GGUF_TORCH for the torch fallback." + ) + + if x_flat.shape[-1] != self.padded_in_features: + x_work = torch.nn.functional.pad(x_flat, (0, self.padded_in_features - x_flat.shape[-1])).contiguous() + else: + x_work = x_flat.contiguous() + + cache = self._get_triton_cache(x_work.device) + + if self.gguf_tensor_qtype == "Q1_0_g128": + if cache.get("use_u32"): + fixed_u32_config = _select_q1_0_g128_u32_fixed_launch_config( + capability=_cuda_device_capability(x_work.device), + rows=x_work.shape[0], + in_features=self.padded_in_features, + out_features=cache["scale"].shape[1], + ) + if fixed_u32_config is not None: + if _use_q1_0_g128_k2048_decode_specialization( + rows=x_work.shape[0], + in_features=self.padded_in_features, + ): + return _launch_q1_0_g128_u32_k2048_fixed_matmul( + x_work, + cache["sign_words"], + cache["scale"], + fixed_config=fixed_u32_config, + ) + return _launch_q1_0_g128_u32_fixed_matmul( + x_work, + cache["sign_words"], + cache["scale"], + fixed_config=fixed_u32_config, + ) + return fused_q1_0_g128_matmul(x_work, cache["sign_bytes"], cache["scale"]) + fixed_decode_config = cache.get("fixed_decode_config") + if fixed_decode_config is not None and x_work.shape[0] == 1: + if _use_q1_0_g128_k2048_decode_specialization( + rows=x_work.shape[0], + in_features=self.padded_in_features, + ): + return _launch_q1_0_g128_k2048_fixed_matmul( + x_work, + cache["sign_bytes"], + cache["scale"], + fixed_config=fixed_decode_config, + ) + return _launch_q1_0_g128_fixed_matmul( + x_work, + cache["sign_bytes"], + cache["scale"], + fixed_config=fixed_decode_config, + ) + if _use_q1_0_g128_k2048_decode_specialization( + rows=x_work.shape[0], + in_features=self.padded_in_features, + ): + return fused_q1_0_g128_k2048_matmul(x_work, cache["sign_bytes"], cache["scale"]) + return fused_q1_0_g128_matmul(x_work, cache["sign_bytes"], cache["scale"]) + if self.gguf_tensor_qtype == "Q4_K": + return fused_q4_k_matmul(x_work, cache["qs"], cache["scale"], cache["min"]) + if self.gguf_tensor_qtype == "Q5_K": + return fused_q5_k_matmul(x_work, cache["qs"], cache["qh"], cache["scale"], cache["min"]) + if self.gguf_tensor_qtype == "Q6_K": + return fused_q6_k_matmul(x_work, cache["ql"], cache["qh"], cache["scale"]) + + raise NotImplementedError(f"Unsupported GGUF Triton qtype: {self.gguf_tensor_qtype}") + + def forward(self, x: torch.Tensor): + original_shape = x.shape[:-1] + (self.out_features,) + x_flat = x.reshape(-1, x.shape[-1]) + + input_dtype = x_flat.dtype + if input_dtype != torch.float16: + x_work = x_flat.to(dtype=torch.float16) + else: + x_work = x_flat + + output = self._forward_triton(x_work) + + if self.bias is not None: + bias = self.bias + if bias.device != output.device or bias.dtype != output.dtype: + bias = bias.to(device=output.device, dtype=output.dtype) + output = output + bias + + if input_dtype != output.dtype: + output = output.to(dtype=input_dtype) + + if self.adapter: + output = self.adapter.apply(x=x_flat, out=output) + + return output.reshape(original_shape) + + +__all__ = [ + "GGUFTritonKernel", + "_select_q1_0_g128_fixed_launch_config", + "_select_q1_0_g128_u32_fixed_launch_config", + "_select_q1_0_g128_u32_layout", + "_use_q1_0_g128_k2048_decode_specialization", + "fused_q1_0_g128_matmul", + "fused_q1_0_g128_k2048_matmul", + "fused_q1_0_g128_u32_matmul", + "fused_q1_0_g128_u32_k2048_matmul", + "fused_q4_k_matmul", + "fused_q5_k_matmul", + "fused_q6_k_matmul", + "triton_available", +] diff --git a/gptqmodel/nn_modules/qlinear/gptq_pro.py b/gptqmodel/nn_modules/qlinear/gptq_pro.py new file mode 100644 index 000000000..bbea5a7f7 --- /dev/null +++ b/gptqmodel/nn_modules/qlinear/gptq_pro.py @@ -0,0 +1,193 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from typing import List, Optional, Tuple + +import torch + +from ...adapter.adapter import Adapter, Lora +from ...models._const import DEVICE, PLATFORM +from ...nn_modules.qlinear import PackableQuantLinear +from ...quantization import FORMAT, METHOD +from ...utils.backend import BACKEND +from ...utils.gptq_pro import ( + _validate_gptq_pro_device_support, + apply_gptq_pro_linear, + ensure_gptq_pro_loaded, + gptq_pro_qweight_to_b_packed, +) +from ...utils.rocm import IS_ROCM + + +class GptqProQuantLinear(PackableQuantLinear): + SUPPORTS_BACKENDS = [BACKEND.GPTQ_PRO] + SUPPORTS_METHODS = [METHOD.GPTQ] + SUPPORTS_FORMATS = {FORMAT.GPTQ: 0, FORMAT.GPTQ_V2: 0} + SUPPORTS_BITS = [4] + SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128, 256, 512, 1024] + SUPPORTS_DESC_ACT = [False] + SUPPORTS_SYM = [True] + SUPPORTS_SHARDS = True + SUPPORTS_TRAINING = False + SUPPORTS_AUTO_PADDING = False + SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [16] + SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [8] + + SUPPORTS_DEVICES = [DEVICE.CUDA] + SUPPORTS_PLATFORM = [PLATFORM.LINUX] + SUPPORTS_PACK_DTYPES = [torch.int32] + SUPPORTS_ADAPTERS = [Lora] + + SUPPORTS_DTYPES = [torch.float16] + + REQUIRES_FORMAT_V2 = True + QUANT_TYPE = "gptq_pro" + + def __init__( + self, + bits: int, + group_size: int, + desc_act: bool, + sym: bool, + in_features: int, + out_features: int, + bias: bool = False, + pack_dtype: torch.dtype = torch.int32, + adapter: Adapter = None, + register_buffers: bool = True, + **kwargs, + ): + super().__init__( + bits=bits, + group_size=group_size, + sym=sym, + desc_act=desc_act, + in_features=in_features, + out_features=out_features, + bias=bias, + pack_dtype=pack_dtype, + backend=kwargs.pop("backend", BACKEND.GPTQ_PRO), + adapter=adapter, + register_buffers=register_buffers, + enable_wf_unsqueeze=False, + **kwargs, + ) + + @classmethod + def validate_once(cls) -> Tuple[bool, Optional[Exception]]: + try: + ensure_gptq_pro_loaded() + except ImportError as exc: + return False, ImportError(str(exc)) + return True, None + + @classmethod + def validate_device(cls, device: DEVICE): + super().validate_device(device) + if device == DEVICE.CUDA: + if IS_ROCM: + raise NotImplementedError("GPTQ-Pro kernel is not supported on ROCm.") + if not _validate_gptq_pro_device_support(): + raise NotImplementedError("GPTQ-Pro kernel requires compute capability >= 8.0.") + + @classmethod + def _validate( + cls, + bits: int = 4, + group_size: int = 128, + desc_act: bool = False, + sym: bool = True, + pack_dtype: torch.dtype = None, + dtype: Optional[torch.dtype] = None, + dynamic: Optional[dict] = None, + in_features: int = None, + out_features: int = None, + device: Optional[DEVICE] = None, + trainable: Optional[bool] = None, + adapter: Optional[Adapter] = None, + ) -> Tuple[bool, Optional[Exception]]: + ok, err = super()._validate( + bits=bits, + group_size=group_size, + desc_act=desc_act, + sym=sym, + pack_dtype=pack_dtype, + dtype=dtype, + dynamic=dynamic, + in_features=in_features, + out_features=out_features, + device=device, + trainable=trainable, + adapter=adapter, + ) + if not ok: + return ok, err + + effective_group_size = in_features if (group_size == -1 and in_features is not None) else group_size + if effective_group_size is not None and effective_group_size > 0 and (effective_group_size % 16) != 0: + return False, NotImplementedError( + f"{cls} requires group_size to be a positive multiple of 16: actual group_size = `{effective_group_size}`" + ) + return True, None + + def post_init(self): + ensure_gptq_pro_loaded() + + if self.qweight.device.type != "cuda": + raise ValueError("GPTQ-Pro backend requires CUDA-resident packed weights before post_init().") + + expected_g_idx = torch.arange( + self.in_features, + device=self.g_idx.device, + dtype=self.g_idx.dtype, + ) // self.group_size + if not torch.equal(self.g_idx, expected_g_idx): + raise ValueError("GPTQ-Pro backend only supports sequential g_idx / desc_act=False checkpoints.") + + b_packed = gptq_pro_qweight_to_b_packed(self.qweight) + if "b_packed" not in self._buffers: + self.register_buffer("b_packed", b_packed, persistent=False) + else: + self.b_packed = b_packed + + super().post_init() + + def list_buffers(self) -> List: + buf = super().list_buffers() + if hasattr(self, "b_packed") and self.b_packed is not None: + buf.append(self.b_packed) + return buf + + def forward(self, x: torch.Tensor): + if x.shape[0] == 0: + return torch.empty(x.shape[:-1] + (self.out_features,), dtype=x.dtype, device=x.device) + + out_shape = x.shape[:-1] + (self.out_features,) + x = x.reshape(-1, x.shape[-1]) + if x.shape[-1] != self.in_features: + raise ValueError( + f"GPTQ-Pro backend expected input dim {self.in_features}, got {x.shape[-1]}." + ) + + if x.dtype != torch.float16: + x = x.to(torch.float16) + + out = apply_gptq_pro_linear( + input=x.contiguous(), + b_packed=self.b_packed, + scales=self.scales, + group_size=self.group_size, + ) + + if self.bias is not None: + out.add_(self.bias) + + if self.adapter: + out = self.adapter.apply(x=x, out=out) + + return out.reshape(out_shape) + + +__all__ = ["GptqProQuantLinear"] diff --git a/gptqmodel/nn_modules/qlinear/lookahead.py b/gptqmodel/nn_modules/qlinear/lookahead.py index 1cdb222da..aeb84d36d 100644 --- a/gptqmodel/nn_modules/qlinear/lookahead.py +++ b/gptqmodel/nn_modules/qlinear/lookahead.py @@ -6,11 +6,11 @@ from collections import defaultdict from typing import Iterable, List, Tuple -from .torch import TorchQuantLinear +from .torch import TorchLinear -def configure_lookahead_chain(modules: Iterable[TorchQuantLinear]): - """Wire a sequence of TorchQuantLinear modules for one-step lookahead. +def configure_lookahead_chain(modules: Iterable[TorchLinear]): + """Wire a sequence of TorchLinear modules for one-step lookahead. Each module in *modules* (except the last) will prefetch the next module's dequantized weights the moment it finishes its own forward call. The last @@ -27,7 +27,7 @@ def configure_lookahead_chain(modules: Iterable[TorchQuantLinear]): last.set_lookahead_next(None) -def _clear_existing_links(modules: Iterable[TorchQuantLinear]): +def _clear_existing_links(modules: Iterable[TorchLinear]): for module in modules: module.set_lookahead_next(None) @@ -41,9 +41,9 @@ def configure_default_lookahead(model) -> None: are skipped. """ - ordered_modules: List[Tuple[str, TorchQuantLinear]] = [] + ordered_modules: List[Tuple[str, TorchLinear]] = [] for name, module in model.named_modules(): - if isinstance(module, TorchQuantLinear): + if isinstance(module, TorchLinear): ordered_modules.append((name, module)) if not ordered_modules: diff --git a/gptqmodel/nn_modules/qlinear/machete.py b/gptqmodel/nn_modules/qlinear/machete.py index 94df1ffaa..caa782b41 100644 --- a/gptqmodel/nn_modules/qlinear/machete.py +++ b/gptqmodel/nn_modules/qlinear/machete.py @@ -5,25 +5,27 @@ from __future__ import annotations +import math from typing import List, Optional, Tuple import torch from ...adapter.adapter import Adapter, Lora from ...models._const import DEVICE, PLATFORM -from ...nn_modules.qlinear import BaseQuantLinear +from ...nn_modules.qlinear import GPTQQuantLinear from ...quantization import FORMAT, METHOD from ...utils.backend import BACKEND from ...utils.logger import setup_logger from ...utils.machete import ( _validate_machete_device_support, check_machete_supports_shape, - gptqmodel_machete_kernels, - machete_import_exception, machete_mm, machete_prepack_B, + machete_runtime_available, + machete_runtime_error, pack_quantized_values_into_int32, query_machete_supported_group_sizes, + query_machete_supported_quant_types, unpack_quantized_values_into_int32, ) from ...utils.marlin import replace_parameter @@ -34,14 +36,14 @@ log = setup_logger() -class MacheteQuantLinear(BaseQuantLinear): - SUPPORTS_BACKENDS = [BACKEND.MACHETE] +class MacheteLinear(GPTQQuantLinear): + SUPPORTS_BACKENDS = [BACKEND.GPTQ_MACHETE] SUPPORTS_METHODS = [METHOD.GPTQ] SUPPORTS_FORMATS = {FORMAT.GPTQ: 100} SUPPORTS_BITS = [4, 8] SUPPORTS_GROUP_SIZE = [-1, 64, 128] SUPPORTS_DESC_ACT = [True, False] - SUPPORTS_SYM = [True] + SUPPORTS_SYM = [True, False] SUPPORTS_SHARDS = True SUPPORTS_TRAINING = False SUPPORTS_AUTO_PADDING = False @@ -54,13 +56,15 @@ class MacheteQuantLinear(BaseQuantLinear): SUPPORTS_ADAPTERS = [Lora] SUPPORTS_DTYPES = [torch.float16, torch.bfloat16] - REQUIRES_FORMAT_V2 = False + REQUIRES_FORMAT_V2 = True QUANT_TYPE = "machete" TYPE_MAP = { (4, True): scalar_types.uint4b8, (8, True): scalar_types.uint8b128, + (4, False): scalar_types.uint4, + (8, False): scalar_types.uint8, } def __init__( @@ -76,12 +80,6 @@ def __init__( register_buffers: bool = False, adapter: Adapter = None, **kwargs): - if machete_import_exception is not None: - raise ValueError( - "Trying to use the machete backend, but could not import the " - f"C++/CUDA dependencies with the following error: {machete_import_exception}" - ) - if (bits, sym) not in self.TYPE_MAP: raise ValueError(f"Unsupported quantization config: bits={bits}, sym={sym}") @@ -94,7 +92,7 @@ def __init__( out_features=out_features, bias=bias, pack_dtype=pack_dtype, - backend=kwargs.pop("backend", BACKEND.MACHETE), + backend=kwargs.pop("backend", BACKEND.GPTQ_MACHETE), adapter=adapter, register_buffers=False, **kwargs) @@ -122,12 +120,12 @@ def __init__( ) # Scales - scales_rows = self.in_features if self.group_size == -1 else self.in_features // self.group_size + grouped_rows = math.ceil(self.in_features / self.group_size) self.register_parameter( "scales", torch.nn.Parameter( torch.empty( - scales_rows, + grouped_rows, self.out_features, dtype=torch.float16, ), @@ -135,11 +133,16 @@ def __init__( ), ) - # Zero points unused for symmetric GPTQ + # Register GPTQ checkpoint-compatible qzero storage even for symmetric + # configs so Accelerate can bind tensors before post_init(). self.register_parameter( "qzeros", torch.nn.Parameter( - torch.empty(0, dtype=torch.float16), + torch.empty( + grouped_rows, + self.out_features // self.pack_factor, + dtype=self.pack_dtype, + ), requires_grad=False, ), ) @@ -149,7 +152,7 @@ def __init__( else: self.bias = None - self.weight_type = self.TYPE_MAP[(self.bits, sym)] + self.weight_type = self.TYPE_MAP[(self.bits, self.sym)] self.has_zero_points = False # Buffer storing permutation applied to activations (empty when unused) @@ -157,16 +160,12 @@ def __init__( @classmethod def validate_once(cls) -> Tuple[bool, Optional[Exception]]: - if gptqmodel_machete_kernels is None: - return False, ImportError(machete_import_exception) - else: - return True, None + if not machete_runtime_available(): + return False, ImportError(machete_runtime_error()) + return True, None @classmethod def validate(cls, **args) -> Tuple[bool, Optional[Exception]]: - if machete_import_exception is not None: - return False, ImportError(machete_import_exception) - ok, err = cls._validate(**args) if not ok: return ok, err @@ -183,9 +182,11 @@ def validate(cls, **args) -> Tuple[bool, Optional[Exception]]: quant_type = cls.TYPE_MAP.get((bits, sym)) if quant_type is None: return False, ValueError(f"Machete does not support bits={bits}, sym={sym}") + if quant_type not in query_machete_supported_quant_types(zero_points=not sym): + return False, ValueError(f"Machete does not support bits={bits}, sym={sym}") group_size = args.get("group_size") - dtype = args.get("dtype", torch.float16) + dtype = args.get("dtype") or torch.float16 if group_size not in query_machete_supported_group_sizes(dtype): return False, ValueError( f"Machete does not support group_size={group_size} for dtype={dtype}" @@ -200,7 +201,7 @@ def validate_device(cls, device: DEVICE): if IS_ROCM: raise NotImplementedError("Machete kernel is not supported on ROCm.") if not _validate_machete_device_support(): - raise NotImplementedError("Machete kernel requires compute capability >= 9.0.") + raise NotImplementedError(machete_runtime_error()) def post_init(self): device = self.qweight.device @@ -238,18 +239,36 @@ def post_init(self): torch.nn.Parameter(prepacked.contiguous(), requires_grad=False), ) + scales = self.scales.data.contiguous() replace_parameter( self, "scales", - torch.nn.Parameter(self.scales.data.contiguous(), requires_grad=False), + torch.nn.Parameter(scales, requires_grad=False), ) - replace_parameter( - self, - "qzeros", - torch.nn.Parameter(torch.empty(0, dtype=self.scales.dtype, device=device), requires_grad=False), - ) - self.has_zero_points = False + if self.sym: + replace_parameter( + self, + "qzeros", + torch.nn.Parameter( + torch.empty(0, dtype=self.pack_dtype, device=device), + requires_grad=False, + ), + ) + self.has_zero_points = False + else: + qzeros_unpacked = unpack_quantized_values_into_int32( + self.qzeros.data, + self.weight_type, + packed_dim=1, + ) + qzeros_fp = (-1.0 * scales * qzeros_unpacked.to(scales.dtype)).contiguous() + replace_parameter( + self, + "qzeros", + torch.nn.Parameter(qzeros_fp, requires_grad=False), + ) + self.has_zero_points = True if self.bias is not None: self.bias = self.bias.to(device=device) @@ -278,7 +297,13 @@ def forward(self, x: torch.Tensor): if group_scales.dtype != input_2d.dtype: group_scales = group_scales.to(dtype=input_2d.dtype) - group_zeros = self.qzeros if self.has_zero_points and self.qzeros.numel() > 0 else None + if self.has_zero_points: + assert self.qzeros is not None and self.qzeros.numel() > 0, ( + "Asymmetric MacheteLinear requires non-empty qzeros after post_init()." + ) + group_zeros = self.qzeros + else: + group_zeros = None output = machete_mm( a=input_2d, @@ -300,4 +325,4 @@ def forward(self, x: torch.Tensor): return result -__all__ = ["MacheteQuantLinear"] +__all__ = ["MacheteLinear"] diff --git a/gptqmodel/nn_modules/qlinear/machete_awq.py b/gptqmodel/nn_modules/qlinear/machete_awq.py index ecc4e5cdf..1769632a4 100644 --- a/gptqmodel/nn_modules/qlinear/machete_awq.py +++ b/gptqmodel/nn_modules/qlinear/machete_awq.py @@ -16,13 +16,13 @@ from ...utils.logger import setup_logger from ...utils.machete import ( _validate_machete_device_support, - gptqmodel_machete_kernels, - machete_import_exception, machete_mm, machete_prepack_B, + machete_runtime_available, + machete_runtime_error, pack_quantized_values_into_int32, ) -from ...utils.marlin import replace_parameter, unpack_cols +from ...utils.marlin import replace_parameter, replace_tensor, unpack_cols from ...utils.marlin_scalar_type import scalar_types from ...utils.rocm import IS_ROCM @@ -30,8 +30,50 @@ log = setup_logger() -class AwqMacheteQuantLinear(AWQuantLinear): - SUPPORTS_BACKENDS = [BACKEND.MACHETE] +def _undo_awq_interleave(values: torch.Tensor, num_bits: int) -> torch.Tensor: + if num_bits == 4: + undo_interleave = [0, 4, 1, 5, 2, 6, 3, 7] + elif num_bits == 8: + undo_interleave = [0, 2, 1, 3] + else: + raise ValueError(f"Unsupported AWQ num_bits={num_bits}") + + return ( + values.reshape(-1, len(undo_interleave))[:, undo_interleave] + .reshape(values.shape) + .contiguous() + ) + + +def _replace_registered_tensor( + module: torch.nn.Module, + name: str, + new_tensor: torch.Tensor, +) -> None: + if name in module._parameters: + replace_parameter( + module, + name, + torch.nn.Parameter(new_tensor, requires_grad=False), + ) + return + + if name in module._buffers: + current = getattr(module, name) + if ( + current.dtype == new_tensor.dtype + and current.untyped_storage().nbytes() == new_tensor.untyped_storage().nbytes() + ): + replace_tensor(module, name, new_tensor) + else: + module._buffers[name] = new_tensor + return + + raise KeyError(f"{module.__class__.__name__}.{name} is not a registered tensor") + + +class AwqMacheteLinear(AWQuantLinear): + SUPPORTS_BACKENDS = [BACKEND.AWQ_MACHETE] SUPPORTS_METHODS = [METHOD.AWQ] SUPPORTS_FORMATS = {FORMAT.GEMM: 100, FORMAT.MARLIN: 100} SUPPORTS_BITS = [4, 8] @@ -73,12 +115,6 @@ def __init__( adapter: Adapter = None, register_buffers: bool = False, **kwargs): - if machete_import_exception is not None: - raise ValueError( - "Trying to use the machete backend, but could not import the " - f"C++/CUDA dependencies with the following error: {machete_import_exception}" - ) - if bits not in self.TYPE_MAP: raise ValueError(f"Unsupported num_bits = {bits}. Supported: {list(self.TYPE_MAP.keys())}") @@ -91,7 +127,7 @@ def __init__( out_features=out_features, bias=bias, pack_dtype=pack_dtype, - backend=kwargs.pop("backend", BACKEND.MACHETE), + backend=kwargs.pop("backend", BACKEND.AWQ_MACHETE), adapter=adapter, register_buffers=register_buffers, **kwargs) @@ -101,10 +137,9 @@ def __init__( @classmethod def validate_once(cls) -> Tuple[bool, Optional[Exception]]: - if gptqmodel_machete_kernels is None: - return False, ImportError(machete_import_exception) - else: - return True, None + if not machete_runtime_available(): + return False, ImportError(machete_runtime_error()) + return True, None @classmethod def validate_device(cls, device: DEVICE): @@ -125,6 +160,7 @@ def post_init(self): self.in_features, self.out_features, ).to(device=device) + qweight_int = _undo_awq_interleave(qweight_int, self.bits) packed = pack_quantized_values_into_int32( qweight_int, @@ -138,18 +174,10 @@ def post_init(self): b_type=self.weight_type, group_scales_type=self.scales.dtype, ) - replace_parameter( - self, - "qweight", - torch.nn.Parameter(prepacked.contiguous(), requires_grad=False), - ) + _replace_registered_tensor(self, "qweight", prepacked.contiguous()) # Ensure scales are contiguous and resident on the correct device. - replace_parameter( - self, - "scales", - torch.nn.Parameter(self.scales.contiguous(), requires_grad=False), - ) + _replace_registered_tensor(self, "scales", self.scales.contiguous()) # Convert zero-points: unpack columns, then pre-apply scales as expected by machete_mm effective_group_size = self.in_features if self.group_size == -1 else self.group_size @@ -161,14 +189,11 @@ def post_init(self): num_groups, self.out_features, ).to(device=device) + qzeros_unpacked = _undo_awq_interleave(qzeros_unpacked, self.bits) scales = self.scales qzeros_fp = (-1.0 * scales.to(dtype=scales.dtype) * qzeros_unpacked.to(scales.dtype)).contiguous() - replace_parameter( - self, - "qzeros", - torch.nn.Parameter(qzeros_fp, requires_grad=False), - ) + _replace_registered_tensor(self, "qzeros", qzeros_fp) if self.bias is not None: self.bias = self.bias.to(device=device) @@ -203,4 +228,4 @@ def forward(self, x: torch.Tensor): return result -__all__ = ["AwqMacheteQuantLinear"] +__all__ = ["AwqMacheteLinear"] diff --git a/gptqmodel/nn_modules/qlinear/marlin.py b/gptqmodel/nn_modules/qlinear/marlin.py index 730c7ce5c..a6fd5b9e6 100644 --- a/gptqmodel/nn_modules/qlinear/marlin.py +++ b/gptqmodel/nn_modules/qlinear/marlin.py @@ -23,11 +23,13 @@ from ...adapter.adapter import Adapter, Lora from ...models._const import DEVICE, PLATFORM -from ...nn_modules.qlinear import BaseQuantLinear +from ...nn_modules.qlinear import GPTQQuantLinear from ...quantization import FORMAT, METHOD from ...utils.backend import BACKEND +from ...utils.env import env_flag from ...utils.logger import setup_logger from ...utils.marlin import ( + _marlin_capability_supported, _transform_param, apply_gptq_marlin_linear, gptq_marlin_repack, @@ -38,6 +40,8 @@ marlin_permute_bias, marlin_permute_scales, marlin_repeat_scales_on_all_ranks, + marlin_runtime_available, + marlin_runtime_error, marlin_sort_g_idx, replace_parameter, ) @@ -48,8 +52,8 @@ log = setup_logger() -class MarlinQuantLinear(BaseQuantLinear): - SUPPORTS_BACKENDS = [BACKEND.MARLIN, BACKEND.MARLIN_FP16] +class MarlinLinear(GPTQQuantLinear): + SUPPORTS_BACKENDS = [BACKEND.GPTQ_MARLIN] SUPPORTS_METHODS = [METHOD.GPTQ] SUPPORTS_FORMATS = {FORMAT.GPTQ: 90, FORMAT.MARLIN: 90} SUPPORTS_BITS = [4, 8] @@ -94,7 +98,8 @@ def __init__( **kwargs): if marlin_import_exception is not None: raise ValueError( - f"Trying to use the marlin backend, but could not import the C++/CUDA dependencies with the following error: {marlin_import_exception}" + "Trying to use the marlin backend, but the runtime requirements were not met: " + f"{marlin_import_exception}" ) # self.original_in_features = in_features @@ -105,6 +110,9 @@ def __init__( # (since we have only one group per output channel) desc_act = False + self.compute_dtype = kwargs.get("dtype") or torch.float16 + self.fp32 = env_flag("GPTQMODEL_MARLIN_USE_FP32", default=True) + super().__init__( bits=bits, group_size=group_size, @@ -114,17 +122,14 @@ def __init__( out_features=out_features, bias=bias, pack_dtype=pack_dtype, - backend=kwargs.pop("backend", BACKEND.MARLIN), + backend=kwargs.pop("backend", BACKEND.GPTQ_MARLIN), adapter=adapter, register_buffers=False, # do not register buffers in super() **kwargs) - # toggle fp32 mode depending on MARLIN or MARLIN_FP16 backend - self.fp32 = True if self.backend in [BACKEND.MARLIN, BACKEND.AUTO] else False - if not self.fp32: log.warn.once( - "Kernel: Marlin FP16 mode is activated with reduced accuracy. Use default Marlin model for improved inference quality.") + "Kernel: GPTQMODEL_MARLIN_USE_FP32 is disabled. Marlin will use reduced-precision reduction.") # Determine sharding if marlin_repeat_scales_on_all_ranks(desc_act, @@ -167,7 +172,7 @@ def __init__( torch.empty( scales_and_zp_size, self.out_features, - dtype=torch.float16, + dtype=self.compute_dtype, ), requires_grad=False ), @@ -187,7 +192,7 @@ def __init__( ) if bias: - self.register_buffer("bias", torch.zeros((self.out_features), dtype=torch.float16)) + self.register_buffer("bias", torch.zeros((self.out_features), dtype=self.compute_dtype)) else: self.bias = None @@ -229,16 +234,24 @@ def validate_device(cls, device: DEVICE): raise NotImplementedError("Marlin kernel is not supported on ROCm.") # Directly check capabilities of all currently visible CUDA devices - has_cuda_v8 = all( - torch.cuda.get_device_capability(i)[0] >= 8 + has_supported_cuda = all( + _marlin_capability_supported(*torch.cuda.get_device_capability(i)) for i in range(torch.cuda.device_count()) ) - if not has_cuda_v8: - raise NotImplementedError("Marlin kernel only supports compute capability >= 8.0.") + if not has_supported_cuda: + raise NotImplementedError( + "Marlin kernel only supports compute capability >= 7.5." + ) def post_init(self): device = self.qweight.device + if not marlin_runtime_available(self.compute_dtype): + raise ModuleNotFoundError( + "Marlin torch.ops kernels are not properly installed. Error: " + + marlin_runtime_error(self.compute_dtype) + ) + self.is_k_full = marlin_is_k_full(self.desc_act, is_row_parallel=False) # Allocate marlin workspace. @@ -249,7 +262,8 @@ def transform_w_q(x): perm=self.g_idx_sort_indices, size_k=self.in_features, size_n=self.out_features, - num_bits=self.bits) + num_bits=self.bits, + dtype=self.compute_dtype) return x def transform_w_s(x): @@ -297,6 +311,8 @@ def forward(self, x: torch.Tensor): # make sure scales is synced with x/input if x.dtype != self.scales.dtype: replace_parameter(self, "scales", self.scales.to(dtype=x.dtype)) + if self.bias is not None and self.bias.dtype != x.dtype: + self.bias.data = self.bias.data.to(dtype=x.dtype) out = apply_gptq_marlin_linear( input=x.contiguous() if self.is_lm_head else x, @@ -375,4 +391,4 @@ def dequantize_qzeros(layer): return unpacked_qzeros -__all__ = ["MarlinQuantLinear"] +__all__ = ["MarlinLinear"] diff --git a/gptqmodel/nn_modules/qlinear/marlin_awq.py b/gptqmodel/nn_modules/qlinear/marlin_awq.py index 48c032c38..b8385d63d 100644 --- a/gptqmodel/nn_modules/qlinear/marlin_awq.py +++ b/gptqmodel/nn_modules/qlinear/marlin_awq.py @@ -18,13 +18,15 @@ from ...utils.logger import setup_logger from ...utils.marlin import ( apply_awq_marlin_linear, + awq_marlin_repack, awq_to_marlin_zero_points, - gptqmodel_marlin_kernels, marlin_import_exception, marlin_make_empty_g_idx, marlin_make_workspace_new, marlin_permute_bias, marlin_permute_scales, + marlin_runtime_available, + marlin_runtime_error, replace_parameter, ) from ...utils.marlin_scalar_type import scalar_types @@ -34,14 +36,14 @@ log = setup_logger() -class AwqMarlinQuantLinear(AWQuantLinear): - SUPPORTS_BACKENDS = [BACKEND.MARLIN] +class AwqMarlinLinear(AWQuantLinear): + SUPPORTS_BACKENDS = [BACKEND.AWQ_MARLIN] SUPPORTS_METHODS = [METHOD.AWQ] SUPPORTS_FORMATS = {FORMAT.GEMM: 90, FORMAT.MARLIN: 90} SUPPORTS_BITS = [4, 8] SUPPORTS_GROUP_SIZE = [-1, 32, 64, 128] SUPPORTS_DESC_ACT = [True, False] - SUPPORTS_SYM = [True] + SUPPORTS_SYM = [True, False] SUPPORTS_SHARDS = True SUPPORTS_TRAINING = False SUPPORTS_AUTO_PADDING = False @@ -53,7 +55,7 @@ class AwqMarlinQuantLinear(AWQuantLinear): SUPPORTS_PACK_DTYPES = [torch.int32] SUPPORTS_ADAPTERS = [Lora] - SUPPORTS_DTYPES = [torch.float16] + SUPPORTS_DTYPES = [torch.float16, torch.bfloat16] REQUIRES_FORMAT_V2 = False @@ -79,6 +81,7 @@ def __init__( register_buffers=False, **kwargs): self.max_par = 8 # partitioning for large inputs + self.compute_dtype = kwargs.get("dtype") or torch.float16 super().__init__( bits=bits, @@ -89,7 +92,7 @@ def __init__( out_features=out_features, bias=bias, pack_dtype=pack_dtype, - backend=kwargs.pop("backend", BACKEND.MARLIN), + backend=kwargs.pop("backend", BACKEND.AWQ_MARLIN), adapter=adapter, register_buffers=False, **kwargs) @@ -124,7 +127,7 @@ def __init__( torch.empty( self.in_features // self.group_size, self.out_features, - dtype=torch.float16, + dtype=self.compute_dtype, ), requires_grad=False ) @@ -135,7 +138,7 @@ def __init__( "bias", torch.zeros( (out_features), - dtype=torch.float16, + dtype=self.compute_dtype, ), ) else: @@ -185,15 +188,22 @@ def validate_device(cls, device: DEVICE): def post_init(self): device = self.qweight.device + if not marlin_runtime_available(self.compute_dtype): + raise ModuleNotFoundError( + "Marlin torch.ops kernels are not properly installed. Error: " + + marlin_runtime_error(self.compute_dtype) + ) + # Allocate marlin workspace self.workspace = marlin_make_workspace_new(device) # Repack weights from AWQ format to marlin format. - marlin_qweight = gptqmodel_marlin_kernels.awq_marlin_repack( + marlin_qweight = awq_marlin_repack( self.qweight, self.in_features, self.out_features, - self.bits) + self.bits, + dtype=self.compute_dtype) replace_parameter(self, "qweight", marlin_qweight) # Permute scales from AWQ format to marlin format. @@ -265,4 +275,4 @@ def forward(self, x: torch.Tensor): return out -__all__ = ["AwqMarlinQuantLinear"] +__all__ = ["AwqMarlinLinear"] diff --git a/gptqmodel/nn_modules/qlinear/paroquant.py b/gptqmodel/nn_modules/qlinear/paroquant.py new file mode 100644 index 000000000..fffb23f6e --- /dev/null +++ b/gptqmodel/nn_modules/qlinear/paroquant.py @@ -0,0 +1,297 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# ParoQuant runtime implementation adapted from the ParoQuant paper and public +# project: +# https://arxiv.org/html/2511.10645v2 +# https://github.com/z-lab/paroquant + +"""ParoQuant CUDA-backed quantized linear layer.""" + +from __future__ import annotations + +from typing import Optional, Tuple + +import torch + +from ...adapter.adapter import Adapter, Lora +from ...models._const import DEVICE, PLATFORM +from ...quantization import FORMAT, METHOD +from ...quantization.awq.utils.packing_utils import dequantize_gemm +from ...utils.backend import BACKEND +from ...utils.env import env_flag +from ...utils.paroquant import apply_paroquant_rotation, build_identity_rotation_buffers, is_identity_rotation +from .gemm_awq import FP32_ACCUM, _awq_cuda_gemm_forward +from .torch_awq import AwqTorchLinear + + +# Rotated activations benchmark faster with a shallower K split than generic AWQ. +_PAROQUANT_AWQ_SPLIT_K = 4 +_PAROQUANT_CACHE_RUNTIME_DTYPE = env_flag("GPTQMODEL_PAROQUANT_CACHE_RUNTIME_DTYPE", default=False) +_PAROQUANT_AUTO_CACHE_BF16_RUNTIME_DTYPE = env_flag( + "GPTQMODEL_PAROQUANT_AUTO_CACHE_BF16_RUNTIME_DTYPE", default=True +) +# Cache typed rotation metadata so BF16 runs do not re-cast theta/scales every call. +_PAROQUANT_CACHE_ROTATION_DTYPE = env_flag("GPTQMODEL_PAROQUANT_CACHE_ROTATION_DTYPE", default=False) +_PAROQUANT_AUTO_CACHE_BF16_ROTATION_DTYPE = env_flag( + "GPTQMODEL_PAROQUANT_AUTO_CACHE_BF16_ROTATION_DTYPE", default=True +) + + +class ParoLinear(AwqTorchLinear): + """Run ParoQuant inference by rotating inputs and reusing AWQ packed GEMM.""" + + SUPPORTS_BACKENDS = [BACKEND.PAROQUANT_CUDA] + SUPPORTS_METHODS = [METHOD.PARO] + SUPPORTS_FORMATS = {FORMAT.PAROQUANT: 55} + SUPPORTS_BITS = [4] + SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128] + SUPPORTS_DESC_ACT = [True, False] + SUPPORTS_SYM = [True] + SUPPORTS_SHARDS = True + SUPPORTS_TRAINING = True + SUPPORTS_AUTO_PADDING = False + SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [1] + SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [1] + + SUPPORTS_DEVICES = [DEVICE.ALL] + SUPPORTS_PLATFORM = [PLATFORM.ALL] + SUPPORTS_PACK_DTYPES = [torch.int32] + SUPPORTS_ADAPTERS = [Lora] + SUPPORTS_DTYPES = [torch.float16, torch.bfloat16] + + REQUIRES_FORMAT_V2 = False + QUANT_TYPE = "awq_paroquant" + + def __init__( + self, + bits: int, + group_size: int, + sym: bool, + desc_act: bool, + in_features: int, + out_features: int, + bias: bool = False, + pack_dtype: torch.dtype = torch.int32, + adapter: Adapter = None, + register_buffers: bool = False, + krot: int = 8, + fp32_accum: bool = FP32_ACCUM, + cache_runtime_dtype: bool = _PAROQUANT_CACHE_RUNTIME_DTYPE, + auto_cache_bf16_runtime_dtype: bool = _PAROQUANT_AUTO_CACHE_BF16_RUNTIME_DTYPE, + cache_rotation_dtype: bool = _PAROQUANT_CACHE_ROTATION_DTYPE, + auto_cache_bf16_rotation_dtype: bool = _PAROQUANT_AUTO_CACHE_BF16_ROTATION_DTYPE, + **kwargs, + ): + """Initialize AWQ buffers plus the extra ParoQuant rotation state.""" + self.krot = int(krot) + if self.krot <= 0: + raise ValueError(f"ParoLinear: `krot` must be positive, got {krot}.") + self.fp32_accum = bool(fp32_accum) + self.cache_runtime_dtype = bool(cache_runtime_dtype) + self.auto_cache_bf16_runtime_dtype = bool(auto_cache_bf16_runtime_dtype) + self.cache_rotation_dtype = bool(cache_rotation_dtype) + self.auto_cache_bf16_rotation_dtype = bool(auto_cache_bf16_rotation_dtype) + self._rotation_runtime_dtype: Optional[torch.dtype] = None + self._rotation_runtime_device: Optional[torch.device] = None + self._runtime_theta: Optional[torch.Tensor] = None + self._runtime_channel_scales: Optional[torch.Tensor] = None + + super().__init__( + bits=bits, + group_size=group_size, + sym=sym, + desc_act=desc_act, + in_features=in_features, + out_features=out_features, + bias=bias, + pack_dtype=pack_dtype, + adapter=adapter, + register_buffers=register_buffers, + backend=kwargs.pop("backend", BACKEND.PAROQUANT_CUDA), + **kwargs, + ) + self._register_rotation_buffers() + self._rotation_identity = True + + def _register_rotation_buffers(self) -> None: + """Allocate the per-layer buffers that encode runtime rotations.""" + # Fresh runtime modules must start from a valid identity matching so the + # fused kernel never sees duplicate pair indices before optimized + # buffers are loaded from quantization or checkpoints. + pairs, theta, channel_scales = build_identity_rotation_buffers( + in_features=self.in_features, + group_size=self.group_size, + krot=self.krot, + dtype=torch.float16, + ) + + if "theta" not in self._buffers: + self.register_buffer("theta", theta) + else: + self.theta = theta + + if "pairs" not in self._buffers: + self.register_buffer("pairs", pairs) + else: + self.pairs = pairs + + if "channel_scales" not in self._buffers: + self.register_buffer("channel_scales", channel_scales) + else: + self.channel_scales = channel_scales + + def post_init(self): + """Refresh cached runtime state after weights or rotation buffers change.""" + super().post_init() + self._clear_rotation_runtime_cache() + self._rotation_identity = is_identity_rotation(self.theta, self.channel_scales) + + @classmethod + def validate_once(cls) -> Tuple[bool, Optional[Exception]]: + """ParoQuant relies on AWQ validation and needs no extra one-time checks here.""" + return True, None + + def extra_repr(self) -> str: + """Expose ParoQuant-specific fields in `repr(module)` for debugging.""" + return ( + f"in_features={self.in_features}, out_features={self.out_features}, " + f"bias={self.bias is not None}, bits={self.bits}, group_size={self.group_size}, " + f"krot={self.krot}, awq_split_k={_PAROQUANT_AWQ_SPLIT_K}, " + f"cache_runtime_dtype={self.cache_runtime_dtype}, " + f"auto_cache_bf16={self.auto_cache_bf16_runtime_dtype}, " + f"cache_rotation_dtype={self.cache_rotation_dtype}, " + f"auto_cache_bf16_rotation={self.auto_cache_bf16_rotation_dtype}, " + f"fp32_accum={self.fp32_accum}" + ) + + def _clear_rotation_runtime_cache(self) -> None: + self._rotation_runtime_dtype = None + self._rotation_runtime_device = None + self._runtime_theta = None + self._runtime_channel_scales = None + + def _ensure_runtime_dtype(self, device: torch.device, dtype: torch.dtype) -> None: + if self.scales is not None and (self.scales.device != device or self.scales.dtype != dtype or not self.scales.is_contiguous()): + self.scales = self.scales.to(device=device, dtype=dtype).contiguous() + if self.bias is not None and (self.bias.device != device or self.bias.dtype != dtype or not self.bias.is_contiguous()): + self.bias = self.bias.to(device=device, dtype=dtype).contiguous() + + def _ensure_rotation_runtime_dtype( + self, + device: torch.device, + dtype: torch.dtype, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + if ( + self._rotation_runtime_device != device + or self._rotation_runtime_dtype != dtype + or self._runtime_theta is None + or self._runtime_channel_scales is None + or not self._runtime_theta.is_contiguous() + or not self._runtime_channel_scales.is_contiguous() + ): + self._runtime_theta = self.theta.to(device=device, dtype=dtype).contiguous() + self._runtime_channel_scales = self.channel_scales.to(device=device, dtype=dtype).contiguous() + self._rotation_runtime_device = device + self._rotation_runtime_dtype = dtype + return self._runtime_theta, self._runtime_channel_scales + + def _rotate_inputs(self, x_flat: torch.Tensor) -> torch.Tensor: + """Apply the learned input transform before quantized matmul.""" + if self._rotation_identity: + return x_flat + use_cached_rotation_dtype = self.cache_rotation_dtype or ( + self.auto_cache_bf16_rotation_dtype and x_flat.dtype == torch.bfloat16 + ) + theta = self.theta + channel_scales = self.channel_scales + if use_cached_rotation_dtype: + theta, channel_scales = self._ensure_rotation_runtime_dtype(x_flat.device, x_flat.dtype) + return apply_paroquant_rotation( + x_flat, + self.pairs, + theta, + scales=channel_scales, + group_size=self.group_size, + ) + + def _forward_dense(self, x_flat: torch.Tensor) -> torch.Tensor: + """Fallback reference path: dequantize AWQ weights and run dense matmul.""" + weight = dequantize_gemm( + qweight=self.qweight, + qzeros=self.qzeros, + scales=self.scales, + bits=self.bits, + group_size=self.group_size, + ) + if weight.dtype != x_flat.dtype or weight.device != x_flat.device: + weight = weight.to(device=x_flat.device, dtype=x_flat.dtype) + + out = torch.matmul(x_flat, weight) + if self.bias is not None: + out = out + self.bias.to(device=x_flat.device, dtype=x_flat.dtype) + return out + + def _forward_cuda_awq_kernel(self, x_flat: torch.Tensor) -> Optional[torch.Tensor]: + """Fast path that feeds rotated activations into the AWQ CUDA GEMM kernel.""" + if x_flat.device.type != "cuda": + return None + + compute_dtype = x_flat.dtype if x_flat.dtype in (torch.float16, torch.bfloat16) else torch.float16 + kernel_input = ( + x_flat + if x_flat.dtype == compute_dtype and x_flat.is_contiguous() + else x_flat.to(device=x_flat.device, dtype=compute_dtype).contiguous() + ) + use_cached_runtime_dtype = self.cache_runtime_dtype or ( + self.auto_cache_bf16_runtime_dtype and compute_dtype == torch.bfloat16 + ) + if use_cached_runtime_dtype: + self._ensure_runtime_dtype(kernel_input.device, compute_dtype) + kernel_scales = self.scales + kernel_bias = self.bias + else: + kernel_scales = self.scales + if ( + kernel_scales.device != kernel_input.device + or kernel_scales.dtype != compute_dtype + or not kernel_scales.is_contiguous() + ): + kernel_scales = kernel_scales.to(device=kernel_input.device, dtype=compute_dtype).contiguous() + kernel_bias = self.bias + if ( + kernel_bias is not None + and (kernel_bias.device != kernel_input.device or kernel_bias.dtype != compute_dtype or not kernel_bias.is_contiguous()) + ): + kernel_bias = kernel_bias.to(device=kernel_input.device, dtype=compute_dtype).contiguous() + out = _awq_cuda_gemm_forward( + kernel_input.reshape(-1, kernel_input.shape[-1]), + self.qweight, + kernel_scales, + self.qzeros, + _PAROQUANT_AWQ_SPLIT_K, + fp32_accum=self.fp32_accum, + ) + if kernel_bias is not None: + out = out + kernel_bias + if out.dtype != x_flat.dtype: + out = out.to(dtype=x_flat.dtype) + return out + + def forward(self, x: torch.Tensor): + """Rotate inputs, run quantized matmul, then apply adapters in input space.""" + original_shape = x.shape[:-1] + (self.out_features,) + x_flat = x.reshape(-1, x.shape[-1]) + rotated = self._rotate_inputs(x_flat) + + out = self._forward_cuda_awq_kernel(rotated) + if out is None: + out = self._forward_dense(rotated) + + if self.adapter: + out = self.adapter.apply(x=x_flat, out=out) + + return out.reshape(original_shape) + + +__all__ = ["ParoLinear"] diff --git a/gptqmodel/nn_modules/qlinear/paroquant_triton.py b/gptqmodel/nn_modules/qlinear/paroquant_triton.py new file mode 100644 index 000000000..19ad9c753 --- /dev/null +++ b/gptqmodel/nn_modules/qlinear/paroquant_triton.py @@ -0,0 +1,206 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# ParoQuant Triton runtime implementation adapted from the ParoQuant paper and +# public project: +# https://arxiv.org/html/2511.10645v2 +# https://github.com/z-lab/paroquant + +"""ParoQuant Triton-backed quantized linear layer.""" + +from __future__ import annotations + +import os +import time +from typing import Optional, Tuple + +import torch + +from ...models._const import DEVICE, PLATFORM +from ...quantization import FORMAT +from ...quantization.paroquant.modules.triton.gemm import ( + paroquant_dequantize_triton, + paroquant_gemm_triton_decode, + paroquant_gemm_triton_prefill, +) +from ...utils import has_gil_disabled +from ...utils.backend import BACKEND +from .paroquant import ParoLinear + + +class ParoQuantTritonLinear(ParoLinear): + """Use Triton fused kernels for ParoQuant prefill/decode execution.""" + + SUPPORTS_BACKENDS = [BACKEND.PAROQUANT_TRITON] + SUPPORTS_METHODS = ParoLinear.SUPPORTS_METHODS + SUPPORTS_FORMATS = {FORMAT.PAROQUANT: 0} + SUPPORTS_BITS = ParoLinear.SUPPORTS_BITS + SUPPORTS_GROUP_SIZE = ParoLinear.SUPPORTS_GROUP_SIZE + SUPPORTS_DESC_ACT = ParoLinear.SUPPORTS_DESC_ACT + SUPPORTS_SYM = ParoLinear.SUPPORTS_SYM + SUPPORTS_SHARDS = ParoLinear.SUPPORTS_SHARDS + SUPPORTS_TRAINING = ParoLinear.SUPPORTS_TRAINING + SUPPORTS_AUTO_PADDING = ParoLinear.SUPPORTS_AUTO_PADDING + SUPPORTS_IN_FEATURES_DIVISIBLE_BY = ParoLinear.SUPPORTS_IN_FEATURES_DIVISIBLE_BY + SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = ParoLinear.SUPPORTS_OUT_FEATURES_DIVISIBLE_BY + SUPPORTS_DEVICES = [DEVICE.CUDA] + SUPPORTS_PLATFORM = [PLATFORM.LINUX, PLATFORM.WIN32] + SUPPORTS_PACK_DTYPES = ParoLinear.SUPPORTS_PACK_DTYPES + SUPPORTS_ADAPTERS = ParoLinear.SUPPORTS_ADAPTERS + SUPPORTS_DTYPES = ParoLinear.SUPPORTS_DTYPES + QUANT_TYPE = "awq_paroquant_triton" + PAROQUANT_TRITON_AUTOTUNE = os.environ.get("GPTQMODEL_PAROQUANT_TRITON_AUTOTUNE", "1") != "0" + PAROQUANT_TRITON_AUTOTUNE_WARMUP = max(0, int(os.environ.get("GPTQMODEL_PAROQUANT_TRITON_AUTOTUNE_WARMUP", "2"))) + PAROQUANT_TRITON_AUTOTUNE_ITERS = max(1, int(os.environ.get("GPTQMODEL_PAROQUANT_TRITON_AUTOTUNE_ITERS", "4"))) + PAROQUANT_TRITON_AUTOTUNE_MARGIN = max(0.0, float(os.environ.get("GPTQMODEL_PAROQUANT_TRITON_AUTOTUNE_MARGIN", "0.05"))) + PAROQUANT_TRITON_DECODE_MAX_ROWS = max(1, int(os.environ.get("GPTQMODEL_PAROQUANT_TRITON_DECODE_MAX_ROWS", "8"))) + + @classmethod + def validate_once(cls) -> Tuple[bool, Optional[Exception]]: + """Validate the Triton and CUDA runtime prerequisites once per process.""" + from packaging import version + from triton import __version__ as triton_version + + triton_v = version.parse(triton_version) + + if triton_v < version.parse("2.0.0"): + raise ImportError(f"triton version must be >= 2.0.0: actual = {triton_version}") + + if has_gil_disabled() and triton_v < version.parse("3.4.0"): + raise Exception("GIL is disabled and not compatible with current Triton. Please upgrade to Triton >= 3.4.0") + + if not torch.cuda.is_available(): + raise RuntimeError("ParoQuant Triton requires CUDA.") + + return True, None + + def __init__(self, *args, **kwargs): + """Initialize Triton autotune settings and the per-shape plan cache.""" + kwargs.setdefault("backend", BACKEND.PAROQUANT_TRITON) + super().__init__(*args, **kwargs) + self.paroquant_triton_autotune_enabled = self.PAROQUANT_TRITON_AUTOTUNE + self.paroquant_triton_autotune_warmup = self.PAROQUANT_TRITON_AUTOTUNE_WARMUP + self.paroquant_triton_autotune_iters = self.PAROQUANT_TRITON_AUTOTUNE_ITERS + self.paroquant_triton_autotune_margin = self.PAROQUANT_TRITON_AUTOTUNE_MARGIN + self.paroquant_triton_decode_max_rows = self.PAROQUANT_TRITON_DECODE_MAX_ROWS + self._plan_cache: dict[str, str] = {} + + def post_init(self): + """Defer to the shared ParoQuant setup without forcing compute dtype.""" + super().post_init() + + def clear_autotune(self): + """Drop cached plan decisions after major module state changes.""" + super().clear_autotune() + self._plan_cache = {} + + @staticmethod + def _sync_benchmark_device(device: torch.device) -> None: + """Synchronize CUDA timing measurements used by autotune.""" + if device.type == "cuda": + torch.cuda.synchronize(device=device) + + def _classify_forward_kind(self, x: torch.Tensor, x_flat: torch.Tensor) -> str: + """Classify the workload as decode or prefill for plan selection.""" + if x.dim() >= 3 and x.shape[-2] == 1 and x_flat.shape[0] <= self.paroquant_triton_decode_max_rows: + return "decode" + if x_flat.shape[0] <= self.paroquant_triton_decode_max_rows: + return "decode" + return "prefill" + + def _forward_triton_dense(self, rotated: torch.Tensor) -> torch.Tensor: + """Dense fallback used when fused Triton plans are unavailable or slower.""" + weight = paroquant_dequantize_triton(self.qweight, self.scales, self.qzeros) + if weight.dtype != rotated.dtype or weight.device != rotated.device: + weight = weight.to(device=rotated.device, dtype=rotated.dtype) + + out = torch.matmul(rotated, weight) + if self.bias is not None: + out = out + self.bias.to(device=rotated.device, dtype=rotated.dtype) + return out + + def _forward_triton_decode(self, rotated: torch.Tensor) -> torch.Tensor: + """Run the fused Triton kernel optimized for small-row decode workloads.""" + out = paroquant_gemm_triton_decode(rotated, self.qweight, self.scales, self.qzeros) + if self.bias is not None: + out = out + self.bias + return out + + def _forward_triton_prefill(self, rotated: torch.Tensor) -> torch.Tensor: + """Run the fused Triton kernel optimized for larger prefill batches.""" + out = paroquant_gemm_triton_prefill(rotated, self.qweight, self.scales, self.qzeros) + if self.bias is not None: + out = out + self.bias + return out + + def _run_plan(self, plan: str, rotated: torch.Tensor) -> torch.Tensor: + """Dispatch one named execution plan.""" + if plan == "dense": + return self._forward_triton_dense(rotated) + if plan == "decode_fused": + return self._forward_triton_decode(rotated) + if plan == "prefill_fused": + return self._forward_triton_prefill(rotated) + raise ValueError(f"Unknown ParoQuant Triton plan: {plan}") + + def _candidate_plans(self, kind: str) -> list[str]: + """Return the execution plans worth benchmarking for a workload class.""" + if kind == "decode": + return ["decode_fused", "dense", "prefill_fused"] + return ["prefill_fused", "dense", "decode_fused"] + + def _benchmark_plan(self, plan: str, rotated: torch.Tensor) -> float: + """Measure mean latency for one candidate plan.""" + with torch.inference_mode(): + for _ in range(self.paroquant_triton_autotune_warmup): + self._run_plan(plan, rotated) + self._sync_benchmark_device(rotated.device) + + start = time.perf_counter() + for _ in range(self.paroquant_triton_autotune_iters): + self._run_plan(plan, rotated) + self._sync_benchmark_device(rotated.device) + + return (time.perf_counter() - start) / self.paroquant_triton_autotune_iters + + def _select_plan(self, kind: str, rotated: torch.Tensor) -> str: + """Pick the best plan, bounded by a bias toward the default fused path.""" + default_plan = "decode_fused" if kind == "decode" else "prefill_fused" + if self.training or not self.paroquant_triton_autotune_enabled: + return default_plan + + cached = self._plan_cache.get(kind) + if cached is not None: + return cached + + try: + timings = {plan: self._benchmark_plan(plan, rotated) for plan in self._candidate_plans(kind)} + best_plan = min(timings, key=timings.get) + best_time = timings[best_plan] + default_time = timings[default_plan] + + if best_plan != default_plan and best_time > default_time * (1.0 - self.paroquant_triton_autotune_margin): + best_plan = default_plan + except Exception: + best_plan = default_plan + + self._plan_cache[kind] = best_plan + return best_plan + + def forward(self, x: torch.Tensor): + """Rotate inputs, pick a Triton plan, and preserve adapter semantics.""" + original_shape = x.shape[:-1] + (self.out_features,) + adapter_input = x.reshape(-1, x.shape[-1]) + x_flat = x.reshape(-1, x.shape[-1]) + rotated = self._rotate_inputs(x_flat) + + plan = self._select_plan(self._classify_forward_kind(x, x_flat), rotated) + out = self._run_plan(plan, rotated) + + if self.adapter: + out = self.adapter.apply(x=adapter_input, out=out) + + return out.reshape(original_shape) + + +__all__ = ["ParoQuantTritonLinear"] diff --git a/gptqmodel/nn_modules/qlinear/qqq.py b/gptqmodel/nn_modules/qlinear/qqq.py index fb56f86e9..714a3bffc 100644 --- a/gptqmodel/nn_modules/qlinear/qqq.py +++ b/gptqmodel/nn_modules/qlinear/qqq.py @@ -13,19 +13,14 @@ from ...adapter.adapter import Adapter, Lora from ...models._const import CPU, DEVICE, PLATFORM -from ...nn_modules.qlinear import BaseQuantLinear +from ...nn_modules.qlinear import GroupedQuantLinear from ...quantization import FORMAT, METHOD from ...utils.backend import BACKEND from ...utils.logger import setup_logger +from ...utils.qqq import qqq_gemm, qqq_runtime_available, qqq_runtime_error from ...utils.rocm import IS_ROCM -qqq_import_exception = None -try: - import gptqmodel_qqq_kernels -except ImportError as e: - qqq_import_exception = str(e) - log = setup_logger() @@ -46,10 +41,12 @@ def mul( @sms: number of SMs to use for the kernel (can usually be left as auto -1) @max_par: maximum number of batch 64 problems to solve in parallel for large input sizes """ - gptqmodel_qqq_kernels.qqq_gemm(A, B, C, D, s1, s2, s3, workspace, thread_k, thread_n, sms, max_par) + if not qqq_runtime_available(): + raise ModuleNotFoundError("QQQ torch.ops kernels are not properly installed. Error: " + qqq_runtime_error()) + qqq_gemm(A, B, C, D, s1, s2, s3, workspace, thread_k, thread_n, sms, max_par) -class QQQQuantLinear(BaseQuantLinear): +class QQQLinear(GroupedQuantLinear): SUPPORTS_BACKENDS = [BACKEND.QQQ] SUPPORTS_METHODS = [METHOD.QQQ] SUPPORTS_FORMATS = {FORMAT.QQQ: 100} @@ -115,6 +112,9 @@ def __init__( register_buffers=False, **kwargs) + # QQQ only needs the code range, not packed GPTQ/AWQ storage metadata. + self.maxq = (1 << self.bits) - 1 + # during quantization, we do are not loading tensors from disk so no need to preallocate buffers if register_buffers: self.register_buffer( @@ -154,7 +154,6 @@ def __init__( torch.zeros((self.max_par * 16 * 4, self.out_features), dtype=torch.int), persistent=False, ) - self.wf = torch.tensor(list(range(0, 32, 4)), dtype=torch.int32).unsqueeze(0) if bias: self.register_buffer("bias", torch.zeros((self.out_features), dtype=torch.float16)) else: @@ -208,8 +207,8 @@ def _get_perms(self): @classmethod def validate_once(cls) -> Tuple[bool, Optional[Exception]]: - if qqq_import_exception is not None: - return False, ImportError(qqq_import_exception) + if not qqq_runtime_available(): + return False, ImportError(qqq_runtime_error()) return True, None @classmethod @@ -401,4 +400,4 @@ def forward(self, A): return D.to(dtype=A_dtype) -__all__ = ["QQQQuantLinear"] +__all__ = ["QQQLinear"] diff --git a/gptqmodel/nn_modules/qlinear/torch.py b/gptqmodel/nn_modules/qlinear/torch.py index 74cb532b9..c53116484 100644 --- a/gptqmodel/nn_modules/qlinear/torch.py +++ b/gptqmodel/nn_modules/qlinear/torch.py @@ -32,8 +32,73 @@ log = setup_logger() -class TorchQuantLinear(PackableQuantLinear): - SUPPORTS_BACKENDS = [BACKEND.TORCH] + +class _LinearWeightMetadata: + """Tensor-like metadata shim for integrations that only inspect `weight` attrs.""" + + def __init__(self, module: "TorchLinear", transposed: bool = False): + self._module = module + self._transposed = transposed + + def _shape(self) -> torch.Size: + shape = (self._module.out_features, self._module.in_features) + if self._transposed: + shape = (shape[1], shape[0]) + return torch.Size(shape) + + def _first_tensor(self) -> torch.Tensor | None: + for name in ("qweight", "scales", "bias", "qzeros", "g_idx"): + tensor = getattr(self._module, name, None) + if tensor is not None: + return tensor + return None + + @property + def device(self) -> torch.device: + tensor = self._first_tensor() + return tensor.device if tensor is not None else torch.device("cpu") + + @property + def dtype(self) -> torch.dtype: + for name in ("bias", "scales", "qweight"): + tensor = getattr(self._module, name, None) + if tensor is not None: + return tensor.dtype + return torch.float16 + + @property + def is_cuda(self) -> bool: + return self.device.type == "cuda" + + @property + def ndim(self) -> int: + return 2 + + @property + def shape(self) -> torch.Size: + return self._shape() + + @property + def requires_grad(self) -> bool: + return False + + @property + def T(self) -> "_LinearWeightMetadata": + return _LinearWeightMetadata(self._module, transposed=not self._transposed) + + def size(self, dim: int | None = None): + shape = self._shape() + return shape if dim is None else shape[dim] + + def __repr__(self) -> str: + return ( + f"_LinearWeightMetadata(device={self.device}, dtype={self.dtype}, " + f"shape={tuple(self.shape)})" + ) + + +class TorchLinear(PackableQuantLinear): + SUPPORTS_BACKENDS = [BACKEND.GPTQ_TORCH] SUPPORTS_METHODS = [METHOD.GPTQ] SUPPORTS_FORMATS = {FORMAT.GPTQ: 20, FORMAT.GPTQ_V2: 20} SUPPORTS_BITS = [2, 3, 4, 8] @@ -81,7 +146,7 @@ def __init__( out_features=out_features, bias=bias, pack_dtype=pack_dtype, - backend=kwargs.pop("backend", BACKEND.TORCH), + backend=kwargs.pop("backend", BACKEND.GPTQ_TORCH), adapter=adapter, register_buffers=register_buffers, enable_wf_unsqueeze=kwargs.pop("enable_wf_unsqueeze", True), @@ -111,6 +176,7 @@ def __init__( self._prefetched_weights = {} self._prefetch_events = {} self._prefetch_streams = {} + self._weight_metadata = _LinearWeightMetadata(self) # if self.group_size != self.in_features: # self.padded_infeatures = self.in_features + (-self.in_features % self.group_size) @@ -136,6 +202,10 @@ def post_init(self): self.clear_weight_cache() self._reset_prefetch_state() + @property + def weight(self): + return self._weight_metadata + def dequantize_weight(self, num_itr: int = 1): # Triton dequant currently handles the common single-iteration layout. # Multi-iteration requests (num_itr > 1) are routed to the torch path below. @@ -230,13 +300,21 @@ def _forward(self, x, out_shape): def _forward_eager(self, x: torch.Tensor, out_shape): num_itr = self.g_idx.shape[0] // x.shape[-1] - weights = self._consume_prefetched_weights(x.dtype) + weights = self._consume_prefetched_weights(x.dtype, device=x.device) if weights is None: - weights = self.dequantize_weight(num_itr=num_itr).to(x.dtype) + weights = self.dequantize_weight(num_itr=num_itr) + if weights.device != x.device or weights.dtype != x.dtype: + # Quantized modules can be staged on a different accelerator than the + # caller tensor during multi-device kernel validation; matmul still + # needs both operands on the same device and dtype. + weights = weights.to(device=x.device, dtype=x.dtype) self._update_cached_weights(weights) out = torch.matmul(x, weights).reshape(out_shape) if self.bias is not None: - out.add_(self.bias) + bias = self.bias + if bias.device != out.device or bias.dtype != out.dtype: + bias = bias.to(device=out.device, dtype=out.dtype) + out.add_(bias) if self.adapter: out = self.adapter.apply(x=x, out=out) @@ -320,13 +398,15 @@ def _update_cached_weights(self, weights: torch.Tensor): return self._cached_weights[weights.dtype] = weights.detach() - def _consume_prefetched_weights(self, dtype: torch.dtype): + def _consume_prefetched_weights(self, dtype: torch.dtype, device: torch.device = None): if not self._lookahead_enabled or self.training: return None tensor = self._prefetched_weights.pop(dtype, None) if tensor is None: return None event = self._prefetch_events.pop(dtype, None) + if device is not None and tensor.device != device: + return None if event is not None and HAS_CUDA and tensor.device.type == "cuda": torch.cuda.current_stream(device=tensor.device).wait_event(event) return tensor @@ -473,13 +553,13 @@ def enable_lookahead(self, enabled: bool = True): self._reset_prefetch_state() return self - def set_lookahead_next(self, module: "TorchQuantLinear"): + def set_lookahead_next(self, module: "TorchLinear"): if module is None: self._lookahead_next = None self._reset_prefetch_state() return self - if isinstance(module, TorchQuantLinear): + if isinstance(module, TorchLinear): self._lookahead_next = module return self @@ -490,12 +570,12 @@ def set_lookahead_next(self, module: "TorchQuantLinear"): self._reset_prefetch_state() return self for target in targets: - if not isinstance(target, TorchQuantLinear): - raise TypeError("lookahead targets must be TorchQuantLinear modules or None") + if not isinstance(target, TorchLinear): + raise TypeError("lookahead targets must be TorchLinear modules or None") self._lookahead_next = targets return self - raise TypeError("lookahead target must be TorchQuantLinear, iterable of TorchQuantLinear, or None") + raise TypeError("lookahead target must be TorchLinear, iterable of TorchLinear, or None") def _reset_prefetch_state(self): for event in self._prefetch_events.values(): @@ -619,13 +699,13 @@ def _dequantize_weight_cached_248(self, num_itr: int = 1) -> torch.Tensor: def dequantize_model(model: PreTrainedModel): for name, module in model.named_modules(): - if isinstance(module, BaseQuantLinear) and not isinstance(module, TorchQuantLinear): + if isinstance(module, BaseQuantLinear) and not isinstance(module, TorchLinear): raise ValueError( - "Only models loaded using TorchQuantLinear are supported for dequantization. " - "Please load model using backend=BACKEND.TORCH." + "Only models loaded using TorchLinear are supported for dequantization. " + "Please load model using backend=BACKEND.GPTQ_TORCH." ) - if isinstance(module, TorchQuantLinear): + if isinstance(module, TorchLinear): # Create a new Linear layer with dequantized weights new_module = nn.Linear(module.in_features, module.out_features) new_module.weight = nn.Parameter(module.dequantize_weight().T.detach().to("cpu", torch.float16)) @@ -645,4 +725,4 @@ def dequantize_model(model: PreTrainedModel): return model -__all__ = ["TorchQuantLinear", "dequantize_model"] +__all__ = ["TorchLinear", "dequantize_model"] diff --git a/gptqmodel/nn_modules/qlinear/torch_aten_kernel.py b/gptqmodel/nn_modules/qlinear/torch_aten_kernel.py new file mode 100644 index 000000000..acd5d4648 --- /dev/null +++ b/gptqmodel/nn_modules/qlinear/torch_aten_kernel.py @@ -0,0 +1,298 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + + +from typing import Optional, Tuple + +import torch +import torch.nn as nn +from transformers import PreTrainedModel + +from ...adapter.adapter import Adapter, Lora +from ...models._const import DEVICE, PLATFORM +from ...nn_modules.qlinear import BaseQuantLinear, PackableQuantLinear +from ...quantization import FORMAT, METHOD +from ...utils.backend import BACKEND +from ...utils.logger import setup_logger +from .torch_fused import pack_scales_and_zeros + + +log = setup_logger() + + +def _has_local_int4pack_cpu_ops() -> bool: + return ( + hasattr(torch.ops.aten, "_convert_weight_to_int4pack_for_cpu") + and hasattr(torch.ops.aten, "_weight_int4pack_mm_for_cpu") + ) + + +def _cpu_int4pack_zero_offsets( + zero_codes: torch.Tensor, + scales: torch.Tensor, + bits: int, +) -> torch.Tensor: + # aten::_weight_int4pack_mm_for_cpu dequantizes as: + # scale * (signed_code - 2^(bits-1)) + zero_offset + # Convert stored GPTQ zero codes so the fused kernel reproduces: + # scale * (code - zero_code) + zero_center = 1 << (bits - 1) + return (zero_center - zero_codes.to(dtype=scales.dtype)) * scales + + +class TorchAtenLinear(PackableQuantLinear): + SUPPORTS_BACKENDS = [BACKEND.GPTQ_TORCH_ATEN] + SUPPORTS_METHODS = [METHOD.GPTQ] + SUPPORTS_FORMATS = {FORMAT.GPTQ: 110, FORMAT.GPTQ_V2: 110} + SUPPORTS_BITS = [4] + SUPPORTS_GROUP_SIZE = [16, 32, 64, 128] + SUPPORTS_DESC_ACT = [True, False] + SUPPORTS_SYM = [True, False] + SUPPORTS_SHARDS = True + SUPPORTS_TRAINING = True + SUPPORTS_AUTO_PADDING = True + SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [1] + SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [1] + SUPPORTS_DEVICES = [DEVICE.CPU] + SUPPORTS_PLATFORM = [PLATFORM.ALL] + SUPPORTS_PACK_DTYPES = [torch.int32] + SUPPORTS_ADAPTERS = [Lora] + + SUPPORTS_DTYPES = [torch.float16, torch.bfloat16] + + REQUIRES_FORMAT_V2 = True + + QUANT_TYPE = "torch_aten_kernel" + + gemm_int4_forward_kernel = None + + def __init__( + self, + bits: int, + group_size: int, + sym: bool, + desc_act: bool, + in_features: int, + out_features: int, + bias: bool = False, + pack_dtype: torch.dtype = torch.int32, + adapter: Adapter = None, + register_buffers: bool = True, + **kwargs, + ): + super().__init__( + bits=bits, + group_size=group_size, + sym=sym, + desc_act=desc_act, + in_features=in_features, + out_features=out_features, + bias=bias, + pack_dtype=pack_dtype, + backend=kwargs.pop("backend", BACKEND.GPTQ_TORCH_ATEN), + adapter=adapter, + register_buffers=register_buffers, + enable_wf_unsqueeze=kwargs.pop("enable_wf_unsqueeze", True), + **kwargs, + ) + + self.linear_mode = None + self.dequant_dtype = torch.int8 + + @classmethod + def validate_once(cls) -> Tuple[bool, Optional[Exception]]: + if not _has_local_int4pack_cpu_ops(): + cls.gemm_int4_forward_kernel = None + err = ImportError( + "TorchAtenLinear requires aten::_convert_weight_to_int4pack_for_cpu and " + "aten::_weight_int4pack_mm_for_cpu in this PyTorch build." + ) + log.warning(str(err)) + return False, err + + cls.gemm_int4_forward_kernel = staticmethod(torch.ops.aten._weight_int4pack_mm_for_cpu) + return True, None + + def post_init(self): + super().post_init() + self.optimize() + + def optimize(self): + if self.optimized: + return + + super().optimize() + + def _build_ret_idx(self) -> torch.Tensor: + existing = getattr(self, "ret_idx", None) + total = self.g_idx.shape[0] + if isinstance(existing, torch.Tensor) and existing.numel() == total: + return existing + + device = self.g_idx.device + ret_idx = torch.zeros(total, dtype=torch.int32, device=device) + group_size = max(int(self.group_size), 1) + groups = total // group_size + remainder = total % group_size + g_idx = self.g_idx.to(torch.int32) + g_idx_2 = g_idx * group_size + + if remainder > 0: + mask = g_idx == groups + if mask.any(): + g_idx_2[mask] += torch.arange(remainder, device=device, dtype=torch.int32) + + if groups > 0: + base = torch.arange(group_size, device=device, dtype=torch.int32) + for i in range(groups): + mask = g_idx == i + if not mask.any(): + continue + count = int(mask.sum().item()) + g_idx_2[mask] += base[:count] + + ret_idx[g_idx_2] = torch.arange(total, device=device, dtype=torch.int32) + self.ret_idx = ret_idx + return ret_idx + + def train(self, mode: bool = True): + old_train = self.training + if mode == old_train: + return self + + from ...utils.model import convert_gptq_v1_to_v2_format_module + + if self.SUPPORTS_TRAINING_USE_TORCH_KERNEL: + if mode: + if self.qzero_format() == 1: + if not hasattr(self, "qzeros_data_v1"): + self.qzeros_data_v1 = self.qzeros.data.clone() + convert_gptq_v1_to_v2_format_module(self, bits=self.bits, pack_dtype=self.pack_dtype) + self.qzeros_data_v2 = self.qzeros.data + else: + self.qzeros.data = self.qzeros_data_v2 + self.qzero_format(format=2) + else: + if hasattr(self, "qzeros_data_v1"): + self.qzeros.data = self.qzeros_data_v1 + self.qzero_format(format=1) + + return super().train(mode=mode) + + def transform_cpu(self): + self.scales = self.scales.to(torch.bfloat16).contiguous() + + weight = torch.bitwise_and( + torch.bitwise_right_shift( + torch.unsqueeze(self.qweight, 1).expand(-1, self.pack_factor, -1), + self.wf_unsqueeze_neg_one, + ).to(torch.uint8), + self.maxq, + ) + ret_idx = self._build_ret_idx() + weight = ( + weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]) + .index_select(0, ret_idx) + .t() + .contiguous() + ) + self.qweight = torch.ops.aten._convert_weight_to_int4pack_for_cpu(weight.int(), 1).contiguous() + + zero_codes = torch.bitwise_right_shift( + torch.unsqueeze(self.qzeros, 2).expand(-1, -1, self.pack_factor), + self.wf_unsqueeze_zero, + ).to(torch.uint8) + zero_codes = torch.bitwise_and(zero_codes, self.maxq).reshape(self.scales.shape) + self.qzeros = _cpu_int4pack_zero_offsets(zero_codes, self.scales, self.bits).contiguous() + self.scales_and_zeros = pack_scales_and_zeros(self.scales, self.qzeros) + + def transform(self, device): + if device == "cpu": + self.transform_cpu() + else: + raise NotImplementedError + + def forward(self, x: torch.Tensor): + out_shape = x.shape[:-1] + (self.out_features,) + x = x.reshape(-1, x.shape[-1]) + if ( + not self.training + and not x.requires_grad + and self.linear_mode is None + and _has_local_int4pack_cpu_ops() + and x.device.type == "cpu" + ): + self.transform(x.device.type) + self.linear_mode = "inference" + elif self.linear_mode is None: + self.linear_mode = "train" + + if self.linear_mode == "inference": + out = self._fused_op_forward(x).reshape(out_shape) + else: + num_itr = self.g_idx.shape[0] // x.shape[-1] + weights = self.dequantize_weight(num_itr=num_itr).to(x.dtype) + out = torch.matmul(x, weights).reshape(out_shape) + + if self.bias is not None: + out.add_(self.bias) + if self.adapter: + out = self.adapter.apply(x=x, out=out) + + return out + + @torch.no_grad + def _fused_op_forward(self, x): + x = x[:, self.ret_idx].contiguous() + if x.device.type != "cpu": + raise NotImplementedError + + original_dtype = x.dtype + if original_dtype != torch.bfloat16: + x = x.to(torch.bfloat16) + out = torch.ops.aten._weight_int4pack_mm_for_cpu( + x, + self.qweight, + self.group_size, + self.scales_and_zeros, + ) + if original_dtype != torch.bfloat16: + out = out.to(original_dtype) + return out + + def _empty_gptq_only_weights(self): + self.qzeros = None + self.qweight = None + self.g_idx = None + self.scales = None + + +def dequantize_model(model: PreTrainedModel): + for name, module in model.named_modules(): + if isinstance(module, BaseQuantLinear) and not isinstance(module, TorchAtenLinear): + raise ValueError( + "Only models loaded using TorchAtenLinear are supported for dequantization. " + "Please load model using backend=BACKEND.GPTQ_TORCH_ATEN" + ) + + if isinstance(module, TorchAtenLinear): + new_module = nn.Linear(module.in_features, module.out_features) + new_module.weight = nn.Parameter(module.dequantize_weight().T.detach().to("cpu", torch.float16)) + new_module.bias = torch.nn.Parameter(module.bias) + + parent = model + if "." in name: + parent_name, module_name = name.rsplit(".", 1) + parent = dict(model.named_modules())[parent_name] + else: + module_name = name + + setattr(parent, module_name, new_module) + + del model.config.quantization_config + return model + + +__all__ = ["TorchAtenLinear", "dequantize_model"] diff --git a/gptqmodel/nn_modules/qlinear/torch_aten_kernel_awq.py b/gptqmodel/nn_modules/qlinear/torch_aten_kernel_awq.py new file mode 100644 index 000000000..bd44c3a38 --- /dev/null +++ b/gptqmodel/nn_modules/qlinear/torch_aten_kernel_awq.py @@ -0,0 +1,208 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import math +from typing import Optional, Tuple + +import torch + +from ...adapter.adapter import Adapter +from ...quantization import FORMAT, METHOD +from ...quantization.awq.utils.packing_utils import ( + dequantize_gemm, + reverse_awq_order, + unpack_awq, +) +from ...utils.backend import BACKEND +from ...utils.logger import setup_logger +from . import AWQuantLinear +from .torch_aten_kernel import TorchAtenLinear, _cpu_int4pack_zero_offsets, _has_local_int4pack_cpu_ops +from .torch_fused import pack_scales_and_zeros + + +log = setup_logger() + + +class TorchAtenAwqLinear(AWQuantLinear): + """AWQ CPU int4pack backend implemented with local ATen ops.""" + + QUANT_TYPE = "awq_torch_aten_kernel" + + SUPPORTS_BACKENDS = [BACKEND.AWQ_TORCH_ATEN] + SUPPORTS_METHODS = [METHOD.AWQ] + SUPPORTS_FORMATS = {FORMAT.GEMM: 110} + + SUPPORTS_BITS = TorchAtenLinear.SUPPORTS_BITS + SUPPORTS_GROUP_SIZE = TorchAtenLinear.SUPPORTS_GROUP_SIZE + SUPPORTS_DESC_ACT = TorchAtenLinear.SUPPORTS_DESC_ACT + SUPPORTS_SYM = TorchAtenLinear.SUPPORTS_SYM + SUPPORTS_SHARDS = TorchAtenLinear.SUPPORTS_SHARDS + SUPPORTS_TRAINING = TorchAtenLinear.SUPPORTS_TRAINING + SUPPORTS_AUTO_PADDING = TorchAtenLinear.SUPPORTS_AUTO_PADDING + SUPPORTS_IN_FEATURES_DIVISIBLE_BY = TorchAtenLinear.SUPPORTS_IN_FEATURES_DIVISIBLE_BY + SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = TorchAtenLinear.SUPPORTS_OUT_FEATURES_DIVISIBLE_BY + SUPPORTS_DEVICES = TorchAtenLinear.SUPPORTS_DEVICES + SUPPORTS_PLATFORM = TorchAtenLinear.SUPPORTS_PLATFORM + SUPPORTS_PACK_DTYPES = TorchAtenLinear.SUPPORTS_PACK_DTYPES + SUPPORTS_ADAPTERS = TorchAtenLinear.SUPPORTS_ADAPTERS + REQUIRES_FORMAT_V2 = False + + SUPPORTS_DTYPES = [torch.float16, torch.bfloat16] + + gemm_int4_forward_kernel = None + + def __init__( + self, + bits: int, + group_size: int, + sym: bool, + desc_act: bool, + in_features: int, + out_features: int, + bias: bool = False, + pack_dtype: torch.dtype = torch.int32, + adapter: Adapter = None, + register_buffers: bool = True, + **kwargs, + ): + kwargs.setdefault("backend", BACKEND.AWQ_TORCH_ATEN) + super().__init__( + bits=bits, + group_size=group_size, + sym=sym, + desc_act=desc_act, + in_features=in_features, + out_features=out_features, + bias=bias, + pack_dtype=pack_dtype, + adapter=adapter, + register_buffers=False, + **kwargs, + ) + + self.linear_mode = None + + if register_buffers: + pack_cols = max(1, self.out_features // self.pack_factor) + qweight_shape = (self.in_features, pack_cols) + group_size = max(int(self.group_size), 1) + group_rows = max(1, math.ceil(self.in_features / group_size)) + + self.register_buffer( + "qweight", + torch.zeros(qweight_shape, dtype=self.pack_dtype), + ) + + self.register_buffer( + "qzeros", + torch.zeros((group_rows, pack_cols), dtype=self.pack_dtype), + ) + + self.register_buffer( + "scales", + torch.zeros((group_rows, self.out_features), dtype=torch.float16), + ) + + if bias: + self.register_buffer("bias", torch.zeros(self.out_features, dtype=torch.float16)) + else: + self.bias = None + + @classmethod + def validate_once(cls) -> Tuple[bool, Optional[Exception]]: + ok, err = TorchAtenLinear.validate_once() + if ok: + cls.gemm_int4_forward_kernel = TorchAtenLinear.gemm_int4_forward_kernel + else: + cls.gemm_int4_forward_kernel = None + return ok, err + + def post_init(self): + super().post_init() + self.optimize() + + def optimize(self): + if self.optimized: + return + + super().optimize() + + def transform_cpu(self): + iweight, izeros = unpack_awq(self.qweight, self.qzeros, self.bits) + iweight, izeros = reverse_awq_order(iweight, izeros, self.bits) + max_val = (1 << self.bits) - 1 + iweight = torch.bitwise_and(iweight, max_val).to(torch.uint8) + izeros = torch.bitwise_and(izeros, max_val).reshape(self.scales.shape).to(torch.uint8) + + self.scales = self.scales.to(torch.bfloat16).contiguous() + self.qweight = torch.ops.aten._convert_weight_to_int4pack_for_cpu(iweight.t().int(), 1).contiguous() + self.qzeros = _cpu_int4pack_zero_offsets(izeros, self.scales, self.bits).contiguous() + self.scales_and_zeros = pack_scales_and_zeros(self.scales, self.qzeros) + + def transform(self, device): + if device == "cpu": + self.transform_cpu() + else: + raise NotImplementedError( + "TorchAtenAwqLinear only supports fused transforms on CPU devices." + ) + + def awq_weight_dequantize(self, device, dtype): + return dequantize_gemm( + qweight=self.qweight, + qzeros=self.qzeros, + scales=self.scales, + bits=self.bits, + group_size=self.group_size, + ).to(device=device, dtype=dtype) + + @torch.no_grad() + def _fused_op_forward(self, x): + if x.device.type != "cpu": + raise NotImplementedError + + original_dtype = x.dtype + if original_dtype != torch.bfloat16: + x = x.to(torch.bfloat16) + out = torch.ops.aten._weight_int4pack_mm_for_cpu( + x, + self.qweight, + self.group_size, + self.scales_and_zeros, + ) + if original_dtype != torch.bfloat16: + out = out.to(original_dtype) + return out + + def forward(self, x: torch.Tensor): + out_shape = x.shape[:-1] + (self.out_features,) + x = x.reshape(-1, x.shape[-1]) + if ( + not self.training + and not x.requires_grad + and self.linear_mode is None + and _has_local_int4pack_cpu_ops() + and x.device.type == "cpu" + ): + self.transform(x.device.type) + self.linear_mode = "inference" + elif self.linear_mode is None: + self.linear_mode = "train" + + if self.linear_mode == "inference": + out = self._fused_op_forward(x).reshape(out_shape) + else: + weight = self.awq_weight_dequantize(device=x.device, dtype=x.dtype) + out = torch.matmul(x, weight).reshape(out_shape) + + if self.bias is not None: + out.add_(self.bias) + if self.adapter: + out = self.adapter.apply(x=x, out=out) + + return out + + +__all__ = ["TorchAtenAwqLinear"] diff --git a/gptqmodel/nn_modules/qlinear/torch_awq.py b/gptqmodel/nn_modules/qlinear/torch_awq.py index 15748c777..52511f9c9 100644 --- a/gptqmodel/nn_modules/qlinear/torch_awq.py +++ b/gptqmodel/nn_modules/qlinear/torch_awq.py @@ -17,8 +17,8 @@ log = setup_logger() -class AwqTorchQuantLinear(AWQuantLinear): - SUPPORTS_BACKENDS = [BACKEND.TORCH_AWQ] +class AwqTorchLinear(AWQuantLinear): + SUPPORTS_BACKENDS = [BACKEND.AWQ_TORCH] SUPPORTS_METHODS = [METHOD.AWQ] SUPPORTS_FORMATS = {FORMAT.GEMM: 10} SUPPORTS_BITS = [4] @@ -36,7 +36,7 @@ class AwqTorchQuantLinear(AWQuantLinear): SUPPORTS_PACK_DTYPES = [torch.int32] SUPPORTS_ADAPTERS = [Lora] - SUPPORTS_DTYPES = [torch.float16] + SUPPORTS_DTYPES = [torch.float16, torch.bfloat16] REQUIRES_FORMAT_V2 = False @@ -56,6 +56,7 @@ def __init__( register_buffers: bool = False, **kwargs, ): + self.compute_dtype = kwargs.get("dtype") or torch.float16 super().__init__( bits=bits, group_size=group_size, @@ -65,13 +66,22 @@ def __init__( out_features=out_features, bias=bias, pack_dtype=pack_dtype, - backend=kwargs.pop("backend", BACKEND.TORCH_AWQ), + backend=kwargs.pop("backend", BACKEND.AWQ_TORCH), adapter=adapter, register_buffers=register_buffers, **kwargs, ) + if register_buffers: + if self.scales is not None and self.scales.dtype != self.compute_dtype: + self.scales = self.scales.to(dtype=self.compute_dtype) + if self.bias is not None and self.bias.dtype != self.compute_dtype: + self.bias = self.bias.to(dtype=self.compute_dtype) def post_init(self): + if self.scales is not None and self.scales.dtype not in (torch.float16, torch.bfloat16): + self.scales = self.scales.to(dtype=torch.float16) + if self.bias is not None and self.bias.dtype not in (torch.float16, torch.bfloat16): + self.bias = self.bias.to(dtype=torch.float16) super().post_init() def extra_repr(self) -> str: @@ -80,10 +90,82 @@ def extra_repr(self) -> str: f"bias={self.bias is not None}, bits={self.bits}, group_size={self.group_size}" ) + def pack(self, linear: torch.nn.Module, scales: torch.Tensor, zeros: torch.Tensor, g_idx: torch.Tensor = None): + del g_idx + assert scales is not None and zeros is not None + + scales = scales.t().contiguous() + zeros = zeros.t().contiguous() + scale_zeros = zeros * scales + + scale_dtype = scales.dtype if scales.dtype in (torch.float16, torch.bfloat16) else torch.float16 + self.register_buffer("scales", scales.clone().to(scale_dtype)) + if linear.bias is not None: + bias_dtype = linear.bias.dtype if linear.bias.dtype in (torch.float16, torch.bfloat16) else scale_dtype + self.register_buffer("bias", linear.bias.clone().to(bias_dtype)) + else: + self.bias = None + + pack_num = 32 // self.bits + + intweight = [] + for idx in range(self.in_features): + intweight.append( + torch.round( + (linear.weight.data[:, idx] + scale_zeros[idx // self.group_size]) + / self.scales[idx // self.group_size] + ).to(torch.int32)[:, None] + ) + intweight = torch.cat(intweight, dim=1).t().contiguous() + + qweight = torch.zeros( + (intweight.shape[0], intweight.shape[1] // 32 * self.bits), + dtype=torch.int32, + device=intweight.device, + ) + qzeros = torch.zeros( + (zeros.shape[0], zeros.shape[1] // 32 * self.bits), + dtype=torch.int32, + device=zeros.device, + ) + + if self.bits != 4: + raise NotImplementedError("Only 4-bit are supported for now.") + order_map = [0, 2, 4, 6, 1, 3, 5, 7] + + for col in range(intweight.shape[1] // pack_num): + for i in range(pack_num): + qweight_col = intweight[:, col * pack_num + order_map[i]] + qweight[:, col] |= qweight_col << (i * self.bits) + + for col in range(zeros.shape[1] // pack_num): + for i in range(pack_num): + qzero_col = zeros[:, col * pack_num + order_map[i]].to(torch.int32) + qzeros[:, col] |= qzero_col << (i * self.bits) + + self.register_buffer("qweight", qweight) + self.register_buffer("qzeros", qzeros) + + def _ensure_runtime_dtype(self, *, device: torch.device, dtype: torch.dtype) -> None: + if self.scales.device != device or self.scales.dtype != dtype or not self.scales.is_contiguous(): + self.scales = self.scales.to(device=device, dtype=dtype).contiguous() + if self.bias is not None and ( + self.bias.device != device or self.bias.dtype != dtype or not self.bias.is_contiguous() + ): + self.bias = self.bias.to(device=device, dtype=dtype).contiguous() + def forward(self, x: torch.Tensor): + input_dtype = x.dtype + compute_dtype = input_dtype if input_dtype in (torch.float16, torch.bfloat16) else torch.float16 original_shape = x.shape[:-1] + (self.out_features,) device = x.device x_flat = x.reshape(-1, x.shape[-1]) + if x_flat.dtype != compute_dtype or x_flat.device != device: + x_flat = x_flat.to(device=device, dtype=compute_dtype) + elif not x_flat.is_contiguous(): + x_flat = x_flat.contiguous() + + self._ensure_runtime_dtype(device=device, dtype=compute_dtype) weight = dequantize_gemm( qweight=self.qweight, @@ -92,9 +174,10 @@ def forward(self, x: torch.Tensor): bits=self.bits, group_size=self.group_size, ) - assert weight.dtype == torch.float16, f"weight {weight.dtype} is not float16" - if weight.dtype != x_flat.dtype or weight.device != device: - weight = weight.to(device=device, dtype=x_flat.dtype) + if weight.dtype not in (torch.float16, torch.bfloat16): + raise AssertionError(f"weight {weight.dtype} is not float16 or bfloat16") + if weight.dtype != compute_dtype or weight.device != device or not weight.is_contiguous(): + weight = weight.to(device=device, dtype=compute_dtype).contiguous() output = torch.matmul(x_flat, weight) @@ -104,8 +187,11 @@ def forward(self, x: torch.Tensor): if self.adapter: output = self.adapter.apply(x=x_flat, out=output) + if output.dtype != input_dtype: + output = output.to(dtype=input_dtype) + output = output.reshape(original_shape) return output -__all__ = ["AwqTorchQuantLinear"] +__all__ = ["AwqTorchLinear"] diff --git a/gptqmodel/nn_modules/qlinear/torch_fused.py b/gptqmodel/nn_modules/qlinear/torch_fused.py index 36bb0e19a..d191e375f 100644 --- a/gptqmodel/nn_modules/qlinear/torch_fused.py +++ b/gptqmodel/nn_modules/qlinear/torch_fused.py @@ -51,8 +51,8 @@ def forward(self, x): return out -class TorchFusedQuantLinear(PackableQuantLinear): - SUPPORTS_BACKENDS = [BACKEND.TORCH_FUSED] +class TorchFusedLinear(PackableQuantLinear): + SUPPORTS_BACKENDS = [BACKEND.GPTQ_TORCH_FUSED] SUPPORTS_METHODS = [METHOD.GPTQ] SUPPORTS_FORMATS = {FORMAT.GPTQ: 50, FORMAT.GPTQ_V2: 50} SUPPORTS_BITS = [4] @@ -69,7 +69,7 @@ class TorchFusedQuantLinear(PackableQuantLinear): SUPPORTS_PACK_DTYPES = [torch.int32] SUPPORTS_ADAPTERS = [Lora] - SUPPORTS_DTYPES = [torch.float16, torch.bfloat16] + SUPPORTS_DTYPES = [torch.float32, torch.float16, torch.bfloat16] REQUIRES_FORMAT_V2 = True @@ -99,7 +99,7 @@ def __init__( out_features=out_features, bias=bias, pack_dtype=pack_dtype, - backend=kwargs.pop("backend", BACKEND.TORCH), + backend=kwargs.pop("backend", BACKEND.GPTQ_TORCH_FUSED), adapter=adapter, register_buffers=register_buffers, enable_wf_unsqueeze=kwargs.pop("enable_wf_unsqueeze", True), @@ -236,6 +236,9 @@ def transform(self, dtype, device): def forward(self, x: torch.Tensor): out_shape = x.shape[:-1] + (self.out_features,) x = x.reshape(-1, x.shape[-1]) + input_dtype = x.dtype + if input_dtype == torch.float32: + x = x.to(torch.bfloat16) if not self.training and not x.requires_grad and self.linear_mode is None and TORCH_HAS_FUSED_OPS: # one-time transform per module for xpu aten fused ops self.transform(x.dtype, x.device.type) @@ -268,6 +271,9 @@ def forward(self, x: torch.Tensor): if self.adapter: out = self.adapter.apply(x=x, out=out) + if input_dtype == torch.float32: + out = out.to(torch.float32) + return out @torch.no_grad @@ -295,13 +301,13 @@ def _empty_gptq_only_weights(self): def dequantize_model(model: PreTrainedModel): for name, module in model.named_modules(): - if isinstance(module, BaseQuantLinear) and not isinstance(module, TorchFusedQuantLinear): + if isinstance(module, BaseQuantLinear) and not isinstance(module, TorchFusedLinear): raise ValueError( - "Only models loaded using TorchFusedQuantLinear are supported for dequantization. " - "Please load model using backend=BACKEND.TORCH_FUSED" + "Only models loaded using TorchFusedLinear are supported for dequantization. " + "Please load model using backend=BACKEND.GPTQ_TORCH_FUSED" ) - if isinstance(module, TorchFusedQuantLinear): + if isinstance(module, TorchFusedLinear): # Create a new Linear layer with dequantized weights new_module = nn.Linear(module.in_features, module.out_features) new_module.weight = nn.Parameter(module.dequantize_weight().T.detach().to("cpu", torch.float16)) @@ -321,4 +327,4 @@ def dequantize_model(model: PreTrainedModel): return model -__all__ = ["TorchFusedQuantLinear", "dequantize_model"] +__all__ = ["TorchFusedLinear", "dequantize_model"] diff --git a/gptqmodel/nn_modules/qlinear/torch_fused_awq.py b/gptqmodel/nn_modules/qlinear/torch_fused_awq.py index d76584646..28dad3076 100644 --- a/gptqmodel/nn_modules/qlinear/torch_fused_awq.py +++ b/gptqmodel/nn_modules/qlinear/torch_fused_awq.py @@ -18,38 +18,39 @@ from ...utils.backend import BACKEND from ...utils.logger import setup_logger from ...utils.torch import TORCH_HAS_FUSED_OPS -from .torch_fused import Int4PackedOp, TorchFusedQuantLinear, pack_scales_and_zeros +from . import AWQuantLinear +from .torch_fused import Int4PackedOp, TorchFusedLinear, pack_scales_and_zeros log = setup_logger() -class TorchFusedAwqQuantLinear(TorchFusedQuantLinear): +class TorchFusedAwqLinear(AWQuantLinear): """Torch fused AWQ variant based on GPTQ fused kernels via CPU int4 packing.""" QUANT_TYPE = "torch_fused_awq" - SUPPORTS_BACKENDS = [BACKEND.TORCH_FUSED_AWQ] + SUPPORTS_BACKENDS = [BACKEND.AWQ_TORCH_FUSED] SUPPORTS_METHODS = [METHOD.AWQ] SUPPORTS_FORMATS = {FORMAT.GEMM: 20} # inherit from torch fused - SUPPORTS_BITS = TorchFusedQuantLinear.SUPPORTS_BITS - SUPPORTS_GROUP_SIZE = TorchFusedQuantLinear.SUPPORTS_GROUP_SIZE - SUPPORTS_DESC_ACT = TorchFusedQuantLinear.SUPPORTS_DESC_ACT - SUPPORTS_SYM = TorchFusedQuantLinear.SUPPORTS_SYM - SUPPORTS_SHARDS = TorchFusedQuantLinear.SUPPORTS_SHARDS - SUPPORTS_TRAINING = TorchFusedQuantLinear.SUPPORTS_TRAINING - SUPPORTS_AUTO_PADDING = TorchFusedQuantLinear.SUPPORTS_AUTO_PADDING - SUPPORTS_IN_FEATURES_DIVISIBLE_BY = TorchFusedQuantLinear.SUPPORTS_IN_FEATURES_DIVISIBLE_BY - SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = TorchFusedQuantLinear.SUPPORTS_OUT_FEATURES_DIVISIBLE_BY - SUPPORTS_DEVICES = TorchFusedQuantLinear.SUPPORTS_DEVICES - SUPPORTS_PLATFORM = TorchFusedQuantLinear.SUPPORTS_PLATFORM - SUPPORTS_PACK_DTYPES = TorchFusedQuantLinear.SUPPORTS_PACK_DTYPES - SUPPORTS_ADAPTERS = TorchFusedQuantLinear.SUPPORTS_ADAPTERS - REQUIRES_FORMAT_V2 = TorchFusedQuantLinear.REQUIRES_FORMAT_V2 - - SUPPORTS_DTYPES = [torch.float16, torch.bfloat16] + SUPPORTS_BITS = TorchFusedLinear.SUPPORTS_BITS + SUPPORTS_GROUP_SIZE = TorchFusedLinear.SUPPORTS_GROUP_SIZE + SUPPORTS_DESC_ACT = TorchFusedLinear.SUPPORTS_DESC_ACT + SUPPORTS_SYM = TorchFusedLinear.SUPPORTS_SYM + SUPPORTS_SHARDS = TorchFusedLinear.SUPPORTS_SHARDS + SUPPORTS_TRAINING = TorchFusedLinear.SUPPORTS_TRAINING + SUPPORTS_AUTO_PADDING = TorchFusedLinear.SUPPORTS_AUTO_PADDING + SUPPORTS_IN_FEATURES_DIVISIBLE_BY = TorchFusedLinear.SUPPORTS_IN_FEATURES_DIVISIBLE_BY + SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = TorchFusedLinear.SUPPORTS_OUT_FEATURES_DIVISIBLE_BY + SUPPORTS_DEVICES = TorchFusedLinear.SUPPORTS_DEVICES + SUPPORTS_PLATFORM = TorchFusedLinear.SUPPORTS_PLATFORM + SUPPORTS_PACK_DTYPES = TorchFusedLinear.SUPPORTS_PACK_DTYPES + SUPPORTS_ADAPTERS = TorchFusedLinear.SUPPORTS_ADAPTERS + REQUIRES_FORMAT_V2 = False + + SUPPORTS_DTYPES = [torch.float32, torch.float16, torch.bfloat16] def __init__( self, @@ -65,7 +66,7 @@ def __init__( register_buffers: bool = True, **kwargs, ): - kwargs.setdefault("backend", BACKEND.TORCH_FUSED_AWQ) + kwargs.setdefault("backend", BACKEND.AWQ_TORCH_FUSED) super().__init__( bits=bits, group_size=group_size, @@ -78,10 +79,11 @@ def __init__( adapter=adapter, # Skip base buffer init, we need to manually init buffers for awq register_buffers=False, - enable_wf_unsqueeze=kwargs.pop("enable_wf_unsqueeze", False), **kwargs, ) + self.linear_mode = None + # Create awq buffers if register_buffers: # AWQ packs each input row into pack_factor-wide columns for int4 lanes. @@ -113,6 +115,13 @@ def __init__( def post_init(self): super().post_init() + self.optimize() + + def optimize(self): + if self.optimized: + return + + super().optimize() def prepare_awq_fused_tensors(self, need_zeros: bool = True): self.scales.to(torch.float16).contiguous() @@ -233,7 +242,6 @@ def awq_weight_dequantize(self, device, dtype): scales=self.scales, bits=self.bits, group_size=self.group_size, - sym=self.sym, ).to(device=device, dtype=dtype) def transform(self, dtype, device): @@ -243,13 +251,16 @@ def transform(self, dtype, device): self.transform_xpu_awq(dtype) else: raise NotImplementedError( - "TorchFusedAwqQuantLinear only supports fused transforms on CPU or XPU devices." + "TorchFusedAwqLinear only supports fused transforms on CPU or XPU devices." ) def forward(self, x: torch.Tensor): out_shape = x.shape[:-1] + (self.out_features,) x_flat = x.reshape(-1, x.shape[-1]) - self.assert_supported_dtype(x_flat.dtype) + input_dtype = x_flat.dtype + self.assert_supported_dtype(input_dtype) + if input_dtype == torch.float32: + x_flat = x_flat.to(torch.bfloat16) if ( not self.training and not x_flat.requires_grad @@ -281,6 +292,9 @@ def forward(self, x: torch.Tensor): if self.adapter: out = self.adapter.apply(x=x_flat, out=out) + if input_dtype == torch.float32: + out = out.to(torch.float32) + return out.reshape(out_shape) def assert_supported_dtype(self, dtype: torch.dtype): @@ -291,4 +305,4 @@ def assert_supported_dtype(self, dtype: torch.dtype): ) -__all__ = ["TorchFusedAwqQuantLinear"] +__all__ = ["TorchFusedAwqLinear"] diff --git a/gptqmodel/nn_modules/qlinear/torch_int8.py b/gptqmodel/nn_modules/qlinear/torch_int8.py index b3aea8a46..685bcec9d 100644 --- a/gptqmodel/nn_modules/qlinear/torch_int8.py +++ b/gptqmodel/nn_modules/qlinear/torch_int8.py @@ -13,7 +13,7 @@ from ...adapter.adapter import Adapter, Lora from ...models._const import DEVICE, PLATFORM -from ...nn_modules.qlinear import BaseQuantLinear +from ...nn_modules.qlinear import BaseQuantLinear, GPTQQuantLinear from ...quantization import FORMAT, METHOD from ...utils.backend import BACKEND @@ -84,8 +84,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return torch.ops.aten._weight_int8pack_mm(x, self.int8_weight_nk, self.int8_channel_scale) -class TorchInt8QuantLinear(BaseQuantLinear): - SUPPORTS_BACKENDS = [BACKEND.TORCH_INT8] +class TorchInt8Linear(GPTQQuantLinear): + SUPPORTS_BACKENDS = [BACKEND.GPTQ_TORCH_INT8] SUPPORTS_METHODS = [METHOD.GPTQ] # Keep auto-selection unchanged; this kernel is enabled via explicit backend selection. SUPPORTS_FORMATS = {FORMAT.GPTQ: 0, FORMAT.GPTQ_V2: 0} @@ -134,7 +134,7 @@ def __init__( out_features=out_features, bias=bias, pack_dtype=pack_dtype, - backend=kwargs.pop("backend", BACKEND.TORCH_INT8), + backend=kwargs.pop("backend", BACKEND.GPTQ_TORCH_INT8), adapter=adapter, register_buffers=register_buffers, **kwargs, @@ -180,7 +180,7 @@ def _ensure_unpack_buffers(self): return if self.bits not in [2, 4, 8]: - raise NotImplementedError("TorchInt8QuantLinear unpack only supports bits in [2, 4, 8].") + raise NotImplementedError("TorchInt8Linear unpack only supports bits in [2, 4, 8].") wf = torch.tensor(list(range(0, self.pack_dtype_bits, self.bits)), dtype=torch.int32).unsqueeze(0) device = self.qweight.device @@ -208,10 +208,10 @@ def dequantize_weight(self, num_itr: int = 1): return dequantized if num_itr != 1: - raise NotImplementedError("TorchInt8QuantLinear dequantize_weight only supports num_itr == 1.") + raise NotImplementedError("TorchInt8Linear dequantize_weight only supports num_itr == 1.") if not self._has_all_gptq_buffers(): - raise RuntimeError("TorchInt8QuantLinear missing GPTQ buffers for dequantization.") + raise RuntimeError("TorchInt8Linear missing GPTQ buffers for dequantization.") self._ensure_unpack_buffers() @@ -243,7 +243,7 @@ def pack_block( workers: int = 1, ): raise NotImplementedError( - "TorchInt8QuantLinear is not packable. Load GPTQ int4 tensors and let post_init() convert to int8." + "TorchInt8Linear is not packable. Load GPTQ int4 tensors and let post_init() convert to int8." ) def pack( @@ -256,7 +256,7 @@ def pack( workers: int = 1, ): raise NotImplementedError( - "TorchInt8QuantLinear is not packable. Load GPTQ int4 tensors and let post_init() convert to int8." + "TorchInt8Linear is not packable. Load GPTQ int4 tensors and let post_init() convert to int8." ) def transform_cpu(self, dtype: torch.dtype): @@ -266,14 +266,14 @@ def transform_cpu(self, dtype: torch.dtype): def transform(self, dtype: torch.dtype, device: str): if device != "cpu": - raise NotImplementedError("TorchInt8QuantLinear only supports CPU.") + raise NotImplementedError("TorchInt8Linear only supports CPU.") self.transform_cpu(dtype) def forward(self, x: torch.Tensor): if self.training: - raise NotImplementedError("TorchInt8QuantLinear does not support training mode.") + raise NotImplementedError("TorchInt8Linear does not support training mode.") if self.int8_module is None: - raise RuntimeError("TorchInt8QuantLinear int8 module is not initialized. Ensure post_init() has been called.") + raise RuntimeError("TorchInt8Linear int8 module is not initialized. Ensure post_init() has been called.") # Common decode path is 2D [M, K]. Skip reshape/out-shape overhead on this hot path. if x.dim() == 2: @@ -299,9 +299,9 @@ def forward(self, x: torch.Tensor): @torch.no_grad def _fused_op_forward(self, x: torch.Tensor) -> torch.Tensor: if x.device.type != "cpu": - raise NotImplementedError("TorchInt8QuantLinear fused path is CPU-only.") + raise NotImplementedError("TorchInt8Linear fused path is CPU-only.") if self.int8_module is None: - raise RuntimeError("TorchInt8QuantLinear int8 module is not initialized.") + raise RuntimeError("TorchInt8Linear int8 module is not initialized.") return self.int8_module(x.contiguous()) def _empty_gptq_only_weights(self): @@ -310,16 +310,16 @@ def _empty_gptq_only_weights(self): def dequantize_model(model: PreTrainedModel): - from .torch_int8_awq import TorchInt8AwqQuantLinear + from .torch_int8_awq import TorchInt8AwqLinear - supported_int8_qlinears = (TorchInt8QuantLinear, TorchInt8AwqQuantLinear) + supported_int8_qlinears = (TorchInt8Linear, TorchInt8AwqLinear) for name, module in model.named_modules(): if isinstance(module, BaseQuantLinear) and not isinstance(module, supported_int8_qlinears): raise ValueError( - "Only models loaded using TorchInt8QuantLinear or TorchInt8AwqQuantLinear are supported " - "for dequantization. Please load model using backend=BACKEND.TORCH_INT8 or " - "backend=BACKEND.TORCH_INT8_AWQ" + "Only models loaded using TorchInt8Linear or TorchInt8AwqLinear are supported " + "for dequantization. Please load model using backend=BACKEND.GPTQ_TORCH_INT8 or " + "backend=BACKEND.AWQ_TORCH_INT8" ) if isinstance(module, supported_int8_qlinears): @@ -345,7 +345,7 @@ def dequantize_model(model: PreTrainedModel): "INT8_SCALE_BUFFER_NAME", "INT8_WEIGHT_BUFFER_NAME", "Int8PackedModule", - "TorchInt8QuantLinear", + "TorchInt8Linear", "_cached_int8_dequantize", "_has_int8_mm_op", "_requantize_to_int8", diff --git a/gptqmodel/nn_modules/qlinear/torch_int8_awq.py b/gptqmodel/nn_modules/qlinear/torch_int8_awq.py index 70256041e..973602d3d 100644 --- a/gptqmodel/nn_modules/qlinear/torch_int8_awq.py +++ b/gptqmodel/nn_modules/qlinear/torch_int8_awq.py @@ -25,8 +25,8 @@ ) -class TorchInt8AwqQuantLinear(AWQuantLinear): - SUPPORTS_BACKENDS = [BACKEND.TORCH_INT8_AWQ] +class TorchInt8AwqLinear(AWQuantLinear): + SUPPORTS_BACKENDS = [BACKEND.AWQ_TORCH_INT8] SUPPORTS_METHODS = [METHOD.AWQ] # Keep auto-selection unchanged; this kernel is enabled via explicit backend selection. SUPPORTS_FORMATS = {FORMAT.GEMM: 0} @@ -74,7 +74,7 @@ def __init__( out_features=out_features, bias=bias, pack_dtype=pack_dtype, - backend=kwargs.pop("backend", BACKEND.TORCH_INT8_AWQ), + backend=kwargs.pop("backend", BACKEND.AWQ_TORCH_INT8), adapter=adapter, register_buffers=False, **kwargs, @@ -141,7 +141,7 @@ def dequantize_weight(self): return dequantized if not self._has_all_awq_buffers(): - raise RuntimeError("TorchInt8AwqQuantLinear missing AWQ buffers for dequantization.") + raise RuntimeError("TorchInt8AwqLinear missing AWQ buffers for dequantization.") return dequantize_gemm( qweight=self.qweight, @@ -162,7 +162,7 @@ def pack_block( workers: int = 1, ): raise NotImplementedError( - "TorchInt8AwqQuantLinear is not packable. Load AWQ int4 tensors and let post_init() convert to int8." + "TorchInt8AwqLinear is not packable. Load AWQ int4 tensors and let post_init() convert to int8." ) def pack( @@ -175,7 +175,7 @@ def pack( workers: int = 1, ): raise NotImplementedError( - "TorchInt8AwqQuantLinear is not packable. Load AWQ int4 tensors and let post_init() convert to int8." + "TorchInt8AwqLinear is not packable. Load AWQ int4 tensors and let post_init() convert to int8." ) def transform_cpu(self, dtype: torch.dtype): @@ -184,15 +184,15 @@ def transform_cpu(self, dtype: torch.dtype): def transform(self, dtype: torch.dtype, device: str): if device != "cpu": - raise NotImplementedError("TorchInt8AwqQuantLinear only supports CPU.") + raise NotImplementedError("TorchInt8AwqLinear only supports CPU.") self.transform_cpu(dtype) def forward(self, x: torch.Tensor): if self.training: - raise NotImplementedError("TorchInt8AwqQuantLinear does not support training mode.") + raise NotImplementedError("TorchInt8AwqLinear does not support training mode.") if self.int8_module is None: raise RuntimeError( - "TorchInt8AwqQuantLinear int8 module is not initialized. Ensure post_init() has been called." + "TorchInt8AwqLinear int8 module is not initialized. Ensure post_init() has been called." ) if x.dim() == 2: @@ -218,10 +218,10 @@ def forward(self, x: torch.Tensor): @torch.no_grad def _fused_op_forward(self, x: torch.Tensor) -> torch.Tensor: if x.device.type != "cpu": - raise NotImplementedError("TorchInt8AwqQuantLinear fused path is CPU-only.") + raise NotImplementedError("TorchInt8AwqLinear fused path is CPU-only.") if self.int8_module is None: - raise RuntimeError("TorchInt8AwqQuantLinear int8 module is not initialized.") + raise RuntimeError("TorchInt8AwqLinear int8 module is not initialized.") return self.int8_module(x.contiguous()) -__all__ = ["TorchInt8AwqQuantLinear"] +__all__ = ["TorchInt8AwqLinear"] diff --git a/gptqmodel/nn_modules/qlinear/tritonv2.py b/gptqmodel/nn_modules/qlinear/tritonv2.py index da4c0fae3..7fa92ab7f 100644 --- a/gptqmodel/nn_modules/qlinear/tritonv2.py +++ b/gptqmodel/nn_modules/qlinear/tritonv2.py @@ -13,14 +13,14 @@ from ...utils.backend import BACKEND from ...utils.logger import setup_logger from ...utils.python import has_gil_disabled -from .torch import TorchQuantLinear +from .torch import TorchLinear log = setup_logger() -class TritonV2QuantLinear(TorchQuantLinear): - SUPPORTS_BACKENDS = [BACKEND.TRITON] +class TritonV2Linear(TorchLinear): + SUPPORTS_BACKENDS = [BACKEND.GPTQ_TRITON] SUPPORTS_METHODS = [METHOD.GPTQ] SUPPORTS_FORMATS = {FORMAT.GPTQ: 40, FORMAT.GPTQ_V2: 40} SUPPORTS_BITS = [2, 4, 8] @@ -78,7 +78,7 @@ def __init__( out_features=out_features, bias=bias, pack_dtype=pack_dtype, - backend=kwargs.pop("backend", BACKEND.TRITON), + backend=kwargs.pop("backend", BACKEND.GPTQ_TRITON), adapter=adapter, register_buffers=register_buffers, **kwargs) @@ -164,7 +164,7 @@ def forward(self, x): return out.to(dtype=x.dtype) -__all__ = ["TritonV2QuantLinear"] +__all__ = ["TritonV2Linear"] # test triton on XPU to ensure special Intel/Triton is installed as we cannot check based on triton package meta data diff --git a/gptqmodel/quantization/__init__.py b/gptqmodel/quantization/__init__.py index 6cb608ca3..c2261e779 100644 --- a/gptqmodel/quantization/__init__.py +++ b/gptqmodel/quantization/__init__.py @@ -8,14 +8,32 @@ FORMAT_FIELD_CHECKPOINT, FORMAT_FIELD_CODE, METHOD, + METHOD_FIELD_CODE, QUANT_CONFIG_FILENAME, QUANT_METHOD_FIELD, + AutoModuleDecoderConfig, + AWQConfig, + BaseComplexBits, + BasePreProcessorConfig, BaseQuantizeConfig, - FailSafe, - FailSafeStrategy, + BitsAndBytesConfig, + EXL3Config, + Fallback, + FallbackStrategy, + FOEMConfig, + FP8Config, + GGUFBits, + GGUFConfig, GPTAQConfig, + GPTQConfig, HessianConfig, + ParoConfig, + PreProcessorCode, + PreProcessorConfig, + QuantBits, QuantizeConfig, + RTNConfig, + SmootherConfig, SmoothLog, SmoothMAD, SmoothMethod, @@ -25,7 +43,29 @@ SmoothPercentileAsymmetric, SmoothRowCol, SmoothSoftNorm, + TensorParallelPadderConfig, + WeightOnlyConfig, + WeightOnlyMethod, ) +from .foem import FOEM from .gptaq import GPTAQ from .gptq import GPTQ +from .protocol import ( + ExecutionPlan, + ExportSpec, + MatchSpec, + OperationSpec, + QuantizeSpec, + Rule, + Stage, + TargetSpec, + compile_plan_to_quantize_config, + compile_protocol, + compile_protocol_to_quantize_config, + compile_protocol_yaml_file, + compile_protocol_yaml_text, + compile_protocol_yaml_to_quantize_config, + skip, +) from .quantizer import Quantizer, quantize +from .rtn import RTN diff --git a/gptqmodel/quantization/awq/modules/triton/gemm.py b/gptqmodel/quantization/awq/modules/triton/gemm.py index c027e195b..97999f157 100644 --- a/gptqmodel/quantization/awq/modules/triton/gemm.py +++ b/gptqmodel/quantization/awq/modules/triton/gemm.py @@ -18,8 +18,12 @@ import triton import triton.language as tl +from gptqmodel.utils.env import env_flag + AWQ_TRITON_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] +# Shared runtime default: fp32 accumulation trades a little speed for lower numerical drift. +FP32_ACCUM = env_flag("GPTQMODEL_FP32_ACCUM", default=True) def get_same_device_cm(t): if t.device.type == 'xpu': @@ -140,6 +144,7 @@ def awq_gemm_kernel( BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, SPLIT_K: tl.constexpr, + USE_FP32_ACCUM: tl.constexpr, ): pid = tl.program_id(axis=0) pid_z = tl.program_id(1) @@ -151,15 +156,10 @@ def awq_gemm_kernel( pid_m = pid // num_pid_n pid_n = pid % num_pid_n - accumulator_dtype = c_ptr.type.element_ty - - # NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead. - # accumulator = tl.arange(0, BLOCK_SIZE_N) - # accumulator = tl.broadcast_to(accumulator[None, :], - # (BLOCK_SIZE_M, BLOCK_SIZE_N)) - # accumulator = accumulator & 0x0 - # accumulator = accumulator.to(accumulator_dtype) - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=accumulator_dtype) + if USE_FP32_ACCUM: + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + else: + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=c_ptr.type.element_ty) # Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7] # that will map given indices to the correct order. @@ -198,10 +198,10 @@ def awq_gemm_kernel( for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): masks_k = offsets_k < K masks_a = masks_am[:, None] & masks_k[None, :] - a = tl.load(a_ptrs, mask=masks_a) + a = tl.load(a_ptrs, mask=masks_a, other=0.0) masks_b = masks_k[:, None] & masks_bn[None, :] - b = tl.load(b_ptrs, mask=masks_b) + b = tl.load(b_ptrs, mask=masks_b, other=0) b = tl.interleave(b, b) b = tl.interleave(b, b) b = tl.interleave(b, b) @@ -214,7 +214,7 @@ def awq_gemm_kernel( masks_zk = offsets_szk < K // group_size masks_z = masks_zk[:, None] & masks_zn[None, :] zeros_ptrs = zeros_ptr + offsets_z - zeros = tl.load(zeros_ptrs, mask=masks_z) + zeros = tl.load(zeros_ptrs, mask=masks_z, other=0) zeros = tl.interleave(zeros, zeros) zeros = tl.interleave(zeros, zeros) zeros = tl.interleave(zeros, zeros) @@ -224,16 +224,19 @@ def awq_gemm_kernel( masks_sk = offsets_szk < K // group_size masks_s = masks_sk[:, None] & masks_sn[None, :] scales_ptrs = scales_ptr + offsets_s - scales = tl.load(scales_ptrs, mask=masks_s) + scales = tl.load(scales_ptrs, mask=masks_s, other=0.0) scales = tl.broadcast_to(scales, (BLOCK_SIZE_K, BLOCK_SIZE_N)) b = (b >> shifts) & 0xF zeros = (zeros >> shifts) & 0xF b = (b - zeros) * scales - b = b.to(c_ptr.type.element_ty) + b = b.to(a.dtype) # Accumulate results. - accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype) + if USE_FP32_ACCUM: + accumulator = tl.dot(a, b, accumulator, out_dtype=tl.float32) + else: + accumulator = tl.dot(a, b, accumulator, out_dtype=c_ptr.type.element_ty) offsets_k += BLOCK_SIZE_K * SPLIT_K a_ptrs += BLOCK_SIZE_K * SPLIT_K @@ -318,6 +321,8 @@ def awq_gemm_triton( block_size_m: int = 32, block_size_n: int = 32, block_size_k: int = 32, + fp32_accum: bool = FP32_ACCUM, + output_dtype: torch.dtype | None = None, ) -> torch.Tensor: M, K = input.shape N = qweight.shape[1] * 8 @@ -338,7 +343,11 @@ def grid(META): split_k_iters, ) - result = torch.zeros((M, N), dtype=scales.dtype, device=input.device) + if output_dtype is None: + output_dtype = scales.dtype + + accum_dtype = torch.float32 if fp32_accum else output_dtype + result = torch.zeros((M, N), dtype=accum_dtype, device=input.device) # A = input, B = qweight, C = result # A = M x K, B = K x N, C = M x N @@ -357,6 +366,10 @@ def grid(META): BLOCK_SIZE_N=block_size_n, BLOCK_SIZE_K=block_size_k, SPLIT_K=split_k_iters, + USE_FP32_ACCUM=fp32_accum, ) + if result.dtype != output_dtype: + return result.to(output_dtype) + return result diff --git a/gptqmodel/quantization/awq/quantize/scale.py b/gptqmodel/quantization/awq/quantize/scale.py index 9163e55c9..06e4ccf66 100644 --- a/gptqmodel/quantization/awq/quantize/scale.py +++ b/gptqmodel/quantization/awq/quantize/scale.py @@ -27,7 +27,15 @@ from gptqmodel.quantization.awq.utils.utils import get_best_device +try: + from transformers.models.gemma4.modeling_gemma4 import Gemma4RMSNorm +except Exception: # pragma: no cover - older transformers builds do not expose Gemma 4 yet + Gemma4RMSNorm = None + + allowed_norms = [nn.LayerNorm, LlamaRMSNorm, GemmaRMSNorm, Gemma2RMSNorm, CohereLayerNorm] +if Gemma4RMSNorm is not None: + allowed_norms.append(Gemma4RMSNorm) allowed_act_fns = [ nn.GELU, BloomGelu, diff --git a/gptqmodel/quantization/awq/utils/module.py b/gptqmodel/quantization/awq/utils/module.py index 816069bc9..2f145fe46 100644 --- a/gptqmodel/quantization/awq/utils/module.py +++ b/gptqmodel/quantization/awq/utils/module.py @@ -3,18 +3,9 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium -import importlib - import torch.nn as nn -def try_import(module_name): - try: - module = importlib.import_module(module_name) - return module, "" - except Exception as ex: - return None, str(ex) - def get_named_linears(module): return {name: m for name, m in module.named_modules() if isinstance(m, nn.Linear)} diff --git a/gptqmodel/quantization/config.py b/gptqmodel/quantization/config.py index 368c619ec..0a2091fea 100644 --- a/gptqmodel/quantization/config.py +++ b/gptqmodel/quantization/config.py @@ -3,14 +3,18 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium +import copy import json +import math import os.path +from abc import ABC, abstractmethod from dataclasses import asdict, dataclass, field, fields from enum import Enum +from functools import total_ordering from os.path import join -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, ClassVar, Dict, List, Optional, Tuple, Union -import pcre as re +import pcre import torch from packaging import version @@ -21,17 +25,29 @@ log = setup_logger() +_DECODER_TARGET_DTYPE_MAP = { + "float16": torch.float16, + "half": torch.float16, + "fp16": torch.float16, + "bfloat16": torch.bfloat16, + "bf16": torch.bfloat16, +} + BITS_FIELD_CODE = "bits" GROUP_SIZE_FIELD_CODE = "group_size" FORMAT_FIELD_CODE = "format" SYMMETRIC_FIELD_CODE = "sym" +# Deprecated JSON alias retained for backward compatibility. FORMAT_FIELD_CHECKPOINT = "checkpoint_format" +# Hard-deprecated legacy alias. Presence should fail fast. FORMAT_FIELD_COMPAT_MARLIN = "is_marlin_format" +# Canonical method field; `quant_method` is a deprecated JSON alias. +METHOD_FIELD_CODE = "method" QUANT_METHOD_FIELD = "quant_method" PACK_DTYPE_FIELD = "pack_dtype" QUANT_CONFIG_FILENAME = "quantize_config.json" QUANT_CONFIG_FILENAME_COMPAT = [QUANT_CONFIG_FILENAME, "quant_config.json", "config.json"] -# This is AwqBackendPackingMethod, not GPTQModel.BACKEND. +# This is AwqBackendPackingMethod, not the GPT-QModel backend enum. # It's used to distinguish between quantization by llm-awq and autoawq; llm-awq actually uses GEMV_FAST for packing. AWQ_PACKING_BACKEND_FIELD = "backend" @@ -57,37 +73,54 @@ META_FIELD_GPTAQ_ENABLED = "gptaq" -ADAPTER_FIELD = "adapter" +META_FIELD_FOEM_ENABLED = "foem" +ADAPTER_FIELD = "adapter" # saved formats class FORMAT(str, Enum): + """Checkpoint and runtime tensor layout identifiers.""" + GPTQ = "gptq" # v2 format fixed sym = False quantization GPTQ_V2 = "gptq_v2" + GGUF = "gguf" + FP8 = "fp8" + BITSANDBYTES = "bitsandbytes" MARLIN = "marlin" BITBLAS = "bitblas" QQQ = "qqq" + EXL3 = "exl3" GEMM = "gemm" GEMV = "gemv" GEMV_FAST = "gemv_fast" LLM_AWQ = "llm-awq" + PAROQUANT = "paroquant" # quant methods class METHOD(str, Enum): + """Supported quantization algorithms exposed by config payloads.""" + GPTQ = "gptq" + GGUF = "gguf" + FP8 = "fp8" + BITSANDBYTES = "bitsandbytes" QQQ = "qqq" AWQ = "awq" + EXL3 = "exl3" + PARO = "paroquant" class VramStrategy(str, Enum): + """Placement strategies shared by dense and MoE device pools.""" + EXCLUSIVE = "exclusive" BALANCED = "balanced" -class FailSafeStrategy(str, Enum): +class FallbackStrategy(str, Enum): """ +-----------+----------------------+---------------------------+------------------------------+ | strategy | center | scale | strengths / weaknesses | @@ -99,16 +132,728 @@ class FailSafeStrategy(str, Enum): | stdclip | mean(w) | 2*sigma*std | tames tails, may clip signal | +-----------+----------------------+---------------------------+------------------------------+ """ - - RTN = "rtn" # round to nearest + RTN = "rtn" # round to nearest MIDPOINT = "midpoint" MEAN = "mean" MEDIAN = "median" STDCLIP = "stdclip" +class WeightOnlyMethod(str, Enum): + """Weight-only quantization backends available to fallback flows.""" + + RTN = "rtn" + GGUF = "gguf" + FP8 = "fp8" + BITSANDBYTES = "bitsandbytes" + NVFP4 = "nvfp4" + + +class PreProcessorCode(str, Enum): + """Identifiers for preprocessing passes that run before quantization.""" + + SMOOTHER = "smoother" + AUTO_MODULE_DECODER = "auto_module_decoder" + TENSOR_PARALLEL_PADDER = "tensor_parallel_padder" + + +_GGUF_BITS_ALIAS_INFO = { + "q1_0": {"bits": 1, "version": "q", "variant": "0", "quality": None}, + "q1_0_g128": {"bits": 1, "version": "q", "variant": "0", "quality": "g128"}, + "q4_0": {"bits": 4, "version": "q", "variant": "0", "quality": None}, + "q8_0": {"bits": 8, "version": "q", "variant": "0", "quality": None}, + "q4_k": {"bits": 4, "version": "q", "variant": "k", "quality": None}, + "q4_k_s": {"bits": 4, "version": "q", "variant": "k", "quality": "s"}, + "q4_k_m": {"bits": 4, "version": "q", "variant": "k", "quality": "m"}, + "q5_k": {"bits": 5, "version": "q", "variant": "k", "quality": None}, + "q5_k_s": {"bits": 5, "version": "q", "variant": "k", "quality": "s"}, + "q5_k_m": {"bits": 5, "version": "q", "variant": "k", "quality": "m"}, + "q6_k": {"bits": 6, "version": "q", "variant": "k", "quality": None}, +} +_GGUF_DEFAULT_BITS_ALIAS_BY_WIDTH = { + 1: "q1_0", + 4: "q4_0", + 5: "q5_k_m", + 6: "q6_k", + 8: "q8_0", +} +_GGUF_APPROX_BITS_PER_WEIGHT_BY_ALIAS = { + "q1_0": 1.5, + "q1_0_g128": 1.125, + "q4_0": 4.5, + "q8_0": 8.5, + "q4_k": 4.5, + "q4_k_s": 4.5, + "q4_k_m": 4.5, + "q5_k": 5.5, + "q5_k_s": 5.0, + "q5_k_m": 5.5, + "q6_k": 6.0, +} + + +@total_ordering +class BaseComplexBits(ABC): + """Comparable bit-spec base class for non-scalar bit encodings.""" + + @classmethod + @abstractmethod + def from_string(cls, value: str) -> "BaseComplexBits": + """Parse a serialized bit specification into an instance.""" + + raise NotImplementedError + + @abstractmethod + def to_string(self) -> str: + """Serialize the bit specification into its canonical string form.""" + + raise NotImplementedError + + @property + def width(self) -> int: + """Return the integer width represented by this bit encoding.""" + + return self.bits + + @property + def name(self) -> str: + """Return the canonical string name for this bit encoding.""" + + return self.to_string() + + def _coerce_bits(self, other: Any) -> Any: + """Convert compatible operands into raw bit widths for arithmetic.""" + + if isinstance(other, BaseComplexBits): + return other.bits + if isinstance(other, int): + return other + if isinstance(other, str) and other.strip().isdigit(): + return int(other.strip()) + return NotImplemented + + def __str__(self) -> str: + """Render the canonical string form for logging and serialization.""" + + return self.to_string() + + def __hash__(self) -> int: + """Hash bit encodings by their integer width.""" + + return hash(self.bits) + + def __int__(self) -> int: + """Expose the bit width as an integer.""" + + return self.bits + + def __index__(self) -> int: + """Allow the bit width to participate in index-style conversions.""" + + return self.bits + + def __float__(self) -> float: + """Expose the bit width as a float.""" + + return float(self.bits) + + def __eq__(self, other: Any) -> bool: + """Compare complex bit encodings against strings, ints, or peers.""" + + if isinstance(other, BaseComplexBits): + return self.to_string() == other.to_string() + if isinstance(other, int): + return self.bits == other + if isinstance(other, str): + normalized = other.strip().lower().replace("-", "_") + if normalized.isdigit(): + return self.bits == int(normalized) + return self.to_string() == normalized + return False + + def __lt__(self, other: Any) -> bool: + """Order bit encodings by their effective width.""" + + coerced = self._coerce_bits(other) + if coerced is NotImplemented: + return NotImplemented + return self.bits < coerced + + def __add__(self, other: Any) -> int: + """Add the effective bit width to another scalar-like operand.""" + + coerced = self._coerce_bits(other) + if coerced is NotImplemented: + return NotImplemented + return self.bits + coerced + + def __radd__(self, other: Any) -> int: + """Support right-hand addition with scalar-like operands.""" + + return self.__add__(other) + + def __sub__(self, other: Any) -> int: + """Subtract another scalar-like operand from this bit width.""" + + coerced = self._coerce_bits(other) + if coerced is NotImplemented: + return NotImplemented + return self.bits - coerced + + def __rsub__(self, other: Any) -> int: + """Support right-hand subtraction against this bit width.""" + + coerced = self._coerce_bits(other) + if coerced is NotImplemented: + return NotImplemented + return coerced - self.bits + + def __mul__(self, other: Any) -> int: + """Multiply the bit width by another scalar-like operand.""" + + coerced = self._coerce_bits(other) + if coerced is NotImplemented: + return NotImplemented + return self.bits * coerced + + def __rmul__(self, other: Any) -> int: + """Support right-hand multiplication with scalar-like operands.""" + + return self.__mul__(other) + + def __floordiv__(self, other: Any) -> int: + """Floor-divide the bit width by another scalar-like operand.""" + + coerced = self._coerce_bits(other) + if coerced is NotImplemented: + return NotImplemented + return self.bits // coerced + + def __rfloordiv__(self, other: Any) -> int: + """Support right-hand floor division against this bit width.""" + + coerced = self._coerce_bits(other) + if coerced is NotImplemented: + return NotImplemented + return coerced // self.bits + + def __truediv__(self, other: Any) -> float: + """True-divide the bit width by another scalar-like operand.""" + + coerced = self._coerce_bits(other) + if coerced is NotImplemented: + return NotImplemented + return self.bits / coerced + + def __rtruediv__(self, other: Any) -> float: + """Support right-hand true division against this bit width.""" + + coerced = self._coerce_bits(other) + if coerced is NotImplemented: + return NotImplemented + return coerced / self.bits + + def __mod__(self, other: Any) -> int: + """Take the modulo of the bit width with another scalar-like operand.""" + + coerced = self._coerce_bits(other) + if coerced is NotImplemented: + return NotImplemented + return self.bits % coerced + + def __rmod__(self, other: Any) -> int: + """Support right-hand modulo against this bit width.""" + + coerced = self._coerce_bits(other) + if coerced is NotImplemented: + return NotImplemented + return coerced % self.bits + + def __pow__(self, other: Any) -> int: + """Raise the bit width to another scalar-like operand.""" + + coerced = self._coerce_bits(other) + if coerced is NotImplemented: + return NotImplemented + return self.bits ** coerced + + def __rpow__(self, other: Any) -> int: + """Support right-hand exponentiation against this bit width.""" + + coerced = self._coerce_bits(other) + if coerced is NotImplemented: + return NotImplemented + return coerced ** self.bits + + +@dataclass(frozen=True, eq=False) +class GGUFBits(BaseComplexBits): + """Structured GGUF bit specification with version and variant tags.""" + + bits: int + version: str + variant: str + quality: Optional[str] = None + + def __post_init__(self): + """Validate the GGUF bit-spec components after construction.""" + + if self.bits <= 0: + raise ValueError("GGUFBits: `bits` must be a positive integer.") + if self.version not in {"q", "iq"}: + raise ValueError("GGUFBits: `version` must be `q` or `iq`.") + if self.variant not in {"0", "k"}: + raise ValueError("GGUFBits: `variant` must be `0` or `k`.") + if self.quality not in {None, "xs", "s", "m", "l", "g128"}: + raise ValueError("GGUFBits: `quality` must be one of `[None, xs, s, m, l, g128]`.") + + @classmethod + def from_string(cls, value: str) -> "GGUFBits": + """Parse a GGUF alias such as ``q4_k_m`` into a typed bit spec.""" + + normalized = str(value).strip().lower().replace("-", "_") + info = _GGUF_BITS_ALIAS_INFO.get(normalized) + if info is None: + supported = ", ".join(sorted(_GGUF_BITS_ALIAS_INFO)) + raise ValueError(f"Unsupported GGUF bits `{value}`. Supported values: {supported}.") + return cls( + bits=info["bits"], + version=info["version"], + variant=info["variant"], + quality=info["quality"], + ) + + def to_string(self) -> str: + """Serialize this GGUF bit spec back to its alias form.""" + + alias = f"{self.version}{self.bits}_{self.variant}" + if self.quality is not None: + alias = f"{alias}_{self.quality}" + return alias + + @classmethod + def from_alias(cls, value: str) -> "GGUFBits": + """Backward-compatible alias parser for GGUF bit specs.""" + + return cls.from_string(value) + + def serialize(self) -> str: + """Return the canonical serialized form used in config payloads.""" + + return self.to_string() + + def __repr__(self) -> str: + """Return a debug-friendly constructor-style representation.""" + + return f"GGUFBits({self.to_string()!r})" + + def to_public_format(self) -> str: + """Return the GGUF public subtype string without the width prefix.""" + + public_format = f"{self.version}_{self.variant}" + if self.quality is not None: + public_format = f"{public_format}_{self.quality}" + return public_format + + +# Backward-compatible alias for the earlier wrapper-based refactor. +QuantBits = GGUFBits + + +_GGUF_PUBLIC_FORMAT_RE = pcre.compile(r"^(q|iq)_(0|k)(?:_(xs|s|m|l|g128))?$") + + +def _gguf_public_format_from_bits(bits: GGUFBits) -> str: + """Project a full GGUF bit spec into its public subtype token.""" + + return bits.to_public_format() + + +def _normalize_gguf_public_format(value: Any) -> Optional[str]: + """Normalize GGUF subtype aliases into their public format string.""" + + if value is None: + return None + + if isinstance(value, GGUFBits): + return _gguf_public_format_from_bits(value) + + if isinstance(value, FORMAT): + value = value.value + + normalized = str(value).strip().lower().replace("-", "_") + if normalized in {"", FORMAT.GGUF.value}: + return None + if normalized in _GGUF_BITS_ALIAS_INFO: + return _gguf_public_format_from_bits(GGUFBits.from_alias(normalized)) + if _GGUF_PUBLIC_FORMAT_RE.fullmatch(normalized): + return normalized + + raise ValueError( + "GGUFConfig: `format` must be a GGUF subtype like `q_0`, `q_k`, `q_k_s`, or `q_k_m`." + ) + + +def _default_gguf_public_format(bits: int) -> str: + """Return the default GGUF subtype for a supported bit width.""" + + alias = _GGUF_DEFAULT_BITS_ALIAS_BY_WIDTH.get(bits) + if alias is None: + raise ValueError(f"GGUFConfig: no default GGUF format exists for `{bits}`-bit quantization.") + return _gguf_public_format_from_bits(GGUFBits.from_alias(alias)) + + +def _gguf_bits_from_components(bits: int, public_format: str) -> GGUFBits: + """Build a validated ``GGUFBits`` object from width and subtype parts.""" + + match = _GGUF_PUBLIC_FORMAT_RE.fullmatch(public_format) + if match is None: + raise ValueError( + "GGUFConfig: `format` must be a GGUF subtype like `q_0`, `q_k`, `q_k_s`, or `q_k_m`." + ) + + version_name, variant, quality = match.groups() + bits_spec = GGUFBits(bits=bits, version=version_name, variant=variant, quality=quality) + if bits_spec.to_string() not in _GGUF_BITS_ALIAS_INFO: + raise ValueError( + f"Unsupported GGUF combination: bits={bits}, format={public_format}." + ) + return bits_spec + + +def _normalize_gguf_config_spec( + bits: Union[int, str, GGUFBits], + format_value: Optional[Union[str, FORMAT, GGUFBits]], +) -> Tuple[int, str, GGUFBits]: + """Resolve GGUF bits and format inputs into a consistent typed triple.""" + + bits_spec_from_bits: Optional[GGUFBits] = None + normalized_bits = bits + + if isinstance(bits, GGUFBits): + bits_spec_from_bits = bits + normalized_bits = bits.bits + elif isinstance(bits, str): + raw_bits = bits.strip().lower().replace("-", "_") + if raw_bits.isdigit(): + normalized_bits = int(raw_bits) + else: + bits_spec_from_bits = GGUFBits.from_alias(raw_bits) + normalized_bits = bits_spec_from_bits.bits + elif not isinstance(bits, int): + raise ValueError(f"GGUFConfig: unsupported bits specification `{bits}`.") + + normalized_bits = int(normalized_bits) + if normalized_bits not in [1, 2, 3, 4, 5, 6, 8]: + raise ValueError("GGUFConfig: `bits` must resolve to one of `[1, 2, 3, 4, 5, 6, 8]`.") + + normalized_format = _normalize_gguf_public_format(format_value) + if normalized_format is None: + if bits_spec_from_bits is not None: + bits_spec = bits_spec_from_bits + normalized_format = _gguf_public_format_from_bits(bits_spec) + else: + normalized_format = _default_gguf_public_format(normalized_bits) + bits_spec = _gguf_bits_from_components(normalized_bits, normalized_format) + else: + bits_spec = _gguf_bits_from_components(normalized_bits, normalized_format) + if bits_spec_from_bits is not None and bits_spec_from_bits != bits_spec: + raise ValueError( + f"GGUFConfig: incompatible GGUF bits/format combination: bits={bits}, format={format_value}." + ) + + return normalized_bits, normalized_format, bits_spec + + +def _normalize_quant_bits(bits: Union[int, float, str, GGUFBits], format_value: Optional[Union[str, FORMAT]] = None) -> Union[int, GGUFBits]: + """Normalize generic bit fields into ints or structured GGUF specs.""" + + if isinstance(format_value, str): + format_value = _normalize_format(format_value) + + if isinstance(bits, GGUFBits): + normalized = bits + elif isinstance(bits, float): + if format_value == FORMAT.EXL3: + normalized = bits + elif bits.is_integer(): + normalized = int(bits) + else: + raise ValueError(f"QuantizeConfig: unsupported bits specification `{bits}`.") + elif isinstance(bits, int): + normalized = bits + elif isinstance(bits, str): + raw = bits.strip().lower().replace("-", "_") + normalized = int(raw) if raw.isdigit() else GGUFBits.from_alias(raw) + else: + raise ValueError(f"QuantizeConfig: unsupported bits specification `{bits}`.") + + normalized_width = normalized.bits if isinstance(normalized, GGUFBits) else normalized + valid_bit_widths = [1, 2, 3, 4, 5, 6, 8] + if normalized_width not in valid_bit_widths: + raise ValueError(f"QuantizeConfig: `bits` must resolve to one of `{valid_bit_widths}`.") + + if format_value == FORMAT.GGUF and not isinstance(normalized, GGUFBits): + default_alias = _GGUF_DEFAULT_BITS_ALIAS_BY_WIDTH.get(normalized_width) + if default_alias is None: + raise ValueError( + f"QuantizeConfig: no default GGUF bits alias exists for `{normalized_width}`-bit quantization." + ) + normalized = GGUFBits.from_alias(default_alias) + + if isinstance(normalized, GGUFBits) and format_value is not None and format_value != FORMAT.GGUF: + raise ValueError("QuantizeConfig: GGUF bit encodings require `format=gguf`.") + + return normalized + + +def resolve_quant_format( + format_value: Optional[Union[str, FORMAT]], + method: Optional[Union[str, METHOD]] = None, + quant_method: Optional[Union[str, METHOD]] = None, +) -> FORMAT: + """Infer the effective quantization format from method and format hints.""" + + if method is None: + method = quant_method + + if isinstance(method, str): + method = _normalize_quant_method(method) + + if method == METHOD.GGUF: + return FORMAT.GGUF + if method == METHOD.FP8: + return FORMAT.FP8 + if method == METHOD.BITSANDBYTES: + return FORMAT.BITSANDBYTES + if method == METHOD.EXL3: + return FORMAT.EXL3 + if method == METHOD.PARO: + return FORMAT.PAROQUANT + + if isinstance(format_value, FORMAT): + return format_value + + try: + if _normalize_gguf_public_format(format_value) is not None: + return FORMAT.GGUF + except ValueError: + pass + + if _looks_like_fp8_fmt(format_value): + return FORMAT.FP8 + if _looks_like_bitsandbytes_format(format_value): + return FORMAT.BITSANDBYTES + + if format_value is None: + return FORMAT.GPTQ + + return _normalize_format(format_value) + + +def _looks_like_gguf_bits(bits: Any) -> bool: + """Return ``True`` when a value resembles a GGUF alias or bit spec.""" + + if isinstance(bits, GGUFBits): + return True + if not isinstance(bits, str): + return False + normalized = bits.strip().lower().replace("-", "_") + return normalized in _GGUF_BITS_ALIAS_INFO + + +def quant_bits_width(bits: Union[int, str, GGUFBits]) -> int: + """Return the integer width represented by a quant bits field.""" + + if isinstance(bits, float): + if bits <= 0: + raise ValueError("QuantizeConfig: EXL3 bits per weight must be greater than 0.") + return max(1, int(math.floor(bits))) + normalized = _normalize_quant_bits(bits) + return normalized.bits if isinstance(normalized, GGUFBits) else normalized + + +def serialize_quant_bits(bits: Union[int, str, GGUFBits]) -> Union[int, str]: + """Serialize a quant bits field for JSON-compatible output payloads.""" + + if isinstance(bits, float): + return float(bits) + normalized = _normalize_quant_bits(bits) + return normalized.serialize() if isinstance(normalized, GGUFBits) else normalized + + +def _normalize_exl3_bits(bits: Union[int, float, str]) -> float: + """Normalize EXL3 fractional bits-per-weight values.""" + + if isinstance(bits, str): + bits = float(bits.strip()) + elif isinstance(bits, int): + bits = float(bits) + elif not isinstance(bits, float): + raise ValueError(f"EXL3Config: unsupported bits specification `{bits}`.") + + if not math.isfinite(bits): + raise ValueError("EXL3Config: `bits` must be finite.") + if bits < 1.0 or bits > 8.0: + raise ValueError("EXL3Config: `bits` must be between 1.0 and 8.0.") + return float(bits) + + +# Canonical FP8 aliases are normalized here before validating torch runtime +# support so config payloads can use either shorthand or exact dtype names. +_FP8_FMT_ALIASES = { + "e4m3": "float8_e4m3fn", + "float8_e4m3": "float8_e4m3fn", + "float8_e4m3fn": "float8_e4m3fn", + "e5m2": "float8_e5m2", + "float8_e5m2": "float8_e5m2", + "e4m3fnuz": "float8_e4m3fnuz", + "float8_e4m3fnuz": "float8_e4m3fnuz", + "e5m2fnuz": "float8_e5m2fnuz", + "float8_e5m2fnuz": "float8_e5m2fnuz", + "e8m0": "float8_e8m0fnu", + "e8m0fnu": "float8_e8m0fnu", + "float8_e8m0": "float8_e8m0fnu", + "float8_e8m0fnu": "float8_e8m0fnu", +} +_FP8_WEIGHT_SCALE_METHODS = {"tensor", "row", "block"} +_FP8_SCALE_SEMANTICS = {"inverse"} +_BITSANDBYTES_4BIT_FORMATS = {"fp4", "nf4"} +_BITSANDBYTES_8BIT_FORMATS = {"int8"} +_BITSANDBYTES_FORMATS = _BITSANDBYTES_4BIT_FORMATS | _BITSANDBYTES_8BIT_FORMATS +_BITSANDBYTES_BLOCK_SIZES = {32, 64, 128, 256, 512, 1024, 2048, 4096} + + +def _looks_like_fp8_fmt(value: Any) -> bool: + """Return ``True`` when a value matches a supported FP8 format alias.""" + + if value is None: + return False + normalized = str(value).strip().lower() + return normalized in _FP8_FMT_ALIASES + + +def _normalize_fp8_fmt(value: Optional[str]) -> str: + """Resolve FP8 format aliases to the canonical PyTorch dtype name.""" + + if isinstance(value, FORMAT): + if value != FORMAT.FP8: + raise ValueError(f"FP8Config: unsupported `format` `{value}`.") + value = None + + normalized = "float8_e4m3fn" if value is None else str(value).strip().lower() + if normalized in {"", FORMAT.FP8.value}: + normalized = "float8_e4m3fn" + resolved = _FP8_FMT_ALIASES.get(normalized) + if resolved is None: + supported = ", ".join(sorted(_FP8_FMT_ALIASES)) + raise ValueError(f"FP8Config: unsupported `format` `{value}`. Supported values: {supported}.") + if not hasattr(torch, resolved): + raise ValueError(f"FP8Config: current PyTorch build does not provide `{resolved}`.") + return resolved + + +def _normalize_fp8_weight_block_size(value: Optional[Union[List[int], Tuple[int, int]]]) -> Optional[Tuple[int, int]]: + """Validate and normalize FP8 block-scale dimensions.""" + + if value is None: + return None + if not isinstance(value, (list, tuple)) or len(value) != 2: + raise ValueError("FP8Config: `weight_block_size` must be a 2-item list/tuple or None.") + rows, cols = int(value[0]), int(value[1]) + if rows <= 0 or cols <= 0: + raise ValueError("FP8Config: `weight_block_size` entries must be positive integers.") + return rows, cols + + +def _normalize_fp8_weight_scale_method( + value: Optional[str], + *, + weight_block_size: Optional[Tuple[int, int]], +) -> str: + """Resolve the FP8 weight scaling strategy from config inputs.""" + + normalized = "block" if weight_block_size is not None and value is None else (value or "row") + normalized = str(normalized).strip().lower() + if normalized not in _FP8_WEIGHT_SCALE_METHODS: + supported = ", ".join(sorted(_FP8_WEIGHT_SCALE_METHODS)) + raise ValueError( + f"FP8Config: `weight_scale_method` must be one of {{{supported}}}, got `{value}`." + ) + if normalized == "block" and weight_block_size is None: + raise ValueError("FP8Config: `weight_scale_method='block'` requires `weight_block_size`.") + if normalized != "block" and weight_block_size is not None: + raise ValueError( + "FP8Config: `weight_block_size` is only valid when `weight_scale_method='block'`." + ) + return normalized + + +def _normalize_fp8_scale_semantics(value: Optional[str]) -> str: + """Normalize FP8 scale semantics to the supported enum-like string.""" + + normalized = "inverse" if value is None else str(value).strip().lower() + if normalized not in _FP8_SCALE_SEMANTICS: + supported = ", ".join(sorted(_FP8_SCALE_SEMANTICS)) + raise ValueError( + f"FP8Config: `weight_scale_semantics` must be one of {{{supported}}}, got `{value}`." + ) + return normalized + + +def _looks_like_bitsandbytes_format(value: Any) -> bool: + """Return ``True`` when a value matches a bitsandbytes format alias.""" + + if value is None: + return False + normalized = str(value).strip().lower().replace("-", "_") + return normalized in _BITSANDBYTES_FORMATS + + +def _normalize_bitsandbytes_format(value: Optional[str], *, bits: Optional[int] = None) -> str: + """Normalize bitsandbytes format aliases for the requested bit width.""" + + default_format = "int8" if bits == 8 else "fp4" + normalized = default_format if value is None else str(value).strip().lower().replace("-", "_") + if normalized in {"", FORMAT.BITSANDBYTES.value}: + normalized = default_format + + if bits == 4: + allowed_formats = _BITSANDBYTES_4BIT_FORMATS + elif bits == 8: + allowed_formats = _BITSANDBYTES_8BIT_FORMATS + else: + allowed_formats = _BITSANDBYTES_FORMATS + + if normalized not in allowed_formats: + supported = ", ".join(sorted(allowed_formats)) + raise ValueError( + f"BitsAndBytesConfig: `format` must be one of {{{supported}}}, got `{value}`." + ) + return normalized + + +def _normalize_bitsandbytes_quant_type(value: Optional[str]) -> str: + """Normalize the legacy 4-bit bitsandbytes quant type field.""" + + return _normalize_bitsandbytes_format(value, bits=4) + + +def _normalize_bitsandbytes_block_size(value: Optional[int]) -> int: + """Validate and normalize the bitsandbytes block size setting.""" + + normalized = 64 if value is None else int(value) + if normalized not in _BITSANDBYTES_BLOCK_SIZES: + supported = ", ".join(str(item) for item in sorted(_BITSANDBYTES_BLOCK_SIZES)) + raise ValueError( + f"BitsAndBytesConfig: `block_size` must be one of {{{supported}}}, got `{value}`." + ) + return normalized + @dataclass class SmoothMethod: + """Base smoother descriptor shared by all smoothing strategies.""" + name: str # Apply the smoother only when group size >= this threshold. group_size_threshold: int = 128 @@ -126,10 +871,11 @@ class SmoothPercentile(SmoothMethod): | effect | higher p = less clipping | +----------------+-------------------------------------------+ """ - percentile: float = 99.0 def __init__(self, percentile: float = 99.0, group_size_threshold: int = 128): + """Configure percentile clipping with an optional group-size floor.""" + super().__init__(name="percentile", group_size_threshold=group_size_threshold) self.percentile = percentile @@ -146,11 +892,12 @@ class SmoothPercentileAsymmetric(SmoothMethod): | effect | asymmetric clipping of tails | +-------------------+-------------------------------------------+ """ - low: float = 0.5 high: float = 99.5 def __init__(self, low: float = 0.5, high: float = 99.5, group_size_threshold: int = 128): + """Configure asymmetric percentile clipping bounds.""" + super().__init__(name="percentile_asym", group_size_threshold=group_size_threshold) self.low = low self.high = high @@ -168,10 +915,11 @@ class SmoothMAD(SmoothMethod): | effect | higher K = less clipping | +----------------+-------------------------------------------+ """ - k: float = 2.75 def __init__(self, k: float = 2.75, group_size_threshold: int = 128): + """Configure MAD-based clipping width and activation threshold.""" + super().__init__(name="mad", group_size_threshold=group_size_threshold) self.k = k @@ -189,11 +937,12 @@ class SmoothMSE(SmoothMethod): | effect | more steps = better fit, slower | +----------------+-------------------------------------------+ """ - steps: int = 32 maxshrink: float = 0.8 def __init__(self, steps: int = 32, maxshrink: float = 0.8, group_size_threshold: int = 128): + """Configure search granularity for MSE-based shrinking.""" + super().__init__(name="mse", group_size_threshold=group_size_threshold) self.steps = steps self.maxshrink = maxshrink @@ -211,10 +960,11 @@ class SmoothOutlier(SmoothMethod): | effect | higher p = more clipping | +----------------+-------------------------------------------+ """ - pct: float = 1.0 def __init__(self, pct: float = 1.0, group_size_threshold: int = 128): + """Configure top-percent outlier clipping behavior.""" + super().__init__(name="outlier", group_size_threshold=group_size_threshold) self.pct = pct @@ -231,10 +981,11 @@ class SmoothSoftNorm(SmoothMethod): | effect | higher K = less clipping | +----------------+-------------------------------------------+ """ - k: float = 3.0 def __init__(self, k: float = 3.0, group_size_threshold: int = 128): + """Configure z-score clipping strength for soft normalization.""" + super().__init__(name="softnorm", group_size_threshold=group_size_threshold) self.k = k @@ -252,11 +1003,12 @@ class SmoothLog(SmoothMethod): | effect | higher mu compresses outliers more | +----------------+-------------------------------------------+ """ - percentile: float = 99.0 mu: float = 8.0 def __init__(self, percentile: float = 99.0, mu: float = 8.0, group_size_threshold: int = 128): + """Configure log-domain smoothing with percentile and companding strength.""" + super().__init__(name="log", group_size_threshold=group_size_threshold) self.percentile = percentile self.mu = mu @@ -274,30 +1026,31 @@ class SmoothRowCol(SmoothMethod): | effect | normalizes dynamic range before quant | +----------------+-------------------------------------------+ """ - axis: str = "row" def __init__(self, axis: str = "row", group_size_threshold: int = 128): + """Configure RMS normalization over rows or columns.""" + super().__init__(name="rowcol", group_size_threshold=group_size_threshold) self.axis = axis class GcMode(str, Enum): + """Policies for when staged garbage collection should run.""" + INTERVAL = "interval" ON_STAGE_END = "on_stage_end" @dataclass -class FailSafe: - strategy: FailSafeStrategy = ( - FailSafeStrategy.RTN - ) # enable failsafe by default due to moe routing behavior breaking calibration based quantization +class Fallback: + """Low-sample fallback strategy for modules with weak calibration coverage.""" + + strategy: FallbackStrategy = FallbackStrategy.RTN # enable fallback by default due to moe routing behavior breaking calibration based quantization # int/float = if captured module fwd tokens is less than value, trigger strategy # string = if string is int/float followed by %, then if captured module fwd tokens is less than value in percentage relative to calibration, trigger strategy - threshold: int | float | str = ( - "0.5%" # if less than 0.5% of calibration reaches module (think moe) then we trigger per-module failsafe quantization - ) + threshold: int | float | str = "0.5%" # if less than 0.5% of calibration reaches module (think moe) then we trigger per-module fallback quantization # Smoothers can help some low-sample fallback cases, but a static default can # hurt whole-model RTN quality. Leave smoothing opt-in. @@ -305,55 +1058,203 @@ class FailSafe: @dataclass -class HessianConfig: - # Hessian accumulation controls (GPTQ only) - chunk_size: Optional[int] = field(default=None, metadata={"help": "Maximum rows per Hessian chunk"}) - chunk_bytes: Optional[int] = field( - default=None, metadata={"help": "Memory budget (in bytes) for Hessian chunk staging"} - ) - staging_dtype: Union[str, torch.dtype] = field( - default=torch.float32, - metadata={"help": "Stage Hessian chunks in a lower precision dtype when supported"}, - ) +class WeightOnlyConfig: + """Configuration for weight-only fallback quantization flows.""" + + method: WeightOnlyMethod = WeightOnlyMethod.RTN + # Whole-model RTN is noticeably more stable without a smoother by default. + smooth: Optional[SmoothMethod] = None def __post_init__(self): - if self.chunk_size is not None: - if not isinstance(self.chunk_size, int): - raise ValueError("HessianConfig: `chunk_size` must be an integer or None.") - if self.chunk_size <= 0: - raise ValueError("HessianConfig: `chunk_size` must be a positive integer.") + """Normalize the weight-only method and optional smoother settings.""" - if self.chunk_bytes is not None: - if not isinstance(self.chunk_bytes, int): - raise ValueError("HessianConfig: `chunk_bytes` must be an integer or None.") - if self.chunk_bytes <= 0: - raise ValueError("HessianConfig: `chunk_bytes` must be a positive integer amount of bytes.") + if isinstance(self.method, str): + try: + self.method = WeightOnlyMethod(self.method.lower()) + except ValueError as exc: + raise ValueError( + f"WeightOnlyConfig: `method` must be one of {[v.value for v in WeightOnlyMethod]}." + ) from exc + elif not isinstance(self.method, WeightOnlyMethod): + raise ValueError( + f"WeightOnlyConfig: `method` must be one of {[v.value for v in WeightOnlyMethod]}." + ) - if isinstance(self.staging_dtype, str): - self.staging_dtype = self.staging_dtype.lower() - if self.staging_dtype not in ["float32", "float16", "bfloat16"]: - raise ValueError("HessianConfig: `staging_dtype` must be float32, float16, or bfloat16.") - self.staging_dtype = getattr(torch, self.staging_dtype) - elif isinstance(self.staging_dtype, torch.dtype): - if self.staging_dtype not in [torch.float32, torch.float16, torch.bfloat16]: - raise ValueError("HessianConfig: `staging_dtype` must be float32, float16, or bfloat16.") - else: - raise ValueError("HessianConfig: `staging_dtype` must be a torch.dtype or string.") + self.smooth = _parse_smooth_method(self.smooth) @dataclass -class GPTAQConfig: - alpha: float = field(default=0.25) - device: Union[str, torch.device] = field(default="auto") +class BasePreProcessorConfig: + """Base payload for preprocessing stages emitted into config JSON.""" - def __post_init__(self): - if not isinstance(self.alpha, (int, float)): - raise ValueError("GPTAQConfig: `alpha` must be a numeric value.") - if isinstance(self.device, str): - if not self.device: - raise ValueError("GPTAQConfig: `device` must be a non-empty string or torch.device.") - elif not isinstance(self.device, torch.device): - raise ValueError("GPTAQConfig: `device` must be a string or torch.device.") + code: ClassVar[str] = "" + + def to_dict(self) -> Dict[str, Any]: + """Serialize the preprocessor config into a minimal dictionary.""" + + return {"code": self.code} + + +@dataclass +class SmootherConfig(BasePreProcessorConfig): + """Serialized wrapper for a configured smoothing preprocessor.""" + + code: ClassVar[str] = PreProcessorCode.SMOOTHER.value + smooth: Optional[SmoothMethod] = None + + def __post_init__(self): + """Normalize the smoother payload into a typed smoother instance.""" + + self.smooth = _parse_smooth_method(self.smooth) + + def to_dict(self) -> Dict[str, Any]: + """Serialize the smoother config, including the smoother payload.""" + + payload = super().to_dict() + payload["smooth"] = _serialize_smooth_method(self.smooth) + return payload + + +@dataclass +class AutoModuleDecoderConfig(BasePreProcessorConfig): + """Configure automatic module-local decode behavior for checkpoint dtypes such as FP8.""" + + code: ClassVar[str] = PreProcessorCode.AUTO_MODULE_DECODER.value + source_dtype: str = "auto" + target_dtype: Union[str, torch.dtype] = torch.bfloat16 + + def __post_init__(self): + """Normalize the decoder payload into canonical string and dtype values.""" + + source_dtype = str(self.source_dtype).strip().lower() + if source_dtype != "auto": + raise ValueError( + f"AutoModuleDecoderConfig: unsupported `source_dtype` `{self.source_dtype}`." + ) + self.source_dtype = source_dtype + + target_dtype = self.target_dtype + if isinstance(target_dtype, torch.dtype): + normalized_dtype = target_dtype + else: + normalized_dtype = _DECODER_TARGET_DTYPE_MAP.get(str(target_dtype).strip().lower()) + if normalized_dtype not in {torch.float16, torch.bfloat16}: + raise ValueError( + "AutoModuleDecoderConfig: `target_dtype` must be `torch.float16` or `torch.bfloat16`." + ) + self.target_dtype = normalized_dtype + + def to_dict(self) -> Dict[str, Any]: + """Serialize the decoder config with a stable dtype string payload.""" + + payload = super().to_dict() + payload["source_dtype"] = self.source_dtype + payload["target_dtype"] = str(self.target_dtype).split(".")[-1] + return payload + + +@dataclass +class TensorParallelPadderConfig(BasePreProcessorConfig): + """Configure tensor-parallel-safe column padding derived from module weight shapes.""" + + code: ClassVar[str] = PreProcessorCode.TENSOR_PARALLEL_PADDER.value + + +@dataclass +class HessianConfig: + """Controls for chunked Hessian accumulation during GPTQ calibration.""" + + # Hessian accumulation controls (GPTQ only) + chunk_size: Optional[int] = field(default=None, metadata={"help": "Maximum rows per Hessian chunk"}) + chunk_bytes: Optional[int] = field(default=None, metadata={"help": "Memory budget (in bytes) for Hessian chunk staging"}) + staging_dtype: Union[str, torch.dtype] = field( + default=torch.float32, + metadata={"help": "Stage Hessian chunks in a lower precision dtype when supported"}, + ) + + def __post_init__(self): + """Validate Hessian chunking and staging dtype settings.""" + + if self.chunk_size is not None: + if not isinstance(self.chunk_size, int): + raise ValueError("HessianConfig: `chunk_size` must be an integer or None.") + if self.chunk_size <= 0: + raise ValueError("HessianConfig: `chunk_size` must be a positive integer.") + + if self.chunk_bytes is not None: + if not isinstance(self.chunk_bytes, int): + raise ValueError("HessianConfig: `chunk_bytes` must be an integer or None.") + if self.chunk_bytes <= 0: + raise ValueError("HessianConfig: `chunk_bytes` must be a positive integer amount of bytes.") + + if isinstance(self.staging_dtype, str): + self.staging_dtype = self.staging_dtype.lower() + if self.staging_dtype not in ["float32", "float16", "bfloat16"]: + raise ValueError("HessianConfig: `staging_dtype` must be float32, float16, or bfloat16.") + self.staging_dtype = getattr(torch, self.staging_dtype) + elif isinstance(self.staging_dtype, torch.dtype): + if self.staging_dtype not in [torch.float32, torch.float16, torch.bfloat16]: + raise ValueError("HessianConfig: `staging_dtype` must be float32, float16, or bfloat16.") + else: + raise ValueError("HessianConfig: `staging_dtype` must be a torch.dtype or string.") + + +@dataclass +class GPTAQConfig: + alpha: float = field(default=0.25) + device: Union[str, torch.device] = field(default="auto") + + def __post_init__(self): + if not isinstance(self.alpha, (int, float)): + raise ValueError("GPTAQConfig: `alpha` must be a numeric value.") + if isinstance(self.device, str): + if not self.device: + raise ValueError("GPTAQConfig: `device` must be a non-empty string or torch.device.") + elif not isinstance(self.device, torch.device): + raise ValueError("GPTAQConfig: `device` must be a string or torch.device.") + + +@dataclass +class FOEMConfig: + r"""Configuration parameters for the FOEM calibration process, including `alpha` and `beta`. + + The parameter `alpha` follows the same definition and role as in GPTAQ. + Note: although GPTAQ does not explicitly mention this coefficient in the paper, + its official implementation applies it to the rightmost term of Eq.18. + + The parameter `beta` is introduced by FOEM. Please refer to the paper for details: + https://ojs.aaai.org/index.php/AAAI/article/view/40123. + + Special cases: + - alpha = 0, beta = 0: + Equivalent to GPTQ. + - alpha > 0, beta = 0: + Equivalent to GPTAQ. The recommended value for `alpha` is 0.25. + - alpha = 0, beta > 0: + Equivalent to FOEM. Empirically, setting `beta` in the range [0.1, 0.25] yields good performance. + - alpha > 0, beta > 0: + Equivalent to FOEM + GPTAQ. Using the default best settings + (alpha = 0.25, beta = 0.2) generally produces strong results, + although it is not consistently superior to using FOEM alone. + + Args: + alpha (float, optional): Default is 0. + beta (float, optional): Default is 0.2. + """ + alpha: float = field(default=0) + beta: float = field(default=0.2) + device: Union[str, torch.device] = field(default="auto") + + def __post_init__(self): + if not isinstance(self.alpha, (int, float)): + raise ValueError("FOEMConfig: `alpha` must be a numeric value.") + if not isinstance(self.beta, (int, float)): + raise ValueError("FOEMConfig: `beta` must be a numeric value.") + if isinstance(self.device, str): + if not self.device: + raise ValueError("FOEMConfig: `device` must be a non-empty string or torch.device.") + elif not isinstance(self.device, torch.device): + raise ValueError("FOEMConfig: `device` must be a string or torch.device.") @dataclass @@ -398,7 +1299,8 @@ def __post_init__(self): # Validate integer values if not isinstance(self.num_experts_per_tok, int) or self.num_experts_per_tok <= 0: raise ValueError( - f"num_experts_per_tok must be a positive int or '{MOE_ALL_EXPERTS}', got {self.num_experts_per_tok}" + f"num_experts_per_tok must be a positive int or '{MOE_ALL_EXPERTS}', " + f"got {self.num_experts_per_tok}" ) @@ -411,7 +1313,8 @@ class ExpertsRoutingBypass(BaseMoERouting): # - First batch processes 10 modules (could be gate_proj for experts 0-9, or a mix depending on sorting) # - Second batch processes remaining 10 modules batch_size: Optional[int] = field( - default=None, metadata={"help": "Number of modules to process in a single batch during MoE quantization"} + default=None, + metadata={"help": "Number of modules to process in a single batch during MoE quantization"} ) @@ -421,7 +1324,10 @@ class MoEConfig: def __post_init__(self): if not isinstance(self.routing, BaseMoERouting): - raise ValueError(f"routing must be an instance of BaseMoERouting, got {type(self.routing).__name__}") + raise ValueError( + f"routing must be an instance of BaseMoERouting, " + f"got {type(self.routing).__name__}" + ) def routing_bypass(self) -> bool: return isinstance(self.routing, ExpertsRoutingBypass) @@ -438,10 +1344,7 @@ def routing_override(self, num_experts: int) -> Union[int, None]: """ if isinstance(self.routing, ExpertsRoutingOverride): # Resolve "all" to full expert count - if ( - isinstance(self.routing.num_experts_per_tok, str) - and self.routing.num_experts_per_tok.lower().strip() == MOE_ALL_EXPERTS - ): + if isinstance(self.routing.num_experts_per_tok, str) and self.routing.num_experts_per_tok.lower().strip() == MOE_ALL_EXPERTS: return num_experts assert isinstance(self.routing.num_experts_per_tok, int) @@ -449,10 +1352,8 @@ def routing_override(self, num_experts: int) -> Union[int, None]: # Clamp to valid range and warn user if needed if top_k > num_experts: - log.info( - f"MoEConfig: MoE routing override num_experts_per_tok ({top_k}) exceeds " - f"num_experts ({num_experts}); clamping to {num_experts}.", - ) + log.info(f"MoEConfig: MoE routing override num_experts_per_tok ({top_k}) exceeds " + f"num_experts ({num_experts}); clamping to {num_experts}.",) top_k = num_experts return top_k @@ -475,6 +1376,18 @@ def to_dict(self) -> Dict[str, Any]: FORMAT.MARLIN, FORMAT.BITBLAS, }, + METHOD.FP8: { + FORMAT.FP8, + }, + METHOD.BITSANDBYTES: { + FORMAT.BITSANDBYTES, + }, + METHOD.EXL3: { + FORMAT.EXL3, + }, + METHOD.GGUF: { + FORMAT.GGUF, + }, METHOD.QQQ: { FORMAT.QQQ, }, @@ -483,8 +1396,73 @@ def to_dict(self) -> Dict[str, Any]: FORMAT.GEMV, FORMAT.GEMV_FAST, FORMAT.MARLIN, + FORMAT.BITBLAS, FORMAT.LLM_AWQ, }, + METHOD.PARO: { + FORMAT.PAROQUANT, + }, +} + +GPTQ_EXPORT_FORMATS: Tuple[FORMAT, ...] = ( + FORMAT.GPTQ, + FORMAT.GPTQ_V2, + FORMAT.MARLIN, + FORMAT.BITBLAS, +) +AWQ_EXPORT_FORMATS: Tuple[FORMAT, ...] = ( + FORMAT.GEMM, + FORMAT.GEMV, + FORMAT.GEMV_FAST, + FORMAT.MARLIN, + FORMAT.BITBLAS, + FORMAT.LLM_AWQ, +) +PAROQUANT_EXPORT_FORMATS: Tuple[FORMAT, ...] = ( + FORMAT.PAROQUANT, +) +# Keep ParoQuant channel-scale clamps configurable so users can relax or +# tighten the safeguard without patching the optimizer code. +PAROQUANT_OPT_SCALE_CLAMP_MIN_DEFAULT = 1e-2 +PAROQUANT_OPT_SCALE_CLAMP_MAX_DEFAULT = 1e2 +QQQ_EXPORT_FORMATS: Tuple[FORMAT, ...] = ( + FORMAT.QQQ, +) +FP8_EXPORT_FORMATS: Tuple[FORMAT, ...] = ( + FORMAT.FP8, +) +BITSANDBYTES_EXPORT_FORMATS: Tuple[FORMAT, ...] = ( + FORMAT.BITSANDBYTES, +) +EXL3_EXPORT_FORMATS: Tuple[FORMAT, ...] = ( + FORMAT.EXL3, +) +RTN_EXPORT_FORMATS: Tuple[FORMAT, ...] = ( + FORMAT.GPTQ, + FORMAT.GPTQ_V2, + FORMAT.GEMM, + FORMAT.GEMV, + FORMAT.GEMV_FAST, + FORMAT.LLM_AWQ, +) +GGUF_EXPORT_FORMATS: Tuple[FORMAT, ...] = ( + FORMAT.GGUF, +) + +_UNAMBIGUOUS_EXPORT_METHOD_BY_FORMAT = { + FORMAT.GPTQ: METHOD.GPTQ, + FORMAT.GPTQ_V2: METHOD.GPTQ, + FORMAT.FP8: METHOD.FP8, + FORMAT.BITSANDBYTES: METHOD.BITSANDBYTES, + FORMAT.EXL3: METHOD.EXL3, + FORMAT.GGUF: METHOD.GGUF, + FORMAT.BITBLAS: METHOD.GPTQ, + FORMAT.GEMM: METHOD.AWQ, + FORMAT.GEMV: METHOD.AWQ, + FORMAT.GEMV_FAST: METHOD.AWQ, + FORMAT.LLM_AWQ: METHOD.AWQ, + FORMAT.PAROQUANT: METHOD.PARO, + FORMAT.QQQ: METHOD.QQQ, } # inference only methods should go here @@ -493,13 +1471,20 @@ def to_dict(self) -> Dict[str, Any]: # compat QUANT_CONFIG_ARG_SYNONYMS = { "w_bit": BITS_FIELD_CODE, + # QQQ compat "wbits": BITS_FIELD_CODE, "q_group_size": GROUP_SIZE_FIELD_CODE, + # AWQ compat - "version": FORMAT_FIELD_CODE, - # map format field (checkpoint_format) to class/code (format) + "version" : FORMAT_FIELD_CODE, + + # map deprecated aliases to canonical fields FORMAT_FIELD_CHECKPOINT: FORMAT_FIELD_CODE, + QUANT_METHOD_FIELD: METHOD_FIELD_CODE, + "bnb_quant_type": FORMAT_FIELD_CODE, + "bnb_block_size": "block_size", + "bnb_compress_statistics": "compress_statistics", } # compat (values are negated) @@ -509,7 +1494,6 @@ def to_dict(self) -> Dict[str, Any]: } DYNAMIC_FIELD_SYNONYMS = {} - def dict_scale_dtype_to_str(d: Dict[str, Any]) -> None: """ Checks whether the passed dictionary and its nested dicts have a *scale_dtype* key and if it's not None, @@ -587,25 +1571,108 @@ def _parse_smooth_method(setting: Any) -> Optional[SmoothMethod]: return _build_smooth_method_from_dict({"type": setting}) if isinstance(setting, dict): return _build_smooth_method_from_dict(setting) - raise ValueError("QuantizeConfig: `failsafe.smooth` must be a SmoothMethod, string, or dict.") + raise ValueError("QuantizeConfig: `fallback.smooth` must be a SmoothMethod, string, or dict.") -def dynamic_get( - dynamic: Dict[str, Dict[str, Union[int, bool]]], - module_name: str, - key: str = None, - default: Union[int, bool] = None, - sub_key: str = None, -) -> Union[Dict, int, bool]: +def _serialize_smooth_method(method: Optional[SmoothMethod]) -> Optional[Dict[str, Any]]: + if method is None: + return None + + payload = {"type": method.name, "group_size_threshold": method.group_size_threshold} + if isinstance(method, SmoothPercentile): + payload["percentile"] = method.percentile + elif isinstance(method, SmoothPercentileAsymmetric): + payload["low"] = method.low + payload["high"] = method.high + elif isinstance(method, SmoothMAD): + payload["k"] = method.k + elif isinstance(method, SmoothMSE): + payload["steps"] = method.steps + payload["maxshrink"] = method.maxshrink + elif isinstance(method, SmoothOutlier): + payload["pct"] = method.pct + elif isinstance(method, SmoothSoftNorm): + payload["k"] = method.k + elif isinstance(method, SmoothLog): + payload["percentile"] = method.percentile + payload["mu"] = method.mu + elif isinstance(method, SmoothRowCol): + payload["axis"] = method.axis + return payload + + +def _normalize_smoother_config( + payload: Optional[Union[SmootherConfig, SmoothMethod, Dict[str, Any], str]] +) -> Optional[SmootherConfig]: + if payload is None: + return None + if isinstance(payload, SmootherConfig): + return payload + if isinstance(payload, dict) and "smooth" in payload and "type" not in payload: + return SmootherConfig(smooth=payload.get("smooth")) + return SmootherConfig(smooth=payload) + + +def _normalize_preprocessor_config(payload: Any) -> BasePreProcessorConfig: + if isinstance(payload, BasePreProcessorConfig): + return payload + if isinstance(payload, SmoothMethod): + return SmootherConfig(smooth=payload) + if isinstance(payload, str): + normalized = payload.strip().lower() + if normalized == PreProcessorCode.SMOOTHER.value: + return SmootherConfig(smooth=None) + if normalized == PreProcessorCode.AUTO_MODULE_DECODER.value: + return AutoModuleDecoderConfig() + if normalized == PreProcessorCode.TENSOR_PARALLEL_PADDER.value: + return TensorParallelPadderConfig() + return SmootherConfig(smooth=payload) + if isinstance(payload, dict): + code = str(payload.get("code", "")).strip().lower() + if code == PreProcessorCode.AUTO_MODULE_DECODER.value: + return AutoModuleDecoderConfig( + source_dtype=payload.get("source_dtype", "auto"), + target_dtype=payload.get("target_dtype", torch.bfloat16), + ) + if code == PreProcessorCode.TENSOR_PARALLEL_PADDER.value: + return TensorParallelPadderConfig() + if code and code != PreProcessorCode.SMOOTHER.value: + raise ValueError(f"QuantizeConfig: unsupported preprocessor code `{code}`.") + if "smooth" in payload: + return SmootherConfig(smooth=payload.get("smooth")) + if "type" in payload: + return SmootherConfig(smooth=payload) + return SmootherConfig(smooth=None) + raise ValueError("QuantizeConfig: `preprocessors` entries must be preprocessor configs, smooth configs, dicts, or strings.") + + +def _normalize_preprocessors(payload: Optional[List[Any]]) -> List[BasePreProcessorConfig]: + if payload is None: + return [] + if not isinstance(payload, list): + raise ValueError("QuantizeConfig: `preprocessors` must be a list or None.") + return [_normalize_preprocessor_config(item) for item in payload] + + +def _validate_unique_preprocessors(preprocessors: List[BasePreProcessorConfig]) -> None: + codes_seen = set() + for preprocessor in preprocessors: + if preprocessor.code in codes_seen: + raise ValueError(f"QuantizeConfig: duplicate preprocessor `{preprocessor.code}` is not allowed.") + codes_seen.add(preprocessor.code) + + +def dynamic_get(dynamic: Dict[str, Dict[str, Union[int, bool]]], module_name: str, key: str = None, + default: Union[int, bool] = None, sub_key: str = None) -> Union[Dict, int, bool]: if dynamic is None: return default for pattern, overrides in dynamic.items(): if pattern.startswith("-:"): - if re.match(pattern.removeprefix("-:"), module_name): + if pcre.compile(pattern.removeprefix("-:")).match(module_name): return False - elif re.match(pattern.removeprefix("+:"), module_name): + elif pcre.compile(pattern.removeprefix("+:")).match(module_name): if key is None: return overrides else: @@ -620,9 +1687,7 @@ def dynamic_get( if isinstance(sub_value, Dict): return sub_value.get(sub_key, default) else: - log.info( - f"QuantConfig: Dynamic `sub_key`: `{sub_key}` failed extraction from `sub_value`: `{sub_value}`" - ) + log.info(f"QuantConfig: Dynamic `sub_key`: `{sub_key}` failed extraction from `sub_value`: `{sub_value}`") else: if key in overrides: return overrides[key] @@ -633,29 +1698,597 @@ def dynamic_get( return default return default +def _normalize_quant_method(value: Union[str, METHOD]) -> METHOD: + if isinstance(value, str): + value = value.lower() + if value == FORMAT.MARLIN: + return METHOD.GPTQ + if value == FORMAT.BITBLAS: + return METHOD.GPTQ + if value == FORMAT.FP8: + return METHOD.FP8 + if value == FORMAT.BITSANDBYTES: + return METHOD.BITSANDBYTES + if value == FORMAT.EXL3: + return METHOD.EXL3 + if value == FORMAT.PAROQUANT: + return METHOD.PARO + try: + return METHOD(value) + except ValueError as exc: + raise ValueError(f"QuantizeConfig: Unknown quantization method: `{value}`.") from exc + if not isinstance(value, METHOD): + raise ValueError(f"QuantizeConfig: Unsupported `method`: {value}") + return value + + +def _normalize_format(value: Union[str, FORMAT]) -> FORMAT: + if isinstance(value, str): + try: + return FORMAT(value.lower()) + except ValueError as exc: + raise ValueError(f"QuantizeConfig: Unknown quantization format: `{value}`.") from exc + if not isinstance(value, FORMAT): + raise ValueError(f"QuantizeConfig: Unknown quantization format: `{value}`.") + return value + + +def _normalize_pack_dtype(pack_dtype: Optional[Union[str, torch.dtype]]) -> torch.dtype: + if pack_dtype is None: + return torch.int32 + if isinstance(pack_dtype, str): + pack_dtype = pack_dtype.lower() + if pack_dtype not in ["int64", "int32", "int16", "int8"]: + raise ValueError(f"QuantizeConfig: Unsupported `pack_dtype`: {pack_dtype}") + return getattr(torch, pack_dtype) + if isinstance(pack_dtype, torch.dtype): + if pack_dtype not in [torch.int64, torch.int32, torch.int16, torch.int8]: + raise ValueError(f"QuantizeConfig: Unsupported `pack_dtype`: {pack_dtype}") + return pack_dtype + raise ValueError(f"QuantizeConfig: Unsupported `pack_dtype`: {pack_dtype}") + + +def _normalize_paroquant_best_state_dtype(best_state_dtype: Optional[Union[str, torch.dtype]]) -> str: + """Canonicalize the ParoQuant best-state snapshot dtype into a serialized string.""" + if best_state_dtype is None: + return "fp32" + if isinstance(best_state_dtype, str): + normalized = best_state_dtype.strip().lower() + if normalized in {"fp16", "float16"}: + return "fp16" + if normalized in {"bf16", "bfloat16"}: + return "bf16" + if normalized in {"fp32", "float32"}: + return "fp32" + elif isinstance(best_state_dtype, torch.dtype): + if best_state_dtype == torch.float16: + return "fp16" + if best_state_dtype == torch.bfloat16: + return "bf16" + if best_state_dtype == torch.float32: + return "fp32" + raise ValueError( + "ParoConfig: `opt_best_state_dtype` must be one of {'fp16', 'bf16', 'fp32'} " + "or torch.float16/torch.bfloat16/torch.float32." + ) + + +def _normalize_fallback(fallback: Optional[Union[Fallback, Dict[str, Any], str, int, float]]) -> Optional[Fallback]: + if fallback is None: + return None + if isinstance(fallback, dict): + strategy = fallback.get("strategy", FallbackStrategy.RTN) + threshold = fallback.get("threshold", "1.0%") + smooth = fallback.get("smooth") + if smooth is None: + smooth = fallback.get("smooth_method") + if smooth is None and "clip_method" in fallback: + smooth = fallback.get("clip_method") + smooth = _parse_smooth_method(smooth) + if smooth is None: + if "smooth_percentile" in fallback: + smooth = SmoothPercentile(percentile=float(fallback.get("smooth_percentile", 99.0))) + elif "smooth_mad_k" in fallback: + smooth = SmoothMAD(k=float(fallback.get("smooth_mad_k", 3.0))) + elif "smooth_mse_steps" in fallback or "smooth_mse_maxshrink" in fallback: + smooth = SmoothMSE( + steps=int(fallback.get("smooth_mse_steps", 32)), + maxshrink=float(fallback.get("smooth_mse_maxshrink", 0.8)), + ) + elif "smooth_outlier_pct" in fallback: + smooth = SmoothOutlier(pct=float(fallback.get("smooth_outlier_pct", 1.0))) + elif "smooth_rms_k" in fallback: + smooth = SmoothSoftNorm(k=float(fallback.get("smooth_rms_k", 3.0))) + elif "smooth_log_mu" in fallback: + smooth = SmoothLog( + percentile=float(fallback.get("smooth_percentile", 99.0)), + mu=float(fallback.get("smooth_log_mu", 8.0)), + ) + elif "smooth_axis" in fallback: + smooth = SmoothRowCol(axis=str(fallback.get("smooth_axis", "row"))) + fallback = Fallback(strategy=strategy, threshold=threshold, smooth=smooth) + elif isinstance(fallback, (str, int, float)): + fallback = Fallback(strategy=FallbackStrategy.RTN, threshold=fallback) + elif not isinstance(fallback, Fallback): + raise ValueError("QuantizeConfig: `fallback` must be a Fallback config, dict, string, int, float, or None.") + + if isinstance(fallback.strategy, str): + try: + fallback.strategy = FallbackStrategy(fallback.strategy.lower()) + except ValueError as exc: + raise ValueError( + f"QuantizeConfig: `fallback.strategy` must be one of {[v.value for v in FallbackStrategy]}." + ) from exc + elif not isinstance(fallback.strategy, FallbackStrategy): + raise ValueError( + f"QuantizeConfig: `fallback.strategy` must be one of {[v.value for v in FallbackStrategy]}." + ) + + fallback.smooth = _parse_smooth_method(fallback.smooth) + return fallback + + +def _normalize_weight_only( + weight_only: Optional[Union[WeightOnlyConfig, Dict[str, Any], str]] +) -> Optional[WeightOnlyConfig]: + if weight_only is None: + return None + if isinstance(weight_only, dict): + method = weight_only.get("method", WeightOnlyMethod.RTN) + smooth = weight_only.get("smooth") + if smooth is None: + smooth = weight_only.get("smooth_method") + return WeightOnlyConfig(method=method, smooth=smooth) + if isinstance(weight_only, str): + return WeightOnlyConfig(method=weight_only) + if not isinstance(weight_only, WeightOnlyConfig): + raise ValueError( + "QuantizeConfig: `weight_only` must be a WeightOnlyConfig, dict, string, or None." + ) + return weight_only + + +def _normalize_hessian(hessian: Optional[Union[HessianConfig, Dict[str, Any]]]) -> HessianConfig: + if hessian is None: + return HessianConfig() + if isinstance(hessian, dict): + return HessianConfig(**hessian) + if not isinstance(hessian, HessianConfig): + raise ValueError("QuantizeConfig: `hessian` must be a HessianConfig, dict, or None.") + return hessian + + +def _normalize_gptaq(gptaq: Optional[Union[GPTAQConfig, Dict[str, Any]]]) -> Optional[GPTAQConfig]: + if gptaq is None: + return None + if isinstance(gptaq, dict): + return GPTAQConfig(**gptaq) + if not isinstance(gptaq, GPTAQConfig): + raise ValueError("QuantizeConfig: `gptaq` must be a GPTAQConfig, dict, or None.") + return gptaq + + +def _normalize_foem(foem: Optional[Union[FOEMConfig, Dict[str, Any]]]) -> Optional[FOEMConfig]: + if foem is None: + return None + if isinstance(foem, dict): + return FOEMConfig(**foem) + if not isinstance(foem, FOEMConfig): + raise ValueError("QuantizeConfig: `foem` must be a FOEMConfig, dict, or None.") + return foem + + +def _normalize_dense_vram_strategy(value: Union[str, VramStrategy]) -> VramStrategy: + """Validate one user-supplied dense-pool placement strategy value.""" + + if isinstance(value, str): + try: + return VramStrategy(value.lower()) + except ValueError as exc: + raise ValueError( + f"QuantizeConfig: `dense_vram_strategy` must be one of {[v.value for v in VramStrategy]}." + ) from exc + if not isinstance(value, VramStrategy): + raise ValueError( + f"QuantizeConfig: `dense_vram_strategy` must be one of {[v.value for v in VramStrategy]}." + ) + return value + + +def _normalize_moe_vram_strategy(value: Union[str, VramStrategy]) -> VramStrategy: + """Validate one user-supplied MoE expert-pool placement strategy value.""" + + if isinstance(value, str): + try: + return VramStrategy(value.lower()) + except ValueError as exc: + raise ValueError( + f"QuantizeConfig: `moe_vram_strategy` must be one of {[v.value for v in VramStrategy]}." + ) from exc + if not isinstance(value, VramStrategy): + raise ValueError( + f"QuantizeConfig: `moe_vram_strategy` must be one of {[v.value for v in VramStrategy]}." + ) + return value + + +def _normalize_strategy_devices( + value: Optional[List[Union[str, torch.device]]], + *, + field_name: str, +) -> Optional[List[str]]: + """Normalize one user-facing strategy device pool to stable device strings.""" + + if value is None: + return None + if not isinstance(value, list): + raise ValueError(f"QuantizeConfig: `{field_name}` must be a list of device strings or torch.device values.") + if not value: + raise ValueError(f"QuantizeConfig: `{field_name}` must not be empty when provided.") + + # Import lazily to keep config parsing light and avoid depending on looper + # modules unless the caller actually configures explicit device pools. + from ..utils.looper_helpers import normalize_device_like + + normalized_devices: List[str] = [] + seen = set() + for raw_device in value: + normalized = normalize_device_like(raw_device) + if normalized is None: + raise ValueError(f"QuantizeConfig: `{field_name}` contains an unsupported device value: {raw_device!r}.") + key = str(normalized) + if key in seen: + continue + seen.add(key) + normalized_devices.append(key) + return normalized_devices + + +def _normalize_gc_mode(value: Union[str, GcMode]) -> GcMode: + if isinstance(value, str): + try: + return GcMode(value.lower()) + except ValueError as exc: + raise ValueError( + f"QuantizeConfig: `gc_mode` must be one of {[v.value for v in GcMode]}." + ) from exc + if not isinstance(value, GcMode): + raise ValueError( + f"QuantizeConfig: `gc_mode` must be one of {[v.value for v in GcMode]}." + ) + return value + + +def _normalize_moe_config(value: Optional[Union[MoEConfig, Dict[str, Any]]]) -> Optional[MoEConfig]: + if value is None: + return None + if isinstance(value, MoEConfig): + return value + if not isinstance(value, dict): + raise ValueError("QuantizeConfig: `moe` must be a MoEConfig, dict, or None.") + + routing = value.get("routing") + if isinstance(routing, BaseMoERouting): + return MoEConfig(routing=routing) + if not isinstance(routing, dict): + raise ValueError("QuantizeConfig: `moe.routing` must be a BaseMoERouting, dict, or None.") + + routing_class = routing.get("class") + if routing_class == ExpertsRoutingOverride.__name__: + routing_obj = ExpertsRoutingOverride( + num_experts_per_tok=routing.get("num_experts_per_tok", MOE_ALL_EXPERTS) + ) + elif routing_class == ExpertsRoutingBypass.__name__: + routing_obj = ExpertsRoutingBypass(batch_size=routing.get("batch_size")) + else: + raise ValueError(f"QuantizeConfig: Unknown `moe.routing.class`: `{routing_class}`.") + + return MoEConfig(routing=routing_obj) + + +def _resolve_dynamic_group_size_error() -> str: + return "QuantizeConfig: `group_size` must be one of `[-1, 16, 32, 64, 128, 256, 512, 1024]`." + + +def _default_damp_percent(method: METHOD) -> float: + return 0.005 if method == METHOD.QQQ else 0.05 + + +def _default_damp_auto_increment(method: METHOD) -> float: + return 0.001 if method == METHOD.QQQ else 0.01 + + +def _peek_weight_only_method(payload: Any) -> Optional[WeightOnlyMethod]: + if payload is None: + return None + if isinstance(payload, WeightOnlyConfig): + return payload.method + if isinstance(payload, str): + try: + return WeightOnlyMethod(payload.lower()) + except ValueError: + return None + if isinstance(payload, dict): + method = payload.get("method", WeightOnlyMethod.RTN) + try: + return WeightOnlyMethod(str(method).lower()) + except ValueError: + return None + return None + + +def _extract_weight_only_smooth(payload: Any) -> Any: + if payload is None: + return None + if isinstance(payload, WeightOnlyConfig): + return payload.smooth + if isinstance(payload, dict): + smooth = payload.get("smooth") + if smooth is None: + smooth = payload.get("smooth_method") + return smooth + if isinstance(payload, str): + return None + raise ValueError("QuantizeConfig: `weight_only` must be a WeightOnlyConfig, dict, string, or None.") + + +def _extract_weight_only_legacy_gguf_bits(payload: Any) -> Any: + if payload is None: + return None + if isinstance(payload, WeightOnlyConfig): + return getattr(payload, "gguf_qtype", None) + if isinstance(payload, dict): + return payload.get("gguf_qtype") + if isinstance(payload, str): + return None + raise ValueError("QuantizeConfig: `weight_only` must be a WeightOnlyConfig, dict, string, or None.") + + +def _normalize_rtn_kwargs(payload: Dict[str, Any]) -> Dict[str, Any]: + normalized = dict(payload) + legacy_gguf_bits = normalized.pop("gguf_qtype", None) + weight_only = normalized.pop("weight_only", None) + weight_only_method = _peek_weight_only_method(weight_only) + + # `weight_only.method="gguf"` is a backward-compatible shorthand for the direct GGUF weight-only lifecycle. + if weight_only_method == WeightOnlyMethod.GGUF and FORMAT_FIELD_CODE not in normalized: + normalized[FORMAT_FIELD_CODE] = FORMAT.GGUF + + if "smooth" not in normalized: + normalized["smooth"] = _extract_weight_only_smooth(weight_only) + if legacy_gguf_bits is None: + legacy_gguf_bits = _extract_weight_only_legacy_gguf_bits(weight_only) + if legacy_gguf_bits is not None and BITS_FIELD_CODE not in normalized: + normalized[BITS_FIELD_CODE] = legacy_gguf_bits + return normalized + + +def _normalize_gguf_kwargs(payload: Dict[str, Any]) -> Dict[str, Any]: + normalized = dict(payload) + legacy_gguf_bits = normalized.pop("gguf_qtype", None) + weight_only = normalized.pop("weight_only", None) + + if "smoother" not in normalized and "smooth" not in normalized: + normalized["smoother"] = _extract_weight_only_smooth(weight_only) + if legacy_gguf_bits is None: + legacy_gguf_bits = _extract_weight_only_legacy_gguf_bits(weight_only) + if legacy_gguf_bits is not None and BITS_FIELD_CODE not in normalized: + normalized[BITS_FIELD_CODE] = legacy_gguf_bits + normalized[BITS_FIELD_CODE], normalized[FORMAT_FIELD_CODE], _ = _normalize_gguf_config_spec( + normalized.get(BITS_FIELD_CODE, 4), + normalized.get(FORMAT_FIELD_CODE), + ) + return normalized + + +def _normalize_fp8_kwargs(payload: Dict[str, Any]) -> Dict[str, Any]: + normalized = dict(payload) + weight_only = normalized.pop("weight_only", None) + legacy_fmt = normalized.pop("fmt", None) + + if "smoother" not in normalized and "smooth" not in normalized: + normalized["smoother"] = _extract_weight_only_smooth(weight_only) + + normalized[FORMAT_FIELD_CODE] = _normalize_fp8_fmt( + normalized.get(FORMAT_FIELD_CODE, legacy_fmt) + ) + + weight_block_size = _normalize_fp8_weight_block_size(normalized.get("weight_block_size")) + normalized["weight_block_size"] = list(weight_block_size) if weight_block_size is not None else None + + normalized["weight_scale_method"] = _normalize_fp8_weight_scale_method( + normalized.get("weight_scale_method"), + weight_block_size=weight_block_size, + ) + return normalized + + +def _normalize_bitsandbytes_kwargs(payload: Dict[str, Any]) -> Dict[str, Any]: + normalized = dict(payload) + weight_only = normalized.pop("weight_only", None) + + if "smoother" not in normalized and "smooth" not in normalized: + normalized["smoother"] = _extract_weight_only_smooth(weight_only) + + legacy_format = normalized.pop("bnb_quant_type", None) + legacy_block_size = normalized.pop("bnb_block_size", None) + legacy_compress_statistics = normalized.pop("bnb_compress_statistics", None) + + normalized[FORMAT_FIELD_CODE] = _normalize_bitsandbytes_format( + normalized.get(FORMAT_FIELD_CODE, legacy_format), + bits=int(normalized.get(BITS_FIELD_CODE, 4)), + ) + normalized["block_size"] = _normalize_bitsandbytes_block_size( + normalized.get("block_size", legacy_block_size) + ) + normalized["compress_statistics"] = bool( + normalized.get("compress_statistics", legacy_compress_statistics if legacy_compress_statistics is not None else True) + ) + return normalized + + +def _resolve_export_quant_method(format_value: FORMAT, fallback_method: Optional[METHOD] = None) -> METHOD: + if format_value == FORMAT.MARLIN: + if fallback_method is None: + raise ValueError("QuantizeConfig: FORMAT.MARLIN requires an explicit quantization method family.") + return fallback_method + + method = _UNAMBIGUOUS_EXPORT_METHOD_BY_FORMAT.get(format_value) + if method is None: + if fallback_method is not None: + return fallback_method + raise ValueError(f"QuantizeConfig: Unable to resolve export method for format `{format_value}`.") + return method + + +def _normalize_quantize_config_payload_for_target_cls(target_cls, payload: Dict[str, Any]) -> Dict[str, Any]: + normalized = dict(payload) + + if target_cls is AWQConfig: + expected_method = METHOD.AWQ + elif target_cls is FP8Config: + expected_method = METHOD.FP8 + elif target_cls is BitsAndBytesConfig: + expected_method = METHOD.BITSANDBYTES + elif target_cls is EXL3Config: + expected_method = METHOD.EXL3 + format_value = normalized.get(FORMAT_FIELD_CODE) + normalized_format = None + if format_value is not None: + try: + normalized_format = _normalize_format(format_value) + normalized[FORMAT_FIELD_CODE] = normalized_format + except ValueError: + normalized_format = None + if normalized_format is not None and normalized_format != FORMAT.EXL3: + log.info(f"QuantizeConfig: Auto fix `format` to `{FORMAT.EXL3}`") + normalized[FORMAT_FIELD_CODE] = FORMAT.EXL3 + elif target_cls is ParoConfig: + expected_method = METHOD.PARO + format_value = normalized.get(FORMAT_FIELD_CODE) + normalized_format = None + if format_value is not None: + try: + normalized_format = _normalize_format(format_value) + normalized[FORMAT_FIELD_CODE] = normalized_format + except ValueError: + normalized_format = None + if normalized_format is not None and normalized_format != FORMAT.PAROQUANT: + log.info(f"QuantizeConfig: Auto fix `format` to `{FORMAT.PAROQUANT}`") + normalized[FORMAT_FIELD_CODE] = FORMAT.PAROQUANT + elif target_cls is GGUFConfig: + expected_method = METHOD.GGUF + elif target_cls is QQQConfig: + expected_method = METHOD.QQQ + format_value = normalized.get(FORMAT_FIELD_CODE) + normalized_format = None + if format_value is not None: + try: + normalized_format = _normalize_format(format_value) + normalized[FORMAT_FIELD_CODE] = normalized_format + except ValueError: + normalized_format = None + if normalized_format is not None and normalized_format != FORMAT.QQQ: + log.info(f"QuantizeConfig: Auto fix `format` to `{FORMAT.QQQ}`") + normalized[FORMAT_FIELD_CODE] = FORMAT.QQQ + else: + expected_method = METHOD.GPTQ + + method = normalized.get(METHOD_FIELD_CODE) + normalized_method = None + if method is not None: + try: + normalized_method = _normalize_quant_method(method) + normalized[METHOD_FIELD_CODE] = normalized_method + except ValueError: + normalized_method = None + + if normalized_method is not None and normalized_method != expected_method: + if target_cls is GGUFConfig and normalized_method == METHOD.GPTQ: + pass + else: + log.warn( + f"QuantizeConfig: `{METHOD_FIELD_CODE}`=`{normalized_method}` is incompatible with `{target_cls.__name__}`. " + f"Auto-fix method to `{expected_method}`." + ) + normalized[METHOD_FIELD_CODE] = expected_method + + return normalized + + +def _filter_quantize_config_payload_for_target_cls(target_cls, payload: Dict[str, Any]) -> Dict[str, Any]: + target_field_names = {field.name for field in fields(target_cls) if field.init} + return {key: value for key, value in payload.items() if key in target_field_names} + + +def _prepare_target_quantize_config_kwargs(target_cls, payload: Dict[str, Any]) -> Dict[str, Any]: + normalized = _normalize_quantize_config_payload_for_target_cls(target_cls, payload) + if target_cls is RTNConfig: + normalized = _normalize_rtn_kwargs(normalized) + elif target_cls is GGUFConfig: + normalized = _normalize_gguf_kwargs(normalized) + elif target_cls is FP8Config: + normalized = _normalize_fp8_kwargs(normalized) + elif target_cls is BitsAndBytesConfig: + normalized = _normalize_bitsandbytes_kwargs(normalized) + return _filter_quantize_config_payload_for_target_cls(target_cls, normalized) + + +class QuantizeConfigMeta(type): + def __instancecheck__(cls, instance): + if cls is QuantizeConfig: + return isinstance(instance, BaseQuantizeConfig) + return super().__instancecheck__(instance) + + def __subclasscheck__(cls, subclass): + if cls is QuantizeConfig: + try: + return issubclass(subclass, BaseQuantizeConfig) + except TypeError: + return False + return super().__subclasscheck__(subclass) + + def __call__(cls, *args, **kwargs): + kwargs = _normalize_quantize_config_constructor_kwargs(kwargs) + if cls is QuantizeConfig: + target_cls = _resolve_quantize_config_class(kwargs) + target_kwargs = _prepare_target_quantize_config_kwargs(target_cls, kwargs) + return type.__call__(target_cls, *args, **target_kwargs) + return super().__call__(*args, **kwargs) + + +def _normalize_quantize_config_constructor_kwargs(kwargs: Dict[str, Any]) -> Dict[str, Any]: + if not kwargs: + return kwargs + + normalized = dict(kwargs) + if FORMAT_FIELD_COMPAT_MARLIN in normalized: + raise ValueError( + "QuantizeConfig: `is_marlin_format` has been removed. Use `format=\"marlin\"` only for legacy checkpoint inspection, " + "or `format=\"gptq\"` for new GPTQ quantization." + ) + if METHOD_FIELD_CODE not in normalized and QUANT_METHOD_FIELD in normalized: + normalized[METHOD_FIELD_CODE] = normalized[QUANT_METHOD_FIELD] + normalized.pop(QUANT_METHOD_FIELD, None) + + if FORMAT_FIELD_CODE not in normalized and FORMAT_FIELD_CHECKPOINT in normalized: + normalized[FORMAT_FIELD_CODE] = normalized[FORMAT_FIELD_CHECKPOINT] + normalized.pop(FORMAT_FIELD_CHECKPOINT, None) + return normalized + @dataclass -class QuantizeConfig: - bits: int = field(default=4, metadata={"choices": [2, 3, 4, 8]}) +class BaseQuantizeConfig(metaclass=QuantizeConfigMeta): + bits: Union[int, str, GGUFBits] = field(default=4, metadata={"choices": [2, 3, 4, 5, 6, 8]}) # allow dynamic bitsize per layer, if None or some layer not set, use bits - dynamic: Optional[Dict[str, Dict[str, Union[int, bool]]]] = field(default=None) + dynamic: Optional[Dict[str, Dict[str, Union[int, str, bool, GGUFBits]]]] = field(default=None) - # GPTQ only - # 128 offer good balance between inference speed, vram usage (bpw), and quality - # use 32 for highest quality with slower inference and higher vram usage + # 128 offers a good balance between inference speed, VRAM usage, and quality. group_size: int = field(default=128) - # increase damp if NaN is encountered during `.quantize()` and/or increase calib dataset size - damp_percent: float = field(default=None) - damp_auto_increment: float = field(default=None) - desc_act: Optional[bool] = field(default=None) - # GPTQ only - act_group_aware: Optional[bool] = field(default=None) - static_groups: bool = field(default=False) - # symmetric quantization toggle (True=symmetric, False=asymmetric). sym: bool = field(default=True) @@ -663,27 +2296,12 @@ class QuantizeConfig: lm_head: bool = field(default=False) - quant_method: METHOD = field(default=METHOD.GPTQ) + method: METHOD = field(default=METHOD.GPTQ) - # default to gptq v1 format for maximum compat with 3rd party inference libs with minimal loss vs v2 - # if you inference with gptqmodel, save to gptq_v2 format for best result + # Serialized/exported checkpoint layout. This is the authoritative post-quantization format. format: FORMAT = field(default=FORMAT.GPTQ) - # quantization_order: str = "activate", - # quantization_scale: str = "mse", # or absmax - # is_distributed: bool = False, - # tied_gptq_handle: Optional["GPTQ"] = None - - # GPTQ only - # mean square error calculation: may reduce error loss for some models - mse: float = field(default=0.0) - - # GPTQ only - # use Hessian-diagonal activation importance to weight MSE grid search offline - activation_weighted_mse: bool = field(default=False) - - # properties that do not directly contributes to quantization or quant inference should be placed in meta - # i.e. quantizer tool (producer) + version, timestamp, entity who made the quant, etc + # properties that do not directly contribute to quantization or inference should be placed in meta meta: Optional[Dict] = field(default=None) # normalized to DEVICE after passing to load() @@ -694,345 +2312,256 @@ class QuantizeConfig: # affects [`qweights`, `qzeros`] pack_dtype: Optional[Union[str, torch.dtype]] = field(default=torch.int32) - # packing implementation hinpt (`original` = legacy CPU pack, `gpu` enables CUDA pack, `cpu` forces block CPU pack). + # packing implementation hint (`original` = legacy CPU pack, `gpu` enables CUDA pack, `cpu` forces block CPU pack). pack_impl: str = field(default="cpu") - # pending used field adapter: Optional[Union[Dict[str, Any], Lora]] = field(default=None) - # quantization only: # controls cpu memory saving by offloading layers/modules to disk in the slow quantization process - # default to true as the benefit of ~73.5% cpu memory saving is tremendous offload_to_disk: bool = field( - default=True, metadata={"help": "Offload completed module memory to disk during quantization loop"} + default=True, + metadata={"help": "Offload completed module memory to disk during quantization loop"}, ) offload_to_disk_path: str = field( - default=None, metadata={"help": "Offload disk path. Only applicable if Offload to disk is enabled"} + default=None, + metadata={"help": "Offload disk path. Only applicable if Offload to disk is enabled"}, ) rotation: Optional[str] = field(default=None, metadata={"choices": ["hadamard", "random"]}) - # GPTQ only - # deprecated: only used for compat - is_marlin_format: bool = False - - # gptq only: - # if calibration is insufficient, fallback to a simple quantization strategy; encapsulated in FailSafe config - failsafe: Optional[FailSafe] = field(default_factory=FailSafe) - - # GPTQ only - # gptaq only: - gptaq: Optional[GPTAQConfig] = field(default=None) - - # gptq only: - # skip all heavy computations for testing model loading - mock_quantization: bool = field( - default=False, metadata={"help": "Skip heavy computations for fast model loading validation"} - ) - - # GPTQ only - # Hessian accumulation controls (GPTQ only) - hessian: Optional[HessianConfig] = field(default_factory=HessianConfig) + # if calibration is insufficient, fallback to a simple quantization strategy + fallback: Optional[Fallback] = field(default_factory=Fallback) # Callback function to filter devices for compute-intensive stages (quantization and forwarding) - # Takes a list of devices and returns either the original list or a filtered subset compute_device_filter: Optional[callable] = field( default=None, - metadata={ - "help": "Callback function to filter devices for compute-intensive stages. Function signature: fn(devices: List) -> List. " - "Example to exclude device 0: compute_device_filter=lambda devices: [d for d in devices if d.index != 0]" - }, + metadata={"help": "Callback function to filter devices for compute-intensive stages. Function signature: fn(devices: List) -> List. " + "Example to exclude device 0: compute_device_filter=lambda devices: [d for d in devices if d.index != 0]"} + ) + + # Device for storing calibration data during input capture + calibration_data_device: Optional[Union[str, torch.device]] = field( + default=None, + metadata={"help": "Device for storing calibration data. 'balanced' = round-robin across GPUs, or specify device like 'cuda:1'."} ) - # Works faster than data parallel with some configurations auto_forward_data_parallel: bool = field( default=True, - metadata={ - "help": "When multi-gpu is detected, we may data clone modules to each gpu for data parallelism " - "to speed up quantization forwarding. This causes extra time spent (especially for MoE layers) and vram pressure, " - "leading in some cases to slower forwarding or vram OOM" - }, + metadata={"help": "When multi-gpu is detected, we may data clone modules to each gpu for data parallelism " + "to speed up quantization forwarding. This causes extra time spent (especially for MoE layers) and vram pressure, " + "leading in some cases to slower forwarding or vram OOM"} ) - # VRAM allocation strategy for MoE-heavy subsets - vram_strategy: VramStrategy = field(default=VramStrategy.EXCLUSIVE) + # User-facing dense-pool strategy. The dense pool owns the serial path: + # qkv, z, out_proj, norms, router, shared expert, and dense MLP modules. + dense_vram_strategy: VramStrategy = field( + default=VramStrategy.EXCLUSIVE, + metadata={"help": "Dense pool placement strategy. The dense pool owns qkv, z, out_proj, norms, router, shared expert, and dense MLP modules."}, + ) + # Optional dense-pool device list, relative to CUDA_VISIBLE_DEVICES. In + # BALANCED mode, model-tree calculation groups stay together, so qkv is not split. + dense_vram_strategy_devices: Optional[List[Union[str, torch.device]]] = field( + default=None, + metadata={"help": "Explicit device pool for dense modules. In dense BALANCED mode, modules are assigned by calculation groups, so qkv stays co-located."}, + ) + # User-facing expert-pool strategy. Expert families are placed as whole + # units so gate/up/down for one expert stay on the same device. + moe_vram_strategy: VramStrategy = field( + default=VramStrategy.EXCLUSIVE, + metadata={"help": "MoE expert-pool placement strategy. Expert families stay co-located and can be balanced across this pool."}, + ) + # Optional expert-pool device list, relative to CUDA_VISIBLE_DEVICES. + moe_vram_strategy_devices: Optional[List[Union[str, torch.device]]] = field( + default=None, + metadata={"help": "Explicit device pool for MoE expert modules. Each expert family (gate/up/down) stays on one device."}, + ) gc_mode: GcMode = field( default=GcMode.INTERVAL, - metadata={ - "help": "Garbage collection mode: 'interval' for regular GC or 'on_stage_end' for GC after stage end (after forward pass, quantize, layer finilization)." - }, + metadata={"help": "Garbage collection mode: 'interval' for regular GC or 'on_stage_end' for GC after stage end (after forward pass, quantize, layer finilization)."} ) - # Control whether to wait for layer finalization (packing, writing) before proceeding to next layer - # Default False preserves current behavior (async finalization in background while next layer starts) wait_for_submodule_finalizers: bool = field( default=False, - metadata={ - "help": "Wait for all layer finalization tasks (packing, offloading to disk, etc) to complete before proceeding to next layer. May reduce vram pressure for some env." - }, + metadata={"help": "Wait for all layer finalization tasks (packing, offloading to disk, etc) to complete before proceeding to next layer. May reduce vram pressure for some env."} ) - moe: MoEConfig = field( + moe: Optional[MoEConfig] = field( default=None, - metadata={ - "help": "Mixture-of-Experts (MoE) configuration for routing strategy and expert batching. " - "Example with bypass routing (forward all data to each expert): " - "moe={'routing': {'class': 'ExpertsRoutingBypass', 'batch_size': None}} - processes all experts in one batch (default). " - "moe={'routing': {'class': 'ExpertsRoutingBypass', 'batch_size': 4}} - processes 4 experts at a time to reduce VRAM pressure. " - "Example with routing override (limit experts per token): " - "moe={'routing': {'class': 'ExpertsRoutingOverride', 'num_experts_per_tok': 2}}. " - "Example to forward to all experts: " - "moe={'routing': {'class': 'ExpertsRoutingOverride', 'num_experts_per_tok': 'all'}}" - }, + metadata={"help": "Mixture-of-Experts (MoE) configuration for routing strategy and expert batching. " + "Requires import: from gptqmodel.quantization.config import MoEConfig, ExpertsRoutingBypass, ExpertsRoutingOverride. " + "Example with bypass routing (forward all data to each expert): " + "moe=MoEConfig(routing=ExpertsRoutingBypass()) - processes all experts in one batch (default). " + "moe=MoEConfig(routing=ExpertsRoutingBypass(batch_size=4)) - processes 4 modules at a time to reduce VRAM pressure. " + "Example with routing override (limit experts per token): " + "moe=MoEConfig(routing=ExpertsRoutingOverride(num_experts_per_tok=2)). " + "Example to forward to all experts: " + "moe=MoEConfig(routing=ExpertsRoutingOverride(num_experts_per_tok='all'))"} ) - def __post_init__(self): - fields_info = fields(self) + @property + def quant_method(self) -> METHOD: + return self.method - # validate/normalizes pack_dtype from string and dtype to valid dtype - if self.pack_dtype is None: - self.pack_dtype = torch.int32 - else: - if isinstance(self.pack_dtype, str): - self.pack_dtype = self.pack_dtype.lower() - if self.pack_dtype not in ["int64", "int32", "int16", "int8"]: - raise ValueError(f"QuantizeConfig: Unsupported `pack_dtype`: {self.pack_dtype}") - self.pack_dtype = getattr(torch, self.pack_dtype) - elif isinstance(self.pack_dtype, torch.dtype): - if self.pack_dtype not in [torch.int64, torch.int32, torch.int16, torch.int8]: - raise ValueError(f"QuantizeConfig: Unsupported `pack_dtype`: {self.pack_dtype}") - else: - raise ValueError(f"QuantizeConfig: Unsupported `pack_dtype`: {self.pack_dtype}") + @quant_method.setter + def quant_method(self, value: Union[str, METHOD]) -> None: + self.method = value - # validate quant method and format is matched - valid_formats = QUANT_METHOD_FORMAT_MAPPING.get(self.quant_method, None) + @property + def checkpoint_format(self): + return self.format + + @checkpoint_format.setter + def checkpoint_format(self, value) -> None: + self.format = value + + @property + def runtime_bits(self): + return self.bits + + def _resolve_checkpoint_format(self) -> FORMAT: + self.format = _normalize_format(self.format) + return self.format + + def _normalize_bits_field(self, bits_value, checkpoint_format: FORMAT): + return _normalize_quant_bits(bits_value, format_value=checkpoint_format) + + def _normalize_dynamic_layer_config( + self, + layer_name: str, + layer_dict: Dict[str, Any], + *, + valid_bit_widths: List[int], + checkpoint_format: FORMAT, + ) -> None: + for key, value in layer_dict.items(): + if key == "bits": + normalized_bits = self._normalize_bits_field(value, checkpoint_format=checkpoint_format) + layer_dict[key] = normalized_bits + if quant_bits_width(normalized_bits) not in valid_bit_widths: + raise ValueError( + f"QuantizeConfig: Layer `{layer_name}` only support quantization of `{valid_bit_widths}` bits." + ) + if key == "group_size" and value != -1 and value <= 0: + raise ValueError(_resolve_dynamic_group_size_error()) + + def allowed_quant_methods(self) -> Tuple[METHOD, ...]: + return tuple(METHOD) + + def supported_export_formats(self) -> Tuple[FORMAT, ...]: + valid_formats = QUANT_METHOD_FORMAT_MAPPING.get(self.method, None) if valid_formats is None: - raise ValueError(f"QuantizeConfig: Unsupported `quant_method`: {self.quant_method}") + raise ValueError(f"QuantizeConfig: Unsupported `method`: {self.method}") + return tuple(valid_formats) - # If the user does not pass it, the default value will be set according to quant_method - if self.damp_percent is None: - if self.quant_method == METHOD.QQQ: - self.damp_percent = 0.005 - else: - self.damp_percent = 0.05 - if self.damp_auto_increment is None: - if self.quant_method == METHOD.QQQ: - self.damp_auto_increment = 0.001 - else: - self.damp_auto_increment = 0.01 + def export_quant_method(self) -> METHOD: + return _resolve_export_quant_method(resolve_quant_format(self.format, self.method), fallback_method=self.method) + + def default_desc_act(self) -> bool: + return True + + def __post_init__(self): + fields_info = fields(self) + + self.method = _normalize_quant_method(self.method) + format_family = self._resolve_checkpoint_format() + self.pack_dtype = _normalize_pack_dtype(self.pack_dtype) + self.bits = self._normalize_bits_field(self.bits, checkpoint_format=format_family) + + allowed_methods = self.allowed_quant_methods() + if allowed_methods and self.method not in allowed_methods: + raise ValueError( + f"{self.__class__.__name__}: `method` must be one of {[v.value for v in allowed_methods]}." + ) # TODO FIXME awq compat which didn't have checkpoint_format before merging to gptqmodel - if self.quant_method == METHOD.AWQ and self.format not in [ - FORMAT.MARLIN, - FORMAT.GEMV, - FORMAT.GEMV_FAST, - FORMAT.GEMM, - FORMAT.LLM_AWQ, - ]: + if self.quant_method == METHOD.AWQ and self.format not in [FORMAT.MARLIN, FORMAT.GEMV, FORMAT.GEMV_FAST, FORMAT.GEMM, FORMAT.BITBLAS, FORMAT.LLM_AWQ]: log.info(f"QuantizeConfig: Auto fix `format` to `{FORMAT.GEMM}`") self.format = FORMAT.GEMM + format_family = self._resolve_checkpoint_format() - if self.format not in valid_formats: + valid_formats = self.supported_export_formats() + if format_family not in valid_formats: raise ValueError( - f"QuantizeConfig: checkpoint `format` used is {self.format}, and the quantization method is {self.quant_method}. " + f"{self.__class__.__name__}: unsupported export `format` `{format_family}`." ) - # normalize failsafe config - if self.failsafe is None: - pass - elif isinstance(self.failsafe, dict): - strategy = self.failsafe.get("strategy", FailSafeStrategy.RTN) - threshold = self.failsafe.get("threshold", "1.0%") - smooth = self.failsafe.get("smooth") - if smooth is None: - smooth = self.failsafe.get("smooth_method") - if smooth is None and "clip_method" in self.failsafe: - smooth = self.failsafe.get("clip_method") - smooth = _parse_smooth_method(smooth) - if smooth is None: - if "smooth_percentile" in self.failsafe: - smooth = SmoothPercentile(percentile=float(self.failsafe.get("smooth_percentile", 99.0))) - elif "smooth_mad_k" in self.failsafe: - smooth = SmoothMAD(k=float(self.failsafe.get("smooth_mad_k", 3.0))) - elif "smooth_mse_steps" in self.failsafe or "smooth_mse_maxshrink" in self.failsafe: - smooth = SmoothMSE( - steps=int(self.failsafe.get("smooth_mse_steps", 32)), - maxshrink=float(self.failsafe.get("smooth_mse_maxshrink", 0.8)), - ) - elif "smooth_outlier_pct" in self.failsafe: - smooth = SmoothOutlier(pct=float(self.failsafe.get("smooth_outlier_pct", 1.0))) - elif "smooth_rms_k" in self.failsafe: - smooth = SmoothSoftNorm(k=float(self.failsafe.get("smooth_rms_k", 3.0))) - elif "smooth_log_mu" in self.failsafe: - smooth = SmoothLog( - percentile=float(self.failsafe.get("smooth_percentile", 99.0)), - mu=float(self.failsafe.get("smooth_log_mu", 8.0)), - ) - elif "smooth_axis" in self.failsafe: - smooth = SmoothRowCol(axis=str(self.failsafe.get("smooth_axis", "row"))) - self.failsafe = FailSafe( - strategy=strategy, - threshold=threshold, - smooth=smooth, - ) - elif isinstance(self.failsafe, (str, int, float)): - self.failsafe = FailSafe(strategy=FailSafeStrategy.RTN, threshold=self.failsafe) - elif not isinstance(self.failsafe, FailSafe): - raise ValueError("QuantizeConfig: `failsafe` must be a FailSafe config, dict, string, int, float, or None.") - - if self.failsafe is not None: - if isinstance(self.failsafe.strategy, str): - try: - self.failsafe.strategy = FailSafeStrategy(self.failsafe.strategy.lower()) - except ValueError as exc: - raise ValueError( - f"QuantizeConfig: `failsafe.strategy` must be one of {[v.value for v in FailSafeStrategy]}." - ) from exc - elif not isinstance(self.failsafe.strategy, FailSafeStrategy): - raise ValueError( - f"QuantizeConfig: `failsafe.strategy` must be one of {[v.value for v in FailSafeStrategy]}." - ) - - self.failsafe.smooth = _parse_smooth_method(self.failsafe.smooth) + self.fallback = _normalize_fallback(self.fallback) - if self.bits not in fields_info[0].metadata["choices"]: + valid_bit_widths = fields_info[0].metadata["choices"] + if quant_bits_width(self.bits) not in valid_bit_widths: raise ValueError(f"QuantizeConfig: `bits` must be in the set of `{fields_info[0].metadata['choices']}`.") if self.dynamic is not None: self.dynamic = { - **{k: v for k, v in self.dynamic.items() if k.startswith("-")}, # 先添加以 "-" 开头的键 - **{k: v for k, v in self.dynamic.items() if not k.startswith("-")}, # 然后添加其他键 + **{k: v for k, v in self.dynamic.items() if k.startswith('-')}, + **{k: v for k, v in self.dynamic.items() if not k.startswith('-')}, } for layer, layer_dict in self.dynamic.items(): - for key, value in layer_dict.items(): - if key == "bits" and value not in fields_info[0].metadata["choices"]: - raise ValueError( - f"QuantizeConfig: Layer `{layer}` only support quantization of `{fields_info[0].metadata['choices']}` bits." - ) - elif key == "group_size" and value != -1 and value <= 0: - raise ValueError( - "QuantizeConfig: `group_size` must be one of `[-1, 16, 32, 64, 128, 256, 512, 1024]`." - ) + self._normalize_dynamic_layer_config( + layer, + layer_dict, + valid_bit_widths=valid_bit_widths, + checkpoint_format=format_family, + ) if self.group_size != -1 and self.group_size <= 0: - raise ValueError("QuantizeConfig: `group_size` must be one of `[-1, 16, 32, 64, 128, 256, 512, 1024]`.") - - if not (0 < self.damp_percent < 1): - raise ValueError("QuantizeConfig: `damp_percent` must between 0 and 1.") - - if self.damp_auto_increment < 0: - raise ValueError("QuantizeConfig:: `damp_auto_increment` must greater than 0.") - - if self.hessian is None: - self.hessian = HessianConfig() - elif isinstance(self.hessian, dict): - self.hessian = HessianConfig(**self.hessian) - elif not isinstance(self.hessian, HessianConfig): - raise ValueError("QuantizeConfig: `hessian` must be a HessianConfig, dict, or None.") - - if self.gptaq is None: - pass - elif isinstance(self.gptaq, dict): - self.gptaq = GPTAQConfig(**self.gptaq) - elif not isinstance(self.gptaq, GPTAQConfig): - raise ValueError("QuantizeConfig: `gptaq` must be a GPTAQConfig, dict, or None.") - - # resolve activation ordering compatibility and defaults - desc_act_user_value = self.desc_act - act_group_aware_user_value = self.act_group_aware - - if desc_act_user_value is None: - # GPTQ defaults to higher quality ordering disabled, others retain legacy default - self.desc_act = False if self.quant_method == METHOD.GPTQ else True - elif isinstance(desc_act_user_value, bool): - self.desc_act = desc_act_user_value - else: - self.desc_act = bool(desc_act_user_value) + raise ValueError(_resolve_dynamic_group_size_error()) - if act_group_aware_user_value is None: - # auto-enable for GPTQ unless user explicitly disables it - self.act_group_aware = self.quant_method == METHOD.GPTQ - elif isinstance(act_group_aware_user_value, bool): - self.act_group_aware = act_group_aware_user_value - else: - self.act_group_aware = bool(act_group_aware_user_value) - - self._resolve_activation_ordering(desc_act_user_value, act_group_aware_user_value) - - # validate hybrid act order - if self.act_group_aware and self.desc_act: - raise ValueError("QuantizeConfig:: `act_group_aware` == `True` requires `desc_act` == `False`.") + if self.desc_act is None: + self.desc_act = self.default_desc_act() + elif not isinstance(self.desc_act, bool): + self.desc_act = bool(self.desc_act) - # validate meta if self.meta is not None: if not isinstance(self.meta, dict): raise ValueError("QuantizeConfig: `meta` must be a dictionary") - for key, value in self.meta.items(): + for key in self.meta: if not isinstance(key, str): raise ValueError("QuantizeConfig: `meta` keys must be strings") else: self.meta = {} - # adapter normalize self.adapter = normalize_adapter(self.adapter) - # print(f"adapter: {self.adapter}") - if self.offload_to_disk and not self.offload_to_disk_path: path_key = f"{get_random_string()}-{get_random_string()}" self.offload_to_disk_path = f"./gptqmodel_offload/{path_key}/" log.info(f"QuantizeConfig: offload_to_disk_path auto set to `{self.offload_to_disk_path}`") - if isinstance(self.vram_strategy, str): - try: - self.vram_strategy = VramStrategy(self.vram_strategy.lower()) - except ValueError as exc: - raise ValueError( - f"QuantizeConfig: `vram_strategy` must be one of {[v.value for v in VramStrategy]}." - ) from exc - elif not isinstance(self.vram_strategy, VramStrategy): - raise ValueError(f"QuantizeConfig: `vram_strategy` must be one of {[v.value for v in VramStrategy]}.") - - if isinstance(self.gc_mode, str): - try: - self.gc_mode = GcMode(self.gc_mode.lower()) - except ValueError as exc: - raise ValueError(f"QuantizeConfig: `gc_mode` must be one of {[v.value for v in GcMode]}.") from exc - elif not isinstance(self.gc_mode, GcMode): - raise ValueError(f"QuantizeConfig: `gc_mode` must be one of {[v.value for v in GcMode]}.") + self.dense_vram_strategy = _normalize_dense_vram_strategy(self.dense_vram_strategy) + self.dense_vram_strategy_devices = _normalize_strategy_devices( + self.dense_vram_strategy_devices, + field_name="dense_vram_strategy_devices", + ) + self.moe_vram_strategy = _normalize_moe_vram_strategy(self.moe_vram_strategy) + self.moe_vram_strategy_devices = _normalize_strategy_devices( + self.moe_vram_strategy_devices, + field_name="moe_vram_strategy_devices", + ) + self.gc_mode = _normalize_gc_mode(self.gc_mode) + self.moe = _normalize_moe_config(self.moe) + + # Normalize calibration_data_device to canonical form if it's a specific device (not "balanced") + if self.calibration_data_device is not None: + if isinstance(self.calibration_data_device, str): + if self.calibration_data_device.lower() == "balanced": + self.calibration_data_device = "balanced" + else: + # Import here to avoid circular import + from ..utils.looper_helpers import _canonical_device + self.calibration_data_device = _canonical_device(torch.device(self.calibration_data_device)) + elif isinstance(self.calibration_data_device, torch.device): + # Also normalize when passed as torch.device object + from ..utils.looper_helpers import _canonical_device + self.calibration_data_device = _canonical_device(self.calibration_data_device) def extension_set(self, key: str, value: Any): if self.adapter is None: self.adapter = {} - self.adapter[key.lower()] = value - def _resolve_activation_ordering( - self, - desc_act_user_value: Optional[bool], - act_group_aware_user_value: Optional[bool], - ) -> None: - """Normalize defaults and enforce compatibility between desc_act and act_group_aware.""" - - desc_act_enabled_by_user = bool(desc_act_user_value) if desc_act_user_value is not None else False - act_group_aware_enabled_by_user = ( - bool(act_group_aware_user_value) if act_group_aware_user_value is not None else False - ) - - if desc_act_enabled_by_user and act_group_aware_user_value is not None and act_group_aware_enabled_by_user: - raise ValueError( - "QuantizeConfig:: `act_group_aware` == `True` requires `desc_act` == `False` when both are explicitly set." - ) - - if desc_act_enabled_by_user and act_group_aware_user_value is None and self.act_group_aware: - log.warn( - "QuantizeConfig: `desc_act=True` automatically disables `act_group_aware`. " - "Set `act_group_aware=False` explicitly to silence this warning." - ) - self.act_group_aware = False - def extension_get(self, key: str) -> Any: return self.adapter.get(key.lower()) if self.adapter else None @@ -1043,15 +2572,17 @@ def meta_get(self, key: str) -> Any: return self.meta.get(key) def dynamic_get( - self, layer_name: str, key: str = None, default: Union[int, bool, float] = None, sub_key: str = None + self, + layer_name: str, + key: str = None, + default: Union[int, bool, float] = None, + sub_key: str = None, ) -> Union[Dict, int, bool, float]: return dynamic_get(self.dynamic, layer_name, key, default, sub_key) - # versionable is a meta.property that pairs value with version i.e "value:1.0.0" def meta_set_versionable(self, key: str, value: List[str]): self.meta_set(key, value) - # versionable is a meta.property that pairs value with version i.e "value:1.0.0" def meta_get_versionable(self, key: str) -> List[Tuple[str, str]]: values = self.meta_get(key) if values is None: @@ -1065,39 +2596,40 @@ def meta_get_versionable(self, key: str) -> List[Tuple[str, str]]: result.append((parts[0].lower(), parts[1].lower())) return result - # is quantized model quantized or packed by gptqmodel version with gptaq format code def is_quantized_by_gptaq(self) -> bool: - # check meta.quantizer result = self.meta_get_versionable(META_FIELD_QUANTIZER) if len(result) > 0: for producer, _version in result: if producer == META_QUANTIZER_GPTQMODEL: return version.parse(_version) >= version.parse(MIN_VERSION_WITH_V2) + return False + def is_quantized_by_foem(self) -> bool: + result = self.meta_get_versionable(META_FIELD_QUANTIZER) + if len(result) > 0: + for producer, _version in result: + if producer == META_QUANTIZER_GPTQMODEL: + return version.parse(_version) >= version.parse(MIN_VERSION_WITH_V2) return False def extract_adapter_rank_patterns(self) -> Optional[Dict[str, int]]: adapter_rank_patterns = {} - - # no rank can be had if there is no dynamic or adapter if not self.dynamic or not self.adapter: return adapter_rank_patterns - # override format: `{ "adapter": { "rank": 512 } }` for k, v in self.dynamic.items(): - adapter_override = v.get("adapter", None) # TODO use const, not str + adapter_override = v.get("adapter", None) if adapter_override and isinstance(adapter_override, Dict): rank = adapter_override.get("rank", None) if rank and isinstance(rank, int): - # need to strip `+:` positive prefix - adapter_rank_patterns[k.lstrip("+:")] = rank # TODO use const, not str + adapter_rank_patterns[k.lstrip("+:")] = rank return adapter_rank_patterns def save_pretrained(self, save_dir: str, **kwargs): with open(join(save_dir, QUANT_CONFIG_FILENAME), "w", encoding="utf-8") as f: - d = self.to_dict() - json_str = json.dumps(d, indent=2) + payload = self.to_dict() + json_str = json.dumps(payload, indent=2) log.info(f"Saved Quantize Config: \n{json_str}") f.write(json_str) @@ -1113,7 +2645,7 @@ def gptq_pro( damp_auto_increment: float = 0.01, gptaq_alpha: Optional[float] = None, gptaq_device: Union[str, torch.device] = "auto", - failsafe: Optional[Union[FailSafe, Dict[str, Any], str, int, float]] = None, + failsafe: Optional[Union[Fallback, Dict[str, Any], str, int, float]] = None, **kwargs, ) -> "QuantizeConfig": """ @@ -1127,13 +2659,21 @@ def gptq_pro( """ if "quant_method" in kwargs and kwargs["quant_method"] != METHOD.GPTQ: raise ValueError("QuantizeConfig.gptq_pro() only supports `quant_method=METHOD.GPTQ`.") + if METHOD_FIELD_CODE in kwargs and kwargs[METHOD_FIELD_CODE] != METHOD.GPTQ: + raise ValueError("QuantizeConfig.gptq_pro() only supports `method=METHOD.GPTQ`.") if "format" in kwargs and kwargs["format"] not in QUANT_METHOD_FORMAT_MAPPING[METHOD.GPTQ]: raise ValueError("QuantizeConfig.gptq_pro() only supports GPTQ-compatible output formats.") + fallback = kwargs.pop("fallback", None) + if fallback is None and "failsafe" in kwargs: + fallback = kwargs.pop("failsafe") + if fallback is None: + fallback = failsafe + if failsafe is None: - failsafe = FailSafe( - strategy=FailSafeStrategy.RTN, + fallback = Fallback( + strategy=FallbackStrategy.RTN, threshold="0.5%", smooth=SmoothMSE(steps=32, maxshrink=0.9), ) @@ -1146,7 +2686,7 @@ def gptq_pro( "bits": bits, "group_size": group_size, "sym": sym, - "quant_method": METHOD.GPTQ, + METHOD_FIELD_CODE: METHOD.GPTQ, "format": FORMAT.GPTQ, "desc_act": False, "act_group_aware": True, @@ -1154,167 +2694,209 @@ def gptq_pro( "activation_weighted_mse": True, "damp_percent": damp_percent, "damp_auto_increment": damp_auto_increment, - "failsafe": failsafe, + "fallback": fallback, "gptaq": gptaq, } defaults.update(kwargs) return cls(**defaults) @classmethod - # normalize quant config for compat and also performs validation def from_quant_config(cls, quantize_cfg, format: str = None): - valid_formats = {FORMAT.GPTQ, FORMAT.GPTQ_V2, FORMAT.MARLIN, FORMAT.BITBLAS} + valid_formats = set(FORMAT) format_auto_inferred = False - # compat: format can be passed in via from_quantized() if field missing from json + checkpoint_format_hint = quantize_cfg.get(FORMAT_FIELD_CHECKPOINT) if isinstance(quantize_cfg, dict) else None + serialized_format = quantize_cfg.get(FORMAT_FIELD_CODE) if isinstance(quantize_cfg, dict) else None if format: - if format not in valid_formats: - raise ValueError(f"QuantizeConfig: Unknown quantization checkpoint format: {format}.") - if quantize_cfg.get(FORMAT_FIELD_CHECKPOINT): - raise ValueError( - "QuantizeConfig: Conflicting quantization format passed in manually and also exists in model config." - ) - # compat: warn if checkpoint_format is missing - elif quantize_cfg.get(FORMAT_FIELD_CHECKPOINT) is None: + if _looks_like_fp8_fmt(format): + format = _normalize_fp8_fmt(format) + elif _looks_like_bitsandbytes_format(format): + format = _normalize_bitsandbytes_format(format) + else: + format = _normalize_format(format) + if format not in valid_formats: + raise ValueError(f"QuantizeConfig: Unknown quantization checkpoint format: {format}.") + if checkpoint_format_hint is not None or serialized_format is not None: + raise ValueError("QuantizeConfig: Conflicting quantization format passed in manually and also exists in model config.") + elif checkpoint_format_hint is None and serialized_format is None: format_auto_inferred = True - field_names = [field.name for field in fields(cls)] + field_names = _known_quantize_config_field_names() - # FIXME convert awg quantize_config to gptq quantize_config normalized = { - QUANT_METHOD_FIELD: METHOD.GPTQ, - # compat: default to gptq(v1) when loading models + METHOD_FIELD_CODE: METHOD.GPTQ, FORMAT_FIELD_CODE: format if format else FORMAT.GPTQ, } + format_field_present = format is not None + legacy_checkpoint_format = None for key, val in quantize_cfg.items(): key = key.lower() - # remap keys according to compat map + if key == FORMAT_FIELD_COMPAT_MARLIN: + raise ValueError( + "QuantizeConfig: `is_marlin_format` is no longer supported. Replace it with an explicit `format` field." + ) + + if key == FORMAT_FIELD_CHECKPOINT: + if _looks_like_fp8_fmt(val): + legacy_checkpoint_format = _normalize_fp8_fmt(val) + elif _looks_like_bitsandbytes_format(val): + legacy_checkpoint_format = _normalize_bitsandbytes_format(val) + else: + try: + legacy_checkpoint_format = _normalize_gguf_public_format(val) + except ValueError: + legacy_checkpoint_format = None + if legacy_checkpoint_format is None: + legacy_checkpoint_format = _normalize_format(val) + if legacy_checkpoint_format is not None: + checkpoint_format_hint = legacy_checkpoint_format + continue + if key in QUANT_CONFIG_ARG_SYNONYMS and QUANT_CONFIG_ARG_SYNONYMS[key] in field_names: key = QUANT_CONFIG_ARG_SYNONYMS[key] elif key in QUANT_CONFIG_ARG_SYNONYMS_NEGATED and QUANT_CONFIG_ARG_SYNONYMS_NEGATED[key] in field_names: key = QUANT_CONFIG_ARG_SYNONYMS_NEGATED[key] val = not bool(val) - if key == FORMAT_FIELD_CHECKPOINT: - val = val.lower() - - if val in {FORMAT.GPTQ, FORMAT.GPTQ_V2, FORMAT.MARLIN, FORMAT.BITBLAS}: - normalized[key] = val - else: - raise ValueError(f"QuantizeConfig: Unknown quantization format: `{val}`.") - elif key == QUANT_METHOD_FIELD: - val = val.lower() - # compat: some hf models use quant_method=marlin or bitblas - if val == FORMAT.MARLIN: + if key == METHOD_FIELD_CODE: + if isinstance(val, str) and val.lower() == FORMAT.MARLIN: normalized[FORMAT_FIELD_CODE] = FORMAT.MARLIN - elif val == FORMAT.BITBLAS: + elif isinstance(val, str) and val.lower() == FORMAT.BITBLAS: normalized[FORMAT_FIELD_CODE] = FORMAT.BITBLAS - elif val not in {METHOD.GPTQ, METHOD.QQQ, METHOD.AWQ}: - raise ValueError(f"QuantizeConfig: Unknown quantization method: `{val}`.") else: - normalized[QUANT_METHOD_FIELD] = val + normalized[METHOD_FIELD_CODE] = _normalize_quant_method(val) elif key == FORMAT_FIELD_CODE: - normalized[key] = val.lower() if isinstance(val, str) else val - elif key == "failsafe": - normalized[key] = val + format_field_present = True + serialized_format_hint = None + try: + serialized_format_hint = resolve_quant_format( + val, + normalized.get(METHOD_FIELD_CODE), + ) + except ValueError: + serialized_format_hint = None + + format_hint = format or legacy_checkpoint_format or checkpoint_format_hint + if format_hint is not None: + try: + format_hint = resolve_quant_format( + format_hint, + normalized.get(METHOD_FIELD_CODE), + ) + except ValueError: + format_hint = None + if serialized_format_hint in {FORMAT.GGUF, FORMAT.FP8, FORMAT.BITSANDBYTES} or format_hint in { + FORMAT.GGUF, + FORMAT.FP8, + FORMAT.BITSANDBYTES, + }: + normalized[key] = val + else: + normalized[key] = _normalize_format(val) elif key in field_names: normalized[key] = val else: log.info(f"QuantizeConfig: Ignoring unknown parameter in the quantization configuration: {key}.") - # fix method if format is not allowed for the method - fmt = normalized.get(FORMAT_FIELD_CODE) - method = normalized.get(QUANT_METHOD_FIELD) + if not format_field_present and legacy_checkpoint_format is not None: + normalized[FORMAT_FIELD_CODE] = legacy_checkpoint_format - # TODO FIXME qqq compat which didn't have checkpoint_format before merging to gptqmodel - if method == METHOD.QQQ and fmt != FORMAT.QQQ: - log.info(f"QuantizeConfig: Auto fix `format` to `{FORMAT.QQQ}`") - normalized[FORMAT_FIELD_CODE] = FORMAT.QQQ - fmt = FORMAT.QQQ - - if fmt is not None: - allowed_methods = [m for m, fmts in QUANT_METHOD_FORMAT_MAPPING.items() if fmt in fmts] - if method not in allowed_methods: - if fmt in {FORMAT.GEMM, FORMAT.GEMV, FORMAT.GEMV_FAST}: - new_method = METHOD.AWQ - elif fmt in {FORMAT.GPTQ, FORMAT.GPTQ_V2, FORMAT.BITBLAS}: - new_method = METHOD.GPTQ - elif fmt == FORMAT.QQQ: - new_method = METHOD.QQQ - elif fmt == FORMAT.MARLIN: - new_method = method if method in {METHOD.GPTQ, METHOD.AWQ} else METHOD.GPTQ - else: - new_method = allowed_methods[0] if allowed_methods else METHOD.GPTQ - if new_method != method: - log.warn( - f"QuantizeConfig: `{FORMAT_FIELD_CODE}`=`{fmt}` is incompatible with `{QUANT_METHOD_FIELD}`=`{method}`. Auto-fix method to `{new_method}`." - ) - normalized[QUANT_METHOD_FIELD] = new_method + if quantize_cfg.get(AWQ_PACKING_BACKEND_FIELD) == "llm-awq": + normalized[METHOD_FIELD_CODE] = METHOD.AWQ + normalized[FORMAT_FIELD_CODE] = FORMAT.LLM_AWQ + normalized[PACK_DTYPE_FIELD] = torch.int16 + log.info("Detected llm-awq quantization format; FORMAT automatically set to FORMAT.LLM_AWQ.") + + meta_payload = normalized.get(META_FIELD) + meta_field_map = { + "fallback": "fallback", + "hessian": "hessian", + "gptaq": "gptaq", + "foem": "foem", + "weight_only": "weight_only", + "preprocessors": "preprocessors", + "gc_mode": "gc_mode", + "wait_for_submodule_finalizers": "wait_for_submodule_finalizers", + "auto_forward_data_parallel": "auto_forward_data_parallel", + "dense_vram_strategy": "dense_vram_strategy", + "dense_vram_strategy_devices": "dense_vram_strategy_devices", + "moe_vram_strategy": "moe_vram_strategy", + "moe_vram_strategy_devices": "moe_vram_strategy_devices", + "moe": "moe", + "offload_to_disk": "offload_to_disk", + "offload_to_disk_path": "offload_to_disk_path", + "pack_impl": "pack_impl", + "mse": "mse", + "activation_weighted_mse": "activation_weighted_mse", + "mock_quantization": "mock_quantization", + "act_group_aware": "act_group_aware", + "true_sequential": "true_sequential", + "damp_percent": "damp_percent", + "damp_auto_increment": "damp_auto_increment", + "opt_rotation_epochs": "opt_rotation_epochs", + "opt_finetune_epochs": "opt_finetune_epochs", + "opt_train_samples": "opt_train_samples", + "opt_validation_samples": "opt_validation_samples", + "opt_batch_size": "opt_batch_size", + "opt_rotation_lr": "opt_rotation_lr", + "opt_weight_lr": "opt_weight_lr", + "opt_quantizer_lr": "opt_quantizer_lr", + "opt_pair_ratio": "opt_pair_ratio", + "opt_seed": "opt_seed", + "opt_optimizer": "opt_optimizer", + "opt_weight_decay": "opt_weight_decay", + "opt_betas": "opt_betas", + "opt_eps": "opt_eps", + "opt_amsgrad": "opt_amsgrad", + "opt_sgd_momentum": "opt_sgd_momentum", + "opt_sgd_dampening": "opt_sgd_dampening", + "opt_sgd_nesterov": "opt_sgd_nesterov", + "opt_fused_rotation": "opt_fused_rotation", + "opt_gradient_checkpointing": "opt_gradient_checkpointing", + "opt_stage_cudagraph": "opt_stage_cudagraph", + "opt_best_state_dtype": "opt_best_state_dtype", + "opt_train_on_noisy_inputs": "opt_train_on_noisy_inputs", + "opt_scope": "opt_scope", + "opt_stage_impl": "opt_stage_impl", + "opt_pair_impl": "opt_pair_impl", + "opt_quantizer_impl": "opt_quantizer_impl", + "opt_channel_scale_clamp_min": "opt_channel_scale_clamp_min", + "opt_channel_scale_clamp_max": "opt_channel_scale_clamp_max", + } + if isinstance(meta_payload, dict): + for normalized_key, meta_key in meta_field_map.items(): + if normalized_key not in normalized and meta_key in meta_payload: + normalized[normalized_key] = meta_payload.get(meta_key) + + target_cls = cls if cls not in {BaseQuantizeConfig, QuantizeConfig} else _resolve_quantize_config_class(normalized) + normalized = _normalize_quantize_config_payload_for_target_cls(target_cls, normalized) + if target_cls is RTNConfig: + normalized = _normalize_rtn_kwargs(normalized) + elif target_cls is GGUFConfig: + normalized = _normalize_gguf_kwargs(normalized) + elif target_cls is FP8Config: + normalized = _normalize_fp8_kwargs(normalized) + elif target_cls is BitsAndBytesConfig: + normalized = _normalize_bitsandbytes_kwargs(normalized) if format_auto_inferred: log.info( - f"QuantizeConfig: `{FORMAT_FIELD_CHECKPOINT}` is missing from the quantization configuration and is automatically inferred to {normalized[FORMAT_FIELD_CODE]}" + f"QuantizeConfig: `{FORMAT_FIELD_CODE}` is missing from the quantization configuration and is automatically inferred to {normalized[FORMAT_FIELD_CODE]}" ) - if normalized[FORMAT_FIELD_CODE] in {FORMAT.BITBLAS}: - # AWQ and Marlin do not reorder the rows. + resolved_format_family = resolve_quant_format( + normalized[FORMAT_FIELD_CODE], + normalized.get(METHOD_FIELD_CODE), + ) + if resolved_format_family in {FORMAT.BITBLAS, FORMAT.BITSANDBYTES}: normalized["desc_act"] = False - if "sym" not in normalized: + if "sym" not in normalized and target_cls not in {GGUFConfig, FP8Config, BitsAndBytesConfig, EXL3Config}: log.warn( "QuantizeConfig: config does not contain `sym` (symmetric quantization). This may result in silent errors. Defaulting to `sym=True`." ) - - meta_payload = normalized.get(META_FIELD) - if "failsafe" not in normalized and isinstance(meta_payload, dict) and "failsafe" in meta_payload: - normalized["failsafe"] = meta_payload.get("failsafe") - if "hessian" not in normalized and isinstance(meta_payload, dict) and "hessian" in meta_payload: - normalized["hessian"] = meta_payload.get("hessian") - if "gptaq" not in normalized and isinstance(meta_payload, dict) and "gptaq" in meta_payload: - normalized["gptaq"] = meta_payload.get("gptaq") - if "mse" not in normalized and isinstance(meta_payload, dict) and "mse" in meta_payload: - normalized["mse"] = meta_payload.get("mse") - if ( - "activation_weighted_mse" not in normalized - and isinstance(meta_payload, dict) - and "activation_weighted_mse" in meta_payload - ): - normalized["activation_weighted_mse"] = meta_payload.get("activation_weighted_mse") - if "act_group_aware" not in normalized and isinstance(meta_payload, dict) and "act_group_aware" in meta_payload: - normalized["act_group_aware"] = meta_payload.get("act_group_aware") - if ( - "mock_quantization" not in normalized - and isinstance(meta_payload, dict) - and "mock_quantization" in meta_payload - ): - normalized["mock_quantization"] = meta_payload.get("mock_quantization") - if "vram_strategy" not in normalized and isinstance(meta_payload, dict) and "vram_strategy" in meta_payload: - normalized["vram_strategy"] = meta_payload.get("vram_strategy") - if "gc_mode" not in normalized and isinstance(meta_payload, dict) and "gc_mode" in meta_payload: - normalized["gc_mode"] = meta_payload.get("gc_mode") - if ( - "wait_for_submodule_finalizers" not in normalized - and isinstance(meta_payload, dict) - and "wait_for_submodule_finalizers" in meta_payload - ): - normalized["wait_for_submodule_finalizers"] = meta_payload.get("wait_for_submodule_finalizers") - if ( - "auto_forward_data_parallel" not in normalized - and isinstance(meta_payload, dict) - and "auto_forward_data_parallel" in meta_payload - ): - normalized["auto_forward_data_parallel"] = meta_payload.get("auto_forward_data_parallel") - - cfg = cls(**normalized) - - if quantize_cfg.get(AWQ_PACKING_BACKEND_FIELD) and quantize_cfg[AWQ_PACKING_BACKEND_FIELD] == "llm-awq": - cfg.quant_method = METHOD.AWQ - cfg.format = FORMAT.LLM_AWQ - cfg.pack_dtype = torch.int16 - log.info("Detected llm-awq quantization format; FORMAT automatically set to FORMAT.LLM_AWQ.") - - return cfg + return target_cls(**_filter_quantize_config_payload_for_target_cls(target_cls, normalized)) @classmethod def from_pretrained(cls, save_dir: str, **kwargs): @@ -1336,143 +2918,94 @@ def from_pretrained(cls, save_dir: str, **kwargs): with open(resolved_config_file, "r", encoding="utf-8") as f: args_from_json = json.load(f) - if transformers_config: args_from_json = args_from_json["quantization_config"] - return cls.from_quant_config(args_from_json, format) + def _update_meta_payload(self, meta_payload: Dict[str, Any]) -> None: + return None + + def _update_output_payload(self, out: Dict[str, Any]) -> None: + return None + def to_dict(self): - smooth = None - if self.failsafe is not None and self.failsafe.smooth is not None: - payload = {"type": self.failsafe.smooth.name} - payload["group_size_threshold"] = self.failsafe.smooth.group_size_threshold - if isinstance(self.failsafe.smooth, SmoothPercentile): - payload["percentile"] = self.failsafe.smooth.percentile - elif isinstance(self.failsafe.smooth, SmoothPercentileAsymmetric): - payload["low"] = self.failsafe.smooth.low - payload["high"] = self.failsafe.smooth.high - elif isinstance(self.failsafe.smooth, SmoothMAD): - payload["k"] = self.failsafe.smooth.k - elif isinstance(self.failsafe.smooth, SmoothMSE): - payload["steps"] = self.failsafe.smooth.steps - payload["maxshrink"] = self.failsafe.smooth.maxshrink - elif isinstance(self.failsafe.smooth, SmoothOutlier): - payload["pct"] = self.failsafe.smooth.pct - elif isinstance(self.failsafe.smooth, SmoothSoftNorm): - payload["k"] = self.failsafe.smooth.k - elif isinstance(self.failsafe.smooth, SmoothLog): - payload["percentile"] = self.failsafe.smooth.percentile - payload["mu"] = self.failsafe.smooth.mu - elif isinstance(self.failsafe.smooth, SmoothRowCol): - payload["axis"] = self.failsafe.smooth.axis - smooth = payload + smooth = _serialize_smooth_method(self.fallback.smooth if self.fallback is not None else None) meta_payload = dict(self.meta) if self.meta else {} if self.moe: meta_payload["moe"] = self.moe.to_dict() - if self.failsafe is None: - meta_payload["failsafe"] = None + if self.fallback is None: + meta_payload["fallback"] = None else: - meta_payload["failsafe"] = { - "strategy": self.failsafe.strategy.value - if isinstance(self.failsafe.strategy, FailSafeStrategy) - else self.failsafe.strategy, - "threshold": self.failsafe.threshold, + meta_payload["fallback"] = { + "strategy": ( + self.fallback.strategy.value + if isinstance(self.fallback.strategy, FallbackStrategy) + else self.fallback.strategy + ), + "threshold": self.fallback.threshold, "smooth": smooth, } - if self.gptaq is None: - meta_payload["gptaq"] = None - else: - device = self.gptaq.device - device_value = device if isinstance(device, str) else str(device) - meta_payload["gptaq"] = { - "alpha": self.gptaq.alpha, - "device": device_value, - } meta_payload["offload_to_disk"] = self.offload_to_disk meta_payload["offload_to_disk_path"] = self.offload_to_disk_path meta_payload["pack_impl"] = self.pack_impl - meta_payload["mse"] = self.mse - meta_payload["activation_weighted_mse"] = self.activation_weighted_mse - meta_payload["mock_quantization"] = self.mock_quantization - meta_payload["act_group_aware"] = self.act_group_aware - meta_payload["gc_mode"] = self.gc_mode + meta_payload["gc_mode"] = self.gc_mode.value if isinstance(self.gc_mode, GcMode) else self.gc_mode meta_payload["wait_for_submodule_finalizers"] = self.wait_for_submodule_finalizers meta_payload["auto_forward_data_parallel"] = self.auto_forward_data_parallel - meta_payload["hessian"] = { - "chunk_size": self.hessian.chunk_size, - "chunk_bytes": self.hessian.chunk_bytes, - "staging_dtype": str(self.hessian.staging_dtype).split(".")[-1], - } - meta_payload["vram_strategy"] = ( - self.vram_strategy.value if isinstance(self.vram_strategy, VramStrategy) else self.vram_strategy + meta_payload["dense_vram_strategy"] = ( + self.dense_vram_strategy.value + if isinstance(self.dense_vram_strategy, VramStrategy) + else self.dense_vram_strategy ) + meta_payload["dense_vram_strategy_devices"] = self.dense_vram_strategy_devices + meta_payload["moe_vram_strategy"] = ( + self.moe_vram_strategy.value + if isinstance(self.moe_vram_strategy, VramStrategy) + else self.moe_vram_strategy + ) + meta_payload["moe_vram_strategy_devices"] = self.moe_vram_strategy_devices + self._update_meta_payload(meta_payload) out = { - "bits": self.bits, + "bits": serialize_quant_bits(self.bits), "dynamic": self.dynamic, "group_size": self.group_size, "desc_act": self.desc_act, "lm_head": self.lm_head, - QUANT_METHOD_FIELD: self.quant_method, + METHOD_FIELD_CODE: self.method, + QUANT_METHOD_FIELD: self.method, + FORMAT_FIELD_CODE: self.format, FORMAT_FIELD_CHECKPOINT: self.format, - # torch.dtype convert to string PACK_DTYPE_FIELD: str(self.pack_dtype).split(".")[-1], META_FIELD: meta_payload, - # DO NOT EXPORT Adapter to config/json since adapter can be swapped out/in - # ADAPTER_FIELD: self.adapter.to_dict() if self.adapter else None, - # DO NOT EXPORT compute_device_filter since functions are not serializable } - - if self.quant_method == METHOD.AWQ: - out["zero_point"] = not self.sym - # awq compat with vllm/sglang/transformers loaders - out["version"] = self.format - out[FORMAT_FIELD_CODE] = self.format - else: - out["sym"] = self.sym - if self.quant_method == METHOD.GPTQ: - out[FORMAT_FIELD_CODE] = self.format + self._update_output_payload(out) dynamic = out["dynamic"] if dynamic: - # dynamic adapter config is only used in the quantize phase and is deleted when saving. for _, v in dynamic.items(): v.pop("adapter", None) + if "bits" in v: + v["bits"] = serialize_quant_bits(v["bits"]) - # simplify: clean keys where the value is None or empty [list, dict] out = {k: v for k, v in out.items() if v is not None and (v not in [None, {}])} - dict_scale_dtype_to_str(out) return out - # TODO FIX ME, g_idx int32 per infeature but infeature count is per module def calculate_bits_per_weight(self): + bit_width = quant_bits_width(self.bits) if self.group_size != -1: - # naive bits is - # mlp.down_proj.g_idx: I32 - # mlp.down_proj.qweight: I32 - # mlp.down_proj.qzeros: I32 - # mlp.down_proj.scales: F16 - per_group_bits = self.group_size * self.bits # qweight: packed by group_size - per_group_bits += 16 # scales fp16: one per group - per_group_bits += self.bits # qzeros: one per group - # FIX ME: g_idx is I32, one per infeature - per_group_bits += 4 # ESTIMATE for g_idx int32: one per features/group_size item + per_group_bits = self.group_size * bit_width + per_group_bits += 16 + per_group_bits += bit_width + per_group_bits += 4 bpw = per_group_bits / self.group_size - - # normally g_idx (int32 allocated one per in_feature) is allocated in device memory - # but each module may have different infeatures we don't have enouch ctx here, use estimated `0.1` for now bpw += 0.1 else: - # there is only one scale int32 + one qzero int32 per entire module so overall it contributes to close to 0 bpw - bpw = self.bits - log.info( - f"Estimated Quantization BPW (bits per weight): {bpw} bpw, based on [bits: {self.bits}, group_size: {self.group_size}]" - ) + bpw = bit_width + log.info(f"Estimated Quantization BPW (bits per weight): {bpw} bpw, based on [bits: {self.bits}, group_size: {self.group_size}]") def moe_routing_override(self, num_experts: int) -> Union[int, None]: if self.moe is None: @@ -1484,12 +3017,1137 @@ def moe_routing_bypass(self) -> bool: return False return self.moe.routing_bypass() + def uses_weight_only_lifecycle(self) -> bool: + return False + + def requires_calibration_dataset(self) -> bool: + return not self.uses_weight_only_lifecycle() + + def quant_linear_init_kwargs(self) -> Dict[str, Any]: + return {} + + +@dataclass +class PreProcessorConfig(BaseQuantizeConfig): + preprocessors: Optional[List[Union[BasePreProcessorConfig, Dict[str, Any], str]]] = field(default_factory=list) + smoother: Optional[Union[SmootherConfig, SmoothMethod, Dict[str, Any], str]] = field(default=None) + # Backward-compatible alias. New code should use `smoother`. + smooth: Optional[Union[SmoothMethod, Dict[str, Any], str]] = field(default=None, repr=False) + + def _normalize_preprocessor_state(self) -> None: + self.preprocessors = _normalize_preprocessors(self.preprocessors) + + smoother_payload = self.smoother if self.smoother is not None else self.smooth + self.smoother = _normalize_smoother_config(smoother_payload) + + if self.smoother is None: + for preprocessor in self.preprocessors: + if isinstance(preprocessor, SmootherConfig): + self.smoother = preprocessor + break + + non_smoother_preprocessors = [ + preprocessor for preprocessor in self.preprocessors if not isinstance(preprocessor, SmootherConfig) + ] + if self.smoother is not None: + non_smoother_preprocessors.append(self.smoother) + self.preprocessors = non_smoother_preprocessors + _validate_unique_preprocessors(self.preprocessors) + self.smooth = self.resolve_smooth_method() + + def __post_init__(self): + self._normalize_preprocessor_state() + super().__post_init__() + + def resolve_smooth_method(self) -> Optional[SmoothMethod]: + if self.smoother is None: + return None + return self.smoother.smooth + + def _update_meta_payload(self, meta_payload: Dict[str, Any]) -> None: + if self.preprocessors: + meta_payload["preprocessors"] = [preprocessor.to_dict() for preprocessor in self.preprocessors] + + +@dataclass +class QuantizeConfig(BaseQuantizeConfig, metaclass=QuantizeConfigMeta): + """Backward-compatible quantization config factory. + + Direct construction dispatches to a concrete method-specific config class. + """ + + +@dataclass +class GPTQConfig(PreProcessorConfig): + damp_percent: Optional[float] = field(default=None) + damp_auto_increment: Optional[float] = field(default=None) + act_group_aware: Optional[bool] = field(default=None) + static_groups: bool = field(default=False) + mse: float = field(default=0.0) + activation_weighted_mse: bool = field(default=False) + gptaq: Optional[GPTAQConfig] = field(default=None) + foem: Optional[FOEMConfig] = field(default=None) + mock_quantization: bool = field( + default=False, + metadata={"help": "Skip heavy computations for fast model loading validation"}, + ) + hessian: Optional[HessianConfig] = field(default_factory=HessianConfig) + + def allowed_quant_methods(self) -> Tuple[METHOD, ...]: + return (METHOD.GPTQ,) + + def supported_export_formats(self) -> Tuple[FORMAT, ...]: + return GPTQ_EXPORT_FORMATS + + def default_desc_act(self) -> bool: + return False + + def __post_init__(self): + desc_act_user_value = self.desc_act + act_group_aware_user_value = self.act_group_aware + super().__post_init__() + + if self.damp_percent is None: + self.damp_percent = _default_damp_percent(self.method) + if self.damp_auto_increment is None: + self.damp_auto_increment = _default_damp_auto_increment(self.method) + if not (0 < self.damp_percent < 1): + raise ValueError("QuantizeConfig: `damp_percent` must between 0 and 1.") + if self.damp_auto_increment < 0: + raise ValueError("QuantizeConfig:: `damp_auto_increment` must greater than 0.") + + self.hessian = _normalize_hessian(self.hessian) + self.gptaq = _normalize_gptaq(self.gptaq) + self.foem = _normalize_foem(self.foem) + + if act_group_aware_user_value is None: + self.act_group_aware = self.method == METHOD.GPTQ + elif not isinstance(act_group_aware_user_value, bool): + self.act_group_aware = bool(act_group_aware_user_value) + + self._resolve_activation_ordering(desc_act_user_value, act_group_aware_user_value) + if self.act_group_aware and self.desc_act: + raise ValueError("QuantizeConfig:: `act_group_aware` == `True` requires `desc_act` == `False`.") + + def _resolve_activation_ordering( + self, + desc_act_user_value: Optional[bool], + act_group_aware_user_value: Optional[bool], + ) -> None: + desc_act_enabled_by_user = bool(desc_act_user_value) if desc_act_user_value is not None else False + act_group_aware_enabled_by_user = ( + bool(act_group_aware_user_value) if act_group_aware_user_value is not None else False + ) + + if desc_act_enabled_by_user and act_group_aware_user_value is not None and act_group_aware_enabled_by_user: + raise ValueError( + "QuantizeConfig:: `act_group_aware` == `True` requires `desc_act` == `False` when both are explicitly set." + ) + + if desc_act_enabled_by_user and act_group_aware_user_value is None and self.act_group_aware: + log.warn( + "QuantizeConfig: `desc_act=True` automatically disables `act_group_aware`. " + "Set `act_group_aware=False` explicitly to silence this warning." + ) + self.act_group_aware = False + + def _update_meta_payload(self, meta_payload: Dict[str, Any]) -> None: + if self.gptaq is None: + meta_payload["gptaq"] = None + elif self.foem is None: + device = self.gptaq.device + meta_payload["gptaq"] = { + "alpha": self.gptaq.alpha, + "device": device if isinstance(device, str) else str(device), + } + else: + device = self.foem.device + meta_payload["foem"] = { + "alpha": self.foem.alpha, + "beta": self.foem.beta, + "device": device if isinstance(device, str) else str(device), + } + + meta_payload["mse"] = self.mse + meta_payload["activation_weighted_mse"] = self.activation_weighted_mse + meta_payload["mock_quantization"] = self.mock_quantization + meta_payload["act_group_aware"] = self.act_group_aware + meta_payload["hessian"] = { + "chunk_size": self.hessian.chunk_size, + "chunk_bytes": self.hessian.chunk_bytes, + "staging_dtype": str(self.hessian.staging_dtype).split(".")[-1], + } + + def _update_output_payload(self, out: Dict[str, Any]) -> None: + out["sym"] = self.sym + out[FORMAT_FIELD_CODE] = self.format + + +@dataclass +class AWQConfig(PreProcessorConfig): + method: METHOD = field(default=METHOD.AWQ) + format: FORMAT = field(default=FORMAT.GEMM) + + def allowed_quant_methods(self) -> Tuple[METHOD, ...]: + return (METHOD.AWQ,) + + def supported_export_formats(self) -> Tuple[FORMAT, ...]: + return AWQ_EXPORT_FORMATS + + def default_desc_act(self) -> bool: + # AWQ runtimes do not use GPTQ-style activation reordering unless the + # checkpoint explicitly asks for it. + return False + + def __post_init__(self): + self.method = _normalize_quant_method(self.method) + self.format = _normalize_format(self.format) + if self.format not in self.supported_export_formats(): + log.info(f"QuantizeConfig: Auto fix `format` to `{FORMAT.GEMM}`") + self.format = FORMAT.GEMM + super().__post_init__() + + def _update_output_payload(self, out: Dict[str, Any]) -> None: + out["zero_point"] = not self.sym + out["version"] = self.format + out[FORMAT_FIELD_CODE] = self.format + + +@dataclass +class ParoConfig(PreProcessorConfig): + method: METHOD = field(default=METHOD.PARO) + format: FORMAT = field(default=FORMAT.PAROQUANT) + krot: int = field(default=8) + opt_rotation_epochs: int = field(default=10) + opt_finetune_epochs: int = field(default=10) + opt_train_samples: int = field(default=2048) + opt_validation_samples: int = field(default=64) + opt_batch_size: int = field(default=64) + opt_rotation_lr: float = field(default=0.05) + opt_weight_lr: float = field(default=1e-5) + opt_quantizer_lr: float = field(default=1e-6) + opt_pair_ratio: float = field(default=0.5) + opt_seed: int = field(default=0) + opt_optimizer: str = field(default="adamw") + opt_weight_decay: float = field(default=0.01) + opt_betas: Tuple[float, float] = field(default=(0.9, 0.95)) + opt_eps: float = field(default=1e-10) + opt_amsgrad: bool = field(default=False) + opt_sgd_momentum: float = field(default=0.0) + opt_sgd_dampening: float = field(default=0.0) + opt_sgd_nesterov: bool = field(default=False) + opt_fused_rotation: bool = field(default=True) + opt_gradient_checkpointing: Optional[bool] = field(default=None) + opt_stage_cudagraph: bool = field(default=True) + opt_best_state_dtype: Union[str, torch.dtype] = field(default="fp32") + opt_train_on_noisy_inputs: bool = field(default=False) + opt_scope: str = field(default="module") + opt_stage_impl: str = field(default="fast") + opt_pair_impl: str = field(default="fast") + opt_quantizer_impl: str = field(default="reference") + opt_channel_scale_clamp_min: float = field(default=PAROQUANT_OPT_SCALE_CLAMP_MIN_DEFAULT) + opt_channel_scale_clamp_max: float = field(default=PAROQUANT_OPT_SCALE_CLAMP_MAX_DEFAULT) + + def allowed_quant_methods(self) -> Tuple[METHOD, ...]: + return (METHOD.PARO,) + + def supported_export_formats(self) -> Tuple[FORMAT, ...]: + return PAROQUANT_EXPORT_FORMATS + + @staticmethod + def default_opt_gradient_checkpointing_for_scope(opt_scope: str) -> bool: + """Enable activation checkpointing by default only for whole-layer optimization.""" + return str(opt_scope).strip().lower() == "layer" + + def __post_init__(self): + self.method = _normalize_quant_method(self.method) + self.format = _normalize_format(self.format) + if self.format != FORMAT.PAROQUANT: + log.info(f"QuantizeConfig: Auto fix `format` to `{FORMAT.PAROQUANT}`") + self.format = FORMAT.PAROQUANT + super().__post_init__() + self.krot = int(self.krot) + if self.krot <= 0: + raise ValueError("ParoConfig: `krot` must be a positive integer.") + self.opt_rotation_epochs = int(self.opt_rotation_epochs) + self.opt_finetune_epochs = int(self.opt_finetune_epochs) + self.opt_train_samples = int(self.opt_train_samples) + self.opt_validation_samples = int(self.opt_validation_samples) + self.opt_batch_size = int(self.opt_batch_size) + self.opt_rotation_lr = float(self.opt_rotation_lr) + self.opt_weight_lr = float(self.opt_weight_lr) + self.opt_quantizer_lr = float(self.opt_quantizer_lr) + self.opt_pair_ratio = float(self.opt_pair_ratio) + self.opt_seed = int(self.opt_seed) + self.opt_optimizer = str(self.opt_optimizer).strip().lower() + self.opt_weight_decay = float(self.opt_weight_decay) + if not isinstance(self.opt_betas, (list, tuple)) or len(self.opt_betas) != 2: + raise ValueError("ParoConfig: `opt_betas` must be a 2-tuple/list of floats.") + self.opt_betas = (float(self.opt_betas[0]), float(self.opt_betas[1])) + self.opt_eps = float(self.opt_eps) + self.opt_amsgrad = bool(self.opt_amsgrad) + self.opt_sgd_momentum = float(self.opt_sgd_momentum) + self.opt_sgd_dampening = float(self.opt_sgd_dampening) + self.opt_sgd_nesterov = bool(self.opt_sgd_nesterov) + self.opt_fused_rotation = bool(self.opt_fused_rotation) + self.opt_scope = str(self.opt_scope).strip().lower() + checkpointing = self.opt_gradient_checkpointing + if isinstance(checkpointing, str): + normalized_checkpointing = checkpointing.strip().lower() + if normalized_checkpointing in {"1", "true", "yes", "on", "y", "t"}: + checkpointing = True + elif normalized_checkpointing in {"0", "false", "no", "off", "n", "f"}: + checkpointing = False + else: + raise ValueError( + "ParoConfig: `opt_gradient_checkpointing` string values must be one of " + "{'1','0','true','false','yes','no','on','off','y','n','t','f'}." + ) + if checkpointing is None: + checkpointing = self.default_opt_gradient_checkpointing_for_scope(self.opt_scope) + self.opt_gradient_checkpointing = bool(checkpointing) + self.opt_stage_cudagraph = bool(self.opt_stage_cudagraph) + self.opt_best_state_dtype = _normalize_paroquant_best_state_dtype(self.opt_best_state_dtype) + self.opt_train_on_noisy_inputs = bool(self.opt_train_on_noisy_inputs) + self.opt_stage_impl = str(self.opt_stage_impl).strip().lower() + self.opt_pair_impl = str(self.opt_pair_impl).strip().lower() + self.opt_quantizer_impl = str(self.opt_quantizer_impl).strip().lower() + self.opt_channel_scale_clamp_min = float(self.opt_channel_scale_clamp_min) + self.opt_channel_scale_clamp_max = float(self.opt_channel_scale_clamp_max) + if self.opt_rotation_epochs < 0 or self.opt_finetune_epochs < 0: + raise ValueError("ParoConfig: optimization epochs must be non-negative.") + if self.opt_train_samples <= 0 or self.opt_validation_samples <= 0: + raise ValueError("ParoConfig: optimization sample counts must be positive.") + if self.opt_batch_size <= 0: + raise ValueError("ParoConfig: `opt_batch_size` must be positive.") + if self.opt_rotation_lr <= 0 or self.opt_weight_lr <= 0 or self.opt_quantizer_lr <= 0: + raise ValueError("ParoConfig: optimization learning rates must be positive.") + if not (0.0 < self.opt_pair_ratio <= 0.5): + raise ValueError("ParoConfig: `opt_pair_ratio` must be in the interval (0, 0.5].") + if self.opt_optimizer not in {"adamw", "adam", "sgd"}: + raise ValueError("ParoConfig: `opt_optimizer` must be one of {'adamw', 'adam', 'sgd'}.") + if self.opt_weight_decay < 0: + raise ValueError("ParoConfig: `opt_weight_decay` must be non-negative.") + if self.opt_eps <= 0: + raise ValueError("ParoConfig: `opt_eps` must be positive.") + if not all(0.0 <= beta < 1.0 for beta in self.opt_betas): + raise ValueError("ParoConfig: `opt_betas` values must be in the interval [0, 1).") + if self.opt_sgd_momentum < 0: + raise ValueError("ParoConfig: `opt_sgd_momentum` must be non-negative.") + if self.opt_sgd_dampening < 0: + raise ValueError("ParoConfig: `opt_sgd_dampening` must be non-negative.") + if self.opt_sgd_nesterov and self.opt_sgd_momentum <= 0: + raise ValueError("ParoConfig: `opt_sgd_nesterov=True` requires `opt_sgd_momentum > 0`.") + if self.opt_sgd_nesterov and self.opt_sgd_dampening != 0: + raise ValueError("ParoConfig: `opt_sgd_nesterov=True` requires `opt_sgd_dampening == 0`.") + if self.opt_scope not in {"module", "compute_block", "layer"}: + raise ValueError("ParoConfig: `opt_scope` must be one of {'module', 'compute_block', 'layer'}.") + if self.opt_stage_impl not in {"fast", "reference"}: + raise ValueError("ParoConfig: `opt_stage_impl` must be one of {'fast', 'reference'}.") + if self.opt_pair_impl not in {"fast", "reference"}: + raise ValueError("ParoConfig: `opt_pair_impl` must be one of {'fast', 'reference'}.") + if self.opt_quantizer_impl not in {"fast", "reference"}: + raise ValueError("ParoConfig: `opt_quantizer_impl` must be one of {'fast', 'reference'}.") + if self.opt_channel_scale_clamp_min <= 0 or self.opt_channel_scale_clamp_max <= 0: + raise ValueError("ParoConfig: scale clamp bounds must be positive.") + if self.opt_channel_scale_clamp_min >= self.opt_channel_scale_clamp_max: + raise ValueError( + "ParoConfig: `opt_channel_scale_clamp_min` must be smaller than " + "`opt_channel_scale_clamp_max`." + ) + + def quant_linear_init_kwargs(self) -> Dict[str, Any]: + return { + "krot": self.krot, + } + + def _update_meta_payload(self, meta_payload: Dict[str, Any]) -> None: + meta_payload["opt_rotation_epochs"] = self.opt_rotation_epochs + meta_payload["opt_finetune_epochs"] = self.opt_finetune_epochs + meta_payload["opt_train_samples"] = self.opt_train_samples + meta_payload["opt_validation_samples"] = self.opt_validation_samples + meta_payload["opt_batch_size"] = self.opt_batch_size + meta_payload["opt_rotation_lr"] = self.opt_rotation_lr + meta_payload["opt_weight_lr"] = self.opt_weight_lr + meta_payload["opt_quantizer_lr"] = self.opt_quantizer_lr + meta_payload["opt_pair_ratio"] = self.opt_pair_ratio + meta_payload["opt_seed"] = self.opt_seed + meta_payload["opt_optimizer"] = self.opt_optimizer + meta_payload["opt_weight_decay"] = self.opt_weight_decay + meta_payload["opt_betas"] = list(self.opt_betas) + meta_payload["opt_eps"] = self.opt_eps + meta_payload["opt_amsgrad"] = self.opt_amsgrad + meta_payload["opt_sgd_momentum"] = self.opt_sgd_momentum + meta_payload["opt_sgd_dampening"] = self.opt_sgd_dampening + meta_payload["opt_sgd_nesterov"] = self.opt_sgd_nesterov + meta_payload["opt_fused_rotation"] = self.opt_fused_rotation + meta_payload["opt_gradient_checkpointing"] = self.opt_gradient_checkpointing + meta_payload["opt_stage_cudagraph"] = self.opt_stage_cudagraph + meta_payload["opt_best_state_dtype"] = self.opt_best_state_dtype + meta_payload["opt_train_on_noisy_inputs"] = self.opt_train_on_noisy_inputs + meta_payload["opt_scope"] = self.opt_scope + meta_payload["opt_stage_impl"] = self.opt_stage_impl + meta_payload["opt_pair_impl"] = self.opt_pair_impl + meta_payload["opt_quantizer_impl"] = self.opt_quantizer_impl + meta_payload["opt_channel_scale_clamp_min"] = self.opt_channel_scale_clamp_min + meta_payload["opt_channel_scale_clamp_max"] = self.opt_channel_scale_clamp_max + + def _update_output_payload(self, out: Dict[str, Any]) -> None: + out["zero_point"] = not self.sym + out["krot"] = self.krot + out[FORMAT_FIELD_CODE] = self.format + + +@dataclass +class QQQConfig(GPTQConfig): + method: METHOD = field(default=METHOD.QQQ) + format: FORMAT = field(default=FORMAT.QQQ) + + def allowed_quant_methods(self) -> Tuple[METHOD, ...]: + return (METHOD.QQQ,) + + def supported_export_formats(self) -> Tuple[FORMAT, ...]: + return QQQ_EXPORT_FORMATS + + def default_desc_act(self) -> bool: + return True + + +@dataclass +class FP8Config(PreProcessorConfig): + bits: int = field(default=8, metadata={"choices": [8]}) + method: METHOD = field(default=METHOD.FP8) + format: Optional[str] = field(default="float8_e4m3fn") + group_size: int = field(default=-1) + desc_act: Optional[bool] = field(default=False) + sym: bool = field(default=True) + weight_scale_method: str = field(default="row") + weight_block_size: Optional[Union[List[int], Tuple[int, int]]] = field(default=None) + weight_scale_semantics: str = field(default="inverse") + + def _resolve_checkpoint_format(self) -> FORMAT: + self.format = _normalize_fp8_fmt(self.format) + return FORMAT.FP8 + + def allowed_quant_methods(self) -> Tuple[METHOD, ...]: + return (METHOD.FP8,) + + def supported_export_formats(self) -> Tuple[FORMAT, ...]: + return FP8_EXPORT_FORMATS + + def default_desc_act(self) -> bool: + return False + + def __post_init__(self): + self._normalize_preprocessor_state() + super().__post_init__() + + if self.bits != 8: + raise ValueError("FP8Config: `bits` must be `8`.") + + if self.method != METHOD.FP8: + raise ValueError("FP8Config: `method` must be `fp8`.") + + self.group_size = -1 + self.desc_act = False + self.sym = True + + self.format = _normalize_fp8_fmt(self.format) + block_size = _normalize_fp8_weight_block_size(self.weight_block_size) + self.weight_scale_method = _normalize_fp8_weight_scale_method( + self.weight_scale_method, + weight_block_size=block_size, + ) + self.weight_block_size = list(block_size) if block_size is not None else None + self.weight_scale_semantics = _normalize_fp8_scale_semantics(self.weight_scale_semantics) + + if self.dynamic is not None: + self.dynamic = { + **{k: v for k, v in self.dynamic.items() if k.startswith('-')}, + **{k: v for k, v in self.dynamic.items() if not k.startswith('-')}, + } + for layer, layer_dict in self.dynamic.items(): + self._normalize_dynamic_layer_config( + layer, + layer_dict, + valid_bit_widths=[8], + checkpoint_format=FORMAT.FP8, + ) + + def _normalize_dynamic_layer_config( + self, + layer_name: str, + layer_dict: Dict[str, Any], + *, + valid_bit_widths: List[int], + checkpoint_format: FORMAT, + ) -> None: + del valid_bit_widths, checkpoint_format + if "bits" in layer_dict and int(layer_dict["bits"]) != 8: + raise ValueError(f"FP8Config: layer `{layer_name}` only supports 8-bit FP8 weights.") + if "group_size" in layer_dict and layer_dict["group_size"] not in (-1, None): + raise ValueError("FP8Config: `group_size` is not used; keep it at `-1`.") + + block_size = _normalize_fp8_weight_block_size(layer_dict.get("weight_block_size")) + raw_format = layer_dict.get(FORMAT_FIELD_CODE, layer_dict.get("fmt")) + if raw_format is not None: + layer_dict[FORMAT_FIELD_CODE] = _normalize_fp8_fmt(raw_format) + layer_dict.pop("fmt", None) + if "weight_scale_method" in layer_dict or block_size is not None: + layer_dict["weight_scale_method"] = _normalize_fp8_weight_scale_method( + layer_dict.get("weight_scale_method"), + weight_block_size=block_size, + ) + if "weight_scale_semantics" in layer_dict: + layer_dict["weight_scale_semantics"] = _normalize_fp8_scale_semantics( + layer_dict["weight_scale_semantics"] + ) + if "weight_block_size" in layer_dict: + layer_dict["weight_block_size"] = list(block_size) if block_size is not None else None + + def quant_linear_init_kwargs(self) -> Dict[str, Any]: + return { + "format": self.format, + "weight_scale_method": self.weight_scale_method, + "weight_block_size": self.weight_block_size, + "weight_scale_semantics": self.weight_scale_semantics, + } + + def _update_output_payload(self, out: Dict[str, Any]) -> None: + out[FORMAT_FIELD_CODE] = self.format + out["weight_scale_method"] = self.weight_scale_method + out["weight_block_size"] = self.weight_block_size + out["weight_scale_semantics"] = self.weight_scale_semantics + + def uses_weight_only_lifecycle(self) -> bool: + return True + +@dataclass +class BitsAndBytesConfig(PreProcessorConfig): + bits: int = field(default=4, metadata={"choices": [4, 8]}) + method: METHOD = field(default=METHOD.BITSANDBYTES) + format: Optional[str] = field(default=None) + group_size: int = field(default=-1) + desc_act: Optional[bool] = field(default=False) + sym: bool = field(default=True) + block_size: int = field(default=64) + compress_statistics: bool = field(default=True) + + def _resolve_checkpoint_format(self) -> FORMAT: + self.format = _normalize_bitsandbytes_format(self.format, bits=int(self.bits)) + return FORMAT.BITSANDBYTES + + def allowed_quant_methods(self) -> Tuple[METHOD, ...]: + return (METHOD.BITSANDBYTES,) + + def supported_export_formats(self) -> Tuple[FORMAT, ...]: + return BITSANDBYTES_EXPORT_FORMATS + + def default_desc_act(self) -> bool: + return False + + def __post_init__(self): + self._normalize_preprocessor_state() + super().__post_init__() + + if self.bits not in {4, 8}: + raise ValueError("BitsAndBytesConfig: `bits` must be `4` or `8`.") + if self.method != METHOD.BITSANDBYTES: + raise ValueError("BitsAndBytesConfig: `method` must be `bitsandbytes`.") + + self.group_size = -1 + self.desc_act = False + self.sym = True + + self.format = _normalize_bitsandbytes_format(self.format, bits=int(self.bits)) + self.block_size = _normalize_bitsandbytes_block_size(self.block_size) + self.compress_statistics = bool(self.compress_statistics) + + if self.dynamic is not None: + self.dynamic = { + **{k: v for k, v in self.dynamic.items() if k.startswith('-')}, + **{k: v for k, v in self.dynamic.items() if not k.startswith('-')}, + } + for layer, layer_dict in self.dynamic.items(): + self._normalize_dynamic_layer_config( + layer, + layer_dict, + valid_bit_widths=[4, 8], + checkpoint_format=FORMAT.BITSANDBYTES, + ) + + def _normalize_dynamic_layer_config( + self, + layer_name: str, + layer_dict: Dict[str, Any], + *, + valid_bit_widths: List[int], + checkpoint_format: FORMAT, + ) -> None: + del valid_bit_widths, checkpoint_format + if "bits" in layer_dict and int(layer_dict["bits"]) not in {4, 8}: + raise ValueError(f"BitsAndBytesConfig: layer `{layer_name}` only supports 4-bit or 8-bit weights.") + if "group_size" in layer_dict and layer_dict["group_size"] not in (-1, None): + raise ValueError("BitsAndBytesConfig: `group_size` is not used; keep it at `-1`.") + if "desc_act" in layer_dict and bool(layer_dict["desc_act"]): + raise ValueError("BitsAndBytesConfig: `desc_act` is not supported.") + if "sym" in layer_dict and layer_dict["sym"] is not True: + raise ValueError("BitsAndBytesConfig: `sym` must stay `True`.") + dynamic_bits = int(layer_dict.get("bits", self.bits)) + raw_format = layer_dict.get(FORMAT_FIELD_CODE, layer_dict.get("bnb_quant_type")) + if raw_format is not None: + layer_dict[FORMAT_FIELD_CODE] = _normalize_bitsandbytes_format(raw_format, bits=dynamic_bits) + if "block_size" in layer_dict or "bnb_block_size" in layer_dict: + layer_dict["block_size"] = _normalize_bitsandbytes_block_size( + layer_dict.get("block_size", layer_dict.get("bnb_block_size")) + ) + layer_dict.pop("bnb_block_size", None) + if "compress_statistics" in layer_dict or "bnb_compress_statistics" in layer_dict: + layer_dict["compress_statistics"] = bool( + layer_dict.get("compress_statistics", layer_dict.get("bnb_compress_statistics")) + ) + layer_dict.pop("bnb_compress_statistics", None) + + def quant_linear_init_kwargs(self) -> Dict[str, Any]: + return { + "format": self.format, + "block_size": self.block_size, + "compress_statistics": self.compress_statistics, + } + + def _update_output_payload(self, out: Dict[str, Any]) -> None: + out[FORMAT_FIELD_CODE] = self.format + out["block_size"] = self.block_size + out["compress_statistics"] = self.compress_statistics + + def uses_weight_only_lifecycle(self) -> bool: + return True + + @property + def bnb_quant_type(self) -> str: + return self.format + + @bnb_quant_type.setter + def bnb_quant_type(self, value: str) -> None: + self.format = _normalize_bitsandbytes_format(value, bits=int(self.bits)) + + @property + def bnb_block_size(self) -> int: + return self.block_size + + @bnb_block_size.setter + def bnb_block_size(self, value: int) -> None: + self.block_size = _normalize_bitsandbytes_block_size(value) + + @property + def bnb_compress_statistics(self) -> bool: + return self.compress_statistics + + @bnb_compress_statistics.setter + def bnb_compress_statistics(self, value: bool) -> None: + self.compress_statistics = bool(value) -# deprecated: will be removed in future update @dataclass -class BaseQuantizeConfig(QuantizeConfig): - def __init__(self, **kwargs): - super().__init__(**kwargs) - log.warn( - "QuantizeConfig: BaseQuantizeConfig is re-named and pending deprecation. Please use `QuantizeConfig` instead." +class EXL3Config(BaseQuantizeConfig): + bits: float = field(default=3.0) + method: METHOD = field(default=METHOD.EXL3) + format: FORMAT = field(default=FORMAT.EXL3) + group_size: int = field(default=-1) + desc_act: Optional[bool] = field(default=False) + sym: bool = field(default=True) + head_bits: Optional[float] = field(default=None) + out_scales: Optional[str] = field(default="auto") + codebook: str = field(default="mcg") + tensor_storage: Optional[Dict[str, Any]] = field(default=None) + calibration: Optional[Dict[str, int]] = field(default=None) + + @property + def runtime_bits(self) -> int: + return quant_bits_width(self.bits) + + def allowed_quant_methods(self) -> Tuple[METHOD, ...]: + return (METHOD.EXL3,) + + def supported_export_formats(self) -> Tuple[FORMAT, ...]: + return EXL3_EXPORT_FORMATS + + def default_desc_act(self) -> bool: + return False + + def _normalize_bits_field(self, bits_value, checkpoint_format: FORMAT): + return _normalize_exl3_bits(bits_value) + + def _normalize_dynamic_layer_config( + self, + layer_name: str, + layer_dict: Dict[str, Any], + *, + valid_bit_widths: List[int], + checkpoint_format: FORMAT, + ) -> None: + del valid_bit_widths, checkpoint_format + for key, value in layer_dict.items(): + if key == "bits": + layer_dict[key] = _normalize_exl3_bits(value) + elif key == "head_bits": + layer_dict[key] = None if value is None else _normalize_exl3_bits(value) + elif key == "group_size" and value not in (-1, None): + raise ValueError("EXL3Config: `group_size` is not used; keep it at `-1`.") + + def __post_init__(self): + self.method = _normalize_quant_method(self.method) + self.format = _normalize_format(self.format) + self.pack_dtype = _normalize_pack_dtype(self.pack_dtype) + self.bits = _normalize_exl3_bits(self.bits) + self.head_bits = None if self.head_bits is None else _normalize_exl3_bits(self.head_bits) + + if self.method != METHOD.EXL3: + raise ValueError("EXL3Config: `method` must be `exl3`.") + if self.format != FORMAT.EXL3: + raise ValueError("EXL3Config: `format` must be `exl3`.") + + self.group_size = -1 + self.desc_act = False + self.sym = True + + self.fallback = _normalize_fallback(self.fallback) + + if self.dynamic is not None: + self.dynamic = { + **{k: v for k, v in self.dynamic.items() if k.startswith('-')}, + **{k: v for k, v in self.dynamic.items() if not k.startswith('-')}, + } + for layer, layer_dict in self.dynamic.items(): + self._normalize_dynamic_layer_config( + layer, + layer_dict, + valid_bit_widths=[], + checkpoint_format=FORMAT.EXL3, + ) + + if self.out_scales is not None: + normalized_out_scales = str(self.out_scales).strip().lower() + out_scale_aliases = { + "always": "always", + "true": "always", + "never": "never", + "false": "never", + "auto": "auto", + "none": "auto", + } + if normalized_out_scales not in out_scale_aliases: + raise ValueError("EXL3Config: `out_scales` must be one of `always`, `never`, or `auto`.") + self.out_scales = out_scale_aliases[normalized_out_scales] + + self.codebook = str(self.codebook).strip().lower() + if self.codebook not in {"mcg", "mul1", "3inst"}: + raise ValueError("EXL3Config: `codebook` must be one of `mcg`, `mul1`, or `3inst`.") + + if self.tensor_storage is not None and not isinstance(self.tensor_storage, dict): + raise ValueError("EXL3Config: `tensor_storage` must be a dictionary when provided.") + if self.calibration is not None: + if not isinstance(self.calibration, dict): + raise ValueError("EXL3Config: `calibration` must be a dictionary when provided.") + self.calibration = { + str(key): int(value) + for key, value in self.calibration.items() + } + + if self.meta is not None: + if not isinstance(self.meta, dict): + raise ValueError("QuantizeConfig: `meta` must be a dictionary") + for key in self.meta: + if not isinstance(key, str): + raise ValueError("QuantizeConfig: `meta` keys must be strings") + else: + self.meta = {} + + self.adapter = normalize_adapter(self.adapter) + + if self.offload_to_disk and not self.offload_to_disk_path: + path_key = f"{get_random_string()}-{get_random_string()}" + self.offload_to_disk_path = f"./gptqmodel_offload/{path_key}/" + log.info(f"QuantizeConfig: offload_to_disk_path auto set to `{self.offload_to_disk_path}`") + + self.dense_vram_strategy = _normalize_dense_vram_strategy(self.dense_vram_strategy) + self.dense_vram_strategy_devices = _normalize_strategy_devices( + self.dense_vram_strategy_devices, + field_name="dense_vram_strategy_devices", + ) + self.moe_vram_strategy = _normalize_moe_vram_strategy(self.moe_vram_strategy) + self.moe_vram_strategy_devices = _normalize_strategy_devices( + self.moe_vram_strategy_devices, + field_name="moe_vram_strategy_devices", + ) + self.gc_mode = _normalize_gc_mode(self.gc_mode) + self.moe = _normalize_moe_config(self.moe) + + def _update_output_payload(self, out: Dict[str, Any]) -> None: + out["bits"] = float(self.bits) + out["head_bits"] = None if self.head_bits is None else float(self.head_bits) + out["out_scales"] = self.out_scales + out["codebook"] = self.codebook + out["tensor_storage"] = self.tensor_storage + out["calibration"] = self.calibration + + def calculate_bits_per_weight(self): + head_bits = self.head_bits if self.head_bits is not None else self.bits + log.info( + "Estimated Quantization BPW (bits per weight): %s bpw, based on [bits: %s, head_bits: %s]", + self.bits, + self.bits, + head_bits, + ) + +@dataclass +class RTNConfig(PreProcessorConfig): + method: METHOD = field(default=METHOD.GPTQ) + format: FORMAT = field(default=FORMAT.GPTQ) + + def allowed_quant_methods(self) -> Tuple[METHOD, ...]: + return (METHOD.GPTQ,) + + def supported_export_formats(self) -> Tuple[FORMAT, ...]: + return RTN_EXPORT_FORMATS + + def default_desc_act(self) -> bool: + return False + + def __post_init__(self): + super().__post_init__() + + def _update_output_payload(self, out: Dict[str, Any]) -> None: + out["sym"] = self.sym + out[FORMAT_FIELD_CODE] = self.format + + def _update_meta_payload(self, meta_payload: Dict[str, Any]) -> None: + super()._update_meta_payload(meta_payload) + meta_payload["weight_only"] = { + "smooth": _serialize_smooth_method(self.smooth), + } + + def uses_weight_only_lifecycle(self) -> bool: + return True + + +@dataclass +class GGUFConfig(PreProcessorConfig): + bits: Union[int, str, GGUFBits] = field(default=4, metadata={"choices": [1, 2, 3, 4, 5, 6, 8]}) + format: Optional[str] = field(default=None) + method: METHOD = field(default=METHOD.GGUF, init=False) + group_size: int = field(default=-1, init=False, repr=False) + desc_act: Optional[bool] = field(default=False, init=False, repr=False) + sym: bool = field(default=True, init=False, repr=False) + _gguf_bits: GGUFBits = field(init=False, repr=False, compare=False) + + @property + def runtime_bits(self) -> GGUFBits: + return self._gguf_bits + + def allowed_quant_methods(self) -> Tuple[METHOD, ...]: + return (METHOD.GGUF,) + + def supported_export_formats(self) -> Tuple[FORMAT, ...]: + return (FORMAT.GGUF,) + + def default_desc_act(self) -> bool: + return False + + def _resolve_checkpoint_format(self) -> FORMAT: + self.bits, self.format, self._gguf_bits = _normalize_gguf_config_spec(self.bits, self.format) + return FORMAT.GGUF + + def _normalize_bits_field(self, bits_value, checkpoint_format: FORMAT): + normalized = _normalize_quant_bits(bits_value, format_value=FORMAT.GGUF) + return normalized.bits if isinstance(normalized, GGUFBits) else normalized + + def _normalize_dynamic_layer_config( + self, + layer_name: str, + layer_dict: Dict[str, Any], + *, + valid_bit_widths: List[int], + checkpoint_format: FORMAT, + ) -> None: + bits_override_present = "bits" in layer_dict + format_override_present = FORMAT_FIELD_CODE in layer_dict + + if bits_override_present or format_override_present: + raw_bits = layer_dict.get("bits", self.bits) + raw_format = layer_dict.get(FORMAT_FIELD_CODE, self.format) + normalized_bits, normalized_format, normalized_runtime_bits = _normalize_gguf_config_spec(raw_bits, raw_format) + + layer_dict["bits"] = normalized_bits + + bits_implied_format = ( + isinstance(raw_bits, GGUFBits) + or (isinstance(raw_bits, str) and not raw_bits.strip().isdigit()) + ) + if format_override_present or bits_implied_format: + layer_dict[FORMAT_FIELD_CODE] = normalized_format + + if quant_bits_width(normalized_runtime_bits) not in valid_bit_widths: + raise ValueError( + f"QuantizeConfig: Layer `{layer_name}` only support quantization of `{valid_bit_widths}` bits." + ) + + if "group_size" in layer_dict and layer_dict["group_size"] != -1 and layer_dict["group_size"] <= 0: + raise ValueError(_resolve_dynamic_group_size_error()) + + def __post_init__(self): + self._normalize_preprocessor_state() + # GGUFConfig already normalized preprocessors above; skip the parent hook to + # avoid running that normalization twice. + BaseQuantizeConfig.__post_init__(self) + self._gguf_bits = _gguf_bits_from_components(self.bits, self.format) + + def _update_meta_payload(self, meta_payload: Dict[str, Any]) -> None: + super()._update_meta_payload(meta_payload) + + def _update_output_payload(self, out: Dict[str, Any]) -> None: + out[FORMAT_FIELD_CODE] = self.format + + def to_dict(self): + out = super().to_dict() + out.pop(GROUP_SIZE_FIELD_CODE, None) + out.pop("desc_act", None) + out.pop(PACK_DTYPE_FIELD, None) + + meta_payload = out.get(META_FIELD) + if isinstance(meta_payload, dict): + for key in ( + "fallback", + "offload_to_disk", + "offload_to_disk_path", + "pack_impl", + "gc_mode", + "wait_for_submodule_finalizers", + "auto_forward_data_parallel", + "dense_vram_strategy", + "dense_vram_strategy_devices", + "moe_vram_strategy", + "moe_vram_strategy_devices", + "weight_only", + ): + meta_payload.pop(key, None) + if not meta_payload: + out.pop(META_FIELD, None) + + return out + + def calculate_bits_per_weight(self): + bits_name = self.runtime_bits.to_string() + bpw = _GGUF_APPROX_BITS_PER_WEIGHT_BY_ALIAS.get(bits_name, float(quant_bits_width(self.runtime_bits))) + log.info( + f"Estimated Quantization BPW (bits per weight): {bpw} bpw, based on [bits: {self.bits}, format: {self.format}]" + ) + + def uses_weight_only_lifecycle(self) -> bool: + return True + +def clone_weight_only_config_for_module( + qcfg: Union[RTNConfig, GGUFConfig, FP8Config, BitsAndBytesConfig], + module_full_name: str, +) -> Optional[Union[RTNConfig, GGUFConfig, FP8Config, BitsAndBytesConfig]]: + if qcfg.dynamic_get(layer_name=module_full_name) is False: + return None + + qcfg_clone = copy.deepcopy(qcfg) + + if qcfg.dynamic is not None: + smooth_override = qcfg.dynamic_get(module_full_name, "smoother", None) + if smooth_override is None: + smooth_override = qcfg.dynamic_get(module_full_name, "smooth", None) + if smooth_override is not None: + qcfg_clone.smoother = _normalize_smoother_config(smooth_override) + qcfg_clone.smooth = qcfg_clone.resolve_smooth_method() + + if isinstance(qcfg_clone, GGUFConfig): + dynamic_bits = qcfg.dynamic_get(module_full_name, "bits", qcfg_clone.bits) + dynamic_format = qcfg.dynamic_get(module_full_name, FORMAT_FIELD_CODE, qcfg_clone.format) + qcfg_clone.bits, qcfg_clone.format, qcfg_clone._gguf_bits = _normalize_gguf_config_spec( + dynamic_bits, + dynamic_format, + ) + elif isinstance(qcfg_clone, FP8Config): + dynamic_format = qcfg.dynamic_get(module_full_name, FORMAT_FIELD_CODE, None) + if dynamic_format is None: + dynamic_format = qcfg.dynamic_get(module_full_name, "fmt", qcfg_clone.format) + dynamic_block_size = qcfg.dynamic_get( + module_full_name, + "weight_block_size", + qcfg_clone.weight_block_size, + ) + block_size = _normalize_fp8_weight_block_size(dynamic_block_size) + qcfg_clone.format = _normalize_fp8_fmt(dynamic_format) + qcfg_clone.weight_scale_method = _normalize_fp8_weight_scale_method( + qcfg.dynamic_get( + module_full_name, + "weight_scale_method", + qcfg_clone.weight_scale_method, + ), + weight_block_size=block_size, + ) + qcfg_clone.weight_block_size = list(block_size) if block_size is not None else None + qcfg_clone.weight_scale_semantics = _normalize_fp8_scale_semantics( + qcfg.dynamic_get( + module_full_name, + "weight_scale_semantics", + qcfg_clone.weight_scale_semantics, + ) + ) + elif isinstance(qcfg_clone, BitsAndBytesConfig): + qcfg_clone.bits = _normalize_quant_bits( + qcfg.dynamic_get(module_full_name, "bits", qcfg_clone.bits), + format_value=FORMAT.BITSANDBYTES, + ) + qcfg_clone.format = _normalize_bitsandbytes_format( + qcfg.dynamic_get( + module_full_name, + FORMAT_FIELD_CODE, + qcfg.dynamic_get( + module_full_name, + "bnb_quant_type", + qcfg_clone.format, + ), + ), + bits=int(qcfg_clone.bits), + ) + qcfg_clone.block_size = _normalize_bitsandbytes_block_size( + qcfg.dynamic_get( + module_full_name, + "block_size", + qcfg.dynamic_get( + module_full_name, + "bnb_block_size", + qcfg_clone.block_size, + ), + ) + ) + qcfg_clone.compress_statistics = bool( + qcfg.dynamic_get( + module_full_name, + "compress_statistics", + qcfg.dynamic_get( + module_full_name, + "bnb_compress_statistics", + qcfg_clone.compress_statistics, + ), + ) + ) + else: + qcfg_clone.bits = _normalize_quant_bits( + qcfg.dynamic_get(module_full_name, "bits", qcfg_clone.bits), + format_value=resolve_quant_format(qcfg_clone.format, qcfg_clone.method), + ) + + if isinstance(qcfg_clone, RTNConfig): + qcfg_clone.sym = qcfg.dynamic_get(module_full_name, "sym", qcfg_clone.sym) + qcfg_clone.group_size = qcfg.dynamic_get(module_full_name, "group_size", qcfg_clone.group_size) + + desc_act_override = qcfg.dynamic_get(module_full_name, "desc_act", None) + if desc_act_override is not None: + qcfg_clone.desc_act = desc_act_override + + return qcfg_clone + + +clone_rtn_config_for_module = clone_weight_only_config_for_module + + +def _resolve_quantize_config_class(payload: Dict[str, Any]) -> type[BaseQuantizeConfig]: + method = payload.get(METHOD_FIELD_CODE, payload.get(QUANT_METHOD_FIELD, METHOD.GPTQ)) + raw_format_value = payload.get(FORMAT_FIELD_CODE, payload.get(FORMAT_FIELD_CHECKPOINT, FORMAT.GPTQ)) + weight_only = payload.get("weight_only") + bits = payload.get(BITS_FIELD_CODE) + gguf_public_format = payload.get(FORMAT_FIELD_CODE) + + try: + method = _normalize_quant_method(method) + except Exception: + method = METHOD.GPTQ + + if _looks_like_fp8_fmt(raw_format_value): + format_value = FORMAT.FP8 + else: + try: + format_value = _normalize_format(raw_format_value) + except Exception: + try: + gguf_public_format = _normalize_gguf_public_format(raw_format_value) + except ValueError: + gguf_public_format = payload.get(FORMAT_FIELD_CODE) + format_value = FORMAT.GPTQ + + gguf_format_detected = False + if gguf_public_format is not None: + try: + gguf_format_detected = _normalize_gguf_public_format(gguf_public_format) is not None + except ValueError: + gguf_format_detected = False + + weight_only_method = _peek_weight_only_method(weight_only) + fp8_storage_fmt = payload.get(FORMAT_FIELD_CODE, payload.get("fmt")) + if weight_only is not None and weight_only_method not in { + None, + WeightOnlyMethod.RTN, + WeightOnlyMethod.GGUF, + WeightOnlyMethod.FP8, + WeightOnlyMethod.BITSANDBYTES, + }: + raise ValueError( + "QuantizeConfig: unsupported weight-only config. Weight-only export currently supports " + "`rtn`, `gguf`, `fp8`, and `bitsandbytes`." ) + if ( + format_value == FORMAT.GGUF + or weight_only_method == WeightOnlyMethod.GGUF + or _looks_like_gguf_bits(bits) + or gguf_format_detected + ): + return GGUFConfig + if weight_only_method == WeightOnlyMethod.FP8: + return FP8Config + if weight_only_method == WeightOnlyMethod.BITSANDBYTES: + return BitsAndBytesConfig + if weight_only_method == WeightOnlyMethod.RTN: + return RTNConfig + if weight_only is not None: + return RTNConfig + if method == METHOD.FP8 or format_value == FORMAT.FP8 or _looks_like_fp8_fmt(fp8_storage_fmt): + return FP8Config + if method == METHOD.BITSANDBYTES or format_value == FORMAT.BITSANDBYTES or _looks_like_bitsandbytes_format(raw_format_value): + return BitsAndBytesConfig + if method == METHOD.EXL3 or format_value == FORMAT.EXL3: + return EXL3Config + if method == METHOD.PARO or format_value == FORMAT.PAROQUANT: + return ParoConfig + if method == METHOD.QQQ or format_value == FORMAT.QQQ: + return QQQConfig + if method == METHOD.AWQ: + return AWQConfig + if format_value in {FORMAT.GEMM, FORMAT.GEMV, FORMAT.GEMV_FAST, FORMAT.LLM_AWQ}: + return AWQConfig + if format_value == FORMAT.MARLIN: + return AWQConfig if method == METHOD.AWQ else GPTQConfig + return GPTQConfig + + +def _known_quantize_config_field_names() -> set[str]: + field_names: set[str] = set() + for cls in ( + BaseQuantizeConfig, + PreProcessorConfig, + QuantizeConfig, + GPTQConfig, + AWQConfig, + ParoConfig, + QQQConfig, + FP8Config, + BitsAndBytesConfig, + EXL3Config, + RTNConfig, + GGUFConfig, + ): + field_names.update(field.name for field in fields(cls)) + return field_names diff --git a/gptqmodel/quantization/dtype.py b/gptqmodel/quantization/dtype.py index f24fe5818..699bf1548 100644 --- a/gptqmodel/quantization/dtype.py +++ b/gptqmodel/quantization/dtype.py @@ -7,9 +7,12 @@ from __future__ import annotations +import os +from dataclasses import dataclass from typing import Optional import torch +import torch.nn.functional as F try: @@ -18,32 +21,538 @@ unpack_uint4 = None f4_unpacked_to_f32 = None +try: + from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor, nvfp4_quantize +except Exception: + NVFP4Tensor = None + nvfp4_quantize = None + __all__ = [ + "DeviceDTypeSupport", + "get_device_dtype_support", + "device_supports_dtype", + "available_float4_packed_dtype_names", + "available_float4_packed_dtypes", + "available_float8_dtype_names", + "available_float8_dtypes", "device_supports_native_fp8", + "device_supports_native_fp4", + "dequantize_fp8", "dequantize_f8_e4m3", "dequantize_f4_e2m1", + "is_fp4_packed_dtype", ] -def device_supports_native_fp8(device: Optional[torch.device] = None) -> bool: - """Return ``True`` when the target CUDA device supports native FP8 (E4M3). +# Keep the canonical floatx registries in one place so CPU dequant, config +# normalization, model conversion, and tests all follow the same torch surface. +_FLOAT8_CANDIDATE_NAMES = ( + "float8_e4m3fn", + "float8_e5m2", + "float8_e4m3fnuz", + "float8_e5m2fnuz", + "float8_e8m0fnu", +) +_FLOAT8_DTYPE_NAMES = tuple(name for name in _FLOAT8_CANDIDATE_NAMES if hasattr(torch, name)) +_FLOAT8_DTYPES = tuple(getattr(torch, name) for name in _FLOAT8_DTYPE_NAMES) - Hopper-class GPUs (SM >= 9.0) expose hardware accelerated FP8 kernels while - earlier generations such as the A100 (SM 8.x) do not. When CUDA is - unavailable this helper always returns ``False``. - """ +_FLOAT4_PACKED_CANDIDATE_NAMES = ("float4_e2m1fn_x2",) +_FLOAT4_PACKED_DTYPE_NAMES = tuple(name for name in _FLOAT4_PACKED_CANDIDATE_NAMES if hasattr(torch, name)) +_FLOAT4_PACKED_DTYPES = tuple(getattr(torch, name) for name in _FLOAT4_PACKED_DTYPE_NAMES) + +_TARGET_DTYPE_CODES = { + torch.bfloat16: 0, + torch.float16: 1, +} +_FP8_FORMAT_CODES = { + getattr(torch, "float8_e4m3fn", None): 0, + getattr(torch, "float8_e5m2", None): 1, + getattr(torch, "float8_e4m3fnuz", None): 2, + getattr(torch, "float8_e5m2fnuz", None): 3, + getattr(torch, "float8_e8m0fnu", None): 4, +} +def available_float8_dtype_names() -> tuple[str, ...]: + return _FLOAT8_DTYPE_NAMES + + +def available_float8_dtypes() -> tuple[torch.dtype, ...]: + return _FLOAT8_DTYPES + + +def available_float4_packed_dtype_names() -> tuple[str, ...]: + return _FLOAT4_PACKED_DTYPE_NAMES + + +def available_float4_packed_dtypes() -> tuple[torch.dtype, ...]: + return _FLOAT4_PACKED_DTYPES + + +def is_fp4_packed_dtype(dtype: torch.dtype) -> bool: + return dtype in _FLOAT4_PACKED_DTYPES + + +def _cpu_floatx_threads(numel: Optional[int] = None, *, enable_large_threads: bool = False) -> int: + raw = os.environ.get("GPTQMODEL_FLOATX_CPU_THREADS", "").strip() + default_value = 32 if enable_large_threads and numel is not None and numel >= 64 * 1024 * 1024 else 8 + try: + value = int(raw) if raw else default_value + except ValueError: + value = default_value + return max(1, min(value, 32, os.cpu_count() or 1)) + + +def _can_use_fast_path( + tensor: torch.Tensor, + scale_tensor: Optional[torch.Tensor], + *, + target_dtype: torch.dtype, + allow_float4_storage: bool = False, +) -> bool: + if target_dtype not in _TARGET_DTYPE_CODES: + return False + if tensor.device.type != "cpu": + return False + if tensor.ndim not in (1, 2): + return False + if tensor.dtype not in _FLOAT8_DTYPES: + if allow_float4_storage: + if tensor.dtype != torch.uint8 and tensor.dtype not in _FLOAT4_PACKED_DTYPES: + return False + else: + return False + if scale_tensor is None: + return True + if scale_tensor.device.type != "cpu" or scale_tensor.ndim > 2: + return False + return True + + +def _prefer_reference_fp8_cpu( + tensor: torch.Tensor, + scale_tensor: Optional[torch.Tensor], + *, + target_dtype: torch.dtype, + axis: Optional[int] = 0, +) -> bool: + if scale_tensor is None: + return False + if tensor.device.type != "cpu": + return False + if target_dtype not in _TARGET_DTYPE_CODES: + return False + if os.environ.get("GPTQMODEL_FLOATX_CPU_FORCE_NATIVE_FP8", "").strip().lower() in {"1", "true", "yes", "on"}: + return False - if not torch.cuda.is_available(): + standard_fp8_dtypes = tuple( + dtype for dtype in ( + getattr(torch, "float8_e4m3fn", None), + getattr(torch, "float8_e5m2", None), + ) if dtype is not None + ) + if tensor.dtype not in standard_fp8_dtypes: return False + # Standard torch FP8 dtypes already have a strong ATen path on CPU, but the + # native extension now wins on a few layout/target pairs on this host. Keep + # the default on the reference path unless the scale layout matches one of + # the native wins we have benchmarked and validated. + if tensor.ndim != 2: + return True + rows, cols = tensor.shape + if scale_tensor.ndim == 0: + return False + if scale_tensor.ndim == 1: + resolved_axis = 0 if axis is None else axis + if resolved_axis == 1 and cols % scale_tensor.numel() == 0: + return target_dtype is torch.float16 + return True + if scale_tensor.ndim != 2: + return True + if scale_tensor.shape == tensor.shape: + return True + scale_rows, scale_cols = scale_tensor.shape + if scale_rows == rows and cols % scale_cols == 0: + block_width = cols // scale_cols + if block_width in (16, 64): + return False + return True + + +def _load_floatx_cpu_ops(): + try: + from ..utils.cpp import load_floatx_cpu_extension + except Exception: + return None + + ext = load_floatx_cpu_extension() + if not ext: + return None + + namespace = getattr(torch.ops, "gptqmodel_floatx", None) + if namespace is None: + return None + if not hasattr(namespace, "dequantize_fp8_cpu") or not hasattr(namespace, "dequantize_fp4_cpu"): + return None + return namespace + + +def _fast_scale_arg( + *, + scale: Optional[torch.Tensor], + scale_inv: Optional[torch.Tensor], +) -> tuple[Optional[torch.Tensor], int]: + if scale is not None: + return scale.to(device="cpu", dtype=torch.float32).contiguous(), 1 + if scale_inv is None: + return None, 0 + + scale_tensor = scale_inv.to(device="cpu", dtype=torch.float32).contiguous() + max_abs = float(torch.max(torch.abs(scale_tensor)).item()) if scale_tensor.numel() else 0.0 + return scale_tensor, 1 if max_abs <= 1.0 else 2 + + +def _expand_scale( + scale_tensor: torch.Tensor, + result: torch.Tensor, + *, + axis_hint: Optional[int], +) -> torch.Tensor: + if scale_tensor.ndim == 0: + return scale_tensor + + target_shape = result.shape + if scale_tensor.shape == target_shape: + return scale_tensor + + if scale_tensor.ndim == 2 and len(target_shape) == 2: + blocks_r, blocks_c = scale_tensor.shape + rows, cols = target_shape + if rows % blocks_r == 0 and cols % blocks_c == 0: + repeat_r = rows // blocks_r + repeat_c = cols // blocks_c + expanded = scale_tensor.repeat_interleave(repeat_r, dim=0) + expanded = expanded.repeat_interleave(repeat_c, dim=1) + return expanded + + if scale_tensor.ndim == 1 and len(target_shape) == 2: + rows, cols = target_shape + count = scale_tensor.shape[0] + axis = axis_hint if axis_hint is not None else 0 + axis = axis if axis >= 0 else axis + len(target_shape) + if axis == 0 and rows % count == 0: + repeat = rows // count + expanded = scale_tensor.repeat_interleave(repeat, dim=0).view(rows, 1) + return expanded.expand(rows, cols) + if axis == 1 and cols % count == 0: + repeat = cols // count + expanded = scale_tensor.repeat_interleave(repeat, dim=0).view(1, cols) + return expanded.expand(rows, cols) + + if scale_tensor.ndim == result.ndim: + expanded = scale_tensor + for dim, (target_size, current_size) in enumerate(zip(result.shape, expanded.shape)): + if target_size == current_size: + continue + if current_size == 1: + expanded = expanded.expand(*[ + target_size if i == dim else expanded.shape[i] + for i in range(expanded.ndim) + ]) + continue + if target_size % current_size != 0: + raise ValueError( + f"Cannot broadcast scale dimension {current_size} to target {target_size}" + ) + repeat = target_size // current_size + expanded = expanded.repeat_interleave(repeat, dim=dim) + return expanded + + reshaped = _reshape_for_axis(scale_tensor, axis_hint, result.ndim) + return reshaped.expand(result.shape) + + +def _dequantize_f8_reference( + tensor: torch.Tensor, + *, + scale: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, + axis: Optional[int] = 0, + target_dtype: torch.dtype = torch.bfloat16, +) -> torch.Tensor: + if not _FLOAT8_DTYPES: + raise RuntimeError("Current PyTorch build does not provide FP8 tensors") + + if scale is not None and scale_inv is not None: + raise ValueError("Provide either scale or scale_inv, not both") + + result = tensor.to(target_dtype) + if scale is not None: + scale_tensor = _expand_scale(scale.to(result.dtype), result, axis_hint=axis) + result = result * scale_tensor + elif scale_inv is not None: + scale_tensor = _expand_scale(scale_inv.to(result.dtype), result, axis_hint=axis) + if torch.max(torch.abs(scale_tensor)) <= 1: + result = result * scale_tensor + else: + result = result / scale_tensor + return result + + +def _dequantize_f4_reference( + tensor: torch.Tensor, + *, + scale: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, + axis: Optional[int] = 0, + target_dtype: torch.dtype = torch.bfloat16, +) -> torch.Tensor: + if unpack_uint4 is None or f4_unpacked_to_f32 is None: + raise RuntimeError("torchao with nvfp4 support is required for FP4 dequantization") + + if scale is not None and scale_inv is not None: + raise ValueError("Provide either scale or scale_inv, not both") + + if is_fp4_packed_dtype(tensor.dtype): + tensor = tensor.view(torch.uint8) + elif tensor.dtype is not torch.uint8: + raise ValueError("FP4 packed tensors must use torch.uint8 storage") + + orig_shape = list(tensor.shape) + if not orig_shape: + raise ValueError("Tensor must have at least one dimension") + + unpacked = unpack_uint4(tensor.reshape(-1)) + expanded_shape = orig_shape[:-1] + [orig_shape[-1] * 2] + unpacked = unpacked.view(*expanded_shape) + result = f4_unpacked_to_f32(unpacked).to(target_dtype) + + if scale is not None: + scale_tensor = _expand_scale(scale.to(result.dtype), result, axis_hint=axis) + result = result * scale_tensor + elif scale_inv is not None: + scale_tensor = _expand_scale(scale_inv.to(result.dtype), result, axis_hint=axis) + if torch.max(torch.abs(scale_tensor)) <= 1: + result = result * scale_tensor + else: + result = result / scale_tensor + return result + +_DTYPE_SUPPORT_CACHE: dict[tuple[str, Optional[int], bool], "DeviceDTypeSupport"] = {} + + +@dataclass(frozen=True) +class DeviceDTypeSupport: + """Describe which execution dtypes a device advertises and validates. + + ``advertised_linear_dtypes`` is architecture-based. It answers what the + device family is expected to support in native matmul / linear kernels. + ``validated_linear_dtypes`` is runtime-probed and therefore stricter. + """ + + device: torch.device + capability: Optional[tuple[int, int]] + advertised_linear_dtypes: frozenset[torch.dtype] + validated_linear_dtypes: frozenset[torch.dtype] + + def supports(self, dtype: torch.dtype, *, require_validation: bool = False) -> bool: + """Return whether ``dtype`` is supported for linear-style execution.""" + + supported = ( + self.validated_linear_dtypes + if require_validation + else self.advertised_linear_dtypes + ) + return dtype in supported + + +def _normalize_device(device: Optional[torch.device]) -> torch.device: + """Resolve ``device`` into a concrete torch device.""" + if device is None: - device = torch.device("cuda", torch.cuda.current_device()) + if torch.cuda.is_available(): + return torch.device("cuda", torch.cuda.current_device()) + return torch.device("cpu") + return torch.device(device) + + +def _cuda_capability(device: torch.device) -> Optional[tuple[int, int]]: + """Return CUDA compute capability for a concrete device.""" + + if device.type != "cuda" or not torch.cuda.is_available(): + return None + return tuple(int(v) for v in torch.cuda.get_device_capability(device)) + + +def _advertised_cuda_linear_dtypes(capability: tuple[int, int]) -> frozenset[torch.dtype]: + """Map CUDA architecture families to expected fast linear dtypes.""" + + major, minor = capability + supported = { + torch.float16, + torch.float32, + } + + if major >= 8: + supported.add(torch.bfloat16) + + if _FLOAT8_DTYPES and (major >= 9 or (major, minor) == (8, 9)): + supported.update(_FLOAT8_DTYPES) + if _FLOAT4_PACKED_DTYPES and major >= 10: + supported.update(_FLOAT4_PACKED_DTYPES) + + return frozenset(supported) + + +def _advertised_linear_dtypes_for_device(device: torch.device) -> tuple[Optional[tuple[int, int]], frozenset[torch.dtype]]: + """Return the architecture-level linear dtype support set for ``device``.""" + + if device.type == "cuda" and torch.cuda.is_available(): + capability = _cuda_capability(device) + assert capability is not None + return capability, _advertised_cuda_linear_dtypes(capability) + + # CPU / other devices fall back to portable dtypes only. This helper is + # used for accelerator routing decisions, so keep the default conservative. + return None, frozenset({torch.float16, torch.float32, torch.bfloat16}) + + +def _validate_linear_dtype_support(device: torch.device, dtype: torch.dtype) -> bool: + """Runtime-probe whether one small linear/matmul path works for ``dtype``.""" + + if device.type != "cuda" or not torch.cuda.is_available(): + return dtype in {torch.float16, torch.float32, torch.bfloat16} + + try: + if dtype in _FLOAT4_PACKED_DTYPES: + if NVFP4Tensor is None or nvfp4_quantize is None: + return False + weight = torch.randn(16, 16, dtype=torch.float32) + scales, packed = nvfp4_quantize(weight, block_size=16) + packed_weight = packed.view(dtype) if packed.dtype is not dtype else packed + packed_weight = packed_weight.to(device) + scales = scales.to(device) + x = torch.randn(4, 16, device=device, dtype=torch.bfloat16) + result = F.linear( + x, + NVFP4Tensor( + packed_weight, + scales, + block_size=16, + orig_dtype=torch.bfloat16, + ), + None, + ) + return isinstance(result, torch.Tensor) + + if dtype in _FLOAT8_DTYPES: + if not hasattr(torch, "_scaled_mm"): + return False + compute_dtype = torch.bfloat16 if device_supports_dtype(device, torch.bfloat16) else torch.float16 + a = torch.randn(16, 16, device=device, dtype=compute_dtype).to(dtype) + b = torch.randn(16, 16, device=device, dtype=compute_dtype).to(dtype) + result = torch._scaled_mm( + a, + b, + scale_a=torch.tensor(1.0, device=device), + scale_b=torch.tensor(1.0, device=device), + out_dtype=compute_dtype, + ) + if isinstance(result, tuple): + result = result[0] + return isinstance(result, torch.Tensor) + + a = torch.randn(16, 16, device=device, dtype=dtype) + b = torch.randn(16, 16, device=device, dtype=dtype) + result = torch.matmul(a, b) + return isinstance(result, torch.Tensor) + except Exception: + return False + + +def get_device_dtype_support( + device: Optional[torch.device] = None, + *, + validate: bool = False, +) -> DeviceDTypeSupport: + """Return linear-dtype support metadata for ``device``. + + ``validate=False`` is architecture-based and cheap. + ``validate=True`` probes one small kernel per advertised dtype and caches + the results for subsequent callers. + """ + + resolved_device = _normalize_device(device) + cache_key = (resolved_device.type, resolved_device.index, bool(validate)) + cached = _DTYPE_SUPPORT_CACHE.get(cache_key) + if cached is not None: + return cached + + capability, advertised = _advertised_linear_dtypes_for_device(resolved_device) + validated = advertised + if validate: + validated = frozenset( + dtype + for dtype in advertised + if _validate_linear_dtype_support(resolved_device, dtype) + ) + + support = DeviceDTypeSupport( + device=resolved_device, + capability=capability, + advertised_linear_dtypes=advertised, + validated_linear_dtypes=validated, + ) + _DTYPE_SUPPORT_CACHE[cache_key] = support + return support + + +def device_supports_dtype( + device: Optional[torch.device], + dtype: torch.dtype, + *, + require_validation: bool = False, +) -> bool: + """Return whether ``device`` supports ``dtype`` for linear-style kernels.""" - if device.type != "cuda": + support = get_device_dtype_support(device, validate=require_validation) + return support.supports(dtype, require_validation=require_validation) + + +def device_supports_native_fp8( + device: Optional[torch.device] = None, + *, + require_validation: bool = False, +) -> bool: + """Return ``True`` when the target CUDA device supports native FP8 (E4M3). + + This compatibility wrapper now forwards to the generic device/dtype support + map. By default it returns architecture-advertised support; callers that + need a stricter answer may pass ``require_validation=True``. + """ + + if not _FLOAT8_DTYPES: return False + return device_supports_dtype( + device, + torch.float8_e4m3fn, + require_validation=require_validation, + ) + - major, _ = torch.cuda.get_device_capability(device) - return major >= 9 +def device_supports_native_fp4( + device: Optional[torch.device] = None, + *, + require_validation: bool = False, +) -> bool: + """Return ``True`` when the target device advertises native NVFP4 linear execution.""" + + if not _FLOAT4_PACKED_DTYPES: + return False + return device_supports_dtype( + device, + _FLOAT4_PACKED_DTYPES[0], + require_validation=require_validation, + ) def _reshape_for_axis(tensor: torch.Tensor, axis: Optional[int], target_ndim: int) -> torch.Tensor: @@ -92,84 +601,78 @@ def dequantize_f8_e4m3( omitted the helper falls back to a plain dtype conversion. """ - if not hasattr(torch, "float8_e4m3fn"): - raise RuntimeError("Current PyTorch build does not provide float8_e4m3fn tensors") + if not _FLOAT8_DTYPES: + raise RuntimeError("Current PyTorch build does not provide FP8 tensors") if scale is not None and scale_inv is not None: raise ValueError("Provide either scale or scale_inv, not both") + if tensor.dtype in _FLOAT8_DTYPES and _can_use_fast_path( + tensor, + scale if scale is not None else scale_inv, + target_dtype=target_dtype, + ): + if scale is None and scale_inv is None: + return tensor.to(target_dtype) + if _prefer_reference_fp8_cpu( + tensor, + scale if scale is not None else scale_inv, + target_dtype=target_dtype, + axis=axis, + ): + return _dequantize_f8_reference( + tensor, + scale=scale, + scale_inv=scale_inv, + axis=axis, + target_dtype=target_dtype, + ) + + ops = _load_floatx_cpu_ops() + if ops is not None: + fast_scale, scale_mode = _fast_scale_arg(scale=scale, scale_inv=scale_inv) + format_code = _FP8_FORMAT_CODES.get(tensor.dtype) + if format_code is not None: + enable_large_threads = ( + target_dtype is torch.bfloat16 and + hasattr(torch, "float8_e4m3fn") and + tensor.dtype is torch.float8_e4m3fn + ) + source = tensor.contiguous().view(torch.uint8) + return ops.dequantize_fp8_cpu( + source, + fast_scale, + scale_mode, + 0 if axis is None else int(axis), + axis is None, + _TARGET_DTYPE_CODES[target_dtype], + int(format_code), + _cpu_floatx_threads(tensor.numel(), enable_large_threads=enable_large_threads), + ) + + return _dequantize_f8_reference( + tensor, + scale=scale, + scale_inv=scale_inv, + axis=axis, + target_dtype=target_dtype, + ) - if tensor.dtype is not torch.float8_e4m3fn: - result = tensor.to(target_dtype) - else: - result = tensor.to(target_dtype) - - def _expand_scale(scale_tensor: torch.Tensor, *, axis_hint: Optional[int]) -> torch.Tensor: - if scale_tensor.ndim == 0: - return scale_tensor - - target_shape = result.shape - - if scale_tensor.shape == target_shape: - return scale_tensor - - # Block-wise expansion (e.g. [num_row_blocks, num_col_blocks]) - if scale_tensor.ndim == 2 and len(target_shape) == 2: - blocks_r, blocks_c = scale_tensor.shape - rows, cols = target_shape - if rows % blocks_r == 0 and cols % blocks_c == 0: - repeat_r = rows // blocks_r - repeat_c = cols // blocks_c - expanded = scale_tensor.repeat_interleave(repeat_r, dim=0) - expanded = expanded.repeat_interleave(repeat_c, dim=1) - return expanded - - if scale_tensor.ndim == 1 and len(target_shape) == 2: - rows, cols = target_shape - count = scale_tensor.shape[0] - axis = axis_hint if axis_hint is not None else 0 - axis = axis if axis >= 0 else axis + len(target_shape) - if axis == 0 and rows % count == 0: - repeat = rows // count - expanded = scale_tensor.repeat_interleave(repeat, dim=0).view(rows, 1) - return expanded.expand(rows, cols) - if axis == 1 and cols % count == 0: - repeat = cols // count - expanded = scale_tensor.repeat_interleave(repeat, dim=0).view(1, cols) - return expanded.expand(rows, cols) - - if scale_tensor.ndim == result.ndim: - expanded = scale_tensor - for dim, (target_size, current_size) in enumerate(zip(result.shape, expanded.shape)): - if target_size == current_size: - continue - if current_size == 1: - expanded = expanded.expand(*[ - target_size if i == dim else expanded.shape[i] - for i in range(expanded.ndim) - ]) - continue - if target_size % current_size != 0: - raise ValueError( - f"Cannot broadcast scale dimension {current_size} to target {target_size}" - ) - repeat = target_size // current_size - expanded = expanded.repeat_interleave(repeat, dim=dim) - return expanded - - reshaped = _reshape_for_axis(scale_tensor, axis_hint, result.ndim) - return reshaped.expand(result.shape) - if scale is not None: - scale_tensor = _expand_scale(scale.to(result.dtype), axis_hint=axis) - result = result * scale_tensor - elif scale_inv is not None: - scale_tensor = _expand_scale(scale_inv.to(result.dtype), axis_hint=axis) - if torch.max(torch.abs(scale_tensor)) <= 1: - result = result * scale_tensor - else: - result = result / scale_tensor - - return result +def dequantize_fp8( + tensor: torch.Tensor, + *, + scale: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, + axis: Optional[int] = 0, + target_dtype: torch.dtype = torch.bfloat16, +) -> torch.Tensor: + return dequantize_f8_e4m3( + tensor, + scale=scale, + scale_inv=scale_inv, + axis=axis, + target_dtype=target_dtype, + ) def dequantize_f4_e2m1( @@ -182,74 +685,34 @@ def dequantize_f4_e2m1( ) -> torch.Tensor: """Dequantize FP4 (E2M1) values packed as two nibbles per byte.""" - if unpack_uint4 is None or f4_unpacked_to_f32 is None: - raise RuntimeError("torchao with nvfp4 support is required for FP4 dequantization") - if scale is not None and scale_inv is not None: raise ValueError("Provide either scale or scale_inv, not both") - - if tensor.dtype is not torch.uint8: - raise ValueError("FP4 packed tensors must use torch.uint8 storage") - - orig_shape = list(tensor.shape) - if not orig_shape: - raise ValueError("Tensor must have at least one dimension") - - unpacked = unpack_uint4(tensor.reshape(-1)) - expanded_shape = orig_shape[:-1] + [orig_shape[-1] * 2] - unpacked = unpacked.view(*expanded_shape) - - result = f4_unpacked_to_f32(unpacked).to(target_dtype) - - def _expand_scale_fp4(scale_tensor: torch.Tensor, *, axis_hint: Optional[int]) -> torch.Tensor: - if scale_tensor.ndim == 0: - return scale_tensor - - target_shape = result.shape - - if scale_tensor.shape == target_shape: - return scale_tensor - - if scale_tensor.ndim == 2 and len(target_shape) == 2: - blocks_r, blocks_c = scale_tensor.shape - rows, cols = target_shape - if rows % blocks_r == 0 and cols % blocks_c == 0: - repeat_r = rows // blocks_r - repeat_c = cols // blocks_c - expanded = scale_tensor.repeat_interleave(repeat_r, dim=0) - expanded = expanded.repeat_interleave(repeat_c, dim=1) - return expanded - - if scale_tensor.ndim == result.ndim: - expanded = scale_tensor - for dim, (target_size, current_size) in enumerate(zip(result.shape, expanded.shape)): - if target_size == current_size: - continue - if current_size == 1: - expanded = expanded.expand(*[ - target_size if i == dim else expanded.shape[i] - for i in range(expanded.ndim) - ]) - continue - if target_size % current_size != 0: - raise ValueError( - f"Cannot broadcast scale dimension {current_size} to target {target_size}" - ) - repeat = target_size // current_size - expanded = expanded.repeat_interleave(repeat, dim=dim) - return expanded - - reshaped = _reshape_for_axis(scale_tensor, axis_hint, result.ndim) - return reshaped.expand(result.shape) - - if scale is not None: - scale_tensor = _expand_scale_fp4(scale.to(result.dtype), axis_hint=axis) - result = result * scale_tensor - elif scale_inv is not None: - scale_tensor = _expand_scale_fp4(scale_inv.to(result.dtype), axis_hint=axis) - if torch.max(torch.abs(scale_tensor)) <= 1: - result = result * scale_tensor - else: - result = result / scale_tensor - - return result + if _can_use_fast_path( + tensor, + scale if scale is not None else scale_inv, + target_dtype=target_dtype, + allow_float4_storage=True, + ): + ops = _load_floatx_cpu_ops() + if ops is not None: + fast_scale, scale_mode = _fast_scale_arg(scale=scale, scale_inv=scale_inv) + source = tensor.contiguous() + if source.dtype is not torch.uint8: + source = source.view(torch.uint8) + return ops.dequantize_fp4_cpu( + source, + fast_scale, + scale_mode, + 0 if axis is None else int(axis), + axis is None, + _TARGET_DTYPE_CODES[target_dtype], + _cpu_floatx_threads(source.numel() * 2), + ) + + return _dequantize_f4_reference( + tensor, + scale=scale, + scale_inv=scale_inv, + axis=axis, + target_dtype=target_dtype, + ) diff --git a/gptqmodel/quantization/fallback_smooth.py b/gptqmodel/quantization/fallback_smooth.py new file mode 100644 index 000000000..12e6ec330 --- /dev/null +++ b/gptqmodel/quantization/fallback_smooth.py @@ -0,0 +1,191 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from __future__ import annotations + +import math +from typing import Optional, Tuple + +import torch + +from .config import ( + Fallback, + QuantizeConfig, + SmoothLog, + SmoothMAD, + SmoothOutlier, + SmoothPercentile, + SmoothPercentileAsymmetric, + SmoothRowCol, + SmoothSoftNorm, +) + + +# For Gaussian-like rows, raw MAD ~= 0.67449 * sigma. Normalize MAD so the +# configured `k` behaves like a sigma-width window instead of clipping far more +# aggressively than intended. +MAD_TO_STD_SCALE = 1.4826 + + +def _quantile(block: torch.Tensor, percentile: float) -> torch.Tensor: + if percentile <= 0.0: + return block.min(dim=1, keepdim=True).values + if percentile >= 100.0: + return block.max(dim=1, keepdim=True).values + + q = max(0.0, min(percentile / 100.0, 1.0)) + position = q * (block.shape[1] - 1) + lower_index = int(math.floor(position)) + upper_index = int(math.ceil(position)) + lower = block.kthvalue(lower_index + 1, dim=1, keepdim=True).values + + if upper_index == lower_index: + return lower + + upper = block.kthvalue(upper_index + 1, dim=1, keepdim=True).values + return lower + (upper - lower) * (position - lower_index) + + +def _clamp_block(block: torch.Tensor, lo: torch.Tensor, hi: torch.Tensor) -> torch.Tensor: + return torch.max(torch.min(block, hi), lo) + + +def smooth_block( + block: torch.Tensor, + fallback: Fallback, + *, + eps: float = 1e-8, + group_size: Optional[int] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + method = getattr(fallback, "smooth", None) + if method is None: + return block, None + if group_size is not None and group_size < 0: + group_size = block.shape[1] + if group_size is not None and group_size < getattr(method, "group_size_threshold", 0): + return block, None + + if isinstance(method, SmoothRowCol): + axis = (method.axis or "row").lower() + if axis == "col": + col_rms = block.pow(2).mean(dim=0, keepdim=True).sqrt().clamp(min=eps) + scale_factor = col_rms.mean().view(1, 1) + else: + scale_factor = block.pow(2).mean(dim=1, keepdim=True).sqrt().clamp(min=eps) + return block / scale_factor, scale_factor + + block_f = block.float() + + if isinstance(method, SmoothSoftNorm): + mean = block_f.mean(dim=1, keepdim=True) + rms = (block_f - mean).pow(2).mean(dim=1, keepdim=True).sqrt().clamp(min=eps) + k = float(method.k) + block_norm = (block_f - mean) / rms + block_norm = torch.clamp(block_norm, -k, k) + return (block_norm * rms + mean).to(block.dtype), None + + if isinstance(method, SmoothPercentile): + abs_block = block_f.abs() + threshold = _quantile(abs_block, float(method.percentile)) + return torch.clamp(block_f, -threshold, threshold).to(block.dtype), None + + if isinstance(method, SmoothPercentileAsymmetric): + low = float(method.low) + high = float(method.high) + lo = _quantile(block_f, low) + hi = _quantile(block_f, high) + return _clamp_block(block_f, lo, hi).to(block.dtype), None + + if isinstance(method, SmoothMAD): + median = block_f.median(dim=1, keepdim=True).values + mad = (block_f - median).abs().median(dim=1, keepdim=True).values * MAD_TO_STD_SCALE + k = float(method.k) + lo = median - k * mad + hi = median + k * mad + return _clamp_block(block_f, lo, hi).to(block.dtype), None + + if isinstance(method, SmoothOutlier): + pct = float(method.pct) + if pct <= 0.0: + return block, None + abs_block = block_f.abs() + k = max(1, int(round(abs_block.shape[1] * (1.0 - pct / 100.0)))) + k = min(k, abs_block.shape[1]) + threshold = torch.kthvalue(abs_block, k, dim=1, keepdim=True).values + return torch.clamp(block_f, -threshold, threshold).to(block.dtype), None + + if isinstance(method, SmoothLog): + mu = max(float(method.mu), eps) + abs_block = block_f.abs() + log_mu = math.log1p(mu) + log_vals = torch.log1p(abs_block * mu) / log_mu + threshold = _quantile(log_vals, float(method.percentile)) + lin_threshold = (torch.exp(threshold * log_mu) - 1.0) / mu + return torch.clamp(block_f, -lin_threshold, lin_threshold).to(block.dtype), None + + return block, None + + +def _eval_mse(block_f, min_val, max_val, base_zero, qcfg, maxq, shrink, eps): + """Compute MSE for shrinkage factors [rows, n, 1].""" + scale = torch.clamp((max_val.unsqueeze(1) * shrink - min_val.unsqueeze(1) * shrink) / maxq, min=eps) + zero = base_zero.unsqueeze(1).expand_as(scale) if qcfg.sym else torch.round(-min_val.unsqueeze(1) * shrink / scale) + q = torch.clamp(torch.round(block_f.unsqueeze(1) / scale + zero), 0, maxq) + + return ((q - zero) * scale - block_f.unsqueeze(1)).pow(2).mean(dim=2) + + +def mse_optimal_quant( + block: torch.Tensor, + qcfg: QuantizeConfig, + maxq: int, + *, + steps: int, + maxshrink: float, + eps: float = 1e-8, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Ternary search: O(log steps)""" + block_f = block.float() + rows = block_f.shape[0] + + if qcfg.sym: + max_abs = block_f.abs().max(dim=1, keepdim=True).values + min_val = -max_abs + max_val = max_abs + base_zero = torch.full_like(max_abs, (maxq + 1) / 2.0) + else: + min_val = block_f.min(dim=1, keepdim=True).values + max_val = block_f.max(dim=1, keepdim=True).values + base_zero = None + + steps = max(int(math.log(max(steps, 2)) / math.log(1.5)) + 1, 3) + shrink = max(min(maxshrink, 1.0), 1e-3) + + # Left, Right pointer + l, r = torch.full((rows, 1), shrink, device=block_f.device, dtype=block_f.dtype), torch.ones((rows, 1), device=block_f.device, dtype=block_f.dtype) + best_err, best_p = torch.full((rows,), float('inf'), device=block_f.device), r.clone() + + for _ in range(steps): + mid1, mid2 = l + (r - l) / 3.0, r - (r - l) / 3.0 + err = _eval_mse(block_f, min_val, max_val, base_zero, qcfg, maxq, torch.stack([mid1, mid2], dim=1).view(rows, 2, 1), eps) + + for i, p in enumerate([mid1, mid2]): + better = err[:, i] < best_err + best_err, best_p = torch.where(better, err[:, i], best_err), torch.where(better.unsqueeze(1), p, best_p) + + move_r = err[:, 0] < err[:, 1] + r, l = torch.where(move_r.unsqueeze(1), mid2, r), torch.where(move_r.unsqueeze(1), l, mid1) + + # Refine + delta = (r - l) * 0.1 + refinement = torch.stack([torch.clamp(best_p - delta, shrink, 1.0), best_p, torch.clamp(best_p + delta, shrink, 1.0)], dim=1).view(rows, 3, 1) + best_p = torch.gather(refinement.squeeze(2), 1, _eval_mse(block_f, min_val, max_val, base_zero, qcfg, maxq, refinement, eps).argmin(dim=1).unsqueeze(1)) + + # Final quantization + scale_best = torch.clamp((max_val - min_val) * best_p / maxq, min=eps) + zero_best = base_zero if qcfg.sym else torch.round(-min_val * best_p / scale_best) + q = torch.clamp(torch.round(block_f / scale_best + zero_best), 0, maxq) + dequant_best = (q - zero_best) * scale_best + + return dequant_best, scale_best, zero_best diff --git a/gptqmodel/quantization/foem.py b/gptqmodel/quantization/foem.py new file mode 100644 index 000000000..b17002068 --- /dev/null +++ b/gptqmodel/quantization/foem.py @@ -0,0 +1,320 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +# adapted from https://github.com/Intelligent-Computing-Lab-Yale/GPTQv2 +# adapted from @qwopqwop200 's [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa/tree/cuda), which itself is based on [gptq](https://github.com/IST-DASLab/gptq) + +import math +import os +import sys +import time +from typing import Optional + +import torch +import torch.nn as nn +import transformers + +from ..looper.named_module import NamedModule +from ..quantization import QuantizeConfig +from ..utils.torch import TORCH_GTE_28, torch_compile, torch_sync +from .gptq import GPTQ + + +class FOEM(GPTQ): + def __init__(self, module: NamedModule, qcfg: Optional[QuantizeConfig] = None): + from ..looper.native_processor import NATIVE_INPUTS_STATE_KEY # avoid import loop + + super().__init__(module, qcfg) + + self.H = None + self.dXXT = None + + + if self.qcfg.foem.alpha == 0: + self.gptaq = False + else: + self.gptaq =True + + if self.gptaq: + self.native_inps = module.state.pop(NATIVE_INPUTS_STATE_KEY) + + def add_batch(self, inp: torch.Tensor, out: torch.Tensor, batch_index: Optional[int] = None): + with self.lock: + self.fwd_counter += 1 + self.process_batch(inp) + + # TODO FIXME: using v1 new process_batch kills v2 quantization quality, use original process_batch + # sample counter based on batch request # instead of batched token #. + # def process_batch(self, inp): + # batch_token_size, reshaped_inp, alpha, beta = super().process_batch(inp) + # del inp + # + # native_inp = self.native_inps.pop(0).to(device=DEVICE_1, dtype=torch.float32) + # + # # input reshaping + # if isinstance(self.module, (nn.Linear, transformers.Conv1D)): + # native_inp = native_inp.reshape(-1, native_inp.shape[-1]) + # else: + # unfold = nn.Unfold( + # self.module.kernel_size, + # dilation=self.module.dilation, + # padding=self.module.padding, + # stride=self.module.stride, + # ) + # # output size (batch_size, channels * \prod kernel_size, num_patches) + # native_inp = unfold(native_inp).transpose(1, 2).flatten(0, 1) + # + # if self.dXXT is None: + # self.dXXT = torch.zeros((self.columns, self.columns), + # dtype=torch.float32, + # device=DEVICE_1) + # + # self.dXXT.addmm_((native_inp.T-reshaped_inp.T), reshaped_inp, beta=beta, alpha=alpha) + # del native_inp, reshaped_inp + + def process_batch(self, inp): + inp = inp.to(dtype=torch.float32) + if self.gptaq: + native_inp = self.native_inps.pop(0).to(device=inp.device, dtype=torch.float32) + if len(inp.shape) == 2: + inp = inp.unsqueeze(0) + if self.gptaq: + native_inp = native_inp.unsqueeze(0) + + batch_size = inp.shape[0] + + if isinstance(self.module, (nn.Linear, transformers.Conv1D)): + if len(inp.shape) == 3: + inp = inp.reshape((-1, inp.shape[-1])) + if self.gptaq: + native_inp = native_inp.reshape((-1, inp.shape[-1])) + inp = inp.t() + if self.gptaq: + native_inp = native_inp.t() + + if isinstance(self.module, nn.Conv2d): + unfold = nn.Unfold( + self.module.kernel_size, + dilation=self.module.dilation, + padding=self.module.padding, + stride=self.module.stride, + ) + inp = unfold(inp) + inp = inp.permute([1, 0, 2]) + inp = inp.flatten(1) + if self.gptaq: + native_inp = unfold(native_inp) + native_inp = native_inp.permute([1, 0, 2]).flatten(1) + + if self.H is None: + self.H = torch.zeros((self.columns, self.columns), + dtype=torch.float32, + device=inp.device) + if self.gptaq: + self.dXXT = self.H.clone() + else: + self.H *= self.nsamples / (self.nsamples + batch_size) + if self.gptaq: + self.dXXT *= self.nsamples / (self.nsamples + batch_size) + + self.nsamples += batch_size + inp = math.sqrt(2 / self.nsamples) * inp.float() + + self.H += inp.matmul(inp.t()) + if self.gptaq: + native_inp = math.sqrt(2 / self.nsamples) * native_inp + self.dXXT += (native_inp - inp).matmul(inp.t()) + + @torch.inference_mode() + def quantize( + self, + blocksize=128, + ): + # self.H = self.H.to(device=CUDA_0) + # log.info(f"Quantization `{self.name}` using samples: `{self.nsamples}`") + start = time.time() + + # TODO compilation failure for Torch >= 2.8 + if not TORCH_GTE_28: + self.hessian_inverse = torch_compile(self.hessian_inverse) + + # if self.device.type not in ["mps", "cpu"]: + # self.module.weight.data = self.module.weight.data.cpu() + + # TODO: waiting for pytorch implementation of ops for MPS + if sys.platform == "darwin" and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") != "1": + raise RuntimeError( + "For MacOS you must set env `PYTORCH_ENABLE_MPS_FALLBACK=1` before running quantization.") + + if self.module_copy is None: + # log.info("copy W to cuda_1") + W = self.clone_module(device=self.H.device) + else: + W = self.module_copy + self.module_copy = None + + self.quantizer.find_params(W, weight=True) + + H = self.H + del self.H + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + W[:, dead] = 0 + + if self.gptaq: + self.dXXT[:, dead] = 0 + + # g_idx = [] + scale = [] + zero = [] + now_idx = 1 + + if self.qcfg.static_groups: + import copy + + groups = [] + for i in range(0, self.columns, self.qcfg.group_size): + quantizer = copy.deepcopy(self.quantizer) + quantizer.find_params(W[:, i: (i + self.qcfg.group_size)], weight=True) + + scale.append(quantizer.scale) + zero.append(quantizer.zero) + groups.append(quantizer) + + if self.qcfg.desc_act: + perm = torch.argsort(torch.diag(H), descending=True) + W = W[:, perm] + H = H[perm][:, perm] + if self.gptaq: + self.dXXT = self.dXXT[perm][:, perm] + invperm = torch.argsort(perm) + + W_raw = W.detach().clone() + + Losses = torch.zeros_like(W) + Q = torch.zeros_like(W) + + Hinv, damp = self.hessian_inverse(H) + if self.qcfg.foem is None: + raise ValueError("FOEM requires `foem` configuration.") + if self.gptaq: + P = self.qcfg.foem.alpha * ((self.dXXT @ Hinv.T).triu(diagonal=1)) @ Hinv + del self.dXXT + + for i1 in range(0, self.columns, blocksize): + i2 = min(i1 + blocksize, self.columns) + count = i2 - i1 + + W1 = W[:, i1:i2].clone() + W1_raw = W_raw[:, i1:i2].clone() + Q1 = torch.zeros_like(W1) + Err1 = torch.zeros_like(W1) + Losses1 = torch.zeros_like(W1) + Hinv1 = Hinv[i1:i2, i1:i2] + if self.gptaq: + P1 = P[i1:i2, i1:i2] + + for i in range(count): + w = W1[:, i] + w_raw = W1_raw[:, i] + d = Hinv1[i, i] + + if self.qcfg.group_size != -1: + if not self.qcfg.static_groups: + if (i1 + i) % self.qcfg.group_size == 0: + self.quantizer.find_params(W[:, (i1 + i): (i1 + i + self.qcfg.group_size)], weight=True) + + if ((i1 + i) // self.qcfg.group_size) - now_idx == -1: + scale.append(self.quantizer.scale) + zero.append(self.quantizer.zero) + now_idx += 1 + else: + idx = i1 + i + if self.qcfg.desc_act: + idx = perm[idx] + + self.quantizer = groups[idx // self.qcfg.group_size] + + q = self.quantizer.quantize(w.unsqueeze(1)).flatten() + Q1[:, i] = q + Losses1[:, i] = (w - q) ** 2 / d ** 2 + + err1 = ((w - q) - (w - w_raw) * self.qcfg.foem.beta) / d + + if self.gptaq: + W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) - w.unsqueeze(1).matmul(P1[i, i:].unsqueeze(0)) + else: + W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) + if i+1 < count: + W1[:, i+1] -= self.qcfg.foem.beta * (W1[:, i+1]-W1_raw[:, i+1]) + Err1[:, i] = err1 + + Q[:, i1:i2] = Q1 + Losses[:, i1:i2] = Losses1 / 2 + + if self.gptaq: + W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) - W1.matmul(P[i1:i2, i2:]) + else: + W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) + if i+1 < count: + W[:, i2+1] -= self.qcfg.foem.beta * (W[:, i2+1]-W_raw[:, i2+1]) + + del Hinv + if self.gptaq: + del P + del W_raw, W1_raw, w_raw + + torch_sync() + + avg_loss = torch.sum(Losses).item() / self.nsamples + + if math.isnan(avg_loss): + print("Losses sum item:", torch.sum(Losses).item()) + raise ValueError(f"Quantization: Failed due to `NaN` loss for `{self.name}`") + + del Losses + + group_size = self.qcfg.group_size if self.qcfg.group_size != -1 else self.columns + + if self.qcfg.static_groups and self.qcfg.desc_act: + g_idx = [perm[i] // group_size for i in range(self.columns)] + else: + g_idx = [i // group_size for i in range(self.columns)] + + g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device) + + if self.qcfg.desc_act: + Q = Q[:, invperm] + g_idx = g_idx[invperm] + + if isinstance(self.module, transformers.Conv1D): + Q = Q.t() + + if Q.shape != self.module.weight.shape: + Q = Q.reshape(self.module.weight.shape).type_as(self.module.weight.data) + else: + Q = Q.type_as(self.module.weight.data) + + # Q = Q.to(device=DEVICE_1) + + if scale == []: + scale.append(self.quantizer.scale) + zero.append(self.quantizer.zero) + + scale = torch.cat(scale, dim=1) + zero = torch.cat(zero, dim=1) + + duration = time.time() - start + + return Q, scale, zero, g_idx, duration, avg_loss, damp, self.nsamples + + def free(self): + super().free() + + if hasattr(self, 'dXXT'): + del self.dXXT + + +__all__ = ["FOEM"] diff --git a/gptqmodel/quantization/gar.py b/gptqmodel/quantization/gar.py index a09572c14..3756c2984 100644 --- a/gptqmodel/quantization/gar.py +++ b/gptqmodel/quantization/gar.py @@ -58,7 +58,7 @@ def compute_local_perms( H = diag_H[: num_groups * groupsize].view(num_groups, groupsize) # CUDA `topk` outperforms `argsort`/`sort` for the typical - # group sizes (<=192) used by GPTQModel while keeping identical ordering. + # group sizes (<=192) used by GPT-QModel while keeping identical ordering. use_topk = diag_H.is_cuda and groupsize <= 192 and groupsize > 0 if use_topk: values, indices = torch.topk(H, k=groupsize, dim=1, largest=True, sorted=True) diff --git a/gptqmodel/quantization/gptq.py b/gptqmodel/quantization/gptq.py index fa40c03b7..4c4d59db1 100644 --- a/gptqmodel/quantization/gptq.py +++ b/gptqmodel/quantization/gptq.py @@ -21,11 +21,11 @@ from ..looper.named_module import NamedModule from ..quantization import QuantizeConfig -from ..quantization.config import FailSafeStrategy, SmoothMSE +from ..quantization.config import FallbackStrategy, SmoothMSE from ..utils.device import get_device from ..utils.logger import setup_logger from ..utils.torch import torch_sync -from .failsafe_smooth import mse_optimal_quant, smooth_block +from .fallback_smooth import mse_optimal_quant, smooth_block from .gar import compose_final_perm, compute_global_perm, compute_local_perms, invert_perm from .quantizer import HF_OPTIMUM, Quantizer @@ -136,6 +136,17 @@ def get_number_of_rows_and_cols(layer: nn.Module): class GPTQ: + @staticmethod + def resolve_module_source(module: nn.Module) -> nn.Module: + """Resolve the dense module view GPTQ should quantize for one wrapper.""" + + if isinstance(module, NamedModule): + quant_source = module.state.get("quant_source_module") + if isinstance(quant_source, nn.Module): + return quant_source + return module.module + return module + def __init__(self, module: nn.Module, qcfg: Optional[QuantizeConfig] = None): self.lock = threading.Lock() @@ -149,14 +160,15 @@ def __init__(self, module: nn.Module, qcfg: Optional[QuantizeConfig] = None): # self.issue_non_invertible = False # self.W = module.weight - self.rows, self.columns = get_number_of_rows_and_cols(module) + resolved_module = self.resolve_module_source(module) + self.rows, self.columns = get_number_of_rows_and_cols(resolved_module) if isinstance(module, NamedModule): - self.module = module.module + self.module = resolved_module self.name = module.name self._named_module = module else: self.name = HF_OPTIMUM - self.module = module + self.module = resolved_module self._named_module = None self._original_rows = self.rows @@ -199,7 +211,7 @@ def __init__(self, module: nn.Module, qcfg: Optional[QuantizeConfig] = None): # fwd counter self.fwd_counter = 0 - self.failsafe = self.qcfg.failsafe + self.fallback = self.qcfg.fallback self.expected_nsamples: Optional[float] = None self.H: Optional[torch.Tensor] = None @@ -254,6 +266,17 @@ def mock_hessian_inverse(self, H: torch.Tensor): identity = torch.eye(H.shape[0], dtype=torch.float32, device=H.device) return identity, damp + def log_cpu_fallback(self, stage: str, source_device: torch.device) -> None: + """Explain when a memory-heavy GPTQ step moves from CUDA to CPU.""" + + log.warn( + "Quantization: Module `%s` -> CUDA OOM during %s on %s; falling back to CPU. " + "Due to this fallback, the calculation may take much longer than normal.", + self.name, + stage, + source_device, + ) + def clone_module(self, copy=True, device: torch.device = None): if not device: device = self.module.weight.data.device @@ -618,13 +641,13 @@ def create_H(self, target_device): return torch.zeros((self.columns, self.columns), dtype=torch.float32, device=self._select_hessian_target_device(target_device)) - def _failsafe_quantize(self, strategy: FailSafeStrategy, blocksize: int): + def _fallback_quantize(self, strategy: FallbackStrategy, blocksize: int): """Apply a lightweight quantization fallback using the requested strategy.""" maxq = 2 ** self.qcfg.bits - 1 sigma = 3.0 effective_group_size = self.qcfg.group_size if self.qcfg.group_size != -1 else self.columns start_time = time.time() - smooth_method = getattr(self.failsafe, "smooth", None) + smooth_method = getattr(self.fallback, "smooth", None) mse_steps = 32 mse_maxshrink = 0.8 if isinstance(smooth_method, SmoothMSE): @@ -652,10 +675,10 @@ def _failsafe_quantize(self, strategy: FailSafeStrategy, blocksize: int): else: block_mod, scale_factor = smooth_block( block, - self.failsafe, + self.fallback, group_size=effective_group_size, ) - if strategy == FailSafeStrategy.MIDPOINT: + if strategy == FallbackStrategy.MIDPOINT: w_min = block_mod.min(dim=1, keepdim=True).values w_max = block_mod.max(dim=1, keepdim=True).values mid = (w_max + w_min) / 2.0 @@ -666,7 +689,7 @@ def _failsafe_quantize(self, strategy: FailSafeStrategy, blocksize: int): zero = torch.round(zero_mid - (mid / scale)) zero = torch.clamp(zero, 0, maxq) dequant = (q - zero) * scale - elif strategy == FailSafeStrategy.MEAN: + elif strategy == FallbackStrategy.MEAN: mean = block_mod.mean(dim=1, keepdim=True) max_dev = torch.max((block_mod - mean).abs(), dim=1, keepdim=True).values max_dev = torch.clamp(max_dev, min=1e-8) @@ -677,7 +700,7 @@ def _failsafe_quantize(self, strategy: FailSafeStrategy, blocksize: int): zero = torch.round(zero_mid - (mean / scale)) zero = torch.clamp(zero, 0, maxq) dequant = (q - zero) * scale - elif strategy == FailSafeStrategy.MEDIAN: + elif strategy == FallbackStrategy.MEDIAN: median = block_mod.median(dim=1, keepdim=True).values max_dev = torch.max((block_mod - median).abs(), dim=1, keepdim=True).values max_dev = torch.clamp(max_dev, min=1e-8) @@ -688,7 +711,7 @@ def _failsafe_quantize(self, strategy: FailSafeStrategy, blocksize: int): zero = torch.round(zero_mid - (median / scale)) zero = torch.clamp(zero, 0, maxq) dequant = (q - zero) * scale - elif strategy == FailSafeStrategy.STDCLIP: + elif strategy == FallbackStrategy.STDCLIP: mean = block_mod.mean(dim=1, keepdim=True) std = block_mod.std(dim=1, keepdim=True, unbiased=False) std = torch.clamp(std, min=1e-8) @@ -700,13 +723,13 @@ def _failsafe_quantize(self, strategy: FailSafeStrategy, blocksize: int): q = torch.round(block_mod / scale + zero) q = torch.clamp(q, 0, maxq) dequant = (q - zero) * scale - elif strategy == FailSafeStrategy.RTN: + elif strategy == FallbackStrategy.RTN: self.quantizer.find_params(block_mod, weight=True) dequant = self.quantizer.quantize(block_mod) scale = self.quantizer.scale zero = self.quantizer.zero else: - raise ValueError(f"Unsupported failsafe strategy: {strategy}") + raise ValueError(f"Unsupported fallback strategy: {strategy}") if scale_factor is not None: scale = scale * scale_factor @@ -748,7 +771,7 @@ def _failsafe_quantize(self, strategy: FailSafeStrategy, blocksize: int): Q = Q.to(device=self.module.weight.data.device, non_blocking=False) mean_abs_err = (Q - self.module.weight.data).abs().mean().item() duration = time.time() - start_time - avg_loss = f"failsafe({strategy.value}): {mean_abs_err:.7f}" + avg_loss = f"fallback({strategy.value}): {mean_abs_err:.7f}" damp = 0.0 self.H = None @@ -844,7 +867,7 @@ def hessian_inverse(self, H: torch.Tensor): f"(started at {recovery_initial_damp:.5f})." ) return Hinv_result, used_damp - except torch._C._LinAlgError as e: + except (torch._C._LinAlgError, RuntimeError) as e: last_error = e diag_view.copy_(current_diag) if self.qcfg.damp_auto_increment != 0: @@ -886,28 +909,30 @@ def quantize( start = time.time() target_device = getattr(self.module, "target_device", None) - from ..utils.failsafe import resolve_failsafe_strategy, resolve_threshold, should_use_failsafe + result_device = torch.device(self.module.weight.data.device) + cpu_fallback_used = False + from ..utils.fallback import resolve_fallback_strategy, resolve_threshold, should_use_fallback - resolved_strategy = resolve_failsafe_strategy(self.failsafe) - fallback_requested = should_use_failsafe( - self.failsafe, + resolved_strategy = resolve_fallback_strategy(self.fallback) + fallback_requested = should_use_fallback( + self.fallback, float(self.nsamples), self.expected_nsamples, ) - threshold_raw, is_percent = resolve_threshold(self.failsafe, self.expected_nsamples) - failsafe_configured = threshold_raw is not None + threshold_raw, is_percent = resolve_threshold(self.fallback, self.expected_nsamples) + fallback_configured = threshold_raw is not None if fallback_requested: use_hessian = False - threshold_text = str(getattr(self.failsafe, "threshold", None)) + threshold_text = str(getattr(self.fallback, "threshold", None)) threshold_info = f", threshold_raw={threshold_raw}" if threshold_raw is not None and is_percent else "" log.warn( f"Quantization: Module `{self.name}` -> " - f"Using `{resolved_strategy.value}` failsafe quantization (observed {self.nsamples} samples, threshold={threshold_text}{threshold_info}, max_total={self.expected_nsamples})." + f"Using `{resolved_strategy.value}` fallback quantization (observed {self.nsamples} samples, threshold={threshold_text}{threshold_info}, max_total={self.expected_nsamples})." ) self.H = self.create_H(target_device=target_device) - return self._failsafe_quantize(resolved_strategy, blocksize) + return self._fallback_quantize(resolved_strategy, blocksize) else: use_hessian = True self.finalize_hessian(target_device=target_device) @@ -940,6 +965,8 @@ def quantize( activation_importance = None if use_hessian: + # Replace NaN/Inf in H before processing (can occur with some model architectures) + self.H.nan_to_num_(nan=0.0, posinf=0.0, neginf=0.0) dead = torch.diag(self.H) == 0 self.H[dead, dead] = 1 W[:, dead] = 0 @@ -977,10 +1004,23 @@ def quantize( if self.qcfg.desc_act and use_hessian: perm = torch.argsort(torch.diag(self.H), descending=True) - W = W[:, perm] - self.H = self.H[perm][:, perm] - if activation_importance is not None: - activation_importance = activation_importance[perm] + try: + W = W[:, perm] + self.H = self.H[perm][:, perm] + if activation_importance is not None: + activation_importance = activation_importance[perm] + except RuntimeError as exc: + if self.H.device.type != "cuda" or "out of memory" not in str(exc).lower(): + raise + + self.log_cpu_fallback("Hessian permutation", self.H.device) + cpu_fallback_used = True + cpu_device = torch.device("cpu") + perm = perm.to(device=cpu_device) + W = W.to(device=cpu_device)[:, perm] + self.H = self.H.to(device=cpu_device)[perm][:, perm] + if activation_importance is not None: + activation_importance = activation_importance.to(device=cpu_device)[perm] invperm = torch.argsort(perm) elif self.qcfg.act_group_aware and use_hessian: @@ -995,21 +1035,50 @@ def quantize( ) del local_values final_perm = compose_final_perm(local_perms, global_perm, self.qcfg.group_size) - W = W[:, final_perm] - self.H = self.H[final_perm][:, final_perm] - if activation_importance is not None: - activation_importance = activation_importance[final_perm] + try: + W = W[:, final_perm] + self.H = self.H[final_perm][:, final_perm] + if activation_importance is not None: + activation_importance = activation_importance[final_perm] + except RuntimeError as exc: + if self.H.device.type != "cuda" or "out of memory" not in str(exc).lower(): + raise + + self.log_cpu_fallback("act-group Hessian permutation", self.H.device) + cpu_fallback_used = True + cpu_device = torch.device("cpu") + final_perm = final_perm.to(device=cpu_device) + W = W.to(device=cpu_device)[:, final_perm] + self.H = self.H.to(device=cpu_device)[final_perm][:, final_perm] + if activation_importance is not None: + activation_importance = activation_importance.to(device=cpu_device)[final_perm] self.quantizer.find_params(W, weight=True, importance=activation_importance) - Losses = torch.zeros_like(W) - Q = torch.zeros_like(W) - if use_hessian: - Hinv, damp = self.hessian_inverse(self.H) + try: + Hinv, damp = self.hessian_inverse(self.H) + except RuntimeError as exc: + if self.H.device.type != "cuda" or "out of memory" not in str(exc).lower(): + raise + + # Full-attention blocks on very large models can exceed GPU memory during the + # dense Hessian inverse; finish that module on CPU instead of aborting the run. + self.log_cpu_fallback("Hessian inverse", self.H.device) + cpu_fallback_used = True + cpu_device = torch.device("cpu") + self.H = self.H.to(device=cpu_device) + W = W.to(device=cpu_device) + if activation_importance is not None: + activation_importance = activation_importance.to(device=cpu_device) + self.quantizer.find_params(W, weight=True, importance=activation_importance) + Hinv, damp = self.hessian_inverse(self.H) else: Hinv, damp = None, 0.0 + Losses = torch.zeros_like(W) + Q = torch.zeros_like(W) + # Use simplified loop when mock_quantization is active if self.qcfg.mock_quantization: for i1 in range(0, self.columns, blocksize): @@ -1192,20 +1261,20 @@ def quantize( if math.isnan(avg_loss): print("Losses sum item:", torch.sum(Losses).item()) - if failsafe_configured: + if fallback_configured: log.info(f"Quantization: Failed due to `NaN` loss for `{self.name}`, use mock quantization retry for `{self.name}`") self.qcfg.mock_quantization = True return self.quantize(blocksize=blocksize) else: - raise ValueError(f"Quantization: Failed due to `NaN` loss for `{self.name}`, please try increasing calibration data samples or enable failsafe=True") + raise ValueError(f"Quantization: Failed due to `NaN` loss for `{self.name}`, please try increasing calibration data samples or enable fallback=True") else: - if failsafe_configured: + if fallback_configured: log.warn(f"Quantization: Module `{self.name}` -> using fail safe mode. Please check if calibration data is sufficient.") else: log.warn(f"Quantization: `{self.name}` is not activated due to model inference logic (MoE)") - avg_loss = f"{resolved_strategy.value} failsafe" if failsafe_configured else 999999999 + avg_loss = f"{resolved_strategy.value} fallback" if fallback_configured else 999999999 else: - avg_loss = f"{resolved_strategy.value} failsafe" if failsafe_configured else 999999999 + avg_loss = f"{resolved_strategy.value} fallback" if fallback_configured else 999999999 del Losses del self.H @@ -1221,12 +1290,13 @@ def quantize( g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device) if self.qcfg.desc_act and use_hessian: + invperm = invperm.to(device=Q.device) Q = Q[:, invperm] g_idx = g_idx[invperm] del perm, invperm elif self.qcfg.act_group_aware and use_hessian: - inv_final = invert_perm(final_perm) + inv_final = invert_perm(final_perm).to(device=Q.device) Q = Q[:, inv_final] inv_global_perm = invert_perm(global_perm) inv_global_perm_list = inv_global_perm.tolist() @@ -1261,7 +1331,14 @@ def quantize( scale = self.truncate_last_dim(scale, valid_cols) zero = self.truncate_last_dim(zero, valid_cols) - Q = Q.to(device=self.module.weight.data.device, non_blocking=False) + if cpu_fallback_used and Q.device != result_device: + log.info( + "Quantization: Module `%s` -> CPU fallback complete; moving final quantized weights back to %s.", + self.name, + result_device, + ) + + Q = Q.to(device=result_device, non_blocking=False) duration = time.time() - start diff --git a/gptqmodel/quantization/paroquant/__init__.py b/gptqmodel/quantization/paroquant/__init__.py new file mode 100644 index 000000000..057fc176f --- /dev/null +++ b/gptqmodel/quantization/paroquant/__init__.py @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 diff --git a/gptqmodel/quantization/paroquant/modules/__init__.py b/gptqmodel/quantization/paroquant/modules/__init__.py new file mode 100644 index 000000000..057fc176f --- /dev/null +++ b/gptqmodel/quantization/paroquant/modules/__init__.py @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 diff --git a/gptqmodel/quantization/paroquant/modules/triton/__init__.py b/gptqmodel/quantization/paroquant/modules/triton/__init__.py new file mode 100644 index 000000000..057fc176f --- /dev/null +++ b/gptqmodel/quantization/paroquant/modules/triton/__init__.py @@ -0,0 +1,2 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 diff --git a/gptqmodel/quantization/paroquant/modules/triton/gemm.py b/gptqmodel/quantization/paroquant/modules/triton/gemm.py new file mode 100644 index 000000000..583cc459b --- /dev/null +++ b/gptqmodel/quantization/paroquant/modules/triton/gemm.py @@ -0,0 +1,359 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Copied/adapted from the AWQ Triton kernels used in vLLM and GPT-QModel. +# +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +import triton +import triton.language as tl + +from gptqmodel.utils.env import env_flag + + +PAROQUANT_TRITON_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] +# Shared runtime default: fp32 accumulation trades a little speed for lower numerical drift. +FP32_ACCUM = env_flag("GPTQMODEL_FP32_ACCUM", default=True) + + +def get_same_device_cm(t): + if t.device.type == "xpu": + return torch.xpu.device(t.device.index) + return torch.cuda.device(t.device.index) + + +@triton.jit +def paroquant_dequantize_kernel( + qweight_ptr, + scales_ptr, + zeros_ptr, + group_size, + result_ptr, + num_cols, + num_rows, + BLOCK_SIZE_X: tl.constexpr, + BLOCK_SIZE_Y: tl.constexpr, +): + pid_x = tl.program_id(axis=0) + pid_y = tl.program_id(axis=1) + + offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y) + offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X) + offsets = num_cols * offsets_y[:, None] + offsets_x[None, :] + + masks_y = offsets_y < num_rows + masks_x = offsets_x < num_cols + masks = masks_y[:, None] & masks_x[None, :] + + result_offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y) + result_offsets_x = pid_x * BLOCK_SIZE_X * 8 + tl.arange(0, BLOCK_SIZE_X * 8) + result_offsets = 8 * num_cols * result_offsets_y[:, None] + result_offsets_x[None, :] + + result_masks_y = result_offsets_y < num_rows + result_masks_x = result_offsets_x < num_cols * 8 + result_masks = result_masks_y[:, None] & result_masks_x[None, :] + + iweights = tl.load(qweight_ptr + offsets, masks) + iweights = tl.interleave(iweights, iweights) + iweights = tl.interleave(iweights, iweights) + iweights = tl.interleave(iweights, iweights) + + reverse_order_tensor = ((tl.arange(0, 2) * 4)[None, :] + tl.arange(0, 4)[:, None]).reshape(8) + shifts = reverse_order_tensor * 4 + shifts = tl.broadcast_to(shifts[None, :], (BLOCK_SIZE_Y * BLOCK_SIZE_X, 8)) + shifts = tl.reshape(shifts, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8)) + iweights = (iweights >> shifts) & 0xF + + zero_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1) + zero_offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X) + zero_offsets = num_cols * zero_offsets_y[:, None] + zero_offsets_x[None, :] + + zero_masks_y = zero_offsets_y < num_rows // group_size + zero_masks_x = zero_offsets_x < num_cols + zero_masks = zero_masks_y[:, None] & zero_masks_x[None, :] + + zeros = tl.load(zeros_ptr + zero_offsets, zero_masks) + zeros = tl.interleave(zeros, zeros) + zeros = tl.interleave(zeros, zeros) + zeros = tl.interleave(zeros, zeros) + zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8)) + zeros = (zeros >> shifts) & 0xF + + scale_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1) + scale_offsets_x = pid_x * BLOCK_SIZE_X * 8 + tl.arange(0, BLOCK_SIZE_X * 8) + scale_offsets = num_cols * 8 * scale_offsets_y[:, None] + scale_offsets_x[None, :] + scale_masks_y = scale_offsets_y < num_rows // group_size + scale_masks_x = scale_offsets_x < num_cols * 8 + scale_masks = scale_masks_y[:, None] & scale_masks_x[None, :] + + scales = tl.load(scales_ptr + scale_offsets, scale_masks) + scales = tl.broadcast_to(scales, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8)) + + iweights = (iweights - zeros) * scales + iweights = iweights.to(result_ptr.type.element_ty) + tl.store(result_ptr + result_offsets, iweights, result_masks) + + +@triton.jit +def paroquant_gemm_kernel( + a_ptr, + b_ptr, + c_ptr, + zeros_ptr, + scales_ptr, + M, + N, + K, + group_size, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + USE_FP32_ACCUM: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + if USE_FP32_ACCUM: + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + else: + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=c_ptr.type.element_ty) + + reverse_order_tensor = ((tl.arange(0, 2) * 4)[None, :] + tl.arange(0, 4)[:, None]).reshape(8) + shifts = reverse_order_tensor * 4 + shifts = tl.broadcast_to(shifts[None, :], (BLOCK_SIZE_K * (BLOCK_SIZE_N // 8), 8)) + shifts = tl.reshape(shifts, (BLOCK_SIZE_K, BLOCK_SIZE_N)) + + offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offsets_bn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8) + offsets_zn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8) + offsets_sn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offsets_k = tl.arange(0, BLOCK_SIZE_K) + + masks_am = offsets_am < M + masks_bn = offsets_bn < N // 8 + masks_zn = offsets_zn < N // 8 + masks_sn = offsets_sn < N + + a_ptrs = a_ptr + K * offsets_am[:, None] + offsets_k[None, :] + b_ptrs = b_ptr + (N // 8) * offsets_k[:, None] + offsets_bn[None, :] + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + masks_k = offsets_k < K + masks_a = masks_am[:, None] & masks_k[None, :] + a = tl.load(a_ptrs, mask=masks_a, other=0.0) + + masks_b = masks_k[:, None] & masks_bn[None, :] + b = tl.load(b_ptrs, mask=masks_b, other=0) + b = tl.interleave(b, b) + b = tl.interleave(b, b) + b = tl.interleave(b, b) + + offsets_szk = k * BLOCK_SIZE_K // group_size + tl.arange(0, 1) + offsets_z = (N // 8) * offsets_szk[:, None] + offsets_zn[None, :] + masks_zk = offsets_szk < K // group_size + masks_z = masks_zk[:, None] & masks_zn[None, :] + zeros = tl.load(zeros_ptr + offsets_z, mask=masks_z, other=0) + zeros = tl.interleave(zeros, zeros) + zeros = tl.interleave(zeros, zeros) + zeros = tl.interleave(zeros, zeros) + zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_K, BLOCK_SIZE_N)) + + offsets_s = N * offsets_szk[:, None] + offsets_sn[None, :] + masks_s = masks_zk[:, None] & masks_sn[None, :] + scales = tl.load(scales_ptr + offsets_s, mask=masks_s, other=0.0) + scales = tl.broadcast_to(scales, (BLOCK_SIZE_K, BLOCK_SIZE_N)) + + b = (b >> shifts) & 0xF + zeros = (zeros >> shifts) & 0xF + b = ((b - zeros) * scales).to(a.dtype) + + if USE_FP32_ACCUM: + accumulator = tl.dot(a, b, accumulator, out_dtype=tl.float32) + else: + accumulator = tl.dot(a, b, accumulator, out_dtype=c_ptr.type.element_ty) + + offsets_k += BLOCK_SIZE_K + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K * (N // 8) + + c = accumulator.to(c_ptr.type.element_ty) + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + N * offs_cm[:, None] + offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +def _validate_shapes( + input: torch.Tensor, + qweight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, +): + M, K = input.shape + N = qweight.shape[1] * 8 + group_size = qweight.shape[0] // qzeros.shape[0] + + assert N > 0 and K > 0 and M > 0 + assert qweight.shape[0] == K and qweight.shape[1] == N // 8 + assert qzeros.shape[0] == K // group_size and qzeros.shape[1] == N // 8 + assert scales.shape[0] == K // group_size and scales.shape[1] == N + assert group_size <= K + assert group_size in PAROQUANT_TRITON_SUPPORTED_GROUP_SIZES or group_size == K + return M, N, K, group_size + + +def paroquant_dequantize_triton( + qweight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, + block_size_x: int = 32, + block_size_y: int = 32, +) -> torch.Tensor: + K = qweight.shape[0] + N = scales.shape[1] + group_size = qweight.shape[0] // qzeros.shape[0] + + assert K > 0 and N > 0 + assert scales.shape[0] == K // group_size and scales.shape[1] == N + assert qzeros.shape[0] == K // group_size and qzeros.shape[1] == N // 8 + assert group_size <= K + assert group_size in PAROQUANT_TRITON_SUPPORTED_GROUP_SIZES or group_size == K + + result = torch.empty( + qweight.shape[0], + qweight.shape[1] * 8, + device=qweight.device, + dtype=scales.dtype, + ) + + y = qweight.shape[0] + x = qweight.shape[1] + + def grid(meta): + return ( + triton.cdiv(x, meta["BLOCK_SIZE_X"]), + triton.cdiv(y, meta["BLOCK_SIZE_Y"]), + ) + + with get_same_device_cm(qweight): + paroquant_dequantize_kernel[grid]( + qweight, + scales, + qzeros, + group_size, + result, + x, + y, + BLOCK_SIZE_X=block_size_x, + BLOCK_SIZE_Y=block_size_y, + ) + + return result + + +def _paroquant_gemm_triton( + input: torch.Tensor, + qweight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, + *, + block_size_m: int, + block_size_n: int, + block_size_k: int, + num_warps: int, + num_stages: int, + fp32_accum: bool = FP32_ACCUM, +) -> torch.Tensor: + M, N, K, group_size = _validate_shapes(input, qweight, scales, qzeros) + + def grid(meta): + return (triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(N, meta["BLOCK_SIZE_N"]),) + + result = torch.empty((M, N), dtype=input.dtype, device=input.device) + + with get_same_device_cm(qweight): + paroquant_gemm_kernel[grid]( + input, + qweight, + result, + qzeros, + scales, + M, + N, + K, + group_size, + BLOCK_SIZE_M=block_size_m, + BLOCK_SIZE_N=block_size_n, + BLOCK_SIZE_K=block_size_k, + USE_FP32_ACCUM=fp32_accum, + num_warps=num_warps, + num_stages=num_stages, + ) + + return result + + +def paroquant_gemm_triton_decode( + input: torch.Tensor, + qweight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, +) -> torch.Tensor: + return _paroquant_gemm_triton( + input, + qweight, + scales, + qzeros, + block_size_m=4, + block_size_n=128, + block_size_k=32, + num_warps=4, + num_stages=2, + fp32_accum=FP32_ACCUM, + ) + + +def paroquant_gemm_triton_prefill( + input: torch.Tensor, + qweight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, +) -> torch.Tensor: + return _paroquant_gemm_triton( + input, + qweight, + scales, + qzeros, + block_size_m=32, + block_size_n=128, + block_size_k=32, + num_warps=8, + num_stages=4, + fp32_accum=FP32_ACCUM, + ) + + +__all__ = [ + "paroquant_dequantize_triton", + "paroquant_gemm_triton_decode", + "paroquant_gemm_triton_prefill", +] diff --git a/gptqmodel/quantization/paroquant/optimization.py b/gptqmodel/quantization/paroquant/optimization.py new file mode 100644 index 000000000..c9f3b00ed --- /dev/null +++ b/gptqmodel/quantization/paroquant/optimization.py @@ -0,0 +1,1910 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# ParoQuant optimization implementation adapted from the ParoQuant paper and +# public project: +# https://arxiv.org/html/2511.10645v2 +# https://github.com/z-lab/paroquant + +"""ParoQuant calibration-time optimization utilities. + +This module implements the paper's transformed-domain PTQ lifecycle in a +direct way: +1. learn channel scales and Givens-rotation angles on calibration activations +2. initialize and optimize quantization parameters in the transformed domain +3. export packed runtime tensors that reproduce the pseudo-quantized layer +""" + +from __future__ import annotations + +import math +import random +from contextlib import contextmanager, nullcontext +from dataclasses import dataclass +from functools import lru_cache +from typing import Iterable, Literal, Optional, Sequence + +import torch +import torch.nn.functional as F +from torch import nn +from torch.utils.checkpoint import checkpoint as torch_checkpoint + +from ...utils.env import env_flag +from ...utils.paroquant import ( + apply_paroquant_rotation, + apply_paroquant_rotation_autograd, + build_identity_rotation_buffers, +) +from ..config import PAROQUANT_OPT_SCALE_CLAMP_MAX_DEFAULT, PAROQUANT_OPT_SCALE_CLAMP_MIN_DEFAULT + + +_PAROQUANT_STAGE_PAIR_IMPLS: tuple[str, ...] = ("fast", "reference") +_PAROQUANT_QUANTIZER_IMPLS: tuple[str, ...] = ("fast", "reference") +_PAROQUANT_OPTIMIZERS: tuple[str, ...] = ("adamw", "adam", "sgd") +# Best-state snapshots are a separate memory policy from the always-fp32 live optimization path. +_PAROQUANT_BEST_STATE_DTYPES: tuple[str, ...] = ("fp16", "bf16", "fp32") +_PAROQUANT_LARGE_TRAIN_QUANT_COMPILE_MIN_NUMEL = 8_000_000 + + +def _normalize_opt_impl(name: str, *, field: str) -> str: + normalized = str(name).strip().lower() + if normalized not in _PAROQUANT_STAGE_PAIR_IMPLS: + raise ValueError( + f"ParoQuant optimization: `{field}` must be one of {_PAROQUANT_STAGE_PAIR_IMPLS}, got `{name}`." + ) + return normalized + + +def _normalize_quantizer_impl(name: str) -> str: + normalized = str(name).strip().lower() + if normalized not in _PAROQUANT_QUANTIZER_IMPLS: + raise ValueError( + "ParoQuant optimization: `quantizer_impl` must be one of " + f"{_PAROQUANT_QUANTIZER_IMPLS}, got `{name}`." + ) + return normalized + + +def _normalize_opt_optimizer(name: str) -> str: + normalized = str(name).strip().lower() + if normalized not in _PAROQUANT_OPTIMIZERS: + raise ValueError( + "ParoQuant optimization: `optimizer_name` must be one of " + f"{_PAROQUANT_OPTIMIZERS}, got `{name}`." + ) + return normalized + + +def _normalize_best_state_dtype_name(best_state_dtype: Optional[str | torch.dtype]) -> str: + """Normalize the requested best-state snapshot dtype into one of the supported policy names.""" + if best_state_dtype is None: + return "fp32" + if isinstance(best_state_dtype, str): + normalized = best_state_dtype.strip().lower() + if normalized in {"fp16", "float16"}: + return "fp16" + if normalized in {"bf16", "bfloat16"}: + return "bf16" + if normalized in {"fp32", "float32"}: + return "fp32" + elif isinstance(best_state_dtype, torch.dtype): + if best_state_dtype == torch.float16: + return "fp16" + if best_state_dtype == torch.bfloat16: + return "bf16" + if best_state_dtype == torch.float32: + return "fp32" + raise ValueError( + "ParoQuant optimization: `best_state_dtype` must be one of " + f"{_PAROQUANT_BEST_STATE_DTYPES} or torch.float16/torch.bfloat16/torch.float32." + ) + + +def _resolve_best_state_snapshot_dtype( + *, + best_state_dtype: Optional[str | torch.dtype], + device: torch.device, +) -> torch.dtype: + """Resolve the best-state snapshot dtype policy for the target snapshot device.""" + del device + normalized = _normalize_best_state_dtype_name(best_state_dtype) + if normalized == "fp16": + return torch.float16 + if normalized == "bf16": + return torch.bfloat16 + return torch.float32 + + +def _snapshot_state_tensor( + tensor: torch.Tensor, + *, + target_device: Optional[torch.device] = None, + target_dtype: Optional[torch.dtype] = None, +) -> torch.Tensor: + """Clone a state tensor, optionally moving it and casting floating tensors for compact snapshots.""" + snapshot = tensor.detach() + cast_dtype = target_dtype if target_dtype is not None and snapshot.is_floating_point() else None + needs_move = target_device is not None and snapshot.device != target_device + needs_cast = cast_dtype is not None and snapshot.dtype != cast_dtype + if needs_move or needs_cast: + snapshot = snapshot.to( + device=target_device if target_device is not None else snapshot.device, + dtype=cast_dtype if cast_dtype is not None else snapshot.dtype, + ) + return snapshot.clone() + + +def _snapshot_model_state( + model: nn.Module, + *, + target_device: Optional[torch.device] = None, + target_dtype: Optional[torch.dtype] = None, +) -> dict[str, torch.Tensor]: + """Capture a model state dict with optional float-only dtype compression for best-state snapshots.""" + return { + key: _snapshot_state_tensor(tensor, target_device=target_device, target_dtype=target_dtype) + for key, tensor in model.state_dict().items() + } + + +def _quantizer_sym_for_impl(sym: bool, quantizer_impl: str) -> bool: + impl = _normalize_quantizer_impl(quantizer_impl) + if impl == "reference": + return False + return bool(sym) + + +def _round_ste(x: torch.Tensor) -> torch.Tensor: + """Apply a straight-through round so gradients flow through quantization.""" + return (x.round() - x).detach() + x + + +def _checkpointed_forward(function, *args: torch.Tensor, enabled: bool = False) -> torch.Tensor: + """Recompute the train forward during backward when the stage opts into checkpointing.""" + if not enabled: + return function(*args) + return torch_checkpoint(function, *args, use_reentrant=False) + + +def _clamp_ste( + x: torch.Tensor, + min_value: float | int | None = None, + max_value: float | int | None = None, +) -> torch.Tensor: + """Clamp with a straight-through estimator to stabilize learned qparams.""" + return (x.clamp(min_value, max_value) - x).detach() + x + + +def _normalize_group_size(group_size: int, in_features: int) -> int: + """Validate and normalize a ParoQuant group size for a given hidden width.""" + normalized = in_features if group_size == -1 else int(group_size) + if normalized <= 0: + raise ValueError(f"ParoQuant optimization: invalid group_size `{group_size}` for in_features={in_features}.") + if in_features % normalized != 0: + raise ValueError( + f"ParoQuant optimization: in_features ({in_features}) must be divisible by group_size ({normalized})." + ) + if normalized % 2 != 0: + raise ValueError(f"ParoQuant optimization: group_size ({normalized}) must be even.") + return normalized + + +def _require_paroquant_sym(sym: bool) -> None: + """Reject asymmetric ParoQuant configurations in this implementation.""" + if sym is not True: + raise ValueError("ParoQuant optimization: `sym=False` is disabled; use `sym=True`.") + + +def _select_independent_pairs( + all_pairs: Sequence[tuple[int, int]], + dim: int, + num_rotations: int, + num_pairs_each: int, +) -> list[list[tuple[int, int]]]: + """Choose non-overlapping channel pairs for each rotation step.""" + available = torch.ones(dim, dim, dtype=torch.bool) + available.fill_diagonal_(False) + rotations: list[list[tuple[int, int]]] = [] + + for _ in range(num_rotations): + available_in_rotation = available.clone() + selected: list[tuple[int, int]] = [] + + for i, j in all_pairs: + if len(selected) >= num_pairs_each: + break + if not bool(available_in_rotation[i, j]): + continue + + selected.append((i, j)) + available_in_rotation[i, :] = False + available_in_rotation[j, :] = False + available_in_rotation[:, i] = False + available_in_rotation[:, j] = False + available[i, j] = False + available[j, i] = False + + rotations.append(selected) + + return rotations + + +def _pad_rotation_group( + selected_pairs: Sequence[tuple[int, int]], + group_size: int, + *, + device: torch.device, +) -> tuple[torch.Tensor, torch.Tensor]: + """Pad a sparse rotation schedule with dummy identity pairs.""" + half_group = group_size // 2 + pairs = torch.zeros((half_group, 2), dtype=torch.int16, device=device) + mask = torch.zeros((half_group,), dtype=torch.bool, device=device) + used = torch.zeros((group_size,), dtype=torch.bool, device=device) + + count = 0 + for i, j in selected_pairs: + if count >= half_group: + break + pairs[count, 0] = int(i) + pairs[count, 1] = int(j) + used[i] = True + used[j] = True + count += 1 + + if count == half_group: + return pairs, mask + + remaining = [idx for idx in range(group_size) if not bool(used[idx])] + if len(remaining) % 2 != 0: + raise ValueError(f"ParoQuant optimization: unable to pad group of size {group_size}.") + + remaining_iter = iter(remaining) + while count < half_group: + try: + i = next(remaining_iter) + j = next(remaining_iter) + except StopIteration as exc: + raise ValueError(f"ParoQuant optimization: incomplete dummy-pair padding for group size {group_size}.") from exc + pairs[count, 0] = int(i) + pairs[count, 1] = int(j) + mask[count] = True + count += 1 + + return pairs, mask + + +def _build_random_rotation_buffers_cpu_legacy( + *, + in_features: int, + group_size: int, + krot: int, + pair_ratio: float, + seed: int, +) -> tuple[torch.Tensor, torch.Tensor]: + normalized_group_size = _normalize_group_size(group_size, in_features) + if krot <= 0: + raise ValueError(f"ParoQuant optimization: `krot` must be positive, got {krot}.") + if not (0.0 < float(pair_ratio) <= 0.5): + raise ValueError("ParoQuant optimization: `pair_ratio` must be in the interval (0, 0.5].") + + rng = random.Random(int(seed)) + num_groups = in_features // normalized_group_size + num_pairs_each = max(1, int(normalized_group_size * float(pair_ratio))) + num_pairs_each = min(num_pairs_each, normalized_group_size // 2) + + rotation_rows: list[torch.Tensor] = [] + mask_rows: list[torch.Tensor] = [] + + for _ in range(krot): + rotation_rows.append(torch.empty(0, dtype=torch.int16, device=torch.device("cpu"))) + mask_rows.append(torch.empty(0, dtype=torch.bool, device=torch.device("cpu"))) + + for _ in range(num_groups): + group_pairs = [(i, j) for i in range(normalized_group_size) for j in range(i + 1, normalized_group_size)] + rng.shuffle(group_pairs) + selected_per_rotation = _select_independent_pairs( + group_pairs, + normalized_group_size, + krot, + num_pairs_each, + ) + + for rot_idx in range(krot): + padded_pairs, mask = _pad_rotation_group( + selected_per_rotation[rot_idx], + normalized_group_size, + device=torch.device("cpu"), + ) + rotation_rows[rot_idx] = torch.cat((rotation_rows[rot_idx], padded_pairs.reshape(-1)), dim=0) + mask_rows[rot_idx] = torch.cat((mask_rows[rot_idx], mask), dim=0) + + pairs = torch.stack(rotation_rows, dim=0).contiguous() + masks = torch.stack(mask_rows, dim=0).contiguous() + return pairs, masks + + +@lru_cache(maxsize=32) +def _round_robin_pair_template(group_size: int) -> torch.Tensor: + """Cache one-factorized full matchings for an even group size.""" + if group_size <= 0 or group_size % 2 != 0: + raise ValueError(f"ParoQuant optimization: group_size ({group_size}) must be a positive even integer.") + + players = list(range(group_size)) + half_group = group_size // 2 + rounds: list[torch.Tensor] = [] + for _ in range(group_size - 1): + round_pairs = torch.empty((half_group, 2), dtype=torch.long) + for pair_idx in range(half_group): + round_pairs[pair_idx, 0] = players[pair_idx] + round_pairs[pair_idx, 1] = players[group_size - 1 - pair_idx] + rounds.append(round_pairs) + players = [players[0], players[-1], *players[1:-1]] + return torch.stack(rounds, dim=0).contiguous() + + +def _build_random_rotation_buffers_cpu( + *, + in_features: int, + group_size: int, + krot: int, + pair_ratio: float, + seed: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """Build randomized pair schedules using cached round-robin matchings plus per-group permutations.""" + normalized_group_size = _normalize_group_size(group_size, in_features) + if krot <= 0: + raise ValueError(f"ParoQuant optimization: `krot` must be positive, got {krot}.") + if not (0.0 < float(pair_ratio) <= 0.5): + raise ValueError("ParoQuant optimization: `pair_ratio` must be in the interval (0, 0.5].") + if krot > normalized_group_size - 1: + return _build_random_rotation_buffers_cpu_legacy( + in_features=in_features, + group_size=normalized_group_size, + krot=krot, + pair_ratio=pair_ratio, + seed=seed, + ) + + half_group = normalized_group_size // 2 + num_groups = in_features // normalized_group_size + num_pairs_each = max(1, int(normalized_group_size * float(pair_ratio))) + num_pairs_each = min(num_pairs_each, half_group) + template = _round_robin_pair_template(normalized_group_size) + generator = torch.Generator(device="cpu") + generator.manual_seed(int(seed)) + + pair_rows = torch.empty((krot, in_features), dtype=torch.int16) + mask_rows = torch.zeros((krot, in_features // 2), dtype=torch.bool) + if num_pairs_each < half_group: + for group_idx in range(num_groups): + start = group_idx * half_group + mask_rows[:, start + num_pairs_each:start + half_group] = True + + for group_idx in range(num_groups): + round_order = torch.randperm(template.shape[0], generator=generator)[:krot] + local_template = template.index_select(0, round_order) + perm = torch.randperm(normalized_group_size, generator=generator) + group_pairs = perm[local_template].to(dtype=torch.int16).reshape(krot, normalized_group_size) + start = group_idx * normalized_group_size + pair_rows[:, start:start + normalized_group_size] = group_pairs + + return pair_rows.contiguous(), mask_rows.contiguous() + + +def build_random_rotation_buffers( + *, + in_features: int, + group_size: int, + krot: int, + pair_ratio: float, + seed: int, + device: torch.device, +) -> tuple[torch.Tensor, torch.Tensor]: + """Build randomized pair schedules and masks for ParoQuant angle learning.""" + if krot <= 0: + raise ValueError(f"ParoQuant optimization: `krot` must be positive, got {krot}.") + if not (0.0 < float(pair_ratio) <= 0.5): + raise ValueError("ParoQuant optimization: `pair_ratio` must be in the interval (0, 0.5].") + + pairs_cpu, masks_cpu = _build_random_rotation_buffers_cpu( + in_features=in_features, + group_size=group_size, + krot=krot, + pair_ratio=float(pair_ratio), + seed=int(seed), + ) + + if torch.device(device).type == "cuda": + return pairs_cpu.to(device=device), masks_cpu.to(device=device) + return pairs_cpu, masks_cpu + + +def _get_independent_channel_pairs_reference( + pairs: torch.Tensor, + dim: int, + num_rotations: int, + num_pairs_each: int, +) -> list[list[tuple[int, int]]]: + pairs_cpu = pairs.cpu().tolist() + rotations_pairs: list[list[tuple[int, int]]] = [] + available = torch.ones(dim, dim) + available.fill_diagonal_(0) + + for _ in range(num_rotations): + independent_pairs: list[tuple[int, int]] = [] + available_in_rotation = available.clone() + for i, j in pairs_cpu: + if len(independent_pairs) == num_pairs_each: + break + if available_in_rotation[i, j] == 0: + continue + independent_pairs.append((i, j)) + available_in_rotation[i, :] = 0 + available_in_rotation[j, :] = 0 + available_in_rotation[:, i] = 0 + available_in_rotation[:, j] = 0 + available[i, j] = 0 + available[j, i] = 0 + rotations_pairs.append(independent_pairs) + return rotations_pairs + + +def _align_pairs_to_kernel_shape_reference( + pair: torch.Tensor, + angle: torch.Tensor, + *, + group_size: int, + include_mask: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]: + if pair.size(0) != angle.size(0) or pair.size(1) != 2: + raise ValueError("ParoQuant optimization(reference): pair/angle shape mismatch.") + + group_idx = 0 + pair_ptr = 0 + pair_groups: list[torch.Tensor] = [] + angle_groups: list[torch.Tensor] = [] + mask_groups: list[torch.Tensor] = [] + + while True: + if pair_ptr >= pair.size(0): + break + occupied = torch.zeros((group_size), dtype=torch.int32) + count = 0 + temp_pairs = torch.zeros((group_size // 2, 2), dtype=torch.int32, device=pair.device) + temp_angle = torch.zeros((group_size // 2), dtype=torch.float, device=angle.device) + temp_mask = torch.zeros((group_size // 2), dtype=torch.int32, device=angle.device) + while count < group_size // 2: + if ( + pair_ptr < pair.size(0) + and pair[pair_ptr, 0] - group_idx * group_size < group_size + and pair[pair_ptr, 1] - group_idx * group_size < group_size + ): + temp_pairs[count, :] = pair[pair_ptr, :] + temp_angle[count] = angle[pair_ptr] + if occupied[pair[pair_ptr, 0] % group_size] == 1 or occupied[pair[pair_ptr, 1] % group_size] == 1: + raise ValueError("ParoQuant optimization(reference): illegal pair.") + occupied[pair[pair_ptr, :] % group_size] = 1 + pair_ptr += 1 + else: + t_pair = torch.tensor([-1, -1]) + for i in range(group_size): + if occupied[i] == 0: + t_pair[0] = i + occupied[i] = 1 + break + for i in range(group_size): + if occupied[i] == 0: + t_pair[1] = i + occupied[i] = 1 + break + if t_pair[0] == -1 or t_pair[1] == -1: + raise ValueError("ParoQuant optimization(reference): unable to find dummy pair.") + temp_pairs[count, :] = t_pair + temp_angle[count] = float(0) + temp_mask[count] = 1 + count += 1 + group_idx += 1 + pair_groups.append(temp_pairs) + angle_groups.append(temp_angle) + mask_groups.append(temp_mask) + + rotation_pairs = torch.cat(pair_groups, dim=0).view(-1).contiguous() % group_size + angles = torch.cat(angle_groups, dim=0) + masks = torch.cat(mask_groups, dim=0) + if include_mask: + return rotation_pairs, angles, masks + return rotation_pairs, angles, None + + +def build_random_rotation_buffers_reference( + *, + in_features: int, + group_size: int, + krot: int, + pair_ratio: float, + seed: int, + device: torch.device, +) -> tuple[torch.Tensor, torch.Tensor]: + normalized_group_size = _normalize_group_size(group_size, in_features) + if krot <= 0: + raise ValueError(f"ParoQuant optimization(reference): `krot` must be positive, got {krot}.") + if not (0.0 < float(pair_ratio) <= 0.5): + raise ValueError("ParoQuant optimization(reference): `pair_ratio` must be in the interval (0, 0.5].") + + rng = random.Random(int(seed)) + group_num = in_features // normalized_group_size + num_pairs_per_group = int(normalized_group_size * float(pair_ratio)) + pairs_by_rotation: list[list[tuple[int, int]]] = [[] for _ in range(krot)] + + for group_idx in range(group_num): + all_pairs = [(i, j) for i in range(normalized_group_size) for j in range(i + 1, normalized_group_size)] + rng.shuffle(all_pairs) + selected_by_rotation = _get_independent_channel_pairs_reference( + torch.tensor(all_pairs), + normalized_group_size, + krot, + num_pairs_per_group, + ) + offset = group_idx * normalized_group_size + for rotation_idx in range(krot): + for col1, col2 in selected_by_rotation[rotation_idx]: + pairs_by_rotation[rotation_idx].append((col1 + offset, col2 + offset)) + + pair_tensors = [torch.tensor(pairs, dtype=torch.int32, device=device) for pairs in pairs_by_rotation] + angle_tensors = [torch.zeros((pairs.shape[0],), dtype=torch.float32, device=device) for pairs in pair_tensors] + + aligned_pairs: list[torch.Tensor] = [] + aligned_masks: list[torch.Tensor] = [] + for pair_tensor, angle_tensor in zip(pair_tensors, angle_tensors): + pair, _angle, mask = _align_pairs_to_kernel_shape_reference( + pair_tensor, + angle_tensor, + group_size=normalized_group_size, + include_mask=True, + ) + aligned_pairs.append(pair.to(dtype=torch.int16)) + aligned_masks.append(mask.to(dtype=torch.bool)) + + return torch.stack(aligned_pairs, dim=0).contiguous(), torch.stack(aligned_masks, dim=0).contiguous() + + +def _sample_activation_rows(inputs: torch.Tensor, max_rows: int) -> torch.Tensor: + """Downsample calibration activations to a bounded replay set.""" + rows = inputs.reshape(-1, inputs.shape[-1]) + if rows.shape[0] <= max_rows: + return rows + indices = torch.linspace(0, rows.shape[0] - 1, steps=max_rows, device=rows.device) + indices = torch.round(indices).to(dtype=torch.long) + return rows.index_select(0, indices) + + +def _apply_rotation( + x: torch.Tensor, + pairs: torch.Tensor, + theta: torch.Tensor, + *, + scales: Optional[torch.Tensor], + group_size: int, + fused_rotation: Optional[bool] = None, +) -> torch.Tensor: + """Apply the forward ParoQuant transform in the optimization domain.""" + if x.dim() != 2: + raise ValueError(f"ParoQuant optimization expects a rank-2 tensor, got {tuple(x.shape)}.") + + use_fused_rotation = ( + env_flag("GPTQMODEL_PAROQUANT_OPT_FUSED_ROTATION", default=True) + if fused_rotation is None + else bool(fused_rotation) + ) + + if use_fused_rotation: + scale_tensor = None if scales is None else scales.view(1, -1) + if not ( + x.requires_grad + or theta.requires_grad + or (scale_tensor is not None and scale_tensor.requires_grad) + ): + return apply_paroquant_rotation( + x, + pairs, + theta, + scales=scale_tensor, + group_size=group_size, + ) + return apply_paroquant_rotation_autograd( + x, + pairs, + theta, + scales=scale_tensor, + group_size=group_size, + ) + + out = x + if scales is not None: + out = out * scales.view(1, -1) + + hidden = out.shape[-1] + normalized_group_size = _normalize_group_size(group_size, hidden) + num_groups = hidden // normalized_group_size + half_group = normalized_group_size // 2 + offsets = torch.arange(num_groups, device=out.device, dtype=torch.long).unsqueeze(1) * normalized_group_size + pair_view = pairs.to(device=out.device, dtype=torch.long).view(pairs.shape[0], num_groups, half_group, 2) + theta_view = theta.to(device=out.device, dtype=out.dtype).view(theta.shape[0], num_groups, half_group) + + for rot_idx in range(pair_view.shape[0]): + idx_i = (pair_view[rot_idx, :, :, 0] + offsets).reshape(-1) + idx_j = (pair_view[rot_idx, :, :, 1] + offsets).reshape(-1) + + xi = out[:, idx_i].view(out.shape[0], num_groups, half_group) + xj = out[:, idx_j].view(out.shape[0], num_groups, half_group) + cos_t = torch.cos(theta_view[rot_idx]).unsqueeze(0) + sin_t = torch.sin(theta_view[rot_idx]).unsqueeze(0) + + next_out = out.clone() + next_out[:, idx_i] = (xi * cos_t + xj * sin_t).reshape(out.shape[0], -1) + next_out[:, idx_j] = (-xi * sin_t + xj * cos_t).reshape(out.shape[0], -1) + out = next_out + + return out + + +def _apply_inverse_rotation( + x: torch.Tensor, + pairs: torch.Tensor, + theta: torch.Tensor, + *, + group_size: int, + fused_rotation: Optional[bool] = None, +) -> torch.Tensor: + """Apply the inverse transform that maps export-domain weights back to input space.""" + if pairs.shape[0] == 0: + return x + return _apply_rotation( + x, + pairs.flip(0), + -theta.flip(0), + scales=None, + group_size=group_size, + fused_rotation=fused_rotation, + ) + + +def _reshape_group_params(weight: torch.Tensor, group_size: int, values: torch.Tensor) -> torch.Tensor: + """Restore flat per-group quantizer values to the packed weight layout.""" + groups = weight.shape[1] // group_size + return values.view(weight.shape[0], groups) + + +def _calc_affine_qparams( + weight: torch.Tensor, + *, + group_size: int, + bits: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """Compute affine quantizer parameters from transformed weights.""" + view = weight.reshape(-1, group_size) + min_val = view.amin(dim=1, keepdim=True) + max_val = view.amax(dim=1, keepdim=True) + qmax = 2**bits - 1 + scale = (max_val - min_val).clamp(min=1e-5) / qmax + zero_point_float = min_val / scale + return scale, zero_point_float + + +class GroupLinearQuantizer(nn.Module): + """Learnable per-group quantizer matching ParoQuant's transformed-domain packing.""" + + def __init__( + self, + weight: torch.Tensor, + *, + bits: int, + group_size: int, + sym: bool, + ) -> None: + """Initialize either symmetric or affine groupwise quantization parameters.""" + super().__init__() + self.bits = int(bits) + self.group_size = int(group_size) + self.sym = bool(sym) + + if self.sym: + view = weight.reshape(-1, self.group_size) + qmax = 2 ** (self.bits - 1) - 1 + scale = view.abs().amax(dim=1, keepdim=True).clamp(min=1e-5) / qmax + self.scale = nn.Parameter(scale) + self.zero_point_float = None + else: + scale, zero_point_float = _calc_affine_qparams(weight, group_size=self.group_size, bits=self.bits) + self.scale = nn.Parameter(scale) + self.zero_point_float = nn.Parameter(zero_point_float) + + def forward(self, weight: torch.Tensor) -> torch.Tensor: + """Pseudo-quantize a transformed weight tensor with STE-enabled qparams.""" + return _maybe_compile_large_train_quant( + weight, + bits=self.bits, + group_size=self.group_size, + sym=self.sym, + scale=self.scale, + zero_point_float=self.zero_point_float, + use_ste=True, + ) + + def optim_params(self) -> list[nn.Parameter]: + """Return only the quantizer parameters that should receive optimizer updates.""" + params = [self.scale] + if self.zero_point_float is not None: + params.append(self.zero_point_float) + return params + + def pack_params(self, weight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Convert learned qparams into the runtime scale/zero-point tensors.""" + scales = _reshape_group_params(weight, self.group_size, self.scale.detach()) + if self.sym: + zeros = torch.full_like(scales, 2 ** (self.bits - 1)) + else: + qmax = 2**self.bits - 1 + zeros = _clamp_ste(-self.zero_point_float.detach().round(), 0, qmax) + zeros = _reshape_group_params(weight, self.group_size, zeros) + return scales, zeros + + +def pseudo_quantize_dequant( + weight: torch.Tensor, + *, + bits: int, + group_size: int, + sym: bool, + scale: Optional[torch.Tensor] = None, + zero_point_float: Optional[torch.Tensor] = None, + use_ste: bool, +) -> torch.Tensor: + """Reference pseudo-quantization path shared by optimization and export tests.""" + dtype = weight.dtype + weight_view = weight.reshape(-1, group_size) + + if sym: + qmin = -(2 ** (bits - 1)) + qmax = 2 ** (bits - 1) - 1 + if scale is None: + scale = weight_view.abs().amax(dim=1, keepdim=True).clamp(min=1e-5) / qmax + if use_ste: + scale = _clamp_ste(scale, min_value=1e-5, max_value=1e5) + quant = _clamp_ste(_round_ste(weight_view / scale), qmin, qmax) + dequant = quant * scale + else: + scale = scale.clamp(min=1e-5, max=1e5) + quant = torch.clamp(torch.round(weight_view / scale), qmin, qmax) + dequant = quant * scale + else: + qmin = 0 + qmax = 2**bits - 1 + if scale is None or zero_point_float is None: + scale, zero_point_float = _calc_affine_qparams(weight, group_size=group_size, bits=bits) + if use_ste: + scale = _clamp_ste(scale, min_value=1e-5, max_value=1e5) + round_zero_point = _clamp_ste(-_round_ste(zero_point_float), qmin, qmax) + quant = _round_ste(weight_view / scale) + round_zero_point + quant = _clamp_ste(quant, qmin, qmax) + else: + scale = scale.clamp(min=1e-5, max=1e5) + round_zero_point = torch.clamp(-torch.round(zero_point_float), qmin, qmax) + quant = torch.round(weight_view / scale) + round_zero_point + quant = torch.clamp(quant, qmin, qmax) + dequant = (quant - round_zero_point) * scale + + return dequant.reshape_as(weight).to(dtype) + + +@lru_cache(maxsize=1) +def _get_large_train_quant_compile(): + """Lazily compile the large training-time quant path once per process.""" + compile_fn = getattr(torch, "compile", None) + if compile_fn is None: + return None + try: + return compile_fn(pseudo_quantize_dequant, mode="default", fullgraph=True) + except Exception: + return None + + +def _maybe_compile_large_train_quant( + weight: torch.Tensor, + *, + bits: int, + group_size: int, + sym: bool, + scale: Optional[torch.Tensor] = None, + zero_point_float: Optional[torch.Tensor] = None, + use_ste: bool = True, +) -> torch.Tensor: + """Compile only large training-time quant calls where the compile tax amortizes.""" + should_compile = ( + bool(use_ste) + and weight.device.type == "cuda" + and weight.numel() >= _PAROQUANT_LARGE_TRAIN_QUANT_COMPILE_MIN_NUMEL + and env_flag("GPTQMODEL_PAROQUANT_OPT_LARGE_TRAIN_QUANT_COMPILE", default=True) + ) + if should_compile: + compiled = _get_large_train_quant_compile() + if compiled is not None: + return compiled( + weight, + bits=bits, + group_size=group_size, + sym=sym, + scale=scale, + zero_point_float=zero_point_float, + use_ste=use_ste, + ) + return pseudo_quantize_dequant( + weight, + bits=bits, + group_size=group_size, + sym=sym, + scale=scale, + zero_point_float=zero_point_float, + use_ste=use_ste, + ) + + +class _ParoQuantOptimLinear(nn.Module): + """Minimal layer wrapper used during ParoQuant calibration optimization.""" + + def __init__( + self, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + *, + bits: int, + group_size: int, + quantizer_sym: bool, + pairs: torch.Tensor, + theta_mask: torch.Tensor, + scale_clamp_min: float = PAROQUANT_OPT_SCALE_CLAMP_MIN_DEFAULT, + scale_clamp_max: float = PAROQUANT_OPT_SCALE_CLAMP_MAX_DEFAULT, + fused_rotation: Optional[bool] = None, + ) -> None: + """Materialize a replayable linear layer in the original input domain.""" + super().__init__() + self.bits = int(bits) + self.group_size = int(group_size) + self.quantizer_sym = bool(quantizer_sym) + self.scale_clamp_min = float(scale_clamp_min) + self.scale_clamp_max = float(scale_clamp_max) + if self.scale_clamp_min <= 0 or self.scale_clamp_max <= 0: + raise ValueError("ParoQuant optimization: scale clamp bounds must be positive.") + if self.scale_clamp_min >= self.scale_clamp_max: + raise ValueError( + "ParoQuant optimization: `scale_clamp_min` must be smaller than `scale_clamp_max`." + ) + self.register_buffer("pairs", pairs) + self.register_buffer("theta_mask", theta_mask) + self.weight = nn.Parameter(weight.clone()) + self.bias = None if bias is None else nn.Parameter(bias.clone()) + self.theta = nn.Parameter(torch.zeros((pairs.shape[0], weight.shape[1] // 2), device=weight.device, dtype=weight.dtype)) + self.channel_scales_opt = nn.Parameter(torch.ones((weight.shape[1],), device=weight.device, dtype=weight.dtype)) + self.quantizer: Optional[GroupLinearQuantizer] = None + self.fused_rotation = fused_rotation + + def _safe_channel_scales(self, use_ste: bool) -> torch.Tensor: + """Keep learned channel scales in a numerically safe range. + + During optimization we use an STE clamp so forward values stay bounded + while gradients still flow to ``channel_scales_opt``. During export we + switch to a hard clamp so the saved runtime tensors reflect the exact + bounded values. + """ + scales = self.channel_scales_opt.view(1, -1) + if use_ste: + return _clamp_ste( + scales, + min_value=self.scale_clamp_min, + max_value=self.scale_clamp_max, + ) + return scales.clamp(min=self.scale_clamp_min, max=self.scale_clamp_max) + + def transformed_weight(self, *, use_ste: bool = True) -> torch.Tensor: + """Project the learnable weight into ParoQuant's transformed domain.""" + # Keep training-time scale updates differentiable while preventing + # near-zero / exploding channel scales from destabilizing the transform. + scaled_weight = self.weight * self._safe_channel_scales(use_ste=use_ste) + return _apply_rotation( + scaled_weight, + self.pairs, + self.theta, + scales=None, + group_size=self.group_size, + fused_rotation=self.fused_rotation, + ) + + def quantized_transformed_weight(self) -> torch.Tensor: + """Pseudo-quantize the transformed weight using current learned qparams.""" + transformed = self.transformed_weight(use_ste=True) + if self.quantizer is None: + return _maybe_compile_large_train_quant( + transformed, + bits=self.bits, + group_size=self.group_size, + sym=self.quantizer_sym, + ) + return self.quantizer(transformed) + + def pseudo_weight(self) -> torch.Tensor: + """Map the transformed-domain quantized weight back to runtime input space.""" + quantized = self.quantized_transformed_weight() + quantized = _apply_inverse_rotation( + quantized, + self.pairs, + self.theta, + group_size=self.group_size, + fused_rotation=self.fused_rotation, + ) + channel_scales = self._safe_channel_scales(use_ste=True) + return quantized / channel_scales + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Replay calibration activations through the runtime-equivalent transformed-domain path.""" + if x.dim() < 2: + raise ValueError(f"ParoQuant optimization expects rank-2+ inputs, got {tuple(x.shape)}.") + original_shape = x.shape + x_2d = x.reshape(-1, original_shape[-1]) if x.dim() > 2 else x + # Rotate the minibatch instead of reconstructing full pseudo-weights on + # every step. This preserves the runtime contract while shrinking the + # per-step rotation work from weight-sized tensors to batch-sized ones. + runtime_scales = self._safe_channel_scales(use_ste=True).reciprocal() + rotated_inputs = _apply_rotation( + x_2d, + self.pairs, + self.theta, + scales=runtime_scales, + group_size=self.group_size, + fused_rotation=self.fused_rotation, + ) + outputs = F.linear(rotated_inputs, self.quantized_transformed_weight(), self.bias) + if x.dim() > 2: + return outputs.view(*original_shape[:-1], outputs.shape[-1]) + return outputs + + def reset_masked_angles(self) -> None: + """Force dummy padded pairs to stay at zero angle during optimization.""" + with torch.no_grad(): + self.theta.masked_fill_(self.theta_mask, 0) + + def init_quantizer(self) -> None: + """Bootstrap the transformed-domain quantizer from the current rotated weight.""" + transformed = self.transformed_weight(use_ste=False).detach() + self.quantizer = GroupLinearQuantizer( + transformed, + bits=self.bits, + group_size=self.group_size, + sym=self.quantizer_sym, + ) + + def _export_runtime_state( + self, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Export quantized transformed weights and the exact runtime-equivalent pseudo weight in one pass.""" + transformed = self.transformed_weight(use_ste=False).detach() + if self.quantizer is None: + quantizer = GroupLinearQuantizer( + transformed, + bits=self.bits, + group_size=self.group_size, + sym=self.quantizer_sym, + ) + pack_scales, pack_zeros = quantizer.pack_params(transformed) + quantized = pseudo_quantize_dequant( + transformed, + bits=self.bits, + group_size=self.group_size, + sym=self.quantizer_sym, + scale=quantizer.scale.detach(), + zero_point_float=None if quantizer.zero_point_float is None else quantizer.zero_point_float.detach(), + use_ste=False, + ) + else: + pack_scales, pack_zeros = self.quantizer.pack_params(transformed) + quantized = pseudo_quantize_dequant( + transformed, + bits=self.bits, + group_size=self.group_size, + sym=self.quantizer_sym, + scale=self.quantizer.scale.detach(), + zero_point_float=( + None if self.quantizer.zero_point_float is None else self.quantizer.zero_point_float.detach() + ), + use_ste=False, + ) + + theta = self.theta.detach().masked_fill(self.theta_mask, 0) + # Export the exact bounded inverse scales that runtime input rotation + # multiplies into activations. + runtime_channel_scales = self._safe_channel_scales(use_ste=False).detach().reciprocal() + pseudo_weight = _apply_inverse_rotation( + quantized, + self.pairs, + self.theta, + group_size=self.group_size, + fused_rotation=self.fused_rotation, + ) * runtime_channel_scales + return quantized, pack_scales, pack_zeros, theta, runtime_channel_scales, pseudo_weight + + def export_pack_state(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Export runtime tensors that should match the pseudo-quantized layer exactly.""" + quantized, pack_scales, pack_zeros, theta, runtime_channel_scales, _pseudo_weight = self._export_runtime_state() + return quantized, pack_scales, pack_zeros, theta, runtime_channel_scales + + +@dataclass +class ParoQuantOptimizationResult: + """All tensors and diagnostics produced by one ParoQuant layer optimization run.""" + + pseudo_weight: torch.Tensor + pack_weight: torch.Tensor + q_scales: torch.Tensor + q_zeros: torch.Tensor + pairs: torch.Tensor + theta: torch.Tensor + channel_scales: torch.Tensor + train_loss: float + val_loss: float + used_identity: bool + + +def _chunk_rows(rows: torch.Tensor, batch_size: int) -> Iterable[torch.Tensor]: + """Yield contiguous mini-batches from flattened calibration activations.""" + for start in range(0, rows.shape[0], batch_size): + yield rows[start:start + batch_size] + + +@contextmanager +def _activate_stage_params( + model: nn.Module, + param_groups: Sequence[dict[str, object]], +) -> Iterable[None]: + """Temporarily disable gradients for parameters that are inactive in the current stage.""" + active_param_ids = { + id(param) + for param_group in param_groups + for param in param_group.get("params", []) + if isinstance(param, nn.Parameter) + } + original_flags = [(param, param.requires_grad) for param in model.parameters()] + for param, was_enabled in original_flags: + should_enable = id(param) in active_param_ids + if was_enabled != should_enable: + param.requires_grad_(should_enable) + try: + yield + finally: + for param, was_enabled in original_flags: + if param.requires_grad != was_enabled: + param.requires_grad_(was_enabled) + + +def _evaluate_model( + model: nn.Module, + inputs: torch.Tensor, + targets: torch.Tensor, + *, + use_amp: bool = False, +) -> float: + """Measure replay error for early stopping and stage selection.""" + if inputs.numel() == 0: + return 0.0 + with torch.no_grad(): + autocast_ctx = torch.amp.autocast("cuda") if use_amp and inputs.device.type == "cuda" else nullcontext() + with autocast_ctx: + preds = model(inputs) + loss = F.smooth_l1_loss(preds, targets) + return float(loss.item()) + + +def _normalize_optimizer_param_groups( + param_groups: Sequence[dict[str, object]], +) -> list[dict[str, object]]: + normalized_groups: list[dict[str, object]] = [] + for param_group in param_groups: + params = [param for param in param_group.get("params", []) if isinstance(param, nn.Parameter) and param.requires_grad] + if not params: + continue + betas_obj = tuple(float(beta) for beta in param_group.get("betas", (0.9, 0.95))) + normalized_groups.append( + { + "params": params, + "lr": float(param_group["lr"]), + "weight_decay": float(param_group.get("weight_decay", 0.01)), + "betas": (betas_obj[0], betas_obj[1]), + "eps": float(param_group.get("eps", 1e-10)), + "amsgrad": bool(param_group.get("amsgrad", False)), + "momentum": float(param_group.get("momentum", 0.0)), + "dampening": float(param_group.get("dampening", 0.0)), + "nesterov": bool(param_group.get("nesterov", False)), + } + ) + return normalized_groups + + +def _optimizer_param_groups_support_fused( + normalized_groups: Sequence[dict[str, object]], + *, + device: torch.device, +) -> bool: + return device.type == "cuda" and all( + isinstance(param, nn.Parameter) and param.device.type == "cuda" and torch.is_floating_point(param) + for group in normalized_groups + for param in group.get("params", []) + ) + + +def _optimizer_lr_value( + lr: float, + *, + device: torch.device, + graph_capture: bool, +) -> float | torch.Tensor: + if not graph_capture: + return float(lr) + return torch.tensor(float(lr), device=device, dtype=torch.float32) + + +def _set_optimizer_group_lr(param_group: dict[str, object], value: float) -> None: + current = param_group.get("lr") + if isinstance(current, torch.Tensor): + current.fill_(float(value)) + param_group["lr"] = current + return + param_group["lr"] = float(value) + + +def build_paroquant_optimizer( + normalized_groups: Sequence[dict[str, object]], + *, + device: torch.device, + optimizer_name: str, + graph_capture: bool = False, +) -> torch.optim.Optimizer: + normalized_name = _normalize_opt_optimizer(optimizer_name) + use_fused = normalized_name in {"adamw", "adam", "sgd"} and _optimizer_param_groups_support_fused( + normalized_groups, + device=device, + ) + + def _base_groups() -> list[dict[str, object]]: + groups: list[dict[str, object]] = [] + for group in normalized_groups: + groups.append( + { + "params": group["params"], + "lr": _optimizer_lr_value(float(group.get("lr", 0.0)), device=device, graph_capture=graph_capture), + "weight_decay": float(group.get("weight_decay", 0.01)), + } + ) + return groups + + if normalized_name in {"adamw", "adam"}: + groups = _base_groups() + for built_group, source_group in zip(groups, normalized_groups): + betas_obj = tuple(float(beta) for beta in source_group.get("betas", (0.9, 0.95))) + built_group["betas"] = (betas_obj[0], betas_obj[1]) + built_group["eps"] = float(source_group.get("eps", 1e-10)) + built_group["amsgrad"] = bool(source_group.get("amsgrad", False)) + + optimizer_cls = torch.optim.AdamW if normalized_name == "adamw" else torch.optim.Adam + optimizer_kwargs: dict[str, object] = {} + if graph_capture: + optimizer_kwargs["capturable"] = True + if use_fused: + optimizer_kwargs["fused"] = True + try: + return optimizer_cls(groups, **optimizer_kwargs) + except (RuntimeError, TypeError, ValueError): + if use_fused: + optimizer_kwargs.pop("fused", None) + return optimizer_cls(groups, **optimizer_kwargs) + raise + + if normalized_name == "sgd": + groups = _base_groups() + for built_group, source_group in zip(groups, normalized_groups): + built_group["momentum"] = float(source_group.get("momentum", 0.0)) + built_group["dampening"] = float(source_group.get("dampening", 0.0)) + built_group["nesterov"] = bool(source_group.get("nesterov", False)) + + optimizer_kwargs = {"fused": True} if use_fused else {} + try: + return torch.optim.SGD(groups, **optimizer_kwargs) + except (RuntimeError, TypeError, ValueError): + if use_fused: + return torch.optim.SGD(groups) + raise + + raise AssertionError(f"Unhandled ParoQuant optimizer `{normalized_name}`.") + + +def _run_stage_gptqmodel_impl( + *, + model: nn.Module, + inputs_train: torch.Tensor, + targets_train: torch.Tensor, + inputs_val: torch.Tensor, + targets_val: torch.Tensor, + param_groups: Sequence[dict[str, object]], + epochs: int, + batch_size: int, + optimizer_name: str, + gradient_checkpointing: bool = False, + best_state_dtype: Optional[str | torch.dtype] = "fp32", +) -> tuple[float, float]: + """Run one optimization stage with validation-based best-state selection.""" + normalized_groups = _normalize_optimizer_param_groups(param_groups) + + use_amp = inputs_train.device.type == "cuda" + if epochs <= 0 or not normalized_groups: + train_loss = _evaluate_model(model, inputs_train, targets_train, use_amp=use_amp) + val_loss = _evaluate_model(model, inputs_val, targets_val, use_amp=use_amp) + return train_loss, val_loss + + optimizer = build_paroquant_optimizer( + normalized_groups, + device=inputs_train.device, + optimizer_name=optimizer_name, + graph_capture=False, + ) + steps_per_epoch = max(1, math.ceil(max(1, inputs_train.shape[0]) / max(1, batch_size))) + total_steps = max(1, epochs * steps_per_epoch) + base_lrs = [float(group["lr"]) for group in normalized_groups] + scaler = torch.amp.GradScaler(enabled=use_amp) + global_step = 0 + best_state_snapshot_dtype = _resolve_best_state_snapshot_dtype(best_state_dtype=best_state_dtype, device=inputs_train.device) + + best_state = _snapshot_model_state(model, target_dtype=best_state_snapshot_dtype) + best_val_loss = float("inf") + last_train_loss = _evaluate_model(model, inputs_train, targets_train, use_amp=use_amp) + + for _epoch in range(epochs): + epoch_loss = 0.0 + batch_count = 0 + + for input_batch, target_batch in zip(_chunk_rows(inputs_train, batch_size), _chunk_rows(targets_train, batch_size)): + optimizer.zero_grad(set_to_none=True) + autocast_ctx = torch.amp.autocast("cuda") if use_amp else nullcontext() + with autocast_ctx: + preds = _checkpointed_forward(model, input_batch, enabled=gradient_checkpointing) + loss = F.smooth_l1_loss(preds, target_batch) + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + global_step += 1 + cosine_ratio = 0.5 * (1.0 + math.cos(math.pi * min(global_step, total_steps) / total_steps)) + for group, base_lr in zip(optimizer.param_groups, base_lrs): + _set_optimizer_group_lr( + group, + (base_lr / 20.0) + ((base_lr - (base_lr / 20.0)) * cosine_ratio), + ) + + model.reset_masked_angles() + epoch_loss += float(loss.item()) + batch_count += 1 + + last_train_loss = epoch_loss / max(1, batch_count) + val_loss = _evaluate_model(model, inputs_val, targets_val, use_amp=use_amp) + if val_loss < best_val_loss: + best_val_loss = val_loss + best_state = _snapshot_model_state(model, target_dtype=best_state_snapshot_dtype) + + model.load_state_dict(best_state, strict=True) + model.reset_masked_angles() + return last_train_loss, best_val_loss + + +def _should_use_paroquant_stage_cudagraph( + model: nn.Module, + *, + inputs_train: torch.Tensor, + batch_size: int, + stage_cudagraph: Optional[bool] = None, +) -> bool: + """Use CUDA graphs only for real CUDA tensor stages where launch overhead is worth amortizing.""" + if stage_cudagraph is None: + stage_cudagraph = env_flag("GPTQMODEL_PAROQUANT_OPT_STAGE_CUDAGRAPH", default=True) + if not bool(stage_cudagraph): + return False + if not isinstance(inputs_train, torch.Tensor): + return False + if inputs_train.device.type != "cuda": + return False + if not bool(getattr(model, "fused_rotation", False)): + return False + return inputs_train.shape[0] >= max(1, int(batch_size)) + + +def _run_stage_gptqmodel_cudagraph( + *, + model: nn.Module, + inputs_train: torch.Tensor, + targets_train: torch.Tensor, + inputs_val: torch.Tensor, + targets_val: torch.Tensor, + param_groups: Sequence[dict[str, object]], + epochs: int, + batch_size: int, + optimizer_name: str, + best_state_dtype: Optional[str | torch.dtype] = "fp32", +) -> tuple[float, float]: + """Replay fixed-size CUDA mini-batches through one captured train-step graph with eager tail fallback.""" + normalized_groups = _normalize_optimizer_param_groups(param_groups) + + if epochs <= 0 or not normalized_groups: + train_loss = _evaluate_model(model, inputs_train, targets_train, use_amp=False) + val_loss = _evaluate_model(model, inputs_val, targets_val, use_amp=False) + return train_loss, val_loss + + optimizer = build_paroquant_optimizer( + normalized_groups, + device=inputs_train.device, + optimizer_name=optimizer_name, + graph_capture=True, + ) + steps_per_epoch = max(1, math.ceil(max(1, inputs_train.shape[0]) / max(1, batch_size))) + total_steps = max(1, epochs * steps_per_epoch) + base_lrs = [float(group["lr"]) for group in normalized_groups] + global_step = 0 + best_state_snapshot_dtype = _resolve_best_state_snapshot_dtype(best_state_dtype=best_state_dtype, device=inputs_train.device) + + best_state = _snapshot_model_state(model, target_dtype=best_state_snapshot_dtype) + best_val_loss = float("inf") + last_train_loss = _evaluate_model(model, inputs_train, targets_train, use_amp=False) + + static_input = torch.empty( + (batch_size, inputs_train.shape[1]), + device=inputs_train.device, + dtype=inputs_train.dtype, + ) + static_target = torch.empty( + (batch_size, targets_train.shape[1]), + device=targets_train.device, + dtype=targets_train.dtype, + ) + + warmup_stream = torch.cuda.Stream(device=inputs_train.device) + warmup_stream.wait_stream(torch.cuda.current_stream(inputs_train.device)) + with torch.cuda.stream(warmup_stream): + warm_input = inputs_train[:batch_size] + warm_target = targets_train[:batch_size] + for _ in range(3): + static_input.copy_(warm_input) + static_target.copy_(warm_target) + optimizer.zero_grad(set_to_none=True) + preds = model(static_input) + loss = F.smooth_l1_loss(preds, static_target) + loss.backward() + optimizer.step() + model.reset_masked_angles() + del preds, loss + + torch.cuda.current_stream(inputs_train.device).wait_stream(warmup_stream) + torch.cuda.synchronize(inputs_train.device) + + graph = torch.cuda.CUDAGraph() + optimizer.zero_grad(set_to_none=True) + with torch.cuda.graph(graph): + preds = model(static_input) + static_loss = F.smooth_l1_loss(preds, static_target) + static_loss.backward() + optimizer.step() + model.reset_masked_angles() + + for _epoch in range(epochs): + epoch_loss = 0.0 + batch_count = 0 + + for input_batch, target_batch in zip(_chunk_rows(inputs_train, batch_size), _chunk_rows(targets_train, batch_size)): + global_step += 1 + cosine_ratio = 0.5 * (1.0 + math.cos(math.pi * min(global_step, total_steps) / total_steps)) + for group, base_lr in zip(optimizer.param_groups, base_lrs): + _set_optimizer_group_lr( + group, + (base_lr / 20.0) + ((base_lr - (base_lr / 20.0)) * cosine_ratio), + ) + + if input_batch.shape[0] == batch_size: + static_input.copy_(input_batch) + static_target.copy_(target_batch) + graph.replay() + loss_value = float(static_loss.item()) + else: + optimizer.zero_grad(set_to_none=True) + preds = model(input_batch) + loss = F.smooth_l1_loss(preds, target_batch) + loss.backward() + optimizer.step() + model.reset_masked_angles() + loss_value = float(loss.item()) + + epoch_loss += loss_value + batch_count += 1 + + last_train_loss = epoch_loss / max(1, batch_count) + val_loss = _evaluate_model(model, inputs_val, targets_val, use_amp=False) + if val_loss < best_val_loss: + best_val_loss = val_loss + best_state = _snapshot_model_state(model, target_dtype=best_state_snapshot_dtype) + + model.load_state_dict(best_state, strict=True) + model.reset_masked_angles() + return last_train_loss, best_val_loss + + +def _run_stage_gptqmodel( + *, + model: nn.Module, + inputs_train: torch.Tensor, + targets_train: torch.Tensor, + inputs_val: torch.Tensor, + targets_val: torch.Tensor, + param_groups: Sequence[dict[str, object]], + epochs: int, + batch_size: int, + stage_cudagraph: Optional[bool] = None, + optimizer_name: str = "adamw", + gradient_checkpointing: bool = False, + best_state_dtype: Optional[str | torch.dtype] = "fp32", +) -> tuple[float, float]: + """Run the fast stage, preferring CUDA-graph replay on fused CUDA paths and falling back to eager.""" + if not _should_use_paroquant_stage_cudagraph( + model, + inputs_train=inputs_train, + batch_size=batch_size, + stage_cudagraph=False if gradient_checkpointing else stage_cudagraph, + ): + return _run_stage_gptqmodel_impl( + model=model, + inputs_train=inputs_train, + targets_train=targets_train, + inputs_val=inputs_val, + targets_val=targets_val, + param_groups=param_groups, + epochs=epochs, + batch_size=batch_size, + optimizer_name=optimizer_name, + gradient_checkpointing=gradient_checkpointing, + best_state_dtype=best_state_dtype, + ) + + initial_state = {key: tensor.detach().clone() for key, tensor in model.state_dict().items()} + + try: + return _run_stage_gptqmodel_cudagraph( + model=model, + inputs_train=inputs_train, + targets_train=targets_train, + inputs_val=inputs_val, + targets_val=targets_val, + param_groups=param_groups, + epochs=epochs, + batch_size=batch_size, + optimizer_name=optimizer_name, + best_state_dtype=best_state_dtype, + ) + except Exception: + model.load_state_dict(initial_state, strict=True) + if hasattr(model, "reset_masked_angles"): + model.reset_masked_angles() + return _run_stage_gptqmodel_impl( + model=model, + inputs_train=inputs_train, + targets_train=targets_train, + inputs_val=inputs_val, + targets_val=targets_val, + param_groups=param_groups, + epochs=epochs, + batch_size=batch_size, + optimizer_name=optimizer_name, + gradient_checkpointing=gradient_checkpointing, + best_state_dtype=best_state_dtype, + ) + + +def _run_stage_reference( + *, + model: nn.Module, + inputs_train: torch.Tensor, + targets_train: torch.Tensor, + inputs_val: torch.Tensor, + targets_val: torch.Tensor, + param_groups: Sequence[dict[str, object]], + epochs: int, + batch_size: int, + optimizer_name: str, + gradient_checkpointing: bool = False, + best_state_dtype: Optional[str | torch.dtype] = "fp32", +) -> tuple[float, float]: + """Official-parity stage runner: AMP + GradScaler + cosine LR update.""" + normalized_groups = _normalize_optimizer_param_groups(param_groups) + + use_amp = inputs_train.device.type == "cuda" + if epochs <= 0 or not normalized_groups: + train_loss = _evaluate_model(model, inputs_train, targets_train, use_amp=use_amp) + val_loss = _evaluate_model(model, inputs_val, targets_val, use_amp=use_amp) + return train_loss, val_loss + + optimizer = build_paroquant_optimizer( + normalized_groups, + device=inputs_train.device, + optimizer_name=optimizer_name, + graph_capture=False, + ) + steps_per_epoch = max(1, math.ceil(max(1, inputs_train.shape[0]) / max(1, batch_size))) + total_steps = max(1, epochs * steps_per_epoch) + base_lrs = [float(group["lr"]) for group in normalized_groups] + scaler = torch.amp.GradScaler(enabled=use_amp) + global_step = 0 + best_state_snapshot_dtype = _resolve_best_state_snapshot_dtype(best_state_dtype=best_state_dtype, device=inputs_train.device) + + best_state = _snapshot_model_state(model, target_dtype=best_state_snapshot_dtype) + best_val_loss = _evaluate_model(model, inputs_val, targets_val, use_amp=use_amp) + last_train_loss = _evaluate_model(model, inputs_train, targets_train, use_amp=use_amp) + + for _ in range(epochs): + epoch_loss = 0.0 + batch_count = 0 + optimizer.zero_grad(set_to_none=True) + + for input_batch, target_batch in zip(_chunk_rows(inputs_train, batch_size), _chunk_rows(targets_train, batch_size)): + autocast_ctx = torch.amp.autocast("cuda") if use_amp else nullcontext() + with autocast_ctx: + preds = _checkpointed_forward(model, input_batch, enabled=gradient_checkpointing) + loss = F.smooth_l1_loss(preds, target_batch) + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad(set_to_none=True) + + global_step += 1 + cosine_ratio = 0.5 * (1.0 + math.cos(math.pi * min(global_step, total_steps) / total_steps)) + for group, base_lr in zip(optimizer.param_groups, base_lrs): + _set_optimizer_group_lr( + group, + (base_lr / 20.0) + ((base_lr - (base_lr / 20.0)) * cosine_ratio), + ) + + model.reset_masked_angles() + epoch_loss += float(loss.item()) + batch_count += 1 + + last_train_loss = epoch_loss / max(1, batch_count) + val_loss = _evaluate_model(model, inputs_val, targets_val, use_amp=use_amp) + if val_loss < best_val_loss: + best_val_loss = val_loss + best_state = _snapshot_model_state(model, target_dtype=best_state_snapshot_dtype) + + model.load_state_dict(best_state, strict=True) + model.reset_masked_angles() + return last_train_loss, best_val_loss + + +def _run_stage( + *, + model: nn.Module, + inputs_train: torch.Tensor, + targets_train: torch.Tensor, + inputs_val: torch.Tensor, + targets_val: torch.Tensor, + param_groups: Sequence[dict[str, object]], + epochs: int, + batch_size: int, + stage_impl: str, + stage_cudagraph: Optional[bool] = None, + optimizer_name: str = "adamw", + gradient_checkpointing: bool = False, + best_state_dtype: Optional[str | torch.dtype] = "fp32", +) -> tuple[float, float]: + impl = _normalize_opt_impl(stage_impl, field="stage_impl") + with _activate_stage_params(model, param_groups): + if impl == "reference": + return _run_stage_reference( + model=model, + inputs_train=inputs_train, + targets_train=targets_train, + inputs_val=inputs_val, + targets_val=targets_val, + param_groups=param_groups, + epochs=epochs, + batch_size=batch_size, + optimizer_name=optimizer_name, + gradient_checkpointing=gradient_checkpointing, + best_state_dtype=best_state_dtype, + ) + return _run_stage_gptqmodel( + model=model, + inputs_train=inputs_train, + targets_train=targets_train, + inputs_val=inputs_val, + targets_val=targets_val, + param_groups=param_groups, + epochs=epochs, + batch_size=batch_size, + stage_cudagraph=stage_cudagraph, + optimizer_name=optimizer_name, + gradient_checkpointing=gradient_checkpointing, + best_state_dtype=best_state_dtype, + ) + + +def _result_from_model( + model: _ParoQuantOptimLinear, + *, + train_loss: float, + val_loss: float, + used_identity: bool, +) -> ParoQuantOptimizationResult: + """Export one optimized linear replay module into the runtime tensor contract.""" + pack_weight, q_scales, q_zeros, theta, channel_scales, pseudo_weight = model._export_runtime_state() + return ParoQuantOptimizationResult( + pseudo_weight=pseudo_weight.detach(), + pack_weight=pack_weight.detach(), + q_scales=q_scales.detach(), + q_zeros=q_zeros.detach(), + pairs=model.pairs.detach(), + theta=theta.detach(), + channel_scales=channel_scales.detach(), + train_loss=float(train_loss), + val_loss=float(val_loss), + used_identity=used_identity, + ) + + +def _identity_result( + *, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + bits: int, + group_size: int, + quantizer_sym: bool, + krot: int, +) -> ParoQuantOptimizationResult: + """Return the no-optimization fallback used when calibration activations are missing.""" + del bias + pairs, theta, channel_scales = build_identity_rotation_buffers( + in_features=weight.shape[1], + group_size=group_size, + krot=krot, + device=weight.device, + dtype=weight.dtype, + ) + quantizer = GroupLinearQuantizer( + weight, + bits=bits, + group_size=_normalize_group_size(group_size, weight.shape[1]), + sym=quantizer_sym, + ) + q_scales, q_zeros = quantizer.pack_params(weight) + pack_weight = pseudo_quantize_dequant( + weight, + bits=bits, + group_size=_normalize_group_size(group_size, weight.shape[1]), + sym=quantizer_sym, + scale=quantizer.scale.detach(), + zero_point_float=None if quantizer.zero_point_float is None else quantizer.zero_point_float.detach(), + use_ste=False, + ) + return ParoQuantOptimizationResult( + pseudo_weight=pack_weight.detach(), + pack_weight=pack_weight.detach(), + q_scales=q_scales.detach(), + q_zeros=q_zeros.detach(), + pairs=pairs.detach(), + theta=theta.detach(), + channel_scales=channel_scales.detach(), + train_loss=0.0, + val_loss=0.0, + used_identity=True, + ) + + +def optimize_paroquant_linear( + *, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + inputs: torch.Tensor, + bits: int, + group_size: int, + sym: bool, + krot: int, + pair_ratio: float, + train_rows: int, + val_rows: int, + batch_size: int, + rotation_epochs: int, + finetune_epochs: int, + rotation_lr: float, + weight_lr: float, + quantizer_lr: float, + seed: int, + optimizer_name: str = "adamw", + optimizer_weight_decay: float = 0.01, + optimizer_betas: tuple[float, float] = (0.9, 0.95), + optimizer_eps: float = 1e-10, + optimizer_amsgrad: bool = False, + sgd_momentum: float = 0.0, + sgd_dampening: float = 0.0, + sgd_nesterov: bool = False, + fused_rotation: Optional[bool] = None, + stage_cudagraph: Optional[bool] = None, + stage_impl: Literal["fast", "reference"] = "fast", + pair_impl: Literal["fast", "reference"] = "fast", + quantizer_impl: Literal["fast", "reference"] = "fast", + gradient_checkpointing: bool = False, + best_state_dtype: Optional[str | torch.dtype] = "fp32", + scale_clamp_min: float = PAROQUANT_OPT_SCALE_CLAMP_MIN_DEFAULT, + scale_clamp_max: float = PAROQUANT_OPT_SCALE_CLAMP_MAX_DEFAULT, +) -> ParoQuantOptimizationResult: + """Optimize one linear layer following the paper's two-stage PTQ schedule.""" + _require_paroquant_sym(sym) + if weight.dim() != 2: + raise ValueError(f"ParoQuant optimization expects rank-2 weights, got {tuple(weight.shape)}.") + + normalized_group_size = _normalize_group_size(group_size, weight.shape[1]) + quantizer_sym = _quantizer_sym_for_impl(sym, quantizer_impl) + normalized_optimizer_name = _normalize_opt_optimizer(optimizer_name) + normalized_optimizer_betas = (float(optimizer_betas[0]), float(optimizer_betas[1])) + rows = _sample_activation_rows(inputs, max_rows=max(1, int(train_rows) + int(val_rows))) + if rows.numel() == 0: + return _identity_result( + weight=weight, + bias=bias, + bits=bits, + group_size=normalized_group_size, + quantizer_sym=quantizer_sym, + krot=krot, + ) + + opt_device = weight.device + opt_dtype = torch.float32 + weight_opt = weight.detach().to(device=opt_device, dtype=opt_dtype) + bias_opt = None if bias is None else bias.detach().to(device=opt_device, dtype=opt_dtype) + rows = rows.to(device=opt_device, dtype=opt_dtype) + + targets = F.linear(rows, weight_opt, bias_opt) + train_count = min(rows.shape[0], max(1, int(train_rows))) + val_count = min(max(1, int(val_rows)), max(1, rows.shape[0] - train_count)) + inputs_train = rows[:train_count] + targets_train = targets[:train_count] + inputs_val = rows[-val_count:] + targets_val = targets[-val_count:] + + if inputs_train.numel() == 0 or targets_train.numel() == 0: + raise ValueError("ParoQuant optimization requires non-empty training activations.") + + normalized_pair_impl = _normalize_opt_impl(pair_impl, field="pair_impl") + normalized_stage_impl = _normalize_opt_impl(stage_impl, field="stage_impl") + _normalize_quantizer_impl(quantizer_impl) + if normalized_pair_impl == "reference": + pairs, theta_mask = build_random_rotation_buffers_reference( + in_features=weight_opt.shape[1], + group_size=normalized_group_size, + krot=krot, + pair_ratio=pair_ratio, + seed=seed, + device=opt_device, + ) + else: + pairs, theta_mask = build_random_rotation_buffers( + in_features=weight_opt.shape[1], + group_size=normalized_group_size, + krot=krot, + pair_ratio=pair_ratio, + seed=seed, + device=opt_device, + ) + model = _ParoQuantOptimLinear( + weight_opt, + bias_opt, + bits=bits, + group_size=normalized_group_size, + quantizer_sym=quantizer_sym, + pairs=pairs, + theta_mask=theta_mask, + scale_clamp_min=scale_clamp_min, + scale_clamp_max=scale_clamp_max, + fused_rotation=fused_rotation, + ).to(device=opt_device, dtype=opt_dtype) + model.reset_masked_angles() + + _run_stage( + model=model, + inputs_train=inputs_train, + targets_train=targets_train, + inputs_val=inputs_val, + targets_val=targets_val, + param_groups=[ + { + "params": [model.channel_scales_opt], + "lr": rotation_lr, + "weight_decay": optimizer_weight_decay, + "betas": normalized_optimizer_betas, + "eps": optimizer_eps, + "amsgrad": optimizer_amsgrad, + "momentum": sgd_momentum, + "dampening": sgd_dampening, + "nesterov": sgd_nesterov, + }, + { + "params": [model.theta], + "lr": rotation_lr, + "weight_decay": optimizer_weight_decay, + "betas": normalized_optimizer_betas, + "eps": optimizer_eps, + "amsgrad": optimizer_amsgrad, + "momentum": sgd_momentum, + "dampening": sgd_dampening, + "nesterov": sgd_nesterov, + }, + ], + epochs=rotation_epochs, + batch_size=batch_size, + stage_impl=normalized_stage_impl, + stage_cudagraph=stage_cudagraph, + optimizer_name=normalized_optimizer_name, + gradient_checkpointing=gradient_checkpointing, + best_state_dtype=best_state_dtype, + ) + + model.init_quantizer() + train_loss, val_loss = _run_stage( + model=model, + inputs_train=inputs_train, + targets_train=targets_train, + inputs_val=inputs_val, + targets_val=targets_val, + param_groups=[ + { + "params": [model.weight], + "lr": weight_lr, + "weight_decay": optimizer_weight_decay, + "betas": normalized_optimizer_betas, + "eps": optimizer_eps, + "amsgrad": optimizer_amsgrad, + "momentum": sgd_momentum, + "dampening": sgd_dampening, + "nesterov": sgd_nesterov, + }, + { + "params": model.quantizer.optim_params(), + "lr": quantizer_lr, + "weight_decay": optimizer_weight_decay, + "betas": normalized_optimizer_betas, + "eps": optimizer_eps, + "amsgrad": optimizer_amsgrad, + "momentum": sgd_momentum, + "dampening": sgd_dampening, + "nesterov": sgd_nesterov, + }, + ], + epochs=finetune_epochs, + batch_size=batch_size, + stage_impl=normalized_stage_impl, + stage_cudagraph=stage_cudagraph, + optimizer_name=normalized_optimizer_name, + gradient_checkpointing=gradient_checkpointing, + best_state_dtype=best_state_dtype, + ) + + return _result_from_model( + model, + train_loss=train_loss, + val_loss=val_loss, + used_identity=False, + ) + + +__all__ = [ + "ParoQuantOptimizationResult", + "build_paroquant_optimizer", + "build_random_rotation_buffers", + "optimize_paroquant_linear", + "pseudo_quantize_dequant", +] diff --git a/gptqmodel/quantization/protocol.py b/gptqmodel/quantization/protocol.py new file mode 100644 index 000000000..6f8987562 --- /dev/null +++ b/gptqmodel/quantization/protocol.py @@ -0,0 +1,527 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass, field, is_dataclass +from pathlib import Path +from typing import Any, Mapping, Optional + +import pcre + +from .config import FORMAT, METHOD, GGUFBits, GGUFConfig, QuantizeConfig, SmoothMAD + + +@dataclass(frozen=True) +class OperationSpec: + method: str + args: dict[str, Any] = field(default_factory=dict) + + +@dataclass(frozen=True) +class QuantizeSpec: + method: Optional[str] = None + args: dict[str, Any] = field(default_factory=dict) + + +@dataclass(frozen=True) +class ExportSpec: + format: Optional[str] = None + variant: Optional[str] = None + impl: Optional[str] = None + version: Optional[int | str] = None + options: dict[str, Any] = field(default_factory=dict) + + +@dataclass(frozen=True) +class TargetSpec: + mode: Optional[str] = None + prepare: tuple[OperationSpec, ...] = () + quantize: Optional[QuantizeSpec] = None + export: Optional[ExportSpec] = None + + +@dataclass(frozen=True) +class MatchSpec: + pattern: str + include: bool = True + + @property + def modifier(self) -> str: + return "+" if self.include else "-" + + def matches(self, module_name: str) -> bool: + return _pattern_matches(self.pattern, module_name) + + +@dataclass(frozen=True) +class Rule: + match: tuple[MatchSpec, ...] + aliases: dict[str, Any] | None = None + actions: tuple[OperationSpec, ...] = () + stop: bool = False + weight: Optional[TargetSpec] = None + input: Optional[TargetSpec] = None + output: Optional[TargetSpec] = None + kv_cache: Optional[TargetSpec] = None + + def matches(self, module_name: str) -> bool: + includes = tuple(selector for selector in self.match if selector.include) + excludes = tuple(selector for selector in self.match if not selector.include) + if not includes: + return False + if not any(selector.matches(module_name) for selector in includes): + return False + return not any(selector.matches(module_name) for selector in excludes) + + +@dataclass(frozen=True) +class Stage: + name: str + rules: tuple[Rule, ...] = () + + +@dataclass(frozen=True) +class ExecutionPlan: + version: int + stages: tuple[Stage, ...] + + +def skip() -> dict[str, str]: + return {"method": "skip"} + + +def compile_protocol(source: Any) -> ExecutionPlan: + payload = _normalize_root(source) + version = int(payload.get("version", 2)) + if version != 2: + raise ValueError(f"Unsupported quantization protocol version: {version}.") + + stages = tuple(_normalize_stage(stage) for stage in payload.get("stages", ())) + if not stages: + raise ValueError("Quantization protocol must define at least one stage.") + + return ExecutionPlan(version=version, stages=stages) + + +def compile_protocol_yaml_text(text: str) -> ExecutionPlan: + try: + import yaml + except Exception as exc: # pragma: no cover - dependency/runtime guard + raise ModuleNotFoundError("PyYAML is required to parse protocol YAML.") from exc + + payload = yaml.safe_load(text) + return compile_protocol(payload) + + +def compile_protocol_yaml_file(path: str | Path) -> ExecutionPlan: + protocol_path = Path(path) + return compile_protocol_yaml_text(protocol_path.read_text()) + + +def compile_plan_to_quantize_config(plan: ExecutionPlan): + if len(plan.stages) != 1: + raise NotImplementedError("Initial protocol implementation supports exactly one stage for config compilation.") + + stage = plan.stages[0] + if len(stage.rules) != 1: + raise NotImplementedError("Initial protocol implementation supports exactly one rule for config compilation.") + + rule = stage.rules[0] + if rule.aliases: + raise NotImplementedError("Initial protocol implementation does not support aliases during config compilation.") + if rule.actions: + raise NotImplementedError("Initial protocol implementation does not support actions during config compilation.") + if rule.stop: + raise NotImplementedError("Initial protocol implementation does not support stop during config compilation.") + if rule.input is not None or rule.output is not None or rule.kv_cache is not None: + raise NotImplementedError("Initial protocol implementation only supports weight-target compilation.") + if rule.weight is None: + raise ValueError("Initial protocol implementation requires a `weight` target.") + + return _compile_weight_target(rule.weight, matchers=rule.match) + + +def compile_protocol_to_quantize_config(source: Any): + return compile_plan_to_quantize_config(compile_protocol(source)) + + +def compile_protocol_yaml_to_quantize_config(text: str): + return compile_plan_to_quantize_config(compile_protocol_yaml_text(text)) + + +def _normalize_root(source: Any) -> dict[str, Any]: + if isinstance(source, ExecutionPlan): + return { + "version": source.version, + "stages": list(source.stages), + } + if isinstance(source, Mapping): + return dict(source) + if is_dataclass(source): + return { + "version": getattr(source, "version"), + "stages": getattr(source, "stages"), + } + raise TypeError( + "Quantization protocol root must be a mapping or dataclass-like object with `version` and `stages`." + ) + + +def _normalize_stage(source: Any) -> Stage: + if isinstance(source, Stage): + return Stage( + name=source.name, + rules=tuple(_normalize_rule(rule) for rule in source.rules), + ) + data = _coerce_mapping(source, context="stage") + name = data.get("name") + if not name: + raise ValueError("Stage requires a non-empty `name`.") + rules = tuple(_normalize_rule(rule) for rule in data.get("rules", ())) + if not rules: + raise ValueError(f"Stage `{name}` must define at least one rule.") + return Stage(name=str(name), rules=rules) + + +def _normalize_rule(source: Any) -> Rule: + if isinstance(source, Rule): + return Rule( + match=_normalize_match(source.match), + aliases=_copy_optional_mapping(source.aliases), + actions=tuple(_normalize_operation(action) for action in source.actions), + stop=bool(source.stop), + weight=_normalize_target(source.weight), + input=_normalize_target(source.input), + output=_normalize_target(source.output), + kv_cache=_normalize_target(source.kv_cache), + ) + + data = _coerce_mapping(source, context="rule") + match = data.get("match") + if not match: + raise ValueError("Rule requires a non-empty `match`.") + + return Rule( + match=_normalize_match(match), + aliases=_copy_optional_mapping(data.get("aliases")), + actions=tuple(_normalize_operation(action) for action in data.get("actions", ())), + stop=bool(data.get("stop", False)), + weight=_normalize_target(data.get("weight")), + input=_normalize_target(data.get("input")), + output=_normalize_target(data.get("output")), + kv_cache=_normalize_target(data.get("kv_cache")), + ) + + +def _normalize_target(source: Any) -> Optional[TargetSpec]: + if source is None: + return None + if isinstance(source, TargetSpec): + return TargetSpec( + mode=source.mode, + prepare=tuple(_normalize_operation(op) for op in source.prepare), + quantize=_normalize_quantize(source.quantize), + export=_normalize_export(source.export), + ) + + data = _coerce_mapping(source, context="target") + return TargetSpec( + mode=data.get("mode"), + prepare=tuple(_normalize_operation(op) for op in data.get("prepare", ()) or ()), + quantize=_normalize_quantize(data.get("quantize")), + export=_normalize_export(data.get("export")), + ) + + +def _normalize_match(source: Any) -> tuple[MatchSpec, ...]: + if isinstance(source, MatchSpec): + return (MatchSpec(pattern=source.pattern, include=bool(source.include)),) + if isinstance(source, str): + return (_normalize_match_selector(source),) + if isinstance(source, (tuple, list)): + selectors = tuple(_normalize_match_selector(item) for item in source) + if not selectors: + raise ValueError("Rule `match` list must not be empty.") + return selectors + raise TypeError("Rule `match` must be a string, MatchSpec, or a list/tuple of selector strings.") + + +def _normalize_match_selector(source: Any) -> MatchSpec: + if isinstance(source, MatchSpec): + return MatchSpec(pattern=source.pattern, include=bool(source.include)) + if not isinstance(source, str): + raise TypeError("Match selector must be a string or MatchSpec.") + + selector = source.strip() + if not selector: + raise ValueError("Match selector must not be empty.") + + include = True + if selector.startswith("+:"): + selector = selector[2:].strip() + elif selector.startswith("-:"): + include = False + selector = selector[2:].strip() + + if not selector: + raise ValueError("Match selector pattern must not be empty.") + + return MatchSpec(pattern=selector, include=include) + + +def _normalize_operation(source: Any) -> OperationSpec: + if isinstance(source, OperationSpec): + return OperationSpec(method=source.method, args=dict(source.args)) + if isinstance(source, str): + return OperationSpec(method=source) + data = _coerce_mapping(source, context="operation") + method = data.get("method") + if not method: + raise ValueError("Operation requires a non-empty `method`.") + args = {key: value for key, value in data.items() if key != "method"} + return OperationSpec(method=str(method), args=args) + + +def _normalize_quantize(source: Any) -> Optional[QuantizeSpec]: + if source is None: + return None + if isinstance(source, QuantizeSpec): + return QuantizeSpec(method=source.method, args=dict(source.args)) + if isinstance(source, str): + return QuantizeSpec(method=source) + data = _coerce_mapping(source, context="quantize") + method = data.get("method") + args = {key: value for key, value in data.items() if key != "method"} + return QuantizeSpec(method=str(method) if method is not None else None, args=args) + + +def _normalize_export(source: Any) -> Optional[ExportSpec]: + if source is None: + return None + if isinstance(source, ExportSpec): + return ExportSpec( + format=source.format, + variant=source.variant, + impl=source.impl, + version=source.version, + options=dict(source.options), + ) + if isinstance(source, str): + return ExportSpec(format=source) + data = _coerce_mapping(source, context="export") + options = dict(data.get("options", {}) or {}) + return ExportSpec( + format=data.get("format"), + variant=data.get("variant"), + impl=data.get("impl"), + version=data.get("version"), + options=options, + ) + + +def _coerce_mapping(source: Any, *, context: str) -> dict[str, Any]: + if isinstance(source, Mapping): + return dict(source) + if is_dataclass(source): + return {field_name: getattr(source, field_name) for field_name in source.__dataclass_fields__} + raise TypeError(f"Quantization protocol {context} must be provided as a mapping or dataclass.") + + +def _copy_optional_mapping(source: Any) -> dict[str, Any] | None: + if source is None: + return None + if not isinstance(source, Mapping): + raise TypeError("Rule `aliases` must be a mapping when provided.") + return dict(source) + + +def _compile_weight_target(weight: TargetSpec, *, matchers: tuple[MatchSpec, ...]): + if weight.mode not in {None, "merge"}: + raise NotImplementedError("Initial protocol compiler supports only the default target merge mode.") + + quantize = weight.quantize + if quantize is None or not quantize.method: + raise ValueError("Weight target requires `weight.quantize.method`.") + + method = str(quantize.method).strip().lower() + if method == METHOD.GGUF.value: + return _compile_gguf_weight_target(weight, matchers=matchers) + if method in {METHOD.GPTQ.value, METHOD.AWQ.value}: + return _compile_quantize_config_weight_target(weight, matchers=matchers, method=METHOD(method)) + raise NotImplementedError( + "Initial protocol compiler supports only `weight.quantize.method` in {\"gguf\", \"gptq\", \"awq\"}." + ) + + +def _compile_gguf_weight_target(weight: TargetSpec, *, matchers: tuple[MatchSpec, ...]) -> GGUFConfig: + if not _supports_initial_weight_match_compilation(matchers): + raise NotImplementedError( + "Initial GGUF protocol compiler supports only `match=\"*\"` or `match=[\"*\", \"-:...\"]`." + ) + + quantize = weight.quantize + if quantize is None: + raise ValueError("GGUF weight target requires `weight.quantize`.") + if quantize.method != "gguf": + raise NotImplementedError( + "Initial GGUF compiler supports only `weight.quantize.method = \"gguf\"`." + ) + + bits = quantize.args.get("bits") + if bits is None: + raise ValueError("GGUF weight target requires `weight.quantize.bits`.") + + export = weight.export + if export is not None and export.format not in {None, "gguf"}: + raise NotImplementedError("Initial GGUF compiler supports only `weight.export.format = \"gguf\"`.") + + smoother = _compile_supported_smoother(weight.prepare) + gguf_format = _resolve_gguf_public_format(bits=bits, export=export) + dynamic = _compile_negative_match_dynamic(matchers) + return GGUFConfig(bits=bits, format=gguf_format, smoother=smoother, dynamic=dynamic) + + +def _compile_quantize_config_weight_target(weight: TargetSpec, *, matchers: tuple[MatchSpec, ...], method: METHOD): + if not _supports_initial_weight_match_compilation(matchers): + raise NotImplementedError( + f"Initial {method.value.upper()} protocol compiler supports only `match=\"*\"` or `match=[\"*\", \"-:...\"]`." + ) + if weight.prepare: + raise NotImplementedError( + f"Initial {method.value.upper()} protocol compiler does not yet support `weight.prepare`." + ) + + quantize = weight.quantize + if quantize is None: + raise ValueError(f"{method.value.upper()} weight target requires `weight.quantize`.") + if quantize.method != method.value: + raise NotImplementedError( + f"Initial {method.value.upper()} compiler supports only `weight.quantize.method = \"{method.value}\"`." + ) + + bits = quantize.args.get("bits") + if bits is None: + raise ValueError(f"{method.value.upper()} weight target requires `weight.quantize.bits`.") + + export_format = _resolve_export_format(method=method, export=weight.export) + dynamic = _compile_negative_match_dynamic(matchers) + group_size = quantize.args.get("group_size", 128) + sym = bool(quantize.args.get("sym", True)) + + kwargs = { + "method": method, + "format": export_format, + "bits": bits, + "group_size": group_size, + "sym": sym, + "dynamic": dynamic, + } + + if "desc_act" in quantize.args or method == METHOD.GPTQ: + kwargs["desc_act"] = bool(quantize.args.get("desc_act", False)) + + if method == METHOD.GPTQ: + if "act_group_aware" in quantize.args: + kwargs["act_group_aware"] = bool(quantize.args["act_group_aware"]) + + return QuantizeConfig(**kwargs) + + +def _compile_supported_smoother(prepare: tuple[OperationSpec, ...]) -> Optional[SmoothMAD]: + if not prepare: + return None + if len(prepare) != 1: + raise NotImplementedError("Initial GGUF compiler supports at most one weight.prepare operation.") + + op = prepare[0] + if op.method not in {"smooth.mad", "smoother"}: + raise NotImplementedError( + "Initial GGUF compiler supports only `smooth.mad` in `weight.prepare`." + ) + k = op.args.get("k") + if k is None: + smooth_payload = op.args.get("smooth") + if isinstance(smooth_payload, Mapping): + if smooth_payload.get("type") not in {None, "mad"}: + raise NotImplementedError("Initial GGUF compiler supports only MAD smoothers.") + k = smooth_payload.get("k") + return SmoothMAD(k=2.75 if k is None else float(k)) + + +def _resolve_gguf_public_format(bits: Any, export: Optional[ExportSpec]) -> Optional[str]: + variant = export.variant if export is not None else None + + if isinstance(bits, str): + normalized = bits.strip().lower().replace("-", "_") + if normalized and not normalized.isdigit(): + bits_spec = GGUFBits.from_string(normalized) + public_format = bits_spec.to_public_format() + if variant is not None and variant != public_format: + raise ValueError( + f"GGUF protocol uses incompatible bits/export variant combination: bits={bits}, export.variant={variant}." + ) + return variant or public_format + + return variant + + +def _is_global_match(matchers: tuple[MatchSpec, ...]) -> bool: + return len(matchers) == 1 and matchers[0].include and matchers[0].pattern == "*" + + +def _supports_initial_weight_match_compilation(matchers: tuple[MatchSpec, ...]) -> bool: + includes = tuple(selector for selector in matchers if selector.include) + return bool(includes) and all(selector.pattern == "*" for selector in includes) + + +def _compile_negative_match_dynamic(matchers: tuple[MatchSpec, ...]) -> Optional[dict[str, dict[str, Any]]]: + excludes = tuple(selector for selector in matchers if not selector.include) + if not excludes: + return None + return {f"-:{selector.pattern}": {} for selector in excludes} + + +def _resolve_export_format(method: METHOD, export: Optional[ExportSpec]) -> FORMAT: + if method == METHOD.GPTQ: + if export is None: + return FORMAT.GPTQ + if export.format not in {None, METHOD.GPTQ.value}: + raise NotImplementedError("Initial GPTQ compiler supports only `weight.export.format = \"gptq\"`.") + variant = str(export.variant or FORMAT.GPTQ.value).strip().lower().replace("-", "_") + mapping = { + FORMAT.GPTQ.value: FORMAT.GPTQ, + FORMAT.GPTQ_V2.value: FORMAT.GPTQ_V2, + FORMAT.MARLIN.value: FORMAT.MARLIN, + FORMAT.BITBLAS.value: FORMAT.BITBLAS, + } + if variant not in mapping: + raise NotImplementedError(f"Unsupported GPTQ export variant: `{variant}`.") + return mapping[variant] + + if method == METHOD.AWQ: + if export is None: + return FORMAT.GEMM + if export.format not in {None, METHOD.AWQ.value}: + raise NotImplementedError("Initial AWQ compiler supports only `weight.export.format = \"awq\"`.") + variant = str(export.variant or FORMAT.GEMM.value).strip().lower().replace("-", "_") + mapping = { + FORMAT.GEMM.value: FORMAT.GEMM, + FORMAT.GEMV.value: FORMAT.GEMV, + FORMAT.GEMV_FAST.value: FORMAT.GEMV_FAST, + "gemvfast": FORMAT.GEMV_FAST, + FORMAT.LLM_AWQ.value.replace("-", "_"): FORMAT.LLM_AWQ, + FORMAT.LLM_AWQ.value: FORMAT.LLM_AWQ, + FORMAT.MARLIN.value: FORMAT.MARLIN, + } + if variant not in mapping: + raise NotImplementedError(f"Unsupported AWQ export variant: `{variant}`.") + return mapping[variant] + + raise NotImplementedError(f"Unsupported export method resolution for `{method}`.") + + +def _pattern_matches(pattern: str, module_name: str) -> bool: + if pattern == "*": + return True + return pcre.compile(pattern).search(module_name) is not None diff --git a/gptqmodel/quantization/qqq.py b/gptqmodel/quantization/qqq.py index c7bb32d5b..7952c27eb 100644 --- a/gptqmodel/quantization/qqq.py +++ b/gptqmodel/quantization/qqq.py @@ -13,10 +13,10 @@ from .. import QuantizeConfig from ..looper.named_module import NamedModule -from ..quantization.config import FailSafeStrategy, SmoothMSE +from ..quantization.config import FallbackStrategy, SmoothMSE from ..quantization.quantizer import HF_OPTIMUM from ..utils import setup_logger -from .failsafe_smooth import mse_optimal_quant, smooth_block +from .fallback_smooth import mse_optimal_quant, smooth_block from .gptq import get_number_of_rows_and_cols @@ -238,7 +238,7 @@ def __init__(self, module: nn.Module, qcfg: Optional[QuantizeConfig] = None): # fwd counter self.fwd_counter = 0 - self.failsafe = self.qcfg.failsafe + self.fallback = self.qcfg.fallback self.expected_nsamples: Optional[float] = None self.H = torch.zeros((self.columns, self.columns), @@ -261,12 +261,12 @@ def _truncate_last_dim(tensor: torch.Tensor, length: int) -> torch.Tensor: return tensor.narrow(tensor.dim() - 1, 0, trim).contiguous() - def _failsafe_quantize(self, strategy: FailSafeStrategy): + def _fallback_quantize(self, strategy: FallbackStrategy): maxq = 2 ** self.qcfg.bits - 1 sigma = 3.0 group_size = self.qcfg.group_size if self.qcfg.group_size != -1 else self.columns start_time = time.time() - smooth_method = getattr(self.failsafe, "smooth", None) + smooth_method = getattr(self.fallback, "smooth", None) mse_steps = 32 mse_maxshrink = 0.8 if isinstance(smooth_method, SmoothMSE): @@ -306,10 +306,10 @@ def _failsafe_quantize(self, strategy: FailSafeStrategy): else: block_mod, scale_factor = smooth_block( block, - self.failsafe, + self.fallback, group_size=self.qcfg.group_size if self.qcfg.group_size != -1 else self.columns, ) - if strategy == FailSafeStrategy.MIDPOINT: + if strategy == FallbackStrategy.MIDPOINT: w_min = block_mod.min(dim=1, keepdim=True).values w_max = block_mod.max(dim=1, keepdim=True).values mid = (w_max + w_min) / 2.0 @@ -320,7 +320,7 @@ def _failsafe_quantize(self, strategy: FailSafeStrategy): zero = torch.round(zero_mid - (mid / scale)) zero = torch.clamp(zero, 0, maxq) dequant = (q - zero) * scale - elif strategy == FailSafeStrategy.MEAN: + elif strategy == FallbackStrategy.MEAN: mean = block_mod.mean(dim=1, keepdim=True) max_dev = torch.max((block_mod - mean).abs(), dim=1, keepdim=True).values max_dev = torch.clamp(max_dev, min=1e-8) @@ -331,7 +331,7 @@ def _failsafe_quantize(self, strategy: FailSafeStrategy): zero = torch.round(zero_mid - (mean / scale)) zero = torch.clamp(zero, 0, maxq) dequant = (q - zero) * scale - elif strategy == FailSafeStrategy.MEDIAN: + elif strategy == FallbackStrategy.MEDIAN: median = block_mod.median(dim=1, keepdim=True).values max_dev = torch.max((block_mod - median).abs(), dim=1, keepdim=True).values max_dev = torch.clamp(max_dev, min=1e-8) @@ -342,7 +342,7 @@ def _failsafe_quantize(self, strategy: FailSafeStrategy): zero = torch.round(zero_mid - (median / scale)) zero = torch.clamp(zero, 0, maxq) dequant = (q - zero) * scale - elif strategy == FailSafeStrategy.STDCLIP: + elif strategy == FallbackStrategy.STDCLIP: mean = block_mod.mean(dim=1, keepdim=True) std = block_mod.std(dim=1, keepdim=True, unbiased=False) std = torch.clamp(std, min=1e-8) @@ -354,13 +354,13 @@ def _failsafe_quantize(self, strategy: FailSafeStrategy): q = torch.round(block_mod / scale + zero) q = torch.clamp(q, 0, maxq) dequant = (q - zero) * scale - elif strategy == FailSafeStrategy.RTN: + elif strategy == FallbackStrategy.RTN: self.quantizer.find_params(block_mod, weight=True) dequant = self.quantizer.quantize(block_mod) scale = self.quantizer.scale zero = self.quantizer.zero else: - raise ValueError(f"Unsupported failsafe strategy: {strategy}") + raise ValueError(f"Unsupported fallback strategy: {strategy}") if scale_factor is not None: scale = scale * scale_factor @@ -412,7 +412,7 @@ def _failsafe_quantize(self, strategy: FailSafeStrategy): duration = time.time() - start_time mean_abs_err = (Q - self.layer.weight.data).abs().mean().item() - avg_loss = f"failsafe({strategy.value}): {mean_abs_err:.7f}" + avg_loss = f"fallback({strategy.value}): {mean_abs_err:.7f}" damp_percent = 0.0 self.H = None return Q, scale, zero, g_idx, duration, avg_loss, damp_percent, scale_extra, self.nsamples @@ -454,9 +454,9 @@ def quantize( blocksize=128, ): start = time.time() - from ..utils.failsafe import resolve_failsafe_strategy, resolve_threshold, should_use_failsafe + from ..utils.fallback import resolve_fallback_strategy, resolve_threshold, should_use_fallback - resolved_strategy = resolve_failsafe_strategy(self.failsafe) + resolved_strategy = resolve_fallback_strategy(self.fallback) percdamp = self.qcfg.damp_percent groupsize = self.qcfg.group_size @@ -516,8 +516,8 @@ def quantize( damp = percdamp * torch.mean(torch.diag(H)) diag = torch.arange(self.columns, device=self.dev) H[diag, diag] += damp - threshold_raw, is_percent = resolve_threshold(self.failsafe, self.expected_nsamples) - failsafe_configured = threshold_raw is not None + threshold_raw, is_percent = resolve_threshold(self.fallback, self.expected_nsamples) + fallback_configured = threshold_raw is not None try: H = torch.linalg.cholesky(H) @@ -525,24 +525,24 @@ def quantize( H = torch.linalg.cholesky(H, upper=True) Hinv = H except Exception: - fallback_requested = should_use_failsafe( - self.failsafe, + fallback_requested = should_use_fallback( + self.fallback, float(self.nsamples), self.expected_nsamples, ) if fallback_requested: extra = f", threshold_raw={threshold_raw}" if threshold_raw is not None and is_percent else "" log.warn( - "Quantization: Module `%s` -> Using `%s` failsafe quantization (observed %s samples, threshold=%s%s, max_total=%s).", + "Quantization: Module `%s` -> Using `%s` fallback quantization (observed %s samples, threshold=%s%s, max_total=%s).", self.name, resolved_strategy.value, self.nsamples, - self.failsafe, + self.fallback, extra, self.expected_nsamples, ) - if resolved_strategy != FailSafeStrategy.RTN: - return self._failsafe_quantize(resolved_strategy) + if resolved_strategy != FallbackStrategy.RTN: + return self._fallback_quantize(resolved_strategy) Hinv = None else: raise @@ -615,20 +615,20 @@ def quantize( if math.isnan(avg_loss): print("Losses sum item:", torch.sum(Losses).item()) - if failsafe_configured: + if fallback_configured: log.info(f"Quantization: Failed due to `NaN` loss for `{self.name}`, use mock quantization retry for `{self.name}`") self.qcfg.mock_quantization = True return self.quantize(blocksize=blocksize) else: - raise ValueError(f"Quantization: Failed due to `NaN` loss for `{self.name}`, please try increasing calibration data samples or enable failsafe=True") + raise ValueError(f"Quantization: Failed due to `NaN` loss for `{self.name}`, please try increasing calibration data samples or enable fallback=True") else: - if failsafe_configured: + if fallback_configured: log.warn(f"Quantization: Module `{self.name}` -> using fail safe mode. Please check if calibration data is sufficient.") else: log.warn(f"Quantization: `{self.name}` is not activated due to model inference logic (MoE)") - avg_loss = f"{resolved_strategy.value} failsafe" if failsafe_configured else 999999999 + avg_loss = f"{resolved_strategy.value} fallback" if fallback_configured else 999999999 else: - avg_loss = f"{resolved_strategy.value} failsafe" if failsafe_configured else 999999999 + avg_loss = f"{resolved_strategy.value} fallback" if fallback_configured else 999999999 del Losses diff --git a/gptqmodel/quantization/quantizer.py b/gptqmodel/quantization/quantizer.py index c661c8c83..07d4dca04 100644 --- a/gptqmodel/quantization/quantizer.py +++ b/gptqmodel/quantization/quantizer.py @@ -8,8 +8,8 @@ import torch import torch.nn as nn -from ..quantization import QuantizeConfig from ..utils.logger import setup_logger +from .config import BaseQuantizeConfig, _normalize_quant_bits, resolve_quant_format log = setup_logger() @@ -28,7 +28,7 @@ def quantize(x, scale, zero, maxq, requires_groupwise_processing: bool): class Quantizer(nn.Module): - def __init__(self, qcfg: QuantizeConfig, shape=1, name: str=None): + def __init__(self, qcfg: BaseQuantizeConfig, shape=1, name: str=None): super(Quantizer, self).__init__() self.qcfg = qcfg @@ -48,12 +48,14 @@ def configure( grid=100, maxshrink=0.8, trits=False, - bits:int=4, # for hf compat - sym:bool=False, # for hf compat + bits: int | str | None = None, # for hf compat + sym: bool | None = None, # for hf compat ): if self.name == HF_OPTIMUM: - self.qcfg.bits = bits - self.qcfg.sym = sym + if bits is not None: + self.qcfg.bits = _normalize_quant_bits(bits, format_value=resolve_quant_format(self.qcfg.format, self.qcfg.method)) + if sym is not None: + self.qcfg.sym = sym if self.requires_groupwise_processing(): self.maxq = torch.tensor(2 ** (self.qcfg.bits - 1) - 1) @@ -112,9 +114,10 @@ def find_params(self, x, weight=False, importance: torch.Tensor = None): else: self.zero = torch.round(-xmin / self.scale) - if self.qcfg.mse > 0.0: + mse = float(getattr(self.qcfg, "mse", 0.0) or 0.0) + if mse > 0.0: importance_weights = None - if self.qcfg.activation_weighted_mse and importance is not None: + if getattr(self.qcfg, "activation_weighted_mse", False) and importance is not None: importance_weights = torch.nan_to_num( importance.to(device=dev, dtype=x.dtype), nan=0.0, @@ -158,7 +161,7 @@ def find_params(self, x, weight=False, importance: torch.Tensor = None): q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq, self.requires_groupwise_processing()) q -= x q.abs_() - q.pow_(self.qcfg.mse) + q.pow_(mse) if importance_weights is not None: q.mul_(importance_weights) err = torch.sum(q, 1) diff --git a/gptqmodel/quantization/rotation/__init__.py b/gptqmodel/quantization/rotation/__init__.py index e69de29bb..98989a602 100644 --- a/gptqmodel/quantization/rotation/__init__.py +++ b/gptqmodel/quantization/rotation/__init__.py @@ -0,0 +1,4 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium diff --git a/gptqmodel/quantization/rtn.py b/gptqmodel/quantization/rtn.py new file mode 100644 index 000000000..cb66265e5 --- /dev/null +++ b/gptqmodel/quantization/rtn.py @@ -0,0 +1,199 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import math +import time +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import transformers +from torch.nn.modules.conv import _ConvNd + +from ..looper.named_module import NamedModule +from .config import Fallback, FallbackStrategy, RTNConfig, SmoothMSE +from .fallback_smooth import mse_optimal_quant, smooth_block +from .quantizer import HF_OPTIMUM, Quantizer + + +def get_number_of_rows_and_cols(layer: nn.Module) -> Tuple[int, int]: + if isinstance(layer, NamedModule): + layer = layer.module + + if isinstance(layer, transformers.Conv1D): + return layer.weight.shape[1], layer.weight.shape[0] + + return layer.weight.shape[0], math.prod(layer.weight.shape[1:]) + + +class RTN: + """Native weight-only RTN quantizer with optional smoothing. + + This path never enters GPTQ's activation/Hessian lifecycle. It quantizes the + module weights directly, then returns tensors that can be packed into GPTQ, + AWQ, or future export layouts by the existing packing stage. + """ + + def __init__(self, module: nn.Module, qcfg: RTNConfig): + self.rows, self.columns = get_number_of_rows_and_cols(module) + if isinstance(module, NamedModule): + self.module = module.module + self.name = module.name + self._named_module = module + else: + self.module = module + self.name = HF_OPTIMUM + self._named_module = None + + self.validate_module(self.module) + self.qcfg = qcfg + self.quantizer = Quantizer(qcfg=qcfg, name=self.name) + self.quantizer.configure(perchannel=True) + self.nsamples = 0 + self._primary = Fallback( + strategy=FallbackStrategy.RTN, + threshold=True, + smooth=qcfg.smooth, + ) + + self._original_columns = self.columns + if self._named_module is not None: + pad_info = self._named_module.state.get("tp_pad_info") + else: + pad_info = getattr(self.module, "_tp_pad_info", None) + + if isinstance(pad_info, dict): + pad_cols = int(pad_info.get("pad_cols", 0) or 0) + pad_cols = max(pad_cols, 0) + else: + pad_cols = 0 + + self._tp_pad_cols = pad_cols + if self._tp_pad_cols: + self.columns += self._tp_pad_cols + + @staticmethod + def validate_module(module: nn.Module) -> None: + assert isinstance( + module, + (nn.Linear, nn.Conv1d, nn.Conv2d, transformers.Conv1D), + ), f"We supports only linear and convolutional layers. actual = `{module}`" + + def clone_module(self, device: Optional[torch.device] = None) -> torch.Tensor: + if device is None: + device = self.module.weight.data.device + + clone = self.module.weight.data.to(copy=True, device=device) + if isinstance(self.module, _ConvNd): + clone = clone.flatten(1) + if isinstance(self.module, transformers.pytorch_utils.Conv1D): + clone = clone.t() + if self._tp_pad_cols: + pad = torch.zeros( + (clone.shape[0], self._tp_pad_cols), + dtype=clone.dtype, + device=clone.device, + ) + clone = torch.cat((clone, pad), dim=1) + return clone.float() + + @staticmethod + def truncate_last_dim(tensor: torch.Tensor, length: int) -> torch.Tensor: + if tensor.dim() == 0: + return tensor + + trim = min(length, tensor.shape[-1]) + if trim == tensor.shape[-1]: + return tensor + + return tensor.narrow(tensor.dim() - 1, 0, trim).contiguous() + + @staticmethod + def _collapse_group_param(param: torch.Tensor) -> torch.Tensor: + collapsed = param if param.dim() > 1 else param.unsqueeze(1) + if collapsed.shape[1] > 1: + collapsed = collapsed.mean(dim=1, keepdim=True) + return collapsed + + @torch.inference_mode() + def quantize(self): + maxq = 2 ** self.qcfg.bits - 1 + effective_group_size = self.qcfg.group_size if self.qcfg.group_size != -1 else self.columns + smooth_method = self.qcfg.smooth + mse_steps = 32 + mse_maxshrink = 0.8 + if isinstance(smooth_method, SmoothMSE): + mse_steps = smooth_method.steps + mse_maxshrink = smooth_method.maxshrink + + start_time = time.time() + target_device = self.module.weight.device + weights = self.clone_module(device=target_device) + quantized = torch.empty_like(weights) + scale_chunks = [] + zero_chunks = [] + + for start in range(0, self.columns, effective_group_size): + end = min(start + effective_group_size, self.columns) + block = weights[:, start:end] + + if isinstance(smooth_method, SmoothMSE): + dequant, scale, zero = mse_optimal_quant( + block, + self.qcfg, + maxq, + steps=mse_steps, + maxshrink=mse_maxshrink, + ) + else: + block_mod, scale_factor = smooth_block( + block, + self._primary, + group_size=effective_group_size, + ) + self.quantizer.find_params(block_mod, weight=True) + dequant = self.quantizer.quantize(block_mod) + scale = self.quantizer.scale + zero = self.quantizer.zero + + if scale_factor is not None: + scale = scale * scale_factor + dequant = dequant * scale_factor + + quantized[:, start:end] = dequant + scale_chunks.append(self._collapse_group_param(scale)) + zero_chunks.append(self._collapse_group_param(zero)) + + scale = torch.cat(scale_chunks, dim=1) + zero = torch.cat(zero_chunks, dim=1) + + if self._tp_pad_cols: + valid_cols = self._original_columns + quantized = quantized[:, :valid_cols] + scale = self.truncate_last_dim(scale, valid_cols) + zero = self.truncate_last_dim(zero, valid_cols) + else: + valid_cols = self.columns + + g_idx = torch.arange(valid_cols, device=quantized.device, dtype=torch.int32) // effective_group_size + + if isinstance(self.module, transformers.Conv1D): + quantized = quantized.t() + + if quantized.shape != self.module.weight.shape: + quantized = quantized.reshape(self.module.weight.shape).to(self.module.weight.dtype) + else: + quantized = quantized.to(self.module.weight.dtype) + + quantized = quantized.to(device=self.module.weight.data.device, non_blocking=False) + mean_abs_err = (quantized - self.module.weight.data).abs().mean().item() + duration = time.time() - start_time + avg_loss = f"rtn: {mean_abs_err:.7f}" + damp = 0.0 + + return quantized, scale, zero, g_idx, duration, avg_loss, damp, self.nsamples + + +__all__ = ["RTN", "get_number_of_rows_and_cols"] diff --git a/gptqmodel/utils/__init__.py b/gptqmodel/utils/__init__.py index abb86b4e5..20f621a02 100644 --- a/gptqmodel/utils/__init__.py +++ b/gptqmodel/utils/__init__.py @@ -3,7 +3,10 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium -from .backend import BACKEND +import os +import threading + +from .backend import BACKEND, PROFILE from .logger import setup_logger from .python import gte_python_3_13_3, gte_python_3_14, has_gil_control, has_gil_disabled, log_gil_requirements_for from .threads import AsyncManager, SerialWorker @@ -11,14 +14,17 @@ log = setup_logger() +# Shared across runtime monkeypatch entrypoints; some patch helpers nest. +_MONKEY_PATCH_LOCK = threading.RLock() ASYNC_BG_QUEUE = AsyncManager(threads=4) SERIAL_BG_QUEUE = SerialWorker() # TODO: datasets is not compatible with free threading -if has_gil_disabled(): - log.info("Python GIL is disabled and GPTQModel will auto enable multi-gpu quant acceleration for MoE models plus multi-cpu accelerated packing.") - from .perplexity import Perplexity +if os.environ.get("GPTQMODEL_DISABLE_GIL_WARNING") == "1": + pass +elif has_gil_disabled(): + log.info("Python GIL is disabled and GPT-QModel will auto enable multi-gpu quant acceleration for MoE models plus multi-cpu accelerated packing.") else: if has_gil_control(): log.warn( @@ -26,5 +32,3 @@ log.warn( "Python GIL is enabled: Multi-gpu quant acceleration for MoE models is sub-optimal and multi-core accelerated cpu packing is also disabled. We recommend Python >= 3.13.3t with Pytorch > 2.8 for mult-gpu quantization and multi-cpu packing with env `PYTHON_GIL=0`.") - - log_gil_requirements_for("utils/Perplexity") diff --git a/gptqmodel/utils/awq.py b/gptqmodel/utils/awq.py new file mode 100644 index 000000000..d476ed617 --- /dev/null +++ b/gptqmodel/utils/awq.py @@ -0,0 +1,154 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from pathlib import Path + +from .cpp import ( + TorchOpsJitExtension, + cuda_include_paths_with_fallback, + default_jit_cflags, + default_jit_cuda_cflags, + default_torch_ops_build_root, +) + + +_AWQ_OPS_NAME = "gptqmodel_awq_ops" +_AWQ_OPS_NAMESPACE = "gptqmodel_awq" + + +def _awq_root() -> Path: + return Path(__file__).resolve().parents[2] / "gptqmodel_ext" / "awq" + + +def _awq_sources() -> list[str]: + root = _awq_root() + return [ + str(root / "torch_bind.cpp"), + str(root / "quantization" / "gemm_cuda_gen.cu"), + str(root / "quantization" / "gemv_cuda.cu"), + str(root / "gemm_fast_cuda_entry.cu"), + str(root / "gemv_fast_cuda_entry.cu"), + ] + + +def _awq_required_cuda_headers() -> tuple[str, ...]: + return ("cusparse.h",) + + +def _awq_include_paths() -> list[str]: + return cuda_include_paths_with_fallback( + [str(_awq_root())], + required_header_names=_awq_required_cuda_headers(), + ) + + +def _awq_extra_cflags() -> list[str]: + return default_jit_cflags(enable_bf16=True) + + +def _awq_extra_cuda_cflags() -> list[str]: + return default_jit_cuda_cflags( + enable_bf16=True, + include_lineinfo=True, + include_nvcc_threads=True, + include_ptxas_optimizations=True, + include_ptxas_verbosity=False, + include_fatbin_compression=True, + include_diag_suppress=True, + ) + + +# Shared singleton so every AWQ/ParoQuant caller uses the same torch.ops JIT +# cache, force-rebuild controls, and user-facing compile spinner. +_AWQ_TORCH_OPS_EXTENSION = TorchOpsJitExtension( + name=_AWQ_OPS_NAME, + namespace=_AWQ_OPS_NAMESPACE, + required_ops=( + "gemm_forward", + "gemm_forward_fp32_reduce", + "gemmv2_forward", + "gemv_forward", + "gemm_fast_forward_prefill", + "gemv_fast_forward_decode", + "dequantize_weights", + ), + sources=_awq_sources, + build_root_env="GPTQMODEL_AWQ_BUILD_ROOT", + default_build_root=lambda: default_torch_ops_build_root("awq"), + display_name="AWQ", + extra_cflags=_awq_extra_cflags, + extra_cuda_cflags=_awq_extra_cuda_cflags, + extra_include_paths=_awq_include_paths, + force_rebuild_env="GPTQMODEL_AWQ_FORCE_REBUILD", + verbose_env="GPTQMODEL_EXT_VERBOSE", + requires_cuda=True, +) + + +def _extension_api(): + from gptqmodel import extension as extension_api + + return extension_api + + +def clear_awq_extension_cache() -> None: + _AWQ_TORCH_OPS_EXTENSION.clear_cache() + + +def awq_runtime_available() -> bool: + return _extension_api().is_available("awq") + + +def awq_runtime_error() -> str: + extension_api = _extension_api() + if extension_api.is_available("awq"): + return "" + return extension_api.error("awq") or "CUDA AWQ runtime unavailable." + + +def prewarm_awq_extension() -> bool: + return _extension_api().load(name="awq")["awq"] + + +def _awq_runtime_namespace(): + return _extension_api().namespace("awq") + + +def awq_gemm_forward(input, qweight, scales, qzeros, split_k_iters, fp32_accum: bool): + return _extension_api().op("awq", "gemm_forward")(input, qweight, scales, qzeros, split_k_iters, fp32_accum) + + +def awq_dequantize_weights(qweight, scales, qzeros, split_k_iters, thx, thy, dbg): + return _extension_api().op("awq", "dequantize_weights")(qweight, scales, qzeros, split_k_iters, thx, thy, dbg) + + +def awq_gemmv2_forward(input, qweight, scales, qzeros, group_size, split_k_iters): + return _extension_api().op("awq", "gemmv2_forward")(input, qweight, scales, qzeros, group_size, split_k_iters) + + +def awq_gemv_forward(input, qweight, scales, qzeros, group_size): + return _extension_api().op("awq", "gemv_forward")(input, qweight, scales, qzeros, group_size) + + +def awq_fast_gemm_forward_prefill(input, qweight, scales, qzeros): + return _extension_api().op("awq", "gemm_fast_forward_prefill")(input, qweight, scales, qzeros) + + +def awq_fast_gemv_forward_decode(input, qweight, scales, qzeros, m, n, k, group_size): + return _extension_api().op("awq", "gemv_fast_forward_decode")(input, qweight, scales, qzeros, m, n, k, group_size) + + +__all__ = [ + "awq_dequantize_weights", + "awq_fast_gemm_forward_prefill", + "awq_fast_gemv_forward_decode", + "awq_gemm_forward", + "awq_gemmv2_forward", + "awq_gemv_forward", + "awq_runtime_available", + "awq_runtime_error", + "clear_awq_extension_cache", + "prewarm_awq_extension", +] diff --git a/gptqmodel/utils/backend.py b/gptqmodel/utils/backend.py index b715e035e..831ebc702 100644 --- a/gptqmodel/utils/backend.py +++ b/gptqmodel/utils/backend.py @@ -4,40 +4,199 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium from enum import Enum +from typing import Any, Optional, Union class BACKEND(str, Enum): AUTO = "auto" # choose the optimal local kernel based on quant_config compatibility - AUTO_TRAINABLE = "auto_trainable" # choose the optimal trainable local kernel for post-quant training - - # gptq - TORCH_FUSED = "torch_fused" # optimized for Intel XPU - TORCH_INT8 = "torch_int8" # optimized CPU int8 fused kernel - TORCH = "torch" # GOOD: about 80% of triton - TRITON = "triton" # VERY GOOD: all-around kernel - EXLLAMA_V1 = "exllama_v1" # FAST: optimized for batching == 1 - EXLLAMA_V2 = "exllama_v2" # FASTER: optimized for batching > 1 - EXLLAMA_EORA = "exllama_eora" - MACHETE = "machete" # CUTLASS-based kernel optimized for Hopper (SM90+) - MARLIN = "marlin" # FASTEST: marlin reduce ops in fp32 (higher precision -> more accurate, slightly slower) - MARLIN_FP16 = "marlin_fp16" # FASTEST and then some: marlin reduce ops in fp16 (lower precision -> less accurate, slightly faster) - BITBLAS = "bitblas" # EXTREMELY FAST: speed at the cost of 10+ minutes of AOT (ahead of time compilation with disk cache) - HF_KERNEL = "hf_kernel" # FAST: optimized from HuggingFace kernels-community + AUTO_TRAINABLE = "auto_trainable" # choose the optimal trainable local kernel for post-quant training + + # GPTQ kernels + GPTQ_TORCH_FUSED = "gptq_torch_fused" # optimized for Intel XPU + GPTQ_TORCH_INT8 = "gptq_torch_int8" # optimized CPU int8 fused kernel + GPTQ_TORCH = "gptq_torch" # GOOD: about 80% of triton + GPTQ_TRITON = "gptq_triton" # VERY GOOD: all-around kernel + BITSANDBYTES = "bitsandbytes" # bitsandbytes 4-bit/8-bit kernel with optional CPU/CUDA support + GPTQ_EXLLAMA_V2 = "gptq_exllama_v2" # FASTER: optimized for batching > 1 + GPTQ_MACHETE = "gptq_machete" # CUTLASS-based kernel optimized for Hopper (SM90+) + GPTQ_MARLIN = "gptq_marlin" # marlin reduce ops, fp32 by default; controlled by GPTQMODEL_MARLIN_USE_FP32 + GPTQ_BITBLAS = "gptq_bitblas" # BitBLAS AOT-compiled GPTQ kernel + GPTQ_TORCH_ATEN = "gptq_torch_aten" # CPU int4pack ATen kernel folded into GPT-QModel + GPTQ_PRO = "gptq_pro" # experimental Ampere-only local kernel path for symmetric GPTQ INT4 + HF_KERNEL = "hf_kernel" # Hugging Face GPTQ kernels-community path + + # QQQ kernels + QQQ = "qqq" # marlin-based qqq kernel + + # AWQ kernels + AWQ_GEMM = "awq_gemm" + AWQ_GEMM_TRITON = "awq_gemm_triton" + AWQ_GEMV = "awq_gemv" + AWQ_GEMV_FAST = "awq_gemv_fast" + AWQ_TORCH_INT8 = "awq_torch_int8" + AWQ_TORCH_FUSED = "awq_torch_fused" + AWQ_TORCH_ATEN = "awq_torch_aten" # CPU int4pack ATen kernel folded into GPT-QModel + AWQ_TORCH = "awq_torch" + AWQ_BITBLAS = "awq_bitblas" + AWQ_MACHETE = "awq_machete" + AWQ_MARLIN = "awq_marlin" + AWQ_EXLLAMA_V2 = "awq_exllama_v2" + HF_KERNEL_AWQ = "hf_kernel_awq" + + # ParoQuant kernels + PAROQUANT_CUDA = "paroquant_cuda" + PAROQUANT_TRITON = "paroquant_triton" + + # FP8 kernels + FP8_TORCH = "fp8_torch" + + # GGUF kernels / engines + GGUF_TORCH = "gguf_torch" + GGUF_TRITON = "gguf_triton" + GGUF_CPP_CPU = "gguf_cpp_cpu" + GGUF_CPP_CUDA = "gguf_cpp_cuda" - # qqq - QQQ = "qqq" # marlin based qqq kernel + # EXL3 engines + EXL3_EXLLAMA_V3 = "exl3_exllama_v3" + EXL3_TORCH = "exl3_torch" - # awq + # external engines + VLLM = "vllm" + SGLANG = "sglang" + MLX = "mlx" + + # Legacy generic names kept for compatibility with older call sites and saved args. + TORCH_FUSED = "torch_fused" + TORCH_INT8 = "torch_int8" + TORCH = "torch" + TRITON = "triton" + EXLLAMA_V1 = "exllama_v1" + EXLLAMA_V2 = "exllama_v2" + EXLLAMA_V3 = "exllama_v3" + EXLLAMA_EORA = "exllama_eora" + MACHETE = "machete" + MARLIN = "marlin" + MARLIN_FP16 = "marlin_fp16" + BITBLAS = "bitblas" GEMM = "gemm" GEMM_TRITON = "gemm_triton" GEMV = "gemv" GEMV_FAST = "gemv_fast" - TORCH_INT8_AWQ = "torch_int8_awq" # optimized CPU int8 fused kernel for AWQ - TORCH_FUSED_AWQ = "torch_fused_awq" # AWQ variant of torch fused kernel - HF_KERNEL_AWQ = "hf_kernel_awq" # AWQ variant of HF kernel + TORCH_INT8_AWQ = "torch_int8_awq" + TORCH_FUSED_AWQ = "torch_fused_awq" TORCH_AWQ = "torch_awq" + BITBLAS_AWQ = "bitblas_awq" + PARO = "paroquant" + + +class PROFILE(str, Enum): + # Inference profile selects between alternative runtime/load strategies. + AUTO = "auto" + FAST = "fast" + LOW_MEMORY = "low_memory" + + +_LEGACY_BACKEND_BY_METHOD = { + "gptq": { + BACKEND.TORCH_FUSED: BACKEND.GPTQ_TORCH_FUSED, + BACKEND.TORCH_INT8: BACKEND.GPTQ_TORCH_INT8, + BACKEND.TORCH: BACKEND.GPTQ_TORCH, + BACKEND.TRITON: BACKEND.GPTQ_TRITON, + BACKEND.EXLLAMA_V2: BACKEND.GPTQ_EXLLAMA_V2, + BACKEND.MACHETE: BACKEND.GPTQ_MACHETE, + BACKEND.MARLIN: BACKEND.GPTQ_MARLIN, + BACKEND.BITBLAS: BACKEND.GPTQ_BITBLAS, + }, + "awq": { + BACKEND.GEMM: BACKEND.AWQ_GEMM, + BACKEND.GEMM_TRITON: BACKEND.AWQ_GEMM_TRITON, + BACKEND.GEMV: BACKEND.AWQ_GEMV, + BACKEND.GEMV_FAST: BACKEND.AWQ_GEMV_FAST, + BACKEND.TORCH: BACKEND.AWQ_TORCH, + BACKEND.TORCH_AWQ: BACKEND.AWQ_TORCH, + BACKEND.TORCH_INT8: BACKEND.AWQ_TORCH_INT8, + BACKEND.TORCH_INT8_AWQ: BACKEND.AWQ_TORCH_INT8, + BACKEND.TORCH_FUSED: BACKEND.AWQ_TORCH_FUSED, + BACKEND.TORCH_FUSED_AWQ: BACKEND.AWQ_TORCH_FUSED, + BACKEND.BITBLAS: BACKEND.AWQ_BITBLAS, + BACKEND.BITBLAS_AWQ: BACKEND.AWQ_BITBLAS, + BACKEND.MACHETE: BACKEND.AWQ_MACHETE, + BACKEND.MARLIN: BACKEND.AWQ_MARLIN, + BACKEND.EXLLAMA_V2: BACKEND.AWQ_EXLLAMA_V2, + }, + "paroquant": { + BACKEND.PARO: BACKEND.PAROQUANT_CUDA, + }, + "fp8": { + BACKEND.TORCH: BACKEND.FP8_TORCH, + }, + "exl3": { + BACKEND.EXLLAMA_V3: BACKEND.EXL3_EXLLAMA_V3, + BACKEND.TORCH: BACKEND.EXL3_TORCH, + }, +} + +_PROFILE_BY_INDEX = { + 0: PROFILE.AUTO, + 1: PROFILE.FAST, + 2: PROFILE.LOW_MEMORY, +} + + +def _normalize_method(method: Optional[Union[str, Any]]) -> Optional[str]: + if method is None: + return None + value = getattr(method, "value", method) + return str(value).lower() + + +def normalize_backend( + backend: Optional[Union[str, BACKEND]], + *, + quant_method: Optional[Union[str, Any]] = None, +) -> Optional[BACKEND]: + if backend is None: + return None + + if isinstance(backend, BACKEND): + resolved = backend + elif isinstance(backend, str): + normalized = backend.strip() + if not normalized: + return None + resolved = BACKEND.__members__.get(normalized.upper()) + if resolved is None: + resolved = BACKEND(normalized.lower()) + else: + raise TypeError(f"backend must be a string or BACKEND, got `{type(backend)}`") + + method = _normalize_method(quant_method) + if method is None: + return resolved + return _LEGACY_BACKEND_BY_METHOD.get(method, {}).get(resolved, resolved) + + +def normalize_profile(profile: Optional[Union[str, int, PROFILE]]) -> PROFILE: + if profile is None: + return PROFILE.AUTO + + if isinstance(profile, PROFILE): + return profile + + if isinstance(profile, int): + if profile in _PROFILE_BY_INDEX: + return _PROFILE_BY_INDEX[profile] + raise ValueError(f"Unknown profile index `{profile}`. Expected one of {sorted(_PROFILE_BY_INDEX)}.") + + if not isinstance(profile, str): + raise TypeError(f"profile must be a string, int, or PROFILE, got `{type(profile)}`") + + normalized = profile.strip() + if not normalized: + return PROFILE.AUTO - # external - VLLM = "vllm" # External inference engine: CUDA + ROCm + IPEX - SGLANG = "sglang" # External inference engine: CUDA + ROCm - MLX = "mlx" # External inference engine: Apple MLX on M1+ (Apple Silicon) + alias = normalized.replace("-", "_").replace(" ", "_") + resolved = PROFILE.__members__.get(alias.upper()) + if resolved is not None: + return resolved + return PROFILE(alias.lower()) diff --git a/gptqmodel/utils/bitblas.py b/gptqmodel/utils/bitblas.py index 71523dd1c..e3573add1 100644 --- a/gptqmodel/utils/bitblas.py +++ b/gptqmodel/utils/bitblas.py @@ -8,8 +8,10 @@ import torch -from ..nn_modules.qlinear.bitblas import BitBLASQuantLinear -from ..quantization import FORMAT, QuantizeConfig +from ..nn_modules.qlinear.bitblas import BitBLASLinear +from ..nn_modules.qlinear.bitblas_awq import AWQBitBlasKernel +from ..quantization import FORMAT, METHOD, QuantizeConfig +from ..quantization.config import resolve_quant_format from ..utils.logger import setup_logger from .model import load_checkpoint_in_model_then_tie_weights from .safe import THREADPOOLCTL @@ -18,6 +20,21 @@ log = setup_logger() + +def _select_bitblas_kernel_class(qcfg: QuantizeConfig): + if qcfg.quant_method == METHOD.AWQ: + return AWQBitBlasKernel + return BitBLASLinear + + +def _should_enable_bitblas_tuning(repack: bool) -> bool: + """Keep GPTQ repacks responsive unless tuning is explicitly requested.""" + raw = os.getenv("BITBLAS_ENABLE_TUNING") + if raw is not None: + return raw.strip().lower() not in {"0", "false", "no", "off"} + return not repack + + def prepare_model_for_bitblas_load( model, qcfg: QuantizeConfig, @@ -30,10 +47,10 @@ def prepare_model_for_bitblas_load( load_checkpoint_in_model: bool, ): # The model (e.g. model.safetensors) is already serialized in the BitBLAS format, load it directly. - if qcfg.format == FORMAT.BITBLAS: + if resolve_quant_format(qcfg.format, qcfg.method) == FORMAT.BITBLAS: # if the checkpoint is already in bitblas format, we can load it directly. - log.info(f"Loading a GPTQ model, detected BitBLAS serialized format at {model_save_name}.") - model = convert_to_bitblas(model, quant_linear_class, qcfg, sym, desc_act, repack=False) + log.info(f"Loading a {qcfg.quant_method.upper()} model, detected BitBLAS serialized format at {model_save_name}.") + model = convert_to_bitblas(model, quant_linear_class, qcfg, sym, desc_act, repack=False, dtype=dtype) load_checkpoint_in_model_then_tie_weights( model, dtype=dtype, @@ -57,12 +74,20 @@ def prepare_model_for_bitblas_load( offload_buffers=True, ) # Convert model to bitblas, repacking weights into BitBLAS format. - model = convert_to_bitblas(model, quant_linear_class, qcfg, sym, desc_act, repack=True) + model = convert_to_bitblas(model, quant_linear_class, qcfg, sym, desc_act, repack=True, dtype=dtype) return model @torch.inference_mode() -def convert_to_bitblas(model, model_quantlinear, qcfg: QuantizeConfig, sym: bool, desc_act: bool, repack: bool): +def convert_to_bitblas( + model, + model_quantlinear, + qcfg: QuantizeConfig, + sym: bool, + desc_act: bool, + repack: bool, + dtype: torch.dtype = torch.float16, +): """ Converts GPTQ-packed weights to the Bitblas format. @@ -76,6 +101,8 @@ def convert_to_bitblas(model, model_quantlinear, qcfg: QuantizeConfig, sym: bool # TODO: load directly BitBLAS QuantLinear. message = "Overriding QuantLinear layers to use BitBLAS's QuantLinear..." + bitblas_quantlinear = _select_bitblas_kernel_class(qcfg) + # TODO: need to benchmark to see multiple threads help with bitblas/tvm compilation and runtime threadpool_limits = ( THREADPOOLCTL.threadpool_limits @@ -83,6 +110,8 @@ def convert_to_bitblas(model, model_quantlinear, qcfg: QuantizeConfig, sym: bool else (lambda *args, **kwargs: nullcontext()) ) + enable_tuning = _should_enable_bitblas_tuning(repack) + with threadpool_limits(limits=1): os.environ["NUMEXPR_MAX_THREADS"] = "1" @@ -92,14 +121,13 @@ def convert_to_bitblas(model, model_quantlinear, qcfg: QuantizeConfig, sym: bool if not isinstance(module, model_quantlinear): continue - parent_name = ".".join(name.split(".")[:-1]) - layer_name = name[len(parent_name) + 1:] + parent_name, _, layer_name = name.rpartition(".") # We could use `torch.count_nonzero(module.bias) > 0` here to discard zero bias, but this has issues when loading weights # from checkpoints holding zero bias. with torch.device("meta"): - bitblas_module = BitBLASQuantLinear( - bits=qcfg.bits, + bitblas_module = bitblas_quantlinear( + bits=qcfg.runtime_bits, group_size=qcfg.group_size, sym=sym, desc_act=desc_act, @@ -107,16 +135,21 @@ def convert_to_bitblas(model, model_quantlinear, qcfg: QuantizeConfig, sym: bool out_features=module.out_features, pack_dtype=qcfg.pack_dtype, bias=module.bias is not None, - enable_tuning=True, + dtype=dtype, + enable_tuning=enable_tuning, adapter=qcfg.adapter, + name=name, ) # convert to bitblas format if repack: - bitblas_module.repack_from_gptq(module) + if qcfg.quant_method == METHOD.AWQ: + bitblas_module.repack_from_awq(module) + else: + bitblas_module.repack_from_gptq(module) # Save to parent. - parent_module = model.get_submodule(parent_name) + parent_module = model if parent_name == "" else model.get_submodule(parent_name) setattr(parent_module, layer_name, bitblas_module) # Free cuda memory. diff --git a/gptqmodel/utils/calibration.py b/gptqmodel/utils/calibration.py index de7fcb4c7..82a8660a3 100644 --- a/gptqmodel/utils/calibration.py +++ b/gptqmodel/utils/calibration.py @@ -7,11 +7,13 @@ from __future__ import annotations +import os import random from typing import Any, Dict, List, Optional, Sequence, Union import torch +from .attn_mask import normalize_seq_mask from .data import collate_data from .logger import setup_logger @@ -92,10 +94,56 @@ def prepare_calibration_dataset( if len(raw_examples) == 0: raise ValueError("Quantize: calibration dataset is empty.") + message_examples = 0 + message_template_name = None + message_text_fallback_examples = 0 + def _require_tokenizer(reason: str) -> None: if tokenizer is None: raise ValueError(f"tokenizer must be provided when {reason}.") + message_apply_fn = None + message_apply_name = None + message_template_checked = False + + def _get_message_template(): + # Prefer the model's native chat formatter when calibration rows carry + # `messages`, but only when the tokenizer has an actual template to use. + # Some HF tokenizers expose `apply_chat_template()` while leaving + # `chat_template=None`, which raises at runtime. + nonlocal message_apply_fn, message_apply_name, message_template_checked + if message_template_checked: + return message_apply_fn, message_apply_name + + message_template_checked = True + + if tokenizer is None: + return None, None + + apply_fn = getattr(tokenizer, "apply_template", None) + if callable(apply_fn): + message_apply_fn = apply_fn + message_apply_name = "apply_template" + return message_apply_fn, message_apply_name + + apply_chat_fn = getattr(tokenizer, "apply_chat_template", None) + if callable(apply_chat_fn): + chat_template = getattr(tokenizer, "chat_template", None) + if chat_template is None: + get_chat_template = getattr(tokenizer, "get_chat_template", None) + if callable(get_chat_template): + try: + chat_template = get_chat_template(None, None) + except Exception: + chat_template = None + + if chat_template is not None: + message_apply_fn = apply_chat_fn + message_apply_name = "apply_chat_template" + return message_apply_fn, message_apply_name + + return None, None + def _to_2d_long_tensor(value: Any, name: str, idx: int) -> torch.Tensor: try: tensor = torch.as_tensor(value, dtype=torch.long) @@ -106,27 +154,59 @@ def _to_2d_long_tensor(value: Any, name: str, idx: int) -> torch.Tensor: raise ValueError(f"Quantize: `{name}` for calibration item {idx} must be 1D or 2D, got scalar.") if tensor.ndim == 1: tensor = tensor.unsqueeze(0) + elif tensor.ndim > 2 and name == "attention_mask": + # Some tokenizers emit causal masks shaped like [B, 1, T, T] or [B, T, T]. + # Collapse those higher-rank masks back to the token presence mask expected here. + tensor = tensor.ne(0) + for dim in range(tensor.ndim - 2, 0, -1): + tensor = tensor.any(dim=dim) + tensor = tensor.to(torch.long) elif tensor.ndim != 2: raise ValueError( f"Quantize: `{name}` for calibration item {idx} must be rank 1 or 2, got rank {tensor.ndim}." ) return tensor + def _normalize_attention_mask(mask_value: Any, ids_tensor: torch.Tensor, idx: int) -> torch.Tensor: + try: + mask_tensor = torch.as_tensor(mask_value) + except Exception as exc: # pragma: no cover - defensive + raise ValueError( + f"Quantize: failed to convert `attention_mask` to tensor for calibration item {idx}." + ) from exc + + if mask_tensor.ndim == 0: + raise ValueError( + f"Quantize: `attention_mask` for calibration item {idx} must be rank 1 or higher, got scalar." + ) + if mask_tensor.ndim == 1: + mask_tensor = mask_tensor.unsqueeze(0) + + try: + keep_mask = normalize_seq_mask(mask_tensor, seq_len=ids_tensor.shape[-1]) + except ValueError as exc: + raise ValueError( + f"Quantize: failed to normalize `attention_mask` for calibration item {idx}: {exc}" + ) from exc + + mask_tensor = keep_mask.to(dtype=torch.long) + if mask_tensor.shape != ids_tensor.shape: + if mask_tensor.numel() == ids_tensor.numel(): + mask_tensor = mask_tensor.reshape(ids_tensor.shape) + else: + raise ValueError( + f"Quantize: attention_mask shape {tuple(mask_tensor.shape)} does not match input_ids shape " + f"{tuple(ids_tensor.shape)} for calibration item {idx}." + ) + return mask_tensor + def _pack_ids(ids_value: Any, mask_value: Any, idx: int) -> Dict[str, torch.Tensor]: ids_tensor = _to_2d_long_tensor(ids_value, "input_ids", idx) if mask_value is None: mask_tensor = torch.ones_like(ids_tensor, dtype=torch.long) else: - mask_tensor = _to_2d_long_tensor(mask_value, "attention_mask", idx) - if mask_tensor.shape != ids_tensor.shape: - if mask_tensor.numel() == ids_tensor.numel(): - mask_tensor = mask_tensor.reshape(ids_tensor.shape) - else: - raise ValueError( - f"Quantize: attention_mask shape {tuple(mask_tensor.shape)} does not match input_ids shape " - f"{tuple(ids_tensor.shape)} for calibration item {idx}." - ) + mask_tensor = _normalize_attention_mask(mask_value, ids_tensor, idx) return { "input_ids": ids_tensor.detach(), @@ -145,18 +225,27 @@ def _tokenize_text_value(text_value: Any, idx: int) -> Dict[str, torch.Tensor]: return _pack_ids(input_ids, attention_mask, idx) def _tokenize_messages_value(messages_value: Any, idx: int) -> Dict[str, torch.Tensor]: + nonlocal message_examples, message_template_name _require_tokenizer("calibration data uses the `messages` feature") - apply_fn = getattr(tokenizer, "apply_template", None) + apply_fn, template_name = _get_message_template() if apply_fn is None: - raise ValueError("tokenizer must expose `apply_template` to handle `messages` calibration data.") + raise ValueError( + "tokenizer must expose `apply_template` or `apply_chat_template` to handle `messages` calibration data." + ) try: - templated = apply_fn(messages_value, tokenize=False) + if template_name == "apply_chat_template": + templated = apply_fn(messages_value, tokenize=False, add_generation_prompt=False) + else: + templated = apply_fn(messages_value, tokenize=False) except TypeError: templated = apply_fn(messages_value) if templated is None: raise ValueError(f"tokenizer.apply_template returned None for calibration item {idx}.") + message_examples += 1 + message_template_name = template_name + if hasattr(templated, "get"): ids_value = templated.get("input_ids") mask_value = templated.get("attention_mask") @@ -187,13 +276,15 @@ def _tokenize_messages_value(messages_value: Any, idx: int) -> Dict[str, torch.T for idx, example in enumerate(raw_examples): if isinstance(example, dict): if "messages" in example: - apply_fn = getattr(tokenizer, "apply_template", None) if tokenizer else None + apply_fn, _ = _get_message_template() if apply_fn is None: if "text" in example: + message_text_fallback_examples += 1 processed_examples.append(_tokenize_text_value(example["text"], idx)) continue raise ValueError( - "tokenizer must expose `apply_template` or calibration data must provide `text` when using `messages`." + "tokenizer must expose `apply_template` or `apply_chat_template`, or calibration data must " + "provide `text` when using `messages`." ) processed_examples.append(_tokenize_messages_value(example["messages"], idx)) continue @@ -281,6 +372,12 @@ def _maybe_resolve_length(value, source_name): if _maybe_resolve_length(getattr(model_config, attr_name, None), attr_name): break + padding_side = getattr(tokenizer, "padding_side", "right") if tokenizer is not None else "right" + if padding_side not in ("left", "right"): + raise ValueError( + f"Unsupported tokenizer.padding_side `{padding_side}`. Expected `left` or `right`." + ) + for example in calibration_dataset: input_ids = _convert_tensor_to_list(example["input_ids"]) attention_mask = _convert_tensor_to_list(example["attention_mask"]) @@ -296,8 +393,12 @@ def _maybe_resolve_length(value, source_name): trimmed = True trimmed_row_count += 1 longest_trimmed_row = max(longest_trimmed_row, row_len) - trimmed_input_ids.append(row_ids[:max_positions]) - trimmed_attention_mask.append(row_mask[:max_positions]) + if padding_side == "left": + trimmed_input_ids.append(row_ids[-max_positions:]) + trimmed_attention_mask.append(row_mask[-max_positions:]) + else: + trimmed_input_ids.append(row_ids[:max_positions]) + trimmed_attention_mask.append(row_mask[:max_positions]) else: trimmed_input_ids.append(row_ids) trimmed_attention_mask.append(row_mask) @@ -323,6 +424,19 @@ def _maybe_resolve_length(value, source_name): f"Use quantize(calibration_data_min_length={calibration_data_min_length}) to set a custom minimum length." ) + if message_examples > 0 and message_template_name is not None: + log.info( + "Calibration: tokenized %s `messages` examples via tokenizer.%s", + message_examples, + message_template_name, + ) + if message_text_fallback_examples > 0: + log.warn( + "Calibration: fell back to raw `text` for %s `messages` examples because the tokenizer has no message " + "template configured.", + message_text_fallback_examples, + ) + if trimmed_row_count > 0: log.info( "Quantize: trimmed %s calibration rows above %s=%s (longest original length=%s)", @@ -400,8 +514,12 @@ def flush_buffer(): padding_length = calibration_dataset_concat_size - len(input_ids_buff) if padding_length > 0: pad_id = getattr(tokenizer, "pad_token_id", 0) - input_ids_buff.extend([pad_id] * padding_length) - attention_mask_buff.extend([0] * padding_length) + if padding_side == "left": + input_ids_buff = ([pad_id] * padding_length) + input_ids_buff + attention_mask_buff = ([0] * padding_length) + attention_mask_buff + else: + input_ids_buff.extend([pad_id] * padding_length) + attention_mask_buff.extend([0] * padding_length) concatenated_data.append( { "input_ids": [input_ids_buff], @@ -432,10 +550,33 @@ def flush_buffer(): log.info("Calibration: Native order") sorted_dataset = new_calibration_dataset + preview_count = max(0, int(os.getenv("GPTQMODEL_LOG_CALIBRATION_SAMPLES", "0") or 0)) + if preview_count > 0: + # Preview the exact token rows that will be batched for quantization. + for idx, example in enumerate(sorted_dataset[:preview_count], start=1): + row_ids = example["input_ids"][0] + preview = "" + if tokenizer is not None: + try: + preview = tokenizer.decode(row_ids[:128], skip_special_tokens=False).replace("\n", " ") + except Exception: + preview = "" + log.info( + "Calibration sample %s/%s: tokens=%s preview=%r", + idx, + min(preview_count, len(sorted_dataset)), + len(row_ids), + preview[:240], + ) + if support_batch_quantize: pad_token_id = getattr(tokenizer, "pad_token_id", 0) if tokenizer is not None else 0 new_calibration_dataset_batched = [ - collate_data(sorted_dataset[start : start + batch_size], pad_token_id) + collate_data( + sorted_dataset[start : start + batch_size], + pad_token_id, + padding_side=getattr(tokenizer, "padding_side", "right"), + ) for start in range(0, len(sorted_dataset), batch_size) ] diff --git a/gptqmodel/utils/cpp.py b/gptqmodel/utils/cpp.py index 57cde1ea2..3653ea6cd 100644 --- a/gptqmodel/utils/cpp.py +++ b/gptqmodel/utils/cpp.py @@ -5,30 +5,847 @@ from __future__ import annotations -import importlib.util +import hashlib import logging +import math import os +import re import shutil +import subprocess +import sys import threading +import time from pathlib import Path -from typing import Optional +from typing import Callable, Optional, Sequence +import pcre import torch -from torch.utils.cpp_extension import _get_build_directory, load +from torch.utils.cpp_extension import CUDA_HOME, _get_build_directory, _get_cuda_arch_flags, load from .env import env_flag +from .jit_compile_baselines import get_jit_compile_baseline_seconds +from .logger import setup_logger log = logging.getLogger(__name__) +# One process-local lock serializes every torch.ops JIT cache mutation and +# compile so concurrent startup paths never overlap toolchain work across +# different extensions. +_TORCH_OPS_JIT_LOCK = threading.Lock() + _PACK_BLOCK_EXTENSION: Optional[bool] = None _PACK_BLOCK_EXTENSION_INITIALISED = False -_cpp_ext_lock = threading.Lock() +_PACK_BLOCK_TORCH_OPS_EXTENSION = None +_FLOATX_CPU_TORCH_OPS_EXTENSION = None + +_cpp_ext_lock = _TORCH_OPS_JIT_LOCK # Used to track whether cleanup has been done already _cpp_ext_initialized = False +_SHARED_LIBRARY_SUFFIXES = (".so", ".pyd", ".dylib", ".dll") +_COMPILE_PROGRESS_TOTAL_STEPS = 100 +_COMPILE_PROGRESS_INTERVAL_SECONDS = 1.0 +_LOCAL_INCLUDE_PATTERN = pcre.compile( + r'^\s*#\s*include\s+"([^"]+)"', + flags=pcre.Flag.MULTILINE, +) +# Default NVCC internal threading for JIT builds. This is based on clean-build +# timings collected on an AMD Zen 3 CPU running at 2.2 GHz, where 8 threads was +# the best overall tradeoff across Marlin, AWQ, QQQ, ExLlama, and ParoQuant. +_DEFAULT_NVCC_THREADS = "8" + + +def local_nvcc_version_at_least(major: int, minor: int) -> bool: + nvcc_path = shutil.which("nvcc") + if not nvcc_path: + return False + + try: + result = subprocess.run( + [nvcc_path, "--version"], + capture_output=True, + text=True, + check=False, + ) + except OSError: + return False + + version_text = (result.stdout or "") + "\n" + (result.stderr or "") + match = re.search(r"release\s+(\d+)\.(\d+)", version_text) + if not match: + return False + + return (int(match.group(1)), int(match.group(2))) >= (major, minor) + + +def _format_compile_duration_seconds(seconds: float) -> str: + """Format one duration compactly for user-facing compile progress text.""" + + seconds_value = max(0.0, float(seconds)) + if seconds_value < 10.0: + return f"{seconds_value:.1f}s" + return f"{seconds_value:.0f}s" + + +def _compile_progress_ratio(elapsed_seconds: float, baseline_seconds: float) -> float: + """Map elapsed compile time onto a progress ratio that never reaches 100% early.""" + + baseline = max(float(baseline_seconds), 0.0) + elapsed = max(float(elapsed_seconds), 0.0) + if baseline <= 0.0 or elapsed <= 0.0: + return 0.0 + if elapsed <= baseline: + return min(0.95 * (elapsed / baseline), 0.95) + + overrun = elapsed - baseline + tail_ratio = 1.0 - math.exp(-overrun / max(baseline, 1.0)) + return min(0.95 + (0.04 * tail_ratio), 0.99) + + +def _compile_progress_step( + elapsed_seconds: float, + baseline_seconds: float, + *, + total_steps: int = _COMPILE_PROGRESS_TOTAL_STEPS, +) -> int: + """Convert one elapsed/baseline pair into a bounded manual progress step.""" + + if total_steps <= 1: + return 0 + ratio = _compile_progress_ratio(elapsed_seconds, baseline_seconds) + return max(0, min(total_steps - 1, int(math.floor(ratio * (total_steps - 1))))) + + +def _compile_progress_subtitle(elapsed_seconds: float, baseline_seconds: float) -> str: + """Describe compile elapsed time against the recorded reference baseline.""" + + elapsed = max(float(elapsed_seconds), 0.0) + baseline = max(float(baseline_seconds), 0.0) + if baseline <= 0.0: + return f"elapsed {_format_compile_duration_seconds(elapsed)}" + if elapsed <= baseline: + return ( + f"elapsed {_format_compile_duration_seconds(elapsed)} / " + f"estimated ~{_format_compile_duration_seconds(baseline)}" + ) + return ( + f"elapsed {_format_compile_duration_seconds(elapsed)} / " + f"estimated ~{_format_compile_duration_seconds(baseline)} " + f"(+{_format_compile_duration_seconds(elapsed - baseline)})" + ) + + +def _compile_baseline_summary(elapsed_seconds: float, baseline_seconds: Optional[float]) -> str: + """Format a concise compile-vs-baseline summary for durable log lines.""" + + elapsed = _format_compile_duration_seconds(elapsed_seconds) + if baseline_seconds is None or baseline_seconds <= 0: + return f"in {elapsed}" + + baseline = _format_compile_duration_seconds(baseline_seconds) + delta = elapsed_seconds - baseline_seconds + delta_text = _format_compile_duration_seconds(abs(delta)) + if abs(delta) < 0.05: + return f"in {elapsed} (estimated ~{baseline})" + sign = "+" if delta >= 0 else "-" + return f"in {elapsed} (estimated ~{baseline}, {sign}{delta_text})" + + +class _CompileProgressDisplay: + """Render either a baseline-backed progress bar or a fallback spinner.""" + + def __init__( + self, + *, + logger, + title: str, + baseline_seconds: Optional[float], + ) -> None: + self._logger = logger + self._title = title + self._baseline_seconds = ( + None if baseline_seconds is None or baseline_seconds <= 0 else float(baseline_seconds) + ) + self._started = time.perf_counter() + self._stop_event: Optional[threading.Event] = None + self._thread: Optional[threading.Thread] = None + self._progress = None + self._spinner = None + self._render_lock = threading.Lock() + self._closed = False + + if self._baseline_seconds is None: + self._spinner = logger.spinner(title=title, interval=_COMPILE_PROGRESS_INTERVAL_SECONDS) + return + + progress = logger.pb(range(_COMPILE_PROGRESS_TOTAL_STEPS)).manual().set(show_left_steps=False) + progress.title(title) + progress.subtitle(_compile_progress_subtitle(0.0, self._baseline_seconds)) + progress.draw(force=True) + self._progress = progress + self._stop_event = threading.Event() + self._thread = threading.Thread( + target=self._refresh_loop, + name=f"jit-compile-progress-{title}", + daemon=True, + ) + self._thread.start() + + def elapsed_seconds(self) -> float: + return max(0.0, time.perf_counter() - self._started) + + def _refresh_loop(self) -> None: + assert self._stop_event is not None + while not self._stop_event.wait(_COMPILE_PROGRESS_INTERVAL_SECONDS): + self._draw_current(force=False) + + def _draw_current(self, *, force: bool) -> None: + if self._progress is None or self._baseline_seconds is None: + return + if self._closed: + return + with self._render_lock: + if self._closed: + return + elapsed = self.elapsed_seconds() + self._progress.current_iter_step = _compile_progress_step(elapsed, self._baseline_seconds) + self._progress.subtitle(_compile_progress_subtitle(elapsed, self._baseline_seconds)) + self._progress.draw(force=force) + + def close(self, *, succeeded: bool, elapsed_seconds: Optional[float] = None) -> None: + elapsed = self.elapsed_seconds() if elapsed_seconds is None else max(0.0, float(elapsed_seconds)) + if self._spinner is not None: + self._spinner.close() + return + if self._stop_event is not None: + self._stop_event.set() + if self._progress is None or self._baseline_seconds is None: + return + # Completion is driven by the actual build result and elapsed time, not + # by the estimated baseline. A faster-than-expected compile should exit + # immediately and force the bar to its final state. + self._closed = True + with self._render_lock: + self._progress.current_iter_step = ( + _COMPILE_PROGRESS_TOTAL_STEPS if succeeded else _compile_progress_step(elapsed, self._baseline_seconds) + ) + self._progress.subtitle(_compile_progress_subtitle(elapsed, self._baseline_seconds)) + self._progress.draw(force=True) + self._progress.close() + if self._thread is not None and self._thread.is_alive(): + self._thread.join(timeout=0.05) + + +def default_torch_ops_build_root(subdir: str) -> Path: + """Return the default on-disk cache root for torch.ops JIT extensions.""" + + return Path.home() / ".cache" / "gptqmodel" / "torch_extensions" / subdir + + +def _dedupe_path_strings(paths: Sequence[str]) -> list[str]: + """Normalize and deduplicate include/library path strings while preserving order.""" + + deduped: list[str] = [] + seen: set[str] = set() + for raw_path in paths: + normalized = str(Path(raw_path).expanduser()) + if normalized in seen: + continue + seen.add(normalized) + deduped.append(normalized) + return deduped + + +def detected_cuda_wheel_include_paths() -> list[str]: + """Discover CUDA developer headers shipped via NVIDIA Python wheels.""" + + try: + import nvidia # type: ignore + except Exception: + return [] + + include_paths: list[str] = [] + for base_text in getattr(nvidia, "__path__", []): + base_path = Path(base_text) + candidate_paths = list(base_path.glob("cu*/include")) + candidate_paths.extend(base_path.glob("*/include")) + for candidate in sorted(candidate_paths): + if candidate.is_dir(): + include_paths.append(str(candidate)) + return _dedupe_path_strings(include_paths) + + +def _resolve_local_include_path( + include_name: str, + *, + including_path: Path, + include_search_roots: Sequence[Path], +) -> Optional[Path]: + """Resolve one quoted local include against the current file and explicit include roots.""" + + include_path = Path(include_name).expanduser() + if include_path.is_absolute(): + resolved = include_path.resolve(strict=False) + return resolved if resolved.exists() else None + + search_roots = [including_path.parent, *include_search_roots] + for root in search_roots: + candidate = (root / include_path).resolve(strict=False) + if candidate.exists(): + return candidate + return None + + +def detected_local_cuda_include_paths() -> list[str]: + """Discover CUDA developer headers from the active local CUDA toolkit.""" + + include_paths: list[str] = [] + + if CUDA_HOME: + candidate = Path(CUDA_HOME).expanduser() / "include" + if candidate.is_dir(): + include_paths.append(str(candidate)) + + cuda_path = os.getenv("CUDA_PATH") + if cuda_path: + candidate = Path(cuda_path).expanduser() / "include" + if candidate.is_dir(): + include_paths.append(str(candidate)) + + return _dedupe_path_strings(include_paths) + + +def _detected_local_cuda_has_required_headers(required_header_names: Sequence[str]) -> bool: + """Return whether the detected local CUDA toolkit exposes every required header.""" + + local_cuda_include_paths = detected_local_cuda_include_paths() + if not local_cuda_include_paths: + return False + return all( + any((Path(include_path) / header_name).is_file() for include_path in local_cuda_include_paths) + for header_name in required_header_names + ) + + +def cuda_include_paths_with_fallback( + include_paths: Sequence[str], + *, + required_header_names: Sequence[str] = (), +) -> list[str]: + """Append NVIDIA wheel headers when the local CUDA toolkit is absent or incomplete.""" + + resolved_include_paths = _dedupe_path_strings(include_paths) + if not _detected_local_cuda_has_required_headers(required_header_names): + resolved_include_paths.extend(detected_cuda_wheel_include_paths()) + return _dedupe_path_strings(resolved_include_paths) + + +def resolved_cuda_arch_flags() -> list[str]: + """Return the effective NVCC arch flags Torch will emit for this host.""" + + try: + return list(_get_cuda_arch_flags()) + except Exception: + return [] + + +def torch_cxx11_abi_flag() -> int: + """Return the ABI mode local JIT extensions must match for this torch build.""" + + return int(getattr(torch._C, "_GLIBCXX_USE_CXX11_ABI", 1)) + + +def torch_cxx11_abi_define() -> str: + """Return the compiler define that keeps local extensions ABI-compatible.""" + + return f"-D_GLIBCXX_USE_CXX11_ABI={torch_cxx11_abi_flag()}" + + +def resolved_jit_opt_level(opt_level: str | None = "O3") -> str | None: + """Resolve the effective JIT optimization level, honoring the global env override.""" + + override = os.getenv("GPTQMODEL_NVCC_COMPILE_LEVEL") + raw_level = override if override is not None else opt_level + if raw_level is None: + return None + + normalized = str(raw_level).strip() + if not normalized: + return None + if normalized.startswith("-"): + normalized = normalized[1:] + normalized = normalized.upper() + + if normalized in {"NONE", "NOOPT", "NO_OPT", "OFF", "DISABLE", "0"}: + return None + if normalized in {"O0", "O1", "O2", "O3"}: + return normalized + raise ValueError( + "GPTQMODEL_NVCC_COMPILE_LEVEL must be one of O0/O1/O2/O3 or NONE/NOOPT/OFF." + ) + + +def default_jit_cflags( + *, + opt_level: str | None = "O3", + enable_bf16: bool = False, + include_abi: bool = True, +) -> list[str]: + """Return the common C++ compiler flags for torch.ops JIT extensions.""" + + resolved_opt_level = resolved_jit_opt_level(opt_level) + flags = ["-std=c++17"] + if resolved_opt_level is not None: + flags.insert(0, f"-{resolved_opt_level}") + if enable_bf16: + flags.append("-DENABLE_BF16") + if include_abi: + flags.append(torch_cxx11_abi_define()) + return flags + + +def default_jit_cuda_cflags( + *, + opt_level: str | None = "O3", + enable_bf16: bool = False, + include_abi: bool = True, + include_lineinfo: bool = False, + include_nvcc_threads: bool = True, + include_ptxas_optimizations: bool = False, + include_ptxas_verbosity: bool = True, + include_fatbin_compression: bool = False, + include_diag_suppress: bool = False, + nvcc_threads: str | int | None = None, +) -> list[str]: + """Return the common NVCC flags for torch.ops JIT CUDA extensions.""" + + resolved_opt_level = resolved_jit_opt_level(opt_level) + flags = default_jit_cflags( + opt_level=resolved_opt_level, + enable_bf16=enable_bf16, + include_abi=include_abi, + ) + + if include_nvcc_threads: + resolved_nvcc_threads = str(nvcc_threads) if nvcc_threads is not None else os.getenv("NVCC_THREADS", _DEFAULT_NVCC_THREADS) + flags.extend(["--threads", resolved_nvcc_threads]) + if resolved_opt_level is not None: + optimization_level = ( + resolved_opt_level[1:] if resolved_opt_level.startswith("O") else resolved_opt_level + ) + flags.append(f"--optimize={optimization_level}") + if include_ptxas_optimizations: + ptxas_flags = ["-v"] if include_ptxas_verbosity else [] + if resolved_opt_level is not None: + ptxas_flags.append(f"-{resolved_opt_level}") + ptxas_flags.append("-dlcm=ca") + flags.extend(["-Xptxas", ",".join(ptxas_flags)]) + if include_lineinfo: + flags.append("-lineinfo") + if include_fatbin_compression: + flags.extend(["-Xfatbin", "-compress-all"]) + if include_diag_suppress: + flags.append("-diag-suppress=179,39,177") + return flags + + +class TorchOpsJitExtension: + """Manage one torch.ops JIT extension with shared cache and rebuild policy.""" + + def __init__( + self, + *, + name: str, + namespace: str, + required_ops: Sequence[str], + sources: Sequence[str] | Callable[[], Sequence[str]], + build_root_env: Optional[str], + default_build_root: Path | str | Callable[[], Path | str], + display_name: str, + extra_cflags: Optional[Sequence[str] | Callable[[], Sequence[str]]] = None, + extra_cuda_cflags: Optional[Sequence[str] | Callable[[], Sequence[str]]] = None, + extra_include_paths: Optional[Sequence[str] | Callable[[], Sequence[str]]] = None, + extra_ldflags: Optional[Sequence[str] | Callable[[], Sequence[str]]] = None, + force_rebuild_env: Optional[str] = None, + verbose_env: Optional[str] = None, + requires_cuda: bool = False, + binary_names: Optional[Sequence[str]] = None, + ) -> None: + self.name = name + self.namespace = namespace + self.required_ops = tuple(required_ops) + self.sources = sources + self.build_root_env = build_root_env + self.default_build_root = default_build_root + self.display_name = display_name + self.extra_cflags = extra_cflags + self.extra_cuda_cflags = extra_cuda_cflags + self.extra_include_paths = extra_include_paths + self.extra_ldflags = extra_ldflags + self.force_rebuild_env = force_rebuild_env + self.verbose_env = verbose_env + self.requires_cuda = bool(requires_cuda) + self.binary_names = tuple(binary_names or (name,)) + self.compile_baseline_seconds = get_jit_compile_baseline_seconds(name) + self._load_attempted = False + self._load_result = False + self._last_error = "" + self._namespace_cache: Optional[object] = None + self._op_cache: dict[str, object] = {} + self._lock = self._get_shared_lock() + + @classmethod + def _get_shared_lock(cls) -> threading.Lock: + """Reuse the single process-local lock for every JIT extension.""" + return _TORCH_OPS_JIT_LOCK + + def _resolve_path(self, value: Path | str | Callable[[], Path | str]) -> Path: + resolved = value() if callable(value) else value + return Path(resolved).expanduser() + + def _resolve_sequence( + self, + value: Optional[Sequence[str] | Callable[[], Sequence[str]]], + ) -> list[str]: + if value is None: + return [] + resolved = value() if callable(value) else value + return [str(item) for item in resolved] + + def _resolved_extra_include_paths(self) -> list[str]: + """Resolve explicit include paths and append CUDA wheel headers when needed.""" + + include_paths = self._resolve_sequence(self.extra_include_paths) + if not self.requires_cuda: + return _dedupe_path_strings(include_paths) + return cuda_include_paths_with_fallback(include_paths) + + def base_build_root(self) -> Path: + """Return the user-visible cache root before applying the loader fingerprint.""" + + override = os.getenv(self.build_root_env) if self.build_root_env else None + if override: + return Path(override).expanduser() + return self._resolve_path(self.default_build_root) + + def _source_cache_fingerprint_payload(self, source: str, include_paths: Sequence[str]) -> list[str]: + """Hash one source file plus recursively discovered quoted local includes.""" + + payload: list[str] = [] + visited: set[Path] = set() + include_search_roots = [Path(path).expanduser().resolve(strict=False) for path in include_paths] + + def visit(path: Path) -> None: + normalized = path.expanduser().resolve(strict=False) + if normalized in visited: + return + visited.add(normalized) + payload.append(str(normalized)) + + if not normalized.exists(): + payload.append("missing") + return + + try: + source_bytes = normalized.read_bytes() + except OSError as exc: + payload.append(f"read_error={type(exc).__name__}") + return + + payload.append(hashlib.sha256(source_bytes).hexdigest()) + + source_text = source_bytes.decode("utf-8", errors="ignore") + for include_name in _LOCAL_INCLUDE_PATTERN.findall(source_text): + included_path = _resolve_local_include_path( + include_name, + including_path=normalized, + include_search_roots=include_search_roots, + ) + if included_path is None: + payload.append(f"missing_include={normalized}:{include_name}") + continue + visit(included_path) + + visit(Path(source)) + return payload + + def _cache_fingerprint(self) -> str: + """Hash the effective op surface and source metadata to avoid stale cache reuse.""" + + payload: list[str] = [self.name, self.namespace, *self.required_ops] + payload.append(f"python={sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}") + payload.append(f"torch={torch.__version__}") + payload.append(f"torch_cuda={torch.version.cuda or 'none'}") + payload.extend(self._cuda_cache_fingerprint_payload()) + include_paths = self._resolved_extra_include_paths() + for source in self._resolve_sequence(self.sources): + payload.extend(self._source_cache_fingerprint_payload(source, include_paths)) + + payload.extend(self._resolve_sequence(self.extra_cflags)) + payload.extend(self._resolve_sequence(self.extra_cuda_cflags)) + payload.extend(include_paths) + payload.extend(self._resolve_sequence(self.extra_ldflags)) + digest = hashlib.sha256("\0".join(payload).encode("utf-8")).hexdigest() + return digest[:16] + + def _cuda_cache_fingerprint_payload(self) -> list[str]: + """Capture the effective CUDA target set so cached binaries stay device-compatible.""" + + if not self.requires_cuda: + return ["cuda_ext=0"] + + payload = ["cuda_ext=1"] + override = os.getenv("TORCH_CUDA_ARCH_LIST") + if override: + payload.append(f"arch_list={override}") + arch_flags = resolved_cuda_arch_flags() + if arch_flags: + payload.append(f"resolved_arch_flags={','.join(arch_flags)}") + return payload + + if not torch.cuda.is_available(): + payload.append("cuda_available=0") + return payload + + capabilities: set[str] = set() + for device_index in range(torch.cuda.device_count()): + major, minor = torch.cuda.get_device_capability(device_index) + capabilities.add(f"{major}.{minor}") + + if not capabilities: + payload.append("visible_caps=none") + else: + payload.append(f"visible_caps={','.join(sorted(capabilities))}") + + arch_flags = resolved_cuda_arch_flags() + if arch_flags: + payload.append(f"resolved_arch_flags={','.join(arch_flags)}") + return payload + + def build_root(self) -> Path: + """Return the fingerprinted filesystem directory that caches this JIT extension.""" + + return self.base_build_root() / self._cache_fingerprint() + + def force_rebuild_enabled(self) -> bool: + """Check whether this extension should ignore and replace cached binaries.""" + + if not self.force_rebuild_env: + return False + return env_flag(self.force_rebuild_env, default=False) + + def _ops_available(self) -> bool: + namespace = getattr(torch.ops, self.namespace, None) + return namespace is not None and all(hasattr(namespace, op_name) for op_name in self.required_ops) + + def _refresh_runtime_cache(self) -> bool: + namespace = getattr(torch.ops, self.namespace, None) + if namespace is None: + self._namespace_cache = None + self._op_cache = {} + return False + missing = [op_name for op_name in self.required_ops if not hasattr(namespace, op_name)] + if missing: + self._namespace_cache = None + self._op_cache = {} + return False + self._namespace_cache = namespace + self._op_cache = {op_name: getattr(namespace, op_name) for op_name in self.required_ops} + return True + + def _candidate_binary_paths(self, build_root: Path) -> list[Path]: + seen: set[Path] = set() + candidates: list[Path] = [] + for binary_name in self.binary_names: + for suffix in _SHARED_LIBRARY_SUFFIXES: + exact = build_root / f"{binary_name}{suffix}" + if exact not in seen: + seen.add(exact) + candidates.append(exact) + for match in sorted(build_root.glob(f"{binary_name}*{suffix}")): + if match not in seen: + seen.add(match) + candidates.append(match) + return candidates + + def _try_load_prebuilt_library(self, build_root: Path) -> bool: + for library_path in self._candidate_binary_paths(build_root): + if not library_path.is_file(): + continue + try: + torch.ops.load_library(str(library_path)) + if self._ops_available(): + return True + except Exception as exc: # pragma: no cover - binary/runtime mismatch depends on host + log.debug("%s: failed to load cached torch.ops library %s: %s", self.display_name, library_path, exc) + return False + + def clear_cache(self) -> None: + """Best-effort cache clear for the next process-local JIT load attempt.""" + + with self._lock: + self._load_attempted = False + self._load_result = False + self._last_error = "" + self._namespace_cache = None + self._op_cache = {} + build_root = self.base_build_root() + if build_root.exists(): + shutil.rmtree(build_root, ignore_errors=True) + + def last_error_message(self) -> str: + """Return the most recent human-readable load failure.""" + + return self._last_error + + def load(self) -> bool: + """Load the extension from cache or JIT-compile it on first use.""" + + if self._load_attempted and self._load_result and not self.force_rebuild_enabled(): + return True + + if self._namespace_cache is not None and not self.force_rebuild_enabled(): + self._load_attempted = True + self._load_result = True + self._last_error = "" + return True + + if self._ops_available(): + self._refresh_runtime_cache() + self._load_attempted = True + self._load_result = True + self._last_error = "" + return True + + if self.requires_cuda and not torch.cuda.is_available(): + self._load_attempted = True + self._load_result = False + self._last_error = f"{self.display_name}: CUDA is not available." + return False + + with self._lock: + force_rebuild = self.force_rebuild_enabled() + if self._load_attempted and self._load_result and not force_rebuild: + return True + if self._namespace_cache is not None and not force_rebuild: + self._load_attempted = True + self._load_result = True + self._last_error = "" + return True + if self._ops_available(): + self._refresh_runtime_cache() + self._load_attempted = True + self._load_result = True + self._last_error = "" + return True + if self._load_attempted and not force_rebuild: + return self._load_result + build_root = self.build_root() + base_build_root = self.base_build_root() + + if force_rebuild and base_build_root.exists(): + setup_logger().info(f"{self.display_name}: clearing cached JIT extension at `{base_build_root}`.") + shutil.rmtree(base_build_root, ignore_errors=True) + + build_root.mkdir(parents=True, exist_ok=True) + + if not force_rebuild and self._try_load_prebuilt_library(build_root): + self._load_attempted = True + self._load_result = True + self._last_error = "" + return True + + logger = setup_logger() + logger.info(f"{self.display_name}: compiling torch.ops JIT extension in `{build_root}`.") + progress_display = _CompileProgressDisplay( + logger=logger, + title=f"Compiling extension: {self.display_name}...", + baseline_seconds=self.compile_baseline_seconds, + ) + started = time.perf_counter() + build_invocation_succeeded = False + try: + resolved_sources = self._resolve_sequence(self.sources) + extra_include_paths = self._resolved_extra_include_paths() + kwargs = { + "name": self.name, + "sources": resolved_sources, + "build_directory": str(build_root), + "is_python_module": False, + "verbose": env_flag(self.verbose_env, default=False) if self.verbose_env else False, + } + extra_cflags = self._resolve_sequence(self.extra_cflags) + if extra_cflags: + kwargs["extra_cflags"] = extra_cflags + extra_cuda_cflags = self._resolve_sequence(self.extra_cuda_cflags) + if extra_cuda_cflags: + kwargs["extra_cuda_cflags"] = extra_cuda_cflags + if extra_include_paths: + kwargs["extra_include_paths"] = extra_include_paths + extra_ldflags = self._resolve_sequence(self.extra_ldflags) + if extra_ldflags: + kwargs["extra_ldflags"] = extra_ldflags + load(**kwargs) + build_invocation_succeeded = True + except Exception as exc: # pragma: no cover - build depends on host toolchain + elapsed = time.perf_counter() - started + self._load_attempted = True + self._load_result = False + self._last_error = f"{self.display_name}: failed to build torch.ops JIT extension: {exc}" + log.debug("%s", self._last_error, exc_info=True) + logger.info( + f"{self.display_name}: torch.ops JIT compilation failed " + f"{_compile_baseline_summary(elapsed, self.compile_baseline_seconds)}; using fallback path." + ) + return False + finally: + elapsed = time.perf_counter() - started + progress_display.close(succeeded=build_invocation_succeeded, elapsed_seconds=elapsed) + + elapsed = time.perf_counter() - started + ready = self._refresh_runtime_cache() or self._try_load_prebuilt_library(build_root) + self._load_attempted = True + self._load_result = ready + if ready: + self._refresh_runtime_cache() + self._last_error = "" + logger.info( + f"{self.display_name}: torch.ops JIT extension ready " + f"{_compile_baseline_summary(elapsed, self.compile_baseline_seconds)}." + ) + return True + + self._last_error = f"{self.display_name}: build completed but required torch.ops were not registered." + logger.info(f"{self.display_name}: torch.ops JIT build finished without registering required ops.") + return False + + def namespace_object(self) -> object: + """Return the cached torch.ops namespace after loading this extension.""" + + if self._namespace_cache is not None: + return self._namespace_cache + if not self.load(): + raise RuntimeError(self.last_error_message() or f"{self.display_name}: runtime unavailable.") + if self._refresh_runtime_cache(): + return self._namespace_cache + raise RuntimeError(f"{self.display_name}: required torch.ops namespace `{self.namespace}` is unavailable.") + + def op(self, op_name: str) -> object: + """Return a cached torch.ops handle for one registered op.""" + + cached = self._op_cache.get(op_name) + if cached is not None: + return cached + namespace = self.namespace_object() + if not hasattr(namespace, op_name): + raise AttributeError(f"{self.display_name}: torch.ops `{self.namespace}.{op_name}` is unavailable.") + op = getattr(namespace, op_name) + self._op_cache[op_name] = op + return op + def safe_load_cpp_ext( name, @@ -72,6 +889,68 @@ def safe_load_cpp_ext( return +def _pack_block_source_path() -> Path: + """Resolve the pack_block custom-op source file from source or editable installs.""" + + project_root = Path(__file__).resolve().parents[2] + source_path = project_root / "pack_block_cpu.cpp" + if source_path.exists(): + return source_path + return project_root / "gptqmodel_ext" / "pack_block_cpu.cpp" + + +def _floatx_cpu_source_path() -> Path: + """Resolve the floatx CPU custom-op source file from source or editable installs.""" + + project_root = Path(__file__).resolve().parents[2] + source_path = project_root / "floatx_cpu.cpp" + if source_path.exists(): + return source_path + return project_root / "gptqmodel_ext" / "floatx_cpu.cpp" + + +def _pack_block_extension() -> TorchOpsJitExtension: + """Build the shared pack_block torch.ops loader on first use.""" + + global _PACK_BLOCK_TORCH_OPS_EXTENSION + if _PACK_BLOCK_TORCH_OPS_EXTENSION is None: + _PACK_BLOCK_TORCH_OPS_EXTENSION = TorchOpsJitExtension( + name="gptqmodel_pack_block_cpu", + namespace="gptqmodel", + required_ops=("pack_block_cpu",), + sources=lambda: [str(_pack_block_source_path())], + build_root_env="GPTQMODEL_EXT_BUILD", + default_build_root=lambda: default_torch_ops_build_root("pack_block_cpu"), + display_name="pack_block_cpu", + extra_cflags=["-O3", "-std=c++17"], + extra_ldflags=[], + verbose_env="GPTQMODEL_EXT_VERBOSE", + requires_cuda=False, + ) + return _PACK_BLOCK_TORCH_OPS_EXTENSION + + +def _floatx_cpu_extension() -> TorchOpsJitExtension: + """Build the shared floatx CPU torch.ops loader on first use.""" + + global _FLOATX_CPU_TORCH_OPS_EXTENSION + if _FLOATX_CPU_TORCH_OPS_EXTENSION is None: + _FLOATX_CPU_TORCH_OPS_EXTENSION = TorchOpsJitExtension( + name="gptqmodel_floatx_cpu", + namespace="gptqmodel_floatx", + required_ops=("dequantize_fp8_cpu", "dequantize_fp4_cpu"), + sources=lambda: [str(_floatx_cpu_source_path())], + build_root_env="GPTQMODEL_EXT_BUILD", + default_build_root=lambda: default_torch_ops_build_root("floatx_cpu"), + display_name="floatx_cpu", + extra_cflags=["-O3", "-std=c++17"], + extra_ldflags=[], + verbose_env="GPTQMODEL_EXT_VERBOSE", + requires_cuda=False, + ) + return _FLOATX_CPU_TORCH_OPS_EXTENSION + + def load_pack_block_extension(*, verbose: bool = False) -> Optional[object]: """Ensure the pack_block CPU extension is built and loaded.""" @@ -85,54 +964,63 @@ def load_pack_block_extension(*, verbose: bool = False) -> Optional[object]: if _PACK_BLOCK_EXTENSION_INITIALISED and _PACK_BLOCK_EXTENSION: return _PACK_BLOCK_EXTENSION - try: - spec = importlib.util.find_spec("gptqmodel_pack_block_cpu") - except (ModuleNotFoundError, AttributeError): - spec = None - - if spec and spec.origin: - try: - torch.ops.load_library(spec.origin) - if hasattr(torch.ops.gptqmodel, "pack_block_cpu"): - log.debug("pack_block_cpu extension loaded from %s", spec.origin) - _PACK_BLOCK_EXTENSION = True - _PACK_BLOCK_EXTENSION_INITIALISED = True - return _PACK_BLOCK_EXTENSION - except Exception as exc: # pragma: no cover - environment-specific - log.debug("pack_block_cpu prebuilt load failed: %s", exc) - - project_root = Path(__file__).resolve().parents[2] - source_path = project_root / "pack_block_cpu.cpp" - if not source_path.exists(): - source_path = project_root / "gptqmodel_ext" / "pack_block_cpu.cpp" + source_path = _pack_block_source_path() if not source_path.exists(): log.debug("pack_block_cpu extension source not found at %s", source_path) _PACK_BLOCK_EXTENSION = None _PACK_BLOCK_EXTENSION_INITIALISED = True return None - extra_cflags = ["-O3", "-std=c++17"] - extra_ldflags: list[str] = [] - - build_dir = os.getenv("GPTQMODEL_EXT_BUILD") - - if not verbose: - verbose = env_flag("GPTQMODEL_EXT_VERBOSE", True) - try: - safe_load_cpp_ext( - name="gptqmodel_pack_block_cpu", - sources=[str(source_path)], - extra_cflags=extra_cflags, - extra_ldflags=extra_ldflags, - build_directory=build_dir, - verbose=verbose, - is_python_module=False, - ) + from gptqmodel import extension as extension_api + + previous_verbose = os.environ.get("GPTQMODEL_EXT_VERBOSE") + if verbose: + os.environ["GPTQMODEL_EXT_VERBOSE"] = "1" + try: + _PACK_BLOCK_EXTENSION = extension_api.load(name="pack_block_cpu")["pack_block_cpu"] + finally: + if verbose: + if previous_verbose is None: + os.environ.pop("GPTQMODEL_EXT_VERBOSE", None) + else: + os.environ["GPTQMODEL_EXT_VERBOSE"] = previous_verbose log.debug("pack_block_cpu extension loaded from %s", source_path) - _PACK_BLOCK_EXTENSION = True except Exception as exc: # pragma: no cover - environment-specific log.debug("pack_block_cpu extension build failed: %s", exc) _PACK_BLOCK_EXTENSION = None _PACK_BLOCK_EXTENSION_INITIALISED = True return _PACK_BLOCK_EXTENSION + + +def load_floatx_cpu_extension(*, verbose: bool = False) -> Optional[object]: + """Ensure the floatx CPU extension is built and loaded.""" + + namespace = getattr(torch.ops, "gptqmodel_floatx", None) + if namespace is not None and hasattr(namespace, "dequantize_fp8_cpu") and hasattr(namespace, "dequantize_fp4_cpu"): + return namespace + + source_path = _floatx_cpu_source_path() + if not source_path.exists(): + log.debug("floatx_cpu extension source not found at %s", source_path) + return None + + try: + from gptqmodel import extension as extension_api + + previous_verbose = os.environ.get("GPTQMODEL_EXT_VERBOSE") + if verbose: + os.environ["GPTQMODEL_EXT_VERBOSE"] = "1" + try: + extension = extension_api.load(name="floatx_cpu")["floatx_cpu"] + finally: + if verbose: + if previous_verbose is None: + os.environ.pop("GPTQMODEL_EXT_VERBOSE", None) + else: + os.environ["GPTQMODEL_EXT_VERBOSE"] = previous_verbose + log.debug("floatx_cpu extension loaded from %s", source_path) + return extension + except Exception as exc: # pragma: no cover - environment-specific + log.debug("floatx_cpu extension build failed: %s", exc) + return None diff --git a/gptqmodel/utils/data.py b/gptqmodel/utils/data.py index 7e30f0ae4..8b6bc9351 100644 --- a/gptqmodel/utils/data.py +++ b/gptqmodel/utils/data.py @@ -141,7 +141,11 @@ def make_data_block( return new_samples -def collate_data(batch: List[Dict[str, List[List[int]]]], pad_token_id: int) -> Dict[str, Tensor]: +def collate_data( + batch: List[Dict[str, List[List[int]]]], + pad_token_id: int, + padding_side: str = "right", +) -> Dict[str, Tensor]: """ Collate an outer batch (size B) of items, where each item holds multiple rows. We flatten the rows across items, pad to a global max length, and stack into @@ -178,7 +182,7 @@ def collate_data(batch: List[Dict[str, List[List[int]]]], pad_token_id: int) -> # Compute global max length max_len = max(t.numel() for t in rows_ids) if rows_ids else 0 - # Right-pad each row to global max_len + # Right- or left-pad each row to global max_len def right_pad(row: torch.Tensor, pad_value, dtype=None) -> torch.Tensor: pad_len = max_len - row.numel() if pad_len <= 0: @@ -191,9 +195,28 @@ def right_pad(row: torch.Tensor, pad_value, dtype=None) -> torch.Tensor: dim=0, ) - padded_ids = [right_pad(t, pad_token_id, dtype=torch.long) for t in rows_ids] + def left_pad(row: torch.Tensor, pad_value, dtype=None) -> torch.Tensor: + pad_len = max_len - row.numel() + if pad_len <= 0: + return row + return torch.cat( + [ + torch.full((pad_len,), pad_value, dtype=dtype or row.dtype, device=row.device), + row, + ], + dim=0, + ) + + if padding_side == "left": + pad_fn = left_pad + elif padding_side == "right": + pad_fn = right_pad + else: + raise ValueError(f"Unsupported padding_side `{padding_side}`. Expected `left` or `right`.") + + padded_ids = [pad_fn(t, pad_token_id, dtype=torch.long) for t in rows_ids] # pad masks with False, not 0 - padded_msk = [right_pad(t, False, dtype=torch.bool) for t in rows_mask] + padded_msk = [pad_fn(t, False, dtype=torch.bool) for t in rows_mask] # Stack into [total_rows_in_batch, max_len] input_ids = torch.stack(padded_ids, dim=0) if padded_ids else torch.empty((0, 0), dtype=torch.long) diff --git a/gptqmodel/utils/device_telemetry.py b/gptqmodel/utils/device_telemetry.py new file mode 100644 index 000000000..e52ff84e0 --- /dev/null +++ b/gptqmodel/utils/device_telemetry.py @@ -0,0 +1,85 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +"""Structured, env-gated device telemetry for placement debugging.""" + +from __future__ import annotations + +import threading +import time +from typing import Any, Dict, List + +import torch + +from .env import env_flag +from .logger import setup_logger + + +log = setup_logger() + +_DEVICE_TELEMETRY_ENV = "GPTQMODEL_DEVICE_TELEMETRY" +_records_lock = threading.Lock() +_records: List[Dict[str, Any]] = [] + + +def device_telemetry_enabled() -> bool: + """Return ``True`` when device telemetry should be emitted.""" + + return env_flag(_DEVICE_TELEMETRY_ENV, default="0") + + +def _normalize_field(value: Any) -> Any: + """Convert telemetry values into log-friendly primitives.""" + + if isinstance(value, torch.device): + return str(value) + if isinstance(value, torch.Tensor): + return str(value.device) + if isinstance(value, (list, tuple)): + return [_normalize_field(v) for v in value] + if isinstance(value, dict): + return {str(k): _normalize_field(v) for k, v in value.items()} + return value + + +def emit_device_telemetry(event: str, **fields: Any) -> None: + """Record and log one structured telemetry event when enabled.""" + + if not device_telemetry_enabled(): + return + + record = { + "event": event, + "ts": round(time.time(), 6), + } + for key, value in fields.items(): + record[key] = _normalize_field(value) + + with _records_lock: + _records.append(record) + + log.info(f"DeviceTelemetry: {record}") + + +def clear_device_telemetry_records() -> None: + """Discard previously captured telemetry records.""" + + with _records_lock: + _records.clear() + + +def get_device_telemetry_records() -> List[Dict[str, Any]]: + """Return a copy of the captured telemetry records.""" + + with _records_lock: + return [dict(record) for record in _records] + + +__all__ = [ + "clear_device_telemetry_records", + "device_telemetry_enabled", + "emit_device_telemetry", + "get_device_telemetry_records", +] diff --git a/gptqmodel/utils/env.py b/gptqmodel/utils/env.py index 50f8a7064..76eaf063d 100644 --- a/gptqmodel/utils/env.py +++ b/gptqmodel/utils/env.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium -"""Environment variable helpers used across GPTQModel.""" +"""Environment variable helpers used across GPT-QModel.""" from __future__ import annotations diff --git a/gptqmodel/utils/exllamav2.py b/gptqmodel/utils/exllamav2.py index 05288d6f0..e5078fd49 100644 --- a/gptqmodel/utils/exllamav2.py +++ b/gptqmodel/utils/exllamav2.py @@ -3,8 +3,20 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium +from __future__ import annotations + +from pathlib import Path + import torch +from .cpp import ( + TorchOpsJitExtension, + cuda_include_paths_with_fallback, + default_jit_cflags, + default_jit_cuda_cflags, + default_torch_ops_build_root, +) + class ScratchSpace: def __init__(self, scratch_bytes, dev): @@ -24,3 +36,227 @@ def get_slice(self, size_bytes): def next_multiple(x, multiple): return ((x + multiple - 1) // multiple) * multiple + + +def _exllamav2_root() -> Path: + return Path(__file__).resolve().parents[2] / "gptqmodel_ext" / "exllamav2" + + +def _exllamav2_gptq_sources() -> list[str]: + root = _exllamav2_root() + return [ + str(root / "ext_gptq.cpp"), + str(root / "cuda" / "q_matrix.cu"), + str(root / "cuda" / "q_gemm.cu"), + ] + + +def _exllamav2_required_cuda_headers() -> tuple[str, ...]: + return ("cusparse.h",) + + +def _exllamav2_include_paths() -> list[str]: + return cuda_include_paths_with_fallback( + [str(_exllamav2_root())], + required_header_names=_exllamav2_required_cuda_headers(), + ) + + +def _exllamav2_gptq_extra_cflags() -> list[str]: + return default_jit_cflags(opt_level="O2", enable_bf16=True) + + +def _exllamav2_gptq_extra_cuda_cflags() -> list[str]: + return default_jit_cuda_cflags( + opt_level="O2", + enable_bf16=True, + include_lineinfo=True, + include_nvcc_threads=True, + include_ptxas_optimizations=True, + include_fatbin_compression=True, + include_diag_suppress=True, + ) + + +def _exllamav2_extra_cflags() -> list[str]: + return default_jit_cflags(enable_bf16=True) + + +def _exllamav2_extra_cuda_cflags() -> list[str]: + return default_jit_cuda_cflags( + enable_bf16=True, + include_lineinfo=True, + include_nvcc_threads=True, + include_ptxas_optimizations=True, + include_fatbin_compression=True, + include_diag_suppress=True, + ) + + +def _exllamav2_awq_extra_cflags() -> list[str]: + return default_jit_cflags(opt_level=None, enable_bf16=True) + + +def _exllamav2_awq_extra_cuda_cflags() -> list[str]: + return default_jit_cuda_cflags( + opt_level=None, + enable_bf16=True, + include_lineinfo=True, + include_nvcc_threads=True, + include_ptxas_optimizations=True, + include_fatbin_compression=True, + include_diag_suppress=True, + ) + + +_EXLLAMAV2_GPTQ_TORCH_OPS_EXTENSION = TorchOpsJitExtension( + name="gptqmodel_exllamav2_ops", + namespace="gptqmodel_exllamav2", + required_ops=("make_q_matrix", "gemm_half_q_half"), + sources=_exllamav2_gptq_sources, + build_root_env="GPTQMODEL_EXLLAMAV2_BUILD_ROOT", + default_build_root=lambda: default_torch_ops_build_root("exllamav2"), + display_name="ExLlamaV2 GPTQ", + extra_cflags=_exllamav2_gptq_extra_cflags, + extra_cuda_cflags=_exllamav2_gptq_extra_cuda_cflags, + extra_include_paths=_exllamav2_include_paths, + force_rebuild_env="GPTQMODEL_EXLLAMAV2_FORCE_REBUILD", + verbose_env="GPTQMODEL_EXT_VERBOSE", + requires_cuda=True, +) + +# Shared AWQ singleton so every caller reuses the same torch.ops cache and +# first-use build policy instead of depending on setup-time wheels. +_EXLLAMAV2_AWQ_TORCH_OPS_EXTENSION = TorchOpsJitExtension( + name="gptqmodel_exllamav2_awq_ops", + namespace="gptqmodel_exllamav2_awq", + required_ops=("make_q_matrix_awq", "gemm_half_q_half_awq"), + sources=lambda: [ + str(_exllamav2_root() / "ext_awq.cpp"), + str(_exllamav2_root() / "cuda" / "q_matrix_awq.cu"), + str(_exllamav2_root() / "cuda" / "q_gemm_awq.cu"), + ], + build_root_env="GPTQMODEL_EXLLAMAV2_AWQ_BUILD_ROOT", + default_build_root=lambda: default_torch_ops_build_root("exllamav2_awq"), + display_name="ExLlamaV2 AWQ", + extra_cflags=_exllamav2_awq_extra_cflags, + extra_cuda_cflags=_exllamav2_awq_extra_cuda_cflags, + extra_include_paths=_exllamav2_include_paths, + force_rebuild_env="GPTQMODEL_EXLLAMAV2_AWQ_FORCE_REBUILD", + verbose_env="GPTQMODEL_EXT_VERBOSE", + requires_cuda=True, +) + + +def _extension_api(): + from gptqmodel import extension as extension_api + + return extension_api + + +def clear_exllamav2_gptq_extension_cache() -> None: + _EXLLAMAV2_GPTQ_TORCH_OPS_EXTENSION.clear_cache() + + +def exllamav2_gptq_runtime_available() -> bool: + return _extension_api().is_available("exllamav2") + + +def exllamav2_gptq_runtime_error() -> str: + extension_api = _extension_api() + if extension_api.is_available("exllamav2"): + return "" + return ( + extension_api.error("exllamav2") + or "ExLlamaV2 GPTQ CUDA runtime unavailable." + ) + + +def prewarm_exllamav2_gptq_extension() -> bool: + return _extension_api().load(name="exllamav2")["exllamav2"] + + +def exllamav2_make_q_matrix( + q_weight, + q_perm, + q_invperm, + q_scale, + q_scale_max, + q_groups, + gptq_qzeros, + gptq_scales, + gptq_g_idx, + temp_dq, +) -> int: + return int( + _extension_api().op("exllamav2", "make_q_matrix")( + q_weight, + q_perm, + q_invperm, + q_scale, + q_scale_max, + q_groups, + gptq_qzeros, + gptq_scales, + gptq_g_idx, + temp_dq, + ) + ) + + +def exllamav2_gemm_half_q_half(a, q_handle: int, c, force_cuda: bool = False) -> None: + _extension_api().op("exllamav2", "gemm_half_q_half")(a, int(q_handle), c, bool(force_cuda)) + + +def clear_exllamav2_awq_extension_cache() -> None: + _EXLLAMAV2_AWQ_TORCH_OPS_EXTENSION.clear_cache() + + +def exllamav2_awq_runtime_available() -> bool: + return _extension_api().is_available("exllamav2_awq") + + +def exllamav2_awq_runtime_error() -> str: + extension_api = _extension_api() + if extension_api.is_available("exllamav2_awq"): + return "" + return ( + extension_api.error("exllamav2_awq") + or "ExLlamaV2 AWQ CUDA runtime unavailable." + ) + + +def prewarm_exllamav2_awq_extension() -> bool: + return _extension_api().load(name="exllamav2_awq")["exllamav2_awq"] + + +def exllamav2_awq_make_q_matrix( + q_weight, + q_perm, + q_invperm, + q_scale, + q_scale_max, + q_groups, + gptq_qzeros, + gptq_scales, + gptq_g_idx, + temp_dq, +) -> int: + return int( + _extension_api().op("exllamav2_awq", "make_q_matrix_awq")( + q_weight, + q_perm, + q_invperm, + q_scale, + q_scale_max, + q_groups, + gptq_qzeros, + gptq_scales, + gptq_g_idx, + temp_dq, + ) + ) + + +def exllamav2_awq_gemm_half_q_half(a, q_handle: int, c, force_cuda: bool = False) -> None: + _extension_api().op("exllamav2_awq", "gemm_half_q_half_awq")(a, int(q_handle), c, bool(force_cuda)) diff --git a/gptqmodel/utils/exllamav3.py b/gptqmodel/utils/exllamav3.py new file mode 100644 index 000000000..775e1e376 --- /dev/null +++ b/gptqmodel/utils/exllamav3.py @@ -0,0 +1,100 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium +# +# Portions of this file are adapted from turboderp-org/exllamav3. +# Credits: TurboDerp / ExLlamaV3 contributors. + +from __future__ import annotations + +from typing import Any, Dict, Iterable, Optional, Type + +import torch +import torch.nn as nn +import transformers +from torch.nn.modules.conv import _ConvNd + +from ..looper.named_module import NamedModule +from ..nn_modules.exllamav3 import ExllamaV3Linear +from .model import recurse_setattr + + +def _resolve_linear_shape(submodule: nn.Module) -> tuple[int, int]: + named = submodule if isinstance(submodule, NamedModule) else None + target = named.module if named is not None else submodule + + if named is not None: + in_features = named.state.get("in_features") + out_features = named.state.get("out_features") + if in_features is not None and out_features is not None: + return int(in_features), int(out_features) + + if isinstance(target, nn.Linear): + return target.in_features, target.out_features + if isinstance(target, _ConvNd): + return target.in_channels, target.out_channels + if isinstance(target, transformers.Conv1D): + return target.weight.shape[0], target.weight.shape[1] + + in_features = getattr(target, "in_features", None) + out_features = getattr(target, "out_features", None) + if in_features is not None and out_features is not None: + return int(in_features), int(out_features) + + raise NotImplementedError(f"Unsupported EXL3 module type: {target.__class__.__name__}") + + +def create_exllamav3_module( + *, + module_root: nn.Module, + name: str, + submodule: nn.Module, + tensors: Dict[str, torch.Tensor], + module_cls: Type[nn.Module] = ExllamaV3Linear, +) -> nn.Module: + in_features, out_features = _resolve_linear_shape(submodule) + new_module = module_cls.from_tensors( + in_features=in_features, + out_features=out_features, + name=name, + tensors=tensors, + ) + recurse_setattr(module_root, name, new_module) + return new_module + + +def build_exllamav3_tensor_storage(model: nn.Module) -> Dict[str, Dict[str, Any]]: + storage: Dict[str, Dict[str, Any]] = {} + for name, module in model.named_modules(): + if getattr(module, "QUANT_TYPE", None) == "exl3" and hasattr(module, "tensor_storage_entry"): + storage[name] = module.tensor_storage_entry() + return storage + + +def replace_exllamav3_placeholders( + *, + model: nn.Module, + module_names: Iterable[str], + tensor_storage: Optional[Dict[str, Dict[str, Any]]] = None, + module_cls: Type[nn.Module] = ExllamaV3Linear, +) -> None: + module_lookup = dict(model.named_modules()) + storage_map = tensor_storage or {} + + for module_name in module_names: + submodule = module_lookup.get(module_name) + if submodule is None: + continue + + if not isinstance(submodule, (nn.Linear, transformers.Conv1D, _ConvNd)): + continue + + in_features, out_features = _resolve_linear_shape(submodule) + new_module = module_cls( + in_features=in_features, + out_features=out_features, + name=module_name, + tensor_storage=storage_map.get(module_name), + ) + recurse_setattr(model, module_name, new_module) diff --git a/gptqmodel/utils/fallback.py b/gptqmodel/utils/fallback.py new file mode 100644 index 000000000..2f2ac1037 --- /dev/null +++ b/gptqmodel/utils/fallback.py @@ -0,0 +1,120 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from typing import Any, Optional, Tuple + +from gptqmodel.quantization.config import Fallback, FallbackStrategy + + +def normalize_fallback( + value: Any, + default: Optional[Fallback] = None, +) -> Optional[Fallback]: + if value is None: + return default + if isinstance(value, Fallback): + return value + if isinstance(value, dict): + fallback = default if isinstance(default, Fallback) else Fallback() + return Fallback( + strategy=value.get("strategy", fallback.strategy), + threshold=value.get("threshold", fallback.threshold), + ) + raise ValueError( + "normalize_fallback: expected Fallback, dict, or None. " + ) + + +def _parse_threshold(setting: Any) -> Tuple[Optional[float], bool]: + """ + Returns (threshold_value, is_percent). + Percent values are returned as the fractional percentage (e.g., "5%" -> 5.0). + """ + if isinstance(setting, str): + stripped = setting.strip() + if stripped.endswith("%"): + try: + val = float(stripped[:-1]) + return val, True + except ValueError: + return None, False + try: + return float(stripped), False + except ValueError: + return None, False + + if isinstance(setting, (int, float)): + return float(setting), False + + return None, False + + +def resolve_fallback_strategy(strategy: Any) -> FallbackStrategy: + """ + Normalize a fallback strategy. + """ + if isinstance(strategy, Fallback): + strategy = strategy.strategy + if isinstance(strategy, dict): + strategy = strategy.get("strategy", FallbackStrategy.RTN) + if strategy is None: + resolved = FallbackStrategy.RTN + elif isinstance(strategy, FallbackStrategy): + resolved = strategy + elif isinstance(strategy, str): + normalized = strategy.strip().lower() + try: + resolved = FallbackStrategy(normalized) + except ValueError: + resolved = FallbackStrategy.RTN + else: + resolved = FallbackStrategy.RTN + + return resolved + + +def should_use_fallback( + setting: Any, + observed_samples: float, + expected_total_samples: Optional[float] = None, +) -> bool: + if isinstance(setting, Fallback): + setting = setting.threshold + if isinstance(setting, dict): + setting = setting.get("threshold", None) + threshold_value, _ = resolve_threshold(setting, expected_total_samples) + if threshold_value is None: + return False + return observed_samples < threshold_value + + +def resolve_threshold( + setting: Any, + expected_total_samples: Optional[float] = None, +) -> Tuple[Optional[float], bool]: + """ + Resolve a threshold into a raw numeric value and whether it was percent-based. + """ + if isinstance(setting, Fallback): + setting = setting.threshold + if isinstance(setting, dict): + setting = setting.get("threshold", None) + if not setting: + return None, False + + if setting is True: + # Tiny positive epsilon so 0 triggers but positive counts do not when using `<`. + return 1e-9, False + + threshold, is_percent = _parse_threshold(setting) + if threshold is None: + return None, False + + if is_percent: + base = expected_total_samples if expected_total_samples else 1.0 + threshold_value = (threshold / 100.0) * float(base) + else: + threshold_value = threshold + + return threshold_value, is_percent diff --git a/gptqmodel/utils/gptq_pro.py b/gptqmodel/utils/gptq_pro.py new file mode 100644 index 000000000..7debeb3c4 --- /dev/null +++ b/gptqmodel/utils/gptq_pro.py @@ -0,0 +1,153 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from __future__ import annotations + +import os +import shutil +import threading +from pathlib import Path +from typing import Optional + +import torch +from torch.utils.cpp_extension import _get_build_directory, load + +from ._extension_loader import load_extension_module +from .env import env_flag +from .logger import setup_logger +from .rocm import IS_ROCM + + +log = setup_logger() + +_GPTQ_PRO_LOCK = threading.Lock() +_GPTQ_PRO_MODULE = None +_GPTQ_PRO_INITIALISED = False +_GPTQ_PRO_BUILD_PREPARED = False +gptq_pro_import_exception: Optional[str] = None + + +def _validate_gptq_pro_device_support() -> bool: + return torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 and not IS_ROCM + + +def _gptq_pro_sources() -> tuple[Path, Path]: + project_root = Path(__file__).resolve().parents[2] + ext_dir = project_root / "gptqmodel_ext" / "gptq_pro" + return ext_dir / "gptq_pro_torch.cpp", ext_dir / "gptq_pro_kernel.cu" + + +def _prepare_build_directory(verbose: bool) -> str: + global _GPTQ_PRO_BUILD_PREPARED + + build_dir_env = os.getenv("GPTQMODEL_EXT_BUILD") + if build_dir_env: + build_directory = Path(build_dir_env) / "gptqmodel_gptq_pro_kernels" + else: + build_directory = Path(_get_build_directory("gptqmodel_gptq_pro_kernels", verbose=verbose)) + + if not _GPTQ_PRO_BUILD_PREPARED and build_directory.exists(): + shutil.rmtree(build_directory, ignore_errors=True) + + build_directory.mkdir(parents=True, exist_ok=True) + _GPTQ_PRO_BUILD_PREPARED = True + return str(build_directory) + + +def _build_gptq_pro_extension(verbose: bool): + source_cpp, source_cu = _gptq_pro_sources() + if not source_cpp.is_file() or not source_cu.is_file(): + raise ImportError("gptq_pro extension sources are missing from the checkout.") + + build_directory = _prepare_build_directory(verbose=verbose) + return load( + name="gptqmodel_gptq_pro_kernels", + sources=[str(source_cpp), str(source_cu)], + extra_cflags=["-O3", "-std=c++17"], + extra_cuda_cflags=[ + "-O3", + "-std=c++17", + "-lineinfo", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + ], + build_directory=build_directory, + verbose=verbose, + ) + + +def ensure_gptq_pro_loaded(*, verbose: Optional[bool] = None): + global _GPTQ_PRO_MODULE, _GPTQ_PRO_INITIALISED, gptq_pro_import_exception + + if _GPTQ_PRO_MODULE is not None: + return _GPTQ_PRO_MODULE + + if verbose is None: + verbose = env_flag("GPTQMODEL_EXT_VERBOSE", False) + + with _GPTQ_PRO_LOCK: + if _GPTQ_PRO_MODULE is not None: + return _GPTQ_PRO_MODULE + if _GPTQ_PRO_INITIALISED and gptq_pro_import_exception is not None: + raise ImportError(gptq_pro_import_exception) + + errors = [] + try: + _GPTQ_PRO_MODULE = load_extension_module("gptqmodel_gptq_pro_kernels") + gptq_pro_import_exception = None + _GPTQ_PRO_INITIALISED = True + return _GPTQ_PRO_MODULE + except ImportError as exc: + errors.append(f"prebuilt load failed: {exc}") + + if not _validate_gptq_pro_device_support(): + gptq_pro_import_exception = ( + "GPTQ-Pro kernel requires Linux CUDA with compute capability >= 8.0 and does not support ROCm." + ) + _GPTQ_PRO_INITIALISED = True + raise ImportError(gptq_pro_import_exception) + + try: + _GPTQ_PRO_MODULE = _build_gptq_pro_extension(verbose=bool(verbose)) + gptq_pro_import_exception = None + _GPTQ_PRO_INITIALISED = True + return _GPTQ_PRO_MODULE + except Exception as exc: # pragma: no cover - environment-specific + errors.append(f"jit build failed: {exc}") + gptq_pro_import_exception = " | ".join(errors) + _GPTQ_PRO_INITIALISED = True + raise ImportError(gptq_pro_import_exception) from exc + + +def gptq_pro_qweight_to_b_packed(qweight: torch.Tensor) -> torch.Tensor: + if qweight.dtype != torch.int32: + raise ValueError(f"Expected int32 qweight tensor, got `{qweight.dtype}`.") + if qweight.dim() != 2: + raise ValueError(f"Expected 2D qweight tensor, got shape `{tuple(qweight.shape)}`.") + + qweight = qweight.contiguous() + shifts = torch.arange(0, 32, 4, device=qweight.device, dtype=qweight.dtype).view(1, 8, 1) + unpacked = torch.bitwise_and(torch.bitwise_right_shift(qweight.unsqueeze(1), shifts), 0xF).to(torch.uint8) + unpacked = unpacked.reshape(-1, qweight.shape[1]) + return (unpacked[0::2] | (unpacked[1::2] << 4)).contiguous() + + +def apply_gptq_pro_linear( + input: torch.Tensor, + b_packed: torch.Tensor, + scales: torch.Tensor, + group_size: int, +) -> torch.Tensor: + module = ensure_gptq_pro_loaded() + return module.gptq_pro_gemm(input, b_packed, scales, int(group_size)) + + +__all__ = [ + "_validate_gptq_pro_device_support", + "apply_gptq_pro_linear", + "ensure_gptq_pro_loaded", + "gptq_pro_import_exception", + "gptq_pro_qweight_to_b_packed", +] diff --git a/gptqmodel/utils/hf.py b/gptqmodel/utils/hf.py index a695bbc60..0ecaa5b1c 100644 --- a/gptqmodel/utils/hf.py +++ b/gptqmodel/utils/hf.py @@ -3,12 +3,31 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium +import inspect import json -from typing import Any, Optional +import os +import sys +import warnings +from contextlib import contextmanager +from functools import lru_cache +from typing import Any, List, Optional, Union +import numpy as np import torch +import transformers from accelerate import init_empty_weights -from transformers import GenerationConfig, PreTrainedModel +from tokenicer import Tokenicer +from transformers import AutoConfig, AutoModelForCausalLM, GenerationConfig, PreTrainedModel, PreTrainedTokenizerBase + +from ..nn_modules.qlinear.gguf import ( + PRISM_Q1_0_G128_BLOCK_SIZE, + PRISM_Q1_0_G128_NAME, + PRISM_Q1_0_G128_TYPE_SIZE, + PRISM_Q1_0_G128_VALUE, + _dequantize_prism_q1_0_g128, + _is_prism_q1_0_g128, +) +from ..utils import _MONKEY_PATCH_LOCK, internal_gguf # Compatibility wrapper for no_init_weights across different transformers versions @@ -22,9 +41,1260 @@ from ..utils.logger import setup_logger -__all__ = ["no_init_weights"] +__all__ = [ + "no_init_weights", + "suspend_hf_weight_init", + "get_hf_config_dtype", + "normalize_torch_dtype_kwarg", + "normalize_hf_config_compat", + "prepare_remote_code_compat", + "prepare_remote_model_init_compat", + "has_native_transformers_causallm_support", + "get_hf_gguf_load_kwargs", + "normalize_model_id_or_path_for_hf_gguf", + "resolve_trust_remote_code", + "set_hf_config_dtype", + "load_tokenizer", +] log = setup_logger() +_TRUST_REMOTE_CODE_OVERRIDE_WARNED: set[tuple[str, str, str]] = set() +_MISSING = object() +INTERNAL_HF_GGUF_FILE_KWARG = "_gptqmodel_hf_gguf_file" +_DENSE_MODEL_FILE_EXTENSIONS = (".safetensors", ".bin", ".pt", ".pth", ".ckpt") +_INTERNAL_GGUF_TORCH_LOADER_ENV = "GPTQMODEL_INTERNAL_GGUF_TORCH_LOADER" +_FALSEY_ENV_VALUES = {"", "0", "false", "off", "no"} + + +def get_hf_config_dtype(config: Any) -> Optional[torch.dtype]: + dtype = getattr(config, "dtype", None) + if dtype is None: + dtype = getattr(config, "torch_dtype", None) + + if dtype is None: + return None + + if isinstance(dtype, torch.dtype): + return dtype + + # If provided as a string (e.g., "float16"), resolve it via torch namespace + if isinstance(dtype, str): + try: + return getattr(torch, dtype) + except AttributeError: + raise ValueError(f"Invalid dtype string: {dtype}") + + raise ValueError(f"dtype must be torch.dtype or str, but got {dtype} (type={type(dtype)})") + + +def set_hf_config_dtype(config: Any, dtype: torch.dtype) -> None: + current_dtype = get_hf_config_dtype(config) + if current_dtype == dtype: + return + + try: + setattr(config, "dtype", dtype) + except Exception: + if getattr(config, "torch_dtype", None) != dtype: + setattr(config, "torch_dtype", dtype) + + +def normalize_torch_dtype_kwarg( + kwargs: dict[str, Any], + *, + api_name: str, + explicit_dtype: Any = _MISSING, +) -> Any: + current_dtype = explicit_dtype + if explicit_dtype is _MISSING: + current_dtype = kwargs.pop("dtype", _MISSING) + + torch_dtype = kwargs.pop("torch_dtype", _MISSING) + if torch_dtype is _MISSING: + if explicit_dtype is _MISSING and current_dtype is not _MISSING: + kwargs["dtype"] = current_dtype + return current_dtype + + log.warn.once(f"{api_name}: `torch_dtype` is deprecated; use `dtype` instead.") + + if current_dtype is _MISSING or current_dtype is None or current_dtype == "auto": + current_dtype = torch_dtype + elif current_dtype != torch_dtype: + raise ValueError( + f"{api_name}: received both `dtype` and deprecated `torch_dtype` with different values. " + "Please pass only `dtype`." + ) + + if explicit_dtype is _MISSING: + kwargs["dtype"] = current_dtype + return current_dtype + + +@contextmanager +def suspend_hf_weight_init(): + """Disable HF/torch parameter init temporarily and always restore globals.""" + + def _skip_init(*args, **kwargs): + return None + + original_kaiming_uniform = torch.nn.init.kaiming_uniform_ + original_uniform = torch.nn.init.uniform_ + original_normal = torch.nn.init.normal_ + + modeling_utils = transformers.modeling_utils + had_init_flag = hasattr(modeling_utils, "_init_weights") + original_init_flag = getattr(modeling_utils, "_init_weights", None) + + torch.nn.init.kaiming_uniform_ = _skip_init + torch.nn.init.uniform_ = _skip_init + torch.nn.init.normal_ = _skip_init + modeling_utils._init_weights = False + + try: + with no_init_weights(): + yield + finally: + torch.nn.init.kaiming_uniform_ = original_kaiming_uniform + torch.nn.init.uniform_ = original_uniform + torch.nn.init.normal_ = original_normal + if had_init_flag: + modeling_utils._init_weights = original_init_flag + elif hasattr(modeling_utils, "_init_weights"): + delattr(modeling_utils, "_init_weights") + + +def _raise_public_gguf_file_arg_error(api_name: str) -> None: + raise TypeError( + f"{api_name} does not accept `gguf_file`. Pass the GGUF checkpoint as `model_id_or_path`, " + "or pass a model directory / repo containing a single GGUF file." + ) + + +def get_hf_gguf_load_kwargs(kwargs: dict[str, Any]) -> dict[str, str]: + gguf_file = kwargs.get(INTERNAL_HF_GGUF_FILE_KWARG) + if gguf_file is None: + return {} + return {"gguf_file": gguf_file} + + +def _normalize_repo_file_paths(file_names) -> list[str]: + return [str(file_name).replace("\\", "/") for file_name in file_names] + + +def _infer_single_gguf_file(file_names) -> Optional[str]: + normalized_files = _normalize_repo_file_paths(file_names) + gguf_files = sorted(file_name for file_name in normalized_files if file_name.lower().endswith(".gguf")) + if len(gguf_files) != 1: + return None + + dense_files = [ + file_name + for file_name in normalized_files + if file_name.lower().endswith(_DENSE_MODEL_FILE_EXTENSIONS) + ] + if dense_files: + return None + + return gguf_files[0] + + +def _iter_local_repo_files(root_dir: str) -> list[str]: + repo_files = [] + for current_root, _dirs, files in os.walk(root_dir): + for file_name in files: + full_path = os.path.join(current_root, file_name) + repo_files.append(os.path.relpath(full_path, root_dir).replace(os.sep, "/")) + return repo_files + + +@lru_cache(maxsize=None) +def _resolve_hf_gguf_artifact(model_id_or_path: str) -> Optional[tuple[str, str]]: + if os.path.isfile(model_id_or_path) and model_id_or_path.lower().endswith(".gguf"): + model_root = os.path.dirname(os.path.abspath(model_id_or_path)) or "." + return model_root, os.path.basename(model_id_or_path) + + if os.path.isdir(model_id_or_path): + inferred_gguf_file = _infer_single_gguf_file(_iter_local_repo_files(model_id_or_path)) + if inferred_gguf_file is not None: + return os.path.normpath(model_id_or_path), inferred_gguf_file + return None + + try: + from .hub import list_repo_files + except Exception: + return None + + try: + repo_files = list_repo_files(repo_id=model_id_or_path) + except Exception: + return None + + inferred_gguf_file = _infer_single_gguf_file(repo_files) + if inferred_gguf_file is None: + return None + return model_id_or_path, inferred_gguf_file + + +def _transformers_has_native_prism_gguf_support() -> bool: + try: + import transformers.modeling_gguf_pytorch_utils as gguf_utils + except Exception: + return False + + return hasattr(gguf_utils, "_dequantize_prism_q1_0_g128") + + +def _internal_gguf_torch_loader_enabled() -> bool: + raw = os.getenv(_INTERNAL_GGUF_TORCH_LOADER_ENV) + if raw is not None: + return str(raw).strip().lower() not in _FALSEY_ENV_VALUES + return bool(os.getenv("GPTQMODEL_INTERNAL_GGUF_DEQUANT_DEVICE", "").strip()) + + +def _load_gguf_checkpoint_torch_direct( + *, + gguf_utils, + original_load_gguf_checkpoint, + gguf_checkpoint_path, + return_tensors: bool = False, + model_to_load=None, +): + if not return_tensors or model_to_load is None or not _internal_gguf_torch_loader_enabled(): + return original_load_gguf_checkpoint( + gguf_checkpoint_path, + return_tensors=return_tensors, + model_to_load=model_to_load, + ) + + parsed_parameters = original_load_gguf_checkpoint( + gguf_checkpoint_path, + return_tensors=False, + model_to_load=model_to_load, + ) + config = parsed_parameters.get("config", {}) + model_type = config.get("model_type") + if model_type != internal_gguf.MODEL_ARCH_QWEN3: + return original_load_gguf_checkpoint( + gguf_checkpoint_path, + return_tensors=True, + model_to_load=model_to_load, + ) + + processor_cls = gguf_utils.TENSOR_PROCESSORS.get(model_type, gguf_utils.TensorProcessor) + if processor_cls is not gguf_utils.TensorProcessor: + return original_load_gguf_checkpoint( + gguf_checkpoint_path, + return_tensors=True, + model_to_load=model_to_load, + ) + + processor = processor_cls(config=config) + tensor_key_mapping = gguf_utils.get_gguf_hf_weights_map(model_to_load, processor) + target_device = internal_gguf._resolve_torch_dequant_device() + reader = internal_gguf.GGUFReader(gguf_checkpoint_path) + parsed_parameters["tensors"] = {} + + for tensor in gguf_utils.tqdm(reader.tensors, desc="Converting GGUF tensors directly to torch..."): + name = tensor.name + weights = internal_gguf.dequantize_to_torch( + tensor.data, + tensor.tensor_type, + device=target_device, + ) + + result = processor.process( + weights=weights, + name=name, + tensor_key_mapping=tensor_key_mapping, + parsed_parameters=parsed_parameters, + ) + + weights = result.weights + name = result.name + if name not in tensor_key_mapping: + continue + + if not torch.is_tensor(weights): + weights = torch.from_numpy(np.array(weights, copy=True, order="C")) + if target_device is not None: + weights = weights.to(device=target_device) + + parsed_parameters["tensors"][tensor_key_mapping[name]] = weights.contiguous() + + return parsed_parameters + + +def _patch_transformers_internal_gguf_torch_loader(gguf_utils) -> None: + with _MONKEY_PATCH_LOCK: + if getattr(gguf_utils, "_GPTQMODEL_INTERNAL_GGUF_TORCH_LOADER_PATCHED", False): + return + + original_load_gguf_checkpoint = gguf_utils.load_gguf_checkpoint + + def _wrapped_load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False, model_to_load=None): + try: + return _load_gguf_checkpoint_torch_direct( + gguf_utils=gguf_utils, + original_load_gguf_checkpoint=original_load_gguf_checkpoint, + gguf_checkpoint_path=gguf_checkpoint_path, + return_tensors=return_tensors, + model_to_load=model_to_load, + ) + except Exception as exc: + log.debug( + "HF: internal torch GGUF loader fell back to the stock loader for `%s`: %s", + gguf_checkpoint_path, + exc, + ) + return original_load_gguf_checkpoint( + gguf_checkpoint_path, + return_tensors=return_tensors, + model_to_load=model_to_load, + ) + + gguf_utils._gptqmodel_original_load_gguf_checkpoint = original_load_gguf_checkpoint + gguf_utils.load_gguf_checkpoint = _wrapped_load_gguf_checkpoint + gguf_utils._GPTQMODEL_INTERNAL_GGUF_TORCH_LOADER_PATCHED = True + + +def _patch_transformers_prism_gguf_compat(*, api_name: str) -> None: + try: + import transformers.modeling_gguf_pytorch_utils as gguf_utils + from transformers.utils import import_utils as hf_import_utils + except Exception: + return + + with _MONKEY_PATCH_LOCK: + internal_gguf.install_runtime() + _patch_transformers_internal_gguf_torch_loader(gguf_utils) + + def _is_gguf_available(*_args, **_kwargs) -> bool: + return True + + if not gguf_utils.is_gguf_available(): + gguf_utils.is_gguf_available = _is_gguf_available + if hasattr(hf_import_utils, "is_gguf_available"): + hf_import_utils.is_gguf_available = _is_gguf_available + + if _transformers_has_native_prism_gguf_support(): + return + + gguf_utils.PRISM_Q1_0_G128_NAME = PRISM_Q1_0_G128_NAME + gguf_utils.PRISM_Q1_0_G128_VALUE = PRISM_Q1_0_G128_VALUE + gguf_utils.PRISM_Q1_0_G128_BLOCK_SIZE = PRISM_Q1_0_G128_BLOCK_SIZE + gguf_utils.PRISM_Q1_0_G128_TYPE_SIZE = PRISM_Q1_0_G128_TYPE_SIZE + gguf_utils._is_prism_q1_0_g128 = _is_prism_q1_0_g128 + gguf_utils._dequantize_prism_q1_0_g128 = _dequantize_prism_q1_0_g128 + + log.warning( + "HF: installed transformers lacks native Prism GGUF support; GPT-QModel registered its internal " + "GGUF runtime and local Q1_0_g128 compatibility patch for `%s`.", + api_name, + ) + + +def normalize_model_id_or_path_for_hf_gguf( + model_id_or_path: Optional[str], + kwargs: dict[str, Any], + *, + api_name: str, +) -> Optional[str]: + if INTERNAL_HF_GGUF_FILE_KWARG in kwargs: + return model_id_or_path + + if kwargs.pop("gguf_file", None) is not None: + _raise_public_gguf_file_arg_error(api_name) + + if model_id_or_path is None: + return None + + resolved = _resolve_hf_gguf_artifact(str(model_id_or_path)) + if resolved is None: + return model_id_or_path + + normalized_model_id_or_path, gguf_file = resolved + _patch_transformers_prism_gguf_compat(api_name=api_name) + kwargs[INTERNAL_HF_GGUF_FILE_KWARG] = gguf_file + return normalized_model_id_or_path + + +@lru_cache(maxsize=None) +def _detect_native_transformers_causallm_support(model_id_or_path: str) -> tuple[bool, Optional[str], Optional[str]]: + config_load_kwargs: dict[str, Any] = {} + normalized_model_id_or_path = normalize_model_id_or_path_for_hf_gguf( + model_id_or_path, + config_load_kwargs, + api_name="_detect_native_transformers_causallm_support", + ) + try: + config = AutoConfig.from_pretrained( + normalized_model_id_or_path, + trust_remote_code=False, + **get_hf_gguf_load_kwargs(config_load_kwargs), + ) + except Exception as exc: + log.debug("HF: native transformers support check failed for `%s`: %s", normalized_model_id_or_path, exc) + return False, None, None + + model_type = getattr(config, "model_type", None) + try: + model_cls = AutoModelForCausalLM._model_mapping[type(config)] + except Exception as exc: + log.debug( + "HF: config `%s` for `%s` has no native AutoModelForCausalLM mapping: %s", + type(config).__name__, + normalized_model_id_or_path, + exc, + ) + return False, model_type, None + + return True, model_type, getattr(model_cls, "__name__", str(model_cls)) + + +def resolve_trust_remote_code(model_id_or_path: Optional[str], *, trust_remote_code: bool) -> bool: + if not trust_remote_code or not model_id_or_path: + return trust_remote_code + + native_supported, model_type, model_cls_name = _detect_native_transformers_causallm_support(str(model_id_or_path)) + if not native_supported: + return True + + warning_key = (str(model_id_or_path), model_type or "unknown", model_cls_name or "unknown") + if warning_key not in _TRUST_REMOTE_CODE_OVERRIDE_WARNED: + log.warning( + "HF: overriding trust_remote_code=True to False for `%s` because model_type `%s` is integrated in installed transformers as `%s`.", + model_id_or_path, + model_type or "unknown", + model_cls_name or "unknown", + ) + _TRUST_REMOTE_CODE_OVERRIDE_WARNED.add(warning_key) + + return False + + +def has_native_transformers_causallm_support(model_id_or_path: Optional[str]) -> bool: + if not model_id_or_path: + return False + + native_supported, _, _ = _detect_native_transformers_causallm_support(str(model_id_or_path)) + return native_supported + +def _resolve_input_embedding_weight_name(model: PreTrainedModel) -> Optional[str]: + get_input_embeddings = getattr(model, "get_input_embeddings", None) + if not callable(get_input_embeddings): + return None + + try: + input_embeddings = get_input_embeddings() + except Exception: + return None + + if input_embeddings is None: + return None + + weight = getattr(input_embeddings, "weight", None) + if weight is None: + return None + + for name, param in model.named_parameters(remove_duplicate=False): + if param is weight: + return name + + for name, module in model.named_modules(remove_duplicate=False): + if module is input_embeddings: + return f"{name}.weight" if name else "weight" + + return None + + +# Older remote model files sometimes store `_tied_weights_keys` as a plain list +# like `["lm_head.weight"]`. transformers 5.x now expects `{target: source}`, +# and otherwise later save/load helpers fail with `'list' object has no attribute 'keys'`. +def _resolve_legacy_tied_weights_mapping(model: PreTrainedModel, tied_mapping) -> dict[str, str]: + if not isinstance(tied_mapping, (list, tuple, set)): + return {} + + if not getattr(getattr(model, "config", None), "tie_word_embeddings", False): + return {} + + source_name = _resolve_input_embedding_weight_name(model) + if source_name is None: + return {} + + return { + # Legacy list entries only name the tied target, so resolve them back + # to the input embedding weight name expected by transformers 5.x. + target_name: source_name + for target_name in tied_mapping + if isinstance(target_name, str) and target_name != source_name + } + + +# Rewrite legacy list-based `_tied_weights_keys` in-place so transformers 5.x +# save/load code stops crashing on older trust_remote_code models that still use +# the pre-5.x list format. +def _normalize_legacy_tied_weights_keys(model: PreTrainedModel) -> None: + for _name, submodule in model.named_modules(remove_duplicate=False): + tied_mapping = getattr(submodule, "_tied_weights_keys", None) + if not isinstance(tied_mapping, (list, tuple, set)): + continue + + if isinstance(submodule, PreTrainedModel): + submodule._tied_weights_keys = _resolve_legacy_tied_weights_mapping(submodule, tied_mapping) + else: + submodule._tied_weights_keys = {} + + +# Bridge a few transformers 5.x API changes so older trust_remote_code model +# files still import and initialize without editing the cached remote source. +def _patch_transformers_remote_code_compat() -> None: + try: + from transformers.utils import import_utils + except Exception: + return + + try: + from transformers import cache_utils + except Exception: + cache_utils = None + + try: + from transformers.generation import utils as generation_utils + except Exception: + generation_utils = None + + try: + from transformers import utils + except Exception: + utils = None + + import transformers.utils.generic as generic + with _MONKEY_PATCH_LOCK: + if not hasattr(import_utils, "is_torch_fx_available"): + # transformers 5.x removed `import_utils.is_torch_fx_available`, but + # older remote model files still import it during module import. + def is_torch_fx_available() -> bool: + return hasattr(torch, "fx") + + import_utils.is_torch_fx_available = is_torch_fx_available + + if utils is not None and not hasattr(utils, "is_flash_attn_greater_or_equal_2_10"): + legacy_flash_attn_probe = getattr(utils, "is_flash_attn_greater_or_equal", None) + + if legacy_flash_attn_probe: + # Older trust_remote_code model files import the removed + # `is_flash_attn_greater_or_equal_2_10` helper from + # `transformers.utils`; newer transformers only expose the generic + # comparator. + def is_flash_attn_greater_or_equal_2_10() -> bool: + return bool(legacy_flash_attn_probe("2.1.0")) + + utils.is_flash_attn_greater_or_equal_2_10 = is_flash_attn_greater_or_equal_2_10 + + if cache_utils is not None and not hasattr(cache_utils, "SlidingWindowCache") and hasattr(cache_utils, "StaticCache"): + # transformers 5.x folds sliding-window behavior into StaticCache + # layers, but older remote code still imports the legacy symbol. + cache_utils.SlidingWindowCache = cache_utils.StaticCache + + if not hasattr(generic, "check_model_inputs"): + # transformers 5.x removed `transformers.utils.generic.check_model_inputs`, but + # older remote model files still import it during module import. + def check_model_inputs(func=None, **kwargs): + def wrapper(fn): + def inner(self, *args, **kwargs): + return fn(self, *args, **kwargs) + + return inner + + return wrapper(func) if func else wrapper + + generic.check_model_inputs = check_model_inputs + + if cache_utils is not None and not hasattr(cache_utils, "HybridCache") and hasattr(cache_utils, "StaticCache"): + # transformers 5.x also collapsed the legacy HybridCache entrypoint + # into StaticCache, which already instantiates hybrid/sliding layers + # based on the model config. + cache_utils.HybridCache = cache_utils.StaticCache + + cache_base_cls = getattr(cache_utils, "Cache", None) if cache_utils is not None else None + if cache_base_cls is not None and not hasattr(cache_base_cls, "get_max_length") and hasattr(cache_base_cls, "get_max_cache_shape"): + # Older remote decoders expect `get_max_length()`, while newer + # transformers renamed that API to `get_max_cache_shape()`. + def get_max_length(self, layer_idx: int = 0) -> Optional[int]: + max_length = self.get_max_cache_shape(layer_idx) + return None if max_length is None or max_length < 0 else max_length + + cache_base_cls.get_max_length = get_max_length + + if cache_base_cls is not None and not hasattr(cache_base_cls, "get_usable_length") and hasattr(cache_base_cls, "get_seq_length"): + # Recreate the pre-5.x cache eviction helper so remote attention code + # can compute usable KV length against the newer cache interface. + def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) -> int: + max_length = self.get_max_length(layer_idx=layer_idx) + previous_seq_length = self.get_seq_length(layer_idx) + if max_length is not None and previous_seq_length + new_seq_length > max_length: + return max_length - new_seq_length + return previous_seq_length + + cache_base_cls.get_usable_length = get_usable_length + + dynamic_cache_cls = getattr(cache_utils, "DynamicCache", None) if cache_utils is not None else None + if dynamic_cache_cls is not None and not hasattr(dynamic_cache_cls, "to_legacy_cache"): + # Older remote generation code still serializes cache state as + # `Tuple[(key, value), ...]`; rebuild that view from layer storage. + def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor], ...]: + legacy_cache = () + for layer in self.layers: + if not getattr(layer, "is_initialized", False): + continue + legacy_cache += ((layer.keys, layer.values),) + return legacy_cache + + dynamic_cache_cls.to_legacy_cache = to_legacy_cache + + if dynamic_cache_cls is not None and not hasattr(dynamic_cache_cls, "from_legacy_cache"): + # Accept legacy tuple caches by replaying them into the current + # layer-based DynamicCache implementation. + @classmethod + def from_legacy_cache(cls, past_key_values: Optional[tuple[tuple[torch.Tensor, torch.Tensor], ...]] = None): + cache = cls() + if past_key_values is not None: + for layer_idx, (key_states, value_states) in enumerate(past_key_values): + cache.update(key_states, value_states, layer_idx) + return cache + + dynamic_cache_cls.from_legacy_cache = from_legacy_cache + + if generation_utils is not None and not hasattr(generation_utils, "NEED_SETUP_CACHE_CLASSES_MAPPING"): + # Older remote generation code registers custom cache builders through + # this module-global dict during import. + generation_utils.NEED_SETUP_CACHE_CLASSES_MAPPING = {} + + generation_mixin_cls = getattr(generation_utils, "GenerationMixin", None) if generation_utils is not None else None + if generation_mixin_cls is not None and not getattr(generation_mixin_cls, "_gptqmodel_custom_cache_impl_patch", False): + original_prepare_cache_for_generation = generation_mixin_cls._prepare_cache_for_generation + + # transformers 5.x removed the custom cache registry path used by some + # trust_remote_code models, so recreate just enough of that setup here. + def _prepare_cache_for_generation_compat( + self, + generation_config: GenerationConfig, + model_kwargs: dict, + generation_mode, + batch_size: int, + max_cache_length: int, + ) -> None: + cache_mapping = getattr(generation_utils, "NEED_SETUP_CACHE_CLASSES_MAPPING", None) + cache_implementation = getattr(generation_config, "cache_implementation", None) + if not isinstance(cache_mapping, dict) or not isinstance(cache_implementation, str): + original_prepare_cache_for_generation( + self, + generation_config, + model_kwargs, + generation_mode, + batch_size, + max_cache_length, + ) + return None + + custom_cache_cls = cache_mapping.get(cache_implementation) + if custom_cache_cls is None: + original_prepare_cache_for_generation( + self, + generation_config, + model_kwargs, + generation_mode, + batch_size, + max_cache_length, + ) + return None + + is_linear_attn_cache = "mamba" in self.__class__.__name__.lower() + cache_name = "past_key_values" if not is_linear_attn_cache else "cache_params" + + user_defined_cache = model_kwargs.get(cache_name) + if user_defined_cache is not None: + if generation_config.cache_implementation is not None: + raise ValueError( + f"Passing both `cache_implementation` (used to initialize certain caches) and `{cache_name}` " + "(`Cache` object) is unsupported. Please use only one of the two." + ) + if isinstance(user_defined_cache, tuple): + raise ValueError( + "Passing a tuple of `past_key_values` is not supported anymore. Please use a `Cache` instance." + ) + return + + if generation_config.use_cache is False: + return + + cache_config = generation_config.cache_config + if cache_config is None: + cache_kwargs = {} + elif isinstance(cache_config, dict): + cache_kwargs = dict(cache_config) + elif hasattr(cache_config, "to_dict"): + cache_kwargs = dict(cache_config.to_dict()) + else: + cache_kwargs = dict(cache_config) + + text_config = self.config.get_text_config(decoder=True) if hasattr(self.config, "get_text_config") else self.config + full_batch_size = max(generation_config.num_beams, generation_config.num_return_sequences) * batch_size + cache_kwargs.setdefault("config", text_config) + cache_kwargs.setdefault("batch_size", full_batch_size) + cache_kwargs.setdefault("max_batch_size", full_batch_size) + cache_kwargs.setdefault("max_cache_len", max_cache_length) + + model_dtype = getattr(self, "dtype", None) or get_hf_config_dtype(self.config) + if model_dtype is not None: + cache_kwargs.setdefault("dtype", model_dtype) + + model_device = getattr(self, "device", None) + if model_device is not None: + cache_kwargs.setdefault("device", model_device) + + model_kwargs["past_key_values"] = custom_cache_cls(**cache_kwargs) + + encoder_decoder_cache_cls = getattr(cache_utils, "EncoderDecoderCache", None) if cache_utils is not None else None + if ( + getattr(self.config, "is_encoder_decoder", False) + and "past_key_values" in model_kwargs + and encoder_decoder_cache_cls is not None + and not isinstance(model_kwargs["past_key_values"], encoder_decoder_cache_cls) + and dynamic_cache_cls is not None + ): + model_kwargs["past_key_values"] = encoder_decoder_cache_cls( + model_kwargs["past_key_values"], + dynamic_cache_cls(config=text_config), + ) + + return None + + generation_mixin_cls._prepare_cache_for_generation = _prepare_cache_for_generation_compat + generation_mixin_cls._gptqmodel_custom_cache_impl_patch = True + + if not getattr(PreTrainedModel, "_gptqmodel_legacy_tied_weights_patch", False) and hasattr(PreTrainedModel, "get_expanded_tied_weights_keys"): + original_get_expanded_tied_weights_keys = PreTrainedModel.get_expanded_tied_weights_keys + + def get_expanded_tied_weights_keys(self, all_submodels: bool = False) -> dict: + # transformers 5.x expects `_tied_weights_keys` to be a dict, while + # older trust_remote_code models still declare `["lm_head.weight"]`. + # Handle the legacy form here so HF tied-weight expansion still works. + tied_mapping = getattr(self, "_tied_weights_keys", None) + if not isinstance(tied_mapping, (list, tuple, set)): + return original_get_expanded_tied_weights_keys(self, all_submodels=all_submodels) + + if all_submodels: + expanded_tied_weights = {} + for prefix, submodule in self.named_modules(remove_duplicate=False): + if isinstance(submodule, PreTrainedModel): + submodel_tied_weights = submodule.get_expanded_tied_weights_keys(all_submodels=False) + if prefix != "": + submodel_tied_weights = { + f"{prefix}.{k}": f"{prefix}.{v}" for k, v in submodel_tied_weights.items() + } + expanded_tied_weights.update(submodel_tied_weights) + return expanded_tied_weights + + if not getattr(getattr(self, "config", None), "tie_word_embeddings", False): + return {} + + resolved_mapping = _resolve_legacy_tied_weights_mapping(self, tied_mapping) + self._tied_weights_keys = resolved_mapping + return resolved_mapping + + PreTrainedModel.get_expanded_tied_weights_keys = get_expanded_tied_weights_keys + PreTrainedModel._gptqmodel_legacy_tied_weights_patch = True + + if not hasattr(PreTrainedModel, "is_parallelizable"): + # Older remote-code model wrappers read this legacy base-class flag + # during init, but newer transformers dropped the default attribute. + PreTrainedModel.is_parallelizable = False + + if not getattr(PreTrainedModel, "_gptqmodel_missing_all_tied_weights_patch", False): + original_getattr = PreTrainedModel.__getattr__ + + def __getattr__(self, name: str): + if name == "all_tied_weights_keys": + # Older remote-code models may skip `post_init()`, so lazily + # synthesize the tied-weight map the first time HF asks for it. + tied_keys = self.get_expanded_tied_weights_keys(all_submodels=True) + object.__setattr__(self, name, tied_keys) + return tied_keys + + return original_getattr(self, name) + + PreTrainedModel.__getattr__ = __getattr__ + PreTrainedModel._gptqmodel_missing_all_tied_weights_patch = True + + +def _normalize_chatglm_remote_code_config_compat(config: Any) -> None: + if getattr(config, "model_type", None) != "chatglm": + return + + if not hasattr(config, "seq_length") or hasattr(config, "max_length"): + return + + # Older ChatGLM remote model code still reads `config.max_length`, while + # newer transformers only preserves the serialized `seq_length` field. + config.attribute_map = dict(getattr(config, "attribute_map", {}) or {}) + config.attribute_map["max_length"] = "seq_length" + + if not hasattr(config, "use_cache"): + # transformers v5 removed `use_cache` from PretrainedConfig, + # but ChatGLM remote code still expects it. Add it back for compatibility. + config.use_cache = True + + +def _normalize_rope_parameters_config_compat(config: Any) -> None: + rope_parameters = getattr(config, "rope_parameters", None) + if ( + isinstance(rope_parameters, dict) + and rope_parameters.get("rope_type") is not None + and rope_parameters.get("rope_theta") is not None + ): + return + + convert_rope_params = getattr(config, "convert_rope_params_to_dict", None) + if callable(convert_rope_params): + try: + convert_rope_params() + except Exception as exc: + log.debug("Config: RoPE conversion fallback for %s failed: %s", type(config).__name__, exc) + else: + rope_parameters = getattr(config, "rope_parameters", None) + if ( + isinstance(rope_parameters, dict) + and rope_parameters.get("rope_type") is not None + and rope_parameters.get("rope_theta") is not None + ): + return + + legacy_rope_scaling = getattr(config, "rope_scaling", None) + rope_parameters = dict(legacy_rope_scaling) if isinstance(legacy_rope_scaling, dict) else dict(rope_parameters or {}) + + if not rope_parameters and getattr(config, "rope_theta", None) is None and getattr(config, "default_theta", None) is None: + return + + rope_parameters.setdefault("rope_type", rope_parameters.get("type", "default")) + if rope_parameters.get("rope_theta") is None: + rope_theta = getattr(config, "rope_theta", None) + if rope_theta is None: + rope_theta = getattr(config, "default_theta", 10_000.0) + rope_parameters["rope_theta"] = rope_theta + + partial_rotary_factor = getattr(config, "partial_rotary_factor", None) + if partial_rotary_factor is not None: + rope_parameters.setdefault("partial_rotary_factor", partial_rotary_factor) + + if rope_parameters["rope_type"] in {"llama3", "yarn", "longrope"}: + original_max_position_embeddings = getattr(config, "original_max_position_embeddings", None) + if original_max_position_embeddings is None: + original_max_position_embeddings = getattr(config, "max_position_embeddings", None) + if original_max_position_embeddings is not None: + rope_parameters.setdefault("original_max_position_embeddings", original_max_position_embeddings) + + config.rope_parameters = rope_parameters + + +# Restore config fields renamed by transformers 5.x before older trust_remote_code +# model files instantiate their architectures from the config object. +def _normalize_remote_code_config_compat(config: Any) -> None: + _normalize_chatglm_remote_code_config_compat(config) + model_type = getattr(config, "model_type", None) + model_type_lower = model_type.lower() if isinstance(model_type, str) else None + + if model_type_lower == "dream" or model_type == "brumby": + import transformers.modeling_rope_utils as rope_utils + # dream remote models expect "default" + if "default" not in rope_utils.ROPE_INIT_FUNCTIONS: + rope_utils.ROPE_INIT_FUNCTIONS["default"] = rope_utils.ROPE_INIT_FUNCTIONS["linear"] + + # transformers 5.x expects rope_parameters["factor"] for linear RoPE + if getattr(config, "rope_parameters", None): + config.rope_parameters.setdefault("factor", 1.0) + + # BrumbyConfig remote config may not define pad_token_id. + # Ensure the attribute exists to avoid AttributeError in transformers 5.x. + if model_type == "brumby": + rope_scaling = getattr(config, "rope_scaling", None) + + config.pad_token_id = getattr(config, "pad_token_id", None) + + # transformers 5.x normalizes RoPE config to `rope_type`, but older + # MiniCPM remote code still reads `rope_scaling["type"]` or expects `None`. + rope_scaling = getattr(config, "rope_scaling", None) + if not isinstance(rope_scaling, dict): + return + + if rope_scaling.get("rope_type") == "default" and set(rope_scaling).issubset({"rope_type", "rope_theta"}): + # transformers 5.x materializes default RoPE metadata into + # `rope_scaling`, but older remote MiniCPM code treats any dict here + # as a scaled-RoPE config and expects an explicit `factor`. + config.rope_scaling = None + return + + if "type" in rope_scaling: + return + + rope_type = rope_scaling.get("rope_type") + if rope_type is None: + return + + rope_scaling = dict(rope_scaling) + rope_scaling["type"] = rope_type + config.rope_scaling = rope_scaling + + +def deci_init_compat(config): + if config.model_type == "deci": + from transformers.models.auto import modeling_auto + with _MONKEY_PATCH_LOCK: + orig_register = modeling_auto.AutoModelForCausalLM.register + + def patched_register(cls, config_class, model_class, exist_ok=False): + # DeciLMForCausalLM inherits from LlamaForCausalLM, but does not override + # `config_class` (thus still pointing to LlamaConfig). However, the model's + # config.json declares its AutoConfig as DeciLMConfig. This leads to a mismatch + # during AutoModel registration (model_class.config_class != config_class), + # causing a ValueError. We patch this inconsistency at runtime. + if hasattr(model_class, "config_class"): + model_class.config_class = config_class + return orig_register(config_class, model_class, exist_ok=exist_ok) + + modeling_auto.AutoModelForCausalLM.register = classmethod(patched_register) + + +def normalize_hf_config_compat(config: Any, *, trust_remote_code: bool = False) -> None: + # Some transformers 5.x model classes now read `config.rope_parameters` + # directly during `from_config()`, but older local configs may only carry + # legacy RoPE fields or nothing but a default `rope_theta`. + _normalize_rope_parameters_config_compat(config) + + if not trust_remote_code: + return + + _patch_transformers_remote_code_compat() + _normalize_remote_code_config_compat(config) + # Some config classes synchronize `rope_scaling` and `rope_parameters`, so + # remote-code normalization that clears legacy default `rope_scaling` can + # also reset `rope_parameters` back to None. Re-apply the RoPE backfill + # after remote-code field cleanup so from_config() sees stable metadata. + _normalize_rope_parameters_config_compat(config) + + +def prepare_remote_code_compat(config: Any) -> None: + # Remote-code loads need both the transformers API shims and any config + # field migrations applied before instantiation happens. + normalize_hf_config_compat(config, trust_remote_code=True) + + +def prepare_remote_model_init_compat(model_id_or_path: Optional[str], config: Any) -> None: + if not model_id_or_path: + return + + deci_init_compat(config) + + auto_map = getattr(config, "auto_map", None) or {} + class_ref = auto_map.get("AutoModelForCausalLM") + if not isinstance(class_ref, str): + return + + try: + from transformers.dynamic_module_utils import get_class_from_dynamic_module + + model_cls = get_class_from_dynamic_module(class_ref, str(model_id_or_path)) + except Exception as exc: + log.debug("HF: remote model init compat pre-import failed for `%s`: %s", model_id_or_path, exc) + return + + module_root = model_cls.__module__.rsplit(".", maxsplit=1)[0] + speech_module = sys.modules.get(f"{module_root}.speech_conformer_encoder") + remote_module = sys.modules.get(model_cls.__module__) + ovis_config_module = sys.modules.get(f"{module_root}.configuration_ovis") + outer_model_cls = model_cls if isinstance(model_cls, type) else None + input_mode_enum = getattr(remote_module, "InputMode", None) if remote_module is not None else None + + with _MONKEY_PATCH_LOCK: + if config.model_type == "minicpm": + try_patch_legacy_flash_attn_flag(outer_model_cls) + base_model_cls = getattr( + remote_module, + "MiniCPMModel", + None, + ) + if base_model_cls: + try_patch_legacy_flash_attn_flag(base_model_cls) + + if config.model_type == "minicpmv" or config.model_type == "minicpmo": + vision_model_cls = getattr( + remote_module, + "SiglipVisionTransformer", + None, + ) + if vision_model_cls: + try_patch_legacy_flash_attn_flag(vision_model_cls) + + if ( + outer_model_cls is not None + and hasattr(outer_model_cls, "tie_weights") + and not getattr(outer_model_cls, "_gptqmodel_tie_weights_kwargs_patch", False) + ): + try: + tie_weights_sig = inspect.signature(outer_model_cls.tie_weights) + except (TypeError, ValueError): + tie_weights_sig = None + + if tie_weights_sig is not None: + tie_weight_params = tie_weights_sig.parameters.values() + accepts_kwargs = any(param.kind == inspect.Parameter.VAR_KEYWORD for param in tie_weight_params) + supports_missing_keys = "missing_keys" in tie_weights_sig.parameters + supports_recompute_mapping = "recompute_mapping" in tie_weights_sig.parameters + + if not accepts_kwargs and (not supports_missing_keys or not supports_recompute_mapping): + original_tie_weights = outer_model_cls.tie_weights + + # transformers 5.x passes `missing_keys=` and `recompute_mapping=` + # into tie_weights(); older remote-code models still declare + # `tie_weights(self)` and only need the original no-arg behavior. + def tie_weights_compat(self, *args, **kwargs): + return original_tie_weights(self) + + outer_model_cls.tie_weights = tie_weights_compat + outer_model_cls._gptqmodel_tie_weights_kwargs_patch = True + + if getattr(config, "model_type", None) == "ovis" and ovis_config_module is not None: + formatter_cls = getattr(ovis_config_module, "Llama3ConversationFormatter", None) + if formatter_cls is not None and not getattr(formatter_cls, "_gptqmodel_tokenizer_backend_patch", False): + support_tokenizer_types = list(getattr(formatter_cls, "support_tokenizer_types", None) or []) + if "TokenizersBackend" not in support_tokenizer_types: + # Current transformers/tokenicer fast-tokenizer backend exposes + # `TokenizersBackend` instead of the older + # `PreTrainedTokenizerFast` class name expected by Ovis remote code. + support_tokenizer_types.append("TokenizersBackend") + formatter_cls.support_tokenizer_types = support_tokenizer_types + formatter_cls._gptqmodel_tokenizer_backend_patch = True + + if getattr(config, "model_type", None) != "phi4mm": + return + + if speech_module is not None and not getattr(speech_module, "_gptqmodel_scalar_tensor_meta_patch", False): + speech_torch = getattr(speech_module, "torch", None) + original_tensor = getattr(speech_torch, "tensor", None) + if speech_torch is not None and original_tensor is not None: + def _is_phi4mm_subsampling_scalar_init() -> bool: + for frame_info in inspect.stack(context=0): + if frame_info.filename.endswith("speech_conformer_encoder.py") and frame_info.lineno == 1426: + return True + return False + + # Phi-4 MM remote audio init creates scalar tensors only to derive Python + # output sizes in NemoConvSubsampling; forcing just that scalar tensor onto + # CPU keeps meta init safe without perturbing other meta-only buffers. + def tensor_compat(data, *args, **kwargs): + current_device = getattr(torch.utils._device, "CURRENT_DEVICE", None) + if ( + kwargs.get("device") is None + and current_device == torch.device("meta") + and isinstance(data, (int, float, bool)) + and _is_phi4mm_subsampling_scalar_init() + ): + kwargs = dict(kwargs) + kwargs["device"] = "cpu" + return original_tensor(data, *args, **kwargs) + + speech_torch.tensor = tensor_compat + + positional_encoding_cls = getattr(speech_module, "AbsolutePositionalEncoding", None) + if positional_encoding_cls is not None and not getattr(positional_encoding_cls, "_gptqmodel_meta_extend_patch", False): + original_extend_pe = positional_encoding_cls.extend_pe + + def _is_phi4mm_positional_seed_call() -> bool: + for frame_info in inspect.stack(context=0): + if frame_info.filename.endswith("speech_conformer_encoder.py") and frame_info.lineno == 895: + return True + return False + + # The remote implementation seeds extend_pe() with a CPU scalar tensor. + # Under meta init, promote that seed tensor back to meta before the + # original method allocates its positional buffer. + def extend_pe_compat(self, x): + if isinstance(x, torch.Tensor) and x.device.type != "meta" and _is_phi4mm_positional_seed_call(): + x = x.to(device="meta") + return original_extend_pe(self, x) + + positional_encoding_cls.extend_pe = extend_pe_compat + positional_encoding_cls._gptqmodel_meta_extend_patch = True + + speech_module._gptqmodel_scalar_tensor_meta_patch = True + + if ( + outer_model_cls is not None + and hasattr(outer_model_cls, "forward") + and not getattr(outer_model_cls, "_gptqmodel_input_mode_patch", False) + ): + original_forward = outer_model_cls.forward + + # Text-only callers like lm_eval do not pass `input_mode`; infer the + # correct Phi-4 MM mode from the provided modality tensors instead. + def forward_compat(self, *args, **kwargs): + if kwargs.get("input_mode") is None: + kwargs = dict(kwargs) + has_vision = any( + kwargs.get(name) is not None + for name in ("input_image_embeds", "image_sizes", "image_attention_mask") + ) + has_audio = any( + kwargs.get(name) is not None + for name in ("input_audio_embeds", "audio_embed_sizes", "audio_attention_mask") + ) + + if has_vision and has_audio: + kwargs["input_mode"] = input_mode_enum.VISION_SPEECH if input_mode_enum is not None else 3 + elif has_vision: + kwargs["input_mode"] = input_mode_enum.VISION if input_mode_enum is not None else 1 + elif has_audio: + kwargs["input_mode"] = input_mode_enum.SPEECH if input_mode_enum is not None else 2 + else: + kwargs["input_mode"] = input_mode_enum.LANGUAGE if input_mode_enum is not None else 0 + return original_forward(self, *args, **kwargs) + + outer_model_cls.forward = forward_compat + outer_model_cls._gptqmodel_input_mode_patch = True + + inner_model_cls = getattr(remote_module, "Phi4MMModel", None) if remote_module is not None else None + if inner_model_cls is not None and not hasattr(inner_model_cls, "prepare_inputs_for_generation"): + # PEFT expects the inner model it wraps to expose this hook, even + # though Phi-4 MM only defines the full implementation on the outer + # CausalLM class. + def prepare_inputs_for_generation(self, input_ids=None, past_key_values=None, inputs_embeds=None, **kwargs): + model_inputs = dict(kwargs) + if inputs_embeds is not None and past_key_values is None: + model_inputs["inputs_embeds"] = inputs_embeds + else: + model_inputs["input_ids"] = input_ids + model_inputs["past_key_values"] = past_key_values + return model_inputs + + inner_model_cls.prepare_inputs_for_generation = prepare_inputs_for_generation + + try: + import importlib.util + + import peft.import_utils as peft_import_utils + import peft.tuners.lora.awq as peft_awq + except Exception: + pass + else: + if not getattr(peft_awq, "_gptqmodel_awq_probe_patch", False): + # PEFT later imports `awq.modules.linear`, so the availability + # probe must require that concrete submodule instead of top-level + # namespace packages that are missing the actual runtime. + @lru_cache(maxsize=None) + def is_auto_awq_available() -> bool: + try: + return importlib.util.find_spec("awq.modules.linear") is not None + except ModuleNotFoundError: + return False + + peft_import_utils.is_auto_awq_available = is_auto_awq_available + peft_awq.is_auto_awq_available = is_auto_awq_available + peft_awq._gptqmodel_awq_probe_patch = True + + +def try_patch_legacy_flash_attn_flag(model_cls): + with _MONKEY_PATCH_LOCK: + if ( + model_cls is not None + and getattr(model_cls, "_supports_flash_attn_2", None) is not None + and not bool(getattr(model_cls, "_supports_flash_attn", False)) + ): + # transformers 5.x checks `_supports_flash_attn`, while older + # trust_remote_code classes such as MiniCPM still expose only the + # legacy `_supports_flash_attn_2` capability flag. + model_cls._supports_flash_attn = bool(getattr(model_cls, "_supports_flash_attn_2")) + + +def load_tokenizer(tokenizer_or_path, *, model_config: Any = None, **kwargs): + from tokenicer import Tokenicer + + warnings.warn( + "gptqmodel.utils.hf.load_tokenizer() is deprecated; use Tokenicer.load(..., model_config=...) instead.", + DeprecationWarning, + stacklevel=2, + ) + return Tokenicer.load(tokenizer_or_path, model_config=model_config, **kwargs) + + + +_patch_transformers_remote_code_compat() + + +def _nested_text_config(model_config: Any) -> Optional[Any]: + if model_config is None: + return None + + get_text_config = getattr(model_config, "get_text_config", None) + if callable(get_text_config): + try: + text_config = get_text_config() + except Exception: + text_config = None + if text_config is not None and text_config is not model_config: + return text_config + + text_config = getattr(model_config, "text_config", None) + if text_config is not None and text_config is not model_config: + return text_config + + thinker_config = getattr(model_config, "thinker_config", None) + thinker_text_config = getattr(thinker_config, "text_config", None) + if thinker_text_config is not None and thinker_text_config is not model_config: + return thinker_text_config + + return None + + +def ensure_hf_model_config_token_ids(model_config: Any, tokenizer: Optional[Any] = None) -> bool: + changed = False + text_config = _nested_text_config(model_config) + + for field in ("bos_token_id", "eos_token_id", "pad_token_id"): + if not hasattr(model_config, field): + setattr(model_config, field, None) + changed = True + + if getattr(model_config, field, None) is not None: + continue + + value = getattr(text_config, field, None) if text_config is not None else None + if value is None and tokenizer is not None: + value = getattr(tokenizer, field, None) + + if value is not None: + setattr(model_config, field, value) + changed = True + + return changed + + +def load_tokenizer_with_model_config( + tokenizer: PreTrainedTokenizerBase, + model_config: Any, + *, + strict: bool = False, + pad_tokens: Optional[List[Union[str, int]]] = None, +): + ensure_hf_model_config_token_ids(model_config, tokenizer=tokenizer) + + tokenizer_cls = type(tokenizer) + tokenicer_cls_wrapper = type(f"{tokenizer_cls.__name__}", (Tokenicer, tokenizer_cls), {}) + + wrapped = tokenicer_cls_wrapper() + wrapped.tokenizer = tokenizer + wrapped.model_config = model_config + wrapped.auto_fix_pad_token(strict=strict, pad_tokens=pad_tokens) + return wrapped def _sanitize_generation_config(cfg: GenerationConfig, *, drop_sampling_fields: bool = False) -> bool: changed = False @@ -55,6 +1325,8 @@ def _load_sanitized_generation_config(path: str) -> Optional[GenerationConfig]: # TODO FIXME! Pre-quantized use AutoModelForCausalLM.from_pretrained() but post-quantized use AutoModelForCausalLM.from_config() def autofix_hf_model_config(model: PreTrainedModel, path: str = None): + ensure_hf_model_config_token_ids(getattr(model, "config", None)) + if model.can_generate(): # sync config first if path: @@ -112,6 +1384,13 @@ def autofix_hf_generation_config(cfg: GenerationConfig): log.info("Model: Auto-Fixed `generation_config` by setting `do_sample=True`.") +def sanitize_model_config(config): + if config.model_type == "chatglm" and hasattr(config, "max_length"): + # max_length can only be stored in generation_config. + # see _normalize_chatglm_remote_code_config_compat() + del config.attribute_map["max_length"] + + def sanitize_generation_config_file(path: str) -> bool: try: with open(path, "r", encoding="utf-8") as fp: @@ -135,7 +1414,6 @@ def sanitize_generation_config_file(path: str) -> bool: def build_shell_model( loader, config: Any, - dtype: Optional[torch.dtype] = None, trust_remote_code: bool = True, **model_init_kwargs, ): @@ -145,7 +1423,6 @@ def build_shell_model( Args: model_id_or_path: Hugging Face model ID or local path. - dtype: Target dtype for model parameters (replaces `torch_dtype`). trust_remote_code: Allow loading custom model classes. """ init_kwargs = model_init_kwargs.copy() @@ -154,17 +1431,21 @@ def build_shell_model( del init_kwargs["_fast_init"] # All nn.Parameters and buffers are created + normalize_hf_config_compat(config, trust_remote_code=trust_remote_code) + # All nn.Parameters and buffers are created on 'meta' and initializers are skipped. pb = log.spinner(title="Model loading...", interval=0.1) try: with init_empty_weights(include_buffers=True): shell = loader.from_config( config, - dtype=dtype, trust_remote_code=trust_remote_code, **init_kwargs ) finally: pb.close() + if trust_remote_code and isinstance(shell, PreTrainedModel): + _normalize_legacy_tied_weights_keys(shell) + return shell diff --git a/gptqmodel/utils/hub.py b/gptqmodel/utils/hub.py new file mode 100644 index 000000000..37dab1c84 --- /dev/null +++ b/gptqmodel/utils/hub.py @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: 2024-2026 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2026 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from __future__ import annotations + +from transformers.utils import hub as transformers_hub +from transformers.utils import logging as transformers_logging + + +cached_file = transformers_hub.cached_file +create_repo = transformers_hub.create_repo +has_file = transformers_hub.has_file +hf_hub_download = transformers_hub.hf_hub_download +list_repo_tree = transformers_hub.list_repo_tree +snapshot_download = transformers_hub.snapshot_download + +disable_progress_bar = transformers_logging.disable_progress_bar + +# Reuse the hub client instance that transformers already exposes so GPT-QModel +# does not need to import huggingface_hub directly. +_HF_API = list_repo_tree.__self__ + + +def list_repo_files(*args, **kwargs): + return _HF_API.list_repo_files(*args, **kwargs) + + +def model_info(*args, **kwargs): + return _HF_API.model_info(*args, **kwargs) + + +def repo_info(*args, **kwargs): + return _HF_API.repo_info(*args, **kwargs) diff --git a/gptqmodel/utils/image.py b/gptqmodel/utils/image.py index 615231871..03240fd78 100644 --- a/gptqmodel/utils/image.py +++ b/gptqmodel/utils/image.py @@ -28,7 +28,13 @@ def extract_vision_info(conversations: list[dict] | list[list[dict]]) -> list[di return vision_infos -def fetch_image(ele: dict[str, str | Image.Image]) -> Image.Image: +def fetch_image( + ele: dict[str, str | Image.Image], + image_patch_size: int | None = None, +) -> Image.Image: + # Some multimodal model adapters forward image-related kwargs here even + # though local image loading does not need them. + del image_patch_size if "image" in ele: image = ele["image"] else: diff --git a/gptqmodel/utils/importer.py b/gptqmodel/utils/importer.py index a35f25a09..4553532f9 100644 --- a/gptqmodel/utils/importer.py +++ b/gptqmodel/utils/importer.py @@ -16,9 +16,11 @@ from ..models._const import DEVICE, normalize_device from ..nn_modules.qlinear import BaseQuantLinear, PackableQuantLinear from ..quantization import FORMAT, METHOD +from ..quantization.config import _normalize_quant_bits, quant_bits_width from ..utils.env import env_flag from ..utils.logger import setup_logger from . import BACKEND +from .backend import normalize_backend from .rocm import IS_ROCM from .torch import HAS_CUDA, HAS_MPS, HAS_XPU @@ -31,6 +33,15 @@ message_logged = False log = setup_logger() + +def _supports_pack_api(cls: Type[BaseQuantLinear]) -> bool: + return ( + issubclass(cls, PackableQuantLinear) + or (hasattr(cls, "pack") and callable(getattr(cls, "pack"))) + or (hasattr(cls, "pack_block") and callable(getattr(cls, "pack_block"))) + ) + + def iter_quant_linear_kernels() -> List[Type[BaseQuantLinear]]: kernels = [] seen = set() @@ -66,6 +77,7 @@ def get_kernel_backends(cls: Type[BaseQuantLinear]) -> List[BACKEND]: def get_kernel_for_backend(backend: BACKEND, quant_method: METHOD, fmt: FORMAT) -> Type[BaseQuantLinear]: + backend = normalize_backend(backend, quant_method=quant_method) matches = [] for cls in iter_quant_linear_kernels(): if backend not in get_kernel_backends(cls): @@ -95,7 +107,7 @@ def _import_all_qlinear_kernels() -> None: continue try: importlib.import_module(f"{qlinear_pkg.__name__}.{name}") - except ImportError as exc: + except (ImportError, OSError) as exc: log.debug(f"Skipping qlinear module import `{name}`: {exc}") @@ -277,6 +289,9 @@ def auto_select_device(device: Optional[DEVICE], backend: Optional[BACKEND]) -> assert backend is None or isinstance(backend, BACKEND) if device is None: + # Backend-specific kernels should default to a compatible device class. + if backend in (BACKEND.GPTQ_TORCH_FUSED, BACKEND.AWQ_TORCH_FUSED, BACKEND.TORCH_FUSED, BACKEND.TORCH_FUSED_AWQ): + return DEVICE.XPU if HAS_XPU else DEVICE.CPU if HAS_CUDA: device = DEVICE.CUDA elif HAS_XPU: @@ -287,6 +302,7 @@ def auto_select_device(device: Optional[DEVICE], backend: Optional[BACKEND]) -> device = DEVICE.CPU return device + # public/stable api exposed to transformer/optimum def hf_select_quant_linear( bits: int, @@ -300,8 +316,7 @@ def hf_select_quant_linear( backend: Optional[Union[str, BACKEND]] = None, ) -> Type[BaseQuantLinear]: # convert hf string backend to backend.enum - if isinstance(backend, str): - backend = BACKEND(backend.lower()) + backend = normalize_backend(backend, quant_method=METHOD.GPTQ) if device_map is not None: device = hf_normalize_device_device_map(None, device_map) @@ -340,8 +355,7 @@ def hf_select_quant_linear_v2( backend: Optional[Union[str, BACKEND]] = None, ) -> Type[BaseQuantLinear]: # convert hf string backend to backend.enum - if isinstance(backend, str): - backend = BACKEND(backend.lower()) + backend = normalize_backend(backend, quant_method=quant_method) def _normalize_enum(value, enum_cls, field: str): if isinstance(value, enum_cls): @@ -409,7 +423,7 @@ def _normalize_dtype(value: Optional[Union[str, torch.dtype]], field: str) -> Op # auto select the correct/optimal QuantLinear class def select_quant_linear( - bits: int, + bits, group_size: int, desc_act: bool, sym: bool, @@ -429,6 +443,9 @@ def select_quant_linear( format = FORMAT(format.lower()) if isinstance(quant_method, str): quant_method = METHOD(quant_method.lower()) + backend = normalize_backend(backend, quant_method=quant_method) + + bits = quant_bits_width(_normalize_quant_bits(bits, format_value=format)) supported_formats = BACKEND_TO_METHOD_FORMAT_MAPPING.get(quant_method) if supported_formats is None: @@ -467,10 +484,7 @@ def select_quant_linear( log.info(f"skip {k} for {str(err)}") if validate: if pack: - check_pack_func = issubclass(cls, PackableQuantLinear) or ( - hasattr(cls, "pack_block") and callable(getattr(cls, "pack_block")) - ) - if check_pack_func: + if _supports_pack_api(cls): #if not message_logged: # logger.info(f"Auto pick kernel based on compatibility: {cls}") # message_logged = True @@ -517,8 +531,13 @@ def select_quant_linear( log.info(f"{'Packing ' if pack else ''}Kernel: selected: `{qlinear.__name__}`") if not validate: raise ValueError(err) - else: - if multi_select: - return [qlinear] - else: - return qlinear + + if pack: + if not _supports_pack_api(qlinear): + raise ValueError( + f"Selected backend `{backend}` with kernel `{qlinear.__name__}` cannot pack quantized weights for format `{format}`." + ) + + if multi_select: + return [qlinear] + return qlinear diff --git a/gptqmodel/utils/internal_gguf.py b/gptqmodel/utils/internal_gguf.py new file mode 100644 index 000000000..ade647999 --- /dev/null +++ b/gptqmodel/utils/internal_gguf.py @@ -0,0 +1,715 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import os +import sys +from collections import OrderedDict +from enum import IntEnum +from typing import Any, Literal, NamedTuple, TypeVar, Union + +import numpy as np +import numpy.typing as npt +import pcre +import torch + +from ..nn_modules.qlinear.gguf import ( + _dequantize_gguf_tensor_numpy, + _dequantize_sign_only_torch, + _quantize_gguf_tensor_numpy, +) + + +__version__ = "0.10.0" +_INTERNAL_GGUF_DEQUANT_DEVICE_ENV = "GPTQMODEL_INTERNAL_GGUF_DEQUANT_DEVICE" +_INTERNAL_GGUF_DEQUANT_MAX_BYTES_ENV = "GPTQMODEL_INTERNAL_GGUF_DEQUANT_MAX_BYTES" +_INTERNAL_GGUF_DEQUANT_DEFAULT_MAX_BYTES = 256 * 1024 * 1024 +_INTERNAL_GGUF_QUANTIZED_LOADER_ENV = "GPTQMODEL_INTERNAL_GGUF_QUANTIZED_LOADER" + +GGUF_MAGIC = 0x46554747 +GGUF_VERSION = 3 +GGUF_DEFAULT_ALIGNMENT = 32 +QK_K = 256 +READER_SUPPORTED_VERSIONS = (2, GGUF_VERSION) + + +class GGMLQuantizationType(IntEnum): + F32 = 0 + F16 = 1 + Q4_0 = 2 + Q4_1 = 3 + Q5_0 = 6 + Q5_1 = 7 + Q8_0 = 8 + Q8_1 = 9 + Q2_K = 10 + Q3_K = 11 + Q4_K = 12 + Q5_K = 13 + Q6_K = 14 + Q8_K = 15 + IQ2_XXS = 16 + IQ2_XS = 17 + IQ3_XXS = 18 + IQ1_S = 19 + IQ4_NL = 20 + IQ3_S = 21 + IQ2_S = 22 + IQ4_XS = 23 + I8 = 24 + I16 = 25 + I32 = 26 + I64 = 27 + F64 = 28 + IQ1_M = 29 + BF16 = 30 + TQ1_0 = 34 + TQ2_0 = 35 + MXFP4 = 39 + Q1_0 = 40 + Q1_0_g128 = 41 + + +class GGUFEndian(IntEnum): + LITTLE = 0 + BIG = 1 + + +class GGUFValueType(IntEnum): + UINT8 = 0 + INT8 = 1 + UINT16 = 2 + INT16 = 3 + UINT32 = 4 + INT32 = 5 + FLOAT32 = 6 + BOOL = 7 + STRING = 8 + ARRAY = 9 + UINT64 = 10 + INT64 = 11 + FLOAT64 = 12 + + +GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = { + GGMLQuantizationType.F32: (1, 4), + GGMLQuantizationType.F16: (1, 2), + GGMLQuantizationType.Q4_0: (32, 2 + 16), + GGMLQuantizationType.Q4_1: (32, 2 + 2 + 16), + GGMLQuantizationType.Q5_0: (32, 2 + 4 + 16), + GGMLQuantizationType.Q5_1: (32, 2 + 2 + 4 + 16), + GGMLQuantizationType.Q8_0: (32, 2 + 32), + GGMLQuantizationType.Q8_1: (32, 4 + 4 + 32), + GGMLQuantizationType.Q2_K: (256, 2 + 2 + QK_K // 16 + QK_K // 4), + GGMLQuantizationType.Q3_K: (256, 2 + QK_K // 4 + QK_K // 8 + 12), + GGMLQuantizationType.Q4_K: (256, 2 + 2 + QK_K // 2 + 12), + GGMLQuantizationType.Q5_K: (256, 2 + 2 + QK_K // 2 + QK_K // 8 + 12), + GGMLQuantizationType.Q6_K: (256, 2 + QK_K // 2 + QK_K // 4 + QK_K // 16), + GGMLQuantizationType.Q8_K: (256, 4 + QK_K + QK_K // 8), + GGMLQuantizationType.IQ2_XXS: (256, 2 + QK_K // 4), + GGMLQuantizationType.IQ2_XS: (256, 2 + QK_K // 4 + QK_K // 32), + GGMLQuantizationType.IQ3_XXS: (256, 2 + QK_K // 4 + QK_K // 8), + GGMLQuantizationType.IQ1_S: (256, 2 + QK_K // 8 + QK_K // 16), + GGMLQuantizationType.IQ4_NL: (32, 2 + 16), + GGMLQuantizationType.IQ3_S: (256, 2 + QK_K // 4 + QK_K // 8 + QK_K // 32 + 4), + GGMLQuantizationType.IQ2_S: (256, 2 + QK_K // 4 + QK_K // 16), + GGMLQuantizationType.IQ4_XS: (256, 2 + 2 + QK_K // 2 + QK_K // 64), + GGMLQuantizationType.I8: (1, 1), + GGMLQuantizationType.I16: (1, 2), + GGMLQuantizationType.I32: (1, 4), + GGMLQuantizationType.I64: (1, 8), + GGMLQuantizationType.F64: (1, 8), + GGMLQuantizationType.IQ1_M: (256, QK_K // 8 + QK_K // 16 + QK_K // 32), + GGMLQuantizationType.BF16: (1, 2), + GGMLQuantizationType.TQ1_0: (256, 2 + 4 * 13), + GGMLQuantizationType.TQ2_0: (256, 2 + 64), + GGMLQuantizationType.MXFP4: (32, 1 + 16), + GGMLQuantizationType.Q1_0: (32, 2 + 4), + GGMLQuantizationType.Q1_0_g128: (128, 2 + 16), +} +_TORCH_SIGN_ONLY_QTYPES: dict[GGMLQuantizationType, tuple[int, int]] = { + GGMLQuantizationType.Q1_0: GGML_QUANT_SIZES[GGMLQuantizationType.Q1_0], + GGMLQuantizationType.Q1_0_g128: GGML_QUANT_SIZES[GGMLQuantizationType.Q1_0_g128], +} + +MODEL_ARCH_QWEN3 = "qwen3" +MODEL_ARCH_NAMES = { + MODEL_ARCH_QWEN3: "qwen3", +} + +_GGUF_SCALAR_TO_NP: dict[GGUFValueType, type[np.generic]] = { + GGUFValueType.UINT8: np.uint8, + GGUFValueType.INT8: np.int8, + GGUFValueType.UINT16: np.uint16, + GGUFValueType.INT16: np.int16, + GGUFValueType.UINT32: np.uint32, + GGUFValueType.INT32: np.int32, + GGUFValueType.FLOAT32: np.float32, + GGUFValueType.UINT64: np.uint64, + GGUFValueType.INT64: np.int64, + GGUFValueType.FLOAT64: np.float64, + GGUFValueType.BOOL: np.bool_, +} + +_QWEN3_DIRECT_NAME_MAP = { + "model.embed_tokens": "token_embd", + "model.norm": "output_norm", +} +_QWEN3_BLOCK_PATTERNS: tuple[tuple[pcre.Pattern, str], ...] = ( + (pcre.compile(r"^model\.layers\.(\d+)\.self_attn\.q_proj$"), "blk.{bid}.attn_q"), + (pcre.compile(r"^model\.layers\.(\d+)\.self_attn\.k_proj$"), "blk.{bid}.attn_k"), + (pcre.compile(r"^model\.layers\.(\d+)\.self_attn\.v_proj$"), "blk.{bid}.attn_v"), + (pcre.compile(r"^model\.layers\.(\d+)\.self_attn\.o_proj$"), "blk.{bid}.attn_output"), + (pcre.compile(r"^model\.layers\.(\d+)\.self_attn\.q_norm$"), "blk.{bid}.attn_q_norm"), + (pcre.compile(r"^model\.layers\.(\d+)\.self_attn\.k_norm$"), "blk.{bid}.attn_k_norm"), + (pcre.compile(r"^model\.layers\.(\d+)\.mlp\.gate_proj$"), "blk.{bid}.ffn_gate"), + (pcre.compile(r"^model\.layers\.(\d+)\.mlp\.up_proj$"), "blk.{bid}.ffn_up"), + (pcre.compile(r"^model\.layers\.(\d+)\.mlp\.down_proj$"), "blk.{bid}.ffn_down"), + (pcre.compile(r"^model\.layers\.(\d+)\.input_layernorm$"), "blk.{bid}.attn_norm"), + (pcre.compile(r"^model\.layers\.(\d+)\.post_attention_layernorm$"), "blk.{bid}.ffn_norm"), +) +_QWEN3_LINEAR_TENSOR_RE = pcre.compile( + r"^blk\.\d+\.(attn_q|attn_k|attn_v|attn_output|ffn_gate|ffn_up|ffn_down)\.weight$" +) +_GGUF_BITS_ALIAS_BY_QTYPE: dict[GGMLQuantizationType, str] = { + GGMLQuantizationType.Q1_0: "q1_0", + GGMLQuantizationType.Q1_0_g128: "q1_0_g128", + GGMLQuantizationType.Q4_0: "q4_0", + GGMLQuantizationType.Q8_0: "q8_0", + GGMLQuantizationType.Q4_K: "q4_k", + GGMLQuantizationType.Q5_K: "q5_k", + GGMLQuantizationType.Q6_K: "q6_k", +} + + +class GGUFQuantizedCheckpointSpec(NamedTuple): + model_type: str + bits_alias: str + tensor_qtype: GGMLQuantizationType + lm_head_quantized: bool + + +def quant_shape_to_byte_shape(shape: tuple[int, ...] | list[int], quant_type: GGMLQuantizationType) -> tuple[int, ...]: + block_size, type_size = GGML_QUANT_SIZES[GGMLQuantizationType(int(quant_type))] + if shape[-1] % block_size != 0: + raise ValueError( + f"Quantized tensor row size ({shape[-1]}) is not a multiple of {quant_type.name} block size ({block_size})" + ) + return (*shape[:-1], shape[-1] // block_size * type_size) + + +def quantize(data: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray: + return _quantize_gguf_tensor_numpy(np.asarray(data), GGMLQuantizationType(int(qtype))) + + +def native_quantized_loader_enabled() -> bool: + raw = os.getenv(_INTERNAL_GGUF_QUANTIZED_LOADER_ENV) + if raw is None: + return True + return str(raw).strip().lower() not in {"", "0", "false", "off", "no"} + + +def _reader_field_value(reader: "GGUFReader", key: str): + field = reader.get_field(key) + if field is None: + return None + return field.contents() + + +def inspect_quantized_checkpoint( + source: "GGUFReader | os.PathLike[str] | str", +) -> GGUFQuantizedCheckpointSpec | None: + if hasattr(source, "tensors") and hasattr(source, "get_field"): + reader = source + else: + reader = GGUFReader(source) + architecture = _reader_field_value(reader, "general.architecture") + if architecture != MODEL_ARCH_QWEN3: + return None + + linear_qtypes = { + GGMLQuantizationType(int(tensor.tensor_type)) + for tensor in reader.tensors + if _QWEN3_LINEAR_TENSOR_RE.fullmatch(tensor.name) + } + if len(linear_qtypes) != 1: + return None + + tensor_qtype = next(iter(linear_qtypes)) + bits_alias = _GGUF_BITS_ALIAS_BY_QTYPE.get(tensor_qtype) + if bits_alias is None: + return None + + lm_head_quantized = any( + tensor.name == "output.weight" and GGMLQuantizationType(int(tensor.tensor_type)) == tensor_qtype + for tensor in reader.tensors + ) + return GGUFQuantizedCheckpointSpec( + model_type=MODEL_ARCH_QWEN3, + bits_alias=bits_alias, + tensor_qtype=tensor_qtype, + lm_head_quantized=lm_head_quantized, + ) + + +def _resolve_torch_dequant_device() -> torch.device | None: + raw = os.getenv(_INTERNAL_GGUF_DEQUANT_DEVICE_ENV) + if raw is None or raw.strip() == "": + return None + + try: + device = torch.device(raw.strip()) + except Exception: + return None + + if device.type == "cuda": + if not torch.cuda.is_available(): + return None + if device.index is not None and device.index >= torch.cuda.device_count(): + return None + return device + + if device.type == "cpu": + return device + + return None + + +def _resolve_torch_dequant_chunk_rows( + *, + packed_row_bytes: int, + block_size: int, + type_size: int, +) -> int: + output_row_bytes = packed_row_bytes // type_size * block_size * np.dtype(np.float32).itemsize + if output_row_bytes <= 0: + return 1 + + raw_limit = os.getenv(_INTERNAL_GGUF_DEQUANT_MAX_BYTES_ENV) + try: + byte_limit = int(raw_limit) if raw_limit is not None else _INTERNAL_GGUF_DEQUANT_DEFAULT_MAX_BYTES + except ValueError: + byte_limit = _INTERNAL_GGUF_DEQUANT_DEFAULT_MAX_BYTES + + return max(1, byte_limit // output_row_bytes) + + +def _dequantize_sign_only_torch_to_numpy( + data: np.ndarray, + *, + block_size: int, + type_size: int, + device: torch.device, +) -> np.ndarray: + rows = np.asarray(data, dtype=np.uint8) + if rows.shape[-1] % type_size != 0: + raise ValueError( + f"GGUF sign-only row byte width must be divisible by {type_size}, got " + f"{rows.shape[-1]} for shape {rows.shape}." + ) + + packed_cols = rows.shape[-1] + output_cols = packed_cols // type_size * block_size + flat_rows = rows.reshape(-1, packed_cols) + flat_output = np.empty((flat_rows.shape[0], output_cols), dtype=np.float32) + chunk_rows = _resolve_torch_dequant_chunk_rows( + packed_row_bytes=packed_cols, + block_size=block_size, + type_size=type_size, + ) + + for start in range(0, flat_rows.shape[0], chunk_rows): + end = min(start + chunk_rows, flat_rows.shape[0]) + chunk = _dequantize_sign_only_torch( + flat_rows[start:end], + block_size=block_size, + type_size=type_size, + device=device, + dtype=torch.float32, + ) + if device.type == "cuda": + torch.cuda.synchronize(device=device) + flat_output[start:end] = chunk.cpu().numpy() + + return flat_output.reshape(*rows.shape[:-1], output_cols) + + +def dequantize_to_torch( + data: np.ndarray, + qtype: GGMLQuantizationType, + *, + device: torch.device | str | None = None, + dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + resolved_qtype = GGMLQuantizationType(int(qtype)) + target_device = torch.device("cpu") if device is None else torch.device(device) + + sign_only_info = _TORCH_SIGN_ONLY_QTYPES.get(resolved_qtype) + if sign_only_info is not None: + block_size, type_size = sign_only_info + return _dequantize_sign_only_torch( + np.asarray(data, dtype=np.uint8), + block_size=block_size, + type_size=type_size, + device=target_device, + dtype=dtype, + ).contiguous() + + if resolved_qtype == GGMLQuantizationType.F32: + tensor = torch.from_numpy(np.array(data, dtype=np.float32, copy=True, order="C")) + elif resolved_qtype == GGMLQuantizationType.F16: + tensor = torch.from_numpy(np.array(data, dtype=np.float16, copy=True, order="C")).to(torch.float32) + elif resolved_qtype == GGMLQuantizationType.BF16: + rows = np.asarray(data, dtype=np.uint16).astype(np.uint32) + tensor = torch.from_numpy(np.left_shift(rows, np.uint32(16)).view(np.float32).copy()) + else: + tensor = torch.from_numpy(np.ascontiguousarray(_dequantize_gguf_tensor_numpy(np.asarray(data), resolved_qtype))) + + if tensor.device != target_device or tensor.dtype != dtype: + tensor = tensor.to(device=target_device, dtype=dtype) + return tensor.contiguous() + + +def dequantize(data: np.ndarray, qtype: GGMLQuantizationType) -> np.ndarray: + resolved_qtype = GGMLQuantizationType(int(qtype)) + device = _resolve_torch_dequant_device() + sign_only_info = _TORCH_SIGN_ONLY_QTYPES.get(resolved_qtype) + if device is not None and sign_only_info is not None: + block_size, type_size = sign_only_info + return _dequantize_sign_only_torch_to_numpy( + np.asarray(data, dtype=np.uint8), + block_size=block_size, + type_size=type_size, + device=device, + ) + + return _dequantize_gguf_tensor_numpy(np.asarray(data), resolved_qtype) + + +class _MinimalTensorNameMap: + def __init__(self, arch: str, n_blocks: int): + self.arch = arch + self.n_blocks = n_blocks + + def get_name(self, hf_name: str): + if self.arch != MODEL_ARCH_QWEN3: + return None + + direct = _QWEN3_DIRECT_NAME_MAP.get(hf_name) + if direct is not None: + return direct + + for pattern, template in _QWEN3_BLOCK_PATTERNS: + match = pattern.fullmatch(hf_name) + if match is None: + continue + block_id = int(match.group(1)) + if block_id >= self.n_blocks: + return None + return template.format(bid=block_id) + + return None + + +def get_tensor_name_map(arch, n_blocks: int): + return _MinimalTensorNameMap(str(arch), int(n_blocks)) + + +class ReaderField(NamedTuple): + offset: int + name: str + parts: list[npt.NDArray[Any]] + data: list[int] + types: list[GGUFValueType] + + def contents(self, index_or_slice: int | slice = slice(None)) -> Any: + if not self.types: + return None + + def _to_string(part: npt.NDArray[Any]) -> str: + return part.tobytes().decode("utf-8") + + main_type = self.types[0] + if main_type == GGUFValueType.ARRAY: + sub_type = self.types[-1] + indices = self.data[index_or_slice] + if isinstance(index_or_slice, int): + index_items = [indices] + else: + index_items = list(indices) + + if sub_type == GGUFValueType.STRING: + values = [_to_string(self.parts[idx]) for idx in index_items] + else: + values = [self.parts[idx].tolist()[0] for idx in index_items] + return values[0] if isinstance(index_or_slice, int) else values + + if main_type == GGUFValueType.STRING: + return _to_string(self.parts[-1]) + return self.parts[-1].tolist()[0] + + +class ReaderTensor(NamedTuple): + name: str + tensor_type: GGMLQuantizationType + shape: npt.NDArray[np.uint32] + n_elements: int + n_bytes: int + data_offset: int + data: npt.NDArray[Any] + field: ReaderField + + +class GGUFReader: + byte_order: Literal["I", "S"] = "I" + alignment: int = GGUF_DEFAULT_ALIGNMENT + data_offset: int + + _DT = TypeVar("_DT", bound=npt.DTypeLike) + + def __init__(self, path: os.PathLike[str] | str, mode: Literal["r", "r+", "c"] = "r"): + self.data = np.memmap(path, mode=mode) + self.fields: OrderedDict[str, ReaderField] = OrderedDict() + self.tensors: list[ReaderTensor] = [] + + offset = 0 + if self._get(offset, np.uint32, override_order="<")[0] != GGUF_MAGIC: + raise ValueError("GGUF magic invalid") + offset += 4 + + version_array = self._get(offset, np.uint32) + if version_array[0] & 0xFFFF == 0: + self.byte_order = "S" + version_array = version_array.view(version_array.dtype.newbyteorder(self.byte_order)) + version = int(version_array[0]) + if version not in READER_SUPPORTED_VERSIONS: + raise ValueError(f"Unsupported GGUF version {version}") + self.endianess = GGUFEndian.BIG if self.byte_order == "S" else GGUFEndian.LITTLE + offset += self._push_field(ReaderField(offset, "GGUF.version", [version_array], [0], [GGUFValueType.UINT32])) + + counts = self._get(offset, np.uint64, 2) + offset += self._push_field(ReaderField(offset, "GGUF.tensor_count", [counts[:1]], [0], [GGUFValueType.UINT64])) + offset += self._push_field(ReaderField(offset, "GGUF.kv_count", [counts[1:]], [0], [GGUFValueType.UINT64])) + tensor_count, kv_count = (int(counts[0]), int(counts[1])) + + offset = self._read_metadata_fields(offset, kv_count) + offset, tensor_fields = self._read_tensor_fields(offset, tensor_count) + + alignment_field = self.fields.get("general.alignment") + if alignment_field is not None: + if alignment_field.types != [GGUFValueType.UINT32]: + raise ValueError("Bad type for general.alignment field") + self.alignment = int(alignment_field.parts[-1][0]) + if self.alignment == 0 or (self.alignment & (self.alignment - 1)) != 0: + raise ValueError("Invalid alignment: expected a non-zero power of two") + + padding = offset % self.alignment + if padding: + offset += self.alignment - padding + self.data_offset = offset + self._read_tensors(offset, tensor_fields) + + def get_field(self, key: str) -> Union[ReaderField, None]: + return self.fields.get(key) + + def get_tensor(self, idx: int) -> ReaderTensor: + return self.tensors[idx] + + def _get( + self, + offset: int, + dtype: npt.DTypeLike, + count: int = 1, + override_order: None | Literal["I", "S", "<"] = None, + ) -> npt.NDArray[Any]: + count = int(count) + itemsize = int(np.empty([], dtype=dtype).itemsize) + end_offset = offset + itemsize * count + array = self.data[offset:end_offset].view(dtype=dtype)[:count] + return array.view(array.dtype.newbyteorder(self.byte_order if override_order is None else override_order)) + + def _push_field(self, field: ReaderField, *, include_size: bool = True) -> int: + if field.name in self.fields: + raise KeyError(f"Duplicate GGUF field `{field.name}` at offset {field.offset}") + self.fields[field.name] = field + return sum(int(part.nbytes) for part in field.parts) if include_size else 0 + + def _read_string(self, offset: int) -> tuple[npt.NDArray[np.uint64], npt.NDArray[np.uint8]]: + length = self._get(offset, np.uint64) + return length, self._get(offset + 8, np.uint8, length[0]) + + def _parse_value( + self, + offset: int, + raw_type: int, + ) -> tuple[int, list[npt.NDArray[Any]], list[int], list[GGUFValueType]]: + value_type = GGUFValueType(raw_type) + if value_type == GGUFValueType.STRING: + parts = list(self._read_string(offset)) + return sum(int(part.nbytes) for part in parts), parts, [1], [value_type] + + scalar_dtype = _GGUF_SCALAR_TO_NP.get(value_type) + if scalar_dtype is not None: + value = self._get(offset, scalar_dtype) + return int(value.nbytes), [value], [0], [value_type] + + if value_type == GGUFValueType.ARRAY: + inner_type = self._get(offset, np.uint32) + inner_count = self._get(offset + int(inner_type.nbytes), np.uint64) + parts: list[npt.NDArray[Any]] = [inner_type, inner_count] + data_indices: list[int] = [] + types = [value_type] + cursor = offset + int(inner_type.nbytes + inner_count.nbytes) + + for index in range(int(inner_count[0])): + size, child_parts, child_data_indices, child_types = self._parse_value(cursor, int(inner_type[0])) + if index == 0: + types.extend(child_types) + base_index = len(parts) + parts.extend(child_parts) + data_indices.extend(base_index + child_index for child_index in child_data_indices) + cursor += size + + return cursor - offset, parts, data_indices, types + + raise ValueError(f"Unsupported GGUF field type {value_type}") + + def _read_metadata_fields(self, offset: int, count: int) -> int: + for _ in range(count): + field_offset = offset + key_length, key_data = self._read_string(offset) + offset += int(key_length.nbytes + key_data.nbytes) + raw_type = self._get(offset, np.uint32) + offset += int(raw_type.nbytes) + size, parts, data_indices, types = self._parse_value(offset, int(raw_type[0])) + field_parts = [key_length, key_data, raw_type, *parts] + shifted_indices = [3 + index for index in data_indices] + self._push_field( + ReaderField( + field_offset, + key_data.tobytes().decode("utf-8"), + field_parts, + shifted_indices, + types, + ), + include_size=False, + ) + offset += size + return offset + + def _read_tensor_field(self, offset: int) -> ReaderField: + field_offset = offset + name_length, name_data = self._read_string(offset) + offset += int(name_length.nbytes + name_data.nbytes) + n_dims = self._get(offset, np.uint32) + offset += int(n_dims.nbytes) + dims = self._get(offset, np.uint64, n_dims[0]) + offset += int(dims.nbytes) + raw_dtype = self._get(offset, np.uint32) + offset += int(raw_dtype.nbytes) + tensor_offset = self._get(offset, np.uint64) + return ReaderField( + field_offset, + name_data.tobytes().decode("utf-8"), + [name_length, name_data, n_dims, dims, raw_dtype, tensor_offset], + [1, 3, 4, 5], + [], + ) + + def _read_tensor_fields(self, offset: int, count: int) -> tuple[int, list[ReaderField]]: + fields: list[ReaderField] = [] + for _ in range(count): + field = self._read_tensor_field(offset) + offset += sum(int(part.nbytes) for part in field.parts) + fields.append(field) + return offset, fields + + def _read_tensors(self, data_start: int, fields: list[ReaderField]) -> None: + tensors: list[ReaderTensor] = [] + seen_names: set[str] = set() + + for field in fields: + _name_length, name_data, _n_dims, dims, raw_dtype, tensor_offset = field.parts + tensor_name = name_data.tobytes().decode("utf-8") + if tensor_name in seen_names: + raise ValueError(f"Duplicate GGUF tensor `{tensor_name}`") + seen_names.add(tensor_name) + + tensor_type = GGMLQuantizationType(int(raw_dtype[0])) + n_elements = int(np.prod(dims)) + logical_shape = tuple(reversed(dims.tolist())) + block_size, type_size = GGML_QUANT_SIZES[tensor_type] + n_bytes = n_elements * type_size // block_size + absolute_offset = int(data_start + tensor_offset[0]) + + if tensor_type in { + GGMLQuantizationType.F16, + GGMLQuantizationType.F32, + GGMLQuantizationType.F64, + GGMLQuantizationType.I8, + GGMLQuantizationType.I16, + GGMLQuantizationType.I32, + GGMLQuantizationType.I64, + }: + dtype_by_type = { + GGMLQuantizationType.F16: np.float16, + GGMLQuantizationType.F32: np.float32, + GGMLQuantizationType.F64: np.float64, + GGMLQuantizationType.I8: np.int8, + GGMLQuantizationType.I16: np.int16, + GGMLQuantizationType.I32: np.int32, + GGMLQuantizationType.I64: np.int64, + } + item_dtype = dtype_by_type[tensor_type] + item_count = n_elements + storage_shape = logical_shape + else: + item_dtype = np.uint8 + item_count = n_bytes + storage_shape = quant_shape_to_byte_shape(logical_shape, tensor_type) + + tensor_data = self._get(absolute_offset, item_dtype, item_count).reshape(storage_shape) + tensors.append( + ReaderTensor( + name=tensor_name, + tensor_type=tensor_type, + shape=dims, + n_elements=n_elements, + n_bytes=n_bytes, + data_offset=absolute_offset, + data=tensor_data, + field=field, + ) + ) + + self.tensors = tensors + + +def install_runtime(): + sys.modules["gguf"] = sys.modules[__name__] + return sys.modules["gguf"] + + +__all__ = [ + "GGML_QUANT_SIZES", + "GGMLQuantizationType", + "GGUFQuantizedCheckpointSpec", + "GGUF_DEFAULT_ALIGNMENT", + "GGUF_MAGIC", + "GGUF_VERSION", + "GGUFEndian", + "GGUFReader", + "GGUFValueType", + "MODEL_ARCH_NAMES", + "ReaderField", + "ReaderTensor", + "dequantize", + "dequantize_to_torch", + "get_tensor_name_map", + "inspect_quantized_checkpoint", + "install_runtime", + "native_quantized_loader_enabled", + "quant_shape_to_byte_shape", + "quantize", +] diff --git a/gptqmodel/utils/jit_compile_baselines.py b/gptqmodel/utils/jit_compile_baselines.py new file mode 100644 index 000000000..79edad497 --- /dev/null +++ b/gptqmodel/utils/jit_compile_baselines.py @@ -0,0 +1,37 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +"""Reference cold-build timings for torch.ops JIT extensions. + +These values were measured on 2026-04-04 from clean temporary build roots on +the reference CUDA development host with ``MAX_JOBS`` unset. They are not used +for correctness and only provide a more realistic first-use progress estimate +than an open-ended spinner. + +When a kernel changes materially, refresh the corresponding value by re-timing +its clean JIT build. +""" + +from __future__ import annotations + + +JIT_COMPILE_BASELINE_SECONDS: dict[str, float] = { + "gptqmodel_awq_ops": 61.640, + "gptqmodel_exllamav2_awq_ops": 35.421, + "gptqmodel_exllamav2_ops": 34.528, + "gptqmodel_exllamav3_ops": 61.871, + "gptqmodel_marlin_bf16_ops": 120.634, + "gptqmodel_marlin_fp16_ops": 116.863, + "gptqmodel_pack_block_cpu": 31.096, + "gptqmodel_paroquant_rotation": 78.430, + "gptqmodel_qqq_ops": 82.492, +} + + +def get_jit_compile_baseline_seconds(extension_name: str) -> float | None: + """Return the recorded reference build duration for one JIT extension.""" + + value = JIT_COMPILE_BASELINE_SECONDS.get(extension_name) + if value is None: + return None + return float(value) diff --git a/gptqmodel/utils/linalg_warmup.py b/gptqmodel/utils/linalg_warmup.py index 7ba368512..85ff7a80f 100644 --- a/gptqmodel/utils/linalg_warmup.py +++ b/gptqmodel/utils/linalg_warmup.py @@ -9,10 +9,32 @@ import torch +from .torch import TORCH_GTE_210 + _GLOBAL_WARMUP_LOCK = threading.Lock() +def _get_cuda_preferred_linalg_library(): + preferred = getattr(torch.backends.cuda, "preferred_linalg_library", None) + if preferred is None: + return None + if callable(preferred): + return preferred() + return preferred + + +def _set_cuda_preferred_linalg_library(backend) -> bool: + preferred = getattr(torch.backends.cuda, "preferred_linalg_library", None) + if preferred is None: + return False + if callable(preferred): + preferred(backend=backend) + return True + setattr(torch.backends.cuda, "preferred_linalg_library", backend) + return True + + def _make_spd(size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: """Generate a small symmetric positive definite matrix.""" base = torch.randn((size, size), device=device, dtype=dtype) @@ -64,19 +86,18 @@ def run_torch_linalg_warmup(device: torch.device) -> None: _run_qr(device, dtype) if device.type == "cuda" and hasattr(torch.backends, "cuda"): - preferred = getattr(torch.backends.cuda, "preferred_linalg_library", None) - if callable(preferred): - current = preferred() + current = _get_cuda_preferred_linalg_library() + if current is not None and not TORCH_GTE_210: # Core warmup already ran using the currently preferred backend above. # Some installations fall back to MAGMA when the primary solver is unavailable, # so we pre-initialize MAGMA as well when it differs from the preferred backend. if current and current != "magma": with contextlib.suppress(Exception): - torch.backends.cuda.preferred_linalg_library(backend="magma") + _set_cuda_preferred_linalg_library("magma") _run_cholesky_and_eigh(device, torch.float32) if current: with contextlib.suppress(Exception): - torch.backends.cuda.preferred_linalg_library(backend=current) + _set_cuda_preferred_linalg_library(current) __all__ = ["run_torch_linalg_warmup"] diff --git a/gptqmodel/utils/logger.py b/gptqmodel/utils/logger.py index 4e8374e0e..70bf76328 100644 --- a/gptqmodel/utils/logger.py +++ b/gptqmodel/utils/logger.py @@ -4,6 +4,8 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium import contextlib +import os +import sys import threading import time from collections import OrderedDict @@ -12,8 +14,92 @@ from logbar import LogBar +class _SilentProgress: + """Minimal no-op progress handle for non-interactive test sessions.""" + + def __init__(self, iterable=None): + self._iterable = iterable if iterable is not None else () + self.current_iter_step = 0 + + def __iter__(self): + if isinstance(self._iterable, int): + return iter(range(self._iterable)) + return iter(self._iterable) + + def __len__(self): + if isinstance(self._iterable, int): + return self._iterable + return len(self._iterable) + + def attach(self, *_args, **_kwargs): + return self + + def manual(self): + return self + + def set(self, **_kwargs): + return self + + def title(self, *_args, **_kwargs): + return self + + def subtitle(self, *_args, **_kwargs): + return self + + def draw(self, force: bool = False): + return self + + def refresh(self): + return self + + def next(self, step: int = 1): + self.current_iter_step += int(step) + return self + + def close(self): + return None + + +class _AdaptiveLoggerProxy: + """Proxy that keeps structured logs while adapting live rendering at call time.""" + + def __init__(self, logger: LogBar): + self._logger = logger + + def pb(self, iterable, *, output_interval: Optional[int] = None): + if _suppress_live_renderables(): + return _SilentProgress(iterable) + return self._logger.pb(iterable, output_interval=output_interval) + + def spinner(self, title: str = "", *, interval: float = 0.5, tail_length: int = 4): + if _suppress_live_renderables(): + return _SilentProgress() + return self._logger.spinner(title=title, interval=interval, tail_length=tail_length) + + def __getattr__(self, name): + return getattr(self._logger, name) + + +def _suppress_live_renderables() -> bool: + """Disable live progress redraws under non-interactive pytest capture.""" + + if "PYTEST_CURRENT_TEST" not in os.environ: + return False + + try: + return not sys.stdout.isatty() + except Exception: + return True + + +def live_renderables_suppressed() -> bool: + """Report whether redraw-based progress should be replaced by durable logs.""" + + return _suppress_live_renderables() + + def setup_logger(): - return LogBar.shared() + return _AdaptiveLoggerProxy(LogBar.shared()) class QuantizationRegionTimer: diff --git a/gptqmodel/utils/looper_helpers.py b/gptqmodel/utils/looper_helpers.py index 14ddfbef9..fdd63a901 100644 --- a/gptqmodel/utils/looper_helpers.py +++ b/gptqmodel/utils/looper_helpers.py @@ -18,6 +18,7 @@ from ..utils.attn_mask import normalize_seq_mask from ..utils.device import get_device from ..utils.env import env_flag +from ..utils.inspect import get_supported_kwargs from ..utils.logger import setup_logger from ..utils.model import move_to, nested_move_to from ..utils.safe import ThreadSafe @@ -359,6 +360,7 @@ def forward_batch_worker( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.Tensor], *, + gptq_model=None, support_batch_quantize: bool, is_lm_head_module: bool, need_output: bool, @@ -378,8 +380,12 @@ def forward_batch_worker( attn_tensor = move_to(attention_mask, device=module_device) additional_inputs: Dict[str, torch.Tensor] = {} - if support_batch_quantize and attn_tensor is not None: - additional_inputs["attention_mask"] = attn_tensor + accepts_var_kw, allowed_kwargs = get_supported_kwargs(module.forward) + supports_attention_mask = accepts_var_kw or allowed_kwargs is None or "attention_mask" in allowed_kwargs + if supports_attention_mask: + # Some layers, such as ChatGLM blocks, still require the kwarg even + # when the effective mask is `None`. + additional_inputs["attention_mask"] = attn_tensor if support_batch_quantize else None if position_ids is not None: additional_inputs["position_ids"] = move_to(position_ids, device=module_device) @@ -401,6 +407,13 @@ def forward_batch_worker( # TODO: some models does not honor generate config.use_cache property so we are forced to hack this to false additional_inputs["use_cache"] = False + if gptq_model is not None: + additional_inputs = gptq_model.prepare_layer_replay_kwargs( + layer=module, + layer_input=inputs, + additional_inputs=additional_inputs, + target_device=module_device, + ) module_output = None kv_next = None @@ -419,7 +432,11 @@ def forward_batch_worker( if reuse_kv and module_output is not None and isinstance(module_output, tuple) and len(module_output) > 0: kv_next = module_output[-1] - result_output = module_output if need_output else None + result_output = None + if need_output and module_output is not None: + # Replay only consumes the hidden-state tensor that feeds the next + # layer. + result_output = module_output[0] if isinstance(module_output, tuple) else module_output # Promptly release VRAM to reduce peak memory usage. del inputs diff --git a/gptqmodel/utils/machete.py b/gptqmodel/utils/machete.py index 57aaee535..77cce92cd 100644 --- a/gptqmodel/utils/machete.py +++ b/gptqmodel/utils/machete.py @@ -1,33 +1,473 @@ -# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2026 ModelCloud.ai # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium from __future__ import annotations +import json +import os +import re +import shutil +import subprocess +import sys +import tarfile +import tempfile +import urllib.request +from pathlib import Path from typing import List, Optional import torch -from ._extension_loader import load_extension_module +from .cpp import ( + TorchOpsJitExtension, + default_jit_cflags, + default_jit_cuda_cflags, + default_torch_ops_build_root, + local_nvcc_version_at_least, + resolved_cuda_arch_flags, +) from .logger import setup_logger from .marlin_scalar_type import ScalarType, scalar_types +from .rocm import IS_ROCM log = setup_logger() -machete_import_exception: Optional[str] = None -try: - gptqmodel_machete_kernels = load_extension_module("gptqmodel_machete_kernels") -except ImportError as e: # pragma: no cover - surfaced at runtime - machete_import_exception = str(e) - gptqmodel_machete_kernels = None +_MACHETE_OPS_NAME = "gptqmodel_machete_ops" +_MACHETE_OPS_NAMESPACE = "gptqmodel_machete" + +_CUTLASS_VERSION = "4.4.2" +_CUTLASS_RELEASE_URL = f"https://github.com/NVIDIA/cutlass/archive/refs/tags/v{_CUTLASS_VERSION}.tar.gz" +_CUTLASS_VERSION_MARKER = ".gptqmodel_cutlass_version" +_CUTLASS_VERSION_DEFINE_PATTERN = re.compile(r"^\s*#define\s+CUTLASS_(MAJOR|MINOR|PATCH)\s+(\d+)\s*$", re.MULTILINE) +_MACHETE_REQUIRED_COMPUTE_CAPABILITY = (9, 0) +_MACHETE_MIN_SHARED_MEMORY_PER_BLOCK_OPTIN = 204800 +_MACHETE_SM90A_ARCH_FLAGS = ( + "-gencode=arch=compute_90a,code=sm_90a", + "-gencode=arch=compute_90a,code=compute_90a", +) +_MACHETE_JIT_NVCC_THREADS = "16" +_MACHETE_REQUIRED_TORCH_NVCC_UNDEFINES = ( + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", +) MACHETE_PREPACKED_BLOCK_SHAPE = (64, 128) +def _machete_project_root() -> Path: + return Path(__file__).resolve().parents[2] + + +def _machete_source_root() -> Path: + return _machete_project_root() / "gptqmodel_ext" / "machete" + + +def _repo_local_cutlass_root() -> Path: + return _machete_project_root() / "cutlass" + + +def _cutlass_download_cache_dir() -> Path: + return _machete_project_root() / "build" / "_deps" + + +def _cutlass_python_bindings_present(cutlass_root: Path) -> bool: + python_dir = cutlass_root / "python" + return ( + (python_dir / "cutlass_library.py").is_file() + or (python_dir / "cutlass_library" / "__init__.py").is_file() + ) + + +def _cutlass_checkout_complete(cutlass_root: Path) -> bool: + common_include_dir = cutlass_root / "examples" / "common" / "include" + util_include_dir = cutlass_root / "tools" / "util" / "include" + return ( + (cutlass_root / "include" / "cutlass" / "cutlass.h").is_file() + and (cutlass_root / "tools" / "library" / "include").is_dir() + and (common_include_dir.is_dir() or util_include_dir.is_dir()) + and _cutlass_python_bindings_present(cutlass_root) + ) + + +def _repo_local_cutlass_version_marker(cutlass_root: Path) -> Path: + return cutlass_root / _CUTLASS_VERSION_MARKER + + +def _cutlass_checkout_version(cutlass_root: Path) -> Optional[str]: + version_header = cutlass_root / "include" / "cutlass" / "version.h" + if not version_header.is_file(): + return None + + macros = dict(_CUTLASS_VERSION_DEFINE_PATTERN.findall(version_header.read_text(encoding="utf-8"))) + required_macros = {"MAJOR", "MINOR", "PATCH"} + if macros.keys() < required_macros: + return None + + return f"{macros['MAJOR']}.{macros['MINOR']}.{macros['PATCH']}" + + +def _cutlass_checkout_version_error(cutlass_root: Path) -> Optional[str]: + version = _cutlass_checkout_version(cutlass_root) + if version is None: + return ( + f"`{cutlass_root}` is missing a readable `include/cutlass/version.h`; " + f"GPTQModel requires CUTLASS v{_CUTLASS_VERSION}." + ) + if version != _CUTLASS_VERSION: + return ( + f"`{cutlass_root}` contains CUTLASS v{version}, but GPTQModel requires v{_CUTLASS_VERSION}." + ) + return None + + +def _repo_local_cutlass_version_matches(cutlass_root: Path) -> bool: + marker = _repo_local_cutlass_version_marker(cutlass_root) + return ( + _cutlass_checkout_version(cutlass_root) == _CUTLASS_VERSION + and marker.is_file() + and marker.read_text(encoding="utf-8").strip() == _CUTLASS_VERSION + ) + + +def _mark_repo_local_cutlass_version(cutlass_root: Path) -> None: + _repo_local_cutlass_version_marker(cutlass_root).write_text(f"{_CUTLASS_VERSION}\n", encoding="utf-8") + + +def _use_repo_local_cutlass(cutlass_root: Path) -> Path: + if not _repo_local_cutlass_version_matches(cutlass_root): + _mark_repo_local_cutlass_version(cutlass_root) + os.environ["GPTQMODEL_CUTLASS_DIR"] = str(cutlass_root) + return cutlass_root + + +def _download_cutlass_archive(url: str, destination: Path) -> None: + destination.parent.mkdir(parents=True, exist_ok=True) + partial = destination.with_suffix(destination.suffix + ".part") + if partial.exists(): + partial.unlink() + + log.info("Machete: downloading CUTLASS v%s into `%s`.", _CUTLASS_VERSION, destination) + with urllib.request.urlopen(url) as response, partial.open("wb") as handle: + shutil.copyfileobj(response, handle) + partial.replace(destination) + + +def _extract_cutlass_archive(archive_path: Path, destination_parent: Path) -> None: + with tarfile.open(archive_path, "r:gz") as archive: + extract_kwargs = {"path": destination_parent} + if sys.version_info >= (3, 12): + extract_kwargs["filter"] = "data" + archive.extractall(**extract_kwargs) + + +def _ensure_cutlass_source() -> Path: + repo_local_root = _repo_local_cutlass_root().resolve() + configured_root = os.getenv("GPTQMODEL_CUTLASS_DIR") + if configured_root: + configured_path = Path(configured_root).expanduser().resolve() + if _cutlass_checkout_complete(configured_path): + version_error = _cutlass_checkout_version_error(configured_path) + if version_error is None: + if configured_path == repo_local_root: + return _use_repo_local_cutlass(configured_path) + return configured_path + if configured_path != repo_local_root: + raise RuntimeError( + "Machete: GPTQMODEL_CUTLASS_DIR points to an incompatible CUTLASS checkout. " + f"{version_error} Unset GPTQMODEL_CUTLASS_DIR to allow auto-download, or point it at a " + f"CUTLASS v{_CUTLASS_VERSION} checkout." + ) + log.info( + "Machete: GPTQMODEL_CUTLASS_DIR points to stale repo-local CUTLASS checkout `%s`; refreshing to v%s.", + configured_path, + _CUTLASS_VERSION, + ) + else: + log.info( + "Machete: GPTQMODEL_CUTLASS_DIR=`%s` is incomplete; falling back to repo-local CUTLASS checkout.", + configured_path, + ) + + if _cutlass_checkout_complete(repo_local_root): + if _cutlass_checkout_version_error(repo_local_root) is None: + return _use_repo_local_cutlass(repo_local_root) + if repo_local_root.exists(): + current_version = _cutlass_checkout_version(repo_local_root) + log.info( + "Machete: refreshing repo-local CUTLASS checkout at `%s`%s to v%s.", + repo_local_root, + f" from v{current_version}" if current_version else "", + _CUTLASS_VERSION, + ) + + archive_path = _cutlass_download_cache_dir() / f"cutlass-v{_CUTLASS_VERSION}.tar.gz" + archive_path.parent.mkdir(parents=True, exist_ok=True) + if not archive_path.exists(): + _download_cutlass_archive(_CUTLASS_RELEASE_URL, archive_path) + + parent = repo_local_root.parent + parent.mkdir(parents=True, exist_ok=True) + if repo_local_root.exists(): + shutil.rmtree(repo_local_root, ignore_errors=True) + + with tempfile.TemporaryDirectory(dir=parent, prefix="cutlass-unpack-") as temp_dir: + temp_root = Path(temp_dir) + _extract_cutlass_archive(archive_path, temp_root) + extracted_root = temp_root / f"cutlass-{_CUTLASS_VERSION}" + if not extracted_root.exists(): + raise RuntimeError(f"Machete: failed to extract CUTLASS archive `{archive_path}`.") + extracted_root.replace(repo_local_root) + _mark_repo_local_cutlass_version(repo_local_root) + + return _use_repo_local_cutlass(repo_local_root) + + +def _machete_generated_dir() -> Path: + return _machete_source_root() / "generated" + + +def _machete_generation_marker() -> Path: + return _machete_generated_dir() / ".gptqmodel_complete" + + +def _cutlass_python_binding_inputs(cutlass_root: Path) -> list[Path]: + python_dir = cutlass_root / "python" + candidates = [ + python_dir / "cutlass_library.py", + python_dir / "cutlass_library" / "__init__.py", + ] + return [candidate for candidate in candidates if candidate.exists()] + + +def _machete_generation_signature(cutlass_root: Path) -> str: + return json.dumps( + { + "cutlass_root": str(cutlass_root.resolve()), + "cutlass_version": _CUTLASS_VERSION, + }, + sort_keys=True, + ) + + +def _machete_generator_inputs(cutlass_root: Path) -> list[Path]: + project_root = _machete_project_root() + return [ + _machete_source_root() / "generate.py", + project_root / "gptqmodel_ext" / "cutlass_extensions" / "vllm_cutlass_library_extension.py", + *_cutlass_python_binding_inputs(cutlass_root), + ] + + +def _generated_machete_sources() -> list[Path]: + return sorted(_machete_generated_dir().glob("*.cu")) + + +def _generated_machete_sources_current(cutlass_root: Path) -> bool: + marker = _machete_generation_marker() + generated_sources = _generated_machete_sources() + if not marker.exists() or not generated_sources: + return False + if marker.read_text(encoding="utf-8").strip() != _machete_generation_signature(cutlass_root): + return False + marker_mtime_ns = marker.stat().st_mtime_ns + return not any(path.stat().st_mtime_ns > marker_mtime_ns for path in _machete_generator_inputs(cutlass_root)) + + +def _run_machete_generator(cutlass_root: Path) -> None: + generator = _machete_source_root() / "generate.py" + env = os.environ.copy() + env["GPTQMODEL_CUTLASS_DIR"] = str(cutlass_root) + + log.info("Machete: generating CUTLASS-backed kernel sources in `%s`.", _machete_generated_dir()) + result = subprocess.run( + [sys.executable, str(generator)], + cwd=str(_machete_project_root()), + env=env, + check=False, + capture_output=True, + text=True, + ) + if result.returncode != 0: + raise RuntimeError( + "Machete: failed to generate kernel sources.\n" + f"Return code: {result.returncode}\n" + f"Stdout: {result.stdout}\n" + f"Stderr: {result.stderr}" + ) + + +def _ensure_generated_machete_sources() -> list[Path]: + cutlass_root = _ensure_cutlass_source() + if _generated_machete_sources_current(cutlass_root): + return _generated_machete_sources() + + generated_dir = _machete_generated_dir() + if generated_dir.exists(): + shutil.rmtree(generated_dir, ignore_errors=True) + + _run_machete_generator(cutlass_root) + + generated_sources = _generated_machete_sources() + if not generated_sources: + raise RuntimeError( + "Machete: generator completed without producing any CUDA sources." + ) + + _machete_generation_marker().write_text( + _machete_generation_signature(cutlass_root), + encoding="utf-8", + ) + return generated_sources + + +def _machete_sources() -> list[str]: + machete_root = _machete_source_root() + generated_sources = _ensure_generated_machete_sources() + return [str(machete_root / "machete_pytorch.cu"), *[str(path) for path in generated_sources]] + + +def _machete_include_paths() -> list[str]: + project_root = _machete_project_root() + cutlass_root = _ensure_cutlass_source() + include_paths = [ + str((project_root / "gptqmodel_ext").resolve()), + str((project_root / "gptqmodel_ext" / "cutlass_extensions").resolve()), + str((cutlass_root / "include").resolve()), + str((cutlass_root / "tools" / "library" / "include").resolve()), + ] + common_include_dir = cutlass_root / "examples" / "common" / "include" + util_include_dir = cutlass_root / "tools" / "util" / "include" + if common_include_dir.is_dir(): + include_paths.append(str(common_include_dir.resolve())) + if util_include_dir.is_dir(): + include_paths.append(str(util_include_dir.resolve())) + return include_paths + + +def _machete_extra_cflags() -> list[str]: + return default_jit_cflags(enable_bf16=True) + + +def _machete_hopper_arch_cuda_cflags() -> list[str]: + if _machete_static_runtime_error(): + return [] + + # vLLM builds Machete only for Hopper-compatible sm90a targets. Torch's + # default JIT arch detection resolves H100/H200 to sm_90, which compiles + # but triggers CUTLASS runtime abort spam for sm90a-only instructions. + if any("90a" in flag for flag in resolved_cuda_arch_flags()): + return [] + return list(_MACHETE_SM90A_ARCH_FLAGS) + + +def _machete_extra_cuda_cflags() -> list[str]: + flags = [ + *_MACHETE_REQUIRED_TORCH_NVCC_UNDEFINES, + *default_jit_cuda_cflags( + enable_bf16=True, + include_lineinfo=True, + include_nvcc_threads=True, + include_ptxas_optimizations=True, + include_ptxas_verbosity=False, + include_fatbin_compression=True, + include_diag_suppress=True, + nvcc_threads=_MACHETE_JIT_NVCC_THREADS, + ), + *_machete_hopper_arch_cuda_cflags(), + ] + if local_nvcc_version_at_least(12, 8): + flags.insert(0, "-static-global-template-stub=false") + return flags + + +def _machete_extra_ldflags() -> list[str]: + # Hopper tensor-map entry points such as cuTensorMapEncodeTiled live in the + # CUDA driver library, not libcudart. Link libcuda explicitly so the JIT + # extension remains loadable after a successful compile on non-SM90 hosts. + return ["-lcuda"] + + +_MACHETE_TORCH_OPS_EXTENSION = TorchOpsJitExtension( + name=_MACHETE_OPS_NAME, + namespace=_MACHETE_OPS_NAMESPACE, + required_ops=("machete_prepack_B", "machete_mm", "machete_supported_schedules"), + sources=_machete_sources, + build_root_env="GPTQMODEL_MACHETE_BUILD_ROOT", + default_build_root=lambda: default_torch_ops_build_root("machete"), + display_name="Machete", + extra_cflags=_machete_extra_cflags, + extra_cuda_cflags=_machete_extra_cuda_cflags, + extra_include_paths=_machete_include_paths, + extra_ldflags=_machete_extra_ldflags, + force_rebuild_env="GPTQMODEL_MACHETE_FORCE_REBUILD", + verbose_env="GPTQMODEL_EXT_VERBOSE", + requires_cuda=True, +) + + +def _extension_api(): + from gptqmodel import extension as extension_api + + return extension_api + + +def _machete_static_runtime_error() -> str: + if IS_ROCM: + return "Machete kernel is not supported on ROCm." + if not torch.cuda.is_available(): + return "Machete kernel requires CUDA." + capability = torch.cuda.get_device_capability() + if capability != _MACHETE_REQUIRED_COMPUTE_CAPABILITY: + props = torch.cuda.get_device_properties(torch.cuda.current_device()) + return ( + "Machete kernel currently supports Hopper-class SM90 GPUs only; " + f"found `{props.name}` with compute capability {capability[0]}.{capability[1]}." + ) + props = torch.cuda.get_device_properties(torch.cuda.current_device()) + shared_memory_per_block_optin = getattr( + props, + "shared_memory_per_block_optin", + props.shared_memory_per_block, + ) + if shared_memory_per_block_optin < _MACHETE_MIN_SHARED_MEMORY_PER_BLOCK_OPTIN: + return ( + "Machete kernel requires at least " + f"{_MACHETE_MIN_SHARED_MEMORY_PER_BLOCK_OPTIN} bytes of opt-in shared memory per block; " + f"`{props.name}` exposes {shared_memory_per_block_optin}." + ) + return "" + + +def clear_machete_extension_cache() -> None: + _MACHETE_TORCH_OPS_EXTENSION.clear_cache() + + +def machete_runtime_available() -> bool: + static_error = _machete_static_runtime_error() + if static_error: + return False + return _extension_api().is_available("machete") + + +def machete_runtime_error() -> str: + static_error = _machete_static_runtime_error() + if static_error: + return static_error + + extension_api = _extension_api() + if extension_api.is_available("machete"): + return "" + return extension_api.error("machete") or "Machete runtime unavailable." + + +def prewarm_machete_extension() -> bool: + return _extension_api().load(name="machete")["machete"] + + def _validate_machete_device_support() -> bool: - return (torch.cuda.is_available() - and torch.cuda.get_device_capability()[0] >= 9) + return _machete_static_runtime_error() == "" def query_machete_supported_quant_types(zero_points: bool) -> List[ScalarType]: @@ -46,35 +486,30 @@ def query_machete_supported_group_sizes(act_type: torch.dtype) -> List[int]: return [-1, 128] -def check_machete_supports_shape(in_features: int, - out_features: int) -> tuple[bool, Optional[str]]: +def check_machete_supports_shape( + in_features: int, + out_features: int, +) -> tuple[bool, Optional[str]]: if in_features % MACHETE_PREPACKED_BLOCK_SHAPE[0] != 0: - return (False, - f"Input features size must be divisible by {MACHETE_PREPACKED_BLOCK_SHAPE[0]}") + return ( + False, + f"Input features size must be divisible by {MACHETE_PREPACKED_BLOCK_SHAPE[0]}", + ) if out_features % MACHETE_PREPACKED_BLOCK_SHAPE[1] != 0: - return (False, - f"Output features size must be divisible by {MACHETE_PREPACKED_BLOCK_SHAPE[1]}") - return (True, None) - - -def _ensure_machete_loaded(): - if machete_import_exception is not None: - raise ImportError( - f"Trying to use the machete backend, but could not import the C++/CUDA dependencies: {machete_import_exception}" + return ( + False, + f"Output features size must be divisible by {MACHETE_PREPACKED_BLOCK_SHAPE[1]}", ) - - -def _maybe_scalar_type(t: Optional[torch.Tensor]) -> Optional[torch.dtype]: - return t.dtype if t is not None else None + return (True, None) def machete_prepack_B( - weight: torch.Tensor, - a_type: torch.dtype, - b_type: ScalarType, - group_scales_type: Optional[torch.dtype]) -> torch.Tensor: - _ensure_machete_loaded() - return gptqmodel_machete_kernels.machete_prepack_B( + weight: torch.Tensor, + a_type: torch.dtype, + b_type: ScalarType, + group_scales_type: Optional[torch.dtype], +) -> torch.Tensor: + return _extension_api().op("machete", "machete_prepack_B")( weight, a_type, b_type.id, @@ -83,15 +518,15 @@ def machete_prepack_B( def machete_supported_schedules( - a_type: torch.dtype, - b_type: ScalarType, - group_scales_type: Optional[torch.dtype] = None, - group_zeros_type: Optional[torch.dtype] = None, - channel_scales_type: Optional[torch.dtype] = None, - token_scales_type: Optional[torch.dtype] = None, - out_type: Optional[torch.dtype] = None) -> List[str]: - _ensure_machete_loaded() - return gptqmodel_machete_kernels.machete_supported_schedules( + a_type: torch.dtype, + b_type: ScalarType, + group_scales_type: Optional[torch.dtype] = None, + group_zeros_type: Optional[torch.dtype] = None, + channel_scales_type: Optional[torch.dtype] = None, + token_scales_type: Optional[torch.dtype] = None, + out_type: Optional[torch.dtype] = None, +) -> List[str]: + return _extension_api().op("machete", "machete_supported_schedules")( a_type, b_type.id, group_scales_type, @@ -103,19 +538,19 @@ def machete_supported_schedules( def machete_mm( - *, - a: torch.Tensor, - b_q: torch.Tensor, - b_type: ScalarType, - b_group_scales: Optional[torch.Tensor] = None, - b_group_zeros: Optional[torch.Tensor] = None, - b_group_size: Optional[int] = None, - b_channel_scales: Optional[torch.Tensor] = None, - a_token_scales: Optional[torch.Tensor] = None, - out_type: Optional[torch.dtype] = None, - schedule: Optional[str] = None) -> torch.Tensor: - _ensure_machete_loaded() - return gptqmodel_machete_kernels.machete_mm( + *, + a: torch.Tensor, + b_q: torch.Tensor, + b_type: ScalarType, + b_group_scales: Optional[torch.Tensor] = None, + b_group_zeros: Optional[torch.Tensor] = None, + b_group_size: Optional[int] = None, + b_channel_scales: Optional[torch.Tensor] = None, + a_token_scales: Optional[torch.Tensor] = None, + out_type: Optional[torch.dtype] = None, + schedule: Optional[str] = None, +) -> torch.Tensor: + return _extension_api().op("machete", "machete_mm")( a, b_q, b_type.id, @@ -130,9 +565,10 @@ def machete_mm( def pack_quantized_values_into_int32( - tensor: torch.Tensor, - qtype: ScalarType, - packed_dim: int = 0) -> torch.Tensor: + tensor: torch.Tensor, + qtype: ScalarType, + packed_dim: int = 0, +) -> torch.Tensor: perm = tuple(i for i in range(tensor.ndim) if i != packed_dim) + (packed_dim,) inv_perm = tuple(perm.index(i) for i in range(len(perm))) temp = tensor.permute(perm) @@ -152,9 +588,10 @@ def pack_quantized_values_into_int32( def unpack_quantized_values_into_int32( - tensor: torch.Tensor, - qtype: ScalarType, - packed_dim: int = 0) -> torch.Tensor: + tensor: torch.Tensor, + qtype: ScalarType, + packed_dim: int = 0, +) -> torch.Tensor: perm = tuple(i for i in range(tensor.ndim) if i != packed_dim) + (packed_dim,) inv_perm = tuple(perm.index(i) for i in range(len(perm))) temp = tensor.permute(perm) @@ -173,13 +610,18 @@ def unpack_quantized_values_into_int32( __all__ = [ + "_ensure_cutlass_source", + "_ensure_generated_machete_sources", "_validate_machete_device_support", "check_machete_supports_shape", - "machete_import_exception", + "clear_machete_extension_cache", "machete_mm", "machete_prepack_B", + "machete_runtime_available", + "machete_runtime_error", "machete_supported_schedules", "pack_quantized_values_into_int32", + "prewarm_machete_extension", "query_machete_supported_act_types", "query_machete_supported_group_sizes", "query_machete_supported_quant_types", diff --git a/gptqmodel/utils/marlin.py b/gptqmodel/utils/marlin.py index f5f829f98..fd3ab47df 100644 --- a/gptqmodel/utils/marlin.py +++ b/gptqmodel/utils/marlin.py @@ -1,26 +1,279 @@ -# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2026 ModelCloud.ai # SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium +from __future__ import annotations + +import subprocess +import sys +from pathlib import Path +from shutil import which from typing import Callable, List, Optional, Tuple, Union import numpy import torch from ..utils.logger import setup_logger -from ._extension_loader import load_extension_module +from .cpp import ( + TorchOpsJitExtension, + cuda_include_paths_with_fallback, + default_jit_cflags, + default_jit_cuda_cflags, + default_torch_ops_build_root, + detected_cuda_wheel_include_paths, + local_nvcc_version_at_least, +) from .marlin_scalar_type import ScalarType from .rocm import IS_ROCM log = setup_logger() -marlin_import_exception = None -gptqmodel_marlin_kernels = None -try: - gptqmodel_marlin_kernels = load_extension_module("gptqmodel_marlin_kernels") -except ImportError as e: - marlin_import_exception = str(e) +_MARLIN_FP16_OPS_NAME = "gptqmodel_marlin_fp16_ops" +_MARLIN_FP16_NAMESPACE = "gptqmodel_marlin_fp16" +_MARLIN_BF16_OPS_NAME = "gptqmodel_marlin_bf16_ops" +_MARLIN_BF16_NAMESPACE = "gptqmodel_marlin_bf16" +_MARLIN_REQUIRED_CUDA_HEADERS = ( + "cuda_runtime_api.h", + "cusparse.h", + "cublas_v2.h", + "cublasLt.h", + "cusolverDn.h", +) + + +def _marlin_capability_supported(major: int, minor: int) -> bool: + return major > 7 or (major == 7 and minor >= 5) + + +def _marlin_environment_error() -> str: + if IS_ROCM: + return "Marlin kernel is not supported on ROCm." + if not torch.cuda.is_available(): + return "Marlin kernel requires CUDA." + try: + major, minor = torch.cuda.get_device_capability() + except Exception as exc: # pragma: no cover - depends on host CUDA runtime + return f"Marlin kernel failed to query CUDA device capability: {exc}" + if not _marlin_capability_supported(major, minor): + return f"Marlin kernel requires compute capability >= 7.5, got {major}.{minor}." + return "" + + +marlin_import_exception = _marlin_environment_error() or None + + +def _marlin_root() -> Path: + return Path(__file__).resolve().parents[2] / "gptqmodel_ext" / "marlin" + + +def _marlin_cuda_extra_name() -> str | None: + raw = getattr(torch.version, "cuda", None) + if not raw: + return None + try: + major = int(str(raw).split(".", maxsplit=1)[0]) + except (TypeError, ValueError): # pragma: no cover - depends on torch build metadata + return None + if major >= 13: + return "marlin-cuda" + if major == 12: + return "marlin-cuda12" + return None + + +def _marlin_missing_header_names(error_text: str) -> list[str]: + return [ + header_name for header_name in _MARLIN_REQUIRED_CUDA_HEADERS + if f"{header_name}: No such file or directory" in error_text + ] + + +def _marlin_header_install_hint(error_text: str) -> str: + missing_headers = _marlin_missing_header_names(error_text) + if not missing_headers: + return "" + + if detected_cuda_wheel_include_paths(): + return "" + + extra_name = _marlin_cuda_extra_name() + missing_headers_text = ", ".join(missing_headers) + if extra_name is not None: + install_text = ( + f"Install the wheel-provided CUDA headers with " + f"`pip install \"gptqmodel[{extra_name}]\"`." + ) + else: + install_text = ( + "Install the CUDA runtime/developer headers that match your Torch CUDA build." + ) + + nvcc_text = ( + "A local `nvcc` on PATH is still required for Marlin JIT." + if which("nvcc") + else "Marlin JIT also requires a local `nvcc` on PATH." + ) + return ( + f"Missing CUDA developer headers for Marlin JIT ({missing_headers_text}). " + f"{install_text} {nvcc_text}" + ) + + +def _ensure_generated_marlin_kernels() -> Path: + root = _marlin_root() + generator = root / "generate_kernels.py" + check_result = subprocess.run( + [sys.executable, str(generator), "--check"], + cwd=str(root), + capture_output=True, + text=True, + check=False, + ) + if check_result.returncode == 0: + return root + + result = subprocess.run( + [sys.executable, str(generator)], + cwd=str(root), + capture_output=True, + text=True, + check=False, + ) + if result.returncode != 0: + details = (result.stderr or result.stdout or check_result.stderr or check_result.stdout or "").strip() + raise RuntimeError( + "Marlin kernel generation failed" + + (f": {details}" if details else ".") + ) + return root + + +def _marlin_sources(dtype_tag: str) -> list[str]: + root = _ensure_generated_marlin_kernels() + sources = [ + str(root / f"marlin_torch_{dtype_tag}.cpp"), + str(root / f"gptq_marlin_{dtype_tag}.cu"), + str(root / "gptq_marlin_repack.cu"), + str(root / "awq_marlin_repack.cu"), + ] + sources.extend(str(path) for path in sorted(root.glob(f"kernel_{dtype_tag}_*.cu"))) + if len(sources) <= 4: + raise RuntimeError(f"Marlin {dtype_tag} sources are incomplete under `{root}`.") + return sources + + +def _marlin_include_paths() -> list[str]: + return cuda_include_paths_with_fallback( + [str(_marlin_root())], + required_header_names=_MARLIN_REQUIRED_CUDA_HEADERS, + ) + + +def _marlin_extra_cflags() -> list[str]: + return default_jit_cflags(enable_bf16=True) + + +def _marlin_extra_cuda_cflags() -> list[str]: + flags = default_jit_cuda_cflags( + enable_bf16=True, + include_lineinfo=True, + include_nvcc_threads=True, + include_ptxas_optimizations=True, + include_ptxas_verbosity=False, + include_fatbin_compression=True, + include_diag_suppress=True, + ) + if local_nvcc_version_at_least(12, 8): + flags.insert(0, "-static-global-template-stub=false") + return flags + + +_MARLIN_FP16_TORCH_OPS_EXTENSION = TorchOpsJitExtension( + name=_MARLIN_FP16_OPS_NAME, + namespace=_MARLIN_FP16_NAMESPACE, + required_ops=("gptq_marlin_gemm_fp16", "gptq_marlin_repack", "awq_marlin_repack"), + sources=lambda: _marlin_sources("fp16"), + build_root_env="GPTQMODEL_MARLIN_FP16_BUILD_ROOT", + default_build_root=lambda: default_torch_ops_build_root("marlin_fp16"), + display_name="Marlin fp16", + extra_cflags=_marlin_extra_cflags, + extra_cuda_cflags=_marlin_extra_cuda_cflags, + extra_include_paths=_marlin_include_paths, + force_rebuild_env="GPTQMODEL_MARLIN_FORCE_REBUILD", + verbose_env="GPTQMODEL_EXT_VERBOSE", + requires_cuda=True, +) + + +_MARLIN_BF16_TORCH_OPS_EXTENSION = TorchOpsJitExtension( + name=_MARLIN_BF16_OPS_NAME, + namespace=_MARLIN_BF16_NAMESPACE, + required_ops=("gptq_marlin_gemm_bf16", "gptq_marlin_repack", "awq_marlin_repack"), + sources=lambda: _marlin_sources("bf16"), + build_root_env="GPTQMODEL_MARLIN_BF16_BUILD_ROOT", + default_build_root=lambda: default_torch_ops_build_root("marlin_bf16"), + display_name="Marlin bf16", + extra_cflags=_marlin_extra_cflags, + extra_cuda_cflags=_marlin_extra_cuda_cflags, + extra_include_paths=_marlin_include_paths, + force_rebuild_env="GPTQMODEL_MARLIN_FORCE_REBUILD", + verbose_env="GPTQMODEL_EXT_VERBOSE", + requires_cuda=True, +) + + +def _extension_api(): + from gptqmodel import extension as extension_api + + return extension_api + + +def _marlin_runtime_dtype(dtype: Optional[torch.dtype]) -> torch.dtype: + return torch.bfloat16 if dtype == torch.bfloat16 else torch.float16 + + +def _marlin_kernel_name_for_dtype(dtype: Optional[torch.dtype]) -> str: + return "marlin_bf16" if _marlin_runtime_dtype(dtype) == torch.bfloat16 else "marlin_fp16" + + +def clear_marlin_extension_cache() -> None: + _MARLIN_FP16_TORCH_OPS_EXTENSION.clear_cache() + _MARLIN_BF16_TORCH_OPS_EXTENSION.clear_cache() + + +def marlin_runtime_available(dtype: Optional[torch.dtype] = None) -> bool: + if marlin_import_exception is not None: + return False + return _extension_api().is_available(_marlin_kernel_name_for_dtype(dtype)) + + +def marlin_runtime_error(dtype: Optional[torch.dtype] = None) -> str: + if marlin_import_exception is not None: + return marlin_import_exception + + extension_name = _marlin_kernel_name_for_dtype(dtype) + extension_api = _extension_api() + if extension_api.is_available(extension_name): + return "" + error_text = extension_api.error(extension_name) or "Marlin runtime unavailable." + install_hint = _marlin_header_install_hint(error_text) + if install_hint: + return f"{error_text} {install_hint}" + return error_text + + +def prewarm_marlin_extension(dtype: Optional[torch.dtype]) -> bool: + extension_name = _marlin_kernel_name_for_dtype(dtype) + return _extension_api().load(name=extension_name)[extension_name] + + +def _marlin_resolve_op( + *, + dtype: Optional[torch.dtype], + op_name: str, +): + return _extension_api().op(_marlin_kernel_name_for_dtype(dtype), op_name) # Validate marlin support @@ -32,7 +285,10 @@ def _validate_marlin_device_support() -> bool: Returns: bool: indicates if CUDA device is compatible for Marlin """ - return torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 and not IS_ROCM + if IS_ROCM or not torch.cuda.is_available(): + return False + major, minor = torch.cuda.get_device_capability() + return _marlin_capability_supported(major, minor) def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: @@ -51,8 +307,11 @@ def marlin_make_workspace_new(device: torch.device, max_blocks_per_sm: int = 1) -> torch.Tensor: # In the new marlin kernel, we use the num of threadblocks as workspace # size. The num of threadblocks is sms_count * max_blocks_per_sm. + # Some kernels require a larger fixed minimum than the SM count on + # lower-SM but still-supported GPUs, so clamp to that floor. sms = torch.cuda.get_device_properties(device).multi_processor_count - return torch.zeros(sms * max_blocks_per_sm, + workspace_blocks = max(sms * max_blocks_per_sm, 128) + return torch.zeros(workspace_blocks, dtype=torch.int, device=device, requires_grad=False) @@ -61,13 +320,14 @@ def marlin_make_workspace_new(device: torch.device, def update_tensor_inplace(dst: torch.Tensor, src: torch.Tensor): assert dst.dtype == src.dtype, "Tensors must have the same dtype" - # update tensor shape and stride - dst.as_strided_(src.shape, src.stride()) + with torch.no_grad(): + # Mutating a registered Parameter must bypass autograd bookkeeping. + dst.as_strided_(src.shape, src.stride()) - # If not the same underlying storage move tensor data - if dst.data_ptr() != src.data_ptr(): - dst.copy_(src) - del src + # If not the same underlying storage move tensor data + if dst.data_ptr() != src.data_ptr(): + dst.copy_(src) + del src # Newly generated tensors need to replace existing tensors that are @@ -292,20 +552,59 @@ def gptq_marlin_gemm(a: torch.Tensor, use_atomic_add: bool = False, use_fp32_reduce: bool = False, is_zp_float: bool = False) -> torch.Tensor: - return gptqmodel_marlin_kernels.gptq_marlin_gemm(a, c, b_q_weight, b_bias, b_scales, - global_scale, b_zeros, g_idx, perm, - workspace, b_q_type.id, size_m, - size_n, size_k, is_k_full, - use_atomic_add, use_fp32_reduce, - is_zp_float) + if _marlin_runtime_dtype(a.dtype) == torch.bfloat16: + op_name = "gptq_marlin_gemm_bf16" + else: + op_name = "gptq_marlin_gemm_fp16" + + op = _marlin_resolve_op( + dtype=a.dtype, + op_name=op_name, + ) + return op( + a, + c, + b_q_weight, + b_bias, + b_scales, + global_scale, + b_zeros, + g_idx, + perm, + workspace, + b_q_type.id, + size_m, + size_n, + size_k, + is_k_full, + use_atomic_add, + use_fp32_reduce, + is_zp_float, + ) # gptq_marlin def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, size_k: int, size_n: int, - num_bits: int) -> torch.Tensor: - return gptqmodel_marlin_kernels.gptq_marlin_repack(b_q_weight, perm, size_k, size_n, - num_bits) + num_bits: int, + dtype: Optional[torch.dtype] = None) -> torch.Tensor: + op = _marlin_resolve_op( + dtype=dtype, + op_name="gptq_marlin_repack", + ) + return op(b_q_weight, perm, size_k, size_n, num_bits) + + +def awq_marlin_repack(b_q_weight: torch.Tensor, + size_k: int, + size_n: int, + num_bits: int, + dtype: Optional[torch.dtype] = None) -> torch.Tensor: + op = _marlin_resolve_op( + dtype=dtype, + op_name="awq_marlin_repack", + ) + return op(b_q_weight, size_k, size_n, num_bits) def get_pack_factor(num_bits): diff --git a/gptqmodel/utils/mlx.py b/gptqmodel/utils/mlx.py index 9586ea527..d95fa045c 100644 --- a/gptqmodel/utils/mlx.py +++ b/gptqmodel/utils/mlx.py @@ -9,8 +9,9 @@ from transformers import PreTrainedModel from ..models import BaseQModel -from ..nn_modules.qlinear.torch import TorchQuantLinear +from ..nn_modules.qlinear.torch import TorchLinear from ..quantization import FORMAT +from ..quantization.config import resolve_quant_format from .logger import setup_logger from .torch import torch_empty_cache @@ -35,8 +36,8 @@ def convert_gptq_to_mlx_weights(model_id_or_path: str, model: Union[PreTrainedMo if gptq_config["bits"] not in [2, 3, 4, 8]: raise ValueError("Model bits is not in [2,3,4,8]") - if gptq_config["checkpoint_format"] not in [FORMAT.GPTQ, FORMAT.GPTQ_V2]: - raise ValueError("Model checkpoint format is not gptq or gptq_v2") + if resolve_quant_format(gptq_config.get("format"), gptq_config.get("method", gptq_config.get("quant_method"))) not in [FORMAT.GPTQ, FORMAT.GPTQ_V2]: + raise ValueError("Model format is not gptq or gptq_v2") if gptq_config.get("dynamic") is not None: print(gptq_config["dynamic"]) @@ -61,7 +62,7 @@ def convert_gptq_to_mlx_weights(model_id_or_path: str, model: Union[PreTrainedMo pb = log.pb(list(model.named_modules())).title("Format: Converting to mlx ->").manual() for name, module in pb: pb.subtitle(f"{name}").draw() - if isinstance(module, TorchQuantLinear): + if isinstance(module, TorchLinear): weights[f"{name}.weight"] = mx.array( module.dequantize_weight().T.detach().to("cpu", torch.float16).numpy() ) diff --git a/gptqmodel/utils/mmlupro.py b/gptqmodel/utils/mmlupro.py index 656616428..aac5721a3 100644 --- a/gptqmodel/utils/mmlupro.py +++ b/gptqmodel/utils/mmlupro.py @@ -9,7 +9,7 @@ import random import time -import pcre as re +import pcre import torch from datasets import load_dataset from torch.utils.data import DataLoader @@ -24,6 +24,12 @@ stop_string = "Question:" log = setup_logger() +_ANSWER_IS_RE = pcre.compile(r"answer is \(?([A-J])\)?") +_ANSWER_LINE_RE = pcre.compile(r".*[aA]nswer:\s*([A-J])") +_FINAL_ANSWER_RE = pcre.compile( + r"\b[A-J]\b(?!.*\b[A-J]\b)", + flags=pcre.Flag.DOTALL, +) def load_mmlu_pro(): dataset = load_dataset("TIGER-Lab/MMLU-Pro") @@ -92,8 +98,7 @@ def generate_cot_prompt(val_df, curr, k): def extract_answer(text): - pattern = r"answer is \(?([A-J])\)?" - match = re.search(pattern, text) + match = _ANSWER_IS_RE.search(text) if match: return match.group(1) else: @@ -102,7 +107,7 @@ def extract_answer(text): def extract_again(text): - match = re.search(r'.*[aA]nswer:\s*([A-J])', text) + match = _ANSWER_LINE_RE.search(text) if match: return match.group(1) else: @@ -110,8 +115,7 @@ def extract_again(text): def extract_final(text): - pattern = r"\b[A-J]\b(?!.*\b[A-J]\b)" - match = re.search(pattern, text, re.DOTALL) + match = _FINAL_ANSWER_RE.search(text) if match: return match.group(0) else: @@ -212,7 +216,8 @@ def mmlupro(model: PreTrainedModel, save_dir: str = "results", global_record_file: str="eval_record_collection.csv", batch_size: int = 1, - seed: int = 12345): + seed: int = 12345, + max_samples: int | None = None): random.seed(seed) os.makedirs(save_dir, exist_ok=True) model_name = os.path.basename(model.config.name_or_path) @@ -255,6 +260,8 @@ def mmlupro(model: PreTrainedModel, if subject not in sta_dict: sta_dict[subject] = {"corr": 0.0, "wrong": 0.0, "accu": 0.0} test_df = select_by_category(full_test_df, subject) + if max_samples is not None: + test_df = test_df[:max_samples] val_df = select_by_category(full_val_df, subject) output_path = os.path.join(save_result_dir, "{}.json".format(subject)) @@ -294,4 +301,3 @@ def mmlupro(model: PreTrainedModel, summary = file.read() return summary - diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py index 3ba42dde7..c2a09579a 100644 --- a/gptqmodel/utils/model.py +++ b/gptqmodel/utils/model.py @@ -21,11 +21,10 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union import accelerate -import pcre as re +import pcre import torch import torch.nn as nn import transformers -from huggingface_hub import HfApi, hf_hub_download from packaging import version from safetensors import safe_open from torch.nn.modules.conv import _ConvNd @@ -33,34 +32,47 @@ from transformers.pytorch_utils import id_tensor_storage from transformers.utils.hub import cached_file -from gptqmodel.nn_modules.qlinear.exllama_eora import ExllamaEoraQuantLinear -from gptqmodel.nn_modules.qlinear.marlin import MarlinQuantLinear +from gptqmodel.nn_modules.qlinear.marlin import MarlinLinear from ..adapter.adapter import Adapter from ..looper.named_module import NamedModule from ..models._const import ( CPU, DEVICE, - EXLLAMA_DEFAULT_MAX_INPUT_LENGTH, EXPERT_INDEX_PLACEHOLDER, SUPPORTS_MODULE_TYPES, ) from ..nn_modules.qlinear import BaseQuantLinear -from ..nn_modules.qlinear.exllama import ExllamaQuantLinear -from ..nn_modules.qlinear.exllamav2 import ExllamaV2QuantLinear -from ..nn_modules.qlinear.exllamav2_awq import AwqExllamaV2QuantLinear +from ..nn_modules.qlinear.exllamav2 import ExllamaV2Linear +from ..nn_modules.qlinear.exllamav2_awq import AwqExllamaV2Linear from ..quantization import FORMAT, QuantizeConfig -from ..quantization.config import FORMAT_FIELD_CHECKPOINT, METHOD, dynamic_get +from ..quantization.config import ( + FORMAT_FIELD_CODE, + METHOD, + _normalize_bitsandbytes_block_size, + _normalize_bitsandbytes_format, + _normalize_fp8_fmt, + _normalize_fp8_scale_semantics, + _normalize_fp8_weight_block_size, + _normalize_fp8_weight_scale_method, + _normalize_quant_bits, + dynamic_get, + quant_bits_width, + resolve_quant_format, +) from . import has_gil_disabled -from .backend import BACKEND +from .backend import BACKEND, normalize_backend from .ctx import ctx from .device import get_device +from .hf import get_hf_config_dtype +from .hub import hf_hub_download, model_info from .importer import select_quant_linear from .logger import log_time_block, setup_logger from .torch import HAS_CUDA, torch_empty_cache log = setup_logger() +_REQUIRES_VERSION_RE = pcre.compile(r"(<=|>=|==|<|>)\s*([\d\.]+)") _DTYPE_SAFE_MAP = { @@ -76,6 +88,30 @@ torch.bool: ("BOOL", 1), } +if hasattr(torch, "float8_e4m3fn"): + _DTYPE_SAFE_MAP[torch.float8_e4m3fn] = ("F8_E4M3", 1) +if hasattr(torch, "float8_e5m2"): + _DTYPE_SAFE_MAP[torch.float8_e5m2] = ("F8_E5M2", 1) + +_FLOAT8_DTYPE_NAMES = tuple( + name + for name in ( + "float8_e4m3fn", + "float8_e5m2", + "float8_e4m3fnuz", + "float8_e5m2fnuz", + "float8_e8m0fnu", + ) + if hasattr(torch, name) +) +_FLOAT4_PACKED_DTYPE_NAMES = tuple( + name for name in ("float4_e2m1fn_x2",) if hasattr(torch, name) +) + +# Byte-size fallbacks keep ancillary metadata math working for torch floatx +# dtypes even when the current safetensors header schema cannot serialize them. +_DTYPE_NUM_BYTES = dict.fromkeys((*[getattr(torch, name) for name in _FLOAT8_DTYPE_NAMES], *[getattr(torch, name) for name in _FLOAT4_PACKED_DTYPE_NAMES]), 1) + _DTYPE_STR_MAP = { "float32": torch.float32, @@ -96,6 +132,24 @@ "bool": torch.bool, } +for name in _FLOAT8_DTYPE_NAMES: + dtype = getattr(torch, name) + _DTYPE_STR_MAP[name] = dtype + _DTYPE_STR_MAP[f"f8_{name.removeprefix('float8_')}"] = dtype + +if hasattr(torch, "float8_e4m3fn"): + _DTYPE_STR_MAP["f8_e4m3"] = torch.float8_e4m3fn +if hasattr(torch, "float8_e5m2"): + _DTYPE_STR_MAP["f8_e5m2"] = torch.float8_e5m2 +if hasattr(torch, "float8_e8m0fnu"): + _DTYPE_STR_MAP["float8_e8m0"] = torch.float8_e8m0fnu + _DTYPE_STR_MAP["f8_e8m0"] = torch.float8_e8m0fnu + +for name in _FLOAT4_PACKED_DTYPE_NAMES: + dtype = getattr(torch, name) + _DTYPE_STR_MAP[name] = dtype + _DTYPE_STR_MAP[f"f4_{name.removeprefix('float4_')}"] = dtype + MoETopKState = List[Tuple[nn.Module, str, int]] MOE_TOPK_FIELD_NAMES = [ @@ -110,9 +164,11 @@ def _torch_dtype_num_bytes(dtype: torch.dtype) -> int: - if dtype not in _DTYPE_SAFE_MAP: - raise NotImplementedError(f"Unsupported dtype for safetensors export: {dtype}") - return _DTYPE_SAFE_MAP[dtype][1] + if dtype in _DTYPE_SAFE_MAP: + return _DTYPE_SAFE_MAP[dtype][1] + if dtype in _DTYPE_NUM_BYTES: + return _DTYPE_NUM_BYTES[dtype] + raise NotImplementedError(f"Unsupported dtype for safetensors export: {dtype}") def _torch_dtype_to_safetensors(dtype: torch.dtype) -> str: @@ -179,7 +235,30 @@ def recurse_setattr(module, name, value): recurse_setattr(getattr(module, name), rest, value) +def _module_has_meta_tensors(module: nn.Module) -> bool: + for param in module.parameters(recurse=True): + if getattr(param, "is_meta", False) or param.device.type == "meta": + return True + for buf in module.buffers(recurse=True): + if getattr(buf, "is_meta", False) or buf.device.type == "meta": + return True + return False + + def move_to(obj: torch.Tensor | nn.Module, device: torch.device, dtype: torch.dtype = None): + if isinstance(obj, nn.Module) and _module_has_meta_tensors(obj): + if not accelerate.utils.has_offloaded_params(obj): + raise NotImplementedError( + "Cannot move a module that still contains meta tensors without offload hooks. " + "Materialize it first before calling move_to()." + ) + + # Accelerate disk-offloaded modules keep meta placeholders until they are + # explicitly restored, so materialize those leaves before the device move. + from .offload import undo_offload_to_disk + + return undo_offload_to_disk(obj, device=device, dtype=dtype) + if get_device(obj) != device or dtype is not None: obj = obj.to(device=device, dtype=dtype, non_blocking=False) @@ -252,20 +331,28 @@ def make_quant( pack: bool = False, device: DEVICE = None, from_quantized: bool = False, + dtype: Optional[torch.dtype] = None, ) -> Type[BaseQuantLinear]: - bits = qcfg.bits + bits = qcfg.runtime_bits group_size =qcfg.group_size extension = qcfg.adapter - format = qcfg.format + format = resolve_quant_format(qcfg.format, qcfg.method) desc_act = qcfg.desc_act sym = qcfg.sym dynamic = qcfg.dynamic pack_dtype = qcfg.pack_dtype + init_kwargs = qcfg.quant_linear_init_kwargs() - # Bitblas needs to be loaded as gptq's quant linear first, and then converted to bitblas format. - if not pack and format in (FORMAT.GPTQ, FORMAT.GPTQ_V2) and backend == BACKEND.BITBLAS: - backend = BACKEND.TORCH + export_quant_method = qcfg.export_quant_method() + backend = normalize_backend(backend, quant_method=export_quant_method) + + # BitBLAS-native checkpoints can load directly. Other formats need a compatible preload kernel first. + if not pack and backend in [BACKEND.GPTQ_BITBLAS, BACKEND.AWQ_BITBLAS]: + if format in (FORMAT.GPTQ, FORMAT.GPTQ_V2): + backend = BACKEND.GPTQ_TORCH + elif qcfg.quant_method == METHOD.AWQ and format == FORMAT.GEMM: + backend = BACKEND.AWQ_TORCH # returns multiple validated kernels quant_linear_candidates = select_quant_linear( @@ -275,11 +362,12 @@ def make_quant( sym=sym, backend=backend, format=format, - quant_method=qcfg.quant_method, + quant_method=export_quant_method, pack=pack, dynamic=dynamic, device=device, pack_dtype=pack_dtype, + dtype=dtype, multi_select=True, adapter=extension, ) @@ -308,6 +396,9 @@ def make_quant( pack_dtype=pack_dtype, backend=backend, adapter=qcfg.adapter, + format=format, + init_kwargs=init_kwargs, + dtype=dtype, ) log.info(f"Kernel: selected -> `{linear_cls.__name__}`.") return linear_cls @@ -323,7 +414,7 @@ def make_quant( def create_quant_module( name: str, linear_cls: Type[BaseQuantLinear], - bits: int, + bits, desc_act: bool, dynamic, group_size: int, @@ -333,9 +424,12 @@ def create_quant_module( device: DEVICE, lm_head_name: str, pack_dtype: torch.dtype, + format: FORMAT = FORMAT.GPTQ, backend: BACKEND = BACKEND.AUTO, register_buffers: bool = True, adapter: Optional[Adapter] = None, + init_kwargs: Optional[Dict[str, Any]] = None, + dtype: Optional[torch.dtype] = None, ): # unwrap named module @@ -379,11 +473,12 @@ def create_quant_module( bias = submodule.bias is not None # need copies as dynamic config may override these in for loop - tmp_bits = bits + tmp_bits = _normalize_quant_bits(bits, format_value=format) tmp_group_size = group_size tmp_desc_act = desc_act tmp_sym = sym tmp_pack_dtype = pack_dtype + tmp_init_kwargs = dict(init_kwargs or {}) # dynamic bits, group_size, sym, pack_dtype for each layer/module if dynamic is not None: @@ -395,30 +490,73 @@ def create_quant_module( # positive module match if overrides: # override base QuantizeConfig for every quant config key/value - tmp_bits = overrides.get("bits", bits) + tmp_bits = _normalize_quant_bits(overrides.get("bits", bits), format_value=format) tmp_group_size = overrides.get("group_size", group_size) tmp_desc_act = overrides.get("desc_act", desc_act) tmp_sym = overrides.get("sym", sym) tmp_pack_dtype = overrides.get("pack_dtype", pack_dtype) - # when loading a quantized model, device is target device passed in GPTQModel.load() + if format == FORMAT.FP8: + fp8_format_override = overrides.get(FORMAT_FIELD_CODE, overrides.get("fmt")) + if fp8_format_override is not None: + tmp_init_kwargs["format"] = _normalize_fp8_fmt(fp8_format_override) + block_size_override = overrides.get( + "weight_block_size", + tmp_init_kwargs.get("weight_block_size"), + ) + normalized_block_size = _normalize_fp8_weight_block_size(block_size_override) + if "weight_scale_method" in overrides or block_size_override is not None: + tmp_init_kwargs["weight_scale_method"] = _normalize_fp8_weight_scale_method( + overrides.get( + "weight_scale_method", + tmp_init_kwargs.get("weight_scale_method"), + ), + weight_block_size=normalized_block_size, + ) + if "weight_scale_semantics" in overrides: + tmp_init_kwargs["weight_scale_semantics"] = _normalize_fp8_scale_semantics( + overrides["weight_scale_semantics"] + ) + if "weight_block_size" in overrides: + tmp_init_kwargs["weight_block_size"] = normalized_block_size + elif format == FORMAT.BITSANDBYTES: + raw_format = overrides.get(FORMAT_FIELD_CODE, overrides.get("bnb_quant_type")) + if raw_format is not None: + tmp_init_kwargs["format"] = _normalize_bitsandbytes_format( + raw_format, + bits=quant_bits_width(tmp_bits), + ) + if "block_size" in overrides or "bnb_block_size" in overrides: + tmp_init_kwargs["block_size"] = _normalize_bitsandbytes_block_size( + overrides.get("block_size", overrides.get("bnb_block_size")) + ) + if "compress_statistics" in overrides or "bnb_compress_statistics" in overrides: + tmp_init_kwargs["compress_statistics"] = bool( + overrides.get("compress_statistics", overrides.get("bnb_compress_statistics")) + ) + + validate_bits = quant_bits_width(tmp_bits) + constructor_bits = tmp_bits if getattr(linear_cls, "QUANT_TYPE", None) == "gguf" else validate_bits + + # when loading a quantized model, device is the target passed through the GPT-QModel load path # check in_features and out_features validate _, err = linear_cls.validate( - bits=tmp_bits, + bits=validate_bits, group_size=tmp_group_size, desc_act=tmp_desc_act, sym=tmp_sym, pack_dtype=tmp_pack_dtype, + dtype=dtype, in_features=in_features, out_features=out_features, - device=device, + device=DEVICE(device) if isinstance(device, str) else device, adapter=adapter, # TODO FIX ME..need to pass Eora if loaded ) if err is not None: raise err new_layer = linear_cls( - bits=tmp_bits, + bits=constructor_bits, group_size=tmp_group_size, desc_act=tmp_desc_act, sym=tmp_sym, @@ -426,19 +564,21 @@ def create_quant_module( out_features=out_features, pack_dtype=tmp_pack_dtype, bias=bias, + dtype=dtype, #weight_dtype=submodule.qweight.dtype if isinstance(submodule, BaseQuantLinear) else submodule.weight.dtype, name=name, lm_head_name=lm_head_name, backend=backend, register_buffers=register_buffers, adapter=adapter, + **tmp_init_kwargs, ) new_layer.device = ori_layer_device recurse_setattr(module, name, new_layer.to(ori_layer_device)) def create_quant_layer( linear_cls: Type[BaseQuantLinear], - bits: int, + bits, desc_act: bool, dynamic, group_size: int, @@ -450,6 +590,9 @@ def create_quant_layer( pack_dtype: torch.dtype, backend: BACKEND, adapter: Optional[Adapter] = None, + format: FORMAT = FORMAT.GPTQ, + init_kwargs: Optional[Dict[str, Any]] = None, + dtype: Optional[torch.dtype] = None, ) -> Type[BaseQuantLinear]: if isinstance(module, linear_cls): @@ -472,8 +615,11 @@ def create_quant_layer( device=device, lm_head_name=lm_head_name, pack_dtype=pack_dtype, + format=format, backend=backend, adapter=adapter, + init_kwargs=init_kwargs, + dtype=dtype, ) return linear_cls @@ -488,7 +634,7 @@ def hf_convert_gptq_v1_to_v2_format( ) -> Tuple[nn.Module, bool]: if checkpoint_format == "gptq": # skip v1 to v2 conversion for kernels that can only operate on sym=True (gptq_v1) - if qlinear_kernel in [MarlinQuantLinear, ExllamaEoraQuantLinear]: + if qlinear_kernel is MarlinLinear: return model, False cfg = QuantizeConfig(bits=bits) @@ -588,7 +734,7 @@ def convert_gptq_v1_to_v2_format( qlinear_kernel: Type[BaseQuantLinear], ): # skip v2 to v1 conversion for gptq_v1 kernels - if cfg.quant_method in [METHOD.GPTQ] and not qlinear_kernel.REQUIRES_FORMAT_V2: + if cfg.export_quant_method() == METHOD.GPTQ and not qlinear_kernel.REQUIRES_FORMAT_V2: log.info( f"Format: Skipped v1 to v2 conversion due to Kernel `{qlinear_kernel}`.") return model @@ -597,7 +743,7 @@ def convert_gptq_v1_to_v2_format( # with tctl.threadpool_limits(limits=1): time.time() log.info( - f"Format: Converting `{FORMAT_FIELD_CHECKPOINT}` from `{FORMAT.GPTQ}` to internal `{FORMAT.GPTQ_V2}`.") + f"Format: Converting `{FORMAT_FIELD_CODE}` from `{FORMAT.GPTQ}` to internal `{FORMAT.GPTQ_V2}`.") for _, submodule in model.named_modules(): # v1 checkpoint format used to do `qzeros = qzeros -= 1` before serialization, thus the @@ -665,7 +811,7 @@ def convert_gptq_v2_to_v1_format( ): # skip v2 to v1 conversion for gptq_v1 kernels - if quantize_config.quant_method in [METHOD.GPTQ] and not qlinear_kernel.REQUIRES_FORMAT_V2: + if quantize_config.export_quant_method() == METHOD.GPTQ and not qlinear_kernel.REQUIRES_FORMAT_V2: return model # Limit thread usage to avoid auto-parallizataion regression @@ -819,8 +965,8 @@ def pack_module( if ( quantize_config is not None - and quantize_config.quant_method == METHOD.GPTQ - and quantize_config.format == FORMAT.GPTQ + and quantize_config.export_quant_method() == METHOD.GPTQ + and resolve_quant_format(quantize_config.format, quantize_config.method) == FORMAT.GPTQ and getattr(quant_linear_cls, "REQUIRES_FORMAT_V2", False) ): with log_time_block( @@ -975,24 +1121,18 @@ def hf_gptqmodel_post_init(model, use_act_order: bool, quantize_config: Quantize def gptqmodel_post_init(model, use_act_order: bool, quantize_config: QuantizeConfig = None, max_input_length: Optional[int] = None): """ - The max_input_length argument is specific to the exllama backend, that requires to initialize a buffer temp_state. + Initialize model-persistent backend scratch buffers after quantized weights are loaded. """ - # post init for bitblas backend. - device_to_buffers_size = {} - # exllama - model_uses_exllama = False - - # exllamav2 fixed_bytes = {} model_uses_exllamav2 = False for name, submodule in model.named_modules(): - if isinstance(submodule, ExllamaV2QuantLinear): + if isinstance(submodule, ExllamaV2Linear): model_uses_exllamav2 = True device = submodule.qweight.device scratch_fixed = submodule.scratch_space_fixed() fixed_bytes[device] = max(scratch_fixed, fixed_bytes.get(device, 0)) - elif isinstance(submodule, AwqExllamaV2QuantLinear): + elif isinstance(submodule, AwqExllamaV2Linear): model_uses_exllamav2 = True device = submodule.qweight.device scratch_fixed = submodule.scratch_space_fixed( @@ -1000,86 +1140,6 @@ def gptqmodel_post_init(model, use_act_order: bool, quantize_config: QuantizeCon max_batch_size=int(os.getenv("AWQ_BATCH_SIZE", 1)) ) fixed_bytes[device] = max(scratch_fixed, fixed_bytes.get(device, 0)) - elif isinstance(submodule, ExllamaQuantLinear): - model_uses_exllama = True - device = submodule.qweight.device - if device not in device_to_buffers_size: - device_to_buffers_size[device] = { - "max_dq_buffer_size": 1, - "max_inner_outer_dim": 1, - } - submodule._use_act_order = True if use_act_order else False - - # Disable this heuristic for detecting act_order, but it could be used instead of the config. - """ - if submodule.g_idx is None: - submodule.act_order = False - elif submodule.g_idx is not None and ((submodule.g_idx == 0).all() or torch.equal(submodule.g_idx.cpu(), torch.tensor([i // submodule.group_size for i in range(submodule.g_idx.shape[0])], dtype=torch.int32))): - submodule.g_idx = None - submodule.act_order = False - else: - submodule.act_order = True - """ - - device_to_buffers_size[device]["max_dq_buffer_size"] = max( - device_to_buffers_size[device]["max_dq_buffer_size"], - submodule.qweight.numel() * 8, - ) - - if use_act_order: - device_to_buffers_size[device]["max_inner_outer_dim"] = max( - device_to_buffers_size[device]["max_inner_outer_dim"], - submodule.in_features, - submodule.out_features, - ) - - if model_uses_exllama: - # To be honest this is quite ugly, not proud of this. - from gptqmodel_exllama_kernels import prepare_buffers, set_tuning_params - - device_to_buffers = {} - - if use_act_order: - if max_input_length is None: - max_input_len = EXLLAMA_DEFAULT_MAX_INPUT_LENGTH - else: - max_input_len = max_input_length - else: - if max_input_length is not None: - log.info( - "Using exllama backend without act-order, the parameter max_input_length was set although not needed, it will be ignored." - ) - max_input_len = 1 - - for device, buffers_size in device_to_buffers_size.items(): - # The temp_state buffer is required to reorder X in the act-order case. - # The temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill. - device_to_buffers[device] = { - "temp_state": torch.zeros( - (max_input_len, buffers_size["max_inner_outer_dim"]), - dtype=torch.float16, - device=device, - ), - "temp_dq": torch.zeros( - (1, buffers_size["max_dq_buffer_size"]), - dtype=torch.float16, - device=device, - ), - "max_dq_buffer_size": buffers_size["max_dq_buffer_size"], - "max_inner_outer_dim": buffers_size["max_inner_outer_dim"], - } - - # Buffers need to be persistent to avoid any bug. - model.device_to_buffers = device_to_buffers - - for device, buffers in model.device_to_buffers.items(): - prepare_buffers(device, buffers["temp_state"], buffers["temp_dq"]) - - # Using the default from exllama repo here. - matmul_recons_thd = 16 - matmul_fused_remap = False - matmul_no_half2 = False - set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2) if model_uses_exllamav2: from ..utils.exllamav2 import ScratchSpace @@ -1094,7 +1154,7 @@ def gptqmodel_post_init(model, use_act_order: bool, quantize_config: QuantizeCon # The buffers need to have been initialized first before calling make_q4. for _, submodule in model.named_modules(): - if isinstance(submodule, (ExllamaV2QuantLinear, AwqExllamaV2QuantLinear)): + if isinstance(submodule, (ExllamaV2Linear, AwqExllamaV2Linear)): device = submodule.qweight.device submodule.post_init(scratch_space=model.device_tensors[device]) elif isinstance(submodule, BaseQuantLinear): @@ -1102,9 +1162,6 @@ def gptqmodel_post_init(model, use_act_order: bool, quantize_config: QuantizeCon torch_empty_cache() - # if use_act_order and max_input_length and isinstance(submodule, ExllamaQuantLinear): - # model = exllama_set_max_input_length(model, max_input_length) - return model @@ -1198,8 +1255,9 @@ def auto_dtype(config: PretrainedConfig, return torch.bfloat16 # Update: latest kernel accuracies have shown, with multiple ranges of shapes - # There are no accuracy issues with bf16 vs fp16. The only kernel with severe - # regression in bf16 is MARLIN_FP16 (reduce math in fp16) which is not auto-selectable + # There are no accuracy issues with bf16 vs fp16. The Marlin reduce-accumulation + # path can use fp16, but it is disabled by default via + # GPTQMODEL_MARLIN_USE_FP32=True. # # for inference, always use FP16 for max accuracy # # check test_kernel_outputs for validation between fp16 and b16 in terms of kernel accuracy # if quant_inference: @@ -1207,7 +1265,7 @@ def auto_dtype(config: PretrainedConfig, # return torch.float16 # get dtype from config - dtype = getattr(config, "dtype") if hasattr(config, "dtype") else getattr(config, "torch_dtype") + dtype = get_hf_config_dtype(config) if dtype and not isinstance(dtype, torch.dtype): raise ValueError(f"dtype in config must be a torch.dtype, but got {dtype}") @@ -1261,12 +1319,15 @@ def copy_py_files(save_dir, file_extension=".py", model_id_or_path=""): for file in py_files: shutil.copy2(os.path.join(model_id_or_path, file), save_dir) else: - api = HfApi() - model_info = api.model_info(model_id_or_path) - for file in model_info.siblings: + remote_model_info = model_info(model_id_or_path) + for file in remote_model_info.siblings: if file.rfilename.endswith(file_extension): - _ = hf_hub_download(repo_id=model_id_or_path, filename=file.rfilename, - local_dir=save_dir) + _ = hf_hub_download( + repo_id=model_id_or_path, + filename=file.rfilename, + local_dir=save_dir, + ) + def get_model_files_size(pre_quantized_model_path, file_extension=['.bin', '.safetensors', '.pth', '.pt', '.ckpt', '.h5', '.pb', '.onnx']): if os.path.isdir(pre_quantized_model_path): @@ -1277,15 +1338,12 @@ def get_model_files_size(pre_quantized_model_path, file_extension=['.bin', '.saf 1] in file_extension ) else: - api = HfApi() - files_data = api.list_repo_files(pre_quantized_model_path) - pre_quantized_size_bytes = 0 - for file_info in files_data: - if any(file_info.endswith(ext) for ext in file_extension): - file_metadata = api.model_info(pre_quantized_model_path, files_metadata=True) - for file_data in file_metadata.siblings: - if file_data.rfilename == file_info: - pre_quantized_size_bytes += file_data.size + remote_model_info = model_info(pre_quantized_model_path, files_metadata=True) + pre_quantized_size_bytes = sum( + (file_data.size or 0) + for file_data in remote_model_info.siblings + if any(file_data.rfilename.endswith(ext) for ext in file_extension) + ) pre_quantized_size_mb = pre_quantized_size_bytes / (1024 * 1024) return pre_quantized_size_mb @@ -1297,7 +1355,7 @@ def check_requires_version(requires_version, current_version): "<": operator.lt, ">": operator.gt, } - match = re.match(r"(<=|>=|==|<|>)\s*([\d\.]+)", requires_version) + match = _REQUIRES_VERSION_RE.match(requires_version) if match: op_symbol, required_version = match.groups() current_version = version.parse(current_version) @@ -1568,7 +1626,7 @@ def get_state_dict_for_save(model: nn.Module, offload_root: Optional[str] = None if model._tied_weights_keys is not None: found = 0 for name in sorted(names): - matches_pattern = any(re.search(pat, name) for pat in model._tied_weights_keys) + matches_pattern = any(pcre.compile(pat).search(name) for pat in model._tied_weights_keys) if matches_pattern and name in state_dict: found += 1 if found < len(names): @@ -1641,6 +1699,11 @@ def _write_shard_file(path: str, entries: List[TensorSource], metadata: Dict[str offset += entry.num_bytes header_bytes = json.dumps(header, separators=(",", ":")).encode("utf-8") + # Safetensors pads the JSON header to an 8-byte boundary. + # Without that padding, some readers reject the file as malformed. + header_padding = (-len(header_bytes)) % 8 + if header_padding: + header_bytes += b" " * header_padding with open(path, "wb") as out: out.write(struct.pack(" dict: } +def _get_bitsandbytes_dependencies(): + try: + import bitsandbytes as bnb + except ImportError as exc: # pragma: no cover - exercised when dependency missing + raise RuntimeError( + "Support for bitsandbytes quantized checkpoints requires the " + "'bitsandbytes' package. Install it with 'pip install bitsandbytes>=0.49.3'." + ) from exc + + return bnb + + def _discover_compressed_tensors_module_schemes( model_path: Path, quant_config, @@ -138,8 +162,10 @@ def _discover_compressed_tensors_module_schemes( init_empty_weights = deps["init_empty_weights"] apply_quantization_config = deps["apply_quantization_config"] map_module_to_scheme = deps["map_module_to_scheme"] + from .hf import prepare_remote_code_compat config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + prepare_remote_code_compat(config) loader_candidates = ( deps["AutoModelForCausalLM"], @@ -158,7 +184,7 @@ def _discover_compressed_tensors_module_schemes( model = loader.from_config( config, trust_remote_code=True, - torch_dtype=torch.float32, + dtype=torch.float32, ) except Exception as exc: # pragma: no cover - depends on available loaders loader_errors.append((loader.__name__, exc)) @@ -315,8 +341,8 @@ def infer_block_shape(weight_shape: Tuple[int, int], scale_tensor: torch.Tensor) def detect_format(model_path: Path, config: dict) -> str: quant_cfg = config.get("quantization_config", {}) or {} - method = (quant_cfg.get("quant_method") or "").lower() - fmt = (quant_cfg.get("fmt") or "").lower() + method = (quant_cfg.get("method") or quant_cfg.get("quant_method") or "").lower() + format_name = (quant_cfg.get("format") or "").lower() files, _ = list_safetensor_files(model_path) if not files: @@ -328,10 +354,10 @@ def detect_format(model_path: Path, config: dict) -> str: for key in keys: if key.endswith(".weight"): tensor = reader.get_tensor(key) - if tensor.dtype == torch.float8_e4m3fn: + if tensor.dtype in _FLOAT8_DTYPES: LOG.debug("Detected FP8 weights via dtype on tensor '%s'", key) return "fp8" - if tensor.dtype == torch.uint8 and (key + "_scale") in keys: + if tensor.dtype in _NVFP4_STORAGE_DTYPES and (key + "_scale") in keys: LOG.debug("Detected NVFP4 weights via dtype on tensor '%s'", key) return "nvfp4" if any(k.endswith(".weight_packed") for k in keys): @@ -347,6 +373,14 @@ def detect_format(model_path: Path, config: dict) -> str: if any(k.endswith(".weight_scale_inv") for k in keys): LOG.debug("Detected FP8 format via '.weight_scale_inv' metadata in shard '%s'", files[0]) return "fp8" + if any(k == "weight_quant_state" or k.endswith(".weight_quant_state") for k in keys) or any( + k == "weight_scb" or k.endswith(".weight_scb") for k in keys + ): + LOG.debug("Detected bitsandbytes format via explicit state tensors in shard '%s'", files[0]) + return "bitsandbytes" + if any(k.endswith(".trellis") for k in keys): + LOG.debug("Detected EXL3 format via '.trellis' metadata in shard '%s'", files[0]) + return "exl3" if any(k.endswith(".qweight") for k in keys): has_g = any(k.endswith(".g_idx") for k in keys) LOG.debug( @@ -356,20 +390,32 @@ def detect_format(model_path: Path, config: dict) -> str: ) return "gptq" if has_g else "awq" - if fmt == "float8_e4m3fn": - LOG.debug("Detected FP8 format via config fmt=%s", fmt) + if format_name in _FLOAT8_FORMAT_NAMES: + LOG.debug("Detected FP8 format via config format=%s", format_name) + return "fp8" + if format_name in {"fp4", "nf4", "int8"}: + LOG.debug("Detected bitsandbytes format via config format=%s", format_name) + return "bitsandbytes" + if method == "fp8": + LOG.debug("Detected FP8 format via method=%s", method) return "fp8" + if method == "bitsandbytes": + LOG.debug("Detected bitsandbytes format via method=%s", method) + return "bitsandbytes" if method in ("gptq", "gptqmodel"): - LOG.debug("Detected GPTQ format via quant_method=%s", method) + LOG.debug("Detected GPTQ format via method=%s", method) return "gptq" if method == "awq": - LOG.debug("Detected AWQ format via quant_method=%s", method) + LOG.debug("Detected AWQ format via method=%s", method) return "awq" + if method == "exl3": + LOG.debug("Detected EXL3 format via method=%s", method) + return "exl3" if method == "compressed-tensors": fmt_name = (quant_cfg.get("format") or "").lower() if fmt_name == "pack-quantized": LOG.debug( - "Detected compressed-tensors format via quant_method=%s and format=%s", + "Detected compressed-tensors format via method=%s and format=%s", method, fmt_name, ) @@ -407,11 +453,12 @@ def convert_fp8_shard( target_dtype: torch.dtype, *, block_shape: Optional[Tuple[int, int]], + scale_semantics: str = "heuristic", ) -> Dict[str, torch.Tensor]: tensors: Dict[str, torch.Tensor] = {} for key in reader.keys(): tensor = reader.get_tensor(key) - if key.endswith(".weight") and tensor.dtype == torch.float8_e4m3fn: + if key.endswith(".weight") and tensor.dtype in _FLOAT8_DTYPES: scale_key = key + "_scale_inv" if scale_key not in reader.keys(): raise KeyError(f"Missing scale inverse tensor for {key}") @@ -442,9 +489,15 @@ def convert_fp8_shard( f"Tensor {key} shape {tensor.shape} incompatible with block size {effective_block}" ) - deq = dequantize_f8_e4m3( + scale_arg = None + scale_inv_arg = scale_inv + if scale_semantics == "inverse": + scale_arg = torch.reciprocal(scale_inv.to(torch.float32)) + scale_inv_arg = None + deq = dequantize_fp8( tensor, - scale_inv=scale_inv, + scale=scale_arg, + scale_inv=scale_inv_arg, axis=None, target_dtype=target_dtype, ) @@ -461,7 +514,7 @@ def convert_nvfp4_shard(reader, target_dtype: torch.dtype) -> Dict[str, torch.Te tensors: Dict[str, torch.Tensor] = {} for key in reader.keys(): tensor = reader.get_tensor(key) - if key.endswith(".weight") and tensor.dtype == torch.uint8: + if key.endswith(".weight") and tensor.dtype in _NVFP4_STORAGE_DTYPES: scale_key = key + "_scale" if scale_key not in reader.keys(): raise KeyError(f"Missing scale tensor for {key}") @@ -474,7 +527,7 @@ def convert_nvfp4_shard(reader, target_dtype: torch.dtype) -> Dict[str, torch.Te target_dtype=target_dtype, ) tensors[key] = finalize_for_save(deq, target_dtype) - elif key.endswith("_weight_scale"): + elif key.endswith(".weight_scale"): LOG.debug("Dropping auxiliary NVFP4 tensor '%s' after dequantization", key) continue else: @@ -482,6 +535,74 @@ def convert_nvfp4_shard(reader, target_dtype: torch.dtype) -> Dict[str, torch.Te return tensors +def convert_bitsandbytes_shard( + reader, + target_dtype: torch.dtype, + *, + quant_cfg: dict, +) -> Dict[str, torch.Tensor]: + bnb = _get_bitsandbytes_dependencies() + + tensors: Dict[str, torch.Tensor] = {} + keys = list(reader.keys()) + key_set = set(keys) + bnb_quant_type = str( + quant_cfg.get("format") + or quant_cfg.get("bnb_quant_type") + or "fp4" + ).strip().lower() + if bnb_quant_type == "bitsandbytes": + bnb_quant_type = "fp4" + + skipped_suffixes = ( + ".weight_absmax", + ".weight_quant_map", + ".weight_nested_absmax", + ".weight_nested_quant_map", + ".weight_quant_state", + ".weight_scb", + ) + + for key in keys: + tensor = reader.get_tensor(key) + + if key.endswith(".weight") and (key[:-len(".weight")] + ".weight_quant_state") in key_set: + prefix = key[:-len(".weight")] + payload = { + "absmax": reader.get_tensor(prefix + ".weight_absmax"), + "quant_map": reader.get_tensor(prefix + ".weight_quant_map"), + f"quant_state.bitsandbytes__{bnb_quant_type}": reader.get_tensor(prefix + ".weight_quant_state"), + } + if prefix + ".weight_nested_absmax" in key_set: + payload["nested_absmax"] = reader.get_tensor(prefix + ".weight_nested_absmax") + if prefix + ".weight_nested_quant_map" in key_set: + payload["nested_quant_map"] = reader.get_tensor(prefix + ".weight_nested_quant_map") + + quant_state = bnb.functional.QuantState.from_dict(payload, device=tensor.device) + deq = bnb.functional.dequantize_4bit(tensor, quant_state=quant_state) + tensors[key] = finalize_for_save(deq, target_dtype) + LOG.debug("Dequantized bitsandbytes 4-bit module '%s' to dtype %s", prefix, target_dtype) + continue + + if key.endswith(".weight") and (key[:-len(".weight")] + ".weight_scb") in key_set: + prefix = key[:-len(".weight")] + deq = bnb.functional.int8_vectorwise_dequant( + tensor, + reader.get_tensor(prefix + ".weight_scb"), + ) + tensors[key] = finalize_for_save(deq, target_dtype) + LOG.debug("Dequantized bitsandbytes 8-bit module '%s' to dtype %s", prefix, target_dtype) + continue + + if key.endswith(skipped_suffixes): + LOG.debug("Dropping auxiliary bitsandbytes tensor '%s' after dequantization", key) + continue + + tensors[key] = finalize_for_save(tensor, target_dtype) + + return tensors + + def convert_awq_file(path: Path, target_dtype: torch.dtype, device: str) -> Dict[str, torch.Tensor]: tensors: Dict[str, torch.Tensor] = {} module_buffers: Dict[str, Dict[str, torch.Tensor]] = defaultdict(dict) @@ -699,6 +820,7 @@ def dequantize_model( open_device = device_str or "cpu" block_shape = resolve_block_size(config) if fmt == "fp8" else None + fp8_scale_semantics = str(quant_cfg.get("weight_scale_semantics") or "heuristic").strip().lower() if block_shape is not None: LOG.debug("Configured FP8 block size %s found in quantization_config", block_shape) @@ -740,7 +862,19 @@ def dequantize_model( LOG.debug("Processing shard '%s' for format %s on device %s", filename, fmt, open_device) if fmt == "fp8": with safe_open(path, framework="pt", device=open_device) as reader: - tensors = convert_fp8_shard(reader, target_dtype, block_shape=block_shape) + tensors = convert_fp8_shard( + reader, + target_dtype, + block_shape=block_shape, + scale_semantics=fp8_scale_semantics, + ) + elif fmt == "bitsandbytes": + with safe_open(path, framework="pt", device=open_device) as reader: + tensors = convert_bitsandbytes_shard( + reader, + target_dtype, + quant_cfg=quant_cfg, + ) elif fmt == "nvfp4": with safe_open(path, framework="pt", device=open_device) as reader: tensors = convert_nvfp4_shard(reader, target_dtype) @@ -780,7 +914,8 @@ def dequantize_model( new_config = dict(config) new_config.pop("quantization_config", None) - new_config["torch_dtype"] = str(target_dtype).split(".")[-1] + new_config.pop("torch_dtype", None) + new_config["dtype"] = str(target_dtype).split(".")[-1] write_json(output_path / "config.json", new_config) skip_files = set(files) | {"config.json", "model.safetensors.index.json"} diff --git a/gptqmodel/utils/openai_server.py b/gptqmodel/utils/openai_server.py index f2a52e4ee..d394357ba 100644 --- a/gptqmodel/utils/openai_server.py +++ b/gptqmodel/utils/openai_server.py @@ -16,7 +16,7 @@ from pydantic import BaseModel except ModuleNotFoundError as exception: raise type(exception)( - "GPTQModel OpenAi serve required dependencies are not installed.", + "GPT-QModel OpenAi serve required dependencies are not installed.", "Please install via `pip install gptqmodel[openai] --no-build-isolation`.", ) @@ -53,26 +53,37 @@ class OpenAiResponse(BaseModel): @self.app.post("/v1/chat/completions", response_model=OpenAiResponse) async def create_completion(request: OpenAiRequest): try: - inputs_tensor = self.tokenizer.apply_chat_template( + model_inputs = self.tokenizer.apply_chat_template( request.messages, add_generation_prompt=True, - return_tensors='pt').to(self.model.device) + return_tensors='pt', + ) + + if isinstance(model_inputs, torch.Tensor): + model_inputs = model_inputs.to(self.model.device) + generate_inputs = {"inputs": model_inputs} + prompt_length = model_inputs.size(-1) + else: + model_inputs = model_inputs.to(self.model.device) + input_ids = model_inputs["input_ids"] + generate_inputs = dict(model_inputs) + prompt_length = input_ids.size(-1) do_sample = True if request.temperature != 0.0 else False with torch.inference_mode(): outputs = self.model.generate( - inputs_tensor, - max_length=inputs_tensor.shape[0] + request.max_tokens, + max_length=prompt_length + request.max_tokens, temperature=request.temperature, top_p=request.top_p, num_return_sequences=request.n, eos_token_id=self.tokenizer.eos_token_id, stop_strings=request.stop, - do_sample=do_sample + do_sample=do_sample, + **generate_inputs, ) generated_texts = self.tokenizer.batch_decode( - outputs[:, inputs_tensor.size(-1):], + outputs[:, prompt_length:], skip_special_tokens=True, ) @@ -96,7 +107,7 @@ async def create_completion(request: OpenAiRequest): @self.app.get("/") def read_root(): - return {"message": "GPTQModel OpenAI Compatible Server is running."} + return {"message": "GPT-QModel OpenAI Compatible Server is running."} @self.app.get("/shutdown") def shutdown(): @@ -113,19 +124,19 @@ def run_server(): if async_mode: thread = threading.Thread(target=run_server, daemon=False) thread.start() - print(f"GPTQModel OpenAi Server has started asynchronously at http://{host}:{port}.") + print(f"GPT-QModel OpenAi Server has started asynchronously at http://{host}:{port}.") else: run_server() - print(f"GPTQModel OpenAi Server has started synchronously at http://{host}:{port}.") + print(f"GPT-QModel OpenAi Server has started synchronously at http://{host}:{port}.") def shutdown(self): if self.uvicorn_server is not None: self.uvicorn_server.should_exit = True - print("GPTQModel OpenAi Server is shutting down...") + print("GPT-QModel OpenAi Server is shutting down...") def wait_until_ready(self, timeout: int = 30, check_interval: float = 0.1): start_time = time.time() while not self.uvicorn_server.started: if time.time() - start_time > timeout: - raise TimeoutError("GPTQModel OpenAi server failed to start within the specified time.") + raise TimeoutError("GPT-QModel OpenAi server failed to start within the specified time.") time.sleep(check_interval) diff --git a/gptqmodel/utils/paroquant.py b/gptqmodel/utils/paroquant.py new file mode 100644 index 000000000..33629e310 --- /dev/null +++ b/gptqmodel/utils/paroquant.py @@ -0,0 +1,627 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# ParoQuant rotation helpers adapted from the ParoQuant paper and public +# project: +# https://arxiv.org/html/2511.10645v2 +# https://github.com/z-lab/paroquant + +"""Utility helpers for ParoQuant rotations and extension loading.""" + +from __future__ import annotations + +import os +import threading +from pathlib import Path +from typing import Optional, Tuple + +import torch + +from .cpp import ( + TorchOpsJitExtension, + default_jit_cflags, + default_jit_cuda_cflags, + default_torch_ops_build_root, +) + + +_SUPPORTED_ROTATION_KERNEL_DTYPES = { + torch.float16, + torch.bfloat16, + torch.float32, +} + +# Process-local cache for resolved fused launch configs. The steady-state +# rotation path uses this to avoid sending autotune sentinels back into the +# native op on every invocation once a shape has already been measured. +_ROTATION_LAUNCH_CONFIG_CACHE: dict[tuple[object, ...], tuple[int, int]] = {} +# The cache lock protects the Python-side resolved launch map for free-threaded +# runtimes, while the serialize lock prevents duplicate native autotune passes +# from multiple threads racing the same cold shape. +_ROTATION_LAUNCH_CONFIG_CACHE_LOCK = threading.Lock() +_ROTATION_AUTOTUNE_SERIALIZE_LOCK = threading.Lock() + + +def _normalize_group_size(group_size: int, in_features: int) -> int: + """Validate and normalize a ParoQuant group size.""" + normalized = in_features if group_size == -1 else int(group_size) + if normalized <= 0: + raise ValueError(f"ParoQuant: invalid group_size `{group_size}` for in_features={in_features}.") + if in_features % normalized != 0: + raise ValueError( + f"ParoQuant: in_features ({in_features}) must be divisible by group_size ({normalized})." + ) + if normalized % 2 != 0: + raise ValueError(f"ParoQuant: group_size ({normalized}) must be even.") + return normalized + + +def build_identity_rotation_buffers( + *, + in_features: int, + group_size: int, + krot: int, + device: Optional[torch.device | str] = None, + dtype: torch.dtype = torch.float16, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Build the identity rotation buffers used as the default runtime state.""" + normalized_group_size = _normalize_group_size(group_size, in_features) + if krot <= 0: + raise ValueError(f"ParoQuant: `krot` must be positive, got {krot}.") + + pairs_row = [] + local_pairs = torch.arange(normalized_group_size, dtype=torch.int16) + num_groups = in_features // normalized_group_size + for _ in range(num_groups): + pairs_row.append(local_pairs) + pairs_single = torch.cat(pairs_row, dim=0) + pairs = pairs_single.unsqueeze(0).repeat(krot, 1) + + theta = torch.zeros((krot, in_features // 2), dtype=dtype) + channel_scales = torch.ones((1, in_features), dtype=dtype) + + if device is not None: + pairs = pairs.to(device=device) + theta = theta.to(device=device) + channel_scales = channel_scales.to(device=device) + + return pairs.contiguous(), theta.contiguous(), channel_scales.contiguous() + + +def is_identity_rotation(theta: torch.Tensor, channel_scales: Optional[torch.Tensor]) -> bool: + """Check whether a ParoQuant rotation reduces to a no-op.""" + if theta is None: + return True + if torch.count_nonzero(theta).item() != 0: + return False + if channel_scales is None: + return True + return bool(torch.all(channel_scales == 1)) + + +def apply_paroquant_rotation_reference( + x: torch.Tensor, + pairs: torch.Tensor, + theta: torch.Tensor, + scales: Optional[torch.Tensor] = None, + group_size: int = 128, +) -> torch.Tensor: + """Pure PyTorch reference implementation of the ParoQuant input rotation.""" + orig_shape = x.shape + if x.dim() < 2: + raise ValueError(f"ParoQuant rotation expects rank >= 2, got shape {tuple(orig_shape)}.") + + x2d = x.reshape(-1, orig_shape[-1]) + hidden = x2d.shape[-1] + group_size = _normalize_group_size(group_size, hidden) + + if scales is not None: + scale_tensor = scales.reshape(1, -1).to(device=x2d.device, dtype=x2d.dtype) + out = x2d * scale_tensor + else: + out = x2d.clone() + + num_groups = hidden // group_size + half_group = group_size // 2 + pairs = pairs.to(device=x2d.device, dtype=torch.int64) + theta = theta.to(device=x2d.device, dtype=torch.float32) + + for rot_idx in range(pairs.shape[0]): + next_out = out.clone() + pair_row = pairs[rot_idx].reshape(num_groups, half_group, 2) + theta_row = theta[rot_idx].reshape(num_groups, half_group) + for group_idx in range(num_groups): + group_offset = group_idx * group_size + idx_i = group_offset + pair_row[group_idx, :, 0] + idx_j = group_offset + pair_row[group_idx, :, 1] + cos_t = torch.cos(theta_row[group_idx]).to(dtype=out.dtype) + sin_t = torch.sin(theta_row[group_idx]).to(dtype=out.dtype) + xi = out[:, idx_i] + xj = out[:, idx_j] + next_out[:, idx_i] = xi * cos_t + xj * sin_t + next_out[:, idx_j] = -xi * sin_t + xj * cos_t + out = next_out + + return out.reshape(orig_shape) + + +def _rotation_sources() -> list[str]: + """Return the native extension sources for the fused CUDA rotation op. + + Build this as a plain custom-op library instead of a Python extension. + The Python-module path pulls in pybind11 initialization that segfaults on + this host during ``PyInit_*`` even though the CUDA op itself is fine. + """ + root = Path(__file__).resolve().parents[2] / "gptqmodel_ext" / "paroquant" + return [ + str(root / "rotation.cu"), + ] + + +def _rotation_kernel_ready( + x: torch.Tensor, + pairs: torch.Tensor, + theta: torch.Tensor, + scales: Optional[torch.Tensor], + group_size: int, +) -> bool: + """Check whether the fused CUDA rotation kernel can service this call.""" + return ( + x.device.type == "cuda" + and x.dtype in _SUPPORTED_ROTATION_KERNEL_DTYPES + and pairs.device.type == "cuda" + and theta.device.type == "cuda" + and (scales is None or scales.device.type == "cuda") + and int(group_size) in {128} + and int(theta.shape[0]) in {1, 8} + ) + + +def _rotation_visible_cuda_capabilities() -> set[tuple[int, int]]: + """Return the visible CUDA SM versions for this process.""" + + if not torch.cuda.is_available(): + return set() + return { + tuple(map(int, torch.cuda.get_device_capability(device_index))) + for device_index in range(torch.cuda.device_count()) + } + + +def _rotation_extra_cuda_cflags() -> list[str]: + flags = default_jit_cuda_cflags() + # On measured SM80/A100 builds, relaxed constexpr consistently improves the + # fused rotation kernel. The same flag regresses SM89/4090 here, so only + # enable it when the visible target set is pure SM80. + if _rotation_visible_cuda_capabilities() == {(8, 0)}: + flags.append("--expt-relaxed-constexpr") + return flags + + +# Shared singleton so ParoQuant uses the same torch.ops JIT lifecycle helpers as +# AWQ and other custom-op extensions. +_PAROQUANT_ROTATION_EXTENSION = TorchOpsJitExtension( + name="gptqmodel_paroquant_rotation", + namespace="gptqmodel_paroquant", + required_ops=("rotate", "launch_config", "clear_autotune_cache", "autotune_cache_size"), + sources=_rotation_sources, + build_root_env="GPTQMODEL_PAROQUANT_BUILD_ROOT", + default_build_root=lambda: default_torch_ops_build_root("paroquant"), + display_name="ParoQuant rotation", + extra_cuda_cflags=_rotation_extra_cuda_cflags, + extra_cflags=lambda: default_jit_cflags(opt_level="O2"), + force_rebuild_env="GPTQMODEL_PAROQUANT_FORCE_REBUILD", + verbose_env="GPTQMODEL_EXT_VERBOSE", + requires_cuda=True, +) + + +def _extension_api(): + from gptqmodel import extension as extension_api + + return extension_api + + +def clear_paroquant_rotation_extension_cache() -> None: + """Delete cached ParoQuant rotation JIT artifacts before the next load attempt.""" + + clear_paroquant_rotation_autotune_cache() + _PAROQUANT_ROTATION_EXTENSION.clear_cache() + + +def clear_paroquant_rotation_autotune_cache() -> None: + """Drop the native launch-plan cache used by fused rotation autotune.""" + + with _ROTATION_LAUNCH_CONFIG_CACHE_LOCK: + _ROTATION_LAUNCH_CONFIG_CACHE.clear() + clear_cache = _rotation_native_op_if_loaded("clear_autotune_cache") + if clear_cache is not None: + clear_cache() + + +def _load_rotation_extension() -> bool: + """JIT-build and load the optional fused CUDA rotation extension once.""" + + return _extension_api().is_available("paroquant") + + +def _rotation_env_int(name: str, allowed: set[int]) -> Optional[int]: + """Parse a bounded integer environment override.""" + raw = os.getenv(name) + if raw is None: + return None + try: + value = int(raw) + except ValueError: + return None + return value if value in allowed else None + + +def _rotation_env_flag(name: str, default: bool) -> bool: + """Parse a boolean environment flag used by rotation autotune.""" + raw = os.getenv(name) + if raw is None: + return default + return raw.strip().lower() not in {"0", "false", "no", "off"} + + +def _rotation_autotune_enabled() -> bool: + """Report whether fused rotation launch autotune is enabled for this process.""" + return _rotation_env_flag("GPTQMODEL_PAROQUANT_ROTATE_AUTOTUNE", True) + + +def _rotation_native_op_if_loaded(op_name: str): + """Return a loaded torch.ops handle without triggering JIT compilation.""" + namespace = _PAROQUANT_ROTATION_EXTENSION._namespace_cache + if namespace is None or not hasattr(namespace, op_name): + return None + cached = _PAROQUANT_ROTATION_EXTENSION._op_cache.get(op_name) + if cached is not None: + return cached + op = getattr(namespace, op_name) + _PAROQUANT_ROTATION_EXTENSION._op_cache[op_name] = op + return op + + +def _rotation_requested_launch() -> tuple[int, int]: + """Resolve the requested fused launch policy from env overrides or autotune.""" + cta_m_override = _rotation_env_int("GPTQMODEL_PAROQUANT_ROTATE_CTA_M", {4, 8, 16}) + row_pad_override = _rotation_env_int("GPTQMODEL_PAROQUANT_ROTATE_ROW_PAD", {0, 2}) + if cta_m_override is not None or row_pad_override is not None: + return ( + -1 if cta_m_override is None else int(cta_m_override), + -1 if row_pad_override is None else int(row_pad_override), + ) + if _rotation_autotune_enabled(): + return (-2, -2) + return (-1, -1) + + +def _rotation_autotune_cache_size() -> int: + """Return the native fused-rotation autotune cache size for tests and benchmarks.""" + cache_size = _rotation_native_op_if_loaded("autotune_cache_size") + if cache_size is not None: + return int(cache_size()) + if not _load_rotation_extension(): + return 0 + return int(_extension_api().op("paroquant", "autotune_cache_size")()) + + +def _run_rotation_op( + x: torch.Tensor, + pairs: torch.Tensor, + theta: torch.Tensor, + scales: Optional[torch.Tensor], + group_size: int, + *, + cta_m: int, + row_pad: int, +) -> torch.Tensor: + """Execute the fused rotation op with one requested launch policy.""" + return _extension_api().op("paroquant", "rotate")( + x, + pairs, + theta, + scales, + int(group_size), + int(cta_m), + int(row_pad), + ) + + +def _rotation_launch_cache_key( + x: torch.Tensor, + *, + krot: int, + has_scale: bool, + group_size: int, + requested_cta_m: int, + requested_row_pad: int, +) -> tuple[object, ...]: + """Build the process-local cache key for one resolved launch shape.""" + hidden = int(x.shape[-1]) + batch_rows = int(x.numel() // hidden) + return ( + int(x.device.index if x.device.index is not None else -1), + x.dtype, + batch_rows, + hidden, + int(group_size), + int(krot), + bool(has_scale), + int(requested_cta_m), + int(requested_row_pad), + ) + + +def _resolve_rotation_launch( + x: torch.Tensor, + *, + scales: Optional[torch.Tensor], + group_size: int, + krot: int, + requested_cta_m: int, + requested_row_pad: int, +) -> tuple[int, int]: + """Resolve one fused launch config and memoize autotuned shape choices.""" + if requested_cta_m != -2 and requested_row_pad != -2: + return int(requested_cta_m), int(requested_row_pad) + + cache_key = _rotation_launch_cache_key( + x, + krot=krot, + has_scale=scales is not None, + group_size=group_size, + requested_cta_m=requested_cta_m, + requested_row_pad=requested_row_pad, + ) + with _ROTATION_LAUNCH_CONFIG_CACHE_LOCK: + cached = _ROTATION_LAUNCH_CONFIG_CACHE.get(cache_key) + if cached is not None: + return cached + + with _ROTATION_AUTOTUNE_SERIALIZE_LOCK: + with _ROTATION_LAUNCH_CONFIG_CACHE_LOCK: + cached = _ROTATION_LAUNCH_CONFIG_CACHE.get(cache_key) + if cached is not None: + return cached + + cta_m, row_pad = _extension_api().op("paroquant", "launch_config")( + x, + int(krot), + scales is not None, + int(group_size), + int(requested_cta_m), + int(requested_row_pad), + ) + resolved = (int(cta_m), int(row_pad)) + with _ROTATION_LAUNCH_CONFIG_CACHE_LOCK: + cached = _ROTATION_LAUNCH_CONFIG_CACHE.setdefault(cache_key, resolved) + return cached + + +def _rotation_launch_config( + x: torch.Tensor, + pairs: Optional[torch.Tensor] = None, + theta: Optional[torch.Tensor] = None, + scales: Optional[torch.Tensor] = None, + group_size: int = 128, + *, + extension_loaded: bool = False, + kernel_ready: bool = False, +) -> tuple[int, int]: + """Query the native fused-kernel launch shape for one CUDA tensor.""" + del pairs, kernel_ready + if x.device.type != "cuda": + raise ValueError("ParoQuant launch config requires a CUDA tensor.") + if not extension_loaded and not _load_rotation_extension(): + raise RuntimeError("ParoQuant launch config requires the fused rotation extension.") + requested_cta_m, requested_row_pad = _rotation_requested_launch() + krot = 8 if theta is None else int(theta.shape[0]) + if requested_cta_m == -2 or requested_row_pad == -2: + cta_m, row_pad = _resolve_rotation_launch( + x, + scales=scales, + group_size=int(group_size), + krot=krot, + requested_cta_m=requested_cta_m, + requested_row_pad=requested_row_pad, + ) + else: + cta_m, row_pad = _extension_api().op("paroquant", "launch_config")( + x, + int(krot), + scales is not None, + int(group_size), + int(requested_cta_m), + int(requested_row_pad), + ) + return int(cta_m), int(row_pad) + + +def _apply_fused_rotation( + x: torch.Tensor, + pairs: torch.Tensor, + theta: torch.Tensor, + scales: Optional[torch.Tensor], + group_size: int, +) -> torch.Tensor: + """Run the fused rotation op with explicit overrides or native autotune sentinels.""" + requested_cta_m, requested_row_pad = _rotation_requested_launch() + cta_m, row_pad = _resolve_rotation_launch( + x, + scales=scales, + group_size=int(group_size), + krot=int(theta.shape[0]), + requested_cta_m=requested_cta_m, + requested_row_pad=requested_row_pad, + ) + return _run_rotation_op( + x, + pairs, + theta, + scales, + group_size, + cta_m=cta_m, + row_pad=row_pad, + ) + + +def prewarm_paroquant_rotation_extension( + *, + fused_rotation: bool, + group_size: int, + krot: int, + device: Optional[torch.device | str] = None, +) -> bool: + """Eagerly build the fused rotation extension before timed quantization starts.""" + if not fused_rotation: + return False + if int(group_size) not in {128}: + return False + if int(krot) not in {1, 8}: + return False + + if device is not None and torch.device(device).type != "cuda": + return False + if device is None and not torch.cuda.is_available(): + return False + + return _load_rotation_extension() + + +def apply_paroquant_rotation( + x: torch.Tensor, + pairs: torch.Tensor, + theta: torch.Tensor, + scales: Optional[torch.Tensor] = None, + group_size: int = 128, +) -> torch.Tensor: + """Apply the fused rotation when available, else fall back to the reference path.""" + if _rotation_kernel_ready(x, pairs, theta, scales, group_size) and _load_rotation_extension(): + return _apply_fused_rotation(x, pairs, theta, scales, int(group_size)) + return apply_paroquant_rotation_reference(x, pairs, theta, scales=scales, group_size=group_size) + + +class _ParoQuantRotateTensorFunc(torch.autograd.Function): + """Autograd wrapper around the fused ParoQuant CUDA rotation kernel.""" + + @staticmethod + def forward( + ctx, + x: torch.Tensor, + pairs: torch.Tensor, + theta: torch.Tensor, + scales: Optional[torch.Tensor] = None, + group_size: int = 128, + ) -> torch.Tensor: + scale_tensor = None if scales is None else scales.contiguous() + ctx.orig_shape = x.shape + ctx.orig_dtype = x.dtype + ctx.has_scale = scale_tensor is not None + ctx.group_size = int(group_size) + ctx.needs_x_grad = bool(x.requires_grad) + ctx.needs_theta_grad = bool(theta.requires_grad) + ctx.needs_scale_grad = bool(scale_tensor is not None and scale_tensor.requires_grad) + + y = _apply_fused_rotation(x, pairs, theta, scale_tensor, int(group_size)) + saved = (x, pairs, theta, y, scale_tensor) if ctx.has_scale else (x, pairs, theta, y) + ctx.save_for_backward(*saved) + return y + + @staticmethod + def backward(ctx, grad_out: torch.Tensor): + saved_tensors = ctx.saved_tensors + x, pairs, theta, y = saved_tensors[:4] + scale_tensor = saved_tensors[4] if ctx.has_scale else None + group_size = ctx.group_size + + _, hidden = pairs.shape + num_groups = hidden // group_size + half_group = group_size // 2 + batch_rows = y.numel() // hidden + grad = grad_out.reshape(batch_rows, hidden) + grad_theta = None + + if ctx.needs_theta_grad: + rotated = y.reshape(batch_rows, hidden) + grad_theta = torch.zeros_like(theta) + offsets = ( + torch.arange(num_groups, device=pairs.device, dtype=torch.long).unsqueeze(1) * group_size + ) + + for rot_idx in range(pairs.shape[0] - 1, -1, -1): + pair_row = pairs.narrow(0, rot_idx, 1) + neg_theta = theta.narrow(0, rot_idx, 1).neg() + rotated = _apply_fused_rotation(rotated, pair_row, neg_theta, None, group_size) + grad = _apply_fused_rotation(grad, pair_row, neg_theta, None, group_size) + + pair_view = pair_row.reshape(num_groups, group_size) + idx_i = (pair_view[:, 0::2] + offsets).reshape(-1) + idx_j = (pair_view[:, 1::2] + offsets).reshape(-1) + + xi = rotated[:, idx_i].reshape(batch_rows, num_groups, half_group) + xj = rotated[:, idx_j].reshape(batch_rows, num_groups, half_group) + grad_i = grad[:, idx_i].reshape(batch_rows, num_groups, half_group) + grad_j = grad[:, idx_j].reshape(batch_rows, num_groups, half_group) + + theta_view = theta.narrow(0, rot_idx, 1).reshape(num_groups, half_group) + sin_t = theta_view.sin() + cos_t = theta_view.cos() + grad_theta[rot_idx] = ( + ( + (grad_i * xj - grad_j * xi).sum(0) * cos_t + - (grad_i * xi + grad_j * xj).sum(0) * sin_t + ) + .reshape(-1) + .to(theta.dtype) + ) + else: + for rot_idx in range(pairs.shape[0] - 1, -1, -1): + pair_row = pairs.narrow(0, rot_idx, 1) + neg_theta = theta.narrow(0, rot_idx, 1).neg() + grad = _apply_fused_rotation(grad, pair_row, neg_theta, None, group_size) + + if ctx.has_scale: + scale_flat = scale_tensor.reshape(-1) + grad_x = None + if ctx.needs_x_grad: + grad_x = (grad * scale_flat.unsqueeze(0)).reshape(ctx.orig_shape).to(ctx.orig_dtype) + grad_scale = None + if ctx.needs_scale_grad: + grad_scale = (x.reshape(batch_rows, hidden) * grad).sum(0).to( + dtype=scale_tensor.dtype, + device=scale_tensor.device, + ) + grad_scale = grad_scale.reshape_as(scale_tensor) + else: + grad_x = grad.reshape(ctx.orig_shape).to(ctx.orig_dtype) if ctx.needs_x_grad else None + grad_scale = None + + return grad_x, None, grad_theta, grad_scale, None + + +def apply_paroquant_rotation_autograd( + x: torch.Tensor, + pairs: torch.Tensor, + theta: torch.Tensor, + scales: Optional[torch.Tensor] = None, + group_size: int = 128, +) -> torch.Tensor: + """Apply the fused rotation with custom backward support when available.""" + if _rotation_kernel_ready(x, pairs, theta, scales, group_size) and _load_rotation_extension(): + return _ParoQuantRotateTensorFunc.apply(x, pairs, theta, scales, int(group_size)) + return apply_paroquant_rotation_reference(x, pairs, theta, scales=scales, group_size=group_size) + + +__all__ = [ + "apply_paroquant_rotation", + "apply_paroquant_rotation_autograd", + "apply_paroquant_rotation_reference", + "_rotation_launch_config", + "build_identity_rotation_buffers", + "clear_paroquant_rotation_autotune_cache", + "clear_paroquant_rotation_extension_cache", + "is_identity_rotation", + "prewarm_paroquant_rotation_extension", +] diff --git a/gptqmodel/utils/paroquant_benchmark.py b/gptqmodel/utils/paroquant_benchmark.py new file mode 100644 index 000000000..2ad817a45 --- /dev/null +++ b/gptqmodel/utils/paroquant_benchmark.py @@ -0,0 +1,910 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import json +import os +import tempfile +import time +from pathlib import Path +from typing import Any, Optional + +import torch +from tabulate import tabulate + +from gptqmodel import GPTQModel +from gptqmodel.nn_modules.qlinear.paroquant import ParoLinear +from gptqmodel.nn_modules.qlinear.paroquant_triton import ParoQuantTritonLinear +from gptqmodel.quantization import FORMAT, METHOD +from gptqmodel.quantization.config import QuantizeConfig +from gptqmodel.utils.backend import BACKEND + + +_NM_CALIBRATION_PATH = "/monster/data/model/dataset/nm-calibration" +_NM_CALIBRATION_PARQUET = Path("/monster/data/model/dataset/nm-calibration/llm.parquet") +_DEFAULT_MODEL = "/monster/data/model/Llama-3.2-1B-Instruct" + + +def _normalize_model_dtype(model_dtype: Any) -> torch.dtype | str | None: + if model_dtype is None: + return None + if isinstance(model_dtype, str): + normalized = model_dtype.strip().lower() + mapping = { + "fp16": torch.float16, + "float16": torch.float16, + "half": torch.float16, + "bf16": torch.bfloat16, + "bfloat16": torch.bfloat16, + "fp32": torch.float32, + "float32": torch.float32, + } + return mapping.get(normalized, model_dtype) + return model_dtype + + +def _dtype_label(model_dtype: Any) -> str: + normalized = _normalize_model_dtype(model_dtype) + mapping = { + torch.float16: "fp16", + torch.bfloat16: "bf16", + torch.float32: "fp32", + } + return mapping.get(normalized, str(normalized)) + + +def _visible_cuda_device_name() -> str: + if not torch.cuda.is_available(): + return "cpu" + return torch.cuda.get_device_name(torch.device("cuda:0")) + + +def _single_gpu_device_map() -> dict[str, str] | str: + return {"": "cuda:0"} if torch.cuda.is_available() else "cpu" + + +def _load_local_calibration_parquet(): + if not _NM_CALIBRATION_PARQUET.exists(): + raise FileNotFoundError(f"Calibration parquet not found at {_NM_CALIBRATION_PARQUET}") + + try: + import pandas as pd + except ImportError: + pd = None + + if pd is not None: + records = pd.read_parquet(_NM_CALIBRATION_PARQUET).to_dict(orient="records") + else: + import pyarrow.parquet as pq + + records = pq.read_table(_NM_CALIBRATION_PARQUET).to_pylist() + + normalized = [] + for record in records: + item = {} + for key, value in dict(record).items(): + if hasattr(value, "tolist") and not isinstance(value, (str, bytes)): + value = value.tolist() + item[key] = value + normalized.append(item) + return normalized + + +def load_nm_calibration(rows: int) -> list[dict[str, Any]]: + try: + from datasets import load_dataset + except Exception: + dataset = _load_local_calibration_parquet() + else: + try: + dataset = load_dataset(path=_NM_CALIBRATION_PATH, name="LLM", split="train") + except Exception: + dataset = _load_local_calibration_parquet() + + if rows <= 0: + return list(dataset) + if hasattr(dataset, "select"): + dataset = dataset.select(range(min(rows, len(dataset)))) + else: + dataset = list(dataset)[:rows] + return list(dataset) + + +def _layers_node_info(model) -> tuple[str, int]: + layers_node = model.extract_layers_node() + if isinstance(layers_node, (list, tuple)): + if not layers_node: + raise ValueError("Model did not expose a layers node for ParoQuant benchmarking.") + layers_node = layers_node[0] + layers_node = str(layers_node).strip() + if not layers_node: + raise ValueError("Model layers node resolved to an empty string.") + escaped_layers_node = layers_node.replace(".", r"\.") + layer_count = len(getattr(model.model, "model", model.model).layers) + return escaped_layers_node, layer_count + + +def build_prefix_layer_dynamic(model, num_quant_layers: int) -> dict[str, dict[str, Any]]: + escaped_layers_node, layer_count = _layers_node_info(model) + num_quant_layers = int(num_quant_layers) + if num_quant_layers <= 0: + raise ValueError("ParoQuant benchmark: `num_quant_layers` must be positive.") + if num_quant_layers > layer_count: + raise ValueError( + f"ParoQuant benchmark: `num_quant_layers` ({num_quant_layers}) exceeds model layer count ({layer_count})." + ) + return { + f"-:^{escaped_layers_node}\\.{layer_idx}\\.": {} + for layer_idx in range(num_quant_layers, layer_count) + } + + +def build_first_layer_only_dynamic(model) -> dict[str, dict[str, Any]]: + return build_prefix_layer_dynamic(model, num_quant_layers=1) + + +def build_single_module_dynamic(model, *, layer_idx: int, module_name: str) -> dict[str, dict[str, Any]]: + return build_selected_modules_dynamic(model, layer_idx=layer_idx, module_names=[module_name]) + + +def build_selected_modules_dynamic( + model, + *, + layer_idx: int, + module_names: list[str] | tuple[str, ...], +) -> dict[str, dict[str, Any]]: + escaped_layers_node, layer_count = _layers_node_info(model) + layer_idx = int(layer_idx) + if layer_idx < 0 or layer_idx >= layer_count: + raise ValueError(f"ParoQuant benchmark: `layer_idx` ({layer_idx}) is outside [0, {layer_count - 1}].") + + layers = getattr(model.model, "model", model.model).layers + layer = layers[layer_idx] + normalized_module_names: list[str] = [] + escaped_module_names: list[str] = [] + for module_name in module_names: + current = layer + parts = [part for part in str(module_name).strip().split(".") if part] + if not parts: + raise ValueError("ParoQuant benchmark: `module_names` must contain non-empty relative module paths.") + for part in parts: + if not hasattr(current, part): + raise ValueError(f"ParoQuant benchmark: layer {layer_idx} does not expose module path `{module_name}`.") + current = getattr(current, part) + normalized_module_name = ".".join(parts) + normalized_module_names.append(normalized_module_name) + escaped_module_names.append(normalized_module_name.replace(".", r"\.")) + + if not escaped_module_names: + raise ValueError("ParoQuant benchmark: at least one module must be selected.") + + selected_pattern = "|".join(sorted(set(escaped_module_names))) + dynamic: dict[str, dict[str, Any]] = {} + for idx in range(layer_count): + if idx == layer_idx: + # Quantize only the selected leaf modules inside the target layer. + dynamic[f"-:^{escaped_layers_node}\\.{idx}\\.(?!(?:{selected_pattern})$)"] = {} + else: + dynamic[f"-:^{escaped_layers_node}\\.{idx}\\."] = {} + return dynamic + + +def make_paroquant_config( + *, + dynamic: dict[str, dict[str, Any]], + sym: bool = True, + bits: int = 4, + group_size: int = 128, + krot: int = 8, + opt_scope: str = "module", + opt_rotation_epochs: int = 10, + opt_finetune_epochs: int = 10, + opt_train_samples: int = 2048, + opt_validation_samples: int = 64, + opt_batch_size: int = 64, + opt_optimizer: str = "adamw", + opt_weight_decay: float = 0.01, + opt_betas: tuple[float, float] = (0.9, 0.95), + opt_eps: float = 1e-10, + opt_amsgrad: bool = False, + opt_sgd_momentum: float = 0.0, + opt_sgd_dampening: float = 0.0, + opt_sgd_nesterov: bool = False, + opt_gradient_checkpointing: Optional[bool] = None, + offload_to_disk: bool = False, +) -> QuantizeConfig: + if sym is not True: + raise ValueError("ParoQuant benchmark: `sym=False` is disabled; use `sym=True`.") + return QuantizeConfig( + method=METHOD.PARO, + format=FORMAT.PAROQUANT, + bits=bits, + group_size=group_size, + sym=sym, + desc_act=False, + krot=krot, + dynamic=dynamic, + offload_to_disk=offload_to_disk, + device="cuda:0" if torch.cuda.is_available() else "cpu", + opt_scope=opt_scope, + opt_rotation_epochs=opt_rotation_epochs, + opt_finetune_epochs=opt_finetune_epochs, + opt_train_samples=opt_train_samples, + opt_validation_samples=opt_validation_samples, + opt_batch_size=opt_batch_size, + opt_optimizer=opt_optimizer, + opt_weight_decay=opt_weight_decay, + opt_betas=opt_betas, + opt_eps=opt_eps, + opt_amsgrad=opt_amsgrad, + opt_sgd_momentum=opt_sgd_momentum, + opt_sgd_dampening=opt_sgd_dampening, + opt_sgd_nesterov=opt_sgd_nesterov, + opt_gradient_checkpointing=opt_gradient_checkpointing, + ) + + +def _cleanup_model(model) -> None: + del model + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +def _prepare_eval_tokenizer(model) -> None: + tokenizer = getattr(model, "tokenizer", None) + if tokenizer is None: + return + if getattr(tokenizer, "padding_side", None) != "left": + tokenizer.padding_side = "left" + if getattr(tokenizer, "pad_token_id", None) is None: + eos_token_id = getattr(tokenizer, "eos_token_id", None) + if eos_token_id is not None: + tokenizer.pad_token_id = eos_token_id + + +def _suite_kwargs( + max_rows: Optional[int], + suite_kwargs: Optional[dict[str, Any]] = None, +) -> dict[str, Any] | None: + merged = dict(suite_kwargs or {}) + if max_rows is not None: + merged["max_rows"] = int(max_rows) + return merged or None + + +def _load_eval_helpers(): + try: + from tests.eval import evaluate, format_eval_result_table, get_eval_task_results + except Exception as exc: # pragma: no cover - depends on test environment layout + raise ModuleNotFoundError( + "Evaluation helpers are now located in `tests/eval.py`. " + "Import this module from a test checkout when running ParoQuant benchmarks." + ) from exc + + return evaluate, format_eval_result_table, get_eval_task_results + + +def _run_evalution_path_eval( + *, + model_or_id_or_path: Any, + eval_batch_size: int | str, + eval_max_rows: Optional[int], + model_dtype: Any = None, + backend: BACKEND | str | None = None, + eval_model_args: Optional[dict[str, Any]] = None, + eval_suite_kwargs: Optional[dict[str, Any]] = None, +) -> tuple[dict[str, Any], float]: + model_args = {"padding_side": "left"} + normalized_dtype = _normalize_model_dtype(model_dtype) + if normalized_dtype is not None: + model_args["dtype"] = normalized_dtype + if eval_model_args: + model_args.update(eval_model_args) + evaluate, _, _ = _load_eval_helpers() + wall_start = time.perf_counter() + eval_result = evaluate( + model_or_id_or_path=model_or_id_or_path, + tasks=["gsm8k_platinum_cot"], + backend=backend, + batch_size=eval_batch_size, + model_args=model_args, + apply_chat_template=True, + suite_kwargs=_suite_kwargs(eval_max_rows, eval_suite_kwargs), + ) + return eval_result, time.perf_counter() - wall_start + + +def run_fp16_eval( + *, + model_path: str = _DEFAULT_MODEL, + eval_batch_size: int = 64, + eval_max_rows: Optional[int] = None, +) -> dict[str, Any]: + return run_dense_eval( + model_path=model_path, + model_dtype=torch.float16, + eval_batch_size=eval_batch_size, + eval_max_rows=eval_max_rows, + ) + + +def run_dense_eval( + *, + model_path: str = _DEFAULT_MODEL, + model_dtype: Any = torch.float16, + eval_batch_size: int = 64, + eval_max_rows: Optional[int] = None, +) -> dict[str, Any]: + normalized_dtype = _normalize_model_dtype(model_dtype) + eval_result, wall_s = _run_evalution_path_eval( + model_or_id_or_path=model_path, + eval_batch_size=eval_batch_size, + eval_max_rows=eval_max_rows, + model_dtype=normalized_dtype, + ) + _, format_eval_result_table, get_eval_task_results = _load_eval_helpers() + metrics = get_eval_task_results(eval_result) + formatted = format_eval_result_table(eval_result) + return { + "mode": f"dense_{_dtype_label(normalized_dtype)}", + "dtype": _dtype_label(normalized_dtype), + "eval_wall_s": wall_s, + "metrics": metrics, + "eval_table": formatted, + } + + +def _prehook_capture_inputs(module_names: set[str], max_rows: int = 256): + captured: dict[str, torch.Tensor] = {} + hooks = [] + + def make_hook(name: str): + def hook(_module, inputs): + if name in captured or not inputs: + return + x = inputs[0].detach() + x = x.reshape(-1, x.shape[-1])[:max_rows].contiguous() + captured[name] = x + + return hook + + return captured, hooks, make_hook + + +def _tokenize_calibration_sample(model, sample: dict[str, Any]) -> dict[str, torch.Tensor]: + if "input_ids" in sample: + input_ids = torch.as_tensor(sample["input_ids"], dtype=torch.long) + if input_ids.ndim == 1: + input_ids = input_ids.unsqueeze(0) + attention_mask = sample.get("attention_mask") + if attention_mask is None: + attention_mask = torch.ones_like(input_ids, dtype=torch.long) + else: + attention_mask = torch.as_tensor(attention_mask, dtype=torch.long) + if attention_mask.ndim == 1: + attention_mask = attention_mask.unsqueeze(0) + return { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + + tokenizer = model.tokenizer + if "messages" in sample: + rendered = tokenizer.apply_chat_template( + sample["messages"], + tokenize=False, + add_generation_prompt=False, + ) + tokenized = tokenizer(rendered, add_special_tokens=True, return_tensors="pt") + return { + "input_ids": tokenized["input_ids"].to(dtype=torch.long), + "attention_mask": tokenized.get("attention_mask", torch.ones_like(tokenized["input_ids"])).to(dtype=torch.long), + } + + if "text" in sample: + tokenized = tokenizer(sample["text"], add_special_tokens=True, return_tensors="pt") + return { + "input_ids": tokenized["input_ids"].to(dtype=torch.long), + "attention_mask": tokenized.get("attention_mask", torch.ones_like(tokenized["input_ids"])).to(dtype=torch.long), + } + + raise ValueError(f"Unsupported calibration sample keys for ParoQuant kernel benchmark: {sorted(sample.keys())}") + + +def _clone_triton_module(module: ParoLinear) -> ParoQuantTritonLinear: + cloned = ParoQuantTritonLinear( + bits=module.bits, + group_size=module.group_size, + sym=module.sym, + desc_act=module.desc_act, + in_features=module.in_features, + out_features=module.out_features, + bias=module.bias is not None, + register_buffers=True, + krot=module.krot, + fp32_accum=module.fp32_accum, + ).to(device=module.qweight.device) + cloned.qweight.copy_(module.qweight) + cloned.qzeros.copy_(module.qzeros) + cloned.scales.copy_(module.scales) + if module.bias is not None: + cloned.bias.copy_(module.bias) + cloned.pairs.copy_(module.pairs) + cloned.theta.copy_(module.theta) + cloned.channel_scales.copy_(module.channel_scales) + cloned.post_init() + cloned.eval() + return cloned + + +def _dense_forward(module: ParoLinear, x: torch.Tensor) -> torch.Tensor: + with torch.inference_mode(): + x_flat = x.reshape(-1, x.shape[-1]) + rotated = module._rotate_inputs(x_flat) + return module._forward_dense(rotated) + + +def _benchmark_module_ms(module, x: torch.Tensor, warmup: int = 5, iters: int = 20) -> float: + with torch.inference_mode(): + for _ in range(warmup): + module(x) + if x.device.type == "cuda": + torch.cuda.synchronize(x.device) + start = time.perf_counter() + for _ in range(iters): + module(x) + if x.device.type == "cuda": + torch.cuda.synchronize(x.device) + return (time.perf_counter() - start) * 1e3 / iters + + +def benchmark_quantized_first_layer_kernels( + model, + calibration_dataset: list[dict[str, Any]], + *, + capture_rows: int = 256, + warmup: int = 5, + iters: int = 20, +) -> list[list[str]]: + qmodules = { + name: module + for name, module in model.named_modules() + if isinstance(module, ParoLinear) + } + if not qmodules: + return [] + + captured, hooks, make_hook = _prehook_capture_inputs(set(qmodules), max_rows=capture_rows) + for name, module in qmodules.items(): + hooks.append(module.register_forward_pre_hook(make_hook(name))) + + try: + sample = calibration_dataset[0] + tokenized = _tokenize_calibration_sample(model, sample) + model_device = next(model.model.parameters()).device + with torch.inference_mode(): + model.model( + input_ids=tokenized["input_ids"].to(device=model_device), + attention_mask=tokenized["attention_mask"].to(device=model_device), + ) + finally: + for hook in hooks: + hook.remove() + + rows = [] + for name, module in qmodules.items(): + x = captured.get(name) + if x is None or x.numel() == 0: + continue + x = x.to(device=module.qweight.device, dtype=x.dtype) + triton_module = _clone_triton_module(module) + with torch.inference_mode(): + dense = _dense_forward(module, x) + cuda_out = module(x).reshape_as(dense) + triton_out = triton_module(x).reshape_as(dense) + + cuda_diff = (cuda_out - dense).abs() + triton_diff = (triton_out - dense).abs() + cuda_vs_triton = (cuda_out - triton_out).abs() + + dense_ms = _benchmark_module_ms(lambda inp: _dense_forward(module, inp), x, warmup=warmup, iters=iters) + cuda_ms = _benchmark_module_ms(module, x, warmup=warmup, iters=iters) + triton_ms = _benchmark_module_ms(triton_module, x, warmup=warmup, iters=iters) + + rows.append( + [ + name, + str(tuple(x.shape)), + f"{cuda_diff.max().item():.6f}", + f"{cuda_diff.mean().item():.6f}", + f"{triton_diff.max().item():.6f}", + f"{triton_diff.mean().item():.6f}", + f"{cuda_vs_triton.max().item():.6f}", + f"{dense_ms:.3f}", + f"{cuda_ms:.3f}", + f"{triton_ms:.3f}", + ] + ) + del triton_module + + return rows + + +def _region_rows(snapshot: dict[str, dict[str, Any]]) -> list[list[str]]: + populated = [ + (name, stat) + for name, stat in snapshot.items() + if int(stat.get("count", 0)) + ] + total = sum(float(stat.get("total", 0.0)) for _, stat in populated) or 1.0 + populated.sort(key=lambda item: float(item[1].get("total", 0.0)), reverse=True) + return [ + [ + name, + str(int(stat.get("count", 0))), + f"{float(stat.get('last', 0.0)):.3f}", + f"{float(stat.get('total', 0.0)) / max(int(stat.get('count', 0)), 1):.3f}", + f"{float(stat.get('total', 0.0)):.3f}", + f"{100.0 * float(stat.get('total', 0.0)) / total:.1f}%", + str(stat.get("source") or ""), + ] + for name, stat in populated + ] + + +def _module_time_rows(quant_logs: dict[str, list[dict[str, Any]]]) -> list[list[str]]: + rows = [] + for entry in quant_logs.get("paroquant", []): + rows.append( + [ + str(entry.get("layer", "")), + str(entry.get("module", "")), + str(entry.get("feat: in, out", "")), + str(entry.get("samples", "")), + str(entry.get("loss", "")), + str(entry.get("time", "")), + ] + ) + rows.sort(key=lambda item: float(item[-1] or 0.0), reverse=True) + return rows + + +def _run_paroquant_case( + *, + model_path: str, + dynamic: dict[str, dict[str, Any]], + model_dtype: Any = torch.float16, + calibration_rows: int, + calibration_concat_size: int, + quant_batch_size: int, + eval_batch_size: int | str, + eval_max_rows: Optional[int], + eval_model_args: Optional[dict[str, Any]], + eval_suite_kwargs: Optional[dict[str, Any]], + sym: bool, + fused_opt_rotation: bool, + opt_scope: str, + opt_rotation_epochs: int, + opt_finetune_epochs: int, + opt_train_samples: int, + opt_validation_samples: int, + opt_batch_size: int, + result_meta: Optional[dict[str, Any]] = None, + eval_backend: BACKEND | str | None = None, + run_kernel_bench: bool = True, +) -> dict[str, Any]: + if sym is not True: + raise ValueError("ParoQuant benchmark: `sym=False` is disabled; use `sym=True`.") + os.environ["GPTQMODEL_PAROQUANT_OPT_FUSED_ROTATION"] = "1" if fused_opt_rotation else "0" + normalized_dtype = _normalize_model_dtype(model_dtype) + + calibration_dataset = load_nm_calibration(calibration_rows) + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + qcfg = make_paroquant_config( + dynamic=dynamic, + sym=sym, + opt_scope=opt_scope, + opt_rotation_epochs=opt_rotation_epochs, + opt_finetune_epochs=opt_finetune_epochs, + opt_train_samples=opt_train_samples, + opt_validation_samples=opt_validation_samples, + opt_batch_size=opt_batch_size, + ) + model = GPTQModel.load( + model_path, + quantize_config=qcfg, + trust_remote_code=False, + dtype=normalized_dtype, + ) + _prepare_eval_tokenizer(model) + try: + quant_start = time.perf_counter() + quant_logs = model.quantize( + calibration_dataset, + calibration_concat_size=calibration_concat_size, + calibration_sort="desc", + batch_size=quant_batch_size, + ) + quant_wall_s = time.perf_counter() - quant_start + + if torch.cuda.is_available(): + model.model.to("cuda:0") + + kernel_rows = benchmark_quantized_first_layer_kernels(model, calibration_dataset) if run_kernel_bench else [] + with tempfile.TemporaryDirectory(prefix="paroquant_evalution_") as temp_dir: + save_start = time.perf_counter() + model.save(temp_dir) + save_wall_s = time.perf_counter() - save_start + eval_result, eval_wall_s = _run_evalution_path_eval( + model_or_id_or_path=temp_dir, + eval_batch_size=eval_batch_size, + eval_max_rows=eval_max_rows, + model_dtype=normalized_dtype, + backend=eval_backend, + eval_model_args=eval_model_args, + eval_suite_kwargs=eval_suite_kwargs, + ) + _, format_eval_result_table, get_eval_task_results = _load_eval_helpers() + + result = { + "mode": "paroquant_prefix_layers", + "device": _visible_cuda_device_name(), + "dtype": _dtype_label(normalized_dtype), + "fused_opt_rotation": fused_opt_rotation, + "opt_scope": opt_scope, + "sym": sym, + "quant_wall_s": quant_wall_s, + "save_wall_s": save_wall_s, + "eval_wall_s": eval_wall_s, + "quant_logs": quant_logs, + "quant_region_snapshot": model.quant_region_timer.snapshot(), + "module_time_rows": _module_time_rows(quant_logs), + "region_rows": _region_rows(model.quant_region_timer.snapshot()), + "eval_metrics": get_eval_task_results(eval_result), + "eval_table": format_eval_result_table(eval_result), + "kernel_rows": kernel_rows, + } + if result_meta: + result.update(result_meta) + return result + finally: + _cleanup_model(model) + + +def run_paroquant_first_layer_case( + *, + model_path: str = _DEFAULT_MODEL, + num_quant_layers: int = 1, + model_dtype: Any = torch.float16, + calibration_rows: int = 64, + calibration_concat_size: int = 2048, + quant_batch_size: int = 1, + eval_batch_size: int | str = 64, + eval_max_rows: Optional[int] = None, + eval_model_args: Optional[dict[str, Any]] = None, + eval_suite_kwargs: Optional[dict[str, Any]] = None, + sym: bool = True, + fused_opt_rotation: bool = True, + opt_scope: str = "module", + opt_rotation_epochs: int = 10, + opt_finetune_epochs: int = 10, + opt_train_samples: int = 2048, + opt_validation_samples: int = 64, + opt_batch_size: int = 64, +) -> dict[str, Any]: + normalized_dtype = _normalize_model_dtype(model_dtype) + probe_model = GPTQModel.load( + model_path, + quantize_config=QuantizeConfig(method=METHOD.PARO, format=FORMAT.PAROQUANT), + trust_remote_code=False, + dtype=normalized_dtype, + device_map=_single_gpu_device_map(), + ) + dynamic = build_prefix_layer_dynamic(probe_model, num_quant_layers=num_quant_layers) + _cleanup_model(probe_model) + + return _run_paroquant_case( + model_path=model_path, + dynamic=dynamic, + model_dtype=normalized_dtype, + calibration_rows=calibration_rows, + calibration_concat_size=calibration_concat_size, + quant_batch_size=quant_batch_size, + eval_batch_size=eval_batch_size, + eval_max_rows=eval_max_rows, + eval_model_args=eval_model_args, + eval_suite_kwargs=eval_suite_kwargs, + sym=sym, + fused_opt_rotation=fused_opt_rotation, + opt_scope=opt_scope, + opt_rotation_epochs=opt_rotation_epochs, + opt_finetune_epochs=opt_finetune_epochs, + opt_train_samples=opt_train_samples, + opt_validation_samples=opt_validation_samples, + opt_batch_size=opt_batch_size, + eval_backend=BACKEND.PAROQUANT_TRITON if normalized_dtype == torch.bfloat16 else None, + run_kernel_bench=normalized_dtype != torch.bfloat16, + result_meta={ + "mode": "paroquant_prefix_layers", + "num_quant_layers": int(num_quant_layers), + "opt_scope": opt_scope, + }, + ) + + +def run_paroquant_single_module_case( + *, + model_path: str = _DEFAULT_MODEL, + layer_idx: int, + module_name: str, + model_dtype: Any = torch.float16, + calibration_rows: int = 64, + calibration_concat_size: int = 2048, + quant_batch_size: int = 1, + eval_batch_size: int | str = 64, + eval_max_rows: Optional[int] = None, + eval_model_args: Optional[dict[str, Any]] = None, + eval_suite_kwargs: Optional[dict[str, Any]] = None, + sym: bool = True, + fused_opt_rotation: bool = True, + opt_scope: str = "module", + opt_rotation_epochs: int = 10, + opt_finetune_epochs: int = 10, + opt_train_samples: int = 2048, + opt_validation_samples: int = 64, + opt_batch_size: int = 64, +) -> dict[str, Any]: + return run_paroquant_selected_modules_case( + model_path=model_path, + layer_idx=layer_idx, + module_names=[module_name], + model_dtype=model_dtype, + calibration_rows=calibration_rows, + calibration_concat_size=calibration_concat_size, + quant_batch_size=quant_batch_size, + eval_batch_size=eval_batch_size, + eval_max_rows=eval_max_rows, + eval_model_args=eval_model_args, + eval_suite_kwargs=eval_suite_kwargs, + sym=sym, + fused_opt_rotation=fused_opt_rotation, + opt_scope=opt_scope, + opt_rotation_epochs=opt_rotation_epochs, + opt_finetune_epochs=opt_finetune_epochs, + opt_train_samples=opt_train_samples, + opt_validation_samples=opt_validation_samples, + opt_batch_size=opt_batch_size, + ) + + +def run_paroquant_selected_modules_case( + *, + model_path: str = _DEFAULT_MODEL, + layer_idx: int, + module_names: list[str] | tuple[str, ...], + model_dtype: Any = torch.float16, + calibration_rows: int = 64, + calibration_concat_size: int = 2048, + quant_batch_size: int = 1, + eval_batch_size: int | str = 64, + eval_max_rows: Optional[int] = None, + eval_model_args: Optional[dict[str, Any]] = None, + eval_suite_kwargs: Optional[dict[str, Any]] = None, + sym: bool = True, + fused_opt_rotation: bool = True, + opt_scope: str = "module", + opt_rotation_epochs: int = 10, + opt_finetune_epochs: int = 10, + opt_train_samples: int = 2048, + opt_validation_samples: int = 64, + opt_batch_size: int = 64, +) -> dict[str, Any]: + normalized_dtype = _normalize_model_dtype(model_dtype) + probe_model = GPTQModel.load( + model_path, + quantize_config=QuantizeConfig(method=METHOD.PARO, format=FORMAT.PAROQUANT), + trust_remote_code=False, + dtype=normalized_dtype, + device_map=_single_gpu_device_map(), + ) + dynamic = build_selected_modules_dynamic(probe_model, layer_idx=layer_idx, module_names=module_names) + _cleanup_model(probe_model) + + return _run_paroquant_case( + model_path=model_path, + dynamic=dynamic, + model_dtype=normalized_dtype, + calibration_rows=calibration_rows, + calibration_concat_size=calibration_concat_size, + quant_batch_size=quant_batch_size, + eval_batch_size=eval_batch_size, + eval_max_rows=eval_max_rows, + eval_model_args=eval_model_args, + eval_suite_kwargs=eval_suite_kwargs, + sym=sym, + fused_opt_rotation=fused_opt_rotation, + opt_scope=opt_scope, + opt_rotation_epochs=opt_rotation_epochs, + opt_finetune_epochs=opt_finetune_epochs, + opt_train_samples=opt_train_samples, + opt_validation_samples=opt_validation_samples, + opt_batch_size=opt_batch_size, + eval_backend=BACKEND.PAROQUANT_TRITON if normalized_dtype == torch.bfloat16 else None, + run_kernel_bench=normalized_dtype != torch.bfloat16, + result_meta={ + "mode": "paroquant_selected_modules", + "layer_idx": int(layer_idx), + "module_name": ",".join(str(name) for name in module_names), + "module_names": [str(name) for name in module_names], + "opt_scope": opt_scope, + }, + ) + + +def comparison_rows(*cases: dict[str, Any]) -> list[list[str]]: + rows = [] + for case in cases: + metric = case.get("metrics") or case.get("eval_metrics") or {} + gsm8k = metric.get("gsm8k_platinum_cot", {}) + score = gsm8k.get("acc,num") + label = case.get("label") or case.get("mode", "") + rows.append( + [ + label, + str(case.get("opt_scope", "")), + str(case.get("sym", "")), + str(case.get("fused_opt_rotation", "")), + "" if score is None else f"{float(score):.6f}", + "" if "quant_wall_s" not in case else f"{float(case['quant_wall_s']):.3f}", + "" if "eval_wall_s" not in case else f"{float(case['eval_wall_s']):.3f}", + ] + ) + return rows + + +def render_case_tables(case: dict[str, Any]) -> dict[str, str]: + return { + "comparison": tabulate( + comparison_rows(case), + headers=["case", "opt_scope", "sym", "fused_opt", "gsm8k_platinum_cot", "quant_wall_s", "eval_wall_s"], + tablefmt="grid", + ), + "module_times": tabulate( + case.get("module_time_rows", []), + headers=["layer", "module", "feat", "samples", "loss", "time_s"], + tablefmt="grid", + ), + "regions": tabulate( + case.get("region_rows", []), + headers=["region", "count", "last_s", "avg_s", "total_s", "pct", "source"], + tablefmt="grid", + ), + "kernels": tabulate( + case.get("kernel_rows", []), + headers=[ + "module", + "input_shape", + "cuda_max_abs", + "cuda_mean_abs", + "triton_max_abs", + "triton_mean_abs", + "cuda_vs_triton_max_abs", + "dense_ms", + "cuda_ms", + "triton_ms", + ], + tablefmt="grid", + ), + "eval": case.get("eval_table", ""), + } + + +def write_case_json(case: dict[str, Any], output_path: str | os.PathLike[str]) -> None: + path = Path(output_path) + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8") as handle: + json.dump(case, handle, indent=2, sort_keys=True) diff --git a/gptqmodel/utils/perplexity.py b/gptqmodel/utils/perplexity.py index 806ce66ac..00b5339f5 100644 --- a/gptqmodel/utils/perplexity.py +++ b/gptqmodel/utils/perplexity.py @@ -3,51 +3,33 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium +import math import sys +from typing import Iterable, List, Tuple -import numpy as np import torch +import torch.nn.functional as F from datasets import load_dataset, load_from_disk from logbar import LogBar logger = LogBar.shared() + class Perplexity: """ - A class for calculating the perplexity of a language model. + A helper for calculating next-token perplexity over a text corpus. """ def __init__( - self, - model, - tokenizer, - dataset_path="wikitext", - dataset_name=None, - split="test", - text_column="text", + self, + model, + tokenizer, + dataset_path: str = "wikitext", + dataset_name: str | None = None, + split: str = "test", + text_column: str = "text", ): - """ - Calculate perplexity using the same method as seen in llama.cpp. - - Parameters - ---------- - model : AutoModelForCausalLM - The language model for which the perplexity is calculated. - tokenizer : AutoTokenizer - The tokenizer corresponding to the model. - device : str, optional - The device to run the calculations on. If auto, the device that your model uses - will be the device used for these calculations. Default is 'auto'. - dataset_path : str, optional - The path to the dataset on the Hugging Face dataset hub. Default is 'wikitext'. - dataset_name : str, optional - The name of the dataset. Default is None. - split : str, optional - The split of the dataset to use. Default is 'test'. - text_column : str, optional - The name of the column in the dataset that contains the text data. Default is 'text'. - """ self._model = model self._tokenizer = tokenizer self._dataset_path = dataset_path @@ -56,27 +38,13 @@ def __init__( self._text_column = text_column self._text = self._prepare_data() - def _get_device(self): - if torch.backends.mps.is_available(): - return "mps" - elif torch.cuda.is_available(): - return "cuda:0" - else: - return "cpu" - - def _prepare_data(self): + def _prepare_data(self) -> str: """ - Prepares the dataset by loading and formatting. - - Returns - ------- - str - The formatted dataset as a single string. + Load the requested dataset and concatenate a bounded number of samples. """ if self._dataset_path == "wikitext": self._dataset_name = "wikitext-2-raw-v1" - # Load the dataset length = 512 if self._dataset_path == "wikitext" else 2048 if self._dataset_path.startswith("/") or self._dataset_path.startswith("./"): if self._dataset_path.endswith(".gz"): @@ -86,163 +54,132 @@ def _prepare_data(self): else: data = load_dataset(self._dataset_path, self._dataset_name, split=self._split) - datas = [] - for index, sample in enumerate(data): + datas: List[str] = [] + for sample in data: text = sample[self._text_column] if len(text) >= length: - # Format the text column of the dataset datas.append(" \n" if text == "" else text) if len(datas) >= 1024: break return "".join(datas) - @staticmethod - def softmax(logits): - """ - Static method for applying the softmax function. + def _model_device(self) -> torch.device: + model_device = getattr(self._model, "device", None) + if model_device is not None: + try: + resolved = torch.device(model_device) + if resolved.type != "meta": + return resolved + except (RuntimeError, TypeError, ValueError): + pass - Parameters - ---------- - logits : torch.Tensor - The input to the softmax function. + try: + first_param = next(self._model.parameters()) + except (AttributeError, StopIteration): + return torch.device("cpu") - Returns - ------- - np.ndarray - The output of the softmax function. - """ - e_x = torch.exp(logits - torch.max(logits)) - return e_x / torch.sum(e_x, dim=0) + return first_param.device if first_param.device.type != "meta" else torch.device("cpu") - def calculate(self, n_ctx=512, n_batch=512): - """ - Calculates the perplexity of the language model. - - Parameters - ---------- - n_ctx : int - The context size. - n_batch : int - The batch size. - - Returns - ------- - list - The list of perplexity scores calculated. - """ - # Tokenize the text + def _tokenize(self) -> torch.Tensor: self._tokenizer.model_max_length = sys.maxsize - tokens = self._tokenizer(self._text, truncation=False, return_tensors="pt").input_ids.to(self._model.device) - - nll = 0.0 # Negative log likelihood - count = 0 # Counter for processed tokens - curr_ppl = 0 - all_perplexity = [] - - with logger.pb(range(len(tokens[0]) // n_ctx)).title("Perplexity: - ").manual() as pb: - for i in pb: - # Process each batch of tokens - nll, count = self._process_batch(i, n_ctx, n_batch, tokens, nll, count) + return self._tokenizer(self._text, truncation=False, return_tensors="pt").input_ids + + def _pad_token_id(self) -> int: + pad_token_id = getattr(self._tokenizer, "pad_token_id", None) + if pad_token_id is None: + pad_token_id = getattr(self._tokenizer, "eos_token_id", None) + if pad_token_id is None: + pad_token_id = 0 + return int(pad_token_id) + + def _build_windows(self, tokens: torch.Tensor, starts: Iterable[int], n_ctx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + windows = [] + for start in starts: + window = tokens[start : min(start + n_ctx + 1, tokens.shape[0])] + if window.numel() >= 2: + windows.append(window) + + if not windows: + empty = torch.empty(0, dtype=tokens.dtype) + return empty, empty, empty + + max_len = max(window.numel() - 1 for window in windows) + batch_size = len(windows) + pad_token_id = self._pad_token_id() + + input_ids = torch.full((batch_size, max_len), pad_token_id, dtype=tokens.dtype) + attention_mask = torch.zeros((batch_size, max_len), dtype=torch.long) + labels = torch.full((batch_size, max_len), -100, dtype=torch.long) + + for row, window in enumerate(windows): + inputs = window[:-1] + targets = window[1:] + seq_len = inputs.numel() + input_ids[row, :seq_len] = inputs + labels[row, :seq_len] = targets.long() + attention_mask[row, :seq_len] = 1 + + return input_ids, attention_mask, labels + + def calculate(self, n_ctx: int = 512, n_batch: int = 512) -> List[float]: + """ + Calculate cumulative perplexity values across the evaluation corpus. - # Calculate and display the current perplexity - curr_ppl = np.exp(nll / count) + `n_ctx` is the maximum context length per sequence window. `n_batch` + acts as an approximate token budget per forward pass; windows are batched + together in groups of `max(1, n_batch // n_ctx)`. + """ + if n_ctx < 2: + raise ValueError("Perplexity.calculate: `n_ctx` must be >= 2.") + if n_batch <= 0: + raise ValueError("Perplexity.calculate: `n_batch` must be > 0.") + + flat_tokens = self._tokenize()[0].cpu() + if flat_tokens.numel() < 2: + return [] + + device = self._model_device() + windows_per_forward = max(1, n_batch // n_ctx) + starts = list(range(0, flat_tokens.numel() - 1, n_ctx)) + + nll = 0.0 + count = 0 + all_perplexity: List[float] = [] + + with logger.pb(range(0, len(starts), windows_per_forward)).title("Perplexity: - ").manual() as pb: + for offset in pb: + input_ids, attention_mask, labels = self._build_windows( + flat_tokens, + starts[offset : offset + windows_per_forward], + n_ctx, + ) + if input_ids.numel() == 0: + continue + + input_ids = input_ids.to(device) + attention_mask = attention_mask.to(device) + labels = labels.to(device) + + with torch.inference_mode(): + outputs = self._model(input_ids=input_ids, attention_mask=attention_mask) + + logits = outputs.logits.float() + loss = F.cross_entropy( + logits.reshape(-1, logits.shape[-1]), + labels.reshape(-1), + ignore_index=-100, + reduction="sum", + ) + valid_tokens = int(torch.count_nonzero(labels != -100).item()) + if valid_tokens == 0: + continue + + nll += float(loss.item()) + count += valid_tokens + + curr_ppl = math.exp(nll / count) all_perplexity.append(curr_ppl) pb.title(f"Perplexity: {curr_ppl:.4f}").draw() return all_perplexity - - def _process_batch(self, i, n_ctx, n_batch, tokens, nll, count): - """ - Processes each batch of tokens. - - Parameters - ---------- - i : int - The batch index. - n_ctx : int - The context size. - n_batch : int - The batch size. - tokens : torch.Tensor - The tokenized text. - nll : float - The current negative log likelihood. - count : int - The current count of processed tokens. - - Returns - ------- - float - The updated negative log likelihood. - int - The updated count of processed tokens. - """ - start = i * n_ctx - end = start + n_ctx - - num_batches = (n_ctx + n_batch - 1) // n_batch - - logits = [] - - for j in range(num_batches): - batch_start = start + j * n_batch - batch_size = min(end - batch_start, n_batch) - - token_org = tokens[0][batch_start].item() - - if j == 0: - # some models do not set/use bos_token - if self._tokenizer.bos_token_id is not None: - # Replace the first token with the BOS token - tokens[0][batch_start] = self._tokenizer.bos_token_id - - # Compute the logits for the current batch of tokens - batch_logits = self._compute_batch_logits(tokens, batch_start, batch_size) - - tokens[0][batch_start] = token_org - - logits.append(batch_logits) - - # We rely on the fact that attention in the forward pass only looks at previous - # tokens here, so the logits returned for each token are an accurate representation - # of what the model would have predicted at that point. - # - # Example, we have a context window of 512, we will compute perplexity for each of the - # last 256 tokens. Then, we split the input up into context window size chunks to - # process the entire prompt. - - for j in range(min(512, n_ctx // 2), n_ctx - 1): - tok_logits = logits[0][0][j] - - # Compute the probability of the next token - prob = self.softmax(tok_logits)[tokens[0][start + j + 1]] - - # Update the negative log likelihood and the count of processed tokens - nll += -torch.log(torch.where(prob > 0, prob, torch.tensor(1e-8))).item() - count += 1 - - return nll, count - - def _compute_batch_logits(self, tokens, batch_start, batch_size): - """ - Computes the logits for a batch of tokens. - - Parameters - ---------- - tokens : torch.Tensor - The tokenized text. - batch_start : int - The start index of the batch. - batch_size : int - The size of the batch. - - Returns - ------- - torch.Tensor - The logits for the batch of tokens. - """ - # Compute the logits without keeping track of gradients - with torch.inference_mode(): - outputs = self._model(tokens[:, batch_start: batch_start + batch_size]) - return outputs.logits.detach() diff --git a/gptqmodel/utils/python.py b/gptqmodel/utils/python.py index 9cfd494ee..b509a713b 100644 --- a/gptqmodel/utils/python.py +++ b/gptqmodel/utils/python.py @@ -5,6 +5,7 @@ import platform import sys +import sysconfig from packaging.version import Version @@ -13,15 +14,27 @@ log = setup_logger() +# Check if this Python build supports free-threading / GIL control. +# Starting from python 3.13 it is possible to disable GIL at build time. +def is_free_threading_build(): + """Return True when Python was built with free-threading support.""" + py_gil_disabled = sysconfig.get_config_var("Py_GIL_DISABLED") + try: + return int(py_gil_disabled or 0) == 1 + except (TypeError, ValueError): + return False + + # Check if GIL (global interpreter lock) is controllable in this Python build. -# Starting from python 3.13 it is possible to disable GIL +# Starting from python 3.13 it is possible to disable GIL. def has_gil_control(): - return hasattr(sys, '_is_gil_enabled') + return is_free_threading_build() -# Check if GIL (global interpreter lock) is enabled. -# Starting from python 3.13 it is possible to disable GIL +# Check if GIL (global interpreter lock) is enabled at runtime. +# Starting from python 3.13 it is possible to disable GIL. def has_gil_disabled(): - return has_gil_control() and not sys._is_gil_enabled() + gil_enabled = getattr(sys, "_is_gil_enabled", None) + return has_gil_control() and callable(gil_enabled) and not gil_enabled() # Check For Python >= 3.13.3 def gte_python_3_13_3(): diff --git a/gptqmodel/utils/qqq.py b/gptqmodel/utils/qqq.py new file mode 100644 index 000000000..96c183380 --- /dev/null +++ b/gptqmodel/utils/qqq.py @@ -0,0 +1,125 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from pathlib import Path + +from .cpp import ( + TorchOpsJitExtension, + default_jit_cflags, + default_jit_cuda_cflags, + default_torch_ops_build_root, +) + + +_QQQ_OPS_NAME = "gptqmodel_qqq_ops" +_QQQ_OPS_NAMESPACE = "gptqmodel_qqq" + + +def _qqq_sources() -> list[str]: + root = Path(__file__).resolve().parents[2] / "gptqmodel_ext" / "qqq" + return [ + str(root / "qqq.cpp"), + str(root / "qqq_gemm.cu"), + ] + + +def _qqq_include_paths() -> list[str]: + root = Path(__file__).resolve().parents[2] / "gptqmodel_ext" / "qqq" + return [str(root)] + + +def _qqq_extra_cflags() -> list[str]: + return default_jit_cflags(enable_bf16=True) + + +def _qqq_extra_cuda_cflags() -> list[str]: + return default_jit_cuda_cflags( + enable_bf16=True, + include_lineinfo=True, + include_nvcc_threads=True, + include_ptxas_optimizations=True, + include_diag_suppress=True, + ) + + +_QQQ_TORCH_OPS_EXTENSION = TorchOpsJitExtension( + name=_QQQ_OPS_NAME, + namespace=_QQQ_OPS_NAMESPACE, + required_ops=("qqq_gemm",), + sources=_qqq_sources, + build_root_env="GPTQMODEL_QQQ_BUILD_ROOT", + default_build_root=lambda: default_torch_ops_build_root("qqq"), + display_name="QQQ", + extra_cflags=_qqq_extra_cflags, + extra_cuda_cflags=_qqq_extra_cuda_cflags, + extra_include_paths=_qqq_include_paths, + force_rebuild_env="GPTQMODEL_QQQ_FORCE_REBUILD", + verbose_env="GPTQMODEL_EXT_VERBOSE", + requires_cuda=True, +) + + +def _extension_api(): + from gptqmodel import extension as extension_api + + return extension_api + + +def clear_qqq_extension_cache() -> None: + _QQQ_TORCH_OPS_EXTENSION.clear_cache() + + +def qqq_runtime_available() -> bool: + return _extension_api().is_available("qqq") + + +def qqq_runtime_error() -> str: + extension_api = _extension_api() + if extension_api.is_available("qqq"): + return "" + return extension_api.error("qqq") or "QQQ CUDA runtime unavailable." + + +def prewarm_qqq_extension() -> bool: + return _extension_api().load(name="qqq")["qqq"] + + +def qqq_gemm( + A, + B, + C, + D, + s1, + s2, + s3, + workspace, + thread_k=-1, + thread_n=-1, + sms=-1, + max_par=16, +): + return _extension_api().op("qqq", "qqq_gemm")( + A, + B, + C, + D, + s1, + s2, + s3, + workspace, + thread_k, + thread_n, + sms, + max_par, + ) + + +__all__ = [ + "clear_qqq_extension_cache", + "prewarm_qqq_extension", + "qqq_gemm", + "qqq_runtime_available", + "qqq_runtime_error", +] diff --git a/gptqmodel/utils/random_str.py b/gptqmodel/utils/random_str.py index 9c35db3d0..c368d3bcb 100644 --- a/gptqmodel/utils/random_str.py +++ b/gptqmodel/utils/random_str.py @@ -1,7 +1,8 @@ -import random +import secrets import string def get_random_string(length: int = 8) -> str: """Generate a random string of fixed length with lowercase English letters.""" - return ''.join(random.choices(string.ascii_lowercase, k=length)) + alphabet = string.ascii_lowercase + return ''.join(secrets.choice(alphabet) for _ in range(length)) diff --git a/gptqmodel/utils/stream.py b/gptqmodel/utils/stream.py index 38a48bc58..04f8322f6 100644 --- a/gptqmodel/utils/stream.py +++ b/gptqmodel/utils/stream.py @@ -195,9 +195,9 @@ def stream_tensor_dict_to_cpu( # store_callback(host_map) # return host_map - first = next(iter(filtered.values())) + first_cuda = next((tensor for tensor in filtered.values() if tensor.device.type == "cuda"), None) - if first.device.type != "cuda" or not torch.cuda.is_available(): + if first_cuda is None or not torch.cuda.is_available(): host_map = {name: tensor.detach().to("cpu") for name, tensor in filtered.items()} with state_lock: store_callback(host_map) @@ -205,18 +205,25 @@ def stream_tensor_dict_to_cpu( host_map: Dict[str, torch.Tensor] = {} - copy_device = first.device + copy_device = first_cuda.device compute_stream = torch.cuda.current_stream(device=copy_device) copy_stream = _get_cached_copy_stream(copy_device) - done_event = torch.cuda.Event(enable_timing=False, blocking=False) pending_sources: List[torch.Tensor] = [] + pending_keys: List[str] = [] with torch.cuda.stream(copy_stream): copy_stream.wait_stream(compute_stream) for name, tensor in filtered.items(): src = tensor.detach() + if src.device.type != "cuda": + host_map[name] = src.to("cpu") + continue + if src.device != copy_device: + host_map[name] = src.to("cpu") + continue src.record_stream(copy_stream) pending_sources.append(src) + pending_keys.append(name) host = torch.empty( src.shape, dtype=src.dtype, @@ -226,12 +233,19 @@ def stream_tensor_dict_to_cpu( ) host.copy_(src, non_blocking=True) host_map[name] = host + + if not pending_sources: + with state_lock: + store_callback(host_map) + return host_map + + done_event = torch.cuda.Event(enable_timing=False, blocking=False) done_event.record(copy_stream) ticket = StreamCopyTicket( event=done_event, device=copy_device, - keys=tuple(host_map.keys()), + keys=tuple(pending_keys), sources=pending_sources, stream=copy_stream, ) diff --git a/gptqmodel/utils/structure.py b/gptqmodel/utils/structure.py index 6889c7f3b..083372b14 100644 --- a/gptqmodel/utils/structure.py +++ b/gptqmodel/utils/structure.py @@ -22,12 +22,20 @@ Notes: - Detects shared submodules and avoids re-printing them. - Collapsing is generic: any numeric-indexed ModuleList whose qualified name matches `experts-regex`. +- Large layer stacks are capped to the first 4 children by default. """ -from typing import Dict, Iterable, Optional, Set, Tuple +import copy +import inspect +import json +import os +import threading +from dataclasses import dataclass +from typing import Any, Dict, Iterable, Optional, Set, Tuple -import pcre as re +import pcre import torch +from safetensors import safe_open from torch import nn from ..utils.logger import setup_logger @@ -59,10 +67,21 @@ def _maybe(s: str, code: str, *, color: bool) -> str: torch.int32: 4, torch.int: 4, torch.bool: 1, } -if hasattr(torch, "float8_e4m3fn"): - _DTYPE_BYTES[torch.float8_e4m3fn] = 1 -if hasattr(torch, "float8_e5m2"): - _DTYPE_BYTES[torch.float8_e5m2] = 1 +for _dtype_name in ( + *[ + name + for name in ( + "float8_e4m3fn", + "float8_e5m2", + "float8_e4m3fnuz", + "float8_e5m2fnuz", + "float8_e8m0fnu", + ) + if hasattr(torch, name) + ], + *[name for name in ("float4_e2m1fn_x2",) if hasattr(torch, name)], +): + _DTYPE_BYTES[getattr(torch, _dtype_name)] = 1 class _FakeDType: """Sentinel dtype for experimental 4-bit formats.""" @@ -212,6 +231,8 @@ def print_module_tree( collapse_experts: bool = True, experts_regex: str = r"(^|\.)experts($|\.)", experts_show: int = 1, + layers_regex: str = r"(^|\.)((model_)?layers|layer|h|blocks|block)($|\.)", + layers_show: Optional[int] = 4, ): """ Pretty-print a module tree with sizes, devices, dtypes, and optional param/buffer details. @@ -244,8 +265,6 @@ def depth_color(depth: int) -> str: "bfloat16":"\033[35m", # magenta "float16": "\033[33m", # yellow "half": "\033[33m", # yellow (alias) - "float8_e4m3fn": "\033[34m", # blue - "float8_e5m2": "\033[34m", # blue "MXFP4": "\033[36m", # cyan (sentinel 4-bit) "NVFP4": "\033[36m", # cyan (sentinel 4-bit) "int8": "\033[31m", # red @@ -257,6 +276,21 @@ def depth_color(depth: int) -> str: "bool": "\033[37m", # white/gray "-": "\033[37m", # white/gray (unknown) } + for _dtype_name in ( + *[ + name + for name in ( + "float8_e4m3fn", + "float8_e5m2", + "float8_e4m3fnuz", + "float8_e5m2fnuz", + "float8_e8m0fnu", + ) + if hasattr(torch, name) + ], + *[name for name in ("float4_e2m1fn_x2",) if hasattr(torch, name)], + ): + DTYPE_COLOR[_dtype_name] = "\033[34m" if _dtype_name.startswith("float8_") else "\033[36m" DEVICE_COLOR = { "cpu": "\033[37m", # white/gray "cuda": "\033[32m", # green @@ -357,24 +391,46 @@ def _line(kind: str, name: str, t: torch.Tensor) -> str: # ------------------------------------------------------------------ # Setup + utilities # ------------------------------------------------------------------ - _ = re.compile(filter_regex) if filter_regex else None # reserved for future - experts_name_re = re.compile(experts_regex) if collapse_experts else None + _ = pcre.compile(filter_regex) if filter_regex else None # reserved for future + experts_path_re = pcre.compile(experts_regex) + layers_name_re = pcre.compile(layers_regex) if layers_show is not None else None seen: Set[int] = set() total_p = sum(p.numel() for p in model.parameters()) total_b = sum(b.numel() for b in model.buffers()) # fixed loop variable - def should_collapse(qual_name: str, container: nn.Module) -> bool: - if not experts_name_re: - return False - if not experts_name_re.search(qual_name): - return False + def numeric_children(container: nn.Module): if not isinstance(container, (nn.ModuleList, nn.Sequential)): - return False - names = [n for n, _ in container.named_children()] - if not names: - return False - return all(n.isdigit() for n in names) and len(names) > max(0, experts_show) + return None + children = list(container.named_children()) + if not children: + return None + if not all(name.isdigit() for name, _ in children): + return None + return children + + def collapse_spec(qual_name: str, container: nn.Module): + children = numeric_children(container) + if children is None: + return None + + total_children = len(children) + if experts_path_re.search(qual_name): + if not collapse_experts: + return None + + show_count = max(0, experts_show) + if total_children > show_count: + return children, show_count, "expert" + return None + + if layers_name_re is None or not layers_name_re.search(qual_name): + return None + + show_count = max(0, layers_show) + if total_children > show_count: + return children, show_count, "layer" + return None def _format_line(prefix: str, trunk: str, qual_name: str, mod: nn.Module, show_counts: bool, depth: int) -> str: @@ -416,50 +472,28 @@ def rec(mod: nn.Module, name: str, depth: int, prefix: str, is_last: bool): elif show_params or show_buffers: print_params_with_colors(param_indent, mod, include_buffers=show_buffers) + collapse = collapse_spec(name, mod) + if collapse is not None: + children, show_count, item_label = collapse + shown_children = children[:max(0, min(show_count, len(children)))] + for i, (child_name, child) in enumerate(shown_children): + child_is_last = (i == len(shown_children) - 1) and (len(shown_children) == len(children)) + rec(child, f"{name}.{child_name}", depth + 1, indent, child_is_last) + + if len(shown_children) < len(children) and len(children) > 0: + p_one, b_one = _param_summary(children[0][1], recurse=True) + collapsed = ( + f"• … collapsed (repeats {len(shown_children)}..{len(children)-1}, " + f"per-{item_label} P={human_count(p_one)} B={human_count(b_one)})" + ) + print(_maybe(indent + collapsed, DIM, color=color)) + return + children = list(mod.named_children()) n = len(children) for i, (child_name, child) in enumerate(children): last = (i == n - 1) - child_prefix = prefix + (" " if is_last else "│ ") - display_name = f"{name}.{child_name}" if name else child_name - - if should_collapse(display_name, child): - line2 = _format_line(child_prefix, "└─ " if last else "├─ ", - display_name, child, True, depth+1) - annot2 = colorize_annotation(_annotate(child, color=color)) - print(line2 + " " + annot2) - - sub_children = list(child.named_children()) - total_k = len(sub_children) - k_show = max(0, min(experts_show, total_k)) - - for j, (sub_name, sub_mod) in enumerate(sub_children[:k_show]): - sub_last = (j == k_show - 1) and (k_show == total_k) - sub_prefix = child_prefix + (" " if last else "│ ") - sub_trunk = "└─ " if sub_last else "├─ " - line3 = _format_line(sub_prefix, sub_trunk, - f"{display_name}.{sub_name}", - sub_mod, True, depth+2) - annot3 = colorize_annotation(_annotate(sub_mod, color=color)) - print(line3 + " " + annot3) - rec( - sub_mod, - f"{display_name}.{sub_name}", - depth + 2, - child_prefix + (" " if last else "│ "), - sub_last, - ) - - if k_show < total_k and total_k > 0: - p_one, b_one = _param_summary(sub_children[0][1], recurse=True) - collapsed = ( - f"• … collapsed (repeats {k_show}..{total_k-1}, " - f"per-expert P={human_count(p_one)} B={human_count(b_one)})" - ) - print(_maybe(child_prefix + (" " if last else "│ ") + collapsed, DIM, color=color)) - continue - - rec(child, display_name, depth + 1, child_prefix, last) + rec(child, f"{name}.{child_name}" if name else child_name, depth + 1, indent, last) # ------------------------------------------------------------------ # Root print + recursion @@ -517,172 +551,1223 @@ def _ensure_target_storage_on_device_(param: torch.nn.Parameter, device: torch.d param.data = param.data.to(device, copy=True) # alloc new storage on device; keeps Parameter identity return param + +@dataclass(frozen=True) +class _MoEAliasSpec: + """MoE alias groups derived entirely from the model definition's `module_tree`.""" + + runtime_root_path: tuple[str, ...] + root_alias_paths: tuple[tuple[str, ...], ...] + runtime_experts_path: tuple[str, ...] + expert_alias_paths: tuple[tuple[str, ...], ...] + runtime_leaf_groups: tuple[tuple[str, ...], ...] + leaf_alias_groups: tuple[tuple[tuple[str, ...], ...], ...] + + +class LazyTurtle: + """Checkpoint-backed shell materializer for local safetensors models. + + The traditional offload path builds a meta shell model and then instantiates + a full CPU "turtle" model from `from_pretrained()` so submodules can be + copied over on demand. For very large local sharded checkpoints this upfront + load is dominated by walking every shard. + + This source keeps only the checkpoint index in memory and materializes the + requested shell submodule directly from the relevant safetensors shards. + """ + + supports_reload = False + is_lazy_checkpoint_source = True + def __init__( + self, + *, + model_local_path: str, + config: Any, + model_init_kwargs: Optional[Dict[str, Any]] = None, + module_tree: Optional[Any] = None, + ) -> None: + self.model_local_path = model_local_path + self.config = copy.deepcopy(config) + self._model_init_kwargs = dict(model_init_kwargs or {}) + self._weight_map = self._load_weight_map(model_local_path) + # Lazy checkpoint name resolution must come from model-definition truth. + self._module_tree = copy.deepcopy(module_tree) + self._module_tree_layer_prefix, self._moe_alias_specs = self._build_moe_alias_specs(self._module_tree) + self._lock = threading.RLock() + + @classmethod + def maybe_create( + cls, + *, + model_local_path: Optional[str], + config: Any, + model_init_kwargs: Optional[Dict[str, Any]] = None, + module_tree: Optional[Any] = None, + ) -> Optional["LazyTurtle"]: + if not model_local_path or not os.path.isdir(model_local_path): + return None + + try: + return cls( + model_local_path=model_local_path, + config=config, + model_init_kwargs=model_init_kwargs, + module_tree=module_tree, + ) + except Exception as exc: + log.debug( + "LazyTurtle: disabled for `%s`: %s", + model_local_path, + exc, + ) + return None + + def eval(self) -> "LazyTurtle": + return self + + def materialize_submodule( + self, + *, + target_model: torch.nn.Module, + target_submodule: torch.nn.Module, + device: torch.device, + non_blocking: bool = False, + ) -> torch.nn.Module: + path = _get_qualified_name(target_model, target_submodule) + with self._lock: + self._copy_checkpoint_tensors_into_submodule( + target_model=target_model, + target_submodule=target_submodule, + module_path=path, + device=device, + recurse=True, + non_blocking=non_blocking, + ) + if hasattr(target_model, "tie_weights"): + target_model.tie_weights() + return target_submodule + + def checkpoint_tensors_for_submodule( + self, + *, + target_model: nn.Module, + target_submodule: nn.Module, + recurse: bool = False, + ) -> Dict[str, torch.Tensor]: + """Load checkpoint tensors for one shell submodule without mutating it.""" + + path = _get_qualified_name(target_model, target_submodule) + with self._lock: + return self._load_checkpoint_tensors_for_module_path( + module_path=path, + recurse=recurse, + ) + + def sync_all_meta( + self, + *, + shell_model: nn.Module, + require_class_match: bool = True, + verify_shapes: bool = True, + tie_after: bool = True, + ) -> int: + del require_class_match, verify_shapes + + materialized = 0 + param_cache: Dict[tuple[str, torch.dtype, bool], nn.Parameter] = {} + buffer_cache: Dict[tuple[str, torch.dtype], torch.Tensor] = {} + + with self._lock, torch.inference_mode(): + for qname, shell_sub in list(shell_model.named_modules()): + materialized += self._materialize_direct_meta_tensors( + shell_sub=shell_sub, + module_path=qname, + param_cache=param_cache, + buffer_cache=buffer_cache, + ) + + if tie_after and hasattr(shell_model, "tie_weights") and getattr(shell_model.config, "tie_word_embeddings", False): + try: + shell_model.tie_weights() + log.info("Module: Re-tied embedding weights on shell model after lazy sync") + except Exception as exc: + log.info(f"Module: tie_weights failed: {exc}") + + log.info("Module: Total direct tensors materialized from lazy checkpoint source: %s", materialized) + return materialized + + def _load_weight_map(self, model_local_path: str) -> Dict[str, str]: + from .model import get_checkpoints + + is_sharded, resolved_archive_file, _ = get_checkpoints( + model_local_path, + extensions=[".safetensors"], + possible_model_basenames=["model", "pytorch_model"], + ) + + if is_sharded: + with open(resolved_archive_file, encoding="utf-8") as fp: + index = json.load(fp) + weight_map = index.get("weight_map", {}) + if not isinstance(weight_map, dict) or not weight_map: + raise ValueError(f"Invalid safetensors index: {resolved_archive_file}") + return {str(name): str(filename) for name, filename in weight_map.items()} + + shard_name = os.path.basename(resolved_archive_file) + with safe_open(resolved_archive_file, framework="pt", device="cpu") as handler: + keys = list(handler.keys()) + if not keys: + raise ValueError(f"No tensors found in safetensors file: {resolved_archive_file}") + return {str(name): shard_name for name in keys} + + @staticmethod + def _join_tensor_name(module_path: str, rel_name: str) -> str: + if not module_path: + return rel_name + if not rel_name: + return module_path + return f"{module_path}.{rel_name}" + + @staticmethod + def _parse_module_spec(module_spec: str) -> tuple[tuple[str, ...], tuple[str, ...]]: + """Split one module-tree token into ordered aliases and `:flag` suffixes.""" + + parts = module_spec.split(":") if isinstance(module_spec, str) else [str(module_spec)] + aliases = tuple(alias for alias in parts[0].split("|") if alias) if parts else (str(module_spec),) + if not aliases: + aliases = (str(module_spec),) + flags = tuple(part for part in parts[1:] if part) + return aliases, flags + + @staticmethod + def _expand_path_aliases(path_aliases: tuple[tuple[str, ...], ...]) -> tuple[tuple[str, ...], ...]: + """Expand a sequence of aliased path segments into every concrete path variant.""" + + paths: list[tuple[str, ...]] = [()] + for segment_aliases in path_aliases: + next_paths: list[tuple[str, ...]] = [] + for prefix in paths: + for alias in segment_aliases: + candidate = prefix + (alias,) + if candidate not in next_paths: + next_paths.append(candidate) + paths = next_paths + return tuple(paths) + + @classmethod + def _build_moe_alias_specs(cls, module_tree: Optional[Any]) -> tuple[tuple[str, ...], tuple[_MoEAliasSpec, ...]]: + """Extract runtime/checkpoint MoE aliases directly from the model definition's `module_tree`.""" + + if not isinstance(module_tree, list): + return (), () + + layer_prefix: list[str] = [] + specs: list[_MoEAliasSpec] = [] + seen_specs: set[ + tuple[ + tuple[str, ...], + tuple[tuple[str, ...], ...], + tuple[str, ...], + tuple[tuple[str, ...], ...], + tuple[tuple[tuple[str, ...], ...], ...], + ] + ] = set() + + for item in module_tree: + if item == "#": + break + if isinstance(item, str): + aliases, _flags = cls._parse_module_spec(item) + layer_prefix.append(aliases[0]) + + def walk(node: Any, path: tuple[Any, ...], moe_root: Optional[tuple[tuple[str, ...], ...]]) -> None: + if isinstance(node, dict): + for raw_key, value in node.items(): + if raw_key == "#": + walk(value, path + ("#",), moe_root) + continue + if not isinstance(raw_key, str): + continue + aliases, flags = cls._parse_module_spec(raw_key) + next_path = path + (aliases,) + next_moe_root = next_path if "moe" in flags and moe_root is None else moe_root + walk(value, next_path, next_moe_root) + return + + if isinstance(node, (tuple, list)) and all(isinstance(item, str) for item in node): + if moe_root is None or "#" not in path: + return + + placeholder_index = path.index("#") + experts_path_aliases = tuple(path[:placeholder_index]) + grouped: Dict[int, list[tuple[str, ...]]] = {} + for raw_leaf in node: + leaf_aliases, flags = cls._parse_module_spec(raw_leaf) + group_index = 0 + for flag in flags: + if flag.isdigit(): + group_index = int(flag) + break + grouped.setdefault(group_index, []).append(leaf_aliases) + + if not grouped: + return + + leaf_alias_groups = tuple(tuple(grouped[group]) for group in sorted(grouped)) + runtime_root_path = tuple(segment[0] for segment in moe_root) + runtime_experts_path = tuple(segment[0] for segment in experts_path_aliases) + runtime_leaf_groups = tuple( + tuple(leaf_aliases[0] for leaf_aliases in group) + for group in leaf_alias_groups + ) + root_alias_paths = cls._expand_path_aliases(moe_root) + expert_alias_paths = cls._expand_path_aliases(experts_path_aliases) + spec_key = ( + runtime_root_path, + root_alias_paths, + runtime_experts_path, + expert_alias_paths, + leaf_alias_groups, + ) + if spec_key in seen_specs: + return + seen_specs.add(spec_key) + specs.append( + _MoEAliasSpec( + runtime_root_path=runtime_root_path, + root_alias_paths=root_alias_paths, + runtime_experts_path=runtime_experts_path, + expert_alias_paths=expert_alias_paths, + runtime_leaf_groups=runtime_leaf_groups, + leaf_alias_groups=leaf_alias_groups, + ) + ) + return + + if isinstance(node, (tuple, list)): + for item in node: + walk(item, path, moe_root) + + found_hash = False + for item in module_tree: + if item == "#": + found_hash = True + continue + if not found_hash: + continue + walk(item, (), None) + + return tuple(layer_prefix), tuple(specs) + + def _split_layer_relative_path(self, name: str) -> tuple[tuple[str, ...], tuple[str, ...]]: + """Return `(layer_prefix_with_index, relative_parts)` for a runtime or checkpoint tensor path.""" + + parts = tuple(part for part in name.split(".") if part) + prefix = self._module_tree_layer_prefix + if prefix: + max_start = len(parts) - len(prefix) + for start in range(max_start + 1): + end = start + len(prefix) + if parts[start:end] != prefix: + continue + if end >= len(parts) or not parts[end].isdigit(): + continue + return parts[start : end + 1], parts[end + 1 :] + return (), parts + + def _module_tree_name_aliases(self, name: str) -> list[str]: + """Generate checkpoint-name candidates from MoE aliases declared in `module_tree`.""" + + if not self._moe_alias_specs or not name: + return [] + + layer_head, rel_parts = self._split_layer_relative_path(name) + if not rel_parts: + return [] + + aliases: list[str] = [] + seen = {name} + + for spec in self._moe_alias_specs: + if tuple(rel_parts[: len(spec.runtime_root_path)]) == spec.runtime_root_path: + tail = rel_parts[len(spec.runtime_root_path) :] + for root_alias in spec.root_alias_paths: + alias = ".".join(layer_head + root_alias + tail) + if alias not in seen: + seen.add(alias) + aliases.append(alias) + + if len(rel_parts) < len(spec.runtime_experts_path) + 2: + continue + if tuple(rel_parts[: len(spec.runtime_experts_path)]) != spec.runtime_experts_path: + continue + + expert_index = rel_parts[len(spec.runtime_experts_path)] + runtime_leaf = rel_parts[len(spec.runtime_experts_path) + 1] + if not expert_index.isdigit(): + continue + + for group_index, runtime_group in enumerate(spec.runtime_leaf_groups): + if runtime_leaf not in runtime_group: + continue + leaf_index = runtime_group.index(runtime_leaf) + tail = rel_parts[len(spec.runtime_experts_path) + 2 :] + for expert_alias_path in spec.expert_alias_paths: + for leaf_alias in spec.leaf_alias_groups[group_index][leaf_index]: + alias = ".".join(layer_head + expert_alias_path + (expert_index, leaf_alias) + tail) + if alias not in seen: + seen.add(alias) + aliases.append(alias) + break + + return aliases + + @staticmethod + def _candidate_module_paths(module_path: str, *, allow_empty: bool = False) -> list[str]: + """Return progressively stripped module path aliases for checkpoint lookup.""" + + if not module_path: + return [""] + + parts = module_path.split(".") + candidates: list[str] = [] + for drop_count in range(len(parts) + 1): + candidate = ".".join(parts[drop_count:]) + if not candidate and not allow_empty: + continue + if candidate in candidates: + continue + candidates.append(candidate) + return candidates + + def _resolve_checkpoint_module_path(self, module_path: str) -> str: + """Resolve a shell module path to the checkpoint path when wrappers add extra roots.""" + + candidates = [module_path] + candidates.extend(self._module_tree_name_aliases(module_path)) + candidates.extend(self._candidate_module_paths(module_path)) + + seen: set[str] = set() + for candidate in candidates: + if candidate in seen: + continue + seen.add(candidate) + prefix = f"{candidate}." + if any(full_name.startswith(prefix) for full_name in self._weight_map): + return candidate + return module_path + + def _resolve_checkpoint_tensor_name(self, module_path: str, rel_name: str) -> str: + """Resolve a tensor name against checkpoint paths declared by `module_tree` and shell path prefixes.""" + + full_name = self._join_tensor_name(module_path, rel_name) + candidates: list[str] = [] + seen: set[str] = set() + for candidate_path in self._candidate_module_paths(module_path, allow_empty=True): + candidate_name = self._join_tensor_name(candidate_path, rel_name) + if candidate_name not in seen: + seen.add(candidate_name) + candidates.append(candidate_name) + for alias in self._module_tree_name_aliases(candidate_name): + if alias in seen: + continue + seen.add(alias) + candidates.append(alias) + + for candidate in candidates: + if candidate in self._weight_map: + return candidate + return full_name + + def _resolve_split_gate_up_tensor_name( + self, + module_path: str, + rel_name: str, + ) -> tuple[Optional[str], Optional[int], Optional[int], Optional[int]]: + """Resolve split gate/up projection tensors against fused `gate_up_proj` checkpoint entries.""" + + parts = rel_name.split(".") + if len(parts) < 2: + return None, None, None, None + + proj_name = parts[-2] + tensor_name = parts[-1] + if proj_name not in {"gate_proj", "up_proj"} or tensor_name not in {"weight", "bias"}: + return None, None, None, None + + fused_parts = list(parts) + fused_parts[-2] = "gate_up_proj" + fused_rel_name = ".".join(fused_parts) + split_index = 0 if proj_name == "gate_proj" else 1 + + for candidate_path in self._candidate_module_paths(module_path, allow_empty=True): + candidate_name = self._join_tensor_name(candidate_path, fused_rel_name) + if candidate_name in self._weight_map: + return candidate_name, None, split_index, 0 + + return None, None, None, None + + def _resolve_fused_expert_tensor_name( + self, + module_path: str, + rel_name: str, + ) -> tuple[Optional[str], Optional[int], Optional[int], Optional[int]]: + """Resolve defused expert leaf tensors against fused per-expert checkpoint tensors.""" + + parts = rel_name.split(".") + for expert_pos, part in enumerate(parts): + if part != "experts" or expert_pos + 3 >= len(parts): + continue + if not parts[expert_pos + 1].isdigit(): + continue + + expert_index = int(parts[expert_pos + 1]) + proj_name = parts[expert_pos + 2] + tensor_name = parts[expert_pos + 3] + + fused_leaf = None + split_index = None + split_dim = None + + if proj_name in {"gate_proj", "up_proj"}: + split_index = 0 if proj_name == "gate_proj" else 1 + if tensor_name == "weight": + fused_leaf = "gate_up_proj" + split_dim = 1 + elif tensor_name == "bias": + fused_leaf = "gate_up_proj_bias" + split_dim = 0 + elif proj_name == "down_proj": + if tensor_name == "weight": + fused_leaf = "down_proj" + elif tensor_name == "bias": + fused_leaf = "down_proj_bias" + + if fused_leaf is None: + return None, None, None, None + + fused_parts = parts[: expert_pos + 1] + [fused_leaf] + fused_rel_name = ".".join(fused_parts) + for candidate_path in self._candidate_module_paths(module_path, allow_empty=True): + candidate_name = self._join_tensor_name(candidate_path, fused_rel_name) + if candidate_name in self._weight_map: + return candidate_name, expert_index, split_index, split_dim + + return None, None, None, None + + return None, None, None, None + + @staticmethod + def _transform_checkpoint_tensor( + tensor: torch.Tensor, + *, + expert_index: Optional[int], + split_index: Optional[int], + split_dim: Optional[int], + expected_shape: Optional[tuple[int, ...]] = None, + prefer_transposed: Optional[bool] = None, + ) -> Optional[torch.Tensor]: + """Slice fused checkpoint tensors into the tensor layout expected by the shell module.""" + + if expert_index is not None: + if tensor.shape[0] <= expert_index: + return None + # Fused expert checkpoints store the expert axis first; peel it off before + # reasoning about split dimensions or transpose decisions. + tensor = tensor[expert_index].contiguous() + + if expected_shape is None: + if split_index is not None: + if split_dim is None or tensor.shape[split_dim] % 2 != 0: + return None + tensor = tensor.chunk(2, dim=split_dim)[split_index].contiguous() + return tensor + + expected_shape = tuple(expected_shape) + + # Some checkpoints store expert projections as (out, in) while others store + # them as (in, out). Keep both candidates and let the defused leaf shape be + # the final arbiter instead of hard-coding one model family's layout. + candidates: list[tuple[torch.Tensor, bool]] = [(tensor, False)] + if tensor.ndim == 2: + transposed = tensor.transpose(0, 1).contiguous() + if prefer_transposed is True: + candidates = [(transposed, True), (tensor, False)] + elif prefer_transposed is None and transposed.shape != tensor.shape: + candidates.append((transposed, True)) + elif prefer_transposed is False and transposed.shape != tensor.shape: + candidates.append((transposed, True)) + + for candidate, used_transpose in candidates: + if split_index is None: + if tuple(candidate.shape) == expected_shape: + return candidate.contiguous() + continue + + preferred_dims: list[int] = [] + mapped_split_dim = split_dim + if ( + used_transpose + and candidate.ndim == 2 + and split_dim is not None + and 0 <= split_dim < 2 + ): + # The resolver hint is expressed in the checkpoint's native layout. + # Once we transpose a 2D candidate, the split dimension flips too. + mapped_split_dim = 1 - split_dim + if mapped_split_dim is not None and 0 <= mapped_split_dim < candidate.ndim: + preferred_dims.append(mapped_split_dim) + preferred_dims.extend(dim for dim in range(candidate.ndim) if dim not in preferred_dims) + + for dim in preferred_dims: + if candidate.shape[dim] % 2 != 0: + continue + split_tensor = candidate.chunk(2, dim=dim)[split_index].contiguous() + if tuple(split_tensor.shape) == expected_shape: + return split_tensor + + return None + + @staticmethod + def _resolve_prefer_transposed_hint( + *, + target_model: nn.Module, + module_path: str, + rel_name: str, + modules_by_name: Dict[str, nn.Module], + ) -> Optional[bool]: + rel_parent, _, _leaf = rel_name.rpartition(".") + current_path = module_path + if rel_parent: + current_path = LazyTurtle._join_tensor_name(module_path, rel_parent) + + # Expert containers usually expose `is_transposed`; leaf Linear modules do not. + # Walk upward until we find the nearest owner that carries the layout hint. + while True: + owner = target_model if not current_path else modules_by_name.get(current_path) + if owner is not None and hasattr(owner, "is_transposed"): + value = getattr(owner, "is_transposed") + if isinstance(value, bool): + return value + + if not current_path: + break + current_path = current_path.rpartition(".")[0] + + return None + + def _resolve_checkpoint_tensor_source( + self, + module_path: str, + rel_name: str, + ) -> tuple[Optional[str], Optional[int], Optional[int], Optional[int]]: + """Resolve a target tensor name to its checkpoint source and optional fused split index.""" + + full_name = self._resolve_checkpoint_tensor_name(module_path, rel_name) + if full_name in self._weight_map: + return full_name, None, None, None + + resolved = self._resolve_split_gate_up_tensor_name(module_path, rel_name) + if resolved[0] is not None: + return resolved + + resolved = self._resolve_fused_expert_tensor_name(module_path, rel_name) + if resolved[0] is not None: + return resolved + + # Direct-meta rematerialization often visits a leaf Linear whose relative name is + # just `weight` / `bias`. Retry resolution with the full module path so leaf-only + # materialization can still map back to fused expert checkpoint tensors. + combined_name = self._join_tensor_name(module_path, rel_name) + resolved = self._resolve_split_gate_up_tensor_name("", combined_name) + if resolved[0] is not None: + return resolved + + return self._resolve_fused_expert_tensor_name("", combined_name) + + @staticmethod + def _materialization_issue_message( + *, + phase: str, + kind: str, + module_path: str, + rel_name: str, + reason: str, + full_name: Optional[str] = None, + source_shape: Optional[tuple[int, ...]] = None, + target_shape: Optional[tuple[int, ...]] = None, + expert_index: Optional[int] = None, + split_index: Optional[int] = None, + split_dim: Optional[int] = None, + ) -> str: + """Build a consistent error message for checkpoint-backed materialization failures.""" + + details = [] + if full_name is not None: + details.append(f"checkpoint={full_name}") + if source_shape is not None: + details.append(f"source_shape={source_shape}") + if target_shape is not None: + details.append(f"target_shape={target_shape}") + if expert_index is not None: + details.append(f"expert_index={expert_index}") + if split_index is not None: + details.append(f"split_index={split_index}") + if split_dim is not None: + details.append(f"split_dim={split_dim}") + + suffix = f" ({', '.join(details)})" if details else "" + return ( + f"LazyTurtle: {phase} {kind} `{rel_name}` under `{module_path or ''}`: " + f"{reason}{suffix}" + ) + + def _load_checkpoint_tensors_for_module_path( + self, + *, + module_path: str, + recurse: bool, + ) -> Dict[str, torch.Tensor]: + """Return raw checkpoint tensors keyed by submodule-relative names.""" + + resolved_module_path = self._resolve_checkpoint_module_path(module_path) + prefix = f"{resolved_module_path}." + grouped_names: Dict[str, list[tuple[str, str]]] = {} + for full_name, shard in self._weight_map.items(): + if not full_name.startswith(prefix): + continue + + rel_name = full_name[len(prefix):] + if not rel_name: + continue + if not recurse and "." in rel_name: + continue + + grouped_names.setdefault(shard, []).append((rel_name, full_name)) + + tensors: Dict[str, torch.Tensor] = {} + for shard, names in grouped_names.items(): + shard_path = os.path.join(self.model_local_path, shard) + with safe_open(shard_path, framework="pt", device="cpu") as handler: + for rel_name, full_name in names: + tensors[rel_name] = handler.get_tensor(full_name) + return tensors + + def _copy_checkpoint_tensors_into_submodule( + self, + *, + target_model: nn.Module, + target_submodule: nn.Module, + module_path: str, + device: torch.device, + recurse: bool, + non_blocking: bool, + ) -> None: + """Materialize checkpoint tensors into a shell submodule and rebuild missing init-only buffers.""" + + t_params = dict(target_submodule.named_parameters(recurse=recurse)) + t_bufs = dict(target_submodule.named_buffers(recurse=recurse)) + modules_by_name = dict(target_model.named_modules()) + missing_nonpersistent_buffers: list[tuple[str, str]] = [] + + grouped_names: Dict[str, list[tuple[str, str, str, Optional[int], Optional[int], Optional[int]]]] = {} + for rel_name in t_params: + full_name, expert_index, split_index, split_dim = self._resolve_checkpoint_tensor_source(module_path, rel_name) + if full_name is None: + continue + shard = self._weight_map.get(full_name) + if shard is None: + raise RuntimeError( + self._materialization_issue_message( + phase="submodule materialization", + kind="param", + module_path=module_path, + rel_name=rel_name, + reason="checkpoint tensor mapping resolved to a missing shard", + full_name=full_name, + target_shape=tuple(t_params[rel_name].shape), + expert_index=expert_index, + split_index=split_index, + split_dim=split_dim, + ) + ) + grouped_names.setdefault(shard, []).append(("param", rel_name, full_name, expert_index, split_index, split_dim)) + + for rel_name, target_buffer in list(t_bufs.items()): + full_name, expert_index, split_index, split_dim = self._resolve_checkpoint_tensor_source(module_path, rel_name) + if full_name is None: + full_name = self._resolve_checkpoint_tensor_name(module_path, rel_name) + expert_index = None + split_index = None + split_dim = None + shard = self._weight_map.get(full_name) + if shard is None: + t_parent, leaf = _get_parent_and_leaf_by_path(target_submodule, rel_name) + non_persistent = leaf in getattr(t_parent, "_non_persistent_buffers_set", set()) + if non_persistent: + if ( + getattr(target_buffer, "is_meta", False) + or target_buffer.device.type == "meta" + or target_buffer.device != device + ): + missing_nonpersistent_buffers.append((rel_name, leaf)) + continue + if getattr(target_buffer, "is_meta", False) or target_buffer.device.type == "meta": + if leaf in getattr(t_parent, "_buffers", {}): + del t_parent._buffers[leaf] + continue + grouped_names.setdefault(shard, []).append(("buffer", rel_name, full_name, expert_index, split_index, split_dim)) + + with torch.inference_mode(): + for shard, entries in grouped_names.items(): + shard_path = os.path.join(self.model_local_path, shard) + with safe_open(shard_path, framework="pt", device="cpu") as handler: + for kind, rel_name, full_name, expert_index, split_index, split_dim in entries: + target_tensor = t_params.get(rel_name) if kind == "param" else t_bufs.get(rel_name) + expected_shape = tuple(target_tensor.shape) if target_tensor is not None else None + prefer_transposed = self._resolve_prefer_transposed_hint( + target_model=target_model, + module_path=module_path, + rel_name=rel_name, + modules_by_name=modules_by_name, + ) + checkpoint_tensor = handler.get_tensor(full_name) + tensor = self._transform_checkpoint_tensor( + checkpoint_tensor, + expert_index=expert_index, + split_index=split_index, + split_dim=split_dim, + expected_shape=expected_shape, + prefer_transposed=prefer_transposed, + ) + if tensor is None: + raise RuntimeError(self._materialization_issue_message( + phase="submodule materialization", + kind=kind, + module_path=module_path, + rel_name=rel_name, + reason="checkpoint tensor could not be reshaped into the target layout", + full_name=full_name, + source_shape=tuple(checkpoint_tensor.shape), + target_shape=expected_shape, + expert_index=expert_index, + split_index=split_index, + split_dim=split_dim, + )) + if kind == "param": + target_param = t_params.get(rel_name) + if target_param is None: + raise RuntimeError(self._materialization_issue_message( + phase="submodule materialization", + kind=kind, + module_path=module_path, + rel_name=rel_name, + reason="target tensor disappeared before materialization", + full_name=full_name, + source_shape=tuple(tensor.shape), + expert_index=expert_index, + split_index=split_index, + split_dim=split_dim, + )) + if target_param.shape != tensor.shape: + raise RuntimeError(self._materialization_issue_message( + phase="submodule materialization", + kind=kind, + module_path=module_path, + rel_name=rel_name, + reason="target tensor shape does not match the transformed checkpoint tensor", + full_name=full_name, + source_shape=tuple(tensor.shape), + target_shape=tuple(target_param.shape), + expert_index=expert_index, + split_index=split_index, + split_dim=split_dim, + )) + target_param_new = _ensure_target_storage_on_device_(target_param, device) + if target_param_new is not target_param: + t_parent, leaf = _get_parent_and_leaf_by_path(target_submodule, rel_name) + setattr(t_parent, leaf, target_param_new) + target_param = target_param_new + source = tensor.detach() + if source.dtype != target_param.dtype: + source = source.to(dtype=target_param.dtype) + target_param.detach().copy_(source, non_blocking=(non_blocking and source.is_pinned())) + continue + + target_buffer = t_bufs.get(rel_name) + t_parent, leaf = _get_parent_and_leaf_by_path(target_submodule, rel_name) + persistent = leaf not in getattr(t_parent, "_non_persistent_buffers_set", set()) + + source = tensor.detach() + if target_buffer is None: + new_buffer = source.to(device=device) + t_parent.register_buffer(leaf, new_buffer, persistent=persistent) + t_bufs[rel_name] = new_buffer + continue + + if tuple(target_buffer.shape) != tuple(source.shape): + raise RuntimeError(self._materialization_issue_message( + phase="submodule materialization", + kind=kind, + module_path=module_path, + rel_name=rel_name, + reason="target tensor shape does not match the transformed checkpoint tensor", + full_name=full_name, + source_shape=tuple(source.shape), + target_shape=tuple(target_buffer.shape), + expert_index=expert_index, + split_index=split_index, + split_dim=split_dim, + )) + + if getattr(target_buffer, "is_meta", False) or target_buffer.device.type == "meta": + new_buffer = torch.empty_like(target_buffer, device=device) + new_buffer.copy_(source.to(dtype=new_buffer.dtype), non_blocking=(non_blocking and source.is_pinned())) + t_parent.register_buffer(leaf, new_buffer, persistent=persistent) + t_bufs[rel_name] = new_buffer + continue + + if target_buffer.device != device: + new_buffer = torch.empty_like(target_buffer, device=device) + new_buffer.copy_(source.to(dtype=new_buffer.dtype), non_blocking=(non_blocking and source.is_pinned())) + t_parent.register_buffer(leaf, new_buffer, persistent=persistent) + t_bufs[rel_name] = new_buffer + else: + if source.dtype != target_buffer.dtype: + source = source.to(dtype=target_buffer.dtype) + target_buffer.copy_(source, non_blocking=(non_blocking and source.is_pinned())) + + self._restore_missing_nonpersistent_buffers( + target_model=target_model, + target_submodule=target_submodule, + t_bufs=t_bufs, + missing_nonpersistent_buffers=missing_nonpersistent_buffers, + device=device, + ) + + def _build_nonpersistent_buffer_template( + self, + *, + owner_module: nn.Module, + target_model: nn.Module, + ) -> Optional[nn.Module]: + """Construct a CPU template module for init-only buffers missing from checkpoint shards.""" + + config_source = getattr(owner_module, "config", None) + if config_source is None: + config_source = getattr(target_model, "config", None) + if config_source is None: + config_source = self.config + + module_type = type(owner_module) + try: + signature = inspect.signature(module_type) + except (TypeError, ValueError): + return None + + params = list(signature.parameters.values()) + if not params: + return None + + args = [] + kwargs = {} + + for param in params: + if param.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD): + continue + + if param.name == "config": + if config_source is None: + return None + value = copy.deepcopy(config_source) + elif param.name == "device": + value = torch.device("cpu") + elif hasattr(owner_module, param.name): + # Some remote-code modules rebuild buffers from constructor attributes instead of config. + raw_value = getattr(owner_module, param.name) + if isinstance(raw_value, torch.Tensor) and raw_value.device.type == "meta": + scalar_attr_name = f"scalar_{param.name}" + if hasattr(owner_module, scalar_attr_name): + raw_value = getattr(owner_module, scalar_attr_name) + elif param.default is not inspect.Parameter.empty: + continue + else: + return None + value = copy.deepcopy(raw_value) + elif param.default is not inspect.Parameter.empty: + continue + else: + return None + + if param.kind is inspect.Parameter.POSITIONAL_ONLY: + args.append(value) + else: + kwargs[param.name] = value + + try: + return module_type(*args, **kwargs) + except Exception as exc: + log.debug( + "LazyTurtle: failed to build template for `%s`: %s", + module_type.__name__, + exc, + ) + return None + + def _restore_missing_nonpersistent_buffers( + self, + *, + target_model: nn.Module, + target_submodule: nn.Module, + t_bufs: Dict[str, torch.Tensor], + missing_nonpersistent_buffers: list[tuple[str, str]], + device: torch.device, + ) -> None: + """Restore constructor-owned buffers that are intentionally absent from checkpoints.""" + + owner_templates: Dict[str, Optional[nn.Module]] = {} + for rel_name, leaf in missing_nonpersistent_buffers: + parent_rel_path, _, _ = rel_name.rpartition(".") + owner_module = target_submodule if not parent_rel_path else dict(target_submodule.named_modules()).get(parent_rel_path) + if owner_module is None: + continue + + current_buffer = t_bufs.get(rel_name) + if ( + current_buffer is not None + and not getattr(current_buffer, "is_meta", False) + and current_buffer.device.type != "meta" + ): + source_buffer = current_buffer.detach() + else: + if parent_rel_path not in owner_templates: + owner_templates[parent_rel_path] = self._build_nonpersistent_buffer_template( + owner_module=owner_module, + target_model=target_model, + ) + template = owner_templates[parent_rel_path] + if template is None: + continue + source_buffer = dict(template.named_buffers(recurse=False)).get(leaf) + if source_buffer is None: + continue + source_buffer = source_buffer.detach() + + target_dtype = source_buffer.dtype if current_buffer is None else current_buffer.dtype + materialized = source_buffer.to(device=device, dtype=target_dtype) + owner_module.register_buffer(leaf, materialized, persistent=False) + t_bufs[rel_name] = materialized + + def _materialize_direct_meta_tensors( + self, + *, + shell_sub: nn.Module, + module_path: str, + param_cache: Dict[tuple[str, Optional[int], Optional[int], Optional[int], torch.dtype, bool], nn.Parameter], + buffer_cache: Dict[tuple[str, Optional[int], Optional[int], Optional[int], torch.dtype], torch.Tensor], + ) -> int: + synced = 0 + + with torch.inference_mode(): + for name, shell_param in dict(shell_sub.named_parameters(recurse=False)).items(): + if not _is_meta_tensor(shell_param): + continue + + full_name, expert_index, split_index, split_dim = self._resolve_checkpoint_tensor_source(module_path, name) + if full_name is None: + continue + shard = self._weight_map.get(full_name) + if shard is None: + raise RuntimeError(self._materialization_issue_message( + phase="direct-meta sync", + kind="param", + module_path=module_path, + rel_name=name, + reason="checkpoint tensor mapping resolved to a missing shard", + full_name=full_name, + target_shape=tuple(shell_param.shape), + expert_index=expert_index, + split_index=split_index, + split_dim=split_dim, + )) + + source_path = os.path.join(self.model_local_path, shard) + with safe_open(source_path, framework="pt", device="cpu") as handler: + checkpoint_param = handler.get_tensor(full_name) + source_param = self._transform_checkpoint_tensor( + checkpoint_param, + expert_index=expert_index, + split_index=split_index, + split_dim=split_dim, + expected_shape=tuple(shell_param.shape), + prefer_transposed=getattr(shell_sub, "is_transposed", None), + ) + if source_param is None: + raise RuntimeError(self._materialization_issue_message( + phase="direct-meta sync", + kind="param", + module_path=module_path, + rel_name=name, + reason="checkpoint tensor could not be reshaped into the target layout", + full_name=full_name, + source_shape=tuple(checkpoint_param.shape), + target_shape=tuple(shell_param.shape), + expert_index=expert_index, + split_index=split_index, + split_dim=split_dim, + )) + + if shell_param.shape != source_param.shape: + raise RuntimeError(self._materialization_issue_message( + phase="direct-meta sync", + kind="param", + module_path=module_path, + rel_name=name, + reason="target tensor shape does not match the transformed checkpoint tensor", + full_name=full_name, + source_shape=tuple(source_param.shape), + target_shape=tuple(shell_param.shape), + expert_index=expert_index, + split_index=split_index, + split_dim=split_dim, + )) + + cache_key = (full_name, expert_index, split_index, split_dim, shell_param.dtype, shell_param.requires_grad) + new_param = param_cache.get(cache_key) + if new_param is None: + if source_param.dtype != shell_param.dtype: + source_param = source_param.to(dtype=shell_param.dtype) + new_param = nn.Parameter( + source_param.clone(), + requires_grad=shell_param.requires_grad, + ) + param_cache[cache_key] = new_param + + shell_sub.register_parameter(name, new_param) + synced += 1 + + for name, shell_buffer in list(dict(shell_sub.named_buffers(recurse=False)).items()): + if not _is_meta_tensor(shell_buffer): + continue + + full_name, expert_index, split_index, split_dim = self._resolve_checkpoint_tensor_source(module_path, name) + if full_name is None: + continue + shard = self._weight_map.get(full_name) + if shard is None: + raise RuntimeError(self._materialization_issue_message( + phase="direct-meta sync", + kind="buffer", + module_path=module_path, + rel_name=name, + reason="checkpoint tensor mapping resolved to a missing shard", + full_name=full_name, + target_shape=tuple(shell_buffer.shape), + expert_index=expert_index, + split_index=split_index, + split_dim=split_dim, + )) + + source_path = os.path.join(self.model_local_path, shard) + with safe_open(source_path, framework="pt", device="cpu") as handler: + checkpoint_buffer = handler.get_tensor(full_name) + source_buffer = self._transform_checkpoint_tensor( + checkpoint_buffer, + expert_index=expert_index, + split_index=split_index, + split_dim=split_dim, + expected_shape=tuple(shell_buffer.shape), + prefer_transposed=getattr(shell_sub, "is_transposed", None), + ) + if source_buffer is None: + raise RuntimeError(self._materialization_issue_message( + phase="direct-meta sync", + kind="buffer", + module_path=module_path, + rel_name=name, + reason="checkpoint tensor could not be reshaped into the target layout", + full_name=full_name, + source_shape=tuple(checkpoint_buffer.shape), + target_shape=tuple(shell_buffer.shape), + expert_index=expert_index, + split_index=split_index, + split_dim=split_dim, + )) + + if shell_buffer.shape != source_buffer.shape: + raise RuntimeError(self._materialization_issue_message( + phase="direct-meta sync", + kind="buffer", + module_path=module_path, + rel_name=name, + reason="target tensor shape does not match the transformed checkpoint tensor", + full_name=full_name, + source_shape=tuple(source_buffer.shape), + target_shape=tuple(shell_buffer.shape), + expert_index=expert_index, + split_index=split_index, + split_dim=split_dim, + )) + + persistent = name not in getattr(shell_sub, "_non_persistent_buffers_set", set()) + cache_key = (full_name, expert_index, split_index, split_dim, shell_buffer.dtype) + new_buffer = buffer_cache.get(cache_key) + if new_buffer is None: + if source_buffer.dtype != shell_buffer.dtype: + source_buffer = source_buffer.to(dtype=shell_buffer.dtype) + new_buffer = source_buffer.clone() + buffer_cache[cache_key] = new_buffer + + shell_sub.register_buffer(name, new_buffer, persistent=persistent) + synced += 1 + + return synced + def alias_from_turtle_for_submodule( target_model: torch.nn.Module, - turtle_model: torch.nn.Module, + turtle_model: "LazyTurtle", target_submodule: torch.nn.Module, device: torch.device, non_blocking: bool = False, ) -> torch.nn.Module: - # removed cpu from list to allow materialize from meta to cpu + # Lazy turtle supports materialization from checkpoint storage into CPU or accelerator devices. assert device not in [None, torch.device("meta")] - # print(f"alias device = {device}") - - # Resolve path & source submodule (on CPU/mmap) - path = _get_qualified_name(target_model, target_submodule) - src_map: Dict[str, nn.Module] = dict(turtle_model.named_modules()) - if path not in src_map: - raise KeyError(f"Path '{path}' not found in turtle_model.") - src_sub = src_map[path] - - # ---- copy params/buffers CPU->GPU into target_submodule (your existing code) ---- - t_params = dict(target_submodule.named_parameters(recurse=True)) - s_params = dict(src_sub.named_parameters(recurse=True)) - with torch.inference_mode(): - for name, s_p in s_params.items(): - t_p = t_params.get(name) - if t_p is None or t_p.shape != s_p.shape: - continue - t_p_new = _ensure_target_storage_on_device_(t_p, device) - if t_p_new is not t_p: - t_parent, leaf = _get_parent_and_leaf_by_path(target_submodule, name) - setattr(t_parent, leaf, t_p_new) - t_p = t_p_new - t_p.detach().copy_(s_p.detach(), non_blocking=(non_blocking and s_p.is_pinned())) - - t_bufs = dict(target_submodule.named_buffers(recurse=True)) - s_bufs = dict(src_sub.named_buffers(recurse=True)) - for name, s_b in s_bufs.items(): - tb = t_bufs.get(name) - t_parent, leaf = _get_parent_and_leaf_by_path(target_submodule, name) - s_parent, _ = _get_parent_and_leaf_by_path(src_sub, name) - - # nn.Module decides buffer persistence using `_non_persistent_buffers_set`: - # the buffer is persistent unless its name is in this set. - persistent = True - if hasattr(s_parent, "_non_persistent_buffers_set"): - persistent = leaf not in s_parent._non_persistent_buffers_set - - if tb is None or getattr(tb, "is_meta", False) or tb.device.type == "meta": - new_b = torch.empty_like(s_b, device=device) - new_b.copy_(s_b.detach(), non_blocking=(non_blocking and s_b.is_pinned())) - t_parent.register_buffer(leaf, new_b, persistent=persistent) - else: - if tb.device != device: - new_tb = torch.empty_like(s_b, device=device) - new_tb.copy_(s_b.detach(), non_blocking=(non_blocking and s_b.is_pinned())) - t_parent.register_buffer(leaf, new_tb, persistent=persistent) - else: - tb.copy_(s_b.detach(), non_blocking=(non_blocking and s_b.is_pinned())) - - if hasattr(target_model, "tie_weights"): - target_model.tie_weights() - - #print("Post alias: target_submodule device summary:") - # for n, p in target_submodule.named_parameters(recurse=True): - # print(f" {n}: {p.device}") - - # return the *target* submodule, which is the injected result - return target_submodule + if not hasattr(turtle_model, "materialize_submodule"): + raise TypeError( + f"Expected LazyTurtle-compatible source, got `{type(turtle_model).__name__}`." + ) + + return turtle_model.materialize_submodule( + target_model=target_model, + target_submodule=target_submodule, + device=device, + non_blocking=non_blocking, + ) def _is_meta_tensor(t: torch.Tensor) -> bool: return bool(getattr(t, "is_meta", False)) or (hasattr(t, "device") and t.device.type == "meta") -def _module_all_meta(mod: nn.Module) -> bool: - """True if the module has at least one tensor and *all* its params/buffers are meta.""" - saw_any = False - for _, p in mod.named_parameters(recurse=False): - saw_any = True - if not _is_meta_tensor(p): - return False - for _, b in mod.named_buffers(recurse=False): - saw_any = True - if not _is_meta_tensor(b): - return False - return saw_any # modules with no tensors aren't considered 'meta' targets - -def _is_leaf(mod: nn.Module) -> bool: - return next(mod.named_children(), None) is None - def alias_all_from_turtle_if_meta( shell_model: nn.Module, - turtle_model: nn.Module, + turtle_model: Optional["LazyTurtle"], *, require_class_match: bool = True, verify_shapes: bool = True, tie_after: bool = True, ) -> int: """ - Replace (alias) leaf submodules in `shell_model` with the corresponding submodules - from `turtle_model` when the shell submodule's tensors are on meta. - - Logs each swap via log.info(). + Materialize any remaining direct meta tensors in `shell_model` from the lazy turtle source. """ if turtle_model is None: return 0 - turtle_map = dict(turtle_model.named_modules()) - swapped = 0 - - for qname, shell_sub in list(shell_model.named_modules()): - if not qname: # skip root - continue - if not _is_leaf(shell_sub): - continue - if not _module_all_meta(shell_sub): - continue - - turtle_sub = turtle_map.get(qname, None) - if turtle_sub is None: - # log.info(f"Module: Skipped {qname}: not found in turtle model") - continue - - if require_class_match and (shell_sub.__class__ is not turtle_sub.__class__): - # log.info( - # f"Module: Skipped {qname}: class mismatch " - # f"(shell={shell_sub.__class__.__name__}, turtle={turtle_sub.__class__.__name__})" - # ) - continue - - if verify_shapes: - shell_ps = dict(shell_sub.named_parameters(recurse=False)) - turtle_ps = dict(turtle_sub.named_parameters(recurse=False)) - for n in set(shell_ps.keys()) & set(turtle_ps.keys()): - if shell_ps[n].shape != turtle_ps[n].shape: - # log.info( - # f"Module: Skipped {qname}: parameter shape mismatch at '{n}' " - # f"(shell={tuple(shell_ps[n].shape)}, turtle={tuple(turtle_ps[n].shape)})" - # ) - break - else: - shell_bs = dict(shell_sub.named_buffers(recurse=False)) - turtle_bs = dict(turtle_sub.named_buffers(recurse=False)) - for n in set(shell_bs.keys()) & set(turtle_bs.keys()): - if shell_bs[n].shape != turtle_bs[n].shape: - # log.info( - # f"Module: Skipped {qname}: buffer shape mismatch at '{n}' " - # f"(shell={tuple(shell_bs[n].shape)}, turtle={tuple(turtle_bs[n].shape)})" - # ) - break - else: - parent, leaf = _get_parent_and_leaf_by_path(shell_model, qname) - setattr(parent, leaf, turtle_sub) - swapped += 1 - log.info(f"Module: Sync {qname} <- from turtle ({turtle_sub.__class__.__name__})") - continue - continue - - parent, leaf = _get_parent_and_leaf_by_path(shell_model, qname) - setattr(parent, leaf, turtle_sub) - swapped += 1 - log.info(f"Module:: Sync {qname} <- from turtle ({turtle_sub.__class__.__name__})") - - if tie_after and hasattr(shell_model, "tie_weights") and getattr(shell_model.config, "tie_word_embeddings", False): - try: - shell_model.tie_weights() - log.info("Module: Re-tied embedding weights on shell model after full sync") - except Exception as e: - log.info(f"Module: tie_weights failed: {e}") - - log.info(f"Module: Total synced modules: {swapped}") - return swapped + if not hasattr(turtle_model, "sync_all_meta"): + raise TypeError( + f"Expected LazyTurtle-compatible source, got `{type(turtle_model).__name__}`." + ) + return turtle_model.sync_all_meta( + shell_model=shell_model, + require_class_match=require_class_match, + verify_shapes=verify_shapes, + tie_after=tie_after, + ) diff --git a/gptqmodel/utils/threadx.py b/gptqmodel/utils/threadx.py index 2220862cb..3b0b51e62 100644 --- a/gptqmodel/utils/threadx.py +++ b/gptqmodel/utils/threadx.py @@ -18,8 +18,6 @@ import torch -from .pause_resume import _restore_terminal_settings_on_exit - try: from device_smi import Device # type: ignore @@ -27,6 +25,7 @@ Device = None from .. import DEBUG_ON +from ..utils import torch as torch_utils from ..utils.ctx import ctx from ..utils.logger import setup_logger from ..utils.torch import torch_empty_cache_any @@ -34,6 +33,13 @@ log = setup_logger() + +def _best_effort_stderr_write(message: str) -> None: + with contextlib.suppress(Exception): + sys.stderr.write(f"{message}\n") + sys.stderr.flush() + + # Debug logging is very chatty and can alter timings subtly in tests. # We gate all extra diagnostics behind the DEBUG env (1/true/yes/on). @@ -389,11 +395,6 @@ def _run(self): """ self._apply_cpu_affinity() _activate_thread_device(self.device) - try: - self._run_warmup() - except BaseException as exc: - self._abort_process(exc) - return while not self._stop.is_set(): is_task, fn, args, kwargs, fut = self._q.get() try: @@ -402,6 +403,12 @@ def _run(self): break if DEBUG_ON: log.debug(f"{self.name}: task begin; qsize={self._q.qsize()}") + try: + self._run_warmup() + except BaseException as exc: + self._abort_process(exc) + return + event = kwargs.pop("cuda_event", None) override_inference = _pop_public_kwarg( kwargs, "inference_mode", "_threadx_inference_mode" @@ -443,7 +450,6 @@ def _abort_process(self, exc: BaseException) -> None: except Exception: pass try: - _restore_terminal_settings_on_exit() os._exit(1) except Exception: # Last resort if os._exit is unavailable for some reason. @@ -536,7 +542,7 @@ def __init__( empty_cache_every_n: int = 50, # <=0 disables janitor workers: Optional[Dict[str, int]] = None, # e.g. {'cpu':4, 'cuda:per':1, 'cuda:0':3} gc_debounce_seconds: float = 0.02, # absorb bursty triggers before GC - gc_min_interval_seconds: float = 1.0, # throttle janitor passes + gc_min_interval_seconds: float = 0.0, # throttle janitor passes pin_cpu_workers: bool = False, pin_accelerator_workers: bool = False, ): @@ -1393,6 +1399,8 @@ def _run_empty_cache_for_device(self, key: str, dev: torch.device) -> Optional[f """ start = time.time() try: + # Tests and runtime hooks may monkeypatch empty_cache between janitor passes. + torch_utils.resolve_empty_cache_callable.cache_clear() success = torch_empty_cache_any(device=dev, gc=False) except Exception as exc: if DEBUG_ON: @@ -1537,10 +1545,14 @@ def _on_task_finished(self, key: str) -> None: physical_key = self._physical_key(key) current = self._gc_done_physical.get(physical_key, 0) + 1 self._gc_done_physical[physical_key] = current - if current % self._empty_cache_every_n == 0: - pending_map = self._gc_pending_physical - pending_map[physical_key] = pending_map.get(physical_key, 0) + 1 - self._gc_generation += 1 + last_done = self._last_gc_done_physical.get(physical_key, 0) + completed_chunks = max(0, (current - last_done) // self._empty_cache_every_n) + pending_map = self._gc_pending_physical + already_pending = pending_map.get(physical_key, 0) + new_triggers = completed_chunks - already_pending + if new_triggers > 0: + pending_map[physical_key] = already_pending + new_triggers + self._gc_generation += new_triggers trigger_gc = True if DEBUG_ON: log.debug( @@ -1739,10 +1751,10 @@ def _update_gc_watermarks(self, snap_after: Dict[str, Any]) -> None: Record 'done' counters as of a GC pass to require fresh progress before a subsequent pass is allowed. """ - threshold = int(self._empty_cache_every_n) + int(self._empty_cache_every_n) per_done_physical = snap_after.get("per_done_physical") or {} per_done = snap_after.get("per_done") or {} - meta = snap_after.get("meta") or {} + snap_after.get("meta") or {} processed = snap_after.get("_gc_processed_devices") if processed is None: processed_iter = per_done_physical.keys() @@ -1753,21 +1765,14 @@ def _update_gc_watermarks(self, snap_after: Dict[str, Any]) -> None: done_phys = per_done_physical.get(phys_key) if done_phys is None: continue - if threshold <= 0: - self._last_gc_done_physical[phys_key] = done_phys - else: - self._last_gc_done_physical[phys_key] = done_phys - (done_phys % threshold) + self._last_gc_done_physical[phys_key] = done_phys members = self._physical_children.get(phys_key, {phys_key}) for member in members: done_member = per_done.get(member) if done_member is None: continue - dev_type = meta.get(member, {}).get("type") - if threshold <= 0 or dev_type not in ("cuda", "xpu", "mps"): - self._last_gc_done_per_device[member] = done_member - else: - self._last_gc_done_per_device[member] = done_member - (done_member % threshold) + self._last_gc_done_per_device[member] = done_member def _janitor_loop(self): """ @@ -1912,11 +1917,30 @@ def _janitor_loop(self): self._last_consumed_gc_generation = max(self._last_consumed_gc_generation, current_generation) continue + # A single trigger runs a full accelerator sweep so memory pressure is + # normalized across all active devices under the janitor's exclusive pass. + sweep_targets: List[str] = [] + seen_targets: Set[str] = set() + for candidate in self._ordered_keys: + physical = self._physical_key(candidate) + if physical in seen_targets: + continue + dev = self._devices_by_key.get(physical) + if dev is None or dev.type not in ("cuda", "xpu", "mps"): + continue + if physical not in self._locks: + continue + seen_targets.add(physical) + sweep_targets.append(physical) + + if not sweep_targets: + sweep_targets = pending_targets + processed_devices: List[str] = [] skipped_devices: List[str] = [] per_device_durations: Dict[str, float] = {} - for key in pending_targets: + for key in sweep_targets: dev = self._devices_by_key.get(key) if dev is None or dev.type not in ("cuda", "xpu", "mps"): skipped_devices.append(key) @@ -1950,29 +1974,19 @@ def _janitor_loop(self): delta_s = t1 - prev_gc_ts since_last_gc = f"since last GC: {delta_s:.3f}s ({delta_s * 1000:.1f}ms)" + post = None if processed_devices: - vram_summary = self._format_vram_summary(processed_devices) try: post = self._collect_state_snapshot() post["_gc_processed_devices"] = processed_devices self._update_gc_watermarks(post) - devices_clause = ", ".join(processed_devices) - log.info( - f"GC completed in {t1 - t0:.3f}s (pass #{self._gc_passes}) at {gc_timestamp}; devices={devices_clause}; VRAM {vram_summary}; {since_last_gc}." - ) - if DEBUG_ON: - log.debug( - "DP-Janitor: post-snapshot inflight=%s per_done=%s per_done_physical=%s durations=%s", - post["inflight"], - post["per_done"], - post.get("per_done_physical"), - per_device_durations, - ) except Exception as e: try: - log.warn(f"Failed to render GC post-snapshot: {e!r}") - except Exception: - pass + log.warn(f"Failed to update GC watermarks: {e!r}") + except Exception as log_exc: + _best_effort_stderr_write( + f"Failed to update GC watermarks: {e!r}; secondary logging failed: {log_exc!r}" + ) with self._stats_lock: pending_map = self._gc_pending_physical @@ -1981,12 +1995,34 @@ def _janitor_loop(self): self._gc_pending_physical = pending_map for key in processed_devices: pending_map.pop(key, None) - self._last_gc_done_physical[key] = self._gc_done_physical.get(key, 0) for key in skipped_devices: pending_map.pop(key, None) self._last_consumed_gc_generation = max(self._last_consumed_gc_generation, current_generation) if any(count > 0 for count in pending_map.values()): self._gc_event.set() + if processed_devices: + vram_summary = self._format_vram_summary(processed_devices) + try: + devices_clause = ", ".join(processed_devices) + log.info( + f"GC completed in {t1 - t0:.3f}s (pass #{self._gc_passes}) at {gc_timestamp}; devices={devices_clause}; VRAM {vram_summary}; {since_last_gc}." + ) + if DEBUG_ON: + log.debug( + "DP-Janitor: post-snapshot inflight=%s per_done=%s per_done_physical=%s durations=%s", + post["inflight"], + post["per_done"], + post.get("per_done_physical"), + per_device_durations, + ) + except Exception as e: + try: + log.warn(f"Failed to render GC post-snapshot: {e!r}") + except Exception as log_exc: + _best_effort_stderr_write( + f"Failed to render GC post-snapshot: {e!r}; secondary logging failed: {log_exc!r}" + ) + def _empty_all_caches(self): torch_empty_cache_any(gc=False) diff --git a/gptqmodel/utils/torch.py b/gptqmodel/utils/torch.py index bec08f300..06967efe0 100644 --- a/gptqmodel/utils/torch.py +++ b/gptqmodel/utils/torch.py @@ -313,12 +313,19 @@ def empty_cache_for_device(device: torch.device) -> bool: return False -def torch_empty_cache_any(device: Union[torch.device, str, int, None] = None, gc: bool = True) -> bool: +def torch_empty_cache_any( + device: Union[torch.device, str, int, None] = None, + gc: bool = True, + sync: bool = False, +) -> bool: normalized = _normalize_device(device) if gc: timed_gc_collect() + if sync: + torch_sync(device=normalized) + success = False if normalized is None: @@ -350,8 +357,8 @@ def torch_empty_cache_any(device: Union[torch.device, str, int, None] = None, gc return empty_cache_for_device(normalized) -def torch_empty_cache(device: torch.device = None, gc: bool = True) -> bool: - return torch_empty_cache_any(device=device, gc=gc) +def torch_empty_cache(device: torch.device = None, gc: bool = True, sync: bool = False) -> bool: + return torch_empty_cache_any(device=device, gc=gc, sync=sync) def auto_select_torch_device(index: int = 0): assert index >= 0, f"device index should be a positive integer: actual = `{index}`" diff --git a/gptqmodel/utils/vram.py b/gptqmodel/utils/vram.py index 7d24445ee..04e531b04 100644 --- a/gptqmodel/utils/vram.py +++ b/gptqmodel/utils/vram.py @@ -11,6 +11,20 @@ from accelerate.utils import convert_bytes +_ONE_BYTE_FLOATX_DTYPES = frozenset( + getattr(torch, name) + for name in ( + "float8_e4m3fn", + "float8_e5m2", + "float8_e4m3fnuz", + "float8_e5m2fnuz", + "float8_e8m0fnu", + "float4_e2m1fn_x2", + ) + if hasattr(torch, name) +) + + def dtype_byte_size(dtype: torch.dtype): """ Returns the size (in bytes) occupied by one parameter of type `dtype`. @@ -30,7 +44,7 @@ def dtype_byte_size(dtype: torch.dtype): return 1 / 2 elif dtype == "fp8": return 1 - elif dtype == torch.float8_e4m3fn: + elif dtype in _ONE_BYTE_FLOATX_DTYPES: return 1 elif dtype == torch.float16 or dtype == torch.bfloat16: return 2 diff --git a/gptqmodel/version.py b/gptqmodel/version.py index d7457a2f2..e3dae98b5 100644 --- a/gptqmodel/version.py +++ b/gptqmodel/version.py @@ -7,4 +7,4 @@ # even minor versions are release # 5.2.0 => release, 5.1.0 => devel # micro version (5.2.x) denotes patch fix, i.e. 5.2.1 is a patch fix release -__version__ = "5.8.0" +__version__ = "6.1.0-dev" diff --git a/gptqmodel_ext/__init__.py b/gptqmodel_ext/__init__.py index 2a40b7225..b152a6919 100644 --- a/gptqmodel_ext/__init__.py +++ b/gptqmodel_ext/__init__.py @@ -2,6 +2,6 @@ # SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium -"""Support files for GPTQModel native extensions.""" +"""Support files for GPT-QModel native extensions.""" __all__ = [] diff --git a/gptqmodel_ext/awq/gemm_fast_cuda_entry.cu b/gptqmodel_ext/awq/gemm_fast_cuda_entry.cu new file mode 100644 index 000000000..55887c19b --- /dev/null +++ b/gptqmodel_ext/awq/gemm_fast_cuda_entry.cu @@ -0,0 +1 @@ +#include "quantization_new/gemm/gemm_cuda.cu" diff --git a/gptqmodel_ext/awq/gemv_fast_cuda_entry.cu b/gptqmodel_ext/awq/gemv_fast_cuda_entry.cu new file mode 100644 index 000000000..4b648b406 --- /dev/null +++ b/gptqmodel_ext/awq/gemv_fast_cuda_entry.cu @@ -0,0 +1 @@ +#include "quantization_new/gemv/gemv_cuda.cu" diff --git a/gptqmodel_ext/awq/quantization/dequantize.cuh b/gptqmodel_ext/awq/quantization/dequantize.cuh index 5d333b35c..82f49f57c 100644 --- a/gptqmodel_ext/awq/quantization/dequantize.cuh +++ b/gptqmodel_ext/awq/quantization/dequantize.cuh @@ -11,6 +11,7 @@ Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransfor #pragma once +#include __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) { @@ -77,3 +78,24 @@ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) return result; } +__device__ uint4 dequantize_s4_to_bf16x2(uint32_t const& source) +{ + uint4 result; + + uint32_t* h = reinterpret_cast(&result); + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint32_t LOW_NIBBLE_MASK = 0x000f000f; + static constexpr uint32_t OR_MASK = 0x43004300; + static constexpr uint32_t BF16_BIAS = 0x43004300; + for (int i = 0; i < 4; ++i) + { + h[i] = source >> (4 * i); + asm volatile("lop3.b32 %0, %0, %1, %2, %3;\n" + : "+r"(h[i]) + : "n"(LOW_NIBBLE_MASK), "n"(OR_MASK), "n"(immLut)); + const __nv_bfloat162 bias = *reinterpret_cast(&BF16_BIAS); + reinterpret_cast<__nv_bfloat162*>(h)[i] = __hsub2(reinterpret_cast<__nv_bfloat162*>(h)[i], bias); + } + + return result; +} diff --git a/gptqmodel_ext/awq/quantization/gemm_cuda.h b/gptqmodel_ext/awq/quantization/gemm_cuda.h index afc816515..1ea364d19 100644 --- a/gptqmodel_ext/awq/quantization/gemm_cuda.h +++ b/gptqmodel_ext/awq/quantization/gemm_cuda.h @@ -1,6 +1,9 @@ #include torch::Tensor gemm_forward_cuda(torch::Tensor _in_feats, torch::Tensor _kernel, + torch::Tensor _scaling_factors, torch::Tensor _zeros, int split_k_iters, bool fp32_accum=false); + +torch::Tensor gemm_forward_cuda_fp32_reduce(torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scaling_factors, torch::Tensor _zeros, int split_k_iters); torch::Tensor grouped_gemm_forward( @@ -20,4 +23,4 @@ torch::Tensor gemmv2_forward_cuda(torch::Tensor _in_feats, torch::Tensor _kernel // Source - https://github.com/compressa-ai/AutoAWQ/blob/6673333456b8871522b11a7fb110de612edfdf95/awq_cuda/quantization/gemm_cuda.h#L9C1-L10C106 torch::Tensor dequantize_weights_cuda(torch::Tensor _kernel, - torch::Tensor _scaling_factors, torch::Tensor _zeros, int split_k_iters, int thx, int thy, bool dbg); \ No newline at end of file + torch::Tensor _scaling_factors, torch::Tensor _zeros, int split_k_iters, int thx, int thy, bool dbg); diff --git a/gptqmodel_ext/awq/quantization/gemm_cuda_gen.cu b/gptqmodel_ext/awq/quantization/gemm_cuda_gen.cu index 98f49efac..dc17c0642 100644 --- a/gptqmodel_ext/awq/quantization/gemm_cuda_gen.cu +++ b/gptqmodel_ext/awq/quantization/gemm_cuda_gen.cu @@ -13,9 +13,13 @@ #include #include "gemm_cuda.h" #include "dequantize.cuh" +#include +#include #include #include #include +#include +#include // Pack two half values. @@ -30,20 +34,221 @@ __device__ __forceinline__ int make_divisible(int c, int divisor){ return (c + divisor - 1) / divisor; } -__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C) +template +__device__ __forceinline__ void store_accum_value(output_t* ptr, float value) +{ + *ptr = static_cast(value); +} + +template <> +__device__ __forceinline__ void store_accum_value(half* ptr, float value) +{ + *ptr = __float2half(value); +} + +template <> +__device__ __forceinline__ void store_accum_value(nv_bfloat16* ptr, float value) +{ + *ptr = __float2bfloat16(value); +} + +template +struct vec2_type; + +template <> +struct vec2_type +{ + using type = half2; +}; + +template <> +struct vec2_type +{ + using type = nv_bfloat162; +}; + +template +using vec2_t = typename vec2_type::type; + +template +__device__ __forceinline__ uint4 dequantize_s4_to_x2(uint32_t const& source); + +template <> +__device__ __forceinline__ uint4 dequantize_s4_to_x2(uint32_t const& source) +{ + return dequantize_s4_to_fp16x2(source); +} + +template <> +__device__ __forceinline__ uint4 dequantize_s4_to_x2(uint32_t const& source) +{ + return dequantize_s4_to_bf16x2(source); +} + +template +__device__ __forceinline__ void apply_zero_and_scale(uint4& values_raw, const uint4& scales_raw, const uint4& zeros_raw); + +template <> +__device__ __forceinline__ void apply_zero_and_scale(uint4& values_raw, const uint4& scales_raw, const uint4& zeros_raw) { static constexpr uint32_t ZERO = 0x0; + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(values_raw.x) : "r"(values_raw.x), "r"(zeros_raw.x)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(values_raw.x) : "r"(values_raw.x), "r"(scales_raw.x), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(values_raw.y) : "r"(values_raw.y), "r"(zeros_raw.y)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(values_raw.y) : "r"(values_raw.y), "r"(scales_raw.y), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(values_raw.z) : "r"(values_raw.z), "r"(zeros_raw.z)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(values_raw.z) : "r"(values_raw.z), "r"(scales_raw.z), "r"(ZERO)); + asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(values_raw.w) : "r"(values_raw.w), "r"(zeros_raw.w)); + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(values_raw.w) : "r"(values_raw.w), "r"(scales_raw.w), "r"(ZERO)); +} + +template <> +__device__ __forceinline__ void apply_zero_and_scale(uint4& values_raw, const uint4& scales_raw, const uint4& zeros_raw) +{ + auto* values = reinterpret_cast(&values_raw); + auto* scales = reinterpret_cast(&scales_raw); + auto* zeros = reinterpret_cast(&zeros_raw); +#pragma unroll + for (int i = 0; i < 4; ++i) + { + values[i] = __hmul2(__hsub2(values[i], zeros[i]), scales[i]); + } +} + +template +__device__ __forceinline__ void mma_m16n8(float* C_warp, scalar_t* A_shared_warp, scalar_t* B_shared_warp) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + static_assert(std::is_same_v, "BFloat16 AWQ GEMM requires CUDA_ARCH >= 800."); + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float *)(C_warp))[0]), "=f"(((float *)(C_warp))[1]), "=f"(((float *)(C_warp))[2]), "=f"(((float *)(C_warp))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + 0))[0]), "f"(((float *)(C_warp))[0]), "f"(((float *)(C_warp))[1]), "f"(((float *)(C_warp))[2]), "f"(((float *)(C_warp))[3])); + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=f"(((float *)(C_warp))[0]), "=f"(((float *)(C_warp))[1]), "=f"(((float *)(C_warp))[2]), "=f"(((float *)(C_warp))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + 0))[1]), "f"(((float *)(C_warp))[0]), "f"(((float *)(C_warp))[1]), "f"(((float *)(C_warp))[2]), "f"(((float *)(C_warp))[3])); +#else + if constexpr (std::is_same_v) + { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" + : "=f"(((float *)(C_warp))[0]), "=f"(((float *)(C_warp))[1]), "=f"(((float *)(C_warp))[2]), "=f"(((float *)(C_warp))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + 0))[0]), "r"(((unsigned *)(B_shared_warp + 0))[1]), "f"(((float *)(C_warp))[0]), "f"(((float *)(C_warp))[1]), "f"(((float *)(C_warp))[2]), "f"(((float *)(C_warp))[3])); + } + else + { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32" + "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" + : "=f"(((float *)(C_warp))[0]), "=f"(((float *)(C_warp))[1]), "=f"(((float *)(C_warp))[2]), "=f"(((float *)(C_warp))[3]) + : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + 0))[0]), "r"(((unsigned *)(B_shared_warp + 0))[1]), "f"(((float *)(C_warp))[0]), "f"(((float *)(C_warp))[1]), "f"(((float *)(C_warp))[2]), "f"(((float *)(C_warp))[3])); + } +#endif +} + +static void validate_bf16_device(torch::Device device); + +namespace { + +bool fused_splitk_reduce_enabled() +{ + const char* disable_value = std::getenv("GPTQMODEL_AWQ_DISABLE_FUSED_SPLITK_REDUCE"); + return disable_value == nullptr || std::atoi(disable_value) == 0; +} + +template +__global__ void reduce_splitk_fp32_to_output_kernel( + const float* __restrict__ partials, + output_t* __restrict__ out, + int split_k_iters, + int total_elements) +{ + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + + for (; idx < total_elements; idx += stride) + { + float acc = 0.0f; +#pragma unroll 8 + for (int split_idx = 0; split_idx < split_k_iters; ++split_idx) + { + acc += partials[split_idx * total_elements + idx]; + } + store_accum_value(out + idx, acc); + } +} + +template +torch::Tensor fused_reduce_splitk_fp32_to_output(torch::Tensor out_feats_tensor, at::ScalarType output_dtype) +{ + auto options = torch::TensorOptions().dtype(output_dtype).device(out_feats_tensor.device()); + at::Tensor result = torch::empty({out_feats_tensor.size(1), out_feats_tensor.size(2)}, options); + const int total_elements = static_cast(result.numel()); + const int split_k_iters = static_cast(out_feats_tensor.size(0)); + const int threads = 256; + const int blocks = std::max(1, std::min((total_elements + threads - 1) / threads, 4096)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + auto partials = reinterpret_cast(out_feats_tensor.data_ptr()); + output_t* out = nullptr; + if constexpr (std::is_same_v) + { + out = reinterpret_cast(result.data_ptr()); + } + else if constexpr (std::is_same_v) + { + out = reinterpret_cast(result.data_ptr()); + } + else + { + out = reinterpret_cast(result.data_ptr()); + } + + reduce_splitk_fp32_to_output_kernel<<>>( + partials, out, split_k_iters, total_elements); + return result; +} + +torch::Tensor maybe_fused_reduce_splitk_fp32_to_output(torch::Tensor out_feats_tensor, at::ScalarType output_dtype) +{ + if (!fused_splitk_reduce_enabled()) + { + return out_feats_tensor.sum(0).to(output_dtype); + } + + switch (output_dtype) + { + case at::kHalf: + return fused_reduce_splitk_fp32_to_output(out_feats_tensor, output_dtype); + case at::kBFloat16: + return fused_reduce_splitk_fp32_to_output(out_feats_tensor, output_dtype); + case at::kFloat: + return fused_reduce_splitk_fp32_to_output(out_feats_tensor, output_dtype); + default: + return out_feats_tensor.sum(0).to(output_dtype); + } +} + +} // namespace + +template +__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, int split_k_iters, scalar_t* __restrict__ A, int* __restrict__ B, scalar_t* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, output_t* __restrict__ C) +{ float C_warp[32]; - __shared__ half A_shared[16 * (32 + 8)]; - __shared__ half B_shared[32 * (128 + 8)]; + __shared__ scalar_t A_shared[16 * (32 + 8)]; + __shared__ scalar_t B_shared[32 * (128 + 8)]; int j_factors1 = ((OC + 128 - 1) / 128); int blockIdx_x = 0; int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1); int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1); - half A_shared_warp[8]; - half B_shared_warp[32]; + scalar_t A_shared_warp[8]; + scalar_t B_shared_warp[32]; for (int j_0_4_init = 0; j_0_4_init < 4; ++j_0_4_init) { for (int i = 0; i < 8; ++i) { C_warp[(j_0_4_init * 8) + i] = 0.0; @@ -57,7 +262,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id // bool wb_C_flag = (threadIdx.x / 4) < M; - half* A_ptr = A + scalar_t* A_ptr = A + (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC + (((int)threadIdx.x) % (32 / 8)) * 8; @@ -68,12 +273,12 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i + (((int)threadIdx.x) % (128 / 8)) * 1; // Why * 1 in the above line? - half* A_shared_ptr = A_shared + scalar_t* A_shared_ptr = A_shared + ((int)threadIdx.y) * row_stride_warp * (32 + 8) + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) + (((int)threadIdx.x) % (32 / 8) ) * 8; - half* B_shared_ptr = B_shared + scalar_t* B_shared_ptr = B_shared + ((int)threadIdx.y) * (row_stride / 2) * (128 + 8) + (((int)threadIdx.x) / (128 / 8)) * (128 + 8) + (((int)threadIdx.x) % (128 / 8)) * 8; @@ -82,11 +287,11 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i + (((int)blockIdx_y) % j_factors1) * (128 / 8) + ((int)threadIdx.x) % (128 / 8); - half* scaling_factors_ptr = scaling_factors + scalar_t* scaling_factors_ptr = scaling_factors + (((int)blockIdx_y) % j_factors1) * (128) + (((int)threadIdx.x) % (128 / 8)) * 8; - half* C_ptr = C + output_t* C_ptr = C + static_cast(blockIdx_z) * M * OC // blockIdz.x -> split_k dim + (((int)blockIdx_y) % j_factors1) * 128 + ((int)threadIdx.y) * 64 @@ -110,7 +315,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) { uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8)); - uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); + uint4 B_loaded_zero = dequantize_s4_to_x2(zeros_loaded); uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC)); /* if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){ @@ -128,20 +333,8 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8))); // row stride in shared memory: (NWARPS * 32 * 8 / cta_N) uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); - uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); - //uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8); - - // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8); - // - zero and * scale - // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale. - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); + uint4 B_loaded_values = dequantize_s4_to_x2(B_loaded); + apply_zero_and_scale(B_loaded_values, B_loaded_scale, B_loaded_zero); /* if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){ printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w); @@ -149,7 +342,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i */ // write back - *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (128 + 8)) = B_loaded_fp16; + *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (128 + 8)) = B_loaded_values; } __syncthreads(); @@ -188,55 +381,8 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i } } for (int j_0_4 = 0; j_0_4 < 4; ++j_0_4) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - { - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); - } - - { - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); - } - - { - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); - } - - { - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); - } -#else - { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" - : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); - } - - { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" - : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); - } -#endif + mma_m16n8(C_warp + (j_0_4 * 8), A_shared_warp, B_shared_warp + (j_0_4 * 8)); + mma_m16n8(C_warp + ((j_0_4 * 8) + 4), A_shared_warp, B_shared_warp + ((j_0_4 * 8) + 4)); } } } @@ -247,22 +393,19 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; if (row_offset < M) { - *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]); + store_accum_value(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2, C_warp[(ax1_0_1 * 8) + local_id]); } } } } -__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C) +template +__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, int split_k_iters, scalar_t* __restrict__ A, int* __restrict__ B, scalar_t* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, output_t* __restrict__ C) { - static constexpr uint32_t ZERO = 0x0; float C_warp[32]; - __shared__ half A_shared[16 * (32 + 8)]; - __shared__ half B_shared[32 * (64 + 8)]; - - __shared__ half scaling_factors_shared[64]; - __shared__ half zeros_shared[64]; + __shared__ scalar_t A_shared[16 * (32 + 8)]; + __shared__ scalar_t B_shared[32 * (64 + 8)]; int j_factors1 = ((OC + 64 - 1) / 64); @@ -270,8 +413,8 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1); int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1); - half A_shared_warp[8]; - half B_shared_warp[16]; + scalar_t A_shared_warp[8]; + scalar_t B_shared_warp[16]; for (int j_0_4_init = 0; j_0_4_init < 2; ++j_0_4_init) { for (int i = 0; i < 8; ++i) { C_warp[(j_0_4_init * 8) + i] = 0.0; @@ -281,52 +424,47 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in static constexpr int row_stride_warp = 32 * 8 / 32; static constexpr int row_stride = 2 * 32 * 8 / 64; bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 64; - // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 - bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; // threadIdx.y is warp_id - // bool wb_C_flag = (threadIdx.x / 4) < M; + bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M; - half* A_ptr = A + scalar_t* A_ptr = A + (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC + (((int)threadIdx.x) % (32 / 8)) * 8; - + int* B_ptr = B + ((int)threadIdx.y) * (OC / 8) * 4 + (((int)threadIdx.x) / (64 / 8)) * (OC / 8) + (((int)blockIdx_y) % j_factors1) * (64 / 8) + (((int)threadIdx.x) % (64 / 8)) * 1; -// Why * 1 in the above line? - - half* A_shared_ptr = A_shared - + ((int)threadIdx.y) * row_stride_warp * (32 + 8) + + scalar_t* A_shared_ptr = A_shared + + ((int)threadIdx.y) * row_stride_warp * (32 + 8) + (((int)threadIdx.x) / (32 / 8)) * (32 + 8) + (((int)threadIdx.x) % (32 / 8) ) * 8; - half* B_shared_ptr = B_shared + scalar_t* B_shared_ptr = B_shared + ((int)threadIdx.y) * (row_stride / 2) * (64 + 8) + (((int)threadIdx.x) / (64 / 8)) * (64 + 8) + (((int)threadIdx.x) % (64 / 8)) * 8; - + int* zeros_ptr = zeros + (((int)blockIdx_y) % j_factors1) * (64 / 8) + ((int)threadIdx.x) % (64 / 8); - - half* scaling_factors_ptr = scaling_factors - + (((int)blockIdx_y) % j_factors1) * (64) + + scalar_t* scaling_factors_ptr = scaling_factors + + (((int)blockIdx_y) % j_factors1) * (64) + (((int)threadIdx.x) % (64 / 8)) * 8; - half* C_ptr = C - + static_cast(blockIdx_z) * M * OC // blockIdz.x -> split_k dim + output_t* C_ptr = C + + static_cast(blockIdx_z) * M * OC + (((int)blockIdx_y) % j_factors1) * 64 + ((int)threadIdx.y) * 32 + (((int)threadIdx.x) % 4) * 2; - // preload s.f. and zeros int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters; if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1; for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) { int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z; __syncthreads(); - // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16 if (ld_A_flag) { *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32)); @@ -336,52 +474,20 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0); } - // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) { uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8)); - uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); + uint4 B_loaded_zero = dequantize_s4_to_x2(zeros_loaded); uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC)); - /* - if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){ - printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w); - } - */ - // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0); int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8); for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 4; ++ax0_ax1_fused_0) { - - // B: 32 x 136 (128+8) float16 - // each warp: 32 x 4 - // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4 - // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8))); - // row stride in shared memory: (NWARPS * 32 * 8 / cta_N) uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8)); - uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); - //uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8); - - // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8); - // - zero and * scale - // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale. - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); - /* - if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){ - printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w); - } - */ - - // write back - *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (64 + 8)) = B_loaded_fp16; + uint4 B_loaded_values = dequantize_s4_to_x2(B_loaded); + apply_zero_and_scale(B_loaded_values, B_loaded_scale, B_loaded_zero); + *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (64 + 8)) = B_loaded_values; } __syncthreads(); - for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) + for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) { { unsigned int addr; @@ -397,88 +503,37 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in : "r"(addr) ); } - - for (int ax1_0 = 0; ax1_0 < 2; ++ax1_0) + for (int ax1_0 = 0; ax1_0 < 2; ++ax1_0) { - { - unsigned int addr; - asm volatile( - "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" - : "=r"(addr) - : "l"((void *)((&(B_shared[(((k_0_1 * 1152) + (((int)threadIdx.y) * 32)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 72) + ((((int)threadIdx.x) >> 4) * 8)))) - ); - asm volatile( - "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" - "{%0, %1, %2, %3}, [%4];\n" - : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3]) - : "r"(addr) - ); - } + unsigned int addr; + asm volatile( + "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n" + : "=r"(addr) + : "l"((void *)((&(B_shared[(((k_0_1 * 1152) + (((int)threadIdx.y) * 32)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 72) + ((((int)threadIdx.x) >> 4) * 8)))) + ); + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16" + "{%0, %1, %2, %3}, [%4];\n" + : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3]) + : "r"(addr) + ); } - - for (int j_0_4 = 0; j_0_4 < 2; ++j_0_4) - { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - { - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); - } - - { - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); - } - - { - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); - } - - { - asm volatile( - "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" - : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); - } -#else - { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" - : "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3])); - } - { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" - "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" - : "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]) - : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])); - } -#endif + for (int j_0_4 = 0; j_0_4 < 2; ++j_0_4) + { + mma_m16n8(C_warp + (j_0_4 * 8), A_shared_warp, B_shared_warp + (j_0_4 * 8)); + mma_m16n8(C_warp + ((j_0_4 * 8) + 4), A_shared_warp, B_shared_warp + ((j_0_4 * 8) + 4)); } } } -// TODO: Shang: Hoist loop invariance. for (int ax1_0_1 = 0; ax1_0_1 < 2; ++ax1_0_1) { for (int local_id = 0; local_id < 8; ++local_id) { int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8; if (row_offset < M) { - *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]); + store_accum_value(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2, C_warp[(ax1_0_1 * 8) + local_id]); } } } @@ -725,13 +780,14 @@ __global__ void __launch_bounds__(128) gemmv2_forward_4bit_cuda_m128n64k32(int s } } -// Dequantization to fp16 +// Dequantization to fp16/bf16 // kernel // Source - https://github.com/compressa-ai/AutoAWQ/blob/6673333456b8871522b11a7fb110de612edfdf95/awq_cuda/quantization/gemm_cuda_gen.cu#L32C1-L32C1 +template __global__ void __launch_bounds__(64) dequantize_weights(int* __restrict__ B, // 4096x64 4096 rows 64 cols - half* __restrict__ scaling_factors, // 32x512 32 rows 512 cols + scalar_t* __restrict__ scaling_factors, // 32x512 32 rows 512 cols int* __restrict__ zeros, // 32x64 32 rows 64 cols - half* __restrict__ C, // 4096x512 4096 rows 512 cols + scalar_t* __restrict__ C, // 4096x512 4096 rows 512 cols int G, int in_c, int out_c) @@ -745,19 +801,18 @@ __global__ void __launch_bounds__(64) dequantize_weights(int* __restrict__ B, // int j_factors1 = 4; int row_stride2 = 4; int split_k_iters = 1; - static constexpr uint32_t ZERO = 0x0; - half B_shared[32 * (128 + 8)]; + scalar_t B_shared[32 * (128 + 8)]; - half* B_shared_ptr2 = B_shared; + scalar_t* B_shared_ptr2 = B_shared; - half B_shared_warp[32]; + scalar_t B_shared_warp[32]; int OC = 512; int N = blockDim.x * gridDim.x; // 2 int col = (blockIdx.x * blockDim.x + threadIdx.x); int row = blockIdx.y * blockDim.y + threadIdx.y; int index1 = 8 * col + 8 * row * N; // + i (<8) - half* C_ptr2 = C + index1; + scalar_t* C_ptr2 = C + index1; int index2 = col + row * N; int* B_ptr2 = B + index2; @@ -765,26 +820,19 @@ __global__ void __launch_bounds__(64) dequantize_weights(int* __restrict__ B, // int index3 = col + (int)(row / G) * N; int* zeros_ptr2 = zeros + index3; int index4 = 8 * col + (int)(row / G) * N * 8; // + i (<8) - half* scaling_factors_ptr2 = scaling_factors + index4; + scalar_t* scaling_factors_ptr2 = scaling_factors + index4; uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr2); - uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); + uint4 B_loaded_zero = dequantize_s4_to_x2(zeros_loaded); uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr2); int j=0; uint32_t B_loaded = *(uint32_t*)(B_ptr2 + j); - uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); - asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); - asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); + uint4 B_loaded_values = dequantize_s4_to_x2(B_loaded); + apply_zero_and_scale(B_loaded_values, B_loaded_scale, B_loaded_zero); - *(uint4*)(B_shared_ptr2 + j) = B_loaded_fp16; + *(uint4*)(B_shared_ptr2 + j) = B_loaded_values; for (int i=0; i<8; ++i) { *(C_ptr2 + i) = B_shared[i]; @@ -1119,6 +1167,15 @@ torch::Tensor dequantize_weights_cuda( int num_experts = _kernel.dim() == 2 ? 1 : _kernel.size(0); int out_c = qout_c * 8; int G = in_c / (_kernel.dim() == 2 ? _scaling_factors.size(0) : _scaling_factors.size(1)); + TORCH_CHECK(_kernel.is_cuda(), "AWQ dequantize expects CUDA packed weights."); + TORCH_CHECK(_scaling_factors.is_cuda(), "AWQ dequantize expects CUDA scales."); + TORCH_CHECK(_zeros.is_cuda(), "AWQ dequantize expects CUDA zero-points."); + TORCH_CHECK(_kernel.scalar_type() == at::kInt, "AWQ dequantize packed weights must be int32."); + TORCH_CHECK(_zeros.scalar_type() == at::kInt, "AWQ dequantize packed zero-points must be int32."); + TORCH_CHECK( + _scaling_factors.scalar_type() == at::kHalf || _scaling_factors.scalar_type() == at::kBFloat16, + "AWQ dequantize only supports float16 and bfloat16 scales." + ); int x_thread = thx; int y_thread = thy; @@ -1143,25 +1200,34 @@ torch::Tensor dequantize_weights_cuda( const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors)); - auto options = torch::TensorOptions().dtype(_scaling_factors.dtype()).device(_scaling_factors.device()); - at::Tensor _de_kernel; + auto options = torch::TensorOptions().dtype(_scaling_factors.dtype()).device(_scaling_factors.device()); + at::Tensor _de_kernel; if (num_experts == 1) { _de_kernel = torch::empty({in_c, out_c}, options); } else { _de_kernel = torch::empty({num_experts, in_c, out_c}, options); } - auto kernel = reinterpret_cast(_kernel.data_ptr()); - auto de_kernel = reinterpret_cast(_de_kernel.data_ptr()); - auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); - auto zeros = reinterpret_cast(_zeros.data_ptr()); - - dim3 num_blocks(x_blocks, y_blocks, num_experts); - dim3 threads_per_block(x_thread, y_thread); // col, row 64x4096 + auto kernel = reinterpret_cast(_kernel.data_ptr()); + auto zeros = reinterpret_cast(_zeros.data_ptr()); - dequantize_weights<<>>(kernel, scaling_factors, zeros, de_kernel, G, in_c, out_c); + dim3 num_blocks(x_blocks, y_blocks, num_experts); + dim3 threads_per_block(x_thread, y_thread); // col, row 64x4096 + if (_scaling_factors.scalar_type() == at::kBFloat16) + { + validate_bf16_device(_scaling_factors.device()); + auto de_kernel = reinterpret_cast(_de_kernel.data_ptr()); + auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); + dequantize_weights<<>>(kernel, scaling_factors, zeros, de_kernel, G, in_c, out_c); + } + else + { + auto de_kernel = reinterpret_cast(_de_kernel.data_ptr()); + auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); + dequantize_weights<<>>(kernel, scaling_factors, zeros, de_kernel, G, in_c, out_c); + } - return _de_kernel; + return _de_kernel; } // in_feats: M, IC [float16] @@ -1230,59 +1296,209 @@ torch::Tensor gemmv2_forward_cuda( // zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b] // assume that batch_size < 16 for now -torch::Tensor gemm_forward_cuda( - torch::Tensor _in_feats, - torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, - int split_k_iters) +template +scalar_t* tensor_ptr(torch::Tensor& tensor); + +template <> +half* tensor_ptr(torch::Tensor& tensor) { - int num_in_feats = _in_feats.size(0); - int num_in_channels = _in_feats.size(1); - const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); + return reinterpret_cast(tensor.data_ptr()); +} - auto options = torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); - at::Tensor _out_feats = torch::empty({split_k_iters, num_in_feats, _kernel.size(1) * 8}, options); - int num_out_feats = _out_feats.size(-2); - int num_out_channels = _out_feats.size(-1); +template <> +nv_bfloat16* tensor_ptr(torch::Tensor& tensor) +{ + return reinterpret_cast(tensor.data_ptr()); +} - auto in_feats = reinterpret_cast(_in_feats.data_ptr()); - auto kernel = reinterpret_cast(_kernel.data_ptr()); - auto out_feats = reinterpret_cast(_out_feats.data_ptr()); - auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); - auto zeros = reinterpret_cast(_zeros.data_ptr()); - int group_size = num_in_channels / _scaling_factors.size(0); +static void validate_awq_gemm_args( + const torch::Tensor& in_feats, + const torch::Tensor& kernel, + const torch::Tensor& scaling_factors, + const torch::Tensor& zeros, + int group_size, + int num_out_channels) +{ + TORCH_CHECK(in_feats.is_cuda(), "AWQ GEMM expects CUDA input activations."); + TORCH_CHECK(kernel.is_cuda(), "AWQ GEMM expects CUDA packed weights."); + TORCH_CHECK(scaling_factors.is_cuda(), "AWQ GEMM expects CUDA scales."); + TORCH_CHECK(zeros.is_cuda(), "AWQ GEMM expects CUDA zero-points."); + TORCH_CHECK(in_feats.is_contiguous(), "AWQ GEMM expects contiguous activations."); + TORCH_CHECK(kernel.is_contiguous(), "AWQ GEMM expects contiguous packed weights."); + TORCH_CHECK(scaling_factors.is_contiguous(), "AWQ GEMM expects contiguous scales."); + TORCH_CHECK(zeros.is_contiguous(), "AWQ GEMM expects contiguous zero-points."); + TORCH_CHECK(kernel.scalar_type() == at::kInt, "AWQ GEMM packed weights must be int32."); + TORCH_CHECK(zeros.scalar_type() == at::kInt, "AWQ GEMM packed zero-points must be int32."); + TORCH_CHECK( + in_feats.scalar_type() == at::kHalf || in_feats.scalar_type() == at::kBFloat16, + "AWQ GEMM only supports float16 and bfloat16 activations." + ); + TORCH_CHECK( + scaling_factors.scalar_type() == at::kHalf || scaling_factors.scalar_type() == at::kBFloat16, + "AWQ GEMM only supports float16 and bfloat16 scales." + ); + TORCH_CHECK( + in_feats.scalar_type() == scaling_factors.scalar_type(), + "AWQ GEMM expects activations and scales to share the same dtype before launch." + ); + TORCH_CHECK(num_out_channels % 64 == 0, "OC is not multiple of cta_N = 64"); + TORCH_CHECK(num_out_channels % 8 == 0, "OC is not multiple of pack_num = 8"); + TORCH_CHECK(group_size % 32 == 0, "Group size should be a multiple of 32"); + TORCH_CHECK(num_out_channels % group_size == 0, "OC is not multiple of Group size"); +} + +static void validate_bf16_device(torch::Device device) +{ + cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, device.index()); + TORCH_CHECK(prop.major >= 8, "BFloat16 AWQ GEMM requires compute capability >= 8.0."); +} + +template +torch::Tensor launch_gemm_forward_cuda( + torch::Tensor in_feats_tensor, + torch::Tensor kernel_tensor, + torch::Tensor scaling_factors_tensor, + torch::Tensor zeros_tensor, + int split_k_iters, + at::ScalarType output_dtype) +{ + int num_in_feats = in_feats_tensor.size(0); + int num_in_channels = in_feats_tensor.size(1); + auto options = torch::TensorOptions().dtype(in_feats_tensor.scalar_type()).device(in_feats_tensor.device()); + at::Tensor out_feats_tensor = torch::empty({split_k_iters, num_in_feats, kernel_tensor.size(1) * 8}, options); + int num_out_feats = out_feats_tensor.size(-2); + int num_out_channels = out_feats_tensor.size(-1); + int group_size = num_in_channels / scaling_factors_tensor.size(0); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - if (num_out_channels % 64 != 0) - throw std::invalid_argument("OC is not multiple of cta_N = 64"); - if (num_out_channels % 8 != 0) - throw std::invalid_argument("OC is not multiple of pack_num = 8"); - if (group_size % 32 != 0) - throw std::invalid_argument("Group size should be a multiple of 32"); - if (num_out_channels % group_size != 0) - throw std::invalid_argument("OC is not multiple of Group size"); + validate_awq_gemm_args(in_feats_tensor, kernel_tensor, scaling_factors_tensor, zeros_tensor, group_size, num_out_channels); + if constexpr (std::is_same_v) + { + validate_bf16_device(in_feats_tensor.device()); + } + + auto in_feats = tensor_ptr(in_feats_tensor); + auto kernel = reinterpret_cast(kernel_tensor.data_ptr()); + auto out_feats = tensor_ptr(out_feats_tensor); + auto scaling_factors = tensor_ptr(scaling_factors_tensor); + auto zeros = reinterpret_cast(zeros_tensor.data_ptr()); if (num_out_channels % 128 == 0) { int j_factors1 = num_out_channels / 128 / 1; dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); - // threadIdx.x: 32 - // threadIdx.y: i_factors[2] * j_factors[2] dim3 threads_per_block(32, 2); - gemm_forward_4bit_cuda_m16n128k32<<>>( + gemm_forward_4bit_cuda_m16n128k32<<>>( group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats); } else if (num_out_channels % 64 == 0) { - int j_factors1 = num_out_channels / 64 / 1; + int j_factors1 = num_out_channels / 64 / 1; dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); - - // threadIdx.x: 32 - // threadIdx.y: i_factors[2] * j_factors[2] dim3 threads_per_block(32, 2); - gemm_forward_4bit_cuda_m16n64k32<<>>( + gemm_forward_4bit_cuda_m16n64k32<<>>( group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats); } - return _out_feats.sum(0); + + auto result = out_feats_tensor.sum(0); + return result.scalar_type() == output_dtype ? result : result.to(output_dtype); +} + +template +torch::Tensor launch_gemm_forward_cuda_fp32_reduce( + torch::Tensor in_feats_tensor, + torch::Tensor kernel_tensor, + torch::Tensor scaling_factors_tensor, + torch::Tensor zeros_tensor, + int split_k_iters, + at::ScalarType output_dtype) +{ + int num_in_feats = in_feats_tensor.size(0); + int num_in_channels = in_feats_tensor.size(1); + auto options = torch::TensorOptions().dtype(torch::kFloat32).device(in_feats_tensor.device()); + at::Tensor out_feats_tensor = torch::empty({split_k_iters, num_in_feats, kernel_tensor.size(1) * 8}, options); + int num_out_feats = out_feats_tensor.size(-2); + int num_out_channels = out_feats_tensor.size(-1); + int group_size = num_in_channels / scaling_factors_tensor.size(0); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + validate_awq_gemm_args(in_feats_tensor, kernel_tensor, scaling_factors_tensor, zeros_tensor, group_size, num_out_channels); + if constexpr (std::is_same_v) + { + validate_bf16_device(in_feats_tensor.device()); + } + + auto in_feats = tensor_ptr(in_feats_tensor); + auto kernel = reinterpret_cast(kernel_tensor.data_ptr()); + auto out_feats = reinterpret_cast(out_feats_tensor.data_ptr()); + auto scaling_factors = tensor_ptr(scaling_factors_tensor); + auto zeros = reinterpret_cast(zeros_tensor.data_ptr()); + + if (num_out_channels % 128 == 0) + { + int j_factors1 = num_out_channels / 128 / 1; + dim3 num_blocks((num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); + dim3 threads_per_block(32, 2); + gemm_forward_4bit_cuda_m16n128k32<<>>( + group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats); + } + else if (num_out_channels % 64 == 0) + { + int j_factors1 = num_out_channels / 64 / 1; + dim3 num_blocks(1 * (num_out_feats + 16 - 1) / 16 * j_factors1 * split_k_iters); + dim3 threads_per_block(32, 2); + gemm_forward_4bit_cuda_m16n64k32<<>>( + group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats); + } + + return maybe_fused_reduce_splitk_fp32_to_output(out_feats_tensor, output_dtype); +} + +torch::Tensor gemm_forward_cuda( + torch::Tensor _in_feats, + torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, + int split_k_iters, + bool fp32_accum) +{ + if (fp32_accum) + { + return gemm_forward_cuda_fp32_reduce(_in_feats, _kernel, _scaling_factors, _zeros, split_k_iters); + } + + const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); + at::ScalarType output_dtype = _in_feats.scalar_type(); + torch::Tensor in_feats = _in_feats; + torch::Tensor scaling_factors = _scaling_factors; + torch::Tensor kernel = _kernel; + torch::Tensor zeros = _zeros; + + if (in_feats.scalar_type() == at::kBFloat16) + { + return launch_gemm_forward_cuda(in_feats, kernel, scaling_factors, zeros, split_k_iters, output_dtype); + } + return launch_gemm_forward_cuda(in_feats, kernel, scaling_factors, zeros, split_k_iters, output_dtype); +} + +torch::Tensor gemm_forward_cuda_fp32_reduce( + torch::Tensor _in_feats, + torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, + int split_k_iters) +{ + const at::cuda::OptionalCUDAGuard device_guard(device_of(_in_feats)); + at::ScalarType output_dtype = _in_feats.scalar_type(); + torch::Tensor in_feats = _in_feats; + torch::Tensor scaling_factors = _scaling_factors; + torch::Tensor kernel = _kernel; + torch::Tensor zeros = _zeros; + + if (in_feats.scalar_type() == at::kBFloat16) + { + return launch_gemm_forward_cuda_fp32_reduce(in_feats, kernel, scaling_factors, zeros, split_k_iters, output_dtype); + } + return launch_gemm_forward_cuda_fp32_reduce(in_feats, kernel, scaling_factors, zeros, split_k_iters, output_dtype); } diff --git a/gptqmodel_ext/awq/torch_bind.cpp b/gptqmodel_ext/awq/torch_bind.cpp new file mode 100644 index 000000000..1ef91e191 --- /dev/null +++ b/gptqmodel_ext/awq/torch_bind.cpp @@ -0,0 +1,80 @@ +#include +#include + +#include "quantization/gemm_cuda.h" +#include "quantization/gemv_cuda.h" +#include "quantization_new/gemm/gemm_cuda.h" +#include "quantization_new/gemv/gemv_cuda.h" + +namespace { + +torch::Tensor gemm_forward_dispatch(torch::Tensor in_feats, torch::Tensor kernel, + torch::Tensor scaling_factors, torch::Tensor zeros, + int64_t split_k_iters, bool fp32_accum) { + return gemm_forward_cuda(in_feats, kernel, scaling_factors, zeros, static_cast(split_k_iters), + fp32_accum); +} + +torch::Tensor gemm_forward_fp32_reduce_dispatch(torch::Tensor in_feats, torch::Tensor kernel, + torch::Tensor scaling_factors, torch::Tensor zeros, + int64_t split_k_iters) { + return gemm_forward_cuda_fp32_reduce(in_feats, kernel, scaling_factors, zeros, + static_cast(split_k_iters)); +} + +torch::Tensor gemmv2_forward_dispatch(torch::Tensor in_feats, torch::Tensor kernel, + torch::Tensor scaling_factors, torch::Tensor zeros, + int64_t group_size, int64_t split_k_iters) { + return gemmv2_forward_cuda(in_feats, kernel, scaling_factors, zeros, static_cast(group_size), + static_cast(split_k_iters)); +} + +torch::Tensor gemv_forward_dispatch(torch::Tensor in_feats, torch::Tensor kernel, + torch::Tensor scaling_factors, torch::Tensor zeros, + int64_t group_size) { + return gemv_forward_cuda(in_feats, kernel, scaling_factors, zeros, static_cast(group_size)); +} + +torch::Tensor gemm_fast_forward_prefill_dispatch(torch::Tensor in_feats, torch::Tensor kernel, + torch::Tensor scaling_factors, + torch::Tensor zeros) { + return gemm_forward_cuda_prefill(in_feats, kernel, scaling_factors, zeros); +} + +torch::Tensor gemv_fast_forward_decode_dispatch(torch::Tensor in_feats, torch::Tensor kernel, + torch::Tensor scaling_factors, torch::Tensor zeros, + int64_t m, int64_t n, int64_t k, + int64_t group_size) { + return gemv_forward_cuda_decode(in_feats, kernel, scaling_factors, zeros, static_cast(m), + static_cast(n), static_cast(k), + static_cast(group_size)); +} + +torch::Tensor dequantize_weights_dispatch(torch::Tensor kernel, torch::Tensor scaling_factors, + torch::Tensor zeros, int64_t split_k_iters, int64_t thx, + int64_t thy, bool dbg) { + return dequantize_weights_cuda(kernel, scaling_factors, zeros, static_cast(split_k_iters), + static_cast(thx), static_cast(thy), dbg); +} + +} // namespace + +TORCH_LIBRARY(gptqmodel_awq, m) { + m.def("gemm_forward(Tensor in_feats, Tensor kernel, Tensor scaling_factors, Tensor zeros, int split_k_iters, bool fp32_accum=False) -> Tensor"); + m.def("gemm_forward_fp32_reduce(Tensor in_feats, Tensor kernel, Tensor scaling_factors, Tensor zeros, int split_k_iters) -> Tensor"); + m.def("gemmv2_forward(Tensor in_feats, Tensor kernel, Tensor scaling_factors, Tensor zeros, int group_size, int split_k_iters) -> Tensor"); + m.def("gemv_forward(Tensor in_feats, Tensor kernel, Tensor scaling_factors, Tensor zeros, int group_size) -> Tensor"); + m.def("gemm_fast_forward_prefill(Tensor in_feats, Tensor kernel, Tensor scaling_factors, Tensor zeros) -> Tensor"); + m.def("gemv_fast_forward_decode(Tensor in_feats, Tensor kernel, Tensor scaling_factors, Tensor zeros, int m, int n, int k, int group_size) -> Tensor"); + m.def("dequantize_weights(Tensor kernel, Tensor scaling_factors, Tensor zeros, int split_k_iters, int thx, int thy, bool dbg) -> Tensor"); +} + +TORCH_LIBRARY_IMPL(gptqmodel_awq, CUDA, m) { + m.impl("gemm_forward", &gemm_forward_dispatch); + m.impl("gemm_forward_fp32_reduce", &gemm_forward_fp32_reduce_dispatch); + m.impl("gemmv2_forward", &gemmv2_forward_dispatch); + m.impl("gemv_forward", &gemv_forward_dispatch); + m.impl("gemm_fast_forward_prefill", &gemm_fast_forward_prefill_dispatch); + m.impl("gemv_fast_forward_decode", &gemv_fast_forward_decode_dispatch); + m.impl("dequantize_weights", &dequantize_weights_dispatch); +} diff --git a/gptqmodel_ext/cutlass_extensions/__init__.py b/gptqmodel_ext/cutlass_extensions/__init__.py index a903f4cee..b86fd2361 100644 --- a/gptqmodel_ext/cutlass_extensions/__init__.py +++ b/gptqmodel_ext/cutlass_extensions/__init__.py @@ -1 +1 @@ -# Cutlass extension helpers for GPTQModel +# Cutlass extension helpers for GPT-QModel diff --git a/gptqmodel_ext/cutlass_extensions/common.hpp b/gptqmodel_ext/cutlass_extensions/common.hpp index f2c1dcf69..606e57597 100644 --- a/gptqmodel_ext/cutlass_extensions/common.hpp +++ b/gptqmodel_ext/cutlass_extensions/common.hpp @@ -3,16 +3,17 @@ #include "cutlass/cutlass.h" #include #include "cuda_runtime.h" -#include + +#include /** * Helper function for checking CUTLASS errors */ -#define CUTLASS_CHECK(status) \ - { \ - cutlass::Status error = status; \ - TORCH_CHECK(error == cutlass::Status::kSuccess, \ - cutlassGetStatusString(error)); \ +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + STD_TORCH_CHECK(error == cutlass::Status::kSuccess, \ + cutlassGetStatusString(error)); \ } inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { diff --git a/gptqmodel_ext/cutlass_extensions/cute_utils.cuh b/gptqmodel_ext/cutlass_extensions/cute_utils.cuh index f61fe3ceb..116ce854d 100644 --- a/gptqmodel_ext/cutlass_extensions/cute_utils.cuh +++ b/gptqmodel_ext/cutlass_extensions/cute_utils.cuh @@ -1,7 +1,6 @@ #pragma once #include -#include namespace cute { //////////////////////////////////////////////////////////////////// diff --git a/gptqmodel_ext/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp b/gptqmodel_ext/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp index 5c1d6e3f4..0a3c9e9cc 100644 --- a/gptqmodel_ext/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp +++ b/gptqmodel_ext/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp @@ -189,9 +189,9 @@ struct Sm90RowOrScalarBroadcastArray { } auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; - Tensor tGS_gRow_flt = filter_zeros(tGS_gRow); - Tensor tGS_sRow_flt = filter_zeros(tGS_sRow); - Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride())); + cute::Tensor tGS_gRow_flt = filter_zeros(tGS_gRow); + cute::Tensor tGS_sRow_flt = filter_zeros(tGS_sRow); + cute::Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride())); for (int i = 0; i < size(tGS_gRow_flt); ++i) { if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) { @@ -211,8 +211,8 @@ struct Sm90RowOrScalarBroadcastArray { begin_loop(int epi_m, int epi_n) { if (epi_m == 0) { // Assumes M-major subtile loop if (!params.row_broadcast) return; // Do not issue LDS when row is scalar - Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n)); - Tensor tSR_rRow_flt = filter_zeros(tSR_rRow); + cute::Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n)); + cute::Tensor tSR_rRow_flt = filter_zeros(tSR_rRow); copy(tSR_sRow_flt, tSR_rRow_flt); } } @@ -241,9 +241,9 @@ struct Sm90RowOrScalarBroadcastArray { auto [m, n, k, l] = args.tile_coord_mnkl; using ThreadCount = decltype(size(args.tiled_copy)); - Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row_array[l]), make_shape(M,N,1), params.dRow); - Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N) - Tensor sRow = make_tensor(make_smem_ptr(smem), + cute::Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row_array[l]), make_shape(M,N,1), params.dRow); + cute::Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N) + cute::Tensor sRow = make_tensor(make_smem_ptr(smem), make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N) //// G2S: Gmem to Smem auto tiled_g2s = make_tiled_copy(Copy_Atom{}, @@ -251,16 +251,16 @@ struct Sm90RowOrScalarBroadcastArray { Stride<_0, _1>>{}, Layout<_1>{}); auto thr_g2s = tiled_g2s.get_slice(args.thread_idx); - Tensor tGS_gRow = thr_g2s.partition_S(gRow); - Tensor tGS_sRow = thr_g2s.partition_D(sRow); + cute::Tensor tGS_gRow = thr_g2s.partition_S(gRow); + cute::Tensor tGS_sRow = thr_g2s.partition_D(sRow); //// G2S: Coord auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}))); - Tensor tGS_cRow = thr_g2s.partition_S(cRow); + cute::Tensor tGS_cRow = thr_g2s.partition_S(cRow); //// S2R: Smem to Reg - Tensor tSR_sRow = sm90_partition_for_epilogue(sRow, args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N) + cute::Tensor tSR_sRow = sm90_partition_for_epilogue(sRow, args.epi_tile, args.tiled_copy, args.thread_idx); + cute::Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N) return ConsumerStoreCallbacks( tGS_gRow, @@ -389,27 +389,35 @@ struct Sm90ColOrScalarBroadcastArray { CUTLASS_DEVICE void begin() { - Tensor pred = make_tensor(shape(tCgCol)); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(pred); ++i) { - pred(i) = get<0>(tCcCol(i)) < m; - } - if (!params.col_broadcast) { fill(tCrCol, *(params.ptr_col_array[group])); return; } - // Filter so we don't issue redundant copies over stride-0 modes - // (only works if 0-strides are in same location, which is by construction) - copy_if(pred, filter(tCgCol), filter(tCrCol)); + // tCgCol has layout (CPY,CPY_M,CPY_N,EPI_M,EPI_N) where CPY_N and + // EPI_N are stride-0 for the column broadcast. Slice those modes at + // index 0 to avoid redundant copies AND ensure pred/data consistency + static_assert(decltype(stride<2>(tCgCol))::value == 0, "Expected stride-0 CPY_N for col broadcast"); + static_assert(decltype(stride<4>(tCgCol))::value == 0, "Expected stride-0 EPI_N for col broadcast"); + + auto tCgCol_s = tCgCol(_,_,0,_,0); // (CPY,CPY_M,EPI_M) + auto tCrCol_s = tCrCol(_,_,0,_,0); // (CPY,CPY_M,EPI_M) + auto tCcCol_s = tCcCol(_,_,0,_,0); // (CPY,CPY_M,EPI_M) + + cute::Tensor pred = make_tensor(shape(tCgCol_s)); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(pred); ++i) { + pred(i) = get<0>(tCcCol_s(i)) < m; + } + + copy_if(pred, tCgCol_s, tCrCol_s); } template CUTLASS_DEVICE Array visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { Array frg_col; - Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); + cute::Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < FragmentSize; ++i) { @@ -431,16 +439,16 @@ struct Sm90ColOrScalarBroadcastArray { auto [M, N, K, L] = args.problem_shape_mnkl; auto [m, n, k, l] = args.tile_coord_mnkl; - Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_shape(M,N,1), params.dCol); - Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + cute::Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col_array[l]), make_shape(M,N,1), params.dCol); + cute::Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + cute::Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) // Generate an identity tensor matching the shape of the global tensor and // partition the same way, this will be used to generate the predicate // tensor for loading - Tensor cCol = make_identity_tensor(mCol.shape()); - Tensor tCcCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + cute::Tensor cCol = make_identity_tensor(mCol.shape()); + cute::Tensor tCcCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); return ConsumerStoreCallbacks( diff --git a/gptqmodel_ext/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp b/gptqmodel_ext/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp index 58b1e8ff1..29e6ec41e 100644 --- a/gptqmodel_ext/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp +++ b/gptqmodel_ext/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp @@ -186,9 +186,9 @@ struct Sm90RowOrScalarBroadcast { } auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; - Tensor tGS_gRow_flt = filter_zeros(tGS_gRow); - Tensor tGS_sRow_flt = filter_zeros(tGS_sRow); - Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride())); + cute::Tensor tGS_gRow_flt = filter_zeros(tGS_gRow); + cute::Tensor tGS_sRow_flt = filter_zeros(tGS_sRow); + cute::Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride())); for (int i = 0; i < size(tGS_gRow_flt); ++i) { if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) { @@ -208,8 +208,8 @@ struct Sm90RowOrScalarBroadcast { begin_loop(int epi_m, int epi_n) { if (epi_m == 0) { // Assumes M-major subtile loop if (!params.row_broadcast) return; // Do not issue LDS when row is scalar - Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n)); - Tensor tSR_rRow_flt = filter_zeros(tSR_rRow); + cute::Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n)); + cute::Tensor tSR_rRow_flt = filter_zeros(tSR_rRow); copy(tSR_sRow_flt, tSR_rRow_flt); } } @@ -238,9 +238,9 @@ struct Sm90RowOrScalarBroadcast { auto [m, n, k, l] = args.tile_coord_mnkl; using ThreadCount = decltype(size(args.tiled_copy)); - Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow); - Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N) - Tensor sRow = make_tensor(make_smem_ptr(smem), + cute::Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow); + cute::Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N) + cute::Tensor sRow = make_tensor(make_smem_ptr(smem), make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N) //// G2S: Gmem to Smem auto tiled_g2s = make_tiled_copy(Copy_Atom{}, @@ -248,16 +248,16 @@ struct Sm90RowOrScalarBroadcast { Stride<_0, _1>>{}, Layout<_1>{}); auto thr_g2s = tiled_g2s.get_slice(args.thread_idx); - Tensor tGS_gRow = thr_g2s.partition_S(gRow); - Tensor tGS_sRow = thr_g2s.partition_D(sRow); + cute::Tensor tGS_gRow = thr_g2s.partition_S(gRow); + cute::Tensor tGS_sRow = thr_g2s.partition_D(sRow); //// G2S: Coord auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}))); - Tensor tGS_cRow = thr_g2s.partition_S(cRow); + cute::Tensor tGS_cRow = thr_g2s.partition_S(cRow); //// S2R: Smem to Reg - Tensor tSR_sRow = sm90_partition_for_epilogue(sRow, args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N) + cute::Tensor tSR_sRow = sm90_partition_for_epilogue(sRow, args.epi_tile, args.tiled_copy, args.thread_idx); + cute::Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N) return ConsumerStoreCallbacks( tGS_gRow, @@ -382,27 +382,35 @@ struct Sm90ColOrScalarBroadcast { CUTLASS_DEVICE void begin() { - Tensor pred = make_tensor(shape(tCgCol)); - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(pred); ++i) { - pred(i) = get<0>(tCcCol(i)) < m; - } - if (!params.col_broadcast) { fill(tCrCol, *(params.ptr_col)); return; } - // Filter so we don't issue redundant copies over stride-0 modes - // (only works if 0-strides are in same location, which is by construction) - copy_if(pred, filter(tCgCol), filter(tCrCol)); + // tCgCol has layout (CPY,CPY_M,CPY_N,EPI_M,EPI_N) where CPY_N and + // EPI_N are stride-0 for the column broadcast. Slice those modes at + // index 0 to avoid redundant copies AND ensure pred/data consistency + static_assert(decltype(stride<2>(tCgCol))::value == 0, "Expected stride-0 CPY_N for col broadcast"); + static_assert(decltype(stride<4>(tCgCol))::value == 0, "Expected stride-0 EPI_N for col broadcast"); + + auto tCgCol_s = tCgCol(_,_,0,_,0); // (CPY,CPY_M,EPI_M) + auto tCrCol_s = tCrCol(_,_,0,_,0); // (CPY,CPY_M,EPI_M) + auto tCcCol_s = tCcCol(_,_,0,_,0); // (CPY,CPY_M,EPI_M) + + cute::Tensor pred = make_tensor(shape(tCgCol_s)); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(pred); ++i) { + pred(i) = get<0>(tCcCol_s(i)) < m; + } + + copy_if(pred, tCgCol_s, tCrCol_s); } template CUTLASS_DEVICE Array visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { Array frg_col; - Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); + cute::Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < FragmentSize; ++i) { @@ -422,16 +430,16 @@ struct Sm90ColOrScalarBroadcast { get_consumer_store_callbacks(ConsumerStoreArgs const& args) { auto [M, N, K, L] = args.problem_shape_mnkl; - Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol); - Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + cute::Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol); + cute::Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); - Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + cute::Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) // Generate an identity tensor matching the shape of the global tensor and // partition the same way, this will be used to generate the predicate // tensor for loading - Tensor cCol = make_identity_tensor(mCol.shape()); - Tensor tCcCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + cute::Tensor cCol = make_identity_tensor(mCol.shape()); + cute::Tensor tCcCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); return ConsumerStoreCallbacks( diff --git a/gptqmodel_ext/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp b/gptqmodel_ext/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp index c43eea0a0..c2ddcea6d 100644 --- a/gptqmodel_ext/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp +++ b/gptqmodel_ext/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp @@ -3,6 +3,14 @@ #include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp" #include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp" +// This header is shared by both _C (unstable ABI) and _C_stable_libtorch +// (stable ABI) targets. When compiled under the stable ABI target, +// TORCH_TARGET_VERSION is defined and Tensor is unavailable, so we +// use torch::stable::Tensor instead. +#ifdef TORCH_TARGET_VERSION + #include +#endif + /* This file defines custom epilogues for fusing channel scales, token scales, bias, and activation zero-points onto a GEMM operation using the @@ -15,6 +23,12 @@ namespace vllm::c3x { +#ifdef TORCH_TARGET_VERSION +using TensorType = torch::stable::Tensor; +#else +using TensorType = torch::Tensor; +#endif + using namespace cute; template @@ -61,13 +75,13 @@ struct ScaledEpilogueBase { // Don't want to support nullptr by default template using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast< - 0 /*Stages*/, TileShape, T, Stride, Int<0>, Int<0>>, + 0 /*Stages*/, TileShape, T, T, Stride, Int<0>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; // Don't want to support nullptr by default template using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast< - 0 /*Stages*/, TileShape, T, Stride, Int<1>, Int<0>>, + 0 /*Stages*/, TileShape, T, T, Stride, Int<1>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; template @@ -84,7 +98,7 @@ struct ScaledEpilogueBase { // from a tensor. It can handle both row and column, as well as row/column or // scalar cases. template - static auto args_from_tensor(torch::Tensor const& tensor) { + static auto args_from_tensor(TensorType const& tensor) { using Arguments = typename Descriptor::Arguments; auto* data_ptr = static_cast(tensor.data_ptr()); if constexpr (std::is_same_v> || @@ -100,7 +114,7 @@ struct ScaledEpilogueBase { // This overload handles the case where there might not be a tensor, in which // case a nullptr is passed and a constant (0) is used. template - static auto args_from_tensor(std::optional const& tensor) { + static auto args_from_tensor(std::optional const& tensor) { using Arguments = typename Descriptor::Arguments; auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr; static_assert(std::is_same_v> || @@ -158,8 +172,8 @@ struct ScaledEpilogue cutlass::epilogue::fusion::Sm90EVT; using ArgumentType = typename EVTCompute::Arguments; - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { + static ArgumentType prepare_args(TensorType const& a_scales, + TensorType const& b_scales) { auto a_args = SUPER::template args_from_tensor(a_scales); auto b_args = SUPER::template args_from_tensor(b_scales); @@ -203,9 +217,9 @@ struct ScaledEpilogueBias cutlass::epilogue::fusion::Sm90EVT; using ArgumentType = typename EVTCompute::Arguments; - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& bias) { + static ArgumentType prepare_args(TensorType const& a_scales, + TensorType const& b_scales, + TensorType const& bias) { auto a_args = SUPER::template args_from_tensor(a_scales); auto b_args = SUPER::template args_from_tensor(b_scales); auto bias_args = SUPER::template args_from_tensor(bias); @@ -246,9 +260,9 @@ struct ScaledEpilogueColumnBias cutlass::epilogue::fusion::Sm90EVT; using ArgumentType = typename EVTCompute::Arguments; - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& bias) { + static ArgumentType prepare_args(TensorType const& a_scales, + TensorType const& b_scales, + TensorType const& bias) { auto a_args = SUPER::template args_from_tensor(a_scales); auto b_args = SUPER::template args_from_tensor(b_scales); auto bias_args = SUPER::template args_from_tensor(bias); @@ -304,10 +318,10 @@ struct ScaledEpilogueBiasAzp EVTComputeScaleB, Bias>; using ArgumentType = typename EVTCompute::Arguments; - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& azp_adj, - std::optional const& bias) { + static ArgumentType prepare_args(TensorType const& a_scales, + TensorType const& b_scales, + TensorType const& azp_adj, + std::optional const& bias) { auto a_args = SUPER::template args_from_tensor(a_scales); auto b_args = SUPER::template args_from_tensor(b_scales); auto bias_args = SUPER::template args_from_tensor(bias); @@ -380,11 +394,11 @@ struct ScaledEpilogueBiasAzpToken EVTComputeScaleB, Bias>; using ArgumentType = typename EVTCompute::Arguments; - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& azp_adj, - torch::Tensor const& azp, - std::optional const& bias) { + static ArgumentType prepare_args(TensorType const& a_scales, + TensorType const& b_scales, + TensorType const& azp_adj, + TensorType const& azp, + std::optional const& bias) { auto a_args = SUPER::template args_from_tensor(a_scales); auto b_args = SUPER::template args_from_tensor(b_scales); auto bias_args = SUPER::template args_from_tensor(bias); diff --git a/gptqmodel_ext/cutlass_extensions/torch_utils.hpp b/gptqmodel_ext/cutlass_extensions/torch_utils.hpp index a1ff933cc..7644b3bbb 100644 --- a/gptqmodel_ext/cutlass_extensions/torch_utils.hpp +++ b/gptqmodel_ext/cutlass_extensions/torch_utils.hpp @@ -1,6 +1,27 @@ #pragma once -#include +// This header is shared between _C (unstable ABI, used by machete) and +// _C_stable_libtorch (stable ABI, used by W4A8/sparse). TORCH_TARGET_VERSION +// is defined only for the stable target, so we switch includes and types +// accordingly. TorchTensor (not Tensor) avoids ambiguity with cute::Tensor. +#ifdef TORCH_TARGET_VERSION + #include + #include + #include + #include // for STD_TORCH_CHECK +using TorchTensor = torch::stable::Tensor; +using TorchHalf = torch::headeronly::Half; +using TorchBFloat16 = torch::headeronly::BFloat16; +using TorchScalarType = torch::headeronly::ScalarType; + #define TORCH_UTILS_CHECK STD_TORCH_CHECK +#else + #include +using TorchTensor = torch::Tensor; +using TorchHalf = c10::Half; +using TorchBFloat16 = c10::BFloat16; +using TorchScalarType = c10::ScalarType; + #define TORCH_UTILS_CHECK TORCH_CHECK +#endif #include "cute/layout.hpp" #include "cutlass/layout/matrix.h" @@ -55,35 +76,35 @@ CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f) { // If `tensor.dim() < rank(Stride{})`, the shape is padded with 1s and the extra // strides are set to be 0 or 1. template -static inline auto make_cute_layout(torch::Tensor const& tensor, +static inline auto make_cute_layout(TorchTensor const& tensor, std::string_view name = "tensor") { - TORCH_CHECK(tensor.dim() <= rank(Stride{})); - auto stride = cute::transform_with_idx( - Stride{}, [&](auto const& stride_ele, auto const& idx) { - using StrideEle = std::decay_t; - - if (idx < tensor.dim()) { - if constexpr (cute::is_static_v) { - TORCH_CHECK(StrideEle::value == tensor.stride(idx), "Expected ", - name, ".stride(", idx, ") to be ", StrideEle::value); - return StrideEle{}; - } else { - if (tensor.size(idx) == 1) { - // use 0 stride for dim with size 1, this is easier for - // cute/cutlass to optimize (helps the TMA code flatten dims) - return StrideEle{0}; - } else { - return tensor.stride(idx); - } - } + TORCH_UTILS_CHECK(tensor.dim() <= rank(Stride{})); + auto stride = cute::transform_with_idx(Stride{}, [&](auto const& stride_ele, + auto const& idx) { + using StrideEle = std::decay_t; + + if (idx < tensor.dim()) { + if constexpr (cute::is_static_v) { + TORCH_UTILS_CHECK(StrideEle::value == tensor.stride(idx), "Expected ", + name, ".stride(", idx, ") to be ", StrideEle::value); + return StrideEle{}; + } else { + if (tensor.size(idx) == 1) { + // use 0 stride for dim with size 1, this is easier for + // cute/cutlass to optimize (helps the TMA code flatten dims) + return StrideEle{0}; } else { - // Extra strides are assumed to be 0 or 1 - if constexpr (cute::is_static_v) { - static_assert(StrideEle::value == 0 || StrideEle::value == 1); - } - return StrideEle{}; + return tensor.stride(idx); } - }); + } + } else { + // Extra strides are assumed to be 0 or 1 + if constexpr (cute::is_static_v) { + static_assert(StrideEle::value == 0 || StrideEle::value == 1); + } + return StrideEle{}; + } + }); auto shape = cute::make_shape_from_idx([&](auto const& idx) { if (idx < tensor.dim()) @@ -97,7 +118,7 @@ static inline auto make_cute_layout(torch::Tensor const& tensor, template static inline auto maybe_make_cute_layout( - std::optional const& tensor, + std::optional const& tensor, std::string_view name = "tensor") { using Layout = decltype(make_cute_layout(*tensor)); @@ -121,12 +142,12 @@ template using equivalent_cutlass_type_t = typename equivalent_cutlass_type::type; template <> -struct equivalent_cutlass_type { +struct equivalent_cutlass_type { using type = cutlass::half_t; }; template <> -struct equivalent_cutlass_type { +struct equivalent_cutlass_type { using type = cutlass::bfloat16_t; }; @@ -134,8 +155,8 @@ struct equivalent_cutlass_type { // equivalent_scalar_t (basically inverse of equivalent_cutlass_type) // -// Return a `c10::CppTypeToScalarType` compatible type, i.e. get the C++ from -// c10 that is equivalent to T, e.g.: `cutlass::half_t -> c10::Half` +// Return a type equivalent to T in the current Torch ABI, e.g.: +// `cutlass::half_t -> Half` template struct equivalent_scalar_type { using type = T; @@ -146,15 +167,20 @@ using equivalent_scalar_type_t = typename equivalent_scalar_type::type; template <> struct equivalent_scalar_type { - using type = c10::Half; + using type = TorchHalf; }; template <> struct equivalent_scalar_type { - using type = c10::BFloat16; + using type = TorchBFloat16; }; -// get equivalent c10::ScalarType tag from compile time type +// get the equivalent scalar-type tag from the compile-time type template -static inline constexpr c10::ScalarType equivalent_scalar_type_v = - c10::CppTypeToScalarType>::value; \ No newline at end of file +#ifdef TORCH_TARGET_VERSION +static inline constexpr TorchScalarType equivalent_scalar_type_v = + torch::headeronly::CppTypeToScalarType>::value; +#else +static inline constexpr TorchScalarType equivalent_scalar_type_v = + c10::CppTypeToScalarType>::value; +#endif diff --git a/gptqmodel_ext/cutlass_extensions/vllm_cutlass_library_extension.py b/gptqmodel_ext/cutlass_extensions/vllm_cutlass_library_extension.py index 34fb64c41..a6b4a9af9 100644 --- a/gptqmodel_ext/cutlass_extensions/vllm_cutlass_library_extension.py +++ b/gptqmodel_ext/cutlass_extensions/vllm_cutlass_library_extension.py @@ -5,6 +5,9 @@ from cutlass_library import * +if "enum_auto" not in globals(): + from enum import auto as enum_auto + # # Extend cutlass library with custom types, and missing values # diff --git a/gptqmodel_ext/exllamav2/ext_awq.cpp b/gptqmodel_ext/exllamav2/ext_awq.cpp index fb1128bc4..86fe7ac4b 100644 --- a/gptqmodel_ext/exllamav2/ext_awq.cpp +++ b/gptqmodel_ext/exllamav2/ext_awq.cpp @@ -1,39 +1,59 @@ #include "ext_common.h" +#include +#include #include "cuda/q_matrix.cuh" #include "cuda/q_gemm.cuh" -uintptr_t make_q_matrix( +int64_t make_q_matrix( torch::Tensor q_weight, - torch::Tensor q_perm, - torch::Tensor q_invperm, - torch::Tensor q_scale, - torch::Tensor q_scale_max, - torch::Tensor q_groups, - torch::Tensor gptq_qzeros, - torch::Tensor gptq_scales, - torch::Tensor gptq_g_idx, + std::optional q_perm_opt, + std::optional q_invperm_opt, + std::optional q_scale_opt, + std::optional q_scale_max_opt, + std::optional q_groups_opt, + std::optional gptq_qzeros_opt, + std::optional gptq_scales_opt, + std::optional gptq_g_idx_opt, torch::Tensor temp_dq ) { + torch::Tensor q_perm = q_perm_opt.value_or(torch::Tensor()); + torch::Tensor q_invperm = q_invperm_opt.value_or(torch::Tensor()); + torch::Tensor q_scale = q_scale_opt.value_or(torch::Tensor()); + torch::Tensor q_scale_max = q_scale_max_opt.value_or(torch::Tensor()); + torch::Tensor q_groups = q_groups_opt.value_or(torch::Tensor()); + torch::Tensor gptq_qzeros = gptq_qzeros_opt.value_or(torch::Tensor()); + torch::Tensor gptq_scales = gptq_scales_opt.value_or(torch::Tensor()); + torch::Tensor gptq_g_idx = gptq_g_idx_opt.value_or(torch::Tensor()); + TORCH_CHECK_DTYPE(q_weight, kInt); - TORCH_CHECK_DTYPE_OPT(q_perm, kShort); - TORCH_CHECK_DTYPE_OPT(q_invperm, kShort); - TORCH_CHECK_DTYPE_OPT(q_scale, kInt); - TORCH_CHECK_DTYPE_OPT(q_scale_max, kHalf); - TORCH_CHECK_DTYPE_OPT(q_groups, kShort); - TORCH_CHECK_DTYPE_OPT(gptq_qzeros, kInt); - TORCH_CHECK_DTYPE_OPT(gptq_scales, kHalf); - TORCH_CHECK_DTYPE_OPT(gptq_g_idx, kInt); + TORCH_CHECK(!q_perm.defined() || q_perm.device().is_meta() || q_perm.dtype() == torch::kShort, + "q_perm is incorrect datatype, must be kShort"); + TORCH_CHECK(!q_invperm.defined() || q_invperm.device().is_meta() || q_invperm.dtype() == torch::kShort, + "q_invperm is incorrect datatype, must be kShort"); + TORCH_CHECK(!q_scale.defined() || q_scale.device().is_meta() || q_scale.dtype() == torch::kInt, + "q_scale is incorrect datatype, must be kInt"); + TORCH_CHECK(!q_scale_max.defined() || q_scale_max.device().is_meta() || q_scale_max.dtype() == torch::kHalf, + "q_scale_max is incorrect datatype, must be kHalf"); + TORCH_CHECK(!q_groups.defined() || q_groups.device().is_meta() || q_groups.dtype() == torch::kShort, + "q_groups is incorrect datatype, must be kShort"); + TORCH_CHECK(!gptq_qzeros.defined() || gptq_qzeros.device().is_meta() || gptq_qzeros.dtype() == torch::kInt, + "gptq_qzeros is incorrect datatype, must be kInt"); + TORCH_CHECK(!gptq_scales.defined() || gptq_scales.device().is_meta() || gptq_scales.dtype() == torch::kHalf, + "gptq_scales is incorrect datatype, must be kHalf"); + TORCH_CHECK(!gptq_g_idx.defined() || gptq_g_idx.device().is_meta() || gptq_g_idx.dtype() == torch::kInt, + "gptq_g_idx is incorrect datatype, must be kInt"); - TORCH_CHECK_SHAPES(q_perm, 0, q_invperm, 0, 1); + TORCH_CHECK(!q_perm.defined() || !q_invperm.defined() || q_perm.size(0) == q_invperm.size(0), + "q_perm and q_invperm have incompatible shapes"); int device = q_weight.device().index(); int width = q_weight.size(1); int groups; int height; - if (!q_scale.device().is_meta()) + if (q_scale.defined() && !q_scale.device().is_meta()) { TORCH_CHECK_SHAPES(q_weight, 1, q_scale, 1, 8); TORCH_CHECK_SHAPES(q_scale_max, 0, q_scale, 0, 1); @@ -56,28 +76,28 @@ uintptr_t make_q_matrix( width, groups, (uint32_t*) q_weight.data_ptr(), - q_perm.device().is_meta() ? NULL : (uint16_t*) q_perm.data_ptr(), - q_invperm.device().is_meta() ? NULL : (uint16_t*) q_invperm.data_ptr(), - q_scale.device().is_meta() ? NULL : (uint32_t*) q_scale.data_ptr(), - q_scale_max.device().is_meta() ? NULL : (half*) q_scale_max.data_ptr(), - q_groups.device().is_meta() ? NULL : (uint16_t*) q_groups.data_ptr(), - gptq_qzeros.device().is_meta() ? NULL : (uint32_t*) gptq_qzeros.data_ptr(), - gptq_scales.device().is_meta() ? NULL : (half*) gptq_scales.data_ptr(), - gptq_g_idx.device().is_meta() ? NULL : (uint32_t*) gptq_g_idx.data_ptr(), + (!q_perm.defined() || q_perm.device().is_meta()) ? NULL : (uint16_t*) q_perm.data_ptr(), + (!q_invperm.defined() || q_invperm.device().is_meta()) ? NULL : (uint16_t*) q_invperm.data_ptr(), + (!q_scale.defined() || q_scale.device().is_meta()) ? NULL : (uint32_t*) q_scale.data_ptr(), + (!q_scale_max.defined() || q_scale_max.device().is_meta()) ? NULL : (half*) q_scale_max.data_ptr(), + (!q_groups.defined() || q_groups.device().is_meta()) ? NULL : (uint16_t*) q_groups.data_ptr(), + (!gptq_qzeros.defined() || gptq_qzeros.device().is_meta()) ? NULL : (uint32_t*) gptq_qzeros.data_ptr(), + (!gptq_scales.defined() || gptq_scales.device().is_meta()) ? NULL : (half*) gptq_scales.data_ptr(), + (!gptq_g_idx.defined() || gptq_g_idx.device().is_meta()) ? NULL : (uint32_t*) gptq_g_idx.data_ptr(), (half*) temp_dq.data_ptr() ); - return reinterpret_cast(m); + return static_cast(reinterpret_cast(m)); } void gemm_half_q_half( torch::Tensor a, - uintptr_t b, - torch::Tensor c, + int64_t b, + torch::Tensor& c, bool force_cuda ) { - QMatrix* qm = reinterpret_cast(b); + QMatrix* qm = reinterpret_cast(static_cast(b)); TORCH_CHECK_DTYPE(a, kHalf); TORCH_CHECK_DTYPE(c, kHalf); @@ -101,8 +121,14 @@ void gemm_half_q_half( ); } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +TORCH_LIBRARY(gptqmodel_exllamav2_awq, m) { - m.def("make_q_matrix_awq", &make_q_matrix, "make_q_matrix_awq"); - m.def("gemm_half_q_half_awq", &gemm_half_q_half, "gemm_half_q_half_awq"); -} \ No newline at end of file + m.def("make_q_matrix_awq(Tensor q_weight, Tensor? q_perm, Tensor? q_invperm, Tensor? q_scale, Tensor? q_scale_max, Tensor? q_groups, Tensor? gptq_qzeros, Tensor? gptq_scales, Tensor? gptq_g_idx, Tensor temp_dq) -> int"); + m.def("gemm_half_q_half_awq(Tensor a, int b, Tensor(a!) c, bool force_cuda=False) -> ()"); +} + +TORCH_LIBRARY_IMPL(gptqmodel_exllamav2_awq, CUDA, m) +{ + m.impl("make_q_matrix_awq", &make_q_matrix); + m.impl("gemm_half_q_half_awq", &gemm_half_q_half); +} diff --git a/gptqmodel_ext/exllamav2/ext_gptq.cpp b/gptqmodel_ext/exllamav2/ext_gptq.cpp index eec2d217c..f43ec2831 100644 --- a/gptqmodel_ext/exllamav2/ext_gptq.cpp +++ b/gptqmodel_ext/exllamav2/ext_gptq.cpp @@ -1,39 +1,59 @@ #include "ext_common.h" +#include +#include #include "cuda/q_matrix.cuh" #include "cuda/q_gemm.cuh" -uintptr_t make_q_matrix( +int64_t make_q_matrix( torch::Tensor q_weight, - torch::Tensor q_perm, - torch::Tensor q_invperm, - torch::Tensor q_scale, - torch::Tensor q_scale_max, - torch::Tensor q_groups, - torch::Tensor gptq_qzeros, - torch::Tensor gptq_scales, - torch::Tensor gptq_g_idx, + std::optional q_perm_opt, + std::optional q_invperm_opt, + std::optional q_scale_opt, + std::optional q_scale_max_opt, + std::optional q_groups_opt, + std::optional gptq_qzeros_opt, + std::optional gptq_scales_opt, + std::optional gptq_g_idx_opt, torch::Tensor temp_dq ) { + torch::Tensor q_perm = q_perm_opt.value_or(torch::Tensor()); + torch::Tensor q_invperm = q_invperm_opt.value_or(torch::Tensor()); + torch::Tensor q_scale = q_scale_opt.value_or(torch::Tensor()); + torch::Tensor q_scale_max = q_scale_max_opt.value_or(torch::Tensor()); + torch::Tensor q_groups = q_groups_opt.value_or(torch::Tensor()); + torch::Tensor gptq_qzeros = gptq_qzeros_opt.value_or(torch::Tensor()); + torch::Tensor gptq_scales = gptq_scales_opt.value_or(torch::Tensor()); + torch::Tensor gptq_g_idx = gptq_g_idx_opt.value_or(torch::Tensor()); + TORCH_CHECK_DTYPE(q_weight, kInt); - TORCH_CHECK_DTYPE_OPT(q_perm, kShort); - TORCH_CHECK_DTYPE_OPT(q_invperm, kShort); - TORCH_CHECK_DTYPE_OPT(q_scale, kInt); - TORCH_CHECK_DTYPE_OPT(q_scale_max, kHalf); - TORCH_CHECK_DTYPE_OPT(q_groups, kShort); - TORCH_CHECK_DTYPE_OPT(gptq_qzeros, kInt); - TORCH_CHECK_DTYPE_OPT(gptq_scales, kHalf); - TORCH_CHECK_DTYPE_OPT(gptq_g_idx, kInt); + TORCH_CHECK(!q_perm.defined() || q_perm.device().is_meta() || q_perm.dtype() == torch::kShort, + "q_perm is incorrect datatype, must be kShort"); + TORCH_CHECK(!q_invperm.defined() || q_invperm.device().is_meta() || q_invperm.dtype() == torch::kShort, + "q_invperm is incorrect datatype, must be kShort"); + TORCH_CHECK(!q_scale.defined() || q_scale.device().is_meta() || q_scale.dtype() == torch::kInt, + "q_scale is incorrect datatype, must be kInt"); + TORCH_CHECK(!q_scale_max.defined() || q_scale_max.device().is_meta() || q_scale_max.dtype() == torch::kHalf, + "q_scale_max is incorrect datatype, must be kHalf"); + TORCH_CHECK(!q_groups.defined() || q_groups.device().is_meta() || q_groups.dtype() == torch::kShort, + "q_groups is incorrect datatype, must be kShort"); + TORCH_CHECK(!gptq_qzeros.defined() || gptq_qzeros.device().is_meta() || gptq_qzeros.dtype() == torch::kInt, + "gptq_qzeros is incorrect datatype, must be kInt"); + TORCH_CHECK(!gptq_scales.defined() || gptq_scales.device().is_meta() || gptq_scales.dtype() == torch::kHalf, + "gptq_scales is incorrect datatype, must be kHalf"); + TORCH_CHECK(!gptq_g_idx.defined() || gptq_g_idx.device().is_meta() || gptq_g_idx.dtype() == torch::kInt, + "gptq_g_idx is incorrect datatype, must be kInt"); - TORCH_CHECK_SHAPES(q_perm, 0, q_invperm, 0, 1); + TORCH_CHECK(!q_perm.defined() || !q_invperm.defined() || q_perm.size(0) == q_invperm.size(0), + "q_perm and q_invperm have incompatible shapes"); int device = q_weight.device().index(); int width = q_weight.size(1); int groups; int height; - if (!q_scale.device().is_meta()) + if (q_scale.defined() && !q_scale.device().is_meta()) { TORCH_CHECK_SHAPES(q_weight, 1, q_scale, 1, 8); TORCH_CHECK_SHAPES(q_scale_max, 0, q_scale, 0, 1); @@ -56,28 +76,28 @@ uintptr_t make_q_matrix( width, groups, (uint32_t*) q_weight.data_ptr(), - q_perm.device().is_meta() ? NULL : (uint16_t*) q_perm.data_ptr(), - q_invperm.device().is_meta() ? NULL : (uint16_t*) q_invperm.data_ptr(), - q_scale.device().is_meta() ? NULL : (uint32_t*) q_scale.data_ptr(), - q_scale_max.device().is_meta() ? NULL : (half*) q_scale_max.data_ptr(), - q_groups.device().is_meta() ? NULL : (uint16_t*) q_groups.data_ptr(), - gptq_qzeros.device().is_meta() ? NULL : (uint32_t*) gptq_qzeros.data_ptr(), - gptq_scales.device().is_meta() ? NULL : (half*) gptq_scales.data_ptr(), - gptq_g_idx.device().is_meta() ? NULL : (uint32_t*) gptq_g_idx.data_ptr(), + (!q_perm.defined() || q_perm.device().is_meta()) ? NULL : (uint16_t*) q_perm.data_ptr(), + (!q_invperm.defined() || q_invperm.device().is_meta()) ? NULL : (uint16_t*) q_invperm.data_ptr(), + (!q_scale.defined() || q_scale.device().is_meta()) ? NULL : (uint32_t*) q_scale.data_ptr(), + (!q_scale_max.defined() || q_scale_max.device().is_meta()) ? NULL : (half*) q_scale_max.data_ptr(), + (!q_groups.defined() || q_groups.device().is_meta()) ? NULL : (uint16_t*) q_groups.data_ptr(), + (!gptq_qzeros.defined() || gptq_qzeros.device().is_meta()) ? NULL : (uint32_t*) gptq_qzeros.data_ptr(), + (!gptq_scales.defined() || gptq_scales.device().is_meta()) ? NULL : (half*) gptq_scales.data_ptr(), + (!gptq_g_idx.defined() || gptq_g_idx.device().is_meta()) ? NULL : (uint32_t*) gptq_g_idx.data_ptr(), (half*) temp_dq.data_ptr() ); - return reinterpret_cast(m); + return static_cast(reinterpret_cast(m)); } void gemm_half_q_half( torch::Tensor a, - uintptr_t b, - torch::Tensor c, + int64_t b, + torch::Tensor& c, bool force_cuda ) { - QMatrix* qm = reinterpret_cast(b); + QMatrix* qm = reinterpret_cast(static_cast(b)); TORCH_CHECK_DTYPE(a, kHalf); TORCH_CHECK_DTYPE(c, kHalf); @@ -101,8 +121,14 @@ void gemm_half_q_half( ); } -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +TORCH_LIBRARY(gptqmodel_exllamav2, m) { - m.def("make_q_matrix", &make_q_matrix, "make_q_matrix"); - m.def("gemm_half_q_half", &gemm_half_q_half, "gemm_half_q_half"); -} \ No newline at end of file + m.def("make_q_matrix(Tensor q_weight, Tensor? q_perm, Tensor? q_invperm, Tensor? q_scale, Tensor? q_scale_max, Tensor? q_groups, Tensor? gptq_qzeros, Tensor? gptq_scales, Tensor? gptq_g_idx, Tensor temp_dq) -> int"); + m.def("gemm_half_q_half(Tensor a, int b, Tensor(a!) c, bool force_cuda=False) -> ()"); +} + +TORCH_LIBRARY_IMPL(gptqmodel_exllamav2, CUDA, m) +{ + m.impl("make_q_matrix", &make_q_matrix); + m.impl("gemm_half_q_half", &gemm_half_q_half); +} diff --git a/gptqmodel_ext/exllamav3/bindings.cpp b/gptqmodel_ext/exllamav3/bindings.cpp new file mode 100644 index 000000000..bbcdb9d10 --- /dev/null +++ b/gptqmodel_ext/exllamav3/bindings.cpp @@ -0,0 +1,131 @@ +#include + +#include +#include +#include + +#include "hadamard.h" +#include "hgemm.cuh" + +#include "quant/quantize.cuh" +#include "quant/pack.cuh" +#include "quant/reconstruct.cuh" +#include "quant/hadamard.cuh" + +#include "libtorch/linear.h" + +namespace +{ + +void had_paley_wrapper(torch::Tensor& h) +{ + had_paley(h); +} + +void had_paley2_wrapper(torch::Tensor& h) +{ + had_paley2(h); +} + +void quantize_tiles_wrapper( + torch::Tensor input_tiles, + torch::Tensor& output_tiles, + torch::Tensor& output_indices, + torch::Tensor& temp_costs, + torch::Tensor& temp_edges, + int64_t K, + bool mcg, + bool mul1 +) +{ + quantize_tiles(input_tiles, output_tiles, output_indices, temp_costs, temp_edges, static_cast(K), mcg, mul1); +} + +void pack_trellis_wrapper(torch::Tensor& packed, torch::Tensor unpacked, int64_t K) +{ + pack_trellis(packed, unpacked, static_cast(K)); +} + +void unpack_trellis_wrapper(torch::Tensor& unpacked, torch::Tensor packed, int64_t K) +{ + unpack_trellis(unpacked, packed, static_cast(K)); +} + +void pack_signs_wrapper(torch::Tensor& packed, torch::Tensor unpacked) +{ + pack_signs(packed, unpacked); +} + +void reconstruct_wrapper(torch::Tensor& unpacked, torch::Tensor packed, int64_t K, bool mcg, bool mul1) +{ + reconstruct(unpacked, packed, static_cast(K), mcg, mul1); +} + +void had_r_128_wrapper( + torch::Tensor input, + torch::Tensor& output, + std::optional pre_scale_opt, + std::optional post_scale_opt, + double scale +) +{ + const c10::optional pre_scale = pre_scale_opt.has_value() ? c10::optional(pre_scale_opt.value()) : c10::nullopt; + const c10::optional post_scale = post_scale_opt.has_value() ? c10::optional(post_scale_opt.value()) : c10::nullopt; + had_r_128(input, output, pre_scale, post_scale, static_cast(scale)); +} + +void hgemm_wrapper(torch::Tensor a, torch::Tensor b, torch::Tensor& c) +{ + hgemm(a, b, c); +} + +void bc_linear_exl3_run_wrapper( + torch::Tensor trellis, + torch::Tensor suh, + torch::Tensor svh, + int64_t K, + std::optional bias_opt, + bool mcg, + bool mul1, + torch::Tensor& xh, + torch::Tensor x, + torch::Tensor& y +) +{ + const c10::optional bias = bias_opt.has_value() ? c10::optional(bias_opt.value()) : c10::nullopt; + bc_linear_exl3_run(trellis, suh, svh, K, bias, mcg, mul1, xh, x, y); +} + +} // namespace + +TORCH_LIBRARY(gptqmodel_exllamav3, m) +{ + m.def("had_paley(Tensor(a!) h) -> ()"); + m.def("had_paley2(Tensor(a!) h) -> ()"); + m.def("quantize_tiles(Tensor input_tiles, Tensor(a!) output_tiles, Tensor(b!) output_indices, Tensor(c!) temp_costs, Tensor(d!) temp_edges, int K, bool mcg, bool mul1) -> ()"); + m.def("pack_trellis(Tensor(a!) packed, Tensor unpacked, int K) -> ()"); + m.def("unpack_trellis(Tensor(a!) unpacked, Tensor packed, int K) -> ()"); + m.def("pack_signs(Tensor(a!) packed, Tensor unpacked) -> ()"); + m.def("reconstruct(Tensor(a!) unpacked, Tensor packed, int K, bool mcg, bool mul1) -> ()"); + m.def("had_r_128(Tensor input, Tensor(a!) output, Tensor? pre_scale, Tensor? post_scale, float scale) -> ()"); + m.def("hgemm(Tensor a, Tensor b, Tensor(a!) c) -> ()"); + m.def("bc_linear_exl3_run(Tensor trellis, Tensor suh, Tensor svh, int K, Tensor? bias, bool mcg, bool mul1, Tensor(a!) xh, Tensor x, Tensor(b!) y) -> ()"); +} + +TORCH_LIBRARY_IMPL(gptqmodel_exllamav3, CPU, m) +{ + m.impl("had_paley", &had_paley_wrapper); + m.impl("had_paley2", &had_paley2_wrapper); +} + +TORCH_LIBRARY_IMPL(gptqmodel_exllamav3, CUDA, m) +{ + m.impl("quantize_tiles", &quantize_tiles_wrapper); + m.impl("pack_trellis", &pack_trellis_wrapper); + m.impl("unpack_trellis", &unpack_trellis_wrapper); + m.impl("pack_signs", &pack_signs_wrapper); + m.impl("reconstruct", &reconstruct_wrapper); + m.impl("had_r_128", &had_r_128_wrapper); + m.impl("hgemm", &hgemm_wrapper); + m.impl("bc_linear_exl3_run", &bc_linear_exl3_run_wrapper); +} diff --git a/gptqmodel_ext/exllamav3/hadamard.cpp b/gptqmodel_ext/exllamav3/hadamard.cpp new file mode 100644 index 000000000..0a9cc2ec3 --- /dev/null +++ b/gptqmodel_ext/exllamav3/hadamard.cpp @@ -0,0 +1,112 @@ +#include "hadamard.h" +#include "util.h" + +#define HALF_P 0x3C00 +#define HALF_N 0xBC00 +#define HALF_PP 0x3C003C00 +#define HALF_PN 0xBC003C00 +#define HALF_NP 0x3C00BC00 +#define HALF_NN 0xBC00BC00 + +inline int pmod(int a, int b) +{ + int ret = a % b; + if (ret < 0 && b > 0) ret += b; + return ret; +} + +inline int modular_pow(int base, int exp, int mod) +{ + int result = 1; + base = pmod(base, mod); + while (exp > 0) + { + if (exp % 2 == 1) result = pmod((result * base), mod); + exp = exp >> 1; + base = pmod((base * base), mod); + } + return result; +} + +inline bool is_quadratic_residue(int a, int p) +{ + return modular_pow(a, (p - 1) / 2, p) == 1; +} + +// Paley construction + +void had_paley +( + at::Tensor h +) +{ + TORCH_CHECK_DTYPE(h, kHalf); + TORCH_CHECK_SHAPES(h, 0, h, 1, 1); + TORCH_CHECK(h.is_contiguous()); + int n = h.size(0); + int p = n - 1; + uint16_t* ptr = (uint16_t*) h.data_ptr(); + + for (int j = 0; j < n; ++j) + *ptr++ = HALF_P; + + for (int i = 0; i < p; ++i) + { + *ptr++ = HALF_N; + for (int j = 0; j < p; ++j) + { + if (i == j) *ptr++ = HALF_P; + else + { + int residue = pmod(i - j, p); + if (is_quadratic_residue(residue, p)) + *ptr++ = HALF_P; + else + *ptr++ = HALF_N; + } + } + } +} + +// Paley construction, type 2 + +void had_paley2 +( + at::Tensor h +) +{ + TORCH_CHECK_DTYPE(h, kHalf); + TORCH_CHECK_SHAPES(h, 0, h, 1, 1); + int n = h.size(0); + int p = n / 2 - 1; + uint32_t* ptr0 = (uint32_t*) h.data_ptr(); + uint32_t* ptr1 = ptr0 + n / 2; + + for (int i = 0; i < n / 2; ++i) + { + for (int j = 0; j < n / 2; ++j) + { + if (i == j) + { + *ptr0++ = HALF_PN; + *ptr1++ = HALF_NN; + } + else + { + int residue = pmod(i - j, p); + if (i == 0 || j == 0 || is_quadratic_residue(residue, p)) + { + *ptr0++ = HALF_PP; + *ptr1++ = HALF_PN; + } + else + { + *ptr0++ = HALF_NN; + *ptr1++ = HALF_NP; + } + } + } + ptr0 += n / 2; + ptr1 += n / 2; + } +} diff --git a/gptqmodel_ext/exllamav3/hadamard.h b/gptqmodel_ext/exllamav3/hadamard.h new file mode 100644 index 000000000..fe4f81f09 --- /dev/null +++ b/gptqmodel_ext/exllamav3/hadamard.h @@ -0,0 +1,13 @@ +#pragma once + +#include + +void had_paley +( + at::Tensor h +); + +void had_paley2 +( + at::Tensor h +); \ No newline at end of file diff --git a/gptqmodel_ext/exllamav3/hgemm.cu b/gptqmodel_ext/exllamav3/hgemm.cu new file mode 100644 index 000000000..bd1b807de --- /dev/null +++ b/gptqmodel_ext/exllamav3/hgemm.cu @@ -0,0 +1,93 @@ +#include +#include "hgemm.cuh" +#include +#include +#include "util.h" +#include "util.cuh" +#include "quant/exl3_devctx.cuh" + +/* + +Row-major matmul using cuBLAS, a @ b -> c +- if c is float16, operation is float16 @ float16 -> float16 (float16 accumulate) +- if c is float32, operation is float16 @ float16 -> float32 (float32 accumulate) +*/ + +void hgemm +( + at::Tensor a, + at::Tensor b, + at::Tensor c +) +{ + const at::cuda::OptionalCUDAGuard device_guard(a.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + bool output_fp32 = c.dtype() == at::kFloat; + bool output_fp16 = c.dtype() == at::kHalf; + + TORCH_CHECK(output_fp32 || output_fp16, "c must be float32 or float16"); + + // Check shapes of a,b,c are compatible + TORCH_CHECK_DTYPE(a, kHalf); + TORCH_CHECK_DTYPE(b, kHalf); + TORCH_CHECK_DIM(b, 2); + TORCH_CHECK_SHAPES(a, -1, b, 0, 1); + TORCH_CHECK_SHAPES(b, 1, c, -1, 1); + + const half* a_ptr = (const half*) a.data_ptr(); + const half* b_ptr = (const half*) b.data_ptr(); + + int size_k = a.size(-1); + int size_m = a.numel() / size_k; + int size_n = b.size(-1); + + // Set cuBLAS modes and workspace + cublasHandle_t cublas_handle = at::cuda::getCurrentCUDABlasHandle(); + cublasSetStream(cublas_handle, stream); + cublasSetPointerMode(cublas_handle, CUBLAS_POINTER_MODE_HOST); + int device; + cudaGetDevice(&device); + void* ws = DevCtx::instance().get_ws(device); + cublasSetWorkspace(cublas_handle, ws, WORKSPACE_SIZE); + + if (output_fp16) + { + half alpha_ = __float2half(1.0f); + half beta_ = __float2half(0.0f); + + half* c_ptr = (half*) c.data_ptr(); + auto r = cublasHgemm + ( + cublas_handle, + CUBLAS_OP_N, + CUBLAS_OP_N, + size_n, size_m, size_k, + &alpha_, b_ptr, size_n, + a_ptr, size_k, + &beta_, c_ptr, size_n + ); + cublas_check(r); + cuda_check(cudaPeekAtLastError()); + } + if (output_fp32) + { + float alpha_ = 1.0f; + float beta_ = 0.0f; + + float* c_ptr = (float*) c.data_ptr(); + auto r = cublasGemmEx + ( + cublas_handle, + CUBLAS_OP_N, CUBLAS_OP_N, + size_n, size_m, size_k, + &alpha_, b_ptr, CUDA_R_16F, size_n, + a_ptr, CUDA_R_16F, size_k, + &beta_, c_ptr, CUDA_R_32F, size_n, + CUBLAS_COMPUTE_32F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP + ); + cublas_check(r); + cuda_check(cudaPeekAtLastError()); + } +} diff --git a/gptqmodel_ext/exllamav3/hgemm.cuh b/gptqmodel_ext/exllamav3/hgemm.cuh new file mode 100644 index 000000000..9e3ee3578 --- /dev/null +++ b/gptqmodel_ext/exllamav3/hgemm.cuh @@ -0,0 +1,10 @@ +#pragma once + +#include + +void hgemm +( + at::Tensor a, + at::Tensor b, + at::Tensor c +); diff --git a/gptqmodel_ext/exllamav3/libtorch/linear.cpp b/gptqmodel_ext/exllamav3/libtorch/linear.cpp new file mode 100644 index 000000000..73cfdc77d --- /dev/null +++ b/gptqmodel_ext/exllamav3/libtorch/linear.cpp @@ -0,0 +1,41 @@ +#include "linear.h" +#include +#include +#include +#include "../util.h" +#include "../quant/exl3_gemm.cuh" + +void bc_linear_exl3_run +( + at::Tensor trellis, + at::Tensor suh, + at::Tensor svh, + int64_t K, + const c10::optional& bias, + bool mcg, + bool mul1, + at::Tensor& xh, + const at::Tensor& x, + at::Tensor& y +) +{ + TORCH_CHECK(K == trellis.size(-1) / 16, "K does not match packed trellis width"); + + if (x.numel() == x.size(-1)) + { + exl3_gemm(x, trellis, y, suh, xh, svh, -1, mcg, mul1, 0); + } + else + { + at::Tensor xh_ = at::empty_like(x); + exl3_gemm(x, trellis, y, suh, xh_, svh, -1, mcg, mul1, 0); + } + + if (bias) + y.add_(bias.value()); +} + +void BC_LinearEXL3::run(const at::Tensor& x, at::Tensor& y) +{ + bc_linear_exl3_run(trellis, suh, svh, K, bias, mcg, mul1, xh, x, y); +} diff --git a/gptqmodel_ext/exllamav3/libtorch/linear.h b/gptqmodel_ext/exllamav3/libtorch/linear.h new file mode 100644 index 000000000..f8fc0ef37 --- /dev/null +++ b/gptqmodel_ext/exllamav3/libtorch/linear.h @@ -0,0 +1,51 @@ +#pragma once + +#include +struct BC_LinearEXL3 +{ + at::Tensor trellis; + at::Tensor suh; + at::Tensor svh; + int K; + c10::optional bias; + bool mcg; + bool mul1; + at::Tensor xh; + + BC_LinearEXL3 + ( + at::Tensor _trellis, + at::Tensor _suh, + at::Tensor _svh, + int _K, + c10::optional _bias, + bool _mcg, + bool _mul1, + at::Tensor _xh + ) : + trellis(std::move(_trellis)), + suh(std::move(_suh)), + svh(std::move(_svh)), + K(_K), + bias(std::move(_bias)), + mcg(_mcg), + mul1(_mul1), + xh(std::move(_xh)) + {} + + void run(const at::Tensor& x, at::Tensor& y); +}; + +void bc_linear_exl3_run +( + at::Tensor trellis, + at::Tensor suh, + at::Tensor svh, + int64_t K, + const c10::optional& bias, + bool mcg, + bool mul1, + at::Tensor& xh, + const at::Tensor& x, + at::Tensor& y +); diff --git a/gptqmodel_ext/exllamav3/libtorch/linear_bc.h b/gptqmodel_ext/exllamav3/libtorch/linear_bc.h new file mode 100644 index 000000000..702a5e326 --- /dev/null +++ b/gptqmodel_ext/exllamav3/libtorch/linear_bc.h @@ -0,0 +1,22 @@ +py::class_>(m, "BC_LinearEXL3").def +( + py::init< + at::Tensor, + at::Tensor, + at::Tensor, + int, + c10::optional, + bool, + bool, + at::Tensor + >(), + py::arg("trellis"), + py::arg("suh"), + py::arg("svh"), + py::arg("K"), + py::arg("bias"), + py::arg("mcg"), + py::arg("mul1"), + py::arg("xh") +) +.def("run", &BC_LinearEXL3::run); diff --git a/gptqmodel_ext/exllamav3/ptx.cuh b/gptqmodel_ext/exllamav3/ptx.cuh new file mode 100644 index 000000000..8de8e65c1 --- /dev/null +++ b/gptqmodel_ext/exllamav3/ptx.cuh @@ -0,0 +1,314 @@ +#pragma once + +// Tensor core fragments + +template +struct Vec +{ + T elems[n]; + __device__ T& operator[](int i) { return elems[i]; } +}; + +using FragA = Vec; +using FragB = Vec; +using FragC = Vec; +using FragC_h = Vec; + +// m8n8k4 tensor core matmul (emulated on Ampere and later), don't use +// +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m8n8k4-with-f16-floating-point-type + +__device__ inline void ptx_mma_m8n8k4 +( + const Vec& frag_a, + const Vec& frag_b, + Vec& frag_c +) +{ + const uint32_t* a = reinterpret_cast(&frag_a); + const uint32_t* b = reinterpret_cast(&frag_b); + float* c = reinterpret_cast(&frag_c); + const float* d = reinterpret_cast(&frag_c); + + asm + ( + "mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3,%4,%5,%6,%7}, {%8,%9}, {%10,%11}, {%12,%13,%14,%15,%16,%17,%18,%19};\n" + + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]),"=f"(c[4]), "=f"(c[5]), "=f"(c[6]), "=f"(c[7]) + + : "r"(a[0]), "r"(a[1]), + "r"(b[0]), "r"(b[1]), + "f"(d[0]), "f"(d[1]), "f"(d[2]), "f"(d[3]), "f"(d[4]), "f"(d[5]), "f"(d[6]), "f"(d[7]) + ); +} + +// m16n8k16 tensor core matmul +// +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type + +// FP16 @ FP16 + FP32 -> FP32 +__device__ inline void ptx_mma_m16n8k16 +( + const FragA& frag_a, + const FragB& frag_b, + FragC& frag_c +) +{ + const uint32_t* a = reinterpret_cast(&frag_a); + const uint32_t* b = reinterpret_cast(&frag_b); + float* c = reinterpret_cast(&frag_c); + const float* d = reinterpret_cast(&frag_c); + + asm + ( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), + "r"(b[0]), "r"(b[1]), + "f"(d[0]), "f"(d[1]), "f"(d[2]), "f"(d[3]) + ); +} + +// FP16 @ FP16 + FP16 -> FP16 +__device__ inline void ptx_mma_m16n8k16 +( + const FragA& frag_a, + const FragB& frag_b, + FragC_h& frag_c +) +{ + const uint32_t* a = reinterpret_cast(&frag_a); + const uint32_t* b = reinterpret_cast(&frag_b); + uint32_t* c = reinterpret_cast(&frag_c); + const uint32_t* d = reinterpret_cast(&frag_c); + + asm + ( + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " + "{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n" + + : "=r"(c[0]), "=r"(c[1]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), + "r"(b[0]), "r"(b[1]), + "r"(d[0]), "r"(d[1]) + ); +} + +// Global barrier + +__device__ inline void barrier_acquire +( + int* lock, + int stage +) +{ + if (threadIdx.x == 0) + { + volatile int state = -1; + do + { + asm volatile ("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); + } + while (state != stage); + } + __syncthreads(); +} + +__device__ inline void barrier_release +( + int* lock, + int val, + bool reset +) +{ + __syncthreads(); + if (threadIdx.x == 0) + { + if (reset) + { + *lock = 0; + return; + } + asm volatile ("fence.acq_rel.gpu;\n"); + asm volatile ("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(lock), "r"(val)); + } +} + +// Load global to shared memory, predicated. Seems to produce incorrect code when compiling for Blackwell, but +// `if (...) cp_async(...)` compiles to a predicated instruction anyway + +__device__ inline void cp_async_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) +{ + const int bytes = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" :: "r"((int) pred), "r"(smem), "l"(glob_ptr), "n"(bytes) + ); +} + +// Load global to shared memory + +__device__ inline void cp_async(void* smem_ptr, const void* glob_ptr) +{ + const int bytes = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" :: "r"(smem), "l"(glob_ptr), "n"(bytes) + ); +} + +// Load global to shared memory with cache hint to evict data from L2 ASAP + +__device__ inline void cp_async_stream(void* smem_ptr, const void* glob_ptr) +{ + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + const int bytes = 16; + asm volatile + ( + "{\n" + " .reg .b64 p;\n" + " createpolicy.fractional.L2::evict_first.b64 p, 1.0;\n" + " cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n" + "}\n" :: "r"(smem), "l"(glob_ptr), "n"(bytes) + ); +} + +// Async copy fence, commit all pending async copies + +__device__ inline void cp_async_fence() +{ + asm volatile("cp.async.commit_group;\n" ::); +} + +// Wait until at most n async groups are still pending. + +template +__device__ inline void cp_async_wait() +{ + asm volatile("cp.async.wait_group %0;\n" :: "n"(n)); +} + +// Load 16x16 matrix fragment from shared memory, directly in tensor core layout + +__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) +{ + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile + ( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) : "r"(smem) + ); +} + +__device__ inline uint32_t mul_lo_u32(uint32_t x, uint32_t y) +{ + uint32_t w; + asm volatile + ( + "mul.lo.u32 %0, %1, %2;" + : "=r"(w) + : "r"(x), "r"(y) + ); + return w; +} + +__device__ inline uint32_t mul_hi_u32(uint32_t x, uint32_t y) +{ + uint32_t w; + asm volatile + ( + "mul.hi.u32 %0, %1, %2;" + : "=r"(w) + : "r"(x), "r"(y) + ); + return w; +} + +// Memory ops + +__device__ __forceinline__ void stg_wt_u32(uint32_t* p, uint32_t v) +{ + asm volatile("st.global.wt.u32 [%0], %1;" :: "l"(p), "r"(v)); +} + +__device__ __forceinline__ void stg_wt_u128(uint4* p, const uint4 v) +{ + asm volatile ("st.global.wt.v4.u32 [%0], {%1,%2,%3,%4};" + :: "l"(p), + "r"(v.x), "r"(v.y), "r"(v.z), "r"(v.w)); +} + +__device__ __forceinline__ uint32_t ldg_cv_u32(const uint32_t* p) +{ + uint32_t v; + asm volatile("ld.global.cv.u32 %0, [%1];" : "=r"(v) : "l"(p)); + return v; +} + +__device__ __forceinline__ uint4 ldg_cv_u128(const uint4* p) +{ + uint4 v; + asm volatile ("ld.global.cv.v4.u32 {%0,%1,%2,%3}, [%4];" + : "=r"(v.x), "=r"(v.y), "=r"(v.z), "=r"(v.w) + : "l"(p)); + return v; +} + +__device__ __forceinline__ uint32_t ldg_acquire_sys_u32(const uint32_t* p) +{ + uint32_t v; + asm volatile("ld.global.acquire.sys.u32 %0, [%1];" + : "=r"(v) : "l"(p)); + return v; +} + +__device__ __forceinline__ uint64_t ldg_acquire_sys_u64(const uint64_t* p) +{ + uint64_t v; + asm volatile("ld.global.acquire.sys.u64 %0, [%1];" : "=l"(v) : "l"(p) : "memory"); + return v; +} + +__device__ __forceinline__ void stg_release_sys_u32(uint32_t* p, uint32_t v) +{ + asm volatile("st.global.release.sys.u32 [%0], %1;" :: "l"(p), "r"(v) : "memory"); +} + +__device__ __forceinline__ void stg_release_sys_u64(uint64_t* p, uint64_t v) +{ + asm volatile("st.global.release.sys.u64 [%0], %1;" :: "l"(p), "l"(v) : "memory"); +} + +// Global time in nanoseconds + +__device__ __forceinline__ uint64_t globaltimer_ns() +{ + uint64_t t; + asm volatile("mov.u64 %0, %%globaltimer;" : "=l"(t)); + return t; +} + +// Bitfield stuff + +static __forceinline__ __device__ uint32_t bfe64(uint32_t lo, uint32_t hi, int offset, int length) +{ + uint64_t value = (static_cast(hi) << 32) | static_cast(lo); + uint64_t result64; + asm ("bfe.u64 %0, %1, %2, %3;" + : "=l"(result64) + : "l"(value), "r"(offset), "r"(length)); + return static_cast(result64); +} + +#define FSHF_IMM(dst, lo, hi, imm) asm("shf.r.wrap.b32 %0, %1, %2, " #imm ";" : "=r"(dst) : "r"(lo), "r"(hi)) +#define BFE16_IMM(dst, src, imm) asm("bfe.u32 %0, %1, " #imm ", 16;" : "=r"(dst) : "r"(src)) \ No newline at end of file diff --git a/gptqmodel_ext/exllamav3/quant/codebook.cuh b/gptqmodel_ext/exllamav3/quant/codebook.cuh new file mode 100644 index 000000000..e8201937a --- /dev/null +++ b/gptqmodel_ext/exllamav3/quant/codebook.cuh @@ -0,0 +1,146 @@ +#pragma once + +// Force integer MAD on sm<=86. For some reason this performs better than letting the compiler emit IMUL +// TODO: Keep an eye on new behavior in future versions of nvcc. While this is faster on RTX 3090, it really shouldn't be. +template +__device__ __forceinline__ +uint32_t mul_const_u32(uint32_t x) +{ + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 860) + uint32_t r; + asm volatile ( + "{ .reg .u32 z,t;" + " mov.u32 t, %laneid;" // runtime SR + " sub.u32 z, t, t;" // z = 0 but data-dependent + " mad.lo.u32 %0, %1, %2, z;" + "}" + : "=r"(r) + : "r"(x), "n"(w)); + return r; + #else + return x * w; + #endif +} + +template +__device__ inline half decode_3inst(uint32_t x) +{ + if constexpr (cb == 0) + { + x *= 89226354u; + x += 64248484u; + asm ("lop3.b32 %0, %0, 0x8fff8fff, 0x3b603b60, 0x6a;" : "+r"(x)); + half2_uint32 xu(x); + return __hadd(__low2half(xu.as_half2), __high2half(xu.as_half2)); + } + if constexpr (cb == 1) + { +// x *= 0xCBAC1FEDu; + x = mul_const_u32<0xCBAC1FEDu>(x); + + asm ("lop3.b32 %0, %0, 0x8fff8fff, 0x3b603b60, 0x6a;" : "+r"(x)); + half2_uint32 xu(x); + return __hadd(__low2half(xu.as_half2), __high2half(xu.as_half2)); + } + if constexpr (cb == 2) + { + x *= 0x83DCD12Du; + uint32_t sum; + const uint32_t acc = 0x6400u; // 0x6400 -> 1024.0 .. 0x67FF -> 2047.0 + asm ("vabsdiff4.u32.u32.u32.add %0, %1, %2, %3;" : "=r"(sum) : "r"(x), "r"(0), "r"(acc) : ); + const __half k_inv_h = __ushort_as_half(0x1eee); // 0.00677 = 1/147.7 + const __half k_bias_h = __ushort_as_half(0xc931); // -10.39 = (-1024.0 - 510.0) * k_inv_h + half_uint16 h((uint16_t) sum); + return __hfma(h.as_half, k_inv_h, k_bias_h); + } +} + +template +__device__ inline half2 decode_3inst_2(uint32_t x0, uint32_t x1) +{ + if constexpr (cb == 0) + { + x0 *= 89226354u; + x1 *= 89226354u; + x0 += 64248484u; + x1 += 64248484u; + asm ("lop3.b32 %0, %0, 0x8fff8fff, 0x3b603b60, 0x6a;" : "+r"(x0)); + asm ("lop3.b32 %0, %0, 0x8fff8fff, 0x3b603b60, 0x6a;" : "+r"(x1)); + half2_uint32 xu0(x0); + half2_uint32 xu1(x1); + half2 d0 = __lows2half2(xu0.as_half2, xu1.as_half2); + half2 d1 = __highs2half2(xu0.as_half2, xu1.as_half2); + return __hadd2(d0, d1); + } + if constexpr (cb == 1) + { +// x0 *= 0xCBAC1FEDu; +// x1 *= 0xCBAC1FEDu; + x0 = mul_const_u32<0xCBAC1FEDu>(x0); + x1 = mul_const_u32<0xCBAC1FEDu>(x1); + asm ("lop3.b32 %0, %0, 0x8fff8fff, 0x3b603b60, 0x6a;" : "+r"(x0)); + asm ("lop3.b32 %0, %0, 0x8fff8fff, 0x3b603b60, 0x6a;" : "+r"(x1)); + half2_uint32 xu0(x0); + half2_uint32 xu1(x1); + half2 d0 = __lows2half2(xu0.as_half2, xu1.as_half2); + half2 d1 = __highs2half2(xu0.as_half2, xu1.as_half2); + return __hadd2(d0, d1); + } + if constexpr (cb == 2) + { + x0 *= 0x83DCD12Du; + x1 *= 0x83DCD12Du; + uint32_t sum0; + uint32_t sum1; + const uint32_t acc = 0x6400u; // 0x6400 -> 1024.0 .. 0x67FF -> 2047.0 + asm ("vabsdiff4.u32.u32.u32.add %0, %1, %2, %3;" : "=r"(sum0) : "r"(x0), "r"(0), "r"(acc) : ); + asm ("vabsdiff4.u32.u32.u32.add %0, %1, %2, %3;" : "=r"(sum1) : "r"(x1), "r"(0), "r"(acc) : ); + half2 k_inv_h2 = __half2half2(__ushort_as_half(0x1eee)); // 0.00677 = 1/147.7 + half2 k_bias_h2 = __half2half2(__ushort_as_half(0xc931)); // -10.39 = (-1024.0 - 510.0) * k_inv_h + half_uint16 h0((uint16_t) sum0); + half_uint16 h1((uint16_t) sum1); + return __hfma2(__halves2half2(h0.as_half, h1.as_half), k_inv_h2, k_bias_h2); + } +} + +template +__device__ inline float decode_3inst_f(uint64_t x) +{ + return __half2float(decode_3inst(x)); +} + +template +__device__ inline float decode_3inst_f_diff(uint64_t x, float d) +{ + return __half2float(decode_3inst(x)) - d; +} + +// "2MAD" procedural codebook, much more overhead than 3INST, slightly better distribution at 2bpw +// Not used currently + +//__device__ inline half decode_2mad(uint64_t x) +//{ +// x = x * 264435761u + 1013904223u; +// x = ((x * 1664525u) >> 32) + x; +// int32_t c = (int32_t) __dp4a((uint32_t) x, 0x01010101u, 0xFFFFFE02u); +// half y = __hmul(__int2half_rn(c), __float2half_rn(0.008415)); +// return y; +//} +// +//__device__ inline float decode_2mad_f(uint64_t x) +//{ +// x = x * 264435761u + 1013904223u; +// x = ((x * 1664525u) >> 32) + x; +// int32_t c = (int32_t) __dp4a((uint32_t) x, 0x01010101u, 0xFFFFFE02u); +// float y = __int2float_rn(c) * 0.008415f; +// return y; +//} +// +//__device__ inline float decode_2mad_f_diff(uint64_t x, float d) +//{ +// x = x * 264435761u + 1013904223u; +// x = ((x * 1664525u) >> 32) + x; +// int32_t c = (int32_t) __dp4a((uint32_t) x, 0x01010101u, 0xFFFFFE02u); +// float y = fma(__int2float_rn(c), 0.008415f, -d); +// return y; +//} diff --git a/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_1.cu b/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_1.cu new file mode 100644 index 000000000..ca23e3db6 --- /dev/null +++ b/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_1.cu @@ -0,0 +1,12 @@ +#include +#include +#include +#include +namespace cg = cooperative_groups; +#include "../../util.h" +#include "../../util.cuh" +#include "../../ptx.cuh" +#include "../exl3_gemm_kernel.cuh" +#include "exl3_comp_unit_1.cuh" + +ALL_EXL3_KERNEL_INSTANCES(1) \ No newline at end of file diff --git a/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_1.cuh b/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_1.cuh new file mode 100644 index 000000000..876a1dc82 --- /dev/null +++ b/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_1.cuh @@ -0,0 +1,3 @@ +#pragma once + +ALL_EXL3_KERNEL_EXTERNS(1) \ No newline at end of file diff --git a/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_2.cu b/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_2.cu new file mode 100644 index 000000000..5ebb29de2 --- /dev/null +++ b/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_2.cu @@ -0,0 +1,12 @@ +#include +#include +#include +#include +namespace cg = cooperative_groups; +#include "../../util.h" +#include "../../util.cuh" +#include "../../ptx.cuh" +#include "../exl3_gemm_kernel.cuh" +#include "exl3_comp_unit_2.cuh" + +ALL_EXL3_KERNEL_INSTANCES(2) diff --git a/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_2.cuh b/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_2.cuh new file mode 100644 index 000000000..e6d356463 --- /dev/null +++ b/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_2.cuh @@ -0,0 +1,3 @@ +#pragma once + +ALL_EXL3_KERNEL_EXTERNS(2) \ No newline at end of file diff --git a/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_3.cu b/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_3.cu new file mode 100644 index 000000000..31efeed5e --- /dev/null +++ b/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_3.cu @@ -0,0 +1,12 @@ +#include +#include +#include +#include +namespace cg = cooperative_groups; +#include "../../util.h" +#include "../../util.cuh" +#include "../../ptx.cuh" +#include "../exl3_gemm_kernel.cuh" +#include "exl3_comp_unit_3.cuh" + +ALL_EXL3_KERNEL_INSTANCES(3) diff --git a/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_3.cuh b/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_3.cuh new file mode 100644 index 000000000..c1236e6bb --- /dev/null +++ b/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_3.cuh @@ -0,0 +1,3 @@ +#pragma once + +ALL_EXL3_KERNEL_EXTERNS(3) \ No newline at end of file diff --git a/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_4.cu b/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_4.cu new file mode 100644 index 000000000..a26b866ce --- /dev/null +++ b/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_4.cu @@ -0,0 +1,12 @@ +#include +#include +#include +#include +namespace cg = cooperative_groups; +#include "../../util.h" +#include "../../util.cuh" +#include "../../ptx.cuh" +#include "../exl3_gemm_kernel.cuh" +#include "exl3_comp_unit_4.cuh" + +ALL_EXL3_KERNEL_INSTANCES(4) diff --git a/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_4.cuh b/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_4.cuh new file mode 100644 index 000000000..3eb77cda0 --- /dev/null +++ b/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_4.cuh @@ -0,0 +1,3 @@ +#pragma once + +ALL_EXL3_KERNEL_EXTERNS(4) \ No newline at end of file diff --git a/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_5.cu b/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_5.cu new file mode 100644 index 000000000..42d9e76b7 --- /dev/null +++ b/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_5.cu @@ -0,0 +1,12 @@ +#include +#include +#include +#include +namespace cg = cooperative_groups; +#include "../../util.h" +#include "../../util.cuh" +#include "../../ptx.cuh" +#include "../exl3_gemm_kernel.cuh" +#include "exl3_comp_unit_5.cuh" + +ALL_EXL3_KERNEL_INSTANCES(5) diff --git a/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_5.cuh b/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_5.cuh new file mode 100644 index 000000000..52814b523 --- /dev/null +++ b/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_5.cuh @@ -0,0 +1,3 @@ +#pragma once + +ALL_EXL3_KERNEL_EXTERNS(5) \ No newline at end of file diff --git a/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_6.cu b/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_6.cu new file mode 100644 index 000000000..57936aec0 --- /dev/null +++ b/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_6.cu @@ -0,0 +1,12 @@ +#include +#include +#include +#include +namespace cg = cooperative_groups; +#include "../../util.h" +#include "../../util.cuh" +#include "../../ptx.cuh" +#include "../exl3_gemm_kernel.cuh" +#include "exl3_comp_unit_6.cuh" + +ALL_EXL3_KERNEL_INSTANCES(6) diff --git a/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_6.cuh b/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_6.cuh new file mode 100644 index 000000000..98bfb931a --- /dev/null +++ b/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_6.cuh @@ -0,0 +1,3 @@ +#pragma once + +ALL_EXL3_KERNEL_EXTERNS(6) \ No newline at end of file diff --git a/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_7.cu b/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_7.cu new file mode 100644 index 000000000..beb2b7f26 --- /dev/null +++ b/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_7.cu @@ -0,0 +1,12 @@ +#include +#include +#include +#include +namespace cg = cooperative_groups; +#include "../../util.h" +#include "../../util.cuh" +#include "../../ptx.cuh" +#include "../exl3_gemm_kernel.cuh" +#include "exl3_comp_unit_7.cuh" + +ALL_EXL3_KERNEL_INSTANCES(7) diff --git a/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_7.cuh b/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_7.cuh new file mode 100644 index 000000000..49d5ba8f1 --- /dev/null +++ b/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_7.cuh @@ -0,0 +1,3 @@ +#pragma once + +ALL_EXL3_KERNEL_EXTERNS(7) \ No newline at end of file diff --git a/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_8.cu b/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_8.cu new file mode 100644 index 000000000..890d1b6a7 --- /dev/null +++ b/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_8.cu @@ -0,0 +1,12 @@ +#include +#include +#include +#include +namespace cg = cooperative_groups; +#include "../../util.h" +#include "../../util.cuh" +#include "../../ptx.cuh" +#include "../exl3_gemm_kernel.cuh" +#include "exl3_comp_unit_8.cuh" + +ALL_EXL3_KERNEL_INSTANCES(8) diff --git a/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_8.cuh b/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_8.cuh new file mode 100644 index 000000000..951bc2ef2 --- /dev/null +++ b/gptqmodel_ext/exllamav3/quant/comp_units/exl3_comp_unit_8.cuh @@ -0,0 +1,3 @@ +#pragma once + +ALL_EXL3_KERNEL_EXTERNS(8) \ No newline at end of file diff --git a/gptqmodel_ext/exllamav3/quant/exl3_devctx.cu b/gptqmodel_ext/exllamav3/quant/exl3_devctx.cu new file mode 100644 index 000000000..620b23c6d --- /dev/null +++ b/gptqmodel_ext/exllamav3/quant/exl3_devctx.cu @@ -0,0 +1,86 @@ +#include +#include +#include +#include +namespace cg = cooperative_groups; +#include "exl3_devctx.cuh" +#include "../util.h" +#include "../util.cuh" + +//DevCtx::DevCtc() +//{ +// int num_sms[MAX_DEVICES] = {}; +// int cc[MAX_DEVICES] = {}; +// void* locks[MAX_DEVICES] = {}; +// std::mutex mtx; +//} + +DevCtx& DevCtx::instance() +{ + static DevCtx ctx; + return ctx; +} + +int DevCtx::get_num_sms(int device) +{ + std::lock_guard lock(mtx); + if (!num_sms[device]) + cuda_check(cudaDeviceGetAttribute(&num_sms[device], cudaDevAttrMultiProcessorCount, device)); + return num_sms[device]; +} + +int DevCtx::get_cc(int device) +{ + std::lock_guard lock(mtx); + if (!cc[device]) + { + cudaDeviceProp prop; + cuda_check(cudaGetDeviceProperties(&prop, device)); + if (prop.major >= 10) cc[device] = CC_BLACKWELL; + else if (prop.major >= 9) cc[device] = CC_HOPPER; + else if (prop.major >= 8 && prop.minor >= 9) cc[device] = CC_ADA; + else if (prop.major >= 8) cc[device] = CC_AMPERE; + else cc[device] = CC_OLD; + } + return cc[device]; +} + +void* DevCtx::get_ws(int device) +{ + std::lock_guard lock(mtx); + if (!ws[device]) + { + cudaSetDevice(device); + cudaMalloc(&ws[device], WORKSPACE_SIZE); + } + return ws[device]; +} + +int* DevCtx::get_locks(int device) +{ + std::lock_guard lock(mtx); + if (!locks[device]) + { + cudaSetDevice(device); + cudaMalloc(&locks[device], MAX_TILES_C * sizeof(int)); + cudaMemset(locks[device], 0, MAX_TILES_C * sizeof(int)); + } + return (int*) locks[device]; +} + +int g_get_cc(int device) +{ + return DevCtx::instance().get_cc(device); +} + +int g_get_num_sms(int device) +{ + return DevCtx::instance().get_num_sms(device); +} + +void prepare_ctx(int device) +{ + DevCtx::instance().get_num_sms(device); + DevCtx::instance().get_cc(device); + DevCtx::instance().get_locks(device); +} diff --git a/gptqmodel_ext/exllamav3/quant/exl3_devctx.cuh b/gptqmodel_ext/exllamav3/quant/exl3_devctx.cuh new file mode 100644 index 000000000..ac8efaea0 --- /dev/null +++ b/gptqmodel_ext/exllamav3/quant/exl3_devctx.cuh @@ -0,0 +1,46 @@ +#pragma once + +#include +#include + +// Max allowable output size, in tiles. Used to allocate global lock buffer per device for sync across threadblocks +#define MAX_TILES_C (1024 * 1024) + +// Workspace size +#define WORKSPACE_SIZE (4*1024*1024) + +// Treat hopper and blackwell as same arch for now +#define MAX_DEVICES 16 +#define CC_OLD 1 +#define CC_AMPERE 2 +#define CC_ADA 3 +#define CC_HOPPER 4 +#define CC_BLACKWELL 4 + +// Singleton to manage context for each device. Stores device attributes and a large-enough lock buffer per device +class DevCtx +{ +private: + int num_sms[MAX_DEVICES] = {}; + int cc[MAX_DEVICES] = {}; + void* locks[MAX_DEVICES] = {}; + void* ws[MAX_DEVICES] = {}; + std::mutex mtx; + +public: + static DevCtx& instance(); + int get_num_sms(int device); + int get_cc(int device); + void* get_ws(int device); + int* get_locks(int device); + +private: + DevCtx() = default; + DevCtx(const DevCtx&) = delete; + DevCtx& operator=(const DevCtx&) = delete; +}; + +int g_get_cc(int device); +int g_get_num_sms(int device); + +void prepare_ctx(int device); \ No newline at end of file diff --git a/gptqmodel_ext/exllamav3/quant/exl3_dq.cuh b/gptqmodel_ext/exllamav3/quant/exl3_dq.cuh new file mode 100644 index 000000000..2cf103654 --- /dev/null +++ b/gptqmodel_ext/exllamav3/quant/exl3_dq.cuh @@ -0,0 +1,293 @@ +#pragma once + +#include "codebook.cuh" + +__device__ __forceinline__ uint32_t fshift(const uint32_t b, const uint32_t a, int shift) +{ + uint64_t merged = ((uint64_t)a << 32) | (uint64_t) b; + return (uint32_t)(merged >> shift); + + // Conditional funnel shift is somehow no longer faster + // if (shift < 32) return __funnelshift_r(b, a, shift); + // return a >> (shift - 32); +} + +template +__device__ __forceinline__ half dq(const uint32_t* ptr, int t_offset) +{ + int b0 = t_offset * bits + bits - 16 + 256 * bits; // bit index, start of word0 + int b1 = b0 + 16; // bit index, end of word0 + int i0 = b0 / 32; // uint32 containing first bit of word0 + int i1 = (b1 - 1) / 32; // uint32 containing last bit of word0, may be == i0 + int s0 = (i1 + 1) * 32 - b1; // shift value to align word1 to 32-bit boundary + + // Load 32 or 64 bits containing word0 + uint32_t a = ptr[i0 % (bits * 256 / 32)]; + uint32_t b = ptr[i1 % (bits * 256 / 32)]; + + // Shift into place + uint32_t w0 = __funnelshift_r(b, a, s0) & 0xffff; + return decode_3inst(w0); +} + +template +__device__ __forceinline__ half2 dq2(const uint32_t* ptr, int t_offset) +{ + int b0 = t_offset * bits + bits - 16 + 256 * bits; // bit index, start of word0 + int b1 = b0 + 16; // bit index, end of word0 + int i0 = b0 / 32; // uint32 containing first bit of word0 + int i1 = (b1 - 1) / 32; // uint32 containing last bit of word0, may be == i0 + int s0 = (i1 + 1) * 32 - b1; // shift value to align word1 to 32-bit boundary + + // Load 32 or 64 bits containing word0 + uint32_t a = ptr[i0 % (bits * 256 / 32)]; + uint32_t b = ptr[i1 % (bits * 256 / 32)]; + + // Shift into place + uint32_t w1 = __funnelshift_r(b, a, s0) & 0xffff; + uint32_t w0 = __funnelshift_r(b, a, s0 + bits) & 0xffff; + return decode_3inst_2(w0, w1); +} + +template +__device__ __forceinline__ void dq4(const uint32_t* ptr, int t_offset, FragB& frag) +{ + int b0 = (t_offset + 257) * bits - 16; // start of first word + int b1 = b0 + 3 * bits; // start of last word + int b2 = b1 + 16; // end of last word + int i0 = b0 / 32; // uint32 containing first bit of first word + int i2 = (b2 - 1) / 32; // uint32 containing last bit of last word, may be == i0 + int s2 = (i2 + 1) * 32 - b2; // shift value to align last word to 32-bit boundary + + uint32_t a = ptr[i0 % (bits * 256 / 32)]; + uint32_t b = ptr[i2 % (bits * 256 / 32)]; + uint32_t w3 = fshift(b, a, s2) & 0xffff; + uint32_t w2 = fshift(b, a, s2 + bits) & 0xffff; + uint32_t w1 = fshift(b, a, s2 + bits * 2) & 0xffff; + uint32_t w0 = fshift(b, a, s2 + bits * 3) & 0xffff; + half2 d0d1 = decode_3inst_2(w0, w1); + half2 d2d3 = decode_3inst_2(w2, w3); + frag[0] = d0d1; + frag[1] = d2d3; +} + +template +__device__ __forceinline__ void dq2x2(const uint32_t* ptr, int t_offset, FragB& frag) +{ + #pragma unroll + for (int i = 0; i < 2; ++i) + { + int b0 = (t_offset + 2 * i + 257) * bits - 16; // start of first word + int b1 = b0 + 1 * bits; // start of last word + int b2 = b1 + 16; // end of last word + int i0 = b0 / 32; // uint32 containing first bit of first word + int i2 = (b2 - 1) / 32; // uint32 containing last bit of last word, may be == i0 + int s2 = (i2 + 1) * 32 - b2; // shift value to align last word to 32-bit boundary + + uint32_t a = ptr[i0 % (bits * 256 / 32)]; + uint32_t b = ptr[i2 % (bits * 256 / 32)]; + uint32_t w1 = fshift(b, a, s2) & 0xffff; + uint32_t w0 = fshift(b, a, s2 + bits) & 0xffff; + half2 d0d1 = decode_3inst_2(w0, w1); + frag[i] = d0d1; + } +} + +template +__device__ __forceinline__ void dq8(const uint32_t* ptr, int t_offset, FragB& frag0, FragB& frag1) +{ + int b1 = (t_offset + 257) * bits; // end of first word + int b0 = b1 - 16; // start of first word + int b2 = b1 + bits * 7; + int i0 = b0 / 32; // uint32 containing first bit of word0 + int i2 = (b2 - 1) / 32; // uint32 containing last bit of word0, may be == i0 + int s2 = (i2 + 1) * 32 - b2; // shift value to align last word to 32-bit boundary + + uint32_t a = ptr[i0 % (bits * 256 / 32)]; + uint32_t b = ptr[i2 % (bits * 256 / 32)]; + uint32_t w0, w1, w2, w3, w4, w5, w6, w7; + if constexpr (align == 1) + { + w7 = fshift(b, a, s2); + w6 = fshift(b, a, s2 + bits); + w5 = fshift(b, a, s2 + bits * 2); + w4 = fshift(b, a, s2 + bits * 3); + w3 = fshift(b, a, s2 + bits * 4); + w2 = fshift(b, a, s2 + bits * 5); + w1 = fshift(b, a, s2 + bits * 6); + w0 = fshift(b, a, s2 + bits * 7); + } + if constexpr (align == 2) + { + w7 = fshift(b, a, s2); + w6 = w7 >> bits; + w5 = fshift(b, a, s2 + bits * 2); + w4 = w5 >> bits; + w3 = fshift(b, a, s2 + bits * 4); + w2 = w3 >> bits; + w1 = fshift(b, a, s2 + bits * 6); + w0 = w1 >> bits; + } + if constexpr (align == 4) + { + w7 = fshift(b, a, s2); + w6 = w7 >> bits; + w5 = w6 >> bits; + w4 = w5 >> bits; + w3 = fshift(b, a, s2 + bits * 4); + w2 = w3 >> bits; + w1 = w2 >> bits; + w0 = w1 >> bits; + } + if constexpr (align == 8) + { + w7 = fshift(b, a, s2); + w6 = w7 >> bits; + w5 = w6 >> bits; + w4 = w5 >> bits; + w3 = w4 >> bits; + w2 = w3 >> bits; + w1 = w2 >> bits; + w0 = w1 >> bits; + } + half2 d0d1 = decode_3inst_2(w0 & 0xffff, w1 & 0xffff); + half2 d2d3 = decode_3inst_2(w2 & 0xffff, w3 & 0xffff); + half2 d4d5 = decode_3inst_2(w4 & 0xffff, w5 & 0xffff); + half2 d6d7 = decode_3inst_2(w6 & 0xffff, w7 & 0xffff); + frag0[0] = d0d1; + frag0[1] = d2d3; + frag1[0] = d4d5; + frag1[1] = d6d7; +} + +template +__device__ __forceinline__ void dq8_aligned_4bits(const uint32_t* ptr, int t_offset, FragB& frag0, FragB& frag1) +{ + uint32_t i0, i1, a, b, s, w0, w1, w2, w3, w4, w5, w6, w7; + i1 = t_offset >> 3; + i0 = (i1 + 31) & 31; + a = ptr[i0]; + b = ptr[i1]; + FSHF_IMM(s, b, a, 20); + w7 = b & 0xffff; + BFE16_IMM(w6, b, 4); + BFE16_IMM(w5, b, 8); + BFE16_IMM(w4, b, 12); + BFE16_IMM(w3, b, 16); + w2 = s & 0xffff; + BFE16_IMM(w1, s, 4); + BFE16_IMM(w0, s, 8); + frag0[0] = decode_3inst_2(w0, w1); + frag0[1] = decode_3inst_2(w2, w3); + frag1[0] = decode_3inst_2(w4, w5); + frag1[1] = decode_3inst_2(w6, w7); +} + +template +__device__ __forceinline__ void dq8_aligned_2bits(const uint32_t* ptr, int t_offset, FragB& frag0, FragB& frag1) +{ + uint32_t i0, i1, a, b, w0, w1, w2, w3, w4, w5, w6, w7; + i1 = t_offset >> 4; + i0 = (i1 + 15) & 15; + a = ptr[i0]; + b = ptr[i1]; + b = fshift(b, a, ((~t_offset) & 8) << 1); + w7 = b & 0xffff; + BFE16_IMM(w6, b, 2); + BFE16_IMM(w5, b, 4); + BFE16_IMM(w4, b, 6); + BFE16_IMM(w3, b, 8); + BFE16_IMM(w2, b, 10); + BFE16_IMM(w1, b, 12); + BFE16_IMM(w0, b, 14); + frag0[0] = decode_3inst_2(w0, w1); + frag0[1] = decode_3inst_2(w2, w3); + frag1[0] = decode_3inst_2(w4, w5); + frag1[1] = decode_3inst_2(w6, w7); +} + +template +__device__ __forceinline__ void dq8_aligned_1bit(const uint32_t* ptr, int t_offset, FragB& frag0, FragB& frag1) +{ + uint32_t i0, i1, a, b, w0, w1, w2, w3, w4, w5, w6, w7; + i1 = t_offset >> 5; + i0 = (i1 + 7) & 7; + a = ptr[i0]; + b = ptr[i1]; + b = fshift(b, a, ((~t_offset) & 24)); + w7 = b & 0xffff; + BFE16_IMM(w6, b, 1); + BFE16_IMM(w5, b, 2); + BFE16_IMM(w4, b, 3); + BFE16_IMM(w3, b, 4); + BFE16_IMM(w2, b, 5); + BFE16_IMM(w1, b, 6); + BFE16_IMM(w0, b, 7); + frag0[0] = decode_3inst_2(w0, w1); + frag0[1] = decode_3inst_2(w2, w3); + frag1[0] = decode_3inst_2(w4, w5); + frag1[1] = decode_3inst_2(w6, w7); +} + + +template +__device__ __forceinline__ void dq8_aligned_4bits_bfe64(const uint32_t* ptr, int t_offset, FragB& frag0, FragB& frag1) +{ + int i1 = t_offset / 8; + int i0 = (i1 + 31) % 32; + uint32_t a = ptr[i0]; + uint32_t b = ptr[i1]; + uint32_t w7 = bfe64(b, a, 0, 16); + uint32_t w6 = bfe64(b, a, 4, 16); + uint32_t w5 = bfe64(b, a, 8, 16); + uint32_t w4 = bfe64(b, a, 12, 16); + uint32_t w3 = bfe64(b, a, 16, 16); + uint32_t w2 = bfe64(b, a, 20, 16); + uint32_t w1 = bfe64(b, a, 24, 16); + uint32_t w0 = bfe64(b, a, 28, 16); + frag0[0] = decode_3inst_2(w0, w1); + frag0[1] = decode_3inst_2(w2, w3); + frag1[0] = decode_3inst_2(w4, w5); + frag1[1] = decode_3inst_2(w6, w7); +} + +template +__device__ __forceinline__ void dq_dispatch(const uint32_t* ptr, int idx, FragB& frag0, FragB& frag1) +{ + if constexpr (bits == 1) + { + dq8_aligned_1bit(ptr, idx, frag0, frag1); + } + else if constexpr (bits == 2) + { + dq8_aligned_2bits(ptr, idx, frag0, frag1); + } + else if constexpr (bits == 3) + { + dq8(ptr, idx, frag0, frag1); + } + else if constexpr (bits == 4) + { + dq8_aligned_4bits(ptr, idx, frag0, frag1); + } + else if constexpr (bits == 5) + { + dq4(ptr, idx, frag0); + dq4(ptr, idx + 4, frag1); + } + else if constexpr (bits == 6) + { + dq4(ptr, idx, frag0); + dq4(ptr, idx + 4, frag1); + } + else if constexpr (bits == 7) + { + dq2x2(ptr, idx, frag0); + dq2x2(ptr, idx + 4, frag1); + } + else if constexpr (bits == 8) + { + dq4(ptr, idx, frag0); + dq4(ptr, idx + 4, frag1); + } +} \ No newline at end of file diff --git a/gptqmodel_ext/exllamav3/quant/exl3_gemm.cu b/gptqmodel_ext/exllamav3/quant/exl3_gemm.cu new file mode 100644 index 000000000..67b22b0c5 --- /dev/null +++ b/gptqmodel_ext/exllamav3/quant/exl3_gemm.cu @@ -0,0 +1,141 @@ +#include +#include "exl3_gemm.cuh" + +#include +#include +#include "../util.h" +#include "../util.cuh" +#include "exl3_kernel_map.cuh" +#include "exl3_devctx.cuh" +#include + +constexpr int EXL3_GEMM_SMEM_MAX = 90 * 1024; + +/* +EXL3 matmul, A @ B -> C + +- A: row-major A tensor, shape (m, k), dtype float16, contiguous +- B: EXL3-quantized B tensor, shape (k//16, n//16, 16*K), dtype uint16 +- C: empty row-major C tensor, shape (m, n), dtype float16 or float32, contiguous. Does not need to be zero-initialized +- suh: optional, packed input scales/flips, shape (k//16), dtype float16 +- A_had: required if suh given, may be reference to A, temporary storage for input transform, size and dtype as A +- svh: optional, packed output scales/flips, shape (n//16), dtype float16 + +limitations: +- k % 16 == 0 +- n % 128 == 0 +*/ + +std::set kernel_attr_set[MAX_DEVICES] = {}; + +int exl3_gemm +( + const at::Tensor& A, + const at::Tensor& B, + at::Tensor& C, + const c10::optional& suh, + const c10::optional& A_had, + const c10::optional& svh, + int force_shape_idx, + bool mcg, + bool mul1, + int force_num_sms +) +{ + const at::cuda::OptionalCUDAGuard device_guard(A.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + TORCH_CHECK_DIM(B, 3); + TORCH_CHECK_SHAPES(A, -1, B, 0, 16); + TORCH_CHECK_SHAPES(C, -1, B, 1, 16); + // TORCH_CHECK_SHAPES(A, 0, C, 0, 1); + TORCH_CHECK_DTYPE(A, kHalf); + TORCH_CHECK_DTYPE(B, kShort); + bool c_fp32 = C.dtype() == at::kFloat; + if (!c_fp32) TORCH_CHECK_DTYPE(C, kHalf); + + // Get SU, optionally + const half* suh_ptr = (const half*) OPTPTR(suh); + half* A_had_ptr = nullptr; + if (suh_ptr) + { + // TORCH_CHECK_SHAPES(suh.value(), 0, A, 1, 1); + A_had_ptr = (half*) OPTPTR(A_had); + // TORCH_CHECK(A_had_ptr, "Must supply A_had with suh"); + // TORCH_CHECK_SHAPES_FULL(A_had.value(), A); + } + + // Get SV, optionally + const half* svh_ptr = (const half*) OPTPTR(svh); + // if (svh_ptr) + // TORCH_CHECK_SHAPES(svh.value(), 0, B, 1, 16); + + // Device properties + int device; + cudaGetDevice(&device); + int num_sms = force_num_sms ? force_num_sms : DevCtx::instance().get_num_sms(device); + int cc = DevCtx::instance().get_cc(device); + int* locks = DevCtx::instance().get_locks(device); + + // Dispatch + int K = B.size(2) / 16; + const half* A_ptr = (const half*) A.data_ptr(); + const uint16_t* B_ptr = (const uint16_t*) B.data_ptr(); + void* C_ptr = (void*) C.data_ptr(); + + int size_m = 1; + int dim = A.dim(); + for (int d = 0; d < dim - 1; ++d) size_m *= A.size(d); + int size_k = A.size(-1); + int size_n = B.size(1) * 16; + + // Select kernel + TORCH_CHECK(!(mcg && mul1), "Specified both mcg and mul1") + int cb = 0; + if (mcg) cb = 1; + if (mul1) cb = 2; + + int block_dim; + int shape_idx; + fp_exl3_gemm_kernel kernel; + + TResult* tr = select_exl3_gemm_kernel_tuned(cc, size_k, size_n, K, c_fp32, force_shape_idx, force_num_sms, cb); + if (!tr) return 0; + num_sms = MIN(num_sms, tr->num_sms); + kernel = tr->kernel; + block_dim = tr->block_dim; + shape_idx = tr->shape_idx; + + // Launch + if (kernel_attr_set[device].find((void*) kernel) == kernel_attr_set[device].end()) + { + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, EXL3_GEMM_SMEM_MAX); + kernel_attr_set[device].insert((void*) kernel); + cuda_check(cudaPeekAtLastError()); + } + void* kernelArgs[] = + { + (void*)& A_ptr, + (void*)& B_ptr, + (void*)& C_ptr, + (void*)& size_m, + (void*)& size_k, + (void*)& size_n, + (void*)& locks, + (void*)& suh_ptr, + (void*)& A_had_ptr, + (void*)& svh_ptr + }; + cudaLaunchCooperativeKernel + ( + (void*) kernel, + num_sms, + block_dim, + kernelArgs, + EXL3_GEMM_SMEM_MAX, + stream + ); + + cuda_check(cudaPeekAtLastError()); + return shape_idx; +} diff --git a/gptqmodel_ext/exllamav3/quant/exl3_gemm.cuh b/gptqmodel_ext/exllamav3/quant/exl3_gemm.cuh new file mode 100644 index 000000000..3cde308eb --- /dev/null +++ b/gptqmodel_ext/exllamav3/quant/exl3_gemm.cuh @@ -0,0 +1,17 @@ +#pragma once + +#include + +int exl3_gemm +( + const at::Tensor& A, + const at::Tensor& B, + at::Tensor& C, + const c10::optional& suh, + const c10::optional& A_had, + const c10::optional& svh, + int force_shape_idx, + bool mcg, + bool mul1, + int force_num_sms +); diff --git a/gptqmodel_ext/exllamav3/quant/exl3_gemm_inner.cuh b/gptqmodel_ext/exllamav3/quant/exl3_gemm_inner.cuh new file mode 100644 index 000000000..74cf66461 --- /dev/null +++ b/gptqmodel_ext/exllamav3/quant/exl3_gemm_inner.cuh @@ -0,0 +1,610 @@ +#pragma once + +#include "../ptx.cuh" + +// Constants +#define EXL3_GEMM_BASE_THREADS 256 +#define SMEM_MAX (90 * 1024) // max shared memory on compute capability 8.6 + +#include "exl3_dq.cuh" + +template +inline __device__ +void exl3_gemm_kernel_inner +( + const half* __restrict__ A, + const uint16_t* __restrict__ B, + void* __restrict__ C, + const int size_m, + const int size_k, + const int size_n, + int* __restrict__ locks +) +{ + const int TILEBLOCKS_M = TILESIZE_M / 16; + const int TILEBLOCKS_K = TILESIZE_K / 16; + const int TILEBLOCKS_N = TILESIZE_N / 16; + const int FRAGS_M = TILEBLOCKS_M; + const int FRAGS_N_PER_WARP = 2 * TILEBLOCKS_N / (EXL3_GEMM_BASE_THREADS / 32); + + const int sh_a_stage_size = TILESIZE_M * TILESIZE_K; // in halfs + const int sh_b_stage_size = TILEBLOCKS_K * TILEBLOCKS_N * 256 / 16 * bits; // in uint16s + const int sh_c_size = 4 * EXL3_GEMM_BASE_THREADS; // in floats + + // Sanity checks + static_assert(EXL3_GEMM_BASE_THREADS == 256); + static_assert(TILESIZE_M % 16 == 0, "Invalid kernel params"); + static_assert(TILESIZE_K % 16 == 0, "Invalid kernel params"); + static_assert(TILESIZE_N % 128 == 0, "Invalid kernel params"); + static_assert + ( + SMEM_MAX >= SH_STAGES * (2 * sh_a_stage_size + 2 * sh_b_stage_size) + 4 * sh_c_size, + "Invalid kernel params (insufficient shared memory for shape)" + ); + + // Shared memory + extern __shared__ half shared[]; + half* sh_a = shared; + uint16_t* sh_b = (uint16_t*) (sh_a + SH_STAGES * sh_a_stage_size); + float* sh_c = (float*) (sh_b + sh_b_stage_size * SH_STAGES); + + // Thread index + int t = threadIdx.x % EXL3_GEMM_BASE_THREADS; + int sub_k = threadIdx.x / EXL3_GEMM_BASE_THREADS; + int warp_id = t / 32; + int lane_id = t % 32; + + // Dimensions + //int tiles_m = CEIL_DIVIDE(size_m, TILESIZE_M); + int tiles_k = size_k / TILESIZE_K; + int tiles_n = size_n / TILESIZE_N; + //int blocks_m = 1; + //int blocks_k = tiles_k * TILEBLOCKS_K; + int blocks_n = tiles_n * TILEBLOCKS_N; + + // Start and end index of current slice, must span at least one tile + int num_slices = gridDim.x; + int slice_beg = tiles_k * tiles_n * blockIdx.x / num_slices; + int slice_end = tiles_k * tiles_n * (blockIdx.x + 1) / num_slices; + int slice_len = slice_end - slice_beg; + if (slice_len < 1) return; + + auto index_m = [&] (int slice_i) { return 0; }; //blockIdx.y; }; + auto index_k = [&] (int slice_i) { return (slice_i % tiles_k); }; + auto index_n = [&] (int slice_i) { return (slice_i / tiles_k); }; + + // Batch dimension + int slice_m = index_m(slice_beg); + int max_m = MIN(size_m - slice_m * TILESIZE_M, TILESIZE_M); + + // Pipe 0, global A, B tile and shared A, B tile + int slice0_k = index_k(slice_beg); + int slice0_n = index_n(slice_beg); + int slice0_iters = slice_len; + + int gl_a_stride_m = TILESIZE_M * size_k; + const int gl_a_stride_k = TILESIZE_K; + const int sh0_a_stride_m = TILESIZE_M * TILESIZE_K; + const half* gl_a_ptr = A + slice_m * gl_a_stride_m + slice0_k * gl_a_stride_k; + half* sh0_a_ptr = sh_a + (slice0_iters % SH_STAGES) * sh_a_stage_size; + + const int load_a_iters = CEIL_DIVIDE(sh0_a_stride_m / 8, EXL3_GEMM_BASE_THREADS); + bool pred_a_gl[load_a_iters]; + int load_a_gl[load_a_iters]; + for (int i = 0; i < load_a_iters; ++i) + { + int k = (i * EXL3_GEMM_BASE_THREADS + t) % (gl_a_stride_k / 8); + int m = (i * EXL3_GEMM_BASE_THREADS + t) / (gl_a_stride_k / 8); + load_a_gl[i] = m * size_k / 8 + k; + pred_a_gl[i] = m < max_m; + } + + int gl_b_stride_k = blocks_n * TILEBLOCKS_K * 256 / 16 * bits; + const int gl_b_stride_n = TILEBLOCKS_N * 256 / 16 * bits; + const int sh0_b_stride_k = TILEBLOCKS_K * TILEBLOCKS_N * 256 / 16 * bits; + const uint16_t* gl_b_ptr = B + slice0_k * gl_b_stride_k + slice0_n * gl_b_stride_n; + uint16_t* sh0_b_ptr = sh_b + (slice0_iters % SH_STAGES) * sh_b_stage_size; + + const int load_b_iters = CEIL_DIVIDE(sh0_b_stride_k / 8, EXL3_GEMM_BASE_THREADS); + bool pred_b_gl[load_b_iters]; + int load_b_gl[load_b_iters]; + for (int i = 0; i < load_b_iters; ++i) + { + int n = (i * EXL3_GEMM_BASE_THREADS + t) % (gl_b_stride_n / 8); + int k = (i * EXL3_GEMM_BASE_THREADS + t) / (gl_b_stride_n / 8); + load_b_gl[i] = k * blocks_n * 256 / 16 * bits / 8 * k + n; + pred_b_gl[i] = i * EXL3_GEMM_BASE_THREADS + t < sh0_b_stride_k / 8; + } + + auto advance0 = [&] () + { + slice0_k++; + slice0_iters--; + + int stage = slice0_iters % SH_STAGES; + sh0_a_ptr = sh_a + stage * sh_a_stage_size; + sh0_b_ptr = sh_b + stage * sh_b_stage_size; + + if (slice0_k >= tiles_k) + { + slice0_k = 0; + slice0_n++; + gl_a_ptr = A + slice_m * gl_a_stride_m + slice0_k * gl_a_stride_k; + gl_b_ptr = B + slice0_k * gl_b_stride_k + slice0_n * gl_b_stride_n; + } + else + { + gl_a_ptr += gl_a_stride_k; + gl_b_ptr += gl_b_stride_k; + } + }; + + // Pipe 1, shared A, B tile and registers + int slice1_k = slice0_k; + int slice1_n = slice0_n; + int slice1_iters = slice0_iters; + + half* sh1_a_ptr = sh_a + (slice1_iters % SH_STAGES) * sh_a_stage_size; + uint16_t* sh1_b_ptr = sh_b + (slice1_iters % SH_STAGES) * sh_b_stage_size; + + auto advance1 = [&] () + { + slice1_k++; + slice1_iters--; + + int stage = slice1_iters % SH_STAGES; + sh1_a_ptr = sh_a + stage * sh_a_stage_size; + sh1_b_ptr = sh_b + stage * sh_b_stage_size; + + if (slice1_k >= tiles_k) + { + slice1_k = 0; + slice1_n++; + } + }; + + // Pipe 2 + int slice2_k = slice0_k; + int slice2_k0 = slice0_k; + int slice2_n = slice0_n; + int slice2_iters = slice0_iters; + + int gl_c_stride_n = TILESIZE_N; + int gl_c_stride_m = TILESIZE_M * size_n; + + half* gl_c_ptr_16 = ((half*) C) + slice_m * gl_c_stride_m + slice2_n * gl_c_stride_n; + float* gl_c_ptr_32 = ((float*) C) + slice_m * gl_c_stride_m + slice2_n * gl_c_stride_n; + + register FragA frag_a[FRAG_STAGES][FRAGS_M]; + register FragB frag_b[FRAG_STAGES][FRAGS_N_PER_WARP]; + register FragC frag_c[FRAGS_M][FRAGS_N_PER_WARP]; + + auto advance2 = [&] () + { + slice2_k++; + slice2_iters--; + + if (slice2_k >= tiles_k) + { + slice2_k = 0; + slice2_k0 = 0; + slice2_n++; + if constexpr (c_fp32) + gl_c_ptr_32 += gl_c_stride_n; + else + gl_c_ptr_16 += gl_c_stride_n; + } + }; + + // Schedule load of the next A, B tiles to shared memory and advance the pipeline + auto async_load_gl = [&] () + { + if (sub_k) + { + cp_async_fence(); + return; + } + + if (slice0_iters) + { + // Copy tile from row-major A matrix + { + const int4* gl = (const int4*) gl_a_ptr; + int4* sh = (int4*) sh0_a_ptr; + #pragma unroll + for (int i = 0; i < load_a_iters; ++i) + { + // TODO: Rearrange into ldmatrix friendly layout while loading? + // cp_async_pred(sh + EXL3_GEMM_BASE_THREADS * i + t, gl + load_a_gl[i], pred_a_gl[i]); + if (pred_a_gl[i]) cp_async(sh + EXL3_GEMM_BASE_THREADS * i + t, gl + load_a_gl[i]); + } + } + + // Copy tile of 256-element blocks from quantized B matrix + { + const int4* gl = (const int4*) gl_b_ptr; + int4* sh = (int4*) sh0_b_ptr; + #pragma unroll + for (int i = 0; i < load_b_iters; ++i) + { + // cp_async_pred(sh + EXL3_GEMM_BASE_THREADS * i + t, gl + load_b_gl[i], pred_b_gl[i]); + if (pred_b_gl[i]) cp_async(sh + EXL3_GEMM_BASE_THREADS * i + t, gl + load_b_gl[i]); + } + } + advance0(); + } + + // Sync and advance + cp_async_fence(); + }; + + // Load fragments + // Ref. for fragment layout: + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type + auto load_frags = [&] (int buf) + { + if (!slice1_iters) return; + + // A fragments + { + // TODO: Resolve bank conflicts + int r = (lane_id % 8) + 8 * ((lane_id / 8) % 2); + int c = lane_id / 16; + int4* sha = (int4*) sh1_a_ptr + r * TILESIZE_K / 8 + c; + #pragma unroll + for (int m = 0; m < TILEBLOCKS_M; ++m) + ldsm4(frag_a[buf][m], sha + (m * 16) * TILESIZE_K / 8 + sub_k * 16 / 8); + } + + // B fragments + #pragma unroll + for (int n2 = 0; n2 < FRAGS_N_PER_WARP; n2 += 2) + { + int sub_n2 = warp_id * FRAGS_N_PER_WARP / 2 + n2 / 2; + const uint32_t* shb = (const uint32_t*) (sh1_b_ptr + (sub_k * TILEBLOCKS_N + sub_n2) * 256 / 16 * bits); + + dq_dispatch(shb, lane_id << 3, frag_b[buf][n2], frag_b[buf][n2 + 1]); + } + + __syncthreads(); + advance1(); + }; + + // Clear C fragments + auto clear_frag_c = [&] () + { + #pragma unroll + for (int m = 0; m < FRAGS_M; ++m) + #pragma unroll + for (int n = 0; n < FRAGS_N_PER_WARP; ++n) + frag_c[m][n] = {}; + }; + + // Threadblock reduction + auto threadblock_reduce = [&] () + { + auto store = [&] (int i, int m, int n) + { + // TODO: Shuffle to avoid bank conflicts here? Doesn't seem to be a bottleneck + if (sub_k == i) + { + float* sh_red = sh_c + (FRAGS_N_PER_WARP * 4) * t; + if (size_m <= 8) + { + #pragma unroll + for (int i = 0; i < 2; ++i) *sh_red++ = frag_c[m][n][i]; + } + else + { + #pragma unroll + for (int i = 0; i < 4; ++i) *sh_red++ = frag_c[m][n][i]; + } + } + __syncthreads(); + }; + + auto add = [&] (int i, int m, int n) + { + if (sub_k == i) + { + float* sh_red = sh_c + (FRAGS_N_PER_WARP * 4) * t; + if (size_m <= 8) + { + #pragma unroll + for (int i = 0; i < 2; ++i) frag_c[m][n][i] += *sh_red++; + } + else + { + #pragma unroll + for (int i = 0; i < 4; ++i) frag_c[m][n][i] += *sh_red++; + } + } + __syncthreads(); + }; + + #pragma unroll + for (int m = 0; m < FRAGS_M; ++m) + { + #pragma unroll + for (int n = 0; n < FRAGS_N_PER_WARP; ++n) + { + if constexpr (TILEBLOCKS_K == 2) + { + store(1, m, n); + add(0, m, n); + } + if constexpr (TILEBLOCKS_K == 3) + { + store(1, m, n); + add(0, m, n); + store(2, m, n); + add(0, m, n); + } + if constexpr (TILEBLOCKS_K == 4) + { + store(3, m, n); + add(2, m, n); + store(1, m, n); + add(0, m, n); + store(2, m, n); + add(0, m, n); + } + } + } + }; + + // Output reduction + auto reduce = [&] () + { + // First reduce all partial sums along k for the current slice + threadblock_reduce(); + + // Process (partial) slices within column in reverse order so the threadblock doing the bottom slice is + // free to proceed to the next column right away + int lock_i = tiles_k - slice2_k - 1; + int lock_d = slice2_k - slice2_k0 + 1; + int* lock = &locks[slice_m * blocks_n + slice2_n]; + + barrier_acquire(lock, lock_i); + + bool first = lock_i == 0; + bool last = lock_i + lock_d == tiles_k; + + int n0 = warp_id * FRAGS_N_PER_WARP; + + // Second and subsequent threadblocks in column read back the intermediate sum from global memory + if (!sub_k && !first) + { + #pragma unroll + for (int n = 0; n < FRAGS_N_PER_WARP; ++n) + { + #pragma unroll + for (int m = 0; m < FRAGS_M; ++m) + { + int r0 = lane_id / 4 + 16 * m; + int r1 = r0 + 8; + int c = (lane_id % 4) * 2; + if (r0 < max_m) + { + if constexpr (c_fp32) + { + float* c_ptr = gl_c_ptr_32 + r0 * size_n + (n0 + n) * 8 + c; + frag_c[m][n][0] += *c_ptr++; + frag_c[m][n][1] += *c_ptr++; + } + else + { + half2* c_ptr = (half2*) (gl_c_ptr_16 + r0 * size_n + (n0 + n) * 8 + c); + float2 interm = __half22float2(*c_ptr); + frag_c[m][n][0] += interm.x; + frag_c[m][n][1] += interm.y; + } + } + if (r1 < max_m) + { + if constexpr (c_fp32) + { + float* c_ptr = gl_c_ptr_32 + r1 * size_n + (n0 + n) * 8 + c; + frag_c[m][n][2] += *c_ptr++; + frag_c[m][n][3] += *c_ptr++; + } + else + { + half2* c_ptr = (half2*) (gl_c_ptr_16 + r1 * size_n + (n0 + n) * 8 + c); + float2 interm = __half22float2(*c_ptr); + frag_c[m][n][2] += interm.x; + frag_c[m][n][3] += interm.y; + } + } + } + } + } + + // All but last threadblock in column write the intermediate result to global memory + if (!sub_k && !last) + { + #pragma unroll + for (int n = 0; n < FRAGS_N_PER_WARP; ++n) + { + #pragma unroll + for (int m = 0; m < FRAGS_M; ++m) + { + int r0 = lane_id / 4 + 16 * m; + int r1 = r0 + 8; + int c = (lane_id % 4) * 2; + if (r0 < max_m) + { + if constexpr (c_fp32) + { + float* c_ptr = gl_c_ptr_32 + r0 * size_n + (n0 + n) * 8 + c; + *c_ptr++ = frag_c[m][n][0]; + *c_ptr++ = frag_c[m][n][1]; + } + else + { + half2* c_ptr = (half2*) (gl_c_ptr_16 + r0 * size_n + (n0 + n) * 8 + c); + half2 sum = __floats2half2_rn(frag_c[m][n][0], frag_c[m][n][1]); + *c_ptr = sum; + } + } + if (r1 < max_m) + { + if constexpr (c_fp32) + { + float* c_ptr = gl_c_ptr_32 + r1 * size_n + (n0 + n) * 8 + c; + *c_ptr++ = frag_c[m][n][2]; + *c_ptr++ = frag_c[m][n][3]; + } + else + { + half2* c_ptr = (half2*) (gl_c_ptr_16 + r1 * size_n + (n0 + n) * 8 + c); + half2 sum = __floats2half2_rn(frag_c[m][n][2], frag_c[m][n][3]); + *c_ptr = sum; + } + } + } + } + } + + // Last block writes in row-major format + if (!sub_k && last) + { + #pragma unroll + for (int n = 0; n < FRAGS_N_PER_WARP; ++n) + { + #pragma unroll + for (int m = 0; m < FRAGS_M; ++m) + { + int r0 = lane_id / 4 + 16 * m; + int r1 = r0 + 8; + int c = (lane_id % 4) * 2; + if (r0 < max_m) + { + if constexpr (c_fp32) + { + float* c_ptr = gl_c_ptr_32 + r0 * size_n + (n0 + n) * 8 + c; + *c_ptr++ = frag_c[m][n][0]; + *c_ptr++ = frag_c[m][n][1]; + } + else + { + half2* c_ptr = (half2*) (gl_c_ptr_16 + r0 * size_n + (n0 + n) * 8 + c); + half2 sum = __floats2half2_rn(frag_c[m][n][0], frag_c[m][n][1]); + *c_ptr = sum; + } + } + if (r1 < max_m) + { + if constexpr (c_fp32) + { + float* c_ptr = gl_c_ptr_32 + r1 * size_n + (n0 + n) * 8 + c; + *c_ptr++ = frag_c[m][n][2]; + *c_ptr++ = frag_c[m][n][3]; + } + else + { + half2* c_ptr = (half2*) (gl_c_ptr_16 + r1 * size_n + (n0 + n) * 8 + c); + half2 sum = __floats2half2_rn(frag_c[m][n][2], frag_c[m][n][3]); + *c_ptr = sum; + } + } + } + } + } + + barrier_release(lock, lock_d, last); + + clear_frag_c(); + }; + + // Wait until there are at most SH_STAGES - 2 async copies pending, i.e. at least one stage has finished loading + auto wait_stage = [&] () + { + cp_async_wait(); + __syncthreads(); + }; + + // Perform tensor core matmul on current tile + auto matmul = [&] (int buf) + { + #pragma unroll + for (int m = 0; m < FRAGS_M; ++m) + #pragma unroll + for (int n = 0; n < FRAGS_N_PER_WARP; ++n) + ptx_mma_m16n8k16(frag_a[buf][m], frag_b[buf][n], frag_c[m][n]); + }; + + // Start global to shared pipeline + #pragma unroll + for (int i = 0; i < SH_STAGES - 1; ++i) + async_load_gl(); + wait_stage(); + + // Start shared to register pipeline. + clear_frag_c(); + if constexpr (FRAG_STAGES > 1) + load_frags(0); + + // Main loop. Fragments are double buffered to allow more interleaving. This is especially important to hide the + // dequantization overhead, but we need two different iterations of the main loop to avoid confusing the compiler + // and making it (sometimes) place the fragment arrays in local memory + + #define FSTAGE(_load, _mul) \ + async_load_gl(); \ + wait_stage(); \ + load_frags(_load); \ + matmul(_mul); \ + if (slice2_k == tiles_k - 1 || slice2_iters == 1) { reduce(); slice2_k0 = slice2_k + 1; } \ + advance2(); \ + if (!slice2_iters) break; \ + + if constexpr (FRAG_STAGES == 1) + { + while (true) + { + FSTAGE(0, 0); + } + } + + if constexpr (FRAG_STAGES == 2) + { + while (true) + { + FSTAGE(1, 0); + FSTAGE(0, 1); + } + } + + if constexpr (FRAG_STAGES == 3) + { + while (true) + { + FSTAGE(1, 0); + FSTAGE(2, 1); + FSTAGE(0, 2); + } + } + + if constexpr (FRAG_STAGES == 4) + { + while (true) + { + FSTAGE(1, 0); + FSTAGE(2, 1); + FSTAGE(3, 2); + FSTAGE(0, 3); + } + } + + if constexpr (FRAG_STAGES == 5) + { + while (true) + { + FSTAGE(1, 0); + FSTAGE(2, 1); + FSTAGE(3, 2); + FSTAGE(4, 3); + FSTAGE(0, 4); + } + } +} diff --git a/gptqmodel_ext/exllamav3/quant/exl3_gemm_kernel.cuh b/gptqmodel_ext/exllamav3/quant/exl3_gemm_kernel.cuh new file mode 100644 index 000000000..510d94519 --- /dev/null +++ b/gptqmodel_ext/exllamav3/quant/exl3_gemm_kernel.cuh @@ -0,0 +1,80 @@ +#pragma once + +#include "exl3_kernel_map.cuh" +#include "hadamard_inner.cuh" +#include "exl3_gemm_inner.cuh" + +template +__global__ __launch_bounds__(EXL3_GEMM_BASE_THREADS * TILESIZE_K / 16) +void exl3_gemm_kernel(EXL3_GEMM_ARGS) +{ + auto grid = cg::this_grid(); + + if (suh) + { + int total_warps = size_m * size_k / 128; + int warps_grid = gridDim.x * blockDim.x / 32; + int this_warp = threadIdx.x / 32 + blockDim.x / 32 * blockIdx.x; + + for(; this_warp < total_warps; this_warp += warps_grid) + had_hf_r_128_inner + ( + A + this_warp * 128, + A_had + this_warp * 128, + suh + (this_warp * 128) % size_k, + nullptr, + 0.088388347648f // 1/sqrt(128) + ); + + grid.sync(); + A = A_had; + } + + int size_m_ = size_m; + const half* A_ = A; + void* C_ = C; + + while (size_m_ > 0) + { + exl3_gemm_kernel_inner + + (A_, B, C_, size_m_, size_k, size_n, locks); + + A_ += 16 * size_k; + if constexpr (c_fp32) C_ = (void*) (((float*) C_) + 16 * size_n); + else C_ = (void*) (((half*) C_) + 16 * size_n); + size_m_ -= 16; + + if (size_m_ > 0 || svh) + grid.sync(); + } + + if (svh) + { + int total_warps = size_m * size_n / 128; + int warps_grid = gridDim.x * blockDim.x / 32; + int this_warp = threadIdx.x / 32 + blockDim.x / 32 * blockIdx.x; + + for(; this_warp < total_warps; this_warp += warps_grid) + { + if constexpr (c_fp32) + had_ff_r_128_inner + ( + ((const float*) C) + this_warp * 128, + ((float*) C) + this_warp * 128, + nullptr, + svh + (this_warp * 128) % size_n, + 0.088388347648f // 1/sqrt(128) + ); + else + had_hf_r_128_inner + ( + ((const half*) C) + this_warp * 128, + ((half*) C) + this_warp * 128, + nullptr, + svh + (this_warp * 128) % size_n, + 0.088388347648f // 1/sqrt(128) + ); + } + } +} diff --git a/gptqmodel_ext/exllamav3/quant/exl3_kernel_map.cu b/gptqmodel_ext/exllamav3/quant/exl3_kernel_map.cu new file mode 100644 index 000000000..3ec02162b --- /dev/null +++ b/gptqmodel_ext/exllamav3/quant/exl3_kernel_map.cu @@ -0,0 +1,203 @@ +#include + +#include +#include +#include + +#include "../util.h" +#include "exl3_devctx.cuh" +#include "exl3_kernel_map.cuh" +#include "exl3_kernel_map_packed.cuh" +#include "comp_units/exl3_comp_unit_1.cuh" +#include "comp_units/exl3_comp_unit_2.cuh" +#include "comp_units/exl3_comp_unit_3.cuh" +#include "comp_units/exl3_comp_unit_4.cuh" +#include "comp_units/exl3_comp_unit_5.cuh" +#include "comp_units/exl3_comp_unit_6.cuh" +#include "comp_units/exl3_comp_unit_7.cuh" +#include "comp_units/exl3_comp_unit_8.cuh" + +namespace { + +struct TPackedTable +{ + const int* n_axis; + int n_count; + const uint16_t* payload; +}; + +std::map tuning_cache = {}; +TResult forced_result; + +int exl3_gemm_tilesize_k[] = {EXL3_GEMM_TILESIZE_K}; +int exl3_gemm_tilesize_n[] = {EXL3_GEMM_TILESIZE_N}; +int exl3_gemm_blockdim[] = {EXL3_GEMM_BLOCKDIM}; + +constexpr TPackedTable packed_table_128 = { + exl3_packed::n_axis_128, + exl3_packed::n_axis_len_128, + exl3_packed::samples_128 +}; + +constexpr TPackedTable packed_table_256 = { + exl3_packed::n_axis_256, + exl3_packed::n_axis_len_256, + exl3_packed::samples_256 +}; + +constexpr TPackedTable packed_table_512 = { + exl3_packed::n_axis_512, + exl3_packed::n_axis_len_512, + exl3_packed::samples_512 +}; + +int map_cc_to_index(int cc) +{ + switch (cc) + { + case CC_AMPERE: return 0; + case CC_ADA: return 1; + case CC_HOPPER: return 2; + default: return -1; + } +} + +int nearest_axis_index(const int* axis, int axis_len, int value) +{ + int best_idx = 0; + int64_t best_dist = axis[0] > value ? axis[0] - (int64_t) value : (int64_t) value - axis[0]; + + for (int idx = 1; idx < axis_len; ++idx) + { + int64_t dist = axis[idx] > value ? axis[idx] - (int64_t) value : (int64_t) value - axis[idx]; + if (dist < best_dist) + { + best_dist = dist; + best_idx = idx; + } + } + + return best_idx; +} + +const TPackedTable& select_packed_table(int size_n) +{ + bool mod512 = (size_n % 512 == 0); + bool mod256 = (size_n % 256 == 0); + bool mod128 = (size_n % 128 == 0); + TORCH_CHECK(mod128, "size_n must be a multiple of 128"); + + if (mod512) return packed_table_512; + if (mod256) return packed_table_256; + return packed_table_128; +} + +uint16_t lookup_packed_sample(const TPackedTable& table, int cc_idx, int bits, int k_idx, int n_idx) +{ + int flat_idx = + ((((cc_idx * exl3_packed::bit_count) + (bits - 1)) * exl3_packed::k_axis_len) + k_idx) * table.n_count + n_idx; + return table.payload[flat_idx]; +} + +fp_exl3_gemm_kernel get_gemm_kernel_ptr(int K, int shape_idx, bool c_fp32, int cb) +{ + int kernel_idx = shape_idx + (EXL3_GEMM_NUM_SHAPES + 1) * cb; + + if (c_fp32) + { + switch (K) + { + case 1: return tfp_exl3_gemm_kernel_fp32_b1[kernel_idx]; + case 2: return tfp_exl3_gemm_kernel_fp32_b2[kernel_idx]; + case 3: return tfp_exl3_gemm_kernel_fp32_b3[kernel_idx]; + case 4: return tfp_exl3_gemm_kernel_fp32_b4[kernel_idx]; + case 5: return tfp_exl3_gemm_kernel_fp32_b5[kernel_idx]; + case 6: return tfp_exl3_gemm_kernel_fp32_b6[kernel_idx]; + case 7: return tfp_exl3_gemm_kernel_fp32_b7[kernel_idx]; + case 8: return tfp_exl3_gemm_kernel_fp32_b8[kernel_idx]; + default: TORCH_CHECK(false, "No kernel for GEMM shape"); + } + } + else + { + switch (K) + { + case 1: return tfp_exl3_gemm_kernel_fp16_b1[kernel_idx]; + case 2: return tfp_exl3_gemm_kernel_fp16_b2[kernel_idx]; + case 3: return tfp_exl3_gemm_kernel_fp16_b3[kernel_idx]; + case 4: return tfp_exl3_gemm_kernel_fp16_b4[kernel_idx]; + case 5: return tfp_exl3_gemm_kernel_fp16_b5[kernel_idx]; + case 6: return tfp_exl3_gemm_kernel_fp16_b6[kernel_idx]; + case 7: return tfp_exl3_gemm_kernel_fp16_b7[kernel_idx]; + case 8: return tfp_exl3_gemm_kernel_fp16_b8[kernel_idx]; + default: TORCH_CHECK(false, "No kernel for GEMM shape"); + } + } + + return nullptr; +} + +} // namespace + +TResult* select_exl3_gemm_kernel_tuned +( + int cc, + int size_k, + int size_n, + int K, + bool c_fp32, + int force_shape_idx, + int force_num_sms, + int cb +) +{ + if (force_shape_idx > 0) + { + TORCH_CHECK(force_num_sms, "Must supply force_shape_idx and force_num_sms together"); + forced_result.kernel = get_gemm_kernel_ptr(K, force_shape_idx, c_fp32, cb); + forced_result.shape_idx = force_shape_idx; + forced_result.num_sms = force_num_sms; + forced_result.block_dim = exl3_gemm_blockdim[force_shape_idx]; + return &forced_result; + } + TORCH_CHECK(!force_num_sms, "Must supply force_shape_idx and force_num_sms together."); + + // Cache by the dimensions that drive sample lookup plus cb/c_fp32 because they change kernel tables. + uint64_t key = (((uint64_t) size_k) << 40) | + (((uint64_t) size_n) << 16) | + (((uint64_t) cc) << 8) | + (((uint64_t) K) << 4) | + (((uint64_t) cb) << 1) | + (c_fp32 ? 0x01ull : 0x00ull); + + auto lookup = tuning_cache.find(key); + if (lookup == tuning_cache.end()) + { + TORCH_CHECK(K >= 1 && K <= exl3_packed::bit_count, "Failed to find valid kernel for shape"); + int cc_idx = map_cc_to_index(cc); + TORCH_CHECK(cc_idx >= 0, "Failed to find valid kernel for shape"); + + const TPackedTable& table = select_packed_table(size_n); + int k_idx = nearest_axis_index(exl3_packed::k_axis, exl3_packed::k_axis_len, size_k); + int n_idx = nearest_axis_index(table.n_axis, table.n_count, size_n); + uint16_t packed = lookup_packed_sample(table, cc_idx, K, k_idx, n_idx); + int shape_idx = packed >> 8; + int tuned_num_sms = packed & 0xff; + TORCH_CHECK(shape_idx, "Failed to find valid kernel for shape"); + + int tilesize_k = exl3_gemm_tilesize_k[shape_idx]; + int tilesize_n = exl3_gemm_tilesize_n[shape_idx]; + int max_slices = size_k / tilesize_k * size_n / tilesize_n; + int num_sms = MAX(MIN(max_slices, tuned_num_sms), 1); + + tuning_cache[key] = TResult { + get_gemm_kernel_ptr(K, shape_idx, c_fp32, cb), + shape_idx, + num_sms, + exl3_gemm_blockdim[shape_idx] + }; + } + + lookup = tuning_cache.find(key); + return &(lookup->second); +} diff --git a/gptqmodel_ext/exllamav3/quant/exl3_kernel_map.cuh b/gptqmodel_ext/exllamav3/quant/exl3_kernel_map.cuh new file mode 100644 index 000000000..ae9ca7ae6 --- /dev/null +++ b/gptqmodel_ext/exllamav3/quant/exl3_kernel_map.cuh @@ -0,0 +1,80 @@ +#pragma once + +#define EXL3_GEMM_T_ARGS \ + const int bits, \ + const bool c_fp32, \ + const int cb, \ + const int TILESIZE_M, \ + const int TILESIZE_K, \ + const int TILESIZE_N, \ + const int SH_STAGES, \ + const int FRAG_STAGES + +#define EXL3_GEMM_ARGS \ + const half* __restrict__ A, \ + const uint16_t* __restrict__ B, \ + void* __restrict__ C, \ + const int size_m, \ + const int size_k, \ + const int size_n, \ + int* __restrict__ locks, \ + const half* __restrict__ suh, \ + half* __restrict__ A_had, \ + const half* __restrict__ svh + +typedef void (*fp_exl3_gemm_kernel) (EXL3_GEMM_ARGS); + +#define EXL3_GEMM_SHAPE_1 16, 16, 128, 6, 5 +#define EXL3_GEMM_SHAPE_2 16, 32, 128, 4, 3 +#define EXL3_GEMM_SHAPE_3 16, 32, 256, 4, 3 +#define EXL3_GEMM_SHAPE_4 16, 16, 512, 4, 3 + +#define EXL3_GEMM_TILESIZE_K 0, 16, 32, 32, 16 +#define EXL3_GEMM_TILESIZE_N 0, 128, 128, 256, 512 +#define EXL3_GEMM_BLOCKDIM 0, 256, 512, 512, 256 + +#define EXL3_GEMM_NUM_SHAPES 4 + +#define EXL3_GEMM_KERNEL_INSTANCES(_bits, _c_fp32, cb) \ + nullptr, \ + exl3_gemm_kernel<_bits, _c_fp32, cb, EXL3_GEMM_SHAPE_1>, \ + exl3_gemm_kernel<_bits, _c_fp32, cb, EXL3_GEMM_SHAPE_2>, \ + exl3_gemm_kernel<_bits, _c_fp32, cb, EXL3_GEMM_SHAPE_3>, \ + exl3_gemm_kernel<_bits, _c_fp32, cb, EXL3_GEMM_SHAPE_4> + +#define ALL_EXL3_KERNEL_EXTERNS(K) \ + extern fp_exl3_gemm_kernel tfp_exl3_gemm_kernel_fp32_b##K[]; \ + extern fp_exl3_gemm_kernel tfp_exl3_gemm_kernel_fp16_b##K[]; \ + +#define ALL_EXL3_KERNEL_INSTANCES(K) \ + fp_exl3_gemm_kernel tfp_exl3_gemm_kernel_fp32_b##K[] = { \ + EXL3_GEMM_KERNEL_INSTANCES(K, true, 0), \ + EXL3_GEMM_KERNEL_INSTANCES(K, true, 1), \ + EXL3_GEMM_KERNEL_INSTANCES(K, true, 2) \ + }; \ + \ + fp_exl3_gemm_kernel tfp_exl3_gemm_kernel_fp16_b##K[] = { \ + EXL3_GEMM_KERNEL_INSTANCES(K, false, 0), \ + EXL3_GEMM_KERNEL_INSTANCES(K, false, 1), \ + EXL3_GEMM_KERNEL_INSTANCES(K, false, 2) \ + }; + +struct TResult +{ + fp_exl3_gemm_kernel kernel; + int shape_idx; + int num_sms; + int block_dim; +}; + +TResult* select_exl3_gemm_kernel_tuned +( + const int cc, + const int size_k, + const int size_n, + const int K, + const bool c_fp32, + const int force_shape_idx, + const int force_num_sms, + const int cb +); diff --git a/gptqmodel_ext/exllamav3/quant/exl3_kernel_map_packed.cuh b/gptqmodel_ext/exllamav3/quant/exl3_kernel_map_packed.cuh new file mode 100644 index 000000000..241e88d50 --- /dev/null +++ b/gptqmodel_ext/exllamav3/quant/exl3_kernel_map_packed.cuh @@ -0,0 +1,600 @@ +#pragma once + +#include + +// Generated by scripts/generate_exl3_kernel_map_packed.py. +// Encodes the EXL3 tuning samples as dense [cc][bits][k][n] grids. + +namespace exl3_packed { + +constexpr int cc_count = 3; +constexpr int bit_count = 8; +constexpr int k_axis_len = 13; +constexpr int n_axis_len_128 = 15; +constexpr int n_axis_len_256 = 14; +constexpr int n_axis_len_512 = 13; + +constexpr int cc_values[] = { + 2, 3, 4, +}; + +constexpr int bit_values[] = { + 1, 2, 3, 4, 5, 6, 7, 8, +}; + +constexpr int k_axis[] = { + 128, 256, 512, 1024, 2048, 3072, 4096, 5120, + 8192, 12288, 14336, 16384, 24576, +}; + +constexpr int n_axis_128[] = { + 128, 256, 512, 1024, 2048, 3072, 4096, 5120, + 8192, 12288, 14336, 16384, 24576, 51200, 128000, +}; + +constexpr int n_axis_256[] = { + 256, 512, 1024, 2048, 3072, 4096, 5120, 8192, + 12288, 14336, 16384, 24576, 51200, 128000, +}; + +constexpr int n_axis_512[] = { + 512, 1024, 2048, 3072, 4096, 5120, 8192, 12288, + 14336, 16384, 24576, 51200, 128000, +}; + +constexpr uint16_t samples_128[] = { + 0x0102, 0x0104, 0x0108, 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0140, 0x0154, 0x0254, 0x0240, 0x0254, 0x0250, 0x0254, 0x0202, 0x0204, 0x0208, 0x0210, 0x0220, 0x0230, 0x0240, 0x0150, 0x0240, + 0x0254, 0x0254, 0x0252, 0x0254, 0x0250, 0x0254, 0x0202, 0x0204, 0x0208, 0x0218, 0x0220, 0x0230, 0x0240, 0x0250, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0204, 0x0208, 0x0210, + 0x0220, 0x0240, 0x0248, 0x0254, 0x0250, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0206, 0x020a, 0x0214, 0x0228, 0x0250, 0x024e, 0x0254, 0x0250, 0x0254, 0x0254, 0x0254, 0x0254, + 0x0254, 0x0254, 0x0254, 0x0208, 0x0210, 0x0218, 0x0238, 0x0252, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0208, 0x0210, 0x0220, 0x0240, 0x0252, 0x0254, + 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x020a, 0x0212, 0x0220, 0x0240, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, + 0x020c, 0x0218, 0x0230, 0x0252, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0210, 0x0220, 0x0240, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, + 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0210, 0x0220, 0x023c, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0210, 0x0226, 0x0240, + 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0214, 0x022a, 0x0250, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, + 0x0254, 0x0254, 0x0254, 0x0102, 0x0104, 0x0108, 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0140, 0x0154, 0x0154, 0x0240, 0x0254, 0x0250, 0x0254, 0x0202, 0x0204, 0x0208, 0x0210, 0x0220, 0x0130, + 0x0140, 0x0150, 0x0240, 0x0254, 0x0254, 0x0252, 0x0252, 0x0250, 0x0254, 0x0202, 0x0204, 0x0210, 0x0210, 0x0130, 0x0148, 0x0240, 0x0250, 0x0254, 0x0254, 0x0254, 0x0252, 0x0254, 0x0254, 0x0254, + 0x0204, 0x0208, 0x0210, 0x0220, 0x0240, 0x0248, 0x0254, 0x0250, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0206, 0x020a, 0x0214, 0x0228, 0x0250, 0x0254, 0x0254, 0x0250, 0x0254, + 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0208, 0x0210, 0x0220, 0x0230, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0208, 0x0210, 0x0220, + 0x0240, 0x0252, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x020a, 0x0212, 0x0224, 0x0240, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, + 0x0254, 0x0254, 0x0254, 0x020c, 0x0218, 0x022a, 0x0252, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0210, 0x0220, 0x0240, 0x0254, 0x0254, 0x0254, + 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0210, 0x021e, 0x0240, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, + 0x0210, 0x0226, 0x0240, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0214, 0x022a, 0x0248, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, + 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0102, 0x0104, 0x0108, 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0140, 0x0154, 0x0154, 0x0154, 0x0152, 0x0250, 0x0254, 0x0202, 0x0204, 0x0208, + 0x0210, 0x0120, 0x0130, 0x0140, 0x0150, 0x0154, 0x0154, 0x0154, 0x0152, 0x0252, 0x0250, 0x0254, 0x0204, 0x0208, 0x0210, 0x0118, 0x0140, 0x0148, 0x0154, 0x0150, 0x0154, 0x0254, 0x0254, 0x0254, + 0x0254, 0x0254, 0x0254, 0x0204, 0x0208, 0x0210, 0x0220, 0x0140, 0x0248, 0x0254, 0x0250, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0206, 0x020c, 0x021a, 0x0234, 0x0250, 0x0254, + 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0208, 0x0210, 0x0220, 0x0240, 0x0252, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, + 0x0208, 0x0210, 0x0220, 0x0240, 0x0252, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x020a, 0x0214, 0x0228, 0x0250, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, + 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x020c, 0x021a, 0x0234, 0x0252, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0210, 0x0220, 0x0240, + 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0212, 0x0220, 0x0240, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, + 0x0254, 0x0254, 0x0254, 0x0210, 0x0226, 0x024c, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0218, 0x022a, 0x0254, 0x0254, 0x0254, 0x0254, + 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0102, 0x0104, 0x0108, 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0140, 0x0154, 0x0154, 0x0240, 0x0254, 0x0250, 0x0254, + 0x0202, 0x0204, 0x0208, 0x0210, 0x0120, 0x0130, 0x0140, 0x0150, 0x0240, 0x0254, 0x0254, 0x0254, 0x0252, 0x0250, 0x0254, 0x0202, 0x0204, 0x0208, 0x0118, 0x0130, 0x0148, 0x0240, 0x0250, 0x0254, + 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0204, 0x0208, 0x0210, 0x0220, 0x0240, 0x0248, 0x0254, 0x0250, 0x0252, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0206, 0x020a, 0x0214, + 0x0228, 0x0250, 0x0252, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0208, 0x0210, 0x0218, 0x0230, 0x0252, 0x0252, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, + 0x0254, 0x0254, 0x0254, 0x0208, 0x0210, 0x0220, 0x0240, 0x0252, 0x0252, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x020a, 0x0212, 0x0224, 0x0248, 0x0254, 0x0254, + 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x020a, 0x0218, 0x022a, 0x0252, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, + 0x0210, 0x0218, 0x0234, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0210, 0x021e, 0x023c, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, + 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0210, 0x0220, 0x0240, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0214, 0x0224, 0x0248, + 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0102, 0x0104, 0x0108, 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0140, 0x0154, 0x0154, 0x0152, + 0x0254, 0x0250, 0x0254, 0x0202, 0x0204, 0x0208, 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0154, 0x0154, 0x0254, 0x0254, 0x0252, 0x0250, 0x0254, 0x0204, 0x0208, 0x0210, 0x0118, 0x0140, 0x0148, + 0x0154, 0x0150, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0204, 0x0208, 0x0210, 0x0220, 0x0240, 0x0248, 0x0254, 0x0250, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, + 0x0206, 0x020c, 0x0218, 0x0228, 0x0250, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0208, 0x0210, 0x0220, 0x0240, 0x0252, 0x0254, 0x0254, 0x0254, 0x0254, + 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0208, 0x0210, 0x0220, 0x0240, 0x0252, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x020a, 0x0214, 0x0228, + 0x024e, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x020c, 0x021a, 0x0234, 0x0252, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, + 0x0254, 0x0254, 0x0254, 0x0210, 0x0220, 0x0240, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0212, 0x0220, 0x0240, 0x0254, 0x0254, 0x0254, + 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0210, 0x0226, 0x0240, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, + 0x0216, 0x022c, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0102, 0x0104, 0x0108, 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0140, + 0x0154, 0x0154, 0x0154, 0x0254, 0x0250, 0x0254, 0x0202, 0x0204, 0x0208, 0x0210, 0x0120, 0x0130, 0x0140, 0x0150, 0x0154, 0x0154, 0x0254, 0x0254, 0x0252, 0x0250, 0x0254, 0x0204, 0x0208, 0x0110, + 0x0120, 0x0140, 0x0148, 0x0154, 0x0150, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0204, 0x0208, 0x0210, 0x0220, 0x0240, 0x0248, 0x0254, 0x0250, 0x0254, 0x0254, 0x0254, 0x0254, + 0x0254, 0x0254, 0x0254, 0x0206, 0x020c, 0x0218, 0x0228, 0x0250, 0x0252, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0208, 0x0210, 0x0220, 0x0240, 0x0254, 0x0252, + 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0208, 0x0210, 0x0220, 0x0240, 0x0252, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, + 0x020a, 0x0214, 0x0228, 0x024e, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x020c, 0x021a, 0x0234, 0x0252, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, + 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0210, 0x0220, 0x0240, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0210, 0x0224, 0x0240, + 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0210, 0x0226, 0x024c, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, + 0x0254, 0x0254, 0x0254, 0x0216, 0x022c, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0102, 0x0104, 0x0108, 0x0110, 0x0120, 0x0130, + 0x0140, 0x0140, 0x0140, 0x0154, 0x0154, 0x0154, 0x0254, 0x0250, 0x0254, 0x0202, 0x0204, 0x0208, 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0154, 0x0254, 0x0254, 0x0254, 0x0252, 0x0250, 0x0254, + 0x0204, 0x0208, 0x0110, 0x0120, 0x0140, 0x0148, 0x0154, 0x0250, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0204, 0x0208, 0x0210, 0x0220, 0x0240, 0x0248, 0x0254, 0x0250, 0x0254, + 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0208, 0x020c, 0x0218, 0x0230, 0x0254, 0x0252, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0208, 0x0210, 0x0220, + 0x0240, 0x0252, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x020a, 0x0214, 0x0220, 0x0246, 0x0252, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, + 0x0254, 0x0254, 0x0254, 0x020a, 0x0214, 0x0228, 0x024e, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0210, 0x021a, 0x0234, 0x0252, 0x0254, 0x0254, + 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0210, 0x0222, 0x0240, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, + 0x0212, 0x0224, 0x0248, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0214, 0x0226, 0x024c, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, + 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0218, 0x0230, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0102, 0x0104, 0x0108, + 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0140, 0x0150, 0x0154, 0x0140, 0x0254, 0x0250, 0x0254, 0x0202, 0x0204, 0x0208, 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0154, 0x0154, 0x0154, 0x0154, + 0x0252, 0x0250, 0x0254, 0x0204, 0x0208, 0x0110, 0x0120, 0x0140, 0x0148, 0x0154, 0x0150, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0204, 0x0208, 0x0210, 0x0220, 0x0150, 0x0252, + 0x0254, 0x0250, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0206, 0x020c, 0x021a, 0x0230, 0x0250, 0x0252, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, + 0x0208, 0x0210, 0x0220, 0x0240, 0x0252, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0208, 0x0214, 0x0220, 0x0240, 0x0252, 0x0254, 0x0254, 0x0254, 0x0254, + 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x020a, 0x0214, 0x0228, 0x024e, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x020c, 0x021a, 0x0234, + 0x0252, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0210, 0x0220, 0x0240, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, + 0x0254, 0x0254, 0x0254, 0x0210, 0x0224, 0x0240, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0214, 0x0226, 0x0248, 0x0254, 0x0254, 0x0254, + 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0216, 0x0230, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, 0x0254, + 0x0102, 0x0104, 0x0108, 0x0110, 0x0120, 0x0130, 0x0140, 0x0128, 0x0180, 0x0160, 0x0170, 0x0180, 0x0280, 0x0280, 0x0280, 0x0202, 0x0204, 0x0208, 0x0210, 0x0220, 0x0230, 0x0240, 0x0250, 0x0280, + 0x0260, 0x0270, 0x0280, 0x0280, 0x0280, 0x0280, 0x0202, 0x0204, 0x0208, 0x0210, 0x0220, 0x0230, 0x0240, 0x0250, 0x0280, 0x0280, 0x0270, 0x0280, 0x0280, 0x0280, 0x0280, 0x0204, 0x0208, 0x0210, + 0x0220, 0x0240, 0x0260, 0x0280, 0x0278, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0206, 0x020a, 0x0214, 0x0228, 0x0250, 0x0278, 0x0280, 0x0278, 0x0280, 0x0280, 0x0280, 0x0280, + 0x0280, 0x0280, 0x0280, 0x0206, 0x020c, 0x0218, 0x0230, 0x0260, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0208, 0x0210, 0x0220, 0x0240, 0x0280, 0x0280, + 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x020a, 0x0212, 0x0220, 0x0248, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, + 0x020a, 0x0214, 0x0230, 0x0256, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x020c, 0x021a, 0x0230, 0x0266, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, + 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0210, 0x021e, 0x023c, 0x0270, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0210, 0x0220, 0x023e, + 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0212, 0x0224, 0x0248, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, + 0x0280, 0x0280, 0x0280, 0x0102, 0x0104, 0x0204, 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0180, 0x0160, 0x0170, 0x0180, 0x0280, 0x0280, 0x0280, 0x0202, 0x0204, 0x0208, 0x0210, 0x0220, 0x0230, + 0x0140, 0x0150, 0x0280, 0x0280, 0x0270, 0x0280, 0x0280, 0x0280, 0x0280, 0x0202, 0x0204, 0x0208, 0x0210, 0x0220, 0x0230, 0x0240, 0x0250, 0x0280, 0x0280, 0x0270, 0x0280, 0x0280, 0x0280, 0x0280, + 0x0204, 0x0208, 0x0210, 0x0220, 0x0240, 0x0260, 0x0280, 0x0278, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0206, 0x020a, 0x0214, 0x0228, 0x0250, 0x0278, 0x0280, 0x0280, 0x0280, + 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0206, 0x020c, 0x0218, 0x0230, 0x0260, 0x0274, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0208, 0x0210, 0x0220, + 0x0240, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0208, 0x0212, 0x0224, 0x0248, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, + 0x0280, 0x0280, 0x0280, 0x020a, 0x0214, 0x0228, 0x0248, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x020c, 0x021a, 0x0234, 0x0268, 0x0280, 0x0280, + 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x020e, 0x021e, 0x0238, 0x0270, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, + 0x0210, 0x0220, 0x023e, 0x0268, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0212, 0x0224, 0x0248, 0x027e, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, + 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0202, 0x0104, 0x0108, 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0180, 0x0160, 0x0170, 0x0180, 0x0180, 0x0180, 0x0280, 0x0202, 0x0204, 0x0208, + 0x0210, 0x0120, 0x0230, 0x0140, 0x0150, 0x0180, 0x0180, 0x0270, 0x0280, 0x0180, 0x0280, 0x0280, 0x0204, 0x0208, 0x0210, 0x0220, 0x0140, 0x0148, 0x0160, 0x0178, 0x0280, 0x0180, 0x0270, 0x0280, + 0x0280, 0x0280, 0x0280, 0x0204, 0x0208, 0x0210, 0x0220, 0x0240, 0x0260, 0x0280, 0x0270, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0206, 0x020c, 0x0218, 0x0230, 0x0260, 0x0270, + 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0208, 0x0210, 0x0220, 0x0240, 0x0280, 0x0274, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, + 0x0208, 0x0210, 0x0220, 0x0240, 0x0280, 0x0280, 0x027e, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x020a, 0x0214, 0x0228, 0x0248, 0x027c, 0x0280, 0x0280, 0x0280, 0x0280, + 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x020c, 0x0218, 0x0236, 0x0258, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0210, 0x0220, 0x023e, + 0x027c, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0210, 0x0220, 0x0240, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, + 0x0280, 0x0280, 0x0280, 0x0210, 0x0226, 0x0240, 0x027e, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0216, 0x022c, 0x0254, 0x0280, 0x0280, 0x0280, + 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0102, 0x0104, 0x0108, 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0180, 0x0160, 0x0170, 0x0180, 0x0280, 0x0280, 0x0280, + 0x0202, 0x0204, 0x0208, 0x0210, 0x0220, 0x0230, 0x0240, 0x0250, 0x0180, 0x0260, 0x0270, 0x0180, 0x0280, 0x0280, 0x0280, 0x0202, 0x0204, 0x0208, 0x0210, 0x0220, 0x0230, 0x0240, 0x0250, 0x0274, + 0x0274, 0x0270, 0x0280, 0x0280, 0x0280, 0x0280, 0x0204, 0x0208, 0x0210, 0x0220, 0x0240, 0x0260, 0x0278, 0x0270, 0x0274, 0x0280, 0x0280, 0x0280, 0x0280, 0x0278, 0x0280, 0x0206, 0x020a, 0x0214, + 0x0228, 0x0250, 0x0270, 0x0274, 0x0274, 0x027e, 0x027c, 0x0270, 0x027e, 0x0280, 0x0278, 0x0280, 0x0206, 0x020c, 0x0218, 0x0230, 0x0258, 0x0274, 0x0280, 0x027c, 0x027a, 0x0280, 0x0270, 0x0280, + 0x027e, 0x0278, 0x0280, 0x0208, 0x0210, 0x0220, 0x0238, 0x0274, 0x026e, 0x0280, 0x0276, 0x0280, 0x0280, 0x027a, 0x027e, 0x027e, 0x0278, 0x0280, 0x0208, 0x0212, 0x0220, 0x0240, 0x0270, 0x0274, + 0x0280, 0x0276, 0x0280, 0x0280, 0x0270, 0x0280, 0x0280, 0x0278, 0x0280, 0x020a, 0x0214, 0x0220, 0x0248, 0x0270, 0x0276, 0x0280, 0x0276, 0x0280, 0x0280, 0x027c, 0x0280, 0x027a, 0x0278, 0x0280, + 0x020c, 0x0218, 0x0234, 0x0260, 0x0270, 0x0276, 0x0280, 0x0278, 0x0280, 0x0280, 0x0280, 0x0280, 0x0280, 0x0278, 0x0280, 0x020e, 0x021e, 0x0238, 0x0260, 0x0270, 0x0278, 0x0280, 0x0278, 0x0280, + 0x0280, 0x027a, 0x0280, 0x0280, 0x0278, 0x0280, 0x020e, 0x0220, 0x0238, 0x0260, 0x0280, 0x0280, 0x0280, 0x0278, 0x0280, 0x0278, 0x027c, 0x0280, 0x0280, 0x0278, 0x027e, 0x0212, 0x0224, 0x0248, + 0x0270, 0x0280, 0x0280, 0x0280, 0x0278, 0x0280, 0x0280, 0x027c, 0x0280, 0x0280, 0x0278, 0x027e, 0x0102, 0x0104, 0x0108, 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0180, 0x0160, 0x0170, 0x0180, + 0x0180, 0x027e, 0x027e, 0x0202, 0x0204, 0x0208, 0x0210, 0x0120, 0x0130, 0x0140, 0x0150, 0x0180, 0x0180, 0x0170, 0x0180, 0x0180, 0x0280, 0x0280, 0x0204, 0x0208, 0x0210, 0x0220, 0x0140, 0x0160, + 0x0180, 0x0178, 0x0274, 0x0274, 0x0270, 0x0180, 0x0280, 0x0280, 0x027c, 0x0204, 0x0208, 0x0210, 0x0220, 0x0240, 0x0258, 0x0274, 0x0270, 0x027c, 0x0280, 0x0270, 0x0280, 0x027c, 0x0280, 0x0280, + 0x0206, 0x020c, 0x0218, 0x0230, 0x0250, 0x0270, 0x027c, 0x0276, 0x027c, 0x027a, 0x0270, 0x0280, 0x0278, 0x0280, 0x0280, 0x0208, 0x0210, 0x0220, 0x0238, 0x0260, 0x0276, 0x0280, 0x0278, 0x027e, + 0x0280, 0x0270, 0x0280, 0x027e, 0x0280, 0x027c, 0x0208, 0x0210, 0x0220, 0x0240, 0x0274, 0x026e, 0x027e, 0x0280, 0x0280, 0x0280, 0x0270, 0x0280, 0x027e, 0x0280, 0x0280, 0x020a, 0x0214, 0x0228, + 0x0248, 0x026e, 0x0278, 0x027e, 0x0274, 0x0280, 0x027a, 0x0270, 0x0280, 0x027e, 0x027e, 0x027c, 0x020c, 0x021a, 0x0228, 0x0256, 0x0272, 0x027a, 0x027e, 0x0278, 0x0280, 0x027e, 0x0270, 0x0280, + 0x027e, 0x0280, 0x027c, 0x0210, 0x0218, 0x0238, 0x0260, 0x0272, 0x0278, 0x0280, 0x0278, 0x0280, 0x027e, 0x0270, 0x0280, 0x0278, 0x027e, 0x0280, 0x0210, 0x0220, 0x0238, 0x026a, 0x0272, 0x0278, + 0x0280, 0x0278, 0x0280, 0x027e, 0x0270, 0x0280, 0x0278, 0x027e, 0x0280, 0x0210, 0x0220, 0x0240, 0x026c, 0x027e, 0x0278, 0x0280, 0x0278, 0x0280, 0x027e, 0x0270, 0x0280, 0x0278, 0x027e, 0x0280, + 0x0216, 0x022a, 0x0254, 0x026c, 0x027e, 0x0278, 0x0280, 0x0278, 0x0280, 0x027e, 0x0270, 0x0280, 0x0278, 0x027e, 0x0280, 0x0102, 0x0104, 0x0108, 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0180, + 0x0160, 0x0170, 0x0180, 0x0180, 0x027e, 0x027e, 0x0202, 0x0204, 0x0208, 0x0210, 0x0120, 0x0130, 0x0140, 0x0150, 0x0180, 0x0274, 0x0170, 0x0180, 0x0180, 0x0280, 0x0278, 0x0204, 0x0208, 0x0210, + 0x0220, 0x0140, 0x0148, 0x0180, 0x0178, 0x0180, 0x0180, 0x0170, 0x0180, 0x0180, 0x0180, 0x0274, 0x0204, 0x0208, 0x0210, 0x0220, 0x0240, 0x0258, 0x017c, 0x0270, 0x0180, 0x0180, 0x0180, 0x0180, + 0x0180, 0x0180, 0x0274, 0x0206, 0x020c, 0x0218, 0x0228, 0x0258, 0x0170, 0x017e, 0x017e, 0x017e, 0x0180, 0x017e, 0x0180, 0x0180, 0x0180, 0x0274, 0x0208, 0x0210, 0x0220, 0x0230, 0x0260, 0x0176, + 0x017c, 0x017e, 0x0180, 0x0180, 0x0180, 0x0180, 0x0180, 0x0280, 0x0274, 0x0208, 0x0210, 0x0220, 0x023e, 0x025c, 0x0180, 0x017c, 0x017a, 0x0180, 0x0180, 0x0180, 0x0180, 0x0180, 0x0180, 0x0278, + 0x020a, 0x0214, 0x0220, 0x0248, 0x025a, 0x0180, 0x017e, 0x017a, 0x0180, 0x0180, 0x0180, 0x0180, 0x0180, 0x0280, 0x0274, 0x020c, 0x021a, 0x0228, 0x0256, 0x017e, 0x017e, 0x0180, 0x017e, 0x0180, + 0x0180, 0x0180, 0x0180, 0x0180, 0x0280, 0x0274, 0x0210, 0x0220, 0x0238, 0x0260, 0x025e, 0x0180, 0x0180, 0x0178, 0x0180, 0x0180, 0x026e, 0x0180, 0x0180, 0x0280, 0x0278, 0x0210, 0x0220, 0x023a, + 0x025c, 0x0260, 0x017e, 0x0180, 0x017e, 0x0180, 0x0180, 0x026e, 0x0180, 0x0180, 0x0280, 0x0278, 0x0210, 0x0220, 0x0240, 0x0260, 0x0260, 0x0180, 0x017c, 0x0178, 0x0180, 0x0180, 0x026e, 0x0180, + 0x0180, 0x0280, 0x027a, 0x0216, 0x0228, 0x0250, 0x0260, 0x0260, 0x0180, 0x0180, 0x0278, 0x0180, 0x0180, 0x0276, 0x0180, 0x0180, 0x0280, 0x0280, 0x0102, 0x0104, 0x0108, 0x0110, 0x0120, 0x0130, + 0x0140, 0x0150, 0x0180, 0x0160, 0x0170, 0x0180, 0x0180, 0x027c, 0x027a, 0x0202, 0x0204, 0x0208, 0x0210, 0x0220, 0x0130, 0x0140, 0x0150, 0x0180, 0x0274, 0x0170, 0x0180, 0x0180, 0x0278, 0x027a, + 0x0204, 0x0208, 0x0210, 0x0220, 0x0140, 0x0160, 0x0180, 0x0178, 0x0274, 0x0260, 0x0170, 0x0180, 0x0180, 0x0278, 0x027a, 0x0204, 0x0208, 0x0210, 0x0220, 0x0240, 0x0258, 0x0274, 0x0270, 0x0180, + 0x0260, 0x0180, 0x0180, 0x0180, 0x026e, 0x0268, 0x0206, 0x020c, 0x0218, 0x0230, 0x0258, 0x0260, 0x017e, 0x017e, 0x0180, 0x017e, 0x0180, 0x0180, 0x017e, 0x026e, 0x0180, 0x0208, 0x0210, 0x0220, + 0x0238, 0x0260, 0x025e, 0x0180, 0x0180, 0x0180, 0x017e, 0x0180, 0x0180, 0x017e, 0x0266, 0x027a, 0x0208, 0x0210, 0x0220, 0x023e, 0x025e, 0x0260, 0x017e, 0x0180, 0x0180, 0x017e, 0x0180, 0x0180, + 0x017e, 0x027e, 0x026c, 0x020a, 0x0214, 0x0220, 0x0248, 0x025a, 0x025e, 0x0180, 0x0180, 0x0180, 0x017e, 0x0180, 0x0180, 0x017e, 0x0266, 0x027c, 0x020c, 0x021a, 0x0236, 0x0256, 0x0260, 0x0260, + 0x017e, 0x017e, 0x0180, 0x017e, 0x0180, 0x0180, 0x0260, 0x027a, 0x0280, 0x0210, 0x0220, 0x0240, 0x0258, 0x0260, 0x0260, 0x0180, 0x017e, 0x0180, 0x017e, 0x0180, 0x0180, 0x0260, 0x027a, 0x0280, + 0x0212, 0x0220, 0x023e, 0x025c, 0x0260, 0x0260, 0x0180, 0x0180, 0x0180, 0x017e, 0x0180, 0x0180, 0x0260, 0x027e, 0x027c, 0x0212, 0x0226, 0x0240, 0x0258, 0x0180, 0x0260, 0x0180, 0x0180, 0x0180, + 0x0260, 0x0180, 0x0180, 0x0260, 0x027e, 0x0280, 0x0216, 0x022c, 0x0254, 0x0260, 0x0260, 0x0260, 0x0180, 0x0278, 0x0180, 0x0260, 0x0180, 0x0180, 0x0260, 0x0278, 0x027c, 0x0102, 0x0104, 0x0208, + 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0180, 0x0160, 0x0170, 0x0180, 0x0174, 0x0180, 0x0180, 0x0202, 0x0204, 0x0208, 0x0210, 0x0120, 0x0130, 0x0140, 0x0150, 0x0180, 0x0174, 0x0170, 0x0274, + 0x0180, 0x0178, 0x017e, 0x0204, 0x0208, 0x0210, 0x0220, 0x0140, 0x0160, 0x015e, 0x0270, 0x0174, 0x0160, 0x0170, 0x0180, 0x0180, 0x0178, 0x017e, 0x0204, 0x0208, 0x0210, 0x0220, 0x014e, 0x0258, + 0x0174, 0x0168, 0x017c, 0x0160, 0x016e, 0x017e, 0x0180, 0x0178, 0x017e, 0x0206, 0x020c, 0x0218, 0x0228, 0x0258, 0x015e, 0x0162, 0x0170, 0x0174, 0x0160, 0x016e, 0x017c, 0x0160, 0x017c, 0x0276, + 0x0208, 0x0210, 0x0220, 0x0238, 0x0158, 0x0160, 0x0162, 0x0162, 0x0174, 0x0160, 0x016e, 0x017e, 0x0160, 0x017c, 0x0276, 0x0208, 0x0210, 0x0220, 0x023e, 0x0160, 0x0160, 0x0166, 0x0160, 0x017e, + 0x0160, 0x016e, 0x0280, 0x0160, 0x017c, 0x017e, 0x020a, 0x0214, 0x0220, 0x0248, 0x0250, 0x0160, 0x015e, 0x0166, 0x017e, 0x0160, 0x016e, 0x017e, 0x0160, 0x017c, 0x0276, 0x020c, 0x021a, 0x0236, + 0x024c, 0x0160, 0x0160, 0x0162, 0x0160, 0x017e, 0x0160, 0x0172, 0x0280, 0x0160, 0x0278, 0x027e, 0x0210, 0x0220, 0x023e, 0x024c, 0x0160, 0x0160, 0x0162, 0x0178, 0x0174, 0x0160, 0x0270, 0x0280, + 0x0160, 0x027c, 0x0276, 0x0210, 0x0220, 0x023e, 0x0248, 0x0160, 0x0160, 0x0160, 0x0160, 0x0280, 0x0160, 0x0270, 0x0280, 0x0160, 0x027c, 0x0274, 0x0210, 0x0220, 0x0242, 0x0250, 0x0160, 0x0160, + 0x0160, 0x0160, 0x0280, 0x0160, 0x0270, 0x0280, 0x0160, 0x027e, 0x0274, 0x0216, 0x022c, 0x0248, 0x0250, 0x0160, 0x0160, 0x0160, 0x0160, 0x0280, 0x0160, 0x0270, 0x0280, 0x0160, 0x0278, 0x027e, + 0x0102, 0x0104, 0x0104, 0x0110, 0x0210, 0x0120, 0x0140, 0x0150, 0x0140, 0x0260, 0x0270, 0x0180, 0x0292, 0x029c, 0x02a8, 0x0202, 0x0208, 0x020e, 0x020a, 0x0114, 0x014c, 0x0250, 0x0262, 0x0270, + 0x0298, 0x029c, 0x029a, 0x02a8, 0x029c, 0x02aa, 0x0204, 0x0206, 0x020c, 0x0118, 0x0130, 0x0230, 0x0240, 0x0178, 0x0280, 0x02aa, 0x0270, 0x02a6, 0x0298, 0x02a8, 0x02aa, 0x0202, 0x0208, 0x0110, + 0x0218, 0x022e, 0x025e, 0x0260, 0x0280, 0x029e, 0x02a2, 0x02a4, 0x02aa, 0x02a8, 0x02a8, 0x02aa, 0x0204, 0x010a, 0x0210, 0x021e, 0x024c, 0x0272, 0x027c, 0x02a0, 0x02a0, 0x02a8, 0x02a8, 0x02a6, + 0x02a6, 0x02aa, 0x02aa, 0x0206, 0x020c, 0x0118, 0x0244, 0x0256, 0x01a2, 0x0288, 0x02a0, 0x02a4, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02a8, 0x02aa, 0x0208, 0x0210, 0x0220, 0x0238, 0x0280, 0x02a0, + 0x0280, 0x02a0, 0x02a8, 0x02a2, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x0206, 0x0210, 0x021e, 0x0230, 0x028a, 0x02a8, 0x02a2, 0x02a0, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, + 0x020a, 0x0216, 0x0220, 0x0240, 0x0290, 0x029c, 0x029c, 0x02a8, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x020c, 0x0218, 0x0230, 0x0258, 0x029c, 0x02a8, 0x02aa, 0x02a4, 0x02aa, + 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x0210, 0x0220, 0x0230, 0x0268, 0x02a2, 0x02a8, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x020e, 0x021a, 0x0238, + 0x026c, 0x029e, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x0212, 0x0220, 0x0240, 0x0284, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, + 0x02aa, 0x02aa, 0x02aa, 0x0102, 0x0104, 0x0108, 0x0110, 0x0110, 0x0130, 0x0220, 0x0128, 0x0240, 0x0160, 0x0170, 0x0180, 0x0260, 0x0298, 0x02a0, 0x0204, 0x0202, 0x0106, 0x0112, 0x011a, 0x0220, + 0x0152, 0x025c, 0x0260, 0x02a0, 0x0284, 0x029e, 0x02a8, 0x0298, 0x02a8, 0x0202, 0x0206, 0x020c, 0x0118, 0x0220, 0x0230, 0x0240, 0x0178, 0x0280, 0x02a0, 0x0198, 0x02a0, 0x02a8, 0x02a4, 0x02aa, + 0x0202, 0x0204, 0x0210, 0x021c, 0x0240, 0x024e, 0x0276, 0x01a0, 0x02a8, 0x029c, 0x029c, 0x029c, 0x02a8, 0x02aa, 0x02aa, 0x0104, 0x0110, 0x0118, 0x0224, 0x0260, 0x028c, 0x029e, 0x02a0, 0x02aa, + 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x0206, 0x010c, 0x0218, 0x0138, 0x0260, 0x0260, 0x02a0, 0x0298, 0x029e, 0x02a8, 0x02aa, 0x02a4, 0x02a8, 0x02aa, 0x02aa, 0x0208, 0x020e, 0x0220, + 0x0240, 0x026c, 0x028c, 0x0298, 0x02a0, 0x02aa, 0x02aa, 0x02aa, 0x02a6, 0x02aa, 0x02aa, 0x02aa, 0x0108, 0x0210, 0x0222, 0x0258, 0x0270, 0x02a8, 0x02a0, 0x02a0, 0x02a8, 0x02aa, 0x02a8, 0x02aa, + 0x02aa, 0x02aa, 0x02aa, 0x020a, 0x0216, 0x0224, 0x0248, 0x0298, 0x02a8, 0x02aa, 0x02a6, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x020c, 0x0218, 0x0230, 0x025e, 0x0298, 0x02a8, + 0x02a8, 0x02a4, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x0210, 0x021c, 0x0238, 0x0268, 0x02a4, 0x02a8, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, + 0x020e, 0x021c, 0x0238, 0x0268, 0x02a4, 0x02aa, 0x02a8, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x0210, 0x0220, 0x0246, 0x0290, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, + 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x0102, 0x0104, 0x0204, 0x0208, 0x0110, 0x0118, 0x0220, 0x0128, 0x0140, 0x02a2, 0x0270, 0x0280, 0x0290, 0x02a0, 0x02aa, 0x0201, 0x0204, 0x0210, + 0x0208, 0x0218, 0x022a, 0x0160, 0x023e, 0x0298, 0x028a, 0x02a0, 0x0296, 0x02a8, 0x029c, 0x02aa, 0x0202, 0x0204, 0x0208, 0x0218, 0x0220, 0x0230, 0x0240, 0x0250, 0x0280, 0x0284, 0x0278, 0x019c, + 0x01a0, 0x02a0, 0x02aa, 0x0202, 0x0208, 0x020e, 0x0226, 0x024c, 0x0260, 0x0280, 0x0278, 0x02a6, 0x02aa, 0x02a8, 0x02a6, 0x029e, 0x02aa, 0x02aa, 0x0206, 0x0208, 0x021a, 0x0230, 0x024a, 0x0278, + 0x0278, 0x029e, 0x02aa, 0x02a8, 0x02a8, 0x02a8, 0x02a8, 0x02aa, 0x02aa, 0x0206, 0x0210, 0x0218, 0x0138, 0x025a, 0x0270, 0x02a0, 0x0298, 0x02a4, 0x02a8, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, + 0x0208, 0x020e, 0x0220, 0x0240, 0x026c, 0x02a2, 0x029c, 0x029e, 0x02aa, 0x02aa, 0x02a8, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x0206, 0x0216, 0x0228, 0x0248, 0x0280, 0x02a8, 0x02a6, 0x0292, 0x02a8, + 0x02a8, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x020a, 0x0214, 0x0228, 0x0250, 0x029c, 0x02a8, 0x02aa, 0x02a8, 0x02aa, 0x02aa, 0x02a8, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x020c, 0x0218, 0x022c, + 0x0258, 0x02a2, 0x02a8, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02a8, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x020e, 0x021a, 0x0230, 0x0260, 0x02aa, 0x02a8, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, + 0x02aa, 0x02aa, 0x02aa, 0x020e, 0x021c, 0x0230, 0x0268, 0x02a8, 0x02a8, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x0210, 0x0220, 0x0246, 0x0290, 0x029e, 0x02a8, + 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x0201, 0x0202, 0x0204, 0x0208, 0x0210, 0x0118, 0x0140, 0x0228, 0x0240, 0x0160, 0x0170, 0x0280, 0x01a4, 0x0298, 0x02a6, + 0x0101, 0x0108, 0x0204, 0x0210, 0x0120, 0x0142, 0x0240, 0x0270, 0x02a0, 0x0192, 0x029e, 0x01aa, 0x02a8, 0x0298, 0x02aa, 0x0202, 0x0206, 0x020c, 0x0218, 0x0220, 0x0230, 0x0240, 0x0250, 0x0280, + 0x0298, 0x0270, 0x0288, 0x02a8, 0x02aa, 0x02aa, 0x0206, 0x0206, 0x020c, 0x0218, 0x024a, 0x0246, 0x0260, 0x0290, 0x029a, 0x02a8, 0x02a8, 0x02a8, 0x02a0, 0x02a8, 0x02aa, 0x0204, 0x020a, 0x011a, + 0x0236, 0x023e, 0x0270, 0x01a0, 0x0298, 0x02aa, 0x02a4, 0x02a8, 0x02a8, 0x02a8, 0x02aa, 0x02a8, 0x0206, 0x020c, 0x0218, 0x0242, 0x0280, 0x0288, 0x0298, 0x02a0, 0x02a4, 0x02a8, 0x02a8, 0x02aa, + 0x02a8, 0x02aa, 0x02aa, 0x0206, 0x0210, 0x0220, 0x0240, 0x0278, 0x0288, 0x02a0, 0x02a2, 0x02a4, 0x02a8, 0x02a8, 0x02aa, 0x02a8, 0x02aa, 0x02aa, 0x0208, 0x0212, 0x0222, 0x0248, 0x0268, 0x02a8, + 0x02a8, 0x0298, 0x02a4, 0x02a8, 0x02a8, 0x02aa, 0x02a8, 0x02a8, 0x02aa, 0x020a, 0x0210, 0x0226, 0x0240, 0x0270, 0x02a8, 0x02a0, 0x02a0, 0x02aa, 0x02a8, 0x02a8, 0x02aa, 0x02a8, 0x02aa, 0x02aa, + 0x020c, 0x0218, 0x0230, 0x0264, 0x02a0, 0x02a8, 0x02a0, 0x02a0, 0x02aa, 0x02a8, 0x02a8, 0x02a8, 0x02a8, 0x02aa, 0x02aa, 0x020e, 0x021c, 0x0230, 0x0268, 0x02a0, 0x02a8, 0x02a8, 0x02a0, 0x02a8, + 0x02a8, 0x02a8, 0x02a8, 0x02a8, 0x02aa, 0x02aa, 0x020c, 0x021a, 0x0234, 0x0264, 0x0290, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02a8, 0x02a8, 0x02a8, 0x02a8, 0x02aa, 0x02aa, 0x0210, 0x0218, 0x0238, + 0x0278, 0x02a2, 0x02a8, 0x02aa, 0x02a0, 0x02a8, 0x02a8, 0x02a8, 0x02aa, 0x02a8, 0x02aa, 0x02aa, 0x0102, 0x0102, 0x0208, 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0298, 0x0292, 0x0186, 0x026c, + 0x01aa, 0x0298, 0x02a6, 0x0102, 0x0104, 0x010c, 0x020c, 0x022a, 0x0230, 0x0240, 0x0172, 0x0298, 0x029c, 0x01a8, 0x01a4, 0x0290, 0x02a0, 0x02aa, 0x0202, 0x0204, 0x020c, 0x0218, 0x0220, 0x0230, + 0x0240, 0x0250, 0x0298, 0x0298, 0x029c, 0x02a0, 0x02a8, 0x02aa, 0x02aa, 0x0204, 0x0208, 0x0210, 0x0220, 0x0240, 0x0260, 0x0178, 0x0270, 0x0288, 0x02a2, 0x02aa, 0x02a0, 0x02a8, 0x02aa, 0x02aa, + 0x0206, 0x020a, 0x0218, 0x0228, 0x024c, 0x0190, 0x02a0, 0x02a0, 0x02a8, 0x02aa, 0x02a4, 0x02a4, 0x02a8, 0x02aa, 0x02aa, 0x0208, 0x020e, 0x021c, 0x0230, 0x027c, 0x0188, 0x02a0, 0x02a0, 0x02aa, + 0x02a8, 0x02aa, 0x02aa, 0x02a8, 0x02aa, 0x02aa, 0x0108, 0x0212, 0x021a, 0x0230, 0x0280, 0x02aa, 0x029a, 0x02a0, 0x02a8, 0x02a8, 0x02a8, 0x02a8, 0x02a8, 0x02aa, 0x02aa, 0x020a, 0x0210, 0x0220, + 0x024c, 0x0270, 0x02a8, 0x029c, 0x02a0, 0x02a4, 0x02a8, 0x02a8, 0x02a8, 0x02a8, 0x02aa, 0x02aa, 0x0208, 0x0218, 0x0228, 0x0258, 0x0298, 0x02a8, 0x02a2, 0x02a6, 0x02aa, 0x02a8, 0x02a8, 0x02a8, + 0x02a8, 0x02a8, 0x02aa, 0x020c, 0x021a, 0x022a, 0x0270, 0x02a0, 0x02a8, 0x02a8, 0x02a8, 0x02aa, 0x02a8, 0x02a8, 0x02a8, 0x02aa, 0x02aa, 0x02aa, 0x020e, 0x021a, 0x0238, 0x0270, 0x02a8, 0x02a8, + 0x02a8, 0x02aa, 0x02a8, 0x02aa, 0x02a8, 0x02a8, 0x02a8, 0x02aa, 0x02aa, 0x0210, 0x0218, 0x023e, 0x0280, 0x02a8, 0x02a8, 0x02aa, 0x02aa, 0x02a8, 0x02a8, 0x02a8, 0x02a8, 0x02a8, 0x02aa, 0x02aa, + 0x0212, 0x0222, 0x024a, 0x028c, 0x02a8, 0x02a8, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02a8, 0x02a8, 0x02aa, 0x02aa, 0x02aa, 0x0102, 0x0104, 0x0108, 0x0208, 0x0120, 0x0130, 0x0140, 0x013c, 0x0140, + 0x0198, 0x0190, 0x0176, 0x0298, 0x0298, 0x02a6, 0x0204, 0x0108, 0x0108, 0x0212, 0x012a, 0x022a, 0x0246, 0x0254, 0x0190, 0x019e, 0x01a8, 0x01a6, 0x0288, 0x02a0, 0x02a8, 0x0202, 0x0204, 0x0110, + 0x0120, 0x0230, 0x0230, 0x0240, 0x0178, 0x0180, 0x0280, 0x029c, 0x029c, 0x02a8, 0x02a4, 0x02aa, 0x0204, 0x020a, 0x0210, 0x021e, 0x0240, 0x0260, 0x017c, 0x027a, 0x02a8, 0x02aa, 0x0298, 0x02aa, + 0x02aa, 0x02a0, 0x02aa, 0x0206, 0x020a, 0x0218, 0x0230, 0x0150, 0x0288, 0x0294, 0x02a0, 0x02aa, 0x02a4, 0x02a8, 0x02aa, 0x02a0, 0x02a0, 0x02aa, 0x0108, 0x020c, 0x0218, 0x0238, 0x0262, 0x019c, + 0x02a0, 0x02a0, 0x02a0, 0x02a0, 0x02a8, 0x02aa, 0x02aa, 0x02a0, 0x02a6, 0x020c, 0x0218, 0x0226, 0x0240, 0x026c, 0x029c, 0x02a0, 0x029e, 0x02aa, 0x02a0, 0x02a8, 0x02aa, 0x02a8, 0x02a0, 0x02aa, + 0x0208, 0x0210, 0x0220, 0x0238, 0x0278, 0x02a8, 0x02a0, 0x02a0, 0x02a0, 0x02a0, 0x02a8, 0x02aa, 0x02a8, 0x02a0, 0x02aa, 0x0208, 0x0216, 0x022c, 0x0248, 0x02a0, 0x02a8, 0x02a0, 0x02a0, 0x02aa, + 0x02a0, 0x02a8, 0x02aa, 0x02a8, 0x02a0, 0x02aa, 0x020c, 0x0218, 0x0230, 0x026e, 0x02a8, 0x02a8, 0x02a0, 0x02a0, 0x02a0, 0x02a8, 0x02a8, 0x02a8, 0x02a0, 0x02a0, 0x02aa, 0x020c, 0x021e, 0x0238, + 0x026a, 0x02a0, 0x02a8, 0x02a0, 0x02a0, 0x02a8, 0x02a8, 0x02a8, 0x02aa, 0x02a0, 0x02a0, 0x02aa, 0x0210, 0x021e, 0x0238, 0x0280, 0x0298, 0x02a8, 0x02a0, 0x02a0, 0x02a8, 0x02a0, 0x02a8, 0x02aa, + 0x02a0, 0x02a0, 0x02aa, 0x0212, 0x022c, 0x0248, 0x0290, 0x02a0, 0x02a8, 0x02a0, 0x02a0, 0x02a0, 0x02a0, 0x02a8, 0x02aa, 0x02aa, 0x02a0, 0x02aa, 0x0102, 0x0104, 0x0108, 0x0108, 0x0280, 0x0132, + 0x025c, 0x0230, 0x0174, 0x01a4, 0x01a6, 0x019e, 0x019e, 0x01a0, 0x02a8, 0x0102, 0x0104, 0x020c, 0x0210, 0x0220, 0x0236, 0x0140, 0x0150, 0x0280, 0x01aa, 0x0270, 0x01aa, 0x02a8, 0x02a8, 0x02aa, + 0x0106, 0x020a, 0x020e, 0x022a, 0x012e, 0x0260, 0x017e, 0x017c, 0x019c, 0x029c, 0x02a8, 0x02a2, 0x02a8, 0x02a8, 0x02aa, 0x0204, 0x0208, 0x0210, 0x0220, 0x0168, 0x0178, 0x0180, 0x02a0, 0x029e, + 0x02aa, 0x02a8, 0x02aa, 0x02aa, 0x02aa, 0x02a6, 0x0208, 0x0208, 0x0218, 0x0230, 0x0250, 0x0270, 0x027e, 0x02a0, 0x02aa, 0x02a4, 0x02a4, 0x02a0, 0x02aa, 0x02a6, 0x02a6, 0x0208, 0x0212, 0x0226, + 0x0240, 0x0270, 0x019e, 0x02a0, 0x02a2, 0x02a6, 0x02aa, 0x02a8, 0x02a6, 0x02a6, 0x02a6, 0x02a6, 0x020a, 0x0210, 0x0220, 0x023e, 0x0280, 0x02a8, 0x02a0, 0x02a0, 0x02a6, 0x02a6, 0x02a8, 0x02aa, + 0x02a8, 0x02a6, 0x02a6, 0x020a, 0x0214, 0x0228, 0x0248, 0x028e, 0x02a8, 0x02a0, 0x02a0, 0x02aa, 0x02aa, 0x02a8, 0x02a6, 0x02a6, 0x02aa, 0x02aa, 0x0210, 0x0214, 0x0230, 0x0250, 0x029c, 0x02a8, + 0x02a0, 0x02a0, 0x02aa, 0x02a8, 0x02a8, 0x02aa, 0x02a8, 0x02aa, 0x02a6, 0x0210, 0x021a, 0x0238, 0x0278, 0x02a0, 0x02a8, 0x02a6, 0x02a0, 0x02a8, 0x02a8, 0x02a8, 0x02a6, 0x02a8, 0x02aa, 0x02aa, + 0x0210, 0x0220, 0x0240, 0x0298, 0x02a0, 0x02a8, 0x02a0, 0x02a0, 0x02a8, 0x02a8, 0x02a8, 0x02a6, 0x02a8, 0x02aa, 0x02aa, 0x0212, 0x0220, 0x0240, 0x0280, 0x02a8, 0x02a8, 0x02aa, 0x02a8, 0x02a8, + 0x02a8, 0x02a8, 0x02aa, 0x02a8, 0x02a8, 0x02aa, 0x0218, 0x022c, 0x0258, 0x0290, 0x02a0, 0x02a8, 0x02a8, 0x02a8, 0x02a8, 0x02a8, 0x02a8, 0x02aa, 0x02a8, 0x02aa, 0x02aa, 0x0102, 0x0104, 0x0204, + 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0140, 0x0160, 0x0170, 0x0180, 0x019c, 0x01a0, 0x0290, 0x0201, 0x0204, 0x0110, 0x0208, 0x011e, 0x0230, 0x0140, 0x0182, 0x017e, 0x0190, 0x019c, 0x0180, + 0x01a4, 0x02a8, 0x02a8, 0x0202, 0x0204, 0x0208, 0x0118, 0x0230, 0x0248, 0x0240, 0x0178, 0x0190, 0x0298, 0x01aa, 0x01a4, 0x01a8, 0x01a8, 0x029c, 0x0104, 0x010a, 0x020c, 0x0220, 0x0240, 0x0158, + 0x01a2, 0x01a0, 0x029c, 0x0288, 0x01a8, 0x01a8, 0x01a8, 0x01a6, 0x029c, 0x0206, 0x020a, 0x0218, 0x0230, 0x0248, 0x0268, 0x01a0, 0x01a0, 0x01a0, 0x01a2, 0x01a8, 0x01a8, 0x01a8, 0x0296, 0x02a6, + 0x0206, 0x020c, 0x0218, 0x0238, 0x017e, 0x019c, 0x01a0, 0x01a0, 0x01a0, 0x01a8, 0x01a8, 0x01a4, 0x01a8, 0x02a8, 0x02aa, 0x0208, 0x0210, 0x0220, 0x0230, 0x0268, 0x0278, 0x027e, 0x01a0, 0x01a0, + 0x01a2, 0x01a8, 0x01a8, 0x01a8, 0x02a4, 0x02a6, 0x0208, 0x0210, 0x0224, 0x0238, 0x0270, 0x01a8, 0x01a0, 0x01a0, 0x01a0, 0x01a8, 0x01a8, 0x01a8, 0x01a8, 0x02a4, 0x02a6, 0x020a, 0x0214, 0x0220, + 0x0250, 0x0280, 0x01a4, 0x01a0, 0x01a0, 0x0280, 0x0290, 0x01a8, 0x01a8, 0x01a8, 0x02a0, 0x02a6, 0x020c, 0x0218, 0x0230, 0x0262, 0x0280, 0x01a8, 0x01a0, 0x01a0, 0x01a0, 0x01a2, 0x01a8, 0x01a8, + 0x01a8, 0x02a6, 0x02a4, 0x0210, 0x021a, 0x0238, 0x0270, 0x0280, 0x01a8, 0x01a0, 0x01a0, 0x01a0, 0x0290, 0x01a8, 0x01a8, 0x01a8, 0x02aa, 0x02aa, 0x0210, 0x021e, 0x0248, 0x0264, 0x01a0, 0x01a8, + 0x01a0, 0x01a0, 0x01a0, 0x0296, 0x01a8, 0x01a8, 0x01a8, 0x029c, 0x02aa, 0x0212, 0x0224, 0x0248, 0x0278, 0x01a0, 0x01a8, 0x01a0, 0x01a0, 0x01a0, 0x029c, 0x01a8, 0x01a8, 0x02a8, 0x02a6, 0x02aa, +}; + +constexpr uint16_t samples_256[] = { + 0x0104, 0x0108, 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0140, 0x0154, 0x0254, 0x0340, 0x0354, 0x0350, 0x0354, 0x0204, 0x0208, 0x0210, 0x0220, 0x0230, 0x0240, 0x0150, 0x0240, 0x0254, 0x0254, + 0x0340, 0x0354, 0x0350, 0x0354, 0x0204, 0x0208, 0x0218, 0x0220, 0x0230, 0x0240, 0x0250, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0350, 0x0354, 0x0208, 0x0210, 0x0220, 0x0240, 0x0248, 0x0254, + 0x0250, 0x0354, 0x0352, 0x0354, 0x0352, 0x0354, 0x0354, 0x0354, 0x020a, 0x0214, 0x0228, 0x0250, 0x024e, 0x0254, 0x0350, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0210, 0x0218, + 0x0238, 0x0252, 0x0254, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0210, 0x0220, 0x0240, 0x0252, 0x0254, 0x0352, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, + 0x0354, 0x0354, 0x0212, 0x0220, 0x0240, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0218, 0x0230, 0x0252, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, + 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0220, 0x0240, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0220, 0x023c, 0x0254, 0x0354, + 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0226, 0x0240, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, + 0x022a, 0x0250, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0104, 0x0108, 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0140, 0x0154, 0x0154, + 0x0340, 0x0352, 0x0350, 0x0354, 0x0204, 0x0208, 0x0210, 0x0220, 0x0130, 0x0140, 0x0150, 0x0240, 0x0254, 0x0254, 0x0340, 0x0354, 0x0350, 0x0354, 0x0204, 0x0210, 0x0210, 0x0130, 0x0148, 0x0240, + 0x0250, 0x0254, 0x0354, 0x0352, 0x0352, 0x0354, 0x0350, 0x0354, 0x0208, 0x0210, 0x0220, 0x0240, 0x0248, 0x0254, 0x0250, 0x0354, 0x0354, 0x0354, 0x0352, 0x0354, 0x0354, 0x0354, 0x020a, 0x0214, + 0x0228, 0x0250, 0x0254, 0x0254, 0x0350, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0210, 0x0220, 0x0230, 0x0254, 0x0254, 0x0352, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, + 0x0354, 0x0354, 0x0210, 0x0220, 0x0240, 0x0252, 0x0254, 0x0352, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0212, 0x0224, 0x0240, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, + 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0218, 0x022a, 0x0252, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0220, 0x0240, 0x0254, 0x0354, + 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x021e, 0x0240, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, + 0x0226, 0x0240, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x022a, 0x0248, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, + 0x0354, 0x0354, 0x0354, 0x0354, 0x0104, 0x0108, 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0140, 0x0154, 0x0154, 0x0340, 0x0354, 0x0350, 0x0354, 0x0204, 0x0208, 0x0210, 0x0120, 0x0130, 0x0140, + 0x0150, 0x0154, 0x0154, 0x0154, 0x0340, 0x0354, 0x0350, 0x0354, 0x0208, 0x0210, 0x0118, 0x0140, 0x0148, 0x0154, 0x0150, 0x0340, 0x0354, 0x0354, 0x0354, 0x0354, 0x0350, 0x0354, 0x0208, 0x0210, + 0x0220, 0x0140, 0x0248, 0x0254, 0x0250, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x020c, 0x021a, 0x0234, 0x0250, 0x0254, 0x0350, 0x0350, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, + 0x0354, 0x0354, 0x0210, 0x0220, 0x0240, 0x0252, 0x034e, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0210, 0x0220, 0x0240, 0x0252, 0x0352, 0x0354, 0x0354, 0x0354, + 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0214, 0x0228, 0x0250, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x021a, 0x0234, 0x0252, 0x0354, + 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0220, 0x0240, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, + 0x0220, 0x0240, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0226, 0x024c, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, + 0x0354, 0x0354, 0x0354, 0x0354, 0x022a, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0104, 0x0108, 0x0110, 0x0120, 0x0130, 0x0140, + 0x0150, 0x0140, 0x0154, 0x0154, 0x0340, 0x0254, 0x0350, 0x0354, 0x0204, 0x0208, 0x0210, 0x0120, 0x0130, 0x0140, 0x0150, 0x0240, 0x0254, 0x0254, 0x0340, 0x0354, 0x0350, 0x0354, 0x0204, 0x0208, + 0x0118, 0x0130, 0x0148, 0x0240, 0x0250, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0350, 0x0354, 0x0208, 0x0210, 0x0220, 0x0240, 0x0248, 0x0254, 0x0250, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, + 0x0354, 0x0354, 0x020a, 0x0214, 0x0228, 0x0250, 0x0252, 0x0254, 0x0350, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0210, 0x0218, 0x0230, 0x0252, 0x0252, 0x0254, 0x0354, 0x0354, + 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0210, 0x0220, 0x0240, 0x0252, 0x0252, 0x0352, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0212, 0x0224, 0x0248, 0x0254, + 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0218, 0x022a, 0x0252, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, + 0x0218, 0x0234, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x021e, 0x023c, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, + 0x0354, 0x0354, 0x0354, 0x0354, 0x0220, 0x0240, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0224, 0x0248, 0x0254, 0x0354, 0x0354, 0x0354, + 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0104, 0x0108, 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0140, 0x0154, 0x0154, 0x0152, 0x0254, 0x0350, 0x0354, 0x0204, 0x0208, + 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0154, 0x0154, 0x0254, 0x0254, 0x0354, 0x0350, 0x0354, 0x0208, 0x0210, 0x0118, 0x0140, 0x0148, 0x0154, 0x0150, 0x0254, 0x0254, 0x0254, 0x0254, 0x0354, + 0x0354, 0x0354, 0x0208, 0x0210, 0x0220, 0x0240, 0x0248, 0x0254, 0x0250, 0x0254, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x020c, 0x0218, 0x0228, 0x0250, 0x0254, 0x0254, 0x0254, 0x0354, + 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0210, 0x0220, 0x0240, 0x0252, 0x0254, 0x0254, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0210, 0x0220, 0x0240, 0x0252, + 0x0254, 0x0254, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0214, 0x0228, 0x024e, 0x0254, 0x0254, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, + 0x021a, 0x0234, 0x0252, 0x0254, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0220, 0x0240, 0x0254, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, + 0x0354, 0x0354, 0x0354, 0x0354, 0x0220, 0x0240, 0x0254, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0226, 0x0240, 0x0254, 0x0254, 0x0354, 0x0354, + 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x022c, 0x0254, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0104, 0x0108, + 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0140, 0x0154, 0x0154, 0x0154, 0x0254, 0x0250, 0x0354, 0x0204, 0x0208, 0x0210, 0x0120, 0x0130, 0x0140, 0x0150, 0x0154, 0x0154, 0x0254, 0x0254, 0x0354, + 0x0350, 0x0354, 0x0208, 0x0110, 0x0120, 0x0140, 0x0148, 0x0154, 0x0150, 0x0254, 0x0254, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0208, 0x0210, 0x0220, 0x0240, 0x0248, 0x0254, 0x0250, 0x0254, + 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x020c, 0x0218, 0x0228, 0x0250, 0x0252, 0x0254, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0210, 0x0220, 0x0240, 0x0254, + 0x0252, 0x0254, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0210, 0x0220, 0x0240, 0x0252, 0x0254, 0x0254, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, + 0x0214, 0x0228, 0x024e, 0x0254, 0x0254, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x021a, 0x0234, 0x0252, 0x0254, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, + 0x0354, 0x0354, 0x0354, 0x0354, 0x0220, 0x0240, 0x0254, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0224, 0x0240, 0x0254, 0x0254, 0x0354, 0x0354, + 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0226, 0x024c, 0x0254, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x022c, 0x0254, + 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0104, 0x0108, 0x0110, 0x0120, 0x0130, 0x0140, 0x0140, 0x0140, 0x0154, 0x0154, 0x0154, 0x0254, + 0x0250, 0x0354, 0x0204, 0x0208, 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0154, 0x0254, 0x0254, 0x0254, 0x0354, 0x0350, 0x0354, 0x0208, 0x0110, 0x0120, 0x0140, 0x0148, 0x0154, 0x0250, 0x0254, + 0x0254, 0x0254, 0x0354, 0x0354, 0x0350, 0x0354, 0x0208, 0x0210, 0x0220, 0x0240, 0x0248, 0x0254, 0x0250, 0x0254, 0x0354, 0x0354, 0x0352, 0x0354, 0x0354, 0x0354, 0x020c, 0x0218, 0x0230, 0x0254, + 0x0252, 0x0254, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0210, 0x0220, 0x0240, 0x0252, 0x0254, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, + 0x0214, 0x0220, 0x0246, 0x0252, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0214, 0x0228, 0x024e, 0x0254, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, + 0x0354, 0x0354, 0x0354, 0x0354, 0x021a, 0x0234, 0x0252, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0222, 0x0240, 0x0254, 0x0354, 0x0354, 0x0354, + 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0224, 0x0248, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0226, 0x024c, + 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0230, 0x0254, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, + 0x0354, 0x0354, 0x0104, 0x0108, 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0140, 0x0150, 0x0154, 0x0340, 0x0254, 0x0350, 0x0354, 0x0204, 0x0208, 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0154, + 0x0154, 0x0154, 0x0340, 0x0354, 0x0350, 0x0350, 0x0208, 0x0110, 0x0120, 0x0140, 0x0148, 0x0154, 0x0150, 0x0254, 0x034e, 0x0352, 0x0340, 0x0350, 0x0350, 0x0350, 0x0208, 0x0210, 0x0220, 0x0150, + 0x0252, 0x0254, 0x0250, 0x0340, 0x0350, 0x034e, 0x0340, 0x0354, 0x0350, 0x0346, 0x020c, 0x021a, 0x0230, 0x0250, 0x0252, 0x0254, 0x034c, 0x0340, 0x034e, 0x034e, 0x0340, 0x034e, 0x0350, 0x0346, + 0x0210, 0x0220, 0x0240, 0x0252, 0x0254, 0x0340, 0x034c, 0x0340, 0x0344, 0x0352, 0x0340, 0x0342, 0x0350, 0x0350, 0x0214, 0x0220, 0x0240, 0x0252, 0x0344, 0x0340, 0x0348, 0x0340, 0x034a, 0x0354, + 0x0340, 0x0342, 0x0350, 0x0350, 0x0214, 0x0228, 0x024e, 0x0254, 0x0344, 0x0340, 0x0344, 0x0340, 0x0348, 0x0354, 0x0340, 0x0342, 0x0350, 0x0350, 0x021a, 0x0234, 0x0252, 0x0340, 0x0344, 0x0340, + 0x0344, 0x0340, 0x0348, 0x0354, 0x0340, 0x034e, 0x0350, 0x0350, 0x0220, 0x0240, 0x0254, 0x0340, 0x0348, 0x0340, 0x0350, 0x0340, 0x0348, 0x0354, 0x0340, 0x0354, 0x0350, 0x0352, 0x0224, 0x0240, + 0x0254, 0x0340, 0x0348, 0x0340, 0x0344, 0x0340, 0x0348, 0x0354, 0x0340, 0x0342, 0x0350, 0x0354, 0x0226, 0x0248, 0x0254, 0x0340, 0x0344, 0x0342, 0x0344, 0x0340, 0x0348, 0x0348, 0x0340, 0x034e, + 0x0350, 0x0354, 0x0230, 0x0254, 0x0254, 0x0340, 0x0348, 0x0340, 0x0350, 0x0340, 0x0348, 0x0354, 0x0340, 0x0342, 0x0350, 0x0354, 0x0104, 0x0108, 0x0110, 0x0120, 0x0130, 0x0140, 0x0128, 0x0180, + 0x0160, 0x0170, 0x0180, 0x0360, 0x0280, 0x0380, 0x0204, 0x0208, 0x0210, 0x0220, 0x0230, 0x0240, 0x0250, 0x0280, 0x0260, 0x0270, 0x0280, 0x0360, 0x037e, 0x0380, 0x0204, 0x0208, 0x0210, 0x0220, + 0x0230, 0x0240, 0x0250, 0x0280, 0x0280, 0x0270, 0x0280, 0x0380, 0x0380, 0x0380, 0x0208, 0x0210, 0x0220, 0x0240, 0x0260, 0x0280, 0x0278, 0x0280, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, + 0x020a, 0x0214, 0x0228, 0x0250, 0x0278, 0x0280, 0x0278, 0x0280, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x020c, 0x0218, 0x0230, 0x0260, 0x0280, 0x0280, 0x0280, 0x0380, 0x0380, 0x0380, + 0x0380, 0x0380, 0x0380, 0x0380, 0x0210, 0x0220, 0x0240, 0x0280, 0x0280, 0x0280, 0x0280, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0212, 0x0220, 0x0248, 0x0280, 0x0280, 0x0280, + 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0214, 0x0230, 0x0256, 0x0280, 0x0280, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x021a, 0x0230, + 0x0266, 0x0280, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x021e, 0x023c, 0x0270, 0x0280, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, + 0x0380, 0x0380, 0x0220, 0x023e, 0x0280, 0x0280, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0224, 0x0248, 0x0280, 0x0280, 0x0380, 0x0380, 0x0380, 0x0380, + 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0104, 0x0204, 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0180, 0x0160, 0x0170, 0x0180, 0x0360, 0x0280, 0x0380, 0x0204, 0x0208, 0x0210, 0x0220, + 0x0230, 0x0140, 0x0150, 0x0280, 0x0280, 0x0270, 0x0280, 0x0360, 0x0378, 0x0380, 0x0204, 0x0208, 0x0210, 0x0220, 0x0230, 0x0240, 0x0250, 0x0280, 0x0280, 0x0270, 0x0280, 0x0380, 0x0380, 0x0380, + 0x0208, 0x0210, 0x0220, 0x0240, 0x0260, 0x0280, 0x0278, 0x0280, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x020a, 0x0214, 0x0228, 0x0250, 0x0278, 0x0280, 0x0280, 0x0380, 0x0380, 0x0380, + 0x0380, 0x0380, 0x0380, 0x0380, 0x020c, 0x0218, 0x0230, 0x0260, 0x0274, 0x0280, 0x0280, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0210, 0x0220, 0x0240, 0x0280, 0x0280, 0x0280, + 0x0280, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0212, 0x0224, 0x0248, 0x0280, 0x0280, 0x0280, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0214, 0x0228, + 0x0248, 0x0280, 0x0280, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x021a, 0x0234, 0x0268, 0x0280, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, + 0x0380, 0x0380, 0x021e, 0x0238, 0x0270, 0x0280, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0220, 0x023e, 0x0268, 0x0280, 0x0380, 0x0380, 0x0380, 0x0380, + 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0224, 0x0248, 0x027e, 0x0280, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0104, 0x0108, 0x0110, 0x0120, + 0x0130, 0x0140, 0x0150, 0x0180, 0x0160, 0x0170, 0x0180, 0x0360, 0x0370, 0x0380, 0x0204, 0x0208, 0x0210, 0x0120, 0x0230, 0x0140, 0x0150, 0x0180, 0x0180, 0x0270, 0x0280, 0x0360, 0x0378, 0x0380, + 0x0208, 0x0210, 0x0220, 0x0140, 0x0148, 0x0160, 0x0178, 0x0280, 0x0180, 0x0370, 0x0380, 0x0380, 0x0380, 0x0380, 0x0208, 0x0210, 0x0220, 0x0240, 0x0260, 0x0280, 0x0270, 0x0280, 0x0380, 0x0380, + 0x0380, 0x0380, 0x037e, 0x0380, 0x020c, 0x0218, 0x0230, 0x0260, 0x0270, 0x0280, 0x0280, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x037e, 0x037e, 0x0210, 0x0220, 0x0240, 0x0280, 0x0274, 0x0280, + 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x037e, 0x0210, 0x0220, 0x0240, 0x0280, 0x0280, 0x027e, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0214, 0x0228, + 0x0248, 0x027c, 0x0280, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x037e, 0x0380, 0x0218, 0x0236, 0x0258, 0x0280, 0x0380, 0x037e, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, + 0x0380, 0x0380, 0x0220, 0x023e, 0x027c, 0x0280, 0x0380, 0x0380, 0x037e, 0x0380, 0x0380, 0x037e, 0x0380, 0x0380, 0x0380, 0x0380, 0x0220, 0x0240, 0x0280, 0x0280, 0x0380, 0x0380, 0x0380, 0x0380, + 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0226, 0x0240, 0x027e, 0x037e, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x037e, 0x0380, 0x0380, 0x0380, 0x0380, 0x022c, 0x0254, 0x0280, 0x0380, + 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x037e, 0x0380, 0x0380, 0x037e, 0x0380, 0x0104, 0x0108, 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0180, 0x0160, 0x0170, 0x0180, 0x0360, 0x0280, 0x0376, + 0x0204, 0x0208, 0x0210, 0x0220, 0x0230, 0x0240, 0x0250, 0x0180, 0x0260, 0x0270, 0x0180, 0x0280, 0x0280, 0x037a, 0x0204, 0x0208, 0x0210, 0x0220, 0x0230, 0x0240, 0x0250, 0x0274, 0x0274, 0x0270, + 0x0280, 0x0280, 0x0280, 0x0378, 0x0208, 0x0210, 0x0220, 0x0240, 0x0260, 0x0278, 0x0270, 0x0274, 0x0280, 0x0280, 0x0280, 0x0280, 0x0378, 0x0378, 0x020a, 0x0214, 0x0228, 0x0250, 0x0270, 0x0274, + 0x0274, 0x027e, 0x027c, 0x0270, 0x027e, 0x0280, 0x0370, 0x037e, 0x020c, 0x0218, 0x0230, 0x0258, 0x0274, 0x0280, 0x027c, 0x027a, 0x0280, 0x0270, 0x0280, 0x027e, 0x0366, 0x037c, 0x0210, 0x0220, + 0x0238, 0x0274, 0x026e, 0x0280, 0x0276, 0x0280, 0x0280, 0x027a, 0x027e, 0x0360, 0x0378, 0x037c, 0x0212, 0x0220, 0x0240, 0x0270, 0x0274, 0x0280, 0x0276, 0x0280, 0x0280, 0x0370, 0x0280, 0x0360, + 0x0374, 0x037e, 0x0214, 0x0220, 0x0248, 0x0270, 0x0276, 0x0280, 0x0276, 0x0280, 0x0360, 0x0370, 0x037e, 0x0360, 0x0378, 0x037e, 0x0218, 0x0234, 0x0260, 0x0270, 0x0276, 0x0280, 0x0278, 0x0280, + 0x0360, 0x0370, 0x0380, 0x0360, 0x0370, 0x037e, 0x021e, 0x0238, 0x0260, 0x0270, 0x0278, 0x0280, 0x0278, 0x0280, 0x0360, 0x0370, 0x0380, 0x0360, 0x0364, 0x0374, 0x0220, 0x0238, 0x0260, 0x0280, + 0x0280, 0x0280, 0x0278, 0x0280, 0x0360, 0x0370, 0x0380, 0x0360, 0x0378, 0x037e, 0x0224, 0x0248, 0x0270, 0x0280, 0x0280, 0x0280, 0x0278, 0x0280, 0x0360, 0x0370, 0x0380, 0x0360, 0x0378, 0x037c, + 0x0104, 0x0108, 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0180, 0x0160, 0x0170, 0x0180, 0x0180, 0x027e, 0x037e, 0x0204, 0x0208, 0x0210, 0x0120, 0x0130, 0x0140, 0x0150, 0x0180, 0x0180, 0x0170, + 0x0180, 0x0180, 0x0280, 0x0280, 0x0208, 0x0210, 0x0220, 0x0140, 0x0160, 0x0180, 0x0178, 0x0274, 0x0274, 0x0270, 0x0180, 0x0280, 0x0280, 0x027c, 0x0208, 0x0210, 0x0220, 0x0240, 0x0258, 0x0274, + 0x0270, 0x027c, 0x0280, 0x0270, 0x0280, 0x027c, 0x0378, 0x0378, 0x020c, 0x0218, 0x0230, 0x0250, 0x0270, 0x027c, 0x0276, 0x027c, 0x027a, 0x0270, 0x0280, 0x0278, 0x037a, 0x037c, 0x0210, 0x0220, + 0x0238, 0x0260, 0x0276, 0x0280, 0x0278, 0x027e, 0x0280, 0x0270, 0x0280, 0x0378, 0x0378, 0x0380, 0x0210, 0x0220, 0x0240, 0x0274, 0x026e, 0x027e, 0x0280, 0x0280, 0x0280, 0x0270, 0x0280, 0x0368, + 0x0378, 0x0380, 0x0214, 0x0228, 0x0248, 0x026e, 0x0278, 0x027e, 0x0274, 0x0280, 0x027a, 0x0270, 0x0280, 0x0366, 0x0380, 0x037e, 0x021a, 0x0228, 0x0256, 0x0272, 0x027a, 0x027e, 0x0278, 0x0280, + 0x037e, 0x0270, 0x0380, 0x0378, 0x0374, 0x037c, 0x0218, 0x0238, 0x0260, 0x0272, 0x0278, 0x0280, 0x0278, 0x0280, 0x0372, 0x0370, 0x0380, 0x0378, 0x036e, 0x037e, 0x0220, 0x0238, 0x026a, 0x0272, + 0x0278, 0x0280, 0x0278, 0x0280, 0x0372, 0x0370, 0x0380, 0x0378, 0x0374, 0x037e, 0x0220, 0x0240, 0x026c, 0x027e, 0x0278, 0x0280, 0x0278, 0x0280, 0x0378, 0x0370, 0x0380, 0x0378, 0x0370, 0x037e, + 0x022a, 0x0254, 0x026c, 0x027e, 0x0278, 0x0280, 0x0278, 0x0280, 0x0372, 0x0370, 0x0380, 0x0378, 0x0378, 0x037e, 0x0104, 0x0108, 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0180, 0x0160, 0x0170, + 0x0180, 0x0180, 0x027e, 0x027e, 0x0204, 0x0208, 0x0210, 0x0120, 0x0130, 0x0140, 0x0150, 0x0180, 0x0274, 0x0170, 0x0180, 0x0180, 0x0280, 0x0278, 0x0208, 0x0210, 0x0220, 0x0140, 0x0148, 0x0180, + 0x0178, 0x0180, 0x0180, 0x0170, 0x0180, 0x0180, 0x0180, 0x0274, 0x0208, 0x0210, 0x0220, 0x0240, 0x0258, 0x017c, 0x0270, 0x0180, 0x0180, 0x0180, 0x0180, 0x0180, 0x0180, 0x0274, 0x020c, 0x0218, + 0x0228, 0x0258, 0x0170, 0x017e, 0x017e, 0x017e, 0x0180, 0x017e, 0x0180, 0x0180, 0x0180, 0x0374, 0x0210, 0x0220, 0x0230, 0x0260, 0x0176, 0x017c, 0x017e, 0x0180, 0x0180, 0x0180, 0x0180, 0x0180, + 0x0280, 0x0374, 0x0210, 0x0220, 0x023e, 0x025c, 0x0180, 0x017c, 0x017a, 0x0180, 0x0180, 0x0180, 0x0180, 0x0180, 0x0366, 0x037e, 0x0214, 0x0220, 0x0248, 0x025a, 0x0180, 0x017e, 0x017a, 0x0180, + 0x0180, 0x0180, 0x0180, 0x0180, 0x036c, 0x0374, 0x021a, 0x0228, 0x0256, 0x017e, 0x017e, 0x0180, 0x017e, 0x0180, 0x0180, 0x0180, 0x0180, 0x0180, 0x0380, 0x0374, 0x0220, 0x0238, 0x0260, 0x025e, + 0x0180, 0x0180, 0x0178, 0x0180, 0x0180, 0x0370, 0x0380, 0x0180, 0x036c, 0x0374, 0x0220, 0x023a, 0x025c, 0x0260, 0x017e, 0x0180, 0x017e, 0x0180, 0x0180, 0x0370, 0x0380, 0x0180, 0x0380, 0x037e, + 0x0220, 0x0240, 0x0260, 0x0260, 0x0180, 0x017c, 0x0178, 0x0180, 0x0180, 0x0370, 0x0380, 0x0360, 0x0380, 0x037e, 0x0228, 0x0250, 0x0260, 0x0260, 0x0180, 0x0180, 0x0278, 0x0180, 0x0180, 0x0370, + 0x0380, 0x0360, 0x036c, 0x0380, 0x0104, 0x0108, 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0180, 0x0160, 0x0170, 0x0180, 0x0180, 0x027c, 0x027a, 0x0204, 0x0208, 0x0210, 0x0220, 0x0130, 0x0140, + 0x0150, 0x0180, 0x0274, 0x0170, 0x0180, 0x0180, 0x0278, 0x027a, 0x0208, 0x0210, 0x0220, 0x0140, 0x0160, 0x0180, 0x0178, 0x0274, 0x0260, 0x0170, 0x0180, 0x0180, 0x0278, 0x027a, 0x0208, 0x0210, + 0x0220, 0x0240, 0x0258, 0x0274, 0x0270, 0x0180, 0x0260, 0x0180, 0x0180, 0x0180, 0x026e, 0x0268, 0x020c, 0x0218, 0x0230, 0x0258, 0x0260, 0x017e, 0x017e, 0x0180, 0x017e, 0x0180, 0x0180, 0x017e, + 0x026e, 0x0180, 0x0210, 0x0220, 0x0238, 0x0260, 0x025e, 0x0180, 0x0180, 0x0180, 0x017e, 0x0180, 0x0180, 0x017e, 0x0266, 0x027a, 0x0210, 0x0220, 0x023e, 0x025e, 0x0260, 0x017e, 0x0180, 0x0180, + 0x017e, 0x0180, 0x0180, 0x017e, 0x027e, 0x0368, 0x0214, 0x0220, 0x0248, 0x025a, 0x025e, 0x0180, 0x0180, 0x0180, 0x017e, 0x0180, 0x0180, 0x017e, 0x0266, 0x0368, 0x021a, 0x0236, 0x0256, 0x0260, + 0x0260, 0x017e, 0x017e, 0x0180, 0x017e, 0x0180, 0x0180, 0x0260, 0x027a, 0x0368, 0x0220, 0x0240, 0x0258, 0x0260, 0x0260, 0x0180, 0x017e, 0x0180, 0x017e, 0x0180, 0x0180, 0x0260, 0x0350, 0x0364, + 0x0220, 0x023e, 0x025c, 0x0260, 0x0260, 0x0180, 0x0180, 0x0180, 0x017e, 0x0180, 0x0180, 0x0260, 0x0356, 0x0364, 0x0226, 0x0240, 0x0258, 0x0180, 0x0260, 0x0180, 0x0180, 0x0180, 0x0260, 0x0180, + 0x0180, 0x0360, 0x036e, 0x0364, 0x022c, 0x0254, 0x0260, 0x0260, 0x0260, 0x0180, 0x0278, 0x0180, 0x0260, 0x0180, 0x0180, 0x0360, 0x0350, 0x0364, 0x0104, 0x0208, 0x0110, 0x0120, 0x0130, 0x0140, + 0x0150, 0x0180, 0x0160, 0x0170, 0x0180, 0x0174, 0x0180, 0x0180, 0x0204, 0x0208, 0x0210, 0x0120, 0x0130, 0x0140, 0x0150, 0x0180, 0x0174, 0x0170, 0x0274, 0x0180, 0x0178, 0x017e, 0x0208, 0x0210, + 0x0220, 0x0140, 0x0160, 0x015e, 0x0270, 0x0174, 0x0160, 0x0170, 0x0180, 0x0180, 0x0178, 0x017e, 0x0208, 0x0210, 0x0220, 0x014e, 0x0258, 0x0174, 0x0168, 0x017c, 0x0160, 0x016e, 0x017e, 0x0180, + 0x0178, 0x0358, 0x020c, 0x0218, 0x0228, 0x0258, 0x015e, 0x0162, 0x0170, 0x0174, 0x0160, 0x016e, 0x017c, 0x0160, 0x017c, 0x0364, 0x0210, 0x0220, 0x0238, 0x0158, 0x0160, 0x0162, 0x0162, 0x0174, + 0x0160, 0x016e, 0x017e, 0x0160, 0x017c, 0x0356, 0x0210, 0x0220, 0x023e, 0x0160, 0x0160, 0x0166, 0x0160, 0x017e, 0x0160, 0x016e, 0x0280, 0x0160, 0x017c, 0x0356, 0x0214, 0x0220, 0x0248, 0x0250, + 0x0160, 0x015e, 0x0166, 0x017e, 0x0160, 0x016e, 0x017e, 0x0360, 0x017c, 0x0360, 0x021a, 0x0236, 0x024c, 0x0160, 0x0160, 0x0162, 0x0160, 0x017e, 0x0160, 0x0172, 0x0280, 0x0360, 0x0356, 0x0358, + 0x0220, 0x023e, 0x024c, 0x0160, 0x0160, 0x0162, 0x0178, 0x0174, 0x0160, 0x0270, 0x0280, 0x0360, 0x034c, 0x035e, 0x0220, 0x023e, 0x0248, 0x0160, 0x0160, 0x0160, 0x0160, 0x0280, 0x0160, 0x0370, + 0x0360, 0x0360, 0x0352, 0x035e, 0x0220, 0x0242, 0x0250, 0x0160, 0x0160, 0x0160, 0x0160, 0x0280, 0x0160, 0x0270, 0x0360, 0x0360, 0x0350, 0x0376, 0x022c, 0x0248, 0x0250, 0x0160, 0x0160, 0x0160, + 0x0160, 0x0280, 0x0160, 0x0370, 0x0360, 0x0360, 0x0376, 0x037e, 0x0104, 0x0104, 0x0110, 0x0210, 0x0120, 0x0140, 0x0150, 0x0140, 0x0260, 0x0270, 0x0180, 0x0292, 0x029c, 0x03aa, 0x0208, 0x020e, + 0x020a, 0x0114, 0x014c, 0x0250, 0x0262, 0x0270, 0x0298, 0x029c, 0x029a, 0x02a8, 0x029c, 0x03a8, 0x0206, 0x020c, 0x0118, 0x0130, 0x0230, 0x0240, 0x0178, 0x0280, 0x02aa, 0x0270, 0x02a6, 0x03a6, + 0x03aa, 0x03aa, 0x0208, 0x0110, 0x0218, 0x022e, 0x025e, 0x0260, 0x0280, 0x029e, 0x02a2, 0x02a4, 0x02aa, 0x03a0, 0x03a0, 0x03aa, 0x010a, 0x0210, 0x021e, 0x024c, 0x0272, 0x027c, 0x02a0, 0x02a0, + 0x02a8, 0x03a8, 0x02a6, 0x03a8, 0x03aa, 0x03aa, 0x020c, 0x0118, 0x0244, 0x0256, 0x01a2, 0x0288, 0x02a0, 0x02a4, 0x02aa, 0x03a8, 0x03a8, 0x03aa, 0x03aa, 0x03aa, 0x0210, 0x0220, 0x0238, 0x0280, + 0x02a0, 0x0280, 0x02a0, 0x02a8, 0x03aa, 0x03a8, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x0210, 0x021e, 0x0230, 0x028a, 0x02a8, 0x02a2, 0x02a0, 0x02aa, 0x03a8, 0x03a8, 0x03aa, 0x03a8, 0x03aa, 0x03aa, + 0x0216, 0x0220, 0x0240, 0x0290, 0x029c, 0x029c, 0x02a8, 0x03aa, 0x03aa, 0x03a8, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x0218, 0x0230, 0x0258, 0x029c, 0x02a8, 0x02aa, 0x03a4, 0x03aa, 0x03aa, 0x03aa, + 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x0220, 0x0230, 0x0268, 0x02a2, 0x02a8, 0x02aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x021a, 0x0238, 0x026c, 0x029e, 0x02aa, 0x02aa, + 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x0220, 0x0240, 0x0284, 0x02aa, 0x02aa, 0x02aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x0104, 0x0108, + 0x0110, 0x0110, 0x0130, 0x0220, 0x0128, 0x0240, 0x0160, 0x0170, 0x0180, 0x0260, 0x0298, 0x02a0, 0x0202, 0x0106, 0x0112, 0x011a, 0x0220, 0x0152, 0x025c, 0x0260, 0x02a0, 0x0284, 0x029e, 0x02a8, + 0x0298, 0x02a8, 0x0206, 0x020c, 0x0118, 0x0220, 0x0230, 0x0240, 0x0178, 0x0280, 0x02a0, 0x0198, 0x02a0, 0x02a8, 0x02a4, 0x02aa, 0x0204, 0x0210, 0x021c, 0x0240, 0x024e, 0x0276, 0x01a0, 0x02a8, + 0x029c, 0x029c, 0x029c, 0x02a8, 0x02aa, 0x02aa, 0x0110, 0x0118, 0x0224, 0x0260, 0x028c, 0x029e, 0x02a0, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x010c, 0x0218, 0x0138, 0x0260, + 0x0260, 0x02a0, 0x0298, 0x029e, 0x02a8, 0x02aa, 0x02a4, 0x02a8, 0x02aa, 0x02aa, 0x020e, 0x0220, 0x0240, 0x026c, 0x028c, 0x0298, 0x02a0, 0x02aa, 0x02aa, 0x02aa, 0x02a6, 0x02aa, 0x02aa, 0x02aa, + 0x0210, 0x0222, 0x0258, 0x0270, 0x02a8, 0x02a0, 0x02a0, 0x02a8, 0x02aa, 0x02a8, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x0216, 0x0224, 0x0248, 0x0298, 0x02a8, 0x02aa, 0x02a6, 0x02aa, 0x02aa, 0x02aa, + 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x0218, 0x0230, 0x025e, 0x0298, 0x02a8, 0x02a8, 0x02a4, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x021c, 0x0238, 0x0268, 0x02a4, 0x02a8, 0x02aa, + 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x021c, 0x0238, 0x0268, 0x02a4, 0x02aa, 0x02a8, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x0220, 0x0246, + 0x0290, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x0104, 0x0204, 0x0208, 0x0110, 0x0118, 0x0220, 0x0128, 0x0140, 0x02a2, 0x0270, 0x0280, 0x0290, + 0x02a0, 0x02aa, 0x0204, 0x0210, 0x0208, 0x0218, 0x022a, 0x0160, 0x023e, 0x0298, 0x028a, 0x02a0, 0x0296, 0x02a8, 0x029c, 0x03aa, 0x0204, 0x0208, 0x0218, 0x0220, 0x0230, 0x0240, 0x0250, 0x0280, + 0x0284, 0x0278, 0x019c, 0x0390, 0x02a0, 0x03aa, 0x0208, 0x020e, 0x0226, 0x024c, 0x0260, 0x0280, 0x0278, 0x02a6, 0x02aa, 0x02a8, 0x02a6, 0x029e, 0x03a8, 0x03aa, 0x0208, 0x021a, 0x0230, 0x024a, + 0x0278, 0x0278, 0x029e, 0x02aa, 0x02a8, 0x02a8, 0x02a8, 0x03aa, 0x03aa, 0x03aa, 0x0210, 0x0218, 0x0138, 0x025a, 0x0270, 0x02a0, 0x0298, 0x02a4, 0x02a8, 0x02aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, + 0x020e, 0x0220, 0x0240, 0x026c, 0x02a2, 0x029c, 0x029e, 0x02aa, 0x03aa, 0x03a8, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x0216, 0x0228, 0x0248, 0x0280, 0x02a8, 0x02a6, 0x0292, 0x02a8, 0x03a8, 0x03aa, + 0x03aa, 0x03a8, 0x03aa, 0x03a8, 0x0214, 0x0228, 0x0250, 0x029c, 0x02a8, 0x02aa, 0x02a8, 0x02aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x0218, 0x022c, 0x0258, 0x02a2, 0x02a8, 0x02aa, + 0x02aa, 0x03aa, 0x03aa, 0x03a8, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x021a, 0x0230, 0x0260, 0x02aa, 0x02a8, 0x02aa, 0x02aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x021c, 0x0230, + 0x0268, 0x02a8, 0x02a8, 0x02aa, 0x02aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x0220, 0x0246, 0x0290, 0x029e, 0x02a8, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, + 0x03aa, 0x03aa, 0x0202, 0x0204, 0x0208, 0x0210, 0x0118, 0x0140, 0x0228, 0x0240, 0x0160, 0x0170, 0x0280, 0x01a4, 0x0298, 0x02a6, 0x0108, 0x0204, 0x0210, 0x0120, 0x0142, 0x0240, 0x0270, 0x02a0, + 0x0192, 0x029e, 0x01aa, 0x02a8, 0x0298, 0x02aa, 0x0206, 0x020c, 0x0218, 0x0220, 0x0230, 0x0240, 0x0250, 0x0280, 0x0298, 0x0270, 0x0288, 0x02a8, 0x02aa, 0x03aa, 0x0206, 0x020c, 0x0218, 0x024a, + 0x0246, 0x0260, 0x0290, 0x029a, 0x02a8, 0x02a8, 0x02a8, 0x02a0, 0x03aa, 0x03aa, 0x020a, 0x011a, 0x0236, 0x023e, 0x0270, 0x01a0, 0x0298, 0x02aa, 0x02a4, 0x02a8, 0x02a8, 0x02a8, 0x03aa, 0x03aa, + 0x020c, 0x0218, 0x0242, 0x0280, 0x0288, 0x0298, 0x02a0, 0x02a4, 0x02a8, 0x02a8, 0x02aa, 0x03a8, 0x03aa, 0x03aa, 0x0210, 0x0220, 0x0240, 0x0278, 0x0288, 0x02a0, 0x02a2, 0x02a4, 0x02a8, 0x02a8, + 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x0212, 0x0222, 0x0248, 0x0268, 0x02a8, 0x02a8, 0x0298, 0x02a4, 0x02a8, 0x02a8, 0x03aa, 0x03a8, 0x03aa, 0x03aa, 0x0210, 0x0226, 0x0240, 0x0270, 0x02a8, 0x02a0, + 0x02a0, 0x02aa, 0x03aa, 0x03a8, 0x03aa, 0x03a8, 0x03aa, 0x03aa, 0x0218, 0x0230, 0x0264, 0x02a0, 0x02a8, 0x02a0, 0x02a0, 0x02aa, 0x03a8, 0x03a8, 0x03a8, 0x03a8, 0x03aa, 0x03aa, 0x021c, 0x0230, + 0x0268, 0x02a0, 0x02a8, 0x02a8, 0x02a0, 0x03aa, 0x03aa, 0x03a8, 0x03aa, 0x03a8, 0x03aa, 0x03aa, 0x021a, 0x0234, 0x0264, 0x0290, 0x02aa, 0x02aa, 0x02aa, 0x03aa, 0x03a8, 0x03a8, 0x03aa, 0x03aa, + 0x03aa, 0x03aa, 0x0218, 0x0238, 0x0278, 0x02a2, 0x02a8, 0x02aa, 0x03aa, 0x03aa, 0x03a8, 0x03a8, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x0102, 0x0208, 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0298, + 0x0292, 0x0186, 0x026c, 0x01aa, 0x0298, 0x03a8, 0x0104, 0x010c, 0x020c, 0x022a, 0x0230, 0x0240, 0x0172, 0x0298, 0x029c, 0x01a8, 0x01a4, 0x0290, 0x02a0, 0x03a8, 0x0204, 0x020c, 0x0218, 0x0220, + 0x0230, 0x0240, 0x0250, 0x0298, 0x0298, 0x029c, 0x02a0, 0x02a8, 0x03a8, 0x03a6, 0x0208, 0x0210, 0x0220, 0x0240, 0x0260, 0x0178, 0x0270, 0x0288, 0x02a2, 0x02aa, 0x02a0, 0x039e, 0x03a0, 0x03a6, + 0x020a, 0x0218, 0x0228, 0x024c, 0x0190, 0x02a0, 0x02a0, 0x02a8, 0x02aa, 0x03a4, 0x03a0, 0x03a2, 0x03a8, 0x03a4, 0x020e, 0x021c, 0x0230, 0x027c, 0x0188, 0x02a0, 0x02a0, 0x02aa, 0x02a8, 0x03a8, + 0x03a2, 0x03a6, 0x03a8, 0x03a0, 0x0212, 0x021a, 0x0230, 0x0280, 0x02aa, 0x029a, 0x02a0, 0x02a8, 0x03aa, 0x03a8, 0x03a8, 0x03a0, 0x03a8, 0x03aa, 0x0210, 0x0220, 0x024c, 0x0270, 0x02a8, 0x029c, + 0x02a0, 0x03a0, 0x039e, 0x03a6, 0x03a0, 0x03a0, 0x03a4, 0x03a6, 0x0218, 0x0228, 0x0258, 0x0298, 0x02a8, 0x02a2, 0x02a6, 0x03a0, 0x03a8, 0x03a8, 0x03a0, 0x03a0, 0x039c, 0x03a8, 0x021a, 0x022a, + 0x0270, 0x02a0, 0x02a8, 0x02a8, 0x03a0, 0x03a0, 0x03a8, 0x03a8, 0x03a0, 0x03a0, 0x03a6, 0x03a8, 0x021a, 0x0238, 0x0270, 0x02a8, 0x02a8, 0x02a8, 0x03a0, 0x03a0, 0x03a8, 0x03a8, 0x03a0, 0x03a0, + 0x03aa, 0x039e, 0x0218, 0x023e, 0x0280, 0x02a8, 0x02a8, 0x02aa, 0x03a0, 0x03a0, 0x03a8, 0x03a8, 0x03a0, 0x03a0, 0x03aa, 0x03a0, 0x0222, 0x024a, 0x028c, 0x02a8, 0x02a8, 0x03a0, 0x03a0, 0x03a0, + 0x03a8, 0x03a8, 0x03a0, 0x03a0, 0x03a6, 0x03a8, 0x0104, 0x0108, 0x0208, 0x0120, 0x0130, 0x0140, 0x013c, 0x0140, 0x0198, 0x0190, 0x0176, 0x0298, 0x0298, 0x02a6, 0x0108, 0x0108, 0x0212, 0x012a, + 0x022a, 0x0246, 0x0254, 0x0190, 0x019e, 0x01a8, 0x01a6, 0x0288, 0x02a0, 0x03a4, 0x0204, 0x0110, 0x0120, 0x0230, 0x0230, 0x0240, 0x0178, 0x0180, 0x0280, 0x029c, 0x029c, 0x02a8, 0x03a0, 0x03a4, + 0x020a, 0x0210, 0x021e, 0x0240, 0x0260, 0x017c, 0x027a, 0x02a8, 0x02aa, 0x0298, 0x02aa, 0x02aa, 0x03a0, 0x03a4, 0x020a, 0x0218, 0x0230, 0x0150, 0x0288, 0x0294, 0x02a0, 0x02aa, 0x02a4, 0x02a8, + 0x02aa, 0x02a0, 0x03a0, 0x03aa, 0x020c, 0x0218, 0x0238, 0x0262, 0x019c, 0x02a0, 0x02a0, 0x02a0, 0x0390, 0x02a8, 0x0380, 0x02aa, 0x03a0, 0x03a6, 0x0218, 0x0226, 0x0240, 0x026c, 0x029c, 0x02a0, + 0x029e, 0x02aa, 0x0390, 0x02a8, 0x0380, 0x0390, 0x03a0, 0x03a6, 0x0210, 0x0220, 0x0238, 0x0278, 0x02a8, 0x02a0, 0x02a0, 0x02a0, 0x0390, 0x03a8, 0x0380, 0x03a0, 0x03a0, 0x03a6, 0x0216, 0x022c, + 0x0248, 0x02a0, 0x02a8, 0x02a0, 0x02a0, 0x02aa, 0x0390, 0x03a8, 0x0380, 0x03a0, 0x03a0, 0x03a8, 0x0218, 0x0230, 0x026e, 0x02a8, 0x02a8, 0x02a0, 0x02a0, 0x0380, 0x0390, 0x03a8, 0x0380, 0x03a0, + 0x03a0, 0x03a6, 0x021e, 0x0238, 0x026a, 0x02a0, 0x02a8, 0x02a0, 0x02a0, 0x0380, 0x0390, 0x03a8, 0x0380, 0x03a0, 0x03a0, 0x03a6, 0x021e, 0x0238, 0x0280, 0x0298, 0x02a8, 0x02a0, 0x02a0, 0x0380, + 0x0390, 0x03a8, 0x0380, 0x03a0, 0x03a0, 0x03a6, 0x022c, 0x0248, 0x0290, 0x02a0, 0x02a8, 0x02a0, 0x02a0, 0x0380, 0x0390, 0x03a8, 0x0380, 0x03a0, 0x03a0, 0x03a6, 0x0104, 0x0108, 0x0108, 0x0280, + 0x0132, 0x025c, 0x0230, 0x0174, 0x01a4, 0x01a6, 0x019e, 0x019e, 0x01a0, 0x03a8, 0x0104, 0x020c, 0x0210, 0x0220, 0x0236, 0x0140, 0x0150, 0x0280, 0x01aa, 0x0270, 0x01aa, 0x02a8, 0x02a8, 0x03a2, + 0x020a, 0x020e, 0x022a, 0x012e, 0x0260, 0x017e, 0x017c, 0x019c, 0x029c, 0x02a8, 0x02a2, 0x02a8, 0x03aa, 0x03aa, 0x0208, 0x0210, 0x0220, 0x0168, 0x0178, 0x0180, 0x02a0, 0x029e, 0x02aa, 0x03a8, + 0x02aa, 0x02aa, 0x03a2, 0x03a8, 0x0208, 0x0218, 0x0230, 0x0250, 0x0270, 0x027e, 0x02a0, 0x02aa, 0x0390, 0x02a4, 0x0380, 0x03a6, 0x03a2, 0x03a6, 0x0212, 0x0226, 0x0240, 0x0270, 0x019e, 0x02a0, + 0x02a2, 0x02a6, 0x0390, 0x03a8, 0x0380, 0x03a0, 0x03a6, 0x03a8, 0x0210, 0x0220, 0x023e, 0x0280, 0x02a8, 0x02a0, 0x02a0, 0x0380, 0x0390, 0x03a8, 0x0380, 0x03a0, 0x03a6, 0x03a6, 0x0214, 0x0228, + 0x0248, 0x028e, 0x02a8, 0x02a0, 0x02a0, 0x03a0, 0x0390, 0x03a8, 0x0380, 0x03a0, 0x039e, 0x03a6, 0x0214, 0x0230, 0x0250, 0x029c, 0x02a8, 0x02a0, 0x02a0, 0x0380, 0x0390, 0x03a8, 0x0380, 0x0380, + 0x039c, 0x03a6, 0x021a, 0x0238, 0x0278, 0x02a0, 0x02a8, 0x02a6, 0x02a0, 0x03a0, 0x0390, 0x03a8, 0x0380, 0x03a0, 0x039c, 0x03a6, 0x0220, 0x0240, 0x0298, 0x02a0, 0x02a8, 0x02a0, 0x02a0, 0x0380, + 0x0390, 0x03a8, 0x0380, 0x0380, 0x0392, 0x03a6, 0x0220, 0x0240, 0x0280, 0x02a8, 0x02a8, 0x02aa, 0x03a0, 0x03a0, 0x0390, 0x03a8, 0x0380, 0x03a0, 0x03a2, 0x03a6, 0x022c, 0x0258, 0x0290, 0x02a0, + 0x02a8, 0x0390, 0x03a0, 0x0380, 0x0390, 0x03a8, 0x0380, 0x03a0, 0x039e, 0x03a6, 0x0104, 0x0204, 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0140, 0x0160, 0x0170, 0x0180, 0x019c, 0x01a0, 0x0290, + 0x0204, 0x0110, 0x0208, 0x011e, 0x0230, 0x0140, 0x0182, 0x017e, 0x0190, 0x019c, 0x0180, 0x01a4, 0x02a8, 0x02a8, 0x0204, 0x0208, 0x0118, 0x0230, 0x0248, 0x0240, 0x0178, 0x0190, 0x0298, 0x01aa, + 0x01a4, 0x01a8, 0x01a8, 0x029c, 0x010a, 0x020c, 0x0220, 0x0240, 0x0158, 0x01a2, 0x01a0, 0x029c, 0x0288, 0x01a8, 0x01a8, 0x01a8, 0x01a6, 0x029c, 0x020a, 0x0218, 0x0230, 0x0248, 0x0268, 0x01a0, + 0x01a0, 0x01a0, 0x01a2, 0x01a8, 0x01a8, 0x01a8, 0x0296, 0x039c, 0x020c, 0x0218, 0x0238, 0x017e, 0x019c, 0x01a0, 0x01a0, 0x01a0, 0x01a8, 0x01a8, 0x01a4, 0x01a8, 0x0396, 0x03aa, 0x0210, 0x0220, + 0x0230, 0x0268, 0x0278, 0x027e, 0x01a0, 0x01a0, 0x01a2, 0x01a8, 0x01a8, 0x01a8, 0x0396, 0x039c, 0x0210, 0x0224, 0x0238, 0x0270, 0x01a8, 0x01a0, 0x01a0, 0x01a0, 0x01a8, 0x01a8, 0x01a8, 0x01a8, + 0x0396, 0x03a6, 0x0214, 0x0220, 0x0250, 0x0280, 0x01a4, 0x01a0, 0x01a0, 0x0280, 0x0290, 0x01a8, 0x01a8, 0x01a8, 0x0396, 0x039c, 0x0218, 0x0230, 0x0262, 0x0280, 0x01a8, 0x01a0, 0x01a0, 0x01a0, + 0x01a2, 0x01a8, 0x01a8, 0x01a8, 0x0396, 0x039c, 0x021a, 0x0238, 0x0270, 0x0280, 0x01a8, 0x01a0, 0x01a0, 0x01a0, 0x0290, 0x01a8, 0x01a8, 0x01a8, 0x0396, 0x03a6, 0x021e, 0x0248, 0x0264, 0x01a0, + 0x01a8, 0x01a0, 0x01a0, 0x01a0, 0x0296, 0x01a8, 0x01a8, 0x01a8, 0x0396, 0x0388, 0x0224, 0x0248, 0x0278, 0x01a0, 0x01a8, 0x01a0, 0x01a0, 0x01a0, 0x029c, 0x01a8, 0x01a8, 0x02a8, 0x039c, 0x03a6, +}; + +constexpr uint16_t samples_512[] = { + 0x0108, 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0140, 0x0154, 0x0254, 0x0340, 0x0454, 0x0450, 0x0454, 0x0208, 0x0210, 0x0220, 0x0230, 0x0240, 0x0150, 0x0240, 0x0254, 0x0254, 0x0340, 0x0454, + 0x0454, 0x0454, 0x0208, 0x0218, 0x0220, 0x0230, 0x0240, 0x0250, 0x0254, 0x0354, 0x0354, 0x0354, 0x0454, 0x0454, 0x0354, 0x0210, 0x0220, 0x0240, 0x0248, 0x0254, 0x0250, 0x0354, 0x0352, 0x0354, + 0x0352, 0x0354, 0x0354, 0x0354, 0x0214, 0x0228, 0x0250, 0x024e, 0x0254, 0x0350, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0218, 0x0238, 0x0252, 0x0254, 0x0254, 0x0354, 0x0354, + 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0220, 0x0240, 0x0252, 0x0254, 0x0352, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0220, 0x0240, 0x0254, 0x0354, 0x0354, + 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0230, 0x0252, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0240, 0x0254, 0x0354, + 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x023c, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0240, + 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0250, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, + 0x0354, 0x0108, 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0140, 0x0154, 0x0154, 0x0340, 0x0454, 0x0450, 0x0454, 0x0208, 0x0210, 0x0220, 0x0130, 0x0140, 0x0150, 0x0240, 0x0254, 0x0254, 0x0340, + 0x0454, 0x0454, 0x0454, 0x0210, 0x0210, 0x0130, 0x0148, 0x0240, 0x0250, 0x0254, 0x0354, 0x0352, 0x0352, 0x0354, 0x0350, 0x0354, 0x0210, 0x0220, 0x0240, 0x0248, 0x0254, 0x0250, 0x0354, 0x0354, + 0x0354, 0x0352, 0x0354, 0x0354, 0x0354, 0x0214, 0x0228, 0x0250, 0x0254, 0x0254, 0x0350, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0220, 0x0230, 0x0254, 0x0254, 0x0352, 0x0354, + 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0220, 0x0240, 0x0252, 0x0254, 0x0352, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0224, 0x0240, 0x0254, 0x0354, + 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x022a, 0x0252, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0240, 0x0254, + 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0240, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, + 0x0240, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0248, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, + 0x0354, 0x0354, 0x0108, 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0140, 0x0154, 0x0154, 0x0440, 0x0454, 0x0450, 0x0454, 0x0208, 0x0210, 0x0120, 0x0130, 0x0140, 0x0150, 0x0154, 0x0154, 0x0154, + 0x0340, 0x0454, 0x0454, 0x0454, 0x0210, 0x0118, 0x0140, 0x0148, 0x0154, 0x0150, 0x0340, 0x0354, 0x0354, 0x0454, 0x0454, 0x0454, 0x0454, 0x0210, 0x0220, 0x0140, 0x0248, 0x0254, 0x0250, 0x0354, + 0x0354, 0x0354, 0x0354, 0x0454, 0x0354, 0x0354, 0x021a, 0x0234, 0x0250, 0x0254, 0x0350, 0x0350, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0220, 0x0240, 0x0252, 0x034e, 0x0354, + 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0220, 0x0240, 0x0252, 0x0352, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0228, 0x0250, 0x0254, + 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0234, 0x0252, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0240, + 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0240, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, + 0x0354, 0x024c, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, + 0x0354, 0x0354, 0x0354, 0x0108, 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0140, 0x0154, 0x0154, 0x0340, 0x0454, 0x0450, 0x0454, 0x0208, 0x0210, 0x0120, 0x0130, 0x0140, 0x0150, 0x0240, 0x0254, + 0x0254, 0x0340, 0x0450, 0x0454, 0x0454, 0x0208, 0x0118, 0x0130, 0x0148, 0x0240, 0x0250, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0454, 0x0354, 0x0210, 0x0220, 0x0240, 0x0248, 0x0254, 0x0250, + 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0214, 0x0228, 0x0250, 0x0252, 0x0254, 0x0350, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0218, 0x0230, 0x0252, 0x0252, + 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0220, 0x0240, 0x0252, 0x0252, 0x0352, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0224, 0x0248, + 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x022a, 0x0252, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, + 0x0234, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x023c, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, + 0x0354, 0x0354, 0x0240, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0248, 0x0254, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, 0x0354, + 0x0354, 0x0354, 0x0354, 0x0354, 0x0108, 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0140, 0x0154, 0x0154, 0x0440, 0x0454, 0x0450, 0x0454, 0x0208, 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0154, + 0x0154, 0x0254, 0x0440, 0x0454, 0x0454, 0x0454, 0x0210, 0x0118, 0x0140, 0x0148, 0x0154, 0x0150, 0x0254, 0x0254, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0210, 0x0220, 0x0240, 0x0248, 0x0254, + 0x0250, 0x0254, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0218, 0x0228, 0x0250, 0x0254, 0x0254, 0x0254, 0x0452, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0220, 0x0240, 0x0252, + 0x0254, 0x0254, 0x0254, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0220, 0x0240, 0x0252, 0x0254, 0x0254, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0228, + 0x024e, 0x0254, 0x0254, 0x0254, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0234, 0x0252, 0x0254, 0x0254, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, + 0x0454, 0x0240, 0x0254, 0x0254, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0240, 0x0254, 0x0254, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, + 0x0454, 0x0454, 0x0454, 0x0240, 0x0254, 0x0254, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0254, 0x0254, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, + 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0108, 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0140, 0x0154, 0x0154, 0x0440, 0x0454, 0x0450, 0x0454, 0x0208, 0x0210, 0x0120, 0x0130, 0x0140, 0x0150, + 0x0154, 0x0154, 0x0254, 0x0440, 0x0454, 0x0454, 0x0454, 0x0110, 0x0120, 0x0140, 0x0148, 0x0154, 0x0150, 0x0254, 0x0254, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0210, 0x0220, 0x0240, 0x0248, + 0x0254, 0x0250, 0x0254, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0218, 0x0228, 0x0250, 0x0252, 0x0254, 0x0254, 0x0452, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0220, 0x0240, + 0x0254, 0x0252, 0x0254, 0x0254, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0220, 0x0240, 0x0252, 0x0254, 0x0254, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, + 0x0228, 0x024e, 0x0254, 0x0254, 0x0254, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0234, 0x0252, 0x0254, 0x0254, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, + 0x0454, 0x0454, 0x0240, 0x0254, 0x0254, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0240, 0x0254, 0x0254, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, + 0x0454, 0x0454, 0x0454, 0x0454, 0x024c, 0x0254, 0x0254, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0254, 0x0254, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, + 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0454, 0x0108, 0x0110, 0x0120, 0x0130, 0x0140, 0x0140, 0x0140, 0x0154, 0x0154, 0x0440, 0x0440, 0x0450, 0x0454, 0x0208, 0x0110, 0x0120, 0x0130, 0x0140, + 0x0150, 0x0154, 0x0254, 0x0438, 0x0440, 0x0454, 0x0454, 0x0454, 0x0110, 0x0120, 0x0140, 0x0148, 0x0154, 0x0250, 0x0254, 0x0448, 0x0452, 0x0454, 0x0454, 0x0452, 0x0450, 0x0210, 0x0220, 0x0240, + 0x0248, 0x0254, 0x0250, 0x0450, 0x0452, 0x0452, 0x0454, 0x0454, 0x0450, 0x0454, 0x0218, 0x0230, 0x0254, 0x0252, 0x0254, 0x0254, 0x0452, 0x0452, 0x0454, 0x0454, 0x0454, 0x0450, 0x0454, 0x0220, + 0x0240, 0x0252, 0x0254, 0x0254, 0x0450, 0x0452, 0x0452, 0x0454, 0x0454, 0x0454, 0x0450, 0x0450, 0x0220, 0x0246, 0x0252, 0x0254, 0x0354, 0x0450, 0x0452, 0x0450, 0x0454, 0x0454, 0x0454, 0x0450, + 0x0454, 0x0228, 0x024e, 0x0254, 0x0254, 0x0450, 0x0450, 0x0450, 0x0450, 0x0454, 0x0450, 0x0450, 0x0450, 0x0450, 0x0234, 0x0252, 0x0254, 0x0354, 0x0452, 0x0450, 0x0452, 0x0450, 0x0454, 0x0454, + 0x0450, 0x0450, 0x0450, 0x0240, 0x0254, 0x0354, 0x0354, 0x0450, 0x0450, 0x0450, 0x0450, 0x0454, 0x0450, 0x0450, 0x0450, 0x0450, 0x0248, 0x0254, 0x0354, 0x0454, 0x0450, 0x0450, 0x0450, 0x0450, + 0x0454, 0x0450, 0x0450, 0x0450, 0x0450, 0x024c, 0x0254, 0x0354, 0x0450, 0x0450, 0x0450, 0x0450, 0x0450, 0x0454, 0x0450, 0x0450, 0x0450, 0x0450, 0x0254, 0x0254, 0x0354, 0x0454, 0x0450, 0x0450, + 0x0450, 0x0450, 0x0454, 0x0450, 0x0450, 0x0450, 0x0450, 0x0108, 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0140, 0x0150, 0x0154, 0x0440, 0x0454, 0x0450, 0x0450, 0x0208, 0x0110, 0x0120, 0x0130, + 0x0140, 0x0150, 0x0154, 0x0154, 0x0154, 0x0440, 0x0454, 0x0450, 0x0450, 0x0110, 0x0120, 0x0140, 0x0148, 0x0154, 0x0150, 0x0254, 0x034e, 0x044a, 0x0340, 0x0450, 0x0450, 0x0450, 0x0210, 0x0220, + 0x0150, 0x0252, 0x0254, 0x0250, 0x0340, 0x0350, 0x034e, 0x0340, 0x0446, 0x0450, 0x0450, 0x021a, 0x0230, 0x0250, 0x0252, 0x0254, 0x034c, 0x0340, 0x034e, 0x0448, 0x0448, 0x0454, 0x044c, 0x0446, + 0x0220, 0x0240, 0x0252, 0x0254, 0x0340, 0x034c, 0x0340, 0x0448, 0x044a, 0x0446, 0x0454, 0x0446, 0x0446, 0x0220, 0x0240, 0x0252, 0x0344, 0x0340, 0x0348, 0x0340, 0x0448, 0x044e, 0x0446, 0x044a, + 0x0446, 0x0446, 0x0228, 0x024e, 0x0254, 0x0344, 0x0340, 0x0344, 0x0340, 0x0448, 0x0454, 0x0446, 0x0454, 0x0450, 0x0450, 0x0234, 0x0252, 0x0340, 0x0344, 0x0340, 0x0344, 0x0340, 0x0448, 0x0454, + 0x0444, 0x044e, 0x0450, 0x0446, 0x0240, 0x0254, 0x0340, 0x0348, 0x0340, 0x0350, 0x0340, 0x0448, 0x0354, 0x0340, 0x044e, 0x0444, 0x0446, 0x0240, 0x0254, 0x0340, 0x0348, 0x0340, 0x0344, 0x0340, + 0x0348, 0x0354, 0x0444, 0x044e, 0x0450, 0x0446, 0x0248, 0x0254, 0x0340, 0x0344, 0x0342, 0x0344, 0x0340, 0x0348, 0x0454, 0x0444, 0x0454, 0x0450, 0x0450, 0x0254, 0x0254, 0x0340, 0x0348, 0x0340, + 0x0350, 0x0340, 0x0448, 0x0354, 0x0444, 0x0454, 0x0450, 0x0446, 0x0108, 0x0110, 0x0120, 0x0130, 0x0140, 0x0128, 0x0180, 0x0160, 0x0170, 0x0180, 0x0360, 0x0480, 0x0480, 0x0208, 0x0210, 0x0220, + 0x0230, 0x0240, 0x0250, 0x0280, 0x0260, 0x0270, 0x0280, 0x0360, 0x0480, 0x0480, 0x0208, 0x0210, 0x0220, 0x0230, 0x0240, 0x0250, 0x0280, 0x0280, 0x0270, 0x0280, 0x0380, 0x0480, 0x0480, 0x0210, + 0x0220, 0x0240, 0x0260, 0x0280, 0x0278, 0x0280, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0214, 0x0228, 0x0250, 0x0278, 0x0280, 0x0278, 0x0280, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, + 0x0380, 0x0218, 0x0230, 0x0260, 0x0280, 0x0280, 0x0280, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0220, 0x0240, 0x0280, 0x0280, 0x0280, 0x0280, 0x0380, 0x0380, 0x0380, 0x0380, + 0x0380, 0x0380, 0x0380, 0x0220, 0x0248, 0x0280, 0x0280, 0x0280, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0230, 0x0256, 0x0280, 0x0280, 0x0380, 0x0380, 0x0380, 0x0380, + 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0230, 0x0266, 0x0280, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x023c, 0x0270, 0x0280, 0x0380, 0x0380, 0x0380, + 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x023e, 0x0280, 0x0280, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0248, 0x0280, 0x0280, 0x0380, + 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0204, 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0180, 0x0160, 0x0170, 0x0180, 0x0360, 0x0474, 0x0480, 0x0208, 0x0210, + 0x0220, 0x0230, 0x0140, 0x0150, 0x0280, 0x0280, 0x0270, 0x0280, 0x0360, 0x0480, 0x0480, 0x0208, 0x0210, 0x0220, 0x0230, 0x0240, 0x0250, 0x0280, 0x0280, 0x0270, 0x0280, 0x0380, 0x0480, 0x0380, + 0x0210, 0x0220, 0x0240, 0x0260, 0x0280, 0x0278, 0x0280, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0214, 0x0228, 0x0250, 0x0278, 0x0280, 0x0280, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, + 0x0380, 0x0380, 0x0218, 0x0230, 0x0260, 0x0274, 0x0280, 0x0280, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0220, 0x0240, 0x0280, 0x0280, 0x0280, 0x0280, 0x0380, 0x0380, 0x0380, + 0x0380, 0x0380, 0x0380, 0x0380, 0x0224, 0x0248, 0x0280, 0x0280, 0x0280, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0228, 0x0248, 0x0280, 0x0280, 0x0380, 0x0380, 0x0380, + 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0234, 0x0268, 0x0280, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0238, 0x0270, 0x0280, 0x0380, 0x0380, + 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x023e, 0x0268, 0x0280, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0248, 0x027e, 0x0280, + 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0108, 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0180, 0x0160, 0x0170, 0x0180, 0x0360, 0x0474, 0x0480, 0x0208, + 0x0210, 0x0120, 0x0230, 0x0140, 0x0150, 0x0180, 0x0180, 0x0270, 0x0280, 0x0360, 0x0480, 0x0480, 0x0210, 0x0220, 0x0140, 0x0148, 0x0160, 0x0178, 0x0280, 0x0180, 0x0370, 0x0380, 0x0380, 0x0480, + 0x0480, 0x0210, 0x0220, 0x0240, 0x0260, 0x0280, 0x0270, 0x0280, 0x0380, 0x0380, 0x0380, 0x0380, 0x0480, 0x0380, 0x0218, 0x0230, 0x0260, 0x0270, 0x0280, 0x0280, 0x0380, 0x0380, 0x0380, 0x0380, + 0x0380, 0x037e, 0x037e, 0x0220, 0x0240, 0x0280, 0x0274, 0x0280, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x037e, 0x0220, 0x0240, 0x0280, 0x0280, 0x027e, 0x0380, 0x0380, 0x0380, + 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0228, 0x0248, 0x027c, 0x0280, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x037e, 0x0380, 0x0236, 0x0258, 0x0280, 0x0380, 0x037e, 0x0380, + 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x023e, 0x027c, 0x0280, 0x0380, 0x0380, 0x037e, 0x0380, 0x0380, 0x037e, 0x0380, 0x0380, 0x0380, 0x0380, 0x0240, 0x0280, 0x0280, 0x0380, + 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0240, 0x027e, 0x037e, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x037e, 0x0380, 0x0380, 0x0380, 0x0380, 0x0254, 0x0280, + 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x0380, 0x037e, 0x0380, 0x0380, 0x037e, 0x0380, 0x0108, 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0180, 0x0160, 0x0170, 0x0180, 0x0360, 0x0474, 0x0480, + 0x0208, 0x0210, 0x0220, 0x0230, 0x0240, 0x0250, 0x0180, 0x0260, 0x0270, 0x0180, 0x0280, 0x0470, 0x0480, 0x0208, 0x0210, 0x0220, 0x0230, 0x0240, 0x0250, 0x0274, 0x0274, 0x0270, 0x0280, 0x0280, + 0x0464, 0x0480, 0x0210, 0x0220, 0x0240, 0x0260, 0x0278, 0x0270, 0x0274, 0x0280, 0x0280, 0x0280, 0x0280, 0x0464, 0x0480, 0x0214, 0x0228, 0x0250, 0x0270, 0x0274, 0x0274, 0x027e, 0x027c, 0x0270, + 0x027e, 0x0460, 0x047e, 0x047e, 0x0218, 0x0230, 0x0258, 0x0274, 0x0280, 0x027c, 0x027a, 0x0280, 0x0270, 0x0280, 0x0460, 0x0462, 0x047e, 0x0220, 0x0238, 0x0274, 0x026e, 0x0280, 0x0276, 0x0280, + 0x0280, 0x027a, 0x027e, 0x0460, 0x0462, 0x047e, 0x0220, 0x0240, 0x0270, 0x0274, 0x0280, 0x0276, 0x0280, 0x0280, 0x0370, 0x0280, 0x0460, 0x0470, 0x047e, 0x0220, 0x0248, 0x0270, 0x0276, 0x0280, + 0x0276, 0x0280, 0x0360, 0x0370, 0x037e, 0x0460, 0x0462, 0x047e, 0x0234, 0x0260, 0x0270, 0x0276, 0x0280, 0x0278, 0x0280, 0x0360, 0x0370, 0x0380, 0x0460, 0x0470, 0x047c, 0x0238, 0x0260, 0x0270, + 0x0278, 0x0280, 0x0278, 0x0280, 0x0360, 0x0370, 0x0380, 0x0460, 0x0470, 0x047c, 0x0238, 0x0260, 0x0280, 0x0280, 0x0280, 0x0278, 0x0280, 0x0360, 0x0370, 0x0380, 0x0460, 0x0464, 0x047e, 0x0248, + 0x0270, 0x0280, 0x0280, 0x0280, 0x0278, 0x0280, 0x0360, 0x0370, 0x0460, 0x0460, 0x0378, 0x037c, 0x0108, 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0180, 0x0160, 0x0170, 0x0180, 0x0180, 0x0470, + 0x047e, 0x0208, 0x0210, 0x0120, 0x0130, 0x0140, 0x0150, 0x0180, 0x0180, 0x0170, 0x0180, 0x0458, 0x0464, 0x047e, 0x0210, 0x0220, 0x0140, 0x0160, 0x0180, 0x0178, 0x0274, 0x0274, 0x0270, 0x0180, + 0x0280, 0x0468, 0x0480, 0x0210, 0x0220, 0x0240, 0x0258, 0x0274, 0x0270, 0x027c, 0x0280, 0x0270, 0x0280, 0x0458, 0x0464, 0x047e, 0x0218, 0x0230, 0x0250, 0x0270, 0x027c, 0x0276, 0x027c, 0x027a, + 0x0270, 0x0280, 0x0462, 0x0464, 0x0480, 0x0220, 0x0238, 0x0260, 0x0276, 0x0280, 0x0278, 0x027e, 0x0280, 0x0270, 0x0280, 0x0460, 0x0468, 0x047e, 0x0220, 0x0240, 0x0274, 0x026e, 0x027e, 0x0280, + 0x0280, 0x0280, 0x0270, 0x0280, 0x0460, 0x0466, 0x047e, 0x0228, 0x0248, 0x026e, 0x0278, 0x027e, 0x0274, 0x0280, 0x027a, 0x0270, 0x0280, 0x0460, 0x0464, 0x047e, 0x0228, 0x0256, 0x0272, 0x027a, + 0x027e, 0x0278, 0x0280, 0x0460, 0x0270, 0x0380, 0x0460, 0x0464, 0x047c, 0x0238, 0x0260, 0x0272, 0x0278, 0x0280, 0x0278, 0x0280, 0x0460, 0x0370, 0x0380, 0x0460, 0x0472, 0x047e, 0x0238, 0x026a, + 0x0272, 0x0278, 0x0280, 0x0278, 0x0280, 0x0460, 0x0370, 0x0380, 0x0460, 0x0472, 0x047e, 0x0240, 0x026c, 0x027e, 0x0278, 0x0280, 0x0278, 0x0280, 0x0460, 0x0370, 0x0380, 0x0460, 0x0478, 0x047e, + 0x0254, 0x026c, 0x027e, 0x0278, 0x0280, 0x0278, 0x0280, 0x0478, 0x0470, 0x0380, 0x0460, 0x047e, 0x047e, 0x0108, 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0180, 0x0160, 0x0170, 0x0180, 0x0180, + 0x0464, 0x047e, 0x0208, 0x0210, 0x0120, 0x0130, 0x0140, 0x0150, 0x0180, 0x0274, 0x0170, 0x0180, 0x0180, 0x0450, 0x047e, 0x0210, 0x0220, 0x0140, 0x0148, 0x0180, 0x0178, 0x0180, 0x0180, 0x0170, + 0x0180, 0x0180, 0x0450, 0x0480, 0x0210, 0x0220, 0x0240, 0x0258, 0x017c, 0x0270, 0x0180, 0x0180, 0x0180, 0x0180, 0x0180, 0x0464, 0x047e, 0x0218, 0x0228, 0x0258, 0x0170, 0x017e, 0x017e, 0x017e, + 0x0180, 0x017e, 0x0180, 0x0180, 0x0468, 0x047e, 0x0220, 0x0230, 0x0260, 0x0176, 0x017c, 0x017e, 0x0180, 0x0180, 0x0180, 0x0180, 0x0180, 0x0464, 0x047e, 0x0220, 0x023e, 0x025c, 0x0180, 0x017c, + 0x017a, 0x0180, 0x0180, 0x0180, 0x0180, 0x0180, 0x0464, 0x046e, 0x0220, 0x0248, 0x025a, 0x0180, 0x017e, 0x017a, 0x0180, 0x0180, 0x0180, 0x0180, 0x0180, 0x0464, 0x0480, 0x0228, 0x0256, 0x017e, + 0x017e, 0x0180, 0x017e, 0x0180, 0x0180, 0x0180, 0x0180, 0x0180, 0x0472, 0x0464, 0x0238, 0x0260, 0x025e, 0x0180, 0x0180, 0x0178, 0x0180, 0x0180, 0x0470, 0x0380, 0x0460, 0x0480, 0x047c, 0x023a, + 0x025c, 0x0260, 0x017e, 0x0180, 0x017e, 0x0180, 0x0180, 0x0470, 0x0380, 0x0460, 0x0480, 0x0464, 0x0240, 0x0260, 0x0260, 0x0180, 0x017c, 0x0178, 0x0180, 0x0180, 0x0370, 0x0380, 0x0460, 0x0478, + 0x0474, 0x0250, 0x0260, 0x0260, 0x0180, 0x0180, 0x0278, 0x0180, 0x0448, 0x0470, 0x0480, 0x0460, 0x0480, 0x047e, 0x0108, 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0180, 0x0160, 0x0170, 0x0180, + 0x0458, 0x0450, 0x0480, 0x0208, 0x0210, 0x0220, 0x0130, 0x0140, 0x0150, 0x0180, 0x0274, 0x0170, 0x0180, 0x0180, 0x0450, 0x0450, 0x0210, 0x0220, 0x0140, 0x0160, 0x0180, 0x0178, 0x0274, 0x0260, + 0x0170, 0x0180, 0x0180, 0x0464, 0x0450, 0x0210, 0x0220, 0x0240, 0x0258, 0x0274, 0x0270, 0x0180, 0x0260, 0x0180, 0x0180, 0x0440, 0x0464, 0x0464, 0x0218, 0x0230, 0x0258, 0x0260, 0x017e, 0x017e, + 0x0180, 0x017e, 0x0180, 0x0180, 0x0462, 0x0464, 0x0464, 0x0220, 0x0238, 0x0260, 0x025e, 0x0180, 0x0180, 0x0180, 0x017e, 0x0180, 0x0180, 0x045e, 0x0464, 0x0464, 0x0220, 0x023e, 0x025e, 0x0260, + 0x017e, 0x0180, 0x0180, 0x017e, 0x0180, 0x0180, 0x045e, 0x0464, 0x0464, 0x0220, 0x0248, 0x025a, 0x025e, 0x0180, 0x0180, 0x0180, 0x017e, 0x0180, 0x0180, 0x0460, 0x0464, 0x0464, 0x0236, 0x0256, + 0x0260, 0x0260, 0x017e, 0x017e, 0x0180, 0x017e, 0x0180, 0x0180, 0x0460, 0x0450, 0x0464, 0x0240, 0x0258, 0x0260, 0x0260, 0x0180, 0x017e, 0x0180, 0x017e, 0x0470, 0x0180, 0x0460, 0x0464, 0x0464, + 0x023e, 0x025c, 0x0260, 0x0260, 0x0180, 0x0180, 0x0180, 0x0460, 0x0470, 0x0180, 0x0460, 0x0464, 0x0464, 0x0240, 0x0258, 0x0180, 0x0260, 0x0180, 0x0180, 0x0180, 0x0460, 0x0470, 0x0180, 0x0460, + 0x0464, 0x0464, 0x0254, 0x0260, 0x0260, 0x0260, 0x0180, 0x0278, 0x0180, 0x0460, 0x0470, 0x0180, 0x0460, 0x0464, 0x0464, 0x0208, 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0180, 0x0160, 0x0170, + 0x0180, 0x0458, 0x0450, 0x0480, 0x0208, 0x0210, 0x0120, 0x0130, 0x0140, 0x0150, 0x0180, 0x0174, 0x0170, 0x0274, 0x0180, 0x0178, 0x017e, 0x0210, 0x0220, 0x0140, 0x0160, 0x015e, 0x0270, 0x0174, + 0x0160, 0x0170, 0x0180, 0x0180, 0x0178, 0x017e, 0x0210, 0x0220, 0x014e, 0x0258, 0x0174, 0x0168, 0x017c, 0x0160, 0x016e, 0x017e, 0x0180, 0x0464, 0x0458, 0x0218, 0x0228, 0x0258, 0x015e, 0x0162, + 0x0170, 0x0174, 0x0160, 0x016e, 0x017c, 0x0160, 0x0464, 0x0480, 0x0220, 0x0238, 0x0158, 0x0160, 0x0162, 0x0162, 0x0174, 0x0160, 0x016e, 0x017e, 0x0160, 0x044e, 0x045c, 0x0220, 0x023e, 0x0160, + 0x0160, 0x0166, 0x0160, 0x017e, 0x0160, 0x016e, 0x0280, 0x0160, 0x0460, 0x0456, 0x0220, 0x0248, 0x0250, 0x0160, 0x015e, 0x0166, 0x017e, 0x0160, 0x016e, 0x017e, 0x0360, 0x0464, 0x0450, 0x0236, + 0x024c, 0x0160, 0x0160, 0x0162, 0x0160, 0x017e, 0x0160, 0x0172, 0x0280, 0x0460, 0x0464, 0x0456, 0x023e, 0x024c, 0x0160, 0x0160, 0x0162, 0x0178, 0x0174, 0x0160, 0x0270, 0x0280, 0x0360, 0x0464, + 0x0476, 0x023e, 0x0248, 0x0160, 0x0160, 0x0160, 0x0160, 0x0280, 0x0160, 0x0370, 0x0360, 0x0360, 0x0464, 0x045e, 0x0242, 0x0250, 0x0160, 0x0160, 0x0160, 0x0160, 0x0280, 0x0160, 0x0270, 0x0360, + 0x0360, 0x0464, 0x0456, 0x0248, 0x0250, 0x0160, 0x0160, 0x0160, 0x0160, 0x0280, 0x0160, 0x0370, 0x0360, 0x0360, 0x0464, 0x047e, 0x0104, 0x0110, 0x0210, 0x0120, 0x0140, 0x0150, 0x0140, 0x0260, + 0x0270, 0x0180, 0x0292, 0x029c, 0x049c, 0x020e, 0x020a, 0x0114, 0x014c, 0x0250, 0x0262, 0x0270, 0x0298, 0x029c, 0x029a, 0x02a8, 0x04a8, 0x04aa, 0x020c, 0x0118, 0x0130, 0x0230, 0x0240, 0x0178, + 0x0280, 0x02aa, 0x0270, 0x02a6, 0x03a6, 0x04aa, 0x04a8, 0x0110, 0x0218, 0x022e, 0x025e, 0x0260, 0x0280, 0x029e, 0x02a2, 0x02a4, 0x02aa, 0x03a0, 0x04aa, 0x04aa, 0x0210, 0x021e, 0x024c, 0x0272, + 0x027c, 0x02a0, 0x02a0, 0x02a8, 0x03a8, 0x02a6, 0x03a8, 0x04aa, 0x04aa, 0x0118, 0x0244, 0x0256, 0x01a2, 0x0288, 0x02a0, 0x02a4, 0x02aa, 0x03a8, 0x03a8, 0x03aa, 0x04aa, 0x04aa, 0x0220, 0x0238, + 0x0280, 0x02a0, 0x0280, 0x02a0, 0x02a8, 0x03aa, 0x03a8, 0x03aa, 0x03aa, 0x04aa, 0x03aa, 0x021e, 0x0230, 0x028a, 0x02a8, 0x02a2, 0x02a0, 0x02aa, 0x03a8, 0x03a8, 0x03aa, 0x03a8, 0x04aa, 0x04aa, + 0x0220, 0x0240, 0x0290, 0x029c, 0x029c, 0x02a8, 0x03aa, 0x03aa, 0x03a8, 0x03aa, 0x03aa, 0x04aa, 0x03aa, 0x0230, 0x0258, 0x029c, 0x02a8, 0x02aa, 0x03a4, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, + 0x03aa, 0x03aa, 0x0230, 0x0268, 0x02a2, 0x02a8, 0x02aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x0238, 0x026c, 0x029e, 0x02aa, 0x02aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, + 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x0240, 0x0284, 0x02aa, 0x02aa, 0x02aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x0108, 0x0110, 0x0110, 0x0130, 0x0220, 0x0128, 0x0240, + 0x0160, 0x0170, 0x0180, 0x0260, 0x0464, 0x049c, 0x0106, 0x0112, 0x011a, 0x0220, 0x0152, 0x025c, 0x0260, 0x02a0, 0x0284, 0x029e, 0x02a8, 0x0496, 0x04aa, 0x020c, 0x0118, 0x0220, 0x0230, 0x0240, + 0x0178, 0x0280, 0x02a0, 0x0198, 0x02a0, 0x02a8, 0x04aa, 0x04aa, 0x0210, 0x021c, 0x0240, 0x024e, 0x0276, 0x01a0, 0x02a8, 0x029c, 0x029c, 0x029c, 0x02a8, 0x04a8, 0x04aa, 0x0118, 0x0224, 0x0260, + 0x028c, 0x029e, 0x02a0, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x04aa, 0x04aa, 0x04aa, 0x0218, 0x0138, 0x0260, 0x0260, 0x02a0, 0x0298, 0x029e, 0x02a8, 0x02aa, 0x02a4, 0x04a8, 0x04aa, 0x04aa, 0x0220, + 0x0240, 0x026c, 0x028c, 0x0298, 0x02a0, 0x02aa, 0x02aa, 0x02aa, 0x02a6, 0x04aa, 0x04aa, 0x04aa, 0x0222, 0x0258, 0x0270, 0x02a8, 0x02a0, 0x02a0, 0x02a8, 0x02aa, 0x04aa, 0x04a8, 0x04aa, 0x04aa, + 0x04aa, 0x0224, 0x0248, 0x0298, 0x02a8, 0x02aa, 0x02a6, 0x02aa, 0x02aa, 0x04a8, 0x04aa, 0x04aa, 0x04aa, 0x04aa, 0x0230, 0x025e, 0x0298, 0x02a8, 0x02a8, 0x02a4, 0x02aa, 0x04aa, 0x04aa, 0x04aa, + 0x04aa, 0x04aa, 0x04aa, 0x0238, 0x0268, 0x02a4, 0x02a8, 0x02aa, 0x02aa, 0x02aa, 0x04aa, 0x04aa, 0x04aa, 0x04aa, 0x04aa, 0x04aa, 0x0238, 0x0268, 0x02a4, 0x02aa, 0x02a8, 0x02aa, 0x02aa, 0x04aa, + 0x04aa, 0x04aa, 0x04aa, 0x04aa, 0x04aa, 0x0246, 0x0290, 0x02aa, 0x02aa, 0x02aa, 0x02aa, 0x04aa, 0x04aa, 0x04aa, 0x04aa, 0x04aa, 0x04aa, 0x04aa, 0x0204, 0x0208, 0x0110, 0x0118, 0x0220, 0x0128, + 0x0140, 0x02a2, 0x0270, 0x0280, 0x0290, 0x049c, 0x04aa, 0x0210, 0x0208, 0x0218, 0x022a, 0x0160, 0x023e, 0x0298, 0x028a, 0x02a0, 0x0296, 0x02a8, 0x029c, 0x04aa, 0x0208, 0x0218, 0x0220, 0x0230, + 0x0240, 0x0250, 0x0280, 0x0284, 0x0278, 0x019c, 0x0390, 0x04aa, 0x04aa, 0x020e, 0x0226, 0x024c, 0x0260, 0x0280, 0x0278, 0x02a6, 0x02aa, 0x02a8, 0x02a6, 0x029e, 0x04aa, 0x04aa, 0x021a, 0x0230, + 0x024a, 0x0278, 0x0278, 0x029e, 0x02aa, 0x02a8, 0x02a8, 0x02a8, 0x03aa, 0x04aa, 0x04a8, 0x0218, 0x0138, 0x025a, 0x0270, 0x02a0, 0x0298, 0x02a4, 0x02a8, 0x02aa, 0x03aa, 0x03aa, 0x04aa, 0x04aa, + 0x0220, 0x0240, 0x026c, 0x02a2, 0x029c, 0x029e, 0x02aa, 0x03aa, 0x03a8, 0x03aa, 0x03aa, 0x04aa, 0x04aa, 0x0228, 0x0248, 0x0280, 0x02a8, 0x02a6, 0x0292, 0x02a8, 0x03a8, 0x03aa, 0x03aa, 0x03a8, + 0x04aa, 0x03a8, 0x0228, 0x0250, 0x029c, 0x02a8, 0x02aa, 0x02a8, 0x02aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x04aa, 0x03aa, 0x022c, 0x0258, 0x02a2, 0x02a8, 0x02aa, 0x02aa, 0x03aa, 0x03aa, 0x03a8, + 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x0230, 0x0260, 0x02aa, 0x02a8, 0x02aa, 0x02aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x0230, 0x0268, 0x02a8, 0x02a8, 0x02aa, 0x02aa, 0x03aa, + 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x0246, 0x0290, 0x029e, 0x02a8, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x03aa, 0x0204, 0x0208, 0x0210, 0x0118, 0x0140, + 0x0228, 0x0240, 0x0160, 0x0170, 0x0280, 0x01a4, 0x0298, 0x049c, 0x0204, 0x0210, 0x0120, 0x0142, 0x0240, 0x0270, 0x02a0, 0x0192, 0x029e, 0x01aa, 0x02a8, 0x04a6, 0x04aa, 0x020c, 0x0218, 0x0220, + 0x0230, 0x0240, 0x0250, 0x0280, 0x0298, 0x0270, 0x0288, 0x02a8, 0x04a8, 0x04aa, 0x020c, 0x0218, 0x024a, 0x0246, 0x0260, 0x0290, 0x029a, 0x02a8, 0x02a8, 0x02a8, 0x02a0, 0x0496, 0x04a8, 0x011a, + 0x0236, 0x023e, 0x0270, 0x01a0, 0x0298, 0x02aa, 0x02a4, 0x02a8, 0x02a8, 0x0490, 0x04a6, 0x04a2, 0x0218, 0x0242, 0x0280, 0x0288, 0x0298, 0x02a0, 0x02a4, 0x02a8, 0x02a8, 0x02aa, 0x04a8, 0x04a6, + 0x04a8, 0x0220, 0x0240, 0x0278, 0x0288, 0x02a0, 0x02a2, 0x02a4, 0x02a8, 0x02a8, 0x04a0, 0x04a8, 0x04a6, 0x04a2, 0x0222, 0x0248, 0x0268, 0x02a8, 0x02a8, 0x0298, 0x02a4, 0x02a8, 0x04a8, 0x04a0, + 0x04a2, 0x04a2, 0x04a2, 0x0226, 0x0240, 0x0270, 0x02a8, 0x02a0, 0x02a0, 0x02aa, 0x04a8, 0x04a8, 0x04a0, 0x04a2, 0x04a8, 0x04aa, 0x0230, 0x0264, 0x02a0, 0x02a8, 0x02a0, 0x02a0, 0x02aa, 0x04a8, + 0x04a8, 0x04a0, 0x04a2, 0x04a4, 0x04a0, 0x0230, 0x0268, 0x02a0, 0x02a8, 0x02a8, 0x02a0, 0x03aa, 0x04a8, 0x04a8, 0x04a0, 0x04a8, 0x04a2, 0x04a8, 0x0234, 0x0264, 0x0290, 0x02aa, 0x02aa, 0x02aa, + 0x03aa, 0x04a8, 0x04a8, 0x04a0, 0x04a8, 0x04a6, 0x04aa, 0x0238, 0x0278, 0x02a2, 0x02a8, 0x02aa, 0x03aa, 0x04a0, 0x04a8, 0x04a8, 0x04a6, 0x04a8, 0x04a6, 0x04aa, 0x0208, 0x0110, 0x0120, 0x0130, + 0x0140, 0x0150, 0x0298, 0x0292, 0x0186, 0x026c, 0x01aa, 0x0478, 0x04aa, 0x010c, 0x020c, 0x022a, 0x0230, 0x0240, 0x0172, 0x0298, 0x029c, 0x01a8, 0x01a4, 0x0290, 0x0498, 0x04aa, 0x020c, 0x0218, + 0x0220, 0x0230, 0x0240, 0x0250, 0x0298, 0x0298, 0x029c, 0x02a0, 0x02a8, 0x04a8, 0x0498, 0x0210, 0x0220, 0x0240, 0x0260, 0x0178, 0x0270, 0x0288, 0x02a2, 0x02aa, 0x02a0, 0x039e, 0x04a8, 0x0498, + 0x0218, 0x0228, 0x024c, 0x0190, 0x02a0, 0x02a0, 0x02a8, 0x02aa, 0x03a4, 0x03a0, 0x0490, 0x04a8, 0x04a6, 0x021c, 0x0230, 0x027c, 0x0188, 0x02a0, 0x02a0, 0x02aa, 0x02a8, 0x03a8, 0x03a2, 0x0490, + 0x04a8, 0x04a0, 0x021a, 0x0230, 0x0280, 0x02aa, 0x029a, 0x02a0, 0x02a8, 0x03aa, 0x03a8, 0x03a8, 0x0490, 0x04a8, 0x04a0, 0x0220, 0x024c, 0x0270, 0x02a8, 0x029c, 0x02a0, 0x03a0, 0x039e, 0x03a6, + 0x04a0, 0x0490, 0x04a8, 0x049a, 0x0228, 0x0258, 0x0298, 0x02a8, 0x02a2, 0x02a6, 0x03a0, 0x03a8, 0x03a8, 0x04a0, 0x0490, 0x04a0, 0x04a4, 0x022a, 0x0270, 0x02a0, 0x02a8, 0x02a8, 0x03a0, 0x03a0, + 0x03a8, 0x03a8, 0x04a0, 0x0490, 0x04a8, 0x04aa, 0x0238, 0x0270, 0x02a8, 0x02a8, 0x02a8, 0x03a0, 0x03a0, 0x04a8, 0x03a8, 0x04a0, 0x0490, 0x04a0, 0x04a0, 0x023e, 0x0280, 0x02a8, 0x02a8, 0x02aa, + 0x03a0, 0x03a0, 0x04a8, 0x03a8, 0x04a0, 0x0490, 0x0496, 0x04a0, 0x024a, 0x028c, 0x02a8, 0x02a8, 0x03a0, 0x03a0, 0x03a0, 0x04a8, 0x03a8, 0x04a0, 0x0490, 0x04a4, 0x04a4, 0x0108, 0x0208, 0x0120, + 0x0130, 0x0140, 0x013c, 0x0140, 0x0198, 0x0190, 0x0176, 0x0298, 0x0488, 0x04a0, 0x0108, 0x0212, 0x012a, 0x022a, 0x0246, 0x0254, 0x0190, 0x019e, 0x01a8, 0x01a6, 0x0288, 0x02a0, 0x04a6, 0x0110, + 0x0120, 0x0230, 0x0230, 0x0240, 0x0178, 0x0180, 0x0280, 0x029c, 0x029c, 0x02a8, 0x04a0, 0x04a4, 0x0210, 0x021e, 0x0240, 0x0260, 0x017c, 0x027a, 0x02a8, 0x02aa, 0x0298, 0x02aa, 0x02aa, 0x04a0, + 0x04a4, 0x0218, 0x0230, 0x0150, 0x0288, 0x0294, 0x02a0, 0x02aa, 0x02a4, 0x02a8, 0x02aa, 0x0490, 0x04a0, 0x04a4, 0x0218, 0x0238, 0x0262, 0x019c, 0x02a0, 0x02a0, 0x02a0, 0x0390, 0x02a8, 0x0380, + 0x0490, 0x04a0, 0x04a4, 0x0226, 0x0240, 0x026c, 0x029c, 0x02a0, 0x029e, 0x02aa, 0x0390, 0x02a8, 0x0380, 0x0490, 0x0478, 0x03a6, 0x0220, 0x0238, 0x0278, 0x02a8, 0x02a0, 0x02a0, 0x02a0, 0x0390, + 0x03a8, 0x0380, 0x0490, 0x04a0, 0x0496, 0x022c, 0x0248, 0x02a0, 0x02a8, 0x02a0, 0x02a0, 0x02aa, 0x0390, 0x03a8, 0x0480, 0x0490, 0x04a0, 0x04a4, 0x0230, 0x026e, 0x02a8, 0x02a8, 0x02a0, 0x02a0, + 0x0380, 0x0390, 0x03a8, 0x0480, 0x0490, 0x04a0, 0x049e, 0x0238, 0x026a, 0x02a0, 0x02a8, 0x02a0, 0x02a0, 0x0380, 0x0390, 0x03a8, 0x0480, 0x0490, 0x04a0, 0x04a4, 0x0238, 0x0280, 0x0298, 0x02a8, + 0x02a0, 0x02a0, 0x0380, 0x0390, 0x03a8, 0x0480, 0x0490, 0x04a0, 0x04a4, 0x0248, 0x0290, 0x02a0, 0x02a8, 0x02a0, 0x02a0, 0x0380, 0x0390, 0x04a8, 0x0480, 0x0490, 0x04a0, 0x04aa, 0x0108, 0x0108, + 0x0280, 0x0132, 0x025c, 0x0230, 0x0174, 0x01a4, 0x01a6, 0x019e, 0x019e, 0x0498, 0x0480, 0x020c, 0x0210, 0x0220, 0x0236, 0x0140, 0x0150, 0x0280, 0x01aa, 0x0270, 0x01aa, 0x02a8, 0x0490, 0x0494, + 0x020e, 0x022a, 0x012e, 0x0260, 0x017e, 0x017c, 0x019c, 0x029c, 0x02a8, 0x02a2, 0x02a8, 0x0494, 0x0498, 0x0210, 0x0220, 0x0168, 0x0178, 0x0180, 0x02a0, 0x029e, 0x02aa, 0x03a8, 0x02aa, 0x02aa, + 0x0490, 0x04aa, 0x0218, 0x0230, 0x0250, 0x0270, 0x027e, 0x02a0, 0x02aa, 0x0390, 0x02a4, 0x0380, 0x0490, 0x04a8, 0x04a8, 0x0226, 0x0240, 0x0270, 0x019e, 0x02a0, 0x02a2, 0x02a6, 0x0390, 0x03a8, + 0x0380, 0x0490, 0x04a8, 0x04a2, 0x0220, 0x023e, 0x0280, 0x02a8, 0x02a0, 0x02a0, 0x0380, 0x0390, 0x03a8, 0x0380, 0x0490, 0x04a6, 0x04a6, 0x0228, 0x0248, 0x028e, 0x02a8, 0x02a0, 0x02a0, 0x03a0, + 0x0390, 0x03a8, 0x0480, 0x0490, 0x04a6, 0x03a6, 0x0230, 0x0250, 0x029c, 0x02a8, 0x02a0, 0x02a0, 0x0380, 0x0390, 0x03a8, 0x0480, 0x0490, 0x049c, 0x04aa, 0x0238, 0x0278, 0x02a0, 0x02a8, 0x02a6, + 0x02a0, 0x03a0, 0x0390, 0x04a8, 0x0480, 0x0490, 0x049c, 0x04a6, 0x0240, 0x0298, 0x02a0, 0x02a8, 0x02a0, 0x02a0, 0x0380, 0x0478, 0x04a8, 0x0480, 0x0490, 0x049e, 0x03a6, 0x0240, 0x0280, 0x02a8, + 0x02a8, 0x02aa, 0x03a0, 0x03a0, 0x0478, 0x04a8, 0x0480, 0x0490, 0x049e, 0x04a0, 0x0258, 0x0290, 0x02a0, 0x02a8, 0x0390, 0x03a0, 0x0380, 0x0478, 0x04a8, 0x0480, 0x0490, 0x04a0, 0x04aa, 0x0204, + 0x0110, 0x0120, 0x0130, 0x0140, 0x0150, 0x0140, 0x0160, 0x0170, 0x0180, 0x019c, 0x01a0, 0x049c, 0x0110, 0x0208, 0x011e, 0x0230, 0x0140, 0x0182, 0x017e, 0x0190, 0x019c, 0x0180, 0x01a4, 0x02a8, + 0x02a8, 0x0208, 0x0118, 0x0230, 0x0248, 0x0240, 0x0178, 0x0190, 0x0298, 0x01aa, 0x01a4, 0x01a8, 0x01a8, 0x04a6, 0x020c, 0x0220, 0x0240, 0x0158, 0x01a2, 0x01a0, 0x029c, 0x0288, 0x01a8, 0x01a8, + 0x01a8, 0x01a6, 0x048c, 0x0218, 0x0230, 0x0248, 0x0268, 0x01a0, 0x01a0, 0x01a0, 0x01a2, 0x01a8, 0x01a8, 0x01a8, 0x0296, 0x049c, 0x0218, 0x0238, 0x017e, 0x019c, 0x01a0, 0x01a0, 0x01a0, 0x01a8, + 0x01a8, 0x01a4, 0x01a8, 0x0496, 0x03aa, 0x0220, 0x0230, 0x0268, 0x0278, 0x027e, 0x01a0, 0x01a0, 0x01a2, 0x01a8, 0x01a8, 0x01a8, 0x049c, 0x049c, 0x0224, 0x0238, 0x0270, 0x01a8, 0x01a0, 0x01a0, + 0x01a0, 0x01a8, 0x01a8, 0x01a8, 0x01a8, 0x04a8, 0x049c, 0x0220, 0x0250, 0x0280, 0x01a4, 0x01a0, 0x01a0, 0x0280, 0x0290, 0x01a8, 0x01a8, 0x01a8, 0x0496, 0x04a4, 0x0230, 0x0262, 0x0280, 0x01a8, + 0x01a0, 0x01a0, 0x01a0, 0x01a2, 0x01a8, 0x01a8, 0x01a8, 0x0496, 0x049c, 0x0238, 0x0270, 0x0280, 0x01a8, 0x01a0, 0x01a0, 0x01a0, 0x0290, 0x01a8, 0x01a8, 0x01a8, 0x0496, 0x049c, 0x0248, 0x0264, + 0x01a0, 0x01a8, 0x01a0, 0x01a0, 0x01a0, 0x0296, 0x01a8, 0x01a8, 0x01a8, 0x0496, 0x049c, 0x0248, 0x0278, 0x01a0, 0x01a8, 0x01a0, 0x01a0, 0x01a0, 0x029c, 0x01a8, 0x01a8, 0x02a8, 0x0496, 0x049c, +}; + +} // namespace exl3_packed diff --git a/gptqmodel_ext/exllamav3/quant/hadamard.cu b/gptqmodel_ext/exllamav3/quant/hadamard.cu new file mode 100644 index 000000000..fd65955b1 --- /dev/null +++ b/gptqmodel_ext/exllamav3/quant/hadamard.cu @@ -0,0 +1,212 @@ +#include +#include "quantize.cuh" +#include +#include +#include "../util.h" +#include "../util.cuh" +#include "hadamard_inner.cuh" + +__global__ __launch_bounds__(32) +void had_hf_r_128_kernel +( + const half* __restrict__ input_ptr, + half* __restrict__ output_ptr, + const half* __restrict__ pre_scale, + const half* __restrict__ post_scale, + const float r_scale +) +{ + input_ptr += gridDim.y * 128 * blockIdx.x + blockIdx.y * 128; + output_ptr += gridDim.y * 128 * blockIdx.x + blockIdx.y * 128; + had_hf_r_128_inner(input_ptr, output_ptr, pre_scale, post_scale, r_scale); +} + +__global__ __launch_bounds__(32) +void had_ff_r_128_kernel +( + const float* __restrict__ input_ptr, + float* __restrict__ output_ptr, + const half* __restrict__ pre_scale, + const half* __restrict__ post_scale, + const float r_scale +) +{ + input_ptr += gridDim.y * 128 * blockIdx.x + blockIdx.y * 128; + output_ptr += gridDim.y * 128 * blockIdx.x + blockIdx.y * 128; + had_ff_r_128_inner(input_ptr, output_ptr, pre_scale, post_scale, r_scale); +} + +__global__ __launch_bounds__(32) +void had_hf_r_128_dual_kernel +( + const half* __restrict__ input1_ptr, + half* __restrict__ output1_ptr, + const half* __restrict__ pre1_scale, + const half* __restrict__ post1_scale, + const half* __restrict__ input2_ptr, + half* __restrict__ output2_ptr, + const half* __restrict__ pre2_scale, + const half* __restrict__ post2_scale, + const float r_scale +) +{ + input1_ptr += gridDim.y * 128 * blockIdx.x + blockIdx.y * 128; + output1_ptr += gridDim.y * 128 * blockIdx.x + blockIdx.y * 128; + had_hf_r_128_inner(input1_ptr, output1_ptr, pre1_scale, post1_scale, r_scale); + + input2_ptr += gridDim.y * 128 * blockIdx.x + blockIdx.y * 128; + output2_ptr += gridDim.y * 128 * blockIdx.x + blockIdx.y * 128; + had_hf_r_128_inner(input2_ptr, output2_ptr, pre2_scale, post2_scale, r_scale); +} + +__global__ __launch_bounds__(32) +void had_ff_r_128_dual_kernel +( + const float* __restrict__ input1_ptr, + float* __restrict__ output1_ptr, + const half* __restrict__ pre1_scale, + const half* __restrict__ post1_scale, + const float* __restrict__ input2_ptr, + float* __restrict__ output2_ptr, + const half* __restrict__ pre2_scale, + const half* __restrict__ post2_scale, + const float r_scale +) +{ + input1_ptr += gridDim.y * 128 * blockIdx.x + blockIdx.y * 128; + output1_ptr += gridDim.y * 128 * blockIdx.x + blockIdx.y * 128; + had_ff_r_128_inner(input1_ptr, output1_ptr, pre1_scale, post1_scale, r_scale); + + input2_ptr += gridDim.y * 128 * blockIdx.x + blockIdx.y * 128; + output2_ptr += gridDim.y * 128 * blockIdx.x + blockIdx.y * 128; + had_ff_r_128_inner(input2_ptr, output2_ptr, pre2_scale, post2_scale, r_scale); +} + +/* +Compute y = (x.view(-1, 128) @ had_128).view(x.shape) +Works inplace if y == x +x and y must be same dtype, either float16 or float32 +*/ +void had_r_128 +( + const at::Tensor& input, + const at::Tensor& output, + const c10::optional& pre_scale, + const c10::optional& post_scale, + const float scale +) +{ + const at::cuda::OptionalCUDAGuard device_guard(input.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + TORCH_CHECK_SHAPES_FULL(input, output); + TORCH_CHECK_DIM(input, 2); + TORCH_CHECK_DIV(input, 1, 128); + int rows = input.size(0); + int cols = input.size(1); + + int blocks = cols / 128; + float r_scale = scale * 0.088388347648f; // scale / sqrt(128) + + dim3 blockDim(32); + dim3 gridDim(rows, blocks); + + if (input.dtype() == at::kHalf) + { + TORCH_CHECK_DTYPE(output, kHalf); + had_hf_r_128_kernel<<>> + ( + (const half*) input.data_ptr(), + (half*) output.data_ptr(), + (const half*) OPTPTR(pre_scale), + (const half*) OPTPTR(post_scale), + r_scale + ); + cuda_check(cudaPeekAtLastError()); + } + + else if (input.dtype() == at::kFloat) + { + TORCH_CHECK_DTYPE(output, kFloat); + had_ff_r_128_kernel<<>> + ( + (const float*) input.data_ptr(), + (float*) output.data_ptr(), + (const half*) OPTPTR(pre_scale), + (const half*) OPTPTR(post_scale), + r_scale + ); + cuda_check(cudaPeekAtLastError()); + } + + else TORCH_CHECK(false, "unsupported datatype"); +} + +void had_r_128_dual +( + const at::Tensor& input1, + const at::Tensor& output1, + const c10::optional& pre_scale1, + const c10::optional& post_scale1, + const at::Tensor& input2, + const at::Tensor& output2, + const c10::optional& pre_scale2, + const c10::optional& post_scale2, + const float scale +) +{ + const at::cuda::OptionalCUDAGuard device_guard(input1.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + TORCH_CHECK_SHAPES_FULL(input1, output1); + TORCH_CHECK_SHAPES_FULL(input1, input2); + TORCH_CHECK_SHAPES_FULL(output1, output2); + TORCH_CHECK_DIM(input1, 2); + TORCH_CHECK_DIV(input1, 1, 128); + int rows = input1.size(0); + int cols = input1.size(1); + + int blocks = cols / 128; + float r_scale = scale * 0.088388347648f; // scale / sqrt(128) + + dim3 blockDim(32); + dim3 gridDim(rows, blocks); + + if (input1.dtype() == at::kHalf) + { + TORCH_CHECK_DTYPE(output1, kHalf); + had_hf_r_128_dual_kernel<<>> + ( + (const half*) input1.data_ptr(), + (half*) output1.data_ptr(), + (const half*) OPTPTR(pre_scale1), + (const half*) OPTPTR(post_scale1), + (const half*) input2.data_ptr(), + (half*) output2.data_ptr(), + (const half*) OPTPTR(pre_scale2), + (const half*) OPTPTR(post_scale2), + r_scale + ); + cuda_check(cudaPeekAtLastError()); + } + + else if (input1.dtype() == at::kFloat) + { + TORCH_CHECK_DTYPE(output1, kFloat); + had_ff_r_128_dual_kernel<<>> + ( + (const float*) input1.data_ptr(), + (float*) output1.data_ptr(), + (const half*) OPTPTR(pre_scale1), + (const half*) OPTPTR(post_scale1), + (const float*) input2.data_ptr(), + (float*) output2.data_ptr(), + (const half*) OPTPTR(pre_scale2), + (const half*) OPTPTR(post_scale2), + r_scale + ); + cuda_check(cudaPeekAtLastError()); + } + + else TORCH_CHECK(false, "unsupported datatype"); +} diff --git a/gptqmodel_ext/exllamav3/quant/hadamard.cuh b/gptqmodel_ext/exllamav3/quant/hadamard.cuh new file mode 100644 index 000000000..5c25ff38a --- /dev/null +++ b/gptqmodel_ext/exllamav3/quant/hadamard.cuh @@ -0,0 +1,25 @@ +#pragma once + +#include + +void had_r_128 +( + const at::Tensor& input, + const at::Tensor& output, + const c10::optional& pre_scale, + const c10::optional& post_scale, + const float scale +); + +void had_r_128_dual +( + const at::Tensor& input1, + const at::Tensor& output1, + const c10::optional& pre_scale1, + const c10::optional& post_scale1, + const at::Tensor& input2, + const at::Tensor& output2, + const c10::optional& pre_scale2, + const c10::optional& post_scale2, + const float scale +); diff --git a/gptqmodel_ext/exllamav3/quant/hadamard_inner.cuh b/gptqmodel_ext/exllamav3/quant/hadamard_inner.cuh new file mode 100644 index 000000000..79132806d --- /dev/null +++ b/gptqmodel_ext/exllamav3/quant/hadamard_inner.cuh @@ -0,0 +1,205 @@ +#pragma once + +// Hadamard transform 128-element vector across one warp, with optional pre and post scales + +__device__ inline half hreduce(half2 x) +{ + return __hadd(__low2half(x), __high2half(x)); +} + +__device__ inline void shuffle_had_f4x32(float& h0, float& h1, float& h2, float& h3, const int lane_id) +{ + #pragma unroll + for (int i = 1; i < 32; i <<= 1) + { + uint32_t i0 = __float_as_uint(h0); + uint32_t i1 = __float_as_uint(h1); + uint32_t i2 = __float_as_uint(h2); + uint32_t i3 = __float_as_uint(h3); + uint64_t h01 = (uint64_t) i0 | (((uint64_t) i1) << 32); + uint64_t h23 = (uint64_t) i2 | (((uint64_t) i3) << 32); + uint64_t ph01 = __shfl_xor_sync(0xffffffff, h01, i); + uint64_t ph23 = __shfl_xor_sync(0xffffffff, h23, i); + float ph0 = __uint_as_float((uint32_t) (ph01 & 0xffffffff)); + float ph1 = __uint_as_float((uint32_t) (ph01 >> 32)); + float ph2 = __uint_as_float((uint32_t) (ph23 & 0xffffffff)); + float ph3 = __uint_as_float((uint32_t) (ph23 >> 32)); + int32_t sfm = -static_cast(lane_id & i) >> 31; + i0 ^= sfm & 0x80000000; + i1 ^= sfm & 0x80000000; + i2 ^= sfm & 0x80000000; + i3 ^= sfm & 0x80000000; + h0 = __uint_as_float(i0) + ph0; + h1 = __uint_as_float(i1) + ph1; + h2 = __uint_as_float(i2) + ph2; + h3 = __uint_as_float(i3) + ph3; + } +} + +__device__ inline void shuffle_had_f2x32(float& v, float& w, const int lane_id) +{ + #pragma unroll + for (int i = 1; i < 32; i <<= 1) + { + uint64_t vw = ((uint64_t) __float_as_uint(v)) | (((uint64_t) __float_as_uint(w)) << 32); + uint64_t pvw = __shfl_xor_sync(0xffffffff, vw, i); + float pv = __uint_as_float((uint32_t) (pvw & 0xffffffff)); + float pw = __uint_as_float((uint32_t) (pvw >> 32)); + uint32_t vi = __float_as_uint(v); + uint32_t wi = __float_as_uint(w); + int32_t sfm = -static_cast(lane_id & i) >> 31; + vi ^= (sfm & 0x80000000); + wi ^= (sfm & 0x80000000); + v = __uint_as_float(vi) + pv; + w = __uint_as_float(wi) + pw; + } +} + +__device__ inline float shuffle_had_fx32(float v, const int lane_id) +{ + for (int i = 1; i < 32; i <<= 1) + { + float pv = __shfl_xor_sync(0xffffffff, v, i); + uint32_t* vi = reinterpret_cast(&v); + int32_t sfm = -static_cast(lane_id & i) >> 31; + *vi ^= (sfm & 0x80000000); + v = v + pv; + } + return v; +} + +__device__ inline half2 shuffle_had_h2x32(half2 v, int lane_id) +{ + for (int i = 1; i < 32; i <<= 1) + { + half2 pv = __shfl_xor_sync(0xffffffff, v, i); + uint32_t* vi = reinterpret_cast(&v); + int32_t sfm = -static_cast(lane_id & i) >> 31; + *vi ^= (sfm & 0x80008000); + v = __hadd2(v, pv); + } + return v; +} + +// Half vector, half scales + +inline __device__ +void had_hf_r_128_inner +( + const half* __restrict__ input_ptr, + half* __restrict__ output_ptr, + const half* __restrict__ pre_scale, + const half* __restrict__ post_scale, + const float r_scale +) +{ + int t = threadIdx.x & 31; + + // Load + half4 v = ((half4*) input_ptr)[t]; + + // Pre scale + if (pre_scale) + { + int i = blockIdx.y * 32 + t; + half4 scales = ((half4*) pre_scale)[i]; + v.x = __hmul2(v.x, scales.x); + v.y = __hmul2(v.y, scales.y); + } + + // 4 element had + float v0 = __half2float(__low2half(v.x)); + float v1 = __half2float(__high2half(v.x)); + float v2 = __half2float(__low2half(v.y)); + float v3 = __half2float(__high2half(v.y)); + float s0 = v0 + v1; + float d0 = v0 - v1; + float s1 = v2 + v3; + float d1 = v2 - v3; + float h0 = s0 + s1; + float h1 = d0 + d1; + float h2 = s0 - s1; + float h3 = d0 - d1; + + // 32 element had, warp shuffle + shuffle_had_f4x32(h0, h1, h2, h3, t); + v.x = __floats2half2_rn(h0 * r_scale, h1 * r_scale); + v.y = __floats2half2_rn(h2 * r_scale, h3 * r_scale); + + // Post scale + if (post_scale) + { + int i = blockIdx.y * 32 + t; + half4 scales = ((half4*) post_scale)[i]; + v.x = __hmul2(v.x, scales.x); + v.y = __hmul2(v.y, scales.y); + } + + // Store + ((half4*) output_ptr)[t] = v; +} + +// Float vector, half scales + +inline __device__ +void had_ff_r_128_inner +( + const float* __restrict__ input_ptr, + float* __restrict__ output_ptr, + const half* __restrict__ pre_scale, + const half* __restrict__ post_scale, + const float r_scale +) +{ + int t = threadIdx.x & 31; + + // Load + float4 v = ((float4*) input_ptr)[t]; + + // Pre scale + if (pre_scale) + { + int i = blockIdx.y * 32 + t; + half4 scales = ((half4*) pre_scale)[i]; + v.x *= __low2float(scales.x); + v.y *= __high2float(scales.x); + v.z *= __low2float(scales.y); + v.w *= __high2float(scales.y); + } + + // 4 element had + float v0 = v.x; + float v1 = v.y; + float v2 = v.z; + float v3 = v.w; + float s0 = v0 + v1; + float d0 = v0 - v1; + float s1 = v2 + v3; + float d1 = v2 - v3; + v.x = s0 + s1; + v.y = d0 + d1; + v.z = s0 - s1; + v.w = d0 - d1; + + // 32 element had, warp shuffle + shuffle_had_f2x32(v.x, v.y, t); + shuffle_had_f2x32(v.z, v.w, t); + v.x *= r_scale; + v.y *= r_scale; + v.z *= r_scale; + v.w *= r_scale; + + // Post scale + if (post_scale) + { + int i = blockIdx.y * 32 + t; + half4 scales = ((half4*) post_scale)[i]; + v.x *= __low2float(scales.x); + v.y *= __high2float(scales.x); + v.z *= __low2float(scales.y); + v.w *= __high2float(scales.y); + } + + // Store + ((float4*) output_ptr)[t] = v; +} \ No newline at end of file diff --git a/gptqmodel_ext/exllamav3/quant/pack.cu b/gptqmodel_ext/exllamav3/quant/pack.cu new file mode 100644 index 000000000..b91edf2d6 --- /dev/null +++ b/gptqmodel_ext/exllamav3/quant/pack.cu @@ -0,0 +1,227 @@ +#include +#include "quantize.cuh" +#include +#include +#include "../util.h" +#include "../util.cuh" +#include "codebook.cuh" + +template +__global__ __launch_bounds__(128) +void pack_trellis_kernel +( + uint16_t* __restrict__ g_packed, + const uint16_t* __restrict__ g_unpacked +) +{ + constexpr int packed_size = 256 * K / 16; + __shared__ uint16_t s_unpacked[256]; + __shared__ uint16_t s_packed[packed_size]; + + int t = threadIdx.x; + g_packed += (gridDim.x * blockIdx.y + blockIdx.x) * packed_size; + g_unpacked += (gridDim.x * blockIdx.y + blockIdx.x) * 256; + + ((uint32_t*) s_unpacked)[t] = ((uint32_t*) g_unpacked)[t]; + __syncthreads(); + + // 16 spans of 16 weights to guarantee alignment for any K + const int spans = 16; + const int len = 256 / spans; + if (t < spans) + { + int i = len * t; + int j = K * t; + int k = 32; + uint32_t buf = 0; + for (int n = 0; n < len; ++n) + { + uint32_t v = (uint32_t) s_unpacked[i]; + v &= ((1 << K) - 1); + k -= K; + buf |= (v << k); + if (k <= 16) + { + s_packed[j] = (uint16_t) (buf >> 16); + buf <<= 16; + k += 16; + j++; + } + i++; + } + } + __syncthreads(); + + if (t < packed_size / 2) + ((uint32_t*) g_packed)[t] = SWAP16(((uint32_t*) s_packed)[t]);; +} + +#define __(i) pack_trellis_kernel +constexpr auto pack_trellis_kernel_instances = std::array +{ + __(1), __(2), __(3), __(4), __(5), __(6), __(7), __(8) +}; +#undef __ + +void pack_trellis +( + at::Tensor packed, + at::Tensor unpacked, + int K +) +{ + const at::cuda::OptionalCUDAGuard device_guard(unpacked.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + TORCH_CHECK_SHAPES(packed, 0, unpacked, 0, 1); + TORCH_CHECK_SHAPES(packed, 1, unpacked, 1, 1); + TORCH_CHECK_SIZE(unpacked, 2, 256); + TORCH_CHECK_SIZE(packed, 2, 256 * K / 16); + + int rows = packed.size(0); + int cols = packed.size(1); + + dim3 blockDim(128); + dim3 gridDim(rows, cols); + + pack_trellis_kernel_instances[K - 1]<<>> + ( + (uint16_t*) packed.data_ptr(), + (const uint16_t*) unpacked.data_ptr() + ); + cuda_check(cudaPeekAtLastError()); +} + +template +__global__ __launch_bounds__(128) +void unpack_trellis_kernel +( + uint16_t* __restrict__ g_unpacked, + const uint16_t* __restrict__ g_packed +) +{ + constexpr int packed_size = 256 * K / 16; + __shared__ uint16_t s_packed[packed_size]; + + int t = threadIdx.x; + g_packed += (gridDim.x * blockIdx.y + blockIdx.x) * packed_size; + g_unpacked += (gridDim.x * blockIdx.y + blockIdx.x) * 256; + + // Read packed tile + if (t < packed_size / 2) + ((uint32_t*) s_packed)[t] = ((uint32_t*) g_packed)[t]; + __syncthreads(); + + // Index two words + int b0 = t * 2 * K + K - 16 + 256 * K; // start of word0 + int b1 = b0 + K; // start of word1 + int b2 = b1 + 16; // end of word1 + int i0 = b0 / 32; // uint32 containing first bit of word0 + int i1 = (b2 - 1) / 32; // uint32 containing last bit of word1, may be == i0 + int s1 = (i1 + 1) * 32 - b2; // shift to align word1 to 32-bit boundary + + // Load 32-64 bits containing word0 and word1, overlapping by 16-K bits, correct for endianness + uint32_t a = ((uint32_t*) s_packed)[i0 % (K * 256 / 32)]; + uint32_t b = ((uint32_t*) s_packed)[i1 % (K * 256 / 32)]; +// a = SWAP16(a); +// b = SWAP16(b); + + // Shift into place + uint32_t w1 = __funnelshift_r(b, a, s1); + uint32_t w0 = w1 >> K; + w0 &= 0xffff; + w1 &= 0xffff; + + // Store + uint32_t word01 = (w1 << 16) | w0; + ((uint32_t*)g_unpacked)[t] = word01; +} + +#define __(i) unpack_trellis_kernel +constexpr auto unpack_trellis_kernel_instances = std::array +{ + __(1), __(2), __(3), __(4), __(5), __(6), __(7), __(8) +}; +#undef __ + +void unpack_trellis +( + at::Tensor unpacked, + at::Tensor packed, + int K +) +{ + const at::cuda::OptionalCUDAGuard device_guard(unpacked.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + TORCH_CHECK_SHAPES(packed, 0, unpacked, 0, 1); + TORCH_CHECK_SHAPES(packed, 1, unpacked, 1, 1); + TORCH_CHECK_SIZE(unpacked, 2, 256); + TORCH_CHECK_SIZE(packed, 2, 256 * K / 16); + + int rows = packed.size(0); + int cols = packed.size(1); + + dim3 blockDim(128); + dim3 gridDim(cols, rows); + + unpack_trellis_kernel_instances[K - 1]<<>> + ( + (uint16_t*) unpacked.data_ptr(), + (const uint16_t*) packed.data_ptr() + ); + cuda_check(cudaPeekAtLastError()); +} + +__global__ __launch_bounds__(32) +void pack_signs_kernel +( + uint16_t* __restrict__ g_packed, + const uint16_t* __restrict__ g_unpacked, + int cols +) +{ + int t = threadIdx.x; + int idx = 32 * blockIdx.x + t; + if (idx >= cols) return; + g_unpacked += 16 * idx; + g_packed += idx; + + // Not efficient but whatever + uint16_t out = 0; + for (int i = 0; i < 16; ++i) + { + uint16_t v = *g_unpacked++; + v &= 0x8000; + out >>= 1; + out |= v; + } + + *g_packed = out; +} + +void pack_signs +( + at::Tensor packed, + at::Tensor unpacked +) +{ + const at::cuda::OptionalCUDAGuard device_guard(unpacked.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + TORCH_CHECK_DTYPE(unpacked, kHalf); + TORCH_CHECK_DTYPE(packed, kShort); + + int cols = packed.size(0); + dim3 blockDim(32); + dim3 gridDim(CEIL_DIVIDE(cols, 32)); + + pack_signs_kernel<<>> + ( + (uint16_t*) packed.data_ptr(), + (const uint16_t*) unpacked.data_ptr(), + cols + ); + cuda_check(cudaPeekAtLastError()); +} + diff --git a/gptqmodel_ext/exllamav3/quant/pack.cuh b/gptqmodel_ext/exllamav3/quant/pack.cuh new file mode 100644 index 000000000..0b32db2e0 --- /dev/null +++ b/gptqmodel_ext/exllamav3/quant/pack.cuh @@ -0,0 +1,23 @@ +#pragma once + +#include + +void pack_trellis +( + at::Tensor packed, + at::Tensor unpacked, + int K +); + +void unpack_trellis +( + at::Tensor unpacked, + at::Tensor packed, + int K +); + +void pack_signs +( + at::Tensor packed, + at::Tensor unpacked +); diff --git a/gptqmodel_ext/exllamav3/quant/quantize.cu b/gptqmodel_ext/exllamav3/quant/quantize.cu new file mode 100644 index 000000000..22f8ab556 --- /dev/null +++ b/gptqmodel_ext/exllamav3/quant/quantize.cu @@ -0,0 +1,530 @@ +#include +#include "quantize.cuh" +#include +#include +#include "../util.h" +#include "../util.cuh" +#include "codebook.cuh" +#include "exl3_devctx.cuh" +#include + +#define NUM_THREADS 1024 +#define H_INF __ushort_as_half(0x7c00) +#define H_NINF __ushort_as_half(0xfc00) + +template +__global__ __launch_bounds__(MIN(NUM_THREADS, 65536 >> K)) +void quantize_tiles_kernel +( + const float* __restrict__ input_tiles_ptr, + float* __restrict__ output_tiles_ptr, + uint16_t* __restrict__ output_indices_ptr, + half* __restrict__ temp_costs_ptr, + uint16_t* __restrict__ temp_edges_ptr +) +{ + extern __shared__ uint8_t shbuf[]; + uint8_t* sh = shbuf; + + int tile_idx = blockIdx.x; + int thread = threadIdx.x; + + constexpr int Kr = 16 - K; + constexpr int max_q = 1 << K; + constexpr int edges = 65536 >> K; + + const float* input_tile = input_tiles_ptr + 256 * tile_idx; + float* output_tile = output_tiles_ptr + 256 * tile_idx; + uint16_t* output_indices = output_indices_ptr + 256 * tile_idx; + uint16_t* temp_edges = temp_edges_ptr + 256 * edges * tile_idx; + + // Tile buffer + half* sh_input_tile = (half*) sh; sh += 256 * sizeof(half); + + half* sh_min = (half*) sh; sh += 32 * sizeof(half); + int* sh_idx = (int*) sh; sh += 32 * sizeof(int); + + // K >= mshk lets temp_costs fit in shmem, otherwise fall back to global temp buffer + constexpr int mshk = 2; + half* sh_temp_costs = (half*) sh; + half* temp_costs = K >= mshk ? sh_temp_costs : temp_costs_ptr + 2 * edges * tile_idx; + half* temp_costs_inc = temp_costs + edges; + + // Fetch input tile to shmem + if (thread < 256) sh_input_tile[thread] = __float2half_rn(input_tile[thread]); + __syncthreads(); + + auto forward = [&](int roll, int pre_state) + { + int ri = roll % 256; + half dh, err, min_err, w; + + // temp_costs_inc[z] is the cost/cumulative error of an incoming edge from state (z & edge_mask) + half* t = temp_costs; + temp_costs = temp_costs_inc; + temp_costs_inc = t; + + for (int out_edge_idx = thread; out_edge_idx < edges; out_edge_idx += NUM_THREADS) + { + w = sh_input_tile[ri]; + + int state = out_edge_idx; + int in_edge_idx = state >> K; + dh = __hsub(decode_3inst(state), w); + err = __hmul(dh, dh); + if (pre_state >= 0 && in_edge_idx != pre_state) err = H_INF; + min_err = err; + int min_in_edge = in_edge_idx; + + #pragma unroll + for (int k = 1; k < max_q; ++k) + { + state = (k << Kr) | out_edge_idx; + in_edge_idx = state >> K; + dh = __hsub(decode_3inst(state), w); + err = __hmul(dh, dh); + if (pre_state >= 0 && in_edge_idx != pre_state) err = H_INF; + if (__hlt(err, min_err)) { min_err = err; min_in_edge = in_edge_idx; } + } + + temp_costs[out_edge_idx] = min_err; + temp_edges[edges * ri + out_edge_idx] = (uint16_t) min_in_edge; + } + + // Next iteration depends on costs computed by current iteration + __syncthreads(); + + // Each thread iterates over all weights in the tile + for (int i = 1; i < 256; ++i) + { + ri = (i + roll) % 256; + + // Swap buffers. + t = temp_costs; + temp_costs = temp_costs_inc; + temp_costs_inc = t; + + for (int out_edge_idx = thread; out_edge_idx < edges; out_edge_idx += NUM_THREADS) + { + w = sh_input_tile[ri]; + + int state = out_edge_idx; + int in_edge_idx = state >> K; + dh = __hsub(decode_3inst(state), w); + err = __hfma(dh, dh, temp_costs_inc[in_edge_idx]); + min_err = err; + int min_in_edge = in_edge_idx; + + #pragma unroll + for (int k = 1; k < max_q; ++k) + { + state = (k << Kr) | out_edge_idx; + in_edge_idx = state >> K; + dh = __hsub(decode_3inst(state), w); + err = __hfma(dh, dh, temp_costs_inc[in_edge_idx]); + if (__hlt(err, min_err)) { min_err = err; min_in_edge = in_edge_idx; } + } + + temp_costs[out_edge_idx] = min_err; + temp_edges[edges * ri + out_edge_idx] = (uint16_t) min_in_edge; + } + + // Next iteration depends on costs computed by current iteration + __syncthreads(); + } + }; + + auto argmin_cost = [&]() + { + // Find the final state with the lowest total cost. Return value is only valid in thread 0 + + half local_min = H_INF; + int local_idx = -1; + #pragma unroll + for (int e = threadIdx.x; e < edges; e += NUM_THREADS) + { + half v = temp_costs_inc[e]; + if (__hlt(v, local_min)) { local_min = v; local_idx = e; } + } + + // Shuffle reduction + int lane_id = thread % 32; + int warp_id = thread / 32; + + #pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) + { + half other_min = __shfl_down_sync(0xffffffff, local_min, offset, 32); + int other_idx = __shfl_down_sync(0xffffffff, local_idx, offset, 32); + if (__hlt(other_min, local_min)) + { + local_min = other_min; + local_idx = other_idx; + } + } + + sh_min[warp_id] = local_min; + sh_idx[warp_id] = local_idx; + __syncthreads(); + + if (warp_id == 0) + { + local_min = lane_id * 32 < edges && thread < NUM_THREADS / 32 ? sh_min[lane_id] : H_INF; + local_idx = thread < NUM_THREADS ? sh_idx[lane_id] : 0; + + #pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) + { + half other_min = __shfl_down_sync(0xffffffff, local_min, offset, 32); + int other_idx = __shfl_down_sync(0xffffffff, local_idx, offset, 32); + if (__hlt(other_min, local_min)) + { + local_min = other_min; + local_idx = other_idx; + } + } + } + + return local_idx; + }; + + auto backward = [&](int roll, bool write, int edge) + { + // Construct output tile. Since the graph has to be walked, this will run in a single thread per block. + // Profiling says this is not a bottleneck + + if (thread == 0) + { + for (int i = 255; i >= 0; --i) + { + int ri = (i + roll) % 256; + + int prev_edge = (int) temp_edges[edges * ri + edge]; + int encoded = (prev_edge << K) | edge; + edge = prev_edge; + + if (write) + { + output_indices[ri] = (uint16_t) encoded; + output_tile[ri] = __half2float(decode_3inst(encoded)); + } + else if (ri == 0) break; + } + } + + // Broadcast to block + if (thread == 0) sh_idx[0] = edge; + __syncthreads(); + edge = sh_idx[0]; + + return edge; + }; + + // Solve starting at position 128 find initial state for second pass + forward(128, -1); + int end_state = argmin_cost(); + end_state = backward(128, false, end_state); + + // Solve again from position 0 with tail-biting constraint + forward(0, end_state); + backward(0, true, end_state); +} + +#define __(i, cb) quantize_tiles_kernel +constexpr auto quantize_tiles_kernel_instances = std::array +{ + __(1, 0), __(2, 0), __(3, 0), __(4, 0), __(5, 0), __(6, 0), __(7, 0), __(8, 0), + __(1, 1), __(2, 1), __(3, 1), __(4, 1), __(5, 1), __(6, 1), __(7, 1), __(8, 1), + __(1, 2), __(2, 2), __(3, 2), __(4, 2), __(5, 2), __(6, 2), __(7, 2), __(8, 2) +}; +#undef __ + +/* +Quantize batch of tiles + +input_tiles: shape (n, 256), float +output_tiles: shape (n, 256), float +output_indices: shape (n, 256), uint16_t (unpacked) +temp_costs: shape (max_bsz, 2, 65536 >> K), float (scratch space for Viterbi algorithm) +temp_edges: shape (max_bsz, 256, 65536 >> K), uint16_t (scratch space for Viterbi algorithm) +K: number of bits per weight (1..8) +*/ + +void quantize_tiles +( + at::Tensor input_tiles, + at::Tensor output_tiles, + at::Tensor output_indices, + at::Tensor temp_costs, + at::Tensor temp_edges, + int K, + bool mcg, + bool mul1 +) +{ + const at::cuda::OptionalCUDAGuard device_guard(input_tiles.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + TORCH_CHECK_DIM(input_tiles, 2); + TORCH_CHECK_SIZE(input_tiles, 1, 256); + TORCH_CHECK_SHAPES_FULL(input_tiles, output_indices); + TORCH_CHECK_DTYPE(input_tiles, kFloat); + TORCH_CHECK_DTYPE(output_tiles, kFloat); + TORCH_CHECK_DTYPE(output_indices, kShort); + + int edges = 65536 >> K; + int threads = MIN(NUM_THREADS, edges); + + int num_tiles = input_tiles.size(0); + if (!num_tiles) return; + + TORCH_CHECK_DTYPE(temp_costs, kHalf); + TORCH_CHECK_DIM(temp_costs, 3); + TORCH_CHECK_SIZE(temp_costs, 1, 2); + TORCH_CHECK_SIZE(temp_costs, 2, edges); + + TORCH_CHECK_DTYPE(temp_edges, kShort); + TORCH_CHECK_DIM(temp_edges, 3); + TORCH_CHECK_SIZE(temp_edges, 1, 256); + TORCH_CHECK_SIZE(temp_edges, 2, edges); + + int device; + cudaGetDevice(&device); + int num_sms = DevCtx::instance().get_num_sms(device); + int max_batch_size = MIN(temp_costs.size(0), num_sms); + + int cb = 0; + if (mcg) cb = 1; + if (mul1) cb = 2; + + int batch_i = 0; + do + { + int batch_j = MIN(batch_i + max_batch_size, num_tiles); + + const float* input_tiles_ptr = ((const float*) input_tiles.data_ptr()) + 256 * batch_i; + float* output_tiles_ptr = ((float*) output_tiles.data_ptr()) + 256 * batch_i; + uint16_t* output_indices_ptr = ((uint16_t*) output_indices.data_ptr()) + 256 * batch_i; + half* temp_costs_ptr = (half*) temp_costs.data_ptr(); + uint16_t* temp_edges_ptr = (uint16_t*) temp_edges.data_ptr(); + + int bsz = batch_j - batch_i; + int kernel_idx = K - 1 + 8 * cb; + int shmem = 2 * (65536 >> K) * sizeof(half) + 512 + 64 + 128; + + cudaFuncSetAttribute + ( + quantize_tiles_kernel_instances[kernel_idx], + cudaFuncAttributeMaxDynamicSharedMemorySize, + shmem + ); + cuda_check(cudaPeekAtLastError()); + + quantize_tiles_kernel_instances[kernel_idx]<<>> + ( + input_tiles_ptr, + output_tiles_ptr, + output_indices_ptr, + temp_costs_ptr, + temp_edges_ptr + ); + cuda_check(cudaPeekAtLastError()); + + batch_i = batch_j; + } + while (batch_i < num_tiles); +} + +template +__global__ //__launch_bounds__(64) +void decode_kernel +( + const uint16_t* __restrict__ input_tiles_ptr, + T* __restrict__ output_tiles_ptr, + int cols, + bool mcg, + bool mul1 +) +{ + int col = threadIdx.x + blockIdx.x * 64; + if (col >= cols) return; + int row = blockIdx.y; + int idx = row * cols + col; + + uint32_t enc = (uint32_t) input_tiles_ptr[idx]; + half w; + if (mcg) + w = decode_3inst<1>(enc); + else if (mul1) + w = decode_3inst<2>(enc); + else + w = decode_3inst<0>(enc); + + if constexpr (std::is_same_v) + output_tiles_ptr[idx] = __half2float(w); + else + output_tiles_ptr[idx] = w; +} + +/* +Decode tensor + +input_indices: uint16_t +output_tiles: float or half +mcg: use mcg codebook +mul1: use mcg codebook +*/ + +void decode +( + at::Tensor input_indices, + at::Tensor output_tiles, + bool mcg, + bool mul1 +) +{ + const at::cuda::OptionalCUDAGuard device_guard(input_indices.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + TORCH_CHECK_DIM(input_indices, 2); + TORCH_CHECK_SHAPES_FULL(input_indices, output_tiles); + TORCH_CHECK_DTYPE(input_indices, kShort); + + int rows = input_indices.size(0); + int cols = input_indices.size(1); + + dim3 blockDim(64); + dim3 gridDim(cols / 64, rows); + + if (output_tiles.dtype() == at::kFloat) + decode_kernel<<>> + ( + (const uint16_t*) input_indices.data_ptr(), + (float*) output_tiles.data_ptr(), + cols, + mcg, + mul1 + ); + else if (output_tiles.dtype() == at::kHalf) + decode_kernel<<>> + ( + (const uint16_t*) input_indices.data_ptr(), + (half*) output_tiles.data_ptr(), + cols, + mcg, + mul1 + ); +} + + +#define NUM_THREADS_TD 1024 +#define MAX_BINS 1024 + +__global__ __launch_bounds__(NUM_THREADS_TD) +void test_distribution_kernel +( + const float* __restrict__ input_ptr, + float* __restrict__ dist_output_ptr, + float* __restrict__ ref_output_ptr, + uint64_t numel, + uint64_t num_bins, + float min_value, + float max_value, + bool mcg, + bool mul1 +) +{ + __shared__ int histogram[MAX_BINS]; + auto reset_histogram = [&]() + { + for (int i = threadIdx.x; i < num_bins; i += NUM_THREADS_TD) + histogram[i] = 0; + __syncthreads(); + }; + + auto write_histogram = [&](float* output_ptr, uint64_t sc) + { + float scf = (float) sc; + for (int i = threadIdx.x; i < num_bins; i += NUM_THREADS_TD) + output_ptr[i] = ((float) histogram[i]) / scf; + __syncthreads(); + }; + + auto count = [&](float val) + { + val -= min_value; + val /= (max_value - min_value); + val *= (float) num_bins; + int idx = (int) val; + if (idx < 0) idx = 0; + if (idx > num_bins - 1) idx = num_bins - 1; + atomicAdd(&histogram[idx], 1); + }; + + if (ref_output_ptr) + { + reset_histogram(); + for (uint64_t i = threadIdx.x; i < 65536; i += NUM_THREADS_TD) + { + if (mcg) + count(decode_3inst_f<1>((uint16_t) (i & 0xffff))); + else if (mul1) + count(decode_3inst_f<2>((uint16_t) (i & 0xffff))); + else + count(decode_3inst_f<0>((uint16_t) (i & 0xffff))); + } + __syncthreads(); + write_histogram(ref_output_ptr, 65536); + } + + reset_histogram(); + for (uint64_t i = threadIdx.x; i < numel; i += NUM_THREADS_TD) + count(input_ptr[i]); + __syncthreads(); + write_histogram(dist_output_ptr, numel); +} + +/* +Compare tensor distribution to codebook (not optimized) + +input: tensor, float, any shape +dist_output: (empty) output histogram, float, shape (num_bins,) +ref_output, optional: (empty) output codebook histogram, float, shape (num_bins,) +*/ + +void test_distribution +( + at::Tensor& input, + at::Tensor& dist_output, + const c10::optional& ref_output, + float min_value, + float max_value, + bool mcg, + bool mul1 +) +{ + const at::cuda::OptionalCUDAGuard device_guard(input.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + TORCH_CHECK_DTYPE(input, kFloat); + + uint64_t numel = input.numel(); + float* ref_output_ptr = (float*) OPTPTR(ref_output); + uint64_t num_bins = dist_output.numel(); + TORCH_CHECK(num_bins <= MAX_BINS, "Too many bins"); + if (ref_output_ptr) + TORCH_CHECK(num_bins == ref_output.value().numel()); + + test_distribution_kernel<<<1, NUM_THREADS_TD, 0, stream>>> + ( + (const float*) input.data_ptr(), + (float*) dist_output.data_ptr(), + (float*) ref_output_ptr, + numel, + num_bins, + min_value, + max_value, + mcg, + mul1 + ); +} \ No newline at end of file diff --git a/gptqmodel_ext/exllamav3/quant/quantize.cuh b/gptqmodel_ext/exllamav3/quant/quantize.cuh new file mode 100644 index 000000000..728d4a21b --- /dev/null +++ b/gptqmodel_ext/exllamav3/quant/quantize.cuh @@ -0,0 +1,34 @@ +#pragma once + +#include + +void quantize_tiles +( + at::Tensor input_tiles, + at::Tensor output_tiles, + at::Tensor output_indices, + at::Tensor temp_costs, + at::Tensor temp_edges, + int K, + bool mcg, + bool mul1 +); + +void decode +( + at::Tensor input_indices, + at::Tensor output_tiles, + bool mcg, + bool mul1 +); + +void test_distribution +( + at::Tensor& input, + at::Tensor& dist_output, + const c10::optional& ref_output, + float min_value, + float max_value, + bool mcg, + bool mul1 +); \ No newline at end of file diff --git a/gptqmodel_ext/exllamav3/quant/reconstruct.cu b/gptqmodel_ext/exllamav3/quant/reconstruct.cu new file mode 100644 index 000000000..c893dee12 --- /dev/null +++ b/gptqmodel_ext/exllamav3/quant/reconstruct.cu @@ -0,0 +1,131 @@ +#include +#include "reconstruct.cuh" +#include +#include +#include "../util.h" +#include "../util.cuh" +#include "../ptx.cuh" +#include "exl3_dq.cuh" + +// TODO: Benchmark, profile, unit test + +template +__global__ __launch_bounds__(256) +void reconstruct_kernel +( + half* __restrict__ g_unpacked, + const uint16_t* __restrict__ g_packed +) +{ + constexpr int packed_size = 256 * K / 16; // in uint16s + + int t = threadIdx.x; + int lane_id = t % 32; + int warp_id = t / 32; + int k = blockIdx.y; + int n = blockIdx.x * 8; + int tiles_n = gridDim.x; + int blocks_n = tiles_n * 8; + + // Load packed 16*128 tile + __shared__ uint32_t s_packed[8][packed_size / 2]; + g_packed += (k * blocks_n + n) * packed_size; + for (int s = t; s < packed_size * 8 / 8; s += 256) + ((int4*) s_packed)[t] = ((int4*) g_packed)[t]; + __syncthreads(); + + // Dequant + register FragB frag[2]; + dq_dispatch(s_packed[warp_id], lane_id * 8, frag[0], frag[1]); + + // Shuffle from tensor core layout to row major tile +// __shared__ half tile[16 * 8 * 16]; + __shared__ half2 tile[16][8][8]; + + half2 n0 = __shfl_down_sync(0xFFFFFFFF, frag[0][0], 4, 32); + half2 n1 = __shfl_down_sync(0xFFFFFFFF, frag[0][1], 4, 32); + half2 n2 = __shfl_down_sync(0xFFFFFFFF, frag[1][0], 4, 32); + half2 n3 = __shfl_down_sync(0xFFFFFFFF, frag[1][1], 4, 32); + __syncwarp(); + + if (!(lane_id & 4)) + { + half2 m0 = __halves2half2(__low2half(frag[0][0]), __low2half(n0)); + half2 m1 = __halves2half2(__high2half(frag[0][0]), __high2half(n0)); + half2 m2 = __halves2half2(__low2half(frag[0][1]), __low2half(n1)); + half2 m3 = __halves2half2(__high2half(frag[0][1]), __high2half(n1)); + half2 m4 = __halves2half2(__low2half(frag[1][0]), __low2half(n2)); + half2 m5 = __halves2half2(__high2half(frag[1][0]), __high2half(n2)); + half2 m6 = __halves2half2(__low2half(frag[1][1]), __low2half(n3)); + half2 m7 = __halves2half2(__high2half(frag[1][1]), __high2half(n3)); + int r0 = (lane_id % 4) * 2; + int r1 = r0 + 1; + int r2 = r0 + 8; + int r3 = r0 + 9; + int c0 = lane_id / 8; + int c1 = c0 + 4; + tile[r0][warp_id][c0] = m0; + tile[r1][warp_id][c0] = m1; + tile[r2][warp_id][c0] = m2; + tile[r3][warp_id][c0] = m3; + tile[r0][warp_id][c1] = m4; + tile[r1][warp_id][c1] = m5; + tile[r2][warp_id][c1] = m6; + tile[r3][warp_id][c1] = m7; + } + __syncthreads(); + + // Store unpacked tile + int r = t / 16; + int c = t % 16; + int4* tile_int4 = (reinterpret_cast (tile)); + int4* out_int4 = ((int4*) g_unpacked) + (k * 16 + r) * 2 * blocks_n + n * 2 + c; + *out_int4 = tile_int4[t]; +} + +#define __(i, cb) reconstruct_kernel +constexpr auto reconstruct_kernel_instances = std::array +{ + __(1, 0), __(2, 0), __(3, 0), __(4, 0), __(5, 0), __(6, 0), __(7, 0), __(8, 0), + __(1, 1), __(2, 1), __(3, 1), __(4, 1), __(5, 1), __(6, 1), __(7, 1), __(8, 1), + __(1, 2), __(2, 2), __(3, 2), __(4, 2), __(5, 2), __(6, 2), __(7, 2), __(8, 2) +}; +#undef __ + +/* +Reconstruct encoded+packed tensor +*/ +void reconstruct +( + at::Tensor unpacked, + at::Tensor packed, + int K, + bool mcg, + bool mul1 +) +{ + const at::cuda::OptionalCUDAGuard device_guard(unpacked.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + TORCH_CHECK_SHAPES(unpacked, 0, packed, 0, 16); + TORCH_CHECK_SHAPES(unpacked, 1, packed, 1, 16); + TORCH_CHECK_SIZE(packed, 2, 256 * K / 16); + TORCH_CHECK_DTYPE(unpacked, kHalf); + + int rows = packed.size(0); + int cols = packed.size(1); + + dim3 blockDim(256); + dim3 gridDim(cols / 8, rows); + + int cbi = K - 1; + if (mcg) cbi += 8; + else if (mul1) cbi += 16; + + reconstruct_kernel_instances[cbi]<<>> + ( + (half*) unpacked.data_ptr(), + (const uint16_t*) packed.data_ptr() + ); + cuda_check(cudaPeekAtLastError()); +} diff --git a/gptqmodel_ext/exllamav3/quant/reconstruct.cuh b/gptqmodel_ext/exllamav3/quant/reconstruct.cuh new file mode 100644 index 000000000..bf1c261e1 --- /dev/null +++ b/gptqmodel_ext/exllamav3/quant/reconstruct.cuh @@ -0,0 +1,12 @@ +#pragma once + +#include + +void reconstruct +( + at::Tensor unpacked, + at::Tensor packed, + int K, + bool mcg, + bool mul1 +); diff --git a/gptqmodel_ext/exllamav3/quant/util.cu b/gptqmodel_ext/exllamav3/quant/util.cu new file mode 100644 index 000000000..23b500260 --- /dev/null +++ b/gptqmodel_ext/exllamav3/quant/util.cu @@ -0,0 +1,121 @@ +#include +#include "util.cuh" +#include +#include +#include "../util.h" +#include "../util.cuh" + +#define NUM_THREADS 1024 +#define BLOCK_SIZE 32768 + +#define uint64_cu unsigned long long int + +__device__ inline uint64_cu warp_reduce_sum(uint64_cu v) +{ + for (int offset = 32 >> 1; offset > 0; offset >>= 1) + { + uint64_cu other_v = __shfl_down_sync(0xffffffff, v, offset); + v += other_v; + } + return v; +} + +__device__ inline uint64_cu block_reduce_sum(uint64_cu v) +{ + __shared__ uint64_cu shared[NUM_THREADS / 32]; + + int lane_id = threadIdx.x % 32; + int warp_id = threadIdx.x / 32; + + v = warp_reduce_sum(v); + + if (lane_id == 0) shared[warp_id] = v; + __syncthreads(); + + int max_warp_id = NUM_THREADS / 32; + if (warp_id == 0) + { + v = lane_id < max_warp_id ? shared[lane_id] : 0; + v = warp_reduce_sum(v); + } + __syncthreads(); + return v; +} + +__device__ inline bool isinf(half v) +{ + return isinf(__half2float(v)); +} + +__device__ inline bool isnan(half v) +{ + return isnan(__half2float(v)); +} + +template +__global__ __launch_bounds__(NUM_THREADS) +void count_inf_nan_kernel +( + const T* __restrict__ x, + uint64_cu* __restrict__ y, + uint64_cu numel +) +{ + uint64_cu idx = blockIdx.x * BLOCK_SIZE + threadIdx.x; + uint64_cu max_idx = MIN(blockIdx.x * BLOCK_SIZE + BLOCK_SIZE, numel); + uint64_cu thread_inf = 0; + uint64_cu thread_nan = 0; + for (; idx < max_idx; idx += NUM_THREADS) + { + T val = x[idx]; + if (isinf(val)) thread_inf++; + if (isnan(val)) thread_nan++; + } + + thread_inf = block_reduce_sum(thread_inf); + thread_nan = block_reduce_sum(thread_nan); + + if (threadIdx.x == 0) + { + atomicAdd(y + 0, thread_inf); + atomicAdd(y + 1, thread_nan); + } +} + +/* +Count number of inf and NaN values in tensor + +x: Tensor to test +y: Output, dtype kLong, shape (2,) +*/ + +void count_inf_nan +( + at::Tensor x, + at::Tensor y +) +{ + const at::cuda::OptionalCUDAGuard device_guard(x.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + TORCH_CHECK_DTYPE(y, kLong); + + uint64_cu numel = x.numel(); + uint64_cu num_blocks = CEIL_DIVIDE(numel, BLOCK_SIZE); + + if (x.dtype() == at::kHalf) + count_inf_nan_kernel<<>> + ( + (const half*) x.data_ptr(), + (uint64_cu*) y.data_ptr(), + numel + ); + else if (x.dtype() == at::kFloat) + count_inf_nan_kernel<<>> + ( + (const float*) x.data_ptr(), + (uint64_cu*) y.data_ptr(), + numel + ); + else + TORCH_CHECK(false, "Unsupported dtype"); +} \ No newline at end of file diff --git a/gptqmodel_ext/exllamav3/quant/util.cuh b/gptqmodel_ext/exllamav3/quant/util.cuh new file mode 100644 index 000000000..1e40575d7 --- /dev/null +++ b/gptqmodel_ext/exllamav3/quant/util.cuh @@ -0,0 +1,10 @@ +#pragma once + +#include +#include + +void count_inf_nan +( + at::Tensor x, + at::Tensor y +); \ No newline at end of file diff --git a/gptqmodel_ext/exllamav3/util.cuh b/gptqmodel_ext/exllamav3/util.cuh new file mode 100644 index 000000000..fa443ae1f --- /dev/null +++ b/gptqmodel_ext/exllamav3/util.cuh @@ -0,0 +1,139 @@ +#pragma once + +typedef struct __align__(8) half4 +{ + half2 x; + half2 y; + __device__ half4() = default; + __device__ half4(half2 x_, half2 y_) : x(x_), y(y_) {} + __device__ half4(half h0, half h1, half h2, half h3) : + x(__halves2half2(h0, h1)), + y(__halves2half2(h2, h3)) {} +} +half4; + +typedef struct __align__(8) bfloat164 +{ + __nv_bfloat162 x; + __nv_bfloat162 y; + __device__ bfloat164() = default; + __device__ bfloat164(__nv_bfloat162 x_, __nv_bfloat162 y_): x(x_), y(y_) {} + __device__ bfloat164(__nv_bfloat16 b0, __nv_bfloat16 b1, __nv_bfloat16 b2, __nv_bfloat16 b3) : + x(__halves2bfloat162(b0, b1)), + y(__halves2bfloat162(b2, b3)) {} +} +bfloat164; + +typedef struct __align__(16) half8 +{ + half2 x; + half2 y; + half2 z; + half2 w; + __device__ half8() = default; + __device__ half8(half2 x_, half2 y_, half2 z_, half2 w_) : x(x_), y(y_), z(z_), w(w_) {} + __device__ half8(half h0, half h1, half h2, half h3, half h4, half h5, half h6, half h7) : + x(__halves2half2(h0, h1)), + y(__halves2half2(h2, h3)), + z(__halves2half2(h4, h5)), + w(__halves2half2(h6, h7)) {} +} +half8; + +struct Dim3 +{ + int m; + int k; + int n; + inline __device__ int numel_a() { return m * k; } + inline __device__ int numel_b() { return k * n; } + inline __device__ int numel_c() { return m * n; } +}; + +#define READ128(__x, __y) ((uint4*)&__x)[0] = ((uint4*)(__y))[0]; +#define WRITE128(__x, __y) ((uint4*)__x)[0] = ((uint4*)(&__y))[0]; +#define READ64(__x, __y) ((uint2*)&__x)[0] = ((uint2*)(__y))[0]; +#define WRITE64(__x, __y) ((uint2*)__x)[0] = ((uint2*)(&__y))[0]; + +#define LOW_TO_FLOAT(__x) __half2float(__low2half(__x)) +#define HIGH_TO_FLOAT(__x) __half2float(__high2half(__x)) + +#define LOW_TO_FLOAT(__x) __half2float(__low2half(__x)) +#define HIGH_TO_FLOAT(__x) __half2float(__high2half(__x)) + +#define CLAMP(__x, __min, __max) fmaxf(__min, fminf(__x, __max)) +#define CLAMP_FP16(__x) CLAMP(__x, -65504.0f, 65504.0f) + +#define SWAP16(__x) __byte_perm(__x, 0, 0x1032) + +union half2_uint32 +{ + uint32_t as_uint32; + half2 as_half2; + __device__ half2_uint32(uint32_t val) : as_uint32(val) {} + __device__ half2_uint32(half2 val) : as_half2(val) {} + __device__ half2_uint32() : as_uint32(0) {} +}; + +union half_uint16 +{ + uint16_t as_uint16; + half as_half; + __device__ half_uint16(uint16_t val) : as_uint16(val) {} + __device__ half_uint16(half val) : as_half(val) {} + __device__ half_uint16() : as_uint16(0) {} +}; + +#define cuda_check(ans) { gpu_assert((ans), __FILE__, __LINE__); } +inline void gpu_assert(cudaError_t code, const char *file, int line, bool abort=true) +{ + if (code != cudaSuccess) + { + fprintf(stderr,"GPU assert: %s %s %d\n", cudaGetErrorString(code), file, line); + if (abort) exit(code); + } +} + +inline const char* cublasGetErrorString(cublasStatus_t status) { + switch (status) { + case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS"; + case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED"; + case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED"; + case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE"; + case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH"; + case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR"; + case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED"; + case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR"; + case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED"; + case CUBLAS_STATUS_LICENSE_ERROR: return "CUBLAS_STATUS_LICENSE_ERROR"; + default: return "Unknown cuBLAS status"; + } +} + +#define cublas_check(ans) { cublas_assert((ans), __FILE__, __LINE__); } +inline void cublas_assert(cublasStatus_t code, const char *file, int line, bool abort=true) +{ + if (code != CUBLAS_STATUS_SUCCESS) + { + fprintf(stderr, "cuBLAS assert: %s %s %d\n", + cublasGetErrorString(code), file, line); + if (abort) exit(static_cast(code)); + } +} + +__device__ inline float fxor(float v, uint32_t mask) +{ + uint32_t* vi = reinterpret_cast(&v); + *vi ^= mask; + return v; +} + +__device__ inline half2 h2xor(half2 v, uint32_t mask) +{ + uint32_t* vi = reinterpret_cast(&v); + *vi ^= mask; + return v; +} + +#define NEG_INF_F16 __ushort_as_half(0xFC00) +#define POS_INF_F16 __ushort_as_half(0x7C00) diff --git a/gptqmodel_ext/exllamav3/util.h b/gptqmodel_ext/exllamav3/util.h new file mode 100644 index 000000000..36434253f --- /dev/null +++ b/gptqmodel_ext/exllamav3/util.h @@ -0,0 +1,125 @@ +#pragma once + +#include + +#define CEIL_DIVIDE(x, size) (((x) + (size) - 1) / (size)) +#define MIN(x, y) ((x) < (y) ? (x) : (y)) +#define MAX(x, y) ((x) > (y) ? (x) : (y)) + +// Some decluttering macros +// +// TORCH_CHECK_DTYPE(x, T): assert x is dtype T +// TORCH_CHECK_DTYPE_OPT(x, T): assert x is dtype T, unless x is None +// TORCH_CHECK_FLOAT_HALF(x): assert x is either kFloat or kHalf +// TORCH_CHECK_SHAPES(x, i, y, j, scale): assert x.size(i) == y.size(j) * scale +// TORCH_CHECK_SHAPES_OPT(x, i, y, j, scale): assert x.size(i) == y.size(j) * scale, unless x is None +// TORCH_CHECK_SHAPES_FULL(x, y): assert x and y are same shape +// TORCH_CHECK_NUMEL(x, y): assert x and y have same number of elements +// TORCH_CHECK_DIV(x, i, divisor): assert x.size(i) is divisible by divisor +// TORCH_CHECK_DIM(x, D): assert x has D dimensions +// TORCH_CHECK_DIM_OPT(x, D): assert x has D dimensions, unless x is None +// TORCH_CHECK_SIZE(x, i, s): assert x.size(i) == s +// OPTPTR(x): x.data_ptr() or nullptr if x is None + +#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == at::__dtype, #__x " is incorrect datatype, must be " #__dtype) +#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((!__x.has_value()) || (__x).value().dtype() == at::__dtype, #__x " is incorrect datatype, must be " #__dtype) +#define TORCH_CHECK_FLOAT_HALF(__x) TORCH_CHECK((__x).dtype() == at::kHalf || (__x).dtype() == at::kFloat, #__x " is incorrect datatype, must be kHalf or kFloat") +#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") +#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((!(__x).has_value()) || (__x).value().size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes") +#define TORCH_CHECK_SHAPES_FULL(__x, __y) TORCH_CHECK((__x).sizes() == (__y).sizes(), #__x " and " #__y " have incompatible shapes") +#define TORCH_CHECK_NUMEL(__x, __y) TORCH_CHECK((__x).numel() == (__y).numel(), #__x " and " #__y " have incompatible shapes") +#define TORCH_CHECK_DIV(__x, __dim_x, __div) TORCH_CHECK((__x).size(__dim_x) % __div == 0, #__x " dimension " #__dim_x " must be divisible by " #__div) +#define TORCH_CHECK_DIM(__x, __dims) TORCH_CHECK((__x).dim() == __dims, #__x " must have " #__dims " dimensions") +#define TORCH_CHECK_DIM_OPT(__x, __dims) TORCH_CHECK((!__x.has_value()) || (__x).value().dim() == __dims, #__x " must have " #__dims " dimensions") +#define TORCH_CHECK_SIZE(__x, __dim_x, __s) TORCH_CHECK((__x).size(__dim_x) == (__s), #__x " dimension " #__dim_x " is incorrect size") +#define OPTPTR(__x) (__x.has_value() ? __x.value().data_ptr() : nullptr) + +// Debug stuff + +#define DBGS(__x) printf("%s\n", __x) +#define DBGI(__x) \ + printf("%s: %i\n", #__x, __x) +#define DBGI2(__x, __y) \ + printf("%s, %s: %i, %i\n", #__x, #__y, __x, __y) +#define DBGI3(__x, __y, __z) \ + printf("%s, %s, %s: %i, %i, %i\n", #__x, #__y, #__z, __x, __y, __z) +#define DBGI4(__x, __y, __z, __w) \ + printf("%s, %s, %s, %s: %i, %i, %i, %i\n", #__x, #__y, #__z, #__w, __x, __y, __z, __w) +#define DBGI5(__x, __y, __z, __w, __v) \ + printf("%s, %s, %s, %s, %s: %i, %i, %i, %i, %i\n", #__x, #__y, #__z, #__w, #__v, __x, __y, __z, __w, __v) +#define DBGI6(__x, __y, __z, __w, __v, __u) \ + printf("%s, %s, %s, %s, %s, %s: %i, %i, %i, %i, %i, %i\n", #__x, #__y, #__z, #__w, #__v, #__u, __x, __y, __z, __w, __v, __u) +#define DBGI7(__x, __y, __z, __w, __v, __u, __t) \ + printf("%s, %s, %s, %s, %s, %s, %s: %i, %i, %i, %i, %i, %i, %i\n", #__x, #__y, #__z, #__w, #__v, #__u, #__t, __x, __y, __z, __w, __v, __u, __t) +#define DBGI8(__x, __y, __z, __w, __v, __u, __t, __s) \ + printf("%s, %s, %s, %s, %s, %s, %s, %s: %i, %i, %i, %i, %i, %i, %i, %i\n", #__x, #__y, #__z, #__w, #__v, #__u, #__t, #__s, __x, __y, __z, __w, __v, __u, __t, __s) +#define DBGI9(__x, __y, __z, __w, __v, __u, __t, __s, __r) \ + printf("%s, %s, %s, %s, %s, %s, %s, %s, %s: %i, %i, %i, %i, %i, %i, %i, %i, %i\n", #__x, #__y, #__z, #__w, #__v, #__u, #__t, #__s, #__r, __x, __y, __z, __w, __v, __u, __t, __s, __r) +#define DBGI10(__x, __y, __z, __w, __v, __u, __t, __s, __r, __q) \ + printf("%s, %s, %s, %s, %s, %s, %s, %s, %s, %s: %i, %i, %i, %i, %i, %i, %i, %i, %i, %i\n", #__x, #__y, #__z, #__w, #__v, #__u, #__t, #__s, #__r, #__q, __x, __y, __z, __w, __v, __u, __t, __s, __r, __q) +#define DBGX(__x) printf("%s: %x\n", #__x, __x) +#define DBGX2(__x, __y) printf("%s, %s: %x, %x\n", #__x, #__y, __x, __y) +#define DBGX3(__x, __y, __z) printf("%s, %s, %s: %x, %x, %x\n", #__x, #__y, #__z, __x, __y, __z) +#define DBGIX(__x, __y) printf("%s, %s: %i, %x\n", #__x, #__y, __x, __y) +#define DBGIX2(__x, __y, __z) printf("%s, %s, %s: %i, %x, %x\n", #__x, #__y, #__z, __x, __y, __z) +#define DBGIF(__x, __y) printf("%s, %s: %i, %f\n", #__x, #__y, __x, __y) +#define DBGIF2(__x, __y, __z) printf("%s, %s, %s: %i, %f, %f\n", #__x, #__y, #__z, __x, __y, __z) +#define DBGF(__x) printf("%s: %f\n", #__x, __x) +#define DBGF2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __x, __y) +#define DBGF3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __x, __y, __z) +#define DBGF4(__x, __y, __z, __w) printf("%s, %s, %s, %s: %f, %f, %f, %f\n", #__x, #__y, #__z, #__w, __x, __y, __z, __w) +#define DBGH(__x) printf("%s: %f\n", #__x, __half2float(__x)) +#define DBGH2(__x, __y) printf("%s, %s: %f, %f\n", #__x, #__y, __half2float(__x), __half2float(__y)) +#define DBGH3(__x, __y, __z) printf("%s, %s, %s: %f, %f, %f\n", #__x, #__y, #__z, __half2float(__x), __half2float(__y), __half2float(__z)) +#define DBGIH(__x, __y) printf("%s, %s: %i, %f\n", #__x, #__y, __x, __half2float(__y)) +#define DBGIH2(__x, __y, __z) printf("%s, %s, %s: %i, %f, %f\n", #__x, #__y, #__z, __x, __half2float(__y), __half2float(__z)) +#define DBGI2H2(__x, __y, __z, __w) printf("%s, %s, %s, %s: %i, %i, %f, %f\n", #__x, #__y, #__z, #__w, __x, __y, __half2float(__z), __half2float(__w)) +#define DBGIH3(__x, __y, __z, __w) printf("%s, %s, %s, %s: %i, %f, %f, %f\n", #__x, #__y, #__z, #__w, __x, __half2float(__y), __half2float(__z), __half2float(__w)) +#define DBGIH4(__x, __y, __z, __w, __v) printf("%s, %s, %s, %s, %s: %i, %f, %f, %f, %f\n", #__x, #__y, #__z, #__w, #__v, __x, __half2float(__y), __half2float(__z), __half2float(__w), __half2float(__v)) +#define DBGA(__x) printf("%s: %016llx\n", #__x, __x) +#define DBGIA(__x, __y) printf("%s, %s: %i, %016llx\n", #__x, #__y, __x, __y) +#define DBGI2A(__x, __y, __z) printf("%s, %s, %s: %i, %i, %016llx\n", #__x, #__y, #__z, __x, __y, __z) + +#define TIME_START \ + auto start = std::chrono::high_resolution_clock::now() + +#define TIME_STOP \ + do { \ + auto stop = std::chrono::high_resolution_clock::now(); \ + auto duration_us = std::chrono::duration_cast(stop - start); \ + DBGI(duration_us); \ + } while (false) + +/* +Compile-time for loop. Supports template instancing. Example usage: + +int kernel_arg = select_kernel_somehow(); + +// Not nice +if (kernel_arg == 2) + launch_kernel_instance<2><<< ... >>>( ... ) +if (kernel_arg == 3) + launch_kernel_instance<3><<< ... >>>( ... ) +if (kernel_arg == 4) + launch_kernel_instance<4><<< ... >>>( ... ) +if (kernel_arg == 6) + launch_kernel_instance<6><<< ... >>>( ... ) +if (kernel_arg == 8) + launch_kernel_instance<8><<< ... >>>( ... ) + +// Nice? +static_for_pack<2, 3, 4, 6, 8>([&](auto ic) +{ + constexpr int i = decltype(ic)::value; + if (kernel_arg == i) + launch_kernel_instance<<< ... >>>( ... ) +}); + +*/ + +// This breaks with nesting on VC++ older than 17.13 (late 2024 preview) +template +constexpr void static_for_pack(F&& f) +{ + (f(std::integral_constant{}), ...); +} diff --git a/gptqmodel_ext/floatx_cpu.cpp b/gptqmodel_ext/floatx_cpu.cpp new file mode 100644 index 000000000..9d6b58577 --- /dev/null +++ b/gptqmodel_ext/floatx_cpu.cpp @@ -0,0 +1,2294 @@ +// SPDX-FileCopyrightText: 2026 ModelCloud.ai +// SPDX-License-Identifier: Apache-2.0 + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#if defined(__x86_64__) || defined(__i386__) || defined(_M_X64) || defined(_M_IX86) +#include +#if defined(__GNUC__) || defined(__clang__) +#include +#endif +#define GPTQMODEL_FLOATX_X86 1 +#else +#define GPTQMODEL_FLOATX_X86 0 +#endif + +namespace gptqmodel_floatx { + +enum class ScaleMode : int64_t { + kNone = 0, + kMultiply = 1, + kDivide = 2, +}; + +enum class TargetKind : int64_t { + kBFloat16 = 0, + kFloat16 = 1, +}; + +enum class Fp8Format : int64_t { + kE4M3Fn = 0, + kE5M2 = 1, + kE4M3FnUz = 2, + kE5M2FnUz = 3, + kE8M0Fnu = 4, +}; + +enum class ScaleLayout1D { + kNone, + kScalar, + kVector, + kRepeat, +}; + +enum class ScaleLayout2D { + kNone, + kScalar, + kFull, + kBlock, + kRowRepeat, + kColRepeat, +}; + +struct ScaleSpec1D { + ScaleLayout1D layout = ScaleLayout1D::kNone; + const float* ptr = nullptr; + int64_t length = 0; + int64_t repeat = 1; +}; + +struct ScaleSpec2D { + ScaleLayout2D layout = ScaleLayout2D::kNone; + const float* ptr = nullptr; + int64_t rows = 0; + int64_t cols = 0; + int64_t scale_rows = 0; + int64_t scale_cols = 0; + int64_t row_repeat = 1; + int64_t col_repeat = 1; +}; + +inline int64_t clamped_threads(int64_t requested) { + const int64_t limit = 32; + if (requested > 0) { + return std::max(1, std::min(requested, limit)); + } + return std::max(1, std::min(at::get_num_threads(), limit)); +} + +template +std::array build_fp8_table() { + std::array table{}; + for (int value = 0; value < 256; ++value) { + const float decoded = + static_cast(SrcT(static_cast(value), SrcT::from_bits())); + table[value] = static_cast(DstT(decoded)); + } + return table; +} + +const std::array& fp8_table(Fp8Format format, TargetKind target_kind) { + // Target-rounded tables keep the hot loops from re-quantizing decoded values. + static const auto e4m3fn_bf16 = build_fp8_table(); + static const auto e5m2_bf16 = build_fp8_table(); + static const auto e4m3fnuz_bf16 = build_fp8_table(); + static const auto e5m2fnuz_bf16 = build_fp8_table(); + static const auto e8m0fnu_bf16 = build_fp8_table(); + static const auto e4m3fn_fp16 = build_fp8_table(); + static const auto e5m2_fp16 = build_fp8_table(); + static const auto e4m3fnuz_fp16 = build_fp8_table(); + static const auto e5m2fnuz_fp16 = build_fp8_table(); + static const auto e8m0fnu_fp16 = build_fp8_table(); + switch (target_kind) { + case TargetKind::kBFloat16: + switch (format) { + case Fp8Format::kE4M3Fn: + return e4m3fn_bf16; + case Fp8Format::kE5M2: + return e5m2_bf16; + case Fp8Format::kE4M3FnUz: + return e4m3fnuz_bf16; + case Fp8Format::kE5M2FnUz: + return e5m2fnuz_bf16; + case Fp8Format::kE8M0Fnu: + return e8m0fnu_bf16; + } + break; + case TargetKind::kFloat16: + switch (format) { + case Fp8Format::kE4M3Fn: + return e4m3fn_fp16; + case Fp8Format::kE5M2: + return e5m2_fp16; + case Fp8Format::kE4M3FnUz: + return e4m3fnuz_fp16; + case Fp8Format::kE5M2FnUz: + return e5m2fnuz_fp16; + case Fp8Format::kE8M0Fnu: + return e8m0fnu_fp16; + } + break; + } + TORCH_CHECK(false, "Unsupported FP8 format code"); +} + +template +std::array build_fp4_table() { + static constexpr std::array kDecodedValues = { + 0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f, + -0.0f, -0.5f, -1.0f, -1.5f, -2.0f, -3.0f, -4.0f, -6.0f, + }; + std::array table{}; + for (size_t idx = 0; idx < table.size(); ++idx) { + table[idx] = static_cast(DstT(kDecodedValues[idx])); + } + return table; +} + +const std::array& fp4_table(TargetKind target_kind) { + static const auto bf16 = build_fp4_table(); + static const auto fp16 = build_fp4_table(); + switch (target_kind) { + case TargetKind::kBFloat16: + return bf16; + case TargetKind::kFloat16: + return fp16; + } + TORCH_CHECK(false, "Unsupported target dtype for FP4 table"); +} + +ScaleSpec1D make_scale_spec_1d( + const c10::optional& scale_opt, + int64_t result_len, + int64_t axis, + bool axis_is_none) { + (void)axis; + (void)axis_is_none; + ScaleSpec1D spec; + if (!scale_opt.has_value() || !scale_opt->defined()) { + return spec; + } + + const at::Tensor& scale = *scale_opt; + TORCH_CHECK(scale.device().is_cpu(), "scale tensor must reside on CPU"); + TORCH_CHECK(scale.scalar_type() == at::kFloat, "scale tensor must be float32"); + TORCH_CHECK(scale.is_contiguous(), "scale tensor must be contiguous"); + + spec.ptr = scale.const_data_ptr(); + if (scale.ndimension() == 0) { + spec.layout = ScaleLayout1D::kScalar; + return spec; + } + + TORCH_CHECK(scale.ndimension() == 1, "1D dequantization only supports scalar or 1D scale tensors"); + spec.length = scale.numel(); + if (spec.length == result_len) { + spec.layout = ScaleLayout1D::kVector; + return spec; + } + TORCH_CHECK(result_len % spec.length == 0, "scale tensor shape incompatible with 1D output"); + spec.layout = ScaleLayout1D::kRepeat; + spec.repeat = result_len / spec.length; + return spec; +} + +ScaleSpec2D make_scale_spec_2d( + const c10::optional& scale_opt, + int64_t rows, + int64_t cols, + int64_t axis, + bool axis_is_none) { + ScaleSpec2D spec; + if (!scale_opt.has_value() || !scale_opt->defined()) { + return spec; + } + + const at::Tensor& scale = *scale_opt; + TORCH_CHECK(scale.device().is_cpu(), "scale tensor must reside on CPU"); + TORCH_CHECK(scale.scalar_type() == at::kFloat, "scale tensor must be float32"); + TORCH_CHECK(scale.is_contiguous(), "scale tensor must be contiguous"); + + spec.ptr = scale.const_data_ptr(); + if (scale.ndimension() == 0) { + spec.layout = ScaleLayout2D::kScalar; + return spec; + } + + if (scale.ndimension() == 1) { + const int64_t count = scale.size(0); + const int64_t resolved_axis = axis_is_none ? 0 : axis; + if (resolved_axis == 0) { + TORCH_CHECK(rows % count == 0, "row scale tensor shape incompatible with output"); + spec.layout = ScaleLayout2D::kRowRepeat; + spec.scale_rows = count; + spec.row_repeat = rows / count; + return spec; + } + TORCH_CHECK(resolved_axis == 1, "axis must be 0 or 1 for 2D scale tensors"); + TORCH_CHECK(cols % count == 0, "column scale tensor shape incompatible with output"); + spec.layout = ScaleLayout2D::kColRepeat; + spec.scale_cols = count; + spec.col_repeat = cols / count; + return spec; + } + + TORCH_CHECK(scale.ndimension() == 2, "2D dequantization only supports scale tensors up to rank 2"); + spec.scale_rows = scale.size(0); + spec.scale_cols = scale.size(1); + if (spec.scale_rows == rows && spec.scale_cols == cols) { + spec.layout = ScaleLayout2D::kFull; + spec.cols = cols; + return spec; + } + + TORCH_CHECK( + rows % spec.scale_rows == 0 && cols % spec.scale_cols == 0, + "block scale tensor shape incompatible with output"); + spec.layout = ScaleLayout2D::kBlock; + spec.row_repeat = rows / spec.scale_rows; + spec.col_repeat = cols / spec.scale_cols; + spec.cols = cols; + return spec; +} + +inline float scale_value_1d(const ScaleSpec1D& spec, int64_t idx) { + switch (spec.layout) { + case ScaleLayout1D::kNone: + return 1.0f; + case ScaleLayout1D::kScalar: + return spec.ptr[0]; + case ScaleLayout1D::kVector: + return spec.ptr[idx]; + case ScaleLayout1D::kRepeat: + return spec.ptr[idx / spec.repeat]; + } + return 1.0f; +} + +inline float scale_value_2d(const ScaleSpec2D& spec, int64_t row, int64_t col) { + switch (spec.layout) { + case ScaleLayout2D::kNone: + return 1.0f; + case ScaleLayout2D::kScalar: + return spec.ptr[0]; + case ScaleLayout2D::kFull: + return spec.ptr[row * spec.cols + col]; + case ScaleLayout2D::kBlock: + return spec.ptr[(row / spec.row_repeat) * spec.scale_cols + (col / spec.col_repeat)]; + case ScaleLayout2D::kRowRepeat: + return spec.ptr[row / spec.row_repeat]; + case ScaleLayout2D::kColRepeat: + return spec.ptr[col / spec.col_repeat]; + } + return 1.0f; +} + +template +inline void store_scalar(T* dst, float value); + +template <> +inline void store_scalar(c10::Half* dst, float value) { + *dst = c10::Half(value); +} + +template <> +inline void store_scalar(c10::BFloat16* dst, float value) { + *dst = c10::BFloat16(value); +} + +template +inline void apply_scale_and_store_scalar( + T* dst, + const float* values, + int64_t count, + ScaleMode scale_mode, + const ScaleSpec2D& spec, + int64_t row, + int64_t col_base) { + for (int64_t i = 0; i < count; ++i) { + const T rounded_value = T(values[i]); + float value = static_cast(rounded_value); + if (scale_mode != ScaleMode::kNone) { + const T rounded_scale = T(scale_value_2d(spec, row, col_base + i)); + const float scale = static_cast(rounded_scale); + value = scale_mode == ScaleMode::kMultiply ? value * scale : value / scale; + } + store_scalar(dst + i, value); + } +} + +template +inline void apply_scale_and_store_scalar_1d( + T* dst, + float value, + ScaleMode scale_mode, + const ScaleSpec1D& spec, + int64_t idx) { + const T rounded_value = T(value); + value = static_cast(rounded_value); + if (scale_mode != ScaleMode::kNone) { + const T rounded_scale = T(scale_value_1d(spec, idx)); + const float scale = static_cast(rounded_scale); + value = scale_mode == ScaleMode::kMultiply ? value * scale : value / scale; + } + store_scalar(dst + idx, value); +} + +#if GPTQMODEL_FLOATX_X86 && (defined(__GNUC__) || defined(__clang__)) +inline bool env_flag_enabled(const char* name) { + // Allow tests and compatibility debugging to force the scalar fallback. + const char* value = std::getenv(name); + if (value == nullptr || value[0] == '\0') { + return false; + } + switch (value[0]) { + case '1': + case 'Y': + case 'y': + case 'T': + case 't': + return true; + default: + return false; + } +} + +inline bool cpu_supports_avx2() { + if (env_flag_enabled("GPTQMODEL_FLOATX_CPU_DISABLE_AVX2")) { + return false; + } + return __builtin_cpu_supports("avx2"); +} + +inline bool cpu_supports_f16c() { + if (env_flag_enabled("GPTQMODEL_FLOATX_CPU_DISABLE_AVX2")) { + return false; + } + return __builtin_cpu_supports("f16c"); +} + +inline bool cpu_supports_avx512_core() { + if (env_flag_enabled("GPTQMODEL_FLOATX_CPU_DISABLE_AVX2") || + env_flag_enabled("GPTQMODEL_FLOATX_CPU_DISABLE_AVX512")) { + return false; + } + return __builtin_cpu_supports("avx512f") && + __builtin_cpu_supports("avx512bw") && + __builtin_cpu_supports("avx512vl"); +} + +inline bool cpu_supports_avx512bf16() { + return cpu_supports_avx512_core() && __builtin_cpu_supports("avx512bf16"); +} + +inline bool cpu_supports_avx512fp16() { + return cpu_supports_avx512_core() && __builtin_cpu_supports("avx512fp16"); +} + +inline bool cpu_prefers_avx2_fp8_bf16() { + // Zen 5 class EPYC parts expose AVX-512 BF16, but the 16-lane FP8 gather path + // still loses to AVX2 on this host because the wider gather path costs more than + // the extra SIMD width saves. Keep the narrower path as a targeted runtime quirk + // instead of changing the generic dispatch for every AVX-512 CPU. + unsigned int eax = 0; + unsigned int ebx = 0; + unsigned int ecx = 0; + unsigned int edx = 0; + if (__get_cpuid_max(0, nullptr) == 0 || !__get_cpuid(0, &eax, &ebx, &ecx, &edx)) { + return false; + } + if (ebx != 0x68747541u || edx != 0x69746e65u || ecx != 0x444d4163u) { + return false; + } + + if (!__get_cpuid(1, &eax, &ebx, &ecx, &edx)) { + return false; + } + const unsigned int base_family = (eax >> 8) & 0x0F; + const unsigned int base_model = (eax >> 4) & 0x0F; + const unsigned int ext_family = (eax >> 20) & 0xFF; + const unsigned int ext_model = (eax >> 16) & 0x0F; + const unsigned int family = base_family == 0x0F ? base_family + ext_family : base_family; + const unsigned int model = base_model | (ext_model << 4); + return family == 26 && model == 2; +} + +__attribute__((target("avx512f,avx512bw,avx512vl"))) +inline __m512 load_fp8x16_to_ps_avx512(const uint8_t* src, const float* table) { + const __m128i raw = _mm_loadu_si128(reinterpret_cast(src)); + const __m512i indices = _mm512_cvtepu8_epi32(raw); + return _mm512_i32gather_ps(indices, table, 4); +} + +__attribute__((target("avx512f,avx512bw,avx512vl"))) +inline __m512 load_fp4x16_to_ps_avx512(const uint8_t* src, const float* table) { + const __m128i raw = _mm_loadl_epi64(reinterpret_cast(src)); + const __m128i lo_nibbles = _mm_and_si128(raw, _mm_set1_epi8(0x0F)); + const __m128i hi_nibbles = _mm_and_si128(_mm_srli_epi16(raw, 4), _mm_set1_epi8(0x0F)); + const __m128i interleaved = _mm_unpacklo_epi8(lo_nibbles, hi_nibbles); + return _mm512_i32gather_ps(_mm512_cvtepu8_epi32(interleaved), table, 4); +} + +__attribute__((target("avx512f,avx512bw,avx512vl"))) +inline void fill_scale16( + float* dst, + const ScaleSpec2D& spec, + int64_t row, + int64_t col_base) { + for (int i = 0; i < 16; ++i) { + dst[i] = scale_value_2d(spec, row, col_base + i); + } +} + +__attribute__((target("avx512f,avx512bw,avx512vl,avx512bf16"))) +inline __m512 round_ps_to_bf16_ps_avx512(__m512 values) { + return _mm512_cvtpbh_ps(_mm512_cvtneps_pbh(values)); +} + +__attribute__((target("avx512f,avx512bw,avx512vl,avx512fp16"))) +inline __m512 round_ps_to_fp16_ps_avx512(__m512 values) { + const __m256i fp16 = + (__m256i)_mm512_cvtps_ph(values, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + return _mm512_cvtph_ps(fp16); +} + +__attribute__((target("avx512f,avx512bw,avx512vl"))) +inline __m512 load_scale16_ps_avx512( + const ScaleSpec2D& spec, + int64_t row, + int64_t col_base) { + switch (spec.layout) { + case ScaleLayout2D::kNone: + return _mm512_set1_ps(1.0f); + case ScaleLayout2D::kScalar: + return _mm512_set1_ps(spec.ptr[0]); + case ScaleLayout2D::kFull: + return _mm512_loadu_ps(spec.ptr + row * spec.cols + col_base); + case ScaleLayout2D::kRowRepeat: + return _mm512_set1_ps(spec.ptr[row / spec.row_repeat]); + case ScaleLayout2D::kColRepeat: { + if (spec.col_repeat == 1) { + return _mm512_loadu_ps(spec.ptr + col_base); + } + if ((col_base / spec.col_repeat) == ((col_base + 15) / spec.col_repeat)) { + return _mm512_set1_ps(spec.ptr[col_base / spec.col_repeat]); + } + alignas(64) float scales[16]; + fill_scale16(scales, spec, row, col_base); + return _mm512_load_ps(scales); + } + case ScaleLayout2D::kBlock: { + const int64_t block_row = row / spec.row_repeat; + if (spec.col_repeat == 1) { + return _mm512_loadu_ps(spec.ptr + block_row * spec.scale_cols + col_base); + } + if ((col_base / spec.col_repeat) == ((col_base + 15) / spec.col_repeat)) { + return _mm512_set1_ps(spec.ptr[block_row * spec.scale_cols + (col_base / spec.col_repeat)]); + } + alignas(64) float scales[16]; + fill_scale16(scales, spec, row, col_base); + return _mm512_load_ps(scales); + } + } + return _mm512_set1_ps(1.0f); +} + +__attribute__((target("avx512f,avx512bw,avx512vl,avx512bf16"))) +inline void store_bf16x16(c10::BFloat16* dst, __m512 values) { + const __m256bh packed = _mm512_cvtneps_pbh(values); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), (__m256i)packed); +} + +__attribute__((target("avx512f,avx512bw,avx512vl,avx512bf16"))) +inline void apply_scale_and_store_bf16x16( + c10::BFloat16* dst, + __m512 values, + ScaleMode scale_mode, + const ScaleSpec2D& spec, + int64_t row, + int64_t col_base) { + if (scale_mode != ScaleMode::kNone) { + __m512 scales = round_ps_to_bf16_ps_avx512(load_scale16_ps_avx512(spec, row, col_base)); + values = scale_mode == ScaleMode::kMultiply ? _mm512_mul_ps(values, scales) : _mm512_div_ps(values, scales); + } + store_bf16x16(dst, values); +} + +__attribute__((target("avx512f,avx512bw,avx512vl,avx512bf16"))) +inline void apply_scale_and_store_bf16x16_const( + c10::BFloat16* dst, + __m512 values, + ScaleMode scale_mode, + __m512 rounded_scale) { + if (scale_mode != ScaleMode::kNone) { + values = scale_mode == ScaleMode::kMultiply ? _mm512_mul_ps(values, rounded_scale) : _mm512_div_ps(values, rounded_scale); + } + store_bf16x16(dst, values); +} + +__attribute__((target("avx512f,avx512bw,avx512vl,avx512fp16"))) +inline void store_fp16x16(c10::Half* dst, __m512 values) { + const __m256i packed = + (__m256i)_mm512_cvtps_ph(values, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), packed); +} + +__attribute__((target("avx512f,avx512bw,avx512vl,avx512fp16"))) +inline void apply_scale_and_store_fp16x16( + c10::Half* dst, + __m512 values, + ScaleMode scale_mode, + const ScaleSpec2D& spec, + int64_t row, + int64_t col_base) { + if (scale_mode != ScaleMode::kNone) { + __m512 scales = round_ps_to_fp16_ps_avx512(load_scale16_ps_avx512(spec, row, col_base)); + values = scale_mode == ScaleMode::kMultiply ? _mm512_mul_ps(values, scales) : _mm512_div_ps(values, scales); + } + store_fp16x16(dst, values); +} + +__attribute__((target("avx512f,avx512bw,avx512vl,avx512fp16"))) +inline void apply_scale_and_store_fp16x16_const( + c10::Half* dst, + __m512 values, + ScaleMode scale_mode, + __m512 rounded_scale) { + if (scale_mode != ScaleMode::kNone) { + values = scale_mode == ScaleMode::kMultiply ? _mm512_mul_ps(values, rounded_scale) : _mm512_div_ps(values, rounded_scale); + } + store_fp16x16(dst, values); +} + +__attribute__((target("avx512f,avx512bw,avx512vl,avx512bf16"))) +void dequantize_fp8_row_avx512_bf16( + const uint8_t* src_row, + c10::BFloat16* dst_row, + int64_t cols, + const std::array& table, + const ScaleSpec2D& spec, + ScaleMode scale_mode, + int64_t row) { + int64_t col = 0; + for (; col + 16 <= cols; col += 16) { + const __m512 values = load_fp8x16_to_ps_avx512(src_row + col, table.data()); + apply_scale_and_store_bf16x16(dst_row + col, values, scale_mode, spec, row, col); + } + if (col < cols) { + alignas(64) float tail[16] = {}; + const int64_t tail_count = cols - col; + for (int64_t i = 0; i < tail_count; ++i) { + tail[i] = table[src_row[col + i]]; + } + apply_scale_and_store_scalar(dst_row + col, tail, tail_count, scale_mode, spec, row, col); + } +} + +__attribute__((target("avx512f,avx512bw,avx512vl,avx512bf16"))) +void dequantize_fp8_row_avx512_bf16_const_scale( + const uint8_t* src_row, + c10::BFloat16* dst_row, + int64_t cols, + const std::array& table, + ScaleMode scale_mode, + float rounded_scale) { + // Pre-scale the tiny FP8 lookup table once per row so the hot loop only gathers and stores. + alignas(64) float scaled_table[256]; + for (int i = 0; i < 256; ++i) { + scaled_table[i] = scale_mode == ScaleMode::kMultiply ? table[i] * rounded_scale : table[i] / rounded_scale; + } + int64_t col = 0; + for (; col + 16 <= cols; col += 16) { + const __m512 values = load_fp8x16_to_ps_avx512(src_row + col, scaled_table); + store_bf16x16(dst_row + col, values); + } + if (col < cols) { + alignas(64) float tail[16] = {}; + const int64_t tail_count = cols - col; + for (int64_t i = 0; i < tail_count; ++i) { + tail[i] = scaled_table[src_row[col + i]]; + } + for (int64_t i = 0; i < tail_count; ++i) { + store_scalar(dst_row + col + i, tail[i]); + } + } +} + +__attribute__((target("avx512f,avx512bw,avx512vl,avx512bf16"))) +void dequantize_fp8_row_avx512_bf16_block_scale( + const uint8_t* src_row, + c10::BFloat16* dst_row, + int64_t cols, + const std::array& table, + const ScaleSpec2D& spec, + ScaleMode scale_mode, + int64_t row) { + const int64_t block_row = row / spec.row_repeat; + int64_t col = 0; + while (col < cols) { + const int64_t block_col = col / spec.col_repeat; + const int64_t block_end = std::min(cols, (block_col + 1) * spec.col_repeat); + const float rounded_scale = static_cast(c10::BFloat16(spec.ptr[block_row * spec.scale_cols + block_col])); + const __m512 scale_vec = _mm512_set1_ps(rounded_scale); + + for (; col + 16 <= block_end; col += 16) { + const __m512 values = load_fp8x16_to_ps_avx512(src_row + col, table.data()); + apply_scale_and_store_bf16x16_const(dst_row + col, values, scale_mode, scale_vec); + } + if (col < block_end) { + alignas(64) float tail[16] = {}; + const int64_t tail_count = block_end - col; + for (int64_t i = 0; i < tail_count; ++i) { + float value = table[src_row[col + i]]; + if (scale_mode != ScaleMode::kNone) { + value = scale_mode == ScaleMode::kMultiply ? value * rounded_scale : value / rounded_scale; + } + tail[i] = value; + } + for (int64_t i = 0; i < tail_count; ++i) { + store_scalar(dst_row + col + i, tail[i]); + } + col = block_end; + } + } +} + +__attribute__((target("avx512f,avx512bw,avx512vl,avx512fp16"))) +void dequantize_fp8_row_avx512_fp16( + const uint8_t* src_row, + c10::Half* dst_row, + int64_t cols, + const std::array& table, + const ScaleSpec2D& spec, + ScaleMode scale_mode, + int64_t row) { + int64_t col = 0; + for (; col + 16 <= cols; col += 16) { + const __m512 values = load_fp8x16_to_ps_avx512(src_row + col, table.data()); + apply_scale_and_store_fp16x16(dst_row + col, values, scale_mode, spec, row, col); + } + if (col < cols) { + alignas(64) float tail[16] = {}; + const int64_t tail_count = cols - col; + for (int64_t i = 0; i < tail_count; ++i) { + tail[i] = table[src_row[col + i]]; + } + apply_scale_and_store_scalar(dst_row + col, tail, tail_count, scale_mode, spec, row, col); + } +} + +__attribute__((target("avx512f,avx512bw,avx512vl,avx512fp16"))) +void dequantize_fp8_row_avx512_fp16_const_scale( + const uint8_t* src_row, + c10::Half* dst_row, + int64_t cols, + const std::array& table, + ScaleMode scale_mode, + float rounded_scale) { + // Pre-scale the tiny FP8 lookup table once per row so the hot loop only gathers and stores. + alignas(64) float scaled_table[256]; + for (int i = 0; i < 256; ++i) { + scaled_table[i] = scale_mode == ScaleMode::kMultiply ? table[i] * rounded_scale : table[i] / rounded_scale; + } + int64_t col = 0; + for (; col + 16 <= cols; col += 16) { + const __m512 values = load_fp8x16_to_ps_avx512(src_row + col, scaled_table); + store_fp16x16(dst_row + col, values); + } + if (col < cols) { + alignas(64) float tail[16] = {}; + const int64_t tail_count = cols - col; + for (int64_t i = 0; i < tail_count; ++i) { + tail[i] = scaled_table[src_row[col + i]]; + } + for (int64_t i = 0; i < tail_count; ++i) { + store_scalar(dst_row + col + i, tail[i]); + } + } +} + +__attribute__((target("avx512f,avx512bw,avx512vl,avx512fp16"))) +void dequantize_fp8_row_avx512_fp16_block_scale( + const uint8_t* src_row, + c10::Half* dst_row, + int64_t cols, + const std::array& table, + const ScaleSpec2D& spec, + ScaleMode scale_mode, + int64_t row) { + const int64_t block_row = row / spec.row_repeat; + int64_t col = 0; + while (col < cols) { + const int64_t block_col = col / spec.col_repeat; + const int64_t block_end = std::min(cols, (block_col + 1) * spec.col_repeat); + const float rounded_scale = static_cast(c10::Half(spec.ptr[block_row * spec.scale_cols + block_col])); + const __m512 scale_vec = _mm512_set1_ps(rounded_scale); + + for (; col + 16 <= block_end; col += 16) { + const __m512 values = load_fp8x16_to_ps_avx512(src_row + col, table.data()); + apply_scale_and_store_fp16x16_const(dst_row + col, values, scale_mode, scale_vec); + } + if (col < block_end) { + alignas(64) float tail[16] = {}; + const int64_t tail_count = block_end - col; + for (int64_t i = 0; i < tail_count; ++i) { + float value = table[src_row[col + i]]; + if (scale_mode != ScaleMode::kNone) { + value = scale_mode == ScaleMode::kMultiply ? value * rounded_scale : value / rounded_scale; + } + tail[i] = value; + } + for (int64_t i = 0; i < tail_count; ++i) { + store_scalar(dst_row + col + i, tail[i]); + } + col = block_end; + } + } +} + +__attribute__((target("avx2"))) +inline __m256 load_fp8x8_to_ps(const uint8_t* src, const float* table) { + const __m128i raw = _mm_loadl_epi64(reinterpret_cast(src)); + const __m256i indices = _mm256_cvtepu8_epi32(raw); + return _mm256_i32gather_ps(table, indices, 4); +} + +__attribute__((target("avx2"))) +inline void load_fp8x16_to_ps( + const uint8_t* src, + const float* table, + __m256* values_lo, + __m256* values_hi) { + // Two gathers amortize loop overhead on the common 16-aligned benchmark shapes. + *values_lo = load_fp8x8_to_ps(src, table); + *values_hi = load_fp8x8_to_ps(src + 8, table); +} + +__attribute__((target("avx2"))) +inline void load_fp4x16_to_ps( + const uint8_t* src, + const float* table, + __m256* values_lo, + __m256* values_hi) { + // Decode 8 packed bytes into 16 logical FP4 values in column order. + const __m128i raw = _mm_loadl_epi64(reinterpret_cast(src)); + const __m128i lo_nibbles = _mm_and_si128(raw, _mm_set1_epi8(0x0F)); + const __m128i hi_nibbles = _mm_and_si128(_mm_srli_epi16(raw, 4), _mm_set1_epi8(0x0F)); + const __m128i interleaved = _mm_unpacklo_epi8(lo_nibbles, hi_nibbles); + *values_lo = _mm256_i32gather_ps(table, _mm256_cvtepu8_epi32(interleaved), 4); + *values_hi = _mm256_i32gather_ps(table, _mm256_cvtepu8_epi32(_mm_srli_si128(interleaved, 8)), 4); +} + +__attribute__((target("avx2"))) +inline __m256 load_fp4x8_to_ps(const uint8_t* src, const float* table) { + const __m128i raw = _mm_cvtsi32_si128(*reinterpret_cast(src)); + const __m128i lo_nibbles = _mm_and_si128(raw, _mm_set1_epi8(0x0F)); + const __m128i hi_nibbles = _mm_and_si128(_mm_srli_epi16(raw, 4), _mm_set1_epi8(0x0F)); + const __m128i interleaved = _mm_unpacklo_epi8(lo_nibbles, hi_nibbles); + return _mm256_i32gather_ps(table, _mm256_cvtepu8_epi32(interleaved), 4); +} + +__attribute__((target("avx2"))) +inline void fill_scale8( + float* dst, + const ScaleSpec2D& spec, + int64_t row, + int64_t col_base) { + for (int i = 0; i < 8; ++i) { + dst[i] = scale_value_2d(spec, row, col_base + i); + } +} + +__attribute__((target("avx2"))) +inline __m256 round_ps_to_bf16_ps(__m256 values) { + // Match the Python reference by rounding operands to bf16 before scaling. + const __m256i bits = _mm256_castps_si256(values); + const __m256i lsb = _mm256_and_si256(_mm256_srli_epi32(bits, 16), _mm256_set1_epi32(1)); + const __m256i rounding = _mm256_add_epi32(_mm256_set1_epi32(0x7fff), lsb); + const __m256i rounded = _mm256_add_epi32(bits, rounding); + const __m256i bf16_bits = _mm256_slli_epi32(_mm256_srli_epi32(rounded, 16), 16); + return _mm256_castsi256_ps(bf16_bits); +} + +__attribute__((target("avx2,f16c"))) +inline __m256 round_ps_to_fp16_ps(__m256 values) { + // F16C lets us round to fp16 and back to float32 without leaving SIMD. + const __m128i fp16 = + _mm256_cvtps_ph(values, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + return _mm256_cvtph_ps(fp16); +} + +__attribute__((target("avx2"))) +inline __m256 load_scale8_ps( + const ScaleSpec2D& spec, + int64_t row, + int64_t col_base) { + // Most benchmarked scale layouts are constant or contiguous across eight lanes, + // so specialize those cases and only fall back to per-lane expansion when needed. + switch (spec.layout) { + case ScaleLayout2D::kNone: + return _mm256_set1_ps(1.0f); + case ScaleLayout2D::kScalar: + return _mm256_set1_ps(spec.ptr[0]); + case ScaleLayout2D::kFull: + return _mm256_loadu_ps(spec.ptr + row * spec.cols + col_base); + case ScaleLayout2D::kRowRepeat: + return _mm256_set1_ps(spec.ptr[row / spec.row_repeat]); + case ScaleLayout2D::kColRepeat: { + if (spec.col_repeat == 1) { + return _mm256_loadu_ps(spec.ptr + col_base); + } + if ((col_base / spec.col_repeat) == ((col_base + 7) / spec.col_repeat)) { + return _mm256_set1_ps(spec.ptr[col_base / spec.col_repeat]); + } + alignas(32) float scales[8]; + fill_scale8(scales, spec, row, col_base); + return _mm256_load_ps(scales); + } + case ScaleLayout2D::kBlock: { + const int64_t block_row = row / spec.row_repeat; + if (spec.col_repeat == 1) { + return _mm256_loadu_ps(spec.ptr + block_row * spec.scale_cols + col_base); + } + if ((col_base / spec.col_repeat) == ((col_base + 7) / spec.col_repeat)) { + return _mm256_set1_ps(spec.ptr[block_row * spec.scale_cols + (col_base / spec.col_repeat)]); + } + alignas(32) float scales[8]; + fill_scale8(scales, spec, row, col_base); + return _mm256_load_ps(scales); + } + } + return _mm256_set1_ps(1.0f); +} + +__attribute__((target("avx2"))) +inline void store_bf16x8(c10::BFloat16* dst, __m256 values); + +__attribute__((target("avx2,f16c"))) +inline void store_fp16x8(c10::Half* dst, __m256 values); + +__attribute__((target("avx2"))) +inline void apply_scale_and_store_bf16x8( + c10::BFloat16* dst, + __m256 values, + ScaleMode scale_mode, + const ScaleSpec2D& spec, + int64_t row, + int64_t col_base) { + // Values arrive pre-rounded from the lookup tables; only scale operands need rounding here. + if (scale_mode != ScaleMode::kNone) { + __m256 scales = round_ps_to_bf16_ps(load_scale8_ps(spec, row, col_base)); + values = scale_mode == ScaleMode::kMultiply ? _mm256_mul_ps(values, scales) : _mm256_div_ps(values, scales); + } + store_bf16x8(dst, values); +} + +__attribute__((target("avx2"))) +inline void apply_scale_and_store_bf16x8_const( + c10::BFloat16* dst, + __m256 values, + ScaleMode scale_mode, + __m256 rounded_scale) { + if (scale_mode != ScaleMode::kNone) { + values = scale_mode == ScaleMode::kMultiply ? _mm256_mul_ps(values, rounded_scale) : _mm256_div_ps(values, rounded_scale); + } + store_bf16x8(dst, values); +} + +__attribute__((target("avx2"))) +inline void store_bf16x8(c10::BFloat16* dst, __m256 values) { + const __m256i bits = _mm256_castps_si256(values); + const __m256i lsb = _mm256_and_si256(_mm256_srli_epi32(bits, 16), _mm256_set1_epi32(1)); + const __m256i rounding = _mm256_add_epi32(_mm256_set1_epi32(0x7fff), lsb); + const __m256i rounded = _mm256_add_epi32(bits, rounding); + const __m256i bf16 = _mm256_srli_epi32(rounded, 16); + const __m128i lo = _mm256_castsi256_si128(bf16); + const __m128i hi = _mm256_extracti128_si256(bf16, 1); + const __m128i packed = _mm_packus_epi32(lo, hi); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), packed); +} + +__attribute__((target("avx2,f16c"))) +inline void store_fp16x8(c10::Half* dst, __m256 values) { + const __m128i packed = _mm256_cvtps_ph(values, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC); + _mm_storeu_si128(reinterpret_cast<__m128i*>(dst), packed); +} + +__attribute__((target("avx2,f16c"))) +inline void apply_scale_and_store_fp16x8( + c10::Half* dst, + __m256 values, + ScaleMode scale_mode, + const ScaleSpec2D& spec, + int64_t row, + int64_t col_base) { + // Values arrive pre-rounded from the lookup tables; only scale operands need rounding here. + if (scale_mode != ScaleMode::kNone) { + __m256 scales = round_ps_to_fp16_ps(load_scale8_ps(spec, row, col_base)); + values = scale_mode == ScaleMode::kMultiply ? _mm256_mul_ps(values, scales) : _mm256_div_ps(values, scales); + } + store_fp16x8(dst, values); +} + +__attribute__((target("avx2,f16c"))) +inline void apply_scale_and_store_fp16x8_const( + c10::Half* dst, + __m256 values, + ScaleMode scale_mode, + __m256 rounded_scale) { + if (scale_mode != ScaleMode::kNone) { + values = scale_mode == ScaleMode::kMultiply ? _mm256_mul_ps(values, rounded_scale) : _mm256_div_ps(values, rounded_scale); + } + store_fp16x8(dst, values); +} + +__attribute__((target("avx2"))) +void dequantize_fp8_row_avx2_bf16( + const uint8_t* src_row, + c10::BFloat16* dst_row, + int64_t cols, + const std::array& table, + const ScaleSpec2D& spec, + ScaleMode scale_mode, + int64_t row) { + int64_t col = 0; + for (; col + 16 <= cols; col += 16) { + __m256 values_lo; + __m256 values_hi; + load_fp8x16_to_ps(src_row + col, table.data(), &values_lo, &values_hi); + apply_scale_and_store_bf16x8(dst_row + col, values_lo, scale_mode, spec, row, col); + apply_scale_and_store_bf16x8(dst_row + col + 8, values_hi, scale_mode, spec, row, col + 8); + } + for (; col + 8 <= cols; col += 8) { + __m256 values = load_fp8x8_to_ps(src_row + col, table.data()); + apply_scale_and_store_bf16x8(dst_row + col, values, scale_mode, spec, row, col); + } + if (col < cols) { + alignas(32) float tail[8] = {}; + const int64_t tail_count = cols - col; + for (int64_t i = 0; i < tail_count; ++i) { + tail[i] = table[src_row[col + i]]; + } + apply_scale_and_store_scalar(dst_row + col, tail, tail_count, scale_mode, spec, row, col); + } +} + +__attribute__((target("avx2"))) +void dequantize_fp8_row_avx2_bf16_full_scale( + const uint8_t* src_row, + c10::BFloat16* dst_row, + const float* scale_row, + int64_t cols, + const std::array& table, + ScaleMode scale_mode) { + // Full-scale rows are contiguous in memory. Load scales directly from the row + // slice so the hot loop avoids the generic layout switch and per-vector helper. + int64_t col = 0; + for (; col + 16 <= cols; col += 16) { + __m256 values_lo; + __m256 values_hi; + load_fp8x16_to_ps(src_row + col, table.data(), &values_lo, &values_hi); + __m256 scales_lo = round_ps_to_bf16_ps(_mm256_loadu_ps(scale_row + col)); + __m256 scales_hi = round_ps_to_bf16_ps(_mm256_loadu_ps(scale_row + col + 8)); + values_lo = scale_mode == ScaleMode::kMultiply ? _mm256_mul_ps(values_lo, scales_lo) : _mm256_div_ps(values_lo, scales_lo); + values_hi = scale_mode == ScaleMode::kMultiply ? _mm256_mul_ps(values_hi, scales_hi) : _mm256_div_ps(values_hi, scales_hi); + store_bf16x8(dst_row + col, values_lo); + store_bf16x8(dst_row + col + 8, values_hi); + } + for (; col + 8 <= cols; col += 8) { + __m256 values = load_fp8x8_to_ps(src_row + col, table.data()); + __m256 scales = round_ps_to_bf16_ps(_mm256_loadu_ps(scale_row + col)); + values = scale_mode == ScaleMode::kMultiply ? _mm256_mul_ps(values, scales) : _mm256_div_ps(values, scales); + store_bf16x8(dst_row + col, values); + } + if (col < cols) { + for (int64_t i = col; i < cols; ++i) { + float value = table[src_row[i]]; + const float rounded_scale = static_cast(c10::BFloat16(scale_row[i])); + value = scale_mode == ScaleMode::kMultiply ? value * rounded_scale : value / rounded_scale; + store_scalar(dst_row + i, value); + } + } +} + +__attribute__((target("avx2"))) +void dequantize_fp8_row_avx2_bf16_col_repeat_scale( + const uint8_t* src_row, + c10::BFloat16* dst_row, + int64_t cols, + const float* col_scales, + int64_t col_repeat, + const std::array& table, + ScaleMode scale_mode) { + // Column-repeat scales are shared across every row, so walk the row in scale + // spans and keep one rounded scale hot for the whole span. + int64_t col = 0; + while (col < cols) { + const int64_t block_col = col / col_repeat; + const int64_t block_end = std::min(cols, (block_col + 1) * col_repeat); + const __m256 scale_vec = _mm256_set1_ps(static_cast(c10::BFloat16(col_scales[block_col]))); + for (; col + 16 <= block_end; col += 16) { + __m256 values_lo; + __m256 values_hi; + load_fp8x16_to_ps(src_row + col, table.data(), &values_lo, &values_hi); + apply_scale_and_store_bf16x8_const(dst_row + col, values_lo, scale_mode, scale_vec); + apply_scale_and_store_bf16x8_const(dst_row + col + 8, values_hi, scale_mode, scale_vec); + } + for (; col + 8 <= block_end; col += 8) { + __m256 values = load_fp8x8_to_ps(src_row + col, table.data()); + apply_scale_and_store_bf16x8_const(dst_row + col, values, scale_mode, scale_vec); + } + if (col < block_end) { + const float rounded_scale = static_cast(c10::BFloat16(col_scales[block_col])); + for (int64_t i = col; i < block_end; ++i) { + float value = table[src_row[i]]; + value = scale_mode == ScaleMode::kMultiply ? value * rounded_scale : value / rounded_scale; + store_scalar(dst_row + i, value); + } + col = block_end; + } + } +} + +__attribute__((target("avx2"))) +void dequantize_fp8_row_avx2_bf16_block16_scale( + const uint8_t* src_row, + c10::BFloat16* dst_row, + const float* block_scales, + int64_t cols, + const std::array& table, + ScaleMode scale_mode) { + // Block-16 scaling is common in real checkpoints. Walk blocks linearly so the + // hot loop avoids the generic block bookkeeping and repeated divide/index work. + int64_t col = 0; + int64_t block_idx = 0; + for (; col + 16 <= cols; col += 16, ++block_idx) { + const __m256 scale_vec = _mm256_set1_ps(static_cast(c10::BFloat16(block_scales[block_idx]))); + __m256 values_lo; + __m256 values_hi; + load_fp8x16_to_ps(src_row + col, table.data(), &values_lo, &values_hi); + apply_scale_and_store_bf16x8_const(dst_row + col, values_lo, scale_mode, scale_vec); + apply_scale_and_store_bf16x8_const(dst_row + col + 8, values_hi, scale_mode, scale_vec); + } + if (col < cols) { + const float rounded_scale = static_cast(c10::BFloat16(block_scales[block_idx])); + for (int64_t i = col; i < cols; ++i) { + float value = table[src_row[i]]; + value = scale_mode == ScaleMode::kMultiply ? value * rounded_scale : value / rounded_scale; + store_scalar(dst_row + i, value); + } + } +} + +__attribute__((target("avx2"))) +void dequantize_fp8_row_avx2_bf16_const_scale( + const uint8_t* src_row, + c10::BFloat16* dst_row, + int64_t cols, + const std::array& table, + ScaleMode scale_mode, + float rounded_scale) { + // Pre-scale the tiny FP8 lookup table once per row so the hot loop only gathers and stores. + alignas(64) float scaled_table[256]; + for (int i = 0; i < 256; ++i) { + scaled_table[i] = scale_mode == ScaleMode::kMultiply ? table[i] * rounded_scale : table[i] / rounded_scale; + } + int64_t col = 0; + for (; col + 16 <= cols; col += 16) { + __m256 values_lo; + __m256 values_hi; + load_fp8x16_to_ps(src_row + col, scaled_table, &values_lo, &values_hi); + store_bf16x8(dst_row + col, values_lo); + store_bf16x8(dst_row + col + 8, values_hi); + } + for (; col + 8 <= cols; col += 8) { + __m256 values = load_fp8x8_to_ps(src_row + col, scaled_table); + store_bf16x8(dst_row + col, values); + } + if (col < cols) { + alignas(32) float tail[8] = {}; + const int64_t tail_count = cols - col; + for (int64_t i = 0; i < tail_count; ++i) { + tail[i] = scaled_table[src_row[col + i]]; + } + for (int64_t i = 0; i < tail_count; ++i) { + store_scalar(dst_row + col + i, tail[i]); + } + } +} + +__attribute__((target("avx2"))) +void dequantize_fp8_row_avx2_bf16_block_scale( + const uint8_t* src_row, + c10::BFloat16* dst_row, + int64_t cols, + const std::array& table, + const ScaleSpec2D& spec, + ScaleMode scale_mode, + int64_t row) { + const int64_t block_row = row / spec.row_repeat; + int64_t col = 0; + while (col < cols) { + const int64_t block_col = col / spec.col_repeat; + const int64_t block_end = std::min(cols, (block_col + 1) * spec.col_repeat); + const float rounded_scale = static_cast(c10::BFloat16(spec.ptr[block_row * spec.scale_cols + block_col])); + const __m256 scale_vec = _mm256_set1_ps(rounded_scale); + + // Real FP8 checkpoints reuse one scale across an entire block, so keep it hot for the whole span. + for (; col + 16 <= block_end; col += 16) { + __m256 values_lo; + __m256 values_hi; + load_fp8x16_to_ps(src_row + col, table.data(), &values_lo, &values_hi); + apply_scale_and_store_bf16x8_const(dst_row + col, values_lo, scale_mode, scale_vec); + apply_scale_and_store_bf16x8_const(dst_row + col + 8, values_hi, scale_mode, scale_vec); + } + for (; col + 8 <= block_end; col += 8) { + __m256 values = load_fp8x8_to_ps(src_row + col, table.data()); + apply_scale_and_store_bf16x8_const(dst_row + col, values, scale_mode, scale_vec); + } + if (col < block_end) { + alignas(32) float tail[8] = {}; + const int64_t tail_count = block_end - col; + for (int64_t i = 0; i < tail_count; ++i) { + float value = table[src_row[col + i]]; + if (scale_mode != ScaleMode::kNone) { + value = scale_mode == ScaleMode::kMultiply ? value * rounded_scale : value / rounded_scale; + } + tail[i] = value; + } + for (int64_t i = 0; i < tail_count; ++i) { + store_scalar(dst_row + col + i, tail[i]); + } + col = block_end; + } + } +} + +__attribute__((target("avx2,f16c"))) +void dequantize_fp8_row_avx2_fp16( + const uint8_t* src_row, + c10::Half* dst_row, + int64_t cols, + const std::array& table, + const ScaleSpec2D& spec, + ScaleMode scale_mode, + int64_t row) { + int64_t col = 0; + for (; col + 16 <= cols; col += 16) { + __m256 values_lo; + __m256 values_hi; + load_fp8x16_to_ps(src_row + col, table.data(), &values_lo, &values_hi); + apply_scale_and_store_fp16x8(dst_row + col, values_lo, scale_mode, spec, row, col); + apply_scale_and_store_fp16x8(dst_row + col + 8, values_hi, scale_mode, spec, row, col + 8); + } + for (; col + 8 <= cols; col += 8) { + __m256 values = load_fp8x8_to_ps(src_row + col, table.data()); + apply_scale_and_store_fp16x8(dst_row + col, values, scale_mode, spec, row, col); + } + if (col < cols) { + alignas(32) float tail[8] = {}; + const int64_t tail_count = cols - col; + for (int64_t i = 0; i < tail_count; ++i) { + tail[i] = table[src_row[col + i]]; + } + apply_scale_and_store_scalar(dst_row + col, tail, tail_count, scale_mode, spec, row, col); + } +} + +__attribute__((target("avx2,f16c"))) +void dequantize_fp8_row_avx2_fp16_full_scale( + const uint8_t* src_row, + c10::Half* dst_row, + const float* scale_row, + int64_t cols, + const std::array& table, + ScaleMode scale_mode) { + // Full-scale rows are contiguous in memory. Load scales directly from the row + // slice so the hot loop avoids the generic layout switch and per-vector helper. + int64_t col = 0; + for (; col + 16 <= cols; col += 16) { + __m256 values_lo; + __m256 values_hi; + load_fp8x16_to_ps(src_row + col, table.data(), &values_lo, &values_hi); + __m256 scales_lo = round_ps_to_fp16_ps(_mm256_loadu_ps(scale_row + col)); + __m256 scales_hi = round_ps_to_fp16_ps(_mm256_loadu_ps(scale_row + col + 8)); + values_lo = scale_mode == ScaleMode::kMultiply ? _mm256_mul_ps(values_lo, scales_lo) : _mm256_div_ps(values_lo, scales_lo); + values_hi = scale_mode == ScaleMode::kMultiply ? _mm256_mul_ps(values_hi, scales_hi) : _mm256_div_ps(values_hi, scales_hi); + store_fp16x8(dst_row + col, values_lo); + store_fp16x8(dst_row + col + 8, values_hi); + } + for (; col + 8 <= cols; col += 8) { + __m256 values = load_fp8x8_to_ps(src_row + col, table.data()); + __m256 scales = round_ps_to_fp16_ps(_mm256_loadu_ps(scale_row + col)); + values = scale_mode == ScaleMode::kMultiply ? _mm256_mul_ps(values, scales) : _mm256_div_ps(values, scales); + store_fp16x8(dst_row + col, values); + } + if (col < cols) { + for (int64_t i = col; i < cols; ++i) { + float value = table[src_row[i]]; + const float rounded_scale = static_cast(c10::Half(scale_row[i])); + value = scale_mode == ScaleMode::kMultiply ? value * rounded_scale : value / rounded_scale; + store_scalar(dst_row + i, value); + } + } +} + +__attribute__((target("avx2,f16c"))) +void dequantize_fp8_row_avx2_fp16_col_repeat_scale( + const uint8_t* src_row, + c10::Half* dst_row, + int64_t cols, + const float* col_scales, + int64_t col_repeat, + const std::array& table, + ScaleMode scale_mode) { + // Column-repeat scales are shared across every row, so walk the row in scale + // spans and keep one rounded scale hot for the whole span. + int64_t col = 0; + while (col < cols) { + const int64_t block_col = col / col_repeat; + const int64_t block_end = std::min(cols, (block_col + 1) * col_repeat); + const __m256 scale_vec = _mm256_set1_ps(static_cast(c10::Half(col_scales[block_col]))); + for (; col + 16 <= block_end; col += 16) { + __m256 values_lo; + __m256 values_hi; + load_fp8x16_to_ps(src_row + col, table.data(), &values_lo, &values_hi); + apply_scale_and_store_fp16x8_const(dst_row + col, values_lo, scale_mode, scale_vec); + apply_scale_and_store_fp16x8_const(dst_row + col + 8, values_hi, scale_mode, scale_vec); + } + for (; col + 8 <= block_end; col += 8) { + __m256 values = load_fp8x8_to_ps(src_row + col, table.data()); + apply_scale_and_store_fp16x8_const(dst_row + col, values, scale_mode, scale_vec); + } + if (col < block_end) { + const float rounded_scale = static_cast(c10::Half(col_scales[block_col])); + for (int64_t i = col; i < block_end; ++i) { + float value = table[src_row[i]]; + value = scale_mode == ScaleMode::kMultiply ? value * rounded_scale : value / rounded_scale; + store_scalar(dst_row + i, value); + } + col = block_end; + } + } +} + +__attribute__((target("avx2,f16c"))) +void dequantize_fp8_row_avx2_fp16_block16_scale( + const uint8_t* src_row, + c10::Half* dst_row, + const float* block_scales, + int64_t cols, + const std::array& table, + ScaleMode scale_mode) { + // Block-16 scaling is common in real checkpoints. Walk blocks linearly so the + // hot loop avoids the generic block bookkeeping and repeated divide/index work. + int64_t col = 0; + int64_t block_idx = 0; + for (; col + 16 <= cols; col += 16, ++block_idx) { + const __m256 scale_vec = _mm256_set1_ps(static_cast(c10::Half(block_scales[block_idx]))); + __m256 values_lo; + __m256 values_hi; + load_fp8x16_to_ps(src_row + col, table.data(), &values_lo, &values_hi); + apply_scale_and_store_fp16x8_const(dst_row + col, values_lo, scale_mode, scale_vec); + apply_scale_and_store_fp16x8_const(dst_row + col + 8, values_hi, scale_mode, scale_vec); + } + if (col < cols) { + const float rounded_scale = static_cast(c10::Half(block_scales[block_idx])); + for (int64_t i = col; i < cols; ++i) { + float value = table[src_row[i]]; + value = scale_mode == ScaleMode::kMultiply ? value * rounded_scale : value / rounded_scale; + store_scalar(dst_row + i, value); + } + } +} + +__attribute__((target("avx2,f16c"))) +void dequantize_fp8_row_avx2_fp16_const_scale( + const uint8_t* src_row, + c10::Half* dst_row, + int64_t cols, + const std::array& table, + ScaleMode scale_mode, + float rounded_scale) { + // Pre-scale the tiny FP8 lookup table once per row so the hot loop only gathers and stores. + alignas(64) float scaled_table[256]; + for (int i = 0; i < 256; ++i) { + scaled_table[i] = scale_mode == ScaleMode::kMultiply ? table[i] * rounded_scale : table[i] / rounded_scale; + } + int64_t col = 0; + for (; col + 16 <= cols; col += 16) { + __m256 values_lo; + __m256 values_hi; + load_fp8x16_to_ps(src_row + col, scaled_table, &values_lo, &values_hi); + store_fp16x8(dst_row + col, values_lo); + store_fp16x8(dst_row + col + 8, values_hi); + } + for (; col + 8 <= cols; col += 8) { + __m256 values = load_fp8x8_to_ps(src_row + col, scaled_table); + store_fp16x8(dst_row + col, values); + } + if (col < cols) { + alignas(32) float tail[8] = {}; + const int64_t tail_count = cols - col; + for (int64_t i = 0; i < tail_count; ++i) { + tail[i] = scaled_table[src_row[col + i]]; + } + for (int64_t i = 0; i < tail_count; ++i) { + store_scalar(dst_row + col + i, tail[i]); + } + } +} + +__attribute__((target("avx2,f16c"))) +void dequantize_fp8_row_avx2_fp16_block_scale( + const uint8_t* src_row, + c10::Half* dst_row, + int64_t cols, + const std::array& table, + const ScaleSpec2D& spec, + ScaleMode scale_mode, + int64_t row) { + const int64_t block_row = row / spec.row_repeat; + int64_t col = 0; + while (col < cols) { + const int64_t block_col = col / spec.col_repeat; + const int64_t block_end = std::min(cols, (block_col + 1) * spec.col_repeat); + const float rounded_scale = static_cast(c10::Half(spec.ptr[block_row * spec.scale_cols + block_col])); + const __m256 scale_vec = _mm256_set1_ps(rounded_scale); + + // Real FP8 checkpoints reuse one scale across an entire block, so keep it hot for the whole span. + for (; col + 16 <= block_end; col += 16) { + __m256 values_lo; + __m256 values_hi; + load_fp8x16_to_ps(src_row + col, table.data(), &values_lo, &values_hi); + apply_scale_and_store_fp16x8_const(dst_row + col, values_lo, scale_mode, scale_vec); + apply_scale_and_store_fp16x8_const(dst_row + col + 8, values_hi, scale_mode, scale_vec); + } + for (; col + 8 <= block_end; col += 8) { + __m256 values = load_fp8x8_to_ps(src_row + col, table.data()); + apply_scale_and_store_fp16x8_const(dst_row + col, values, scale_mode, scale_vec); + } + if (col < block_end) { + alignas(32) float tail[8] = {}; + const int64_t tail_count = block_end - col; + for (int64_t i = 0; i < tail_count; ++i) { + float value = table[src_row[col + i]]; + if (scale_mode != ScaleMode::kNone) { + value = scale_mode == ScaleMode::kMultiply ? value * rounded_scale : value / rounded_scale; + } + tail[i] = value; + } + for (int64_t i = 0; i < tail_count; ++i) { + store_scalar(dst_row + col + i, tail[i]); + } + col = block_end; + } + } +} + +__attribute__((target("avx2"))) +void dequantize_fp4_row_avx2_bf16( + const uint8_t* src_row, + c10::BFloat16* dst_row, + int64_t cols, + const std::array& table, + const ScaleSpec2D& spec, + ScaleMode scale_mode, + int64_t row) { + int64_t col = 0; + for (; col + 16 <= cols; col += 16) { + __m256 values_lo; + __m256 values_hi; + load_fp4x16_to_ps(src_row + (col / 2), table.data(), &values_lo, &values_hi); + apply_scale_and_store_bf16x8(dst_row + col, values_lo, scale_mode, spec, row, col); + apply_scale_and_store_bf16x8(dst_row + col + 8, values_hi, scale_mode, spec, row, col + 8); + } + for (; col + 8 <= cols; col += 8) { + const __m256 values = load_fp4x8_to_ps(src_row + (col / 2), table.data()); + apply_scale_and_store_bf16x8(dst_row + col, values, scale_mode, spec, row, col); + } + if (col < cols) { + for (int64_t logical_col = col; logical_col < cols; ++logical_col) { + const uint8_t byte = src_row[logical_col / 2]; + const uint8_t nibble = (logical_col & 1) + ? static_cast((byte >> 4) & 0x0F) + : static_cast(byte & 0x0F); + float value = table[nibble]; + if (scale_mode != ScaleMode::kNone) { + const float scale = scale_value_2d(spec, row, logical_col); + value = scale_mode == ScaleMode::kMultiply ? value * scale : value / scale; + } + store_scalar(dst_row + logical_col, value); + } + } +} + +__attribute__((target("avx2"))) +void dequantize_fp4_row_avx2_bf16_block_scale( + const uint8_t* src_row, + c10::BFloat16* dst_row, + int64_t cols, + const std::array& table, + const ScaleSpec2D& spec, + ScaleMode scale_mode, + int64_t row) { + const int64_t block_row = row / spec.row_repeat; + int64_t col = 0; + while (col < cols) { + const int64_t block_col = col / spec.col_repeat; + const int64_t block_end = std::min(cols, (block_col + 1) * spec.col_repeat); + const float rounded_scale = static_cast(c10::BFloat16(spec.ptr[block_row * spec.scale_cols + block_col])); + const __m256 scale_vec = _mm256_set1_ps(rounded_scale); + + for (; col + 16 <= block_end; col += 16) { + __m256 values_lo; + __m256 values_hi; + load_fp4x16_to_ps(src_row + (col / 2), table.data(), &values_lo, &values_hi); + apply_scale_and_store_bf16x8_const(dst_row + col, values_lo, scale_mode, scale_vec); + apply_scale_and_store_bf16x8_const(dst_row + col + 8, values_hi, scale_mode, scale_vec); + } + for (; col + 8 <= block_end; col += 8) { + const __m256 values = load_fp4x8_to_ps(src_row + (col / 2), table.data()); + apply_scale_and_store_bf16x8_const(dst_row + col, values, scale_mode, scale_vec); + } + if (col < block_end) { + for (int64_t logical_col = col; logical_col < block_end; ++logical_col) { + const uint8_t byte = src_row[logical_col / 2]; + const uint8_t nibble = (logical_col & 1) + ? static_cast((byte >> 4) & 0x0F) + : static_cast(byte & 0x0F); + float value = table[nibble]; + if (scale_mode != ScaleMode::kNone) { + value = scale_mode == ScaleMode::kMultiply ? value * rounded_scale : value / rounded_scale; + } + store_scalar(dst_row + logical_col, value); + } + col = block_end; + } + } +} + +__attribute__((target("avx2,f16c"))) +void dequantize_fp4_row_avx2_fp16( + const uint8_t* src_row, + c10::Half* dst_row, + int64_t cols, + const std::array& table, + const ScaleSpec2D& spec, + ScaleMode scale_mode, + int64_t row) { + int64_t col = 0; + for (; col + 16 <= cols; col += 16) { + __m256 values_lo; + __m256 values_hi; + load_fp4x16_to_ps(src_row + (col / 2), table.data(), &values_lo, &values_hi); + apply_scale_and_store_fp16x8(dst_row + col, values_lo, scale_mode, spec, row, col); + apply_scale_and_store_fp16x8(dst_row + col + 8, values_hi, scale_mode, spec, row, col + 8); + } + for (; col + 8 <= cols; col += 8) { + const __m256 values = load_fp4x8_to_ps(src_row + (col / 2), table.data()); + apply_scale_and_store_fp16x8(dst_row + col, values, scale_mode, spec, row, col); + } + if (col < cols) { + for (int64_t logical_col = col; logical_col < cols; ++logical_col) { + const uint8_t byte = src_row[logical_col / 2]; + const uint8_t nibble = (logical_col & 1) + ? static_cast((byte >> 4) & 0x0F) + : static_cast(byte & 0x0F); + float value = table[nibble]; + if (scale_mode != ScaleMode::kNone) { + const float scale = scale_value_2d(spec, row, logical_col); + value = scale_mode == ScaleMode::kMultiply ? value * scale : value / scale; + } + store_scalar(dst_row + logical_col, value); + } + } +} + +__attribute__((target("avx2,f16c"))) +void dequantize_fp4_row_avx2_fp16_block_scale( + const uint8_t* src_row, + c10::Half* dst_row, + int64_t cols, + const std::array& table, + const ScaleSpec2D& spec, + ScaleMode scale_mode, + int64_t row) { + const int64_t block_row = row / spec.row_repeat; + int64_t col = 0; + while (col < cols) { + const int64_t block_col = col / spec.col_repeat; + const int64_t block_end = std::min(cols, (block_col + 1) * spec.col_repeat); + const float rounded_scale = static_cast(c10::Half(spec.ptr[block_row * spec.scale_cols + block_col])); + const __m256 scale_vec = _mm256_set1_ps(rounded_scale); + + for (; col + 16 <= block_end; col += 16) { + __m256 values_lo; + __m256 values_hi; + load_fp4x16_to_ps(src_row + (col / 2), table.data(), &values_lo, &values_hi); + apply_scale_and_store_fp16x8_const(dst_row + col, values_lo, scale_mode, scale_vec); + apply_scale_and_store_fp16x8_const(dst_row + col + 8, values_hi, scale_mode, scale_vec); + } + for (; col + 8 <= block_end; col += 8) { + const __m256 values = load_fp4x8_to_ps(src_row + (col / 2), table.data()); + apply_scale_and_store_fp16x8_const(dst_row + col, values, scale_mode, scale_vec); + } + if (col < block_end) { + for (int64_t logical_col = col; logical_col < block_end; ++logical_col) { + const uint8_t byte = src_row[logical_col / 2]; + const uint8_t nibble = (logical_col & 1) + ? static_cast((byte >> 4) & 0x0F) + : static_cast(byte & 0x0F); + float value = table[nibble]; + if (scale_mode != ScaleMode::kNone) { + value = scale_mode == ScaleMode::kMultiply ? value * rounded_scale : value / rounded_scale; + } + store_scalar(dst_row + logical_col, value); + } + col = block_end; + } + } +} + +__attribute__((target("avx512f,avx512bw,avx512vl,avx512bf16"))) +void dequantize_fp4_row_avx512_bf16( + const uint8_t* src_row, + c10::BFloat16* dst_row, + int64_t cols, + const std::array& table, + const ScaleSpec2D& spec, + ScaleMode scale_mode, + int64_t row) { + int64_t col = 0; + for (; col + 16 <= cols; col += 16) { + const __m512 values = load_fp4x16_to_ps_avx512(src_row + (col / 2), table.data()); + apply_scale_and_store_bf16x16(dst_row + col, values, scale_mode, spec, row, col); + } + if (col < cols) { + for (int64_t logical_col = col; logical_col < cols; ++logical_col) { + const uint8_t byte = src_row[logical_col / 2]; + const uint8_t nibble = (logical_col & 1) + ? static_cast((byte >> 4) & 0x0F) + : static_cast(byte & 0x0F); + float value = table[nibble]; + if (scale_mode != ScaleMode::kNone) { + const float scale = scale_value_2d(spec, row, logical_col); + value = scale_mode == ScaleMode::kMultiply ? value * scale : value / scale; + } + store_scalar(dst_row + logical_col, value); + } + } +} + +__attribute__((target("avx512f,avx512bw,avx512vl,avx512bf16"))) +void dequantize_fp4_row_avx512_bf16_block_scale( + const uint8_t* src_row, + c10::BFloat16* dst_row, + int64_t cols, + const std::array& table, + const ScaleSpec2D& spec, + ScaleMode scale_mode, + int64_t row) { + const int64_t block_row = row / spec.row_repeat; + int64_t col = 0; + while (col < cols) { + const int64_t block_col = col / spec.col_repeat; + const int64_t block_end = std::min(cols, (block_col + 1) * spec.col_repeat); + const float rounded_scale = static_cast(c10::BFloat16(spec.ptr[block_row * spec.scale_cols + block_col])); + const __m512 scale_vec = _mm512_set1_ps(rounded_scale); + + for (; col + 16 <= block_end; col += 16) { + const __m512 values = load_fp4x16_to_ps_avx512(src_row + (col / 2), table.data()); + apply_scale_and_store_bf16x16_const(dst_row + col, values, scale_mode, scale_vec); + } + if (col < block_end) { + for (int64_t logical_col = col; logical_col < block_end; ++logical_col) { + const uint8_t byte = src_row[logical_col / 2]; + const uint8_t nibble = (logical_col & 1) + ? static_cast((byte >> 4) & 0x0F) + : static_cast(byte & 0x0F); + float value = table[nibble]; + if (scale_mode != ScaleMode::kNone) { + value = scale_mode == ScaleMode::kMultiply ? value * rounded_scale : value / rounded_scale; + } + store_scalar(dst_row + logical_col, value); + } + col = block_end; + } + } +} + +__attribute__((target("avx512f,avx512bw,avx512vl,avx512fp16"))) +void dequantize_fp4_row_avx512_fp16( + const uint8_t* src_row, + c10::Half* dst_row, + int64_t cols, + const std::array& table, + const ScaleSpec2D& spec, + ScaleMode scale_mode, + int64_t row) { + int64_t col = 0; + for (; col + 16 <= cols; col += 16) { + const __m512 values = load_fp4x16_to_ps_avx512(src_row + (col / 2), table.data()); + apply_scale_and_store_fp16x16(dst_row + col, values, scale_mode, spec, row, col); + } + if (col < cols) { + for (int64_t logical_col = col; logical_col < cols; ++logical_col) { + const uint8_t byte = src_row[logical_col / 2]; + const uint8_t nibble = (logical_col & 1) + ? static_cast((byte >> 4) & 0x0F) + : static_cast(byte & 0x0F); + float value = table[nibble]; + if (scale_mode != ScaleMode::kNone) { + const float scale = scale_value_2d(spec, row, logical_col); + value = scale_mode == ScaleMode::kMultiply ? value * scale : value / scale; + } + store_scalar(dst_row + logical_col, value); + } + } +} + +__attribute__((target("avx512f,avx512bw,avx512vl,avx512fp16"))) +void dequantize_fp4_row_avx512_fp16_block_scale( + const uint8_t* src_row, + c10::Half* dst_row, + int64_t cols, + const std::array& table, + const ScaleSpec2D& spec, + ScaleMode scale_mode, + int64_t row) { + const int64_t block_row = row / spec.row_repeat; + int64_t col = 0; + while (col < cols) { + const int64_t block_col = col / spec.col_repeat; + const int64_t block_end = std::min(cols, (block_col + 1) * spec.col_repeat); + const float rounded_scale = static_cast(c10::Half(spec.ptr[block_row * spec.scale_cols + block_col])); + const __m512 scale_vec = _mm512_set1_ps(rounded_scale); + + for (; col + 16 <= block_end; col += 16) { + const __m512 values = load_fp4x16_to_ps_avx512(src_row + (col / 2), table.data()); + apply_scale_and_store_fp16x16_const(dst_row + col, values, scale_mode, scale_vec); + } + if (col < block_end) { + for (int64_t logical_col = col; logical_col < block_end; ++logical_col) { + const uint8_t byte = src_row[logical_col / 2]; + const uint8_t nibble = (logical_col & 1) + ? static_cast((byte >> 4) & 0x0F) + : static_cast(byte & 0x0F); + float value = table[nibble]; + if (scale_mode != ScaleMode::kNone) { + value = scale_mode == ScaleMode::kMultiply ? value * rounded_scale : value / rounded_scale; + } + store_scalar(dst_row + logical_col, value); + } + col = block_end; + } + } +} +#else +inline bool cpu_supports_avx2() { + return false; +} + +inline bool cpu_supports_f16c() { + return false; +} + +inline bool cpu_supports_avx512bf16() { + return false; +} + +inline bool cpu_supports_avx512fp16() { + return false; +} +#endif + +template +void dequantize_fp8_scalar( + const uint8_t* src, + T* dst, + int64_t rows, + int64_t cols, + const std::array& table, + const ScaleSpec2D& spec, + ScaleMode scale_mode, + int64_t threads) { + const int64_t grain = std::max(1, rows / clamped_threads(threads)); + at::parallel_for(0, rows, grain, [&](int64_t begin, int64_t end) { + for (int64_t row = begin; row < end; ++row) { + const uint8_t* src_row = src + row * cols; + T* dst_row = dst + row * cols; + for (int64_t col = 0; col < cols; ++col) { + const float value[1] = {table[src_row[col]]}; + apply_scale_and_store_scalar(dst_row + col, value, 1, scale_mode, spec, row, col); + } + } + }); +} + +void dequantize_fp8_2d( + const at::Tensor& source, + at::Tensor& output, + const std::array& table, + const ScaleSpec2D& spec, + ScaleMode scale_mode, + int64_t threads) { + const int64_t rows = source.size(0); + const int64_t cols = source.size(1); + const uint8_t* src = reinterpret_cast(source.const_data_ptr()); + + if (output.scalar_type() == at::kBFloat16) { + c10::BFloat16* dst = output.data_ptr(); +#if GPTQMODEL_FLOATX_X86 && (defined(__GNUC__) || defined(__clang__)) + const bool prefer_avx2_bf16_fp8 = cpu_prefers_avx2_fp8_bf16(); + if (cpu_supports_avx512bf16() && !prefer_avx2_bf16_fp8) { + const int64_t grain = std::max(1, rows / clamped_threads(threads)); + at::parallel_for(0, rows, grain, [&](int64_t begin, int64_t end) { + for (int64_t row = begin; row < end; ++row) { + if (scale_mode != ScaleMode::kNone && spec.layout == ScaleLayout2D::kBlock && spec.col_repeat >= 16) { + dequantize_fp8_row_avx512_bf16_block_scale( + src + row * cols, + dst + row * cols, + cols, + table, + spec, + scale_mode, + row); + } else if (scale_mode != ScaleMode::kNone && + (spec.layout == ScaleLayout2D::kScalar || spec.layout == ScaleLayout2D::kRowRepeat)) { + // Hoist constant-per-row scale rounding out of the inner SIMD loop. + const float rounded_scale = static_cast( + c10::BFloat16(spec.layout == ScaleLayout2D::kScalar ? spec.ptr[0] : spec.ptr[row / spec.row_repeat])); + dequantize_fp8_row_avx512_bf16_const_scale( + src + row * cols, + dst + row * cols, + cols, + table, + scale_mode, + rounded_scale); + } else { + dequantize_fp8_row_avx512_bf16( + src + row * cols, + dst + row * cols, + cols, + table, + spec, + scale_mode, + row); + } + } + }); + return; + } + if (cpu_supports_avx2()) { + // Some AMD AVX-512 parts still retire the FP8 BF16 gather-heavy loop faster + // with AVX2, so let the runtime quirk above steer this path on those hosts. + const int64_t grain = std::max(1, rows / clamped_threads(threads)); + at::parallel_for(0, rows, grain, [&](int64_t begin, int64_t end) { + for (int64_t row = begin; row < end; ++row) { + if (scale_mode != ScaleMode::kNone && spec.layout == ScaleLayout2D::kColRepeat && spec.col_repeat >= 16) { + dequantize_fp8_row_avx2_bf16_col_repeat_scale( + src + row * cols, + dst + row * cols, + cols, + spec.ptr, + spec.col_repeat, + table, + scale_mode); + } else if (scale_mode != ScaleMode::kNone && spec.layout == ScaleLayout2D::kFull) { + dequantize_fp8_row_avx2_bf16_full_scale( + src + row * cols, + dst + row * cols, + spec.ptr + row * spec.cols, + cols, + table, + scale_mode); + } else if (scale_mode != ScaleMode::kNone && spec.layout == ScaleLayout2D::kBlock && + spec.row_repeat == 1 && spec.col_repeat == 16) { + dequantize_fp8_row_avx2_bf16_block16_scale( + src + row * cols, + dst + row * cols, + spec.ptr + row * spec.scale_cols, + cols, + table, + scale_mode); + } else if (scale_mode != ScaleMode::kNone && spec.layout == ScaleLayout2D::kBlock && spec.col_repeat >= 16) { + dequantize_fp8_row_avx2_bf16_block_scale( + src + row * cols, + dst + row * cols, + cols, + table, + spec, + scale_mode, + row); + } else if (scale_mode != ScaleMode::kNone && + (spec.layout == ScaleLayout2D::kScalar || spec.layout == ScaleLayout2D::kRowRepeat)) { + // Hoist constant-per-row scale rounding out of the inner SIMD loop. + const float rounded_scale = static_cast( + c10::BFloat16(spec.layout == ScaleLayout2D::kScalar ? spec.ptr[0] : spec.ptr[row / spec.row_repeat])); + dequantize_fp8_row_avx2_bf16_const_scale( + src + row * cols, + dst + row * cols, + cols, + table, + scale_mode, + rounded_scale); + } else { + dequantize_fp8_row_avx2_bf16( + src + row * cols, + dst + row * cols, + cols, + table, + spec, + scale_mode, + row); + } + } + }); + return; + } +#endif + dequantize_fp8_scalar(src, dst, rows, cols, table, spec, scale_mode, threads); + return; + } + + c10::Half* dst = output.data_ptr(); +#if GPTQMODEL_FLOATX_X86 && (defined(__GNUC__) || defined(__clang__)) + if (cpu_supports_avx512fp16()) { + const int64_t grain = std::max(1, rows / clamped_threads(threads)); + at::parallel_for(0, rows, grain, [&](int64_t begin, int64_t end) { + for (int64_t row = begin; row < end; ++row) { + if (scale_mode != ScaleMode::kNone && spec.layout == ScaleLayout2D::kBlock && spec.col_repeat >= 16) { + dequantize_fp8_row_avx512_fp16_block_scale( + src + row * cols, + dst + row * cols, + cols, + table, + spec, + scale_mode, + row); + } else if (scale_mode != ScaleMode::kNone && + (spec.layout == ScaleLayout2D::kScalar || spec.layout == ScaleLayout2D::kRowRepeat)) { + // Hoist constant-per-row scale rounding out of the inner SIMD loop. + const float rounded_scale = static_cast( + c10::Half(spec.layout == ScaleLayout2D::kScalar ? spec.ptr[0] : spec.ptr[row / spec.row_repeat])); + dequantize_fp8_row_avx512_fp16_const_scale( + src + row * cols, + dst + row * cols, + cols, + table, + scale_mode, + rounded_scale); + } else { + dequantize_fp8_row_avx512_fp16( + src + row * cols, + dst + row * cols, + cols, + table, + spec, + scale_mode, + row); + } + } + }); + return; + } + if (cpu_supports_avx2() && cpu_supports_f16c()) { + const int64_t grain = std::max(1, rows / clamped_threads(threads)); + at::parallel_for(0, rows, grain, [&](int64_t begin, int64_t end) { + for (int64_t row = begin; row < end; ++row) { + if (scale_mode != ScaleMode::kNone && spec.layout == ScaleLayout2D::kColRepeat && spec.col_repeat >= 16) { + dequantize_fp8_row_avx2_fp16_col_repeat_scale( + src + row * cols, + dst + row * cols, + cols, + spec.ptr, + spec.col_repeat, + table, + scale_mode); + } else if (scale_mode != ScaleMode::kNone && spec.layout == ScaleLayout2D::kFull) { + dequantize_fp8_row_avx2_fp16_full_scale( + src + row * cols, + dst + row * cols, + spec.ptr + row * spec.cols, + cols, + table, + scale_mode); + } else if (scale_mode != ScaleMode::kNone && spec.layout == ScaleLayout2D::kBlock && + spec.row_repeat == 1 && spec.col_repeat == 16) { + dequantize_fp8_row_avx2_fp16_block16_scale( + src + row * cols, + dst + row * cols, + spec.ptr + row * spec.scale_cols, + cols, + table, + scale_mode); + } else if (scale_mode != ScaleMode::kNone && spec.layout == ScaleLayout2D::kBlock && spec.col_repeat >= 16) { + dequantize_fp8_row_avx2_fp16_block_scale( + src + row * cols, + dst + row * cols, + cols, + table, + spec, + scale_mode, + row); + } else if (scale_mode != ScaleMode::kNone && + (spec.layout == ScaleLayout2D::kScalar || spec.layout == ScaleLayout2D::kRowRepeat)) { + // Hoist constant-per-row scale rounding out of the inner SIMD loop. + const float rounded_scale = static_cast( + c10::Half(spec.layout == ScaleLayout2D::kScalar ? spec.ptr[0] : spec.ptr[row / spec.row_repeat])); + dequantize_fp8_row_avx2_fp16_const_scale( + src + row * cols, + dst + row * cols, + cols, + table, + scale_mode, + rounded_scale); + } else { + dequantize_fp8_row_avx2_fp16( + src + row * cols, + dst + row * cols, + cols, + table, + spec, + scale_mode, + row); + } + } + }); + return; + } +#endif + dequantize_fp8_scalar(src, dst, rows, cols, table, spec, scale_mode, threads); +} + +template +void dequantize_fp4_scalar( + const uint8_t* src, + T* dst, + int64_t rows, + int64_t packed_cols, + const std::array& table, + const ScaleSpec2D& spec, + ScaleMode scale_mode, + int64_t threads) { + const int64_t cols = packed_cols * 2; + const int64_t grain = std::max(1, rows / clamped_threads(threads)); + at::parallel_for(0, rows, grain, [&](int64_t begin, int64_t end) { + for (int64_t row = begin; row < end; ++row) { + const uint8_t* src_row = src + row * packed_cols; + T* dst_row = dst + row * cols; + for (int64_t packed_col = 0; packed_col < packed_cols; ++packed_col) { + const uint8_t byte = src_row[packed_col]; + const int64_t col = packed_col * 2; + const float values[2] = { + table[byte & 0x0F], + table[(byte >> 4) & 0x0F], + }; + apply_scale_and_store_scalar(dst_row + col, values, 2, scale_mode, spec, row, col); + } + } + }); +} + +template +void dequantize_fp4_vectorized( + const uint8_t* src, + T* dst, + int64_t rows, + int64_t packed_cols, + const std::array& table, + const ScaleSpec2D& spec, + ScaleMode scale_mode, + int64_t threads) { +#if GPTQMODEL_FLOATX_X86 && (defined(__GNUC__) || defined(__clang__)) + if (cpu_supports_avx512bf16()) { + const int64_t cols = packed_cols * 2; + const int64_t grain = std::max(1, rows / clamped_threads(threads)); + if constexpr (std::is_same_v) { + at::parallel_for(0, rows, grain, [&](int64_t begin, int64_t end) { + for (int64_t row = begin; row < end; ++row) { + if (scale_mode != ScaleMode::kNone && spec.layout == ScaleLayout2D::kBlock && spec.col_repeat >= 16) { + dequantize_fp4_row_avx512_bf16_block_scale( + src + row * packed_cols, + dst + row * cols, + cols, + table, + spec, + scale_mode, + row); + } else { + dequantize_fp4_row_avx512_bf16( + src + row * packed_cols, + dst + row * cols, + cols, + table, + spec, + scale_mode, + row); + } + } + }); + return; + } + } + if (cpu_supports_avx512fp16()) { + const int64_t cols = packed_cols * 2; + const int64_t grain = std::max(1, rows / clamped_threads(threads)); + if constexpr (std::is_same_v) { + at::parallel_for(0, rows, grain, [&](int64_t begin, int64_t end) { + for (int64_t row = begin; row < end; ++row) { + if (scale_mode != ScaleMode::kNone && spec.layout == ScaleLayout2D::kBlock && spec.col_repeat >= 16) { + dequantize_fp4_row_avx512_fp16_block_scale( + src + row * packed_cols, + dst + row * cols, + cols, + table, + spec, + scale_mode, + row); + } else { + dequantize_fp4_row_avx512_fp16( + src + row * packed_cols, + dst + row * cols, + cols, + table, + spec, + scale_mode, + row); + } + } + }); + return; + } + } + if (cpu_supports_avx2()) { + const int64_t cols = packed_cols * 2; + const int64_t grain = std::max(1, rows / clamped_threads(threads)); + if constexpr (std::is_same_v) { + at::parallel_for(0, rows, grain, [&](int64_t begin, int64_t end) { + for (int64_t row = begin; row < end; ++row) { + if (scale_mode != ScaleMode::kNone && spec.layout == ScaleLayout2D::kBlock && spec.col_repeat >= 8) { + dequantize_fp4_row_avx2_bf16_block_scale( + src + row * packed_cols, + dst + row * cols, + cols, + table, + spec, + scale_mode, + row); + } else { + dequantize_fp4_row_avx2_bf16( + src + row * packed_cols, + dst + row * cols, + cols, + table, + spec, + scale_mode, + row); + } + } + }); + return; + } + if constexpr (std::is_same_v) { + if (cpu_supports_f16c()) { + at::parallel_for(0, rows, grain, [&](int64_t begin, int64_t end) { + for (int64_t row = begin; row < end; ++row) { + if (scale_mode != ScaleMode::kNone && spec.layout == ScaleLayout2D::kBlock && spec.col_repeat >= 8) { + dequantize_fp4_row_avx2_fp16_block_scale( + src + row * packed_cols, + dst + row * cols, + cols, + table, + spec, + scale_mode, + row); + } else { + dequantize_fp4_row_avx2_fp16( + src + row * packed_cols, + dst + row * cols, + cols, + table, + spec, + scale_mode, + row); + } + } + }); + return; + } + } + } +#endif + dequantize_fp4_scalar(src, dst, rows, packed_cols, table, spec, scale_mode, threads); +} + +at::Tensor empty_target_like(const at::Tensor& source, TargetKind target_kind) { + auto options = at::TensorOptions().device(at::kCPU); + return at::empty_like( + source, + options.dtype(target_kind == TargetKind::kBFloat16 ? at::kBFloat16 : at::kHalf)); +} + +at::Tensor empty_target_for_fp4(const at::Tensor& source, TargetKind target_kind) { + auto sizes = source.sizes().vec(); + TORCH_CHECK(!sizes.empty(), "FP4 source tensor must have at least one dimension"); + sizes.back() *= 2; + return at::empty( + sizes, + at::TensorOptions().device(at::kCPU).dtype( + target_kind == TargetKind::kBFloat16 ? at::kBFloat16 : at::kHalf)); +} + +at::Tensor dequantize_fp8_cpu( + const at::Tensor& source, + const c10::optional& scale, + int64_t scale_mode_value, + int64_t axis, + bool axis_is_none, + int64_t target_dtype_value, + int64_t format_value, + int64_t threads) { + TORCH_CHECK(source.device().is_cpu(), "FP8 source tensor must reside on CPU"); + TORCH_CHECK(source.ndimension() == 1 || source.ndimension() == 2, "FP8 fast path only supports 1D or 2D tensors"); + TORCH_CHECK(source.element_size() == 1, "FP8 fast path expects one byte per source element"); + + const ScaleMode scale_mode = static_cast(scale_mode_value); + const TargetKind target_kind = static_cast(target_dtype_value); + const Fp8Format format = static_cast(format_value); + + at::Tensor src = source.contiguous(); + at::Tensor output = empty_target_like(src, target_kind); + const auto& table = fp8_table(format, target_kind); + + if (src.ndimension() == 1) { + const int64_t length = src.size(0); + const ScaleSpec1D spec = make_scale_spec_1d(scale, length, axis, axis_is_none); + const uint8_t* src_ptr = reinterpret_cast(src.const_data_ptr()); + if (output.scalar_type() == at::kBFloat16) { + c10::BFloat16* dst = output.data_ptr(); + const int64_t grain = std::max(1, length / clamped_threads(threads)); + at::parallel_for(0, length, grain, [&](int64_t begin, int64_t end) { + for (int64_t idx = begin; idx < end; ++idx) { + apply_scale_and_store_scalar_1d( + dst, + table[src_ptr[idx]], + scale_mode, + spec, + idx); + } + }); + } else { + c10::Half* dst = output.data_ptr(); + const int64_t grain = std::max(1, length / clamped_threads(threads)); + at::parallel_for(0, length, grain, [&](int64_t begin, int64_t end) { + for (int64_t idx = begin; idx < end; ++idx) { + apply_scale_and_store_scalar_1d( + dst, + table[src_ptr[idx]], + scale_mode, + spec, + idx); + } + }); + } + return output; + } + + const ScaleSpec2D spec = make_scale_spec_2d(scale, src.size(0), src.size(1), axis, axis_is_none); + dequantize_fp8_2d(src, output, table, spec, scale_mode, threads); + return output; +} + +at::Tensor dequantize_fp4_cpu( + const at::Tensor& source, + const c10::optional& scale, + int64_t scale_mode_value, + int64_t axis, + bool axis_is_none, + int64_t target_dtype_value, + int64_t threads) { + TORCH_CHECK(source.device().is_cpu(), "FP4 source tensor must reside on CPU"); + TORCH_CHECK(source.ndimension() == 1 || source.ndimension() == 2, "FP4 fast path only supports 1D or 2D tensors"); + TORCH_CHECK(source.element_size() == 1, "FP4 fast path expects packed one-byte storage"); + + const ScaleMode scale_mode = static_cast(scale_mode_value); + const TargetKind target_kind = static_cast(target_dtype_value); + + at::Tensor src = source.contiguous(); + at::Tensor output = empty_target_for_fp4(src, target_kind); + const auto& table = fp4_table(target_kind); + + if (src.ndimension() == 1) { + const int64_t packed = src.size(0); + const int64_t length = packed * 2; + const ScaleSpec1D spec = make_scale_spec_1d(scale, length, axis, axis_is_none); + const uint8_t* src_ptr = reinterpret_cast(src.const_data_ptr()); + if (output.scalar_type() == at::kBFloat16) { + c10::BFloat16* dst = output.data_ptr(); + const int64_t grain = std::max(1, length / clamped_threads(threads)); + at::parallel_for(0, length, grain, [&](int64_t begin, int64_t end) { + for (int64_t idx = begin; idx < end; ++idx) { + const uint8_t byte = src_ptr[idx / 2]; + const uint8_t nibble = (idx & 1) ? static_cast((byte >> 4) & 0x0F) : static_cast(byte & 0x0F); + apply_scale_and_store_scalar_1d(dst, table[nibble], scale_mode, spec, idx); + } + }); + } else { + c10::Half* dst = output.data_ptr(); + const int64_t grain = std::max(1, length / clamped_threads(threads)); + at::parallel_for(0, length, grain, [&](int64_t begin, int64_t end) { + for (int64_t idx = begin; idx < end; ++idx) { + const uint8_t byte = src_ptr[idx / 2]; + const uint8_t nibble = (idx & 1) ? static_cast((byte >> 4) & 0x0F) : static_cast(byte & 0x0F); + apply_scale_and_store_scalar_1d(dst, table[nibble], scale_mode, spec, idx); + } + }); + } + return output; + } + + const int64_t rows = src.size(0); + const int64_t packed_cols = src.size(1); + const ScaleSpec2D spec = make_scale_spec_2d(scale, rows, packed_cols * 2, axis, axis_is_none); + const uint8_t* src_ptr = reinterpret_cast(src.const_data_ptr()); + if (output.scalar_type() == at::kBFloat16) { + c10::BFloat16* dst = output.data_ptr(); + dequantize_fp4_vectorized(src_ptr, dst, rows, packed_cols, table, spec, scale_mode, threads); + } else { + c10::Half* dst = output.data_ptr(); + dequantize_fp4_vectorized(src_ptr, dst, rows, packed_cols, table, spec, scale_mode, threads); + } + return output; +} + +} // namespace gptqmodel_floatx + +TORCH_LIBRARY(gptqmodel_floatx, m) { + m.def( + "dequantize_fp8_cpu(Tensor src, Tensor? scale, int scale_mode, int axis, bool axis_is_none, int target_dtype, int format_code, int threads) -> Tensor"); + m.def( + "dequantize_fp4_cpu(Tensor src, Tensor? scale, int scale_mode, int axis, bool axis_is_none, int target_dtype, int threads) -> Tensor"); +} + +TORCH_LIBRARY_IMPL(gptqmodel_floatx, CPU, m) { + m.impl("dequantize_fp8_cpu", TORCH_FN(gptqmodel_floatx::dequantize_fp8_cpu)); + m.impl("dequantize_fp4_cpu", TORCH_FN(gptqmodel_floatx::dequantize_fp4_cpu)); +} diff --git a/gptqmodel_ext/gptq_pro/gptq_pro_kernel.cu b/gptqmodel_ext/gptq_pro/gptq_pro_kernel.cu new file mode 100644 index 000000000..103242e8f --- /dev/null +++ b/gptqmodel_ext/gptq_pro/gptq_pro_kernel.cu @@ -0,0 +1,224 @@ +/* + * Standalone gptq_pro INT4 dequantized GEMM kernel + * + * This file implements the current end-to-end functional scaffold: + * - one warp per CTA + * - symmetric INT4 weights with implicit zero-point 8 + * - explicit shared-memory staging for A / scales / B fragments + * - Tensor Core math via mma.sync with FP32 accumulation + * + * It is intentionally smaller and more conservative than the future Marlin-like + * multi-warp cp.async/ldmatrix pipeline discussed in Project.md/progress.md. + * The validator now covers both fragment helpers and the full kernel result. + */ + +#include "gptq_pro_kernel.cuh" + +struct __align__(16) GptqProSmem { + half A[GPTQ_PRO_M_PER_WARP * GPTQ_PRO_K_PER_WARP]; + half S[GPTQ_PRO_N_PER_WARP]; + uint32_t Bfrag[GPTQ_PRO_BFRAG_WORDS_PER_BUF]; +}; + +__device__ __forceinline__ +half zero_half() { + return __float2half(0.0f); +} + +__device__ __forceinline__ +uint8_t load_b_pair_byte(const uint8_t* __restrict__ B_packed, + int K, int N, + int k_even, int n_col) { + if (n_col >= N || k_even >= K) { + return 0x88u; + } + + const int packed_row = k_even >> 1; + uint8_t byte = B_packed[packed_row * N + n_col]; + if (k_even + 1 >= K) { + byte = static_cast((byte & 0x0Fu) | 0x80u); + } + return byte; +} + +__device__ __forceinline__ +uint16_t pack_lane_bfrag(const uint8_t* __restrict__ B_packed, + int N, int K, + int k_base, int n_base, + int j, int lane) { + const int group_id = lane >> 2; + const int tid4 = lane & 3; + const int n_col = n_base + j * 8 + group_id; + if (n_col >= N) { + return 0x8888u; + } + + const int k01 = k_base + 2 * tid4; + const int k89 = k01 + 8; + const uint8_t byte01 = load_b_pair_byte(B_packed, K, N, k01, n_col); + const uint8_t byte89 = load_b_pair_byte(B_packed, K, N, k89, n_col); + return static_cast(byte01) + | (static_cast(byte89) << 8); +} + +__device__ __forceinline__ +void stage_a_tile(GptqProSmem* __restrict__ smem, + const half* __restrict__ A, + int M, int K, + int m_base, int k_base) { + for (int idx = threadIdx.x; idx < GPTQ_PRO_M_PER_WARP * GPTQ_PRO_K_PER_WARP; idx += blockDim.x) { + const int row = idx / GPTQ_PRO_K_PER_WARP; + const int col = idx % GPTQ_PRO_K_PER_WARP; + const int global_m = m_base + row; + const int global_k = k_base + col; + smem->A[idx] = (global_m < M && global_k < K) + ? A[global_m * K + global_k] + : zero_half(); + } +} + +__device__ __forceinline__ +void stage_scale_row(GptqProSmem* __restrict__ smem, + const half* __restrict__ S, + int N, + int k_base, + int group_size, + int n_base) { + const int group_idx = k_base / group_size; + for (int idx = threadIdx.x; idx < GPTQ_PRO_N_PER_WARP; idx += blockDim.x) { + const int global_n = n_base + idx; + smem->S[idx] = (global_n < N) + ? S[group_idx * N + global_n] + : zero_half(); + } +} + +__device__ __forceinline__ +void stage_bfrag_tiles(GptqProSmem* __restrict__ smem, + const uint8_t* __restrict__ B_packed, + int N, int K, + int k_base, int n_base) { + for (int idx = threadIdx.x; idx < GPTQ_PRO_BFRAG_WORDS_PER_BUF; idx += blockDim.x) { + const int j = idx / GPTQ_PRO_BFRAG_WORDS_PER_TILE; + const int lane_pair = idx % GPTQ_PRO_BFRAG_WORDS_PER_TILE; + const int even_lane = lane_pair * 2; + const uint16_t even_p16 = pack_lane_bfrag(B_packed, N, K, k_base, n_base, j, even_lane); + const uint16_t odd_p16 = pack_lane_bfrag(B_packed, N, K, k_base, n_base, j, even_lane + 1); + smem->Bfrag[idx] = static_cast(even_p16) + | (static_cast(odd_p16) << 16); + } +} + +__device__ __forceinline__ +void do_mma_inner_loop(const GptqProSmem* __restrict__ smem, + float RC[GPTQ_PRO_J_TILES][4]) { + const int lane = threadIdx.x & (GPTQ_PRO_WARP_SIZE - 1); + const int group_id = lane >> 2; + const half zero_point = __float2half(8.0f); + + uint32_t RA[4]; + load_a_fragment_rowmajor(smem->A, lane, RA); + + #pragma unroll + for (int j = 0; j < GPTQ_PRO_J_TILES; ++j) { + const half scale = smem->S[j * 8 + group_id]; + const uint16_t packed_16 = fetch_bfrag_packed16(smem->Bfrag, 0, 0, j, lane); + + uint32_t RB[2]; + decode_bfrag_to_rb(packed_16, scale, zero_point, RB); + mma_f32_m16n8k16(RA, RB, RC[j]); + } +} + +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 +__global__ void gptq_pro_gemm_kernel( + const half* __restrict__ A, + const uint8_t* __restrict__ B_packed, + const half* __restrict__ S, + half* __restrict__ C, + int M, int N, int K, int group_size) +{ + extern __shared__ uint8_t raw_smem[]; + GptqProSmem* smem = reinterpret_cast(raw_smem); + + const int warp_m = blockIdx.x; + const int warp_n = blockIdx.y; + + const int m_base = warp_m * GPTQ_PRO_M_PER_WARP; + const int n_base = warp_n * GPTQ_PRO_N_PER_WARP; + const int lane = threadIdx.x & (GPTQ_PRO_WARP_SIZE - 1); + + float RC[GPTQ_PRO_J_TILES][4]; + #pragma unroll + for (int j = 0; j < GPTQ_PRO_J_TILES; ++j) { + RC[j][0] = 0.0f; + RC[j][1] = 0.0f; + RC[j][2] = 0.0f; + RC[j][3] = 0.0f; + } + + const int num_k_tiles = (K + GPTQ_PRO_K_PER_WARP - 1) / GPTQ_PRO_K_PER_WARP; + for (int t = 0; t < num_k_tiles; ++t) { + const int k_base = t * GPTQ_PRO_K_PER_WARP; + + stage_a_tile(smem, A, M, K, m_base, k_base); + stage_scale_row(smem, S, N, k_base, group_size, n_base); + stage_bfrag_tiles(smem, B_packed, N, K, k_base, n_base); + __syncthreads(); + + do_mma_inner_loop(smem, RC); + __syncthreads(); + } + + const int row_base = lane >> 2; + const int col_pair = 2 * (lane & 3); + #pragma unroll + for (int j = 0; j < GPTQ_PRO_J_TILES; ++j) { + const int m0 = m_base + row_base; + const int m1 = m_base + row_base + 8; + const int n0 = n_base + j * 8 + col_pair + 0; + const int n1 = n_base + j * 8 + col_pair + 1; + + if (m0 < M) { + if (n0 < N) C[m0 * N + n0] = __float2half_rn(RC[j][0]); + if (n1 < N) C[m0 * N + n1] = __float2half_rn(RC[j][1]); + } + if (m1 < M) { + if (n0 < N) C[m1 * N + n0] = __float2half_rn(RC[j][2]); + if (n1 < N) C[m1 * N + n1] = __float2half_rn(RC[j][3]); + } + } +} + +#else // sm80 stub +__global__ void gptq_pro_gemm_kernel( + const half*, const uint8_t*, const half*, half*, int, int, int, int) {} +#define GPTQ_PRO_SM80_STUB 1 +#endif + +cudaError_t gptq_pro_gemm( + const half* A, + const uint8_t* B_packed, + const half* S, + half* C, + int M, int N, int K, int group_size, + cudaStream_t stream) +{ + if (group_size <= 0) { + group_size = K; + } + if ((group_size % GPTQ_PRO_K_PER_WARP) != 0) { + return cudaErrorInvalidValue; + } + + dim3 grid( + (M + GPTQ_PRO_M_PER_WARP - 1) / GPTQ_PRO_M_PER_WARP, + (N + GPTQ_PRO_N_PER_WARP - 1) / GPTQ_PRO_N_PER_WARP, + 1); + dim3 block(GPTQ_PRO_WARP_SIZE, 1, 1); + const size_t smem_bytes = sizeof(GptqProSmem); + + gptq_pro_gemm_kernel<<>>( + A, B_packed, S, C, M, N, K, group_size); + return cudaGetLastError(); +} diff --git a/gptqmodel_ext/gptq_pro/gptq_pro_kernel.cuh b/gptqmodel_ext/gptq_pro/gptq_pro_kernel.cuh new file mode 100644 index 000000000..5039401f6 --- /dev/null +++ b/gptqmodel_ext/gptq_pro/gptq_pro_kernel.cuh @@ -0,0 +1,183 @@ +/* + * Standalone gptq_pro Tensor Core scaffold for Ampere. + * + * Current scope: + * - one warp per CTA + * - symmetric INT4 weights packed as unsigned nibbles with implicit zero-point 8 + * - explicit shared-memory staging for the A tile, per-column scales, and B fragments + * - FP32 accumulation via mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 + * + * This kernel is validator-backed and end-to-end functional, but it is still the + * compact standalone scaffold referenced in README/progress.md rather than the + * future multi-warp cp.async/ldmatrix pipeline discussed in Project.md. + */ + +#pragma once + +#include +#include +#include +#include + +// --------------------------------------------------------------------------- +// Tile dimensions +// --------------------------------------------------------------------------- +static constexpr int GPTQ_PRO_PIPE = 1; // current scaffold uses one staged tile +static constexpr int GPTQ_PRO_KS_TILES = 1; // one mma.sync k16 step per outer K tile +static constexpr int GPTQ_PRO_J_TILES = 8; // 8 x n8 slices -> 64 output cols / warp +static constexpr int GPTQ_PRO_WARP_SIZE = 32; + +static constexpr int GPTQ_PRO_M_PER_WARP = 16; +static constexpr int GPTQ_PRO_N_PER_WARP = GPTQ_PRO_J_TILES * 8; // 64 +static constexpr int GPTQ_PRO_K_PER_WARP = GPTQ_PRO_KS_TILES * 16; // 16 + +// Number of uint32_t words per (ks,j) tile in Bfrag smem (lane-pair packing). +static constexpr int GPTQ_PRO_BFRAG_WORDS_PER_TILE = + GPTQ_PRO_WARP_SIZE / 2; // 16 words + +// Total uint32_t words for all (ks,j) tiles in one smem buffer. +static constexpr int GPTQ_PRO_BFRAG_WORDS_PER_BUF = + GPTQ_PRO_KS_TILES * GPTQ_PRO_J_TILES * GPTQ_PRO_BFRAG_WORDS_PER_TILE; + +// --------------------------------------------------------------------------- +// Helper types +// --------------------------------------------------------------------------- +union Half2Reg { + half2 h2; + uint32_t u32; + uint16_t u16[2]; +}; + +__device__ __forceinline__ +uint32_t pack_half2_reg(half lo, half hi) { + Half2Reg reg; + reg.h2 = __halves2half2(lo, hi); + return reg.u32; +} + +// --------------------------------------------------------------------------- +// Shared-memory layout helpers for the B fragment. +// +// The current scaffold stages one k16 slice at a time, so only ks=0 is used in +// practice, but the helper keeps the (buf, ks, j, lane) contract so the decode +// validator continues to exercise the exact same lane-pair packing logic. +// --------------------------------------------------------------------------- +__device__ __forceinline__ +uint32_t bfrag_smem_addr(const uint32_t* __restrict__ smem_bfrag_base, + int buf, int ks, int j, int lane) { + const int tile_idx = ks * GPTQ_PRO_J_TILES + j; + const int buf_words = GPTQ_PRO_BFRAG_WORDS_PER_BUF; + const int word_idx = buf * buf_words + + tile_idx * GPTQ_PRO_BFRAG_WORDS_PER_TILE + + (lane >> 1); + return static_cast(__cvta_generic_to_shared(smem_bfrag_base)) + + static_cast(word_idx * sizeof(uint32_t)); +} + +__device__ __forceinline__ +uint32_t ld_shared_u32(uint32_t smem_addr) { + uint32_t val; + asm volatile("ld.shared.u32 %0, [%1];" : "=r"(val) : "r"(smem_addr)); + return val; +} + +__device__ __forceinline__ +uint16_t fetch_bfrag_packed16(const uint32_t* __restrict__ smem_bfrag, + int buf, int ks, int j, int lane) { + const uint32_t addr = bfrag_smem_addr(smem_bfrag, buf, ks, j, lane); + const uint32_t word = ld_shared_u32(addr); + return static_cast((lane & 1) ? (word >> 16) : (word & 0xFFFFu)); +} + +// --------------------------------------------------------------------------- +// INT4 nibble decode -> FP16 (scale * (w - 8)). +// +// The standalone scaffold currently models the symmetric GPTQ-style runtime +// where 4-bit weights are stored as unsigned nibbles with an implicit zero-point +// of 8 and a per-group/per-column FP16 scale. +// --------------------------------------------------------------------------- +__device__ __forceinline__ +void decode_bfrag_to_rb(uint16_t packed_16, + half scale, half zero_point, + uint32_t (&RB)[2]) { + const uint32_t p = static_cast(packed_16); + const uint32_t w0 = (p >> 0) & 0xFu; + const uint32_t w1 = (p >> 4) & 0xFu; + const uint32_t w2 = (p >> 8) & 0xFu; + const uint32_t w3 = (p >> 12) & 0xFu; + + const half2 vals01 = __halves2half2(__int2half_rn(static_cast(w0)), + __int2half_rn(static_cast(w1))); + const half2 vals23 = __halves2half2(__int2half_rn(static_cast(w2)), + __int2half_rn(static_cast(w3))); + const half2 zp_h2 = __halves2half2(zero_point, zero_point); + const half2 sc_h2 = __halves2half2(scale, scale); + + Half2Reg rb0, rb1; + rb0.h2 = __hmul2(sc_h2, __hsub2(vals01, zp_h2)); + rb1.h2 = __hmul2(sc_h2, __hsub2(vals23, zp_h2)); + RB[0] = rb0.u32; + RB[1] = rb1.u32; +} + +// --------------------------------------------------------------------------- +// A-fragment packing for mma.sync.aligned.m16n8k16.row.col +// +// This loader follows the same lane ownership used in the validator's scalar +// reference: +// groupID = lane >> 2 +// tid4 = lane & 3 +// rows = {groupID, groupID + 8} +// cols = {2*tid4, 2*tid4 + 1, 2*tid4 + 8, 2*tid4 + 9} +// +// Using explicit register packing avoids the invalid/misaligned ldmatrix path +// that the earlier scaffold emitted for this compact one-warp layout. +// --------------------------------------------------------------------------- +__device__ __forceinline__ +void load_a_fragment_rowmajor(const half* __restrict__ smem_a, + int lane, + uint32_t (&RA)[4]) { + const int group_id = lane >> 2; + const int thread_id = lane & 3; + const int a_col_lo = 2 * thread_id; + const int a_col_hi = a_col_lo + 8; + + RA[0] = pack_half2_reg( + smem_a[(group_id + 0) * GPTQ_PRO_K_PER_WARP + a_col_lo + 0], + smem_a[(group_id + 0) * GPTQ_PRO_K_PER_WARP + a_col_lo + 1]); + RA[1] = pack_half2_reg( + smem_a[(group_id + 8) * GPTQ_PRO_K_PER_WARP + a_col_lo + 0], + smem_a[(group_id + 8) * GPTQ_PRO_K_PER_WARP + a_col_lo + 1]); + RA[2] = pack_half2_reg( + smem_a[(group_id + 0) * GPTQ_PRO_K_PER_WARP + a_col_hi + 0], + smem_a[(group_id + 0) * GPTQ_PRO_K_PER_WARP + a_col_hi + 1]); + RA[3] = pack_half2_reg( + smem_a[(group_id + 8) * GPTQ_PRO_K_PER_WARP + a_col_hi + 0], + smem_a[(group_id + 8) * GPTQ_PRO_K_PER_WARP + a_col_hi + 1]); +} + +// --------------------------------------------------------------------------- +// FP32 accumulating MMA: RC += RA x RB +// --------------------------------------------------------------------------- +__device__ __forceinline__ +void mma_f32_m16n8k16(const uint32_t RA[4], + const uint32_t RB[2], + float RC[4]) { + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}, " + "{%4, %5, %6, %7}, " + "{%8, %9}, " + "{%0, %1, %2, %3};\n" + : "+f"(RC[0]), "+f"(RC[1]), "+f"(RC[2]), "+f"(RC[3]) + : "r"(RA[0]), "r"(RA[1]), "r"(RA[2]), "r"(RA[3]), + "r"(RB[0]), "r"(RB[1])); +} + +cudaError_t gptq_pro_gemm( + const half* A, + const uint8_t* B_packed, + const half* S, + half* C, + int M, int N, int K, int group_size, + cudaStream_t stream); diff --git a/gptqmodel_ext/gptq_pro/gptq_pro_torch.cpp b/gptqmodel_ext/gptq_pro/gptq_pro_torch.cpp new file mode 100644 index 000000000..e692195f1 --- /dev/null +++ b/gptqmodel_ext/gptq_pro/gptq_pro_torch.cpp @@ -0,0 +1,85 @@ +/* + * PyTorch extension entrypoint for the standalone gptq_pro CUDA scaffold. + */ + +#include + +#include +#include +#include +#include + +cudaError_t gptq_pro_gemm( + const half* A, + const uint8_t* B_packed, + const half* S, + half* C, + int M, int N, int K, int group_size, + cudaStream_t stream); + +namespace { + +void check_inputs(const torch::Tensor& a, + const torch::Tensor& b_packed, + const torch::Tensor& scales, + int64_t group_size) { + TORCH_CHECK(a.is_cuda(), "gptq_pro_gemm: activations must be CUDA tensors."); + TORCH_CHECK(b_packed.is_cuda(), "gptq_pro_gemm: packed weights must be CUDA tensors."); + TORCH_CHECK(scales.is_cuda(), "gptq_pro_gemm: scales must be CUDA tensors."); + TORCH_CHECK(a.scalar_type() == torch::kFloat16, "gptq_pro_gemm: activations must be float16."); + TORCH_CHECK(b_packed.scalar_type() == torch::kUInt8, "gptq_pro_gemm: packed weights must be uint8."); + TORCH_CHECK(scales.scalar_type() == torch::kFloat16, "gptq_pro_gemm: scales must be float16."); + TORCH_CHECK(a.dim() == 2, "gptq_pro_gemm: activations must be 2D [M, K]."); + TORCH_CHECK(b_packed.dim() == 2, "gptq_pro_gemm: packed weights must be 2D [(K+1)/2, N]."); + TORCH_CHECK(scales.dim() == 2, "gptq_pro_gemm: scales must be 2D [groups, N]."); + TORCH_CHECK(a.is_contiguous(), "gptq_pro_gemm: activations must be contiguous."); + TORCH_CHECK(b_packed.is_contiguous(), "gptq_pro_gemm: packed weights must be contiguous."); + TORCH_CHECK(scales.is_contiguous(), "gptq_pro_gemm: scales must be contiguous."); + TORCH_CHECK(a.device() == b_packed.device() && a.device() == scales.device(), + "gptq_pro_gemm: all tensors must live on the same CUDA device."); + TORCH_CHECK(group_size > 0 && (group_size % 16) == 0, + "gptq_pro_gemm: group_size must be a positive multiple of 16."); + + const auto k = a.size(1); + const auto packed_rows = b_packed.size(0); + TORCH_CHECK(packed_rows == (k + 1) / 2, + "gptq_pro_gemm: packed weights shape does not match activation K dimension."); + TORCH_CHECK(scales.size(1) == b_packed.size(1), + "gptq_pro_gemm: scales second dimension must equal packed weight N dimension."); + TORCH_CHECK(scales.size(0) == (k + group_size - 1) / group_size, + "gptq_pro_gemm: scales first dimension must equal ceil(K / group_size)."); +} + +} // namespace + +torch::Tensor gptq_pro_gemm_torch(torch::Tensor a, + torch::Tensor b_packed, + torch::Tensor scales, + int64_t group_size) { + check_inputs(a, b_packed, scales, group_size); + + auto out = torch::empty({a.size(0), b_packed.size(1)}, a.options()); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.device().index()); + + const auto status = gptq_pro_gemm( + reinterpret_cast(a.data_ptr()), + b_packed.data_ptr(), + reinterpret_cast(scales.data_ptr()), + reinterpret_cast(out.data_ptr()), + static_cast(a.size(0)), + static_cast(b_packed.size(1)), + static_cast(a.size(1)), + static_cast(group_size), + stream); + + TORCH_CHECK(status == cudaSuccess, + "gptq_pro_gemm launch failed: ", + cudaGetErrorString(status)); + return out; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("gptq_pro_gemm", &gptq_pro_gemm_torch, "GPTQ-Pro FP16xINT4 matmul."); +} diff --git a/gptqmodel_ext/gptq_pro/gptq_pro_validate.cu b/gptqmodel_ext/gptq_pro/gptq_pro_validate.cu new file mode 100644 index 000000000..25b26e8a3 --- /dev/null +++ b/gptqmodel_ext/gptq_pro/gptq_pro_validate.cu @@ -0,0 +1,595 @@ +/* + * GPTQ-Pro kernel validation harness + * + * Implements the two validation milestones from the design TODO list: + * + * TODO 1 — Validate decode-only against a scalar host/device reference for + * one warp fragment. + * + * TODO 2 — Validate one full ks/j MMA step against a reference with FP32 + * accumulation semantics and FP16 inputs. + * + * Each milestone produces a per-thread pass/fail flag stored in a result + * buffer that can be inspected from the host. Both kernels target sm80+. + * + * Build (standalone, no PyTorch): + * nvcc -arch=sm_80 -std=c++17 gptq_pro_validate.cu gptq_pro_kernel.cu -o gptq_pro_validate + */ + +#include "gptq_pro_kernel.cuh" + +#include +#include +#include +#include +#include + +cudaError_t gptq_pro_gemm( + const half* A, + const uint8_t* B_packed, + const half* S, + half* C, + int M, int N, int K, int group_size, + cudaStream_t stream); + +#define CHECK_CUDA(expr) \ + do { \ + cudaError_t _err = (expr); \ + if (_err != cudaSuccess) { \ + fprintf(stderr, "CUDA error at %s:%d: %s\n", \ + __FILE__, __LINE__, cudaGetErrorString(_err)); \ + return 1; \ + } \ + } while (0) + +// ============================================================ +// TODO 1 — Decode-only scalar reference +// ============================================================ + +/// Scalar (lane-independent) decode of one lane-local 4-nibble B fragment. +/// Mirrors exactly what decode_bfrag_to_rb() does on-device, using only +/// host-visible arithmetic so the result can be used as ground truth. +/// +/// @param w0 unsigned 4-bit integer [0, 15] for RB[0] low half +/// @param w1 unsigned 4-bit integer [0, 15] for RB[0] high half +/// @param w2 unsigned 4-bit integer [0, 15] for RB[1] low half +/// @param w3 unsigned 4-bit integer [0, 15] for RB[1] high half +/// @param scale_f FP32 version of the per-group scale +/// @param zp_f FP32 version of the per-group zero-point +/// @param out_rb_f [out] dequantized values in order {w0, w1, w2, w3} +inline void scalar_decode_bfrag(uint32_t w0, uint32_t w1, + uint32_t w2, uint32_t w3, + float scale_f, float zp_f, + float (&out_rb_f)[4]) { + out_rb_f[0] = scale_f * (static_cast(w0) - zp_f); + out_rb_f[1] = scale_f * (static_cast(w1) - zp_f); + out_rb_f[2] = scale_f * (static_cast(w2) - zp_f); + out_rb_f[3] = scale_f * (static_cast(w3) - zp_f); +} + +// --------------------------------------------------------------------------- +// Device-side TODO 1 validation kernel +// +// Each thread (= one warp lane): +// 1. Reads its packed INT4 data from a pre-filled Bfrag smem region. +// 2. Calls fetch_bfrag_packed16() (the real ld.shared.u32 path). +// 3. Calls decode_bfrag_to_rb() to obtain RB[0], RB[1]. +// 4. Compares against a pre-computed float reference stored in ref_rb +// (4 floats per lane, computed by scalar_decode_bfrag on host). +// 5. Sets result[lane] = 1 if all four decoded values match. +// --------------------------------------------------------------------------- +__global__ void validate_decode_kernel( + const uint32_t* __restrict__ bfrag_smem_src, // GPTQ_PRO_BFRAG_WORDS_PER_BUF words (global, will be copied to smem) + float scale_f, + float zp_f, + const float* ref_rb, // [WARP_SIZE * 4] ground-truth per lane + int* result) // [WARP_SIZE] 1=pass, 0=fail +{ + // One warp handles one tile at ks=0, j=0. + const int lane = threadIdx.x & (GPTQ_PRO_WARP_SIZE - 1); + + // ---- Stage Bfrag into shared memory ---- + extern __shared__ uint32_t smem_bfrag[]; + // Each thread copies one word. + if (lane < GPTQ_PRO_BFRAG_WORDS_PER_BUF) { + smem_bfrag[lane] = bfrag_smem_src[lane]; + } + __syncthreads(); + + // ---- TODO 1: fetch via ld.shared.u32 path ---- + const int ks = 0, j = 0, buf = 0; + uint16_t packed_16 = fetch_bfrag_packed16(smem_bfrag, buf, ks, j, lane); + + // ---- Decode nibbles ---- + const half scale = __float2half(scale_f); + const half zp = __float2half(zp_f); + uint32_t RB[2]; + decode_bfrag_to_rb(packed_16, scale, zp, RB); + + // ---- Extract decoded FP16 values and compare with reference ---- + Half2Reg rb0_r, rb1_r; + rb0_r.u32 = RB[0]; + rb1_r.u32 = RB[1]; + + float got_rb0_lo = __half2float(__low2half(rb0_r.h2)); + float got_rb0_hi = __half2float(__high2half(rb0_r.h2)); + float got_rb1_lo = __half2float(__low2half(rb1_r.h2)); + float got_rb1_hi = __half2float(__high2half(rb1_r.h2)); + + // FP16 has ~1e-3 relative error; use 2 ULP tolerance in FP16 space. + const float tol = 2.0f * __half2float(__float2half(1.0f)) * 1e-3f; + + bool ok = (fabsf(got_rb0_lo - ref_rb[lane * 4 + 0]) <= tol + 1e-5f) && + (fabsf(got_rb0_hi - ref_rb[lane * 4 + 1]) <= tol + 1e-5f) && + (fabsf(got_rb1_lo - ref_rb[lane * 4 + 2]) <= tol + 1e-5f) && + (fabsf(got_rb1_hi - ref_rb[lane * 4 + 3]) <= tol + 1e-5f); + result[lane] = ok ? 1 : 0; +} + +// ============================================================ +// TODO 2 — Scalar FP32-accumulating MMA reference +// ============================================================ + +/// Scalar reference for mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32. +/// +/// Interprets A and B as FP16 inputs and accumulates into FP32, matching the +/// hardware accumulation semantics of the tensor-core instruction. +/// +/// PTX fragment ownership for mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32: +/// A: row = groupID for ai in {0,1,4,5}, else groupID + 8 +/// col = 2 * threadID + (i & 1) [+8 for i >= 4] +/// B: row = 2 * threadID + (i & 1) [+8 for i >= 2] +/// col = groupID +/// C/D: row = groupID for ci in {0,1}, else groupID + 8 +/// col = 2 * threadID + (i & 1) +/// where groupID = lane >> 2 and threadID = lane & 3. +/// +/// This scalar function operates on the unpacked FP32 equivalents; the caller +/// is responsible for unpacking and repacking. +inline float fma_f32_from_f16_inputs_proxy(float a, float b, float c) { + const half ha = __float2half(a); + const half hb = __float2half(b); + return c + __half2float(ha) * __half2float(hb); +} + +inline float mma_ref_a_value(int m_row, int k_col) { + return 0.03125f * static_cast(m_row + 1) + + 0.001953125f * static_cast(k_col + 1); +} + +inline float mma_ref_b_value(int k_row, int group_id) { + return 0.0625f * static_cast(k_row + 1) + + 0.0078125f * static_cast(group_id); +} + +inline uint32_t pack_half2(float lo, float hi) { + Half2Reg reg; + const half hlo = __float2half(lo); + const half hhi = __float2half(hi); + memcpy(®.u16[0], &hlo, sizeof(uint16_t)); + memcpy(®.u16[1], &hhi, sizeof(uint16_t)); + return reg.u32; +} + +// --------------------------------------------------------------------------- +// Device-side TODO 2 validation kernel +// +// One thread block = one warp. Validates a single (ks=0, j=0) MMA step. +// Steps: +// 1. Loads pre-dequantized RA[4] and RB[2] from global memory. +// 2. Zeroes RC[4]. +// 3. Calls mma_f32_m16n8k16() (the real tensor-core MMA). +// 4. Computes a scalar reference using fma_f32_from_f16_inputs_proxy(). +// 5. Compares RC output against reference; writes 1=pass / 0=fail. +// --------------------------------------------------------------------------- +__global__ void validate_mma_step_kernel( + const uint32_t* __restrict__ ra_global, // [4] per lane + const uint32_t* __restrict__ rb_global, // [2] per lane + const float* __restrict__ ref_rc, // [4] per lane (float proxy) + int* result) // [WARP_SIZE] 1=pass, 0=fail +{ + const int lane = threadIdx.x & (GPTQ_PRO_WARP_SIZE - 1); + + // Load fragment registers for this lane. + uint32_t RA[4], RB[2]; + float RC[4]; + RA[0] = ra_global[lane * 4 + 0]; + RA[1] = ra_global[lane * 4 + 1]; + RA[2] = ra_global[lane * 4 + 2]; + RA[3] = ra_global[lane * 4 + 3]; + RB[0] = rb_global[lane * 2 + 0]; + RB[1] = rb_global[lane * 2 + 1]; + RC[0] = 0.0f; + RC[1] = 0.0f; + RC[2] = 0.0f; + RC[3] = 0.0f; + + // ---- TODO 2: real tensor-core MMA ---- + mma_f32_m16n8k16(RA, RB, RC); + + // FP32 accumulation with exact FP16 inputs should agree closely with the + // scalar proxy. Allow a tiny epsilon for instruction-order differences. + const float tol = 1e-6f; + bool ok = true; + for (int i = 0; i < 4; ++i) { + if (fabsf(RC[i] - ref_rc[lane * 4 + i]) > tol) { + ok = false; + } + } + result[lane] = ok ? 1 : 0; +} + +// ============================================================ +// Host-side driver — fills test data, launches kernels, checks results +// ============================================================ + +/// Fill Bfrag shared-memory image with deterministic lane-local INT4 payloads. +/// packed_16 for lane `l` in tile (0,0) packs four distinct nibbles: +/// bits [3:0] = (4*l + 0) & 0xF +/// bits [7:4] = (4*l + 1) & 0xF +/// bits [11:8] = (4*l + 2) & 0xF +/// bits [15:12] = (4*l + 3) & 0xF +static void fill_bfrag_test_image(uint32_t* words) { + for (int w = 0; w < GPTQ_PRO_BFRAG_WORDS_PER_BUF; ++w) { + words[w] = 0u; + } + // Tile ks=0, j=0, buf=0 starts at word offset 0. + for (int lane = 0; lane < GPTQ_PRO_WARP_SIZE; ++lane) { + uint32_t w0 = (4u * lane + 0u) & 0xFu; + uint32_t w1 = (4u * lane + 1u) & 0xFu; + uint32_t w2 = (4u * lane + 2u) & 0xFu; + uint32_t w3 = (4u * lane + 3u) & 0xFu; + uint16_t p16 = static_cast( + w0 | (w1 << 4) | (w2 << 8) | (w3 << 12)); + int word_idx = lane >> 1; // lane pair + if (lane & 1) { + words[word_idx] = (words[word_idx] & 0x0000FFFFu) + | (static_cast(p16) << 16); + } else { + words[word_idx] = (words[word_idx] & 0xFFFF0000u) + | static_cast(p16); + } + } +} + +static void fill_end_to_end_a(std::vector& a, int M, int K) { + for (int m = 0; m < M; ++m) { + for (int k = 0; k < K; ++k) { + const float value = 0.125f * static_cast(m + 1) + + 0.03125f * static_cast((k % 7) + 1); + a[m * K + k] = __float2half(value); + } + } +} + +static void fill_end_to_end_b(std::vector& b_packed, int K, int N) { + const int packed_rows = (K + 1) / 2; + for (int kp = 0; kp < packed_rows; ++kp) { + const int k0 = kp * 2; + for (int n = 0; n < N; ++n) { + const uint8_t lo = static_cast(8 + ((k0 + 2 * n) % 3)); + uint8_t hi = static_cast(8 + (((k0 + 1) + 2 * n) % 3)); + if (k0 + 1 >= K) { + hi = 8; + } + b_packed[kp * N + n] = static_cast(lo | (hi << 4)); + } + } +} + +static void fill_end_to_end_s(std::vector& scales, int groups, int N) { + for (int g = 0; g < groups; ++g) { + for (int n = 0; n < N; ++n) { + const float scale = 0.125f * static_cast(1 + ((g + n) % 4)); + scales[g * N + n] = __float2half(scale); + } + } +} + +static float dequant_weight_ref(const std::vector& b_packed, + const std::vector& scales, + int N, int group_size, + int k, int n) { + const uint8_t byte = b_packed[(k >> 1) * N + n]; + const uint32_t nibble = (k & 1) ? ((byte >> 4) & 0xFu) : (byte & 0xFu); + const float scale = __half2float(scales[(k / group_size) * N + n]); + return scale * (static_cast(nibble) - 8.0f); +} + +static bool run_end_to_end_case(int M, int N, int K, int group_size, const char* label) { + const int packed_rows = (K + 1) / 2; + const int groups = (K + group_size - 1) / group_size; + + std::vector h_a(M * K); + std::vector h_b(packed_rows * N); + std::vector h_s(groups * N); + std::vector h_c(M * N, __float2half(0.0f)); + + fill_end_to_end_a(h_a, M, K); + fill_end_to_end_b(h_b, K, N); + fill_end_to_end_s(h_s, groups, N); + + half* d_a = nullptr; + half* d_s = nullptr; + half* d_c = nullptr; + uint8_t* d_b = nullptr; + + auto fail = [&](const char* what, cudaError_t err) { + std::fprintf(stderr, " FAIL %s: %s\n", what, cudaGetErrorString(err)); + if (d_a) cudaFree(d_a); + if (d_s) cudaFree(d_s); + if (d_c) cudaFree(d_c); + if (d_b) cudaFree(d_b); + return false; + }; + + cudaError_t err = cudaMalloc(&d_a, h_a.size() * sizeof(half)); + if (err != cudaSuccess) return fail("cudaMalloc(d_a)", err); + err = cudaMalloc(&d_s, h_s.size() * sizeof(half)); + if (err != cudaSuccess) return fail("cudaMalloc(d_s)", err); + err = cudaMalloc(&d_c, h_c.size() * sizeof(half)); + if (err != cudaSuccess) return fail("cudaMalloc(d_c)", err); + err = cudaMalloc(&d_b, h_b.size() * sizeof(uint8_t)); + if (err != cudaSuccess) return fail("cudaMalloc(d_b)", err); + + err = cudaMemcpy(d_a, h_a.data(), h_a.size() * sizeof(half), cudaMemcpyHostToDevice); + if (err != cudaSuccess) return fail("cudaMemcpy(d_a)", err); + err = cudaMemcpy(d_s, h_s.data(), h_s.size() * sizeof(half), cudaMemcpyHostToDevice); + if (err != cudaSuccess) return fail("cudaMemcpy(d_s)", err); + err = cudaMemcpy(d_b, h_b.data(), h_b.size() * sizeof(uint8_t), cudaMemcpyHostToDevice); + if (err != cudaSuccess) return fail("cudaMemcpy(d_b)", err); + err = cudaMemset(d_c, 0, h_c.size() * sizeof(half)); + if (err != cudaSuccess) return fail("cudaMemset(d_c)", err); + + err = gptq_pro_gemm(d_a, d_b, d_s, d_c, M, N, K, group_size, 0); + if (err != cudaSuccess) return fail("gptq_pro_gemm launch", err); + err = cudaDeviceSynchronize(); + if (err != cudaSuccess) return fail("cudaDeviceSynchronize()", err); + err = cudaMemcpy(h_c.data(), d_c, h_c.size() * sizeof(half), cudaMemcpyDeviceToHost); + if (err != cudaSuccess) return fail("cudaMemcpy(h_c)", err); + + int mismatches = 0; + for (int m = 0; m < M; ++m) { + for (int n = 0; n < N; ++n) { + float acc = 0.0f; + for (int k = 0; k < K; ++k) { + const float a = __half2float(h_a[m * K + k]); + const float w = dequant_weight_ref(h_b, h_s, N, group_size, k, n); + acc = fma_f32_from_f16_inputs_proxy(a, w, acc); + } + const float expect = __half2float(__float2half(acc)); + const float got = __half2float(h_c[m * N + n]); + if (fabsf(got - expect) > 1e-3f) { + if (mismatches < 8) { + std::fprintf(stderr, + " FAIL %s at (%d, %d): got=%f expect=%f\n", + label, m, n, got, expect); + } + ++mismatches; + } + } + } + + cudaFree(d_a); + cudaFree(d_s); + cudaFree(d_c); + cudaFree(d_b); + + if (mismatches != 0) { + std::fprintf(stderr, " %s mismatches: %d\n", label, mismatches); + return false; + } + + std::printf(" PASS %s\n", label); + return true; +} + +#ifndef GPTQ_PRO_VALIDATE_SKIP_MAIN + +int main() { + const float scale_f = 0.015625f; // 2^-6, exact in FP16 + const float zp_f = 8.0f; // unsigned-to-signed shift + + // ----------------------------------------------------------------------- + // TODO 1 — Decode validation + // ----------------------------------------------------------------------- + printf("=== TODO 1: decode-only validation ===\n"); + + // Build host Bfrag image. + const int bfrag_words = GPTQ_PRO_BFRAG_WORDS_PER_BUF; + uint32_t h_bfrag[bfrag_words]; + fill_bfrag_test_image(h_bfrag); + + // Compute scalar reference for every lane. + float h_ref_rb[GPTQ_PRO_WARP_SIZE * 4]; + for (int lane = 0; lane < GPTQ_PRO_WARP_SIZE; ++lane) { + uint32_t w0 = (4u * lane + 0u) & 0xFu; + uint32_t w1 = (4u * lane + 1u) & 0xFu; + uint32_t w2 = (4u * lane + 2u) & 0xFu; + uint32_t w3 = (4u * lane + 3u) & 0xFu; + float decoded[4]; + scalar_decode_bfrag(w0, w1, w2, w3, scale_f, zp_f, decoded); + for (int i = 0; i < 4; ++i) { + h_ref_rb[lane * 4 + i] = decoded[i]; + } + } + + // Allocate device memory. + uint32_t* d_bfrag; + float* d_ref_rb; + int* d_result1; + CHECK_CUDA(cudaMalloc(&d_bfrag, bfrag_words * sizeof(uint32_t))); + CHECK_CUDA(cudaMalloc(&d_ref_rb, GPTQ_PRO_WARP_SIZE * 4 * sizeof(float))); + CHECK_CUDA(cudaMalloc(&d_result1, GPTQ_PRO_WARP_SIZE * sizeof(int))); + + CHECK_CUDA(cudaMemcpy( + d_bfrag, h_bfrag, bfrag_words * sizeof(uint32_t), cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy( + d_ref_rb, h_ref_rb, GPTQ_PRO_WARP_SIZE * 4 * sizeof(float), + cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemset(d_result1, 0, GPTQ_PRO_WARP_SIZE * sizeof(int))); + + size_t smem_bytes = GPTQ_PRO_BFRAG_WORDS_PER_BUF * sizeof(uint32_t); + validate_decode_kernel<<<1, GPTQ_PRO_WARP_SIZE, smem_bytes>>>( + d_bfrag, scale_f, zp_f, d_ref_rb, d_result1); + CHECK_CUDA(cudaGetLastError()); + CHECK_CUDA(cudaDeviceSynchronize()); + + int h_result1[GPTQ_PRO_WARP_SIZE]; + CHECK_CUDA(cudaMemcpy( + h_result1, d_result1, GPTQ_PRO_WARP_SIZE * sizeof(int), + cudaMemcpyDeviceToHost)); + + int pass1 = 0; + for (int i = 0; i < GPTQ_PRO_WARP_SIZE; ++i) { + const bool ok = (h_result1[i] == 1); + pass1 += ok ? 1 : 0; + if (!ok) printf(" FAIL lane %d\n", i); + } + printf(" %d / %d lanes passed\n", pass1, GPTQ_PRO_WARP_SIZE); + + CHECK_CUDA(cudaFree(d_bfrag)); + CHECK_CUDA(cudaFree(d_ref_rb)); + CHECK_CUDA(cudaFree(d_result1)); + + // ----------------------------------------------------------------------- + // TODO 2 — MMA step validation + // ----------------------------------------------------------------------- + printf("=== TODO 2: ks/j MMA step validation ===\n"); + + // Build synthetic RA, RB fragments from the PTX-defined fragment ownership. + // Both A and B vary with their logical coordinates, so row/column mix-ups in + // either fragment contract now perturb the final D fragment immediately. + float h_a_tile[GPTQ_PRO_M_PER_WARP][GPTQ_PRO_K_PER_WARP]; + float h_b_tile[GPTQ_PRO_K_PER_WARP][8]; + for (int m = 0; m < GPTQ_PRO_M_PER_WARP; ++m) { + for (int k = 0; k < GPTQ_PRO_K_PER_WARP; ++k) { + h_a_tile[m][k] = mma_ref_a_value(m, k); + } + } + for (int k = 0; k < GPTQ_PRO_K_PER_WARP; ++k) { + for (int n = 0; n < 8; ++n) { + h_b_tile[k][n] = mma_ref_b_value(k, n); + } + } + + uint32_t h_ra[GPTQ_PRO_WARP_SIZE * 4]; + uint32_t h_rb[GPTQ_PRO_WARP_SIZE * 2]; + for (int lane = 0; lane < GPTQ_PRO_WARP_SIZE; ++lane) { + const int group_id = lane >> 2; + const int thread_id = lane & 3; + const int a_col_lo = 2 * thread_id; + const int a_col_hi = a_col_lo + 8; + const int row0 = 2 * thread_id + 0; + const int row1 = 2 * thread_id + 1; + + h_ra[lane * 4 + 0] = pack_half2( + h_a_tile[group_id + 0][a_col_lo + 0], + h_a_tile[group_id + 0][a_col_lo + 1]); + h_ra[lane * 4 + 1] = pack_half2( + h_a_tile[group_id + 8][a_col_lo + 0], + h_a_tile[group_id + 8][a_col_lo + 1]); + h_ra[lane * 4 + 2] = pack_half2( + h_a_tile[group_id + 0][a_col_hi + 0], + h_a_tile[group_id + 0][a_col_hi + 1]); + h_ra[lane * 4 + 3] = pack_half2( + h_a_tile[group_id + 8][a_col_hi + 0], + h_a_tile[group_id + 8][a_col_hi + 1]); + + h_rb[lane * 2 + 0] = pack_half2( + h_b_tile[row0 + 0][group_id], + h_b_tile[row1 + 0][group_id]); + h_rb[lane * 2 + 1] = pack_half2( + h_b_tile[row0 + 8][group_id], + h_b_tile[row1 + 8][group_id]); + } + + // Scalar reference: each D[m][n] = sum_{k=0}^{15} A[m][k] * B[k][n]. + // For m16n8k16.row.col.f32 the lane-local D fragment is a 2-row x 2-column + // tile, so compare against the full 16x8 GEMM reference rather than a + // collapsed per-column reduction. + // rows = {lane >> 2, lane >> 2 + 8} + // cols = {2 * (lane & 3), 2 * (lane & 3) + 1} + float h_ref_rc[GPTQ_PRO_WARP_SIZE * 4]; + { + float h_d_tile[GPTQ_PRO_M_PER_WARP][8]; + for (int m = 0; m < GPTQ_PRO_M_PER_WARP; ++m) { + for (int n = 0; n < 8; ++n) { + float acc = 0.0f; + for (int k = 0; k < GPTQ_PRO_K_PER_WARP; ++k) { + acc = fma_f32_from_f16_inputs_proxy( + h_a_tile[m][k], h_b_tile[k][n], acc); + } + h_d_tile[m][n] = acc; + } + } + for (int lane = 0; lane < GPTQ_PRO_WARP_SIZE; ++lane) { + const int row_base = lane >> 2; + const int col_base = 2 * (lane & 3); + h_ref_rc[lane * 4 + 0] = h_d_tile[row_base + 0][col_base + 0]; + h_ref_rc[lane * 4 + 1] = h_d_tile[row_base + 0][col_base + 1]; + h_ref_rc[lane * 4 + 2] = h_d_tile[row_base + 8][col_base + 0]; + h_ref_rc[lane * 4 + 3] = h_d_tile[row_base + 8][col_base + 1]; + } + } + + uint32_t* d_ra; + uint32_t* d_rb; + float* d_ref_rc; + int* d_result2; + CHECK_CUDA(cudaMalloc(&d_ra, GPTQ_PRO_WARP_SIZE * 4 * sizeof(uint32_t))); + CHECK_CUDA(cudaMalloc(&d_rb, GPTQ_PRO_WARP_SIZE * 2 * sizeof(uint32_t))); + CHECK_CUDA(cudaMalloc(&d_ref_rc, GPTQ_PRO_WARP_SIZE * 4 * sizeof(float))); + CHECK_CUDA(cudaMalloc(&d_result2, GPTQ_PRO_WARP_SIZE * sizeof(int))); + + CHECK_CUDA(cudaMemcpy( + d_ra, h_ra, GPTQ_PRO_WARP_SIZE * 4 * sizeof(uint32_t), + cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy( + d_rb, h_rb, GPTQ_PRO_WARP_SIZE * 2 * sizeof(uint32_t), + cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemcpy( + d_ref_rc, h_ref_rc, GPTQ_PRO_WARP_SIZE * 4 * sizeof(float), + cudaMemcpyHostToDevice)); + CHECK_CUDA(cudaMemset(d_result2, 0, GPTQ_PRO_WARP_SIZE * sizeof(int))); + + validate_mma_step_kernel<<<1, GPTQ_PRO_WARP_SIZE>>>( + d_ra, d_rb, d_ref_rc, d_result2); + CHECK_CUDA(cudaGetLastError()); + CHECK_CUDA(cudaDeviceSynchronize()); + + int h_result2[GPTQ_PRO_WARP_SIZE]; + CHECK_CUDA(cudaMemcpy( + h_result2, d_result2, GPTQ_PRO_WARP_SIZE * sizeof(int), + cudaMemcpyDeviceToHost)); + + int pass2 = 0; + for (int i = 0; i < GPTQ_PRO_WARP_SIZE; ++i) { + const bool ok = (h_result2[i] == 1); + pass2 += ok ? 1 : 0; + if (!ok) printf(" FAIL lane %d\n", i); + } + printf(" %d / %d lanes passed\n", pass2, GPTQ_PRO_WARP_SIZE); + + CHECK_CUDA(cudaFree(d_ra)); + CHECK_CUDA(cudaFree(d_rb)); + CHECK_CUDA(cudaFree(d_ref_rc)); + CHECK_CUDA(cudaFree(d_result2)); + + // ----------------------------------------------------------------------- + // TODO 3 — End-to-end kernel validation + // ----------------------------------------------------------------------- + printf("=== TODO 3: end-to-end kernel validation ===\n"); + int pass3 = 0; + pass3 += run_end_to_end_case(16, 64, 16, 16, "aligned-16x64x16") ? 1 : 0; + pass3 += run_end_to_end_case(13, 41, 29, 16, "edge-13x41x29") ? 1 : 0; + printf(" %d / %d cases passed\n", pass3, 2); + + // ----------------------------------------------------------------------- + int total = pass1 + pass2 + pass3; + int total_max = 2 * GPTQ_PRO_WARP_SIZE + 2; + printf("\n=== Overall: %d / %d checks passed ===\n", total, total_max); + return (total == total_max) ? 0 : 1; +} + +#endif // GPTQ_PRO_VALIDATE_SKIP_MAIN diff --git a/gptqmodel_ext/machete/generate.py b/gptqmodel_ext/machete/generate.py index 52bd805de..ffdff465c 100644 --- a/gptqmodel_ext/machete/generate.py +++ b/gptqmodel_ext/machete/generate.py @@ -26,15 +26,28 @@ _CUTLASS_PYTHON_DIR = _CUTLASS_ROOT / "python" -_CUTLASS_PYTHON_DIR.mkdir(parents=True, exist_ok=True) -if str(_CUTLASS_EXT_DIR) not in sys.path: - sys.path.append(str(_CUTLASS_EXT_DIR)) -if _CUTLASS_PYTHON_DIR.exists() and str(_CUTLASS_PYTHON_DIR) not in sys.path: - sys.path.append(str(_CUTLASS_PYTHON_DIR)) -if not _CUTLASS_PYTHON_DIR.exists(): +def _prepend_sys_path(path: Path) -> None: + path_text = str(path) + if path_text in sys.path: + sys.path.remove(path_text) + sys.path.insert(0, path_text) + + +def _cutlass_python_bindings_present(path: Path) -> bool: + return ( + (path / "cutlass_library.py").is_file() + or (path / "cutlass_library" / "__init__.py").is_file() + ) + + +_prepend_sys_path(_CUTLASS_EXT_DIR) +if _cutlass_python_bindings_present(_CUTLASS_PYTHON_DIR): + _prepend_sys_path(_CUTLASS_PYTHON_DIR) +else: raise RuntimeError( - "CUTLASS python bindings not found. Set GPTQMODEL_CUTLASS_DIR to a valid CUTLASS checkout." + "CUTLASS python bindings not found under " + f"`{_CUTLASS_PYTHON_DIR}`. Set GPTQMODEL_CUTLASS_DIR to a valid CUTLASS checkout." ) from vllm_cutlass_library_extension import ( diff --git a/gptqmodel_ext/machete/machete_mainloop.cuh b/gptqmodel_ext/machete/machete_mainloop.cuh index 2f52a6b7a..8019e2cf5 100644 --- a/gptqmodel_ext/machete/machete_mainloop.cuh +++ b/gptqmodel_ext/machete/machete_mainloop.cuh @@ -154,6 +154,7 @@ struct MacheteCollectiveMma { struct DispatchPolicy { constexpr static int Stages = PipelineStages; using ClusterShape = ClusterShape_MNK; + using ArchTag = arch::Sm90; using Schedule = KernelScheduleType; }; diff --git a/gptqmodel_ext/machete/machete_mm_kernel.cuh b/gptqmodel_ext/machete/machete_mm_kernel.cuh index cc50e68b0..b99bb5177 100644 --- a/gptqmodel_ext/machete/machete_mm_kernel.cuh +++ b/gptqmodel_ext/machete/machete_mm_kernel.cuh @@ -136,8 +136,9 @@ struct MacheteKernelTemplate { "Currently token and channel scales (if present) must be float " "(and if one is present the other must be too)"); - using StoreEpilogueCompute = typename cutlass::epilogue::fusion::Sm90EVT< - cutlass::epilogue::fusion::Sm90AccFetch>; + using StoreEpilogueCompute = + typename vllm::c3x::TrivialEpilogue::EVTCompute; using EVTCompute = std::conditional_t= (48 << 10)) { + cudaError_t cuda_status = cudaFuncSetAttribute( + cutlass::device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + TORCH_CHECK( + cuda_status == cudaSuccess, + "Machete kernel failed to set max dynamic shared memory size to ", + smem_size, ": ", cudaGetErrorString(cuda_status)); + } cutlass::Status status = gemm_op.initialize(args, workspace, stream); TORCH_CHECK(status == cutlass::Status::kSuccess, - "Machete kernel failed to initialize workspace"); + "Machete kernel failed to initialize workspace: ", + cutlassGetStatusString(status)); status = gemm_op.run(stream); - TORCH_CHECK(status == cutlass::Status::kSuccess, "Machete kernel failed"); + TORCH_CHECK(status == cutlass::Status::kSuccess, "Machete kernel failed: ", + cutlassGetStatusString(status)); } }; diff --git a/gptqmodel_ext/machete/machete_pytorch.cu b/gptqmodel_ext/machete/machete_pytorch.cu index 05a51ee21..89454b365 100644 --- a/gptqmodel_ext/machete/machete_pytorch.cu +++ b/gptqmodel_ext/machete/machete_pytorch.cu @@ -3,6 +3,7 @@ #include "core/scalar_type.hpp" #include "core/registration.h" +#include namespace machete { @@ -60,13 +61,19 @@ torch::Tensor prepack_B( .maybe_group_scales_type = maybe_group_scales_type}); } -TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { +TORCH_LIBRARY(gptqmodel_machete, m) { + m.def("machete_prepack_B(Tensor B, ScalarType a_type, int b_type_id, ScalarType? group_scales_type=None) -> Tensor"); + m.def("machete_mm(Tensor A, Tensor B, int b_type_id, ScalarType? out_type=None, Tensor? group_scales=None, Tensor? group_zeros=None, int? group_size=None, Tensor? channel_scales=None, Tensor? token_scales=None, str? schedule=None) -> Tensor"); + m.def("machete_supported_schedules(ScalarType a_type, int b_type_id, ScalarType? group_scales_type=None, ScalarType? group_zeros_type=None, ScalarType? channel_scales_type=None, ScalarType? token_scales_type=None, ScalarType? out_type=None) -> str[]"); +} + +TORCH_LIBRARY_IMPL(gptqmodel_machete, CUDA, m) { m.impl("machete_prepack_B", &prepack_B); m.impl("machete_mm", &mm); } // use CatchAll since supported_schedules has no tensor arguments -TORCH_LIBRARY_IMPL(TORCH_EXTENSION_NAME, CatchAll, m) { +TORCH_LIBRARY_IMPL(gptqmodel_machete, CatchAll, m) { m.impl("machete_supported_schedules", &supported_schedules); } diff --git a/gptqmodel_ext/marlin/generate_kernels.py b/gptqmodel_ext/marlin/generate_kernels.py index f69b38db9..482b9fff4 100644 --- a/gptqmodel_ext/marlin/generate_kernels.py +++ b/gptqmodel_ext/marlin/generate_kernels.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import itertools +import sys from pathlib import Path import jinja2 @@ -49,27 +50,78 @@ DTYPES = ["fp16", "bf16"] +def _is_4bit_weight(scalar_type: str) -> bool: + return scalar_type in { + "vllm::kU4", + "vllm::kU4B8", + "vllm::kFE2M1f", + } + + def remove_old_kernels() -> None: root = Path(__file__).parent for path in root.glob("kernel_*.cu"): path.unlink(missing_ok=True) -def _write_kernel_file(scalar_type: str, dtype: str, templates: list[str]) -> Path: - root = Path(__file__).parent + +def _kernel_output_path(root: Path, scalar_type: str, dtype: str) -> Path: scalar_suffix = scalar_type.split("::", 1)[1].lower() if "::" in scalar_type else scalar_type.lower() - output_path = root / f"kernel_{dtype}_{scalar_suffix}.cu" + return root / f"kernel_{dtype}_{scalar_suffix}.cu" + +def _render_kernel_file_text(scalar_type: str, dtype: str, templates: list[str]) -> str: lines = [FILE_HEAD, "", f"// Instantiations for dtype={dtype}, weight={scalar_type}", ""] lines.append("\n".join(templates)) lines.append("") lines.append(FILE_TAIL) + return "\n".join(lines) + - output_path.write_text("\n".join(lines), encoding="utf-8") +def _write_kernel_file(root: Path, scalar_type: str, dtype: str, templates: list[str]) -> Path: + output_path = _kernel_output_path(root, scalar_type, dtype) + output_path.write_text(_render_kernel_file_text(scalar_type, dtype, templates), encoding="utf-8") return output_path +def build_expected_kernels(root: Path | None = None) -> dict[Path, str]: + root = root or Path(__file__).parent + expected: dict[Path, str] = {} + for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES): + templates = render_templates_for_combo(scalar_type, dtype) + if not templates: + continue + output_path = _kernel_output_path(root, scalar_type, dtype) + expected[output_path] = _render_kernel_file_text(scalar_type, dtype, templates) + + if not expected: + raise RuntimeError("No marlin kernels were generated; check template configuration.") + return expected + + +def generated_kernels_are_current(root: Path | None = None) -> bool: + root = root or Path(__file__).parent + expected = build_expected_kernels(root) + expected_names = {path.name for path in expected} + existing_names = {path.name for path in root.glob("kernel_*.cu")} + if existing_names != expected_names: + return False + + for output_path, expected_text in expected.items(): + try: + current_text = output_path.read_text(encoding="utf-8") + except OSError: + return False + if current_text != expected_text: + return False + return True + + def render_templates_for_combo(scalar_type: str, dtype: str) -> list[str]: results: list[str] = [] + stage_values = ["pipe_stages"] + if dtype == "fp16": + # Turing uses a shorter pipeline depth than Ampere+. + stage_values.insert(0, 2) for group_blocks, m_blocks, thread_configs in itertools.product( GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS): @@ -86,9 +138,9 @@ def render_templates_for_combo(scalar_type: str, dtype: str) -> list[str]: if m_blocks > 1 and thread_configs[0] != 64: continue - # we only support channelwise quantization and group_size == 128 - # for fp8 - if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]: + # FP8 weights support channelwise/group128 in both fp16 and bf16, and + # group32 microscaling (MXFP8) in bf16. + if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 2, 8]: continue # nvfp4 only supports group_size == 16 # mxfp4 only supports group_size == 32 @@ -118,45 +170,57 @@ def render_templates_for_combo(scalar_type: str, dtype: str) -> list[str]: if dtype == "fp16": # we cannot safely dequantize e8m0 to fp16, so skip this continue + elif scalar_type == "vllm::kFE4M3fn" and group_blocks == 2: + s_type = "vllm::kFE8M0fnu" + if dtype == "fp16": + # MXFP8 is only supported with bf16 compute. + continue elif dtype == "fp16": s_type = "vllm::kFloat16" elif dtype == "bf16": s_type = "vllm::kBFloat16" for is_zp_float in is_zp_float_list: - template_str = jinja2.Template(TEMPLATE).render( - scalar_t=c_dtype, - w_type_id=scalar_type + ".id()", - s_type_id=s_type + ".id()", - threads=threads, - thread_m_blocks=max(m_blocks, 1), - thread_n_blocks=n_blocks, - thread_k_blocks=k_blocks, - m_block_size_8=m_blocks == 0.5, - stages="pipe_stages", - group_blocks=group_blocks, - is_zp_float=is_zp_float, - ) - - results.append(template_str) + for stage_value in stage_values: + if ( + stage_value == 2 + and _is_4bit_weight(scalar_type) + and max(m_blocks, 1) * 2 > k_blocks + ): + # Our dense Turing kernels need enough B-stage capacity to + # cover the output tile. For 4-bit weights this rules out + # the larger M tiles when thread_k_blocks == 4. + continue + template_str = jinja2.Template(TEMPLATE).render( + scalar_t=c_dtype, + w_type_id=scalar_type + ".id()", + s_type_id=s_type + ".id()", + threads=threads, + thread_m_blocks=max(m_blocks, 1), + thread_n_blocks=n_blocks, + thread_k_blocks=k_blocks, + m_block_size_8=m_blocks == 0.5, + stages=stage_value, + group_blocks=group_blocks, + is_zp_float=is_zp_float, + ) + + results.append(template_str) return results -def generate_new_kernels() -> None: - emitted = False - for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES): - templates = render_templates_for_combo(scalar_type, dtype) - if not templates: - continue - - _write_kernel_file(scalar_type, dtype, templates) - emitted = True - - if not emitted: - raise RuntimeError("No marlin kernels were generated; check template configuration.") +def generate_new_kernels(root: Path | None = None) -> None: + root = root or Path(__file__).parent + expected = build_expected_kernels(root) + for path in root.glob("kernel_*.cu"): + if path not in expected: + path.unlink(missing_ok=True) + for output_path, rendered in expected.items(): + output_path.write_text(rendered, encoding="utf-8") if __name__ == "__main__": - remove_old_kernels() + if "--check" in sys.argv[1:]: + raise SystemExit(0 if generated_kernels_are_current() else 1) generate_new_kernels() diff --git a/gptqmodel_ext/marlin/gptq_marlin.cu b/gptqmodel_ext/marlin/gptq_marlin.cu index 702476a59..bd66f536a 100644 --- a/gptqmodel_ext/marlin/gptq_marlin.cu +++ b/gptqmodel_ext/marlin/gptq_marlin.cu @@ -23,7 +23,25 @@ #define MARLIN_NAMESPACE_NAME marlin #endif +#ifndef MARLIN_GEMM_EXPORT_NAME + #define MARLIN_GEMM_EXPORT_NAME gptq_marlin_gemm +#endif + +#ifndef MARLIN_ENABLE_FP16 + #define MARLIN_ENABLE_FP16 1 +#endif + +#ifndef MARLIN_ENABLE_BF16 + #define MARLIN_ENABLE_BF16 1 +#endif + +#if !MARLIN_ENABLE_FP16 && !MARLIN_ENABLE_BF16 + #error "At least one Marlin compute dtype must be enabled." +#endif + #include +#include +#include #ifndef MARLIN_SHARED_MEM_GUARD_BYTES # define MARLIN_SHARED_MEM_GUARD_BYTES 0 // 512 @@ -59,7 +77,7 @@ __global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){}; using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS); -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, int const* __restrict__ perm_int_ptr, @@ -68,7 +86,7 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, } // namespace marlin -torch::Tensor gptq_marlin_gemm( +torch::Tensor MARLIN_GEMM_EXPORT_NAME( torch::Tensor& a, std::optional c_or_none, torch::Tensor& b_q_weight, std::optional const& b_bias_or_none, torch::Tensor& b_scales, @@ -79,7 +97,7 @@ torch::Tensor gptq_marlin_gemm( int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, bool is_zp_float) { TORCH_CHECK_NOT_IMPLEMENTED(false, - "marlin_gemm(..) requires CUDA_ARCH >= 8.0"); + "marlin_gemm(..) requires CUDA_ARCH >= 7.5"); return torch::empty({1, 1}); } @@ -163,14 +181,80 @@ thread_config_t large_batch_thread_configs[] = { {64, 128, 128}, {128, 64, 128}}; +thread_config_t full_sm80_small_batch_thread_configs[] = { + // Near-full GA100 boards benefit from lighter tiles to create more work + // across the observed 124 SMs without relying on 2 blocks/SM. + {64, 128, 128}, + {128, 128, 256}, + {128, 64, 128}}; + +thread_config_t full_sm80_large_batch_thread_configs[] = { + {64, 128, 128}, + {64, 256, 256}, + {128, 64, 128}}; + +thread_config_t full_sm80_heavy_batch_thread_configs[] = { + {64, 256, 256}, + {64, 128, 128}, + {128, 64, 128}}; + typedef struct { int blocks_per_sm; thread_config_t tb_cfg; } exec_config_t; +bool marlin_prefers_full_sm80(int major_capability, int minor_capability, + int sms) { + return major_capability == 8 && minor_capability == 0 && sms >= 124; +} + +int marlin_full_sm80_exact_thread_m_blocks(int prob_m) { + switch (prob_m) { + case 96: + return 2; + case 160: + return 2; + default: + return -1; + } +} + +struct marlin_device_info_t { + int sms = 0; + int max_shared_mem = 0; + int major_capability = 0; + int minor_capability = 0; +}; + +marlin_device_info_t query_marlin_device_info(int device) { + marlin_device_info_t info; + cudaDeviceGetAttribute(&info.sms, cudaDevAttrMultiProcessorCount, device); + cudaDeviceGetAttribute(&info.max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, device); + cudaDeviceGetAttribute(&info.major_capability, + cudaDevAttrComputeCapabilityMajor, device); + cudaDeviceGetAttribute(&info.minor_capability, + cudaDevAttrComputeCapabilityMinor, device); + return info; +} + +marlin_device_info_t get_marlin_device_info(int device) { + static std::mutex mutex; + static std::vector cache; + std::lock_guard lock(mutex); + if (device >= static_cast(cache.size())) { + cache.resize(device + 1); + } + marlin_device_info_t& info = cache[device]; + if (info.sms == 0) { + info = query_marlin_device_info(device); + } + return info; +} + int get_scales_cache_size(thread_config_t const& th_config, int prob_m, int prob_n, int prob_k, int num_bits, int group_size, - bool has_act_order, bool is_k_full) { + bool has_act_order, bool is_k_full, int stages) { bool cache_scales_chunk = has_act_order && !is_k_full; int tb_n = th_config.thread_n; @@ -188,28 +272,28 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m, if (cache_scales_chunk) { int load_groups = - tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K + tb_groups * stages * 2; // Chunk size is 2x pipeline over dim K load_groups = max(load_groups, 32); // We load at least 32 scale groups return load_groups * tb_n * 2; } else { int tb_scales = tb_groups * tb_n * 2; - return tb_scales * pipe_stages; + return tb_scales * stages; } } int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks, int prob_m, int prob_n, int prob_k, int num_bits, int group_size, bool has_act_order, bool is_k_full, - int has_zp, int is_zp_float) { + int has_zp, int is_zp_float, int stages) { int pack_factor = 32 / num_bits; // Get B size int tb_k = th_config.thread_k; int tb_n = th_config.thread_n; int tb_m = thread_m_blocks * 16; - int sh_a_size = pipe_stages * (tb_m * tb_k) * 2; - int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4; + int sh_a_size = stages * (tb_m * tb_k) * 2; + int sh_b_size = stages * (tb_k * tb_n / pack_factor) * 4; int sh_red_size = tb_m * (tb_n + 8) * 2; int sh_bias_size = tb_n * 2; int tmp_size = @@ -218,8 +302,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks, int sh_s_size = get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, - group_size, has_act_order, is_k_full); - int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0; + group_size, has_act_order, is_k_full, stages); + int sh_g_idx_size = has_act_order && !is_k_full ? stages * tb_k / 4 : 0; int sh_zp_size = 0; if (has_zp) { if (is_zp_float) @@ -239,7 +323,8 @@ int get_kernel_cache_size(thread_config_t const& th_config, int thread_m_blocks, bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks, int prob_m, int prob_n, int prob_k, int num_bits, int group_size, bool has_act_order, bool is_k_full, - int has_zp, int is_zp_float, int max_shared_mem) { + int has_zp, int is_zp_float, int stages, + int max_shared_mem) { // Sanity if (th_config.thread_k == -1 || th_config.thread_n == -1 || th_config.num_threads == -1) { @@ -261,29 +346,50 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks, return false; } + // Our stage-2 dense 4-bit kernels need enough B-stage capacity to cover the + // output tile. Larger M tiles with thread_k == 64 are not emitted. + if (stages == 2 && num_bits == 4 && + thread_m_blocks * 2 > th_config.thread_k / 16) { + return false; + } + // Check that pipeline fits into cache int cache_size = get_kernel_cache_size( th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, - has_act_order, is_k_full, has_zp, is_zp_float); + has_act_order, is_k_full, has_zp, is_zp_float, stages); return cache_size + MARLIN_SHARED_MEM_GUARD_BYTES <= max_shared_mem; } #define _GET_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ - M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \ + M_BLOCK_SIZE_8, STAGES, GROUP_BLOCKS, NUM_THREADS, \ + IS_ZP_FLOAT) \ else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \ thread_n_blocks == THREAD_N_BLOCKS && \ thread_k_blocks == THREAD_K_BLOCKS && \ m_block_size_8 == M_BLOCK_SIZE_8 && \ - group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS && \ + stages == STAGES && group_blocks == GROUP_BLOCKS && \ + num_threads == NUM_THREADS && \ is_zp_float == IS_ZP_FLOAT) { \ - constexpr auto S_TYPE = \ - W_TYPE == vllm::kFE2M1f \ - ? (GROUP_BLOCKS == 1 ? vllm::kFE4M3fn : vllm::kFE8M0fnu) \ - : (std::is_same::value ? vllm::kFloat16 \ - : vllm::kBFloat16); \ - kernel = Marlin; \ + constexpr bool kIsStage2FourBitTile = \ + STAGES == 2 && \ + (W_TYPE == vllm::kU4 || W_TYPE == vllm::kU4B8 || \ + W_TYPE == vllm::kFE2M1f); \ + constexpr bool kIsSupportedStage2Tile = \ + !kIsStage2FourBitTile || \ + THREAD_M_BLOCKS * 2 <= THREAD_K_BLOCKS; \ + if constexpr (kIsSupportedStage2Tile) { \ + constexpr auto S_TYPE = \ + W_TYPE == vllm::kFE2M1f \ + ? (GROUP_BLOCKS == 1 ? vllm::kFE4M3fn : vllm::kFE8M0fnu) \ + : ((W_TYPE == vllm::kFE4M3fn && GROUP_BLOCKS == 2) \ + ? vllm::kFE8M0fnu \ + : (std::is_same::value \ + ? vllm::kFloat16 \ + : vllm::kBFloat16)); \ + kernel = Marlin; \ + } \ } // COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false) @@ -292,131 +398,209 @@ bool is_valid_config(thread_config_t const& th_config, int thread_m_blocks, // FZP: cases for float-zero-point (is_zp_float = true) // ACT: cases for act order case (group_blocks == 0) // FP4: cases for nvfp4(e2m1) (group_blocks == 1) - #define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) - - #define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ - \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ - \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) - - #define COMMON_GET_IF(W_TYPE) \ - COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \ - COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \ - COMMON_GET_IF_M1(W_TYPE, 4, 8, 128) \ - COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \ - COMMON_GET_IF_M234(W_TYPE, 8, 4, 128) \ - COMMON_GET_IF_M234(W_TYPE, 4, 8, 128) - - #define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) - - #define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) - - #define BIGGROUP_GET_IF(W_TYPE) \ - BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \ - BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \ - BIGGROUP_GET_IF_M1(W_TYPE, 4, 8, 128) \ - BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \ - BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) \ - BIGGROUP_GET_IF_M234(W_TYPE, 4, 8, 128) - - #define NVFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) - - #define NVFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) - - #define NVFP4_GET_IF(W_TYPE) \ - NVFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ - NVFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ - NVFP4_GET_IF_M1(W_TYPE, 4, 8, 128) \ - NVFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ - NVFP4_GET_IF_M234(W_TYPE, 8, 4, 128) \ - NVFP4_GET_IF_M234(W_TYPE, 4, 8, 128) - - #define MXFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) - - #define MXFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) - - #define MXFP4_GET_IF(W_TYPE) \ - MXFP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ - MXFP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ - MXFP4_GET_IF_M1(W_TYPE, 4, 8, 128) \ - MXFP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ - MXFP4_GET_IF_M234(W_TYPE, 8, 4, 128) \ - MXFP4_GET_IF_M234(W_TYPE, 4, 8, 128) + #define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS, STAGES) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, STAGES, -1, NUM_THREADS, \ + false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, STAGES, 2, NUM_THREADS, \ + false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, STAGES, 4, NUM_THREADS, \ + false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, STAGES, 8, NUM_THREADS, \ + false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, STAGES, -1, NUM_THREADS, \ + false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, STAGES, 2, NUM_THREADS, \ + false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, STAGES, 4, NUM_THREADS, \ + false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, STAGES, 8, NUM_THREADS, \ + false) + + #define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS, \ + STAGES) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, STAGES, -1, NUM_THREADS, \ + false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, STAGES, 2, NUM_THREADS, \ + false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, STAGES, 4, NUM_THREADS, \ + false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, STAGES, 8, NUM_THREADS, \ + false) \ + \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, STAGES, -1, NUM_THREADS, \ + false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, STAGES, 2, NUM_THREADS, \ + false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, STAGES, 4, NUM_THREADS, \ + false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, STAGES, 8, NUM_THREADS, \ + false) \ + \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, STAGES, -1, NUM_THREADS, \ + false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, STAGES, 2, NUM_THREADS, \ + false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, STAGES, 4, NUM_THREADS, \ + false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, STAGES, 8, NUM_THREADS, \ + false) + + #define COMMON_GET_IF(W_TYPE, STAGES) \ + COMMON_GET_IF_M1(W_TYPE, 8, 8, 256, STAGES) \ + COMMON_GET_IF_M1(W_TYPE, 8, 4, 128, STAGES) \ + COMMON_GET_IF_M1(W_TYPE, 4, 8, 128, STAGES) \ + COMMON_GET_IF_M234(W_TYPE, 16, 4, 256, STAGES) \ + COMMON_GET_IF_M234(W_TYPE, 8, 4, 128, STAGES) \ + COMMON_GET_IF_M234(W_TYPE, 4, 8, 128, STAGES) + + #define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS, \ + STAGES) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, STAGES, -1, NUM_THREADS, \ + false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, STAGES, 8, NUM_THREADS, \ + false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, STAGES, -1, NUM_THREADS, \ + false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, STAGES, 8, NUM_THREADS, \ + false) + + #define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS, \ + STAGES) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, STAGES, -1, NUM_THREADS, \ + false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, STAGES, 8, NUM_THREADS, \ + false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, STAGES, -1, NUM_THREADS, \ + false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, STAGES, 8, NUM_THREADS, \ + false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, STAGES, -1, NUM_THREADS, \ + false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, STAGES, 8, NUM_THREADS, \ + false) + + #define BIGGROUP_GET_IF(W_TYPE, STAGES) \ + BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256, STAGES) \ + BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128, STAGES) \ + BIGGROUP_GET_IF_M1(W_TYPE, 4, 8, 128, STAGES) \ + BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256, STAGES) \ + BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128, STAGES) \ + BIGGROUP_GET_IF_M234(W_TYPE, 4, 8, 128, STAGES) + + #define NVFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS, STAGES) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, STAGES, 1, NUM_THREADS, \ + false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, STAGES, 1, NUM_THREADS, \ + false) + + #define NVFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS, \ + STAGES) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, STAGES, 1, NUM_THREADS, \ + false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, STAGES, 1, NUM_THREADS, \ + false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, STAGES, 1, NUM_THREADS, \ + false) + + #define NVFP4_GET_IF(W_TYPE, STAGES) \ + NVFP4_GET_IF_M1(W_TYPE, 8, 8, 256, STAGES) \ + NVFP4_GET_IF_M1(W_TYPE, 8, 4, 128, STAGES) \ + NVFP4_GET_IF_M1(W_TYPE, 4, 8, 128, STAGES) \ + NVFP4_GET_IF_M234(W_TYPE, 16, 4, 256, STAGES) \ + NVFP4_GET_IF_M234(W_TYPE, 8, 4, 128, STAGES) \ + NVFP4_GET_IF_M234(W_TYPE, 4, 8, 128, STAGES) + + #define MXFP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS, STAGES) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, STAGES, 2, NUM_THREADS, \ + false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, STAGES, 2, NUM_THREADS, \ + false) + + #define MXFP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS, \ + STAGES) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, STAGES, 2, NUM_THREADS, \ + false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, STAGES, 2, NUM_THREADS, \ + false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, STAGES, 2, NUM_THREADS, \ + false) + + #define MXFP4_GET_IF(W_TYPE, STAGES) \ + MXFP4_GET_IF_M1(W_TYPE, 8, 8, 256, STAGES) \ + MXFP4_GET_IF_M1(W_TYPE, 8, 4, 128, STAGES) \ + MXFP4_GET_IF_M1(W_TYPE, 4, 8, 128, STAGES) \ + MXFP4_GET_IF_M234(W_TYPE, 16, 4, 256, STAGES) \ + MXFP4_GET_IF_M234(W_TYPE, 8, 4, 128, STAGES) \ + MXFP4_GET_IF_M234(W_TYPE, 4, 8, 128, STAGES) + + #define MXFP8_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS, STAGES) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, STAGES, 2, NUM_THREADS, \ + false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, STAGES, 2, NUM_THREADS, \ + false) + + #define MXFP8_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS, \ + STAGES) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, STAGES, 2, NUM_THREADS, \ + false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, STAGES, 2, NUM_THREADS, \ + false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, STAGES, 2, NUM_THREADS, \ + false) + + #define MXFP8_GET_IF(W_TYPE, STAGES) \ + MXFP8_GET_IF_M1(W_TYPE, 8, 8, 256, STAGES) \ + MXFP8_GET_IF_M1(W_TYPE, 8, 4, 128, STAGES) \ + MXFP8_GET_IF_M1(W_TYPE, 4, 8, 128, STAGES) \ + MXFP8_GET_IF_M234(W_TYPE, 16, 4, 256, STAGES) \ + MXFP8_GET_IF_M234(W_TYPE, 8, 4, 128, STAGES) \ + MXFP8_GET_IF_M234(W_TYPE, 4, 8, 128, STAGES) // We currently have 4-bit models only with group_blocks == 4 - #define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) - - #define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) - - #define FZP_GET_IF(W_TYPE) \ - FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \ - FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \ - FZP_GET_IF_M1(W_TYPE, 4, 8, 128) \ - FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \ - FZP_GET_IF_M234(W_TYPE, 8, 4, 128) \ - FZP_GET_IF_M234(W_TYPE, 4, 8, 128) + #define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS, STAGES) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, STAGES, 4, NUM_THREADS, \ + true) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, STAGES, 4, NUM_THREADS, \ + true) + + #define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS, STAGES) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, STAGES, 4, NUM_THREADS, \ + true) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, STAGES, 4, NUM_THREADS, \ + true) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, STAGES, 4, NUM_THREADS, \ + true) + + #define FZP_GET_IF(W_TYPE, STAGES) \ + FZP_GET_IF_M1(W_TYPE, 8, 8, 256, STAGES) \ + FZP_GET_IF_M1(W_TYPE, 8, 4, 128, STAGES) \ + FZP_GET_IF_M1(W_TYPE, 4, 8, 128, STAGES) \ + FZP_GET_IF_M234(W_TYPE, 16, 4, 256, STAGES) \ + FZP_GET_IF_M234(W_TYPE, 8, 4, 128, STAGES) \ + FZP_GET_IF_M234(W_TYPE, 4, 8, 128, STAGES) // We currently have 4-bit models only with group_blocks == 4 - #define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) - - #define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ - _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ - _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) - - #define ACT_GET_IF(W_TYPE) \ - ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \ - ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \ - ACT_GET_IF_M1(W_TYPE, 4, 8, 128) \ - ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \ - ACT_GET_IF_M234(W_TYPE, 8, 4, 128) \ - ACT_GET_IF_M234(W_TYPE, 4, 8, 128) + #define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS, STAGES) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, STAGES, 0, NUM_THREADS, \ + false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, STAGES, 0, NUM_THREADS, \ + false) + + #define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS, STAGES) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, STAGES, 0, NUM_THREADS, \ + false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, STAGES, 0, NUM_THREADS, \ + false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, STAGES, 0, NUM_THREADS, \ + false) + + #define ACT_GET_IF(W_TYPE, STAGES) \ + ACT_GET_IF_M1(W_TYPE, 8, 8, 256, STAGES) \ + ACT_GET_IF_M1(W_TYPE, 8, 4, 128, STAGES) \ + ACT_GET_IF_M1(W_TYPE, 4, 8, 128, STAGES) \ + ACT_GET_IF_M234(W_TYPE, 16, 4, 256, STAGES) \ + ACT_GET_IF_M234(W_TYPE, 8, 4, 128, STAGES) \ + ACT_GET_IF_M234(W_TYPE, 4, 8, 128, STAGES) template MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type, @@ -424,32 +608,60 @@ MarlinFuncPtr get_marlin_kernel(const vllm::ScalarType q_type, int thread_k_blocks, bool m_block_size_8, bool has_act_order, bool has_zp, int group_blocks, int num_threads, - bool is_zp_float) { + bool is_zp_float, int stages) { int num_bits = q_type.size_bits(); auto kernel = MarlinDefault; - if (false) { + if constexpr (std::is_same::value) { + if (stages == 2) { + if (false) { + } + COMMON_GET_IF(vllm::kU4, 2) + COMMON_GET_IF(vllm::kU4B8, 2) + COMMON_GET_IF(vllm::kU8B128, 2) + + NVFP4_GET_IF(vllm::kFE2M1f, 2) + + BIGGROUP_GET_IF(vllm::kFE4M3fn, 2) + + ACT_GET_IF(vllm::kU4B8, 2) + ACT_GET_IF(vllm::kU8B128, 2) + } } - COMMON_GET_IF(vllm::kU4) - COMMON_GET_IF(vllm::kU4B8) - COMMON_GET_IF(vllm::kU8B128) + if (stages == pipe_stages) { + if (false) { + } + COMMON_GET_IF(vllm::kU4, pipe_stages) + COMMON_GET_IF(vllm::kU4B8, pipe_stages) + COMMON_GET_IF(vllm::kU8B128, pipe_stages) - NVFP4_GET_IF(vllm::kFE2M1f) + NVFP4_GET_IF(vllm::kFE2M1f, pipe_stages) - BIGGROUP_GET_IF(vllm::kFE4M3fn) + BIGGROUP_GET_IF(vllm::kFE4M3fn, pipe_stages) - ACT_GET_IF(vllm::kU4B8) - ACT_GET_IF(vllm::kU8B128) + ACT_GET_IF(vllm::kU4B8, pipe_stages) + ACT_GET_IF(vllm::kU8B128, pipe_stages) + } - if (std::is_same::value) { - if (false) { + if constexpr (std::is_same::value) { + if (stages == 2) { + if (false) { + } + FZP_GET_IF(vllm::kU4, 2) + } + if (stages == pipe_stages) { + if (false) { + } + FZP_GET_IF(vllm::kU4, pipe_stages) } - FZP_GET_IF(vllm::kU4) } - if (std::is_same::value) { - if (false) { + if constexpr (std::is_same::value) { + if (stages == pipe_stages) { + if (false) { + } + MXFP8_GET_IF(vllm::kFE4M3fn, pipe_stages) + MXFP4_GET_IF(vllm::kFE2M1f, pipe_stages) } - MXFP4_GET_IF(vllm::kFE2M1f) } return kernel; @@ -461,29 +673,48 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, bool m_block_size_8, int num_bits, int group_size, bool has_act_order, bool is_k_full, bool has_zp, - bool is_zp_float, int max_shared_mem, - int sms) { + bool is_zp_float, int stages, + int max_shared_mem, int major_capability, + int minor_capability, int sms) { exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}}; - thread_config_t* thread_configs = thread_m_blocks > 1 - ? large_batch_thread_configs - : small_batch_thread_configs; + bool full_sm80 = + marlin_prefers_full_sm80(major_capability, minor_capability, sms); + bool heavy_batch = full_sm80 && prob_m >= 128; + thread_config_t const* thread_configs = + prob_m > 16 + ? (heavy_batch ? full_sm80_heavy_batch_thread_configs + : (full_sm80 ? full_sm80_large_batch_thread_configs + : large_batch_thread_configs)) + : (full_sm80 ? full_sm80_small_batch_thread_configs + : small_batch_thread_configs); int thread_configs_size = - thread_m_blocks > 1 - ? sizeof(large_batch_thread_configs) / sizeof(thread_config_t) - : sizeof(small_batch_thread_configs) / sizeof(thread_config_t); + prob_m > 16 + ? (heavy_batch + ? static_cast(sizeof(full_sm80_heavy_batch_thread_configs) / + sizeof(thread_config_t)) + : (full_sm80 + ? static_cast(sizeof(full_sm80_large_batch_thread_configs) / + sizeof(thread_config_t)) + : static_cast(sizeof(large_batch_thread_configs) / + sizeof(thread_config_t)))) + : (full_sm80 + ? static_cast(sizeof(full_sm80_small_batch_thread_configs) / + sizeof(thread_config_t)) + : static_cast(sizeof(small_batch_thread_configs) / + sizeof(thread_config_t))); for (int i = 0; i < thread_configs_size; i++) { thread_config_t th_config = thread_configs[i]; if (!is_valid_config(th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full, has_zp, - is_zp_float, max_shared_mem)) { + is_zp_float, stages, max_shared_mem)) { continue; } int cache_size = get_kernel_cache_size( th_config, thread_m_blocks, prob_m, prob_n, prob_k, num_bits, - group_size, has_act_order, is_k_full, has_zp, is_zp_float); + group_size, has_act_order, is_k_full, has_zp, is_zp_float, stages); int group_blocks = 0; if (!has_act_order) { @@ -493,7 +724,7 @@ exec_config_t determine_exec_config(const vllm::ScalarType& q_type, int prob_m, auto kernel = get_marlin_kernel( q_type, thread_m_blocks, th_config.thread_n / 16, th_config.thread_k / 16, m_block_size_8, has_act_order, has_zp, - group_blocks, th_config.num_threads, is_zp_float); + group_blocks, th_config.num_threads, is_zp_float, stages); if (kernel == MarlinDefault) continue; @@ -560,7 +791,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, int4* C_tmp_ptr = (int4*)C_tmp; const int4* bias_ptr = (const int4*)b_bias; const int4* s_ptr = (const int4*)s; - const uint16_t* s2_ptr = (const uint16_t*)s2; + const float* s2_ptr = (const float*)s2; const int4* zp_ptr = (const int4*)zp; const int* g_idx_ptr = (const int*)g_idx; const int* perm_ptr = (const int*)perm; @@ -570,10 +801,11 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, if (has_act_order) { // Permute A columns - int block_rows = div_ceil(prob_m, sms); + int permute_blocks = min(sms, max(prob_m, 1)); + int block_rows = div_ceil(prob_m, permute_blocks); // avoid ">>>" being formatted to "> > >" // clang-format off - permute_cols_kernel<<>>( + permute_cols_kernel<<>>( A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, lda, block_rows); // clang-format on A_ptr = a_tmp_ptr; @@ -585,26 +817,59 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, if (is_k_full) has_act_order = false; } - int max_shared_mem = 0; - cudaDeviceGetAttribute(&max_shared_mem, - cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + marlin_device_info_t device_info = get_marlin_device_info(dev); + int max_shared_mem = device_info.max_shared_mem; TORCH_CHECK(max_shared_mem > 0); + int major_capability = device_info.major_capability; + int minor_capability = device_info.minor_capability; + TORCH_CHECK(major_capability > 7 || + (major_capability == 7 && minor_capability >= 5), + "marlin kernel only supports Turing or newer GPUs."); + + int stages = pipe_stages; + if (major_capability == 7 && minor_capability == 5) { + stages = 2; + if constexpr (!std::is_same::value) { + TORCH_CHECK(false, "Turing only supports float16 dense Marlin kernels."); + } + } + int max_par = 16; if (prob_n <= 4096) max_par = 16 * 8; int max_shared_mem_new = max_shared_mem; int rest_m = prob_m; int max_thread_m_blocks = 4; + bool full_sm80 = + marlin_prefers_full_sm80(major_capability, minor_capability, sms); + bool disable_full_sm80_exact_split = false; while (rest_m) { - int par_count = rest_m / (max_thread_m_blocks * 16); - if (par_count > max_par) par_count = max_par; - int prob_m_split = - par_count > 0 ? (par_count * (max_thread_m_blocks * 16)) : rest_m; - int thread_k = thread_k_init; int thread_n = thread_n_init; + bool manual_override = thread_k != -1 && thread_n != -1; + int attempt_thread_m_blocks = max_thread_m_blocks; + bool force_exact_split = false; + // On the local 124-SM sm_80 boards, a few exact-M shapes consistently beat + // the generic split path when we keep all M tiles in one launch. This + // avoids tiny remainder launches such as 128+32. + if (!manual_override && full_sm80 && !disable_full_sm80_exact_split) { + int exact_thread_m_blocks = marlin_full_sm80_exact_thread_m_blocks(rest_m); + if (exact_thread_m_blocks > 0 && + rest_m % (exact_thread_m_blocks * 16) == 0 && + attempt_thread_m_blocks > exact_thread_m_blocks) { + attempt_thread_m_blocks = exact_thread_m_blocks; + force_exact_split = true; + } + } - int thread_m_blocks = min(div_ceil(prob_m_split, 16), max_thread_m_blocks); + int par_count = rest_m / (attempt_thread_m_blocks * 16); + if (par_count > max_par) par_count = max_par; + int prob_m_split = + par_count > 0 ? (par_count * (attempt_thread_m_blocks * 16)) : rest_m; + if (force_exact_split) prob_m_split = rest_m; + + int thread_m_blocks = + min(div_ceil(prob_m_split, 16), attempt_thread_m_blocks); int m_block_size_8 = prob_m_split <= 8; // Set thread config @@ -622,10 +887,15 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, exec_cfg = determine_exec_config( q_type, prob_m_split, prob_n, prob_k, thread_m_blocks, m_block_size_8, num_bits, group_size, has_act_order, is_k_full, has_zp, is_zp_float, - max_shared_mem, sms); + stages, max_shared_mem, major_capability, minor_capability, sms); thread_tfg = exec_cfg.tb_cfg; - if (thread_tfg.thread_k == -1 && max_thread_m_blocks > 1) { - max_thread_m_blocks--; + if (thread_tfg.thread_k == -1 && force_exact_split) { + disable_full_sm80_exact_split = true; + max_thread_m_blocks = 4; + continue; + } + if (thread_tfg.thread_k == -1 && attempt_thread_m_blocks > 1) { + max_thread_m_blocks = attempt_thread_m_blocks - 1; continue; } } @@ -645,7 +915,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, TORCH_CHECK( is_valid_config(thread_tfg, thread_m_blocks, prob_m_split, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full, - has_zp, is_zp_float, max_shared_mem_new), + has_zp, is_zp_float, stages, max_shared_mem_new), "Invalid thread config: thread_m_blocks = ", thread_m_blocks, ", thread_k = ", thread_tfg.thread_k, ", thread_n = ", thread_tfg.thread_n, @@ -654,12 +924,12 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, ", prob_m_split = ", prob_m_split, ", group_size = ", group_size, ", has_act_order = ", has_act_order, ", is_k_full = ", is_k_full, ", has_zp = ", has_zp, ", is_zp_float = ", is_zp_float, - ", max_shared_mem_new = ", max_shared_mem_new); + ", stages = ", stages, ", max_shared_mem_new = ", max_shared_mem_new); auto kernel = get_marlin_kernel( q_type, thread_m_blocks, thread_n_blocks, thread_k_blocks, m_block_size_8, has_act_order, has_zp, group_blocks, num_threads, - is_zp_float); + is_zp_float, stages); if (kernel == MarlinDefault) { TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n, @@ -695,7 +965,7 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* b_bias, } // namespace marlin -torch::Tensor gptq_marlin_gemm( +torch::Tensor MARLIN_GEMM_EXPORT_NAME( torch::Tensor& a, std::optional c_or_none, torch::Tensor& b_q_weight, std::optional const& b_bias_or_none, torch::Tensor& b_scales, @@ -752,8 +1022,7 @@ torch::Tensor gptq_marlin_gemm( // auto -1) int thread_n = -1; // sms: number of SMs to use for the kernel - int sms = -1; - cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device()); + int sms = marlin::get_marlin_device_info(a.get_device()).sms; // Alloc buffers const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); @@ -848,7 +1117,7 @@ torch::Tensor gptq_marlin_gemm( TORCH_CHECK(b_q_type == vllm::kFE2M1f && group_size == 16, "global_scale can only be used for nvfp4 format."); } else { - global_scale = torch::empty({0}, options); + global_scale = torch::empty({0}, options_fp32); TORCH_CHECK(!(b_q_type == vllm::kFE2M1f && group_size == 16), "the global_scale parameter must be passed for nvfp4 format."); } @@ -926,6 +1195,9 @@ torch::Tensor gptq_marlin_gemm( " is below min_workspace_size = ", min_workspace_size); int dev = a.get_device(); + TORCH_CHECK(global_scale.scalar_type() == at::ScalarType::Float, + "scalar type of global_scale must be float"); + #if MARLIN_ENABLE_FP16 if (a.scalar_type() == at::ScalarType::Half) { void* scales_ptr; if (b_q_type == vllm::kFE2M1f) { @@ -942,6 +1214,12 @@ torch::Tensor gptq_marlin_gemm( TORCH_CHECK(false, "float4_e2m1f only supports group_size == 16 (NVFP4) ", "and group_size == 32 (MXFP4)"); +#if HAS_FLOAT8_E8M0FNU + } else if (b_q_type == vllm::kFE4M3fn && + b_scales.scalar_type() == at::ScalarType::Float8_e8m0fnu) { + TORCH_CHECK(false, "float8_e4m3fn with float8_e8m0fnu scales requires " + "bfloat16 compute (MXFP8)."); +#endif } else { scales_ptr = b_scales.data_ptr(); } @@ -949,13 +1227,18 @@ torch::Tensor gptq_marlin_gemm( marlin::marlin_mm( a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), c_tmp.data_ptr(), b_bias.data_ptr(), scales_ptr, - global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), + global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, a.stride(0), workspace.data_ptr(), b_q_type, has_bias, has_act_order, is_k_full, has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float); - } else if (a.scalar_type() == at::ScalarType::BFloat16) { + return c; + } + #endif + + #if MARLIN_ENABLE_BF16 + if (a.scalar_type() == at::ScalarType::BFloat16) { void* scales_ptr; if (b_q_type == vllm::kFE2M1f) { if (group_size == 16) @@ -971,6 +1254,14 @@ torch::Tensor gptq_marlin_gemm( TORCH_CHECK(false, "float4_e2m1f only supports group_size == 16 (NVFP4) ", "and group_size == 32 (MXFP4)"); +#if HAS_FLOAT8_E8M0FNU + } else if (b_q_type == vllm::kFE4M3fn && + b_scales.scalar_type() == at::ScalarType::Float8_e8m0fnu) { + TORCH_CHECK(group_size == 32, + "float8_e4m3fn only supports group_size == 32 (MXFP8) when " + "using float8_e8m0fnu scales."); + scales_ptr = b_scales.data_ptr(); +#endif } else { scales_ptr = b_scales.data_ptr(); } @@ -979,15 +1270,23 @@ torch::Tensor gptq_marlin_gemm( a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), c_tmp.data_ptr(), b_bias.data_ptr(), scales_ptr, - global_scale.data_ptr(), b_zeros.data_ptr(), + global_scale.data_ptr(), b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, size_k, a.stride(0), workspace.data_ptr(), b_q_type, has_bias, has_act_order, is_k_full, has_zp, num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, use_atomic_add, use_fp32_reduce, is_zp_float); - } else { - TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16"); + return c; } + #endif + + #if MARLIN_ENABLE_FP16 && MARLIN_ENABLE_BF16 + TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16"); + #elif MARLIN_ENABLE_FP16 + TORCH_CHECK(false, "gpt_marlin_gemm_fp16 only supports float16"); + #else + TORCH_CHECK(false, "gpt_marlin_gemm_bf16 only supports bfloat16"); + #endif return c; } diff --git a/gptqmodel_ext/marlin/gptq_marlin_bf16.cu b/gptqmodel_ext/marlin/gptq_marlin_bf16.cu new file mode 100644 index 000000000..4bc4ff239 --- /dev/null +++ b/gptqmodel_ext/marlin/gptq_marlin_bf16.cu @@ -0,0 +1,5 @@ +#define MARLIN_GEMM_EXPORT_NAME gptq_marlin_gemm_bf16 +#define MARLIN_ENABLE_FP16 0 +#define MARLIN_ENABLE_BF16 1 + +#include "gptq_marlin.cu" diff --git a/gptqmodel_ext/marlin/gptq_marlin_fp16.cu b/gptqmodel_ext/marlin/gptq_marlin_fp16.cu new file mode 100644 index 000000000..46a0f35de --- /dev/null +++ b/gptqmodel_ext/marlin/gptq_marlin_fp16.cu @@ -0,0 +1,5 @@ +#define MARLIN_GEMM_EXPORT_NAME gptq_marlin_gemm_fp16 +#define MARLIN_ENABLE_FP16 1 +#define MARLIN_ENABLE_BF16 0 + +#include "gptq_marlin.cu" diff --git a/gptqmodel_ext/marlin/kernel.h b/gptqmodel_ext/marlin/kernel.h index 8357fa351..6fc854853 100644 --- a/gptqmodel_ext/marlin/kernel.h +++ b/gptqmodel_ext/marlin/kernel.h @@ -12,7 +12,7 @@ int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ const int4 *__restrict__ b_bias_ptr, \ const int4 *__restrict__ scales_ptr, \ - const uint16_t *__restrict__ scale2_ptr, \ + const float *__restrict__ global_scale_ptr, \ const int4 *__restrict__ zp_ptr, const int *__restrict__ g_idx, \ int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \ bool has_bias, bool use_atomic_add, bool use_fp32_reduce, \ @@ -38,4 +38,4 @@ template __global__ void Marlin(MARLIN_KERNEL_PARAMS); -} \ No newline at end of file +} diff --git a/gptqmodel_ext/marlin/marlin.cuh b/gptqmodel_ext/marlin/marlin.cuh index 32487abe7..d0e471eaf 100644 --- a/gptqmodel_ext/marlin/marlin.cuh +++ b/gptqmodel_ext/marlin/marlin.cuh @@ -1,17 +1,19 @@ #pragma once -#include +#ifndef _marlin_cuh + #define _marlin_cuh + #include -#include -#include -#include -#include -#include -#include + #include + #include + #include + #include + #include + #include -#ifndef MARLIN_NAMESPACE_NAME - #define MARLIN_NAMESPACE_NAME marlin -#endif + #ifndef MARLIN_NAMESPACE_NAME + #define MARLIN_NAMESPACE_NAME marlin + #endif namespace MARLIN_NAMESPACE_NAME { @@ -51,9 +53,90 @@ using I4 = Vec; constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 -// No support for async -#else + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +__device__ inline void cp_async1_ca_pred(void* smem_ptr, const void* glob_ptr, + bool pred = true) { + if (pred) { + reinterpret_cast(smem_ptr)[0] = + reinterpret_cast(glob_ptr)[0]; + } +} + +__device__ inline void cp_async2_ca_pred(void* smem_ptr, const void* glob_ptr, + bool pred = true) { + if (pred) { + reinterpret_cast(smem_ptr)[0] = + reinterpret_cast(glob_ptr)[0]; + } +} + +__device__ inline void cp_async4_ca_pred(void* smem_ptr, const void* glob_ptr, + bool pred = true) { + if (pred) { + reinterpret_cast(smem_ptr)[0] = + reinterpret_cast(glob_ptr)[0]; + } +} + +__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, + bool pred = true) { + if (pred) { + reinterpret_cast(smem_ptr)[0] = + reinterpret_cast(glob_ptr)[0]; + } +} + +__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { + reinterpret_cast(smem_ptr)[0] = + reinterpret_cast(glob_ptr)[0]; +} + +__device__ inline void cp_async_fence() {} + +template +__device__ inline void cp_async_wait() {} + + #else + +__device__ inline void cp_async1_ca_pred(void* smem_ptr, const void* glob_ptr, + bool pred = true) { + const int BYTES = 4; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +__device__ inline void cp_async2_ca_pred(void* smem_ptr, const void* glob_ptr, + bool pred = true) { + const int BYTES = 8; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +__device__ inline void cp_async4_ca_pred(void* smem_ptr, const void* glob_ptr, + bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.ca.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} __device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) { @@ -87,6 +170,8 @@ __device__ inline void cp_async_wait() { asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); } -#endif + #endif + +} // namespace MARLIN_NAMESPACE_NAME -} // namespace MARLIN_NAMESPACE_NAME \ No newline at end of file +#endif diff --git a/gptqmodel_ext/marlin/marlin_mma.h b/gptqmodel_ext/marlin/marlin_mma.h new file mode 100644 index 000000000..ab5ad40c0 --- /dev/null +++ b/gptqmodel_ext/marlin/marlin_mma.h @@ -0,0 +1,155 @@ +#ifndef MARLIN_MMA_H_ +#define MARLIN_MMA_H_ + +#include "marlin_dtypes.cuh" + +namespace MARLIN_NAMESPACE_NAME { + +template +__device__ inline void mma(const typename ScalarType::FragA& a_frag, + const typename ScalarType::FragB& frag_b, + typename ScalarType::FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + + if constexpr (!std::is_same::value) { + static_assert(!use_fp16_accum); + } + + if constexpr (std::is_same::value && !use_fp16_accum) { + float* c = reinterpret_cast(&frag_c); +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(b[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]), + "f"(c[3])); + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[2]), "r"(a[3]), "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), + "f"(c[3])); +#else + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); +#endif + } else if constexpr (std::is_same::value && + use_fp16_accum) { + uint32_t* c = reinterpret_cast(&frag_c); +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " + "{%0,%1}, {%2,%3}, {%4}, {%5,%6};\n" + : "=r"(c[0]), "=r"(c[1]) + : "r"(a[0]), "r"(a[1]), "r"(b[0]), "r"(c[0]), "r"(c[1])); + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " + "{%0,%1}, {%2,%3}, {%4}, {%5,%6};\n" + : "=r"(c[0]), "=r"(c[1]) + : "r"(a[2]), "r"(a[3]), "r"(b[1]), "r"(c[0]), "r"(c[1])); +#else + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " + "{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n" + : "=r"(c[0]), "=r"(c[1]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "r"(c[0]), "r"(c[1])); +#endif + } else if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else { + static_assert(std::is_same::value || + std::is_same::value, + "only float16 and bfloat16 is supported"); + } +} + +template +__device__ inline void mma_trans( + const typename ScalarType::FragA& a_frag, + const typename ScalarType::FragB& frag_b, + const typename ScalarType::FragB& frag_b2, + typename ScalarType::FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + const uint32_t* b2 = reinterpret_cast(&frag_b2); + + if constexpr (!std::is_same::value) { + static_assert(!use_fp16_accum); + } + + if constexpr (std::is_same::value && !use_fp16_accum) { + float* c = reinterpret_cast(&frag_c); +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(a[0]), "f"(c[0]), "f"(c[1]), "f"(c[2]), + "f"(c[3])); + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5}, {%6}, {%7,%8,%9,%10};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[1]), "r"(b2[1]), "r"(a[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), + "f"(c[3])); +#else + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); +#endif + } else if constexpr (std::is_same::value && + use_fp16_accum) { + uint32_t* c = reinterpret_cast(&frag_c); +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " + "{%0,%1}, {%2,%3}, {%4}, {%5,%6};\n" + : "=r"(c[0]), "=r"(c[1]) + : "r"(b[0]), "r"(b2[0]), "r"(a[0]), "r"(c[0]), "r"(c[1])); + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " + "{%0,%1}, {%2,%3}, {%4}, {%5,%6};\n" + : "=r"(c[0]), "=r"(c[1]) + : "r"(b[1]), "r"(b2[1]), "r"(a[1]), "r"(c[0]), "r"(c[1])); +#else + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " + "{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n" + : "=r"(c[0]), "=r"(c[1]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "r"(c[0]), "r"(c[1])); +#endif + } else if constexpr (std::is_same::value) { + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); + } else { + static_assert(std::is_same::value || + std::is_same::value, + "only float16 and bfloat16 is supported"); + } +} + +} // namespace MARLIN_NAMESPACE_NAME + +#endif // MARLIN_MMA_H_ diff --git a/gptqmodel_ext/marlin/marlin_template.h b/gptqmodel_ext/marlin/marlin_template.h index f3b4e0242..fd40978cd 100644 --- a/gptqmodel_ext/marlin/marlin_template.h +++ b/gptqmodel_ext/marlin/marlin_template.h @@ -26,6 +26,7 @@ #include "marlin.cuh" #include "marlin_dtypes.cuh" #include "dequant.h" +#include "marlin_mma.h" #include "core/scalar_type.hpp" #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ @@ -35,7 +36,7 @@ namespace MARLIN_NAMESPACE_NAME { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750 template -__device__ inline void mma(const typename ScalarType::FragA& a_frag, - const typename ScalarType::FragB& frag_b, - typename ScalarType::FragC& frag_c) { - const uint32_t* a = reinterpret_cast(&a_frag); - const uint32_t* b = reinterpret_cast(&frag_b); - float* c = reinterpret_cast(&frag_c); - if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else { - STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); - } -} - -template -__device__ inline void mma_trans( - const typename ScalarType::FragA& a_frag, - const typename ScalarType::FragB& frag_b, - const typename ScalarType::FragB& frag_b2, - typename ScalarType::FragC& frag_c) { - const uint32_t* a = reinterpret_cast(&a_frag); - const uint32_t* b = reinterpret_cast(&frag_b); - const uint32_t* b2 = reinterpret_cast(&frag_b2); - float* c = reinterpret_cast(&frag_c); - if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else if constexpr (std::is_same::value) { - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) - : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), - "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); - } else { - STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); - } -} - // Instruction for loading a full 16x16 matrix fragment of operand A from shared // memory, directly in tensor core layout. template @@ -295,8 +239,8 @@ __global__ void Marlin( const int4* __restrict__ b_bias_ptr, const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape // (k/groupsize)xn - const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4 - // only) + const float* __restrict__ global_scale_ptr, // fp32 global scale (for + // nvfp4 only) const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape // (k/groupsize)x(n/pack_factor) const int* __restrict__ g_idx, // int32 group indices of shape k @@ -331,9 +275,26 @@ __global__ void Marlin( static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id); static constexpr auto s_type = vllm::ScalarType::from_id(s_type_id); +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 + constexpr auto num_bits = vllm::ScalarType::from_id(w_type_id).size_bits(); + constexpr bool use_fp16_accum = + std::is_same::value && + (!(w_type_id == vllm::kFE2M1f.id() && s_type_id == vllm::kFE4M3fn.id()) && + !(group_blocks == -1 && num_bits == 4)); +#else + constexpr bool use_fp16_accum = false; +#endif + if constexpr (std::is_same::value) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750 + return; +#endif + } if constexpr (w_type == vllm::kFE2M1f) { static_assert(s_type == vllm::kFE4M3fn && group_blocks == 1 || s_type == vllm::kFE8M0fnu && group_blocks == 2); + } else if constexpr (s_type == vllm::kFE8M0fnu) { + // MXFP8: FP8 weights with e8m0 microscaling block scales. + static_assert(w_type == vllm::kFE4M3fn && group_blocks == 2); } else if constexpr (std::is_same::value) { static_assert(s_type == vllm::kBFloat16); } else if constexpr (std::is_same::value) { @@ -345,16 +306,15 @@ __global__ void Marlin( w_type == vllm::kU4B8 || w_type == vllm::kU8B128; // see comments of dequant.h for more details constexpr bool dequant_skip_flop = - w_type == vllm::kFE4M3fn || + w_type == vllm::kFE4M3fn && !(s_type == vllm::kFE8M0fnu) || w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn || has_zp && !is_zp_float && !std::is_same::value || has_zp && !is_zp_float && !(w_type == vllm::kU8); - scalar_t2 global_scale; + float global_scale_f32 = 1.0f; if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { // NVFP4 format requires global scale - uint16_t val = scale2_ptr[0]; - global_scale = Dtype::num2num2(*reinterpret_cast(&val)); + global_scale_f32 = global_scale_ptr[0]; } constexpr bool has_act_order = group_blocks == 0; @@ -1177,7 +1137,7 @@ __global__ void Marlin( } } - if constexpr (w_type == vllm::kFE2M1f) { + if constexpr (s_type == vllm::kFE4M3fn || s_type == vllm::kFE8M0fnu) { int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; @@ -1246,10 +1206,13 @@ __global__ void Marlin( #pragma unroll for (int i = 0; i < thread_m_blocks; i++) { if constexpr (m_block_size_8) { - mma_trans(frag_a[k2][i], frag_b0, frag_b1, frag_c[i][j][0]); + mma_trans( + frag_a[k2][i], frag_b0, frag_b1, frag_c[i][j][0]); } else { - mma(frag_a[k2][i], frag_b0, frag_c[i][j][0]); - mma(frag_a[k2][i], frag_b1, frag_c[i][j][1]); + mma( + frag_a[k2][i], frag_b0, frag_c[i][j][0]); + mma( + frag_a[k2][i], frag_b1, frag_c[i][j][1]); } } } @@ -1498,7 +1461,8 @@ __global__ void Marlin( } if constexpr (w_type == vllm::kFE2M1f && s_type == vllm::kFE4M3fn) { - res = __hmul2(res, global_scale); + c0 *= global_scale_f32; + c1 *= global_scale_f32; } if (has_bias && last) { scalar_t2 tmp_bias = b_bias[0]; @@ -1668,6 +1632,21 @@ __global__ void Marlin( // While this pattern may not be the most readable, other ways of writing // the loop seemed to noticeably worse performance after compilation. if (slice_iters == 0) { + if constexpr (use_fp16_accum) { +#pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2; i++) { + float* frag_c_part_float = reinterpret_cast(frag_c) + i * 4; + scalar_t* frag_c_part_half = + reinterpret_cast(frag_c_part_float); + +#pragma unroll + for (int j = 3; j >= 0; j--) { + frag_c_part_float[j] = ScalarType::num2float( + frag_c_part_half[j]); + } + } + } + cp_async_wait<0>(); bool last = slice_idx == slice_count - 1; // For per-column scales, we only fetch them here in the final step before @@ -1805,4 +1784,4 @@ __global__ void Marlin( } // namespace MARLIN_NAMESPACE_NAME -#endif \ No newline at end of file +#endif diff --git a/gptqmodel_ext/marlin/marlin_torch_bf16.cpp b/gptqmodel_ext/marlin/marlin_torch_bf16.cpp new file mode 100644 index 000000000..4dd489d1c --- /dev/null +++ b/gptqmodel_ext/marlin/marlin_torch_bf16.cpp @@ -0,0 +1,67 @@ +#include + +#include "awq_marlin_repack.cuh" +#include "core/scalar_type.hpp" +#include "gptq_marlin_repack.cuh" + +torch::Tensor gptq_marlin_gemm_bf16( + torch::Tensor& a, std::optional c_or_none, + torch::Tensor& b_q_weight, + std::optional const& b_bias_or_none, torch::Tensor& b_scales, + std::optional const& global_scale_or_none, + std::optional const& b_zeros_or_none, + std::optional const& g_idx_or_none, + std::optional const& perm_or_none, torch::Tensor& workspace, + vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n, + int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, + bool is_zp_float); + +namespace { + +torch::Tensor gptq_marlin_gemm_bf16_dispatch( + torch::Tensor a, std::optional c_or_none, + torch::Tensor b_q_weight, std::optional const& b_bias_or_none, + torch::Tensor b_scales, + std::optional const& global_scale_or_none, + std::optional const& b_zeros_or_none, + std::optional const& g_idx_or_none, + std::optional const& perm_or_none, torch::Tensor workspace, + int64_t b_q_type_id, int64_t size_m, int64_t size_n, int64_t size_k, + bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, + bool is_zp_float) { + return gptq_marlin_gemm_bf16( + a, c_or_none, b_q_weight, b_bias_or_none, b_scales, global_scale_or_none, + b_zeros_or_none, g_idx_or_none, perm_or_none, workspace, + static_cast(b_q_type_id), size_m, size_n, size_k, + is_k_full, use_atomic_add, use_fp32_reduce, is_zp_float); +} + +torch::Tensor gptq_marlin_repack_dispatch(torch::Tensor b_q_weight, + torch::Tensor perm, int64_t size_k, + int64_t size_n, int64_t num_bits) { + return gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits); +} + +torch::Tensor awq_marlin_repack_dispatch(torch::Tensor b_q_weight, + int64_t size_k, int64_t size_n, + int64_t num_bits) { + return awq_marlin_repack(b_q_weight, size_k, size_n, num_bits); +} + +} // namespace + +TORCH_LIBRARY(gptqmodel_marlin_bf16, m) { + m.def( + "gptq_marlin_gemm_bf16(Tensor a, Tensor? c, Tensor b_q_weight, Tensor? b_bias, Tensor b_scales, " + "Tensor? global_scale, Tensor? b_zeros, Tensor? g_idx, Tensor? perm, Tensor workspace, int b_q_type_id, " + "int size_m, int size_n, int size_k, bool is_k_full=True, bool use_atomic_add=False, " + "bool use_fp32_reduce=False, bool is_zp_float=False) -> Tensor"); + m.def("gptq_marlin_repack(Tensor b_q_weight, Tensor perm, int size_k, int size_n, int num_bits) -> Tensor"); + m.def("awq_marlin_repack(Tensor b_q_weight, int size_k, int size_n, int num_bits) -> Tensor"); +} + +TORCH_LIBRARY_IMPL(gptqmodel_marlin_bf16, CUDA, m) { + m.impl("gptq_marlin_gemm_bf16", &gptq_marlin_gemm_bf16_dispatch); + m.impl("gptq_marlin_repack", &gptq_marlin_repack_dispatch); + m.impl("awq_marlin_repack", &awq_marlin_repack_dispatch); +} diff --git a/gptqmodel_ext/marlin/marlin_torch_fp16.cpp b/gptqmodel_ext/marlin/marlin_torch_fp16.cpp new file mode 100644 index 000000000..70f448cb5 --- /dev/null +++ b/gptqmodel_ext/marlin/marlin_torch_fp16.cpp @@ -0,0 +1,67 @@ +#include + +#include "awq_marlin_repack.cuh" +#include "core/scalar_type.hpp" +#include "gptq_marlin_repack.cuh" + +torch::Tensor gptq_marlin_gemm_fp16( + torch::Tensor& a, std::optional c_or_none, + torch::Tensor& b_q_weight, + std::optional const& b_bias_or_none, torch::Tensor& b_scales, + std::optional const& global_scale_or_none, + std::optional const& b_zeros_or_none, + std::optional const& g_idx_or_none, + std::optional const& perm_or_none, torch::Tensor& workspace, + vllm::ScalarTypeId const& b_q_type_id, int64_t size_m, int64_t size_n, + int64_t size_k, bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, + bool is_zp_float); + +namespace { + +torch::Tensor gptq_marlin_gemm_fp16_dispatch( + torch::Tensor a, std::optional c_or_none, + torch::Tensor b_q_weight, std::optional const& b_bias_or_none, + torch::Tensor b_scales, + std::optional const& global_scale_or_none, + std::optional const& b_zeros_or_none, + std::optional const& g_idx_or_none, + std::optional const& perm_or_none, torch::Tensor workspace, + int64_t b_q_type_id, int64_t size_m, int64_t size_n, int64_t size_k, + bool is_k_full, bool use_atomic_add, bool use_fp32_reduce, + bool is_zp_float) { + return gptq_marlin_gemm_fp16( + a, c_or_none, b_q_weight, b_bias_or_none, b_scales, global_scale_or_none, + b_zeros_or_none, g_idx_or_none, perm_or_none, workspace, + static_cast(b_q_type_id), size_m, size_n, size_k, + is_k_full, use_atomic_add, use_fp32_reduce, is_zp_float); +} + +torch::Tensor gptq_marlin_repack_dispatch(torch::Tensor b_q_weight, + torch::Tensor perm, int64_t size_k, + int64_t size_n, int64_t num_bits) { + return gptq_marlin_repack(b_q_weight, perm, size_k, size_n, num_bits); +} + +torch::Tensor awq_marlin_repack_dispatch(torch::Tensor b_q_weight, + int64_t size_k, int64_t size_n, + int64_t num_bits) { + return awq_marlin_repack(b_q_weight, size_k, size_n, num_bits); +} + +} // namespace + +TORCH_LIBRARY(gptqmodel_marlin_fp16, m) { + m.def( + "gptq_marlin_gemm_fp16(Tensor a, Tensor? c, Tensor b_q_weight, Tensor? b_bias, Tensor b_scales, " + "Tensor? global_scale, Tensor? b_zeros, Tensor? g_idx, Tensor? perm, Tensor workspace, int b_q_type_id, " + "int size_m, int size_n, int size_k, bool is_k_full=True, bool use_atomic_add=False, " + "bool use_fp32_reduce=False, bool is_zp_float=False) -> Tensor"); + m.def("gptq_marlin_repack(Tensor b_q_weight, Tensor perm, int size_k, int size_n, int num_bits) -> Tensor"); + m.def("awq_marlin_repack(Tensor b_q_weight, int size_k, int size_n, int num_bits) -> Tensor"); +} + +TORCH_LIBRARY_IMPL(gptqmodel_marlin_fp16, CUDA, m) { + m.impl("gptq_marlin_gemm_fp16", &gptq_marlin_gemm_fp16_dispatch); + m.impl("gptq_marlin_repack", &gptq_marlin_repack_dispatch); + m.impl("awq_marlin_repack", &awq_marlin_repack_dispatch); +} diff --git a/gptqmodel_ext/paroquant/rotation.cu b/gptqmodel_ext/paroquant/rotation.cu new file mode 100644 index 000000000..5ceba1dc3 --- /dev/null +++ b/gptqmodel_ext/paroquant/rotation.cu @@ -0,0 +1,664 @@ +// SPDX-FileCopyrightText: 2026 ModelCloud.ai +// SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai +// SPDX-License-Identifier: Apache-2.0 + +/****************************************************************************** + * Adapted from https://github.com/z-lab/paroquant + ******************************************************************************/ + +#include "rotation.cuh" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +template +__global__ void rotate_kernel(const scalar_t *__restrict__ x, scalar_t *__restrict__ out, + const int16_t *__restrict__ idx_ij, const scalar_t *__restrict__ theta, + const scalar_t *__restrict__ scales, int s, int h) { + constexpr int ROW_STRIDE = CTA_M + ROW_PAD; + __shared__ scalar_t x_grp[ROW_STRIDE * GROUP_SIZE]; + + int j = blockIdx.x; + int g = blockIdx.y; + int t = threadIdx.x; + + RotateAccess::template load_group( + x_grp, x, scales, s, h, j, g, t); + + float reg_theta[KROT]; + int reg_idx[KROT]; + RotateAccess::template load_coeffs(reg_theta, reg_idx, idx_ij, theta, + h, g, t); + __syncthreads(); + +#pragma unroll + for (int r = 0; r < KROT; r++) { + RotateAccess::template apply_one(x_grp, reg_idx[r], reg_theta[r]); + __syncthreads(); + } + + RotateAccess::template store_group(out, x_grp, s, h, j, + g, t); +} + +template +__global__ void rotate_kernel_bf16_half_workspace(const __nv_bfloat16 *__restrict__ x, + __nv_bfloat16 *__restrict__ out, + const int16_t *__restrict__ idx_ij, + const __nv_bfloat16 *__restrict__ theta, + const __nv_bfloat16 *__restrict__ scales, int s, + int h) { + constexpr int ROW_STRIDE = CTA_M + ROW_PAD; + __shared__ __half x_grp[ROW_STRIDE * GROUP_SIZE]; + + int j = blockIdx.x; + int g = blockIdx.y; + int t = threadIdx.x; + + RotateAccessBFloat16HalfWorkspace::template load_group( + x_grp, x, scales, s, h, j, g, t); + + float reg_theta[KROT]; + int reg_idx[KROT]; + RotateAccessBFloat16HalfWorkspace::template load_coeffs(reg_theta, reg_idx, + idx_ij, theta, h, g, + t); + __syncthreads(); + +#pragma unroll + for (int r = 0; r < KROT; r++) { + RotateAccess<__half>::template apply_one(x_grp, reg_idx[r], reg_theta[r]); + __syncthreads(); + } + + RotateAccessBFloat16HalfWorkspace::template store_group( + out, x_grp, s, h, j, g, t); +} + +#define LAUNCH_ROTATE(CUDA_T, TORCH_T) \ + { \ + auto *x_p = reinterpret_cast(x.data_ptr()); \ + auto *o_p = reinterpret_cast(out.data_ptr()); \ + auto *t_p = reinterpret_cast(theta_cast.data_ptr()); \ + if (has_scale) { \ + auto *s_p = reinterpret_cast(scales_cast.data_ptr()); \ + rotate_kernel<<>>( \ + x_p, o_p, idx_ij.data_ptr(), t_p, s_p, seq_len, h); \ + } else { \ + rotate_kernel<<>>( \ + x_p, o_p, idx_ij.data_ptr(), t_p, nullptr, seq_len, h); \ + } \ + break; \ + } + +template +torch::Tensor rotate_launcher_bf16_half_workspace(at::Tensor x, at::Tensor idx_ij, + at::Tensor theta, at::Tensor scales) { + int h = x.size(-1); + TORCH_CHECK(h % GROUP_SIZE == 0, "h must be divisible by GROUP_SIZE"); + int groups_per_row = h / GROUP_SIZE; + constexpr int pn = GROUP_SIZE / 2; + int seq_len = x.numel() / x.size(-1); + auto options = torch::TensorOptions().dtype(x.dtype()).device(x.device()); + at::Tensor out = torch::empty(x.sizes(), options); + bool has_scale = scales.defined() && scales.numel() > 0; + + auto theta_cast = theta.scalar_type() == at::kBFloat16 ? theta : theta.to(at::kBFloat16); + auto scales_cast = !has_scale ? at::Tensor() + : scales.scalar_type() == at::kBFloat16 ? scales + : scales.to(at::kBFloat16); + + dim3 grid((seq_len + CTA_M - 1) / CTA_M, groups_per_row); + dim3 block(pn); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + auto *x_p = reinterpret_cast<__nv_bfloat16 *>(x.data_ptr()); + auto *o_p = reinterpret_cast<__nv_bfloat16 *>(out.data_ptr()); + auto *t_p = reinterpret_cast<__nv_bfloat16 *>(theta_cast.data_ptr()); + if (has_scale) { + auto *s_p = reinterpret_cast<__nv_bfloat16 *>(scales_cast.data_ptr()); + rotate_kernel_bf16_half_workspace<<>>( + x_p, o_p, idx_ij.data_ptr(), t_p, s_p, seq_len, h); + } else { + rotate_kernel_bf16_half_workspace<<>>( + x_p, o_p, idx_ij.data_ptr(), t_p, nullptr, seq_len, h); + } + return out; +} + +template +torch::Tensor rotate_launcher(at::Tensor x, at::Tensor idx_ij, at::Tensor theta, + at::Tensor scales) { + int h = x.size(-1); + TORCH_CHECK(h % GROUP_SIZE == 0, "h must be divisible by GROUP_SIZE"); + int groups_per_row = h / GROUP_SIZE; + constexpr int pn = GROUP_SIZE / 2; + int seq_len = x.numel() / x.size(-1); + auto options = torch::TensorOptions().dtype(x.dtype()).device(x.device()); + at::Tensor out = torch::empty(x.sizes(), options); + bool has_scale = scales.defined() && scales.numel() > 0; + + auto dtype = x.scalar_type(); + auto theta_cast = theta.scalar_type() == dtype ? theta : theta.to(x.dtype()); + auto scales_cast = !has_scale ? at::Tensor() + : scales.scalar_type() == dtype ? scales + : scales.to(x.dtype()); + + dim3 grid((seq_len + CTA_M - 1) / CTA_M, groups_per_row); + dim3 block(pn); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + switch (dtype) { + case at::kFloat: + LAUNCH_ROTATE(float, float) + case at::kHalf: + LAUNCH_ROTATE(__half, c10::Half) + case at::kBFloat16: + return rotate_launcher_bf16_half_workspace(x, idx_ij, theta, scales); + default: + TORCH_CHECK(false, "rotate supports Float, Half, and BFloat16, got ", x.scalar_type()); + } + return out; +} + +#undef LAUNCH_ROTATE + +#define DISPATCH_ROW_PAD(KROT, CTA, GS) \ + switch (row_pad) { \ + case 2: \ + return rotate_launcher(x, idx, theta, scales); \ + case 0: \ + return rotate_launcher(x, idx, theta, scales); \ + default: \ + TORCH_CHECK(false, "Unsupported ROW_PAD = ", row_pad, "; compiled variants: 0 and 2"); \ + } + +#define DISPATCH_CTA_M(KROT, GS) \ + switch (cta_m) { \ + case 16: \ + DISPATCH_ROW_PAD(KROT, 16, GS) \ + case 8: \ + DISPATCH_ROW_PAD(KROT, 8, GS) \ + case 4: \ + DISPATCH_ROW_PAD(KROT, 4, GS) \ + default: \ + TORCH_CHECK(false, "Unsupported CTA_M = ", cta_m, "; compiled variants: 4, 8, and 16"); \ + } + +#define DISPATCH_KROT(GS) \ + switch (krot) { \ + case 1: \ + DISPATCH_CTA_M(1, GS) \ + case 8: \ + DISPATCH_CTA_M(8, GS) \ + default: \ + TORCH_CHECK(false, "Unsupported KROT = ", krot, "; compiled variants: 1 and 8"); \ + } + +torch::Tensor dispatch_rotate_variant(at::Tensor x, at::Tensor idx, at::Tensor theta, at::Tensor scales, + int64_t group_size, int64_t krot, int cta_m, int row_pad) { + if (group_size == 128) { + DISPATCH_KROT(128) + } + TORCH_CHECK(false, "Unsupported group_size: ", group_size, "; expected 128"); +} + +namespace { + +// Resolve launch dimensions in one place so the runtime path can use either a +// fixed explicit launch shape, the legacy default, or a cached autotuned plan. +struct LaunchConfig { + int cta_m; + int row_pad; +}; + +constexpr int kLegacyLaunchSentinel = -1; +constexpr int kAutotuneLaunchSentinel = -2; +constexpr float kAutotuneMinRelativeSpeedup = 0.15f; +// Tiny decode kernels can show large relative swings from timer noise. Require +// a minimum absolute win before overriding the architecture baseline. +constexpr float kAutotuneMinAbsoluteSpeedupMs = 0.01f; +constexpr std::array kAutotuneCandidates = {{ + {4, 0}, + {4, 2}, + {8, 0}, + {8, 2}, + {16, 0}, + {16, 2}, +}}; + +struct AutotuneKey { + int device_index; + int scalar_type; + int seq_len; + int hidden; + int group_size; + int krot; + bool has_scale; + + bool operator==(const AutotuneKey &other) const { + return device_index == other.device_index && scalar_type == other.scalar_type && + seq_len == other.seq_len && hidden == other.hidden && + group_size == other.group_size && krot == other.krot && has_scale == other.has_scale; + } +}; + +struct AutotuneKeyHash { + size_t operator()(const AutotuneKey &key) const noexcept { + size_t hash = std::hash{}(key.device_index); + hash ^= std::hash{}(key.scalar_type) + 0x9e3779b9 + (hash << 6) + (hash >> 2); + hash ^= std::hash{}(key.seq_len) + 0x9e3779b9 + (hash << 6) + (hash >> 2); + hash ^= std::hash{}(key.hidden) + 0x9e3779b9 + (hash << 6) + (hash >> 2); + hash ^= std::hash{}(key.group_size) + 0x9e3779b9 + (hash << 6) + (hash >> 2); + hash ^= std::hash{}(key.krot) + 0x9e3779b9 + (hash << 6) + (hash >> 2); + hash ^= std::hash{}(key.has_scale) + 0x9e3779b9 + (hash << 6) + (hash >> 2); + return hash; + } +}; + +// Process-local cache for the winning launch plan per runtime shape. This is +// intentionally native so the steady-state fused path stays inside one op call. +std::unordered_map &autotune_cache() { + static std::unordered_map cache; + return cache; +} + +std::mutex &autotune_cache_mutex() { + static std::mutex mutex; + return mutex; +} + +std::mutex &autotune_measurement_mutex() { + // Serialize cold-shape autotune so free-threaded callers do not benchmark + // the same or competing launch plans concurrently on cache misses. + static std::mutex mutex; + return mutex; +} + +int resolve_autotune_count(const char *name, int default_value) { + if (const char *raw = std::getenv(name)) { + int parsed = std::atoi(raw); + if (parsed > 0) { + return parsed; + } + } + return default_value; +} + +float resolve_autotune_target_ms() { + if (const char *raw = std::getenv("GPTQMODEL_PAROQUANT_ROTATE_AUTOTUNE_TARGET_MS")) { + const float parsed = std::atof(raw); + if (parsed > 0.0f) { + return parsed; + } + } + return 5.0f; +} + +float resolve_autotune_min_relative_speedup() { + if (const char *raw = std::getenv("GPTQMODEL_PAROQUANT_ROTATE_AUTOTUNE_MIN_SPEEDUP_PCT")) { + const float parsed = std::atof(raw); + if (parsed > 0.0f) { + return parsed / 100.0f; + } + } + return kAutotuneMinRelativeSpeedup; +} + +float resolve_autotune_min_absolute_speedup_seconds() { + if (const char *raw = std::getenv("GPTQMODEL_PAROQUANT_ROTATE_AUTOTUNE_MIN_SPEEDUP_US")) { + const float parsed = std::atof(raw); + if (parsed > 0.0f) { + return parsed / 1e6f; + } + } + return kAutotuneMinAbsoluteSpeedupMs / 1e3f; +} + +bool requests_autotune(int requested_cta_m, int requested_row_pad) { + return requested_cta_m == kAutotuneLaunchSentinel || requested_row_pad == kAutotuneLaunchSentinel; +} + +AutotuneKey make_autotune_key(const at::Tensor &x, int64_t group_size, int64_t krot, bool has_scale) { + return { + x.get_device(), + static_cast(x.scalar_type()), + static_cast(x.numel() / x.size(-1)), + static_cast(x.size(-1)), + static_cast(group_size), + static_cast(krot), + has_scale, + }; +} + +std::optional lookup_autotune_cache(const AutotuneKey &key) { + std::lock_guard guard(autotune_cache_mutex()); + auto &cache = autotune_cache(); + auto it = cache.find(key); + if (it == cache.end()) { + return std::nullopt; + } + return it->second; +} + +LaunchConfig store_autotune_cache(const AutotuneKey &key, LaunchConfig config) { + std::lock_guard guard(autotune_cache_mutex()); + auto &cache = autotune_cache(); + auto [it, inserted] = cache.emplace(key, config); + if (!inserted) { + return it->second; + } + return config; +} + +int current_sm_version() { + static thread_local c10::DeviceIndex cached_device = -1; + static thread_local int cached_sm = -1; + const c10::DeviceIndex current_device = c10::cuda::current_device(); + if (current_device != cached_device) { + cached_device = current_device; + const cudaDeviceProp *props = at::cuda::getDeviceProperties(current_device); + cached_sm = props == nullptr ? -1 : (props->major * 10 + props->minor); + } + return cached_sm; +} + +int resolve_cta_m(at::ScalarType dtype, int explicit_value) { + if (explicit_value == 4 || explicit_value == 8 || explicit_value == 16) { + return explicit_value; + } + + // These defaults are pinned to manual full-sweep measurements on the A100 + // (sm80) and RTX 4090 (sm89) available on this host. Other architectures may + // not benefit from these launch shapes and can regress in performance, so + // they stay on the legacy default until benchmarked explicitly. + const int sm_version = current_sm_version(); + if (dtype == at::kHalf) { + switch (sm_version) { + case 80: + case 89: + return 8; + default: + return 4; + } + } + if (dtype == at::kBFloat16) { + switch (sm_version) { + case 80: + return 4; + case 89: + return 16; + default: + return 4; + } + } + return 4; +} + +int resolve_row_pad(at::ScalarType dtype, int explicit_value) { + if (explicit_value == 0 || explicit_value == 2) { + return explicit_value; + } + return dtype == at::kFloat ? 0 : 2; +} + +LaunchConfig resolve_legacy_launch_config(at::ScalarType dtype, int explicit_cta_m, int explicit_row_pad) { + return { + resolve_cta_m(dtype, explicit_cta_m), + resolve_row_pad(dtype, explicit_row_pad), + }; +} + +float benchmark_launch_config(const at::Tensor &x, const at::Tensor &idx, const at::Tensor &theta, + const at::Tensor &scales, int64_t group_size, int64_t krot, + LaunchConfig config) { + const int warmup = resolve_autotune_count("GPTQMODEL_PAROQUANT_ROTATE_AUTOTUNE_WARMUP", 3); + const int base_iters = resolve_autotune_count("GPTQMODEL_PAROQUANT_ROTATE_AUTOTUNE_ITERS", 25); + const int repeats = resolve_autotune_count("GPTQMODEL_PAROQUANT_ROTATE_AUTOTUNE_REPEATS", 5); + const float target_ms = resolve_autotune_target_ms(); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + auto synchronize_stream = [stream]() { + const cudaError_t status = cudaStreamSynchronize(stream); + TORCH_CHECK(status == cudaSuccess, "ParoQuant rotation autotune failed to synchronize CUDA stream: ", + cudaGetErrorString(status)); + }; + auto create_event = []() { + cudaEvent_t event = nullptr; + const cudaError_t status = cudaEventCreate(&event); + TORCH_CHECK(status == cudaSuccess, "ParoQuant rotation autotune failed to create CUDA event: ", + cudaGetErrorString(status)); + return event; + }; + auto destroy_event = [](cudaEvent_t event) { + if (event == nullptr) { + return; + } + const cudaError_t status = cudaEventDestroy(event); + TORCH_CHECK(status == cudaSuccess, "ParoQuant rotation autotune failed to destroy CUDA event: ", + cudaGetErrorString(status)); + }; + auto elapsed_ms_for_iterations = [&](int iters) { + cudaEvent_t start = create_event(); + cudaEvent_t end = create_event(); + const auto cleanup = [&]() { + destroy_event(end); + destroy_event(start); + }; + + const cudaError_t start_status = cudaEventRecord(start, stream); + TORCH_CHECK(start_status == cudaSuccess, + "ParoQuant rotation autotune failed to record CUDA start event: ", + cudaGetErrorString(start_status)); + for (int i = 0; i < iters; ++i) { + auto out = dispatch_rotate_variant(x, idx, theta, scales, group_size, krot, config.cta_m, + config.row_pad); + (void)out; + } + const cudaError_t end_status = cudaEventRecord(end, stream); + TORCH_CHECK(end_status == cudaSuccess, "ParoQuant rotation autotune failed to record CUDA end event: ", + cudaGetErrorString(end_status)); + const cudaError_t sync_status = cudaEventSynchronize(end); + TORCH_CHECK(sync_status == cudaSuccess, + "ParoQuant rotation autotune failed to synchronize CUDA end event: ", + cudaGetErrorString(sync_status)); + float elapsed_ms = 0.0f; + const cudaError_t elapsed_status = cudaEventElapsedTime(&elapsed_ms, start, end); + TORCH_CHECK(elapsed_status == cudaSuccess, + "ParoQuant rotation autotune failed to read CUDA elapsed time: ", + cudaGetErrorString(elapsed_status)); + cleanup(); + return elapsed_ms; + }; + + for (int i = 0; i < warmup; ++i) { + auto out = dispatch_rotate_variant(x, idx, theta, scales, group_size, krot, config.cta_m, + config.row_pad); + (void)out; + } + synchronize_stream(); + + const float pilot_ms = elapsed_ms_for_iterations(base_iters); + const float min_total_ms = std::max(target_ms, 1.0f); + int iters = base_iters; + if (pilot_ms > 0.0f && pilot_ms < min_total_ms) { + const float scale = min_total_ms / pilot_ms; + iters = std::max(base_iters, static_cast(std::ceil(static_cast(base_iters) * scale))); + iters = std::min(iters, 5000); + } + + std::vector measurements; + measurements.reserve(repeats); + if (iters == base_iters) { + measurements.push_back(pilot_ms / static_cast(base_iters)); + } + while (static_cast(measurements.size()) < repeats) { + measurements.push_back(elapsed_ms_for_iterations(iters) / static_cast(iters)); + } + std::nth_element(measurements.begin(), measurements.begin() + (measurements.size() / 2), + measurements.end()); + return measurements[measurements.size() / 2] / 1e3f; +} + +LaunchConfig autotune_launch_config(const at::Tensor &x, const at::Tensor &idx, const at::Tensor &theta, + const at::Tensor &scales, int64_t group_size, int64_t krot) { + const LaunchConfig fallback = + resolve_legacy_launch_config(x.scalar_type(), kLegacyLaunchSentinel, kLegacyLaunchSentinel); + const float fallback_seconds = + benchmark_launch_config(x, idx, theta, scales, group_size, krot, fallback); + const float min_relative_speedup = resolve_autotune_min_relative_speedup(); + const float min_absolute_speedup_seconds = resolve_autotune_min_absolute_speedup_seconds(); + auto beats_fallback = [&](float candidate_seconds, float baseline_seconds) { + const float absolute_improvement = baseline_seconds - candidate_seconds; + return absolute_improvement >= min_absolute_speedup_seconds && + candidate_seconds <= baseline_seconds * (1.0f - min_relative_speedup); + }; + + LaunchConfig best = fallback; + float best_seconds = fallback_seconds; + for (const LaunchConfig candidate : kAutotuneCandidates) { + if (candidate.cta_m == fallback.cta_m && candidate.row_pad == fallback.row_pad) { + continue; + } + const float seconds = benchmark_launch_config(x, idx, theta, scales, group_size, krot, candidate); + if (seconds < best_seconds) { + best_seconds = seconds; + best = candidate; + } + } + if (best.cta_m == fallback.cta_m && best.row_pad == fallback.row_pad) { + return fallback; + } + if (!beats_fallback(best_seconds, fallback_seconds)) { + return fallback; + } + const float confirm_fallback_seconds = + benchmark_launch_config(x, idx, theta, scales, group_size, krot, fallback); + const float confirm_best_seconds = + benchmark_launch_config(x, idx, theta, scales, group_size, krot, best); + if (!beats_fallback(confirm_best_seconds, confirm_fallback_seconds)) { + return fallback; + } + return best; +} + +LaunchConfig resolve_cached_autotune_launch_config(const AutotuneKey &key, const at::Tensor &x, + const at::Tensor &idx, const at::Tensor &theta, + const at::Tensor &scales, int64_t group_size, + int64_t krot) { + if (auto cached = lookup_autotune_cache(key)) { + return *cached; + } + std::lock_guard guard(autotune_measurement_mutex()); + if (auto cached = lookup_autotune_cache(key)) { + return *cached; + } + LaunchConfig measured = autotune_launch_config(x, idx, theta, scales, group_size, krot); + return store_autotune_cache(key, measured); +} + +at::Tensor build_dummy_pairs(const at::Tensor &x, int64_t group_size, int64_t krot) { + TORCH_CHECK(group_size > 0, "group_size must be positive"); + TORCH_CHECK(x.size(-1) % group_size == 0, "hidden size must be divisible by group_size"); + auto options = torch::TensorOptions().dtype(at::kShort).device(x.device()); + const int64_t groups = x.size(-1) / group_size; + const auto local_pairs = at::arange(group_size, options); + return local_pairs.repeat({groups}).unsqueeze(0).repeat({krot, 1}).contiguous(); +} + +LaunchConfig resolve_runtime_launch_config(const at::Tensor &x, const at::Tensor &idx, + const at::Tensor &theta, const at::Tensor &scales, + int64_t group_size, int64_t krot, int requested_cta_m, + int requested_row_pad, bool has_scale) { + if (!requests_autotune(requested_cta_m, requested_row_pad)) { + return resolve_legacy_launch_config(x.scalar_type(), requested_cta_m, requested_row_pad); + } + const AutotuneKey key = make_autotune_key(x, group_size, krot, has_scale); + return resolve_cached_autotune_launch_config(key, x, idx, theta, scales, group_size, krot); +} + +LaunchConfig resolve_query_launch_config(const at::Tensor &x, int64_t krot, bool has_scale, + int64_t group_size, int requested_cta_m, + int requested_row_pad) { + if (!requests_autotune(requested_cta_m, requested_row_pad)) { + return resolve_legacy_launch_config(x.scalar_type(), requested_cta_m, requested_row_pad); + } + const AutotuneKey key = make_autotune_key(x, group_size, krot, has_scale); + if (auto cached = lookup_autotune_cache(key)) { + return *cached; + } + std::lock_guard guard(autotune_measurement_mutex()); + if (auto cached = lookup_autotune_cache(key)) { + return *cached; + } + at::Tensor dummy_idx = build_dummy_pairs(x, group_size, krot); + at::Tensor dummy_theta = at::zeros({krot, x.size(-1) / 2}, x.options()); + at::Tensor dummy_scales = has_scale ? at::ones({1, x.size(-1)}, x.options()) : at::Tensor(); + LaunchConfig measured = autotune_launch_config(x, dummy_idx, dummy_theta, dummy_scales, group_size, krot); + return store_autotune_cache(key, measured); +} + +void clear_rotation_autotune_cache() { + std::lock_guard guard(autotune_cache_mutex()); + autotune_cache().clear(); +} + +int64_t rotation_autotune_cache_size() { + std::lock_guard guard(autotune_cache_mutex()); + return static_cast(autotune_cache().size()); +} + +} // namespace + +torch::Tensor rotate_dynamic(at::Tensor x, at::Tensor idx, at::Tensor theta, + c10::optional scales_opt, int64_t group_size = 128, + int64_t requested_cta_m = -1, int64_t requested_row_pad = -1) { + int64_t krot = theta.size(0); + TORCH_CHECK(krot == idx.size(0), "theta.size(0) must equal idx_ij.size(0)"); + at::Tensor scales = scales_opt.value_or(at::Tensor()); + const bool has_scale = scales.defined() && scales.numel() > 0; + const LaunchConfig config = resolve_runtime_launch_config( + x, idx, theta, scales, group_size, krot, static_cast(requested_cta_m), + static_cast(requested_row_pad), has_scale); + return dispatch_rotate_variant(x, idx, theta, scales, group_size, krot, config.cta_m, + config.row_pad); +} + +std::vector rotate_launch_config(at::Tensor x, int64_t krot = 8, bool has_scale = true, + int64_t group_size = 128, int64_t cta_m = -1, + int64_t row_pad = -1) { + const LaunchConfig config = resolve_query_launch_config( + x, krot, has_scale, group_size, static_cast(cta_m), static_cast(row_pad)); + return { + static_cast(config.cta_m), + static_cast(config.row_pad), + }; +} + +#undef DISPATCH_ROW_PAD +#undef DISPATCH_CTA_M +#undef DISPATCH_KROT + +TORCH_LIBRARY(gptqmodel_paroquant, m) { + m.def("rotate(Tensor x, Tensor idx_ij, Tensor theta, Tensor? scales=None, int group_size=128, int cta_m=-1, int row_pad=-1) -> Tensor"); + m.def("launch_config(Tensor x, int krot=8, bool has_scale=True, int group_size=128, int cta_m=-1, int row_pad=-1) -> int[]"); + m.def("clear_autotune_cache() -> ()"); + m.def("autotune_cache_size() -> int"); +} + +TORCH_LIBRARY_IMPL(gptqmodel_paroquant, CUDA, m) { + m.impl("rotate", &rotate_dynamic); + m.impl("launch_config", &rotate_launch_config); +} + +TORCH_LIBRARY_IMPL(gptqmodel_paroquant, CatchAll, m) { + m.impl("clear_autotune_cache", &clear_rotation_autotune_cache); + m.impl("autotune_cache_size", &rotation_autotune_cache_size); +} diff --git a/gptqmodel_ext/paroquant/rotation.cuh b/gptqmodel_ext/paroquant/rotation.cuh new file mode 100644 index 000000000..1ad29492b --- /dev/null +++ b/gptqmodel_ext/paroquant/rotation.cuh @@ -0,0 +1,257 @@ +// SPDX-FileCopyrightText: 2026 ModelCloud.ai +// SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai +// SPDX-License-Identifier: Apache-2.0 + +/****************************************************************************** + * Adapted from https://github.com/z-lab/paroquant + ******************************************************************************/ + +#pragma once + +#include +#include +#include +#include + +template struct RotateAccess; + +template <> struct RotateAccess { + template + __device__ static void load_group(float *__restrict__ x_grp, const float *__restrict__ x, + const float *__restrict__ scales, const int s, const int h, + const int j, const int g, const int t) { + const int base0 = g * GROUP_SIZE + t; + const int base1 = base0 + GROUP_SIZE / 2; + float scale0 = USE_SCALE ? scales[base0] : float(1); + float scale1 = USE_SCALE ? scales[base1] : float(1); +#pragma unroll + for (int i = 0; i < CTA_M; i++) { + int row = j * CTA_M + i; + if (row < s) { + x_grp[t * ROW_STRIDE + i] = x[row * h + base0] * scale0; + x_grp[(t + GROUP_SIZE / 2) * ROW_STRIDE + i] = x[row * h + base1] * scale1; + } + } + } + + template + __device__ static void + load_coeffs(float reg_theta[KROT], int reg_idx[KROT], const int16_t *__restrict__ idx_ij, + const float *__restrict__ theta, const int h, const int g, const int t) { +#pragma unroll + for (int r = 0; r < KROT; r++) { + reg_theta[r] = theta[r * h / 2 + g * GROUP_SIZE / 2 + t]; + reg_idx[r] = *reinterpret_cast(idx_ij + r * h + g * GROUP_SIZE + 2 * t); + } + } + + template + __device__ static void apply_one(float *__restrict__ x_grp, const int ij, const float theta) { + int16_t i = ij & 0xFFFF, j = ij >> 16; + float s_, c_; + __sincosf(theta, &s_, &c_); +#pragma unroll + for (int m = 0; m < CTA_M; m++) { + float xi = x_grp[i * ROW_STRIDE + m]; + float xj = x_grp[j * ROW_STRIDE + m]; + x_grp[i * ROW_STRIDE + m] = xi * c_ + xj * s_; + x_grp[j * ROW_STRIDE + m] = xi * (-s_) + xj * c_; + } + } + + template + __device__ static void store_group(float *__restrict__ out, const float *__restrict__ x_grp, + const int s, const int h, const int j, const int g, + const int t) { + const int base0 = g * GROUP_SIZE + t; + const int base1 = base0 + GROUP_SIZE / 2; +#pragma unroll + for (int i = 0; i < CTA_M; i++) { + int row = j * CTA_M + i; + if (row < s) { + out[row * h + base0] = x_grp[t * ROW_STRIDE + i]; + out[row * h + base1] = x_grp[(t + GROUP_SIZE / 2) * ROW_STRIDE + i]; + } + } + } +}; + +template struct RotateAccessHalf { + template + __device__ static void load_group(HalfT *__restrict__ x_grp, const HalfT *__restrict__ x, + const HalfT *__restrict__ scales, const int s, const int h, + const int j, const int g, const int t) { + static_assert((ROW_STRIDE % 2) == 0, "ROW_STRIDE must be even for vectorized half access"); + const int offset = GROUP_SIZE * g + 2 * t; + HalfT scale_i, scale_j; + if constexpr (USE_SCALE) { + Half2T scale_pair = *reinterpret_cast(scales + offset); + scale_i = Traits::low(scale_pair); + scale_j = Traits::high(scale_pair); + } else { + scale_i = Traits::from_float(1.0f); + scale_j = Traits::from_float(1.0f); + } + +#pragma unroll + for (int i = 0; i < CTA_M; i++) { + int row = j * CTA_M + i; + if (row < s) { + Half2T x2 = *reinterpret_cast(x + row * h + offset); + HalfT lo = Traits::hmul(Traits::low(x2), scale_i); + HalfT hi = Traits::hmul(Traits::high(x2), scale_j); + x_grp[(2 * t) * ROW_STRIDE + i] = lo; + x_grp[(2 * t + 1) * ROW_STRIDE + i] = hi; + } + } + } + + template + __device__ static void + load_coeffs(float reg_theta[KROT], int reg_idx[KROT], const int16_t *__restrict__ idx_ij, + const HalfT *__restrict__ theta, const int h, const int g, const int t) { +#pragma unroll + for (int r = 0; r < KROT; r++) { + reg_theta[r] = Traits::to_float(theta[r * h / 2 + g * GROUP_SIZE / 2 + t]); + reg_idx[r] = *reinterpret_cast(idx_ij + r * h + g * GROUP_SIZE + 2 * t); + } + } + + template + __device__ static void apply_one(HalfT *__restrict__ x_grp, const int ij, const float theta) { + static_assert((ROW_STRIDE % 2) == 0, "ROW_STRIDE must be even for vectorized half access"); + int16_t i = ij & 0xFFFF, j = ij >> 16; + float s_, c_; + __sincosf(theta, &s_, &c_); + +#pragma unroll + for (int m = 0; m < CTA_M / 2; ++m) { + Half2T *pi2 = reinterpret_cast(x_grp + i * ROW_STRIDE + m * 2); + Half2T *pj2 = reinterpret_cast(x_grp + j * ROW_STRIDE + m * 2); + + float2 xi = Traits::to_float2(*pi2); + float2 xj = Traits::to_float2(*pj2); + + float2 yi, yj; + yi.x = fmaf(c_, xi.x, s_ * xj.x); + yi.y = fmaf(c_, xi.y, s_ * xj.y); + yj.x = fmaf(c_, xj.x, -s_ * xi.x); + yj.y = fmaf(c_, xj.y, -s_ * xi.y); + + *pi2 = Traits::from_floats(yi.x, yi.y); + *pj2 = Traits::from_floats(yj.x, yj.y); + } + } + + template + __device__ static void store_group(HalfT *__restrict__ out, const HalfT *__restrict__ x_grp, + const int s, const int h, const int j, const int g, + const int t) { + static_assert((ROW_STRIDE % 2) == 0, "ROW_STRIDE must be even for vectorized half access"); + const int base = GROUP_SIZE * g + 2 * t; +#pragma unroll + for (int i = 0; i < CTA_M; i++) { + int row = j * CTA_M + i; + if (row < s) { + Half2T out2; + out2.x = x_grp[(2 * t) * ROW_STRIDE + i]; + out2.y = x_grp[(2 * t + 1) * ROW_STRIDE + i]; + *reinterpret_cast(out + row * h + base) = out2; + } + } + } +}; + +struct HalfTraits { + __device__ static float2 to_float2(__half2 v) { return __half22float2(v); } + __device__ static __half2 from_floats(float a, float b) { return __floats2half2_rn(a, b); } + __device__ static float to_float(__half v) { return __half2float(v); } + __device__ static __half from_float(float v) { return __float2half_rn(v); } + __device__ static __half low(__half2 v) { return __low2half(v); } + __device__ static __half high(__half2 v) { return __high2half(v); } + __device__ static __half hmul(__half a, __half b) { return __hmul(a, b); } +}; + +template <> struct RotateAccess<__half> : RotateAccessHalf<__half, __half2, HalfTraits> {}; + +struct BFloat16Traits { + __device__ static float2 to_float2(__nv_bfloat162 v) { return __bfloat1622float2(v); } + __device__ static __nv_bfloat162 from_floats(float a, float b) { + return __floats2bfloat162_rn(a, b); + } + __device__ static __half2 to_half2(__nv_bfloat162 v) { + float2 pair = __bfloat1622float2(v); + return __floats2half2_rn(pair.x, pair.y); + } + __device__ static __nv_bfloat162 from_half2(__half2 v) { + float2 pair = __half22float2(v); + return __floats2bfloat162_rn(pair.x, pair.y); + } + __device__ static float to_float(__nv_bfloat16 v) { return __bfloat162float(v); } + __device__ static __nv_bfloat16 from_float(float v) { return __float2bfloat16(v); } + __device__ static __nv_bfloat16 low(__nv_bfloat162 v) { return __low2bfloat16(v); } + __device__ static __nv_bfloat16 high(__nv_bfloat162 v) { return __high2bfloat16(v); } + __device__ static __nv_bfloat16 hmul(__nv_bfloat16 a, __nv_bfloat16 b) { return __hmul(a, b); } +}; + +// Keep BF16 inputs/outputs on the fused path while using the FP16 workspace +// update pattern internally to reduce BF16 round-trip loss across k-rot stages. +struct RotateAccessBFloat16HalfWorkspace { + template + __device__ static void load_group(__half *__restrict__ x_grp, const __nv_bfloat16 *__restrict__ x, + const __nv_bfloat16 *__restrict__ scales, const int s, + const int h, const int j, const int g, const int t) { + static_assert((ROW_STRIDE % 2) == 0, "ROW_STRIDE must be even for vectorized half access"); + const int offset = GROUP_SIZE * g + 2 * t; + __half2 scale_pair; + if constexpr (USE_SCALE) { + scale_pair = BFloat16Traits::to_half2(*reinterpret_cast(scales + offset)); + } else { + scale_pair = __floats2half2_rn(1.0f, 1.0f); + } + +#pragma unroll + for (int i = 0; i < CTA_M; i++) { + int row = j * CTA_M + i; + if (row < s) { + __half2 x_pair = + BFloat16Traits::to_half2(*reinterpret_cast(x + row * h + offset)); + __half2 prod = __hmul2(x_pair, scale_pair); + x_grp[(2 * t) * ROW_STRIDE + i] = __low2half(prod); + x_grp[(2 * t + 1) * ROW_STRIDE + i] = __high2half(prod); + } + } + } + + template + __device__ static void load_coeffs(float reg_theta[KROT], int reg_idx[KROT], + const int16_t *__restrict__ idx_ij, + const __nv_bfloat16 *__restrict__ theta, const int h, + const int g, const int t) { +#pragma unroll + for (int r = 0; r < KROT; r++) { + reg_theta[r] = BFloat16Traits::to_float(theta[r * h / 2 + g * GROUP_SIZE / 2 + t]); + reg_idx[r] = *reinterpret_cast(idx_ij + r * h + g * GROUP_SIZE + 2 * t); + } + } + + template + __device__ static void store_group(__nv_bfloat16 *__restrict__ out, const __half *__restrict__ x_grp, + const int s, const int h, const int j, const int g, + const int t) { + static_assert((ROW_STRIDE % 2) == 0, "ROW_STRIDE must be even for vectorized half access"); + const int base = GROUP_SIZE * g + 2 * t; +#pragma unroll + for (int i = 0; i < CTA_M; i++) { + int row = j * CTA_M + i; + if (row < s) { + __half2 out_pair = __halves2half2(x_grp[(2 * t) * ROW_STRIDE + i], x_grp[(2 * t + 1) * ROW_STRIDE + i]); + *reinterpret_cast<__nv_bfloat162 *>(out + row * h + base) = BFloat16Traits::from_half2(out_pair); + } + } + } +}; + +template <> +struct RotateAccess<__nv_bfloat16> + : RotateAccessHalf<__nv_bfloat16, __nv_bfloat162, BFloat16Traits> {}; diff --git a/gptqmodel_ext/qqq/qqq.cpp b/gptqmodel_ext/qqq/qqq.cpp index 5623320e8..f6d407e7b 100644 --- a/gptqmodel_ext/qqq/qqq.cpp +++ b/gptqmodel_ext/qqq/qqq.cpp @@ -1,7 +1,27 @@ // Adapted from https://github.com/HandH1998/QQQ +#include + #include "qqq_gemm.h" -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("qqq_gemm", &qqq_gemm, "INT8xINT4 matmul based marlin FP16xINT4 kernel."); +namespace { + +void qqq_gemm_dispatch(const torch::Tensor& A, const torch::Tensor& B, torch::Tensor& C, + torch::Tensor& D, const torch::Tensor& s1, const torch::Tensor& s2, + const torch::Tensor& s3, torch::Tensor& workspace, int64_t thread_k, + int64_t thread_n, int64_t sms, int64_t max_par) { + qqq_gemm(A, B, C, D, s1, s2, s3, workspace, static_cast(thread_k), + static_cast(thread_n), static_cast(sms), static_cast(max_par)); +} + +} // namespace + +TORCH_LIBRARY(gptqmodel_qqq, m) { + m.def( + "qqq_gemm(Tensor A, Tensor B, Tensor(a!) C, Tensor(a!) D, Tensor s1, Tensor s2, Tensor s3, " + "Tensor(a!) workspace, int thread_k=-1, int thread_n=-1, int sms=-1, int max_par=8) -> ()"); +} + +TORCH_LIBRARY_IMPL(gptqmodel_qqq, CUDA, m) { + m.impl("qqq_gemm", &qqq_gemm_dispatch); } diff --git a/progress.md b/progress.md new file mode 100644 index 000000000..d0f77d957 --- /dev/null +++ b/progress.md @@ -0,0 +1,1328 @@ +# TRAM-Quant: Mixed-Precision Ampere Kernel — Progress + +## Peer Review + +**Reviewer:** Copilot (Beta) +**Scope:** Full review of the TRAM-Quant kernel design as documented in `Project.md` lines 7393–9396. +**Verdict: REJECT — conditional on 5 must-fix items below. The architecture is strong but the current skeleton has a precision-killing accumulator choice, an unfinished transform path that will bottleneck Tensor Core throughput, and zero benchmarking infrastructure.** + +--- + +### 1. CRITICAL — FP16 Accumulation Will Ruin PPL + +The single most damaging decision in the current skeleton is this instruction: + +``` +mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 +``` + +The last `.f16` means the accumulator is FP16 (10-bit mantissa, ~3.3 decimal digits). For a model with K=4096, each output element accumulates over `K/16 = 256` MMA operations. FP16 accumulation over 256 steps introduces catastrophic precision loss — partial sums overflow and small contributions vanish entirely. This will measurably degrade PPL, especially on models with wide hidden dimensions (≥4096). + +**The fix:** Switch to the FP32 accumulator variant. Ampere Tensor Cores support this at identical throughput: + +```cpp +// CORRECTED: FP32 accumulator, FP16 inputs +#define MMA_SYNC_M16N8K16_F32(RC, RA, RB) \ + asm volatile( \ + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " \ + "{%0, %1, %2, %3}, " \ + "{%4, %5, %6, %7}, " \ + "{%8, %9}, " \ + "{%0, %1, %2, %3};\n" \ + : "+r"((RC)[0]), "+r"((RC)[1]), "+r"((RC)[2]), "+r"((RC)[3]) \ + : "r"((RA)[0]), "r"((RA)[1]), "r"((RA)[2]), "r"((RA)[3]), \ + "r"((RB)[0]), "r"((RB)[1]) \ + ) + +// Accumulators revert to 4 regs (4x FP32) per m16n8k16 slice: +uint32_t RC[8][4]; +``` + +This restores 4 accumulator registers per slice (32 total for 8 column slices), which was ironically the *original* layout before it was incorrectly "fixed" to 2 regs. The PTX `m16n8k16` with `.f32` C/D uses 4 × 32-bit registers holding 4 FP32 values. The register pressure increase from 16 → 32 regs for accumulators is well within Ampere's budget (see §5 below). + +Convert accumulators to FP16 only in the epilogue when writing to global memory. This is what Marlin does and is what you should do. + +--- + +### 2. CRITICAL — Paro Transform Is Underspecified and Will Stall the Pipeline + +The `apply_paro_transform()` is still a pseudocode placeholder. This is not just an implementation gap — it's a **design risk** because the transform happens on the critical path between Barrier A and the MMA loop, directly gating Tensor Core utilization. + +**Quantified cost per stage (INT4 kernel, 8 warps):** +- Channel scaling: 16 rows × 64 cols = 1024 FP16 multiplies per warp +- Givens rotations (8 pairs): 16 rows × 8 pairs × 4 FMAs = 512 FMAs per warp +- Total: ~1536 FP16 ops per warp per K_STAGE + +These operations hit shared memory for both reads and writes, creating a read-modify-write dependency chain. The `__syncthreads()` (Barrier B) after the transform further serializes this with the MMA phase. + +**Specific concerns:** + +(a) **Shared memory bank conflicts during transform:** Each warp's 16-row slab has 64 columns of FP16 = 128 bytes per row. When warp 0 and warp 4 both apply Givens rotations that touch the same column pair within their respective row slabs, the shared memory addresses differ by `16 * 128 = 2048 bytes = 64 banks`. On Ampere (32 banks, 4-byte banking), this wraps to `64 mod 32 = 0` — **same bank**. If the Givens pair indices `(a, b)` are identical across warps executing simultaneously, you get 2-way bank conflicts. + +**Fix:** Pad the A tile row stride from 128 bytes to 132 bytes (add 2 FP16 padding per row). This breaks the bank alignment pattern: + +```cpp +// Padded activation tile to avoid cross-warp bank conflicts during transform +constexpr int A_ROW_STRIDE_BYTES = 132; // 64 FP16 + 2 padding = 66 half values +constexpr int A_STAGE_BYTES_PADDED = M_TILE * A_ROW_STRIDE_BYTES; // 64 * 132 = 8448 +``` + +(b) **Transform as a throughput bottleneck:** With 8 Givens rotations per K_STAGE=64 block, the rotation pass consumes ~30 cycles per row (2 FP16 loads + 4 FMAs + 2 stores, assuming ~4-cycle FMA latency with ILP). Over 16 rows, that's ~480 cycles per warp. The MMA inner loop (4 ks × 8 j × ~8 cycles per mma.sync) is ~256 cycles. **The transform is nearly 2× the cost of the MMA phase.** This means Tensor Cores are idle nearly half the time waiting for the transform to complete. + +**Fix options:** +- Reduce rotation count. 4 rotations per block instead of 8 may be the quality-speed sweet spot. Benchmark quality with `ROT_COUNT ∈ {2, 4, 6, 8}`. +- Overlap transform with MMA from the *previous* stage by restructuring the pipeline: + ``` + Stage N: [MMA(N)] + [Transform(N+1)] // concurrent if on disjoint smem + ``` + This requires keeping 2 A tiles live simultaneously (doubling A-tile shared memory), but the total smem footprint stays under 64KB. + +--- + +### 3. MAJOR — Validation Layout Has Hidden Bank Conflicts + +The current `uint16_t Bfrag[(ks * 8 + j) * 32 + lane]` validation layout has a 2-way bank conflict pattern. + +Ampere shared memory has 32 banks with 4-byte stride. When 32 lanes read consecutive `uint16_t` addresses: +- Lane 0 reads byte offset `0` → bank 0 +- Lane 1 reads byte offset `2` → bank 0 (same bank!) +- Lane 2 reads byte offset `4` → bank 1 +- Lane 3 reads byte offset `6` → bank 1 + +This is a **2-way bank conflict on every B fragment load.** + +Yes, this is a "validation layout" meant to be replaced — but you should fix it now rather than baking in a performance anti-pattern that masks real bottlenecks during validation. The production `ld.shared.u32` path (4 bytes per lane) will resolve this naturally, but you should move to it immediately rather than treating it as a future optimization. + +**Immediate fix for validation:** + +```cpp +// Use u32 fetch even for validation — kills bank conflicts and matches production path +uint32_t packed_pair = *reinterpret_cast( + &Bfrag_u16[(ks * 8 + j) * 32 + (lane & ~1)]); +uint16_t packed_16 = (lane & 1) ? (packed_pair >> 16) : (packed_pair & 0xFFFF); +``` + +--- + +### 4. MAJOR — Missing Edge-Case Handling Throughout + +The skeleton has zero bounds checking. This will produce garbage results for real-world matrix dimensions: + +**(a) M-edge tiles:** When `M % 64 != 0`, the last CTA row-block reads out-of-bounds from the A matrix. cp.async will happily fetch garbage from global memory. Fix: add `cp.async` predication or zero-fill the tail. + +```cpp +// Predicated cp.async for M-edge tiles +int global_m = block_m + local_m; +if (global_m < M) { + // issue cp.async +} else { + // zero-fill this smem slot + *(uint4*)(smem + offset) = make_uint4(0, 0, 0, 0); +} +``` + +**(b) N-edge tiles:** When `N % 128 != 0` (INT4) or `N % 64 != 0` (INT8), the last CTA column-block's weight loads go out of bounds. Same fix needed. + +**(c) Short-K drain:** The prologue fix `min(PIPE - 1, num_k_stages)` was identified but never integrated into the actual skeleton code. The drain loop body after the main loop is still just a comment (`// ... wait_group(2) -> math ...`). This must be written out — it's not trivial with the transform + barrier sequence. + +**(d) Odd batch sizes at inference time:** For autoregressive decoding with batch size 1, `M_TILE = 64` wastes 63/64 rows. The kernel needs a separate small-M path or dynamic M_TILE selection. This is an inference-critical case. + +--- + +### 5. MAJOR — No Benchmarking Strategy Exists + +The entire design conversation contains zero profiling methodology. The proposal cannot be approved without a concrete benchmarking plan. + +**Required benchmarking deliverables:** + +**(a) Microbenchmark suite (before full kernel integration):** + +| Test | What it measures | Tool | +|------|-----------------|------| +| Decode-only kernel | INT4→FP16 throughput per SM, cycles/element | Nsight Compute, `sm__inst_executed` | +| Transform-only kernel | Paro transform throughput, bank conflict rate | Nsight Compute, `l1tex__data_bank_conflicts_pipe_lsu` | +| MMA-only kernel (no decode) | Theoretical TC peak for this tile size | Nsight Compute, `sm__pipe_tensor_op_hmma_cycles_active` | +| Full inner loop (1 CTA) | Combined throughput, stall reasons | Nsight Compute warp stall analysis | + +**(b) Matrix dimension sweep:** + +``` +M ∈ {1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024} # batch sizes +K ∈ {4096, 8192, 11008, 14336} # common hidden dims +N ∈ {4096, 8192, 11008, 14336} # common output dims +``` + +Include at least 3 non-power-of-2 edge cases: `M=7, K=5120, N=13824`. + +**(c) End-to-end perplexity comparison:** + +| Configuration | Model | Metric | +|--------------|-------|--------| +| FP16 baseline | Qwen3.5-4B / Llama-3-8B | WikiText-2 PPL | +| GPTQ 4-bit (Marlin kernel) | same | WikiText-2 PPL | +| TRAM-Quant INT4-only | same | WikiText-2 PPL | +| TRAM-Quant INT4+INT8 mixed | same | WikiText-2 PPL | + +**(d) Nsight Systems timeline:** +- Capture at least 100 token generation steps +- Verify no gaps between kernel launches (kernel fusion is working) +- Verify SM occupancy ≥ 50% sustained +- Verify memory throughput ≥ 70% of Ampere's theoretical bandwidth (A100: 2039 GB/s HBM) + +**(e) Roofline target:** +The INT4 kernel is memory-bound for small M. Compute the theoretical arithmetic intensity: +``` +AI = (2 * M * N * K) / (M*K*2 + K*N/2 + N*2 + transform_meta) FLOP/byte +``` +For M=1, K=4096, N=4096: +``` +AI = (2*1*4096*4096) / (1*4096*2 + 4096*4096/2 + 4096*2 + 256*64) ≈ 3.9 FLOP/byte +``` +At 3.9 FLOP/byte, you're memory-bound on A100 (312 TFLOPS / 2039 GB/s = 153 FLOP/byte crossover). Target: ≥ 80% of HBM bandwidth for M ≤ 16. + +--- + +### 6. MODERATE — Scale Recomputation in the Decode Double-Buffer + +Inside the `j` loop, the decode path recomputes `scale2_nxt`, `c2_nxt`, `s2_reg_nxt`, `c2_reg_nxt` for every column slice. Since `scale = S_smem[col_base + j*8 + groupID]`, and `groupID` is constant per thread, there are only 8 unique scales per thread across the 8 `j` iterations. + +**Fix:** Preload all 8 scales into registers before the `j` loop: + +```cpp +uint32_t s2_regs[8], c2_regs[8]; +#pragma unroll +for (int j = 0; j < 8; ++j) { + int global_col = col_base + j * 8 + groupID; + half scale = S_smem[global_col]; + Half2Reg s2, c2; + s2.h2 = __halves2half2(scale, scale); + c2.h2 = __hmul2(s2.h2, offset_base); + s2_regs[j] = s2.u32; + c2_regs[j] = c2.u32; +} +``` + +Cost: 16 extra registers. Saves: 8 shared memory loads + 8 half2 constructions + 8 multiplications per `ks` iteration. Net win. + +--- + +### 7. MODERATE — Epilogue Is Completely Absent + +The accumulator-to-global-memory store path is the most common place to introduce coalescing disasters. The current skeleton has only `// Write accum[] to global memory C`. + +For 8 warps each holding `RC[8][4]` (with FP32 accumulators as recommended in §1), the epilogue must: +1. Convert FP32 accumulators to FP16 +2. Write to global memory with 128-byte coalesced transactions + +**Skeleton for the epilogue:** + +```cpp +// Convert FP32 accums -> FP16 and store +// Each warp owns a 16x64 output tile +// Row = row_base + D-fragment row mapping +// Col = col_base + j*8 + D-fragment col mapping +// +// For m16n8k16 FP32 D-fragment: +// groupID = lane >> 2, tid4 = lane & 3 +// d0 -> row = 2*tid4+0, col = groupID (already known) +// d1 -> row = 2*tid4+1, col = groupID +// d2 -> row = 2*tid4+8, col = groupID +// d3 -> row = 2*tid4+9, col = groupID +// +// Coalesce by having adjacent lanes write adjacent columns. +// groupID = lane >> 2 means lanes 0-3 write col 0, lanes 4-7 write col 1, etc. +// This is NOT coalesced (4 threads per column). +// +// Fix: use shared memory as a transpose buffer. +// Write RC to smem in fragment order, __syncthreads(), read back in row-major +// coalesced order, write to global. +``` + +This is non-trivial and must be designed, not hand-waved. + +--- + +### 8. MINOR — cp.async Macro Needs Proper Shared Address Conversion + +Flagged in the conversation but never fixed in code. The `CP_ASYNC_CG_EVICT` macro takes a raw shared-memory address, but the skeleton never shows `__cvta_generic_to_shared()` being called. The `cvta.to.shared.u32` is shown only for `ldmatrix`. + +**Fix for the cp.async path:** + +```cpp +#define CP_ASYNC_CG_16B(dst_smem, src_global) \ + asm volatile( \ + "cp.async.cg.shared.global [%0], [%1], 16;\n" \ + :: "r"(dst_smem), "l"(src_global)) + +#define CP_ASYNC_CG_16B_EVICT(dst_smem, src_global) \ + asm volatile( \ + "cp.async.cg.shared.global.L2::evict_first [%0], [%1], 16;\n" \ + :: "r"(dst_smem), "l"(src_global)) +``` + +All shared-memory addresses must go through `cvta.to.shared.u32` before being passed to these macros. + +--- + +### 9. MINOR — INT8 Kernel Is Vapor + +The INT8 kernel is described as "structurally identical, just with N_TILE=64" but never written. The decode path is completely different: +- INT8 needs 8 values per lane fragment (not 4 INT4s in 16 bits) +- The magic decode trick doesn't apply to INT8 (values span a full byte, not a nibble) +- The B fragment packing contract is different + +The INT8 kernel needs its own decode path and validation. Don't assume it's trivial. + +--- + +### Summary of Required Actions + +| # | Severity | Item | Status | +|---|----------|------|--------| +| 1 | **CRITICAL** | Switch to FP32 accumulators (`f32.f16.f16.f32`) | Must fix | +| 2 | **CRITICAL** | Implement + cost-analyze Paro transform, benchmark rotation count | Must fix | +| 3 | **MAJOR** | Fix Bfrag bank conflicts, move to u32 fetch | Must fix | +| 4 | **MAJOR** | Add M/N/K edge-case handling and short-K drain | Must fix | +| 5 | **MAJOR** | Deliver complete benchmarking plan per §5 | Must fix | +| 6 | MODERATE | Preload scales outside j-loop | Should fix | +| 7 | MODERATE | Design and implement epilogue with coalesced stores | Should fix | +| 8 | MINOR | Fix cp.async address conversion | Fix before production | +| 9 | MINOR | Write INT8 decode path | Fix before production | + +--- + +### What Is Already Good + +To be clear: the iterative design conversation produced genuinely strong work in several areas: + +- **Two-kernel segregation** (pure INT4 / pure INT8 warp tiles) is the correct architecture. No branching in the inner loop. +- **XOR-swizzled activation layout** for conflict-free `ldmatrix` is textbook correct. +- **B-fragment lane mapping** is now correct after 3 iterations (`groupID = lane >> 2`, `tid4 = lane & 3`). +- **Magic FP16 decode** math is exact: `fma(1024+q, scale, -1032*scale) = (q-8)*scale` with single rounding. +- **Double-buffered decode pipeline** is the right structure for guaranteed MMA/decode overlap. +- **4-stage cp.async pipeline** with `evict_first` for weights matches Marlin's proven design. +- **3-barrier synchronization** (A: stage resident, B: transform visible, C: stage freed) is correct. +- **Single-scale-per-thread** insight from the PTX B-fragment column invariant is a real register pressure win. + +The architecture is the right one for Ampere. The execution needs the fixes above before it can ship. + +--- + +**Awaiting: `## Final Plan` from Alpha with resolutions for items 1–5.** + +*Do not proceed to implementation until the accumulator precision, transform cost model, and benchmarking plan are locked.* + +--- + +## Final Plan + +**Author:** Alpha +**Responding to:** Peer Review items 1–9. + +Every critical and major item is addressed below with concrete code, cost analysis, and benchmarking commitment. Minor items 8–9 are acknowledged and scheduled. + +--- + +### Fix 1 (CRITICAL) — FP32 Accumulators + +Switching the MMA instruction and accumulator array globally. No exceptions. + +```cpp +// ------------------------------------------------------- +// Corrected accumulator: FP32, 4 regs per m16n8k16 slice +// ------------------------------------------------------- +uint32_t RC[8][4]; +#pragma unroll +for (int j = 0; j < 8; ++j) { + RC[j][0] = 0; RC[j][1] = 0; RC[j][2] = 0; RC[j][3] = 0; +} + +// FP32 accumulator MMA macro +#define MMA_SYNC_M16N8K16_F32(RC, RA, RB) \ + asm volatile( \ + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " \ + "{%0, %1, %2, %3}, " \ + "{%4, %5, %6, %7}, " \ + "{%8, %9}, " \ + "{%0, %1, %2, %3};\n" \ + : "+r"((RC)[0]), "+r"((RC)[1]), "+r"((RC)[2]), "+r"((RC)[3]) \ + : "r"((RA)[0]), "r"((RA)[1]), "r"((RA)[2]), "r"((RA)[3]), \ + "r"((RB)[0]), "r"((RB)[1]) \ + ) +``` + +**Register budget impact:** +- `RC[8][4]` = 32 regs (was 16 with the incorrect f16 form) +- `RA[4]` = 4 regs +- `RB_fp16_cur/nxt[2]` = 4 regs +- 8× preloaded scale pairs `s2_regs[8]`, `c2_regs[8]` = 16 regs (see Fix 6) +- Address/loop vars, decode temporaries: ~12 regs + +Total estimated: ~68 regs. Ampere allows 255 regs per thread; 256 threads/CTA × 68 = 17,408 regs < 65,536 available per SM (A100). **No spill risk at this occupancy target.** + +**Epilogue: FP32 → FP16 conversion before global store:** + +```cpp +// Convert FP32 accumulator to FP16 pair for output +__device__ __forceinline__ +uint32_t f32x2_to_f16x2(float a, float b) { + uint32_t out; + asm volatile( + "cvt.rn.f16.f32 %0, %1;\n" + "cvt.rn.f16.f32 %2, %3;\n" + : "=h"(*((__half*)&out)), "=h"(*(((__half*)&out)+1)) + : "f"(a), "f"(b) + ); + return out; +} +// Or cleaner using __float22half2_rn if cuda_fp16.h is available +``` + +--- + +### Fix 2 (CRITICAL) — Paro Transform: Implementation + Cost Model + Rotation Budget + +**2a. Concrete implementation of `apply_paro_transform`:** + +```cpp +// apply_paro_transform: warp-local, modifies A_smem in-place +// Each warp owns rows [row_base .. row_base+15] of the 64×64 A tile. +// 'scale' and 'rot' arrays are from TransformStage in smem. +// +// A_smem layout: row stride = A_ROW_STRIDE_BYTES = 132 bytes (padded) +// rot_count is a compile-time constant (see budget below). + +template +__device__ __forceinline__ void apply_paro_transform( + uint8_t* __restrict__ stage_smem, + int stage_base, + int row_base, + int lane) +{ + half* A = (half*)(stage_smem + stage_base); + const half* scale = (const half*)(stage_smem + stage_base + + A_STAGE_BYTES_PADDED); + const RotationMeta* rot = (const RotationMeta*)( + stage_smem + stage_base + A_STAGE_BYTES_PADDED + 128); + + // Each thread handles columns in strides of 32 (warpSize). + // Phase 1: channel scaling — all 64 cols, all 16 rows of this warp's slab. + #pragma unroll 4 + for (int local_row = 0; local_row < 16; ++local_row) { + int row = row_base + local_row; + half* A_row = A + row * (A_ROW_STRIDE_BYTES / sizeof(half)); + + // Each lane scales 2 columns (lane covers cols lane*2 and lane*2+1). + // With warpSize=32 and K_STAGE=64: 64 cols / 32 lanes = 2 cols/lane. + int c0 = lane * 2; + int c1 = lane * 2 + 1; + A_row[c0] = __hmul(A_row[c0], scale[c0]); + A_row[c1] = __hmul(A_row[c1], scale[c1]); + } + + // Warp-level sync: scaling writes must be visible to rotation reads. + __syncwarp(); + + // Phase 2: sparse Givens rotations (ROT_COUNT iterations, unrolled). + #pragma unroll + for (int t = 0; t < ROT_COUNT; ++t) { + const int u = rot[t].a; // col index 0..63 + const int v = rot[t].b; // col index 0..63 + const half c = rot[t].c; + const half s = rot[t].s; + + #pragma unroll 4 + for (int local_row = 0; local_row < 16; ++local_row) { + int row = row_base + local_row; + half* A_row = A + row * (A_ROW_STRIDE_BYTES / sizeof(half)); + + // Only the lane that "owns" col u and col v executes this. + // lane_u = u / 2, lane_v = v / 2. + // To avoid divergence: ALL lanes compute, only correct lanes write. + // Alternatively: broadcast via warp shuffles. + // + // Broadcast-based implementation (zero-divergence): + half au = __shfl_sync(0xFFFFFFFF, A_row[u % 2], u / 2); + half av = __shfl_sync(0xFFFFFFFF, A_row[v % 2], v / 2); + half new_u = __hadd(__hmul(c, au), __hmul(s, av)); + half new_v = __hsub(__hmul(c, av), __hmul(s, au)); + + // Only the owning lanes write back. + if (lane == (u / 2)) A_row[u % 2] = new_u; + if (lane == (v / 2)) A_row[v % 2] = new_v; + } + __syncwarp(); // Rotation t must complete before rotation t+1 reads + } +} +``` + +**2b. Cost model (conservative, per-stage per-CTA):** + +| Phase | Ops/warp | Cycles (est.) | Bottleneck | +|-------|----------|--------------|------------| +| Scale (64 cols × 16 rows, FP16 mul) | 1024 FMAs | ~64 cycles (16 FP16 FMAs/cycle/warp) | FP16 pipe | +| Rotation loads (`__shfl_sync`) | 2 × ROT_COUNT × 16 rows | ~32 cycles (ROT_COUNT=4) | Register/shuffle | +| Rotation FMAs | 4 × ROT_COUNT × 16 rows | ~64 cycles (ROT_COUNT=4) | FP16 pipe | +| Rotation writes | 2 × ROT_COUNT × 16 rows | ~16 cycles | Shared mem | +| **Total (ROT_COUNT=4)** | | **~176 cycles/warp** | | +| MMA inner loop | 4 ks × 8 j | **~256 cycles/warp** | TC | + +With `ROT_COUNT=4`, transform cost is 176/256 ≈ **69% of MMA cost**, not 2× as it was with ROT_COUNT=8. This is acceptable. Warps are pipelined; while warp 0 transforms, warp 4 is already in its MMA loop. + +**2c. Committed rotation budget:** `ROT_COUNT = 4` for the first implementation. Quality ablation (§5b) will test `{2, 4, 6, 8}`. If ROT_COUNT=4 matches ROT_COUNT=8 quality within 0.05 PPL, we ship ROT_COUNT=4. + +**2d. A-tile padding to eliminate bank conflicts during transform:** + +```cpp +// Padded row stride: 64 FP16 = 128 bytes + 4 bytes padding = 132 bytes +// 132 / 4 = 33 banks touched per row → odd number breaks periodicity +constexpr int A_ROW_STRIDE_BYTES = 132; +constexpr int A_STAGE_BYTES_PADDED = M_TILE * A_ROW_STRIDE_BYTES; // 8448 B + +// Updated per-stage total for INT4: +constexpr int STAGE4_BYTES = A_STAGE_BYTES_PADDED + XFORM_STAGE_BYTES + + W4_STAGE_BYTES + S4_STAGE_BYTES; // 13056 B +constexpr int TOTAL_SMEM_BYTES_INT4 = PIPE * STAGE4_BYTES; // 52224 B + +// INT8 equivalent: +constexpr int STAGE8_BYTES = A_STAGE_BYTES_PADDED + XFORM_STAGE_BYTES + + W8_STAGE_BYTES + S8_STAGE_BYTES; // 12928 B +constexpr int TOTAL_SMEM_BYTES_INT8 = PIPE * STAGE8_BYTES; // 51712 B +``` + +Both remain under 64KB (Ampere's max per-SM shared memory that still allows 2 CTAs/SM). + +--- + +### Fix 3 (MAJOR) — B Fragment: Immediate Move to u32 Fetch + +The `uint16_t` validation layout is gone. The physical layout is lane-pair packed from day one. + +```cpp +// Physical Bfrag layout: [ks * N_FRAG + j][lane_pair] -> uint32_t +// lane_pair = lane >> 1: adjacent lane pairs share a 4-byte word. +// Low 16 bits = even lane's payload, high 16 bits = odd lane's payload. +// +// Alignment: bfrag_stage_offset must be 4-byte aligned. +// B smem region is already 4096B aligned by construction. + +__device__ __forceinline__ +uint16_t fetch_b_fragment(const uint8_t* B_smem_base, int ks, int j, int lane) { + // Each (ks, j) slice = 32 lanes × 2 bytes = 64 bytes = 16 × uint32_t + int pair_idx = lane >> 1; + int lane_parity = lane & 1; + const uint32_t* B_u32 = reinterpret_cast(B_smem_base) + + (ks * 8 + j) * 16 + pair_idx; + uint32_t word; + asm volatile("ld.shared.u32 %0, [%1];" : "=r"(word) + : "r"(__cvta_generic_to_shared(B_u32))); + return (lane_parity == 0) ? (uint16_t)(word & 0xFFFF) + : (uint16_t)(word >> 16); +} +``` + +Each `ld.shared.u32` fetches a 4-byte word containing two lanes' payloads. Lanes in a pair (0,1), (2,3), ... hit **different banks** because pair_idx increments by 1 for every 4 bytes → hits consecutive banks. No conflicts. + +--- + +### Fix 4 (MAJOR) — Edge Cases: Complete Handling + +**4a. Predicated cp.async for M/N/K boundaries:** + +```cpp +// Templated predicated loader — emits cp.async only if in-bounds +__device__ __forceinline__ void cp_async_pred( + uint32_t smem_addr, const void* gmem_addr, bool valid) +{ + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %2, 0;\n" + " @p cp.async.cg.shared.global [%0], [%1], 16;\n" + " @!p st.shared.v4.u32 [%0], {0,0,0,0};\n" // zero-fill if OOB + "}\n" + :: "r"(smem_addr), "l"(gmem_addr), "r"((int)valid) + ); +} + +// Usage in the A-tile loader: +for (int chunk = tid; chunk < M_TILE * A_CHUNKS; chunk += blockDim.x) { + int local_m = chunk / A_CHUNKS; + int local_k = chunk % A_CHUNKS; + int global_m = block_m + local_m; + int global_k = k_tile_base + local_k * CHUNK_K; + bool valid = (global_m < M) && (global_k < K); + uint32_t dst = smem_base_u32 + chunk * 16; + const void* src = A + global_m * K + global_k; // row-major + cp_async_pred(dst, src, valid); +} +``` + +**4b. Short-K drain — fully written out:** + +```cpp +// Drain: process stages that were prefetched but not yet consumed. +// After the main loop, 'pipe_in_flight' stages are pending. +// pipe_in_flight = min(PIPE - 1, num_k_stages) at loop exit. + +int drain_count = min(PIPE - 1, num_k_stages); +for (int d = 0; d < drain_count; ++d) { + // Wait for one more group to complete + CP_ASYNC_WAIT(drain_count - 1 - d); + __syncthreads(); // Barrier A + + int load_base = smem_load_idx * STAGE4_BYTES; + apply_paro_transform(smem, load_base, row_base, lane); + __syncthreads(); // Barrier B + + do_mma_inner_loop(smem, load_base, row_base, col_base, + warp_id, lane, groupID, tid4, RC, s2_regs, c2_regs); + __syncthreads(); // Barrier C + + smem_load_idx = (smem_load_idx + 1) % PIPE; +} +``` + +**4c. Small-M dispatch (batch size 1–16):** + +For autoregressive decode where M=1, a dedicated `M_TILE=16` kernel variant is registered. The kernel selector at the host dispatch site: + +```cpp +// Host dispatch +if (M <= 16) + launch_ampere_paro_int4_kernel<16, 128>(grid, block, smem_size, stream, ...); +else + launch_ampere_paro_int4_kernel<64, 128>(grid, block, smem_size, stream, ...); +``` + +The M=16 variant uses 2 warps (64 threads), halving the transform cost and fitting 4 CTAs/SM instead of 2. + +--- + +### Fix 5 (MAJOR) — Benchmarking Plan (Committed) + +**5a. Microbenchmark suite (CUDA kernels, runnable without full model):** + +``` +tests/kernels/ +├── bench_decode_only.cu # INT4→FP16 throughput, no MMA +├── bench_transform_only.cu # Paro transform throughput, bank conflict profiling +├── bench_mma_only.cu # Tensor Core saturation test (identity weights) +├── bench_inner_loop.cu # Combined decode+MMA, 1 CTA, sweep K +└── bench_full_kernel.cu # Full pipeline, sweep M×N×K +``` + +Each bench emits: +- `cycles_per_output_element` via CUDA events +- `effective_memory_bandwidth_GBs` +- `mma_active_fraction` (ratio of cycles with TC busy) + +**Profiling commands (required before any performance claim):** + +```bash +# TC utilization and stall analysis +ncu --metrics sm__pipe_tensor_op_hmma_cycles_active.avg.pct_of_peak_sustained_active,\ +smsp__warp_issue_stalled_wait_pct,\ +smsp__warp_issue_stalled_mio_throttle_pct \ + ./bench_inner_loop --M 64 --N 128 --K 4096 + +# Bank conflict verification (must show 0 after Fix 3) +ncu --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld.sum \ + ./bench_inner_loop + +# Memory bandwidth (roofline) +ncu --metrics l1tex__t_bytes_pipe_lsu_mem_global_op_ld.sum,\ +dram__bytes.sum \ + ./bench_full_kernel --M 1 --N 4096 --K 4096 +``` + +**5b. Matrix dimension sweep (all must pass, including edge cases):** + +```python +# scripts/sweep_correctness.py +M_vals = [1, 2, 4, 7, 8, 16, 32, 64, 128, 256] # includes odd sizes +K_vals = [64, 128, 256, 512, 1024, 4096, 5120, 8192] # includes non-power-of-2 +N_vals = [128, 256, 512, 1024, 4096, 11008, 13824] # includes non-power-of-2 + +TOLERANCE_VS_FP32_REF = 1e-2 # max relative error on any output element +``` + +Pass criterion: max relative error < 1e-2 vs. FP32 cuBLAS reference for all (M, K, N) combinations above. + +**5c. End-to-end PPL targets:** + +| Configuration | Qwen3.5-4B WikiText-2 PPL | Target (vs FP16 baseline) | +|-------------|--------------------------|--------------------------| +| FP16 baseline | measured | reference | +| GPTQ 4-bit RTN | measured | ≤ +0.3 PPL | +| TRAM INT4-only (ROT_COUNT=0) | measured | ≤ +0.25 PPL | +| TRAM INT4 (ROT_COUNT=4) | measured | ≤ +0.15 PPL (must beat GPTQ) | +| TRAM INT4+INT8 mixed | measured | ≤ +0.10 PPL | + +If TRAM INT4 (ROT_COUNT=4) does not beat GPTQ 4-bit RTN PPL by at least 0.05 PPL at equal avg bits-per-weight, the transform is **not worth the runtime cost** and ROT_COUNT is set to 0. + +**5d. Nsight Systems timeline (required before claiming "Marlin-like" speed):** + +```bash +nsys profile --trace=cuda,nvtx \ + --stats=true \ + --output tram_quant_timeline \ + python run_inference.py --model Qwen3.5-4B-TRAM --tokens 200 --batch 1 +``` + +Pass criteria: +- Kernel launch overhead < 5 µs between consecutive GEMM calls +- SM active utilization ≥ 50% (batch=1) and ≥ 80% (batch=64) +- Memory bandwidth utilization ≥ 70% of theoretical HBM peak at batch=1 + +**5e. Rotation count ablation (runs in parallel with correctness sweep):** + +```bash +for rot in 0 2 4 6 8; do + python eval_ppl.py --model Qwen3.5-4B --rot-count $rot \ + --output results/rot_ablation_${rot}.json +done +``` + +Results tabulated in `results/rot_ablation_summary.md` before any implementation decisions. + +--- + +### Fix 6 (MODERATE) — Preload Scales Outside j-Loop + +Already reflected in Fix 1's register budget. Complete implementation: + +```cpp +// Preload all 8 column scales before the j-loop +uint32_t s2_regs[8], c2_regs[8]; +const half2 offset_base = __halves2half2( + __float2half(-1032.0f), __float2half(-1032.0f)); + +#pragma unroll +for (int j = 0; j < 8; ++j) { + const int global_col = col_base + j * 8 + groupID; + const half scale = S_smem[global_col]; + Half2Reg s2, c2; + s2.h2 = __halves2half2(scale, scale); + c2.h2 = __hmul2(s2.h2, offset_base); + s2_regs[j] = s2.u32; + c2_regs[j] = c2.u32; +} +``` + +--- + +### Fix 7 (MODERATE) — Epilogue: Coalesced Store via Transpose Buffer + +The D-fragment for `m16n8k16.row.col.f32.f16.f16.f32` maps per-lane as: +- `d0` → row `2*tid4+0`, col `groupID` (FP32 in RC[j][0] low half) +- `d1` → row `2*tid4+1`, col `groupID` (FP32 in RC[j][0] high half) +- `d2` → row `2*tid4+8`, col `groupID` (FP32 in RC[j][1] low half) +- `d3` → row `2*tid4+9`, col `groupID` (FP32 in RC[j][1] high half) + +This is column-major within a warp (8 lanes per column), so direct global stores are **not** coalesced. The epilogue must transpose through shared memory: + +```cpp +// Epilogue shared-memory transpose buffer +// Reuse weight smem (weights are consumed; smem is free after main loop). +half* epilogue_buf = reinterpret_cast(smem); +// epilogue_buf has M_TILE * N_TILE4 * 2 bytes = 64*128*2 = 16384 B available + +// Step 1: write FP32 accumulators to epilogue_buf in fragment order +// (column-major within a 16x8 warp tile) +#pragma unroll +for (int j = 0; j < 8; ++j) { + // D-fragment row/col for this warp+j + const int col = col_base + j * 8 + groupID; + const int r0 = row_base + 2 * tid4 + 0; + const int r1 = row_base + 2 * tid4 + 1; + const int r2 = row_base + 2 * tid4 + 8; + const int r3 = row_base + 2 * tid4 + 9; + + // Convert FP32 -> FP16 and write to smem + epilogue_buf[r0 * N_TILE4 + col] = __float2half(*(float*)&RC[j][0]); + epilogue_buf[r1 * N_TILE4 + col] = __float2half(*(float*)&RC[j][1] /* high */); + epilogue_buf[r2 * N_TILE4 + col] = __float2half(*(float*)&RC[j][2]); + epilogue_buf[r3 * N_TILE4 + col] = __float2half(*(float*)&RC[j][3] /* high */); +} + +__syncthreads(); // Barrier: all warps have written their tiles + +// Step 2: read smem in row-major order and write to global C +// Each thread writes one 128-bit (8 FP16) word = 8 output elements. +// 256 threads × 8 elements = 2048 elements per iteration. +// M_TILE * N_TILE4 = 64 * 128 = 8192 elements → 4 iterations. +for (int chunk = tid; chunk < (M_TILE * N_TILE4) / 8; chunk += blockDim.x) { + int out_m = (chunk * 8) / N_TILE4 + block_m; + int out_n = (chunk * 8) % N_TILE4 + block_n; + if (out_m < M && out_n < N) { + *reinterpret_cast(C + out_m * N + out_n) = + *reinterpret_cast(epilogue_buf + chunk * 8); + } +} +``` + +--- + +### Fix 8 (MINOR) — cp.async Address Conversion (Acknowledged) + +All shared-memory addresses in cp.async calls will use `__cvta_generic_to_shared()`. No raw pointer casts. + +```cpp +// Canonical cp.async helper +template +__device__ __forceinline__ void cp_async_cg( + uint8_t* dst_smem, const void* src_gmem) +{ + uint32_t smem_addr = __cvta_generic_to_shared(dst_smem); + asm volatile( + "cp.async.cg.shared.global [%0], [%1], %2;\n" + :: "r"(smem_addr), "l"(src_gmem), "n"(BYTES) + ); +} + +template +__device__ __forceinline__ void cp_async_cg_evict( + uint8_t* dst_smem, const void* src_gmem) +{ + uint32_t smem_addr = __cvta_generic_to_shared(dst_smem); + asm volatile( + "cp.async.cg.shared.global.L2::evict_first [%0], [%1], %2;\n" + :: "r"(smem_addr), "l"(src_gmem), "n"(BYTES) + ); +} +``` + +--- + +### Fix 9 (MINOR) — INT8 Kernel Decode Path (Acknowledged) + +The INT8 decode is distinct and scheduled as a separate deliverable. High-level contract: + +- B fragment: `uint8_t` per INT8 weight, not nibble-packed +- Each lane owns 4 INT8 values from one column (same PTX fragment layout as INT4) +- Payload is 4 bytes = one `uint32_t` per lane (no 16-bit tricks needed) +- Magic decode: `q ∈ [0,255]`, symmetric INT8 with zero-point 128: + ``` + deq = (q - 128) * scale + ``` + This uses FP32 conversion: `__int2float_rn(q) * scale - 128.0f * scale` + Or via PTX `cvt.rn.f16.u8` if available +- No cross-lane nibble alignment issues + +INT8 kernel is gated on INT4 kernel passing full validation. Timeline: INT4 validation → INT8 decode path → INT8 full kernel. + +--- + +### Consolidated Corrected Skeleton (Inner Loop, All Fixes Applied) + +```cpp +// ----------------------------------------------------------------------- +// TRAM-Quant INT4 Kernel — Corrected Inner Loop +// All Peer Review fixes 1–8 applied. +// ----------------------------------------------------------------------- + +// --- Constants (updated for padded A layout) --- +constexpr int PIPE = 4; +constexpr int M_TILE = 64; +constexpr int N_TILE4 = 128; +constexpr int K_STAGE = 64; +constexpr int ROT_COUNT = 4; +constexpr int A_ROW_STRIDE_HALF = 66; // 64 values + 2 padding +constexpr int A_ROW_STRIDE_BYTES= 132; +constexpr int A_STAGE_BYTES = M_TILE * A_ROW_STRIDE_BYTES; // 8448 +constexpr int XFORM_STAGE_BYTES = 256; +constexpr int W4_STAGE_BYTES = 4096; +constexpr int S4_STAGE_BYTES = 256; +constexpr int STAGE4_BYTES = A_STAGE_BYTES + XFORM_STAGE_BYTES + + W4_STAGE_BYTES + S4_STAGE_BYTES; // 13056 +constexpr int TOTAL_SMEM_BYTES = PIPE * STAGE4_BYTES; // 52224 + +// --- Warp/lane setup --- +const int warp_id = tid >> 5; +const int lane = tid & 31; +const int wm = warp_id & 3; +const int wn = warp_id >> 2; +const int row_base= wm * 16; +const int col_base= wn * 64; + +// --- PTX B-fragment lane ownership --- +const int groupID = lane >> 2; // output column within current n8 slice (0..7) +const int tid4 = lane & 3; // row-pair selector (0..3) + +// --- FP32 Accumulators --- +uint32_t RC[8][4]; +#pragma unroll +for (int j = 0; j < 8; ++j) { + RC[j][0] = 0; RC[j][1] = 0; RC[j][2] = 0; RC[j][3] = 0; +} + +// --- Double-buffer decode registers --- +uint32_t RB_fp16_cur[2], RB_fp16_nxt[2]; + +// --- Magic decode constant --- +constexpr uint32_t MAGIC_FP16 = 0x64006400u; +const half2 offset_base = __halves2half2( + __float2half(-1032.0f), __float2half(-1032.0f)); + +union Half2Reg { half2 h2; uint32_t u32; }; + +// --- A-fragment smem address (XOR-swizzled) --- +// Computed once here; updated inside ks loop. +uint32_t a_smem_base_u32 = __cvta_generic_to_shared( + smem + smem_load_idx * STAGE4_BYTES); + +// ----------------------------------------------------------------------- +// Inner K-stage loop (drops into the main pipeline loop body) +// ----------------------------------------------------------------------- + +// Preload all 8 scales before the j-loop (done once per ks, hoisted): +// NOTE: scales are per K-group; they stay constant across ks within one stage. +uint8_t* S_smem = smem + smem_load_idx * STAGE4_BYTES + + A_STAGE_BYTES + XFORM_STAGE_BYTES + W4_STAGE_BYTES; +uint32_t s2_regs[8], c2_regs[8]; +#pragma unroll +for (int j = 0; j < 8; ++j) { + const int global_col = col_base + j * 8 + groupID; + const half scale = ((const half*)S_smem)[global_col]; + Half2Reg s2, c2; + s2.h2 = __halves2half2(scale, scale); + c2.h2 = __hmul2(s2.h2, offset_base); + s2_regs[j] = s2.u32; + c2_regs[j] = c2.u32; +} + +// B smem pointer +const uint8_t* B_smem = smem + smem_load_idx * STAGE4_BYTES + + A_STAGE_BYTES + XFORM_STAGE_BYTES; + +#pragma unroll +for (int ks = 0; ks < 4; ++ks) { + uint32_t RA[4]; + + // --- A fragment: ldmatrix with XOR swizzle --- + { + int col = ks * 16 + (lane / 16) * 8; + int row = row_base + (lane % 16); + int vec = col / 8; + int phys_v = vec ^ (row & 7); + uint32_t a_addr = a_smem_base_u32 + + (row * A_ROW_STRIDE_HALF + phys_v * 8) * 2; + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(RA[0]), "=r"(RA[1]), "=r"(RA[2]), "=r"(RA[3]) + : "r"(a_addr) + ); + } + + // --- Decode j=0 preload --- + { + uint16_t p16 = fetch_b_fragment(B_smem, ks, 0, lane); + uint32_t p32 = p16; + uint32_t w0 = p32 & 0x00FFu, w1 = (p32 >> 8) & 0x00FFu; + uint32_t h0 = (w0 & 0x0Fu) | ((w0 & 0xF0u) << 12); + uint32_t h1 = (w1 & 0x0Fu) | ((w1 & 0xF0u) << 12); + asm volatile("fma.rn.f16x2 %0,%1,%2,%3;" : "=r"(RB_fp16_cur[0]) + : "r"(h0|MAGIC_FP16), "r"(s2_regs[0]), "r"(c2_regs[0])); + asm volatile("fma.rn.f16x2 %0,%1,%2,%3;" : "=r"(RB_fp16_cur[1]) + : "r"(h1|MAGIC_FP16), "r"(s2_regs[0]), "r"(c2_regs[0])); + } + + #pragma unroll + for (int j = 0; j < 8; ++j) { + if (j < 7) { + uint16_t p16 = fetch_b_fragment(B_smem, ks, j + 1, lane); + uint32_t p32 = p16; + uint32_t w0 = p32 & 0x00FFu, w1 = (p32 >> 8) & 0x00FFu; + uint32_t h0 = (w0 & 0x0Fu) | ((w0 & 0xF0u) << 12); + uint32_t h1 = (w1 & 0x0Fu) | ((w1 & 0xF0u) << 12); + const int jn = j + 1; + asm volatile("fma.rn.f16x2 %0,%1,%2,%3;" : "=r"(RB_fp16_nxt[0]) + : "r"(h0|MAGIC_FP16), "r"(s2_regs[jn]), "r"(c2_regs[jn])); + asm volatile("fma.rn.f16x2 %0,%1,%2,%3;" : "=r"(RB_fp16_nxt[1]) + : "r"(h1|MAGIC_FP16), "r"(s2_regs[jn]), "r"(c2_regs[jn])); + } + + MMA_SYNC_M16N8K16_F32(RC[j], RA, RB_fp16_cur); + + if (j < 7) { + RB_fp16_cur[0] = RB_fp16_nxt[0]; + RB_fp16_cur[1] = RB_fp16_nxt[1]; + } + } +} // end ks loop +// [Epilogue via smem transpose buffer follows — see Fix 7] +``` + +--- + +### What Remains Before First Execution + +| Step | Owner | Gate | +|------|-------|------| +| Write `fetch_b_fragment` into full kernel (Fix 3) | Alpha | — | +| Write `apply_paro_transform<4>` (Fix 2) | Alpha | — | +| Write predicated cp.async loaders (Fix 4a) | Alpha | — | +| Write short-K drain (Fix 4b) | Alpha | — | +| Write epilogue with smem transpose (Fix 7) | Alpha | — | +| Build and run decode-only unit test vs. scalar ref | Alpha | All above | +| Run `ncu` bank conflict metric → must show 0 | Alpha | Decode unit test | +| Run correctness sweep `M∈{1,7,64,128}, K∈{64,4096,5120}` | Alpha | Bank conflict clean | +| Run PPL on Qwen3.5-4B, ROT_COUNT ∈ {0,2,4} | Alpha | Correctness sweep | + +--- + +*Alpha out. Awaiting sign-off from Beta.* + +--- + +## Beta Sign-Off + +**Reviewer:** Copilot (Beta) +**Status: APPROVED — conditional implementation sign-off.** + +The Final Plan resolves all 5 must-fix items from the Peer Review. Going through them explicitly: + +**✅ Fix 1 — FP32 Accumulators:** `mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32` is now the canonical instruction throughout. RC[8][4] at 32 regs is correct. Register budget math checks out at ~68 regs/thread — no spill risk on A100. FP16 epilogue conversion path is specified. This is the right call. + +**✅ Fix 2 — Transform:** The warp-shuffle based `apply_paro_transform` eliminates cross-warp divergence. The cost table (176 vs 256 cycles at ROT_COUNT=4) is honest and within acceptable range. Padding to `A_ROW_STRIDE_BYTES=132` kills the bank conflict I flagged. ROT_COUNT=4 as the default with a committed quality ablation is a rational starting point. + +**✅ Fix 3 — B Fragment Load:** The `fetch_b_fragment` u32 lane-pair trick is correct. Adjacent pairs hit different banks (pair_idx increments with 4B stride). This replaces the uint16_t layout immediately rather than deferring it. + +**✅ Fix 4 — Edge Cases:** Predicated cp.async covers M/N boundaries. Short-K drain is written out with the full 3-barrier sequence maintained. Small-M kernel variant (M_TILE=16 for batch≤16) is the right answer for autoregressive decode. + +**✅ Fix 5 — Benchmarking:** The plan is now specific, executable, and rigorous. `ncu` metrics are named. The PPL target (TRAM INT4 must beat GPTQ by ≥0.05 PPL or ROT_COUNT reverts to 0) is a real kill switch, not a vague aspiration. The roofline target (≥70% HBM BW at M=1) gives a concrete number to hit. + +**Three residual conditions for continued approval:** + +1. **The `__syncwarp()` calls inside `apply_paro_transform` are necessary.** The rotation pass reads back values written by the scaling pass. If scaling and rotation are interleaved across rows without `__syncwarp()`, lane A's rotation may read un-scaled values from lane B. The current implementation has `__syncwarp()` after scaling and between rotations — keep them. + +2. **The epilogue FP32→FP16 conversion has a latent bug.** The `*(float*)&RC[j][0]` extracts the full 32 bits as a float, but for the high half of a register pair, `*(float*)&RC[j][1]` still extracts the low float from RC[j][1], not the high float of RC[j][0]. The FP32 accumulator layout for `m16n8k16.row.col.f32` is **4 separate FP32 values in 4 separate 32-bit registers** (one float per register). The comment in the epilogue says "high half of RC[j][0]" which is wrong — RC[j][0] through RC[j][3] are each a single `float`. Correct the epilogue: + ```cpp + epilogue_buf[r0 * N_TILE4 + col] = __float2half(*(const float*)&RC[j][0]); + epilogue_buf[r1 * N_TILE4 + col] = __float2half(*(const float*)&RC[j][1]); + epilogue_buf[r2 * N_TILE4 + col] = __float2half(*(const float*)&RC[j][2]); + epilogue_buf[r3 * N_TILE4 + col] = __float2half(*(const float*)&RC[j][3]); + ``` + Fix this before the first execution. + +3. **The `a_smem_base_u32` must update each iteration of the outer K-stage loop** when `smem_load_idx` advances. Currently it is computed once before the ks loop. Either recompute it at the top of each pipeline iteration or pass it as a parameter to `do_mma_inner_loop`. + +These are correctness issues, not design issues. The architecture is sound. + +**Bottom line:** The plan is solid. Execute in the order Alpha's table specifies — unit test decode first, bank-conflict clean second, correctness sweep third, PPL last. Do not run the PPL sweep until `ncu` confirms zero bank conflicts and the correctness sweep is green. + +*Beta sign-off complete. Proceed to implementation.* + +--- + +## Implementation Status + +The approved structural fixes have now been applied to the standalone `gptq_pro` CUDA scaffold under `gptqmodel_ext/gptq_pro/`. + +- The MMA path was switched from FP16 accumulation to FP32 accumulation (`mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32`). +- The kernel accumulator storage was updated from `RC[J_TILES][2]` FP16-pair registers to `RC[J_TILES][4]` FP32 outputs. +- The epilogue store contract was corrected so each lane writes its 4 owned output elements using the PTX fragment map (`groupID = lane >> 2`, `tid4 = lane & 3`). +- The B-fragment contract was corrected to match the official PTX `m16n8k16.row.col` ownership for matrix B: each lane now owns one logical B column (`groupID = lane >> 2`) across rows `{2*tid4, 2*tid4+1, 2*tid4+8, 2*tid4+9}` rather than two columns from one row pair. +- The standalone validation harness was upgraded to validate the FP32-accumulator MMA path instead of the old FP16-accumulator path. +- The decode-only validator was tightened to check all 4 decoded FP16 values from each lane-local packed INT4 word, not just a subset of halves. +- The MMA-step validator now uses non-uniform, FP16-exact synthetic A and B tiles wired through the PTX-defined fragment ownership, so row/column transpositions fail immediately instead of slipping through a uniform-data smoke test. +- The validator now checks CUDA runtime calls explicitly, which prevents false positives or garbage summaries when a selected GPU is unavailable or out of memory on a shared machine. +- A repeatable helper script now exists at `scripts/run_gptq_pro_validate.sh`; it builds the standalone validator, targets GPU `2` by default, and falls back to GPU `3` when the primary 3060 is unavailable. +- Misleading comments about a completed production pipeline and a magic-decode fast path were corrected so the scaffold is honest about what is implemented versus still placeholder. + +### Validation Run + +The updated scaffold was verified locally with: + +- `nvcc -arch=sm_80 -std=c++17 -c gptq_pro_kernel.cu -o /tmp/gptq_pro_kernel.o` +- `nvcc -arch=sm_80 -std=c++17 gptq_pro_validate.cu -o /tmp/gptq_pro_validate_phase2` +- `CUDA_VISIBLE_DEVICES=3 /tmp/gptq_pro_validate_phase2` → `64 / 64 checks passed` +- `scripts/run_gptq_pro_validate.sh` → builds and runs the same standalone validator successfully on the default RTX 3060 (`2`) +- `python -m pytest tests/qcfg/test_failsafe_meta.py -q` → `14 passed` + +**Shared-machine note:** one intermediate retry on RTX 3060 index `2` returned `cudaMalloc(...): out of memory` during validator setup, but a later rerun via `scripts/run_gptq_pro_validate.sh` passed on the same board once contention cleared. In practice, both RTX 3060s (`2` and `3`) are now known-good validation targets for this standalone scaffold. + +### Remaining Known Scope + +This is still a **standalone scaffold**, not the full final kernel from Alpha's plan: + +- `gptq_pro_gemm_kernel` is now end-to-end functional on `sm80`, but it is a compact single-warp scaffold rather than the planned multi-warp Ampere kernel. +- The previous placeholder path was replaced with explicit shared-memory staging for activations, per-column scales, and packed INT4 B fragments; helper-level validation alone is no longer the only safety net. +- The invalid `ldmatrix` path that faulted with `cudaErrorMisalignedAddress` was removed in favor of validator-backed manual A-fragment packing for the current scaffold. +- Paro rotation metadata is still not wired into the runtime path. +- The direct global-store epilogue is correct but not yet coalesced via a shared-memory transpose buffer. +- The runtime currently models symmetric INT4 with implicit zero-point `8` and requires `group_size` to be a multiple of `16`; asymmetric qzero metadata still needs its own path if required. +- The INT8 sibling kernel, real `cp.async` pipeline, and benchmark suite remain future work. + +### Remaining Speed Headroom + +- Restore a validated XOR-swizzled `ldmatrix` A-load path once the shared-memory layout is finalized for this standalone kernel. +- Reintroduce real Ampere async staging (`cp.async` / deeper pipelining) after the global->shared contract is locked down. +- Expand from the current `1 warp x 16x64x16` scaffold to the planned larger CTA tiles (`64x128` INT4, separate INT8 rescue path) for better memory/TC overlap. +- Replace the direct epilogue stores with a coalesced shared-memory transpose buffer. +- Fuse real Paro metadata and benchmark transform cost / bank conflicts instead of skipping the runtime transform path. +- Add Nsight Compute / Systems microbenchmarks so future tuning is driven by measured stall reasons rather than source inspection alone. + +--- + +## Verification Run (Post-Implementation) + +**Date:** 2026-03-20 +**Environment:** Python 3.13.11, PyTorch 2.10.0, CUDA 12.x, 3× RTX 3090 + 2× RTX 3060 + +### CUDA Kernel Build +``` +$ nvcc -arch=sm_80 -std=c++17 -c gptq_pro_kernel.cu -o /tmp/gptq_pro_kernel.o +BUILD OK + +$ nvcc -arch=sm_80 -std=c++17 gptq_pro_validate.cu -o /tmp/gptq_pro_validate +VALIDATE BUILD OK +``` + +### Standalone Validator +``` +=== TODO 1: decode-only validation === + 32 / 32 lanes passed +=== TODO 2: ks/j MMA step validation === + 32 / 32 lanes passed + +=== Overall: 64 / 64 checks passed === +``` + +### RTX 3060 Validation Path +```bash +$ CUDA_VISIBLE_DEVICES=2 /tmp/gptq_pro_validate_phase2 +CUDA error at gptq_pro_validate.cu:265: out of memory + +$ CUDA_VISIBLE_DEVICES=3 /tmp/gptq_pro_validate_phase2 +=== TODO 1: decode-only validation === + 32 / 32 lanes passed +=== TODO 2: ks/j MMA step validation === + 32 / 32 lanes passed + +=== Overall: 64 / 64 checks passed === + +$ scripts/run_gptq_pro_validate.sh +==> Selected GPU 2 +=== TODO 1: decode-only validation === + 32 / 32 lanes passed +=== TODO 2: ks/j MMA step validation === + 32 / 32 lanes passed + +=== Overall: 64 / 64 checks passed === +``` + +### Python Test Suite +``` +$ python -m pytest tests/qcfg/test_failsafe_meta.py -q +14 passed in 4.65s +``` + +### Concurrent Fixes Included in This Commit +- **`gptqmodel/utils/hf.py`**: Added `load_tokenizer_with_model_config()` and `ensure_hf_model_config_token_ids()` to propagate model config token IDs (bos, eos, pad) to Tokenicer wrappers, fixing tokenizer/config mismatch bugs. +- **`gptqmodel/models/base.py`**: Replaced raw `Tokenicer.load()` calls with `load_tokenizer_with_model_config()` at all three entry points (init, quantize, load_quantized). +- **`tests/test_hf_config_autofix.py`**: Unit tests for the new HF config autofix helpers. + +--- + +## Qwen3.5-4B Quantization & Benchmark Results + +**Date:** 2026-03-20 +**Model:** [wangzhang/Qwen3.5-4B-abliterated](https://huggingface.co/wangzhang/Qwen3.5-4B-abliterated) +**Architecture:** Qwen3.5 (hybrid linear + full attention, 32 layers, hidden_size=2560) +**Quantization:** GPTQ 4-bit, group_size=128 + +### Environment +- Python 3.13.11, PyTorch 2.10.0, Transformers 5.3.0 +- GPTQModel 5.8.0 (dev), TritonV2 kernel backend +- GPUs: 3× RTX 3090 (24 GB) + 2× RTX 3060 (12 GB), Driver 570.211.01 + +### Quantization Summary + +| Metric | Value | +|--------|-------| +| FP16 model size | 7.83 GB | +| GPTQ 4-bit model size | 2.92 GB | +| Size reduction | 62.71% (4.91 GB saved) | +| Effective BPW | 4.29 bpw | +| Calibration samples | 128 (WikiText-2 train, min 512 chars) | +| Quantization time | 181.4s (1× RTX 3090) | + +### Perplexity (WikiText-2, test split) + +| Configuration | Perplexity | +|---------------|------------| +| GPTQ 4-bit g128 | **8.6759** | + +Sliding-window evaluation: max_length=2048, stride=512, 578 windows, 297,053 tokens total. + +### Baseline Generation Speed Benchmark (GPTQModel / Transformers runtime) + +Greedy decoding (do_sample=False), 10 prompts averaged per setting. +This first-pass benchmark is preserved for comparison, but it is **not** the final answer to the +corrected request because it used the GPTQModel / Transformers runtime rather than vLLM. + +| GPU Config | max_new_tokens | Tokens/sec | +|------------|---------------|------------| +| **1× RTX 3090** | 128 | **24.16** | +| 1× RTX 3090 | 256 | 24.31 | +| 1× RTX 3090 | 512 | 24.53 | +| **2× RTX 3090** | 128 | **17.62** | +| 2× RTX 3090 | 256 | 17.73 | +| 2× RTX 3090 | 512 | 17.67 | + +### Baseline Analysis + +- **1× RTX 3090** delivers ~24.3 tok/s sustained across all sequence lengths, consistent + with the model fitting entirely in 24 GB VRAM (2.92 GB quantized). +- **2× RTX 3090** (device_map="auto") is **~27% slower** than 1× due to pipeline-parallel + cross-GPU communication overhead. Since the model fits in a single GPU's memory, + splitting across 2 GPUs adds inter-GPU transfer latency without a memory benefit. + Multi-GPU shines for models that exceed single-GPU capacity. +- Throughput is stable across 128→512 token generation lengths, indicating the kernel + is compute-bound rather than launch-overhead-bound at these sequence lengths. +- The TritonV2 kernel backend was auto-selected for inference. + +### Baseline Notes +- Linear attention layers (conv1d, in_proj_a/b) were intentionally left unquantized + as they use different compute patterns from standard attention projections. +- The `flash-linear-attention` fast path was unavailable; torch fallback was used for + the linear attention layers. Installing `fla` + `causal-conv1d` would likely improve + throughput further. + +### Follow-up: Original PPL, GPTQ-Pro, and vLLM + +The corrected follow-up run compared against the original BF16 model, quantized a fresh +`QuantizeConfig.gptq_pro()` checkpoint, and benchmarked that checkpoint with vLLM rather than the +Transformers runtime. + +#### Original vs Quantized Perplexity (WikiText-2, test split) + +All three numbers below use the same sliding-window setup: `max_length=2048`, `stride=512`, +578 windows, and 297,053 tokens total. + +| Configuration | Perplexity | Delta vs original | +|---------------|------------|-------------------| +| **Original BF16** | **8.3116** | baseline | +| GPTQ 4-bit g128 | 8.6759 | +0.3643 | +| **GPTQ-Pro 4-bit g128** | **8.6314** | **+0.3198** | + +GPTQ-Pro recovered `0.0445` PPL versus the earlier plain GPTQ run under the same evaluation setup. + +#### GPTQ-Pro Quantization Summary + +| Metric | Value | +|--------|-------| +| Model load time | 4.9s | +| Quantization time | 324.9s | +| Calibration samples | 128 | +| Output format | GPTQ-compatible (`format=gptq`, `checkpoint_format=gptq`) | +| Key quality knobs | `act_group_aware=true`, `mse=2.0`, adaptive damping, `SmoothAuto` failsafe | + +Saved GPTQ-Pro metadata confirms GAR, MSE search, adaptive damping, and failsafe smoothing while +still producing a GPTQ-compatible checkpoint that vLLM can consume through Marlin. + +#### vLLM GPTQ-Pro Benchmark + +Stock vLLM 0.17.0 and the tested nightly build still do not load this `qwen3_5_text` checkpoint +cleanly out of the box in this environment. The benchmark therefore used a temporary runtime patch +that: + +- wraps the Hugging Face `qwen3_5_text` config in vLLM's `Qwen3_5Config` +- forces `language_model_only=True` +- skips multimodal / vision initialization +- remaps checkpoint weights from `model.*` to `language_model.model.*` + +Once patched, vLLM selected `gptq_marlin` automatically and reported +`Using MarlinLinearKernel for GPTQMarlinLinearMethod`. + +| GPU Config | TP Size | max_new_tokens | Tokens/sec | Engine init | +|------------|---------|----------------|------------|-------------| +| **1× RTX 3090** | 1 | 128 | **175.21** | 37.03s | +| 1× RTX 3090 | 1 | 256 | 178.14 | 37.03s | +| **2× RTX 3090** | 2 | 128 | **194.20** | 56.53s | +| 2× RTX 3090 | 2 | 256 | 206.53 | 56.53s | + +#### Why the Earlier Speed Was Slow + +- The original speed run above was measured with the GPTQModel / Transformers inference path, + not vLLM, so it never exercised vLLM scheduling or Marlin's end-to-end runtime path. +- Qwen3.5 linear-attention layers were on the torch fallback path because the + `flash-linear-attention` / `fla` and `causal-conv1d` fast-path dependencies were unavailable. +- The plain GPTQ checkpoint already fit comfortably on one 24 GB RTX 3090, so the earlier + `device_map="auto"` two-GPU split only added inter-GPU communication overhead. +- After switching to the GPTQ-Pro checkpoint and the vLLM `gptq_marlin` path, throughput improved + to roughly `7.3×` the earlier 1×-GPU Transformers result (`178.14 / 24.53`) and `11.7×` the + earlier 2×-GPU Transformers result (`206.53 / 17.67`). + +#### Follow-up Notes + +- Two-GPU tensor parallelism helped only modestly here: `+10.8%` at 128 tokens and `+15.9%` + at 256 tokens, which is expected for a 4-bit 4B model that already fits on one card. +- The patched vLLM path still emits noisy shutdown warnings (`destroy_process_group()` / + `Engine core proc ... died unexpectedly`) after writing the benchmark JSON, so the text-only + Qwen3.5 integration should still be treated as upstream-incomplete. +- vLLM also logged an unrelated plugin load error (`ModuleNotFoundError: No module named 'reap'`) + and FLA shape warnings during warmup/inference. These did not prevent successful runs, but they + are worth cleaning up before treating this path as production-ready. + +### Follow-up: 27B replacement for the GGUF-only HauhauCS request + +The requested repository +`HauhauCS/Qwen3.5-35B-A3B-Uncensored-HauhauCS-Aggressive` turned out to be GGUF-only, with the +model card explicitly saying `GPTQ — coming soon`. Because GPTQModel requires a Transformers / +Safetensors source checkpoint for GPTQ-Pro quantization, the replacement run used the +user-approved Transformers checkpoint `huihui-ai/Huihui-Qwen3.5-27B-abliterated`. + +#### 27B original vs GPTQ-Pro perplexity + +Shared-machine VRAM pressure from another long-lived process on the second RTX 3090 repeatedly +OOMed the original-model evaluation path, even after reducing context length. The stable fallback +was a fixed regression slice on one clean RTX 3090 plus CPU offload: + +- dataset: WikiText-2 raw test +- `max_length=256` +- `stride=256` +- `max_windows=16` +- `4096` scored tokens total + +| Configuration | Perplexity | Delta vs original | +|---------------|------------|-------------------| +| **Original BF16** | **11.6266** | baseline | +| **GPTQ-Pro 4-bit g32** | **12.0161** | **+0.3895** | + +These absolute numbers should not be compared directly against the earlier 4B full-dataset sweep: +they use much shorter context windows and fewer total tokens because the shared environment could +not sustain the full-length BF16 evaluation for this larger model. + +#### 27B GPTQ-Pro quantization summary + +| Metric | Value | +|--------|-------| +| Model | `huihui-ai/Huihui-Qwen3.5-27B-abliterated` | +| Load time | `5.3s` | +| Quantization time | `2273.6s` | +| Save time | `21.8s` | +| Calibration samples | `128` | +| Output size | ~`18G` | +| Output shards | `5` safetensors files | +| Key quality / stability knobs | `group_size=32`, `balanced` VRAM strategy, `gc_mode=on_stage_end`, `auto_forward_data_parallel=false`, `wait_for_submodule_finalizers=true`, `ExpertsRoutingBypass(batch_size=2)`, disk offload | + +The offload scratch directory peaked in the mid-teens of gigabytes and showed active per-module +staging throughout the run, which confirmed that the MoE-aware serial / offload path was doing the +heavy lifting rather than silently hanging. + +#### 27B vLLM smoke result + +The installed `vLLM 0.17.0` still does not cleanly deploy this Qwen 3.5 text-family checkpoint in +the tested environment. A one-shot offline `LLM.generate()` smoke test on the new quantized GPTQ-Pro +checkpoint did select `gptq_marlin`, but then failed before generation with: + +- `TypeError: Invalid type of HuggingFace config` +- expected `Qwen3_5Config` +- found `Qwen3_5TextConfig` + +For this replacement model, the mismatch surfaced through vLLM's multimodal renderer path rather +than the earlier text-only load path, but the root integration problem is the same: upstream vLLM +still does not fully normalize the Hugging Face Qwen 3.5 config family in this environment. diff --git a/pyproject.toml b/pyproject.toml index c047c022b..fb3367b8b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,13 +1,12 @@ [build-system] requires = [ - "setuptools>=80.9", - "ninja>=1.13.0", # required for faster compilataion + "setuptools>=77.0.1,<83", ] build-backend = "setuptools.build_meta" [project] name = "GPTQModel" -dynamic = ["version"] +dynamic = ["version", "dependencies"] description = "Production ready LLM model compression/quantization toolkit with hw accelerated inference support for both cpu/gpu via HF, vLLM, and SGLang." readme = "README.md" requires-python = ">=3.10" @@ -31,40 +30,14 @@ classifiers = [ "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: Information Analysis", ] -dependencies = [ - "accelerate>=1.10.1", - "numpy==2.2.6; python_version < '3.14'", - "numpy>=2.3.0; python_version >= '3.14'", - "torch>=2.8.0", - "safetensors>=0.6.2", - "transformers>=4.57.1", - "threadpoolctl>=3.6.0", - "packaging>=24.2", - "device-smi>=0.5.3", - "protobuf>=6.32.0", - "pillow>=11.3.0", - "hf_transfer>=0.1.9; python_version < '3.14'", - "huggingface_hub>=0.34.4", - "tokenicer>=0.0.8", - "logbar>=0.2.1", - "maturin>=1.9.4", # required by safetensors and hf_transfer - "datasets>=3.6.0", - "pyarrow>=21.0", - "dill>=0.3.8", # datasets requirements - "pypcre>=0.2.12", - "torchao>=0.14.1", # fix bad transformers 4.57.1 breaking torchao compat - "kernels>=0.12.2", # For CPU kernels - "defuser>=0.0.6", - # "cython>=3.1.4", # required by hf-xet/hf-transfer -# "flash-attn>=2.8.3", <-- install for lower vram usage -] - -[project.scripts] -gptqmodel = "gptqmodel.cli.gptqmodel:main" [project.urls] Homepage = "https://github.com/ModelCloud/GPTQModel" +[tool.setuptools.dynamic] +dependencies = { file = ["requirements.txt"] } + + [project.optional-dependencies] test = [ "pytest>=8.3.5", @@ -86,16 +59,30 @@ sglang = [ bitblas = [ "bitblas==0.1.0.post1", ] +bitsandbytes = [ + "bitsandbytes>=0.49.3", +] hf = [ "optimum>=1.21.2", ] eval = [ - "lm_eval>=0.4.7", - "evalplus>=0.3.1", + "Evalution", ] triton = [ "triton>=3.4.0", ] +marlin-cuda12 = [ + "nvidia-cuda-runtime-cu12==12.9.79", + "nvidia-cublas-cu12==12.9.1.4", + "nvidia-cusparse-cu12==12.5.10.65", + "nvidia-cusolver-cu12==11.7.5.82", +] +marlin-cuda = [ + "nvidia-cuda-runtime>=13.0.96", + "nvidia-cublas>=13.1.0.3", + "nvidia-cusparse>=12.6.3.3", + "nvidia-cusolver>=12.0.4.66", +] openai = [ "uvicorn", "fastapi", @@ -104,3 +91,6 @@ openai = [ mlx = [ "mlx_lm>=0.24.0", ] + +[tool.uv] +torch-backend = "auto" diff --git a/requirements.txt b/requirements.txt index fa82639c5..dd89e918e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,23 +1,22 @@ -accelerate>=1.10.1 +accelerate>=1.13.0 numpy==2.2.6; python_version < "3.14" numpy>=2.3.0; python_version >= "3.14" torch>=2.8.0 -safetensors>=0.6.2 -transformers>=4.57.1 +safetensors>=0.7.0 +transformers>=5.4.0 threadpoolctl>=3.6.0 packaging>=24.2 device-smi>=0.5.2 -protobuf>=6.32.0 +protobuf>=7.34.0 pillow>=11.3.0 -pypcre>=0.2.12 -hf_transfer>=0.1.9; python_version < "3.14" -huggingface_hub>=0.34.4 -tokenicer>=0.0.8 -logbar>=0.2.1 +pypcre>=0.3.0 +tokenicer>=0.0.12 +logbar>=0.4.1 +jinja2>=3.1.0 +ninja>=1.13.0 maturin>=1.9.4 datasets>=3.6.0 pyarrow>=21.0 dill>=0.3.8 -torchao>=0.14.1 -kernels>=0.12.2 -defuser>=0.0.6 +torchao>=0.16.0 +defuser>=0.0.18 diff --git a/scripts/arch.md b/scripts/arch.md new file mode 100644 index 000000000..f4af337a2 --- /dev/null +++ b/scripts/arch.md @@ -0,0 +1,37 @@ +# CI Architecture + +## Naming + +- `.github/scripts/ci_*.py` are the only workflow entrypoints. +- Shared data stays in `.github/scripts/*.yaml`. +- Shared logic lives in reusable modules instead of one-off CLIs. + +## GPTQModel unit test flow + +1. `check-vm` +- `ci_workflow.py check-vm` computes `ip`, `run_id`, `install_ts`, and matrix parallelism, then writes them to `GITHUB_OUTPUT`. + +2. `list-test-files` +- `ci_workflow.py list-tests` scans `tests/`, filters ignored and regex-matched cases, splits them into torch/model/mlx buckets, and builds shared env matrices from `deps.yaml` and `test.yaml`. + +3. `prepare` +- `ci_workflow.py prepare-common-envs` deduplicates shared env rows, creates or refreshes those uv envs in parallel, installs base requirements, and syncs the repo-scoped git dependencies. + +4. `torch` and `torch-models` +- `ci_workflow.py activate-test-env` resolves `GPU_COUNT`, `HAS_SPECIFIC_DEPS`, `ENV_NAME`, and `UV_CACHE_DIR`. +- `ci_workflow.py setup-specific-env` applies per-test compiler/python settings and test-specific install/uninstall package rules from `deps.yaml` and `blacklist.yaml`. +- `ci_workflow.py install-package` serializes source installs with lock files so only one job populates a dedicated env at a time. +- `ci_gpu.py allocate` and `ci_gpu.py release` talk to the shared GPU allocator service. +- `ci_tests.py run` executes pytest, streams logs, keeps GPU leases alive, and logs VRAM usage. +- `ci_tests.py check-log` prints the same failure excerpts the old shell step grepped from the test log. + +## Config files + +- `.github/scripts/test.yaml`: per-group and per-test runtime config, mainly Python version and GPU count. +- `.github/scripts/deps.yaml`: extra packages needed by individual tests or directories. +- `.github/scripts/blacklist.yaml`: packages that must be removed for specific tests. + +## Maintenance rule + +- Add new workflow behavior by extending an existing `ci_*` entrypoint first. +- Only keep shell in workflow files for minimal glue such as checkout, `source /opt/uv/setup_uv_venv.sh ...`, and GitHub expression wiring. diff --git a/scripts/benchmark_awq_cuda_fp32_reduce_ab.py b/scripts/benchmark_awq_cuda_fp32_reduce_ab.py new file mode 100644 index 000000000..a4e62a7f3 --- /dev/null +++ b/scripts/benchmark_awq_cuda_fp32_reduce_ab.py @@ -0,0 +1,576 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import argparse +import json +import math +import os +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import torch +from tabulate import tabulate + +@dataclass(frozen=True) +class BenchCase: + case_id: str + batch: int + seq: int + in_features: int + out_features: int + group_size: int = 128 + krot: int = 8 + + +DEFAULT_CASES = [ + BenchCase("decode_q_proj", batch=1, seq=1, in_features=2048, out_features=2048), + BenchCase("prefill_q_proj", batch=1, seq=128, in_features=2048, out_features=2048), + BenchCase("batched_q_proj", batch=4, seq=128, in_features=2048, out_features=2048), + BenchCase("prefill_k_proj", batch=1, seq=128, in_features=2048, out_features=512), + BenchCase("prefill_gate_proj", batch=1, seq=128, in_features=2048, out_features=8192), + BenchCase("batched_down_proj", batch=4, seq=128, in_features=8192, out_features=2048), +] + +QUICK_CASES = [ + BenchCase("decode_q_proj", batch=1, seq=1, in_features=2048, out_features=2048), + BenchCase("prefill_q_proj", batch=1, seq=128, in_features=2048, out_features=2048), + BenchCase("batched_down_proj", batch=4, seq=128, in_features=8192, out_features=2048), +] + + +def _resolve_dtype(name: str) -> torch.dtype: + if name == "fp16": + return torch.float16 + if name == "bf16": + return torch.bfloat16 + raise ValueError(f"Unsupported dtype: {name}") + + +def _subset_cases(cases: list[BenchCase], shard_index: int, num_shards: int) -> list[BenchCase]: + if num_shards <= 0: + raise ValueError("`num_shards` must be positive.") + if shard_index < 0 or shard_index >= num_shards: + raise ValueError(f"`shard_index` must be in [0, {num_shards - 1}].") + return [case for index, case in enumerate(cases) if index % num_shards == shard_index] + + +def _pack_awq_tensor(unpacked: torch.Tensor, bits: int) -> torch.Tensor: + pack_factor = 32 // bits + order_map = [0, 2, 4, 6, 1, 3, 5, 7] + packed = torch.zeros((unpacked.shape[0], unpacked.shape[1] // pack_factor), dtype=torch.int32) + for col in range(unpacked.shape[1] // pack_factor): + for i, order in enumerate(order_map): + value = unpacked[:, col * pack_factor + order].to(torch.int32) + packed[:, col] |= value << (i * bits) + return packed + + +def _make_quant_buffers(case: BenchCase, dtype: torch.dtype, bits: int = 4) -> dict[str, torch.Tensor]: + from gptqmodel.utils.paroquant import build_identity_rotation_buffers + + groups = case.in_features // case.group_size + int_weight = torch.randint(0, 2**bits, size=(case.in_features, case.out_features), dtype=torch.int32) + zero_points = torch.randint(0, 2**bits, size=(groups, case.out_features), dtype=torch.int32) + scales = (torch.rand(groups, case.out_features, dtype=torch.float32) * 0.5) + 0.75 + scales = scales.to(dtype=dtype) + + pairs, theta, channel_scales = build_identity_rotation_buffers( + in_features=case.in_features, + group_size=case.group_size, + krot=case.krot, + dtype=dtype, + ) + theta.uniform_(-0.2, 0.2) + channel_scales.uniform_(0.75, 1.25) + + return { + "qweight": _pack_awq_tensor(int_weight, bits), + "qzeros": _pack_awq_tensor(zero_points, bits), + "scales": scales, + "pairs": pairs, + "theta": theta, + "channel_scales": channel_scales, + } + + +def _dense_reference( + x: torch.Tensor, + qweight: torch.Tensor, + qzeros: torch.Tensor, + scales: torch.Tensor, + bits: int, + group_size: int, +) -> torch.Tensor: + from gptqmodel.quantization.awq.utils.packing_utils import dequantize_gemm + + dense_weight = dequantize_gemm( + qweight=qweight, + qzeros=qzeros, + scales=scales, + bits=bits, + group_size=group_size, + ).to(device=x.device, dtype=x.dtype) + return torch.matmul(x, dense_weight) + + +def _benchmark_ms(fn, device: torch.device, warmup: int, iters: int) -> float: + with torch.inference_mode(): + for _ in range(warmup): + fn() + torch.cuda.synchronize(device) + start = time.perf_counter() + for _ in range(iters): + fn() + torch.cuda.synchronize(device) + return (time.perf_counter() - start) * 1e3 / iters + + +def _format_speedup(speedup: float) -> str: + return f"{speedup:.3f}x" + + +def _mode_label(mode: str, split_k: int) -> str: + return f"{mode}_k{split_k}" + + +def _path_label(path: str, mode: str, split_k: int) -> str: + if path == "dequant": + return "dequant" + return f"fused_{_mode_label(mode, split_k)}" + + +def _run_path( + x: torch.Tensor, + qweight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, + bits: int, + group_size: int, + path: str, + split_k: int, + mode: str, +) -> torch.Tensor: + if path == "dequant": + return _dense_reference( + x, + qweight, + qzeros, + scales, + bits=bits, + group_size=group_size, + ) + + from gptqmodel.utils.awq import awq_gemm_forward + return awq_gemm_forward( + x, + qweight, + scales, + qzeros, + split_k, + fp32_accum=(mode == "fp32_accum"), + ) + + +def _run_suite( + device: torch.device, + dtype: torch.dtype, + warmup: int, + iters: int, + quick: bool, + rotate_inputs: bool, + shard_index: int, + num_shards: int, + baseline_path: str, + candidate_path: str, + baseline_mode: str, + candidate_mode: str, + baseline_split_k: int, + candidate_split_k: int, +) -> dict[str, Any]: + cases = _subset_cases(QUICK_CASES if quick else DEFAULT_CASES, shard_index=shard_index, num_shards=num_shards) + rows = [] + speedups = [] + candidate_wins = 0 + baseline_label = _path_label(baseline_path, baseline_mode, baseline_split_k) + candidate_label = _path_label(candidate_path, candidate_mode, candidate_split_k) + + for index, case in enumerate(cases): + torch.manual_seed(1000 + index) + buffers = _make_quant_buffers(case, dtype=dtype) + qweight = buffers["qweight"].to(device) + qzeros = buffers["qzeros"].to(device) + scales = buffers["scales"].to(device) + x = torch.randn((case.batch * case.seq, case.in_features), device=device, dtype=dtype) + + kernel_input = x + if rotate_inputs: + from gptqmodel.utils.paroquant import apply_paroquant_rotation + + pairs = buffers["pairs"].to(device) + theta = buffers["theta"].to(device) + channel_scales = buffers["channel_scales"].to(device) + kernel_input = apply_paroquant_rotation( + x, + pairs, + theta, + scales=channel_scales, + group_size=case.group_size, + ) + + with torch.inference_mode(): + dense = _dense_reference(kernel_input, qweight, qzeros, scales, bits=4, group_size=case.group_size) + baseline = _run_path( + kernel_input, + qweight, + scales, + qzeros, + bits=4, + group_size=case.group_size, + path=baseline_path, + split_k=baseline_split_k, + mode=baseline_mode, + ) + candidate = _run_path( + kernel_input, + qweight, + scales, + qzeros, + bits=4, + group_size=case.group_size, + path=candidate_path, + split_k=candidate_split_k, + mode=candidate_mode, + ) + + baseline_dense = (baseline - dense).abs() + candidate_dense = (candidate - dense).abs() + baseline_candidate = (baseline - candidate).abs() + + baseline_ms = _benchmark_ms( + lambda: _run_path( + kernel_input, + qweight, + scales, + qzeros, + bits=4, + group_size=case.group_size, + path=baseline_path, + split_k=baseline_split_k, + mode=baseline_mode, + ), + device=device, + warmup=warmup, + iters=iters, + ) + candidate_ms = _benchmark_ms( + lambda: _run_path( + kernel_input, + qweight, + scales, + qzeros, + bits=4, + group_size=case.group_size, + path=candidate_path, + split_k=candidate_split_k, + mode=candidate_mode, + ), + device=device, + warmup=warmup, + iters=iters, + ) + speedup = baseline_ms / candidate_ms + speedups.append(speedup) + winner = candidate_label if candidate_ms < baseline_ms else baseline_label + candidate_wins += int(candidate_ms < baseline_ms) + + rows.append( + { + "case_id": case.case_id, + "batch": case.batch, + "seq": case.seq, + "in_features": case.in_features, + "out_features": case.out_features, + "dtype": str(dtype).replace("torch.", ""), + "baseline_path": baseline_path, + "candidate_path": candidate_path, + "baseline_mode": baseline_mode, + "candidate_mode": candidate_mode, + "baseline_split_k": baseline_split_k, + "candidate_split_k": candidate_split_k, + "baseline_ms": baseline_ms, + "candidate_ms": candidate_ms, + "speedup": speedup, + "winner": winner, + "baseline_dense_max_abs": baseline_dense.max().item(), + "candidate_dense_max_abs": candidate_dense.max().item(), + "baseline_candidate_max_abs": baseline_candidate.max().item(), + "baseline_candidate_mean_abs": baseline_candidate.mean().item(), + } + ) + + geo_mean_speedup = math.exp(sum(math.log(v) for v in speedups) / len(speedups)) if speedups else float("nan") + return { + "baseline_label": baseline_label, + "candidate_label": candidate_label, + "rows": rows, + "geo_mean_speedup": geo_mean_speedup, + "candidate_wins": candidate_wins, + "case_count": len(speedups), + } + + +def run( + device: torch.device, + dtype: torch.dtype, + warmup: int, + iters: int, + quick: bool, + shard_index: int, + num_shards: int, + baseline_path: str, + candidate_path: str, + baseline_mode: str, + candidate_mode: str, + baseline_split_k: int, + candidate_split_k: int, +) -> dict[str, Any]: + torch.manual_seed(0) + torch.cuda.manual_seed_all(0) + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + + return { + "device": torch.cuda.get_device_name(device), + "cuda_device": str(device), + "dtype": str(dtype).replace("torch.", ""), + "warmup": warmup, + "iters": iters, + "quick": quick, + "shard_index": shard_index, + "num_shards": num_shards, + "baseline_path": baseline_path, + "candidate_path": candidate_path, + "baseline_mode": baseline_mode, + "candidate_mode": candidate_mode, + "baseline_split_k": baseline_split_k, + "candidate_split_k": candidate_split_k, + "awq": _run_suite( + device=device, + dtype=dtype, + warmup=warmup, + iters=iters, + quick=quick, + rotate_inputs=False, + shard_index=shard_index, + num_shards=num_shards, + baseline_path=baseline_path, + candidate_path=candidate_path, + baseline_mode=baseline_mode, + candidate_mode=candidate_mode, + baseline_split_k=baseline_split_k, + candidate_split_k=candidate_split_k, + ), + "paroquant": _run_suite( + device=device, + dtype=dtype, + warmup=warmup, + iters=iters, + quick=quick, + rotate_inputs=True, + shard_index=shard_index, + num_shards=num_shards, + baseline_path=baseline_path, + candidate_path=candidate_path, + baseline_mode=baseline_mode, + candidate_mode=candidate_mode, + baseline_split_k=baseline_split_k, + candidate_split_k=candidate_split_k, + ), + } + + +def _print_suite(name: str, results: dict[str, Any]) -> None: + print(name) + print(f"Configs: baseline={results['baseline_label']} candidate={results['candidate_label']}") + print("Accuracy") + print( + tabulate( + [ + [ + row["case_id"], + f"{row['batch']}x{row['seq']}", + f"{row['in_features']}->{row['out_features']}", + f"{row['baseline_dense_max_abs']:.6f}", + f"{row['candidate_dense_max_abs']:.6f}", + f"{row['baseline_candidate_max_abs']:.6f}", + f"{row['baseline_candidate_mean_abs']:.6f}", + ] + for row in results["rows"] + ], + headers=[ + "case", + "batch x seq", + "matmul", + "baseline vs dense max_abs", + "candidate vs dense max_abs", + "baseline vs candidate max_abs", + "baseline vs candidate mean_abs", + ], + tablefmt="plain", + ) + ) + print() + print("Benchmark") + print( + tabulate( + [ + [ + row["case_id"], + f"{row['batch']}x{row['seq']}", + f"{row['in_features']}->{row['out_features']}", + f"{row['baseline_ms']:.3f}", + f"{row['candidate_ms']:.3f}", + _format_speedup(row["speedup"]), + row["winner"], + ] + for row in results["rows"] + ], + headers=[ + "case", + "batch x seq", + "matmul", + "baseline ms", + "candidate ms", + "speedup", + "winner", + ], + tablefmt="plain", + ) + ) + print() + print( + "Summary: " + f"candidate_wins={results['candidate_wins']}/{results['case_count']}, " + f"geo_mean_speedup={results['geo_mean_speedup']:.3f}x" + ) + print() + + +def _configure_awq_runtime(args: argparse.Namespace) -> None: + if args.force_rebuild_awq: + build_root = Path("/tmp") / ( + f"awq_jit_bench_{os.getpid()}_dev{args.device}_shard{args.shard_index}_of_{args.num_shards}" + ) + os.environ["GPTQMODEL_AWQ_BUILD_ROOT"] = str(build_root) + os.environ["GPTQMODEL_AWQ_FORCE_REBUILD"] = "1" + else: + os.environ.pop("GPTQMODEL_AWQ_BUILD_ROOT", None) + os.environ.pop("GPTQMODEL_AWQ_FORCE_REBUILD", None) + + from gptqmodel.utils.awq import awq_runtime_error, clear_awq_extension_cache, prewarm_awq_extension + + if args.force_rebuild_awq: + clear_awq_extension_cache() + + if not prewarm_awq_extension(): + raise RuntimeError(f"Failed to build/load the AWQ CUDA runtime: {awq_runtime_error()}") + + +def _configure_paroquant_runtime(args: argparse.Namespace, device: torch.device) -> None: + if args.force_rebuild_paroquant: + build_root = Path("/tmp") / ( + f"paroquant_ext_awqbench_{os.getpid()}_dev{args.device}_shard{args.shard_index}_of_{args.num_shards}" + ) + os.environ["GPTQMODEL_PAROQUANT_BUILD_ROOT"] = str(build_root) + os.environ["GPTQMODEL_PAROQUANT_FORCE_REBUILD"] = "1" + else: + os.environ.pop("GPTQMODEL_PAROQUANT_BUILD_ROOT", None) + os.environ.pop("GPTQMODEL_PAROQUANT_FORCE_REBUILD", None) + + from gptqmodel.utils.paroquant import ( + clear_paroquant_rotation_extension_cache, + prewarm_paroquant_rotation_extension, + ) + + if args.force_rebuild_paroquant: + clear_paroquant_rotation_extension_cache() + + if not prewarm_paroquant_rotation_extension( + fused_rotation=True, + group_size=128, + krot=8, + device=device, + ): + raise RuntimeError("Failed to build/load the fused ParoQuant CUDA rotation extension.") + + +def main() -> int: + parser = argparse.ArgumentParser( + description="A/B benchmark AWQ and ParoQuant CUDA GEMM configs with dense-reference accuracy reporting." + ) + parser.add_argument("--device", type=int, default=0, help="CUDA device index within the current visible set.") + parser.add_argument("--dtype", choices=("fp16", "bf16"), default="fp16") + parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations per case.") + parser.add_argument("--iters", type=int, default=20, help="Measured iterations per case.") + parser.add_argument("--quick", action="store_true", help="Run a smaller subset of benchmark cases.") + parser.add_argument("--shard-index", type=int, default=0) + parser.add_argument("--num-shards", type=int, default=1) + parser.add_argument("--baseline-path", choices=("fused", "dequant"), default="fused") + parser.add_argument("--candidate-path", choices=("fused", "dequant"), default="fused") + parser.add_argument("--baseline-mode", choices=("legacy", "fp32_accum"), default="legacy") + parser.add_argument("--candidate-mode", choices=("legacy", "fp32_accum"), default="fp32_accum") + parser.add_argument("--baseline-split-k", type=int, default=8) + parser.add_argument("--candidate-split-k", type=int, default=8) + parser.add_argument("--json-out", type=Path, default=None) + parser.add_argument("--force-rebuild-awq", action="store_true") + parser.add_argument("--force-rebuild-paroquant", action="store_true") + parser.add_argument("--json", action="store_true", help="Also emit the full result payload as JSON.") + args = parser.parse_args() + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for the AWQ CUDA benchmark.") + if args.baseline_split_k <= 0 or args.candidate_split_k <= 0: + raise ValueError("split_k_iters must be positive.") + + device = torch.device(f"cuda:{args.device}") + _configure_awq_runtime(args) + _configure_paroquant_runtime(args, device) + + results = run( + device=device, + dtype=_resolve_dtype(args.dtype), + warmup=args.warmup, + iters=args.iters, + quick=args.quick, + shard_index=args.shard_index, + num_shards=args.num_shards, + baseline_path=args.baseline_path, + candidate_path=args.candidate_path, + baseline_mode=args.baseline_mode, + candidate_mode=args.candidate_mode, + baseline_split_k=args.baseline_split_k, + candidate_split_k=args.candidate_split_k, + ) + + print(f"Device: {results['device']} ({results['cuda_device']}, dtype={results['dtype']})") + print() + _print_suite("AWQ", results["awq"]) + _print_suite("ParoQuant", results["paroquant"]) + + if args.json_out is not None: + args.json_out.parent.mkdir(parents=True, exist_ok=True) + args.json_out.write_text(json.dumps(results, indent=2), encoding="utf-8") + + if args.json: + print(json.dumps(results, indent=2)) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/benchmark_awq_fused_reduce_ab.py b/scripts/benchmark_awq_fused_reduce_ab.py new file mode 100644 index 000000000..0de2de2c5 --- /dev/null +++ b/scripts/benchmark_awq_fused_reduce_ab.py @@ -0,0 +1,490 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import argparse +import json +import math +import os +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import torch +from tabulate import tabulate + + +@dataclass(frozen=True) +class BenchCase: + case_id: str + batch: int + seq: int + in_features: int + out_features: int + group_size: int = 128 + krot: int = 8 + + +DEFAULT_CASES = [ + BenchCase("decode_q_proj", batch=1, seq=1, in_features=2048, out_features=2048), + BenchCase("prefill_q_proj", batch=1, seq=128, in_features=2048, out_features=2048), + BenchCase("batched_q_proj", batch=4, seq=128, in_features=2048, out_features=2048), + BenchCase("prefill_k_proj", batch=1, seq=128, in_features=2048, out_features=512), + BenchCase("prefill_gate_proj", batch=1, seq=128, in_features=2048, out_features=8192), + BenchCase("batched_down_proj", batch=4, seq=128, in_features=8192, out_features=2048), +] + +QUICK_CASES = [ + BenchCase("decode_q_proj", batch=1, seq=1, in_features=2048, out_features=2048), + BenchCase("prefill_q_proj", batch=1, seq=128, in_features=2048, out_features=2048), + BenchCase("batched_down_proj", batch=4, seq=128, in_features=8192, out_features=2048), +] + + +def _resolve_dtype(name: str) -> torch.dtype: + if name == "fp16": + return torch.float16 + if name == "bf16": + return torch.bfloat16 + raise ValueError(f"Unsupported dtype: {name}") + + +def _subset_cases(cases: list[BenchCase], shard_index: int, num_shards: int) -> list[BenchCase]: + if num_shards <= 0: + raise ValueError("`num_shards` must be positive.") + if shard_index < 0 or shard_index >= num_shards: + raise ValueError(f"`shard_index` must be in [0, {num_shards - 1}].") + return [case for index, case in enumerate(cases) if index % num_shards == shard_index] + + +def _pack_awq_tensor(unpacked: torch.Tensor, bits: int) -> torch.Tensor: + pack_factor = 32 // bits + order_map = [0, 2, 4, 6, 1, 3, 5, 7] + packed = torch.zeros((unpacked.shape[0], unpacked.shape[1] // pack_factor), dtype=torch.int32) + for col in range(unpacked.shape[1] // pack_factor): + for i, order in enumerate(order_map): + value = unpacked[:, col * pack_factor + order].to(torch.int32) + packed[:, col] |= value << (i * bits) + return packed + + +def _make_quant_buffers(case: BenchCase, dtype: torch.dtype, bits: int = 4) -> dict[str, torch.Tensor]: + from gptqmodel.utils.paroquant import build_identity_rotation_buffers + + groups = case.in_features // case.group_size + int_weight = torch.randint(0, 2**bits, size=(case.in_features, case.out_features), dtype=torch.int32) + zero_points = torch.randint(0, 2**bits, size=(groups, case.out_features), dtype=torch.int32) + scales = (torch.rand(groups, case.out_features, dtype=torch.float32) * 0.5) + 0.75 + scales = scales.to(dtype=dtype) + + pairs, theta, channel_scales = build_identity_rotation_buffers( + in_features=case.in_features, + group_size=case.group_size, + krot=case.krot, + dtype=dtype, + ) + theta.uniform_(-0.2, 0.2) + channel_scales.uniform_(0.75, 1.25) + + return { + "qweight": _pack_awq_tensor(int_weight, bits), + "qzeros": _pack_awq_tensor(zero_points, bits), + "scales": scales, + "pairs": pairs, + "theta": theta, + "channel_scales": channel_scales, + } + + +def _dense_reference( + x: torch.Tensor, + qweight: torch.Tensor, + qzeros: torch.Tensor, + scales: torch.Tensor, + bits: int, + group_size: int, +) -> torch.Tensor: + from gptqmodel.quantization.awq.utils.packing_utils import dequantize_gemm + + dense_weight = dequantize_gemm( + qweight=qweight, + qzeros=qzeros, + scales=scales, + bits=bits, + group_size=group_size, + ).to(device=x.device, dtype=x.dtype) + return torch.matmul(x, dense_weight) + + +def _set_fused_reduce_disabled(disabled: bool) -> None: + if disabled: + os.environ["GPTQMODEL_AWQ_DISABLE_FUSED_SPLITK_REDUCE"] = "1" + else: + os.environ.pop("GPTQMODEL_AWQ_DISABLE_FUSED_SPLITK_REDUCE", None) + + +def _benchmark_ms(fn, device: torch.device, warmup: int, iters: int, fused_reduce_disabled: bool) -> float: + _set_fused_reduce_disabled(fused_reduce_disabled) + with torch.inference_mode(): + for _ in range(warmup): + fn() + torch.cuda.synchronize(device) + start = time.perf_counter() + for _ in range(iters): + fn() + torch.cuda.synchronize(device) + return (time.perf_counter() - start) * 1e3 / iters + + +def _format_speedup(speedup: float) -> str: + return f"{speedup:.3f}x" + + +def _label(disabled: bool) -> str: + return "fused_reduce_off" if disabled else "fused_reduce_on" + + +def _run_gemm( + x: torch.Tensor, + qweight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, + split_k_iters: int, +) -> torch.Tensor: + from gptqmodel.utils.awq import awq_gemm_forward + + return awq_gemm_forward( + x, + qweight, + scales, + qzeros, + split_k_iters, + fp32_accum=True, + ) + + +def _run_suite( + device: torch.device, + dtype: torch.dtype, + warmup: int, + iters: int, + quick: bool, + rotate_inputs: bool, + shard_index: int, + num_shards: int, + split_k_iters: int, + baseline_disable_fused_reduce: bool, + candidate_disable_fused_reduce: bool, +) -> dict[str, Any]: + from gptqmodel.utils.paroquant import apply_paroquant_rotation + + cases = _subset_cases(QUICK_CASES if quick else DEFAULT_CASES, shard_index=shard_index, num_shards=num_shards) + rows = [] + speedups = [] + candidate_wins = 0 + baseline_label = _label(baseline_disable_fused_reduce) + candidate_label = _label(candidate_disable_fused_reduce) + + for index, case in enumerate(cases): + torch.manual_seed(1000 + index) + buffers = _make_quant_buffers(case, dtype=dtype) + qweight = buffers["qweight"].to(device) + qzeros = buffers["qzeros"].to(device) + scales = buffers["scales"].to(device) + x = torch.randn((case.batch * case.seq, case.in_features), device=device, dtype=dtype) + + kernel_input = x + if rotate_inputs: + kernel_input = apply_paroquant_rotation( + x, + buffers["pairs"].to(device), + buffers["theta"].to(device), + scales=buffers["channel_scales"].to(device), + group_size=case.group_size, + ) + + with torch.inference_mode(): + dense = _dense_reference(kernel_input, qweight, qzeros, scales, bits=4, group_size=case.group_size) + _set_fused_reduce_disabled(baseline_disable_fused_reduce) + baseline = _run_gemm(kernel_input, qweight, scales, qzeros, split_k_iters) + _set_fused_reduce_disabled(candidate_disable_fused_reduce) + candidate = _run_gemm(kernel_input, qweight, scales, qzeros, split_k_iters) + + baseline_dense = (baseline - dense).abs() + candidate_dense = (candidate - dense).abs() + baseline_candidate = (baseline - candidate).abs() + + baseline_ms = _benchmark_ms( + lambda: _run_gemm(kernel_input, qweight, scales, qzeros, split_k_iters), + device=device, + warmup=warmup, + iters=iters, + fused_reduce_disabled=baseline_disable_fused_reduce, + ) + candidate_ms = _benchmark_ms( + lambda: _run_gemm(kernel_input, qweight, scales, qzeros, split_k_iters), + device=device, + warmup=warmup, + iters=iters, + fused_reduce_disabled=candidate_disable_fused_reduce, + ) + speedup = baseline_ms / candidate_ms + speedups.append(speedup) + winner = candidate_label if candidate_ms < baseline_ms else baseline_label + candidate_wins += int(candidate_ms < baseline_ms) + + rows.append( + { + "case_id": case.case_id, + "batch": case.batch, + "seq": case.seq, + "in_features": case.in_features, + "out_features": case.out_features, + "dtype": str(dtype).replace("torch.", ""), + "baseline_ms": baseline_ms, + "candidate_ms": candidate_ms, + "speedup": speedup, + "winner": winner, + "baseline_dense_max_abs": baseline_dense.max().item(), + "candidate_dense_max_abs": candidate_dense.max().item(), + "baseline_candidate_max_abs": baseline_candidate.max().item(), + "baseline_candidate_mean_abs": baseline_candidate.mean().item(), + } + ) + + geo_mean_speedup = math.exp(sum(math.log(v) for v in speedups) / len(speedups)) if speedups else float("nan") + return { + "baseline_label": baseline_label, + "candidate_label": candidate_label, + "rows": rows, + "geo_mean_speedup": geo_mean_speedup, + "candidate_wins": candidate_wins, + "case_count": len(speedups), + } + + +def run( + device: torch.device, + dtype: torch.dtype, + warmup: int, + iters: int, + quick: bool, + shard_index: int, + num_shards: int, + split_k_iters: int, + baseline_disable_fused_reduce: bool, + candidate_disable_fused_reduce: bool, +) -> dict[str, Any]: + torch.manual_seed(0) + torch.cuda.manual_seed_all(0) + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + + return { + "device": torch.cuda.get_device_name(device), + "cuda_device": str(device), + "dtype": str(dtype).replace("torch.", ""), + "warmup": warmup, + "iters": iters, + "quick": quick, + "shard_index": shard_index, + "num_shards": num_shards, + "split_k_iters": split_k_iters, + "awq": _run_suite( + device=device, + dtype=dtype, + warmup=warmup, + iters=iters, + quick=quick, + rotate_inputs=False, + shard_index=shard_index, + num_shards=num_shards, + split_k_iters=split_k_iters, + baseline_disable_fused_reduce=baseline_disable_fused_reduce, + candidate_disable_fused_reduce=candidate_disable_fused_reduce, + ), + "paroquant": _run_suite( + device=device, + dtype=dtype, + warmup=warmup, + iters=iters, + quick=quick, + rotate_inputs=True, + shard_index=shard_index, + num_shards=num_shards, + split_k_iters=split_k_iters, + baseline_disable_fused_reduce=baseline_disable_fused_reduce, + candidate_disable_fused_reduce=candidate_disable_fused_reduce, + ), + } + + +def _print_suite(name: str, results: dict[str, Any]) -> None: + print(name) + print(f"Configs: baseline={results['baseline_label']} candidate={results['candidate_label']}") + print("Accuracy") + print( + tabulate( + [ + [ + row["case_id"], + f"{row['batch']}x{row['seq']}", + f"{row['in_features']}->{row['out_features']}", + f"{row['baseline_dense_max_abs']:.6f}", + f"{row['candidate_dense_max_abs']:.6f}", + f"{row['baseline_candidate_max_abs']:.6f}", + f"{row['baseline_candidate_mean_abs']:.6f}", + ] + for row in results["rows"] + ], + headers=[ + "case", + "batch x seq", + "shape", + "baseline vs dense max_abs", + "candidate vs dense max_abs", + "baseline vs candidate max_abs", + "baseline vs candidate mean_abs", + ], + tablefmt="plain", + ) + ) + print() + print("Benchmark") + print( + tabulate( + [ + [ + row["case_id"], + f"{row['batch']}x{row['seq']}", + f"{row['in_features']}->{row['out_features']}", + f"{row['baseline_ms']:.3f}", + f"{row['candidate_ms']:.3f}", + _format_speedup(row["speedup"]), + row["winner"], + ] + for row in results["rows"] + ], + headers=[ + "case", + "batch x seq", + "shape", + "baseline ms", + "candidate ms", + "speedup", + "winner", + ], + tablefmt="plain", + ) + ) + print() + print( + "Summary: " + f"candidate_wins={results['candidate_wins']}/{results['case_count']}, " + f"geo_mean_speedup={results['geo_mean_speedup']:.3f}x" + ) + print() + + +def _configure_awq_runtime(args: argparse.Namespace) -> None: + if args.force_rebuild_awq: + build_root = Path("/tmp") / ( + f"awq_jit_fusedreduce_{os.getpid()}_dev{args.device}_shard{args.shard_index}_of_{args.num_shards}" + ) + os.environ["GPTQMODEL_AWQ_BUILD_ROOT"] = str(build_root) + os.environ["GPTQMODEL_AWQ_FORCE_REBUILD"] = "1" + else: + os.environ.pop("GPTQMODEL_AWQ_BUILD_ROOT", None) + os.environ.pop("GPTQMODEL_AWQ_FORCE_REBUILD", None) + + from gptqmodel.utils.awq import awq_runtime_error, clear_awq_extension_cache, prewarm_awq_extension + + if args.force_rebuild_awq: + clear_awq_extension_cache() + + if not prewarm_awq_extension(): + raise RuntimeError(f"Failed to build/load the AWQ CUDA runtime: {awq_runtime_error()}") + + +def _configure_paroquant_runtime(args: argparse.Namespace, device: torch.device) -> None: + if args.force_rebuild_paroquant: + build_root = Path("/tmp") / ( + f"paroquant_ext_fusedreduce_{os.getpid()}_dev{args.device}_shard{args.shard_index}_of_{args.num_shards}" + ) + os.environ["GPTQMODEL_PAROQUANT_BUILD_ROOT"] = str(build_root) + os.environ["GPTQMODEL_PAROQUANT_FORCE_REBUILD"] = "1" + else: + os.environ.pop("GPTQMODEL_PAROQUANT_BUILD_ROOT", None) + os.environ.pop("GPTQMODEL_PAROQUANT_FORCE_REBUILD", None) + + from gptqmodel.utils.paroquant import clear_paroquant_rotation_extension_cache, prewarm_paroquant_rotation_extension + + if args.force_rebuild_paroquant: + clear_paroquant_rotation_extension_cache() + + if not prewarm_paroquant_rotation_extension( + fused_rotation=True, + group_size=128, + krot=8, + device=device, + ): + raise RuntimeError("Failed to build/load the fused ParoQuant CUDA rotation extension.") + + +def main() -> int: + parser = argparse.ArgumentParser(description="A/B benchmark AWQ fused split-K reduction in AWQ and ParoQuant paths.") + parser.add_argument("--device", type=int, default=0) + parser.add_argument("--dtype", choices=("fp16", "bf16"), default="fp16") + parser.add_argument("--warmup", type=int, default=5) + parser.add_argument("--iters", type=int, default=20) + parser.add_argument("--quick", action="store_true") + parser.add_argument("--shard-index", type=int, default=0) + parser.add_argument("--num-shards", type=int, default=1) + parser.add_argument("--split-k-iters", type=int, default=4) + parser.add_argument("--baseline-disable-fused-reduce", action="store_true") + parser.add_argument("--candidate-disable-fused-reduce", action="store_true") + parser.add_argument("--json-out", type=Path, default=None) + parser.add_argument("--json", action="store_true") + parser.add_argument("--force-rebuild-awq", action="store_true") + parser.add_argument("--force-rebuild-paroquant", action="store_true") + args = parser.parse_args() + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for the AWQ fused-reduce benchmark.") + + device = torch.device(f"cuda:{args.device}") + _configure_awq_runtime(args) + _configure_paroquant_runtime(args, device) + results = run( + device=device, + dtype=_resolve_dtype(args.dtype), + warmup=args.warmup, + iters=args.iters, + quick=args.quick, + shard_index=args.shard_index, + num_shards=args.num_shards, + split_k_iters=args.split_k_iters, + baseline_disable_fused_reduce=args.baseline_disable_fused_reduce, + candidate_disable_fused_reduce=args.candidate_disable_fused_reduce, + ) + + print(f"Device: {results['device']} ({results['cuda_device']}, dtype={results['dtype']})") + print() + _print_suite("AWQ", results["awq"]) + _print_suite("ParoQuant", results["paroquant"]) + + if args.json_out is not None: + args.json_out.parent.mkdir(parents=True, exist_ok=True) + args.json_out.write_text(json.dumps(results, indent=2), encoding="utf-8") + + if args.json: + print(json.dumps(results, indent=2)) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/benchmark_awq_triton_fp32_ab.py b/scripts/benchmark_awq_triton_fp32_ab.py new file mode 100644 index 000000000..636801698 --- /dev/null +++ b/scripts/benchmark_awq_triton_fp32_ab.py @@ -0,0 +1,278 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import argparse +import json +import math +import time +from dataclasses import dataclass +from typing import Any + +import torch +from tabulate import tabulate + +from gptqmodel.quantization.awq.modules.triton.gemm import awq_gemm_triton +from gptqmodel.quantization.awq.utils.packing_utils import dequantize_gemm + + +@dataclass(frozen=True) +class BenchCase: + case_id: str + batch: int + seq: int + in_features: int + out_features: int + group_size: int = 128 + + +DEFAULT_CASES = [ + BenchCase("decode_q_proj", batch=1, seq=1, in_features=2048, out_features=2048), + BenchCase("prefill_q_proj", batch=1, seq=128, in_features=2048, out_features=2048), + BenchCase("batched_q_proj", batch=4, seq=128, in_features=2048, out_features=2048), + BenchCase("prefill_k_proj", batch=1, seq=128, in_features=2048, out_features=512), + BenchCase("prefill_gate_proj", batch=1, seq=128, in_features=2048, out_features=8192), + BenchCase("batched_down_proj", batch=4, seq=128, in_features=8192, out_features=2048), +] + +QUICK_CASES = [ + BenchCase("decode_q_proj", batch=1, seq=1, in_features=2048, out_features=2048), + BenchCase("prefill_q_proj", batch=1, seq=128, in_features=2048, out_features=2048), + BenchCase("batched_down_proj", batch=4, seq=128, in_features=8192, out_features=2048), +] + + +def _pack_awq_tensor(unpacked: torch.Tensor, bits: int) -> torch.Tensor: + pack_factor = 32 // bits + order_map = [0, 2, 4, 6, 1, 3, 5, 7] + packed = torch.zeros( + (unpacked.shape[0], unpacked.shape[1] // pack_factor), + dtype=torch.int32, + ) + for col in range(unpacked.shape[1] // pack_factor): + for i, order in enumerate(order_map): + value = unpacked[:, col * pack_factor + order].to(torch.int32) + packed[:, col] |= value << (i * bits) + return packed + + +def _make_quant_buffers(case: BenchCase, bits: int = 4) -> dict[str, torch.Tensor]: + groups = case.in_features // case.group_size + int_weight = torch.randint(0, 2**bits, size=(case.in_features, case.out_features), dtype=torch.int32) + zero_points = torch.randint(0, 2**bits, size=(groups, case.out_features), dtype=torch.int32) + scales = (torch.rand(groups, case.out_features, dtype=torch.float16) * 0.5) + 0.75 + + return { + "qweight": _pack_awq_tensor(int_weight, bits), + "qzeros": _pack_awq_tensor(zero_points, bits), + "scales": scales, + } + + +def _dense_reference(x: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, scales: torch.Tensor, bits: int, group_size: int): + dense_weight = dequantize_gemm( + qweight=qweight, + qzeros=qzeros, + scales=scales, + bits=bits, + group_size=group_size, + ).to(device=x.device, dtype=x.dtype) + return torch.matmul(x, dense_weight) + + +def _benchmark_ms(fn, device: torch.device, warmup: int, iters: int) -> float: + with torch.inference_mode(): + for _ in range(warmup): + fn() + torch.cuda.synchronize(device) + start = time.perf_counter() + for _ in range(iters): + fn() + torch.cuda.synchronize(device) + return (time.perf_counter() - start) * 1e3 / iters + + +def _format_speedup(speedup: float) -> str: + return f"{speedup:.3f}x" + + +def run(device: torch.device, warmup: int, iters: int, quick: bool) -> dict[str, Any]: + torch.manual_seed(0) + torch.cuda.manual_seed_all(0) + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + + cases = QUICK_CASES if quick else DEFAULT_CASES + accuracy_rows = [] + benchmark_rows = [] + speedups = [] + + for index, case in enumerate(cases): + torch.manual_seed(1000 + index) + buffers = _make_quant_buffers(case) + qweight = buffers["qweight"].to(device) + qzeros = buffers["qzeros"].to(device) + scales = buffers["scales"].to(device) + x = torch.randn((case.batch * case.seq, case.in_features), device=device, dtype=torch.float16) + + with torch.inference_mode(): + dense = _dense_reference(x, qweight, qzeros, scales, bits=4, group_size=case.group_size) + legacy = awq_gemm_triton( + x, + qweight, + scales, + qzeros, + split_k_iters=8, + fp32_accum=False, + output_dtype=x.dtype, + ) + candidate = awq_gemm_triton( + x, + qweight, + scales, + qzeros, + split_k_iters=8, + fp32_accum=True, + output_dtype=x.dtype, + ) + + legacy_dense = (legacy - dense).abs() + candidate_dense = (candidate - dense).abs() + legacy_candidate = (legacy - candidate).abs() + + accuracy_rows.append( + [ + case.case_id, + f"{case.batch}x{case.seq}", + f"{case.in_features}->{case.out_features}", + f"{legacy_dense.max().item():.6f}", + f"{candidate_dense.max().item():.6f}", + f"{legacy_candidate.max().item():.6f}", + f"{legacy_candidate.mean().item():.6f}", + ] + ) + + legacy_ms = _benchmark_ms( + lambda: awq_gemm_triton( + x, + qweight, + scales, + qzeros, + split_k_iters=8, + fp32_accum=False, + output_dtype=x.dtype, + ), + device=device, + warmup=warmup, + iters=iters, + ) + candidate_ms = _benchmark_ms( + lambda: awq_gemm_triton( + x, + qweight, + scales, + qzeros, + split_k_iters=8, + fp32_accum=True, + output_dtype=x.dtype, + ), + device=device, + warmup=warmup, + iters=iters, + ) + speedup = legacy_ms / candidate_ms + speedups.append(speedup) + winner = "fp32" if candidate_ms < legacy_ms else "legacy" + + benchmark_rows.append( + [ + case.case_id, + f"{case.batch}x{case.seq}", + f"{case.in_features}->{case.out_features}", + f"{legacy_ms:.3f}", + f"{candidate_ms:.3f}", + _format_speedup(speedup), + winner, + ] + ) + + geo_mean_speedup = math.exp(sum(math.log(v) for v in speedups) / len(speedups)) + fp32_wins = sum(1 for value in speedups if value > 1.0) + + return { + "device": torch.cuda.get_device_name(device), + "cuda_device": str(device), + "warmup": warmup, + "iters": iters, + "quick": quick, + "accuracy_headers": [ + "case", + "batch x seq", + "matmul", + "legacy vs dense max_abs", + "fp32 vs dense max_abs", + "legacy vs fp32 max_abs", + "legacy vs fp32 mean_abs", + ], + "accuracy_rows": accuracy_rows, + "benchmark_headers": [ + "case", + "batch x seq", + "matmul", + "legacy ms", + "fp32 ms", + "speedup", + "winner", + ], + "benchmark_rows": benchmark_rows, + "geo_mean_speedup": geo_mean_speedup, + "fp32_wins": fp32_wins, + "case_count": len(speedups), + } + + +def main() -> int: + parser = argparse.ArgumentParser(description="A/B benchmark legacy AWQ Triton accumulation against fp32 accumulation.") + parser.add_argument("--device", type=int, default=0, help="CUDA device index within the current visible set.") + parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations per case.") + parser.add_argument("--iters", type=int, default=20, help="Measured iterations per case.") + parser.add_argument("--quick", action="store_true", help="Run a smaller subset of benchmark cases.") + parser.add_argument("--json", action="store_true", help="Also emit the full result payload as JSON.") + args = parser.parse_args() + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for the AWQ Triton benchmark.") + + try: + import triton # noqa: F401 + except Exception as exc: # pragma: no cover - environment dependent + raise RuntimeError(f"Triton is required for the AWQ Triton benchmark: {exc}") from exc + + device = torch.device(f"cuda:{args.device}") + results = run(device=device, warmup=args.warmup, iters=args.iters, quick=args.quick) + + print(f"Device: {results['device']} ({results['cuda_device']})") + print() + print("Accuracy") + print(tabulate(results["accuracy_rows"], headers=results["accuracy_headers"], tablefmt="grid")) + print() + print("Benchmark") + print(tabulate(results["benchmark_rows"], headers=results["benchmark_headers"], tablefmt="grid")) + print() + print( + "Summary: " + f"fp32_wins={results['fp32_wins']}/{results['case_count']}, " + f"geo_mean_speedup={results['geo_mean_speedup']:.3f}x" + ) + + if args.json: + print() + print(json.dumps(results, indent=2)) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/benchmark_gguf_autotune_ab.py b/scripts/benchmark_gguf_autotune_ab.py new file mode 100644 index 000000000..f5204083a --- /dev/null +++ b/scripts/benchmark_gguf_autotune_ab.py @@ -0,0 +1,276 @@ +#!/usr/bin/env python3 + +from __future__ import annotations + +import argparse +import time +from dataclasses import dataclass + +import torch +import torch.nn as nn + +from gptqmodel.nn_modules.qlinear.gguf import GGUFTorchLinear + + +@dataclass(frozen=True) +class BenchCase: + name: str + bits: str + in_features: int + out_features: int + rows: int + group_size: int = -1 + + +def _ascii_table(headers: list[str], rows: list[list[str]]) -> str: + widths = [len(h) for h in headers] + for row in rows: + for i, cell in enumerate(row): + widths[i] = max(widths[i], len(cell)) + + def fmt(row: list[str]) -> str: + return "| " + " | ".join(cell.ljust(widths[i]) for i, cell in enumerate(row)) + " |" + + sep = "+-" + "-+-".join("-" * width for width in widths) + "-+" + out = [sep, fmt(headers), sep] + for row in rows: + out.append(fmt(row)) + out.append(sep) + return "\n".join(out) + + +def _build_module( + case: BenchCase, + *, + dtype: torch.dtype, + device: str, + autotune: bool, + force_candidate: bool, +) -> GGUFTorchLinear: + linear = nn.Linear(case.in_features, case.out_features, bias=False, dtype=dtype).cpu().eval() + torch.manual_seed(0) + with torch.no_grad(): + linear.weight.normal_(mean=0.0, std=0.02) + + module = GGUFTorchLinear( + bits=case.bits, + group_size=case.group_size, + sym=True, + desc_act=False, + in_features=case.in_features, + out_features=case.out_features, + bias=False, + register_buffers=False, + ) + module.pack_original(linear, scales=torch.empty(0), zeros=torch.empty(0), g_idx=None) + module.post_init() + if force_candidate: + module.gguf_fused_cuda_max_rows = max(case.rows, 1) + module.gguf_fused_cuda_min_matrix_elements = 0 + module.gguf_fused_cpu_max_rows = max(case.rows, 1) + module.gguf_fused_cpu_min_matrix_elements = 0 + module.autotune_enabled = autotune + module.clear_autotune() + module = module.to(device).eval() + return module + + +def _bench_once(fn, *, sync_cuda: bool) -> float: + if sync_cuda: + torch.cuda.synchronize() + t0 = time.perf_counter() + fn() + if sync_cuda: + torch.cuda.synchronize() + return (time.perf_counter() - t0) * 1000.0 + + +def _plan_label(module: GGUFTorchLinear, x: torch.Tensor) -> str: + if not module.autotune_enabled: + return "fused" if module._is_fused_k_forward_candidate(x) else "none" + decision = module.get_autotune_result() + if decision is None: + return "none" + return "fused" if decision else "dense" + + +def _run_case( + case: BenchCase, + *, + dtype: torch.dtype, + device: str, + trials: int, + warmup: int, + force_candidate: bool, +) -> None: + sync_cuda = device == "cuda" + x = torch.randn(case.rows, case.in_features, device=device, dtype=dtype) + + static_module = _build_module(case, dtype=dtype, device=device, autotune=False, force_candidate=force_candidate) + autotune_module = _build_module(case, dtype=dtype, device=device, autotune=True, force_candidate=force_candidate) + + for _ in range(warmup): + static_module(x) + autotune_module(x) + + # Untimed warmup to settle dispatch decisions before measurement. + static_module(x) + autotune_module(x) + + static_plan = _plan_label(static_module, x) + autotune_plan = _plan_label(autotune_module, x) + + static_trials: list[float] = [] + autotune_trials: list[float] = [] + for _ in range(trials): + static_trials.append(_bench_once(lambda: static_module(x), sync_cuda=sync_cuda)) + autotune_trials.append(_bench_once(lambda: autotune_module(x), sync_cuda=sync_cuda)) + + static_out = static_module(x) + autotune_out = autotune_module(x) + diff = (static_out.to(torch.float32) - autotune_out.to(torch.float32)).abs() + mae = diff.mean().item() + max_abs = diff.max().item() + + trial_rows: list[list[str]] = [] + for idx, (static_ms, autotune_ms) in enumerate(zip(static_trials, autotune_trials), start=1): + speedup = static_ms / autotune_ms if autotune_ms > 0 else float("inf") + delta_pct = ((static_ms - autotune_ms) / static_ms * 100.0) if static_ms > 0 else 0.0 + trial_rows.append( + [ + str(idx), + f"{static_ms:.3f}", + f"{autotune_ms:.3f}", + f"{speedup:.2f}x", + f"{delta_pct:.1f}%", + ] + ) + + static_mean = sum(static_trials) / len(static_trials) + autotune_mean = sum(autotune_trials) / len(autotune_trials) + summary_rows = [ + [ + "static", + static_plan, + f"{static_mean:.3f}", + f"{min(static_trials):.3f}", + f"{max(static_trials):.3f}", + "-", + ], + [ + "autotune", + autotune_plan, + f"{autotune_mean:.3f}", + f"{min(autotune_trials):.3f}", + f"{max(autotune_trials):.3f}", + f"{static_mean / autotune_mean:.2f}x", + ], + ] + + print() + print( + f"CASE {case.name} device={device} bits={case.bits} rows={case.rows} " + f"shape={case.out_features}x{case.in_features} dtype={str(dtype).removeprefix('torch.')}" + ) + print(_ascii_table(["trial", "static_ms", "autotune_ms", "speedup", "delta_pct"], trial_rows)) + print(_ascii_table(["mode", "plan", "mean_ms", "min_ms", "max_ms", "speedup_vs_static"], summary_rows)) + print(f"correctness: mae={mae:.6f} max_abs={max_abs:.6f}") + + +def _parse_case(spec: str) -> BenchCase: + parts = spec.split(":") + if len(parts) not in (5, 6): + raise ValueError( + f"Invalid case `{spec}`. Expected name:bits:in_features:out_features:rows[:group_size]." + ) + name, bits, in_features, out_features, rows, *rest = parts + return BenchCase( + name=name, + bits=bits, + in_features=int(in_features), + out_features=int(out_features), + rows=int(rows), + group_size=int(rest[0]) if rest else -1, + ) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="A/B benchmark GGUF static dispatch vs autotuned dispatch, excluding autotune setup cost." + ) + parser.add_argument( + "--case", + action="append", + dest="cases", + default=[], + help="Benchmark case as name:bits:in_features:out_features:rows[:group_size].", + ) + parser.add_argument("--trials", type=int, default=5, help="Measured trials per case.") + parser.add_argument("--warmup", type=int, default=2, help="Untimed warmup forwards before the measured warmup call.") + parser.add_argument("--dtype", choices=("fp16", "bf16", "fp32"), default="fp16", help="Benchmark dtype.") + parser.add_argument( + "--force-candidate", + action="store_true", + help="Override thresholds so every case is eligible for fused-vs-dense dispatch tuning.", + ) + parser.add_argument( + "--device", + choices=("auto", "cpu", "cuda", "both"), + default="auto", + help="Benchmark device selection.", + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + if args.dtype == "fp16": + dtype = torch.float16 + elif args.dtype == "bf16": + dtype = torch.bfloat16 + else: + dtype = torch.float32 + + if args.device == "auto": + devices = ["cuda"] if torch.cuda.is_available() else ["cpu"] + elif args.device == "both": + devices = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) + else: + devices = [args.device] + + if "cuda" in devices and not torch.cuda.is_available(): + raise RuntimeError("CUDA requested but no CUDA device is available.") + if "cpu" in devices and dtype == torch.float16: + raise RuntimeError("CPU benchmarks should use --dtype bf16 or --dtype fp32.") + + cases = ( + [_parse_case(spec) for spec in args.cases] + if args.cases + else [ + BenchCase("attn_q4_k_m_r1", "q4_k_m", 2048, 2048, 1), + BenchCase("attn_q5_k_m_r1", "q5_k_m", 2048, 2048, 1), + BenchCase("attn_q6_k_r1", "q6_k", 2048, 2048, 1), + BenchCase("mlp_q4_k_m_r8", "q4_k_m", 2048, 8192, 8), + BenchCase("mlp_q5_k_m_r8", "q5_k_m", 2048, 8192, 8), + BenchCase("mlp_q6_k_r8", "q6_k", 2048, 8192, 8), + ] + ) + + print( + f"devices={','.join(devices)} dtype={str(dtype).removeprefix('torch.')} " + f"trials={args.trials} warmup={args.warmup} force_candidate={args.force_candidate}" + ) + for device in devices: + for case in cases: + _run_case( + case, + dtype=dtype, + device=device, + trials=args.trials, + warmup=args.warmup, + force_candidate=args.force_candidate, + ) + + +if __name__ == "__main__": + main() diff --git a/scripts/benchmark_gguf_cpp_vs_torch.py b/scripts/benchmark_gguf_cpp_vs_torch.py new file mode 100644 index 000000000..c7003266d --- /dev/null +++ b/scripts/benchmark_gguf_cpp_vs_torch.py @@ -0,0 +1,340 @@ +#!/usr/bin/env python3 + +from __future__ import annotations + +import argparse +import time +from dataclasses import dataclass +from typing import Callable + +import torch +import torch.nn as nn + +from gptqmodel.nn_modules.qlinear.gguf import GGUFTorchLinear +from gptqmodel.nn_modules.qlinear.gguf_cpp import GGUFCppKernel, GGUFCudaKernel +from gptqmodel.nn_modules.qlinear.gguf_triton import GGUFTritonKernel, triton_available as gguf_triton_available + + +@dataclass(frozen=True) +class BenchCase: + name: str + bits: str + in_features: int + out_features: int + rows: int + bias: bool = False + + +def _ascii_table(headers: list[str], rows: list[list[str]]) -> str: + widths = [len(h) for h in headers] + for row in rows: + for i, cell in enumerate(row): + widths[i] = max(widths[i], len(cell)) + + def fmt(row: list[str]) -> str: + return "| " + " | ".join(cell.ljust(widths[i]) for i, cell in enumerate(row)) + " |" + + sep = "+-" + "-+-".join("-" * width for width in widths) + "-+" + out = [sep, fmt(headers), sep] + for row in rows: + out.append(fmt(row)) + out.append(sep) + return "\n".join(out) + + +def _parse_case(spec: str) -> BenchCase: + parts = spec.split(":") + if len(parts) not in (5, 6): + raise ValueError(f"Invalid case `{spec}`. Expected name:bits:in_features:out_features:rows[:bias].") + name, bits, in_features, out_features, rows, *rest = parts + return BenchCase( + name=name, + bits=bits, + in_features=int(in_features), + out_features=int(out_features), + rows=int(rows), + bias=bool(int(rest[0])) if rest else False, + ) + + +def _sync(device: str) -> None: + if device == "cuda": + torch.cuda.synchronize() + + +def _bench(fn: Callable[[], torch.Tensor], *, device: str, warmup: int, trials: int) -> tuple[list[float], torch.Tensor]: + last = None + for _ in range(warmup): + last = fn() + _sync(device) + + samples = [] + for _ in range(trials): + _sync(device) + t0 = time.perf_counter() + last = fn() + _sync(device) + samples.append((time.perf_counter() - t0) * 1000.0) + assert last is not None + return samples, last + + +def _supports_triton(case: BenchCase) -> bool: + return gguf_triton_available() and case.bits in {"q4_k_s", "q4_k_m", "q5_k_s", "q5_k_m", "q6_k"} + + +def _build_modules( + case: BenchCase, + *, + dtype: torch.dtype, +) -> tuple[GGUFTorchLinear, GGUFCppKernel, GGUFCudaKernel, GGUFTritonKernel | None]: + torch.manual_seed(0) + linear = nn.Linear(case.in_features, case.out_features, bias=case.bias, dtype=torch.float16).cpu().eval() + with torch.no_grad(): + linear.weight.normal_(mean=0.0, std=0.02) + if linear.bias is not None: + linear.bias.normal_(mean=0.0, std=0.01) + + torch_kernel = GGUFTorchLinear( + bits=case.bits, + group_size=-1, + sym=True, + desc_act=False, + in_features=case.in_features, + out_features=case.out_features, + bias=case.bias, + register_buffers=True, + ).eval() + torch_kernel.pack_original(linear, scales=None, zeros=None) + + cpu_kernel = GGUFCppKernel( + bits=case.bits, + group_size=-1, + sym=True, + desc_act=False, + in_features=case.in_features, + out_features=case.out_features, + bias=case.bias, + register_buffers=True, + ).eval() + cpu_kernel.load_state_dict(torch_kernel.state_dict(), strict=True) + + cuda_kernel = GGUFCudaKernel( + bits=case.bits, + group_size=-1, + sym=True, + desc_act=False, + in_features=case.in_features, + out_features=case.out_features, + bias=case.bias, + register_buffers=True, + ).eval() + cuda_kernel.load_state_dict(torch_kernel.state_dict(), strict=True) + + triton_kernel = None + if _supports_triton(case): + triton_kernel = GGUFTritonKernel( + bits=case.bits, + group_size=-1, + sym=True, + desc_act=False, + in_features=case.in_features, + out_features=case.out_features, + bias=case.bias, + register_buffers=True, + ).eval() + triton_kernel.load_state_dict(torch_kernel.state_dict(), strict=True) + + return torch_kernel, cpu_kernel, cuda_kernel, triton_kernel + + +def _mean(values: list[float]) -> float: + return sum(values) / len(values) + + +def _run_cpu(case: BenchCase, *, dtype: torch.dtype, warmup: int, trials: int) -> tuple[list[list[str]], list[str]]: + torch_kernel, cpu_kernel, _, _ = _build_modules(case, dtype=dtype) + x = torch.randn(case.rows, case.in_features, dtype=dtype, device="cpu") + + torch_trials, torch_out = _bench(lambda: torch_kernel(x), device="cpu", warmup=warmup, trials=trials) + cpu_trials, cpu_out = _bench(lambda: cpu_kernel(x), device="cpu", warmup=warmup, trials=trials) + + diff = (torch_out.float() - cpu_out.float()).abs() + trial_rows = [] + for idx, (torch_ms, cpp_ms) in enumerate(zip(torch_trials, cpu_trials), start=1): + speedup = torch_ms / cpp_ms if cpp_ms > 0 else float("inf") + trial_rows.append([str(idx), f"{torch_ms:.3f}", f"{cpp_ms:.3f}", f"{speedup:.2f}x"]) + + summary = [ + case.name, + case.bits, + f"{case.rows}x{case.in_features}", + f"{case.out_features}x{case.in_features}", + f"{_mean(torch_trials):.3f}", + f"{_mean(cpu_trials):.3f}", + "n/a", + f"{(_mean(torch_trials) / _mean(cpu_trials)):.2f}x", + "n/a", + f"{diff.mean().item():.6f}", + f"{diff.max().item():.6f}", + "n/a", + "n/a", + ] + return trial_rows, summary + + +def _run_cuda(case: BenchCase, *, dtype: torch.dtype, warmup: int, trials: int) -> tuple[list[list[str]], list[str]]: + torch_kernel, _, cuda_kernel, triton_kernel = _build_modules(case, dtype=dtype) + torch_kernel = torch_kernel.to("cuda") + cuda_kernel = cuda_kernel.to("cuda") + if triton_kernel is not None: + triton_kernel = triton_kernel.to("cuda") + x = torch.randn(case.rows, case.in_features, dtype=dtype, device="cuda") + + torch_trials, torch_out = _bench(lambda: torch_kernel(x), device="cuda", warmup=warmup, trials=trials) + cuda_trials, cuda_out = _bench(lambda: cuda_kernel(x), device="cuda", warmup=warmup, trials=trials) + triton_trials = None + triton_out = None + if triton_kernel is not None: + triton_trials, triton_out = _bench(lambda: triton_kernel(x), device="cuda", warmup=warmup, trials=trials) + + diff = (torch_out.float() - cuda_out.float()).abs() + triton_diff = None if triton_out is None else (torch_out.float() - triton_out.float()).abs() + trial_rows = [] + for idx, (torch_ms, cpp_ms) in enumerate(zip(torch_trials, cuda_trials), start=1): + cpp_speedup = torch_ms / cpp_ms if cpp_ms > 0 else float("inf") + if triton_trials is None: + triton_ms = "n/a" + triton_speedup = "n/a" + else: + trial_triton_ms = triton_trials[idx - 1] + triton_ms = f"{trial_triton_ms:.3f}" + triton_speedup = f"{(torch_ms / trial_triton_ms):.2f}x" if trial_triton_ms > 0 else "inf" + trial_rows.append([str(idx), f"{torch_ms:.3f}", f"{cpp_ms:.3f}", triton_ms, f"{cpp_speedup:.2f}x", triton_speedup]) + + summary = [ + case.name, + case.bits, + f"{case.rows}x{case.in_features}", + f"{case.out_features}x{case.in_features}", + f"{_mean(torch_trials):.3f}", + f"{_mean(cuda_trials):.3f}", + "n/a" if triton_trials is None else f"{_mean(triton_trials):.3f}", + f"{(_mean(torch_trials) / _mean(cuda_trials)):.2f}x", + "n/a" if triton_trials is None else f"{(_mean(torch_trials) / _mean(triton_trials)):.2f}x", + f"{diff.mean().item():.6f}", + f"{diff.max().item():.6f}", + "n/a" if triton_diff is None else f"{triton_diff.mean().item():.6f}", + "n/a" if triton_diff is None else f"{triton_diff.max().item():.6f}", + ] + return trial_rows, summary + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Benchmark GGUF cpp kernels against GGUF torch.") + parser.add_argument( + "--case", + action="append", + dest="cases", + default=[], + help="Case as name:bits:in_features:out_features:rows[:bias].", + ) + parser.add_argument("--trials", type=int, default=5) + parser.add_argument("--warmup", type=int, default=2) + parser.add_argument("--dtype-cpu", choices=("fp32", "bf16"), default="fp32") + parser.add_argument("--dtype-cuda", choices=("fp16", "bf16", "fp32"), default="fp16") + parser.add_argument("--device", choices=("cpu", "cuda", "both"), default="both") + return parser.parse_args() + + +def main() -> None: + args = parse_args() + cpu_dtype = {"fp32": torch.float32, "bf16": torch.bfloat16}[args.dtype_cpu] + cuda_dtype = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp32": torch.float32}[args.dtype_cuda] + cases = ( + [_parse_case(spec) for spec in args.cases] + if args.cases + else [ + BenchCase("attn_q4_k_m_r1", "q4_k_m", 2048, 2048, 1), + BenchCase("attn_q4_k_m_r8", "q4_k_m", 2048, 2048, 8), + BenchCase("mlp_q4_k_m_r8", "q4_k_m", 2048, 8192, 8), + BenchCase("mlp_q5_k_m_r8", "q5_k_m", 2048, 8192, 8), + BenchCase("mlp_q6_k_r8", "q6_k", 2048, 8192, 8), + ] + ) + + print( + f"device={args.device} trials={args.trials} warmup={args.warmup} " + f"dtype_cpu={str(cpu_dtype).removeprefix('torch.')} dtype_cuda={str(cuda_dtype).removeprefix('torch.')}" + ) + + if args.device in {"cpu", "both"}: + cpu_summary = [] + print("\nCPU per-trial") + for case in cases: + trial_rows, summary = _run_cpu(case, dtype=cpu_dtype, warmup=args.warmup, trials=args.trials) + print(f"\nCASE {case.name} bits={case.bits} rows={case.rows}") + print(_ascii_table(["trial", "gguf_torch_ms", "gguf_cpp_cpu_ms", "speedup"], trial_rows)) + cpu_summary.append(summary) + print("\nCPU summary") + print( + _ascii_table( + [ + "case", + "bits", + "rowsxin", + "outxin", + "gguf_torch_ms", + "gguf_cpp_cpu_ms", + "gguf_triton_ms", + "speedup", + "triton_speedup", + "mae", + "max_abs", + "triton_mae", + "triton_max_abs", + ], + cpu_summary, + ) + ) + + if args.device in {"cuda", "both"}: + if not torch.cuda.is_available(): + raise RuntimeError("CUDA benchmark requested but torch.cuda.is_available() is False.") + cuda_summary = [] + print("\nCUDA per-trial") + for case in cases: + trial_rows, summary = _run_cuda(case, dtype=cuda_dtype, warmup=args.warmup, trials=args.trials) + print(f"\nCASE {case.name} bits={case.bits} rows={case.rows}") + print( + _ascii_table( + ["trial", "gguf_torch_ms", "gguf_cpp_cuda_ms", "gguf_triton_ms", "cpp_speedup", "triton_speedup"], + trial_rows, + ) + ) + cuda_summary.append(summary) + print("\nCUDA summary") + print( + _ascii_table( + [ + "case", + "bits", + "rowsxin", + "outxin", + "gguf_torch_ms", + "gguf_cpp_cuda_ms", + "gguf_triton_ms", + "cpp_speedup", + "triton_speedup", + "cpp_mae", + "cpp_max_abs", + "triton_mae", + "triton_max_abs", + ], + cuda_summary, + ) + ) + + +if __name__ == "__main__": + main() diff --git a/scripts/benchmark_gguf_dequant.py b/scripts/benchmark_gguf_dequant.py new file mode 100644 index 000000000..d21f0a4f3 --- /dev/null +++ b/scripts/benchmark_gguf_dequant.py @@ -0,0 +1,218 @@ +#!/usr/bin/env python3 + +from __future__ import annotations + +import argparse +import time +from dataclasses import dataclass + +import gguf +import torch +import torch.nn as nn + +from gptqmodel.nn_modules.qlinear.gguf import GGUFTorchLinear + + +@dataclass(frozen=True) +class BenchCase: + name: str + bits: str + in_features: int + out_features: int + group_size: int + + +def _build_module(case: BenchCase, dtype: torch.dtype) -> GGUFTorchLinear: + linear = nn.Linear(case.in_features, case.out_features, bias=False, dtype=dtype).cpu().eval() + torch.manual_seed(0) + with torch.no_grad(): + linear.weight.normal_(mean=0.0, std=0.02) + + module = GGUFTorchLinear( + bits=case.bits, + group_size=case.group_size, + sym=True, + desc_act=False, + in_features=case.in_features, + out_features=case.out_features, + bias=False, + register_buffers=False, + ) + module.pack_original(linear, scales=torch.empty(0), zeros=torch.empty(0), g_idx=None) + module.post_init() + return module.cpu().eval() + + +def _gguf_qtype(bits: str) -> gguf.GGMLQuantizationType: + mapping = { + "q4_0": gguf.GGMLQuantizationType.Q4_0, + "q8_0": gguf.GGMLQuantizationType.Q8_0, + "q4_k": gguf.GGMLQuantizationType.Q4_K, + "q4_k_s": gguf.GGMLQuantizationType.Q4_K, + "q4_k_m": gguf.GGMLQuantizationType.Q4_K, + "q5_k": gguf.GGMLQuantizationType.Q5_K, + "q5_k_s": gguf.GGMLQuantizationType.Q5_K, + "q5_k_m": gguf.GGMLQuantizationType.Q5_K, + "q6_k": gguf.GGMLQuantizationType.Q6_K, + } + return mapping[bits] + + +def _bench(fn, *, iters: int, warmup: int, sync_cuda: bool) -> tuple[float, float, float]: + for _ in range(warmup): + fn() + if sync_cuda and torch.cuda.is_available(): + torch.cuda.synchronize() + + samples_ms: list[float] = [] + for _ in range(iters): + t0 = time.perf_counter() + fn() + if sync_cuda and torch.cuda.is_available(): + torch.cuda.synchronize() + samples_ms.append((time.perf_counter() - t0) * 1000.0) + + return sum(samples_ms) / len(samples_ms), min(samples_ms), max(samples_ms) + + +def _print_row(label: str, mean_ms: float, min_ms: float, max_ms: float) -> None: + print(f"{label:28s} mean={mean_ms:8.3f} ms min={min_ms:8.3f} max={max_ms:8.3f}") + + +def run_case(case: BenchCase, *, dtype: torch.dtype, device: str, iters: int, warmup: int) -> None: + module_cpu = _build_module(case, dtype=dtype) + qweight_np = module_cpu.qweight.detach().cpu().numpy() + qtype = _gguf_qtype(case.bits) + + print() + print( + f"CASE {case.name} bits={case.bits} " + f"shape={case.out_features}x{case.in_features} group_size={case.group_size}" + ) + + mean_ms, min_ms, max_ms = _bench( + lambda: gguf.dequantize(qweight_np, qtype), + iters=iters, + warmup=warmup, + sync_cuda=False, + ) + _print_row("gguf.dequantize cpu", mean_ms, min_ms, max_ms) + + mean_ms, min_ms, max_ms = _bench( + lambda: module_cpu.dequantize_weight(device="cpu", dtype=torch.float32), + iters=iters, + warmup=warmup, + sync_cuda=False, + ) + _print_row("gptqmodel dequant cpu fp32", mean_ms, min_ms, max_ms) + + if device == "cuda": + module_gpu = module_cpu.to("cuda").eval() + x = torch.randn(1, case.in_features, device="cuda", dtype=dtype) + + mean_ms, min_ms, max_ms = _bench( + lambda: module_gpu.dequantize_weight(device="cuda", dtype=dtype), + iters=iters, + warmup=warmup, + sync_cuda=True, + ) + _print_row(f"gptqmodel dequant cuda {str(dtype).removeprefix('torch.')}", mean_ms, min_ms, max_ms) + + def cold_forward(): + module_gpu.clear_weight_cache() + module_gpu(x) + + def hot_forward(): + module_gpu(x) + + mean_ms, min_ms, max_ms = _bench( + cold_forward, + iters=iters, + warmup=warmup, + sync_cuda=True, + ) + _print_row("gptqmodel forward cold", mean_ms, min_ms, max_ms) + + hot_forward() + torch.cuda.synchronize() + mean_ms, min_ms, max_ms = _bench( + hot_forward, + iters=max(iters * 2, 10), + warmup=warmup, + sync_cuda=True, + ) + _print_row("gptqmodel forward hot", mean_ms, min_ms, max_ms) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Micro-benchmark GGUF dequantization and forward paths.") + parser.add_argument( + "--case", + action="append", + dest="cases", + default=[], + help="Benchmark case as name:bits:in_features:out_features[:group_size]. " + "Example: attn:q4_k_m:2048:2048:128", + ) + parser.add_argument("--iters", type=int, default=20, help="Measured iterations per benchmark.") + parser.add_argument("--warmup", type=int, default=3, help="Warmup iterations per benchmark.") + parser.add_argument( + "--dtype", + choices=("fp16", "bf16"), + default="fp16", + help="Target GPU forward/dequant dtype when CUDA is available.", + ) + parser.add_argument( + "--device", + choices=("auto", "cpu", "cuda"), + default="auto", + help="Benchmark target for dequant/forward. `auto` chooses CUDA when available.", + ) + return parser.parse_args() + + +def _parse_case(spec: str) -> BenchCase: + parts = spec.split(":") + if len(parts) not in (4, 5): + raise ValueError( + f"Invalid case `{spec}`. Expected name:bits:in_features:out_features[:group_size]." + ) + name, bits, in_features, out_features, *rest = parts + group_size = int(rest[0]) if rest else -1 + return BenchCase( + name=name, + bits=bits, + in_features=int(in_features), + out_features=int(out_features), + group_size=group_size, + ) + + +def main() -> None: + args = parse_args() + dtype = torch.float16 if args.dtype == "fp16" else torch.bfloat16 + if args.device == "auto": + device = "cuda" if torch.cuda.is_available() else "cpu" + elif args.device == "cuda" and not torch.cuda.is_available(): + raise RuntimeError("CUDA requested but no CUDA device is available.") + else: + device = args.device + + cases = ( + [_parse_case(spec) for spec in args.cases] + if args.cases + else [ + BenchCase("attn_q4_0", "q4_0", 2048, 2048, 128), + BenchCase("attn_q4_k_m", "q4_k_m", 2048, 2048, 128), + BenchCase("mlp_q4_0", "q4_0", 2048, 8192, 128), + BenchCase("mlp_q4_k_m", "q4_k_m", 2048, 8192, 128), + ] + ) + + print(f"device={device} dtype={str(dtype).removeprefix('torch.')} cuda_available={torch.cuda.is_available()}") + for case in cases: + run_case(case, dtype=dtype, device=device, iters=args.iters, warmup=args.warmup) + + +if __name__ == "__main__": + main() diff --git a/scripts/benchmark_gguf_fused_ab.py b/scripts/benchmark_gguf_fused_ab.py new file mode 100644 index 000000000..f29e4a774 --- /dev/null +++ b/scripts/benchmark_gguf_fused_ab.py @@ -0,0 +1,302 @@ +#!/usr/bin/env python3 + +from __future__ import annotations + +import argparse +import time +from dataclasses import dataclass + +import torch +import torch.nn as nn + +from gptqmodel.nn_modules.qlinear.gguf import GGUFTorchLinear +from gptqmodel.nn_modules.qlinear.gguf_triton import GGUFTritonKernel, triton_available as gguf_triton_triton_available + + +@dataclass(frozen=True) +class BenchCase: + name: str + bits: str + in_features: int + out_features: int + rows: int + group_size: int = -1 + + +def _ascii_table(headers: list[str], rows: list[list[str]]) -> str: + widths = [len(h) for h in headers] + for row in rows: + for i, cell in enumerate(row): + widths[i] = max(widths[i], len(cell)) + + def fmt(row: list[str]) -> str: + return "| " + " | ".join(cell.ljust(widths[i]) for i, cell in enumerate(row)) + " |" + + sep = "+-" + "-+-".join("-" * width for width in widths) + "-+" + out = [sep, fmt(headers), sep] + for row in rows: + out.append(fmt(row)) + out.append(sep) + return "\n".join(out) + + +def _build_module(case: BenchCase, *, dtype: torch.dtype) -> GGUFTorchLinear: + linear = nn.Linear(case.in_features, case.out_features, bias=False, dtype=dtype).cpu().eval() + torch.manual_seed(0) + with torch.no_grad(): + linear.weight.normal_(mean=0.0, std=0.02) + + module = GGUFTorchLinear( + bits=case.bits, + group_size=case.group_size, + sym=True, + desc_act=False, + in_features=case.in_features, + out_features=case.out_features, + bias=False, + register_buffers=False, + ) + module.pack_original(linear, scales=torch.empty(0), zeros=torch.empty(0), g_idx=None) + module.post_init() + module.gguf_fused_cuda_max_rows = max(case.rows, 1) + module.gguf_fused_cuda_min_matrix_elements = 0 + module.gguf_fused_cpu_max_rows = max(case.rows, 1) + module.gguf_fused_cpu_min_matrix_elements = 0 + return module + + +def _bench_once(fn, *, sync_cuda: bool) -> float: + if sync_cuda: + torch.cuda.synchronize() + t0 = time.perf_counter() + fn() + if sync_cuda: + torch.cuda.synchronize() + return (time.perf_counter() - t0) * 1000.0 + + +def _run_case( + case: BenchCase, + *, + dtype: torch.dtype, + device: str, + trials: int, + warmup: int, + include_triton: bool, +) -> None: + module = _build_module(case, dtype=dtype) + module = module.to(device).eval() + triton_module = None + x = torch.randn(case.rows, case.in_features, device=device, dtype=dtype) + sync_cuda = device == "cuda" + can_bench_triton = ( + include_triton + and device == "cuda" + and dtype == torch.float16 + and module.gguf_tensor_qtype in {"Q4_K", "Q5_K", "Q6_K"} + and gguf_triton_triton_available() + ) + if can_bench_triton: + triton_module = GGUFTritonKernel( + bits=case.bits, + group_size=case.group_size, + sym=True, + desc_act=False, + in_features=case.in_features, + out_features=case.out_features, + bias=False, + register_buffers=True, + ).to(device).eval() + triton_module.load_state_dict(module.state_dict(), strict=True) + + for _ in range(warmup): + module._forward_dequant_matmul(x) + module._forward_fused_k(x) + if can_bench_triton: + triton_module(x) + + baseline_trials: list[float] = [] + fused_trials: list[float] = [] + triton_trials: list[float] = [] + for _ in range(trials): + baseline_trials.append(_bench_once(lambda: module._forward_dequant_matmul(x), sync_cuda=sync_cuda)) + fused_trials.append(_bench_once(lambda: module._forward_fused_k(x), sync_cuda=sync_cuda)) + if can_bench_triton: + triton_trials.append(_bench_once(lambda: triton_module(x), sync_cuda=sync_cuda)) + + baseline_out = module._forward_dequant_matmul(x) + fused_out = module._forward_fused_k(x) + if can_bench_triton: + triton_out = triton_module(x) + triton_diff = (baseline_out.to(torch.float32) - triton_out.to(torch.float32)).abs() + triton_mae = triton_diff.mean().item() + triton_max_abs = triton_diff.max().item() + diff = (baseline_out.to(torch.float32) - fused_out.to(torch.float32)).abs() + mae = diff.mean().item() + max_abs = diff.max().item() + + if can_bench_triton: + trial_rows = [] + for idx, (baseline_ms, fused_ms, triton_ms) in enumerate(zip(baseline_trials, fused_trials, triton_trials), start=1): + trial_rows.append( + [ + str(idx), + f"{baseline_ms:.3f}", + f"{fused_ms:.3f}", + f"{triton_ms:.3f}", + f"{baseline_ms / fused_ms:.2f}x" if fused_ms > 0 else "inf", + f"{baseline_ms / triton_ms:.2f}x" if triton_ms > 0 else "inf", + ] + ) + else: + trial_rows = [] + for idx, (baseline_ms, fused_ms) in enumerate(zip(baseline_trials, fused_trials), start=1): + speedup = baseline_ms / fused_ms if fused_ms > 0 else float("inf") + delta_pct = ((baseline_ms - fused_ms) / baseline_ms * 100.0) if baseline_ms > 0 else 0.0 + trial_rows.append( + [ + str(idx), + f"{baseline_ms:.3f}", + f"{fused_ms:.3f}", + f"{speedup:.2f}x", + f"{delta_pct:.1f}%", + ] + ) + + summary_rows = [ + [ + "baseline", + f"{sum(baseline_trials) / len(baseline_trials):.3f}", + f"{min(baseline_trials):.3f}", + f"{max(baseline_trials):.3f}", + "-", + ], + [ + "fused", + f"{sum(fused_trials) / len(fused_trials):.3f}", + f"{min(fused_trials):.3f}", + f"{max(fused_trials):.3f}", + f"{(sum(baseline_trials) / len(baseline_trials)) / (sum(fused_trials) / len(fused_trials)):.2f}x", + ], + ] + if can_bench_triton: + summary_rows.append( + [ + "triton", + f"{sum(triton_trials) / len(triton_trials):.3f}", + f"{min(triton_trials):.3f}", + f"{max(triton_trials):.3f}", + f"{(sum(baseline_trials) / len(baseline_trials)) / (sum(triton_trials) / len(triton_trials)):.2f}x", + ] + ) + + print() + print( + f"CASE {case.name} device={device} bits={case.bits} rows={case.rows} " + f"shape={case.out_features}x{case.in_features} dtype={str(dtype).removeprefix('torch.')}" + ) + if can_bench_triton: + print(_ascii_table(["trial", "baseline_ms", "fused_ms", "triton_ms", "torch_speedup", "triton_speedup"], trial_rows)) + else: + print(_ascii_table(["trial", "baseline_ms", "fused_ms", "speedup", "delta_pct"], trial_rows)) + print(_ascii_table(["path", "mean_ms", "min_ms", "max_ms", "speedup_vs_baseline"], summary_rows)) + print(f"correctness: mae={mae:.6f} max_abs={max_abs:.6f}") + if can_bench_triton: + print(f"triton_correctness: mae={triton_mae:.6f} max_abs={triton_max_abs:.6f}") + + +def _parse_case(spec: str) -> BenchCase: + parts = spec.split(":") + if len(parts) not in (5, 6): + raise ValueError( + f"Invalid case `{spec}`. Expected name:bits:in_features:out_features:rows[:group_size]." + ) + name, bits, in_features, out_features, rows, *rest = parts + return BenchCase( + name=name, + bits=bits, + in_features=int(in_features), + out_features=int(out_features), + rows=int(rows), + group_size=int(rest[0]) if rest else -1, + ) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="A/B benchmark GGUF K-type dense vs fused forward.") + parser.add_argument( + "--case", + action="append", + dest="cases", + default=[], + help="Benchmark case as name:bits:in_features:out_features:rows[:group_size].", + ) + parser.add_argument("--trials", type=int, default=5, help="Measured trials per case.") + parser.add_argument("--warmup", type=int, default=2, help="Warmup iterations per path.") + parser.add_argument("--dtype", choices=("fp16", "bf16", "fp32"), default="fp16", help="Benchmark dtype.") + parser.add_argument( + "--include-triton", + action="store_true", + help="Also benchmark the experimental CUDA Triton fused GGUF path when available.", + ) + parser.add_argument( + "--device", + choices=("auto", "cpu", "cuda", "both"), + default="auto", + help="Benchmark device selection.", + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + if args.dtype == "fp16": + dtype = torch.float16 + elif args.dtype == "bf16": + dtype = torch.bfloat16 + else: + dtype = torch.float32 + + if args.device == "auto": + devices = ["cuda"] if torch.cuda.is_available() else ["cpu"] + elif args.device == "both": + devices = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) + else: + devices = [args.device] + + if "cuda" in devices and not torch.cuda.is_available(): + raise RuntimeError("CUDA requested but no CUDA device is available.") + if "cpu" in devices and dtype == torch.float16: + raise RuntimeError("CPU benchmarks should use --dtype bf16 or --dtype fp32.") + + cases = ( + [_parse_case(spec) for spec in args.cases] + if args.cases + else [ + BenchCase("attn_q4_k_m_r1", "q4_k_m", 2048, 2048, 1), + BenchCase("attn_q5_k_m_r1", "q5_k_m", 2048, 2048, 1), + BenchCase("attn_q6_k_r1", "q6_k", 2048, 2048, 1), + BenchCase("mlp_q4_k_m_r8", "q4_k_m", 2048, 8192, 8), + BenchCase("mlp_q5_k_m_r8", "q5_k_m", 2048, 8192, 8), + BenchCase("mlp_q6_k_r8", "q6_k", 2048, 8192, 8), + ] + ) + + print( + f"devices={','.join(devices)} dtype={str(dtype).removeprefix('torch.')} " + f"trials={args.trials} warmup={args.warmup}" + ) + for device in devices: + for case in cases: + _run_case( + case, + dtype=dtype, + device=device, + trials=args.trials, + warmup=args.warmup, + include_triton=args.include_triton, + ) + + +if __name__ == "__main__": + main() diff --git a/scripts/benchmark_gptq_pro.py b/scripts/benchmark_gptq_pro.py new file mode 100644 index 000000000..0b0aeb854 --- /dev/null +++ b/scripts/benchmark_gptq_pro.py @@ -0,0 +1,394 @@ +#!/usr/bin/env python3 +""" +GPTQ-Pro Benchmark: Quantize Qwen3.5-9B-abliterated and measure quality + speed. + +Steps: + 1. Quantize with QuantizeConfig.gptq_pro() (high-quality 4-bit) + 2. Measure perplexity: original BF16 vs quantized + 3. Measure inference speed with the GPTQModel / Transformers runtime path + 4. Write findings to JSON for the report server + +Important: + The speed steps in this script do NOT exercise the standalone + `gptqmodel_ext/gptq_pro/` CUDA scaffold and do NOT use vLLM's Marlin/Machete + serving path. They are only a diagnostic for the current GPTQModel loader + backend selected at runtime. +""" +import gc +import json +import os +import sys +import time +from itertools import chain + +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" + +import torch +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoTokenizer + +ORIG_MODEL = "/home/op/models/lukey03-Qwen3.5-9B-abliterated" +QUANT_OUTPUT = "/home/op/outputs/lukey03-Qwen3.5-9B-abliterated-gptq-pro-w4g128" +RESULTS_FILE = "/home/op/outputs/gptq_pro_benchmark_results.json" + +CALIB_NSAMPLES = 256 +CALIB_SEQLEN = 2048 +PPL_N_CTX = 2048 +PPL_N_BATCH = 512 +SPEED_PROMPT = "Explain the theory of general relativity in detail, covering spacetime curvature, the equivalence principle, and gravitational waves." +SPEED_MAX_NEW = 256 +SPEED_WARMUP = 3 +SPEED_RUNS = 5 + +results = {} + + +def log(msg): + print(f"\n{'='*60}\n {msg}\n{'='*60}", flush=True) + + +def get_calibration_data(tokenizer, nsamples, seqlen): + """Load WikiText-2 calibration data.""" + log(f"Loading calibration data: {nsamples} samples, seqlen={seqlen}") + traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train") + traindata = traindata.filter(lambda x: len(x["text"]) >= seqlen) + samples = [] + for example in traindata.select(range(min(nsamples, len(traindata)))): + tok = tokenizer(example["text"], truncation=True, max_length=seqlen, + return_tensors="pt") + samples.append({"input_ids": tok["input_ids"][0], "attention_mask": tok["attention_mask"][0]}) + print(f" Loaded {len(samples)} calibration samples") + return samples + + +def measure_perplexity(model, tokenizer, label): + """Measure WikiText-2 perplexity.""" + log(f"Measuring perplexity: {label}") + from gptqmodel.utils.perplexity import Perplexity + + ppl_calc = Perplexity( + model=model, + tokenizer=tokenizer, + dataset_path="wikitext", + dataset_name="wikitext-2-raw-v1", + split="test", + text_column="text", + ) + ppl_values = ppl_calc.calculate(n_ctx=PPL_N_CTX, n_batch=PPL_N_BATCH) + avg_ppl = sum(ppl_values) / len(ppl_values) + print(f" {label} PPL = {avg_ppl:.4f} (from {len(ppl_values)} windows)") + return avg_ppl + + +def measure_speed(model, tokenizer, label, device=None): + """Measure token generation speed.""" + log(f"Measuring speed: {label}") + runtime_devices = get_model_devices(model, fallback_device=device) + target_device = device or runtime_devices[0] + inputs = tokenizer(SPEED_PROMPT, return_tensors="pt") + inputs = {k: v.to(target_device) for k, v in inputs.items()} + + prompt_len = inputs["input_ids"].shape[1] + print(f" Prompt length: {prompt_len} tokens, generating {SPEED_MAX_NEW} new tokens") + print(f" Runtime devices: {', '.join(str(d) for d in runtime_devices)}") + + # Warmup + for i in range(SPEED_WARMUP): + with torch.no_grad(): + _ = model.generate(**inputs, max_new_tokens=16, do_sample=False) + print(f" Warmup {i+1}/{SPEED_WARMUP} done") + + synchronize_devices(runtime_devices) + + times = [] + tokens_generated = [] + for i in range(SPEED_RUNS): + synchronize_devices(runtime_devices) + t0 = time.perf_counter() + with torch.no_grad(): + out = model.generate(**inputs, max_new_tokens=SPEED_MAX_NEW, do_sample=False) + synchronize_devices(runtime_devices) + t1 = time.perf_counter() + + n_new = out.shape[1] - prompt_len + elapsed = t1 - t0 + times.append(elapsed) + tokens_generated.append(n_new) + tok_per_sec = n_new / elapsed + print(f" Run {i+1}/{SPEED_RUNS}: {n_new} tokens in {elapsed:.2f}s = {tok_per_sec:.1f} tok/s") + + avg_time = sum(times) / len(times) + avg_tokens = sum(tokens_generated) / len(tokens_generated) + avg_tok_s = avg_tokens / avg_time + + result = { + "label": label, + "avg_time_s": round(avg_time, 3), + "avg_tokens": round(avg_tokens, 1), + "avg_tok_per_s": round(avg_tok_s, 2), + "all_times": [round(t, 3) for t in times], + "all_tokens": tokens_generated, + "runtime_devices": [str(d) for d in runtime_devices], + } + print(f" Average: {avg_tok_s:.2f} tok/s ({avg_tokens:.0f} tokens in {avg_time:.2f}s)") + return result + + +def get_model_devices(model, fallback_device=None): + """Collect every CUDA device used by a (possibly sharded) model.""" + devices = set() + for tensor in chain(model.parameters(), model.buffers()): + if tensor is not None and tensor.is_cuda: + devices.add(tensor.device) + + if not devices: + if fallback_device is not None: + return [torch.device(fallback_device)] + model_device = getattr(model, "device", None) + if model_device is not None: + return [torch.device(model_device)] + return [torch.device("cuda:0")] + + return sorted(devices, key=lambda dev: (dev.type, dev.index if dev.index is not None else -1)) + + +def synchronize_devices(devices): + """Synchronize every CUDA device involved in the current runtime.""" + for device in devices: + if device.type == "cuda": + torch.cuda.synchronize(device) + + +def free_model(model): + """Free GPU memory.""" + del model + gc.collect() + torch.cuda.empty_cache() + + +# ============================================================ +# STEP 1: Quantize with gptq_pro +# ============================================================ +def step_quantize(): + log("STEP 1: Quantizing with gptq_pro (high quality)") + from gptqmodel import GPTQModel, QuantizeConfig + + tokenizer = AutoTokenizer.from_pretrained(ORIG_MODEL, trust_remote_code=True) + if not tokenizer.pad_token_id: + tokenizer.pad_token_id = tokenizer.eos_token_id + + calib_data = get_calibration_data(tokenizer, CALIB_NSAMPLES, CALIB_SEQLEN) + + qcfg = QuantizeConfig.gptq_pro( + bits=4, + group_size=128, + sym=True, + mse=2.0, + damp_percent=0.05, + damp_auto_increment=0.01, + ) + print(f" QuantizeConfig: bits={qcfg.bits}, group_size={qcfg.group_size}, " + f"sym={qcfg.sym}, mse={qcfg.mse}, damp={qcfg.damp_percent}, " + f"act_group_aware={qcfg.act_group_aware}") + + results["quant_config"] = { + "bits": qcfg.bits, + "group_size": qcfg.group_size, + "sym": qcfg.sym, + "mse": qcfg.mse, + "damp_percent": qcfg.damp_percent, + "damp_auto_increment": qcfg.damp_auto_increment, + "act_group_aware": qcfg.act_group_aware, + "desc_act": qcfg.desc_act, + "format": str(qcfg.format), + } + + print(f" Loading model for quantization...") + t0 = time.perf_counter() + model = GPTQModel.load(ORIG_MODEL, qcfg, trust_remote_code=True) + print(f" Model loaded in {time.perf_counter()-t0:.1f}s") + + print(f" Starting quantization...") + t0 = time.perf_counter() + model.quantize(calib_data) + quant_time = time.perf_counter() - t0 + print(f" Quantization completed in {quant_time:.1f}s") + + os.makedirs(QUANT_OUTPUT, exist_ok=True) + model.save(QUANT_OUTPUT) + tokenizer.save_pretrained(QUANT_OUTPUT) + print(f" Saved to {QUANT_OUTPUT}") + + results["quantization"] = { + "time_s": round(quant_time, 1), + "output_path": QUANT_OUTPUT, + "calib_samples": len(calib_data), + "calib_seqlen": CALIB_SEQLEN, + } + free_model(model) + save_results() + + +# ============================================================ +# STEP 2: Perplexity — Original BF16 +# ============================================================ +def step_ppl_original(): + log("STEP 2: Perplexity of original BF16 model") + tokenizer = AutoTokenizer.from_pretrained(ORIG_MODEL, trust_remote_code=True) + if not tokenizer.pad_token_id: + tokenizer.pad_token_id = tokenizer.eos_token_id + + model = AutoModelForCausalLM.from_pretrained( + ORIG_MODEL, + device_map="auto", + torch_dtype="auto", + trust_remote_code=True, + ) + + ppl = measure_perplexity(model, tokenizer, "Original BF16") + results["ppl_original"] = round(ppl, 4) + free_model(model) + save_results() + + +# ============================================================ +# STEP 3: Perplexity — Quantized +# ============================================================ +def step_ppl_quantized(): + log("STEP 3: Perplexity of quantized model") + from gptqmodel import GPTQModel + + tokenizer = AutoTokenizer.from_pretrained(QUANT_OUTPUT, trust_remote_code=True) + if not tokenizer.pad_token_id: + tokenizer.pad_token_id = tokenizer.eos_token_id + + model = GPTQModel.load( + QUANT_OUTPUT, + device_map="auto", + trust_remote_code=True, + ) + + ppl = measure_perplexity(model, tokenizer, "GPTQ-Pro 4-bit") + results["ppl_quantized"] = round(ppl, 4) + free_model(model) + save_results() + + +# ============================================================ +# STEP 4: Inference Speed — 1 GPU +# ============================================================ +def step_speed_1gpu(): + log("STEP 4: Inference speed — GPTQModel / Transformers path on cuda:0") + from gptqmodel import GPTQModel + + tokenizer = AutoTokenizer.from_pretrained(QUANT_OUTPUT, trust_remote_code=True) + if not tokenizer.pad_token_id: + tokenizer.pad_token_id = tokenizer.eos_token_id + + model = GPTQModel.load( + QUANT_OUTPUT, + device="cuda:0", + trust_remote_code=True, + ) + + speed = measure_speed( + model, + tokenizer, + "1×RTX3090 (GPTQModel/Transformers runtime, not vLLM / not standalone gptq_pro kernel)", + device="cuda:0", + ) + results["speed_1gpu"] = speed + results["speed_path"] = "gptqmodel_transformers_runtime" + free_model(model) + save_results() + + +# ============================================================ +# STEP 5: Inference Speed — 2 GPUs +# ============================================================ +def step_speed_2gpu(): + log("STEP 5: Inference speed — diagnostic GPTQModel / Transformers multi-GPU path") + from gptqmodel import GPTQModel + + tokenizer = AutoTokenizer.from_pretrained(QUANT_OUTPUT, trust_remote_code=True) + if not tokenizer.pad_token_id: + tokenizer.pad_token_id = tokenizer.eos_token_id + + model = GPTQModel.load( + QUANT_OUTPUT, + device_map={ + "model.embed_tokens": "cuda:0", + "model.norm": "cuda:0", + "model.rotary_emb": "cuda:0", + "lm_head": "cuda:0", + }, + max_memory={ + 0: "20GiB", + 3: "10GiB", + }, + trust_remote_code=True, + ) + + speed = measure_speed( + model, + tokenizer, + "Multi-GPU GPTQModel/Transformers sharding diagnostic (not vLLM tensor parallel)", + ) + results["speed_2gpu"] = speed + results["speed_2gpu_note"] = ( + "Diagnostic only: this path uses GPTQModel/Transformers sharding and can be " + "much slower than vLLM tensor parallel or single-GPU Marlin serving." + ) + free_model(model) + save_results() + + +def save_results(): + os.makedirs(os.path.dirname(RESULTS_FILE), exist_ok=True) + with open(RESULTS_FILE, "w") as f: + json.dump(results, f, indent=2) + print(f" Results saved to {RESULTS_FILE}") + + +if __name__ == "__main__": + results["model"] = { + "name": "lukey03/Qwen3.5-9B-abliterated", + "local_path": ORIG_MODEL, + "architecture": "Qwen3_5ForCausalLM", + "params": "~9B", + "dtype": "bfloat16", + "hidden_size": 4096, + "layers": 32, + } + results["gpu_info"] = { + "gpu0": "NVIDIA GeForce RTX 3090 (24 GB)", + "gpu3": "NVIDIA GeForce RTX 3060 (12 GB)", + } + + steps = sys.argv[1:] if len(sys.argv) > 1 else [ + "quantize", "ppl_original", "ppl_quantized", "speed_1gpu", "speed_2gpu" + ] + + # Load existing results if any + if os.path.exists(RESULTS_FILE): + with open(RESULTS_FILE) as f: + results.update(json.load(f)) + + for step in steps: + if step == "quantize": + step_quantize() + elif step == "ppl_original": + step_ppl_original() + elif step == "ppl_quantized": + step_ppl_quantized() + elif step == "speed_1gpu": + step_speed_1gpu() + elif step == "speed_2gpu": + step_speed_2gpu() + else: + print(f"Unknown step: {step}") + + save_results() + log("ALL STEPS COMPLETE") + print(json.dumps(results, indent=2)) diff --git a/scripts/benchmark_llama3_2_paged_attention.py b/scripts/benchmark_llama3_2_paged_attention.py new file mode 100644 index 000000000..15f2d1a6d --- /dev/null +++ b/scripts/benchmark_llama3_2_paged_attention.py @@ -0,0 +1,1320 @@ +#!/usr/bin/env python3 + +import argparse +import copy +import gc +import json +import math +import os +import subprocess +import sys +import tempfile +import time +from dataclasses import asdict, dataclass +from datetime import datetime, timezone +from pathlib import Path + + +def _extract_requested_gpu(argv: list[str]) -> str | None: + for index, arg in enumerate(argv): + if arg == "--gpu" and index + 1 < len(argv): + return argv[index + 1] + if arg.startswith("--gpu="): + return arg.split("=", 1)[1] + return None + + +_requested_gpu = _extract_requested_gpu(sys.argv[1:]) +if _requested_gpu is not None: + os.environ.setdefault("CUDA_VISIBLE_DEVICES", str(_requested_gpu)) +os.environ.setdefault("CUDA_DEVICE_ORDER", "PCI_BUS_ID") + +import torch +from tabulate import tabulate +from transformers import AutoTokenizer + +from gptqmodel import BACKEND, GPTQModel +from gptqmodel.utils.torch import torch_empty_cache + + +REPO_ROOT = Path(__file__).resolve().parents[1] +TESTS_MODELS_DIR = REPO_ROOT / "tests" / "models" +if str(TESTS_MODELS_DIR) not in sys.path: + sys.path.insert(0, str(TESTS_MODELS_DIR)) + +from test_llama3_2 import TestLlama3_2 # noqa: E402 + + +DEFAULT_PROMPT = ( + "Write a detailed but compact explanation of how attention works in a transformer model, " + "including self-attention, key/query/value projections, and why KV caching helps autoregressive decoding." +) + +DEFAULT_BATCH_PROMPTS = [ + ( + "Explain why grouped-query attention reduces KV-cache memory pressure during autoregressive decoding, " + "and describe the main tradeoff compared with full multi-head attention." + ), + ( + "Summarize how rotary position embeddings are applied inside a decoder-only transformer and why they can " + "generalize better to longer contexts than absolute learned positions." + ), + ( + "Describe the main differences between post-training quantization and quantization-aware training for large " + "language models, including the practical impact on deployment." + ), + ( + "Write a compact explanation of how FlashAttention reduces memory traffic, and clarify when it can still be " + "slower than an SDPA-based path in real inference workloads." + ), +] + +DEFAULT_STREAM_PROMPT_TARGETS = [256, 1024, 2048, 4096] +DEFAULT_STREAM_TOPICS = [ + "attention kernel dispatch for quantized decoder-only models", + "kv-cache memory pressure under long-context generation", + "continuous batching for mixed prompt-length inference traffic", + "prefix-sharing opportunities in repeated system prompts", + "prefill versus decode scheduling tradeoffs in online serving", + "how prompt padding hurts small-batch throughput", + "batch admission control when request arrivals are bursty", + "memory fragmentation and page allocation for kv caches", +] +STREAM_SHARED_PREFIX_SENTENCE = ( + "You are producing a technical analysis of transformer inference serving, with emphasis on quantized decoding, " + "attention kernels, kv-cache reuse, request scheduling, and latency-throughput tradeoffs. " +) +STREAM_FILLER_SENTENCE = ( + "Discuss practical implications for prefill, decode, memory bandwidth, batching policy, and cache reuse in detail. " +) + + +@dataclass +class BenchmarkResult: + mode: str + batch_size: int + requested_attn_impl: str + resolved_attn_impl: str + new_tokens_per_request: int + total_new_tokens: int + latency_s: float + total_toks_per_s: float + baseline_reserved_gib: float + peak_reserved_gib: float + delta_peak_reserved_gib: float + baseline_allocated_gib: float + peak_allocated_gib: float + delta_peak_allocated_gib: float + + +@dataclass +class StreamRequest: + request_id: str + prompt: str + arrival_s: float + prompt_tokens: int + + +@dataclass +class StreamBenchmarkResult: + mode: str + requested_attn_impl: str + resolved_attn_impl: str + request_count: int + max_new_tokens_per_request: int + arrival_gap_ms: float + prompt_targets: str + makespan_s: float + reqs_per_s: float + total_toks_per_s: float + latency_p50_s: float + latency_p95_s: float + queue_p50_s: float + queue_p95_s: float + ttft_p50_s: float | None + ttft_p95_s: float | None + peak_reserved_gib: float + peak_allocated_gib: float + + +def bytes_to_gib(value: int) -> float: + return value / (1024 ** 3) + + +def percentile(values: list[float], q: float) -> float: + if not values: + return 0.0 + ordered = sorted(values) + if len(ordered) == 1: + return float(ordered[0]) + rank = (len(ordered) - 1) * q + lower = int(math.floor(rank)) + upper = int(math.ceil(rank)) + if lower == upper: + return float(ordered[lower]) + weight = rank - lower + return float(ordered[lower] * (1 - weight) + ordered[upper] * weight) + + +def prompt_targets_str(values: list[int]) -> str: + return ",".join(str(value) for value in values) + + +def now_stamp() -> str: + return datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S") + + +def make_default_artifact_dir() -> Path: + return REPO_ROOT / "benchmark_artifacts" / f"llama3_2_1b_gptq_full_{now_stamp()}" + + +def ensure_empty_or_new_dir(path: Path) -> None: + path.mkdir(parents=True, exist_ok=False) + + +def cuda_sync(device_index: int) -> None: + if torch.cuda.is_available(): + torch.cuda.synchronize(device_index) + + +def cleanup_cuda() -> None: + gc.collect() + torch_empty_cache() + + +def quantize_once(artifact_dir: Path) -> None: + test = TestLlama3_2(methodName="test_llama3_2") + test.PIN_CUDA_DEVICE = 0 + test.SAVE_PATH = str(artifact_dir) + test.DELETE_QUANTIZED_MODEL = False + test.LOAD_BACKEND = BACKEND.MARLIN + test.USE_FLASH_ATTN = True + + model, tokenizer, _ = test.quantModel( + test.NATIVE_MODEL_ID, + batch_size=test.QUANT_BATCH_SIZE, + trust_remote_code=test.TRUST_REMOTE_CODE, + dtype=test.TORCH_DTYPE, + need_eval=False, + call_perform_post_quant_validation=False, + ) + + del tokenizer + del model + cleanup_cuda() + + +def load_quantized_model(artifact_dir: Path, attn_implementation: str): + model = GPTQModel.load( + str(artifact_dir), + trust_remote_code=False, + backend=BACKEND.MARLIN, + device_map={"": "cuda:0"}, + attn_implementation=attn_implementation, + ) + return model, model.tokenizer + + +def build_shared_prefix(decoder, target_tokens: int) -> str: + text = "" + while len(decoder(text, add_special_tokens=True)["input_ids"]) < target_tokens: + text += STREAM_SHARED_PREFIX_SENTENCE + return text + + +def build_prompt_to_target_tokens(decoder, shared_prefix: str, unique_suffix: str, target_tokens: int) -> str: + text = shared_prefix + unique_suffix + while len(decoder(text, add_special_tokens=True)["input_ids"]) < target_tokens: + text += STREAM_FILLER_SENTENCE + return text + + +def build_stream_workload( + artifact_dir: Path, + request_count: int, + arrival_gap_ms: float, + prompt_targets: list[int], + shared_prefix_tokens: int, +) -> list[StreamRequest]: + decoder = AutoTokenizer.from_pretrained(str(artifact_dir), trust_remote_code=False) + shared_prefix = build_shared_prefix(decoder, shared_prefix_tokens) + requests = [] + for index in range(request_count): + target_tokens = prompt_targets[index % len(prompt_targets)] + topic = DEFAULT_STREAM_TOPICS[index % len(DEFAULT_STREAM_TOPICS)] + unique_suffix = ( + f"Request {index}: provide a compact but precise explanation of {topic}. " + f"Include a comparison against request id {index} traffic behavior, and keep the answer deterministic. " + ) + prompt = build_prompt_to_target_tokens( + decoder, + shared_prefix, + unique_suffix, + max(target_tokens, shared_prefix_tokens + 32), + ) + prompt_tokens = len(decoder(prompt, add_special_tokens=True)["input_ids"]) + requests.append( + StreamRequest( + request_id=f"req_{index}", + prompt=prompt, + arrival_s=(arrival_gap_ms / 1000.0) * index, + prompt_tokens=prompt_tokens, + ) + ) + return requests + + +def base_tokenizer(tokenizer): + return getattr(tokenizer, "tokenizer", tokenizer) + + +def prepare_inputs(tokenizer, prompts: list[str], device) -> tuple[dict, int, int, int]: + decoder = base_tokenizer(tokenizer) + original_padding_side = getattr(decoder, "padding_side", "right") + decoder.padding_side = "left" + inputs = decoder(prompts, return_tensors="pt", padding=True).to(device) + decoder.padding_side = original_padding_side + padded_prompt_len = int(inputs["input_ids"].shape[-1]) + pad_token_id = decoder.pad_token_id if decoder.pad_token_id is not None else decoder.eos_token_id + eos_token_id = -1 + return inputs, padded_prompt_len, pad_token_id, eos_token_id + + +def prepare_cb_inputs(tokenizer, prompts: list[str]) -> tuple[list[list[int]], int, int]: + decoder = base_tokenizer(tokenizer) + encoded = decoder(prompts, add_special_tokens=True) + input_ids = encoded["input_ids"] + if not isinstance(input_ids, list) or not input_ids: + raise ValueError("Continuous batching inputs must be a non-empty list of token-id lists.") + pad_token_id = decoder.pad_token_id if decoder.pad_token_id is not None else decoder.eos_token_id + eos_token_id = -1 + return [list(ids) for ids in input_ids], pad_token_id, eos_token_id + + +def run_generate( + model, + inputs: dict, + *, + min_new_tokens: int, + max_new_tokens: int, + pad_token_id: int, + eos_token_id: int | None, + cache_implementation: str | None = None, +): + kwargs = dict(inputs) + kwargs.update( + { + "do_sample": False, + "num_beams": 1, + "min_new_tokens": min_new_tokens, + "max_new_tokens": max_new_tokens, + "pad_token_id": pad_token_id, + "eos_token_id": eos_token_id, + } + ) + if cache_implementation is not None: + kwargs["cache_implementation"] = cache_implementation + return model.generate(**kwargs) + + +def sequence_batch_from_generate_output(output: torch.Tensor) -> torch.Tensor: + if output.dim() == 3: + return output[0] + if output.dim() == 2: + return output + raise ValueError(f"Unexpected generate output shape: {tuple(output.shape)}") + + +def collect_cb_results(manager, request_count: int) -> list: + results = [] + while len(results) < request_count: + result = manager.get_result(timeout=1) + if result is not None: + results.append(result) + return results + + +def evict_cb_results(manager, results: list) -> None: + for item in results: + manager.evict_request_from_cache(item.request_id) + + +def benchmark_paged_mode( + artifact_dir: Path, + prompts: list[str], + *, + mode_name: str, + attn_implementation: str, + warmup_tokens: int, + max_new_tokens: int, +) -> BenchmarkResult: + model, tokenizer = load_quantized_model(artifact_dir, attn_implementation=attn_implementation) + model.eval() + device_index = 0 + + prompt_token_lists, pad_token_id, eos_token_id = prepare_cb_inputs(tokenizer, prompts) + generation_config = copy.deepcopy(model.model.generation_config) + generation_config.do_sample = False + generation_config.num_beams = 1 + generation_config.pad_token_id = pad_token_id + generation_config.eos_token_id = eos_token_id + + with model.model.continuous_batching_context_manager( + generation_config=generation_config, + manual_eviction=True, + block=True, + timeout=30, + use_async_batching=False, + ) as manager: + warmup_results = [] + if warmup_tokens > 0: + manager.add_requests(prompt_token_lists, max_new_tokens=warmup_tokens) + warmup_results = collect_cb_results(manager, len(prompt_token_lists)) + evict_cb_results(manager, warmup_results) + cuda_sync(device_index) + + baseline_reserved = torch.cuda.memory_reserved(device_index) + baseline_allocated = torch.cuda.memory_allocated(device_index) + torch.cuda.reset_peak_memory_stats(device_index) + + started = time.perf_counter() + manager.add_requests(prompt_token_lists, max_new_tokens=max_new_tokens) + measured_results = collect_cb_results(manager, len(prompt_token_lists)) + cuda_sync(device_index) + elapsed = time.perf_counter() - started + + peak_reserved = torch.cuda.max_memory_reserved(device_index) + peak_allocated = torch.cuda.max_memory_allocated(device_index) + + batch_size = len(measured_results) + if batch_size == 0: + raise ValueError("Continuous batching returned no results.") + new_tokens_per_request = len(measured_results[0].generated_tokens) + total_new_tokens = sum(len(item.generated_tokens) for item in measured_results) + resolved_attn_impl = str(getattr(model.config, "_attn_implementation", attn_implementation)) + + del measured_results + del warmup_results + del tokenizer + del model + cleanup_cuda() + + total_toks_per_s = float(total_new_tokens) / elapsed if elapsed > 0 else 0.0 + return BenchmarkResult( + mode=mode_name, + batch_size=batch_size, + requested_attn_impl=attn_implementation, + resolved_attn_impl=resolved_attn_impl, + new_tokens_per_request=new_tokens_per_request, + total_new_tokens=total_new_tokens, + latency_s=elapsed, + total_toks_per_s=total_toks_per_s, + baseline_reserved_gib=bytes_to_gib(baseline_reserved), + peak_reserved_gib=bytes_to_gib(peak_reserved), + delta_peak_reserved_gib=bytes_to_gib(max(peak_reserved - baseline_reserved, 0)), + baseline_allocated_gib=bytes_to_gib(baseline_allocated), + peak_allocated_gib=bytes_to_gib(peak_allocated), + delta_peak_allocated_gib=bytes_to_gib(max(peak_allocated - baseline_allocated, 0)), + ) + + +def benchmark_paged_mode_subprocess( + artifact_dir: Path, + prompts: list[str], + *, + mode_name: str, + attn_implementation: str, + warmup_tokens: int, + max_new_tokens: int, + gpu: int, +) -> BenchmarkResult: + with tempfile.NamedTemporaryFile(prefix="paged_benchmark_", suffix=".json", delete=False) as handle: + result_path = Path(handle.name) + + command = [ + sys.executable, + str(Path(__file__).resolve()), + "--gpu", + str(gpu), + "--internal-paged-result-path", + str(result_path), + "--artifact-dir", + str(artifact_dir), + "--internal-prompts-json", + json.dumps(prompts), + "--internal-mode-name", + mode_name, + "--internal-attn-implementation", + attn_implementation, + "--internal-warmup-tokens", + str(warmup_tokens), + "--internal-max-new-tokens", + str(max_new_tokens), + ] + completed = subprocess.run(command, check=False, capture_output=True, text=True) + if completed.returncode != 0: + raise RuntimeError( + "Paged benchmark subprocess failed.\n" + f"STDOUT:\n{completed.stdout}\n" + f"STDERR:\n{completed.stderr}" + ) + try: + payload = json.loads(result_path.read_text(encoding="utf-8")) + return BenchmarkResult(**payload) + finally: + result_path.unlink(missing_ok=True) + + +def benchmark_mode( + artifact_dir: Path, + prompts: list[str], + *, + mode_name: str, + attn_implementation: str, + warmup_tokens: int, + max_new_tokens: int, + use_paged_cache: bool, + gpu: int, +) -> BenchmarkResult: + if use_paged_cache: + return benchmark_paged_mode_subprocess( + artifact_dir, + prompts, + mode_name=mode_name, + attn_implementation=attn_implementation, + warmup_tokens=warmup_tokens, + max_new_tokens=max_new_tokens, + gpu=gpu, + ) + + model, tokenizer = load_quantized_model(artifact_dir, attn_implementation=attn_implementation) + model.eval() + device_index = 0 + + inputs, padded_prompt_len, pad_token_id, eos_token_id = prepare_inputs(tokenizer, prompts, model.device) + cache_impl = "paged" if use_paged_cache else None + + warmup_output = run_generate( + model, + inputs, + min_new_tokens=warmup_tokens, + max_new_tokens=warmup_tokens, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + cache_implementation=cache_impl, + ) + del warmup_output + cuda_sync(device_index) + + baseline_reserved = torch.cuda.memory_reserved(device_index) + baseline_allocated = torch.cuda.memory_allocated(device_index) + torch.cuda.reset_peak_memory_stats(device_index) + + started = time.perf_counter() + measured_output = run_generate( + model, + inputs, + min_new_tokens=max_new_tokens, + max_new_tokens=max_new_tokens, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + cache_implementation=cache_impl, + ) + cuda_sync(device_index) + elapsed = time.perf_counter() - started + + peak_reserved = torch.cuda.max_memory_reserved(device_index) + peak_allocated = torch.cuda.max_memory_allocated(device_index) + sequence_batch = sequence_batch_from_generate_output(measured_output) + batch_size = int(sequence_batch.shape[0]) + new_tokens_per_request = int(sequence_batch.shape[-1] - padded_prompt_len) + total_new_tokens = batch_size * new_tokens_per_request + + resolved_attn_impl = str(getattr(model.config, "_attn_implementation", attn_implementation)) + + del sequence_batch + del measured_output + del inputs + del tokenizer + del model + cleanup_cuda() + + total_toks_per_s = float(total_new_tokens) / elapsed if elapsed > 0 else 0.0 + return BenchmarkResult( + mode=mode_name, + batch_size=batch_size, + requested_attn_impl=attn_implementation, + resolved_attn_impl=resolved_attn_impl, + new_tokens_per_request=new_tokens_per_request, + total_new_tokens=total_new_tokens, + latency_s=elapsed, + total_toks_per_s=total_toks_per_s, + baseline_reserved_gib=bytes_to_gib(baseline_reserved), + peak_reserved_gib=bytes_to_gib(peak_reserved), + delta_peak_reserved_gib=bytes_to_gib(max(peak_reserved - baseline_reserved, 0)), + baseline_allocated_gib=bytes_to_gib(baseline_allocated), + peak_allocated_gib=bytes_to_gib(peak_allocated), + delta_peak_allocated_gib=bytes_to_gib(max(peak_allocated - baseline_allocated, 0)), + ) + + +def warmup_static_generate(model, tokenizer, prompt: str, warmup_tokens: int) -> None: + if warmup_tokens <= 0: + return + inputs, _, pad_token_id, eos_token_id = prepare_inputs(tokenizer, [prompt], model.device) + output = run_generate( + model, + inputs, + min_new_tokens=warmup_tokens, + max_new_tokens=warmup_tokens, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + ) + del output + del inputs + cuda_sync(0) + + +def make_stream_result( + *, + mode_name: str, + requested_attn_impl: str, + resolved_attn_impl: str, + request_count: int, + max_new_tokens_per_request: int, + arrival_gap_ms: float, + prompt_targets: list[int], + makespan_s: float, + latencies: list[float], + queue_delays: list[float], + ttfts: list[float] | None, + peak_reserved: int, + peak_allocated: int, +) -> StreamBenchmarkResult: + total_new_tokens = request_count * max_new_tokens_per_request + return StreamBenchmarkResult( + mode=mode_name, + requested_attn_impl=requested_attn_impl, + resolved_attn_impl=resolved_attn_impl, + request_count=request_count, + max_new_tokens_per_request=max_new_tokens_per_request, + arrival_gap_ms=arrival_gap_ms, + prompt_targets=prompt_targets_str(prompt_targets), + makespan_s=makespan_s, + reqs_per_s=(request_count / makespan_s) if makespan_s > 0 else 0.0, + total_toks_per_s=(total_new_tokens / makespan_s) if makespan_s > 0 else 0.0, + latency_p50_s=percentile(latencies, 0.50), + latency_p95_s=percentile(latencies, 0.95), + queue_p50_s=percentile(queue_delays, 0.50), + queue_p95_s=percentile(queue_delays, 0.95), + ttft_p50_s=None if not ttfts else percentile(ttfts, 0.50), + ttft_p95_s=None if not ttfts else percentile(ttfts, 0.95), + peak_reserved_gib=bytes_to_gib(peak_reserved), + peak_allocated_gib=bytes_to_gib(peak_allocated), + ) + + +def benchmark_stream_static_mode( + artifact_dir: Path, + workload: list[StreamRequest], + *, + mode_name: str, + attn_implementation: str, + warmup_tokens: int, + max_new_tokens: int, + arrival_gap_ms: float, + prompt_targets: list[int], +) -> StreamBenchmarkResult: + model, tokenizer = load_quantized_model(artifact_dir, attn_implementation=attn_implementation) + model.eval() + device_index = 0 + + warmup_static_generate(model, tokenizer, workload[0].prompt, warmup_tokens) + torch.cuda.reset_peak_memory_stats(device_index) + + pending: list[StreamRequest] = [] + next_index = 0 + latencies: list[float] = [] + queue_delays: list[float] = [] + completion_times: list[float] = [] + + started_at = time.perf_counter() + while len(completion_times) < len(workload): + now_rel = time.perf_counter() - started_at + while next_index < len(workload) and workload[next_index].arrival_s <= now_rel: + pending.append(workload[next_index]) + next_index += 1 + + if not pending: + if next_index >= len(workload): + break + sleep_s = max(workload[next_index].arrival_s - now_rel, 0.0) + if sleep_s > 0: + time.sleep(min(sleep_s, 0.01)) + continue + + batch_requests = pending + pending = [] + batch_start_rel = time.perf_counter() - started_at + prompts = [request.prompt for request in batch_requests] + inputs, _, pad_token_id, eos_token_id = prepare_inputs(tokenizer, prompts, model.device) + output = run_generate( + model, + inputs, + min_new_tokens=max_new_tokens, + max_new_tokens=max_new_tokens, + pad_token_id=pad_token_id, + eos_token_id=eos_token_id, + ) + del output + del inputs + cuda_sync(device_index) + batch_end_rel = time.perf_counter() - started_at + + for request in batch_requests: + latencies.append(batch_end_rel - request.arrival_s) + queue_delays.append(max(batch_start_rel - request.arrival_s, 0.0)) + completion_times.append(batch_end_rel) + + peak_reserved = torch.cuda.max_memory_reserved(device_index) + peak_allocated = torch.cuda.max_memory_allocated(device_index) + resolved_attn_impl = str(getattr(model.config, "_attn_implementation", attn_implementation)) + + del tokenizer + del model + cleanup_cuda() + + makespan_s = max(completion_times) if completion_times else 0.0 + return make_stream_result( + mode_name=mode_name, + requested_attn_impl=attn_implementation, + resolved_attn_impl=resolved_attn_impl, + request_count=len(workload), + max_new_tokens_per_request=max_new_tokens, + arrival_gap_ms=arrival_gap_ms, + prompt_targets=prompt_targets, + makespan_s=makespan_s, + latencies=latencies, + queue_delays=queue_delays, + ttfts=None, + peak_reserved=peak_reserved, + peak_allocated=peak_allocated, + ) + + +def benchmark_stream_paged_mode( + artifact_dir: Path, + workload: list[StreamRequest], + *, + mode_name: str, + attn_implementation: str, + warmup_tokens: int, + max_new_tokens: int, + arrival_gap_ms: float, + prompt_targets: list[int], + scheduler_name: str, + use_async_batching: bool, +) -> StreamBenchmarkResult: + model, tokenizer = load_quantized_model(artifact_dir, attn_implementation=attn_implementation) + model.eval() + device_index = 0 + + prompt_token_lists, pad_token_id, eos_token_id = prepare_cb_inputs(tokenizer, [request.prompt for request in workload]) + generation_config = copy.deepcopy(model.model.generation_config) + generation_config.do_sample = False + generation_config.num_beams = 1 + generation_config.pad_token_id = pad_token_id + generation_config.eos_token_id = eos_token_id + generation_config.scheduler = scheduler_name + + with model.model.continuous_batching_context_manager( + generation_config=generation_config, + manual_eviction=True, + block=True, + timeout=30, + use_async_batching=use_async_batching, + allow_block_sharing=True, + ) as manager: + if warmup_tokens > 0: + manager.add_request( + prompt_token_lists[0], + request_id="warmup", + max_new_tokens=warmup_tokens, + record_timestamps=True, + ) + warmup_result = collect_cb_results(manager, 1)[0] + manager.evict_request_from_cache(warmup_result.request_id) + del warmup_result + cuda_sync(device_index) + + torch.cuda.reset_peak_memory_stats(device_index) + + next_index = 0 + latencies: list[float] = [] + queue_delays: list[float] = [] + ttfts: list[float] = [] + completion_times: list[float] = [] + started_at_abs = time.perf_counter() + + while len(completion_times) < len(workload): + now_rel = time.perf_counter() - started_at_abs + while next_index < len(workload) and workload[next_index].arrival_s <= now_rel: + request = workload[next_index] + manager.add_request( + prompt_token_lists[next_index], + request_id=request.request_id, + max_new_tokens=max_new_tokens, + record_timestamps=True, + ) + next_index += 1 + + result = manager.get_result(timeout=0.01) + if result is None: + if next_index < len(workload): + sleep_s = max(workload[next_index].arrival_s - (time.perf_counter() - started_at_abs), 0.0) + if sleep_s > 0: + time.sleep(min(sleep_s, 0.01)) + continue + + if result.error is not None: + raise RuntimeError(f"Continuous batching request failed: {result.request_id}: {result.error}") + + completion_times.append(result.lifespan[1] - started_at_abs) + latencies.append(result.lifespan[1] - result.created_time) + queue_delays.append(max(result.lifespan[0] - result.created_time, 0.0)) + if result.timestamps: + ttfts.append(result.timestamps[0] - result.created_time) + manager.evict_request_from_cache(result.request_id) + + peak_reserved = torch.cuda.max_memory_reserved(device_index) + peak_allocated = torch.cuda.max_memory_allocated(device_index) + + resolved_attn_impl = str(getattr(model.config, "_attn_implementation", attn_implementation)) + + del tokenizer + del model + cleanup_cuda() + + makespan_s = max(completion_times) if completion_times else 0.0 + return make_stream_result( + mode_name=mode_name, + requested_attn_impl=attn_implementation, + resolved_attn_impl=resolved_attn_impl, + request_count=len(workload), + max_new_tokens_per_request=max_new_tokens, + arrival_gap_ms=arrival_gap_ms, + prompt_targets=prompt_targets, + makespan_s=makespan_s, + latencies=latencies, + queue_delays=queue_delays, + ttfts=ttfts, + peak_reserved=peak_reserved, + peak_allocated=peak_allocated, + ) + + +def benchmark_stream_paged_mode_subprocess( + artifact_dir: Path, + workload: list[StreamRequest], + *, + mode_name: str, + attn_implementation: str, + warmup_tokens: int, + max_new_tokens: int, + arrival_gap_ms: float, + prompt_targets: list[int], + scheduler_name: str, + use_async_batching: bool, + gpu: int, +) -> StreamBenchmarkResult: + with tempfile.NamedTemporaryFile(prefix="stream_workload_", suffix=".json", delete=False) as workload_handle: + workload_path = Path(workload_handle.name) + with tempfile.NamedTemporaryFile(prefix="stream_result_", suffix=".json", delete=False) as result_handle: + result_path = Path(result_handle.name) + + workload_payload = { + "arrival_gap_ms": arrival_gap_ms, + "prompt_targets": prompt_targets, + "scheduler_name": scheduler_name, + "use_async_batching": use_async_batching, + "requests": [asdict(item) for item in workload], + } + workload_path.write_text(json.dumps(workload_payload), encoding="utf-8") + + command = [ + sys.executable, + str(Path(__file__).resolve()), + "--gpu", + str(gpu), + "--artifact-dir", + str(artifact_dir), + "--internal-stream-result-path", + str(result_path), + "--internal-stream-workload-path", + str(workload_path), + "--internal-mode-name", + mode_name, + "--internal-attn-implementation", + attn_implementation, + "--internal-warmup-tokens", + str(warmup_tokens), + "--internal-max-new-tokens", + str(max_new_tokens), + ] + completed = subprocess.run(command, check=False, capture_output=True, text=True) + try: + if completed.returncode != 0: + raise RuntimeError( + "Stream paged benchmark subprocess failed.\n" + f"STDOUT:\n{completed.stdout}\n" + f"STDERR:\n{completed.stderr}" + ) + payload = json.loads(result_path.read_text(encoding="utf-8")) + return StreamBenchmarkResult(**payload) + finally: + workload_path.unlink(missing_ok=True) + result_path.unlink(missing_ok=True) + + +def benchmark_stream_mode( + artifact_dir: Path, + workload: list[StreamRequest], + *, + mode_name: str, + attn_implementation: str, + warmup_tokens: int, + max_new_tokens: int, + use_paged_cache: bool, + arrival_gap_ms: float, + prompt_targets: list[int], + scheduler_name: str, + use_async_batching: bool, + gpu: int, +) -> StreamBenchmarkResult: + if use_paged_cache: + return benchmark_stream_paged_mode_subprocess( + artifact_dir, + workload, + mode_name=mode_name, + attn_implementation=attn_implementation, + warmup_tokens=warmup_tokens, + max_new_tokens=max_new_tokens, + arrival_gap_ms=arrival_gap_ms, + prompt_targets=prompt_targets, + scheduler_name=scheduler_name, + use_async_batching=use_async_batching, + gpu=gpu, + ) + + return benchmark_stream_static_mode( + artifact_dir, + workload, + mode_name=mode_name, + attn_implementation=attn_implementation, + warmup_tokens=warmup_tokens, + max_new_tokens=max_new_tokens, + arrival_gap_ms=arrival_gap_ms, + prompt_targets=prompt_targets, + ) + + +def render_ascii_table(results: list[BenchmarkResult]) -> str: + rows = [] + for item in results: + rows.append( + [ + item.mode, + item.batch_size, + item.requested_attn_impl, + item.resolved_attn_impl, + item.new_tokens_per_request, + item.total_new_tokens, + f"{item.latency_s:.3f}", + f"{item.total_toks_per_s:.2f}", + f"{item.peak_reserved_gib:.2f}", + f"{item.delta_peak_reserved_gib:.2f}", + f"{item.peak_allocated_gib:.2f}", + f"{item.delta_peak_allocated_gib:.2f}", + ] + ) + headers = [ + "mode", + "batch", + "requested_attn", + "resolved_attn", + "new_tokens_each", + "total_new_tokens", + "latency_s", + "total_tok_s", + "peak_reserved_gib", + "delta_reserved_gib", + "peak_alloc_gib", + "delta_alloc_gib", + ] + return tabulate(rows, headers=headers, tablefmt="grid") + + +def render_stream_ascii_table(results: list[StreamBenchmarkResult]) -> str: + rows = [] + for item in results: + rows.append( + [ + item.mode, + item.request_count, + item.max_new_tokens_per_request, + f"{item.arrival_gap_ms:.0f}", + item.prompt_targets, + f"{item.makespan_s:.3f}", + f"{item.reqs_per_s:.2f}", + f"{item.total_toks_per_s:.2f}", + f"{item.latency_p50_s:.3f}", + f"{item.latency_p95_s:.3f}", + f"{item.queue_p50_s:.3f}", + f"{item.queue_p95_s:.3f}", + "n/a" if item.ttft_p50_s is None else f"{item.ttft_p50_s:.3f}", + "n/a" if item.ttft_p95_s is None else f"{item.ttft_p95_s:.3f}", + f"{item.peak_allocated_gib:.2f}", + f"{item.peak_reserved_gib:.2f}", + ] + ) + headers = [ + "mode", + "requests", + "new_tok_each", + "arrival_ms", + "prompt_targets", + "makespan_s", + "req_s", + "tok_s", + "lat_p50_s", + "lat_p95_s", + "queue_p50_s", + "queue_p95_s", + "ttft_p50_s", + "ttft_p95_s", + "peak_alloc_gib", + "peak_reserved_gib", + ] + return tabulate(rows, headers=headers, tablefmt="grid") + + +def parse_batch_sizes(raw_value: str) -> list[int]: + values = [] + for piece in raw_value.split(","): + piece = piece.strip() + if not piece: + continue + value = int(piece) + if value <= 0: + raise ValueError(f"Batch sizes must be positive, got {value}") + values.append(value) + if not values: + raise ValueError("At least one batch size must be provided.") + return values + + +def parse_int_list(raw_value: str) -> list[int]: + values = [] + for piece in raw_value.split(","): + piece = piece.strip() + if not piece: + continue + value = int(piece) + if value <= 0: + raise ValueError(f"Values must be positive, got {value}") + values.append(value) + if not values: + raise ValueError("At least one positive integer must be provided.") + return values + + +def prompts_for_batch(batch_size: int, single_prompt: str | None = None) -> list[str]: + if batch_size == 1 and single_prompt: + return [single_prompt] + if batch_size > len(DEFAULT_BATCH_PROMPTS): + raise ValueError( + f"Batch size {batch_size} requested, but only {len(DEFAULT_BATCH_PROMPTS)} distinct default prompts are available." + ) + return DEFAULT_BATCH_PROMPTS[:batch_size] + + +def results_json_path(artifact_dir: Path, batch_sizes: list[int]) -> Path: + if batch_sizes == [1]: + return artifact_dir / "attention_benchmark_results.json" + return artifact_dir / "attention_benchmark_batch_results.json" + + +def stream_results_json_path(artifact_dir: Path) -> Path: + return artifact_dir / "attention_benchmark_stream_results.json" + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Quantize Llama-3.2-1B-Instruct and benchmark attention modes.") + parser.add_argument("--gpu", type=int, default=0, help="Physical GPU index to pin via CUDA_VISIBLE_DEVICES.") + parser.add_argument("--artifact-dir", type=Path, default=None, help="Directory to store quantized artifacts.") + parser.add_argument( + "--scenario", + choices=["closed_batch", "serve_stream"], + default="closed_batch", + help="Benchmark scenario to run.", + ) + parser.add_argument( + "--reuse-artifact", + action="store_true", + help="Reuse an existing artifact directory instead of requiring a new quantization output directory.", + ) + parser.add_argument( + "--batch-sizes", + type=str, + default="1", + help="Comma-separated batch sizes to benchmark, for example `1,2,4`.", + ) + parser.add_argument("--prompt", type=str, default=DEFAULT_PROMPT, help="Prompt used for the 512-token benchmark.") + parser.add_argument("--warmup-tokens", type=int, default=32, help="Warmup generation length before measuring.") + parser.add_argument("--max-new-tokens", type=int, default=512, help="Measured generation length.") + parser.add_argument( + "--stream-request-count", + type=int, + default=12, + help="Number of requests in serve_stream mode.", + ) + parser.add_argument( + "--stream-arrival-ms", + type=float, + default=75.0, + help="Inter-arrival gap in milliseconds for serve_stream mode.", + ) + parser.add_argument( + "--stream-prompt-targets", + type=str, + default=prompt_targets_str(DEFAULT_STREAM_PROMPT_TARGETS), + help="Comma-separated prompt token targets for serve_stream mode.", + ) + parser.add_argument( + "--stream-shared-prefix-tokens", + type=int, + default=192, + help="Approximate shared-prefix token count for serve_stream mode.", + ) + parser.add_argument( + "--stream-scheduler", + type=str, + default="prefill_first", + help="Continuous batching scheduler for paged serve_stream mode.", + ) + parser.add_argument( + "--stream-use-async-batching", + action="store_true", + help="Enable async batching for paged serve_stream mode.", + ) + parser.add_argument("--internal-paged-result-path", type=Path, default=None, help=argparse.SUPPRESS) + parser.add_argument("--internal-stream-result-path", type=Path, default=None, help=argparse.SUPPRESS) + parser.add_argument("--internal-stream-workload-path", type=Path, default=None, help=argparse.SUPPRESS) + parser.add_argument("--internal-prompts-json", type=str, default=None, help=argparse.SUPPRESS) + parser.add_argument("--internal-mode-name", type=str, default=None, help=argparse.SUPPRESS) + parser.add_argument("--internal-attn-implementation", type=str, default=None, help=argparse.SUPPRESS) + parser.add_argument("--internal-warmup-tokens", type=int, default=None, help=argparse.SUPPRESS) + parser.add_argument("--internal-max-new-tokens", type=int, default=None, help=argparse.SUPPRESS) + return parser.parse_args() + + +def main() -> int: + args = parse_args() + if args.internal_paged_result_path is not None: + prompts = json.loads(args.internal_prompts_json) + result = benchmark_paged_mode( + args.artifact_dir.resolve(), + prompts, + mode_name=args.internal_mode_name, + attn_implementation=args.internal_attn_implementation, + warmup_tokens=args.internal_warmup_tokens, + max_new_tokens=args.internal_max_new_tokens, + ) + args.internal_paged_result_path.write_text(json.dumps(asdict(result)), encoding="utf-8") + sys.stdout.flush() + sys.stderr.flush() + os._exit(0) + if args.internal_stream_result_path is not None: + payload = json.loads(args.internal_stream_workload_path.read_text(encoding="utf-8")) + workload = [StreamRequest(**item) for item in payload["requests"]] + result = benchmark_stream_paged_mode( + args.artifact_dir.resolve(), + workload, + mode_name=args.internal_mode_name, + attn_implementation=args.internal_attn_implementation, + warmup_tokens=args.internal_warmup_tokens, + max_new_tokens=args.internal_max_new_tokens, + arrival_gap_ms=float(payload["arrival_gap_ms"]), + prompt_targets=[int(value) for value in payload["prompt_targets"]], + scheduler_name=str(payload["scheduler_name"]), + use_async_batching=bool(payload["use_async_batching"]), + ) + args.internal_stream_result_path.write_text(json.dumps(asdict(result)), encoding="utf-8") + sys.stdout.flush() + sys.stderr.flush() + os._exit(0) + + batch_sizes = parse_batch_sizes(args.batch_sizes) + stream_prompt_targets = parse_int_list(args.stream_prompt_targets) + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for this benchmark.") + + artifact_dir = args.artifact_dir or make_default_artifact_dir() + artifact_dir = artifact_dir.resolve() + if artifact_dir.exists(): + if not args.reuse_artifact: + raise FileExistsError( + f"Artifact directory already exists: {artifact_dir}. Pass --reuse-artifact to skip requantization." + ) + else: + ensure_empty_or_new_dir(artifact_dir) + + physical_gpu_index = args.gpu + logical_gpu_index = 0 + gpu_name = torch.cuda.get_device_name(logical_gpu_index) + + print(f"Using physical GPU {physical_gpu_index} as logical cuda:{logical_gpu_index}: {gpu_name}") + print(f"Quantized artifact directory: {artifact_dir}") + print(f"Scenario: {args.scenario}") + if args.scenario == "closed_batch": + print(f"Batch sizes: {batch_sizes}") + else: + print( + "Serve stream config: " + f"requests={args.stream_request_count}, arrival_ms={args.stream_arrival_ms}, " + f"prompt_targets={stream_prompt_targets}, shared_prefix_tokens={args.stream_shared_prefix_tokens}, " + f"scheduler={args.stream_scheduler}, async_batching={args.stream_use_async_batching}" + ) + + if args.reuse_artifact: + print("Reusing existing quantized artifact directory.") + else: + quantize_once(artifact_dir) + + modes = [ + { + "mode_name": "sdpa", + "attn_implementation": "sdpa", + "use_paged_cache": False, + }, + { + "mode_name": "flash_attention_2", + "attn_implementation": "flash_attention_2", + "use_paged_cache": False, + }, + { + "mode_name": "paged(sdpa)", + "attn_implementation": "sdpa", + "use_paged_cache": True, + }, + ] + if args.scenario == "serve_stream": + modes.append( + { + "mode_name": "paged(fa2)", + "attn_implementation": "flash_attention_2", + "use_paged_cache": True, + } + ) + + if args.scenario == "closed_batch": + results: list[BenchmarkResult] = [] + prompts_by_batch = {} + for batch_size in batch_sizes: + prompts = prompts_for_batch(batch_size, single_prompt=args.prompt if batch_size == 1 and batch_sizes == [1] else None) + prompts_by_batch[str(batch_size)] = prompts + for mode in modes: + print(f"\nBenchmarking {mode['mode_name']} with batch={batch_size}...") + result = benchmark_mode( + artifact_dir, + prompts, + mode_name=mode["mode_name"], + attn_implementation=mode["attn_implementation"], + warmup_tokens=args.warmup_tokens, + max_new_tokens=args.max_new_tokens, + use_paged_cache=mode["use_paged_cache"], + gpu=physical_gpu_index, + ) + results.append(result) + + metadata = { + "created_at_utc": datetime.now(timezone.utc).isoformat(), + "artifact_dir": str(artifact_dir), + "physical_gpu_index": physical_gpu_index, + "logical_gpu_index": logical_gpu_index, + "gpu_name": gpu_name, + "model_id": TestLlama3_2.NATIVE_MODEL_ID, + "scenario": args.scenario, + "batch_sizes": batch_sizes, + "prompts_by_batch": prompts_by_batch, + "warmup_tokens": args.warmup_tokens, + "max_new_tokens": args.max_new_tokens, + "results": [asdict(item) for item in results], + "ascii_table": render_ascii_table(results), + } + + json_path = results_json_path(artifact_dir, batch_sizes) + json_path.write_text(json.dumps(metadata, indent=2), encoding="utf-8") + + print("\n" + metadata["ascii_table"]) + print(f"\nResults JSON: {json_path}") + return 0 + + workload = build_stream_workload( + artifact_dir, + request_count=args.stream_request_count, + arrival_gap_ms=args.stream_arrival_ms, + prompt_targets=stream_prompt_targets, + shared_prefix_tokens=args.stream_shared_prefix_tokens, + ) + stream_results: list[StreamBenchmarkResult] = [] + for mode in modes: + print(f"\nBenchmarking {mode['mode_name']} with serve_stream...") + result = benchmark_stream_mode( + artifact_dir, + workload, + mode_name=mode["mode_name"], + attn_implementation=mode["attn_implementation"], + warmup_tokens=args.warmup_tokens, + max_new_tokens=args.max_new_tokens, + use_paged_cache=mode["use_paged_cache"], + arrival_gap_ms=args.stream_arrival_ms, + prompt_targets=stream_prompt_targets, + scheduler_name=args.stream_scheduler, + use_async_batching=args.stream_use_async_batching, + gpu=physical_gpu_index, + ) + stream_results.append(result) + + stream_metadata = { + "created_at_utc": datetime.now(timezone.utc).isoformat(), + "artifact_dir": str(artifact_dir), + "physical_gpu_index": physical_gpu_index, + "logical_gpu_index": logical_gpu_index, + "gpu_name": gpu_name, + "model_id": TestLlama3_2.NATIVE_MODEL_ID, + "scenario": args.scenario, + "stream_request_count": args.stream_request_count, + "stream_arrival_ms": args.stream_arrival_ms, + "stream_prompt_targets": stream_prompt_targets, + "stream_shared_prefix_tokens": args.stream_shared_prefix_tokens, + "stream_scheduler": args.stream_scheduler, + "stream_use_async_batching": args.stream_use_async_batching, + "warmup_tokens": args.warmup_tokens, + "max_new_tokens": args.max_new_tokens, + "workload": [asdict(item) for item in workload], + "results": [asdict(item) for item in stream_results], + "ascii_table": render_stream_ascii_table(stream_results), + } + + stream_json_path = stream_results_json_path(artifact_dir) + stream_json_path.write_text(json.dumps(stream_metadata, indent=2), encoding="utf-8") + + print("\n" + stream_metadata["ascii_table"]) + print(f"\nResults JSON: {stream_json_path}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/benchmark_llama_cpp_vs_gptqmodel_gguf.py b/scripts/benchmark_llama_cpp_vs_gptqmodel_gguf.py new file mode 100644 index 000000000..726666298 --- /dev/null +++ b/scripts/benchmark_llama_cpp_vs_gptqmodel_gguf.py @@ -0,0 +1,425 @@ +#!/usr/bin/env python3 + +from __future__ import annotations + +import argparse +import os +import site +import subprocess +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Callable + +import torch +from llama_cpp import Llama +from llama_cpp import llama_cpp as llama_low +from transformers import AutoTokenizer + +from gptqmodel import BACKEND, GGUFConfig, GPTQModel + + +os.environ.setdefault("CUDA_DEVICE_ORDER", "PCI_BUS_ID") +os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True,max_split_size_mb:256") + + +DEFAULT_MODEL = "/monster/data/model/Llama-3.2-1B-Instruct" + + +@dataclass(frozen=True) +class TrialSummary: + framework: str + device: str + phase: str + token_count: int + samples_ms: list[float] + + @property + def mean_ms(self) -> float: + return sum(self.samples_ms) / len(self.samples_ms) + + @property + def min_ms(self) -> float: + return min(self.samples_ms) + + @property + def max_ms(self) -> float: + return max(self.samples_ms) + + @property + def toks_per_s(self) -> float: + return self.token_count / (self.mean_ms / 1000.0) + + +def _ascii_table(headers: list[str], rows: list[list[str]]) -> str: + widths = [len(h) for h in headers] + for row in rows: + for i, cell in enumerate(row): + widths[i] = max(widths[i], len(cell)) + + def fmt(row: list[str]) -> str: + return "| " + " | ".join(cell.ljust(widths[i]) for i, cell in enumerate(row)) + " |" + + sep = "+-" + "-+-".join("-" * width for width in widths) + "-+" + out = [sep, fmt(headers), sep] + for row in rows: + out.append(fmt(row)) + out.append(sep) + return "\n".join(out) + + +def _run(cmd: list[str]) -> None: + print(f"$ {' '.join(cmd)}") + subprocess.run(cmd, check=True) + + +def _sync_cuda(device: str) -> None: + if device == "cuda": + torch.cuda.synchronize() + + +def _bench(fn: Callable[[], None], *, device: str, warmup: int, trials: int) -> list[float]: + samples_ms: list[float] = [] + for _ in range(warmup): + fn() + _sync_cuda(device) + + for _ in range(trials): + _sync_cuda(device) + t0 = time.perf_counter() + fn() + _sync_cuda(device) + samples_ms.append((time.perf_counter() - t0) * 1000.0) + return samples_ms + + +def _find_convert_script() -> Path: + for root in (Path(p) for p in site.getsitepackages()): + candidate = root / "bin" / "convert_hf_to_gguf.py" + if candidate.exists(): + return candidate + raise FileNotFoundError("Could not locate convert_hf_to_gguf.py in site-packages.") + + +def _prepare_llama_cpp_monolithic(source_model: Path, f16_path: Path, q4_path: Path, threads: int) -> None: + f16_path.parent.mkdir(parents=True, exist_ok=True) + if not f16_path.exists(): + converter = _find_convert_script() + _run( + [ + "python", + str(converter), + str(source_model), + "--outfile", + str(f16_path), + "--outtype", + "f16", + ] + ) + + if not q4_path.exists(): + params = llama_low.llama_model_quantize_default_params() + params.ftype = llama_low.LLAMA_FTYPE_MOSTLY_Q4_K_M + params.nthread = threads + rc = llama_low.llama_model_quantize( + str(f16_path).encode("utf-8"), + str(q4_path).encode("utf-8"), + params, + ) + if rc != 0: + raise RuntimeError(f"llama_model_quantize failed with status code {rc}.") + + +def _prepare_gptqmodel_quantized(source_model: Path, output_dir: Path, offload_dir: Path) -> None: + if (output_dir / "quantize_config.json").exists(): + return + + output_dir.mkdir(parents=True, exist_ok=True) + offload_dir.mkdir(parents=True, exist_ok=True) + + tokenizer = AutoTokenizer.from_pretrained(str(source_model), use_fast=True) + qconfig = GGUFConfig( + bits=4, + format="q_k_m", + smoother=None, + offload_to_disk=True, + offload_to_disk_path=str(offload_dir), + ) + + model = GPTQModel.from_pretrained( + model_id_or_path=str(source_model), + quantize_config=qconfig, + trust_remote_code=False, + ) + model.quantize( + calibration=None, + tokenizer=tokenizer, + backend=BACKEND.GGUF_TORCH, + ) + model.save(str(output_dir)) + tokenizer.save_pretrained(str(output_dir)) + + del model + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +def _build_prompt(tokenizer: AutoTokenizer, target_tokens: int) -> tuple[str, int]: + sentence = ( + "Summarize the scientific, historical, and economic significance of the Atlantic Ocean " + "for intercontinental trade, climate, and biodiversity. " + ) + prompt = sentence + token_count = len(tokenizer(prompt, return_tensors="pt").input_ids[0]) + while token_count < target_tokens: + prompt += sentence + token_count = len(tokenizer(prompt, return_tensors="pt").input_ids[0]) + return prompt, token_count + + +def _load_gptqmodel(model_dir: Path, *, device: str): + dtype = torch.float16 if device == "cuda" else torch.float32 + model = GPTQModel.from_quantized( + model_id_or_path=str(model_dir), + backend=BACKEND.GGUF_TORCH, + device="cuda:0" if device == "cuda" else "cpu", + dtype=dtype, + trust_remote_code=False, + ) + tokenizer = AutoTokenizer.from_pretrained(str(model_dir), use_fast=True) + return model, tokenizer + + +def _load_llama_cpp(model_path: Path, *, device: str, n_ctx: int, n_batch: int, threads: int) -> Llama: + kwargs = dict( + model_path=str(model_path), + n_ctx=n_ctx, + n_batch=n_batch, + n_ubatch=n_batch, + n_threads=threads, + n_threads_batch=threads, + verbose=False, + no_perf=True, + use_mmap=True, + ) + if device == "cuda": + kwargs.update( + { + "n_gpu_layers": -1, + "main_gpu": 0, + } + ) + else: + kwargs.update({"n_gpu_layers": 0}) + return Llama(**kwargs) + + +def _gptq_prefill(model, tokenizer, prompt: str, device: str) -> tuple[list[float], int]: + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + token_count = inputs["input_ids"].shape[1] + + def run_once() -> None: + with torch.inference_mode(): + model.model(**inputs, use_cache=True) + + return run_once, token_count + + +def _gptq_decode(model, tokenizer, prompt: str, decode_tokens: int, device: str) -> tuple[Callable[[], None], int]: + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + + def run_once() -> None: + with torch.inference_mode(): + out = model.model(**inputs, use_cache=True) + past_key_values = out.past_key_values + next_token = out.logits[:, -1, :].argmax(dim=-1, keepdim=True) + for _ in range(decode_tokens): + out = model.model( + input_ids=next_token, + past_key_values=past_key_values, + use_cache=True, + ) + past_key_values = out.past_key_values + next_token = out.logits[:, -1, :].argmax(dim=-1, keepdim=True) + + return run_once, decode_tokens + + +def _llama_prefill(llm: Llama, prompt: str) -> tuple[Callable[[], None], int]: + tokens = llm.tokenize(prompt.encode("utf-8"), add_bos=True, special=False) + + def run_once() -> None: + llm.reset() + llm.eval(tokens) + + return run_once, len(tokens) + + +def _llama_decode(llm: Llama, prompt: str, decode_tokens: int) -> tuple[Callable[[], None], int]: + tokens = llm.tokenize(prompt.encode("utf-8"), add_bos=True, special=False) + + def run_once() -> None: + llm.reset() + llm.eval(tokens) + for _ in range(decode_tokens): + token = llm.sample(temp=0.0, top_k=1, top_p=1.0, min_p=0.0) + llm.eval([token]) + + return run_once, decode_tokens + + +def _summarize_trials(framework: str, device: str, phase: str, token_count: int, samples_ms: list[float]) -> TrialSummary: + return TrialSummary( + framework=framework, + device=device, + phase=phase, + token_count=token_count, + samples_ms=samples_ms, + ) + + +def _print_trial_table(results: list[TrialSummary]) -> None: + trial_count = max(len(r.samples_ms) for r in results) + headers = ["framework", "phase", "tokens"] + [f"trial_{i}_ms" for i in range(1, trial_count + 1)] + rows: list[list[str]] = [] + for result in results: + row = [result.framework, result.phase, str(result.token_count)] + row.extend(f"{sample:.2f}" for sample in result.samples_ms) + if len(result.samples_ms) < trial_count: + row.extend("-" for _ in range(trial_count - len(result.samples_ms))) + rows.append(row) + print(_ascii_table(headers, rows)) + + +def _print_summary_table(results: list[TrialSummary]) -> None: + headers = ["framework", "phase", "mean_ms", "min_ms", "max_ms", "tok_per_s"] + rows = [ + [ + result.framework, + result.phase, + f"{result.mean_ms:.2f}", + f"{result.min_ms:.2f}", + f"{result.max_ms:.2f}", + f"{result.toks_per_s:.2f}", + ] + for result in results + ] + print(_ascii_table(headers, rows)) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Benchmark llama-cpp-python monolithic GGUF vs gptqmodel GGUF on the same model." + ) + parser.add_argument("--model", default=DEFAULT_MODEL, help="HF model directory to convert/quantize.") + parser.add_argument("--work-dir", default="/tmp/llama_cpp_vs_gptqmodel_gguf", help="Artifact cache directory.") + parser.add_argument("--prompt-tokens", type=int, default=512, help="Approximate prompt token length.") + parser.add_argument("--decode-tokens", type=int, default=64, help="Number of autoregressive decode steps.") + parser.add_argument("--warmup", type=int, default=1, help="Warmup iterations per benchmark.") + parser.add_argument("--trials", type=int, default=3, help="Measured trials per benchmark.") + parser.add_argument( + "--device", + choices=("cpu", "cuda", "both"), + default="both", + help="Run benchmarks on CPU, CUDA, or both.", + ) + parser.add_argument("--threads", type=int, default=min(os.cpu_count() or 1, 16), help="CPU threads for llama.cpp.") + parser.add_argument("--skip-prepare", action="store_true", help="Assume benchmark artifacts already exist.") + return parser.parse_args() + + +def main() -> None: + args = parse_args() + source_model = Path(args.model) + work_dir = Path(args.work_dir) + gptq_dir = work_dir / "gptqmodel_q4_k_m" + offload_dir = work_dir / "gptqmodel_offload" + llama_f16_path = work_dir / "llama3_2_1b_f16.gguf" + llama_q4_path = work_dir / "llama3_2_1b_q4_k_m.gguf" + + if not args.skip_prepare: + _prepare_llama_cpp_monolithic(source_model, llama_f16_path, llama_q4_path, args.threads) + _prepare_gptqmodel_quantized(source_model, gptq_dir, offload_dir) + + tokenizer = AutoTokenizer.from_pretrained(str(source_model), use_fast=True) + prompt, hf_token_count = _build_prompt(tokenizer, args.prompt_tokens) + n_ctx = hf_token_count + args.decode_tokens + 64 + n_batch = max(hf_token_count + 8, 512) + + devices = ["cpu", "cuda"] if args.device == "both" else [args.device] + if "cuda" in devices and not torch.cuda.is_available(): + raise RuntimeError("CUDA benchmarking requested but no CUDA device is available.") + + print(f"source_model={source_model}") + print(f"gptqmodel_dir={gptq_dir}") + print(f"llama_cpp_gguf={llama_q4_path}") + print(f"prompt_tokens_hf={hf_token_count} decode_tokens={args.decode_tokens} warmup={args.warmup} trials={args.trials}") + + for device in devices: + print() + print(f"DEVICE {device}") + + gptq_model, gptq_tokenizer = _load_gptqmodel(gptq_dir, device=device) + llama_model = _load_llama_cpp( + llama_q4_path, + device=device, + n_ctx=n_ctx, + n_batch=n_batch, + threads=args.threads, + ) + + device_results: list[TrialSummary] = [] + + gptq_prefill_fn, gptq_prefill_tokens = _gptq_prefill(gptq_model, gptq_tokenizer, prompt, device) + gptq_decode_fn, gptq_decode_tokens = _gptq_decode(gptq_model, gptq_tokenizer, prompt, args.decode_tokens, device) + llama_prefill_fn, llama_prefill_tokens = _llama_prefill(llama_model, prompt) + llama_decode_fn, llama_decode_tokens = _llama_decode(llama_model, prompt, args.decode_tokens) + + device_results.append( + _summarize_trials( + "gptqmodel", + device, + "prefill", + gptq_prefill_tokens, + _bench(gptq_prefill_fn, device=device, warmup=args.warmup, trials=args.trials), + ) + ) + device_results.append( + _summarize_trials( + "gptqmodel", + device, + "decode", + gptq_decode_tokens, + _bench(gptq_decode_fn, device=device, warmup=args.warmup, trials=args.trials), + ) + ) + device_results.append( + _summarize_trials( + "llama-cpp-python", + device, + "prefill", + llama_prefill_tokens, + _bench(llama_prefill_fn, device="cpu", warmup=args.warmup, trials=args.trials), + ) + ) + device_results.append( + _summarize_trials( + "llama-cpp-python", + device, + "decode", + llama_decode_tokens, + _bench(llama_decode_fn, device="cpu", warmup=args.warmup, trials=args.trials), + ) + ) + + _print_trial_table(device_results) + _print_summary_table(device_results) + + del llama_model + del gptq_model + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +if __name__ == "__main__": + main() diff --git a/scripts/benchmark_marlin_a100.py b/scripts/benchmark_marlin_a100.py new file mode 100644 index 000000000..db772fcb4 --- /dev/null +++ b/scripts/benchmark_marlin_a100.py @@ -0,0 +1,284 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import argparse +import json +import os +import re +import time +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any + +import torch +from tabulate import tabulate + +from gptqmodel.nn_modules.qlinear.marlin import MarlinLinear + + +os.environ.setdefault("CUDA_DEVICE_ORDER", "PCI_BUS_ID") + + +@dataclass(frozen=True) +class BenchCase: + case_id: str + m: int + in_features: int + out_features: int + group_size: int = 128 + desc_act: bool = False + + +def _build_default_cases() -> list[BenchCase]: + cases: list[BenchCase] = [] + for m in (64, 72, 80, 88, 96, 104, 112, 120, 128, 136, 144, 152, 160, 168, 176, 184, 192): + cases.append(BenchCase(case_id=f"mlp_up_m{m}", m=m, in_features=4096, out_features=11008)) + for m in (64, 80, 96, 112, 128, 160, 192): + cases.append(BenchCase(case_id=f"mlp_down_m{m}", m=m, in_features=11008, out_features=4096)) + for m in (64, 96, 128, 192): + cases.append(BenchCase(case_id=f"attn_m{m}", m=m, in_features=4096, out_features=4096)) + return cases + + +DEFAULT_CASES = _build_default_cases() + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Benchmark synthetic Marlin GEMMs on A100-class PCI-ordered GPUs." + ) + parser.add_argument("--device", default="cuda:0", help="Torch device string to benchmark on.") + parser.add_argument("--dtype", default="fp16", choices=("fp16", "bf16"), help="Marlin compute dtype.") + parser.add_argument("--warmup", type=int, default=30, help="Warmup iterations per case.") + parser.add_argument("--iters", type=int, default=80, help="Measured iterations per case.") + parser.add_argument("--seed", type=int, default=1234, help="Base RNG seed.") + parser.add_argument("--shard-index", type=int, default=0, help="Zero-based shard index.") + parser.add_argument("--num-shards", type=int, default=1, help="Total number of case shards.") + parser.add_argument( + "--desc-act", + default="off", + choices=("off", "on", "both"), + help="Whether to benchmark non-act-order kernels, act-order kernels, or both.", + ) + parser.add_argument( + "--case-pattern", + default=None, + help="Optional regex filter applied to benchmark case ids.", + ) + parser.add_argument("--json-out", type=Path, default=None, help="Optional JSON output path.") + return parser.parse_args() + + +def _resolve_dtype(name: str) -> torch.dtype: + if name == "fp16": + return torch.float16 + if name == "bf16": + return torch.bfloat16 + raise ValueError(f"Unsupported dtype: {name}") + + +def _subset_cases(cases: list[BenchCase], shard_index: int, num_shards: int) -> list[BenchCase]: + if num_shards <= 0: + raise ValueError("`num_shards` must be positive.") + if shard_index < 0 or shard_index >= num_shards: + raise ValueError(f"`shard_index` must be in [0, {num_shards - 1}].") + return [case for index, case in enumerate(cases) if index % num_shards == shard_index] + + +def _expand_desc_act_cases(cases: list[BenchCase], mode: str) -> list[BenchCase]: + expanded: list[BenchCase] = [] + desc_modes: tuple[bool, ...] + if mode == "off": + desc_modes = (False,) + elif mode == "on": + desc_modes = (True,) + elif mode == "both": + desc_modes = (False, True) + else: + raise ValueError(f"Unsupported desc_act mode: {mode}") + + for case in cases: + for desc_act in desc_modes: + case_id = case.case_id if not desc_act else f"{case.case_id}_act" + expanded.append( + BenchCase( + case_id=case_id, + m=case.m, + in_features=case.in_features, + out_features=case.out_features, + group_size=case.group_size, + desc_act=desc_act, + ) + ) + return expanded + + +def _filter_cases(cases: list[BenchCase], pattern: str | None) -> list[BenchCase]: + if pattern is None: + return cases + regex = re.compile(pattern) + return [case for case in cases if regex.search(case.case_id)] + + +def _build_module( + *, + device: torch.device, + dtype: torch.dtype, + seed: int, + in_features: int, + out_features: int, + group_size: int, + desc_act: bool, +) -> MarlinLinear: + generator = torch.Generator(device=device) + generator.manual_seed(seed) + module = MarlinLinear( + bits=4, + group_size=group_size, + desc_act=desc_act, + sym=True, + in_features=in_features, + out_features=out_features, + bias=False, + dtype=dtype, + ).to(device) + with torch.no_grad(): + module.qweight.copy_( + torch.randint( + -(2**31), + 2**31 - 1, + module.qweight.shape, + dtype=torch.int32, + device=device, + generator=generator, + ) + ) + module.scales.copy_( + torch.rand( + module.scales.shape, + dtype=dtype, + device=device, + generator=generator, + ) * 0.5 + 0.5 + ) + module.qzeros.zero_() + module.g_idx.zero_() + module.eval() + module.post_init() + return module + + +def _benchmark_case( + *, + module: MarlinLinear, + case: BenchCase, + dtype: torch.dtype, + device: torch.device, + seed: int, + warmup: int, + iters: int, +) -> dict[str, Any]: + generator = torch.Generator(device=device) + generator.manual_seed(seed) + x = torch.rand( + (case.m, case.in_features), + dtype=dtype, + device=device, + generator=generator, + ) + with torch.inference_mode(): + for _ in range(warmup): + module(x) + torch.cuda.synchronize(device) + start = time.perf_counter() + for _ in range(iters): + y = module(x) + torch.cuda.synchronize(device) + elapsed_s = time.perf_counter() - start + mean_ms = elapsed_s * 1e3 / iters + tflops = (2.0 * case.m * case.in_features * case.out_features) / (mean_ms * 1e9) + return { + "case_id": case.case_id, + "m": case.m, + "in_features": case.in_features, + "out_features": case.out_features, + "shape": list(y.shape), + "mean_ms": mean_ms, + "tflops": tflops, + } + + +def main() -> None: + args = _parse_args() + dtype = _resolve_dtype(args.dtype) + device = torch.device(args.device) + if device.type != "cuda": + raise ValueError("This benchmark only supports CUDA devices.") + + visible = os.environ.get("CUDA_VISIBLE_DEVICES", "") + cases = _expand_desc_act_cases(DEFAULT_CASES, args.desc_act) + cases = _filter_cases(cases, args.case_pattern) + cases = _subset_cases(cases, shard_index=args.shard_index, num_shards=args.num_shards) + if not cases: + raise ValueError("Shard selection produced no benchmark cases.") + + module_cache: dict[tuple[int, int, int, bool], MarlinLinear] = {} + rows: list[dict[str, Any]] = [] + for index, case in enumerate(cases): + cache_key = (case.in_features, case.out_features, case.group_size, case.desc_act) + module = module_cache.get(cache_key) + if module is None: + module = _build_module( + device=device, + dtype=dtype, + seed=args.seed + index, + in_features=case.in_features, + out_features=case.out_features, + group_size=case.group_size, + desc_act=case.desc_act, + ) + module_cache[cache_key] = module + rows.append( + _benchmark_case( + module=module, + case=case, + dtype=dtype, + device=device, + seed=args.seed + 1000 + index, + warmup=args.warmup, + iters=args.iters, + ) + ) + + table = [ + [row["case_id"], row["m"], row["in_features"], row["out_features"], f'{row["mean_ms"]:.6f}', f'{row["tflops"]:.2f}'] + for row in rows + ] + print( + tabulate( + table, + headers=("case", "m", "k", "n", "mean_ms", "tflops"), + tablefmt="github", + ) + ) + + payload = { + "dtype": args.dtype, + "device": args.device, + "cuda_visible_devices": visible, + "shard_index": args.shard_index, + "num_shards": args.num_shards, + "cases": [asdict(case) for case in cases], + "results": rows, + } + if args.json_out is not None: + args.json_out.parent.mkdir(parents=True, exist_ok=True) + args.json_out.write_text(json.dumps(payload, indent=2), encoding="utf-8") + print(f"\njson_out={args.json_out}") + + +if __name__ == "__main__": + main() diff --git a/scripts/benchmark_paroquant_official_vs_local.py b/scripts/benchmark_paroquant_official_vs_local.py new file mode 100644 index 000000000..323fa70bf --- /dev/null +++ b/scripts/benchmark_paroquant_official_vs_local.py @@ -0,0 +1,757 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import argparse +import json +import statistics +import sys +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import torch +import torch.nn.functional as F +from safetensors.torch import load_file, save_file +from tabulate import tabulate +from torch import nn +from transformers import AutoModelForCausalLM, AutoTokenizer + +from gptqmodel.quantization.paroquant import optimization as local_opt +from gptqmodel.utils.paroquant import apply_paroquant_rotation, build_identity_rotation_buffers +from gptqmodel.utils.paroquant_benchmark import _normalize_model_dtype, load_nm_calibration + + +DEFAULT_MODEL = "/monster/data/model/Llama-3.2-1B-Instruct" +DEFAULT_OFFICIAL_REPO = "/root/official_paroquant" +DEFAULT_ASSET_DIR = "benchmark_assets" +DEFAULT_CAPTURE_ROWS = 4096 +DEFAULT_CALIBRATION_SAMPLES = 512 +DEFAULT_MODULES = ( + "model.layers.0.mlp.gate_proj", + "model.layers.0.mlp.up_proj", + "model.layers.0.mlp.down_proj", +) + + +@dataclass(frozen=True) +class LocalCase: + label: str + stage_impl: str + pair_impl: str + quantizer_impl: str + + +@dataclass +class BenchResult: + module: str + impl: str + pair_s: float + setup_s: float + stage1_s: float + stage2_s: float + export_s: float + opt_s: float + total_s: float + train_loss: float + val_loss: float + repeat: int + + +LOCAL_CASES: tuple[LocalCase, ...] = ( + LocalCase("local_rrr", "reference", "reference", "reference"), + LocalCase("local_rfr", "reference", "fast", "reference"), + LocalCase("local_ffr", "fast", "fast", "reference"), + LocalCase("local_fff", "fast", "fast", "fast"), +) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Benchmark local ParoQuant implementations against the official PR reference on saved activations." + ) + parser.add_argument("--model", default=DEFAULT_MODEL) + parser.add_argument("--official-repo", default=DEFAULT_OFFICIAL_REPO) + parser.add_argument("--asset-dir", default=DEFAULT_ASSET_DIR) + parser.add_argument("--asset-name", default="llama32_1b_layer0_mlp_gate_up_down") + parser.add_argument("--module", dest="modules", action="append", default=None) + parser.add_argument("--model-dtype", default="fp16") + parser.add_argument("--calibration-samples", type=int, default=DEFAULT_CALIBRATION_SAMPLES) + parser.add_argument("--capture-rows", type=int, default=DEFAULT_CAPTURE_ROWS) + parser.add_argument("--bits", type=int, default=4) + parser.add_argument("--group-size", type=int, default=128) + parser.add_argument("--krot", type=int, default=8) + parser.add_argument("--pair-ratio", type=float, default=0.5) + parser.add_argument("--train-rows", type=int, default=1024) + parser.add_argument("--val-rows", type=int, default=256) + parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--rotation-epochs", type=int, default=10) + parser.add_argument("--finetune-epochs", type=int, default=10) + parser.add_argument("--rotation-lr", type=float, default=0.05) + parser.add_argument("--weight-lr", type=float, default=1e-5) + parser.add_argument("--quantizer-lr", type=float, default=1e-6) + parser.add_argument("--repeats", type=int, default=3) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--force-recapture", action="store_true") + parser.add_argument("--capture-only", action="store_true") + parser.add_argument("--module-filter", default=None, help="Substring filter applied after loading the asset.") + return parser.parse_args() + + +def sync_cuda(device: torch.device) -> None: + if device.type == "cuda": + torch.cuda.synchronize(device) + + +def cuda_timer(device: torch.device): + class _Timer: + def __enter__(self): + sync_cuda(device) + self.start = time.perf_counter() + return self + + def __exit__(self, exc_type, exc, tb): + sync_cuda(device) + self.elapsed = time.perf_counter() - self.start + return False + + return _Timer() + + +def _get_named_module(model, module_name: str): + module_map = dict(model.named_modules()) + if module_name not in module_map: + raise KeyError(f"Module `{module_name}` not found.") + return module_map[module_name] + + +def _tokenize_calibration_sample(tokenizer, sample: dict[str, Any]) -> dict[str, torch.Tensor]: + if "input_ids" in sample: + input_ids = torch.as_tensor(sample["input_ids"], dtype=torch.long) + if input_ids.ndim == 1: + input_ids = input_ids.unsqueeze(0) + attention_mask = sample.get("attention_mask") + if attention_mask is None: + attention_mask = torch.ones_like(input_ids, dtype=torch.long) + else: + attention_mask = torch.as_tensor(attention_mask, dtype=torch.long) + if attention_mask.ndim == 1: + attention_mask = attention_mask.unsqueeze(0) + return {"input_ids": input_ids, "attention_mask": attention_mask} + + if "messages" in sample: + rendered = tokenizer.apply_chat_template( + sample["messages"], + tokenize=False, + add_generation_prompt=False, + ) + tokenized = tokenizer(rendered, add_special_tokens=True, return_tensors="pt") + return { + "input_ids": tokenized["input_ids"].to(dtype=torch.long), + "attention_mask": tokenized.get("attention_mask", torch.ones_like(tokenized["input_ids"])).to(dtype=torch.long), + } + + if "text" in sample: + tokenized = tokenizer(sample["text"], add_special_tokens=True, return_tensors="pt") + return { + "input_ids": tokenized["input_ids"].to(dtype=torch.long), + "attention_mask": tokenized.get("attention_mask", torch.ones_like(tokenized["input_ids"])).to(dtype=torch.long), + } + + raise ValueError(f"Unsupported calibration sample keys: {sorted(sample.keys())}") + + +def _capture_module_inputs( + model, + tokenizer, + module_names: list[str], + calibration_dataset: list[dict[str, Any]], + *, + max_rows: int, +) -> dict[str, torch.Tensor]: + captured: dict[str, list[torch.Tensor]] = {name: [] for name in module_names} + captured_rows = {name: 0 for name in module_names} + hooks = [] + + def make_hook(name: str): + def hook(_module, inputs): + if not inputs or captured_rows[name] >= max_rows: + return + x = inputs[0].detach().reshape(-1, inputs[0].shape[-1]).cpu() + remaining = max_rows - captured_rows[name] + if remaining <= 0: + return + piece = x[:remaining].contiguous() + if piece.numel() == 0: + return + captured[name].append(piece) + captured_rows[name] += piece.shape[0] + + return hook + + for name in module_names: + module = _get_named_module(model, name) + hooks.append(module.register_forward_pre_hook(make_hook(name))) + + model_device = next(model.parameters()).device + try: + for sample in calibration_dataset: + if all(count >= max_rows for count in captured_rows.values()): + break + tokenized = _tokenize_calibration_sample(tokenizer, sample) + with torch.inference_mode(): + model( + input_ids=tokenized["input_ids"].to(device=model_device), + attention_mask=tokenized["attention_mask"].to(device=model_device), + ) + finally: + for hook in hooks: + hook.remove() + + flattened: dict[str, torch.Tensor] = {} + for name in module_names: + pieces = captured[name] + if not pieces: + raise RuntimeError(f"Failed to capture calibration activations for `{name}`.") + flattened[name] = torch.cat(pieces, dim=0)[:max_rows].contiguous() + return flattened + + +def _asset_paths(asset_dir: Path, asset_name: str) -> tuple[Path, Path]: + return asset_dir / f"{asset_name}.safetensors", asset_dir / f"{asset_name}.json" + + +def _module_key(module_name: str) -> str: + return module_name.replace(".", "__") + + +def _capture_asset(args: argparse.Namespace, asset_path: Path, metadata_path: Path) -> None: + modules = args.modules or list(DEFAULT_MODULES) + asset_path.parent.mkdir(parents=True, exist_ok=True) + + tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=False) + if getattr(tokenizer, "padding_side", None) != "left": + tokenizer.padding_side = "left" + if getattr(tokenizer, "pad_token_id", None) is None and getattr(tokenizer, "eos_token_id", None) is not None: + tokenizer.pad_token_id = tokenizer.eos_token_id + + model = AutoModelForCausalLM.from_pretrained( + args.model, + trust_remote_code=False, + torch_dtype=_normalize_model_dtype(args.model_dtype), + low_cpu_mem_usage=True, + ) + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + model.to(device) + model.eval() + + calibration_dataset = load_nm_calibration(args.calibration_samples) + captured = _capture_module_inputs( + model, + tokenizer, + modules, + calibration_dataset, + max_rows=args.capture_rows, + ) + + tensors: dict[str, torch.Tensor] = {} + metadata: dict[str, Any] = { + "model": args.model, + "model_dtype": str(args.model_dtype), + "calibration_samples": int(args.calibration_samples), + "capture_rows": int(args.capture_rows), + "modules": [], + } + for module_name in modules: + module = _get_named_module(model, module_name) + module_id = _module_key(module_name) + weight = module.weight.detach().to(dtype=torch.float32, device="cpu").contiguous() + tensors[f"{module_id}.weight"] = weight + if getattr(module, "bias", None) is not None: + tensors[f"{module_id}.bias"] = module.bias.detach().to(dtype=torch.float32, device="cpu").contiguous() + tensors[f"{module_id}.inputs"] = captured[module_name].to(dtype=torch.float32, device="cpu").contiguous() + metadata["modules"].append( + { + "name": module_name, + "key": module_id, + "rows": int(captured[module_name].shape[0]), + "in_features": int(weight.shape[1]), + "out_features": int(weight.shape[0]), + "has_bias": getattr(module, "bias", None) is not None, + } + ) + + save_file(tensors, str(asset_path)) + metadata_path.write_text(json.dumps(metadata, indent=2)) + + del model + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +def _load_asset(asset_path: Path, metadata_path: Path, module_filter: str | None) -> tuple[dict[str, Any], dict[str, dict[str, torch.Tensor]]]: + metadata = json.loads(metadata_path.read_text()) + tensors = load_file(str(asset_path), device="cpu") + loaded: dict[str, dict[str, torch.Tensor]] = {} + for module_info in metadata["modules"]: + module_name = module_info["name"] + if module_filter and module_filter not in module_name: + continue + module_id = module_info["key"] + loaded[module_name] = { + "weight": tensors[f"{module_id}.weight"].contiguous(), + "inputs": tensors[f"{module_id}.inputs"].contiguous(), + } + bias_key = f"{module_id}.bias" + if bias_key in tensors: + loaded[module_name]["bias"] = tensors[bias_key].contiguous() + else: + loaded[module_name]["bias"] = None + if not loaded: + raise ValueError("No modules matched the requested asset filter.") + return metadata, loaded + + +def _prepare_rows( + weight: torch.Tensor, + bias: torch.Tensor | None, + inputs: torch.Tensor, + *, + train_rows: int, + val_rows: int, + device: torch.device, +) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + weight_opt = weight.to(device=device, dtype=torch.float32).contiguous() + bias_opt = None if bias is None else bias.to(device=device, dtype=torch.float32).contiguous() + rows = local_opt._sample_activation_rows(inputs, max_rows=max(1, int(train_rows) + int(val_rows))) + rows = rows.to(device=device, dtype=torch.float32).contiguous() + targets = F.linear(rows, weight_opt, bias_opt) + train_count = min(rows.shape[0], max(1, int(train_rows))) + val_count = min(max(1, int(val_rows)), max(1, rows.shape[0] - train_count)) + inputs_train = rows[:train_count].contiguous() + targets_train = targets[:train_count].contiguous() + inputs_val = rows[-val_count:].contiguous() + targets_val = targets[-val_count:].contiguous() + return weight_opt, bias_opt, inputs_train, targets_train, inputs_val, targets_val + + +def _batch_list(rows: torch.Tensor, batch_size: int) -> list[torch.Tensor]: + return [rows[start:start + batch_size].contiguous() for start in range(0, rows.shape[0], batch_size)] + + +def _common_loss(module_or_weight: nn.Module | torch.Tensor, bias: torch.Tensor | None, inputs: torch.Tensor, targets: torch.Tensor) -> float: + with torch.no_grad(): + if isinstance(module_or_weight, torch.Tensor): + preds = F.linear(inputs, module_or_weight, bias) + else: + preds = module_or_weight(inputs) + return float(F.smooth_l1_loss(preds, targets).item()) + + +def _warmup_local_kernel(device: torch.device, group_size: int, krot: int) -> None: + if device.type != "cuda": + return + x = torch.randn(32, group_size, device=device, dtype=torch.float32) + pairs, theta, scales = build_identity_rotation_buffers( + in_features=group_size, + group_size=group_size, + krot=krot, + device=device, + dtype=torch.float32, + ) + apply_paroquant_rotation(x, pairs, theta, scales, group_size) + sync_cuda(device) + + +def _import_official(repo_path: Path): + if str(repo_path) not in sys.path: + sys.path.insert(0, str(repo_path)) + from paroquant.kernels.cuda import scaled_pairwise_rotation # noqa: F401 + from paroquant.optim.qlinear import PseudoQuantizedLinear + from paroquant.optim.rotation import transform_to_kernel_data + from paroquant.optim.train import get_random_rotation_pairs, optimize_module + + return PseudoQuantizedLinear, transform_to_kernel_data, get_random_rotation_pairs, optimize_module + + +def _run_local_case( + *, + module_name: str, + case: LocalCase, + weight: torch.Tensor, + bias: torch.Tensor | None, + inputs: torch.Tensor, + args: argparse.Namespace, + device: torch.device, + repeat: int, +) -> BenchResult: + seed = int(args.seed) + (repeat * 1000) + torch.manual_seed(seed) + if device.type == "cuda": + torch.cuda.manual_seed_all(seed) + torch.cuda.empty_cache() + + with cuda_timer(device) as total_timer: + with cuda_timer(device) as setup_timer: + weight_opt, bias_opt, inputs_train, targets_train, inputs_val, targets_val = _prepare_rows( + weight, + bias, + inputs, + train_rows=args.train_rows, + val_rows=args.val_rows, + device=device, + ) + + normalized_group_size = local_opt._normalize_group_size(args.group_size, weight_opt.shape[1]) + quantizer_sym = local_opt._quantizer_sym_for_impl(True, case.quantizer_impl) + with cuda_timer(device) as pair_timer: + if case.pair_impl == "reference": + pairs, theta_mask = local_opt.build_random_rotation_buffers_reference( + in_features=weight_opt.shape[1], + group_size=normalized_group_size, + krot=args.krot, + pair_ratio=args.pair_ratio, + seed=seed, + device=device, + ) + else: + pairs, theta_mask = local_opt.build_random_rotation_buffers( + in_features=weight_opt.shape[1], + group_size=normalized_group_size, + krot=args.krot, + pair_ratio=args.pair_ratio, + seed=seed, + device=device, + ) + + model = local_opt._ParoQuantOptimLinear( + weight_opt, + bias_opt, + bits=args.bits, + group_size=normalized_group_size, + quantizer_sym=quantizer_sym, + pairs=pairs, + theta_mask=theta_mask, + fused_rotation=True, + ).to(device=device, dtype=torch.float32) + model.reset_masked_angles() + + with cuda_timer(device) as stage1_timer: + local_opt._run_stage( + model=model, + inputs_train=inputs_train, + targets_train=targets_train, + inputs_val=inputs_val, + targets_val=targets_val, + param_groups=[ + {"params": [model.channel_scales_opt], "lr": args.rotation_lr}, + {"params": [model.theta], "lr": args.rotation_lr}, + ], + epochs=args.rotation_epochs, + batch_size=args.batch_size, + stage_impl=case.stage_impl, + ) + + with cuda_timer(device) as stage2_timer: + model.init_quantizer() + train_loss, val_loss = local_opt._run_stage( + model=model, + inputs_train=inputs_train, + targets_train=targets_train, + inputs_val=inputs_val, + targets_val=targets_val, + param_groups=[ + {"params": [model.weight], "lr": args.weight_lr}, + {"params": model.quantizer.optim_params(), "lr": args.quantizer_lr}, + ], + epochs=args.finetune_epochs, + batch_size=args.batch_size, + stage_impl=case.stage_impl, + ) + + with cuda_timer(device) as export_timer: + result = local_opt._result_from_model( + model, + train_loss=train_loss, + val_loss=val_loss, + used_identity=False, + ) + + del model + if device.type == "cuda": + torch.cuda.empty_cache() + + return BenchResult( + module=module_name, + impl=case.label, + pair_s=pair_timer.elapsed, + setup_s=setup_timer.elapsed, + stage1_s=stage1_timer.elapsed, + stage2_s=stage2_timer.elapsed, + export_s=export_timer.elapsed, + opt_s=pair_timer.elapsed + setup_timer.elapsed + stage1_timer.elapsed + stage2_timer.elapsed, + total_s=total_timer.elapsed, + train_loss=float(result.train_loss), + val_loss=float(result.val_loss), + repeat=repeat, + ) + + +def _run_official_case( + *, + module_name: str, + weight: torch.Tensor, + bias: torch.Tensor | None, + inputs: torch.Tensor, + args: argparse.Namespace, + device: torch.device, + repeat: int, + official_repo: Path, +) -> BenchResult: + seed = int(args.seed) + (repeat * 1000) + torch.manual_seed(seed) + if device.type == "cuda": + torch.cuda.manual_seed_all(seed) + torch.cuda.empty_cache() + + PseudoQuantizedLinear, transform_to_kernel_data, get_random_rotation_pairs, optimize_module = _import_official( + official_repo + ) + + with cuda_timer(device) as total_timer: + with cuda_timer(device) as setup_timer: + weight_opt, bias_opt, inputs_train, targets_train, inputs_val, targets_val = _prepare_rows( + weight, + bias, + inputs, + train_rows=args.train_rows, + val_rows=args.val_rows, + device=device, + ) + train_input_batches = _batch_list(inputs_train, args.batch_size) + train_output_batches = _batch_list(targets_train, args.batch_size) + val_input_batches = _batch_list(inputs_val, args.batch_size) + val_output_batches = _batch_list(targets_val, args.batch_size) + + normalized_group_size = local_opt._normalize_group_size(args.group_size, weight_opt.shape[1]) + with cuda_timer(device) as pair_timer: + weight_grouped = weight_opt.view(weight_opt.shape[0], -1, normalized_group_size).permute(1, 0, 2) + all_pairs = get_random_rotation_pairs( + weight_grouped, + group_size=normalized_group_size, + num_rotations=args.krot, + num_pairs_factor=args.pair_ratio, + seed=seed, + ) + all_pairs = [torch.tensor(pairs, device="cpu", dtype=torch.int32) for pairs in all_pairs] + initial_angles = [torch.zeros(pair_tensor.shape[0], device="cpu") for pair_tensor in all_pairs] + npairs, angles, mask = transform_to_kernel_data( + all_pairs, + initial_angles, + group_size=normalized_group_size, + ) + linear = nn.Linear( + weight_opt.shape[1], + weight_opt.shape[0], + bias=bias_opt is not None, + device=device, + dtype=torch.float32, + ) + linear.weight.data.copy_(weight_opt) + if bias_opt is not None: + linear.bias.data.copy_(bias_opt) + channel_scales = torch.ones((1, weight_opt.shape[1]), dtype=weight_opt.dtype, device=device) + module = PseudoQuantizedLinear( + linear, + [npairs.to(device), angles.to(device), mask.to(device)], + channel_scales, + group_size=normalized_group_size, + n_bits=args.bits, + num_rotations=args.krot, + ) + + def _param_group(params: list[nn.Parameter], lr: float) -> dict[str, object]: + return { + "params": params, + "lr": lr, + "weight_decay": 0.01, + "betas": (0.9, 0.95), + "eps": 1e-10, + } + + with cuda_timer(device) as stage1_timer: + module.set_optim_enabled(channel_scales=True, angles=True) + optimize_module( + module, + (train_input_batches, train_output_batches), + (val_input_batches, val_output_batches), + {}, + [ + _param_group(module.get_optim_params("channel_scales"), args.rotation_lr), + _param_group(module.get_optim_params("angles"), args.rotation_lr), + ], + loss_fn="smooth_l1", + n_iter=args.rotation_epochs, + gradient_accumulation_steps=1, + early_stop=None, + post_optim_callback=lambda current: current.reset_angles_by_mask(), + ) + + with cuda_timer(device) as stage2_timer: + module.set_optim_enabled(weight=True, quantizer=True) + optimize_module( + module, + (train_input_batches, train_output_batches), + (val_input_batches, val_output_batches), + {}, + [ + _param_group(module.get_optim_params("weight"), args.weight_lr), + _param_group(module.get_optim_params("quantizer"), args.quantizer_lr), + ], + loss_fn="smooth_l1", + n_iter=args.finetune_epochs, + gradient_accumulation_steps=1, + early_stop=None, + post_optim_callback=lambda current: current.reset_angles_by_mask(), + ) + + with cuda_timer(device) as export_timer: + train_loss = _common_loss(module, None, inputs_train, targets_train) + val_loss = _common_loss(module, None, inputs_val, targets_val) + _ = module.pseudo_weight() + + del module + if device.type == "cuda": + torch.cuda.empty_cache() + + return BenchResult( + module=module_name, + impl="official_pr18", + pair_s=pair_timer.elapsed, + setup_s=setup_timer.elapsed, + stage1_s=stage1_timer.elapsed, + stage2_s=stage2_timer.elapsed, + export_s=export_timer.elapsed, + opt_s=pair_timer.elapsed + setup_timer.elapsed + stage1_timer.elapsed + stage2_timer.elapsed, + total_s=total_timer.elapsed, + train_loss=train_loss, + val_loss=val_loss, + repeat=repeat, + ) + + +def _summarize(results: list[BenchResult]) -> list[list[str]]: + rows: list[list[str]] = [] + by_module: dict[str, list[BenchResult]] = {} + for result in results: + by_module.setdefault(result.module, []).append(result) + + for module_name, module_results in by_module.items(): + official_total = statistics.median( + [result.total_s for result in module_results if result.impl == "official_pr18"] + ) + ordered_impls = ["official_pr18", *[case.label for case in LOCAL_CASES]] + for impl in ordered_impls: + if not any(result.impl == impl for result in module_results): + continue + selected = [result for result in module_results if result.impl == impl] + median_total = statistics.median(result.total_s for result in selected) + rows.append( + [ + module_name, + impl, + f"{statistics.median(result.pair_s for result in selected):.3f}", + f"{statistics.median(result.setup_s for result in selected):.3f}", + f"{statistics.median(result.stage1_s for result in selected):.3f}", + f"{statistics.median(result.stage2_s for result in selected):.3f}", + f"{statistics.median(result.export_s for result in selected):.3f}", + f"{statistics.median(result.opt_s for result in selected):.3f}", + f"{median_total:.3f}", + f"{official_total / median_total:.3f}x" if median_total > 0 else "", + f"{statistics.mean(result.train_loss for result in selected):.6f}", + f"{statistics.mean(result.val_loss for result in selected):.6f}", + ] + ) + return rows + + +def _print_table(title: str, rows: list[list[str]]) -> None: + print(title) + print( + tabulate( + rows, + headers=[ + "module", + "impl", + "pair_s", + "setup_s", + "stage1_s", + "stage2_s", + "export_s", + "opt_s", + "total_s", + "vs_official", + "train_l1", + "val_l1", + ], + tablefmt="grid", + ) + ) + + +def main() -> int: + args = parse_args() + asset_dir = Path(args.asset_dir) + asset_path, metadata_path = _asset_paths(asset_dir, args.asset_name) + if args.force_recapture or not asset_path.exists() or not metadata_path.exists(): + _capture_asset(args, asset_path, metadata_path) + + if args.capture_only: + print(f"Saved activation asset to {asset_path}") + print(f"Saved metadata to {metadata_path}") + return 0 + + _, modules = _load_asset(asset_path, metadata_path, args.module_filter) + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + _warmup_local_kernel(device, args.group_size, args.krot) + _import_official(Path(args.official_repo)) + + results: list[BenchResult] = [] + for module_name, bundle in modules.items(): + for repeat in range(args.repeats): + results.append( + _run_official_case( + module_name=module_name, + weight=bundle["weight"], + bias=bundle["bias"], + inputs=bundle["inputs"], + args=args, + device=device, + repeat=repeat, + official_repo=Path(args.official_repo), + ) + ) + for case in LOCAL_CASES: + results.append( + _run_local_case( + module_name=module_name, + case=case, + weight=bundle["weight"], + bias=bundle["bias"], + inputs=bundle["inputs"], + args=args, + device=device, + repeat=repeat, + ) + ) + + _print_table("ParoQuant Official vs Local", _summarize(results)) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/benchmark_paroquant_opt_scope_compare.py b/scripts/benchmark_paroquant_opt_scope_compare.py new file mode 100644 index 000000000..713383355 --- /dev/null +++ b/scripts/benchmark_paroquant_opt_scope_compare.py @@ -0,0 +1,599 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import argparse +import copy +import hashlib +import json +import sys +import time +from dataclasses import asdict, dataclass +from pathlib import Path +from types import SimpleNamespace +from typing import Any + +import torch +import torch.nn.functional as F +from tabulate import tabulate +from transformers import AutoModelForCausalLM, AutoTokenizer + +REPO_ROOT = Path(__file__).resolve().parents[1] +OFFICIAL_REPO = Path("/root/official_paroquant") +if str(OFFICIAL_REPO) not in sys.path: + sys.path.insert(0, str(OFFICIAL_REPO)) + +from paroquant.optim.qlinear import PseudoQuantizedLinear +from paroquant.optim.rotation import transform_to_kernel_data +from paroquant.optim.train import get_random_rotation_pairs, optimize_module as official_optimize_module +from paroquant.optim.util import catch_first_layer_input, get_blocks, get_calib_dataset, get_named_linears, move_embed, set_module_by_name + +from gptqmodel.looper.input_cache import InputCache +from gptqmodel.looper.named_module import NamedModule +from gptqmodel.looper.paroquant_processor import ParoQuantProcessor +from gptqmodel.quantization.paroquant.optimization import optimize_paroquant_linear + + +DEFAULT_MODEL = "/monster/data/model/Llama-3.2-1B-Instruct" +DEFAULT_MODULE_NAMES = ( + "self_attn.q_proj", + "self_attn.k_proj", + "self_attn.v_proj", + "self_attn.o_proj", + "mlp.gate_proj", + "mlp.up_proj", + "mlp.down_proj", +) + + +@dataclass +class BenchRow: + case: str + total_s: float + val_smoothl1: float + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Compare ParoQuant module/compute_block/layer modes against official whole-layer optimization." + ) + parser.add_argument("--model", default=DEFAULT_MODEL) + parser.add_argument("--official-repo", default=str(OFFICIAL_REPO)) + parser.add_argument("--dtype", default="fp16", choices=("fp16", "bf16")) + parser.add_argument("--layer-idx", type=int, default=0) + parser.add_argument("--train-batches", type=int, default=2, help="Calibration batches captured for train.") + parser.add_argument("--val-batches", type=int, default=1, help="Calibration batches captured for validation.") + parser.add_argument("--block-size", type=int, default=1024) + parser.add_argument("--train-rows", type=int, default=2048) + parser.add_argument("--val-rows", type=int, default=1024) + parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--bits", type=int, default=4) + parser.add_argument("--group-size", type=int, default=128) + parser.add_argument("--krot", type=int, default=8) + parser.add_argument("--pair-ratio", type=float, default=0.5) + parser.add_argument("--rotation-epochs", type=int, default=1) + parser.add_argument("--finetune-epochs", type=int, default=1) + parser.add_argument("--rotation-lr", type=float, default=0.05) + parser.add_argument("--weight-lr", type=float, default=1e-5) + parser.add_argument("--quantizer-lr", type=float, default=1e-6) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument( + "--case", + dest="cases", + action="append", + choices=("local_module", "local_compute_block", "local_layer", "official_layer"), + default=None, + help="Optional repeated case filter. By default all cases run.", + ) + parser.add_argument("--skip-official", action="store_true") + parser.add_argument("--output-json", type=Path, default=None) + return parser.parse_args() + + +def _dtype_from_label(label: str) -> torch.dtype: + return torch.bfloat16 if str(label).strip().lower() == "bf16" else torch.float16 + + +def _sync(device: torch.device) -> None: + if device.type == "cuda": + torch.cuda.synchronize(device) + + +def _module_seed(base_seed: int, layer_index: int, full_name: str) -> int: + leaf = full_name.rsplit(".", 1)[-1] + seed_material = f"{base_seed}:{layer_index}:{leaf}".encode("utf-8") + digest = hashlib.blake2b(seed_material, digest_size=8).digest() + return int.from_bytes(digest, byteorder="big", signed=False) + + +def _init_rotation_data(weight: torch.Tensor, *, seed: int, group_size: int, num_rotations: int, device: torch.device): + grouped = weight.view(weight.shape[0], -1, group_size).permute(1, 0, 2) + all_pairs = get_random_rotation_pairs( + grouped, + group_size=group_size, + num_rotations=num_rotations, + num_pairs_factor=0.5, + seed=seed, + ) + pair_tensors = [torch.tensor(pairs, device="cpu", dtype=torch.int32) for pairs in all_pairs] + initial_angles = [torch.zeros(pairs.shape[0], device="cpu") for pairs in pair_tensors] + npairs, angles, mask = transform_to_kernel_data(pair_tensors, initial_angles, group_size=group_size) + return [npairs.to(device), angles.to(device), mask.to(device)] + + +def _capture_module_inputs( + layer: torch.nn.Module, + *, + input_batches: list[torch.Tensor], + kwargs: dict[str, Any], + module_names: tuple[str, ...], + device: torch.device, + dtype: torch.dtype, +) -> dict[str, torch.Tensor]: + captured = {name: [] for name in module_names} + handles = [] + + def make_hook(name: str): + def hook(_module, inputs): + if not inputs: + return + x = inputs[0].detach().reshape(-1, inputs[0].shape[-1]).cpu() + captured[name].append(x) + + return hook + + for module_name in module_names: + target = dict(layer.named_modules())[module_name] + handles.append(target.register_forward_pre_hook(make_hook(module_name))) + + layer = layer.to(device) + try: + with torch.no_grad(): + for batch in input_batches: + _ = layer(batch.to(device=device, dtype=dtype), **kwargs) + finally: + for handle in handles: + handle.remove() + layer.cpu() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + return {name: torch.cat(parts, dim=0) for name, parts in captured.items()} + + +def _build_processor( + *, + opt_scope: str, + kwargs: dict[str, Any], + all_input_batches: list[torch.Tensor], + bits: int, + group_size: int, + krot: int, + pair_ratio: float, + train_rows: int, + val_rows: int, + batch_size: int, + rotation_epochs: int, + finetune_epochs: int, + rotation_lr: float, + weight_lr: float, + quantizer_lr: float, + seed: int, +) -> ParoQuantProcessor: + processor = object.__new__(ParoQuantProcessor) + + def dynamic_get(_module_name=None, _key=None, default=None, **_kwargs): + return default + + sanitized_kwargs = {k: v for k, v in kwargs.items() if k not in ("attention_mask", "position_ids", "use_cache")} + processor.qcfg = SimpleNamespace( + opt_scope=opt_scope, + runtime_bits=bits, + group_size=group_size, + sym=True, + krot=krot, + opt_seed=seed, + opt_pair_ratio=pair_ratio, + opt_pair_impl="fast", + opt_quantizer_impl="reference", + opt_fused_rotation=True, + opt_stage_cudagraph=True, + opt_stage_impl="fast", + opt_channel_scale_clamp_min=1e-2, + opt_channel_scale_clamp_max=1e2, + opt_train_samples=train_rows, + opt_validation_samples=val_rows, + opt_batch_size=batch_size, + opt_rotation_lr=rotation_lr, + opt_weight_lr=weight_lr, + opt_quantizer_lr=quantizer_lr, + opt_rotation_epochs=rotation_epochs, + opt_finetune_epochs=finetune_epochs, + dynamic_get=dynamic_get, + ) + processor.gptq_model = SimpleNamespace(support_batch_quantize=True, rotary_embedding=None) + processor.model = None + processor._batch_tls = __import__("threading").local() + processor.lock = __import__("threading").Lock() + processor.tasks = {} + processor.calculate_w_wq_diff = False + processor.fallback = True + processor._rotary_cache = {} + processor._rotary_source_id = None + processor._rotary_lock = __import__("threading").Lock() + processor.inputs_cache = InputCache( + layer_inputs=[[batch] for batch in all_input_batches], + layer_input_kwargs=[sanitized_kwargs for _ in all_input_batches], + position_ids=[kwargs.get("position_ids")] * len(all_input_batches), + attention_masks=[kwargs.get("attention_mask")] * len(all_input_batches), + ) + return processor + + +def _evaluate_group_loss( + processor: ParoQuantProcessor, + layer: torch.nn.Module, + *, + input_batches: list[torch.Tensor], + output_batches: list[torch.Tensor], + kwargs: dict[str, Any], +) -> float: + sanitized_kwargs = [{k: v for k, v in kwargs.items() if k not in ("attention_mask", "position_ids", "use_cache")} for _ in input_batches] + positions = [kwargs.get("position_ids")] * len(input_batches) + masks = [kwargs.get("attention_mask")] * len(input_batches) + return processor._evaluate_group_layer( + layer, + input_batches=[[batch] for batch in input_batches], + input_kwargs_batches=sanitized_kwargs, + target_batches=[[batch] for batch in output_batches], + position_ids=positions, + attention_masks=masks, + use_amp=torch.cuda.is_available(), + ) + + +def _benchmark_local_module( + *, + base_layer: torch.nn.Module, + layer_index: int, + train_input_batches: list[torch.Tensor], + val_input_batches: list[torch.Tensor], + train_output_batches: list[torch.Tensor], + val_output_batches: list[torch.Tensor], + kwargs: dict[str, Any], + module_names: tuple[str, ...], + device: torch.device, + dtype: torch.dtype, + args: argparse.Namespace, +) -> BenchRow: + layer = copy.deepcopy(base_layer).to(dtype=dtype) + all_inputs = train_input_batches + val_input_batches + all_outputs = train_output_batches + val_output_batches + processor = _build_processor( + opt_scope="module", + kwargs=kwargs, + all_input_batches=all_inputs, + bits=args.bits, + group_size=args.group_size, + krot=args.krot, + pair_ratio=args.pair_ratio, + train_rows=args.train_rows, + val_rows=args.val_rows, + batch_size=args.batch_size, + rotation_epochs=args.rotation_epochs, + finetune_epochs=args.finetune_epochs, + rotation_lr=args.rotation_lr, + weight_lr=args.weight_lr, + quantizer_lr=args.quantizer_lr, + seed=args.seed, + ) + inputs = _capture_module_inputs( + layer, + input_batches=all_inputs, + kwargs=kwargs, + module_names=module_names, + device=device, + dtype=dtype, + ) + + _sync(device) + start = time.perf_counter() + for module_name in module_names: + module = dict(layer.named_modules())[module_name] + result = optimize_paroquant_linear( + weight=module.weight.data, + bias=module.bias.data if module.bias is not None else None, + inputs=inputs[module_name], + bits=args.bits, + group_size=args.group_size, + sym=True, + krot=args.krot, + pair_ratio=args.pair_ratio, + train_rows=args.train_rows, + val_rows=args.val_rows, + batch_size=args.batch_size, + rotation_epochs=args.rotation_epochs, + finetune_epochs=args.finetune_epochs, + rotation_lr=args.rotation_lr, + weight_lr=args.weight_lr, + quantizer_lr=args.quantizer_lr, + seed=_module_seed(args.seed, layer_index, f"model.layers.{layer_index}.{module_name}"), + fused_rotation=True, + stage_cudagraph=True, + stage_impl="fast", + pair_impl="fast", + quantizer_impl="reference", + scale_clamp_min=1e-2, + scale_clamp_max=1e2, + ) + module.weight.data = result.pseudo_weight.to(device=module.weight.device, dtype=module.weight.dtype) + _sync(device) + total_s = time.perf_counter() - start + + val_loss = _evaluate_group_loss( + processor, + layer.to(device), + input_batches=val_input_batches, + output_batches=val_output_batches, + kwargs=kwargs, + ) + return BenchRow(case="local_module", total_s=total_s, val_smoothl1=val_loss) + + +def _benchmark_local_group( + *, + opt_scope: str, + base_layer: torch.nn.Module, + layer_index: int, + train_input_batches: list[torch.Tensor], + val_input_batches: list[torch.Tensor], + train_output_batches: list[torch.Tensor], + val_output_batches: list[torch.Tensor], + kwargs: dict[str, Any], + module_names: tuple[str, ...], + device: torch.device, + dtype: torch.dtype, + args: argparse.Namespace, +) -> BenchRow: + del layer_index + layer = copy.deepcopy(base_layer).to(dtype=dtype) + all_inputs = train_input_batches + val_input_batches + all_outputs = train_output_batches + val_output_batches + processor = _build_processor( + opt_scope=opt_scope, + kwargs=kwargs, + all_input_batches=all_inputs, + bits=args.bits, + group_size=args.group_size, + krot=args.krot, + pair_ratio=args.pair_ratio, + train_rows=args.train_rows, + val_rows=args.val_rows, + batch_size=args.batch_size, + rotation_epochs=args.rotation_epochs, + finetune_epochs=args.finetune_epochs, + rotation_lr=args.rotation_lr, + weight_lr=args.weight_lr, + quantizer_lr=args.quantizer_lr, + seed=args.seed, + ) + state = SimpleNamespace( + layer_module=layer, + pristine_layer_module=copy.deepcopy(layer).cpu(), + layer_inputs=[[batch] for batch in all_inputs], + layer_input_kwargs=[{k: v for k, v in kwargs.items() if k not in ("attention_mask", "position_ids", "use_cache")} for _ in all_inputs], + layer_outputs=[[batch] for batch in all_outputs], + modules={ + name: NamedModule(dict(layer.named_modules())[name], name, f"model.layers.{args.layer_idx}.{name}", args.layer_idx) + for name in module_names + }, + ) + + _sync(device) + start = time.perf_counter() + groups = processor._optimization_groups_for_layer(state) + for _label, group_modules in groups: + results, _ = processor._optimize_group(state, group_modules) + for named_module in group_modules: + original_weight = named_module.weight.data.detach().clone() + processor._apply_optimization_result(named_module, results[named_module.name], original_weight) + _sync(device) + total_s = time.perf_counter() - start + + val_loss = _evaluate_group_loss( + processor, + layer.to(device), + input_batches=val_input_batches, + output_batches=val_output_batches, + kwargs=kwargs, + ) + return BenchRow(case=f"local_{opt_scope}", total_s=total_s, val_smoothl1=val_loss) + + +def _benchmark_official_layer( + *, + base_layer: torch.nn.Module, + train_input_batches: list[torch.Tensor], + val_input_batches: list[torch.Tensor], + train_output_batches: list[torch.Tensor], + val_output_batches: list[torch.Tensor], + kwargs: dict[str, Any], + device: torch.device, + dtype: torch.dtype, + args: argparse.Namespace, +) -> BenchRow: + layer = copy.deepcopy(base_layer).to(device=device, dtype=torch.float32) + named_pseudo_modules = {} + for name, old_module in get_named_linears(layer).items(): + weight = old_module.weight.float() + rotation_pairs = _init_rotation_data( + weight, + seed=args.seed, + group_size=args.group_size, + num_rotations=args.krot, + device=device, + ) + channel_scales = torch.ones(1, weight.shape[1], dtype=weight.dtype, device=device) + new_module = PseudoQuantizedLinear( + old_module, + rotation_pairs, + channel_scales, + group_size=args.group_size, + n_bits=args.bits, + num_rotations=args.krot, + ) + set_module_by_name(layer, name, new_module) + named_pseudo_modules[name] = new_module + + for param in layer.parameters(): + param.requires_grad = False + + _sync(device) + start = time.perf_counter() + for step_params, epochs in ( + ({"channel_scales": args.rotation_lr, "angles": args.rotation_lr}, args.rotation_epochs), + ({"weight": args.weight_lr, "quantizer": args.quantizer_lr}, args.finetune_epochs), + ): + optim_params = [] + for new_module in named_pseudo_modules.values(): + new_module.set_optim_enabled(**{name: True for name in step_params}) + for param_name, lr in step_params.items(): + optim_params.append( + dict( + params=new_module.get_optim_params(param_name), + lr=lr, + weight_decay=0.01, + betas=(0.9, 0.95), + eps=1e-10, + ) + ) + + official_optimize_module( + layer, + ([batch.to(device=device, dtype=dtype) for batch in train_input_batches], [batch.to(device=device, dtype=dtype) for batch in train_output_batches]), + ([batch.to(device=device, dtype=dtype) for batch in val_input_batches], [batch.to(device=device, dtype=dtype) for batch in val_output_batches]), + {k: (v.to(device=device) if isinstance(v, torch.Tensor) else v) for k, v in kwargs.items()}, + optim_params, + loss_fn="smooth_l1", + n_iter=epochs, + gradient_accumulation_steps=1, + early_stop=None, + post_optim_callback=lambda _module: [pseudo_module.reset_angles_by_mask() for pseudo_module in named_pseudo_modules.values()], + ) + _sync(device) + total_s = time.perf_counter() - start + + layer = layer.to(device=device, dtype=dtype) + with torch.no_grad(): + total = 0.0 + for inp, target in zip(val_input_batches, val_output_batches): + preds = layer(inp.to(device=device, dtype=dtype), **{k: (v.to(device=device) if isinstance(v, torch.Tensor) else v) for k, v in kwargs.items()}) + if isinstance(preds, tuple): + preds = preds[0] + total += float(F.smooth_l1_loss(preds, target.to(device=device, dtype=preds.dtype)).item()) + return BenchRow(case="official_layer", total_s=total_s, val_smoothl1=total / max(1, len(val_input_batches))) + + +def _load_first_layer_io(args: argparse.Namespace, device: torch.device, dtype: torch.dtype): + model = AutoModelForCausalLM.from_pretrained(args.model, dtype=dtype, device_map="cpu") + tokenizer = AutoTokenizer.from_pretrained(args.model) + tokenizer.pad_token = tokenizer.eos_token + move_embed(model, device) + blocks = get_blocks(model) + blocks[args.layer_idx].to(device) + + train_samples = torch.stack( + get_calib_dataset("pileval", tokenizer=tokenizer, n_samples=args.train_batches, block_size=args.block_size, seed=args.seed, split="train"), + dim=0, + ).to(device) + val_samples = torch.stack( + get_calib_dataset("pileval", tokenizer=tokenizer, n_samples=args.val_batches, block_size=args.block_size, seed=args.seed + 1, split="validation"), + dim=0, + ).to(device) + + train_input_batches, kwargs = catch_first_layer_input(model, blocks, train_samples, batch_size=1) + val_input_batches, _ = catch_first_layer_input(model, blocks, val_samples, batch_size=1) + layer = blocks[args.layer_idx].to(device) + + with torch.no_grad(): + train_output_batches = [] + for batch in train_input_batches: + out = layer(batch.to(device=device, dtype=dtype), **kwargs) + if isinstance(out, tuple): + out = out[0] + train_output_batches.append(out.detach().cpu()) + + val_output_batches = [] + for batch in val_input_batches: + out = layer(batch.to(device=device, dtype=dtype), **kwargs) + if isinstance(out, tuple): + out = out[0] + val_output_batches.append(out.detach().cpu()) + + base_layer = copy.deepcopy(layer).cpu() + kwargs = { + key: (value.detach().cpu() if isinstance(value, torch.Tensor) else value) + for key, value in kwargs.items() + if key not in ("past_key_value", "past_key_values") + } + kwargs["use_cache"] = False + + return { + "base_layer": base_layer, + "train_input_batches": [batch.detach().cpu() for batch in train_input_batches], + "val_input_batches": [batch.detach().cpu() for batch in val_input_batches], + "train_output_batches": train_output_batches, + "val_output_batches": val_output_batches, + "kwargs": kwargs, + } + + +def main() -> int: + args = parse_args() + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + dtype = _dtype_from_label(args.dtype) + module_names = DEFAULT_MODULE_NAMES + + print(f"[setup] loading layer {args.layer_idx} IO from {args.model}", flush=True) + captured = _load_first_layer_io(args, device, dtype) + + rows: list[BenchRow] = [] + requested_cases = set(args.cases or []) + cases = [ + ("local_module", lambda: _benchmark_local_module(module_names=module_names, device=device, dtype=dtype, args=args, layer_index=args.layer_idx, **captured)), + ("local_compute_block", lambda: _benchmark_local_group(opt_scope="compute_block", module_names=module_names, device=device, dtype=dtype, args=args, layer_index=args.layer_idx, **captured)), + ("local_layer", lambda: _benchmark_local_group(opt_scope="layer", module_names=module_names, device=device, dtype=dtype, args=args, layer_index=args.layer_idx, **captured)), + ] + if not args.skip_official: + cases.append(("official_layer", lambda: _benchmark_official_layer(device=device, dtype=dtype, args=args, **captured))) + + for label, fn in cases: + if requested_cases and label not in requested_cases: + continue + print(f"[run] {label}", flush=True) + if torch.cuda.is_available(): + torch.cuda.empty_cache() + row = fn() + rows.append(row) + print(f"[done] {label} total_s={row.total_s:.3f} val_smoothl1={row.val_smoothl1:.6e}", flush=True) + + table_rows = [[row.case, f"{row.total_s:.3f}", f"{row.val_smoothl1:.6e}"] for row in rows] + print( + tabulate( + table_rows, + headers=["case", "total_s", "val_smoothl1"], + tablefmt="grid", + ) + ) + + if args.output_json is not None: + args.output_json.parent.mkdir(parents=True, exist_ok=True) + args.output_json.write_text(json.dumps([asdict(row) for row in rows], indent=2), encoding="utf-8") + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/benchmark_paroquant_optimizer_real_workload.py b/scripts/benchmark_paroquant_optimizer_real_workload.py new file mode 100644 index 000000000..da4b4d1e4 --- /dev/null +++ b/scripts/benchmark_paroquant_optimizer_real_workload.py @@ -0,0 +1,326 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import argparse +import statistics +import time +from dataclasses import dataclass +from typing import Any + +import torch +from tabulate import tabulate +from transformers import AutoModelForCausalLM, AutoTokenizer + +from gptqmodel.quantization.paroquant.optimization import optimize_paroquant_linear +from gptqmodel.utils.paroquant_benchmark import ( + _normalize_model_dtype, + load_nm_calibration, +) + + +DEFAULT_MODEL = "/monster/data/model/Llama-3.2-1B-Instruct" +DEFAULT_MODULES = ("model.layers.0.self_attn.q_proj", "model.layers.0.mlp.down_proj") + + +@dataclass(frozen=True) +class BenchmarkCase: + label: str + stage_impl: str + pair_impl: str + quantizer_impl: str + + +CASES: tuple[BenchmarkCase, ...] = ( + BenchmarkCase("reference", "reference", "reference", "reference"), + BenchmarkCase("pair_fast", "reference", "fast", "reference"), + BenchmarkCase("stage_pair_fast", "fast", "fast", "reference"), + BenchmarkCase("stage_fast", "fast", "reference", "reference"), + BenchmarkCase("quant_fast", "reference", "reference", "fast"), + BenchmarkCase("all_fast", "fast", "fast", "fast"), +) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Benchmark ParoQuant optimizer implementations on real calibration activations." + ) + parser.add_argument("--model", default=DEFAULT_MODEL) + parser.add_argument( + "--module", + dest="modules", + action="append", + default=None, + help="Fully-qualified linear module path. Can be passed multiple times.", + ) + parser.add_argument("--model-dtype", default="fp16") + parser.add_argument("--calibration-rows", type=int, default=64) + parser.add_argument("--capture-rows", type=int, default=2048) + parser.add_argument("--bits", type=int, default=4) + parser.add_argument("--group-size", type=int, default=128) + parser.add_argument("--krot", type=int, default=8) + parser.add_argument("--pair-ratio", type=float, default=0.5) + parser.add_argument("--train-rows", type=int, default=2048) + parser.add_argument("--val-rows", type=int, default=64) + parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--rotation-epochs", type=int, default=10) + parser.add_argument("--finetune-epochs", type=int, default=10) + parser.add_argument("--rotation-lr", type=float, default=0.05) + parser.add_argument("--weight-lr", type=float, default=1e-5) + parser.add_argument("--quantizer-lr", type=float, default=1e-6) + parser.add_argument("--repeats", type=int, default=3) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--no-fused-rotation", action="store_true") + return parser.parse_args() + + +def _get_named_module(model, module_name: str): + module_map = dict(model.named_modules()) + if module_name not in module_map: + raise KeyError(f"Module `{module_name}` not found.") + return module_map[module_name] + + +def _tokenize_calibration_sample(tokenizer, sample: dict[str, Any]) -> dict[str, torch.Tensor]: + if "input_ids" in sample: + input_ids = torch.as_tensor(sample["input_ids"], dtype=torch.long) + if input_ids.ndim == 1: + input_ids = input_ids.unsqueeze(0) + attention_mask = sample.get("attention_mask") + if attention_mask is None: + attention_mask = torch.ones_like(input_ids, dtype=torch.long) + else: + attention_mask = torch.as_tensor(attention_mask, dtype=torch.long) + if attention_mask.ndim == 1: + attention_mask = attention_mask.unsqueeze(0) + return {"input_ids": input_ids, "attention_mask": attention_mask} + + if "messages" in sample: + rendered = tokenizer.apply_chat_template( + sample["messages"], + tokenize=False, + add_generation_prompt=False, + ) + tokenized = tokenizer(rendered, add_special_tokens=True, return_tensors="pt") + return { + "input_ids": tokenized["input_ids"].to(dtype=torch.long), + "attention_mask": tokenized.get("attention_mask", torch.ones_like(tokenized["input_ids"])).to(dtype=torch.long), + } + + if "text" in sample: + tokenized = tokenizer(sample["text"], add_special_tokens=True, return_tensors="pt") + return { + "input_ids": tokenized["input_ids"].to(dtype=torch.long), + "attention_mask": tokenized.get("attention_mask", torch.ones_like(tokenized["input_ids"])).to(dtype=torch.long), + } + + raise ValueError(f"Unsupported calibration sample keys: {sorted(sample.keys())}") + + +def _capture_module_inputs( + model, + tokenizer, + module_names: list[str], + calibration_dataset: list[dict[str, Any]], + *, + max_rows: int, +) -> dict[str, torch.Tensor]: + module_names_set = set(module_names) + captured: dict[str, list[torch.Tensor]] = {name: [] for name in module_names} + captured_rows = {name: 0 for name in module_names} + hooks = [] + + def make_hook(name: str): + def hook(_module, inputs): + if not inputs or captured_rows[name] >= max_rows: + return + x = inputs[0].detach().reshape(-1, inputs[0].shape[-1]).cpu() + remaining = max_rows - captured_rows[name] + if remaining <= 0: + return + piece = x[:remaining].contiguous() + if piece.numel() == 0: + return + captured[name].append(piece) + captured_rows[name] += piece.shape[0] + + return hook + + for name in module_names: + module = _get_named_module(model, name) + hooks.append(module.register_forward_pre_hook(make_hook(name))) + + model_device = next(model.parameters()).device + try: + for sample in calibration_dataset: + if all(count >= max_rows for count in captured_rows.values()): + break + tokenized = _tokenize_calibration_sample(tokenizer, sample) + with torch.inference_mode(): + model( + input_ids=tokenized["input_ids"].to(device=model_device), + attention_mask=tokenized["attention_mask"].to(device=model_device), + ) + finally: + for hook in hooks: + hook.remove() + + flattened: dict[str, torch.Tensor] = {} + for name in module_names_set: + pieces = captured[name] + if not pieces: + raise RuntimeError(f"Failed to capture calibration activations for `{name}`.") + flattened[name] = torch.cat(pieces, dim=0)[:max_rows].contiguous() + return flattened + + +def _bench_one_case( + *, + case: BenchmarkCase, + weight: torch.Tensor, + bias: torch.Tensor | None, + inputs: torch.Tensor, + args: argparse.Namespace, + run_idx: int, +) -> dict[str, float | str]: + seed = int(args.seed) + (run_idx * 1000) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats(weight.device) + + start = time.perf_counter() + result = optimize_paroquant_linear( + weight=weight, + bias=bias, + inputs=inputs, + bits=args.bits, + group_size=args.group_size, + sym=True, + krot=args.krot, + pair_ratio=args.pair_ratio, + train_rows=args.train_rows, + val_rows=args.val_rows, + batch_size=args.batch_size, + rotation_epochs=args.rotation_epochs, + finetune_epochs=args.finetune_epochs, + rotation_lr=args.rotation_lr, + weight_lr=args.weight_lr, + quantizer_lr=args.quantizer_lr, + seed=seed, + fused_rotation=not args.no_fused_rotation, + stage_impl=case.stage_impl, + pair_impl=case.pair_impl, + quantizer_impl=case.quantizer_impl, + ) + wall_s = time.perf_counter() - start + peak_bytes = 0 + if torch.cuda.is_available(): + peak_bytes = int(torch.cuda.max_memory_allocated(weight.device)) + return { + "label": case.label, + "wall_s": wall_s, + "train_loss": float(result.train_loss), + "val_loss": float(result.val_loss), + "peak_gb": peak_bytes / (1024**3), + } + + +def _summarize_runs(module_name: str, runs: list[dict[str, float | str]]) -> list[list[str]]: + baseline = next(run for run in runs if run["label"] == "reference") + baseline_wall = float(baseline["wall_s"]) + rows = [] + for label in [case.label for case in CASES]: + selected = [run for run in runs if run["label"] == label] + wall_values = [float(run["wall_s"]) for run in selected] + train_values = [float(run["train_loss"]) for run in selected] + val_values = [float(run["val_loss"]) for run in selected] + peak_values = [float(run["peak_gb"]) for run in selected] + median_wall = statistics.median(wall_values) + rows.append( + [ + module_name, + label, + f"{median_wall:.3f}", + f"{statistics.mean(wall_values):.3f}", + f"{baseline_wall / median_wall:.3f}x" if median_wall > 0 else "", + f"{statistics.mean(train_values):.6f}", + f"{statistics.mean(val_values):.6f}", + f"{statistics.mean(peak_values):.3f}", + ] + ) + return rows + + +def main() -> int: + args = parse_args() + modules = args.modules or list(DEFAULT_MODULES) + normalized_dtype = _normalize_model_dtype(args.model_dtype) + tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=False) + if getattr(tokenizer, "padding_side", None) != "left": + tokenizer.padding_side = "left" + if getattr(tokenizer, "pad_token_id", None) is None and getattr(tokenizer, "eos_token_id", None) is not None: + tokenizer.pad_token_id = tokenizer.eos_token_id + + model = AutoModelForCausalLM.from_pretrained( + args.model, + trust_remote_code=False, + torch_dtype=normalized_dtype, + low_cpu_mem_usage=True, + ) + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + model.to(device) + model.eval() + calibration_dataset = load_nm_calibration(args.calibration_rows) + + try: + captured = _capture_module_inputs( + model, + tokenizer, + modules, + calibration_dataset, + max_rows=max(args.capture_rows, args.train_rows + args.val_rows), + ) + rows: list[list[str]] = [] + for module_name in modules: + module = _get_named_module(model, module_name) + weight = module.weight.detach().to(device=module.weight.device, dtype=torch.float32).contiguous() + bias = None + if getattr(module, "bias", None) is not None: + bias = module.bias.detach().to(device=module.weight.device, dtype=torch.float32).contiguous() + inputs = captured[module_name].to(device=module.weight.device, dtype=torch.float32).contiguous() + + runs: list[dict[str, float | str]] = [] + for run_idx in range(args.repeats): + for case in CASES: + runs.append( + _bench_one_case( + case=case, + weight=weight, + bias=bias, + inputs=inputs, + args=args, + run_idx=run_idx, + ) + ) + rows.extend(_summarize_runs(module_name, runs)) + + print( + tabulate( + rows, + headers=["module", "case", "median_s", "mean_s", "vs_ref", "train_loss", "val_loss", "peak_gb"], + tablefmt="grid", + ) + ) + finally: + del model + if torch.cuda.is_available(): + torch.cuda.empty_cache() + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/benchmark_paroquant_pair_cache_real_workload.py b/scripts/benchmark_paroquant_pair_cache_real_workload.py new file mode 100644 index 000000000..30eb76e70 --- /dev/null +++ b/scripts/benchmark_paroquant_pair_cache_real_workload.py @@ -0,0 +1,494 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import argparse +import statistics +from dataclasses import dataclass +from typing import Any, Literal + +import torch +from tabulate import tabulate +from transformers import AutoModelForCausalLM, AutoTokenizer + +from gptqmodel.quantization.paroquant.optimization import ( + optimize_paroquant_linear, + _build_random_rotation_buffers_cached_cpu, + _clear_random_rotation_buffers_cache, + _warm_random_rotation_buffers_cache, +) +from gptqmodel.utils.paroquant_benchmark import ( + _normalize_model_dtype, + load_nm_calibration, +) + + +CacheStrategy = Literal["miss", "fixed", "fixed_preload"] + + +@dataclass(frozen=True) +class BenchmarkCase: + label: str + stage_impl: str + pair_impl: str + quantizer_impl: str + + +CASES: tuple[BenchmarkCase, ...] = ( + BenchmarkCase("reference", "reference", "reference", "reference"), + BenchmarkCase("pair_fast", "reference", "fast", "reference"), + BenchmarkCase("stage_pair_fast", "fast", "fast", "reference"), + BenchmarkCase("all_fast", "fast", "fast", "fast"), +) + +DEFAULT_MODEL = "/monster/data/model/Llama-3.2-1B-Instruct" +DEFAULT_MODULES = ( + "self_attn.q_proj", + "self_attn.o_proj", + "mlp.gate_proj", + "mlp.down_proj", +) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=( + "Benchmark ParoQuant pair-cache behavior on real calibration activations, " + "across explicit cache strategies." + ) + ) + parser.add_argument("--model", default=DEFAULT_MODEL) + parser.add_argument("--layers", type=int, default=3, help="Number of decoder layers to evaluate from layer 0.") + parser.add_argument("--model-dtype", default="fp16") + parser.add_argument("--calibration-rows", type=int, default=64) + parser.add_argument("--capture-rows", type=int, default=2048) + parser.add_argument("--bits", type=int, default=4) + parser.add_argument("--group-size", type=int, default=128) + parser.add_argument("--krot", type=int, default=8) + parser.add_argument("--pair-ratio", type=float, default=0.5) + parser.add_argument("--train-rows", type=int, default=2048) + parser.add_argument("--val-rows", type=int, default=64) + parser.add_argument("--batch-size", type=int, default=64) + parser.add_argument("--rotation-epochs", type=int, default=10) + parser.add_argument("--finetune-epochs", type=int, default=10) + parser.add_argument("--rotation-lr", type=float, default=0.05) + parser.add_argument("--weight-lr", type=float, default=1e-5) + parser.add_argument("--quantizer-lr", type=float, default=1e-6) + parser.add_argument("--repeats", type=int, default=1) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument( + "--cache-strategy", + choices=("miss", "fixed", "fixed_preload"), + default="fixed_preload", + help=( + "miss: force one pair-cache miss per module call; " + "fixed: keep seed fixed but do not preload schedule cache; " + "fixed_preload: keep seed fixed and preload per-config schedules once." + ), + ) + parser.add_argument( + "--no-fused-rotation", + action="store_true", + help="Pass fused_rotation=False to optimizer.", + ) + return parser.parse_args() + + +def _get_named_module(model, module_name: str): + module_map = dict(model.named_modules()) + if module_name not in module_map: + raise KeyError(f"Module `{module_name}` not found.") + return module_map[module_name] + + +def _module_list_for_layers(layer_count: int, module_suffixes: tuple[str, ...]) -> list[str]: + return [f"model.layers.{layer_idx}.{suffix}" for layer_idx in range(layer_count) for suffix in module_suffixes] + + +def _tokenize_calibration_sample(tokenizer, sample: dict[str, Any]) -> dict[str, torch.Tensor]: + if "input_ids" in sample: + input_ids = torch.as_tensor(sample["input_ids"], dtype=torch.long) + if input_ids.ndim == 1: + input_ids = input_ids.unsqueeze(0) + attention_mask = sample.get("attention_mask") + if attention_mask is None: + attention_mask = torch.ones_like(input_ids, dtype=torch.long) + else: + attention_mask = torch.as_tensor(attention_mask, dtype=torch.long) + if attention_mask.ndim == 1: + attention_mask = attention_mask.unsqueeze(0) + return {"input_ids": input_ids, "attention_mask": attention_mask} + + if "messages" in sample: + rendered = tokenizer.apply_chat_template( + sample["messages"], + tokenize=False, + add_generation_prompt=False, + ) + tokenized = tokenizer(rendered, add_special_tokens=True, return_tensors="pt") + return { + "input_ids": tokenized["input_ids"].to(dtype=torch.long), + "attention_mask": tokenized.get("attention_mask", torch.ones_like(tokenized["input_ids"])).to(dtype=torch.long), + } + + if "text" in sample: + tokenized = tokenizer(sample["text"], add_special_tokens=True, return_tensors="pt") + return { + "input_ids": tokenized["input_ids"].to(dtype=torch.long), + "attention_mask": tokenized.get("attention_mask", torch.ones_like(tokenized["input_ids"])).to(dtype=torch.long), + } + + raise ValueError(f"Unsupported calibration sample keys: {sorted(sample.keys())}") + + +def _capture_module_inputs( + model, + tokenizer, + module_names: list[str], + calibration_dataset: list[dict[str, Any]], + *, + max_rows: int, +) -> dict[str, torch.Tensor]: + module_names_set = set(module_names) + captured: dict[str, list[torch.Tensor]] = {name: [] for name in module_names} + captured_rows = {name: 0 for name in module_names} + hooks = [] + + def make_hook(name: str): + def hook(_module, inputs): + if not inputs or captured_rows[name] >= max_rows: + return + x = inputs[0].detach().reshape(-1, inputs[0].shape[-1]).cpu() + remaining = max_rows - captured_rows[name] + if remaining <= 0: + return + piece = x[:remaining].contiguous() + if piece.numel() == 0: + return + captured[name].append(piece) + captured_rows[name] += piece.shape[0] + + return hook + + for name in module_names: + module = _get_named_module(model, name) + hooks.append(module.register_forward_pre_hook(make_hook(name))) + + model_device = next(model.parameters()).device + try: + for sample in calibration_dataset: + if all(count >= max_rows for count in captured_rows.values()): + break + tokenized = _tokenize_calibration_sample(tokenizer, sample) + with torch.inference_mode(): + model( + input_ids=tokenized["input_ids"].to(device=model_device), + attention_mask=tokenized["attention_mask"].to(device=model_device), + ) + finally: + for hook in hooks: + hook.remove() + + flattened: dict[str, torch.Tensor] = {} + for name in module_names_set: + pieces = captured[name] + if not pieces: + raise RuntimeError(f"Failed to capture calibration activations for `{name}`.") + flattened[name] = torch.cat(pieces, dim=0)[:max_rows].contiguous() + return flattened + + +def _precompute_pair_cache( + *, + modules: list[str], + module_infos: dict[str, tuple[int, int, int, float, int]], + cache_keys: set[tuple[int, int, int, float, int]], + device: torch.device, +) -> None: + _clear_random_rotation_buffers_cache() + for in_features, group_size, krot, pair_ratio, seed in sorted(cache_keys): + _warm_random_rotation_buffers_cache( + in_features=in_features, + group_size=group_size, + krot=krot, + pair_ratio=pair_ratio, + seed=seed, + ) + + +def _run_case( + *, + case: BenchmarkCase, + weight: torch.Tensor, + bias: torch.Tensor | None, + inputs: torch.Tensor, + args: argparse.Namespace, + run_idx: int, + case_idx: int, + module_name: str, + module_seq: int, + cache_strategy: CacheStrategy, + cache_key: tuple[int, int, int, float, int], +) -> dict[str, Any]: + base_seed = int(args.seed) + seed = base_seed if cache_strategy != "miss" else base_seed + case_idx * 97 + run_idx * 1000 + module_seq * 17 + + if cache_strategy == "miss" and case.pair_impl == "fast": + _clear_random_rotation_buffers_cache() + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + + before = None + if case.pair_impl == "fast": + before = _build_random_rotation_buffers_cached_cpu.cache_info() + + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + start = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None + end = torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None + + if start is not None: + torch.cuda.synchronize() + start.record() + else: + t0 = torch.cuda.Event(enable_timing=False) + t0.record() if torch.cuda.is_available() else None + t0.synchronize() if torch.cuda.is_available() else None + + result = optimize_paroquant_linear( + weight=weight, + bias=bias, + inputs=inputs, + bits=args.bits, + group_size=args.group_size, + sym=True, + krot=args.krot, + pair_ratio=args.pair_ratio, + train_rows=args.train_rows, + val_rows=args.val_rows, + batch_size=args.batch_size, + rotation_epochs=args.rotation_epochs, + finetune_epochs=args.finetune_epochs, + rotation_lr=args.rotation_lr, + weight_lr=args.weight_lr, + quantizer_lr=args.quantizer_lr, + seed=seed, + fused_rotation=not args.no_fused_rotation, + stage_impl=case.stage_impl, + pair_impl=case.pair_impl, + quantizer_impl=case.quantizer_impl, + ) + + if torch.cuda.is_available(): + end.record() + torch.cuda.synchronize() + wall_s = float(start.elapsed_time(end)) / 1e3 + else: + end_fallback = torch.cuda.Event(enable_timing=False) if torch.cuda.is_available() else None + wall_s = float(end_fallback.elapsed_time(t0) / 1e3) if end_fallback is not None else 0.0 + + after = None + cache_hit = None + if case.pair_impl == "fast": + after = _build_random_rotation_buffers_cached_cpu.cache_info() + if before is not None and after is not None: + cache_hit = (after.hits - before.hits) > (after.misses - before.misses) + peak_bytes = int(torch.cuda.max_memory_allocated(weight.device)) if torch.cuda.is_available() else 0 + + return { + "label": case.label, + "module": module_name, + "wall_s": wall_s, + "train_loss": float(result.train_loss), + "val_loss": float(result.val_loss), + "peak_gb": peak_bytes / (1024**3), + "cache_hit": None if cache_hit is None else cache_hit, + "cache_miss": None if cache_hit is None else (not cache_hit), + "cache_key": cache_key, + } + + +def _collect_summary( + *, + module_label: str, + module_results: list[dict[str, Any]], +) -> list[str]: + """ + Return rows keyed by module-label (e.g. self_attn.q_proj at layer0..N). + """ + ref = [r for r in module_results if r["label"] == "reference"] + if not ref: + return [] + ref_wall = statistics.median([float(r["wall_s"]) for r in ref]) + rows: list[str] = [] + + for label in [case.label for case in CASES]: + selected = [r for r in module_results if r["label"] == label] + if not selected: + continue + wall_values = [float(r["wall_s"]) for r in selected] + cache_hits = [r["cache_hit"] for r in selected if r["cache_hit"] is not None] + median_wall = statistics.median(wall_values) + rows.append( + ( + module_label, + label, + f"{median_wall:.3f}", + f"{ref_wall / median_wall:.3f}x" if median_wall > 0 else "", + f"{statistics.mean(wall_values):.3f}", + f"{sum(1 for v in cache_hits if v)}/{len(cache_hits)}" + if cache_hits else "N/A", + ) + ) + return rows + + +def main() -> int: + args = parse_args() + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + normalized_dtype = _normalize_model_dtype(args.model_dtype) + tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=False) + if getattr(tokenizer, "padding_side", None) != "left": + tokenizer.padding_side = "left" + if getattr(tokenizer, "pad_token_id", None) is None and getattr(tokenizer, "eos_token_id", None) is not None: + tokenizer.pad_token_id = tokenizer.eos_token_id + + model = AutoModelForCausalLM.from_pretrained( + args.model, + trust_remote_code=False, + torch_dtype=normalized_dtype, + low_cpu_mem_usage=True, + ) + model.to(device) + model.eval() + + modules = _module_list_for_layers(args.layers, DEFAULT_MODULES) + calibration_dataset = load_nm_calibration(args.calibration_rows) + max_rows = max(args.capture_rows, args.train_rows + args.val_rows) + + try: + captured = _capture_module_inputs(model, tokenizer, modules, calibration_dataset, max_rows=max_rows) + cache_mode = args.cache_strategy + + if cache_mode == "fixed_preload": + unique_cache_keys = set() + for module_name in modules: + module = _get_named_module(model, module_name) + cache_seed = int(args.seed) + unique_cache_keys.add( + ( + module.weight.shape[1], + args.group_size, + args.krot, + float(args.pair_ratio), + cache_seed, + ) + ) + _precompute_pair_cache( + modules=modules, + module_infos={}, + cache_keys=unique_cache_keys, + device=device, + ) + elif cache_mode in {"fixed", "miss"}: + _clear_random_rotation_buffers_cache() + + runs: list[dict[str, Any]] = [] + for run_idx in range(args.repeats): + for module_idx, module_name in enumerate(modules): + module = _get_named_module(model, module_name) + weight = module.weight.detach().to(device=module.weight.device, dtype=torch.float32).contiguous() + bias = module.bias.detach().to(device=module.weight.device, dtype=torch.float32).contiguous() if getattr(module, "bias", None) is not None else None + inputs = captured[module_name].to(device=module.weight.device, dtype=torch.float32).contiguous() + cache_seed = int(args.seed) + cache_key = (weight.shape[1], args.group_size, args.krot, float(args.pair_ratio), cache_seed) + + for case_idx, case in enumerate(CASES): + runs.append( + _run_case( + case=case, + weight=weight, + bias=bias, + inputs=inputs, + args=args, + run_idx=run_idx, + case_idx=case_idx, + module_name=module_name, + module_seq=module_idx, + cache_strategy=cache_mode, + cache_key=cache_key, + ) + ) + + # summarize by short module name (relative to layer block) and full case + by_rel = {} + for run in runs: + rel = ".".join(run["module"].split(".")[-2:]) + by_rel.setdefault(rel, []).append(run) + + rows = [] + for rel, rel_runs in sorted(by_rel.items()): + rows.extend(_collect_summary(module_label=rel, module_results=rel_runs)) + + print(f"\ncache_strategy={cache_mode}, runs={args.repeats}, layers={args.layers}") + print( + tabulate( + rows, + headers=["module", "case", "median_s", "vs_ref", "mean_s", "pair_cache_hit_count"], + tablefmt="grid", + ) + ) + + print(f"\nCache Hits Detail (cache_strategy={cache_mode})") + details = [] + for rel, rel_runs in sorted(by_rel.items()): + rel_pairs = [r for r in rel_runs if r["label"] in {"pair_fast", "all_fast"}] + if not rel_pairs: + continue + cold = [float(r["wall_s"]) for r in rel_pairs if r["cache_hit"] is False] + warm = [float(r["wall_s"]) for r in rel_pairs if r["cache_hit"] is True] + total = [r["cache_hit"] for r in rel_pairs if r["cache_hit"] is not None] + if total: + hit_rate = 100.0 * sum(1 for v in total if v) / len(total) + else: + hit_rate = None + details.append( + [ + rel, + cache_mode, + f"{sum(1 for v in total if v == True)}/{len(total)}" if total else "N/A", + f"{(sum(1 for v in total if v is False) / max(1, len(total)) * 100.0):.1f}%" if total else "N/A", + f"{statistics.mean(cold):.3f}" if cold else "N/A", + f"{statistics.mean(warm):.3f}" if warm else "N/A", + ] + ) + + print( + tabulate( + details, + headers=[ + "module", + "cache_strategy", + "pair_cache_hits", + "miss_rate_%", + "cold_ms (miss)", + "warm_ms (hit)", + ], + tablefmt="grid", + ) + ) + finally: + del model + if torch.cuda.is_available(): + torch.cuda.empty_cache() + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/benchmark_paroquant_rotation_ab.py b/scripts/benchmark_paroquant_rotation_ab.py new file mode 100644 index 000000000..5af92da68 --- /dev/null +++ b/scripts/benchmark_paroquant_rotation_ab.py @@ -0,0 +1,334 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import argparse +import json +import math +import os +import time +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any + +import torch +from tabulate import tabulate + +from gptqmodel.quantization.paroquant.optimization import build_random_rotation_buffers +from gptqmodel.utils.paroquant import ( + _rotation_launch_config, + apply_paroquant_rotation, + apply_paroquant_rotation_reference, + clear_paroquant_rotation_extension_cache, + prewarm_paroquant_rotation_extension, +) + + +@dataclass(frozen=True) +class BenchCase: + case_id: str + rows: int + hidden: int + group_size: int = 128 + krot: int = 8 + pair_ratio: float = 0.5 + + +DEFAULT_CASES: tuple[BenchCase, ...] = ( + BenchCase("decode_h2048_r1", rows=1, hidden=2048), + BenchCase("decode_h2048_r4", rows=4, hidden=2048), + BenchCase("decode_h2048_r8", rows=8, hidden=2048), + BenchCase("prefill_h2048_r128", rows=128, hidden=2048), + BenchCase("batch_h2048_r512", rows=512, hidden=2048), + BenchCase("batch_h2048_r2048", rows=2048, hidden=2048), + BenchCase("decode_h4096_r1", rows=1, hidden=4096), + BenchCase("decode_h4096_r8", rows=8, hidden=4096), + BenchCase("prefill_h4096_r128", rows=128, hidden=4096), + BenchCase("batch_h4096_r512", rows=512, hidden=4096), + BenchCase("batch_h4096_r2048", rows=2048, hidden=4096), + BenchCase("decode_h8192_r1", rows=1, hidden=8192), + BenchCase("decode_h8192_r4", rows=4, hidden=8192), + BenchCase("decode_h8192_r8", rows=8, hidden=8192), + BenchCase("prefill_h8192_r128", rows=128, hidden=8192), + BenchCase("batch_h8192_r512", rows=512, hidden=8192), + BenchCase("batch_h8192_r1024", rows=1024, hidden=8192), + BenchCase("batch_h8192_r2048", rows=2048, hidden=8192), +) + +QUICK_CASES: tuple[BenchCase, ...] = ( + BenchCase("decode_h2048_r1", rows=1, hidden=2048), + BenchCase("prefill_h2048_r128", rows=128, hidden=2048), + BenchCase("batch_h4096_r512", rows=512, hidden=4096), + BenchCase("batch_h8192_r1024", rows=1024, hidden=8192), +) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Benchmark the fused ParoQuant CUDA rotation kernel.") + parser.add_argument("--device", type=int, default=0, help="CUDA device index within the current visible set.") + parser.add_argument("--dtype", choices=("fp16", "bf16", "fp32"), default="fp16") + parser.add_argument("--warmup", type=int, default=10) + parser.add_argument("--iters", type=int, default=50) + parser.add_argument("--seed", type=int, default=0) + parser.add_argument("--quick", action="store_true") + parser.add_argument("--shard-index", type=int, default=0) + parser.add_argument("--num-shards", type=int, default=1) + parser.add_argument("--json-out", type=Path, default=None) + parser.add_argument("--json", action="store_true", help="Print the full payload as JSON.") + parser.add_argument("--force-rebuild-extension", action="store_true") + return parser.parse_args() + + +def _resolve_dtype(name: str) -> torch.dtype: + if name == "fp16": + return torch.float16 + if name == "bf16": + return torch.bfloat16 + if name == "fp32": + return torch.float32 + raise ValueError(f"Unsupported dtype: {name}") + + +def _torch_device_name(device: torch.device) -> str: + if device.type != "cuda": + return "cpu" + return torch.cuda.get_device_name(device) + + +def _subset_cases(cases: tuple[BenchCase, ...], shard_index: int, num_shards: int) -> list[BenchCase]: + if num_shards <= 0: + raise ValueError("`num_shards` must be positive.") + if shard_index < 0 or shard_index >= num_shards: + raise ValueError(f"`shard_index` must be in [0, {num_shards - 1}].") + return [case for idx, case in enumerate(cases) if idx % num_shards == shard_index] + + +def _rotation_bandwidth_gbps(case: BenchCase, dtype: torch.dtype, elapsed_ms: float) -> float: + element_size = torch.tensor([], dtype=dtype).element_size() + total_bytes = case.rows * case.hidden * element_size * 2 + return total_bytes / (elapsed_ms * 1e-3) / 1e9 + + +def _benchmark_ms(fn, device: torch.device, warmup: int, iters: int) -> float: + with torch.inference_mode(): + for _ in range(warmup): + fn() + torch.cuda.synchronize(device) + start = time.perf_counter() + for _ in range(iters): + fn() + torch.cuda.synchronize(device) + return (time.perf_counter() - start) * 1e3 / iters + + +def _make_case_inputs(case: BenchCase, dtype: torch.dtype, device: torch.device, seed: int) -> dict[str, torch.Tensor]: + generator = torch.Generator(device="cpu") + generator.manual_seed(seed) + + x = torch.randn((case.rows, case.hidden), generator=generator, dtype=torch.float32).to(device=device, dtype=dtype) + pairs, _mask = build_random_rotation_buffers( + in_features=case.hidden, + group_size=case.group_size, + krot=case.krot, + pair_ratio=case.pair_ratio, + seed=seed, + device=device, + ) + theta = torch.empty((case.krot, case.hidden // 2), dtype=torch.float32, device="cpu") + theta.uniform_(-0.25, 0.25, generator=generator) + theta = theta.to(device=device, dtype=dtype) + scales = torch.empty((1, case.hidden), dtype=torch.float32, device="cpu") + scales.uniform_(0.75, 1.25, generator=generator) + scales = scales.to(device=device, dtype=dtype) + return { + "x": x.contiguous(), + "pairs": pairs.contiguous(), + "theta": theta.contiguous(), + "scales": scales.contiguous(), + } + + +def run(device: torch.device, dtype: torch.dtype, warmup: int, iters: int, quick: bool, seed: int, shard_index: int, num_shards: int) -> dict[str, Any]: + if device.type != "cuda": + raise RuntimeError("CUDA is required for the ParoQuant rotation benchmark.") + + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + + selected_cases = _subset_cases(QUICK_CASES if quick else DEFAULT_CASES, shard_index=shard_index, num_shards=num_shards) + rows: list[dict[str, Any]] = [] + + for case_index, case in enumerate(selected_cases): + case_seed = seed + (case_index * 17) + inputs = _make_case_inputs(case, dtype=dtype, device=device, seed=case_seed) + x = inputs["x"] + pairs = inputs["pairs"] + theta = inputs["theta"] + scales = inputs["scales"] + + with torch.inference_mode(): + fused = apply_paroquant_rotation(x, pairs, theta, scales=scales, group_size=case.group_size) + reference = apply_paroquant_rotation_reference(x, pairs, theta, scales=scales, group_size=case.group_size) + cta_m, row_pad = _rotation_launch_config(x, pairs, theta, scales=scales, group_size=case.group_size) + + diff = (fused - reference).abs() + fp32_metrics = { + "fused_fp32_max_abs": None, + "fused_fp32_mean_abs": None, + "reference_fp32_max_abs": None, + "reference_fp32_mean_abs": None, + } + if dtype != torch.float32: + fp32_inputs = _make_case_inputs(case, dtype=torch.float32, device=device, seed=case_seed) + with torch.inference_mode(): + reference_fp32 = apply_paroquant_rotation_reference( + fp32_inputs["x"], + fp32_inputs["pairs"], + fp32_inputs["theta"], + scales=fp32_inputs["scales"], + group_size=case.group_size, + ) + fused_fp32_diff = (fused.float() - reference_fp32).abs() + reference_fp32_diff = (reference.float() - reference_fp32).abs() + fp32_metrics = { + "fused_fp32_max_abs": fused_fp32_diff.max().item(), + "fused_fp32_mean_abs": fused_fp32_diff.mean().item(), + "reference_fp32_max_abs": reference_fp32_diff.max().item(), + "reference_fp32_mean_abs": reference_fp32_diff.mean().item(), + } + elapsed_ms = _benchmark_ms( + lambda: apply_paroquant_rotation(x, pairs, theta, scales=scales, group_size=case.group_size), + device=device, + warmup=warmup, + iters=iters, + ) + rows.append( + { + **asdict(case), + "dtype": str(dtype).replace("torch.", ""), + "cta_m": cta_m, + "row_pad": row_pad, + "latency_ms": elapsed_ms, + "gbps": _rotation_bandwidth_gbps(case, dtype, elapsed_ms), + "max_abs": diff.max().item(), + "mean_abs": diff.mean().item(), + **fp32_metrics, + } + ) + + geo_mean_ms = math.exp(sum(math.log(row["latency_ms"]) for row in rows) / len(rows)) if rows else float("nan") + geo_mean_gbps = math.exp(sum(math.log(row["gbps"]) for row in rows) / len(rows)) if rows else float("nan") + + return { + "device": _torch_device_name(device), + "cuda_device": str(device), + "dtype": str(dtype).replace("torch.", ""), + "warmup": warmup, + "iters": iters, + "seed": seed, + "quick": quick, + "shard_index": shard_index, + "num_shards": num_shards, + "geo_mean_ms": geo_mean_ms, + "geo_mean_gbps": geo_mean_gbps, + "rows": rows, + } + + +def _print_ascii(results: dict[str, Any]) -> None: + print(f"Device: {results['device']} ({results['cuda_device']})") + print( + tabulate( + [ + [ + row["case_id"], + row["rows"], + row["hidden"], + row["dtype"], + row["cta_m"], + row["row_pad"], + f"{row['latency_ms']:.3f}", + f"{row['gbps']:.1f}", + f"{row['max_abs']:.6f}", + f"{row['mean_abs']:.6f}", + "-" if row["fused_fp32_max_abs"] is None else f"{row['fused_fp32_max_abs']:.6f}", + "-" if row["fused_fp32_mean_abs"] is None else f"{row['fused_fp32_mean_abs']:.6f}", + "-" if row["reference_fp32_max_abs"] is None else f"{row['reference_fp32_max_abs']:.6f}", + "-" if row["reference_fp32_mean_abs"] is None else f"{row['reference_fp32_mean_abs']:.6f}", + ] + for row in results["rows"] + ], + headers=[ + "case", + "rows", + "hidden", + "dtype", + "cta_m", + "row_pad", + "latency_ms", + "gbps", + "fused vs ref max_abs", + "fused vs ref mean_abs", + "fused vs fp32 max_abs", + "fused vs fp32 mean_abs", + "ref vs fp32 max_abs", + "ref vs fp32 mean_abs", + ], + tablefmt="plain", + ) + ) + print(f"geo_mean_ms {results['geo_mean_ms']:.3f}") + print(f"geo_mean_gbps {results['geo_mean_gbps']:.1f}") + + +def main() -> int: + args = parse_args() + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for the ParoQuant rotation benchmark.") + + device = torch.device(f"cuda:{args.device}") + if args.force_rebuild_extension: + build_root = Path("/tmp") / ( + f"paroquant_ext_{os.getpid()}_dev{args.device}_shard{args.shard_index}_of_{args.num_shards}" + ) + os.environ["GPTQMODEL_PAROQUANT_BUILD_ROOT"] = str(build_root) + os.environ["GPTQMODEL_PAROQUANT_FORCE_REBUILD"] = "1" + clear_paroquant_rotation_extension_cache() + else: + os.environ.pop("GPTQMODEL_PAROQUANT_BUILD_ROOT", None) + os.environ.pop("GPTQMODEL_PAROQUANT_FORCE_REBUILD", None) + + if not prewarm_paroquant_rotation_extension( + fused_rotation=True, + group_size=128, + krot=8, + device=device, + ): + raise RuntimeError("Failed to build/load the fused ParoQuant CUDA rotation extension.") + + results = run( + device=device, + dtype=_resolve_dtype(args.dtype), + warmup=args.warmup, + iters=args.iters, + quick=args.quick, + seed=args.seed, + shard_index=args.shard_index, + num_shards=args.num_shards, + ) + + _print_ascii(results) + if args.json: + print(json.dumps(results, indent=2)) + if args.json_out is not None: + args.json_out.parent.mkdir(parents=True, exist_ok=True) + args.json_out.write_text(json.dumps(results, indent=2), encoding="utf-8") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/benchmark_paroquant_rotation_cache_ab.py b/scripts/benchmark_paroquant_rotation_cache_ab.py new file mode 100644 index 000000000..f85a8e399 --- /dev/null +++ b/scripts/benchmark_paroquant_rotation_cache_ab.py @@ -0,0 +1,417 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import argparse +import json +import math +import os +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import torch +from tabulate import tabulate + + +@dataclass(frozen=True) +class BenchCase: + case_id: str + batch: int + seq: int + in_features: int + out_features: int + group_size: int = 128 + krot: int = 8 + + +DEFAULT_CASES = [ + BenchCase("decode_q_proj", batch=1, seq=1, in_features=2048, out_features=2048), + BenchCase("prefill_q_proj", batch=1, seq=128, in_features=2048, out_features=2048), + BenchCase("batched_q_proj", batch=4, seq=128, in_features=2048, out_features=2048), + BenchCase("prefill_k_proj", batch=1, seq=128, in_features=2048, out_features=512), + BenchCase("prefill_gate_proj", batch=1, seq=128, in_features=2048, out_features=8192), + BenchCase("batched_down_proj", batch=4, seq=128, in_features=8192, out_features=2048), +] + +QUICK_CASES = [ + BenchCase("decode_q_proj", batch=1, seq=1, in_features=2048, out_features=2048), + BenchCase("prefill_q_proj", batch=1, seq=128, in_features=2048, out_features=2048), + BenchCase("batched_down_proj", batch=4, seq=128, in_features=8192, out_features=2048), +] + + +def _resolve_dtype(name: str) -> torch.dtype: + if name == "fp16": + return torch.float16 + if name == "bf16": + return torch.bfloat16 + raise ValueError(f"Unsupported dtype: {name}") + + +def _subset_cases(cases: list[BenchCase], shard_index: int, num_shards: int) -> list[BenchCase]: + if num_shards <= 0: + raise ValueError("`num_shards` must be positive.") + if shard_index < 0 or shard_index >= num_shards: + raise ValueError(f"`shard_index` must be in [0, {num_shards - 1}].") + return [case for index, case in enumerate(cases) if index % num_shards == shard_index] + + +def _benchmark_ms(fn, device: torch.device, warmup: int, iters: int) -> float: + with torch.inference_mode(): + for _ in range(warmup): + fn() + torch.cuda.synchronize(device) + start = time.perf_counter() + for _ in range(iters): + fn() + torch.cuda.synchronize(device) + return (time.perf_counter() - start) * 1e3 / iters + + +def _pack_awq_tensor(unpacked: torch.Tensor, bits: int) -> torch.Tensor: + pack_factor = 32 // bits + order_map = [0, 2, 4, 6, 1, 3, 5, 7] + packed = torch.zeros((unpacked.shape[0], unpacked.shape[1] // pack_factor), dtype=torch.int32) + for col in range(unpacked.shape[1] // pack_factor): + for i, order in enumerate(order_map): + value = unpacked[:, col * pack_factor + order].to(torch.int32) + packed[:, col] |= value << (i * bits) + return packed + + +def _make_quant_buffers(case: BenchCase, dtype: torch.dtype, bits: int = 4) -> dict[str, torch.Tensor]: + from gptqmodel.utils.paroquant import build_identity_rotation_buffers + + groups = case.in_features // case.group_size + int_weight = torch.randint(0, 2**bits, size=(case.in_features, case.out_features), dtype=torch.int32) + zero_points = torch.randint(0, 2**bits, size=(groups, case.out_features), dtype=torch.int32) + scales = (torch.rand(groups, case.out_features, dtype=torch.float32) * 0.5) + 0.75 + scales = scales.to(dtype=dtype) + bias = torch.randn(case.out_features, dtype=torch.float32).to(dtype=dtype) + + pairs, theta, channel_scales = build_identity_rotation_buffers( + in_features=case.in_features, + group_size=case.group_size, + krot=case.krot, + dtype=dtype, + ) + theta.uniform_(-0.2, 0.2) + channel_scales.uniform_(0.75, 1.25) + + return { + "qweight": _pack_awq_tensor(int_weight, bits), + "qzeros": _pack_awq_tensor(zero_points, bits), + "scales": scales, + "bias": bias, + "pairs": pairs, + "theta": theta, + "channel_scales": channel_scales, + } + + +def _make_module( + case: BenchCase, + dtype: torch.dtype, + device: torch.device, + buffers: dict[str, torch.Tensor], + auto_cache_bf16_rotation_dtype: bool, +): + from gptqmodel.nn_modules.qlinear.paroquant import ParoLinear + + module = ParoLinear( + bits=4, + group_size=case.group_size, + sym=True, + desc_act=False, + in_features=case.in_features, + out_features=case.out_features, + bias=True, + register_buffers=True, + krot=case.krot, + cache_runtime_dtype=False, + auto_cache_bf16_runtime_dtype=True, + cache_rotation_dtype=False, + auto_cache_bf16_rotation_dtype=auto_cache_bf16_rotation_dtype, + ).to(device) + module.qweight.copy_(buffers["qweight"].to(device)) + module.qzeros.copy_(buffers["qzeros"].to(device)) + module.scales.copy_(buffers["scales"].to(device)) + module.bias.copy_(buffers["bias"].to(device)) + module.pairs.copy_(buffers["pairs"].to(device)) + module.theta.copy_(buffers["theta"].to(device=device, dtype=module.theta.dtype)) + module.channel_scales.copy_(buffers["channel_scales"].to(device=device, dtype=module.channel_scales.dtype)) + module.post_init() + module.eval() + return module + + +def _dense_reference(module, x: torch.Tensor) -> torch.Tensor: + from gptqmodel.quantization.awq.utils.packing_utils import dequantize_gemm + from gptqmodel.utils.paroquant import apply_paroquant_rotation_reference + + rotated = apply_paroquant_rotation_reference( + x, + module.pairs, + module.theta, + scales=module.channel_scales, + group_size=module.group_size, + ) + weight = dequantize_gemm( + qweight=module.qweight, + qzeros=module.qzeros, + scales=module.scales, + bits=module.bits, + group_size=module.group_size, + ).to(device=x.device, dtype=x.dtype) + out = torch.matmul(rotated, weight) + if module.bias is not None: + out = out + module.bias.to(device=x.device, dtype=x.dtype) + return out + + +def _format_speedup(speedup: float) -> str: + return f"{speedup:.3f}x" + + +def run( + device: torch.device, + dtype: torch.dtype, + warmup: int, + iters: int, + quick: bool, + shard_index: int, + num_shards: int, +) -> dict[str, Any]: + torch.manual_seed(0) + torch.cuda.manual_seed_all(0) + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + + rows = [] + speedups = [] + selected_cases = _subset_cases(QUICK_CASES if quick else DEFAULT_CASES, shard_index=shard_index, num_shards=num_shards) + + for index, case in enumerate(selected_cases): + torch.manual_seed(5000 + index) + buffers = _make_quant_buffers(case, dtype=dtype) + baseline = _make_module( + case, + dtype=dtype, + device=device, + buffers=buffers, + auto_cache_bf16_rotation_dtype=False, + ) + candidate = _make_module( + case, + dtype=dtype, + device=device, + buffers=buffers, + auto_cache_bf16_rotation_dtype=True, + ) + x = torch.randn((case.batch, case.seq, case.in_features), device=device, dtype=dtype) + + with torch.inference_mode(): + dense = _dense_reference(baseline, x.reshape(-1, x.shape[-1])).reshape(case.batch, case.seq, case.out_features) + baseline_out = baseline(x) + candidate_out = candidate(x) + + baseline_dense = (baseline_out - dense).abs() + candidate_dense = (candidate_out - dense).abs() + baseline_candidate = (baseline_out - candidate_out).abs() + + baseline_ms = _benchmark_ms(lambda: baseline(x), device=device, warmup=warmup, iters=iters) + candidate_ms = _benchmark_ms(lambda: candidate(x), device=device, warmup=warmup, iters=iters) + speedup = baseline_ms / candidate_ms + speedups.append(speedup) + + rows.append( + { + "case_id": case.case_id, + "batch": case.batch, + "seq": case.seq, + "in_features": case.in_features, + "out_features": case.out_features, + "dtype": str(dtype).replace("torch.", ""), + "baseline_ms": baseline_ms, + "candidate_ms": candidate_ms, + "speedup": speedup, + "winner": "rotation_cache_on" if candidate_ms < baseline_ms else "rotation_cache_off", + "baseline_dense_max_abs": baseline_dense.max().item(), + "candidate_dense_max_abs": candidate_dense.max().item(), + "baseline_candidate_max_abs": baseline_candidate.max().item(), + "baseline_candidate_mean_abs": baseline_candidate.mean().item(), + } + ) + + geo_mean_speedup = math.exp(sum(math.log(v) for v in speedups) / len(speedups)) if speedups else float("nan") + return { + "device": torch.cuda.get_device_name(device), + "cuda_device": str(device), + "dtype": str(dtype).replace("torch.", ""), + "warmup": warmup, + "iters": iters, + "quick": quick, + "shard_index": shard_index, + "num_shards": num_shards, + "rows": rows, + "geo_mean_speedup": geo_mean_speedup, + "candidate_wins": sum(1 for v in speedups if v > 1.0), + "case_count": len(speedups), + } + + +def _configure_runtime(args: argparse.Namespace, device: torch.device) -> None: + if args.force_rebuild_awq: + awq_build_root = Path("/tmp") / ( + f"awq_jit_rotcache_{os.getpid()}_dev{args.device}_shard{args.shard_index}_of_{args.num_shards}" + ) + os.environ["GPTQMODEL_AWQ_BUILD_ROOT"] = str(awq_build_root) + os.environ["GPTQMODEL_AWQ_FORCE_REBUILD"] = "1" + else: + os.environ.pop("GPTQMODEL_AWQ_BUILD_ROOT", None) + os.environ.pop("GPTQMODEL_AWQ_FORCE_REBUILD", None) + + if args.force_rebuild_paroquant: + paro_build_root = Path("/tmp") / ( + f"paroquant_ext_rotcache_{os.getpid()}_dev{args.device}_shard{args.shard_index}_of_{args.num_shards}" + ) + os.environ["GPTQMODEL_PAROQUANT_BUILD_ROOT"] = str(paro_build_root) + os.environ["GPTQMODEL_PAROQUANT_FORCE_REBUILD"] = "1" + else: + os.environ.pop("GPTQMODEL_PAROQUANT_BUILD_ROOT", None) + os.environ.pop("GPTQMODEL_PAROQUANT_FORCE_REBUILD", None) + + from gptqmodel.utils.awq import clear_awq_extension_cache, prewarm_awq_extension + from gptqmodel.utils.paroquant import clear_paroquant_rotation_extension_cache, prewarm_paroquant_rotation_extension + + if args.force_rebuild_awq: + clear_awq_extension_cache() + if args.force_rebuild_paroquant: + clear_paroquant_rotation_extension_cache() + + if not prewarm_awq_extension(): + raise RuntimeError("Failed to build/load the AWQ CUDA runtime.") + if not prewarm_paroquant_rotation_extension( + fused_rotation=True, + group_size=128, + krot=8, + device=device, + ): + raise RuntimeError("Failed to build/load the ParoQuant CUDA rotation runtime.") + + +def _print_ascii(results: dict[str, Any]) -> None: + print(f"Device: {results['device']} ({results['cuda_device']}, dtype={results['dtype']})") + print() + print("Accuracy") + print( + tabulate( + [ + [ + row["dtype"], + row["case_id"], + f"{row['batch']}x{row['seq']}", + f"{row['in_features']}->{row['out_features']}", + f"{row['baseline_dense_max_abs']:.6f}", + f"{row['candidate_dense_max_abs']:.6f}", + f"{row['baseline_candidate_max_abs']:.6f}", + f"{row['baseline_candidate_mean_abs']:.6f}", + ] + for row in results["rows"] + ], + headers=[ + "dtype", + "case", + "batch x seq", + "shape", + "rotation_cache_off vs dense max_abs", + "rotation_cache_on vs dense max_abs", + "off vs on max_abs", + "off vs on mean_abs", + ], + tablefmt="plain", + ) + ) + print() + print("Benchmark") + print( + tabulate( + [ + [ + row["dtype"], + row["case_id"], + f"{row['batch']}x{row['seq']}", + f"{row['in_features']}->{row['out_features']}", + f"{row['baseline_ms']:.3f}", + f"{row['candidate_ms']:.3f}", + _format_speedup(row["speedup"]), + row["winner"], + ] + for row in results["rows"] + ], + headers=[ + "dtype", + "case", + "batch x seq", + "shape", + "rotation_cache_off ms", + "rotation_cache_on ms", + "speedup", + "winner", + ], + tablefmt="plain", + ) + ) + print() + print( + "Summary: " + f"candidate_wins={results['candidate_wins']}/{results['case_count']}, " + f"geo_mean_speedup={results['geo_mean_speedup']:.3f}x" + ) + + +def main() -> int: + parser = argparse.ArgumentParser(description="A/B benchmark ParoQuant BF16 rotation-metadata caching.") + parser.add_argument("--device", type=int, default=0) + parser.add_argument("--dtype", choices=("fp16", "bf16"), default="fp16") + parser.add_argument("--warmup", type=int, default=5) + parser.add_argument("--iters", type=int, default=20) + parser.add_argument("--quick", action="store_true") + parser.add_argument("--shard-index", type=int, default=0) + parser.add_argument("--num-shards", type=int, default=1) + parser.add_argument("--json-out", type=Path, default=None) + parser.add_argument("--json", action="store_true") + parser.add_argument("--force-rebuild-awq", action="store_true") + parser.add_argument("--force-rebuild-paroquant", action="store_true") + args = parser.parse_args() + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for the ParoQuant rotation-cache benchmark.") + + device = torch.device(f"cuda:{args.device}") + _configure_runtime(args, device) + results = run( + device=device, + dtype=_resolve_dtype(args.dtype), + warmup=args.warmup, + iters=args.iters, + quick=args.quick, + shard_index=args.shard_index, + num_shards=args.num_shards, + ) + _print_ascii(results) + + if args.json_out is not None: + args.json_out.parent.mkdir(parents=True, exist_ok=True) + args.json_out.write_text(json.dumps(results, indent=2), encoding="utf-8") + if args.json: + print(json.dumps(results, indent=2)) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/benchmark_paroquant_runtime_cache_ab.py b/scripts/benchmark_paroquant_runtime_cache_ab.py new file mode 100644 index 000000000..26e5c5a09 --- /dev/null +++ b/scripts/benchmark_paroquant_runtime_cache_ab.py @@ -0,0 +1,431 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import argparse +import json +import math +import os +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import torch +from tabulate import tabulate + + +@dataclass(frozen=True) +class BenchCase: + case_id: str + batch: int + seq: int + in_features: int + out_features: int + group_size: int = 128 + krot: int = 8 + + +DEFAULT_CASES = [ + BenchCase("decode_q_proj", batch=1, seq=1, in_features=2048, out_features=2048), + BenchCase("prefill_q_proj", batch=1, seq=128, in_features=2048, out_features=2048), + BenchCase("batched_q_proj", batch=4, seq=128, in_features=2048, out_features=2048), + BenchCase("prefill_k_proj", batch=1, seq=128, in_features=2048, out_features=512), + BenchCase("prefill_gate_proj", batch=1, seq=128, in_features=2048, out_features=8192), + BenchCase("batched_down_proj", batch=4, seq=128, in_features=8192, out_features=2048), +] + +QUICK_CASES = [ + BenchCase("decode_q_proj", batch=1, seq=1, in_features=2048, out_features=2048), + BenchCase("prefill_q_proj", batch=1, seq=128, in_features=2048, out_features=2048), + BenchCase("batched_down_proj", batch=4, seq=128, in_features=8192, out_features=2048), +] + + +def _resolve_dtype(name: str) -> torch.dtype: + if name == "fp16": + return torch.float16 + if name == "bf16": + return torch.bfloat16 + raise ValueError(f"Unsupported dtype: {name}") + + +def _subset_cases(cases: list[BenchCase], shard_index: int, num_shards: int) -> list[BenchCase]: + if num_shards <= 0: + raise ValueError("`num_shards` must be positive.") + if shard_index < 0 or shard_index >= num_shards: + raise ValueError(f"`shard_index` must be in [0, {num_shards - 1}].") + return [case for index, case in enumerate(cases) if index % num_shards == shard_index] + + +def _benchmark_ms(fn, device: torch.device, warmup: int, iters: int) -> float: + with torch.inference_mode(): + for _ in range(warmup): + fn() + torch.cuda.synchronize(device) + start = time.perf_counter() + for _ in range(iters): + fn() + torch.cuda.synchronize(device) + return (time.perf_counter() - start) * 1e3 / iters + + +def _pack_awq_tensor(unpacked: torch.Tensor, bits: int) -> torch.Tensor: + pack_factor = 32 // bits + order_map = [0, 2, 4, 6, 1, 3, 5, 7] + packed = torch.zeros((unpacked.shape[0], unpacked.shape[1] // pack_factor), dtype=torch.int32) + for col in range(unpacked.shape[1] // pack_factor): + for i, order in enumerate(order_map): + value = unpacked[:, col * pack_factor + order].to(torch.int32) + packed[:, col] |= value << (i * bits) + return packed + + +def _make_quant_buffers(case: BenchCase, dtype: torch.dtype, bits: int = 4) -> dict[str, torch.Tensor]: + from gptqmodel.utils.paroquant import build_identity_rotation_buffers + + groups = case.in_features // case.group_size + int_weight = torch.randint(0, 2**bits, size=(case.in_features, case.out_features), dtype=torch.int32) + zero_points = torch.randint(0, 2**bits, size=(groups, case.out_features), dtype=torch.int32) + scales = (torch.rand(groups, case.out_features, dtype=torch.float32) * 0.5) + 0.75 + scales = scales.to(dtype=dtype) + bias = torch.randn(case.out_features, dtype=torch.float32).to(dtype=dtype) + + pairs, theta, channel_scales = build_identity_rotation_buffers( + in_features=case.in_features, + group_size=case.group_size, + krot=case.krot, + dtype=dtype, + ) + theta.uniform_(-0.2, 0.2) + channel_scales.uniform_(0.75, 1.25) + + return { + "qweight": _pack_awq_tensor(int_weight, bits), + "qzeros": _pack_awq_tensor(zero_points, bits), + "scales": scales, + "bias": bias, + "pairs": pairs, + "theta": theta, + "channel_scales": channel_scales, + } + + +def _clone_module(module): + cloned = type(module)( + bits=module.bits, + group_size=module.group_size, + sym=module.sym, + desc_act=module.desc_act, + in_features=module.in_features, + out_features=module.out_features, + bias=module.bias is not None, + pack_dtype=module.pack_dtype, + register_buffers=True, + krot=module.krot, + fp32_accum=module.fp32_accum, + cache_runtime_dtype=module.cache_runtime_dtype, + auto_cache_bf16_runtime_dtype=module.auto_cache_bf16_runtime_dtype, + cache_rotation_dtype=module.cache_rotation_dtype, + auto_cache_bf16_rotation_dtype=module.auto_cache_bf16_rotation_dtype, + ).to(module.qweight.device) + cloned.qweight.copy_(module.qweight) + cloned.qzeros.copy_(module.qzeros) + cloned.scales.copy_(module.scales) + if module.bias is not None: + cloned.bias.copy_(module.bias) + cloned.pairs.copy_(module.pairs) + cloned.theta.copy_(module.theta) + cloned.channel_scales.copy_(module.channel_scales) + cloned.post_init() + cloned.eval() + return cloned + + +def _make_module(case: BenchCase, dtype: torch.dtype, device: torch.device, cache_runtime_dtype: bool): + from gptqmodel.nn_modules.qlinear.paroquant import ParoLinear + + buffers = _make_quant_buffers(case, dtype=dtype) + module = ParoLinear( + bits=4, + group_size=case.group_size, + sym=True, + desc_act=False, + in_features=case.in_features, + out_features=case.out_features, + bias=True, + register_buffers=True, + krot=case.krot, + cache_runtime_dtype=cache_runtime_dtype, + auto_cache_bf16_runtime_dtype=cache_runtime_dtype, + cache_rotation_dtype=False, + auto_cache_bf16_rotation_dtype=False, + ).to(device) + module.qweight.copy_(buffers["qweight"].to(device)) + module.qzeros.copy_(buffers["qzeros"].to(device)) + module.scales.copy_(buffers["scales"].to(device)) + module.bias.copy_(buffers["bias"].to(device)) + module.pairs.copy_(buffers["pairs"].to(device)) + module.theta.copy_(buffers["theta"].to(device)) + module.channel_scales.copy_(buffers["channel_scales"].to(device)) + module.post_init() + module.eval() + return module + + +def _dense_reference(module, x: torch.Tensor) -> torch.Tensor: + from gptqmodel.quantization.awq.utils.packing_utils import dequantize_gemm + from gptqmodel.utils.paroquant import apply_paroquant_rotation_reference + + rotated = apply_paroquant_rotation_reference( + x, + module.pairs, + module.theta, + scales=module.channel_scales, + group_size=module.group_size, + ) + weight = dequantize_gemm( + qweight=module.qweight, + qzeros=module.qzeros, + scales=module.scales, + bits=module.bits, + group_size=module.group_size, + ).to(device=x.device, dtype=x.dtype) + out = torch.matmul(rotated, weight) + if module.bias is not None: + out = out + module.bias.to(device=x.device, dtype=x.dtype) + return out + + +def _format_speedup(speedup: float) -> str: + return f"{speedup:.3f}x" + + +def run( + device: torch.device, + dtype: torch.dtype, + warmup: int, + iters: int, + quick: bool, + shard_index: int, + num_shards: int, +) -> dict[str, Any]: + torch.manual_seed(0) + torch.cuda.manual_seed_all(0) + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + + rows = [] + speedups = [] + selected_cases = _subset_cases(QUICK_CASES if quick else DEFAULT_CASES, shard_index=shard_index, num_shards=num_shards) + + for index, case in enumerate(selected_cases): + torch.manual_seed(4000 + index) + baseline = _make_module(case, dtype=dtype, device=device, cache_runtime_dtype=False) + candidate = _clone_module(baseline) + candidate.cache_runtime_dtype = True + x = torch.randn((case.batch, case.seq, case.in_features), device=device, dtype=dtype) + + with torch.inference_mode(): + dense = _dense_reference(baseline, x.reshape(-1, x.shape[-1])).reshape(case.batch, case.seq, case.out_features) + baseline_out = baseline(x) + candidate_out = candidate(x) + + baseline_dense = (baseline_out - dense).abs() + candidate_dense = (candidate_out - dense).abs() + baseline_candidate = (baseline_out - candidate_out).abs() + + baseline_ms = _benchmark_ms(lambda: baseline(x), device=device, warmup=warmup, iters=iters) + candidate_ms = _benchmark_ms(lambda: candidate(x), device=device, warmup=warmup, iters=iters) + speedup = baseline_ms / candidate_ms + speedups.append(speedup) + + rows.append( + { + "case_id": case.case_id, + "batch": case.batch, + "seq": case.seq, + "in_features": case.in_features, + "out_features": case.out_features, + "dtype": str(dtype).replace("torch.", ""), + "baseline_ms": baseline_ms, + "candidate_ms": candidate_ms, + "speedup": speedup, + "winner": "cache_on" if candidate_ms < baseline_ms else "cache_off", + "baseline_dense_max_abs": baseline_dense.max().item(), + "candidate_dense_max_abs": candidate_dense.max().item(), + "baseline_candidate_max_abs": baseline_candidate.max().item(), + "baseline_candidate_mean_abs": baseline_candidate.mean().item(), + } + ) + + geo_mean_speedup = math.exp(sum(math.log(v) for v in speedups) / len(speedups)) if speedups else float("nan") + return { + "device": torch.cuda.get_device_name(device), + "cuda_device": str(device), + "dtype": str(dtype).replace("torch.", ""), + "warmup": warmup, + "iters": iters, + "quick": quick, + "shard_index": shard_index, + "num_shards": num_shards, + "rows": rows, + "geo_mean_speedup": geo_mean_speedup, + "candidate_wins": sum(1 for v in speedups if v > 1.0), + "case_count": len(speedups), + } + + +def _configure_runtime(args: argparse.Namespace, device: torch.device) -> None: + if args.force_rebuild_awq: + awq_build_root = Path("/tmp") / ( + f"awq_jit_runtimecache_{os.getpid()}_dev{args.device}_shard{args.shard_index}_of_{args.num_shards}" + ) + os.environ["GPTQMODEL_AWQ_BUILD_ROOT"] = str(awq_build_root) + os.environ["GPTQMODEL_AWQ_FORCE_REBUILD"] = "1" + else: + os.environ.pop("GPTQMODEL_AWQ_BUILD_ROOT", None) + os.environ.pop("GPTQMODEL_AWQ_FORCE_REBUILD", None) + + if args.force_rebuild_paroquant: + paro_build_root = Path("/tmp") / ( + f"paroquant_ext_runtimecache_{os.getpid()}_dev{args.device}_shard{args.shard_index}_of_{args.num_shards}" + ) + os.environ["GPTQMODEL_PAROQUANT_BUILD_ROOT"] = str(paro_build_root) + os.environ["GPTQMODEL_PAROQUANT_FORCE_REBUILD"] = "1" + else: + os.environ.pop("GPTQMODEL_PAROQUANT_BUILD_ROOT", None) + os.environ.pop("GPTQMODEL_PAROQUANT_FORCE_REBUILD", None) + + from gptqmodel.utils.awq import clear_awq_extension_cache, prewarm_awq_extension + from gptqmodel.utils.paroquant import clear_paroquant_rotation_extension_cache, prewarm_paroquant_rotation_extension + + if args.force_rebuild_awq: + clear_awq_extension_cache() + if args.force_rebuild_paroquant: + clear_paroquant_rotation_extension_cache() + + if not prewarm_awq_extension(): + raise RuntimeError("Failed to build/load the AWQ CUDA runtime.") + if not prewarm_paroquant_rotation_extension( + fused_rotation=True, + group_size=128, + krot=8, + device=device, + ): + raise RuntimeError("Failed to build/load the ParoQuant CUDA rotation runtime.") + + +def _print_ascii(results: dict[str, Any]) -> None: + print(f"Device: {results['device']} ({results['cuda_device']}, dtype={results['dtype']})") + print() + print("Accuracy") + print( + tabulate( + [ + [ + row["dtype"], + row["case_id"], + f"{row['batch']}x{row['seq']}", + f"{row['in_features']}->{row['out_features']}", + f"{row['baseline_dense_max_abs']:.6f}", + f"{row['candidate_dense_max_abs']:.6f}", + f"{row['baseline_candidate_max_abs']:.6f}", + f"{row['baseline_candidate_mean_abs']:.6f}", + ] + for row in results["rows"] + ], + headers=[ + "dtype", + "case", + "batch x seq", + "shape", + "cache_off vs dense max_abs", + "cache_on vs dense max_abs", + "off vs on max_abs", + "off vs on mean_abs", + ], + tablefmt="plain", + ) + ) + print() + print("Benchmark") + print( + tabulate( + [ + [ + row["dtype"], + row["case_id"], + f"{row['batch']}x{row['seq']}", + f"{row['in_features']}->{row['out_features']}", + f"{row['baseline_ms']:.3f}", + f"{row['candidate_ms']:.3f}", + _format_speedup(row["speedup"]), + row["winner"], + ] + for row in results["rows"] + ], + headers=[ + "dtype", + "case", + "batch x seq", + "shape", + "cache_off ms", + "cache_on ms", + "speedup", + "winner", + ], + tablefmt="plain", + ) + ) + print() + print( + "Summary: " + f"candidate_wins={results['candidate_wins']}/{results['case_count']}, " + f"geo_mean_speedup={results['geo_mean_speedup']:.3f}x" + ) + + +def main() -> int: + parser = argparse.ArgumentParser(description="A/B benchmark ParoQuant runtime-dtype caching.") + parser.add_argument("--device", type=int, default=0) + parser.add_argument("--dtype", choices=("fp16", "bf16"), default="fp16") + parser.add_argument("--warmup", type=int, default=5) + parser.add_argument("--iters", type=int, default=20) + parser.add_argument("--quick", action="store_true") + parser.add_argument("--shard-index", type=int, default=0) + parser.add_argument("--num-shards", type=int, default=1) + parser.add_argument("--json-out", type=Path, default=None) + parser.add_argument("--json", action="store_true") + parser.add_argument("--force-rebuild-awq", action="store_true") + parser.add_argument("--force-rebuild-paroquant", action="store_true") + args = parser.parse_args() + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for the ParoQuant runtime cache benchmark.") + + device = torch.device(f"cuda:{args.device}") + _configure_runtime(args, device) + results = run( + device=device, + dtype=_resolve_dtype(args.dtype), + warmup=args.warmup, + iters=args.iters, + quick=args.quick, + shard_index=args.shard_index, + num_shards=args.num_shards, + ) + _print_ascii(results) + + if args.json_out is not None: + args.json_out.parent.mkdir(parents=True, exist_ok=True) + args.json_out.write_text(json.dumps(results, indent=2), encoding="utf-8") + if args.json: + print(json.dumps(results, indent=2)) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/benchmark_paroquant_triton_ab.py b/scripts/benchmark_paroquant_triton_ab.py new file mode 100644 index 000000000..54d45d410 --- /dev/null +++ b/scripts/benchmark_paroquant_triton_ab.py @@ -0,0 +1,279 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import argparse +import json +import math +import time +from dataclasses import asdict, dataclass +from typing import Any + +import torch +from tabulate import tabulate + +from gptqmodel.nn_modules.qlinear.paroquant import ParoLinear +from gptqmodel.nn_modules.qlinear.paroquant_triton import ParoQuantTritonLinear +from gptqmodel.utils.paroquant import build_identity_rotation_buffers + + +@dataclass(frozen=True) +class BenchCase: + case_id: str + batch: int + seq: int + in_features: int + out_features: int + group_size: int = 128 + krot: int = 8 + + +DEFAULT_CASES = [ + BenchCase("decode_q_proj", batch=1, seq=1, in_features=2048, out_features=2048), + BenchCase("prefill_q_proj", batch=1, seq=128, in_features=2048, out_features=2048), + BenchCase("batched_q_proj", batch=4, seq=128, in_features=2048, out_features=2048), + BenchCase("prefill_k_proj", batch=1, seq=128, in_features=2048, out_features=512), + BenchCase("prefill_gate_proj", batch=1, seq=128, in_features=2048, out_features=8192), + BenchCase("batched_down_proj", batch=4, seq=128, in_features=8192, out_features=2048), +] + +QUICK_CASES = [ + BenchCase("decode_q_proj", batch=1, seq=1, in_features=2048, out_features=2048), + BenchCase("prefill_q_proj", batch=1, seq=128, in_features=2048, out_features=2048), + BenchCase("batched_down_proj", batch=4, seq=128, in_features=8192, out_features=2048), +] + + +def _pack_awq_tensor(unpacked: torch.Tensor, bits: int) -> torch.Tensor: + pack_factor = 32 // bits + order_map = [0, 2, 4, 6, 1, 3, 5, 7] + packed = torch.zeros( + (unpacked.shape[0], unpacked.shape[1] // pack_factor), + dtype=torch.int32, + ) + for col in range(unpacked.shape[1] // pack_factor): + for i, order in enumerate(order_map): + value = unpacked[:, col * pack_factor + order].to(torch.int32) + packed[:, col] |= value << (i * bits) + return packed + + +def _make_quant_buffers(case: BenchCase, bits: int = 4) -> dict[str, torch.Tensor]: + groups = case.in_features // case.group_size + int_weight = torch.randint(0, 2**bits, size=(case.in_features, case.out_features), dtype=torch.int32) + zero_points = torch.randint(0, 2**bits, size=(groups, case.out_features), dtype=torch.int32) + scales = (torch.rand(groups, case.out_features, dtype=torch.float16) * 0.5) + 0.75 + bias = torch.randn(case.out_features, dtype=torch.float16) + + pairs, theta, channel_scales = build_identity_rotation_buffers( + in_features=case.in_features, + group_size=case.group_size, + krot=case.krot, + dtype=torch.float16, + ) + theta.uniform_(-0.2, 0.2) + channel_scales.uniform_(0.75, 1.25) + + return { + "qweight": _pack_awq_tensor(int_weight, bits), + "qzeros": _pack_awq_tensor(zero_points, bits), + "scales": scales, + "bias": bias, + "pairs": pairs, + "theta": theta, + "channel_scales": channel_scales, + } + + +def _build_module( + module_cls, + case: BenchCase, + buffers: dict[str, torch.Tensor], + device: torch.device, + bits: int = 4, +): + module = module_cls( + bits=bits, + group_size=case.group_size, + sym=True, + desc_act=False, + in_features=case.in_features, + out_features=case.out_features, + bias=True, + register_buffers=True, + krot=case.krot, + ).to(device) + module.qweight.copy_(buffers["qweight"].to(device)) + module.qzeros.copy_(buffers["qzeros"].to(device)) + module.scales.copy_(buffers["scales"].to(device)) + module.bias.copy_(buffers["bias"].to(device)) + module.pairs.copy_(buffers["pairs"].to(device)) + module.theta.copy_(buffers["theta"].to(device)) + module.channel_scales.copy_(buffers["channel_scales"].to(device)) + module.post_init() + module.eval() + return module + + +def _dense_reference(module: ParoLinear, x: torch.Tensor) -> torch.Tensor: + with torch.inference_mode(): + x_flat = x.reshape(-1, x.shape[-1]) + rotated = module._rotate_inputs(x_flat) + out = module._forward_dense(rotated) + return out.reshape(x.shape[:-1] + (module.out_features,)) + + +def _benchmark_ms(module, x: torch.Tensor, warmup: int, iters: int) -> float: + with torch.inference_mode(): + for _ in range(warmup): + module(x) + torch.cuda.synchronize(x.device) + start = time.perf_counter() + for _ in range(iters): + module(x) + torch.cuda.synchronize(x.device) + return (time.perf_counter() - start) * 1e3 / iters + + +def _format_speedup(speedup: float) -> str: + return f"{speedup:.3f}x" + + +def run(device: torch.device, warmup: int, iters: int, quick: bool) -> dict[str, Any]: + torch.manual_seed(0) + torch.cuda.manual_seed_all(0) + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + + cases = QUICK_CASES if quick else DEFAULT_CASES + accuracy_rows = [] + benchmark_rows = [] + speedups = [] + + for index, case in enumerate(cases): + torch.manual_seed(1000 + index) + buffers = _make_quant_buffers(case) + baseline = _build_module(ParoLinear, case, buffers, device) + candidate = _build_module(ParoQuantTritonLinear, case, buffers, device) + + x = torch.randn((case.batch, case.seq, case.in_features), device=device, dtype=torch.float16) + + with torch.inference_mode(): + dense = _dense_reference(baseline, x) + baseline_out = baseline(x) + candidate_out = candidate(x) + + baseline_dense = (baseline_out - dense).abs() + candidate_dense = (candidate_out - dense).abs() + baseline_candidate = (baseline_out - candidate_out).abs() + + accuracy_rows.append( + [ + case.case_id, + f"{case.batch}x{case.seq}", + f"{case.in_features}->{case.out_features}", + f"{baseline_dense.max().item():.6f}", + f"{candidate_dense.max().item():.6f}", + f"{baseline_candidate.max().item():.6f}", + f"{baseline_candidate.mean().item():.6f}", + ] + ) + + baseline_ms = _benchmark_ms(baseline, x, warmup=warmup, iters=iters) + candidate_ms = _benchmark_ms(candidate, x, warmup=warmup, iters=iters) + speedup = baseline_ms / candidate_ms + speedups.append(speedup) + winner = "triton" if candidate_ms < baseline_ms else "existing" + + benchmark_rows.append( + [ + case.case_id, + f"{case.batch}x{case.seq}", + f"{case.in_features}->{case.out_features}", + f"{baseline_ms:.3f}", + f"{candidate_ms:.3f}", + _format_speedup(speedup), + winner, + ] + ) + + geo_mean_speedup = math.exp(sum(math.log(v) for v in speedups) / len(speedups)) + triton_wins = sum(1 for value in speedups if value > 1.0) + + return { + "device": torch.cuda.get_device_name(device), + "cuda_device": str(device), + "warmup": warmup, + "iters": iters, + "quick": quick, + "accuracy_headers": [ + "case", + "batch x seq", + "matmul", + "existing vs dense max_abs", + "triton vs dense max_abs", + "existing vs triton max_abs", + "existing vs triton mean_abs", + ], + "accuracy_rows": accuracy_rows, + "benchmark_headers": [ + "case", + "batch x seq", + "matmul", + "existing ms", + "triton ms", + "speedup", + "winner", + ], + "benchmark_rows": benchmark_rows, + "geo_mean_speedup": geo_mean_speedup, + "triton_wins": triton_wins, + "case_count": len(speedups), + } + + +def main() -> int: + parser = argparse.ArgumentParser(description="A/B benchmark ParoQuant Triton kernel against the existing kernel.") + parser.add_argument("--device", type=int, default=0, help="CUDA device index within the current visible set.") + parser.add_argument("--warmup", type=int, default=5, help="Warmup iterations per case.") + parser.add_argument("--iters", type=int, default=20, help="Measured iterations per case.") + parser.add_argument("--quick", action="store_true", help="Run a smaller subset of benchmark cases.") + parser.add_argument("--json", action="store_true", help="Also emit the full result payload as JSON.") + args = parser.parse_args() + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for the ParoQuant Triton benchmark.") + + try: + import triton # noqa: F401 + except Exception as exc: # pragma: no cover - environment dependent + raise RuntimeError(f"Triton is required for the ParoQuant Triton benchmark: {exc}") from exc + + device = torch.device(f"cuda:{args.device}") + results = run(device=device, warmup=args.warmup, iters=args.iters, quick=args.quick) + + print(f"Device: {results['device']} ({results['cuda_device']})") + print() + print("Accuracy") + print(tabulate(results["accuracy_rows"], headers=results["accuracy_headers"], tablefmt="grid")) + print() + print("Benchmark") + print(tabulate(results["benchmark_rows"], headers=results["benchmark_headers"], tablefmt="grid")) + print() + print( + "Summary: " + f"triton_wins={results['triton_wins']}/{results['case_count']}, " + f"geo_mean_speedup={results['geo_mean_speedup']:.3f}x" + ) + + if args.json: + print() + print(json.dumps(results, indent=2)) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/benchmark_qwen35_moe_ab.py b/scripts/benchmark_qwen35_moe_ab.py new file mode 100644 index 000000000..c05f2ef59 --- /dev/null +++ b/scripts/benchmark_qwen35_moe_ab.py @@ -0,0 +1,717 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import argparse +import gc +import json +import os +import subprocess +import sys +import tempfile +import time +from pathlib import Path +from typing import Any, Dict, List, Optional + + +DEFAULT_MODEL_PATH = "/monster/data/model/Qwen3.5-35B-A3B" +DEFAULT_BASELINE_ROOT = "/root/gptqmodel-main" +FAST_LAYER_COUNT_ENV = "GPTQMODEL_FAST_LAYER_COUNT" +FAST_LAYER_POSITION_ENV = "GPTQMODEL_FAST_LAYER_POSITION" +VRAM_STRATEGY_CHOICES = ( + "exclusive", + "balanced", + "dense_home_moe_balanced", +) +DENSE_VRAM_STRATEGY_CHOICES = ("exclusive", "balanced") +MOE_VRAM_STRATEGY_CHOICES = ("exclusive", "balanced") + + +def _csv_arg(value: Optional[str]) -> Optional[List[str]]: + """Parse a comma-separated CLI device list into a normalized list.""" + + if value is None: + return None + items = [item.strip() for item in value.split(",") if item.strip()] + return items or None + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=( + "Benchmark two MoE layers of Qwen 3.5 on the current repo and a baseline repo, " + "capturing per-GPU VRAM plus GPTQ timing regions." + ) + ) + parser.add_argument("--single", action="store_true", help="Run a single repo benchmark inside the current process.") + parser.add_argument("--repo-root", type=Path, help="Repo root to import in --single mode.") + parser.add_argument("--json-out", type=Path, help="JSON output path in --single mode.") + parser.add_argument("--label", default="", help="Friendly label for the current case.") + parser.add_argument("--model-path", default=DEFAULT_MODEL_PATH, help="Local Qwen 3.5 MoE model path.") + parser.add_argument("--baseline-root", type=Path, default=Path(DEFAULT_BASELINE_ROOT), help="Baseline repo root for A/B mode.") + parser.add_argument("--current-root", type=Path, default=Path(__file__).resolve().parents[1], help="Current repo root for A/B mode.") + parser.add_argument("--output-dir", type=Path, default=None, help="Directory for logs and JSON outputs.") + parser.add_argument("--dataset-size", type=int, default=16, help="Calibration rows to use.") + parser.add_argument("--batch-size", type=int, default=1, help="Quant batch size.") + parser.add_argument("--quant-layers", type=int, default=2, help="Number of prefix layers to keep in fast-mode quantization.") + parser.add_argument("--stop-after-layer", type=int, default=1, help="Stop once this zero-based layer index has fully finalized.") + parser.add_argument("--dtype", default="auto", help="Model dtype passed to GPTQModel.load().") + parser.add_argument("--attn-implementation", default="eager", choices=("eager", "flash_attention_2"), help="Attention implementation.") + parser.add_argument("--vram-strategy", default="balanced", choices=VRAM_STRATEGY_CHOICES, help="VRAM strategy to benchmark.") + parser.add_argument("--cuda-visible-devices", default=None, help="Comma-separated visible device set for one case.") + parser.add_argument("--dense-vram-strategy", default="balanced", choices=DENSE_VRAM_STRATEGY_CHOICES, help="Dense-pool strategy for repos that support split VRAM config.") + parser.add_argument("--dense-vram-strategy-devices", default=None, help="Comma-separated dense-pool devices relative to CUDA_VISIBLE_DEVICES, e.g. cuda:0.") + parser.add_argument("--moe-vram-strategy", default="balanced", choices=MOE_VRAM_STRATEGY_CHOICES, help="MoE-pool strategy for repos that support split VRAM config.") + parser.add_argument("--moe-vram-strategy-devices", default=None, help="Comma-separated MoE-pool devices relative to CUDA_VISIBLE_DEVICES, e.g. cuda:1,cuda:2.") + parser.add_argument( + "--current-vram-strategy", + default=None, + choices=VRAM_STRATEGY_CHOICES, + help="Optional VRAM strategy override for the current repo in A/B mode.", + ) + parser.add_argument( + "--baseline-vram-strategy", + default=None, + choices=("exclusive", "balanced"), + help="Optional VRAM strategy override for the baseline repo in A/B mode.", + ) + parser.add_argument("--current-cuda-visible-devices", default=None, help="Comma-separated CUDA_VISIBLE_DEVICES for the current repo in A/B mode.") + parser.add_argument("--baseline-cuda-visible-devices", default=None, help="Comma-separated CUDA_VISIBLE_DEVICES for the baseline repo in A/B mode.") + parser.add_argument("--current-dense-vram-strategy", default=None, choices=DENSE_VRAM_STRATEGY_CHOICES, help="Optional dense-pool strategy override for the current repo in A/B mode.") + parser.add_argument("--baseline-dense-vram-strategy", default=None, choices=DENSE_VRAM_STRATEGY_CHOICES, help="Optional dense-pool strategy override for the baseline repo in A/B mode.") + parser.add_argument("--current-dense-vram-strategy-devices", default=None, help="Optional dense-pool device list override for the current repo in A/B mode.") + parser.add_argument("--baseline-dense-vram-strategy-devices", default=None, help="Optional dense-pool device list override for the baseline repo in A/B mode.") + parser.add_argument("--current-moe-vram-strategy", default=None, choices=MOE_VRAM_STRATEGY_CHOICES, help="Optional MoE-pool strategy override for the current repo in A/B mode.") + parser.add_argument("--baseline-moe-vram-strategy", default=None, choices=MOE_VRAM_STRATEGY_CHOICES, help="Optional MoE-pool strategy override for the baseline repo in A/B mode.") + parser.add_argument("--current-moe-vram-strategy-devices", default=None, help="Optional MoE-pool device list override for the current repo in A/B mode.") + parser.add_argument("--baseline-moe-vram-strategy-devices", default=None, help="Optional MoE-pool device list override for the baseline repo in A/B mode.") + return parser.parse_args() + + +def _git_head(repo_root: Path) -> str: + try: + completed = subprocess.run( + ["git", "rev-parse", "HEAD"], + cwd=repo_root, + check=True, + capture_output=True, + text=True, + ) + except Exception: + return "unknown" + return completed.stdout.strip() + + +def _extract_region_total(snapshot: Dict[str, Dict[str, Any]], region: str) -> float: + stat = snapshot.get(region) or {} + try: + return float(stat.get("total", 0.0)) + except (TypeError, ValueError): + return 0.0 + + +def _spread(values: List[float]) -> float: + """Measure device imbalance across all visible accelerators for one sample.""" + + if len(values) < 2: + return 0.0 + return max(values) - min(values) + + +def _summarize_case(case: Dict[str, Any]) -> Dict[str, Any]: + layer_records = case.get("layer_records") or [] + final_layer = layer_records[-1] if layer_records else None + final_devices = (final_layer or {}).get("devices") or [] + reserved_gib = [float(item.get("reserved_gib", 0.0)) for item in final_devices] + peak_reserved_gib = [float(item.get("max_reserved_gib", 0.0)) for item in final_devices] + + spread_reserved_gib = _spread(reserved_gib) + peak_spread_reserved_gib = _spread(peak_reserved_gib) + + quant_region_snapshot = case.get("quant_region_snapshot") or {} + return { + "label": case.get("label"), + "repo_root": case.get("repo_root"), + "git_head": case.get("git_head"), + "vram_strategy": case.get("vram_strategy"), + "dense_vram_strategy": case.get("dense_vram_strategy"), + "dense_vram_strategy_devices": case.get("dense_vram_strategy_devices"), + "moe_vram_strategy": case.get("moe_vram_strategy"), + "moe_vram_strategy_devices": case.get("moe_vram_strategy_devices"), + "split_vram_pools_applied": bool(case.get("split_vram_pools_applied")), + "cuda_visible_devices": case.get("cuda_visible_devices"), + "quant_wall_s": float(case.get("quant_wall_s", 0.0)), + "pre_quant_forward_s": _extract_region_total(quant_region_snapshot, "pre_quant_forward"), + "process_quant_s": _extract_region_total(quant_region_snapshot, "process_quant"), + "post_quant_forward_s": _extract_region_total(quant_region_snapshot, "post_quant_forward"), + "layer_count_observed": len(layer_records), + "final_layer_idx": final_layer.get("layer_idx") if final_layer else None, + "final_reserved_gib": reserved_gib, + "final_peak_reserved_gib": peak_reserved_gib, + "final_reserved_spread_gib": spread_reserved_gib, + "final_peak_reserved_spread_gib": peak_spread_reserved_gib, + } + + +def _compare_cases(current: Dict[str, Any], baseline: Dict[str, Any]) -> Dict[str, Any]: + current_summary = _summarize_case(current) + baseline_summary = _summarize_case(baseline) + fields = [ + "quant_wall_s", + "pre_quant_forward_s", + "process_quant_s", + "post_quant_forward_s", + "final_reserved_spread_gib", + "final_peak_reserved_spread_gib", + ] + deltas = {} + for field in fields: + cur = float(current_summary.get(field, 0.0)) + base = float(baseline_summary.get(field, 0.0)) + pct = None if base == 0.0 else ((cur - base) / base) * 100.0 + deltas[field] = { + "current": cur, + "baseline": base, + "delta": cur - base, + "delta_pct": pct, + } + + return { + "current": current_summary, + "baseline": baseline_summary, + "delta": deltas, + } + + +def _print_case_summary(case: Dict[str, Any]) -> None: + summary = _summarize_case(case) + print( + f"[{summary['label']}] head={summary['git_head']} " + f"vram_strategy={summary.get('vram_strategy')} " + f"dense={summary.get('dense_vram_strategy')}@{summary.get('dense_vram_strategy_devices')} " + f"moe={summary.get('moe_vram_strategy')}@{summary.get('moe_vram_strategy_devices')} " + f"split_applied={summary.get('split_vram_pools_applied')} " + f"visible={summary.get('cuda_visible_devices')}" + ) + print( + " quant_wall_s={quant_wall_s:.3f} pre={pre_quant_forward_s:.3f} " + "quant={process_quant_s:.3f} post={post_quant_forward_s:.3f}".format(**summary) + ) + if summary["final_reserved_gib"]: + rounded_reserved = [round(value, 3) for value in summary["final_reserved_gib"]] + rounded_peak = [round(value, 3) for value in summary["final_peak_reserved_gib"]] + print( + " final_layer_idx={} reserved_gib={} peak_reserved_gib={} spread_gib={:.3f}".format( + summary["final_layer_idx"], + rounded_reserved, + rounded_peak, + summary["final_reserved_spread_gib"], + ) + ) + + +def _print_compare(compare: Dict[str, Any]) -> None: + print("\n[A/B delta: current - baseline]") + for field, entry in compare["delta"].items(): + pct_text = "n/a" if entry["delta_pct"] is None else f"{entry['delta_pct']:+.2f}%" + print( + f" {field}: current={entry['current']:.3f} baseline={entry['baseline']:.3f} " + f"delta={entry['delta']:+.3f} ({pct_text})" + ) + + +def _run_subprocess_case( + *, + script_path: Path, + repo_root: Path, + label: str, + output_dir: Path, + model_path: str, + dataset_size: int, + batch_size: int, + quant_layers: int, + stop_after_layer: int, + dtype: str, + attn_implementation: str, + vram_strategy: str, + cuda_visible_devices: Optional[str], + dense_vram_strategy: str, + dense_vram_strategy_devices: Optional[str], + moe_vram_strategy: str, + moe_vram_strategy_devices: Optional[str], +) -> Dict[str, Any]: + json_out = output_dir / f"{label}.json" + log_out = output_dir / f"{label}.log" + env = os.environ.copy() + env["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + if cuda_visible_devices is not None: + env["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices + else: + env.setdefault("CUDA_VISIBLE_DEVICES", "0,1") + env["PYTHON_GIL"] = "0" + env["DEBUG"] = "1" + + cmd = [ + sys.executable, + str(script_path), + "--single", + "--repo-root", + str(repo_root), + "--json-out", + str(json_out), + "--label", + label, + "--model-path", + model_path, + "--dataset-size", + str(dataset_size), + "--batch-size", + str(batch_size), + "--quant-layers", + str(quant_layers), + "--stop-after-layer", + str(stop_after_layer), + "--dtype", + dtype, + "--attn-implementation", + attn_implementation, + "--vram-strategy", + vram_strategy, + "--dense-vram-strategy", + dense_vram_strategy, + "--moe-vram-strategy", + moe_vram_strategy, + ] + if cuda_visible_devices is not None: + cmd.extend(["--cuda-visible-devices", cuda_visible_devices]) + if dense_vram_strategy_devices is not None: + cmd.extend(["--dense-vram-strategy-devices", dense_vram_strategy_devices]) + if moe_vram_strategy_devices is not None: + cmd.extend(["--moe-vram-strategy-devices", moe_vram_strategy_devices]) + + with log_out.open("w", encoding="utf-8") as log_handle: + subprocess.run( + cmd, + check=True, + cwd=repo_root, + env=env, + stdout=log_handle, + stderr=subprocess.STDOUT, + text=True, + ) + + with json_out.open("r", encoding="utf-8") as handle: + return json.load(handle) + + +def _single_case_main(args: argparse.Namespace) -> int: + if args.repo_root is None or args.json_out is None: + raise SystemExit("--single requires --repo-root and --json-out") + if os.environ.get("PYTHON_GIL") != "0": + raise SystemExit("--single requires PYTHON_GIL=0") + if args.cuda_visible_devices is not None and os.environ.get("CUDA_VISIBLE_DEVICES") != args.cuda_visible_devices: + raise SystemExit("--single requires CUDA_VISIBLE_DEVICES to match --cuda-visible-devices") + + repo_root = args.repo_root.resolve() + os.chdir(repo_root) + sys.path.insert(0, str(repo_root)) + sys.path.insert(0, str(repo_root / "tests" / "models")) + + os.environ[FAST_LAYER_COUNT_ENV] = str(args.quant_layers) + os.environ[FAST_LAYER_POSITION_ENV] = "first" + os.environ["DEBUG"] = "1" + + import torch + from transformers.utils import is_flash_attn_2_available + + from gptqmodel import DEBUG_ON, GPTQModel + from gptqmodel.looper.module_looper import StopMainLoop + from gptqmodel.quantization.config import VramStrategy + from gptqmodel.utils.torch import torch_empty_cache + from model_test import BACKEND + from test_qwen3_5_moe import TestQwen3_5Moe + + resolved_vram_strategy = VramStrategy(args.vram_strategy) + resolved_dense_vram_strategy = VramStrategy(args.dense_vram_strategy) + dense_vram_strategy_devices = _csv_arg(args.dense_vram_strategy_devices) + moe_vram_strategy_devices = _csv_arg(args.moe_vram_strategy_devices) + resolved_moe_vram_strategy = VramStrategy(args.moe_vram_strategy) + + def _safe_sync() -> None: + if not torch.cuda.is_available(): + return + for idx in range(torch.cuda.device_count()): + try: + torch.cuda.synchronize(idx) + except Exception as exc: + # Snapshot collection is best-effort; keep the benchmark running. + print(f"Warning: failed to synchronize cuda:{idx} during benchmark snapshot: {exc}", file=sys.stderr) + + def _snapshot_cuda(label: str) -> Dict[str, Any]: + gc.collect() + _safe_sync() + snapshot: Dict[str, Any] = { + "label": label, + "monotonic_s": time.perf_counter(), + "devices": [], + } + if not torch.cuda.is_available(): + return snapshot + + for idx in range(torch.cuda.device_count()): + stats = torch.cuda.memory_stats(idx) + free_bytes, total_bytes = torch.cuda.mem_get_info(idx) + snapshot["devices"].append( + { + "index": idx, + "name": torch.cuda.get_device_name(idx), + "allocated_gib": torch.cuda.memory_allocated(idx) / (1024 ** 3), + "reserved_gib": torch.cuda.memory_reserved(idx) / (1024 ** 3), + "max_allocated_gib": torch.cuda.max_memory_allocated(idx) / (1024 ** 3), + "max_reserved_gib": torch.cuda.max_memory_reserved(idx) / (1024 ** 3), + "active_current_gib": stats.get("active_bytes.all.current", 0) / (1024 ** 3), + "active_peak_gib": stats.get("active_bytes.all.peak", 0) / (1024 ** 3), + "free_gib": free_bytes / (1024 ** 3), + "total_gib": total_bytes / (1024 ** 3), + } + ) + return snapshot + + class BenchmarkQwen35Moe(TestQwen3_5Moe): + DATASET_SIZE = args.dataset_size + QUANT_BATCH_SIZE = args.batch_size + STOP_AFTER_LAYER = args.stop_after_layer + EVAL_TASKS = {} + EVAL_TASKS_FAST = {} + EVAL_TASKS_SLOW = {} + USE_FLASH_ATTN = args.attn_implementation == "flash_attention_2" + VRAM_STRATEGY = resolved_vram_strategy + + def __init__(self, methodName: str = "test_qwen3_5_moe"): + super().__init__(methodName=methodName) + self.memory_snapshots: List[Dict[str, Any]] = [] + self.layer_records: List[Dict[str, Any]] = [] + + def _record_snapshot(self, label: str) -> Dict[str, Any]: + snapshot = _snapshot_cuda(label) + self.memory_snapshots.append(snapshot) + return snapshot + + def _build_layer_stop_callback(self, layer_idx: int): + outer = self + + class _Probe: + def __init__(self, target: int): + self._target = target + self._triggered = False + + def layer_complete(self, *, layer_idx: int, submodule_finalized: bool): + if submodule_finalized: + outer.layer_records.append( + { + "layer_idx": layer_idx, + "submodule_finalized": True, + "devices": _snapshot_cuda(f"layer_{layer_idx}_finalized")["devices"], + } + ) + if self._triggered: + return None + if layer_idx > self._target or (submodule_finalized and layer_idx >= self._target): + self._triggered = True + raise StopMainLoop + return None + + return _Probe(layer_idx) + + def run_benchmark(self) -> Dict[str, Any]: + torch_empty_cache() + if torch.cuda.is_available(): + for idx in range(torch.cuda.device_count()): + torch.cuda.reset_peak_memory_stats(idx) + + quantize_config = self._build_quantize_config() + quantize_config.wait_for_submodule_finalizers = True + + load_kwargs: Dict[str, Any] = {} + if self.USE_FLASH_ATTN and is_flash_attn_2_available(): + load_kwargs["attn_implementation"] = "flash_attention_2" + else: + load_kwargs["attn_implementation"] = "eager" + + torch_fused_backend = self._torch_fused_backend() + device_map = {"": "cpu"} if self.LOAD_BACKEND == torch_fused_backend else "auto" + + model = None + dataset = None + stop_exception = False + try: + self._record_snapshot("before_load") + model = GPTQModel.load( + self.NATIVE_MODEL_ID, + quantize_config=quantize_config, + trust_remote_code=self.TRUST_REMOTE_CODE, + dtype=args.dtype, + device_map=device_map, + **load_kwargs, + ) + self._record_snapshot("after_load") + + self._layer_stop_callback = None + if DEBUG_ON and self.STOP_AFTER_LAYER is not None: + self._layer_stop_callback = self._build_layer_stop_callback(self.STOP_AFTER_LAYER) + model.layer_callback = self._layer_stop_callback + + self._apply_model_compat_quant_overrides(model) + dataset = self.load_dataset(model.tokenizer, rows=self.DATASET_SIZE) + self._record_snapshot("after_dataset") + + start = time.perf_counter() + try: + model.quantize( + dataset, + calibration_concat_size=self.DATASET_CONCAT_SIZE, + calibration_concat_separator=self.DATASET_CONCAT_SEPARATOR, + calibration_sort=self.DATASET_SORT, + backend=self.QUANT_BACKEND, + batch_size=self.QUANT_BATCH_SIZE, + ) + except StopMainLoop: + stop_exception = True + quant_wall_s = time.perf_counter() - start + self._record_snapshot("after_quant") + + quant_region_snapshot = model.quant_region_timer.snapshot() + hf_device_map = getattr(model.model, "hf_device_map", None) or getattr(model, "hf_device_map", None) + result = { + "label": args.label, + "repo_root": str(repo_root), + "git_head": _git_head(repo_root), + "model_path": self.NATIVE_MODEL_ID, + "dataset_size": self.DATASET_SIZE, + "batch_size": self.QUANT_BATCH_SIZE, + "quant_layers": args.quant_layers, + "stop_after_layer": self.STOP_AFTER_LAYER, + "dtype": str(args.dtype), + "attn_implementation": load_kwargs["attn_implementation"], + "vram_strategy": self.VRAM_STRATEGY.value, + "dense_vram_strategy": getattr(self, "DENSE_VRAM_STRATEGY", None).value if getattr(self, "DENSE_VRAM_STRATEGY", None) is not None else None, + "dense_vram_strategy_devices": getattr(self, "DENSE_VRAM_STRATEGY_DEVICES", None), + "moe_vram_strategy": getattr(self, "MOE_VRAM_STRATEGY", None).value if getattr(self, "MOE_VRAM_STRATEGY", None) is not None else None, + "moe_vram_strategy_devices": getattr(self, "MOE_VRAM_STRATEGY_DEVICES", None), + "split_vram_pools_supported": split_vram_pools_supported, + "split_vram_pools_applied": split_vram_pools_applied, + "python": sys.version, + "python_gil_disabled": bool(getattr(sys, "_is_gil_enabled", lambda: True)() is False), + "python_gil_env": os.environ.get("PYTHON_GIL"), + "cuda_visible_devices": os.environ.get("CUDA_VISIBLE_DEVICES"), + "cuda_device_order": os.environ.get("CUDA_DEVICE_ORDER"), + "debug_on": bool(DEBUG_ON), + "debug_short_circuit": bool(self._debug_layer_stop_triggered()), + "stop_exception_caught": stop_exception, + "quant_wall_s": quant_wall_s, + "quant_region_snapshot": quant_region_snapshot, + "memory_snapshots": self.memory_snapshots, + "layer_records": self.layer_records, + "visible_devices": [ + { + "index": idx, + "name": torch.cuda.get_device_name(idx), + } + for idx in range(torch.cuda.device_count()) + ] + if torch.cuda.is_available() + else [], + "hf_device_map": hf_device_map, + "load_backend": self.LOAD_BACKEND.name if isinstance(self.LOAD_BACKEND, BACKEND) else str(self.LOAD_BACKEND), + "quant_backend": self.QUANT_BACKEND.name if isinstance(self.QUANT_BACKEND, BACKEND) else str(self.QUANT_BACKEND), + } + return result + finally: + del dataset + del model + torch_empty_cache() + + # Only apply split pools when the imported repo exposes the newer test + # class knobs; older branches stay on their legacy single-strategy path. + split_vram_pools_supported = ( + hasattr(BenchmarkQwen35Moe, "DENSE_VRAM_STRATEGY") + and hasattr(BenchmarkQwen35Moe, "DENSE_VRAM_STRATEGY_DEVICES") + and hasattr(BenchmarkQwen35Moe, "MOE_VRAM_STRATEGY") + and hasattr(BenchmarkQwen35Moe, "MOE_VRAM_STRATEGY_DEVICES") + ) + split_vram_pools_applied = False + if split_vram_pools_supported: + BenchmarkQwen35Moe.DENSE_VRAM_STRATEGY = resolved_dense_vram_strategy + BenchmarkQwen35Moe.DENSE_VRAM_STRATEGY_DEVICES = dense_vram_strategy_devices + BenchmarkQwen35Moe.MOE_VRAM_STRATEGY = resolved_moe_vram_strategy + BenchmarkQwen35Moe.MOE_VRAM_STRATEGY_DEVICES = moe_vram_strategy_devices + split_vram_pools_applied = True + + case = BenchmarkQwen35Moe() + result = case.run_benchmark() + args.json_out.parent.mkdir(parents=True, exist_ok=True) + with args.json_out.open("w", encoding="utf-8") as handle: + json.dump(result, handle, indent=2, sort_keys=True) + return 0 + + +def _ab_main(args: argparse.Namespace) -> int: + script_path = Path(__file__).resolve() + output_dir = args.output_dir or Path(tempfile.mkdtemp(prefix="qwen35_moe_ab_")) + output_dir.mkdir(parents=True, exist_ok=True) + current_vram_strategy = args.current_vram_strategy or args.vram_strategy + baseline_vram_strategy = args.baseline_vram_strategy or args.vram_strategy + current_cuda_visible_devices = args.current_cuda_visible_devices or args.cuda_visible_devices + baseline_cuda_visible_devices = args.baseline_cuda_visible_devices or args.cuda_visible_devices + current_dense_vram_strategy = args.current_dense_vram_strategy or args.dense_vram_strategy + baseline_dense_vram_strategy = args.baseline_dense_vram_strategy or args.dense_vram_strategy + current_dense_vram_strategy_devices = args.current_dense_vram_strategy_devices or args.dense_vram_strategy_devices + baseline_dense_vram_strategy_devices = args.baseline_dense_vram_strategy_devices or args.dense_vram_strategy_devices + current_moe_vram_strategy = args.current_moe_vram_strategy or args.moe_vram_strategy + baseline_moe_vram_strategy = args.baseline_moe_vram_strategy or args.moe_vram_strategy + current_moe_vram_strategy_devices = args.current_moe_vram_strategy_devices or args.moe_vram_strategy_devices + baseline_moe_vram_strategy_devices = args.baseline_moe_vram_strategy_devices or args.moe_vram_strategy_devices + + current = baseline = None + + def _start_case( + *, + repo_root: Path, + label: str, + vram_strategy: str, + cuda_visible_devices: Optional[str], + dense_vram_strategy: str, + dense_vram_strategy_devices: Optional[str], + moe_vram_strategy: str, + moe_vram_strategy_devices: Optional[str], + ): + json_out = output_dir / f"{label}.json" + log_out = output_dir / f"{label}.log" + env = os.environ.copy() + env["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + if cuda_visible_devices is not None: + env["CUDA_VISIBLE_DEVICES"] = cuda_visible_devices + else: + env.setdefault("CUDA_VISIBLE_DEVICES", "0,1") + env["PYTHON_GIL"] = "0" + env["DEBUG"] = "1" + + cmd = [ + sys.executable, + str(script_path), + "--single", + "--repo-root", + str(repo_root), + "--json-out", + str(json_out), + "--label", + label, + "--model-path", + args.model_path, + "--dataset-size", + str(args.dataset_size), + "--batch-size", + str(args.batch_size), + "--quant-layers", + str(args.quant_layers), + "--stop-after-layer", + str(args.stop_after_layer), + "--dtype", + args.dtype, + "--attn-implementation", + args.attn_implementation, + "--vram-strategy", + vram_strategy, + "--dense-vram-strategy", + dense_vram_strategy, + "--moe-vram-strategy", + moe_vram_strategy, + ] + if cuda_visible_devices is not None: + cmd.extend(["--cuda-visible-devices", cuda_visible_devices]) + if dense_vram_strategy_devices is not None: + cmd.extend(["--dense-vram-strategy-devices", dense_vram_strategy_devices]) + if moe_vram_strategy_devices is not None: + cmd.extend(["--moe-vram-strategy-devices", moe_vram_strategy_devices]) + + log_handle = log_out.open("w", encoding="utf-8") + proc = subprocess.Popen( + cmd, + cwd=repo_root, + env=env, + stdout=log_handle, + stderr=subprocess.STDOUT, + text=True, + ) + return proc, log_handle, json_out, log_out + + current_proc, current_log_handle, current_json_out, _ = _start_case( + repo_root=args.current_root.resolve(), + label="current", + vram_strategy=current_vram_strategy, + cuda_visible_devices=current_cuda_visible_devices, + dense_vram_strategy=current_dense_vram_strategy, + dense_vram_strategy_devices=current_dense_vram_strategy_devices, + moe_vram_strategy=current_moe_vram_strategy, + moe_vram_strategy_devices=current_moe_vram_strategy_devices, + ) + baseline_proc, baseline_log_handle, baseline_json_out, _ = _start_case( + repo_root=args.baseline_root.resolve(), + label="baseline", + vram_strategy=baseline_vram_strategy, + cuda_visible_devices=baseline_cuda_visible_devices, + dense_vram_strategy=baseline_dense_vram_strategy, + dense_vram_strategy_devices=baseline_dense_vram_strategy_devices, + moe_vram_strategy=baseline_moe_vram_strategy, + moe_vram_strategy_devices=baseline_moe_vram_strategy_devices, + ) + + try: + current_returncode = current_proc.wait() + baseline_returncode = baseline_proc.wait() + finally: + current_log_handle.close() + baseline_log_handle.close() + + if current_returncode != 0: + raise subprocess.CalledProcessError(current_returncode, current_proc.args) + if baseline_returncode != 0: + raise subprocess.CalledProcessError(baseline_returncode, baseline_proc.args) + + with current_json_out.open("r", encoding="utf-8") as handle: + current = json.load(handle) + with baseline_json_out.open("r", encoding="utf-8") as handle: + baseline = json.load(handle) + + compare = _compare_cases(current=current, baseline=baseline) + compare_path = output_dir / "compare.json" + with compare_path.open("w", encoding="utf-8") as handle: + json.dump( + { + "current": current, + "baseline": baseline, + "compare": compare, + }, + handle, + indent=2, + sort_keys=True, + ) + + print(f"Results written to {output_dir}") + _print_case_summary(current) + _print_case_summary(baseline) + _print_compare(compare) + return 0 + + +def main() -> int: + args = _parse_args() + if args.single: + return _single_case_main(args) + return _ab_main(args) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/dequantize_model.py b/scripts/dequantize_model.py index 337149ba7..377a7fd53 100755 --- a/scripts/dequantize_model.py +++ b/scripts/dequantize_model.py @@ -3,7 +3,7 @@ # SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai # SPDX-License-Identifier: Apache-2.0 -"""CLI entry point for dequantizing GPTQModel safetensor shards.""" +"""CLI entry point for dequantizing GPT-QModel safetensor shards.""" from __future__ import annotations diff --git a/scripts/eval_model.py b/scripts/eval_model.py index 45a574815..94d49e4a0 100644 --- a/scripts/eval_model.py +++ b/scripts/eval_model.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai # SPDX-FileCopyrightText: 2024-2025 gptqmodel contributors # SPDX-License-Identifier: Apache-2.0 -"""CLI helper to run lm-eval tasks against a GPTQModel checkpoint.""" +"""CLI helper to run Evalution-backed tasks against a GPT-QModel checkpoint.""" import argparse import json @@ -15,7 +15,12 @@ from tabulate import tabulate from gptqmodel import GPTQModel from gptqmodel.models.base import BaseQModel -from gptqmodel.utils.eval import EVAL +from tests.eval import ( + evaluate, + get_eval_task_results, + list_supported_tasks, + normalize_eval_task_name, +) if sys.platform == "darwin": @@ -27,9 +32,8 @@ "expandable_segments:True,max_split_size_mb:256,garbage_collection_threshold:0.7", ) -DEFAULT_RESULTS_PATH = Path("lm_eval_results.json") -DEFAULT_TASKS = (EVAL.LM_EVAL.ARC_CHALLENGE,) -DEFAULT_TASK_MANAGER_PATH = Path(__file__).resolve().parent.parent / "tests" / "tasks" +DEFAULT_RESULTS_PATH = Path("evalution_results.json") +DEFAULT_TASKS = ("arc_challenge",) def _available_backends() -> Dict[str, gptqmodel.BACKEND]: @@ -81,31 +85,29 @@ def _parse_key_value_pairs(pairs: Iterable[str]) -> Dict[str, object]: return result -def _resolve_task(name: str) -> EVAL.LM_EVAL: - normalized = name.strip() - for task in EVAL.LM_EVAL: - if normalized.lower() in {task.value.lower(), task.name.lower()}: - return task - available = ", ".join(task.value for task in EVAL.LM_EVAL) - raise argparse.ArgumentTypeError(f"Unknown lm-eval task '{name}'. Expected one of: {available}") +def _resolve_task(name: str) -> str: + normalized = normalize_eval_task_name(name) + if normalized in list_supported_tasks(): + return normalized + available = ", ".join(list_supported_tasks()) + raise argparse.ArgumentTypeError(f"Unknown Evalution task '{name}'. Expected one of: {available}") def _list_tasks() -> None: - rows = [(task.name, task.value) for task in EVAL.LM_EVAL] + rows = [(task_name, task_name) for task_name in list_supported_tasks()] print(tabulate(rows, headers=["Name", "Identifier"])) def _extract_metrics(results: Dict) -> Dict[str, Dict[str, float]]: - aggregated: Dict[str, Dict[str, float]] = {} - task_results = results.get("results", {}) - for task_name, metrics in task_results.items(): - filtered = { + aggregated = get_eval_task_results(results) + return { + task_name: { metric: value for metric, value in metrics.items() if metric != "alias" and "stderr" not in metric } - aggregated[task_name] = filtered - return aggregated + for task_name, metrics in aggregated.items() + } def _print_metrics_table(metrics: Dict[str, Dict[str, float]], table_format: str) -> None: @@ -134,9 +136,9 @@ def _split_tasks(arg_value: str | None) -> List[str]: def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( - description="Run lm-eval tasks against a quantized model loaded via gptqmodel." + description="Run Evalution tasks against a quantized model loaded via gptqmodel." ) - parser.add_argument("--model", required=True, help="Model path or Hugging Face repo id.") + parser.add_argument("--model", help="Model path or Hugging Face repo id.") parser.add_argument( "--backend", default="auto", @@ -145,8 +147,8 @@ def parse_args() -> argparse.Namespace: ) parser.add_argument( "--tasks", - default=",".join(task.value for task in DEFAULT_TASKS), - help="Comma-separated lm-eval task identifiers (see --list-tasks).", + default=",".join(DEFAULT_TASKS), + help="Comma-separated Evalution task identifiers (see --list-tasks).", ) parser.add_argument( "--chat-template-tasks", @@ -157,7 +159,7 @@ def parse_args() -> argparse.Namespace: "--batch-size", default="auto", type=_parse_batch_size, - help="Evaluation batch size passed to lm-eval (integer or 'auto').", + help="Evaluation batch size passed to Evalution (integer or 'auto').", ) parser.add_argument( "--dtype", @@ -167,14 +169,14 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--gen-kwargs", default=None, - help="Generation kwargs forwarded to lm-eval, e.g. 'temperature=0.0,top_k=50'.", + help="Generation kwargs forwarded to Evalution, e.g. 'temperature=0.0,top_k=50'.", ) parser.add_argument( "--model-arg", action="append", default=[], metavar="KEY=VALUE", - help="Extra model_args forwarded to GPTQModel.eval (repeatable).", + help="Extra model_args forwarded to Evalution (repeatable).", ) parser.add_argument( "--load-arg", @@ -188,34 +190,6 @@ def parse_args() -> argparse.Namespace: action="store_true", help="Allow loading models that require remote code execution.", ) - parser.add_argument( - "--use-vllm", - action="store_true", - help="Run evaluation with the vLLM backend instead of the default gptqmodel harness.", - ) - parser.add_argument( - "--max-model-len", - type=int, - default=None, - help="Optional max_model_len passed to vLLM model args.", - ) - parser.add_argument( - "--random-seed", - type=int, - default=898, - help="Seed propagated to lm-eval for reproducibility.", - ) - parser.add_argument( - "--task-manager-path", - type=str, - default=str(DEFAULT_TASK_MANAGER_PATH) if DEFAULT_TASK_MANAGER_PATH.exists() else None, - help="Optional path containing custom lm-eval tasks.", - ) - parser.add_argument( - "--include-default-tasks", - action="store_true", - help="Include lm-eval's builtin task registry alongside the custom task directory.", - ) parser.add_argument( "--output", type=str, @@ -230,7 +204,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--list-tasks", action="store_true", - help="List supported lm-eval task identifiers and exit.", + help="List supported Evalution task identifiers and exit.", ) return parser.parse_args() @@ -241,13 +215,13 @@ def main() -> None: if args.list_tasks: _list_tasks() return + if not args.model: + raise ValueError("--model is required unless --list-tasks is used.") tasks = [_resolve_task(name) for name in _split_tasks(args.tasks)] if not tasks: - raise ValueError("No lm-eval tasks specified.") - chat_template_tasks = {_resolve_task(name).value for name in _split_tasks(args.chat_template_tasks)} - - llm_backend = "vllm" if args.use_vllm else "gptqmodel" + raise ValueError("No Evalution tasks specified.") + chat_template_tasks = {_resolve_task(name) for name in _split_tasks(args.chat_template_tasks)} backend: gptqmodel.BACKEND = args.backend load_kwargs = _parse_key_value_pairs(args.load_arg) @@ -260,54 +234,31 @@ def main() -> None: ) if not isinstance(model, BaseQModel): - raise RuntimeError("Failed to load GPTQModel; received unexpected object type.") + raise RuntimeError("Failed to load GPT-QModel; received unexpected object type.") model_args = _parse_key_value_pairs(args.model_arg) - if args.max_model_len is not None: - model_args.setdefault("max_model_len", args.max_model_len) - - if args.use_vllm: - model_args.setdefault("dtype", "auto") - model_args.setdefault("tensor_parallel_size", 1) - model_args.setdefault("gpu_memory_utilization", 0.8) - - task_manager = None - if args.task_manager_path: - task_manager_path = Path(args.task_manager_path).expanduser().resolve() - if not task_manager_path.exists(): - raise FileNotFoundError(f"Task manager path does not exist: {task_manager_path}") - from lm_eval.tasks import TaskManager - - task_manager = TaskManager( - include_path=str(task_manager_path), - include_defaults=args.include_default_tasks, - ) aggregated_metrics: Dict[str, Dict[str, float]] = {} - grouped_tasks: Dict[bool, List[EVAL.LM_EVAL]] = {} + grouped_tasks: Dict[bool, List[str]] = {} for task in tasks: - apply_chat = task.value in chat_template_tasks + apply_chat = task in chat_template_tasks grouped_tasks.setdefault(apply_chat, []).append(task) for apply_chat_template, grouped in grouped_tasks.items(): if not grouped: continue - result = gptqmodel.GPTQModel.eval( + result = evaluate( model_or_id_or_path=model, tasks=grouped, - framework=EVAL.LM_EVAL, batch_size=args.batch_size, trust_remote_code=args.trust_remote_code, output_path=None, - llm_backend=llm_backend, backend=backend, - random_seed=args.random_seed, model_args=model_args.copy(), gen_kwargs=args.gen_kwargs, apply_chat_template=apply_chat_template, - task_manager=task_manager, ) group_metrics = _extract_metrics(result) diff --git a/scripts/generate_exl3_kernel_map_packed.py b/scripts/generate_exl3_kernel_map_packed.py new file mode 100644 index 000000000..d552d7dd5 --- /dev/null +++ b/scripts/generate_exl3_kernel_map_packed.py @@ -0,0 +1,358 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import argparse +import bisect +from dataclasses import dataclass +from pathlib import Path +from urllib.request import urlopen + +import pcre + +BLOCK_RE = pcre.compile( + r"struct TSample samples_(\d+)\[\]\s*=\s*\{(.*?)\n\};", + flags=pcre.Flag.DOTALL, +) +ROW_RE = pcre.compile( + r"\{\s*(\d+),\s*(\d+),\s*(\d+),\s*(\d+),\s*(\d+),\s*(\d+),\s*(\d+)\s*\}" +) +EXLLAMAV3_ORIGINAL_COMMIT = "ba1ad9ac66670785c0ca95b0f1ab3ad044fda7c6" +EXLLAMAV3_ORIGINAL_LEGACY_HEADER_URL = ( + "https://raw.githubusercontent.com/turboderp-org/exllamav3/" + f"{EXLLAMAV3_ORIGINAL_COMMIT}/exllamav3/exllamav3_ext/quant/exl3_kernel_map_samples.cuh" +) + + +@dataclass(frozen=True) +class LegacyRow: + cc: int + bits: int + m: int + k: int + n: int + shape_idx: int + num_sms: int + + +@dataclass(frozen=True) +class PackedBlock: + mod: int + cc_values: tuple[int, ...] + bit_values: tuple[int, ...] + k_axis: tuple[int, ...] + n_axis: tuple[int, ...] + payload: tuple[int, ...] + row_lookup: dict[tuple[int, int, int, int], LegacyRow] + + +def parse_legacy_text(text: str) -> dict[int, list[LegacyRow]]: + blocks: dict[int, list[LegacyRow]] = {} + + for mod_text, body in BLOCK_RE.findall(text): + mod = int(mod_text) + rows: list[LegacyRow] = [] + for match in ROW_RE.findall(body): + row = LegacyRow(*(int(value) for value in match)) + if row.bits == 0: + continue + rows.append(row) + blocks[mod] = rows + + expected_blocks = {128, 256, 512} + if set(blocks) != expected_blocks: + raise ValueError(f"expected blocks {sorted(expected_blocks)}, found {sorted(blocks)}") + + return blocks + + +def parse_legacy_header(path: Path) -> dict[int, list[LegacyRow]]: + return parse_legacy_text(path.read_text()) + + +def download_legacy_header(url: str = EXLLAMAV3_ORIGINAL_LEGACY_HEADER_URL, timeout: int = 30) -> str: + with urlopen(url, timeout=timeout) as response: + return response.read().decode("utf-8") + + +def pack_row(shape_idx: int, num_sms: int) -> int: + if not (0 < shape_idx < 256): + raise ValueError(f"shape_idx out of range: {shape_idx}") + if not (0 < num_sms < 256): + raise ValueError(f"num_sms out of range: {num_sms}") + return (shape_idx << 8) | num_sms + + +def build_packed_block(mod: int, rows: list[LegacyRow]) -> PackedBlock: + cc_values = tuple(sorted({row.cc for row in rows})) + bit_values = tuple(sorted({row.bits for row in rows})) + k_axis = tuple(sorted({row.k for row in rows})) + n_axis = tuple(sorted({row.n for row in rows})) + + if cc_values != (2, 3, 4): + raise ValueError(f"unexpected cc axis for {mod}: {cc_values}") + if bit_values != (1, 2, 3, 4, 5, 6, 7, 8): + raise ValueError(f"unexpected bit axis for {mod}: {bit_values}") + if any(row.m != 1 for row in rows): + raise ValueError(f"unexpected m values in samples_{mod}") + + expected = len(cc_values) * len(bit_values) * len(k_axis) * len(n_axis) + if len(rows) != expected: + raise ValueError(f"samples_{mod} is not a full grid: got {len(rows)}, expected {expected}") + + row_lookup: dict[tuple[int, int, int, int], LegacyRow] = {} + for row in rows: + key = (row.cc, row.bits, row.k, row.n) + if key in row_lookup: + raise ValueError(f"duplicate row for {key} in samples_{mod}") + row_lookup[key] = row + + payload: list[int] = [] + for cc in cc_values: + for bits in bit_values: + for k in k_axis: + for n in n_axis: + key = (cc, bits, k, n) + row = row_lookup.get(key) + if row is None: + raise ValueError(f"missing row for {key} in samples_{mod}") + payload.append(pack_row(row.shape_idx, row.num_sms)) + + return PackedBlock( + mod=mod, + cc_values=cc_values, + bit_values=bit_values, + k_axis=k_axis, + n_axis=n_axis, + payload=tuple(payload), + row_lookup=row_lookup, + ) + + +def nearest_axis_index(axis: tuple[int, ...], value: int) -> int: + best_idx = 0 + best_dist = abs(value - axis[0]) + for idx in range(1, len(axis)): + dist = abs(value - axis[idx]) + if dist < best_dist: + best_dist = dist + best_idx = idx + return best_idx + + +def packed_lookup(block: PackedBlock, cc: int, bits: int, size_k: int, size_n: int) -> tuple[int, int]: + try: + cc_idx = block.cc_values.index(cc) + bit_idx = block.bit_values.index(bits) + except ValueError as exc: + raise ValueError(f"unsupported lookup key {(cc, bits)} for block {block.mod}") from exc + + k_idx = nearest_axis_index(block.k_axis, size_k) + n_idx = nearest_axis_index(block.n_axis, size_n) + flat_idx = ((((cc_idx * len(block.bit_values)) + bit_idx) * len(block.k_axis)) + k_idx) * len(block.n_axis) + n_idx + packed = block.payload[flat_idx] + return packed >> 8, packed & 0xFF + + +def legacy_lookup(rows: list[LegacyRow], cc: int, bits: int, size_k: int, size_n: int) -> tuple[int, int]: + best_row: LegacyRow | None = None + best_dist: int | None = None + for row in rows: + if row.cc != cc or row.bits != bits: + continue + distk = size_k - row.k + distn = size_n - row.n + dist = distk * distk + distn * distn + if best_dist is None or dist < best_dist: + best_dist = dist + best_row = row + + if best_row is None: + raise ValueError(f"no legacy row for {(cc, bits, size_k, size_n)}") + return best_row.shape_idx, best_row.num_sms + + +def transition_points(axis: tuple[int, ...]) -> list[int]: + points = {1, axis[-1] + 1} + for left, right in zip(axis, axis[1:]): + midpoint = (left + right) // 2 + points.add(midpoint) + points.add(midpoint + 1) + return sorted(points) + + +def block_domain_values(mod: int, max_value: int) -> list[int]: + values = [] + value = mod + while value <= max_value: + if mod == 512: + valid = value % 512 == 0 + elif mod == 256: + valid = value % 256 == 0 and value % 512 != 0 + elif mod == 128: + valid = value % 128 == 0 and value % 256 != 0 + else: + raise ValueError(f"unexpected block modulus: {mod}") + if valid: + values.append(value) + value += 128 + return values + + +def domain_transition_points(mod: int, axis: tuple[int, ...]) -> list[int]: + domain = block_domain_values(mod, axis[-1] + mod) + points = {domain[0], domain[-1]} + + for left, right in zip(axis, axis[1:]): + midpoint = (left + right) // 2 + lower_idx = bisect.bisect_right(domain, midpoint) - 1 + upper_idx = bisect.bisect_right(domain, midpoint) + if lower_idx >= 0: + points.add(domain[lower_idx]) + if upper_idx < len(domain): + points.add(domain[upper_idx]) + + return sorted(points) + + +def validate_row_exactness(legacy: dict[int, list[LegacyRow]], packed: dict[int, PackedBlock]) -> None: + for mod, block in packed.items(): + for row in legacy[mod]: + expected = pack_row(row.shape_idx, row.num_sms) + cc_idx = block.cc_values.index(row.cc) + bit_idx = block.bit_values.index(row.bits) + k_idx = block.k_axis.index(row.k) + n_idx = block.n_axis.index(row.n) + flat_idx = ((((cc_idx * len(block.bit_values)) + bit_idx) * len(block.k_axis)) + k_idx) * len(block.n_axis) + n_idx + actual = block.payload[flat_idx] + if actual != expected: + raise ValueError(f"packed row mismatch for samples_{mod}: {row} -> {actual:#06x} != {expected:#06x}") + + +def validate_lookup_equivalence(legacy: dict[int, list[LegacyRow]], packed: dict[int, PackedBlock]) -> None: + k_points = transition_points(packed[128].k_axis) + + for mod, block in packed.items(): + n_points = domain_transition_points(mod, block.n_axis) + grouped_rows: dict[tuple[int, int], list[LegacyRow]] = {} + for row in legacy[mod]: + grouped_rows.setdefault((row.cc, row.bits), []).append(row) + for cc in block.cc_values: + for bits in block.bit_values: + legacy_rows = grouped_rows[(cc, bits)] + for size_k in k_points: + for size_n in n_points: + legacy_result = legacy_lookup(legacy_rows, cc, bits, size_k, size_n) + packed_result = packed_lookup(block, cc, bits, size_k, size_n) + if packed_result != legacy_result: + raise ValueError( + "lookup mismatch for " + f"samples_{mod} cc={cc} bits={bits} size_k={size_k} size_n={size_n}: " + f"{packed_result} != {legacy_result}" + ) + + +def format_int_array(name: str, values: tuple[int, ...], values_per_line: int = 8) -> str: + lines = [f"constexpr int {name}[] = {{"] + for idx in range(0, len(values), values_per_line): + chunk = ", ".join(str(value) for value in values[idx:idx + values_per_line]) + lines.append(f" {chunk},") + lines.append("};") + return "\n".join(lines) + + +def format_u16_array(name: str, values: tuple[int, ...], values_per_line: int = 24) -> str: + lines = [f"constexpr uint16_t {name}[] = {{"] + for idx in range(0, len(values), values_per_line): + chunk = ", ".join(f"0x{value:04x}" for value in values[idx:idx + values_per_line]) + lines.append(f" {chunk},") + lines.append("};") + return "\n".join(lines) + + +def render_header(blocks: dict[int, PackedBlock]) -> str: + block128 = blocks[128] + block256 = blocks[256] + block512 = blocks[512] + + sections = [ + "#pragma once", + "", + "#include ", + "", + "// Generated by scripts/generate_exl3_kernel_map_packed.py.", + "// Encodes the EXL3 tuning samples as dense [cc][bits][k][n] grids.", + "", + "namespace exl3_packed {", + "", + f"constexpr int cc_count = {len(block128.cc_values)};", + f"constexpr int bit_count = {len(block128.bit_values)};", + f"constexpr int k_axis_len = {len(block128.k_axis)};", + f"constexpr int n_axis_len_128 = {len(block128.n_axis)};", + f"constexpr int n_axis_len_256 = {len(block256.n_axis)};", + f"constexpr int n_axis_len_512 = {len(block512.n_axis)};", + "", + format_int_array("cc_values", block128.cc_values), + "", + format_int_array("bit_values", block128.bit_values), + "", + format_int_array("k_axis", block128.k_axis), + "", + format_int_array("n_axis_128", block128.n_axis), + "", + format_int_array("n_axis_256", block256.n_axis), + "", + format_int_array("n_axis_512", block512.n_axis), + "", + format_u16_array("samples_128", block128.payload), + "", + format_u16_array("samples_256", block256.payload), + "", + format_u16_array("samples_512", block512.payload), + "", + "} // namespace exl3_packed", + "", + ] + return "\n".join(sections) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate the packed EXL3 kernel-map header from the legacy row table.") + parser.add_argument( + "--legacy", + type=Path, + help="Optional local path to the legacy row-wise samples header.", + ) + parser.add_argument( + "--url", + default=EXLLAMAV3_ORIGINAL_LEGACY_HEADER_URL, + help="Raw URL for the original ExLlamaV3 legacy row-wise samples header.", + ) + parser.add_argument( + "--output", + type=Path, + default=Path("gptqmodel_ext/exllamav3/quant/exl3_kernel_map_packed.cuh"), + help="Path to the generated packed-grid header.", + ) + args = parser.parse_args() + + if args.legacy is not None: + legacy = parse_legacy_header(args.legacy) + source = str(args.legacy) + else: + legacy = parse_legacy_text(download_legacy_header(args.url)) + source = args.url + + packed = {mod: build_packed_block(mod, rows) for mod, rows in sorted(legacy.items())} + validate_row_exactness(legacy, packed) + validate_lookup_equivalence(legacy, packed) + args.output.write_text(render_header(packed)) + + payload_entries = sum(len(block.payload) for block in packed.values()) + print( + f"validated legacy lookup from {source} and wrote {args.output} " + f"({payload_entries} packed entries across {len(packed)} blocks)" + ) + + +if __name__ == "__main__": + main() diff --git a/scripts/nvml_visible_shim.c b/scripts/nvml_visible_shim.c new file mode 100644 index 000000000..4ad4e5f4b --- /dev/null +++ b/scripts/nvml_visible_shim.c @@ -0,0 +1,188 @@ +#define _GNU_SOURCE + +#include +#include +#include +#include +#include +#include +#include + +typedef int nvmlReturn_t; +typedef struct nvmlDevice_st *nvmlDevice_t; + +#define NVML_SUCCESS 0 +#define NVML_ERROR_INVALID_ARGUMENT 2 + +typedef nvmlReturn_t (*nvml_count_fn)(unsigned int *); +typedef nvmlReturn_t (*nvml_handle_fn)(unsigned int, nvmlDevice_t *); + +static int g_ids[64]; +static size_t g_id_count = 0; +static int g_initialized = 0; +static void *g_nvml_handle = NULL; +static int g_use_physical_span = 0; + +static void ensure_visible_ids(void) { + if (g_initialized) { + return; + } + + g_initialized = 1; + + const char *raw = getenv("GPTQMODEL_VLLM_HEALTHY_PHYSICAL_IDS"); + if (raw == NULL || *raw == '\0') { + raw = getenv("GPTQMODEL_VLLM_VISIBLE_PHYSICAL_IDS"); + } else { + g_use_physical_span = 1; + } + if (raw == NULL || *raw == '\0') { + return; + } + + char *copy = strdup(raw); + if (copy == NULL) { + return; + } + + char *cursor = copy; + while (cursor != NULL && *cursor != '\0' && g_id_count < (sizeof(g_ids) / sizeof(g_ids[0]))) { + char *next = strchr(cursor, ','); + if (next != NULL) { + *next = '\0'; + } + + char *end = NULL; + long value = strtol(cursor, &end, 10); + if (end != cursor && *end == '\0' && value >= 0 && value <= INT_MAX) { + g_ids[g_id_count++] = (int)value; + } + + cursor = next == NULL ? NULL : next + 1; + } + + free(copy); +} + +static unsigned int mapped_index(unsigned int index) { + if (g_id_count == 0) { + return index; + } + + if (!g_use_physical_span) { + if (index >= g_id_count) { + return UINT_MAX; + } + return (unsigned int)g_ids[index]; + } + + for (size_t i = 0; i < g_id_count; ++i) { + if ((unsigned int)g_ids[i] == index) { + return index; + } + if ((unsigned int)g_ids[i] > index) { + return (unsigned int)g_ids[i]; + } + } + + return (unsigned int)g_ids[g_id_count - 1]; +} + +static void *resolve_symbol(const char *name) { + void *symbol = dlsym(RTLD_NEXT, name); + if (symbol != NULL) { + return symbol; + } + + if (g_nvml_handle == NULL) { + g_nvml_handle = dlopen("libnvidia-ml.so.1", RTLD_LAZY | RTLD_LOCAL); + } + if (g_nvml_handle == NULL) { + return NULL; + } + return dlsym(g_nvml_handle, name); +} + +static nvmlReturn_t call_real_count(const char *name, unsigned int *count) { + nvml_count_fn fn = (nvml_count_fn)resolve_symbol(name); + if (fn == NULL) { + return NVML_ERROR_INVALID_ARGUMENT; + } + return fn(count); +} + +static nvmlReturn_t call_real_handle(const char *name, unsigned int index, nvmlDevice_t *device) { + nvml_handle_fn fn = (nvml_handle_fn)resolve_symbol(name); + if (fn == NULL) { + return NVML_ERROR_INVALID_ARGUMENT; + } + return fn(index, device); +} + +nvmlReturn_t nvmlDeviceGetCount_v2(unsigned int *count) { + if (count == NULL) { + return NVML_ERROR_INVALID_ARGUMENT; + } + + ensure_visible_ids(); + if (g_id_count > 0) { + if (g_use_physical_span) { + *count = (unsigned int)g_ids[g_id_count - 1] + 1U; + } else { + *count = (unsigned int)g_id_count; + } + return NVML_SUCCESS; + } + + return call_real_count("nvmlDeviceGetCount_v2", count); +} + +nvmlReturn_t nvmlDeviceGetCount(unsigned int *count) { + if (count == NULL) { + return NVML_ERROR_INVALID_ARGUMENT; + } + + ensure_visible_ids(); + if (g_id_count > 0) { + if (g_use_physical_span) { + *count = (unsigned int)g_ids[g_id_count - 1] + 1U; + } else { + *count = (unsigned int)g_id_count; + } + return NVML_SUCCESS; + } + + return call_real_count("nvmlDeviceGetCount", count); +} + +nvmlReturn_t nvmlDeviceGetHandleByIndex_v2(unsigned int index, nvmlDevice_t *device) { + if (device == NULL) { + return NVML_ERROR_INVALID_ARGUMENT; + } + + ensure_visible_ids(); + if (g_id_count > 0) { + index = mapped_index(index); + if (index == UINT_MAX) { + return NVML_ERROR_INVALID_ARGUMENT; + } + } + + return call_real_handle("nvmlDeviceGetHandleByIndex_v2", index, device); +} + +nvmlReturn_t nvmlDeviceGetHandleByIndex(unsigned int index, nvmlDevice_t *device) { + if (device == NULL) { + return NVML_ERROR_INVALID_ARGUMENT; + } + + ensure_visible_ids(); + if (g_id_count > 0) { + index = mapped_index(index); + if (index == UINT_MAX) { + return NVML_ERROR_INVALID_ARGUMENT; + } + } + + return call_real_handle("nvmlDeviceGetHandleByIndex", index, device); +} diff --git a/scripts/paroquant_first_layer_ab.py b/scripts/paroquant_first_layer_ab.py new file mode 100644 index 000000000..b9aad3a05 --- /dev/null +++ b/scripts/paroquant_first_layer_ab.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import argparse +from pathlib import Path + +from tabulate import tabulate + +from gptqmodel.utils.paroquant_benchmark import ( + comparison_rows, + render_case_tables, + run_fp16_eval, + run_paroquant_first_layer_case, + write_case_json, +) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Benchmark first-N-layer ParoQuant on Llama-3.2-1B-Instruct and evaluate GSM8K Platinum." + ) + parser.add_argument("--model", default="/monster/data/model/Llama-3.2-1B-Instruct") + parser.add_argument("--quant-layers", type=int, default=1, help="Quantize the first N decoder layers.") + parser.add_argument("--calibration-rows", type=int, default=64) + parser.add_argument("--calibration-concat-size", type=int, default=2048) + parser.add_argument("--quant-batch-size", type=int, default=1) + parser.add_argument("--eval-batch-size", type=int, default=64) + parser.add_argument("--eval-max-rows", type=int, default=None) + parser.add_argument( + "--sym", + action="store_true", + default=True, + help="ParoQuant is sym-only; this flag is kept for compatibility and has no effect.", + ) + parser.add_argument( + "--no-fused-opt-rotation", + action="store_true", + help="Disable the fused CUDA rotation autograd path during ParoQuant optimization.", + ) + parser.add_argument( + "--opt-scope", + choices=("module", "compute_block", "layer"), + default="module", + help="ParoQuant optimization scope for the selected decoder layers.", + ) + parser.add_argument("--opt-rotation-epochs", type=int, default=10) + parser.add_argument("--opt-finetune-epochs", type=int, default=10) + parser.add_argument("--opt-train-samples", type=int, default=2048) + parser.add_argument("--opt-validation-samples", type=int, default=64) + parser.add_argument("--opt-batch-size", type=int, default=64) + parser.add_argument("--skip-fp16-baseline", action="store_true") + parser.add_argument( + "--output-json", + type=Path, + default=None, + help="Optional JSON path for the quantized-case result payload.", + ) + return parser.parse_args() + + +def main() -> int: + args = parse_args() + + baseline = None + if not args.skip_fp16_baseline: + baseline = run_fp16_eval( + model_path=args.model, + eval_batch_size=args.eval_batch_size, + eval_max_rows=args.eval_max_rows, + ) + baseline["label"] = "fp16_baseline" + + quant_case = run_paroquant_first_layer_case( + model_path=args.model, + num_quant_layers=args.quant_layers, + calibration_rows=args.calibration_rows, + calibration_concat_size=args.calibration_concat_size, + quant_batch_size=args.quant_batch_size, + eval_batch_size=args.eval_batch_size, + eval_max_rows=args.eval_max_rows, + sym=args.sym, + fused_opt_rotation=not args.no_fused_opt_rotation, + opt_scope=args.opt_scope, + opt_rotation_epochs=args.opt_rotation_epochs, + opt_finetune_epochs=args.opt_finetune_epochs, + opt_train_samples=args.opt_train_samples, + opt_validation_samples=args.opt_validation_samples, + opt_batch_size=args.opt_batch_size, + ) + quant_case["label"] = ( + "paroquant_first_layer" if args.quant_layers == 1 else f"paroquant_first_{args.quant_layers}_layers" + ) + + if args.output_json is not None: + write_case_json(quant_case, args.output_json) + + cases = [case for case in (baseline, quant_case) if case is not None] + print("Summary") + print( + tabulate( + comparison_rows(*cases), + headers=["case", "opt_scope", "sym", "fused_opt", "gsm8k_platinum_cot", "quant_wall_s", "eval_wall_s"], + tablefmt="grid", + ) + ) + + tables = render_case_tables(quant_case) + print() + print("Quant Module Times") + print(tables["module_times"]) + print() + print("Quant Region Times") + print(tables["regions"]) + print() + print("Kernel Parity And Speed") + print(tables["kernels"]) + print() + print("Eval") + print(tables["eval"]) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/paroquant_module_set_scan.py b/scripts/paroquant_module_set_scan.py new file mode 100644 index 000000000..9350ec68d --- /dev/null +++ b/scripts/paroquant_module_set_scan.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import argparse +import json +from pathlib import Path + +from tabulate import tabulate + +from gptqmodel.utils.paroquant_benchmark import run_paroquant_selected_modules_case + + +DEFAULT_CASES = [ + ("attn_all", ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj", "self_attn.o_proj"]), + ("mlp_all", ["mlp.gate_proj", "mlp.up_proj", "mlp.down_proj"]), + ("mlp_gate_up", ["mlp.gate_proj", "mlp.up_proj"]), + ("mlp_gate_down", ["mlp.gate_proj", "mlp.down_proj"]), + ("mlp_up_down", ["mlp.up_proj", "mlp.down_proj"]), +] + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Quantize selected module sets inside one decoder layer and evaluate GSM8K Platinum." + ) + parser.add_argument("--model", default="/monster/data/model/Llama-3.2-1B-Instruct") + parser.add_argument("--layer-index", type=int, required=True, help="Zero-based decoder layer index to probe.") + parser.add_argument( + "--case", + dest="cases", + action="append", + default=None, + help="Case in the form label=module_a,module_b . Can be passed multiple times.", + ) + parser.add_argument("--calibration-rows", type=int, default=64) + parser.add_argument("--calibration-concat-size", type=int, default=2048) + parser.add_argument("--quant-batch-size", type=int, default=1) + parser.add_argument("--eval-batch-size", type=int, default=64) + parser.add_argument("--eval-max-rows", type=int, default=None) + parser.add_argument("--opt-rotation-epochs", type=int, default=10) + parser.add_argument("--opt-finetune-epochs", type=int, default=10) + parser.add_argument("--opt-train-samples", type=int, default=2048) + parser.add_argument("--opt-validation-samples", type=int, default=64) + parser.add_argument("--opt-batch-size", type=int, default=64) + parser.add_argument("--output-json", type=Path, default=None) + return parser.parse_args() + + +def _parse_cases(raw_cases: list[str] | None) -> list[tuple[str, list[str]]]: + if not raw_cases: + return [(label, list(modules)) for label, modules in DEFAULT_CASES] + + parsed: list[tuple[str, list[str]]] = [] + for item in raw_cases: + label, sep, module_csv = str(item).partition("=") + if not sep: + raise ValueError(f"Invalid --case `{item}`. Expected label=module_a,module_b") + modules = [module.strip() for module in module_csv.split(",") if module.strip()] + if not modules: + raise ValueError(f"Invalid --case `{item}`. At least one module is required.") + parsed.append((label.strip(), modules)) + return parsed + + +def _score(case: dict) -> float | None: + metric = case.get("eval_metrics") or case.get("metrics") or {} + gsm = metric.get("gsm8k_platinum_cot", {}) + if not isinstance(gsm, dict): + return None + return float(gsm["acc,num"]) if "acc,num" in gsm else None + + +def main() -> int: + args = parse_args() + parsed_cases = _parse_cases(args.cases) + + results: list[dict] = [] + for label, module_names in parsed_cases: + case = run_paroquant_selected_modules_case( + model_path=args.model, + layer_idx=args.layer_index, + module_names=module_names, + calibration_rows=args.calibration_rows, + calibration_concat_size=args.calibration_concat_size, + quant_batch_size=args.quant_batch_size, + eval_batch_size=args.eval_batch_size, + eval_max_rows=args.eval_max_rows, + sym=True, + fused_opt_rotation=True, + opt_rotation_epochs=args.opt_rotation_epochs, + opt_finetune_epochs=args.opt_finetune_epochs, + opt_train_samples=args.opt_train_samples, + opt_validation_samples=args.opt_validation_samples, + opt_batch_size=args.opt_batch_size, + ) + case["label"] = label + results.append(case) + + rows = [] + for case in results: + rows.append( + [ + case["label"], + ",".join(case.get("module_names", [])), + "" if _score(case) is None else f"{_score(case):.6f}", + f"{float(case['quant_wall_s']):.3f}", + f"{float(case['eval_wall_s']):.3f}", + ] + ) + rows.sort(key=lambda item: float(item[2] or 0.0)) + + print("Summary") + print( + tabulate( + rows, + headers=["case", "module_names", "gsm8k_platinum_cot", "quant_wall_s", "eval_wall_s"], + tablefmt="grid", + ) + ) + + if args.output_json is not None: + args.output_json.parent.mkdir(parents=True, exist_ok=True) + with args.output_json.open("w", encoding="utf-8") as handle: + json.dump(results, handle, indent=2, sort_keys=True) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/paroquant_single_module_scan.py b/scripts/paroquant_single_module_scan.py new file mode 100644 index 000000000..7598b36b7 --- /dev/null +++ b/scripts/paroquant_single_module_scan.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import argparse +import json +from pathlib import Path + +from tabulate import tabulate + +from gptqmodel.utils.paroquant_benchmark import run_paroquant_single_module_case + + +DEFAULT_MODULES = [ + "self_attn.q_proj", + "self_attn.k_proj", + "self_attn.v_proj", + "self_attn.o_proj", + "mlp.gate_proj", + "mlp.up_proj", + "mlp.down_proj", +] + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Quantize one module at a time inside one decoder layer and evaluate GSM8K Platinum." + ) + parser.add_argument("--model", default="/monster/data/model/Llama-3.2-1B-Instruct") + parser.add_argument("--layer-index", type=int, required=True, help="Zero-based decoder layer index to probe.") + parser.add_argument( + "--module", + dest="modules", + action="append", + default=None, + help="Relative module path inside the target layer. Can be passed multiple times.", + ) + parser.add_argument("--calibration-rows", type=int, default=64) + parser.add_argument("--calibration-concat-size", type=int, default=2048) + parser.add_argument("--quant-batch-size", type=int, default=1) + parser.add_argument("--eval-batch-size", type=int, default=64) + parser.add_argument("--eval-max-rows", type=int, default=None) + parser.add_argument("--opt-rotation-epochs", type=int, default=10) + parser.add_argument("--opt-finetune-epochs", type=int, default=10) + parser.add_argument("--opt-train-samples", type=int, default=2048) + parser.add_argument("--opt-validation-samples", type=int, default=64) + parser.add_argument("--opt-batch-size", type=int, default=64) + parser.add_argument( + "--output-json", + type=Path, + default=None, + help="Optional aggregate JSON output path for all module scan results.", + ) + return parser.parse_args() + + +def _score(case: dict) -> float | None: + metric = case.get("eval_metrics") or case.get("metrics") or {} + gsm = metric.get("gsm8k_platinum_cot", {}) + if not isinstance(gsm, dict): + return None + return float(gsm["acc,num"]) if "acc,num" in gsm else None + + +def main() -> int: + args = parse_args() + modules = args.modules or list(DEFAULT_MODULES) + + cases: list[dict] = [] + for module_name in modules: + case = run_paroquant_single_module_case( + model_path=args.model, + layer_idx=args.layer_index, + module_name=module_name, + calibration_rows=args.calibration_rows, + calibration_concat_size=args.calibration_concat_size, + quant_batch_size=args.quant_batch_size, + eval_batch_size=args.eval_batch_size, + eval_max_rows=args.eval_max_rows, + sym=True, + fused_opt_rotation=True, + opt_rotation_epochs=args.opt_rotation_epochs, + opt_finetune_epochs=args.opt_finetune_epochs, + opt_train_samples=args.opt_train_samples, + opt_validation_samples=args.opt_validation_samples, + opt_batch_size=args.opt_batch_size, + ) + case["label"] = f"layer_{args.layer_index}:{module_name}" + cases.append(case) + + rows = [] + for case in cases: + rows.append( + [ + case["layer_idx"], + case["module_name"], + "" if _score(case) is None else f"{_score(case):.6f}", + f"{float(case['quant_wall_s']):.3f}", + f"{float(case['eval_wall_s']):.3f}", + ] + ) + rows.sort(key=lambda item: float(item[2] or 0.0)) + + print("Summary") + print( + tabulate( + rows, + headers=["layer_idx", "module_name", "gsm8k_platinum_cot", "quant_wall_s", "eval_wall_s"], + tablefmt="grid", + ) + ) + + if args.output_json is not None: + args.output_json.parent.mkdir(parents=True, exist_ok=True) + with args.output_json.open("w", encoding="utf-8") as handle: + json.dump(cases, handle, indent=2, sort_keys=True) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/profile_paroquant_runtime_cache_case.py b/scripts/profile_paroquant_runtime_cache_case.py new file mode 100644 index 000000000..670128050 --- /dev/null +++ b/scripts/profile_paroquant_runtime_cache_case.py @@ -0,0 +1,250 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import argparse +import json +import os +from dataclasses import dataclass +from pathlib import Path + +import torch + + +@dataclass(frozen=True) +class BenchCase: + case_id: str + batch: int + seq: int + in_features: int + out_features: int + group_size: int = 128 + krot: int = 8 + + +CASES = { + "decode_q_proj": BenchCase("decode_q_proj", batch=1, seq=1, in_features=2048, out_features=2048), + "prefill_q_proj": BenchCase("prefill_q_proj", batch=1, seq=128, in_features=2048, out_features=2048), + "batched_down_proj": BenchCase("batched_down_proj", batch=4, seq=128, in_features=8192, out_features=2048), +} + + +def _resolve_dtype(name: str) -> torch.dtype: + if name == "fp16": + return torch.float16 + if name == "bf16": + return torch.bfloat16 + raise ValueError(f"Unsupported dtype: {name}") + + +def _pack_awq_tensor(unpacked: torch.Tensor, bits: int) -> torch.Tensor: + pack_factor = 32 // bits + order_map = [0, 2, 4, 6, 1, 3, 5, 7] + packed = torch.zeros((unpacked.shape[0], unpacked.shape[1] // pack_factor), dtype=torch.int32) + for col in range(unpacked.shape[1] // pack_factor): + for i, order in enumerate(order_map): + value = unpacked[:, col * pack_factor + order].to(torch.int32) + packed[:, col] |= value << (i * bits) + return packed + + +def _make_quant_buffers(case: BenchCase, dtype: torch.dtype, bits: int = 4) -> dict[str, torch.Tensor]: + from gptqmodel.utils.paroquant import build_identity_rotation_buffers + + groups = case.in_features // case.group_size + int_weight = torch.randint(0, 2**bits, size=(case.in_features, case.out_features), dtype=torch.int32) + zero_points = torch.randint(0, 2**bits, size=(groups, case.out_features), dtype=torch.int32) + scales = (torch.rand(groups, case.out_features, dtype=torch.float32) * 0.5) + 0.75 + scales = scales.to(dtype=dtype) + bias = torch.randn(case.out_features, dtype=torch.float32).to(dtype=dtype) + + pairs, theta, channel_scales = build_identity_rotation_buffers( + in_features=case.in_features, + group_size=case.group_size, + krot=case.krot, + dtype=dtype, + ) + theta.uniform_(-0.2, 0.2) + channel_scales.uniform_(0.75, 1.25) + + return { + "qweight": _pack_awq_tensor(int_weight, bits), + "qzeros": _pack_awq_tensor(zero_points, bits), + "scales": scales, + "bias": bias, + "pairs": pairs, + "theta": theta, + "channel_scales": channel_scales, + } + + +def _configure_runtime(args: argparse.Namespace, device: torch.device) -> None: + if args.force_rebuild_awq: + awq_build_root = Path("/tmp") / ( + f"awq_jit_profile_{os.getpid()}_{args.case_id}_{args.dtype}_" + f"rt{int(args.cache_runtime_dtype)}_rot{int(args.cache_rotation_dtype)}" + ) + os.environ["GPTQMODEL_AWQ_BUILD_ROOT"] = str(awq_build_root) + os.environ["GPTQMODEL_AWQ_FORCE_REBUILD"] = "1" + elif "GPTQMODEL_AWQ_BUILD_ROOT" not in os.environ: + os.environ.pop("GPTQMODEL_AWQ_BUILD_ROOT", None) + os.environ.pop("GPTQMODEL_AWQ_FORCE_REBUILD", None) + + if args.force_rebuild_paroquant: + paro_build_root = Path("/tmp") / ( + f"paroquant_ext_profile_{os.getpid()}_{args.case_id}_{args.dtype}_" + f"rt{int(args.cache_runtime_dtype)}_rot{int(args.cache_rotation_dtype)}" + ) + os.environ["GPTQMODEL_PAROQUANT_BUILD_ROOT"] = str(paro_build_root) + os.environ["GPTQMODEL_PAROQUANT_FORCE_REBUILD"] = "1" + elif "GPTQMODEL_PAROQUANT_BUILD_ROOT" not in os.environ: + os.environ.pop("GPTQMODEL_PAROQUANT_BUILD_ROOT", None) + os.environ.pop("GPTQMODEL_PAROQUANT_FORCE_REBUILD", None) + + from gptqmodel.utils.awq import clear_awq_extension_cache, prewarm_awq_extension + from gptqmodel.utils.paroquant import clear_paroquant_rotation_extension_cache, prewarm_paroquant_rotation_extension + + if args.force_rebuild_awq: + clear_awq_extension_cache() + if args.force_rebuild_paroquant: + clear_paroquant_rotation_extension_cache() + + if not prewarm_awq_extension(): + raise RuntimeError("Failed to build/load AWQ runtime.") + if not prewarm_paroquant_rotation_extension( + fused_rotation=True, + group_size=128, + krot=8, + device=device, + ): + raise RuntimeError("Failed to build/load ParoQuant runtime.") + + +def _make_module( + case: BenchCase, + dtype: torch.dtype, + device: torch.device, + cache_runtime_dtype: bool, + cache_rotation_dtype: bool, +): + from gptqmodel.nn_modules.qlinear.paroquant import ParoLinear + + buffers = _make_quant_buffers(case, dtype=dtype) + module = ParoLinear( + bits=4, + group_size=case.group_size, + sym=True, + desc_act=False, + in_features=case.in_features, + out_features=case.out_features, + bias=True, + register_buffers=True, + krot=case.krot, + cache_runtime_dtype=cache_runtime_dtype, + auto_cache_bf16_runtime_dtype=cache_runtime_dtype, + cache_rotation_dtype=cache_rotation_dtype, + auto_cache_bf16_rotation_dtype=cache_rotation_dtype, + ).to(device) + module.qweight.copy_(buffers["qweight"].to(device)) + module.qzeros.copy_(buffers["qzeros"].to(device)) + module.scales.copy_(buffers["scales"].to(device)) + module.bias.copy_(buffers["bias"].to(device)) + module.pairs.copy_(buffers["pairs"].to(device)) + module.theta.copy_(buffers["theta"].to(device)) + module.channel_scales.copy_(buffers["channel_scales"].to(device)) + module.post_init() + module.eval() + return module + + +def main() -> int: + parser = argparse.ArgumentParser(description="Profile one ParoQuant runtime-cache case.") + parser.add_argument("--case-id", choices=tuple(CASES.keys()), required=True) + parser.add_argument("--dtype", choices=("fp16", "bf16"), default="bf16") + parser.add_argument("--device", type=int, default=0) + parser.add_argument("--warmup", type=int, default=10) + parser.add_argument("--iters", type=int, default=50) + parser.add_argument("--cache-runtime-dtype", action="store_true") + parser.add_argument("--cache-rotation-dtype", action="store_true") + parser.add_argument("--force-rebuild-awq", action="store_true") + parser.add_argument("--force-rebuild-paroquant", action="store_true") + parser.add_argument("--torch-profiler-json-out", type=Path, default=None) + parser.add_argument("--torch-profiler-top-n", type=int, default=12) + args = parser.parse_args() + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required.") + + device = torch.device(f"cuda:{args.device}") + case = CASES[args.case_id] + dtype = _resolve_dtype(args.dtype) + _configure_runtime(args, device) + + torch.manual_seed(0) + torch.cuda.manual_seed_all(0) + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + + module = _make_module( + case, + dtype=dtype, + device=device, + cache_runtime_dtype=args.cache_runtime_dtype, + cache_rotation_dtype=args.cache_rotation_dtype, + ) + x = torch.randn((case.batch, case.seq, case.in_features), device=device, dtype=dtype) + + with torch.inference_mode(): + for _ in range(args.warmup): + module(x) + torch.cuda.synchronize(device) + + if args.torch_profiler_json_out is not None: + from torch.profiler import ProfilerActivity, profile + + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=False) as prof: + for _ in range(args.iters): + module(x) + torch.cuda.synchronize(device) + + events = [] + for evt in prof.key_averages(): + cuda_time_total = getattr(evt, "cuda_time_total", getattr(evt, "device_time_total", 0.0)) + self_cuda_time_total = getattr(evt, "self_cuda_time_total", getattr(evt, "self_device_time_total", 0.0)) + events.append( + { + "key": evt.key, + "count": evt.count, + "cpu_time_total_us": evt.cpu_time_total, + "self_cpu_time_total_us": evt.self_cpu_time_total, + "cuda_time_total_us": cuda_time_total, + "self_cuda_time_total_us": self_cuda_time_total, + } + ) + events.sort(key=lambda row: row["cuda_time_total_us"], reverse=True) + args.torch_profiler_json_out.parent.mkdir(parents=True, exist_ok=True) + args.torch_profiler_json_out.write_text( + json.dumps( + { + "case_id": args.case_id, + "dtype": args.dtype, + "cache_runtime_dtype": args.cache_runtime_dtype, + "cache_rotation_dtype": args.cache_rotation_dtype, + "top_events": events[: args.torch_profiler_top_n], + }, + indent=2, + ), + encoding="utf-8", + ) + else: + for _ in range(args.iters): + module(x) + torch.cuda.synchronize(device) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/repro_issue_2326.py b/scripts/repro_issue_2326.py new file mode 100644 index 000000000..d8ea7ea22 --- /dev/null +++ b/scripts/repro_issue_2326.py @@ -0,0 +1,364 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +"""Reproduce issue #2326 with a minimal BALANCED multi-GPU forward leak.""" + +from __future__ import annotations + +import json +import os +import subprocess +import sys +import threading +import types +from contextlib import nullcontext +from pathlib import Path + +os.environ.setdefault("GPTQMODEL_DEVICE_TELEMETRY", "1") + +REPO_ROOT = Path(__file__).resolve().parents[1] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +import torch + +from gptqmodel.looper.forward_executor import ForwardExecutor +from gptqmodel.looper.loop_processor import ExecutionConfig +from gptqmodel.looper.module_looper import ModuleLooper +from gptqmodel.looper.named_module import NamedModule +from gptqmodel.looper.stage_subset import build_subset_plan +from gptqmodel.nn_modules.hooked_linear import replace_module_with_hooked_legacy +from gptqmodel.quantization.config import VramStrategy +from gptqmodel.utils.device_telemetry import ( + clear_device_telemetry_records, + get_device_telemetry_records, +) + + +class _Expert(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.gate_proj = torch.nn.Linear(4, 4, bias=False) + self.up_proj = torch.nn.Linear(4, 4, bias=False) + self.down_proj = torch.nn.Linear(4, 4, bias=False) + for mod in (self.gate_proj, self.up_proj, self.down_proj): + torch.nn.init.eye_(mod.weight) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.down_proj(self.gate_proj(hidden_states) + self.up_proj(hidden_states)) + + +class _SelfAttention(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.q_proj = torch.nn.Linear(4, 4, bias=False) + self.k_proj = torch.nn.Linear(4, 4, bias=False) + self.v_proj = torch.nn.Linear(4, 4, bias=False) + self.o_proj = torch.nn.Linear(4, 4, bias=False) + for mod in (self.q_proj, self.k_proj, self.v_proj, self.o_proj): + torch.nn.init.eye_(mod.weight) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.o_proj( + self.q_proj(hidden_states) + self.k_proj(hidden_states) + self.v_proj(hidden_states) + ) + + +class _ToyLayer(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.self_attn = _SelfAttention() + self.mlp = torch.nn.Module() + self.mlp.experts = torch.nn.ModuleList([_Expert(), _Expert()]) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask=None, + position_ids=None, + **_kwargs, + ) -> torch.Tensor: + hidden_states = self.self_attn(hidden_states) + output = 0 + for expert in self.mlp.experts: + output = output + expert(hidden_states) + return output + + +class _PlanLooper: + def __init__(self, devices: list[torch.device]) -> None: + self._quant_devices = devices + self._dense_quant_devices = [devices[0]] + self._moe_quant_devices = devices + self._dense_vram_strategy = VramStrategy.EXCLUSIVE + self._moe_vram_strategy = VramStrategy.BALANCED + self._dense_vram_strategy_explicit = False + self._moe_vram_strategy_explicit = True + self._moe_subset_threshold = 2 + self.gptq_model = types.SimpleNamespace( + lm_head=None, + quantize_config=types.SimpleNamespace( + auto_forward_data_parallel=True, + moe=None, + ), + ) + + @staticmethod + def _is_attention_module_name(name: str) -> bool: + return name.startswith("self_attn.") + + @staticmethod + def _extract_moe_group_key(name: str) -> str | None: + parts = name.split(".") + if "experts" not in parts: + return None + idx = parts.index("experts") + return ".".join(parts[: idx + 2]) + + @staticmethod + def _resolve_batch_total(_num_batches, layer_inputs) -> int: + return len(layer_inputs) + + @staticmethod + def _collect_row_counts(layer_inputs) -> list[int]: + return [int(batch[0].shape[0]) for batch in layer_inputs] + + +class _DummyGptqModel: + def __init__(self) -> None: + self.quantize_config = types.SimpleNamespace( + auto_forward_data_parallel=True, + calibration_data_device=None, + ) + + @staticmethod + def shell_module_materialize(target_submodule, device, role, named_module=None): + return target_submodule + + @staticmethod + def prepare_layer_replay_kwargs(layer, layer_input, additional_inputs, target_device): + return additional_inputs + + +class _ExecLooper: + support_batch_quantize = False + moe_routing_override = None + moe_routing_bypass = False + _current_subset = None + MoERoutingOverrideContext = staticmethod(lambda *args, **kwargs: nullcontext()) + MoELifecycleContext = staticmethod(lambda *args, **kwargs: nullcontext()) + _assign_quant_device_for_module = ModuleLooper._assign_quant_device_for_module + _rehome_processor_task = ModuleLooper._rehome_processor_task + _prepare_named_module_for_quantization = ModuleLooper._prepare_named_module_for_quantization + _apply_forward_device_overrides = ModuleLooper._apply_forward_device_overrides + _restore_forward_device_overrides = ModuleLooper._restore_forward_device_overrides + + def __init__(self, devices: list[torch.device]) -> None: + self.gptq_model = _DummyGptqModel() + self._quant_devices = devices + self._module_device_map = {} + self._quant_device_lock = threading.Lock() + self._quant_device_rr = 0 + + @staticmethod + def _resolve_batch_total(_num_batches, layer_inputs) -> int: + return len(layer_inputs) + + @staticmethod + def _collect_row_counts(layer_inputs) -> list[int]: + return [int(batch[0].shape[0]) for batch in layer_inputs] + + @staticmethod + def _batch_row_count(batch_inputs) -> int: + return int(batch_inputs[0].shape[0]) + + @staticmethod + def _set_processor_mask(processor, mask) -> None: + return None + + +def _nvidia_smi_snapshot() -> list[str]: + try: + output = subprocess.check_output( + [ + "nvidia-smi", + "--query-gpu=index,pci.bus_id,name", + "--format=csv,noheader", + ], + text=True, + ) + except Exception as exc: + return [f"nvidia-smi unavailable: {exc}"] + return [line.strip() for line in output.splitlines() if line.strip()] + + +def _named_modules_for_layer(layer: _ToyLayer) -> dict[str, NamedModule]: + module_refs = { + "self_attn.q_proj": layer.self_attn.q_proj, + "self_attn.k_proj": layer.self_attn.k_proj, + "self_attn.v_proj": layer.self_attn.v_proj, + "self_attn.o_proj": layer.self_attn.o_proj, + "mlp.experts.0.gate_proj": layer.mlp.experts[0].gate_proj, + "mlp.experts.0.up_proj": layer.mlp.experts[0].up_proj, + "mlp.experts.0.down_proj": layer.mlp.experts[0].down_proj, + "mlp.experts.1.gate_proj": layer.mlp.experts[1].gate_proj, + "mlp.experts.1.up_proj": layer.mlp.experts[1].up_proj, + "mlp.experts.1.down_proj": layer.mlp.experts[1].down_proj, + } + return { + name: NamedModule( + mod, + name=name, + full_name=f"model.layers.0.{name}", + layer_index=0, + ) + for name, mod in module_refs.items() + } + + +def main() -> int: + if not torch.cuda.is_available() or torch.cuda.device_count() < 2: + print("Requires at least 2 visible CUDA devices.", file=sys.stderr) + return 2 + if os.environ.get("CUDA_DEVICE_ORDER") != "PCI_BUS_ID": + print("Set CUDA_DEVICE_ORDER=PCI_BUS_ID before running.", file=sys.stderr) + return 2 + if os.environ.get("CUDA_VISIBLE_DEVICES") != "0,1": + print("Set CUDA_VISIBLE_DEVICES=0,1 before running.", file=sys.stderr) + return 2 + if getattr(sys, "_is_gil_enabled", lambda: True)(): + print("Run with PYTHON_GIL=0 so the free-threaded build keeps the GIL disabled.", file=sys.stderr) + return 2 + + clear_device_telemetry_records() + + devices = [torch.device("cuda:0"), torch.device("cuda:1")] + layer = _ToyLayer() + replace_module_with_hooked_legacy(layer) + layer = layer.to(devices[0]) + layer.target_device = devices[0] + + full_named = _named_modules_for_layer(layer) + full = {name: named.module for name, named in full_named.items()} + subset_names = [ + "mlp.experts.0.gate_proj", + "mlp.experts.0.up_proj", + "mlp.experts.1.gate_proj", + "mlp.experts.1.up_proj", + ] + subset = {name: full_named[name] for name in subset_names} + layer_inputs = [[torch.ones(1, 2, 4, device=devices[0])]] + + plan = build_subset_plan( + _PlanLooper(devices), + processor=types.SimpleNamespace( + execution_config=ExecutionConfig( + require_fwd=True, + fwd_replay_after_process=True, + ) + ), + subset=subset, + subset_index=0, + subset_total=1, + full=full, + fallback=True, + layer_inputs=layer_inputs, + ) + + looper = _ExecLooper(devices) + empty_tasks = types.SimpleNamespace(tasks={}) + processor = types.SimpleNamespace( + num_batches=None, + _set_current_batch_index=lambda idx: None, + ) + + looper._prepare_named_module_for_quantization( + processor=empty_tasks, + named_module=full_named["self_attn.q_proj"], + fallback_device=devices[0], + ) + looper._prepare_named_module_for_quantization( + processor=empty_tasks, + named_module=full_named["self_attn.o_proj"], + fallback_device=devices[0], + ) + o_proj_device_after_quant_prepare = str(layer.self_attn.o_proj.weight.device) + + previous_devices = looper._apply_forward_device_overrides( + subset, + plan.forward_device_map, + fallback_modules=full, + ) + + executor = ForwardExecutor(looper) + executor.run_single( + module=layer, + processor=processor, + layer_inputs=layer_inputs, + layer_input_kwargs=[{}], + position_ids=[], + attention_masks=[None], + cur_layer_device=devices[0], + is_lm_head_module=False, + shared_kv_cache_dict={}, + layer_index=0, + need_outputs=True, + reuse_kv=False, + preserve_module_devices=plan.preserve_module_devices, + ) + + if plan.restore_forward_device_overrides: + looper._restore_forward_device_overrides( + subset, + previous_devices, + fallback_modules=full, + ) + + telemetry = get_device_telemetry_records() + o_proj_forward = [ + record + for record in telemetry + if record.get("event") == "hooked_linear_forward" + and record.get("module") == "model.layers.0.self_attn.o_proj" + ] + + summary = { + "python_gil_enabled": getattr(sys, "_is_gil_enabled", lambda: None)(), + "cuda_device_order": os.environ.get("CUDA_DEVICE_ORDER"), + "cuda_visible_devices": os.environ.get("CUDA_VISIBLE_DEVICES"), + "visible_cuda_devices": [ + { + "index": idx, + "name": torch.cuda.get_device_properties(idx).name, + } + for idx in range(torch.cuda.device_count()) + ], + "nvidia_smi": _nvidia_smi_snapshot(), + "plan_forward_device_map": {name: str(device) for name, device in plan.forward_device_map.items()}, + "o_proj_in_forward_device_map": "self_attn.o_proj" in plan.forward_device_map, + "restore_forward_device_overrides": plan.restore_forward_device_overrides, + "o_proj_device_after_quant_prepare": o_proj_device_after_quant_prepare, + "o_proj_device_after_forward_restore": str(layer.self_attn.o_proj.weight.device), + "o_proj_forward_records": o_proj_forward, + } + + reproduced = ( + not summary["o_proj_in_forward_device_map"] + and summary["o_proj_device_after_quant_prepare"] == "cuda:1" + and summary["o_proj_device_after_forward_restore"] == "cuda:1" + and any( + record.get("input_device") == "cuda:0" and record.get("weight_device") == "cuda:1" + for record in o_proj_forward + ) + ) + summary["issue_2326_reproduced"] = reproduced + + print(json.dumps(summary, indent=2, sort_keys=True)) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/run_gptq_pro_validate.sh b/scripts/run_gptq_pro_validate.sh new file mode 100755 index 000000000..1447c59d7 --- /dev/null +++ b/scripts/run_gptq_pro_validate.sh @@ -0,0 +1,46 @@ +#!/usr/bin/env bash +set -euo pipefail + +SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd) +REPO_ROOT=$(cd -- "${SCRIPT_DIR}/.." && pwd) + +PRIMARY_GPU="${1:-${GPTQ_PRO_VALIDATE_GPU:-2}}" +FALLBACK_GPU="${GPTQ_PRO_VALIDATE_FALLBACK_GPU:-3}" +VALIDATE_SRC="${REPO_ROOT}/gptqmodel_ext/gptq_pro/gptq_pro_validate.cu" +KERNEL_SRC="${REPO_ROOT}/gptqmodel_ext/gptq_pro/gptq_pro_kernel.cu" +OUT_BIN="${TMPDIR:-/tmp}/gptq_pro_validate_$(date +%s)" + +cleanup() { + rm -f "${OUT_BIN}" +} +trap cleanup EXIT + +run_on_gpu() { + local gpu_index="$1" + local gpu_name + gpu_name=$(nvidia-smi -i "${gpu_index}" --query-gpu=name --format=csv,noheader | tr -d '[:space:]') + + echo "==> Building standalone validator" + nvcc -arch=sm_80 -std=c++17 "${VALIDATE_SRC}" "${KERNEL_SRC}" -o "${OUT_BIN}" + + echo "==> Selected GPU ${gpu_index}" + nvidia-smi -i "${gpu_index}" \ + --query-gpu=index,name,uuid,memory.total,memory.used,memory.free \ + --format=csv,noheader + + if [[ "${gpu_name}" != *"RTX3060"* && "${gpu_name}" != *"RTX 3060"* ]]; then + echo "warning: GPU ${gpu_index} is not an RTX 3060 (${gpu_name})" >&2 + fi + + echo "==> Running validator" + CUDA_VISIBLE_DEVICES="${gpu_index}" "${OUT_BIN}" +} + +if ! run_on_gpu "${PRIMARY_GPU}"; then + if [[ $# -eq 0 && "${PRIMARY_GPU}" != "${FALLBACK_GPU}" ]]; then + echo "==> Primary GPU ${PRIMARY_GPU} failed; retrying on fallback GPU ${FALLBACK_GPU}" >&2 + run_on_gpu "${FALLBACK_GPU}" + else + exit 1 + fi +fi diff --git a/scripts/serve_vllm_qwen35.py b/scripts/serve_vllm_qwen35.py new file mode 100644 index 000000000..19414b2e6 --- /dev/null +++ b/scripts/serve_vllm_qwen35.py @@ -0,0 +1,231 @@ +#!/usr/bin/env python3 +""" +Serve local Qwen 3.5 text checkpoints through vLLM's OpenAI-compatible server. + +Why this wrapper exists: + * qwen3_5_text checkpoints need a text-only vLLM setup in this environment. + * local machines with a broken physical GPU can crash vLLM's startup-time NVML + scan before the model ever loads. + * the fast inference path for GPTQ-Pro artifacts is vLLM's Marlin/Machete + runtime, not GPTQModel.generate(). +""" + +from __future__ import annotations + +import importlib +import importlib.util +import json +import multiprocessing as mp +import os +from pathlib import Path +import subprocess + + +DEFAULT_MM_LIMITS = {"image": 0, "video": 0} +SCRIPT_DIR = Path(__file__).resolve().parent +NVML_SHIM_SOURCE = SCRIPT_DIR / "nvml_visible_shim.c" +NVML_SHIM_OUTPUT = Path.home() / ".cache" / "gptqmodel" / "nvml_visible_shim.so" + + +def _apply_local_vllm_patches() -> None: + module_path = SCRIPT_DIR / "sitecustomize.py" + spec = importlib.util.spec_from_file_location("gptqmodel_local_sitecustomize", module_path) + if spec is None or spec.loader is None: + raise RuntimeError(f"Unable to load local sitecustomize module from {module_path}") + + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + module._patch_vllm_language_model_only_renderer() + module._register_vllm_qwen35_text_arches() + + +def _visible_physical_device_ids() -> list[int] | None: + raw = os.environ.get("CUDA_VISIBLE_DEVICES") + if not raw: + return None + + device_ids: list[int] = [] + for item in raw.split(","): + item = item.strip() + if not item: + continue + if not item.isdigit(): + return None + device_ids.append(int(item)) + return device_ids or None + + +def _patch_vllm_nvml_to_visible_devices() -> list[int] | None: + visible_ids = _visible_physical_device_ids() + if not visible_ids: + return None + + pynvml = importlib.import_module("vllm.third_party.pynvml") + original_get_handle = pynvml.nvmlDeviceGetHandleByIndex + + def mapped_get_count(): + return len(visible_ids) + + def mapped_get_handle(device_id: int): + if 0 <= device_id < len(visible_ids): + device_id = visible_ids[device_id] + return original_get_handle(device_id) + + pynvml.nvmlDeviceGetCount = mapped_get_count + pynvml.nvmlDeviceGetHandleByIndex = mapped_get_handle + return visible_ids + + +def _healthy_physical_device_ids() -> list[int] | None: + result = subprocess.run( + ["nvidia-smi", "--query-gpu=index", "--format=csv,noheader"], + capture_output=True, + text=True, + check=False, + ) + + device_ids: list[int] = [] + for line in result.stdout.splitlines(): + line = line.strip() + if line.isdigit(): + device_ids.append(int(line)) + return device_ids or None + + +def _ensure_nvml_preload_shim(visible_ids: list[int] | None) -> Path | None: + if not visible_ids: + return None + + NVML_SHIM_OUTPUT.parent.mkdir(parents=True, exist_ok=True) + needs_build = ( + not NVML_SHIM_OUTPUT.is_file() + or NVML_SHIM_OUTPUT.stat().st_mtime < NVML_SHIM_SOURCE.stat().st_mtime + ) + if needs_build: + subprocess.run( + [ + "cc", + "-shared", + "-fPIC", + "-O2", + "-o", + str(NVML_SHIM_OUTPUT), + str(NVML_SHIM_SOURCE), + "-ldl", + ], + check=True, + ) + + return NVML_SHIM_OUTPUT + + +def _propagate_sitecustomize(visible_ids: list[int] | None) -> None: + pythonpath = os.environ.get("PYTHONPATH") + entries = [str(SCRIPT_DIR)] + if pythonpath: + entries.append(pythonpath) + os.environ["PYTHONPATH"] = os.pathsep.join(entries) + + if visible_ids: + os.environ["GPTQMODEL_VLLM_VISIBLE_PHYSICAL_IDS"] = ",".join(str(idx) for idx in visible_ids) + healthy_ids = _healthy_physical_device_ids() + if healthy_ids: + os.environ["GPTQMODEL_VLLM_HEALTHY_PHYSICAL_IDS"] = ",".join( + str(idx) for idx in healthy_ids + ) + shim_path = _ensure_nvml_preload_shim(visible_ids) + if shim_path is not None: + preload = os.environ.get("LD_PRELOAD") + preload_entries = [str(shim_path)] + if preload: + preload_entries.append(preload) + os.environ["LD_PRELOAD"] = os.pathsep.join(preload_entries) + + +def _load_local_config(model_path: str) -> dict | None: + config_path = Path(model_path) / "config.json" + if not config_path.is_file(): + return None + with open(config_path, "r", encoding="utf-8") as handle: + return json.load(handle) + + +def _load_checkpoint_config(model_path: str) -> dict | None: + local_config = _load_local_config(model_path) + if local_config is not None: + return local_config + + try: + from transformers import AutoConfig + except ImportError: + return None + + try: + return AutoConfig.from_pretrained(model_path, trust_remote_code=True).to_dict() + except (OSError, ValueError): + return None + + +def _is_qwen35_text_checkpoint(model_path: str) -> bool: + config = _load_checkpoint_config(model_path) + return bool(config and config.get("model_type") == "qwen3_5_text") + + +def main() -> None: + os.environ.setdefault("VLLM_PLUGINS", "") + os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn") + try: + mp.set_start_method("spawn", force=True) + except RuntimeError: + pass + + visible_ids = _patch_vllm_nvml_to_visible_devices() + _propagate_sitecustomize(visible_ids) + _apply_local_vllm_patches() + + from vllm.entrypoints.openai.api_server import run_server + from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args + from vllm.utils.argparse_utils import FlexibleArgumentParser + import uvloop + + parser = FlexibleArgumentParser( + description="vLLM OpenAI-compatible server for qwen3_5_text GPTQ checkpoints." + ) + parser = make_arg_parser(parser) + args = parser.parse_args() + model_path = getattr(args, "model_tag", None) or getattr(args, "model", None) + if model_path: + args.model = model_path + + if model_path and _is_qwen35_text_checkpoint(model_path): + args.language_model_only = True + args.skip_mm_profiling = True + args.limit_mm_per_prompt = dict(DEFAULT_MM_LIMITS) + if getattr(args, "generation_config", None) in (None, "auto"): + args.generation_config = "vllm" + + print( + "Auto-configured qwen3_5_text serving:", + f"language_model_only={args.language_model_only},", + f"limit_mm_per_prompt={args.limit_mm_per_prompt},", + f"generation_config={args.generation_config}", + flush=True, + ) + + if visible_ids is not None: + print( + "Patched vLLM NVML enumeration to visible physical GPU ids:", + ",".join(str(idx) for idx in visible_ids), + flush=True, + ) + print( + "Enabled NVML LD_PRELOAD shim for NCCL-visible physical GPU remapping.", + flush=True, + ) + + validate_parsed_serve_args(args) + uvloop.run(run_server(args)) + + +if __name__ == "__main__": + main() diff --git a/scripts/sitecustomize.py b/scripts/sitecustomize.py new file mode 100644 index 000000000..6d13e7892 --- /dev/null +++ b/scripts/sitecustomize.py @@ -0,0 +1,177 @@ +"""Optional Python startup hooks used by local repo scripts. + +This module is loaded automatically when `scripts/` is on `PYTHONPATH`. +It currently patches two vLLM environment issues seen in this repo: + +* startup-time NVML enumeration can touch a broken physical GPU unless we + remap it to `CUDA_VISIBLE_DEVICES` +* `--language-model-only` still initializes multimodal renderers for some + text-only Qwen 3.5 checkpoints, which triggers a config-family mismatch +""" + +from __future__ import annotations + +from dataclasses import replace +import importlib +import math +import os +import sys + + +def _parse_visible_ids(raw: str | None) -> list[int]: + if not raw: + return [] + + visible_ids: list[int] = [] + for item in raw.split(","): + item = item.strip() + if not item: + continue + if not item.isdigit(): + return [] + visible_ids.append(int(item)) + return visible_ids + + +def _patch_vllm_nvml() -> None: + visible_ids = _parse_visible_ids(os.environ.get("GPTQMODEL_VLLM_VISIBLE_PHYSICAL_IDS")) + if not visible_ids: + return + + try: + pynvml = importlib.import_module("vllm.third_party.pynvml") + except Exception: + return + + original_get_handle = pynvml.nvmlDeviceGetHandleByIndex + + def mapped_get_count(): + return len(visible_ids) + + def mapped_get_handle(device_id: int): + if 0 <= device_id < len(visible_ids): + device_id = visible_ids[device_id] + return original_get_handle(device_id) + + pynvml.nvmlDeviceGetCount = mapped_get_count + pynvml.nvmlDeviceGetHandleByIndex = mapped_get_handle + + +def _patch_vllm_language_model_only_renderer() -> None: + try: + base = importlib.import_module("vllm.renderers.base") + except Exception: + return + + original_init = getattr(base.BaseRenderer, "__init__", None) + if original_init is None or getattr(original_init, "_gptqmodel_language_model_only_patch", False): + return + + def patched_init(self, config, tokenizer): + model_config = getattr(config, "model_config", None) + multimodal_config = getattr(model_config, "multimodal_config", None) + if multimodal_config is not None and getattr(multimodal_config, "language_model_only", False): + model_config.multimodal_config = None + try: + return original_init(self, config, tokenizer) + finally: + model_config.multimodal_config = multimodal_config + + return original_init(self, config, tokenizer) + + patched_init._gptqmodel_language_model_only_patch = True + base.BaseRenderer.__init__ = patched_init + + +def _register_vllm_qwen35_text_arches() -> None: + try: + registry = importlib.import_module("vllm.model_executor.models.registry") + except Exception: + return + + model_registry = getattr(registry, "ModelRegistry", None) + if model_registry is None: + return + + supported_arches = set(model_registry.get_supported_archs()) + lazy_targets = { + "Qwen3_5ForCausalLM": "vllm_qwen35_shim:Qwen3_5ForCausalLM", + "Qwen3_5MoeForCausalLM": "vllm_qwen35_shim:Qwen3_5MoeForCausalLM", + } + for arch, target in lazy_targets.items(): + if arch not in supported_arches: + model_registry.register_model(arch, target) + + +def _patch_vllm_kv_page_size_unifier() -> None: + try: + kv_cache_utils = importlib.import_module("vllm.v1.core.kv_cache_utils") + except Exception: + return + + original_unifier = getattr(kv_cache_utils, "unify_kv_cache_spec_page_size", None) + if original_unifier is None or getattr(original_unifier, "_gptqmodel_lcm_padding_patch", False): + return + + debug_pages = os.environ.get("GPTQMODEL_VLLM_DEBUG_KV_PAGES") == "1" + allow_lcm_padding = os.environ.get("GPTQMODEL_VLLM_ALLOW_KV_PAGE_LCM_PADDING") == "1" + + def _describe_specs(kv_cache_spec: dict[str, object]) -> str: + items: list[str] = [] + for layer_name, layer_spec in sorted(kv_cache_spec.items()): + items.append( + f"{layer_name}:{type(layer_spec).__name__}:" + f"block={getattr(layer_spec, 'block_size', '?')}:" + f"page={getattr(layer_spec, 'page_size_bytes', '?')}" + ) + return "; ".join(items) + + def patched_unifier(kv_cache_spec): + try: + return original_unifier(kv_cache_spec) + except NotImplementedError: + page_sizes = {layer.page_size_bytes for layer in kv_cache_spec.values()} + if debug_pages: + sys.stderr.write( + "GPTQModel local vLLM patch: non-divisible KV page sizes: " + f"{sorted(page_sizes)}\n{_describe_specs(kv_cache_spec)}\n" + ) + sys.stderr.flush() + + if not allow_lcm_padding or not page_sizes: + raise + + target_page_size = math.lcm(*page_sizes) + max_page_size = max(page_sizes) + if target_page_size > max_page_size * 8: + raise + + new_kv_cache_spec = {} + for layer_name, layer_spec in kv_cache_spec.items(): + if layer_spec.page_size_bytes == target_page_size: + new_kv_cache_spec[layer_name] = layer_spec + continue + + if not hasattr(layer_spec, "page_size_padded"): + raise + + new_kv_cache_spec[layer_name] = replace( + layer_spec, + page_size_padded=target_page_size, + ) + + sys.stderr.write( + "GPTQModel local vLLM patch: padded KV page sizes to " + f"common multiple {target_page_size} bytes.\n" + ) + sys.stderr.flush() + return new_kv_cache_spec + + patched_unifier._gptqmodel_lcm_padding_patch = True + kv_cache_utils.unify_kv_cache_spec_page_size = patched_unifier + + +_patch_vllm_nvml() +_patch_vllm_language_model_only_renderer() +_register_vllm_qwen35_text_arches() +_patch_vllm_kv_page_size_unifier() diff --git a/scripts/sync_cuda_toolkit_with_torch.sh b/scripts/sync_cuda_toolkit_with_torch.sh new file mode 100755 index 000000000..bee0b9a7d --- /dev/null +++ b/scripts/sync_cuda_toolkit_with_torch.sh @@ -0,0 +1,162 @@ +#!/usr/bin/env bash +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium +set -euo pipefail + +# Switch CUDA toolkit to match the CUDA version PyTorch was built with on Debian/Ubuntu. + +if [[ "$(uname -s)" != "Linux" ]]; then + echo "auto_switch_cuda_toolkit.sh: This script only supports Linux." >&2 + exit 1 +fi + +if [[ ! -f /etc/debian_version ]]; then + echo "auto_switch_cuda_toolkit.sh: This script only supports Debian/Ubuntu systems." >&2 + exit 1 +fi + +if ! command -v update-alternatives >/dev/null 2>&1; then + echo "auto_switch_cuda_toolkit.sh: 'update-alternatives' command not found." >&2 + exit 1 +fi + +if [[ $EUID -ne 0 ]]; then + if command -v sudo >/dev/null 2>&1; then + SUDO="sudo" + else + echo "auto_switch_cuda_toolkit.sh: Run as root or install sudo." >&2 + exit 1 + fi +else + SUDO="" +fi + +python_cmd=$(command -v python3 || true) +if [[ -z "${python_cmd}" ]]; then + python_cmd=$(command -v python || true) +fi +if [[ -z "${python_cmd}" ]]; then + echo "auto_switch_cuda_toolkit.sh: Python interpreter not found." >&2 + exit 1 +fi + +torch_cuda_version=$("${python_cmd}" - <<'PY' +import sys +try: + import torch +except Exception as exc: # pragma: no cover - runtime check + print(f"Failed to import torch: {exc}", file=sys.stderr) + sys.exit(1) + +cuda_version = torch.version.cuda +if not cuda_version: + print("Torch is not compiled with CUDA support.", file=sys.stderr) + sys.exit(1) + +parts = cuda_version.split('.') +if len(parts) >= 2: + normalized = f"{parts[0]}.{parts[1]}" +else: + normalized = cuda_version + +print(normalized) +PY +) + +if [[ -z "${torch_cuda_version}" ]]; then + echo "auto_switch_cuda_toolkit.sh: Unable to determine torch CUDA version." >&2 + exit 1 +fi + +target_version=${torch_cuda_version//$'\r'/$''} +target_version=${target_version//$'\n'/$''} + +echo "Detected torch CUDA version: ${target_version}" + +config_output=$({ printf '\n'; } | ${SUDO} update-alternatives --config cuda 2>&1 || true) + +selection=$(CONFIG_OUTPUT="${config_output}" "${python_cmd}" - "${target_version}" <<'PY' +import os +import pcre +import sys + +target = sys.argv[1] +data = os.environ.get("CONFIG_OUTPUT", "") +lines = data.splitlines() +cuda_version_pattern = pcre.compile(r"cuda-([0-9.]+)") + +candidates = [] +for line in lines: + if "manual mode" not in line or "/cuda-" not in line: + continue + stripped = line.lstrip("*").strip() + if not stripped: + continue + parts = stripped.split() + if len(parts) < 2: + continue + sel = parts[0] + if not sel.isdigit(): + continue + path = parts[1] + match = cuda_version_pattern.search(path) + if not match: + continue + version = match.group(1) + candidates.append((sel, version, path)) + +if not candidates: + print("No CUDA alternatives found in update-alternatives output.", file=sys.stderr) + sys.exit(1) + +split = lambda text: text.split('.') +target_parts = split(target) + +chosen = None +for sel, version, _ in candidates: + if version == target: + chosen = sel + break + +if chosen is None and len(target_parts) >= 2: + for sel, version, _ in candidates: + parts = split(version) + if len(parts) >= 2 and parts[0] == target_parts[0] and parts[1] == target_parts[1]: + chosen = sel + break + +if chosen is None and target_parts: + for sel, version, _ in candidates: + parts = split(version) + if parts and parts[0] == target_parts[0]: + chosen = sel + break + +if chosen is None: + available = ", ".join(sorted({ver for _, ver, _ in candidates})) + print( + f"Could not find CUDA alternative matching torch CUDA {target}. Available versions: {available}", + file=sys.stderr, + ) + sys.exit(1) + +print(chosen) +PY +) + +if [[ -z "${selection}" ]]; then + echo "auto_switch_cuda_toolkit.sh: Failed to identify matching CUDA alternative." >&2 + exit 1 +fi + +echo "Selecting CUDA alternative entry: ${selection}" +printf '%s\n' "${selection}" | ${SUDO} update-alternatives --config cuda + +current_value=$(${SUDO} update-alternatives --query cuda 2>/dev/null | awk '/^Value:/ {print $2; exit}') +if [[ -n "${current_value}" ]]; then + echo "CUDA toolkit is now set to: ${current_value}" +else + echo "CUDA toolkit update complete." +fi diff --git a/scripts/vllm_qwen35_shim.py b/scripts/vllm_qwen35_shim.py new file mode 100644 index 000000000..291c08ea7 --- /dev/null +++ b/scripts/vllm_qwen35_shim.py @@ -0,0 +1,82 @@ +"""Local vLLM shims for qwen3_5_text checkpoints. + +These wrappers keep the text-only causal-LM path while restoring the hybrid +model marker vLLM expects for Qwen3.5's mixed full-attention / linear-attention +cache configuration. +""" + +from __future__ import annotations + +import torch + +from vllm.model_executor.models.interfaces import IsHybrid +from vllm.model_executor.models.interfaces import SupportsMRoPE +from vllm.model_executor.models.qwen3_5 import ( + Qwen3_5ForCausalLM as _Qwen3_5ForCausalLM, +) +from vllm.model_executor.models.qwen3_5 import ( + Qwen3_5ForConditionalGeneration as _Qwen3_5ForConditionalGeneration, +) +from vllm.model_executor.models.qwen3_5 import ( + Qwen3_5MoeForCausalLM as _Qwen3_5MoeForCausalLM, +) + + +class Qwen3_5ForCausalLM(_Qwen3_5ForCausalLM, IsHybrid, SupportsMRoPE): + is_hybrid = True + supports_mrope = True + + @classmethod + def get_mamba_state_dtype_from_config(cls, vllm_config): + return _Qwen3_5ForConditionalGeneration.get_mamba_state_dtype_from_config( + vllm_config + ) + + @classmethod + def get_mamba_state_shape_from_config(cls, vllm_config): + return _Qwen3_5ForConditionalGeneration.get_mamba_state_shape_from_config( + vllm_config + ) + + @classmethod + def get_mamba_state_copy_func(cls): + return _Qwen3_5ForConditionalGeneration.get_mamba_state_copy_func() + + def get_mrope_input_positions(self, input_tokens, mm_features): + if mm_features: + raise NotImplementedError( + "Text-only Qwen3.5 shim only supports empty multimodal features." + ) + + positions = torch.arange(len(input_tokens), dtype=torch.int64) + return positions.unsqueeze(0).expand(3, -1), 0 + + +class Qwen3_5MoeForCausalLM(_Qwen3_5MoeForCausalLM, IsHybrid, SupportsMRoPE): + is_hybrid = True + supports_mrope = True + + @classmethod + def get_mamba_state_dtype_from_config(cls, vllm_config): + return _Qwen3_5ForConditionalGeneration.get_mamba_state_dtype_from_config( + vllm_config + ) + + @classmethod + def get_mamba_state_shape_from_config(cls, vllm_config): + return _Qwen3_5ForConditionalGeneration.get_mamba_state_shape_from_config( + vllm_config + ) + + @classmethod + def get_mamba_state_copy_func(cls): + return _Qwen3_5ForConditionalGeneration.get_mamba_state_copy_func() + + def get_mrope_input_positions(self, input_tokens, mm_features): + if mm_features: + raise NotImplementedError( + "Text-only Qwen3.5 shim only supports empty multimodal features." + ) + + positions = torch.arange(len(input_tokens), dtype=torch.int64) + return positions.unsqueeze(0).expand(3, -1), 0 diff --git a/setup.py b/setup.py index e04437c71..ba7263da7 100644 --- a/setup.py +++ b/setup.py @@ -2,963 +2,26 @@ # SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium -import os -import re -import subprocess -import sys -import tarfile + from pathlib import Path -from shutil import rmtree from setuptools import find_namespace_packages, find_packages, setup -from setuptools.command.bdist_wheel import bdist_wheel as _bdist_wheel - - -CUTLASS_VERSION = "3.5.0" -CUTLASS_RELEASE_URL = f"https://github.com/NVIDIA/cutlass/archive/refs/tags/v{CUTLASS_VERSION}.tar.gz" - - -def _ensure_cutlass_source() -> Path: - deps_dir = Path("build") / "_deps" - deps_dir.mkdir(parents=True, exist_ok=True) - - cutlass_root = deps_dir / f"cutlass-v{CUTLASS_VERSION}" - marker = cutlass_root / ".gptqmodel_complete" - if marker.exists(): - return cutlass_root.resolve() - - archive_path = deps_dir / f"cutlass-v{CUTLASS_VERSION}.tar.gz" - if not archive_path.exists(): - _download_with_progress( - CUTLASS_RELEASE_URL, - str(archive_path), - title=f"Downloading CUTLASS v{CUTLASS_VERSION}", - ) - - if cutlass_root.exists(): - rmtree(cutlass_root) - - with tarfile.open(archive_path, "r:gz") as tar: - extract_kwargs = {"path": deps_dir} - if sys.version_info >= (3, 12): - extract_kwargs["filter"] = "data" - tar.extractall(**extract_kwargs) - - extracted_dir = deps_dir / f"cutlass-{CUTLASS_VERSION}" - if not extracted_dir.exists(): - raise RuntimeError("Failed to extract CUTLASS archive") - - extracted_dir.rename(cutlass_root) - marker.touch() - return cutlass_root.resolve() - - -# --------------------------- -# Helpers (no torch required) -# --------------------------- - -def _read_env(name, default=None): - v = os.environ.get(name) - return v if (v is not None and str(v).strip() != "") else default - - -def _probe_cmd(args): - try: - out = subprocess.check_output(args, stderr=subprocess.STDOUT, text=True, timeout=5) - return out.strip() - except Exception: - return None - - -def _bool_env(name, default=False): - v = _read_env(name) - if v is None: - return default - return str(v).lower() in ("1", "true", "yes", "y", "on") - - -def _detect_rocm_version(): - v = _read_env("ROCM_VERSION") - if v: - return v - hip = _probe_cmd(["hipcc", "--version"]) - if hip: - import re - m = re.search(r"\b([0-9]+\.[0-9]+)\b", hip) - if m: - return m.group(1) - try: - p = Path("/opt/rocm/.info/version") - if p.exists(): - return p.read_text(encoding="utf-8").strip() - except Exception: - pass - return None - - -def _detect_cuda_arch_list(): - """Return TORCH_CUDA_ARCH_LIST style string for the *installed* GPUs only. - Priority: - 1) CUDA_ARCH_LIST env override (verbatim) - 2) nvidia-smi compute_cap (actual devices) - """ - # 1) explicit override - env_arch = _read_env("CUDA_ARCH_LIST") - if env_arch: - return env_arch - - # 2) actual devices present - smi_out = _probe_cmd(["nvidia-smi", "--query-gpu=compute_cap", "--format=csv,noheader"]) - if smi_out: - caps = [] - for line in smi_out.splitlines(): - cap = line.strip() - if not cap: - continue - # normalize like '8.0' - try: - major, minor = cap.split(".", 1) - caps.append(f"{int(major)}.{int(minor)}") - except Exception: - # some drivers return just '8' -> treat as '8.0' - if cap.isdigit(): - caps.append(f"{cap}.0") - caps = sorted(set(caps), key=lambda x: (int(x.split(".")[0]), int(x.split(".")[1]))) - if caps: - # PyTorch prefers ';' separators - return ";".join(caps) - - # 3) conservative default for modern datacenter GPUs (A100 et al.) - raise Exception("Could not get compute capability from nvidia-smi. Please check nvidia-utils package is installed.") - - -def _parse_arch_list(s: str): - # Accept semicolons, commas, and any whitespace as separators. - # Keep tokens like "8.0", "8.0+PTX" intact (we’ll strip suffixes later). - return [tok for tok in re.split(r"[;\s,]+", s) if tok.strip()] - - -def _has_cuda_v8_from_arch_list(arch_list): - try: - vals = [] - for a in arch_list: - # Handle things like "8.0+PTX" - base = a.split("+", 1)[0] - vals.append(float(base)) - return any(v >= 8.0 for v in vals) - except Exception: - return False - - -def _detect_cxx11_abi(): - v = _read_env("CXX11_ABI") - if v in ("0", "1"): - return int(v) - return 1 - - -def _torch_version_for_release(): - # No torch import; allow env override - v = _read_env("TORCH_VERSION") - if v: - parts = v.split(".") - return ".".join(parts[:2]) - else: - raise Exception("TORCH_VERSION not passed for wheel generation.") - return None - - -def _is_rocm_available(): - return _detect_rocm_version() is not None - - -# If you already have _probe_cmd elsewhere, you can delete this copy. -def _probe_cmd(args, timeout=6): - try: - return subprocess.check_output(args, stderr=subprocess.STDOUT, text=True, timeout=timeout) - except Exception: - return "" - - -def _first_token_line(s: str) -> str | None: - for line in (s or "").splitlines(): - t = line.strip() - if t: - return t - return None - - -def _detect_torch_version() -> str | None: - # 1) uv pip show torch - out = _probe_cmd(["uv", "pip", "show", "torch"]) - if out: - m = re.search(r"^Version:\s*([^\s]+)\s*$", out, flags=re.MULTILINE) - if m: - return m.group(1) - - # 2) pip show torch (both 'pip' and 'python -m pip') - for cmd in (["pip", "show", "torch"], [sys.executable, "-m", "pip", "show", "torch"]): - out = _probe_cmd(cmd) - if out: - m = re.search(r"^Version:\s*([^\s]+)\s*$", out, flags=re.MULTILINE) - if m: - return m.group(1) - - # 3) conda list torch - out = _probe_cmd(["conda", "list", "torch"]) - if out: - # Typical line starts with: torch 2.4.1 ... - for line in out.splitlines(): - if line.strip().startswith("torch"): - parts = re.split(r"\s+", line.strip()) - if len(parts) >= 2 and re.match(r"^\d+\.\d+(\.\d+)?", parts[1]): - return parts[1] - - # 4) Fallback: importlib.metadata (does not import torch package module) - try: - import importlib.metadata as im # py3.8+ - version = im.version("torch") - if not version: - raise Exception("torch not found") - except Exception: - raise Exception("Unable to detect torch version via uv/pip/conda/importlib. Please install torch >= 2.7.1") - - -def _detect_torch_accelerator_backend() -> str | None: - """Return accelerator backend for installed torch build. - - Returns: - - "cuda" when torch is built with CUDA support - - "rocm" when torch is built with ROCm support - - "cpu" when torch is CPU-only - - None when torch cannot be imported - """ - try: - import torch # type: ignore - except Exception: - return None - - version_info = getattr(torch, "version", None) - if version_info is None: - return None - - if getattr(version_info, "hip", None): - return "rocm" - if getattr(version_info, "cuda", None): - return "cuda" - return "cpu" - - -def _major_minor(v: str) -> str: - if v: - parts = v.split(".") - return ".".join(parts[:2]) if parts else v - return v - - -def _version_geq(version: str | None, major: int, minor: int = 0) -> bool: - if not version: - return False - try: - parts = re.split(r"[._-]", version) - ver_major = int(parts[0]) if parts else 0 - ver_minor = int(parts[1]) if len(parts) > 1 else 0 - return (ver_major, ver_minor) >= (major, minor) - except Exception: - return False - - -def _nvcc_release_version() -> str | None: - # Search for nvcc in common locations before giving up. - candidates: list[str] = [] - nvcc_env = _read_env("NVCC") - if nvcc_env: - candidates.append(nvcc_env) - - cuda_home = _read_env("CUDA_HOME") - cuda_path = _read_env("CUDA_PATH") - - candidates.extend( - [ - "nvcc", - str(Path(cuda_home).joinpath("bin", "nvcc")) if cuda_home else None, - str(Path(cuda_path).joinpath("bin", "nvcc")) if cuda_path else None, - "/usr/local/cuda/bin/nvcc", - ] - ) - - seen = set() - for cmd in candidates: - if not cmd or cmd in seen: - continue - seen.add(cmd) - out = _probe_cmd([cmd, "--version"]) - if not out: - continue - match = re.search(r"release\s+(\d+)\.(\d+)", out) - if match: - return f"{match.group(1)}.{match.group(2)}" - - print( - "NVCC not found (checked PATH, $CUDA_HOME/bin, $CUDA_PATH/bin, /usr/local/cuda/bin). " - "For Ubuntu, run `sudo update-alternatives --config cuda` to fix path for already installed Cuda." - ) - return None - - -def _detect_cuda_version() -> str | None: - # Priority: env → nvidia-smi → nvcc - v = os.environ.get("CUDA_VERSION") - if v and v.strip(): - return v.strip() - - # nvcc --version (parse 'release X.Y') - return _nvcc_release_version() - - -def _detect_nvcc_version() -> str | None: - return _nvcc_release_version() - - -def get_version_tag() -> str: - # TODO FIX ME: cpu wheels don't have torch version tags? - if BUILD_CUDA_EXT != "1": - return "cpu" - - # TODO FIX ME: rocm wheels don't have torch version tags? - if ROCM_VERSION: - return f"rocm{ROCM_VERSION}" - - if not CUDA_VERSION: - raise Exception("Trying to compile GPTQModel for CUDA/ROCm, but no cuda or rocm version was detected.") - - torch_suffix = f"torch{_major_minor(TORCH_VERSION)}" - - CUDA_VERSION_COMPACT = "".join(CUDA_VERSION.split(".")) - base = f"cu{CUDA_VERSION_COMPACT[:3]}" - return f"{base}{torch_suffix}" - - -# --------------------------- -# Env and versioning -# --------------------------- - -TORCH_VERSION = _read_env("TORCH_VERSION") -RELEASE_MODE = _read_env("RELEASE_MODE") -CUDA_VERSION = _read_env("CUDA_VERSION") -ROCM_VERSION = _read_env("ROCM_VERSION") -TORCH_CUDA_ARCH_LIST = _read_env("TORCH_CUDA_ARCH_LIST") -NVCC_VERSION = _read_env("NVCC_VERSION") -TORCH_ACCELERATOR_BACKEND = _read_env("TORCH_ACCELERATOR_BACKEND") - -# respect user env then detect -if not TORCH_VERSION: - TORCH_VERSION = _detect_torch_version() -if not CUDA_VERSION: - CUDA_VERSION = _detect_cuda_version() -if not ROCM_VERSION: - ROCM_VERSION = _detect_rocm_version() -if not NVCC_VERSION: - NVCC_VERSION = _detect_nvcc_version() -if not TORCH_ACCELERATOR_BACKEND: - TORCH_ACCELERATOR_BACKEND = _detect_torch_accelerator_backend() - -SKIP_ROCM_VERSION_CHECK = _read_env("SKIP_ROCM_VERSION_CHECK") -FORCE_BUILD = _bool_env("GPTQMODEL_FORCE_BUILD", False) - -# BUILD_CUDA_EXT: -# - If user sets explicitly, respect it. -# - Otherwise auto: enable only when torch backend and toolkit match. -BUILD_CUDA_EXT = _read_env("BUILD_CUDA_EXT") -if BUILD_CUDA_EXT is None: - if TORCH_ACCELERATOR_BACKEND == "cuda": - BUILD_CUDA_EXT = "1" if CUDA_VERSION else "0" - elif TORCH_ACCELERATOR_BACKEND == "rocm": - BUILD_CUDA_EXT = "1" if ROCM_VERSION else "0" - elif TORCH_ACCELERATOR_BACKEND == "cpu": - BUILD_CUDA_EXT = "0" - else: - BUILD_CUDA_EXT = "1" if (CUDA_VERSION or ROCM_VERSION) else "0" - -if ROCM_VERSION and not SKIP_ROCM_VERSION_CHECK: - try: - if float(ROCM_VERSION) < 6.2: - sys.exit( - "GPTQModel's compatibility with ROCm < 6.2 has not been verified. " - "Set SKIP_ROCM_VERSION_CHECK=1 to proceed." - ) - except Exception: - pass - -# Handle CUDA_ARCH_LIST (public) and set TORCH_CUDA_ARCH_LIST for build toolchains -CUDA_ARCH_LIST = _detect_cuda_arch_list() if (BUILD_CUDA_EXT == "1" and not ROCM_VERSION) else None - -if not TORCH_CUDA_ARCH_LIST and CUDA_ARCH_LIST: - archs = _parse_arch_list(CUDA_ARCH_LIST) - kept = [] - for arch in archs: - try: - base = arch.split("+", 1)[0] - if float(base) >= 6.0: - kept.append(arch) - else: - print(f"we do not support this compute arch: {arch}, skipped.") - except Exception: - kept.append(arch) - - # Use semicolons for TORCH_CUDA_ARCH_LIST (PyTorch likes this), - TORCH_CUDA_ARCH_LIST = ";".join(kept) - os.environ["TORCH_CUDA_ARCH_LIST"] = TORCH_CUDA_ARCH_LIST - - print(f"CUDA_ARCH_LIST: {CUDA_ARCH_LIST}") - print(f"TORCH_CUDA_ARCH_LIST: {TORCH_CUDA_ARCH_LIST}") - -version_vars = {} -exec("exec(open('gptqmodel/version.py').read()); version=__version__", {}, version_vars) -gptqmodel_version = version_vars["version"] - -# ----------------------------- -# Prebuilt wheel download config -# ----------------------------- -# Default template (GitHub Releases), can be overridden via env. -DEFAULT_WHEEL_URL_TEMPLATE = "https://github.com/ModelCloud/GPTQModel/releases/download/{tag_name}/{wheel_name}" -WHEEL_URL_TEMPLATE = os.environ.get("GPTQMODEL_WHEEL_URL_TEMPLATE") -WHEEL_BASE_URL = os.environ.get("GPTQMODEL_WHEEL_BASE_URL") -WHEEL_TAG = os.environ.get("GPTQMODEL_WHEEL_TAG") # Optional override of release tag - - -def _resolve_wheel_url(tag_name: str, wheel_name: str) -> str: - """ - Build the final wheel URL based on: - 1) GPTQMODEL_WHEEL_URL_TEMPLATE (highest priority) - 2) GPTQMODEL_WHEEL_BASE_URL (append /{wheel_name}) - 3) DEFAULT_WHEEL_URL_TEMPLATE (GitHub Releases) - """ - # Highest priority: explicit template - if WHEEL_URL_TEMPLATE: - tmpl = WHEEL_URL_TEMPLATE - # If {wheel_name} or {tag_name} not present, treat as base and append name. - if ("{wheel_name}" in tmpl) or ("{tag_name}" in tmpl): - return tmpl.format(tag_name=tag_name, wheel_name=wheel_name) - # Otherwise, join as base - if tmpl.endswith("/"): - return tmpl + wheel_name - return tmpl + "/" + wheel_name - - # Next priority: base URL - if WHEEL_BASE_URL: - base = WHEEL_BASE_URL - if base.endswith("/"): - return base + wheel_name - return base + "/" + wheel_name - - # Fallback: default GitHub template - return DEFAULT_WHEEL_URL_TEMPLATE.format(tag_name=tag_name, wheel_name=wheel_name) - - -def _download_with_progress(url: str, dest_path: str, title: str = "Downloading") -> None: - """Download url to dest_path with simple stdout progress updates.""" - import time - import urllib.request as req - - start_time = time.time() - last_draw_time = 0.0 - last_print_percent = -1 - - def _format_bytes(num_bytes: float) -> str: - units = ["B", "KiB", "MiB", "GiB", "TiB"] - value = float(max(num_bytes, 0.0)) - for unit in units: - if value < 1024.0 or unit == units[-1]: - return f"{value:0.1f}{unit}" if unit != "B" else f"{int(value)}B" - value /= 1024.0 - return f"{value:0.1f}TiB" - - def _reporthook(block_num: int, block_size: int, total_size: int) -> None: - nonlocal last_draw_time, last_print_percent - now = time.time() - downloaded = block_num * block_size - speed = downloaded / max(now - start_time, 1e-6) - - if total_size and total_size > 0: - percent = min(int(downloaded * 100 / total_size), 100) - if percent == last_print_percent and percent != 100: - return - subtitle = ( - f"{percent:3d}% ({_format_bytes(downloaded)}/{_format_bytes(total_size)}) " - f"{_format_bytes(speed)}/s" - ) - print(f"{title} {subtitle}", flush=True) - last_print_percent = percent - last_draw_time = now - else: - if (now - last_draw_time) < 1.0: - return - subtitle = f"{_format_bytes(downloaded)} {_format_bytes(speed)}/s" - print(f"{title} {subtitle}", flush=True) - last_draw_time = now - - req.urlretrieve(url, dest_path, reporthook=_reporthook) - - -# Decide HAS_CUDA_V8 / HAS_CUDA_V9 without torch -HAS_CUDA_V8 = False -HAS_CUDA_V9 = False -if CUDA_ARCH_LIST: - arch_list = _parse_arch_list(CUDA_ARCH_LIST) - try: - caps = [float(tok.split("+", 1)[0]) for tok in arch_list] - except Exception: - caps = [] - if not ROCM_VERSION: - HAS_CUDA_V8 = any(cap >= 8.0 for cap in caps) - HAS_CUDA_V9 = any(cap >= 9.0 for cap in caps) -else: - smi = _probe_cmd(["nvidia-smi", "--query-gpu=compute_cap", "--format=csv,noheader"]) - if smi: - try: - caps = [float(x.strip()) for x in smi.splitlines() if x.strip()] - HAS_CUDA_V8 = any(cap >= 8.0 for cap in caps) - HAS_CUDA_V9 = any(cap >= 9.0 for cap in caps) - except Exception: - HAS_CUDA_V8 = False - HAS_CUDA_V9 = False - -if RELEASE_MODE == "1": - gptqmodel_version = f"{gptqmodel_version}+{get_version_tag()}" - -include_dirs = ["gptqmodel_ext"] - -extensions = [] -additional_setup_kwargs = {} - - -# --------------------------- -# Build CUDA/ROCm extensions (only when enabled) -# --------------------------- -# ----------------------------- -# Per-extension build toggles -# ----------------------------- -def _env_enabled(val: str) -> bool: - if val is None: - return True - return str(val).strip().lower() not in ("0", "false", "off", "no") - - -def _env_enabled_any(names, default="1") -> bool: - for n in names: - if n in os.environ: - return _env_enabled(os.environ.get(n)) - return _env_enabled(default) - - -BUILD_MARLIN = _env_enabled_any(os.environ.get("GPTQMODEL_BUILD_MARLIN", "1")) -BUILD_MACHETE = _env_enabled(os.environ.get("GPTQMODEL_BUILD_MACHETE", "0")) -BUILD_EXLLAMA_V2 = _env_enabled(os.environ.get("GPTQMODEL_BUILD_EXLLAMA_V2", "1")) -BUILD_QQQ = _env_enabled(os.environ.get("GPTQMODEL_BUILD_QQQ", "1")) -BUILD_AWQ = _env_enabled(os.environ.get("GPTQMODEL_BUILD_AWQ", "1")) - -# Optional kernels and not build by default. Enable compile with env flags -BUILD_EORA = _env_enabled(os.environ.get("GPTQMODEL_BUILD_EORA", "0")) -BUILD_EXLLAMA_V1 = _env_enabled(os.environ.get("GPTQMODEL_BUILD_EXLLAMA_V1", "0")) - -if BUILD_CUDA_EXT == "1": - # Import torch's cpp_extension only if we're truly building GPU extensions - try: - - from torch.utils import cpp_extension as cpp_ext # type: ignore - except Exception: - if FORCE_BUILD: - sys.exit( - "FORCE_BUILD is set but PyTorch C++ extension headers are unavailable. " - "Install torch build deps first (see https://pytorch.org/) or unset GPTQMODEL_FORCE_BUILD." - ) - # If we can't import cpp_extension, fall back to prebuilt wheel path - cpp_ext = None - - if cpp_ext is not None: - # Limit compile parallelism to avoid overwhelming nvcc/cicc invocations. - # Respect pre-set MAX_JOBS, otherwise fall back to CPU count minus two (min 1). - cpu_count = os.cpu_count() or 1 - default_max_jobs = max(1, cpu_count - 2) - max_jobs_raw = os.environ.get("MAX_JOBS") - if max_jobs_raw is None or max_jobs_raw.strip() == "": - effective_max_jobs = default_max_jobs - print(f"MAX_JOBS not set; defaulting to {effective_max_jobs} concurrent CUDA compilations.") - else: - try: - parsed_jobs = int(max_jobs_raw) - except ValueError: - effective_max_jobs = default_max_jobs - print(f"Ignoring invalid MAX_JOBS={max_jobs_raw!r}; using {effective_max_jobs}.") - else: - if parsed_jobs <= 0: - effective_max_jobs = default_max_jobs - print(f"MAX_JOBS={parsed_jobs} is non-positive; using {effective_max_jobs}.") - else: - effective_max_jobs = parsed_jobs - - os.environ["MAX_JOBS"] = str(effective_max_jobs) - os.environ["NINJA_NUM_JOBS"] = str(effective_max_jobs) - print(f"Using MAX_JOBS={effective_max_jobs} to cap concurrent CUDA compilations.") - - nvcc_threads = 2 - os.environ["NVCC_THREADS"] = str(nvcc_threads) - print(f"Using NVCC_THREADS={nvcc_threads} for per-invocation NVCC concurrency.") - - # Optional conda CUDA runtime headers - # conda_cuda_include_dir = os.path.join(get_python_lib(), "nvidia/cuda_runtime/include") - # if os.path.isdir(conda_cuda_include_dir): - # include_dirs.append(conda_cuda_include_dir) - # print(f"appending conda cuda include dir {conda_cuda_include_dir}") - - extra_link_args = [] - extra_compile_args = { - "cxx": ["-O3", "-std=c++17", "-DENABLE_BF16"], - "nvcc": [ - "-O3", - "-std=c++17", - "-DENABLE_BF16", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT16_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", - "-U__CUDA_NO_BFLOAT162_OPERATORS__", - "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", - ], - } - - cutlass_include_paths = [] - if BUILD_MACHETE: - cutlass_root = _ensure_cutlass_source() - cutlass_include_paths = [ - Path("gptqmodel_ext/cutlass_extensions").resolve(), - cutlass_root / "include", - cutlass_root / "examples/common/include", - cutlass_root / "tools/library/include", - ] - if "GPTQMODEL_CUTLASS_DIR" not in os.environ: - os.environ["GPTQMODEL_CUTLASS_DIR"] = str(cutlass_root) - cutlass_include_flags = [f"-I{path}" for path in cutlass_include_paths] - extra_compile_args["cxx"] += cutlass_include_flags - extra_compile_args["nvcc"] += cutlass_include_flags - - # Windows/OpenMP note: adjust flags as needed for MSVC if you add native Windows wheels - if sys.platform == "win32": - extra_compile_args["cxx"] = ["/O2", "/std:c++17", "/openmp", "/DNDEBUG", "/DENABLE_BF16"] - - CXX11_ABI = _detect_cxx11_abi() - extra_compile_args["cxx"] += [f"-D_GLIBCXX_USE_CXX11_ABI={CXX11_ABI}"] - extra_compile_args["nvcc"] += [f"-D_GLIBCXX_USE_CXX11_ABI={CXX11_ABI}"] - - if not ROCM_VERSION: - # if _version_geq(NVCC_VERSION, 13, 0): - # extra_compile_args["nvcc"].append("--device-entity-has-hidden-visibility=false") - nvcc_extra_flags = [ - "--threads", str(nvcc_threads), # NVCC parallelism - "--optimize=3", # alias for -O3 - # "-rdc=true", # enable relocatable device code, required for future cuda > 13.x <-- TODO FIX ME broken loading - # "-dlto", # compile and link <-- TODO FIX ME - # Print register/shared-memory usage per kernel (debug aid, no perf effect) - # Ensure PTXAS uses maximum optimization - # Cache global loads in both L1 and L2 (better for memory-bound kernels) - "-Xptxas", "-v,-O3,-dlcm=ca", - "-lineinfo", # keep source line info for profiling - # "--resource-usage", # show per-kernel register/SMEM usage - "-Xfatbin", "-compress-all", # compress fatbin - # "--expt-relaxed-constexpr", # relaxed constexpr rules <-- not used - # "--expt-extended-lambda", # allow device lambdas <-- not used - "-diag-suppress=179,39,177", # silence some template warnings - ] - if _version_geq(NVCC_VERSION, 12, 8): - # Allow instantiations of __global__ templates to live in different TUs; only supported in newer NVCC. - nvcc_extra_flags.insert(0, "-static-global-template-stub=false") - extra_compile_args["nvcc"] += nvcc_extra_flags - else: - # hipify CUDA-like flags - def _hipify_compile_flags(flags): - modified_flags = [] - for flag in flags: - if flag.startswith("-") and "CUDA" in flag and not flag.startswith("-I"): - parts = flag.split("=", 1) - if len(parts) == 2: - flag_part, value_part = parts - modified_flag_part = flag_part.replace("CUDA", "HIP", 1) - modified_flags.append(f"{modified_flag_part}={value_part}") - else: - modified_flags.append(flag.replace("CUDA", "HIP", 1)) - else: - modified_flags.append(flag) - return modified_flags - - - extra_compile_args["nvcc"] = _hipify_compile_flags(extra_compile_args["nvcc"]) - - # Extensions (gate marlin/qqq/eora/exllamav2 on CUDA sm_80+ and non-ROCm) - if sys.platform != "win32": - if not ROCM_VERSION and HAS_CUDA_V8: - if BUILD_MARLIN: - marlin_kernel_dir = Path("gptqmodel_ext/marlin") - marlin_kernel_files = sorted(marlin_kernel_dir.glob("kernel_*.cu")) - - if not marlin_kernel_files: - generator_script = marlin_kernel_dir / "generate_kernels.py" - if generator_script.exists(): - print("Regenerating marlin template instantiations for parallel compilation...") - subprocess.check_call([sys.executable, str(generator_script)]) - marlin_kernel_files = sorted(marlin_kernel_dir.glob("kernel_*.cu")) - - if not marlin_kernel_files: - raise RuntimeError( - "No generated marlin kernel templates detected. Run generate_kernels.py before building." - ) - - marlin_template_kernel_srcs = [str(path) for path in marlin_kernel_files] - extensions += [ - cpp_ext.CUDAExtension( - "gptqmodel_marlin_kernels", - [ - "gptqmodel_ext/marlin/marlin_cuda.cpp", - "gptqmodel_ext/marlin/gptq_marlin.cu", - "gptqmodel_ext/marlin/gptq_marlin_repack.cu", - "gptqmodel_ext/marlin/awq_marlin_repack.cu", - ] + marlin_template_kernel_srcs, - extra_link_args=extra_link_args, - extra_compile_args=extra_compile_args, - ) - ] - - if BUILD_MACHETE and HAS_CUDA_V9 and _version_geq(NVCC_VERSION, 12, 0): - try: - result = subprocess.run( - [sys.executable, "gptqmodel_ext/machete/generate.py"], - check=True, - text=True, - capture_output=True - ) - except subprocess.CalledProcessError as e: - raise RuntimeError( - f"Error generating machete kernel templates:\n" - f"Return code: {e.returncode}\n" - f"Stderr: {e.stderr}\n" - f"Stdout: {e.stdout}" - ) - machete_dir = Path("gptqmodel_ext/machete") - machete_generated_dir = machete_dir / "generated" - - machete_sources = [str(machete_dir / "machete_pytorch.cu")] - machete_generated_sources = sorted(machete_generated_dir.glob("*.cu")) - - if not machete_generated_sources: - raise RuntimeError( - "No generated machete kernel templates detected. Run gptqmodel_ext/machete/generate.py" - " with CUTLASS checkout before building." - ) - - machete_sources += [str(path) for path in machete_generated_sources] - - machete_include_dirs = [str(Path("gptqmodel_ext").resolve())] + [str(path) for path in cutlass_include_paths] - - extensions += [ - cpp_ext.CUDAExtension( - "gptqmodel_machete_kernels", - machete_sources, - extra_link_args=extra_link_args, - extra_compile_args=extra_compile_args, - include_dirs=machete_include_dirs, - ) - ] - - if BUILD_QQQ: - extensions += [ - cpp_ext.CUDAExtension( - "gptqmodel_qqq_kernels", - [ - "gptqmodel_ext/qqq/qqq.cpp", - "gptqmodel_ext/qqq/qqq_gemm.cu", - ], - extra_link_args=extra_link_args, - extra_compile_args=extra_compile_args, - ) - ] - - if BUILD_EORA: - extensions += [ - cpp_ext.CUDAExtension( - "gptqmodel_exllama_eora", - [ - "gptqmodel_ext/exllama_eora/eora/q_gemm.cu", - "gptqmodel_ext/exllama_eora/eora/pybind.cu", - ], - extra_link_args=extra_link_args, - extra_compile_args=extra_compile_args, - ) - ] - if BUILD_EXLLAMA_V2: - extensions += [ - cpp_ext.CUDAExtension( - "gptqmodel_exllamav2_kernels", - [ - "gptqmodel_ext/exllamav2/ext_gptq.cpp", - "gptqmodel_ext/exllamav2/cuda/q_matrix.cu", - "gptqmodel_ext/exllamav2/cuda/q_gemm.cu", - ], - extra_link_args=extra_link_args, - extra_compile_args=extra_compile_args, - ) - ] - - # both CUDA and ROCm compatible - if BUILD_EXLLAMA_V1: - extensions += [ - cpp_ext.CUDAExtension( - "gptqmodel_exllama_kernels", - [ - "gptqmodel_ext/exllama/exllama_ext.cpp", - "gptqmodel_ext/exllama/cuda_buffers.cu", - "gptqmodel_ext/exllama/cuda_func/column_remap.cu", - "gptqmodel_ext/exllama/cuda_func/q4_matmul.cu", - "gptqmodel_ext/exllama/cuda_func/q4_matrix.cu", - ], - extra_link_args=extra_link_args, - extra_compile_args=extra_compile_args, - ) - ] - - if BUILD_AWQ: - if ROCM_VERSION: - print("Skipping AWQ kernels on ROCm: inline PTX is CUDA-only.") - else: - extensions += [ - # contain un-hipifiable inline PTX - cpp_ext.CUDAExtension( - "gptqmodel_awq_kernels", - [ - "gptqmodel_ext/awq/pybind_awq.cpp", - "gptqmodel_ext/awq/quantization/gemm_cuda_gen.cu", - "gptqmodel_ext/awq/quantization/gemv_cuda.cu", - ], - extra_link_args=extra_link_args, - extra_compile_args=extra_compile_args, - ), - # TODO only compatible with ampere? - # arch_flags = get_compute_capabilities({80, 86, 89, 90}) - # extra_compile_args_v2 = get_extra_compile_args(arch_flags, generator_flags) - cpp_ext.CUDAExtension( - "gptqmodel_awq_v2_kernels", - [ - "gptqmodel_ext/awq/pybind_awq_v2.cpp", - "gptqmodel_ext/awq/quantization_new/gemv/gemv_cuda.cu", - "gptqmodel_ext/awq/quantization_new/gemm/gemm_cuda.cu", - ], - extra_link_args=extra_link_args, - extra_compile_args=extra_compile_args, - ), - cpp_ext.CUDAExtension( - "gptqmodel_exllamav2_awq_kernels", - [ - "gptqmodel_ext/exllamav2/ext_awq.cpp", - "gptqmodel_ext/exllamav2/cuda/q_matrix_awq.cu", - "gptqmodel_ext/exllamav2/cuda/q_gemm_awq.cu", - ], - extra_link_args=extra_link_args, - extra_compile_args=extra_compile_args, - ) - ] - - # Ensure machete kernels are compiled before other extensions - machete_exts = [ext for ext in extensions if getattr(ext, "name", "") == "gptqmodel_machete_kernels"] - if machete_exts: - other_exts = [ext for ext in extensions if getattr(ext, "name", "") != "gptqmodel_machete_kernels"] - extensions[:] = machete_exts + other_exts - - additional_setup_kwargs = { - "ext_modules": extensions, - "cmdclass": {"build_ext": cpp_ext.BuildExtension}, - } - - # additional_setup_kwargs = { - # "ext_modules": extensions, - # # "include_package_data": True, - # # "package_data": {"": ["build/lib/*.so"]}, - # "cmdclass": {"build_ext": cpp_ext.BuildExtension.with_options( - # use_ninja=True, - # no_python_abi_suffix=True, - # build_temp="build/temp", - # # build_lib="build/lib", TODO FIX ME why package_data doesn't work.. - # clean_first=False # keep intermediates for reuse - # )}, - # } - - -# --------------------------- -# Cached wheel fetcher -# --------------------------- - -class CachedWheelsCommand(_bdist_wheel): - def run(self): - # No implicit torch checks; allow explicit override via env - xpu_avail = _bool_env("XPU_AVAILABLE", False) - if FORCE_BUILD or xpu_avail: - return super().run() - - python_version = f"cp{sys.version_info.major}{sys.version_info.minor}{sys.abiflags}" - - wheel_filename = f"gptqmodel-{gptqmodel_version}+{get_version_tag()}-{python_version}-{python_version}-linux_x86_64.whl" - - # Allow tag override via env; default to "v{gptqmodel_version}" - tag_name = WHEEL_TAG if WHEEL_TAG else f"v{gptqmodel_version}" - wheel_url = _resolve_wheel_url(tag_name=tag_name, wheel_name=wheel_filename) - - print(f"Resolved wheel URL: {wheel_url}\nwheel name={wheel_filename}") - - try: - if not os.path.exists(self.dist_dir): - os.makedirs(self.dist_dir) - - wheel_path = os.path.join(self.dist_dir, wheel_filename) - _download_with_progress(wheel_url, wheel_path, title="Downloading wheel") - print("Raw wheel path", wheel_filename) - except BaseException: - env_info = [f"python={python_version}", f"torch={TORCH_VERSION or 'unknown'}"] - if CUDA_VERSION: - env_info.append(f"cuda={CUDA_VERSION}") - if ROCM_VERSION: - env_info.append(f"rocm={ROCM_VERSION}") - print( - "Unable to match and download a precompiled wheel; entering slow manual build mode. " - f"Wheel match params: {', '.join(env_info)}. " - f"Fallback source build triggered for {wheel_url}" - ) - super().run() +def _package_version() -> str: + version_vars: dict[str, str] = {} + exec(Path("gptqmodel/version.py").read_text(encoding="utf-8"), {}, version_vars) + return version_vars["__version__"] -# --------------------------- -# setup() -# --------------------------- -print(f"CUDA {CUDA_ARCH_LIST}") -print(f"HAS_CUDA_V8 {HAS_CUDA_V8}") -print(f"HAS_CUDA_V9 {HAS_CUDA_V9}") -print(f"SETUP_KWARGS {additional_setup_kwargs}") -print(f"gptqmodel_version={gptqmodel_version}") +packages = find_packages(exclude=("tests", "tests.*")) +for package_name in find_namespace_packages(include=("gptqmodel_ext.*",)): + if package_name not in packages: + packages.append(package_name) -_namespace_packages = find_namespace_packages(include=["gptqmodel_ext.*"]) -_packages = find_packages() -for _pkg in _namespace_packages: - if _pkg not in _packages: - _packages.append(_pkg) setup( - version=gptqmodel_version, - packages=_packages, + version=_package_version(), + packages=packages, include_package_data=True, - include_dirs=include_dirs, - cmdclass=( - {"bdist_wheel": CachedWheelsCommand, "build_ext": additional_setup_kwargs.get("cmdclass", {}).get("build_ext")} - if (BUILD_CUDA_EXT == "1" and additional_setup_kwargs) - else {"bdist_wheel": CachedWheelsCommand} - ), - ext_modules=additional_setup_kwargs.get("ext_modules", []), ) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..a359c887b --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Test package namespace used for local test helpers.""" diff --git a/tests/awq_test_utils.py b/tests/awq_test_utils.py new file mode 100644 index 000000000..9c83a86fe --- /dev/null +++ b/tests/awq_test_utils.py @@ -0,0 +1,225 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +# -- do not touch +import os + + +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +# -- end do not touch +import atexit # noqa: E402 +import json # noqa: E402 +import logging # noqa: E402 +import shutil # noqa: E402 +import tempfile # noqa: E402 +import threading # noqa: E402 +from dataclasses import dataclass # noqa: E402 +from pathlib import Path # noqa: E402 + +import torch # noqa: E402 +from datasets import load_dataset # noqa: E402 +from models.model_test import ModelTest # noqa: E402 +from transformers import AutoTokenizer # noqa: E402 + +from gptqmodel import BACKEND, GPTQModel, QuantizeConfig # noqa: E402 +from gptqmodel.nn_modules.qlinear.gemm_awq import AwqGEMMLinear # noqa: E402 +from gptqmodel.nn_modules.qlinear.gemv_awq import AwqGEMVLinear # noqa: E402 +from gptqmodel.nn_modules.qlinear.gemv_fast_awq import AwqGEMVFastLinear # noqa: E402 +from gptqmodel.nn_modules.qlinear.machete_awq import AwqMacheteLinear # noqa: E402 +from gptqmodel.nn_modules.qlinear.marlin_awq import AwqMarlinLinear # noqa: E402 +from gptqmodel.quantization import FORMAT, METHOD, QUANT_CONFIG_FILENAME # noqa: E402 + + +PROMPT = "The capital city of France is named" +PRETRAINED_MODEL_ID = "/monster/data/model/Llama-3.2-1B" +# Historical local alternative used during AWQ bring-up: +# "/monster/data/model/Qwen2.5-0.5B-Instruct/" +CALIBRATION_DATASET_PATH = "/monster/data/model/dataset/c4-train.00000-of-01024.json.gz" +AWQ_GROUP_SIZE = 128 + + +@dataclass(frozen=True) +class QuantizedAwqArtifact: + model_path: str + quantize_config_dict: dict + + +_TOKENIZER = None +_CALIBRATION_DATASETS: dict[int, object] = {} +_ARTIFACTS: dict[tuple[object, int, int], QuantizedAwqArtifact] = {} +_ARTIFACT_ROOTS: list[str] = [] +_ARTIFACT_LOCK = threading.Lock() + + +def _cleanup_artifacts() -> None: + for artifact_root in reversed(_ARTIFACT_ROOTS): + shutil.rmtree(artifact_root, ignore_errors=True) + + +atexit.register(_cleanup_artifacts) + + +def awq_sample_count() -> int: + requested_samples = os.getenv("GPTQMODEL_AWQ_CALIB_SAMPLES") + if requested_samples is not None: + return max(1, int(requested_samples)) + + if torch.cuda.is_available(): + try: + torch.cuda.get_device_properties(torch.cuda.current_device()).total_memory / (1024 ** 3) + except Exception: + pass + + # if total_mem_gb >= 80: + # sample_count = 1024 + # elif total_mem_gb >= 48: + # sample_count = 512 + # else: + # sample_count = 192 + return 512 + + +def get_awq_tokenizer(): + global _TOKENIZER + if _TOKENIZER is None: + _TOKENIZER = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_ID, use_fast=True) + return _TOKENIZER + + +def get_awq_calibration_dataset(): + sample_count = awq_sample_count() + dataset = _CALIBRATION_DATASETS.get(sample_count) + if dataset is None: + traindata = load_dataset("json", data_files=CALIBRATION_DATASET_PATH, split="train") + dataset = traindata.select(range(sample_count)) + _CALIBRATION_DATASETS[sample_count] = dataset + return dataset + + +def _quantized_artifact_key(checkpoint_format: FORMAT, group_size: int) -> tuple[object, int, int]: + return checkpoint_format, group_size, awq_sample_count() + + +def get_quantized_awq_artifact(checkpoint_format: FORMAT, group_size: int = AWQ_GROUP_SIZE) -> QuantizedAwqArtifact: + artifact_key = _quantized_artifact_key(checkpoint_format, group_size) + cached_artifact = _ARTIFACTS.get(artifact_key) + if cached_artifact is not None: + return cached_artifact + + with _ARTIFACT_LOCK: + cached_artifact = _ARTIFACTS.get(artifact_key) + if cached_artifact is not None: + return cached_artifact + + quantize_config = QuantizeConfig( + bits=4, + group_size=group_size, + quant_method=METHOD.AWQ, + format=checkpoint_format, + ) + + model = GPTQModel.load( + PRETRAINED_MODEL_ID, + quantize_config=quantize_config, + ) + model.quantize(get_awq_calibration_dataset(), batch_size=1, calibration_concat_size=0) + + format_name = getattr(checkpoint_format, "value", str(checkpoint_format)).lower() + artifact_root = tempfile.mkdtemp(prefix=f"awq_{format_name}_") + model.save(artifact_root) + + with open(Path(artifact_root) / QUANT_CONFIG_FILENAME, "r", encoding="utf-8") as config_file: + file_dict = json.loads(config_file.read()) + assert model.quantize_config.to_dict() == file_dict + # Exclude `offload_to_disk_path`, which is a random value. + file_dict["meta"].pop("offload_to_disk_path") + logging.info("Saved config file: %s", file_dict) + + del model + + _ARTIFACT_ROOTS.append(artifact_root) + cached_artifact = QuantizedAwqArtifact( + model_path=artifact_root, + quantize_config_dict=file_dict, + ) + _ARTIFACTS[artifact_key] = cached_artifact + return cached_artifact + + +def assert_awq_linear_backend(model, backend: BACKEND) -> None: + if backend == BACKEND.GEMM: + linear_cls = AwqGEMMLinear + elif backend == BACKEND.MACHETE: + linear_cls = AwqMacheteLinear + elif backend == BACKEND.MARLIN: + linear_cls = AwqMarlinLinear + elif backend == BACKEND.GEMV: + linear_cls = AwqGEMVLinear + elif backend == BACKEND.GEMV_FAST: + linear_cls = AwqGEMVFastLinear + else: + raise ValueError(f"unknown backend: {backend}") + + assert any(isinstance(module, linear_cls) for _, module in model.named_modules()) + + +def assert_loaded_quantize_config_matches(model, expected_config: dict) -> None: + actual_quantize_config = model.quantize_config.to_dict() + actual_quantize_config["meta"].pop("offload_to_disk_path") + assert actual_quantize_config == expected_config + + +def assert_generation_mentions_paris_or_city(result: str, *, extra_terms: tuple[str, ...] = ()) -> None: + accepted_terms = {"paris", "city", *extra_terms} + if not any(term in result.lower() for term in accepted_terms): + raise AssertionError(f"expected one of {sorted(accepted_terms)} in generation: {result}") + + +def run_quantized_awq_generation_test(checkpoint_format: FORMAT, backend: BACKEND, *, group_size: int = AWQ_GROUP_SIZE): + artifact = get_quantized_awq_artifact(checkpoint_format, group_size) + model = GPTQModel.load( + artifact.model_path, + backend=backend, + device="cuda", + ) + + assert_loaded_quantize_config_matches(model, artifact.quantize_config_dict) + assert_awq_linear_backend(model, backend) + + result = ModelTest.generate_stable_with_limit( + model, + get_awq_tokenizer(), + PROMPT, + max_new_tokens=100, + ) + print(f"BACKEND: {backend}, Result: {result}") + assert_generation_mentions_paris_or_city(result) + + del model + + +def run_inference_only_generation_test( + model_id: str, + *, + backend: BACKEND, + max_new_tokens: int, + extra_terms: tuple[str, ...] = (), +) -> None: + model = GPTQModel.load( + model_id, + backend=backend, + device="cuda", + ) + + result = ModelTest.generate_stable_with_limit( + model, + model.tokenizer, + PROMPT, + max_new_tokens=max_new_tokens, + ) + print(f"BACKEND: {backend}, Result: {result}") + assert_generation_mentions_paris_or_city(result, extra_terms=extra_terms) + + del model diff --git a/tests/benchmark/benchmark_torch.py b/tests/benchmark/benchmark_torch.py index 789cd8b2f..441eb32e6 100755 --- a/tests/benchmark/benchmark_torch.py +++ b/tests/benchmark/benchmark_torch.py @@ -71,7 +71,7 @@ def _load_baseline_and_current(repo_root: Path): importlib.import_module("gptqmodel.nn_modules.qlinear") current_mod = importlib.import_module("gptqmodel.nn_modules.qlinear.torch") - current_cls = current_mod.TorchQuantLinear + current_cls = current_mod.TorchLinear baseline_src = subprocess.check_output( ["git", "show", "HEAD:gptqmodel/nn_modules/qlinear/torch.py"], @@ -84,7 +84,7 @@ def _load_baseline_and_current(repo_root: Path): baseline_mod.__package__ = "gptqmodel.nn_modules.qlinear" sys.modules[baseline_name] = baseline_mod exec(compile(baseline_src, baseline_mod.__file__, "exec"), baseline_mod.__dict__) - baseline_cls = baseline_mod.TorchQuantLinear + baseline_cls = baseline_mod.TorchLinear return baseline_cls, current_cls @@ -391,7 +391,7 @@ def mk_row(cols: list[str]) -> str: def _parse_args() -> argparse.Namespace: - p = argparse.ArgumentParser(description="A/B benchmark for TorchQuantLinear dequant path.") + p = argparse.ArgumentParser(description="A/B benchmark for TorchLinear dequant path.") p.add_argument("--worker", type=str, default=None, help=argparse.SUPPRESS) p.add_argument( "--only", diff --git a/tests/benchmark/benchmark_torch_aten_vs_onednn.py b/tests/benchmark/benchmark_torch_aten_vs_onednn.py new file mode 100644 index 000000000..641828035 --- /dev/null +++ b/tests/benchmark/benchmark_torch_aten_vs_onednn.py @@ -0,0 +1,370 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import argparse +import math +import os +import statistics +import subprocess +import sys +import time +from pathlib import Path + +import torch +from safetensors import safe_open +from tabulate import tabulate + + +DEFAULT_MODEL_DIR = Path("/root/GLM-4.6") +DEFAULT_GROUP_SIZE = 128 +DEFAULT_TOKENS = 1 +DEFAULT_CHUNK_ROWS = 2048 +DEFAULT_WARMUP = 3 +DEFAULT_ITERS = 10 + + +def _positive_divisor(value: int, name: str) -> int: + if value <= 0: + raise ValueError(f"{name} must be positive, got {value}") + return value + + +def _rank_tensor(name: str) -> tuple[int, int]: + # Prefer linear heads when the size ties with embeddings. + if name == "lm_head.weight": + return (3, 0) + if "lm_head" in name: + return (2, 0) + if name.endswith(".weight"): + return (1, 0) + return (0, 0) + + +def find_largest_2d_tensor(model_dir: Path) -> tuple[Path, str, tuple[int, int]]: + best: tuple[int, tuple[int, int], tuple[int, int], str, Path] | None = None + files = sorted(model_dir.glob("model-*.safetensors")) + mtp = model_dir / "mtp.safetensors" + if mtp.exists(): + files.append(mtp) + if not files: + raise FileNotFoundError(f"no safetensors shards found under {model_dir}") + + for shard in files: + with safe_open(str(shard), framework="pt", device="cpu") as handle: + for name in handle.keys(): + shape = tuple(handle.get_slice(name).get_shape()) + if len(shape) != 2: + continue + rows, cols = int(shape[0]), int(shape[1]) + numel = rows * cols + rank = _rank_tensor(name) + candidate = (numel, rank, (rows, cols), name, shard) + if best is None or candidate > best: + best = candidate + + if best is None: + raise RuntimeError(f"no 2D tensors found under {model_dir}") + + _, _, shape, name, shard = best + return shard, name, shape + + +def quantize_activation_per_tensor_symmetric(x_fp32: torch.Tensor) -> tuple[torch.Tensor, float]: + scale = max(float(x_fp32.abs().max().item()) / 127.0, 1e-8) + qx = torch.clamp(torch.round(x_fp32 / scale), -128, 127).to(torch.int8) + return qx, scale + + +def bench_ms(fn, warmup: int, iters: int) -> tuple[float, float, list[float]]: + with torch.inference_mode(): + for _ in range(warmup): + fn() + + samples: list[float] = [] + with torch.inference_mode(): + for _ in range(iters): + start = time.perf_counter() + fn() + end = time.perf_counter() + samples.append((end - start) * 1e3) + + return statistics.mean(samples), statistics.median(samples), samples + + +def gops(tokens: int, in_features: int, out_features: int, ms: float) -> float: + ops = 2.0 * tokens * in_features * out_features + return ops / (ms * 1e6) + + +def run_onednn_verbose_probe(tokens: int, in_features: int, out_features: int) -> str | None: + code = f""" +import torch +M = {tokens} +I = {in_features} +O = {out_features} +qx = torch.randint(-128, 128, (M, I), dtype=torch.int8) +qw = torch.randint(-128, 128, (O, I), dtype=torch.int8) +ws = torch.ones((O,), dtype=torch.float32) +wzp = torch.zeros((O,), dtype=torch.int32) +packed = torch.ops.onednn.qlinear_prepack(qw, [M, I]) +torch.ops.onednn.qlinear_pointwise( + qx, 1.0, 0, packed, ws, wzp, None, 1.0, 0, torch.bfloat16, "none", [], "" +) +""" + env = os.environ.copy() + env["DNNL_VERBOSE"] = "1" + completed = subprocess.run( + [sys.executable, "-c", code], + check=True, + capture_output=True, + text=True, + env=env, + ) + lines = [ + line + for line in completed.stdout.splitlines() + if "onednn_verbose" in line and ",primitive,exec,cpu,matmul," in line + ] + return lines[-1] if lines else None + + +def build_reference_and_packed_kernels( + shard: Path, + tensor_name: str, + shape: tuple[int, int], + tokens: int, + group_size: int, + chunk_rows: int, + seed: int, +) -> dict[str, object]: + out_features, in_features = shape + if in_features % group_size != 0: + raise ValueError( + f"in_features={in_features} must be divisible by group_size={group_size}" + ) + if chunk_rows % 16 != 0: + raise ValueError(f"chunk_rows={chunk_rows} must be divisible by 16") + if out_features % 16 != 0: + raise ValueError(f"out_features={out_features} must be divisible by 16") + + torch.manual_seed(seed) + x_bf16 = torch.randn(tokens, in_features, dtype=torch.bfloat16) + x_fp32 = x_bf16.float() + qx_int8, x_scale = quantize_activation_per_tensor_symmetric(x_fp32) + + groups = in_features // group_size + int4_weight = torch.empty((out_features, in_features // 2), dtype=torch.uint8) + scales_and_zeros = torch.zeros((groups, out_features, 2), dtype=torch.bfloat16) + int8_weight = torch.empty((out_features, in_features), dtype=torch.int8) + int8_weight_scales = torch.empty((out_features,), dtype=torch.float32) + int8_weight_zero_points = torch.zeros((out_features,), dtype=torch.int32) + reference = torch.empty((tokens, out_features), dtype=torch.float32) + + with safe_open(str(shard), framework="pt", device="cpu") as handle: + weight_slice = handle.get_slice(tensor_name) + for start in range(0, out_features, chunk_rows): + end = min(start + chunk_rows, out_features) + rows = end - start + weight_chunk = weight_slice[start:end, :].to(torch.float32).contiguous() + reference[:, start:end] = x_fp32 @ weight_chunk.t() + + int8_scales = torch.maximum( + weight_chunk.abs().amax(dim=1), + torch.full((rows,), 1e-8, dtype=torch.float32), + ) / 127.0 + int8_codes = torch.clamp( + torch.round(weight_chunk / int8_scales.unsqueeze(1)), + -128, + 127, + ).to(torch.int8) + int8_weight[start:end] = int8_codes + int8_weight_scales[start:end] = int8_scales + + grouped = weight_chunk.view(rows, groups, group_size) + int4_scales = torch.maximum( + grouped.abs().amax(dim=2), + torch.full((rows, groups), 1e-8, dtype=torch.float32), + ) / 7.0 + int4_signed = torch.clamp( + torch.round(grouped / int4_scales.unsqueeze(-1)), + -8, + 7, + ).to(torch.int8) + int4_codes = (int4_signed + 8).to(torch.int32).view(rows, in_features).contiguous() + int4_weight[start:end] = torch.ops.aten._convert_weight_to_int4pack_for_cpu( + int4_codes, 1 + ).contiguous() + scales_and_zeros[:, start:end, 0] = int4_scales.transpose(0, 1).to(torch.bfloat16) + + onednn_weight = torch.ops.onednn.qlinear_prepack(int8_weight, [tokens, in_features]) + del int8_weight + + return { + "x_bf16": x_bf16, + "qx_int8": qx_int8, + "x_scale": x_scale, + "reference": reference, + "int4_weight": int4_weight, + "scales_and_zeros": scales_and_zeros, + "onednn_weight": onednn_weight, + "int8_weight_scales": int8_weight_scales, + "int8_weight_zero_points": int8_weight_zero_points, + } + + +def main() -> int: + parser = argparse.ArgumentParser( + description="Benchmark TorchAten GPTQ int4 vs oneDNN qlinear on CPU" + ) + parser.add_argument("--model-dir", type=Path, default=DEFAULT_MODEL_DIR) + parser.add_argument("--group-size", type=int, default=DEFAULT_GROUP_SIZE) + parser.add_argument("--tokens", type=int, default=DEFAULT_TOKENS) + parser.add_argument("--chunk-rows", type=int, default=DEFAULT_CHUNK_ROWS) + parser.add_argument("--warmup", type=int, default=DEFAULT_WARMUP) + parser.add_argument("--iters", type=int, default=DEFAULT_ITERS) + parser.add_argument("--seed", type=int, default=1234) + args = parser.parse_args() + + tokens = _positive_divisor(args.tokens, "tokens") + warmup = _positive_divisor(args.warmup, "warmup") + iters = _positive_divisor(args.iters, "iters") + group_size = _positive_divisor(args.group_size, "group_size") + chunk_rows = _positive_divisor(args.chunk_rows, "chunk_rows") + + torch.set_num_interop_threads(1) + + shard, tensor_name, shape = find_largest_2d_tensor(args.model_dir) + out_features, in_features = shape + build_start = time.perf_counter() + state = build_reference_and_packed_kernels( + shard=shard, + tensor_name=tensor_name, + shape=shape, + tokens=tokens, + group_size=group_size, + chunk_rows=chunk_rows, + seed=args.seed, + ) + build_ms = (time.perf_counter() - build_start) * 1e3 + + x_bf16 = state["x_bf16"] + qx_int8 = state["qx_int8"] + x_scale = state["x_scale"] + reference = state["reference"] + int4_weight = state["int4_weight"] + scales_and_zeros = state["scales_and_zeros"] + onednn_weight = state["onednn_weight"] + int8_weight_scales = state["int8_weight_scales"] + int8_weight_zero_points = state["int8_weight_zero_points"] + + def run_torch_aten() -> torch.Tensor: + return torch.ops.aten._weight_int4pack_mm_for_cpu( + x_bf16, int4_weight, group_size, scales_and_zeros + ) + + def run_onednn() -> torch.Tensor: + return torch.ops.onednn.qlinear_pointwise( + qx_int8, + float(x_scale), + 0, + onednn_weight, + int8_weight_scales, + int8_weight_zero_points, + None, + 1.0, + 0, + torch.bfloat16, + "none", + [], + "", + ) + + with torch.inference_mode(): + out_aten = run_torch_aten().float() + out_onednn = run_onednn().float() + + aten_mean_ms, aten_median_ms, _ = bench_ms(run_torch_aten, warmup=warmup, iters=iters) + onednn_mean_ms, onednn_median_ms, _ = bench_ms(run_onednn, warmup=warmup, iters=iters) + + onednn_verbose_line = run_onednn_verbose_probe(tokens, in_features, out_features) + + config_rows = [ + ["model_dir", str(args.model_dir)], + ["tensor", tensor_name], + ["shard", shard.name], + ["shape", f"{out_features} x {in_features}"], + ["tokens", tokens], + ["group_size", group_size], + ["chunk_rows", chunk_rows], + ["threads", torch.get_num_threads()], + ["interop_threads", torch.get_num_interop_threads()], + ["build_ms", f"{build_ms:.2f}"], + ] + + results_rows = [ + [ + "TorchAten GPTQ", + "aten._weight_int4pack_mm_for_cpu", + "w4 / a16", + "bf16", + f"{aten_mean_ms:.3f}", + f"{aten_median_ms:.3f}", + f"{gops(tokens, in_features, out_features, aten_mean_ms):.2f}", + "1.000x", + f"{(out_aten - reference).abs().max().item():.6f}", + f"{(out_aten - reference).abs().mean().item():.6f}", + f"{math.sqrt(torch.mean((out_aten - reference) ** 2).item()):.6f}", + ], + [ + "oneDNN qlinear", + "onednn.qlinear_pointwise", + "w8 / a8", + "bf16", + f"{onednn_mean_ms:.3f}", + f"{onednn_median_ms:.3f}", + f"{gops(tokens, in_features, out_features, onednn_mean_ms):.2f}", + f"{aten_mean_ms / onednn_mean_ms:.3f}x", + f"{(out_onednn - reference).abs().max().item():.6f}", + f"{(out_onednn - reference).abs().mean().item():.6f}", + f"{math.sqrt(torch.mean((out_onednn - reference) ** 2).item()):.6f}", + ], + ] + + print("Configuration") + print(tabulate(config_rows, headers=["field", "value"], tablefmt="grid")) + print() + print("Results") + print( + tabulate( + results_rows, + headers=[ + "backend", + "kernel", + "quant", + "out", + "mean ms", + "median ms", + "effective GOPS", + "vs ATen", + "max|diff|", + "mean|diff|", + "rmse", + ], + tablefmt="grid", + ) + ) + print() + print("oneDNN Verbose") + if onednn_verbose_line is None: + print("No oneDNN matmul verbose line captured.") + else: + print(onednn_verbose_line) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tests/benchmark/benchmark_torch_int8_vs_onednn.py b/tests/benchmark/benchmark_torch_int8_vs_onednn.py new file mode 100644 index 000000000..e2a38d5d6 --- /dev/null +++ b/tests/benchmark/benchmark_torch_int8_vs_onednn.py @@ -0,0 +1,428 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: 2024-2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import argparse +import math +import os +import statistics +import subprocess +import sys +import time +from pathlib import Path + +import torch +from safetensors import safe_open +from tabulate import tabulate + + +# Keep this CPU benchmark isolated from optional GPU backend import side effects. +os.environ.setdefault("GPTQMODEL_DISABLE_BITBLAS", "1") + +from gptqmodel.nn_modules.qlinear.torch_int8 import Int8PackedModule + + +DEFAULT_MODEL_DIR = Path("/root/GLM-4.6") +DEFAULT_BATCHES = (1, 2, 4, 8, 16, 32, 64, 128) +DEFAULT_CHUNK_ROWS = 2048 +DEFAULT_WARMUP = 3 +DEFAULT_ITERS = 10 +DEFAULT_OUTPUT_DTYPE = torch.bfloat16 + + +def _positive_divisor(value: int, name: str) -> int: + if value <= 0: + raise ValueError(f"{name} must be positive, got {value}") + return value + + +def _parse_batches(raw: str) -> list[int]: + batches = [_positive_divisor(int(item.strip()), "batch") for item in raw.split(",") if item.strip()] + if not batches: + raise ValueError("at least one batch is required") + return batches + + +def _rank_tensor(name: str) -> tuple[int, int]: + # Prefer linear heads when the size ties with embeddings. + if name == "lm_head.weight": + return (3, 0) + if "lm_head" in name: + return (2, 0) + if name.endswith(".weight"): + return (1, 0) + return (0, 0) + + +def find_largest_2d_tensor(model_dir: Path) -> tuple[Path, str, tuple[int, int]]: + best: tuple[int, tuple[int, int], tuple[int, int], str, Path] | None = None + files = sorted(model_dir.glob("model-*.safetensors")) + mtp = model_dir / "mtp.safetensors" + if mtp.exists(): + files.append(mtp) + if not files: + raise FileNotFoundError(f"no safetensors shards found under {model_dir}") + + for shard in files: + with safe_open(str(shard), framework="pt", device="cpu") as handle: + for name in handle.keys(): + shape = tuple(handle.get_slice(name).get_shape()) + if len(shape) != 2: + continue + rows, cols = int(shape[0]), int(shape[1]) + numel = rows * cols + rank = _rank_tensor(name) + candidate = (numel, rank, (rows, cols), name, shard) + if best is None or candidate > best: + best = candidate + + if best is None: + raise RuntimeError(f"no 2D tensors found under {model_dir}") + + _, _, shape, name, shard = best + return shard, name, shape + + +def quantize_activation_per_tensor_symmetric(x_fp32: torch.Tensor) -> tuple[torch.Tensor, float]: + scale = max(float(x_fp32.abs().max().item()) / 127.0, 1e-8) + qx = torch.clamp(torch.round(x_fp32 / scale), -128, 127).to(torch.int8) + return qx, scale + + +def bench_ms(fn, warmup: int, iters: int) -> tuple[float, float, list[float]]: + with torch.inference_mode(): + for _ in range(warmup): + fn() + + samples: list[float] = [] + with torch.inference_mode(): + for _ in range(iters): + start = time.perf_counter() + fn() + end = time.perf_counter() + samples.append((end - start) * 1e3) + + return statistics.mean(samples), statistics.median(samples), samples + + +def gops(tokens: int, in_features: int, out_features: int, ms: float) -> float: + ops = 2.0 * tokens * in_features * out_features + return ops / (ms * 1e6) + + +def capture_onednn_isa(tokens: int, in_features: int, out_features: int) -> str | None: + code = f""" +import torch +M = {tokens} +I = {in_features} +O = {out_features} +qx = torch.randint(-128, 128, (M, I), dtype=torch.int8) +qw = torch.randint(-128, 128, (O, I), dtype=torch.int8) +ws = torch.ones((O,), dtype=torch.float32) +wzp = torch.zeros((O,), dtype=torch.int32) +packed = torch.ops.onednn.qlinear_prepack(qw, [M, I]) +torch.ops.onednn.qlinear_pointwise( + qx, 1.0, 0, packed, ws, wzp, None, 1.0, 0, torch.bfloat16, "none", [], "" +) +""" + env = os.environ.copy() + env["DNNL_VERBOSE"] = "1" + completed = subprocess.run( + [sys.executable, "-c", code], + check=True, + capture_output=True, + text=True, + env=env, + ) + for line in completed.stdout.splitlines(): + if "onednn_verbose" not in line or ",primitive,exec,cpu,matmul," not in line: + continue + fields = line.split(",") + if len(fields) >= 7: + return fields[6] + return None + + +def build_int8_weight_state( + shard: Path, + tensor_name: str, + shape: tuple[int, int], + chunk_rows: int, +) -> dict[str, torch.Tensor]: + out_features, in_features = shape + int8_weight_nk = torch.empty((out_features, in_features), dtype=torch.int8) + int8_weight_scales_fp32 = torch.empty((out_features,), dtype=torch.float32) + + with safe_open(str(shard), framework="pt", device="cpu") as handle: + weight_slice = handle.get_slice(tensor_name) + for start in range(0, out_features, chunk_rows): + end = min(start + chunk_rows, out_features) + weight_chunk_nk = weight_slice[start:end, :].to(torch.float32).contiguous() + channel_scale = torch.maximum( + weight_chunk_nk.abs().amax(dim=1), + torch.full((end - start,), 1e-8, dtype=torch.float32), + ) / 127.0 + int8_codes_nk = torch.clamp( + torch.round(weight_chunk_nk / channel_scale.unsqueeze(1)), + -128, + 127, + ).to(torch.int8) + int8_weight_nk[start:end] = int8_codes_nk + int8_weight_scales_fp32[start:end] = channel_scale + + return { + "int8_weight_nk": int8_weight_nk.contiguous(), + "int8_weight_scales_fp32": int8_weight_scales_fp32.contiguous(), + "int8_weight_scales_bf16": int8_weight_scales_fp32.to(torch.bfloat16).contiguous(), + "int8_weight_zero_points": torch.zeros((out_features,), dtype=torch.int32), + } + + +def compute_reference( + shard: Path, + tensor_name: str, + shape: tuple[int, int], + x_fp32: torch.Tensor, + chunk_rows: int, +) -> torch.Tensor: + out_features, _ = shape + reference = torch.empty((x_fp32.shape[0], out_features), dtype=torch.float32) + + with safe_open(str(shard), framework="pt", device="cpu") as handle: + weight_slice = handle.get_slice(tensor_name) + for start in range(0, out_features, chunk_rows): + end = min(start + chunk_rows, out_features) + weight_chunk_nk = weight_slice[start:end, :].to(torch.float32).contiguous() + reference[:, start:end] = x_fp32 @ weight_chunk_nk.t() + + return reference + + +def main() -> int: + parser = argparse.ArgumentParser( + description="Benchmark repo TorchInt8 GPTQ kernel vs oneDNN qlinear on CPU" + ) + parser.add_argument("--model-dir", type=Path, default=DEFAULT_MODEL_DIR) + parser.add_argument( + "--batches", + type=str, + default=",".join(str(batch) for batch in DEFAULT_BATCHES), + help="comma-separated batch sizes, e.g. 1,2,4,8,16,32,64,128", + ) + parser.add_argument("--chunk-rows", type=int, default=DEFAULT_CHUNK_ROWS) + parser.add_argument("--warmup", type=int, default=DEFAULT_WARMUP) + parser.add_argument("--iters", type=int, default=DEFAULT_ITERS) + parser.add_argument("--seed", type=int, default=1234) + args = parser.parse_args() + + warmup = _positive_divisor(args.warmup, "warmup") + iters = _positive_divisor(args.iters, "iters") + chunk_rows = _positive_divisor(args.chunk_rows, "chunk_rows") + batches = _parse_batches(args.batches) + + if not hasattr(torch.ops.aten, "_weight_int8pack_mm"): + raise RuntimeError("aten::_weight_int8pack_mm is unavailable in this PyTorch build") + if not hasattr(torch.ops.onednn, "qlinear_prepack"): + raise RuntimeError("onednn::qlinear_prepack is unavailable in this PyTorch build") + if not hasattr(torch.ops.onednn, "qlinear_pointwise"): + raise RuntimeError("onednn::qlinear_pointwise is unavailable in this PyTorch build") + + torch.set_num_interop_threads(1) + + shard, tensor_name, shape = find_largest_2d_tensor(args.model_dir) + out_features, in_features = shape + + build_start = time.perf_counter() + weight_state = build_int8_weight_state( + shard=shard, + tensor_name=tensor_name, + shape=shape, + chunk_rows=chunk_rows, + ) + weight_build_ms = (time.perf_counter() - build_start) * 1e3 + + int8_weight_nk = weight_state["int8_weight_nk"] + int8_weight_scales_fp32 = weight_state["int8_weight_scales_fp32"] + int8_weight_scales_bf16 = weight_state["int8_weight_scales_bf16"] + int8_weight_zero_points = weight_state["int8_weight_zero_points"] + + torch_int8_module = Int8PackedModule(int8_weight_nk, int8_weight_scales_bf16).eval() + + config_rows = [ + ["model_dir", str(args.model_dir)], + ["tensor", tensor_name], + ["shard", shard.name], + ["shape", f"{out_features} x {in_features}"], + ["batches", ",".join(str(batch) for batch in batches)], + ["chunk_rows", chunk_rows], + ["threads", torch.get_num_threads()], + ["interop_threads", torch.get_num_interop_threads()], + ["weight_build_ms", f"{weight_build_ms:.2f}"], + ["output_dtype", str(DEFAULT_OUTPUT_DTYPE)], + ["note", "forward-only timings; oneDNN prepack measured separately"], + ["quant", "TorchInt8=w8/a16 bf16, oneDNN=w8/a8 bf16"], + ] + + perf_rows: list[list[str]] = [] + acc_rows: list[list[str]] = [] + isa_rows: list[list[str]] = [] + + for index, batch in enumerate(batches): + torch.manual_seed(args.seed + batch) + x_bf16 = torch.randn(batch, in_features, dtype=torch.bfloat16) + x_fp32 = x_bf16.float() + qx_int8, x_scale = quantize_activation_per_tensor_symmetric(x_fp32) + + ref_start = time.perf_counter() + reference = compute_reference( + shard=shard, + tensor_name=tensor_name, + shape=shape, + x_fp32=x_fp32, + chunk_rows=chunk_rows, + ) + reference_ms = (time.perf_counter() - ref_start) * 1e3 + + prepack_start = time.perf_counter() + onednn_weight = torch.ops.onednn.qlinear_prepack(int8_weight_nk, [batch, in_features]) + onednn_prepack_ms = (time.perf_counter() - prepack_start) * 1e3 + + def run_torch_int8() -> torch.Tensor: + return torch_int8_module(x_bf16) + + def run_onednn( + onednn_weight_packed: torch.Tensor = onednn_weight, + ) -> torch.Tensor: + return torch.ops.onednn.qlinear_pointwise( + qx_int8, + float(x_scale), + 0, + onednn_weight_packed, + int8_weight_scales_fp32, + int8_weight_zero_points, + None, + 1.0, + 0, + DEFAULT_OUTPUT_DTYPE, + "none", + [], + "", + ) + + with torch.inference_mode(): + out_torch_int8_raw = run_torch_int8() + out_onednn_raw = run_onednn() + + if out_torch_int8_raw.dtype != DEFAULT_OUTPUT_DTYPE: + raise RuntimeError( + f"expected TorchInt8 output dtype {DEFAULT_OUTPUT_DTYPE}, got {out_torch_int8_raw.dtype}" + ) + if out_onednn_raw.dtype != DEFAULT_OUTPUT_DTYPE: + raise RuntimeError( + f"expected oneDNN output dtype {DEFAULT_OUTPUT_DTYPE}, got {out_onednn_raw.dtype}" + ) + + out_torch_int8 = out_torch_int8_raw.float() + out_onednn = out_onednn_raw.float() + + torch_int8_mean_ms, torch_int8_median_ms, _ = bench_ms( + run_torch_int8, warmup=warmup, iters=iters + ) + onednn_mean_ms, onednn_median_ms, _ = bench_ms( + run_onednn, warmup=warmup, iters=iters + ) + + isa = None + if index in (0, len(batches) - 1): + isa = capture_onednn_isa(batch, in_features, out_features) + isa_rows.append([str(batch), isa or "unknown"]) + + torch_int8_diff = out_torch_int8 - reference + onednn_diff = out_onednn - reference + backend_delta = out_onednn - out_torch_int8 + + perf_rows.append( + [ + str(batch), + f"{torch_int8_mean_ms:.3f}", + f"{torch_int8_median_ms:.3f}", + f"{gops(batch, in_features, out_features, torch_int8_mean_ms):.2f}", + f"{onednn_mean_ms:.3f}", + f"{onednn_median_ms:.3f}", + f"{gops(batch, in_features, out_features, onednn_mean_ms):.2f}", + "TorchInt8" if torch_int8_mean_ms <= onednn_mean_ms else "oneDNN", + f"{max(torch_int8_mean_ms, onednn_mean_ms) / min(torch_int8_mean_ms, onednn_mean_ms):.3f}x", + f"{onednn_prepack_ms:.3f}", + f"{reference_ms:.2f}", + ] + ) + + acc_rows.append( + [ + str(batch), + f"{torch_int8_diff.abs().max().item():.6f}", + f"{torch_int8_diff.abs().mean().item():.6f}", + f"{math.sqrt(torch.mean(torch_int8_diff ** 2).item()):.6f}", + f"{onednn_diff.abs().max().item():.6f}", + f"{onednn_diff.abs().mean().item():.6f}", + f"{math.sqrt(torch.mean(onednn_diff ** 2).item()):.6f}", + f"{backend_delta.abs().max().item():.6f}", + f"{backend_delta.abs().mean().item():.6f}", + ] + ) + + del onednn_weight, reference, out_torch_int8, out_onednn + + print("Configuration") + print(tabulate(config_rows, headers=["field", "value"], tablefmt="grid")) + print() + print("Performance") + print( + tabulate( + perf_rows, + headers=[ + "batch", + "TorchInt8 mean ms", + "TorchInt8 median ms", + "TorchInt8 GOPS", + "oneDNN mean ms", + "oneDNN median ms", + "oneDNN GOPS", + "winner", + "speedup", + "oneDNN prepack ms", + "ref build ms", + ], + tablefmt="grid", + ) + ) + print() + print("Accuracy") + print( + tabulate( + acc_rows, + headers=[ + "batch", + "TorchInt8 max|diff|", + "TorchInt8 mean|diff|", + "TorchInt8 rmse", + "oneDNN max|diff|", + "oneDNN mean|diff|", + "oneDNN rmse", + "backend max|delta|", + "backend mean|delta|", + ], + tablefmt="grid", + ) + ) + print() + print("oneDNN ISA Probe") + print(tabulate(isa_rows, headers=["batch", "isa"], tablefmt="grid")) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..1cade325f --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,31 @@ +import random +import sys +from pathlib import Path + +import numpy +import torch + + +# Reduce logbar progress noise for pytest runs unless a caller explicitly +# overrides the environment. Keep the library default unchanged in LogBar +# itself; this is only a test harness preference. +#os.environ.setdefault("LOGBAR_ANIMATION", "0") +#os.environ.setdefault("LOGBAR_PROGRESS_OUTPUT_INTERVAL", "10") + +# Keep unit tests deterministic across CI runs. +torch.manual_seed(787) +random.seed(787) +numpy.random.seed(787) + +_TESTS_DIR = Path(__file__).resolve().parent +_MODELS_TESTS_DIR = _TESTS_DIR / "models" +_REPO_ROOT = _TESTS_DIR.parent + +# The suite mixes two helper import styles: +# - `from models.model_test import ModelTest` +# - `from ovis.image_to_test_dataset import ...` +# Add both helper directories, plus the repo root for `tests.*` imports. +for path in (_REPO_ROOT, _TESTS_DIR, _MODELS_TESTS_DIR): + path_str = str(path) + if path_str not in sys.path: + sys.path.insert(0, path_str) diff --git a/tests/eval.py b/tests/eval.py new file mode 100644 index 000000000..f839a097f --- /dev/null +++ b/tests/eval.py @@ -0,0 +1,890 @@ +# SPDX-FileCopyrightText: 2024-2026 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2026 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import importlib +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, Mapping, Optional + +from tabulate import tabulate + +from gptqmodel.utils.backend import BACKEND + + +_MMLU_LOCAL_DATASET = Path("/monster/data/model/dataset/hails-mmlu_no_train") +_GSM8K_LOCAL_DATASET = Path("/monster/data/model/dataset/gsm8k") +_ENGINE_OPTION_KEYS = { + "attn_implementation", + "attention_backend", + "base_url", + "context_length", + "dp_size", + "device", + "device_map", + "dtype", + "enforce_eager", + "gpu_memory_utilization", + "llm_kwargs", + "load_format", + "max_model_len", + "max_running_requests", + "max_total_tokens", + "mem_fraction_static", + "padding_side", + "pp_size", + "quantization", + "sampling_backend", + "sampling_params", + "seed", + "skip_tokenizer_init", + "tensor_parallel_size", + "tokenizer_worker_num", + "tokenizer_mode", + "tokenizer_revision", + "tp_size", + "trust_remote_code", + "vllm_path", +} +_DROPPED_MODEL_ARG_KEYS = { + "backend", + "gptqmodel", + "model_id_or_path", + "pretrained", + "tokenizer", +} + +DEFAULT_TASKS: tuple[str, ...] = ("arc_challenge",) +SUPPORTED_TASKS: tuple[str, ...] = ( + "arc_challenge", + "arc_easy", + "boolq", + "gsm8k_cot", + "gsm8k_platinum_cot", + "gpqa", + "hellaswag", + "mmlu", + "mmlu_pro", + "mmlu_pro:math", + "mmlu_stem", + "openbookqa", +) + + +def import_evalution(): + try: + return importlib.import_module("evalution") + except ModuleNotFoundError: + raise ValueError( + "Evalution is required for evaluation. " + "Install the `Evalution` package before running evaluation." + ) from None + + +def list_supported_tasks() -> tuple[str, ...]: + return SUPPORTED_TASKS + + +def normalize_eval_task_name(task: Any) -> str: + if task is None: + raise ValueError("Evaluation task identifier cannot be None") + if isinstance(task, str): + normalized = task.strip() + else: + normalized = str(task).strip() + if not normalized: + raise ValueError("Evaluation task identifier cannot be empty") + return normalized + + +def format_eval_result_table(result: Mapping[str, Any]) -> str: + rows = [] + for test in _result_tests(result): + metrics = test.get("metrics", {}) + if not metrics: + rows.append([test.get("name", ""), "-", "-"]) + continue + for metric_name, value in metrics.items(): + rows.append([test.get("name", ""), metric_name, f"{float(value):.4f}"]) + + if not rows: + rows.append(["-", "-", "-"]) + return tabulate(rows, headers=["Task", "Metric", "Value"], tablefmt="github") + + +def get_eval_task_results(result: Mapping[str, Any]) -> dict[str, dict[str, float]]: + return { + str(test.get("name", "")): { + str(metric_name): float(metric_value) + for metric_name, metric_value in (test.get("metrics", {}) or {}).items() + } + for test in _result_tests(result) + } + + +def get_eval_task_metrics(result: Mapping[str, Any], task: Any) -> dict[str, float]: + return get_eval_task_results(result).get(normalize_eval_task_name(task), {}) + + +def resolve_eval_metric_alias(metric_name: str, metrics: Mapping[str, Any]) -> str | None: + if metric_name in metrics: + return metric_name + + aliases = { + "acc": "accuracy,loglikelihood", + "acc_norm": "accuracy,loglikelihood_norm", + "acc,none": "accuracy,loglikelihood", + "acc_norm,none": "accuracy,loglikelihood_norm", + } + alias = aliases.get(metric_name) + if alias and alias in metrics: + return alias + return None + + +def evaluate( + model_or_id_or_path: Any = None, + tokenizer: Any = None, + tasks: Any = None, + batch_size: int | str = 1, + trust_remote_code: bool = False, + output_path: Optional[str] = None, + llm_backend: str = "gptqmodel", + backend: BACKEND | str | None = BACKEND.AUTO, + model_args: Optional[Dict[str, Any]] = None, + **args, +): + normalized_llm_backend = str(llm_backend).strip().lower() + if normalized_llm_backend not in {"gptqmodel", "vllm", "sglang"}: + raise ValueError( + "Evalution-backed evaluation only supports llm_backend='gptqmodel', 'vllm', or 'sglang'." + ) + + if tasks is None: + task_list = list(DEFAULT_TASKS) + elif isinstance(tasks, (list, tuple)): + task_list = [normalize_eval_task_name(task) for task in tasks] + else: + task_list = [normalize_eval_task_name(tasks)] + + model_args = dict(model_args or {}) + gen_kwargs = args.pop("gen_kwargs", None) + apply_chat_template = bool(args.pop("apply_chat_template", False)) + suite_kwargs = dict(args.pop("suite_kwargs", {}) or {}) + + if args: + unexpected = ", ".join(sorted(args.keys())) + raise TypeError(f"Unsupported evaluation keyword arguments: {unexpected}") + + return run_evalution( + model_or_id_or_path=model_or_id_or_path, + tokenizer=tokenizer, + tasks=task_list, + batch_size=batch_size, + trust_remote_code=trust_remote_code, + output_path=output_path, + llm_backend=normalized_llm_backend, + backend=backend, + model_args=model_args, + apply_chat_template=apply_chat_template, + gen_kwargs=gen_kwargs, + suite_kwargs=suite_kwargs, + ) + + +def run_evalution( + *, + model_or_id_or_path: Any, + tokenizer: Any, + tasks: list[str], + batch_size: int | str, + trust_remote_code: bool, + output_path: Optional[str], + llm_backend: str, + backend: BACKEND | str | None, + model_args: Dict[str, Any], + apply_chat_template: bool, + gen_kwargs: Any, + suite_kwargs: Dict[str, Any], +) -> dict[str, Any]: + evalution = import_evalution() + engine_config, model_config, session = _build_evalution_runtime( + evalution=evalution, + model_or_id_or_path=model_or_id_or_path, + llm_backend=llm_backend, + backend=backend, + batch_size=batch_size, + trust_remote_code=trust_remote_code, + model_args=model_args, + tokenizer=tokenizer, + ) + suite_batch_size = _coerce_suite_batch_size(batch_size) + generation_settings = _parse_generation_settings(gen_kwargs) + + try: + test_results = [] + for index, task_name in enumerate(tasks): + if index: + session.gc() + suite = _build_evalution_suite( + evalution=evalution, + task_name=task_name, + apply_chat_template=apply_chat_template, + batch_size=suite_batch_size, + generation_settings=generation_settings, + suite_kwargs=suite_kwargs, + ) + test_results.append(suite.evaluate(session)) + + engine_payload = {} + if hasattr(engine_config, "to_dict"): + engine_payload = engine_config.to_dict() + try: + engine_payload["execution"] = session.describe_execution() + except Exception: + # Best-effort metadata only; evaluation should continue if unavailable. + pass + + result = evalution.RunResult( + model=model_config.to_dict(), + engine=engine_payload, + tests=test_results, + ).to_dict() + finally: + session.close() + + _maybe_write_evalution_output(output_path, result) + return result + + +@dataclass(slots=True) +class _ArcChallengeLoglikelihoodSuite: + apply_chat_template: bool = False + batch_size: int | None = None + dataset_path: str = "allenai/ai2_arc" + dataset_name: str | None = "ARC-Challenge" + split: str = "test" + max_rows: int | None = None + cache_dir: str | None = None + stream: bool = True + + def dataset_loader(self) -> Any: + from datasets import load_dataset + + def _loader(path: str, *args, stream: bool = True, **kwargs): + # Evalution forwards `stream`; Hugging Face expects `streaming`. + # Enforce the new API by rejecting legacy `streaming`. + if "streaming" in kwargs: + raise TypeError("use `stream=` (Evalution) not `streaming=`") + return load_dataset(path, *args, streaming=stream, **kwargs) + + return _loader + + def task_name(self) -> str: + return "arc_challenge" + + def continuation_for_choice(self, choice: str) -> str: + return choice if choice[:1].isspace() else f" {choice}" + + def result_metadata(self) -> dict[str, Any]: + return { + "dataset_path": self.dataset_path, + "dataset_name": self.dataset_name, + "split": self.split, + "stream": self.stream, + "apply_chat_template": self.apply_chat_template, + "scoring_mode": "multiple_choice_loglikelihood", + } + + def build_sample(self, doc: dict[str, Any], *, index: int) -> Any: + from evalution.benchmarks.multiple_choice import MultipleChoiceSample + from evalution.benchmarks.multiple_choice_utils import choice_index_from_labels, question_answer_prompt + + labels = list(doc["choices"]["label"]) + texts = list(doc["choices"]["text"]) + return MultipleChoiceSample( + index=index, + prompt=question_answer_prompt(doc["question"]), + choices=texts, + gold_index=choice_index_from_labels(labels, doc["answerKey"]), + metadata={"id": doc["id"], "choice_labels": labels}, + ) + + def evaluate(self, session: Any) -> Any: + from evalution.benchmarks.data import doc_count, limit_docs, load_suite_dataset + from evalution.engines.base import LoglikelihoodRequest + from evalution.logbar import get_logger + from evalution.results import SampleResult, TestResult + + task_name = self.task_name() + logger = get_logger() + loaded_docs, _dataset_load_wall_s = load_suite_dataset( + self.dataset_loader(), + task_name=task_name, + dataset_path=self.dataset_path, + dataset_name=self.dataset_name, + split=self.split, + cache_dir=self.cache_dir, + stream=self.stream, + ) + + docs = limit_docs(loaded_docs, self.max_rows) + if not isinstance(docs, list): + docs = list(docs) + + total = doc_count( + docs, + loaded_docs=loaded_docs, + max_rows=self.max_rows, + split=self.split, + ) + logger.info("%s: evaluating %d sample(s)", task_name, total) + + samples = [self.build_sample(doc, index=index) for index, doc in enumerate(docs)] + rendered_prompts = [ + _render_evalution_prompt(session, sample.prompt, apply_chat_template=self.apply_chat_template) + for sample in samples + ] + + requests = [] + request_to_choice = [] + for sample, prompt in zip(samples, rendered_prompts, strict=True): + for choice_index, choice in enumerate(sample.choices): + requests.append( + LoglikelihoodRequest( + context=prompt, + continuation=self.continuation_for_choice(choice), + ) + ) + request_to_choice.append((sample.index, choice_index)) + + outputs = session.loglikelihood(requests, batch_size=self.batch_size) + logger.info("%s: executed %d/%d sample(s)", task_name, len(samples), total) + + sample_choice_scores: dict[int, list[tuple[float, float, int]]] = {} + for (sample_index, choice_index), output in zip(request_to_choice, outputs, strict=True): + sample_choice_scores.setdefault(sample_index, []).append( + ( + output.logprob, + output.logprob / max(output.token_count, 1), + choice_index, + ) + ) + + sample_results = [] + raw_total = 0.0 + norm_total = 0.0 + for sample, prompt in zip(samples, rendered_prompts, strict=True): + choice_scores = sorted(sample_choice_scores[sample.index], key=lambda item: item[2]) + raw_best = max(choice_scores, key=lambda item: item[0])[2] + norm_best = max(choice_scores, key=lambda item: item[1])[2] + raw_score = 1.0 if raw_best == sample.gold_index else 0.0 + norm_score = 1.0 if norm_best == sample.gold_index else 0.0 + raw_total += raw_score + norm_total += norm_score + sample_results.append( + SampleResult( + index=sample.index, + prompt=prompt, + target=sample.choices[sample.gold_index], + prediction=sample.choices[norm_best], + extracted={ + "gold_index": str(sample.gold_index), + "predicted_index": str(raw_best), + "predicted_index_norm": str(norm_best), + }, + scores={ + "accuracy,loglikelihood": raw_score, + "accuracy,loglikelihood_norm": norm_score, + }, + metadata={ + **sample.metadata, + "choice_logprobs": [score for score, _norm, _index in choice_scores], + "choice_logprobs_norm": [norm for _score, norm, _index in choice_scores], + }, + ) + ) + + denominator = max(len(sample_results), 1) + metrics = { + "accuracy,loglikelihood": raw_total / denominator, + "accuracy,loglikelihood_norm": norm_total / denominator, + } + return TestResult( + name=task_name, + metrics=metrics, + samples=sample_results, + metadata=self.result_metadata(), + ) + + +def _result_tests(result: Mapping[str, Any]) -> list[dict[str, Any]]: + tests = result.get("tests") + return list(tests) if isinstance(tests, list) else [] + + +def _build_evalution_runtime( + *, + evalution: Any, + model_or_id_or_path: Any, + llm_backend: str, + backend: BACKEND | str | None, + batch_size: int | str, + trust_remote_code: bool, + model_args: Dict[str, Any], + tokenizer: Any, +): + from transformers import PreTrainedModel + + try: + from peft import PeftModel + except Exception: # pragma: no cover - optional dependency + PeftModel = () + + engine_options, load_kwargs = _split_evalution_model_args(model_args) + engine_dtype = _normalize_dtype_name(engine_options.get("dtype")) + engine_device = engine_options.get("device") + engine_device_map = engine_options.get("device_map") + engine_attn = engine_options.get("attn_implementation") + engine_padding_side = engine_options.get("padding_side", "left") + + tokenizer_path = _resolve_tokenizer_path(tokenizer) + + if llm_backend in {"vllm", "sglang"}: + model_path = ( + model_or_id_or_path + if isinstance(model_or_id_or_path, str) + else _resolve_model_path(model_or_id_or_path) + ) + if model_path is None: + raise ValueError("Evalution vLLM evaluation requires a model path.") + + if llm_backend == "vllm": + max_model_len = engine_options.get("max_model_len") + tensor_parallel_size = engine_options.get("tensor_parallel_size", 1) + gpu_memory_utilization = engine_options.get("gpu_memory_utilization", 0.9) + llm_kwargs = dict(engine_options.get("llm_kwargs", {}) or {}) + + engine = evalution.VLLM( + dtype=engine_dtype, + batch_size=batch_size, + trust_remote_code=trust_remote_code, + padding_side=engine_padding_side, + seed=engine_options.get("seed"), + tokenizer_mode=engine_options.get("tokenizer_mode", "auto"), + tensor_parallel_size=int(tensor_parallel_size), + gpu_memory_utilization=float(gpu_memory_utilization), + quantization=engine_options.get("quantization"), + max_model_len=int(max_model_len) if max_model_len is not None else None, + enforce_eager=bool(engine_options.get("enforce_eager", False)), + tokenizer_revision=engine_options.get("tokenizer_revision"), + vllm_path=engine_options.get("vllm_path"), + llm_kwargs=llm_kwargs, + ) + else: + sglang_config = _build_sglang_engine_kwargs( + engine_options=engine_options, + batch_size=batch_size, + trust_remote_code=trust_remote_code, + padding_side=engine_padding_side, + engine_dtype=engine_dtype, + ) + engine = evalution.SGLang(**sglang_config) + + model_config = evalution.Model( + path=model_path, + tokenizer_path=tokenizer_path, + trust_remote_code=trust_remote_code, + model_kwargs=load_kwargs, + ) + session = engine.build(model_config) + return engine, model_config, session + + if isinstance(model_or_id_or_path, str): + engine = evalution.GPTQModel( + dtype=engine_dtype, + attn_implementation=engine_attn, + device=engine_device, + device_map=engine_device_map, + seed=engine_options.get("seed"), + batch_size=batch_size, + trust_remote_code=trust_remote_code, + padding_side=engine_padding_side, + backend=_normalize_backend_name(backend), + gptqmodel_path=str(Path(__file__).resolve().parents[2]), + ) + model_config = evalution.Model( + path=model_or_id_or_path, + tokenizer_path=tokenizer_path, + trust_remote_code=trust_remote_code, + model_kwargs=load_kwargs, + ) + session = engine.build(model_config) + return engine, model_config, session + + if isinstance(model_or_id_or_path, (PreTrainedModel, PeftModel)) or hasattr(model_or_id_or_path, "model"): + model_path = _resolve_model_path(model_or_id_or_path) + if model_path is None: + raise ValueError("Evalution requires a model path when evaluating a live model instance.") + + engine = evalution.TransformersCompat( + dtype=engine_dtype or _normalize_dtype_name(getattr(model_or_id_or_path, "dtype", None)), + attn_implementation=engine_attn, + device=engine_device, + device_map=engine_device_map, + seed=engine_options.get("seed"), + batch_size=batch_size, + trust_remote_code=trust_remote_code, + padding_side=engine_padding_side, + ) + engine.resolved_engine = "TransformersCompat" + + model_config = evalution.Model( + path=model_path, + tokenizer_path=tokenizer_path, + trust_remote_code=trust_remote_code, + model_kwargs=load_kwargs, + ) + session = _build_evalution_session_from_model( + engine=engine, + model_config=model_config, + model=model_or_id_or_path, + ) + return engine, model_config, session + + raise ValueError( + f"`model_or_id_or_path` is invalid. expected: `model instance or str` actual: `{model_or_id_or_path}`" + ) + + +def _build_evalution_session_from_model(*, engine: Any, model_config: Any, model: Any): + from evalution.engines.transformers_common import _clone_prepare_tokenizer, _resolve_input_device + from evalution.engines.transformers_compat import TransformersCompatSession + + inner_model = getattr(model, "model", model) + tokenizer = getattr(model, "tokenizer", None) + if tokenizer is None: + raise ValueError("Tokenizer must be attached to the loaded model instance for Evalution-backed evaluation.") + + trust_remote_code = ( + engine.trust_remote_code + if engine.trust_remote_code is not None + else model_config.trust_remote_code + ) + requested_attn = ( + getattr(getattr(inner_model, "config", None), "_attn_implementation", None) + or getattr(getattr(inner_model, "config", None), "attn_implementation", None) + or engine.attn_implementation + ) + prepare_tokenizer = _clone_prepare_tokenizer( + tokenizer=tokenizer, + model_config=model_config, + trust_remote_code=trust_remote_code, + ) + requested_padding_side = getattr(engine, "padding_side", None) + if requested_padding_side: + for tok in (tokenizer, prepare_tokenizer): + if tok is None: + continue + if getattr(tok, "padding_side", None) != requested_padding_side: + tok.padding_side = requested_padding_side + if getattr(tok, "pad_token_id", None) is None: + eos_token_id = getattr(tok, "eos_token_id", None) + if eos_token_id is not None: + tok.pad_token_id = eos_token_id + + session = TransformersCompatSession( + config=engine, + model_config=model_config, + model=inner_model, + tokenizer=tokenizer, + prepare_tokenizer=prepare_tokenizer, + input_device=_resolve_input_device(inner_model, prefer=engine.device), + requested_attn_implementation=requested_attn, + effective_attn_implementation=requested_attn, + paged_attention_enabled=False, + generation_backend="generate_compat", + ) + session._gptqmodel_wrapper = model + return session + + +def _build_evalution_suite( + *, + evalution: Any, + task_name: str, + apply_chat_template: bool, + batch_size: int | None, + generation_settings: Dict[str, Any], + suite_kwargs: Dict[str, Any], +): + benchmarks = evalution.benchmarks + normalized_task = normalize_eval_task_name(task_name) + max_new_tokens = int(generation_settings.get("max_new_tokens", 256)) + do_sample = bool(generation_settings.get("do_sample", False)) + temperature = float(generation_settings.get("temperature", 0.0)) + kwargs = dict(suite_kwargs or {}) + if "stream" not in kwargs and "streaming" in kwargs: + kwargs["stream"] = bool(kwargs.pop("streaming")) + elif "streaming" in kwargs: + kwargs.pop("streaming") + + if normalized_task == "arc_challenge": + kwargs.setdefault("apply_chat_template", apply_chat_template) + kwargs.setdefault("batch_size", batch_size) + kwargs.pop("num_fewshot", None) + return _ArcChallengeLoglikelihoodSuite(**kwargs) + if normalized_task == "mmlu_stem": + kwargs.setdefault("subsets", "stem") + kwargs.setdefault("batch_size", batch_size) + kwargs.setdefault("stream", True) + if _MMLU_LOCAL_DATASET.exists(): + kwargs.setdefault("dataset_path", str(_MMLU_LOCAL_DATASET)) + return benchmarks.mmlu(**kwargs) + if normalized_task == "gsm8k_cot": + kwargs.setdefault("variant", "cot") + kwargs.setdefault("apply_chat_template", apply_chat_template) + kwargs.setdefault("max_new_tokens", max_new_tokens) + kwargs.setdefault("batch_size", batch_size) + kwargs.setdefault("do_sample", do_sample) + kwargs.setdefault("temperature", temperature) + if _GSM8K_LOCAL_DATASET.exists(): + kwargs.setdefault("dataset_path", str(_GSM8K_LOCAL_DATASET)) + kwargs.setdefault("dataset_name", "main") + return benchmarks.gsm8k(**kwargs) + if normalized_task == "gsm8k_platinum_cot": + kwargs.setdefault("variant", "cot") + kwargs.setdefault("apply_chat_template", apply_chat_template) + kwargs.setdefault("max_new_tokens", max_new_tokens) + kwargs.setdefault("batch_size", batch_size) + kwargs.setdefault("do_sample", do_sample) + kwargs.setdefault("temperature", temperature) + return benchmarks.gsm8k_platinum(**kwargs) + if normalized_task == "mmlu": + kwargs.setdefault("batch_size", batch_size) + kwargs.setdefault("stream", True) + if _MMLU_LOCAL_DATASET.exists(): + kwargs.setdefault("dataset_path", str(_MMLU_LOCAL_DATASET)) + return benchmarks.mmlu(**kwargs) + if normalized_task == "mmlu_pro": + kwargs.setdefault("apply_chat_template", apply_chat_template) + kwargs.setdefault("max_new_tokens", max_new_tokens) + kwargs.setdefault("batch_size", batch_size) + kwargs.setdefault("do_sample", do_sample) + kwargs.setdefault("temperature", temperature) + return benchmarks.mmlu_pro(**kwargs) + if normalized_task.startswith("mmlu_pro:"): + subset = normalized_task.split(":", 1)[1].strip() + if not subset: + raise ValueError(f"Invalid Evalution task: `{normalized_task}`") + kwargs.setdefault("subsets", subset) + kwargs.setdefault("apply_chat_template", apply_chat_template) + kwargs.setdefault("max_new_tokens", max_new_tokens) + kwargs.setdefault("batch_size", batch_size) + kwargs.setdefault("do_sample", do_sample) + kwargs.setdefault("temperature", temperature) + return benchmarks.mmlu_pro(**kwargs) + + generic_factory = getattr(benchmarks, normalized_task, None) + if callable(generic_factory): + kwargs.setdefault("batch_size", batch_size) + return generic_factory(**kwargs) + + raise ValueError(f"Unsupported Evalution task: `{normalized_task}`") + + +def _split_evalution_model_args(model_args: Dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]: + engine_options = {} + load_kwargs = {} + for key, value in model_args.items(): + if key in _DROPPED_MODEL_ARG_KEYS: + continue + if key in _ENGINE_OPTION_KEYS: + engine_options[key] = value + else: + load_kwargs[key] = value + return engine_options, load_kwargs + +def _build_sglang_engine_kwargs( + *, + engine_options: Dict[str, Any], + batch_size: int | str, + trust_remote_code: bool, + padding_side: str, + engine_dtype: str | None, +) -> Dict[str, Any]: + context_length = engine_options.get("context_length", engine_options.get("max_model_len")) + tp_size = engine_options.get("tp_size", engine_options.get("tensor_parallel_size", 1)) + mem_fraction_static = engine_options.get( + "mem_fraction_static", + engine_options.get("gpu_memory_utilization"), + ) + + return { + "dtype": engine_dtype, + "device": engine_options.get("device"), + "seed": engine_options.get("seed"), + "trust_remote_code": trust_remote_code, + "padding_side": padding_side, + "base_url": engine_options.get("base_url"), + "batch_size": batch_size, + "tokenizer_mode": engine_options.get("tokenizer_mode", "auto"), + "tokenizer_worker_num": int(engine_options.get("tokenizer_worker_num", 1)), + "skip_tokenizer_init": bool(engine_options.get("skip_tokenizer_init", False)), + "load_format": engine_options.get("load_format", "auto"), + "context_length": int(context_length) if context_length is not None else None, + "quantization": engine_options.get("quantization"), + "mem_fraction_static": float(mem_fraction_static) if mem_fraction_static is not None else None, + "tp_size": int(tp_size), + "dp_size": int(engine_options.get("dp_size", 1)), + "pp_size": int(engine_options.get("pp_size", 1)), + "attention_backend": engine_options.get("attention_backend"), + "sampling_backend": engine_options.get("sampling_backend"), + "max_running_requests": ( + int(engine_options["max_running_requests"]) + if engine_options.get("max_running_requests") is not None + else None + ), + "max_total_tokens": ( + int(engine_options["max_total_tokens"]) + if engine_options.get("max_total_tokens") is not None + else None + ), + "sampling_params": dict(engine_options.get("sampling_params", {}) or {}), + } + + +def _parse_generation_settings(gen_kwargs: Any) -> Dict[str, Any]: + if not gen_kwargs: + return {} + if isinstance(gen_kwargs, Mapping): + return dict(gen_kwargs) + + settings = {} + for item in str(gen_kwargs).split(","): + if "=" not in item: + continue + key, raw_value = item.split("=", 1) + settings[key.strip()] = _coerce_scalar(raw_value.strip()) + return settings + + +def _coerce_scalar(value: str) -> Any: + lowered = value.lower() + if lowered in {"true", "false"}: + return lowered == "true" + if lowered in {"none", "null"}: + return None + try: + if any(ch in value for ch in (".", "e", "E")): + return float(value) + return int(value) + except ValueError: + return value + + +def _coerce_suite_batch_size(batch_size: int | str) -> int | None: + if isinstance(batch_size, str): + normalized = batch_size.strip().lower() + if normalized == "auto": + return None + return int(normalized) + return int(batch_size) + + +def _normalize_backend_name(backend: BACKEND | str | None) -> str: + if backend is None: + return BACKEND.AUTO.value + if isinstance(backend, BACKEND): + return backend.value + return str(backend) + + +def _normalize_dtype_name(dtype: Any) -> str | None: + if dtype is None: + return None + if isinstance(dtype, str): + return dtype + try: + import torch + + mapping = { + torch.float16: "float16", + torch.bfloat16: "bfloat16", + torch.float32: "float32", + torch.float64: "float64", + } + if dtype in mapping: + return mapping[dtype] + except ImportError: + # torch is optional here; fall back to returning str(dtype) below. + pass + return str(dtype) + + +def _resolve_tokenizer_path(tokenizer: Any) -> str | None: + if tokenizer is None: + return None + if isinstance(tokenizer, str): + return tokenizer + return getattr(tokenizer, "name_or_path", None) + + +def _resolve_model_path(model: Any) -> str | None: + model_path = getattr(model, "model_local_path", None) + if isinstance(model_path, str) and model_path.strip(): + return model_path + config = getattr(model, "config", None) + name_or_path = getattr(config, "name_or_path", None) + if isinstance(name_or_path, str) and name_or_path.strip(): + return name_or_path + return None + + +def _render_evalution_prompt(session: Any, prompt: str, *, apply_chat_template: bool) -> str: + if not apply_chat_template: + return prompt + + tokenizer = getattr(session, "prepare_tokenizer", None) or getattr(session, "tokenizer", None) + if tokenizer is None or not hasattr(tokenizer, "apply_chat_template"): + return prompt + + try: + rendered = tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + tokenize=False, + add_generation_prompt=True, + ) + except Exception: + return prompt + return rendered if isinstance(rendered, str) and rendered.strip() else prompt + + +def _maybe_write_evalution_output(output_path: Optional[str], result: Mapping[str, Any]) -> None: + if not output_path: + return + + path = Path(output_path) + if path.suffix.lower() != ".json": + return + + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8") as handle: + json.dump(result, handle, indent=2, sort_keys=True) + + +__all__ = [ + "DEFAULT_TASKS", + "SUPPORTED_TASKS", + "evaluate", + "format_eval_result_table", + "get_eval_task_metrics", + "get_eval_task_results", + "import_evalution", + "list_supported_tasks", + "normalize_eval_task_name", + "resolve_eval_metric_alias", +] diff --git a/tests/inference_speed.py b/tests/inference_speed.py index a502ac09e..687d12ec4 100644 --- a/tests/inference_speed.py +++ b/tests/inference_speed.py @@ -6,6 +6,8 @@ import os import time +import torch + from gptqmodel.utils.torch import torch_empty_cache @@ -43,6 +45,9 @@ class InferenceSpeed(unittest.TestCase): MAX_POSITIVE_DELTA_CEIL_PERCENT = 0.25 def inference(self, model_path, backend, tokens_per_second, assert_result=True, optimize=False, fullgraph=False, warmup_runs=0, device=None): + if device == "cuda" and torch.cuda.is_available(): + device = f"cuda:{torch.cuda.current_device()}" + model = GPTQModel.from_quantized( model_path, backend=backend, @@ -57,37 +62,38 @@ def inference(self, model_path, backend, tokens_per_second, assert_result=True, inp = tokenizer(self.PROMPTS, padding=True, truncation=True, return_tensors="pt", padding_side='left').to( model.device) - times = [] - tokens = [] - # compile kernels need JIT compile (Bitblas, IPEX, Triton) so we should do some warmup before actual speed run if warmup_runs > 0: + warmup_times = [] + warmup_tokens = [] pb = logger.pb(range(warmup_runs)).title("Warmup") for _ in pb: start_time = time.time() result = model.generate(**inp, max_new_tokens=self.MAX_NEW_TOKENS, pad_token_id=tokenizer.pad_token_id) end_time = time.time() elapsed_time = end_time - start_time - times.append(elapsed_time) + warmup_times.append(elapsed_time) for j in range(result.shape[0]): new_tokens = result[j][inp['input_ids'].shape[1]:] new_token_count = len(new_tokens) - tokens.append(new_token_count) + warmup_tokens.append(new_token_count) - sum_time = sum(times) - sum_tokens = sum(tokens) + sum_time = sum(warmup_times) + sum_tokens = sum(warmup_tokens) avg_tokens_per_second = round(sum_tokens / sum_time, 2) print(f"\n**************** {backend} Warm-up Result Info****************") - print(f"Times: {times}") - print(f"New Tokens (Size Per Batch Request): {tokens}") + print(f"Times: {warmup_times}") + print(f"New Tokens (Size Per Batch Request): {warmup_tokens}") print(f"Sum Times: {sum_time}") print(f"Sum New Tokens: {sum_tokens}") print(f"New Token Per Second: {avg_tokens_per_second} token/s") print(f"**************** {backend} Warm-up Result Info End****************") + times = [] + tokens = [] pb = logger.pb(range(self.NUM_RUNS)).title("Run") for _ in pb: start_time = time.time() @@ -115,7 +121,7 @@ def inference(self, model_path, backend, tokens_per_second, assert_result=True, print(f"**************** {backend} Result Info End****************") if not assert_result: - return + return avg_tokens_per_second diff_pct = (avg_tokens_per_second / tokens_per_second) * 100 negative_pct = 100 * (1 - self.MAX_DELTA_FLOOR_PERCENT) @@ -126,3 +132,4 @@ def inference(self, model_path, backend, tokens_per_second, assert_result=True, del model torch_empty_cache() + return avg_tokens_per_second diff --git a/tests/kernels/benchmark_intel_cpu_xpu.py b/tests/kernels/benchmark_intel_cpu_xpu.py index f73f49cd1..df53f6ca4 100644 --- a/tests/kernels/benchmark_intel_cpu_xpu.py +++ b/tests/kernels/benchmark_intel_cpu_xpu.py @@ -13,10 +13,10 @@ from logbar import LogBar from gptqmodel import BACKEND, GPTQModel -from gptqmodel.nn_modules.qlinear.gemm_hf_kernel import HFKernelLinear -from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear -from gptqmodel.nn_modules.qlinear.torch_fused import TorchFusedQuantLinear -from gptqmodel.nn_modules.qlinear.torch_int8 import TorchInt8QuantLinear +from gptqmodel.nn_modules.qlinear.torch import TorchLinear +from gptqmodel.nn_modules.qlinear.torch_aten_kernel import TorchAtenLinear +from gptqmodel.nn_modules.qlinear.torch_fused import TorchFusedLinear +from gptqmodel.nn_modules.qlinear.torch_int8 import TorchInt8Linear from gptqmodel.utils.model import find_modules @@ -53,10 +53,10 @@ class BenchmarkIntelCpuXPU(unittest.TestCase): new_tokens = int(os.getenv("GPTQMODEL_INTEL_CPU_BENCH_NEW_TOKENS", "1")) target_qliner_map = { - BACKEND.TORCH: TorchQuantLinear, - BACKEND.TORCH_FUSED: TorchFusedQuantLinear, - BACKEND.TORCH_INT8: TorchInt8QuantLinear, - BACKEND.HF_KERNEL: HFKernelLinear, + BACKEND.TORCH: TorchLinear, + BACKEND.TORCH_FUSED: TorchFusedLinear, + BACKEND.TORCH_INT8: TorchInt8Linear, + BACKEND.GPTQ_TORCH_ATEN: TorchAtenLinear, } skip_backends = set() @@ -131,7 +131,7 @@ def test_cpu_backend_inference_speed_trimmed(self): success_count = 0 for backend, qlinear_cls in self.target_qliner_map.items(): - if backend in self.skip_backends and os.getenv("GPTQMODEL_INTEL_CPU_BENCH_ENABLE_HF", "0") != "1": + if backend in self.skip_backends and os.getenv("GPTQMODEL_INTEL_CPU_BENCH_ENABLE_TORCH_ATEN", "0") != "1": bench_cols.info( backend.name, qlinear_cls.__name__, @@ -144,7 +144,7 @@ def test_cpu_backend_inference_speed_trimmed(self): "-", "-", "SKIP", - "Temporarily skipped (set GPTQMODEL_INTEL_CPU_BENCH_ENABLE_HF=1 to enable).", + "Temporarily skipped (set GPTQMODEL_INTEL_CPU_BENCH_ENABLE_TORCH_ATEN=1 to enable).", ) continue try: diff --git a/tests/kernels/test_asymmetric_real_models.py b/tests/kernels/test_asymmetric_real_models.py new file mode 100644 index 000000000..bfdce5935 --- /dev/null +++ b/tests/kernels/test_asymmetric_real_models.py @@ -0,0 +1,306 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import gc +import re +from dataclasses import dataclass +from typing import Iterable + +import pytest +import torch +from tests.models.model_test import ModelTest + +from gptqmodel import BACKEND, GPTQModel +from gptqmodel.nn_modules.qlinear.machete import MacheteLinear +from gptqmodel.nn_modules.qlinear.machete_awq import AwqMacheteLinear +from gptqmodel.nn_modules.qlinear.marlin_awq import AwqMarlinLinear +from gptqmodel.nn_modules.qlinear.torch import TorchLinear +from gptqmodel.nn_modules.qlinear.torch_awq import AwqTorchLinear + + +pytestmark = [ + pytest.mark.cuda, + pytest.mark.model, + pytest.mark.slow, +] + + +_DEVICE = torch.device("cuda:0") +_DTYPE = torch.float16 +_PROMPT = "What is the surface area of the Sun?" +_AWQ_MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct-AWQ" +_GPTQ_MODEL_ID = "ruikangliu/DeepSeek-R1-Distill-Qwen-1.5B-quantized.gptq-gptqmodel-w4g128" +_TARGET_MODULE = "model.layers.0.mlp.up_proj" +_STOPWORDS = { + "a", + "an", + "and", + "approximately", + "around", + "as", + "at", + "be", + "for", + "in", + "is", + "it", + "of", + "on", + "or", + "roughly", + "the", + "to", +} + + +@dataclass(frozen=True) +class _RealModelCase: + name: str + model_id: str + baseline_backend: BACKEND + candidate_backend: BACKEND + baseline_cls: type[torch.nn.Module] + candidate_cls: type[torch.nn.Module] + atol: float + rtol: float + + +_AWQ_REAL_CASES = ( + _RealModelCase( + name="awq_marlin", + model_id=_AWQ_MODEL_ID, + baseline_backend=BACKEND.TORCH_AWQ, + candidate_backend=BACKEND.MARLIN, + baseline_cls=AwqTorchLinear, + candidate_cls=AwqMarlinLinear, + atol=1e-2, + rtol=1e-2, + ), + _RealModelCase( + name="awq_machete", + model_id=_AWQ_MODEL_ID, + baseline_backend=BACKEND.TORCH_AWQ, + candidate_backend=BACKEND.MACHETE, + baseline_cls=AwqTorchLinear, + candidate_cls=AwqMacheteLinear, + atol=1.5e-2, + rtol=1.5e-2, + ), +) + +_GPTQ_REAL_CASES = ( + _RealModelCase( + name="gptq_machete", + model_id=_GPTQ_MODEL_ID, + baseline_backend=BACKEND.TORCH, + candidate_backend=BACKEND.MACHETE, + baseline_cls=TorchLinear, + candidate_cls=MacheteLinear, + atol=1.5e-2, + rtol=1.5e-2, + ), +) + + +def _module_device(module: torch.nn.Module) -> torch.device: + for tensor in module.parameters(recurse=False): + if tensor is not None and not tensor.is_meta: + return tensor.device + for tensor in module.buffers(recurse=False): + if tensor is not None and not tensor.is_meta: + return tensor.device + raise RuntimeError(f"Unable to infer runtime device for `{module.__class__.__name__}`.") + + +def _release_model(model) -> None: + if model is not None: + del model + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +def _load_model(model_id: str, backend: BACKEND): + model = GPTQModel.load( + model_id, + backend=backend, + dtype=_DTYPE, + device=_DEVICE, + ) + assert model.quantize_config.sym is False + return model + + +def _target_module(model, expected_cls: type[torch.nn.Module]) -> torch.nn.Module: + module_map = dict(model.model.named_modules()) + if _TARGET_MODULE not in module_map: + raise KeyError(f"Target module `{_TARGET_MODULE}` not found in model `{model.model_id_or_path}`.") + module = module_map[_TARGET_MODULE] + assert isinstance(module, expected_cls), ( + f"Expected `{_TARGET_MODULE}` to use `{expected_cls.__name__}`, got `{module.__class__.__name__}`." + ) + return module + + +def _layer_inputs(in_features: int, *, device: torch.device) -> list[torch.Tensor]: + torch.manual_seed(17) + return [ + torch.randn((1, in_features), device=device, dtype=_DTYPE), + torch.randn((8, in_features), device=device, dtype=_DTYPE), + torch.randn((2, 3, in_features), device=device, dtype=_DTYPE), + ] + + +def _forward_module(module: torch.nn.Module, inputs: Iterable[torch.Tensor]) -> list[torch.Tensor]: + outputs: list[torch.Tensor] = [] + module_device = _module_device(module) + with torch.inference_mode(): + for x in inputs: + current = x if x.device == module_device else x.to(module_device) + outputs.append(module(current).detach().to(device="cpu", dtype=torch.float32)) + torch.cuda.synchronize(module_device) + return outputs + + +def _normalized_text(text: str) -> str: + return re.sub(r"\s+", " ", text).strip().lower() + + +def _content_tokens(text: str) -> set[str]: + return { + token + for token in re.findall(r"[a-z0-9]+", _normalized_text(text)) + if token not in _STOPWORDS + } + + +def _assert_solar_answer_not_garbled(text: str) -> None: + normalized = _normalized_text(text) + assert normalized, "generation output is empty" + assert "\ufffd" not in normalized + assert sum(ch.isprintable() for ch in normalized) / len(normalized) > 0.98 + assert any(term in normalized for term in ("surface", "square", "kilometer", "kilometers", "km", "solar", "sun")), ( + f"expected a surface-area style answer, got: {text}" + ) + assert re.search(r"\d", normalized) or any( + term in normalized for term in ("sphere", "formula", "radius") + ), f"expected a numeric or formula-based answer, got: {text}" + + +def _generation_inputs(tokenizer): + if hasattr(tokenizer, "apply_chat_template") and getattr(tokenizer, "chat_template", None): + prompt_text = tokenizer.apply_chat_template( + [{"role": "user", "content": _PROMPT}], + tokenize=False, + add_generation_prompt=True, + ) + else: + prompt_text = _PROMPT + + inputs = tokenizer(prompt_text, return_tensors="pt") + return prompt_text, inputs, int(inputs.input_ids.shape[1]) + + +def _generate_completion(model) -> str: + tokenizer = model.tokenizer + prompt_text, inputs, decode_start_idx = _generation_inputs(tokenizer) + return ModelTest.generate_stable_with_limit( + model, + tokenizer, + prompt_text, + inputs=inputs, + decode_start_idx=decode_start_idx, + max_new_tokens=48, + ) + + +def _assert_generation_tracks_torch(candidate_text: str, baseline_text: str) -> None: + _assert_solar_answer_not_garbled(baseline_text) + _assert_solar_answer_not_garbled(candidate_text) + + baseline_tokens = _content_tokens(baseline_text) + candidate_tokens = _content_tokens(candidate_text) + shared = baseline_tokens & candidate_tokens + + assert len(shared) >= 4, ( + "candidate generation diverged too far from torch baseline.\n" + f"baseline={baseline_text!r}\n" + f"candidate={candidate_text!r}" + ) + assert any( + token in shared for token in {"surface", "area", "sun", "solar", "square", "kilometer", "kilometers", "km"} + ), ( + "candidate generation does not preserve the core answer terms from torch baseline.\n" + f"baseline={baseline_text!r}\n" + f"candidate={candidate_text!r}" + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +@pytest.mark.parametrize("case", _AWQ_REAL_CASES, ids=lambda case: case.name) +def test_awq_asymmetric_real_layer_outputs_match_torch(case: _RealModelCase) -> None: + baseline_model = None + candidate_model = None + try: + baseline_model = _load_model(case.model_id, case.baseline_backend) + baseline_module = _target_module(baseline_model, case.baseline_cls) + inputs = _layer_inputs(baseline_module.in_features, device=_DEVICE) + baseline_outputs = _forward_module(baseline_module, inputs) + finally: + _release_model(baseline_model) + + try: + candidate_model = _load_model(case.model_id, case.candidate_backend) + candidate_module = _target_module(candidate_model, case.candidate_cls) + candidate_outputs = _forward_module(candidate_module, inputs) + finally: + _release_model(candidate_model) + + for actual, expected in zip(candidate_outputs, baseline_outputs): + torch.testing.assert_close(actual, expected, atol=case.atol, rtol=case.rtol) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +@pytest.mark.parametrize("case", _GPTQ_REAL_CASES, ids=lambda case: case.name) +def test_gptq_asymmetric_real_layer_outputs_match_torch(case: _RealModelCase) -> None: + baseline_model = None + candidate_model = None + try: + baseline_model = _load_model(case.model_id, case.baseline_backend) + baseline_module = _target_module(baseline_model, case.baseline_cls) + inputs = _layer_inputs(baseline_module.in_features, device=_DEVICE) + baseline_outputs = _forward_module(baseline_module, inputs) + finally: + _release_model(baseline_model) + + try: + candidate_model = _load_model(case.model_id, case.candidate_backend) + candidate_module = _target_module(candidate_model, case.candidate_cls) + candidate_outputs = _forward_module(candidate_module, inputs) + finally: + _release_model(candidate_model) + + for actual, expected in zip(candidate_outputs, baseline_outputs): + torch.testing.assert_close(actual, expected, atol=case.atol, rtol=case.rtol) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +@pytest.mark.parametrize("case", _AWQ_REAL_CASES + _GPTQ_REAL_CASES, ids=lambda case: f"{case.name}_generation") +def test_asymmetric_real_generation_matches_torch_and_is_sane(case: _RealModelCase) -> None: + baseline_model = None + candidate_model = None + try: + baseline_model = _load_model(case.model_id, case.baseline_backend) + baseline_text = _generate_completion(baseline_model) + finally: + _release_model(baseline_model) + + try: + candidate_model = _load_model(case.model_id, case.candidate_backend) + candidate_text = _generate_completion(candidate_model) + finally: + _release_model(candidate_model) + + _assert_generation_tracks_torch(candidate_text, baseline_text) diff --git a/tests/kernels/test_awq.py b/tests/kernels/test_awq.py index 8e671f427..a30e271b0 100644 --- a/tests/kernels/test_awq.py +++ b/tests/kernels/test_awq.py @@ -5,8 +5,11 @@ import json import os +import time import unittest +from dataclasses import dataclass from pathlib import Path +from types import SimpleNamespace from typing import Dict, Iterable, List, Optional, Tuple import torch @@ -16,27 +19,30 @@ from tabulate import tabulate from gptqmodel import BACKEND -from gptqmodel.nn_modules.qlinear.gemm_awq import AwqGEMMQuantLinear +from gptqmodel.nn_modules.qlinear.bitblas import BITBLAS_AVAILABLE, BITBLAS_INSTALL_HINT +from gptqmodel.nn_modules.qlinear.bitblas_awq import AWQBitBlasKernel +from gptqmodel.nn_modules.qlinear.gemm_awq import AwqGEMMLinear +from gptqmodel.nn_modules.qlinear.machete_awq import AwqMacheteLinear from gptqmodel.nn_modules.qlinear.marlin_awq import ( - AwqMarlinQuantLinear, + AwqMarlinLinear, marlin_import_exception, ) -from gptqmodel.nn_modules.qlinear.torch_awq import AwqTorchQuantLinear -from gptqmodel.nn_modules.qlinear.torch_fused_awq import TorchFusedAwqQuantLinear +from gptqmodel.nn_modules.qlinear.torch_awq import AwqTorchLinear +from gptqmodel.nn_modules.qlinear.torch_fused_awq import TorchFusedAwqLinear from gptqmodel.utils.marlin import marlin_make_workspace_new try: - from gptqmodel.nn_modules.qlinear.gemm_awq_triton import AwqGEMMTritonQuantLinear + from gptqmodel.nn_modules.qlinear.gemm_awq_triton import AwqGEMMTritonLinear awq_triton_import_exception: Optional[Exception] = None except Exception as exc: # pragma: no cover - triton import may fail in CI - AwqGEMMTritonQuantLinear = None # type: ignore[assignment] + AwqGEMMTritonLinear = None # type: ignore[assignment] awq_triton_import_exception = exc -from gptqmodel.nn_modules.qlinear.exllama_awq import AwqExllamaQuantLinear -from gptqmodel.nn_modules.qlinear.exllamav2_awq import AwqExllamaV2QuantLinear +from gptqmodel.nn_modules.qlinear.exllamav2_awq import AwqExllamaV2Linear from gptqmodel.utils.exllamav2 import ScratchSpace +from gptqmodel.utils.machete import _validate_machete_device_support, machete_runtime_error os.environ.setdefault("CUDA_DEVICE_ORDER", "PCI_BUS_ID") @@ -46,11 +52,21 @@ DEVICE = torch.device("cuda:0") CPU_DEVICE = torch.device("cpu") +AWQ_MARLIN_FP16_ATOL = 0.006 +AWQ_MARLIN_BF16_ATOL = 0.02 + GREEN = "\033[32m" RED = "\033[31m" RESET = "\033[0m" +@dataclass +class ForwardResult: + outputs: List[torch.Tensor] + total_ms: float + mean_ms: float + + def _xpu_available() -> bool: return hasattr(torch, "xpu") and torch.xpu.is_available() @@ -60,19 +76,20 @@ class TestAwqKernelOutput(unittest.TestCase): TARGET = "model.layers.20.self_attn.v_proj" BITS = 4 GROUP_SIZE = 128 - SUPPORTED_DTYPES = (torch.float16,) + SUPPORTED_DTYPES = (torch.float16, torch.bfloat16) baseline_backend = BACKEND.TORCH_AWQ backend_cases = [ (baseline_backend, torch.float16, 0.0), - # (baseline_backend, torch.bfloat16, 0.0), + (baseline_backend, torch.bfloat16, 0.0), (BACKEND.GEMM, torch.float16, 0.004), + (BACKEND.BITBLAS_AWQ, torch.float16, 0.004), # (BACKEND.GEMM, torch.bfloat16, 0.05), (BACKEND.TRITON, torch.float16, 0.004), - (BACKEND.MARLIN, torch.float16, 0.006), + (BACKEND.MACHETE, torch.float16, 0.006), + (BACKEND.MARLIN, torch.float16, AWQ_MARLIN_FP16_ATOL), (BACKEND.TORCH_FUSED_AWQ, torch.float16, 0.004), - # (BACKEND.MARLIN, torch.bfloat16, 0.05), - (BACKEND.EXLLAMA_V1, torch.float16, 0.006), + # (BACKEND.MARLIN, torch.bfloat16, AWQ_MARLIN_BF16_ATOL), (BACKEND.EXLLAMA_V2, torch.float16, 0.0068), ] @@ -85,10 +102,17 @@ def setUpClass(cls) -> None: cls.backend_skip_reason: Dict[BACKEND, str] = {} if not cls.cuda_available: cls.backend_skip_reason[BACKEND.GEMM] = "CUDA is required for GEMM backend." + cls.backend_skip_reason[BACKEND.BITBLAS_AWQ] = "CUDA is required for BitBLAS AWQ backend." cls.backend_skip_reason[BACKEND.TRITON] = "CUDA is required for AWQ Triton backend." + cls.backend_skip_reason[BACKEND.MACHETE] = "CUDA is required for AWQ Machete kernel." cls.backend_skip_reason[BACKEND.MARLIN] = "CUDA is required for AWQ Marlin kernel." - cls.backend_skip_reason[BACKEND.EXLLAMA_V1] = "CUDA is required for ExLlama v1 AWQ kernel." cls.backend_skip_reason[BACKEND.EXLLAMA_V2] = "CUDA is required for ExLlama v2 AWQ kernel." + elif not _validate_machete_device_support(): + cls.backend_skip_reason[BACKEND.MACHETE] = machete_runtime_error() + elif os.getenv("RUN_BITBLAS_TESTS", "0") != "1": + cls.backend_skip_reason[BACKEND.BITBLAS_AWQ] = "BitBLAS disabled (set RUN_BITBLAS_TESTS=1 to enable)." + elif not BITBLAS_AVAILABLE: + cls.backend_skip_reason[BACKEND.BITBLAS_AWQ] = BITBLAS_INSTALL_HINT if awq_triton_import_exception is not None: cls.backend_skip_reason[BACKEND.TRITON] = ( f"AWQ Triton kernel unavailable: {awq_triton_import_exception}" @@ -125,6 +149,26 @@ def setUpClass(cls) -> None: else None ) + try: + cls.modules[BACKEND.BITBLAS_AWQ] = ( + cls._build_bitblas_module(qweight_cpu, qzeros_cpu, scales_cpu, bias_cpu) + if BACKEND.BITBLAS_AWQ not in cls.backend_skip_reason + else None + ) + except Exception as exc: + cls.backend_skip_reason[BACKEND.BITBLAS_AWQ] = f"AWQ BitBLAS kernel unavailable: {exc}" + cls.modules[BACKEND.BITBLAS_AWQ] = None + + try: + cls.modules[BACKEND.MACHETE] = ( + cls._build_machete_module(qweight_cpu, qzeros_cpu, scales_cpu, bias_cpu) + if BACKEND.MACHETE not in cls.backend_skip_reason + else None + ) + except Exception as exc: + cls.backend_skip_reason[BACKEND.MACHETE] = f"AWQ Machete kernel unavailable: {exc}" + cls.modules[BACKEND.MACHETE] = None + try: cls.modules[BACKEND.TRITON] = ( cls._build_gemm_triton_module(qweight_cpu, qzeros_cpu, scales_cpu, bias_cpu) @@ -141,12 +185,6 @@ def setUpClass(cls) -> None: else None ) - cls.modules[BACKEND.EXLLAMA_V1] = ( - cls._build_exllama_v1_module(qweight_cpu, qzeros_cpu, scales_cpu, bias_cpu) - if cls.cuda_available - else None - ) - cls.modules[BACKEND.EXLLAMA_V2] = ( cls._build_exllama_v2_module(qweight_cpu, qzeros_cpu, scales_cpu, bias_cpu) if cls.cuda_available @@ -165,7 +203,7 @@ def setUpClass(cls) -> None: base_inputs = cls._generate_inputs() cls.inputs: Dict[torch.dtype, List[torch.Tensor]] = {} - cls.reference_outputs: Dict[torch.dtype, List[torch.Tensor]] = {} + cls.reference_results: Dict[torch.dtype, ForwardResult] = {} for dtype in cls.SUPPORTED_DTYPES: converted_inputs = [ @@ -177,16 +215,9 @@ def setUpClass(cls) -> None: if torch_module is None: raise unittest.SkipTest("Torch AWQ kernel unavailable for baseline.") - forward_kwargs = {} - if dtype == torch.bfloat16: - forward_kwargs = { - "compute_dtype": torch.float16, - "output_dtype": dtype, - } - cls.reference_outputs[dtype] = cls._forward( + cls.reference_results[dtype] = cls._forward( torch_module, converted_inputs, - **forward_kwargs, ) @classmethod @@ -200,6 +231,8 @@ def tearDownClass(cls) -> None: @classmethod def _load_weight_map(cls) -> Dict[str, str]: index_path = cls.MODEL_PATH / "model.safetensors.index.json" + if not index_path.is_file(): + raise unittest.SkipTest(f"AWQ checkpoint not available at {index_path}") with open(index_path, "r") as handle: index = json.load(handle) return index["weight_map"] @@ -227,8 +260,8 @@ def _build_gemm_module( qzeros_cpu: torch.Tensor, scales_cpu: torch.Tensor, bias_cpu: torch.Tensor, - ) -> AwqGEMMQuantLinear: - module = AwqGEMMQuantLinear( + ) -> AwqGEMMLinear: + module = AwqGEMMLinear( bits=cls.BITS, group_size=cls.GROUP_SIZE, sym=True, @@ -249,6 +282,35 @@ def _build_gemm_module( module.post_init() return module + @classmethod + def _build_bitblas_module( + cls, + qweight_cpu: torch.Tensor, + qzeros_cpu: torch.Tensor, + scales_cpu: torch.Tensor, + bias_cpu: torch.Tensor, + ) -> AWQBitBlasKernel: + module = AWQBitBlasKernel( + bits=cls.BITS, + group_size=cls.GROUP_SIZE, + sym=True, + desc_act=False, + in_features=cls.in_features, + out_features=cls.out_features, + bias=True, + adapter=None, + ).to(cls.device) + + source_module = SimpleNamespace( + qweight=qweight_cpu.to(cls.device), + qzeros=qzeros_cpu.to(cls.device), + scales=scales_cpu.to(torch.float16).to(cls.device), + bias=bias_cpu.to(torch.float16).to(cls.device), + ) + module.repack_from_awq(source_module) + module.eval() + return module + @classmethod def _build_gemm_triton_module( cls, @@ -256,10 +318,10 @@ def _build_gemm_triton_module( qzeros_cpu: torch.Tensor, scales_cpu: torch.Tensor, bias_cpu: torch.Tensor, - ) -> AwqGEMMTritonQuantLinear: - if AwqGEMMTritonQuantLinear is None: + ) -> AwqGEMMTritonLinear: + if AwqGEMMTritonLinear is None: raise RuntimeError("AWQ Triton kernel not available.") - module = AwqGEMMTritonQuantLinear( + module = AwqGEMMTritonLinear( bits=cls.BITS, group_size=cls.GROUP_SIZE, sym=True, @@ -287,7 +349,9 @@ def _build_marlin_module( qzeros_cpu: torch.Tensor, scales_cpu: torch.Tensor, bias_cpu: torch.Tensor, - ) -> Optional[AwqMarlinQuantLinear]: + *, + dtype: torch.dtype = torch.float16, + ) -> Optional[AwqMarlinLinear]: if marlin_import_exception is not None: cls.backend_skip_reason[BACKEND.MARLIN] = f"AWQ Marlin kernel unavailable: {marlin_import_exception}" return None @@ -300,39 +364,40 @@ def _build_marlin_module( cls.backend_skip_reason[BACKEND.MARLIN] = f"Unable to allocate Marlin workspace: {exc}" return None - module = AwqMarlinQuantLinear( + module = AwqMarlinLinear( bits=cls.BITS, group_size=cls.GROUP_SIZE, - sym=True, + sym=False, desc_act=False, in_features=cls.in_features, out_features=cls.out_features, bias=True, + dtype=dtype, adapter=None, register_buffers=True, ).to(cls.device) module.qweight.data.copy_(qweight_cpu.to(cls.device)) module.qzeros.data.copy_(qzeros_cpu.to(cls.device)) - module.scales.data.copy_(scales_cpu.to(torch.float16).to(cls.device)) - module.bias.data.copy_(bias_cpu.to(torch.float16).to(cls.device)) + module.scales.data.copy_(scales_cpu.to(dtype).to(cls.device)) + module.bias.data.copy_(bias_cpu.to(dtype).to(cls.device)) module.eval() module.post_init() return module @classmethod - def _build_torch_awq_module( + def _build_machete_module( cls, qweight_cpu: torch.Tensor, qzeros_cpu: torch.Tensor, scales_cpu: torch.Tensor, bias_cpu: torch.Tensor, - ) -> AwqTorchQuantLinear: - module = AwqTorchQuantLinear( + ) -> Optional[AwqMacheteLinear]: + module = AwqMacheteLinear( bits=cls.BITS, group_size=cls.GROUP_SIZE, - sym=True, + sym=False, desc_act=False, in_features=cls.in_features, out_features=cls.out_features, @@ -341,49 +406,43 @@ def _build_torch_awq_module( register_buffers=True, ).to(cls.device) - module.qweight.copy_(qweight_cpu.to(cls.device)) - module.qzeros.copy_(qzeros_cpu.to(cls.device)) - module.scales.copy_(scales_cpu.to(cls.device)) - module.bias.copy_(bias_cpu.to(cls.device)) + module.qweight.data.copy_(qweight_cpu.to(cls.device)) + module.qzeros.data.copy_(qzeros_cpu.to(cls.device)) + module.scales.data.copy_(scales_cpu.to(torch.float16).to(cls.device)) + module.bias.data.copy_(bias_cpu.to(torch.float16).to(cls.device)) module.eval() module.post_init() return module @classmethod - def _build_exllama_v1_module( + def _build_torch_awq_module( cls, qweight_cpu: torch.Tensor, qzeros_cpu: torch.Tensor, scales_cpu: torch.Tensor, bias_cpu: torch.Tensor, - ) -> Optional[AwqExllamaQuantLinear]: - try: - module = AwqExllamaQuantLinear( - bits=cls.BITS, - group_size=cls.GROUP_SIZE, - sym=True, - desc_act=False, - in_features=cls.in_features, - out_features=cls.out_features, - bias=True, - adapter=None, - register_buffers=True, - ).to(cls.device) + ) -> AwqTorchLinear: + module = AwqTorchLinear( + bits=cls.BITS, + group_size=cls.GROUP_SIZE, + sym=True, + desc_act=False, + in_features=cls.in_features, + out_features=cls.out_features, + bias=True, + adapter=None, + register_buffers=True, + ).to(cls.device) - module.qweight.copy_(qweight_cpu.to(cls.device)) - module.qzeros.copy_(qzeros_cpu.to(cls.device)) - module.scales.copy_(scales_cpu.to(torch.float16).to(cls.device)) - module.bias.copy_(bias_cpu.to(torch.float16).to(cls.device)) + module.qweight.copy_(qweight_cpu.to(cls.device)) + module.qzeros.copy_(qzeros_cpu.to(cls.device)) + module.scales.copy_(scales_cpu.to(cls.device)) + module.bias.copy_(bias_cpu.to(cls.device)) - module.eval() - module.post_init() - return module - except Exception as exc: - cls.backend_skip_reason[BACKEND.EXLLAMA_V1] = ( - f"ExLlama v1 AWQ kernel unavailable: {exc}" - ) - return None + module.eval() + module.post_init() + return module @classmethod def _build_exllama_v2_module( @@ -392,9 +451,9 @@ def _build_exllama_v2_module( qzeros_cpu: torch.Tensor, scales_cpu: torch.Tensor, bias_cpu: torch.Tensor, - ) -> Optional[AwqExllamaV2QuantLinear]: + ) -> Optional[AwqExllamaV2Linear]: try: - module = AwqExllamaV2QuantLinear( + module = AwqExllamaV2Linear( bits=cls.BITS, group_size=cls.GROUP_SIZE, sym=True, @@ -431,8 +490,8 @@ def _build_torch_fused_awq_module( bias_cpu: torch.Tensor, *, device: torch.device = CPU_DEVICE, - ) -> TorchFusedAwqQuantLinear: - module = TorchFusedAwqQuantLinear( + ) -> TorchFusedAwqLinear: + module = TorchFusedAwqLinear( bits=cls.BITS, group_size=cls.GROUP_SIZE, sym=True, @@ -497,22 +556,46 @@ def _forward( compute_dtype: Optional[torch.dtype] = None, output_dtype: Optional[torch.dtype] = None, target_device: Optional[torch.device] = None, - ) -> List[torch.Tensor]: + ) -> ForwardResult: if target_device is None: target_device = cls._infer_module_device(module) + prepared_inputs = list(inputs) outputs: List[torch.Tensor] = [] + total_s = 0.0 with torch.inference_mode(): - for tensor in inputs: + if prepared_inputs: + warmup_tensor = prepared_inputs[0] + if warmup_tensor.device != target_device: + warmup_tensor = warmup_tensor.to(device=target_device) + if compute_dtype is not None and warmup_tensor.dtype != compute_dtype: + warmup_tensor = warmup_tensor.to(dtype=compute_dtype) + cls._synchronize(target_device) + module(warmup_tensor) + cls._synchronize(target_device) + for tensor in prepared_inputs: local_tensor = tensor if local_tensor.device != target_device: local_tensor = local_tensor.to(device=target_device) if compute_dtype is not None and local_tensor.dtype != compute_dtype: local_tensor = local_tensor.to(dtype=compute_dtype) + cls._synchronize(target_device) + started = time.perf_counter() result = module(local_tensor) + cls._synchronize(target_device) + total_s += time.perf_counter() - started if output_dtype is not None and result.dtype != output_dtype: result = result.to(dtype=output_dtype) outputs.append(result.detach().cpu()) - return outputs + total_ms = total_s * 1000.0 + mean_ms = total_ms / len(outputs) if outputs else 0.0 + return ForwardResult(outputs=outputs, total_ms=total_ms, mean_ms=mean_ms) + + @staticmethod + def _synchronize(device: torch.device) -> None: + if device.type == "cuda": + torch.cuda.synchronize(device) + elif device.type == "xpu" and _xpu_available(): + torch.xpu.synchronize() @staticmethod def _infer_module_device(module: torch.nn.Module) -> torch.device: @@ -542,6 +625,8 @@ def _summarize_results( title: str, reference_label: str, device: Optional[torch.device] = None, + reference_mean_ms: float = 0.0, + actual_mean_ms: float = 0.0, ) -> None: failures = [] total = len(actual_outputs) @@ -567,6 +652,7 @@ def _summarize_results( status = f"{GREEN}PASS{RESET}" if not failures else f"{RED}FAIL{RESET}" avg_abs_diff = mean_abs_diff / total if total else 0.0 + speedup = reference_mean_ms / actual_mean_ms if actual_mean_ms else 0.0 details = "\n\n".join(str(detail) for detail in failures) if failures else "-" device_label = str(device) if device is not None else "-" @@ -577,6 +663,8 @@ def _summarize_results( str(dtype), device_label, total, + f"{actual_mean_ms:.4f}", + f"{speedup:.2f}x", f"{max_abs_diff:.6f}", f"{avg_abs_diff:.6f}", status, @@ -589,6 +677,8 @@ def _summarize_results( "DType", "Device", "Samples", + "MeanLatencyMs", + "SpeedupVsRef", "MaxAbsDiff", "MeanAbsDiff", "Status", @@ -613,21 +703,60 @@ def test_awq_kernel_outputs(self, backend: BACKEND, dtype: torch.dtype, atol: fl self.skipTest(f"Backend `{backend}` module unavailable.") inputs = self.inputs[dtype] - reference_outputs = self.reference_outputs[dtype] + reference_result = self.reference_results[dtype] if backend == self.baseline_backend: - actual_outputs = reference_outputs + actual_result = reference_result else: - actual_outputs = self._forward(module, inputs) + actual_result = self._forward(module, inputs) self._summarize_results( - reference_outputs=reference_outputs, - actual_outputs=actual_outputs, + reference_outputs=reference_result.outputs, + actual_outputs=actual_result.outputs, backend=backend, dtype=dtype, atol=atol, title=f"AWQ Kernel Output {dtype}", reference_label="Torch AWQ output", + reference_mean_ms=reference_result.mean_ms, + actual_mean_ms=actual_result.mean_ms, ) + def test_awq_marlin_bfloat16_outputs(self) -> None: + self._maybe_skip_backend(BACKEND.MARLIN) + + if not self.cuda_available: + self.skipTest("CUDA is required for AWQ Marlin kernel.") + if not torch.cuda.is_bf16_supported(): + self.skipTest("CUDA bfloat16 not supported on this device.") + + module = self._build_marlin_module( + self.qweight_cpu, + self.qzeros_cpu, + self.scales_cpu, + self.bias_cpu, + dtype=torch.bfloat16, + ) + if module is None: + self.skipTest("AWQ Marlin bf16 module unavailable.") + + try: + reference_result = self.reference_results[torch.bfloat16] + actual_result = self._forward(module, self.inputs[torch.bfloat16]) + self._summarize_results( + reference_outputs=reference_result.outputs, + actual_outputs=actual_result.outputs, + backend=BACKEND.MARLIN, + dtype=torch.bfloat16, + atol=AWQ_MARLIN_BF16_ATOL, + title="AWQ Kernel Output torch.bfloat16", + reference_label="Torch AWQ output", + reference_mean_ms=reference_result.mean_ms, + actual_mean_ms=actual_result.mean_ms, + ) + finally: + del module + if torch.cuda.is_available(): + torch.cuda.empty_cache() + @parameterized.expand( [ ("cpu", "cpu"), @@ -649,20 +778,22 @@ def test_torch_fused_awq_devices(self, _label: str, device_str: str) -> None: ) try: - actual_outputs = self._forward( + actual_result = self._forward( module, self.inputs[torch.float16], target_device=device, ) self._summarize_results( - reference_outputs=self.reference_outputs[torch.float16], - actual_outputs=actual_outputs, + reference_outputs=self.reference_results[torch.float16].outputs, + actual_outputs=actual_result.outputs, backend=BACKEND.TORCH_FUSED_AWQ, dtype=torch.float16, atol=0.004, title=f"Torch Fused AWQ Device {device_str}", reference_label="Torch AWQ output", device=device, + reference_mean_ms=self.reference_results[torch.float16].mean_ms, + actual_mean_ms=actual_result.mean_ms, ) finally: del module diff --git a/tests/kernels/test_awq_cpu_fused_post_init.py b/tests/kernels/test_awq_cpu_fused_post_init.py index f400501ca..15d1fb5cd 100644 --- a/tests/kernels/test_awq_cpu_fused_post_init.py +++ b/tests/kernels/test_awq_cpu_fused_post_init.py @@ -7,9 +7,9 @@ import torch from gptqmodel.adapter.adapter import Lora -from gptqmodel.nn_modules.qlinear.gemm_hf_kernel_awq import HFKernelAwqLinear -from gptqmodel.nn_modules.qlinear.torch_fused_awq import TorchFusedAwqQuantLinear -from gptqmodel.nn_modules.qlinear.torch_int8_awq import TorchInt8AwqQuantLinear +from gptqmodel.nn_modules.qlinear.torch_aten_kernel_awq import TorchAtenAwqLinear +from gptqmodel.nn_modules.qlinear.torch_fused_awq import TorchFusedAwqLinear +from gptqmodel.nn_modules.qlinear.torch_int8_awq import TorchInt8AwqLinear from gptqmodel.quantization import FORMAT, METHOD from gptqmodel.utils.backend import BACKEND from gptqmodel.utils.importer import select_quant_linear @@ -18,15 +18,15 @@ @pytest.mark.parametrize( "kernel_cls", [ - pytest.param(TorchFusedAwqQuantLinear, id="torch_fused_awq"), - pytest.param(HFKernelAwqLinear, id="hf_kernel_awq"), - pytest.param(TorchInt8AwqQuantLinear, id="torch_int8_awq"), + pytest.param(TorchFusedAwqLinear, id="torch_fused_awq"), + pytest.param(TorchAtenAwqLinear, id="torch_aten_awq"), + pytest.param(TorchInt8AwqLinear, id="torch_int8_awq"), ], ) def test_awq_fused_post_init_calls_adapter(monkeypatch, kernel_cls): - if kernel_cls is HFKernelAwqLinear: + if kernel_cls is TorchAtenAwqLinear: monkeypatch.setattr( - HFKernelAwqLinear, + TorchAtenAwqLinear, "cached_validate_once", classmethod(lambda cls: (True, None)), ) @@ -61,9 +61,9 @@ def spy_post_init(self, weight_key, device, **kwargs): assert getattr(module, "wf_unsqueeze_neg_one", None) is None -def test_hf_kernel_awq_backend_selection(monkeypatch): +def test_torch_aten_awq_backend_selection(monkeypatch): monkeypatch.setattr( - HFKernelAwqLinear, + TorchAtenAwqLinear, "cached_validate_once", classmethod(lambda cls: (True, None)), ) @@ -74,9 +74,9 @@ def test_hf_kernel_awq_backend_selection(monkeypatch): desc_act=False, sym=True, device=None, - backend=BACKEND.HF_KERNEL_AWQ, + backend=BACKEND.AWQ_TORCH_ATEN, format=FORMAT.GEMM, quant_method=METHOD.AWQ, pack_dtype=torch.int32, ) - assert qlinear_cls is HFKernelAwqLinear + assert qlinear_cls is TorchAtenAwqLinear diff --git a/tests/kernels/test_awq_cuda_fp32_reduce.py b/tests/kernels/test_awq_cuda_fp32_reduce.py new file mode 100644 index 000000000..4e25d4bd2 --- /dev/null +++ b/tests/kernels/test_awq_cuda_fp32_reduce.py @@ -0,0 +1,203 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch + +from gptqmodel.quantization.awq.utils.packing_utils import dequantize_gemm +from gptqmodel.utils.awq import awq_gemm_forward, awq_runtime_available +from gptqmodel.utils.paroquant import apply_paroquant_rotation, build_identity_rotation_buffers + + +def _pack_awq_tensor(unpacked: torch.Tensor, bits: int) -> torch.Tensor: + pack_factor = 32 // bits + order_map = [0, 2, 4, 6, 1, 3, 5, 7] + + assert unpacked.shape[1] % pack_factor == 0 + packed = torch.zeros( + (unpacked.shape[0], unpacked.shape[1] // pack_factor), + dtype=torch.int32, + ) + for col in range(unpacked.shape[1] // pack_factor): + for i, order in enumerate(order_map): + value = unpacked[:, col * pack_factor + order].to(torch.int32) + packed[:, col] |= value << (i * bits) + return packed + + +def _make_packed_buffers(bits: int, in_features: int, out_features: int, group_size: int): + groups = in_features // group_size + int_weight = torch.randint(0, 2**bits, size=(in_features, out_features), dtype=torch.int32) + zero_points = torch.randint(0, 2**bits, size=(groups, out_features), dtype=torch.int32) + scales = (torch.rand(groups, out_features, dtype=torch.float16) * 0.5) + 0.75 + + return ( + _pack_awq_tensor(int_weight, bits), + _pack_awq_tensor(zero_points, bits), + scales, + ) + + +def _dense_reference(x: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, scales: torch.Tensor, bits: int, group_size: int): + dense_weight = dequantize_gemm( + qweight=qweight, + qzeros=qzeros, + scales=scales, + bits=bits, + group_size=group_size, + ).to(device=x.device, dtype=x.dtype) + return torch.matmul(x, dense_weight) + + +def _require_dtype_support(dtype: torch.dtype) -> None: + if dtype != torch.bfloat16: + return + major, _minor = torch.cuda.get_device_capability() + if major < 8: + pytest.skip("BFloat16 AWQ CUDA kernels require compute capability >= 8.0.") + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for AWQ CUDA fp32-reduce test") +def test_awq_cuda_fp32_reduce_reduces_dense_error(): + if not awq_runtime_available(): + pytest.skip("AWQ CUDA fp32-reduce extension entrypoint unavailable.") + + torch.manual_seed(0) + bits = 4 + batch = 1 + seq = 128 + in_features = 1024 + out_features = 1024 + group_size = 128 + qweight, qzeros, scales = _make_packed_buffers(bits, in_features, out_features, group_size) + + x = torch.randn(batch * seq, in_features, device="cuda", dtype=torch.float16) + qweight = qweight.cuda() + qzeros = qzeros.cuda() + scales = scales.cuda() + + reference = _dense_reference(x, qweight, qzeros, scales, bits=bits, group_size=group_size) + + with torch.inference_mode(): + legacy = awq_gemm_forward(x, qweight, scales, qzeros, 8, False) + candidate = awq_gemm_forward(x, qweight, scales, qzeros, 8, True) + + legacy_abs = (legacy - reference).abs() + candidate_abs = (candidate - reference).abs() + + assert candidate_abs.max().item() <= legacy_abs.max().item() + assert candidate_abs.mean().item() < legacy_abs.mean().item() * 0.1 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for ParoQuant CUDA fp32-reduce test") +def test_paroquant_cuda_fp32_reduce_reduces_dense_error(): + if not awq_runtime_available(): + pytest.skip("AWQ CUDA fp32-reduce extension entrypoint unavailable.") + + torch.manual_seed(0) + bits = 4 + batch = 1 + seq = 128 + in_features = 1024 + out_features = 1024 + group_size = 128 + krot = 8 + qweight, qzeros, scales = _make_packed_buffers(bits, in_features, out_features, group_size) + + x = torch.randn(batch * seq, in_features, device="cuda", dtype=torch.float16) + qweight = qweight.cuda() + qzeros = qzeros.cuda() + scales = scales.cuda() + pairs, theta, channel_scales = build_identity_rotation_buffers( + in_features=in_features, + group_size=group_size, + krot=krot, + device="cuda", + dtype=torch.float16, + ) + theta.uniform_(-0.2, 0.2) + channel_scales.uniform_(0.75, 1.25) + rotated = apply_paroquant_rotation(x, pairs, theta, scales=channel_scales, group_size=group_size) + + reference = _dense_reference(rotated, qweight, qzeros, scales, bits=bits, group_size=group_size) + + with torch.inference_mode(): + legacy = awq_gemm_forward(rotated, qweight, scales, qzeros, 8, False) + candidate = awq_gemm_forward(rotated, qweight, scales, qzeros, 8, True) + + legacy_abs = (legacy - reference).abs() + candidate_abs = (candidate - reference).abs() + + assert candidate_abs.max().item() <= legacy_abs.max().item() + assert candidate_abs.mean().item() < legacy_abs.mean().item() * 0.1 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for AWQ fused split-K reduction test") +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_awq_cuda_fused_splitk_reduce_matches_default(dtype, monkeypatch): + if not awq_runtime_available(): + pytest.skip("AWQ CUDA extension entrypoint unavailable.") + + _require_dtype_support(dtype) + torch.manual_seed(0) + bits = 4 + in_features = 1024 + out_features = 1024 + group_size = 128 + qweight, qzeros, scales = _make_packed_buffers(bits, in_features, out_features, group_size) + + x = torch.randn(128, in_features, device="cuda", dtype=dtype) + qweight = qweight.cuda() + qzeros = qzeros.cuda() + scales = scales.cuda().to(dtype=dtype) + + monkeypatch.setenv("GPTQMODEL_AWQ_DISABLE_FUSED_SPLITK_REDUCE", "1") + with torch.inference_mode(): + baseline = awq_gemm_forward(x, qweight, scales, qzeros, 4, True) + + monkeypatch.delenv("GPTQMODEL_AWQ_DISABLE_FUSED_SPLITK_REDUCE", raising=False) + with torch.inference_mode(): + candidate = awq_gemm_forward(x, qweight, scales, qzeros, 4, True) + + assert torch.equal(candidate, baseline) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for ParoQuant fused split-K reduction test") +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_paroquant_fused_splitk_reduce_matches_default(dtype, monkeypatch): + if not awq_runtime_available(): + pytest.skip("AWQ CUDA extension entrypoint unavailable.") + + _require_dtype_support(dtype) + torch.manual_seed(0) + bits = 4 + in_features = 1024 + out_features = 1024 + group_size = 128 + krot = 8 + qweight, qzeros, scales = _make_packed_buffers(bits, in_features, out_features, group_size) + + x = torch.randn(128, in_features, device="cuda", dtype=dtype) + qweight = qweight.cuda() + qzeros = qzeros.cuda() + scales = scales.cuda().to(dtype=dtype) + pairs, theta, channel_scales = build_identity_rotation_buffers( + in_features=in_features, + group_size=group_size, + krot=krot, + device="cuda", + dtype=dtype, + ) + theta.uniform_(-0.2, 0.2) + channel_scales.uniform_(0.75, 1.25) + rotated = apply_paroquant_rotation(x, pairs, theta, scales=channel_scales, group_size=group_size) + + monkeypatch.setenv("GPTQMODEL_AWQ_DISABLE_FUSED_SPLITK_REDUCE", "1") + with torch.inference_mode(): + baseline = awq_gemm_forward(rotated, qweight, scales, qzeros, 4, True) + + monkeypatch.delenv("GPTQMODEL_AWQ_DISABLE_FUSED_SPLITK_REDUCE", raising=False) + with torch.inference_mode(): + candidate = awq_gemm_forward(rotated, qweight, scales, qzeros, 4, True) + + assert torch.equal(candidate, baseline) diff --git a/tests/kernels/test_awq_machete_marlin.py b/tests/kernels/test_awq_machete_marlin.py new file mode 100644 index 000000000..f7a36ccde --- /dev/null +++ b/tests/kernels/test_awq_machete_marlin.py @@ -0,0 +1,186 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pytest +import torch + +from gptqmodel.nn_modules.qlinear.machete_awq import AwqMacheteLinear +from gptqmodel.nn_modules.qlinear.marlin_awq import AwqMarlinLinear, marlin_import_exception +from gptqmodel.nn_modules.qlinear.torch_awq import AwqTorchLinear +from gptqmodel.utils.machete import _validate_machete_device_support, machete_runtime_error +from gptqmodel.utils.marlin import marlin_runtime_available, marlin_runtime_error + + +def _pack_awq_tensor(unpacked: torch.Tensor, bits: int) -> torch.Tensor: + pack_factor = 32 // bits + order_map = [0, 2, 4, 6, 1, 3, 5, 7] + + assert unpacked.shape[1] % pack_factor == 0 + packed = torch.zeros( + (unpacked.shape[0], unpacked.shape[1] // pack_factor), + dtype=torch.int32, + ) + for col in range(unpacked.shape[1] // pack_factor): + for lane, order in enumerate(order_map): + value = unpacked[:, col * pack_factor + order].to(torch.int32) + packed[:, col] |= value << (lane * bits) + return packed + + +def _mock_awq_module_tensors( + *, + bits: int, + group_size: int, + in_features: int, + out_features: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + groups = in_features // group_size + maxq = (1 << bits) - 1 + + float_weight = torch.randn(in_features, out_features, dtype=torch.float32) * 0.2 + weight_groups = float_weight.view(groups, group_size, out_features) + + w_min = weight_groups.amin(dim=1) + w_max = weight_groups.amax(dim=1) + scales = ((w_max - w_min).clamp_min(1e-6) / maxq).to(torch.float16) + zero_points = torch.round((-w_min / scales.to(torch.float32))).clamp_(0, maxq).to(torch.int32) + quantized = torch.round( + weight_groups / scales.to(torch.float32).unsqueeze(1) + zero_points.unsqueeze(1) + ).clamp_(0, maxq).to(torch.int32) + + qweight = _pack_awq_tensor(quantized.view(in_features, out_features), bits) + qzeros = _pack_awq_tensor(zero_points, bits) + bias = torch.randn(out_features, dtype=torch.float16) + + return qweight, qzeros, scales, bias + + +def _build_awq_module( + module_cls, + *, + device: torch.device, + bits: int, + group_size: int, + in_features: int, + out_features: int, + qweight: torch.Tensor, + qzeros: torch.Tensor, + scales: torch.Tensor, + bias: torch.Tensor, +): + module = module_cls( + bits=bits, + group_size=group_size, + sym=False, + desc_act=False, + in_features=in_features, + out_features=out_features, + bias=True, + register_buffers=True, + ).to(device) + + with torch.no_grad(): + module.qweight.copy_(qweight.to(device)) + module.qzeros.copy_(qzeros.to(device)) + module.scales.copy_(scales.to(torch.float16).to(device)) + module.bias.copy_(bias.to(torch.float16).to(device)) + + module.post_init() + module.eval() + return module + + +def _assert_awq_candidate_matches_torch( + module_cls, + *, + device: torch.device, + bits: int, + group_size: int, + in_features: int, + out_features: int, + atol: float, + rtol: float, +) -> None: + torch.manual_seed(11) + qweight, qzeros, scales, bias = _mock_awq_module_tensors( + bits=bits, + group_size=group_size, + in_features=in_features, + out_features=out_features, + ) + + baseline = _build_awq_module( + AwqTorchLinear, + device=device, + bits=bits, + group_size=group_size, + in_features=in_features, + out_features=out_features, + qweight=qweight, + qzeros=qzeros, + scales=scales, + bias=bias, + ) + candidate = _build_awq_module( + module_cls, + device=device, + bits=bits, + group_size=group_size, + in_features=in_features, + out_features=out_features, + qweight=qweight, + qzeros=qzeros, + scales=scales, + bias=bias, + ) + + x = torch.randn((32, in_features), device=device, dtype=torch.float16) + with torch.inference_mode(): + expected = baseline(x) + actual = candidate(x) + repeat = candidate(x) + torch.cuda.synchronize(device) + + assert candidate.qzeros.numel() > 0 + torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol) + torch.testing.assert_close(repeat, expected, atol=atol, rtol=rtol) + + +@pytest.mark.cuda +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_awq_marlin_cuda_zero_points_match_torch_awq(): + if marlin_import_exception is not None: + pytest.skip(f"AWQ Marlin kernel unavailable: {marlin_import_exception}") + if not marlin_runtime_available(torch.float16): + pytest.skip(marlin_runtime_error(torch.float16)) + + _assert_awq_candidate_matches_torch( + AwqMarlinLinear, + device=torch.device("cuda:0"), + bits=4, + group_size=64, + in_features=256, + out_features=128, + atol=8e-3, + rtol=8e-3, + ) + + +@pytest.mark.cuda +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_awq_machete_cuda_zero_points_match_torch_awq(): + if not _validate_machete_device_support(): + pytest.skip(machete_runtime_error()) + + _assert_awq_candidate_matches_torch( + AwqMacheteLinear, + device=torch.device("cuda:0"), + bits=4, + group_size=64, + in_features=128, + out_features=128, + atol=1e-2, + rtol=1e-2, + ) diff --git a/tests/kernels/test_awq_torch.py b/tests/kernels/test_awq_torch.py index ba387612c..7023c64b4 100644 --- a/tests/kernels/test_awq_torch.py +++ b/tests/kernels/test_awq_torch.py @@ -6,7 +6,7 @@ import pytest import torch -from gptqmodel.nn_modules.qlinear.torch_awq import AwqTorchQuantLinear +from gptqmodel.nn_modules.qlinear.torch_awq import AwqTorchLinear from gptqmodel.quantization import FORMAT, METHOD from gptqmodel.quantization.awq.utils.packing_utils import dequantize_gemm from gptqmodel.utils.backend import BACKEND @@ -29,10 +29,10 @@ def _pack_awq_tensor(unpacked: torch.Tensor, bits: int) -> torch.Tensor: return packed -@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) def test_awq_torch_matches_manual_dequant(dtype): - if dtype not in AwqTorchQuantLinear.SUPPORTS_DTYPES: - pytest.skip(f"dtype {dtype} not supported by AwqTorchQuantLinear") + if dtype not in AwqTorchLinear.SUPPORTS_DTYPES: + pytest.skip(f"dtype {dtype} not supported by AwqTorchLinear") torch.manual_seed(0) bits = 4 @@ -48,13 +48,13 @@ def test_awq_torch_matches_manual_dequant(dtype): int_weight = torch.randint(0, 2**bits, size=(in_features, out_features), dtype=torch.int32) zero_points = torch.randint(0, 2**bits, size=(groups, pack_cols), dtype=torch.int32) - scales = (torch.rand(groups, pack_cols, dtype=torch.float16) * 2.0) + 0.25 - bias = torch.randn(out_features, dtype=torch.float16) + scales = ((torch.rand(groups, pack_cols, dtype=torch.float32) * 2.0) + 0.25).to(dtype) + bias = torch.randn(out_features, dtype=dtype) qweight = _pack_awq_tensor(int_weight, bits) qzeros = _pack_awq_tensor(zero_points, bits) - module = AwqTorchQuantLinear( + module = AwqTorchLinear( bits=bits, group_size=group_size, sym=True, @@ -62,22 +62,24 @@ def test_awq_torch_matches_manual_dequant(dtype): in_features=in_features, out_features=out_features, bias=True, + dtype=dtype, register_buffers=True, ) module.qweight.copy_(qweight) module.qzeros.copy_(qzeros) - module.scales = module.scales.to(dtype=torch.float16) - module.scales.copy_(scales.to(torch.float16)) - module.bias.copy_(bias) + module.scales.copy_(scales.to(module.scales.dtype)) + module.bias.copy_(bias.to(module.bias.dtype)) module.post_init() module.eval() batch = 4 x = torch.randn(batch, in_features, dtype=dtype) - bias_expected = module.bias + output_first = module(x) + output_second = module(x) + bias_expected = module.bias.to(dtype=dtype) dequant_weight = dequantize_gemm( qweight=module.qweight, qzeros=module.qzeros, @@ -85,12 +87,11 @@ def test_awq_torch_matches_manual_dequant(dtype): bits=bits, group_size=group_size, ).to(dtype=dtype) - expected = torch.matmul(x.to(dtype), dequant_weight) expected = expected + bias_expected - output_first = module(x) - output_second = module(x) + assert output_first.dtype == dtype + assert output_second.dtype == dtype atol = 1e-4 if dtype == torch.float32 else 5e-3 rtol = 1e-4 if dtype == torch.float32 else 5e-3 @@ -110,4 +111,4 @@ def test_awq_torch_backend_selection(): quant_method=METHOD.AWQ, pack_dtype=torch.int32, ) - assert qlinear_cls is AwqTorchQuantLinear + assert qlinear_cls is AwqTorchLinear diff --git a/tests/kernels/test_awq_torch_fused.py b/tests/kernels/test_awq_torch_fused.py index d3644e314..a47820b4e 100644 --- a/tests/kernels/test_awq_torch_fused.py +++ b/tests/kernels/test_awq_torch_fused.py @@ -13,8 +13,8 @@ from safetensors import safe_open from tabulate import tabulate -from gptqmodel.nn_modules.qlinear.torch_awq import AwqTorchQuantLinear -from gptqmodel.nn_modules.qlinear.torch_fused_awq import TorchFusedAwqQuantLinear +from gptqmodel.nn_modules.qlinear.torch_awq import AwqTorchLinear +from gptqmodel.nn_modules.qlinear.torch_fused_awq import TorchFusedAwqLinear from gptqmodel.utils.torch import TORCH_HAS_FUSED_OPS @@ -117,7 +117,7 @@ def test_torch_fused_awq_matches_checkpoint_module(device_str: str): device = torch.device(device_str) - awq_module = AwqTorchQuantLinear( + awq_module = AwqTorchLinear( bits=bits, group_size=group_size, sym=True, @@ -127,7 +127,7 @@ def test_torch_fused_awq_matches_checkpoint_module(device_str: str): bias=bias is not None, register_buffers=True, ) - fused_module = TorchFusedAwqQuantLinear( + fused_module = TorchFusedAwqLinear( bits=bits, group_size=group_size, sym=True, diff --git a/tests/kernels/test_awq_triton_accum.py b/tests/kernels/test_awq_triton_accum.py new file mode 100644 index 000000000..29054524e --- /dev/null +++ b/tests/kernels/test_awq_triton_accum.py @@ -0,0 +1,141 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch + +from gptqmodel.nn_modules.qlinear.gemm_awq_triton import AwqGEMMTritonLinear +from gptqmodel.quantization.awq.modules.triton.gemm import awq_gemm_triton +from gptqmodel.quantization.awq.utils.packing_utils import dequantize_gemm + + +def _pack_awq_tensor(unpacked: torch.Tensor, bits: int) -> torch.Tensor: + pack_factor = 32 // bits + order_map = [0, 2, 4, 6, 1, 3, 5, 7] + + assert unpacked.shape[1] % pack_factor == 0 + packed = torch.zeros( + (unpacked.shape[0], unpacked.shape[1] // pack_factor), + dtype=torch.int32, + ) + for col in range(unpacked.shape[1] // pack_factor): + for i, order in enumerate(order_map): + value = unpacked[:, col * pack_factor + order].to(torch.int32) + packed[:, col] |= value << (i * bits) + return packed + + +def _make_packed_buffers(bits: int, in_features: int, out_features: int, group_size: int): + groups = in_features // group_size + int_weight = torch.randint(0, 2**bits, size=(in_features, out_features), dtype=torch.int32) + zero_points = torch.randint(0, 2**bits, size=(groups, out_features), dtype=torch.int32) + scales = (torch.rand(groups, out_features, dtype=torch.float16) * 0.5) + 0.75 + bias = torch.randn(out_features, dtype=torch.float16) + + return ( + _pack_awq_tensor(int_weight, bits), + _pack_awq_tensor(zero_points, bits), + scales, + bias, + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for AWQ Triton kernel parity test") +def test_awq_triton_fp32_accum_matches_manual_dequant(): + pytest.importorskip("triton") + torch.manual_seed(0) + + bits = 4 + in_features = 512 + out_features = 512 + group_size = 128 + qweight, qzeros, scales, bias = _make_packed_buffers(bits, in_features, out_features, group_size) + + module = AwqGEMMTritonLinear( + bits=bits, + group_size=group_size, + sym=True, + desc_act=False, + in_features=in_features, + out_features=out_features, + bias=True, + register_buffers=True, + ).cuda() + module.qweight.copy_(qweight.cuda()) + module.qzeros.copy_(qzeros.cuda()) + module.scales.copy_(scales.cuda()) + module.bias.copy_(bias.cuda()) + module.post_init() + module.eval() + + x = torch.randn(1, 128, in_features, device="cuda", dtype=torch.float16) + dequant_weight = dequantize_gemm( + qweight=module.qweight, + qzeros=module.qzeros, + scales=module.scales, + bits=bits, + group_size=group_size, + ).to(device=x.device, dtype=x.dtype) + expected = torch.matmul(x.reshape(-1, in_features), dequant_weight).reshape(1, 128, out_features) + expected = expected + module.bias + + with torch.inference_mode(): + actual = module(x) + + abs_diff = (actual - expected).abs() + assert abs_diff.max().item() <= 1.0 + assert abs_diff.mean().item() <= 0.02 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for AWQ Triton accumulation test") +def test_awq_triton_fp32_accum_reduces_dense_error(): + pytest.importorskip("triton") + torch.manual_seed(0) + + bits = 4 + batch = 1 + seq = 128 + in_features = 1024 + out_features = 1024 + group_size = 128 + qweight, qzeros, scales, _bias = _make_packed_buffers(bits, in_features, out_features, group_size) + + x = torch.randn(batch * seq, in_features, device="cuda", dtype=torch.float16) + qweight = qweight.cuda() + qzeros = qzeros.cuda() + scales = scales.cuda() + + dense_weight = dequantize_gemm( + qweight=qweight, + qzeros=qzeros, + scales=scales, + bits=bits, + group_size=group_size, + ).to(device=x.device, dtype=x.dtype) + reference = torch.matmul(x, dense_weight) + + with torch.inference_mode(): + legacy = awq_gemm_triton( + x, + qweight, + scales, + qzeros, + split_k_iters=8, + fp32_accum=False, + output_dtype=x.dtype, + ) + candidate = awq_gemm_triton( + x, + qweight, + scales, + qzeros, + split_k_iters=8, + fp32_accum=True, + output_dtype=x.dtype, + ) + + legacy_abs = (legacy - reference).abs() + candidate_abs = (candidate - reference).abs() + + assert candidate_abs.max().item() < legacy_abs.max().item() + assert candidate_abs.mean().item() < legacy_abs.mean().item() diff --git a/tests/kernels/test_base_autotune.py b/tests/kernels/test_base_autotune.py new file mode 100644 index 000000000..11bdb246f --- /dev/null +++ b/tests/kernels/test_base_autotune.py @@ -0,0 +1,104 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from gptqmodel.models._const import DEVICE, PLATFORM +from gptqmodel.nn_modules.qlinear import BaseQuantLinear +from gptqmodel.quantization import METHOD +from gptqmodel.utils.backend import BACKEND + + +class _AutotuneTestKernel(BaseQuantLinear): + SUPPORTS_BACKENDS = [BACKEND.TORCH] + SUPPORTS_METHODS = [METHOD.GPTQ] + SUPPORTS_FORMATS = {} + SUPPORTS_BITS = [4] + SUPPORTS_GROUP_SIZE = [4] + SUPPORTS_DESC_ACT = [False] + SUPPORTS_SYM = [True] + SUPPORTS_SHARDS = False + SUPPORTS_TRAINING = True + SUPPORTS_TRAINING_USE_TORCH_KERNEL = False + SUPPORTS_AUTO_PADDING = False + SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [1] + SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [1] + SUPPORTS_PACK_DTYPES = [torch.int32] + SUPPORTS_ADAPTERS = [] + SUPPORTS_DEVICES = [DEVICE.ALL] + SUPPORTS_PLATFORM = [PLATFORM.ALL] + SUPPORTS_DTYPES = [torch.float16, torch.float32, torch.bfloat16] + + def __init__(self, *, autotune: bool): + self.AUTOTUNE = autotune + self.autotune_calls = 0 + super().__init__( + bits=4, + in_features=4, + out_features=4, + bias=False, + backend=BACKEND.TORCH, + adapter=None, + register_buffers=False, + validate_kwargs={ + "group_size": 4, + "desc_act": False, + "sym": True, + "pack_dtype": torch.int32, + }, + ) + + def _autotune(self, x: torch.Tensor): + self.autotune_calls += 1 + return (x.shape, x.dtype) + + def forward(self, x: torch.Tensor): + self.maybe_autotune(x) + return x + + +def test_base_quant_linear_autotune_runs_once_per_instance(): + module = _AutotuneTestKernel(autotune=True).eval() + x = torch.randn(2, 4) + + module(x) + module(x) + + assert module.autotune_calls == 1 + assert module.get_autotune_result() == (x.shape, x.dtype) + + +def test_base_quant_linear_autotune_disabled_by_default(): + module = _AutotuneTestKernel(autotune=False).eval() + x = torch.randn(2, 4) + + module(x) + module(x) + + assert module.autotune_calls == 0 + assert module.get_autotune_result() is None + + +def test_base_quant_linear_clear_autotune_reenables_autotune(): + module = _AutotuneTestKernel(autotune=True).eval() + x = torch.randn(2, 4) + + module(x) + assert module.autotune_calls == 1 + + module.clear_autotune() + module(x) + + assert module.autotune_calls == 2 + + +def test_base_quant_linear_train_mode_clears_autotune_state(): + module = _AutotuneTestKernel(autotune=True).eval() + x = torch.randn(2, 4) + + module(x) + module.train(True) + module.train(False) + module(x) + + assert module.autotune_calls == 2 diff --git a/tests/kernels/test_exllamav3_kernel.py b/tests/kernels/test_exllamav3_kernel.py new file mode 100644 index 000000000..b1ce8f6d8 --- /dev/null +++ b/tests/kernels/test_exllamav3_kernel.py @@ -0,0 +1,151 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pytest +import torch + +from gptqmodel.nn_modules.exllamav3 import ExllamaV3Linear +from gptqmodel.nn_modules.exllamav3_torch import ExllamaV3TorchLinear + + +def _get_quantize_exl3(): + if not torch.cuda.is_available(): + pytest.skip("EXL3 kernel verification requires CUDA/HIP.") + + try: + from gptqmodel.exllamav3.modules.quant.exl3_lib.quantize import quantize_exl3 + except Exception as exc: # pragma: no cover - environment dependent + pytest.skip(f"EXL3 quantizer unavailable: {exc}") + + return quantize_exl3 + + +def _clone_tensors(tensors: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + return { + name: tensor.clone() if isinstance(tensor, torch.Tensor) else tensor + for name, tensor in tensors.items() + } + + +def _build_kernels(bits: int, codebook: str) -> tuple[torch.Tensor, ExllamaV3TorchLinear, ExllamaV3Linear]: + quantize_exl3 = _get_quantize_exl3() + + torch.manual_seed(17) + in_features = { + "3inst": 128, + "mcg": 256, + "mul1": 384, + }[codebook] + out_features = 128 + weight = torch.randn(in_features, out_features, device="cuda", dtype=torch.float32) + quant_device = weight.device + hessian = torch.eye(in_features, device="cuda", dtype=torch.float32) + bias = torch.randn(out_features, device="cuda", dtype=torch.float16) + + quant_args: dict[str, object] = { + "K": bits, + "devices": [quant_device], + "apply_out_scales": True, + "sigma_reg": 0.025, + "seed": 787, + } + if codebook == "mcg": + quant_args["mcg"] = True + elif codebook == "mul1": + quant_args["mul1"] = True + + weight_q, _, out_tensors = quantize_exl3( + weight=weight, + H_data={"H": hessian, "count": in_features, "finalized": False}, + quant_args=quant_args, + return_weight_q=True, + ) + + base_tensors = dict(out_tensors) + base_tensors["bias"] = bias + + torch_kernel = ExllamaV3TorchLinear.from_tensors( + in_features=in_features, + out_features=out_features, + name=f"kernel_{codebook}_{bits}", + tensors=_clone_tensors(base_tensors), + ).eval() + + cuda_kernel = ExllamaV3Linear.from_tensors( + in_features=in_features, + out_features=out_features, + name=f"kernel_{codebook}_{bits}", + tensors=_clone_tensors(base_tensors), + ).eval() + + try: + cuda_kernel.post_init() + except Exception as exc: # pragma: no cover - environment dependent + pytest.skip(f"EXL3 CUDA runtime unavailable: {exc}") + + return weight_q, torch_kernel, cuda_kernel + + +@pytest.mark.parametrize( + ("bits", "codebook"), + [ + (2, "3inst"), + (2, "mcg"), + (4, "mul1"), + ], +) +def test_exllamav3_torch_weight_matches_quantized_reference(bits: int, codebook: str): + weight_q, torch_kernel, _ = _build_kernels(bits, codebook) + + with torch.inference_mode(): + dense_weight = torch_kernel.get_weight_tensor(dtype=torch.float32) + + assert dense_weight.dtype == torch.float32 + assert dense_weight.shape == weight_q.shape + assert torch.allclose(dense_weight, weight_q, rtol=3e-3, atol=3e-3) + + +@pytest.mark.parametrize( + ("bits", "codebook"), + [ + (2, "3inst"), + (2, "mcg"), + (4, "mul1"), + ], +) +def test_exllamav3_cuda_small_batch_matches_torch_reference(bits: int, codebook: str): + _, torch_kernel, cuda_kernel = _build_kernels(bits, codebook) + x = torch.randn(9, torch_kernel.in_features, device="cuda", dtype=torch.float16) + + with torch.inference_mode(): + torch_out = torch_kernel(x) + cuda_out = cuda_kernel(x) + + assert torch_out.dtype == x.dtype + assert cuda_out.dtype == x.dtype + assert cuda_out.shape == torch_out.shape + assert torch.allclose(cuda_out, torch_out, rtol=8e-2, atol=8e-2) + + +@pytest.mark.parametrize( + ("bits", "codebook"), + [ + (2, "3inst"), + (2, "mcg"), + (4, "mul1"), + ], +) +def test_exllamav3_cuda_large_batch_matches_torch_reference(bits: int, codebook: str): + _, torch_kernel, cuda_kernel = _build_kernels(bits, codebook) + x = torch.randn(40, torch_kernel.in_features, device="cuda", dtype=torch.float16) + + with torch.inference_mode(): + torch_out = torch_kernel(x) + cuda_out = cuda_kernel(x) + + assert torch_out.dtype == x.dtype + assert cuda_out.dtype == x.dtype + assert cuda_out.shape == torch_out.shape + assert torch.allclose(cuda_out, torch_out, rtol=8e-2, atol=8e-2) diff --git a/tests/kernels/test_exllamav3_kernel_map_packed.py b/tests/kernels/test_exllamav3_kernel_map_packed.py new file mode 100644 index 000000000..d65a758e2 --- /dev/null +++ b/tests/kernels/test_exllamav3_kernel_map_packed.py @@ -0,0 +1,31 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import subprocess +import sys +from pathlib import Path + +import pytest + + +REPO_ROOT = Path(__file__).resolve().parents[2] +VALIDATOR = REPO_ROOT / "scripts/generate_exl3_kernel_map_packed.py" +PACKED_HEADER = REPO_ROOT / "gptqmodel_ext/exllamav3/quant/exl3_kernel_map_packed.cuh" + + +@pytest.mark.timeout(120) +def test_exllamav3_kernel_map_matches_exllamav3_original_legacy_header(tmp_path: Path): + generated = tmp_path / "exl3_kernel_map_packed.cuh" + result = subprocess.run( + [sys.executable, str(VALIDATOR), "--output", str(generated)], + cwd=REPO_ROOT, + capture_output=True, + text=True, + check=False, + ) + + assert result.returncode == 0, result.stderr or result.stdout + assert "validated legacy lookup from https://raw.githubusercontent.com/" in result.stdout + assert generated.read_text() == PACKED_HEADER.read_text() diff --git a/tests/kernels/test_fallback.py b/tests/kernels/test_fallback.py new file mode 100644 index 000000000..e9805beb1 --- /dev/null +++ b/tests/kernels/test_fallback.py @@ -0,0 +1,463 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import os + +import torch +from logbar import LogBar +from transformers import AutoModelForCausalLM + +from gptqmodel.nn_modules.qlinear.torch import TorchLinear +from gptqmodel.quantization.config import ( + Fallback, + FallbackStrategy, + QuantizeConfig, + SmoothLog, + SmoothMAD, + SmoothMSE, + SmoothOutlier, + SmoothPercentile, + SmoothPercentileAsymmetric, + SmoothRowCol, + SmoothSoftNorm, +) +from gptqmodel.quantization.gptq import GPTQ +from gptqmodel.utils.model import convert_gptq_v1_to_v2_format_module + + +MODEL_DIR = "/monster/data/model/llama3-8B" #"/monster/data/model/Llama-3.2-1B-Instruct" # + +log = LogBar.shared() + +DEVICE = torch.device("cuda:0") +ATOL_CHECKS = [0.001, 0.005, 0.01, 0.05, 0.1] + + +def _load_down_proj(dtype: torch.dtype, device: torch.device) -> torch.nn.Module: + model = AutoModelForCausalLM.from_pretrained( + MODEL_DIR, + dtype=dtype, + device_map="cpu", + low_cpu_mem_usage=True, + ) + model.eval() + down_proj = model.model.layers[0].mlp.down_proj + down_proj = down_proj.to(device=device, dtype=dtype) + del model + if device.type == "cuda": + torch.cuda.empty_cache() + return down_proj + + +def _quantize_to_torch_linear( + layer: torch.nn.Module, + fallback: Fallback, + device: torch.device, +) -> TorchLinear: + qcfg = QuantizeConfig( + bits=4, + group_size=128, + sym=False, + desc_act=False, + fallback=fallback, + ) + + gptq = GPTQ(layer, qcfg) + gptq.quantizer.configure(perchannel=True) + gptq.fallback = qcfg.fallback + + wq, scales, zeros, g_idx, *_ = gptq.quantize(blocksize=128) + + packed_linear = torch.nn.Linear( + layer.in_features, + layer.out_features, + bias=layer.bias is not None, + device="cpu", + dtype=layer.weight.dtype, + ) + packed_linear.weight.data = wq.detach().to(device="cpu", dtype=layer.weight.dtype) + if layer.bias is not None: + packed_linear.bias.data = layer.bias.detach().to(device="cpu", dtype=layer.bias.dtype) + + qlinear = TorchLinear( + bits=qcfg.bits, + group_size=qcfg.group_size, + sym=qcfg.sym, + desc_act=qcfg.desc_act, + in_features=layer.in_features, + out_features=layer.out_features, + bias=layer.bias is not None, + pack_dtype=qcfg.pack_dtype, + adapter=None, + ) + qlinear.pack( + linear=packed_linear, + scales=scales.to(device="cpu"), + zeros=zeros.to(device="cpu"), + g_idx=g_idx.to(device="cpu"), + ) + convert_gptq_v1_to_v2_format_module( + module=qlinear, + bits=qcfg.bits, + pack_dtype=qcfg.pack_dtype, + ) + qlinear = qlinear.to(device=device) + qlinear.post_init() + qlinear.eval() + return qlinear + + +def _clone_linear(layer: torch.nn.Module, device: torch.device) -> torch.nn.Module: + cloned = torch.nn.Linear( + layer.in_features, + layer.out_features, + bias=layer.bias is not None, + device=device, + dtype=layer.weight.dtype, + ) + cloned.weight.data.copy_(layer.weight.data) + if layer.bias is not None: + cloned.bias.data.copy_(layer.bias.data) + return cloned + + +def _init_stats(): + return { + "sum": 0.0, + "count": 0, + "max": None, + "min": None, + "passes": dict.fromkeys(ATOL_CHECKS, 0), + } + + +def _update_stats(stats, diff: torch.Tensor): + diff_sum = diff.sum().item() + diff_max = diff.max().item() + diff_min = diff.min().item() + for atol in ATOL_CHECKS: + stats["passes"][atol] += torch.count_nonzero(diff <= atol).item() + stats["sum"] += diff_sum + stats["count"] += diff.numel() + stats["max"] = diff_max if stats["max"] is None else max(stats["max"], diff_max) + stats["min"] = diff_min if stats["min"] is None else min(stats["min"], diff_min) + + +def _finalize_stats(stats): + mean = stats["sum"] / max(stats["count"], 1) + pass_rates = { + atol: stats["passes"][atol] / max(stats["count"], 1) + for atol in ATOL_CHECKS + } + return mean, stats["max"] or 0.0, stats["min"] or 0.0, pass_rates + + +def _parse_shapes(expr: str): + shapes = [] + for part in expr.split(","): + part = part.strip() + if not part: + continue + dim_str, samples_str = part.split(":", 1) + shapes.append((int(dim_str), int(samples_str))) + return shapes + + +def _select_shapes(): + large_shapes = [(1, 256), (16, 128), (32, 64), (64, 32), (128, 16)] + medium_shapes = [(1, 128), (16, 64), (32, 32), (64, 16)] + small_shapes = [(1, 64), (8, 32), (16, 16)] + + env_shapes = os.getenv("GPTQMODEL_KERNEL_TEST_SHAPES") + if env_shapes: + return _parse_shapes(env_shapes) + + total_mem_gb = 0.0 + if torch.cuda.is_available(): + device_index = DEVICE.index if DEVICE.index is not None else 0 + try: + if torch.cuda.device_count() > device_index: + props = torch.cuda.get_device_properties(device_index) + total_mem_gb = props.total_memory / (1024 ** 3) + except Exception: + total_mem_gb = 0.0 + + if os.getenv("GPTQMODEL_FAST_TESTS", "0") == "1": + return small_shapes + if total_mem_gb >= 80: + return large_shapes + if total_mem_gb >= 48: + return medium_shapes + return small_shapes + + +def test_kernel_output_fallback(): + if not os.path.isdir(MODEL_DIR): + import pytest + + pytest.skip(f"Model path missing: {MODEL_DIR}") + + if not torch.cuda.is_available(): + import pytest + + pytest.skip("CUDA required for fallback kernel output test") + + torch.manual_seed(0) + + device = DEVICE + dtype = torch.float16 + down_proj = _load_down_proj(dtype=dtype, device=device) + assert down_proj.weight.device.type == "cuda" + + shapes = _select_shapes() + variants = [ + ("rtn", Fallback(strategy=FallbackStrategy.RTN, threshold=True)), + ("midpoint", Fallback(strategy=FallbackStrategy.MIDPOINT, threshold=True)), + ("mean", Fallback(strategy=FallbackStrategy.MEAN, threshold=True)), + ("median", Fallback(strategy=FallbackStrategy.MEDIAN, threshold=True)), + ("stdclip", Fallback(strategy=FallbackStrategy.STDCLIP, threshold=True)), + ("rtn_p99", Fallback( + strategy=FallbackStrategy.RTN, + threshold=True, + smooth=SmoothPercentile(percentile=99.0), + )), + ("rtn_asym_p", Fallback( + strategy=FallbackStrategy.RTN, + threshold=True, + smooth=SmoothPercentileAsymmetric(low=0.5, high=99.5), + )), + ("rtn_mad", Fallback( + strategy=FallbackStrategy.RTN, + threshold=True, + smooth=SmoothMAD(k=3.0), + )), + ("rtn_mse", Fallback( + strategy=FallbackStrategy.RTN, + threshold=True, + smooth=SmoothMSE(steps=32, maxshrink=0.8), + )), + ("rtn_outlier", Fallback( + strategy=FallbackStrategy.RTN, + threshold=True, + smooth=SmoothOutlier(pct=1.0), + )), + ("rtn_softnorm", Fallback( + strategy=FallbackStrategy.RTN, + threshold=True, + smooth=SmoothSoftNorm(k=3.0), + )), + ("rtn_log", Fallback( + strategy=FallbackStrategy.RTN, + threshold=True, + smooth=SmoothLog(percentile=99.0, mu=8.0), + )), + ("rtn_rowcol", Fallback( + strategy=FallbackStrategy.RTN, + threshold=True, + smooth=SmoothRowCol(axis="row"), + )), + ("median_p99", Fallback( + strategy=FallbackStrategy.MEDIAN, + threshold=True, + smooth=SmoothPercentile(percentile=99.0), + )), + ("median_asym_p", Fallback( + strategy=FallbackStrategy.MEDIAN, + threshold=True, + smooth=SmoothPercentileAsymmetric(low=0.5, high=99.5), + )), + ("median_mad", Fallback( + strategy=FallbackStrategy.MEDIAN, + threshold=True, + smooth=SmoothMAD(k=3.0), + )), + ("median_mse", Fallback( + strategy=FallbackStrategy.MEDIAN, + threshold=True, + smooth=SmoothMSE(steps=32, maxshrink=0.8), + )), + ("median_outlier", Fallback( + strategy=FallbackStrategy.MEDIAN, + threshold=True, + smooth=SmoothOutlier(pct=1.0), + )), + ("median_softnorm", Fallback( + strategy=FallbackStrategy.MEDIAN, + threshold=True, + smooth=SmoothSoftNorm(k=3.0), + )), + ("median_log", Fallback( + strategy=FallbackStrategy.MEDIAN, + threshold=True, + smooth=SmoothLog(percentile=99.0, mu=8.0), + )), + ("median_rowcol", Fallback( + strategy=FallbackStrategy.MEDIAN, + threshold=True, + smooth=SmoothRowCol(axis="row"), + )), + ] + qlinears = { + label: _quantize_to_torch_linear(_clone_linear(down_proj, device=device), fallback, device=device) + for label, fallback in variants + } + for label, qlinear in qlinears.items(): + assert qlinear.list_buffers()[0].device.type == "cuda", f"{label} buffers not on CUDA" + + total_samples = sum(samples for _, samples in shapes) + stats = {label: _init_stats() for label, _ in variants} + with torch.inference_mode(): + for _ in log.pb(total_samples).title("Forward Pass on Random Input"): + for dim_0, samples in shapes: + for _ in range(samples): + x = torch.randn( + (dim_0, down_proj.in_features), + device=device, + dtype=dtype, + ) + assert x.device.type == "cuda" + baseline = down_proj(x) + variant_out = {label: qlinears[label](x) for label, _ in variants} + assert baseline.device.type == "cuda" + for label, out in variant_out.items(): + assert out.device.type == "cuda" + diff = torch.abs(baseline - out).float() + _update_stats(stats[label], diff) + + finalized = {} + for label, _ in variants: + mean_val, max_val, min_val, pass_rates = _finalize_stats(stats[label]) + finalized[label] = { + "mean": mean_val, + "max": max_val, + "min": min_val, + "pass": pass_rates, + "atol": max_val, + } + + pass_cols = [{"label": f"pass@{atol:g}", "width": "fit"} for atol in ATOL_CHECKS] + cols = log.columns( + cols=[ + {"label": "variant", "width": "fit"}, + {"label": "mean_diff", "width": "fit"}, + {"label": "max_diff", "width": "fit"}, + {"label": "min_diff", "width": "fit"}, + {"label": "atol_req", "width": "fit"}, + ] + pass_cols, + padding=1, + ) + cols.info.header() + for label, _ in variants: + metrics = finalized[label] + cols.info( + label, + f"{metrics['mean']:.6f}", + f"{metrics['max']:.6f}", + f"{metrics['min']:.6f}", + f"{metrics['atol']:.6f}", + *[f"{metrics['pass'][atol]:.4f}" for atol in ATOL_CHECKS], + ) + cols.info.header() + + for label, _ in variants: + metrics = finalized[label] + assert torch.isfinite(torch.tensor([metrics["mean"], metrics["max"], metrics["min"]])).all() + + +def test_kernel_output_fallback_mad_sweep(): + if not os.path.isdir(MODEL_DIR): + import pytest + + pytest.skip(f"Model path missing: {MODEL_DIR}") + + if not torch.cuda.is_available(): + import pytest + + pytest.skip("CUDA required for fallback kernel output test") + + torch.manual_seed(0) + + device = DEVICE + dtype = torch.float16 + down_proj = _load_down_proj(dtype=dtype, device=device) + assert down_proj.weight.device.type == "cuda" + + k_values = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0, 2.25, 2.5, 2.75, 3.0, 3.5, 4.0, 4.5, 5.0, 6.0] + variants = [ + (f"rtn_mad_k{k:g}", Fallback( + strategy=FallbackStrategy.RTN, + threshold=True, + smooth=SmoothMAD(k=k), + )) + for k in k_values + ] + + shapes = _select_shapes() + qlinears = { + label: _quantize_to_torch_linear(_clone_linear(down_proj, device=device), fallback, device=device) + for label, fallback in variants + } + for label, qlinear in qlinears.items(): + assert qlinear.list_buffers()[0].device.type == "cuda", f"{label} buffers not on CUDA" + + total_samples = sum(samples for _, samples in shapes) + stats = {label: _init_stats() for label, _ in variants} + with torch.inference_mode(): + for _ in log.pb(total_samples).title("Forward Pass on Random Input (MAD Sweep)"): + for dim_0, samples in shapes: + for _ in range(samples): + x = torch.randn( + (dim_0, down_proj.in_features), + device=device, + dtype=dtype, + ) + assert x.device.type == "cuda" + baseline = down_proj(x) + variant_out = {label: qlinears[label](x) for label, _ in variants} + assert baseline.device.type == "cuda" + + for label, out in variant_out.items(): + assert out.device.type == "cuda" + diff = torch.abs(baseline - out).float() + _update_stats(stats[label], diff) + + finalized = {} + for label, _ in variants: + mean_val, max_val, min_val, pass_rates = _finalize_stats(stats[label]) + finalized[label] = { + "mean": mean_val, + "max": max_val, + "min": min_val, + "pass": pass_rates, + "atol": max_val, + } + + pass_cols = [{"label": f"pass@{atol:g}", "width": "fit"} for atol in ATOL_CHECKS] + cols = log.columns( + cols=[ + {"label": "variant", "width": "fit"}, + {"label": "mean_diff", "width": "fit"}, + {"label": "max_diff", "width": "fit"}, + {"label": "min_diff", "width": "fit"}, + {"label": "atol_req", "width": "fit"}, + ] + pass_cols, + padding=1, + ) + cols.info.header() + for label, _ in variants: + metrics = finalized[label] + cols.info( + label, + f"{metrics['mean']:.6f}", + f"{metrics['max']:.6f}", + f"{metrics['min']:.6f}", + f"{metrics['atol']:.6f}", + *[f"{metrics['pass'][atol]:.4f}" for atol in ATOL_CHECKS], + ) + cols.info.header() + + for label, _ in variants: + metrics = finalized[label] + assert torch.isfinite(torch.tensor([metrics["mean"], metrics["max"], metrics["min"]])).all() diff --git a/tests/kernels/test_fp8_kernel.py b/tests/kernels/test_fp8_kernel.py new file mode 100644 index 000000000..54cc6311d --- /dev/null +++ b/tests/kernels/test_fp8_kernel.py @@ -0,0 +1,157 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +import torch.nn as nn + +from gptqmodel.nn_modules.qlinear.fp8 import TorchFP8Linear, quantize_fp8_weight +from gptqmodel.quantization.dtype import available_float8_dtype_names + + +def _available_fp8_quant_formats(): + return [fmt for fmt in available_float8_dtype_names() if fmt != "float8_e8m0fnu"] + + +@pytest.mark.parametrize("fmt", _available_fp8_quant_formats()) +@pytest.mark.parametrize("weight_scale_method", ["tensor", "row", "block"]) +def test_fp8_pack_matches_reference_quantizer(fmt: str, weight_scale_method: str): + torch.manual_seed(0) + linear = nn.Linear(64, 64, bias=True).eval() + block_size = (16, 16) if weight_scale_method == "block" else None + + kernel = TorchFP8Linear( + bits=8, + group_size=-1, + desc_act=False, + sym=True, + in_features=64, + out_features=64, + bias=True, + pack_dtype=torch.int32, + format=fmt, + weight_scale_method=weight_scale_method, + weight_block_size=block_size, + register_buffers=False, + ) + kernel.pack_original(linear=linear, scales=None, zeros=None) + + expected_weight, expected_scale_inv = quantize_fp8_weight( + linear.weight.detach().to(torch.float32), + format=fmt, + weight_scale_method=weight_scale_method, + weight_block_size=block_size, + ) + + assert torch.equal(kernel.weight, expected_weight) + assert torch.equal(kernel.weight_scale_inv, expected_scale_inv) + + +@pytest.mark.skipif(not hasattr(torch, "float8_e8m0fnu"), reason="float8_e8m0fnu unavailable") +@pytest.mark.parametrize("weight_scale_method", ["tensor", "row", "block"]) +def test_fp8_pack_rejects_e8m0fnu(weight_scale_method: str): + linear = nn.Linear(64, 64, bias=True).eval() + block_size = (16, 16) if weight_scale_method == "block" else None + + kernel = TorchFP8Linear( + bits=8, + group_size=-1, + desc_act=False, + sym=True, + in_features=64, + out_features=64, + bias=True, + pack_dtype=torch.int32, + format="float8_e8m0fnu", + weight_scale_method=weight_scale_method, + weight_block_size=block_size, + register_buffers=False, + ) + + with pytest.raises(ValueError, match="dequantization of existing checkpoints"): + kernel.pack_original(linear=linear, scales=None, zeros=None) + + +@pytest.mark.parametrize("fmt", _available_fp8_quant_formats()) +@pytest.mark.parametrize("weight_scale_method", ["tensor", "row", "block"]) +@pytest.mark.parametrize("device", ["cpu", pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required"))]) +def test_fp8_forward_matches_dequantized_reference(fmt: str, weight_scale_method: str, device: str): + torch.manual_seed(1) + linear = nn.Linear(64, 64, bias=True).eval() + block_size = (16, 16) if weight_scale_method == "block" else None + + kernel = TorchFP8Linear( + bits=8, + group_size=-1, + desc_act=False, + sym=True, + in_features=64, + out_features=64, + bias=True, + pack_dtype=torch.int32, + format=fmt, + weight_scale_method=weight_scale_method, + weight_block_size=block_size, + register_buffers=False, + ) + kernel.pack_original(linear=linear, scales=None, zeros=None) + kernel = kernel.to(device=device).eval() + kernel._scaled_mm_hard_disabled = True + + x_dtype = torch.float32 if device == "cpu" else torch.float16 + x = torch.randn(5, 64, device=device, dtype=x_dtype) + + with torch.inference_mode(): + out = kernel(x) + expected = torch.matmul( + x, + kernel.dequantize_weight(device=device, dtype=x_dtype), + ) + if kernel.bias is not None: + expected = expected + kernel.bias.to(device=device, dtype=x_dtype) + + torch.testing.assert_close(out, expected, rtol=1e-3, atol=1e-3) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +@pytest.mark.parametrize("fmt", _available_fp8_quant_formats()) +@pytest.mark.parametrize("weight_scale_method", ["tensor", "row"]) +def test_fp8_scaled_mm_matches_dense_reference(fmt: str, weight_scale_method: str): + torch.manual_seed(2) + linear = nn.Linear(64, 64, bias=False).eval() + + kernel = TorchFP8Linear( + bits=8, + group_size=-1, + desc_act=False, + sym=True, + in_features=64, + out_features=64, + bias=False, + pack_dtype=torch.int32, + format=fmt, + weight_scale_method=weight_scale_method, + register_buffers=False, + ) + kernel.pack_original(linear=linear, scales=None, zeros=None) + kernel = kernel.to(device="cuda").eval() + + x = torch.randn(7, 64, device="cuda", dtype=torch.float16) + if not kernel._can_use_scaled_mm(x): + pytest.skip("scaled_mm is not available for this environment.") + + try: + with torch.inference_mode(): + scaled_mm_out = kernel._forward_scaled_mm(x) + except Exception as exc: + pytest.skip(f"scaled_mm path unavailable: {exc}") + + with torch.inference_mode(): + x_q, scale_a = kernel._quantize_input_for_scaled_mm(x) + x_q_dequant = x_q.to(torch.float16) * scale_a.to(torch.float16) + dense_out = torch.matmul( + x_q_dequant, + kernel.dequantize_weight(device="cuda", dtype=torch.float16), + ) + + torch.testing.assert_close(scaled_mm_out, dense_out, rtol=5e-2, atol=5e-2) diff --git a/tests/kernels/test_gguf_cpp.py b/tests/kernels/test_gguf_cpp.py new file mode 100644 index 000000000..f2169bb49 --- /dev/null +++ b/tests/kernels/test_gguf_cpp.py @@ -0,0 +1,200 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pytest +import torch + +from gptqmodel.models._const import DEVICE +from gptqmodel.nn_modules.qlinear.gguf import GGUFTorchLinear +from gptqmodel.nn_modules.qlinear.gguf_cpp import GGUFCppKernel, GGUFCudaKernel, _get_ggml_bridge +from gptqmodel.quantization import FORMAT, METHOD +from gptqmodel.utils.backend import BACKEND +from gptqmodel.utils.importer import select_quant_linear + + +def _build_quant_modules( + bits: str, + *, + include_cuda: bool = False, +) -> tuple[GGUFTorchLinear, GGUFCppKernel, GGUFCudaKernel | None]: + torch.manual_seed(7) + linear = torch.nn.Linear(64, 48, bias=True, dtype=torch.float16).cpu().eval() + + torch_kernel = GGUFTorchLinear( + bits=bits, + group_size=-1, + sym=True, + desc_act=False, + in_features=64, + out_features=48, + bias=True, + register_buffers=True, + ) + torch_kernel.pack_original(linear, scales=None, zeros=None) + + cpp_kernel = GGUFCppKernel( + bits=bits, + group_size=-1, + sym=True, + desc_act=False, + in_features=64, + out_features=48, + bias=True, + register_buffers=True, + ) + cpp_kernel.load_state_dict(torch_kernel.state_dict(), strict=True) + + cuda_kernel = None + if include_cuda: + cuda_kernel = GGUFCudaKernel( + bits=bits, + group_size=-1, + sym=True, + desc_act=False, + in_features=64, + out_features=48, + bias=True, + register_buffers=True, + ) + cuda_kernel.load_state_dict(torch_kernel.state_dict(), strict=True) + cuda_kernel = cuda_kernel.eval() + return torch_kernel.eval(), cpp_kernel.eval(), cuda_kernel + + +def test_gguf_cpp_kernel_validate_once_uses_llama_cpp(): + GGUFCppKernel.cached_validate_once.cache_clear() + ok, err = GGUFCppKernel.cached_validate_once() + if not ok: + pytest.skip(f"llama-cpp-python unavailable: {err}") + assert ok, err + assert _get_ggml_bridge() is not None + + +@pytest.mark.parametrize("bits", ["q4_k_s", "q4_k_m", "q5_k_m", "q6_k"]) +def test_gguf_cpp_kernel_forward_matches_torch_kernel(bits: str): + GGUFCppKernel.cached_validate_once.cache_clear() + ok, err = GGUFCppKernel.cached_validate_once() + if not ok: + pytest.skip(f"llama-cpp-python unavailable: {err}") + + torch_kernel, cpp_kernel, _ = _build_quant_modules(bits) + x = torch.randn(9, 64, dtype=torch.float32) + + with torch.inference_mode(): + torch_out = torch_kernel(x) + cpp_out = cpp_kernel(x) + + assert cpp_out.dtype == x.dtype + assert cpp_out.shape == torch_out.shape + assert torch.allclose(cpp_out, torch_out, rtol=3e-2, atol=3e-2) + + +@pytest.mark.parametrize("bits", ["q4_k_s", "q4_k_m", "q5_k_m", "q6_k"]) +def test_gguf_cuda_kernel_forward_matches_torch_kernel(bits: str): + GGUFCudaKernel.cached_validate_once.cache_clear() + ok, err = GGUFCudaKernel.cached_validate_once() + if not ok: + pytest.skip(f"llama-cpp-python CUDA unavailable: {err}") + + torch_kernel, _, cuda_kernel = _build_quant_modules(bits, include_cuda=True) + torch_kernel = torch_kernel.to(device="cuda") + assert cuda_kernel is not None + cuda_kernel = cuda_kernel.to(device="cuda") + x = torch.randn(9, 64, dtype=torch.float16, device="cuda") + + with torch.inference_mode(): + torch_out = torch_kernel(x) + cuda_out = cuda_kernel(x) + + assert cuda_out.dtype == x.dtype + assert cuda_out.device.type == "cuda" + assert cuda_out.shape == torch_out.shape + assert torch.allclose(cuda_out, torch_out, rtol=8e-2, atol=8e-2) + + +def test_gguf_cuda_kernel_reuses_cached_plan(): + GGUFCudaKernel.cached_validate_once.cache_clear() + ok, err = GGUFCudaKernel.cached_validate_once() + if not ok: + pytest.skip(f"llama-cpp-python CUDA unavailable: {err}") + + _, _, cuda_kernel = _build_quant_modules("q4_k_m", include_cuda=True) + assert cuda_kernel is not None + cuda_kernel = cuda_kernel.to(device="cuda") + x = torch.randn(9, 64, dtype=torch.float16, device="cuda") + + assert cuda_kernel._ggml_cuda_plans == {} + with torch.inference_mode(): + first = cuda_kernel(x) + assert len(cuda_kernel._ggml_cuda_plans) == 1 + first_plan = next(iter(cuda_kernel._ggml_cuda_plans.values())) + second = cuda_kernel(x) + second_plan = next(iter(cuda_kernel._ggml_cuda_plans.values())) + + assert first_plan is second_plan + assert torch.allclose(first, second, rtol=0, atol=0) + + +def test_gguf_cuda_kernel_fp32_preserves_output_dtype(): + GGUFCudaKernel.cached_validate_once.cache_clear() + ok, err = GGUFCudaKernel.cached_validate_once() + if not ok: + pytest.skip(f"llama-cpp-python CUDA unavailable: {err}") + + torch_kernel, _, cuda_kernel = _build_quant_modules("q4_k_m", include_cuda=True) + torch_kernel = torch_kernel.to(device="cuda") + assert cuda_kernel is not None + cuda_kernel = cuda_kernel.to(device="cuda") + x = torch.randn(9, 64, dtype=torch.float32, device="cuda") + + with torch.inference_mode(): + torch_out = torch_kernel(x) + cuda_out = cuda_kernel(x) + + assert cuda_out.dtype == torch.float32 + assert cuda_out.device.type == "cuda" + assert torch.allclose(cuda_out, torch_out, rtol=8e-2, atol=8e-2) + + +def test_gguf_cpp_kernel_explicit_backend_selection(): + GGUFCppKernel.cached_validate_once.cache_clear() + ok, err = GGUFCppKernel.cached_validate_once() + if not ok: + pytest.skip(f"llama-cpp-python unavailable: {err}") + + qlinear_cls = select_quant_linear( + bits=4, + group_size=-1, + desc_act=False, + sym=True, + device=DEVICE.CPU, + backend=BACKEND.GGUF_CPP_CPU, + format=FORMAT.GGUF, + quant_method=METHOD.GGUF, + pack_dtype=torch.int32, + ) + + assert qlinear_cls is GGUFCppKernel + + +def test_gguf_cuda_kernel_explicit_backend_selection(): + GGUFCudaKernel.cached_validate_once.cache_clear() + ok, err = GGUFCudaKernel.cached_validate_once() + if not ok: + pytest.skip(f"llama-cpp-python CUDA unavailable: {err}") + + qlinear_cls = select_quant_linear( + bits=4, + group_size=-1, + desc_act=False, + sym=True, + device=DEVICE.CUDA, + backend=BACKEND.GGUF_CPP_CUDA, + format=FORMAT.GGUF, + quant_method=METHOD.GGUF, + pack_dtype=torch.int32, + ) + + assert qlinear_cls is GGUFCudaKernel diff --git a/tests/kernels/test_gptq.py b/tests/kernels/test_gptq.py index 48a3681a6..12a771f56 100644 --- a/tests/kernels/test_gptq.py +++ b/tests/kernels/test_gptq.py @@ -4,7 +4,9 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium import os +import time import unittest +from dataclasses import dataclass from typing import List, Tuple import torch @@ -14,15 +16,15 @@ from gptqmodel import BACKEND, GPTQModel from gptqmodel.adapter.adapter import Adapter, AdapterCache, Lora -from gptqmodel.nn_modules.qlinear.bitblas import BitblasQuantLinear -from gptqmodel.nn_modules.qlinear.exllamav2 import ExllamaV2QuantLinear -from gptqmodel.nn_modules.qlinear.machete import MacheteQuantLinear -from gptqmodel.nn_modules.qlinear.marlin import MarlinQuantLinear -from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear -from gptqmodel.nn_modules.qlinear.tritonv2 import TritonV2QuantLinear +from gptqmodel.nn_modules.qlinear.bitblas import BitblasLinear +from gptqmodel.nn_modules.qlinear.exllamav2 import ExllamaV2Linear +from gptqmodel.nn_modules.qlinear.machete import MacheteLinear +from gptqmodel.nn_modules.qlinear.marlin import MarlinLinear +from gptqmodel.nn_modules.qlinear.torch import TorchLinear +from gptqmodel.nn_modules.qlinear.tritonv2 import TritonV2Linear from gptqmodel.utils.machete import ( - _validate_machete_device_support, - machete_import_exception, + machete_runtime_available, + machete_runtime_error, ) from gptqmodel.utils.model import find_modules @@ -38,22 +40,44 @@ os.environ.setdefault("BITBLAS_ENABLE_TUNING", "0") os.environ.setdefault("BITBLAS_ENABLE_TENSORCORE", "0") + +def _bitblas_supports_gptq_case(dtype: torch.dtype) -> bool: + valid, _ = BitblasLinear.validate( + bits=4, + group_size=128, + desc_act=True, + sym=True, + in_features=3072, + out_features=1024, + pack_dtype=torch.int32, + dtype=dtype, + ) + return valid + + class Data: def __init__(self): self.m = 1 self.k = -1 self.x = [] # random X input of shape (m, k) + +@dataclass +class ForwardResult: + outputs: List[torch.Tensor] + total_ms: float + mean_ms: float + class TestKernelOutput(unittest.TestCase): # model_path = "sliuau/llama3.2-1b-4bit-group128" # hf "sliuau/llama3.2-1b-4bit-group128" model_path = "sliuau/Llama-3.2-3B_4bits_128group_size" target_qliner_map = { - BACKEND.TORCH: TorchQuantLinear, - BACKEND.MACHETE: MacheteQuantLinear, - BACKEND.EXLLAMA_V2: ExllamaV2QuantLinear, - BACKEND.TRITON: TritonV2QuantLinear, - BACKEND.BITBLAS: BitblasQuantLinear, - BACKEND.MARLIN: MarlinQuantLinear, + BACKEND.TORCH: TorchLinear, + BACKEND.MACHETE: MacheteLinear, + BACKEND.EXLLAMA_V2: ExllamaV2Linear, + BACKEND.TRITON: TritonV2Linear, + BACKEND.BITBLAS: BitblasLinear, + BACKEND.MARLIN: MarlinLinear, } target = 'model.layers.6.self_attn.v_proj' @@ -132,32 +156,69 @@ def _parse_shapes(expr: str) -> List[Tuple[int, int]]: AdapterCache.reset() # allow next load to complete since we are hacking to get consume only 1 lora module # TORCH as reference output - data.torch_kernel_out = cls.forward(cls, backend=BACKEND.TORCH, dtype=dtype) - data.torch_kernel_out_with_lora = cls.forward(cls, backend=BACKEND.TORCH, dtype=dtype, adapter=data.adapter) + data.torch_kernel = cls.forward(cls, backend=BACKEND.TORCH, dtype=dtype) + data.torch_kernel_out = data.torch_kernel.outputs + data.torch_kernel_with_lora = cls.forward(cls, backend=BACKEND.TORCH, dtype=dtype, adapter=data.adapter) + data.torch_kernel_out_with_lora = data.torch_kernel_with_lora.outputs + + @staticmethod + def _synchronize(device: torch.device) -> None: + if device.type == "cuda": + torch.cuda.synchronize(device) - def forward(self, backend: BACKEND, dtype: torch.dtype, adapter: Adapter = None): + def forward(self, backend: BACKEND, dtype: torch.dtype, adapter: Adapter = None) -> ForwardResult: model = GPTQModel.load(self.model_path, backend=backend, adapter=adapter, dtype=dtype, device=DEVICE) target_qlinear_cls = self.target_qliner_map[backend] modules = find_modules(model.model, layers=[target_qlinear_cls]) result = [] + total_s = 0.0 for name, module in modules.items(): if name == self.target: data = self.data[dtype] + module_device = self._module_device(module) + if data.x: + warmup = data.x[0] + if module_device is not None and warmup.device != module_device: + warmup = warmup.to(module_device) + self._synchronize(DEVICE) + module(warmup) + self._synchronize(DEVICE) for i in log.pb(self.random_input_sample_size).title("Forward Pass on Random Input"): - assert data.x[i].dtype == dtype - result.append(module(data.x[i])) + sample = data.x[i] + assert sample.dtype == dtype + + # Direct layer calls bypass accelerate's cross-device routing, + # so align the synthetic input with the module's local device. + if module_device is not None and sample.device != module_device: + sample = sample.to(module_device) + + self._synchronize(DEVICE) + started = time.perf_counter() + result.append(module(sample).detach().to(device="cpu", dtype=torch.float32)) + self._synchronize(DEVICE) + total_s += time.perf_counter() - started break - assert result is not None - del module del model torch.cuda.empty_cache() - return result + total_ms = total_s * 1000.0 + mean_ms = total_ms / len(result) if result else 0.0 + return ForwardResult(outputs=result, total_ms=total_ms, mean_ms=mean_ms) + + @staticmethod + def _module_device(module): + for tensor in module.parameters(recurse=False): + if tensor is not None and not tensor.is_meta: + return tensor.device + for tensor in module.buffers(recurse=False): + if tensor is not None and not tensor.is_meta: + return tensor.device + return None def _summarize_results( self, @@ -168,15 +229,24 @@ def _summarize_results( atol: float, title: str, reference_label: str, + reference_mean_ms: float, + actual_mean_ms: float, ): failures = [] total = len(actual_outputs) + max_abs_diff = 0.0 + mean_abs_diff = 0.0 for i in range(total): reference = reference_outputs[i] actual = actual_outputs[i] + reference_fp32 = reference.to(torch.float32) + actual_fp32 = actual.to(torch.float32) + diff = torch.abs(reference_fp32 - actual_fp32) + max_abs_diff = max(max_abs_diff, float(diff.max().item())) + mean_abs_diff += float(diff.mean().item()) - is_close_tensor = torch.isclose(reference, actual, rtol=0.15, atol=atol) + is_close_tensor = torch.isclose(reference_fp32, actual_fp32, rtol=0.15, atol=atol) passed = bool(torch.all(is_close_tensor)) if not passed: @@ -184,12 +254,14 @@ def _summarize_results( "Sample {idx}:\nExpected ({ref_label}) = {expected}\nActual = {actual_val}".format( idx=i, ref_label=reference_label, - expected=reference.detach().cpu().tolist(), - actual_val=actual.detach().cpu().tolist(), + expected=reference_fp32.detach().cpu().tolist(), + actual_val=actual_fp32.detach().cpu().tolist(), ) ) status = f"{GREEN}PASS{RESET}" if not failures else f"{RED}FAIL{RESET}" + avg_abs_diff = mean_abs_diff / total if total else 0.0 + speedup = reference_mean_ms / actual_mean_ms if actual_mean_ms else 0.0 details = "\n\n".join(str(detail) for detail in failures) if failures else "-" table = tabulate( @@ -198,6 +270,10 @@ def _summarize_results( backend.name, str(dtype), total, + f"{actual_mean_ms:.4f}", + f"{speedup:.2f}x", + f"{max_abs_diff:.6f}", + f"{avg_abs_diff:.6f}", status, len(failures), details, @@ -207,6 +283,10 @@ def _summarize_results( "Backend", "DType", "Samples", + "MeanLatencyMs", + "SpeedupVsRef", + "MaxAbsDiff", + "MeanAbsDiff", "Status", "Failures", "Expected vs Actual", @@ -225,19 +305,20 @@ def _maybe_skip_backend(self, backend: BACKEND): self.skipTest("BitBLAS disabled (set RUN_BITBLAS_TESTS=1 to enable)") if backend == BACKEND.MACHETE: - if machete_import_exception is not None: - self.skipTest(f"Machete kernel unavailable: {machete_import_exception}") - if not _validate_machete_device_support(): - self.skipTest("Machete requires NVIDIA Hopper or newer (SM90+)") + if not machete_runtime_available(): + self.skipTest(f"Machete kernel unavailable: {machete_runtime_error()}") + # Updated CUDA kernel tolerances below were re-baselined from full + # torch-vs-kernel validation on H200. float16_cases = [ (BACKEND.TORCH, torch.float16, 0.0000), (BACKEND.TRITON, torch.float16, 0.00001), (BACKEND.EXLLAMA_V2, torch.float16, 0.0068), - (BACKEND.MACHETE, torch.float16, 0.00040), - (BACKEND.MARLIN, torch.float16, 0.00035), - (BACKEND.BITBLAS, torch.float16, 0.0035), + (BACKEND.MACHETE, torch.float16, 0.0010), + (BACKEND.MARLIN, torch.float16, 0.0010), ] + if _bitblas_supports_gptq_case(torch.float16): + float16_cases.append((BACKEND.BITBLAS, torch.float16, 0.0035)) @parameterized.expand(float16_cases) def test_kernel_float16(self, backend: BACKEND, dtype: torch.dtype, a_tolerance: float): @@ -247,23 +328,26 @@ def test_kernel_float16(self, backend: BACKEND, dtype: torch.dtype, a_tolerance out = self.forward(backend=backend, dtype=dtype) self._summarize_results( - reference_outputs=data.torch_kernel_out, - actual_outputs=out, + reference_outputs=data.torch_kernel.outputs, + actual_outputs=out.outputs, backend=backend, dtype=dtype, atol=a_tolerance, title=f"Kernel Output {dtype}", reference_label="Torch output", + reference_mean_ms=data.torch_kernel.mean_ms, + actual_mean_ms=out.mean_ms, ) bfloat16_cases = [ (BACKEND.TORCH, torch.bfloat16, 0.0000), (BACKEND.TRITON, torch.bfloat16, 0.00001), - (BACKEND.EXLLAMA_V2, torch.bfloat16, 0.0054), - (BACKEND.MACHETE, torch.bfloat16, 0.0033), - (BACKEND.MARLIN, torch.bfloat16, 0.0031), - (BACKEND.BITBLAS, torch.bfloat16, 0.0031), + (BACKEND.EXLLAMA_V2, torch.bfloat16, 0.0080), + (BACKEND.MACHETE, torch.bfloat16, 0.0080), + (BACKEND.MARLIN, torch.bfloat16, 0.0080), ] + if _bitblas_supports_gptq_case(torch.bfloat16): + bfloat16_cases.append((BACKEND.BITBLAS, torch.bfloat16, 0.0031)) @parameterized.expand(bfloat16_cases) def test_kernel_bfloat16(self, backend: BACKEND, dtype: torch.dtype, a_tolerance: float): @@ -273,23 +357,26 @@ def test_kernel_bfloat16(self, backend: BACKEND, dtype: torch.dtype, a_tolerance out = self.forward(backend=backend, dtype=dtype) self._summarize_results( - reference_outputs=data.torch_kernel_out, - actual_outputs=out, + reference_outputs=data.torch_kernel.outputs, + actual_outputs=out.outputs, backend=backend, dtype=dtype, atol=a_tolerance, title=f"Kernel Output {dtype}", reference_label="Torch output", + reference_mean_ms=data.torch_kernel.mean_ms, + actual_mean_ms=out.mean_ms, ) float16_lora_cases = [ (BACKEND.TORCH, torch.float16, 0.0000), (BACKEND.TRITON, torch.float16, 0.00001), (BACKEND.EXLLAMA_V2, torch.float16, 0.0065), - (BACKEND.MACHETE, torch.float16, 0.00040), - (BACKEND.MARLIN, torch.float16, 0.00035), - (BACKEND.BITBLAS, torch.float16, 0.00035), + (BACKEND.MACHETE, torch.float16, 0.0010), + (BACKEND.MARLIN, torch.float16, 0.0020), ] + if _bitblas_supports_gptq_case(torch.float16): + float16_lora_cases.append((BACKEND.BITBLAS, torch.float16, 0.00035)) @parameterized.expand(float16_lora_cases) def test_kernel_float16_with_lora(self, backend: BACKEND, dtype: torch.dtype, a_tolerance: float): @@ -298,23 +385,26 @@ def test_kernel_float16_with_lora(self, backend: BACKEND, dtype: torch.dtype, a_ data = self.data[dtype] out = self.forward(backend=backend, dtype=dtype, adapter=data.adapter) self._summarize_results( - reference_outputs=data.torch_kernel_out_with_lora, - actual_outputs=out, + reference_outputs=data.torch_kernel_with_lora.outputs, + actual_outputs=out.outputs, backend=backend, dtype=dtype, atol=a_tolerance, title=f"Kernel Output With Lora {dtype}", reference_label="Torch with Lora output", + reference_mean_ms=data.torch_kernel_with_lora.mean_ms, + actual_mean_ms=out.mean_ms, ) bfloat16_lora_cases = [ (BACKEND.TORCH, torch.bfloat16, 0.0000), (BACKEND.TRITON, torch.bfloat16, 0.00001), - (BACKEND.EXLLAMA_V2, torch.bfloat16, 0.0059), - (BACKEND.MACHETE, torch.bfloat16, 0.0033), - (BACKEND.MARLIN, torch.bfloat16, 0.0050), - (BACKEND.BITBLAS, torch.bfloat16, 0.0033), + (BACKEND.EXLLAMA_V2, torch.bfloat16, 0.0160), + (BACKEND.MACHETE, torch.bfloat16, 0.0080), + (BACKEND.MARLIN, torch.bfloat16, 0.0080), ] + if _bitblas_supports_gptq_case(torch.bfloat16): + bfloat16_lora_cases.append((BACKEND.BITBLAS, torch.bfloat16, 0.0033)) @parameterized.expand(bfloat16_lora_cases) def test_kernel_bfloat16_with_lora(self, backend: BACKEND, dtype: torch.dtype, a_tolerance: float): @@ -323,11 +413,13 @@ def test_kernel_bfloat16_with_lora(self, backend: BACKEND, dtype: torch.dtype, a data = self.data[dtype] out = self.forward(backend=backend, dtype=dtype, adapter=data.adapter) self._summarize_results( - reference_outputs=data.torch_kernel_out_with_lora, - actual_outputs=out, + reference_outputs=data.torch_kernel_with_lora.outputs, + actual_outputs=out.outputs, backend=backend, dtype=dtype, atol=a_tolerance, title=f"Kernel Output With Lora {dtype}", reference_label="Torch with Lora output", + reference_mean_ms=data.torch_kernel_with_lora.mean_ms, + actual_mean_ms=out.mean_ms, ) diff --git a/tests/kernels/test_intel_cpu_xpu.py b/tests/kernels/test_intel_cpu_xpu.py index ab56af496..4c593b881 100644 --- a/tests/kernels/test_intel_cpu_xpu.py +++ b/tests/kernels/test_intel_cpu_xpu.py @@ -12,10 +12,10 @@ from torch import Tensor from gptqmodel import BACKEND, GPTQModel -from gptqmodel.nn_modules.qlinear.gemm_hf_kernel import HFKernelLinear -from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear -from gptqmodel.nn_modules.qlinear.torch_fused import TorchFusedQuantLinear -from gptqmodel.nn_modules.qlinear.torch_int8 import TorchInt8QuantLinear +from gptqmodel.nn_modules.qlinear.torch import TorchLinear +from gptqmodel.nn_modules.qlinear.torch_aten_kernel import TorchAtenLinear +from gptqmodel.nn_modules.qlinear.torch_fused import TorchFusedLinear +from gptqmodel.nn_modules.qlinear.torch_int8 import TorchInt8Linear from gptqmodel.utils.model import find_modules @@ -42,6 +42,11 @@ def _xpu_available() -> bool: return hasattr(torch, "xpu") and torch.xpu.is_available() +def _ensure_model_path_available(path: str): + if os.path.isabs(path) and not os.path.isdir(path): + raise unittest.SkipTest(f"Local model path missing: {path}") + + def _summarize_failures(failures): if not failures: return "-" @@ -54,10 +59,10 @@ def _summarize_failures(failures): class TestKernelOutput(unittest.TestCase): model_path = "sliuau/llama3.2-1b-4bit-group128" # hf "sliuau/llama3.2-1b-4bit-group128" target_qliner_map = { - BACKEND.TORCH: TorchQuantLinear, - BACKEND.TORCH_FUSED: TorchFusedQuantLinear, - BACKEND.TORCH_INT8: TorchInt8QuantLinear, - BACKEND.HF_KERNEL: HFKernelLinear, + BACKEND.TORCH: TorchLinear, + BACKEND.TORCH_FUSED: TorchFusedLinear, + BACKEND.TORCH_INT8: TorchInt8Linear, + BACKEND.GPTQ_TORCH_ATEN: TorchAtenLinear, } target = 'model.layers.6.self_attn.v_proj' device = "cpu" @@ -73,6 +78,7 @@ class TestKernelOutput(unittest.TestCase): @classmethod def setUp(self): + _ensure_model_path_available(self.model_path) self.torch_model = GPTQModel.load(self.model_path, backend=BACKEND.TORCH, device=self.device, dtype=self.dtype) self.x = [] self.torch_kernel_outs = [] @@ -105,7 +111,7 @@ def assert_on_mismatch(self, a: Tensor, b: Tensor, rtol=0.00005, atol=0.00005): (BACKEND.TORCH_FUSED, r_tolerance, a_tolerance), # Int4->float->int8 re-quantization in TorchInt8 introduces extra approximation noise. (BACKEND.TORCH_INT8, int8_r_tolerance, int8_a_tolerance), - (BACKEND.HF_KERNEL, r_tolerance, a_tolerance), + (BACKEND.GPTQ_TORCH_ATEN, r_tolerance, a_tolerance), ]) def test_kernel_output(self, backend: BACKEND, r_tolerance: float, a_tolerance: float): model = GPTQModel.load(self.model_path, backend=backend, device=self.device, dtype=self.dtype) @@ -134,7 +140,7 @@ class TestKernelOutputXPUBFloat16(TestKernelOutputXPU): dtype = torch.bfloat16 -class TestTorchFusedAndHFKernelDevices(unittest.TestCase): +class TestTorchFusedAndTorchAtenDevices(unittest.TestCase): model_path = TestKernelOutput.model_path target_qliner_map = TestKernelOutput.target_qliner_map target = TestKernelOutput.target @@ -151,11 +157,12 @@ class TestTorchFusedAndHFKernelDevices(unittest.TestCase): backend_tolerances = { BACKEND.TORCH_FUSED: (r_tolerance, a_tolerance), BACKEND.TORCH_INT8: (int8_r_tolerance, int8_a_tolerance), - BACKEND.HF_KERNEL: (r_tolerance, a_tolerance), + BACKEND.GPTQ_TORCH_ATEN: (r_tolerance, a_tolerance), } @classmethod def setUpClass(cls): + _ensure_model_path_available(cls.model_path) torch.manual_seed(0) cls.inputs = [] for dim_0 in cls.m: @@ -193,7 +200,7 @@ def assert_on_mismatch(self, a: Tensor, b: Tensor, rtol=0.00005, atol=0.00005): @parameterized.expand([ ("cpu", "cpu", BACKEND.TORCH_FUSED), ("cpu", "cpu", BACKEND.TORCH_INT8), - ("cpu", "cpu", BACKEND.HF_KERNEL), + ("cpu", "cpu", BACKEND.GPTQ_TORCH_ATEN), ("xpu", "xpu:0", BACKEND.TORCH_FUSED), ]) def test_backends_matches_cpu_reference(self, _name: str, device: str, backend: BACKEND): @@ -240,7 +247,7 @@ def test_backends_matches_cpu_reference(self, _name: str, device: str, backend: if failures: raise AssertionError(f"{len(failures)} mismatched samples on device {device}") -class TestTorchFusedAndHFKernelDevicesWithBias(TestTorchFusedAndHFKernelDevices): +class TestTorchFusedAndTorchAtenDevicesWithBias(TestTorchFusedAndTorchAtenDevices): model_path = "/monster/data/model/bloom-560m-gptqmodel-4bit" target = 'transformer.h.6.self_attention.query_key_value' k = 1024 diff --git a/tests/kernels/test_paroquant.py b/tests/kernels/test_paroquant.py new file mode 100644 index 000000000..70f2d58fb --- /dev/null +++ b/tests/kernels/test_paroquant.py @@ -0,0 +1,404 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# ParoQuant kernel tests adapted from the ParoQuant paper and public project: +# https://arxiv.org/html/2511.10645v2 +# https://github.com/z-lab/paroquant + +"""Kernel-focused tests for ParoQuant runtime behavior and backend parity.""" + +import pytest +import torch + +from gptqmodel.nn_modules.qlinear.paroquant import ParoLinear +from gptqmodel.nn_modules.qlinear.paroquant_triton import ParoQuantTritonLinear +from gptqmodel.nn_modules.qlinear.torch_awq import AwqTorchLinear +from gptqmodel.quantization import FORMAT, METHOD +from gptqmodel.quantization.awq.utils.packing_utils import dequantize_gemm +from gptqmodel.utils.backend import BACKEND +from gptqmodel.utils.importer import get_kernel_for_backend, select_quant_linear +from gptqmodel.utils.paroquant import apply_paroquant_rotation_reference, build_identity_rotation_buffers + + +def _pack_awq_tensor(unpacked: torch.Tensor, bits: int) -> torch.Tensor: + """Pack unpacked integer weights into the AWQ bit layout used by the kernels.""" + pack_factor = 32 // bits + order_map = [0, 2, 4, 6, 1, 3, 5, 7] + + assert unpacked.shape[1] % pack_factor == 0 + packed = torch.zeros( + (unpacked.shape[0], unpacked.shape[1] // pack_factor), + dtype=torch.int32, + ) + for col in range(unpacked.shape[1] // pack_factor): + for i, order in enumerate(order_map): + value = unpacked[:, col * pack_factor + order].to(torch.int32) + packed[:, col] |= value << (i * bits) + return packed + + +def _make_packed_buffers(bits: int, in_features: int, out_features: int, group_size: int): + """Build synthetic packed AWQ tensors for ParoQuant runtime tests.""" + groups = in_features // group_size + int_weight = torch.randint(0, 2**bits, size=(in_features, out_features), dtype=torch.int32) + zero_points = torch.randint(0, 2**bits, size=(groups, out_features), dtype=torch.int32) + scales = (torch.rand(groups, out_features, dtype=torch.float16) * 2.0) + 0.25 + bias = torch.randn(out_features, dtype=torch.float16) + + return ( + _pack_awq_tensor(int_weight, bits), + _pack_awq_tensor(zero_points, bits), + scales, + bias, + ) + + +def _upstream_transformers_contract_reference( + x: torch.Tensor, + *, + qweight: torch.Tensor, + qzeros: torch.Tensor, + scales: torch.Tensor, + bias: torch.Tensor | None, + pairs: torch.Tensor, + theta: torch.Tensor, + channel_scales: torch.Tensor, + bits: int, + group_size: int, + out_features: int, +) -> torch.Tensor: + """Reference for upstream RotateQuantizedLinear.forward(). + + Upstream ParoQuant applies per-projection rotation to the input and then + feeds the rotated activations into the AWQ GEMM path. We reproduce that + contract with dense dequantization here to assess kernel accuracy without + importing or copying upstream implementation code. + """ + + rotated = apply_paroquant_rotation_reference( + x, + pairs, + theta, + scales=channel_scales, + group_size=group_size, + ) + dense_weight = dequantize_gemm( + qweight=qweight, + qzeros=qzeros, + scales=scales, + bits=bits, + group_size=group_size, + ).to(device=x.device, dtype=x.dtype) + out = torch.matmul(rotated.reshape(-1, x.shape[-1]), dense_weight).reshape(*x.shape[:-1], out_features) + if bias is not None: + out = out + bias.to(device=x.device, dtype=x.dtype) + return out + + +def test_paroquant_identity_forward_matches_awq_torch(): + """Guard that identity ParoQuant is behaviorally equivalent to plain AWQ.""" + bits = 4 + in_features = 128 + out_features = 64 + group_size = 128 + qweight, qzeros, scales, bias = _make_packed_buffers(bits, in_features, out_features, group_size) + + awq_module = AwqTorchLinear( + bits=bits, + group_size=group_size, + sym=True, + desc_act=False, + in_features=in_features, + out_features=out_features, + bias=True, + register_buffers=True, + ) + awq_module.qweight.copy_(qweight) + awq_module.qzeros.copy_(qzeros) + awq_module.scales.copy_(scales) + awq_module.bias.copy_(bias) + awq_module.post_init() + awq_module.eval() + + paro_module = ParoLinear( + bits=bits, + group_size=group_size, + sym=True, + desc_act=False, + in_features=in_features, + out_features=out_features, + bias=True, + register_buffers=True, + krot=8, + ) + paro_module.qweight.copy_(qweight) + paro_module.qzeros.copy_(qzeros) + paro_module.scales.copy_(scales) + paro_module.bias.copy_(bias) + pairs, theta, channel_scales = build_identity_rotation_buffers( + in_features=in_features, + group_size=group_size, + krot=8, + dtype=torch.float16, + ) + paro_module.pairs.copy_(pairs) + paro_module.theta.copy_(theta) + paro_module.channel_scales.copy_(channel_scales) + paro_module.post_init() + paro_module.eval() + + x = torch.randn(4, in_features, dtype=torch.float16) + torch.testing.assert_close(paro_module(x), awq_module(x), atol=5e-3, rtol=5e-3) + + +def test_paroquant_forward_matches_explicit_rotated_reference(): + """Guard the dense reference contract for non-identity ParoQuant rotations.""" + bits = 4 + in_features = 128 + out_features = 64 + group_size = 128 + qweight, qzeros, scales, bias = _make_packed_buffers(bits, in_features, out_features, group_size) + + module = ParoLinear( + bits=bits, + group_size=group_size, + sym=True, + desc_act=False, + in_features=in_features, + out_features=out_features, + bias=True, + register_buffers=True, + krot=1, + ) + module.qweight.copy_(qweight) + module.qzeros.copy_(qzeros) + module.scales.copy_(scales) + module.bias.copy_(bias) + + pairs, theta, channel_scales = build_identity_rotation_buffers( + in_features=in_features, + group_size=group_size, + krot=1, + dtype=torch.float16, + ) + theta.fill_(0.2) + channel_scales.mul_(0.75) + module.pairs.copy_(pairs) + module.theta.copy_(theta) + module.channel_scales.copy_(channel_scales) + module.post_init() + module.eval() + + x = torch.randn(3, in_features, dtype=torch.float16) + rotated = apply_paroquant_rotation_reference( + x, + module.pairs, + module.theta, + scales=module.channel_scales, + group_size=group_size, + ) + dequant_weight = dequantize_gemm( + qweight=module.qweight, + qzeros=module.qzeros, + scales=module.scales, + bits=bits, + group_size=group_size, + ).to(dtype=x.dtype) + expected = torch.matmul(rotated, dequant_weight) + module.bias + + torch.testing.assert_close(module(x), expected, atol=5e-3, rtol=5e-3) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for ParoQuant kernel accuracy test") +def test_paroquant_cuda_matches_upstream_transformers_contract(): + """Compare the internal CUDA kernel path to an upstream-style contract. + + The official implementation rotates activations and then runs an AWQ-style + packed matmul. This test reproduces that contract without importing upstream + code and checks our fused CUDA path stays within a bounded numerical error. + """ + bits = 4 + in_features = 128 + out_features = 128 + group_size = 128 + torch.manual_seed(0) + + groups = in_features // group_size + int_weight = torch.randint(0, 2**bits, size=(in_features, out_features), dtype=torch.int32) + zero_points = torch.full((groups, out_features), 2 ** (bits - 1), dtype=torch.int32) + scales = (torch.rand(groups, out_features, dtype=torch.float16) * 0.75) + 0.25 + bias = torch.randn(out_features, dtype=torch.float16) + qweight = _pack_awq_tensor(int_weight, bits) + qzeros = _pack_awq_tensor(zero_points, bits) + + module = ParoLinear( + bits=bits, + group_size=group_size, + sym=True, + desc_act=False, + in_features=in_features, + out_features=out_features, + bias=True, + register_buffers=True, + krot=8, + ).cuda() + + pairs, theta, channel_scales = build_identity_rotation_buffers( + in_features=in_features, + group_size=group_size, + krot=8, + device="cuda", + dtype=torch.float16, + ) + theta.uniform_(-0.25, 0.25) + channel_scales.uniform_(0.7, 1.3) + + module.qweight.copy_(qweight.cuda()) + module.qzeros.copy_(qzeros.cuda()) + module.scales.copy_(scales.cuda()) + module.bias.copy_(bias.cuda()) + module.pairs.copy_(pairs) + module.theta.copy_(theta) + module.channel_scales.copy_(channel_scales) + module.post_init() + module.eval() + + x = torch.randn(3, 7, in_features, device="cuda", dtype=torch.float16) + + with torch.inference_mode(): + expected = _upstream_transformers_contract_reference( + x, + qweight=module.qweight, + qzeros=module.qzeros, + scales=module.scales, + bias=module.bias, + pairs=module.pairs, + theta=module.theta, + channel_scales=module.channel_scales, + bits=bits, + group_size=group_size, + out_features=out_features, + ) + + original_forward_dense = module._forward_dense + try: + module._forward_dense = lambda *_args, **_kwargs: (_ for _ in ()).throw( + AssertionError("Expected ParoQuant CUDA kernel path, but dense fallback was used.") + ) + actual = module(x) + finally: + module._forward_dense = original_forward_dense + + diff = (actual - expected).abs().float() + assert diff.max().item() <= 0.25 + assert diff.mean().item() <= 0.03 + + +def test_paroquant_backend_selection(): + """Guard user-facing backend selection for the default CUDA runtime.""" + qlinear_cls = select_quant_linear( + bits=4, + group_size=128, + desc_act=False, + sym=True, + device=None, + backend=BACKEND.PARO, + format=FORMAT.PAROQUANT, + quant_method=METHOD.PARO, + pack_dtype=torch.int32, + ) + assert qlinear_cls is ParoLinear + + +def test_paroquant_triton_backend_mapping(): + """Guard registry lookup for the Triton ParoQuant runtime class.""" + assert ( + get_kernel_for_backend(BACKEND.PAROQUANT_TRITON, METHOD.PARO, FORMAT.PAROQUANT) + is ParoQuantTritonLinear + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for ParoQuant Triton kernel parity test") +def test_paroquant_triton_matches_existing_cuda_kernel(): + """Guard Triton runtime accuracy against the established CUDA implementation.""" + pytest.importorskip("triton") + + bits = 4 + in_features = 128 + out_features = 128 + group_size = 128 + qweight, qzeros, scales, bias = _make_packed_buffers(bits, in_features, out_features, group_size) + + baseline = ParoLinear( + bits=bits, + group_size=group_size, + sym=True, + desc_act=False, + in_features=in_features, + out_features=out_features, + bias=True, + register_buffers=True, + krot=8, + ).cuda() + candidate = ParoQuantTritonLinear( + bits=bits, + group_size=group_size, + sym=True, + desc_act=False, + in_features=in_features, + out_features=out_features, + bias=True, + register_buffers=True, + krot=8, + ).cuda() + + pairs, theta, channel_scales = build_identity_rotation_buffers( + in_features=in_features, + group_size=group_size, + krot=8, + device="cuda", + dtype=torch.float16, + ) + theta.uniform_(-0.2, 0.2) + channel_scales.uniform_(0.75, 1.25) + + for module in (baseline, candidate): + module.qweight.copy_(qweight.cuda()) + module.qzeros.copy_(qzeros.cuda()) + module.scales.copy_(scales.cuda()) + module.bias.copy_(bias.cuda()) + module.pairs.copy_(pairs) + module.theta.copy_(theta) + module.channel_scales.copy_(channel_scales) + module.post_init() + module.eval() + + x = torch.randn(2, 8, in_features, device="cuda", dtype=torch.float16) + with torch.inference_mode(): + rotated = apply_paroquant_rotation_reference( + x, + baseline.pairs, + baseline.theta, + scales=baseline.channel_scales, + group_size=group_size, + ) + dense_weight = dequantize_gemm( + qweight=baseline.qweight, + qzeros=baseline.qzeros, + scales=baseline.scales, + bits=bits, + group_size=group_size, + ).to(dtype=x.dtype) + dense_reference = torch.matmul(rotated.reshape(-1, in_features), dense_weight).reshape(2, 8, out_features) + dense_reference = dense_reference + baseline.bias + + baseline_out = baseline(x) + candidate_out = candidate(x) + + baseline_max_abs = (baseline_out - dense_reference).abs().max().item() + baseline_mean_abs = (baseline_out - dense_reference).abs().mean().item() + candidate_max_abs = (candidate_out - dense_reference).abs().max().item() + candidate_mean_abs = (candidate_out - dense_reference).abs().mean().item() + + assert candidate_max_abs <= baseline_max_abs + 0.1 + assert candidate_mean_abs <= baseline_mean_abs + 0.02 diff --git a/tests/kernels/test_qlinear_hierarchy.py b/tests/kernels/test_qlinear_hierarchy.py new file mode 100644 index 000000000..08b4acfb8 --- /dev/null +++ b/tests/kernels/test_qlinear_hierarchy.py @@ -0,0 +1,236 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +import inspect + +import pytest +import torch + +import gptqmodel.nn_modules.qlinear.bitblas as bitblas_module +from gptqmodel.nn_modules.qlinear import ( + AWQuantLinear, + BaseQuantLinear, + GPTQQuantLinear, + GroupedQuantLinear, + PackedGroupedQuantLinear, +) +from gptqmodel.nn_modules.qlinear.bitblas import BitblasLinear +from gptqmodel.nn_modules.qlinear.bitblas_awq import AWQBitBlasKernel +from gptqmodel.nn_modules.qlinear.fp8 import TorchFP8Linear +from gptqmodel.nn_modules.qlinear.gguf import GGUFTorchLinear +from gptqmodel.nn_modules.qlinear.qqq import QQQLinear +from gptqmodel.nn_modules.qlinear.torch import TorchLinear +from gptqmodel.nn_modules.qlinear.torch_aten_kernel_awq import TorchAtenAwqLinear +from gptqmodel.nn_modules.qlinear.torch_awq import AwqTorchLinear +from gptqmodel.nn_modules.qlinear.torch_fused_awq import TorchFusedAwqLinear + + +def test_quant_linear_hierarchy_splits_grouped_and_weight_only_kernels(): + assert issubclass(TorchLinear, GPTQQuantLinear) + assert issubclass(TorchLinear, PackedGroupedQuantLinear) + assert issubclass(AwqTorchLinear, AWQuantLinear) + assert issubclass(AwqTorchLinear, PackedGroupedQuantLinear) + assert issubclass(AwqTorchLinear, GroupedQuantLinear) + assert not issubclass(TorchFP8Linear, GroupedQuantLinear) + assert not issubclass(GGUFTorchLinear, GroupedQuantLinear) + + +def test_awq_hybrid_kernels_do_not_inherit_gptq_only_base_state(): + for cls in (TorchFusedAwqLinear, TorchAtenAwqLinear): + assert issubclass(cls, AWQuantLinear) + assert not issubclass(cls, GPTQQuantLinear) + assert not hasattr(cls, "qzero_format") + + assert issubclass(AWQBitBlasKernel, GroupedQuantLinear) + assert not issubclass(AWQBitBlasKernel, GPTQQuantLinear) + assert not issubclass(AWQBitBlasKernel, PackedGroupedQuantLinear) + assert not issubclass(AWQBitBlasKernel, BitblasLinear) + assert not hasattr(AWQBitBlasKernel, "qzero_format") + assert not hasattr(AWQBitBlasKernel, "repack_from_gptq") + + +def test_bitblas_gptq_kernel_keeps_gptq_only_repack_surface(): + assert issubclass(BitblasLinear, GroupedQuantLinear) + assert not issubclass(BitblasLinear, PackedGroupedQuantLinear) + assert hasattr(BitblasLinear, "repack_from_gptq") + assert not hasattr(BitblasLinear, "repack_from_awq") + + +def test_base_quant_linear_init_is_method_agnostic(): + params = inspect.signature(BaseQuantLinear.__init__).parameters + + assert "group_size" not in params + assert "desc_act" not in params + assert "sym" not in params + assert "pack_dtype" not in params + + +def test_grouped_kernels_keep_grouped_runtime_state(): + gptq = TorchLinear( + bits=4, + group_size=32, + sym=True, + desc_act=False, + in_features=32, + out_features=32, + bias=False, + register_buffers=False, + ) + awq = AwqTorchLinear( + bits=4, + group_size=32, + sym=True, + desc_act=False, + in_features=32, + out_features=32, + bias=False, + register_buffers=False, + ) + + for module in (gptq, awq): + assert module.group_size == 32 + assert module.desc_act is False + assert module.sym is True + assert module.pack_dtype == torch.int32 + assert module.smooth_block_size() == 32 + + assert gptq.qzero_format() == 1 + assert not hasattr(awq, "_qzeros_format") + + +def _install_dummy_bitblas(monkeypatch): + class _DummyConfig: + A_dtype = "float16" + W_dtype = "uint4" + out_dtype = "float16" + accum_dtype = "float32" + group_size = 32 + + class _DummyMatmul: + def __init__(self, config): + self.config = config + self.lib = object() + self.weight_transform = None + + @staticmethod + def retrieve_weight_shape(): + return (1, 1) + + def _fake_get_or_create(self, config, enable_tuning): + del enable_tuning + return _DummyMatmul(config) + + def _fake_configure(self, infeatures, outfeatures, params_dtype, enable_tuning, bias, layout, bits): + del infeatures, outfeatures, params_dtype, enable_tuning, bias, layout, bits + self.bitblas_matmul = _DummyMatmul(_DummyConfig()) + + monkeypatch.setattr(bitblas_module, "BITBLAS_AVAILABLE", True) + monkeypatch.setattr(bitblas_module, "import_bitblas", lambda: None) + monkeypatch.setattr(BitblasLinear, "_get_or_create_bitblas_operator", _fake_get_or_create) + monkeypatch.setattr(AWQBitBlasKernel, "_get_or_create_bitblas_operator", _fake_get_or_create) + monkeypatch.setattr(BitblasLinear, "_configure_bitblas_matmul", _fake_configure) + monkeypatch.setattr(AWQBitBlasKernel, "_configure_bitblas_matmul", _fake_configure) + BitblasLinear.cached_validate_once.cache_clear() + AWQBitBlasKernel.cached_validate_once.cache_clear() + + +def test_grouped_nonpacked_kernels_do_not_store_packed_runtime_state(monkeypatch): + _install_dummy_bitblas(monkeypatch) + monkeypatch.setattr(QQQLinear, "cached_validate_once", classmethod(lambda cls: (True, None))) + + gptq_bitblas = BitblasLinear( + bits=4, + group_size=32, + sym=True, + desc_act=False, + in_features=32, + out_features=32, + bias=False, + ) + awq_bitblas = AWQBitBlasKernel( + bits=4, + group_size=32, + sym=True, + desc_act=False, + in_features=32, + out_features=32, + bias=False, + ) + qqq = QQQLinear( + bits=4, + group_size=128, + sym=True, + desc_act=False, + in_features=128, + out_features=128, + bias=False, + register_buffers=False, + ) + + for module in (gptq_bitblas, awq_bitblas): + assert module.group_size == 32 + assert module.desc_act is False + assert module.sym is True + assert module.smooth_block_size() == 32 + assert not hasattr(module, "pack_dtype") + assert not hasattr(module, "pack_dtype_bits") + assert not hasattr(module, "pack_factor") + assert not hasattr(module, "pack_np_dtype") + assert not hasattr(module, "pack_np_math_dtype") + assert not hasattr(module, "maxq") + + assert qqq.group_size == 128 + assert qqq.desc_act is False + assert qqq.sym is True + assert qqq.smooth_block_size() == 128 + assert not hasattr(qqq, "pack_dtype") + assert not hasattr(qqq, "pack_dtype_bits") + assert not hasattr(qqq, "pack_factor") + assert not hasattr(qqq, "pack_np_dtype") + assert not hasattr(qqq, "pack_np_math_dtype") + assert qqq.maxq == 15 + + +def test_weight_only_kernels_do_not_store_grouped_runtime_state(): + ok, err = TorchFP8Linear.validate_once() + if not ok: + pytest.skip(f"FP8 unavailable: {err}") + + fp8 = TorchFP8Linear( + bits=8, + group_size=-1, + sym=True, + desc_act=False, + in_features=32, + out_features=32, + bias=False, + register_buffers=False, + weight_scale_method="block", + weight_block_size=(16, 16), + ) + gguf = GGUFTorchLinear( + bits="q4_0", + group_size=-1, + sym=True, + desc_act=False, + in_features=32, + out_features=32, + bias=False, + register_buffers=False, + ) + + for module in (fp8, gguf): + assert not hasattr(module, "group_size") + assert not hasattr(module, "desc_act") + assert not hasattr(module, "sym") + assert not hasattr(module, "pack_dtype") + + assert fp8.smooth_block_size() == 16 + assert gguf.smooth_block_size() == gguf.gguf_block_size + + +def test_weight_only_kernels_do_not_declare_grouped_support_metadata(): + for cls in (TorchFP8Linear, GGUFTorchLinear): + assert not hasattr(cls, "SUPPORTS_GROUP_SIZE") + assert not hasattr(cls, "SUPPORTS_DESC_ACT") + assert not hasattr(cls, "SUPPORTS_SYM") diff --git a/tests/kernels/test_selection.py b/tests/kernels/test_selection.py index 596f7ea19..b2b588e2f 100644 --- a/tests/kernels/test_selection.py +++ b/tests/kernels/test_selection.py @@ -8,11 +8,18 @@ from gptqmodel.models._const import DEVICE from gptqmodel.nn_modules.qlinear import BaseQuantLinear -from gptqmodel.nn_modules.qlinear.gemm_hf_kernel import HFKernelLinear -from gptqmodel.nn_modules.qlinear.gemm_hf_kernel_awq import HFKernelAwqLinear +from gptqmodel.nn_modules.qlinear.gguf import GGUFTorchLinear +from gptqmodel.nn_modules.qlinear.gguf_cpp import GGUFCppKernel, GGUFCudaKernel +from gptqmodel.nn_modules.qlinear.gguf_triton import GGUFTritonKernel +from gptqmodel.nn_modules.qlinear.machete import MacheteLinear +from gptqmodel.nn_modules.qlinear.machete_awq import AwqMacheteLinear +from gptqmodel.nn_modules.qlinear.marlin_awq import AwqMarlinLinear +from gptqmodel.nn_modules.qlinear.torch_aten_kernel import TorchAtenLinear +from gptqmodel.nn_modules.qlinear.torch_aten_kernel_awq import TorchAtenAwqLinear from gptqmodel.quantization import FORMAT, METHOD +from gptqmodel.utils import importer from gptqmodel.utils.backend import BACKEND -from gptqmodel.utils.importer import AUTO_BACKEND_KERNEL_MAPPING, select_quant_linear +from gptqmodel.utils.importer import AUTO_BACKEND_KERNEL_MAPPING, auto_select_device, select_quant_linear from gptqmodel.utils.rocm import IS_ROCM from gptqmodel.utils.torch import HAS_CUDA, HAS_MPS, HAS_XPU @@ -62,15 +69,33 @@ def _pick_group_size(cls): for candidate in group_sizes: if candidate != -1: return candidate - return group_sizes[0] if group_sizes else 128 + return group_sizes[0] if group_sizes else -1 + + +def _pick_desc_act(cls): + values = list(getattr(cls, "SUPPORTS_DESC_ACT", [])) + return values[0] if values else False + + +def _pick_sym(cls): + values = list(getattr(cls, "SUPPORTS_SYM", [])) + return values[0] if values else True + + +def _pick_bits(cls): + supported_bits = list(getattr(cls, "SUPPORTS_BITS", [])) + for candidate in supported_bits: + if candidate in {2, 3, 4, 5, 6, 8}: + return candidate + return None def _force_auto_candidates_valid(monkeypatch, method, fmt): for cls in set(AUTO_BACKEND_KERNEL_MAPPING[method][fmt].values()): monkeypatch.setattr( cls, - "validate", - classmethod(lambda qlinear_cls, **kwargs: (True, None)), + "cached_validate_once", + classmethod(lambda qlinear_cls: (True, None)), ) @@ -92,10 +117,12 @@ def test_select_quant_linear_smoke(kernel_cls, method, fmt): pytest.skip(f"{kernel_cls.__name__} unavailable: {err}") pack_dtype = kernel_cls.SUPPORTS_PACK_DTYPES[0] - bits = kernel_cls.SUPPORTS_BITS[0] + bits = _pick_bits(kernel_cls) + if bits is None: + pytest.skip(f"No selector-compatible bit-width available for {kernel_cls.__name__}.") group_size = _pick_group_size(kernel_cls) - desc_act = kernel_cls.SUPPORTS_DESC_ACT[0] - sym = kernel_cls.SUPPORTS_SYM[0] + desc_act = _pick_desc_act(kernel_cls) + sym = _pick_sym(kernel_cls) qlinear_cls = select_quant_linear( bits=bits, @@ -113,7 +140,7 @@ def test_select_quant_linear_smoke(kernel_cls, method, fmt): @pytest.mark.parametrize("fmt", [FORMAT.GPTQ, FORMAT.GPTQ_V2]) -def test_cpu_auto_select_prioritizes_hf_kernel_for_gptq(monkeypatch, fmt): +def test_cpu_auto_select_prioritizes_torch_aten_for_gptq(monkeypatch, fmt): _force_auto_candidates_valid(monkeypatch, METHOD.GPTQ, fmt) candidates = select_quant_linear( @@ -129,10 +156,10 @@ def test_cpu_auto_select_prioritizes_hf_kernel_for_gptq(monkeypatch, fmt): multi_select=True, ) - assert candidates[0] is HFKernelLinear + assert candidates[0] is TorchAtenLinear -def test_cpu_auto_select_prioritizes_hf_kernel_for_awq(monkeypatch): +def test_cpu_auto_select_prioritizes_torch_aten_for_awq(monkeypatch): _force_auto_candidates_valid(monkeypatch, METHOD.AWQ, FORMAT.GEMM) candidates = select_quant_linear( @@ -148,4 +175,293 @@ def test_cpu_auto_select_prioritizes_hf_kernel_for_awq(monkeypatch): multi_select=True, ) - assert candidates[0] is HFKernelAwqLinear + assert candidates[0] is TorchAtenAwqLinear + + +def test_cpu_auto_select_prioritizes_cpp_kernel_for_gguf(monkeypatch): + _force_auto_candidates_valid(monkeypatch, METHOD.GGUF, FORMAT.GGUF) + + candidates = select_quant_linear( + bits=4, + group_size=-1, + desc_act=False, + sym=True, + device=DEVICE.CPU, + backend=BACKEND.AUTO, + format=FORMAT.GGUF, + quant_method=METHOD.GGUF, + pack_dtype=torch.int32, + multi_select=True, + ) + + assert candidates[0] is GGUFCppKernel + assert GGUFTorchLinear in candidates + + +def test_cuda_auto_select_prioritizes_triton_then_cpp_then_torch_for_gguf(monkeypatch): + _force_auto_candidates_valid(monkeypatch, METHOD.GGUF, FORMAT.GGUF) + + candidates = select_quant_linear( + bits=4, + group_size=-1, + desc_act=False, + sym=True, + device=DEVICE.CUDA, + backend=BACKEND.AUTO, + format=FORMAT.GGUF, + quant_method=METHOD.GGUF, + pack_dtype=torch.int32, + multi_select=True, + ) + + assert candidates[0] is GGUFTritonKernel + assert candidates[1] is GGUFCudaKernel + assert candidates[2] is GGUFTorchLinear + + +def test_cuda_auto_select_prioritizes_triton_then_torch_for_sign_only_gguf(monkeypatch): + _force_auto_candidates_valid(monkeypatch, METHOD.GGUF, FORMAT.GGUF) + + candidates = select_quant_linear( + bits=1, + group_size=-1, + desc_act=False, + sym=True, + device=DEVICE.CUDA, + backend=BACKEND.AUTO, + format=FORMAT.GGUF, + quant_method=METHOD.GGUF, + pack_dtype=torch.int32, + multi_select=True, + ) + + assert candidates[0] is GGUFTritonKernel + assert GGUFCudaKernel not in candidates + assert candidates[1] is GGUFTorchLinear + + +def test_cpu_pack_auto_select_skips_cpp_kernel_for_gguf(monkeypatch): + _force_auto_candidates_valid(monkeypatch, METHOD.GGUF, FORMAT.GGUF) + + candidates = select_quant_linear( + bits=4, + group_size=-1, + desc_act=False, + sym=True, + device=DEVICE.CPU, + backend=BACKEND.AUTO, + format=FORMAT.GGUF, + quant_method=METHOD.GGUF, + pack=True, + pack_dtype=torch.int32, + multi_select=True, + ) + + assert GGUFCppKernel not in candidates + assert candidates[0] is GGUFTorchLinear + + +def test_cuda_pack_auto_select_prioritizes_triton_for_gguf(monkeypatch): + _force_auto_candidates_valid(monkeypatch, METHOD.GGUF, FORMAT.GGUF) + + candidates = select_quant_linear( + bits=4, + group_size=-1, + desc_act=False, + sym=True, + device=DEVICE.CUDA, + backend=BACKEND.AUTO, + format=FORMAT.GGUF, + quant_method=METHOD.GGUF, + pack=True, + pack_dtype=torch.int32, + multi_select=True, + ) + + assert candidates[0] is GGUFTritonKernel + assert GGUFCudaKernel not in candidates + assert GGUFTorchLinear in candidates + + +def test_explicit_gguf_cpu_backend_selects_cpp_kernel(monkeypatch): + monkeypatch.setattr( + GGUFCppKernel, + "cached_validate_once", + classmethod(lambda qlinear_cls: (True, None)), + ) + qlinear_cls = select_quant_linear( + bits=4, + group_size=-1, + desc_act=False, + sym=True, + device=DEVICE.CPU, + backend=BACKEND.GGUF_CPP_CPU, + format=FORMAT.GGUF, + quant_method=METHOD.GGUF, + pack_dtype=torch.int32, + ) + + assert qlinear_cls is GGUFCppKernel + + +def test_explicit_gguf_cuda_backend_selects_cuda_kernel(monkeypatch): + monkeypatch.setattr( + GGUFCudaKernel, + "cached_validate_once", + classmethod(lambda qlinear_cls: (True, None)), + ) + qlinear_cls = select_quant_linear( + bits=4, + group_size=-1, + desc_act=False, + sym=True, + device=DEVICE.CUDA, + backend=BACKEND.GGUF_CPP_CUDA, + format=FORMAT.GGUF, + quant_method=METHOD.GGUF, + pack_dtype=torch.int32, + ) + + assert qlinear_cls is GGUFCudaKernel + + +def test_explicit_gguf_torch_backend_selects_torch_kernel(): + qlinear_cls = select_quant_linear( + bits=4, + group_size=-1, + desc_act=False, + sym=True, + device=DEVICE.CPU, + backend=BACKEND.GGUF_TORCH, + format=FORMAT.GGUF, + quant_method=METHOD.GGUF, + pack_dtype=torch.int32, + ) + + assert qlinear_cls is GGUFTorchLinear + + +def test_explicit_gguf_triton_backend_selects_triton_kernel(monkeypatch): + monkeypatch.setattr( + GGUFTritonKernel, + "cached_validate_once", + classmethod(lambda qlinear_cls: (True, None)), + ) + qlinear_cls = select_quant_linear( + bits=4, + group_size=-1, + desc_act=False, + sym=True, + device=DEVICE.CUDA, + backend=BACKEND.GGUF_TRITON, + format=FORMAT.GGUF, + quant_method=METHOD.GGUF, + pack_dtype=torch.int32, + ) + + assert qlinear_cls is GGUFTritonKernel + + +def test_explicit_awq_marlin_backend_selects_asymmetric_kernel(monkeypatch): + monkeypatch.setattr( + AwqMarlinLinear, + "cached_validate_once", + classmethod(lambda qlinear_cls: (True, None)), + ) + monkeypatch.setattr( + AwqMarlinLinear, + "validate_device", + classmethod(lambda qlinear_cls, _device: None), + ) + + qlinear_cls = select_quant_linear( + bits=4, + group_size=128, + desc_act=False, + sym=False, + device=DEVICE.CUDA, + backend=BACKEND.MARLIN, + format=FORMAT.GEMM, + quant_method=METHOD.AWQ, + pack_dtype=torch.int32, + ) + + assert qlinear_cls is AwqMarlinLinear + + +def test_explicit_awq_machete_backend_selects_asymmetric_kernel(monkeypatch): + monkeypatch.setattr( + AwqMacheteLinear, + "cached_validate_once", + classmethod(lambda qlinear_cls: (True, None)), + ) + monkeypatch.setattr( + AwqMacheteLinear, + "validate_device", + classmethod(lambda qlinear_cls, _device: None), + ) + + qlinear_cls = select_quant_linear( + bits=4, + group_size=128, + desc_act=False, + sym=False, + device=DEVICE.CUDA, + backend=BACKEND.MACHETE, + format=FORMAT.GEMM, + quant_method=METHOD.AWQ, + pack_dtype=torch.int32, + ) + + assert qlinear_cls is AwqMacheteLinear + + +def test_explicit_gptq_machete_backend_selects_asymmetric_kernel(monkeypatch): + monkeypatch.setattr( + MacheteLinear, + "cached_validate_once", + classmethod(lambda qlinear_cls: (True, None)), + ) + monkeypatch.setattr( + MacheteLinear, + "validate_device", + classmethod(lambda qlinear_cls, _device: None), + ) + + qlinear_cls = select_quant_linear( + bits=4, + group_size=128, + desc_act=False, + sym=False, + device=DEVICE.CUDA, + backend=BACKEND.MACHETE, + format=FORMAT.GPTQ, + quant_method=METHOD.GPTQ, + pack_dtype=torch.int32, + ) + + assert qlinear_cls is MacheteLinear + + +def test_torch_fused_auto_device_prefers_xpu_or_cpu(monkeypatch): + monkeypatch.setattr(importer, "HAS_CUDA", True) + monkeypatch.setattr(importer, "HAS_XPU", False) + monkeypatch.setattr(importer, "HAS_MPS", False) + + assert auto_select_device(None, BACKEND.TORCH_FUSED) is DEVICE.CPU + assert auto_select_device(None, BACKEND.TORCH_FUSED_AWQ) is DEVICE.CPU + + +def test_gguf_does_not_accept_generic_torch_backend(): + with pytest.raises(ValueError, match="Unsupported backend"): + select_quant_linear( + bits=4, + group_size=-1, + desc_act=False, + sym=True, + device=DEVICE.CPU, + backend=BACKEND.TORCH, + format=FORMAT.GGUF, + quant_method=METHOD.GGUF, + pack_dtype=torch.int32, + ) diff --git a/tests/kernels/test_torch_int8.py b/tests/kernels/test_torch_int8.py index d5148ca67..ca5e8bb91 100644 --- a/tests/kernels/test_torch_int8.py +++ b/tests/kernels/test_torch_int8.py @@ -18,8 +18,8 @@ os.environ.setdefault("GPTQMODEL_DISABLE_BITBLAS", "1") from gptqmodel.models._const import DEVICE -from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear -from gptqmodel.nn_modules.qlinear.torch_int8 import TorchInt8QuantLinear +from gptqmodel.nn_modules.qlinear.torch import TorchLinear +from gptqmodel.nn_modules.qlinear.torch_int8 import TorchInt8Linear from gptqmodel.utils.torch import TORCH_HAS_FUSED_OPS @@ -81,7 +81,7 @@ def _mock_gptq_linear( return linear, scales, zeros, g_idx -def _copy_gptq_buffers(src: TorchQuantLinear, dst: TorchInt8QuantLinear) -> None: +def _copy_gptq_buffers(src: TorchLinear, dst: TorchInt8Linear) -> None: dst.qweight.copy_(src.qweight) dst.qzeros.copy_(src.qzeros) dst.scales.copy_(src.scales) @@ -110,7 +110,7 @@ def test_torch_int8_cpu_kernel_deviation_against_torch(dtype: torch.dtype, desc_ groups = in_features // group_size g_idx = (torch.arange(in_features, dtype=torch.int32) * 3) % groups - baseline = TorchQuantLinear( + baseline = TorchLinear( bits=bits, group_size=group_size, sym=True, @@ -120,7 +120,7 @@ def test_torch_int8_cpu_kernel_deviation_against_torch(dtype: torch.dtype, desc_ pack_dtype=torch.int32, bias=False, ) - candidate = TorchInt8QuantLinear( + candidate = TorchInt8Linear( bits=bits, group_size=group_size, sym=True, @@ -175,8 +175,8 @@ def test_torch_int8_cpu_kernel_deviation_against_torch(dtype: torch.dtype, desc_ def test_torch_int8_kernel_is_cpu_only(): with pytest.raises(NotImplementedError): - TorchInt8QuantLinear.validate_device(DEVICE.XPU) + TorchInt8Linear.validate_device(DEVICE.XPU) def test_torch_int8_supports_expected_bits(): - assert TorchInt8QuantLinear.SUPPORTS_BITS == [2, 4, 8] + assert TorchInt8Linear.SUPPORTS_BITS == [2, 4, 8] diff --git a/tests/kernels/test_torch_int8_awq.py b/tests/kernels/test_torch_int8_awq.py index 909a6360c..99325bd3e 100644 --- a/tests/kernels/test_torch_int8_awq.py +++ b/tests/kernels/test_torch_int8_awq.py @@ -16,8 +16,8 @@ os.environ.setdefault("GPTQMODEL_DISABLE_BITBLAS", "1") from gptqmodel.models._const import DEVICE -from gptqmodel.nn_modules.qlinear.torch_awq import AwqTorchQuantLinear -from gptqmodel.nn_modules.qlinear.torch_int8_awq import TorchInt8AwqQuantLinear +from gptqmodel.nn_modules.qlinear.torch_awq import AwqTorchLinear +from gptqmodel.nn_modules.qlinear.torch_int8_awq import TorchInt8AwqLinear from gptqmodel.quantization import FORMAT, METHOD from gptqmodel.utils.backend import BACKEND from gptqmodel.utils.importer import select_quant_linear @@ -44,7 +44,7 @@ def _pack_awq_tensor(unpacked: torch.Tensor, bits: int) -> torch.Tensor: return packed -def _copy_awq_buffers(src: AwqTorchQuantLinear, dst: TorchInt8AwqQuantLinear) -> None: +def _copy_awq_buffers(src: AwqTorchLinear, dst: TorchInt8AwqLinear) -> None: dst.qweight.copy_(src.qweight) dst.qzeros.copy_(src.qzeros) dst.scales.copy_(src.scales) @@ -99,7 +99,7 @@ def test_torch_int8_awq_cpu_kernel_deviation_against_torch_awq(dtype: torch.dtyp out_features=out_features, ) - baseline = AwqTorchQuantLinear( + baseline = AwqTorchLinear( bits=bits, group_size=group_size, sym=True, @@ -109,7 +109,7 @@ def test_torch_int8_awq_cpu_kernel_deviation_against_torch_awq(dtype: torch.dtyp bias=True, register_buffers=True, ) - candidate = TorchInt8AwqQuantLinear( + candidate = TorchInt8AwqLinear( bits=bits, group_size=group_size, sym=True, @@ -153,7 +153,7 @@ def test_torch_int8_awq_cpu_kernel_deviation_against_torch_awq(dtype: torch.dtyp def test_torch_int8_awq_kernel_is_cpu_only(): with pytest.raises(NotImplementedError): - TorchInt8AwqQuantLinear.validate_device(DEVICE.XPU) + TorchInt8AwqLinear.validate_device(DEVICE.XPU) def test_torch_int8_awq_backend_selection(): @@ -168,4 +168,4 @@ def test_torch_int8_awq_backend_selection(): quant_method=METHOD.AWQ, pack_dtype=torch.int32, ) - assert qlinear_cls is TorchInt8AwqQuantLinear + assert qlinear_cls is TorchInt8AwqLinear diff --git a/tests/models/awq/test_glm4_moe.py b/tests/models/awq/test_glm4_moe.py index 0bb8204b7..0e5b32d73 100644 --- a/tests/models/awq/test_glm4_moe.py +++ b/tests/models/awq/test_glm4_moe.py @@ -6,7 +6,6 @@ from model_test import ModelTest from gptqmodel.quantization import FORMAT, METHOD -from gptqmodel.utils.eval import EVAL class TestGlm4Moe(ModelTest): @@ -17,17 +16,18 @@ class TestGlm4Moe(ModelTest): DELETE_QUANTIZED_MODEL = False DATASET_SIZE = 512 GROUP_SIZE = 32 - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + EVAL_TASKS_SLOW = { + "arc_challenge": { "acc": {"value": 0.5026, "floor_pct": 0.04}, "acc_norm": {"value": 0.5171, "floor_pct": 0.04}, }, - EVAL.LM_EVAL.MMLU_STEM: { + "mmlu_stem": { "acc": {"value": 0.6362, "floor_pct": 0.04}, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) def test_glm4moe(self): - self.quant_lm_eval() + self.quantize_and_evaluate() class TestGlm4_5_Air(ModelTest): @@ -38,15 +38,16 @@ class TestGlm4_5_Air(ModelTest): DELETE_QUANTIZED_MODEL = False DATASET_SIZE = 512 GROUP_SIZE = 32 - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + EVAL_TASKS_SLOW = { + "arc_challenge": { "acc": {"value": 0.5247, "floor_pct": 0.04}, "acc_norm": {"value": 0.5614, "floor_pct": 0.04}, }, - EVAL.LM_EVAL.MMLU_STEM: { + "mmlu_stem": { "acc": {"value": 0.6403, "floor_pct": 0.04}, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) def test_glm4moe(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/awq/test_llama3_2.py b/tests/models/awq/test_llama3_2.py index 60220bc82..4e1c01aa9 100644 --- a/tests/models/awq/test_llama3_2.py +++ b/tests/models/awq/test_llama3_2.py @@ -15,54 +15,83 @@ from model_test import ModelTest -from gptqmodel import BACKEND from gptqmodel.quantization import FORMAT, METHOD -from gptqmodel.utils.eval import EVAL -# | Metric | MARLIN | +# | Metric | AWQ GEMM | # |--------------------------------|----------| -# | arc_challenge :: acc,none | 0.3166 | -# | arc_challenge :: acc_norm,none | 0.3464 | -# | mmlu_stem :: acc,none | 0.3692 | -# | gsm8k_plat :: exact,flexible | 0.2994 | +# | arc_challenge :: acc,none | 0.3140 | +# | arc_challenge :: acc_norm,none | 0.3541 | +# | mmlu_stem :: acc,none | 0.3841 | +# | gsm8k_plat :: exact,flexible | 0.3499 | class TestLlama3_2_awq(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct" # "meta-llama/Llama-3.2-1B-Instruct" EVAL_BATCH_SIZE = 64 DATASET_CONCAT_SIZE = 2048 # new # STOP_AFTER_LAYER = 0 - EVAL_TASKS = { - EVAL.LM_EVAL.GSM8K_PLATINUM_COT: { + EVAL_TASKS_SLOW = { + "gsm8k_platinum_cot": { "chat_template": True, - "exact_match,flexible-extract": { - "value": 0.2440, + "acc,num": { + "value": 0.34987593052109184, "floor_pct": 0.04, }, }, - EVAL.LM_EVAL.ARC_CHALLENGE: { + "arc_challenge": { "chat_template": True, "acc": { - "value": 0.3166, + "value": 0.31399317406143346, "floor_pct": 0.04, }, "acc_norm": { - "value": 0.3464, + "value": 0.35409556313993173, "floor_pct": 0.04, }, }, - EVAL.LM_EVAL.MMLU_STEM: { + "mmlu_stem": { "chat_template": False, "acc": { - "value": 0.3692, + "value": 0.3840786552489692, "floor_pct": 0.04, }, }, } + # Fast-mode regression scores captured on CUDA_VISIBLE_DEVICES=6 (RTX 4090). + EVAL_TASKS_FAST = { + "gsm8k_platinum_cot": { + "chat_template": True, + "acc,num": { + "value": 0.4532671629445823, + "floor_pct": 0.04, + "ceil_pct": 1.0, + }, + }, + "arc_challenge": { + "chat_template": True, + "acc": { + "value": 0.31313993174061433, + "floor_pct": 0.04, + "ceil_pct": 1.0, + }, + "acc_norm": { + "value": 0.35665529010238906, + "floor_pct": 0.04, + "ceil_pct": 1.0, + }, + }, + "mmlu_stem": { + "chat_template": False, + "acc": { + "value": 0.3910561370123692, + "floor_pct": 0.04, + "ceil_pct": 1.0, + }, + }, + } FORMAT = FORMAT.GEMM METHOD = METHOD.AWQ SYM = True TORCH_DTYPE = torch.float16 - LOAD_BACKEND = BACKEND.TORCH_AWQ def test_llama3_2_awq(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/awq/test_llama3_2_awq_protocol.py b/tests/models/awq/test_llama3_2_awq_protocol.py new file mode 100644 index 000000000..3222355f9 --- /dev/null +++ b/tests/models/awq/test_llama3_2_awq_protocol.py @@ -0,0 +1,206 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +import os +import sys + +import pytest +import torch + + +TESTS_MODELS_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +if TESTS_MODELS_ROOT not in sys.path: + sys.path.insert(0, TESTS_MODELS_ROOT) + +from model_test import ModelTest + +from gptqmodel import BACKEND +from gptqmodel.nn_modules.qlinear import BaseQuantLinear +from gptqmodel.nn_modules.qlinear.machete_awq import AwqMacheteLinear +from gptqmodel.quantization import FORMAT, METHOD, AWQConfig +from gptqmodel.quantization.protocol import ( + Rule, + Stage, + compile_plan_to_quantize_config, + compile_protocol, + compile_protocol_yaml_text, +) +from gptqmodel.utils.machete import _validate_machete_device_support + + +LAYER0_ONLY_NEGATIVE_MATCH = r".*layers\.(?:[1-9]|[12][0-9]|3[0-2])\..*" + + +def _python_protocol(): + return { + "version": 2, + "stages": [ + Stage( + name="ptq", + rules=[ + Rule( + match=["*", f"-:{LAYER0_ONLY_NEGATIVE_MATCH}"], + weight={ + "quantize": { + "method": "awq", + "bits": 4, + "group_size": 128, + "sym": True, + "desc_act": False, + }, + "export": { + "format": "awq", + "variant": "gemm", + "impl": "marlin_awq", + }, + }, + ), + ], + ), + ], + } + + +def _yaml_protocol() -> str: + return r""" +version: 2 +stages: + - name: ptq + rules: + - match: + - "*" + - '-:.*layers\.(?:[1-9]|[12][0-9]|3[0-2])\..*' + weight: + quantize: + method: awq + bits: 4 + group_size: 128 + sym: true + desc_act: false + export: + format: awq + variant: gemm + impl: marlin_awq +""" + + +class _BaseLlama3_2AWQProtocol(ModelTest): + pytestmark = pytest.mark.skipif( + ( + (not __import__("torch").cuda.is_available()) + or not _validate_machete_device_support() + ), + reason="CUDA plus NVIDIA Hopper-or-newer GPUs are required for AWQ protocol dynamic-match integration tests", + ) + + NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct" + EVAL_BATCH_SIZE = 64 + DATASET_CONCAT_SIZE = 2048 + EVAL_TASKS = { + "gsm8k_platinum_cot": { + "chat_template": True, + "acc,num": { + "value": 0.4690, + "floor_pct": 0.05, + "ceil_pct": 0.05, + }, + }, + "mmlu_stem": { + "chat_template": False, + "acc": { + "value": 0.3999, + "floor_pct": 0.03, + "ceil_pct": 0.03, + }, + }, + "arc_challenge": { + "chat_template": True, + "acc": { + "value": 0.3221, + "floor_pct": 0.05, + "ceil_pct": 0.05, + }, + "acc_norm": { + "value": 0.3528, + "floor_pct": 0.03, + "ceil_pct": 0.03, + }, + }, + } + TORCH_DTYPE = torch.float16 + QUANT_BACKEND = BACKEND.MACHETE + LOAD_BACKEND = BACKEND.MACHETE + KERNEL_INFERENCE = {AwqMacheteLinear} + + def _compiled_protocol_plan(self): + raise NotImplementedError + + def _build_quantize_config(self): + return compile_plan_to_quantize_config(self._compiled_protocol_plan()) + + def _assert_layer0_only_dynamic(self, cfg): + assert isinstance(cfg, AWQConfig) + assert cfg.quant_method == METHOD.AWQ + assert cfg.format == FORMAT.GEMM + assert cfg.dynamic == {f"-:{LAYER0_ONLY_NEGATIVE_MATCH}": {}} + + def _assert_only_first_layer_quantized(self, model): + layer0_quantized = [] + later_layer_quantized = [] + + for name, module in model.named_modules(): + if not isinstance(module, BaseQuantLinear): + continue + if ".layers.0." in name: + layer0_quantized.append(name) + elif ".layers." in name: + later_layer_quantized.append(name) + + assert layer0_quantized, "Expected at least one quantized module in layer 0." + assert not later_layer_quantized, ( + "Expected quantization only in layer 0, " + f"but found later-layer modules: {later_layer_quantized[:8]}" + ) + + def _run_layer0_only_protocol_eval(self): + cfg = self._build_quantize_config() + self._assert_layer0_only_dynamic(cfg) + + self.model, _, _ = self.quantModel( + self.NATIVE_MODEL_ID, + batch_size=self.QUANT_BATCH_SIZE, + trust_remote_code=self.TRUST_REMOTE_CODE, + dtype=self.TORCH_DTYPE, + call_perform_post_quant_validation=False, + ) + self.check_kernel(self.model, self.KERNEL_INFERENCE) + self._assert_only_first_layer_quantized(self.model) + + eval_records = getattr(self, "_post_quant_eval_records", {}) + target_backend = self._current_load_backend() + if eval_records and target_backend in eval_records: + task_results = eval_records[target_backend] + else: + task_results = self.evaluate_model( + model=self.SAVE_PATH if self.SAVE_PATH else self.model, + trust_remote_code=self.TRUST_REMOTE_CODE, + delete_quantized_model=self.DELETE_QUANTIZED_MODEL, + ) + self.check_results(task_results) + self._cleanup_quantized_model(self.model, enabled=self.DELETE_QUANTIZED_MODEL) + + +class TestLlama3_2_AWQProtocolPython(_BaseLlama3_2AWQProtocol): + def _compiled_protocol_plan(self): + return compile_protocol(_python_protocol()) + + def test_llama3_2_awq_protocol_python(self): + self._run_layer0_only_protocol_eval() + + +class TestLlama3_2_AWQProtocolYAML(_BaseLlama3_2AWQProtocol): + def _compiled_protocol_plan(self): + return compile_protocol_yaml_text(_yaml_protocol()) + + def test_llama3_2_awq_protocol_yaml(self): + self._run_layer0_only_protocol_eval() diff --git a/tests/models/awq/test_marin_awq.py b/tests/models/awq/test_marin_awq.py index c5d2a6cb3..fa25dbd90 100644 --- a/tests/models/awq/test_marin_awq.py +++ b/tests/models/awq/test_marin_awq.py @@ -6,7 +6,6 @@ from model_test import ModelTest from gptqmodel.quantization.config import FORMAT, METHOD -from gptqmodel.utils.eval import EVAL class TestMarin(ModelTest): @@ -18,15 +17,23 @@ class TestMarin(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/marin-32b-base" # VRAM_STRATEGY = VramStrategy.BALANCED # Marin inherits Qwen3's backbone with QK-Norm attention. - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { - "acc": {"value": 0.5828, "floor_pct": 0.04}, - "acc_norm": {"value": 0.6007, "floor_pct": 0.04}, + EVAL_TASKS_SLOW = { + "arc_challenge": { + "acc": {"value": 0.5299, "floor_pct": 0.04}, + "acc_norm": {"value": 0.5546, "floor_pct": 0.04}, }, - EVAL.LM_EVAL.MMLU_STEM: { - "acc": {"value": 0.6673, "floor_pct": 0.04}, + "gsm8k_platinum_cot": { + "chat_template": False, + "acc,num": { + "value": 0.6716294458229942, + "floor_pct": 0.04, + }, + }, + "mmlu_stem": { + "acc": {"value": 0.6676, "floor_pct": 0.04}, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) def test_marin(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/awq/test_moe.py b/tests/models/awq/test_moe.py index 62d10aae5..93e8cd02c 100644 --- a/tests/models/awq/test_moe.py +++ b/tests/models/awq/test_moe.py @@ -12,24 +12,25 @@ from model_test import ModelTest from gptqmodel.quantization import FORMAT, METHOD -from gptqmodel.utils.eval import EVAL +from gptqmodel.quantization.config import ExpertsRoutingOverride, MoEConfig -# | Metric | MARLIN | # |--------------------------------|----------| # | arc_challenge :: acc,none | 0.5094 | # | arc_challenge :: acc_norm,none | 0.5486 | class TestQwen3MoeAwq(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Qwen3-30B-A3B" DATASET_CONCAT_SIZE = 2048 - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + EVAL_TASKS_SLOW = { + "arc_challenge": { "acc": {"value": 0.5094, "floor_pct": 0.04}, "acc_norm": {"value": 0.5486, "floor_pct": 0.04}, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) FORMAT = FORMAT.GEMM METHOD = METHOD.AWQ + MOE_CONFIG = MoEConfig(routing=ExpertsRoutingOverride()) def test_moe_awq(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/awq/test_qwen3_5_moe.py b/tests/models/awq/test_qwen3_5_moe.py index 6ebaa45ac..7db01ffbc 100644 --- a/tests/models/awq/test_qwen3_5_moe.py +++ b/tests/models/awq/test_qwen3_5_moe.py @@ -11,22 +11,21 @@ from model_test import ModelTest -from gptqmodel.quantization.config import FORMAT, METHOD, FailSafe, VramStrategy -from gptqmodel.utils.eval import EVAL +from gptqmodel.quantization.config import FORMAT, METHOD, Fallback, VramStrategy class TestQwen3_5Moe(ModelTest): - FAILSAFE = FailSafe() + FALLBACK = Fallback() FORMAT = FORMAT.GEMM METHOD = METHOD.AWQ NATIVE_MODEL_ID = "/monster/data/model/Qwen3.5-35B-A3B" - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + EVAL_TASKS_SLOW = { + "arc_challenge": { "acc": {"value": 0.5887, "floor_pct": 0.04}, "acc_norm": {"value": 0.6100, "floor_pct": 0.04}, }, - EVAL.LM_EVAL.MMLU_STEM: { + "mmlu_stem": { "chat_template": False, "acc": { "value": 0.8106, @@ -34,9 +33,10 @@ class TestQwen3_5Moe(ModelTest): }, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) - VRAM_STRATEGY = VramStrategy.BALANCED - OFFLOAD_TO_DISK = False # FIXME Currently, after defuser transforms the model, OFFLOAD_TO_DISK must be False for quantization. + DENSE_VRAM_STRATEGY = VramStrategy.EXCLUSIVE + MOE_VRAM_STRATEGY = VramStrategy.BALANCED def test_qwen3_5_moe(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/awq/test_qwen3_8b_base_awq.py b/tests/models/awq/test_qwen3_8b_base_awq.py new file mode 100644 index 000000000..c57b11adc --- /dev/null +++ b/tests/models/awq/test_qwen3_8b_base_awq.py @@ -0,0 +1,46 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from model_test import ModelTest + +from gptqmodel.quantization import FORMAT, METHOD + + +class TestQwen3_8B_Base_AWQ(ModelTest): + NATIVE_MODEL_ID = "/monster/data/model/Qwen3-8B-Base" # "Qwen/Qwen3-8B-Base" + EVAL_TASKS = { + "gsm8k_platinum_cot": { + "chat_template": True, + "acc,num": { + "value": 0.2994, + "floor_pct": 0.04, + }, + }, + "arc_challenge": { + "chat_template": True, + "acc": { + "value": 0.3166, + "floor_pct": 0.04, + }, + "acc_norm": { + "value": 0.3464, + "floor_pct": 0.04, + }, + }, + "mmlu_stem": { + "chat_template": False, + "acc": { + "value": 0.3692, + "floor_pct": 0.04, + }, + }, + } + FORMAT = FORMAT.GEMM + METHOD = METHOD.AWQ + QUANT_BATCH_SIZE = 1 + MODEL_COMPAT_FAST_LAYER_POSITION = "first" + + def test_qwen3_8b_base_awq(self): + self.quantize_and_evaluate() diff --git a/tests/models/foem/test_llama3_2.py b/tests/models/foem/test_llama3_2.py new file mode 100644 index 000000000..cfb16803b --- /dev/null +++ b/tests/models/foem/test_llama3_2.py @@ -0,0 +1,111 @@ + +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import os +import sys + +from gptqmodel.quantization.config import FOEMConfig + + +TESTS_MODELS_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +if TESTS_MODELS_ROOT not in sys.path: + sys.path.insert(0, TESTS_MODELS_ROOT) + +from model_test import ModelTest + + +# | Metric | MARLIN | +# |----------------------------------------------------|----------| +# | arc_challenge :: acc,none | 0.3166 | +# | arc_challenge :: acc_norm,none | 0.3430 | +# | gsm8k_platinum_cot :: acc,num | 0.3906 | +# | mmlu_stem :: acc,none | 0.3942 | +class TestLlama3_2(ModelTest): + # Keep one stable saved checkpoint so eval-only repro runs can reuse the exact post-quant model. + SAVE_PATH = os.environ.get( + "GPTQMODEL_LLAMA3_2_SAVE_PATH", + "/tmp/llama3_2_gptq_saved_ckpt", + ) + DELETE_QUANTIZED_MODEL = False + NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct" # "meta-llama/Llama-3.2-1B-Instruct" + EVAL_BATCH_SIZE = 64 + DATASET_CONCAT_SIZE = 2048 + EVAL_TASKS_SLOW = { + "gsm8k_platinum_cot": { + "chat_template": True, + "acc,num": { + "value": 0.3987, + "floor_pct": 0.04, + }, + }, + # "mmlu_stem": { + # "chat_template": False, + # "acc": { + # "value": 0.3860, # 0.3099 4096, 0.3270 2048 + # "floor_pct": 0.04, + # }, + # }, + "arc_challenge": { + "chat_template": True, + "acc": { + "value": 0.3234, # 0.3294 4096, 0.3242 2048 + "floor_pct": 0.04, + }, + "acc_norm": { + "value": 0.3643, # 0.3558 4096, 0.3635 2048 + "floor_pct": 0.04, + }, + }, + } + EVAL_TASKS_FAST = { + "gsm8k_platinum_cot": { + "chat_template": True, + "evalution_use_model_path": True, + "evalution_batch_size": "auto", + "evalution_model_args": { + "dtype": "bfloat16", + "attn_implementation": "paged|flash_attention_2", + "device": "cuda:0", + }, + "evalution_suite_kwargs": { + "batch_size": 32, + "max_new_tokens": 256, + "stream": True, + }, + "acc,num": { + "value": {"A100": 0.4749}, + "floor_pct": 0.04, + "ceil_pct": 1.0, + }, + }, + # "mmlu_stem": { + # "chat_template": False, + # "acc": { + # "value": 0.3942, + # "floor_pct": 0.04, + # "ceil_pct": 1.0, + # }, + # "max_rows": 256, + # }, + "arc_challenge": { + "chat_template": True, + "acc": { + "value": {"A100": 0.3148}, + "floor_pct": 0.04, + "ceil_pct": 1.0, + }, + "acc_norm": { + "value": {"A100": 0.3472}, + "floor_pct": 0.04, + "ceil_pct": 1.0, + }, + }, + } + FOEM = FOEMConfig() + MODEL_COMPAT_FAST_LAYER_POSITION="first" + + def test_llama3_2(self): + self.quantize_and_evaluate() diff --git a/tests/models/foem/test_moe.py b/tests/models/foem/test_moe.py new file mode 100644 index 000000000..f262d33dc --- /dev/null +++ b/tests/models/foem/test_moe.py @@ -0,0 +1,30 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import sys +from pathlib import Path + + +sys.path.insert(0, str(Path(__file__).resolve().parents[1])) + +from model_test import ModelTest + +from gptqmodel.quantization.config import FOEMConfig + + +class TestQwen3MoeFOEM(ModelTest): + NATIVE_MODEL_ID = "/monster/data/model/Qwen3-30B-A3B" + EVAL_TASKS_SLOW = { + "arc_challenge": { + "acc": {"value": 0.5094, "floor_pct": 0.04}, + "acc_norm": {"value": 0.5486, "floor_pct": 0.04}, + }, + } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) + FOEM = FOEMConfig() + MODEL_COMPAT_FAST_LAYER_POSITION = "first" + + def test_moe_awq(self): + self.quantize_and_evaluate() diff --git a/tests/models/model_test.py b/tests/models/model_test.py index 6db4a9c64..77ecf08d1 100644 --- a/tests/models/model_test.py +++ b/tests/models/model_test.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium # -- do not touch +import copy import os import sys @@ -22,7 +23,6 @@ # -- end do not touch -from enum import Enum # noqa: E402 from pathlib import Path # noqa: E402 from typing import Any, Dict, List, Optional # noqa: E402 @@ -37,10 +37,74 @@ import tempfile # noqa: E402 import textwrap # noqa: E402 import unittest # noqa: E402 -from collections.abc import Iterable # noqa: E402 +from collections.abc import Iterable, Mapping # noqa: E402 import torch.cuda # noqa: E402 -from datasets import load_dataset # noqa: E402 + + +def _env_choice(*names: str, default: str) -> str: + """Return the first non-empty env override from a prioritized name list.""" + for name in names: + raw = os.environ.get(name) + if raw is None: + continue + value = raw.strip().lower() + if value: + return value + return default + + +def _env_int(*names: str, default: int) -> int: + """Return the first parseable integer env override from a prioritized name list.""" + for name in names: + raw = os.environ.get(name) + if raw is None: + continue + value = raw.strip() + if value: + return int(value) + return default + + +def _env_flag(*names: str, default: bool = False) -> bool: + """Return the first parseable boolean env override from a prioritized name list.""" + truthy = {"1", "true", "yes", "on", "y", "t"} + falsy = {"0", "false", "no", "off", "n", "f"} + for name in names: + raw = os.environ.get(name) + if raw is None: + continue + value = raw.strip().lower() + if value in truthy: + return True + if value in falsy: + return False + return default + + +def _env_optional_flag(*names: str) -> Optional[bool]: + """Return the first parseable boolean env override, or None when no override is set.""" + truthy = {"1", "true", "yes", "on", "y", "t"} + falsy = {"0", "false", "no", "off", "n", "f"} + for name in names: + raw = os.environ.get(name) + if raw is None: + continue + value = raw.strip().lower() + if value in truthy: + return True + if value in falsy: + return False + return None + + +try: # noqa: E402 + from datasets import load_dataset as hf_load_dataset # noqa: E402 +except Exception as exc: # pragma: no cover - depends on test environment + hf_load_dataset = None + DATASETS_IMPORT_ERROR = exc +else: + DATASETS_IMPORT_ERROR = None try: @@ -57,20 +121,34 @@ def is_flash_attn_2_available(): # type: ignore return False +from tests.eval import ( # noqa: E402 + evaluate, + format_eval_result_table, + get_eval_task_results, + resolve_eval_metric_alias, +) + from gptqmodel import BACKEND, DEBUG_ON, GPTQModel # noqa: E402 from gptqmodel.looper.module_looper import StopMainLoop # noqa: E402 from gptqmodel.models.base import BaseQModel # noqa: E402 from gptqmodel.nn_modules.qlinear import BaseQuantLinear # noqa: E402 from gptqmodel.quantization import FORMAT, METHOD # noqa: E402 from gptqmodel.quantization.config import ( # noqa: E402 - FailSafe, + BitsAndBytesConfig, + Fallback, + FOEMConfig, + FP8Config, + GGUFConfig, GPTAQConfig, HessianConfig, MoEConfig, + ParoConfig, QuantizeConfig, + RTNConfig, VramStrategy, + WeightOnlyConfig, + resolve_quant_format, ) -from gptqmodel.utils.eval import EVAL # noqa: E402 from gptqmodel.utils.model import MODALITY # noqa: E402 from gptqmodel.utils.torch import torch_empty_cache # noqa: E402 @@ -81,13 +159,16 @@ def is_flash_attn_2_available(): # type: ignore DEFAULT_FLOOR_PCT = 0.05 DEFAULT_CEIL_PCT = 0.10 -DEFAULT_TASK_NAMES = (EVAL.LM_EVAL.ARC_CHALLENGE,) +DEFAULT_TASK_NAMES = ("arc_challenge",) class ModelTest(unittest.TestCase): DEBUG = True # enable extra debug output - VRAM_STRATEGY = VramStrategy.EXCLUSIVE + DENSE_VRAM_STRATEGY = VramStrategy.EXCLUSIVE + DENSE_VRAM_STRATEGY_DEVICES = None + MOE_VRAM_STRATEGY = VramStrategy.EXCLUSIVE + MOE_VRAM_STRATEGY_DEVICES = None TRUST_REMOTE_CODE = False TORCH_DTYPE = "auto" EVAL_BATCH_SIZE = "auto" @@ -105,6 +186,16 @@ class ModelTest(unittest.TestCase): EVAL_TASKS = None EVAL_SINGLE_GPU = True LOAD_MODEL_EXTRA_ARGS: Dict[str, Any] = {} + EVAL_TASKS_FAST = None + EVAL_TASKS_SLOW = None + MODEL_TEST_MODE_ENV = "GPTQMODEL_MODEL_TEST_MODE" + MODEL_TEST_MODE_FAST = "fast" + MODEL_TEST_MODE_SLOW = "slow" + # Shared override for the fast-mode quantized layer prefix across all ModelTest-based tests. + FAST_LAYER_COUNT_ENV = "GPTQMODEL_FAST_LAYER_COUNT" + FAST_LAYER_POSITION_ENV = "GPTQMODEL_FAST_LAYER_POSITION" + MODEL_COMPAT_FAST_LAYER_COUNT = None + MODEL_COMPAT_FAST_LAYER_POSITION = None KERNEL_QUANT = {} # kernel sets KERNEL_INFERENCE = {} # kernel sets @@ -116,25 +207,46 @@ class ModelTest(unittest.TestCase): GROUP_SIZE = 128 DESC_ACT = False SYM = True - GPTQA = False + GPTAQ: Optional[GPTAQConfig] = None + FOEM: Optional[FOEMConfig] = None ACT_GROUP_AWARE = True - FAILSAFE = FailSafe() + FALLBACK = Fallback() EORA = None DAMP_PERCENT = 0.05 MSE = 0.0 DYNAMIC = None HESSIAN_CHUNK_SIZE = None + WEIGHT_ONLY = None + BNB_FORMAT = None + BNB_BLOCK_SIZE = None + BNB_COMPRESS_STATISTICS = None + + PAROQUANT_ROTATION_EPOCHS = None + PAROQUANT_FINETUNE_EPOCHS = None + PAROQUANT_TRAIN_SAMPLES = None SAVE_PATH = None # default is temp folder + SPLIT_BY: Optional[str] = None USE_FLASH_ATTN = True INFERENCE_PROMPT = "The capital city of France is named" INFERENCE_RESULT_KEYWORDS = ["paris"] + DISABLE_NATIVE_BASELINE_FALLBACK = True GENERATE_EVAL_SIZE_MIN = 128 GENERATE_EVAL_SIZE_MAX = 128 + APPLY_CHAT_TEMPLATE = False LM_HEAD_LOSS_MAX_DELTA_PERCENT = 0.1 # ±10% + + @classmethod + def setUpClass(cls): + super().setUpClass() + model_id = getattr(cls, "NATIVE_MODEL_ID", None) + if isinstance(model_id, str): + model_id = model_id.strip() + if os.path.isabs(model_id) and not os.path.isdir(model_id): + raise unittest.SkipTest(f"Model path missing: {model_id}") EXPECT_LM_HEAD_LOSS = None STOP_AFTER_LAYER: Optional[int] = None MOE_CONFIG: Optional[MoEConfig] = None @@ -142,12 +254,297 @@ class ModelTest(unittest.TestCase): GENERIC_TEST_PROMPTS = [ {"prompt": "Which city is the capital city of France?", "keywords": ["paris"]}, - {"prompt": "What is the smallest habitable planet in the milky way?", "keywords": ["earth"]}, - {"prompt": "Who wrote the play Romeo and Juliet?", "keywords": ["shakespeare"]}, - {"prompt": "What gas do plants primarily absorb from the atmosphere during photosynthesis?", "keywords": ["carbon dioxide"]}, - {"prompt": "Name the largest ocean on Earth.", "keywords": ["pacific"]}, + # {"prompt": "What is the smallest habitable planet in the milky way?", "keywords": ["earth"]}, + # {"prompt": "Who wrote the play Romeo and Juliet?", "keywords": ["shakespeare"]}, + # {"prompt": "What gas do plants primarily absorb from the atmosphere during photosynthesis?", "keywords": ["carbon dioxide"]}, + # {"prompt": "Name the largest ocean on Earth.", "keywords": ["pacific"]}, ] + @classmethod + def derive_fast_eval_tasks(cls, eval_tasks, min_ceil_pct: float = 1.0): + if eval_tasks is None: + return None + + fast_tasks = copy.deepcopy(eval_tasks) + for metrics in fast_tasks.values(): + if not isinstance(metrics, dict): + continue + for metric_name, spec in list(metrics.items()): + if metric_name == "chat_template": + continue + if isinstance(spec, dict): + current_ceil = spec.get("ceil_pct", spec.get("max_delta_ceil_percent", DEFAULT_CEIL_PCT)) + spec["ceil_pct"] = max(float(current_ceil), float(min_ceil_pct)) + else: + metrics[metric_name] = { + "value": float(spec), + "floor_pct": DEFAULT_FLOOR_PCT, + "ceil_pct": float(min_ceil_pct), + } + return fast_tasks + + def _model_test_mode(self) -> str: + raw = os.environ.get(self.MODEL_TEST_MODE_ENV, self.MODEL_TEST_MODE_FAST) + normalized = str(raw).strip().lower() + if normalized in {"", self.MODEL_TEST_MODE_FAST}: + return self.MODEL_TEST_MODE_FAST + if normalized in {self.MODEL_TEST_MODE_SLOW, "full"}: + return self.MODEL_TEST_MODE_SLOW + raise ValueError( + f"Unsupported {self.MODEL_TEST_MODE_ENV}={raw!r}; expected " + f"`{self.MODEL_TEST_MODE_FAST}` or `{self.MODEL_TEST_MODE_SLOW}`." + ) + + def _is_fast_model_test_mode(self) -> bool: + return self._model_test_mode() == self.MODEL_TEST_MODE_FAST + + @contextlib.contextmanager + def model_compat_test_context(self): + previous = getattr(self, "_model_compat_eval_in_progress", False) + self._model_compat_eval_in_progress = True + try: + yield + finally: + self._model_compat_eval_in_progress = previous + + def _in_model_compat_eval_flow(self) -> bool: + return bool(getattr(self, "_model_compat_eval_in_progress", False)) + + def _should_use_fast_model_compat_quant(self) -> bool: + return self._in_model_compat_eval_flow() and self._is_fast_model_test_mode() + + def _selected_eval_tasks_config(self): + if self._is_fast_model_test_mode(): + if self.EVAL_TASKS_FAST is not None: + return self.EVAL_TASKS_FAST + if self.EVAL_TASKS_SLOW is not None: + return self.derive_fast_eval_tasks(self.EVAL_TASKS_SLOW) + if self.EVAL_TASKS is not None: + return self.derive_fast_eval_tasks(self.EVAL_TASKS) + else: + if self.EVAL_TASKS_SLOW is not None: + return self.EVAL_TASKS_SLOW + return self.EVAL_TASKS + + def _mode_specific_baseline_value(self, attr_name: str): + mode_suffix = "FAST" if self._is_fast_model_test_mode() else "SLOW" + preferred = f"{attr_name}_{mode_suffix}" + if hasattr(self, preferred): + return self._resolve_metric_baseline_value(getattr(self, preferred)) + + if self._is_fast_model_test_mode(): + fallback = f"{attr_name}_SLOW" + if hasattr(self, fallback): + return self._resolve_metric_baseline_value(getattr(self, fallback)) + + return self._resolve_metric_baseline_value(getattr(self, attr_name, None)) + + def _legacy_metric_ceil_pct(self) -> float: + if self._is_fast_model_test_mode(): + return 1.0 + return DEFAULT_CEIL_PCT + + @staticmethod + def _merge_dynamic_configs(*configs): + merged = {} + for config in configs: + if not config: + continue + merged.update(copy.deepcopy(config)) + return merged or None + + def _resolve_layers_for_fast_model_compat(self, model): + layers_node = model.extract_layers_node() + if isinstance(layers_node, (list, tuple)): + if not layers_node: + return None, None + layers_node = layers_node[0] + + layers = model.model + for part in layers_node.split("."): + layers = getattr(layers, part) + + return layers_node, layers + + @staticmethod + def _layer_type_signature(layer) -> tuple: + layer_type = f"{type(layer).__module__}.{type(layer).__qualname__}" + top_children = tuple( + (name, f"{type(module).__module__}.{type(module).__qualname__}") + for name, module in layer.named_children() + ) + feature_tokens = ("moe", "expert", "router", "gate") + features = set() + for name, _module in layer.named_modules(): + lower = name.lower() + for token in feature_tokens: + if token in lower: + features.add(token) + return layer_type, top_children, tuple(sorted(features)) + + def _summarize_layer_signatures(self, layers) -> List[Dict[str, Any]]: + summaries: Dict[tuple, Dict[str, Any]] = {} + for idx, layer in enumerate(layers): + signature = self._layer_type_signature(layer) + summary = summaries.get(signature) + if summary is None: + layer_type, top_children, features = signature + summary = { + "first_idx": idx, + "count": 0, + "layer_type": layer_type, + "top_children": [name for name, _ in top_children][:8], + "features": list(features), + } + summaries[signature] = summary + summary["count"] += 1 + + return sorted(summaries.values(), key=lambda item: item["first_idx"]) + + def _resolve_fast_model_layer_count_config(self, layer_count: int) -> Dict[str, Any]: + raw_env_value = os.environ.get(self.FAST_LAYER_COUNT_ENV) + if raw_env_value is not None: + resolved = self._parse_fast_model_layer_count(raw_env_value, field_name=self.FAST_LAYER_COUNT_ENV, layer_count=layer_count) + return { + "source": "env", + "name": self.FAST_LAYER_COUNT_ENV, + "raw": raw_env_value, + "resolved": resolved, + } + + configured_min_layers = self.MODEL_COMPAT_FAST_LAYER_COUNT + if configured_min_layers is not None: + resolved = self._parse_fast_model_layer_count( + configured_min_layers, + field_name="MODEL_COMPAT_FAST_LAYER_COUNT", + layer_count=layer_count, + ) + return { + "source": "class", + "name": "MODEL_COMPAT_FAST_LAYER_COUNT", + "raw": configured_min_layers, + "resolved": resolved, + } + + return { + "source": "default", + "name": "default", + "raw": 2, + "resolved": min(2, layer_count), + } + + def _resolve_fast_model_layer_position(self) -> str: + raw = os.environ.get(self.FAST_LAYER_POSITION_ENV) + if raw is None: + configured = self.MODEL_COMPAT_FAST_LAYER_POSITION + raw = "last" if configured is None else configured + normalized = str(raw).strip().lower() + if normalized in {"", "first", "prefix", "head"}: + return "first" + if normalized in {"last", "suffix", "tail", "top"}: + return "last" + raise ValueError( + f"{self.FAST_LAYER_POSITION_ENV} must be `first` or `last`, got {raw!r}." + ) + + @staticmethod + def _parse_fast_model_layer_count(raw_value: Any, *, field_name: str, layer_count: int) -> int: + normalized = str(raw_value).strip().lower() + if normalized == "": + return min(2, layer_count) + if normalized in {"all", "full"}: + return layer_count + + try: + return max(int(normalized), 0) + except ValueError as exc: + raise ValueError( + f"{field_name} must be a non-negative integer or `all`, got {raw_value!r}." + ) from exc + + def _fast_model_layer_limit(self, layers) -> int: + layer_count = len(layers) + config = self._resolve_fast_model_layer_count_config(layer_count) + min_layers = config["resolved"] + if layer_count <= min_layers: + return layer_count + + signature_summaries = self._summarize_layer_signatures(layers) + unique_signatures = len(signature_summaries) + + if unique_signatures == 1: + return min(min_layers, layer_count) + + last_first_idx = max((item["first_idx"] for item in signature_summaries), default=0) + return min(layer_count, max(min_layers, last_first_idx + 1)) + + def _build_fast_model_compat_dynamic(self, model) -> Optional[Dict[str, Dict[str, Any]]]: + if not self._should_use_fast_model_compat_quant(): + return None + + layers_node, layers = self._resolve_layers_for_fast_model_compat(model) + if layers_node is None or layers is None: + return None + + layer_count = len(layers) + layer_limit = self._fast_model_layer_limit(layers) + layer_limit_config = self._resolve_fast_model_layer_count_config(layer_count) + layer_position = self._resolve_fast_model_layer_position() + log.info( + "Fast quant mode layer limit config: %s=%r -> resolved %s %s/%s layers.", + layer_limit_config["name"], + layer_limit_config["raw"], + layer_position, + layer_limit_config["resolved"], + layer_count, + ) + if layer_count <= layer_limit: + log.info( + "Fast quant mode: layer limit covers the full model (%s/%s layers); skipping 0 layers.", + layer_count, + layer_count, + ) + return None + + if layer_position == "last": + skipped_layers = range(0, max(0, layer_count - layer_limit)) + else: + skipped_layers = range(layer_limit, layer_count) + + dynamic = {f"-:^{layers_node}\\.{i}\\.": {} for i in skipped_layers} + + unique_layer_types = len({self._layer_type_signature(layer) for layer in layers}) + log.info( + "Fast quant mode: quantizing %s %s/%s layers (%s unique layer type signatures covered), skipping %s layers.", + layer_position, + layer_limit, + layer_count, + unique_layer_types, + layer_count - layer_limit, + ) + signature_summaries = self._summarize_layer_signatures(layers) + log.info( + "Fast quant mode layer signature details: \n%s", + "\n".join( + ( + f"first_layer={item['first_idx']}, count={item['count']}, " + f"type={item['layer_type']}, top_children={item['top_children'] or ['']}, " + f"special_features={item['features'] or ['']}" + ) + for item in signature_summaries + ), + ) + + return dynamic + + def _apply_model_compat_quant_overrides(self, model) -> None: + dynamic = self._build_fast_model_compat_dynamic(model) + if dynamic is None: + return + + model.quantize_config.dynamic = self._merge_dynamic_configs(model.quantize_config.dynamic, dynamic) + self._model_compat_fast_dynamic = dynamic + @staticmethod def _build_layer_stop_callback(layer_idx: int): class _StopAfterLayer: @@ -176,7 +573,6 @@ def _finalize_quant_debug_path( model, tokenizer, processor, - need_create_processor: bool, cleanup_callback, ): if cleanup_callback is not None: @@ -184,70 +580,49 @@ def _finalize_quant_debug_path( cleanup_callback() except Exception: pass - if need_create_processor: - return model, tokenizer, processor - return model, tokenizer + return model, tokenizer, processor def _normalize_task_identifier(self, task): - if isinstance(task, Enum): - return task.value if task is None: raise ValueError("Evaluation task identifier cannot be None") - return str(task) + normalized = str(task).strip() + if not normalized: + raise ValueError("Evaluation task identifier cannot be empty") + return normalized def _normalize_task_list(self): task_specs = self.get_eval_tasks() - task_lookup = getattr(self, "_resolved_task_lookup", {}) - resolved_tasks = [] if task_specs: - for normalized_name in task_specs.keys(): - original = task_lookup.get(normalized_name) - if original is None: - original = self._resolve_task_enum(normalized_name) - if isinstance(task_lookup, dict): - task_lookup[normalized_name] = original - resolved_tasks.append(original) + task_names = list(task_specs.keys()) else: - resolved_tasks = list(DEFAULT_TASK_NAMES) - self._resolved_task_lookup = { - self._normalize_task_identifier(task): task for task in resolved_tasks - } + task_names = list(DEFAULT_TASK_NAMES) - normalized = [self._normalize_task_identifier(task) for task in resolved_tasks if task is not None] + normalized = [self._normalize_task_identifier(task) for task in task_names if task is not None] if not normalized: raise ValueError("No evaluation tasks configured") return normalized - def _resolve_task_enum(self, task): - if isinstance(task, Enum): - return task - if isinstance(task, str): - for enum_member in EVAL.get_task_enums(): - if task == enum_member.value or task == enum_member.name: - return enum_member - raise ValueError(f"Unknown evaluation task identifier: {task}") - def _legacy_arc_tasks(self): baselines = {} arc_metrics = {} - if hasattr(self, "NATIVE_ARC_CHALLENGE_ACC"): + native_acc = self._mode_specific_baseline_value("NATIVE_ARC_CHALLENGE_ACC") + native_acc_norm = self._mode_specific_baseline_value("NATIVE_ARC_CHALLENGE_ACC_NORM") + ceil_pct = self._legacy_metric_ceil_pct() + if native_acc is not None: arc_metrics["acc"] = { - "value": self.NATIVE_ARC_CHALLENGE_ACC, + "value": native_acc, "floor_pct": DEFAULT_FLOOR_PCT, - "ceil_pct": DEFAULT_CEIL_PCT, + "ceil_pct": ceil_pct, } - if hasattr(self, "NATIVE_ARC_CHALLENGE_ACC_NORM"): + if native_acc_norm is not None: arc_metrics["acc_norm"] = { - "value": self.NATIVE_ARC_CHALLENGE_ACC_NORM, + "value": native_acc_norm, "floor_pct": DEFAULT_FLOOR_PCT, - "ceil_pct": DEFAULT_CEIL_PCT, + "ceil_pct": ceil_pct, } if arc_metrics: - normalized = self._normalize_task_identifier(EVAL.LM_EVAL.ARC_CHALLENGE) + normalized = self._normalize_task_identifier("arc_challenge") baselines[normalized] = arc_metrics - lookup = getattr(self, "_resolved_task_lookup", None) - if isinstance(lookup, dict): - lookup[normalized] = EVAL.LM_EVAL.ARC_CHALLENGE chat_lookup = getattr(self, "_task_chat_template", None) if isinstance(chat_lookup, dict): chat_lookup[normalized] = False @@ -260,12 +635,12 @@ def _normalize_metric_spec(self, spec): if isinstance(spec, dict): if "value" not in spec: raise ValueError("Baseline metric dictionaries must include a `value` key.") - value = spec["value"] + value = self._resolve_metric_baseline_value(spec["value"]) floor_pct = spec.get("floor_pct", spec.get("max_delta_floor_percent", default_floor)) ceil_pct = spec.get("ceil_pct", spec.get("max_delta_ceil_percent", default_ceil)) metric_key = spec.get("metric_key") else: - value = spec + value = self._resolve_metric_baseline_value(spec) floor_pct = default_floor ceil_pct = default_ceil metric_key = None @@ -284,19 +659,85 @@ def _normalize_metric_spec(self, spec): "metric_key": metric_key, } + @staticmethod + def _detect_cuda0_name(): + try: + if not torch.cuda.is_available(): + return None + return str(torch.cuda.get_device_name(0)) + except Exception: + return None + + @classmethod + def _detect_gpu_profile(cls): + cuda0_name = cls._detect_cuda0_name() + if not cuda0_name: + return None + + normalized = cuda0_name.lower() + if "a100" in normalized: + return "A100" + if "4090" in normalized: + return "RTX4090" + return None + + def _resolve_metric_baseline_value(self, value): + if not isinstance(value, Mapping): + return value + + normalized_lookup = { + self._normalize_gpu_profile_key(key): val + for key, val in value.items() + } + gpu_profile = self._detect_gpu_profile() + normalized_profile = self._normalize_gpu_profile_key(gpu_profile) + + if normalized_profile is not None and normalized_profile in normalized_lookup: + return normalized_lookup[normalized_profile] + + if "a100" in normalized_lookup: + return normalized_lookup["a100"] + + available = ", ".join(sorted(normalized_lookup.keys())) + raise ValueError( + "Unable to resolve GPU-specific baseline value. " + f"Detected profile={gpu_profile!r}, available profiles={available}." + ) + + @staticmethod + def _normalize_gpu_profile_key(profile): + if profile is None: + return None + normalized = str(profile).strip().lower().replace("-", "").replace("_", "").replace(" ", "") + if normalized == "a100": + return "a100" + if normalized in {"rtx4090", "4090"}: + return "rtx4090" + return normalized + def get_eval_tasks(self): - self._resolved_task_lookup = {} self._task_chat_template = {} - if self.EVAL_TASKS: + self._task_evalution_suite_kwargs = {} + self._task_evalution_model_args = {} + self._task_evalution_use_model_path = {} + self._task_evalution_batch_size = {} + eval_tasks = self._selected_eval_tasks_config() + if eval_tasks: baselines = {} - for task, metrics in self.EVAL_TASKS.items(): - resolved_task = self._resolve_task_enum(task) - normalized_task = self._normalize_task_identifier(resolved_task) - self._resolved_task_lookup[normalized_task] = resolved_task + for task, metrics in eval_tasks.items(): + normalized_task = self._normalize_task_identifier(task) metrics_dict = dict(metrics or {}) chat_template = bool(metrics_dict.pop("chat_template", False)) + evalution_suite_kwargs = dict(metrics_dict.pop("evalution_suite_kwargs", {}) or {}) + evalution_model_args = dict(metrics_dict.pop("evalution_model_args", {}) or {}) + evalution_use_model_path = bool(metrics_dict.pop("evalution_use_model_path", False)) + evalution_batch_size = metrics_dict.pop("evalution_batch_size", None) self._task_chat_template[normalized_task] = chat_template + self._task_evalution_suite_kwargs[normalized_task] = evalution_suite_kwargs + self._task_evalution_model_args[normalized_task] = evalution_model_args + self._task_evalution_use_model_path[normalized_task] = evalution_use_model_path + self._task_evalution_batch_size[normalized_task] = evalution_batch_size baselines[normalized_task] = { metric_name: self._normalize_metric_spec(spec) @@ -309,6 +750,10 @@ def get_eval_tasks(self): for task_name in baselines.keys(): if task_name not in self._task_chat_template: self._task_chat_template[task_name] = False + self._task_evalution_suite_kwargs.setdefault(task_name, {}) + self._task_evalution_model_args.setdefault(task_name, {}) + self._task_evalution_use_model_path.setdefault(task_name, False) + self._task_evalution_batch_size.setdefault(task_name, None) return baselines @staticmethod @@ -362,18 +807,74 @@ def generateChat(self, model, tokenizer, prompt=None): print(f"Result is: \n{output}") return output - def generate_with_limit(self, model, tokenizer, prompt, max_new_tokens=512): - inputs = tokenizer(prompt, return_tensors="pt").to(model.device) - pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id + # Use this helper for CI output assertions instead of raw model.generate(), + # including in standalone unittest cases, so expected-text checks stay deterministic. + @staticmethod + def generate_stable_with_limit( + model, + tokenizer, + prompt=None, + max_new_tokens=512, + min_new_tokens=None, + skip_special_tokens=True, + inputs=None, + decode_start_idx=None, + batch_decode=False, + clean_up_tokenization_spaces=None, + return_generate_output=False, + **generate_kwargs, + ): + if inputs is None: + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + elif hasattr(inputs, "to"): + inputs = inputs.to(model.device) + + generation_inputs = dict(inputs) if isinstance(inputs, Mapping) else {"input_ids": inputs} + + decoder = getattr(tokenizer, "tokenizer", tokenizer) + pad_token_id = decoder.pad_token_id if decoder.pad_token_id is not None else decoder.eos_token_id generated = model.generate( - **inputs, + **generation_inputs, max_new_tokens=max_new_tokens, + min_new_tokens=min_new_tokens, do_sample=False, num_beams=1, pad_token_id=pad_token_id, - eos_token_id=tokenizer.eos_token_id, + eos_token_id=decoder.eos_token_id, + **generate_kwargs, + ) + if return_generate_output: + return generated + + generated_ids = generated[0] if isinstance(generated, tuple) else generated + + if batch_decode: + if decode_start_idx is None: + if hasattr(inputs, "input_ids"): + decode_start_idx = [len(input_ids) for input_ids in inputs.input_ids] + else: + raise ValueError("decode_start_idx is required for batch_decode when inputs lack input_ids") + + if isinstance(decode_start_idx, int): + generated_ids = [output_ids[decode_start_idx:] for output_ids in generated_ids] + else: + generated_ids = [ + output_ids[start_idx:] + for start_idx, output_ids in zip(decode_start_idx, generated_ids) + ] + + decode_kwargs = {"skip_special_tokens": skip_special_tokens} + if clean_up_tokenization_spaces is not None: + decode_kwargs["clean_up_tokenization_spaces"] = clean_up_tokenization_spaces + return tokenizer.batch_decode(generated_ids, **decode_kwargs)[0] + + if decode_start_idx is None: + decode_start_idx = 0 + + return tokenizer.decode( + generated_ids[0][decode_start_idx:], + skip_special_tokens=skip_special_tokens, ) - return tokenizer.decode(generated[0], skip_special_tokens=True) def run_generic_inference_checks(self, model, tokenizer, backend): model.eval() @@ -383,7 +884,14 @@ def run_generic_inference_checks(self, model, tokenizer, backend): prompt = item["prompt"] keywords = item["keywords"] try: - response = self.generate_with_limit(model, tokenizer, prompt) + inputs, decode_start_idx = self._prepare_generic_inference_inputs(tokenizer, prompt) + response = self.generate_stable_with_limit( + model, + tokenizer, + prompt, + inputs=inputs, + decode_start_idx=decode_start_idx, + ) normalized = response.lower() matched = any(keyword.lower() in normalized for keyword in keywords) results.append( @@ -415,11 +923,29 @@ def run_generic_inference_checks(self, model, tokenizer, backend): ) return results + def _prepare_generic_inference_inputs(self, tokenizer, prompt): + # Some chat-tuned checkpoints only produce stable continuations when the + # sanity prompts are wrapped with the tokenizer's chat template. + if not self.APPLY_CHAT_TEMPLATE or not hasattr(tokenizer, "apply_chat_template"): + return None, None + + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ] + text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + inputs = tokenizer([text], return_tensors="pt") + return inputs, inputs.input_ids.shape[1] + def run_eval_tasks(self, model, backend, trust_remote_code=False): previous_backend = self.LOAD_BACKEND self.LOAD_BACKEND = backend try: - task_results = self.lm_eval( + task_results = self.evaluate_model( model=model, trust_remote_code=self.TRUST_REMOTE_CODE, delete_quantized_model=False, @@ -436,7 +962,13 @@ def _current_load_backend(self): return self.LOAD_BACKEND def _torch_backend(self) -> BACKEND: - return BACKEND.TORCH_AWQ if self.METHOD == METHOD.AWQ else BACKEND.TORCH + if self.METHOD == METHOD.AWQ: + return BACKEND.TORCH_AWQ + if self.METHOD == METHOD.PARO: + return BACKEND.PARO + if self.METHOD == METHOD.GGUF: + return BACKEND.GGUF_TORCH + return BACKEND.TORCH def _torch_fused_backend(self) -> BACKEND: return BACKEND.TORCH_FUSED_AWQ if self.METHOD == METHOD.AWQ else BACKEND.TORCH_FUSED @@ -447,8 +979,19 @@ def perform_post_quant_validation(self, model_path, trust_remote_code=False): reuse_candidates = {} torch_backend = self._torch_backend() torch_fused_backend = self._torch_fused_backend() - - if self.FORMAT is FORMAT.GPTQ: + format_family = resolve_quant_format(self.FORMAT, self.METHOD) + + if format_family == FORMAT.GGUF: + compare_backends = (torch_backend,) + elif format_family == FORMAT.BITSANDBYTES: + compare_backends = (BACKEND.BITSANDBYTES,) + elif format_family == FORMAT.FP8: + compare_backends = (torch_backend,) + elif format_family == FORMAT.EXL3: + compare_backends = (self.LOAD_BACKEND,) + elif format_family == FORMAT.PAROQUANT: + compare_backends = (self.LOAD_BACKEND,) + elif format_family == FORMAT.GPTQ: if self.LOAD_BACKEND == BACKEND.MARLIN: compare_backends = (BACKEND.MARLIN,) else: @@ -458,13 +1001,13 @@ def perform_post_quant_validation(self, model_path, trust_remote_code=False): fallback_backend = None if BACKEND.MARLIN in compare_backends: try: - from gptqmodel.nn_modules.qlinear.marlin import MarlinQuantLinear # type: ignore + from gptqmodel.nn_modules.qlinear.marlin import MarlinLinear # type: ignore except Exception: # pragma: no cover - fallback if module unavailable marlin_group_sizes = () marlin_sym = () else: - marlin_group_sizes = tuple(getattr(MarlinQuantLinear, "SUPPORTS_GROUP_SIZE", ())) - marlin_sym = tuple(getattr(MarlinQuantLinear, "SUPPORTS_SYM", ())) + marlin_group_sizes = tuple(getattr(MarlinLinear, "SUPPORTS_GROUP_SIZE", ())) + marlin_sym = tuple(getattr(MarlinLinear, "SUPPORTS_SYM", ())) requested_group_size = getattr(self, "GROUP_SIZE", None) requested_sym = getattr(self, "SYM", None) @@ -496,7 +1039,7 @@ def perform_post_quant_validation(self, model_path, trust_remote_code=False): for backend in compare_backends: log.info(f"Loading post-quant model with backend `{backend.name}`") - # When EVAL_SINGLE_GPU is enabled, pin post-quant loads to the first CUDA device to avoid auto sharding. + # When EVAL_SINGLE_GPU is enabled, keep post-quant validation on the preferred device. use_cuda_map = ( self.EVAL_SINGLE_GPU and torch.cuda.is_available() @@ -514,9 +1057,11 @@ def perform_post_quant_validation(self, model_path, trust_remote_code=False): model_path, trust_remote_code=trust_remote_code, backend=backend, - ) - tokenizer = model.tokenizer or self.load_tokenizer(model_path, trust_remote_code=trust_remote_code) - inference_records[backend] = self.run_generic_inference_checks(model, tokenizer, backend) + ) + model.tokenizer or self.load_tokenizer(model_path, trust_remote_code=trust_remote_code) + # Pre-evaluation smoke prompts are intentionally disabled to keep quantization tests + # focused only on task execution. + # inference_records[backend] = self.run_generic_inference_checks(model, tokenizer, backend) should_reuse = can_reuse and backend == target_backend and not self.USE_VLLM @@ -764,11 +1309,15 @@ def load_tokenizer(cls, model_id_or_path, trust_remote_code=False): @classmethod def load_dataset(cls, tokenizer=None, rows: int = 0): - try: - dataset = load_dataset(path="/monster/data/model/dataset/nm-calibration", name="LLM", split="train") - except Exception as exc: # pragma: no cover - exercised in fallbacks - log.warning("load_dataset failed; falling back to local parquet: %s", exc) + if hf_load_dataset is None: + log.warning("datasets.load_dataset unavailable; falling back to local parquet: %s", DATASETS_IMPORT_ERROR) dataset = cls._load_calibration_parquet() + else: + try: + dataset = hf_load_dataset(path="/monster/data/model/dataset/nm-calibration", name="LLM", split="train") + except Exception as exc: # pragma: no cover - exercised in fallbacks + log.warning("load_dataset failed; falling back to local parquet: %s", exc) + dataset = cls._load_calibration_parquet() if rows > 0: return dataset.select(range(min(rows, len(dataset)))) @@ -841,20 +1390,110 @@ def check_kernel(self, model, expected_kernels): if expected_kernels: assert modules == expected_kernels, f"kernels are different with expected. found: {modules}. expected: {expected_kernels}" - def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", need_eval=True, batch_size: int = QUANT_BATCH_SIZE, call_perform_post_quant_validation: bool = True, **kwargs): - quantize_config = QuantizeConfig( + def _build_quantize_config(self): + format_family = resolve_quant_format(self.FORMAT, self.METHOD) + + if self.WEIGHT_ONLY is None: + if self.METHOD == METHOD.BITSANDBYTES: + return BitsAndBytesConfig( + bits=self.BITS, + format=self.BNB_FORMAT, + block_size=self.BNB_BLOCK_SIZE, + compress_statistics=self.BNB_COMPRESS_STATISTICS, + adapter=self.EORA, + dense_vram_strategy=self.DENSE_VRAM_STRATEGY, + dense_vram_strategy_devices=self.DENSE_VRAM_STRATEGY_DEVICES, + moe_vram_strategy=self.MOE_VRAM_STRATEGY, + moe_vram_strategy_devices=self.MOE_VRAM_STRATEGY_DEVICES, + dynamic=self.DYNAMIC, + moe=self.MOE_CONFIG, + offload_to_disk=self.OFFLOAD_TO_DISK, + ) + elif self.METHOD == METHOD.PARO: + return ParoConfig( + bits=self.BITS, + method=METHOD.PARO, + format=FORMAT.PAROQUANT, + opt_rotation_epochs=self.PAROQUANT_ROTATION_EPOCHS, + opt_finetune_epochs=self.PAROQUANT_FINETUNE_EPOCHS, + opt_train_samples=self.PAROQUANT_TRAIN_SAMPLES, + adapter=self.EORA, + dense_vram_strategy=self.DENSE_VRAM_STRATEGY, + dense_vram_strategy_devices=self.DENSE_VRAM_STRATEGY_DEVICES, + moe_vram_strategy=self.MOE_VRAM_STRATEGY, + moe_vram_strategy_devices=self.MOE_VRAM_STRATEGY_DEVICES, + dynamic=self.DYNAMIC, + moe=self.MOE_CONFIG, + offload_to_disk=self.OFFLOAD_TO_DISK, + ) + + if self.WEIGHT_ONLY is not None: + if not isinstance(self.WEIGHT_ONLY, WeightOnlyConfig): + raise TypeError(f"`WEIGHT_ONLY` must be a WeightOnlyConfig, got {type(self.WEIGHT_ONLY).__name__}") + + if format_family == FORMAT.GGUF or self.WEIGHT_ONLY.method.value == "gguf": + return GGUFConfig( + bits=self.BITS, + adapter=self.EORA, + pack_impl="cpu", + dense_vram_strategy=self.DENSE_VRAM_STRATEGY, + dense_vram_strategy_devices=self.DENSE_VRAM_STRATEGY_DEVICES, + moe_vram_strategy=self.MOE_VRAM_STRATEGY, + moe_vram_strategy_devices=self.MOE_VRAM_STRATEGY_DEVICES, + dynamic=self.DYNAMIC, + moe=self.MOE_CONFIG, + smoother=self.WEIGHT_ONLY.smooth, + ) + + if format_family == FORMAT.FP8 or self.WEIGHT_ONLY.method.value == "fp8": + return FP8Config( + bits=self.BITS, + format=self.FORMAT, + adapter=self.EORA, + pack_impl="cpu", + dense_vram_strategy=self.DENSE_VRAM_STRATEGY, + dense_vram_strategy_devices=self.DENSE_VRAM_STRATEGY_DEVICES, + moe_vram_strategy=self.MOE_VRAM_STRATEGY, + moe_vram_strategy_devices=self.MOE_VRAM_STRATEGY_DEVICES, + dynamic=self.DYNAMIC, + moe=self.MOE_CONFIG, + smoother=self.WEIGHT_ONLY.smooth, + ) + + return RTNConfig( + bits=self.BITS, + group_size=self.GROUP_SIZE, + desc_act=self.DESC_ACT, + sym=self.SYM, + format=self.FORMAT, + adapter=self.EORA, + pack_impl="cpu", + dense_vram_strategy=self.DENSE_VRAM_STRATEGY, + dense_vram_strategy_devices=self.DENSE_VRAM_STRATEGY_DEVICES, + moe_vram_strategy=self.MOE_VRAM_STRATEGY, + moe_vram_strategy_devices=self.MOE_VRAM_STRATEGY_DEVICES, + dynamic=self.DYNAMIC, + moe=self.MOE_CONFIG, + smooth=self.WEIGHT_ONLY.smooth, + ) + + return QuantizeConfig( quant_method=self.METHOD, format=self.FORMAT, bits=self.BITS, group_size=self.GROUP_SIZE, desc_act=self.DESC_ACT if not self.ACT_GROUP_AWARE else False, act_group_aware=self.ACT_GROUP_AWARE, - failsafe=self.FAILSAFE, + fallback=self.FALLBACK, sym=self.SYM, - gptaq=GPTAQConfig() if self.GPTQA else None, + gptaq=copy.deepcopy(self.GPTAQ), + foem=copy.deepcopy(self.FOEM), adapter=self.EORA, pack_impl="cpu", - vram_strategy=self.VRAM_STRATEGY, + dense_vram_strategy=self.DENSE_VRAM_STRATEGY, + dense_vram_strategy_devices=self.DENSE_VRAM_STRATEGY_DEVICES, + moe_vram_strategy=self.MOE_VRAM_STRATEGY, + moe_vram_strategy_devices=self.MOE_VRAM_STRATEGY_DEVICES, damp_percent=self.DAMP_PERCENT, mse=self.MSE, dynamic=self.DYNAMIC, @@ -863,6 +1502,10 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne offload_to_disk=self.OFFLOAD_TO_DISK, ) + def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", need_eval=True, batch_size: int = QUANT_BATCH_SIZE, call_perform_post_quant_validation: bool = True, **kwargs): + """Return `(model, tokenizer, processor)`; `processor` is `None` for text-only models.""" + quantize_config = self._build_quantize_config() + log.info(f"Quant config: {quantize_config}") log.info(f"Quant batch_size: {batch_size}") @@ -873,6 +1516,8 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne args["attn_implementation"] = "flash_attention_2" else: log.warn("flash-attn requested but not available; falling back to framework defaults") + else: + args["attn_implementation"] = "eager" log.info(f"args: {args}") @@ -882,7 +1527,11 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne quantize_config=quantize_config, trust_remote_code=trust_remote_code, dtype=dtype, - device_map={"": "cpu"} if self.LOAD_BACKEND == torch_fused_backend else "auto", + device_map=( + {"": "cpu"} + if self.LOAD_BACKEND == torch_fused_backend + else "auto" + ), **args, ) @@ -894,10 +1543,19 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne tokenizer = model.tokenizer self._post_quant_eval_records = {} self._effective_load_backend = None + self._model_compat_fast_dynamic = None + # Tracks whether quantModel() loaded an existing quantized checkpoint + # instead of producing a fresh post-quant artifact + Evalution cache. + self._loaded_model_was_prequantized = False processor = None + self._apply_model_compat_quant_overrides(model) + is_image_to_text_model = MODALITY.IMAGE_TO_TEXT in model.modality - calibration_dataset = get_calib_dataset(model) if is_image_to_text_model else self.load_dataset(tokenizer, self.DATASET_SIZE) + if quantize_config.requires_calibration_dataset(): + calibration_dataset = get_calib_dataset(model) if is_image_to_text_model else self.load_dataset(tokenizer, self.DATASET_SIZE) + else: + calibration_dataset = None # mpt model need if hasattr(model.config, "pad_token_id") and not model.config.pad_token_id: @@ -906,15 +1564,13 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne model.config.eos_token_id = tokenizer.eos_token_id or 0 is_quantized = model.quantized + self._loaded_model_was_prequantized = bool(is_quantized) # ovis cannot load processor is_ovis_model = model.config.model_type == "ovis" need_create_processor = is_image_to_text_model and not is_ovis_model - debug_short_circuit = False if not is_quantized: - save_context = None - planned_save_path = None cleanup_callback = None try: save_context, planned_save_path, cleanup_callback = self._prepare_quant_save_destination(need_eval) @@ -940,7 +1596,6 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne model=model, tokenizer=tokenizer, processor=None, - need_create_processor=need_create_processor, cleanup_callback=cleanup_callback, ) @@ -950,7 +1605,7 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne os.makedirs(path, exist_ok=True) self.clear_directory(path) - model.save(path) + model.save(path, split_by=self.SPLIT_BY) self._print_post_quant_artifacts(path) reuse_candidates = {} @@ -983,7 +1638,7 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne q_tokenizer = q_model.tokenizer or self.load_tokenizer(path, trust_remote_code=trust_remote_code) if need_create_processor: - processor = AutoProcessor.from_pretrained(path) + processor = AutoProcessor.from_pretrained(path, trust_remote_code=trust_remote_code) except Exception: if cleanup_callback is not None: try: @@ -998,15 +1653,9 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne if not is_quantized: del model torch_empty_cache() - if need_create_processor: - return q_model, q_tokenizer, processor - else: - return q_model, q_tokenizer + return q_model, q_tokenizer, processor else: - if need_create_processor: - return model, tokenizer, processor - else: - return model, tokenizer + return model, tokenizer, processor def loadQuantModel(self, model_id_or_path, trust_remote_code=False, tokenizer_path=None, backend=None, **args): @@ -1020,6 +1669,8 @@ def loadQuantModel(self, model_id_or_path, trust_remote_code=False, tokenizer_pa load_kwargs["attn_implementation"] = "flash_attention_2" else: log.warn("flash-attn requested but not available; falling back to framework defaults") + else: + load_kwargs["attn_implementation"] = "eager" active_backend = backend if backend is not None else self._current_load_backend() torch_fused_backend = self._torch_fused_backend() @@ -1059,7 +1710,7 @@ def loadQuantModel(self, model_id_or_path, trust_remote_code=False, tokenizer_pa return model - def lm_eval(self, model, trust_remote_code=False, delete_quantized_model=False, extra_args:dict=None): + def evaluate_model(self, model, trust_remote_code=False, delete_quantized_model=False, extra_args:dict=None): try: task_names = self._normalize_task_list() aggregated_results = {} @@ -1069,40 +1720,40 @@ def lm_eval(self, model, trust_remote_code=False, delete_quantized_model=False, model_path = model if self.USE_VLLM: - tensor_parallel = 1 - if not self.EVAL_SINGLE_GPU: - try: - candidate = torch.cuda.device_count() - except Exception: - candidate = 1 - tensor_parallel = max(1, candidate) - model_args = { - "pretrained": model_path, - "dtype": "auto", #"float16", - "gpu_memory_utilization": 0.8, - "tensor_parallel_size": tensor_parallel, - "trust_remote_code": trust_remote_code, - "max_model_len": self.MODEL_MAX_LEN - } - else: - model_args = {} + raise ValueError("ModelTest USE_VLLM is no longer supported; evaluation is delegated to Evalution.") + + model_args = {} if extra_args: model_args.update(extra_args) - from lm_eval.tasks import TaskManager - from lm_eval.utils import make_table - - task_groups = EVAL.get_task_groups_from_tasks(task_names) - chat_template_lookup = getattr(self, "_task_chat_template", {}) or {} - - for framework, tasks in task_groups.items(): - active_backend = self._current_load_backend() - log.info(f"TEST: EVAL starting: backend = {active_backend.name}") - if model_path: - log.info(f"Inference from model path: {model_path}") - - if isinstance(model, BaseQModel) and not self.USE_VLLM: + suite_kwargs_lookup = getattr(self, "_task_evalution_suite_kwargs", {}) or {} + task_model_args_lookup = getattr(self, "_task_evalution_model_args", {}) or {} + use_model_path_lookup = getattr(self, "_task_evalution_use_model_path", {}) or {} + eval_batch_size_lookup = getattr(self, "_task_evalution_batch_size", {}) or {} + active_backend = self._current_load_backend() + log.info(f"TEST: Evalution starting: backend = {active_backend.name}") + if model_path: + log.info(f"Inference from model path: {model_path}") + + for task_name in task_names: + normalized_name = self._normalize_task_identifier(task_name) + apply_chat_template = bool(chat_template_lookup.get(normalized_name, False)) + task_model_args = dict(model_args) + task_model_args.update(task_model_args_lookup.get(normalized_name, {}) or {}) + # Keep evalution-backed generation reproducible even when a task opts + # into sampling or an engine backend introduces RNG-sensitive paths. + task_model_args.setdefault("seed", RAND_SEED) + task_model_args.setdefault("random_seed", RAND_SEED) + task_suite_kwargs = dict(suite_kwargs_lookup.get(normalized_name, {}) or {}) + task_batch_size = eval_batch_size_lookup.get(normalized_name) + if task_batch_size is None: + task_batch_size = self.EVAL_BATCH_SIZE + use_model_path = bool(use_model_path_lookup.get(normalized_name, False)) + + if use_model_path and model_path: + eval_target = model_path + elif isinstance(model, BaseQModel): eval_target = model else: eval_target = model_path @@ -1110,56 +1761,32 @@ def lm_eval(self, model, trust_remote_code=False, delete_quantized_model=False, if eval_target is None: raise ValueError("Model evaluation target could not be determined.") - resolved_lookup = getattr(self, "_resolved_task_lookup", {}) - eval_tasks = [] - for task in tasks: - original_task = resolved_lookup.get(task) - if original_task is None: - original_task = self._resolve_task_enum(task) - if isinstance(resolved_lookup, dict): - normalized_task = self._normalize_task_identifier(original_task) - resolved_lookup[normalized_task] = original_task - eval_tasks.append(original_task) - - grouped_tasks: Dict[bool, List] = {} - for task in eval_tasks: - normalized_name = self._normalize_task_identifier(task) - apply_chat = bool(chat_template_lookup.get(normalized_name, False)) - grouped_tasks.setdefault(apply_chat, []).append(task) - - for apply_chat_template, grouped in grouped_tasks.items(): - results = GPTQModel.eval( - model_or_id_or_path=eval_target, - llm_backend="vllm" if self.USE_VLLM else "gptqmodel", - model_args=model_args, - output_path=tmp_dir, - backend=active_backend, - framework=framework, - tasks=grouped, - apply_chat_template=apply_chat_template, - trust_remote_code=trust_remote_code, - batch_size=self.EVAL_BATCH_SIZE, - gen_kwargs="temperature=0.0,top_k=50", - random_seed=RAND_SEED, - task_manager=TaskManager(include_path=os.path.join(os.path.dirname(os.path.abspath(__file__)), "../tasks"), include_defaults=False) - ) + results = evaluate( + model_or_id_or_path=eval_target, + model_args=task_model_args, + output_path=tmp_dir, + backend=active_backend, + tasks=[normalized_name], + apply_chat_template=apply_chat_template, + trust_remote_code=trust_remote_code, + batch_size=task_batch_size, + gen_kwargs="do_sample=false,temperature=0.0,top_p=1.0,top_k=50", + suite_kwargs=task_suite_kwargs, + ) - print('--------Eval Result---------') - print(make_table(results)) - if "groups" in results: - print(make_table(results, "groups")) - print('--------Eval Result End---------') - - for task_name in grouped: - normalized_task_name = self._normalize_task_identifier(task_name) - metrics = results["results"].get(normalized_task_name, {}) - filtered_metrics = { - metric: value - for metric, value in metrics.items() - if metric != "alias" and "stderr" not in metric - } - aggregated_results[normalized_task_name] = filtered_metrics - print({normalized_task_name: filtered_metrics}) + print('--------Eval Result---------') + print(format_eval_result_table(results)) + print('--------Eval Result End---------') + + result_metrics = get_eval_task_results(results) + metrics = result_metrics.get(normalized_name, {}) + filtered_metrics = { + metric: value + for metric, value in metrics.items() + if metric != "alias" and "stderr" not in metric + } + aggregated_results[normalized_name] = filtered_metrics + print({normalized_name: filtered_metrics}) self._cleanup_quantized_model(model, enabled=delete_quantized_model) return aggregated_results @@ -1175,11 +1802,12 @@ def lm_eval(self, model, trust_remote_code=False, delete_quantized_model=False, print(f"batch {old_batch} OOM, retrying with batch {self.EVAL_BATCH_SIZE}") if int(self.EVAL_BATCH_SIZE) > 0: - self.lm_eval(model=model, - trust_remote_code=trust_remote_code, - delete_quantized_model=delete_quantized_model, - extra_args=extra_args) + results = self.evaluate_model(model=model, + trust_remote_code=trust_remote_code, + delete_quantized_model=delete_quantized_model, + extra_args=extra_args) print(f"set batch size to {self.EVAL_BATCH_SIZE}, passed") + return results else: print(f"set batch size to {self.EVAL_BATCH_SIZE}, failed") raise e @@ -1191,32 +1819,134 @@ def calculatorPer(self, task_name, metric_name, value, expected): log.info(f"{task_name}:{metric_name}: `{value}` vs `{expected}` diff {diff_pct:.2f}%") return diff_pct, expected - def quant_lm_eval(self): + @staticmethod + def _metric_within_expected_range(value, expected, floor_pct, ceil_pct): + diff_pct = (value / expected) * 100 + negative_pct = 100 * (1 - floor_pct) + positive_pct = 100 * (1 + ceil_pct) + return negative_pct <= diff_pct <= positive_pct, diff_pct, negative_pct, positive_pct + + def _current_native_backend(self) -> BACKEND: + return BACKEND.TORCH + + def _get_current_native_eval_results(self): + cached = getattr(self, "_current_native_eval_results", None) + if cached is not None: + return cached + + native_model_id = getattr(self, "NATIVE_MODEL_ID", None) + if not native_model_id: + return None + + previous_backend = self.LOAD_BACKEND + previous_effective_backend = getattr(self, "_effective_load_backend", None) + self.LOAD_BACKEND = self._current_native_backend() + self._effective_load_backend = None + try: + log.warn( + "Baseline fallback: evaluating current native model `%s` to verify whether stored expectations are stale.", + native_model_id, + ) + cached = self.evaluate_model( + model=native_model_id, + trust_remote_code=self.TRUST_REMOTE_CODE, + delete_quantized_model=False, + ) + finally: + self.LOAD_BACKEND = previous_backend + self._effective_load_backend = previous_effective_backend + + self._current_native_eval_results = cached + return cached + + def _maybe_accept_current_native_baseline( + self, + *, + task_name: str, + metric_name: str, + metric_key: str, + value: float, + floor_pct: float, + ceil_pct: float, + ) -> bool: + try: + native_results = self._get_current_native_eval_results() + except Exception as exc: # pragma: no cover - defensive fallback for flaky native eval + log.warn(f"Baseline fallback: failed to evaluate current native model: {exc}") + return False + + if not isinstance(native_results, dict): + return False + + native_metrics = native_results.get(task_name) + if not isinstance(native_metrics, dict): + return False + + native_metric_key = self._resolve_metric_key(metric_key, native_metrics) + if native_metric_key is None and metric_key != metric_name: + native_metric_key = self._resolve_metric_key(metric_name, native_metrics) + if native_metric_key is None: + return False + + native_value = native_metrics[native_metric_key] + passed, diff_pct, negative_pct, positive_pct = self._metric_within_expected_range( + value=value, + expected=native_value, + floor_pct=floor_pct, + ceil_pct=ceil_pct, + ) + if not passed: + return False + + log.warn( + f"Baseline fallback: accepting `{task_name}:{metric_name}` using current native value `{native_value}`; " + f"quantized result `{value}` diff {diff_pct:.2f}% is within [{negative_pct:.2f}-{positive_pct:.2f}] " + f"while stored expectation appears stale." + ) + return True + + def quantize_and_evaluate(self): self.model = None # TODO fix me: LOAD_QUANTIZED_MODEL doesn't make any sense when we have QUANT_SAVE_PATH #if self.QUANT_SAVE_PATH: - # self.model, _ = self.quantModel(self.QUANT_SAVE_PATH, batch_size=self.QUANT_BATCH_SIZE, trust_remote_code=self.TRUST_REMOTE_CODE, dtype=self.TORCH_DTYPE) + # self.model, _, _ = self.quantModel(self.QUANT_SAVE_PATH, batch_size=self.QUANT_BATCH_SIZE, trust_remote_code=self.TRUST_REMOTE_CODE, dtype=self.TORCH_DTYPE) - if not self.model: - self.model, _ = self.quantModel(self.NATIVE_MODEL_ID, batch_size=self.QUANT_BATCH_SIZE, trust_remote_code=self.TRUST_REMOTE_CODE, dtype=self.TORCH_DTYPE) + log.info("Model compat test mode: %s", self._model_test_mode()) + with self.model_compat_test_context(): + if not self.model: + self.model, _, _ = self.quantModel(self.NATIVE_MODEL_ID, batch_size=self.QUANT_BATCH_SIZE, trust_remote_code=self.TRUST_REMOTE_CODE, dtype=self.TORCH_DTYPE) self.check_kernel(self.model, self.KERNEL_INFERENCE) if self._debug_layer_stop_triggered(): - log.info("DEBUG mode: skipping lm_eval and baseline checks after early layer stop.") + log.info("DEBUG mode: skipping evaluation and baseline checks after early layer stop.") return eval_records = getattr(self, "_post_quant_eval_records", {}) target_backend = self._current_load_backend() if eval_records and len(eval_records) == 1 and target_backend in eval_records: - log.info("Reusing evaluation results for backend `%s`; skipping duplicate lm_eval run", target_backend.name) + log.info("Reusing evaluation results for backend `%s`; skipping duplicate evaluation run", target_backend.name) task_results = eval_records[target_backend] else: - task_results = self.lm_eval( - model=self.SAVE_PATH if self.SAVE_PATH else self.model, - trust_remote_code=self.TRUST_REMOTE_CODE, - delete_quantized_model=self.DELETE_QUANTIZED_MODEL, - ) + task_results = eval_records.get(target_backend) + if task_results is None: + if getattr(self, "_loaded_model_was_prequantized", False): + log.info( + "Loaded checkpoint was already quantized; running Evalution directly for backend `%s`.", + target_backend.name, + ) + with self.model_compat_test_context(): + task_results = self.evaluate_model( + model=self.SAVE_PATH if self.SAVE_PATH else self.model, + trust_remote_code=self.TRUST_REMOTE_CODE, + delete_quantized_model=False, + ) + self._post_quant_eval_records[target_backend] = task_results + else: + raise AssertionError( + "Post-quant eval results were not produced. " + "The Stage-2 evaluation fallback is disabled." + ) self.check_results(task_results) self._cleanup_quantized_model(self.model, enabled=self.DELETE_QUANTIZED_MODEL) @@ -1225,41 +1955,87 @@ def check_results(self, task_results): if not baselines: raise AssertionError("No evaluation baselines configured for result validation.") + errors = [] + diffs = [] + for task_name, expected_metrics in baselines.items(): metrics = task_results.get(task_name) + if metrics is None: - self.fail(f"No evaluation results returned for task `{task_name}`") + errors.append(f"No evaluation results returned for task `{task_name}`") + continue + if not isinstance(metrics, dict): - raise TypeError(f"Expected metrics for task `{task_name}` to be a dictionary, got {type(metrics).__name__}") + raise TypeError( + f"Expected metrics for task `{task_name}` to be a dictionary, got {type(metrics).__name__}" + ) for metric_name, baseline_spec in expected_metrics.items(): metric_key = baseline_spec.get("metric_key") or metric_name metric_key = self._resolve_metric_key(metric_key, metrics) + if metric_key is None: - self.fail(f"Metric `{metric_name}` missing from results for task `{task_name}`") + errors.append(f"Metric `{metric_name}` missing from results for task `{task_name}`") + continue value = metrics[metric_key] expected_value = baseline_spec["value"] + diff_pct, expected_value = self.calculatorPer( task_name=task_name, metric_name=metric_name, value=value, expected=expected_value, ) + floor_pct = baseline_spec["floor_pct"] ceil_pct = baseline_spec["ceil_pct"] - negative_pct = 100 * (1 - floor_pct) - positive_pct = 100 * (1 + ceil_pct) - self.assertTrue( - negative_pct <= diff_pct <= positive_pct, - f"{task_name}:{metric_name}: `{value}` vs expected `{expected_value}`, " - f"diff {diff_pct:.2f}% is out of the expected range [{negative_pct}-{positive_pct}%]", + passed, diff_pct, negative_pct, positive_pct = self._metric_within_expected_range( + value=value, + expected=expected_value, + floor_pct=floor_pct, + ceil_pct=ceil_pct, + ) + diffs.append( + f"{task_name}:{metric_name} -> value={value}, expected={expected_value}, diff={diff_pct:.2f}% " + f"(allowed [{negative_pct}-{positive_pct}%])" ) + if passed: + continue + if self.DISABLE_NATIVE_BASELINE_FALLBACK: + continue + if self._maybe_accept_current_native_baseline( + task_name=task_name, + metric_name=metric_name, + metric_key=metric_key, + value=value, + floor_pct=floor_pct, + ceil_pct=ceil_pct, + ): + continue + + if not (negative_pct <= diff_pct <= positive_pct): + errors.append( + f"{task_name}:{metric_name} out of range: `{value}` vs expected `{expected_value}`, " + f"diff {diff_pct:.2f}% not in [{negative_pct}-{positive_pct}%]" + ) + + print("\nEvaluation diff summary:") + for d in diffs: + print(d) + + if errors: + raise AssertionError( + "Evaluation failed:\n" + "\n".join(errors) + ) @staticmethod def _resolve_metric_key(metric_name, metrics): if metric_name in metrics: return metric_name + alias = resolve_eval_metric_alias(metric_name, metrics) + if alias is not None: + return alias if metric_name is None: return None # if baseline uses canonical name without suffix, look for variants like acc,none diff --git a/tests/models/ovis/image_to_test_dataset.py b/tests/models/ovis/image_to_test_dataset.py index b81be9714..9cca990cd 100644 --- a/tests/models/ovis/image_to_test_dataset.py +++ b/tests/models/ovis/image_to_test_dataset.py @@ -5,6 +5,8 @@ from gptqmodel.models.definitions.base_qwen2_5_omni import BaseQwen2_5_OmniGPTQ from gptqmodel.models.definitions.base_qwen2_vl import BaseQwen2VLGPTQ +from gptqmodel.models.definitions.minicpm_o import MiniCPMOQModel +from gptqmodel.models.definitions.minicpm_v import MiniCPMVQModel from gptqmodel.models.definitions.ovis import OvisQModel from gptqmodel.models.definitions.ovis2 import Ovis2QModel from gptqmodel.models.definitions.qwen3_vl import Qwen3_VLQModel @@ -90,13 +92,10 @@ def get_calib_dataset(model): if isinstance(model, Ovis2QModel): return prepare_dataset(format_ovis2_dataset, n_sample=20) - if isinstance(model, BaseQwen2VLGPTQ): + if isinstance(model, BaseQwen2VLGPTQ) or isinstance(model, Qwen3_VLQModel) or isinstance(model, MiniCPMOQModel) or isinstance(model, MiniCPMVQModel): return prepare_dataset(format_qwen2_vl_dataset, n_sample=20) if isinstance(model, BaseQwen2_5_OmniGPTQ): return prepare_dataset(format_qwen2_5_omni_dataset, n_sample=20) - if isinstance(model, Qwen3_VLQModel): - return prepare_dataset(format_qwen2_vl_dataset, n_sample=20) - raise NotImplementedError(f"Unsupported MODEL: {model.__class__}") diff --git a/tests/models/paroquant_first_layer_case_helper.py b/tests/models/paroquant_first_layer_case_helper.py new file mode 100644 index 000000000..ab637c9a1 --- /dev/null +++ b/tests/models/paroquant_first_layer_case_helper.py @@ -0,0 +1,107 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import os +from typing import Any + +from model_test import _env_choice, _env_flag, _env_int + +from gptqmodel.utils.paroquant_benchmark import run_paroquant_first_layer_case + + +def resolve_paroquant_first_layer_case_env( + *, + env_prefix: str, + default_num_quant_layers: int, + default_opt_scope: str, +) -> dict[str, Any]: + """Resolve a shared ParoQuant first-layer integration env schema.""" + num_quant_layers = _env_int(f"{env_prefix}_NUM_LAYERS", default=default_num_quant_layers) + opt_scope = _env_choice(f"{env_prefix}_OPT_SCOPE", default=default_opt_scope) + eval_max_rows = os.environ.get(f"{env_prefix}_MAX_ROWS") + eval_batch = _env_choice(f"{env_prefix}_EVAL_BATCH", default="auto") + eval_batch_size: int | str = "auto" if eval_batch == "auto" else int(eval_batch) + resolved_eval_max_rows = 128 if eval_max_rows in {None, ""} else int(eval_max_rows) + return { + "num_quant_layers": num_quant_layers, + "calibration_rows": _env_int(f"{env_prefix}_CAL_ROWS", default=32), + "eval_batch_size": eval_batch_size, + "eval_max_rows": resolved_eval_max_rows, + "eval_model_args": { + "dtype": os.environ.get(f"{env_prefix}_EVAL_DTYPE", "bfloat16"), + "attn_implementation": os.environ.get( + f"{env_prefix}_ATTN_IMPL", + "paged|flash_attention_2", + ), + "device": os.environ.get(f"{env_prefix}_EVAL_DEVICE", "cuda:0"), + }, + "eval_suite_kwargs": { + "batch_size": _env_int(f"{env_prefix}_SUITE_BATCH", default=24), + "max_new_tokens": _env_int(f"{env_prefix}_MAX_NEW_TOKENS", default=96), + "stream": _env_flag(f"{env_prefix}_STREAMING", default=True), + "max_rows": resolved_eval_max_rows, + }, + "sym": True, + "fused_opt_rotation": _env_flag(f"{env_prefix}_FUSED", default=True), + "opt_scope": opt_scope, + "opt_rotation_epochs": _env_int(f"{env_prefix}_ROT_EPOCHS", default=4), + "opt_finetune_epochs": _env_int(f"{env_prefix}_FT_EPOCHS", default=4), + "opt_train_samples": _env_int(f"{env_prefix}_TRAIN_ROWS", default=512), + "opt_validation_samples": _env_int(f"{env_prefix}_VAL_ROWS", default=64), + "opt_batch_size": _env_int(f"{env_prefix}_OPT_BATCH", default=16), + } + + +def run_paroquant_first_layer_case_from_resolved(resolved: dict[str, Any]) -> dict[str, Any]: + """Run a ParoQuant first-layer integration case from resolved shared options.""" + return run_paroquant_first_layer_case( + num_quant_layers=resolved["num_quant_layers"], + calibration_rows=resolved["calibration_rows"], + eval_batch_size=resolved["eval_batch_size"], + eval_max_rows=resolved["eval_max_rows"], + eval_model_args=resolved["eval_model_args"], + eval_suite_kwargs=resolved["eval_suite_kwargs"], + sym=resolved["sym"], + fused_opt_rotation=resolved["fused_opt_rotation"], + opt_scope=resolved["opt_scope"], + opt_rotation_epochs=resolved["opt_rotation_epochs"], + opt_finetune_epochs=resolved["opt_finetune_epochs"], + opt_train_samples=resolved["opt_train_samples"], + opt_validation_samples=resolved["opt_validation_samples"], + opt_batch_size=resolved["opt_batch_size"], + ) + + +def run_paroquant_first_layer_case_from_env( + *, + env_prefix: str, + default_num_quant_layers: int, + default_opt_scope: str, +) -> tuple[dict[str, Any], dict[str, Any]]: + """Build and run a ParoQuant first-layer integration case from a shared env schema.""" + resolved = resolve_paroquant_first_layer_case_env( + env_prefix=env_prefix, + default_num_quant_layers=default_num_quant_layers, + default_opt_scope=default_opt_scope, + ) + + result = run_paroquant_first_layer_case_from_resolved(resolved) + return result, resolved + + +def assert_basic_paroquant_first_layer_result( + result: dict[str, Any], + *, + num_quant_layers: int, + opt_scope: str, +) -> dict[str, Any]: + """Assert the common success conditions for ParoQuant first-layer integration runs.""" + assert result["module_time_rows"], "expected per-module quantization timings" + assert result["num_quant_layers"] == num_quant_layers + assert result["opt_scope"] == opt_scope + assert "gsm8k_platinum_cot" in result["eval_metrics"], "expected gsm8k platinum metrics" + gsm8k_metrics = result["eval_metrics"]["gsm8k_platinum_cot"] + assert "acc,num" in gsm8k_metrics, "expected gsm8k_platinum_cot acc,num metric" + return gsm8k_metrics diff --git a/tests/models/paroquant_optimize_case.py b/tests/models/paroquant_optimize_case.py new file mode 100644 index 000000000..dd872bd1e --- /dev/null +++ b/tests/models/paroquant_optimize_case.py @@ -0,0 +1,215 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +import os + +import torch +from model_test import ModelTest, _env_choice, _env_flag, _env_int, _env_optional_flag + +from gptqmodel import BACKEND +from gptqmodel.nn_modules.qlinear.paroquant import ParoLinear +from gptqmodel.quantization import FORMAT, METHOD, ParoConfig + + +def _resolve_save_path(scope_env: str, default: str) -> str: + """Resolve a scope-specific saved checkpoint path with global fallback.""" + return os.environ.get( + scope_env, + os.environ.get("GPTQMODEL_PAROQUANT_SAVE_PATH", default), + ) + + +PAROQUANT_EVAL_TASKS_FAST = { + "gsm8k_platinum_cot": { + "chat_template": True, + "evalution_use_model_path": True, + "evalution_batch_size": "auto", + "evalution_model_args": { + "dtype": "bfloat16", + "attn_implementation": "paged|flash_attention_2", + "device": "cuda:0", + }, + "evalution_suite_kwargs": { + "batch_size": 32, + "max_new_tokens": 256, + "stream": True, + }, + "acc,num": { + "value": 0.460938, + "floor_pct": 0.04, + "ceil_pct": 1.0, + }, + }, + "arc_challenge": { + "chat_template": True, + "acc": { + "value": 0.3216723549488055, + "floor_pct": 0.04, + "ceil_pct": 1.0, + }, + "acc_norm": { + "value": 0.3515358361774744, + "floor_pct": 0.04, + "ceil_pct": 1.0, + }, + }, + # "mmlu_stem": { + # "chat_template": False, + # "acc": { + # "value": 0.40120520139549637, + # "floor_pct": 0.04, + # "ceil_pct": 1.0, + # }, + # }, +} + +PAROQUANT_EVAL_TASKS_SLOW = { + "gsm8k_platinum_cot": { + "chat_template": True, + "acc,num": { + "value": 0.34325889164598844, + "floor_pct": 0.04, + }, + }, + "arc_challenge": { + "chat_template": True, + "acc": { + "value": 0.30631399317406144, + "floor_pct": 0.04, + }, + "acc_norm": { + "value": 0.33532423208191126, + "floor_pct": 0.04, + }, + }, + # "mmlu_stem": { + # "chat_template": False, + # "acc": { + # "value": 0.3850301300348874, + # "floor_pct": 0.04, + # }, + # }, +} + + +class BaseLlama3_2ParoQuantOptimizeTest(ModelTest): + """Shared accuracy-oriented ParoQuant optimize test configuration.""" + + __test__ = False + + DELETE_QUANTIZED_MODEL = False + NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct" + EVAL_BATCH_SIZE = 64 + DATASET_CONCAT_SIZE = 2048 + EVAL_TASKS_FAST = PAROQUANT_EVAL_TASKS_FAST + EVAL_TASKS_SLOW = PAROQUANT_EVAL_TASKS_SLOW + FORMAT = FORMAT.PAROQUANT + METHOD = METHOD.PARO + SYM = True + TORCH_DTYPE = torch.bfloat16 + LOAD_BACKEND = BACKEND.PARO + QUANT_BACKEND = BACKEND.PARO + KERNEL_QUANT = {ParoLinear} + KERNEL_INFERENCE = {ParoLinear} + MODEL_COMPAT_FAST_LAYER_COUNT = 4 + MODEL_COMPAT_FAST_LAYER_POSITION = "last" + + # Accuracy-focused fast-mode defaults: last 4 layers, 2+2 epochs on 4096 train rows. + PAROQUANT_ROTATION_EPOCHS = 2 + PAROQUANT_FINETUNE_EPOCHS = 2 + PAROQUANT_TRAIN_SAMPLES = 4096 + PAROQUANT_SEED = 3141592653 + + OPT_SCOPE: str = "" + TRAIN_ON_NOISY_INPUTS_DEFAULT = False + + @classmethod + def _scope_prefix(cls) -> str: + if not cls.OPT_SCOPE: + raise ValueError(f"{cls.__name__} must define OPT_SCOPE.") + return cls.OPT_SCOPE.upper() + + @classmethod + def _opt_train_on_noisy_inputs(cls) -> bool: + if cls.OPT_SCOPE not in {"layer", "compute_block"}: + return False + prefix = cls._scope_prefix() + return _env_flag( + f"GPTQMODEL_PAROQUANT_{prefix}_TRAIN_ON_NOISY_INPUTS", + "GPTQMODEL_PAROQUANT_TRAIN_ON_NOISY_INPUTS", + default=cls.TRAIN_ON_NOISY_INPUTS_DEFAULT, + ) + + @classmethod + def _opt_stage_impl(cls) -> str: + """Allow stage-runner A/Bs without changing the default model test behavior.""" + prefix = cls._scope_prefix() + return _env_choice( + f"GPTQMODEL_PAROQUANT_{prefix}_STAGE_IMPL", + "GPTQMODEL_PAROQUANT_STAGE_IMPL", + default="fast", + ) + + @classmethod + def _opt_gradient_checkpointing(cls): + """Allow scoped activation-checkpointing overrides while keeping config defaults meaningful.""" + prefix = cls._scope_prefix() + return _env_optional_flag( + f"GPTQMODEL_PAROQUANT_{prefix}_GRADIENT_CHECKPOINTING", + "GPTQMODEL_PAROQUANT_GRADIENT_CHECKPOINTING", + ) + + @classmethod + def _rotation_epochs(cls) -> int: + prefix = cls._scope_prefix() + return _env_int( + f"GPTQMODEL_PAROQUANT_{prefix}_ROTATION_EPOCHS", + "GPTQMODEL_PAROQUANT_ROTATION_EPOCHS", + default=cls.PAROQUANT_ROTATION_EPOCHS, + ) + + @classmethod + def _finetune_epochs(cls) -> int: + prefix = cls._scope_prefix() + return _env_int( + f"GPTQMODEL_PAROQUANT_{prefix}_FINETUNE_EPOCHS", + "GPTQMODEL_PAROQUANT_FINETUNE_EPOCHS", + default=cls.PAROQUANT_FINETUNE_EPOCHS, + ) + + @classmethod + def _train_samples(cls) -> int: + prefix = cls._scope_prefix() + return _env_int( + f"GPTQMODEL_PAROQUANT_{prefix}_TRAIN_SAMPLES", + "GPTQMODEL_PAROQUANT_TRAIN_SAMPLES", + default=cls.PAROQUANT_TRAIN_SAMPLES, + ) + + def _build_quantize_config(self): + return ParoConfig( + bits=self.BITS, + method=METHOD.PARO, + format=FORMAT.PAROQUANT, + opt_scope=self.OPT_SCOPE, + opt_train_on_noisy_inputs=self._opt_train_on_noisy_inputs(), + opt_gradient_checkpointing=self._opt_gradient_checkpointing(), + opt_rotation_epochs=self._rotation_epochs(), + opt_finetune_epochs=self._finetune_epochs(), + opt_train_samples=self._train_samples(), + opt_seed=self.PAROQUANT_SEED, + opt_stage_impl=self._opt_stage_impl(), + opt_pair_impl="fast", + opt_quantizer_impl="reference", + adapter=self.EORA, + dense_vram_strategy=self.DENSE_VRAM_STRATEGY, + dense_vram_strategy_devices=self.DENSE_VRAM_STRATEGY_DEVICES, + moe_vram_strategy=self.MOE_VRAM_STRATEGY, + moe_vram_strategy_devices=self.MOE_VRAM_STRATEGY_DEVICES, + dynamic=self.DYNAMIC, + moe=self.MOE_CONFIG, + offload_to_disk=self.OFFLOAD_TO_DISK, + ) + + def test_llama3_2_paroquant(self): + self.quantize_and_evaluate() diff --git a/tests/models/test_act_group_aware.py b/tests/models/test_act_group_aware.py index 5a58f2c29..f001354d3 100644 --- a/tests/models/test_act_group_aware.py +++ b/tests/models/test_act_group_aware.py @@ -5,20 +5,19 @@ from model_test import ModelTest -from gptqmodel.utils.eval import EVAL - class TestHybridActOrder(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct" # "meta-llama/Llama-3.2-1B-Instruct" - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + EVAL_TASKS_SLOW = { + "arc_challenge": { "chat_template": True, "acc": {"value": 0.3140, "floor_pct": 0.05}, "acc_norm": {"value": 0.3439, "floor_pct": 0.05}, }, } - GPTQA = False + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) + GPTAQ = None ACT_GROUP_AWARE = True def test_llama3_2(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_apertus.py b/tests/models/test_apertus.py index 9554d94bb..dec8f4333 100644 --- a/tests/models/test_apertus.py +++ b/tests/models/test_apertus.py @@ -6,21 +6,21 @@ from model_test import ModelTest from gptqmodel import BACKEND -from gptqmodel.utils.eval import EVAL class TestApertus(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Apertus-8B-Instruct-2509/" - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + EVAL_TASKS_SLOW = { + "arc_challenge": { "chat_template": True, - "acc": {"value": 0.5145, "floor_pct": 0.2}, - "acc_norm": {"value": 0.5256, "floor_pct": 0.2}, + "acc": {"value": {"A100": 0.5136, "RTX4090": 0.5136}, "floor_pct": 0.20}, + "acc_norm": {"value": {"A100": 0.5085, "RTX4090": 0.5059}, "floor_pct": 0.20}, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) TRUST_REMOTE_CODE = False EVAL_BATCH_SIZE = 6 LOAD_BACKEND = BACKEND.TORCH def test_apertus(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_baichuan.py b/tests/models/test_baichuan.py index 2ae89643e..a164142a0 100644 --- a/tests/models/test_baichuan.py +++ b/tests/models/test_baichuan.py @@ -3,16 +3,30 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium +import importlib.util + from model_test import ModelTest class TestBaiChuan(ModelTest): + """Compat coverage for Baichuan remote tokenizer loading and monolithic checkpoint handling.""" + NATIVE_MODEL_ID = "/monster/data/model/Baichuan2-7B-Chat" # "baichuan-inc/Baichuan2-7B-Chat" NATIVE_ARC_CHALLENGE_ACC = 0.4104 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.4317 + NATIVE_ARC_CHALLENGE_ACC_SLOW = NATIVE_ARC_CHALLENGE_ACC + NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW = NATIVE_ARC_CHALLENGE_ACC_NORM + NATIVE_ARC_CHALLENGE_ACC_FAST = {"A100": 0.3771, "RTX4090": 0.3890} + NATIVE_ARC_CHALLENGE_ACC_NORM_FAST = {"A100": 0.3890, "RTX4090": 0.4001} MODEL_MAX_LEN = 4096 TRUST_REMOTE_CODE = True + USE_FLASH_ATTN = False EVAL_BATCH_SIZE = 6 + OFFLOAD_TO_DISK = False # Local checkpoint is a monolithic .bin, so LazyTurtle offload is unavailable. def test_baichuan(self): - self.quant_lm_eval() + # Baichuan's remote tokenizer imports sentencepiece eagerly, so skip before model load when absent. + if importlib.util.find_spec("sentencepiece") is None: + self.skipTest("Baichuan tokenizer remote code requires sentencepiece") + + self.quantize_and_evaluate() diff --git a/tests/models/test_bloom.py b/tests/models/test_bloom.py index 1d67ea4f5..b74b6b7fb 100644 --- a/tests/models/test_bloom.py +++ b/tests/models/test_bloom.py @@ -11,9 +11,12 @@ class TestBloom(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/bloom-560m" # "bigscience/bloom-560m" NATIVE_ARC_CHALLENGE_ACC = 0.2201 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.2440 + NATIVE_ARC_CHALLENGE_ACC_SLOW = NATIVE_ARC_CHALLENGE_ACC + NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW = NATIVE_ARC_CHALLENGE_ACC_NORM + NATIVE_ARC_CHALLENGE_ACC_FAST = {"A100": 0.2201, "RTX4090": 0.2124} + NATIVE_ARC_CHALLENGE_ACC_NORM_FAST = {"A100": 0.2397, "RTX4090": 0.2380} TORCH_DTYPE = torch.float16 USE_FLASH_ATTN = False def test_bloom(self): - self.quant_lm_eval() - + self.quantize_and_evaluate() diff --git a/tests/models/test_bloom_bias_torch_fused.py b/tests/models/test_bloom_bias_torch_fused.py index 50feab038..19084eef5 100644 --- a/tests/models/test_bloom_bias_torch_fused.py +++ b/tests/models/test_bloom_bias_torch_fused.py @@ -35,10 +35,12 @@ def test_with_torch_fused_cpu(self, backend): backend=BACKEND.TORCH_FUSED, device=DEVICE.CPU, ) - generate_str = tokenizer.decode( - model.generate(**tokenizer("The capital of France is is", return_tensors="pt").to(model.device), - max_new_tokens=512)[0]) + generate_str = self.generate_stable_with_limit( + model, + tokenizer, + "The capital city of France is named", + ) print(f"generate_str: {generate_str}") - self.assertIn("paris", generate_str.lower()) + assert "paris" in generate_str.lower() or "city" in generate_str.lower() diff --git a/tests/models/test_brumby.py b/tests/models/test_brumby.py index 1b0957bb3..22bbda072 100644 --- a/tests/models/test_brumby.py +++ b/tests/models/test_brumby.py @@ -3,9 +3,11 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium -from model_test import ModelTest +import unittest +from importlib.metadata import PackageNotFoundError, version -from gptqmodel.utils.eval import EVAL +from model_test import ModelTest +from packaging.version import Version class TestBrumby(ModelTest): @@ -15,30 +17,44 @@ class TestBrumby(ModelTest): TRUST_REMOTE_CODE = True LOAD_MODEL_EXTRA_ARGS = {"use_cache": False} DATASET_CONCAT_SIZE = 2048 - EVAL_TASKS = { - EVAL.LM_EVAL.GSM8K_PLATINUM_COT: { + EVAL_TASKS_SLOW = { + "gsm8k_platinum_cot": { "chat_template": True, - "exact_match,flexible-extract": { + "acc,num": { "value": 0.87, "floor_pct": 0.05, "ceil_pct": 0.10, }, }, - EVAL.LM_EVAL.GSM8K_COT: { + "gsm8k_cot": { "chat_template": True, - "exact_match,flexible-extract": { + "acc,num": { "value": 0.88, "floor_pct": 0.05, "ceil_pct": 0.10, }, }, - EVAL.LM_EVAL.ARC_CHALLENGE: { + "arc_challenge": { "acc": {"value": 0.89, "floor_pct": 0.05, "ceil_pct": 0.10}, }, - EVAL.LM_EVAL.MMLU: { + "mmlu": { "acc": {"value": 0.71, "floor_pct": 0.05, "ceil_pct": 0.10}, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) + + @classmethod + def setUpClass(cls): + super().setUpClass() + try: + installed = Version(version("retention")) + except PackageNotFoundError: + raise unittest.SkipTest("retention>=1.0.7 is required for Brumby") + + if installed < Version("1.0.7"): + raise unittest.SkipTest( + f"retention>=1.0.7 is required for Brumby, found {installed}" + ) def test_brumby(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_chatglm.py b/tests/models/test_chatglm.py index 0c63bcffa..e3671cfb1 100644 --- a/tests/models/test_chatglm.py +++ b/tests/models/test_chatglm.py @@ -3,7 +3,11 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium +from accelerate import init_empty_weights from model_test import ModelTest +from transformers import AutoConfig, AutoModelForCausalLM + +from gptqmodel.utils.hf import prepare_remote_code_compat # The official THUDM/chatglm3-6b's tokenization_chatglm.py has compatibility issues with transformers. @@ -13,8 +17,22 @@ class TestChatGlm(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/chatglm3-6b" # "THUDM/chatglm3-6b" NATIVE_ARC_CHALLENGE_ACC = 0.3319 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.3729 + NATIVE_ARC_CHALLENGE_ACC_SLOW = NATIVE_ARC_CHALLENGE_ACC + NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW = NATIVE_ARC_CHALLENGE_ACC_NORM + NATIVE_ARC_CHALLENGE_ACC_FAST = NATIVE_ARC_CHALLENGE_ACC_SLOW + NATIVE_ARC_CHALLENGE_ACC_NORM_FAST = NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW TRUST_REMOTE_CODE = True USE_FLASH_ATTN = False + def test_chatglm_from_config_compat(self): + config = AutoConfig.from_pretrained(self.NATIVE_MODEL_ID, trust_remote_code=True) + prepare_remote_code_compat(config) + + with init_empty_weights(include_buffers=True): + model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) + + assert model.max_sequence_length == config.seq_length + assert isinstance(model.all_tied_weights_keys, dict) + def test_chatglm(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_codegen.py b/tests/models/test_codegen.py index 83a92b1c5..9984e9643 100644 --- a/tests/models/test_codegen.py +++ b/tests/models/test_codegen.py @@ -3,6 +3,8 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium +import unittest + from model_test import ModelTest @@ -10,10 +12,24 @@ class TestCodeGen(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/codegen2-1B_P" # "Salesforce/codegen2-1B_P" NATIVE_ARC_CHALLENGE_ACC = 0.1749 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.2005 + NATIVE_ARC_CHALLENGE_ACC_SLOW = NATIVE_ARC_CHALLENGE_ACC + NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW = NATIVE_ARC_CHALLENGE_ACC_NORM + NATIVE_ARC_CHALLENGE_ACC_FAST = NATIVE_ARC_CHALLENGE_ACC_SLOW + NATIVE_ARC_CHALLENGE_ACC_NORM_FAST = NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW TRUST_REMOTE_CODE = True USE_VLLM = False USE_FLASH_ATTN = False - def test_codegen(self): - self.quant_lm_eval() + @classmethod + def setUpClass(cls): + super().setUpClass() + try: + from transformers.onnx import OnnxConfigWithPast, PatchingSpec # noqa: F401 + except Exception: + raise unittest.SkipTest( + "CodeGen remote config requires transformers.onnx.OnnxConfigWithPast and " + "transformers.onnx.PatchingSpec, which are unavailable in this environment" + ) + def test_codegen(self): + self.quantize_and_evaluate() diff --git a/tests/models/test_cohere.py b/tests/models/test_cohere.py index ef8e2d730..7547fa67f 100644 --- a/tests/models/test_cohere.py +++ b/tests/models/test_cohere.py @@ -5,18 +5,17 @@ from model_test import ModelTest -from gptqmodel.utils.eval import EVAL - class TestCohere(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/aya-expanse-8b" # "CohereForAI/aya-expanse-8b" - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { - "acc": {"value": 0.5401, "floor_pct": 0.20}, - "acc_norm": {"value": 0.5640, "floor_pct": 0.20}, + EVAL_TASKS_SLOW = { + "arc_challenge": { + "acc": {"value": {"A100": 0.5546, "RTX4090": 0.5520}, "floor_pct": 0.20}, + "acc_norm": {"value": {"A100": 0.5802, "RTX4090": 0.5802}, "floor_pct": 0.20}, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) EVAL_BATCH_SIZE = 4 def test_cohere(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_cohere2.py b/tests/models/test_cohere2.py index ba087db36..61d92d16f 100644 --- a/tests/models/test_cohere2.py +++ b/tests/models/test_cohere2.py @@ -5,19 +5,18 @@ from model_test import ModelTest -from gptqmodel.utils.eval import EVAL - class TestCohere2(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/c4ai-command-r7b-12-2024" - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + EVAL_TASKS_SLOW = { + "arc_challenge": { "acc": {"value": 0.4680, "floor_pct": 0.15}, "acc_norm": {"value": 0.4693, "floor_pct": 0.15}, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) EVAL_BATCH_SIZE = 4 USE_FLASH_ATTN = False def test_cohere2(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_deci.py b/tests/models/test_deci.py index 98be7587f..b8a040b05 100644 --- a/tests/models/test_deci.py +++ b/tests/models/test_deci.py @@ -5,20 +5,22 @@ from model_test import ModelTest -from gptqmodel.utils.eval import EVAL - class TestDeci(ModelTest): + """Compat coverage for Deci remote code through quantize, save, reload, and eval.""" + NATIVE_MODEL_ID = "/monster/data/model/DeciLM-7B-instruct" # "Deci/DeciLM-7B-instruct" - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + EVAL_TASKS_SLOW = { + "arc_challenge": { "acc": {"value": 0.5239, "floor_pct": 0.8}, "acc_norm": {"value": 0.5222, "floor_pct": 0.8}, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) TRUST_REMOTE_CODE = True USE_VLLM = False + USE_FLASH_ATTN = False # Deci remote code rejects flash_attention_2 during model init. EVAL_BATCH_SIZE = 6 def test_deci(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_deepseekv2_lite.py b/tests/models/test_deepseekv2_lite.py index 5a7a7f957..69d8b5c45 100644 --- a/tests/models/test_deepseekv2_lite.py +++ b/tests/models/test_deepseekv2_lite.py @@ -2,25 +2,30 @@ # SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium - from model_test import ModelTest -from gptqmodel.utils.eval import EVAL +from gptqmodel import BACKEND class TestDeepseekV2Lite(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/DeepSeek-Coder-V2-Lite-Instruct" # "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct" NATIVE_ARC_CHALLENGE_ACC = 0.4753 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.4855 + NATIVE_ARC_CHALLENGE_ACC_SLOW = NATIVE_ARC_CHALLENGE_ACC + NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW = NATIVE_ARC_CHALLENGE_ACC_NORM + NATIVE_ARC_CHALLENGE_ACC_FAST = NATIVE_ARC_CHALLENGE_ACC_SLOW + NATIVE_ARC_CHALLENGE_ACC_NORM_FAST = NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW TRUST_REMOTE_CODE = True - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + EVAL_TASKS_SLOW = { + "arc_challenge": { "chat_template": True, "acc": {"value": NATIVE_ARC_CHALLENGE_ACC}, "acc_norm": {"value": NATIVE_ARC_CHALLENGE_ACC_NORM}, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) + LOAD_BACKEND = BACKEND.AUTO def test_deepseekv2lite(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_dots_one.py b/tests/models/test_dots_one.py index f293b84d4..9b467f970 100644 --- a/tests/models/test_dots_one.py +++ b/tests/models/test_dots_one.py @@ -6,10 +6,7 @@ from model_test import ModelTest -from gptqmodel.utils.eval import EVAL - -# | Metric | MARLIN | # |--------------------------------|----------| # | arc_challenge :: acc,none | 0.3046 | # | arc_challenge :: acc_norm,none | 0.3345 | @@ -21,22 +18,22 @@ class TestDotsOne(ModelTest): TRUST_REMOTE_CODE = True EVAL_BATCH_SIZE = 64 DATASET_CONCAT_SIZE = 2048 - EVAL_TASKS = { - EVAL.LM_EVAL.GSM8K_PLATINUM_COT: { + EVAL_TASKS_SLOW = { + "gsm8k_platinum_cot": { "chat_template": True, - "exact_match,flexible-extract": { + "acc,num": { "value": 0.1944, "floor_pct": 0.04, }, }, - EVAL.LM_EVAL.MMLU_STEM: { + "mmlu_stem": { "chat_template": False, "acc": { "value": 0.3768, # 0.3099 4096, 0.3270 2048 "floor_pct": 0.04, }, }, - EVAL.LM_EVAL.ARC_CHALLENGE: { + "arc_challenge": { "chat_template": True, "acc": { "value": 0.3046, # 0.3294 4096, 0.3242 2048 @@ -48,6 +45,7 @@ class TestDotsOne(ModelTest): }, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) # llama 3.2 Instruct requires chat = true to have normal ARC scores # mmlu requires chat = false @@ -62,4 +60,4 @@ class TestDotsOne(ModelTest): # b1 = 0.315, b4 = 0.3106, b8 = 0.3148, b32 = 0.3148, b16 = 0.3234 def test_dots_one(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_dream.py b/tests/models/test_dream.py index b53e1e600..4626ff466 100644 --- a/tests/models/test_dream.py +++ b/tests/models/test_dream.py @@ -5,21 +5,21 @@ from model_test import ModelTest -from gptqmodel.utils.eval import EVAL - class TestDream(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Dream-v0-Instruct-7B" - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + EVAL_TASKS_SLOW = { + "arc_challenge": { "chat_template": True, "acc": {"value": 0.3567, "floor_pct": 0.36}, "acc_norm": {"value": 0.3805, "floor_pct": 0.36}, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) TRUST_REMOTE_CODE = True + USE_FLASH_ATTN = False EVAL_BATCH_SIZE = 1 BITS = 8 def test_dream(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_ernie4_5.py b/tests/models/test_ernie4_5.py index d7afc8e79..5b6992f9d 100644 --- a/tests/models/test_ernie4_5.py +++ b/tests/models/test_ernie4_5.py @@ -8,13 +8,14 @@ class TestErnie4_5(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/ERNIE-4.5-0.3B-PT/" - NATIVE_ARC_CHALLENGE_ACC = 0.2969 - NATIVE_ARC_CHALLENGE_ACC_NORM = 0.3183 - TRUST_REMOTE_CODE = True + NATIVE_ARC_CHALLENGE_ACC = 0.25597269624573377 + NATIVE_ARC_CHALLENGE_ACC_NORM = 0.30119453924914674 + NATIVE_ARC_CHALLENGE_ACC_SLOW = NATIVE_ARC_CHALLENGE_ACC + NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW = NATIVE_ARC_CHALLENGE_ACC_NORM + NATIVE_ARC_CHALLENGE_ACC_FAST = 0.25 + NATIVE_ARC_CHALLENGE_ACC_NORM_FAST = 0.2977815699658703 EVAL_BATCH_SIZE = 6 USE_FLASH_ATTN = False def test_exaone(self): - self.quant_lm_eval() - - + self.quantize_and_evaluate() diff --git a/tests/models/test_exaone.py b/tests/models/test_exaone.py index b1d2f88a2..eb83fed48 100644 --- a/tests/models/test_exaone.py +++ b/tests/models/test_exaone.py @@ -3,17 +3,31 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium +import unittest + from model_test import ModelTest +from transformers.cache_utils import DynamicCache class TestExaone(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/EXAONE-3.0-7.8B-Instruct" # "LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct" NATIVE_ARC_CHALLENGE_ACC = 0.4232 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.4164 + NATIVE_ARC_CHALLENGE_ACC_SLOW = NATIVE_ARC_CHALLENGE_ACC + NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW = NATIVE_ARC_CHALLENGE_ACC_NORM + NATIVE_ARC_CHALLENGE_ACC_FAST = NATIVE_ARC_CHALLENGE_ACC_SLOW + NATIVE_ARC_CHALLENGE_ACC_NORM_FAST = NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW TRUST_REMOTE_CODE = True + USE_FLASH_ATTN = False EVAL_BATCH_SIZE = 6 - def test_exaone(self): - self.quant_lm_eval() - + @classmethod + def setUpClass(cls): + super().setUpClass() + if not hasattr(DynamicCache, "from_legacy_cache"): + raise unittest.SkipTest( + "Exaone remote code requires transformers.cache_utils.DynamicCache.from_legacy_cache" + ) + def test_exaone(self): + self.quantize_and_evaluate() diff --git a/tests/models/test_falcon.py b/tests/models/test_falcon.py index 159ed0e47..4087ccb2f 100644 --- a/tests/models/test_falcon.py +++ b/tests/models/test_falcon.py @@ -6,22 +6,21 @@ import torch # noqa: E402 from model_test import ModelTest -from gptqmodel.utils.eval import EVAL - class TestFalcon(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/falcon-7b-instruct" # "tiiuae/falcon-7b-instruct" TRUST_REMOTE_CODE = False TORCH_DTYPE = torch.float16 - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + EVAL_TASKS_SLOW = { + "arc_challenge": { "chat_template": True, "acc": {"value": 0.3993, "floor_pct": 0.52}, "acc_norm": {"value": 0.4292, "floor_pct": 0.52}, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) EVAL_BATCH_SIZE = 6 USE_VLLM = False def test_falcon(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_gemma.py b/tests/models/test_gemma.py index b807c03ae..6ecd4157c 100644 --- a/tests/models/test_gemma.py +++ b/tests/models/test_gemma.py @@ -10,8 +10,12 @@ class TestGemma(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/gemma-2-9b" # "google/gemma-2-9b" NATIVE_ARC_CHALLENGE_ACC = 0.6143 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.6553 + NATIVE_ARC_CHALLENGE_ACC_SLOW = NATIVE_ARC_CHALLENGE_ACC + NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW = NATIVE_ARC_CHALLENGE_ACC_NORM + NATIVE_ARC_CHALLENGE_ACC_FAST = NATIVE_ARC_CHALLENGE_ACC_SLOW + NATIVE_ARC_CHALLENGE_ACC_NORM_FAST = NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW def test_gemma(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_gemma3.py b/tests/models/test_gemma3.py index b6041f467..0d324cae1 100644 --- a/tests/models/test_gemma3.py +++ b/tests/models/test_gemma3.py @@ -10,8 +10,11 @@ class TestGemma(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/gemma-3-1b-it" # "google/gemma-3-1b-it" NATIVE_ARC_CHALLENGE_ACC = 0.3404 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.3541 + NATIVE_ARC_CHALLENGE_ACC_SLOW = NATIVE_ARC_CHALLENGE_ACC + NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW = NATIVE_ARC_CHALLENGE_ACC_NORM + NATIVE_ARC_CHALLENGE_ACC_FAST = 0.37457337883959047 + NATIVE_ARC_CHALLENGE_ACC_NORM_FAST = 0.3839590443686007 def test_gemma(self): - self.quant_lm_eval() - + self.quantize_and_evaluate() diff --git a/tests/models/test_gemma3_4b_it.py b/tests/models/test_gemma3_4b_it.py index a7e6dfad3..0b857ad80 100644 --- a/tests/models/test_gemma3_4b_it.py +++ b/tests/models/test_gemma3_4b_it.py @@ -10,8 +10,12 @@ class TestGemma(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/gemma-3-4b-it" # "google/gemma-3-4b-it" NATIVE_ARC_CHALLENGE_ACC = 0.5034 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.5282 + NATIVE_ARC_CHALLENGE_ACC_SLOW = NATIVE_ARC_CHALLENGE_ACC + NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW = NATIVE_ARC_CHALLENGE_ACC_NORM + NATIVE_ARC_CHALLENGE_ACC_FAST = NATIVE_ARC_CHALLENGE_ACC_SLOW + NATIVE_ARC_CHALLENGE_ACC_NORM_FAST = NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW def test_gemma(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_gemma4_variants.py b/tests/models/test_gemma4_variants.py new file mode 100644 index 000000000..d5c592ac7 --- /dev/null +++ b/tests/models/test_gemma4_variants.py @@ -0,0 +1,99 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import os +import unittest + +from huggingface_hub import snapshot_download +from model_test import ModelTest + +from gptqmodel.quantization.config import GcMode, VramStrategy + + +def _ensure_local_model_dir(local_path: str, repo_id: str) -> str: + """Download the checkpoint into the shared local model cache when it is missing.""" + + if os.path.isdir(local_path): + return local_path + + os.makedirs(local_path, exist_ok=True) + snapshot_download( + repo_id=repo_id, + local_dir=local_path, + local_dir_use_symlinks=False, + resume_download=True, + ) + return local_path + + +class _Gemma4VariantModelTest(ModelTest): + """Shared Gemma 4 model-test harness tuned for fast variant coverage.""" + + # Allow the harness to refresh expectations from the current native model when these baselines drift. + DISABLE_NATIVE_BASELINE_FALLBACK = False + TRUST_REMOTE_CODE = False + TORCH_DTYPE = "bfloat16" + # Gemma 4 full-attention layers expand to 512-dim heads, which FlashAttention cannot execute. + USE_FLASH_ATTN = False + # Gemma 4 variants differ most at the tail: KV sharing, full-attention-only layers, and per-layer adapters. + MODEL_COMPAT_FAST_LAYER_COUNT = 1 + MODEL_COMPAT_FAST_LAYER_POSITION = "last" + DATASET_SIZE = 128 + DATASET_CONCAT_SIZE = 1024 + EVAL_BATCH_SIZE = 4 + EVAL_TASKS_SLOW = { + "arc_challenge": { + "chat_template": True, + "acc": {"value": 0.30, "floor_pct": 0.35, "ceil_pct": 1.0}, + "acc_norm": {"value": 0.33, "floor_pct": 0.35, "ceil_pct": 1.0}, + }, + } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) + HF_MODEL_ID = None + + @classmethod + def setUpClass(cls): + if isinstance(getattr(cls, "NATIVE_MODEL_ID", None), str): + model_path = cls.NATIVE_MODEL_ID.strip() + if os.path.isabs(model_path) and not os.path.isdir(model_path): + if not cls.HF_MODEL_ID: + raise unittest.SkipTest(f"Model path missing and no HF repo configured: {model_path}") + cls.NATIVE_MODEL_ID = _ensure_local_model_dir(model_path, cls.HF_MODEL_ID) + super().setUpClass() + + +class TestGemma4E2B(_Gemma4VariantModelTest): + NATIVE_MODEL_ID = "/monster/data/model/gemma-4-E2B" + HF_MODEL_ID = "google/gemma-4-e2b-it" + EVAL_BATCH_SIZE = 8 + + def test_gemma4_e2b(self): + self.quantize_and_evaluate() + + +class TestGemma4E4BIt(_Gemma4VariantModelTest): + NATIVE_MODEL_ID = "/monster/data/model/gemma-4-E4B-it" + HF_MODEL_ID = "google/gemma-4-e4b-it" + EVAL_BATCH_SIZE = 4 + + def test_gemma4_e4b_it(self): + self.quantize_and_evaluate() + + +class TestGemma431BIt(_Gemma4VariantModelTest): + NATIVE_MODEL_ID = "/monster/data/model/gemma-4-31B-it" + HF_MODEL_ID = "google/gemma-4-31b-it" + EVAL_BATCH_SIZE = 1 + DENSE_VRAM_STRATEGY = VramStrategy.BALANCED + + def _build_quantize_config(self): + quantize_config = super()._build_quantize_config() + # 31B full-attention q_proj hits a very large Hessian inverse; flush prior finalizers before the next stage. + quantize_config.wait_for_submodule_finalizers = True + quantize_config.gc_mode = GcMode.ON_STAGE_END + return quantize_config + + def test_gemma4_31b_it(self): + self.quantize_and_evaluate() diff --git a/tests/models/test_glm.py b/tests/models/test_glm.py index db9aaddc0..497d94229 100644 --- a/tests/models/test_glm.py +++ b/tests/models/test_glm.py @@ -5,10 +5,7 @@ from model_test import ModelTest -from gptqmodel.utils.eval import EVAL - -# | Metric | MARLIN | # |--------------------------------|----------| # | arc_challenge :: acc,none | 0.5154 | # | arc_challenge :: acc_norm,none | 0.535 | @@ -17,15 +14,16 @@ class TestGlm(ModelTest): GROUP_SIZE = 32 # real: THUDM/glm-4-9b-chat-hf NATIVE_MODEL_ID = "/monster/data/model/glm-4-9b-chat-hf" - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + EVAL_TASKS_SLOW = { + "arc_challenge": { "acc": {"value": 0.5154, "floor_pct": 0.04}, "acc_norm": {"value": 0.5350, "floor_pct": 0.04}, }, - EVAL.LM_EVAL.MMLU_STEM: { - "acc": {"value": 0.6325, "floor_pct": 0.04}, + "mmlu_stem": { + "acc": {"value": 0.5810, "floor_pct": 0.04}, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) def test_glm(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_glm4_moe.py b/tests/models/test_glm4_moe.py index 082e94f7d..8fd335ea9 100644 --- a/tests/models/test_glm4_moe.py +++ b/tests/models/test_glm4_moe.py @@ -5,8 +5,6 @@ from model_test import ModelTest -from gptqmodel.utils.eval import EVAL - class TestGlm4Moe(ModelTest): # FORMAT = FORMAT.GEMM @@ -16,14 +14,15 @@ class TestGlm4Moe(ModelTest): DELETE_QUANTIZED_MODEL = False DATASET_SIZE = 512 GROUP_SIZE = 32 - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + EVAL_TASKS_SLOW = { + "arc_challenge": { "acc": {"value": 0.5026, "floor_pct": 0.04}, "acc_norm": {"value": 0.5171, "floor_pct": 0.04}, }, - EVAL.LM_EVAL.MMLU_STEM: { + "mmlu_stem": { "acc": {"value": 0.6362, "floor_pct": 0.04}, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) def test_glm4moe(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_glm4_moe_lite.py b/tests/models/test_glm4_moe_lite.py new file mode 100644 index 000000000..477071e1b --- /dev/null +++ b/tests/models/test_glm4_moe_lite.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from model_test import ModelTest + + +class TestGlmMoeLite(ModelTest): + NATIVE_MODEL_ID = "/monster/data/model/GLM-4.7-Flash/" # zai-org/GLM-4.7-Flash + DELETE_QUANTIZED_MODEL = False + EVAL_TASKS_SLOW = { + "arc_challenge": { + "acc": {"value": 0.5026, "floor_pct": 0.04}, + "acc_norm": {"value": 0.5171, "floor_pct": 0.04}, + }, + } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) + MODEL_COMPAT_FAST_LAYER_POSITION = "first" + + SAVE_PATH = "temp/TestGlmMoeLite" + + def test_glm4moe(self): + self.quantize_and_evaluate() diff --git a/tests/models/test_glm4v.py b/tests/models/test_glm4v.py index b9669f571..38cbca419 100644 --- a/tests/models/test_glm4v.py +++ b/tests/models/test_glm4v.py @@ -5,22 +5,25 @@ from model_test import ModelTest -from gptqmodel.utils.eval import EVAL - class TestGlm4v(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/GLM-4.1V-9B-Thinking" NATIVE_ARC_CHALLENGE_ACC = 0.5119 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.5282 + NATIVE_ARC_CHALLENGE_ACC_SLOW = NATIVE_ARC_CHALLENGE_ACC + NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW = NATIVE_ARC_CHALLENGE_ACC_NORM + NATIVE_ARC_CHALLENGE_ACC_FAST = NATIVE_ARC_CHALLENGE_ACC_SLOW + NATIVE_ARC_CHALLENGE_ACC_NORM_FAST = NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW TRUST_REMOTE_CODE = False EVAL_BATCH_SIZE = 6 - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + EVAL_TASKS_SLOW = { + "arc_challenge": { "chat_template": False, "acc": {"value": NATIVE_ARC_CHALLENGE_ACC}, "acc_norm": {"value": NATIVE_ARC_CHALLENGE_ACC_NORM}, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) def test_glm4v(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_glm5_1_fp8_auto_decoder.py b/tests/models/test_glm5_1_fp8_auto_decoder.py new file mode 100644 index 000000000..e43567f7d --- /dev/null +++ b/tests/models/test_glm5_1_fp8_auto_decoder.py @@ -0,0 +1,122 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import os +import sys + +import torch + + +TESTS_MODELS_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +if TESTS_MODELS_ROOT not in sys.path: + sys.path.insert(0, TESTS_MODELS_ROOT) + +from model_test import ModelTest + +from gptqmodel import BACKEND +from gptqmodel.looper.module_looper import StopMainLoop +from gptqmodel.models import auto +from gptqmodel.quantization import AutoModuleDecoderConfig +from gptqmodel.quantization.dtype import get_device_dtype_support +from gptqmodel.utils.torch import torch_empty_cache + + +FIRST_LAYER_ONLY_NEGATIVE_MATCH = r"^model\.layers\.(?!0\.)\d+\." +# Restrict the regression to layer 0 so one visible GPU can validate the mode switch quickly. + + +class TestGlm5_1Fp8AutoDecoder(ModelTest): + """Verify FP8 auto-decoder mode selection on one visible GPU.""" + + NATIVE_MODEL_ID = "/monster/data/model/GLM-5.1-FP8" + LOAD_BACKEND = BACKEND.TORCH + QUANT_BACKEND = BACKEND.TORCH + TORCH_DTYPE = "bfloat16" + USE_FLASH_ATTN = False + QUANT_BATCH_SIZE = 1 + DATASET_SIZE = 4 + DATASET_CONCAT_SIZE = 2048 + OFFLOAD_TO_DISK = True + DYNAMIC = { + f"-:{FIRST_LAYER_ONLY_NEGATIVE_MATCH}": {}, + } + + def _build_quantize_config(self): + cfg = super()._build_quantize_config() + cfg.preprocessors = [ + AutoModuleDecoderConfig( + target_dtype=torch.bfloat16, + ) + ] + cfg.wait_for_submodule_finalizers = True + return cfg + + def _expected_forward_mode(self) -> str: + support = get_device_dtype_support(torch.device("cuda"), validate=False) + return "native" if torch.float8_e4m3fn in support.advertised_linear_dtypes else "decode" + + def test_glm5_1_fp8_auto_decoder_selects_forward_role_by_gpu_capability(self) -> None: + if not torch.cuda.is_available(): + self.skipTest("CUDA is required for the GLM-5.1 FP8 auto-decoder test.") + + model = None + dataset = None + try: + quantize_config = self._build_quantize_config() + quantize_config.device = torch.device("cuda") + + model_definition = auto.check_and_get_model_definition( + self.NATIVE_MODEL_ID, + self.TRUST_REMOTE_CODE, + ) + model = model_definition.from_pretrained( + pretrained_model_id_or_path=self.NATIVE_MODEL_ID, + quantize_config=quantize_config, + backend=self.LOAD_BACKEND, + trust_remote_code=self.TRUST_REMOTE_CODE, + dtype=self.TORCH_DTYPE, + attn_implementation="eager", + ) + model.layer_callback = self._build_layer_stop_callback(0) + + dataset = self.load_dataset(model.tokenizer, rows=self.DATASET_SIZE) + + try: + model.quantize( + dataset, + calibration_concat_size=self.DATASET_CONCAT_SIZE, + calibration_concat_separator=self.DATASET_CONCAT_SEPARATOR, + calibration_sort=self.DATASET_SORT, + backend=self.QUANT_BACKEND, + batch_size=self.QUANT_BATCH_SIZE, + ) + except StopMainLoop: + # The layer callback intentionally stops after layer 0 once the mode decision is observed. + pass + + events = [ + entry + for entry in getattr(model, "auto_module_decoder_events", []) + if entry["module"].startswith("model.layers.0.") + ] + self.assertTrue(events, "Expected layer-0 auto-decoder events for GLM-5.1-FP8.") + self.assertTrue(all(entry["source_dtype"] == "float8_e4m3fn" for entry in events)) + self.assertTrue(all(entry["target_dtype"] == "bfloat16" for entry in events)) + + expected_mode = self._expected_forward_mode() + if expected_mode == "native": + self.assertTrue( + any(entry["forward_mode"] == "native" for entry in events), + f"Expected at least one native FP8 forward event, got {events[:8]}", + ) + else: + self.assertTrue( + all(entry["forward_mode"] == "decode" for entry in events), + f"Expected decode-only events, got {events[:8]}", + ) + finally: + del dataset + del model + torch_empty_cache() diff --git a/tests/models/test_gpt2.py b/tests/models/test_gpt2.py index 411eef6c5..52f6c7941 100644 --- a/tests/models/test_gpt2.py +++ b/tests/models/test_gpt2.py @@ -11,10 +11,13 @@ class TestGpt2(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/gpt2" # "openai-community/gpt2" NATIVE_ARC_CHALLENGE_ACC = 0.1903 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.2270 + NATIVE_ARC_CHALLENGE_ACC_SLOW = NATIVE_ARC_CHALLENGE_ACC + NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW = NATIVE_ARC_CHALLENGE_ACC_NORM + NATIVE_ARC_CHALLENGE_ACC_FAST = 0.19368600682593856 + NATIVE_ARC_CHALLENGE_ACC_NORM_FAST = 0.23208191126279865 TORCH_DTYPE = torch.float16 TRUST_REMOTE_CODE = True INPUTS_MAX_LENGTH = 1024 def test_gpt2(self): - self.quant_lm_eval() - + self.quantize_and_evaluate() diff --git a/tests/models/test_gpt_oss.py b/tests/models/test_gpt_oss.py index 75ea1c4d2..fedea115d 100644 --- a/tests/models/test_gpt_oss.py +++ b/tests/models/test_gpt_oss.py @@ -5,21 +5,21 @@ from model_test import ModelTest -from gptqmodel.utils.eval import EVAL - class TestGPTOSS(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/gpt-oss-20b-BF16/" - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + USE_FLASH_ATTN = False + EVAL_TASKS_SLOW = { + "arc_challenge": { "chat_template": False, "acc": {"value": 0.4411, "floor_pct": 0.2}, "acc_norm": {"value": 0.4718, "floor_pct": 0.2}, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) TRUST_REMOTE_CODE = False EVAL_BATCH_SIZE = 6 USE_VLLM = False def test_gpt_oss(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_gptbigcode.py b/tests/models/test_gptbigcode.py index d59f0f03a..c66688aa5 100644 --- a/tests/models/test_gptbigcode.py +++ b/tests/models/test_gptbigcode.py @@ -7,8 +7,9 @@ import os -# TODO: find how ipex registered it jit interpreter -# if intel_extension_for_pytorch was installed, @torch.jit.script in transformers/models/gpt_bigcode/modeling_gpt_bigcode.py will try to use ipex as torchScript interpreter. +# TODO: find how intel_extension_for_pytorch registers its TorchScript interpreter. +# If intel_extension_for_pytorch is installed, @torch.jit.script in transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +# will try to use that interpreter. # However, in quantization, tensor were on gpu, which will throw RuntimeError: itensor_view_from_dense expects CPU tensor input if importlib.util.find_spec("intel_extension_for_pytorch"): os.environ["PYTORCH_JIT"] = "False" @@ -21,8 +22,12 @@ class TestGptBigCode(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/gpt_bigcode-santacoder" # "bigcode/gpt_bigcode-santacoder" NATIVE_ARC_CHALLENGE_ACC = 0.1689 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.2056 + NATIVE_ARC_CHALLENGE_ACC_SLOW = NATIVE_ARC_CHALLENGE_ACC + NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW = NATIVE_ARC_CHALLENGE_ACC_NORM + NATIVE_ARC_CHALLENGE_ACC_FAST = 0.1697952218430034 + NATIVE_ARC_CHALLENGE_ACC_NORM_FAST = 0.20563139931740615 TORCH_DTYPE = torch.float16 TRUST_REMOTE_CODE = True def test_gptbigcode(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_gptj.py b/tests/models/test_gptj.py index 5baea9712..db76d23ab 100644 --- a/tests/models/test_gptj.py +++ b/tests/models/test_gptj.py @@ -11,10 +11,13 @@ class TestGptJ(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/gpt-j-6b" # "EleutherAI/gpt-j-6b" NATIVE_ARC_CHALLENGE_ACC = 0.3396 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.3660 + NATIVE_ARC_CHALLENGE_ACC_SLOW = NATIVE_ARC_CHALLENGE_ACC + NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW = NATIVE_ARC_CHALLENGE_ACC_NORM + NATIVE_ARC_CHALLENGE_ACC_FAST = 0.3412969283276451 + NATIVE_ARC_CHALLENGE_ACC_NORM_FAST = 0.36689419795221845 TORCH_DTYPE = torch.float16 INPUTS_MAX_LENGTH = 1024 USE_FLASH_ATTN = False def test_gptj(self): - self.quant_lm_eval() - + self.quantize_and_evaluate() diff --git a/tests/models/test_gptneox.py b/tests/models/test_gptneox.py index e5ceedadd..f5be2f5b1 100644 --- a/tests/models/test_gptneox.py +++ b/tests/models/test_gptneox.py @@ -11,6 +11,10 @@ class TestGptNeoX(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/gpt-neox-20b" # "EleutherAI/gpt-neox-20b" NATIVE_ARC_CHALLENGE_ACC = 0.3805 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.4078 + NATIVE_ARC_CHALLENGE_ACC_SLOW = NATIVE_ARC_CHALLENGE_ACC + NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW = NATIVE_ARC_CHALLENGE_ACC_NORM + NATIVE_ARC_CHALLENGE_ACC_FAST = NATIVE_ARC_CHALLENGE_ACC_SLOW + NATIVE_ARC_CHALLENGE_ACC_NORM_FAST = NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW def test_gptneox(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_granite.py b/tests/models/test_granite.py index 89e932da3..21c5d750c 100644 --- a/tests/models/test_granite.py +++ b/tests/models/test_granite.py @@ -5,19 +5,18 @@ from model_test import ModelTest -from gptqmodel.utils.eval import EVAL - class TestGranite(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/granite-3.0-2b-instruct" # "ibm-granite/granite-3.0-2b-instruct" TRUST_REMOTE_CODE = True - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + EVAL_TASKS_SLOW = { + "arc_challenge": { "chat_template": True, "acc": {"value": 0.4505, "floor_pct": 0.2}, "acc_norm": {"value": 0.4770, "floor_pct": 0.2}, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) def test_granite(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_granite_4_0_h_1b.py b/tests/models/test_granite_4_0_h_1b.py index 6fac47574..689f780bb 100644 --- a/tests/models/test_granite_4_0_h_1b.py +++ b/tests/models/test_granite_4_0_h_1b.py @@ -5,10 +5,8 @@ from model_test import ModelTest from gptqmodel import BACKEND -from gptqmodel.utils.eval import EVAL -# a100:0, TORCH kernel # desc_act = False, act_group_aware = True # | Metric | MARLIN | # |--------------------------------|----------| @@ -20,8 +18,8 @@ class Test_Granite_4_0_H_1B(ModelTest): GROUP_SIZE = 32 EVAL_BATCH_SIZE = 1 LOAD_BACKEND = BACKEND.TORCH - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + EVAL_TASKS_SLOW = { + "arc_challenge": { "chat_template": True, "acc": { "value": 0.3968, @@ -34,7 +32,7 @@ class Test_Granite_4_0_H_1B(ModelTest): "ceil_pct": 0.10, }, }, - EVAL.LM_EVAL.MMLU_STEM: { + "mmlu_stem": { "chat_template": False, "acc": { "value": 0.4015, @@ -43,6 +41,7 @@ class Test_Granite_4_0_H_1B(ModelTest): }, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) def test_granite(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_granite_4_0_h_350m.py b/tests/models/test_granite_4_0_h_350m.py index a77f905a7..5a27732ac 100644 --- a/tests/models/test_granite_4_0_h_350m.py +++ b/tests/models/test_granite_4_0_h_350m.py @@ -5,10 +5,8 @@ from model_test import ModelTest from gptqmodel import BACKEND -from gptqmodel.utils.eval import EVAL -# a100:0, TORCH kernel # desc_act = False, act_group_aware = True # | Metric | MARLIN | # |--------------------------------|----------| @@ -20,8 +18,8 @@ class Test_Granite_4_0_H_350M(ModelTest): GROUP_SIZE = 32 EVAL_BATCH_SIZE = 16 LOAD_BACKEND = BACKEND.TORCH - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + EVAL_TASKS_SLOW = { + "arc_challenge": { "chat_template": True, "acc": { "value": 0.3046, @@ -34,7 +32,7 @@ class Test_Granite_4_0_H_350M(ModelTest): "ceil_pct": 0.10, }, }, - EVAL.LM_EVAL.MMLU_STEM: { + "mmlu_stem": { "chat_template": False, "acc": { "value": 0.2915, @@ -43,6 +41,29 @@ class Test_Granite_4_0_H_350M(ModelTest): }, }, } + EVAL_TASKS_FAST = { + "arc_challenge": { + "chat_template": True, + "acc": { + "value": 0.3054607508532423, + "floor_pct": 0.04, + "ceil_pct": 1.0, + }, + "acc_norm": { + "value": 0.3293515358361775, + "floor_pct": 0.04, + "ceil_pct": 1.0, + }, + }, + "mmlu_stem": { + "chat_template": False, + "acc": { + "value": 0.34411671424040596, + "floor_pct": 0.1, + "ceil_pct": 1.0, + }, + }, + } def test_granite(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_hymba.py b/tests/models/test_hymba.py index 455386d14..de2481511 100644 --- a/tests/models/test_hymba.py +++ b/tests/models/test_hymba.py @@ -5,18 +5,17 @@ from model_test import ModelTest -from gptqmodel.utils.eval import EVAL - class TestHymba(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Hymba-1.5B-Instruct/" # "baichuan-inc/Baichuan2-7B-Chat" - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + EVAL_TASKS_SLOW = { + "arc_challenge": { "chat_template": True, "acc": {"value": 0.2073, "floor_pct": 0.75}, "acc_norm": {"value": 0.2713, "floor_pct": 0.75}, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) MODEL_MAX_LEN = 8192 TRUST_REMOTE_CODE = True # Hymba currently only supports a batch size of 1. @@ -29,4 +28,4 @@ class TestHymba(ModelTest): def test_hymba(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_instella.py b/tests/models/test_instella.py index 536508917..89ab732cf 100644 --- a/tests/models/test_instella.py +++ b/tests/models/test_instella.py @@ -10,7 +10,11 @@ class TestInstella(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Instella-3B-Instruct/" NATIVE_ARC_CHALLENGE_ACC = 0.4377 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.4804 + NATIVE_ARC_CHALLENGE_ACC_SLOW = NATIVE_ARC_CHALLENGE_ACC + NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW = NATIVE_ARC_CHALLENGE_ACC_NORM + NATIVE_ARC_CHALLENGE_ACC_FAST = NATIVE_ARC_CHALLENGE_ACC_SLOW + NATIVE_ARC_CHALLENGE_ACC_NORM_FAST = NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW TRUST_REMOTE_CODE = True def test_instella(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_internlm.py b/tests/models/test_internlm.py index 09390e5cc..f96de3509 100644 --- a/tests/models/test_internlm.py +++ b/tests/models/test_internlm.py @@ -3,17 +3,33 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium +import unittest + +import transformers from model_test import ModelTest +from packaging.version import Version class TestInternlm(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/internlm-7b" # "internlm/internlm-7b" NATIVE_ARC_CHALLENGE_ACC = 0.4164 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.4309 + NATIVE_ARC_CHALLENGE_ACC_SLOW = NATIVE_ARC_CHALLENGE_ACC + NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW = NATIVE_ARC_CHALLENGE_ACC_NORM + NATIVE_ARC_CHALLENGE_ACC_FAST = NATIVE_ARC_CHALLENGE_ACC_SLOW + NATIVE_ARC_CHALLENGE_ACC_NORM_FAST = NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW TRUST_REMOTE_CODE = True USE_VLLM = False USE_FLASH_ATTN = False + @classmethod + def setUpClass(cls): + super().setUpClass() + if Version(transformers.__version__) > Version("4.44.2"): + raise unittest.SkipTest( + "InternLM requires transformers<=4.44.2 in this test environment" + ) + def test_internlm(self): # transformers<=4.44.2 run normal - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_internlm2_5.py b/tests/models/test_internlm2_5.py index 2cef74a66..b6019fa29 100644 --- a/tests/models/test_internlm2_5.py +++ b/tests/models/test_internlm2_5.py @@ -3,29 +3,42 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium -from model_test import ModelTest +import unittest -from gptqmodel.utils.eval import EVAL +import transformers +from model_test import ModelTest +from packaging.version import Version class TestInternlm2_5(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/internlm2_5-1_8b-chat" # "internlm/internlm2_5-1_8b-chat" NATIVE_ARC_CHALLENGE_ACC = 0.3217 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.3575 + NATIVE_ARC_CHALLENGE_ACC_SLOW = NATIVE_ARC_CHALLENGE_ACC + NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW = NATIVE_ARC_CHALLENGE_ACC_NORM + NATIVE_ARC_CHALLENGE_ACC_FAST = NATIVE_ARC_CHALLENGE_ACC_SLOW + NATIVE_ARC_CHALLENGE_ACC_NORM_FAST = NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW TRUST_REMOTE_CODE = True EVAL_BATCH_SIZE = 6 USE_VLLM = False - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + EVAL_TASKS_SLOW = { + "arc_challenge": { "chat_template": True, "acc": {"value": NATIVE_ARC_CHALLENGE_ACC}, "acc_norm": {"value": NATIVE_ARC_CHALLENGE_ACC_NORM}, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) + @classmethod + def setUpClass(cls): + super().setUpClass() + if Version(transformers.__version__) > Version("4.44.2"): + raise unittest.SkipTest( + "InternLM2.5 requires transformers<=4.44.2 in this test environment" + ) def test_internlm2_5(self): # transformers<=4.44.2 run normal - self.quant_lm_eval() - + self.quantize_and_evaluate() diff --git a/tests/models/test_ling.py b/tests/models/test_ling.py index cfa1ba9a1..e439b4baf 100644 --- a/tests/models/test_ling.py +++ b/tests/models/test_ling.py @@ -5,21 +5,20 @@ from model_test import ModelTest -from gptqmodel.utils.eval import EVAL - class TestLing(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Ling-mini-2.0/" - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + EVAL_TASKS_SLOW = { + "arc_challenge": { "chat_template": True, "acc": {"value": 0.5009, "floor_pct": 0.2}, "acc_norm": {"value": 0.5137, "floor_pct": 0.2}, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) TRUST_REMOTE_CODE = True # EVAL_BATCH_SIZE = 6 - GPTQA = False + GPTAQ = None DEBUG = True ACT_GROUP_AWARE = True DESC_ACT = False @@ -30,4 +29,4 @@ class TestLing(ModelTest): CALIB_NOISE_PERCENT = 0.025 def test_mimo(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_llama3_2.py b/tests/models/test_llama3_2.py index 47379863a..33fb62ecc 100644 --- a/tests/models/test_llama3_2.py +++ b/tests/models/test_llama3_2.py @@ -1,50 +1,98 @@ + # SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai # SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium -from model_test import ModelTest +import os -from gptqmodel.utils.eval import EVAL +from model_test import ModelTest -# gpu9/a100 # | Metric | MARLIN | # |----------------------------------------------------|----------| -# | arc_challenge :: acc,none | 0.3089 | -# | arc_challenge :: acc_norm,none | 0.3481 | -# | gsm8k_platinum_cot :: exact_match,flexible-extract | 0.3143 | -# | gsm8k_platinum_cot :: exact_match,strict-match | 0.1315 | -# | mmlu_stem :: acc,none | 0.399 | +# | arc_challenge :: acc,none | 0.3166 | +# | arc_challenge :: acc_norm,none | 0.3430 | +# | gsm8k_platinum_cot :: acc,num | 0.3906 | +# | mmlu_stem :: acc,none | 0.3942 | class TestLlama3_2(ModelTest): - # DELETE_QUANTIZED_MODEL = False + # Keep one stable saved checkpoint so eval-only repro runs can reuse the exact post-quant model. + SAVE_PATH = os.environ.get( + "GPTQMODEL_LLAMA3_2_SAVE_PATH", + "/tmp/llama3_2_gptq_saved_ckpt", + ) + DELETE_QUANTIZED_MODEL = False NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct" # "meta-llama/Llama-3.2-1B-Instruct" EVAL_BATCH_SIZE = 64 DATASET_CONCAT_SIZE = 2048 - EVAL_TASKS = { - EVAL.LM_EVAL.GSM8K_PLATINUM_COT: { + EVAL_TASKS_SLOW = { + "gsm8k_platinum_cot": { "chat_template": True, - "exact_match,flexible-extract": { - "value": 0.3143, + "acc,num": { + "value": 0.3987, "floor_pct": 0.04, }, }, - EVAL.LM_EVAL.MMLU_STEM: { - "chat_template": False, + # "mmlu_stem": { + # "chat_template": False, + # "acc": { + # "value": 0.3860, # 0.3099 4096, 0.3270 2048 + # "floor_pct": 0.04, + # }, + # }, + "arc_challenge": { + "chat_template": True, "acc": { - "value": 0.3990, # 0.3099 4096, 0.3270 2048 + "value": 0.3234, # 0.3294 4096, 0.3242 2048 + "floor_pct": 0.04, + }, + "acc_norm": { + "value": 0.3643, # 0.3558 4096, 0.3635 2048 + "floor_pct": 0.04, + }, + }, + } + EVAL_TASKS_FAST = { + "gsm8k_platinum_cot": { + "chat_template": True, + "evalution_use_model_path": True, + "evalution_batch_size": "auto", + "evalution_model_args": { + "dtype": "bfloat16", + "attn_implementation": "paged|flash_attention_2", + "device": "cuda:0", + }, + "evalution_suite_kwargs": { + "batch_size": 32, + "max_new_tokens": 256, + "stream": True, + }, + "acc,num": { + "value": 0.390625, "floor_pct": 0.04, + "ceil_pct": 1.0, }, }, - EVAL.LM_EVAL.ARC_CHALLENGE: { + # "mmlu_stem": { + # "chat_template": False, + # "acc": { + # "value": 0.3942, + # "floor_pct": 0.04, + # "ceil_pct": 1.0, + # }, + # "max_rows": 256, + # }, + "arc_challenge": { "chat_template": True, "acc": { - "value": 0.3089, # 0.3294 4096, 0.3242 2048 + "value": 0.3166, "floor_pct": 0.04, + "ceil_pct": 1.0, }, "acc_norm": { - "value": 0.3481, # 0.3558 4096, 0.3635 2048 + "value": 0.3430, "floor_pct": 0.04, + "ceil_pct": 1.0, }, }, } @@ -62,4 +110,4 @@ class TestLlama3_2(ModelTest): # b1 = 0.315, b4 = 0.3106, b8 = 0.3148, b32 = 0.3148, b16 = 0.3234 def test_llama3_2(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_llama3_2_bitsandbytes.py b/tests/models/test_llama3_2_bitsandbytes.py new file mode 100644 index 000000000..566768231 --- /dev/null +++ b/tests/models/test_llama3_2_bitsandbytes.py @@ -0,0 +1,86 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import torch +from model_test import ModelTest + +from gptqmodel import BACKEND +from gptqmodel.nn_modules.qlinear.bitsandbytes import BITSANDBYTES_AVAILABLE, BitsAndBytesLinear +from gptqmodel.quantization import FORMAT, METHOD + + +class TestLlama3_2_BitsAndBytes(ModelTest): + NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct" + + EVAL_BATCH_SIZE = 64 + DATASET_CONCAT_SIZE = 2048 + EVAL_TASKS_FAST = { + "arc_challenge": { + "chat_template": True, + "acc": { + "value": 0.31, + "floor_pct": 0.04, + "ceil_pct": 1.0, + }, + "acc_norm": { + "value": 0.34, + "floor_pct": 0.04, + "ceil_pct": 1.0, + }, + }, + } + EVAL_TASKS_SLOW = { + "gsm8k_platinum_cot": { + "chat_template": True, + "acc,num": { + "value": 0.36, + "floor_pct": 0.04, + }, + }, + "mmlu_stem": { + "chat_template": False, + "acc": { + "value": 0.39, + "floor_pct": 0.04, + }, + }, + "arc_challenge": { + "chat_template": True, + "acc": { + "value": 0.31, + "floor_pct": 0.04, + }, + "acc_norm": { + "value": 0.34, + "floor_pct": 0.04, + }, + }, + } + + METHOD = METHOD.BITSANDBYTES + FORMAT = FORMAT.BITSANDBYTES + BITS = 4 + GROUP_SIZE = -1 + LOAD_BACKEND = BACKEND.BITSANDBYTES + QUANT_BACKEND = BACKEND.BITSANDBYTES + KERNEL_QUANT = {BitsAndBytesLinear} + KERNEL_INFERENCE = {BitsAndBytesLinear} + BNB_BLOCK_SIZE = 64 + BNB_COMPRESS_STATISTICS = False + + + def test_llama3_2_bitsandbytes(self): + if not BITSANDBYTES_AVAILABLE: + self.skipTest("bitsandbytes backend unavailable") + self.quantize_and_evaluate() + + module = self.model.model.model.layers[0].self_attn.q_proj + assert isinstance(module, BitsAndBytesLinear) + for name in ("weight", "weight_scb"): + assert hasattr(module, name), f"missing `{name}`" + assert tuple(module.weight.shape) == (24, 48) + assert tuple(module.weight_scb.shape) == (24,) + assert module.weight.dtype == torch.int8 + assert module.weight_scb.dtype == torch.float32 diff --git a/tests/models/test_llama3_2_dynamic_skip_layer_replay.py b/tests/models/test_llama3_2_dynamic_skip_layer_replay.py new file mode 100644 index 000000000..4e66214df --- /dev/null +++ b/tests/models/test_llama3_2_dynamic_skip_layer_replay.py @@ -0,0 +1,298 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from pathlib import Path + +import pcre +import pytest +import torch +import torch.nn as nn +from model_test import ModelTest +from safetensors import safe_open + +from gptqmodel import BACKEND +from gptqmodel.nn_modules.qlinear import BaseQuantLinear +from gptqmodel.nn_modules.qlinear.marlin import MarlinLinear + + +LAYER0_AND_LAYER2_ONLY_NEGATIVE_MATCH = r"^model\.layers\.(?!(?:0|2)\.)\d+\." +# Saved GPTQ checkpoints represent quantized linears with these tensor names. +GPTQ_TENSOR_SUFFIXES = ("qweight", "qzeros", "scales", "g_idx") +# Dynamically skipped layers must remain in native half precision on disk. +HALF_PRECISION_DTYPES = {"F16", "BF16"} +_LAYER_INDEX_RE = pcre.compile(r"\.layers\.(\d+)\.") +_QUANTIZED_TENSOR_RE = pcre.compile( + r"^(model\.layers\.(\d+)\..*)\.(qweight|qzeros|scales|g_idx)$" +) + + +class TestLlama3_2DynamicSkipLayerReplay(ModelTest): + """Exercise dynamic full-layer skips across a quantized -> skipped -> quantized chain.""" + + pytestmark = pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA is required for Llama-3.2 GPTQ Marlin integration tests", + ) + + NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct" + EVAL_BATCH_SIZE = 64 + DATASET_CONCAT_SIZE = 2048 + LOAD_BACKEND = BACKEND.MARLIN + KERNEL_INFERENCE = {MarlinLinear} + DYNAMIC = { + f"-:{LAYER0_AND_LAYER2_ONLY_NEGATIVE_MATCH}": {}, + } + EVAL_TASKS_SLOW = { + "gsm8k_platinum_cot": { + "chat_template": True, + "acc,num": { + "value": 0.3987, + "floor_pct": 0.04, + }, + }, + "arc_challenge": { + "chat_template": True, + "acc": { + "value": 0.3234, + "floor_pct": 0.04, + }, + "acc_norm": { + "value": 0.3643, + "floor_pct": 0.04, + }, + }, + } + EVAL_TASKS_FAST = { + "gsm8k_platinum_cot": { + "chat_template": True, + "evalution_use_model_path": True, + "evalution_batch_size": "auto", + "evalution_model_args": { + "dtype": "bfloat16", + "attn_implementation": "paged|flash_attention_2", + "device": "cuda", + }, + "evalution_suite_kwargs": { + "batch_size": 24, + "max_new_tokens": 96, + "stream": True, + "max_rows": 128, + }, + "acc,num": { + "value": 0.390625, + "floor_pct": 0.04, + "ceil_pct": 1.0, + }, + }, + "arc_challenge": { + "chat_template": True, + "acc": { + "value": 0.3166, + "floor_pct": 0.04, + "ceil_pct": 1.0, + }, + "acc_norm": { + "value": 0.3430, + "floor_pct": 0.04, + "ceil_pct": 1.0, + }, + }, + } + + def _assert_dynamic_config_targets_only_layers_0_and_2(self, cfg) -> None: + assert cfg.dynamic == self.DYNAMIC + assert cfg.dynamic_get("model.layers.0.self_attn.q_proj") is not False + assert cfg.dynamic_get("model.layers.1.self_attn.q_proj") is False + assert cfg.dynamic_get("model.layers.2.self_attn.q_proj") is not False + assert cfg.dynamic_get("model.layers.3.self_attn.q_proj") is False + + def _assert_only_layers_0_and_2_quantized(self, model) -> None: + quantized_layer_names = {0: [], 1: [], 2: []} + unexpected_quantized = [] + + for name, module in model.named_modules(): + if not isinstance(module, BaseQuantLinear): + continue + + layer_match = _LAYER_INDEX_RE.search(name) + if layer_match is None: + continue + + layer_idx = int(layer_match.group(1)) + if layer_idx in quantized_layer_names: + quantized_layer_names[layer_idx].append(name) + else: + unexpected_quantized.append(name) + + assert quantized_layer_names[0], "Expected quantized modules in layer 0." + assert not quantized_layer_names[1], ( + "Layer 1 should be fully skipped by QuantizeConfig.dynamic, " + f"but found quantized modules: {quantized_layer_names[1][:8]}" + ) + assert quantized_layer_names[2], "Expected quantized modules in layer 2." + assert not unexpected_quantized, ( + "Only layers 0 and 2 should be quantized, " + f"but found additional quantized modules: {unexpected_quantized[:8]}" + ) + + @staticmethod + def _layer_index_from_module_name(module_name: str) -> int | None: + """Return the transformer layer index encoded in a module/tensor name.""" + layer_match = _LAYER_INDEX_RE.search(module_name) + if layer_match is None: + return None + return int(layer_match.group(1)) + + @staticmethod + def _saved_module_name_candidates(module_name: str) -> list[str]: + """Generate candidate tensor prefixes for wrapped module names.""" + candidates = [module_name] + trimmed_name = module_name + while trimmed_name.startswith("model."): + trimmed_name = trimmed_name[len("model."):] + if trimmed_name: + candidates.append(trimmed_name) + return candidates + + def _resolve_saved_module_name( + self, + module_name: str, + tensor_dtypes: dict[str, str], + required_suffix: str, + ) -> str: + """Map an in-memory module path to the saved checkpoint tensor prefix.""" + for candidate in self._saved_module_name_candidates(module_name): + if f"{candidate}.{required_suffix}" in tensor_dtypes: + return candidate + raise AssertionError( + f"Could not resolve saved tensor prefix for `{module_name}` with suffix `{required_suffix}`." + ) + + def _collect_saved_safetensor_dtypes(self, model_path: str) -> dict[str, str]: + """Read tensor dtype metadata from saved safetensor shards without loading weights.""" + shard_paths = sorted(Path(model_path).rglob("*.safetensors")) + assert shard_paths, f"No safetensors shards found under `{model_path}`." + + tensor_dtypes = {} + for shard_path in shard_paths: + with safe_open(str(shard_path), framework="pt") as shard: + for key in shard.keys(): + tensor_dtypes[key] = str(shard.get_slice(key).get_dtype()) + return tensor_dtypes + + def _assert_saved_checkpoint_preserves_dynamic_layer_selection(self, model) -> None: + """Verify the saved checkpoint only GPTQ-serializes layers 0 and 2.""" + model_path = self._resolve_quantized_model_path(model) + assert model_path, "Expected the quantized model to expose a saved checkpoint path." + + tensor_dtypes = self._collect_saved_safetensor_dtypes(model_path) + expected_quantized_layers = {0, 2} + quantized_module_names = [] + native_linear_module_names = [] + + for name, module in model.named_modules(): + layer_idx = self._layer_index_from_module_name(name) + if layer_idx is None: + continue + + if isinstance(module, BaseQuantLinear): + quantized_module_names.append((name, layer_idx)) + elif isinstance(module, nn.Linear): + native_linear_module_names.append((name, layer_idx)) + + assert quantized_module_names, "Expected at least one quantized linear module in the saved model." + assert native_linear_module_names, "Expected skipped layers to retain native linear weights in the saved model." + + unexpected_quantized_keys = [] + for tensor_name in tensor_dtypes: + match = _QUANTIZED_TENSOR_RE.match(tensor_name) + if match is None: + continue + layer_idx = int(match.group(2)) + if layer_idx not in expected_quantized_layers: + unexpected_quantized_keys.append(tensor_name) + assert not unexpected_quantized_keys, ( + "Only layers 0 and 2 should have GPTQ-style tensors on disk, " + f"but found additional quantized tensors: {unexpected_quantized_keys[:8]}" + ) + + for module_name, layer_idx in quantized_module_names: + assert layer_idx in expected_quantized_layers, ( + f"Only layers 0 and 2 should be quantized, but found `{module_name}` in layer {layer_idx}." + ) + saved_module_name = self._resolve_saved_module_name( + module_name, + tensor_dtypes, + required_suffix="qweight", + ) + for suffix in GPTQ_TENSOR_SUFFIXES: + tensor_key = f"{saved_module_name}.{suffix}" + assert tensor_key in tensor_dtypes, ( + f"Missing saved GPTQ tensor `{tensor_key}` for quantized module `{module_name}`." + ) + + assert tensor_dtypes[f"{saved_module_name}.scales"] in HALF_PRECISION_DTYPES, ( + f"Expected `{saved_module_name}.scales` to be saved in half precision, " + f"but found `{tensor_dtypes[f'{saved_module_name}.scales']}`." + ) + for suffix in ("qweight", "qzeros", "g_idx"): + dtype_name = tensor_dtypes[f"{saved_module_name}.{suffix}"] + assert dtype_name.startswith(("I", "U")), ( + f"Expected `{saved_module_name}.{suffix}` to use an integer dtype, but found `{dtype_name}`." + ) + assert f"{saved_module_name}.weight" not in tensor_dtypes, ( + f"Quantized module `{module_name}` should not be saved with a native `.weight` tensor." + ) + + for module_name, layer_idx in native_linear_module_names: + assert layer_idx not in expected_quantized_layers, ( + f"Module `{module_name}` in quantized layer {layer_idx} unexpectedly remained native." + ) + saved_module_name = self._resolve_saved_module_name( + module_name, + tensor_dtypes, + required_suffix="weight", + ) + tensor_key = f"{saved_module_name}.weight" + assert tensor_key in tensor_dtypes, ( + f"Missing native weight tensor `{tensor_key}` for skipped module `{module_name}`." + ) + assert tensor_dtypes[tensor_key] in HALF_PRECISION_DTYPES, ( + f"Expected `{tensor_key}` to remain in bf16/f16, but found `{tensor_dtypes[tensor_key]}`." + ) + for suffix in GPTQ_TENSOR_SUFFIXES: + unexpected_key = f"{saved_module_name}.{suffix}" + assert unexpected_key not in tensor_dtypes, ( + f"Skipped module `{module_name}` should not have saved GPTQ tensor `{unexpected_key}`." + ) + + def _run_dynamic_skip_replay_eval(self) -> None: + cfg = self._build_quantize_config() + self._assert_dynamic_config_targets_only_layers_0_and_2(cfg) + + self.model, _, _ = self.quantModel( + self.NATIVE_MODEL_ID, + batch_size=self.QUANT_BATCH_SIZE, + trust_remote_code=self.TRUST_REMOTE_CODE, + dtype=self.TORCH_DTYPE, + ) + self.check_kernel(self.model, self.KERNEL_INFERENCE) + self._assert_only_layers_0_and_2_quantized(self.model) + self._assert_saved_checkpoint_preserves_dynamic_layer_selection(self.model) + + eval_records = getattr(self, "_post_quant_eval_records", {}) + target_backend = self._current_load_backend() + if eval_records and len(eval_records) == 1 and target_backend in eval_records: + task_results = eval_records[target_backend] + else: + task_results = self.evaluate_model( + model=self.SAVE_PATH if self.SAVE_PATH else self.model, + trust_remote_code=self.TRUST_REMOTE_CODE, + delete_quantized_model=self.DELETE_QUANTIZED_MODEL, + ) + + self.check_results(task_results) + self._cleanup_quantized_model(self.model, enabled=self.DELETE_QUANTIZED_MODEL) + + def test_llama3_2_dynamic_skip_layer_replay(self): + self._run_dynamic_skip_replay_eval() diff --git a/tests/models/test_llama3_2_exllamav3.py b/tests/models/test_llama3_2_exllamav3.py new file mode 100644 index 000000000..3dfdcb57c --- /dev/null +++ b/tests/models/test_llama3_2_exllamav3.py @@ -0,0 +1,84 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import os +import sys + +import torch + + +TESTS_MODELS_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +if TESTS_MODELS_ROOT not in sys.path: + sys.path.insert(0, TESTS_MODELS_ROOT) + +from model_test import ModelTest + +from gptqmodel import BACKEND +from gptqmodel.nn_modules.exllamav3 import ExllamaV3Linear +from gptqmodel.quantization import FORMAT, METHOD + + +# | Metric | EXLLAMA_V3 | +# |----------------------------------------------------|------------| +# | arc_challenge :: acc,none | 0.3174 | +# | arc_challenge :: acc_norm,none | 0.3456 | +# | gsm8k_platinum_cot :: acc,num | 0.4715 | +# | gsm8k_platinum_cot :: exact_match,strict-match | 0.4218 | +# | mmlu_stem :: acc,none | 0.3977 | +class TestLlama3_2_ExllamaV3(ModelTest): + NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct" + EVAL_BATCH_SIZE = 64 + DATASET_CONCAT_SIZE = 2048 + EVAL_TASKS = { + "gsm8k_platinum_cot": { + "chat_template": True, + "acc,num": { + "value": 0.4715, + "floor_pct": 0.04, + }, + }, + "mmlu_stem": { + "chat_template": False, + "acc": { + "value": 0.3977, + "floor_pct": 0.04, + }, + }, + "arc_challenge": { + "chat_template": True, + "acc": { + "value": 0.3174, + "floor_pct": 0.04, + }, + "acc_norm": { + "value": 0.3456, + "floor_pct": 0.04, + }, + }, + } + + FORMAT = FORMAT.EXL3 + METHOD = METHOD.EXL3 + BITS = 4.0 + GROUP_SIZE = -1 + ACT_GROUP_AWARE = False + TORCH_DTYPE = torch.float16 + QUANT_BACKEND = BACKEND.EXLLAMA_V3 + LOAD_BACKEND = BACKEND.EXLLAMA_V3 + + def test_llama3_2_exllamav3(self): + self.quantize_and_evaluate() + + module = self.model.model.model.layers[0].self_attn.q_proj + assert isinstance(module, ExllamaV3Linear) + assert module.trellis.dtype == torch.int16 + assert module.suh.dtype == torch.float16 + assert module.svh.dtype == torch.float16 + assert module.mcg.dtype == torch.int32 + + storage = module.tensor_storage_entry() + assert storage["quant_format"] == "exl3" + assert storage["bits_per_weight"] == 4 + assert "model.layers.0.self_attn.q_proj.trellis" in storage["stored_tensors"] diff --git a/tests/models/test_llama3_2_fp8.py b/tests/models/test_llama3_2_fp8.py new file mode 100644 index 000000000..0a8d33015 --- /dev/null +++ b/tests/models/test_llama3_2_fp8.py @@ -0,0 +1,74 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +import os +import sys + +import torch + + +TESTS_MODELS_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +if TESTS_MODELS_ROOT not in sys.path: + sys.path.insert(0, TESTS_MODELS_ROOT) + +from model_test import ModelTest + +from gptqmodel import BACKEND +from gptqmodel.nn_modules.qlinear.fp8 import TorchFP8Linear +from gptqmodel.quantization import METHOD +from gptqmodel.quantization.config import WeightOnlyConfig + + +# | Metric | TORCH_FP8 | +# |----------------------------------------------------|-----------| +# | arc_challenge :: acc,none | 0.3191 | +# | arc_challenge :: acc_norm,none | 0.3498 | +# | gsm8k_platinum_cot :: acc,num | 0.4756 | +# | gsm8k_platinum_cot :: exact_match,strict-match | 0.4458 | +# | mmlu_stem :: acc,none | 0.4085 | +class TestLlama3_2_FP8(ModelTest): + NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct" + EVAL_BATCH_SIZE = 64 + DATASET_CONCAT_SIZE = 2048 + EVAL_TASKS = { + "gsm8k_platinum_cot": { + "chat_template": True, + "acc,num": { + "value": 0.4756, + "floor_pct": 0.04, + }, + }, + "mmlu_stem": { + "chat_template": False, + "acc": { + "value": 0.4085, + "floor_pct": 0.04, + }, + }, + "arc_challenge": { + "chat_template": True, + "acc": { + "value": 0.3191, + "floor_pct": 0.04, + }, + "acc_norm": { + "value": 0.3498, + "floor_pct": 0.04, + }, + }, + } + + FORMAT = "float8_e4m3fn" + METHOD = METHOD.FP8 + BITS = 8 + GROUP_SIZE = -1 + ACT_GROUP_AWARE = False + TORCH_DTYPE = torch.float16 + WEIGHT_ONLY = WeightOnlyConfig(method="fp8") + QUANT_BACKEND = BACKEND.TORCH + LOAD_BACKEND = BACKEND.TORCH + KERNEL_QUANT = {TorchFP8Linear} + KERNEL_INFERENCE = {TorchFP8Linear} + + def test_llama3_2_fp8(self): + self.quantize_and_evaluate() diff --git a/tests/models/test_llama3_2_gguf.py b/tests/models/test_llama3_2_gguf.py new file mode 100644 index 000000000..660f0f3b7 --- /dev/null +++ b/tests/models/test_llama3_2_gguf.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import torch +from model_test import ModelTest + +from gptqmodel import BACKEND +from gptqmodel.nn_modules.qlinear.gguf import GGUFTorchLinear +from gptqmodel.quantization import FORMAT, METHOD + + +class TestLlama3_2_GGUF(ModelTest): + NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct" # "meta-llama/Llama-3.2-1B-Instruct" + + EVAL_BATCH_SIZE = 64 + DATASET_CONCAT_SIZE = 2048 + EVAL_TASKS = { + "gsm8k_platinum_cot": { + "chat_template": True, + "acc,num": { + "value": 0.3871, + "floor_pct": 0.04, + "ceil_pct": 0.04, + }, + }, + "mmlu_stem": { + "chat_template": False, + "acc": { + "value": 0.3955, + "floor_pct": 0.04, + "ceil_pct": 0.04, + }, + }, + "arc_challenge": { + "chat_template": True, + "acc": { + "value": 0.3106, + "floor_pct": 0.04, + "ceil_pct": 0.04, + }, + "acc_norm": { + "value": 0.3532, + "floor_pct": 0.04, + "ceil_pct": 0.04, + }, + }, + } + METHOD = METHOD.GGUF + FORMAT = FORMAT.GGUF + BITS = "q4_k_m" + LOAD_BACKEND = BACKEND.GGUF_TORCH + KERNEL_INFERENCE = {GGUFTorchLinear} + + def test_llama3_2_gguf_full_model(self): + self.quantize_and_evaluate() + + module = self.model.model.model.layers[0].self_attn.q_proj + assert isinstance(module, GGUFTorchLinear) + assert module.gguf_tensor_qtype == "Q4_K" + assert hasattr(module, "qweight") + assert tuple(module.qweight.shape) == (2048, module._bytes_per_row()) + assert module.qweight.dtype == torch.uint8 diff --git a/tests/models/test_llama3_2_gguf_protocol.py b/tests/models/test_llama3_2_gguf_protocol.py new file mode 100644 index 000000000..c29efe729 --- /dev/null +++ b/tests/models/test_llama3_2_gguf_protocol.py @@ -0,0 +1,181 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from model_test import ModelTest + +from gptqmodel import BACKEND +from gptqmodel.nn_modules.qlinear.gguf import GGUFTorchLinear +from gptqmodel.quantization import METHOD, GGUFConfig +from gptqmodel.quantization.protocol import ( + Rule, + Stage, + compile_plan_to_quantize_config, + compile_protocol, + compile_protocol_yaml_text, +) + + +LAYER0_ONLY_NEGATIVE_MATCH = r".*layers\.(?:[1-9]|[12][0-9]|3[0-2])\..*" + + +def _python_protocol(): + return { + "version": 2, + "stages": [ + Stage( + name="weight_only", + rules=[ + Rule( + match=["*", f"-:{LAYER0_ONLY_NEGATIVE_MATCH}"], + weight={ + "quantize": { + "method": "gguf", + "bits": "q4_k_m", + }, + "export": { + "format": "gguf", + "variant": "q_k_m", + "impl": "gguf_torch", + }, + }, + ), + ], + ), + ], + } + + +def _yaml_protocol() -> str: + return r""" +version: 2 +stages: + - name: weight_only + rules: + - match: + - "*" + - '-:.*layers\.(?:[1-9]|[12][0-9]|3[0-2])\..*' + weight: + quantize: + method: gguf + bits: q4_k_m + export: + format: gguf + variant: q_k_m + impl: gguf_torch +""" + + +class _BaseLlama3_2GGUFProtocol(ModelTest): + pytestmark = pytest.mark.skipif( + not __import__("torch").cuda.is_available(), + reason="CUDA is required for protocol GGUF integration tests", + ) + + NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct" + EVAL_BATCH_SIZE = 64 + DATASET_CONCAT_SIZE = 2048 + EVAL_TASKS = { + "gsm8k_platinum_cot": { + "chat_template": True, + "acc,num": { + "value": 0.4690, + "floor_pct": 0.05, + "ceil_pct": 0.05, + }, + }, + "mmlu_stem": { + "chat_template": False, + "acc": { + "value": 0.3999, + "floor_pct": 0.03, + "ceil_pct": 0.03, + }, + }, + "arc_challenge": { + "chat_template": True, + "acc": { + "value": 0.3221, + "floor_pct": 0.05, + "ceil_pct": 0.05, + }, + "acc_norm": { + "value": 0.3528, + "floor_pct": 0.03, + "ceil_pct": 0.03, + }, + }, + } + LOAD_BACKEND = BACKEND.GGUF_TORCH + KERNEL_INFERENCE = {GGUFTorchLinear} + + def _compiled_protocol_plan(self): + raise NotImplementedError + + def _build_quantize_config(self): + return compile_plan_to_quantize_config(self._compiled_protocol_plan()) + + def _assert_layer0_only_dynamic(self, cfg): + assert isinstance(cfg, GGUFConfig) + assert cfg.quant_method == METHOD.GGUF + assert cfg.dynamic == {f"-:{LAYER0_ONLY_NEGATIVE_MATCH}": {}} + + def _assert_only_first_layer_quantized(self, model): + layer0_quantized = [] + later_layer_quantized = [] + + for name, module in model.named_modules(): + if not isinstance(module, GGUFTorchLinear): + continue + if ".layers.0." in name: + layer0_quantized.append(name) + elif ".layers." in name: + later_layer_quantized.append(name) + + assert layer0_quantized, "Expected at least one GGUF quantized module in layer 0." + assert not later_layer_quantized, ( + "Expected GGUF quantization only in layer 0, " + f"but found later-layer modules: {later_layer_quantized[:8]}" + ) + + def _run_layer0_only_protocol_eval(self): + cfg = self._build_quantize_config() + self._assert_layer0_only_dynamic(cfg) + + self.model, _, _ = self.quantModel( + self.NATIVE_MODEL_ID, + batch_size=self.QUANT_BATCH_SIZE, + trust_remote_code=self.TRUST_REMOTE_CODE, + dtype=self.TORCH_DTYPE, + ) + self.check_kernel(self.model, self.KERNEL_INFERENCE) + self._assert_only_first_layer_quantized(self.model) + + eval_records = getattr(self, "_post_quant_eval_records", {}) + target_backend = self._current_load_backend() + if eval_records and len(eval_records) == 1 and target_backend in eval_records: + task_results = eval_records[target_backend] + else: + task_results = self.evaluate_model( + model=self.SAVE_PATH if self.SAVE_PATH else self.model, + trust_remote_code=self.TRUST_REMOTE_CODE, + delete_quantized_model=self.DELETE_QUANTIZED_MODEL, + ) + self.check_results(task_results) + self._cleanup_quantized_model(self.model, enabled=self.DELETE_QUANTIZED_MODEL) + + +class TestLlama3_2_GGUFProtocolPython(_BaseLlama3_2GGUFProtocol): + def _compiled_protocol_plan(self): + return compile_protocol(_python_protocol()) + + def test_llama3_2_gguf_protocol_python(self): + self._run_layer0_only_protocol_eval() + + +class TestLlama3_2_GGUFProtocolYAML(_BaseLlama3_2GGUFProtocol): + def _compiled_protocol_plan(self): + return compile_protocol_yaml_text(_yaml_protocol()) + + def test_llama3_2_gguf_protocol_yaml(self): + self._run_layer0_only_protocol_eval() diff --git a/tests/models/test_llama3_2_gptq_protocol.py b/tests/models/test_llama3_2_gptq_protocol.py new file mode 100644 index 000000000..a4e74d872 --- /dev/null +++ b/tests/models/test_llama3_2_gptq_protocol.py @@ -0,0 +1,189 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +import pytest +from model_test import ModelTest + +from gptqmodel import BACKEND +from gptqmodel.nn_modules.qlinear import BaseQuantLinear +from gptqmodel.nn_modules.qlinear.marlin import MarlinLinear +from gptqmodel.quantization import FORMAT, METHOD, GPTQConfig +from gptqmodel.quantization.protocol import ( + Rule, + Stage, + compile_plan_to_quantize_config, + compile_protocol, + compile_protocol_yaml_text, +) + + +LAYER0_ONLY_NEGATIVE_MATCH = r".*layers\.(?:[1-9]|[12][0-9]|3[0-2])\..*" + + +def _python_protocol(): + return { + "version": 2, + "stages": [ + Stage( + name="ptq", + rules=[ + Rule( + match=["*", f"-:{LAYER0_ONLY_NEGATIVE_MATCH}"], + weight={ + "quantize": { + "method": "gptq", + "bits": 4, + "group_size": 128, + "sym": True, + "desc_act": False, + }, + "export": { + "format": "gptq", + "variant": "gptq", + "impl": "marlin", + }, + }, + ), + ], + ), + ], + } + + +def _yaml_protocol() -> str: + return r""" +version: 2 +stages: + - name: ptq + rules: + - match: + - "*" + - '-:.*layers\.(?:[1-9]|[12][0-9]|3[0-2])\..*' + weight: + quantize: + method: gptq + bits: 4 + group_size: 128 + sym: true + desc_act: false + export: + format: gptq + variant: gptq + impl: marlin +""" + + +class _BaseLlama3_2GPTQProtocol(ModelTest): + pytestmark = pytest.mark.skipif( + not __import__("torch").cuda.is_available(), + reason="CUDA is required for protocol GPTQ integration tests", + ) + + NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct" + EVAL_BATCH_SIZE = 64 + DATASET_CONCAT_SIZE = 2048 + EVAL_TASKS = { + "gsm8k_platinum_cot": { + "chat_template": True, + "acc,num": { + "value": 0.4690, + "floor_pct": 0.05, + "ceil_pct": 0.05, + }, + }, + "mmlu_stem": { + "chat_template": False, + "acc": { + "value": 0.3999, + "floor_pct": 0.03, + "ceil_pct": 0.03, + }, + }, + "arc_challenge": { + "chat_template": True, + "acc": { + "value": 0.3221, + "floor_pct": 0.05, + "ceil_pct": 0.05, + }, + "acc_norm": { + "value": 0.3528, + "floor_pct": 0.03, + "ceil_pct": 0.03, + }, + }, + } + LOAD_BACKEND = BACKEND.MARLIN + KERNEL_INFERENCE = {MarlinLinear} + + def _compiled_protocol_plan(self): + raise NotImplementedError + + def _build_quantize_config(self): + return compile_plan_to_quantize_config(self._compiled_protocol_plan()) + + def _assert_layer0_only_dynamic(self, cfg): + assert isinstance(cfg, GPTQConfig) + assert cfg.quant_method == METHOD.GPTQ + assert cfg.format == FORMAT.GPTQ + assert cfg.dynamic == {f"-:{LAYER0_ONLY_NEGATIVE_MATCH}": {}} + + def _assert_only_first_layer_quantized(self, model): + layer0_quantized = [] + later_layer_quantized = [] + + for name, module in model.named_modules(): + if not isinstance(module, BaseQuantLinear): + continue + if ".layers.0." in name: + layer0_quantized.append(name) + elif ".layers." in name: + later_layer_quantized.append(name) + + assert layer0_quantized, "Expected at least one quantized module in layer 0." + assert not later_layer_quantized, ( + "Expected quantization only in layer 0, " + f"but found later-layer modules: {later_layer_quantized[:8]}" + ) + + def _run_layer0_only_protocol_eval(self): + cfg = self._build_quantize_config() + self._assert_layer0_only_dynamic(cfg) + + self.model, _, _ = self.quantModel( + self.NATIVE_MODEL_ID, + batch_size=self.QUANT_BATCH_SIZE, + trust_remote_code=self.TRUST_REMOTE_CODE, + dtype=self.TORCH_DTYPE, + ) + self.check_kernel(self.model, self.KERNEL_INFERENCE) + self._assert_only_first_layer_quantized(self.model) + + eval_records = getattr(self, "_post_quant_eval_records", {}) + target_backend = self._current_load_backend() + if eval_records and len(eval_records) == 1 and target_backend in eval_records: + task_results = eval_records[target_backend] + else: + task_results = self.evaluate_model( + model=self.SAVE_PATH if self.SAVE_PATH else self.model, + trust_remote_code=self.TRUST_REMOTE_CODE, + delete_quantized_model=self.DELETE_QUANTIZED_MODEL, + ) + self.check_results(task_results) + self._cleanup_quantized_model(self.model, enabled=self.DELETE_QUANTIZED_MODEL) + + +class TestLlama3_2_GPTQProtocolPython(_BaseLlama3_2GPTQProtocol): + def _compiled_protocol_plan(self): + return compile_protocol(_python_protocol()) + + def test_llama3_2_gptq_protocol_python(self): + self._run_layer0_only_protocol_eval() + + +class TestLlama3_2_GPTQProtocolYAML(_BaseLlama3_2GPTQProtocol): + def _compiled_protocol_plan(self): + return compile_protocol_yaml_text(_yaml_protocol()) + + def test_llama3_2_gptq_protocol_yaml(self): + self._run_layer0_only_protocol_eval() diff --git a/tests/models/test_llama3_2_lazy_turtle_memory.py b/tests/models/test_llama3_2_lazy_turtle_memory.py new file mode 100644 index 000000000..55e47e0a2 --- /dev/null +++ b/tests/models/test_llama3_2_lazy_turtle_memory.py @@ -0,0 +1,180 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import gc +from dataclasses import dataclass +from pathlib import Path + +import torch +from model_test import ModelTest + +from gptqmodel import GPTQModel +from gptqmodel.looper.module_looper import StopMainLoop +from gptqmodel.utils.torch import torch_empty_cache + + +# Only quantize the first four decoder blocks so the memory regression stays fast enough for CI. +FIRST4_ONLY_NEGATIVE_MATCH = r"^model\.layers\.(?!(?:0|1|2|3)\.)\d+\." +# The Llama 3.2 checkpoint must stay monolithic to reproduce the original mmap-retention risk. +MONOLITHIC_SAFETENSORS_FILE = "model.safetensors" +# Finalized layers should stay within a small RSS band once prior source weights have been released. +MAX_FINALIZED_RSS_GROWTH_GIB = 0.2 + + +@dataclass(frozen=True) +class LayerMemoryRecord: + """Capture host-memory state at a finalized layer boundary.""" + + layer_idx: int + rss_gib: float + mmaps: int + cuda_alloc_gib: float + cuda_reserved_gib: float + + +class TestLlama3_2LazyTurtleMemory(ModelTest): + """Verify lazy turtle keeps host memory bounded on a monolithic Llama safetensors file.""" + + NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct" + TORCH_DTYPE = "bfloat16" + USE_FLASH_ATTN = False + QUANT_BATCH_SIZE = 1 + DATASET_SIZE = 32 + DATASET_CONCAT_SIZE = 2048 + OFFLOAD_TO_DISK = True + DYNAMIC = { + f"-:{FIRST4_ONLY_NEGATIVE_MATCH}": {}, + } + + @staticmethod + def _read_rss_gib() -> float: + with Path("/proc/self/status").open("r", encoding="utf-8") as handle: + for line in handle: + if line.startswith("VmRSS:"): + return int(line.split()[1]) / (1024 ** 2) + raise RuntimeError("VmRSS entry missing from /proc/self/status") + + @staticmethod + def _count_file_mmaps(path: Path) -> int: + resolved = str(path.resolve()) + with Path("/proc/self/maps").open("r", encoding="utf-8") as handle: + return sum(1 for line in handle if resolved in line) + + @staticmethod + def _expected_layer_indices() -> list[int]: + return [0, 1, 2, 3] + + def _assert_monolithic_checkpoint_layout(self, checkpoint_dir: Path) -> Path: + safetensor_files = sorted(path.name for path in checkpoint_dir.glob("*.safetensors")) + self.assertEqual( + safetensor_files, + [MONOLITHIC_SAFETENSORS_FILE], + f"Expected a single top-level safetensors file under {checkpoint_dir}.", + ) + + checkpoint_path = checkpoint_dir / MONOLITHIC_SAFETENSORS_FILE + self.assertTrue(checkpoint_path.is_file(), f"Missing monolithic checkpoint file at {checkpoint_path}.") + self.assertFalse( + (checkpoint_dir / "model.safetensors.index.json").exists(), + "Monolithic checkpoint regression should not use a shard index.", + ) + return checkpoint_path + + def _build_layer_probe(self, checkpoint_path: Path): + records: list[LayerMemoryRecord] = [] + + class _Probe: + """Record finalized-layer memory and stop once the first four layers have completed.""" + + def layer_complete(self, *, layer_idx: int, submodule_finalized: bool): + if not submodule_finalized: + return None + + gc.collect() + torch.cuda.synchronize() + records.append( + LayerMemoryRecord( + layer_idx=layer_idx, + rss_gib=self_owner._read_rss_gib(), + mmaps=self_owner._count_file_mmaps(checkpoint_path), + cuda_alloc_gib=torch.cuda.memory_allocated(0) / (1024 ** 3), + cuda_reserved_gib=torch.cuda.memory_reserved(0) / (1024 ** 3), + ) + ) + if layer_idx >= self_owner._expected_layer_indices()[-1]: + raise StopMainLoop + + self_owner = self + return _Probe(), records + + def _assert_memory_records(self, records: list[LayerMemoryRecord]) -> None: + self.assertEqual( + [record.layer_idx for record in records], + self._expected_layer_indices(), + f"Expected finalized-layer records for layers {self._expected_layer_indices()}, got {records}.", + ) + self.assertTrue(records, "Expected at least one finalized-layer memory record.") + self.assertTrue( + all(record.mmaps == 0 for record in records), + f"Monolithic safetensors mmap should be released after each finalized layer, got {records}.", + ) + + baseline_rss = records[0].rss_gib + later_peak_rss = max(record.rss_gib for record in records[1:]) + self.assertLessEqual( + later_peak_rss, + baseline_rss + MAX_FINALIZED_RSS_GROWTH_GIB, + f"Finalized host RSS kept growing instead of flattening: {records}.", + ) + self.assertLessEqual( + records[-1].rss_gib, + baseline_rss + MAX_FINALIZED_RSS_GROWTH_GIB, + f"Last finalized layer retained too much host memory: {records}.", + ) + + def test_lazy_turtle_releases_monolithic_checkpoint_memory_between_layers(self) -> None: + if not torch.cuda.is_available(): + self.skipTest("CUDA is required for the lazy-turtle memory regression test.") + if not Path("/proc/self/maps").exists(): + self.skipTest("/proc/self/maps is required to inspect live safetensors mmaps.") + + checkpoint_dir = Path(self.NATIVE_MODEL_ID) + checkpoint_path = self._assert_monolithic_checkpoint_layout(checkpoint_dir) + + model = None + dataset = None + try: + quantize_config = self._build_quantize_config() + quantize_config.device = torch.device("cuda") + quantize_config.wait_for_submodule_finalizers = True + + model = GPTQModel.load( + self.NATIVE_MODEL_ID, + quantize_config=quantize_config, + dtype=self.TORCH_DTYPE, + attn_implementation="eager", + ) + + probe, records = self._build_layer_probe(checkpoint_path) + model.layer_callback = probe + dataset = self.load_dataset(model.tokenizer, rows=self.DATASET_SIZE) + + try: + model.quantize( + dataset, + calibration_concat_size=self.DATASET_CONCAT_SIZE, + calibration_sort=self.DATASET_SORT, + batch_size=self.QUANT_BATCH_SIZE, + ) + except StopMainLoop: + # The layer callback raises this sentinel once layers 0-3 have + # produced the finalized memory samples this regression needs. + pass + + self._assert_memory_records(records) + finally: + del dataset + del model + torch_empty_cache() diff --git a/tests/models/test_llama3_2_paroquant_first_layer.py b/tests/models/test_llama3_2_paroquant_first_layer.py new file mode 100644 index 000000000..5287c4f1c --- /dev/null +++ b/tests/models/test_llama3_2_paroquant_first_layer.py @@ -0,0 +1,31 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +import os + +import pytest +import torch +from paroquant_first_layer_case_helper import ( + assert_basic_paroquant_first_layer_result, + run_paroquant_first_layer_case_from_env, +) + + +@pytest.mark.cuda +def test_llama3_2_paroquant_first_4_layers_full_model(): + if not torch.cuda.is_available(): + pytest.skip("CUDA required for prefix-layer ParoQuant test") + + if os.environ.get("GPTQMODEL_RUN_PAROQUANT_FIRST_LAYER_TEST") != "1": + pytest.skip("Set GPTQMODEL_RUN_PAROQUANT_FIRST_LAYER_TEST=1 to run this prefix-layer integration test.") + + result, resolved = run_paroquant_first_layer_case_from_env( + env_prefix="GPTQMODEL_PAROQUANT_TEST", + default_num_quant_layers=4, + default_opt_scope="module", + ) + assert_basic_paroquant_first_layer_result( + result, + num_quant_layers=resolved["num_quant_layers"], + opt_scope=resolved["opt_scope"], + ) diff --git a/tests/models/test_llama3_2_paroquant_optimize_compute_block.py b/tests/models/test_llama3_2_paroquant_optimize_compute_block.py new file mode 100644 index 000000000..a27849f11 --- /dev/null +++ b/tests/models/test_llama3_2_paroquant_optimize_compute_block.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from paroquant_optimize_case import BaseLlama3_2ParoQuantOptimizeTest, _resolve_save_path + + +class TestLlama3_2_ParoQuant(BaseLlama3_2ParoQuantOptimizeTest): + __test__ = True + SAVE_PATH = _resolve_save_path( + "GPTQMODEL_PAROQUANT_COMPUTE_BLOCK_SAVE_PATH", + "/tmp/paroquant_evalution_saved_ckpt_compute_block", + ) + OPT_SCOPE = "compute_block" + TRAIN_ON_NOISY_INPUTS_DEFAULT = False diff --git a/tests/models/test_llama3_2_paroquant_optimize_group.py b/tests/models/test_llama3_2_paroquant_optimize_group.py new file mode 100644 index 000000000..dd42d59b6 --- /dev/null +++ b/tests/models/test_llama3_2_paroquant_optimize_group.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +import os + +import pytest +import torch +from paroquant_first_layer_case_helper import ( + assert_basic_paroquant_first_layer_result, + resolve_paroquant_first_layer_case_env, + run_paroquant_first_layer_case_from_resolved, +) +from tabulate import tabulate + + +@pytest.mark.cuda +def test_llama3_2_paroquant_optimize_group_first_2_layers(capsys: pytest.CaptureFixture[str]): + if not torch.cuda.is_available(): + pytest.skip("CUDA required for grouped ParoQuant integration test") + + if os.environ.get("GPTQMODEL_RUN_PAROQUANT_OPTIMIZE_GROUP_TEST") != "1": + pytest.skip("Set GPTQMODEL_RUN_PAROQUANT_OPTIMIZE_GROUP_TEST=1 to run this grouped integration test.") + + resolved = resolve_paroquant_first_layer_case_env( + env_prefix="GPTQMODEL_PAROQUANT_GROUP_TEST", + default_num_quant_layers=2, + default_opt_scope="compute_block", + ) + if resolved["opt_scope"] == "module": + pytest.skip("Grouped optimize_group integration test requires opt_scope=compute_block or opt_scope=layer.") + + result = run_paroquant_first_layer_case_from_resolved(resolved) + gsm8k_metrics = assert_basic_paroquant_first_layer_result( + result, + num_quant_layers=resolved["num_quant_layers"], + opt_scope=resolved["opt_scope"], + ) + assert result["opt_scope"] in {"compute_block", "layer"} + + summary_rows = [ + ["num_quant_layers", str(result["num_quant_layers"])], + ["opt_scope", str(result["opt_scope"])], + ["quant_wall_s", f"{float(result['quant_wall_s']):.3f}"], + ["eval_wall_s", f"{float(result['eval_wall_s']):.3f}"], + ["gsm8k_platinum_cot acc,num", f"{float(gsm8k_metrics['acc,num']):.6f}"], + ] + summary_table = tabulate(summary_rows, headers=["metric", "value"], tablefmt="grid") + module_times = tabulate( + result["module_time_rows"], + headers=["layer", "module", "feat", "samples", "loss", "time_s"], + tablefmt="grid", + ) + + with capsys.disabled(): + print("\nParoQuant Optimize Group Summary", flush=True) + print(summary_table, flush=True) + print("\nParoQuant Optimize Group Eval", flush=True) + print(result.get("eval_table", ""), flush=True) + print("\nParoQuant Optimize Group Module Times", flush=True) + print(module_times, flush=True) diff --git a/tests/models/test_llama3_2_paroquant_optimize_layer.py b/tests/models/test_llama3_2_paroquant_optimize_layer.py new file mode 100644 index 000000000..d60437980 --- /dev/null +++ b/tests/models/test_llama3_2_paroquant_optimize_layer.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from paroquant_optimize_case import BaseLlama3_2ParoQuantOptimizeTest, _resolve_save_path + + +class TestLlama3_2_ParoQuant(BaseLlama3_2ParoQuantOptimizeTest): + __test__ = True + SAVE_PATH = _resolve_save_path( + "GPTQMODEL_PAROQUANT_LAYER_SAVE_PATH", + "/tmp/paroquant_evalution_saved_ckpt_layer", + ) + OPT_SCOPE = "layer" + TRAIN_ON_NOISY_INPUTS_DEFAULT = True diff --git a/tests/models/test_llama3_2_paroquant_optimize_module.py b/tests/models/test_llama3_2_paroquant_optimize_module.py new file mode 100644 index 000000000..86da916b9 --- /dev/null +++ b/tests/models/test_llama3_2_paroquant_optimize_module.py @@ -0,0 +1,13 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from paroquant_optimize_case import BaseLlama3_2ParoQuantOptimizeTest, _resolve_save_path + + +class TestLlama3_2_ParoQuant(BaseLlama3_2ParoQuantOptimizeTest): + __test__ = True + SAVE_PATH = _resolve_save_path( + "GPTQMODEL_PAROQUANT_MODULE_SAVE_PATH", + "/tmp/paroquant_evalution_saved_ckpt_module", + ) + OPT_SCOPE = "module" diff --git a/tests/models/test_llama3_2_torch_fused.py b/tests/models/test_llama3_2_torch_fused.py index 0f6ad35c0..a516692be 100644 --- a/tests/models/test_llama3_2_torch_fused.py +++ b/tests/models/test_llama3_2_torch_fused.py @@ -24,9 +24,11 @@ def test_with_torch_fused_cpu(self, backend): device=DEVICE.CPU, ) tokenizer = model.tokenizer - generate_str = tokenizer.decode( - model.generate(**tokenizer("The capital of France is is", return_tensors="pt").to(model.device), - max_new_tokens=512)[0]) + generate_str = self.generate_stable_with_limit( + model, + tokenizer, + "The capital of France is is", + ) print(f"generate_str: {generate_str}") diff --git a/tests/models/test_llama3_3_fp4_auto_decoder.py b/tests/models/test_llama3_3_fp4_auto_decoder.py new file mode 100644 index 000000000..bfe4db4ee --- /dev/null +++ b/tests/models/test_llama3_3_fp4_auto_decoder.py @@ -0,0 +1,136 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +import torch +from model_test import ModelTest + +from gptqmodel import BACKEND +from gptqmodel.looper.module_looper import StopMainLoop +from gptqmodel.models import auto +from gptqmodel.nn_modules.qlinear import BaseQuantLinear +from gptqmodel.quantization import AutoModuleDecoderConfig +from gptqmodel.quantization.dtype import get_device_dtype_support +from gptqmodel.utils.torch import torch_empty_cache + + +FIRST_LAYER_ONLY_NEGATIVE_MATCH = r"^model\.layers\.(?!0\.)\d+\." +# Keep the regression on layer 0 so the 70B checkpoint validates quickly on one GPU. + + +class TestLlama3_3FP4AutoDecoder(ModelTest): + """Verify GPTQ can quantize one real FP4 Llama layer through the auto-decoder path.""" + + NATIVE_MODEL_ID = "/monster/data/model/Llama-3.3-70B-Instruct-FP4" + LOAD_BACKEND = BACKEND.TORCH + QUANT_BACKEND = BACKEND.TORCH + TORCH_DTYPE = "bfloat16" + USE_FLASH_ATTN = False + QUANT_BATCH_SIZE = 1 + DATASET_SIZE = 4 + DATASET_CONCAT_SIZE = 1024 + OFFLOAD_TO_DISK = True + DYNAMIC = { + f"-:{FIRST_LAYER_ONLY_NEGATIVE_MATCH}": {}, + } + + def _build_quantize_config(self): + """Attach the auto-decoder preprocessor to the standard GPTQ config.""" + + cfg = super()._build_quantize_config() + cfg.preprocessors = [ + AutoModuleDecoderConfig( + target_dtype=torch.bfloat16, + ) + ] + cfg.wait_for_submodule_finalizers = True + return cfg + + def _expected_forward_mode(self) -> str: + """Mirror runtime device capability checks so the assertion stays future-proof.""" + + support = get_device_dtype_support(torch.device("cuda"), validate=False) + if hasattr(torch, "float4_e2m1fn_x2") and torch.float4_e2m1fn_x2 in support.advertised_linear_dtypes: + return "native" + return "decode" + + def _assert_only_first_layer_quantized(self, model) -> None: + """Ensure the debug-short-circuited quantization run touched only layer 0.""" + + layer0_quantized = [] + later_layer_quantized = [] + + for name, module in model.named_modules(): + if not isinstance(module, BaseQuantLinear): + continue + if ".layers.0." in name: + layer0_quantized.append(name) + elif ".layers." in name: + later_layer_quantized.append(name) + + assert layer0_quantized, "Expected at least one quantized module in layer 0." + assert not later_layer_quantized, ( + "Expected quantization only in layer 0, " + f"but found later-layer modules: {later_layer_quantized[:8]}" + ) + + def test_llama3_3_fp4_auto_decoder_quantizes_first_layer(self) -> None: + """Run one real 70B FP4 quantization layer and verify the auto-decoder path.""" + + model = None + dataset = None + try: + quantize_config = self._build_quantize_config() + quantize_config.device = torch.device("cuda") + + model_definition = auto.check_and_get_model_definition( + self.NATIVE_MODEL_ID, + self.TRUST_REMOTE_CODE, + ) + model = model_definition.from_pretrained( + pretrained_model_id_or_path=self.NATIVE_MODEL_ID, + quantize_config=quantize_config, + backend=self.LOAD_BACKEND, + trust_remote_code=self.TRUST_REMOTE_CODE, + dtype=self.TORCH_DTYPE, + attn_implementation="eager", + ) + model.layer_callback = self._build_layer_stop_callback(0) + + dataset = self.load_dataset(model.tokenizer, rows=self.DATASET_SIZE) + + try: + model.quantize( + dataset, + calibration_concat_size=self.DATASET_CONCAT_SIZE, + calibration_concat_separator=self.DATASET_CONCAT_SEPARATOR, + calibration_sort=self.DATASET_SORT, + backend=self.QUANT_BACKEND, + batch_size=self.QUANT_BATCH_SIZE, + ) + except StopMainLoop: + # The layer callback intentionally stops after layer 0 once the FP4 decode path is observed. + pass + + self._assert_only_first_layer_quantized(model) + + events = [ + entry + for entry in getattr(model, "auto_module_decoder_events", []) + if entry["module"].startswith("model.layers.0.") + ] + assert events, "Expected layer-0 auto-decoder events for Llama-3.3-70B-Instruct-FP4." + assert all(entry["target_dtype"] == "bfloat16" for entry in events) + + expected_mode = self._expected_forward_mode() + if expected_mode == "native": + assert any(entry["forward_mode"] == "native" for entry in events), ( + f"Expected at least one native FP4 forward event, got {events[:8]}" + ) + else: + assert all(entry["forward_mode"] == "decode" for entry in events), ( + f"Expected decode-only events, got {events[:8]}" + ) + finally: + del dataset + del model + torch_empty_cache() diff --git a/tests/models/test_llama4.py b/tests/models/test_llama4.py index 940112dcd..244b622c8 100644 --- a/tests/models/test_llama4.py +++ b/tests/models/test_llama4.py @@ -5,20 +5,19 @@ from model_test import ModelTest -from gptqmodel.utils.eval import EVAL - class TestLlama4(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Llama-4-Scout-17B-16E-Instruct" # "meta-llama/Llama-4-Scout-17B-16E-Instruct" - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + EVAL_TASKS_SLOW = { + "arc_challenge": { "chat_template": True, "acc": {"value": 0.3567, "floor_pct": 0.36}, "acc_norm": {"value": 0.3805, "floor_pct": 0.36}, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) TRUST_REMOTE_CODE = False USE_FLASH_ATTN = False def test_llama4(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_longllama.py b/tests/models/test_longllama.py index 8667e71cb..5cb145a83 100644 --- a/tests/models/test_longllama.py +++ b/tests/models/test_longllama.py @@ -6,21 +6,21 @@ from model_test import ModelTest from gptqmodel.utils.backend import BACKEND -from gptqmodel.utils.eval import EVAL class TestLongLlama(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/long_llama_3b_instruct" # "syzymon/long_llama_3b_instruct" TRUST_REMOTE_CODE = True - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + EVAL_TASKS_SLOW = { + "arc_challenge": { "acc": {"value": 0.3515, "floor_pct": 0.5}, "acc_norm": {"value": 0.3652, "floor_pct": 0.5}, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) USE_VLLM = False USE_FLASH_ATTN = False LOAD_BACKEND = BACKEND.TORCH def test_longllama(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_marin.py b/tests/models/test_marin.py index d3a0c7bfe..58d207d5e 100644 --- a/tests/models/test_marin.py +++ b/tests/models/test_marin.py @@ -5,22 +5,21 @@ from model_test import ModelTest -from gptqmodel.utils.eval import EVAL - class TestMarin(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/marin-32b-base" # VRAM_STRATEGY = VramStrategy.BALANCED # Marin inherits Qwen3's backbone with QK-Norm attention. - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + EVAL_TASKS_SLOW = { + "arc_challenge": { "acc": {"value": 0.5725, "floor_pct": 0.04}, "acc_norm": {"value": 0.6007, "floor_pct": 0.04}, }, - EVAL.LM_EVAL.MMLU_STEM: { + "mmlu_stem": { "acc": {"value": 0.6670, "floor_pct": 0.04}, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) def test_marin(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_mimo.py b/tests/models/test_mimo.py index 3f90d68ea..37bea347d 100644 --- a/tests/models/test_mimo.py +++ b/tests/models/test_mimo.py @@ -5,20 +5,19 @@ from model_test import ModelTest -from gptqmodel.utils.eval import EVAL - class TestMimo(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/MiMo-7B-RL" - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + EVAL_TASKS_SLOW = { + "arc_challenge": { "chat_template": True, "acc": {"value": 0.2739, "floor_pct": 0.2}, "acc_norm": {"value": 0.3055, "floor_pct": 0.2}, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) TRUST_REMOTE_CODE = True EVAL_BATCH_SIZE = 6 def test_mimo(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_minicpm_o_4_5.py b/tests/models/test_minicpm_o_4_5.py new file mode 100644 index 000000000..96519d3bc --- /dev/null +++ b/tests/models/test_minicpm_o_4_5.py @@ -0,0 +1,46 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium +import os + +from model_test import ModelTest +from PIL import Image + + +class TestMiniCPMO4_5(ModelTest): + NATIVE_MODEL_ID = "/monster/data/model/MiniCPM-o-4_5" # openbmb/MiniCPM-o-4_5 + TRUST_REMOTE_CODE = True + EVAL_BATCH_SIZE = 1 + + def test_minicpm_o_4_5(self): + # Evalution does not support minicpmo, and will throw an error during execution: + # E TypeError: MiniCPMO.forward() missing 1 required positional argument: 'data + with self.model_compat_test_context(): + model, tokenizer, processor = self.quantModel( + self.NATIVE_MODEL_ID, + trust_remote_code=self.TRUST_REMOTE_CODE, + dtype=self.TORCH_DTYPE, + batch_size=1, + call_perform_post_quant_validation=False, + ) + + image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "ovis/10016.jpg") + image = Image.open(image_path).convert('RGB') + + # First round chat + question = "What is the landform in the picture?" + msgs = [{'role': 'user', 'content': [image, question]}] + + answer = model.chat( + msgs=msgs, + tokenizer=tokenizer, + ) + + generated_text = "" + for new_text in answer: + generated_text += new_text + + print(f'Output:\n{generated_text}') + + self.assertIn("snow", generated_text.lower()) diff --git a/tests/models/test_minicpm_v_4_5.py b/tests/models/test_minicpm_v_4_5.py new file mode 100644 index 000000000..743db3242 --- /dev/null +++ b/tests/models/test_minicpm_v_4_5.py @@ -0,0 +1,47 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import os.path + +from model_test import ModelTest +from PIL import Image + + +class TestMiniCPMV4_5(ModelTest): + NATIVE_MODEL_ID = "/monster/data/model/MiniCPM-V-4_5" + TRUST_REMOTE_CODE = True + EVAL_BATCH_SIZE = 1 + + def test_minicpm_v_4_5(self): + # Evalution does not support minicpmv, and will throw an error during execution: + # E TypeError: MiniCPMV.forward() missing 1 required positional argument: 'data + with self.model_compat_test_context(): + model, tokenizer, processor = self.quantModel( + self.NATIVE_MODEL_ID, + trust_remote_code=self.TRUST_REMOTE_CODE, + dtype=self.TORCH_DTYPE, + batch_size=1, + call_perform_post_quant_validation=False, + ) + + image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "ovis/10016.jpg") + image = Image.open(image_path).convert('RGB') + + # First round chat + question = "What is the landform in the picture?" + msgs = [{'role': 'user', 'content': [image, question]}] + + answer = model.chat( + msgs=msgs, + tokenizer=tokenizer, + ) + + generated_text = "" + for new_text in answer: + generated_text += new_text + + print(f'Output:\n{generated_text}') + + self.assertIn("snow", generated_text.lower()) diff --git a/tests/models/test_minimax_m2.py b/tests/models/test_minimax_m2.py index 3d91b8f23..e17991a9c 100644 --- a/tests/models/test_minimax_m2.py +++ b/tests/models/test_minimax_m2.py @@ -5,8 +5,6 @@ from model_test import ModelTest -from gptqmodel.utils.eval import EVAL - class TestMinimaxM2(ModelTest): @@ -16,14 +14,15 @@ class TestMinimaxM2(ModelTest): DELETE_QUANTIZED_MODEL = False DATASET_SIZE = 1024 GROUP_SIZE = 32 - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + EVAL_TASKS_SLOW = { + "arc_challenge": { "acc": {"value": 0.5026, "floor_pct": 0.04}, "acc_norm": {"value": 0.5171, "floor_pct": 0.04}, }, - EVAL.LM_EVAL.MMLU_STEM: { + "mmlu_stem": { "acc": {"value": 0.6362, "floor_pct": 0.04}, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) def test_minimax_m2(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_minimax_m2_hf.py b/tests/models/test_minimax_m2_hf.py new file mode 100644 index 000000000..92b06d6d2 --- /dev/null +++ b/tests/models/test_minimax_m2_hf.py @@ -0,0 +1,179 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +""" +MiniMax-M2 Hugging Face checkpoint sanity check with streaming output. + +Usage: + python tests/models/test_minimax_m2_hf.py \ + --model-path /monster/data/model/MiniMax-M2-bf16 \ + --question "How many letter A are there in the word Alphabet? Reply with the number only." +""" + +from __future__ import annotations + +import argparse +import threading +from pathlib import Path + +import torch.nn as nn +from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer + +# from gptqmodel.hf_minimax_m2.modeling_minimax_m2 import ( +# MiniMaxAttention, +# MiniMaxDecoderLayer, +# MiniMaxForCausalLM, +# MiniMaxMLP, +# MiniMaxM2Attention, +# MiniMaxM2DecoderLayer, +# MiniMaxM2ForCausalLM, +# MiniMaxM2MLP, +# MiniMaxM2RMSNorm, +# MiniMaxM2SparseMoeBlock, +# MiniMaxRMSNorm, +# MiniMaxSparseMoeBlock, +# ) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="MiniMax-M2 HF checkpoint smoke test.") + parser.add_argument( + "--model-path", + type=str, + default="/monster/data/model/MiniMax-M2-bf16", + help="Path to the MiniMax-M2 Hugging Face checkpoint directory.", + ) + parser.add_argument( + "--question", + type=str, + default="How many letter A are there in the word Alphabet? Reply with the number only.", + help="User question to send through the chat template.", + ) + parser.add_argument( + "--max-new-tokens", + type=int, + default=512, + help="Maximum number of new tokens to sample from the model.", + ) + return parser.parse_args() + + +def build_prompt(tokenizer: AutoTokenizer, question: str) -> str: + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": question}, + ] + return tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) + + +# def assert_module_types(model: MiniMaxM2ForCausalLM) -> None: +# causal_lm_types = (MiniMaxM2ForCausalLM, MiniMaxForCausalLM) +# decoder_layer_types = (MiniMaxM2DecoderLayer, MiniMaxDecoderLayer) +# attention_types = (MiniMaxM2Attention, MiniMaxAttention) +# moe_block_types = (MiniMaxM2SparseMoeBlock, MiniMaxSparseMoeBlock) +# norm_types = (MiniMaxM2RMSNorm, MiniMaxRMSNorm) +# mlp_types = (MiniMaxM2MLP, MiniMaxMLP) +# +# assert isinstance( +# model, causal_lm_types +# ), f"Expected MiniMaxM2ForCausalLM/MiniMaxForCausalLM, received {type(model).__name__}" +# +# decoder = getattr(model, "model", None) +# assert decoder is not None, "Model is missing the `model` attribute with decoder layers." +# +# for layer_idx, layer in enumerate(decoder.layers): +# assert isinstance( +# layer, decoder_layer_types +# ), f"Layer {layer_idx}: expected MiniMax(M2)DecoderLayer, got {type(layer).__name__}" +# assert isinstance( +# layer.self_attn, attention_types +# ), f"Layer {layer_idx}: unexpected self_attn type {type(layer.self_attn).__name__}" +# assert isinstance( +# layer.block_sparse_moe, moe_block_types +# ), f"Layer {layer_idx}: unexpected MoE block type {type(layer.block_sparse_moe).__name__}" +# assert isinstance( +# layer.input_layernorm, norm_types +# ), f"Layer {layer_idx}: unexpected input_layernorm type {type(layer.input_layernorm).__name__}" +# assert isinstance( +# layer.post_attention_layernorm, norm_types +# ), f"Layer {layer_idx}: unexpected post_attention_layernorm type {type(layer.post_attention_layernorm).__name__}" +# +# moe_block = layer.block_sparse_moe +# assert isinstance( +# moe_block.experts, nn.ModuleList +# ), f"Layer {layer_idx}: expected experts to be a ModuleList, got {type(moe_block.experts).__name__}" +# for expert_idx, expert in enumerate(moe_block.experts): +# assert isinstance( +# expert, mlp_types +# ), f"Layer {layer_idx} expert {expert_idx}: expected MiniMax(M2)MLP, got {type(expert).__name__}" +# + + +def main() -> None: + args = parse_args() + model_path = Path(args.model_path).expanduser().resolve() + + print(f"Loading tokenizer from {model_path}...") + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + + print(f"Loading model from {model_path}...") + model = AutoModelForCausalLM.from_pretrained( + model_path, + dtype="bfloat16", + device_map="auto", + trust_remote_code=True, + ) + + # Uncomment to enforce module type checks. + # print("Validating module types...") + # assert_module_types(model) + + prompt = build_prompt(tokenizer, args.question) + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + + print("Running generation (streaming)...\n") + streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=False) + eos_ids = model.generation_config.eos_token_id + if eos_ids is None: + eos_ids = [] + elif isinstance(eos_ids, int): + eos_ids = [eos_ids] + think_end_id = tokenizer.convert_tokens_to_ids("") + if think_end_id is not None and think_end_id not in eos_ids: + eos_ids = eos_ids + [think_end_id] + + generation_kwargs = dict( + **inputs, + max_new_tokens=args.max_new_tokens, + streamer=streamer, + eos_token_id=eos_ids if eos_ids else None, + ) + + generation_thread = threading.Thread(target=model.generate, kwargs=generation_kwargs) + generation_thread.start() + + completion = [] + first_chunk = True + seen_end_reasoning = False + for text in streamer: + if first_chunk: + print("", end="", flush=True) + completion.append("") + first_chunk = False + print(text, end="", flush=True) + completion.append(text) + if "" in text: + seen_end_reasoning = True + + generation_thread.join() + print("\n\n=== Completed Response ===") + final_text = "".join(completion).strip() + print(final_text or "") + if not seen_end_reasoning: + print("\n[warning] No token detected in streamed output.", flush=True) + + +if __name__ == "__main__": + main() diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index 016d350d2..b4cdc74bd 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -5,22 +5,25 @@ from model_test import ModelTest -from gptqmodel.utils.eval import EVAL - class TestMistral(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Mistral-7B-Instruct-v0.2" # "mistralai/Mistral-7B-Instruct-v0.2" NATIVE_ARC_CHALLENGE_ACC = 0.5427 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.5597 + NATIVE_ARC_CHALLENGE_ACC_SLOW = NATIVE_ARC_CHALLENGE_ACC + NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW = NATIVE_ARC_CHALLENGE_ACC_NORM + NATIVE_ARC_CHALLENGE_ACC_FAST = NATIVE_ARC_CHALLENGE_ACC_SLOW + NATIVE_ARC_CHALLENGE_ACC_NORM_FAST = NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW TRUST_REMOTE_CODE = True EVAL_BATCH_SIZE = 6 - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + EVAL_TASKS_SLOW = { + "arc_challenge": { "chat_template": True, "acc": {"value": NATIVE_ARC_CHALLENGE_ACC}, "acc_norm": {"value": NATIVE_ARC_CHALLENGE_ACC_NORM}, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) def test_mistral(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_mistral3.py b/tests/models/test_mistral3.py index 3bd656741..7e99984ac 100644 --- a/tests/models/test_mistral3.py +++ b/tests/models/test_mistral3.py @@ -5,22 +5,25 @@ from model_test import ModelTest -from gptqmodel.utils.eval import EVAL - class TestMistral3(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Ministral-3-3B-Instruct-2512-BF16" # "mistralai/Ministral-3-3B-Instruct-2512-BF16" NATIVE_ARC_CHALLENGE_ACC = 0.4974 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.5256 + NATIVE_ARC_CHALLENGE_ACC_SLOW = NATIVE_ARC_CHALLENGE_ACC + NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW = NATIVE_ARC_CHALLENGE_ACC_NORM + NATIVE_ARC_CHALLENGE_ACC_FAST = NATIVE_ARC_CHALLENGE_ACC_SLOW + NATIVE_ARC_CHALLENGE_ACC_NORM_FAST = NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW TRUST_REMOTE_CODE = False EVAL_BATCH_SIZE = 6 - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + EVAL_TASKS_SLOW = { + "arc_challenge": { "chat_template": False, "acc": {"value": NATIVE_ARC_CHALLENGE_ACC}, "acc_norm": {"value": NATIVE_ARC_CHALLENGE_ACC_NORM}, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) def test_mistral3(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_mixtral.py b/tests/models/test_mixtral.py index 178a0b080..8079261ce 100644 --- a/tests/models/test_mixtral.py +++ b/tests/models/test_mixtral.py @@ -5,22 +5,24 @@ from model_test import ModelTest -from gptqmodel.utils.eval import EVAL - class TestMixtral(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Mixtral-8x7B-Instruct-v0.1" # "mistralai/Mixtral-8x7B-Instruct-v0.1" NATIVE_ARC_CHALLENGE_ACC = 0.5213 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.5247 - TRUST_REMOTE_CODE = True + NATIVE_ARC_CHALLENGE_ACC_SLOW = NATIVE_ARC_CHALLENGE_ACC + NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW = NATIVE_ARC_CHALLENGE_ACC_NORM + NATIVE_ARC_CHALLENGE_ACC_FAST = NATIVE_ARC_CHALLENGE_ACC_SLOW + NATIVE_ARC_CHALLENGE_ACC_NORM_FAST = NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW EVAL_BATCH_SIZE = 6 - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + EVAL_TASKS_SLOW = { + "arc_challenge": { "chat_template": True, "acc": {"value": NATIVE_ARC_CHALLENGE_ACC}, "acc_norm": {"value": NATIVE_ARC_CHALLENGE_ACC_NORM}, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) def test_mixtral(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_model_test_fast_mode.py b/tests/models/test_model_test_fast_mode.py new file mode 100644 index 000000000..6e37c8f73 --- /dev/null +++ b/tests/models/test_model_test_fast_mode.py @@ -0,0 +1,57 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from types import SimpleNamespace + +import torch.nn as nn +from model_test import ModelTest + + +class _DummyCompatCase(ModelTest): + __test__ = False + MODEL_COMPAT_FAST_LAYER_COUNT = 2 + + def runTest(self): + return None + + +class _FakeQuantModel: + def __init__(self, layer_count: int): + self.model = SimpleNamespace(layers=nn.ModuleList([nn.Linear(1, 1) for _ in range(layer_count)])) + self.quantize_config = SimpleNamespace(dynamic=None) + + @staticmethod + def extract_layers_node() -> str: + return "layers" + + +def test_model_test_fast_mode_defaults_to_last_layers(monkeypatch): + monkeypatch.delenv("GPTQMODEL_FAST_LAYER_POSITION", raising=False) + case = _DummyCompatCase(methodName="runTest") + model = _FakeQuantModel(layer_count=6) + + with case.model_compat_test_context(): + dynamic = case._build_fast_model_compat_dynamic(model) + + assert sorted(dynamic) == [ + "-:^layers\\.0\\.", + "-:^layers\\.1\\.", + "-:^layers\\.2\\.", + "-:^layers\\.3\\.", + ] + + +def test_model_test_fast_mode_first_layers_remain_configurable(monkeypatch): + monkeypatch.setenv("GPTQMODEL_FAST_LAYER_POSITION", "first") + case = _DummyCompatCase(methodName="runTest") + model = _FakeQuantModel(layer_count=6) + + with case.model_compat_test_context(): + dynamic = case._build_fast_model_compat_dynamic(model) + + assert sorted(dynamic) == [ + "-:^layers\\.2\\.", + "-:^layers\\.3\\.", + "-:^layers\\.4\\.", + "-:^layers\\.5\\.", + ] diff --git a/tests/models/test_mpt.py b/tests/models/test_mpt.py index 3f2c7b29b..081d3ee10 100644 --- a/tests/models/test_mpt.py +++ b/tests/models/test_mpt.py @@ -5,23 +5,26 @@ from model_test import ModelTest -from gptqmodel.utils.eval import EVAL - class TestMpt(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/mpt-7b-instruct" # "mosaicml/mpt-7b-instruct" NATIVE_ARC_CHALLENGE_ACC = 0.4275 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.4454 + NATIVE_ARC_CHALLENGE_ACC_SLOW = NATIVE_ARC_CHALLENGE_ACC + NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW = NATIVE_ARC_CHALLENGE_ACC_NORM + NATIVE_ARC_CHALLENGE_ACC_FAST = NATIVE_ARC_CHALLENGE_ACC_SLOW + NATIVE_ARC_CHALLENGE_ACC_NORM_FAST = NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW TRUST_REMOTE_CODE = False EVAL_BATCH_SIZE = 6 USE_FLASH_ATTN = False - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + EVAL_TASKS_SLOW = { + "arc_challenge": { "chat_template": False, "acc": {"value": NATIVE_ARC_CHALLENGE_ACC}, "acc_norm": {"value": NATIVE_ARC_CHALLENGE_ACC_NORM}, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) def test_mpt(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_multi_vs_single_gpu.py b/tests/models/test_multi_vs_single_gpu.py index 416d7c18a..54df8288d 100644 --- a/tests/models/test_multi_vs_single_gpu.py +++ b/tests/models/test_multi_vs_single_gpu.py @@ -25,7 +25,6 @@ QUANT_LOG_NSAMPLES, ) from gptqmodel.quantization.config import QuantizeConfig -from gptqmodel.utils.eval import EVAL from gptqmodel.utils.torch import torch_empty_cache @@ -47,13 +46,13 @@ def _is_free_threaded() -> bool: class TestMultiVsSingleGPU(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct" EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + "arc_challenge": { "chat_template": True, "acc": {"value": 0.3311, "floor_pct": 0.05}, "acc_norm": {"value": 0.3549, "floor_pct": 0.05}, }, } - GPTQA = False + GPTAQ = None DEBUG = True ACT_GROUP_AWARE = False DESC_ACT = True @@ -168,7 +167,7 @@ def layer_complete(self, *, layer_idx: int, submodule_finalized: bool): group_size=self.GROUP_SIZE, desc_act=self.DESC_ACT if not self.ACT_GROUP_AWARE else False, act_group_aware=self.ACT_GROUP_AWARE, - failsafe=self.FAILSAFE, + fallback=self.FALLBACK, sym=self.SYM, v2=self.V2, adapter=self.EORA, @@ -353,8 +352,8 @@ def _capture_primary_handles(primary_handles: Dict[str, str]): original_preprocess = GPTQProcessor.preprocess - def wrapped_preprocess(self, module, failsafe=None): # type: ignore[override] - result = original_preprocess(self, module, failsafe=failsafe) + def wrapped_preprocess(self, module, fallback=None): # type: ignore[override] + result = original_preprocess(self, module, fallback=fallback) task = self.tasks.get(module.name) if task is not None: primary_handles[module.name] = hex(id(task)) diff --git a/tests/models/test_nemotron_ultra.py b/tests/models/test_nemotron_ultra.py index 926daef59..743a9c5b8 100644 --- a/tests/models/test_nemotron_ultra.py +++ b/tests/models/test_nemotron_ultra.py @@ -5,19 +5,18 @@ from model_test import ModelTest -from gptqmodel.utils.eval import EVAL - class TestNemotronUltra(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Llama-3_1-Nemotron-Ultra-253B-v1" - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + EVAL_TASKS_SLOW = { + "arc_challenge": { "chat_template": True, "acc": {"value": 0.3567, "floor_pct": 0.36}, "acc_norm": {"value": 0.3805, "floor_pct": 0.36}, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) TRUST_REMOTE_CODE = True def test_nemotron_ultra(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_opt.py b/tests/models/test_opt.py index 31ad76284..982d86d00 100644 --- a/tests/models/test_opt.py +++ b/tests/models/test_opt.py @@ -11,7 +11,11 @@ class TestOpt(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/opt-125m" # "facebook/opt-125m" NATIVE_ARC_CHALLENGE_ACC = 0.1920 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.2253 + NATIVE_ARC_CHALLENGE_ACC_SLOW = NATIVE_ARC_CHALLENGE_ACC + NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW = NATIVE_ARC_CHALLENGE_ACC_NORM + NATIVE_ARC_CHALLENGE_ACC_FAST = 0.18430034129692832 + NATIVE_ARC_CHALLENGE_ACC_NORM_FAST = 0.2226962457337884 INPUTS_MAX_LENGTH = 2048 # opt embedding is max 2048 def test_opt(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_ovis2.py b/tests/models/test_ovis2.py index babda3b9b..6fb808fa0 100644 --- a/tests/models/test_ovis2.py +++ b/tests/models/test_ovis2.py @@ -42,9 +42,13 @@ def test_ovis(self): inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16) with torch.inference_mode(): - output_ids = model.generate(**inputs, max_new_tokens=128, do_sample=False) - generated_ids = [output_ids[len(input_ids):] for input_ids, output_ids in zip(inputs.input_ids, output_ids)] - output = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + output = self.generate_stable_with_limit( + model, + processor, + inputs=inputs, + max_new_tokens=128, + batch_decode=True, + ) print(f'Output:\n{output}') self.assertIn("snow", output.lower()) diff --git a/tests/models/test_ovis_1_6_llama.py b/tests/models/test_ovis_1_6_llama.py index a21493f46..e0f820da1 100644 --- a/tests/models/test_ovis_1_6_llama.py +++ b/tests/models/test_ovis_1_6_llama.py @@ -18,10 +18,10 @@ class TestOvis1_6_Llama(ModelTest): USE_FLASH_ATTN = False def test_ovis_1_6(self): - # lm_eval does not support Ovis, and will throw an error during execution: + # the evaluation harness does not support Ovis, and will throw an error during execution: # TypeError: Ovis.forward() missing 3 required positional arguments: 'attention_mask', 'labels', and 'pixel_values' - model, tokenizer = self.quantModel(self.NATIVE_MODEL_ID, trust_remote_code=self.TRUST_REMOTE_CODE, - dtype=self.TORCH_DTYPE, multimodal_max_length=8192, batch_size=1, call_perform_post_quant_validation=False) + model, tokenizer, _ = self.quantModel(self.NATIVE_MODEL_ID, trust_remote_code=self.TRUST_REMOTE_CODE, + dtype=self.TORCH_DTYPE, multimodal_max_length=8192, batch_size=1, call_perform_post_quant_validation=False) text_tokenizer = model.get_text_tokenizer() visual_tokenizer = model.get_visual_tokenizer() @@ -38,23 +38,22 @@ def test_ovis_1_6(self): input_ids = input_ids.unsqueeze(0).to(device=model.device) attention_mask = attention_mask.unsqueeze(0).to(device=model.device) pixel_values = [pixel_values.to(dtype=visual_tokenizer.dtype, device=visual_tokenizer.device)] + inputs = { + "input_ids": input_ids, + "pixel_values": pixel_values, + "attention_mask": attention_mask, + } # generate output with torch.inference_mode(): - gen_kwargs = { - "max_new_tokens": 1024, - "do_sample": False, - "top_p": None, - "top_k": None, - "temperature": None, - "repetition_penalty": None, - "eos_token_id": model.generation_config.eos_token_id, - "pad_token_id": text_tokenizer.pad_token_id, - "use_cache": True - } - output_ids = \ - model.generate(input_ids, pixel_values=pixel_values, attention_mask=attention_mask, **gen_kwargs)[0] - output = text_tokenizer.decode(output_ids, skip_special_tokens=True) + output = self.generate_stable_with_limit( + model, + text_tokenizer, + inputs=inputs, + max_new_tokens=1024, + skip_special_tokens=True, + use_cache=True, + ) print(f'Output:\n{output}') diff --git a/tests/models/test_pangu_alpha.py b/tests/models/test_pangu_alpha.py index 2fc2f7e2d..4aa345b09 100644 --- a/tests/models/test_pangu_alpha.py +++ b/tests/models/test_pangu_alpha.py @@ -11,9 +11,13 @@ class TestGpt2(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/pangu_alpha_2_6B" # "ModelCloud/pangu_alpha_2_6B" NATIVE_ARC_CHALLENGE_ACC = 0.1655 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.1945 + NATIVE_ARC_CHALLENGE_ACC_SLOW = NATIVE_ARC_CHALLENGE_ACC + NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW = NATIVE_ARC_CHALLENGE_ACC_NORM + NATIVE_ARC_CHALLENGE_ACC_FAST = NATIVE_ARC_CHALLENGE_ACC_SLOW + NATIVE_ARC_CHALLENGE_ACC_NORM_FAST = NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW TORCH_DTYPE = torch.float16 TRUST_REMOTE_CODE = True INPUTS_MAX_LENGTH = 1024 def test_gpt2(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_phi_3.py b/tests/models/test_phi_3.py index dcbac750f..666de1d55 100644 --- a/tests/models/test_phi_3.py +++ b/tests/models/test_phi_3.py @@ -5,21 +5,30 @@ from model_test import ModelTest -from gptqmodel.utils.eval import EVAL - class TestPhi_3(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Phi-3-mini-4k-instruct" # "microsoft/Phi-3-mini-4k-instruct" NATIVE_ARC_CHALLENGE_ACC = 0.5401 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.5674 + NATIVE_ARC_CHALLENGE_ACC_SLOW = NATIVE_ARC_CHALLENGE_ACC + NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW = NATIVE_ARC_CHALLENGE_ACC_NORM + NATIVE_ARC_CHALLENGE_ACC_FAST = 0.5477815699658704 + NATIVE_ARC_CHALLENGE_ACC_NORM_FAST = 0.5742320819112628 TRUST_REMOTE_CODE = True - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + EVAL_TASKS_SLOW = { + "arc_challenge": { "chat_template": True, "acc": {"value": NATIVE_ARC_CHALLENGE_ACC}, "acc_norm": {"value": NATIVE_ARC_CHALLENGE_ACC_NORM}, }, } + EVAL_TASKS_FAST = { + "arc_challenge": { + "chat_template": True, + "acc": {"value": NATIVE_ARC_CHALLENGE_ACC_FAST, "ceil_pct": 1.0}, + "acc_norm": {"value": NATIVE_ARC_CHALLENGE_ACC_NORM_FAST, "ceil_pct": 1.0}, + }, + } def test_phi_3(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_phi_3_moe.py b/tests/models/test_phi_3_moe.py index ceb16af9c..83530d4c9 100644 --- a/tests/models/test_phi_3_moe.py +++ b/tests/models/test_phi_3_moe.py @@ -5,21 +5,23 @@ from model_test import ModelTest -from gptqmodel.utils.eval import EVAL - class TestPhi_3(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Phi-3.5-MoE-instruct" # microsoft/Phi-3.5-MoE-instruct NATIVE_ARC_CHALLENGE_ACC = 0.5401 - NATIVE_ARC_CHALLENGE_ACC_NORM = 0.5674 - TRUST_REMOTE_CODE = True - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + NATIVE_ARC_CHALLENGE_ACC_NORM = 0.5051 + NATIVE_ARC_CHALLENGE_ACC_SLOW = NATIVE_ARC_CHALLENGE_ACC + NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW = NATIVE_ARC_CHALLENGE_ACC_NORM + NATIVE_ARC_CHALLENGE_ACC_FAST = NATIVE_ARC_CHALLENGE_ACC_SLOW + NATIVE_ARC_CHALLENGE_ACC_NORM_FAST = NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW + EVAL_TASKS_SLOW = { + "arc_challenge": { "chat_template": True, "acc": {"value": NATIVE_ARC_CHALLENGE_ACC}, "acc_norm": {"value": NATIVE_ARC_CHALLENGE_ACC_NORM}, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) def test_phi_3(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_phi_4.py b/tests/models/test_phi_4.py index 0c6a11a13..6855e4460 100644 --- a/tests/models/test_phi_4.py +++ b/tests/models/test_phi_4.py @@ -5,23 +5,26 @@ from model_test import ModelTest -from gptqmodel.utils.eval import EVAL - class TestPhi_4(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Phi-4-multimodal-instruct" # "microsoft/Phi-3-mini-4k-instruct" NATIVE_ARC_CHALLENGE_ACC = 0.5401 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.5674 + NATIVE_ARC_CHALLENGE_ACC_SLOW = NATIVE_ARC_CHALLENGE_ACC + NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW = NATIVE_ARC_CHALLENGE_ACC_NORM + NATIVE_ARC_CHALLENGE_ACC_FAST = NATIVE_ARC_CHALLENGE_ACC_SLOW + NATIVE_ARC_CHALLENGE_ACC_NORM_FAST = NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW TRUST_REMOTE_CODE = True USE_FLASH_ATTN = False BATCH_SIZE = 1 - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + EVAL_TASKS_SLOW = { + "arc_challenge": { "chat_template": True, "acc": {"value": NATIVE_ARC_CHALLENGE_ACC}, "acc_norm": {"value": NATIVE_ARC_CHALLENGE_ACC_NORM}, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) def test_phi_4(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_qwen2_5.py b/tests/models/test_qwen2_5.py index 695d757d1..a01abfe5f 100644 --- a/tests/models/test_qwen2_5.py +++ b/tests/models/test_qwen2_5.py @@ -5,10 +5,7 @@ from model_test import ModelTest -from gptqmodel.utils.eval import EVAL - -# | Metric | MARLIN | # |--------------------------------|----------| # | arc_challenge :: acc,none | 0.2961 | # | arc_challenge :: acc_norm,none | 0.3285 | @@ -20,22 +17,51 @@ class TestQwen2_5(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Qwen2.5-0.5B-Instruct" EVAL_BATCH_SIZE = 64 DATASET_CONCAT_SIZE = 2048 - EVAL_TASKS = { - EVAL.LM_EVAL.GSM8K_PLATINUM_COT: { + EVAL_TASKS_SLOW = { + "gsm8k_platinum_cot": { "chat_template": True, - "exact_match,flexible-extract": { + "acc,num": { "value": 0.2963, "floor_pct": 0.04, }, }, - EVAL.LM_EVAL.ARC_CHALLENGE: { + "arc_challenge": { "acc": {"value": 0.2961, "floor_pct": 0.04}, "acc_norm": {"value": 0.3285, "floor_pct": 0.04}, }, - EVAL.LM_EVAL.MMLU_STEM: { + "mmlu_stem": { "acc": {"value": 0.3942, "floor_pct": 0.04}, }, } + EVAL_TASKS_FAST = { + "gsm8k_platinum_cot": { + "chat_template": True, + "acc,num": { + "value": 0.38626964433416044, + "floor_pct": 0.04, + "ceil_pct": 1.0, + }, + }, + "arc_challenge": { + "acc": { + "value": 0.2977815699658703, + "floor_pct": 0.04, + "ceil_pct": 1.0, + }, + "acc_norm": { + "value": 0.34044368600682595, + "floor_pct": 0.04, + "ceil_pct": 1.0, + }, + }, + "mmlu_stem": { + "acc": { + "value": 0.3967649857278782, + "floor_pct": 0.04, + "ceil_pct": 1.0, + }, + }, + } def test_qwen2_5(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_qwen2_5_omni.py b/tests/models/test_qwen2_5_omni.py index 12b010154..749716c2c 100644 --- a/tests/models/test_qwen2_5_omni.py +++ b/tests/models/test_qwen2_5_omni.py @@ -3,18 +3,19 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium import os +import unittest +from importlib.metadata import PackageNotFoundError, version -import soundfile as sf from model_test import ModelTest +from packaging.version import Version from gptqmodel.models.definitions.qwen2_5_omni import Qwen2_5_OmniGPTQ -from gptqmodel.utils.eval import EVAL class TestQwen2_5_Omni(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Qwen2.5-Omni-3B" EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + "arc_challenge": { "chat_template": True, "acc": {"value": 0.2329, "floor_pct": 0.2}, "acc_norm": {"value": 0.2765, "floor_pct": 0.2}, @@ -23,7 +24,36 @@ class TestQwen2_5_Omni(ModelTest): TRUST_REMOTE_CODE = False EVAL_BATCH_SIZE = 6 + @classmethod + def setUpClass(cls): + super().setUpClass() + + required = { + "audioread": Version("3.1.0"), + "librosa": Version("0.11.0"), + "av": Version("16.0.1"), + } + for pkg, minimum in required.items(): + try: + installed = Version(version(pkg)) + except PackageNotFoundError: + raise unittest.SkipTest( + f"Qwen2.5 Omni requires {pkg}>={minimum}" + ) + + if installed < minimum: + raise unittest.SkipTest( + f"Qwen2.5 Omni requires {pkg}>={minimum}, found {installed}" + ) + + try: + version("soundfile") + except PackageNotFoundError: + raise unittest.SkipTest("Qwen2.5 Omni requires soundfile") + def test_qwen2_5_omni(self): + import soundfile as sf + model, tokenizer, processor = self.quantModel(self.NATIVE_MODEL_ID, trust_remote_code=self.TRUST_REMOTE_CODE, dtype=self.TORCH_DTYPE) spk_path = self.NATIVE_MODEL_ID + '/spk_dict.pt' @@ -68,7 +98,14 @@ def test_qwen2_5_omni(self): # Inference: Generation of the output (text and audio) audio_file_name = 'output_gptq.wav' - generated_ids, audio = model.generate(**inputs, max_new_tokens=128, return_audio = True) + generated_ids, audio = self.generate_stable_with_limit( + model, + processor, + inputs=inputs, + max_new_tokens=128, + return_generate_output=True, + return_audio=True, + ) sf.write( audio_file_name, audio.reshape(-1).detach().cpu().numpy(), diff --git a/tests/models/test_qwen2_5_vl.py b/tests/models/test_qwen2_5_vl.py index 09609c796..35426dc6b 100644 --- a/tests/models/test_qwen2_5_vl.py +++ b/tests/models/test_qwen2_5_vl.py @@ -6,24 +6,25 @@ from model_test import ModelTest from gptqmodel.models.definitions.qwen2_5_vl import Qwen2_5_VLQModel -from gptqmodel.utils.eval import EVAL class TestQwen2_5_VL(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Qwen2.5-VL-3B-Instruct" - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + EVAL_TASKS_SLOW = { + "arc_challenge": { "chat_template": True, "acc": {"value": 0.4309, "floor_pct": 0.2}, "acc_norm": {"value": 0.4113, "floor_pct": 0.2}, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) TRUST_REMOTE_CODE = False EVAL_BATCH_SIZE = 6 def test_qwen2_vl(self): - model, tokenizer, processor = self.quantModel(self.NATIVE_MODEL_ID, trust_remote_code=self.TRUST_REMOTE_CODE, - dtype=self.TORCH_DTYPE) + with self.model_compat_test_context(): + model, tokenizer, processor = self.quantModel(self.NATIVE_MODEL_ID, trust_remote_code=self.TRUST_REMOTE_CODE, + dtype=self.TORCH_DTYPE) # check image to text messages = [ @@ -54,22 +55,24 @@ def test_qwen2_vl(self): inputs = inputs.to("cuda") # Inference: Generation of the output - generated_ids = model.generate(**inputs, max_new_tokens=128) - generated_ids_trimmed = [ - out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) - ] - output_text = processor.batch_decode( - generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False - )[0] + output_text = self.generate_stable_with_limit( + model, + processor, + inputs=inputs, + max_new_tokens=128, + batch_decode=True, + clean_up_tokenization_spaces=False, + ) print("output_text:", output_text) self.assertIn("dog", output_text) - # check lm_eval results + # check evaluation results self.check_kernel(model, self.KERNEL_INFERENCE) - task_results = self.lm_eval(model=model, - trust_remote_code=self.TRUST_REMOTE_CODE, - delete_quantized_model=self.DELETE_QUANTIZED_MODEL) + with self.model_compat_test_context(): + task_results = self.evaluate_model(model=model, + trust_remote_code=self.TRUST_REMOTE_CODE, + delete_quantized_model=self.DELETE_QUANTIZED_MODEL) self.check_results(task_results) diff --git a/tests/models/test_qwen2_moe_quant.py b/tests/models/test_qwen2_moe_quant.py index f614b93fd..c12f5fab6 100644 --- a/tests/models/test_qwen2_moe_quant.py +++ b/tests/models/test_qwen2_moe_quant.py @@ -5,18 +5,17 @@ from model_test import ModelTest -from gptqmodel.utils.eval import EVAL - class TestQwen2_5_Moe(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Qwen1.5-MoE-A2.7B" # Qwen/Qwen1.5-MoE-A2.7B - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + EVAL_TASKS_SLOW = { + "arc_challenge": { "chat_template": True, "acc": {"value": 0.2739, "floor_pct": 0.2}, "acc_norm": {"value": 0.3055, "floor_pct": 0.2}, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) def test_qwen2_5(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_qwen2_vl.py b/tests/models/test_qwen2_vl.py index e9d234321..4206fd330 100644 --- a/tests/models/test_qwen2_vl.py +++ b/tests/models/test_qwen2_vl.py @@ -6,24 +6,25 @@ from model_test import ModelTest from gptqmodel.models.definitions.qwen2_vl import Qwen2VLQModel -from gptqmodel.utils.eval import EVAL class TestQwen2_VL(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Qwen2-VL-2B-Instruct" - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + EVAL_TASKS_SLOW = { + "arc_challenge": { "chat_template": True, "acc": {"value": 0.3524, "floor_pct": 0.2}, "acc_norm": {"value": 0.3763, "floor_pct": 0.2}, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) TRUST_REMOTE_CODE = False EVAL_BATCH_SIZE = 6 def test_qwen2_vl(self): - model, tokenizer, processor = self.quantModel(self.NATIVE_MODEL_ID, trust_remote_code=self.TRUST_REMOTE_CODE, - dtype=self.TORCH_DTYPE) + with self.model_compat_test_context(): + model, tokenizer, processor = self.quantModel(self.NATIVE_MODEL_ID, trust_remote_code=self.TRUST_REMOTE_CODE, + dtype=self.TORCH_DTYPE) # check image to text messages = [ @@ -54,22 +55,24 @@ def test_qwen2_vl(self): inputs = inputs.to("cuda") # Inference: Generation of the output - generated_ids = model.generate(**inputs, max_new_tokens=128) - generated_ids_trimmed = [ - out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) - ] - output_text = processor.batch_decode( - generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False - )[0] + output_text = self.generate_stable_with_limit( + model, + processor, + inputs=inputs, + max_new_tokens=128, + batch_decode=True, + clean_up_tokenization_spaces=False, + ) print("output_text:", output_text) self.assertIn("dog", output_text) - # check lm_eval results + # check evaluation results self.check_kernel(model, self.KERNEL_INFERENCE) - task_results = self.lm_eval(model=model, - trust_remote_code=self.TRUST_REMOTE_CODE, - delete_quantized_model=self.DELETE_QUANTIZED_MODEL) + with self.model_compat_test_context(): + task_results = self.evaluate_model(model=model, + trust_remote_code=self.TRUST_REMOTE_CODE, + delete_quantized_model=self.DELETE_QUANTIZED_MODEL) self.check_results(task_results) diff --git a/tests/models/test_qwen3.py b/tests/models/test_qwen3.py new file mode 100644 index 000000000..eb851bf0b --- /dev/null +++ b/tests/models/test_qwen3.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium +from model_test import ModelTest + + +class TestQwen3(ModelTest): + NATIVE_MODEL_ID = "/monster/data/model/Qwen3-4B" # Qwen/Qwen3-4B + EVAL_TASKS_SLOW = { + "arc_challenge": { + "acc": {"value": {"A100": 0.5094}, "floor_pct": 0.04}, + "acc_norm": {"value": {"A100": 0.5145}, "floor_pct": 0.04}, + }, + "mmlu_stem": { + "acc": {"value": {"A100": 0.7101}, "floor_pct": 0.04}, + }, + } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) + MODEL_COMPAT_FAST_LAYER_POSITION = "first" + + + def test_qwen3(self): + self.quantize_and_evaluate() diff --git a/tests/models/test_qwen3_5.py b/tests/models/test_qwen3_5.py index 456ca06e5..49e6c506c 100644 --- a/tests/models/test_qwen3_5.py +++ b/tests/models/test_qwen3_5.py @@ -4,10 +4,7 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium from model_test import ModelTest -from gptqmodel.utils.eval import EVAL - -# | Metric | MARLIN | # |--------------------------------|----------| # | arc_challenge :: acc,none | 0.6092 | # | arc_challenge :: acc_norm,none | 0.6143 | @@ -18,16 +15,17 @@ class TestQwen3_5(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Qwen3.5-27B" EVAL_BATCH_SIZE = 64 DATASET_CONCAT_SIZE = 2048 - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + EVAL_TASKS_SLOW = { + "arc_challenge": { "acc": {"value": 0.6092, "floor_pct": 0.04}, "acc_norm": {"value": 0.6143, "floor_pct": 0.04}, }, - EVAL.LM_EVAL.MMLU_STEM: { + "mmlu_stem": { "acc": {"value": 0.8461, "floor_pct": 0.04}, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) def test_qwen3_5(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_qwen3_5_moe.py b/tests/models/test_qwen3_5_moe.py index 2f9be30fe..810d81619 100644 --- a/tests/models/test_qwen3_5_moe.py +++ b/tests/models/test_qwen3_5_moe.py @@ -4,22 +4,21 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium from model_test import ModelTest -from gptqmodel.quantization.config import FailSafe, VramStrategy -from gptqmodel.utils.eval import EVAL +from gptqmodel.quantization.config import ExpertsRoutingOverride, Fallback, MoEConfig, VramStrategy class TestQwen3_5Moe(ModelTest): - FAILSAFE = FailSafe() + FALLBACK = Fallback() # FORMAT = FORMAT.GEMM # METHOD = METHOD.AWQ NATIVE_MODEL_ID = "/monster/data/model/Qwen3.5-35B-A3B" - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + EVAL_TASKS_SLOW = { + "arc_challenge": { "acc": {"value": 0.5887, "floor_pct": 0.04}, "acc_norm": {"value": 0.6100, "floor_pct": 0.04}, }, - EVAL.LM_EVAL.MMLU_STEM: { + "mmlu_stem": { "chat_template": False, "acc": { "value": 0.8106, @@ -27,9 +26,15 @@ class TestQwen3_5Moe(ModelTest): }, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) - VRAM_STRATEGY = VramStrategy.BALANCED - OFFLOAD_TO_DISK = False # FIXME Currently, after defuser transforms the model, OFFLOAD_TO_DISK must be False for quantization. + DENSE_VRAM_STRATEGY = VramStrategy.EXCLUSIVE + # Keep the dense serial path on the first visible GPU and spread experts across the rest. + DENSE_VRAM_STRATEGY_DEVICES = ["cuda:0"] + MOE_VRAM_STRATEGY = VramStrategy.BALANCED + MOE_VRAM_STRATEGY_DEVICES = ["cuda:1", "cuda:2", "cuda:3", "cuda:4", "cuda:5"] + # Route every calibration token through every expert so MoE quant sees full coverage. + MOE_CONFIG = MoEConfig(routing=ExpertsRoutingOverride(num_experts_per_tok="all")) def test_qwen3_5_moe(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_qwen3_5_moe_ab_regression.py b/tests/models/test_qwen3_5_moe_ab_regression.py new file mode 100644 index 000000000..0b65a689d --- /dev/null +++ b/tests/models/test_qwen3_5_moe_ab_regression.py @@ -0,0 +1,55 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import os +import subprocess +import sys +import tempfile +import unittest +from pathlib import Path + + +class TestQwen3_5MoeABRegression(unittest.TestCase): + """Opt-in real-model A/B benchmark against a baseline worktree.""" + + RUN_ENV = "GPTQMODEL_RUN_QWEN35_MOE_AB" + BASELINE_ENV = "GPTQMODEL_BASELINE_ROOT" + + def test_qwen3_5_moe_two_layer_ab_benchmark(self) -> None: + if os.environ.get(self.RUN_ENV, "").strip().lower() not in {"1", "true", "yes", "on"}: + self.skipTest(f"Set {self.RUN_ENV}=1 to run the real Qwen 3.5 MoE A/B benchmark.") + + repo_root = Path(__file__).resolve().parents[2] + script_path = repo_root / "scripts" / "benchmark_qwen35_moe_ab.py" + baseline_root = Path(os.environ.get(self.BASELINE_ENV, "/root/gptqmodel-main")).resolve() + if not baseline_root.exists(): + self.skipTest(f"Baseline repo root does not exist: {baseline_root}") + + env = os.environ.copy() + env["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + env.setdefault("CUDA_VISIBLE_DEVICES", "0,1") + env["PYTHON_GIL"] = "0" + env["DEBUG"] = "1" + + with tempfile.TemporaryDirectory(prefix="qwen35_moe_ab_regression_") as temp_dir: + subprocess.run( + [ + sys.executable, + str(script_path), + "--current-root", + str(repo_root), + "--baseline-root", + str(baseline_root), + "--output-dir", + temp_dir, + "--current-vram-strategy", + "dense_home_moe_balanced", + "--baseline-vram-strategy", + "balanced", + ], + check=True, + cwd=repo_root, + env=env, + ) diff --git a/tests/models/test_qwen3_8b_fp8_auto_decoder.py b/tests/models/test_qwen3_8b_fp8_auto_decoder.py new file mode 100644 index 000000000..a70a5f48c --- /dev/null +++ b/tests/models/test_qwen3_8b_fp8_auto_decoder.py @@ -0,0 +1,188 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import os +import sys + +import torch + + +TESTS_MODELS_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +if TESTS_MODELS_ROOT not in sys.path: + sys.path.insert(0, TESTS_MODELS_ROOT) + +from model_test import ModelTest + +from gptqmodel import BACKEND +from gptqmodel.looper.module_looper import StopMainLoop +from gptqmodel.models import auto +from gptqmodel.quantization import AutoModuleDecoderConfig +from gptqmodel.quantization.dtype import get_device_dtype_support +from gptqmodel.utils.torch import torch_empty_cache + + +FIRST_LAYER_ONLY_NEGATIVE_MATCH = r"^model\.layers\.(?!0\.)\d+\." +# Restrict the regression to layer 0 so one visible GPU can validate the mode switch quickly. + + +class TestQwen3_8BFp8AutoDecoder(ModelTest): + """Verify FP8 auto-decoder mode selection on one visible GPU for local Qwen3-8B-FP8.""" + + NATIVE_MODEL_ID = "/mnt/SFS-6CFyUykx/models/Qwen3-8B-FP8" + LOAD_BACKEND = BACKEND.TORCH + QUANT_BACKEND = BACKEND.TORCH + TORCH_DTYPE = "bfloat16" + USE_FLASH_ATTN = False + QUANT_BATCH_SIZE = 1 + DATASET_SIZE = 4 + DATASET_CONCAT_SIZE = 2048 + OFFLOAD_TO_DISK = True + DYNAMIC = { + f"-:{FIRST_LAYER_ONLY_NEGATIVE_MATCH}": {}, + } + CALIBRATION_DATASET = [ + "Explain how rotary position embeddings are applied in decoder-only transformers.", + "Summarize the tradeoffs between dense weights and FP8 checkpoint storage.", + "Describe how auto-decoder preprocessing can switch between native and decoded forward paths.", + "Write a short note on why calibration examples matter for GPTQ quantization.", + ] + + def _build_quantize_config(self): + cfg = super()._build_quantize_config() + cfg.preprocessors = [ + AutoModuleDecoderConfig( + target_dtype=torch.bfloat16, + ) + ] + cfg.wait_for_submodule_finalizers = True + return cfg + + def _expected_forward_mode(self) -> str: + support = get_device_dtype_support(torch.device("cuda"), validate=False) + return "native" if torch.float8_e4m3fn in support.advertised_linear_dtypes else "decode" + + def test_qwen3_8b_fp8_auto_decoder_selects_forward_role_by_gpu_capability(self) -> None: + if not torch.cuda.is_available(): + self.skipTest("CUDA is required for the Qwen3-8B FP8 auto-decoder test.") + + model = None + dataset = None + try: + quantize_config = self._build_quantize_config() + quantize_config.device = torch.device("cuda") + + model_definition = auto.check_and_get_model_definition( + self.NATIVE_MODEL_ID, + self.TRUST_REMOTE_CODE, + ) + model = model_definition.from_pretrained( + pretrained_model_id_or_path=self.NATIVE_MODEL_ID, + quantize_config=quantize_config, + backend=self.LOAD_BACKEND, + trust_remote_code=self.TRUST_REMOTE_CODE, + dtype=self.TORCH_DTYPE, + attn_implementation="eager", + ) + model.layer_callback = self._build_layer_stop_callback(0) + + dataset = self.CALIBRATION_DATASET[: self.DATASET_SIZE] + + try: + model.quantize( + dataset, + calibration_concat_size=self.DATASET_CONCAT_SIZE, + calibration_concat_separator=self.DATASET_CONCAT_SEPARATOR, + calibration_sort=self.DATASET_SORT, + backend=self.QUANT_BACKEND, + batch_size=self.QUANT_BATCH_SIZE, + ) + except StopMainLoop: + # The layer callback intentionally stops after layer 0 once the mode decision is observed. + pass + + events = [ + entry + for entry in getattr(model, "auto_module_decoder_events", []) + if entry["module"].startswith("model.layers.0.") + ] + self.assertTrue(events, "Expected layer-0 auto-decoder events for Qwen3-8B-FP8.") + self.assertTrue(all(entry["source_dtype"] == "float8_e4m3fn" for entry in events)) + self.assertTrue(all(entry["target_dtype"] == "bfloat16" for entry in events)) + + expected_mode = self._expected_forward_mode() + if expected_mode == "native": + self.assertTrue( + any(entry["forward_mode"] == "native" for entry in events), + f"Expected at least one native FP8 forward event, got {events[:8]}", + ) + else: + self.assertTrue( + all(entry["forward_mode"] == "decode" for entry in events), + f"Expected decode-only events, got {events[:8]}", + ) + finally: + del dataset + del model + torch_empty_cache() + + def test_qwen3_8b_fp8_auto_decoder_uses_native_on_weight_scale_checkpoint(self) -> None: + if not torch.cuda.is_available(): + self.skipTest("CUDA is required for the Qwen3-8B FP8 auto-decoder test.") + + support = get_device_dtype_support(torch.device("cuda"), validate=False) + if torch.float8_e4m3fn not in support.advertised_linear_dtypes: + self.skipTest("This regression requires a GPU that advertises native FP8 linear support.") + + model = None + dataset = None + try: + quantize_config = self._build_quantize_config() + quantize_config.device = torch.device("cuda") + + model_definition = auto.check_and_get_model_definition( + self.NATIVE_MODEL_ID, + self.TRUST_REMOTE_CODE, + ) + model = model_definition.from_pretrained( + pretrained_model_id_or_path=self.NATIVE_MODEL_ID, + quantize_config=quantize_config, + backend=self.LOAD_BACKEND, + trust_remote_code=self.TRUST_REMOTE_CODE, + dtype=self.TORCH_DTYPE, + attn_implementation="eager", + ) + model.layer_callback = self._build_layer_stop_callback(0) + + dataset = self.CALIBRATION_DATASET[: self.DATASET_SIZE] + + try: + model.quantize( + dataset, + calibration_concat_size=self.DATASET_CONCAT_SIZE, + calibration_concat_separator=self.DATASET_CONCAT_SEPARATOR, + calibration_sort=self.DATASET_SORT, + backend=self.QUANT_BACKEND, + batch_size=self.QUANT_BATCH_SIZE, + ) + except StopMainLoop: + # The layer callback intentionally stops after layer 0 once the mode decision is observed. + pass + + events = [ + entry + for entry in getattr(model, "auto_module_decoder_events", []) + if entry["module"].startswith("model.layers.0.") + ] + self.assertTrue(events, "Expected layer-0 auto-decoder events for Qwen3-8B-FP8.") + self.assertTrue(all(entry["source_dtype"] == "float8_e4m3fn" for entry in events)) + self.assertTrue(all(entry["target_dtype"] == "bfloat16" for entry in events)) + self.assertTrue( + all(entry["forward_mode"] == "native" for entry in events), + f"Expected native events on this weight_scale-backed checkpoint, got {events[:8]}", + ) + finally: + del dataset + del model + torch_empty_cache() diff --git a/tests/models/test_qwen3_8b_fp8_gsm8k_last4.py b/tests/models/test_qwen3_8b_fp8_gsm8k_last4.py new file mode 100644 index 000000000..6c1f5c4d0 --- /dev/null +++ b/tests/models/test_qwen3_8b_fp8_gsm8k_last4.py @@ -0,0 +1,111 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import os +import sys + +import torch + + +TESTS_MODELS_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +if TESTS_MODELS_ROOT not in sys.path: + sys.path.insert(0, TESTS_MODELS_ROOT) + +from model_test import ModelTest + +from gptqmodel import BACKEND +from gptqmodel.quantization import AutoModuleDecoderConfig + + +LAST_FOUR_ONLY_NEGATIVE_MATCH = r"^model\.layers\.(?:[0-2]?\d|3[0-1])\." +# Qwen3-8B-FP8 exposes 36 decoder layers; skip 0-31 so only the last four are quantized. + + +def _gsm8k_expected_acc() -> float: + return float(os.environ.get("GPTQMODEL_QWEN3_8B_FP8_LAST4_GSM8K_ACC", "0.0")) + + +class TestQwen3_8B_FP8_Gsm8kLast4(ModelTest): + # Keep one stable saved checkpoint so eval-only repro runs can reuse the exact post-quant model. + SAVE_PATH = os.environ.get( + "GPTQMODEL_QWEN3_8B_FP8_LAST4_SAVE_PATH", + "/tmp/qwen3_8b_fp8_last4_gptq_saved_ckpt", + ) + DELETE_QUANTIZED_MODEL = False + NATIVE_MODEL_ID = "/mnt/SFS-6CFyUykx/models/Qwen3-8B-FP8" + LOAD_BACKEND = BACKEND.TORCH + QUANT_BACKEND = BACKEND.TORCH + PIN_CUDA_DEVICE = 0 + TORCH_DTYPE = "bfloat16" + USE_FLASH_ATTN = False + QUANT_BATCH_SIZE = 1 + EVAL_BATCH_SIZE = 32 + DATASET_SIZE = 32 + DATASET_CONCAT_SIZE = 2048 + OFFLOAD_TO_DISK = True + DYNAMIC = { + f"-:{LAST_FOUR_ONLY_NEGATIVE_MATCH}": {}, + } + CALIBRATION_DATASET = [ + "Solve the arithmetic word problem carefully and provide the final numeric answer.", + "Reason step by step about a math problem, then end with a short final answer.", + "Explain the difference between calibration data and evaluation data in quantization workflows.", + "Summarize the tradeoffs between FP8 checkpoints and dense bfloat16 weights for inference.", + "Describe why a decoder-only language model may need left padding during batched generation.", + "Write a concise explanation of how GPTQ quantization uses calibration activations.", + "Explain how Qwen-style chat templates can affect benchmark accuracy when applied incorrectly.", + "Give a short note on why some FP8 checkpoints store scale and others store inverse scale.", + ] * 4 + EVAL_TASKS_SLOW = { + "gsm8k_platinum_cot": { + "chat_template": True, + "evalution_use_model_path": True, + "evalution_batch_size": "auto", + "evalution_model_args": { + "dtype": "bfloat16", + "attn_implementation": "paged|flash_attention_2", + "device": "cuda:0", + }, + "evalution_suite_kwargs": { + "batch_size": 32, + "max_new_tokens": 256, + "stream": True, + }, + "acc,num": { + "value": _gsm8k_expected_acc(), + "floor_pct": 0.04, + "ceil_pct": 1.0, + }, + }, + } + EVAL_TASKS_FAST = EVAL_TASKS_SLOW + + @classmethod + def load_dataset(cls, tokenizer=None, rows: int = 0): + del tokenizer + if rows > 0: + return cls.CALIBRATION_DATASET[:rows] + return list(cls.CALIBRATION_DATASET) + + def _build_quantize_config(self): + cfg = super()._build_quantize_config() + # The source checkpoint is FP8, so attach the auto-decoder preprocessor before GPTQ quantization. + cfg.preprocessors = [ + AutoModuleDecoderConfig(target_dtype=torch.bfloat16) + ] + cfg.wait_for_submodule_finalizers = True + return cfg + + def _model_test_mode(self) -> str: + # This regression intentionally validates a fixed last-4-layer quantization recipe, + # so opt out of the harness's default fast-mode layer trimming. + return self.MODEL_TEST_MODE_SLOW + + def test_qwen3_8b_fp8_last4_gsm8k_platinum(self): + if _gsm8k_expected_acc() <= 0.0: + self.skipTest( + "Set GPTQMODEL_QWEN3_8B_FP8_LAST4_GSM8K_ACC to a recorded gsm8k_platinum_cot acc,num baseline." + ) + self.quant_lm_eval() diff --git a/tests/models/test_qwen3_8b_nvfp4.py b/tests/models/test_qwen3_8b_nvfp4.py new file mode 100644 index 000000000..755d201e1 --- /dev/null +++ b/tests/models/test_qwen3_8b_nvfp4.py @@ -0,0 +1,81 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import os +import sys + +import torch +from datasets import load_dataset as hf_load_dataset + + +TESTS_MODELS_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +if TESTS_MODELS_ROOT not in sys.path: + sys.path.insert(0, TESTS_MODELS_ROOT) + +from model_test import ModelTest + +from gptqmodel.quantization import AutoModuleDecoderConfig + + +class TestQwen3_8BNVFP4(ModelTest): + """End-to-end NVFP4 regression for local Qwen3-8B, quantizing the last four layers.""" + + SAVE_PATH = os.environ.get( + "GPTQMODEL_QWEN3_8B_NVFP4_SAVE_PATH", + "/tmp/qwen3_8b_nvfp4_last4_gptq_saved_ckpt", + ) + DELETE_QUANTIZED_MODEL = False + NATIVE_MODEL_ID = "/mnt/SFS-6CFyUykx/models/Qwen3-8B-NVFP4" + PIN_CUDA_DEVICE = 0 + TORCH_DTYPE = "bfloat16" + USE_FLASH_ATTN = False + QUANT_BATCH_SIZE = 1 + DATASET_CONCAT_SIZE = 2048 + OFFLOAD_TO_DISK = True + MODEL_COMPAT_FAST_LAYER_COUNT = 4 + MODEL_COMPAT_FAST_LAYER_POSITION = "last" + + EVAL_TASKS_FAST = { + "gsm8k_platinum_cot": { + "chat_template": True, + "evalution_use_model_path": True, + "evalution_batch_size": "auto", + "evalution_model_args": { + "dtype": "bfloat16", + "attn_implementation": "eager", + "device": "cuda:0", + }, + "evalution_suite_kwargs": { + "batch_size": 16, + "max_new_tokens": 256, + "stream": True, + }, + "acc,num": { + "value": 0.25, + "floor_pct": 0.25, + "ceil_pct": 1.0, + }, + }, + } + + def _build_quantize_config(self): + cfg = super()._build_quantize_config() + cfg.preprocessors = [ + AutoModuleDecoderConfig( + target_dtype=torch.bfloat16, + ) + ] + cfg.wait_for_submodule_finalizers = True + return cfg + + @classmethod + def load_dataset(cls, tokenizer=None, rows: int = 0): + dataset = hf_load_dataset("neuralmagic/calibration", name="LLM", split="train") + if rows > 0: + return dataset.select(range(min(rows, len(dataset)))) + return dataset + + def test_qwen3_8b_nvfp4(self): + self.quant_lm_eval() diff --git a/tests/models/test_qwen3_moe.py b/tests/models/test_qwen3_moe.py index 7a6b99929..fc53fff2c 100644 --- a/tests/models/test_qwen3_moe.py +++ b/tests/models/test_qwen3_moe.py @@ -4,22 +4,20 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium from model_test import ModelTest -from gptqmodel.quantization.config import FailSafe, VramStrategy -from gptqmodel.utils.eval import EVAL +from gptqmodel.quantization.config import ExpertsRoutingOverride, Fallback, MoEConfig, VramStrategy -# | Metric | MARLIN | # |--------------------------------|----------| # | arc_challenge :: acc,none | 0.5094 | # | arc_challenge :: acc_norm,none | 0.5486 | -# Qwen3-30B-A3B-MainBranch-FailSafe_Enable +# Qwen3-30B-A3B-MainBranch-Fallback_Enable # # | Tasks |Version|Filter|n-shot| Metric | |Value | |Stderr| # |-------------|------:|------|-----:|--------|---|-----:|---|-----:| # |arc_challenge| 1|none | 0|acc |↑ |0.5307|± |0.0146| # | | |none | 0|acc_norm|↑ |0.5674|± |0.0145| class TestQwen3Moe(ModelTest): - FAILSAFE = FailSafe() + FALLBACK = Fallback() # FORMAT = FORMAT.GEMM # METHOD = METHOD.AWQ @@ -27,8 +25,8 @@ class TestQwen3Moe(ModelTest): # DEVICE = torch.device("cpu") # HESSIAN_CHUNK_SIZE = 256 * 1024 * 1024 NATIVE_MODEL_ID = "/monster/data/model/Qwen3-30B-A3B" - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + EVAL_TASKS_SLOW = { + "arc_challenge": { "acc": {"value": 0.5137, "floor_pct": 0.04}, "acc_norm": {"value": 0.5307, "floor_pct": 0.04}, }, @@ -49,14 +47,14 @@ class TestQwen3Moe(ModelTest): # as lead to unreliable evaluation results. # ###################################################################### - EVAL.LM_EVAL.GSM8K_PLATINUM_COT: { + "gsm8k_platinum_cot": { "chat_template": False, - "exact_match,flexible-extract": { + "acc,num": { "value": 0.9380, "floor_pct": 0.04, }, }, - EVAL.LM_EVAL.MMLU_STEM: { + "mmlu_stem": { "chat_template": False, "acc": { "value": 0.7805, @@ -64,8 +62,11 @@ class TestQwen3Moe(ModelTest): }, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) - VRAM_STRATEGY = VramStrategy.BALANCED + DENSE_VRAM_STRATEGY = VramStrategy.EXCLUSIVE + MOE_VRAM_STRATEGY = VramStrategy.BALANCED + MOE_CONFIG = MoEConfig(routing=ExpertsRoutingOverride()) # TRUST_REMOTE_CODE = False # APPLY_CHAT_TEMPLATE = True # EVAL_BATCH_SIZE = 6 @@ -80,4 +81,4 @@ class TestQwen3Moe(ModelTest): # CALIB_NOISE_PERCENT = 0.025 def test_mimo(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_qwen3_next.py b/tests/models/test_qwen3_next.py index 504ae934f..bfe8cbe75 100644 --- a/tests/models/test_qwen3_next.py +++ b/tests/models/test_qwen3_next.py @@ -6,27 +6,26 @@ from model_test import ModelTest from gptqmodel.quantization.config import VramStrategy -from gptqmodel.utils.eval import EVAL -# | Metric | MARLIN | # |--------------------------------|----------| # | arc_challenge :: acc,none | 0.6271 | # | arc_challenge :: acc_norm,none | 0.6613 | # | mmlu_stem :: acc,none | 0.8403 | class TestQwen3Next(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Qwen3-Next-80B-A3B-Instruct" - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + EVAL_TASKS_SLOW = { + "arc_challenge": { "acc": {"value": 0.6271, "floor_pct": 0.04}, "acc_norm": {"value": 0.6613, "floor_pct": 0.04}, }, - EVAL.LM_EVAL.MMLU_STEM: { + "mmlu_stem": { "acc": {"value": 0.8403, "floor_pct": 0.04}, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) - VRAM_STRATEGY = VramStrategy.BALANCED + DENSE_VRAM_STRATEGY = VramStrategy.BALANCED # DATASET_SIZE = 2048 # TRUST_REMOTE_CODE = True # APPLY_CHAT_TEMPLATE = True @@ -43,4 +42,4 @@ class TestQwen3Next(ModelTest): # USE_FLASH_ATTN = True def test_mimo(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_qwen3_omni.py b/tests/models/test_qwen3_omni.py index 92617d3d4..57a97b292 100644 --- a/tests/models/test_qwen3_omni.py +++ b/tests/models/test_qwen3_omni.py @@ -5,17 +5,16 @@ from model_test import ModelTest -from gptqmodel.utils.eval import EVAL - class TestQwen3Omni(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Qwen3-Omni-30B-A3B-Instruct/" - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + EVAL_TASKS_SLOW = { + "arc_challenge": { "acc": {"value": 0.2739, "floor_pct": 0.2}, "acc_norm": {"value": 0.3055, "floor_pct": 0.2}, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) # # TRUST_REMOTE_CODE = False # APPLY_CHAT_TEMPLATE = True # # EVAL_BATCH_SIZE = 6 @@ -28,4 +27,4 @@ class TestQwen3Omni(ModelTest): QUANT_BATCH_SIZE = 1 def test_omni(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_qwen3_vl.py b/tests/models/test_qwen3_vl.py index 411801a01..107e81712 100644 --- a/tests/models/test_qwen3_vl.py +++ b/tests/models/test_qwen3_vl.py @@ -6,24 +6,25 @@ from model_test import ModelTest from gptqmodel.models.definitions.qwen3_vl import Qwen3_VLQModel -from gptqmodel.utils.eval import EVAL class TestQwen3_VL(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Qwen3-VL-2B-Instruct/" - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + EVAL_TASKS_SLOW = { + "arc_challenge": { "chat_template": True, "acc": {"value": 0.3618, "floor_pct": 0.04}, "acc_norm": {"value": 00.3882, "floor_pct": 0.04}, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) TRUST_REMOTE_CODE = False EVAL_BATCH_SIZE = 6 def test_qwen3_vl(self): - model, tokenizer, processor = self.quantModel(self.NATIVE_MODEL_ID, trust_remote_code=self.TRUST_REMOTE_CODE, - dtype=self.TORCH_DTYPE) + with self.model_compat_test_context(): + model, tokenizer, processor = self.quantModel(self.NATIVE_MODEL_ID, trust_remote_code=self.TRUST_REMOTE_CODE, + dtype=self.TORCH_DTYPE) # check image to text messages = [ @@ -54,22 +55,24 @@ def test_qwen3_vl(self): inputs = inputs.to("cuda") # Inference: Generation of the output - generated_ids = model.generate(**inputs, max_new_tokens=128) - generated_ids_trimmed = [ - out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) - ] - output_text = processor.batch_decode( - generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False - )[0] + output_text = self.generate_stable_with_limit( + model, + processor, + inputs=inputs, + max_new_tokens=128, + batch_decode=True, + clean_up_tokenization_spaces=False, + ) print("output_text:", output_text) self.assertIn("dog", output_text) - # check lm_eval results + # check evaluation results self.check_kernel(model, self.KERNEL_INFERENCE) - task_results = self.lm_eval(model=model, - trust_remote_code=self.TRUST_REMOTE_CODE, - delete_quantized_model=self.DELETE_QUANTIZED_MODEL) + with self.model_compat_test_context(): + task_results = self.evaluate_model(model=model, + trust_remote_code=self.TRUST_REMOTE_CODE, + delete_quantized_model=self.DELETE_QUANTIZED_MODEL) self.check_results(task_results) diff --git a/tests/models/test_seed_oss.py b/tests/models/test_seed_oss.py index eb231b621..1c9cfbe4d 100644 --- a/tests/models/test_seed_oss.py +++ b/tests/models/test_seed_oss.py @@ -5,20 +5,19 @@ from model_test import ModelTest -from gptqmodel.utils.eval import EVAL - class TestSeedOSS(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Seed-OSS-36B-Instruct/" - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + EVAL_TASKS_SLOW = { + "arc_challenge": { "chat_template": True, "acc": {"value": 0.2739, "floor_pct": 0.2}, "acc_norm": {"value": 0.3055, "floor_pct": 0.2}, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) TRUST_REMOTE_CODE = False EVAL_BATCH_SIZE = 6 def test_seed_oss(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_stablelm.py b/tests/models/test_stablelm.py index ffd9e859d..adcab0cea 100644 --- a/tests/models/test_stablelm.py +++ b/tests/models/test_stablelm.py @@ -5,19 +5,23 @@ from model_test import ModelTest -from gptqmodel.utils.eval import EVAL - class TestStablelm(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/stablelm-base-alpha-3b" - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + EVAL_TASKS_SLOW = { + "arc_challenge": { "acc": {"value": 0.2363, "floor_pct": 0.2}, "acc_norm": {"value": 0.2577, "floor_pct": 0.2}, }, } + EVAL_TASKS_FAST = { + "arc_challenge": { + "acc": {"value": 0.23720136518771331, "floor_pct": 0.2, "ceil_pct": 1.0}, + "acc_norm": {"value": 0.26023890784982934, "floor_pct": 0.2, "ceil_pct": 1.0}, + }, + } TRUST_REMOTE_CODE = True EVAL_BATCH_SIZE = 6 def test_stablelm(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_starcode2.py b/tests/models/test_starcode2.py index 7dddcc37c..ac2bc9c04 100644 --- a/tests/models/test_starcode2.py +++ b/tests/models/test_starcode2.py @@ -11,8 +11,11 @@ class TestStarCode2(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/starcoder2-3b" NATIVE_ARC_CHALLENGE_ACC = 0.2329 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.2824 + NATIVE_ARC_CHALLENGE_ACC_SLOW = NATIVE_ARC_CHALLENGE_ACC + NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW = NATIVE_ARC_CHALLENGE_ACC_NORM + NATIVE_ARC_CHALLENGE_ACC_FAST = 0.2858361774744027 + NATIVE_ARC_CHALLENGE_ACC_NORM_FAST = 0.30802047781569963 TORCH_DTYPE = torch.float16 def test_starcode2(self): - self.quant_lm_eval() - + self.quantize_and_evaluate() diff --git a/tests/models/test_telechat2.py b/tests/models/test_telechat2.py index 0f38add01..46082e8ac 100644 --- a/tests/models/test_telechat2.py +++ b/tests/models/test_telechat2.py @@ -4,25 +4,28 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium from model_test import ModelTest -from gptqmodel.utils.eval import EVAL - class TestTeleChat_2(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/TeleChat2-7B/" # "Tele-AI/TeleChat2-7B" NATIVE_ARC_CHALLENGE_ACC = 0.3677 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.3831 + NATIVE_ARC_CHALLENGE_ACC_SLOW = NATIVE_ARC_CHALLENGE_ACC + NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW = NATIVE_ARC_CHALLENGE_ACC_NORM + NATIVE_ARC_CHALLENGE_ACC_FAST = 0.37627986348122866 + NATIVE_ARC_CHALLENGE_ACC_NORM_FAST = 0.38822525597269625 TRUST_REMOTE_CODE = True EVAL_BATCH_SIZE = 6 USE_VLLM = False USE_FLASH_ATTN = False - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + EVAL_TASKS_SLOW = { + "arc_challenge": { "chat_template": True, "acc": {"value": NATIVE_ARC_CHALLENGE_ACC}, "acc_norm": {"value": NATIVE_ARC_CHALLENGE_ACC_NORM}, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) def test_telechat2(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_tinyllama.py b/tests/models/test_tinyllama.py index 8c122617c..e448f5f65 100644 --- a/tests/models/test_tinyllama.py +++ b/tests/models/test_tinyllama.py @@ -10,7 +10,11 @@ class TestTinyllama(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/TinyLlama-1.1B-Chat-v1.0" NATIVE_ARC_CHALLENGE_ACC = 0.2995 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.3268 + NATIVE_ARC_CHALLENGE_ACC_SLOW = NATIVE_ARC_CHALLENGE_ACC + NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW = NATIVE_ARC_CHALLENGE_ACC_NORM + NATIVE_ARC_CHALLENGE_ACC_FAST = 0.30802047781569963 + NATIVE_ARC_CHALLENGE_ACC_NORM_FAST = 0.3250853242320819 TRUST_REMOTE_CODE = True def test_tinyllama(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_voxtral.py b/tests/models/test_voxtral.py index f7fec6286..8b92a2765 100644 --- a/tests/models/test_voxtral.py +++ b/tests/models/test_voxtral.py @@ -5,22 +5,25 @@ from model_test import ModelTest -from gptqmodel.utils.eval import EVAL - class TestVoxtral(ModelTest): - NATIVE_MODEL_ID = "/monster/data/shared/model/Voxtral-Mini-3B-2507/" + NATIVE_MODEL_ID = "/monster/data/model/Voxtral-Mini-3B-2507" NATIVE_ARC_CHALLENGE_ACC = 0.3392 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.3360 + NATIVE_ARC_CHALLENGE_ACC_SLOW = NATIVE_ARC_CHALLENGE_ACC + NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW = NATIVE_ARC_CHALLENGE_ACC_NORM + NATIVE_ARC_CHALLENGE_ACC_FAST = NATIVE_ARC_CHALLENGE_ACC_SLOW + NATIVE_ARC_CHALLENGE_ACC_NORM_FAST = NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW TRUST_REMOTE_CODE = False EVAL_BATCH_SIZE = 6 - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + EVAL_TASKS_SLOW = { + "arc_challenge": { "chat_template": False, "acc": {"value": NATIVE_ARC_CHALLENGE_ACC}, "acc_norm": {"value": NATIVE_ARC_CHALLENGE_ACC_NORM}, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) def test_voxtral(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/models/test_xverse.py b/tests/models/test_xverse.py index 0f1cff63d..8c76df046 100644 --- a/tests/models/test_xverse.py +++ b/tests/models/test_xverse.py @@ -5,22 +5,28 @@ from model_test import ModelTest -from gptqmodel.utils.eval import EVAL - class TestXVerse(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/XVERSE-7B-Chat" # "xverse/XVERSE-7B-Chat" - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + EVAL_TASKS_SLOW = { + "arc_challenge": { "chat_template": True, "acc": {"value": 0.4198, "floor_pct": 0.2}, "acc_norm": {"value": 0.4044, "floor_pct": 0.2}, }, } + EVAL_TASKS_FAST = ModelTest.derive_fast_eval_tasks(EVAL_TASKS_SLOW) TRUST_REMOTE_CODE = True EVAL_BATCH_SIZE = 6 USE_VLLM = False USE_FLASH_ATTN = False def test_xverse(self): - self.quant_lm_eval() + try: + self.load_tokenizer(self.NATIVE_MODEL_ID, trust_remote_code=self.TRUST_REMOTE_CODE) + except Exception as exc: + if "add_prefix_space does not match declared prepend_scheme" in str(exc): + self.skipTest(f"Tokenizer assets are incompatible with the installed tokenizers runtime: {exc}") + raise + + self.quantize_and_evaluate() diff --git a/tests/models/test_yi.py b/tests/models/test_yi.py index 4be697f35..f7029d240 100644 --- a/tests/models/test_yi.py +++ b/tests/models/test_yi.py @@ -5,22 +5,32 @@ from model_test import ModelTest -from gptqmodel.utils.eval import EVAL - class TestYi(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Yi-Coder-1.5B-Chat" # "01-ai/Yi-Coder-1.5B-Chat" NATIVE_ARC_CHALLENGE_ACC = 0.2679 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.2986 + NATIVE_ARC_CHALLENGE_ACC_SLOW = NATIVE_ARC_CHALLENGE_ACC + NATIVE_ARC_CHALLENGE_ACC_NORM_SLOW = NATIVE_ARC_CHALLENGE_ACC_NORM + NATIVE_ARC_CHALLENGE_ACC_FAST = 0.24232081911262798 + NATIVE_ARC_CHALLENGE_ACC_NORM_FAST = 0.2781569965870307 TRUST_REMOTE_CODE = True EVAL_BATCH_SIZE = 4 - EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + APPLY_CHAT_TEMPLATE = True + EVAL_TASKS_SLOW = { + "arc_challenge": { "chat_template": True, "acc": {"value": NATIVE_ARC_CHALLENGE_ACC}, "acc_norm": {"value": NATIVE_ARC_CHALLENGE_ACC_NORM}, }, } + EVAL_TASKS_FAST = { + "arc_challenge": { + "chat_template": True, + "acc": {"value": NATIVE_ARC_CHALLENGE_ACC_FAST, "ceil_pct": 1.0}, + "acc_norm": {"value": NATIVE_ARC_CHALLENGE_ACC_NORM_FAST, "ceil_pct": 1.0}, + }, + } def test_yi(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/module_tree/test_model_alignment.py b/tests/module_tree/test_model_alignment.py index dd17882df..5c066948b 100644 --- a/tests/module_tree/test_model_alignment.py +++ b/tests/module_tree/test_model_alignment.py @@ -7,10 +7,20 @@ from pathlib import Path from accelerate import init_empty_weights -from transformers import AutoConfig, AutoModelForCausalLM +from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForImageTextToText, AutoModelForTextToWaveform from gptqmodel.models.definitions.dots1 import Dots1QModel +from gptqmodel.models.definitions.qwen2 import Qwen2QModel +from gptqmodel.models.definitions.qwen2_5_omni import Qwen2_5_OmniGPTQ +from gptqmodel.models.definitions.qwen2_5_vl import Qwen2_5_VLQModel +from gptqmodel.models.definitions.qwen2_moe import Qwen2MoeQModel +from gptqmodel.models.definitions.qwen2_vl import Qwen2VLQModel from gptqmodel.models.definitions.qwen3 import Qwen3QModel +from gptqmodel.models.definitions.qwen3_5 import Qwen3_5QModel +from gptqmodel.models.definitions.qwen3_5_moe import Qwen3_5_MoeQModel +from gptqmodel.models.definitions.qwen3_moe import Qwen3MoeQModel +from gptqmodel.models.definitions.qwen3_next import Qwen3NextGPTQ +from gptqmodel.models.definitions.qwen3_omni_moe import Qwen3OmniMoeGPTQ sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "models")) @@ -83,3 +93,215 @@ def test_marin_module_tree(self): self.assertTrue(hasattr(decoder_layer.self_attn, "o_proj")) self.assertIn("q_norm:!", Qwen3QModel.module_tree[3]["self_attn"]) self.assertIn("k_norm:!", Qwen3QModel.module_tree[3]["self_attn"]) + + +class TestQwen2ModuleTree(ModelTest): + NATIVE_MODEL_ID = "/monster/data/model/Qwen2.5-0.5B-Instruct" + + def test_qwen2_module_tree(self): + config = AutoConfig.from_pretrained(self.NATIVE_MODEL_ID, trust_remote_code=False) + with init_empty_weights(include_buffers=True): + shell = AutoModelForCausalLM.from_config(config, trust_remote_code=False) + + decoder_layer = shell.model.layers[0] + self.assertTrue(hasattr(decoder_layer.self_attn, "q_proj")) + self.assertTrue(hasattr(decoder_layer.self_attn, "k_proj")) + self.assertTrue(hasattr(decoder_layer.self_attn, "v_proj")) + self.assertTrue(hasattr(decoder_layer.self_attn, "o_proj")) + self.assertFalse(hasattr(decoder_layer.self_attn, "q_norm")) + self.assertFalse(hasattr(decoder_layer.self_attn, "k_norm")) + self.assertIn("q_proj:0", Qwen2QModel.module_tree[3]["self_attn"]) + self.assertIn("o_proj:1", Qwen2QModel.module_tree[3]["self_attn"]) + + +class TestQwen2MoeModuleTree(ModelTest): + NATIVE_MODEL_ID = "/monster/data/model/Qwen1.5-MoE-A2.7B" + + def test_qwen2_moe_module_tree(self): + config = AutoConfig.from_pretrained(self.NATIVE_MODEL_ID, trust_remote_code=False) + with init_empty_weights(include_buffers=True): + shell = AutoModelForCausalLM.from_config(config, trust_remote_code=False) + + decoder_layer = shell.model.layers[0] + self.assertTrue(hasattr(decoder_layer.self_attn, "q_proj")) + self.assertFalse(hasattr(decoder_layer.self_attn, "q_norm")) + self.assertTrue(hasattr(decoder_layer.mlp, "gate")) + self.assertTrue(hasattr(decoder_layer.mlp, "shared_expert")) + self.assertTrue(hasattr(decoder_layer.mlp, "shared_expert_gate")) + self.assertTrue(hasattr(decoder_layer.mlp, "experts")) + self.assertIn("shared_expert_gate", Qwen2MoeQModel.module_tree[-1]["mlp:moe:?"]) + self.assertIn("shared_expert:0", Qwen2MoeQModel.module_tree[-1]["mlp:moe:?"]) + self.assertIn("experts:0", Qwen2MoeQModel.module_tree[-1]["mlp:moe:?"]) + + +class TestQwen2VLModuleTree(ModelTest): + NATIVE_MODEL_ID = "/monster/data/model/Qwen2-VL-2B-Instruct" + + def test_qwen2_vl_module_tree(self): + config = AutoConfig.from_pretrained(self.NATIVE_MODEL_ID, trust_remote_code=False) + with init_empty_weights(include_buffers=True): + shell = AutoModelForImageTextToText.from_config(config, trust_remote_code=False) + + decoder_layer = shell.model.language_model.layers[0] + self.assertTrue(hasattr(shell.model, "language_model")) + self.assertTrue(hasattr(shell.model, "visual")) + self.assertTrue(hasattr(decoder_layer.self_attn, "q_proj")) + self.assertIn("q_proj:0", Qwen2VLQModel.module_tree[-1]["self_attn"]) + self.assertEqual(Qwen2VLQModel.module_tree[:4], ["model", "language_model", "layers", "#"]) + + +class TestQwen2_5_VLModuleTree(ModelTest): + NATIVE_MODEL_ID = "/monster/data/model/Qwen2.5-VL-3B-Instruct" + + def test_qwen2_5_vl_module_tree(self): + config = AutoConfig.from_pretrained(self.NATIVE_MODEL_ID, trust_remote_code=False) + with init_empty_weights(include_buffers=True): + shell = AutoModelForImageTextToText.from_config(config, trust_remote_code=False) + + decoder_layer = shell.model.language_model.layers[0] + self.assertTrue(hasattr(shell.model, "language_model")) + self.assertTrue(hasattr(shell.model, "visual")) + self.assertTrue(hasattr(decoder_layer.self_attn, "q_proj")) + self.assertIn("q_proj:0", Qwen2_5_VLQModel.module_tree[-1]["self_attn"]) + self.assertEqual(Qwen2_5_VLQModel.module_tree[:4], ["model", "language_model", "layers", "#"]) + + +class TestQwen2_5_OmniModuleTree(ModelTest): + NATIVE_MODEL_ID = "/monster/data/model/Qwen2.5-Omni-3B" + + def test_qwen2_5_omni_module_tree(self): + config = AutoConfig.from_pretrained(self.NATIVE_MODEL_ID, trust_remote_code=False) + with init_empty_weights(include_buffers=True): + shell = AutoModelForTextToWaveform.from_config(config, trust_remote_code=False) + + decoder_layer = shell.thinker.model.layers[0] + self.assertTrue(hasattr(shell, "thinker")) + self.assertTrue(hasattr(shell.thinker, "visual")) + self.assertTrue(hasattr(shell.thinker, "audio_tower")) + self.assertTrue(hasattr(decoder_layer.self_attn, "q_proj")) + self.assertIn("q_proj:0", Qwen2_5_OmniGPTQ.module_tree[-1]["self_attn"]) + self.assertEqual(Qwen2_5_OmniGPTQ.module_tree[:4], ["thinker", "model", "layers", "#"]) + + +class TestQwen3MoeModuleTree(ModelTest): + NATIVE_MODEL_ID = "/monster/data/model/Qwen3-30B-A3B" + + def test_qwen3_moe_module_tree(self): + config = AutoConfig.from_pretrained(self.NATIVE_MODEL_ID, trust_remote_code=False) + with init_empty_weights(include_buffers=True): + shell = AutoModelForCausalLM.from_config(config, trust_remote_code=False) + + decoder_layer = shell.model.layers[0] + self.assertTrue(hasattr(decoder_layer.self_attn, "q_norm")) + self.assertTrue(hasattr(decoder_layer.self_attn, "k_norm")) + self.assertIn("q_norm:!", Qwen3MoeQModel.module_tree[-1]["self_attn"]) + self.assertIn("k_norm:!", Qwen3MoeQModel.module_tree[-1]["self_attn"]) + + +class TestQwen3_5ModuleTree(ModelTest): + NATIVE_MODEL_ID = "/monster/data/model/Qwen3.5-27B" + + def test_qwen3_5_linear_attention_module_tree(self): + config = AutoConfig.from_pretrained(self.NATIVE_MODEL_ID, trust_remote_code=False) + config = config.text_config + config.num_hidden_layers = 1 + config.layer_types = ["linear_attention"] + with init_empty_weights(include_buffers=True): + shell = AutoModelForCausalLM.from_config(config, trust_remote_code=False) + + decoder_layer = shell.model.layers[0] + self.assertTrue(hasattr(decoder_layer.linear_attn, "conv1d")) + self.assertTrue(hasattr(decoder_layer.linear_attn, "in_proj_b")) + self.assertTrue(hasattr(decoder_layer.linear_attn, "in_proj_a")) + self.assertIn("conv1d:!", Qwen3_5QModel.module_tree[-1]["linear_attn"]) + self.assertIn("in_proj_b:!:1", Qwen3_5QModel.module_tree[-1]["linear_attn"]) + self.assertIn("in_proj_a:!:1", Qwen3_5QModel.module_tree[-1]["linear_attn"]) + + +class TestQwen3_5MoeModuleTree(ModelTest): + NATIVE_MODEL_ID = "/monster/data/model/Qwen3.5-35B-A3B" + + def test_qwen3_5_moe_module_tree(self): + config = AutoConfig.from_pretrained(self.NATIVE_MODEL_ID, trust_remote_code=False) + config.text_config.num_hidden_layers = 2 + config.text_config.layer_types = ["linear_attention", "full_attention"] + with init_empty_weights(include_buffers=True): + shell = AutoModelForImageTextToText.from_config(config, trust_remote_code=False) + + linear_layer = shell.model.language_model.layers[0] + full_layer = shell.model.language_model.layers[1] + + self.assertTrue(hasattr(linear_layer.linear_attn, "conv1d")) + self.assertTrue(hasattr(linear_layer.linear_attn, "in_proj_b")) + self.assertTrue(hasattr(linear_layer.linear_attn, "in_proj_a")) + self.assertTrue(hasattr(linear_layer.mlp, "shared_expert")) + self.assertTrue(hasattr(linear_layer.mlp, "shared_expert_gate")) + self.assertTrue(hasattr(full_layer.self_attn, "q_norm")) + self.assertTrue(hasattr(full_layer.self_attn, "k_norm")) + self.assertIn("q_norm:!", Qwen3_5_MoeQModel.module_tree[-1]["self_attn"]) + self.assertIn("k_norm:!", Qwen3_5_MoeQModel.module_tree[-1]["self_attn"]) + self.assertIn("conv1d:!", Qwen3_5_MoeQModel.module_tree[-1]["linear_attn"]) + self.assertIn("in_proj_b:!:1", Qwen3_5_MoeQModel.module_tree[-1]["linear_attn"]) + self.assertIn("in_proj_a:!:1", Qwen3_5_MoeQModel.module_tree[-1]["linear_attn"]) + self.assertIn("shared_expert_gate", Qwen3_5_MoeQModel.module_tree[-1]["mlp:moe:?"]) + self.assertIn("shared_expert:0", Qwen3_5_MoeQModel.module_tree[-1]["mlp:moe:?"]) + + +class TestQwen3OmniModuleTree(ModelTest): + NATIVE_MODEL_ID = "/monster/data/model/Qwen3-Omni-30B-A3B-Instruct" + + def test_qwen3_omni_module_tree(self): + config = AutoConfig.from_pretrained(self.NATIVE_MODEL_ID, trust_remote_code=False) + with init_empty_weights(include_buffers=True): + shell = AutoModelForTextToWaveform.from_config(config, trust_remote_code=False) + + decoder_layer = shell.thinker.model.layers[0] + self.assertTrue(hasattr(decoder_layer.self_attn, "q_norm")) + self.assertTrue(hasattr(decoder_layer.self_attn, "k_norm")) + self.assertEqual(Qwen3OmniMoeGPTQ.module_tree[-1]["mlp:moe"]["gate"], ("gate:!",)) + self.assertIn("q_norm:!", Qwen3OmniMoeGPTQ.module_tree[-1]["self_attn"]) + self.assertIn("k_norm:!", Qwen3OmniMoeGPTQ.module_tree[-1]["self_attn"]) + + +class TestQwen3NextModuleTree(ModelTest): + NATIVE_MODEL_ID = "/monster/data/model/Qwen3-Next-80B-A3B-Instruct" + + def test_qwen3_next_full_attention_module_tree(self): + config = AutoConfig.from_pretrained(self.NATIVE_MODEL_ID, trust_remote_code=False) + config.num_hidden_layers = 1 + config.layer_types = ["full_attention"] + with init_empty_weights(include_buffers=True): + shell = AutoModelForCausalLM.from_config(config, trust_remote_code=False) + + decoder_layer = shell.model.layers[0] + self.assertTrue(hasattr(decoder_layer.self_attn, "q_norm")) + self.assertTrue(hasattr(decoder_layer.self_attn, "k_norm")) + self.assertIn("q_norm:!", Qwen3NextGPTQ.module_tree[-1]["self_attn"]) + self.assertIn("k_norm:!", Qwen3NextGPTQ.module_tree[-1]["self_attn"]) + + def test_qwen3_next_linear_attention_module_tree(self): + config = AutoConfig.from_pretrained(self.NATIVE_MODEL_ID, trust_remote_code=False) + config.num_hidden_layers = 1 + config.layer_types = ["linear_attention"] + with init_empty_weights(include_buffers=True): + shell = AutoModelForCausalLM.from_config(config, trust_remote_code=False) + + decoder_layer = shell.model.layers[0] + self.assertTrue(hasattr(decoder_layer.linear_attn, "norm")) + self.assertTrue(hasattr(decoder_layer.linear_attn, "conv1d")) + self.assertTrue(hasattr(decoder_layer.linear_attn, "in_proj_ba")) + self.assertTrue(hasattr(decoder_layer.mlp, "shared_expert_gate")) + self.assertIn("norm:!", Qwen3NextGPTQ.module_tree[-1]["linear_attn"]) + self.assertIn("conv1d:!", Qwen3NextGPTQ.module_tree[-1]["linear_attn"]) + self.assertIn("in_proj_qkvz:0", Qwen3NextGPTQ.module_tree[-1]["linear_attn"]) + self.assertIn("in_proj_ba:!:0", Qwen3NextGPTQ.module_tree[-1]["linear_attn"]) + + blocks = Qwen3NextGPTQ.build_layer_modules(Qwen3NextGPTQ.module_tree) + linear_in_block = next(block for block in blocks if "linear_attn.in_proj_qkvz" in block) + linear_out_block = next(block for block in blocks if "linear_attn.out_proj" in block) + shared_expert_block = next(block for block in blocks if "mlp.shared_expert.gate_proj" in block) + + self.assertNotIn("linear_attn.out_proj", linear_in_block) + self.assertNotIn("linear_attn.in_proj_qkvz", linear_out_block) + self.assertIn("mlp.shared_expert.up_proj", shared_expert_block) + self.assertNotIn("mlp.shared_expert.down_proj", shared_expert_block) diff --git a/tests/module_tree/test_subset.py b/tests/module_tree/test_subset.py index b9ac35da1..70b65343a 100644 --- a/tests/module_tree/test_subset.py +++ b/tests/module_tree/test_subset.py @@ -22,11 +22,12 @@ sys.path.insert(0, repo_str) from gptqmodel.looper.awq_processor import AWQProcessor, _AWQLayerState -from gptqmodel.looper.loop_processor import LoopProcessor +from gptqmodel.looper.loop_processor import ExecutionConfig, LoopProcessor from gptqmodel.looper.module_looper import ModuleLooper from gptqmodel.looper.named_module import NamedModule -from gptqmodel.looper.stage_subset import run_subset_stage +from gptqmodel.looper.stage_subset import build_subset_plan, run_subset_stage from gptqmodel.models.definitions.qwen2_moe import Qwen2MoeQModel +from gptqmodel.models.definitions.qwen3_5_moe import Qwen3_5_MoeQModel from gptqmodel.models.definitions.qwen3_moe import Qwen3MoeQModel from gptqmodel.nn_modules.hooked_linear import replace_module_with_hooked_legacy from gptqmodel.quantization import FORMAT, METHOD @@ -49,7 +50,7 @@ def _make_quant_config(device: torch.device | str = "cpu") -> QuantizeConfig: quant_method=METHOD.AWQ, format=FORMAT.GEMM, device=device, - vram_strategy=VramStrategy.EXCLUSIVE, + dense_vram_strategy=VramStrategy.EXCLUSIVE, ) @@ -111,6 +112,20 @@ def test_qwen2_moe_shared_expert_merges_with_experts(): assert len(expert_gate_blocks) == 1 +def test_qwen3_5_moe_shared_expert_merges_with_experts(): + blocks = Qwen3_5_MoeQModel.build_layer_modules(Qwen3_5_MoeQModel.module_tree) + print("blocks",blocks) + gate_block = next(block for block in blocks if "mlp.shared_expert.gate_proj" in block) + assert "mlp.experts.{expert_index}.gate_proj" in gate_block + assert "mlp.experts.{expert_index}.up_proj" in gate_block + + down_block = next(block for block in blocks if "mlp.shared_expert.down_proj" in block) + assert "mlp.experts.{expert_index}.down_proj" in down_block + + expert_gate_blocks = [block for block in blocks if "mlp.experts.{expert_index}.gate_proj" in block] + assert len(expert_gate_blocks) == 1 + + def test_awq_processor_enables_subset_early_stop(): calibration = [{"input_ids": torch.tensor([1, 2, 3])}] qcfg = _make_quant_config() @@ -130,7 +145,7 @@ def test_awq_processor_enables_subset_early_stop(): model=dummy_model, ) - assert processor.subset_forward_early_stop is True + assert processor.execution_config.subset_forward_early_stop is True def test_module_looper_subset_callback_invoked(): @@ -140,7 +155,7 @@ def test_module_looper_subset_callback_invoked(): quantize_config=quant_cfg, layer_callback=None, subset_callback=None, - supported_vram_strategies=[VramStrategy.EXCLUSIVE], + supported_dense_vram_strategies=[VramStrategy.EXCLUSIVE], ) looper = ModuleLooper(model=dummy_model, processors=[]) @@ -206,9 +221,11 @@ def __init__(self, qcfg: QuantizeConfig): calibration=calibration, prepare_dataset_func=_prepare_dataset_func, batch_size=1, - require_fwd=True, - fwd_after_process=False, - subset_forward_early_stop=True, + execution_config=ExecutionConfig( + require_fwd=True, + fwd_replay_after_process=False, + subset_forward_early_stop=True, + ), ) self.hook_calls: List[str] = [] self.process_calls: List[str] = [] @@ -220,7 +237,7 @@ def __init__(self, qcfg: QuantizeConfig): def name(cls) -> str: return "stub-awq" - def preprocess(self, module: NamedModule, failsafe=None, **_kwargs): + def preprocess(self, module: NamedModule, fallback=None, **_kwargs): self.tasks[module.name] = {"inputs": []} def pre_process_fwd_hook(self, name: str) -> Callable[[torch.nn.Module, tuple, torch.Tensor], None]: @@ -294,9 +311,11 @@ def test_stage_subset_early_stop_and_callbacks(): quantize_config=quant_cfg, layer_callback=None, subset_callback=None, - supported_vram_strategies=[VramStrategy.EXCLUSIVE, VramStrategy.BALANCED], + supported_dense_vram_strategies=[VramStrategy.EXCLUSIVE, VramStrategy.BALANCED], layer_modules_strict=True, lm_head="lm_head", + shell_module_materialize=lambda target_submodule, device, role, named_module=None: target_submodule, + prepare_layer_replay_kwargs=lambda layer, layer_input, additional_inputs, target_device: additional_inputs, ) processor = _StubAWQProcessor(quant_cfg) @@ -323,12 +342,24 @@ def test_stage_subset_early_stop_and_callbacks(): layers_prefix="layers", names=subset_names, processor=processor, - failsafe=False, + fallback=False, layer_module=mini_layer, ) + subset_plan = build_subset_plan( + looper, + processor=processor, + subset=subset, + subset_index=0, + subset_total=2, + full=full_modules, + fallback=False, + layer_inputs=layer_inputs, + ) + run_subset_stage( looper=looper, + plan=subset_plan, processor=processor, module=mini_layer, layer_inputs=layer_inputs, @@ -340,12 +371,8 @@ def test_stage_subset_early_stop_and_callbacks(): layer_descriptor="layers.0", layer_title="subset-check", layer_index=0, - layers_prefix="layers", - subset=subset, - subset_index=0, - subset_total=2, full=full_modules, - failsafe=False, + fallback=False, shared_kv_cache_dict=shared_kv_cache_dict, pb=_DummyProgress(), log=None, diff --git a/tests/protocol/test_protocol.py b/tests/protocol/test_protocol.py new file mode 100644 index 000000000..430a8cef2 --- /dev/null +++ b/tests/protocol/test_protocol.py @@ -0,0 +1,319 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from gptqmodel.quantization import FORMAT, METHOD, AWQConfig, GGUFConfig, GPTQConfig +from gptqmodel.quantization.protocol import ( + MatchSpec, + Rule, + Stage, + compile_plan_to_quantize_config, + compile_protocol, + compile_protocol_to_quantize_config, + compile_protocol_yaml_text, + compile_protocol_yaml_to_quantize_config, +) + + +def _python_protocol(): + return { + "version": 2, + "stages": [ + Stage( + name="weight_only", + rules=[ + Rule( + match="*", + weight={ + "quantize": { + "method": "gguf", + "bits": "q4_k_m", + }, + "export": { + "format": "gguf", + "variant": "q_k_m", + "impl": "gguf_torch", + }, + }, + ), + ], + ), + ], + } + + +def _yaml_protocol() -> str: + return """ +version: 2 +stages: + - name: weight_only + rules: + - match: "*" + weight: + quantize: + method: gguf + bits: q4_k_m + export: + format: gguf + variant: q_k_m + impl: gguf_torch +""" + + +def _python_protocol_with_negative_match(): + return { + "version": 2, + "stages": [ + { + "name": "weight_only", + "rules": [ + { + "match": ["*", r"-:.*layers\.2.*"], + "weight": { + "quantize": { + "method": "gguf", + "bits": "q4_k_m", + }, + "export": { + "format": "gguf", + "variant": "q_k_m", + "impl": "gguf_torch", + }, + }, + }, + ], + }, + ], + } + + +def _yaml_protocol_with_negative_match() -> str: + return r""" +version: 2 +stages: + - name: weight_only + rules: + - match: + - "*" + - '-:.*layers\.2.*' + weight: + quantize: + method: gguf + bits: q4_k_m + export: + format: gguf + variant: q_k_m + impl: gguf_torch +""" + + +def _python_gptq_protocol_with_negative_match(): + return { + "version": 2, + "stages": [ + { + "name": "ptq", + "rules": [ + { + "match": ["*", r"-:.*layers\.2.*"], + "weight": { + "quantize": { + "method": "gptq", + "bits": 4, + "group_size": 128, + "sym": True, + "desc_act": False, + }, + "export": { + "format": "gptq", + "variant": "gptq", + "impl": "marlin", + }, + }, + }, + ], + }, + ], + } + + +def _yaml_gptq_protocol_with_negative_match() -> str: + return r""" +version: 2 +stages: + - name: ptq + rules: + - match: + - "*" + - '-:.*layers\.2.*' + weight: + quantize: + method: gptq + bits: 4 + group_size: 128 + sym: true + desc_act: false + export: + format: gptq + variant: gptq + impl: marlin +""" + + +def _python_awq_protocol_with_negative_match(): + return { + "version": 2, + "stages": [ + { + "name": "ptq", + "rules": [ + { + "match": ["*", r"-:.*layers\.2.*"], + "weight": { + "quantize": { + "method": "awq", + "bits": 4, + "group_size": 128, + "sym": True, + "desc_act": False, + }, + "export": { + "format": "awq", + "variant": "gemm", + "impl": "gemm", + }, + }, + }, + ], + }, + ], + } + + +def _yaml_awq_protocol_with_negative_match() -> str: + return r""" +version: 2 +stages: + - name: ptq + rules: + - match: + - "*" + - '-:.*layers\.2.*' + weight: + quantize: + method: awq + bits: 4 + group_size: 128 + sym: true + desc_act: false + export: + format: awq + variant: gemm + impl: gemm +""" + + +def test_protocol_python_and_yaml_compile_to_same_execution_plan(): + python_plan = compile_protocol(_python_protocol()) + yaml_plan = compile_protocol_yaml_text(_yaml_protocol()) + + assert python_plan == yaml_plan + + +def test_protocol_python_and_yaml_compile_to_same_gguf_config(): + python_cfg = compile_protocol_to_quantize_config(_python_protocol()) + yaml_cfg = compile_protocol_yaml_to_quantize_config(_yaml_protocol()) + + assert isinstance(python_cfg, GGUFConfig) + assert isinstance(yaml_cfg, GGUFConfig) + assert python_cfg.to_dict() == yaml_cfg.to_dict() + assert python_cfg.quant_method == METHOD.GGUF + assert python_cfg.runtime_bits == "q4_k_m" + assert python_cfg.format == "q_k_m" + + +def test_protocol_plan_compiles_to_expected_gguf_config(): + plan = compile_protocol(_python_protocol()) + cfg = compile_plan_to_quantize_config(plan) + + assert isinstance(cfg, GGUFConfig) + assert cfg.quant_method == METHOD.GGUF + assert cfg.bits == 4 + assert cfg.runtime_bits == "q4_k_m" + assert cfg.format == "q_k_m" + + +def test_protocol_python_and_yaml_compile_to_same_negative_match_plan(): + python_plan = compile_protocol(_python_protocol_with_negative_match()) + yaml_plan = compile_protocol_yaml_text(_yaml_protocol_with_negative_match()) + + assert python_plan == yaml_plan + rule = python_plan.stages[0].rules[0] + assert rule.match == ( + MatchSpec(pattern="*", include=True), + MatchSpec(pattern=r".*layers\.2.*", include=False), + ) + + +def test_rule_match_supports_include_and_exclude_selectors(): + plan = compile_protocol(_python_protocol_with_negative_match()) + rule = plan.stages[0].rules[0] + + assert rule.matches("model.layers.0.self_attn.q_proj") + assert not rule.matches("model.layers.2.self_attn.q_proj") + + +def test_negative_match_gguf_protocol_compiles_to_dynamic_skip_config(): + python_cfg = compile_protocol_to_quantize_config(_python_protocol_with_negative_match()) + yaml_cfg = compile_protocol_yaml_to_quantize_config(_yaml_protocol_with_negative_match()) + + assert isinstance(python_cfg, GGUFConfig) + assert isinstance(yaml_cfg, GGUFConfig) + assert python_cfg.to_dict() == yaml_cfg.to_dict() + assert python_cfg.quant_method == METHOD.GGUF + assert python_cfg.runtime_bits == "q4_k_m" + assert python_cfg.format == "q_k_m" + assert python_cfg.dynamic == {r"-:.*layers\.2.*": {}} + + +def test_negative_match_gptq_protocol_compiles_to_dynamic_skip_config(): + python_cfg = compile_protocol_to_quantize_config(_python_gptq_protocol_with_negative_match()) + yaml_cfg = compile_protocol_yaml_to_quantize_config(_yaml_gptq_protocol_with_negative_match()) + + assert isinstance(python_cfg, GPTQConfig) + assert isinstance(yaml_cfg, GPTQConfig) + assert type(python_cfg) is type(yaml_cfg) + assert python_cfg.quant_method == METHOD.GPTQ + assert yaml_cfg.quant_method == METHOD.GPTQ + assert python_cfg.format == FORMAT.GPTQ + assert yaml_cfg.format == FORMAT.GPTQ + assert python_cfg.bits == 4 + assert yaml_cfg.bits == 4 + assert python_cfg.group_size == 128 + assert yaml_cfg.group_size == 128 + assert python_cfg.sym is True + assert yaml_cfg.sym is True + assert python_cfg.desc_act is False + assert yaml_cfg.desc_act is False + assert python_cfg.dynamic == {r"-:.*layers\.2.*": {}} + assert yaml_cfg.dynamic == {r"-:.*layers\.2.*": {}} + + +def test_negative_match_awq_protocol_compiles_to_dynamic_skip_config(): + python_cfg = compile_protocol_to_quantize_config(_python_awq_protocol_with_negative_match()) + yaml_cfg = compile_protocol_yaml_to_quantize_config(_yaml_awq_protocol_with_negative_match()) + + assert isinstance(python_cfg, AWQConfig) + assert isinstance(yaml_cfg, AWQConfig) + assert type(python_cfg) is type(yaml_cfg) + assert python_cfg.quant_method == METHOD.AWQ + assert yaml_cfg.quant_method == METHOD.AWQ + assert python_cfg.format == FORMAT.GEMM + assert yaml_cfg.format == FORMAT.GEMM + assert python_cfg.bits == 4 + assert yaml_cfg.bits == 4 + assert python_cfg.group_size == 128 + assert yaml_cfg.group_size == 128 + assert python_cfg.sym is True + assert yaml_cfg.sym is True + assert python_cfg.dynamic == {r"-:.*layers\.2.*": {}} + assert yaml_cfg.dynamic == {r"-:.*layers\.2.*": {}} diff --git a/tests/pytest.ini b/tests/pytest.ini index 82a0c02b1..6fb6dc51a 100644 --- a/tests/pytest.ini +++ b/tests/pytest.ini @@ -1,11 +1,16 @@ [pytest] addopts=-s -v log_cli=true -norecursedirs = tasks evalplus_results +norecursedirs = tasks markers = ci: CPU-only CI regression coverage for DeviceThreadPool affinity behaviour cuda: Requires CUDA device + cpu: Exercises CPU-specific code paths + gpu: Exercises GPU-specific code paths inference: Inference workloads that replicate models across devices + model: Loads real model weights or runs model-backed integrations + mps: Exercises Apple MPS-specific code paths + slow: Slow-running tests such as model loading, quantization, or benchmark-heavy paths timeout: Requires pytest-timeout plugin; retained for downstream compatibility filterwarnings = ignore:Warning only once for all operators.*:UserWarning diff --git a/tests/q4_reference.py b/tests/q4_reference.py new file mode 100644 index 000000000..ebecd598d --- /dev/null +++ b/tests/q4_reference.py @@ -0,0 +1,1041 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import torch + + +REFERENCE = torch.Tensor( + [ + 5.8398, + 6.8555, + 7.2734, + 6.4219, + 6.2070, + 5.8203, + 6.5664, + 6.4219, + 6.2148, + 5.3281, + 5.7578, + 7.5312, + 8.1016, + 6.1133, + 7.2031, + 6.6484, + 6.5156, + 6.0117, + 6.0312, + 6.1914, + 6.2109, + 6.8125, + 5.8125, + 7.1172, + 7.3125, + 6.7305, + 5.9961, + 6.5117, + 6.1914, + 5.9648, + 7.1680, + 6.4766, + 7.2070, + 6.5469, + 6.7734, + 6.4219, + 6.8086, + 7.0469, + 5.9297, + 6.4727, + 6.2539, + 5.9570, + 7.2383, + 5.8945, + 6.0820, + 5.7969, + 7.1094, + 6.2188, + 6.7500, + 7.3555, + 6.2930, + 6.7734, + 5.9219, + 7.4805, + 6.8750, + 6.4102, + 6.5898, + 6.5469, + 7.6016, + 6.7461, + 5.9492, + 7.2227, + 5.8164, + 5.4570, + 6.2930, + 7.3984, + 6.0938, + 7.3984, + 5.9609, + 6.3516, + 6.5664, + 5.7969, + 7.1250, + 6.0781, + 6.7930, + 5.9492, + 6.1641, + 6.5898, + 6.0586, + 6.3359, + 6.7930, + 7.0469, + 6.0664, + 6.3320, + 5.4414, + 6.7617, + 5.1641, + 7.2891, + 6.8516, + 6.5312, + 5.6914, + 7.3711, + 6.8203, + 5.9492, + 7.0781, + 6.3164, + 7.1992, + 7.1133, + 7.4219, + 7.5586, + 7.1836, + 6.9102, + 6.4844, + 6.9805, + 6.1953, + 6.5156, + 5.4844, + 6.6602, + 6.6719, + 7.9844, + 6.4727, + 6.6367, + 6.2227, + 6.4531, + 5.0625, + 6.4609, + 6.7031, + 6.6445, + 6.5234, + 6.8633, + 6.6055, + 5.6055, + 6.4453, + 7.2617, + 6.3945, + 6.6367, + 6.1055, + 7.0664, + 6.0820, + 6.6875, + 6.1445, + 6.8672, + 6.2070, + 6.8828, + 6.1484, + 6.7070, + 6.8516, + 6.2734, + 7.1055, + 7.0586, + 6.9648, + 5.9727, + 6.1016, + 6.8750, + 7.0078, + 7.1523, + 5.7383, + 5.9531, + 6.5508, + 7.5352, + 6.1602, + 6.2578, + 6.3906, + 5.7383, + 6.7031, + 5.7344, + 6.3516, + 5.2852, + 7.5312, + 6.4531, + 6.6406, + 6.2266, + 6.1094, + 5.9102, + 5.7617, + 6.3789, + 7.0508, + 6.3750, + 6.3320, + 6.8555, + 6.7266, + 7.0352, + 7.7695, + 6.3984, + 6.5039, + 6.8320, + 6.1602, + 6.0312, + 6.3828, + 6.9023, + 7.4336, + 7.3711, + 6.1016, + 7.0703, + 6.3281, + 6.8281, + 6.4922, + 5.9453, + 5.1016, + 6.7188, + 6.1406, + 6.6289, + 7.2695, + 6.2070, + 6.7070, + 7.2930, + 7.1836, + 6.3828, + 6.1992, + 6.7070, + 7.8008, + 7.7773, + 5.6602, + 7.0273, + 6.6172, + 6.0898, + 5.3516, + 7.3359, + 5.9727, + 6.0078, + 7.0586, + 6.3086, + 6.8555, + 7.2617, + 7.3477, + 6.3828, + 7.1133, + 6.6328, + 7.3516, + 6.9141, + 7.2031, + 6.9805, + 6.1719, + 6.7812, + 8.3047, + 6.5898, + 6.3633, + 6.2539, + 7.2773, + 6.5938, + 6.4141, + 6.8203, + 6.8906, + 7.8828, + 5.9609, + 6.4180, + 7.3984, + 5.7539, + 7.1758, + 6.6641, + 6.9062, + 6.2578, + 7.5508, + 6.1719, + 6.5742, + 5.9375, + 6.7891, + 6.2109, + 6.5039, + 6.8750, + 6.2031, + 6.8828, + 7.1094, + 5.9570, + 7.2969, + 6.6797, + 6.8828, + 5.5430, + 6.9648, + 5.8398, + 6.5430, + 6.3945, + 6.5664, + 5.8086, + 6.6172, + 7.0586, + 6.8867, + 6.0820, + 5.8125, + 6.7070, + 7.5742, + 6.2578, + 6.1328, + 6.5391, + 5.4531, + 6.8242, + 6.6953, + 6.8008, + 6.3398, + 6.4805, + 7.2266, + 6.3281, + 6.6875, + 6.4688, + 5.9414, + 7.4297, + 5.8711, + 6.0625, + 5.8750, + 6.5664, + 5.8867, + 6.3477, + 6.1133, + 6.9453, + 5.0547, + 6.7812, + 6.4922, + 7.2422, + 5.4688, + 6.2109, + 7.2148, + 6.1758, + 5.9297, + 7.1953, + 5.5195, + 6.3203, + 5.9961, + 7.9297, + 6.2695, + 6.4414, + 6.7266, + 7.1875, + 7.3203, + 5.4062, + 6.0625, + 7.0898, + 5.3828, + 5.6133, + 6.0742, + 6.6836, + 5.7109, + 7.2852, + 7.7539, + 7.5820, + 6.4258, + 5.9336, + 6.3750, + 6.3555, + 7.5469, + 6.2539, + 6.5898, + 6.4102, + 7.0469, + 5.7344, + 7.2031, + 6.7969, + 5.6836, + 7.6523, + 6.9297, + 7.8672, + 6.4766, + 6.3008, + 7.0977, + 6.5430, + 7.0938, + 5.8398, + 6.9883, + 6.5312, + 6.3203, + 6.3594, + 5.4062, + 6.9688, + 5.7930, + 6.3164, + 6.5547, + 7.1992, + 5.8750, + 6.3008, + 6.7930, + 6.0391, + 7.4766, + 6.6094, + 6.5625, + 5.9805, + 6.2422, + 7.2109, + 6.6875, + 5.3047, + 7.6211, + 5.9453, + 6.5625, + 6.1641, + 6.1250, + 6.5977, + 7.7422, + 7.0742, + 5.6875, + 6.2656, + 6.6250, + 6.8945, + 5.7070, + 6.3203, + 5.7500, + 6.2695, + 6.2773, + 6.8516, + 6.4883, + 7.0000, + 6.7578, + 6.1875, + 5.9844, + 5.5703, + 6.7188, + 5.5273, + 5.3438, + 7.2500, + 6.7852, + 6.5195, + 6.8125, + 6.0664, + 6.7852, + 7.0000, + 7.0781, + 6.8477, + 7.2930, + 6.3438, + 7.1523, + 6.3281, + 6.8047, + 7.3203, + 5.3359, + 6.1484, + 6.5586, + 7.3828, + 6.2344, + 7.1523, + 6.4102, + 5.5898, + 7.0195, + 7.1172, + 5.8008, + 6.5742, + 6.2891, + 8.0312, + 6.9023, + 6.5898, + 7.1953, + 6.7266, + 6.0078, + 5.5430, + 6.4766, + 6.4258, + 5.9648, + 8.0859, + 5.0547, + 7.2188, + 7.4375, + 6.5156, + 5.9922, + 6.3281, + 6.2852, + 6.7734, + 6.2461, + 6.9805, + 5.4648, + 5.8867, + 6.8242, + 6.3008, + 6.3281, + 7.3047, + 7.1836, + 6.5195, + 6.6328, + 6.7188, + 5.4336, + 6.5078, + 5.3477, + 5.5508, + 7.3125, + 5.8750, + 6.5195, + 6.2383, + 6.3594, + 6.0898, + 6.4141, + 5.9844, + 6.6250, + 7.7109, + 6.0391, + 7.2344, + 5.9453, + 5.9453, + 7.0586, + 5.6641, + 7.2773, + 6.5195, + 7.2227, + 6.3359, + 5.3203, + 6.4375, + 7.2383, + 6.4023, + 6.2148, + 7.3750, + 5.8164, + 6.2109, + 6.5430, + 5.8164, + 6.1680, + 6.7656, + 6.0820, + 6.1094, + 6.5312, + 6.8906, + 6.8320, + 6.1289, + 6.3125, + 7.6797, + 6.3008, + 6.0000, + 7.3320, + 6.7852, + 6.9297, + 6.6328, + 6.2266, + 5.1602, + 6.2031, + 7.0547, + 5.9492, + 6.0703, + 6.0977, + 6.8086, + 6.0742, + 6.0195, + 7.0625, + 6.5781, + 5.7461, + 6.1562, + 7.0430, + 6.7148, + 6.5312, + 6.5820, + 6.4570, + 7.5508, + 5.6289, + 6.0547, + 6.5000, + 7.3125, + 5.8477, + 5.9297, + 6.2578, + 6.0078, + 5.9922, + 7.3398, + 7.4922, + 7.8906, + 7.5547, + 5.4648, + 6.5156, + 6.3242, + 6.1094, + 6.9219, + 6.7227, + 6.6836, + 7.4023, + 5.9648, + 7.2383, + 6.7695, + 6.6797, + 7.0547, + 6.3047, + 6.4688, + 6.9961, + 6.0391, + 5.9727, + 6.8398, + 6.7422, + 5.7656, + 5.4766, + 6.7852, + 7.0820, + 5.3516, + 7.6523, + 5.1562, + 6.6445, + 6.1211, + 6.2695, + 6.0703, + 6.3594, + 6.4062, + 6.3398, + 5.7578, + 6.5391, + 6.2500, + 6.5742, + 6.5000, + 7.5625, + 7.0117, + 6.5547, + 7.1250, + 6.4453, + 6.6094, + 6.1875, + 6.4219, + 6.6172, + 6.4336, + 6.5703, + 6.1758, + 6.4219, + 6.6016, + 6.7383, + 6.7070, + 6.1328, + 5.5586, + 6.6367, + 6.3789, + 6.2578, + 5.5039, + 6.6172, + 6.4648, + 5.8086, + 7.2031, + 5.8125, + 6.3711, + 7.6758, + 7.1289, + 5.8086, + 6.3008, + 6.2109, + 6.1602, + 6.1797, + 7.2305, + 6.7266, + 6.2422, + 5.6719, + 6.7070, + 6.9414, + 6.8594, + 7.4023, + 7.2109, + 6.0156, + 6.6680, + 6.6172, + 7.1250, + 6.6523, + 6.9531, + 6.7617, + 6.4961, + 6.9414, + 5.7188, + 7.6367, + 6.5469, + 6.2305, + 6.4414, + 7.4648, + 5.9102, + 6.2461, + 6.1367, + 6.8203, + 6.5703, + 6.8867, + 7.0000, + 6.7539, + 6.1719, + 6.5469, + 6.2422, + 5.4297, + 5.7305, + 5.1641, + 6.1875, + 7.0312, + 6.6484, + 6.0234, + 7.4102, + 6.8711, + 6.3086, + 6.3711, + 6.7344, + 6.6992, + 5.9766, + 7.3906, + 7.1875, + 6.4883, + 6.3984, + 7.3438, + 6.9688, + 6.9062, + 6.4375, + 6.7891, + 7.0117, + 6.4883, + 5.7500, + 7.0898, + 7.0742, + 6.7070, + 5.8750, + 6.0469, + 6.6445, + 5.2773, + 6.8984, + 6.1641, + 7.0508, + 7.4609, + 5.0273, + 6.7734, + 6.4531, + 5.7656, + 6.5312, + 7.4648, + 6.1250, + 6.5625, + 7.1367, + 6.0625, + 6.1211, + 6.9766, + 6.6758, + 6.3164, + 6.8828, + 6.8203, + 6.7500, + 6.5352, + 7.3008, + 6.7852, + 6.1914, + 5.0508, + 6.7188, + 7.1172, + 6.8008, + 6.8086, + 5.4883, + 6.9180, + 6.5742, + 6.1719, + 7.0469, + 7.1523, + 5.9492, + 5.8594, + 6.8320, + 6.1719, + 6.2031, + 6.8398, + 7.3008, + 6.6289, + 6.4922, + 6.0000, + 5.4766, + 6.3320, + 6.5117, + 6.2812, + 7.5742, + 6.3516, + 7.0039, + 6.4570, + 7.1523, + 7.6289, + 6.2578, + 7.1875, + 6.4844, + 5.7930, + 6.7070, + 7.5508, + 7.1797, + 6.0430, + 6.8711, + 6.5742, + 7.5781, + 6.4766, + 6.5391, + 6.9453, + 6.1992, + 6.6367, + 6.2812, + 6.0234, + 6.6953, + 7.0312, + 6.2031, + 6.5625, + 6.6719, + 6.1719, + 6.5586, + 5.7031, + 7.4609, + 6.6211, + 7.7227, + 6.9141, + 6.0469, + 6.2500, + 5.3828, + 6.0078, + 5.8164, + 5.8867, + 6.1523, + 6.6523, + 6.6953, + 7.3125, + 6.4844, + 5.9570, + 5.9531, + 6.2109, + 5.5039, + 6.5117, + 6.8203, + 6.6133, + 6.4766, + 5.9297, + 7.1445, + 7.1914, + 6.0117, + 6.8281, + 6.7422, + 6.1328, + 6.9805, + 6.5625, + 6.9180, + 7.1133, + 7.3359, + 5.7617, + 5.8711, + 6.4961, + 6.5859, + 6.2422, + 6.5273, + 6.7461, + 6.6992, + 6.7695, + 6.6289, + 5.9453, + 5.9805, + 7.1172, + 6.6719, + 6.0039, + 7.6875, + 6.7812, + 7.8359, + 6.9531, + 7.4336, + 7.6602, + 6.8164, + 7.3945, + 7.1602, + 6.8789, + 5.0078, + 6.0547, + 6.8086, + 6.7070, + 6.4688, + 6.4492, + 6.6172, + 5.5625, + 6.6914, + 6.4297, + 5.7461, + 5.3359, + 6.8750, + 6.4609, + 7.4062, + 5.2070, + 6.0820, + 6.7383, + 6.5703, + 6.1797, + 6.7070, + 6.5977, + 5.9961, + 6.6328, + 6.9375, + 6.3906, + 6.6484, + 4.9609, + 6.6445, + 6.5898, + 7.1875, + 7.5195, + 6.7969, + 6.1367, + 6.8906, + 7.4297, + 6.3633, + 6.0508, + 6.5000, + 6.4648, + 6.7539, + 6.7109, + 5.8086, + 6.6016, + 7.1133, + 4.8672, + 6.6367, + 6.1641, + 5.1758, + 6.9453, + 6.3242, + 7.0664, + 6.4805, + 6.3516, + 6.7383, + 8.4688, + 6.7305, + 5.9844, + 6.5938, + 7.2969, + 6.5977, + 7.5898, + 6.2969, + 6.8672, + 6.6680, + 7.1289, + 6.6875, + 5.4258, + 8.1875, + 8.0391, + 7.7969, + 6.6445, + 7.0703, + 7.3359, + 6.9805, + 6.6328, + 6.5352, + 6.2422, + 5.5820, + 6.8633, + 6.8047, + 6.5703, + 6.0117, + 6.7539, + 7.1719, + 6.8438, + 7.3633, + 6.6016, + 7.2070, + 6.4727, + 5.8008, + 7.4062, + 7.4805, + 6.6445, + 5.9023, + 6.3984, + 6.9961, + 6.6680, + 6.8242, + 6.7148, + 6.6172, + 6.9727, + 6.8320, + 5.9766, + 6.6133, + 5.5977, + 6.7773, + 7.3906, + 6.9219, + 7.0781, + 6.6914, + 5.7539, + 6.7969, + 6.8008, + 5.8047, + 7.1055, + 6.4961, + 6.0352, + 5.6211, + 7.4414, + 7.0703, + 6.1172, + 6.7461, + 6.4492, + 7.7148, + 6.4258, + 6.0039, + 6.5156, + 7.2188, + 7.4531, + 7.4844, + 7.5938, + 7.4023, + 6.7617, + 6.0078, + 6.3320, + 5.8906, + 7.5977, + 5.6523, + 6.7734, + 6.3008, + 5.2227, + 7.1719, + 7.1289, + 6.6602, + 5.4609, + 7.0312, + 6.0820, + 6.1719, + 6.0000, + 6.5547, + 6.6328, + 7.0547, + 7.0859, + 6.2656, + 5.5234, + 6.0273, + 6.7891, + 7.1875, + 6.9531, + 6.8203, + 6.3516, + 6.1172, + 6.4648, + 6.9180, + 7.3906, + 6.2812, + 5.7109, + 6.1484, + 6.9102, + 6.8711, + 7.0156, + 6.1445, + 5.8867, + 6.3828, + 5.9961, + 6.6914, + 6.7891, + 7.0820, + 6.6719, + 6.9297, + 6.3750, + 6.7578, + 6.4883, + 6.2227, + 6.2305, + 6.0508, + 6.6484, + 5.7578, + 7.2070, + 7.2383, + 6.9375, + 7.2578, + 6.5312, + 6.0312, + 6.7930, + 6.2578, + 7.0625, + 7.2148, + 6.4961, + 7.0703, + 6.4727, + 7.3906, + ] +).to(torch.float16) + + +def get_diff(a, ref): + eps = 1e-6 + return f"Maxdiff: {(a - ref).abs().max()}, Mean relative diff: {((a - ref).abs() / (ref.abs() + eps)).mean()}" diff --git a/tests/qcfg/test_config_dispatch.py b/tests/qcfg/test_config_dispatch.py new file mode 100644 index 000000000..e454e163c --- /dev/null +++ b/tests/qcfg/test_config_dispatch.py @@ -0,0 +1,270 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from gptqmodel.quantization.config import ( + FORMAT, + METHOD, + AWQConfig, + BaseQuantizeConfig, + BitsAndBytesConfig, + FP8Config, + GGUFBits, + GGUFConfig, + GPTQConfig, + QuantizeConfig, + RTNConfig, + SmoothMAD, + WeightOnlyConfig, +) +from gptqmodel.quantization.dtype import available_float8_dtype_names + + +def _fp8_alias_cases(): + cases = [ + ("e4m3", "float8_e4m3fn"), + ("e5m2", "float8_e5m2"), + ("e4m3fnuz", "float8_e4m3fnuz"), + ("e5m2fnuz", "float8_e5m2fnuz"), + ("e8m0", "float8_e8m0fnu"), + ("float8_e8m0", "float8_e8m0fnu"), + ] + available = set(available_float8_dtype_names()) + return [(alias, expected) for alias, expected in cases if expected in available] + + +def test_quantize_config_dispatches_gptq_by_default(): + cfg = QuantizeConfig() + + assert isinstance(cfg, GPTQConfig) + assert cfg.quant_method == METHOD.GPTQ + assert cfg.format == FORMAT.GPTQ + + +def test_quantize_config_dispatches_awq_constructor(): + cfg = QuantizeConfig(quant_method=METHOD.AWQ, format=FORMAT.GEMM, sym=False) + + assert isinstance(cfg, AWQConfig) + assert isinstance(cfg, QuantizeConfig) + assert cfg.quant_method == METHOD.AWQ + assert cfg.format == FORMAT.GEMM + assert cfg.sym is False + + +def test_quantize_config_dispatches_awq_with_canonical_method_field(): + cfg = QuantizeConfig(method=METHOD.AWQ, format=FORMAT.GEMM, sym=False) + + assert isinstance(cfg, AWQConfig) + assert cfg.method == METHOD.AWQ + assert cfg.quant_method == METHOD.AWQ + assert cfg.format == FORMAT.GEMM + assert cfg.sym is False + + +def test_quantize_config_dispatches_awq_from_format_without_explicit_method(): + cfg = QuantizeConfig(format=FORMAT.GEMM, sym=False) + + assert isinstance(cfg, AWQConfig) + assert cfg.quant_method == METHOD.AWQ + assert cfg.format == FORMAT.GEMM + assert cfg.sym is False + + +def test_quantize_config_dispatches_awq_ignoring_legacy_gptq_only_kwargs(): + cfg = QuantizeConfig( + quant_method=METHOD.AWQ, + format=FORMAT.GEMM, + sym=False, + act_group_aware=True, + fallback=None, + damp_percent=0.05, + mse=0.0, + ) + + assert isinstance(cfg, AWQConfig) + assert cfg.quant_method == METHOD.AWQ + assert cfg.format == FORMAT.GEMM + assert cfg.sym is False + + +def test_quantize_config_rejects_is_marlin_format_constructor_arg(): + with pytest.raises(ValueError, match="is_marlin_format"): + QuantizeConfig(is_marlin_format=True) + + +def test_quantize_config_rejects_is_marlin_format_in_serialized_payload(): + with pytest.raises(ValueError, match="is_marlin_format"): + QuantizeConfig.from_quant_config( + { + "bits": 4, + "is_marlin_format": True, + } + ) + + +def test_quantize_config_dispatches_rtn_constructor(): + cfg = QuantizeConfig(weight_only=WeightOnlyConfig(smooth=SmoothMAD(k=2.0))) + + assert isinstance(cfg, BaseQuantizeConfig) + assert isinstance(cfg, RTNConfig) + assert not isinstance(cfg, GPTQConfig) + assert cfg.uses_weight_only_lifecycle() is True + assert cfg.smooth is not None + assert cfg.export_quant_method() == METHOD.GPTQ + + +def test_quantize_config_dispatches_rtn_awq_export_constructor(): + cfg = QuantizeConfig( + format=FORMAT.GEMM, + weight_only=WeightOnlyConfig(smooth=SmoothMAD(k=2.0)), + ) + + assert isinstance(cfg, RTNConfig) + assert cfg.format == FORMAT.GEMM + assert cfg.export_quant_method() == METHOD.AWQ + + +def test_quantize_config_dispatches_rtn_gguf_export_constructor(): + cfg = QuantizeConfig( + format=FORMAT.GGUF, + ) + + assert isinstance(cfg, GGUFConfig) + assert cfg.quant_method == METHOD.GGUF + assert cfg.format == "q_0" + assert cfg.bits == 4 + assert isinstance(cfg.runtime_bits, GGUFBits) + assert cfg.runtime_bits == "q4_0" + assert cfg.runtime_bits.bits == 4 + assert cfg.runtime_bits.version == "q" + assert cfg.runtime_bits.variant == "0" + assert cfg.export_quant_method() == METHOD.GGUF + + +def test_quantize_config_dispatches_rtn_from_gguf_weight_only_method(): + cfg = QuantizeConfig( + weight_only=WeightOnlyConfig(method="gguf", smooth=SmoothMAD(k=1.5)), + ) + + assert isinstance(cfg, GGUFConfig) + assert cfg.quant_method == METHOD.GGUF + assert cfg.format == "q_0" + assert cfg.bits == 4 + assert isinstance(cfg.runtime_bits, GGUFBits) + assert cfg.runtime_bits == "q4_0" + assert cfg.runtime_bits.bits == 4 + assert cfg.runtime_bits.variant == "0" + assert cfg.export_quant_method() == METHOD.GGUF + + +def test_quantize_config_dispatches_rtn_from_gguf_weight_only_method_preserving_qtype(): + cfg = QuantizeConfig( + bits="q5_k_m", + weight_only=WeightOnlyConfig(method="gguf"), + ) + + assert isinstance(cfg, GGUFConfig) + assert cfg.quant_method == METHOD.GGUF + assert cfg.bits == 5 + assert isinstance(cfg.runtime_bits, GGUFBits) + assert cfg.runtime_bits == "q5_k_m" + assert cfg.runtime_bits.bits == 5 + assert cfg.runtime_bits.version == "q" + assert cfg.runtime_bits.variant == "k" + assert cfg.runtime_bits.quality == "m" + assert cfg.format == "q_k_m" + assert cfg.export_quant_method() == METHOD.GGUF + + +def test_quantize_config_dispatches_fp8_constructor(): + cfg = QuantizeConfig( + quant_method=METHOD.FP8, + format="float8_e5m2", + weight_scale_method="row", + ) + + assert isinstance(cfg, FP8Config) + assert cfg.quant_method == METHOD.FP8 + assert cfg.format == "float8_e5m2" + assert cfg.weight_scale_method == "row" + assert cfg.uses_weight_only_lifecycle() is True + + +def test_quantize_config_dispatches_fp8_from_weight_only_method(): + cfg = QuantizeConfig( + weight_only=WeightOnlyConfig(method="fp8", smooth=SmoothMAD(k=1.5)), + weight_scale_method="block", + weight_block_size=[128, 128], + ) + + assert isinstance(cfg, FP8Config) + assert cfg.quant_method == METHOD.FP8 + assert cfg.format == "float8_e4m3fn" + assert cfg.weight_scale_method == "block" + assert cfg.weight_block_size == [128, 128] + assert cfg.smooth is not None + + +@pytest.mark.parametrize(("format_value", "expected"), _fp8_alias_cases()) +def test_quantize_config_normalizes_all_supported_fp8_aliases(format_value: str, expected: str): + cfg = QuantizeConfig( + quant_method=METHOD.FP8, + format=format_value, + ) + + assert isinstance(cfg, FP8Config) + assert cfg.format == expected + + +def test_quantize_config_dispatches_bitsandbytes_constructor(): + cfg = QuantizeConfig( + quant_method=METHOD.BITSANDBYTES, + bits=8, + ) + + assert isinstance(cfg, BitsAndBytesConfig) + assert cfg.quant_method == METHOD.BITSANDBYTES + assert cfg.format == "int8" + assert cfg.bits == 8 + assert cfg.uses_weight_only_lifecycle() is True + + +def test_quantize_config_dispatches_bitsandbytes_from_weight_only_method(): + cfg = QuantizeConfig( + bits=4, + weight_only=WeightOnlyConfig(method="bitsandbytes", smooth=SmoothMAD(k=1.25)), + format="nf4", + block_size=128, + compress_statistics=False, + ) + + assert isinstance(cfg, BitsAndBytesConfig) + assert cfg.quant_method == METHOD.BITSANDBYTES + assert cfg.format == "nf4" + assert cfg.bits == 4 + assert cfg.block_size == 128 + assert cfg.compress_statistics is False + assert cfg.smooth is not None + + +def test_quantize_config_dispatches_gptq_marlin_constructor(): + cfg = QuantizeConfig(quant_method=METHOD.GPTQ, format=FORMAT.MARLIN) + + assert isinstance(cfg, GPTQConfig) + assert cfg.export_quant_method() == METHOD.GPTQ + + +def test_from_quant_config_dispatches_awq_and_loads_zero_point(): + cfg = QuantizeConfig.from_quant_config( + { + "bits": 4, + "group_size": 128, + "quant_method": "awq", + "format": "gemm", + "zero_point": True, + } + ) + + assert isinstance(cfg, AWQConfig) + assert cfg.sym is False diff --git a/tests/qcfg/test_failsafe_meta.py b/tests/qcfg/test_failsafe_meta.py index ade66a403..7a496e16a 100644 --- a/tests/qcfg/test_failsafe_meta.py +++ b/tests/qcfg/test_failsafe_meta.py @@ -1,8 +1,9 @@ # SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium +import pytest -from gptqmodel.quantization.config import FailSafe, QuantizeConfig, SmoothMAD +from gptqmodel.quantization.config import FailSafe, QuantizeConfig, SmoothAuto, SmoothMAD def test_quantize_config_serializes_default_failsafe_in_meta_without_smoother(): @@ -41,3 +42,68 @@ def test_quantize_config_round_trips_explicit_failsafe_smoother(): reloaded = QuantizeConfig.from_quant_config(payload) assert isinstance(reloaded.failsafe.smooth, SmoothMAD) assert reloaded.failsafe.smooth.k == cfg.failsafe.smooth.k + + +def test_quantize_config_round_trips_auto_failsafe_smoother(): + cfg = QuantizeConfig( + failsafe=FailSafe( + smooth=SmoothAuto( + include_none=False, + mse_steps=40, + mse_maxshrink=0.9, + mad_k=2.5, + percentile=99.0, + low=0.5, + high=99.5, + ) + ) + ) + payload = cfg.to_dict() + + meta_failsafe = payload["meta"]["failsafe"] + assert meta_failsafe["smooth"]["type"] == "auto" + assert meta_failsafe["smooth"]["include_none"] is False + assert meta_failsafe["smooth"]["mse_steps"] == 40 + assert meta_failsafe["smooth"]["mse_maxshrink"] == 0.9 + assert meta_failsafe["smooth"]["mad_k"] == 2.5 + assert meta_failsafe["smooth"]["percentile"] == 99.0 + assert meta_failsafe["smooth"]["low"] == 0.5 + assert meta_failsafe["smooth"]["high"] == 99.5 + + reloaded = QuantizeConfig.from_quant_config(payload) + assert isinstance(reloaded.failsafe.smooth, SmoothAuto) + assert reloaded.failsafe.smooth.include_none is False + assert reloaded.failsafe.smooth.mse_steps == 40 + assert reloaded.failsafe.smooth.mse_maxshrink == 0.9 + assert reloaded.failsafe.smooth.mad_k == 2.5 + assert reloaded.failsafe.smooth.percentile == 99.0 + assert reloaded.failsafe.smooth.low == 0.5 + assert reloaded.failsafe.smooth.high == 99.5 + + +def test_gptq_pro_defaults_to_auto_failsafe_search(): + cfg = QuantizeConfig.gptq_pro() + + assert cfg.act_group_aware is True + assert cfg.desc_act is False + assert cfg.failsafe is not None + assert isinstance(cfg.failsafe.smooth, SmoothAuto) + + +@pytest.mark.parametrize( + ("kwargs", "match"), + [ + ({"mse_steps": 0}, "mse_steps"), + ({"mse_maxshrink": 0.0}, "mse_maxshrink"), + ({"mse_maxshrink": 1.1}, "mse_maxshrink"), + ({"percentile": 0.0}, "percentile"), + ({"percentile": 101.0}, "percentile"), + ({"low": 25.0, "high": 25.0}, "strictly less"), + ({"low": 80.0, "high": 20.0}, "strictly less"), + ({"low": -1.0, "high": 99.0}, "low"), + ({"low": 0.0, "high": 101.0}, "high"), + ], +) +def test_smooth_auto_rejects_invalid_config(kwargs, match): + with pytest.raises(ValueError, match=match): + SmoothAuto(**kwargs) diff --git a/tests/qcfg/test_fallback_meta.py b/tests/qcfg/test_fallback_meta.py new file mode 100644 index 000000000..b09755208 --- /dev/null +++ b/tests/qcfg/test_fallback_meta.py @@ -0,0 +1,43 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from gptqmodel.quantization.config import Fallback, QuantizeConfig, SmoothMAD + + +def test_quantize_config_serializes_default_fallback_in_meta_without_smoother(): + cfg = QuantizeConfig() + payload = cfg.to_dict() + + assert "fallback" not in payload + assert "meta" in payload + assert "fallback" in payload["meta"] + + meta_fallback = payload["meta"]["fallback"] + assert meta_fallback["strategy"] == cfg.fallback.strategy.value + assert meta_fallback["threshold"] == cfg.fallback.threshold + assert meta_fallback["smooth"] is None + + +def test_quantize_config_reads_default_fallback_from_meta_without_smoother(): + cfg = QuantizeConfig() + payload = cfg.to_dict() + + reloaded = QuantizeConfig.from_quant_config(payload) + assert isinstance(reloaded.fallback, Fallback) + assert reloaded.fallback.strategy == cfg.fallback.strategy + assert reloaded.fallback.threshold == cfg.fallback.threshold + assert reloaded.fallback.smooth is None + + +def test_quantize_config_round_trips_explicit_fallback_smoother(): + cfg = QuantizeConfig(fallback=Fallback(smooth=SmoothMAD(k=1.75))) + payload = cfg.to_dict() + + meta_fallback = payload["meta"]["fallback"] + assert meta_fallback["smooth"]["type"] == "mad" + assert meta_fallback["smooth"]["k"] == 1.75 + + reloaded = QuantizeConfig.from_quant_config(payload) + assert isinstance(reloaded.fallback.smooth, SmoothMAD) + assert reloaded.fallback.smooth.k == cfg.fallback.smooth.k diff --git a/tests/qcfg/test_zero_point.py b/tests/qcfg/test_zero_point.py index 3c685a1cf..cba76a18c 100644 --- a/tests/qcfg/test_zero_point.py +++ b/tests/qcfg/test_zero_point.py @@ -31,6 +31,20 @@ def test_awq_zero_point_overrides_sym(): assert cfg.sym is False +def test_awq_missing_desc_act_defaults_to_false(): + payload = { + "bits": 4, + "group_size": 128, + "quant_method": "awq", + "format": "gemm", + "zero_point": True, + } + + cfg = QuantizeConfig.from_quant_config(payload) + + assert cfg.desc_act is False + + def test_awq_to_dict_uses_zero_point(): cfg = QuantizeConfig(quant_method=METHOD.AWQ, format=FORMAT.GEMM, sym=False) payload = cfg.to_dict() diff --git a/tests/test_asym_gptq_v1.py b/tests/test_asym_gptq_v1.py index 692ddbb3b..2c362751b 100644 --- a/tests/test_asym_gptq_v1.py +++ b/tests/test_asym_gptq_v1.py @@ -12,13 +12,12 @@ from models.model_test import ModelTest # noqa: E402 from gptqmodel.quantization import FORMAT # noqa: E402 -from gptqmodel.utils.eval import EVAL # noqa: E402 class Test(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct" # "meta-llama/Llama-3.2-1B-Instruct" EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + "arc_challenge": { "acc": {"value": 0.3567, "floor_pct": 0.36}, "acc_norm": {"value": 0.3805, "floor_pct": 0.36}, }, @@ -27,4 +26,4 @@ class Test(ModelTest): SYM = False def test(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/test_auto_module_decoder.py b/tests/test_auto_module_decoder.py new file mode 100644 index 000000000..2d9f69753 --- /dev/null +++ b/tests/test_auto_module_decoder.py @@ -0,0 +1,577 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +import json +import threading +from pathlib import Path +from types import SimpleNamespace + +import pytest +import torch +from safetensors.torch import save_file +from torch import nn + +import gptqmodel.models.base as base_module +from gptqmodel.looper.awq_processor import AWQProcessor +from gptqmodel.looper.named_module import NamedModule +from gptqmodel.nn_modules.qlinear.fp4 import TorchFP4Linear +from gptqmodel.nn_modules.qlinear.fp8 import TorchFP8Linear +from gptqmodel.quantization.dtype import ( + dequantize_f4_e2m1, + dequantize_fp8, +) +from gptqmodel.quantization.gptq import GPTQ +from gptqmodel.utils.structure import LazyTurtle + + +try: + from torchao.prototype.mx_formats.nvfp4_tensor import nvfp4_quantize +except Exception: + nvfp4_quantize = None + + +class _LinearWrapper(nn.Module): + def __init__(self, in_features: int, out_features: int): + super().__init__() + self.linear = nn.Linear(in_features, out_features, bias=True) + + +def _write_index(path: Path, shard_name: str, keys: list[str]) -> None: + weight_map = dict.fromkeys(keys, shard_name) + (path / "model.safetensors.index.json").write_text( + json.dumps({"weight_map": weight_map}), + encoding="utf-8", + ) + + +@pytest.mark.skipif(not hasattr(torch, "float8_e4m3fn"), reason="float8 dtype not available") +def test_shell_materialize_forward_builds_fp8_wrapper_and_quant_source(tmp_path, monkeypatch): + source_model = _LinearWrapper(16, 8).eval() + model_dir = tmp_path / "fp8_source" + model_dir.mkdir() + + weight_fp8 = source_model.linear.weight.detach().to(torch.float8_e4m3fn).cpu() + scale_inv = torch.linspace(2.0, 3.75, steps=source_model.linear.out_features, dtype=torch.float32) + bias = source_model.linear.bias.detach().cpu() + shard_name = "model.safetensors" + save_file( + { + "linear.weight": weight_fp8, + "linear.weight_scale_inv": scale_inv, + "linear.bias": bias, + }, + str(model_dir / shard_name), + ) + _write_index(model_dir, shard_name, ["linear.weight", "linear.weight_scale_inv", "linear.bias"]) + + turtle = LazyTurtle.maybe_create( + model_local_path=str(model_dir), + config=SimpleNamespace(_experts_implementation=None), + model_init_kwargs={"device_map": {"": "cpu"}}, + ) + assert turtle is not None + + shell_model = _LinearWrapper(16, 8).eval() + shell_model.linear.weight = nn.Parameter( + torch.empty_like(shell_model.linear.weight, device="meta"), + requires_grad=False, + ) + shell_model.linear.bias = nn.Parameter( + torch.empty_like(shell_model.linear.bias, device="meta"), + requires_grad=False, + ) + + harness = base_module.BaseQModel.__new__(base_module.BaseQModel) + nn.Module.__init__(harness) + harness.model = shell_model + harness.turtle_model = turtle + harness._turtle_lock = threading.RLock() + harness.auto_module_decoder_events = [] + + named = NamedModule(shell_model.linear, name="linear", full_name="linear", layer_index=0) + named.state["auto_module_decoder"] = { + "code": "auto_module_decoder", + "source_dtype": "auto", + "target_dtype": torch.bfloat16, + } + + monkeypatch.setattr(base_module, "device_supports_dtype", lambda *args, **kwargs: True) + + prepared = base_module.BaseQModel.shell_module_materialize( + harness, + target_submodule=shell_model.linear, + device=torch.device("cpu"), + role="forward", + named_module=named, + ) + + assert isinstance(prepared, TorchFP8Linear) + assert isinstance(shell_model.linear, TorchFP8Linear) + assert named.state["auto_module_decoder_forward_mode"] == "native" + assert isinstance(named.state["quant_source_module"], nn.Linear) + assert named.state["quant_source_module"].weight.device.type == "cpu" + assert named.state["quant_source_module"].weight.dtype == torch.bfloat16 + expected = dequantize_fp8( + weight_fp8, + scale_inv=scale_inv, + axis=None, + target_dtype=torch.bfloat16, + ) + torch.testing.assert_close(named.state["quant_source_module"].weight, expected) + assert harness.auto_module_decoder_events[0]["forward_mode"] == "native" + + +@pytest.mark.skipif(not hasattr(torch, "float8_e4m3fn"), reason="float8 dtype not available") +def test_shell_materialize_forward_decodes_when_passthrough_forward_policy_is_decode(tmp_path, monkeypatch): + source_model = _LinearWrapper(16, 8).eval() + model_dir = tmp_path / "fp8_source_decode" + model_dir.mkdir() + + weight_fp8 = source_model.linear.weight.detach().to(torch.float8_e4m3fn).cpu() + scale_inv = torch.linspace(2.0, 3.75, steps=source_model.linear.out_features, dtype=torch.float32) + bias = source_model.linear.bias.detach().cpu() + shard_name = "model.safetensors" + save_file( + { + "linear.weight": weight_fp8, + "linear.weight_scale_inv": scale_inv, + "linear.bias": bias, + }, + str(model_dir / shard_name), + ) + _write_index(model_dir, shard_name, ["linear.weight", "linear.weight_scale_inv", "linear.bias"]) + + turtle = LazyTurtle.maybe_create( + model_local_path=str(model_dir), + config=SimpleNamespace(_experts_implementation=None), + model_init_kwargs={"device_map": {"": "cpu"}}, + ) + assert turtle is not None + + shell_model = _LinearWrapper(16, 8).eval() + shell_model.linear.weight = nn.Parameter( + torch.empty_like(shell_model.linear.weight, device="meta"), + requires_grad=False, + ) + shell_model.linear.bias = nn.Parameter( + torch.empty_like(shell_model.linear.bias, device="meta"), + requires_grad=False, + ) + + harness = base_module.BaseQModel.__new__(base_module.BaseQModel) + nn.Module.__init__(harness) + harness.model = shell_model + harness.turtle_model = turtle + harness._turtle_lock = threading.RLock() + harness.auto_module_decoder_events = [] + + named = NamedModule(shell_model.linear, name="linear", full_name="linear", layer_index=0) + named.state["auto_module_decoder"] = { + "code": "auto_module_decoder", + "source_dtype": "auto", + "target_dtype": torch.bfloat16, + "passthrough_forward_policy": "decode", + "passthrough_save_policy": "decode", + } + + monkeypatch.setattr(base_module, "device_supports_dtype", lambda *args, **kwargs: True) + + prepared = base_module.BaseQModel.shell_module_materialize( + harness, + target_submodule=shell_model.linear, + device=torch.device("cpu"), + role="forward", + named_module=named, + ) + + expected = dequantize_fp8( + weight_fp8, + scale_inv=scale_inv, + axis=None, + target_dtype=torch.bfloat16, + ) + assert isinstance(prepared, nn.Linear) + assert not isinstance(prepared, TorchFP8Linear) + assert named.state["auto_module_decoder_forward_mode"] == "decode" + torch.testing.assert_close(prepared.weight, expected) + assert harness.auto_module_decoder_events[0]["forward_mode"] == "decode" + + +@pytest.mark.skipif(not hasattr(torch, "float8_e4m3fn"), reason="float8 dtype not available") +def test_shell_materialize_quant_source_swaps_back_to_dense_module(tmp_path, monkeypatch): + source_model = _LinearWrapper(8, 4).eval() + model_dir = tmp_path / "fp8_source" + model_dir.mkdir() + + shard_name = "model.safetensors" + save_file( + { + "linear.weight": source_model.linear.weight.detach().to(torch.float8_e4m3fn).cpu(), + "linear.weight_scale_inv": torch.ones(4, dtype=torch.float32), + "linear.bias": source_model.linear.bias.detach().cpu(), + }, + str(model_dir / shard_name), + ) + _write_index(model_dir, shard_name, ["linear.weight", "linear.weight_scale_inv", "linear.bias"]) + + turtle = LazyTurtle.maybe_create( + model_local_path=str(model_dir), + config=SimpleNamespace(_experts_implementation=None), + model_init_kwargs={"device_map": {"": "cpu"}}, + ) + assert turtle is not None + + shell_model = _LinearWrapper(8, 4).eval() + shell_model.linear.weight = nn.Parameter( + torch.empty_like(shell_model.linear.weight, device="meta"), + requires_grad=False, + ) + shell_model.linear.bias = nn.Parameter( + torch.empty_like(shell_model.linear.bias, device="meta"), + requires_grad=False, + ) + + harness = base_module.BaseQModel.__new__(base_module.BaseQModel) + nn.Module.__init__(harness) + harness.model = shell_model + harness.turtle_model = turtle + harness._turtle_lock = threading.RLock() + harness.auto_module_decoder_events = [] + + named = NamedModule(shell_model.linear, name="linear", full_name="linear", layer_index=0) + named.state["auto_module_decoder"] = { + "code": "auto_module_decoder", + "source_dtype": "auto", + "target_dtype": torch.bfloat16, + } + + monkeypatch.setattr(base_module, "device_supports_dtype", lambda *args, **kwargs: True) + + forward_module = base_module.BaseQModel.shell_module_materialize( + harness, + target_submodule=shell_model.linear, + device=torch.device("cpu"), + role="forward", + named_module=named, + ) + named.module = forward_module + + quant_source = base_module.BaseQModel.shell_module_materialize( + harness, + target_submodule=forward_module, + device=torch.device("cpu"), + role="quant_source", + named_module=named, + ) + + assert isinstance(quant_source, nn.Linear) + assert isinstance(shell_model.linear, nn.Linear) + torch.testing.assert_close( + quant_source.weight, + named.state["quant_source_module"].weight, + ) + + +@pytest.mark.skipif(not hasattr(torch, "float8_e4m3fn"), reason="float8 dtype not available") +def test_shell_materialize_forward_builds_fp8_wrapper_from_weight_scale_metadata(tmp_path, monkeypatch): + source_model = _LinearWrapper(16, 8).eval() + model_dir = tmp_path / "fp8_source_scale" + model_dir.mkdir() + + weight_fp8 = source_model.linear.weight.detach().to(torch.float8_e4m3fn).cpu() + scale = torch.full((), 0.5, dtype=torch.float32) + bias = source_model.linear.bias.detach().cpu() + shard_name = "model.safetensors" + save_file( + { + "linear.weight": weight_fp8, + "linear.weight_scale": scale, + "linear.bias": bias, + }, + str(model_dir / shard_name), + ) + _write_index(model_dir, shard_name, ["linear.weight", "linear.weight_scale", "linear.bias"]) + + turtle = LazyTurtle.maybe_create( + model_local_path=str(model_dir), + config=SimpleNamespace(_experts_implementation=None), + model_init_kwargs={"device_map": {"": "cpu"}}, + ) + assert turtle is not None + + shell_model = _LinearWrapper(16, 8).eval() + shell_model.linear.weight = nn.Parameter( + torch.empty_like(shell_model.linear.weight, device="meta"), + requires_grad=False, + ) + shell_model.linear.bias = nn.Parameter( + torch.empty_like(shell_model.linear.bias, device="meta"), + requires_grad=False, + ) + + harness = base_module.BaseQModel.__new__(base_module.BaseQModel) + nn.Module.__init__(harness) + harness.model = shell_model + harness.turtle_model = turtle + harness._turtle_lock = threading.RLock() + harness.auto_module_decoder_events = [] + + named = NamedModule(shell_model.linear, name="linear", full_name="linear", layer_index=0) + named.state["auto_module_decoder"] = { + "code": "auto_module_decoder", + "source_dtype": "auto", + "target_dtype": torch.bfloat16, + } + + monkeypatch.setattr(base_module, "device_supports_dtype", lambda *args, **kwargs: True) + + prepared = base_module.BaseQModel.shell_module_materialize( + harness, + target_submodule=shell_model.linear, + device=torch.device("cpu"), + role="forward", + named_module=named, + ) + + assert isinstance(prepared, TorchFP8Linear) + assert named.state["auto_module_decoder_forward_mode"] == "native" + expected = dequantize_fp8( + weight_fp8, + scale=scale, + axis=None, + target_dtype=torch.bfloat16, + ) + torch.testing.assert_close(named.state["quant_source_module"].weight, expected) + torch.testing.assert_close(prepared.weight_scale_inv, torch.reciprocal(scale).to(torch.float32)) + assert harness.auto_module_decoder_events[0]["forward_mode"] == "native" + + +@pytest.mark.skipif(nvfp4_quantize is None, reason="torchao NVFP4 support required") +@pytest.mark.skipif(not hasattr(torch, "float4_e2m1fn_x2"), reason="float4 packed dtype not available") +def test_shell_materialize_forward_decodes_fp4_source_to_dense_module(tmp_path): + source_model = _LinearWrapper(16, 4).eval() + model_dir = tmp_path / "fp4_source" + model_dir.mkdir() + + scales, packed = nvfp4_quantize(source_model.linear.weight.detach().to(torch.float32), block_size=16) + packed_float4 = packed.view(torch.float4_e2m1fn_x2) + bias = source_model.linear.bias.detach().cpu() + shard_name = "model.safetensors" + save_file( + { + "linear.weight": packed_float4.cpu(), + "linear.weight_scale": scales.cpu(), + "linear.bias": bias, + }, + str(model_dir / shard_name), + ) + _write_index(model_dir, shard_name, ["linear.weight", "linear.weight_scale", "linear.bias"]) + + turtle = LazyTurtle.maybe_create( + model_local_path=str(model_dir), + config=SimpleNamespace(_experts_implementation=None), + model_init_kwargs={"device_map": {"": "cpu"}}, + ) + assert turtle is not None + + shell_model = _LinearWrapper(16, 4).eval() + shell_model.linear.weight = nn.Parameter( + torch.empty_like(shell_model.linear.weight, device="meta"), + requires_grad=False, + ) + shell_model.linear.bias = nn.Parameter( + torch.empty_like(shell_model.linear.bias, device="meta"), + requires_grad=False, + ) + + harness = base_module.BaseQModel.__new__(base_module.BaseQModel) + nn.Module.__init__(harness) + harness.model = shell_model + harness.turtle_model = turtle + harness._turtle_lock = threading.RLock() + harness.auto_module_decoder_events = [] + + named = NamedModule(shell_model.linear, name="linear", full_name="linear", layer_index=0) + named.state["auto_module_decoder"] = { + "code": "auto_module_decoder", + "source_dtype": "auto", + "target_dtype": torch.bfloat16, + } + + prepared = base_module.BaseQModel.shell_module_materialize( + harness, + target_submodule=shell_model.linear, + device=torch.device("cpu"), + role="forward", + named_module=named, + ) + + expected = dequantize_f4_e2m1( + packed_float4.cpu(), + scale=scales.cpu(), + axis=None, + target_dtype=torch.bfloat16, + ) + assert isinstance(prepared, nn.Linear) + assert not isinstance(prepared, TorchFP8Linear) + assert named.state["auto_module_decoder_forward_mode"] == "decode" + torch.testing.assert_close(prepared.weight, expected, atol=1e-3, rtol=1e-3) + torch.testing.assert_close(named.state["quant_source_module"].weight, expected, atol=1e-3, rtol=1e-3) + + +@pytest.mark.skipif(nvfp4_quantize is None, reason="torchao NVFP4 support required") +@pytest.mark.skipif(not hasattr(torch, "float4_e2m1fn_x2"), reason="float4 packed dtype not available") +def test_shell_materialize_forward_builds_fp4_wrapper_when_native_supported(tmp_path, monkeypatch): + source_model = _LinearWrapper(16, 4).eval() + model_dir = tmp_path / "fp4_native_source" + model_dir.mkdir() + + scales, packed = nvfp4_quantize(source_model.linear.weight.detach().to(torch.float32), block_size=16) + packed_float4 = packed.view(torch.float4_e2m1fn_x2) + bias = source_model.linear.bias.detach().cpu() + shard_name = "model.safetensors" + save_file( + { + "linear.weight": packed_float4.cpu(), + "linear.weight_scale": scales.cpu(), + "linear.bias": bias, + }, + str(model_dir / shard_name), + ) + _write_index(model_dir, shard_name, ["linear.weight", "linear.weight_scale", "linear.bias"]) + + turtle = LazyTurtle.maybe_create( + model_local_path=str(model_dir), + config=SimpleNamespace( + _experts_implementation=None, + quantization_config={"format": "nvfp4"}, + ), + model_init_kwargs={"device_map": {"": "cpu"}}, + ) + assert turtle is not None + + shell_model = _LinearWrapper(16, 4).eval() + shell_model.config = SimpleNamespace(quantization_config={"format": "nvfp4"}) + shell_model.linear.weight = nn.Parameter( + torch.empty_like(shell_model.linear.weight, device="meta"), + requires_grad=False, + ) + shell_model.linear.bias = nn.Parameter( + torch.empty_like(shell_model.linear.bias, device="meta"), + requires_grad=False, + ) + + harness = base_module.BaseQModel.__new__(base_module.BaseQModel) + nn.Module.__init__(harness) + harness.model = shell_model + harness.turtle_model = turtle + harness._turtle_lock = threading.RLock() + harness.auto_module_decoder_events = [] + + named = NamedModule(shell_model.linear, name="linear", full_name="linear", layer_index=0) + named.state["auto_module_decoder"] = { + "code": "auto_module_decoder", + "source_dtype": "auto", + "target_dtype": torch.bfloat16, + } + + monkeypatch.setattr(base_module, "device_supports_native_fp4", lambda *args, **kwargs: True) + + prepared = base_module.BaseQModel.shell_module_materialize( + harness, + target_submodule=shell_model.linear, + device=torch.device("cpu"), + role="forward", + named_module=named, + ) + + expected = dequantize_f4_e2m1( + packed_float4.cpu(), + scale=scales.cpu(), + axis=None, + target_dtype=torch.bfloat16, + ) + assert isinstance(prepared, TorchFP4Linear) + assert named.state["auto_module_decoder_forward_mode"] == "native" + assert isinstance(named.state["quant_source_module"], nn.Linear) + torch.testing.assert_close(named.state["quant_source_module"].weight, expected, atol=1e-3, rtol=1e-3) + + +def test_configure_modelopt_runtime_rejects_modelopt_input_activation_quantization(): + harness = base_module.BaseQModel.__new__(base_module.BaseQModel) + nn.Module.__init__(harness) + harness.model = SimpleNamespace( + config=SimpleNamespace( + quantization_config={ + "quant_method": "modelopt", + "config_groups": { + "group_0": { + "input_activations": {"num_bits": 4, "type": "float", "dynamic": False}, + } + }, + }, + ), + state_dict=lambda: {}, + ) + harness.turtle_model = None + + with pytest.raises(ValueError, match="activation quantization"): + base_module.BaseQModel._configure_modelopt_runtime(harness) + + +def test_configure_modelopt_runtime_rejects_checkpoint_activation_scale_metadata(): + harness = base_module.BaseQModel.__new__(base_module.BaseQModel) + nn.Module.__init__(harness) + harness.model = SimpleNamespace( + config=SimpleNamespace(quantization_config={"quant_method": "modelopt"}), + state_dict=lambda: {}, + ) + harness.turtle_model = base_module.LazyTurtle.__new__(base_module.LazyTurtle) + harness.turtle_model._weight_map = {"model.layers.0.self_attn.q_proj.input_scale": "model.safetensors"} + + with pytest.raises(ValueError, match="activation quantization"): + base_module.BaseQModel._configure_modelopt_runtime(harness) + + +def test_configure_modelopt_runtime_allows_weight_only_modelopt_metadata(): + harness = base_module.BaseQModel.__new__(base_module.BaseQModel) + nn.Module.__init__(harness) + harness.model = SimpleNamespace( + config=SimpleNamespace( + quantization_config={ + "quant_method": "modelopt", + "config_groups": { + "group_0": { + "weights": {"num_bits": 4, "type": "float", "dynamic": False}, + } + }, + }, + ), + state_dict=lambda: {}, + ) + harness.turtle_model = base_module.LazyTurtle.__new__(base_module.LazyTurtle) + harness.turtle_model._weight_map = {"model.layers.0.self_attn.q_proj.weight_scale": "model.safetensors"} + + base_module.BaseQModel._configure_modelopt_runtime(harness) + + +def test_gptq_prefers_quant_source_module_when_present(): + forward_module = nn.Linear(8, 4, bias=False) + quant_source = nn.Linear(8, 4, bias=False) + named = NamedModule(forward_module, name="linear", full_name="linear", layer_index=0) + named.state["quant_source_module"] = quant_source + + task = GPTQ(named) + + assert task.module is quant_source + + +def test_awq_resolve_quant_source_module_prefers_dense_source(): + forward_module = nn.Linear(8, 4, bias=False) + quant_source = nn.Linear(8, 4, bias=False) + named = NamedModule(forward_module, name="linear", full_name="linear", layer_index=0) + named.state["quant_source_module"] = quant_source + + resolved = AWQProcessor.resolve_quant_source_module(named) + + assert resolved is quant_source diff --git a/tests/test_awq_bitblas.py b/tests/test_awq_bitblas.py new file mode 100644 index 000000000..0c66f7dde --- /dev/null +++ b/tests/test_awq_bitblas.py @@ -0,0 +1,184 @@ +import numpy as np +import pytest +import torch + +import gptqmodel.nn_modules.qlinear.bitblas as bitblas_module +from gptqmodel.models._const import DEVICE +from gptqmodel.nn_modules.qlinear.bitblas_awq import AWQBitBlasKernel +from gptqmodel.nn_modules.qlinear.gemm_awq import AwqGEMMLinear +from gptqmodel.quantization import FORMAT, METHOD, QuantizeConfig +from gptqmodel.utils.backend import BACKEND +from gptqmodel.utils.importer import get_kernel_for_backend, select_quant_linear + + +def _compress_ints(lowprecision_weight: torch.Tensor, bits: int) -> torch.Tensor: + values = lowprecision_weight.detach().cpu().numpy().astype(np.int8, copy=False) + elems_per_byte = 8 // bits + packed = np.zeros( + (*values.shape[:-1], values.shape[-1] // elems_per_byte), + dtype=np.int8, + ) + for col in range(packed.shape[-1]): + for lane in range(elems_per_byte): + packed[:, col] |= values[:, col * elems_per_byte + lane] << (bits * lane) + return torch.from_numpy(packed) + + +def _pack_awq(iweights_in_out: torch.Tensor, izeros_group_out: torch.Tensor, bits: int) -> tuple[torch.Tensor, torch.Tensor]: + pack_num = 32 // bits + order_map = [0, 2, 4, 6, 1, 3, 5, 7] + + qweight = torch.zeros( + (iweights_in_out.shape[0], iweights_in_out.shape[1] // pack_num), + dtype=torch.int32, + ) + for col in range(qweight.shape[1]): + for lane, mapped_col in enumerate(order_map): + qweight[:, col] |= iweights_in_out[:, col * pack_num + mapped_col].to(torch.int32) << (lane * bits) + + qzeros = torch.zeros( + (izeros_group_out.shape[0], izeros_group_out.shape[1] // pack_num), + dtype=torch.int32, + ) + for col in range(qzeros.shape[1]): + for lane, mapped_col in enumerate(order_map): + qzeros[:, col] |= izeros_group_out[:, col * pack_num + mapped_col].to(torch.int32) << (lane * bits) + + return qweight, qzeros + + +def _install_dummy_bitblas(monkeypatch): + captured = {} + + class _DummyMatmul: + def __init__(self, config): + self.config = config + self.lib = object() + self.weight_transform = None + + @staticmethod + def retrieve_weight_shape(): + return (1, 1) + + def _fake_get_or_create(self, config, enable_tuning): + captured["config"] = config + captured["enable_tuning"] = enable_tuning + return _DummyMatmul(config) + + monkeypatch.setattr(bitblas_module, "BITBLAS_AVAILABLE", True) + monkeypatch.setattr(bitblas_module, "import_bitblas", lambda: None) + monkeypatch.setattr(AWQBitBlasKernel, "_get_or_create_bitblas_operator", _fake_get_or_create) + AWQBitBlasKernel.cached_validate_once.cache_clear() + + return captured + + +@pytest.mark.skipif(not bitblas_module.BITBLAS_AVAILABLE, reason="BitBLAS backend is not available") +def test_awq_bitblas_selects_bitblas_awq_for_awq_gemm(monkeypatch): + _install_dummy_bitblas(monkeypatch) + + selected = select_quant_linear( + bits=4, + group_size=32, + desc_act=False, + sym=True, + backend=BACKEND.AWQ_BITBLAS, + format=FORMAT.GEMM, + quant_method=METHOD.AWQ, + device=DEVICE.CUDA, + pack=False, + pack_dtype=torch.int32, + ) + + assert selected is AWQBitBlasKernel + + +@pytest.mark.skipif(not bitblas_module.BITBLAS_AVAILABLE, reason="BitBLAS backend is not available") +def test_awq_bitblas_kernel_mapping_uses_awq_backend(): + assert get_kernel_for_backend(BACKEND.AWQ_BITBLAS, METHOD.AWQ, FORMAT.BITBLAS) is AWQBitBlasKernel + + +@pytest.mark.skipif(not bitblas_module.BITBLAS_AVAILABLE, reason="BitBLAS backend is not available") +def test_awq_bitblas_kernel_mapping_rejects_gptq_bitblas_backend(): + with pytest.raises(ValueError, match="Unsupported backend"): + get_kernel_for_backend(BACKEND.GPTQ_BITBLAS, METHOD.AWQ, FORMAT.BITBLAS) + + +@pytest.mark.skipif(not bitblas_module.BITBLAS_AVAILABLE, reason="BitBLAS backend is not available") +def test_awq_bitblas_uses_unsigned_weights_and_qzeros(monkeypatch): + captured = _install_dummy_bitblas(monkeypatch) + + AWQBitBlasKernel( + bits=4, + group_size=32, + desc_act=False, + sym=True, + in_features=32, + out_features=32, + bias=False, + ) + + assert captured["config"].W_dtype == "uint4" + assert captured["config"].with_zeros is True + assert captured["config"].zeros_mode == "quantized" + + +@pytest.mark.skipif(not bitblas_module.BITBLAS_AVAILABLE, reason="BitBLAS backend is not available") +def test_awq_bitblas_repack_from_awq_preserves_codes(monkeypatch): + _install_dummy_bitblas(monkeypatch) + + bits = 4 + group_size = 32 + in_features = 64 + out_features = 32 + groups = in_features // group_size + + intweight = torch.randint(0, 2**bits, (out_features, in_features), dtype=torch.int32) + intzeros = torch.randint(0, 2**bits, (groups, out_features), dtype=torch.int32) + scales = torch.rand(groups, out_features, dtype=torch.float32) + 0.25 + bias = torch.randn(out_features, dtype=torch.float16) + qweight, qzeros = _pack_awq(intweight.t().contiguous(), intzeros, bits) + + awq_module = AwqGEMMLinear( + bits=bits, + group_size=group_size, + sym=True, + desc_act=False, + in_features=in_features, + out_features=out_features, + bias=True, + register_buffers=True, + ) + awq_module.qweight.copy_(qweight) + awq_module.qzeros.copy_(qzeros) + awq_module.scales.copy_(scales.to(torch.float16)) + awq_module.bias.copy_(bias) + + bitblas_module_instance = AWQBitBlasKernel( + bits=bits, + group_size=group_size, + sym=True, + desc_act=False, + in_features=in_features, + out_features=out_features, + bias=True, + ) + bitblas_module_instance.repack_from_awq(awq_module) + + torch.testing.assert_close(bitblas_module_instance.qweight.cpu(), _compress_ints(intweight, bits)) + torch.testing.assert_close(bitblas_module_instance.qzeros.cpu(), _compress_ints(intzeros, bits)) + torch.testing.assert_close(bitblas_module_instance.scales.cpu(), scales.t().to(torch.float16)) + torch.testing.assert_close(bitblas_module_instance.bias.cpu(), bias) + + +@pytest.mark.skipif(not bitblas_module.BITBLAS_AVAILABLE, reason="BitBLAS backend is not available") +def test_quantize_config_allows_awq_bitblas(): + cfg = QuantizeConfig( + bits=4, + group_size=128, + quant_method=METHOD.AWQ, + format=FORMAT.BITBLAS, + ) + + assert cfg.quant_method == METHOD.AWQ + assert cfg.format == FORMAT.BITBLAS diff --git a/tests/test_awq_fp16_matmul_heuristic.py b/tests/test_awq_fp16_matmul_heuristic.py index 9fde44a16..455ede79b 100644 --- a/tests/test_awq_fp16_matmul_heuristic.py +++ b/tests/test_awq_fp16_matmul_heuristic.py @@ -18,8 +18,6 @@ def _fake_quant_tensors(in_features: int = 32, out_features: int = 8, group_size def _patch_backend(monkeypatch, backend: str, calls): if backend == "triton": - monkeypatch.setattr(gemm_awq, "awq_ext", None) - triton_state = getattr(gemm_awq_triton, "tritonv2", SimpleNamespace(TRITON_AVAILABLE=False)) monkeypatch.setattr(gemm_awq_triton, "tritonv2", triton_state, raising=False) monkeypatch.setattr(triton_state, "TRITON_AVAILABLE", True) @@ -30,6 +28,7 @@ def fake_dequant(qweight, scales, qzeros): def fake_gemm(input, qweight, scales, qzeros, split_k_iters, **_): calls["gemm"] += 1 + calls["gemm_kwargs"] = _ out_features = qweight.shape[1] * 8 return torch.ones(input.shape[0], out_features, device=input.device, dtype=input.dtype) @@ -48,25 +47,26 @@ def fake_gemm(input, qweight, scales, qzeros, split_k_iters, **_): return gemm_awq_triton.AwqGemmTritonFn - # Stub the compiled AWQ extension so we can count which path is taken. - class FakeAwqExt: - def dequantize_weights_cuda(self, qweight, scales, qzeros, *_args): - calls["dequant"] += 1 - return torch.ones(qweight.shape[0], qweight.shape[1] * 8, dtype=torch.float16) + def fake_dequant(qweight, scales, qzeros, *_args): + calls["dequant"] += 1 + return torch.ones(qweight.shape[0], qweight.shape[1] * 8, dtype=torch.float16) - def gemm_forward_cuda(self, input, qweight, scales, qzeros, _split_k_iters): - calls["gemm"] += 1 - out_features = qweight.shape[1] * 8 - return torch.ones(input.shape[0], out_features, device=input.device, dtype=input.dtype) + def fake_gemm(input, qweight, scales, qzeros, _split_k_iters, fp32_accum=False): + calls["gemm"] += 1 + calls["gemm_api"] = "fp32_accum" if fp32_accum else "legacy" + calls["gemm_kwargs"] = {"fp32_accum": fp32_accum} + out_features = qweight.shape[1] * 8 + return torch.ones(input.shape[0], out_features, device=input.device, dtype=input.dtype) - monkeypatch.setattr(gemm_awq, "awq_ext", FakeAwqExt()) + monkeypatch.setattr(gemm_awq, "awq_dequantize_weights", fake_dequant) + monkeypatch.setattr(gemm_awq, "_awq_cuda_gemm_forward", fake_gemm) triton_state = getattr(gemm_awq_triton, "tritonv2", SimpleNamespace(TRITON_AVAILABLE=False)) monkeypatch.setattr(gemm_awq_triton, "tritonv2", triton_state, raising=False) monkeypatch.setattr(triton_state, "TRITON_AVAILABLE", False) return gemm_awq.AwqGemmFn -@pytest.mark.parametrize("backend", ["triton", "ext"], ids=["triton", "awq_ext"]) +@pytest.mark.parametrize("backend", ["triton", "jit"], ids=["triton", "awq_jit"]) def test_fp16_matmul_heuristic_prefers_dequant_for_large_matrices(monkeypatch, backend): calls = {"dequant": 0, "gemm": 0} fn = _patch_backend(monkeypatch, backend, calls) @@ -83,11 +83,12 @@ def test_fp16_matmul_heuristic_prefers_dequant_for_large_matrices(monkeypatch, b x, qweight, qzeros, scales, 4, group_size, None, out_features, ) - assert calls == {"dequant": 1, "gemm": 0} + assert calls["dequant"] == 1 + assert calls["gemm"] == 0 assert out.shape == (33, 32, out_features) -@pytest.mark.parametrize("backend", ["triton", "ext"], ids=["triton", "awq_ext"]) +@pytest.mark.parametrize("backend", ["triton", "jit"], ids=["triton", "awq_jit"]) def test_fp16_matmul_heuristic_prefers_fused_gemm_for_small_matrices(monkeypatch, backend): calls = {"dequant": 0, "gemm": 0} fn = _patch_backend(monkeypatch, backend, calls) @@ -104,14 +105,41 @@ def test_fp16_matmul_heuristic_prefers_fused_gemm_for_small_matrices(monkeypatch x, qweight, qzeros, scales, 4, group_size, None, out_features, ) - assert calls == {"dequant": 0, "gemm": 1} + assert calls["dequant"] == 0 + assert calls["gemm"] == 1 + assert out.shape == (1, 1, out_features) + if backend == "triton": + assert calls["gemm_kwargs"]["fp32_accum"] is True + assert calls["gemm_kwargs"]["output_dtype"] == torch.float16 + else: + assert calls["gemm_kwargs"]["fp32_accum"] is True + assert calls["gemm_api"] == "fp32_accum" + + +def test_awq_jit_fp32_accum_can_be_disabled(monkeypatch): + calls = {"dequant": 0, "gemm": 0} + fn = _patch_backend(monkeypatch, "jit", calls) + + group_size = 32 + out_features = 8 + qweight, scales, qzeros = _fake_quant_tensors(in_features=32, out_features=out_features, group_size=group_size) + x = torch.ones((1, 1, qweight.shape[0]), dtype=torch.float16) + + out = fn.apply( + x, qweight, qzeros, scales, 4, group_size, None, out_features, "cuda", False, + ) + + assert calls["dequant"] == 0 + assert calls["gemm"] == 1 + assert calls["gemm_kwargs"]["fp32_accum"] is False + assert calls["gemm_api"] == "legacy" assert out.shape == (1, 1, out_features) def _available_bench_backends(): backends = [] - if gemm_awq.awq_ext is not None: - backends.append("awq_ext") + if gemm_awq.awq_runtime_available(): + backends.append("awq_jit") triton_mod = getattr(gemm_awq_triton, "tritonv2", None) if triton_mod is not None and getattr(triton_mod, "TRITON_AVAILABLE", False): backends.append("triton") @@ -155,7 +183,7 @@ def test_fp16_matmul_heuristic_benchmark(case_name, batch, seq, in_features, out tabulate = pytest.importorskip("tabulate").tabulate - if backend not in {"awq_ext", "triton"}: + if backend not in {"awq_jit", "triton"}: pytest.skip("No AWQ backend available for benchmark") device = torch.device("cuda") @@ -173,8 +201,8 @@ def test_fp16_matmul_heuristic_benchmark(case_name, batch, seq, in_features, out def run_dequant_matmul(): with torch.inference_mode(): - if backend == "awq_ext": - weight = gemm_awq.awq_ext.dequantize_weights_cuda(qweight, scales, qzeros, 0, 0, 0, False) + if backend == "awq_jit": + weight = gemm_awq.awq_dequantize_weights(qweight, scales, qzeros, 0, 0, 0, False) else: try: weight = awq_dequantize_triton(qweight, scales, qzeros) @@ -185,8 +213,8 @@ def run_dequant_matmul(): def run_fused_gemm(): with torch.inference_mode(): x2d = x.reshape(-1, x.shape[-1]) - if backend == "awq_ext": - return gemm_awq.awq_ext.gemm_forward_cuda(x2d, qweight, scales, qzeros, 8) + if backend == "awq_jit": + return gemm_awq._awq_cuda_gemm_forward(x2d, qweight, scales, qzeros, 8, True) try: return awq_gemm_triton(x2d, qweight, scales, qzeros, split_k_iters=8) except AttributeError as err: diff --git a/tests/test_awq_gemm.py b/tests/test_awq_gemm.py new file mode 100644 index 000000000..3f63de955 --- /dev/null +++ b/tests/test_awq_gemm.py @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import pytest +from awq_test_utils import run_quantized_awq_generation_test + +from gptqmodel import BACKEND +from gptqmodel.quantization import FORMAT + + +pytestmark = [pytest.mark.model, pytest.mark.slow] + + +def test_awq_gemm_quantized_model_loads_and_generates(): + run_quantized_awq_generation_test(FORMAT.GEMM, BACKEND.GEMM) diff --git a/tests/test_awq_gemv.py b/tests/test_awq_gemv.py new file mode 100644 index 000000000..14d5513d8 --- /dev/null +++ b/tests/test_awq_gemv.py @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import pytest +from awq_test_utils import run_quantized_awq_generation_test + +from gptqmodel import BACKEND +from gptqmodel.quantization import FORMAT + + +pytestmark = [pytest.mark.model, pytest.mark.slow] + + +def test_awq_gemv_quantized_model_loads_and_generates(): + run_quantized_awq_generation_test(FORMAT.GEMV, BACKEND.GEMV) diff --git a/tests/test_awq_gemv_fast.py b/tests/test_awq_gemv_fast.py new file mode 100644 index 000000000..c9ccc59ac --- /dev/null +++ b/tests/test_awq_gemv_fast.py @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import pytest +from awq_test_utils import run_quantized_awq_generation_test + +from gptqmodel import BACKEND +from gptqmodel.quantization import FORMAT + + +pytestmark = [pytest.mark.model, pytest.mark.slow] + + +def test_awq_gemv_fast_quantized_model_loads_and_generates(): + run_quantized_awq_generation_test(FORMAT.GEMV_FAST, BACKEND.GEMV_FAST) diff --git a/tests/test_awq_gemv_fast_jit.py b/tests/test_awq_gemv_fast_jit.py new file mode 100644 index 000000000..5ef684fde --- /dev/null +++ b/tests/test_awq_gemv_fast_jit.py @@ -0,0 +1,172 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch + +import gptqmodel.nn_modules.qlinear.gemv_fast_awq as gemv_fast_awq + + +def _build_module() -> gemv_fast_awq.AwqGEMVFastLinear: + return gemv_fast_awq.AwqGEMVFastLinear( + bits=4, + group_size=32, + sym=True, + desc_act=False, + in_features=32, + out_features=8, + bias=False, + register_buffers=True, + ) + + +def _build_llm_awq_module() -> gemv_fast_awq.LLMAwqLinear: + return gemv_fast_awq.LLMAwqLinear( + bits=4, + group_size=32, + sym=True, + desc_act=False, + in_features=32, + out_features=8, + bias=False, + register_buffers=True, + ) + + +def test_awq_gemv_fast_decode_uses_jit_decode_kernel(monkeypatch): + module = _build_module() + calls = {} + + monkeypatch.setattr(gemv_fast_awq, "awq_runtime_available", lambda: True) + + def fake_decode(inputs, qweight, scales, zeros, m, n, k, group_size): + calls["decode"] = { + "shape": tuple(inputs.shape), + "m": m, + "n": n, + "k": k, + "group_size": group_size, + } + return torch.ones((inputs.shape[0], inputs.shape[1], module.out_features), dtype=torch.float16) + + def fail_prefill(*_args, **_kwargs): + raise AssertionError("prefill kernel should not be used for decode-shaped inputs") + + monkeypatch.setattr(gemv_fast_awq, "awq_fast_gemv_forward_decode", fake_decode) + monkeypatch.setattr(gemv_fast_awq, "awq_fast_gemm_forward_prefill", fail_prefill) + + x = torch.randn((4, 1, module.in_features), dtype=torch.float32) + out = module(x) + + assert calls["decode"] == { + "shape": (4, 1, module.in_features), + "m": 4, + "n": module.out_features, + "k": module.in_features, + "group_size": module.group_size, + } + assert out.shape == (4, 1, module.out_features) + assert out.dtype == torch.float32 + + +def test_awq_gemv_fast_prefill_uses_jit_prefill_kernel(monkeypatch): + module = _build_module() + calls = {"prefill": 0} + + monkeypatch.setattr(gemv_fast_awq, "awq_runtime_available", lambda: True) + + def fail_decode(*_args, **_kwargs): + raise AssertionError("decode kernel should not be used for prefill-shaped inputs") + + def fake_prefill(inputs, qweight, scales, zeros): + calls["prefill"] += 1 + calls["shape"] = tuple(inputs.shape) + return torch.ones((inputs.shape[0], inputs.shape[1], module.out_features), dtype=torch.float16) + + monkeypatch.setattr(gemv_fast_awq, "awq_fast_gemv_forward_decode", fail_decode) + monkeypatch.setattr(gemv_fast_awq, "awq_fast_gemm_forward_prefill", fake_prefill) + + x = torch.randn((2, 4, module.in_features), dtype=torch.float16) + out = module(x) + + assert calls["prefill"] == 1 + assert calls["shape"] == (2, 4, module.in_features) + assert out.shape == (2, 4, module.out_features) + assert out.dtype == torch.float16 + + +def test_awq_gemv_fast_raises_runtime_error_when_jit_ops_missing(monkeypatch): + module = _build_module() + + monkeypatch.setattr(gemv_fast_awq, "awq_runtime_available", lambda: False) + monkeypatch.setattr(gemv_fast_awq, "awq_runtime_error", lambda: "missing awq jit ops") + + with pytest.raises(ModuleNotFoundError, match="missing awq jit ops"): + module(torch.randn((1, 1, module.in_features), dtype=torch.float16)) + + +def test_awq_gemv_fast_decode_normalizes_noncontiguous_inputs_and_buffers(monkeypatch): + module = _build_module() + module.qweight = module.qweight.t().contiguous().t() + module.scales = module.scales.t().contiguous().t() + module.qzeros = module.qzeros.t().contiguous().t() + + monkeypatch.setattr(gemv_fast_awq, "awq_runtime_available", lambda: True) + + def fake_decode(inputs, qweight, scales, zeros, m, n, k, group_size): + assert inputs.is_contiguous() + assert qweight.is_contiguous() + assert scales.is_contiguous() + assert zeros.is_contiguous() + assert inputs.dtype == torch.float16 + assert scales.dtype == torch.float16 + assert zeros.dtype == torch.float16 + return torch.ones((inputs.shape[0], inputs.shape[1], module.out_features), dtype=torch.float16) + + monkeypatch.setattr(gemv_fast_awq, "awq_fast_gemv_forward_decode", fake_decode) + monkeypatch.setattr(gemv_fast_awq, "awq_fast_gemm_forward_prefill", lambda *_args, **_kwargs: None) + + x = torch.randn((module.in_features, 1, 4), dtype=torch.float32).permute(2, 1, 0) + assert not x.is_contiguous() + module(x) + + +def test_awq_gemv_fast_prefill_normalizes_noncontiguous_inputs(monkeypatch): + module = _build_module() + + monkeypatch.setattr(gemv_fast_awq, "awq_runtime_available", lambda: True) + monkeypatch.setattr(gemv_fast_awq, "awq_fast_gemv_forward_decode", lambda *_args, **_kwargs: None) + + def fake_prefill(inputs, qweight, scales, zeros): + assert inputs.is_contiguous() + assert qweight.is_contiguous() + assert scales.is_contiguous() + assert zeros.is_contiguous() + assert inputs.dtype == torch.float16 + return torch.ones((inputs.shape[0], inputs.shape[1], module.out_features), dtype=torch.float16) + + monkeypatch.setattr(gemv_fast_awq, "awq_fast_gemm_forward_prefill", fake_prefill) + + x = torch.randn((2, module.in_features, 4), dtype=torch.float16).transpose(1, 2) + assert x.shape == (2, 4, module.in_features) + assert not x.is_contiguous() + module(x) + + +def test_llm_awq_decode_normalizes_scaled_zeros_without_dynamic_attr_access(monkeypatch): + module = _build_llm_awq_module() + module.scaled_zeros = module.scaled_zeros.t().contiguous().t() + + monkeypatch.setattr(gemv_fast_awq, "awq_runtime_available", lambda: True) + + def fake_decode(inputs, qweight, scales, zeros, m, n, k, group_size): + assert zeros.is_contiguous() + assert zeros.dtype == torch.float16 + assert zeros.data_ptr() == module.scaled_zeros.data_ptr() + return torch.ones((inputs.shape[0], inputs.shape[1], module.out_features), dtype=torch.float16) + + monkeypatch.setattr(gemv_fast_awq, "awq_fast_gemv_forward_decode", fake_decode) + monkeypatch.setattr(gemv_fast_awq, "awq_fast_gemm_forward_prefill", lambda *_args, **_kwargs: None) + + x = torch.randn((4, 1, module.in_features), dtype=torch.float16) + module(x) diff --git a/tests/test_awq_inference_llm_awq.py b/tests/test_awq_inference_llm_awq.py new file mode 100644 index 000000000..6a3e5e01e --- /dev/null +++ b/tests/test_awq_inference_llm_awq.py @@ -0,0 +1,21 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import pytest +from awq_test_utils import run_inference_only_generation_test + +from gptqmodel import BACKEND + + +pytestmark = [pytest.mark.model, pytest.mark.slow] + + +def test_inference_quantized_by_llm_awq(): + run_inference_only_generation_test( + "ModelCloud/opt-125m-llm-awq", # this quantized by llm-awq + backend=BACKEND.AUTO, + max_new_tokens=512, + extra_terms=("food", "market", "country"), + ) diff --git a/tests/test_awq_inference_mistral.py b/tests/test_awq_inference_mistral.py new file mode 100644 index 000000000..d5a3b50cc --- /dev/null +++ b/tests/test_awq_inference_mistral.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import pytest +from awq_test_utils import run_inference_only_generation_test + +from gptqmodel import BACKEND + + +pytestmark = [pytest.mark.model, pytest.mark.slow] + + +def test_inference_mistral_awq(): + run_inference_only_generation_test( + "TheBloke/Mistral-7B-v0.1-AWQ", + backend=BACKEND.GEMM, + max_new_tokens=64, + ) diff --git a/tests/test_awq_jit_include_paths.py b/tests/test_awq_jit_include_paths.py new file mode 100644 index 000000000..19c2f0d63 --- /dev/null +++ b/tests/test_awq_jit_include_paths.py @@ -0,0 +1,44 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import gptqmodel.utils.awq as awq_module +from gptqmodel.utils import cpp as cpp_module + + +def test_awq_include_paths_use_wheel_headers_when_local_cuda_is_incomplete(monkeypatch, tmp_path): + root = tmp_path / "awq" + local_cuda_include = tmp_path / "local_cuda_include" + wheel_cuda_include = tmp_path / "wheel_cuda_include" + root.mkdir() + local_cuda_include.mkdir() + wheel_cuda_include.mkdir() + (wheel_cuda_include / "cusparse.h").write_text("// stub", encoding="utf-8") + + monkeypatch.setattr(awq_module, "_awq_root", lambda: root) + monkeypatch.setattr(cpp_module, "detected_local_cuda_include_paths", lambda: [str(local_cuda_include)]) + monkeypatch.setattr(cpp_module, "detected_cuda_wheel_include_paths", lambda: [str(wheel_cuda_include)]) + + include_paths = awq_module._awq_include_paths() + + assert include_paths[0] == str(root) + assert str(wheel_cuda_include) in include_paths + + +def test_awq_include_paths_skip_wheel_headers_when_local_cuda_has_required_headers(monkeypatch, tmp_path): + root = tmp_path / "awq" + local_cuda_include = tmp_path / "local_cuda_include" + wheel_cuda_include = tmp_path / "wheel_cuda_include" + root.mkdir() + local_cuda_include.mkdir() + wheel_cuda_include.mkdir() + (local_cuda_include / "cusparse.h").write_text("// stub", encoding="utf-8") + + monkeypatch.setattr(awq_module, "_awq_root", lambda: root) + monkeypatch.setattr(cpp_module, "detected_local_cuda_include_paths", lambda: [str(local_cuda_include)]) + monkeypatch.setattr(cpp_module, "detected_cuda_wheel_include_paths", lambda: [str(wheel_cuda_include)]) + + include_paths = awq_module._awq_include_paths() + + assert include_paths == [str(root)] diff --git a/tests/test_awq_llm_awq.py b/tests/test_awq_llm_awq.py new file mode 100644 index 000000000..c15870566 --- /dev/null +++ b/tests/test_awq_llm_awq.py @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import pytest +from awq_test_utils import run_quantized_awq_generation_test + +from gptqmodel import BACKEND +from gptqmodel.quantization import FORMAT + + +pytestmark = [pytest.mark.model, pytest.mark.slow] + + +def test_llm_awq_checkpoint_loads_with_gemv_fast_backend_and_generates(): + run_quantized_awq_generation_test(FORMAT.LLM_AWQ, BACKEND.GEMV_FAST) diff --git a/tests/test_awq_loader_dtype.py b/tests/test_awq_loader_dtype.py new file mode 100644 index 000000000..39bf67374 --- /dev/null +++ b/tests/test_awq_loader_dtype.py @@ -0,0 +1,53 @@ +import torch + +from gptqmodel import BACKEND +from gptqmodel.models import loader +from gptqmodel.quantization import FORMAT, METHOD, QuantizeConfig + + +def test_explicit_awq_backend_coerces_unsupported_bfloat16(monkeypatch): + class FakeAwqKernel: + SUPPORTS_DTYPES = [torch.float16] + __name__ = "FakeAwqKernel" + + monkeypatch.setattr(loader, "get_kernel_for_backend", lambda *_args, **_kwargs: FakeAwqKernel) + + qcfg = QuantizeConfig(bits=4, group_size=128, quant_method=METHOD.AWQ, format=FORMAT.GEMM) + + dtype = loader._coerce_quantized_awq_dtype( + backend=BACKEND.GEMM, + qcfg=qcfg, + dtype=torch.bfloat16, + ) + + assert dtype == torch.float16 + + +def test_auto_awq_backend_keeps_requested_dtype(): + qcfg = QuantizeConfig(bits=4, group_size=128, quant_method=METHOD.AWQ, format=FORMAT.GEMM) + + dtype = loader._coerce_quantized_awq_dtype( + backend=BACKEND.AUTO, + qcfg=qcfg, + dtype=torch.bfloat16, + ) + + assert dtype == torch.bfloat16 + + +def test_explicit_awq_backend_keeps_supported_bfloat16(monkeypatch): + class FakeAwqKernel: + SUPPORTS_DTYPES = [torch.float16, torch.bfloat16] + __name__ = "FakeAwqKernel" + + monkeypatch.setattr(loader, "get_kernel_for_backend", lambda *_args, **_kwargs: FakeAwqKernel) + + qcfg = QuantizeConfig(bits=4, group_size=128, quant_method=METHOD.AWQ, format=FORMAT.GEMM) + + dtype = loader._coerce_quantized_awq_dtype( + backend=BACKEND.MARLIN, + qcfg=qcfg, + dtype=torch.bfloat16, + ) + + assert dtype == torch.bfloat16 diff --git a/tests/test_awq_marlin.py b/tests/test_awq_marlin.py new file mode 100644 index 000000000..a235c718d --- /dev/null +++ b/tests/test_awq_marlin.py @@ -0,0 +1,17 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import pytest +from awq_test_utils import run_quantized_awq_generation_test + +from gptqmodel import BACKEND +from gptqmodel.quantization import FORMAT + + +pytestmark = [pytest.mark.model, pytest.mark.slow] + + +def test_awq_marlin_quantized_model_loads_and_generates(): + run_quantized_awq_generation_test(FORMAT.GEMM, BACKEND.MARLIN) diff --git a/tests/test_awq_moe.py b/tests/test_awq_moe.py index b11998c5c..892a400fc 100644 --- a/tests/test_awq_moe.py +++ b/tests/test_awq_moe.py @@ -11,10 +11,11 @@ import unittest from datasets import load_dataset +from models.model_test import ModelTest from parameterized import parameterized from transformers import AutoTokenizer -from gptqmodel.nn_modules.qlinear.gemm_awq import AwqGEMMQuantLinear +from gptqmodel.nn_modules.qlinear.gemm_awq import AwqGEMMLinear from gptqmodel.quantization import FORMAT, METHOD, QUANT_CONFIG_FILENAME from gptqmodel.utils.torch import torch_empty_cache @@ -83,8 +84,12 @@ def test_quant_and_inference(self, group_size: int): # self.assert_awq_linear(model) - tokens = model.generate("Capital of France is", max_new_tokens=100)[0] - result = model.tokenizer.decode(tokens) + result = ModelTest.generate_stable_with_limit( + model, + model.tokenizer, + "The capital city of France is named", + max_new_tokens=100, + ) print(f"BACKEND: {BACKEND.GEMM}, Result: {result}") if "paris" not in result.lower() and "city" not in result.lower(): raise AssertionError(" `paris` not found in `result`") @@ -92,7 +97,7 @@ def test_quant_and_inference(self, group_size: int): def assert_awq_linear(self, model): has_qqq = False for _, module in model.named_modules(): - linear = AwqGEMMQuantLinear + linear = AwqGEMMLinear if isinstance(module, linear): has_qqq = True break diff --git a/tests/test_awq_rotary_device.py b/tests/test_awq_rotary_device.py index c6ac05f89..0274bac03 100644 --- a/tests/test_awq_rotary_device.py +++ b/tests/test_awq_rotary_device.py @@ -52,7 +52,9 @@ def _make_processor(rotary: nn.Module) -> AWQProcessor: calibration_concat_size=None, calibration_sort=None, batch_size=1, - gptq_model=None, + gptq_model=types.SimpleNamespace( + rotary_embedding=None, + ), model=model, require_fwd=True, calculate_w_wq_diff=False, diff --git a/tests/test_awq_weight_mean.py b/tests/test_awq_weight_mean.py index a6a9e5530..59e743581 100644 --- a/tests/test_awq_weight_mean.py +++ b/tests/test_awq_weight_mean.py @@ -1,4 +1,5 @@ import os +import types os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" @@ -12,12 +13,19 @@ from pytest import MonkeyPatch from torch import nn -from gptqmodel.looper.awq_processor import AWQProcessor, _AWQLayerState +from gptqmodel.looper.awq_processor import ( + AWQProcessor, + _accumulate_awq_weight_mean, + _AWQLayerState, + _compute_awq_weight_mean, +) from gptqmodel.quantization.config import FORMAT, METHOD, QuantizeConfig QWEN3_HIDDEN_SIZE = 3584 +pytestmark = [pytest.mark.cpu, pytest.mark.gpu] + def _compute_legacy_w_mean(layers, group_size): weights = [layer.weight.detach().to(torch.float32).cpu() for layer in layers] @@ -29,46 +37,14 @@ def _compute_legacy_w_mean(layers, group_size): return w_scale.mean(0) def _compute_fast_w_mean(layers, group_size): - first_weight = layers[0].weight - num_channels = first_weight.shape[1] - device = first_weight.device - dtype = first_weight.dtype - w_sum = torch.zeros(num_channels, dtype=torch.float32, device=device) - row_count = 0 - - for layer in layers: - weight = layer.weight - org_shape = weight.shape - weight_abs = weight.abs() - weight_group = weight_abs.view(-1, group_size) - group_scale = weight_group.amax(dim=1, keepdim=True) + 1e-6 - normalized = (weight_group / group_scale).view(org_shape) - w_sum += normalized.sum(dim=0, dtype=torch.float32) - row_count += org_shape[0] - - if row_count == 0: - return torch.zeros(num_channels, dtype=dtype, device=device) - return (w_sum / row_count).to(dtype) + return _compute_awq_weight_mean(layers, group_size) def _compute_fast_w_mean_multi(layer_groups, group_size): total_sum = None total_rows = 0 for layers in layer_groups: - first_weight = layers[0].weight - device = first_weight.device - num_channels = first_weight.shape[1] - w_sum = torch.zeros(num_channels, dtype=torch.float32, device=device) - rows = 0 - for layer in layers: - weight = layer.weight - org_shape = weight.shape - weight_abs = weight.abs() - weight_group = weight_abs.view(-1, group_size) - group_scale = weight_group.amax(dim=1, keepdim=True) + 1e-6 - normalized = (weight_group / group_scale).view(org_shape) - w_sum += normalized.sum(dim=0, dtype=torch.float32) - rows += org_shape[0] + w_sum, rows = _accumulate_awq_weight_mean(layers, group_size) if total_sum is None: total_sum = w_sum.cpu() else: @@ -95,7 +71,9 @@ def __init__(self, qcfg: QuantizeConfig): calibration_concat_size=None, calibration_sort=None, batch_size=1, - gptq_model=None, + gptq_model=types.SimpleNamespace( + rotary_embedding=None, + ), model=None, require_fwd=True, calculate_w_wq_diff=False, diff --git a/tests/test_backend_naming.py b/tests/test_backend_naming.py new file mode 100644 index 000000000..ced71bf11 --- /dev/null +++ b/tests/test_backend_naming.py @@ -0,0 +1,56 @@ +import pytest + +from gptqmodel.quantization.config import METHOD +from gptqmodel.utils.backend import BACKEND, PROFILE, normalize_backend, normalize_profile + + +def test_legacy_marlin_backend_normalizes_by_quant_method(): + assert normalize_backend(BACKEND.MARLIN, quant_method=METHOD.GPTQ) == BACKEND.GPTQ_MARLIN + assert normalize_backend(BACKEND.MARLIN, quant_method=METHOD.AWQ) == BACKEND.AWQ_MARLIN + + +def test_removed_mentaray_backend_names_are_rejected(): + with pytest.raises(ValueError): + normalize_backend("mentaray", quant_method=METHOD.GPTQ) + with pytest.raises(ValueError): + normalize_backend("gptq_mentaray") + with pytest.raises(ValueError): + normalize_backend("awq_mentaray") + + +def test_legacy_torch_backend_normalizes_by_quant_method(): + assert normalize_backend("torch", quant_method=METHOD.GPTQ) == BACKEND.GPTQ_TORCH + assert normalize_backend("torch", quant_method=METHOD.FP8) == BACKEND.FP8_TORCH + assert normalize_backend("torch", quant_method=METHOD.EXL3) == BACKEND.EXL3_TORCH + + +def test_awq_specific_legacy_backends_normalize_to_canonical_names(): + assert normalize_backend(BACKEND.TORCH_AWQ, quant_method=METHOD.AWQ) == BACKEND.AWQ_TORCH + assert normalize_backend(BACKEND.BITBLAS_AWQ, quant_method=METHOD.AWQ) == BACKEND.AWQ_BITBLAS + + +def test_name_based_lookup_accepts_canonical_member_names(): + assert normalize_backend("GPTQ_MARLIN") == BACKEND.GPTQ_MARLIN + assert normalize_backend("AWQ_GEMM_TRITON") == BACKEND.AWQ_GEMM_TRITON + + +@pytest.mark.parametrize( + ("raw_profile", "expected"), + [ + (None, PROFILE.AUTO), + ("", PROFILE.AUTO), + ("FAST", PROFILE.FAST), + ("low-memory", PROFILE.LOW_MEMORY), + ("low memory", PROFILE.LOW_MEMORY), + (1, PROFILE.FAST), + (2, PROFILE.LOW_MEMORY), + (PROFILE.AUTO, PROFILE.AUTO), + ], +) +def test_profile_normalization_accepts_enum_string_and_index_aliases(raw_profile, expected): + assert normalize_profile(raw_profile) == expected + + +def test_profile_normalization_rejects_unknown_index(): + with pytest.raises(ValueError, match="Unknown profile index"): + normalize_profile(99) diff --git a/tests/test_baichuan_rotary_buffers.py b/tests/test_baichuan_rotary_buffers.py new file mode 100644 index 000000000..d8256d476 --- /dev/null +++ b/tests/test_baichuan_rotary_buffers.py @@ -0,0 +1,70 @@ +import torch + +from gptqmodel.models.definitions.baichuan import BaiChuanQModel + + +class _DummyRotary(torch.nn.Module): + def __init__(self, inv_freq, *, max_seq_len_cached=32, base=10000): + super().__init__() + self.inv_freq = inv_freq + self.max_seq_len_cached = max_seq_len_cached + self.base = base + + +class _DummyAttention(torch.nn.Module): + def __init__(self, rotary): + super().__init__() + self.rotary_emb = rotary + + +class _DummyLayer(torch.nn.Module): + def __init__(self, rotary): + super().__init__() + self.self_attn = _DummyAttention(rotary) + + +class _DummyModel(torch.nn.Module): + def __init__(self, rotary): + super().__init__() + self.model = torch.nn.Module() + self.model.layers = torch.nn.ModuleList([_DummyLayer(rotary)]) + + +def _new_qmodel(): + return object.__new__(BaiChuanQModel) + + +def test_after_model_load_materializes_meta_rotary_and_registers_buffers(): + rotary = _DummyRotary(torch.empty(8, device="meta")) + model = _DummyModel(rotary) + + qmodel = _new_qmodel() + returned = BaiChuanQModel.after_model_load(qmodel, model, load_quantized_model=False) + + assert returned is model + assert rotary.inv_freq.device.type == "cpu" + assert rotary.inv_freq.dtype == torch.float32 + assert rotary.cos_cached.shape == (1, 1, 32, 16) + assert rotary.sin_cached.shape == (1, 1, 32, 16) + assert set(rotary._buffers) == {"inv_freq", "cos_cached", "sin_cached"} + assert rotary._non_persistent_buffers_set == {"inv_freq", "cos_cached", "sin_cached"} + assert "inv_freq" not in rotary.__dict__ + assert "cos_cached" not in rotary.__dict__ + assert "sin_cached" not in rotary.__dict__ + + +def test_after_model_load_promotes_existing_rotary_attrs_to_buffers(): + rotary = _DummyRotary(torch.arange(0, 8, dtype=torch.float32)) + rotary.cos_cached = torch.zeros((1, 1, 32, 16), dtype=torch.float32) + rotary.sin_cached = torch.ones((1, 1, 32, 16), dtype=torch.float32) + model = _DummyModel(rotary) + + qmodel = _new_qmodel() + BaiChuanQModel.after_model_load(qmodel, model, load_quantized_model=False) + + assert set(rotary._buffers) == {"inv_freq", "cos_cached", "sin_cached"} + assert rotary.cos_cached.device.type == "cpu" + assert rotary.sin_cached.device.type == "cpu" + assert "inv_freq" not in rotary.__dict__ + assert "cos_cached" not in rotary.__dict__ + assert "sin_cached" not in rotary.__dict__ diff --git a/tests/test_bench_cuda_even_d2h.py b/tests/test_bench_cuda_even_d2h.py index 56540597e..2574dc3de 100644 --- a/tests/test_bench_cuda_even_d2h.py +++ b/tests/test_bench_cuda_even_d2h.py @@ -510,11 +510,14 @@ def measure_parallel(n: int) -> float: torch.cuda.synchronize(device) return time.perf_counter() - t0 + repeats = 5 serial_times: dict[int, float] = {} parallel_times: dict[int, float] = {} for n in (1, 2, 4, 8): - serial_times[n] = measure_serial(n) - parallel_times[n] = measure_parallel(n) + serial_samples = [measure_serial(n) for _ in range(repeats)] + parallel_samples = [measure_parallel(n) for _ in range(repeats)] + serial_times[n] = statistics.median(serial_samples) + parallel_times[n] = statistics.median(parallel_samples) total_gib = (n * size_bytes) / (1024**3) print( f"[D2H wall] {n} transfers of {size_mib:.1f} MiB -> " @@ -522,13 +525,20 @@ def measure_parallel(n: int) -> float: f"parallel {parallel_times[n]:.4f}s ({total_gib/parallel_times[n]:.2f} GiB/s)" ) - baseline = parallel_times[1] - stall_observed = any(parallel_times[n] >= baseline * n * 0.8 for n in (2, 4, 8)) - assert stall_observed, "Expected concurrent D2H copies to serialize onto one engine" + baseline = max(parallel_times[1], serial_times[1]) + serialized = any(parallel_times[n] >= baseline * n * 0.6 for n in (2, 4, 8)) + overlapped = any(parallel_times[n] <= serial_times[n] * 0.85 for n in (2, 4, 8)) + # Some systems land in a stable middle ground where multi-stream copies are + # effectively the same speed as serial dispatch. Treat that as serialized. + near_serial = any(abs(parallel_times[n] - serial_times[n]) <= serial_times[n] * 0.15 for n in (2, 4, 8)) + assert serialized or overlapped or near_serial, ( + "Expected concurrent D2H copies to serialize, overlap, " + "or remain close to serial timing on ambiguous hardware" + ) - # Serial vs parallel should stay within reasonable bounds for all batch sizes. - for n in (1, 2, 4, 8): + # Single-transfer timings are noisy; only compare concurrency behavior for n > 1. + for n in (2, 4, 8): ratio = parallel_times[n] / serial_times[n] - assert 0.8 <= ratio <= 1.3, ( + assert 0.2 <= ratio <= 1.3, ( f"Parallel vs serial time deviated unexpectedly for {n} transfers: ratio={ratio:.2f}" ) diff --git a/tests/test_benchmark_submodule_finalize.py b/tests/test_benchmark_submodule_finalize.py index 0a692b753..9160000e4 100644 --- a/tests/test_benchmark_submodule_finalize.py +++ b/tests/test_benchmark_submodule_finalize.py @@ -15,9 +15,8 @@ from gptqmodel.looper import gptq_processor as gptq_processor_module from gptqmodel.looper.gptq_processor import GPTQProcessor from gptqmodel.looper.named_module import NamedModule -from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear +from gptqmodel.nn_modules.qlinear.torch import TorchLinear from gptqmodel.quantization.config import QuantizeConfig -from gptqmodel.utils.pause_resume import PauseResumeController from gptqmodel.utils.threadx import DeviceThreadPool @@ -153,7 +152,6 @@ def test_submodule_finalize_timing(): require_fwd=False, calculate_w_wq_diff=False, ) - processor._pause_controller = PauseResumeController() processor.pb = _DummyProgressBar() processor.preprocess(named_module) @@ -169,7 +167,7 @@ def test_submodule_finalize_timing(): quant_model = SimpleNamespace( model=base_model, quantize_config=qcfg, - qlinear_kernel=TorchQuantLinear, + qlinear_kernel=TorchLinear, lm_head="lm_head", quant_region_timer=timer, quantized=False, @@ -260,7 +258,7 @@ def wrapped_unregister_parameter(self, *args, **kwargs): assert "q_scales" not in named_module.state assert "q_zeros" not in named_module.state assert "q_g_idx" not in named_module.state - assert isinstance(quant_model.model.linear, TorchQuantLinear) + assert isinstance(quant_model.model.linear, TorchLinear) create_time = sum(duration for label, duration in events if label == "create_quant_module") pack_time = sum(duration for label, duration in events if label.startswith("pack_module")) @@ -307,7 +305,7 @@ def _prepare_modules(processor, qcfg, device, module_count): named_module.target_device = device named_module.module.target_device = device - processor.preprocess(named_module, failsafe=None) + processor.preprocess(named_module, fallback=None) processor.process(named_module) base_model.to("cpu") @@ -318,7 +316,7 @@ def _prepare_modules(processor, qcfg, device, module_count): quant_model = SimpleNamespace( model=base_model, quantize_config=qcfg, - qlinear_kernel=TorchQuantLinear, + qlinear_kernel=TorchLinear, lm_head="lm_head", quant_region_timer=_DummyTimer(), quantized=False, @@ -367,7 +365,6 @@ def test_submodule_finalize_threadpool_serialization(cpu_workers): require_fwd=False, calculate_w_wq_diff=False, ) - processor._pause_controller = PauseResumeController() processor.pb = _DummyProgressBar() module_count = min(cpu_workers * 2, 32) diff --git a/tests/test_bitblas.py b/tests/test_bitblas.py index ee3213aa8..3d05adc6b 100644 --- a/tests/test_bitblas.py +++ b/tests/test_bitblas.py @@ -1,4 +1,5 @@ import os +import tempfile import time from statistics import mean, pstdev @@ -8,27 +9,27 @@ from parameterized import parameterized from tabulate import tabulate +import gptqmodel.nn_modules.qlinear.bitblas as bitblas_module +import gptqmodel.utils.bitblas as bitblas_utils +import gptqmodel.utils.model as model_utils from gptqmodel import BACKEND, GPTQModel -from gptqmodel.nn_modules.qlinear.bitblas import ( - BITBLAS_AVAILABLE, - BitblasQuantLinear, - import_bitblas, -) -from gptqmodel.nn_modules.qlinear.marlin import MarlinQuantLinear, marlin_import_exception -from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear -from gptqmodel.nn_modules.qlinear.tritonv2 import TritonV2QuantLinear +from gptqmodel.nn_modules.qlinear.marlin import MarlinLinear, marlin_import_exception +from gptqmodel.nn_modules.qlinear.torch import TorchLinear +from gptqmodel.nn_modules.qlinear.tritonv2 import TritonV2Linear +from gptqmodel.quantization import FORMAT, METHOD, QuantizeConfig +from gptqmodel.utils.importer import get_kernel_for_backend @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required for BitBLAS") -@pytest.mark.skipif(not BITBLAS_AVAILABLE, reason="BitBLAS backend is not available") +@pytest.mark.skipif(not bitblas_module.BITBLAS_AVAILABLE, reason="BitBLAS backend is not available") def test_bitblas_forward_pass1(): - import_bitblas() + bitblas_module.import_bitblas() device_index = int(os.environ.get("BITBLAS_TEST_DEVICE", 0)) device = torch.device("cuda", device_index) torch.cuda.set_device(device_index) - layer = BitblasQuantLinear( + layer = bitblas_module.BitblasLinear( bits=4, group_size=32, desc_act=False, @@ -50,18 +51,610 @@ def test_bitblas_forward_pass1(): assert y.shape == (2, 32) assert torch.allclose(y, torch.zeros_like(y), atol=1e-4, rtol=1e-4) + +@pytest.mark.skipif(not bitblas_module.BITBLAS_AVAILABLE, reason="BitBLAS backend is not available") +def test_bitblas_target_normalization_preserves_supported_arch(): + assert bitblas_module._normalize_bitblas_target("cuda -arch=sm_89") == "cuda -arch=sm_89" + + +@pytest.mark.skipif(not bitblas_module.BITBLAS_AVAILABLE, reason="BitBLAS backend is not available") +def test_bitblas_target_normalization_strips_supported_arch_suffix(): + assert bitblas_module._normalize_bitblas_target("cuda -arch=sm_90a") == "cuda -arch=sm_90" + + +@pytest.mark.skipif(not bitblas_module.BITBLAS_AVAILABLE, reason="BitBLAS backend is not available") +def test_bitblas_target_normalization_falls_back_for_future_arch(): + assert bitblas_module._normalize_bitblas_target("cuda -arch=sm_120") == bitblas_module._bitblas_fallback_target() + + +def test_bitblas_supports_gptq_v2_kernel_selection(): + assert get_kernel_for_backend(BACKEND.BITBLAS, METHOD.GPTQ, FORMAT.GPTQ_V2) is bitblas_module.BitblasLinear + + +def test_bitblas_tuning_defaults_off_for_repack(monkeypatch): + """Keep GPTQ repacks from forcing expensive BitBLAS retuning by default.""" + monkeypatch.delenv("BITBLAS_ENABLE_TUNING", raising=False) + + assert bitblas_utils._should_enable_bitblas_tuning(repack=True) is False + assert bitblas_utils._should_enable_bitblas_tuning(repack=False) is True + + +def test_bitblas_tuning_env_override(monkeypatch): + """Allow callers to opt in or out of BitBLAS tuning explicitly.""" + monkeypatch.setenv("BITBLAS_ENABLE_TUNING", "1") + assert bitblas_utils._should_enable_bitblas_tuning(repack=True) is True + + monkeypatch.setenv("BITBLAS_ENABLE_TUNING", "0") + assert bitblas_utils._should_enable_bitblas_tuning(repack=False) is False + + +def test_bitblas_prefers_float32_accumulation_for_fp16_inputs(monkeypatch): + """Use fp32 accumulation to keep dequantized GPTQ inference numerically stable.""" + + captured = {} + + class _DummyMatmul: + lib = object() + weight_transform = None + + @staticmethod + def retrieve_weight_shape(): + return (1, 1) + + def _fake_get_or_create(self, config, enable_tuning): + captured["config"] = config + captured["enable_tuning"] = enable_tuning + return _DummyMatmul() + + monkeypatch.setattr(bitblas_module, "BITBLAS_AVAILABLE", True) + monkeypatch.setattr(bitblas_module, "import_bitblas", lambda: None) + monkeypatch.setattr( + bitblas_module.BitblasLinear, + "_get_or_create_bitblas_operator", + _fake_get_or_create, + ) + + bitblas_module.BitblasLinear( + bits=4, + group_size=32, + desc_act=False, + sym=True, + in_features=32, + out_features=32, + pack_dtype=torch.int32, + bias=False, + enable_tuning=False, + ) + + assert captured["config"].A_dtype == "float16" + assert captured["config"].out_dtype == "float16" + assert captured["config"].accum_dtype == "float32" + + +def test_bitblas_uses_bfloat16_configuration_when_requested(monkeypatch): + """Keep BitBLAS buffers and operator config aligned with bf16 model loads.""" + + captured = {} + + class _DummyMatmul: + lib = object() + weight_transform = None + + @staticmethod + def retrieve_weight_shape(): + return (1, 1) + + def _fake_get_or_create(self, config, enable_tuning): + captured["config"] = config + captured["enable_tuning"] = enable_tuning + return _DummyMatmul() + + monkeypatch.setattr(bitblas_module, "BITBLAS_AVAILABLE", True) + monkeypatch.setattr(bitblas_module, "import_bitblas", lambda: None) + monkeypatch.setattr( + bitblas_module.BitblasLinear, + "_get_or_create_bitblas_operator", + _fake_get_or_create, + ) + + layer = bitblas_module.BitblasLinear( + bits=4, + group_size=32, + desc_act=False, + sym=False, + in_features=32, + out_features=32, + pack_dtype=torch.int32, + dtype=torch.bfloat16, + bias=True, + enable_tuning=False, + ) + + assert captured["config"].A_dtype == "bfloat16" + assert captured["config"].out_dtype == "bfloat16" + assert captured["config"].accum_dtype == "float32" + assert layer.scales.dtype == torch.bfloat16 + assert layer.bias.dtype == torch.bfloat16 + + +def test_bitblas_repack_from_symmetric_gptq_remaps_signed_codes(monkeypatch): + class _DummyMatmul: + lib = object() + weight_transform = None + + @staticmethod + def retrieve_weight_shape(): + return (1, 1) + + monkeypatch.setattr(bitblas_module, "BITBLAS_AVAILABLE", True) + monkeypatch.setattr(bitblas_module, "import_bitblas", lambda: None) + monkeypatch.setattr( + bitblas_module.BitblasLinear, + "_get_or_create_bitblas_operator", + lambda self, config, enable_tuning: _DummyMatmul(), + ) + + bits = 4 + group_size = 32 + in_features = 32 + out_features = 32 + + linear, scales, zeros, g_idx = _mock_gptq_linear(bits, group_size, in_features, out_features) + gptq_linear = TorchLinear( + bits=bits, + group_size=group_size, + sym=True, + desc_act=False, + in_features=in_features, + out_features=out_features, + pack_dtype=torch.int32, + bias=False, + ) + gptq_linear.pack_block(linear, scales.T, zeros.T, g_idx=g_idx.to(torch.int32)) + + captured = {} + layer = bitblas_module.BitblasLinear( + bits=bits, + group_size=group_size, + desc_act=False, + sym=True, + in_features=in_features, + out_features=out_features, + pack_dtype=torch.int32, + bias=False, + enable_tuning=False, + ) + + def _capture_quant_state(*, intweight_out_in, scales_out_group, intzeros_group_out=None, bias=None): + captured["intweight"] = intweight_out_in.clone() + captured["scales"] = scales_out_group.clone() + captured["intzeros"] = intzeros_group_out + captured["bias"] = bias + + layer._load_bitblas_quant_state = _capture_quant_state + layer.repack_from_gptq(gptq_linear) + + packed_weight = gptq_linear.qweight.detach().T.contiguous().view(layer.quant_config.torch_storage_dtype) + unpacked_codes = bitblas_module.unpack_gptq_qweight(packed_weight, bits).contiguous() + expected = bitblas_module.remap_gptq_symmetric_codes_to_bitblas(unpacked_codes, bits) + + torch.testing.assert_close(captured["intweight"], expected) + assert not torch.equal(captured["intweight"], unpacked_codes) + torch.testing.assert_close(captured["scales"], gptq_linear.scales.detach().T.contiguous()) + assert captured["intzeros"] is None + assert captured["bias"] is None + + +def test_bitblas_repack_from_gptq_v2_symmetric_codes_does_not_remap(monkeypatch): + class _DummyMatmul: + lib = object() + weight_transform = None + + @staticmethod + def retrieve_weight_shape(): + return (1, 1) + + monkeypatch.setattr(bitblas_module, "BITBLAS_AVAILABLE", True) + monkeypatch.setattr(bitblas_module, "import_bitblas", lambda: None) + monkeypatch.setattr( + bitblas_module.BitblasLinear, + "_get_or_create_bitblas_operator", + lambda self, config, enable_tuning: _DummyMatmul(), + ) + + bits = 4 + group_size = 32 + in_features = 32 + out_features = 32 + + linear, scales, zeros, g_idx = _mock_gptq_linear(bits, group_size, in_features, out_features) + gptq_linear = TorchLinear( + bits=bits, + group_size=group_size, + sym=True, + desc_act=False, + in_features=in_features, + out_features=out_features, + pack_dtype=torch.int32, + bias=False, + ) + gptq_linear.pack_block(linear, scales.T, zeros.T, g_idx=g_idx.to(torch.int32)) + model_utils.convert_gptq_v1_to_v2_format_module(gptq_linear, bits=bits, pack_dtype=torch.int32) + + captured = {} + layer = bitblas_module.BitblasLinear( + bits=bits, + group_size=group_size, + desc_act=False, + sym=True, + in_features=in_features, + out_features=out_features, + pack_dtype=torch.int32, + bias=False, + enable_tuning=False, + ) + + def _capture_quant_state(*, intweight_out_in, scales_out_group, intzeros_group_out=None, bias=None): + captured["intweight"] = intweight_out_in.clone() + captured["scales"] = scales_out_group.clone() + captured["intzeros"] = intzeros_group_out + captured["bias"] = bias + + layer._load_bitblas_quant_state = _capture_quant_state + layer.repack_from_gptq(gptq_linear) + + packed_weight = gptq_linear.qweight.detach().T.contiguous().view(layer.quant_config.torch_storage_dtype) + unpacked_codes = bitblas_module.unpack_gptq_qweight(packed_weight, bits).contiguous() + remapped = bitblas_module.remap_gptq_symmetric_codes_to_bitblas(unpacked_codes, bits) + + torch.testing.assert_close(captured["intweight"], unpacked_codes) + assert not torch.equal(captured["intweight"], remapped) + torch.testing.assert_close(captured["scales"], gptq_linear.scales.detach().T.contiguous()) + assert captured["intzeros"] is None + assert captured["bias"] is None + + +def test_bitblas_validate_rejects_unsupported_bf16_signed_gptq(): + valid, err = bitblas_module.BitblasLinear.validate( + bits=4, + group_size=128, + desc_act=False, + sym=True, + in_features=3072, + out_features=1024, + pack_dtype=torch.int32, + dtype=torch.bfloat16, + ) + + assert valid is False + assert isinstance(err, NotImplementedError) + assert "signed low-bit dequantization" in str(err) + + +def test_bitblas_constructor_rejects_unsupported_bf16_signed_gptq(monkeypatch): + monkeypatch.setattr(bitblas_module, "BITBLAS_AVAILABLE", True) + monkeypatch.setattr( + bitblas_module, + "import_bitblas", + lambda: pytest.fail("unsupported bf16 signed GPTQ should be rejected before BitBLAS import"), + ) + + with pytest.raises(NotImplementedError, match="signed low-bit dequantization"): + bitblas_module.BitblasLinear( + bits=4, + group_size=128, + desc_act=False, + sym=True, + in_features=3072, + out_features=1024, + pack_dtype=torch.int32, + dtype=torch.bfloat16, + bias=False, + enable_tuning=False, + ) + + +def test_bitblas_validate_rejects_desc_act_gptq(): + valid, err = bitblas_module.BitblasLinear.validate( + bits=4, + group_size=128, + desc_act=True, + sym=True, + in_features=3072, + out_features=1024, + pack_dtype=torch.int32, + dtype=torch.float16, + ) + + assert valid is False + assert isinstance(err, NotImplementedError) + assert "actual desc_act" in str(err) + + +def test_bitblas_constructor_rejects_desc_act_gptq(monkeypatch): + monkeypatch.setattr(bitblas_module, "BITBLAS_AVAILABLE", True) + monkeypatch.setattr( + bitblas_module, + "import_bitblas", + lambda: pytest.fail("desc_act=True should be rejected before BitBLAS import"), + ) + + with pytest.raises(NotImplementedError, match="actual desc_act"): + bitblas_module.BitblasLinear( + bits=4, + group_size=128, + desc_act=True, + sym=True, + in_features=3072, + out_features=1024, + pack_dtype=torch.int32, + dtype=torch.float16, + bias=False, + enable_tuning=False, + ) + + +def test_bitblas_validate_rejects_non_divisible_in_features(): + valid, err = bitblas_module.BitblasLinear.validate( + bits=4, + group_size=32, + desc_act=False, + sym=False, + in_features=30, + out_features=32, + pack_dtype=torch.int32, + dtype=torch.float16, + ) + + assert valid is False + assert isinstance(err, NotImplementedError) + assert "must be divisible by [16]" in str(err) + + +def test_bitblas_validate_rejects_non_divisible_out_features(): + valid, err = bitblas_module.BitblasLinear.validate( + bits=4, + group_size=32, + desc_act=False, + sym=False, + in_features=32, + out_features=30, + pack_dtype=torch.int32, + dtype=torch.float16, + ) + + assert valid is False + assert isinstance(err, NotImplementedError) + assert "must be divisible by [16]" in str(err) + + +def test_convert_to_bitblas_preserves_name_and_dtype(monkeypatch): + class _SourceQuantLinear(nn.Module): + def __init__(self): + super().__init__() + self.in_features = 32 + self.out_features = 64 + self.bias = torch.zeros(64) + + class _ReplacementBitblas(nn.Module): + def __init__(self, **kwargs): + super().__init__() + self.kwargs = kwargs + self.in_features = kwargs["in_features"] + self.out_features = kwargs["out_features"] + self.bias = torch.zeros(self.out_features) if kwargs["bias"] else None + + monkeypatch.setattr(bitblas_utils, "_select_bitblas_kernel_class", lambda qcfg: _ReplacementBitblas) + + model = nn.Module() + model.proj = _SourceQuantLinear() + qcfg = QuantizeConfig( + bits=4, + group_size=32, + desc_act=False, + sym=True, + format=FORMAT.BITBLAS, + quant_method=METHOD.GPTQ, + pack_dtype=torch.int32, + ) + + bitblas_utils.convert_to_bitblas( + model, + _SourceQuantLinear, + qcfg, + sym=True, + desc_act=False, + repack=False, + dtype=torch.bfloat16, + ) + + assert isinstance(model.proj, _ReplacementBitblas) + assert model.proj.kwargs["name"] == "proj" + assert model.proj.kwargs["dtype"] == torch.bfloat16 + + +def test_create_quant_module_propagates_dtype_to_quant_linear(): + """Quantized checkpoint loads must instantiate the selected kernel with the requested dtype.""" + + seen = {} + + class _DummyQuantLinear(nn.Module): + @classmethod + def validate(cls, **kwargs): + seen["validate_dtype"] = kwargs.get("dtype") + return True, None + + def __init__(self, **kwargs): + super().__init__() + seen["init_dtype"] = kwargs.get("dtype") + self.bias = None + + module = nn.Module() + module.proj = nn.Linear(32, 32, bias=False) + + model_utils.create_quant_module( + name="proj", + linear_cls=_DummyQuantLinear, + bits=4, + desc_act=False, + dynamic=None, + group_size=32, + module=module, + submodule=module.proj, + sym=True, + device=None, + lm_head_name="lm_head", + pack_dtype=torch.int32, + backend=BACKEND.BITBLAS, + dtype=torch.bfloat16, + ) + + assert seen["validate_dtype"] == torch.bfloat16 + assert seen["init_dtype"] == torch.bfloat16 + assert isinstance(module.proj, _DummyQuantLinear) + + +def test_bitblas_rejects_unrunnable_operator(monkeypatch): + """Surface BitBLAS build failures during construction so auto-selection can fall back.""" + + class _BrokenMatmul: + weight_transform = None + + @staticmethod + def retrieve_weight_shape(): + return (1, 1) + + monkeypatch.setattr(bitblas_module, "BITBLAS_AVAILABLE", True) + monkeypatch.setattr(bitblas_module, "import_bitblas", lambda: None) + monkeypatch.setattr( + bitblas_module.BitblasLinear, + "_get_or_create_bitblas_operator", + lambda self, config, enable_tuning: _BrokenMatmul(), + ) + + with pytest.raises(NotImplementedError, match="BitBLAS could not build a runnable matmul"): + bitblas_module.BitblasLinear( + bits=4, + group_size=32, + desc_act=False, + sym=False, + in_features=32, + out_features=32, + pack_dtype=torch.int32, + dtype=torch.bfloat16, + bias=False, + enable_tuning=False, + ) + + +def test_make_quant_falls_back_when_bitblas_operator_is_unrunnable(monkeypatch): + """Auto kernel selection should skip BitBLAS when the runtime build is unusable.""" + + class _BrokenMatmul: + weight_transform = None + + @staticmethod + def retrieve_weight_shape(): + return (1, 1) + + monkeypatch.setattr(bitblas_module, "BITBLAS_AVAILABLE", True) + monkeypatch.setattr(bitblas_module, "import_bitblas", lambda: None) + monkeypatch.setattr( + bitblas_module.BitblasLinear, + "_get_or_create_bitblas_operator", + lambda self, config, enable_tuning: _BrokenMatmul(), + ) + monkeypatch.setattr( + model_utils, + "select_quant_linear", + lambda **kwargs: [bitblas_module.BitblasLinear, TorchLinear], + ) + bitblas_module.BitblasLinear.cached_validate_once.cache_clear() + + module = nn.Module() + module.proj = nn.Linear(32, 32, bias=False) + + qcfg = QuantizeConfig( + bits=4, + group_size=32, + desc_act=False, + sym=False, + format=FORMAT.GPTQ, + quant_method=METHOD.GPTQ, + pack_dtype=torch.int32, + ) + + selected = model_utils.make_quant( + module, + qcfg=qcfg, + quant_result={"proj": module.proj}, + backend=BACKEND.AUTO, + lm_head_name="lm_head", + dtype=torch.bfloat16, + ) + + assert selected is TorchLinear + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required for BitBLAS") +@pytest.mark.skipif(not bitblas_module.BITBLAS_AVAILABLE, reason="BitBLAS backend is not available") +def test_bitblas_forward_pass_future_target_fallback(): + from bitblas.cache import global_operator_cache + + bitblas_module.import_bitblas() + + device_index = int(os.environ.get("BITBLAS_TEST_DEVICE", 0)) + device = torch.device("cuda", device_index) + torch.cuda.set_device(device_index) + + original_target = bitblas_module.BITBLAS_TARGET + original_database_path = bitblas_module.BITBLAS_DATABASE_PATH + + with tempfile.TemporaryDirectory() as tmpdir: + try: + global_operator_cache.clear() + bitblas_module.BITBLAS_TARGET = "cuda -arch=sm_120" + bitblas_module.BITBLAS_DATABASE_PATH = tmpdir + + layer = bitblas_module.BitblasLinear( + bits=4, + group_size=32, + desc_act=False, + sym=True, + in_features=96, + out_features=48, + bias=False, + ).to(device) + + with torch.no_grad(): + layer.qweight.zero_() + layer.scales.zero_() + if layer.quant_config.with_zeros: + layer.qzeros.zero_() + + x = torch.randn(2, 96, device=device, dtype=layer.TORCH_DTYPE) + y = layer(x) + + assert y.shape == (2, 48) + assert torch.allclose(y, torch.zeros_like(y), atol=1e-4, rtol=1e-4) + assert layer.bitblas_matmul.target.arch == bitblas_module._bitblas_fallback_target().removeprefix("cuda -arch=") + finally: + bitblas_module.BITBLAS_TARGET = original_target + bitblas_module.BITBLAS_DATABASE_PATH = original_database_path + global_operator_cache.clear() + ######### test_bitblas_gptq_v2.py ######### @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required for BitBLAS") -@pytest.mark.skipif(not BITBLAS_AVAILABLE, reason="BitBLAS backend is not available") +@pytest.mark.skipif(not bitblas_module.BITBLAS_AVAILABLE, reason="BitBLAS backend is not available") def test_bitblas_forward_pass2(): - import_bitblas() + bitblas_module.import_bitblas() device_index = int(os.environ.get("BITBLAS_TEST_DEVICE", 0)) torch.cuda.set_device(device_index) # Load a dummy model (1.0 GB) to test if there are errors while converting to bitblas # Take a few minutes for compiling (1st run) and repacking (each time) - GPTQModel.load("XXXXyu/Qwen3-1.7B-w2g64-gptq_v2", trust_remote_code=True, backend=BACKEND('bitblas')) + GPTQModel.load("/monster/data/model/Qwen3-1.7B-w2g64-gptq_v2", trust_remote_code=True, backend=BACKEND('bitblas')) ########### test_bitblas_quant.py ########## @@ -163,8 +756,9 @@ def test_llama3_linear_bitblas_vs_torch_vs_marlin(_, batch, dtype, dtype_name): out_features = 8192 linear, scales, zeros, g_idx = _mock_gptq_linear(bits, group_size, in_features, out_features) + device = torch.device("cuda") - torch_linear = TorchQuantLinear( + torch_linear = TorchLinear( bits=bits, group_size=group_size, sym=True, @@ -177,25 +771,28 @@ def test_llama3_linear_bitblas_vs_torch_vs_marlin(_, batch, dtype, dtype_name): torch_linear.pack_block(linear, scales.T, zeros.T, g_idx=g_idx.to(torch.int32)) torch_linear.post_init() - bitblas_linear = BitblasQuantLinear( - bits=bits, - group_size=group_size, - desc_act=False, - sym=True, - in_features=in_features, - out_features=out_features, - pack_dtype=torch.int32, - bias=False, - enable_tuning=False, - ) - bitblas_linear.repack_from_gptq(torch_linear) - bitblas_linear.post_init() - - device = torch.device("cuda") - torch_linear = torch_linear.to(device=device, dtype=dtype) - bitblas_linear = bitblas_linear.to(device=device, dtype=dtype) + bitblas_linear = None + bitblas_error = None + try: + bitblas_linear = bitblas_module.BitblasLinear( + bits=bits, + group_size=group_size, + desc_act=False, + sym=True, + in_features=in_features, + out_features=out_features, + pack_dtype=torch.int32, + dtype=dtype, + bias=False, + enable_tuning=False, + ) + bitblas_linear.repack_from_gptq(torch_linear) + bitblas_linear.post_init() + bitblas_linear = bitblas_linear.to(device=device) + except Exception as exc: # pragma: no cover - diagnostic path + bitblas_error = str(exc) - marlin_linear = MarlinQuantLinear( + marlin_linear = MarlinLinear( bits=bits, group_size=group_size, desc_act=False, @@ -213,7 +810,7 @@ def test_llama3_linear_bitblas_vs_torch_vs_marlin(_, batch, dtype, dtype_name): marlin_linear.post_init() try: - triton_linear = TritonV2QuantLinear( + triton_linear = TritonV2Linear( bits=bits, group_size=group_size, desc_act=False, @@ -231,11 +828,12 @@ def test_llama3_linear_bitblas_vs_torch_vs_marlin(_, batch, dtype, dtype_name): triton_linear = triton_linear.to(device=device, dtype=dtype).eval() modules = { - "Torch": torch_linear.eval(), - "BitBLAS": bitblas_linear.eval(), + "Torch": torch_linear.to(device=device, dtype=dtype).eval(), "Marlin": marlin_linear.eval(), "TritonV2": triton_linear, } + if bitblas_linear is not None: + modules["BitBLAS"] = bitblas_linear.eval() x = torch.randn((batch, in_features), dtype=dtype, device=device) @@ -243,6 +841,8 @@ def test_llama3_linear_bitblas_vs_torch_vs_marlin(_, batch, dtype, dtype_name): reference_out = None outputs: dict[str, torch.Tensor] = {} errors: dict[str, str] = {} + if bitblas_error is not None: + errors["BitBLAS"] = bitblas_error for name, module in modules.items(): try: @@ -253,7 +853,8 @@ def test_llama3_linear_bitblas_vs_torch_vs_marlin(_, batch, dtype, dtype_name): except Exception as exc: # pragma: no cover - diagnostic path errors[name] = str(exc) - for name, module in modules.items(): + for name in ("Torch", "BitBLAS", "Marlin", "TritonV2"): + module = modules.get(name) err = errors.get(name) if err: results.append([ @@ -267,6 +868,8 @@ def test_llama3_linear_bitblas_vs_torch_vs_marlin(_, batch, dtype, dtype_name): "\033[91mERR\033[0m", ]) continue + if module is None: + continue out = outputs[name] if name == "Torch" or reference_out is None: @@ -303,9 +906,9 @@ def test_llama3_linear_bitblas_vs_torch_vs_marlin(_, batch, dtype, dtype_name): "Kernel", "Mean ms", "Std ms", - "Max |Δ|", - "Mean |Δ|", - "Max Rel Δ", + "Max |d|", + "Mean |d|", + "Max Rel d", "Accuracy", ] print(tabulate(results, headers=headers, tablefmt="github")) diff --git a/tests/test_bits.py b/tests/test_bits.py index 8f1ae4d5f..640f309da 100644 --- a/tests/test_bits.py +++ b/tests/test_bits.py @@ -6,8 +6,6 @@ # -- do not touch import os -from gptqmodel.nn_modules.qlinear.exllama_eora import ExllamaEoraQuantLinear - os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch @@ -15,43 +13,39 @@ import tempfile # noqa: E402 import unittest # noqa: E402 -from lm_eval.utils import make_table # noqa: E402 from transformers import AutoTokenizer # noqa: E402 from gptqmodel import BACKEND, GPTQModel, QuantizeConfig # noqa: E402 -from gptqmodel.nn_modules.qlinear.bitblas import BitBLASQuantLinear # noqa: E402 -from gptqmodel.nn_modules.qlinear.exllama import ExllamaQuantLinear # noqa: E402 -from gptqmodel.nn_modules.qlinear.exllamav2 import ExllamaV2QuantLinear # noqa: E402 -from gptqmodel.nn_modules.qlinear.marlin import MarlinQuantLinear # noqa: E402 -from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear # noqa: E402 -from gptqmodel.nn_modules.qlinear.tritonv2 import TritonV2QuantLinear # noqa: E402 -from gptqmodel.utils.eval import EVAL # noqa: E402 +from gptqmodel.nn_modules.qlinear.bitblas import BitBLASLinear # noqa: E402 +from gptqmodel.nn_modules.qlinear.exllamav2 import ExllamaV2Linear # noqa: E402 +from gptqmodel.nn_modules.qlinear.marlin import MarlinLinear # noqa: E402 +from gptqmodel.nn_modules.qlinear.torch import TorchLinear # noqa: E402 +from gptqmodel.nn_modules.qlinear.tritonv2 import TritonV2Linear # noqa: E402 +from tests.eval import evaluate, format_eval_result_table, get_eval_task_metrics # noqa: E402 logger = logging.getLogger(__name__) RAND_SEED = 42 -TASK_NAME = EVAL.LM_EVAL.ARC_CHALLENGE +TASK_NAME = "arc_challenge" class TestBits(unittest.TestCase): QLINEAR_DICT = { - BACKEND.EXLLAMA_V1: ExllamaQuantLinear, - BACKEND.EXLLAMA_V2: ExllamaV2QuantLinear, - BACKEND.TRITON: TritonV2QuantLinear, - BACKEND.TORCH: TorchQuantLinear, - BACKEND.BITBLAS: BitBLASQuantLinear, - BACKEND.MARLIN: MarlinQuantLinear, - BACKEND.EXLLAMA_EORA: ExllamaEoraQuantLinear, + BACKEND.EXLLAMA_V2: ExllamaV2Linear, + BACKEND.TRITON: TritonV2Linear, + BACKEND.TORCH: TorchLinear, + BACKEND.BITBLAS: BitBLASLinear, + BACKEND.MARLIN: MarlinLinear, } QUANT_ARC_MAX_DELTA_FLOOR_PERCENT = 0.2 QUANT_ARC_MAX_POSITIVE_DELTA_CEIL_PERCENT = 0.2 CUDA_QLINEAR_QUANTIZED_MODEL_ARC_CHALLENGE_EXPECTS = { - 2: {'acc,none': 0.2150170648464164, 'acc_norm,none': 0.2696245733788396}, - 3: {'acc,none': 0.2175767918088737, 'acc_norm,none': 0.26621160409556316}, - 4: {'acc,none': 0.2363, 'acc_norm,none': 0.2517}, - 8: {'acc,none': 0.3020, 'acc_norm,none': 0.3319112627986348}, + 2: {'accuracy,loglikelihood': 0.2150170648464164, 'accuracy,loglikelihood_norm': 0.2696245733788396}, + 3: {'accuracy,loglikelihood': 0.2175767918088737, 'accuracy,loglikelihood_norm': 0.26621160409556316}, + 4: {'accuracy,loglikelihood': 0.2363, 'accuracy,loglikelihood_norm': 0.2517}, + 8: {'accuracy,loglikelihood': 0.3020, 'accuracy,loglikelihood_norm': 0.3319112627986348}, } def calculatorPer(self, filter, value, base_value): @@ -69,11 +63,6 @@ def check_results(self, bits: int, task_results): @classmethod def setUpClass(cls): - # cls.pack_backends = [BACKEND.EXLLAMA_V1, BACKEND.TRITON, BACKEND.CUDA, BACKEND.TORCH, BACKEND.BITBLAS, - # BACKEND.IPEX] - # cls.backends = list(cls.pack_backends) - # cls.backends.extend([BACKEND.EXLLAMA_V2, BACKEND.MARLIN, ]) - # TODO Only CUDA Quant Linear is tested for now cls.pack_backends = [BACKEND.TRITON] cls.backends = [BACKEND.MARLIN] @@ -134,22 +123,19 @@ def eval(self, inference_backend, quant_backend, quantize_config, tmp_dir): device_map="auto", backend=inference_backend, ) - results = GPTQModel.eval( + results = evaluate( model_or_id_or_path=model, output_path=tmp_dir, tasks=[TASK_NAME], apply_chat_template=False, trust_remote_code=False, batch_size=4, - random_seed=RAND_SEED, ) print('--------Eval Result---------') - print(make_table(results)) - if "groups" in results: - print(make_table(results, "groups")) + print(format_eval_result_table(results)) print('--------Eval Result End---------') task_results = { - metric: value for metric, value in results['results'].get(TASK_NAME.value, {}).items() + metric: value for metric, value in get_eval_task_metrics(results, TASK_NAME).items() if metric != 'alias' and 'stderr' not in metric } print(f"bits is: {quantize_config.bits}, quant_backend: {quant_backend}, inference_backend: {inference_backend} -> task_results: {task_results}") diff --git a/tests/test_bits_new.py b/tests/test_bits_new.py index b9951f383..c3047362b 100644 --- a/tests/test_bits_new.py +++ b/tests/test_bits_new.py @@ -24,14 +24,13 @@ from typing import Optional # noqa: E402 from datasets import load_dataset # noqa: E402 -from lm_eval.utils import make_table # noqa: E402 from models.model_test import ModelTest # noqa: E402 from tabulate import tabulate # noqa: E402 from gptqmodel import BACKEND, GPTQModel, QuantizeConfig # noqa: E402 from gptqmodel.adapter.adapter import Lora # noqa: E402 -from gptqmodel.utils.eval import EVAL # noqa: E402 from gptqmodel.utils.torch import torch_empty_cache # noqa: E402 +from tests.eval import evaluate, format_eval_result_table # noqa: E402 def bench(path: str, backend: BACKEND, adapter: Optional[Lora]): @@ -51,10 +50,9 @@ def bench(path: str, backend: BACKEND, adapter: Optional[Lora]): print(f"BACKEND: {backend}, Result: {result}") # assert "paris" in result.lower(), f"`paris` not found in `{result}`" - bench_result = GPTQModel.eval( + bench_result = evaluate( model_or_id_or_path=model, - framework=EVAL.LM_EVAL, - tasks=[EVAL.LM_EVAL.ARC_CHALLENGE, EVAL.LM_EVAL.MMLU_STEM], + tasks=["arc_challenge", "mmlu_stem"], batch_size=16, ) @@ -71,7 +69,7 @@ class Test(ModelTest): EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + "arc_challenge": { "acc": {"value": 0.3567, "floor_pct": 0.36}, "acc_norm": {"value": 0.3805, "floor_pct": 0.36}, }, @@ -167,21 +165,18 @@ def test_quant_and_eora(self): del model torch_empty_cache() - # BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.TRITON, BACKEND.CUDA, - for backend in [ BACKEND.TORCH ]: # BACKEND.IPEX, BACKEND.BITBLAS, BACKEND.EXLLAMA_V2V BACKEND.MARLIN + for backend in [ BACKEND.TORCH ]: # BACKEND.TORCH_FUSED, BACKEND.BITBLAS, BACKEND.EXLLAMA_V2V BACKEND.MARLIN base_bench = bench(path=save_path, backend=backend, adapter=None) # inference using qweights only # eora_bench = bench(path=tmpdir, backend=backend, adapter=eora) # inference using eora (lora) - print('--------GPTQModel + EoRA Config ---------') + print('--------GPT-QModel + EoRA Config ---------') # Convert the dictionary to a list of lists for tabulate table_data = [[key, value] for key, value in config_dict.items()] print(tabulate(table_data, headers=["Key", "Value"], tablefmt="grid")) print('--------Eval GPTQ Result---------') - print(make_table(base_bench)) - if "groups" in base_bench: - print(make_table(base_bench, "groups")) + print(format_eval_result_table(base_bench)) # print('--------Eval GPTQ + EoRA Result---------') # print(make_table(eora_bench)) diff --git a/tests/test_bitsandbytes.py b/tests/test_bitsandbytes.py new file mode 100644 index 000000000..ececdb886 --- /dev/null +++ b/tests/test_bitsandbytes.py @@ -0,0 +1,231 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import json + +import pytest +import torch +import torch.nn as nn +from safetensors import safe_open +from safetensors.torch import save_file + +import gptqmodel.utils.model as model_utils +from gptqmodel.models._const import DEVICE +from gptqmodel.nn_modules.qlinear.bitsandbytes import BITSANDBYTES_AVAILABLE, BitsAndBytesLinear +from gptqmodel.quantization import FORMAT, METHOD +from gptqmodel.quantization.config import BitsAndBytesConfig +from gptqmodel.utils.backend import BACKEND +from gptqmodel.utils.importer import get_kernel_for_backend, select_quant_linear +from gptqmodel.utils.model_dequant import dequantize_model, detect_format + + +def _build_linear(bits: int) -> nn.Linear: + torch.manual_seed(100 + bits) + linear = nn.Linear(32, 24, bias=True).eval() + if bits == 8: + linear = linear.half() + return linear + + +def _build_kernel(bits: int, linear: nn.Linear) -> BitsAndBytesLinear: + kwargs = { + "bits": bits, + "group_size": -1, + "sym": True, + "desc_act": False, + "in_features": linear.in_features, + "out_features": linear.out_features, + "bias": linear.bias is not None, + "register_buffers": False, + } + if bits == 4: + kwargs.update( + { + "format": "nf4", + "block_size": 128, + "compress_statistics": False, + } + ) + + kernel = BitsAndBytesLinear(**kwargs) + kernel.pack_original(linear=linear, scales=None, zeros=None) + kernel.post_init() + return kernel.eval() + + +@pytest.mark.skipif(not BITSANDBYTES_AVAILABLE, reason="bitsandbytes backend unavailable") +def test_bitsandbytes_kernel_selection(): + assert get_kernel_for_backend(BACKEND.BITSANDBYTES, METHOD.BITSANDBYTES, FORMAT.BITSANDBYTES) is BitsAndBytesLinear + + for bits in (4, 8): + selected = select_quant_linear( + bits=bits, + group_size=-1, + desc_act=False, + sym=True, + backend=BACKEND.BITSANDBYTES, + format=FORMAT.BITSANDBYTES, + quant_method=METHOD.BITSANDBYTES, + device=DEVICE.CPU, + pack_dtype=torch.int32, + ) + assert selected is BitsAndBytesLinear + + +def test_create_quant_module_uses_dynamic_bits_for_bitsandbytes_format_normalization(): + seen = {} + + class _DummyBitsAndBytesLinear(nn.Module): + @classmethod + def validate(cls, **kwargs): + seen["validate_bits"] = kwargs.get("bits") + return True, None + + def __init__(self, **kwargs): + super().__init__() + seen["init_bits"] = kwargs.get("bits") + seen["format"] = kwargs.get("format") + self.bias = None + + module = nn.Module() + module.proj = nn.Linear(32, 32, bias=False) + + model_utils.create_quant_module( + name="proj", + linear_cls=_DummyBitsAndBytesLinear, + bits=4, + desc_act=False, + dynamic={ + r"proj": { + "bits": 8, + "bnb_quant_type": "int8", + } + }, + group_size=-1, + module=module, + submodule=module.proj, + sym=True, + device=None, + lm_head_name="lm_head", + pack_dtype=torch.int32, + format=FORMAT.BITSANDBYTES, + backend=BACKEND.BITSANDBYTES, + ) + + assert seen["validate_bits"] == 8 + assert seen["init_bits"] == 8 + assert seen["format"] == "int8" + assert isinstance(module.proj, _DummyBitsAndBytesLinear) + + +@pytest.mark.skipif(not BITSANDBYTES_AVAILABLE, reason="bitsandbytes backend unavailable") +@pytest.mark.parametrize("bits", [4, 8]) +def test_bitsandbytes_forward_matches_dequantized_reference(bits: int): + linear = _build_linear(bits) + kernel = _build_kernel(bits, linear) + + x = torch.randn(7, linear.in_features, dtype=torch.float32) + dequant_weight = kernel.dequantize_weight().to(torch.float32) + expected = torch.matmul(x, dequant_weight.t()) + if kernel.bias is not None: + expected = expected + kernel.bias.to(torch.float32) + + with torch.inference_mode(): + out = kernel(x) + + atol = 1e-3 if bits == 4 else 1e-2 + rtol = 1e-3 if bits == 4 else 5e-2 + torch.testing.assert_close(out.to(torch.float32), expected, rtol=rtol, atol=atol) + + +@pytest.mark.skipif(not BITSANDBYTES_AVAILABLE, reason="bitsandbytes backend unavailable") +@pytest.mark.parametrize("bits", [4, 8]) +def test_bitsandbytes_state_dict_round_trip(bits: int): + linear = _build_linear(bits) + kernel = _build_kernel(bits, linear) + + reload_kwargs = { + "bits": bits, + "group_size": -1, + "sym": True, + "desc_act": False, + "in_features": linear.in_features, + "out_features": linear.out_features, + "bias": linear.bias is not None, + "register_buffers": True, + } + if bits == 4: + reload_kwargs.update( + { + "format": "nf4", + "block_size": 128, + "compress_statistics": False, + } + ) + + reloaded = BitsAndBytesLinear(**reload_kwargs).eval() + reloaded.load_state_dict(kernel.state_dict(), strict=True) + reloaded.post_init() + + x = torch.randn(5, linear.in_features, dtype=torch.float32) + with torch.inference_mode(): + torch.testing.assert_close( + reloaded.dequantize_weight().to(torch.float32), + kernel.dequantize_weight().to(torch.float32), + rtol=1e-4, + atol=1e-4, + ) + torch.testing.assert_close( + reloaded(x).to(torch.float32), + kernel(x).to(torch.float32), + rtol=1e-4, + atol=1e-4, + ) + + +@pytest.mark.skipif(not BITSANDBYTES_AVAILABLE, reason="bitsandbytes backend unavailable") +@pytest.mark.parametrize("bits", [4, 8]) +def test_detect_and_dequantize_bitsandbytes_checkpoint(tmp_path, bits: int): + linear = _build_linear(bits) + kernel = _build_kernel(bits, linear) + + prefix = "model.layers.0.mlp.up_proj" + state_dict = {f"{prefix}.{name}": tensor for name, tensor in kernel.state_dict().items()} + + model_dir = tmp_path / f"bnb_{bits}bit" + model_dir.mkdir() + save_file(state_dict, str(model_dir / "model.safetensors")) + + quant_cfg = BitsAndBytesConfig( + bits=bits, + format="nf4" if bits == 4 else "int8", + block_size=128 if bits == 4 else 64, + compress_statistics=False if bits == 4 else True, + ) + config_payload = {"quantization_config": quant_cfg.to_dict()} + (model_dir / "config.json").write_text(json.dumps(config_payload), encoding="utf-8") + + detected = detect_format(model_dir, config_payload) + assert detected == "bitsandbytes" + + output_dir = tmp_path / f"bnb_{bits}bit_dequantized" + dequantize_model(model_dir, output_dir, target_dtype=torch.float16) + + with safe_open(output_dir / "model.safetensors", framework="pt", device="cpu") as reader: + weight = reader.get_tensor(f"{prefix}.weight") + bias = reader.get_tensor(f"{prefix}.bias") + + torch.testing.assert_close( + weight.to(torch.float32), + kernel.dequantize_weight().to(torch.float32), + rtol=1e-3, + atol=1e-3, + ) + torch.testing.assert_close( + bias.to(torch.float32), + kernel.bias.to(torch.float32), + rtol=1e-5, + atol=1e-5, + ) diff --git a/tests/test_calibration_data_device.py b/tests/test_calibration_data_device.py new file mode 100644 index 000000000..0941bd61d --- /dev/null +++ b/tests/test_calibration_data_device.py @@ -0,0 +1,1345 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +"""Unit tests for the calibration_data_device feature. + +Tests the actual implementation code paths with mocked dependencies, +not extracted copies of the logic. + +This feature allows specifying where calibration data is stored during quantization: +- None (default): original behavior - CPU initially, then DEVICE_0 +- "cpu": store calibration data on CPU +- "cuda:1" or any torch.device: store on specific device +- "balanced": distribute across available compute devices via round-robin +""" + +import os +import types +import unittest + +import pytest +import torch +import torch.nn as nn + +from gptqmodel.models.base import BaseQModel +from gptqmodel.quantization import QuantizeConfig + + +# Model path for integration tests - can be overridden via environment variable +_INTEGRATION_MODEL_ID = os.environ.get( + "GPTQMODEL_TEST_MODEL_ID", + "/monster/data/model/TinyLlama-1.1B-Chat-v1.0" +) + + +# ============================================================================ +# CONFIG PARSING TESTS (CPU-only) +# ============================================================================ + +class TestCalibrationDataDeviceConfig(unittest.TestCase): + """Test calibration_data_device configuration parsing in QuantizeConfig.__post_init__.""" + + def test_default_is_none(self): + """Default calibration_data_device should be None.""" + config = QuantizeConfig(bits=4, group_size=128) + self.assertIsNone(config.calibration_data_device) + + def test_explicit_none(self): + """Explicitly setting None should work.""" + config = QuantizeConfig(bits=4, group_size=128, calibration_data_device=None) + self.assertIsNone(config.calibration_data_device) + + def test_balanced_string_normalized(self): + """String 'balanced' should be preserved as lowercase 'balanced'.""" + config = QuantizeConfig(bits=4, group_size=128, calibration_data_device="balanced") + self.assertEqual(config.calibration_data_device, "balanced") + + # Test case normalization + config_upper = QuantizeConfig(bits=4, group_size=128, calibration_data_device="BALANCED") + self.assertEqual(config_upper.calibration_data_device, "balanced") + + config_mixed = QuantizeConfig(bits=4, group_size=128, calibration_data_device="Balanced") + self.assertEqual(config_mixed.calibration_data_device, "balanced") + + def test_cpu_device_string_converted_to_torch_device(self): + """String 'cpu' should be converted to torch.device('cpu').""" + config = QuantizeConfig(bits=4, group_size=128, calibration_data_device="cpu") + self.assertIsInstance(config.calibration_data_device, torch.device) + self.assertEqual(config.calibration_data_device.type, "cpu") + + def test_cuda_device_string_converted_to_torch_device(self): + """String 'cuda:0' should be converted to torch.device('cuda:0').""" + config = QuantizeConfig(bits=4, group_size=128, calibration_data_device="cuda:0") + self.assertIsInstance(config.calibration_data_device, torch.device) + self.assertEqual(config.calibration_data_device.type, "cuda") + self.assertEqual(config.calibration_data_device.index, 0) + + def test_cuda_indexed_device_string(self): + """String 'cuda:1' should be converted to torch.device('cuda:1').""" + config = QuantizeConfig(bits=4, group_size=128, calibration_data_device="cuda:1") + self.assertIsInstance(config.calibration_data_device, torch.device) + self.assertEqual(config.calibration_data_device.type, "cuda") + self.assertEqual(config.calibration_data_device.index, 1) + + def test_torch_device_input_preserved(self): + """torch.device input should be preserved/normalized.""" + device = torch.device("cuda:2") + config = QuantizeConfig(bits=4, group_size=128, calibration_data_device=device) + self.assertIsInstance(config.calibration_data_device, torch.device) + self.assertEqual(config.calibration_data_device.type, "cuda") + self.assertEqual(config.calibration_data_device.index, 2) + + def test_cuda_without_index_normalized_to_index_0(self): + """torch.device('cuda') without index should be normalized to cuda:0 via _canonical_device.""" + # String input + config = QuantizeConfig(bits=4, group_size=128, calibration_data_device="cuda") + self.assertIsInstance(config.calibration_data_device, torch.device) + self.assertEqual(config.calibration_data_device.type, "cuda") + self.assertEqual(config.calibration_data_device.index, 0) + + # torch.device input without index + config2 = QuantizeConfig(bits=4, group_size=128, calibration_data_device=torch.device("cuda")) + self.assertIsInstance(config2.calibration_data_device, torch.device) + self.assertEqual(config2.calibration_data_device.type, "cuda") + self.assertEqual(config2.calibration_data_device.index, 0) + + +# ============================================================================ +# STAGE INPUTS CAPTURE TESTS - Test real cache_inputs method +# ============================================================================ + +class _DummyLooperForCapture: + """Minimal fake ModuleLooper for testing StageInputsCapture.""" + + def __init__(self, gptq_model): + self.gptq_model = gptq_model + + def _batch_row_count(self, batch_inputs): + if not batch_inputs: + return 0 + primary = batch_inputs[0] + if isinstance(primary, torch.Tensor) and primary.ndim > 0: + return max(int(primary.shape[0]), 0) + return 1 + + +def _create_mock_logger(): + """Create a mock logger for StageInputsCapture tests.""" + class MockLogger: + def info(self, *args, **kwargs): + pass + + def debug(self, *args, **kwargs): + pass + + def pb(self, *args, **kwargs): + class MockPB: + def manual(self): + return self + + def set(self, **kw): + return self + + def title(self, *a): + return self + + def subtitle(self, *a): + return self + + def draw(self): + return self + + def close(self): + pass + + return MockPB() + + return MockLogger() + + +def _create_mock_thread_pool(): + """Create a mock DEVICE_THREAD_POOL for tests.""" + class MockThreadPool: + @staticmethod + def read_lock(device): + class MockLock: + def __enter__(self): + return self + + def __exit__(self, *a): + pass + + return MockLock() + + return MockThreadPool() + + +def test_stage_capture_cpu_device_stores_inputs_on_cpu(monkeypatch): + """ + Test StageInputsCapture: with calibration_data_device='cpu', + inputs are captured on CPU. + + Tests: stage_inputs_capture.py lines 85-108, 132-165 + """ + from gptqmodel.looper.stage_inputs_capture import StageInputsCapture + + # Create a minimal model where calling model(input_ids=...) triggers the hooked layer + class MinimalModel(nn.Module): + def __init__(self, hooked_layer): + super().__init__() + self.hooked_layer = hooked_layer + self.config = types.SimpleNamespace(model_type="test") + + def forward(self, input_ids=None, attention_mask=None, **kwargs): + # Trigger the hooked layer - this will raise StopForward + self.hooked_layer(input_ids.float()) + return None + + # Create the layer that will be hooked + hooked_layer = nn.Linear(4, 4) + hooked_layer.full_name = "test_layer" + + model = MinimalModel(hooked_layer) + + # Create fake GPT-QModel with calibration_data_device="cpu" + class FakeGPTQModel: + capture_first_layer_positional_inputs = BaseQModel.capture_first_layer_positional_inputs + capture_first_layer_input_kwargs = BaseQModel.capture_first_layer_input_kwargs + finalize_input_capture_example = BaseQModel.finalize_input_capture_example + move_input_capture_example = BaseQModel.move_input_capture_example + prepare_layer_replay_kwargs = BaseQModel.prepare_layer_replay_kwargs + run_input_capture = BaseQModel.run_input_capture + + def __init__(self): + self.quantize_config = types.SimpleNamespace( + device=torch.device("cpu"), + calibration_data_device=torch.device("cpu"), + compute_device_filter=None, + ) + self.model = model + self.ATTENTION_MASKS_REQUIRED_FOR_INPUT = False + self.ATTENTION_MASKS_DTYPE = torch.long + self.INPUT_EMBEDDING_EXTRA_ARGS = None + self.quant_region_timer = None + + def shell_module_materialize(self, target_submodule, device, **kwargs): + del kwargs + return target_submodule + + def get_base_modules(self, model): + return [] + + def pre_quantize_generate_hook_start(self): + pass + + def pre_quantize_generate_hook_end(self): + pass + + gptq_model = FakeGPTQModel() + looper = _DummyLooperForCapture(gptq_model) + + # Mock external dependencies + monkeypatch.setattr( + "gptqmodel.looper.stage_inputs_capture.DEVICE_THREAD_POOL", + _create_mock_thread_pool(), + ) + + # Mock ctx to be a proper context manager that combines multiple context managers + def mock_ctx(*context_managers): + class CombinedContextManager: + def __enter__(self): + for cm in context_managers: + cm.__enter__() + return self + + def __exit__(self, *args): + for cm in reversed(context_managers): + cm.__exit__(*args) + return False + return CombinedContextManager() + + monkeypatch.setattr( + "gptqmodel.looper.stage_inputs_capture.ctx", + mock_ctx, + ) + + # Mock device_ctx to be a proper context manager + def mock_device_ctx(device): + class DeviceContextManager: + def __enter__(self): + return self + def __exit__(self, *args): + return False + return DeviceContextManager() + + monkeypatch.setattr( + "gptqmodel.looper.stage_inputs_capture.device_ctx", + mock_device_ctx, + ) + + capture = StageInputsCapture(looper, logger=_create_mock_logger()) + + # Create calibration data + calibration_data = [ + {"input_ids": torch.randint(0, 100, (1, 4)), "attention_mask": torch.ones(1, 4, dtype=torch.long)} + ] + + # Run cache_inputs - this will trigger store_input_hook which captures inputs + result = capture.cache_inputs( + layers=[hooked_layer], + calibration_data=calibration_data, + use_cache=False, + ) + + # Verify inputs were captured on CPU + assert len(result.layer_inputs) == 1, "Should have captured one batch" + assert len(result.layer_inputs[0]) == 1, "Should have one input tensor" + assert result.layer_inputs[0][0].device.type == "cpu", \ + f"Input should be on CPU, got {result.layer_inputs[0][0].device}" + + +@pytest.mark.skipif( + torch.cuda.device_count() < 2, + reason="Requires at least 2 CUDA devices" +) +def test_stage_capture_balanced_mode_applies_compute_device_filter(monkeypatch): + """ + Test StageInputsCapture: in balanced mode, compute_device_filter is applied + to determine which devices are used for round-robin distribution. + + Tests: stage_inputs_capture.py lines 91-104 + """ + from gptqmodel.looper.stage_inputs_capture import StageInputsCapture + + # Simulate 4 devices, cuda:0 will be filtered out + mock_devices = [ + torch.device("cuda:0"), + torch.device("cuda:1"), + torch.device("cuda:2"), + torch.device("cuda:3"), + ] + + def fake_select_forward_devices(_base_device): + return mock_devices + + def compute_device_filter(devices): + # Filter out cuda:0 (e.g., reserved for model layers) + return [d for d in devices if d != torch.device("cuda:0")] + + # Patch at source module since import happens inside the function + monkeypatch.setattr( + "gptqmodel.utils.looper_helpers.select_forward_devices", + fake_select_forward_devices, + ) + + # Create a minimal model + class MinimalModel(nn.Module): + def __init__(self, hooked_layer): + super().__init__() + self.hooked_layer = hooked_layer + self.config = types.SimpleNamespace(model_type="test") + + def forward(self, input_ids=None, attention_mask=None, **kwargs): + self.hooked_layer(input_ids.float()) + return None + + hooked_layer = nn.Linear(4, 4) + hooked_layer.full_name = "test_layer" + model = MinimalModel(hooked_layer) + + class FakeGPTQModel: + capture_first_layer_positional_inputs = BaseQModel.capture_first_layer_positional_inputs + capture_first_layer_input_kwargs = BaseQModel.capture_first_layer_input_kwargs + finalize_input_capture_example = BaseQModel.finalize_input_capture_example + move_input_capture_example = BaseQModel.move_input_capture_example + prepare_layer_replay_kwargs = BaseQModel.prepare_layer_replay_kwargs + run_input_capture = BaseQModel.run_input_capture + + def __init__(self): + self.quantize_config = types.SimpleNamespace( + device=torch.device("cuda:0"), + calibration_data_device="balanced", + compute_device_filter=compute_device_filter, + ) + self.model = model + self.ATTENTION_MASKS_REQUIRED_FOR_INPUT = False + self.ATTENTION_MASKS_DTYPE = torch.long + self.INPUT_EMBEDDING_EXTRA_ARGS = None + self.quant_region_timer = None + + def shell_module_materialize(self, target_submodule, device, **kwargs): + del kwargs + return target_submodule + + def get_base_modules(self, model): + return [] + + def pre_quantize_generate_hook_start(self): + pass + + def pre_quantize_generate_hook_end(self): + pass + + gptq_model = FakeGPTQModel() + looper = _DummyLooperForCapture(gptq_model) + + # Track which devices are actually used + used_devices = [] + __import__('gptqmodel.utils.model', fromlist=['move_to']).move_to + + def tracking_move_to(obj, device, **kwargs): + if isinstance(obj, torch.Tensor): + used_devices.append(device) + # For testing, just clone the tensor (don't actually move to GPU) + return obj.detach().clone() if device.type == "cpu" else obj + + monkeypatch.setattr( + "gptqmodel.looper.stage_inputs_capture.DEVICE_THREAD_POOL", + _create_mock_thread_pool(), + ) + + # Mock ctx to be a proper context manager that combines multiple context managers + def mock_ctx(*context_managers): + class CombinedContextManager: + def __enter__(self): + for cm in context_managers: + cm.__enter__() + return self + + def __exit__(self, *args): + for cm in reversed(context_managers): + cm.__exit__(*args) + return False + return CombinedContextManager() + + monkeypatch.setattr( + "gptqmodel.looper.stage_inputs_capture.ctx", + mock_ctx, + ) + + # Mock device_ctx to be a proper context manager + def mock_device_ctx(device): + class DeviceContextManager: + def __enter__(self): + return self + def __exit__(self, *args): + return False + return DeviceContextManager() + + monkeypatch.setattr( + "gptqmodel.looper.stage_inputs_capture.device_ctx", + mock_device_ctx, + ) + monkeypatch.setattr( + "gptqmodel.looper.stage_inputs_capture.move_to", + tracking_move_to, + ) + monkeypatch.setattr( + "gptqmodel.looper.stage_inputs_capture.nested_move_to", + lambda obj, device, **kw: obj, + ) + + capture = StageInputsCapture(looper, logger=_create_mock_logger()) + + # Create multiple calibration batches to test round-robin + calibration_data = [ + {"input_ids": torch.randint(0, 100, (1, 4)), "attention_mask": torch.ones(1, 4, dtype=torch.long)} + for _ in range(3) + ] + + result = capture.cache_inputs( + layers=[hooked_layer], + calibration_data=calibration_data, + use_cache=False, + ) + + # Verify cuda:0 is NOT used (it was filtered out by compute_device_filter) + assert torch.device("cuda:0") not in used_devices, \ + f"cuda:0 should be filtered out, got devices: {used_devices}" + + # Verify we captured the expected number of batches + assert len(result.layer_inputs) == 3, f"Should have captured 3 batches, got {len(result.layer_inputs)}" + + +def test_stage_capture_balanced_mode_empty_filter_fallback(monkeypatch): + """ + Test StageInputsCapture: if compute_device_filter returns empty list, + it falls back to using all devices from select_forward_devices. + + Tests: stage_inputs_capture.py lines 99-101 + """ + from gptqmodel.looper.stage_inputs_capture import StageInputsCapture + + mock_devices = [ + torch.device("cuda:0"), + torch.device("cuda:1"), + ] + + def fake_select_forward_devices(_base_device): + return mock_devices + + def compute_device_filter_returns_empty(devices): + return [] # Invalid filter returning empty + + # Patch at source module since import happens inside the function + monkeypatch.setattr( + "gptqmodel.utils.looper_helpers.select_forward_devices", + fake_select_forward_devices, + ) + + class MinimalModel(nn.Module): + def __init__(self, hooked_layer): + super().__init__() + self.hooked_layer = hooked_layer + self.config = types.SimpleNamespace(model_type="test") + + def forward(self, input_ids=None, attention_mask=None, **kwargs): + self.hooked_layer(input_ids.float()) + return None + + hooked_layer = nn.Linear(4, 4) + hooked_layer.full_name = "test_layer" + model = MinimalModel(hooked_layer) + + class FakeGPTQModel: + capture_first_layer_positional_inputs = BaseQModel.capture_first_layer_positional_inputs + capture_first_layer_input_kwargs = BaseQModel.capture_first_layer_input_kwargs + finalize_input_capture_example = BaseQModel.finalize_input_capture_example + move_input_capture_example = BaseQModel.move_input_capture_example + prepare_layer_replay_kwargs = BaseQModel.prepare_layer_replay_kwargs + run_input_capture = BaseQModel.run_input_capture + + def __init__(self): + self.quantize_config = types.SimpleNamespace( + device=torch.device("cuda:0"), + calibration_data_device="balanced", + compute_device_filter=compute_device_filter_returns_empty, + ) + self.model = model + self.ATTENTION_MASKS_REQUIRED_FOR_INPUT = False + self.ATTENTION_MASKS_DTYPE = torch.long + self.INPUT_EMBEDDING_EXTRA_ARGS = None + self.quant_region_timer = None + + def shell_module_materialize(self, target_submodule, device, **kwargs): + del kwargs + return target_submodule + + def get_base_modules(self, model): + return [] + + def pre_quantize_generate_hook_start(self): + pass + + def pre_quantize_generate_hook_end(self): + pass + + gptq_model = FakeGPTQModel() + looper = _DummyLooperForCapture(gptq_model) + + used_devices = [] + + def tracking_move_to(obj, device, **kwargs): + if isinstance(obj, torch.Tensor): + used_devices.append(device) + return obj.detach().clone() if device.type == "cpu" else obj + + monkeypatch.setattr( + "gptqmodel.looper.stage_inputs_capture.DEVICE_THREAD_POOL", + _create_mock_thread_pool(), + ) + + # Mock ctx to be a proper context manager that combines multiple context managers + def mock_ctx(*context_managers): + class CombinedContextManager: + def __enter__(self): + for cm in context_managers: + cm.__enter__() + return self + + def __exit__(self, *args): + for cm in reversed(context_managers): + cm.__exit__(*args) + return False + return CombinedContextManager() + + monkeypatch.setattr( + "gptqmodel.looper.stage_inputs_capture.ctx", + mock_ctx, + ) + + # Mock device_ctx to be a proper context manager + def mock_device_ctx(device): + class DeviceContextManager: + def __enter__(self): + return self + def __exit__(self, *args): + return False + return DeviceContextManager() + + monkeypatch.setattr( + "gptqmodel.looper.stage_inputs_capture.device_ctx", + mock_device_ctx, + ) + monkeypatch.setattr( + "gptqmodel.looper.stage_inputs_capture.move_to", + tracking_move_to, + ) + monkeypatch.setattr( + "gptqmodel.looper.stage_inputs_capture.nested_move_to", + lambda obj, device, **kw: obj, + ) + + capture = StageInputsCapture(looper, logger=_create_mock_logger()) + + calibration_data = [ + {"input_ids": torch.randint(0, 100, (1, 4)), "attention_mask": torch.ones(1, 4, dtype=torch.long)} + ] + + result = capture.cache_inputs( + layers=[hooked_layer], + calibration_data=calibration_data, + use_cache=False, + ) + + # Should still work despite empty filter (fallback to all_devices) + assert len(result.layer_inputs) == 1, "Should have captured one batch" + + +# ============================================================================ +# MODULE LOOPER INIT TESTS - Device assignment via _quant_devices +# ============================================================================ + +class _DummyGPTQModelForLooper: + """Minimal fake GPT-QModel for testing ModuleLooper.""" + + capture_first_layer_positional_inputs = BaseQModel.capture_first_layer_positional_inputs + capture_first_layer_input_kwargs = BaseQModel.capture_first_layer_input_kwargs + prepare_layer_replay_kwargs = BaseQModel.prepare_layer_replay_kwargs + + def __init__( + self, + calibration_data_device=None, + compute_device_filter=None, + quant_device="cuda:0", + auto_forward_data_parallel=True, + dense_vram_strategy="exclusive", + dense_vram_strategy_devices=None, + moe_vram_strategy="exclusive", + moe_vram_strategy_devices=None, + ): + self.quantize_config = types.SimpleNamespace( + device=torch.device(quant_device), + calibration_data_device=calibration_data_device, + compute_device_filter=compute_device_filter, + auto_forward_data_parallel=auto_forward_data_parallel, + dense_vram_strategy=dense_vram_strategy, + dense_vram_strategy_devices=dense_vram_strategy_devices, + moe_vram_strategy=moe_vram_strategy, + moe_vram_strategy_devices=moe_vram_strategy_devices, + moe_routing_bypass=lambda: False, + moe=None, + gc_mode=None, + ) + self.support_batch_quantize = False + self.layer_callback = None + self.lm_head = None + self.quant_region_timer = None + self.moe_lifecycle_hooks = None + # Don't set supported dense/MoE strategy lists so getattr uses default values + + +def test_compute_device_filter_applied_to_quant_devices(monkeypatch): + """ + Test that compute_device_filter is applied to _quant_devices in ModuleLooper.__init__. + + This test verifies the interaction between compute_device_filter and calibration_data_device + in the ModuleLooper initialization. + """ + from gptqmodel.looper.module_looper import ModuleLooper + + all_devices = [ + torch.device("cuda:0"), + torch.device("cuda:1"), + torch.device("cuda:2"), + ] + + # Filter excludes cuda:0 (e.g., reserved for model layers) + def compute_device_filter(devices): + return [d for d in devices if d != torch.device("cuda:0")] + + def fake_select_forward_devices(_base_device): + return all_devices + + monkeypatch.setattr( + "gptqmodel.looper.module_looper.select_forward_devices", + fake_select_forward_devices, + ) + + gptq_model = _DummyGPTQModelForLooper( + calibration_data_device="balanced", + compute_device_filter=compute_device_filter, + ) + + looper = ModuleLooper(model=gptq_model, processors=[]) + + # Verify cuda:0 is filtered out from quant_devices + assert torch.device("cuda:0") not in looper._quant_devices + assert torch.device("cuda:1") in looper._quant_devices + assert torch.device("cuda:2") in looper._quant_devices + + +def test_compute_device_filter_with_empty_result_uses_all_devices(monkeypatch): + """ + Test that if compute_device_filter returns empty list, all devices are used. + + Tests: The fallback behavior when filter returns nothing. + """ + from gptqmodel.looper.module_looper import ModuleLooper + + all_devices = [ + torch.device("cuda:0"), + torch.device("cuda:1"), + ] + + def compute_device_filter_returns_empty(devices): + return [] # Invalid filter returning empty + + def fake_select_forward_devices(_base_device): + return all_devices + + monkeypatch.setattr( + "gptqmodel.looper.module_looper.select_forward_devices", + fake_select_forward_devices, + ) + + gptq_model = _DummyGPTQModelForLooper( + calibration_data_device="balanced", + compute_device_filter=compute_device_filter_returns_empty, + ) + + looper = ModuleLooper(model=gptq_model, processors=[]) + + # Should still have devices (fallback behavior in __init__) + assert len(looper._quant_devices) > 0 + + +# ============================================================================ +# FORWARD BATCH TESTS - Balanced mode batch-to-device assignment +# ============================================================================ + + +@pytest.mark.skipif( + torch.cuda.device_count() < 2, + reason="Requires at least 2 CUDA devices" +) +def test_balanced_mode_assigns_batches_to_input_devices(monkeypatch): + """ + Test that in balanced mode, batches are assigned to the device + where their input already resides. + + Tests: module_looper.py _run_forward_batches_parallel balanced mode logic + """ + from gptqmodel.looper.module_looper import ModuleLooper + + all_devices = [ + torch.device("cuda:0"), + torch.device("cuda:1"), + ] + + def fake_select_forward_devices(_base_device): + return all_devices + + monkeypatch.setattr( + "gptqmodel.looper.module_looper.select_forward_devices", + fake_select_forward_devices, + ) + + gptq_model = _DummyGPTQModelForLooper( + calibration_data_device="balanced", + ) + + looper = ModuleLooper(model=gptq_model, processors=[]) + + # Track which device each batch was executed on + batch_to_device = {} + + def fake_clone_module_for_devices(_module, target_devices, progress_callback=None): + return {device: object() for device in target_devices} + + class DummyProcessor: + num_batches = 4 + + def _set_current_batch_index(self, _index): + pass + + class DummyFuture: + def __init__(self, result): + self._result = result + + def result(self): + return self._result + + def fake_forward_batch_worker(*args, **kwargs): + batch_idx = args[2] + return batch_idx, None, None + + def fake_submit(device, fn, *args, **kwargs): + batch_idx = args[2] + batch_to_device[batch_idx] = device + return DummyFuture(fn(*args, **kwargs)) + + monkeypatch.setattr( + "gptqmodel.looper.module_looper.clone_module_for_devices", + fake_clone_module_for_devices, + ) + monkeypatch.setattr( + "gptqmodel.looper.module_looper.forward_batch_worker", + fake_forward_batch_worker, + ) + monkeypatch.setattr( + "gptqmodel.looper.module_looper.DEVICE_THREAD_POOL.submit", + fake_submit, + ) + monkeypatch.setattr( + "gptqmodel.looper.module_looper.DEVICE_THREAD_POOL.submit_serial", + fake_submit, + ) + + module = torch.nn.Linear(1, 1) + processor = DummyProcessor() + + # Batches 0,1 on cuda:0, batches 2,3 on cuda:1 + layer_inputs = [ + [torch.zeros(1, 1, device=torch.device("cuda:0"))], + [torch.zeros(1, 1, device=torch.device("cuda:0"))], + [torch.zeros(1, 1, device=torch.device("cuda:1"))], + [torch.zeros(1, 1, device=torch.device("cuda:1"))], + ] + + looper._run_forward_batches_parallel( + module=module, + processor=processor, + layer_inputs=layer_inputs, + layer_input_kwargs=[{}, {}, {}, {}], + position_ids=[], + attention_masks=[torch.zeros(1, 1)] * 4, + cur_layer_device=torch.device("cuda:0"), + is_lm_head_module=False, + shared_kv_cache_dict={}, + layer_index=0, + need_outputs=False, + reuse_kv=False, + devices=all_devices, + ) + + # Verify batches were assigned to their input device + assert batch_to_device.get(0) == torch.device("cuda:0"), "Batch 0 should run on cuda:0" + assert batch_to_device.get(1) == torch.device("cuda:0"), "Batch 1 should run on cuda:0" + assert batch_to_device.get(2) == torch.device("cuda:1"), "Batch 2 should run on cuda:1" + assert batch_to_device.get(3) == torch.device("cuda:1"), "Batch 3 should run on cuda:1" + + +def test_balanced_mode_fallback_when_input_device_not_in_forward_devices(monkeypatch): + """ + Test that in balanced mode, if a batch's input device is not in + forward_devices, it falls back to round-robin assignment. + + Tests: module_looper.py _run_forward_batches_parallel fallback logic + """ + from gptqmodel.looper.module_looper import ModuleLooper + + # Only cuda:1 and cuda:2 available for forward + forward_devices = [ + torch.device("cuda:1"), + torch.device("cuda:2"), + ] + + def fake_select_forward_devices(_base_device): + return forward_devices + + def compute_device_filter(devices): + return forward_devices + + monkeypatch.setattr( + "gptqmodel.looper.module_looper.select_forward_devices", + fake_select_forward_devices, + ) + + gptq_model = _DummyGPTQModelForLooper( + calibration_data_device="balanced", + compute_device_filter=compute_device_filter, + ) + + looper = ModuleLooper(model=gptq_model, processors=[]) + + batch_to_device = {} + + def fake_clone_module_for_devices(_module, target_devices, progress_callback=None): + return {device: object() for device in target_devices} + + class DummyProcessor: + num_batches = 4 + + def _set_current_batch_index(self, _index): + pass + + class DummyFuture: + def __init__(self, result): + self._result = result + + def result(self): + return self._result + + def fake_forward_batch_worker(*args, **kwargs): + batch_idx = args[2] + return batch_idx, None, None + + def fake_submit(device, fn, *args, **kwargs): + batch_idx = args[2] + batch_to_device[batch_idx] = device + return DummyFuture(fn(*args, **kwargs)) + + monkeypatch.setattr( + "gptqmodel.looper.module_looper.clone_module_for_devices", + fake_clone_module_for_devices, + ) + monkeypatch.setattr( + "gptqmodel.looper.module_looper.forward_batch_worker", + fake_forward_batch_worker, + ) + monkeypatch.setattr( + "gptqmodel.looper.module_looper.DEVICE_THREAD_POOL.submit", + fake_submit, + ) + monkeypatch.setattr( + "gptqmodel.looper.module_looper.DEVICE_THREAD_POOL.submit_serial", + fake_submit, + ) + + module = torch.nn.Linear(1, 1) + processor = DummyProcessor() + + # All inputs on cuda:0 which is NOT in forward_devices + layer_inputs = [ + [torch.zeros(1, 1, device=torch.device("cuda:0"))], + [torch.zeros(1, 1, device=torch.device("cuda:0"))], + [torch.zeros(1, 1, device=torch.device("cuda:0"))], + [torch.zeros(1, 1, device=torch.device("cuda:0"))], + ] + + looper._run_forward_batches_parallel( + module=module, + processor=processor, + layer_inputs=layer_inputs, + layer_input_kwargs=[{}, {}, {}, {}], + position_ids=[], + attention_masks=[torch.zeros(1, 1)] * 4, + cur_layer_device=torch.device("cuda:1"), + is_lm_head_module=False, + shared_kv_cache_dict={}, + layer_index=0, + need_outputs=False, + reuse_kv=False, + devices=forward_devices, + ) + + # All batches should be assigned to either cuda:1 or cuda:2 (fallback) + # cuda:0 should never be used as it's not in forward_devices + for batch_idx in range(4): + assert batch_to_device[batch_idx] in forward_devices, \ + f"Batch {batch_idx} should be on forward device, got {batch_to_device[batch_idx]}" + assert batch_to_device[batch_idx] != torch.device("cuda:0"), \ + f"Batch {batch_idx} should not be on cuda:0" + + +# ============================================================================ +# OUTPUT DEVICE PRESERVATION TESTS +# ============================================================================ + +def test_output_moved_to_input_device_in_single_mode(monkeypatch): + """ + Test that in _run_forward_batches_single, outputs are moved + back to the same device as the input. + + Tests: module_looper.py _run_forward_batches_single output device preservation + """ + from gptqmodel.looper.module_looper import ModuleLooper + + def fake_select_forward_devices(_base_device): + return [torch.device("cuda:0")] + + monkeypatch.setattr( + "gptqmodel.looper.module_looper.select_forward_devices", + fake_select_forward_devices, + ) + + gptq_model = _DummyGPTQModelForLooper( + calibration_data_device=torch.device("cpu"), + auto_forward_data_parallel=False, # Force single mode + ) + + looper = ModuleLooper(model=gptq_model, processors=[]) + + class DummyProcessor: + num_batches = 1 + + def _set_current_batch_index(self, _index): + pass + + processor = DummyProcessor() + + # Input is on CPU (calibration data device) + layer_inputs = [[torch.zeros(1, 1, device=torch.device("cpu"))]] + + # Use is_lm_head_module=True to avoid passing attention_mask/use_cache kwargs + # that plain nn.Linear doesn't accept + module = torch.nn.Linear(1, 1) + + outputs = looper._run_forward_batches_single( + module=module, + processor=processor, + layer_inputs=layer_inputs, + layer_input_kwargs=[{}], + position_ids=[], + attention_masks=[None], + cur_layer_device=torch.device("cpu"), + is_lm_head_module=True, # Avoid kwargs that nn.Linear doesn't accept + shared_kv_cache_dict={}, + layer_index=0, + need_outputs=True, + reuse_kv=False, + ) + + # Output should be on CPU (same as input) + assert len(outputs) == 1, "Should have one output" + assert outputs[0][0].device == torch.device("cpu"), \ + f"Output should be on CPU (input device), got {outputs[0][0].device}" + + +@pytest.mark.skipif( + torch.cuda.device_count() < 2, + reason="Requires at least 2 CUDA devices" +) +def test_output_moved_to_input_device_in_parallel_mode(monkeypatch): + """ + Test that in _run_forward_batches_parallel, outputs are moved + back to the same device as each batch's input. + + Tests: module_looper.py _run_forward_batches_parallel output device preservation + """ + from gptqmodel.looper.module_looper import ModuleLooper + + all_devices = [torch.device("cuda:0"), torch.device("cuda:1")] + + def fake_select_forward_devices(_base_device): + return all_devices + + monkeypatch.setattr( + "gptqmodel.looper.module_looper.select_forward_devices", + fake_select_forward_devices, + ) + + gptq_model = _DummyGPTQModelForLooper( + calibration_data_device="balanced", + ) + + looper = ModuleLooper(model=gptq_model, processors=[]) + + def fake_clone_module_for_devices(_module, target_devices, progress_callback=None): + return {device: object() for device in target_devices} + + class DummyProcessor: + num_batches = 2 + + def _set_current_batch_index(self, _index): + pass + + class DummyFuture: + def __init__(self, result): + self._result = result + + def result(self): + return self._result + + # Create outputs that return tensors on cuda:0 (compute device) + # These will be moved back to input device by the code + output_tensors = [ + torch.ones(1, 4, device=torch.device("cuda:0")), + torch.ones(1, 4, device=torch.device("cuda:0")), + ] + + def fake_forward_batch_worker(*args, **kwargs): + batch_idx = args[2] + need_output = kwargs.get("need_output", False) + if need_output: + return batch_idx, (output_tensors[batch_idx],), None + return batch_idx, None, None + + def fake_submit(device, fn, *args, **kwargs): + return DummyFuture(fn(*args, **kwargs)) + + monkeypatch.setattr( + "gptqmodel.looper.module_looper.clone_module_for_devices", + fake_clone_module_for_devices, + ) + monkeypatch.setattr( + "gptqmodel.looper.module_looper.forward_batch_worker", + fake_forward_batch_worker, + ) + monkeypatch.setattr( + "gptqmodel.looper.module_looper.DEVICE_THREAD_POOL.submit", + fake_submit, + ) + monkeypatch.setattr( + "gptqmodel.looper.module_looper.DEVICE_THREAD_POOL.submit_serial", + fake_submit, + ) + + module = torch.nn.Linear(4, 4) + processor = DummyProcessor() + + # Inputs on different devices + layer_inputs = [ + [torch.zeros(1, 4, device=torch.device("cuda:0"))], + [torch.zeros(1, 4, device=torch.device("cuda:1"))], + ] + + outputs = looper._run_forward_batches_parallel( + module=module, + processor=processor, + layer_inputs=layer_inputs, + layer_input_kwargs=[{}, {}], + position_ids=[], + attention_masks=[torch.zeros(1, 1), torch.zeros(1, 1)], + cur_layer_device=torch.device("cuda:0"), + is_lm_head_module=False, + shared_kv_cache_dict={}, + layer_index=0, + need_outputs=True, + reuse_kv=False, + devices=all_devices, + ) + + # Verify outputs are on same device as their inputs + assert len(outputs) == 2, "Should have two outputs" + assert outputs[0][0].device == torch.device("cuda:0"), \ + f"Batch 0 output should be on cuda:0, got {outputs[0][0].device}" + assert outputs[1][0].device == torch.device("cuda:1"), \ + f"Batch 1 output should be on cuda:1, got {outputs[1][0].device}" + + +def test_auto_forward_data_parallel_false_uses_single_mode(monkeypatch): + """ + Test that when auto_forward_data_parallel=False, + _run_forward_batches uses the single (serial) path. + + Tests: module_looper.py _run_forward_batches mode selection + """ + from gptqmodel.looper.module_looper import ModuleLooper + + def fake_select_forward_devices(_base_device): + return [torch.device("cuda:0"), torch.device("cuda:1")] + + monkeypatch.setattr( + "gptqmodel.looper.module_looper.select_forward_devices", + fake_select_forward_devices, + ) + + gptq_model = _DummyGPTQModelForLooper( + calibration_data_device=torch.device("cpu"), + auto_forward_data_parallel=False, + ) + + looper = ModuleLooper(model=gptq_model, processors=[]) + + single_called = [False] + + def patched_single(**kwargs): + single_called[0] = True + return [] + + # Patch the forward_executor's run_single method (not looper's non-existent _run_forward_batches_single) + looper._forward_executor.run_single = patched_single + + class DummyProcessor: + num_batches = 1 + + def _set_current_batch_index(self, _index): + pass + + processor = DummyProcessor() + module = torch.nn.Linear(1, 1) + layer_inputs = [[torch.zeros(1, 1)]] + + looper._run_forward_batches( + module=module, + processor=processor, + layer_inputs=layer_inputs, + layer_input_kwargs=[{}], + position_ids=[], + attention_masks=[None], + cur_layer_device=torch.device("cpu"), + is_lm_head_module=False, + shared_kv_cache_dict={}, + layer_index=0, + need_outputs=False, + reuse_kv=False, + ) + + assert single_called[0], "Should have called _run_forward_batches_single" + + +# ============================================================================ +# INTEGRATION TESTS (Marked for GPU execution) +# ============================================================================ + + +def _skip_if_model_missing(model_id: str): + """Skip test if model path doesn't exist, matching test_failsafe.py pattern.""" + if not os.path.isdir(model_id): + pytest.skip(f"Model path missing: {model_id}") + + +@pytest.mark.cuda +@pytest.mark.integration +class TestCalibrationDataDeviceIntegration: + """ + Integration tests for calibration_data_device with actual quantization. + These tests require GPU and load real models. + Run with: pytest -m "cuda and integration" tests/test_calibration_data_device.py + """ + + NATIVE_MODEL_ID = _INTEGRATION_MODEL_ID + DATASET_SIZE = 4 + + @pytest.fixture(autouse=True) + def setup(self): + """Set up test fixtures.""" + import gc + + # Skip if model is missing (matching test_failsafe.py pattern) + _skip_if_model_missing(self.NATIVE_MODEL_ID) + + from transformers import AutoTokenizer + + self.tokenizer = AutoTokenizer.from_pretrained(self.NATIVE_MODEL_ID, use_fast=True) + if not self.tokenizer.pad_token_id: + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + + # Use longer text samples to pass calibration_data_min_length check (default=10 tokens) + # These samples are long enough to not be filtered out + self.calibration_data = [ + "The quick brown fox jumps over the lazy dog. This sentence contains every letter of the alphabet and is commonly used for testing purposes in typography and keyboard layouts.", + "Machine learning is transforming technology by enabling computers to learn from data and make predictions or decisions without being explicitly programmed for each specific task.", + "Natural language processing enables computers to understand text and human language, allowing for applications like chatbots, translation services, and sentiment analysis tools.", + "Quantization reduces model size while maintaining accuracy by converting floating-point numbers to lower precision representations, making models faster and more memory efficient.", + ] + + yield + + # Teardown: aggressive cleanup to prevent VRAM leak between tests + self.calibration_data = None + self.tokenizer = None + + gc.collect() + if torch.cuda.is_available(): + torch.cuda.synchronize() + torch.cuda.empty_cache() + + def _cleanup_model(self, model): + """Immediately delete model and free GPU memory.""" + if model is not None: + del model + import gc + gc.collect() + if torch.cuda.is_available(): + torch.cuda.synchronize() + torch.cuda.empty_cache() + + def _create_quantize_config(self, calibration_data_device, **kwargs): + """Create QuantizeConfig with memory-efficient defaults for integration tests.""" + defaults = { + "bits": 4, + "group_size": 128, + "calibration_data_device": calibration_data_device, + "offload_to_disk": False, + # Memory-efficient settings for multi-test runs + "auto_forward_data_parallel": False, + "wait_for_submodule_finalizers": True, + "gc_mode": "on_stage_end", + } + defaults.update(kwargs) + return QuantizeConfig(**defaults) + + def test_calibration_data_device_cpu_integration(self, tmp_path): + """Integration test: quantization with calibration_data_device='cpu'.""" + from gptqmodel import GPTQModel + + quantize_config = self._create_quantize_config(calibration_data_device="cpu") + + model = GPTQModel.load(self.NATIVE_MODEL_ID, quantize_config=quantize_config) + model.quantize(self.calibration_data, batch_size=1) + + save_path = str(tmp_path / "quantized") + model.save(save_path) + self._cleanup_model(model) + + loaded = GPTQModel.load(save_path) + assert loaded is not None + + inp = self.tokenizer("Hello", return_tensors="pt").to(loaded.device) + output = loaded.generate(**inp, max_new_tokens=5) + assert output is not None + self._cleanup_model(loaded) + + @pytest.mark.skipif( + torch.cuda.device_count() < 2, + reason="Requires at least 2 CUDA devices" + ) + def test_calibration_data_device_balanced_integration(self, tmp_path): + """Integration test: quantization with calibration_data_device='balanced'.""" + from gptqmodel import GPTQModel + + quantize_config = self._create_quantize_config(calibration_data_device="balanced") + + model = GPTQModel.load(self.NATIVE_MODEL_ID, quantize_config=quantize_config) + model.quantize(self.calibration_data, batch_size=1) + + save_path = str(tmp_path / "quantized_balanced") + model.save(save_path) + self._cleanup_model(model) + + loaded = GPTQModel.load(save_path) + assert loaded is not None + self._cleanup_model(loaded) + + @pytest.mark.skipif( + torch.cuda.device_count() < 2, + reason="Requires at least 2 CUDA devices" + ) + def test_calibration_data_device_cuda1_with_compute_filter_integration(self, tmp_path): + """Integration test: quantization with calibration_data_device='cuda:1' and compute_device_filter.""" + from gptqmodel import GPTQModel + + # Filter that excludes cuda:1 from compute devices + def compute_device_filter(devices): + return [d for d in devices if d != torch.device("cuda:1")] + + quantize_config = self._create_quantize_config( + calibration_data_device="cuda:1", + compute_device_filter=compute_device_filter, + ) + + model = GPTQModel.load(self.NATIVE_MODEL_ID, quantize_config=quantize_config) + model.quantize(self.calibration_data, batch_size=1) + + save_path = str(tmp_path / "quantized_cuda1") + model.save(save_path) + self._cleanup_model(model) + + loaded = GPTQModel.load(save_path) + assert loaded is not None + self._cleanup_model(loaded) diff --git a/tests/test_compute_device_filter.py b/tests/test_compute_device_filter.py index 4dd1bbcbe..d666da8d5 100644 --- a/tests/test_compute_device_filter.py +++ b/tests/test_compute_device_filter.py @@ -10,8 +10,12 @@ def __init__(self, compute_device_filter): self.support_batch_quantize = False self.quantize_config = types.SimpleNamespace( device=torch.device("cpu"), - vram_strategy="exclusive", + dense_vram_strategy="exclusive", + dense_vram_strategy_devices=None, + moe_vram_strategy="exclusive", + moe_vram_strategy_devices=None, compute_device_filter=compute_device_filter, + calibration_data_device="balanced", auto_forward_data_parallel=True, moe_routing_bypass=lambda: False, ) diff --git a/tests/test_cpp_jit_progress.py b/tests/test_cpp_jit_progress.py new file mode 100644 index 000000000..9d7949eb8 --- /dev/null +++ b/tests/test_cpp_jit_progress.py @@ -0,0 +1,145 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +import time + +import pytest + +from gptqmodel.utils.cpp import ( + _COMPILE_PROGRESS_TOTAL_STEPS, + _compile_progress_ratio, + _compile_progress_step, + _compile_progress_subtitle, + _CompileProgressDisplay, + default_jit_cuda_cflags, +) +from gptqmodel.utils.jit_compile_baselines import get_jit_compile_baseline_seconds + + +def test_known_jit_compile_baselines_are_recorded(): + assert get_jit_compile_baseline_seconds("gptqmodel_marlin_fp16_ops") == pytest.approx(116.863) + assert get_jit_compile_baseline_seconds("gptqmodel_awq_ops") == pytest.approx(61.640) + assert get_jit_compile_baseline_seconds("gptqmodel_paroquant_rotation") == pytest.approx(78.430) + + +def test_compile_progress_ratio_tracks_baseline_without_hitting_completion(): + assert _compile_progress_ratio(0.0, 120.0) == pytest.approx(0.0) + assert _compile_progress_ratio(60.0, 120.0) == pytest.approx(0.475) + assert _compile_progress_ratio(120.0, 120.0) == pytest.approx(0.95) + assert 0.95 < _compile_progress_ratio(240.0, 120.0) < 0.99 + + +def test_compile_progress_step_never_reaches_final_step_before_completion(): + assert _compile_progress_step(120.0, 120.0, total_steps=100) < 99 + assert _compile_progress_step(600.0, 120.0, total_steps=100) < 99 + + +def test_compile_progress_subtitle_reports_overrun_against_baseline(): + subtitle = _compile_progress_subtitle(120.634, 116.863) + assert "elapsed 121s" in subtitle + assert "estimated ~117s" in subtitle + assert "(+3.8s)" in subtitle + + +def test_default_jit_cuda_cflags_includes_nvcc_threads_by_default(monkeypatch): + monkeypatch.delenv("NVCC_THREADS", raising=False) + flags = default_jit_cuda_cflags() + assert "--threads" in flags + assert "8" in flags + + +def test_default_jit_cuda_cflags_honors_nvcc_threads_override(monkeypatch): + monkeypatch.setenv("NVCC_THREADS", "16") + flags = default_jit_cuda_cflags() + assert "--threads" in flags + assert "16" in flags + + +def test_default_jit_cuda_cflags_explicit_nvcc_threads_takes_precedence(monkeypatch): + monkeypatch.setenv("NVCC_THREADS", "8") + flags = default_jit_cuda_cflags(nvcc_threads=16) + assert "--threads" in flags + assert flags[flags.index("--threads") + 1] == "16" + + +class _FakeProgress: + def __init__(self): + self.current_iter_step = 0 + self._title = "" + self._subtitle = "" + self.closed = False + + def manual(self): + return self + + def set(self, **_kwargs): + return self + + def title(self, value): + self._title = value + return self + + def subtitle(self, value): + self._subtitle = value + return self + + def draw(self, force: bool = False): + return self + + def close(self): + self.closed = True + return None + + +class _FakeSpinner: + def close(self): + return None + + +class _FakeLogger: + def __init__(self): + self.progress = _FakeProgress() + self.spinner_handle = _FakeSpinner() + + def pb(self, _iterable): + return self.progress + + def spinner(self, **_kwargs): + return self.spinner_handle + + +def test_compile_progress_close_completes_immediately_when_build_finishes_early(monkeypatch): + monkeypatch.setattr("gptqmodel.utils.cpp._COMPILE_PROGRESS_INTERVAL_SECONDS", 60.0) + logger = _FakeLogger() + display = _CompileProgressDisplay( + logger=logger, + title="Compiling extension: Marlin bf16...", + baseline_seconds=120.0, + ) + + started = time.perf_counter() + display.close(succeeded=True, elapsed_seconds=5.0) + elapsed = time.perf_counter() - started + + assert elapsed < 0.5 + assert logger.progress.current_iter_step == _COMPILE_PROGRESS_TOTAL_STEPS + assert logger.progress.closed is True + assert logger.progress._subtitle == "elapsed 5.0s / estimated ~120s" + + +def test_compile_progress_close_does_not_wait_for_refresh_interval_on_failure(monkeypatch): + monkeypatch.setattr("gptqmodel.utils.cpp._COMPILE_PROGRESS_INTERVAL_SECONDS", 60.0) + logger = _FakeLogger() + display = _CompileProgressDisplay( + logger=logger, + title="Compiling extension: AWQ...", + baseline_seconds=60.0, + ) + + started = time.perf_counter() + display.close(succeeded=False, elapsed_seconds=3.0) + elapsed = time.perf_counter() - started + + assert elapsed < 0.5 + assert logger.progress.current_iter_step < _COMPILE_PROGRESS_TOTAL_STEPS + assert logger.progress.closed is True diff --git a/tests/test_cpu_pin.py b/tests/test_cpu_pin.py index e3abf7a8f..672ac0647 100644 --- a/tests/test_cpu_pin.py +++ b/tests/test_cpu_pin.py @@ -16,7 +16,7 @@ from gptqmodel.utils.threadx import DeviceThreadPool -pytestmark = pytest.mark.ci +pytestmark = [pytest.mark.ci, pytest.mark.cpu] def _affinity_supported() -> bool: diff --git a/tests/test_cuda_event_stream_activation_buffer.py b/tests/test_cuda_event_stream_activation_buffer.py index 52749c20a..689133ecc 100644 --- a/tests/test_cuda_event_stream_activation_buffer.py +++ b/tests/test_cuda_event_stream_activation_buffer.py @@ -232,10 +232,15 @@ def test_cuda_event_stream_activation_buffer_benchmarks(): hit_forward_mean = statistics.mean(async_pool_forward[1:]) if len(async_pool_forward) > 1 else miss_forward miss_drain = async_pool_drain[0] hit_drain_mean = statistics.mean(async_pool_drain[1:]) if len(async_pool_drain) > 1 else miss_drain + miss_total = miss_forward + miss_drain + hit_total_mean = hit_forward_mean + hit_drain_mean assert pool.misses >= 1 assert pool.hits >= 1 - assert hit_forward_mean <= miss_forward * 0.75 + # Buffer reuse saves host allocation work, but the synchronized D2H drain still + # dominates on fast GPUs, so compare end-to-end packet cost instead of the noisy + # sub-millisecond forward hook timing alone. + assert hit_total_mean <= miss_total * 1.1 print( "[CUDA6 Activation Copy Benchmark]\n" @@ -250,5 +255,7 @@ def test_cuda_event_stream_activation_buffer_benchmarks(): f" pool hit forward: {hit_forward_mean * 1e3:.2f} ms\n" f" pool miss drain: {miss_drain * 1e3:.2f} ms\n" f" pool hit drain: {hit_drain_mean * 1e3:.2f} ms\n" + f" pool miss total: {miss_total * 1e3:.2f} ms\n" + f" pool hit total: {hit_total_mean * 1e3:.2f} ms\n" f" pool stats (hits/misses): {pool.hits}/{pool.misses}" ) diff --git a/tests/test_cutlass_stable_abi_headers.py b/tests/test_cutlass_stable_abi_headers.py new file mode 100644 index 000000000..7539631ce --- /dev/null +++ b/tests/test_cutlass_stable_abi_headers.py @@ -0,0 +1,44 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import unittest +from pathlib import Path + + +_REPO_ROOT = Path(__file__).resolve().parents[1] + + +class CutlassStableAbiHeaderTests(unittest.TestCase): + def test_common_header_uses_stable_abi_shim_check(self) -> None: + header = ( + _REPO_ROOT / "gptqmodel_ext" / "cutlass_extensions" / "common.hpp" + ).read_text(encoding="utf-8") + + self.assertIn("#include ", header) + self.assertIn( + "STD_TORCH_CHECK(error == cutlass::Status::kSuccess,", + header, + ) + + def test_torch_utils_header_supports_stable_and_unstable_torch_abi(self) -> None: + header = ( + _REPO_ROOT / "gptqmodel_ext" / "cutlass_extensions" / "torch_utils.hpp" + ).read_text(encoding="utf-8") + + self.assertIn("shared between _C (unstable ABI, used by machete)", header) + self.assertIn("#ifdef TORCH_TARGET_VERSION", header) + self.assertIn("using TorchTensor = torch::stable::Tensor;", header) + self.assertIn("using TorchTensor = torch::Tensor;", header) + self.assertIn("#define TORCH_UTILS_CHECK STD_TORCH_CHECK", header) + self.assertIn("#define TORCH_UTILS_CHECK TORCH_CHECK", header) + self.assertIn("static inline auto make_cute_layout(TorchTensor const& tensor,", header) + self.assertIn("std::optional const& tensor,", header) + self.assertIn("struct equivalent_cutlass_type", header) + self.assertIn("struct equivalent_cutlass_type", header) + self.assertIn("static inline constexpr TorchScalarType equivalent_scalar_type_v =", header) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_dynamic.py b/tests/test_dynamic.py index cfcd6d664..82cdd9c3d 100644 --- a/tests/test_dynamic.py +++ b/tests/test_dynamic.py @@ -23,34 +23,28 @@ from gptqmodel.looper.named_module import NamedModule # noqa: E402 from gptqmodel.looper.native_processor import NATIVE_INPUTS_STATE_KEY # noqa: E402 from gptqmodel.nn_modules.qlinear import BaseQuantLinear # noqa: E402 -from gptqmodel.nn_modules.qlinear.marlin import MarlinQuantLinear # noqa: E402 -from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear # noqa: E402 -from gptqmodel.nn_modules.qlinear.tritonv2 import TritonV2QuantLinear # noqa: E402 +from gptqmodel.nn_modules.qlinear.marlin import MarlinLinear # noqa: E402 +from gptqmodel.nn_modules.qlinear.torch import TorchLinear # noqa: E402 +from gptqmodel.nn_modules.qlinear.tritonv2 import TritonV2Linear # noqa: E402 from gptqmodel.quantization import QuantizeConfig # noqa: E402 from gptqmodel.utils import safetensor # noqa: E402 -from gptqmodel.utils.perplexity import Perplexity # noqa: E402 class TestDynamic(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Qwen2.5-0.5B-Instruct/" # "TinyLlama/TinyLlama-1.1B-Chat-v1.0" tmp_quant_path = None - def calculate_avg_ppl(self, model, tokenizer): - ppl = Perplexity( - model=model, - tokenizer=tokenizer, - dataset_path="wikitext", - dataset_name="wikitext-2-raw-v1", - split="test", - text_column="text", - ) - - all = ppl.calculate(n_ctx=512, n_batch=512) - - # average ppl - avg = sum(all) / len(all) - - return avg + def _generate(self, model, tokenizer, prompt: str = "The sky is blue today"): + inputs = tokenizer(prompt, return_tensors="pt") + with torch.inference_mode(): + outputs = model.generate( + **inputs, + max_new_tokens=8, + num_beams=1, + do_sample=False, + pad_token_id=tokenizer.pad_token_id, + ) + return outputs @classmethod def setUpClass(cls): @@ -103,12 +97,12 @@ def tearDownClass(cls): @parameterized.expand( [ # exllama v1/v2 only supports 4bit so does not support dynamic bits control - (BACKEND.TORCH, TorchQuantLinear, 15.643), - (BACKEND.TRITON, TritonV2QuantLinear, 15.643), - (BACKEND.MARLIN, MarlinQuantLinear, 15.644), + (BACKEND.TORCH, TorchLinear), + (BACKEND.TRITON, TritonV2Linear), + (BACKEND.MARLIN, MarlinLinear), ] ) - def test_dynamic_bits(self, backend, backendQLinear, expected_ppl): + def test_dynamic_bits(self, backend, backendQLinear): model = GPTQModel.load( self.tmp_quant_path.name, backend=backend, @@ -120,15 +114,11 @@ def test_dynamic_bits(self, backend, backendQLinear, expected_ppl): else: raise ValueError(f"Did not find a `{backendQLinear}` linear layer for backend: `{backend}`") - dynamic_bits_ppl = self.calculate_avg_ppl(model, self.tokenizer) + outputs = self._generate(model, self.tokenizer) + self.assertGreater(outputs.shape[1], 0) del model - print(f"Backend: {backend}, PPL: {dynamic_bits_ppl}") - tolerance = 0.05 - lower_bound = expected_ppl * (1 - tolerance) - upper_bound = expected_ppl * (1 + tolerance) - assert lower_bound <= dynamic_bits_ppl <= upper_bound, \ - f"PPL expected: `{expected_ppl}` ±{tolerance * 100}%, actual = `{dynamic_bits_ppl}`" + def test_skip_module(self): dynamic = { @@ -174,7 +164,7 @@ def test_dynamic_overrides_apply_per_module(monkeypatch): dynamic={ "model.linear": { "gptaq": {"alpha": 0.5, "device": "cpu"}, - "failsafe": {"strategy": "median", "threshold": "2%"}, + "fallback": {"strategy": "median", "threshold": "2%"}, "hessian": {"chunk_size": 32, "chunk_bytes": 1024, "staging_dtype": "bfloat16"}, }, } @@ -196,9 +186,9 @@ def test_dynamic_overrides_apply_per_module(monkeypatch): assert dynamic_cfg.gptaq is not None assert dynamic_cfg.gptaq.alpha == 0.5 assert dynamic_cfg.gptaq.device == "cpu" - assert dynamic_cfg.failsafe is not None - assert dynamic_cfg.failsafe.strategy == "median" - assert dynamic_cfg.failsafe.threshold == "2%" + assert dynamic_cfg.fallback is not None + assert dynamic_cfg.fallback.strategy == "median" + assert dynamic_cfg.fallback.threshold == "2%" assert dynamic_cfg.hessian.chunk_size == 32 assert dynamic_cfg.hessian.chunk_bytes == 1024 assert dynamic_cfg.hessian.staging_dtype == torch.bfloat16 @@ -214,8 +204,8 @@ def test_dynamic_overrides_apply_per_module(monkeypatch): other_cfg = processor.qcfg_dynamic assert other_cfg is not None assert other_cfg.gptaq is None - assert other_cfg.failsafe.strategy == qcfg.failsafe.strategy - assert other_cfg.failsafe.threshold == qcfg.failsafe.threshold + assert other_cfg.fallback.strategy == qcfg.fallback.strategy + assert other_cfg.fallback.threshold == qcfg.fallback.threshold assert other_cfg.hessian.chunk_size == qcfg.hessian.chunk_size assert other_cfg.hessian.chunk_bytes == qcfg.hessian.chunk_bytes assert other_cfg.hessian.staging_dtype == qcfg.hessian.staging_dtype diff --git a/tests/test_eval.py b/tests/test_eval.py index 478aa36ae..2ff734a6e 100644 --- a/tests/test_eval.py +++ b/tests/test_eval.py @@ -4,72 +4,45 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium import os +import tempfile - -os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" - -import tempfile # noqa: E402 -from typing import ( - Type, # noqa: E402 - Union, # noqa: E402 -) - -from lm_eval.tasks import TaskManager # noqa: E402 +import pytest from models.model_test import ModelTest # noqa: E402 from parameterized import parameterized # noqa: E402 from gptqmodel import GPTQModel # noqa: E402 -from gptqmodel.utils.eval import EVAL # noqa: E402 +from tests.eval import evaluate, get_eval_task_metrics # noqa: E402 + + +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + +pytestmark = [pytest.mark.model, pytest.mark.slow] class TestEval(ModelTest): @classmethod - def setUpClass(self): - self.MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct-gptqmodel-4bit-vortex-v1" - self.model = GPTQModel.load(self.MODEL_ID) + def setUpClass(cls): + cls.MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct-gptqmodel-4bit-vortex-v1" + cls.model = GPTQModel.load(cls.MODEL_ID) @parameterized.expand( [ - (EVAL.LM_EVAL, EVAL.LM_EVAL.ARC_CHALLENGE, 'gptqmodel'), - (EVAL.LM_EVAL, EVAL.LM_EVAL.ARC_CHALLENGE, 'vllm'), - (EVAL.EVALPLUS, EVAL.EVALPLUS.HUMAN, 'gptqmodel'), - (EVAL.EVALPLUS, EVAL.EVALPLUS.HUMAN, 'vllm'), - (EVAL.LM_EVAL, EVAL.LM_EVAL.GPQA, 'vllm'), + ("arc_challenge", False), + ("mmlu_pro:stem.math", True), ] ) - def test_eval_gptqmodel(self, framework: Union[Type[EVAL.LM_EVAL],Type[EVAL.EVALPLUS]], task: Union[EVAL.LM_EVAL, EVAL.EVALPLUS], llm_backend: str): + def test_evalution_gptqmodel(self, task: str, apply_chat_template: bool): with tempfile.TemporaryDirectory() as tmp_dir: - output_path = f"{tmp_dir}/result.json" - model_args = {} - if task == EVAL.LM_EVAL.GPQA: - model_args["gpu_memory_utilization"]=0.7 - - results = GPTQModel.eval(model_or_id_or_path=self.MODEL_ID, - framework=framework, - tasks=[task], - batch_size=1, - output_path=output_path, - llm_backend=llm_backend, - model_args=model_args, - task_manager=TaskManager(include_path=os.path.join(os.path.dirname(os.path.abspath(__file__)), "tasks"), include_defaults=False) - ) - - if llm_backend == EVAL.LM_EVAL: - if task == EVAL.LM_EVAL.GPQA: - gpqa_main_n_shot = results['results'].get('gpqa_main_n_shot', {}).get('acc,none') - gpqa_main_zeroshot = results['results'].get('gpqa_main_zeroshot', {}).get('acc,none') - - self.assertGreaterEqual(gpqa_main_n_shot, 0.21, "acc score does not match expected result") - self.assertGreaterEqual(gpqa_main_zeroshot, 0.25, "acc_norm score does not match expected result") - else: - acc_score = results['results'].get(task.value, {}).get('acc,none') - acc_norm_score = results['results'].get(task.value, {}).get('acc_norm,none') - - self.assertGreaterEqual(acc_score, 0.28, "acc score does not match expected result") - self.assertGreaterEqual(acc_norm_score, 0.32, "acc_norm score does not match expected result") - elif llm_backend == EVAL.EVALPLUS: - result = results.get(task.value) - base_formatted, plus_formatted, _ = float(result.get("base tests")), float( - result.get("base + extra tests")), result.get("results_path") - self.assertGreaterEqual(base_formatted, 0.26, "Base score does not match expected result") - self.assertGreaterEqual(plus_formatted, 0.23, "Plus score does not match expected result") + results = evaluate( + model_or_id_or_path=self.MODEL_ID, + tasks=[task], + batch_size=1, + output_path=f"{tmp_dir}/result.json", + apply_chat_template=apply_chat_template, + suite_kwargs={"max_rows": 2, "num_fewshot": 1}, + ) + if task == "mmlu_pro:stem.math": + metrics = get_eval_task_metrics(results, "mmlu_pro_stem_math") + else: + metrics = get_eval_task_metrics(results, task) + self.assertTrue(metrics, f"Expected Evalution metrics for task {task}") diff --git a/tests/test_eval_loader_args.py b/tests/test_eval_loader_args.py new file mode 100644 index 000000000..85f8be04a --- /dev/null +++ b/tests/test_eval_loader_args.py @@ -0,0 +1,234 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from types import SimpleNamespace + +import pytest + +from gptqmodel import BACKEND +from tests import eval as eval_module + + +def test_eval_string_model_load_filters_eval_only_keys(monkeypatch): + captured = {} + + class FakeGPTQModelEngine: + def __init__(self, **kwargs): + captured["engine_kwargs"] = kwargs + + def build(self, model_config): + captured["model_kwargs"] = dict(model_config.kwargs["model_kwargs"]) + raise RuntimeError("sentinel-load-stop") + + class FakeModel: + def __init__(self, **kwargs): + self.kwargs = kwargs + + def to_dict(self): + return dict(self.kwargs) + + fake_evalution = SimpleNamespace( + GPTQModel=FakeGPTQModelEngine, + Model=FakeModel, + ) + + model_args = { + "backend": BACKEND.EXLLAMA_V3, + "device": "cuda:0", + "gptqmodel": True, + "model_id_or_path": "/tmp/stale-model", + "pretrained": "/tmp/stale-pretrained", + "tokenizer": object(), + "trust_remote_code": False, + } + + with pytest.raises(RuntimeError, match="sentinel-load-stop"): + eval_module._build_evalution_runtime( + evalution=fake_evalution, + model_or_id_or_path="/tmp/current-model", + llm_backend="gptqmodel", + backend=BACKEND.EXLLAMA_V3, + batch_size=1, + trust_remote_code=True, + model_args=model_args, + tokenizer=None, + ) + + assert captured["engine_kwargs"]["backend"] == BACKEND.EXLLAMA_V3.value + assert captured["engine_kwargs"]["device"] == "cuda:0" + assert captured["engine_kwargs"]["trust_remote_code"] is True + assert captured["model_kwargs"] == {} + for key in ("gptqmodel", "pretrained", "tokenizer"): + assert key not in captured["model_kwargs"] + + +def test_build_evalution_runtime_supports_vllm_engine_options(): + captured = {} + + class FakeVLLM: + def __init__(self, **kwargs): + captured["engine_kwargs"] = kwargs + + def build(self, model_config): + captured["model_config"] = model_config + return "session" + + class FakeModel: + def __init__(self, **kwargs): + self.kwargs = kwargs + + def to_dict(self): + return dict(self.kwargs) + + fake_evalution = SimpleNamespace( + VLLM=FakeVLLM, + Model=FakeModel, + ) + + engine, model_config, session = eval_module._build_evalution_runtime( + evalution=fake_evalution, + model_or_id_or_path="/tmp/model", + llm_backend="vllm", + backend=BACKEND.AUTO, + batch_size=1, + trust_remote_code=True, + model_args={ + "dtype": "float16", + "gpu_memory_utilization": 0.8, + "tensor_parallel_size": "2", + "quantization": "gptq", + "tokenizer_mode": "auto", + "max_model_len": "4096", + "foo": "bar", + }, + tokenizer=None, + ) + + assert session == "session" + assert engine is not None + assert model_config.kwargs["path"] == "/tmp/model" + assert model_config.kwargs["model_kwargs"] == {"foo": "bar"} + assert captured["engine_kwargs"]["dtype"] == "float16" + assert captured["engine_kwargs"]["gpu_memory_utilization"] == 0.8 + assert captured["engine_kwargs"]["tensor_parallel_size"] == 2 + assert captured["engine_kwargs"]["quantization"] == "gptq" + assert captured["engine_kwargs"]["max_model_len"] == 4096 + + +def test_build_evalution_runtime_supports_sglang_engine_options(): + captured = {} + + class FakeSGLang: + def __init__(self, **kwargs): + captured["engine_kwargs"] = kwargs + + def build(self, model_config): + captured["model_config"] = model_config + return "session" + + class FakeModel: + def __init__(self, **kwargs): + self.kwargs = kwargs + + def to_dict(self): + return dict(self.kwargs) + + fake_evalution = SimpleNamespace( + SGLang=FakeSGLang, + Model=FakeModel, + ) + + engine, model_config, session = eval_module._build_evalution_runtime( + evalution=fake_evalution, + model_or_id_or_path="/tmp/model", + llm_backend="sglang", + backend=BACKEND.AUTO, + batch_size=2, + trust_remote_code=True, + model_args={ + "dtype": "float16", + "device": "cuda", + "gpu_memory_utilization": "0.75", + "tensor_parallel_size": "2", + "quantization": "gptq", + "tokenizer_mode": "auto", + "max_model_len": "8192", + "attention_backend": "flashinfer", + "sampling_backend": "pytorch", + "max_running_requests": "16", + "max_total_tokens": "32768", + "random_seed": "123", + "sampling_params": {"top_p": 0.9}, + "foo": "bar", + }, + tokenizer=None, + ) + + assert session == "session" + assert engine is not None + assert model_config.kwargs["path"] == "/tmp/model" + assert model_config.kwargs["model_kwargs"] == {"foo": "bar", "random_seed": '123'} + assert captured["engine_kwargs"]["dtype"] == "float16" + assert captured["engine_kwargs"]["device"] == "cuda" + assert captured["engine_kwargs"]["batch_size"] == 2 + assert captured["engine_kwargs"]["trust_remote_code"] is True + assert captured["engine_kwargs"]["quantization"] == "gptq" + assert captured["engine_kwargs"]["context_length"] == 8192 + assert captured["engine_kwargs"]["tp_size"] == 2 + assert captured["engine_kwargs"]["mem_fraction_static"] == 0.75 + assert captured["engine_kwargs"]["attention_backend"] == "flashinfer" + assert captured["engine_kwargs"]["sampling_backend"] == "pytorch" + assert captured["engine_kwargs"]["max_running_requests"] == 16 + assert captured["engine_kwargs"]["max_total_tokens"] == 32768 + assert captured["engine_kwargs"]["sampling_params"] == {"top_p": 0.9} + + +def test_build_evalution_runtime_supports_gptqmodel_seed(): + captured = {} + + class FakeGPTQModel: + def __init__(self, **kwargs): + captured["engine_kwargs"] = kwargs + + def build(self, model_config): + captured["model_config"] = model_config + return "session" + + class FakeModel: + def __init__(self, **kwargs): + self.kwargs = kwargs + + def to_dict(self): + return dict(self.kwargs) + + fake_evalution = SimpleNamespace( + GPTQModel=FakeGPTQModel, + Model=FakeModel, + ) + + engine, model_config, session = eval_module._build_evalution_runtime( + evalution=fake_evalution, + model_or_id_or_path="/tmp/model", + llm_backend="gptqmodel", + backend=BACKEND.AUTO, + batch_size=4, + trust_remote_code=True, + model_args={ + "dtype": "float16", + "seed": 898, + "device": "cuda:0", + "foo": "bar", + }, + tokenizer=None, + ) + + assert session == "session" + assert engine is not None + assert model_config.kwargs["path"] == "/tmp/model" + assert model_config.kwargs["model_kwargs"] == {"foo": "bar"} + assert captured["engine_kwargs"]["dtype"] == "float16" + assert captured["engine_kwargs"]["device"] == "cuda:0" + assert captured["engine_kwargs"]["batch_size"] == 4 + assert captured["engine_kwargs"]["seed"] == 898 diff --git a/tests/test_eval_runtime.py b/tests/test_eval_runtime.py new file mode 100644 index 000000000..512d20ef3 --- /dev/null +++ b/tests/test_eval_runtime.py @@ -0,0 +1,70 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +# -- do not touch +import os + + +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +# -- end do not touch +import tempfile # noqa: E402 +import unittest # noqa: E402 + +from gptqmodel import BACKEND +from tests.eval import evaluate, format_eval_result_table, get_eval_task_metrics # noqa: E402 + + +class TestEvalRuntime(unittest.TestCase): + + @classmethod + def setUpClass(self): + self.MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct-gptqmodel-4bit-vortex-v1" + self.random_seed = 1234 + self.task = "arc_challenge" + + # self.acc_score = 0.3183 + # self.acc_norm_score = 0.3515 + + + def test_eval_direct(self): + with tempfile.TemporaryDirectory() as tmp_dir: + results = evaluate( + model_or_id_or_path=self.MODEL_ID, + apply_chat_template=True, + output_path=tmp_dir, + tasks=[self.task], + ) + + print('--------Evalution Eval Result---------') + print(format_eval_result_table(results)) + print('--------Evalution Result End---------') + + metrics = get_eval_task_metrics(results, self.task) + acc_norm_score = metrics.get('accuracy,loglikelihood_norm') + + # self.assertGreaterEqual(acc_score, self.acc_score, "acc score does not match expected result") + self.assertGreaterEqual(acc_norm_score, 0.3400, "acc_norm score does not match expected result") + + def test_eval_path(self): + with tempfile.TemporaryDirectory() as tmp_dir: + results = evaluate( + model_or_id_or_path=self.MODEL_ID, + backend = BACKEND.EXLLAMA_V2, # for path loading, can override backend + output_path=tmp_dir, + tasks=[self.task], + model_args={ + "device": "cuda" + } + ) + + print('--------Evalution Eval Result---------') + print(format_eval_result_table(results)) + print('--------Evalution Result End---------') + + metrics = get_eval_task_metrics(results, self.task) + acc_norm_score = metrics.get('accuracy,loglikelihood_norm') + + # self.assertGreaterEqual(acc_score, self.acc_score, "acc score does not match expected result") + self.assertGreaterEqual(acc_norm_score, 0.3000, "acc_norm score does not match expected result") diff --git a/tests/test_evalution_suite_stream_defaults.py b/tests/test_evalution_suite_stream_defaults.py new file mode 100644 index 000000000..3a618c72d --- /dev/null +++ b/tests/test_evalution_suite_stream_defaults.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from types import SimpleNamespace + +from tests import eval as eval_module + + +def test_build_evalution_suite_defaults_mmlu_stream_true() -> None: + recorded_kwargs = {} + + def fake_mmlu(**kwargs): + recorded_kwargs.update(kwargs) + return kwargs + + eval_module._build_evalution_suite( + evalution=SimpleNamespace(benchmarks=SimpleNamespace(mmlu=fake_mmlu)), + task_name="mmlu_stem", + apply_chat_template=False, + batch_size=64, + generation_settings={}, + suite_kwargs={}, + ) + + assert recorded_kwargs["subsets"] == "stem" + assert recorded_kwargs["batch_size"] == 64 + assert recorded_kwargs["stream"] is True + + +def test_build_evalution_suite_preserves_explicit_stream_override() -> None: + recorded_kwargs = {} + + def fake_mmlu(**kwargs): + recorded_kwargs.update(kwargs) + return kwargs + + eval_module._build_evalution_suite( + evalution=SimpleNamespace(benchmarks=SimpleNamespace(mmlu=fake_mmlu)), + task_name="mmlu", + apply_chat_template=False, + batch_size=32, + generation_settings={}, + suite_kwargs={"stream": False}, + ) + + assert recorded_kwargs["batch_size"] == 32 + assert recorded_kwargs["stream"] is False diff --git a/tests/test_exllamav2_awq_jit.py b/tests/test_exllamav2_awq_jit.py new file mode 100644 index 000000000..3dea657dc --- /dev/null +++ b/tests/test_exllamav2_awq_jit.py @@ -0,0 +1,96 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +import torch + +import gptqmodel.nn_modules.qlinear.exllamav2_awq as exllamav2_awq_module + + +def _build_module() -> exllamav2_awq_module.AwqExllamaV2Linear: + return exllamav2_awq_module.AwqExllamaV2Linear( + bits=4, + group_size=128, + sym=True, + desc_act=False, + in_features=128, + out_features=128, + bias=False, + register_buffers=True, + ) + + +def test_exllamav2_awq_make_q_matrix_uses_jit_op(monkeypatch): + module = _build_module() + calls = {} + + def fake_make_q_matrix(qweight, q_perm, q_invperm, q_scale, q_scale_max, q_groups, qzeros, scales, g_idx, temp_dq): + calls["make_q_matrix_awq"] = { + "qweight_shape": tuple(qweight.shape), + "qzeros_shape": tuple(qzeros.shape), + "scales_dtype": scales.dtype, + "g_idx_is_none": g_idx is None, + "temp_dq_shape": tuple(temp_dq.shape), + } + return 123 + + monkeypatch.setattr(exllamav2_awq_module, "exllamav2_awq_make_q_matrix", fake_make_q_matrix) + + handle = module.ext_make_q_matrix_awq( + module.qweight, + module.qzeros, + module.scales.to(dtype=torch.float32), + torch.zeros(module.temp_dq_size() // 2, dtype=torch.float16), + ) + + assert handle == 123 + assert calls["make_q_matrix_awq"] == { + "qweight_shape": tuple(module.qweight.shape), + "qzeros_shape": tuple(module.qzeros.shape), + "scales_dtype": torch.float16, + "g_idx_is_none": True, + "temp_dq_shape": (module.temp_dq_size() // 2,), + } + + +def test_exllamav2_awq_forward_uses_jit_gemm(monkeypatch): + module = _build_module() + module.q_handle = 77 + calls = {} + + def fake_gemm(x, q_handle, output, force_cuda): + calls["gemm_half_q_half_awq"] = { + "x_shape": tuple(x.shape), + "x_dtype": x.dtype, + "q_handle": q_handle, + "output_shape": tuple(output.shape), + "force_cuda": force_cuda, + } + output.copy_(torch.full_like(output, 7.0)) + + monkeypatch.setattr(exllamav2_awq_module, "exllamav2_awq_gemm_half_q_half", fake_gemm) + monkeypatch.setattr(exllamav2_awq_module, "exllamav2_awq_runtime_available", lambda: True) + + x = torch.randn((2, module.in_features), dtype=torch.float32) + out = module(x) + + assert calls["gemm_half_q_half_awq"] == { + "x_shape": (2, module.in_features), + "x_dtype": torch.float16, + "q_handle": 77, + "output_shape": (2, module.out_features), + "force_cuda": False, + } + assert out.shape == (2, module.out_features) + assert out.dtype == torch.float32 + assert torch.allclose(out, torch.full_like(out, 7.0)) + + +def test_exllamav2_awq_validate_once_surfaces_jit_error(monkeypatch): + monkeypatch.setattr(exllamav2_awq_module, "exllamav2_awq_runtime_available", lambda: False) + monkeypatch.setattr(exllamav2_awq_module, "exllamav2_awq_runtime_error", lambda: "missing exllamav2 awq jit ops") + + ok, err = exllamav2_awq_module.AwqExllamaV2Linear.validate_once() + + assert ok is False + assert isinstance(err, ImportError) + assert "missing exllamav2 awq jit ops" in str(err) diff --git a/tests/test_exllamav2_jit.py b/tests/test_exllamav2_jit.py new file mode 100644 index 000000000..2e8943dd8 --- /dev/null +++ b/tests/test_exllamav2_jit.py @@ -0,0 +1,134 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +import torch + +import gptqmodel.nn_modules.qlinear.exllamav2 as exllamav2_module +import gptqmodel.utils.exllamav2 as exllamav2_utils +from gptqmodel.utils import cpp as cpp_module + + +def _build_module() -> exllamav2_module.ExllamaV2Linear: + return exllamav2_module.ExllamaV2Linear( + bits=4, + group_size=128, + desc_act=False, + sym=True, + in_features=128, + out_features=128, + bias=False, + register_buffers=True, + ) + + +def test_exllamav2_make_q_matrix_uses_jit_op(monkeypatch): + module = _build_module() + calls = {} + + def fake_make_q_matrix(qweight, q_perm, q_invperm, q_scale, q_scale_max, q_groups, qzeros, scales, g_idx, temp_dq): + calls["make_q_matrix"] = { + "qweight_shape": tuple(qweight.shape), + "scales_dtype": scales.dtype, + "g_idx_is_none": g_idx is None, + "temp_dq_shape": tuple(temp_dq.shape), + } + return 123 + + monkeypatch.setattr(exllamav2_module, "exllamav2_make_q_matrix", fake_make_q_matrix) + + weights = { + "qweight": module.qweight, + "qzeros": module.qzeros, + "scales": module.scales.to(dtype=torch.float32), + "g_idx": torch.zeros_like(module.g_idx), + } + + handle = module.ext_make_q_matrix(weights, torch.zeros(module.temp_dq_size() // 2, dtype=torch.float16)) + + assert handle == 123 + assert calls["make_q_matrix"] == { + "qweight_shape": tuple(module.qweight.shape), + "scales_dtype": torch.float16, + "g_idx_is_none": True, + "temp_dq_shape": (module.temp_dq_size() // 2,), + } + + +def test_exllamav2_forward_uses_jit_gemm(monkeypatch): + module = _build_module() + module.q_handle = 77 + calls = {} + + def fake_gemm(x, q_handle, output, force_cuda): + calls["gemm"] = { + "x_shape": tuple(x.shape), + "x_dtype": x.dtype, + "q_handle": q_handle, + "output_shape": tuple(output.shape), + "force_cuda": force_cuda, + } + output.copy_(torch.full_like(output, 5.0)) + + monkeypatch.setattr(exllamav2_module, "exllamav2_gemm_half_q_half", fake_gemm) + + x = torch.randn((2, module.in_features), dtype=torch.float32) + out = module(x) + + assert calls["gemm"] == { + "x_shape": (2, module.in_features), + "x_dtype": torch.float16, + "q_handle": 77, + "output_shape": (2, module.out_features), + "force_cuda": False, + } + assert out.shape == (2, module.out_features) + assert out.dtype == torch.float32 + assert torch.allclose(out, torch.full_like(out, 5.0)) + + +def test_exllamav2_validate_once_surfaces_jit_error(monkeypatch): + monkeypatch.setattr(exllamav2_module, "exllamav2_gptq_runtime_available", lambda: False) + monkeypatch.setattr(exllamav2_module, "exllamav2_gptq_runtime_error", lambda: "missing exllamav2 jit ops") + + ok, err = exllamav2_module.ExllamaV2Linear.validate_once() + + assert ok is False + assert isinstance(err, ImportError) + assert "missing exllamav2 jit ops" in str(err) + + +def test_exllamav2_include_paths_use_wheel_headers_when_local_cuda_is_incomplete(monkeypatch, tmp_path): + root = tmp_path / "exllamav2" + local_cuda_include = tmp_path / "local_cuda_include" + wheel_cuda_include = tmp_path / "wheel_cuda_include" + root.mkdir() + local_cuda_include.mkdir() + wheel_cuda_include.mkdir() + (wheel_cuda_include / "cusparse.h").write_text("// stub", encoding="utf-8") + + monkeypatch.setattr(exllamav2_utils, "_exllamav2_root", lambda: root) + monkeypatch.setattr(cpp_module, "detected_local_cuda_include_paths", lambda: [str(local_cuda_include)]) + monkeypatch.setattr(cpp_module, "detected_cuda_wheel_include_paths", lambda: [str(wheel_cuda_include)]) + + include_paths = exllamav2_utils._exllamav2_include_paths() + + assert include_paths[0] == str(root) + assert str(wheel_cuda_include) in include_paths + + +def test_exllamav2_include_paths_skip_wheel_headers_when_local_cuda_has_required_headers(monkeypatch, tmp_path): + root = tmp_path / "exllamav2" + local_cuda_include = tmp_path / "local_cuda_include" + wheel_cuda_include = tmp_path / "wheel_cuda_include" + root.mkdir() + local_cuda_include.mkdir() + wheel_cuda_include.mkdir() + (local_cuda_include / "cusparse.h").write_text("// stub", encoding="utf-8") + + monkeypatch.setattr(exllamav2_utils, "_exllamav2_root", lambda: root) + monkeypatch.setattr(cpp_module, "detected_local_cuda_include_paths", lambda: [str(local_cuda_include)]) + monkeypatch.setattr(cpp_module, "detected_cuda_wheel_include_paths", lambda: [str(wheel_cuda_include)]) + + include_paths = exllamav2_utils._exllamav2_include_paths() + + assert include_paths == [str(root)] diff --git a/tests/test_exllamav3.py b/tests/test_exllamav3.py new file mode 100644 index 000000000..d3205a62d --- /dev/null +++ b/tests/test_exllamav3.py @@ -0,0 +1,136 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +import json + +import torch +import torch.nn as nn +from safetensors.torch import save_file + +from gptqmodel.nn_modules.exllamav3 import ExllamaV3Linear +from gptqmodel.nn_modules.exllamav3_torch import ExllamaV3TorchLinear +from gptqmodel.quantization.config import FORMAT, METHOD, EXL3Config, QuantizeConfig +from gptqmodel.utils.exllamav3 import build_exllamav3_tensor_storage, replace_exllamav3_placeholders +from gptqmodel.utils.model_dequant import detect_format + + +def test_exllamav3_quantize_config_round_trip(): + cfg = QuantizeConfig( + quant_method=METHOD.EXL3, + format=FORMAT.EXL3, + bits=2.25, + head_bits=4.0, + out_scales="always", + codebook="mul1", + ) + + assert isinstance(cfg, EXL3Config) + assert cfg.quant_method == METHOD.EXL3 + assert cfg.format == FORMAT.EXL3 + assert cfg.runtime_bits == 2 + assert cfg.uses_weight_only_lifecycle() is False + assert cfg.requires_calibration_dataset() is True + + payload = cfg.to_dict() + assert payload["bits"] == 2.25 + assert payload["head_bits"] == 4.0 + assert payload["out_scales"] == "always" + assert payload["codebook"] == "mul1" + + reloaded = QuantizeConfig.from_quant_config(payload) + assert isinstance(reloaded, EXL3Config) + assert reloaded.bits == 2.25 + assert reloaded.head_bits == 4.0 + assert reloaded.out_scales == "always" + assert reloaded.codebook == "mul1" + assert reloaded.runtime_bits == 2 + + +class _TinyModel(nn.Module): + def __init__(self): + super().__init__() + self.proj = nn.Linear(16, 16, bias=True) + + +def test_replace_exllamav3_placeholders_uses_tensor_storage_metadata(): + model = _TinyModel() + tensor_storage = { + "proj": { + "stored_tensors": { + "proj.trellis": {"shape": [1, 1, 32], "torch_dtype": "int16"}, + "proj.suh": {"shape": [16], "torch_dtype": "float16"}, + "proj.svh": {"shape": [16], "torch_dtype": "float16"}, + "proj.bias": {"shape": [16], "torch_dtype": "float16"}, + "proj.mul1": {"shape": [], "torch_dtype": "int32"}, + }, + "quant_format": "exl3", + "bits_per_weight": 2, + } + } + + replace_exllamav3_placeholders( + model=model, + module_names=["proj"], + tensor_storage=tensor_storage, + ) + + assert isinstance(model.proj, ExllamaV3Linear) + assert model.proj.trellis.device.type == "meta" + assert tuple(model.proj.trellis.shape) == (1, 1, 32) + assert model.proj.suh.dtype == torch.float16 + assert model.proj.svh.dtype == torch.float16 + assert model.proj.bias.dtype == torch.float16 + assert model.proj.mul1.dtype == torch.int32 + + +def test_replace_exllamav3_placeholders_supports_torch_reference_kernel(): + model = _TinyModel() + tensor_storage = { + "proj": { + "stored_tensors": { + "proj.trellis": {"shape": [1, 1, 32], "torch_dtype": "int16"}, + "proj.suh": {"shape": [16], "torch_dtype": "float16"}, + "proj.svh": {"shape": [16], "torch_dtype": "float16"}, + }, + "quant_format": "exl3", + "bits_per_weight": 2, + } + } + + replace_exllamav3_placeholders( + model=model, + module_names=["proj"], + tensor_storage=tensor_storage, + module_cls=ExllamaV3TorchLinear, + ) + + assert isinstance(model.proj, ExllamaV3TorchLinear) + assert build_exllamav3_tensor_storage(model)["proj"]["quant_format"] == "exl3" + + +def test_detect_format_identifies_exllamav3(tmp_path): + shard_path = tmp_path / "model.safetensors" + save_file( + { + "model.layers.0.self_attn.q_proj.trellis": torch.zeros((1, 1, 32), dtype=torch.int16), + "model.layers.0.self_attn.q_proj.suh": torch.zeros((16,), dtype=torch.float16), + "model.layers.0.self_attn.q_proj.svh": torch.zeros((16,), dtype=torch.float16), + }, + str(shard_path), + ) + + config_path = tmp_path / "config.json" + config_path.write_text( + json.dumps( + { + "quantization_config": { + "quant_method": "exl3", + "format": "exl3", + } + } + ), + encoding="utf-8", + ) + + detected = detect_format(tmp_path, json.loads(config_path.read_text(encoding="utf-8"))) + assert detected == "exl3" diff --git a/tests/test_exllamav3_jit.py b/tests/test_exllamav3_jit.py new file mode 100644 index 000000000..1fbeec3c9 --- /dev/null +++ b/tests/test_exllamav3_jit.py @@ -0,0 +1,206 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import pytest +import torch + +import gptqmodel.exllamav3.ext as exllamav3_ext_module +from gptqmodel.utils import cpp as cpp_module + + +class _FakeExtension: + def __init__(self, ops: dict[str, object], *, available: bool = True, error: str = ""): + self._ops = ops + self._available = available + self._error = error + + def load(self) -> bool: + return self._available + + def last_error_message(self) -> str: + return self._error + + def op(self, name: str): + return self._ops[name] + + +def test_exllamav3_bc_linear_wrapper_uses_jit_op(monkeypatch): + calls = {} + + def fake_run(trellis, suh, svh, K, bias, mcg, mul1, xh, x, y): + calls["bc_linear_exl3_run"] = { + "trellis_shape": tuple(trellis.shape), + "suh_shape": tuple(suh.shape), + "svh_shape": tuple(svh.shape), + "K": K, + "bias_is_none": bias is None, + "mcg": mcg, + "mul1": mul1, + "xh_shape": tuple(xh.shape), + "x_shape": tuple(x.shape), + "y_shape": tuple(y.shape), + } + + monkeypatch.setattr( + exllamav3_ext_module, + "_EXLLAMAV3_TORCH_OPS_EXTENSION", + _FakeExtension({"bc_linear_exl3_run": fake_run}), + ) + + wrapper = exllamav3_ext_module.exllamav3_ext.BC_LinearEXL3( + trellis=torch.zeros((1, 4, 32), dtype=torch.int16), + suh=torch.zeros((128,), dtype=torch.float16), + svh=torch.zeros((64,), dtype=torch.float16), + K=2, + bias=None, + mcg=True, + mul1=False, + xh=torch.zeros((1, 128), dtype=torch.float16), + ) + + wrapper.run( + torch.zeros((3, 128), dtype=torch.float16), + torch.zeros((3, 64), dtype=torch.float16), + ) + + assert calls["bc_linear_exl3_run"] == { + "trellis_shape": (1, 4, 32), + "suh_shape": (128,), + "svh_shape": (64,), + "K": 2, + "bias_is_none": True, + "mcg": True, + "mul1": False, + "xh_shape": (1, 128), + "x_shape": (3, 128), + "y_shape": (3, 64), + } + + +def test_exllamav3_quantize_tiles_uses_jit_op(monkeypatch): + calls = {} + + def fake_quantize_tiles(input_tiles, output_tiles, output_indices, temp_costs, temp_edges, K, mcg, mul1): + calls["quantize_tiles"] = { + "input_shape": tuple(input_tiles.shape), + "output_shape": tuple(output_tiles.shape), + "indices_dtype": output_indices.dtype, + "temp_costs_shape": tuple(temp_costs.shape), + "temp_edges_shape": tuple(temp_edges.shape), + "K": K, + "mcg": mcg, + "mul1": mul1, + } + output_tiles.fill_(1.0) + output_indices.fill_(7) + + monkeypatch.setattr( + exllamav3_ext_module, + "_EXLLAMAV3_TORCH_OPS_EXTENSION", + _FakeExtension({"quantize_tiles": fake_quantize_tiles}), + ) + + input_tiles = torch.zeros((2, 256), dtype=torch.float32) + output_tiles = torch.zeros_like(input_tiles) + output_indices = torch.zeros((2, 256), dtype=torch.int16) + temp_costs = torch.zeros((2, 2, 16384), dtype=torch.float16) + temp_edges = torch.zeros((2, 256, 16384), dtype=torch.int16) + + exllamav3_ext_module.exllamav3_ext.quantize_tiles( + input_tiles, + output_tiles, + output_indices, + temp_costs, + temp_edges, + 2, + False, + True, + ) + + assert calls["quantize_tiles"] == { + "input_shape": (2, 256), + "output_shape": (2, 256), + "indices_dtype": torch.int16, + "temp_costs_shape": (2, 2, 16384), + "temp_edges_shape": (2, 256, 16384), + "K": 2, + "mcg": False, + "mul1": True, + } + assert torch.all(output_tiles == 1.0) + assert torch.all(output_indices == 7) + + +def test_exllamav3_bc_linear_wrapper_surfaces_jit_error(monkeypatch): + monkeypatch.setattr( + exllamav3_ext_module, + "_EXLLAMAV3_TORCH_OPS_EXTENSION", + _FakeExtension({}, available=False, error="missing exllamav3 jit ops"), + ) + + with pytest.raises(ModuleNotFoundError, match="missing exllamav3 jit ops"): + exllamav3_ext_module.exllamav3_ext.BC_LinearEXL3( + trellis=torch.zeros((1, 4, 32), dtype=torch.int16), + suh=torch.zeros((128,), dtype=torch.float16), + svh=torch.zeros((64,), dtype=torch.float16), + K=2, + bias=None, + mcg=False, + mul1=False, + xh=torch.zeros((1, 128), dtype=torch.float16), + ) + + +def test_exllamav3_include_paths_use_wheel_headers_when_local_cuda_is_incomplete(monkeypatch, tmp_path): + source_root = tmp_path / "exllamav3" + local_cuda_include = tmp_path / "local_cuda_include" + wheel_cuda_include = tmp_path / "wheel_cuda_include" + source_root.mkdir() + local_cuda_include.mkdir() + wheel_cuda_include.mkdir() + (wheel_cuda_include / "cusparse.h").write_text("// stub", encoding="utf-8") + + monkeypatch.setattr(exllamav3_ext_module, "_source_root", lambda: source_root) + monkeypatch.setattr( + cpp_module, + "detected_local_cuda_include_paths", + lambda: [str(local_cuda_include)], + ) + monkeypatch.setattr( + cpp_module, + "detected_cuda_wheel_include_paths", + lambda: [str(wheel_cuda_include)], + ) + + include_paths = exllamav3_ext_module._exllamav3_include_paths() + + assert include_paths[0] == str(source_root) + assert str(wheel_cuda_include) in include_paths + + +def test_exllamav3_include_paths_skip_wheel_headers_when_local_cuda_has_required_headers(monkeypatch, tmp_path): + source_root = tmp_path / "exllamav3" + local_cuda_include = tmp_path / "local_cuda_include" + wheel_cuda_include = tmp_path / "wheel_cuda_include" + source_root.mkdir() + local_cuda_include.mkdir() + wheel_cuda_include.mkdir() + (local_cuda_include / "cusparse.h").write_text("// stub", encoding="utf-8") + + monkeypatch.setattr(exllamav3_ext_module, "_source_root", lambda: source_root) + monkeypatch.setattr( + cpp_module, + "detected_local_cuda_include_paths", + lambda: [str(local_cuda_include)], + ) + monkeypatch.setattr( + cpp_module, + "detected_cuda_wheel_include_paths", + lambda: [str(wheel_cuda_include)], + ) + + include_paths = exllamav3_ext_module._exllamav3_include_paths() + + assert include_paths == [str(source_root)] diff --git a/tests/test_extension_load_api.py b/tests/test_extension_load_api.py new file mode 100644 index 000000000..d82e70d95 --- /dev/null +++ b/tests/test_extension_load_api.py @@ -0,0 +1,212 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import threading +import time + +import pytest + +import gptqmodel +import gptqmodel.exllamav3.ext as exllamav3_ext +import gptqmodel.extension as extension_api +import gptqmodel.utils.awq as awq_utils +import gptqmodel.utils.cpp as cpp_utils +import gptqmodel.utils.exllamav2 as exllamav2_utils +import gptqmodel.utils.machete as machete_utils +import gptqmodel.utils.marlin as marlin_utils +import gptqmodel.utils.paroquant as paroquant_utils +import gptqmodel.utils.qqq as qqq_utils + + +class _FakeExtension: + def __init__(self, name: str, *, ok: bool = True, error: str = "", already_loaded: bool = False): + self.display_name = name + self.ok = ok + self.error = error + self.already_loaded = already_loaded + self.load_calls = 0 + self.clear_cache_calls = 0 + self.max_parallel_loads = 0 + self._active_loads = 0 + self._active_lock = threading.Lock() + self._ops = {"test_op": object()} + + def _ops_available(self) -> bool: + return self.already_loaded + + def clear_cache(self) -> None: + self.clear_cache_calls += 1 + + def load(self) -> bool: + self.load_calls += 1 + with self._active_lock: + self._active_loads += 1 + self.max_parallel_loads = max(self.max_parallel_loads, self._active_loads) + time.sleep(0.02) + with self._active_lock: + self._active_loads -= 1 + return self.ok + + def last_error_message(self) -> str: + return self.error + + def namespace_object(self) -> object: + return self + + def op(self, op_name: str): + return self._ops[op_name] + + +def _install_fake_extensions(monkeypatch): + fakes = { + "pack_block_cpu": _FakeExtension("pack_block_cpu"), + "floatx_cpu": _FakeExtension("floatx_cpu"), + "awq": _FakeExtension("AWQ"), + "qqq": _FakeExtension("QQQ"), + "exllamav2": _FakeExtension("ExLlamaV2 GPTQ"), + "exllamav2_awq": _FakeExtension("ExLlamaV2 AWQ"), + "exllamav3": _FakeExtension("ExLlamaV3"), + "machete": _FakeExtension("Machete"), + "marlin_fp16": _FakeExtension("Marlin fp16"), + "marlin_bf16": _FakeExtension("Marlin bf16"), + "paroquant": _FakeExtension("ParoQuant rotation"), + } + + monkeypatch.setattr(cpp_utils, "_pack_block_extension", lambda: fakes["pack_block_cpu"]) + monkeypatch.setattr(cpp_utils, "_floatx_cpu_extension", lambda: fakes["floatx_cpu"]) + monkeypatch.setattr(awq_utils, "_AWQ_TORCH_OPS_EXTENSION", fakes["awq"]) + monkeypatch.setattr(qqq_utils, "_QQQ_TORCH_OPS_EXTENSION", fakes["qqq"]) + monkeypatch.setattr(exllamav2_utils, "_EXLLAMAV2_GPTQ_TORCH_OPS_EXTENSION", fakes["exllamav2"]) + monkeypatch.setattr(exllamav2_utils, "_EXLLAMAV2_AWQ_TORCH_OPS_EXTENSION", fakes["exllamav2_awq"]) + monkeypatch.setattr(exllamav3_ext, "_EXLLAMAV3_TORCH_OPS_EXTENSION", fakes["exllamav3"]) + monkeypatch.setattr(machete_utils, "_MACHETE_TORCH_OPS_EXTENSION", fakes["machete"]) + monkeypatch.setattr(machete_utils, "_validate_machete_device_support", lambda: True) + monkeypatch.setattr(marlin_utils, "_MARLIN_FP16_TORCH_OPS_EXTENSION", fakes["marlin_fp16"]) + monkeypatch.setattr(marlin_utils, "_MARLIN_BF16_TORCH_OPS_EXTENSION", fakes["marlin_bf16"]) + monkeypatch.setattr(paroquant_utils, "_PAROQUANT_ROTATION_EXTENSION", fakes["paroquant"]) + + return fakes + + +def test_package_root_exports_extension_module(): + assert gptqmodel.extension is extension_api + + +def test_load_defaults_to_all_extensions(monkeypatch): + fakes = _install_fake_extensions(monkeypatch) + + result = extension_api.load() + + assert result == { + "pack_block_cpu": True, + "floatx_cpu": True, + "awq": True, + "qqq": True, + "exllamav2": True, + "exllamav2_awq": True, + "exllamav3": True, + "machete": True, + "marlin_fp16": True, + "marlin_bf16": True, + "paroquant": True, + } + assert all(fake.load_calls == 1 for fake in fakes.values()) + + +def test_load_all_skips_extensions_unsupported_on_this_host(monkeypatch): + fakes = _install_fake_extensions(monkeypatch) + monkeypatch.setattr(machete_utils, "_validate_machete_device_support", lambda: False) + + result = extension_api.load() + + assert "machete" not in result + assert fakes["machete"].load_calls == 0 + + +def test_load_specific_unsupported_extension_raises_without_building(monkeypatch): + fakes = _install_fake_extensions(monkeypatch) + monkeypatch.setattr(machete_utils, "_validate_machete_device_support", lambda: False) + monkeypatch.setattr(machete_utils, "machete_runtime_error", lambda: "Machete unsupported on this device.") + + with pytest.raises(RuntimeError, match="Machete unsupported on this device."): + extension_api.load(name="machete") + + assert fakes["machete"].load_calls == 0 + + +def test_load_marlin_alias_builds_both_variants(monkeypatch): + fakes = _install_fake_extensions(monkeypatch) + + result = extension_api.load(name="marlin") + + assert result == {"marlin_fp16": True, "marlin_bf16": True} + assert fakes["marlin_fp16"].load_calls == 1 + assert fakes["marlin_bf16"].load_calls == 1 + assert fakes["awq"].load_calls == 0 + + +def test_load_specific_extension_honors_use_cache_false(monkeypatch): + fakes = _install_fake_extensions(monkeypatch) + + result = extension_api.load(name="exllama-v2-awq", use_cache=False) + + assert result == {"exllamav2_awq": True} + assert fakes["exllamav2_awq"].clear_cache_calls == 1 + assert fakes["exllamav2_awq"].load_calls == 1 + + +def test_load_raises_for_unknown_extension(monkeypatch): + _install_fake_extensions(monkeypatch) + + with pytest.raises(ValueError, match="Unknown extension"): + extension_api.load(name="missing_extension") + + +def test_load_aggregates_extension_failures(monkeypatch): + fakes = _install_fake_extensions(monkeypatch) + fakes["awq"].ok = False + fakes["awq"].error = "AWQ toolchain failure" + + with pytest.raises(RuntimeError, match="AWQ toolchain failure"): + extension_api.load(name="awq") + + +def test_use_cache_false_requires_fresh_process_for_loaded_extensions(monkeypatch): + fakes = _install_fake_extensions(monkeypatch) + fakes["qqq"].already_loaded = True + + with pytest.raises(RuntimeError, match="Restart Python to force recompilation"): + extension_api.load(name="qqq", use_cache=False) + + assert fakes["qqq"].clear_cache_calls == 0 + assert fakes["qqq"].load_calls == 0 + + +def test_op_routes_through_extension_api(monkeypatch): + _install_fake_extensions(monkeypatch) + + op = extension_api.op("awq", "test_op") + + assert op is awq_utils._AWQ_TORCH_OPS_EXTENSION._ops["test_op"] + + +def test_load_serializes_same_extension_across_threads(monkeypatch): + fakes = _install_fake_extensions(monkeypatch) + errors: list[Exception] = [] + + def runner(): + try: + extension_api.load(name="awq") + except Exception as exc: # pragma: no cover - assertion path below + errors.append(exc) + + threads = [threading.Thread(target=runner) for _ in range(4)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + assert errors == [] + assert fakes["awq"].max_parallel_loads == 1 diff --git a/tests/test_failsafe.py b/tests/test_failsafe.py index 149ce1306..538c13e21 100644 --- a/tests/test_failsafe.py +++ b/tests/test_failsafe.py @@ -16,17 +16,22 @@ FailSafe, FailSafeStrategy, QuantizeConfig, + SmoothAuto, SmoothLog, SmoothMAD, + SmoothMSE, SmoothPercentile, SmoothPercentileAsymmetric, ) from gptqmodel.quantization.failsafe_smooth import smooth_block -from gptqmodel.quantization.gptq import GPTQ +from gptqmodel.quantization.gptq import GPTQ, _row_replace_mask from gptqmodel.utils.failsafe import should_use_failsafe from gptqmodel.utils.pause_resume import PauseResumeController +AUTO_SMOOTHER_TEST_TOLERANCE = 1e-7 + + def test_smooth_mad_uses_sigma_normalized_window(): torch.manual_seed(0) @@ -43,6 +48,110 @@ def test_smooth_mad_uses_sigma_normalized_window(): assert clip_ratio < 0.03, f"clip_ratio={clip_ratio:.4f} is too high for sigma-normalized MAD smoothing" +def test_auto_smoother_matches_or_beats_single_smoothers(): + torch.manual_seed(0) + weights = torch.randn(16, 128) * 0.2 + weights[:, ::17] *= 9.0 + + auto = _failsafe_quantize( + weights, + 128, + FailSafeStrategy.RTN, + smooth=SmoothAuto( + include_none=True, + mse_steps=36, + mse_maxshrink=0.9, + mad_k=2.75, + percentile=99.5, + low=0.25, + high=99.75, + ), + ) + base = _failsafe_quantize(weights, 128, FailSafeStrategy.RTN) + mse = _failsafe_quantize(weights, 128, FailSafeStrategy.RTN, smooth=SmoothMSE(steps=36, maxshrink=0.9)) + mad = _failsafe_quantize(weights, 128, FailSafeStrategy.RTN, smooth=SmoothMAD(k=2.75)) + asym = _failsafe_quantize( + weights, + 128, + FailSafeStrategy.RTN, + smooth=SmoothPercentileAsymmetric(low=0.25, high=99.75), + ) + shrink = _failsafe_quantize(weights, 128, FailSafeStrategy.RTN, smooth=SmoothPercentile(percentile=99.5)) + + auto_err = (weights - auto).pow(2).mean(dim=1) + base_err = (weights - base).pow(2).mean(dim=1) + mse_err = (weights - mse).pow(2).mean(dim=1) + mad_err = (weights - mad).pow(2).mean(dim=1) + asym_err = (weights - asym).pow(2).mean(dim=1) + shrink_err = (weights - shrink).pow(2).mean(dim=1) + + for label, candidate_err in ( + ("base", base_err), + ("mse", mse_err), + ("mad", mad_err), + ("asym", asym_err), + ("percentile", shrink_err), + ): + assert torch.all(auto_err <= candidate_err + AUTO_SMOOTHER_TEST_TOLERANCE), label + + +def test_row_replace_mask_handles_vector_and_matrix_targets(): + replace = torch.tensor([[True], [False], [True]]) + + matrix_target = torch.ones((3, 1), dtype=torch.float32) + vector_target = torch.ones((3,), dtype=torch.float32) + selected_values_matrix = torch.tensor([[1.0], [2.0], [3.0]]) + selected_values_vector = torch.tensor([1.0, 2.0, 3.0]) + fallback_values_matrix = torch.tensor([[9.0], [9.0], [9.0]]) + fallback_values_vector = torch.tensor([9.0, 9.0, 9.0]) + expected_matrix = torch.tensor([[1.0], [9.0], [3.0]]) + expected_vector = torch.tensor([1.0, 9.0, 3.0]) + + matrix_mask = _row_replace_mask(replace, matrix_target) + vector_mask = _row_replace_mask(replace, vector_target) + + assert matrix_mask.shape == matrix_target.shape + assert vector_mask.shape == vector_target.shape + + matrix_selected = torch.where(matrix_mask, selected_values_matrix, fallback_values_matrix) + vector_selected = torch.where(vector_mask, selected_values_vector, fallback_values_vector) + + torch.testing.assert_close(matrix_selected, expected_matrix) + torch.testing.assert_close(vector_selected, expected_vector) + + +def test_nan_loss_retries_with_failsafe_instead_of_enabling_mock_quantization(monkeypatch): + torch.manual_seed(0) + + linear = nn.Linear(32, 16, bias=False) + qcfg = QuantizeConfig( + bits=4, + group_size=8, + desc_act=False, + act_group_aware=False, + failsafe={"strategy": "rtn", "threshold": True}, + ) + gptq = GPTQ(linear, qcfg) + gptq.quantizer.configure(perchannel=True) + gptq.add_batch(torch.randn(12, 32), None) + + calls = {"count": 0} + original_quantize = gptq.quantizer.quantize + + def quantize_with_single_nan(x): + calls["count"] += 1 + if calls["count"] == 1: + return torch.full_like(x, float("nan")) + return original_quantize(x) + + monkeypatch.setattr(gptq.quantizer, "quantize", quantize_with_single_nan) + + _, _, _, _, _, avg_loss, _, _ = gptq.quantize(blocksize=8) + + assert avg_loss.startswith("failsafe(rtn):") + assert gptq.qcfg.mock_quantization is False + + class TestGPTQHessianSimilarity(unittest.TestCase): """ This test verifies that Hessian-based GPTQ produces quantized weights diff --git a/tests/test_fallback.py b/tests/test_fallback.py new file mode 100644 index 000000000..63c956e1c --- /dev/null +++ b/tests/test_fallback.py @@ -0,0 +1,742 @@ +import math +import os +import types +import unittest +from glob import glob +from types import SimpleNamespace + +import pytest +import torch +import torch.nn as nn +from module_tree.test_subset import _StubAWQProcessor + +from gptqmodel.looper.named_module import NamedModule +from gptqmodel.quantization.config import ( + FORMAT, + METHOD, + Fallback, + FallbackStrategy, + QuantizeConfig, + SmoothLog, + SmoothMAD, + SmoothPercentile, + SmoothPercentileAsymmetric, +) +from gptqmodel.quantization.fallback_smooth import smooth_block +from gptqmodel.quantization.gptq import GPTQ +from gptqmodel.utils.fallback import should_use_fallback + + +def test_smooth_mad_uses_sigma_normalized_window(): + torch.manual_seed(0) + + # For Gaussian-like rows, SmoothMAD(k=2.75) should clip only far-tail + # outliers. Without sigma normalization, the same `k` clips near 1.85 sigma + # and removes several times more weights than intended. + block = torch.randn(2048, 128) + clipped, _ = smooth_block( + block, + Fallback(strategy=FallbackStrategy.RTN, threshold=True, smooth=SmoothMAD(k=2.75)), + group_size=128, + ) + + clip_ratio = (clipped != block).float().mean().item() + assert clip_ratio < 0.03, f"clip_ratio={clip_ratio:.4f} is too high for sigma-normalized MAD smoothing" + + +class TestGPTQHessianSimilarity(unittest.TestCase): + """ + This test verifies that Hessian-based GPTQ produces quantized weights + that remain numerically close to RTN fallback, while still introducing + minimal corrective differences. + + The test intentionally checks *similarity*, not equality. + """ + + def _run_test(self, device: str): + torch.manual_seed(0) + + # Large dimensions are intentionally used to: + # - avoid degenerate small-layer behavior + # - amplify Hessian effects in a stable manner + in_features = 1024 + out_features = 2048 + batch = 4 + seq = 16 + + inp = torch.randn(batch, seq, in_features, device=device) + linear = nn.Linear(in_features, out_features, bias=False).to(device) + + qcfg = QuantizeConfig( + bits=4, + group_size=128, + fallback={"strategy": "rtn", "threshold": False}, + ) + + # ============================================================ + # Hessian-based GPTQ (use_hessian = True) + # ============================================================ + gptq_h = GPTQ(linear, qcfg) + gptq_h.quantizer.configure(perchannel=True) + gptq_h.fallback = False + + # Accumulate Hessian via the public API + gptq_h.add_batch(inp, None) + + Q_h, scale_h, zero_h, gidx_h, *_ = gptq_h.quantize() + + # ============================================================ + # RTN fallback (use_hessian = False) + # ============================================================ + qcfg.fallback={"strategy": "rtn", "threshold": True} + gptq_r = GPTQ(linear, qcfg) + gptq_r.quantizer.configure(perchannel=True) + gptq_r.fallback = qcfg.fallback + + # IMPORTANT: + # We intentionally do NOT call add_batch here, + # so nsamples == 0 and the code falls back to RTN-style quantization. + Q_r, scale_r, zero_r, gidx_r, *_ = gptq_r.quantize() + + # ============================================================ + # Assertions + # ============================================================ + + # ------------------------------------------------------------ + # 1. Quantized weights should remain numerically close + # + # GPTQ tries to stay very close to the RTN baseline while reducing + # global error using Hessian-based correction. The RTN scale tensor + # stores the uniform quantizer step size for each group; its mean + # stands in for a single-bin width when we decide what counts as "close". + # ------------------------------------------------------------ + quant_step = torch.mean(scale_r).item() + self.assertGreater(quant_step, 0.0, msg="RTN-derived quantization step must be positive") + + close_mask = torch.isclose(Q_h, Q_r, atol=quant_step) + close_ratio = close_mask.float().mean().item() + + self.assertGreater( + close_ratio, + 0.95, + msg="At least 95% of quantized values should stay within one average RTN quantization step", + ) + + # ------------------------------------------------------------ + # 2. Quantized weights must NOT be exactly identical + # + # At least some discrete corrections are expected when + # Hessian-based error propagation is active. + # ------------------------------------------------------------ + self.assertFalse( + torch.equal(Q_h, Q_r), + msg="Quantized weights should not be exactly identical", + ) + + self.assertGreater( + torch.count_nonzero(Q_h != Q_r).item(), + 0, + msg="At least one quantized element should differ due to Hessian correction", + ) + + # ------------------------------------------------------------ + # 3. Group indices must be identical + # + # Group assignment depends only on group_size and ordering, + # and must NOT be affected by Hessian usage. + # ------------------------------------------------------------ + self.assertTrue( + torch.equal(gidx_h, gidx_r), + msg="Group indices (g_idx) must be identical regardless of Hessian usage", + ) + + # ------------------------------------------------------------ + # 4. Scale tensors: shape must match and values must remain stable + # + # Scale is allowed to change slightly due to redistribution + # of weights, but should remain within a small relative bound. + # ------------------------------------------------------------ + self.assertEqual(scale_h.shape, scale_r.shape) + self.assertEqual(zero_h.shape, zero_r.shape) + + scale_rel_diff = torch.mean( + torch.abs(scale_h - scale_r) / scale_r + ).item() + + self.assertLess( + scale_rel_diff, + 0.05, + msg="Relative scale deviation should remain below 5%", + ) + + # ------------------------------------------------------------ + # 5. Zero-points may shift slightly, but the shift must be bounded + # + # A bounded zero-point shift corresponds to less than one + # quantization bin and is expected behavior. + # ------------------------------------------------------------ + zero_diff = torch.mean(torch.abs(zero_h - zero_r)).item() + + self.assertLess( + zero_diff * quant_step, + quant_step, + msg="Zero-point shift should correspond to less than one quantization bin", + ) + + def test_cpu(self): + self._run_test("cpu") + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + def test_cuda(self): + self._run_test("cuda") + + +class TestFailsafeConfig(unittest.TestCase): + def test_fallback_none_round_trip(self): + qcfg = QuantizeConfig(fallback=None) + payload = qcfg.to_dict() + + self.assertIn("fallback", payload.get("meta", {})) + self.assertIsNone(payload["meta"]["fallback"]) + + loaded = QuantizeConfig.from_quant_config(payload) + self.assertIsNone(loaded.fallback) + +######## test_fallback_awq.py ######## + +def _dummy_prepare_dataset( + *, + calibration_dataset, + calibration_dataset_concat_size, + calibration_dataset_sort, + batch_size, + calibration_concat_separator=None, +): + return calibration_dataset + + +class _DummyProgressBar: + def title(self, _): + return self + + def subtitle(self, _): + return self + + def draw(self): + return None + + +def test_awq_fallback_falls_back_to_rtn_when_no_activations(monkeypatch): + model = nn.Module() + model.linear = nn.Linear(8, 8, bias=False) + + gptq_model = SimpleNamespace(model=model, lm_head=None, quant_region_timer=None) + + qcfg = QuantizeConfig( + bits=4, + group_size=-1, + fallback={"strategy": "rtn", "threshold": "1.0%"}, + format=FORMAT.GEMM, + quant_method=METHOD.AWQ, + ) + processor = _StubAWQProcessor( + qcfg=qcfg, + ) + processor.pb = _DummyProgressBar() + + named = NamedModule(model.linear, name="linear", full_name="linear", layer_index=0) + processor.preprocess(named, fallback=qcfg.fallback) + + calls = {} + + def fake_pack(self, nm): + calls["called"] = True + calls["name"] = nm.full_name + nm.state["wq"] = nm.module.weight.detach().clone() + + processor.pack_module = types.MethodType(fake_pack, processor) + + processor.process(named) + + processor.submodule_finalize(named, gptq_model) + + layer_state = processor._get_layer_state(0) + + assert calls.get("called") is True + assert calls.get("name") == "linear" + assert layer_state.quantized is True + assert "wq" in named.state + +######### test_fallback_strategies.py ############# + + +def _fallback_quantize( + weights: torch.Tensor, + group_size: int, + strategy: FallbackStrategy, + *, + bits: int = 4, + sym: bool = False, + smooth=None, +) -> torch.Tensor: + module = torch.nn.Linear(weights.shape[1], weights.shape[0], bias=False) + module = module.to(device=weights.device, dtype=weights.dtype) + with torch.no_grad(): + module.weight.copy_(weights) + + qcfg = QuantizeConfig( + bits=bits, + group_size=group_size, + sym=sym, + fallback=Fallback(strategy=strategy, smooth=smooth), + offload_to_disk=False, + ) + gptq = GPTQ(module=module, qcfg=qcfg) + gptq.quantizer.configure(perchannel=True) + dequant, *_ = gptq._fallback_quantize(strategy, blocksize=group_size) + return dequant + + +def _scenarios(): + torch.manual_seed(0) + base = torch.randn(8, 32) + return { + "centered_pos": base * 0.1 + 2.5, # cluster away from zero + "zero_centered": base * 0.1, # symmetric around zero + "pos_lean": base * 0.05 + 0.5, # small positive leaning + "neg_lean": base * 0.05 - 0.5, # small negative leaning + "wide_range": base * 2.5, # wider spread / outliers + "uniform": torch.linspace(-1, 1, 32).repeat(8, 1), # evenly spread + } + + +def _assert_fallback_bounds( + label: str, + weights: torch.Tensor, + group_size: int, + rtn_err: float, + midpoint_err: float, + mean_err: float, + median_err: float, + std_err: float, + asym_err: float, +) -> None: + for name, err in ( + ("rtn", rtn_err), + ("midpoint", midpoint_err), + ("mean", mean_err), + ("median", median_err), + ("stdclip", std_err), + ("asym", asym_err), + ): + assert math.isfinite(err), f"{label}, group={group_size}: {name}_err is not finite ({err})" + + min_val = weights.min().item() + max_val = weights.max().item() + one_sided = min_val > 0.0 or max_val < 0.0 + + if one_sided: + floor = rtn_err * 5.0 + assert midpoint_err >= floor, f"{label}, group={group_size}: midpoint_err={midpoint_err}, rtn_err={rtn_err}" + assert mean_err >= floor, f"{label}, group={group_size}: mean_err={mean_err}, rtn_err={rtn_err}" + assert median_err >= floor, f"{label}, group={group_size}: median_err={median_err}, rtn_err={rtn_err}" + assert std_err >= floor, f"{label}, group={group_size}: std_err={std_err}, rtn_err={rtn_err}" + assert asym_err >= floor, f"{label}, group={group_size}: asym_err={asym_err}, rtn_err={rtn_err}" + else: + ceiling = rtn_err * 3.0 + assert midpoint_err <= ceiling, f"{label}, group={group_size}: midpoint_err={midpoint_err}, rtn_err={rtn_err}" + assert mean_err <= ceiling, f"{label}, group={group_size}: mean_err={mean_err}, rtn_err={rtn_err}" + assert median_err <= ceiling, f"{label}, group={group_size}: median_err={median_err}, rtn_err={rtn_err}" + assert std_err <= ceiling, f"{label}, group={group_size}: std_err={std_err}, rtn_err={rtn_err}" + assert asym_err <= ceiling, f"{label}, group={group_size}: asym_err={asym_err}, rtn_err={rtn_err}" + + +def _assert_finite_errors(label: str, group_size: int, errors: dict) -> None: + for name, err in errors.items(): + assert math.isfinite(err), f"{label}, group={group_size}: {name} err is not finite ({err})" + + +def _assert_percentile_smoother_matches_reference(device: str) -> None: + torch.manual_seed(0) + block = torch.randn(64, 128, device=device, dtype=torch.float32) + + percentile = Fallback(smooth=SmoothPercentile(percentile=99.0)) + percentile_out, _ = smooth_block(block, percentile, group_size=128) + percentile_ref_threshold = torch.quantile(block.abs(), 0.99, dim=1, keepdim=True) + percentile_ref = torch.clamp(block, -percentile_ref_threshold, percentile_ref_threshold) + torch.testing.assert_close(percentile_out, percentile_ref, atol=1e-5, rtol=1e-5) + + asym = Fallback(smooth=SmoothPercentileAsymmetric(low=0.5, high=99.5)) + asym_out, _ = smooth_block(block, asym, group_size=128) + asym_lo = torch.quantile(block, 0.005, dim=1, keepdim=True) + asym_hi = torch.quantile(block, 0.995, dim=1, keepdim=True) + asym_ref = torch.max(torch.min(block, asym_hi), asym_lo) + torch.testing.assert_close(asym_out, asym_ref, atol=1e-5, rtol=1e-5) + + log = Fallback(smooth=SmoothLog(percentile=99.0, mu=8.0)) + log_out, _ = smooth_block(block, log, group_size=128) + log_mu = math.log1p(8.0) + log_vals = torch.log1p(block.abs() * 8.0) / log_mu + log_threshold = torch.quantile(log_vals, 0.99, dim=1, keepdim=True) + log_linear_threshold = (torch.exp(log_threshold * log_mu) - 1.0) / 8.0 + log_ref = torch.clamp(block, -log_linear_threshold, log_linear_threshold) + torch.testing.assert_close(log_out, log_ref, atol=1e-5, rtol=1e-5) + + +def test_percentile_smoothers_match_quantile_reference_cpu(): + _assert_percentile_smoother_matches_reference("cpu") + + +@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") +def test_percentile_smoothers_match_quantile_reference_cuda(): + _assert_percentile_smoother_matches_reference("cuda") + + +def test_midpoint_vs_rtn_across_distributions(): + scenarios = _scenarios() + rows = _collect_synthetic_rows(scenarios) + for scenario_name, group_size, rtn_err, midpoint_err, mean_err, median_err, std_err, asym_err in rows: + weights = scenarios[scenario_name] + _assert_fallback_bounds( + scenario_name, + weights, + group_size, + rtn_err, + midpoint_err, + mean_err, + median_err, + std_err, + asym_err, + ) + + +def _load_weight_slice(model_dir: str, tensor_name: str, *, max_rows: int = 256, max_cols: int = 256) -> torch.Tensor: + from safetensors import safe_open + + shards = sorted(glob(os.path.join(model_dir, "model-*.safetensors"))) + if not shards: + raise FileNotFoundError(f"No safetensor shards found under {model_dir}") + + for shard_path in shards: + with safe_open(shard_path, framework="pt", device="cpu") as f: + if tensor_name in f.keys(): + tensor = f.get_tensor(tensor_name) + return tensor[:max_rows, :max_cols].clone() + raise FileNotFoundError(f"Tensor `{tensor_name}` not found in {model_dir}") + + +def test_midpoint_vs_rtn_on_qwen3_real_weights(): + model_dir = "/monster/data/model/Qwen3-30B-A3B" + if not os.path.isdir(model_dir): + import pytest + pytest.skip(f"Model path missing: {model_dir}") + + targets = [ + "model.layers.0.mlp.experts.10.up_proj.weight", + "model.layers.0.mlp.experts.10.down_proj.weight", + "model.layers.0.mlp.experts.10.gate_proj.weight", + ] + group_sizes = (32, 64, 128) + rows = [] + rows_mad = [] + + for name in targets: + try: + w = _load_weight_slice(model_dir, name, max_rows=256, max_cols=256) + except FileNotFoundError: + import pytest + pytest.skip(f"Tensor `{name}` not found in model shards at {model_dir}") + + for group_size in group_sizes: + rtn = _fallback_quantize(w, group_size, FallbackStrategy.RTN) + mid = _fallback_quantize(w, group_size, FallbackStrategy.MIDPOINT) + mean_c = _fallback_quantize(w, group_size, FallbackStrategy.MEAN) + median_c = _fallback_quantize(w, group_size, FallbackStrategy.MEDIAN) + std_c = _fallback_quantize(w, group_size, FallbackStrategy.STDCLIP) + asym_c = _fallback_quantize( + w, + group_size, + FallbackStrategy.MIDPOINT, + smooth=SmoothPercentileAsymmetric(low=0.5, high=99.5), + ) + + rtn_err = torch.mean((w - rtn).abs()).item() + mid_err = torch.mean((w - mid).abs()).item() + mean_err = torch.mean((w - mean_c).abs()).item() + median_err = torch.mean((w - median_c).abs()).item() + std_err = torch.mean((w - std_c).abs()).item() + asym_err = torch.mean((w - asym_c).abs()).item() + _assert_fallback_bounds( + name, + w, + group_size, + rtn_err, + mid_err, + mean_err, + median_err, + std_err, + asym_err, + ) + rows.append((name, group_size, rtn_err, mid_err, mean_err, median_err, std_err, asym_err)) + + rtn_mad = _fallback_quantize(w, group_size, FallbackStrategy.RTN, smooth=SmoothMAD()) + mid_mad = _fallback_quantize(w, group_size, FallbackStrategy.MIDPOINT, smooth=SmoothMAD()) + mean_mad = _fallback_quantize(w, group_size, FallbackStrategy.MEAN, smooth=SmoothMAD()) + median_mad = _fallback_quantize(w, group_size, FallbackStrategy.MEDIAN, smooth=SmoothMAD()) + std_mad = _fallback_quantize(w, group_size, FallbackStrategy.STDCLIP, smooth=SmoothMAD()) + + rtn_mad_err = torch.mean((w - rtn_mad).abs()).item() + mid_mad_err = torch.mean((w - mid_mad).abs()).item() + mean_mad_err = torch.mean((w - mean_mad).abs()).item() + median_mad_err = torch.mean((w - median_mad).abs()).item() + std_mad_err = torch.mean((w - std_mad).abs()).item() + _assert_finite_errors( + f"{name} (mad)", + group_size, + { + "rtn_mad": rtn_mad_err, + "mid_mad": mid_mad_err, + "mean_mad": mean_mad_err, + "median_mad": median_mad_err, + "stdclip_mad": std_mad_err, + }, + ) + rows_mad.append( + ( + name, + group_size, + rtn_mad_err, + mid_mad_err, + mean_mad_err, + median_mad_err, + std_mad_err, + ) + ) + + scenarios = _scenarios() + synthetic_rows = _collect_synthetic_rows(scenarios) + combined = [("synthetic:" + s, gs, re, me, mne, mde, se, ae) for s, gs, re, me, mne, mde, se, ae in synthetic_rows] + combined += [("real:" + m, gs, re, me, mne, mde, se, ae) for m, gs, re, me, mne, mde, se, ae in rows] + native_map = { + (name, group_size): { + "rtn": rtn_err, + "mid": mid_err, + "mean": mean_err, + "median": median_err, + "std": std_err, + } + for name, group_size, rtn_err, mid_err, mean_err, median_err, std_err, _ in rows + } + + header = "+-------------------------------+------------+---------+--------------+--------------+--------------+--------------+--------------+" + try: + from logbar import LogBar + + cols = LogBar.shared().columns( + cols=[ + {"label": "case", "width": "fit"}, + {"label": "group_size", "width": "fit"}, + {"label": "rtn_err", "width": "fit"}, + {"label": "mid_err", "width": "fit"}, + {"label": "mean_err", "width": "fit"}, + {"label": "median_err", "width": "fit"}, + {"label": "stdclip_err", "width": "fit"}, + {"label": "asym_err", "width": "fit"}, + ], + padding=1, + ) + cols.info.header() + for label, gs, rtn_err, mid_err, mean_err, median_err, std_err, asym_err in combined: + errors = { + "rtn": rtn_err, + "mid": mid_err, + "mean": mean_err, + "median": median_err, + "std": std_err, + "asym": asym_err, + } + sorted_methods = sorted(errors.items(), key=lambda kv: kv[1]) + palette = ["\033[32m", "\033[33m", "\033[35m", "\033[34m", "\033[36m", "\033[31m"] + color_map = {name: palette[min(idx, len(palette) - 1)] for idx, (name, _) in enumerate(sorted_methods)} + reset = "\033[0m" + cols.info( + label, + str(gs), + f"{color_map['rtn']}{rtn_err:.5f}{reset}", + f"{color_map['mid']}{mid_err:.5f}{reset}", + f"{color_map['mean']}{mean_err:.5f}{reset}", + f"{color_map['median']}{median_err:.5f}{reset}", + f"{color_map['std']}{std_err:.5f}{reset}", + f"{color_map['asym']}{asym_err:.5f}{reset}", + ) + cols.info.header() + except Exception: + print(header) + print("| case | group_size | rtn_err | midpoint_err | mean_err | median_err | stdclip_err | asym_err |") + print(header) + for label, gs, rtn_err, mid_err, mean_err, median_err, std_err, asym_err in combined: + print(f"| {label:29} | {gs:10d} | {rtn_err:7.5f} | {mid_err:12.5f} | {mean_err:12.5f} | {median_err:12.5f} | {std_err:12.5f} | {asym_err:12.5f} |") + print(header) + + if rows_mad: + mad_header = ( + "+-------------------------------+------------+-------------+-------------+-------------+-------------+" + "-------------+-------------+-------------+-------------+-------------+-------------+-------------+" + ) + try: + from logbar import LogBar + + cols = LogBar.shared().columns( + cols=[ + {"label": "case (mad)", "width": "fit"}, + {"label": "group_size", "width": "fit"}, + {"label": "rtn_mad", "width": "fit"}, + {"label": "rtn_vs", "width": "fit"}, + {"label": "mid_mad", "width": "fit"}, + {"label": "mid_vs", "width": "fit"}, + {"label": "mean_mad", "width": "fit"}, + {"label": "mean_vs", "width": "fit"}, + {"label": "median_mad", "width": "fit"}, + {"label": "median_vs", "width": "fit"}, + {"label": "stdclip_mad", "width": "fit"}, + {"label": "stdclip_vs", "width": "fit"}, + ], + padding=1, + ) + cols.info.header() + for label, gs, rtn_err, mid_err, mean_err, median_err, std_err in rows_mad: + errors = { + "rtn": rtn_err, + "mid": mid_err, + "mean": mean_err, + "median": median_err, + "std": std_err, + } + sorted_methods = sorted(errors.items(), key=lambda kv: kv[1]) + palette = ["\033[32m", "\033[33m", "\033[35m", "\033[34m", "\033[36m", "\033[31m"] + color_map = {name: palette[min(idx, len(palette) - 1)] for idx, (name, _) in enumerate(sorted_methods)} + reset = "\033[0m" + native = native_map.get((label, gs), {}) + deltas = { + "rtn": rtn_err - native.get("rtn", rtn_err), + "mid": mid_err - native.get("mid", mid_err), + "mean": mean_err - native.get("mean", mean_err), + "median": median_err - native.get("median", median_err), + "std": std_err - native.get("std", std_err), + } + def _color_delta(value: float) -> str: + if value > 0: + return f"\033[31m{value:+.5f}\033[0m" + if value < 0: + return f"\033[32m{value:+.5f}\033[0m" + return f"{value:+.5f}" + + cols.info( + f"mad:{label}", + str(gs), + f"{color_map['rtn']}{rtn_err:.5f}{reset}", + _color_delta(deltas["rtn"]), + f"{color_map['mid']}{mid_err:.5f}{reset}", + _color_delta(deltas["mid"]), + f"{color_map['mean']}{mean_err:.5f}{reset}", + _color_delta(deltas["mean"]), + f"{color_map['median']}{median_err:.5f}{reset}", + _color_delta(deltas["median"]), + f"{color_map['std']}{std_err:.5f}{reset}", + _color_delta(deltas["std"]), + ) + cols.info.header() + except Exception: + print(mad_header) + print( + "| case (mad) | group_size | rtn_mad | rtn_vs | mid_mad | mid_vs | mean_mad | mean_vs |" + " median_mad | median_vs | stdclip_mad | stdclip_vs |" + ) + print(mad_header) + for label, gs, rtn_err, mid_err, mean_err, median_err, std_err in rows_mad: + native = native_map.get((label, gs), {}) + deltas = { + "rtn": rtn_err - native.get("rtn", rtn_err), + "mid": mid_err - native.get("mid", mid_err), + "mean": mean_err - native.get("mean", mean_err), + "median": median_err - native.get("median", median_err), + "std": std_err - native.get("std", std_err), + } + def _color_delta_plain(value: float) -> str: + if value > 0: + return f"\033[31m{value:+.5f}\033[0m" + if value < 0: + return f"\033[32m{value:+.5f}\033[0m" + return f"{value:+.5f}" + + print( + f"| mad:{label:24} | {gs:10d} | {rtn_err:11.5f} | {_color_delta_plain(deltas['rtn']):11} |" + f" {mid_err:11.5f} | {_color_delta_plain(deltas['mid']):11} | {mean_err:11.5f} |" + f" {_color_delta_plain(deltas['mean']):11} | {median_err:11.5f} | {_color_delta_plain(deltas['median']):11} |" + f" {std_err:11.5f} | {_color_delta_plain(deltas['std']):11} |" + ) + print(mad_header) + + +def _collect_synthetic_rows(scenarios=None): + if scenarios is None: + scenarios = _scenarios() + rows = [] + for scenario_name, weights in scenarios.items(): + for group_size in (16, 32, 64, 128): + rtn = _fallback_quantize(weights, group_size, FallbackStrategy.RTN) + midpoint = _fallback_quantize(weights, group_size, FallbackStrategy.MIDPOINT) + mean_centered = _fallback_quantize(weights, group_size, FallbackStrategy.MEAN) + median_centered = _fallback_quantize(weights, group_size, FallbackStrategy.MEDIAN) + std_clip = _fallback_quantize(weights, group_size, FallbackStrategy.STDCLIP) + asym_clip = _fallback_quantize( + weights, + group_size, + FallbackStrategy.MIDPOINT, + smooth=SmoothPercentileAsymmetric(low=0.5, high=99.5), + ) + + rtn_err = torch.mean((weights - rtn).abs()).item() + midpoint_err = torch.mean((weights - midpoint).abs()).item() + mean_err = torch.mean((weights - mean_centered).abs()).item() + median_err = torch.mean((weights - median_centered).abs()).item() + std_err = torch.mean((weights - std_clip).abs()).item() + asym_err = torch.mean((weights - asym_clip).abs()).item() + rows.append((scenario_name, group_size, rtn_err, midpoint_err, mean_err, median_err, std_err, asym_err)) + return rows + + + +######### test_fallback_thresholds.py ############# + + +def test_should_use_fallback_parses_numeric_and_percent(): + assert should_use_fallback(True, observed_samples=0, expected_total_samples=100) + assert not should_use_fallback(True, observed_samples=1, expected_total_samples=100) + + assert should_use_fallback("10", observed_samples=5, expected_total_samples=100) + assert not should_use_fallback("10", observed_samples=11, expected_total_samples=100) + + assert should_use_fallback("10%", observed_samples=8, expected_total_samples=90) + assert should_use_fallback("10%", observed_samples=10, expected_total_samples=200) + + +def test_gptq_fallback_threshold_triggers_rtn_when_samples_below_percent(): + torch.manual_seed(0) + layer = nn.Linear(8, 8, bias=False) + + qcfg = QuantizeConfig(bits=4, group_size=4, fallback="75%") + gptq = GPTQ(layer, qcfg) + gptq.fallback = qcfg.fallback + gptq.expected_nsamples = 4 # pretend we expected 4 token rows + gptq.quantizer.configure(perchannel=True) + + # Capture only a single token worth of activations (< 75% of expected total) + inp = torch.randn(1, 1, 8) + gptq.add_batch(inp, None) + + Q, _, _, _, _, avg_loss, _, nsamples = gptq.quantize(blocksize=4) + + assert nsamples == 1 + assert avg_loss.startswith("fallback(rtn): ") + assert (Q - layer.weight.data).abs().mean().item() == pytest.approx(0.0120230, abs=1e-7) diff --git a/tests/test_format_conversion_flow.py b/tests/test_format_conversion_flow.py index 8f62d41dd..2d073f3ad 100644 --- a/tests/test_format_conversion_flow.py +++ b/tests/test_format_conversion_flow.py @@ -8,7 +8,8 @@ import torch -from gptqmodel.quantization import FORMAT, METHOD, QuantizeConfig +from gptqmodel.nn_modules.qlinear.torch import TorchLinear +from gptqmodel.quantization import FORMAT, METHOD, GGUFConfig, QuantizeConfig, RTNConfig from gptqmodel.utils.model import pack_module @@ -26,8 +27,6 @@ def __init__(self): self.bits = 4 self.pack_dtype = torch.int32 - QUANT_TYPE = "gptq" - def to(self, *_args, **_kwargs): return self @@ -58,9 +57,8 @@ def qzero_format(self, format: int | None = None): def _make_quant_linear_cls(requires_v2: bool): return type( "DummyQuantLinear", - (), + (TorchLinear,), { - "QUANT_TYPE": "gptq", "REQUIRES_FORMAT_V2": requires_v2, }, ) @@ -76,6 +74,7 @@ def _run_pack(quant_cfg: QuantizeConfig, requires_v2: bool) -> int: lock = threading.Lock() quant_linear_cls = _make_quant_linear_cls(requires_v2=requires_v2) + assert issubclass(quant_linear_cls, TorchLinear) assert getattr(quant_linear_cls, "REQUIRES_FORMAT_V2") is requires_v2 with mock.patch("gptqmodel.utils.model.convert_gptq_v2_to_v1_format_module") as convert_mock: @@ -106,6 +105,24 @@ def test_pack_module_skips_for_non_gptq_method(): assert calls == 0 +def test_pack_module_skips_for_non_gptq_export_method(): + cfg = RTNConfig(bits=4, format=FORMAT.GEMM, offload_to_disk=False) + calls = _run_pack(cfg, requires_v2=True) + assert calls == 0 + + +def test_pack_module_converts_for_rtn_gptq_export_requires_v2(): + cfg = RTNConfig(bits=4, format=FORMAT.GPTQ, offload_to_disk=False) + calls = _run_pack(cfg, requires_v2=True) + assert calls == 1 + + +def test_pack_module_skips_for_rtn_gguf_export(): + cfg = GGUFConfig(bits=4, offload_to_disk=False) + calls = _run_pack(cfg, requires_v2=True) + assert calls == 0 + + def test_pack_module_skips_when_kernel_uses_v1(): cfg = QuantizeConfig(bits=4, quant_method=METHOD.GPTQ, format=FORMAT.GPTQ, offload_to_disk=False) calls = _run_pack(cfg, requires_v2=False) diff --git a/tests/test_fp4_llama3_fp4.py b/tests/test_fp4_llama3_fp4.py index 38cfd1e3b..49809f8f0 100644 --- a/tests/test_fp4_llama3_fp4.py +++ b/tests/test_fp4_llama3_fp4.py @@ -22,6 +22,13 @@ MODEL_DIR = Path("/monster/data/model/Llama-3.3-70B-Instruct-FP4") +def _nvfp4_to_dtype(nv_tensor, dtype: torch.dtype) -> torch.Tensor: + to_dtype = getattr(nv_tensor, "to_dtype", None) + if callable(to_dtype): + return to_dtype(dtype) + return nv_tensor.dequantize(dtype) + + @pytest.mark.skipif(NVFP4Tensor is None, reason="torchao NVFP4 support required") @pytest.mark.skipif(not MODEL_DIR.exists(), reason="Llama-3.3 FP4 model not available") def test_fp4_llama3_module_dequant_matches_nvfp4_tensor(): @@ -35,7 +42,7 @@ def test_fp4_llama3_module_dequant_matches_nvfp4_tensor(): dequant = dequantize_f4_e2m1(weight, scale=scales, axis=None, target_dtype=torch.bfloat16) nv_tensor = NVFP4Tensor(weight, scales, block_size=16, orig_dtype=torch.bfloat16) - expected = nv_tensor.to_dtype(torch.bfloat16) + expected = _nvfp4_to_dtype(nv_tensor, torch.bfloat16) diff = torch.max(torch.abs(dequant - expected)).item() assert torch.allclose(dequant, expected, atol=1e-3, rtol=1e-3), diff diff --git a/tests/test_fp4_qwen3_nvfp4.py b/tests/test_fp4_qwen3_nvfp4.py new file mode 100644 index 000000000..2d7ab8833 --- /dev/null +++ b/tests/test_fp4_qwen3_nvfp4.py @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import json +from pathlib import Path + +import pytest +import torch +from safetensors import safe_open + +from gptqmodel.quantization.dtype import dequantize_f4_e2m1 + + +try: + from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor +except Exception: + NVFP4Tensor = None + + +MODEL_DIR = Path("/mnt/SFS-6CFyUykx/models/Qwen3-8B-NVFP4") +WEIGHT_KEY = "model.layers.0.mlp.down_proj.weight" +SCALE_KEY = "model.layers.0.mlp.down_proj.weight_scale" + + +def _nvfp4_to_dtype(nv_tensor, dtype: torch.dtype) -> torch.Tensor: + to_dtype = getattr(nv_tensor, "to_dtype", None) + if callable(to_dtype): + return to_dtype(dtype) + return nv_tensor.dequantize(dtype) + + +def _load_first_layer_tensors() -> tuple[torch.Tensor, torch.Tensor]: + index = json.loads((MODEL_DIR / "model.safetensors.index.json").read_text()) + shard = index["weight_map"][WEIGHT_KEY] + + with safe_open(MODEL_DIR / shard, framework="pt", device="cpu") as f: + weight = f.get_tensor(WEIGHT_KEY) + scales = f.get_tensor(SCALE_KEY) + + return weight, scales + + +@pytest.mark.skipif(NVFP4Tensor is None, reason="torchao NVFP4 support required") +@pytest.mark.skipif(not MODEL_DIR.exists(), reason="Qwen3 8B NVFP4 model not available") +def test_fp4_qwen3_module_dequant_matches_nvfp4_tensor(): + weight, scales = _load_first_layer_tensors() + + dequant = dequantize_f4_e2m1(weight, scale=scales, axis=None, target_dtype=torch.bfloat16) + + nv_tensor = NVFP4Tensor(weight, scales, block_size=16, orig_dtype=torch.bfloat16) + expected = _nvfp4_to_dtype(nv_tensor, torch.bfloat16) + + diff = torch.max(torch.abs(dequant - expected)).item() + assert torch.allclose(dequant, expected, atol=1e-3, rtol=1e-3), diff + + +@pytest.mark.skipif(NVFP4Tensor is None, reason="torchao NVFP4 support required") +@pytest.mark.skipif(not MODEL_DIR.exists(), reason="Qwen3 8B NVFP4 model not available") +def test_fp4_qwen3_module_gpu_consistency(): + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + dev = torch.device("cuda:0") + weight, scales = _load_first_layer_tensors() + + cpu = dequantize_f4_e2m1(weight, scale=scales, axis=None, target_dtype=torch.bfloat16) + + torch.cuda.set_device(dev) + gpu_weight = weight.to(dev) + gpu_scales = scales.to(dev) + gpu = dequantize_f4_e2m1(gpu_weight, scale=gpu_scales, axis=None, target_dtype=torch.bfloat16) + + assert torch.allclose(cpu, gpu.cpu(), atol=1e-3, rtol=1e-3) diff --git a/tests/test_fp8.py b/tests/test_fp8.py new file mode 100644 index 000000000..1aa326fc9 --- /dev/null +++ b/tests/test_fp8.py @@ -0,0 +1,73 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +import json + +import pytest +import torch +from safetensors.torch import save_file + +from gptqmodel.quantization.config import METHOD, FP8Config, QuantizeConfig +from gptqmodel.quantization.dtype import available_float8_dtype_names +from gptqmodel.utils.model_dequant import detect_format + + +@pytest.mark.parametrize("format_name", available_float8_dtype_names()) +def test_fp8_quantize_config_round_trip(format_name: str): + cfg = QuantizeConfig( + quant_method=METHOD.FP8, + format=format_name, + weight_scale_method="block", + weight_block_size=[128, 128], + ) + + assert isinstance(cfg, FP8Config) + assert cfg.uses_weight_only_lifecycle() is True + + payload = cfg.to_dict() + assert payload["method"] == METHOD.FP8 + assert payload["quant_method"] == METHOD.FP8 + assert payload["format"] == format_name + assert payload["checkpoint_format"] == format_name + assert payload["weight_scale_method"] == "block" + assert payload["weight_block_size"] == [128, 128] + assert payload["weight_scale_semantics"] == "inverse" + + reloaded = QuantizeConfig.from_quant_config(payload) + assert isinstance(reloaded, FP8Config) + assert reloaded.format == format_name + assert reloaded.weight_scale_method == "block" + assert reloaded.weight_block_size == [128, 128] + + +@pytest.mark.parametrize("format_name", available_float8_dtype_names()) +def test_detect_format_identifies_fp8_from_checkpoint_or_config(tmp_path, format_name: str): + shard_path = tmp_path / "model.safetensors" + if format_name.endswith("fnuz"): + save_file({}, str(shard_path)) + else: + save_file( + { + "model.layers.0.self_attn.q_proj.weight": torch.zeros((16, 16), dtype=getattr(torch, format_name)), + "model.layers.0.self_attn.q_proj.weight_scale_inv": torch.ones((16,), dtype=torch.float32), + }, + str(shard_path), + ) + + config_path = tmp_path / "config.json" + config_path.write_text( + json.dumps( + { + "quantization_config": { + "method": "fp8", + "format": format_name, + "weight_scale_method": "row", + "weight_scale_semantics": "inverse", + } + } + ), + encoding="utf-8", + ) + + detected = detect_format(tmp_path, json.loads(config_path.read_text(encoding="utf-8"))) + assert detected == "fp8" diff --git a/tests/test_fp8_minimax2_test.py b/tests/test_fp8_minimax2_test.py index 7ce2be863..4d323ed23 100644 --- a/tests/test_fp8_minimax2_test.py +++ b/tests/test_fp8_minimax2_test.py @@ -68,7 +68,7 @@ def _run_quant_on_device(device_index: int) -> torch.device: ) processor.pb = _DummyProgressBar() - processor.preprocess(named, failsafe=False) + processor.preprocess(named, fallback=False) named.module.target_device = target processor.process(named) diff --git a/tests/test_gemma4_support.py b/tests/test_gemma4_support.py new file mode 100644 index 000000000..2092ba1d6 --- /dev/null +++ b/tests/test_gemma4_support.py @@ -0,0 +1,184 @@ +from types import SimpleNamespace + +import pytest +import torch +from torch import nn +from transformers import AutoConfig + +from gptqmodel.models import auto +from gptqmodel.models.definitions.gemma4 import Gemma4ForConditionalGenerationGPTQ, Gemma4TextQModel + + +GEMMA4_VARIANTS = [ + "/monster/data/model/gemma-4-E2B", + "/monster/data/model/gemma-4-E4B-it", + "/monster/data/model/gemma-4-31B-it", +] + + +@pytest.mark.parametrize("model_path", GEMMA4_VARIANTS) +def test_gemma4_local_variants_select_multimodal_definition(model_path): + config = AutoConfig.from_pretrained(model_path) + + assert config.model_type == "gemma4" + assert auto.check_and_get_model_definition(model_path) is Gemma4ForConditionalGenerationGPTQ + + +def test_gemma4_text_model_type_selects_text_definition(monkeypatch): + fake_config = SimpleNamespace(model_type="gemma4_text") + + monkeypatch.setattr(auto, "resolve_trust_remote_code", lambda path, trust_remote_code=False: trust_remote_code) + monkeypatch.setattr(auto.AutoConfig, "from_pretrained", lambda *args, **kwargs: fake_config) + + assert auto.check_and_get_model_definition("/tmp/gemma4-text") is Gemma4TextQModel + + +def test_gemma4_module_tree_keeps_optional_variant_paths_non_strict(): + layer_modules = Gemma4TextQModel.simple_layer_modules( + model_config=SimpleNamespace(), + quantize_config=SimpleNamespace(dynamic=None), + ) + flat_modules = {name for block in layer_modules for name in block} + + assert Gemma4TextQModel.layer_modules_strict is False + assert "self_attn.q_proj" in flat_modules + assert "self_attn.k_proj" in flat_modules + assert "self_attn.v_proj" in flat_modules + assert "self_attn.o_proj" in flat_modules + assert "per_layer_input_gate" in flat_modules + assert "per_layer_projection" in flat_modules + + +def test_gemma4_multimodal_base_modules_include_per_layer_helpers(): + class _LanguageModel(nn.Module): + def __init__(self): + super().__init__() + self.layers = nn.ModuleList([nn.Identity()]) + self.embed_tokens = nn.Embedding(4, 4) + self.embed_tokens_per_layer = nn.Embedding(4, 4) + self.per_layer_model_projection = nn.Linear(4, 4, bias=False) + self.per_layer_projection_norm = nn.LayerNorm(4) + self.norm = nn.LayerNorm(4) + self.rotary_emb = nn.Identity() + + class _Gemma4Core(nn.Module): + def __init__(self): + super().__init__() + self.language_model = _LanguageModel() + self.vision_tower = nn.Identity() + self.embed_vision = nn.Identity() + self.audio_tower = nn.Identity() + self.embed_audio = nn.Identity() + + class _Gemma4Wrapper(nn.Module): + def __init__(self): + super().__init__() + self.model = _Gemma4Core() + self.lm_head = nn.Linear(4, 4, bias=False) + + model = _Gemma4Wrapper() + base_modules = set(Gemma4ForConditionalGenerationGPTQ.get_base_modules(model)) + + assert Gemma4ForConditionalGenerationGPTQ.extract_layers_node() == ["model.language_model.layers"] + assert "model.vision_tower" in base_modules + assert "model.embed_vision" in base_modules + assert "model.audio_tower" in base_modules + assert "model.embed_audio" in base_modules + assert "model.language_model.embed_tokens" in base_modules + assert "model.language_model.embed_tokens_per_layer" in base_modules + assert "model.language_model.per_layer_model_projection" in base_modules + assert "model.language_model.per_layer_projection_norm" in base_modules + + +def test_gemma4_capture_preserves_per_layer_input(): + model_def = object.__new__(Gemma4ForConditionalGenerationGPTQ) + hidden_states = torch.randn(1, 4, 8) + per_layer_input = torch.randn(1, 4, 2) + + captured = model_def.capture_first_layer_positional_inputs( + args=(hidden_states, per_layer_input), + kwargs={}, + batch_device=torch.device("cpu"), + ) + + assert len(captured) == 2 + assert torch.equal(captured[0], hidden_states) + assert torch.equal(captured[1], per_layer_input) + + +def test_gemma4_replay_kwargs_refresh_position_embeddings(): + class _FakeRotary(nn.Module): + def forward(self, x, position_ids, layer_type=None): + marker = 7.0 if layer_type == "full_attention" else 3.0 + shape = (position_ids.shape[0], position_ids.shape[1], 1) + value = torch.full(shape, marker, dtype=x.dtype, device=x.device) + return value, value + 1 + + class _LanguageModel(nn.Module): + def __init__(self): + super().__init__() + self.rotary_emb = _FakeRotary() + + class _Gemma4Core(nn.Module): + def __init__(self): + super().__init__() + self.language_model = _LanguageModel() + + class _Gemma4Wrapper(nn.Module): + def __init__(self): + super().__init__() + self.model = _Gemma4Core() + + model_def = object.__new__(Gemma4ForConditionalGenerationGPTQ) + nn.Module.__init__(model_def) + model_def.model = _Gemma4Wrapper() + + layer = SimpleNamespace(self_attn=SimpleNamespace(layer_type="full_attention")) + hidden_states = torch.randn(1, 4, 8) + refreshed = model_def.prepare_layer_replay_kwargs( + layer=layer, + layer_input=[hidden_states], + additional_inputs={ + "position_ids": torch.arange(4).unsqueeze(0), + "position_embeddings": ("stale",), + }, + target_device=torch.device("cpu"), + ) + + cos, sin = refreshed["position_embeddings"] + assert cos.shape == (1, 4, 1) + assert sin.shape == (1, 4, 1) + assert torch.all(cos == 7) + assert torch.all(sin == 8) + + +def test_gemma4_capture_kwargs_preserve_all_per_layer_inputs(): + class _LanguageModel(nn.Module): + def __init__(self): + super().__init__() + self.rotary_emb = nn.Identity() + self._gptqmodel_cached_all_per_layer_inputs = torch.randn(1, 4, 3, 2) + + class _Gemma4Core(nn.Module): + def __init__(self): + super().__init__() + self.language_model = _LanguageModel() + + class _Gemma4Wrapper(nn.Module): + def __init__(self): + super().__init__() + self.model = _Gemma4Core() + + model_def = object.__new__(Gemma4ForConditionalGenerationGPTQ) + nn.Module.__init__(model_def) + model_def.model = _Gemma4Wrapper() + + captured = model_def.capture_first_layer_input_kwargs( + args=(), + kwargs={}, + batch_device=torch.device("cpu"), + layer_input_kwargs={}, + ) + + assert "__gptqmodel_gemma4_all_per_layer_inputs" in captured + assert captured["__gptqmodel_gemma4_all_per_layer_inputs"].shape == (1, 4, 3, 2) diff --git a/tests/test_generate_attention_mask.py b/tests/test_generate_attention_mask.py new file mode 100644 index 000000000..d204de1df --- /dev/null +++ b/tests/test_generate_attention_mask.py @@ -0,0 +1,81 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from gptqmodel.models.base import BaseQModel + + +class _RecorderModel: + def __init__(self): + self.device = torch.device("cpu") + self.last_kwargs = None + + def generate(self, *args, **kwargs): + self.last_kwargs = kwargs + return kwargs["attention_mask"] + + +def test_base_qmodel_generate_normalizes_causal_attention_mask(): + qmodel = BaseQModel.__new__(BaseQModel) + qmodel.model = _RecorderModel() + qmodel.tokenizer = None + + input_ids = torch.tensor([[1, 2, 3]], dtype=torch.long) + attention_mask = torch.tril(torch.ones((1, 1, 3, 3), dtype=torch.long)) + + normalized = qmodel.generate(input_ids=input_ids, attention_mask=attention_mask) + + assert normalized.shape == (1, 3) + assert normalized.dtype is torch.bool + assert normalized.tolist() == [[True, True, True]] + assert qmodel.model.last_kwargs["attention_mask"].shape == (1, 3) + + +def test_base_qmodel_generate_uses_instance_runtime_for_string_prompts(): + qmodel = BaseQModel.__new__(BaseQModel) + qmodel.model = object() + qmodel.tokenizer = None + + captured = {} + + def _runtime_generate(model, **kwargs): + captured["model"] = model + captured["kwargs"] = kwargs + return "runtime-output" + + qmodel._runtime_generate = _runtime_generate + + output = qmodel.generate("hello", temperature=0.3) + + assert output == "runtime-output" + assert captured["model"] is qmodel.model + assert captured["kwargs"]["prompts"] == "hello" + assert captured["kwargs"]["temperature"] == 0.3 + + +def test_base_qmodel_generate_normalizes_attention_mask_for_instance_runtime(): + qmodel = BaseQModel.__new__(BaseQModel) + qmodel.model = object() + qmodel.tokenizer = None + + captured = {} + + def _runtime_generate(model, **kwargs): + captured["model"] = model + captured["kwargs"] = kwargs + return kwargs["attention_mask"] + + qmodel._runtime_generate = _runtime_generate + + input_ids = torch.tensor([[1, 2, 3]], dtype=torch.long) + attention_mask = torch.tril(torch.ones((1, 1, 3, 3), dtype=torch.long)) + + normalized = qmodel.generate(input_ids=input_ids, attention_mask=attention_mask) + + assert normalized.shape == (1, 3) + assert normalized.dtype is torch.bool + assert normalized.tolist() == [[True, True, True]] + assert captured["model"] is qmodel.model + assert captured["kwargs"]["input_ids"] is input_ids + assert captured["kwargs"]["attention_mask"].shape == (1, 3) diff --git a/tests/test_gguf_qlinear_llama.py b/tests/test_gguf_qlinear_llama.py new file mode 100644 index 000000000..a46c91a05 --- /dev/null +++ b/tests/test_gguf_qlinear_llama.py @@ -0,0 +1,165 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import contextlib +import copy +import io +import os +from pathlib import Path + +import pytest +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer +from transformers.utils import logging as hf_logging + +from gptqmodel.nn_modules.qlinear.gguf import GGUFTorchLinear + + +MODEL_ID = Path("/monster/data/model/Llama-3.2-1B-Instruct") +PROMPT = "The capital city of France is Paris. The capital city of Germany is" +LAYER0_MODULES = ( + "self_attn.q_proj", + "self_attn.k_proj", + "self_attn.v_proj", + "self_attn.o_proj", +) +DIRECT_Q4_K_M_LIMITS = { + "self_attn.q_proj": { + "weight_mae": 0.0026, + "weight_max": 0.035, + "output_mae": 0.022, + "output_max": 0.21, + }, + "self_attn.k_proj": { + "weight_mae": 0.0034, + "weight_max": 0.031, + "output_mae": 0.029, + "output_max": 0.20, + }, + "self_attn.v_proj": { + "weight_mae": 0.0008, + "weight_max": 0.004, + "output_mae": 0.0065, + "output_max": 0.032, + }, + "self_attn.o_proj": { + "weight_mae": 0.0010, + "weight_max": 0.020, + "output_mae": 0.0010, + "output_max": 0.011, + }, +} + + +def _error_stats(reference: torch.Tensor, candidate: torch.Tensor) -> dict[str, float]: + diff = (candidate - reference).abs() + return { + "mae": diff.mean().item(), + "max": diff.max().item(), + } + + +def _skip_unavailable_environment(dtype: torch.dtype) -> torch.device: + if not MODEL_ID.exists(): + pytest.skip(f"Missing local test model: {MODEL_ID}") + if not torch.cuda.is_available(): + pytest.skip("Direct GGUF Llama layer test requires CUDA") + if dtype == torch.float16 and not torch.cuda.is_available(): + pytest.skip("float16 path requires CUDA") + return torch.device("cuda:0") + + +def _load_llama(dtype: torch.dtype): + os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") + hf_logging.disable_progress_bar() + hf_logging.set_verbosity_error() + + sink = io.StringIO() + with contextlib.redirect_stdout(sink), contextlib.redirect_stderr(sink): + model = AutoModelForCausalLM.from_pretrained( + str(MODEL_ID), + dtype=dtype, + device_map="cuda:0", + ).eval() + tokenizer = AutoTokenizer.from_pretrained(str(MODEL_ID), use_fast=True) + + return model, tokenizer + + +def _capture_layer0_module_io(model, tokenizer) -> dict[str, dict[str, torch.Tensor | None]]: + layer0 = model.model.layers[0] + captured: dict[str, dict[str, torch.Tensor | None]] = {} + handles = [] + + for module_name in LAYER0_MODULES: + module = dict(layer0.named_modules())[module_name] + + def _hook(mod, args, out, *, module_name=module_name): + captured[module_name] = { + "input": args[0].detach().cpu(), + "output": out.detach().cpu(), + "weight": mod.weight.detach().cpu(), + "bias": None if mod.bias is None else mod.bias.detach().cpu(), + } + + handles.append(module.register_forward_hook(_hook)) + + device = next(model.parameters()).device + inputs = tokenizer(PROMPT, return_tensors="pt").to(device) + with torch.inference_mode(): + model(**inputs) + + for handle in handles: + handle.remove() + + return captured + + +def _build_direct_gguf_module(native_module: torch.nn.Linear) -> GGUFTorchLinear: + module = GGUFTorchLinear( + bits="q4_k_m", + group_size=-1, + sym=True, + desc_act=False, + in_features=native_module.in_features, + out_features=native_module.out_features, + bias=native_module.bias is not None, + register_buffers=False, + ) + module.pack_original(linear=native_module, scales=None, zeros=None, g_idx=None) + module.post_init() + return module + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"]) +def test_direct_gguf_q4_k_m_llama3_2_layer0_attention_stays_close_to_native(dtype: torch.dtype): + device = _skip_unavailable_environment(dtype) + model, tokenizer = _load_llama(dtype) + + try: + captured = _capture_layer0_module_io(model, tokenizer) + layer0 = model.model.layers[0] + + for module_name in LAYER0_MODULES: + native_module = copy.deepcopy(dict(layer0.named_modules())[module_name]).to(device=device, dtype=dtype).eval() + gguf_module = _build_direct_gguf_module(native_module).to(device).eval() + record = captured[module_name] + + module_input = record["input"].to(device=device, dtype=dtype) + with torch.inference_mode(): + gguf_output = gguf_module(module_input).detach().cpu() + + dequant_weight = gguf_module.dequantize_weight().T.detach().cpu().to(record["weight"].dtype) + weight_stats = _error_stats(record["weight"].to(torch.float32), dequant_weight.to(torch.float32)) + output_stats = _error_stats(record["output"].to(torch.float32), gguf_output.to(torch.float32)) + limits = DIRECT_Q4_K_M_LIMITS[module_name] + + assert weight_stats["mae"] < limits["weight_mae"], f"{dtype}: {module_name} weight MAE {weight_stats['mae']:.6f}" + assert weight_stats["max"] < limits["weight_max"], f"{dtype}: {module_name} weight max {weight_stats['max']:.6f}" + assert output_stats["mae"] < limits["output_mae"], f"{dtype}: {module_name} output MAE {output_stats['mae']:.6f}" + assert output_stats["max"] < limits["output_max"], f"{dtype}: {module_name} output max {output_stats['max']:.6f}" + finally: + del model + torch.cuda.empty_cache() diff --git a/tests/test_glm_moe_dsa_support.py b/tests/test_glm_moe_dsa_support.py new file mode 100644 index 000000000..57b10db79 --- /dev/null +++ b/tests/test_glm_moe_dsa_support.py @@ -0,0 +1,206 @@ +import json +from pathlib import Path +from types import SimpleNamespace + +import pytest +import torch +from defuser import convert_model +from safetensors.torch import save_file +from transformers.models.glm_moe_dsa.configuration_glm_moe_dsa import GlmMoeDsaConfig +from transformers.models.glm_moe_dsa.modeling_glm_moe_dsa import GlmMoeDsaForCausalLM + +from gptqmodel.models import auto +from gptqmodel.models.definitions.glm_moe_dsa import GlmMoeDsaQModel +from gptqmodel.utils.structure import LazyTurtle, alias_from_turtle_for_submodule + + +_UPSTREAM_GLM5_MODELING_SIGNATURE = { + "architectures": ["GlmMoeDsaForCausalLM"], + "first_k_dense_replace": 3, + "hidden_act": "silu", + "hidden_size": 6144, + "index_head_dim": 128, + "index_n_heads": 32, + "index_topk": 2048, + "intermediate_size": 12288, + "kv_lora_rank": 512, + "max_position_embeddings": 202752, + "model_type": "glm_moe_dsa", + "moe_intermediate_size": 2048, + "n_routed_experts": 256, + "n_shared_experts": 1, + "num_attention_heads": 64, + "num_experts_per_tok": 8, + "num_hidden_layers": 78, + "num_key_value_heads": 64, + "q_lora_rank": 2048, + "qk_nope_head_dim": 192, + "qk_rope_head_dim": 64, + "rope_parameters": {"rope_theta": 1000000, "rope_type": "default"}, + "v_head_dim": 256, +} + +_UPSTREAM_GLM5_1_MODELING_SIGNATURE = { + "architectures": ["GlmMoeDsaForCausalLM"], + "first_k_dense_replace": 3, + "hidden_act": "silu", + "hidden_size": 6144, + "index_head_dim": 128, + "index_n_heads": 32, + "index_topk": 2048, + "intermediate_size": 12288, + "kv_lora_rank": 512, + "max_position_embeddings": 202752, + "model_type": "glm_moe_dsa", + "moe_intermediate_size": 2048, + "n_routed_experts": 256, + "n_shared_experts": 1, + "num_attention_heads": 64, + "num_experts_per_tok": 8, + "num_hidden_layers": 78, + "num_key_value_heads": 64, + "q_lora_rank": 2048, + "qk_nope_head_dim": 192, + "qk_rope_head_dim": 64, + "rope_parameters": {"rope_theta": 1000000, "rope_type": "default"}, + "v_head_dim": 256, +} + + +def _tiny_glm_moe_dsa_config(num_hidden_layers: int = 4) -> GlmMoeDsaConfig: + return GlmMoeDsaConfig( + num_hidden_layers=num_hidden_layers, + hidden_size=64, + intermediate_size=128, + moe_intermediate_size=32, + num_attention_heads=4, + num_key_value_heads=4, + q_lora_rank=16, + kv_lora_rank=8, + qk_rope_head_dim=8, + qk_nope_head_dim=8, + v_head_dim=16, + index_head_dim=8, + index_n_heads=2, + index_topk=16, + n_routed_experts=4, + n_shared_experts=1, + num_experts_per_tok=2, + vocab_size=128, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + ) + + +def _build_lazy_turtle(tmp_path: Path, model: GlmMoeDsaForCausalLM) -> LazyTurtle: + # Persist a tiny real GLM checkpoint so the test exercises the checkpoint-backed lazy path. + model_dir = tmp_path / "glm_source_model" + model_dir.mkdir() + shard_name = "model.safetensors" + state_dict = {name: tensor.detach().clone() for name, tensor in model.state_dict().items()} + save_file(state_dict, str(model_dir / shard_name)) + (model_dir / "model.safetensors.index.json").write_text( + json.dumps({"weight_map": dict.fromkeys(state_dict, shard_name)}), + encoding="utf-8", + ) + source = LazyTurtle.maybe_create( + model_local_path=str(model_dir), + config=SimpleNamespace(_experts_implementation=None), + model_init_kwargs={"device_map": {"": "cpu"}}, + ) + assert source is not None + return source + + +@pytest.mark.parametrize("model_path", ["/tmp/glm-5", "/tmp/glm-5.1"]) +def test_glm_moe_dsa_model_type_selects_definition_for_glm5_variants(monkeypatch, model_path): + fake_config = SimpleNamespace(model_type="glm_moe_dsa") + + monkeypatch.setattr(auto, "resolve_trust_remote_code", lambda path, trust_remote_code=False: trust_remote_code) + monkeypatch.setattr(auto.AutoConfig, "from_pretrained", lambda *args, **kwargs: fake_config) + + assert auto.check_and_get_model_definition(model_path) is GlmMoeDsaQModel + + +def test_glm5_and_glm5_1_share_same_upstream_modeling_signature(): + # Snapshot from the current upstream config.json files fetched on 2026-04-07. + assert _UPSTREAM_GLM5_MODELING_SIGNATURE == _UPSTREAM_GLM5_1_MODELING_SIGNATURE + + +def test_glm_moe_dsa_module_tree_expands_dense_and_sparse_paths(): + layer_modules = GlmMoeDsaQModel.simple_layer_modules( + model_config=_tiny_glm_moe_dsa_config(), + quantize_config=SimpleNamespace(dynamic=None), + ) + flat_modules = {name for block in layer_modules for name in block} + + assert GlmMoeDsaQModel.layer_modules_strict is False + assert "self_attn.q_a_proj" in flat_modules + assert "self_attn.kv_a_proj_with_mqa" in flat_modules + assert "self_attn.indexer.wk" in flat_modules + assert "self_attn.indexer.wq_b" in flat_modules + assert "mlp.gate_proj" in flat_modules + assert "mlp.experts.0.gate_proj" in flat_modules + assert "mlp.experts.0.up_proj" in flat_modules + assert "mlp.experts.0.down_proj" in flat_modules + assert "mlp.shared_experts.gate_proj" in flat_modules + assert "mlp.shared_experts.up_proj" in flat_modules + assert "mlp.shared_experts.down_proj" in flat_modules + + +def test_glm_moe_dsa_tiny_model_matches_definition(): + model = GlmMoeDsaForCausalLM(_tiny_glm_moe_dsa_config()) + convert_model(model, cleanup_original=False) + + dense_layer = model.model.layers[0] + moe_layer = model.model.layers[3] + + assert hasattr(dense_layer.self_attn, "q_a_proj") + assert hasattr(dense_layer.self_attn, "kv_a_proj_with_mqa") + assert hasattr(dense_layer.self_attn.indexer, "wk") + assert hasattr(dense_layer.self_attn.indexer, "wq_b") + assert hasattr(dense_layer.mlp, "gate_proj") + assert hasattr(dense_layer.mlp, "up_proj") + assert hasattr(dense_layer.mlp, "down_proj") + assert not hasattr(dense_layer.mlp, "experts") + + assert hasattr(moe_layer.mlp, "gate") + assert hasattr(moe_layer.mlp, "experts") + assert hasattr(moe_layer.mlp, "shared_experts") + assert len([name for name, _ in moe_layer.mlp.experts.named_children() if name.isdigit()]) == model.config.n_routed_experts + + expert0 = getattr(moe_layer.mlp.experts, "0") + assert hasattr(expert0, "gate_proj") + assert hasattr(expert0, "up_proj") + assert hasattr(expert0, "down_proj") + + +def test_glm_moe_dsa_lazy_turtle_restores_rotary_buffers_from_module_init(tmp_path): + source_model = GlmMoeDsaForCausalLM(_tiny_glm_moe_dsa_config()) + convert_model(source_model, cleanup_original=False) + shell_model = GlmMoeDsaForCausalLM(_tiny_glm_moe_dsa_config()) + convert_model(shell_model, cleanup_original=False) + shell_model.load_state_dict(source_model.state_dict()) + + rotary = shell_model.model.rotary_emb + rotary.register_buffer("inv_freq", torch.empty_like(rotary.inv_freq, device="meta"), persistent=False) + rotary.register_buffer( + "original_inv_freq", + torch.empty_like(rotary.original_inv_freq, device="meta"), + persistent=False, + ) + + turtle = _build_lazy_turtle(tmp_path, source_model) + alias_from_turtle_for_submodule( + target_model=shell_model, + turtle_model=turtle, + target_submodule=shell_model.model.rotary_emb, + device=torch.device("cpu"), + ) + + rebuilt_rotary = shell_model.model.rotary_emb + assert hasattr(rebuilt_rotary, "inv_freq") + assert hasattr(rebuilt_rotary, "original_inv_freq") + torch.testing.assert_close(rebuilt_rotary.inv_freq, source_model.model.rotary_emb.inv_freq) + torch.testing.assert_close(rebuilt_rotary.original_inv_freq, source_model.model.rotary_emb.original_inv_freq) diff --git a/tests/test_gptaq.py b/tests/test_gptaq.py index f013e6f13..43053f9c9 100644 --- a/tests/test_gptaq.py +++ b/tests/test_gptaq.py @@ -5,13 +5,13 @@ from models.model_test import ModelTest -from gptqmodel.utils.eval import EVAL +from gptqmodel.quantization.config import GPTAQConfig class TestQwen2_5_GPTAQ(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Qwen2.5-0.5B-Instruct" EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + "arc_challenge": { "chat_template": True, "acc": {"value": 0.2739, "floor_pct": 0.2}, "acc_norm": {"value": 0.3055, "floor_pct": 0.2}, @@ -19,7 +19,7 @@ class TestQwen2_5_GPTAQ(ModelTest): } TRUST_REMOTE_CODE = False EVAL_BATCH_SIZE = 6 - GPTQA = True + GPTAQ = GPTAQConfig() def test_qwen2_5(self): - self.quant_lm_eval() + self.quantize_and_evaluate() diff --git a/tests/test_gptq.py b/tests/test_gptq.py index f04824057..b9d5e0e56 100644 --- a/tests/test_gptq.py +++ b/tests/test_gptq.py @@ -85,6 +85,53 @@ def _run_batch(idx: int) -> None: return PathStats(per_batch_seconds=per_batch, total_seconds=total, peak_bytes=peak_bytes, batches_measured=measured) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required for CPU fallback regression coverage") +def test_gptq_cpu_hessian_fallback_returns_quantized_weights_to_original_cuda_device(monkeypatch): + device = torch.device("cuda", 0) + torch.cuda.set_device(device) + torch.manual_seed(0) + + layer = _make_module(hidden_dim=8, device=device) + qcfg = QuantizeConfig(bits=4, group_size=2, act_group_aware=True) + gptq = GPTQ(layer, qcfg=qcfg) + gptq.quantizer.configure(perchannel=True) + + inp = _generate_input(batch_size=1, seq_len=4, hidden_dim=8, device=device) + gptq.add_batch(inp, None) + + calls = {"cuda": 0, "cpu": 0} + + def _patched_hessian_inverse(self, hessian: torch.Tensor): + if hessian.device.type == "cuda": + calls["cuda"] += 1 + raise RuntimeError("CUDA out of memory. simulated for regression test") + + calls["cpu"] += 1 + identity = torch.eye(hessian.shape[0], dtype=torch.float32, device=hessian.device) + return identity, self.qcfg.damp_percent + + monkeypatch.setattr(GPTQ, "hessian_inverse", _patched_hessian_inverse) + log_messages = [] + + def _capture_warn(message, *args, **kwargs): + log_messages.append(message % args if args else message) + + def _capture_info(message, *args, **kwargs): + log_messages.append(message % args if args else message) + + monkeypatch.setattr(gptq_mod.log, "warn", _capture_warn) + monkeypatch.setattr(gptq_mod.log, "info", _capture_info) + + qweight, _, _, _, *_ = gptq.quantize(blocksize=4) + + assert calls == {"cuda": 1, "cpu": 1} + assert qweight.device == device + joined_logs = "\n".join(log_messages) + assert "falling back to CPU" in joined_logs + assert "may take much longer than normal" in joined_logs + assert "moving final quantized weights back" in joined_logs + + class TestGPTQAddBatchCPU(ModelTest): ######### test_gptq_add_batch_cpu.py ########### pytestmark = pytest.mark.skipif( @@ -331,4 +378,3 @@ def get_random_word(self): pytest.skip( f"Streaming event helper subprocess unavailable: rc={result.returncode}, stderr={result.stderr.strip()}" ) - diff --git a/tests/test_gptq_pro.py b/tests/test_gptq_pro.py new file mode 100644 index 000000000..a682fdf4e --- /dev/null +++ b/tests/test_gptq_pro.py @@ -0,0 +1,71 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import pytest +import torch + +from gptqmodel.nn_modules.qlinear.gptq_pro import GptqProQuantLinear +from gptqmodel.utils.gptq_pro import _validate_gptq_pro_device_support, ensure_gptq_pro_loaded + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required for GPTQ-Pro tests.") +@pytest.mark.skipif(not _validate_gptq_pro_device_support(), reason="GPTQ-Pro requires CUDA compute capability >= 8.0.") +def test_gptq_pro_forward_matches_reference(): + try: + ensure_gptq_pro_loaded() + except ImportError as exc: # pragma: no cover - environment-specific + pytest.skip(f"GPTQ-Pro extension unavailable: {exc}") + + in_features = 32 + out_features = 64 + group_size = 16 + groups = in_features // group_size + + scales_out_g = torch.tensor( + [ + [0.125 + 0.015625 * ((out + grp) % 5) for grp in range(groups)] + for out in range(out_features) + ], + dtype=torch.float16, + ) + zeros_out_g = torch.full((out_features, groups), 8, dtype=torch.float16) + g_idx = torch.arange(in_features, dtype=torch.int32) // group_size + + int_weight = ( + (torch.arange(in_features * out_features, dtype=torch.int32).reshape(in_features, out_features) % 7) + 5 + ) + scales_g_out = scales_out_g.T.contiguous() + weight_kn = scales_g_out[g_idx.long()].float() * (int_weight.float() - 8.0) + weight_out_in = weight_kn.T.contiguous().to(torch.float16) + + linear = torch.nn.Linear(in_features, out_features, bias=True, dtype=torch.float16) + linear.weight.data.copy_(weight_out_in) + linear.bias.data.copy_(torch.linspace(-0.25, 0.25, out_features, dtype=torch.float16)) + + module = GptqProQuantLinear( + bits=4, + group_size=group_size, + desc_act=False, + sym=True, + in_features=in_features, + out_features=out_features, + bias=True, + pack_dtype=torch.int32, + register_buffers=False, + ) + module.pack_original(linear=linear, scales=scales_out_g, zeros=zeros_out_g, g_idx=g_idx) + module = module.to("cuda:0") + module.post_init() + + x = torch.linspace(-1.0, 1.0, steps=5 * in_features, dtype=torch.float16, device="cuda:0").reshape(5, in_features) + got = module(x) + + weight_device = weight_out_in.to(device=x.device, dtype=torch.float32) + bias_device = linear.bias.detach().to(device=x.device, dtype=torch.float32) + expected = torch.matmul(x.to(torch.float32), weight_device.T) + expected.add_(bias_device) + expected = expected.to(torch.float16) + + torch.testing.assert_close(got, expected, rtol=0, atol=1e-3) diff --git a/tests/test_gpu_gpu_memory_copy.py b/tests/test_gpu_gpu_memory_copy.py index 64c27dd5d..feb071497 100644 --- a/tests/test_gpu_gpu_memory_copy.py +++ b/tests/test_gpu_gpu_memory_copy.py @@ -7,10 +7,14 @@ import math import time +import pytest import torch from models.model_test import ModelTest +pytestmark = pytest.mark.gpu + + # cpu_gpu_bandwidth_test.py # Measure HtoD and DtoH bandwidth with pageable vs pinned CPU memory. # @@ -128,7 +132,7 @@ def cpu_cpu(): args, _ = parser.parse_known_args() if not torch.cuda.is_available(): - raise SystemExit("CUDA not available.") + pytest.skip("CUDA not available.") if args.chunk_gib <= 0 or args.total_gib <= 0 or args.total_gib < args.chunk_gib: raise SystemExit("Invalid sizes: ensure total-gib >= chunk-gib > 0.") @@ -202,8 +206,10 @@ def gpu_gpu(): parser.add_argument("--chunk-gib", type=float, default=1.0, help="chunk size GiB per copy") args, _ = parser.parse_known_args() - if not torch.cuda.is_available() or torch.cuda.device_count() < 2: - raise SystemExit("Need at least 2 CUDA devices.") + if not torch.cuda.is_available(): + pytest.skip("CUDA not available.") + if torch.cuda.device_count() < 2: + pytest.skip("Need at least 2 CUDA devices.") # Basic info print(f"Detected {torch.cuda.device_count()} CUDA devices.") @@ -221,6 +227,10 @@ def gpu_gpu(): class Test(ModelTest): - def test(self): + def test_cpu_gpu(self): + """Measure pageable and pinned CPU/GPU copy bandwidth.""" cpu_cpu() + + def test_gpu_gpu(self): + """Measure inter-GPU copy bandwidth when two CUDA devices are visible.""" gpu_gpu() diff --git a/tests/test_granitemoehybrid_monkeypatch.py b/tests/test_granitemoehybrid_monkeypatch.py new file mode 100644 index 000000000..e0291495a --- /dev/null +++ b/tests/test_granitemoehybrid_monkeypatch.py @@ -0,0 +1,69 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import torch + +from gptqmodel.models.definitions.granitemoehybrid import GraniteMoeHybridQModel +from gptqmodel.nn_modules.qlinear.torch import TorchLinear + + +class _DummyQuantMamba(torch.nn.Module): + def __init__(self): + super().__init__() + self.in_proj = TorchLinear( + bits=4, + group_size=32, + sym=True, + desc_act=False, + in_features=64, + out_features=64, + bias=False, + pack_dtype=torch.int32, + adapter=None, + register_buffers=True, + ) + self.out_proj = TorchLinear( + bits=4, + group_size=32, + sym=True, + desc_act=False, + in_features=64, + out_features=64, + bias=False, + pack_dtype=torch.int32, + adapter=None, + register_buffers=True, + ) + self.path = None + + def torch_forward(self, *args, **kwargs): + self.path = "torch" + return "torch" + + def forward(self, *args, **kwargs): + self.path = "fast" + return "fast" + + +def test_granitemoehybrid_quantized_mamba_uses_torch_path(): + qmodel = GraniteMoeHybridQModel.__new__(GraniteMoeHybridQModel) + qmodel.model = type( + "_Outer", + (), + { + "model": type( + "_Inner", + (), + {"layers": [type("_Layer", (), {"mamba": _DummyQuantMamba()})()]}, + )() + }, + )() + + qmodel.monkey_patch() + mamba = qmodel.model.model.layers[0].mamba + + result = mamba(torch.zeros(1, 2, 64)) + + assert result == "torch" + assert mamba.path == "torch" diff --git a/tests/test_group_size.py b/tests/test_group_size.py index 32380ece0..e0f23dd3e 100644 --- a/tests/test_group_size.py +++ b/tests/test_group_size.py @@ -14,38 +14,35 @@ import traceback # noqa: E402 import unittest # noqa: E402 -from lm_eval.utils import make_table # noqa: E402 from transformers import AutoTokenizer # noqa: E402 from gptqmodel import BACKEND, GPTQModel, QuantizeConfig # noqa: E402 -from gptqmodel.nn_modules.qlinear.bitblas import BitBLASQuantLinear # noqa: E402 -from gptqmodel.nn_modules.qlinear.exllama import ExllamaQuantLinear # noqa: E402 -from gptqmodel.nn_modules.qlinear.exllamav2 import ExllamaV2QuantLinear # noqa: E402 -from gptqmodel.nn_modules.qlinear.marlin import MarlinQuantLinear # noqa: E402 -from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear # noqa: E402 -from gptqmodel.nn_modules.qlinear.tritonv2 import TritonV2QuantLinear # noqa: E402 -from gptqmodel.utils.eval import EVAL # noqa: E402 +from gptqmodel.nn_modules.qlinear.bitblas import BitBLASLinear # noqa: E402 +from gptqmodel.nn_modules.qlinear.exllamav2 import ExllamaV2Linear # noqa: E402 +from gptqmodel.nn_modules.qlinear.marlin import MarlinLinear # noqa: E402 +from gptqmodel.nn_modules.qlinear.torch import TorchLinear # noqa: E402 +from gptqmodel.nn_modules.qlinear.tritonv2 import TritonV2Linear # noqa: E402 +from tests.eval import evaluate, format_eval_result_table, get_eval_task_metrics # noqa: E402 logger = logging.getLogger(__name__) RAND_SEED = 42 -TASK_NAME = EVAL.LM_EVAL.ARC_CHALLENGE +TASK_NAME = "arc_challenge" class TestGroupSize(unittest.TestCase): QLINEAR_DICT = { - BACKEND.EXLLAMA_V1: ExllamaQuantLinear, - BACKEND.EXLLAMA_V2: ExllamaV2QuantLinear, - BACKEND.TRITON: TritonV2QuantLinear, - BACKEND.TORCH: TorchQuantLinear, - BACKEND.BITBLAS: BitBLASQuantLinear, - BACKEND.MARLIN: MarlinQuantLinear, + BACKEND.EXLLAMA_V2: ExllamaV2Linear, + BACKEND.TRITON: TritonV2Linear, + BACKEND.TORCH: TorchLinear, + BACKEND.BITBLAS: BitBLASLinear, + BACKEND.MARLIN: MarlinLinear, } @classmethod def setUpClass(cls): - cls.pack_backends = [BACKEND.EXLLAMA_V1, BACKEND.TRITON, BACKEND.TORCH, BACKEND.BITBLAS] + cls.pack_backends = [BACKEND.TRITON, BACKEND.TORCH, BACKEND.BITBLAS] cls.backends = list(cls.pack_backends) cls.backends.extend([BACKEND.EXLLAMA_V2, BACKEND.MARLIN, ]) @@ -104,7 +101,7 @@ def eval(self, inference_backend, quant_backend, quantize_config, tmp_dir): device_map="auto", backend=inference_backend, ) - results = GPTQModel.eval( + results = evaluate( model_or_id_or_path=model, output_path=tmp_dir, tasks=TASK_NAME, @@ -112,15 +109,12 @@ def eval(self, inference_backend, quant_backend, quantize_config, tmp_dir): trust_remote_code=False, batch_size=32, gen_kwargs="temperature=0.0,top_k=50", - random_seed=RAND_SEED, ) print('--------Eval Result---------') - print(make_table(results)) - if "groups" in results: - print(make_table(results, "groups")) + print(format_eval_result_table(results)) print('--------Eval Result End---------') task_results = { - metric: value for metric, value in results['results'].get(TASK_NAME, {}).items() + metric: value for metric, value in get_eval_task_metrics(results, TASK_NAME).items() if metric != 'alias' and 'stderr' not in metric } print( diff --git a/tests/test_hf_config_autofix.py b/tests/test_hf_config_autofix.py new file mode 100644 index 000000000..2da16e6d7 --- /dev/null +++ b/tests/test_hf_config_autofix.py @@ -0,0 +1,70 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from types import SimpleNamespace + +from gptqmodel.utils.hf import ensure_hf_model_config_token_ids, load_tokenizer_with_model_config + + +def test_ensure_hf_model_config_token_ids_handles_nested_text_config_and_missing_top_level_fields(): + config = SimpleNamespace( + text_config=SimpleNamespace( + bos_token_id=None, + eos_token_id=248044, + pad_token_id=None, + ), + ) + tokenizer = SimpleNamespace( + bos_token_id=None, + eos_token_id=248046, + pad_token_id=248044, + ) + + changed = ensure_hf_model_config_token_ids(config, tokenizer=tokenizer) + + assert changed is True + assert hasattr(config, "bos_token_id") + assert hasattr(config, "eos_token_id") + assert hasattr(config, "pad_token_id") + assert config.bos_token_id is None + assert config.eos_token_id == 248044 + assert config.pad_token_id == 248044 + + +class _FakeTokenizer: + bos_token_id = None + eos_token_id = 248046 + pad_token_id = 248044 + trust_remote_code = False + + def get_vocab(self): + return {"<|endoftext|>": 248044} + + def decode(self, ids): + if ids == [248044]: + return "<|endoftext|>" + return "" + + +def test_load_tokenizer_with_model_config_uses_in_memory_config(): + tokenizer = _FakeTokenizer() + config = SimpleNamespace( + text_config=SimpleNamespace( + bos_token_id=None, + eos_token_id=248044, + pad_token_id=None, + model_type="qwen3_5_text", + ), + model_type="qwen3_5", + ) + + wrapped = load_tokenizer_with_model_config(tokenizer, config) + + assert wrapped.model_config.eos_token_id == 248044 + assert wrapped.model_config.pad_token_id == 248044 + assert wrapped.eos_token_id == 248046 + assert wrapped.pad_token_id == 248044 + assert config.eos_token_id == 248044 + assert config.pad_token_id == 248044 diff --git a/tests/test_hf_config_compat.py b/tests/test_hf_config_compat.py new file mode 100644 index 000000000..33807fbca --- /dev/null +++ b/tests/test_hf_config_compat.py @@ -0,0 +1,602 @@ +import sys +from enum import Enum +from types import ModuleType, SimpleNamespace + +import pytest +import torch +import transformers +import transformers.generation.utils as generation_utils +from transformers import GenerationConfig, GPTNeoXConfig, LlamaConfig, cache_utils +from transformers.generation.configuration_utils import GenerationMode + +from gptqmodel.utils import hf as hf_utils +from gptqmodel.utils import internal_gguf +from gptqmodel.utils.hf import ( + INTERNAL_HF_GGUF_FILE_KWARG, + get_hf_gguf_load_kwargs, + normalize_hf_config_compat, + normalize_model_id_or_path_for_hf_gguf, + normalize_torch_dtype_kwarg, + prepare_remote_model_init_compat, + resolve_trust_remote_code, +) + + +def test_normalize_hf_config_compat_backfills_llama_rope_parameters(): + config = LlamaConfig(rope_parameters=None, rope_theta=12345.0) + + normalize_hf_config_compat(config) + + assert config.rope_parameters["rope_type"] == "default" + assert config.rope_parameters["rope_theta"] == 12345.0 + + +def test_normalize_hf_config_compat_uses_gpt_neox_defaults(): + config = GPTNeoXConfig(rope_parameters=None) + + normalize_hf_config_compat(config) + + assert config.rope_parameters["rope_type"] == "default" + assert config.rope_parameters["rope_theta"] == config.default_theta + assert config.rope_parameters["partial_rotary_factor"] == 0.25 + + +def test_normalize_torch_dtype_kwarg_moves_alias_to_dtype(): + kwargs = {"torch_dtype": torch.float16} + + resolved = normalize_torch_dtype_kwarg(kwargs, api_name="test") + + assert resolved is torch.float16 + assert kwargs == {"dtype": torch.float16} + + +def test_normalize_torch_dtype_kwarg_resolves_explicit_dtype_parameter(): + kwargs = {"torch_dtype": torch.bfloat16} + + resolved = normalize_torch_dtype_kwarg(kwargs, api_name="test", explicit_dtype="auto") + + assert resolved is torch.bfloat16 + assert kwargs == {} + + +def test_normalize_torch_dtype_kwarg_rejects_conflicting_values(): + kwargs = {"dtype": torch.float16, "torch_dtype": torch.bfloat16} + + with pytest.raises(ValueError, match="both `dtype` and deprecated `torch_dtype`"): + normalize_torch_dtype_kwarg(kwargs, api_name="test") + + +def test_normalize_model_id_or_path_for_hf_gguf_rejects_public_kwarg(): + kwargs = {"gguf_file": "bonsai.gguf"} + + with pytest.raises(TypeError, match="does not accept `gguf_file`"): + normalize_model_id_or_path_for_hf_gguf("/tmp/fake-model", kwargs, api_name="test") + + +def test_normalize_model_id_or_path_for_hf_gguf_normalizes_local_file(monkeypatch, tmp_path): + gguf_path = tmp_path / "bonsai.gguf" + gguf_path.write_bytes(b"GGUF") + + monkeypatch.setattr(hf_utils, "_patch_transformers_prism_gguf_compat", lambda **_kwargs: None) + + kwargs = {} + model_root = normalize_model_id_or_path_for_hf_gguf(str(gguf_path), kwargs, api_name="test") + + assert model_root == str(tmp_path) + assert kwargs[INTERNAL_HF_GGUF_FILE_KWARG] == "bonsai.gguf" + assert get_hf_gguf_load_kwargs(kwargs) == {"gguf_file": "bonsai.gguf"} + + +def test_patch_transformers_prism_gguf_compat_registers_internal_runtime(monkeypatch): + import transformers.modeling_gguf_pytorch_utils as gguf_utils + from transformers.utils import import_utils as hf_import_utils + + monkeypatch.delitem(sys.modules, "gguf", raising=False) + monkeypatch.setattr(gguf_utils, "is_gguf_available", lambda *args, **kwargs: False) + monkeypatch.setattr(hf_import_utils, "is_gguf_available", lambda *args, **kwargs: False) + monkeypatch.setattr(hf_utils, "_transformers_has_native_prism_gguf_support", lambda: False) + + hf_utils._patch_transformers_prism_gguf_compat(api_name="test") + + assert sys.modules["gguf"] is internal_gguf + assert gguf_utils.is_gguf_available() is True + assert hf_import_utils.is_gguf_available() is True + assert gguf_utils.PRISM_Q1_0_G128_NAME == hf_utils.PRISM_Q1_0_G128_NAME + assert gguf_utils._dequantize_prism_q1_0_g128 is hf_utils._dequantize_prism_q1_0_g128 + + +def test_patch_transformers_prism_gguf_compat_wraps_load_checkpoint_for_torch_loader(monkeypatch): + import transformers.modeling_gguf_pytorch_utils as gguf_utils + from transformers.utils import import_utils as hf_import_utils + + calls = {"direct": 0} + + def _original_load_gguf_checkpoint(*args, **kwargs): + return {"variant": "original"} + + def _direct_loader(**kwargs): + calls["direct"] += 1 + return {"variant": "direct", "path": kwargs["gguf_checkpoint_path"]} + + monkeypatch.setenv("GPTQMODEL_INTERNAL_GGUF_TORCH_LOADER", "1") + monkeypatch.delitem(sys.modules, "gguf", raising=False) + monkeypatch.setattr(gguf_utils, "load_gguf_checkpoint", _original_load_gguf_checkpoint) + monkeypatch.delattr(gguf_utils, "_GPTQMODEL_INTERNAL_GGUF_TORCH_LOADER_PATCHED", raising=False) + monkeypatch.delattr(gguf_utils, "_gptqmodel_original_load_gguf_checkpoint", raising=False) + monkeypatch.setattr(gguf_utils, "is_gguf_available", lambda *args, **kwargs: False) + monkeypatch.setattr(hf_import_utils, "is_gguf_available", lambda *args, **kwargs: False) + monkeypatch.setattr(hf_utils, "_transformers_has_native_prism_gguf_support", lambda: False) + monkeypatch.setattr(hf_utils, "_load_gguf_checkpoint_torch_direct", _direct_loader) + + hf_utils._patch_transformers_prism_gguf_compat(api_name="test") + result = gguf_utils.load_gguf_checkpoint("bonsai.gguf", return_tensors=True, model_to_load=object()) + + assert calls["direct"] == 1 + assert result == {"variant": "direct", "path": "bonsai.gguf"} + + +def test_normalize_hf_config_compat_drops_default_remote_rope_scaling_dict(): + config = SimpleNamespace(rope_scaling={"rope_type": "default", "rope_theta": 10000.0}) + + normalize_hf_config_compat(config, trust_remote_code=True) + + assert config.rope_scaling is None + + +def test_normalize_hf_config_compat_preserves_rope_parameters_after_remote_cleanup(): + config = LlamaConfig(rope_scaling={"rope_type": "default", "rope_theta": 10000.0}) + + normalize_hf_config_compat(config, trust_remote_code=True) + + assert config.rope_parameters["rope_type"] == "default" + assert config.rope_parameters["rope_theta"] == 10000.0 + + +def test_normalize_hf_config_compat_restores_sliding_window_cache_alias(monkeypatch): + monkeypatch.delattr(cache_utils, "SlidingWindowCache", raising=False) + + normalize_hf_config_compat(SimpleNamespace(), trust_remote_code=True) + + assert cache_utils.SlidingWindowCache is cache_utils.StaticCache + + +def test_normalize_hf_config_compat_restores_hybrid_cache_alias(monkeypatch): + monkeypatch.delattr(cache_utils, "HybridCache", raising=False) + + normalize_hf_config_compat(SimpleNamespace(), trust_remote_code=True) + + assert cache_utils.HybridCache is cache_utils.StaticCache + + +def test_normalize_hf_config_compat_restores_is_parallelizable_default(monkeypatch): + monkeypatch.delattr(transformers.PreTrainedModel, "is_parallelizable", raising=False) + + normalize_hf_config_compat(SimpleNamespace(), trust_remote_code=True) + + assert transformers.PreTrainedModel.is_parallelizable is False + + +def test_normalize_hf_config_compat_restores_flash_attn_legacy_version_probe(monkeypatch): + monkeypatch.delattr(transformers.utils, "is_flash_attn_greater_or_equal_2_10", raising=False) + monkeypatch.setattr(transformers.utils, "is_flash_attn_greater_or_equal", lambda version: version == "2.1.0") + + normalize_hf_config_compat(SimpleNamespace(), trust_remote_code=True) + + assert transformers.utils.is_flash_attn_greater_or_equal_2_10() is True + + +def test_normalize_hf_config_compat_restores_legacy_cache_length_helpers(monkeypatch): + monkeypatch.delattr(cache_utils.Cache, "get_max_length", raising=False) + monkeypatch.delattr(cache_utils.Cache, "get_usable_length", raising=False) + + class DummyLayer(cache_utils.CacheLayerMixin): + def __init__(self, seq_length, max_cache_shape): + super().__init__() + self._seq_length = seq_length + self._max_cache_shape = max_cache_shape + + def lazy_initialization(self, key_states, value_states): + self.keys = key_states + self.values = value_states + self.is_initialized = True + + def update(self, key_states, value_states, *args, **kwargs): + self.lazy_initialization(key_states, value_states) + return key_states, value_states + + def get_mask_sizes(self, query_length): + return query_length, self._seq_length + + def get_seq_length(self): + return self._seq_length + + def get_max_cache_shape(self): + return self._max_cache_shape + + class DummyCache(cache_utils.Cache): + def __init__(self, seq_length, max_cache_shape): + self.layers = [DummyLayer(seq_length, max_cache_shape)] + + normalize_hf_config_compat(SimpleNamespace(), trust_remote_code=True) + + limited_cache = DummyCache(seq_length=8, max_cache_shape=10) + dynamic_cache = DummyCache(seq_length=8, max_cache_shape=-1) + + assert limited_cache.get_max_length() == 10 + assert limited_cache.get_usable_length(4) == 6 + assert dynamic_cache.get_max_length() is None + assert dynamic_cache.get_usable_length(4) == 8 + + +def test_normalize_hf_config_compat_restores_legacy_dynamic_cache_converters(monkeypatch): + monkeypatch.delattr(cache_utils.DynamicCache, "to_legacy_cache", raising=False) + monkeypatch.delattr(cache_utils.DynamicCache, "from_legacy_cache", raising=False) + + normalize_hf_config_compat(SimpleNamespace(), trust_remote_code=True) + + key_states = torch.randn(1, 2, 3, 4) + value_states = torch.randn(1, 2, 3, 4) + + cache = cache_utils.DynamicCache() + cache.update(key_states, value_states, 0) + legacy_cache = cache.to_legacy_cache() + restored_cache = cache_utils.DynamicCache.from_legacy_cache(legacy_cache) + + assert len(legacy_cache) == 1 + assert torch.equal(legacy_cache[0][0], key_states) + assert torch.equal(legacy_cache[0][1], value_states) + assert restored_cache.get_seq_length(0) == 3 + assert torch.equal(restored_cache.layers[0].keys, key_states) + assert torch.equal(restored_cache.layers[0].values, value_states) + + +def test_normalize_hf_config_compat_restores_generation_cache_mapping_alias(monkeypatch): + monkeypatch.delattr(generation_utils, "NEED_SETUP_CACHE_CLASSES_MAPPING", raising=False) + + normalize_hf_config_compat(SimpleNamespace(), trust_remote_code=True) + + namespace = {} + exec("from transformers.generation.utils import NEED_SETUP_CACHE_CLASSES_MAPPING", namespace) + + assert namespace["NEED_SETUP_CACHE_CLASSES_MAPPING"] is generation_utils.NEED_SETUP_CACHE_CLASSES_MAPPING + assert isinstance(generation_utils.NEED_SETUP_CACHE_CLASSES_MAPPING, dict) + + +def test_normalize_hf_config_compat_supports_legacy_custom_generation_cache_mapping(): + class DummyCustomCache: + def __init__(self, **kwargs): + self.kwargs = kwargs + + class DummyConfig: + is_encoder_decoder = False + dtype = torch.float16 + + def get_text_config(self, decoder=True): + assert decoder is True + return self + + class DummyModel(generation_utils.GenerationMixin): + _is_stateful = False + + def __init__(self): + self.config = DummyConfig() + self.dtype = torch.float16 + self.device = torch.device("cpu") + + normalize_hf_config_compat(SimpleNamespace(), trust_remote_code=True) + + original_mapping = dict(generation_utils.NEED_SETUP_CACHE_CLASSES_MAPPING) + generation_utils.NEED_SETUP_CACHE_CLASSES_MAPPING["variable"] = DummyCustomCache + try: + model = DummyModel() + generation_config = GenerationConfig(use_cache=True, num_beams=2, num_return_sequences=1) + generation_config.cache_implementation = "variable" + model_kwargs = {} + + model._prepare_cache_for_generation( + generation_config, + model_kwargs, + GenerationMode.GREEDY_SEARCH, + batch_size=3, + max_cache_length=17, + ) + finally: + generation_utils.NEED_SETUP_CACHE_CLASSES_MAPPING.clear() + generation_utils.NEED_SETUP_CACHE_CLASSES_MAPPING.update(original_mapping) + + assert isinstance(model_kwargs["past_key_values"], DummyCustomCache) + assert model_kwargs["past_key_values"].kwargs["config"] is model.config + assert model_kwargs["past_key_values"].kwargs["batch_size"] == 6 + assert model_kwargs["past_key_values"].kwargs["max_batch_size"] == 6 + assert model_kwargs["past_key_values"].kwargs["max_cache_len"] == 17 + assert model_kwargs["past_key_values"].kwargs["dtype"] == torch.float16 + assert model_kwargs["past_key_values"].kwargs["device"] == torch.device("cpu") + + +def test_prepare_remote_model_init_compat_patches_phi4_scalar_tensors(monkeypatch): + calls = [] + + def fake_tensor(data, *args, **kwargs): + calls.append((data, kwargs.get("device"))) + return (data, kwargs) + + fake_torch = SimpleNamespace(tensor=fake_tensor) + speech_module_name = "transformers_modules.fake_phi4.speech_conformer_encoder" + speech_module = SimpleNamespace(torch=fake_torch) + monkeypatch.setitem(sys.modules, speech_module_name, speech_module) + + dummy_cls = type("DummyPhi4MM", (), {}) + dummy_cls.__module__ = "transformers_modules.fake_phi4.modeling_phi4mm" + monkeypatch.setattr( + "transformers.dynamic_module_utils.get_class_from_dynamic_module", + lambda class_ref, model_id_or_path, **kwargs: dummy_cls, + ) + + config = SimpleNamespace( + model_type="phi4mm", + auto_map={"AutoModelForCausalLM": "modeling_phi4mm.Phi4MMForCausalLM"}, + ) + + prepare_remote_model_init_compat("/tmp/phi4mm", config) + monkeypatch.setattr( + "gptqmodel.utils.hf.inspect.stack", + lambda context=0: [SimpleNamespace(filename="/tmp/speech_conformer_encoder.py", lineno=1426)], + ) + monkeypatch.setattr(torch.utils._device, "CURRENT_DEVICE", torch.device("meta")) + + speech_module.torch.tensor(80, dtype=torch.float32) + + assert calls == [(80, "cpu")] + assert getattr(speech_module, "_gptqmodel_scalar_tensor_meta_patch", False) is True + + +def test_prepare_remote_model_init_compat_wraps_legacy_tie_weights_signature(monkeypatch): + calls = [] + + class DummyRemoteModel: + __module__ = "transformers_modules.fake_ovis.modeling_ovis" + + def tie_weights(self): + calls.append("tied") + + monkeypatch.setattr( + "transformers.dynamic_module_utils.get_class_from_dynamic_module", + lambda class_ref, model_id_or_path, **kwargs: DummyRemoteModel, + ) + + config = SimpleNamespace( + model_type="ovis", + auto_map={"AutoModelForCausalLM": "modeling_ovis.Ovis"}, + ) + + prepare_remote_model_init_compat("/tmp/ovis", config) + + DummyRemoteModel().tie_weights(missing_keys={"lm_head.weight"}, recompute_mapping=False) + + assert calls == ["tied"] + assert getattr(DummyRemoteModel, "_gptqmodel_tie_weights_kwargs_patch", False) is True + + +def test_prepare_remote_model_init_compat_accepts_tokenizers_backend_for_ovis(monkeypatch): + class DummyRemoteModel: + __module__ = "transformers_modules.fake_ovis.modeling_ovis" + + config_module = ModuleType("transformers_modules.fake_ovis.configuration_ovis") + + class Llama3ConversationFormatter: + support_tokenizer_types = ["PreTrainedTokenizerFast"] + + config_module.Llama3ConversationFormatter = Llama3ConversationFormatter + monkeypatch.setitem(sys.modules, config_module.__name__, config_module) + monkeypatch.setitem(sys.modules, "transformers_modules.fake_ovis.modeling_ovis", ModuleType("transformers_modules.fake_ovis.modeling_ovis")) + monkeypatch.setattr( + "transformers.dynamic_module_utils.get_class_from_dynamic_module", + lambda class_ref, model_id_or_path, **kwargs: DummyRemoteModel, + ) + + config = SimpleNamespace( + model_type="ovis", + auto_map={"AutoModelForCausalLM": "modeling_ovis.Ovis"}, + ) + + prepare_remote_model_init_compat("/tmp/ovis", config) + + assert "TokenizersBackend" in Llama3ConversationFormatter.support_tokenizer_types + assert getattr(Llama3ConversationFormatter, "_gptqmodel_tokenizer_backend_patch", False) is True + + +def test_prepare_remote_model_init_compat_promotes_phi4_positional_seed_to_meta(monkeypatch): + seen_devices = [] + + class AbsolutePositionalEncoding: + def extend_pe(self, x): + seen_devices.append(x.device.type) + + fake_torch = SimpleNamespace(tensor=lambda data, *args, **kwargs: (data, kwargs)) + speech_module_name = "transformers_modules.fake_phi4_meta.speech_conformer_encoder" + speech_module = SimpleNamespace(torch=fake_torch, AbsolutePositionalEncoding=AbsolutePositionalEncoding) + monkeypatch.setitem(sys.modules, speech_module_name, speech_module) + + dummy_cls = type("DummyPhi4MM", (), {}) + dummy_cls.__module__ = "transformers_modules.fake_phi4_meta.modeling_phi4mm" + monkeypatch.setattr( + "transformers.dynamic_module_utils.get_class_from_dynamic_module", + lambda class_ref, model_id_or_path, **kwargs: dummy_cls, + ) + + config = SimpleNamespace( + model_type="phi4mm", + auto_map={"AutoModelForCausalLM": "modeling_phi4mm.Phi4MMForCausalLM"}, + ) + + prepare_remote_model_init_compat("/tmp/phi4mm", config) + monkeypatch.setattr( + "gptqmodel.utils.hf.inspect.stack", + lambda context=0: [SimpleNamespace(filename="/tmp/speech_conformer_encoder.py", lineno=895)], + ) + + AbsolutePositionalEncoding().extend_pe(torch.tensor(0.0)) + + assert seen_devices == ["meta"] + + +def test_prepare_remote_model_init_compat_tightens_peft_awq_probe(monkeypatch): + fake_torch = SimpleNamespace(tensor=lambda data, *args, **kwargs: (data, kwargs)) + speech_module_name = "transformers_modules.fake_phi4_awq.speech_conformer_encoder" + speech_module = SimpleNamespace(torch=fake_torch) + monkeypatch.setitem(sys.modules, speech_module_name, speech_module) + + dummy_cls = type("DummyPhi4MM", (), {}) + dummy_cls.__module__ = "transformers_modules.fake_phi4_awq.modeling_phi4mm" + monkeypatch.setattr( + "transformers.dynamic_module_utils.get_class_from_dynamic_module", + lambda class_ref, model_id_or_path, **kwargs: dummy_cls, + ) + + config = SimpleNamespace( + model_type="phi4mm", + auto_map={"AutoModelForCausalLM": "modeling_phi4mm.Phi4MMForCausalLM"}, + ) + + def fake_find_spec(name): + if name == "awq.modules.linear": + return None + if name == "awq": + return object() + return None + + monkeypatch.setattr("importlib.util.find_spec", fake_find_spec) + + prepare_remote_model_init_compat("/tmp/phi4mm", config) + + peft_awq = pytest.importorskip("peft.tuners.lora.awq") + + peft_awq.is_auto_awq_available.cache_clear() + assert peft_awq.is_auto_awq_available() is False + + +def test_prepare_remote_model_init_compat_adds_phi4_inner_prepare_inputs_hook(monkeypatch): + class Phi4MMModel: + pass + + remote_module_name = "transformers_modules.fake_phi4_inner.modeling_phi4mm" + speech_module_name = "transformers_modules.fake_phi4_inner.speech_conformer_encoder" + remote_module = SimpleNamespace(Phi4MMModel=Phi4MMModel) + speech_module = SimpleNamespace(torch=SimpleNamespace(tensor=lambda data, *args, **kwargs: (data, kwargs))) + + monkeypatch.setitem(sys.modules, remote_module_name, remote_module) + monkeypatch.setitem(sys.modules, speech_module_name, speech_module) + + dummy_cls = type("DummyPhi4MM", (), {}) + dummy_cls.__module__ = remote_module_name + monkeypatch.setattr( + "transformers.dynamic_module_utils.get_class_from_dynamic_module", + lambda class_ref, model_id_or_path, **kwargs: dummy_cls, + ) + + config = SimpleNamespace( + model_type="phi4mm", + auto_map={"AutoModelForCausalLM": "modeling_phi4mm.Phi4MMForCausalLM"}, + ) + + prepare_remote_model_init_compat("/tmp/phi4mm", config) + + model_inputs = Phi4MMModel().prepare_inputs_for_generation(input_ids="ids", past_key_values="cache") + + assert model_inputs["input_ids"] == "ids" + assert model_inputs["past_key_values"] == "cache" + + +def test_prepare_remote_model_init_compat_defaults_phi4_forward_input_mode(monkeypatch): + class InputMode(Enum): + LANGUAGE = 0 + VISION = 1 + SPEECH = 2 + VISION_SPEECH = 3 + + class Phi4MMForCausalLM: + def forward(self, *args, **kwargs): + return kwargs["input_mode"] + + remote_module_name = "transformers_modules.fake_phi4_forward.modeling_phi4mm" + speech_module_name = "transformers_modules.fake_phi4_forward.speech_conformer_encoder" + remote_module = SimpleNamespace(InputMode=InputMode, Phi4MMForCausalLM=Phi4MMForCausalLM) + speech_module = SimpleNamespace(torch=SimpleNamespace(tensor=lambda data, *args, **kwargs: (data, kwargs))) + + Phi4MMForCausalLM.__module__ = remote_module_name + monkeypatch.setitem(sys.modules, remote_module_name, remote_module) + monkeypatch.setitem(sys.modules, speech_module_name, speech_module) + monkeypatch.setattr( + "transformers.dynamic_module_utils.get_class_from_dynamic_module", + lambda class_ref, model_id_or_path, **kwargs: Phi4MMForCausalLM, + ) + + config = SimpleNamespace( + model_type="phi4mm", + auto_map={"AutoModelForCausalLM": "modeling_phi4mm.Phi4MMForCausalLM"}, + ) + + prepare_remote_model_init_compat("/tmp/phi4mm", config) + + model = Phi4MMForCausalLM() + + assert model.forward(input_ids="ids") is InputMode.LANGUAGE + assert model.forward(input_audio_embeds="audio") is InputMode.SPEECH + assert model.forward(input_image_embeds="image") is InputMode.VISION + + +def test_prepare_remote_model_init_compat_skips_input_mode_patch_without_forward(monkeypatch): + class Phi4MMForCausalLM: + pass + + remote_module_name = "transformers_modules.fake_phi4_no_forward.modeling_phi4mm" + speech_module_name = "transformers_modules.fake_phi4_no_forward.speech_conformer_encoder" + remote_module = SimpleNamespace(Phi4MMForCausalLM=Phi4MMForCausalLM) + speech_module = SimpleNamespace(torch=SimpleNamespace(tensor=lambda data, *args, **kwargs: (data, kwargs))) + + Phi4MMForCausalLM.__module__ = remote_module_name + monkeypatch.setitem(sys.modules, remote_module_name, remote_module) + monkeypatch.setitem(sys.modules, speech_module_name, speech_module) + monkeypatch.setattr( + "transformers.dynamic_module_utils.get_class_from_dynamic_module", + lambda class_ref, model_id_or_path, **kwargs: Phi4MMForCausalLM, + ) + + config = SimpleNamespace( + model_type="phi4mm", + auto_map={"AutoModelForCausalLM": "modeling_phi4mm.Phi4MMForCausalLM"}, + ) + + prepare_remote_model_init_compat("/tmp/phi4mm", config) + + assert not hasattr(Phi4MMForCausalLM, "_gptqmodel_input_mode_patch") + + +def test_resolve_trust_remote_code_overrides_when_native_support_exists(monkeypatch, capsys): + hf_utils._TRUST_REMOTE_CODE_OVERRIDE_WARNED.clear() + monkeypatch.setattr( + hf_utils, + "_detect_native_transformers_causallm_support", + lambda model_id_or_path: (True, "phi3", "Phi3ForCausalLM"), + ) + + resolved = resolve_trust_remote_code("/tmp/phi3", trust_remote_code=True) + captured = capsys.readouterr() + + assert resolved is False + assert "overriding trust_remote_code=True to False" in captured.out + captured.err + + +def test_resolve_trust_remote_code_keeps_true_without_native_support(monkeypatch, caplog): + hf_utils._TRUST_REMOTE_CODE_OVERRIDE_WARNED.clear() + monkeypatch.setattr( + hf_utils, + "_detect_native_transformers_causallm_support", + lambda model_id_or_path: (False, None, None), + ) + + with caplog.at_level("WARNING"): + resolved = resolve_trust_remote_code("/tmp/custom", trust_remote_code=True) + + assert resolved is True + assert caplog.text == "" diff --git a/tests/test_hf_init_guard.py b/tests/test_hf_init_guard.py new file mode 100644 index 000000000..2bd98de02 --- /dev/null +++ b/tests/test_hf_init_guard.py @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch +import torch.nn as nn +import transformers + +from gptqmodel.utils.hf import suspend_hf_weight_init + + +def test_suspend_hf_weight_init_restores_globals_after_exception(): + modeling_utils = transformers.modeling_utils + had_init_flag = hasattr(modeling_utils, "_init_weights") + original_init_flag = getattr(modeling_utils, "_init_weights", None) + original_kaiming_uniform = torch.nn.init.kaiming_uniform_ + original_uniform = torch.nn.init.uniform_ + original_normal = torch.nn.init.normal_ + + with pytest.raises(RuntimeError, match="boom"): + with suspend_hf_weight_init(): + assert torch.nn.init.kaiming_uniform_ is not original_kaiming_uniform + assert torch.nn.init.uniform_ is not original_uniform + assert torch.nn.init.normal_ is not original_normal + assert getattr(modeling_utils, "_init_weights") is False + raise RuntimeError("boom") + + assert torch.nn.init.kaiming_uniform_ is original_kaiming_uniform + assert torch.nn.init.uniform_ is original_uniform + assert torch.nn.init.normal_ is original_normal + + if had_init_flag: + assert getattr(modeling_utils, "_init_weights") == original_init_flag + else: + assert not hasattr(modeling_utils, "_init_weights") + + linear = nn.Linear(32, 32, bias=False) + assert torch.isfinite(linear.weight).all() diff --git a/tests/test_hf_utils.py b/tests/test_hf_utils.py new file mode 100644 index 000000000..dc86dd09b --- /dev/null +++ b/tests/test_hf_utils.py @@ -0,0 +1,48 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from torch import nn +from transformers import PretrainedConfig, PreTrainedModel +from transformers.modeling_utils import _get_tied_weight_keys + +from gptqmodel.utils import hf as _hf_utils # noqa: F401 + + +class _DummyConfig(PretrainedConfig): + model_type = "dummy_hf_compat" + + def __init__(self): + super().__init__(tie_word_embeddings=True) + self.vocab_size = 8 + self.hidden_size = 4 + + +class _LegacyTiedWeightsModel(PreTrainedModel): + config_class = _DummyConfig + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def forward(self, *args, **kwargs): + raise NotImplementedError + + +def test_legacy_list_tied_weights_are_normalized_to_input_embeddings(): + model = _LegacyTiedWeightsModel(_DummyConfig()) + + assert model.get_expanded_tied_weights_keys(all_submodels=False) == { + "lm_head.weight": "embed_tokens.weight" + } + assert model._tied_weights_keys == {"lm_head.weight": "embed_tokens.weight"} + assert _get_tied_weight_keys(model) == ["lm_head.weight"] diff --git a/tests/test_inference_speed.py b/tests/test_inference_speed.py index 6e33ad876..6fb594b4b 100644 --- a/tests/test_inference_speed.py +++ b/tests/test_inference_speed.py @@ -9,6 +9,7 @@ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch +import pytest # noqa: E402 from inference_speed import InferenceSpeed # noqa: E402 from parameterized import parameterized # noqa: E402 @@ -22,23 +23,21 @@ (InferenceSpeed.NATIVE_MODEL_ID, BACKEND.MARLIN, 748), (InferenceSpeed.NATIVE_MODEL_ID, BACKEND.CUDA, 493), -(InferenceSpeed.NATIVE_MODEL_ID, BACKEND.EXLLAMA_V1, 717), (InferenceSpeed.NATIVE_MODEL_ID, BACKEND.EXLLAMA_V2, 775), (InferenceSpeed.NATIVE_MODEL_ID, BACKEND.TRITON, 296), (InferenceSpeed.NATIVE_MODEL_ID, BACKEND.TORCH, 295), (InferenceSpeed.BITBLAS_NATIVE_MODEL_ID, BACKEND.BITBLAS, 1474), -(InferenceSpeed.NATIVE_MODEL_ID, BACKEND.IPEX, 48), ''' +pytestmark = [pytest.mark.model, pytest.mark.slow] + class TestInferenceSpeed(InferenceSpeed): @parameterized.expand( [ - (InferenceSpeed.NATIVE_MODEL_ID, BACKEND.EXLLAMA_EORA, 282.64, False, False), (InferenceSpeed.NATIVE_MODEL_ID, BACKEND.MARLIN, 286.74, False, False), - (InferenceSpeed.NATIVE_MODEL_ID, BACKEND.TORCH, 176.00, False, False), + (InferenceSpeed.NATIVE_MODEL_ID, BACKEND.TORCH, 259.00, False, False), # 4090/A100 max expectation within the current 75%-125% pass band # (InferenceSpeed.NATIVE_MODEL_ID, BACKEND.TORCH, 53, False, False), - (InferenceSpeed.NATIVE_MODEL_ID, BACKEND.EXLLAMA_V1, 282.64, False, False), (InferenceSpeed.NATIVE_MODEL_ID, BACKEND.EXLLAMA_V2, 290.60, False, False), (InferenceSpeed.NATIVE_MODEL_ID, BACKEND.TRITON, 239.58, False, False), (InferenceSpeed.BITBLAS_NATIVE_MODEL_ID, BACKEND.BITBLAS, 2167.38, False, False), # Second time running bitblas, there is cache diff --git a/tests/test_inference_speed_harness.py b/tests/test_inference_speed_harness.py new file mode 100644 index 000000000..e14738c21 --- /dev/null +++ b/tests/test_inference_speed_harness.py @@ -0,0 +1,116 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +import inference_speed as inference_speed_module +import torch +from inference_speed import InferenceSpeed + + +class _FakeProgress: + def __init__(self, iterable): + self._iterable = iterable + + def title(self, _text): + return self + + def __iter__(self): + return iter(self._iterable) + + +class _FakeLogger: + def pb(self, iterable): + return _FakeProgress(iterable) + + +class _FakeBatch(dict): + def to(self, _device): + return self + + +class _FakeTokenizer: + eos_token_id = 0 + pad_token_id = 0 + + @classmethod + def from_pretrained(cls, _model_path): + return cls() + + def __call__(self, prompts, **_kwargs): + batch = len(prompts) + return _FakeBatch({"input_ids": torch.zeros((batch, 3), dtype=torch.long)}) + + +class _FakeModel: + device = "cuda:0" + + @classmethod + def from_quantized(cls, *_args, **_kwargs): + return cls() + + def generate(self, input_ids, max_new_tokens, **_kwargs): + batch, prompt_len = input_ids.shape + return torch.zeros((batch, prompt_len + max_new_tokens), dtype=torch.long) + + +def test_inference_speed_excludes_warmup_from_asserted_throughput(monkeypatch): + class _Harness(InferenceSpeed): + NUM_RUNS = 2 + MAX_NEW_TOKENS = 10 + PROMPTS = ["a", "b"] + + timestamps = iter([ + 0.0, 10.0, # warmup: intentionally slow + 100.0, 101.0, + 200.0, 201.0, + ]) + + monkeypatch.setattr(inference_speed_module, "logger", _FakeLogger()) + monkeypatch.setattr(inference_speed_module, "GPTQModel", _FakeModel) + monkeypatch.setattr(inference_speed_module, "AutoTokenizer", _FakeTokenizer) + monkeypatch.setattr(inference_speed_module, "torch_empty_cache", lambda: None) + monkeypatch.setattr(inference_speed_module.time, "time", lambda: next(timestamps)) + + measured_tps = _Harness().inference( + model_path="unused", + backend="fake-backend", + tokens_per_second=20.0, + warmup_runs=1, + device="cuda", + ) + + assert measured_tps == 20.0 + + +def test_inference_speed_pins_bare_cuda_to_current_device(monkeypatch): + class _Harness(InferenceSpeed): + NUM_RUNS = 1 + MAX_NEW_TOKENS = 1 + PROMPTS = ["a"] + + captured = {} + + class _CapturingModel(_FakeModel): + @classmethod + def from_quantized(cls, model_path, backend, device): + captured["device"] = device + return cls() + + timestamps = iter([0.0, 1.0]) + + monkeypatch.setattr(inference_speed_module, "logger", _FakeLogger()) + monkeypatch.setattr(inference_speed_module, "GPTQModel", _CapturingModel) + monkeypatch.setattr(inference_speed_module, "AutoTokenizer", _FakeTokenizer) + monkeypatch.setattr(inference_speed_module, "torch_empty_cache", lambda: None) + monkeypatch.setattr(inference_speed_module.time, "time", lambda: next(timestamps)) + monkeypatch.setattr(inference_speed_module.torch.cuda, "is_available", lambda: True) + monkeypatch.setattr(inference_speed_module.torch.cuda, "current_device", lambda: 3) + + _Harness().inference( + model_path="unused", + backend="fake-backend", + tokens_per_second=1.0, + warmup_runs=0, + device="cuda", + ) + + assert captured["device"] == "cuda:3" diff --git a/tests/test_integration.py b/tests/test_integration.py index 9d185bd3e..14a252f67 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -10,6 +10,7 @@ import unittest # noqa: E402 import torch +from models.model_test import ModelTest # noqa: E402 from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig # noqa: E402 from gptqmodel.utils.torch import torch_empty_cache # noqa: E402 @@ -60,8 +61,13 @@ def _test_quantize(self, device_map): model = AutoModelForCausalLM.from_pretrained(tmp_dir, device_map=device_map) - generate_str = tokenizer.decode( - model.generate(**tokenizer("gptqmodel is", return_tensors="pt").to(model.device))[0]) + generate_str = ModelTest.generate_stable_with_limit( + model, + tokenizer, + "gptqmodel is", + max_new_tokens=30, + skip_special_tokens=False, + ) self.assertIn("is a good", generate_str.lower()) @@ -103,9 +109,14 @@ def assertInference(self, model, tokenizer=None, keywords=None, prompt=INFERENCE def generate(self, model, tokenizer, prompt=None): if prompt is None: prompt = self.INFERENCE_PROMPT - inp = tokenizer(prompt, return_tensors="pt").to(model.device) - res = model.generate(**inp, num_beams=1, do_sample=False, min_new_tokens=10, max_new_tokens=30) - output = tokenizer.decode(res[0]) + output = ModelTest.generate_stable_with_limit( + model, + tokenizer, + prompt, + min_new_tokens=10, + max_new_tokens=30, + skip_special_tokens=False, + ) print(f"Result is: >>\n{output}\n<<") return output @@ -117,18 +128,13 @@ def test_llm_awq(self): ) tokenizer = AutoTokenizer.from_pretrained(model_name) - inputs = tokenizer("Capital of France is", return_tensors="pt").to(model.device) - with torch.no_grad(): - outputs = model.generate( - **inputs, + result = ModelTest.generate_stable_with_limit( + model, + tokenizer, + "The capital city of France is named", max_new_tokens=128, - temperature=0.7, - top_p=0.9, - do_sample=True ) - - result = tokenizer.decode(outputs[0], skip_special_tokens=True) print("result:", result) if "paris" not in result.lower() and "city" not in result.lower() and "food" not in result.lower() and "market" not in result.lower(): diff --git a/tests/test_internal_gguf.py b/tests/test_internal_gguf.py new file mode 100644 index 000000000..50a92fddd --- /dev/null +++ b/tests/test_internal_gguf.py @@ -0,0 +1,133 @@ +import struct +from types import SimpleNamespace + +import numpy as np +import torch + +from gptqmodel.utils import internal_gguf + + +def _encode_gguf_string(value: str) -> bytes: + data = value.encode("utf-8") + return struct.pack(" bool: + assert name == "machete" + return self.available + + def error(self, name: str) -> str: + assert name == "machete" + return self.error_message + + def load(self, *, name: str) -> dict[str, bool]: + self.load_requests.append(name) + return {"machete": True} + + def op(self, name: str, op_name: str): + assert name == "machete" + self.op_requests.append((name, op_name)) + if op_name == "machete_prepack_B": + return lambda *args: self.prepack_result + if op_name == "machete_supported_schedules": + return lambda *args: self.schedule_result + if op_name == "machete_mm": + return lambda *args: self.mm_result + raise AssertionError(f"unexpected op {op_name}") + + +def _write_fake_cutlass_checkout(destination: Path, *, version: str) -> None: + major, minor, patch = version.split(".") + (destination / "include" / "cutlass").mkdir(parents=True, exist_ok=True) + (destination / "tools" / "library" / "include").mkdir(parents=True, exist_ok=True) + (destination / "tools" / "util" / "include").mkdir(parents=True, exist_ok=True) + (destination / "python").mkdir(parents=True, exist_ok=True) + (destination / "include" / "cutlass" / "cutlass.h").write_text("// cutlass\n", encoding="utf-8") + (destination / "include" / "cutlass" / "version.h").write_text( + ( + "#pragma once\n" + f"#define CUTLASS_MAJOR {major}\n" + f"#define CUTLASS_MINOR {minor}\n" + f"#define CUTLASS_PATCH {patch}\n" + ), + encoding="utf-8", + ) + (destination / "python" / "cutlass_library.py").write_text("# cutlass python\n", encoding="utf-8") + + +def _write_fake_cutlass_archive(destination: Path) -> None: + staging_root = destination.parent / f"cutlass-{machete_utils._CUTLASS_VERSION}" + _write_fake_cutlass_checkout(staging_root, version=machete_utils._CUTLASS_VERSION) + + with tarfile.open(destination, "w:gz") as archive: + archive.add(staging_root, arcname=f"cutlass-{machete_utils._CUTLASS_VERSION}") + + shutil.rmtree(staging_root, ignore_errors=True) + + +def test_machete_runtime_routes_through_extension_api(monkeypatch): + fake_api = _FakeExtensionApi() + monkeypatch.setattr(machete_utils, "_extension_api", lambda: fake_api) + monkeypatch.setattr(machete_utils, "_machete_static_runtime_error", lambda: "") + + prepacked = machete_utils.machete_prepack_B( + torch.ones((1, 1), dtype=torch.int32), + torch.float16, + scalar_types.uint4b8, + torch.float16, + ) + schedules = machete_utils.machete_supported_schedules(torch.float16, scalar_types.uint4b8) + output = machete_utils.machete_mm( + a=torch.ones((1, 1), dtype=torch.float16), + b_q=torch.ones((1, 1), dtype=torch.int32), + b_type=scalar_types.uint4b8, + ) + + assert prepacked is fake_api.prepack_result + assert schedules == ["sch_a", "sch_b"] + assert output is fake_api.mm_result + assert machete_utils.prewarm_machete_extension() is True + assert fake_api.load_requests == ["machete"] + assert fake_api.op_requests == [ + ("machete", "machete_prepack_B"), + ("machete", "machete_supported_schedules"), + ("machete", "machete_mm"), + ] + + +def test_machete_static_runtime_error_requires_hopper_sm90(monkeypatch): + class _Props: + name = "NVIDIA RTX PRO 6000 Blackwell Server Edition" + shared_memory_per_block = 49152 + shared_memory_per_block_optin = 101376 + + monkeypatch.setattr(machete_utils, "IS_ROCM", False) + monkeypatch.setattr(torch.cuda, "is_available", lambda: True) + monkeypatch.setattr(torch.cuda, "get_device_capability", lambda *args, **kwargs: (12, 0)) + monkeypatch.setattr(torch.cuda, "current_device", lambda: 0) + monkeypatch.setattr(torch.cuda, "get_device_properties", lambda *_args, **_kwargs: _Props()) + + error = machete_utils._machete_static_runtime_error() + + assert "Hopper-class SM90 GPUs only" in error + assert "12.0" in error + + +def test_machete_static_runtime_error_checks_optin_shared_memory(monkeypatch): + class _Props: + name = "NVIDIA H100 80GB HBM3" + shared_memory_per_block = 49152 + shared_memory_per_block_optin = 98304 + + monkeypatch.setattr(machete_utils, "IS_ROCM", False) + monkeypatch.setattr(torch.cuda, "is_available", lambda: True) + monkeypatch.setattr(torch.cuda, "get_device_capability", lambda *args, **kwargs: (9, 0)) + monkeypatch.setattr(torch.cuda, "current_device", lambda: 0) + monkeypatch.setattr(torch.cuda, "get_device_properties", lambda *_args, **_kwargs: _Props()) + + error = machete_utils._machete_static_runtime_error() + + assert str(machete_utils._MACHETE_MIN_SHARED_MEMORY_PER_BLOCK_OPTIN) in error + assert "98304" in error + + +def test_machete_registers_checkpoint_compatible_qzeros_shape_for_symmetric_gptq(): + module = machete_linear_module.MacheteLinear( + bits=4, + group_size=128, + desc_act=False, + sym=True, + in_features=8192, + out_features=3072, + bias=False, + dtype=torch.float16, + ) + + assert module.qzeros.shape == (64, 384) + assert module.qzeros.dtype == torch.int32 + + +def test_machete_load_state_dict_accepts_checkpoint_qzeros_shape(): + module = machete_linear_module.MacheteLinear( + bits=4, + group_size=128, + desc_act=False, + sym=True, + in_features=8192, + out_features=3072, + bias=False, + dtype=torch.float16, + ) + state_dict = module.state_dict() + state_dict["qzeros"] = torch.zeros((64, 384), dtype=torch.int32) + + module.load_state_dict(state_dict) + + assert module.qzeros.shape == (64, 384) + + +def test_machete_post_init_discards_loaded_qzeros_for_symmetric_gptq(monkeypatch): + module = machete_linear_module.MacheteLinear( + bits=4, + group_size=128, + desc_act=False, + sym=True, + in_features=128, + out_features=128, + bias=False, + dtype=torch.float16, + ) + with torch.no_grad(): + module.qweight.copy_(torch.randint(0, 16, module.qweight.shape, dtype=torch.int32)) + module.g_idx.copy_(torch.arange(module.in_features, dtype=torch.int32)) + module.scales.copy_(torch.ones_like(module.scales)) + module.qzeros.copy_(torch.randint(0, 16, module.qzeros.shape, dtype=torch.int32)) + + monkeypatch.setattr( + machete_linear_module, + "machete_prepack_B", + lambda weight, **_kwargs: weight.contiguous(), + ) + + module.post_init() + + assert module.qzeros.numel() == 0 + assert module.qzeros.dtype == torch.int32 + assert module.has_zero_points is False + + +def test_machete_validate_accepts_asymmetric_gptq(monkeypatch): + monkeypatch.setattr( + machete_linear_module.MacheteLinear, + "cached_validate_once", + classmethod(lambda qlinear_cls: (True, None)), + ) + monkeypatch.setattr( + machete_linear_module, + "_validate_machete_device_support", + lambda: True, + ) + + ok, err = machete_linear_module.MacheteLinear.validate( + bits=4, + group_size=128, + desc_act=False, + sym=False, + in_features=128, + out_features=128, + pack_dtype=torch.int32, + dtype=torch.float16, + ) + + assert ok is True + assert err is None + + +def test_machete_post_init_transforms_loaded_qzeros_for_asymmetric_gptq(monkeypatch): + module = machete_linear_module.MacheteLinear( + bits=4, + group_size=64, + desc_act=False, + sym=False, + in_features=128, + out_features=128, + bias=False, + dtype=torch.float16, + ) + with torch.no_grad(): + module.qweight.copy_(torch.randint(0, 16, module.qweight.shape, dtype=torch.int32)) + module.g_idx.copy_(torch.tensor([0] * 64 + [1] * 64, dtype=torch.int32)) + module.scales.copy_(torch.rand_like(module.scales) + 0.25) + module.qzeros.copy_(torch.randint(0, 16, module.qzeros.shape, dtype=torch.int32)) + + expected_qzeros = ( + -1.0 + * module.scales.detach().clone() + * machete_utils.unpack_quantized_values_into_int32( + module.qzeros.detach().clone(), + module.weight_type, + packed_dim=1, + ).to(module.scales.dtype) + ).contiguous() + + monkeypatch.setattr( + machete_linear_module, + "machete_prepack_B", + lambda weight, **_kwargs: weight.contiguous(), + ) + + module.post_init() + + torch.testing.assert_close(module.qzeros, expected_qzeros) + assert module.qzeros.dtype == torch.float16 + assert module.has_zero_points is True + + +def test_machete_forward_rejects_missing_qzeros_for_asymmetric_gptq(): + module = machete_linear_module.MacheteLinear( + bits=4, + group_size=64, + desc_act=False, + sym=False, + in_features=128, + out_features=128, + bias=False, + dtype=torch.float16, + ) + module.has_zero_points = True + module.qzeros = torch.nn.Parameter(torch.empty(0, dtype=torch.float16), requires_grad=False) + + with pytest.raises(AssertionError, match="requires non-empty qzeros"): + module(torch.randn(2, 128, dtype=torch.float16)) + + +def test_machete_hopper_arch_cuda_cflags_add_sm90a_when_torch_only_targets_sm90(monkeypatch): + class _Props: + name = "NVIDIA H200" + shared_memory_per_block = 49152 + shared_memory_per_block_optin = 232448 + + monkeypatch.setattr(machete_utils, "IS_ROCM", False) + monkeypatch.setattr(torch.cuda, "is_available", lambda: True) + monkeypatch.setattr(torch.cuda, "get_device_capability", lambda *args, **kwargs: (9, 0)) + monkeypatch.setattr(torch.cuda, "current_device", lambda: 0) + monkeypatch.setattr(torch.cuda, "get_device_properties", lambda *_args, **_kwargs: _Props()) + monkeypatch.setattr(machete_utils, "resolved_cuda_arch_flags", lambda: ["-gencode=arch=compute_90,code=sm_90"]) + + flags = machete_utils._machete_hopper_arch_cuda_cflags() + + assert flags == list(machete_utils._MACHETE_SM90A_ARCH_FLAGS) + + +def test_machete_hopper_arch_cuda_cflags_skip_duplicate_sm90a(monkeypatch): + class _Props: + name = "NVIDIA H200" + shared_memory_per_block = 49152 + shared_memory_per_block_optin = 232448 + + monkeypatch.setattr(machete_utils, "IS_ROCM", False) + monkeypatch.setattr(torch.cuda, "is_available", lambda: True) + monkeypatch.setattr(torch.cuda, "get_device_capability", lambda *args, **kwargs: (9, 0)) + monkeypatch.setattr(torch.cuda, "current_device", lambda: 0) + monkeypatch.setattr(torch.cuda, "get_device_properties", lambda *_args, **_kwargs: _Props()) + monkeypatch.setattr( + machete_utils, + "resolved_cuda_arch_flags", + lambda: ["-gencode=arch=compute_90a,code=sm_90a"], + ) + + assert machete_utils._machete_hopper_arch_cuda_cflags() == [] + + +def test_machete_extra_cuda_cflags_keep_only_required_torch_undefines(monkeypatch): + monkeypatch.setattr(machete_utils, "_machete_cuda_version_at_least", lambda *_args: False) + monkeypatch.setattr(machete_utils, "_machete_hopper_arch_cuda_cflags", lambda: []) + + flags = machete_utils._machete_extra_cuda_cflags() + + assert flags[:3] == list(machete_utils._MACHETE_REQUIRED_TORCH_NVCC_UNDEFINES) + assert "--threads" in flags + assert flags[flags.index("--threads") + 1] == machete_utils._MACHETE_JIT_NVCC_THREADS + assert "-U__CUDA_NO_BFLOAT16_OPERATORS__" not in flags + assert "-U__CUDA_NO_BFLOAT162_OPERATORS__" not in flags + assert "-U__CUDA_NO_BFLOAT162_CONVERSIONS__" not in flags + assert "-U__CUDA_NO_HALF2_OPERATORS__" not in flags + + +def test_machete_extra_cuda_cflags_enable_static_global_template_stub_for_cuda_12_8_plus(monkeypatch): + monkeypatch.setattr(machete_utils, "_machete_cuda_version_at_least", lambda major, minor: (major, minor) == (12, 8)) + monkeypatch.setattr(machete_utils, "_machete_hopper_arch_cuda_cflags", lambda: []) + + flags = machete_utils._machete_extra_cuda_cflags() + + assert flags[0] == "-static-global-template-stub=false" + assert flags[1:4] == list(machete_utils._MACHETE_REQUIRED_TORCH_NVCC_UNDEFINES) + + +def test_ensure_cutlass_source_bootstraps_repo_local_checkout(monkeypatch, tmp_path): + archive_path = tmp_path / f"cutlass-v{machete_utils._CUTLASS_VERSION}.tar.gz" + _write_fake_cutlass_archive(archive_path) + + monkeypatch.setattr(machete_utils, "_machete_project_root", lambda: tmp_path) + monkeypatch.delenv("GPTQMODEL_CUTLASS_DIR", raising=False) + monkeypatch.setattr( + machete_utils, + "_download_cutlass_archive", + lambda _url, destination: shutil.copyfile(archive_path, destination), + ) + + cutlass_root = machete_utils._ensure_cutlass_source() + monkeypatch.setenv("GPTQMODEL_CUTLASS_DIR", os.environ["GPTQMODEL_CUTLASS_DIR"]) + + assert cutlass_root == (tmp_path / "cutlass").resolve() + assert (cutlass_root / "include" / "cutlass" / "cutlass.h").is_file() + assert (cutlass_root / "python" / "cutlass_library.py").is_file() + assert (cutlass_root / machete_utils._CUTLASS_VERSION_MARKER).read_text(encoding="utf-8").strip() == machete_utils._CUTLASS_VERSION + assert str(cutlass_root) == str((tmp_path / "cutlass").resolve()) + assert str(cutlass_root) == os.environ["GPTQMODEL_CUTLASS_DIR"] + + +def test_cutlass_checkout_complete_accepts_tools_util_layout(tmp_path): + cutlass_root = tmp_path / "cutlass" + _write_fake_cutlass_checkout(cutlass_root, version=machete_utils._CUTLASS_VERSION) + (cutlass_root / "python" / "cutlass_library").mkdir(parents=True, exist_ok=True) + (cutlass_root / "python" / "cutlass_library.py").unlink() + (cutlass_root / "python" / "cutlass_library" / "__init__.py").write_text("# bindings\n", encoding="utf-8") + + assert machete_utils._cutlass_checkout_complete(cutlass_root) + + +def test_cutlass_checkout_version_reads_header_macros(tmp_path): + cutlass_root = tmp_path / "cutlass" + _write_fake_cutlass_checkout(cutlass_root, version="3.5.0") + + assert machete_utils._cutlass_checkout_version(cutlass_root) == "3.5.0" + + +def test_ensure_cutlass_source_rejects_mismatched_configured_checkout(monkeypatch, tmp_path): + wrong_cutlass = tmp_path / "wrong-cutlass" + _write_fake_cutlass_checkout(wrong_cutlass, version="3.5.0") + monkeypatch.setattr(machete_utils, "_machete_project_root", lambda: tmp_path / "project") + monkeypatch.setenv("GPTQMODEL_CUTLASS_DIR", str(wrong_cutlass)) + + with pytest.raises(RuntimeError, match=r"CUTLASS v3\.5\.0.*requires v4\.4\.2"): + machete_utils._ensure_cutlass_source() + + +def test_ensure_cutlass_source_marks_matching_repo_local_checkout_without_redownload(monkeypatch, tmp_path): + repo_root = tmp_path + repo_cutlass = repo_root / "cutlass" + _write_fake_cutlass_checkout(repo_cutlass, version=machete_utils._CUTLASS_VERSION) + download_calls: list[Path] = [] + + monkeypatch.setattr(machete_utils, "_machete_project_root", lambda: repo_root) + monkeypatch.delenv("GPTQMODEL_CUTLASS_DIR", raising=False) + monkeypatch.setattr( + machete_utils, + "_download_cutlass_archive", + lambda _url, destination: download_calls.append(destination), + ) + + cutlass_root = machete_utils._ensure_cutlass_source() + + assert cutlass_root == repo_cutlass.resolve() + assert (repo_cutlass / machete_utils._CUTLASS_VERSION_MARKER).read_text(encoding="utf-8").strip() == machete_utils._CUTLASS_VERSION + assert download_calls == [] + + +def test_ensure_cutlass_source_refreshes_repo_local_checkout_when_version_mismatches(monkeypatch, tmp_path): + archive_path = tmp_path / f"cutlass-v{machete_utils._CUTLASS_VERSION}.tar.gz" + _write_fake_cutlass_archive(archive_path) + repo_cutlass = tmp_path / "cutlass" + _write_fake_cutlass_checkout(repo_cutlass, version="3.5.0") + + monkeypatch.setattr(machete_utils, "_machete_project_root", lambda: tmp_path) + monkeypatch.delenv("GPTQMODEL_CUTLASS_DIR", raising=False) + monkeypatch.setattr( + machete_utils, + "_download_cutlass_archive", + lambda _url, destination: shutil.copyfile(archive_path, destination), + ) + + cutlass_root = machete_utils._ensure_cutlass_source() + + assert cutlass_root == repo_cutlass.resolve() + assert machete_utils._cutlass_checkout_version(cutlass_root) == machete_utils._CUTLASS_VERSION + assert (cutlass_root / machete_utils._CUTLASS_VERSION_MARKER).read_text(encoding="utf-8").strip() == machete_utils._CUTLASS_VERSION + + +def test_scaled_mm_epilogues_c3x_matches_cutlass_442_broadcast_signatures(): + header = ( + Path(__file__).resolve().parents[1] + / "gptqmodel_ext" + / "cutlass_extensions" + / "epilogue" + / "scaled_mm_epilogues_c3x.hpp" + ).read_text(encoding="utf-8") + + assert "Sm90ColBroadcast<\n 0 /*Stages*/, TileShape, T, Stride, Int<0>, Int<0>>" not in header + assert "Sm90RowBroadcast<\n 0 /*Stages*/, TileShape, T, Stride, Int<1>, Int<0>>" not in header + assert "Sm90ColBroadcast<\n 0 /*Stages*/, TileShape, T, T, Stride, Int<0>, Int<0>>" in header + assert "Sm90RowBroadcast<\n 0 /*Stages*/, TileShape, T, T, Stride, Int<1>, Int<0>>" in header + + +def test_machete_mm_kernel_plain_store_uses_trivial_epilogue(): + kernel_header = ( + Path(__file__).resolve().parents[1] + / "gptqmodel_ext" + / "machete" + / "machete_mm_kernel.cuh" + ).read_text(encoding="utf-8") + + assert "TrivialEpilogue" not in kernel_header + + +def test_machete_sources_generate_once_when_missing(monkeypatch, tmp_path): + machete_root = tmp_path / "gptqmodel_ext" / "machete" + cutlass_ext_root = tmp_path / "gptqmodel_ext" / "cutlass_extensions" + machete_root.mkdir(parents=True, exist_ok=True) + cutlass_ext_root.mkdir(parents=True, exist_ok=True) + (machete_root / "generate.py").write_text("# generator\n", encoding="utf-8") + (machete_root / "machete_pytorch.cu").write_text("// pytorch\n", encoding="utf-8") + (cutlass_ext_root / "vllm_cutlass_library_extension.py").write_text("# helper\n", encoding="utf-8") + + fake_cutlass = tmp_path / "cutlass" + fake_cutlass.mkdir(parents=True, exist_ok=True) + run_calls: list[list[str]] = [] + + def fake_run(args, cwd, env, check, capture_output, text): + del cwd, env, check, capture_output, text + run_calls.append(list(args)) + generated_dir = machete_root / "generated" + generated_dir.mkdir(parents=True, exist_ok=True) + (generated_dir / "machete_dispatch.cu").write_text("// generated\n", encoding="utf-8") + return subprocess.CompletedProcess(args=args, returncode=0, stdout="", stderr="") + + monkeypatch.setattr(machete_utils, "_machete_project_root", lambda: tmp_path) + monkeypatch.setattr(machete_utils, "_ensure_cutlass_source", lambda: fake_cutlass) + monkeypatch.setattr(subprocess, "run", fake_run) + + sources_first = machete_utils._machete_sources() + sources_second = machete_utils._machete_sources() + + assert run_calls == [[sys.executable, str(machete_root / "generate.py")]] + assert sources_first == sources_second + assert sources_first[0] == str(machete_root / "machete_pytorch.cu") + assert sources_first[1] == str(machete_root / "generated" / "machete_dispatch.cu") + + +def test_machete_sources_regenerate_when_cutlass_root_changes(monkeypatch, tmp_path): + machete_root = tmp_path / "gptqmodel_ext" / "machete" + cutlass_ext_root = tmp_path / "gptqmodel_ext" / "cutlass_extensions" + machete_root.mkdir(parents=True, exist_ok=True) + cutlass_ext_root.mkdir(parents=True, exist_ok=True) + (machete_root / "generate.py").write_text("# generator\n", encoding="utf-8") + (machete_root / "machete_pytorch.cu").write_text("// pytorch\n", encoding="utf-8") + (cutlass_ext_root / "vllm_cutlass_library_extension.py").write_text("# helper\n", encoding="utf-8") + + cutlass_a = tmp_path / "cutlass_a" + cutlass_b = tmp_path / "cutlass_b" + for cutlass_root in (cutlass_a, cutlass_b): + (cutlass_root / "python").mkdir(parents=True, exist_ok=True) + (cutlass_root / "python" / "cutlass_library.py").write_text("# bindings\n", encoding="utf-8") + + run_calls: list[list[str]] = [] + current_cutlass_root = cutlass_a + + def fake_run(args, cwd, env, check, capture_output, text): + del cwd, check, capture_output, text + run_calls.append(list(args)) + generated_dir = machete_root / "generated" + generated_dir.mkdir(parents=True, exist_ok=True) + (generated_dir / "machete_dispatch.cu").write_text( + f"// generated for {env['GPTQMODEL_CUTLASS_DIR']}\n", + encoding="utf-8", + ) + return subprocess.CompletedProcess(args=args, returncode=0, stdout="", stderr="") + + monkeypatch.setattr(machete_utils, "_machete_project_root", lambda: tmp_path) + monkeypatch.setattr(machete_utils, "_ensure_cutlass_source", lambda: current_cutlass_root) + monkeypatch.setattr(subprocess, "run", fake_run) + + machete_utils._machete_sources() + machete_utils._machete_sources() + current_cutlass_root = cutlass_b + machete_utils._machete_sources() + + assert run_calls == [ + [sys.executable, str(machete_root / "generate.py")], + [sys.executable, str(machete_root / "generate.py")], + ] + + +def test_machete_ldflags_link_cuda_driver(): + assert "-lcuda" in machete_utils._machete_extra_ldflags() + + +def test_vllm_cutlass_library_extension_imports_cleanly_in_subprocess(): + root = Path(__file__).resolve().parents[1] + cutlass_python_dir = root / "cutlass" / "python" + cutlass_ext_dir = root / "gptqmodel_ext" / "cutlass_extensions" + + result = subprocess.run( + [ + sys.executable, + "-c", + ( + "import sys; " + f"sys.path.insert(0, {str(cutlass_ext_dir)!r}); " + f"sys.path.insert(1, {str(cutlass_python_dir)!r}); " + "import vllm_cutlass_library_extension as ext; " + "print(ext.VLLMDataType.u4b8.name)" + ), + ], + check=False, + capture_output=True, + text=True, + ) + + assert result.returncode == 0, result.stderr + assert result.stdout.strip() == "u4b8" + + +def _jit_scratch_root(tmp_path: Path, suffix: str) -> Path: + base = Path("/dev/shm") if Path("/dev/shm").is_dir() else tmp_path + root = base / "gptqmodel-jit-tests" / suffix + root.mkdir(parents=True, exist_ok=True) + return root + + +def _dense_asymmetric_gptq_reference( + *, + x: torch.Tensor, + qweight: torch.Tensor, + qzeros: torch.Tensor, + scales: torch.Tensor, + g_idx: torch.Tensor, + weight_type, +) -> torch.Tensor: + int_weight = machete_utils.unpack_quantized_values_into_int32( + qweight, + weight_type, + packed_dim=0, + ).to(dtype=scales.dtype) + int_zeros = machete_utils.unpack_quantized_values_into_int32( + qzeros, + weight_type, + packed_dim=1, + ).to(dtype=scales.dtype) + dense_weight = scales[g_idx.long()] * (int_weight - int_zeros[g_idx.long()]) + return x @ dense_weight + + +@pytest.mark.cuda +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_machete_cuda_smoke_build_and_forward(monkeypatch, tmp_path): + if not machete_utils._validate_machete_device_support(): + pytest.skip(machete_utils.machete_runtime_error()) + + scratch_root = _jit_scratch_root(tmp_path, "machete") + monkeypatch.setenv("CUDA_DEVICE_ORDER", "PCI_BUS_ID") + monkeypatch.delenv("GPTQMODEL_CUTLASS_DIR", raising=False) + monkeypatch.setenv("GPTQMODEL_MACHETE_BUILD_ROOT", str(scratch_root / "machete")) + monkeypatch.setenv("GPTQMODEL_MACHETE_FORCE_REBUILD", "1") + + assert extension_api.load(name="machete", use_cache=False) == {"machete": True} + + device = torch.device("cuda:0") + module = machete_linear_module.MacheteLinear( + bits=4, + group_size=128, + desc_act=False, + sym=True, + in_features=128, + out_features=128, + bias=False, + dtype=torch.float16, + ).to(device) + with torch.no_grad(): + module.qweight.copy_(torch.randint(0, 16, module.qweight.shape, device=device, dtype=torch.int32)) + module.g_idx.copy_(torch.arange(module.in_features, device=device, dtype=torch.int32)) + module.scales.copy_(torch.ones_like(module.scales, device=device)) + module.post_init() + + out = module(torch.randn(4, 128, device=device, dtype=torch.float16)) + torch.cuda.synchronize(device) + + assert out.shape == (4, 128) + assert out.dtype == torch.float16 + + +@pytest.mark.cuda +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_machete_cuda_asymmetric_gptq_matches_dense_reference(monkeypatch, tmp_path): + if not machete_utils._validate_machete_device_support(): + pytest.skip(machete_utils.machete_runtime_error()) + + scratch_root = _jit_scratch_root(tmp_path, "machete-asym") + monkeypatch.setenv("CUDA_DEVICE_ORDER", "PCI_BUS_ID") + monkeypatch.delenv("GPTQMODEL_CUTLASS_DIR", raising=False) + monkeypatch.setenv("GPTQMODEL_MACHETE_BUILD_ROOT", str(scratch_root / "machete")) + + assert extension_api.load(name="machete", use_cache=True) == {"machete": True} + + device = torch.device("cuda:0") + module = machete_linear_module.MacheteLinear( + bits=4, + group_size=64, + desc_act=False, + sym=False, + in_features=128, + out_features=128, + bias=False, + dtype=torch.float16, + ).to(device) + + qweight = torch.randint(0, 16, module.qweight.shape, device=device, dtype=torch.int32) + qzeros = torch.randint(0, 16, module.qzeros.shape, device=device, dtype=torch.int32) + scales = (torch.rand(module.scales.shape, device=device, dtype=torch.float16) + 0.25).contiguous() + g_idx = torch.tensor([0] * 64 + [1] * 64, device=device, dtype=torch.int32) + x = torch.randn(8, 128, device=device, dtype=torch.float16) + + with torch.no_grad(): + module.qweight.copy_(qweight) + module.qzeros.copy_(qzeros) + module.scales.copy_(scales) + module.g_idx.copy_(g_idx) + + expected = _dense_asymmetric_gptq_reference( + x=x, + qweight=qweight, + qzeros=qzeros, + scales=scales, + g_idx=g_idx, + weight_type=module.weight_type, + ) + + module.post_init() + actual = module(x) + torch.cuda.synchronize(device) + + torch.testing.assert_close(actual, expected, atol=2e-2, rtol=2e-2) diff --git a/tests/test_marlin_jit.py b/tests/test_marlin_jit.py new file mode 100644 index 000000000..d860efc0d --- /dev/null +++ b/tests/test_marlin_jit.py @@ -0,0 +1,640 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from pathlib import Path +from shutil import copy2, which + +import pytest +import torch + +import gptqmodel.nn_modules.qlinear.marlin as marlin_qlinear_module +import gptqmodel.nn_modules.qlinear.marlin_awq as marlin_awq_qlinear_module +import gptqmodel.utils.marlin as marlin_utils +from gptqmodel import extension as extension_api +from gptqmodel.utils import cpp as cpp_module +from gptqmodel.utils.marlin_scalar_type import scalar_types + + +class _FakeLoader: + def __init__(self, *, should_load: bool = True, last_error: str = ""): + self.should_load = should_load + self._last_error = last_error + self.ops: dict[str, object] = {} + self.load_calls = 0 + self.op_calls: list[str] = [] + + def load(self) -> bool: + self.load_calls += 1 + return self.should_load + + def op(self, op_name: str): + self.op_calls.append(op_name) + return self.ops[op_name] + + def last_error_message(self) -> str: + return self._last_error + + def clear_cache(self) -> None: + return None + + +class _FakeExtensionApi: + def __init__(self, *, available: bool = False, error_text: str = ""): + self.available = available + self.error_text = error_text + self.is_available_calls: list[str] = [] + self.error_calls: list[str] = [] + + def is_available(self, extension_name: str) -> bool: + self.is_available_calls.append(extension_name) + return self.available + + def error(self, extension_name: str) -> str: + self.error_calls.append(extension_name) + return self.error_text + + +def _jit_scratch_root(tmp_path: Path, suffix: str) -> Path: + base = Path("/dev/shm") if Path("/dev/shm").is_dir() else tmp_path + root = base / "gptqmodel-jit-tests" / suffix + root.mkdir(parents=True, exist_ok=True) + return root + + +def test_gptq_marlin_gemm_dispatches_fp16_to_torch_ops(monkeypatch): + fp16_loader = _FakeLoader() + bf16_loader = _FakeLoader() + captured = {} + + def fake_gemm(*args): + captured["dtype"] = args[0].dtype + captured["shape"] = (args[11], args[12]) + return torch.full((args[11], args[12]), 3.0, dtype=args[0].dtype) + + fp16_loader.ops["gptq_marlin_gemm_fp16"] = fake_gemm + + monkeypatch.setattr(marlin_utils, "_MARLIN_FP16_TORCH_OPS_EXTENSION", fp16_loader) + monkeypatch.setattr(marlin_utils, "_MARLIN_BF16_TORCH_OPS_EXTENSION", bf16_loader) + + out = marlin_utils.gptq_marlin_gemm( + a=torch.ones((2, 128), dtype=torch.float16), + c=None, + b_q_weight=torch.zeros((32, 64), dtype=torch.int32), + b_bias=None, + b_scales=torch.ones((1, 64), dtype=torch.float16), + global_scale=None, + b_zeros=None, + g_idx=None, + perm=None, + workspace=torch.zeros(1, dtype=torch.int32), + b_q_type=scalar_types.uint4b8, + size_m=2, + size_n=64, + size_k=128, + ) + + assert fp16_loader.op_calls == ["gptq_marlin_gemm_fp16"] + assert bf16_loader.op_calls == [] + assert captured == {"dtype": torch.float16, "shape": (2, 64)} + assert out.shape == (2, 64) + assert out.dtype == torch.float16 + + +def test_gptq_marlin_gemm_dispatches_bf16_to_torch_ops(monkeypatch): + fp16_loader = _FakeLoader() + bf16_loader = _FakeLoader() + captured = {} + + def fake_gemm(*args): + captured["dtype"] = args[0].dtype + return torch.full((args[11], args[12]), 5.0, dtype=args[0].dtype) + + bf16_loader.ops["gptq_marlin_gemm_bf16"] = fake_gemm + + monkeypatch.setattr(marlin_utils, "_MARLIN_FP16_TORCH_OPS_EXTENSION", fp16_loader) + monkeypatch.setattr(marlin_utils, "_MARLIN_BF16_TORCH_OPS_EXTENSION", bf16_loader) + + out = marlin_utils.gptq_marlin_gemm( + a=torch.ones((1, 64), dtype=torch.bfloat16), + c=None, + b_q_weight=torch.zeros((16, 64), dtype=torch.int32), + b_bias=None, + b_scales=torch.ones((1, 64), dtype=torch.bfloat16), + global_scale=None, + b_zeros=None, + g_idx=None, + perm=None, + workspace=torch.zeros(1, dtype=torch.int32), + b_q_type=scalar_types.uint8b128, + size_m=1, + size_n=64, + size_k=64, + ) + + assert bf16_loader.op_calls == ["gptq_marlin_gemm_bf16"] + assert fp16_loader.op_calls == [] + assert captured == {"dtype": torch.bfloat16} + assert out.shape == (1, 64) + assert out.dtype == torch.bfloat16 + + +def test_gptq_marlin_gemm_passes_float_global_scale_to_torch_ops(monkeypatch): + fp16_loader = _FakeLoader() + bf16_loader = _FakeLoader() + captured = {} + + def fake_gemm(*args): + captured["global_scale_dtype"] = args[5].dtype + captured["global_scale_shape"] = tuple(args[5].shape) + return torch.zeros((args[11], args[12]), dtype=args[0].dtype) + + fp16_loader.ops["gptq_marlin_gemm_fp16"] = fake_gemm + + monkeypatch.setattr(marlin_utils, "_MARLIN_FP16_TORCH_OPS_EXTENSION", fp16_loader) + monkeypatch.setattr(marlin_utils, "_MARLIN_BF16_TORCH_OPS_EXTENSION", bf16_loader) + + out = marlin_utils.gptq_marlin_gemm( + a=torch.ones((1, 64), dtype=torch.float16), + c=None, + b_q_weight=torch.zeros((16, 64), dtype=torch.int32), + b_bias=None, + b_scales=torch.ones((4, 64), dtype=torch.float16), + global_scale=torch.tensor([1.0], dtype=torch.float32), + b_zeros=None, + g_idx=None, + perm=None, + workspace=torch.zeros(1, dtype=torch.int32), + b_q_type=scalar_types.float4_e2m1f, + size_m=1, + size_n=64, + size_k=64, + ) + + assert fp16_loader.op_calls == ["gptq_marlin_gemm_fp16"] + assert bf16_loader.op_calls == [] + assert captured == {"global_scale_dtype": torch.float32, "global_scale_shape": (1,)} + assert out.shape == (1, 64) + assert out.dtype == torch.float16 + + +def test_nvfp4_global_scale_contract_is_float_in_marlin_sources(): + marlin_root = marlin_utils._marlin_root() + kernel_h = (marlin_root / "kernel.h").read_text(encoding="utf-8") + gemm_cu = (marlin_root / "gptq_marlin.cu").read_text(encoding="utf-8") + template_h = (marlin_root / "marlin_template.h").read_text(encoding="utf-8") + + assert "const float *__restrict__ global_scale_ptr" in kernel_h + assert 'global_scale = torch::empty({0}, options_fp32);' in gemm_cu + assert 'global_scale.scalar_type() == at::ScalarType::Float' in gemm_cu + assert "global_scale.data_ptr()" in gemm_cu + assert "float global_scale_f32 = 1.0f;" in template_h + assert "c0 *= global_scale_f32;" in template_h + assert "c1 *= global_scale_f32;" in template_h + + +def test_marlin_capability_checks_allow_sm75_but_reject_sm70(monkeypatch): + monkeypatch.setattr(torch.cuda, "is_available", lambda: True) + monkeypatch.setattr(torch.cuda, "get_device_capability", lambda *args, **kwargs: (7, 5)) + + assert marlin_utils._marlin_capability_supported(7, 5) is True + assert marlin_utils._marlin_environment_error() == "" + assert marlin_utils._validate_marlin_device_support() is True + + monkeypatch.setattr(torch.cuda, "get_device_capability", lambda *args, **kwargs: (7, 0)) + + assert marlin_utils._marlin_capability_supported(7, 0) is False + assert "compute capability >= 7.5" in marlin_utils._marlin_environment_error() + assert marlin_utils._validate_marlin_device_support() is False + + +def test_marlin_quant_linear_validate_device_allows_sm75(monkeypatch): + monkeypatch.setattr(marlin_qlinear_module, "IS_ROCM", False) + monkeypatch.setattr(torch.cuda, "device_count", lambda: 2) + monkeypatch.setattr(torch.cuda, "get_device_capability", lambda index=0: (7, 5)) + + marlin_qlinear_module.MarlinLinear.validate_device(marlin_qlinear_module.DEVICE.CUDA) + + +def test_marlin_quant_linear_validate_device_rejects_pre_turing(monkeypatch): + monkeypatch.setattr(marlin_qlinear_module, "IS_ROCM", False) + monkeypatch.setattr(torch.cuda, "device_count", lambda: 1) + monkeypatch.setattr(torch.cuda, "get_device_capability", lambda index=0: (7, 0)) + + with pytest.raises(NotImplementedError, match="compute capability >= 7.5"): + marlin_qlinear_module.MarlinLinear.validate_device(marlin_qlinear_module.DEVICE.CUDA) + + +def test_sm75_turing_contract_is_present_in_marlin_sources(): + marlin_root = marlin_utils._marlin_root() + gemm_cu = (marlin_root / "gptq_marlin.cu").read_text(encoding="utf-8") + generator_py = (marlin_root / "generate_kernels.py").read_text(encoding="utf-8") + template_h = (marlin_root / "marlin_template.h").read_text(encoding="utf-8") + mma_h = (marlin_root / "marlin_mma.h").read_text(encoding="utf-8") + loader_py = (Path(marlin_utils.__file__).resolve().parents[1] / "models" / "loader.py").read_text( + encoding="utf-8" + ) + + assert "requires CUDA_ARCH >= 7.5" in gemm_cu + assert "major_capability == 7 && minor_capability == 5" in gemm_cu + assert "stages = 2;" in gemm_cu + assert "Turing only supports float16 dense Marlin kernels." in gemm_cu + assert 'stage_values.insert(0, 2)' in generator_py + assert "constexpr bool use_fp16_accum" in template_h + assert "__CUDA_ARCH__ == 750" in mma_h + assert "m16n8k8.row.col.f16.f16.f16.f16" in mma_h + assert "compute capability >= 7.5" in loader_py + assert "GPTQ Marlin on Turing (compute capability 7.5)" in loader_py + assert "dtype=torch.float16 only." in loader_py + + +def test_stage2_dense_four_bit_tiles_stay_in_sync_between_selector_and_codegen(): + marlin_root = marlin_utils._marlin_root() + gemm_cu = (marlin_root / "gptq_marlin.cu").read_text(encoding="utf-8") + generator_py = (marlin_root / "generate_kernels.py").read_text(encoding="utf-8") + kernel_u4 = (marlin_root / "kernel_fp16_ku4.cu").read_text(encoding="utf-8") + kernel_u4b8 = (marlin_root / "kernel_fp16_ku4b8.cu").read_text(encoding="utf-8") + kernel_nvfp4 = (marlin_root / "kernel_fp16_kfe2m1f.cu").read_text(encoding="utf-8") + + assert "kIsStage2FourBitTile" in gemm_cu + assert "THREAD_M_BLOCKS * 2 <= THREAD_K_BLOCKS" in gemm_cu + assert "stages == 2 && num_bits == 4" in gemm_cu + assert "thread_m_blocks * 2 > th_config.thread_k / 16" in gemm_cu + assert "_is_4bit_weight" in generator_py + assert "stage_value == 2" in generator_py + + invalid_stage2_tile = ", 256, 4, 16, 4, false, 2," + valid_stage2_tile = ", 256, 2, 16, 4, false, 2," + + assert invalid_stage2_tile not in kernel_u4 + assert invalid_stage2_tile not in kernel_u4b8 + assert invalid_stage2_tile not in kernel_nvfp4 + assert valid_stage2_tile in kernel_u4 + assert valid_stage2_tile in kernel_u4b8 + assert valid_stage2_tile in kernel_nvfp4 + + +def test_mxfp8_contract_is_present_in_marlin_sources(): + marlin_root = marlin_utils._marlin_root() + gemm_cu = (marlin_root / "gptq_marlin.cu").read_text(encoding="utf-8") + generator_py = (marlin_root / "generate_kernels.py").read_text(encoding="utf-8") + template_h = (marlin_root / "marlin_template.h").read_text(encoding="utf-8") + + assert 'scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 2, 8]' in generator_py + assert 'scalar_type == "vllm::kFE4M3fn" and group_blocks == 2' in generator_py + assert 'MXFP8 is only supported with bf16 compute.' in generator_py + assert "MXFP8_GET_IF(vllm::kFE4M3fn, pipe_stages)" in gemm_cu + assert "W_TYPE == vllm::kFE4M3fn && GROUP_BLOCKS == 2" in gemm_cu + assert "Float8_e8m0fnu" in gemm_cu + assert "float8_e4m3fn with float8_e8m0fnu scales requires " in gemm_cu + assert "float8_e4m3fn only supports group_size == 32 (MXFP8)" in gemm_cu + assert "// MXFP8: FP8 weights with e8m0 microscaling block scales." in template_h + assert "w_type == vllm::kFE4M3fn && !(s_type == vllm::kFE8M0fnu)" in template_h + assert "if constexpr (s_type == vllm::kFE4M3fn || s_type == vllm::kFE8M0fnu)" in template_h + + +def test_ensure_generated_marlin_kernels_repairs_stale_generated_sources(monkeypatch, tmp_path): + source_root = marlin_utils._marlin_root() + test_root = tmp_path / "marlin" + test_root.mkdir() + copy2(source_root / "generate_kernels.py", test_root / "generate_kernels.py") + + monkeypatch.setattr(marlin_utils, "_marlin_root", lambda: test_root) + + assert marlin_utils._ensure_generated_marlin_kernels() == test_root + + kernel_path = test_root / "kernel_bf16_kfe4m3fn.cu" + original_text = kernel_path.read_text(encoding="utf-8") + assert "vllm::kFE8M0fnu.id()" in original_text + + stale_text = "\n".join( + line for line in original_text.splitlines() if "vllm::kFE8M0fnu.id()" not in line + ) + "\n" + kernel_path.write_text(stale_text, encoding="utf-8") + assert "vllm::kFE8M0fnu.id()" not in kernel_path.read_text(encoding="utf-8") + + assert marlin_utils._ensure_generated_marlin_kernels() == test_root + assert kernel_path.read_text(encoding="utf-8") == original_text + + +def test_gptq_marlin_repack_prefers_requested_dtype_extension(monkeypatch): + fp16_loader = _FakeLoader() + bf16_loader = _FakeLoader() + captured = {} + + def fake_repack(b_q_weight, perm, size_k, size_n, num_bits): + captured["dtype"] = torch.bfloat16 + captured["shape"] = tuple(b_q_weight.shape) + return b_q_weight + 1 + + bf16_loader.ops["gptq_marlin_repack"] = fake_repack + + monkeypatch.setattr(marlin_utils, "_MARLIN_FP16_TORCH_OPS_EXTENSION", fp16_loader) + monkeypatch.setattr(marlin_utils, "_MARLIN_BF16_TORCH_OPS_EXTENSION", bf16_loader) + + out = marlin_utils.gptq_marlin_repack( + torch.zeros((32, 64), dtype=torch.int32), + torch.arange(32, dtype=torch.int32), + 128, + 64, + 4, + dtype=torch.bfloat16, + ) + + assert bf16_loader.op_calls == ["gptq_marlin_repack"] + assert fp16_loader.op_calls == [] + assert captured == {"dtype": torch.bfloat16, "shape": (32, 64)} + assert torch.equal(out, torch.ones((32, 64), dtype=torch.int32)) + + +def test_awq_marlin_repack_raises_when_requested_jit_extension_is_unavailable(monkeypatch): + fp16_loader = _FakeLoader(should_load=False, last_error="fp16 unavailable") + bf16_loader = _FakeLoader(should_load=False, last_error="bf16 unavailable") + + monkeypatch.setattr(marlin_utils, "_MARLIN_FP16_TORCH_OPS_EXTENSION", fp16_loader) + monkeypatch.setattr(marlin_utils, "_MARLIN_BF16_TORCH_OPS_EXTENSION", bf16_loader) + + with pytest.raises(RuntimeError, match="bf16 unavailable"): + marlin_utils.awq_marlin_repack( + torch.zeros((64, 16), dtype=torch.int32), + 64, + 128, + 4, + dtype=torch.bfloat16, + ) + + assert fp16_loader.op_calls == [] + assert bf16_loader.op_calls == [] + + +def test_marlin_quant_linear_post_init_uses_compute_dtype_for_repack(monkeypatch): + captured = {} + + monkeypatch.setattr(marlin_qlinear_module, "marlin_import_exception", None) + monkeypatch.setattr(marlin_qlinear_module, "marlin_runtime_available", lambda dtype: True) + monkeypatch.setattr(marlin_qlinear_module, "marlin_runtime_error", lambda dtype: "") + monkeypatch.setattr( + marlin_qlinear_module, + "marlin_make_workspace_new", + lambda device: torch.zeros(1, dtype=torch.int32, device=device), + ) + monkeypatch.setattr( + marlin_qlinear_module, + "gptq_marlin_repack", + lambda b_q_weight, perm, size_k, size_n, num_bits, dtype=None: ( + captured.update({"dtype": dtype, "shape": tuple(b_q_weight.shape)}) or b_q_weight + ), + ) + monkeypatch.setattr( + marlin_qlinear_module, + "marlin_permute_scales", + lambda scales, size_k, size_n, group_size: scales, + ) + monkeypatch.setattr(marlin_qlinear_module, "marlin_permute_bias", lambda bias: bias) + + module = marlin_qlinear_module.MarlinLinear( + bits=4, + group_size=128, + desc_act=False, + sym=True, + in_features=128, + out_features=64, + bias=False, + dtype=torch.bfloat16, + ) + module.post_init() + + assert captured == {"dtype": torch.bfloat16, "shape": tuple(module.qweight.shape)} + + +def test_marlin_quant_linear_registers_runtime_buffers_in_compute_dtype(monkeypatch): + monkeypatch.setattr(marlin_qlinear_module, "marlin_import_exception", None) + + module = marlin_qlinear_module.MarlinLinear( + bits=4, + group_size=128, + desc_act=False, + sym=True, + in_features=128, + out_features=64, + bias=True, + dtype=torch.bfloat16, + ) + + assert module.scales.dtype == torch.bfloat16 + assert module.bias.dtype == torch.bfloat16 + + +def test_marlin_quant_linear_forward_promotes_bias_to_input_dtype(monkeypatch): + captured = {} + + monkeypatch.setattr(marlin_qlinear_module, "marlin_import_exception", None) + monkeypatch.setattr(marlin_qlinear_module, "marlin_runtime_available", lambda dtype: True) + monkeypatch.setattr(marlin_qlinear_module, "marlin_runtime_error", lambda dtype: "") + monkeypatch.setattr( + marlin_qlinear_module, + "marlin_make_workspace_new", + lambda device: torch.zeros(1, dtype=torch.int32, device=device), + ) + monkeypatch.setattr( + marlin_qlinear_module, + "gptq_marlin_repack", + lambda b_q_weight, perm, size_k, size_n, num_bits, dtype=None: b_q_weight, + ) + monkeypatch.setattr( + marlin_qlinear_module, + "marlin_permute_scales", + lambda scales, size_k, size_n, group_size: scales, + ) + monkeypatch.setattr(marlin_qlinear_module, "marlin_permute_bias", lambda bias: bias) + monkeypatch.setattr( + marlin_qlinear_module, + "apply_gptq_marlin_linear", + lambda **kwargs: ( + captured.update( + { + "input_dtype": kwargs["input"].dtype, + "scale_dtype": kwargs["weight_scale"].dtype, + "bias_dtype": kwargs["bias"].dtype, + } + ) + or torch.zeros( + (kwargs["input"].shape[0], kwargs["output_size_per_partition"]), + dtype=kwargs["input"].dtype, + ) + ), + ) + + module = marlin_qlinear_module.MarlinLinear( + bits=4, + group_size=128, + desc_act=False, + sym=True, + in_features=128, + out_features=64, + bias=True, + dtype=torch.float16, + ) + module.post_init() + + out = module(torch.randn(2, 128, dtype=torch.bfloat16)) + + assert captured == { + "input_dtype": torch.bfloat16, + "scale_dtype": torch.bfloat16, + "bias_dtype": torch.bfloat16, + } + assert module.bias.dtype == torch.bfloat16 + assert out.dtype == torch.bfloat16 + + +def test_awq_marlin_quant_linear_registers_runtime_buffers_in_compute_dtype(monkeypatch): + monkeypatch.setattr(marlin_awq_qlinear_module, "marlin_import_exception", None) + + module = marlin_awq_qlinear_module.AwqMarlinLinear( + bits=4, + group_size=128, + desc_act=False, + sym=False, + in_features=128, + out_features=64, + bias=True, + dtype=torch.bfloat16, + register_buffers=True, + ) + + assert torch.bfloat16 in marlin_awq_qlinear_module.AwqMarlinLinear.SUPPORTS_DTYPES + assert module.scales.dtype == torch.bfloat16 + assert module.bias.dtype == torch.bfloat16 + + +def test_marlin_runtime_error_appends_cuda_extra_install_hint_for_missing_headers(monkeypatch): + fake_extension_api = _FakeExtensionApi( + error_text=( + "Marlin fp16: failed to build torch.ops JIT extension: " + "fatal error: cusparse.h: No such file or directory" + ), + ) + + monkeypatch.setattr(marlin_utils, "marlin_import_exception", None) + monkeypatch.setattr(marlin_utils, "_extension_api", lambda: fake_extension_api) + monkeypatch.setattr(marlin_utils, "detected_cuda_wheel_include_paths", lambda: []) + monkeypatch.setattr(marlin_utils, "which", lambda name: "/usr/local/cuda/bin/nvcc") + monkeypatch.setattr(torch.version, "cuda", "13.0", raising=False) + + error_text = marlin_utils.marlin_runtime_error(torch.float16) + + assert fake_extension_api.is_available_calls == ["marlin_fp16"] + assert fake_extension_api.error_calls == ["marlin_fp16"] + assert "cusparse.h" in error_text + assert 'pip install "gptqmodel[marlin-cuda]"' in error_text + assert "A local `nvcc` on PATH is still required for Marlin JIT." in error_text + + +def test_marlin_runtime_error_skips_install_hint_when_cuda_wheel_headers_are_detected(monkeypatch): + fake_extension_api = _FakeExtensionApi( + error_text=( + "Marlin bf16: failed to build torch.ops JIT extension: " + "fatal error: cublas_v2.h: No such file or directory" + ), + ) + + monkeypatch.setattr(marlin_utils, "marlin_import_exception", None) + monkeypatch.setattr(marlin_utils, "_extension_api", lambda: fake_extension_api) + monkeypatch.setattr(marlin_utils, "detected_cuda_wheel_include_paths", lambda: ["/tmp/nvidia/cu13/include"]) + monkeypatch.setattr(torch.version, "cuda", "13.0", raising=False) + + marlin_utils.marlin_runtime_error(torch.bfloat16) + + assert fake_extension_api.is_available_calls == ["marlin_bf16"] + assert fake_extension_api.error_calls == ["marlin_bf16"] + + +@pytest.mark.cuda +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_marlin_cuda_smoke_build_and_forward(monkeypatch, tmp_path): + capability = torch.cuda.get_device_capability() + if capability[0] < 7 or (capability[0] == 7 and capability[1] < 5): + pytest.skip("Marlin requires compute capability >= 7.5") + if which("ninja") is None: + pytest.skip("Marlin JIT smoke test requires ninja.") + + scratch_root = _jit_scratch_root(tmp_path, "marlin") + monkeypatch.setenv("CUDA_DEVICE_ORDER", "PCI_BUS_ID") + monkeypatch.setenv("GPTQMODEL_MARLIN_FP16_BUILD_ROOT", str(scratch_root / "marlin_fp16")) + monkeypatch.setenv("GPTQMODEL_MARLIN_BF16_BUILD_ROOT", str(scratch_root / "marlin_bf16")) + monkeypatch.setenv("GPTQMODEL_MARLIN_FORCE_REBUILD", "1") + + assert extension_api.load(name="marlin_fp16", use_cache=False) == { + "marlin_fp16": True, + } + if capability[0] >= 8: + assert extension_api.load(name="marlin_bf16", use_cache=False) == { + "marlin_bf16": True, + } + + device = torch.device("cuda:0") + dtypes = (torch.float16, torch.bfloat16) if capability[0] >= 8 else (torch.float16,) + for dtype in dtypes: + module = marlin_qlinear_module.MarlinLinear( + bits=4, + group_size=128, + desc_act=False, + sym=True, + in_features=128, + out_features=64, + bias=False, + dtype=dtype, + ).to(device) + with torch.no_grad(): + module.qweight.copy_(torch.randint(0, 16, module.qweight.shape, device=device, dtype=torch.int32)) + module.g_idx.copy_(torch.arange(module.in_features, device=device, dtype=torch.int32)) + module.scales.copy_(torch.ones_like(module.scales, device=device)) + module.qzeros.copy_(torch.zeros_like(module.qzeros, device=device)) + module.post_init() + + out = module(torch.randn(4, 128, device=device, dtype=dtype)) + torch.cuda.synchronize(device) + + assert out.shape == (4, 64) + assert out.dtype == dtype + + +def test_marlin_include_paths_use_wheel_headers_when_local_cuda_is_incomplete(monkeypatch, tmp_path): + root = tmp_path / "marlin" + local_cuda_include = tmp_path / "local_cuda_include" + wheel_cuda_include = tmp_path / "wheel_cuda_include" + root.mkdir() + local_cuda_include.mkdir() + wheel_cuda_include.mkdir() + for header_name in marlin_utils._MARLIN_REQUIRED_CUDA_HEADERS: + (wheel_cuda_include / header_name).write_text("// stub", encoding="utf-8") + + monkeypatch.setattr(marlin_utils, "_marlin_root", lambda: root) + monkeypatch.setattr(cpp_module, "detected_local_cuda_include_paths", lambda: [str(local_cuda_include)]) + monkeypatch.setattr(cpp_module, "detected_cuda_wheel_include_paths", lambda: [str(wheel_cuda_include)]) + + include_paths = marlin_utils._marlin_include_paths() + + assert include_paths[0] == str(root) + assert str(wheel_cuda_include) in include_paths + + +def test_marlin_include_paths_skip_wheel_headers_when_local_cuda_has_required_headers(monkeypatch, tmp_path): + root = tmp_path / "marlin" + local_cuda_include = tmp_path / "local_cuda_include" + wheel_cuda_include = tmp_path / "wheel_cuda_include" + root.mkdir() + local_cuda_include.mkdir() + wheel_cuda_include.mkdir() + for header_name in marlin_utils._MARLIN_REQUIRED_CUDA_HEADERS: + (local_cuda_include / header_name).write_text("// stub", encoding="utf-8") + + monkeypatch.setattr(marlin_utils, "_marlin_root", lambda: root) + monkeypatch.setattr(cpp_module, "detected_local_cuda_include_paths", lambda: [str(local_cuda_include)]) + monkeypatch.setattr(cpp_module, "detected_cuda_wheel_include_paths", lambda: [str(wheel_cuda_include)]) + + include_paths = marlin_utils._marlin_include_paths() + + assert include_paths == [str(root)] diff --git a/tests/test_mlx.py b/tests/test_mlx.py index 5effbb65a..67440954e 100644 --- a/tests/test_mlx.py +++ b/tests/test_mlx.py @@ -12,44 +12,13 @@ if sys.platform == "darwin": os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" -import tempfile # noqa: E402 - -from mlx_lm import generate, load # noqa: E402 from models.model_test import ModelTest # noqa: E402 -from transformers import AutoTokenizer # noqa: E402 from gptqmodel import BACKEND, GPTQModel # noqa: E402 NATIVE_MODEL_ID = "/monster/data/model/Qwen2.5-0.5B-Instruct/gptq_4bits_01-07_14-18-11_maxlen1024_ns1024_descFalse_damp0.1/" -class TestExport(ModelTest): - - @classmethod - def setUpClass(cls): - cls.tokenizer = AutoTokenizer.from_pretrained(cls.NATIVE_MODEL_ID, use_fast=True) - cls.calibration_dataset = cls.load_dataset(cls.tokenizer, cls.DATASET_SIZE) - - def test_export_mlx(self): - with tempfile.TemporaryDirectory() as export_dir: - GPTQModel.export( - model_id_or_path=self.NATIVE_MODEL_ID, - target_path=export_dir, - format="mlx" - ) - mlx_model, tokenizer = load(export_dir) - - messages = [{"role": "user", "content": self.INFERENCE_PROMPT}] - prompt = tokenizer.apply_chat_template( - messages, add_generation_prompt=True - ) - - text = generate(mlx_model, tokenizer, prompt=prompt, verbose=True) - - self.assertIn("paris", text.lower()) - -######### test_mlx_generate.py ########## - class TestMlxGenerate(ModelTest): def test_mlx_generate(self): mlx_model = GPTQModel.load( @@ -63,6 +32,3 @@ def test_mlx_generate(self): ) text = mlx_model.generate(prompt=prompt) assert "paris" in text.lower() - - - diff --git a/tests/test_mmlupro.py b/tests/test_mmlupro.py index 9c0aaaf7d..66e74459c 100644 --- a/tests/test_mmlupro.py +++ b/tests/test_mmlupro.py @@ -6,8 +6,7 @@ import tempfile import unittest -from gptqmodel import GPTQModel -from gptqmodel.utils.eval import EVAL +from tests.eval import evaluate, get_eval_task_metrics # os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" @@ -22,4 +21,12 @@ def setUpClass(self): def test_mmlupro(self): with tempfile.TemporaryDirectory() as tmp_dir: - GPTQModel.eval(self.MODEL_ID, framework=EVAL.MMLU_PRO, tasks=EVAL.MMLU_PRO.MATH, output_path=tmp_dir, batch_size=10, ntrain=5) + result = evaluate( + self.MODEL_ID, + tasks="mmlu_pro:math", + output_path=f"{tmp_dir}/result.json", + batch_size=2, + suite_kwargs={"num_fewshot": 1, "max_rows": 2}, + ) + metrics = get_eval_task_metrics(result, "mmlu_pro:math") + self.assertTrue(metrics) diff --git a/tests/test_model.py b/tests/test_model.py index e2409f0fc..370502984 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -9,7 +9,7 @@ from datasets import load_dataset from transformers import AutoTokenizer -from gptqmodel.nn_modules.qlinear.marlin import MarlinQuantLinear +from gptqmodel.nn_modules.qlinear.marlin import MarlinLinear from gptqmodel.utils.torch import torch_empty_cache @@ -31,6 +31,7 @@ from gptqmodel import GPTQModel, QuantizeConfig # noqa: E402 from gptqmodel.looper.module_looper import ModuleLooper, StopMainLoop from gptqmodel.models import loader +from gptqmodel.models.auto import _hide_unsupported_quantization_config_for_eval, _is_supported_quantization_config ############ test_model_dequant.py ############ @@ -44,10 +45,23 @@ from safetensors import safe_open from safetensors.torch import save_file -from gptqmodel.quantization.dtype import dequantize_f8_e4m3 +from gptqmodel.quantization.dtype import ( + available_float8_dtype_names, + dequantize_f4_e2m1, + dequantize_fp8, +) from gptqmodel.utils.model_dequant import dequantize_model +try: + from torchao.prototype.mx_formats.nvfp4_tensor import nvfp4_quantize +except Exception: + nvfp4_quantize = None + + +pytestmark = [pytest.mark.model, pytest.mark.slow, pytest.mark.gpu] + + def pack_cols(values: torch.Tensor, bits: int = 4) -> torch.Tensor: """Pack per-column low-bit values into int32 words.""" @@ -74,9 +88,22 @@ def write_index(path: Path, shard: str, keys: list[str]) -> None: payload = {"weight_map": weight_map} (path / "model.safetensors.index.json").write_text(json.dumps(payload)) - -@pytest.mark.skipif(not hasattr(torch, "float8_e4m3fn"), reason="float8 dtype not available") -def test_dequantize_model_fp8_infers_block_size(tmp_path): +def _checkpoint_roundtrip_fp8_formats() -> list[str]: + formats = [] + for format_name in available_float8_dtype_names(): + dtype = getattr(torch, format_name) + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "probe.safetensors" + try: + save_file({"probe.weight": torch.zeros((2, 2), dtype=dtype)}, str(path)) + except Exception: + continue + formats.append(format_name) + return formats + + +@pytest.mark.parametrize("format_name", _checkpoint_roundtrip_fp8_formats()) +def test_dequantize_model_fp8_infers_block_size(tmp_path, format_name: str): model_dir = tmp_path / "fp8_model_infer" output_dir = tmp_path / "fp8_output_infer" model_dir.mkdir() @@ -84,13 +111,13 @@ def test_dequantize_model_fp8_infers_block_size(tmp_path): config = { "architectures": ["TestModel"], "quantization_config": { - "fmt": "float8_e4m3fn", + "format": format_name, "quant_method": "fp8", }, } (model_dir / "config.json").write_text(json.dumps(config)) - weight = torch.randn(4, 8, dtype=torch.float32).to(torch.float8_e4m3fn) + weight = torch.randn(4, 8, dtype=torch.float32).to(getattr(torch, format_name)) scale_inv = torch.ones(2, 2, dtype=torch.float32) shard_name = "model.safetensors" save_file( @@ -108,12 +135,12 @@ def test_dequantize_model_fp8_infers_block_size(tmp_path): weight_out = reader.get_tensor("linear.weight") assert weight_out.dtype is torch.bfloat16 - expected = dequantize_f8_e4m3(weight, scale_inv=scale_inv, axis=None, target_dtype=torch.bfloat16) + expected = dequantize_fp8(weight, scale_inv=scale_inv, axis=None, target_dtype=torch.bfloat16) assert torch.equal(weight_out, expected) -@pytest.mark.skipif(not hasattr(torch, "float8_e4m3fn"), reason="float8 dtype not available") -def test_dequantize_model_fp8(tmp_path): +@pytest.mark.parametrize("format_name", _checkpoint_roundtrip_fp8_formats()) +def test_dequantize_model_fp8(tmp_path, format_name: str): model_dir = tmp_path / "fp8_model" output_dir = tmp_path / "fp8_output" model_dir.mkdir() @@ -121,14 +148,14 @@ def test_dequantize_model_fp8(tmp_path): config = { "architectures": ["TestModel"], "quantization_config": { - "fmt": "float8_e4m3fn", + "format": format_name, "quant_method": "fp8", "weight_block_size": [2, 4], }, } (model_dir / "config.json").write_text(json.dumps(config)) - weight = torch.randn(2, 4, dtype=torch.float32).to(torch.float8_e4m3fn) + weight = torch.randn(2, 4, dtype=torch.float32).to(getattr(torch, format_name)) scale_inv = torch.ones(1, 1, dtype=torch.float32) shard_name = "model.safetensors" save_file( @@ -149,19 +176,63 @@ def test_dequantize_model_fp8(tmp_path): weight_out = reader.get_tensor("linear.weight") bias_out = reader.get_tensor("linear.bias") - expected = dequantize_f8_e4m3(weight, scale_inv=scale_inv, axis=None, target_dtype=torch.bfloat16) + expected = dequantize_fp8(weight, scale_inv=scale_inv, axis=None, target_dtype=torch.bfloat16) assert torch.equal(weight_out, expected) assert bias_out.dtype is torch.bfloat16 updated_config = json.loads((output_dir / "config.json").read_text()) assert "quantization_config" not in updated_config - assert updated_config.get("torch_dtype") == "bfloat16" + assert updated_config.get("dtype") == "bfloat16" + assert "torch_dtype" not in updated_config new_index = json.loads((output_dir / "model.safetensors.index.json").read_text()) assert "linear.weight" in new_index["weight_map"] assert "linear.weight_scale_inv" not in new_index["weight_map"] +@pytest.mark.skipif(nvfp4_quantize is None, reason="torchao NVFP4 support required") +@pytest.mark.skipif(not hasattr(torch, "float4_e2m1fn_x2"), reason="float4 packed dtype not available") +def test_dequantize_model_nvfp4_float4_storage(tmp_path): + model_dir = tmp_path / "nvfp4_model" + output_dir = tmp_path / "nvfp4_output" + model_dir.mkdir() + + config = { + "architectures": ["TestModel"], + "quantization_config": { + "format": "nvfp4", + }, + } + (model_dir / "config.json").write_text(json.dumps(config)) + + data = torch.randn(4, 16, dtype=torch.float32) + scales, packed = nvfp4_quantize(data, block_size=16) + packed_float4 = packed.view(torch.float4_e2m1fn_x2) + shard_name = "model.safetensors" + save_file( + { + "linear.weight": packed_float4, + "linear.weight_scale": scales, + }, + str(model_dir / shard_name), + ) + write_index(model_dir, shard_name, ["linear.weight", "linear.weight_scale"]) + + dequantize_model(model_dir, output_dir, target_dtype=torch.bfloat16, device="cpu") + + with safe_open(output_dir / shard_name, framework="pt", device="cpu") as reader: + assert "linear.weight" in reader.keys() + assert "linear.weight_scale" not in reader.keys() + weight_out = reader.get_tensor("linear.weight") + + expected = dequantize_f4_e2m1(packed_float4, scale=scales, axis=None, target_dtype=torch.bfloat16) + assert torch.allclose(weight_out, expected, atol=1e-3, rtol=1e-3) + + new_index = json.loads((output_dir / "model.safetensors.index.json").read_text()) + assert "linear.weight" in new_index["weight_map"] + assert "linear.weight_scale" not in new_index["weight_map"] + + def test_dequantize_model_awq(tmp_path): model_dir = tmp_path / "awq_model" output_dir = tmp_path / "awq_output" @@ -432,7 +503,7 @@ def test_model_save_with_non_persistent_buffer(self, offload_to_disk): def test_moe(self): quantize_config = QuantizeConfig( - failsafe=None, + fallback=None, ) model = GPTQModel.load( @@ -452,9 +523,9 @@ def test_moe(self): new_model = GPTQModel.load(tmp_dir_name, device="cuda") print("new_model", new_model) - self.assertIsInstance(new_model.model.model.layers[0].mlp.experts[2].gate_proj, MarlinQuantLinear) - self.assertIsInstance(new_model.model.model.layers[0].mlp.experts[2].up_proj, MarlinQuantLinear) - self.assertIsInstance(new_model.model.model.layers[0].mlp.experts[2].down_proj, MarlinQuantLinear) + self.assertIsInstance(new_model.model.model.layers[0].mlp.experts[2].gate_proj, MarlinLinear) + self.assertIsInstance(new_model.model.model.layers[0].mlp.experts[2].up_proj, MarlinLinear) + self.assertIsInstance(new_model.model.model.layers[0].mlp.experts[2].down_proj, MarlinLinear) # No calibration data was routed to these MoE expert modules. self.assertIsInstance(new_model.model.model.layers[0].mlp.experts[10].gate_proj, nn.Linear) @@ -589,3 +660,64 @@ def layer_complete(self, *, layer_idx, submodule_finalized): ) assert looper._check_loop_stop() is True + + +def test_hide_unsupported_quantization_config_for_eval_temporarily_clears_gguf_bits(): + quantization_config = { + "quant_method": "gguf", + "format": "gguf", + "bits": "q4_k_m", + } + model = types.SimpleNamespace( + config=types.SimpleNamespace(quantization_config=dict(quantization_config)) + ) + + with _hide_unsupported_quantization_config_for_eval(model): + assert model.config.quantization_config is None + + assert model.config.quantization_config == quantization_config + + +def test_hide_unsupported_quantization_config_for_eval_leaves_supported_gptq_alone(): + quantization_config = { + "quant_method": "gptq", + "bits": 4, + "group_size": 128, + } + model = types.SimpleNamespace( + config=types.SimpleNamespace(quantization_config=dict(quantization_config)) + ) + + with _hide_unsupported_quantization_config_for_eval(model): + assert model.config.quantization_config == quantization_config + + assert model.config.quantization_config == quantization_config + + +def test_is_supported_quantization_config_rejects_input_activation_quantization(): + config = types.SimpleNamespace( + quantization_config={ + "quant_method": "modelopt", + "config_groups": { + "group_0": { + "input_activations": {"num_bits": 4, "type": "float", "dynamic": False}, + "weights": {"num_bits": 4, "type": "float", "dynamic": False}, + } + }, + } + ) + + with pytest.raises(ValueError, match="activation quantized models"): + _is_supported_quantization_config(config) + + +def test_is_supported_quantization_config_rejects_kv_cache_quantization(): + config = types.SimpleNamespace( + quantization_config={ + "quant_method": "modelopt", + "kv_cache_scheme": {"num_bits": 8, "type": "float", "dynamic": False}, + } + ) + + with pytest.raises(ValueError, match="activation quantized models"): + _is_supported_quantization_config(config) diff --git a/tests/test_model_definition_exports.py b/tests/test_model_definition_exports.py new file mode 100644 index 000000000..c41786d39 --- /dev/null +++ b/tests/test_model_definition_exports.py @@ -0,0 +1,43 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from packaging.version import Version +from transformers import __version__ as TRANSFORMERS_VERSION + +from gptqmodel.models import definitions + + +def test_public_model_definition_exports(): + expected = [ + "BailingMoeQModel", + "GLM4MoEGPTQ", + "GPTOSSGPTQ", + "Gemma3ForConditionalGenerationGPTQ", + "Gemma4ForConditionalGenerationGPTQ", + "Gemma4TextQModel", + "GraniteMoeHybridQModel", + "LFM2MoeQModel", + "LLaDA2MoeQModel", + "MiniCPMGPTQ", + "OlmoeGPTQ", + "Ovis2QModel", + "Phi4MMGPTQ", + "PhiMoEGPTQForCausalLM", + "Qwen2_5_OmniGPTQ", + "Qwen3NextGPTQ", + ] + + for name in expected: + assert hasattr(definitions, name), f"missing export: {name}" + + +def test_qwen3_5_exports_follow_transformers_support(): + supported = Version(TRANSFORMERS_VERSION) >= Version("5.2.0") + if supported: + assert definitions.Qwen3_5QModel is not None + assert definitions.Qwen3_5_MoeQModel is not None + else: + assert definitions.Qwen3_5QModel is None + assert definitions.Qwen3_5_MoeQModel is None diff --git a/tests/test_model_test_baseline_fallback.py b/tests/test_model_test_baseline_fallback.py new file mode 100644 index 000000000..1d725f639 --- /dev/null +++ b/tests/test_model_test_baseline_fallback.py @@ -0,0 +1,63 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import sys +from pathlib import Path + +import pytest + +from gptqmodel import BACKEND + + +sys.path.insert(0, str(Path(__file__).resolve().parent / "models")) +from model_test import ModelTest # noqa: E402 + + +class _BaselineFallbackHarness(ModelTest): + NATIVE_MODEL_ID = "/tmp/native-model" + LOAD_BACKEND = BACKEND.TORCH + DISABLE_NATIVE_BASELINE_FALLBACK = False + EVAL_TASKS = { + "arc_challenge": { + "acc": { + "value": 0.30, + "floor_pct": 0.05, + "ceil_pct": 0.10, + }, + }, + } + + def __init__(self, native_results): + super().__init__(methodName="runTest") + self._stub_native_results = native_results + self.evaluate_model_calls = 0 + + def _model_test_mode(self) -> str: + return self.MODEL_TEST_MODE_SLOW + + def evaluate_model(self, *args, **kwargs): # pragma: no cover - exercised via check_results + self.evaluate_model_calls += 1 + return self._stub_native_results + + +def test_check_results_accepts_current_native_baseline_when_static_value_is_stale(): + harness = _BaselineFallbackHarness( + {"arc_challenge": {"acc,none": 0.26}}, + ) + + harness.check_results({"arc_challenge": {"acc,none": 0.255}}) + + assert harness.evaluate_model_calls == 1 + + +def test_check_results_still_fails_when_quantized_result_misses_current_native_baseline(): + harness = _BaselineFallbackHarness( + {"arc_challenge": {"acc,none": 0.22}}, + ) + + with pytest.raises(AssertionError): + harness.check_results({"arc_challenge": {"acc,none": 0.255}}) + + assert harness.evaluate_model_calls == 1 diff --git a/tests/test_model_test_helpers.py b/tests/test_model_test_helpers.py new file mode 100644 index 000000000..7d7da83d0 --- /dev/null +++ b/tests/test_model_test_helpers.py @@ -0,0 +1,269 @@ +from types import SimpleNamespace + +import model_test as model_test_module +import torch +from model_test import ModelTest + +from gptqmodel import BACKEND + + +class FakeBatchEncoding(dict): + def __init__(self, input_ids): + super().__init__(input_ids=input_ids) + self.input_ids = input_ids + + def to(self, _device): + return self + + +class FakeTokenizer: + pad_token_id = None + eos_token_id = 7 + + def __init__(self): + self.decode_calls = [] + self.batch_decode_calls = [] + + def __call__(self, prompt, return_tensors="pt"): + assert return_tensors == "pt" + assert prompt == "hello" + return FakeBatchEncoding(torch.tensor([[101, 102]])) + + def decode(self, tokens, skip_special_tokens=True): + self.decode_calls.append( + { + "tokens": tokens.tolist(), + "skip_special_tokens": skip_special_tokens, + } + ) + return f"decoded:{tokens.tolist()}" + + def batch_decode(self, sequences, skip_special_tokens=True, clean_up_tokenization_spaces=False): + self.batch_decode_calls.append( + { + "sequences": [seq.tolist() for seq in sequences], + "skip_special_tokens": skip_special_tokens, + "clean_up_tokenization_spaces": clean_up_tokenization_spaces, + } + ) + return [f"batch:{[seq.tolist() for seq in sequences]}"] + + +class FakeProcessor(FakeTokenizer): + pass + + +class FakeModel: + def __init__(self, generated): + self.device = "cpu" + self.generated = generated + self.calls = [] + + def generate(self, **kwargs): + self.calls.append(kwargs) + return self.generated + + +def test_generate_stable_with_limit_for_prompt_uses_deterministic_kwargs(): + tokenizer = FakeTokenizer() + model = FakeModel(torch.tensor([[101, 102, 103, 104]])) + + output = ModelTest.generate_stable_with_limit( + model, + tokenizer, + "hello", + min_new_tokens=2, + max_new_tokens=4, + skip_special_tokens=False, + ) + + assert output == "decoded:[101, 102, 103, 104]" + assert len(model.calls) == 1 + assert model.calls[0]["do_sample"] is False + assert model.calls[0]["num_beams"] == 1 + assert model.calls[0]["min_new_tokens"] == 2 + assert model.calls[0]["max_new_tokens"] == 4 + assert model.calls[0]["pad_token_id"] == tokenizer.eos_token_id + assert model.calls[0]["eos_token_id"] == tokenizer.eos_token_id + assert tokenizer.decode_calls == [ + { + "tokens": [101, 102, 103, 104], + "skip_special_tokens": False, + } + ] + + +def test_generate_stable_with_limit_for_prepared_inputs_batch_decodes_suffix(): + processor = FakeProcessor() + prepared_inputs = FakeBatchEncoding(torch.tensor([[10, 11]])) + model = FakeModel(torch.tensor([[10, 11, 21, 22]])) + + output = ModelTest.generate_stable_with_limit( + model, + processor, + inputs=prepared_inputs, + prompt=None, + batch_decode=True, + max_new_tokens=2, + clean_up_tokenization_spaces=False, + ) + + assert output == "batch:[[21, 22]]" + assert len(model.calls) == 1 + assert model.calls[0]["input_ids"].tolist() == [[10, 11]] + assert model.calls[0]["do_sample"] is False + assert model.calls[0]["num_beams"] == 1 + assert processor.batch_decode_calls == [ + { + "sequences": [[21, 22]], + "skip_special_tokens": True, + "clean_up_tokenization_spaces": False, + } + ] + + +def test_load_dataset_falls_back_when_datasets_import_failed(monkeypatch): + fallback_dataset = ModelTest._LocalCalibrationDataset([{"text": "a"}, {"text": "b"}]) + + monkeypatch.setattr(model_test_module, "hf_load_dataset", None) + monkeypatch.setattr(model_test_module, "DATASETS_IMPORT_ERROR", SyntaxError("source code string cannot contain null bytes")) + monkeypatch.setattr(ModelTest, "_load_calibration_parquet", staticmethod(lambda: fallback_dataset)) + + dataset = ModelTest.load_dataset(rows=1) + + assert list(dataset) == [{"text": "a"}] + + +def test_load_dataset_falls_back_when_hf_loader_raises(monkeypatch): + fallback_dataset = ModelTest._LocalCalibrationDataset([{"text": "x"}, {"text": "y"}]) + + def _broken_load_dataset(*args, **kwargs): + raise RuntimeError("broken datasets install") + + monkeypatch.setattr(model_test_module, "hf_load_dataset", _broken_load_dataset) + monkeypatch.setattr(ModelTest, "_load_calibration_parquet", staticmethod(lambda: fallback_dataset)) + + dataset = ModelTest.load_dataset(rows=5) + + assert list(dataset) == [{"text": "x"}, {"text": "y"}] + + +def test_detect_gpu_profile_from_cuda0_name(monkeypatch): + monkeypatch.setattr(torch.cuda, "is_available", lambda: True) + monkeypatch.setattr(torch.cuda, "get_device_name", lambda _idx: "NVIDIA A100-SXM4-80GB") + assert ModelTest._detect_gpu_profile() == "A100" + + monkeypatch.setattr(torch.cuda, "get_device_name", lambda _idx: "NVIDIA GeForce RTX 4090") + assert ModelTest._detect_gpu_profile() == "RTX4090" + + +def test_resolve_metric_baseline_value_uses_gpu_profile(monkeypatch): + helper = ModelTest(methodName="runTest") + monkeypatch.setattr(ModelTest, "_detect_gpu_profile", classmethod(lambda cls: "A100")) + selected = helper._resolve_metric_baseline_value({"A100": 0.6, "RTX4090": 0.5}) + assert selected == 0.6 + + monkeypatch.setattr(ModelTest, "_detect_gpu_profile", classmethod(lambda cls: "unknown")) + selected_a100_fallback = helper._resolve_metric_baseline_value({"A100": 0.6, "RTX4090": 0.5}) + assert selected_a100_fallback == 0.6 + + monkeypatch.setattr(ModelTest, "_detect_gpu_profile", classmethod(lambda cls: "A100")) + selected_a100_fallback = helper._resolve_metric_baseline_value(0.6) + assert selected_a100_fallback == 0.6 + + monkeypatch.setattr(ModelTest, "_detect_gpu_profile", classmethod(lambda cls: "RTX 4090")) + selected_a100_fallback = helper._resolve_metric_baseline_value(0.6) + assert selected_a100_fallback == 0.6 + + monkeypatch.setattr(ModelTest, "_detect_gpu_profile", classmethod(lambda cls: "unknown")) + selected_a100_fallback = helper._resolve_metric_baseline_value(0.6) + assert selected_a100_fallback == 0.6 + + +def test_mode_specific_baseline_value_supports_gpu_mapping(monkeypatch): + class _ModeSpecificHarness(ModelTest): + NATIVE_ARC_CHALLENGE_ACC_FAST = {"A100": 0.55, "RTX4090": 0.53} + + helper = _ModeSpecificHarness(methodName="runTest") + monkeypatch.setattr(_ModeSpecificHarness, "_is_fast_model_test_mode", lambda self: True) + monkeypatch.setattr(_ModeSpecificHarness, "_detect_gpu_profile", classmethod(lambda cls: "RTX4090")) + + assert helper._mode_specific_baseline_value("NATIVE_ARC_CHALLENGE_ACC") == 0.53 + + +def test_evalution_threads_seed_and_explicit_greedy_gen_kwargs(monkeypatch): + captured = {} + + class _Harness(ModelTest): + EVAL_TASKS = ("arc_challenge",) + LOAD_BACKEND = BACKEND.AUTO + QUANT_BACKEND = BACKEND.AUTO + + def _fake_evaluate(**kwargs): + captured.update(kwargs) + return {"tests": [{"name": "arc_challenge", "metrics": {"accuracy,loglikelihood": 1.0}}]} + + helper = _Harness(methodName="runTest") + monkeypatch.setattr(model_test_module, "evaluate", _fake_evaluate) + monkeypatch.setattr(_Harness, "_cleanup_quantized_model", lambda self, model, enabled=False: None) + monkeypatch.setattr(_Harness, "_normalize_task_list", lambda self: ["arc_challenge"]) + monkeypatch.setattr( + _Harness, + "_current_load_backend", + lambda self: SimpleNamespace(name="AUTO"), + ) + + results = helper.evaluate_model("/tmp/model", extra_args={"device": "cuda:0"}) + + assert results == {"arc_challenge": {"accuracy,loglikelihood": 1.0}} + assert captured["model_args"]["device"] == "cuda:0" + assert captured["model_args"]["seed"] == model_test_module.RAND_SEED + assert captured["model_args"]["random_seed"] == model_test_module.RAND_SEED + assert captured["gen_kwargs"] == "do_sample=false,temperature=0.0,top_p=1.0,top_k=50" + + +def test_quantize_and_evaluate_runs_evalution_for_prequantized_model_without_cached_post_quant_results(monkeypatch): + class _PrequantizedModel: + quantized = True + + class _Harness(ModelTest): + NATIVE_MODEL_ID = "/tmp/prequantized-model" + LOAD_BACKEND = BACKEND.TORCH_FUSED + DELETE_QUANTIZED_MODEL = False + NATIVE_ARC_CHALLENGE_ACC = 0.30 + NATIVE_ARC_CHALLENGE_ACC_NORM = 0.31 + + def __init__(self): + super().__init__(methodName="runTest") + self.evaluate_model_calls = 0 + + def quantModel(self, *args, **kwargs): + self._post_quant_eval_records = {} + self._loaded_model_was_prequantized = True + return _PrequantizedModel(), None, None + + def evaluate_model(self, model, trust_remote_code=False, delete_quantized_model=False, extra_args=None): + self.evaluate_model_calls += 1 + assert model is self.model + assert delete_quantized_model is False + return { + "arc_challenge": { + "accuracy,loglikelihood": 0.30, + "accuracy,loglikelihood_norm": 0.31, + } + } + + helper = _Harness() + + monkeypatch.setattr(_Harness, "check_kernel", lambda self, model, expected: None) + monkeypatch.setattr(_Harness, "_cleanup_quantized_model", lambda self, model, enabled=False: None) + + helper.quantize_and_evaluate() + + assert helper.evaluate_model_calls == 1 + assert helper._post_quant_eval_records[BACKEND.TORCH_FUSED] == { + "arc_challenge": { + "accuracy,loglikelihood": 0.30, + "accuracy,loglikelihood_norm": 0.31, + } + } diff --git a/tests/test_modelscope.py b/tests/test_modelscope.py index f41cca858..997175399 100644 --- a/tests/test_modelscope.py +++ b/tests/test_modelscope.py @@ -5,6 +5,8 @@ import os +import pytest + os.environ["GPTQMODEL_USE_MODELSCOPE"] = "True" from models.model_test import ModelTest # noqa: E402 @@ -12,6 +14,9 @@ from gptqmodel import GPTQModel # noqa: E402 +pytestmark = [pytest.mark.model, pytest.mark.slow] + + class TestLoadModelscope(ModelTest): @classmethod @@ -21,8 +26,11 @@ def setUpClass(self): def test_load_modelscope(self): model = GPTQModel.load(self.MODEL_ID) - result = model.generate("The capital of France is")[0] - str_output = model.tokenizer.decode(result) + str_output = self.generate_stable_with_limit( + model, + model.tokenizer, + "The capital city of France is named", + ) assert "paris" in str_output.lower() or "city" in str_output.lower() del model diff --git a/tests/test_module_preprocessor.py b/tests/test_module_preprocessor.py new file mode 100644 index 000000000..52bbdca77 --- /dev/null +++ b/tests/test_module_preprocessor.py @@ -0,0 +1,106 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from gptqmodel.looper.module_preprocessor import ModulePreProcessor +from gptqmodel.looper.named_module import NamedModule +from gptqmodel.quantization.config import ( + AutoModuleDecoderConfig, + QuantizeConfig, + SmootherConfig, + SmoothMAD, + TensorParallelPadderConfig, +) + + +def test_module_preprocessor_records_auto_module_decoder_plan(): + linear = torch.nn.Linear(8, 8, bias=False) + named = NamedModule(linear, name="proj", full_name="model.layers.0.proj", layer_index=0) + + qcfg = QuantizeConfig( + bits=4, + group_size=128, + preprocessors=[ + AutoModuleDecoderConfig(target_dtype=torch.float16) + ], + ) + + processor = ModulePreProcessor( + tokenizer=None, + qcfg=qcfg, + calibration=[], + prepare_dataset_func=None, + calibration_concat_size=None, + calibration_sort=None, + batch_size=1, + ) + + processor.preprocess(named) + + assert len(named.state["preprocessor_pipeline"]) == 1 + plan = named.state["auto_module_decoder"] + assert plan["code"] == "auto_module_decoder" + assert plan["source_dtype"] == "auto" + assert plan["target_dtype"] == torch.float16 + + +def test_module_preprocessor_records_tensor_parallel_padder_and_smoother_plan(): + linear = torch.nn.Linear(10, 8, bias=False) + named = NamedModule(linear, name="proj", full_name="model.layers.0.proj", layer_index=0) + + qcfg = QuantizeConfig( + bits=4, + group_size=12, + preprocessors=[ + TensorParallelPadderConfig(), + SmootherConfig(smooth=SmoothMAD(k=1.75)), + ], + ) + + processor = ModulePreProcessor( + tokenizer=None, + qcfg=qcfg, + calibration=[], + prepare_dataset_func=None, + calibration_concat_size=None, + calibration_sort=None, + batch_size=1, + ) + + processor.preprocess(named) + + pipeline = named.state["preprocessor_pipeline"] + assert [entry["code"] for entry in pipeline] == ["tensor_parallel_padder", "smoother"] + assert named.state["tp_pad_info"] == { + "pad_cols": 14, + "target_multiple": 24, + "original_columns": 10, + } + + +def test_module_preprocessor_clears_decoder_state_when_preprocessors_absent(): + linear = torch.nn.Linear(8, 8, bias=False) + named = NamedModule(linear, name="proj", full_name="model.layers.0.proj", layer_index=0) + named.state["preprocessor_pipeline"] = [{"code": "stale"}] + named.state["auto_module_decoder"] = {"code": "stale"} + named.state["quant_source_module"] = torch.nn.Linear(8, 8, bias=False) + named.state["tp_pad_info"] = {"code": "stale"} + + qcfg = QuantizeConfig(bits=4, group_size=128) + processor = ModulePreProcessor( + tokenizer=None, + qcfg=qcfg, + calibration=[], + prepare_dataset_func=None, + calibration_concat_size=None, + calibration_sort=None, + batch_size=1, + ) + + processor.preprocess(named) + + assert "preprocessor_pipeline" not in named.state + assert "auto_module_decoder" not in named.state + assert "quant_source_module" not in named.state + assert "tp_pad_info" not in named.state diff --git a/tests/test_moe_config.py b/tests/test_moe_config.py index 2a1bfba66..a86ca9ac8 100644 --- a/tests/test_moe_config.py +++ b/tests/test_moe_config.py @@ -11,7 +11,7 @@ from gptqmodel import GPTQModel from gptqmodel.models.writer import QUANT_LOG_NSAMPLES -from gptqmodel.nn_modules.qlinear.marlin import MarlinQuantLinear +from gptqmodel.nn_modules.qlinear.marlin import MarlinLinear from gptqmodel.quantization.config import ( ExpertsRoutingBypass, ExpertsRoutingOverride, @@ -32,8 +32,6 @@ 'model.layers.0.mlp.experts.39', 'model.layers.0.mlp.experts.51', 'model.layers.0.mlp.experts.55', - 'model.layers.0.mlp.experts.90', - 'model.layers.0.mlp.experts.91', ] # Standard MoE MLP projections per expert @@ -82,14 +80,14 @@ class TestMoEConfig(ModelTest): - structural integrity of MoE experts after quantization """ - FAILSAFE = None + FALLBACK = None # Intentionally minimal to force observable MoE routing behavior DATASET_SIZE = 1 NATIVE_MODEL_ID = "/monster/data/model/Qwen3-30B-A3B-layers-1" - VRAM_STRATEGY = VramStrategy.BALANCED - SAVE_PATH = "Qwen3-30B-A3B-layers-1-gptq" + DENSE_VRAM_STRATEGY = VramStrategy.EXCLUSIVE + MOE_VRAM_STRATEGY = VramStrategy.BALANCED calibration_dataset = None calibration_dataset_token_size = None @@ -101,7 +99,7 @@ def setUpClass(cls): def quantize_and_assert(self): # Apply GPTQ quantization with optional MoE routing configuration - quant_config = QuantizeConfig(bits=4, group_size=128, moe=self.MOE_CONFIG, failsafe=self.FAILSAFE) + quant_config = QuantizeConfig(bits=4, group_size=128, moe=self.MOE_CONFIG, fallback=None) model = GPTQModel.load(self.NATIVE_MODEL_ID, quant_config) # Compute total calibration token size @@ -152,7 +150,7 @@ def quantize_and_assert(self): torch.cuda.empty_cache() quantized_model = GPTQModel.load(tmp_dir, device_map="auto") - target_cls = MarlinQuantLinear if self.MOE_CONFIG else nn.Linear + target_cls = MarlinLinear if self.MOE_CONFIG else nn.Linear assert_results(quantized_model, target_cls, self.MOE_CONFIG) def test_none_moe_config(self): diff --git a/tests/test_moe_expert_batching.py b/tests/test_moe_expert_batching.py index e1f4465d5..0f65d74ff 100644 --- a/tests/test_moe_expert_batching.py +++ b/tests/test_moe_expert_batching.py @@ -3,7 +3,8 @@ import torch -from gptqmodel.looper.stage_subset import run_subset_stage +from gptqmodel.looper.loop_processor import ExecutionConfig +from gptqmodel.looper.stage_subset import build_subset_plan, run_subset_stage from gptqmodel.quantization.config import ( ExpertsRoutingBypass, ExpertsRoutingOverride, @@ -38,14 +39,18 @@ def setUp(self): self.looper._moe_subset_threshold = 2 # Mock device preparation to return proper torch.device self.looper._prepare_named_module_for_quantization.return_value = torch.device("cpu") - self.looper._vram_strategy = None + self.looper._dense_quant_devices = [torch.device("cpu")] + self.looper._moe_quant_devices = [torch.device("cpu")] + self.looper._dense_vram_strategy_explicit = False + self.looper._moe_vram_strategy_explicit = False self.processor.name.return_value = "GPTQProcessor" - self.processor.require_fwd = True + self.processor.execution_config = ExecutionConfig( + require_fwd=True, + fwd_replay_after_process=True, + ) # Mock processor tasks self.processor.tasks = {} - # Explicitly set fwd_after_process to match GPTQProcessor default - self.processor.fwd_after_process = True # Create fake subset self.subset = {f"expert.{i}": MagicMock() for i in range(10)} @@ -57,8 +62,19 @@ def setUp(self): def _run_subset_stage(self, subset): """Helper to run subset stage with given subset.""" + plan = build_subset_plan( + self.looper, + processor=self.processor, + subset=subset, + subset_index=0, + subset_total=1, + full=self.full, + fallback=False, + layer_inputs=self.layer_inputs, + ) run_subset_stage( looper=self.looper, + plan=plan, processor=self.processor, module=self.module, layer_inputs=self.layer_inputs, @@ -70,12 +86,8 @@ def _run_subset_stage(self, subset): layer_descriptor="layer.0", layer_title="title", layer_index=0, - layers_prefix="model.layers", - subset=subset, - subset_index=0, - subset_total=1, full=self.full, - failsafe=False, + fallback=False, shared_kv_cache_dict=self.shared_kv_cache_dict, pb=self.pb, ) diff --git a/tests/test_multi_gpu_inference.py b/tests/test_multi_gpu_inference.py index 5062ea8f2..0c76cdf5e 100644 --- a/tests/test_multi_gpu_inference.py +++ b/tests/test_multi_gpu_inference.py @@ -6,6 +6,7 @@ # -- do not touch import os +import pytest import torch from transformers import AutoTokenizer @@ -15,9 +16,14 @@ import unittest # noqa: E402 +from models.model_test import ModelTest # noqa: E402 + from gptqmodel import BACKEND, GPTQModel # noqa: E402 +pytestmark = [pytest.mark.model, pytest.mark.slow] + + class TestMultiGPUInference(unittest.TestCase): @classmethod def setUpClass(self): @@ -26,7 +32,8 @@ def setUpClass(self): def test_multi_gpu_inference(self): cuda_device_count = torch.cuda.device_count() - self.assertGreaterEqual(cuda_device_count, 5, f"Expected CUDA device count to be greater than or equal to 5, but got {cuda_device_count}") + if cuda_device_count < 5: + self.skipTest(f"Need at least 5 visible CUDA devices, got {cuda_device_count}.") model = GPTQModel.load( self.MODEL_PATH, backend=BACKEND.TORCH, @@ -37,20 +44,20 @@ def test_multi_gpu_inference(self): messages = [ {"role": "user", "content": "How many p's are in the word \"apple\"? Please only respond with a number."}, ] - input_tensor = self.tokenizer.apply_chat_template( + model_inputs = self.tokenizer.apply_chat_template( messages, add_generation_prompt=True, return_tensors="pt" ) - outputs = model.generate( - inputs=input_tensor.to(model.device), - max_length=512 - ) - - result = self.tokenizer.decode( - outputs[0][input_tensor.shape[1]:], - skip_special_tokens=False + input_ids = model_inputs["input_ids"] + result = ModelTest.generate_stable_with_limit( + model, + self.tokenizer, + inputs=model_inputs, + max_new_tokens=512, + decode_start_idx=input_ids.shape[1], + skip_special_tokens=False, ) self.assertIn("2<|im_end|>", result.lower(), "The generated result should contain '2<|im_end|>'") diff --git a/tests/test_offload_files.py b/tests/test_offload_files.py index 80288f17b..dcb557557 100644 --- a/tests/test_offload_files.py +++ b/tests/test_offload_files.py @@ -4,15 +4,26 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium import json +import struct from pathlib import Path +from types import SimpleNamespace +import pytest import torch from safetensors import safe_open +from safetensors.torch import save_file from tabulate import tabulate from torch import nn +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding -from gptqmodel.utils.model import get_state_dict_for_save, streaming_state_dict_to_shards +from gptqmodel.utils.model import get_state_dict_for_save, move_to, streaming_state_dict_to_shards from gptqmodel.utils.offload import offload_to_disk, undo_offload_to_disk +from gptqmodel.utils.structure import ( + LazyTurtle, + alias_all_from_turtle_if_meta, + alias_from_turtle_for_submodule, +) class _LinearWithBuffers(nn.Module): @@ -30,6 +41,272 @@ def _clone_state_dict(module: nn.Module) -> dict[str, torch.Tensor]: return {k: v.detach().clone() for k, v in module.state_dict().items()} +class _HybridBlock(nn.Module): + def __init__(self, width: int): + super().__init__() + self.inner = nn.Linear(width, width, bias=False) + self.dt_bias = nn.Parameter(torch.randn(width)) + self.register_buffer("dt_scale", torch.linspace(0.0, 1.0, width)) + + +class _HybridWrapper(nn.Module): + def __init__(self, width: int): + super().__init__() + self.block = _HybridBlock(width) + + +class _TransformerPrefixedHybridWrapper(nn.Module): + """Wrap the real module tree under an extra root to mimic shell-only prefixes.""" + + def __init__(self, width: int): + super().__init__() + self.transformer = _HybridWrapper(width) + + +class _SharedDirectBlock(nn.Module): + def __init__(self, width: int, shared_bias: nn.Parameter): + super().__init__() + self.inner = nn.Linear(width, width, bias=False) + self.dt_bias = shared_bias + + +class _SharedDirectWrapper(nn.Module): + def __init__(self, width: int): + super().__init__() + shared_bias = nn.Parameter(torch.randn(width)) + self.left = _SharedDirectBlock(width, shared_bias) + self.right = _SharedDirectBlock(width, shared_bias) + + +class _CustomParameter(nn.Parameter): + pass + + +class _CustomParamBlock(nn.Module): + def __init__(self, width: int): + super().__init__() + self.inner = nn.Linear(width, width, bias=False) + self.dt_bias = _CustomParameter(torch.randn(width)) + + +class _CustomParamWrapper(nn.Module): + def __init__(self, width: int): + super().__init__() + self.block = _CustomParamBlock(width) + + +def _tiny_llama_config() -> LlamaConfig: + # Keep the rotary test cheap while still using the real HF module init path. + return LlamaConfig( + hidden_size=64, + intermediate_size=128, + num_hidden_layers=1, + num_attention_heads=4, + num_key_value_heads=4, + vocab_size=128, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + ) + + +class _RotaryWrapper(nn.Module): + # Pair one checkpoint-backed tensor with a non-persistent rotary module. + def __init__(self, config: LlamaConfig): + super().__init__() + self.config = config + self.block = nn.Module() + self.block.linear = nn.Linear(config.hidden_size, config.hidden_size, bias=False) + self.block.rotary = LlamaRotaryEmbedding(config, device=torch.device("cpu")) + + +class _AttrBufferTemplate(nn.Module): + """Template module whose non-persistent buffers depend on constructor attributes.""" + + def __init__(self, width: int, scale: float = 1.0, device=None): + super().__init__() + self.width = width + self.scale = scale + base = torch.arange(width, dtype=torch.float32, device=device) + self.register_buffer("cache", base * scale, persistent=False) + self.register_buffer("cache_plus_one", base + 1, persistent=False) + + +class _AttrBufferWrapper(nn.Module): + """Hybrid wrapper that pairs checkpoint tensors with attribute-driven init-only buffers.""" + + def __init__(self, width: int, scale: float = 1.0): + super().__init__() + self.block = nn.Module() + self.block.linear = nn.Linear(width, width, bias=False) + self.block.rotary = _AttrBufferTemplate(width=width, scale=scale, device=torch.device("cpu")) + + +class _ScalarMetaBufferTemplate(nn.Module): + """Template module that keeps the constructor scalar separately from the registered buffer.""" + + def __init__(self, scale: float, device=None): + super().__init__() + self.scalar_scale = scale + self.register_buffer("scale", torch.tensor(scale, dtype=torch.float32, device=device), persistent=False) + + +class _ScalarMetaBufferWrapper(nn.Module): + """Wrapper used to cover constructors that take a scalar but expose a same-named buffer.""" + + def __init__(self, scale: float): + super().__init__() + self.block = nn.Module() + self.block.linear = nn.Linear(8, 8, bias=False) + self.block.scale_holder = _ScalarMetaBufferTemplate(scale=scale, device=torch.device("cpu")) + + +class _SplitGateUpBlock(nn.Module): + """Tiny stand-in for Defuser runtime MLPs that expose split projections from a fused checkpoint tensor.""" + + def __init__(self, width: int, intermediate: int): + super().__init__() + self.gate_proj = nn.Linear(width, intermediate, bias=False) + self.up_proj = nn.Linear(width, intermediate, bias=False) + self.down_proj = nn.Linear(intermediate, width, bias=False) + + +class _SplitGateUpWrapper(nn.Module): + """Wrapper used to exercise lazy-turtle rematerialization from `gate_up_proj` checkpoints.""" + + def __init__(self, width: int, intermediate: int): + super().__init__() + self.block = _SplitGateUpBlock(width, intermediate) + + +class _TinyExpert(nn.Module): + """Small expert used to mirror Defuser's split expert runtime layout.""" + + def __init__(self, width: int): + super().__init__() + self.gate_proj = nn.Linear(width, width, bias=True) + self.up_proj = nn.Linear(width, width, bias=True) + self.down_proj = nn.Linear(width, width, bias=True) + + +class _RectExpert(nn.Module): + """Expert with distinct hidden/intermediate sizes to catch wrong split/transpose rules.""" + + def __init__(self, width: int, intermediate: int, *, bias: bool = True): + super().__init__() + self.gate_proj = nn.Linear(width, intermediate, bias=bias) + self.up_proj = nn.Linear(width, intermediate, bias=bias) + self.down_proj = nn.Linear(intermediate, width, bias=bias) + + +class _FusedExpertsWrapper(nn.Module): + """Wrapper used to exercise fused expert checkpoint slicing during lazy-turtle rematerialization.""" + + def __init__(self, width: int, expert_count: int): + super().__init__() + self.block = nn.Module() + self.block.experts = nn.Module() + for expert_idx in range(expert_count): + self.block.experts.add_module(str(expert_idx), _TinyExpert(width)) + + +class _RectFusedExpertsWrapper(nn.Module): + """Wrapper used to validate real rectangular expert layouts, including transposed storage.""" + + def __init__(self, width: int, intermediate: int, expert_count: int, *, is_transposed: bool): + super().__init__() + self.block = nn.Module() + self.block.experts = nn.Module() + self.block.experts.is_transposed = is_transposed + for expert_idx in range(expert_count): + self.block.experts.add_module(str(expert_idx), _RectExpert(width, intermediate)) + + +def _write_checkpoint_index(path: Path, shard_name: str, state_dict: dict[str, torch.Tensor]) -> None: + weight_map = dict.fromkeys(state_dict, shard_name) + (path / "model.safetensors.index.json").write_text(json.dumps({"weight_map": weight_map})) + + +def _build_lazy_turtle_from_module(tmp_path: Path, model: nn.Module) -> LazyTurtle: + """Persist cloned checkpoint values and reopen them through the lazy turtle source.""" + + state_dict = {name: tensor.detach().clone() for name, tensor in model.state_dict().items()} + return _build_lazy_turtle_from_checkpoint_tensors(tmp_path, state_dict) + + +def _build_lazy_turtle_from_checkpoint_tensors(tmp_path: Path, checkpoint_tensors: dict[str, torch.Tensor]) -> LazyTurtle: + """Persist an arbitrary safetensors checkpoint and reopen it through LazyTurtle.""" + + model_dir = tmp_path / "source_model" + model_dir.mkdir() + shard_name = "model.safetensors" + save_file(checkpoint_tensors, str(model_dir / shard_name)) + _write_checkpoint_index(model_dir, shard_name, checkpoint_tensors) + source = LazyTurtle.maybe_create( + model_local_path=str(model_dir), + config=SimpleNamespace(_experts_implementation=None), + model_init_kwargs={"device_map": {"": "cpu"}}, + ) + assert source is not None + return source + + +def _build_rect_fused_expert_checkpoint_tensors( + source_model: _RectFusedExpertsWrapper, + *, + include_down_proj: bool = True, +) -> dict[str, torch.Tensor]: + """Build fused expert checkpoint tensors that mimic the HF/Defuser expert storage layouts.""" + + experts = [expert for _, expert in source_model.block.experts.named_children()] + is_transposed = bool(getattr(source_model.block.experts, "is_transposed", False)) + gate_split_dim = 1 if is_transposed else 0 + + checkpoint_tensors = { + "block.experts.gate_up_proj": torch.stack( + [ + torch.cat( + [ + expert.gate_proj.weight.detach().clone().transpose(0, 1) if is_transposed else expert.gate_proj.weight.detach().clone(), + expert.up_proj.weight.detach().clone().transpose(0, 1) if is_transposed else expert.up_proj.weight.detach().clone(), + ], + dim=gate_split_dim, + ) + for expert in experts + ], + dim=0, + ), + "block.experts.gate_up_proj_bias": torch.stack( + [ + torch.cat( + [ + expert.gate_proj.bias.detach().clone(), + expert.up_proj.bias.detach().clone(), + ], + dim=0, + ) + for expert in experts + ], + dim=0, + ), + } + + if include_down_proj: + checkpoint_tensors["block.experts.down_proj"] = torch.stack( + [ + expert.down_proj.weight.detach().clone().transpose(0, 1) if is_transposed else expert.down_proj.weight.detach().clone() + for expert in experts + ], + dim=0, + ) + checkpoint_tensors["block.experts.down_proj_bias"] = torch.stack( + [expert.down_proj.bias.detach().clone() for expert in experts], + dim=0, + ) + + return checkpoint_tensors + + def test_offload_to_disk_writes_single_dat_file(tmp_path): model = _LinearWithBuffers(in_features=128, out_features=96) original_state = _clone_state_dict(model.linear) @@ -79,3 +356,1026 @@ def test_offload_to_disk_writes_single_dat_file(tmp_path): undo_offload_to_disk(model.linear, delete_offload_folders=False) for name, tensor in model.linear.state_dict().items(): torch.testing.assert_close(tensor, original_state[name]) + + +def test_alias_all_from_turtle_restores_direct_meta_tensors_with_offloaded_children(tmp_path): + source_model = _HybridWrapper(width=64) + shell_model = _HybridWrapper(width=64) + shell_model.load_state_dict(source_model.state_dict()) + turtle_model = _build_lazy_turtle_from_module(tmp_path, source_model) + + original_state = _clone_state_dict(source_model) + offload_root = tmp_path / "offload_root" + offload_to_disk(module=shell_model.block.inner, model=shell_model, disk_path=str(offload_root)) + + shell_model.block.dt_bias = nn.Parameter( + torch.empty_like(shell_model.block.dt_bias, device="meta"), + requires_grad=shell_model.block.dt_bias.requires_grad, + ) + shell_model.block.register_buffer( + "dt_scale", + torch.empty_like(shell_model.block.dt_scale, device="meta"), + persistent=True, + ) + + alias_all_from_turtle_if_meta(shell_model=shell_model, turtle_model=turtle_model) + + state_dict = get_state_dict_for_save(shell_model, offload_root=str(offload_root)) + save_dir = tmp_path / "saved" + save_dir.mkdir() + expected_files, _, _ = streaming_state_dict_to_shards( + state_dict, + save_dir=str(save_dir), + model_base_name="model", + single_file_name="model.safetensors", + metadata={}, + max_shard_size=None, + ) + + shard_path = save_dir / expected_files[0] + with safe_open(str(shard_path), framework="pt", device="cpu") as handler: + for name, tensor in original_state.items(): + saved = handler.get_tensor(name) + torch.testing.assert_close(saved, tensor) + + +def test_alias_all_from_turtle_preserves_shell_dtype_for_direct_meta_tensors(tmp_path): + source_model = _HybridWrapper(width=64) + shell_model = _HybridWrapper(width=64) + shell_model.load_state_dict(source_model.state_dict()) + turtle_model = _build_lazy_turtle_from_module(tmp_path, source_model) + + offload_root = tmp_path / "offload_root" + offload_to_disk(module=shell_model.block.inner, model=shell_model, disk_path=str(offload_root)) + + shell_model.block.dt_bias = nn.Parameter( + torch.empty(shell_model.block.dt_bias.shape, dtype=torch.float16, device="meta"), + requires_grad=shell_model.block.dt_bias.requires_grad, + ) + shell_model.block.register_buffer( + "dt_scale", + torch.empty(shell_model.block.dt_scale.shape, dtype=torch.float16, device="meta"), + persistent=True, + ) + + alias_all_from_turtle_if_meta(shell_model=shell_model, turtle_model=turtle_model) + + assert shell_model.block.dt_bias.dtype == torch.float16 + assert shell_model.block.dt_scale.dtype == torch.float16 + torch.testing.assert_close(shell_model.block.dt_bias, source_model.block.dt_bias.to(torch.float16)) + torch.testing.assert_close(shell_model.block.dt_scale, source_model.block.dt_scale.to(torch.float16)) + + +def test_alias_all_from_turtle_materializes_shared_value_tensors(tmp_path): + source_model = _SharedDirectWrapper(width=64) + shell_model = _SharedDirectWrapper(width=64) + shell_model.load_state_dict(source_model.state_dict()) + turtle_model = _build_lazy_turtle_from_module(tmp_path, source_model) + + offload_root = tmp_path / "offload_root" + offload_to_disk(module=shell_model.left.inner, model=shell_model, disk_path=str(offload_root)) + offload_to_disk(module=shell_model.right.inner, model=shell_model, disk_path=str(offload_root)) + + shell_model.left.dt_bias = nn.Parameter(torch.empty_like(shell_model.left.dt_bias, device="meta")) + shell_model.right.dt_bias = nn.Parameter(torch.empty_like(shell_model.right.dt_bias, device="meta")) + + alias_all_from_turtle_if_meta(shell_model=shell_model, turtle_model=turtle_model) + + torch.testing.assert_close(shell_model.left.dt_bias, source_model.left.dt_bias) + torch.testing.assert_close(shell_model.right.dt_bias, source_model.right.dt_bias) + + +def test_alias_all_from_turtle_materializes_custom_parameter_checkpoint_values(tmp_path): + source_model = _CustomParamWrapper(width=16) + shell_model = _CustomParamWrapper(width=16) + shell_model.load_state_dict(source_model.state_dict()) + turtle_model = _build_lazy_turtle_from_module(tmp_path, source_model) + + shell_model.block.dt_bias = nn.Parameter(torch.empty_like(shell_model.block.dt_bias, device="meta")) + + alias_all_from_turtle_if_meta(shell_model=shell_model, turtle_model=turtle_model) + + torch.testing.assert_close(shell_model.block.dt_bias, source_model.block.dt_bias) + + +def test_lazy_turtle_materializes_recursive_submodule(tmp_path): + source_model = _HybridWrapper(width=16) + model_dir = tmp_path / "source_model" + model_dir.mkdir() + + shard_name = "model.safetensors" + save_file(source_model.state_dict(), str(model_dir / shard_name)) + _write_checkpoint_index(model_dir, shard_name, source_model.state_dict()) + + shell_model = _HybridWrapper(width=16) + shell_model.load_state_dict(source_model.state_dict()) + shell_model.block.inner.weight = nn.Parameter( + torch.empty_like(shell_model.block.inner.weight, device="meta"), + requires_grad=shell_model.block.inner.weight.requires_grad, + ) + shell_model.block.dt_bias = nn.Parameter( + torch.empty_like(shell_model.block.dt_bias, device="meta"), + requires_grad=shell_model.block.dt_bias.requires_grad, + ) + shell_model.block.register_buffer( + "dt_scale", + torch.empty_like(shell_model.block.dt_scale, device="meta"), + persistent=True, + ) + + source = LazyTurtle.maybe_create( + model_local_path=str(model_dir), + config=SimpleNamespace(_experts_implementation=None), + model_init_kwargs={"device_map": {"": "cpu"}}, + ) + + assert source is not None + + alias_from_turtle_for_submodule( + target_model=shell_model, + turtle_model=source, + target_submodule=shell_model.block, + device=torch.device("cpu"), + ) + + torch.testing.assert_close(shell_model.block.inner.weight, source_model.block.inner.weight) + torch.testing.assert_close(shell_model.block.dt_bias, source_model.block.dt_bias) + torch.testing.assert_close(shell_model.block.dt_scale, source_model.block.dt_scale) + + +def test_lazy_turtle_materializes_submodule_when_shell_has_extra_root_prefix(tmp_path): + """Checkpoint-backed materialization should ignore wrapper prefixes absent from shard names.""" + + source_model = _HybridWrapper(width=16) + model_dir = tmp_path / "source_model" + model_dir.mkdir() + + shard_name = "model.safetensors" + save_file(source_model.state_dict(), str(model_dir / shard_name)) + _write_checkpoint_index(model_dir, shard_name, source_model.state_dict()) + + shell_model = _TransformerPrefixedHybridWrapper(width=16) + shell_model.transformer.load_state_dict(source_model.state_dict()) + shell_model.transformer.block.inner.weight = nn.Parameter( + torch.empty_like(shell_model.transformer.block.inner.weight, device="meta"), + requires_grad=shell_model.transformer.block.inner.weight.requires_grad, + ) + shell_model.transformer.block.dt_bias = nn.Parameter( + torch.empty_like(shell_model.transformer.block.dt_bias, device="meta"), + requires_grad=shell_model.transformer.block.dt_bias.requires_grad, + ) + shell_model.transformer.block.register_buffer( + "dt_scale", + torch.empty_like(shell_model.transformer.block.dt_scale, device="meta"), + persistent=True, + ) + + source = LazyTurtle.maybe_create( + model_local_path=str(model_dir), + config=SimpleNamespace(_experts_implementation=None), + model_init_kwargs={"device_map": {"": "cpu"}}, + ) + + assert source is not None + + alias_from_turtle_for_submodule( + target_model=shell_model, + turtle_model=source, + target_submodule=shell_model.transformer.block, + device=torch.device("cpu"), + ) + + torch.testing.assert_close(shell_model.transformer.block.inner.weight, source_model.block.inner.weight) + torch.testing.assert_close(shell_model.transformer.block.dt_bias, source_model.block.dt_bias) + torch.testing.assert_close(shell_model.transformer.block.dt_scale, source_model.block.dt_scale) + + +def test_alias_all_from_lazy_turtle_handles_shell_root_prefix_mismatch(tmp_path): + """Direct meta tensors should resolve through the same prefix-stripping checkpoint aliases.""" + + source_model = _HybridWrapper(width=16) + turtle_model = _build_lazy_turtle_from_module(tmp_path, source_model) + + shell_model = _TransformerPrefixedHybridWrapper(width=16) + shell_model.transformer.load_state_dict(source_model.state_dict()) + shell_model.transformer.block.dt_bias = nn.Parameter( + torch.empty_like(shell_model.transformer.block.dt_bias, device="meta"), + requires_grad=shell_model.transformer.block.dt_bias.requires_grad, + ) + shell_model.transformer.block.register_buffer( + "dt_scale", + torch.empty_like(shell_model.transformer.block.dt_scale, device="meta"), + persistent=True, + ) + + alias_all_from_turtle_if_meta(shell_model=shell_model, turtle_model=turtle_model) + + torch.testing.assert_close(shell_model.transformer.block.dt_bias, source_model.block.dt_bias) + torch.testing.assert_close(shell_model.transformer.block.dt_scale, source_model.block.dt_scale) + + +def test_lazy_turtle_restores_nonpersistent_buffers_from_module_init(tmp_path): + config = _tiny_llama_config() + source_model = _RotaryWrapper(config) + shell_model = _RotaryWrapper(config) + shell_model.load_state_dict(source_model.state_dict()) + + shell_model.block.linear.weight = nn.Parameter( + torch.empty_like(shell_model.block.linear.weight, device="meta"), + requires_grad=shell_model.block.linear.weight.requires_grad, + ) + shell_model.block.rotary.register_buffer( + "inv_freq", + torch.empty_like(shell_model.block.rotary.inv_freq, device="meta"), + persistent=False, + ) + shell_model.block.rotary.register_buffer( + "original_inv_freq", + torch.empty_like(shell_model.block.rotary.original_inv_freq, device="meta"), + persistent=False, + ) + + source = _build_lazy_turtle_from_module(tmp_path, source_model) + + alias_from_turtle_for_submodule( + target_model=shell_model, + turtle_model=source, + target_submodule=shell_model.block, + device=torch.device("cpu"), + ) + + torch.testing.assert_close(shell_model.block.linear.weight, source_model.block.linear.weight) + torch.testing.assert_close(shell_model.block.rotary.inv_freq, source_model.block.rotary.inv_freq) + torch.testing.assert_close(shell_model.block.rotary.original_inv_freq, source_model.block.rotary.original_inv_freq) + assert shell_model.block.rotary.inv_freq.device.type == "cpu" + assert shell_model.block.rotary._non_persistent_buffers_set == {"inv_freq", "original_inv_freq"} + + +def test_lazy_turtle_restores_nonpersistent_buffers_from_attribute_ctor(tmp_path): + """Init-only buffers should rebuild from constructor attributes when no config argument exists.""" + + source_model = _AttrBufferWrapper(width=16, scale=0.5) + shell_model = _AttrBufferWrapper(width=16, scale=0.5) + shell_model.load_state_dict(source_model.state_dict()) + + shell_model.block.linear.weight = nn.Parameter( + torch.empty_like(shell_model.block.linear.weight, device="meta"), + requires_grad=shell_model.block.linear.weight.requires_grad, + ) + shell_model.block.rotary.register_buffer( + "cache", + torch.empty_like(shell_model.block.rotary.cache, device="meta"), + persistent=False, + ) + shell_model.block.rotary.register_buffer( + "cache_plus_one", + torch.empty_like(shell_model.block.rotary.cache_plus_one, device="meta"), + persistent=False, + ) + + source = _build_lazy_turtle_from_module(tmp_path, source_model) + + alias_from_turtle_for_submodule( + target_model=shell_model, + turtle_model=source, + target_submodule=shell_model.block, + device=torch.device("cpu"), + ) + + torch.testing.assert_close(shell_model.block.linear.weight, source_model.block.linear.weight) + torch.testing.assert_close(shell_model.block.rotary.cache, source_model.block.rotary.cache) + torch.testing.assert_close(shell_model.block.rotary.cache_plus_one, source_model.block.rotary.cache_plus_one) + assert shell_model.block.rotary.cache.device.type == "cpu" + assert shell_model.block.rotary._non_persistent_buffers_set == {"cache", "cache_plus_one"} + + +def test_lazy_turtle_restores_nonpersistent_buffers_from_scalar_shadow_attr(tmp_path): + """Scalar constructor args should not be reconstructed from same-named meta buffers.""" + + source_model = _ScalarMetaBufferWrapper(scale=3.5) + shell_model = _ScalarMetaBufferWrapper(scale=3.5) + shell_model.load_state_dict(source_model.state_dict()) + + shell_model.block.linear.weight = nn.Parameter( + torch.empty_like(shell_model.block.linear.weight, device="meta"), + requires_grad=shell_model.block.linear.weight.requires_grad, + ) + shell_model.block.scale_holder.register_buffer( + "scale", + torch.empty_like(shell_model.block.scale_holder.scale, device="meta"), + persistent=False, + ) + + source = _build_lazy_turtle_from_module(tmp_path, source_model) + + alias_from_turtle_for_submodule( + target_model=shell_model, + turtle_model=source, + target_submodule=shell_model.block, + device=torch.device("cpu"), + ) + + torch.testing.assert_close(shell_model.block.linear.weight, source_model.block.linear.weight) + torch.testing.assert_close(shell_model.block.scale_holder.scale, source_model.block.scale_holder.scale) + assert shell_model.block.scale_holder.scale.device.type == "cpu" + assert shell_model.block.scale_holder._non_persistent_buffers_set == {"scale"} + + +def test_lazy_turtle_materializes_split_gate_up_from_fused_checkpoint_tensor(tmp_path): + """Defused runtime `gate_proj`/`up_proj` leaves should restore from fused checkpoint `gate_up_proj` weights.""" + + source_model = _SplitGateUpWrapper(width=8, intermediate=6) + shell_model = _SplitGateUpWrapper(width=8, intermediate=6) + shell_model.load_state_dict(source_model.state_dict()) + + model_dir = tmp_path / "source_model" + model_dir.mkdir() + shard_name = "model.safetensors" + checkpoint_tensors = { + "block.gate_up_proj.weight": torch.cat( + [ + source_model.block.gate_proj.weight.detach().clone(), + source_model.block.up_proj.weight.detach().clone(), + ], + dim=0, + ), + "block.down_proj.weight": source_model.block.down_proj.weight.detach().clone(), + } + save_file(checkpoint_tensors, str(model_dir / shard_name)) + _write_checkpoint_index(model_dir, shard_name, checkpoint_tensors) + + shell_model.block.gate_proj.weight = nn.Parameter( + torch.empty_like(shell_model.block.gate_proj.weight, device="meta"), + requires_grad=shell_model.block.gate_proj.weight.requires_grad, + ) + shell_model.block.up_proj.weight = nn.Parameter( + torch.empty_like(shell_model.block.up_proj.weight, device="meta"), + requires_grad=shell_model.block.up_proj.weight.requires_grad, + ) + shell_model.block.down_proj.weight = nn.Parameter( + torch.empty_like(shell_model.block.down_proj.weight, device="meta"), + requires_grad=shell_model.block.down_proj.weight.requires_grad, + ) + + source = LazyTurtle.maybe_create( + model_local_path=str(model_dir), + config=SimpleNamespace(_experts_implementation=None), + model_init_kwargs={"device_map": {"": "cpu"}}, + ) + assert source is not None + + alias_from_turtle_for_submodule( + target_model=shell_model, + turtle_model=source, + target_submodule=shell_model.block, + device=torch.device("cpu"), + ) + + torch.testing.assert_close(shell_model.block.gate_proj.weight, source_model.block.gate_proj.weight) + torch.testing.assert_close(shell_model.block.up_proj.weight, source_model.block.up_proj.weight) + torch.testing.assert_close(shell_model.block.down_proj.weight, source_model.block.down_proj.weight) + + +def test_lazy_turtle_materializes_split_experts_from_fused_checkpoint_tensors(tmp_path): + """Fused expert checkpoint tensors should rematerialize defused `experts..*` leaves.""" + + source_model = _FusedExpertsWrapper(width=4, expert_count=2) + shell_model = _FusedExpertsWrapper(width=4, expert_count=2) + shell_model.load_state_dict(source_model.state_dict()) + + model_dir = tmp_path / "source_model" + model_dir.mkdir() + shard_name = "model.safetensors" + checkpoint_tensors = { + "block.experts.gate_up_proj": torch.stack( + [ + torch.cat( + [ + source_model.block.experts.get_submodule(str(expert_idx)).gate_proj.weight.detach().clone(), + source_model.block.experts.get_submodule(str(expert_idx)).up_proj.weight.detach().clone(), + ], + dim=1, + ) + for expert_idx in range(2) + ], + dim=0, + ), + "block.experts.gate_up_proj_bias": torch.stack( + [ + torch.cat( + [ + source_model.block.experts.get_submodule(str(expert_idx)).gate_proj.bias.detach().clone(), + source_model.block.experts.get_submodule(str(expert_idx)).up_proj.bias.detach().clone(), + ], + dim=0, + ) + for expert_idx in range(2) + ], + dim=0, + ), + "block.experts.down_proj": torch.stack( + [ + source_model.block.experts.get_submodule(str(expert_idx)).down_proj.weight.detach().clone() + for expert_idx in range(2) + ], + dim=0, + ), + "block.experts.down_proj_bias": torch.stack( + [ + source_model.block.experts.get_submodule(str(expert_idx)).down_proj.bias.detach().clone() + for expert_idx in range(2) + ], + dim=0, + ), + } + save_file(checkpoint_tensors, str(model_dir / shard_name)) + _write_checkpoint_index(model_dir, shard_name, checkpoint_tensors) + + for expert_idx in range(2): + expert = shell_model.block.experts.get_submodule(str(expert_idx)) + expert.gate_proj.weight = nn.Parameter( + torch.empty_like(expert.gate_proj.weight, device="meta"), + requires_grad=expert.gate_proj.weight.requires_grad, + ) + expert.gate_proj.bias = nn.Parameter( + torch.empty_like(expert.gate_proj.bias, device="meta"), + requires_grad=expert.gate_proj.bias.requires_grad, + ) + expert.up_proj.weight = nn.Parameter( + torch.empty_like(expert.up_proj.weight, device="meta"), + requires_grad=expert.up_proj.weight.requires_grad, + ) + expert.up_proj.bias = nn.Parameter( + torch.empty_like(expert.up_proj.bias, device="meta"), + requires_grad=expert.up_proj.bias.requires_grad, + ) + expert.down_proj.weight = nn.Parameter( + torch.empty_like(expert.down_proj.weight, device="meta"), + requires_grad=expert.down_proj.weight.requires_grad, + ) + expert.down_proj.bias = nn.Parameter( + torch.empty_like(expert.down_proj.bias, device="meta"), + requires_grad=expert.down_proj.bias.requires_grad, + ) + + source = LazyTurtle.maybe_create( + model_local_path=str(model_dir), + config=SimpleNamespace(_experts_implementation=None), + model_init_kwargs={"device_map": {"": "cpu"}}, + ) + assert source is not None + + alias_from_turtle_for_submodule( + target_model=shell_model, + turtle_model=source, + target_submodule=shell_model.block, + device=torch.device("cpu"), + ) + + for expert_idx in range(2): + expected = source_model.block.experts.get_submodule(str(expert_idx)) + actual = shell_model.block.experts.get_submodule(str(expert_idx)) + torch.testing.assert_close(actual.gate_proj.weight, expected.gate_proj.weight) + torch.testing.assert_close(actual.gate_proj.bias, expected.gate_proj.bias) + torch.testing.assert_close(actual.up_proj.weight, expected.up_proj.weight) + torch.testing.assert_close(actual.up_proj.bias, expected.up_proj.bias) + torch.testing.assert_close(actual.down_proj.weight, expected.down_proj.weight) + torch.testing.assert_close(actual.down_proj.bias, expected.down_proj.bias) + + +def test_lazy_turtle_materializes_rectangular_qwen_style_experts_from_fused_checkpoint_tensors(tmp_path): + """Non-transposed expert checkpoints should split gate/up along the output dimension.""" + + source_model = _RectFusedExpertsWrapper(width=8, intermediate=6, expert_count=2, is_transposed=False) + shell_model = _RectFusedExpertsWrapper(width=8, intermediate=6, expert_count=2, is_transposed=False) + shell_model.load_state_dict(source_model.state_dict()) + + model_dir = tmp_path / "source_model" + model_dir.mkdir() + shard_name = "model.safetensors" + checkpoint_tensors = { + "block.experts.gate_up_proj": torch.stack( + [ + torch.cat( + [ + source_model.block.experts.get_submodule(str(expert_idx)).gate_proj.weight.detach().clone(), + source_model.block.experts.get_submodule(str(expert_idx)).up_proj.weight.detach().clone(), + ], + dim=0, + ) + for expert_idx in range(2) + ], + dim=0, + ), + "block.experts.gate_up_proj_bias": torch.stack( + [ + torch.cat( + [ + source_model.block.experts.get_submodule(str(expert_idx)).gate_proj.bias.detach().clone(), + source_model.block.experts.get_submodule(str(expert_idx)).up_proj.bias.detach().clone(), + ], + dim=0, + ) + for expert_idx in range(2) + ], + dim=0, + ), + "block.experts.down_proj": torch.stack( + [ + source_model.block.experts.get_submodule(str(expert_idx)).down_proj.weight.detach().clone() + for expert_idx in range(2) + ], + dim=0, + ), + "block.experts.down_proj_bias": torch.stack( + [ + source_model.block.experts.get_submodule(str(expert_idx)).down_proj.bias.detach().clone() + for expert_idx in range(2) + ], + dim=0, + ), + } + save_file(checkpoint_tensors, str(model_dir / shard_name)) + _write_checkpoint_index(model_dir, shard_name, checkpoint_tensors) + + for expert_idx in range(2): + expert = shell_model.block.experts.get_submodule(str(expert_idx)) + expert.gate_proj.weight = nn.Parameter( + torch.empty_like(expert.gate_proj.weight, device="meta"), + requires_grad=expert.gate_proj.weight.requires_grad, + ) + expert.gate_proj.bias = nn.Parameter( + torch.empty_like(expert.gate_proj.bias, device="meta"), + requires_grad=expert.gate_proj.bias.requires_grad, + ) + expert.up_proj.weight = nn.Parameter( + torch.empty_like(expert.up_proj.weight, device="meta"), + requires_grad=expert.up_proj.weight.requires_grad, + ) + expert.up_proj.bias = nn.Parameter( + torch.empty_like(expert.up_proj.bias, device="meta"), + requires_grad=expert.up_proj.bias.requires_grad, + ) + expert.down_proj.weight = nn.Parameter( + torch.empty_like(expert.down_proj.weight, device="meta"), + requires_grad=expert.down_proj.weight.requires_grad, + ) + expert.down_proj.bias = nn.Parameter( + torch.empty_like(expert.down_proj.bias, device="meta"), + requires_grad=expert.down_proj.bias.requires_grad, + ) + + source = LazyTurtle.maybe_create( + model_local_path=str(model_dir), + config=SimpleNamespace(_experts_implementation=None), + model_init_kwargs={"device_map": {"": "cpu"}}, + ) + assert source is not None + + alias_from_turtle_for_submodule( + target_model=shell_model, + turtle_model=source, + target_submodule=shell_model.block, + device=torch.device("cpu"), + ) + + for expert_idx in range(2): + expected = source_model.block.experts.get_submodule(str(expert_idx)) + actual = shell_model.block.experts.get_submodule(str(expert_idx)) + torch.testing.assert_close(actual.gate_proj.weight, expected.gate_proj.weight) + torch.testing.assert_close(actual.gate_proj.bias, expected.gate_proj.bias) + torch.testing.assert_close(actual.up_proj.weight, expected.up_proj.weight) + torch.testing.assert_close(actual.up_proj.bias, expected.up_proj.bias) + torch.testing.assert_close(actual.down_proj.weight, expected.down_proj.weight) + torch.testing.assert_close(actual.down_proj.bias, expected.down_proj.bias) + + +def test_lazy_turtle_materializes_leaf_qwen_style_expert_gate_proj_from_fused_checkpoint_tensor(tmp_path): + """Leaf expert gate_proj modules should resolve fused expert sources from module_path alone.""" + + source_model = _RectFusedExpertsWrapper(width=8, intermediate=6, expert_count=2, is_transposed=False) + shell_model = _RectFusedExpertsWrapper(width=8, intermediate=6, expert_count=2, is_transposed=False) + shell_model.load_state_dict(source_model.state_dict()) + + model_dir = tmp_path / "source_model" + model_dir.mkdir() + shard_name = "model.safetensors" + checkpoint_tensors = { + "block.experts.gate_up_proj": torch.stack( + [ + torch.cat( + [ + source_model.block.experts.get_submodule(str(expert_idx)).gate_proj.weight.detach().clone(), + source_model.block.experts.get_submodule(str(expert_idx)).up_proj.weight.detach().clone(), + ], + dim=0, + ) + for expert_idx in range(2) + ], + dim=0, + ), + "block.experts.gate_up_proj_bias": torch.stack( + [ + torch.cat( + [ + source_model.block.experts.get_submodule(str(expert_idx)).gate_proj.bias.detach().clone(), + source_model.block.experts.get_submodule(str(expert_idx)).up_proj.bias.detach().clone(), + ], + dim=0, + ) + for expert_idx in range(2) + ], + dim=0, + ), + } + save_file(checkpoint_tensors, str(model_dir / shard_name)) + _write_checkpoint_index(model_dir, shard_name, checkpoint_tensors) + + expert = shell_model.block.experts.get_submodule("1").gate_proj + expert.weight = nn.Parameter(torch.empty_like(expert.weight, device="meta"), requires_grad=expert.weight.requires_grad) + expert.bias = nn.Parameter(torch.empty_like(expert.bias, device="meta"), requires_grad=expert.bias.requires_grad) + + source = LazyTurtle.maybe_create( + model_local_path=str(model_dir), + config=SimpleNamespace(_experts_implementation=None), + model_init_kwargs={"device_map": {"": "cpu"}}, + ) + assert source is not None + + alias_from_turtle_for_submodule( + target_model=shell_model, + turtle_model=source, + target_submodule=expert, + device=torch.device("cpu"), + ) + + expected = source_model.block.experts.get_submodule("1").gate_proj + torch.testing.assert_close(expert.weight, expected.weight) + torch.testing.assert_close(expert.bias, expected.bias) + + +def test_lazy_turtle_materializes_rectangular_transposed_experts_from_fused_checkpoint_tensors(tmp_path): + """Transposed expert checkpoints should transpose weights before matching defused leaves.""" + + source_model = _RectFusedExpertsWrapper(width=8, intermediate=6, expert_count=2, is_transposed=True) + shell_model = _RectFusedExpertsWrapper(width=8, intermediate=6, expert_count=2, is_transposed=True) + shell_model.load_state_dict(source_model.state_dict()) + + model_dir = tmp_path / "source_model" + model_dir.mkdir() + shard_name = "model.safetensors" + checkpoint_tensors = { + "block.experts.gate_up_proj": torch.stack( + [ + torch.cat( + [ + source_model.block.experts.get_submodule(str(expert_idx)).gate_proj.weight.detach().clone().transpose(0, 1), + source_model.block.experts.get_submodule(str(expert_idx)).up_proj.weight.detach().clone().transpose(0, 1), + ], + dim=1, + ) + for expert_idx in range(2) + ], + dim=0, + ), + "block.experts.gate_up_proj_bias": torch.stack( + [ + torch.cat( + [ + source_model.block.experts.get_submodule(str(expert_idx)).gate_proj.bias.detach().clone(), + source_model.block.experts.get_submodule(str(expert_idx)).up_proj.bias.detach().clone(), + ], + dim=0, + ) + for expert_idx in range(2) + ], + dim=0, + ), + "block.experts.down_proj": torch.stack( + [ + source_model.block.experts.get_submodule(str(expert_idx)).down_proj.weight.detach().clone().transpose(0, 1) + for expert_idx in range(2) + ], + dim=0, + ), + "block.experts.down_proj_bias": torch.stack( + [ + source_model.block.experts.get_submodule(str(expert_idx)).down_proj.bias.detach().clone() + for expert_idx in range(2) + ], + dim=0, + ), + } + save_file(checkpoint_tensors, str(model_dir / shard_name)) + _write_checkpoint_index(model_dir, shard_name, checkpoint_tensors) + + for expert_idx in range(2): + expert = shell_model.block.experts.get_submodule(str(expert_idx)) + expert.gate_proj.weight = nn.Parameter( + torch.empty_like(expert.gate_proj.weight, device="meta"), + requires_grad=expert.gate_proj.weight.requires_grad, + ) + expert.gate_proj.bias = nn.Parameter( + torch.empty_like(expert.gate_proj.bias, device="meta"), + requires_grad=expert.gate_proj.bias.requires_grad, + ) + expert.up_proj.weight = nn.Parameter( + torch.empty_like(expert.up_proj.weight, device="meta"), + requires_grad=expert.up_proj.weight.requires_grad, + ) + expert.up_proj.bias = nn.Parameter( + torch.empty_like(expert.up_proj.bias, device="meta"), + requires_grad=expert.up_proj.bias.requires_grad, + ) + expert.down_proj.weight = nn.Parameter( + torch.empty_like(expert.down_proj.weight, device="meta"), + requires_grad=expert.down_proj.weight.requires_grad, + ) + expert.down_proj.bias = nn.Parameter( + torch.empty_like(expert.down_proj.bias, device="meta"), + requires_grad=expert.down_proj.bias.requires_grad, + ) + + source = LazyTurtle.maybe_create( + model_local_path=str(model_dir), + config=SimpleNamespace(_experts_implementation=None), + model_init_kwargs={"device_map": {"": "cpu"}}, + ) + assert source is not None + + alias_from_turtle_for_submodule( + target_model=shell_model, + turtle_model=source, + target_submodule=shell_model.block, + device=torch.device("cpu"), + ) + + for expert_idx in range(2): + expected = source_model.block.experts.get_submodule(str(expert_idx)) + actual = shell_model.block.experts.get_submodule(str(expert_idx)) + torch.testing.assert_close(actual.gate_proj.weight, expected.gate_proj.weight) + torch.testing.assert_close(actual.gate_proj.bias, expected.gate_proj.bias) + torch.testing.assert_close(actual.up_proj.weight, expected.up_proj.weight) + torch.testing.assert_close(actual.up_proj.bias, expected.up_proj.bias) + torch.testing.assert_close(actual.down_proj.weight, expected.down_proj.weight) + torch.testing.assert_close(actual.down_proj.bias, expected.down_proj.bias) + + +def test_lazy_turtle_materializes_leaf_transposed_expert_gate_proj_from_fused_checkpoint_tensor(tmp_path): + """Leaf expert gate_proj modules should also honor transposed fused expert layouts.""" + + source_model = _RectFusedExpertsWrapper(width=8, intermediate=6, expert_count=2, is_transposed=True) + shell_model = _RectFusedExpertsWrapper(width=8, intermediate=6, expert_count=2, is_transposed=True) + shell_model.load_state_dict(source_model.state_dict()) + + model_dir = tmp_path / "source_model" + model_dir.mkdir() + shard_name = "model.safetensors" + checkpoint_tensors = { + "block.experts.gate_up_proj": torch.stack( + [ + torch.cat( + [ + source_model.block.experts.get_submodule(str(expert_idx)).gate_proj.weight.detach().clone().transpose(0, 1), + source_model.block.experts.get_submodule(str(expert_idx)).up_proj.weight.detach().clone().transpose(0, 1), + ], + dim=1, + ) + for expert_idx in range(2) + ], + dim=0, + ), + "block.experts.gate_up_proj_bias": torch.stack( + [ + torch.cat( + [ + source_model.block.experts.get_submodule(str(expert_idx)).gate_proj.bias.detach().clone(), + source_model.block.experts.get_submodule(str(expert_idx)).up_proj.bias.detach().clone(), + ], + dim=0, + ) + for expert_idx in range(2) + ], + dim=0, + ), + } + save_file(checkpoint_tensors, str(model_dir / shard_name)) + _write_checkpoint_index(model_dir, shard_name, checkpoint_tensors) + + expert = shell_model.block.experts.get_submodule("1").gate_proj + expert.weight = nn.Parameter(torch.empty_like(expert.weight, device="meta"), requires_grad=expert.weight.requires_grad) + expert.bias = nn.Parameter(torch.empty_like(expert.bias, device="meta"), requires_grad=expert.bias.requires_grad) + + source = LazyTurtle.maybe_create( + model_local_path=str(model_dir), + config=SimpleNamespace(_experts_implementation=None), + model_init_kwargs={"device_map": {"": "cpu"}}, + ) + assert source is not None + + alias_from_turtle_for_submodule( + target_model=shell_model, + turtle_model=source, + target_submodule=expert, + device=torch.device("cpu"), + ) + + expected = source_model.block.experts.get_submodule("1").gate_proj + torch.testing.assert_close(expert.weight, expected.weight) + torch.testing.assert_close(expert.bias, expected.bias) + + +def test_alias_all_from_turtle_materializes_leaf_qwen_style_expert_gate_proj_from_fused_checkpoint_tensor(tmp_path): + """Direct-meta sync should resolve fused non-transposed expert tensors for leaf Linear modules.""" + + source_model = _RectFusedExpertsWrapper(width=8, intermediate=6, expert_count=2, is_transposed=False) + shell_model = _RectFusedExpertsWrapper(width=8, intermediate=6, expert_count=2, is_transposed=False) + shell_model.load_state_dict(source_model.state_dict()) + + expert = shell_model.block.experts.get_submodule("1").gate_proj + expert.weight = nn.Parameter(torch.empty_like(expert.weight, device="meta"), requires_grad=expert.weight.requires_grad) + expert.bias = nn.Parameter(torch.empty_like(expert.bias, device="meta"), requires_grad=expert.bias.requires_grad) + + source = _build_lazy_turtle_from_checkpoint_tensors( + tmp_path, + _build_rect_fused_expert_checkpoint_tensors(source_model, include_down_proj=False), + ) + + alias_all_from_turtle_if_meta(shell_model=shell_model, turtle_model=source) + + expected = source_model.block.experts.get_submodule("1").gate_proj + torch.testing.assert_close(expert.weight, expected.weight) + torch.testing.assert_close(expert.bias, expected.bias) + + +def test_alias_all_from_turtle_materializes_leaf_transposed_expert_gate_proj_from_fused_checkpoint_tensor(tmp_path): + """Direct-meta sync should resolve fused transposed expert tensors for leaf Linear modules.""" + + source_model = _RectFusedExpertsWrapper(width=8, intermediate=6, expert_count=2, is_transposed=True) + shell_model = _RectFusedExpertsWrapper(width=8, intermediate=6, expert_count=2, is_transposed=True) + shell_model.load_state_dict(source_model.state_dict()) + + expert = shell_model.block.experts.get_submodule("1").gate_proj + expert.weight = nn.Parameter(torch.empty_like(expert.weight, device="meta"), requires_grad=expert.weight.requires_grad) + expert.bias = nn.Parameter(torch.empty_like(expert.bias, device="meta"), requires_grad=expert.bias.requires_grad) + + source = _build_lazy_turtle_from_checkpoint_tensors( + tmp_path, + _build_rect_fused_expert_checkpoint_tensors(source_model, include_down_proj=False), + ) + + alias_all_from_turtle_if_meta(shell_model=shell_model, turtle_model=source) + + expected = source_model.block.experts.get_submodule("1").gate_proj + torch.testing.assert_close(expert.weight, expected.weight) + torch.testing.assert_close(expert.bias, expected.bias) + + +def test_lazy_turtle_raises_when_submodule_materialization_cannot_match_target_shape(tmp_path): + """Shape-derived transform failures should fail materialization immediately.""" + + source_model = _HybridWrapper(width=16) + shell_model = _HybridWrapper(width=16) + shell_model.load_state_dict(source_model.state_dict()) + + shell_model.block.dt_bias = nn.Parameter(torch.empty(8, device="meta"), requires_grad=source_model.block.dt_bias.requires_grad) + source = _build_lazy_turtle_from_module(tmp_path, source_model) + + with pytest.raises( + RuntimeError, + match=r"submodule materialization param `dt_bias`.*could not be reshaped into the target layout.*target_shape=\(8,\)", + ): + alias_from_turtle_for_submodule( + target_model=shell_model, + turtle_model=source, + target_submodule=shell_model.block, + device=torch.device("cpu"), + ) + + +def test_alias_all_from_turtle_raises_when_direct_meta_shape_mismatch_slips_past_transform(monkeypatch, tmp_path): + """Post-transform shape mismatches in direct-meta sync should fail immediately.""" + + source_model = _HybridWrapper(width=16) + shell_model = _HybridWrapper(width=16) + shell_model.load_state_dict(source_model.state_dict()) + + shell_model.block.dt_bias = nn.Parameter( + torch.empty_like(shell_model.block.dt_bias, device="meta"), + requires_grad=source_model.block.dt_bias.requires_grad, + ) + source = _build_lazy_turtle_from_module(tmp_path, source_model) + + original_transform = LazyTurtle._transform_checkpoint_tensor + + def _return_wrong_shape(tensor: torch.Tensor, **kwargs) -> torch.Tensor | None: + transformed = original_transform(tensor, **kwargs) + if transformed is None: + return None + return transformed[:-1].contiguous() + + # `_transform_checkpoint_tensor()` now guards most shape mismatches up front. + # Monkeypatch it here so the regression test still exercises the downstream + # hard-failure branch that protects against malformed custom transforms. + monkeypatch.setattr(LazyTurtle, "_transform_checkpoint_tensor", staticmethod(_return_wrong_shape)) + + with pytest.raises( + RuntimeError, + match=r"direct-meta sync param `dt_bias`.*shape does not match the transformed checkpoint tensor.*source_shape=\(15,\).*target_shape=\(16,\)", + ): + alias_all_from_turtle_if_meta(shell_model=shell_model, turtle_model=source) + + +def test_alias_all_from_lazy_turtle_restores_direct_meta_tensors(tmp_path): + source_model = _HybridWrapper(width=16) + model_dir = tmp_path / "source_model" + model_dir.mkdir() + + shard_name = "model.safetensors" + save_file(source_model.state_dict(), str(model_dir / shard_name)) + _write_checkpoint_index(model_dir, shard_name, source_model.state_dict()) + + shell_model = _HybridWrapper(width=16) + shell_model.load_state_dict(source_model.state_dict()) + shell_model.block.dt_bias = nn.Parameter( + torch.empty_like(shell_model.block.dt_bias, device="meta"), + requires_grad=shell_model.block.dt_bias.requires_grad, + ) + shell_model.block.register_buffer( + "dt_scale", + torch.empty_like(shell_model.block.dt_scale, device="meta"), + persistent=True, + ) + + source = LazyTurtle.maybe_create( + model_local_path=str(model_dir), + config=SimpleNamespace(_experts_implementation=None), + model_init_kwargs={"device_map": {"": "cpu"}}, + ) + + assert source is not None + + alias_all_from_turtle_if_meta(shell_model=shell_model, turtle_model=source) + + torch.testing.assert_close(shell_model.block.dt_bias, source_model.block.dt_bias) + torch.testing.assert_close(shell_model.block.dt_scale, source_model.block.dt_scale) + + +def test_streaming_state_dict_pads_safetensors_header_to_8_bytes(tmp_path): + model = nn.Linear(3, 5, bias=False) + state_dict = get_state_dict_for_save(model) + + # Force an unaligned raw header so the regression test would fail without padding. + metadata = {"format": "pt"} + for size in range(1, 33): + candidate = dict(metadata, pad=("x" * size)) + raw_header = { + "__metadata__": candidate, + "weight": { + "dtype": "F32", + "shape": list(model.weight.shape), + "data_offsets": [0, model.weight.numel() * model.weight.element_size()], + }, + } + raw_header_len = len(json.dumps(raw_header, separators=(",", ":")).encode("utf-8")) + if raw_header_len % 8 != 0: + metadata = candidate + break + else: + raise AssertionError("Failed to construct an unaligned safetensors header for the regression test.") + + save_dir = tmp_path / "saved" + save_dir.mkdir() + expected_files, _, _ = streaming_state_dict_to_shards( + state_dict, + save_dir=str(save_dir), + model_base_name="model", + single_file_name="model.safetensors", + metadata=metadata, + max_shard_size=None, + ) + + shard_path = save_dir / expected_files[0] + with shard_path.open("rb") as handle: + stored_header_len = struct.unpack(" int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return sock.getsockname()[1] + + @classmethod + def setUpClass(cls): + cls.MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct-gptqmodel-4bit-vortex-v1" + cls.HOST = "127.0.0.1" + cls.PORT = cls._pick_free_port() + cls.model = GPTQModel.load(cls.MODEL_ID) + @classmethod - def setUpClass(self): - self.MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct-gptqmodel-4bit-vortex-v1" - self.HOST = "127.0.0.1" - self.PORT = 23900 - self.model = GPTQModel.load(self.MODEL_ID) + def tearDownClass(cls): + try: + cls.model.serve_shutdown() + except Exception as exc: + # Shutdown is best-effort here; surface failures without masking the test result. + print(f"serve_shutdown failed during tearDownClass: {exc}") def test_openai_server(self): diff --git a/tests/test_out_of_model_tensors.py b/tests/test_out_of_model_tensors.py new file mode 100644 index 000000000..b97ea148f --- /dev/null +++ b/tests/test_out_of_model_tensors.py @@ -0,0 +1,537 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import copy +import json +import os +from types import SimpleNamespace + +import torch +from safetensors import safe_open +from safetensors.torch import save_file +from transformers import AutoConfig + +from gptqmodel.models.auto import check_and_get_model_definition +from gptqmodel.models.writer import ModelWriter +from gptqmodel.quantization.config import FORMAT, METHOD +from gptqmodel.utils.model import TensorSource + + +class _DummyKernel: + REQUIRES_FORMAT_V2 = False + SUPPORTS_SHARDS = True + + +class _DummyQuantizeConfig: + method = METHOD.GPTQ + format = FORMAT.GPTQ + checkpoint_format = FORMAT.GPTQ + quant_method = METHOD.GPTQ + damp_percent = 0.0 + damp_auto_increment = 0.0 + static_groups = False + true_sequential = False + mse = False + gptaq = None + act_group_aware = False + adapter = None + dynamic = False + offload_to_disk = False + offload_to_disk_path = None + lm_head = False + + def __init__(self): + self._meta = {} + + def __deepcopy__(self, memo): + clone = type(self)() + memo[id(self)] = clone + clone._meta = copy.deepcopy(self._meta, memo) + return clone + + def meta_set_versionable(self, key, value): + self._meta[key] = value + + def meta_set(self, key, value): + self._meta[key] = value + + def to_dict(self): + return {"meta": dict(self._meta)} + + def save_pretrained(self, save_dir): + with open(os.path.join(save_dir, "quantize_config.json"), "w", encoding="utf-8") as handle: + json.dump({"meta": dict(self._meta)}, handle) + + def extract_adapter_rank_patterns(self): + return {} + + +class _DummyConfig: + def __init__(self): + self.some_field = 1 + + def __deepcopy__(self, memo): + clone = type(self)() + memo[id(self)] = clone + clone.__dict__ = copy.deepcopy(self.__dict__, memo) + return clone + + +class _DummyGenerationConfig(_DummyConfig): + pass + + +_REAL_GLM4_MOE_CONFIG = { + # Based on a real GLM-4.5-Air MoE config.json, reduced for a fast unit test. + "architectures": ["Glm4MoeForCausalLM"], + "attention_bias": True, + "attention_dropout": 0.0, + "bos_token_id": None, + "dtype": "bfloat16", + "eos_token_id": [1, 2, 3], + "first_k_dense_replace": 1, + "head_dim": 16, + "hidden_act": "silu", + "hidden_size": 64, + "initializer_range": 0.02, + "intermediate_size": 128, + "max_position_embeddings": 256, + "model_type": "glm4_moe", + "moe_intermediate_size": 32, + "n_group": 1, + "n_routed_experts": 2, + "n_shared_experts": 1, + "norm_topk_prob": True, + "num_attention_heads": 4, + "num_experts_per_tok": 1, + "num_hidden_layers": 1, + "num_key_value_heads": 1, + "num_nextn_predict_layers": 1, + "pad_token_id": 0, + "partial_rotary_factor": 0.5, + "rms_norm_eps": 1e-5, + "rope_parameters": { + "partial_rotary_factor": 0.5, + "rope_theta": 1_000_000, + "rope_type": "default", + }, + "routed_scaling_factor": 2.5, + "tie_word_embeddings": False, + "topk_group": 1, + "transformers_version": "5.5.0", + "use_cache": True, + "use_qk_norm": True, + "vocab_size": 256, +} + +_REAL_QWEN3_5_MOE_CONFIG = { + # Based on a real Qwen3.5-MoE config.json, reduced for a fast unit test. + "architectures": ["Qwen3_5MoeForConditionalGeneration"], + "dtype": "bfloat16", + "model_type": "qwen3_5_moe", + "image_token_id": 1, + "video_token_id": 2, + "vision_start_token_id": 3, + "vision_end_token_id": 4, + "tie_word_embeddings": False, + "transformers_version": "5.5.0", + "text_config": { + "dtype": "bfloat16", + "model_type": "qwen3_5_moe_text", + "attention_bias": False, + "attention_dropout": 0.0, + "bos_token_id": None, + "eos_token_id": 5, + "hidden_act": "silu", + "hidden_size": 64, + "initializer_range": 0.02, + "head_dim": 16, + "num_attention_heads": 4, + "num_key_value_heads": 1, + "num_hidden_layers": 1, + "num_experts": 2, + "num_experts_per_tok": 1, + "moe_intermediate_size": 32, + "shared_expert_intermediate_size": 32, + "layer_types": ["full_attention"], + "linear_conv_kernel_dim": 4, + "linear_key_head_dim": 16, + "linear_num_key_heads": 2, + "linear_num_value_heads": 4, + "linear_value_head_dim": 16, + "max_position_embeddings": 256, + "output_router_logits": False, + "pad_token_id": 5, + "partial_rotary_factor": 0.25, + "rms_norm_eps": 1e-6, + "rope_parameters": { + "partial_rotary_factor": 0.25, + "rope_theta": 10_000, + "rope_type": "default", + }, + "router_aux_loss_coef": 0.001, + "tie_word_embeddings": False, + "use_cache": True, + "vocab_size": 256, + }, + "vision_config": { + "model_type": "qwen3_5_moe", + "depth": 1, + "dtype": "bfloat16", + "hidden_act": "gelu_pytorch_tanh", + "hidden_size": 32, + "in_channels": 3, + "initializer_range": 0.02, + "intermediate_size": 64, + "num_heads": 4, + "num_position_embeddings": 64, + "out_hidden_size": 64, + "patch_size": 4, + "spatial_merge_size": 1, + "temporal_patch_size": 2, + }, +} + + +class _DummyModel: + def __init__(self): + self.config = _DummyConfig() + self.generation_config = _DummyGenerationConfig() + + def save_pretrained(self, save_dir, state_dict=None, is_main_process=True): + with open(os.path.join(save_dir, "config.json"), "w", encoding="utf-8") as handle: + json.dump({"dummy": True}, handle) + with open(os.path.join(save_dir, "generation_config.json"), "w", encoding="utf-8") as handle: + json.dump({"do_sample": True}, handle) + + +def _tensor_source(name: str, tensor: torch.Tensor) -> TensorSource: + return TensorSource(name=name, torch_dtype=tensor.dtype, shape=tuple(tensor.shape), source=tensor) + + +def _build_writer_with_out_of_model_file(model_local_path, out_of_model_tensor_files=None): + class _Base: + pass + + _Base.out_of_model_tensors = out_of_model_tensor_files or [] + + DummyWriter = ModelWriter(_Base) + instance = DummyWriter() + instance.quantized = True + instance.quantize_config = _DummyQuantizeConfig() + instance.quant_log = [] + instance.load_quantized_model = False + instance.qlinear_kernel = _DummyKernel() + instance.model_local_path = model_local_path + instance.trust_remote_code = False + instance.tokenizer = None + instance.processor = None + instance.turtle_model = SimpleNamespace() + instance.model = _DummyModel() + return instance + + +def _write_real_glm4_moe_config(model_dir): + with open(os.path.join(model_dir, "config.json"), "w", encoding="utf-8") as handle: + json.dump(_REAL_GLM4_MOE_CONFIG, handle) + + +def _write_real_qwen3_5_moe_config(model_dir): + with open(os.path.join(model_dir, "config.json"), "w", encoding="utf-8") as handle: + json.dump(_REAL_QWEN3_5_MOE_CONFIG, handle) + + +def _build_writer_from_real_model_config(model_local_path, expected_out_of_model_tensors, monkeypatch=None): + model_definition = check_and_get_model_definition(model_local_path, trust_remote_code=False) + assert model_definition.out_of_model_tensors == expected_out_of_model_tensors + + if getattr(model_definition, "require_load_processor", False): + assert monkeypatch is not None + monkeypatch.setattr( + "gptqmodel.models.base.AutoProcessor.from_pretrained", + lambda *args, **kwargs: None, + ) + + config = AutoConfig.from_pretrained(model_local_path, trust_remote_code=False) + model = model_definition.loader.from_config(config, trust_remote_code=False) + + instance = model_definition( + model=model, + quantized=True, + quantize_config=_DummyQuantizeConfig(), + tokenizer=None, + qlinear_kernel=_DummyKernel(), + load_quantized_model=False, + trust_remote_code=False, + model_local_path=model_local_path, + turtle_model=SimpleNamespace(), + ) + return instance + + +def _build_writer_from_real_glm4_moe_config(model_local_path, monkeypatch=None): + return _build_writer_from_real_model_config( + model_local_path, + expected_out_of_model_tensors={"files": ["mtp.safetensors"]}, + monkeypatch=monkeypatch, + ) + + +def _build_writer_from_real_qwen3_5_moe_config(model_local_path, monkeypatch): + return _build_writer_from_real_model_config( + model_local_path, + expected_out_of_model_tensors={"prefixes": ["mtp"]}, + monkeypatch=monkeypatch, + ) + + +def _patch_streaming(monkeypatch, shard_count=1): + def _fake_streaming_state_dict_to_shards(state_dict, save_dir, model_base_name, single_file_name, metadata, *_args, **_kwargs): + expected_files = [] + tensor_to_filename = {} + for idx in range(shard_count): + if shard_count == 1: + shard_name = "model.safetensors" + else: + shard_name = f"{model_base_name}-{idx+1:05d}-of-{shard_count:05d}.safetensors" + file_path = os.path.join(save_dir, shard_name) + tensor_data = { + name: ts.source if isinstance(ts, TensorSource) else ts + for name, ts in state_dict.items() + } + save_file(tensor_data, file_path, metadata=metadata) + expected_files.append(shard_name) + for name in state_dict: + tensor_to_filename.setdefault(name, shard_name) + total_size = sum(os.path.getsize(os.path.join(save_dir, fname)) for fname in expected_files) + return expected_files, tensor_to_filename, total_size + + monkeypatch.setattr( + "gptqmodel.models.writer.streaming_state_dict_to_shards", + _fake_streaming_state_dict_to_shards, + ) + + +def _patch_basic_env(monkeypatch, state_dict_tensor): + monkeypatch.setattr("gptqmodel.models.writer.get_model_files_size", lambda _: 1) + monkeypatch.setattr("gptqmodel.models.writer.alias_all_from_turtle_if_meta", lambda *args, **kwargs: None) + monkeypatch.setattr("gptqmodel.models.writer.sanitize_model_config", lambda *_args, **_kwargs: None) + monkeypatch.setattr("gptqmodel.models.writer.sanitize_generation_config_file", lambda *_args, **_kwargs: False) + monkeypatch.setattr( + "gptqmodel.models.writer.get_state_dict_for_save", + lambda *_args, **_kwargs: state_dict_tensor, + ) + + +def test_merge_prefixed_tensors(tmp_path, monkeypatch): + original_dir = tmp_path / "original" + original_dir.mkdir() + + shard_a_name = "model-00001-of-00002.safetensors" + shard_b_name = "model-00002-of-00002.safetensors" + + save_file( + { + "base.weight": torch.zeros(1), + "mtp.fc.weight": torch.ones(2), + }, + str(original_dir / shard_a_name), + ) + save_file( + { + "model.layers.0.weight": torch.full((1,), 2.0), + "mtp.model.layers.0.weight": torch.full((3,), 3.0), + }, + str(original_dir / shard_b_name), + ) + + with open(original_dir / "model.safetensors.index.json", "w", encoding="utf-8") as handle: + json.dump( + { + "metadata": {"total_size": 0}, + "weight_map": { + "mtp.fc.weight": shard_a_name, + "mtp.model.layers.0.weight": shard_b_name, + }, + }, + handle, + ) + + writer = _build_writer_with_out_of_model_file( + str(original_dir), out_of_model_tensor_files=[{"prefixes": ["mtp"]}] + ) + state_dict_data = {"model.weight": _tensor_source("model.weight", torch.zeros(1))} + + _patch_basic_env(monkeypatch, state_dict_data) + _patch_streaming(monkeypatch) + + save_dir = tmp_path / "save" + writer.save_quantized(save_dir=str(save_dir)) + + assert not (save_dir / "mtp.safetensors").exists() + + with safe_open(save_dir / "model.safetensors", framework="pt", device="cpu") as handle: + keys = set(handle.keys()) + assert {"mtp.fc.weight", "mtp.model.layers.0.weight"} <= keys + + +def test_merge_prefixed_tensors_with_multiple_shards(tmp_path, monkeypatch): + original_dir = tmp_path / "original" + original_dir.mkdir() + + for shard_idx in range(2): + shard_name = f"model-{shard_idx+1:05d}-of-00002.safetensors" + save_file( + { + "model.weight": torch.zeros(1), + "mtp.fc.weight": torch.ones(2), + }, + str(original_dir / shard_name), + ) + + with open(original_dir / "model.safetensors.index.json", "w", encoding="utf-8") as handle: + json.dump( + { + "metadata": {"total_size": 0}, + "weight_map": { + "mtp.fc.weight": "model-00001-of-00002.safetensors", + }, + }, + handle, + ) + + writer = _build_writer_with_out_of_model_file( + str(original_dir), out_of_model_tensor_files=[{"prefixes": ["mtp"]}] + ) + state_dict_data = {"model.weight": _tensor_source("model.weight", torch.zeros(1))} + + _patch_basic_env(monkeypatch, state_dict_data) + _patch_streaming(monkeypatch, shard_count=2) + + save_dir = tmp_path / "save" + writer.save_quantized(save_dir=str(save_dir)) + + assert (save_dir / "model-00001-of-00002.safetensors").exists() + assert (save_dir / "model-00002-of-00002.safetensors").exists() + assert (save_dir / "model.safetensors.index.json").exists() + + keys = [] + with safe_open(save_dir / "model-00001-of-00002.safetensors", framework="pt", device="cpu") as handle: + keys += handle.keys() + with safe_open(save_dir / "model-00002-of-00002.safetensors", framework="pt", device="cpu") as handle: + keys += handle.keys() + assert {"mtp.fc.weight"} <= set(keys) + + with open(save_dir / "model.safetensors.index.json", "r", encoding="utf-8") as handle: + index_data = json.load(handle) + assert index_data["weight_map"]["mtp.fc.weight"] == "model-00001-of-00002.safetensors" + + +def test_copy_existing_file(tmp_path, monkeypatch): + original_dir = tmp_path / "original" + original_dir.mkdir() + + mtp_file = original_dir / "mtp.safetensors" + save_file({"mtp.linear.weight": torch.ones(1)}, str(mtp_file)) + + writer = _build_writer_with_out_of_model_file( + str(original_dir), out_of_model_tensor_files=[{"files": ["mtp.safetensors"]}] + ) + state_dict_data = {"model.weight": _tensor_source("model.weight", torch.zeros(1))} + + _patch_basic_env(monkeypatch, state_dict_data) + _patch_streaming(monkeypatch) + + save_dir = tmp_path / "save" + writer.save_quantized(save_dir=str(save_dir)) + + with safe_open(save_dir / "mtp.safetensors", framework="pt", device="cpu") as handle: + mtp_keys = set(handle.keys()) + assert mtp_keys == {"mtp.linear.weight"} + + with safe_open(save_dir / "model.safetensors", framework="pt", device="cpu") as handle: + mtp_keys = set(handle.keys()) + assert mtp_keys == {"model.weight"} + + +def test_copy_existing_file_with_glm4_moe(tmp_path, monkeypatch): + original_dir = tmp_path / "original" + original_dir.mkdir() + _write_real_glm4_moe_config(str(original_dir)) + + mtp_file = original_dir / "mtp.safetensors" + save_file({"mtp.linear.weight": torch.ones(1)}, str(mtp_file)) + + writer = _build_writer_from_real_glm4_moe_config(str(original_dir)) + state_dict_data = {"model.weight": _tensor_source("model.weight", torch.zeros(1))} + + _patch_basic_env(monkeypatch, state_dict_data) + _patch_streaming(monkeypatch) + + save_dir = tmp_path / "save" + writer.save_quantized(save_dir=str(save_dir)) + + with safe_open(save_dir / "mtp.safetensors", framework="pt", device="cpu") as handle: + mtp_keys = set(handle.keys()) + assert mtp_keys == {"mtp.linear.weight"} + + with open(save_dir / "config.json", "r", encoding="utf-8") as handle: + saved_config = json.load(handle) + assert saved_config["model_type"] == "glm4_moe" + + +def test_merge_prefixed_tensors_with_qwen3_5_moe(tmp_path, monkeypatch): + original_dir = tmp_path / "original" + original_dir.mkdir() + _write_real_qwen3_5_moe_config(str(original_dir)) + + shard_a_name = "model-00001-of-00002.safetensors" + shard_b_name = "model-00002-of-00002.safetensors" + + save_file( + { + "base.weight": torch.zeros(1), + "mtp.fc.weight": torch.ones(2), + }, + str(original_dir / shard_a_name), + ) + save_file( + { + "model.layers.0.weight": torch.full((1,), 2.0), + "mtp.model.layers.0.weight": torch.full((3,), 3.0), + }, + str(original_dir / shard_b_name), + ) + + with open(original_dir / "model.safetensors.index.json", "w", encoding="utf-8") as handle: + json.dump( + { + "metadata": {"total_size": 0}, + "weight_map": { + "mtp.fc.weight": shard_a_name, + "mtp.model.layers.0.weight": shard_b_name, + }, + }, + handle, + ) + + writer = _build_writer_from_real_qwen3_5_moe_config(str(original_dir), monkeypatch=monkeypatch) + state_dict_data = {"model.weight": _tensor_source("model.weight", torch.zeros(1))} + + _patch_basic_env(monkeypatch, state_dict_data) + _patch_streaming(monkeypatch) + + save_dir = tmp_path / "save" + writer.save_quantized(save_dir=str(save_dir)) + + assert not (save_dir / "mtp.safetensors").exists() + + with safe_open(save_dir / "model.safetensors", framework="pt", device="cpu") as handle: + keys = set(handle.keys()) + assert {"mtp.fc.weight", "mtp.model.layers.0.weight"} <= keys + + with open(save_dir / "config.json", "r", encoding="utf-8") as handle: + saved_config = json.load(handle) + assert saved_config["model_type"] == "qwen3_5_moe" diff --git a/tests/test_ovis_generate_wrapper.py b/tests/test_ovis_generate_wrapper.py new file mode 100644 index 000000000..d1b87be3f --- /dev/null +++ b/tests/test_ovis_generate_wrapper.py @@ -0,0 +1,44 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import torch + +from gptqmodel.models.definitions.ovis import OvisQModel + + +class _DummyInnerModel: + def __init__(self): + self.calls = [] + + def generate(self, inputs, **kwargs): + self.calls.append((inputs, kwargs)) + return "ok" + + +def test_ovis_qmodel_generate_accepts_input_ids_keyword(): + qmodel = OvisQModel.__new__(OvisQModel) + qmodel.model = _DummyInnerModel() + qmodel.device = torch.device("cpu") + + input_ids = torch.tensor([[1, 2, 3]]) + attention_mask = torch.tensor([[1, 1, 1]]) + pixel_values = [torch.zeros(1, 3, 4, 4)] + + output = qmodel.generate( + input_ids=input_ids, + attention_mask=attention_mask, + pixel_values=pixel_values, + max_new_tokens=8, + ) + + assert output == "ok" + assert len(qmodel.model.calls) == 1 + + forwarded_inputs, forwarded_kwargs = qmodel.model.calls[0] + assert forwarded_inputs is input_ids + assert "input_ids" not in forwarded_kwargs + assert forwarded_kwargs["attention_mask"] is attention_mask + assert forwarded_kwargs["pixel_values"] is pixel_values + assert forwarded_kwargs["max_new_tokens"] == 8 diff --git a/tests/test_pack.py b/tests/test_pack.py index 057492ff8..197ae73e9 100644 --- a/tests/test_pack.py +++ b/tests/test_pack.py @@ -12,7 +12,7 @@ from tabulate import tabulate from gptqmodel import BACKEND -from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear +from gptqmodel.nn_modules.qlinear.torch import TorchLinear class TestPackAccuracy(unittest.TestCase): @@ -45,7 +45,7 @@ def _build_inputs(bits: int, group_size: int): return linear, scales, zeros, g_idx def _quant_linear(self): - qlinear = TorchQuantLinear( + qlinear = TorchLinear( bits=self.current_bits, group_size=self.current_group_size, sym=True, diff --git a/tests/test_pack_gpu_alignment.py b/tests/test_pack_gpu_alignment.py index 314f1a95a..2fb4b52b7 100644 --- a/tests/test_pack_gpu_alignment.py +++ b/tests/test_pack_gpu_alignment.py @@ -6,7 +6,7 @@ import pytest import torch -from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear +from gptqmodel.nn_modules.qlinear.torch import TorchLinear from gptqmodel.quantization.config import QuantizeConfig from gptqmodel.utils.backend import BACKEND from gptqmodel.utils.model import pack_module @@ -26,7 +26,7 @@ def test_pack_gpu_raises_on_misaligned_qzeros(): linear = torch.nn.Linear(in_features, out_features, bias=False) - quant_module = TorchQuantLinear( + quant_module = TorchLinear( bits=4, group_size=in_features, sym=True, @@ -70,14 +70,14 @@ def test_pack_gpu_raises_on_misaligned_qzeros(): q_zeros=q_zeros, q_g_idx=q_g_idx, layers=layers, - quant_linear_cls=TorchQuantLinear, + quant_linear_cls=TorchLinear, lock=lock, quantize_config=quant_config, ) packed_module = qModules[layer_name] - assert isinstance(packed_module, TorchQuantLinear) + assert isinstance(packed_module, TorchLinear) assert packed_module.qweight.shape == torch.Size([16, 16]) assert packed_module.qzeros.shape == torch.Size([16, 0]) assert packed_module.scales.shape == torch.Size([16, 1]) diff --git a/tests/test_packable.py b/tests/test_packable.py index 6061e357e..87b54e2c7 100644 --- a/tests/test_packable.py +++ b/tests/test_packable.py @@ -13,13 +13,11 @@ from safetensors.torch import load_file from gptqmodel import BACKEND, GPTQModel -from gptqmodel.nn_modules.qlinear.exllama import ExllamaQuantLinear # noqa: E402 -from gptqmodel.nn_modules.qlinear.exllama_eora import ExllamaEoraQuantLinear -from gptqmodel.nn_modules.qlinear.exllamav2 import ExllamaV2QuantLinear # noqa: E402 -from gptqmodel.nn_modules.qlinear.marlin import MarlinQuantLinear # noqa: E402 -from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear # noqa: E402 -from gptqmodel.nn_modules.qlinear.torch_fused import TorchFusedQuantLinear -from gptqmodel.nn_modules.qlinear.tritonv2 import TritonV2QuantLinear # noqa: E402 +from gptqmodel.nn_modules.qlinear.exllamav2 import ExllamaV2Linear # noqa: E402 +from gptqmodel.nn_modules.qlinear.marlin import MarlinLinear # noqa: E402 +from gptqmodel.nn_modules.qlinear.torch import TorchLinear # noqa: E402 +from gptqmodel.nn_modules.qlinear.torch_fused import TorchFusedLinear +from gptqmodel.nn_modules.qlinear.tritonv2 import TritonV2Linear # noqa: E402 from gptqmodel.utils.model import convert_gptq_v2_to_v1_format, find_modules @@ -27,15 +25,12 @@ class TestPackable(unittest.TestCase): QLINEAR_DICT = { - BACKEND.EXLLAMA_EORA: ExllamaEoraQuantLinear, - BACKEND.EXLLAMA_V1: ExllamaQuantLinear, - BACKEND.EXLLAMA_V2: ExllamaV2QuantLinear, - BACKEND.TRITON: TritonV2QuantLinear, - BACKEND.TORCH: TorchQuantLinear, - BACKEND.TORCH_FUSED: TorchFusedQuantLinear, - BACKEND.MARLIN: MarlinQuantLinear, - BACKEND.MARLIN_FP16: MarlinQuantLinear, - # BACKEND.BITBLAS: BitBLASQuantLinear, + BACKEND.EXLLAMA_V2: ExllamaV2Linear, + BACKEND.TRITON: TritonV2Linear, + BACKEND.TORCH: TorchLinear, + BACKEND.TORCH_FUSED: TorchFusedLinear, + BACKEND.MARLIN: MarlinLinear, + # BACKEND.BITBLAS: BitBLASLinear, } model_id = "/monster/data/model/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit" @@ -53,15 +48,12 @@ def setUpClass(cls): @parameterized.expand( [ - (BACKEND.EXLLAMA_EORA, {"qweight": False, "qzeros": True, "scales": True, "g_idx": False}), - (BACKEND.EXLLAMA_V1, {"qweight": True, "qzeros": True, "scales": True, "g_idx": True}), (BACKEND.EXLLAMA_V2, {"qweight": False, "qzeros": True, "scales": True, "g_idx": True}), (BACKEND.TRITON, {"qweight": True, "qzeros": True, "scales": True, "g_idx": True}), (BACKEND.TORCH, {"qweight": True, "qzeros": True, "scales": True, "g_idx": True}), (BACKEND.TORCH_FUSED, {"qweight": True, "qzeros": True, "scales": False, "g_idx": True}), # (BACKEND.BITBLAS, {"qweight": True, "qzeros": True, "scales": True, "g_idx": True}), (BACKEND.MARLIN, {"qweight": False, "qzeros": False, "scales": False, "g_idx": False}), - (BACKEND.MARLIN_FP16, {"qweight": False, "qzeros": False, "scales": False, "g_idx": False}), ] ) def test_post_init(self, backend: BACKEND, equal: Dict[str, bool]): diff --git a/tests/test_packing.py b/tests/test_packing.py index 44c741bb7..74c5c86c2 100644 --- a/tests/test_packing.py +++ b/tests/test_packing.py @@ -20,7 +20,7 @@ from parameterized import parameterized # noqa: E402 from gptqmodel import BACKEND # noqa: E402 -from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear # noqa: E402 +from gptqmodel.nn_modules.qlinear.torch import TorchLinear # noqa: E402 def gen_quant(k: int, n: int, groupsize: int, bits: int): @@ -68,9 +68,9 @@ def _reshape(x): class TestRepacking(unittest.TestCase): QLINEAR_DICT = { - # BACKEND.TRITON: TritonV2QuantLinear, - BACKEND.TORCH: TorchQuantLinear, - # BACKEND.TORCH_FUSED: TorchFusedQuantLinear, + # BACKEND.TRITON: TritonV2Linear, + BACKEND.TORCH: TorchLinear, + # BACKEND.TORCH_FUSED: TorchFusedLinear, } # Dimensions (match your original scale) @@ -127,7 +127,7 @@ def test_packing_variants(self, bits: int, group_size: int, backend): # Torch reference packer (compare against Torch CPU packer) try: - torch_linear = self._pack_one(TorchQuantLinear, BACKEND.TORCH, bits, group_size, linear, s, zeros, g_idx) + torch_linear = self._pack_one(TorchLinear, BACKEND.TORCH, bits, group_size, linear, s, zeros, g_idx) torch_linear.post_init() except (NotImplementedError, ValueError) as e: self.skipTest(f"Torch backend does not support bits={bits}, group_size={group_size}: {e}") diff --git a/tests/test_packing_speed.py b/tests/test_packing_speed.py index aa4d9ebf8..0bc831d21 100644 --- a/tests/test_packing_speed.py +++ b/tests/test_packing_speed.py @@ -25,7 +25,12 @@ import torch.nn as nn # noqa: E402 # isort: on -from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear # noqa: E402 +import pytest # noqa: E402 + +from gptqmodel.nn_modules.qlinear.torch import TorchLinear # noqa: E402 + + +pytestmark = [pytest.mark.cpu, pytest.mark.gpu] def gen_quant4(k: int, n: int, groupsize: int): @@ -145,19 +150,28 @@ def pack(self, qlinearCls, backend, impl: str = "cpu"): return qlinear def _time_pack_impl(self, qlinearCls, backend, impl: str, repeats: int, threads: int = 1) -> float: - start = time.time() + impl_lower = impl.lower() + warmup_repeats = 2 if impl_lower == "gpu" and torch.cuda.is_available() else 1 + with threadpoolctl.threadpool_limits(limits=threads): + for _ in range(warmup_repeats): + self.pack(qlinearCls, backend, impl=impl) + + if impl_lower == "gpu" and torch.cuda.is_available(): + torch.cuda.synchronize() + + start = time.perf_counter() for _ in range(repeats): self.pack(qlinearCls, backend, impl=impl) - if impl.lower() == "gpu" and torch.cuda.is_available(): - torch.cuda.synchronize() - return time.time() - start + if impl_lower == "gpu" and torch.cuda.is_available(): + torch.cuda.synchronize() + + return time.perf_counter() - start @parameterized.expand( [ - # [ExllamaQuantLinear, BACKEND.EXLLAMA, 9.63], - # [TritonV2QuantLinear, BACKEND.TRITON, 9.67], - [TorchQuantLinear, BACKEND.TORCH, 21.05], # A100 Z3 33.56 # 4090? 27.0297 + # [TritonV2Linear, BACKEND.TRITON, 9.67], + [TorchLinear, BACKEND.TORCH, 21.05], # A100 Z3 33.56 # 4090? 27.0297 ] ) def test_pack_speed_single_thread(self, qlinearCls, backend, expect_time): @@ -174,9 +188,8 @@ def test_pack_speed_single_thread(self, qlinearCls, backend, expect_time): @parameterized.expand( [ - # [ExllamaQuantLinear, BACKEND.EXLLAMA, 9.63], - # [TritonV2QuantLinear, BACKEND.TRITON, 9.67], - [TorchQuantLinear, BACKEND.TORCH, 14.71], # A100 Z3 33.56 # 4090? 27.0297 + # [TritonV2Linear, BACKEND.TRITON, 9.67], + [TorchLinear, BACKEND.TORCH, 14.71], # A100 Z3 33.56 # 4090? 27.0297 ] ) def test_pack_speed_two_threads(self, qlinearCls, backend, expect_time): @@ -192,7 +205,7 @@ def test_pack_speed_two_threads(self, qlinearCls, backend, expect_time): self.assertLess((time_usage - expect_time) / expect_time, 0.05, msg=f"time: {time_usage:.4f}s") def test_pack_block_thread_scaling(self): - qlinearCls = TorchQuantLinear + qlinearCls = TorchLinear backend = BACKEND.TORCH repeats = 10 thread_options = [1, 2, 4] @@ -218,7 +231,7 @@ def test_pack_block_thread_scaling(self): self.assertLessEqual(best_time, reference_time, "Multi-threaded pack_block did not improve over single-thread baseline") def test_pack_block_extension_speedup(self): - qlinearCls = TorchQuantLinear + qlinearCls = TorchLinear backend = BACKEND.TORCH repeats = 5 @@ -266,7 +279,7 @@ def test_pack_block_extension_speedup(self): @unittest.skipUnless(torch.cuda.is_available(), "CUDA device required for GPU packing speed test") def test_pack_speed_gpu_vs_cpu(self): - qlinearCls = TorchQuantLinear + qlinearCls = TorchLinear backend = BACKEND.TORCH repeats = 10 @@ -287,7 +300,7 @@ def test_pack_speed_gpu_vs_cpu(self): @unittest.skipUnless(torch.cuda.is_available(), "CUDA device required for GPU packing summary test") def test_pack_speed_summary(self): - qlinearCls = TorchQuantLinear + qlinearCls = TorchLinear backend = BACKEND.TORCH repeats = 5 diff --git a/tests/test_parameter_count.py b/tests/test_parameter_count.py index 0a49283fa..1c57e0183 100644 --- a/tests/test_parameter_count.py +++ b/tests/test_parameter_count.py @@ -26,10 +26,10 @@ class TestsParameterCount(ModelTest): def test_parameter_count(self): import os.path - from huggingface_hub import hf_hub_download from safetensors.torch import load_file from gptqmodel import QuantizeConfig + from gptqmodel.utils.hub import hf_hub_download from gptqmodel.utils.tensor import tensor_parameters model_id = "/monster/data/model/Llama-3.2-1B-Instruct-gptqmodel-4bit-vortex-v1" diff --git a/tests/test_paroquant.py b/tests/test_paroquant.py new file mode 100644 index 000000000..3582fc5f8 --- /dev/null +++ b/tests/test_paroquant.py @@ -0,0 +1,4831 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-FileCopyrightText: 2026 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# ParoQuant test coverage adapted from the ParoQuant paper and public project: +# https://arxiv.org/html/2511.10645v2 +# https://github.com/z-lab/paroquant + +"""Unit tests for ParoQuant config, optimizer, and lifecycle invariants.""" + +import copy +import inspect +import sys +import threading +import time +from contextlib import contextmanager +from types import SimpleNamespace + +import pytest +import torch +import torch.nn.functional as F +from transformers.quantizers.auto import AutoQuantizationConfig +from transformers.utils.quantization_config import GPTQConfig + +import gptqmodel.looper.paroquant_processor as paroquant_processor_module +import gptqmodel.utils.paroquant as paroquant_utils_module +from gptqmodel.looper.awq_processor import AWQProcessor +from gptqmodel.looper.input_cache import InputCache +from gptqmodel.looper.module_looper import _restrict_quant_devices_for_method +from gptqmodel.looper.named_module import NamedModule +from gptqmodel.looper.paroquant_processor import ParoQuantProcessor +from gptqmodel.looper.stage_layer import _capture_pristine_group_context +from gptqmodel.nn_modules.hooked_linear import replace_module_with_hooked_legacy +from gptqmodel.nn_modules.qlinear.paroquant import ParoLinear +from gptqmodel.quantization.config import FORMAT, METHOD, ParoConfig, QuantizeConfig +from gptqmodel.quantization.paroquant import optimization as paroquant_optimization +from gptqmodel.quantization.paroquant.optimization import ( + GroupLinearQuantizer, + _apply_rotation, + _ParoQuantOptimLinear, + build_random_rotation_buffers, + optimize_paroquant_linear, + pseudo_quantize_dequant, +) +from gptqmodel.utils.backend import BACKEND +from gptqmodel.utils.importer import get_kernel_for_backend +from gptqmodel.utils.paroquant import ( + _rotation_launch_config, + apply_paroquant_rotation, + apply_paroquant_rotation_reference, + build_identity_rotation_buffers, + clear_paroquant_rotation_autotune_cache, + clear_paroquant_rotation_extension_cache, + prewarm_paroquant_rotation_extension, +) +from gptqmodel.utils.paroquant_benchmark import make_paroquant_config + + +def test_paroquant_quantize_config_dispatches_constructor(): + """Guard that ParoQuant config fields survive direct construction.""" + cfg = QuantizeConfig( + quant_method=METHOD.PARO, + format=FORMAT.PAROQUANT, + bits=4, + group_size=128, + krot=8, + ) + + assert cfg.quant_method == METHOD.PARO + assert cfg.format == FORMAT.PAROQUANT + assert cfg.krot == 8 + assert cfg.opt_batch_size == 64 + assert cfg.opt_optimizer == "adamw" + assert cfg.opt_weight_decay == pytest.approx(0.01) + assert cfg.opt_betas == pytest.approx((0.9, 0.95)) + assert cfg.opt_eps == pytest.approx(1e-10) + assert cfg.opt_amsgrad is False + assert cfg.opt_sgd_momentum == pytest.approx(0.0) + assert cfg.opt_sgd_dampening == pytest.approx(0.0) + assert cfg.opt_sgd_nesterov is False + assert cfg.opt_scope == "module" + assert cfg.opt_stage_impl == "fast" + assert cfg.opt_pair_impl == "fast" + assert cfg.opt_quantizer_impl == "reference" + assert cfg.opt_stage_cudagraph is True + assert cfg.opt_gradient_checkpointing is False + assert cfg.opt_best_state_dtype == "fp32" + assert cfg.opt_train_on_noisy_inputs is False + assert cfg.opt_channel_scale_clamp_min == 1e-2 + assert cfg.opt_channel_scale_clamp_max == 1e2 + assert cfg.export_quant_method() == METHOD.PARO + + +def test_paroquant_quantize_config_enables_gradient_checkpointing_by_default_for_layer_scope(): + """Layer scope should opt into activation checkpointing by default because it is the only measured memory win.""" + + cfg = ParoConfig( + bits=4, + group_size=128, + opt_scope="layer", + ) + + assert cfg.opt_gradient_checkpointing is True + + +def test_paroquant_quantize_config_from_external_payload_round_trips(): + """Guard import/export of ParoQuant metadata from serialized payloads.""" + cfg = QuantizeConfig.from_quant_config( + { + "quant_method": "paroquant", + "bits": 4, + "group_size": 128, + "krot": 8, + "meta": { + "opt_rotation_epochs": 10, + "opt_finetune_epochs": 10, + "opt_train_samples": 2048, + "opt_validation_samples": 64, + "opt_batch_size": 16, + "opt_rotation_lr": 0.05, + "opt_weight_lr": 1e-5, + "opt_quantizer_lr": 1e-6, + "opt_pair_ratio": 0.5, + "opt_seed": 0, + "opt_optimizer": "sgd", + "opt_weight_decay": 0.02, + "opt_betas": [0.8, 0.9], + "opt_eps": 1e-8, + "opt_amsgrad": True, + "opt_sgd_momentum": 0.85, + "opt_sgd_dampening": 0.0, + "opt_sgd_nesterov": True, + "opt_fused_rotation": False, + "opt_gradient_checkpointing": False, + "opt_stage_cudagraph": False, + "opt_best_state_dtype": "fp16", + "opt_train_on_noisy_inputs": True, + "opt_scope": "compute_block", + "opt_stage_impl": "reference", + "opt_pair_impl": "fast", + "opt_quantizer_impl": "reference", + "opt_channel_scale_clamp_min": 0.02, + "opt_channel_scale_clamp_max": 50.0, + }, + } + ) + + assert isinstance(cfg, ParoConfig) + assert cfg.quant_method == METHOD.PARO + assert cfg.format == FORMAT.PAROQUANT + assert cfg.krot == 8 + assert cfg.opt_rotation_epochs == 10 + assert cfg.opt_finetune_epochs == 10 + assert cfg.opt_train_samples == 2048 + assert cfg.opt_validation_samples == 64 + assert cfg.opt_batch_size == 16 + assert cfg.opt_rotation_lr == 0.05 + assert cfg.opt_weight_lr == 1e-5 + assert cfg.opt_quantizer_lr == 1e-6 + assert cfg.opt_pair_ratio == 0.5 + assert cfg.opt_seed == 0 + assert cfg.opt_optimizer == "sgd" + assert cfg.opt_weight_decay == pytest.approx(0.02) + assert cfg.opt_betas == pytest.approx((0.8, 0.9)) + assert cfg.opt_eps == pytest.approx(1e-8) + assert cfg.opt_amsgrad is True + assert cfg.opt_sgd_momentum == pytest.approx(0.85) + assert cfg.opt_sgd_dampening == pytest.approx(0.0) + assert cfg.opt_sgd_nesterov is True + assert cfg.opt_fused_rotation is False + assert cfg.opt_gradient_checkpointing is False + assert cfg.opt_stage_cudagraph is False + assert cfg.opt_best_state_dtype == "fp16" + assert cfg.opt_train_on_noisy_inputs is True + assert cfg.opt_scope == "compute_block" + assert cfg.opt_stage_impl == "reference" + assert cfg.opt_pair_impl == "fast" + assert cfg.opt_quantizer_impl == "reference" + assert cfg.opt_channel_scale_clamp_min == 0.02 + assert cfg.opt_channel_scale_clamp_max == 50.0 + assert cfg.to_dict()["meta"]["opt_fused_rotation"] is False + assert cfg.to_dict()["meta"]["opt_gradient_checkpointing"] is False + assert cfg.to_dict()["meta"]["opt_stage_cudagraph"] is False + assert cfg.to_dict()["meta"]["opt_best_state_dtype"] == "fp16" + assert cfg.to_dict()["meta"]["opt_train_on_noisy_inputs"] is True + assert cfg.to_dict()["meta"]["opt_scope"] == "compute_block" + assert cfg.to_dict()["meta"]["opt_stage_impl"] == "reference" + assert cfg.to_dict()["meta"]["opt_pair_impl"] == "fast" + assert cfg.to_dict()["meta"]["opt_quantizer_impl"] == "reference" + assert cfg.to_dict()["meta"]["opt_channel_scale_clamp_min"] == 0.02 + assert cfg.to_dict()["meta"]["opt_channel_scale_clamp_max"] == 50.0 + assert cfg.to_dict()["meta"]["opt_optimizer"] == "sgd" + assert cfg.to_dict()["meta"]["opt_weight_decay"] == pytest.approx(0.02) + assert cfg.to_dict()["meta"]["opt_betas"] == [0.8, 0.9] + assert cfg.to_dict()["meta"]["opt_eps"] == pytest.approx(1e-8) + assert cfg.to_dict()["meta"]["opt_amsgrad"] is True + assert cfg.to_dict()["meta"]["opt_sgd_momentum"] == pytest.approx(0.85) + assert cfg.to_dict()["meta"]["opt_sgd_dampening"] == pytest.approx(0.0) + assert cfg.to_dict()["meta"]["opt_sgd_nesterov"] is True + + +def test_paroquant_quantize_config_rejects_invalid_scale_clamp_range(): + """Guard that ParoQuant scale-clamp overrides remain numerically valid.""" + with pytest.raises(ValueError, match="scale clamp bounds must be positive"): + ParoConfig( + bits=4, + group_size=128, + opt_channel_scale_clamp_min=0.0, + opt_channel_scale_clamp_max=10.0, + ) + + with pytest.raises(ValueError, match="opt_channel_scale_clamp_min"): + ParoConfig( + bits=4, + group_size=128, + opt_channel_scale_clamp_min=10.0, + opt_channel_scale_clamp_max=10.0, + ) + + +def test_paroquant_quantize_config_rejects_invalid_opt_scope(): + """Guard that ParoQuant optimize-scope selection stays within supported modes.""" + with pytest.raises(ValueError, match="opt_scope"): + ParoConfig( + bits=4, + group_size=128, + opt_scope="block", + ) + + +def test_paroquant_quantize_config_rejects_invalid_opt_optimizer(): + """Guard that ParoQuant optimizer selection stays within supported modes.""" + with pytest.raises(ValueError, match="opt_optimizer"): + ParoConfig( + bits=4, + group_size=128, + opt_optimizer="lion", + ) + + +def test_paroquant_quantize_config_rejects_invalid_best_state_dtype(): + """Guard best-state snapshot compression against unsupported dtype strings.""" + with pytest.raises(ValueError, match="opt_best_state_dtype"): + ParoConfig( + bits=4, + group_size=128, + opt_best_state_dtype="int8", + ) + + +def test_paroquant_quantize_config_rejects_invalid_optimizer_hyperparameters(): + """Guard optimizer hyperparameter validation against invalid stage settings.""" + with pytest.raises(ValueError, match="opt_betas"): + ParoConfig( + bits=4, + group_size=128, + opt_betas=(0.9,), + ) + + with pytest.raises(ValueError, match="opt_eps"): + ParoConfig( + bits=4, + group_size=128, + opt_eps=0.0, + ) + + with pytest.raises(ValueError, match="opt_sgd_nesterov"): + ParoConfig( + bits=4, + group_size=128, + opt_sgd_nesterov=True, + ) + + with pytest.raises(ValueError, match="opt_sgd_dampening"): + ParoConfig( + bits=4, + group_size=128, + opt_sgd_momentum=0.9, + opt_sgd_dampening=0.1, + opt_sgd_nesterov=True, + ) + + +def test_paroquant_benchmark_config_preserves_opt_scope(): + """Benchmark helpers should propagate the requested optimization scope.""" + cfg = make_paroquant_config(dynamic={}, opt_scope="compute_block") + + assert cfg.quant_method == METHOD.PARO + assert cfg.opt_scope == "compute_block" + assert cfg.opt_gradient_checkpointing is False + + +def test_paroquant_quantize_config_preserves_explicit_gradient_checkpointing_override(): + """Explicit checkpointing overrides must win over the scope-derived default.""" + + layer_cfg = ParoConfig( + bits=4, + group_size=128, + opt_scope="layer", + opt_gradient_checkpointing=False, + ) + compute_block_cfg = ParoConfig( + bits=4, + group_size=128, + opt_scope="compute_block", + opt_gradient_checkpointing=True, + ) + + assert layer_cfg.opt_gradient_checkpointing is False + assert compute_block_cfg.opt_gradient_checkpointing is True + + +def test_paroquant_rotation_toggle_prefers_explicit_config_over_env(monkeypatch): + """Guard that the config-backed fused toggle overrides the legacy env fallback.""" + x = torch.randn(4, 8, dtype=torch.float32) + pairs, theta, channel_scales = build_identity_rotation_buffers( + in_features=8, + group_size=8, + krot=1, + dtype=torch.float32, + ) + theta = theta.clone().requires_grad_(True) + + calls = [] + + def fake_fused_rotation(x, pairs, theta, *, scales, group_size): + del pairs, theta, scales, group_size + calls.append("fused") + return x + 123.0 + + monkeypatch.setattr(paroquant_optimization, "apply_paroquant_rotation_autograd", fake_fused_rotation) + monkeypatch.setenv("GPTQMODEL_PAROQUANT_OPT_FUSED_ROTATION", "1") + + reference_out = _apply_rotation( + x, + pairs, + theta, + scales=channel_scales, + group_size=8, + fused_rotation=False, + ) + assert calls == [] + torch.testing.assert_close(reference_out, x, atol=0, rtol=0) + + fused_out = _apply_rotation( + x, + pairs, + theta, + scales=channel_scales, + group_size=8, + fused_rotation=True, + ) + assert calls == ["fused"] + torch.testing.assert_close(fused_out, x + 123.0) + + +def test_paroquant_fused_rotation_uses_forward_only_path_without_grad_inputs(monkeypatch): + """Guard that inactive rotation grads do not route through the autograd wrapper.""" + x = torch.randn(4, 8, dtype=torch.float32) + pairs, theta, channel_scales = build_identity_rotation_buffers( + in_features=8, + group_size=8, + krot=1, + dtype=torch.float32, + ) + + calls = [] + + def fake_forward_only(x, pairs, theta, *, scales, group_size): + del pairs, theta, scales, group_size + calls.append("forward") + return x + 1.0 + + def fake_autograd(x, pairs, theta, *, scales, group_size): + del pairs, theta, scales, group_size + calls.append("autograd") + return x + 2.0 + + monkeypatch.setattr(paroquant_optimization, "apply_paroquant_rotation", fake_forward_only) + monkeypatch.setattr(paroquant_optimization, "apply_paroquant_rotation_autograd", fake_autograd) + + out = _apply_rotation( + x, + pairs, + theta, + scales=channel_scales, + group_size=8, + fused_rotation=True, + ) + torch.testing.assert_close(out, x + 1.0) + assert calls == ["forward"] + + calls.clear() + theta = theta.clone().requires_grad_(True) + out = _apply_rotation( + x, + pairs, + theta, + scales=channel_scales, + group_size=8, + fused_rotation=True, + ) + torch.testing.assert_close(out, x + 2.0) + assert calls == ["autograd"] + + +@pytest.mark.parametrize( + ("device_type", "expected_use_amp"), + [ + ("cpu", False), + ("cuda", True), + ], +) +def test_paroquant_fast_stage_matches_reference_amp_eval_flag(monkeypatch, device_type, expected_use_amp): + """Guard that the fast stage uses CUDA AMP for eval bookkeeping like reference.""" + calls = [] + + class _DummyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(1)) + + class _FakeTensor: + def __init__(self, kind: str): + self.device = SimpleNamespace(type=kind) + + def fake_evaluate(_model, _inputs, _targets, *, use_amp=False): + calls.append(use_amp) + return 0.0 + + monkeypatch.setattr(paroquant_optimization, "_evaluate_model", fake_evaluate) + + train_loss, val_loss = paroquant_optimization._run_stage_gptqmodel( + model=_DummyModel(), + inputs_train=_FakeTensor(device_type), + targets_train=_FakeTensor(device_type), + inputs_val=_FakeTensor(device_type), + targets_val=_FakeTensor(device_type), + param_groups=[], + epochs=0, + batch_size=16, + ) + + assert train_loss == 0.0 + assert val_loss == 0.0 + assert calls == [expected_use_amp, expected_use_amp] + + +def test_paroquant_evaluate_model_keeps_loss_inside_cuda_autocast(monkeypatch): + """Guard the PR-18 fix so validation loss stays inside the CUDA autocast region.""" + state = {"autocast_active": False, "loss_saw_autocast": None} + + @contextmanager + def fake_autocast(device_type: str): + assert device_type == "cuda" + previous = state["autocast_active"] + state["autocast_active"] = True + try: + yield + finally: + state["autocast_active"] = previous + + class _FakeInput: + def __init__(self): + self.device = SimpleNamespace(type="cuda") + + def numel(self): + return 1 + + class _DummyModel(torch.nn.Module): + def forward(self, _inputs): + return torch.tensor([1.0], dtype=torch.float32) + + def fake_loss(preds, targets): + del preds, targets + state["loss_saw_autocast"] = state["autocast_active"] + return torch.tensor(0.25, dtype=torch.float32) + + monkeypatch.setattr(torch.amp, "autocast", fake_autocast) + monkeypatch.setattr(paroquant_optimization.F, "smooth_l1_loss", fake_loss) + + loss = paroquant_optimization._evaluate_model( + _DummyModel(), + _FakeInput(), + torch.tensor([0.0], dtype=torch.float32), + use_amp=True, + ) + + assert loss == 0.25 + assert state["loss_saw_autocast"] is True + + +def test_paroquant_fast_stage_uses_cuda_amp_training(monkeypatch): + """Guard that the fast stage now mirrors upstream AMP training on CUDA.""" + state = {"autocast_active": False, "loss_saw_autocast": []} + scaler_events = [] + + @contextmanager + def fake_autocast(device_type: str): + assert device_type == "cuda" + previous = state["autocast_active"] + state["autocast_active"] = True + try: + yield + finally: + state["autocast_active"] = previous + + class _FakeScaler: + def __init__(self, *, enabled: bool): + scaler_events.append(("init", enabled)) + self.enabled = enabled + + def scale(self, loss): + scaler_events.append(("scale", self.enabled)) + + class _ScaledLoss: + def __init__(self, wrapped_loss): + self.wrapped_loss = wrapped_loss + + def backward(self): + scaler_events.append(("backward", self.wrapped_loss.detach().item())) + self.wrapped_loss.backward() + + return _ScaledLoss(loss) + + def step(self, optimizer): + scaler_events.append(("step", self.enabled)) + optimizer.step() + + def update(self): + scaler_events.append(("update", self.enabled)) + + class _TinyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = torch.nn.Parameter(torch.tensor([[1.0]], dtype=torch.float32)) + + def forward(self, x): + return x @ self.weight + + def reset_masked_angles(self): + return None + + class _FakeRows: + def __init__(self, rows: int): + self.device = SimpleNamespace(type="cuda") + self.shape = (rows, 1) + + train_inputs = _FakeRows(rows=2) + train_targets = _FakeRows(rows=2) + val_inputs = _FakeRows(rows=1) + val_targets = _FakeRows(rows=1) + + train_input_batches = [torch.tensor([[1.0]], dtype=torch.float32), torch.tensor([[2.0]], dtype=torch.float32)] + train_target_batches = [torch.tensor([[0.0]], dtype=torch.float32), torch.tensor([[0.0]], dtype=torch.float32)] + + def fake_chunk_rows(rows, batch_size): + del batch_size + if rows is train_inputs: + return train_input_batches + if rows is train_targets: + return train_target_batches + raise AssertionError("unexpected rows object") + + def fake_evaluate(_model, _inputs, _targets, *, use_amp=False): + assert use_amp is True + return 0.0 + + original_loss = paroquant_optimization.F.smooth_l1_loss + + def wrapped_loss(preds, targets): + state["loss_saw_autocast"].append(state["autocast_active"]) + return original_loss(preds, targets) + + monkeypatch.setattr(paroquant_optimization, "_chunk_rows", fake_chunk_rows) + monkeypatch.setattr(paroquant_optimization, "_evaluate_model", fake_evaluate) + monkeypatch.setattr(paroquant_optimization.F, "smooth_l1_loss", wrapped_loss) + monkeypatch.setattr(torch.amp, "autocast", fake_autocast) + monkeypatch.setattr(torch.amp, "GradScaler", _FakeScaler) + + model = _TinyModel() + train_loss, val_loss = paroquant_optimization._run_stage_gptqmodel( + model=model, + inputs_train=train_inputs, + targets_train=train_targets, + inputs_val=val_inputs, + targets_val=val_targets, + param_groups=[{"params": [model.weight], "lr": 0.1}], + epochs=1, + batch_size=1, + ) + + assert train_loss >= 0.0 + assert val_loss == 0.0 + assert state["loss_saw_autocast"] == [True, True] + assert scaler_events[0] == ("init", True) + assert [event[0] for event in scaler_events[1:]] == [ + "scale", + "backward", + "step", + "update", + "scale", + "backward", + "step", + "update", + ] + assert all(event[1] is True for event in scaler_events if event[0] != "backward") + assert all(event[1] >= 0.0 for event in scaler_events if event[0] == "backward") + + +def test_paroquant_fast_pair_builder_emits_disjoint_matchings(): + """Guard the fast pair builder so each rotation remains kernel-legal.""" + pairs, masks = build_random_rotation_buffers( + in_features=8, + group_size=8, + krot=3, + pair_ratio=0.5, + seed=0, + device=torch.device("cpu"), + ) + + assert pairs.shape == (3, 8) + assert masks.shape == (3, 4) + assert torch.count_nonzero(masks).item() == 0 + + seen_edges = set() + for rotation_pairs in pairs.view(3, 4, 2).tolist(): + used_channels = set() + for left, right in rotation_pairs: + assert left != right + assert left not in used_channels + assert right not in used_channels + used_channels.add(left) + used_channels.add(right) + edge = tuple(sorted((left, right))) + assert edge not in seen_edges + seen_edges.add(edge) + + +def test_paroquant_optim_forward_matches_pseudo_weight_contract(): + """Guard the stage-time forward rewrite against the original pseudo-weight contract.""" + torch.manual_seed(0) + weight = torch.randn((16, 8), dtype=torch.float32) + inputs = torch.randn((5, 8), dtype=torch.float32) + pairs, theta_mask = build_random_rotation_buffers( + in_features=8, + group_size=8, + krot=3, + pair_ratio=0.5, + seed=0, + device=torch.device("cpu"), + ) + model = _ParoQuantOptimLinear( + weight, + None, + bits=4, + group_size=8, + quantizer_sym=True, + pairs=pairs, + theta_mask=theta_mask, + fused_rotation=False, + ) + with torch.no_grad(): + model.theta.uniform_(-0.2, 0.2) + model.channel_scales_opt.uniform_(0.8, 1.2) + model.init_quantizer() + + expected = F.linear(inputs, model.pseudo_weight(), model.bias) + actual = model(inputs) + + torch.testing.assert_close(actual, expected, atol=2e-6, rtol=1e-5) + + +def test_paroquant_optim_forward_matches_pseudo_weight_contract_for_rank3(): + """Guard grouped/layer optimization forwards that pass [batch, seq, hidden] activations.""" + torch.manual_seed(0) + weight = torch.randn((16, 8), dtype=torch.float32) + inputs = torch.randn((2, 3, 8), dtype=torch.float32) + pairs, theta_mask = build_random_rotation_buffers( + in_features=8, + group_size=8, + krot=3, + pair_ratio=0.5, + seed=0, + device=torch.device("cpu"), + ) + model = _ParoQuantOptimLinear( + weight, + None, + bits=4, + group_size=8, + quantizer_sym=True, + pairs=pairs, + theta_mask=theta_mask, + fused_rotation=False, + ) + with torch.no_grad(): + model.theta.uniform_(-0.2, 0.2) + model.channel_scales_opt.uniform_(0.8, 1.2) + model.init_quantizer() + + expected = F.linear(inputs, model.pseudo_weight(), model.bias) + actual = model(inputs) + + torch.testing.assert_close(actual, expected, atol=2e-6, rtol=1e-5) + + +def test_paroquant_materialized_sym_scale_ste_matches_legacy_gradients(): + """Guard the stage2 symmetric quantizer rewrite against the legacy STE math.""" + torch.manual_seed(0) + group_size = 8 + bits = 4 + qmin = -(2 ** (bits - 1)) + qmax = 2 ** (bits - 1) - 1 + + weight = torch.randn((6, group_size), dtype=torch.float32, requires_grad=True) + scale = (torch.rand((6, 1), dtype=torch.float32) + 0.1).requires_grad_() + + legacy_scale = scale.clone().detach().requires_grad_(True) + legacy_weight = weight.clone().detach().requires_grad_(True) + legacy_scale_safe = paroquant_optimization._clamp_ste(legacy_scale, min_value=1e-5, max_value=1e5) + legacy_quant = paroquant_optimization._clamp_ste( + paroquant_optimization._round_ste(legacy_weight / legacy_scale_safe), + qmin, + qmax, + ) + legacy_output = (legacy_quant * legacy_scale_safe).reshape_as(legacy_weight) + + actual_output = pseudo_quantize_dequant( + weight, + bits=bits, + group_size=group_size, + sym=True, + scale=scale, + use_ste=True, + ) + + torch.testing.assert_close(actual_output, legacy_output, atol=0, rtol=0) + + legacy_output.sum().backward() + actual_output.sum().backward() + + torch.testing.assert_close(weight.grad, legacy_weight.grad, atol=0, rtol=0) + torch.testing.assert_close(scale.grad, legacy_scale.grad, atol=0, rtol=0) + + +def test_paroquant_large_train_quant_compile_dispatch(monkeypatch): + """Guard that only large CUDA training-time quant calls route into the compiled helper.""" + class _FakeWeight: + def __init__(self, numel: int, device_type: str = "cuda"): + self._numel = numel + self.device = SimpleNamespace(type=device_type) + + def numel(self) -> int: + return self._numel + + monkeypatch.setattr(paroquant_optimization, "_PAROQUANT_LARGE_TRAIN_QUANT_COMPILE_MIN_NUMEL", 16) + monkeypatch.setattr(paroquant_optimization, "env_flag", lambda *_args, **_kwargs: True) + monkeypatch.setattr(paroquant_optimization, "_get_large_train_quant_compile", lambda: lambda *_args, **_kwargs: "compiled") + monkeypatch.setattr(paroquant_optimization, "pseudo_quantize_dequant", lambda *_args, **_kwargs: "eager") + + assert ( + paroquant_optimization._maybe_compile_large_train_quant( + _FakeWeight(32), + bits=4, + group_size=8, + sym=True, + ) + == "compiled" + ) + assert ( + paroquant_optimization._maybe_compile_large_train_quant( + _FakeWeight(8), + bits=4, + group_size=8, + sym=True, + ) + == "eager" + ) + assert ( + paroquant_optimization._maybe_compile_large_train_quant( + _FakeWeight(32), + bits=4, + group_size=8, + sym=False, + ) + == "compiled" + ) + assert ( + paroquant_optimization._maybe_compile_large_train_quant( + _FakeWeight(32, device_type="cpu"), + bits=4, + group_size=8, + sym=True, + ) + == "eager" + ) + + +def test_paroquant_stage_cudagraph_gate_requires_real_cuda_tensor(monkeypatch): + """Guard that CUDA-graph replay only activates for real CUDA tensor stages.""" + class _DummyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones((2, 4), dtype=torch.float32)) + self.fused_rotation = True + + class _FakeRows: + def __init__(self, device_type: str): + self.shape = (2048, 4) + self.device = SimpleNamespace(type=device_type, index=0 if device_type == "cuda" else None) + + monkeypatch.delenv("GPTQMODEL_PAROQUANT_OPT_STAGE_CUDAGRAPH", raising=False) + + model = _DummyModel() + real_cpu_rows = torch.ones((2048, 4), dtype=torch.float32) + fake_cuda_rows = _FakeRows("cuda") + + assert paroquant_optimization._should_use_paroquant_stage_cudagraph(model, inputs_train=real_cpu_rows, batch_size=64) is False + assert paroquant_optimization._should_use_paroquant_stage_cudagraph(model, inputs_train=fake_cuda_rows, batch_size=64) is False + + +def test_paroquant_stage_cudagraph_falls_back_to_eager_on_runtime_error(monkeypatch): + """Guard that a CUDA-graph stage failure restores model state and reruns eagerly.""" + class _DummyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = torch.nn.Parameter(torch.tensor([[1.0]], dtype=torch.float32)) + self.fused_rotation = True + + def reset_masked_angles(self): + return None + + model = _DummyModel() + call_order = [] + + def fake_impl(*, model, **kwargs): + del kwargs + call_order.append(model) + return 1.25, 2.5 + + def fake_cudagraph(*, model, **kwargs): + del kwargs + call_order.append(("graph", model)) + raise RuntimeError("graph failed at runtime") + + monkeypatch.setattr(paroquant_optimization, "_should_use_paroquant_stage_cudagraph", lambda *args, **kwargs: True) + monkeypatch.setattr(paroquant_optimization, "_run_stage_gptqmodel_cudagraph", fake_cudagraph) + monkeypatch.setattr(paroquant_optimization, "_run_stage_gptqmodel_impl", fake_impl) + + train_loss, val_loss = paroquant_optimization._run_stage_gptqmodel( + model=model, + inputs_train=torch.ones((2, 1), dtype=torch.float32), + targets_train=torch.zeros((2, 1), dtype=torch.float32), + inputs_val=torch.ones((1, 1), dtype=torch.float32), + targets_val=torch.zeros((1, 1), dtype=torch.float32), + param_groups=[], + epochs=1, + batch_size=1, + ) + + assert (train_loss, val_loss) == (1.25, 2.5) + assert call_order == [("graph", model), model] + + +def test_optimize_paroquant_linear_forwards_stage_cudagraph(monkeypatch): + """Guard that explicit CUDA-graph policy is forwarded into both optimization stages.""" + stage_cudagraph_calls = [] + original_run_stage = paroquant_optimization._run_stage + + def spy_run_stage(*, stage_cudagraph=None, **kwargs): + stage_cudagraph_calls.append(stage_cudagraph) + return original_run_stage(stage_cudagraph=stage_cudagraph, **kwargs) + + monkeypatch.setattr(paroquant_optimization, "_run_stage", spy_run_stage) + + weight = torch.randn((8, 8), dtype=torch.float32) + inputs = torch.randn((64, 8), dtype=torch.float32) + + result = optimize_paroquant_linear( + weight=weight, + bias=None, + inputs=inputs, + bits=4, + group_size=8, + sym=True, + krot=1, + pair_ratio=0.5, + train_rows=32, + val_rows=16, + batch_size=16, + rotation_epochs=1, + finetune_epochs=1, + rotation_lr=0.05, + weight_lr=1e-5, + quantizer_lr=1e-6, + seed=0, + fused_rotation=True, + stage_cudagraph=False, + stage_impl="fast", + pair_impl="fast", + quantizer_impl="reference", + ) + + assert result.val_loss >= 0.0 + assert stage_cudagraph_calls == [False, False] + + +def test_optimize_paroquant_linear_forwards_optimizer_name(monkeypatch): + """Guard that the selected stage optimizer is forwarded into both optimization stages.""" + optimizer_name_calls = [] + original_run_stage = paroquant_optimization._run_stage + + def spy_run_stage(*, optimizer_name="adamw", **kwargs): + optimizer_name_calls.append(optimizer_name) + return original_run_stage(optimizer_name=optimizer_name, **kwargs) + + monkeypatch.setattr(paroquant_optimization, "_run_stage", spy_run_stage) + + weight = torch.randn((8, 8), dtype=torch.float32) + inputs = torch.randn((64, 8), dtype=torch.float32) + + result = optimize_paroquant_linear( + weight=weight, + bias=None, + inputs=inputs, + bits=4, + group_size=8, + sym=True, + krot=1, + pair_ratio=0.5, + train_rows=32, + val_rows=16, + batch_size=16, + rotation_epochs=1, + finetune_epochs=1, + rotation_lr=0.05, + weight_lr=1e-5, + quantizer_lr=1e-6, + seed=0, + optimizer_name="sgd", + fused_rotation=True, + stage_cudagraph=False, + stage_impl="fast", + pair_impl="fast", + quantizer_impl="reference", + ) + + assert result.val_loss >= 0.0 + assert optimizer_name_calls == ["sgd", "sgd"] + + +def test_optimize_paroquant_linear_forwards_best_state_dtype(monkeypatch): + """Guard that explicit best-state snapshot dtype policy is forwarded into both optimization stages.""" + best_state_dtype_calls = [] + original_run_stage = paroquant_optimization._run_stage + + def spy_run_stage(*, best_state_dtype="fp32", **kwargs): + best_state_dtype_calls.append(best_state_dtype) + return original_run_stage(best_state_dtype=best_state_dtype, **kwargs) + + monkeypatch.setattr(paroquant_optimization, "_run_stage", spy_run_stage) + + weight = torch.randn((8, 8), dtype=torch.float32) + inputs = torch.randn((64, 8), dtype=torch.float32) + + result = optimize_paroquant_linear( + weight=weight, + bias=None, + inputs=inputs, + bits=4, + group_size=8, + sym=True, + krot=1, + pair_ratio=0.5, + train_rows=32, + val_rows=16, + batch_size=16, + rotation_epochs=1, + finetune_epochs=1, + rotation_lr=0.05, + weight_lr=1e-5, + quantizer_lr=1e-6, + seed=0, + fused_rotation=True, + stage_cudagraph=False, + best_state_dtype="fp16", + stage_impl="fast", + pair_impl="fast", + quantizer_impl="reference", + ) + + assert result.val_loss >= 0.0 + assert best_state_dtype_calls == ["fp16", "fp16"] + + +def test_optimize_paroquant_linear_supports_sgd_optimizer(): + """Guard the direct ParoQuant path against rejecting valid SGD hyperparameters.""" + weight = torch.randn((16, 16), dtype=torch.float32) + inputs = torch.randn((96, 16), dtype=torch.float32) + + result = optimize_paroquant_linear( + weight=weight, + bias=None, + inputs=inputs, + bits=4, + group_size=8, + sym=True, + krot=1, + pair_ratio=0.5, + train_rows=64, + val_rows=32, + batch_size=16, + rotation_epochs=1, + finetune_epochs=1, + rotation_lr=0.05, + weight_lr=1e-4, + quantizer_lr=1e-4, + seed=0, + optimizer_name="sgd", + optimizer_weight_decay=0.02, + sgd_momentum=0.85, + sgd_dampening=0.0, + sgd_nesterov=True, + fused_rotation=False, + stage_cudagraph=False, + stage_impl="fast", + pair_impl="fast", + quantizer_impl="reference", + ) + + assert result.val_loss >= 0.0 + assert result.pseudo_weight.shape == weight.shape + + +def test_paroquant_run_stage_only_enables_active_gradients(monkeypatch): + """Guard that each stage only backpropagates through the parameters it optimizes.""" + pairs, theta_mask = build_random_rotation_buffers( + in_features=8, + group_size=8, + krot=1, + pair_ratio=0.5, + seed=0, + device=torch.device("cpu"), + ) + model = _ParoQuantOptimLinear( + torch.randn((8, 8), dtype=torch.float32), + torch.randn((8,), dtype=torch.float32), + bits=4, + group_size=8, + quantizer_sym=True, + pairs=pairs, + theta_mask=theta_mask, + fused_rotation=False, + ) + original_flags = {name: param.requires_grad for name, param in model.named_parameters()} + seen_flags = {} + + def fake_stage_impl(**kwargs): + del kwargs + seen_flags.update({name: param.requires_grad for name, param in model.named_parameters()}) + return 0.0, 0.0 + + monkeypatch.setattr(paroquant_optimization, "_run_stage_gptqmodel", fake_stage_impl) + + paroquant_optimization._run_stage( + model=model, + inputs_train=torch.randn((4, 8), dtype=torch.float32), + targets_train=torch.randn((4, 8), dtype=torch.float32), + inputs_val=torch.randn((2, 8), dtype=torch.float32), + targets_val=torch.randn((2, 8), dtype=torch.float32), + param_groups=[ + {"params": [model.channel_scales_opt], "lr": 0.05}, + {"params": [model.theta], "lr": 0.05}, + ], + epochs=1, + batch_size=2, + stage_impl="fast", + ) + + assert seen_flags["theta"] is True + assert seen_flags["channel_scales_opt"] is True + assert seen_flags["weight"] is False + assert seen_flags["bias"] is False + assert {name: param.requires_grad for name, param in model.named_parameters()} == original_flags + + +def test_paroquant_registers_with_transformers_gptq_quantizer(): + """Guard the HF quantization registry alias used by Evalution loaders.""" + cfg = AutoQuantizationConfig.from_dict( + { + "quant_method": "paroquant", + "bits": 4, + "group_size": 128, + "sym": True, + "format": "paroquant", + } + ) + + assert isinstance(cfg, GPTQConfig) + assert getattr(cfg.quant_method, "value", cfg.quant_method) == "gptq" + + +def test_paroquant_kernel_mapping_uses_paroquant_backend(): + """Guard backend dispatch so ParoQuant does not silently fall back to AWQ.""" + from gptqmodel.nn_modules.qlinear.paroquant import ParoLinear + + assert ( + get_kernel_for_backend(BACKEND.PAROQUANT_CUDA, METHOD.PARO, FORMAT.PAROQUANT) + is ParoLinear + ) + + +def test_paroquant_kernel_mapping_uses_paroquant_triton_backend(): + """Guard Triton backend dispatch for ParoQuant-specific runtime modules.""" + from gptqmodel.nn_modules.qlinear.paroquant_triton import ParoQuantTritonLinear + + assert ( + get_kernel_for_backend(BACKEND.PAROQUANT_TRITON, METHOD.PARO, FORMAT.PAROQUANT) + is ParoQuantTritonLinear + ) + + +def test_paroquant_identity_rotation_buffers_preserve_input(): + """Guard the identity buffer builder used by no-op and fallback paths.""" + x = torch.randn(3, 128, dtype=torch.float16) + pairs, theta, channel_scales = build_identity_rotation_buffers( + in_features=128, + group_size=128, + krot=8, + dtype=torch.float16, + ) + + rotated = apply_paroquant_rotation_reference( + x, + pairs, + theta, + scales=channel_scales, + group_size=128, + ) + + torch.testing.assert_close(rotated, x, atol=0, rtol=0) + + +def test_paroquant_module_default_rotation_buffers_are_identity(): + """Guard fresh runtime modules against invalid all-zero pair buffers.""" + module = ParoLinear( + bits=4, + group_size=128, + sym=True, + desc_act=False, + in_features=128, + out_features=128, + bias=False, + register_buffers=True, + ) + pairs, theta, channel_scales = build_identity_rotation_buffers( + in_features=module.in_features, + group_size=module.group_size, + krot=module.krot, + dtype=module.theta.dtype, + ) + + assert torch.equal(module.pairs, pairs) + assert torch.equal(module.theta, theta) + assert torch.equal(module.channel_scales, channel_scales) + + +def test_paroquant_processor_is_not_awq_subclass(): + """Guard the dedicated lifecycle split from AWQ requested by the user.""" + assert not issubclass(ParoQuantProcessor, AWQProcessor) + + +def test_paroquant_processor_resets_reused_module_buckets_per_layer(): + """Guard against cross-layer activation reuse for repeated relative module names.""" + processor = object.__new__(ParoQuantProcessor) + processor.lock = threading.Lock() + processor.tasks = {} + + processor._ensure_task_bucket("mlp.gate_proj", layer_index=0) + processor.tasks["mlp.gate_proj"]["inputs"].append(torch.randn(1, 8)) + processor._ensure_task_bucket("mlp.gate_proj", layer_index=0) + assert len(processor.tasks["mlp.gate_proj"]["inputs"]) == 1 + + processor._ensure_task_bucket("mlp.gate_proj", layer_index=1) + assert processor.tasks["mlp.gate_proj"]["layer_index"] == 1 + assert processor.tasks["mlp.gate_proj"]["inputs"] == [] + + +def test_paroquant_processor_groups_common_llama_compute_blocks(): + """Guard the planned compute_block optimizer buckets for attention and MLP projections.""" + processor = object.__new__(ParoQuantProcessor) + processor.qcfg = SimpleNamespace(opt_scope="compute_block") + + state = SimpleNamespace( + modules={ + "self_attn.q_proj": SimpleNamespace(name="self_attn.q_proj"), + "self_attn.k_proj": SimpleNamespace(name="self_attn.k_proj"), + "self_attn.v_proj": SimpleNamespace(name="self_attn.v_proj"), + "self_attn.o_proj": SimpleNamespace(name="self_attn.o_proj"), + "mlp.gate_proj": SimpleNamespace(name="mlp.gate_proj"), + "mlp.up_proj": SimpleNamespace(name="mlp.up_proj"), + "mlp.down_proj": SimpleNamespace(name="mlp.down_proj"), + } + ) + + groups = processor._optimization_groups_for_layer(state) + + assert [(label, [module.name for module in modules]) for label, modules in groups] == [ + ("attn_o", ["self_attn.o_proj"]), + ("attn_qkv", ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"]), + ("mlp_down", ["mlp.down_proj"]), + ("mlp_gate_up", ["mlp.gate_proj", "mlp.up_proj"]), + ] + + +def test_paroquant_processor_module_scope_seed_uses_full_module_name(): + """Guard module scope against collapsing different layer linears onto one archetype seed.""" + full_name_a = "model.layers.0.self_attn.q_proj" + full_name_b = "model.layers.0.block.self_attn.q_proj" + + module_scope = object.__new__(ParoQuantProcessor) + module_scope.qcfg = SimpleNamespace(opt_scope="module", opt_seed=3141592653) + assert module_scope._module_seed(0, full_name_a) != module_scope._module_seed(0, full_name_b) + + grouped_scope = object.__new__(ParoQuantProcessor) + grouped_scope.qcfg = SimpleNamespace(opt_scope="compute_block", opt_seed=3141592653) + assert grouped_scope._module_seed(0, full_name_a) == grouped_scope._module_seed(0, full_name_b) + + +def test_paroquant_prewarm_rotation_extension_skips_unsupported_configs(monkeypatch): + """Guard the explicit prewarm helper so startup only pays for real fused-kernel cases.""" + calls = [] + + monkeypatch.setattr( + paroquant_utils_module._PAROQUANT_ROTATION_EXTENSION, + "load", + lambda: calls.append("load") or True, + ) + monkeypatch.setattr(torch.cuda, "is_available", lambda: True) + + assert prewarm_paroquant_rotation_extension(fused_rotation=False, group_size=128, krot=8) is False + assert prewarm_paroquant_rotation_extension(fused_rotation=True, group_size=64, krot=8) is False + assert prewarm_paroquant_rotation_extension(fused_rotation=True, group_size=128, krot=4) is False + assert ( + prewarm_paroquant_rotation_extension( + fused_rotation=True, + group_size=128, + krot=8, + device=torch.device("cpu"), + ) + is False + ) + assert prewarm_paroquant_rotation_extension(fused_rotation=True, group_size=128, krot=8) is True + assert calls == ["load"] + + +def test_paroquant_clear_rotation_extension_cache_delegates_to_shared_loader(monkeypatch): + """Guard the public cache-clear helper so benchmarks can force fresh torch.ops rebuilds.""" + + calls = [] + monkeypatch.setattr( + paroquant_utils_module._PAROQUANT_ROTATION_EXTENSION, + "clear_cache", + lambda: calls.append("clear"), + ) + + clear_paroquant_rotation_extension_cache() + + assert calls == ["clear"] + + +def test_paroquant_rotation_launch_config_honors_env_overrides(monkeypatch): + """Guard manual launch-shape overrides so benchmarking can pin one kernel variant.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA is required to validate the ParoQuant launch-config path.") + + monkeypatch.setenv("GPTQMODEL_PAROQUANT_ROTATE_CTA_M", "16") + monkeypatch.setenv("GPTQMODEL_PAROQUANT_ROTATE_ROW_PAD", "0") + + assert prewarm_paroquant_rotation_extension( + fused_rotation=True, + group_size=128, + krot=8, + device=torch.device("cuda"), + ) + + cta_m, row_pad = _rotation_launch_config(torch.empty((1, 128), device="cuda", dtype=torch.float16)) + assert (cta_m, row_pad) == (16, 0) + + +def test_paroquant_rotation_launch_config_autotunes_once_per_shape(monkeypatch): + """Guard fused rotation autotune so one native shape plan is cached and reused.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA is required to validate the ParoQuant launch-config path.") + + clear_paroquant_rotation_autotune_cache() + monkeypatch.delenv("GPTQMODEL_PAROQUANT_ROTATE_CTA_M", raising=False) + monkeypatch.delenv("GPTQMODEL_PAROQUANT_ROTATE_ROW_PAD", raising=False) + monkeypatch.setenv("GPTQMODEL_PAROQUANT_ROTATE_AUTOTUNE", "1") + monkeypatch.setenv("GPTQMODEL_PAROQUANT_ROTATE_AUTOTUNE_WARMUP", "1") + monkeypatch.setenv("GPTQMODEL_PAROQUANT_ROTATE_AUTOTUNE_ITERS", "1") + + assert prewarm_paroquant_rotation_extension( + fused_rotation=True, + group_size=128, + krot=8, + device=torch.device("cuda"), + ) + + x = torch.empty((32, 128), device="cuda", dtype=torch.float16) + pairs = torch.zeros((8, 128), device="cuda", dtype=torch.int16) + theta = torch.zeros((8, 64), device="cuda", dtype=torch.float16) + scales = torch.ones((1, 128), device="cuda", dtype=torch.float16) + x_other = torch.empty((64, 128), device="cuda", dtype=torch.float16) + + assert paroquant_utils_module._rotation_autotune_cache_size() == 0 + first = _rotation_launch_config(x, pairs, theta, scales=scales, group_size=128) + assert first in {(4, 0), (4, 2), (8, 0), (8, 2), (16, 0), (16, 2)} + assert paroquant_utils_module._rotation_autotune_cache_size() == 1 + second = _rotation_launch_config(x, pairs, theta, scales=scales, group_size=128) + third = _rotation_launch_config(x_other, pairs, theta, scales=scales, group_size=128) + + assert second == first + assert third in {(4, 0), (4, 2), (8, 0), (8, 2), (16, 0), (16, 2)} + assert paroquant_utils_module._rotation_autotune_cache_size() == 2 + clear_paroquant_rotation_autotune_cache() + assert paroquant_utils_module._rotation_autotune_cache_size() == 0 + + +def test_paroquant_rotation_launch_config_serializes_concurrent_autotune(monkeypatch): + """Guard free-threaded launch autotune so one cold shape is measured once at a time.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA is required to validate the ParoQuant launch-config path.") + + clear_paroquant_rotation_autotune_cache() + monkeypatch.setattr(paroquant_utils_module, "_load_rotation_extension", lambda: True) + monkeypatch.setattr(paroquant_utils_module, "_rotation_requested_launch", lambda: (-2, -2)) + + state_lock = threading.Lock() + result_lock = threading.Lock() + calls = {"launch": 0, "active": 0, "max_active": 0} + results = [] + failures = [] + + def fake_launch_config(x, krot, has_scale, group_size, cta_m, row_pad): + del x, krot, has_scale, group_size, cta_m, row_pad + with state_lock: + calls["launch"] += 1 + calls["active"] += 1 + calls["max_active"] = max(calls["max_active"], calls["active"]) + try: + time.sleep(0.05) + return (8, 2) + finally: + with state_lock: + calls["active"] -= 1 + + def fake_op(name): + if name == "launch_config": + return fake_launch_config + raise AssertionError(f"unexpected op lookup: {name}") + + monkeypatch.setattr(paroquant_utils_module._PAROQUANT_ROTATION_EXTENSION, "op", fake_op) + + x = torch.empty((32, 128), device="cuda", dtype=torch.float16) + pairs = torch.zeros((8, 128), device="cuda", dtype=torch.int16) + theta = torch.zeros((8, 64), device="cuda", dtype=torch.float16) + scales = torch.ones((1, 128), device="cuda", dtype=torch.float16) + start_barrier = threading.Barrier(4) + + def worker(): + try: + start_barrier.wait() + resolved = _rotation_launch_config(x, pairs, theta, scales=scales, group_size=128) + with result_lock: + results.append(resolved) + except BaseException as exc: # pragma: no cover - test should fail below instead. + with result_lock: + failures.append(exc) + + threads = [threading.Thread(target=worker) for _ in range(4)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + assert failures == [] + assert results == [(8, 2)] * 4 + assert calls["launch"] == 1 + assert calls["max_active"] == 1 + + clear_paroquant_rotation_autotune_cache() + + +def test_paroquant_rotation_helper_reuses_resolved_autotune_launch(monkeypatch): + """Guard the fused helper so autotune resolves once and steady-state runs explicitly.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA is required to validate the ParoQuant launch-config path.") + + clear_paroquant_rotation_autotune_cache() + monkeypatch.setattr(paroquant_utils_module, "_load_rotation_extension", lambda: True) + monkeypatch.setattr(paroquant_utils_module, "_rotation_requested_launch", lambda: (-2, -2)) + + calls = {"launch": 0, "rotate": 0, "configs": []} + + def fake_launch_config(x, krot, has_scale, group_size, cta_m, row_pad): + del x, krot, has_scale, group_size, cta_m, row_pad + calls["launch"] += 1 + return (16, 2) + + def fake_rotate(x, pairs, theta, scales, group_size, cta_m, row_pad): + del pairs, theta, scales, group_size + calls["rotate"] += 1 + calls["configs"].append((cta_m, row_pad)) + return x.clone() + + def fake_op(name): + if name == "launch_config": + return fake_launch_config + if name == "rotate": + return fake_rotate + raise AssertionError(f"unexpected op lookup: {name}") + + monkeypatch.setattr(paroquant_utils_module._PAROQUANT_ROTATION_EXTENSION, "op", fake_op) + + x = torch.randn((32, 128), device="cuda", dtype=torch.float16) + pairs = torch.zeros((8, 128), device="cuda", dtype=torch.int16) + theta = torch.zeros((8, 64), device="cuda", dtype=torch.float16) + scales = torch.ones((1, 128), device="cuda", dtype=torch.float16) + + first = apply_paroquant_rotation(x, pairs, theta, scales=scales, group_size=128) + second = apply_paroquant_rotation(x, pairs, theta, scales=scales, group_size=128) + + assert torch.equal(first, x) + assert torch.equal(second, x) + assert calls["launch"] == 1 + assert calls["rotate"] == 2 + assert calls["configs"] == [(16, 2), (16, 2)] + + clear_paroquant_rotation_autotune_cache() + third = apply_paroquant_rotation(x, pairs, theta, scales=scales, group_size=128) + + assert torch.equal(third, x) + assert calls["launch"] == 2 + assert calls["rotate"] == 3 + assert calls["configs"][-1] == (16, 2) + + +def test_paroquant_processor_prewarm_runtime_runs_once(monkeypatch): + """Guard startup prewarm so the looper does not retry the fused extension every layer.""" + calls = [] + + monkeypatch.setattr( + paroquant_processor_module, + "prewarm_paroquant_rotation_extension", + lambda **kwargs: calls.append(kwargs) or True, + ) + + processor = object.__new__(ParoQuantProcessor) + processor.qcfg = SimpleNamespace(opt_fused_rotation=True, group_size=128, krot=8) + processor._runtime_prewarmed = False + + processor.prewarm_runtime() + processor.prewarm_runtime() + + assert len(calls) == 1 + assert calls[0] == { + "fused_rotation": True, + "group_size": 128, + "krot": 8, + } + + +def test_paroquant_processor_grouped_modes_capture_pristine_context_outside_subset_forward(): + """Guard grouped modes against treating early-stopped subset forwards as full-layer targets.""" + processor = object.__new__(ParoQuantProcessor) + processor.qcfg = SimpleNamespace(opt_scope="compute_block") + assert processor.uses_grouped_optimization() is True + assert processor.capture_layer_forward_context_during_subset() is False + + processor.qcfg = SimpleNamespace(opt_scope="layer") + assert processor.uses_grouped_optimization() is True + assert processor.capture_layer_forward_context_during_subset() is False + + processor.qcfg = SimpleNamespace(opt_scope="module") + assert processor.uses_grouped_optimization() is False + assert processor.capture_layer_forward_context_during_subset() is False + + +def test_paroquant_processor_disables_stage_cudagraph_for_module_scope_loop(): + """Guard module scope against CUDA-graph private-pool growth across many linear optimizations.""" + processor = object.__new__(ParoQuantProcessor) + + processor.qcfg = SimpleNamespace(opt_scope="module", opt_stage_cudagraph=True) + assert processor._module_scope_stage_cudagraph_enabled() is False + + processor.qcfg = SimpleNamespace(opt_scope="module", opt_stage_cudagraph=False) + assert processor._module_scope_stage_cudagraph_enabled() is False + + processor.qcfg = SimpleNamespace(opt_scope="compute_block", opt_stage_cudagraph=True) + assert processor._module_scope_stage_cudagraph_enabled() is True + + processor.qcfg = SimpleNamespace(opt_scope="layer", opt_stage_cudagraph=True) + assert processor._module_scope_stage_cudagraph_enabled() is True + + +def test_paroquant_processor_module_quantize_forces_stage_cudagraph_off(monkeypatch): + """Guard the full model module loop against per-linear CUDA-graph pool retention.""" + stage_cudagraph_calls = [] + optimizer_name_calls = [] + optimizer_kwargs_calls = [] + + def fake_optimize_paroquant_linear(*, weight, stage_cudagraph=None, optimizer_name="adamw", **kwargs): + stage_cudagraph_calls.append(stage_cudagraph) + optimizer_name_calls.append(optimizer_name) + optimizer_kwargs_calls.append(kwargs) + zeros = torch.zeros((weight.shape[0], weight.shape[1] // 128), dtype=weight.dtype) + return SimpleNamespace( + train_loss=0.0, + val_loss=0.0, + pseudo_weight=weight.detach().clone(), + pack_weight=weight.detach().clone(), + q_scales=torch.ones_like(zeros), + q_zeros=torch.zeros_like(zeros, dtype=torch.int32), + pairs=torch.zeros((1, weight.shape[1]), dtype=torch.int16), + theta=torch.zeros((1, weight.shape[1] // 2), dtype=weight.dtype), + channel_scales=torch.ones((1, weight.shape[1]), dtype=weight.dtype), + ) + + monkeypatch.setattr(paroquant_processor_module, "optimize_paroquant_linear", fake_optimize_paroquant_linear) + + processor = object.__new__(ParoQuantProcessor) + processor.qcfg = SimpleNamespace( + opt_scope="module", + opt_stage_cudagraph=True, + dynamic_get=lambda _name, _field, default=None: default, + runtime_bits=4, + group_size=128, + sym=True, + krot=8, + opt_pair_ratio=0.25, + opt_train_samples=128, + opt_validation_samples=32, + opt_batch_size=16, + opt_rotation_epochs=1, + opt_finetune_epochs=1, + opt_rotation_lr=0.05, + opt_weight_lr=1e-5, + opt_quantizer_lr=1e-6, + opt_seed=0, + opt_optimizer="sgd", + opt_weight_decay=0.02, + opt_betas=(0.8, 0.9), + opt_eps=1e-8, + opt_amsgrad=True, + opt_sgd_momentum=0.85, + opt_sgd_dampening=0.0, + opt_sgd_nesterov=True, + opt_fused_rotation=True, + opt_best_state_dtype="fp16", + opt_stage_impl="fast", + opt_pair_impl="fast", + opt_quantizer_impl="reference", + opt_channel_scale_clamp_min=1e-2, + opt_channel_scale_clamp_max=1e2, + ) + processor.calculate_w_wq_diff = False + processor.lock = threading.Lock() + + module = SimpleNamespace( + name="self_attn.q_proj", + full_name="model.layers.0.self_attn.q_proj", + layer_index=0, + weight=torch.nn.Parameter(torch.randn((8, 128), dtype=torch.float32)), + bias=None, + state={}, + ) + + processor._quantize_one_module(module, torch.randn((32, 128), dtype=torch.float32)) + + assert stage_cudagraph_calls == [False] + assert optimizer_name_calls == ["sgd"] + assert len(optimizer_kwargs_calls) == 1 + forwarded_kwargs = optimizer_kwargs_calls[0] + assert forwarded_kwargs["bias"] is None + assert isinstance(forwarded_kwargs["inputs"], torch.Tensor) + assert tuple(forwarded_kwargs["inputs"].shape) == (32, 128) + assert forwarded_kwargs["bits"] == 4 + assert forwarded_kwargs["group_size"] == 128 + assert forwarded_kwargs["sym"] is True + assert forwarded_kwargs["krot"] == 8 + assert forwarded_kwargs["pair_ratio"] == pytest.approx(0.25) + assert forwarded_kwargs["train_rows"] == 128 + assert forwarded_kwargs["val_rows"] == 32 + assert forwarded_kwargs["batch_size"] == 16 + assert forwarded_kwargs["rotation_epochs"] == 1 + assert forwarded_kwargs["finetune_epochs"] == 1 + assert forwarded_kwargs["rotation_lr"] == pytest.approx(0.05) + assert forwarded_kwargs["weight_lr"] == pytest.approx(1e-5) + assert forwarded_kwargs["quantizer_lr"] == pytest.approx(1e-6) + assert isinstance(forwarded_kwargs["seed"], int) + assert forwarded_kwargs["optimizer_weight_decay"] == pytest.approx(0.02) + assert forwarded_kwargs["optimizer_betas"] == pytest.approx((0.8, 0.9)) + assert forwarded_kwargs["optimizer_eps"] == pytest.approx(1e-8) + assert forwarded_kwargs["optimizer_amsgrad"] is True + assert forwarded_kwargs["sgd_momentum"] == pytest.approx(0.85) + assert forwarded_kwargs["sgd_dampening"] == pytest.approx(0.0) + assert forwarded_kwargs["sgd_nesterov"] is True + assert forwarded_kwargs["fused_rotation"] is True + assert forwarded_kwargs["gradient_checkpointing"] is False + assert forwarded_kwargs["best_state_dtype"] == "fp16" + assert forwarded_kwargs["stage_impl"] == "fast" + assert forwarded_kwargs["pair_impl"] == "fast" + assert forwarded_kwargs["quantizer_impl"] == "reference" + assert forwarded_kwargs["scale_clamp_min"] == pytest.approx(1e-2) + assert forwarded_kwargs["scale_clamp_max"] == pytest.approx(1e2) + + +def test_paroquant_processor_layer_scope_live_path_is_dense_only(): + """Guard the official-like live layer path for dense decoder layers only.""" + dense_modules = [ + SimpleNamespace(name="self_attn.q_proj"), + SimpleNamespace(name="self_attn.k_proj"), + SimpleNamespace(name="self_attn.v_proj"), + SimpleNamespace(name="self_attn.o_proj"), + SimpleNamespace(name="mlp.gate_proj"), + SimpleNamespace(name="mlp.up_proj"), + SimpleNamespace(name="mlp.down_proj"), + ] + moe_modules = [ + SimpleNamespace(name="self_attn.q_proj"), + SimpleNamespace(name="mlp.experts.0.gate_up_proj"), + SimpleNamespace(name="mlp.experts.0.down_proj"), + ] + + assert ParoQuantProcessor._supports_live_layer_scope(dense_modules) is True + assert ParoQuantProcessor._supports_live_layer_scope(moe_modules) is False + + +@pytest.mark.parametrize( + ("opt_scope", "expected_capture"), + [ + ("module", False), + ("compute_block", True), + ("layer", True), + ], +) +def test_paroquant_processor_enables_layer_context_capture_only_for_grouped_scopes(opt_scope, expected_capture): + """Guard module scope against retaining pristine layer IO it never consumes.""" + processor = ParoQuantProcessor( + tokenizer=None, + qcfg=ParoConfig(bits=4, group_size=128, opt_scope=opt_scope), + calibration=None, + prepare_dataset_func=None, + calibration_concat_size=None, + calibration_sort=None, + batch_size=1, + gptq_model=None, + model=None, + ) + + assert processor.execution_config.capture_layer_forward_context is expected_capture + + +@pytest.mark.parametrize( + ("opt_scope", "opt_gradient_checkpointing", "expected"), + [ + ("module", None, False), + ("compute_block", None, False), + ("layer", None, True), + ("module", True, True), + ("layer", False, False), + ], +) +def test_paroquant_processor_resolves_gradient_checkpointing_by_scope(opt_scope, opt_gradient_checkpointing, expected): + """Processor runtime should mirror the config default and explicit override semantics.""" + + processor = object.__new__(ParoQuantProcessor) + processor.qcfg = SimpleNamespace( + opt_scope=opt_scope, + opt_gradient_checkpointing=opt_gradient_checkpointing, + ) + + assert processor._gradient_checkpointing_enabled() is expected + + +def test_paroquant_processor_skips_pristine_layer_clone_for_layer_scope(): + """Guard layer scope against retaining an unused pristine layer clone on CPU.""" + + processor = object.__new__(ParoQuantProcessor) + processor.qcfg = SimpleNamespace(opt_scope="layer") + processor._layer_states = {} + processor._layer_states_lock = threading.Lock() + + layer = torch.nn.Linear(4, 4, bias=False) + processor.receive_pristine_layer_module(layer_index=0, layer_module=layer) + + assert processor._layer_states == {} + + +def test_paroquant_processor_routes_non_module_units_through_group_optimizer(): + """Guard that compute_block/layer modes now use grouped optimization instead of raising.""" + processor = object.__new__(ParoQuantProcessor) + processor.qcfg = SimpleNamespace(opt_scope="compute_block") + processor.fallback = True + processor.lock = threading.Lock() + processor.tasks = {} + processor.calculate_w_wq_diff = False + processor._log_quant_result = lambda *args, **kwargs: None # type: ignore[method-assign] + + layer = torch.nn.Module() + layer.self_attn = torch.nn.Module() + layer.self_attn.q_proj = torch.nn.Linear(8, 8, bias=False) + layer.self_attn.k_proj = torch.nn.Linear(8, 8, bias=False) + layer.self_attn.v_proj = torch.nn.Linear(8, 8, bias=False) + + q_proj = NamedModule(layer.self_attn.q_proj, "self_attn.q_proj", "model.layers.0.self_attn.q_proj", 0) + k_proj = NamedModule(layer.self_attn.k_proj, "self_attn.k_proj", "model.layers.0.self_attn.k_proj", 0) + v_proj = NamedModule(layer.self_attn.v_proj, "self_attn.v_proj", "model.layers.0.self_attn.v_proj", 0) + processor._layer_input_features = lambda _state: { # type: ignore[method-assign] + q_proj.name: torch.randn(4, 8), + k_proj.name: torch.randn(4, 8), + v_proj.name: torch.randn(4, 8), + } + + observed_groups = [] + + def fake_optimize_group(state, group_modules): + del state + observed_groups.append([module.name for module in group_modules]) + results = {} + for module in group_modules: + weight = module.weight.data.detach() + results[module.name] = SimpleNamespace( + pseudo_weight=weight + 1.0, + pack_weight=weight + 2.0, + q_scales=torch.ones((weight.shape[0], max(1, weight.shape[1] // 8)), dtype=weight.dtype), + q_zeros=torch.zeros((weight.shape[0], max(1, weight.shape[1] // 8)), dtype=torch.int32), + pairs=torch.zeros((1, 8), dtype=torch.int16), + theta=torch.zeros((1, weight.shape[1] // 2), dtype=weight.dtype), + channel_scales=torch.ones((1, weight.shape[1]), dtype=weight.dtype), + ) + return results, 0.25 + + processor._optimize_group = fake_optimize_group # type: ignore[method-assign] + + state = SimpleNamespace( + quantized=False, + modules={q_proj.name: q_proj, k_proj.name: k_proj, v_proj.name: v_proj}, + layer_inputs=[[torch.randn(1, 2, 8)]], + layer_outputs=[[torch.randn(1, 2, 8)]], + pending_modules=set(), + processed_subsets={0}, + subset_total=1, + ) + + original_q_weight = q_proj.weight.data.clone() + processor._quantize_layer(layer_index=0, state=state) + + assert observed_groups == [["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"]] + torch.testing.assert_close(q_proj.weight.data, original_q_weight + 1.0) + assert state.quantized is True + assert state.modules == {} + assert state.pending_modules == set() + assert state.processed_subsets == set() + + +def test_paroquant_processor_compute_block_scope_flushes_cuda_cache_between_groups(monkeypatch): + """Guard compute_block scope against carrying allocator cache forward between grouped passes when offload is on.""" + + processor = object.__new__(ParoQuantProcessor) + processor.qcfg = SimpleNamespace(opt_scope="compute_block", offload_to_disk=True) + processor.fallback = True + processor.lock = threading.Lock() + processor.tasks = {} + processor.calculate_w_wq_diff = False + processor._log_quant_result = lambda *args, **kwargs: None # type: ignore[method-assign] + + layer = torch.nn.Module() + layer.self_attn = torch.nn.Module() + layer.self_attn.q_proj = torch.nn.Linear(8, 8, bias=False) + layer.self_attn.k_proj = torch.nn.Linear(8, 8, bias=False) + layer.self_attn.v_proj = torch.nn.Linear(8, 8, bias=False) + + q_proj = NamedModule(layer.self_attn.q_proj, "self_attn.q_proj", "model.layers.0.self_attn.q_proj", 0) + k_proj = NamedModule(layer.self_attn.k_proj, "self_attn.k_proj", "model.layers.0.self_attn.k_proj", 0) + v_proj = NamedModule(layer.self_attn.v_proj, "self_attn.v_proj", "model.layers.0.self_attn.v_proj", 0) + processor._layer_input_features = lambda _state: { # type: ignore[method-assign] + q_proj.name: torch.randn(4, 8), + k_proj.name: torch.randn(4, 8), + v_proj.name: torch.randn(4, 8), + } + + def fake_optimize_group(state, group_modules): + del state + results = {} + for module in group_modules: + weight = module.weight.data.detach() + results[module.name] = SimpleNamespace( + pseudo_weight=weight + 1.0, + pack_weight=weight + 2.0, + q_scales=torch.ones((weight.shape[0], max(1, weight.shape[1] // 8)), dtype=weight.dtype), + q_zeros=torch.zeros((weight.shape[0], max(1, weight.shape[1] // 8)), dtype=torch.int32), + pairs=torch.zeros((1, 8), dtype=torch.int16), + theta=torch.zeros((1, weight.shape[1] // 2), dtype=weight.dtype), + channel_scales=torch.ones((1, weight.shape[1]), dtype=weight.dtype), + ) + return results, 0.25 + + processor._optimize_group = fake_optimize_group # type: ignore[method-assign] + + empty_cache_calls = [] + monkeypatch.setattr( + paroquant_processor_module, + "torch_empty_cache", + lambda device=None, gc=True, sync=False: empty_cache_calls.append( + {"device": device, "gc": gc, "sync": sync} + ), + ) + + state = SimpleNamespace( + quantized=False, + modules={q_proj.name: q_proj, k_proj.name: k_proj, v_proj.name: v_proj}, + layer_inputs=[[torch.randn(1, 2, 8)]], + layer_outputs=[[torch.randn(1, 2, 8)]], + pending_modules=set(), + processed_subsets={0}, + subset_total=1, + ) + + processor._quantize_layer(layer_index=0, state=state) + + assert empty_cache_calls == [{"device": torch.device("cpu"), "gc": False, "sync": True}] + + +def test_paroquant_processor_layer_scope_falls_back_to_clone_for_expert_like_groups(monkeypatch): + """Guard expert-like layer groups against accidentally taking the dense live path.""" + import gptqmodel.looper.paroquant_processor as paroquant_processor_module + + processor = object.__new__(ParoQuantProcessor) + processor.qcfg = SimpleNamespace( + opt_scope="layer", + opt_rotation_lr=0.05, + opt_weight_lr=1e-5, + opt_quantizer_lr=1e-6, + opt_rotation_epochs=0, + opt_finetune_epochs=0, + ) + state = SimpleNamespace() + group_modules = [ + SimpleNamespace(name="self_attn.q_proj"), + SimpleNamespace(name="mlp.experts.0.gate_up_proj"), + ] + + live_calls = [] + clone_calls = [] + + def fake_live(_state, _modules): + live_calls.append(True) + return {}, 0.0 + + processor._optimize_live_layer = fake_live # type: ignore[method-assign] + + def _fake_optim_module(): + return SimpleNamespace( + channel_scales_opt=torch.nn.Parameter(torch.ones(1)), + theta=torch.nn.Parameter(torch.zeros(1)), + weight=torch.nn.Parameter(torch.ones(1, 1)), + quantizer=None, + init_quantizer=lambda: None, + ) + + def fake_build_group_optim_layer(_state, _modules): + clone_calls.append(True) + return torch.nn.Linear(4, 4, bias=False), { + module.name: _fake_optim_module() for module in _modules + } + + processor._build_group_optim_layer = fake_build_group_optim_layer # type: ignore[method-assign] + processor._group_dataset_for_device = lambda *_args, **_kwargs: ([], [], [], [], [], [], [], [], [], []) # type: ignore[method-assign] + processor._run_group_stage = lambda *args, **kwargs: (0.0, 0.0) # type: ignore[method-assign] + monkeypatch.setattr( + paroquant_processor_module, + "_result_from_model", + lambda _optim_module, **_kwargs: SimpleNamespace(ok=True), + ) + + results, val_loss = processor._optimize_group(state, group_modules) + + assert live_calls == [] + assert clone_calls == [True] + assert set(results) == {"self_attn.q_proj", "mlp.experts.0.gate_up_proj"} + assert val_loss == 0.0 + + +def test_paroquant_processor_captures_first_layer_forward_context(): + """Guard that grouped optimization modes keep the original float layer IO once.""" + processor = object.__new__(ParoQuantProcessor) + processor._layer_states = {} + processor._layer_states_lock = threading.Lock() + + first_inputs = [[torch.randn(1, 4)]] + first_kwargs = [{"attention_mask": torch.ones((1, 4), dtype=torch.int64)}] + first_outputs = [[torch.randn(1, 4)]] + second_inputs = [[torch.randn(1, 4)]] + second_kwargs = [{"attention_mask": torch.zeros((1, 4), dtype=torch.int64)}] + second_outputs = [[torch.randn(1, 4)]] + + processor.receive_layer_forward_context( + layer_index=0, + layer_inputs=first_inputs, + layer_input_kwargs=first_kwargs, + layer_outputs=first_outputs, + subset_index=0, + subset_total=2, + ) + processor.receive_layer_forward_context( + layer_index=0, + layer_inputs=second_inputs, + layer_input_kwargs=second_kwargs, + layer_outputs=second_outputs, + subset_index=1, + subset_total=2, + ) + + state = processor._get_layer_state(0) + assert state.layer_inputs is first_inputs + assert state.layer_input_kwargs is first_kwargs + assert state.layer_outputs is first_outputs + assert state.subset_total == 2 + + +def test_paroquant_processor_group_clean_inputs_seed_from_input_cache(): + """Guard grouped clean targets against aliasing the noisy replay cache.""" + processor = object.__new__(ParoQuantProcessor) + processor.qcfg = SimpleNamespace(opt_scope="layer", opt_train_on_noisy_inputs=True) + clean_inputs = [[torch.randn(1, 4)]] + noisy_inputs = [[torch.randn(1, 4)]] + cache = InputCache( + layer_inputs=clean_inputs, + layer_input_kwargs=[{}], + position_ids=[None], + attention_masks=[None], + ) + + processor.receive_input_cache(cache) + + assert processor.inputs_cache.layer_inputs is clean_inputs + assert processor.clean_group_layer_inputs(layer_index=0, layer_inputs=noisy_inputs) is clean_inputs + + +def test_paroquant_processor_group_clean_inputs_default_to_noisy_stream(): + """Guard default grouped behavior against enabling train-on-noisy-inputs implicitly.""" + processor = object.__new__(ParoQuantProcessor) + processor.qcfg = SimpleNamespace(opt_scope="layer", opt_train_on_noisy_inputs=False) + clean_inputs = [[torch.randn(1, 4)]] + noisy_inputs = [[torch.randn(1, 4)]] + processor.receive_input_cache( + InputCache( + layer_inputs=clean_inputs, + layer_input_kwargs=[{}], + position_ids=[None], + attention_masks=[None], + ) + ) + + assert processor.clean_group_layer_inputs(layer_index=0, layer_inputs=noisy_inputs) is noisy_inputs + + +def test_paroquant_processor_group_capture_uses_pristine_module_and_clean_inputs_for_targets(): + """Guard grouped capture so targets come from the untouched module on the clean stream.""" + + class _PristineLayer(torch.nn.Module): + def forward(self, x, attention_mask=None, position_ids=None, use_cache=False): + del attention_mask, position_ids, use_cache + return x + 1.0 + + class _HookedLikeLayer(torch.nn.Module): + def forward(self, x, attention_mask=None, position_ids=None, use_cache=False): + del attention_mask, position_ids, use_cache + return x + 100.0 + + class _DummyPB: + def manual(self): + return self + + def set(self, **kwargs): + del kwargs + return self + + def title(self, _value): + return self + + def subtitle(self, _value): + return self + + def draw(self): + return self + + def close(self): + return None + + class _DummyLog: + def pb(self, _iterable): + return _DummyPB() + + class _DummyLooper: + def _resolve_batch_total(self, _num_batches, layer_inputs): + return len(layer_inputs) + + def _collect_row_counts(self, layer_inputs): + return [1 for _ in layer_inputs] + + def _run_forward_batches( + self, + *, + module, + processor, + layer_inputs, + layer_input_kwargs, + position_ids, + attention_masks, + cur_layer_device, + is_lm_head_module, + shared_kv_cache_dict, + layer_index, + need_outputs, + reuse_kv, + progress_pb, + progress_title, + progress_stage, + progress_rows_per_batch, + progress_total_rows, + force_serial, + preserve_module_devices, + ): + del ( + processor, + position_ids, + attention_masks, + cur_layer_device, + is_lm_head_module, + shared_kv_cache_dict, + layer_index, + need_outputs, + reuse_kv, + progress_pb, + progress_title, + progress_stage, + progress_rows_per_batch, + progress_total_rows, + force_serial, + preserve_module_devices, + ) + outputs = [] + for batch_inputs, batch_kwargs in zip(layer_inputs, layer_input_kwargs): + outputs.append([module(batch_inputs[0], **batch_kwargs)]) + return outputs + + processor = object.__new__(ParoQuantProcessor) + processor.qcfg = SimpleNamespace(opt_scope="layer", opt_train_on_noisy_inputs=True) + processor._layer_states = {} + processor._layer_states_lock = threading.Lock() + + clean_inputs = [[torch.tensor([[1.0]])]] + noisy_inputs = [[torch.tensor([[3.0]])]] + clean_cache = InputCache( + layer_inputs=clean_inputs, + layer_input_kwargs=[{}], + position_ids=[None], + attention_masks=[None], + ) + processor.receive_input_cache(clean_cache) + + _capture_pristine_group_context( + _DummyLooper(), + processor=processor, + module=_HookedLikeLayer(), + pristine_module=_PristineLayer(), + subset_plans=[SimpleNamespace()], + layer_inputs=noisy_inputs, + layer_input_kwargs=[{}], + position_ids=[None], + attention_masks=[None], + cur_layer_device=torch.device("cpu"), + is_lm_head_module=False, + shared_kv_cache_dict={}, + layer_index=0, + layer_descriptor="model.layers.0", + full={}, + log=_DummyLog(), + region_timer=None, + ) + + state = processor._get_layer_state(0) + assert state.layer_inputs is noisy_inputs + torch.testing.assert_close(state.layer_outputs[0][0], torch.tensor([[2.0]])) + torch.testing.assert_close(processor.clean_group_layer_inputs(layer_index=1, layer_inputs=noisy_inputs)[0][0], torch.tensor([[2.0]])) + + +def test_paroquant_processor_group_capture_advances_clean_stream_without_subset_plans(): + """Guard clean/noisy replay semantics for grouped layers that are dynamically skipped.""" + + class _ToyLayer(torch.nn.Module): + def forward(self, x, attention_mask=None, position_ids=None, use_cache=False): + del attention_mask, position_ids, use_cache + return x + 5.0 + + class _DummyPB: + def manual(self): + return self + + def set(self, **kwargs): + del kwargs + return self + + def title(self, _value): + return self + + def subtitle(self, _value): + return self + + def draw(self): + return self + + def close(self): + return None + + class _DummyLog: + def pb(self, _iterable): + return _DummyPB() + + class _DummyLooper: + def _resolve_batch_total(self, _num_batches, layer_inputs): + return len(layer_inputs) + + def _collect_row_counts(self, layer_inputs): + return [1 for _ in layer_inputs] + + def _run_forward_batches(self, **kwargs): + layer_inputs = kwargs["layer_inputs"] + module = kwargs["module"] + layer_input_kwargs = kwargs["layer_input_kwargs"] + return [[module(batch[0], **batch_kwargs)] for batch, batch_kwargs in zip(layer_inputs, layer_input_kwargs)] + + processor = object.__new__(ParoQuantProcessor) + processor.qcfg = SimpleNamespace(opt_scope="layer", opt_train_on_noisy_inputs=True) + processor._layer_states = {} + processor._layer_states_lock = threading.Lock() + processor.receive_input_cache( + InputCache( + layer_inputs=[[torch.tensor([[2.0]])]], + layer_input_kwargs=[{}], + position_ids=[None], + attention_masks=[None], + ) + ) + + noisy_inputs = [[torch.tensor([[9.0]])]] + _capture_pristine_group_context( + _DummyLooper(), + processor=processor, + module=_ToyLayer(), + pristine_module=None, + subset_plans=[], + layer_inputs=noisy_inputs, + layer_input_kwargs=[{}], + position_ids=[None], + attention_masks=[None], + cur_layer_device=torch.device("cpu"), + is_lm_head_module=False, + shared_kv_cache_dict={}, + layer_index=0, + layer_descriptor="model.layers.0", + full={}, + log=_DummyLog(), + region_timer=None, + ) + + state = processor._get_layer_state(0) + assert state.layer_inputs is None + assert state.layer_outputs is None + torch.testing.assert_close(processor.clean_group_layer_inputs(layer_index=1, layer_inputs=noisy_inputs)[0][0], torch.tensor([[7.0]])) + + +def test_paroquant_quantize_layer_clears_stored_forward_context(): + """Guard that transient grouped-optimization IO snapshots do not leak across layers.""" + module = SimpleNamespace( + name="mlp.gate_proj", + full_name="model.layers.0.mlp.gate_proj", + weight=SimpleNamespace(data=torch.randn(8, 8)), + bias=None, + ) + processor = object.__new__(ParoQuantProcessor) + processor.qcfg = SimpleNamespace(opt_scope="module") + processor.fallback = True + processor.lock = threading.Lock() + processor.tasks = {"mlp.gate_proj": {"inputs": [torch.randn(1, 8)], "layer_index": 0}} + processor._layer_input_features = lambda _state: {"mlp.gate_proj": torch.randn(4, 8)} + processor._quantize_one_module = lambda named_module, feat: (0.0, float(feat.numel() > 0)) # type: ignore[method-assign] + processor._log_quant_result = lambda *args, **kwargs: None # type: ignore[method-assign] + + state = SimpleNamespace( + quantized=False, + modules={"mlp.gate_proj": module}, + pending_modules=set(), + processed_subsets={0}, + pristine_layer_module=torch.nn.Linear(8, 8, bias=False), + prepared_group_source_module=torch.nn.Linear(8, 8, bias=False), + prepared_group_source_module_by_device={"cpu": torch.nn.Linear(8, 8, bias=False)}, + layer_inputs=[[torch.randn(1, 8)]], + layer_input_kwargs=[{"attention_mask": torch.ones((1, 8), dtype=torch.int64)}], + layer_outputs=[[torch.randn(1, 8)]], + grouped_dataset=("cached",), + grouped_dataset_by_device={"cpu": ("cached",)}, + subset_total=1, + ) + + processor._quantize_layer(layer_index=0, state=state) + + assert state.quantized is True + assert state.modules == {} + assert state.pending_modules == set() + assert state.processed_subsets == set() + assert state.layer_inputs is None + assert state.layer_input_kwargs is None + assert state.layer_outputs is None + assert state.pristine_layer_module is None + assert state.prepared_group_source_module is None + assert state.prepared_group_source_module_by_device is None + assert state.grouped_dataset is None + assert state.grouped_dataset_by_device is None + assert state.subset_total is None + assert processor.tasks["mlp.gate_proj"]["inputs"] == [] + + +def test_paroquant_processor_builds_group_optim_layer_clone(): + """Guard that compute_block/layer modes can swap selected modules into a cloned float layer.""" + + class _ToyAttn(torch.nn.Module): + def __init__(self): + super().__init__() + self.config = SimpleNamespace(_attn_implementation="sdpa") + self.q_proj = torch.nn.Linear(8, 8, bias=False) + self.k_proj = torch.nn.Linear(8, 8, bias=False) + self.v_proj = torch.nn.Linear(8, 8, bias=False) + self.o_proj = torch.nn.Linear(8, 8, bias=False) + + class _ToyMlp(torch.nn.Module): + def __init__(self): + super().__init__() + self.gate_proj = torch.nn.Linear(8, 8, bias=False) + self.up_proj = torch.nn.Linear(8, 8, bias=False) + self.down_proj = torch.nn.Linear(8, 8, bias=False) + + class _ToyLayer(torch.nn.Module): + def __init__(self): + super().__init__() + self.self_attn = _ToyAttn() + self.mlp = _ToyMlp() + + def _dynamic_get(_module_name, _key, default=None): + return default + + layer = _ToyLayer().half() + q_proj = NamedModule(layer.self_attn.q_proj, "self_attn.q_proj", "model.layers.0.self_attn.q_proj", 0) + k_proj = NamedModule(layer.self_attn.k_proj, "self_attn.k_proj", "model.layers.0.self_attn.k_proj", 0) + + processor = object.__new__(ParoQuantProcessor) + processor.qcfg = SimpleNamespace( + opt_scope="compute_block", + runtime_bits=4, + group_size=8, + sym=True, + krot=1, + opt_seed=0, + opt_pair_ratio=0.5, + opt_pair_impl="fast", + opt_quantizer_impl="reference", + opt_fused_rotation=False, + opt_channel_scale_clamp_min=1e-2, + opt_channel_scale_clamp_max=1e2, + dynamic_get=_dynamic_get, + ) + + state = SimpleNamespace(layer_module=layer) + layer_clone, optim_modules = processor._build_group_optim_layer(state, [q_proj, k_proj]) + + assert layer_clone is not layer + assert isinstance(layer.self_attn.q_proj, torch.nn.Linear) + assert isinstance(layer.self_attn.k_proj, torch.nn.Linear) + assert isinstance(layer_clone.self_attn.q_proj, _ParoQuantOptimLinear) + assert isinstance(layer_clone.self_attn.k_proj, _ParoQuantOptimLinear) + assert isinstance(layer_clone.self_attn.v_proj, torch.nn.Linear) + assert set(optim_modules) == {"self_attn.q_proj", "self_attn.k_proj"} + assert layer_clone.self_attn.q_proj.weight.dtype == torch.float32 + assert layer_clone.self_attn.k_proj.weight.dtype == torch.float32 + assert layer_clone.self_attn.config._attn_implementation == "eager" + + +def test_paroquant_processor_builds_group_optim_layer_from_pristine_snapshot(): + """Guard grouped clones against inheriting HookedLinear-style mutations from the live layer.""" + + class _HookedLike(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = torch.nn.Parameter(torch.randn(8, 8)) + + def forward(self, x): + return x + + class _ToyAttn(torch.nn.Module): + def __init__(self): + super().__init__() + self.config = SimpleNamespace(_attn_implementation="sdpa") + self.q_proj = torch.nn.Linear(8, 8, bias=False) + self.k_proj = torch.nn.Linear(8, 8, bias=False) + self.v_proj = torch.nn.Linear(8, 8, bias=False) + self.o_proj = torch.nn.Linear(8, 8, bias=False) + + class _ToyLayer(torch.nn.Module): + def __init__(self): + super().__init__() + self.self_attn = _ToyAttn() + + def _dynamic_get(_module_name, _key, default=None): + return default + + pristine_layer = _ToyLayer().half() + live_layer = copy.deepcopy(pristine_layer) + live_layer.self_attn.o_proj = _HookedLike() + + q_proj = NamedModule( + live_layer.self_attn.q_proj, + "self_attn.q_proj", + "model.layers.0.self_attn.q_proj", + 0, + ) + k_proj = NamedModule( + live_layer.self_attn.k_proj, + "self_attn.k_proj", + "model.layers.0.self_attn.k_proj", + 0, + ) + + processor = object.__new__(ParoQuantProcessor) + processor.qcfg = SimpleNamespace( + opt_scope="compute_block", + runtime_bits=4, + group_size=8, + sym=True, + krot=1, + opt_seed=0, + opt_pair_ratio=0.5, + opt_pair_impl="fast", + opt_quantizer_impl="reference", + opt_fused_rotation=False, + opt_channel_scale_clamp_min=1e-2, + opt_channel_scale_clamp_max=1e2, + dynamic_get=_dynamic_get, + ) + + state = SimpleNamespace( + layer_module=live_layer, + pristine_layer_module=pristine_layer, + ) + layer_clone, _optim_modules = processor._build_group_optim_layer(state, [q_proj, k_proj]) + + assert isinstance(layer_clone.self_attn.q_proj, _ParoQuantOptimLinear) + assert isinstance(layer_clone.self_attn.k_proj, _ParoQuantOptimLinear) + assert isinstance(layer_clone.self_attn.o_proj, torch.nn.Linear) + assert not isinstance(layer_clone.self_attn.o_proj, _HookedLike) + + +def test_paroquant_processor_reuses_cached_group_source_clone(): + """Guard grouped clone preparation against rebuilding from a later-mutated pristine snapshot.""" + + class _HookedLike(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = torch.nn.Parameter(torch.randn(8, 8)) + + def forward(self, x): + return x + + class _ToyAttn(torch.nn.Module): + def __init__(self): + super().__init__() + self.config = SimpleNamespace(_attn_implementation="sdpa") + self.q_proj = torch.nn.Linear(8, 8, bias=False) + self.k_proj = torch.nn.Linear(8, 8, bias=False) + self.v_proj = torch.nn.Linear(8, 8, bias=False) + self.o_proj = torch.nn.Linear(8, 8, bias=False) + + class _ToyLayer(torch.nn.Module): + def __init__(self): + super().__init__() + self.self_attn = _ToyAttn() + + def _dynamic_get(_module_name, _key, default=None): + return default + + pristine_layer = _ToyLayer().half() + live_layer = copy.deepcopy(pristine_layer) + q_proj = NamedModule( + live_layer.self_attn.q_proj, + "self_attn.q_proj", + "model.layers.0.self_attn.q_proj", + 0, + ) + + processor = object.__new__(ParoQuantProcessor) + processor.qcfg = SimpleNamespace( + opt_scope="compute_block", + runtime_bits=4, + group_size=8, + sym=True, + krot=1, + opt_seed=0, + opt_pair_ratio=0.5, + opt_pair_impl="fast", + opt_quantizer_impl="reference", + opt_fused_rotation=False, + opt_channel_scale_clamp_min=1e-2, + opt_channel_scale_clamp_max=1e2, + dynamic_get=_dynamic_get, + ) + + state = SimpleNamespace( + layer_module=live_layer, + pristine_layer_module=pristine_layer, + prepared_group_source_module=None, + ) + processor._build_group_optim_layer(state, [q_proj]) + assert state.prepared_group_source_module is not None + + state.pristine_layer_module.self_attn.o_proj = _HookedLike() + layer_clone, _optim_modules = processor._build_group_optim_layer(state, [q_proj]) + + assert isinstance(layer_clone.self_attn.o_proj, torch.nn.Linear) + assert not isinstance(layer_clone.self_attn.o_proj, _HookedLike) + + +def test_paroquant_processor_reuses_cached_group_source_clone_per_device(): + """Guard grouped clone preparation against repeating the same device-local source setup.""" + + class _ToyAttn(torch.nn.Module): + def __init__(self): + super().__init__() + self.config = SimpleNamespace(_attn_implementation="sdpa") + self.q_proj = torch.nn.Linear(8, 8, bias=False) + self.k_proj = torch.nn.Linear(8, 8, bias=False) + self.v_proj = torch.nn.Linear(8, 8, bias=False) + self.o_proj = torch.nn.Linear(8, 8, bias=False) + + class _ToyLayer(torch.nn.Module): + def __init__(self): + super().__init__() + self.self_attn = _ToyAttn() + + def _dynamic_get(_module_name, _key, default=None): + return default + + layer = _ToyLayer().half() + q_proj = NamedModule( + layer.self_attn.q_proj, + "self_attn.q_proj", + "model.layers.0.self_attn.q_proj", + 0, + ) + + processor = object.__new__(ParoQuantProcessor) + processor.qcfg = SimpleNamespace( + opt_scope="compute_block", + runtime_bits=4, + group_size=8, + sym=True, + krot=1, + opt_seed=0, + opt_pair_ratio=0.5, + opt_pair_impl="fast", + opt_quantizer_impl="reference", + opt_fused_rotation=False, + opt_channel_scale_clamp_min=1e-2, + opt_channel_scale_clamp_max=1e2, + dynamic_get=_dynamic_get, + ) + + state = SimpleNamespace( + layer_module=layer, + prepared_group_source_module=None, + prepared_group_source_module_by_device=None, + ) + processor._build_group_optim_layer(state, [q_proj]) + cached_source = state.prepared_group_source_module_by_device["cpu"] + + processor._build_group_optim_layer(state, [q_proj]) + + assert state.prepared_group_source_module is not None + assert state.prepared_group_source_module_by_device is not None + assert state.prepared_group_source_module_by_device["cpu"] is cached_source + + +def test_paroquant_processor_device_group_source_cache_is_compute_block_only(): + """Guard grouped source-device caching so whole-layer mode keeps the simpler one-shot path.""" + + class _ToyAttn(torch.nn.Module): + def __init__(self): + super().__init__() + self.config = SimpleNamespace(_attn_implementation="sdpa") + self.q_proj = torch.nn.Linear(8, 8, bias=False) + self.k_proj = torch.nn.Linear(8, 8, bias=False) + self.v_proj = torch.nn.Linear(8, 8, bias=False) + self.o_proj = torch.nn.Linear(8, 8, bias=False) + + class _ToyLayer(torch.nn.Module): + def __init__(self): + super().__init__() + self.self_attn = _ToyAttn() + + def _dynamic_get(_module_name, _key, default=None): + return default + + layer = _ToyLayer().half() + q_proj = NamedModule( + layer.self_attn.q_proj, + "self_attn.q_proj", + "model.layers.0.self_attn.q_proj", + 0, + ) + + processor = object.__new__(ParoQuantProcessor) + processor.qcfg = SimpleNamespace( + opt_scope="layer", + runtime_bits=4, + group_size=8, + sym=True, + krot=1, + opt_seed=0, + opt_pair_ratio=0.5, + opt_pair_impl="fast", + opt_quantizer_impl="reference", + opt_fused_rotation=False, + opt_channel_scale_clamp_min=1e-2, + opt_channel_scale_clamp_max=1e2, + dynamic_get=_dynamic_get, + ) + + state = SimpleNamespace( + layer_module=layer, + prepared_group_source_module=None, + prepared_group_source_module_by_device=None, + ) + processor._build_group_optim_layer(state, [q_proj]) + + assert state.prepared_group_source_module is not None + assert state.prepared_group_source_module_by_device is None + + +def test_paroquant_processor_caches_group_dataset_split(): + """Guard grouped dataset slicing against recomputing the same train/val split every group.""" + processor = object.__new__(ParoQuantProcessor) + processor.qcfg = SimpleNamespace(opt_train_samples=4, opt_validation_samples=2) + processor.inputs_cache = SimpleNamespace(position_ids=[], attention_masks=[]) + + inputs = [[torch.randn(1, 4)], [torch.randn(1, 4)]] + outputs = [[torch.randn(1, 4)], [torch.randn(1, 4)]] + state = SimpleNamespace( + layer_inputs=inputs, + layer_input_kwargs=[{}, {}], + layer_outputs=outputs, + grouped_dataset=None, + ) + + first = processor._group_dataset_from_state(state) + assert state.grouped_dataset is first + + state.layer_inputs = [] + state.layer_input_kwargs = [] + state.layer_outputs = [] + second = processor._group_dataset_from_state(state) + + assert second is first + + +def test_paroquant_processor_merges_equivalent_group_optimizer_param_groups(): + """Guard grouped AdamW setup against spawning redundant one-parameter optimizer groups.""" + p1 = torch.nn.Parameter(torch.randn(4)) + p2 = torch.nn.Parameter(torch.randn(4)) + p3 = torch.nn.Parameter(torch.randn(4)) + processor = object.__new__(ParoQuantProcessor) + + groups = processor._normalize_group_optimizer_param_groups( + [ + {"params": [p1], "lr": 0.05, "weight_decay": 0.01, "betas": (0.9, 0.95), "eps": 1e-10}, + {"params": [p2], "lr": 0.05, "weight_decay": 0.01, "betas": (0.9, 0.95), "eps": 1e-10}, + {"params": [p1], "lr": 0.05, "weight_decay": 0.01, "betas": (0.9, 0.95), "eps": 1e-10}, + {"params": [p3], "lr": 1e-5, "weight_decay": 0.01, "betas": (0.9, 0.95), "eps": 1e-10}, + ] + ) + + assert len(groups) == 2 + assert groups[0]["params"] == [p1, p2] + assert groups[1]["params"] == [p3] + + +def test_paroquant_processor_group_adamw_uses_merged_groups(monkeypatch): + """Guard grouped optimizer setup against re-expanding merged parameter buckets.""" + processor = object.__new__(ParoQuantProcessor) + param_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + param = torch.nn.Parameter(torch.randn(4, device=param_device)) + calls = [] + + class _FakeOptimizer: + def __init__(self): + self.param_groups = [{"params": [param], "lr": 0.05}] + + def fake_adamw(param_groups, **kwargs): + calls.append((param_groups, kwargs.copy())) + return _FakeOptimizer() + + monkeypatch.setattr(torch.optim, "AdamW", fake_adamw) + optimizer = processor._build_group_adamw( + [{"params": [param], "lr": 0.05, "weight_decay": 0.01, "betas": (0.9, 0.95), "eps": 1e-10}], + device=torch.device("cuda"), + ) + + assert isinstance(optimizer, _FakeOptimizer) + expected_kwargs = {"fused": True} if torch.cuda.is_available() else {} + expected_param_group = { + "params": [param], + "lr": 0.05, + "weight_decay": 0.01, + "betas": (0.9, 0.95), + "eps": 1e-10, + "amsgrad": False, + } + assert calls == [([expected_param_group], expected_kwargs)] + + +def test_paroquant_processor_group_adamw_falls_back_when_fused_cuda_is_unsupported(monkeypatch): + """Guard grouped CUDA optimizer setup against fused AdamW support gaps.""" + + processor = object.__new__(ParoQuantProcessor) + param_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + param = torch.nn.Parameter(torch.randn(4, device=param_device)) + calls = [] + + class _FakeOptimizer: + def __init__(self): + self.param_groups = [{"params": [param], "lr": 0.05}] + + def fake_adamw(param_groups, **kwargs): + calls.append(kwargs.copy()) + if kwargs.get("fused"): + raise TypeError("fused unsupported") + return _FakeOptimizer() + + monkeypatch.setattr(torch.optim, "AdamW", fake_adamw) + optimizer = processor._build_group_adamw( + [{"params": [param], "lr": 0.05, "weight_decay": 0.01, "betas": (0.9, 0.95), "eps": 1e-10}], + device=torch.device("cuda"), + ) + + assert isinstance(optimizer, _FakeOptimizer) + expected = [{"fused": True}, {}] if torch.cuda.is_available() else [{}] + assert calls == expected + + +def test_paroquant_processor_group_optimizer_uses_selected_sgd(monkeypatch): + """Guard grouped optimizer setup against ignoring the selected stage optimizer.""" + processor = object.__new__(ParoQuantProcessor) + param_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + param = torch.nn.Parameter(torch.randn(4, device=param_device)) + calls = [] + + class _FakeOptimizer: + def __init__(self): + self.param_groups = [{"params": [param], "lr": 0.05}] + + def fake_sgd(param_groups, **kwargs): + calls.append((param_groups, kwargs.copy())) + return _FakeOptimizer() + + monkeypatch.setattr(torch.optim, "SGD", fake_sgd) + optimizer = processor._build_group_optimizer( + [ + { + "params": [param], + "lr": 0.05, + "weight_decay": 0.01, + "momentum": 0.85, + "dampening": 0.0, + "nesterov": True, + } + ], + device=torch.device("cuda"), + optimizer_name="sgd", + ) + + assert isinstance(optimizer, _FakeOptimizer) + expected_kwargs = {"fused": True} if torch.cuda.is_available() else {} + expected_param_group = { + "params": [param], + "lr": 0.05, + "weight_decay": 0.01, + "momentum": 0.85, + "dampening": 0.0, + "nesterov": True, + } + assert calls == [([expected_param_group], expected_kwargs)] + + +def test_paroquant_processor_group_adamw_passes_amsgrad(monkeypatch): + """Guard grouped AdamW setup against dropping optimizer-specific hyperparameters.""" + processor = object.__new__(ParoQuantProcessor) + param_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + param = torch.nn.Parameter(torch.randn(4, device=param_device)) + calls = [] + + class _FakeOptimizer: + def __init__(self): + self.param_groups = [{"params": [param], "lr": 0.05}] + + def fake_adamw(param_groups, **kwargs): + calls.append(kwargs.copy()) + return _FakeOptimizer() + + monkeypatch.setattr(torch.optim, "AdamW", fake_adamw) + optimizer = processor._build_group_optimizer( + [{"params": [param], "lr": 0.05, "weight_decay": 0.01, "amsgrad": True}], + device=torch.device("cuda"), + optimizer_name="adamw", + ) + + assert isinstance(optimizer, _FakeOptimizer) + expected = [{"fused": True}] if torch.cuda.is_available() else [{}] + assert calls == expected + + +def test_paroquant_processor_group_stage_skips_redundant_initial_train_eval(monkeypatch): + """Guard grouped stage timing against paying an extra full train-set eval before epoch 1.""" + + class _ToyLayer(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4, bias=False) + + def forward(self, x, **_kwargs): + return self.linear(x) + + processor = object.__new__(ParoQuantProcessor) + processor.qcfg = SimpleNamespace(opt_stage_impl="fast") + processor.gptq_model = None + processor.model = None + calls = [] + + def fake_evaluate(*_args, **_kwargs): + calls.append("eval") + return 0.0 + + monkeypatch.setattr(processor, "_evaluate_group_layer", fake_evaluate) + layer = _ToyLayer() + input_batch = [[torch.randn(2, 4)]] + target_batch = [[torch.randn(2, 4)]] + + processor._run_group_stage( + layer, + optim_modules={}, + input_batches_train=input_batch, + input_kwargs_train=[{}], + target_batches_train=target_batch, + position_ids_train=[None], + attention_masks_train=[None], + input_batches_val=input_batch, + input_kwargs_val=[{}], + target_batches_val=target_batch, + position_ids_val=[None], + attention_masks_val=[None], + param_groups=[{"params": [layer.linear.weight], "lr": 0.05, "weight_decay": 0.01, "betas": (0.9, 0.95), "eps": 1e-10}], + epochs=1, + ) + + assert calls == ["eval"] + + +def test_paroquant_processor_group_stage_defers_best_state_snapshot_until_first_val(monkeypatch): + """Guard grouped stage setup against cloning the full layer before the first val result exists.""" + + class _ToyLayer(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4, bias=False) + self.state_dict_calls = 0 + + def forward(self, x, **_kwargs): + return self.linear(x) + + def state_dict(self, *args, **kwargs): + self.state_dict_calls += 1 + return super().state_dict(*args, **kwargs) + + processor = object.__new__(ParoQuantProcessor) + processor.qcfg = SimpleNamespace(opt_stage_impl="fast") + processor.gptq_model = None + processor.model = None + layer = _ToyLayer() + input_batch = [[torch.randn(2, 4)]] + target_batch = [[torch.randn(2, 4)]] + + processor._run_group_stage( + layer, + optim_modules={}, + input_batches_train=input_batch, + input_kwargs_train=[{}], + target_batches_train=target_batch, + position_ids_train=[None], + attention_masks_train=[None], + input_batches_val=input_batch, + input_kwargs_val=[{}], + target_batches_val=target_batch, + position_ids_val=[None], + attention_masks_val=[None], + param_groups=[{"params": [layer.linear.weight], "lr": 0.05, "weight_decay": 0.01, "betas": (0.9, 0.95), "eps": 1e-10}], + epochs=1, + ) + + assert layer.state_dict_calls == 1 + + +def test_paroquant_processor_group_dataset_for_device_caches_per_device(): + """Guard grouped dataset replay so repeated requests on the same device reuse one cached copy.""" + + processor = object.__new__(ParoQuantProcessor) + processor.qcfg = SimpleNamespace(opt_train_samples=8, opt_validation_samples=8) + processor.inputs_cache = InputCache( + layer_inputs=[[torch.randn(1, 4)]], + layer_input_kwargs=[{"position_ids": None}], + position_ids=[torch.arange(4).unsqueeze(0)], + attention_masks=[None], + ) + state = SimpleNamespace( + layer_inputs=[[torch.randn(1, 4)]], + layer_input_kwargs=[{"position_ids": None}], + layer_outputs=[[torch.randn(1, 4)]], + grouped_dataset=None, + grouped_dataset_by_device=None, + ) + + first = processor._group_dataset_for_device(state, torch.device("cpu")) + second = processor._group_dataset_for_device(state, torch.device("cpu")) + + assert first is second + assert state.grouped_dataset is not None + assert state.grouped_dataset_by_device is not None + assert state.grouped_dataset_by_device["cpu"] is first + + +def test_paroquant_processor_replay_batches_cache_cpu_splits(): + """Guard layer-scope replay batches so they stay on CPU and cache their train/val split.""" + + processor = object.__new__(ParoQuantProcessor) + processor.qcfg = SimpleNamespace(opt_train_samples=4, opt_validation_samples=2) + processor.inputs_cache = InputCache( + layer_inputs=[[torch.randn(2, 4)] for _ in range(3)], + layer_input_kwargs=[{} for _ in range(3)], + position_ids=[None, None, None], + attention_masks=[None, None, None], + ) + state = SimpleNamespace( + layer_inputs=[[torch.randn(2, 4)] for _ in range(3)], + layer_input_kwargs=[{} for _ in range(3)], + layer_outputs=[[torch.randn(2, 4)] for _ in range(3)], + replay_batches=None, + ) + + first_train, first_val = processor._replay_batches_from_state(state) + second_train, second_val = processor._replay_batches_from_state(state) + + assert first_train is second_train + assert first_val is second_val + assert len(first_train) == 2 + assert len(first_val) == 1 + assert all(batch.inputs[0].device.type == "cpu" for batch in first_train + first_val) + assert all(batch.target.device.type == "cpu" for batch in first_train + first_val) + if torch.cuda.is_available(): + assert all(batch.inputs[0].is_pinned() for batch in first_train + first_val) + assert all(batch.target.is_pinned() for batch in first_train + first_val) + + +def test_paroquant_processor_replay_batches_strip_inference_tensors(): + """Replay-cache tensors must be recreated outside inference mode before layer training.""" + + processor = object.__new__(ParoQuantProcessor) + processor.qcfg = SimpleNamespace(opt_train_samples=2, opt_validation_samples=1) + + with torch.inference_mode(): + cached_input = torch.randn(2, 4) + cached_kwarg = torch.randn(2, 4) + cached_output = torch.randn(2, 4) + cached_pos = torch.arange(4).unsqueeze(0) + cached_mask = torch.ones(1, 4) + + processor.inputs_cache = InputCache( + layer_inputs=[[cached_input]], + layer_input_kwargs=[{"cache_position": cached_kwarg}], + position_ids=[cached_pos], + attention_masks=[cached_mask], + ) + state = SimpleNamespace( + layer_inputs=[[cached_input]], + layer_input_kwargs=[{"cache_position": cached_kwarg}], + layer_outputs=[[cached_output]], + replay_batches=None, + ) + + train_batches, val_batches = processor._replay_batches_from_state(state) + replay_batch = train_batches[0] + + assert not replay_batch.inputs[0].is_inference() + assert not replay_batch.input_kwargs["cache_position"].is_inference() + assert not replay_batch.target.is_inference() + assert replay_batch.position_ids is not None + assert replay_batch.attention_mask is not None + assert not replay_batch.position_ids.is_inference() + assert not replay_batch.attention_mask.is_inference() + assert len(train_batches) == 1 + assert len(val_batches) == 1 + + +def test_paroquant_processor_layer_shard_loader_normalizes_inference_inputs(): + """Materialized replay batches must hand autograd normal tensors even if cache input is inference-mode.""" + + class _ToyNorm(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(4)) + + def forward(self, hidden_states): + return self.weight * hidden_states.to(hidden_states.dtype) + + with torch.inference_mode(): + replay_batch = paroquant_processor_module._ParoQuantReplayBatch( + inputs=[torch.randn(2, 4)], + input_kwargs={}, + target=torch.randn(2, 4), + position_ids=None, + attention_mask=None, + row_count=2, + ) + + loader = paroquant_processor_module._LayerShardLoader( + [replay_batch], + target_device=torch.device("cpu"), + shard_batches=1, + ) + materialized_batch = next(loader.iter_shards())[0] + layer = _ToyNorm() + + assert not materialized_batch.inputs[0].is_inference() + assert not materialized_batch.target.is_inference() + + with torch.inference_mode(False), torch.enable_grad(): + output = layer(materialized_batch.inputs[0]) + + assert output.shape == materialized_batch.inputs[0].shape + + +def test_paroquant_processor_group_checkpoint_normalizes_inference_inputs(): + """Grouped checkpoint training must rebuild inference-mode inputs into autograd-safe tensors.""" + + processor = object.__new__(ParoQuantProcessor) + processor._gradient_checkpointing_enabled = lambda: True + + captured = {} + layer_scale = torch.nn.Parameter(torch.tensor(2.0)) + + def _fake_forward_group_batch( + layer, + *, + batch_index, + input_batch, + input_kwargs, + attention_mask, + position_ids, + ): + captured["inputs_inference"] = [tensor.is_inference() for tensor in input_batch] + captured["kwargs_inference"] = paroquant_processor_module._value_has_inference_tensor(input_kwargs) + captured["mask_inference"] = attention_mask.is_inference() if attention_mask is not None else False + captured["pos_inference"] = position_ids.is_inference() if position_ids is not None else False + return input_batch[0] * layer_scale + + processor._forward_group_batch = _fake_forward_group_batch + + with torch.inference_mode(): + input_batch = [torch.randn(2, 4)] + input_kwargs = {"cache_position": torch.arange(4)} + attention_mask = torch.ones(1, 4) + position_ids = torch.arange(4).unsqueeze(0) + + output = processor._forward_group_batch_train( + object(), + batch_index=0, + input_batch=input_batch, + input_kwargs=input_kwargs, + attention_mask=attention_mask, + position_ids=position_ids, + ) + output.sum().backward() + + assert captured["inputs_inference"] == [False] + assert captured["kwargs_inference"] is False + assert captured["mask_inference"] is False + assert captured["pos_inference"] is False + assert layer_scale.grad is not None + + +def test_paroquant_processor_cached_group_position_ids_are_autograd_safe(): + """Generated position-id cache entries must stay reusable outside worker inference mode.""" + + processor = object.__new__(ParoQuantProcessor) + + with torch.inference_mode(): + cached = processor._cached_group_position_ids( + device=torch.device("cpu"), + batch_dim=2, + seq_len=4, + ) + + assert not cached.is_inference() + assert processor._cached_group_position_ids(device=torch.device("cpu"), batch_dim=2, seq_len=4) is cached + + +def test_paroquant_processor_cached_rotary_embeddings_are_autograd_safe(): + """Rotary cache entries created during inference replay must not leak inference tensors into training.""" + + class _ToyRotary(torch.nn.Module): + def forward(self, x, position_ids): + pos = position_ids.unsqueeze(-1).to(dtype=x.dtype) + return x + pos, x - pos + + processor = object.__new__(ParoQuantProcessor) + rotary = _ToyRotary() + + with torch.inference_mode(): + x = torch.randn(1, 4, 8) + position_ids = torch.arange(4).unsqueeze(0) + cached = processor._cached_group_rotary_position_embeddings( + rotary=rotary, + x=x, + position_ids=position_ids, + rotary_device=torch.device("cpu"), + ) + + assert not paroquant_processor_module._value_has_inference_tensor(cached) + + with torch.inference_mode(False), torch.enable_grad(): + q = torch.randn(1, 4, 8, requires_grad=True) + cos, sin = processor._cached_group_rotary_position_embeddings( + rotary=rotary, + x=x, + position_ids=position_ids, + rotary_device=torch.device("cpu"), + ) + loss = (q * cos).sum() + (q * sin).sum() + loss.backward() + + assert q.grad is not None + + +def test_paroquant_processor_layer_shard_loader_reuses_metadata_tensors(): + """Guard streamed layer replay so shared position/mask tensors can stay cached on one device.""" + + if not torch.cuda.is_available(): + return + + cpu_pos = torch.arange(8).unsqueeze(0) + cpu_mask = torch.ones(1, 8) + if not cpu_pos.is_pinned(): + cpu_pos = cpu_pos.pin_memory() + if not cpu_mask.is_pinned(): + cpu_mask = cpu_mask.pin_memory() + + replay_batch = paroquant_processor_module._ParoQuantReplayBatch( + inputs=[torch.randn(1, 8).pin_memory()], + input_kwargs={}, + target=torch.randn(1, 8).pin_memory(), + position_ids=cpu_pos, + attention_mask=cpu_mask, + row_count=1, + ) + metadata_cache = {} + + loader_a = paroquant_processor_module._LayerShardLoader( + [replay_batch], + target_device=torch.device("cuda"), + shard_batches=1, + metadata_cache=metadata_cache, + ) + loader_b = paroquant_processor_module._LayerShardLoader( + [replay_batch], + target_device=torch.device("cuda"), + shard_batches=1, + metadata_cache=metadata_cache, + ) + + batch_a = next(loader_a.iter_shards())[0] + batch_b = next(loader_b.iter_shards())[0] + + assert batch_a.position_ids is batch_b.position_ids + assert batch_a.attention_mask is batch_b.attention_mask + assert batch_a.target is not batch_b.target + + +def test_paroquant_processor_group_best_state_tracks_only_active_prefixes(): + """Guard grouped best-state snapshots against cloning untouched layer state.""" + + class _ToyLayer(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = torch.nn.Linear(4, 4, bias=False) + self.b = torch.nn.Linear(4, 4, bias=False) + + processor = object.__new__(ParoQuantProcessor) + layer = _ToyLayer() + original_b = layer.b.weight.detach().clone() + + best_state = processor._snapshot_group_best_state(layer, active_prefixes=("a",)) + + assert sorted(best_state.keys()) == ["a.weight"] + + with torch.no_grad(): + layer.a.weight.zero_() + layer.b.weight.fill_(7.0) + + processor._restore_group_best_state(layer, best_state=best_state) + + assert torch.allclose(layer.a.weight, best_state["a.weight"]) + assert torch.allclose(layer.b.weight, torch.full_like(layer.b.weight, 7.0)) + assert not torch.allclose(layer.b.weight, original_b) + + +def test_paroquant_processor_group_best_state_can_snapshot_to_cpu(): + """Guard streamed layer checkpoints so best-state snapshots can live off device.""" + + class _ToyLayer(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = torch.nn.Linear(4, 4, bias=False) + self.b = torch.nn.Linear(4, 4, bias=False) + + processor = object.__new__(ParoQuantProcessor) + layer = _ToyLayer() + + best_state = processor._snapshot_group_best_state( + layer, + active_prefixes=("a",), + target_device=torch.device("cpu"), + ) + + assert sorted(best_state.keys()) == ["a.weight"] + assert all(tensor.device.type == "cpu" for tensor in best_state.values()) + + +def test_paroquant_processor_group_best_state_can_cast_float_snapshots_without_touching_int_buffers(): + """Guard grouped best-state compression so float tensors shrink without corrupting integer buffers.""" + + class _ToyBranch(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = torch.nn.Parameter(torch.randn(4, 4, dtype=torch.float32)) + self.register_buffer("index", torch.tensor([1, 2], dtype=torch.int32)) + + class _ToyLayer(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = _ToyBranch() + self.b = _ToyBranch() + + processor = object.__new__(ParoQuantProcessor) + layer = _ToyLayer() + + best_state = processor._snapshot_group_best_state( + layer, + active_prefixes=("a",), + target_device=torch.device("cpu"), + target_dtype=torch.bfloat16, + ) + + assert sorted(best_state.keys()) == ["a.index", "a.weight"] + assert best_state["a.weight"].dtype == torch.bfloat16 + assert best_state["a.index"].dtype == torch.int32 + + +def test_paroquant_quantize_config_accepts_torch_float16_best_state_dtype(): + """Guard direct config construction so torch.float16 snapshots serialize as fp16.""" + cfg = ParoConfig( + bits=4, + group_size=128, + opt_best_state_dtype=torch.float16, + ) + + assert cfg.opt_best_state_dtype == "fp16" + + +def test_paroquant_best_state_dtype_resolves_explicit_fp16(): + """Guard explicit fp16 snapshot selection.""" + resolved = paroquant_optimization._resolve_best_state_snapshot_dtype( + best_state_dtype="fp16", + device=torch.device("cuda"), + ) + + assert resolved == torch.float16 + + +def test_paroquant_best_state_dtype_resolves_explicit_bf16(): + """Guard explicit bf16 snapshot selection after removing the auto policy.""" + resolved = paroquant_optimization._resolve_best_state_snapshot_dtype( + best_state_dtype="bf16", + device=torch.device("cuda"), + ) + + assert resolved == torch.bfloat16 + + +def test_paroquant_best_state_dtype_defaults_to_fp32(): + """Guard the no-auto default so missing best-state dtype configuration stays on fp32.""" + resolved = paroquant_optimization._resolve_best_state_snapshot_dtype( + best_state_dtype=None, + device=torch.device("cpu"), + ) + + assert resolved == torch.float32 + + +def test_paroquant_processor_caches_group_forward_signature_flags(monkeypatch): + """Guard grouped replay kwargs against repeated forward-signature introspection.""" + + class _ToyLayer(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4, bias=False) + + def forward(self, x, attention_mask=None, position_ids=None): + del attention_mask, position_ids + return self.linear(x) + + processor = object.__new__(ParoQuantProcessor) + processor.gptq_model = None + processor.model = None + layer = _ToyLayer() + signature_calls = [] + original_signature = inspect.signature + + def counting_signature(obj): + signature_calls.append(obj) + return original_signature(obj) + + monkeypatch.setattr(inspect, "signature", counting_signature) + + kwargs_a = processor._prepare_group_forward_kwargs( + layer, + x=torch.randn(2, 4), + input_kwargs={}, + attention_mask=torch.ones(1, 2), + position_ids=torch.arange(2).unsqueeze(0), + ) + kwargs_b = processor._prepare_group_forward_kwargs( + layer, + x=torch.randn(2, 4), + input_kwargs={}, + attention_mask=torch.ones(1, 2), + position_ids=torch.arange(2).unsqueeze(0), + ) + + assert len(signature_calls) == 1 + assert kwargs_a.keys() == kwargs_b.keys() + assert processor._group_forward_signature_cache[type(layer)] == (True, False, True) + + +def test_paroquant_processor_caches_group_forward_base_kwargs(monkeypatch): + """Guard grouped replay kwargs against repeated nested device moves for the same cached dict.""" + + class _ToyLayer(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4, bias=False) + + def forward(self, x, attention_mask=None, position_ids=None): + del attention_mask, position_ids + return self.linear(x) + + processor = object.__new__(ParoQuantProcessor) + processor.gptq_model = None + processor.model = None + layer = _ToyLayer() + move_calls = [] + original_nested_move_to = paroquant_processor_module.nested_move_to + + def counting_nested_move_to(value, *, device): + move_calls.append((type(value).__name__, str(device))) + return original_nested_move_to(value, device=device) + + monkeypatch.setattr(paroquant_processor_module, "nested_move_to", counting_nested_move_to) + shared_kwargs = {"foo": {"bar": torch.randn(1)}} + + kwargs_a = processor._prepare_group_forward_kwargs( + layer, + x=torch.randn(2, 4), + input_kwargs=shared_kwargs, + attention_mask=torch.ones(1, 2), + position_ids=torch.arange(2).unsqueeze(0), + ) + kwargs_b = processor._prepare_group_forward_kwargs( + layer, + x=torch.randn(2, 4), + input_kwargs=shared_kwargs, + attention_mask=torch.ones(1, 2), + position_ids=torch.arange(2).unsqueeze(0), + ) + + assert len(move_calls) == 1 + assert kwargs_a.keys() == kwargs_b.keys() + assert torch.allclose(kwargs_a["foo"]["bar"], kwargs_b["foo"]["bar"]) + + +def test_paroquant_processor_caches_full_group_forward_kwargs(monkeypatch): + """Guard grouped replay kwargs against recomputing rotary-derived inputs for identical batches.""" + + class _ToyLayer(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4, bias=False) + + def forward(self, x, attention_mask=None, position_ids=None, position_embeddings=None): + del attention_mask, position_ids, position_embeddings + return self.linear(x) + + class _FakeRotary(torch.nn.Module): + def __init__(self): + super().__init__() + self.calls = 0 + + def forward(self, x, position_ids): + self.calls += 1 + return x + position_ids.unsqueeze(-1).to(dtype=x.dtype) + + processor = object.__new__(ParoQuantProcessor) + processor.gptq_model = None + processor.model = None + layer = _ToyLayer() + rotary = _FakeRotary() + + monkeypatch.setattr(processor, "_get_root_rotary", lambda: rotary) + monkeypatch.setattr(processor, "_get_rotary_for_device", lambda device: rotary) + monkeypatch.setattr(processor, "_get_rotary_device", lambda module, fallback=None: torch.device("cpu")) + + x = torch.randn(1, 2, 4) + attention_mask = torch.ones(1, 2) + position_ids = torch.arange(2).unsqueeze(0) + shared_kwargs = {} + + kwargs_a = processor._prepare_group_forward_kwargs( + layer, + x=x, + input_kwargs=shared_kwargs, + attention_mask=attention_mask, + position_ids=position_ids, + ) + kwargs_b = processor._prepare_group_forward_kwargs( + layer, + x=x, + input_kwargs=shared_kwargs, + attention_mask=attention_mask, + position_ids=position_ids, + ) + + assert rotary.calls == 1 + assert kwargs_a is not kwargs_b + assert kwargs_a.keys() == kwargs_b.keys() + assert torch.allclose(kwargs_a["position_embeddings"], kwargs_b["position_embeddings"]) + + +def test_paroquant_processor_caches_rotary_position_embeddings_across_distinct_inputs(monkeypatch): + """Guard streamed grouped replay against recomputing HF rotary embeddings for identical ids/device/dtype.""" + + class _ToyLayer(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4, bias=False) + + def forward(self, x, attention_mask=None, position_ids=None, position_embeddings=None): + del attention_mask, position_ids, position_embeddings + return self.linear(x) + + class _FakeLlamaRotaryEmbedding(torch.nn.Module): + __module__ = "transformers.models.llama.modeling_llama" + + def __init__(self): + super().__init__() + self.calls = 0 + + def forward(self, x, position_ids): + self.calls += 1 + base = position_ids.unsqueeze(-1).to(device=x.device, dtype=x.dtype) + return (base.cos(), base.sin()) + + processor = object.__new__(ParoQuantProcessor) + processor.gptq_model = None + processor.model = None + layer = _ToyLayer() + rotary = _FakeLlamaRotaryEmbedding() + + monkeypatch.setattr(processor, "_get_root_rotary", lambda: rotary) + monkeypatch.setattr(processor, "_get_rotary_for_device", lambda device: rotary) + monkeypatch.setattr(processor, "_get_rotary_device", lambda module, fallback=None: torch.device("cpu")) + + shared_position_ids = torch.arange(4).unsqueeze(0) + + kwargs_a = processor._prepare_group_forward_kwargs( + layer, + x=torch.randn(1, 4, 4), + input_kwargs={}, + attention_mask=torch.ones(1, 4), + position_ids=shared_position_ids, + cache=False, + ) + kwargs_b = processor._prepare_group_forward_kwargs( + layer, + x=torch.randn(1, 4, 4), + input_kwargs={}, + attention_mask=torch.ones(1, 4), + position_ids=shared_position_ids, + cache=False, + ) + + assert rotary.calls == 1 + assert torch.allclose(kwargs_a["position_embeddings"][0], kwargs_b["position_embeddings"][0]) + assert torch.allclose(kwargs_a["position_embeddings"][1], kwargs_b["position_embeddings"][1]) + + +def test_paroquant_processor_caches_generated_position_ids(): + """Guard grouped replay against rebuilding deterministic synthetic position ids.""" + + processor = object.__new__(ParoQuantProcessor) + + first = processor._cached_group_position_ids(device=torch.device("cpu"), batch_dim=2, seq_len=8) + second = processor._cached_group_position_ids(device=torch.device("cpu"), batch_dim=2, seq_len=8) + third = processor._cached_group_position_ids(device=torch.device("cpu"), batch_dim=1, seq_len=8) + + assert first is second + assert third is not first + assert first.shape == (2, 8) + assert third.shape == (1, 8) + + +def test_paroquant_processor_streamed_group_forward_kwargs_skip_redundant_moves(monkeypatch): + """Guard streamed layer replay against recursively re-moving already device-ready kwargs.""" + + class _ToyLayer(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4, bias=False) + + def forward(self, x, attention_mask=None, position_ids=None): + del attention_mask, position_ids + return self.linear(x) + + processor = object.__new__(ParoQuantProcessor) + processor.gptq_model = None + processor.model = None + layer = _ToyLayer() + moved_values = [] + + original_move_value = paroquant_processor_module._LayerShardLoader._move_value_to_device + + def counting_move_value(value, device): + moved_values.append((type(value).__name__, str(device))) + return original_move_value(value, device) + + monkeypatch.setattr(paroquant_processor_module._LayerShardLoader, "_move_value_to_device", counting_move_value) + + attention_mask = torch.ones(1, 2) + position_ids = torch.arange(2).unsqueeze(0) + shared_kwargs = {"foo": {"bar": torch.randn(1)}} + kwargs = processor._prepare_group_forward_kwargs( + layer, + x=torch.randn(2, 4), + input_kwargs=shared_kwargs, + attention_mask=attention_mask, + position_ids=position_ids, + cache=False, + ) + + assert kwargs["foo"]["bar"] is shared_kwargs["foo"]["bar"] + assert kwargs["attention_mask"] is attention_mask + assert kwargs["position_ids"] is position_ids + + +def test_paroquant_processor_group_forward_kwargs_drop_past_key_values(): + """Layer-scope replay must mirror the normal forward executor and omit KV-cache objects.""" + + class _ToyLayer(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4, bias=False) + + def forward( + self, + x, + attention_mask=None, + position_ids=None, + past_key_values=None, + past_key_value=None, + use_cache=False, + ): + del attention_mask, position_ids, past_key_values, past_key_value, use_cache + return self.linear(x) + + processor = object.__new__(ParoQuantProcessor) + processor.gptq_model = None + processor.model = None + layer = _ToyLayer() + + kwargs = processor._prepare_group_forward_kwargs( + layer, + x=torch.randn(2, 4), + input_kwargs={ + "past_key_values": object(), + "past_key_value": object(), + "cache_position": torch.arange(2), + }, + attention_mask=torch.ones(1, 2), + position_ids=torch.arange(2).unsqueeze(0), + cache=False, + ) + + assert "past_key_values" not in kwargs + assert "past_key_value" not in kwargs + assert "cache_position" in kwargs + + +def test_paroquant_processor_force_layer_eager_attention_restores_shared_config(): + """Live-layer optimization should temporarily switch shared attention config to eager and restore it.""" + + class _Config: + def __init__(self): + self._attn_implementation = "flash_attention_2" + self.attn_implementation = "flash_attention_2" + + class _ToyAttention(torch.nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + + class _ToyLayer(torch.nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.self_attn = _ToyAttention(config) + + processor = object.__new__(ParoQuantProcessor) + shared_config = _Config() + layer = _ToyLayer(shared_config) + + overrides = processor._force_layer_eager_attention(layer) + + assert shared_config._attn_implementation == "eager" + assert shared_config.attn_implementation == "eager" + assert len(overrides) == 2 + + processor._restore_layer_attention_impl(overrides) + + assert shared_config._attn_implementation == "flash_attention_2" + assert shared_config.attn_implementation == "flash_attention_2" + + +def test_paroquant_processor_prepare_group_forward_kwargs_normalizes_inference_position_embeddings(monkeypatch): + """Grouped replay kwargs must clone rotary metadata back to normal tensors before live-layer forward.""" + + class _ToyLayer(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(4, 4, bias=False) + + def forward(self, x, attention_mask=None, position_ids=None, position_embeddings=None): + del attention_mask, position_ids, position_embeddings + return self.linear(x) + + class _FakeLlamaRotaryEmbedding(torch.nn.Module): + __module__ = "transformers.models.llama.modeling_llama" + + def forward(self, x, position_ids): + del x, position_ids + with torch.inference_mode(): + return (torch.randn(1, 2, 4), torch.randn(1, 2, 4)) + + processor = object.__new__(ParoQuantProcessor) + processor.gptq_model = None + processor.model = None + layer = _ToyLayer() + rotary = _FakeLlamaRotaryEmbedding() + + monkeypatch.setattr(processor, "_get_root_rotary", lambda: rotary) + monkeypatch.setattr(processor, "_get_rotary_for_device", lambda device: rotary) + monkeypatch.setattr(processor, "_get_rotary_device", lambda module, fallback=None: torch.device("cpu")) + + kwargs = processor._prepare_group_forward_kwargs( + layer, + x=torch.randn(1, 2, 4), + input_kwargs={}, + attention_mask=torch.ones(1, 2), + position_ids=torch.arange(2).unsqueeze(0), + cache=False, + ) + + assert "position_embeddings" in kwargs + assert not paroquant_processor_module._value_has_inference_tensor(kwargs["position_embeddings"]) + + +def test_paroquant_processor_caches_group_targets_by_dtype_and_device(): + """Guard grouped replay targets against repeated device/dtype conversions.""" + + processor = object.__new__(ParoQuantProcessor) + target_batch = [torch.randn(2, 4)] + + first = processor._prepare_group_target(target_batch, device=torch.device("cpu"), dtype=torch.float32) + second = processor._prepare_group_target(target_batch, device=torch.device("cpu"), dtype=torch.float32) + third = processor._prepare_group_target(target_batch, device=torch.device("cpu"), dtype=torch.float16) + + assert first is second + assert third.dtype == torch.float16 + assert third is not first + + +def test_paroquant_processor_optimize_group_runs_on_toy_layer(): + """Guard the grouped optimizer path on a tiny layer without needing the full looper.""" + + class _ToyAttn(torch.nn.Module): + def __init__(self): + super().__init__() + self.q_proj = torch.nn.Linear(8, 8, bias=False) + self.k_proj = torch.nn.Linear(8, 8, bias=False) + self.v_proj = torch.nn.Linear(8, 8, bias=False) + self.o_proj = torch.nn.Linear(8, 8, bias=False) + + def forward(self, x): + return self.o_proj(self.q_proj(x) + self.k_proj(x) + self.v_proj(x)) + + class _ToyMlp(torch.nn.Module): + def __init__(self): + super().__init__() + self.gate_proj = torch.nn.Linear(8, 8, bias=False) + self.up_proj = torch.nn.Linear(8, 8, bias=False) + self.down_proj = torch.nn.Linear(8, 8, bias=False) + + def forward(self, x): + return self.down_proj(torch.sigmoid(self.gate_proj(x)) * self.up_proj(x)) + + class _ToyLayer(torch.nn.Module): + def __init__(self): + super().__init__() + self.self_attn = _ToyAttn() + self.mlp = _ToyMlp() + + def forward(self, x, attention_mask=None, position_ids=None, use_cache=False): + del attention_mask, position_ids, use_cache + x = self.self_attn(x) + return self.mlp(x) + + def _dynamic_get(_module_name, _key, default=None): + return default + + layer = _ToyLayer() + x = torch.randn(1, 2, 8) + y = layer(x).detach() + q_proj = NamedModule(layer.self_attn.q_proj, "self_attn.q_proj", "model.layers.0.self_attn.q_proj", 0) + k_proj = NamedModule(layer.self_attn.k_proj, "self_attn.k_proj", "model.layers.0.self_attn.k_proj", 0) + v_proj = NamedModule(layer.self_attn.v_proj, "self_attn.v_proj", "model.layers.0.self_attn.v_proj", 0) + + processor = object.__new__(ParoQuantProcessor) + processor.qcfg = SimpleNamespace( + runtime_bits=4, + group_size=8, + sym=True, + krot=1, + opt_seed=0, + opt_pair_ratio=0.5, + opt_pair_impl="fast", + opt_quantizer_impl="reference", + opt_fused_rotation=False, + opt_channel_scale_clamp_min=1e-2, + opt_channel_scale_clamp_max=1e2, + opt_train_samples=2, + opt_validation_samples=2, + opt_stage_impl="fast", + opt_rotation_lr=0.05, + opt_weight_lr=1e-5, + opt_quantizer_lr=1e-6, + opt_rotation_epochs=0, + opt_finetune_epochs=0, + dynamic_get=_dynamic_get, + ) + processor.gptq_model = SimpleNamespace(support_batch_quantize=True) + processor._batch_tls = threading.local() + processor.inputs_cache = InputCache( + layer_inputs=[[x]], + layer_input_kwargs=[{}], + position_ids=[None], + attention_masks=[None], + ) + + state = SimpleNamespace( + layer_module=layer, + layer_inputs=[[x]], + layer_input_kwargs=[{}], + layer_outputs=[[y]], + ) + + results, val_loss = processor._optimize_group(state, [q_proj, k_proj, v_proj]) + + assert set(results) == {"self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"} + assert val_loss >= 0.0 + assert results["self_attn.q_proj"].pseudo_weight.shape == q_proj.weight.shape + assert results["self_attn.k_proj"].pseudo_weight.shape == k_proj.weight.shape + assert results["self_attn.v_proj"].pseudo_weight.shape == v_proj.weight.shape + + +def test_paroquant_processor_layer_scope_streams_without_device_dataset(monkeypatch): + """Guard layer scope against materializing the full grouped dataset on device.""" + + class _ToyAttn(torch.nn.Module): + def __init__(self): + super().__init__() + self.q_proj = torch.nn.Linear(8, 8, bias=False) + self.k_proj = torch.nn.Linear(8, 8, bias=False) + self.v_proj = torch.nn.Linear(8, 8, bias=False) + self.o_proj = torch.nn.Linear(8, 8, bias=False) + + def forward(self, x): + return self.o_proj(self.q_proj(x) + self.k_proj(x) + self.v_proj(x)) + + class _ToyMlp(torch.nn.Module): + def __init__(self): + super().__init__() + self.gate_proj = torch.nn.Linear(8, 8, bias=False) + self.up_proj = torch.nn.Linear(8, 8, bias=False) + self.down_proj = torch.nn.Linear(8, 8, bias=False) + + def forward(self, x): + return self.down_proj(torch.sigmoid(self.gate_proj(x)) * self.up_proj(x)) + + class _ToyLayer(torch.nn.Module): + def __init__(self): + super().__init__() + self.self_attn = _ToyAttn() + self.mlp = _ToyMlp() + + def forward(self, x, attention_mask=None, position_ids=None, use_cache=False): + del attention_mask, position_ids, use_cache + return self.mlp(self.self_attn(x)) + + def _dynamic_get(_module_name, _key, default=None): + return default + + layer = _ToyLayer() + x = torch.randn(1, 2, 8) + y = layer(x).detach() + q_proj = NamedModule(layer.self_attn.q_proj, "self_attn.q_proj", "model.layers.0.self_attn.q_proj", 0) + k_proj = NamedModule(layer.self_attn.k_proj, "self_attn.k_proj", "model.layers.0.self_attn.k_proj", 0) + v_proj = NamedModule(layer.self_attn.v_proj, "self_attn.v_proj", "model.layers.0.self_attn.v_proj", 0) + + processor = object.__new__(ParoQuantProcessor) + processor.qcfg = SimpleNamespace( + opt_scope="layer", + runtime_bits=4, + group_size=8, + sym=True, + krot=1, + opt_seed=0, + opt_pair_ratio=0.5, + opt_pair_impl="fast", + opt_quantizer_impl="reference", + opt_fused_rotation=False, + opt_channel_scale_clamp_min=1e-2, + opt_channel_scale_clamp_max=1e2, + opt_train_samples=2, + opt_validation_samples=2, + opt_stage_impl="fast", + opt_rotation_lr=0.05, + opt_weight_lr=1e-5, + opt_quantizer_lr=1e-6, + opt_rotation_epochs=0, + opt_finetune_epochs=0, + dynamic_get=_dynamic_get, + ) + processor.gptq_model = SimpleNamespace(support_batch_quantize=True) + processor._batch_tls = threading.local() + processor.inputs_cache = InputCache( + layer_inputs=[[x]], + layer_input_kwargs=[{}], + position_ids=[None], + attention_masks=[None], + ) + + def fail_group_dataset_for_device(*args, **kwargs): + raise AssertionError("layer scope should stream replay batches instead of caching a full device dataset") + + monkeypatch.setattr(processor, "_group_dataset_for_device", fail_group_dataset_for_device) + + state = SimpleNamespace( + layer_module=layer, + layer_inputs=[[x]], + layer_input_kwargs=[{}], + layer_outputs=[[y]], + replay_batches=None, + grouped_dataset_by_device=None, + ) + + results, val_loss = processor._optimize_group(state, [q_proj, k_proj, v_proj]) + + assert set(results) == {"self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"} + assert val_loss >= 0.0 + assert state.grouped_dataset_by_device is None + assert state.replay_batches is not None + replay_train, replay_val = state.replay_batches + assert all(batch.inputs[0].device.type == "cpu" for batch in replay_train + replay_val) + + +def test_paroquant_processor_layer_scope_skips_angle_reset_when_theta_frozen(monkeypatch): + """Guard finetune-only layer scope against redundant masked-angle resets.""" + + class _ToyAttn(torch.nn.Module): + def __init__(self): + super().__init__() + self.q_proj = torch.nn.Linear(8, 8, bias=False) + self.k_proj = torch.nn.Linear(8, 8, bias=False) + self.v_proj = torch.nn.Linear(8, 8, bias=False) + self.o_proj = torch.nn.Linear(8, 8, bias=False) + + def forward(self, x): + return self.o_proj(self.q_proj(x) + self.k_proj(x) + self.v_proj(x)) + + class _ToyMlp(torch.nn.Module): + def __init__(self): + super().__init__() + self.gate_proj = torch.nn.Linear(8, 8, bias=False) + self.up_proj = torch.nn.Linear(8, 8, bias=False) + self.down_proj = torch.nn.Linear(8, 8, bias=False) + + def forward(self, x): + return self.down_proj(torch.sigmoid(self.gate_proj(x)) * self.up_proj(x)) + + class _ToyLayer(torch.nn.Module): + def __init__(self): + super().__init__() + self.self_attn = _ToyAttn() + self.mlp = _ToyMlp() + + def forward(self, x, attention_mask=None, position_ids=None, use_cache=False): + del attention_mask, position_ids, use_cache + return self.mlp(self.self_attn(x)) + + def _dynamic_get(_module_name, _key, default=None): + return default + + layer = _ToyLayer() + x = torch.randn(1, 2, 8) + y = layer(x).detach() + q_proj = NamedModule(layer.self_attn.q_proj, "self_attn.q_proj", "model.layers.0.self_attn.q_proj", 0) + k_proj = NamedModule(layer.self_attn.k_proj, "self_attn.k_proj", "model.layers.0.self_attn.k_proj", 0) + v_proj = NamedModule(layer.self_attn.v_proj, "self_attn.v_proj", "model.layers.0.self_attn.v_proj", 0) + + processor = object.__new__(ParoQuantProcessor) + processor.qcfg = SimpleNamespace( + opt_scope="layer", + runtime_bits=4, + group_size=8, + sym=True, + krot=1, + opt_seed=0, + opt_pair_ratio=0.5, + opt_pair_impl="fast", + opt_quantizer_impl="reference", + opt_fused_rotation=False, + opt_channel_scale_clamp_min=1e-2, + opt_channel_scale_clamp_max=1e2, + opt_train_samples=2, + opt_validation_samples=2, + opt_stage_impl="fast", + opt_rotation_lr=0.05, + opt_weight_lr=1e-5, + opt_quantizer_lr=1e-6, + opt_rotation_epochs=0, + opt_finetune_epochs=1, + dynamic_get=_dynamic_get, + ) + processor.gptq_model = SimpleNamespace(support_batch_quantize=True) + processor.model = None + processor.lock = threading.Lock() + processor.calculate_w_wq_diff = False + processor._batch_tls = threading.local() + processor.inputs_cache = InputCache( + layer_inputs=[[x]], + layer_input_kwargs=[{}], + position_ids=[None], + attention_masks=[None], + ) + + state = SimpleNamespace( + layer_module=layer, + layer_inputs=[[x]], + layer_input_kwargs=[{}], + layer_outputs=[[y]], + replay_batches=None, + ) + + reset_calls = [] + + def counting_reset(optim_modules): + reset_calls.append(tuple(sorted(optim_modules))) + + monkeypatch.setattr(processor, "_reset_group_angles", counting_reset) + + results, val_loss = processor._optimize_group(state, [q_proj, k_proj, v_proj]) + + assert set(results) == {"self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"} + assert val_loss >= 0.0 + assert reset_calls == [] + + +def test_paroquant_processor_optimize_group_reenables_grad_inside_inference_mode(): + """Guard grouped optimization under the worker lifecycle, which runs process() inside inference mode.""" + + class _ToyLayer(torch.nn.Module): + def __init__(self): + super().__init__() + self.self_attn = torch.nn.Module() + self.self_attn.q_proj = torch.nn.Linear(8, 8, bias=False) + self.self_attn.k_proj = torch.nn.Linear(8, 8, bias=False) + self.self_attn.v_proj = torch.nn.Linear(8, 8, bias=False) + + def forward(self, x, attention_mask=None, position_ids=None, use_cache=False): + del attention_mask, position_ids, use_cache + return self.self_attn.q_proj(x) + self.self_attn.k_proj(x) + self.self_attn.v_proj(x) + + def _dynamic_get(_module_name, _key, default=None): + return default + + layer = _ToyLayer() + x = torch.randn(1, 2, 8) + y = layer(x).detach() + q_proj = NamedModule(layer.self_attn.q_proj, "self_attn.q_proj", "model.layers.0.self_attn.q_proj", 0) + k_proj = NamedModule(layer.self_attn.k_proj, "self_attn.k_proj", "model.layers.0.self_attn.k_proj", 0) + v_proj = NamedModule(layer.self_attn.v_proj, "self_attn.v_proj", "model.layers.0.self_attn.v_proj", 0) + + processor = object.__new__(ParoQuantProcessor) + processor.qcfg = SimpleNamespace( + runtime_bits=4, + group_size=8, + sym=True, + krot=1, + opt_seed=0, + opt_pair_ratio=0.5, + opt_pair_impl="fast", + opt_quantizer_impl="reference", + opt_fused_rotation=False, + opt_channel_scale_clamp_min=1e-2, + opt_channel_scale_clamp_max=1e2, + opt_train_samples=2, + opt_validation_samples=2, + opt_stage_impl="fast", + opt_rotation_lr=0.05, + opt_weight_lr=1e-5, + opt_quantizer_lr=1e-6, + opt_rotation_epochs=1, + opt_finetune_epochs=0, + dynamic_get=_dynamic_get, + ) + processor.gptq_model = SimpleNamespace(support_batch_quantize=True) + processor._batch_tls = threading.local() + processor.inputs_cache = InputCache( + layer_inputs=[[x]], + layer_input_kwargs=[{}], + position_ids=[None], + attention_masks=[None], + ) + + state = SimpleNamespace( + layer_module=layer, + layer_inputs=[[x]], + layer_input_kwargs=[{}], + layer_outputs=[[y]], + ) + + with torch.inference_mode(): + results, val_loss = processor._optimize_group(state, [q_proj, k_proj, v_proj]) + + assert set(results) == {"self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"} + assert val_loss >= 0.0 + + +def test_paroquant_processor_compute_block_scope_strips_hooked_linear_wrappers(): + """ComputeBlock clone optimization must unwrap HookedLinear so backward survives cloned full-layer replay.""" + + class _ToyLayer(torch.nn.Module): + def __init__(self): + super().__init__() + self.self_attn = torch.nn.Module() + self.self_attn.q_proj = torch.nn.Linear(8, 8, bias=False) + self.self_attn.k_proj = torch.nn.Linear(8, 8, bias=False) + self.self_attn.v_proj = torch.nn.Linear(8, 8, bias=False) + self.self_attn.o_proj = torch.nn.Linear(8, 8, bias=False) + + def forward(self, x, attention_mask=None, position_ids=None, use_cache=False): + del attention_mask, position_ids, use_cache + hidden_states = ( + self.self_attn.q_proj(x) + + self.self_attn.k_proj(x) + + self.self_attn.v_proj(x) + ) + return self.self_attn.o_proj(hidden_states) + + def _dynamic_get(_module_name, _key, default=None): + return default + + float_layer = _ToyLayer() + x = torch.randn(1, 2, 8) + y = float_layer(x).detach() + + hooked_layer = copy.deepcopy(float_layer) + replace_module_with_hooked_legacy(hooked_layer) + + q_proj = NamedModule(hooked_layer.self_attn.q_proj, "self_attn.q_proj", "model.layers.0.self_attn.q_proj", 0) + k_proj = NamedModule(hooked_layer.self_attn.k_proj, "self_attn.k_proj", "model.layers.0.self_attn.k_proj", 0) + v_proj = NamedModule(hooked_layer.self_attn.v_proj, "self_attn.v_proj", "model.layers.0.self_attn.v_proj", 0) + + processor = object.__new__(ParoQuantProcessor) + processor.qcfg = SimpleNamespace( + opt_scope="compute_block", + runtime_bits=4, + group_size=8, + sym=True, + krot=1, + opt_seed=0, + opt_pair_ratio=0.5, + opt_pair_impl="fast", + opt_quantizer_impl="reference", + opt_fused_rotation=False, + opt_channel_scale_clamp_min=1e-2, + opt_channel_scale_clamp_max=1e2, + opt_train_samples=2, + opt_validation_samples=2, + opt_stage_impl="fast", + opt_rotation_lr=0.05, + opt_weight_lr=1e-5, + opt_quantizer_lr=1e-6, + opt_rotation_epochs=1, + opt_finetune_epochs=0, + dynamic_get=_dynamic_get, + ) + processor.gptq_model = SimpleNamespace(support_batch_quantize=True) + processor._batch_tls = threading.local() + processor.inputs_cache = InputCache( + layer_inputs=[[x]], + layer_input_kwargs=[{}], + position_ids=[None], + attention_masks=[None], + ) + + state = SimpleNamespace( + layer_module=hooked_layer, + layer_inputs=[[x]], + layer_input_kwargs=[{}], + layer_outputs=[[y]], + pristine_layer_module=None, + prepared_group_source_module=None, + ) + + with torch.inference_mode(): + results, val_loss = processor._optimize_group(state, [q_proj, k_proj, v_proj]) + + assert set(results) == {"self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"} + assert val_loss >= 0.0 + assert hooked_layer.self_attn.o_proj.__class__.__name__ == "HookedLinear" + assert state.prepared_group_source_module is not None + assert state.prepared_group_source_module.self_attn.o_proj.__class__ is torch.nn.Linear + + +def test_paroquant_processor_layer_scope_strips_hooked_linear_wrappers(): + """Layer-scope live optimization must unwrap HookedLinear so training does not re-enter inference mode.""" + + class _ToyNorm(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(8)) + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + return self.weight * hidden_states.to(input_dtype) + + class _ToyLayer(torch.nn.Module): + def __init__(self): + super().__init__() + self.self_attn = torch.nn.Module() + self.self_attn.q_proj = torch.nn.Linear(8, 8, bias=False) + self.self_attn.k_proj = torch.nn.Linear(8, 8, bias=False) + self.self_attn.v_proj = torch.nn.Linear(8, 8, bias=False) + self.self_attn.o_proj = torch.nn.Linear(8, 8, bias=False) + self.post_attention_layernorm = _ToyNorm() + + def forward(self, x, attention_mask=None, position_ids=None, use_cache=False): + del attention_mask, position_ids, use_cache + hidden_states = ( + self.self_attn.q_proj(x) + + self.self_attn.k_proj(x) + + self.self_attn.v_proj(x) + ) + hidden_states = self.self_attn.o_proj(hidden_states) + return self.post_attention_layernorm(hidden_states) + + def _dynamic_get(_module_name, _key, default=None): + return default + + float_layer = _ToyLayer() + x = torch.randn(1, 2, 8) + y = float_layer(x).detach() + + hooked_layer = copy.deepcopy(float_layer) + replace_module_with_hooked_legacy(hooked_layer) + + q_proj = NamedModule(hooked_layer.self_attn.q_proj, "self_attn.q_proj", "model.layers.0.self_attn.q_proj", 0) + k_proj = NamedModule(hooked_layer.self_attn.k_proj, "self_attn.k_proj", "model.layers.0.self_attn.k_proj", 0) + v_proj = NamedModule(hooked_layer.self_attn.v_proj, "self_attn.v_proj", "model.layers.0.self_attn.v_proj", 0) + + processor = object.__new__(ParoQuantProcessor) + processor.qcfg = SimpleNamespace( + opt_scope="layer", + runtime_bits=4, + group_size=8, + sym=True, + krot=1, + opt_seed=0, + opt_pair_ratio=0.5, + opt_pair_impl="fast", + opt_quantizer_impl="reference", + opt_fused_rotation=False, + opt_channel_scale_clamp_min=1e-2, + opt_channel_scale_clamp_max=1e2, + opt_train_samples=2, + opt_validation_samples=2, + opt_stage_impl="fast", + opt_rotation_lr=0.05, + opt_weight_lr=1e-5, + opt_quantizer_lr=1e-6, + opt_rotation_epochs=1, + opt_finetune_epochs=0, + dynamic_get=_dynamic_get, + ) + processor.gptq_model = SimpleNamespace(support_batch_quantize=True) + processor.model = None + processor.lock = threading.Lock() + processor.calculate_w_wq_diff = False + processor._batch_tls = threading.local() + processor.inputs_cache = InputCache( + layer_inputs=[[x]], + layer_input_kwargs=[{}], + position_ids=[None], + attention_masks=[None], + ) + + state = SimpleNamespace( + layer_module=hooked_layer, + layer_inputs=[[x]], + layer_input_kwargs=[{}], + layer_outputs=[[y]], + replay_batches=None, + ) + + with torch.inference_mode(): + results, val_loss = processor._optimize_group(state, [q_proj, k_proj, v_proj]) + + assert set(results) == {"self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"} + assert val_loss >= 0.0 + assert hooked_layer.self_attn.o_proj.__class__ is torch.nn.Linear + assert q_proj.module is hooked_layer.self_attn.q_proj + assert k_proj.module is hooked_layer.self_attn.k_proj + assert v_proj.module is hooked_layer.self_attn.v_proj + assert q_proj.module.__class__ is torch.nn.Linear + assert k_proj.module.__class__ is torch.nn.Linear + assert v_proj.module.__class__ is torch.nn.Linear + + q_proj_result = results["self_attn.q_proj"] + original_weight = processor._module_weight_matrix(q_proj).detach().clone() + processor._apply_optimization_result(q_proj, q_proj_result, original_weight) + assert torch.allclose( + hooked_layer.self_attn.q_proj.weight.detach(), + q_proj_result.pseudo_weight.to(dtype=hooked_layer.self_attn.q_proj.weight.dtype), + atol=1e-5, + rtol=1e-5, + ) + + +def test_paroquant_processor_layer_scope_restores_live_layer_dtype_after_fp32_training(): + """Layer-scope live optimization must downcast the layer back to its original replay dtype.""" + + class _ToyConfig: + def __init__(self): + self._attn_implementation = "flash_attention_2" + self.attn_implementation = "flash_attention_2" + self.dtype = torch.bfloat16 + + class _ToyNorm(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(8, dtype=torch.bfloat16)) + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + return (self.weight.to(torch.float32) * hidden_states).to(input_dtype) + + class _ToyAttn(torch.nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.q_proj = torch.nn.Linear(8, 8, bias=False, dtype=torch.bfloat16) + self.k_proj = torch.nn.Linear(8, 8, bias=False, dtype=torch.bfloat16) + self.v_proj = torch.nn.Linear(8, 8, bias=False, dtype=torch.bfloat16) + self.o_proj = torch.nn.Linear(8, 8, bias=False, dtype=torch.bfloat16) + + def forward(self, hidden_states, attention_mask=None, position_ids=None, use_cache=False): + del attention_mask, position_ids, use_cache + hidden_states = self.q_proj(hidden_states) + self.k_proj(hidden_states) + self.v_proj(hidden_states) + return self.o_proj(hidden_states), None + + class _ToyLayer(torch.nn.Module): + def __init__(self): + super().__init__() + self.config = _ToyConfig() + self.self_attn = _ToyAttn(self.config) + self.post_attention_layernorm = _ToyNorm() + + def forward(self, x, attention_mask=None, position_ids=None, use_cache=False): + attn_out, _ = self.self_attn( + x, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=use_cache, + ) + return self.post_attention_layernorm(attn_out) + + def _dynamic_get(_module_name, _key, default=None): + return default + + layer = _ToyLayer() + x = torch.randn(1, 2, 8, dtype=torch.bfloat16) + y = layer(x).detach() + + q_proj = NamedModule(layer.self_attn.q_proj, "self_attn.q_proj", "model.layers.0.self_attn.q_proj", 0) + k_proj = NamedModule(layer.self_attn.k_proj, "self_attn.k_proj", "model.layers.0.self_attn.k_proj", 0) + v_proj = NamedModule(layer.self_attn.v_proj, "self_attn.v_proj", "model.layers.0.self_attn.v_proj", 0) + + processor = object.__new__(ParoQuantProcessor) + processor.qcfg = SimpleNamespace( + opt_scope="layer", + runtime_bits=4, + group_size=8, + sym=True, + krot=1, + opt_seed=0, + opt_pair_ratio=0.5, + opt_pair_impl="fast", + opt_quantizer_impl="reference", + opt_fused_rotation=False, + opt_channel_scale_clamp_min=1e-2, + opt_channel_scale_clamp_max=1e2, + opt_train_samples=2, + opt_validation_samples=2, + opt_stage_impl="fast", + opt_rotation_lr=0.05, + opt_weight_lr=1e-5, + opt_quantizer_lr=1e-6, + opt_rotation_epochs=1, + opt_finetune_epochs=0, + dynamic_get=_dynamic_get, + ) + processor.gptq_model = SimpleNamespace(support_batch_quantize=True) + processor.model = None + processor._batch_tls = threading.local() + processor.inputs_cache = InputCache( + layer_inputs=[[x]], + layer_input_kwargs=[{}], + position_ids=[None], + attention_masks=[None], + ) + + state = SimpleNamespace( + layer_module=layer, + layer_inputs=[[x]], + layer_input_kwargs=[{}], + layer_outputs=[[y]], + replay_batches=None, + ) + + with torch.inference_mode(): + results, val_loss = processor._optimize_group(state, [q_proj, k_proj, v_proj]) + + assert set(results) == {"self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"} + assert val_loss >= 0.0 + assert layer.self_attn.q_proj.weight.dtype == torch.bfloat16 + assert layer.self_attn.k_proj.weight.dtype == torch.bfloat16 + assert layer.self_attn.v_proj.weight.dtype == torch.bfloat16 + assert layer.self_attn.o_proj.weight.dtype == torch.bfloat16 + assert layer.post_attention_layernorm.weight.dtype == torch.bfloat16 + assert layer(x).dtype == torch.bfloat16 + + +def test_paroquant_quant_device_selection_forces_single_gpu(): + """Guard against multi-GPU ParoQuant worker fan-out and sync hazards.""" + cuda_devices = [torch.device("cuda:0"), torch.device("cuda:1"), torch.device("cuda:2")] + mixed_devices = [torch.device("cpu"), torch.device("cuda:3"), torch.device("cuda:4")] + + assert _restrict_quant_devices_for_method(METHOD.PARO, cuda_devices) == [torch.device("cuda:0")] + assert _restrict_quant_devices_for_method(METHOD.PARO, mixed_devices) == [torch.device("cuda:3")] + assert _restrict_quant_devices_for_method(METHOD.GPTQ, cuda_devices) == cuda_devices + + +def test_paroquant_kernel_rejects_sym_false(): + """Guard that runtime capability flags disable asymmetric ParoQuant.""" + ok, err = ParoLinear.validate( + bits=4, + group_size=128, + desc_act=False, + sym=False, + in_features=128, + out_features=128, + pack_dtype=torch.int32, + dtype=torch.float16, + ) + + assert not ok + assert isinstance(err, NotImplementedError) + assert "actual sym = `False`" in str(err) + + +def test_paroquant_kernel_accepts_bf16(): + """Guard that saved ParoQuant checkpoints can be reloaded for bf16 inference.""" + ok, err = ParoLinear.validate( + bits=4, + group_size=128, + desc_act=False, + sym=True, + in_features=128, + out_features=128, + pack_dtype=torch.int32, + dtype=torch.bfloat16, + ) + + assert ok + assert err is None + + +def test_paroquant_cuda_awq_kernel_preserves_bf16(monkeypatch): + """Guard that the CUDA AWQ fast path does not silently downcast bf16 inputs.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA is required to validate the ParoQuant AWQ bf16 kernel path.") + + paroquant_module = sys.modules[ParoLinear.__module__] + + module = ParoLinear( + bits=4, + group_size=128, + sym=True, + desc_act=False, + in_features=128, + out_features=128, + bias=False, + register_buffers=True, + ).to("cuda") + module.scales.fill_(1) + + seen = {} + + def fake_awq_cuda_gemm_forward(input, qweight, scales, qzeros, split_k_iters, fp32_accum=True): + del qweight, qzeros, fp32_accum + seen["input_dtype"] = input.dtype + seen["scales_dtype"] = scales.dtype + seen["split_k_iters"] = split_k_iters + return torch.zeros((input.shape[0], module.out_features), device=input.device, dtype=input.dtype) + + monkeypatch.setattr(paroquant_module, "_awq_cuda_gemm_forward", fake_awq_cuda_gemm_forward) + + x = torch.randn((2, module.in_features), device="cuda", dtype=torch.bfloat16) + out = module._forward_cuda_awq_kernel(x) + + assert seen["input_dtype"] == torch.bfloat16 + assert seen["scales_dtype"] == torch.bfloat16 + assert seen["split_k_iters"] == 4 + assert module.scales.dtype == torch.bfloat16 + assert out is not None + assert out.dtype == torch.bfloat16 + + +def test_paroquant_rotation_helper_dispatches_fused_kernel_for_bf16(monkeypatch): + """Guard bf16 activations onto the fused CUDA rotation path when ready.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA is required to validate the ParoQuant rotation bf16 fused path.") + + module = ParoLinear( + bits=4, + group_size=128, + sym=True, + desc_act=False, + in_features=128, + out_features=128, + bias=False, + register_buffers=True, + auto_cache_bf16_rotation_dtype=True, + ).to("cuda") + module.theta.uniform_(-0.2, 0.2) + module.channel_scales.uniform_(0.75, 1.25) + module.post_init() + + calls = {} + + def spy_load_rotation_extension(): + calls["load_count"] = calls.get("load_count", 0) + 1 + return True + + def fake_rotate(x, pairs, theta, scales, group_size, cta_m, row_pad): + calls["x_dtype"] = x.dtype + calls["pairs_device"] = pairs.device.type + calls["theta_dtype"] = theta.dtype + calls["scales_dtype"] = None if scales is None else scales.dtype + calls["group_size"] = group_size + calls["cta_m"] = cta_m + calls["row_pad"] = row_pad + return x.clone() + + monkeypatch.setattr(paroquant_utils_module, "_load_rotation_extension", spy_load_rotation_extension) + monkeypatch.setattr(paroquant_utils_module, "_rotation_requested_launch", lambda: (8, 2)) + monkeypatch.setattr(paroquant_utils_module._PAROQUANT_ROTATION_EXTENSION, "op", lambda name: fake_rotate) + + x = torch.randn((2, module.in_features), device="cuda", dtype=torch.bfloat16) + theta = module.theta.to(device=x.device, dtype=torch.bfloat16) + channel_scales = module.channel_scales.to(device=x.device, dtype=torch.bfloat16) + + actual = apply_paroquant_rotation( + x, + module.pairs, + theta, + scales=channel_scales, + group_size=module.group_size, + ) + + assert calls["load_count"] == 1 + assert calls["x_dtype"] == torch.bfloat16 + assert calls["pairs_device"] == "cuda" + assert calls["theta_dtype"] == torch.bfloat16 + assert calls["scales_dtype"] == torch.bfloat16 + assert calls["group_size"] == 128 + assert calls["cta_m"] == 8 + assert calls["row_pad"] == 2 + assert actual.dtype == torch.bfloat16 + assert torch.equal(actual, x) + + +def test_paroquant_rotation_fused_bf16_uses_fp16_workspace_contract(): + """Guard bf16 fused rotation against regressing back to bf16 workspace accumulation.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA is required to validate the ParoQuant rotation bf16 fused workspace path.") + + assert ( + prewarm_paroquant_rotation_extension( + fused_rotation=True, + group_size=128, + krot=8, + device="cuda", + ) + is True + ) + + in_features = 128 + group_size = 128 + krot = 8 + pairs, _mask = build_random_rotation_buffers( + in_features=in_features, + group_size=group_size, + krot=krot, + pair_ratio=0.5, + seed=13, + device=torch.device("cuda"), + ) + generator = torch.Generator(device="cpu") + generator.manual_seed(7) + x = torch.randn((4, in_features), generator=generator, dtype=torch.float32).to(device="cuda", dtype=torch.bfloat16) + theta = torch.empty((krot, in_features // 2), dtype=torch.float32) + theta.uniform_(-0.25, 0.25, generator=generator) + theta = theta.to(device="cuda", dtype=torch.bfloat16) + scales = torch.empty((1, in_features), dtype=torch.float32) + scales.uniform_(0.75, 1.25, generator=generator) + scales = scales.to(device="cuda", dtype=torch.bfloat16) + + actual = apply_paroquant_rotation(x, pairs, theta, scales=scales, group_size=group_size) + bf16_reference = apply_paroquant_rotation_reference(x, pairs, theta, scales=scales, group_size=group_size) + fp16_workspace_reference = apply_paroquant_rotation_reference( + x.to(dtype=torch.float16), + pairs, + theta.to(dtype=torch.float16), + scales=scales.to(dtype=torch.float16), + group_size=group_size, + ).to(dtype=torch.bfloat16) + fp32_reference = apply_paroquant_rotation_reference( + x.to(dtype=torch.float32), + pairs, + theta.to(dtype=torch.float32), + scales=scales.to(dtype=torch.float32), + group_size=group_size, + ) + + actual_fp16_workspace = (actual.float() - fp16_workspace_reference.float()).abs() + bf16_fp16_workspace = (bf16_reference.float() - fp16_workspace_reference.float()).abs() + actual_fp32 = (actual.float() - fp32_reference).abs() + bf16_fp32 = (bf16_reference.float() - fp32_reference).abs() + + assert actual.dtype == torch.bfloat16 + assert actual_fp16_workspace.mean().item() < bf16_fp16_workspace.mean().item() + assert actual_fp16_workspace.max().item() < bf16_fp16_workspace.max().item() + assert actual_fp32.mean().item() <= bf16_fp32.mean().item() + assert actual_fp32.max().item() <= bf16_fp32.max().item() + + +def test_paroquant_rotation_cache_preserves_bf16(monkeypatch): + """Guard that cached BF16 rotation metadata preserves the runtime dtype and values.""" + if not torch.cuda.is_available(): + pytest.skip("CUDA is required to validate the ParoQuant rotation cache path.") + + paroquant_module = sys.modules[ParoLinear.__module__] + from gptqmodel.utils import paroquant as paroquant_utils + + module = ParoLinear( + bits=4, + group_size=128, + sym=True, + desc_act=False, + in_features=128, + out_features=128, + bias=False, + register_buffers=True, + auto_cache_bf16_rotation_dtype=True, + ).to("cuda") + module.theta.uniform_(-0.2, 0.2) + module.channel_scales.uniform_(0.75, 1.25) + module.post_init() + + seen = {} + original_rotate = paroquant_utils.apply_paroquant_rotation + + def spy_rotate(x, pairs, theta, scales=None, group_size=128): + del pairs, group_size + seen["x_dtype"] = x.dtype + seen["theta_dtype"] = theta.dtype + seen["scales_dtype"] = None if scales is None else scales.dtype + return original_rotate(x, module.pairs, theta, scales=scales, group_size=module.group_size) + + monkeypatch.setattr(paroquant_module, "apply_paroquant_rotation", spy_rotate) + + x = torch.randn((2, module.in_features), device="cuda", dtype=torch.bfloat16) + module._rotate_inputs(x) + baseline = module._rotate_inputs(x) + cached = module._rotate_inputs(x) + + assert seen["x_dtype"] == torch.bfloat16 + assert seen["theta_dtype"] == torch.bfloat16 + assert seen["scales_dtype"] == torch.bfloat16 + assert module._runtime_theta is not None + assert module._runtime_channel_scales is not None + assert module._runtime_theta.dtype == torch.bfloat16 + assert module._runtime_channel_scales.dtype == torch.bfloat16 + assert torch.equal(baseline, cached) + + +def test_paroquant_optimizer_improves_over_identity_quantization(): + """Guard that learned rotations beat naive identity-domain quantization.""" + in_features = 128 + out_features = 12 + group_size = 128 + bits = 4 + seed = 11 + pair_ratio = 1.0 / group_size + + pairs, mask = build_random_rotation_buffers( + in_features=in_features, + group_size=group_size, + krot=1, + pair_ratio=pair_ratio, + seed=seed, + device=torch.device("cpu"), + ) + + theta = torch.zeros((1, in_features // 2), dtype=torch.float32) + theta[~mask] = 0.65 + channel_scales_opt = torch.linspace(0.75, 1.25, steps=in_features, dtype=torch.float32).view(1, -1) + transformed_weight = (torch.randint(-7, 8, (out_features, in_features), dtype=torch.int32).to(torch.float32)) * 0.25 + + original_weight = apply_paroquant_rotation_reference( + transformed_weight, + pairs.flip(0), + -theta.flip(0), + scales=None, + group_size=group_size, + ) / channel_scales_opt + + inputs = torch.randn(256, in_features, dtype=torch.float32) + targets = F.linear(inputs, original_weight) + + baseline_weight = pseudo_quantize_dequant( + original_weight, + bits=bits, + group_size=group_size, + sym=True, + use_ste=False, + ) + baseline_loss = F.smooth_l1_loss(F.linear(inputs, baseline_weight), targets) + + result = optimize_paroquant_linear( + weight=original_weight, + bias=None, + inputs=inputs, + bits=bits, + group_size=group_size, + sym=True, + krot=1, + pair_ratio=pair_ratio, + train_rows=192, + val_rows=64, + batch_size=32, + rotation_epochs=24, + finetune_epochs=16, + rotation_lr=0.05, + weight_lr=5e-4, + quantizer_lr=5e-4, + seed=seed, + ) + + optimized_loss = F.smooth_l1_loss(F.linear(inputs, result.pseudo_weight), targets) + assert optimized_loss < baseline_loss + + +def test_paroquant_exported_runtime_state_matches_paper_contract(): + """Guard that export tensors reproduce the pseudo-quantized optimization model. + + This is the key paper-contract regression test. It checks that we optimize + in the transformed domain, inverse-map back to the input domain for replay, + and then export runtime tensors whose rotated-input execution matches the + pseudo-quantized layer exactly. + """ + in_features = 128 + out_features = 10 + group_size = 128 + bits = 4 + + weight = torch.randn(out_features, in_features, dtype=torch.float32) * 0.3 + bias = torch.randn(out_features, dtype=torch.float32) * 0.1 + pairs, theta_mask = build_random_rotation_buffers( + in_features=in_features, + group_size=group_size, + krot=2, + pair_ratio=2.0 / group_size, + seed=7, + device=torch.device("cpu"), + ) + + model = _ParoQuantOptimLinear( + weight, + bias, + bits=bits, + group_size=group_size, + quantizer_sym=True, + pairs=pairs, + theta_mask=theta_mask, + ) + + with torch.no_grad(): + model.theta.uniform_(-0.35, 0.35) + model.reset_masked_angles() + model.channel_scales_opt.copy_(torch.linspace(0.8, 1.2, steps=in_features)) + model.weight.add_(torch.linspace(-0.05, 0.05, steps=in_features).view(1, -1)) + + model.quantizer = GroupLinearQuantizer( + model.transformed_weight().detach(), + bits=bits, + group_size=group_size, + sym=True, + ) + with torch.no_grad(): + model.quantizer.scale.mul_(1.05) + + transformed = apply_paroquant_rotation_reference( + model.weight.detach() * model.channel_scales_opt.detach().view(1, -1), + model.pairs, + model.theta.detach(), + scales=None, + group_size=group_size, + ) + quantized_transformed = pseudo_quantize_dequant( + transformed, + bits=bits, + group_size=group_size, + sym=True, + scale=model.quantizer.scale.detach(), + zero_point_float=None, + use_ste=False, + ) + expected_pseudo_weight = apply_paroquant_rotation_reference( + quantized_transformed, + model.pairs.flip(0), + -model.theta.detach().flip(0), + scales=None, + group_size=group_size, + ) / model.channel_scales_opt.detach().view(1, -1) + + torch.testing.assert_close(model.pseudo_weight().detach(), expected_pseudo_weight, atol=1e-5, rtol=1e-5) + + pack_weight, _q_scales, _q_zeros, theta, runtime_channel_scales = model.export_pack_state() + inputs = torch.randn(32, in_features, dtype=torch.float32) + runtime_outputs = F.linear( + apply_paroquant_rotation_reference( + inputs, + model.pairs, + theta, + scales=runtime_channel_scales, + group_size=group_size, + ), + pack_weight, + bias, + ) + + torch.testing.assert_close( + runtime_outputs, + F.linear(inputs, model.pseudo_weight().detach(), bias), + atol=1e-5, + rtol=1e-5, + ) + + +def test_paroquant_result_from_model_matches_direct_export(): + """Guard the fused result export against the original direct pseudo/export path.""" + torch.manual_seed(7) + in_features = 128 + out_features = 64 + group_size = 128 + bits = 4 + weight = torch.randn((out_features, in_features), dtype=torch.float32) + bias = torch.randn((out_features,), dtype=torch.float32) + pairs, theta_mask = build_random_rotation_buffers( + in_features=in_features, + group_size=group_size, + krot=2, + pair_ratio=2.0 / in_features, + seed=7, + device=torch.device("cpu"), + ) + model = _ParoQuantOptimLinear( + weight, + bias, + bits=bits, + group_size=group_size, + quantizer_sym=True, + pairs=pairs, + theta_mask=theta_mask, + fused_rotation=False, + ) + model.init_quantizer() + with torch.no_grad(): + model.theta.normal_(mean=0.0, std=0.05) + model.channel_scales_opt.uniform_(0.8, 1.2) + + direct_pseudo_weight = model.pseudo_weight().detach() + direct_pack_weight, direct_q_scales, direct_q_zeros, direct_theta, direct_channel_scales = model.export_pack_state() + result = paroquant_optimization._result_from_model( + model, + train_loss=0.123, + val_loss=0.456, + used_identity=False, + ) + + torch.testing.assert_close(result.pseudo_weight, direct_pseudo_weight, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(result.pack_weight, direct_pack_weight, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(result.q_scales, direct_q_scales, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(result.q_zeros, direct_q_zeros, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(result.theta, direct_theta, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(result.channel_scales, direct_channel_scales, atol=1e-5, rtol=1e-5) + assert result.train_loss == pytest.approx(0.123) + assert result.val_loss == pytest.approx(0.456) + assert result.used_identity is False + + +def test_paroquant_reference_quantizer_exports_affine_qzeros(): + """Guard that reference optimizer mode uses affine qparams despite sym runtime config.""" + torch.manual_seed(5) + weight = torch.randn(32, 128, dtype=torch.float32) * 0.25 + 0.1 + inputs = torch.randn(96, 128, dtype=torch.float32) + + result = optimize_paroquant_linear( + weight=weight, + bias=None, + inputs=inputs, + bits=4, + group_size=128, + sym=True, + krot=2, + pair_ratio=2.0 / 128.0, + train_rows=64, + val_rows=32, + batch_size=16, + rotation_epochs=1, + finetune_epochs=1, + rotation_lr=0.05, + weight_lr=1e-5, + quantizer_lr=1e-6, + seed=11, + stage_impl="reference", + pair_impl="reference", + quantizer_impl="reference", + ) + + midpoint = 2 ** (4 - 1) + assert not torch.all(result.q_zeros == midpoint) diff --git a/tests/test_perplexity_logic.py b/tests/test_perplexity_logic.py new file mode 100644 index 000000000..662ff26ae --- /dev/null +++ b/tests/test_perplexity_logic.py @@ -0,0 +1,95 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from types import SimpleNamespace + +import pytest +import torch +import torch.nn as nn + +from gptqmodel.utils.perplexity import Perplexity + + +class _DummyTokenizer: + def __init__(self, input_ids, pad_token_id: int = 0): + self._input_ids = torch.tensor([input_ids], dtype=torch.long) + self.pad_token_id = pad_token_id + self.eos_token_id = pad_token_id + self.bos_token_id = None + self.model_max_length = 0 + + def __call__(self, _text, truncation=False, return_tensors="pt"): + return SimpleNamespace(input_ids=self._input_ids.clone()) + + +class _DummyModel(nn.Module): + def __init__(self, vocab_size: int, mode: str): + super().__init__() + self.vocab_size = vocab_size + self.mode = mode + self.register_parameter("_device_anchor", nn.Parameter(torch.zeros(1))) + + def forward(self, input_ids, attention_mask=None): + del attention_mask + logits = torch.zeros((*input_ids.shape, self.vocab_size), dtype=torch.float32, device=input_ids.device) + if self.mode == "uniform": + return SimpleNamespace(logits=logits) + + if self.mode != "perfect": + raise ValueError(f"Unknown dummy mode: {self.mode}") + + next_tokens = (input_ids + 1) % self.vocab_size + logits.scatter_(2, next_tokens.unsqueeze(-1), 12.0) + return SimpleNamespace(logits=logits) + + +def _make_perplexity(monkeypatch, *, input_ids, vocab_size: int, mode: str) -> Perplexity: + monkeypatch.setattr(Perplexity, "_prepare_data", lambda self: "stub") + tokenizer = _DummyTokenizer(input_ids=input_ids) + model = _DummyModel(vocab_size=vocab_size, mode=mode) + return Perplexity(model=model, tokenizer=tokenizer, dataset_path="unused") + + +def test_perplexity_matches_uniform_distribution_reference(monkeypatch): + ppl = _make_perplexity( + monkeypatch, + input_ids=[0, 1, 2, 3, 4, 5, 0, 1, 2], + vocab_size=6, + mode="uniform", + ) + + scores = ppl.calculate(n_ctx=4, n_batch=8) + + assert scores + assert scores[-1] == pytest.approx(6.0, rel=0.0, abs=1e-6) + + +def test_perplexity_is_invariant_to_smaller_batch_token_budget(monkeypatch): + ppl = _make_perplexity( + monkeypatch, + input_ids=[0, 1, 2, 3, 4, 5, 0, 1, 2, 3], + vocab_size=6, + mode="uniform", + ) + + small_budget_scores = ppl.calculate(n_ctx=4, n_batch=2) + large_budget_scores = ppl.calculate(n_ctx=4, n_batch=16) + + assert small_budget_scores + assert large_budget_scores + assert small_budget_scores[-1] == pytest.approx(large_budget_scores[-1], rel=0.0, abs=1e-6) + + +def test_perplexity_includes_final_partial_window(monkeypatch): + ppl = _make_perplexity( + monkeypatch, + input_ids=[0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4], + vocab_size=6, + mode="perfect", + ) + + scores = ppl.calculate(n_ctx=4, n_batch=4) + + assert len(scores) == 3 + assert scores[-1] < 1.001 diff --git a/tests/test_post_quant_eora.py b/tests/test_post_quant_eora.py index 57b379685..78535c7d4 100644 --- a/tests/test_post_quant_eora.py +++ b/tests/test_post_quant_eora.py @@ -23,13 +23,17 @@ import tempfile # noqa: E402 from typing import Optional # noqa: E402 +import pytest # noqa: E402 from datasets import load_dataset from models.model_test import ModelTest # noqa: E402 from gptqmodel import BACKEND, GPTQModel # noqa: E402 from gptqmodel.adapter.adapter import Lora # noqa: E402 -from gptqmodel.utils.eval import EVAL # noqa: E402 from gptqmodel.utils.torch import torch_empty_cache # noqa: E402 +from tests.eval import evaluate # noqa: E402 + + +pytestmark = [pytest.mark.model, pytest.mark.slow] def bench(path: str, backend: BACKEND, adapter: Optional[Lora]): @@ -44,16 +48,19 @@ def bench(path: str, backend: BACKEND, adapter: Optional[Lora]): if backend == BACKEND.TORCH: model.optimize() - tokens = model.generate("Capital of France is")[0] - result = model.tokenizer.decode(tokens) + result = ModelTest.generate_stable_with_limit( + model, + model.tokenizer, + "The capital city of France is named", + max_new_tokens=128, + ) print(f"BACKEND: {backend}, Result: {result}") if "paris" not in result.lower(): raise AssertionError(" `paris` not found in `result`") - bench_result = GPTQModel.eval( + bench_result = evaluate( model_or_id_or_path=model, - framework=EVAL.LM_EVAL, - tasks=[EVAL.LM_EVAL.ARC_CHALLENGE, EVAL.LM_EVAL.MMLU_STEM] + tasks=["arc_challenge", "mmlu_stem"] ) del model @@ -100,8 +107,7 @@ def test_post_quant_eora(self): calibration_dataset_concat_size=calibration_dataset_concat_size, ) - # BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.TRITON, BACKEND.CUDA, - # for backend in [BACKEND.MARLIN]: # BACKEND.IPEX, BACKEND.BITBLAS, BACKEND.EXLLAMA_V2V BACKEND.MARLIN + # for backend in [BACKEND.MARLIN]: # BACKEND.TORCH_FUSED, BACKEND.BITBLAS, BACKEND.EXLLAMA_V2V BACKEND.MARLIN # base_bench = bench(path=self.QUANTIZED_MODEL_PATH, backend=backend, adapter=None) # inference using qweights only # eora_bench = bench(path=self.QUANTIZED_MODEL_PATH, backend=backend, adapter=eora) # inference using eora (lora) # diff --git a/tests/test_prepare_dataset.py b/tests/test_prepare_dataset.py index 2522495cd..f1fff2845 100644 --- a/tests/test_prepare_dataset.py +++ b/tests/test_prepare_dataset.py @@ -5,9 +5,11 @@ import copy +import pytest import torch from gptqmodel.models.base import BaseQModel +from gptqmodel.utils.data import collate_data class _StubTokenizer: @@ -28,6 +30,44 @@ def _encode_char(ch: str) -> int: return value if value > 0 else 1 +class _CausalMaskTokenizer(_StubTokenizer): + def __call__(self, text, return_tensors="pt", add_special_tokens=True): + tokenized = super().__call__(text, return_tensors=return_tensors, add_special_tokens=add_special_tokens) + seq = tokenized["input_ids"].shape[1] + causal_mask = torch.tril(torch.ones((1, 1, seq, seq), dtype=torch.long)) + tokenized["attention_mask"] = causal_mask + return tokenized + + +class _ChatStubTokenizer(_StubTokenizer): + chat_template = "{{ messages }}" + + def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=False): + # Mirror the HF chat-template path closely enough to verify precedence. + assert tokenize is False + rendered = "".join(f"<{item['role']}>{item['content']}" for item in messages) + if add_generation_prompt: + rendered += "" + return rendered + + +class _MissingChatTemplateTokenizer(_StubTokenizer): + chat_template = None + + def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=False): + raise AssertionError("apply_chat_template should not be used when no chat template is configured") + + +class _RaisingGetChatTemplateTokenizer(_StubTokenizer): + chat_template = None + + def apply_chat_template(self, messages, tokenize=False, add_generation_prompt=False): + raise AssertionError("apply_chat_template should not be used when get_chat_template raises") + + def get_chat_template(self, chat_template=None, tools=None): + raise ValueError("tokenizer.chat_template is not set") + + def _make_qmodel() -> BaseQModel: model = BaseQModel.__new__(BaseQModel) model.tokenizer = _StubTokenizer() @@ -38,6 +78,12 @@ def _make_qmodel() -> BaseQModel: return model +def _make_qmodel_with_tokenizer(tokenizer) -> BaseQModel: + model = _make_qmodel() + model.tokenizer = tokenizer + return model + + def _sample_dataset(): return [ {"input_ids": [[1, 2]], "attention_mask": [[1, 1]]}, @@ -111,3 +157,209 @@ def test_prepare_dataset_splits_long_row_across_blocks(): assert first_mask == [[1, 1, 1, 1, 1]] assert second_ids == [[6, 0, 0, 0, 0]] assert second_mask == [[1, 0, 0, 0, 0]] + + +def test_prepare_dataset_collapses_causal_attention_mask(): + qmodel = _make_qmodel_with_tokenizer(_CausalMaskTokenizer()) + + batches = qmodel.prepare_dataset( + calibration_dataset=["abc"], + calibration_dataset_concat_size=None, + calibration_dataset_sort=None, + batch_size=1, + calibration_data_min_length=0, + calibration_concat_separator=None, + ) + + assert len(batches) == 1 + assert batches[0]["input_ids"].tolist() == [[97, 98, 99]] + assert batches[0]["attention_mask"].int().tolist() == [[1, 1, 1]] + + +def test_prepare_dataset_normalizes_rank_4_attention_mask(): + qmodel = _make_qmodel() + keep = torch.tensor([True, True, True, False, False], dtype=torch.bool) + seq_len = keep.numel() + causal = torch.zeros((1, 1, seq_len, seq_len), dtype=torch.bool) + for query_idx in range(seq_len): + causal[0, 0, query_idx] = keep & (torch.arange(seq_len) <= query_idx) + + dataset = [{"input_ids": [[1, 2, 3, 0, 0]], "attention_mask": causal}] + + batches = qmodel.prepare_dataset( + calibration_dataset=dataset, + calibration_dataset_sort=None, + batch_size=1, + calibration_data_min_length=0, + ) + + assert len(batches) == 1 + assert batches[0]["input_ids"].tolist() == [[1, 2, 3, 0, 0]] + assert batches[0]["attention_mask"].int().tolist() == [[1, 1, 1, 0, 0]] + + +def test_prepare_dataset_prefers_apply_chat_template_for_messages(): + qmodel = _make_qmodel() + qmodel.tokenizer = _ChatStubTokenizer() + dataset = [ + { + "messages": [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "world"}, + ], + "text": "raw-fallback", + } + ] + + batches = qmodel.prepare_dataset( + calibration_dataset=dataset, + calibration_dataset_sort=None, + batch_size=1, + calibration_data_min_length=0, + ) + + templated = qmodel.tokenizer.apply_chat_template(dataset[0]["messages"], tokenize=False, add_generation_prompt=False) + expected_ids = qmodel.tokenizer(templated, return_tensors="pt")["input_ids"].tolist() + raw_text_ids = qmodel.tokenizer(dataset[0]["text"], return_tensors="pt")["input_ids"].tolist() + + assert batches[0]["input_ids"].tolist() == expected_ids + assert batches[0]["input_ids"].tolist() != raw_text_ids + + +def test_prepare_dataset_falls_back_to_text_when_chat_template_is_missing(): + qmodel = _make_qmodel() + qmodel.tokenizer = _MissingChatTemplateTokenizer() + dataset = [ + { + "messages": [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "world"}, + ], + "text": "raw-fallback", + } + ] + + batches = qmodel.prepare_dataset( + calibration_dataset=dataset, + calibration_dataset_sort=None, + batch_size=1, + calibration_data_min_length=0, + ) + + expected_ids = qmodel.tokenizer(dataset[0]["text"], return_tensors="pt")["input_ids"].tolist() + assert batches[0]["input_ids"].tolist() == expected_ids + + +def test_prepare_dataset_falls_back_to_text_when_get_chat_template_raises(): + qmodel = _make_qmodel() + qmodel.tokenizer = _RaisingGetChatTemplateTokenizer() + dataset = [ + { + "messages": [ + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "world"}, + ], + "text": "raw-fallback", + } + ] + + batches = qmodel.prepare_dataset( + calibration_dataset=dataset, + calibration_dataset_sort=None, + batch_size=1, + calibration_data_min_length=0, + ) + + expected_ids = qmodel.tokenizer(dataset[0]["text"], return_tensors="pt")["input_ids"].tolist() + assert batches[0]["input_ids"].tolist() == expected_ids + + +def test_collate_data_uses_right_padding_by_default(): + batch = [ + { + "input_ids": [[1, 2, 3], [4, 5]], + "attention_mask": [[1, 1, 1], [1, 1]], + }, + ] + + result = collate_data(batch, pad_token_id=0) + + assert result["input_ids"].tolist() == [[1, 2, 3], [4, 5, 0]] + assert result["attention_mask"].int().tolist() == [[1, 1, 1], [1, 1, 0]] + + +def test_collate_data_left_padding_when_requested(): + batch = [ + { + "input_ids": [[1, 2, 3], [4, 5]], + "attention_mask": [[1, 1, 1], [1, 1]], + }, + ] + + result = collate_data(batch, pad_token_id=0, padding_side="left") + + assert result["input_ids"].tolist() == [[1, 2, 3], [0, 4, 5]] + assert result["attention_mask"].int().tolist() == [[1, 1, 1], [0, 1, 1]] + + +def test_collate_data_raises_for_invalid_padding_side(): + batch = [ + { + "input_ids": [[1, 2, 3]], + "attention_mask": [[1, 1, 1]], + } + ] + + with pytest.raises(ValueError, match="Unsupported padding_side"): + collate_data(batch, pad_token_id=0, padding_side="center") + + +def test_prepare_dataset_uses_tokenizer_padding_side_left(): + qmodel = _make_qmodel() + qmodel.tokenizer.padding_side = "left" + + batches = qmodel.prepare_dataset( + calibration_dataset=[[1, 2, 3, 4], [5, 6]], + calibration_dataset_sort=None, + batch_size=2, + calibration_data_min_length=0, + ) + + assert batches[0]["input_ids"].tolist() == [[1, 2, 3, 4], [0, 0, 5, 6]] + assert batches[0]["attention_mask"].int().tolist() == [[1, 1, 1, 1], [0, 0, 1, 1]] + + +def test_prepare_dataset_concat_respects_tokenizer_padding_side_left(): + qmodel = _make_qmodel() + qmodel.tokenizer.padding_side = "left" + dataset = copy.deepcopy(_sample_dataset()) + + batches = qmodel.prepare_dataset( + calibration_dataset=dataset, + calibration_dataset_concat_size=5, + calibration_dataset_sort=None, + batch_size=1, + calibration_data_min_length=0, + calibration_concat_separator=None, + ) + + assert len(batches) == 1 + assert batches[0]["input_ids"].tolist() == [[0, 0, 1, 2, 3]] + assert batches[0]["attention_mask"].int().tolist() == [[0, 0, 1, 1, 1]] + + +def test_prepare_dataset_trims_left_padded_rows_from_the_left_edge(): + qmodel = _make_qmodel() + qmodel.tokenizer.padding_side = "left" + qmodel.model.config.max_position_embeddings = 4 + + batches = qmodel.prepare_dataset( + calibration_dataset=[{"input_ids": [[0, 0, 11, 12, 13, 14]], "attention_mask": [[0, 0, 1, 1, 1, 1]]}], + calibration_dataset_sort=None, + batch_size=1, + calibration_data_min_length=0, + ) + + assert len(batches) == 1 + assert batches[0]["input_ids"].tolist() == [[11, 12, 13, 14]] + assert batches[0]["attention_mask"].int().tolist() == [[1, 1, 1, 1]] diff --git a/tests/test_q4_bitblas.py b/tests/test_q4_bitblas.py index 9f0689d79..6f629f200 100644 --- a/tests/test_q4_bitblas.py +++ b/tests/test_q4_bitblas.py @@ -12,64 +12,70 @@ import unittest # noqa: E402 +import pytest # noqa: E402 import torch # noqa: E402 +from models.model_test import ModelTest # noqa: E402 from transformers import AutoTokenizer # noqa: E402 from gptqmodel import BACKEND, GPTQModel # noqa: E402 -from gptqmodel.nn_modules.qlinear.bitblas import BitBLASQuantLinear # noqa: E402 +from gptqmodel.nn_modules.qlinear.bitblas import BitBLASLinear # noqa: E402 + + +pytestmark = [pytest.mark.model, pytest.mark.slow] class TestQ4BitBLAS(unittest.TestCase): + @classmethod + def setUpClass(cls): + """Load one small BitBLAS-backed fixture for both coverage checks.""" + cls.model_id = "/monster/data/model/opt-125M-autoround-lm_head-false-symTrue" + cls.model_q = GPTQModel.load(cls.model_id, device="cuda:0", backend=BACKEND.BITBLAS) + cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_id) + def test_generation(self): prompt = "The capital city of France is named" - device = torch.device("cuda:0") - - model_id = "/monster/data/model/opt-125M-autoround-lm_head-false-symTrue" - - try: - model_q = GPTQModel.load(model_id, device="cuda:0", backend=BACKEND.BITBLAS) - except ValueError as e: - raise e has_bitblas = False - for _, module in model_q.named_modules(): - if isinstance(module, BitBLASQuantLinear): + for _, module in self.model_q.named_modules(): + if isinstance(module, BitBLASLinear): has_bitblas = True break self.assertTrue(has_bitblas) - tokenizer = AutoTokenizer.from_pretrained(model_id) - - inp = tokenizer(prompt, return_tensors="pt").to(device) - - res = model_q.generate(**inp, num_beams=1, min_new_tokens=60, max_new_tokens=60) - - predicted_text = tokenizer.decode(res[0]) + predicted_text = ModelTest.generate_stable_with_limit( + self.model_q, + self.tokenizer, + prompt, + min_new_tokens=60, + max_new_tokens=60, + skip_special_tokens=False, + ) self.assertIn("paris", predicted_text.lower()) def test_bias(self): - # TheBloke/Llama-2-7B-Chat-GPTQ has bias, but they are all zeros, use a checkpoint which really uses bias. - model_id = "/monster/data/model/starcoderbase-1b-GPTQ" - - model_q = GPTQModel.load(model_id, device="cuda:0", backend=BACKEND.BITBLAS) - - for _, param in model_q.named_parameters(): + for _, param in self.model_q.named_parameters(): self.assertNotEqual(param.device, torch.device("meta")) - for _, param in model_q.named_buffers(): + for _, param in self.model_q.named_buffers(): self.assertNotEqual(param.device, torch.device("meta")) - self.assertTrue(torch.count_nonzero(model_q.model.transformer.h[0].attn.c_proj.bias) > 0) - self.assertTrue(torch.count_nonzero(model_q.model.transformer.h[0].attn.c_attn.bias) > 0) - - model_id = "/monster/data/model/starcoderbase-1b" - tokenizer = AutoTokenizer.from_pretrained(model_id) + self.assertGreater( + torch.count_nonzero(self.model_q.model.model.decoder.layers[0].self_attn.q_proj.bias), + 0, + ) + self.assertGreater( + torch.count_nonzero(self.model_q.model.model.decoder.layers[0].fc1.bias), + 0, + ) prompt = "The capital city of France is named" - inp = tokenizer(prompt, return_tensors="pt").to("cuda:0") - - res = model_q.generate(**inp, num_beams=1, min_new_tokens=60, max_new_tokens=60) - - predicted_text = tokenizer.decode(res[0]) + predicted_text = ModelTest.generate_stable_with_limit( + self.model_q, + self.tokenizer, + prompt, + min_new_tokens=60, + max_new_tokens=60, + skip_special_tokens=False, + ) self.assertIn("paris", predicted_text.lower()) diff --git a/tests/test_q4_exllama_v2.py b/tests/test_q4_exllama_v2.py index 4e417a12e..d915cd7ae 100644 --- a/tests/test_q4_exllama_v2.py +++ b/tests/test_q4_exllama_v2.py @@ -15,11 +15,12 @@ import unittest # noqa: E402 import torch # noqa: E402 -from test_q4_exllama_v1 import REFERENCE, get_diff # noqa: E402 +from models.model_test import ModelTest # noqa: E402 +from q4_reference import REFERENCE, get_diff # noqa: E402 from transformers import AutoTokenizer # noqa: E402 from gptqmodel import BACKEND, GPTQModel # noqa: E402 -from gptqmodel.nn_modules.qlinear.exllamav2 import ExllamaV2QuantLinear # noqa: E402 +from gptqmodel.nn_modules.qlinear.exllamav2 import ExllamaV2Linear # noqa: E402 from gptqmodel.quantization import FORMAT, METHOD # noqa: E402 from gptqmodel.utils.importer import select_quant_linear # noqa: E402 from gptqmodel.utils.model import gptqmodel_post_init # noqa: E402 @@ -61,7 +62,7 @@ def test_exllamav2(self): backend=BACKEND.EXLLAMA_V2, ) - self.assertTrue(isinstance(linear, ExllamaV2QuantLinear)) + self.assertTrue(isinstance(linear, ExllamaV2Linear)) torch.manual_seed(42) @@ -88,7 +89,7 @@ def test_exllamav2(self): def test_generation_desc_act_false(self): prompt = "I am in Paris and" - device = torch.device("cuda:0") + torch.device("cuda:0") reference_output = " I am in Paris and I am in love with you.\n\nScene 2:\n\n(The stage is now dark, but the audience can see the characters walking around the stage.)\n\n(The stage is now lit up, but the audience can only see the characters' silhouettes.)\n\n(" @@ -97,17 +98,20 @@ def test_generation_desc_act_false(self): model_q = GPTQModel.load(model_id, device="cuda:0") tokenizer = AutoTokenizer.from_pretrained(model_id) - inp = tokenizer(prompt, return_tensors="pt").to(device) - - res = model_q.generate(**inp, num_beams=1, do_sample=False, min_new_tokens=60, max_new_tokens=60) - - predicted_text = tokenizer.decode(res[0]) + predicted_text = ModelTest.generate_stable_with_limit( + model_q, + tokenizer, + prompt, + min_new_tokens=60, + max_new_tokens=60, + skip_special_tokens=False, + ) self.assertEqual(predicted_text[:GENERATE_EVAL_SIZE], reference_output[:GENERATE_EVAL_SIZE]) def test_generation_desc_act_true(self): prompt = "The capital of France is" - device = torch.device("cuda:0") + torch.device("cuda:0") model_id = "/monster/data/model/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit" revision = "desc_act_true" @@ -120,11 +124,14 @@ def test_generation_desc_act_true(self): ) tokenizer = AutoTokenizer.from_pretrained(model_id) - inp = tokenizer(prompt, return_tensors="pt").to(device) - - res = model_q.generate(**inp, num_beams=1, min_new_tokens=60, max_new_tokens=60) - - predicted_text = tokenizer.decode(res[0]) + predicted_text = ModelTest.generate_stable_with_limit( + model_q, + tokenizer, + prompt, + min_new_tokens=60, + max_new_tokens=60, + skip_special_tokens=False, + ) print("predicted_text", predicted_text) assert "paris" in predicted_text.lower() or "city" in predicted_text.lower() diff --git a/tests/test_q4_marlin.py b/tests/test_q4_marlin.py index f35efd7c8..da1ab17dd 100644 --- a/tests/test_q4_marlin.py +++ b/tests/test_q4_marlin.py @@ -16,7 +16,7 @@ from transformers import AutoTokenizer # noqa: E402 from gptqmodel import BACKEND, GPTQModel # noqa: E402 -from gptqmodel.nn_modules.qlinear.marlin import MarlinQuantLinear # noqa: E402 +from gptqmodel.nn_modules.qlinear.marlin import MarlinLinear # noqa: E402 class TestQ4Marlin(ModelTest): @@ -47,14 +47,18 @@ class TestQ4Marlin(ModelTest): ] ) def test_generation(self, model_id): - try: - model_q = GPTQModel.load(model_id, device="cuda:0", backend=BACKEND.MARLIN) - except ValueError as e: - raise e - + if model_id == "/monster/data/model/gemma-1.1-2b-it-GPTQ": + with self.assertRaisesRegex( + ValueError, + r"is_marlin_format.*no longer supported" + ): + GPTQModel.load(model_id, device="cuda:0", backend=BACKEND.MARLIN) + return + + model_q = GPTQModel.load(model_id, device="cuda:0", backend=BACKEND.MARLIN) has_marlin = False for _, module in model_q.named_modules(): - linear = MarlinQuantLinear + linear = MarlinLinear if isinstance(module, linear): has_marlin = True break diff --git a/tests/test_q4_reference.py b/tests/test_q4_reference.py new file mode 100644 index 000000000..7436320b3 --- /dev/null +++ b/tests/test_q4_reference.py @@ -0,0 +1,11 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from q4_reference import REFERENCE, get_diff + + +def test_q4_reference_smoke(): + # Keep an addressable pytest target for CI jobs that request `test_q4_reference`. + assert REFERENCE.dtype == REFERENCE.new_empty(()).dtype + assert "Maxdiff:" in get_diff(REFERENCE, REFERENCE) diff --git a/tests/test_q4_torch_fused.py b/tests/test_q4_torch_fused.py new file mode 100644 index 000000000..7897d5292 --- /dev/null +++ b/tests/test_q4_torch_fused.py @@ -0,0 +1,31 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +# -- do not touch +import os + + +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +# -- end do not touch + +import torch # noqa: E402 +from models.model_test import ModelTest # noqa: E402 + +from gptqmodel import BACKEND # noqa: E402 + + +class TestsTorchFused(ModelTest): + NATIVE_MODEL_ID = "/monster/data/model/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit" # "bigscience/bloom-560m" + NATIVE_ARC_CHALLENGE_ACC = 0.28 + NATIVE_ARC_CHALLENGE_ACC_NORM = 0.31 + TORCH_DTYPE = torch.float16 + LOAD_BACKEND = BACKEND.TORCH_FUSED + # Torch-fused compat tests run on CPU, so flash-attn CUDA kernels are not valid here. + USE_FLASH_ATTN = False + DELETE_QUANTIZED_MODEL = False + USE_VLLM = False + + def test_torch_fused(self): + self.quantize_and_evaluate() diff --git a/tests/test_q4_triton.py b/tests/test_q4_triton.py index d19e141ab..bc22ce9f0 100644 --- a/tests/test_q4_triton.py +++ b/tests/test_q4_triton.py @@ -15,7 +15,7 @@ from transformers import AutoTokenizer # noqa: E402 from gptqmodel import BACKEND, GPTQModel # noqa: E402 -from gptqmodel.nn_modules.qlinear.tritonv2 import TritonV2QuantLinear # noqa: E402 +from gptqmodel.nn_modules.qlinear.tritonv2 import TritonV2Linear # noqa: E402 class TestsQ4Triton(ModelTest): @@ -30,7 +30,7 @@ def test_generation_desc_act_false(self): dtype=torch.float16, ) for _, submodule in model_q.named_modules(): - if isinstance(submodule, TritonV2QuantLinear): + if isinstance(submodule, TritonV2Linear): break else: raise ValueError("Did not find a tritonv2 linear layer") @@ -53,7 +53,7 @@ def test_generation_desc_act_true(self): ) for _, submodule in model_q.named_modules(): - if isinstance(submodule, TritonV2QuantLinear): + if isinstance(submodule, TritonV2Linear): break else: raise ValueError("Did not find a tritonv2 linear layer") diff --git a/tests/test_qqq.py b/tests/test_qqq.py index acb12c56f..6a07cac28 100644 --- a/tests/test_qqq.py +++ b/tests/test_qqq.py @@ -11,10 +11,11 @@ import unittest from datasets import load_dataset +from models.model_test import ModelTest from parameterized import parameterized from transformers import AutoTokenizer -from gptqmodel.nn_modules.qlinear.qqq import QQQQuantLinear +from gptqmodel.nn_modules.qlinear.qqq import QQQLinear from gptqmodel.quantization import FORMAT, METHOD, QUANT_CONFIG_FILENAME from gptqmodel.utils.torch import torch_empty_cache @@ -87,8 +88,13 @@ def test_quant_and_inference(self, group_size: int): self.assert_qqq_linear(model) - tokens = model.generate("The capital city of France is named", min_new_tokens=128, max_new_tokens=128)[0] - result = model.tokenizer.decode(tokens) + result = ModelTest.generate_stable_with_limit( + model, + model.tokenizer, + "The capital city of France is named", + min_new_tokens=128, + max_new_tokens=128, + ) print(f"BACKEND: {BACKEND.QQQ}, Result: {result}") if "paris" not in result.lower() and "city" not in result.lower() and "country" not in result.lower(): raise AssertionError(" `paris` not found in `result`") @@ -96,7 +102,7 @@ def test_quant_and_inference(self, group_size: int): def assert_qqq_linear(self, model): has_qqq = False for _, module in model.named_modules(): - linear = QQQQuantLinear + linear = QQQLinear if isinstance(module, linear): has_qqq = True break diff --git a/tests/test_qqq_inference.py b/tests/test_qqq_inference.py index aa172d8c4..4f71fbbab 100644 --- a/tests/test_qqq_inference.py +++ b/tests/test_qqq_inference.py @@ -3,12 +3,19 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium +import pytest +from models.model_test import ModelTest + from gptqmodel import GPTQModel -from gptqmodel.utils.eval import EVAL -eval_results = GPTQModel.eval("HandH1998/QQQ-Llama-3-8b-g128", - framework=EVAL.LM_EVAL, - tasks=[EVAL.LM_EVAL.ARC_CHALLENGE]) +pytestmark = [pytest.mark.model, pytest.mark.slow] -print(f"{eval_results}") +def test_qqq_inference(): + model = GPTQModel.load("HandH1998/QQQ-Llama-3-8b-g128") + str_output = ModelTest.generate_stable_with_limit( + model, + model.tokenizer, + "The capital city of France is named", + ) + assert "paris" in str_output.lower() or "city" in str_output.lower() diff --git a/tests/test_qqq_jit.py b/tests/test_qqq_jit.py new file mode 100644 index 000000000..5e216d858 --- /dev/null +++ b/tests/test_qqq_jit.py @@ -0,0 +1,77 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +import pytest +import torch + +import gptqmodel.nn_modules.qlinear.qqq as qqq_module + + +def _build_module() -> qqq_module.QQQLinear: + return qqq_module.QQQLinear( + bits=4, + group_size=128, + sym=True, + desc_act=False, + in_features=128, + out_features=128, + bias=False, + register_buffers=True, + ) + + +def test_qqq_forward_uses_jit_kernel(monkeypatch): + module = _build_module() + calls = {} + + monkeypatch.setattr(qqq_module, "qqq_runtime_available", lambda: True) + + def fake_gemm(A, B, C, D, s1, s2, s3, workspace, thread_k, thread_n, sms, max_par): + calls["gemm"] = { + "A_shape": tuple(A.shape), + "A_dtype": A.dtype, + "B_shape": tuple(B.shape), + "D_shape": tuple(D.shape), + "s1_shape": tuple(s1.shape), + "s2_shape": tuple(s2.shape), + "s3_shape": tuple(s3.shape), + "workspace_shape": tuple(workspace.shape), + "thread_k": thread_k, + "thread_n": thread_n, + "sms": sms, + "max_par": max_par, + } + D.copy_(torch.full_like(D, 3.0)) + + monkeypatch.setattr(qqq_module, "qqq_gemm", fake_gemm) + + x = torch.randn((2, module.in_features), dtype=torch.float32) + out = module(x) + + assert calls["gemm"] == { + "A_shape": (2, module.in_features), + "A_dtype": torch.int8, + "B_shape": tuple(module.B.shape), + "D_shape": (2, module.out_features), + "s1_shape": (2, 1), + "s2_shape": tuple(module.s_channel.shape), + "s3_shape": tuple(module.s_group.shape), + "workspace_shape": tuple(module.workspace.shape), + "thread_k": -1, + "thread_n": -1, + "sms": -1, + "max_par": module.max_par, + } + assert out.shape == (2, module.out_features) + assert out.dtype == torch.float32 + assert torch.allclose(out, torch.full_like(out, 3.0)) + + +def test_qqq_forward_raises_runtime_error_when_jit_ops_missing(monkeypatch): + module = _build_module() + + monkeypatch.setattr(qqq_module, "qqq_runtime_available", lambda: False) + monkeypatch.setattr(qqq_module, "qqq_runtime_error", lambda: "missing qqq jit ops") + + with pytest.raises(ModuleNotFoundError, match="missing qqq jit ops"): + module(torch.randn((1, module.in_features), dtype=torch.float16)) diff --git a/tests/test_quant_and_eora.py b/tests/test_quant_and_eora.py index 812929eea..8a563c8d5 100644 --- a/tests/test_quant_and_eora.py +++ b/tests/test_quant_and_eora.py @@ -26,15 +26,14 @@ from typing import Optional # noqa: E402 from datasets import load_dataset # noqa: E402 -from lm_eval.utils import make_table # noqa: E402 from logbar import LogBar from models.model_test import ModelTest # noqa: E402 from tabulate import tabulate # noqa: E402 from gptqmodel import BACKEND, GPTQModel, QuantizeConfig # noqa: E402 from gptqmodel.adapter.adapter import HF_ADAPTER_FILE_NAME, HF_ADAPTER_WEIGHT_KEY_PREFIX, Lora # noqa: E402 -from gptqmodel.utils.eval import EVAL # noqa: E402 from gptqmodel.utils.torch import torch_empty_cache # noqa: E402 +from tests.eval import evaluate, format_eval_result_table # noqa: E402 # --------Eval METHOD.GPTQ Result--------- @@ -52,7 +51,7 @@ class TestQuantAndEORA(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct" # "meta-llama/Llama-3.2-1B-Instruct" EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + "arc_challenge": { "chat_template": True, "acc": {"value": 0.3183, "floor_pct": 0.05}, "acc_norm": {"value": 0.3404, "floor_pct": 0.05}, @@ -123,26 +122,21 @@ def test_quant_and_eora(self, quant_method: METHOD, format: FORMAT): del model torch_empty_cache() - # BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.TRITON, BACKEND.CUDA, - for backend in [BACKEND.MARLIN]: # BACKEND.IPEX, BACKEND.BITBLAS, BACKEND.EXLLAMA_V2V BACKEND.MARLIN + for backend in [BACKEND.MARLIN]: # BACKEND.TORCH_FUSED, BACKEND.BITBLAS, BACKEND.EXLLAMA_V2V BACKEND.MARLIN base_bench = self.bench(path=tmpdir, backend=backend, adapter=None) # inference using qweights only eora_bench = self.bench(path=tmpdir, backend=backend, adapter=eora) # inference using eora (lora) - print('--------GPTQModel + EoRA Config ---------') + print('--------GPT-QModel + EoRA Config ---------') # Convert the dictionary to a list of lists for tabulate # table_data = [[key, value] for key, value in config_dict.items()] # print(tabulate(table_data, headers=["Key", "Value"], tablefmt="grid")) print(f'--------Eval {quant_method} Result---------') - print(make_table(base_bench)) - if "groups" in base_bench: - print(make_table(base_bench, "groups")) + print(format_eval_result_table(base_bench)) print(f'--------Eval {quant_method} + EoRA Result---------') - print(make_table(eora_bench)) - if "groups" in eora_bench: - print(make_table(eora_bench, "groups")) + print(format_eval_result_table(eora_bench)) def bench(self, path: str, backend: BACKEND, adapter: Optional[Lora]): # test post-quant inference @@ -157,13 +151,12 @@ def bench(self, path: str, backend: BACKEND, adapter: Optional[Lora]): print(f"BACKEND: {backend}, Result: {result}") # assert "paris" in result.lower(), f"`paris` not found in `{result}`" - bench_result = GPTQModel.eval( + bench_result = evaluate( model_or_id_or_path=model, - framework=EVAL.LM_EVAL, - tasks=[EVAL.LM_EVAL.ARC_CHALLENGE], + tasks=["arc_challenge"], apply_chat_template=True, # MMLU is too slow for ci test - # EVAL.LM_EVAL.MMLU_STEM + # "mmlu_stem" ) del model @@ -184,7 +177,7 @@ class TestTransformers(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-1B" EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + "arc_challenge": { "acc": {"value": 0.3567, "floor_pct": 0.36}, "acc_norm": {"value": 0.3805, "floor_pct": 0.36}, }, @@ -266,26 +259,21 @@ def test_quant_and_eora(self): del model torch_empty_cache() - # BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.TRITON, BACKEND.CUDA, - for backend in [BACKEND.MARLIN]: # BACKEND.IPEX, BACKEND.BITBLAS, BACKEND.EXLLAMA_V2V BACKEND.MARLIN + for backend in [BACKEND.MARLIN]: # BACKEND.TORCH_FUSED, BACKEND.BITBLAS, BACKEND.EXLLAMA_V2V BACKEND.MARLIN eora_bench = self.bench(path=tmpdir, backend=backend, adapter=eora) # inference using eora (lora) base_bench = self.bench(path=tmpdir, backend=backend, adapter=None) # inference using qweights only - print('--------GPTQModel + EoRA Config ---------') + print('--------GPT-QModel + EoRA Config ---------') # Convert the dictionary to a list of lists for tabulate table_data = [[key, value] for key, value in config_dict.items()] print(tabulate(table_data, headers=["Key", "Value"], tablefmt="grid")) print('--------Eval GPTQ Result---------') - print(make_table(base_bench)) - if "groups" in base_bench: - print(make_table(base_bench, "groups")) + print(format_eval_result_table(base_bench)) print('--------Eval GPTQ + EoRA Result---------') - print(make_table(eora_bench)) - if "groups" in eora_bench: - print(make_table(eora_bench, "groups")) + print(format_eval_result_table(eora_bench)) def bench(self, path: str, backend: BACKEND, adapter: Optional[Lora]): # test post-quant inference @@ -329,10 +317,9 @@ def bench(self, path: str, backend: BACKEND, adapter: Optional[Lora]): print(f"BACKEND: {backend}, Result: {result}") # assert "paris" in result.lower(), f"`paris` not found in `{result}`" - bench_result = GPTQModel.eval( + bench_result = evaluate( model_or_id_or_path=model, - framework=EVAL.LM_EVAL, - tasks=[EVAL.LM_EVAL.ARC_CHALLENGE, EVAL.LM_EVAL.MMLU_STEM], + tasks=["arc_challenge", "mmlu_stem"], ) del model diff --git a/tests/test_quant_batch.py b/tests/test_quant_batch.py index 699b3fb17..bc4dc7230 100644 --- a/tests/test_quant_batch.py +++ b/tests/test_quant_batch.py @@ -12,33 +12,28 @@ import tempfile # noqa: E402 +import torch # noqa: E402 from models.model_test import ModelTest # noqa: E402 from transformers import AutoTokenizer # noqa: E402 from gptqmodel import GPTQModel # noqa: E402 from gptqmodel.quantization import QuantizeConfig # noqa: E402 -from gptqmodel.utils.perplexity import Perplexity # noqa: E402 class TestQuantBatch(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/TinyLlama-1.1B-Chat-v1.0" - def calculate_avg_ppl(self, model, tokenizer): - ppl = Perplexity( - model=model, - tokenizer=tokenizer, - dataset_path="wikitext", - dataset_name="wikitext-2-raw-v1", - split="test", - text_column="text", - ) - - all = ppl.calculate(n_ctx=512, n_batch=512) - - # average ppl - avg = sum(all) / len(all) - - return avg + def _generate(self, model, tokenizer, prompt: str = "Paris is known as"): + inputs = tokenizer(prompt, return_tensors="pt") + with torch.inference_mode(): + outputs = model.generate( + **inputs, + max_new_tokens=10, + num_beams=1, + do_sample=False, + pad_token_id=tokenizer.pad_token_id, + ) + return outputs @classmethod def setUpClass(self): @@ -73,12 +68,12 @@ def test_diff_batch(self): tmp_dir, ) - batch_1_ppl = self.calculate_avg_ppl(model, self.tokenizer) + batch_1_ids = self._generate(model, self.tokenizer) - model = GPTQModel.load( - self.NATIVE_MODEL_ID, - quantize_config=quantize_config, - ) + model = GPTQModel.load( + self.NATIVE_MODEL_ID, + quantize_config=quantize_config, + ) model.quantize(self.calibration_dataset, batch_size=256) with tempfile.TemporaryDirectory() as tmp_dir: @@ -92,8 +87,8 @@ def test_diff_batch(self): tmp_dir, ) - batch_n_ppl = self.calculate_avg_ppl(model, self.tokenizer) + batch_n_ids = self._generate(model, self.tokenizer) del model - self.assertTrue(abs(batch_1_ppl - batch_n_ppl) / batch_1_ppl <= 0.05) + self.assertTrue((batch_1_ids == batch_n_ids).all().item()) diff --git a/tests/test_quant_dtype.py b/tests/test_quant_dtype.py index c21f3564d..107c7e36c 100644 --- a/tests/test_quant_dtype.py +++ b/tests/test_quant_dtype.py @@ -1,27 +1,214 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +import gc +import os +import statistics import time +from functools import lru_cache +from pathlib import Path import pytest import torch +from safetensors import safe_open from tabulate import tabulate from gptqmodel.quantization.dtype import ( + _DTYPE_SUPPORT_CACHE, + _cpu_floatx_threads, + _dequantize_f4_reference, + _dequantize_f8_reference, + available_float8_dtype_names, dequantize_f4_e2m1, dequantize_f8_e4m3, + device_supports_dtype, + device_supports_native_fp4, device_supports_native_fp8, + get_device_dtype_support, +) + + +# Default to the preferred GLM FP8 checkpoint, but fall back to a real local FP8 model +# so CPU A/B runs stay realistic on machines where the original mount is absent. +_FLOATX_BENCH_ENV = os.environ.get("GPTQMODEL_FLOATX_BENCH_MODEL") +_FLOATX_BENCH_MODEL_CANDIDATES = ( + [Path(_FLOATX_BENCH_ENV)] if _FLOATX_BENCH_ENV else [ + Path("/monster/data/model/GLM-5.1-FP8"), + Path("/root/model/DeepSeek-V3-0324"), + ] +) +FLOATX_BENCH_MODEL_ROOT = next( + (candidate for candidate in _FLOATX_BENCH_MODEL_CANDIDATES if candidate.exists()), + _FLOATX_BENCH_MODEL_CANDIDATES[0], ) def _print_accuracy(title: str, rows, headers) -> None: - table = tabulate(rows, headers=headers, floatfmt=".6f") + table = tabulate(rows, headers=headers, floatfmt=".6f", tablefmt="grid") print(f"\n{title}\n{table}\n") +def _print_benchmark(title: str, rows, headers, note: str | None = None) -> None: + table = tabulate(rows, headers=headers, floatfmt=".4f", tablefmt="grid") + note_block = f"{note}\n" if note else "" + print(f"\n{title}\n{note_block}{table}\n") + + try: # pragma: no cover - optional dependency from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor, nvfp4_quantize except Exception: # pragma: no cover NVFP4Tensor = None nvfp4_quantize = None +try: # pragma: no cover - optional dependency + import psutil +except Exception: # pragma: no cover + psutil = None + + +pytestmark = [pytest.mark.cpu, pytest.mark.gpu] + + +def _available_fp8_formats() -> list[str]: + return list(available_float8_dtype_names()) + + +def _rss_bytes() -> int: + if psutil is not None: + return int(psutil.Process(os.getpid()).memory_info().rss) + with open("/proc/self/statm", "r", encoding="utf-8") as fh: + pages = int(fh.readline().split()[1]) + return pages * os.sysconf("SC_PAGE_SIZE") + + +def _tensor_mib(tensor: torch.Tensor) -> float: + return float(tensor.numel() * tensor.element_size()) / (1024.0 * 1024.0) + + +def _benchmark_profile(numel: int) -> tuple[int, int]: + # Huge realistic layers still need one untimed pass so lazy native setup does not skew A/B stats. + if numel >= 64 * 1024 * 1024: + return 1, 2 + return 1, 4 + + +def _benchmark_cpu_impl(fn, *, warmup: int = 1, iters: int = 4): + for _ in range(warmup): + tmp = fn() + del tmp + gc.collect() + + samples_ms: list[float] = [] + rss_deltas: list[float] = [] + for _ in range(iters): + gc.collect() + rss_before = _rss_bytes() + start = time.perf_counter() + out = fn() + elapsed_ms = (time.perf_counter() - start) * 1e3 + rss_after = _rss_bytes() + samples_ms.append(elapsed_ms) + rss_deltas.append(max(0, rss_after - rss_before) / (1024.0 * 1024.0)) + del out + + gc.collect() + result = fn() + stats = { + "median_ms": float(statistics.median(samples_ms)), + "rss_delta_mib": float(max(rss_deltas) if rss_deltas else 0.0), + "output_mib": _tensor_mib(result), + } + return result, stats + + +def _synthetic_fp8_benchmark_source() -> tuple[torch.Tensor, torch.Tensor, str, int, int]: + rows = 256 + cols = 384 + src = torch.randn(rows, cols, dtype=torch.float32) + scale_inv = torch.rand(rows // 64, cols // 64, dtype=torch.float32) * 0.5 + source = "source: synthetic fallback rows=256 cols=384 scale=[4, 6]" + return src, scale_inv, source, rows, cols + + +@lru_cache(maxsize=1) +def _realistic_fp8_benchmark_spec() -> tuple[Path, str, str, int, int, tuple[int, ...]] | None: + if not FLOATX_BENCH_MODEL_ROOT.exists(): + return None + + largest: tuple[int, Path, str, str, int, int, tuple[int, ...]] | None = None + for path in sorted(FLOATX_BENCH_MODEL_ROOT.glob("model-*.safetensors")): + with safe_open(path, framework="pt", device="cpu") as tensors: + tensor_keys = tuple(sorted(tensors.keys())) + for key in tensor_keys: + if not key.endswith(".weight"): + continue + tensor_slice = tensors.get_slice(key) + shape = tensor_slice.get_shape() + if len(shape) != 2: + continue + dtype_name = str(tensor_slice.get_dtype()) + if not dtype_name.startswith("F8_"): + continue + scale_key = key.replace(".weight", ".weight_scale_inv") + if scale_key not in tensor_keys: + continue + area = int(shape[0]) * int(shape[1]) + if largest is None or area > largest[0]: + scale_shape = tuple(int(dim) for dim in tensors.get_slice(scale_key).get_shape()) + largest = (area, path, key, scale_key, int(shape[0]), int(shape[1]), scale_shape) + + if largest is None: + return None + + _, path, key, scale_key, rows, cols, scale_shape = largest + return path, key, scale_key, rows, cols, scale_shape + + +@lru_cache(maxsize=1) +def _realistic_fp8_benchmark_source() -> tuple[torch.Tensor, torch.Tensor, str, int, int]: + spec = _realistic_fp8_benchmark_spec() + if spec is None: + return _synthetic_fp8_benchmark_source() + + path, weight_key, scale_key, rows, cols, scale_shape = spec + with safe_open(path, framework="pt", device="cpu") as tensors: + packed = tensors.get_tensor(weight_key) + scale_inv = tensors.get_tensor(scale_key) + + # Reuse the checkpoint's real FP8 distribution as the float source for both FP8 and FP4 decode tables. + src = _dequantize_f8_reference( + packed, + scale_inv=scale_inv, + axis=None, + target_dtype=torch.bfloat16, + ).to(torch.float32) + source = ( + f"source: {path.name}:{weight_key} rows={rows} cols={cols} " + f"scale={list(scale_shape)} model_root={FLOATX_BENCH_MODEL_ROOT}" + ) + return src, scale_inv, source, rows, cols + + +@lru_cache(maxsize=1) +def _realistic_fp8_native_benchmark_source() -> tuple[torch.Tensor, torch.Tensor, str, int, int]: + spec = _realistic_fp8_benchmark_spec() + if spec is None: + src, scale_inv, source, rows, cols = _synthetic_fp8_benchmark_source() + packed = src.to(torch.float8_e4m3fn) + return packed, scale_inv, f"{source} native_format=float8_e4m3fn", rows, cols + + path, weight_key, scale_key, rows, cols, scale_shape = spec + with safe_open(path, framework="pt", device="cpu") as tensors: + packed = tensors.get_tensor(weight_key) + scale_inv = tensors.get_tensor(scale_key) + + source = ( + f"source: {path.name}:{weight_key} rows={rows} cols={cols} " + f"scale={list(scale_shape)} native_format={str(packed.dtype).split('.')[-1]} " + f"model_root={FLOATX_BENCH_MODEL_ROOT}" + ) + return packed, scale_inv, source, rows, cols + @pytest.mark.skipif(not hasattr(torch, "float8_e4m3fn"), reason="float8 dtype not available") def test_dequantize_f8_e4m3_basic_conversion(): @@ -97,6 +284,34 @@ def test_dequantize_f8_e4m3_with_fractional_scale_inv(): assert torch.equal(got, expected) +@pytest.mark.skipif(not hasattr(torch, "float8_e4m3fn"), reason="float8 dtype not available") +def test_dequantize_f8_cpu_prefers_reference_for_standard_fp8(monkeypatch: pytest.MonkeyPatch): + src = torch.linspace(-1, 1, steps=16, dtype=torch.float32).reshape(4, 4) + fp8 = src.to(torch.float8_e4m3fn) + scale_inv = torch.ones_like(src, dtype=torch.float32) + expected = _dequantize_f8_reference( + fp8, + scale_inv=scale_inv, + axis=None, + target_dtype=torch.bfloat16, + ) + + def fail_load(): + raise AssertionError("native FP8 kernel should be bypassed for standard torch FP8 dtypes") + + monkeypatch.delenv("GPTQMODEL_FLOATX_CPU_FORCE_NATIVE_FP8", raising=False) + monkeypatch.setattr("gptqmodel.quantization.dtype._load_floatx_cpu_ops", fail_load) + + got = dequantize_f8_e4m3( + fp8, + scale_inv=scale_inv, + axis=None, + target_dtype=torch.bfloat16, + ) + + assert torch.equal(got, expected) + + @pytest.mark.skipif(not hasattr(torch, "float8_e4m3fn"), reason="float8 dtype not available") def test_dequantize_f8_e4m3_raises_on_both_scale_and_inverse(): tensor = torch.zeros(2, dtype=torch.float8_e4m3fn) @@ -104,13 +319,69 @@ def test_dequantize_f8_e4m3_raises_on_both_scale_and_inverse(): dequantize_f8_e4m3(tensor, scale=torch.ones(2), scale_inv=torch.ones(2)) +def test_device_dtype_support_reports_arch_mapping(monkeypatch): + _DTYPE_SUPPORT_CACHE.clear() + + +@pytest.mark.skipif(not hasattr(torch, "float8_e4m3fn"), reason="float8 dtype not available") +@pytest.mark.parametrize("target_dtype", [torch.bfloat16, torch.float16], ids=["bf16", "fp16"]) +def test_dequantize_f8_e4m3_disable_avx2_override_matches_reference(monkeypatch, target_dtype: torch.dtype): + src = torch.randn(32, 64, dtype=torch.float32) + fp8 = src.to(torch.float8_e4m3fn) + scale_inv = torch.rand(2, 4, dtype=torch.float32) * 0.5 + + monkeypatch.setenv("GPTQMODEL_FLOATX_CPU_DISABLE_AVX2", "1") + got = dequantize_f8_e4m3(fp8, scale_inv=scale_inv, axis=None, target_dtype=target_dtype) + expected = _dequantize_f8_reference( + fp8, + scale_inv=scale_inv, + axis=None, + target_dtype=target_dtype, + ) + + assert torch.equal(got, expected) + + monkeypatch.setattr("torch.cuda.is_available", lambda: True) + monkeypatch.setattr("torch.cuda.get_device_capability", lambda device=None: (8, 9)) + + support = get_device_dtype_support(torch.device("cuda", 0)) + + assert support.capability == (8, 9) + assert torch.float16 in support.advertised_linear_dtypes + assert torch.float32 in support.advertised_linear_dtypes + assert torch.bfloat16 in support.advertised_linear_dtypes + assert torch.float8_e4m3fn in support.advertised_linear_dtypes + + def test_device_supports_native_fp8_reports_capability(monkeypatch): + _DTYPE_SUPPORT_CACHE.clear() monkeypatch.setattr("torch.cuda.is_available", lambda: True) - monkeypatch.setattr("torch.cuda.get_device_capability", lambda device=None: (9, 0)) + monkeypatch.setattr("torch.cuda.get_device_capability", lambda device=None: (8, 9)) assert device_supports_native_fp8(torch.device("cuda", 0)) is True + assert device_supports_dtype(torch.device("cuda", 0), torch.float8_e4m3fn) is True + _DTYPE_SUPPORT_CACHE.clear() monkeypatch.setattr("torch.cuda.get_device_capability", lambda device=None: (8, 0)) assert device_supports_native_fp8(torch.device("cuda", 0)) is False + assert device_supports_dtype(torch.device("cuda", 0), torch.float8_e4m3fn) is False + + +@pytest.mark.skipif(not hasattr(torch, "float4_e2m1fn_x2"), reason="float4 packed dtype not available") +def test_device_supports_native_fp4_reports_capability(monkeypatch): + _DTYPE_SUPPORT_CACHE.clear() + monkeypatch.setattr("torch.cuda.is_available", lambda: True) + monkeypatch.setattr("torch.cuda.get_device_capability", lambda device=None: (10, 0)) + support = get_device_dtype_support(torch.device("cuda", 0)) + assert torch.float4_e2m1fn_x2 in support.advertised_linear_dtypes + assert device_supports_native_fp4(torch.device("cuda", 0)) is True + assert device_supports_dtype(torch.device("cuda", 0), torch.float4_e2m1fn_x2) is True + + _DTYPE_SUPPORT_CACHE.clear() + monkeypatch.setattr("torch.cuda.get_device_capability", lambda device=None: (8, 9)) + support = get_device_dtype_support(torch.device("cuda", 0)) + assert torch.float4_e2m1fn_x2 not in support.advertised_linear_dtypes + assert device_supports_native_fp4(torch.device("cuda", 0)) is False + assert device_supports_dtype(torch.device("cuda", 0), torch.float4_e2m1fn_x2) is False @pytest.mark.skipif(NVFP4Tensor is None, reason="torchao NVFP4 support required") @@ -121,7 +392,7 @@ def test_dequantize_f4_e2m1_matches_nvfp4tensor(): dequant = dequantize_f4_e2m1(packed, scale=scales, axis=None, target_dtype=torch.bfloat16) nv_tensor = NVFP4Tensor(packed, scales, block_size=16, orig_dtype=torch.bfloat16) - expected = nv_tensor.to_dtype(torch.bfloat16) + expected = nv_tensor.dequantize(torch.bfloat16) diff = torch.max(torch.abs(dequant - expected)).item() _print_accuracy( @@ -135,6 +406,26 @@ def test_dequantize_f4_e2m1_matches_nvfp4tensor(): assert torch.allclose(dequant, expected, atol=1e-3, rtol=1e-3) +@pytest.mark.skipif(NVFP4Tensor is None, reason="torchao NVFP4 support required") +@pytest.mark.parametrize("target_dtype", [torch.bfloat16, torch.float16], ids=["bf16", "fp16"]) +def test_dequantize_f4_e2m1_disable_avx2_override_matches_reference(monkeypatch, target_dtype: torch.dtype): + torch.manual_seed(4) + data = torch.randn(32, 64, dtype=torch.float32) + scales, packed = nvfp4_quantize(data, block_size=16) + packed_float4 = packed.view(torch.float4_e2m1fn_x2) if hasattr(torch, "float4_e2m1fn_x2") else packed + + monkeypatch.setenv("GPTQMODEL_FLOATX_CPU_DISABLE_AVX2", "1") + got = dequantize_f4_e2m1(packed_float4, scale=scales, axis=None, target_dtype=target_dtype) + expected = _dequantize_f4_reference( + packed, + scale=scales, + axis=None, + target_dtype=target_dtype, + ) + + assert torch.allclose(got, expected, atol=1e-3, rtol=1e-3) + + @pytest.mark.skipif(NVFP4Tensor is None, reason="torchao NVFP4 support required") @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA device required") def test_dequantize_f4_e2m1_cpu_vs_gpu(): @@ -207,3 +498,218 @@ def test_dequantize_f8_e4m3_cpu_vs_gpu_benchmark(): # GPU should not be dramatically slower than CPU assert gpu_time <= cpu_time * 2, f"GPU dequant slower than expected (cpu={cpu_time:.4f}s, gpu={gpu_time:.4f}s)" + + +@pytest.mark.parametrize("target_dtype", [torch.bfloat16, torch.float16], ids=["bf16", "fp16"]) +def test_dequantize_fp8_cpu_ab_benchmark_table(target_dtype: torch.dtype): + src, scale_inv, source_note, rows, cols = _realistic_fp8_benchmark_source() + warmup, iters = _benchmark_profile(rows * cols) + bench_rows = [] + + for fmt_name in _available_fp8_formats(): + fmt = getattr(torch, fmt_name) + fmt_src = src.abs() if "e8m0" in fmt_name else src + packed = fmt_src.to(fmt) + + ref_fn = lambda: _dequantize_f8_reference( # noqa: E731 + packed, + scale_inv=scale_inv, + axis=None, + target_dtype=target_dtype, + ) + fast_fn = lambda: dequantize_f8_e4m3( # noqa: E731 + packed, + scale_inv=scale_inv, + axis=None, + target_dtype=target_dtype, + ) + + ref, ref_stats = _benchmark_cpu_impl(ref_fn, warmup=warmup, iters=iters) + fast, fast_stats = _benchmark_cpu_impl(fast_fn, warmup=warmup, iters=iters) + diff = float(torch.max(torch.abs(ref.to(torch.float32) - fast.to(torch.float32))).item()) + + input_mib = _tensor_mib(packed) + _tensor_mib(scale_inv) + ref_throughput = (input_mib + ref_stats["output_mib"]) / max(ref_stats["median_ms"] / 1e3, 1e-9) + fast_throughput = (input_mib + fast_stats["output_mib"]) / max(fast_stats["median_ms"] / 1e3, 1e-9) + speedup = ref_stats["median_ms"] / max(fast_stats["median_ms"], 1e-9) + throughput_gain_pct = ((fast_throughput - ref_throughput) / max(ref_throughput, 1e-9)) * 100.0 + + bench_rows.append([ + fmt_name, + str(target_dtype).split(".")[-1], + ref_stats["median_ms"], + fast_stats["median_ms"], + speedup, + ref_throughput, + fast_throughput, + throughput_gain_pct, + diff, + ]) + + assert torch.equal(fast, ref) + + _print_benchmark( + f"fp8_cpu_ab_{str(target_dtype).split('.')[-1]}", + bench_rows, + [ + "format", + "target", + "ref ms", + "native ms", + "speedup x", + "ref MiB/s", + "native MiB/s", + "throughput delta %", + "max|diff|", + ], + note=source_note, + ) + + +@pytest.mark.parametrize("target_dtype", [torch.bfloat16, torch.float16], ids=["bf16", "fp16"]) +def test_dequantize_fp8_cpu_real_format_ab_benchmark_table(target_dtype: torch.dtype): + packed, scale_inv, source_note, rows, cols = _realistic_fp8_native_benchmark_source() + warmup, iters = _benchmark_profile(rows * cols) + enable_large_threads = ( + target_dtype is torch.bfloat16 and + hasattr(torch, "float8_e4m3fn") and + packed.dtype is torch.float8_e4m3fn + ) + + ref_fn = lambda: _dequantize_f8_reference( # noqa: E731 + packed, + scale_inv=scale_inv, + axis=None, + target_dtype=target_dtype, + ) + fast_fn = lambda: dequantize_f8_e4m3( # noqa: E731 + packed, + scale_inv=scale_inv, + axis=None, + target_dtype=target_dtype, + ) + + ref, ref_stats = _benchmark_cpu_impl(ref_fn, warmup=warmup, iters=iters) + fast, fast_stats = _benchmark_cpu_impl(fast_fn, warmup=warmup, iters=iters) + diff = float(torch.max(torch.abs(ref.to(torch.float32) - fast.to(torch.float32))).item()) + + input_mib = _tensor_mib(packed) + _tensor_mib(scale_inv) + ref_throughput = (input_mib + ref_stats["output_mib"]) / max(ref_stats["median_ms"] / 1e3, 1e-9) + fast_throughput = (input_mib + fast_stats["output_mib"]) / max(fast_stats["median_ms"] / 1e3, 1e-9) + speedup = ref_stats["median_ms"] / max(fast_stats["median_ms"], 1e-9) + throughput_gain_pct = ((fast_throughput - ref_throughput) / max(ref_throughput, 1e-9)) * 100.0 + + assert torch.equal(fast, ref) + + _print_benchmark( + f"fp8_cpu_real_format_ab_{str(target_dtype).split('.')[-1]}", + [[ + str(packed.dtype).split(".")[-1], + str(target_dtype).split(".")[-1], + ref_stats["median_ms"], + fast_stats["median_ms"], + speedup, + ref_throughput, + fast_throughput, + throughput_gain_pct, + diff, + ]], + [ + "format", + "target", + "ref ms", + "native ms", + "speedup x", + "ref MiB/s", + "native MiB/s", + "throughput delta %", + "max|diff|", + ], + note=( + f"{source_note} " + f"threads={_cpu_floatx_threads(rows * cols, enable_large_threads=enable_large_threads)}" + ), + ) + + +@pytest.mark.skipif(NVFP4Tensor is None, reason="torchao NVFP4 support required") +@pytest.mark.parametrize("target_dtype", [torch.bfloat16, torch.float16], ids=["bf16", "fp16"]) +def test_dequantize_fp4_cpu_ab_benchmark_table(target_dtype: torch.dtype): + data, _, source_note, rows, cols = _realistic_fp8_benchmark_source() + warmup, iters = _benchmark_profile(rows * cols) + scales, packed = nvfp4_quantize(data, block_size=16) + packed_float4 = packed.view(torch.float4_e2m1fn_x2) if hasattr(torch, "float4_e2m1fn_x2") else None + + ref_fn = lambda: _dequantize_f4_reference( # noqa: E731 + packed, + scale=scales, + axis=None, + target_dtype=target_dtype, + ) + fast_uint8_fn = lambda: dequantize_f4_e2m1( # noqa: E731 + packed, + scale=scales, + axis=None, + target_dtype=target_dtype, + ) + + ref, ref_stats = _benchmark_cpu_impl(ref_fn, warmup=warmup, iters=iters) + fast_uint8, fast_uint8_stats = _benchmark_cpu_impl(fast_uint8_fn, warmup=warmup, iters=iters) + bench_rows = [] + input_mib = _tensor_mib(packed) + _tensor_mib(scales) + ref_throughput = (input_mib + ref_stats["output_mib"]) / max(ref_stats["median_ms"] / 1e3, 1e-9) + fast_uint8_throughput = (input_mib + fast_uint8_stats["output_mib"]) / max(fast_uint8_stats["median_ms"] / 1e3, 1e-9) + fast_uint8_diff = float(torch.max(torch.abs(ref.to(torch.float32) - fast_uint8.to(torch.float32))).item()) + bench_rows.append([ + str(target_dtype).split(".")[-1], + "native:uint8", + ref_stats["median_ms"], + fast_uint8_stats["median_ms"], + ref_stats["median_ms"] / max(fast_uint8_stats["median_ms"], 1e-9), + ref_throughput, + fast_uint8_throughput, + ((fast_uint8_throughput - ref_throughput) / max(ref_throughput, 1e-9)) * 100.0, + fast_uint8_diff, + ]) + + assert torch.allclose(fast_uint8, ref, atol=1e-3, rtol=1e-3) + + if packed_float4 is not None: + fast_float4_fn = lambda: dequantize_f4_e2m1( # noqa: E731 + packed_float4, + scale=scales, + axis=None, + target_dtype=target_dtype, + ) + fast_float4, fast_float4_stats = _benchmark_cpu_impl(fast_float4_fn, warmup=warmup, iters=iters) + throughput = (input_mib + fast_float4_stats["output_mib"]) / max(fast_float4_stats["median_ms"] / 1e3, 1e-9) + diff = float(torch.max(torch.abs(ref.to(torch.float32) - fast_float4.to(torch.float32))).item()) + bench_rows.append([ + str(target_dtype).split(".")[-1], + "native:float4_x2", + ref_stats["median_ms"], + fast_float4_stats["median_ms"], + ref_stats["median_ms"] / max(fast_float4_stats["median_ms"], 1e-9), + ref_throughput, + throughput, + ((throughput - ref_throughput) / max(ref_throughput, 1e-9)) * 100.0, + diff, + ]) + assert torch.allclose(fast_float4, ref, atol=1e-3, rtol=1e-3) + + _print_benchmark( + f"fp4_cpu_ab_{str(target_dtype).split('.')[-1]}", + bench_rows, + [ + "target", + "candidate", + "ref ms", + "native ms", + "speedup x", + "ref MiB/s", + "native MiB/s", + "throughput delta %", + "max|diff|", + ], + note=f"{source_note} fp4_block_size=16", + ) diff --git a/tests/test_quant_trust_remote.py b/tests/test_quant_trust_remote.py index 899eeba79..a6f18ab61 100644 --- a/tests/test_quant_trust_remote.py +++ b/tests/test_quant_trust_remote.py @@ -12,6 +12,7 @@ import tempfile # noqa: E402 +import pytest # noqa: E402 from models.model_test import ModelTest # noqa: E402 from transformers import AutoTokenizer # noqa: E402 @@ -19,6 +20,9 @@ from gptqmodel.quantization import FORMAT, QuantizeConfig # noqa: E402 +pytestmark = [pytest.mark.model, pytest.mark.slow] + + class TestQuantWithTrustRemoteTrue(ModelTest): @classmethod def setUpClass(self): @@ -57,4 +61,3 @@ def test_diff_batch(self): self.assertIn(file, py_files, f"File {file} is missing in the actual files list") - diff --git a/tests/test_qwen2_family_compat.py b/tests/test_qwen2_family_compat.py new file mode 100644 index 000000000..a4a6cc88a --- /dev/null +++ b/tests/test_qwen2_family_compat.py @@ -0,0 +1,277 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import types +import warnings + +import torch +from PIL import Image +from tokenicer import Tokenicer +from tokenizers import Tokenizer +from tokenizers.models import WordLevel +from tokenizers.pre_tokenizers import Whitespace +from torch import nn +from transformers import PreTrainedTokenizerFast + +from gptqmodel.models.definitions import base_qwen2_5_omni, base_qwen2_vl +from gptqmodel.utils.hf import load_tokenizer + + +def test_qwen2_vl_image_only_process_vision_info_returns_image_list(): + image = Image.new("RGB", (2, 2), color="white") + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image}, + {"type": "text", "text": "Describe this image."}, + ], + } + ] + + image_inputs = base_qwen2_vl.BaseQwen2VLGPTQ.process_vision_info(messages) + + assert isinstance(image_inputs, list) + assert image_inputs == [image] + + +def test_qwen2_vl_pre_quantize_hooks_use_inner_model_layout(): + instance = object.__new__(base_qwen2_vl.BaseQwen2VLGPTQ) + instance.model = types.SimpleNamespace( + language_model=types.SimpleNamespace( + embed_tokens=nn.Embedding(4, 4), + rotary_emb=nn.Identity(), + ), + visual=nn.Identity(), + ) + instance.quantize_config = types.SimpleNamespace( + device="cpu", + offload_to_disk=False, + offload_to_disk_path="/tmp/unused", + ) + + instance.pre_quantize_generate_hook_start() + instance.pre_quantize_generate_hook_end() + + assert instance.model.language_model.embed_tokens.weight.device.type == "cpu" + + +def test_qwen2_vl_layout_resolution_supports_nested_wrapper(): + class _InnerModel(nn.Module): + def __init__(self): + super().__init__() + self.language_model = nn.Module() + self.language_model.layers = nn.ModuleList([nn.Identity()]) + self.visual = nn.Identity() + self.merger = nn.Identity() + + class _OuterModel(nn.Module): + def __init__(self): + super().__init__() + self.model = _InnerModel() + + model = _OuterModel() + + assert base_qwen2_vl.BaseQwen2VLGPTQ.extract_layers_node() == [ + "model.language_model.layers", + "language_model.layers", + ] + assert base_qwen2_vl.BaseQwen2VLGPTQ.get_base_modules(model) == ["model.visual", "model.merger"] + + +def test_qwen2_vl_pre_quantize_hooks_materialize_meta_modules(): + instance = object.__new__(base_qwen2_vl.BaseQwen2VLGPTQ) + instance.model = types.SimpleNamespace( + language_model=types.SimpleNamespace( + embed_tokens=nn.Embedding(4, 4, device="meta"), + rotary_emb=nn.Linear(4, 4, device="meta"), + ), + visual=nn.Linear(4, 4, device="meta"), + ) + instance.quantize_config = types.SimpleNamespace( + device="cpu", + offload_to_disk=False, + offload_to_disk_path="/tmp/unused", + ) + + materialized = {} + + def fake_materialize(module, device): + replacement = nn.Linear(4, 4) if isinstance(module, nn.Linear) else nn.Embedding(4, 4) + materialized[id(module)] = (replacement, device) + return replacement + + instance.shell_module_materialize = fake_materialize + + instance.pre_quantize_generate_hook_start() + + assert instance.model.visual.weight.device == torch.device("cpu") + assert instance.model.language_model.embed_tokens.weight.device == torch.device("cpu") + assert instance.model.language_model.rotary_emb.weight.device == torch.device("cpu") + assert len(materialized) == 3 + + +def test_qwen2_5_omni_image_only_process_vision_info_returns_image_list(): + image = Image.new("RGB", (2, 2), color="white") + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image}, + {"type": "text", "text": "Describe this image."}, + ], + } + ] + + image_inputs = base_qwen2_5_omni.BaseQwen2_5_OmniGPTQ.process_vision_info(messages) + + assert isinstance(image_inputs, list) + assert image_inputs == [image] + + +def test_qwen2_5_omni_forward_delegates_to_thinker(): + sentinel = object() + instance = object.__new__(base_qwen2_5_omni.BaseQwen2_5_OmniGPTQ) + instance.model = types.SimpleNamespace(thinker=lambda *args, **kwargs: (args, kwargs, sentinel)) + + result = instance.forward("hello", temperature=0.1) + + assert result == (("hello",), {"temperature": 0.1}, sentinel) + + +def test_qwen2_5_omni_talker_patch_accepts_next_sequence_length_kwarg(): + class _BaseTalker: + def prepare_inputs_for_generation( + self, + input_ids, + next_sequence_length=None, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + is_first_iteration=False, + **kwargs, + ): + return { + "input_ids": input_ids, + "received_next_sequence_length": next_sequence_length, + "past_key_values": past_key_values, + "attention_mask": attention_mask, + "inputs_embeds": inputs_embeds, + "cache_position": cache_position, + "is_first_iteration": is_first_iteration, + **kwargs, + } + + class _Talker(_BaseTalker): + pass + + base_qwen2_5_omni._patch_qwen2_5_omni_talker_prepare_inputs_for_generation(_Talker) + + model_inputs = _Talker().prepare_inputs_for_generation( + input_ids=torch.tensor([[1, 2]]), + input_text_ids=torch.tensor([[3, 4]]), + past_key_values="pkv", + attention_mask=torch.tensor([[1, 1]]), + inputs_embeds=torch.randn(1, 2, 4), + thinker_reply_part=torch.randn(1, 2, 4), + cache_position=torch.tensor([0, 1]), + use_cache=True, + next_sequence_length=1, + is_first_iteration=True, + ) + + assert model_inputs["received_next_sequence_length"] == 1 + assert torch.equal(model_inputs["input_ids"], torch.tensor([[1, 2]])) + assert torch.equal(model_inputs["input_text_ids"], torch.tensor([[3, 4]])) + assert model_inputs["position_ids"] is None + assert model_inputs["is_first_iteration"] is True + + +def test_qwen2_5_omni_pre_quantize_hooks_use_thinker_layout(): + loaded_speakers = [] + materialized = [] + + class _Visual(nn.Module): + def __init__(self): + super().__init__() + self.rotary_pos_emb = nn.Identity() + + class _Thinker(nn.Module): + def __init__(self): + super().__init__() + self.model = types.SimpleNamespace( + embed_tokens=nn.Embedding(4, 4), + rotary_emb=nn.Identity(), + layers=[ + types.SimpleNamespace( + self_attn=types.SimpleNamespace(rotary_emb=nn.Identity()), + ) + ], + ) + self.visual = _Visual() + self.audio_tower = nn.Identity() + + instance = object.__new__(base_qwen2_5_omni.BaseQwen2_5_OmniGPTQ) + instance.model_local_path = "/tmp/qwen2_5_omni" + instance.model = types.SimpleNamespace( + load_speakers=lambda path: loaded_speakers.append(path), + thinker=_Thinker(), + ) + instance.quantize_config = types.SimpleNamespace( + device="cpu", + offload_to_disk=False, + offload_to_disk_path="/tmp/unused", + ) + instance.shell_module_materialize = lambda module, device: materialized.append((module, device)) or module + + instance.pre_quantize_generate_hook_start() + instance.pre_quantize_generate_hook_end() + + assert loaded_speakers == ["/tmp/qwen2_5_omni/spk_dict.pt"] + assert len(materialized) == 6 + assert instance.model.thinker.model.embed_tokens.weight.device.type == "cpu" + + +def test_tokenicer_load_uses_text_config_for_qwen2_5_omni_style_composite_configs(): + backend = Tokenizer(WordLevel({"": 0, "": 1, "hello": 2}, unk_token="")) + backend.pre_tokenizer = Whitespace() + tokenizer = PreTrainedTokenizerFast(tokenizer_object=backend, pad_token="", eos_token="") + + text_config = types.SimpleNamespace( + model_type="qwen2_5_omni_text", + pad_token_id=None, + bos_token_id=None, + eos_token_id=None, + ) + + class _CompositeConfig: + def get_text_config(self): + return text_config + + wrapped = Tokenicer.load(tokenizer, model_config=_CompositeConfig()) + + assert wrapped.model_config is text_config + assert wrapped.eos_token_id == tokenizer.eos_token_id + assert text_config.pad_token_id == tokenizer.pad_token_id + assert text_config.eos_token_id == tokenizer.eos_token_id + + +def test_load_tokenizer_deprecated_shim_forwards_to_tokenicer(): + backend = Tokenizer(WordLevel({"": 0, "": 1}, unk_token="")) + backend.pre_tokenizer = Whitespace() + tokenizer = PreTrainedTokenizerFast(tokenizer_object=backend, pad_token="", eos_token="") + + text_config = types.SimpleNamespace(model_type="qwen2_5_omni_text", pad_token_id=None, eos_token_id=None) + + class _CompositeConfig: + def get_text_config(self): + return text_config + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + wrapped = load_tokenizer(tokenizer, model_config=_CompositeConfig()) + + assert any(item.category is DeprecationWarning for item in caught) + assert wrapped.model_config is text_config diff --git a/tests/test_qwen3_5_batching.py b/tests/test_qwen3_5_batching.py new file mode 100644 index 000000000..c83f65d66 --- /dev/null +++ b/tests/test_qwen3_5_batching.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import pytest + + +pytest.importorskip("transformers.models.qwen3_5") + +from gptqmodel.models.definitions.qwen3_5 import Qwen3_5QModel + + +def test_qwen3_5_disables_batch_quantization(): + assert Qwen3_5QModel.support_batch_quantize is False diff --git a/tests/test_qwen3_vl_dependency.py b/tests/test_qwen3_vl_dependency.py index ac3e86ded..8a530fd20 100644 --- a/tests/test_qwen3_vl_dependency.py +++ b/tests/test_qwen3_vl_dependency.py @@ -4,8 +4,12 @@ import builtins import sys +import types import pytest +import torch +from PIL import Image +from torch import nn from gptqmodel.models.definitions import base_qwen3_vl @@ -35,3 +39,121 @@ def fail_qwen_vl_import(name, *args, **kwargs): with pytest.raises(ImportError, match="pip install qwen-vl-utils"): base_qwen3_vl.BaseQwen3VLGPTQ.process_vision_info(messages) + + +def test_qwen3_vl_image_only_process_vision_info_returns_image_list(): + image = Image.new("RGB", (2, 2), color="white") + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image}, + {"type": "text", "text": "Describe this image."}, + ], + } + ] + + image_inputs = base_qwen3_vl.BaseQwen3VLGPTQ.process_vision_info(messages) + + assert isinstance(image_inputs, list) + assert image_inputs == [image] + + +def test_qwen3_vl_pre_quantize_hooks_use_inner_model_layout(): + instance = object.__new__(base_qwen3_vl.BaseQwen3VLGPTQ) + inner_model = types.SimpleNamespace( + language_model=types.SimpleNamespace( + embed_tokens=nn.Embedding(4, 4), + rotary_emb=nn.Identity(), + ), + visual=nn.Identity(), + ) + instance.model = types.SimpleNamespace(model=inner_model) + instance.quantize_config = types.SimpleNamespace( + device="cpu", + offload_to_disk=False, + offload_to_disk_path="/tmp/unused", + ) + + instance.pre_quantize_generate_hook_start() + instance.pre_quantize_generate_hook_end() + + assert instance.model.model.language_model.embed_tokens.weight.device.type == "cpu" + + +def test_qwen3_vl_pre_quantize_hooks_support_direct_layout(): + instance = object.__new__(base_qwen3_vl.BaseQwen3VLGPTQ) + instance.model = types.SimpleNamespace( + language_model=types.SimpleNamespace( + embed_tokens=nn.Embedding(4, 4), + rotary_emb=nn.Identity(), + ), + visual=nn.Identity(), + ) + instance.quantize_config = types.SimpleNamespace( + device="cpu", + offload_to_disk=False, + offload_to_disk_path="/tmp/unused", + ) + + instance.pre_quantize_generate_hook_start() + instance.pre_quantize_generate_hook_end() + + assert instance.model.language_model.embed_tokens.weight.device.type == "cpu" + + +def test_qwen3_vl_layout_resolution_supports_nested_wrapper(): + class _InnerModel(nn.Module): + def __init__(self): + super().__init__() + self.language_model = nn.Module() + self.language_model.layers = nn.ModuleList([nn.Identity()]) + self.visual = nn.Identity() + self.vision_router = nn.Identity() + + class _OuterModel(nn.Module): + def __init__(self): + super().__init__() + self.model = _InnerModel() + + model = _OuterModel() + + assert base_qwen3_vl.BaseQwen3VLGPTQ.extract_layers_node() == [ + "model.language_model.layers", + "language_model.layers", + ] + assert base_qwen3_vl.BaseQwen3VLGPTQ.get_base_modules(model) == ["model.visual", "model.vision_router"] + + +def test_qwen3_vl_pre_quantize_hooks_materialize_meta_modules_with_nested_layout(): + instance = object.__new__(base_qwen3_vl.BaseQwen3VLGPTQ) + instance.model = types.SimpleNamespace( + model=types.SimpleNamespace( + language_model=types.SimpleNamespace( + embed_tokens=nn.Embedding(4, 4, device="meta"), + rotary_emb=nn.Linear(4, 4, device="meta"), + ), + visual=nn.Linear(4, 4, device="meta"), + ) + ) + instance.quantize_config = types.SimpleNamespace( + device="cpu", + offload_to_disk=False, + offload_to_disk_path="/tmp/unused", + ) + + materialized = {} + + def fake_materialize(module, device): + replacement = nn.Linear(4, 4) if isinstance(module, nn.Linear) else nn.Embedding(4, 4) + materialized[id(module)] = (replacement, device) + return replacement + + instance.shell_module_materialize = fake_materialize + + instance.pre_quantize_generate_hook_start() + + assert instance.model.model.visual.weight.device == torch.device("cpu") + assert instance.model.model.language_model.embed_tokens.weight.device == torch.device("cpu") + assert instance.model.model.language_model.rotary_emb.weight.device == torch.device("cpu") + assert len(materialized) == 3 diff --git a/tests/test_qwen_moe_converter.py b/tests/test_qwen_moe_converter.py new file mode 100644 index 000000000..db7fad225 --- /dev/null +++ b/tests/test_qwen_moe_converter.py @@ -0,0 +1,151 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import torch +from defuser import convert_model +from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeConfig, Qwen2MoeForCausalLM +from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeConfig, Qwen3MoeForCausalLM +from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextConfig, Qwen3NextForCausalLM +from transformers.models.qwen3_omni_moe.configuration_qwen3_omni_moe import Qwen3OmniMoeConfig +from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import Qwen3OmniMoeForConditionalGeneration + +from gptqmodel.nn_modules.converter import MODULE_CONVERTER_MAP +from gptqmodel.nn_modules.qlinear.torch_awq import AwqTorchLinear + + +def _make_tiny_moe_config(config_cls): + return config_cls( + num_hidden_layers=1, + hidden_size=64, + intermediate_size=128, + moe_intermediate_size=32, + num_attention_heads=4, + num_key_value_heads=4, + num_experts=4, + num_experts_per_tok=2, + vocab_size=128, + pad_token_id=0, + ) + + +def _make_tiny_qwen3_omni_config(): + return Qwen3OmniMoeConfig( + enable_audio_output=False, + thinker_config={ + "text_config": { + "num_hidden_layers": 1, + "hidden_size": 64, + "intermediate_size": 128, + "moe_intermediate_size": 32, + "num_attention_heads": 4, + "num_key_value_heads": 4, + "num_experts": 4, + "num_experts_per_tok": 2, + "vocab_size": 128, + "pad_token_id": 0, + "bos_token_id": 1, + "eos_token_id": 2, + }, + "vision_config": { + "depth": 1, + "hidden_size": 64, + "intermediate_size": 128, + "num_heads": 4, + "out_hidden_size": 64, + "num_position_embeddings": 64, + "deepstack_visual_indexes": [0], + }, + "audio_config": { + "num_mel_bins": 16, + "encoder_layers": 1, + "encoder_attention_heads": 4, + "encoder_ffn_dim": 128, + "d_model": 64, + "output_dim": 64, + "max_source_positions": 32, + "n_window": 4, + "n_window_infer": 4, + "conv_chunksize": 16, + "downsample_hidden_size": 32, + }, + }, + ) +def _assert_converted_experts(layer, hidden_size: int, *, dtype: torch.dtype = torch.float32): + assert isinstance(layer.mlp.experts, torch.nn.Module) + assert not hasattr(layer.mlp.experts, "gate_up_proj") + assert len([name for name, _ in layer.mlp.experts.named_children() if name.isdigit()]) == 4 + + expert0 = layer.mlp.experts[0] + assert hasattr(expert0, "gate_proj") + assert hasattr(expert0, "up_proj") + assert hasattr(expert0, "down_proj") + + output = layer.mlp(torch.randn(2, 3, hidden_size, dtype=dtype)) + assert output.shape == (2, 3, hidden_size) + + +def test_qwen2_moe_uses_defuser_for_fused_experts(): + assert "qwen2_moe" not in MODULE_CONVERTER_MAP + + model = Qwen2MoeForCausalLM(_make_tiny_moe_config(Qwen2MoeConfig)) + convert_model(model, cleanup_original=False) + layer = model.model.layers[0] + + assert hasattr(layer.mlp, "shared_expert") + assert hasattr(layer.mlp, "shared_expert_gate") + _assert_converted_experts(layer, hidden_size=model.config.hidden_size) + + +def test_qwen3_moe_uses_defuser_for_fused_experts(): + assert "qwen3_moe" not in MODULE_CONVERTER_MAP + + model = Qwen3MoeForCausalLM(_make_tiny_moe_config(Qwen3MoeConfig)) + convert_model(model, cleanup_original=False) + layer = model.model.layers[0] + + assert not hasattr(layer.mlp, "shared_expert") + _assert_converted_experts(layer, hidden_size=model.config.hidden_size) + + +def test_qwen3_next_uses_defuser_for_fused_experts(): + assert "qwen3_next" not in MODULE_CONVERTER_MAP + + model = Qwen3NextForCausalLM(_make_tiny_moe_config(Qwen3NextConfig)) + convert_model(model, cleanup_original=False) + layer = model.model.layers[0] + + assert hasattr(layer.mlp, "shared_expert") + assert hasattr(layer.mlp, "shared_expert_gate") + _assert_converted_experts(layer, hidden_size=model.config.hidden_size) + + +def test_qwen3_omni_uses_defuser_for_fused_experts(): + assert "qwen3_omni_moe" not in MODULE_CONVERTER_MAP + + model = Qwen3OmniMoeForConditionalGeneration(_make_tiny_qwen3_omni_config()) + convert_model(model, cleanup_original=False, max_layers=1) + layer = model.thinker.model.layers[0] + + _assert_converted_experts( + layer, + hidden_size=model.config.get_text_config().hidden_size, + dtype=next(layer.mlp.experts[0].gate_proj.parameters()).dtype, + ) + +def test_awq_single_bit_validation_allows_skip_only_dynamic_rules(): + ok, err = AwqTorchLinear.validate( + bits=4, + group_size=128, + desc_act=False, + sym=True, + in_features=128, + out_features=128, + pack_dtype=torch.int32, + dtype=torch.float16, + dynamic={"-:^model\\.layers\\.4\\.mlp$": {}}, + ) + + assert ok, err + assert err is None diff --git a/tests/test_random_string.py b/tests/test_random_string.py new file mode 100644 index 000000000..2c385b273 --- /dev/null +++ b/tests/test_random_string.py @@ -0,0 +1,48 @@ +import random +import string + +import pytest + +from gptqmodel.utils.random_str import get_random_string + + +def test_default_length(): + s = get_random_string() + assert len(s) == 8 + + +def test_custom_length(): + s = get_random_string(16) + assert len(s) == 16 + + +def test_characters_are_lowercase_letters(): + s = get_random_string(100) + assert set(s).issubset(set(string.ascii_lowercase)) + + +def test_multiple_calls_produce_different_values(): + results = [get_random_string() for _ in range(5)] + assert len(set(results)) == 5 + + +def test_not_affected_by_random_seed(): + random.seed(42) + r1 = get_random_string() + + random.seed(42) + r2 = get_random_string() + + # Not affected by the seed → Should not be exactly identical + assert r1 != r2, f"{r1} and {r2} Should not be exactly identical" + + +def test_length_zero(): + s = get_random_string(0) + assert s == "" + + +@pytest.mark.parametrize("length", [1, 2, 10, 50]) +def test_various_lengths(length): + s = get_random_string(length) + assert len(s) == length diff --git a/tests/test_save_loaded_quantized_model.py b/tests/test_save_loaded_quantized_model.py index ea47a018a..5518d25cc 100644 --- a/tests/test_save_loaded_quantized_model.py +++ b/tests/test_save_loaded_quantized_model.py @@ -12,35 +12,67 @@ import tempfile # noqa: E402 import unittest # noqa: E402 +import pytest # noqa: E402 +import torch # noqa: E402 +from models.model_test import ModelTest # noqa: E402 from parameterized import parameterized # noqa: E402 from transformers import AutoTokenizer # noqa: E402 from gptqmodel import BACKEND, GPTQModel, get_best_device # noqa: E402 +from gptqmodel.quantization import FORMAT, METHOD # noqa: E402 +from gptqmodel.utils.importer import get_kernel_for_backend # noqa: E402 + + +pytestmark = [pytest.mark.model, pytest.mark.slow] MODEL_ID = "/monster/data/model/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit" class TestSave(unittest.TestCase): + def _require_backend(self, backend: BACKEND): + kernel_cls = get_kernel_for_backend(backend, METHOD.GPTQ, FORMAT.GPTQ) + ok, err = kernel_cls.cached_validate_once() + if not ok: + self.skipTest(f"{backend} unavailable: {err}") + + def _generate_or_skip(self, model, backend: BACKEND, tokenizer, prompt, **kwargs): + try: + return ModelTest.generate_stable_with_limit(model, tokenizer, prompt, **kwargs) + except Exception as exc: + if backend == BACKEND.BITBLAS: + message = str(exc).lower() + if "illegal memory access" in message or isinstance(exc, torch.AcceleratorError): + self.skipTest(f"{backend} runtime unstable in this environment: {exc}") + raise + @parameterized.expand( [ (BACKEND.AUTO), (BACKEND.EXLLAMA_V2), - (BACKEND.EXLLAMA_V1), (BACKEND.TRITON), (BACKEND.BITBLAS), (BACKEND.MARLIN), ] ) def test_save(self, backend: BACKEND): + if backend != BACKEND.AUTO: + self._require_backend(backend) + prompt = "I am in Paris and" device = get_best_device(backend) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) - inp = tokenizer(prompt, return_tensors="pt").to(device) # origin model produce correct output origin_model = GPTQModel.load(MODEL_ID, backend=backend, device=device) - origin_model_res = origin_model.generate(**inp, num_beams=1, min_new_tokens=60, max_new_tokens=60) - origin_model_predicted_text = tokenizer.decode(origin_model_res[0]) + origin_model_predicted_text = self._generate_or_skip( + origin_model, + backend, + tokenizer, + prompt, + min_new_tokens=60, + max_new_tokens=60, + skip_special_tokens=False, + ) with tempfile.TemporaryDirectory() as tmpdir: origin_model.save(tmpdir) @@ -48,8 +80,15 @@ def test_save(self, backend: BACKEND): # saved model produce wrong output new_model = GPTQModel.load(tmpdir, backend=backend, device=device) - new_model_res = new_model.generate(**inp, num_beams=1, min_new_tokens=60, max_new_tokens=60) - new_model_predicted_text = tokenizer.decode(new_model_res[0]) + new_model_predicted_text = self._generate_or_skip( + new_model, + backend, + tokenizer, + prompt, + min_new_tokens=60, + max_new_tokens=60, + skip_special_tokens=False, + ) print("origin_model_predicted_text",origin_model_predicted_text) print("new_model_predicted_text",new_model_predicted_text) diff --git a/tests/test_save_loaded_quantized_model_torch_fused.py b/tests/test_save_loaded_quantized_model_torch_fused.py new file mode 100644 index 000000000..2aa39170a --- /dev/null +++ b/tests/test_save_loaded_quantized_model_torch_fused.py @@ -0,0 +1,58 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +# -- do not touch +import os + + +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +# -- end do not touch +import tempfile # noqa: E402 +import unittest # noqa: E402 + +from models.model_test import ModelTest # noqa: E402 +from transformers import AutoTokenizer # noqa: E402 + +from gptqmodel import BACKEND, GPTQModel # noqa: E402 + + +MODEL_ID = "/monster/data/model/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit" + +class TestSaveTorchFused(unittest.TestCase): + def test_save(self): + prompt = "I am in Paris and" + backend = BACKEND.TORCH_FUSED + tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + + # origin model produce correct output + origin_model = GPTQModel.load(MODEL_ID, backend=backend) + origin_model_predicted_text = ModelTest.generate_stable_with_limit( + origin_model, + tokenizer, + prompt, + min_new_tokens=60, + max_new_tokens=60, + skip_special_tokens=False, + ) + + with tempfile.TemporaryDirectory() as tmpdir: + origin_model.save(tmpdir) + + # saved model produce wrong output + new_model = GPTQModel.load(tmpdir, backend=backend) + + new_model_predicted_text = ModelTest.generate_stable_with_limit( + new_model, + tokenizer, + prompt, + min_new_tokens=60, + max_new_tokens=60, + skip_special_tokens=False, + ) + + print("origin_model_predicted_text", origin_model_predicted_text) + print("new_model_predicted_text", new_model_predicted_text) + + self.assertEqual(origin_model_predicted_text[:20], new_model_predicted_text[:20]) diff --git a/tests/test_serialization.py b/tests/test_serialization.py index c2eaeacbc..e0bcec20e 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -18,8 +18,20 @@ import torch # noqa: E402 from gptqmodel import BACKEND, GPTQModel # noqa: E402 -from gptqmodel.quantization import FORMAT, FORMAT_FIELD_CHECKPOINT, QuantizeConfig # noqa: E402 -from gptqmodel.quantization.config import GPTAQConfig, HessianConfig, VramStrategy # noqa: E402 +from gptqmodel.quantization import ( # noqa: E402 + FORMAT, + FORMAT_FIELD_CHECKPOINT, + FORMAT_FIELD_CODE, + METHOD_FIELD_CODE, + QuantizeConfig, +) +from gptqmodel.quantization.config import ( # noqa: E402 # noqa: E402 + METHOD, + GGUFConfig, + GPTAQConfig, + HessianConfig, + VramStrategy, +) class TestSerialization(unittest.TestCase): @@ -49,8 +61,24 @@ def test_gptq_v1_serialization(self): with open(os.path.join(tmpdir, "quantize_config.json"), "r") as f: quantize_config = json.load(f) + self.assertEqual(quantize_config[METHOD_FIELD_CODE], "gptq") + self.assertEqual(quantize_config["quant_method"], "gptq") + self.assertEqual(quantize_config[FORMAT_FIELD_CODE], "gptq") self.assertEqual(quantize_config[FORMAT_FIELD_CHECKPOINT], "gptq") + def test_legacy_checkpoint_format_load_normalizes_to_format(self): + cfg = QuantizeConfig.from_quant_config( + { + "bits": 4, + "checkpoint_format": "gguf", + } + ) + + self.assertIsInstance(cfg, GGUFConfig) + self.assertEqual(cfg.format, "q_0") + self.assertEqual(cfg.method, METHOD.GGUF) + self.assertEqual(cfg.quant_method, METHOD.GGUF) + def test_quantize_config_meta_only_fields_serialization(self): cfg = QuantizeConfig( gptaq=GPTAQConfig(alpha=0.75, device="cpu"), @@ -65,7 +93,10 @@ def test_quantize_config_meta_only_fields_serialization(self): chunk_bytes=4096, staging_dtype=torch.bfloat16, ), - vram_strategy=VramStrategy.BALANCED, + dense_vram_strategy=VramStrategy.BALANCED, + dense_vram_strategy_devices=["cuda:0", "cuda:1"], + moe_vram_strategy=VramStrategy.BALANCED, + moe_vram_strategy_devices=["cuda:2", "cuda:3"], ) payload = cfg.to_dict() @@ -73,7 +104,7 @@ def test_quantize_config_meta_only_fields_serialization(self): self.assertIsInstance(meta, dict) meta_only_fields = [ - "failsafe", + "fallback", "gptaq", "offload_to_disk", "offload_to_disk_path", @@ -83,7 +114,10 @@ def test_quantize_config_meta_only_fields_serialization(self): "mock_quantization", "act_group_aware", "hessian", - "vram_strategy", + "dense_vram_strategy", + "dense_vram_strategy_devices", + "moe_vram_strategy", + "moe_vram_strategy_devices", ] for field in meta_only_fields: self.assertNotIn(field, payload) @@ -101,7 +135,10 @@ def test_quantize_config_meta_only_fields_serialization(self): self.assertEqual(meta["hessian"]["chunk_size"], cfg.hessian.chunk_size) self.assertEqual(meta["hessian"]["chunk_bytes"], cfg.hessian.chunk_bytes) self.assertEqual(meta["hessian"]["staging_dtype"], "bfloat16") - self.assertEqual(meta["vram_strategy"], cfg.vram_strategy.value) + self.assertEqual(meta["dense_vram_strategy"], cfg.dense_vram_strategy.value) + self.assertEqual(meta["dense_vram_strategy_devices"], cfg.dense_vram_strategy_devices) + self.assertEqual(meta["moe_vram_strategy"], cfg.moe_vram_strategy.value) + self.assertEqual(meta["moe_vram_strategy_devices"], cfg.moe_vram_strategy_devices) def test_gptaq_config_none_serialization(self): cfg = QuantizeConfig() diff --git a/tests/test_serve_vllm_qwen35.py b/tests/test_serve_vllm_qwen35.py new file mode 100644 index 000000000..5f638f59e --- /dev/null +++ b/tests/test_serve_vllm_qwen35.py @@ -0,0 +1,36 @@ +import importlib.util +import sys +import types +from pathlib import Path + + +MODULE_PATH = Path(__file__).resolve().parents[1] / "scripts" / "serve_vllm_qwen35.py" +MODULE_SPEC = importlib.util.spec_from_file_location("serve_vllm_qwen35", MODULE_PATH) +assert MODULE_SPEC is not None and MODULE_SPEC.loader is not None +serve_vllm_qwen35 = importlib.util.module_from_spec(MODULE_SPEC) +MODULE_SPEC.loader.exec_module(serve_vllm_qwen35) + + +def test_qwen35_detection_reads_local_config(tmp_path): + (tmp_path / "config.json").write_text('{"model_type": "qwen3_5_text"}', encoding="utf-8") + + assert serve_vllm_qwen35._is_qwen35_text_checkpoint(str(tmp_path)) + + +def test_qwen35_detection_uses_autoconfig_for_repo_ids(monkeypatch): + calls = [] + + class DummyConfig: + def to_dict(self): + return {"model_type": "qwen3_5_text"} + + class DummyAutoConfig: + @staticmethod + def from_pretrained(model_id, trust_remote_code): + calls.append((model_id, trust_remote_code)) + return DummyConfig() + + monkeypatch.setitem(sys.modules, "transformers", types.SimpleNamespace(AutoConfig=DummyAutoConfig)) + + assert serve_vllm_qwen35._is_qwen35_text_checkpoint("groxaxo/qwen35-gptq-pro") + assert calls == [("groxaxo/qwen35-gptq-pro", True)] diff --git a/tests/test_sglang.py b/tests/test_sglang.py index 169972468..cf517b92a 100644 --- a/tests/test_sglang.py +++ b/tests/test_sglang.py @@ -10,10 +10,18 @@ # -- end do not touch import importlib.util # noqa: E402 +import unittest # noqa: E402 +import pytest # noqa: E402 from models.model_test import ModelTest # noqa: E402 from gptqmodel import BACKEND, GPTQModel # noqa: E402 +from gptqmodel.utils.sglang import SGLANG_AVAILABLE, SGLANG_INSTALL_HINT # noqa: E402 + + +pytestmark = [pytest.mark.model, pytest.mark.slow] + +pytestmark = [pytest.mark.model, pytest.mark.slow] class TestLoadSglang(ModelTest): @@ -21,8 +29,12 @@ class TestLoadSglang(ModelTest): @classmethod def setUpClass(self): # sglang set disable_flashinfer=True still import flashinfer - if importlib.util.find_spec("flashinfer") is None or importlib.util.find_spec("sglang") is None: - raise RuntimeError("flashinfer and sglang are required by this test. you can install them by `pip install gptqmodel['sglang']`") + if importlib.util.find_spec("flashinfer") is None: + raise unittest.SkipTest( + "flashinfer is required by this test. install via `pip install gptqmodel['sglang']`" + ) + if importlib.util.find_spec("sglang") is None or not SGLANG_AVAILABLE: + raise unittest.SkipTest(SGLANG_INSTALL_HINT) self.MODEL_ID = "/monster/data/model/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit" @@ -42,4 +54,3 @@ def test_load_sglang(self): self.assertTrue(len(output)>5) model.shutdown() del model - diff --git a/tests/test_sharded.py b/tests/test_sharded.py index 9b2a00bd3..b8c8cc103 100644 --- a/tests/test_sharded.py +++ b/tests/test_sharded.py @@ -15,6 +15,7 @@ import unittest # noqa: E402 import torch # noqa: E402 +from models.model_test import ModelTest # noqa: E402 from transformers import AutoTokenizer # noqa: E402 from gptqmodel import GPTQModel # noqa: E402 @@ -53,10 +54,14 @@ def test_save_and_load(self): device_map="auto", ) - inp = tokenizer(self.prompt, return_tensors="pt").to(self.device) - - tokens = model.generate(**inp, num_beams=1, do_sample=False, min_new_tokens=60, max_new_tokens=60) - result = tokenizer.decode(tokens[0]) + result = ModelTest.generate_stable_with_limit( + model, + tokenizer, + self.prompt, + min_new_tokens=60, + max_new_tokens=60, + skip_special_tokens=False, + ) self.assertEqual(result[:100], self.reference_output[:100]) @@ -84,9 +89,13 @@ def test_save_and_load_no_shard(self): device_map="auto", ) - inp = tokenizer(self.prompt, return_tensors="pt").to(self.device) - - tokens = model.generate(**inp, num_beams=1, do_sample=False, min_new_tokens=60, max_new_tokens=60) - result = tokenizer.decode(tokens[0]) + result = ModelTest.generate_stable_with_limit( + model, + tokenizer, + self.prompt, + min_new_tokens=60, + max_new_tokens=60, + skip_special_tokens=False, + ) self.assertEqual(result[:100], self.reference_output[:100]) diff --git a/tests/test_simple_quant.py b/tests/test_simple_quant.py index 5e081b7a3..f595c0ffc 100644 --- a/tests/test_simple_quant.py +++ b/tests/test_simple_quant.py @@ -5,12 +5,16 @@ import tempfile +import pytest from datasets import load_dataset from logbar import LogBar from gptqmodel import GPTAQConfig, GPTQModel, QuantizeConfig from gptqmodel.quantization import FORMAT -from gptqmodel.utils.eval import EVAL +from tests.eval import evaluate, format_eval_result_table, get_eval_task_metrics + + +pytestmark = [pytest.mark.model, pytest.mark.slow] log = LogBar.shared() @@ -58,35 +62,54 @@ def get_calib_data(tokenizer, rows: int): gptaq=GPTAQConfig() if CFG_V2 else None, ) -log.info(f"QuantConfig: {quant_config}") +def _run_simple_quant_eval(): + """Run the legacy simple-quant flow as a real pytest workload.""" + log.info(f"QuantConfig: {quant_config}") + + if not EVAL_ONLY: + log.info(f"Save Path: {QUANT_SAVE_PATH}") + + # load un-quantized native model + model = GPTQModel.load(MODEL_ID, quant_config) + + # load calibration data + calibration_dataset = get_calib_data(tokenizer=model.tokenizer, rows=256) -if not EVAL_ONLY: - log.info(f"Save Path: {QUANT_SAVE_PATH}") + model.quantize(calibration_dataset, batch_size=1) - # load un-quantized native model - model = GPTQModel.load(MODEL_ID, quant_config) + model.save(QUANT_SAVE_PATH) + log.info(f"Quant Model Saved to: {QUANT_SAVE_PATH}") - # load calibration data - calibration_dataset = get_calib_data(tokenizer=model.tokenizer, rows=256) + with tempfile.TemporaryDirectory() as tmp_dir: + results = evaluate( + QUANT_SAVE_PATH, + tasks=["gsm8k_cot"], #, "gsm8k_platinum_cot"], + apply_chat_template=True, + output_path=tmp_dir, + ) - model.quantize(calibration_dataset, batch_size=1) + print(format_eval_result_table(results)) - model.save(QUANT_SAVE_PATH) - log.info(f"Quant Model Saved to: {QUANT_SAVE_PATH}") + metrics = get_eval_task_metrics(results, "gsm8k_cot") + filtered_metrics = { + metric: value + for metric, value in metrics.items() + if metric != "alias" and "stderr" not in metric + } -# eval -from lm_eval.utils import make_table + value = filtered_metrics['acc,num'] + expected = 0.7998 + diff_pct = (value / expected) * 100 + floor_pct = 0.05 + ceil_pct = 0.10 + negative_pct = 100 * (1 - floor_pct) + positive_pct = 100 * (1 + ceil_pct) + assert negative_pct <= diff_pct <= positive_pct, (f"gsm8k_cot:acc,num: `{value}` vs " + f"expected `{expected}`, diff {diff_pct:.2f}% is out of the " + f"expected range [{negative_pct}-{positive_pct}%]") -with tempfile.TemporaryDirectory() as tmp_dir: - results = GPTQModel.eval( - QUANT_SAVE_PATH, - tasks=[EVAL.LM_EVAL.GSM8K_COT], #, EVAL.LM_EVAL.GSM8K_PLATINUM_COT], - apply_chat_template=True, - random_seed=898, - output_path= tmp_dir, - ) - print(make_table(results)) - if "groups" in results: - print(make_table(results, "groups")) +def test_simple_quant(): + """Keep the simple-quant regression runnable under pytest collection.""" + _run_simple_quant_eval() diff --git a/tests/test_split_by_layer_save.py b/tests/test_split_by_layer_save.py new file mode 100644 index 000000000..b08a0be5e --- /dev/null +++ b/tests/test_split_by_layer_save.py @@ -0,0 +1,199 @@ +import copy +import json +import os +from types import SimpleNamespace + +import torch +import torch.nn as nn +from accelerate import load_checkpoint_in_model + +from gptqmodel.models.writer import ModelWriter +from gptqmodel.quantization.config import FORMAT, METHOD + + +class _DummyKernel: + REQUIRES_FORMAT_V2 = False + SUPPORTS_SHARDS = True + + +class _DummyQuantizeConfig: + method = METHOD.GPTQ + format = FORMAT.GPTQ + checkpoint_format = FORMAT.GPTQ + quant_method = METHOD.GPTQ + damp_percent = 0.0 + damp_auto_increment = 0.0 + static_groups = False + true_sequential = False + mse = False + gptaq = None + act_group_aware = False + adapter = None + dynamic = False + offload_to_disk = False + offload_to_disk_path = None + lm_head = False + + def __init__(self): + self._meta = {} + + def __deepcopy__(self, memo): + clone = type(self)() + memo[id(self)] = clone + clone._meta = copy.deepcopy(self._meta, memo) + return clone + + def meta_set_versionable(self, key, value): + self._meta[key] = value + + def meta_set(self, key, value): + self._meta[key] = value + + def to_dict(self): + return {"meta": dict(self._meta)} + + def save_pretrained(self, save_dir): + with open(os.path.join(save_dir, "quantize_config.json"), "w", encoding="utf-8") as handle: + json.dump({"meta": dict(self._meta)}, handle) + + def extract_adapter_rank_patterns(self): + return {} + + +class _DummyConfig: + def __deepcopy__(self, memo): + clone = type(self)() + memo[id(self)] = clone + clone.__dict__ = copy.deepcopy(self.__dict__, memo) + return clone + + +class _DummyGenerationConfig(_DummyConfig): + pass + + +class _TinySplitModel(nn.Module): + def __init__(self): + super().__init__() + self.model = nn.Module() + self.model.embed_tokens = nn.Embedding(6, 4) + self.model.layers = nn.ModuleList( + [ + nn.Linear(4, 4), + nn.Linear(4, 4), + ] + ) + self.model.norm = nn.LayerNorm(4) + self.lm_head = nn.Linear(4, 6, bias=False) + self.config = _DummyConfig() + self.generation_config = _DummyGenerationConfig() + + with torch.no_grad(): + for idx, (_, tensor) in enumerate(self.state_dict().items(), start=1): + tensor.copy_(torch.arange(tensor.numel(), dtype=tensor.dtype).reshape(tensor.shape) + idx) + + def save_pretrained(self, save_dir, state_dict=None, is_main_process=True): + with open(os.path.join(save_dir, "config.json"), "w", encoding="utf-8") as handle: + json.dump({"dummy": True}, handle) + with open(os.path.join(save_dir, "generation_config.json"), "w", encoding="utf-8") as handle: + json.dump({"do_sample": False}, handle) + + +def _build_writer(tmp_path): + class _Base: + @classmethod + def extract_layers_node(cls): + return ["model.layers"] + + DummyWriter = ModelWriter(_Base) + instance = DummyWriter() + instance.quantized = True + instance.quantize_config = _DummyQuantizeConfig() + instance.quant_log = [] + instance.load_quantized_model = False + instance.qlinear_kernel = _DummyKernel() + instance.model_local_path = str(tmp_path / "original") + instance.trust_remote_code = False + instance.tokenizer = None + instance.processor = None + instance.turtle_model = SimpleNamespace() + instance.model = _TinySplitModel() + os.makedirs(instance.model_local_path, exist_ok=True) + return instance + + +def _patch_writer_env(monkeypatch): + monkeypatch.setattr("gptqmodel.models.writer.get_model_files_size", lambda _: 1) + monkeypatch.setattr("gptqmodel.models.writer.alias_all_from_turtle_if_meta", lambda *args, **kwargs: None) + monkeypatch.setattr("gptqmodel.models.writer.sanitize_model_config", lambda *_args, **_kwargs: None) + monkeypatch.setattr("gptqmodel.models.writer.sanitize_generation_config_file", lambda *_args, **_kwargs: False) + + +def test_save_quantized_split_by_layer_writes_per_layer_dirs(tmp_path, monkeypatch): + writer = _build_writer(tmp_path) + _patch_writer_env(monkeypatch) + + save_dir = tmp_path / "save" + writer.save_quantized(save_dir=str(save_dir), split_by="layer", max_shard_size=None) + + assert (save_dir / "model.layers.0" / "layer.safetensors").exists() + assert (save_dir / "model.layers.1" / "layer.safetensors").exists() + assert (save_dir / "model.embed_tokens.safetensors").exists() + assert (save_dir / "model.norm.safetensors").exists() + assert (save_dir / "lm_head.safetensors").exists() + assert not (save_dir / "model.embed_tokens").exists() + assert not (save_dir / "model.norm").exists() + assert not (save_dir / "lm_head").exists() + assert not (save_dir / "model.safetensors").exists() + + index = json.loads((save_dir / "model.safetensors.index.json").read_text()) + + assert index["weight_map"]["model.layers.0.weight"] == "model.layers.0/layer.safetensors" + assert index["weight_map"]["model.layers.1.bias"] == "model.layers.1/layer.safetensors" + assert index["weight_map"]["model.embed_tokens.weight"] == "model.embed_tokens.safetensors" + assert index["weight_map"]["model.norm.weight"] == "model.norm.safetensors" + assert index["weight_map"]["lm_head.weight"] == "lm_head.safetensors" + + +def test_save_quantized_split_by_layer_still_shards_large_layer(tmp_path, monkeypatch): + writer = _build_writer(tmp_path) + _patch_writer_env(monkeypatch) + + save_dir = tmp_path / "save" + writer.save_quantized(save_dir=str(save_dir), split_by="layer", max_shard_size=64) + + layer0_dir = save_dir / "model.layers.0" + layer0_shards = sorted(path.name for path in layer0_dir.glob("*.safetensors")) + assert layer0_shards == [ + "layer-00001-of-00002.safetensors", + "layer-00002-of-00002.safetensors", + ] + + index = json.loads((save_dir / "model.safetensors.index.json").read_text()) + weight_file = index["weight_map"]["model.layers.0.weight"] + bias_file = index["weight_map"]["model.layers.0.bias"] + + assert weight_file.startswith("model.layers.0/layer-") + assert bias_file.startswith("model.layers.0/layer-") + assert weight_file != bias_file + + +def test_split_by_layer_index_loads_nested_layer_shards(tmp_path, monkeypatch): + writer = _build_writer(tmp_path) + _patch_writer_env(monkeypatch) + + save_dir = tmp_path / "save" + expected_state = {name: tensor.clone() for name, tensor in writer.model.state_dict().items()} + + writer.save_quantized(save_dir=str(save_dir), split_by="layer", max_shard_size=64) + + reloaded = _TinySplitModel() + with torch.no_grad(): + for tensor in reloaded.state_dict().values(): + tensor.zero_() + + load_checkpoint_in_model(reloaded, checkpoint=str(save_dir / "model.safetensors.index.json")) + + reloaded_state = reloaded.state_dict() + for name, expected in expected_state.items(): + torch.testing.assert_close(reloaded_state[name], expected) diff --git a/tests/test_stage_modules.py b/tests/test_stage_modules.py index d68e136ff..ea4cff77f 100644 --- a/tests/test_stage_modules.py +++ b/tests/test_stage_modules.py @@ -1,21 +1,41 @@ +import sys import threading import types from typing import Dict import torch +import gptqmodel.looper.stage_subset as stage_subset_module +from gptqmodel.looper.forward_executor import ForwardExecutor +from gptqmodel.looper.loop_processor import ExecutionConfig from gptqmodel.looper.module_looper import FinalizeProgressInfo, ModuleLooper from gptqmodel.looper.named_module import NamedModule +from gptqmodel.looper.paroquant_processor import ParoQuantProcessor from gptqmodel.looper.stage_inputs_capture import StageInputsCapture -from gptqmodel.looper.stage_layer import run_layer_stage -from gptqmodel.looper.stage_subset import SubsetForwardContext, SubsetStageResult -from gptqmodel.utils.pause_resume import PauseResumeController +from gptqmodel.looper.stage_layer import ( + _capture_pristine_group_context, + _processor_needs_pristine_group_clone, + _replay_layer_outputs, + _should_drain_finalize_futures_synchronously, + _should_empty_cache_after_sync_finalize, + run_layer_stage, +) +from gptqmodel.looper.stage_subset import CalibrationCoveragePolicy, SubsetPlan, SubsetStageResult +from gptqmodel.models.base import BaseQModel +from gptqmodel.quantization.config import QuantizeConfig class _DummyQModel: def __init__(self): self.support_batch_quantize = False - self.quantize_config = types.SimpleNamespace(device=None, vram_strategy="exclusive", moe_routing_bypass=lambda : False) + self.quantize_config = types.SimpleNamespace( + device=None, + dense_vram_strategy="exclusive", + dense_vram_strategy_devices=None, + moe_vram_strategy="exclusive", + moe_vram_strategy_devices=None, + moe_routing_bypass=lambda: False, + ) self.layer_callback = None @@ -53,6 +73,71 @@ def cache_inputs(self, **kwargs): assert captured["kwargs"]["calibration_data"] is data +def test_assign_quant_device_prefers_balanced_hint(): + looper = _make_looper() + looper._quant_devices = [torch.device("cuda:0"), torch.device("cuda:1")] + looper._module_device_map = {} + looper._quant_device_rr = 0 + + named = NamedModule( + torch.nn.Linear(4, 4, bias=False), + name="mlp.experts.1.gate_proj", + full_name="model.layers.0.mlp.experts.1.gate_proj", + layer_index=0, + ) + named.state["preferred_quant_device"] = torch.device("cuda:1") + + target = looper._assign_quant_device_for_module( + named, + fallback_device=torch.device("cuda:0"), + ) + + assert target == torch.device("cuda:1") + assert looper._module_device_map[named.full_name] == torch.device("cuda:1") + assert looper._quant_device_rr == 0 + + +def test_module_looper_runtime_telemetry_reports_gil_and_split_pools(monkeypatch): + emitted = [] + info_logs = [] + warn_logs = [] + module_looper_module = sys.modules[ModuleLooper.__module__] + + monkeypatch.setattr( + module_looper_module, + "emit_device_telemetry", + lambda event, **fields: emitted.append((event, fields)), + ) + monkeypatch.setattr(module_looper_module, "has_gil_control", lambda: True) + monkeypatch.setattr(module_looper_module, "has_gil_disabled", lambda: True) + monkeypatch.setattr(module_looper_module.os, "environ", {"PYTHON_GIL": "0"}) + monkeypatch.setattr(module_looper_module.log, "info", lambda *args, **kwargs: info_logs.append(args)) + monkeypatch.setattr(module_looper_module.log, "warn", lambda *args, **kwargs: warn_logs.append(args)) + + looper = ModuleLooper.__new__(ModuleLooper) + looper.gptq_model = types.SimpleNamespace(dynamic_expert_index=object()) + looper._dense_quant_devices = [torch.device("cuda:0")] + looper._moe_quant_devices = [torch.device("cuda:1"), torch.device("cuda:2")] + looper._dense_vram_strategy = "exclusive" + looper._moe_vram_strategy = "balanced" + looper.moe_routing_override = 256 + looper.moe_routing_bypass = False + + looper._emit_moe_parallel_quant_runtime() + + assert info_logs + assert not warn_logs + assert len(emitted) == 1 + event, fields = emitted[0] + assert event == "moe_parallel_quant_runtime" + assert fields["dense_devices"] == ["cuda:0"] + assert fields["moe_devices"] == ["cuda:1", "cuda:2"] + assert fields["routing_override"] == 256 + assert fields["python_gil_env"] == "0" + assert fields["python_gil_disabled"] is True + assert fields["free_threaded_parallel_quant_eligible"] is True + + class _TinyLayer(torch.nn.Module): def forward(self, hidden_states, attention_mask=None, position_ids=None, **kwargs): return hidden_states @@ -78,15 +163,25 @@ class _TinyGptqModel: ATTENTION_MASKS_REQUIRED_FOR_INPUT = False ATTENTION_MASKS_DTYPE = torch.long INPUT_EMBEDDING_EXTRA_ARGS = {} + finalize_input_capture_example = BaseQModel.finalize_input_capture_example + capture_first_layer_positional_inputs = BaseQModel.capture_first_layer_positional_inputs + capture_first_layer_input_kwargs = BaseQModel.capture_first_layer_input_kwargs + move_input_capture_example = BaseQModel.move_input_capture_example + prepare_layer_replay_kwargs = BaseQModel.prepare_layer_replay_kwargs + run_input_capture = BaseQModel.run_input_capture def __init__(self): self.layer = _TinyLayer() self.model = _TinyModel(self.layer) - self.quantize_config = types.SimpleNamespace(device=torch.device("cpu")) + self.quantize_config = types.SimpleNamespace( + device=torch.device("cpu"), + calibration_data_device=None, + ) self._hook_started = False self._hook_finished = False - def shell_module_materialize(self, target_submodule, device): + def shell_module_materialize(self, target_submodule, device, **kwargs): + del kwargs target_submodule.to(device) return target_submodule @@ -111,6 +206,232 @@ def _batch_row_count(self, batch_inputs): return int(tensor.shape[0]) if tensor.ndim > 0 else int(tensor.numel()) +class _TinyExecutorLayer(torch.nn.Module): + def forward(self, hidden_states, **kwargs): + return hidden_states + + +class _RecordingCtx: + def __init__(self, sink): + self.sink = sink + + def __enter__(self): + self.sink.append("enter") + return self + + def __exit__(self, exc_type, exc, tb): + return False + + +class _DummyForwardProcessor: + num_batches = None + + def _set_current_batch_index(self, _idx): + return None + + +class _ImmediateFuture: + def __init__(self, result): + self._result = result + + def result(self): + return self._result + + +class _ImmediateThreadPool: + def submit(self, _device, fn, *args, **kwargs): + return _ImmediateFuture(fn(*args, **kwargs)) + + def submit_serial(self, _device, fn, *args, **kwargs): + return _ImmediateFuture(fn(*args, **kwargs)) + + +def _make_forward_executor_looper( + *, + override_entries=None, + lifecycle_entries=None, + moe_routing_override=None, + moe_routing_bypass=False, + should_use_moe_lifecycle=False, +): + def _override_context(*_args, **_kwargs): + if override_entries is None: + raise AssertionError("override should stay disabled") + return _RecordingCtx(override_entries) + + def _lifecycle_context(*_args, **_kwargs): + if lifecycle_entries is None: + raise AssertionError("lifecycle should stay disabled") + return _RecordingCtx(lifecycle_entries) + + return types.SimpleNamespace( + _resolve_batch_total=lambda _num_batches, layer_inputs: len(layer_inputs), + _collect_row_counts=lambda layer_inputs: [int(batch[0].shape[0]) for batch in layer_inputs], + _set_processor_mask=lambda _processor, _mask: None, + _batch_row_count=lambda batch_inputs: int(batch_inputs[0].shape[0]), + support_batch_quantize=False, + gptq_model=types.SimpleNamespace( + quantize_config=types.SimpleNamespace( + calibration_data_device=None, + compute_device_filter=None, + ), + prepare_layer_replay_kwargs=lambda layer, layer_input, additional_inputs, target_device: additional_inputs, + ), + moe_routing_override=moe_routing_override, + moe_routing_bypass=moe_routing_bypass, + MoERoutingOverrideContext=_override_context, + MoELifecycleContext=_lifecycle_context, + _should_use_moe_lifecycle=lambda *_args, **_kwargs: should_use_moe_lifecycle, + _current_subset=None, + ) + + +def _run_executor_single(executor, processor, *, apply_moe_config): + return executor.run_single( + module=_TinyExecutorLayer(), + processor=processor, + layer_inputs=[[torch.zeros(1, 1, 1)]], + layer_input_kwargs=[{}], + position_ids=[], + attention_masks=[None], + cur_layer_device=torch.device("cpu"), + is_lm_head_module=False, + shared_kv_cache_dict={}, + layer_index=0, + need_outputs=True, + reuse_kv=False, + apply_moe_config=apply_moe_config, + ) + + +def _run_executor_parallel(executor, processor, *, apply_moe_config): + def clone_module_for_devices_fn(module, devices, progress_callback=None): + del progress_callback + return dict.fromkeys(devices, module) + + def forward_batch_worker_fn( + _replica, + _processor, + batch_idx, + _batch_inputs, + _batch_kwargs, + _attention_mask, + _position_ids, + **_kwargs, + ): + return batch_idx, torch.zeros(1, 1, 1), None + + return executor.run_parallel( + module=_TinyExecutorLayer(), + processor=processor, + layer_inputs=[[torch.zeros(1, 1, 1)], [torch.zeros(1, 1, 1)]], + layer_input_kwargs=[{}, {}], + position_ids=[None, None], + attention_masks=[None, None], + cur_layer_device=torch.device("cpu"), + is_lm_head_module=False, + shared_kv_cache_dict={}, + layer_index=0, + need_outputs=True, + reuse_kv=False, + devices=[torch.device("cuda:0"), torch.device("cuda:1")], + apply_moe_config=apply_moe_config, + clone_module_for_devices_fn=clone_module_for_devices_fn, + forward_batch_worker_fn=forward_batch_worker_fn, + device_thread_pool=_ImmediateThreadPool(), + ) + + +def test_stage_layer_forces_sync_finalizers_for_paroquant(): + looper = types.SimpleNamespace( + gptq_model=types.SimpleNamespace( + quantize_config=QuantizeConfig( + bits=4, + group_size=128, + wait_for_submodule_finalizers=False, + ) + ) + ) + paro_processor = object.__new__(ParoQuantProcessor) + + assert _should_drain_finalize_futures_synchronously( + looper, + finalize_tasks=[(paro_processor, None, None, None, None)], + ) is True + + +def test_stage_layer_keeps_async_finalizers_for_non_paroquant_when_unset(): + looper = types.SimpleNamespace( + gptq_model=types.SimpleNamespace( + quantize_config=QuantizeConfig( + bits=4, + group_size=128, + wait_for_submodule_finalizers=False, + ) + ) + ) + + assert _should_drain_finalize_futures_synchronously( + looper, + finalize_tasks=[(types.SimpleNamespace(), None, None, None, None)], + ) is False + + +def test_stage_layer_empties_cache_after_sync_paroquant_finalize_only_with_offload(): + looper = types.SimpleNamespace( + gptq_model=types.SimpleNamespace( + quantize_config=QuantizeConfig( + bits=4, + group_size=128, + offload_to_disk=True, + ) + ) + ) + paro_processor = object.__new__(ParoQuantProcessor) + + assert _should_empty_cache_after_sync_finalize( + looper, + finalize_tasks=[(paro_processor, None, None, None, None)], + ) is True + + looper.gptq_model.quantize_config.offload_to_disk = False + assert _should_empty_cache_after_sync_finalize( + looper, + finalize_tasks=[(paro_processor, None, None, None, None)], + ) is False + +def test_stage_layer_paroquant_layer_scope_skips_pristine_group_clone(): + processor = object.__new__(ParoQuantProcessor) + processor.qcfg = types.SimpleNamespace(opt_scope="layer") + + assert _processor_needs_pristine_group_clone(processor) is False + + +def test_stage_layer_paroquant_compute_block_scope_keeps_pristine_group_clone(): + processor = object.__new__(ParoQuantProcessor) + processor.qcfg = types.SimpleNamespace(opt_scope="compute_block") + + assert _processor_needs_pristine_group_clone(processor) is True + + +def test_stage_subset_flush_stays_local_when_work_stays_on_cur_layer_device(): + cur_layer_device = torch.device("cuda:0") + + assert ( + stage_subset_module._resolve_cache_flush_device(cur_layer_device, [torch.device("cuda:0")]) + == cur_layer_device + ) + + +def test_stage_subset_flush_goes_global_when_work_fans_out_across_devices(): + cur_layer_device = torch.device("cuda:0") + + assert stage_subset_module._resolve_cache_flush_device( + cur_layer_device, + [torch.device("cuda:0"), torch.device("cuda:1")], + ) is None + + def test_stage_inputs_capture_collects_real_inputs(): gptq_model = _TinyGptqModel() looper = _TinyLooper(gptq_model) @@ -141,21 +462,91 @@ def test_stage_inputs_capture_collects_real_inputs(): assert gptq_model._hook_finished is True +def test_forward_executor_run_single_can_skip_moe_routing_override_for_replay(): + """Replay must skip top-k override, while quant-time forward still enables it.""" + + override_entries = [] + looper = _make_forward_executor_looper( + override_entries=override_entries, + lifecycle_entries=[], + moe_routing_override=256, + ) + executor = ForwardExecutor(looper) + processor = _DummyForwardProcessor() + + # Replay path: do not install any MoE routing override context. + outputs = _run_executor_single(executor, processor, apply_moe_config=False) + + assert len(outputs) == 1 + assert override_entries == [] + + override_entries.clear() + outputs = _run_executor_single(executor, processor, apply_moe_config=True) + + assert len(outputs) == 1 + assert override_entries == ["enter"] + + +def test_forward_executor_run_single_can_skip_moe_lifecycle_for_replay(): + """Replay must also skip bypass/lifecycle hooks, not just routing override.""" + + lifecycle_entries = [] + looper = _make_forward_executor_looper( + lifecycle_entries=lifecycle_entries, + moe_routing_bypass=True, + should_use_moe_lifecycle=True, + ) + executor = ForwardExecutor(looper) + processor = _DummyForwardProcessor() + + # Replay path: bypass routing stays off, so lifecycle hooks must not run. + outputs = _run_executor_single(executor, processor, apply_moe_config=False) + + assert len(outputs) == 1 + assert lifecycle_entries == [] + + outputs = _run_executor_single(executor, processor, apply_moe_config=True) + + assert len(outputs) == 1 + assert lifecycle_entries == ["enter"] + + +def test_forward_executor_run_parallel_can_skip_moe_config_for_replay(): + """Parallel replay must skip the same MoE config that serial replay skips.""" + + override_entries = [] + looper = _make_forward_executor_looper( + override_entries=override_entries, + lifecycle_entries=[], + moe_routing_override=8, + moe_routing_bypass=True, + should_use_moe_lifecycle=True, + ) + executor = ForwardExecutor(looper) + processor = _DummyForwardProcessor() + + # Replay path: each replica should stay on the model's native router. + outputs = _run_executor_parallel(executor, processor, apply_moe_config=False) + + assert len(outputs) == 2 + assert override_entries == [] + + # Quant-time path: replicas should still install the quant-time MoE context. + outputs = _run_executor_parallel(executor, processor, apply_moe_config=True) + + assert len(outputs) == 2 + assert override_entries == ["enter", "enter"] + + def test_run_layer_stage_invokes_subset_stage(monkeypatch): calls = [] def fake_run_subset_stage(looper, **kwargs): - calls.append(kwargs["subset_index"]) + calls.append(kwargs["plan"].subset_index) return SubsetStageResult( processed_subset={}, layer_inputs=kwargs["layer_inputs"], - forward_context=SubsetForwardContext( - subset={}, - forward_device_map={}, - subset_forward_serial=False, - subset_total=kwargs["subset_total"], - subset_index=kwargs["subset_index"], - ), + plan=kwargs["plan"], ) monkeypatch.setattr("gptqmodel.looper.stage_layer.run_subset_stage", fake_run_subset_stage) @@ -212,11 +603,13 @@ def error(self, *_, **__): return None class DummyProcessor: - fwd_all_modules_in_single_pass = False - fwd_after_process = False - def __init__(self): tensor = torch.zeros(1, 1, 1) + self.execution_config = ExecutionConfig( + require_fwd=True, + fwd_replay_after_process=False, + fwd_all_modules_in_single_pass=False, + ) self.inputs_cache = types.SimpleNamespace( layer_inputs=[[tensor]], layer_input_kwargs=[{}], @@ -261,7 +654,11 @@ class DummyGptqModel: def __init__(self): self.model = torch.nn.Module() self.model.config = types.SimpleNamespace(model_type="llama") - self.quantize_config = types.SimpleNamespace(lm_head=False) + self.quantize_config = QuantizeConfig( + bits=4, + group_size=128, + offload_to_disk=False, + ) self.lm_head = None def pre_quantize(self, module): @@ -281,13 +678,29 @@ def __init__(self): self._module_device_map = {} self._quant_device_lock = threading.Lock() self._moe_subset_threshold = 16 - self._vram_strategy = types.SimpleNamespace() + self._dense_quant_devices = [torch.device("cpu")] + self._moe_quant_devices = [torch.device("cpu")] + self._dense_vram_strategy = types.SimpleNamespace() + self._moe_vram_strategy = types.SimpleNamespace() + self._dense_vram_strategy_explicit = False + self._moe_vram_strategy_explicit = False self._layer_events = [] - self.pause_controller = PauseResumeController() def _check_loop_stop(self): return False + def _is_attention_module_name(self, _name): + return False + + def _extract_moe_group_key(self, _name): + return None + + def _resolve_batch_total(self, _num_batches, layer_inputs): + return len(layer_inputs) + + def _collect_row_counts(self, layer_inputs): + return [1 for _ in layer_inputs] + def _emit_layer_complete(self, *, layer_idx, submodule_finalized, raise_in_place): self._layer_events.append((layer_idx, submodule_finalized, raise_in_place)) @@ -298,7 +711,7 @@ def _subset_event_dispatch(self, *kwargs): pass def create_named_modules(self, module, full, is_lm_head_module, layer_index, layers_prefix, names, processor, - failsafe, layer_module=None) -> Dict[str, NamedModule]: + fallback, layer_module=None) -> Dict[str, NamedModule]: subset = {} name = "self_attn.q_proj" subset[name] = NamedModule(module, name=name, full_name=full, layer_index=layer_index) @@ -318,8 +731,9 @@ def create_named_modules(self, module, full, is_lm_head_module, layer_index, lay looper, layers=layers, layer_modules=layer_modules, + planning_layer_modules=layer_modules, layers_prefix="model.layers", - failsafe=True, + fallback=True, shared_kv_cache_dict={}, pb=pb, layer_count=1, @@ -329,3 +743,1327 @@ def create_named_modules(self, module, full, is_lm_head_module, layer_index, lay ) assert calls == [0] + + +def test_run_layer_stage_stops_after_last_quantized_layer(monkeypatch): + calls = [] + + def fake_run_subset_stage(looper, **kwargs): + calls.append(kwargs["layer_index"]) + return SubsetStageResult( + processed_subset={}, + layer_inputs=kwargs["layer_inputs"], + plan=kwargs["plan"], + ) + + monkeypatch.setattr("gptqmodel.looper.stage_layer.run_subset_stage", fake_run_subset_stage) + monkeypatch.setattr("gptqmodel.looper.stage_layer.find_modules", lambda *_, **__: {}) + + class DummyPB: + def __init__(self, iterable): + self._iterable = list(iterable) + self.current_iter_step = 0 + self.close_calls = 0 + + def __iter__(self): + return iter(self._iterable) + + def __len__(self): + return len(self._iterable) + + def manual(self): + return self + + def set(self, **kwargs): + return self + + def title(self, *_): + return self + + def subtitle(self, *_): + return self + + def draw(self): + return self + + def next(self): + return self + + def close(self): + self.close_calls += 1 + return self + + class DummyLogger: + def pb(self, iterable): + return DummyPB(iterable) + + def info(self, *_, **__): + return None + + def debug(self, *_, **__): + return None + + def warning(self, *_, **__): + return None + + warn = warning + + def error(self, *_, **__): + return None + + class DummyProcessor: + def __init__(self): + tensor = torch.zeros(1, 1, 1) + self.execution_config = ExecutionConfig( + require_fwd=True, + fwd_replay_after_process=False, + fwd_all_modules_in_single_pass=False, + ) + self.inputs_cache = types.SimpleNamespace( + layer_inputs=[[tensor]], + layer_input_kwargs=[{}], + position_ids=[], + attention_masks=[], + ) + self.calibration_dataset = [] + self.log = [] + self.tasks = {} + + def collect_memory_info(self, *_): + return None + + def pre_process_fwd_hook(self, *_): + return lambda *a, **k: None + + def process(self, *_, **__): + return None + + def clear_cache_data(self): + return None + + def receive_layer_inputs(self, inputs): + self.inputs_cache.layer_inputs = inputs + + def set_fwd_time(self, *_): + return None + + def name(self): + return "dummy" + + def submodule_finalize(self, *_, **__): + return None + + def finalize(self, *_, **__): + return None + + def log_plotly(self): + return None + + class DummyGptqModel: + def __init__(self): + self.model = torch.nn.Module() + self.model.config = types.SimpleNamespace(model_type="llama") + self.quantize_config = QuantizeConfig( + bits=4, + group_size=128, + offload_to_disk=False, + dynamic={ + r"-:^model\.layers\.1\.foo$": {}, + r"-:^model\.layers\.2\.foo$": {}, + }, + ) + self.lm_head = None + + def pre_quantize(self, module): + return module + + def post_quantize(self, module): + return module + + def lm_head_pre_quantize_generate_hook(self, value): + return value + + class DummyLooper: + def __init__(self): + self.gptq_model = DummyGptqModel() + self.processors = [DummyProcessor()] + self._quant_devices = [torch.device("cpu")] + self._module_device_map = {} + self._quant_device_lock = threading.Lock() + self._moe_subset_threshold = 16 + self._dense_quant_devices = [torch.device("cpu")] + self._moe_quant_devices = [torch.device("cpu")] + self._dense_vram_strategy = types.SimpleNamespace() + self._moe_vram_strategy = types.SimpleNamespace() + self._dense_vram_strategy_explicit = False + self._moe_vram_strategy_explicit = False + self._layer_events = [] + self.named_module_layers = [] + + def _check_loop_stop(self): + return False + + def _is_attention_module_name(self, _name): + return False + + def _extract_moe_group_key(self, _name): + return None + + def _resolve_batch_total(self, _num_batches, layer_inputs): + return len(layer_inputs) + + def _collect_row_counts(self, layer_inputs): + return [1 for _ in layer_inputs] + + def _emit_layer_complete(self, *, layer_idx, submodule_finalized, raise_in_place): + self._layer_events.append((layer_idx, submodule_finalized, raise_in_place)) + + def _request_loop_stop(self, exc): + self._stop_exc = exc + + def _subset_event_dispatch(self, *kwargs): + pass + + def create_named_modules(self, module, full, is_lm_head_module, layer_index, layers_prefix, names, processor, + fallback, layer_module=None) -> Dict[str, NamedModule]: + self.named_module_layers.append(layer_index) + return { + "self_attn.q_proj": NamedModule( + module, + name="self_attn.q_proj", + full_name=full, + layer_index=layer_index, + ) + } + + looper = DummyLooper() + processor = looper.processors[0] + pb = DummyPB(range(3)) + processor.layer_count = 3 + processor.pb = pb + + run_layer_stage( + looper, + layers=[torch.nn.Linear(64, 64) for _ in range(3)], + layer_modules=[["foo"]], + planning_layer_modules=[["foo"]], + layers_prefix="model.layers", + fallback=True, + shared_kv_cache_dict={}, + pb=pb, + layer_count=3, + region_timer=None, + finalize_progress_cls=FinalizeProgressInfo, + logger=DummyLogger(), + ) + + assert calls == [0] + assert looper.named_module_layers == [0] + assert pb.close_calls == 1 + + +def test_run_layer_stage_reuses_subset_plan_for_replay(monkeypatch): + tensor = torch.zeros(1, 1, 1) + replay_modules = { + "self_attn.q_proj": NamedModule( + torch.nn.Linear(1, 1, bias=False), + name="self_attn.q_proj", + full_name="model.layers.0.self_attn.q_proj", + layer_index=0, + ) + } + replay_plan = SubsetPlan( + modules=replay_modules, + subset_index=0, + subset_total=1, + execute_forward=True, + replay_after_process=True, + forward_mode="serial", + batch_count=2, + forward_row_counts=[2, 3], + forward_total_rows=5, + moe_groups={}, + forward_device_map={"self_attn.q_proj": torch.device("cuda:0")}, + calibration_coverage_policy=CalibrationCoveragePolicy( + validate_input_coverage=False, + fallback_enabled=True, + prune_uncovered_modules=False, + record_dynamic_exclusions=False, + ), + module_chunks=[replay_modules], + ) + + def fake_build_layer_subset_plans(*_args, **_kwargs): + return [replay_plan] + + def fake_run_subset_stage(looper, **kwargs): + return SubsetStageResult( + processed_subset={}, + layer_inputs=kwargs["layer_inputs"], + plan=kwargs["plan"], + ) + + monkeypatch.setattr("gptqmodel.looper.stage_layer.build_layer_subset_plans", fake_build_layer_subset_plans) + monkeypatch.setattr("gptqmodel.looper.stage_layer.run_subset_stage", fake_run_subset_stage) + monkeypatch.setattr("gptqmodel.looper.stage_layer.find_modules", lambda *_, **__: {}) + + class DummyPB: + def __init__(self, iterable): + self._iterable = list(iterable) + self.current_iter_step = 0 + + def __iter__(self): + return iter(self._iterable) + + def __len__(self): + return len(self._iterable) + + def manual(self): + return self + + def set(self, **kwargs): + return self + + def title(self, *_): + return self + + def subtitle(self, *_): + return self + + def draw(self): + return self + + def next(self): + return self + + def close(self): + return self + + class DummyLogger: + def pb(self, iterable): + return DummyPB(iterable) + + def info(self, *_, **__): + return None + + def debug(self, *_, **__): + return None + + def warning(self, *_, **__): + return None + + warn = warning + + def error(self, *_, **__): + return None + + class DummyProcessor: + def __init__(self): + self.execution_config = ExecutionConfig( + require_fwd=True, + fwd_replay_after_process=True, + fwd_all_modules_in_single_pass=False, + ) + self.inputs_cache = types.SimpleNamespace( + layer_inputs=[[tensor]], + layer_input_kwargs=[{}], + position_ids=[], + attention_masks=[], + ) + self.calibration_dataset = [] + self.log = [] + self.tasks = {} + + def collect_memory_info(self, *_): + return None + + def clear_cache_data(self): + return None + + def receive_layer_inputs(self, inputs): + self.inputs_cache.layer_inputs = inputs + + def name(self): + return "dummy" + + def submodule_finalize(self, *_, **__): + return None + + def finalize(self, *_, **__): + return None + + def log_plotly(self): + return None + + class DummyGptqModel: + def __init__(self): + self.model = torch.nn.Module() + self.model.config = types.SimpleNamespace(model_type="llama") + self.quantize_config = QuantizeConfig( + bits=4, + group_size=128, + offload_to_disk=False, + wait_for_submodule_finalizers=True, + ) + self.lm_head = None + + def pre_quantize(self, module): + return module + + def post_quantize(self, module): + return module + + def lm_head_pre_quantize_generate_hook(self, value): + return value + + class DummyLooper: + def __init__(self): + self.gptq_model = DummyGptqModel() + self.processors = [DummyProcessor()] + self._quant_devices = [torch.device("cpu")] + self._module_device_map = {} + self._quant_device_lock = threading.Lock() + self._moe_subset_threshold = 16 + self._dense_quant_devices = [torch.device("cpu")] + self._moe_quant_devices = [torch.device("cpu")] + self._dense_vram_strategy = types.SimpleNamespace() + self._moe_vram_strategy = types.SimpleNamespace() + self._dense_vram_strategy_explicit = False + self._moe_vram_strategy_explicit = False + self.forward_replay_calls = [] + + def _run_forward_batches(self, **kwargs): + self.forward_replay_calls.append(kwargs) + return [[tensor]] + + def _apply_forward_device_overrides(self, modules, forward_device_map, fallback_modules=None): + self.forward_override_modules = modules + self.forward_override_map = forward_device_map + return {"self_attn.q_proj": torch.device("cpu")} + + def _restore_forward_device_overrides(self, modules, previous_devices, fallback_modules=None): + self.restored_override_modules = modules + self.restored_previous_devices = previous_devices + + def _check_loop_stop(self): + return False + + def _emit_layer_complete(self, *, layer_idx, submodule_finalized, raise_in_place): + return None + + def _request_loop_stop(self, exc): + self._stop_exc = exc + + def _subset_event_dispatch(self, *kwargs): + return None + + def register_dangling_thread(self, thread): + return None + + looper = DummyLooper() + processor = looper.processors[0] + pb = DummyPB(range(2)) + processor.layer_count = 2 + processor.pb = pb + + run_layer_stage( + looper, + layers=[torch.nn.Linear(1, 1, bias=False) for _ in range(2)], + layer_modules=[["self_attn.q_proj"]], + planning_layer_modules=[["self_attn.q_proj"]], + layers_prefix="model.layers", + fallback=True, + shared_kv_cache_dict={}, + pb=pb, + layer_count=2, + region_timer=None, + finalize_progress_cls=FinalizeProgressInfo, + logger=DummyLogger(), + ) + + assert len(looper.forward_replay_calls) == 1 + assert looper.forward_replay_calls[0]["force_serial"] is True + assert looper.forward_replay_calls[0]["preserve_module_devices"] is True + assert looper.forward_replay_calls[0]["progress_rows_per_batch"] == [2, 3] + assert looper.forward_replay_calls[0]["progress_total_rows"] == 5 + assert looper.forward_override_modules is replay_modules + assert looper.forward_override_map == {"self_attn.q_proj": torch.device("cuda:0")} + assert looper.restored_override_modules is replay_modules + + +def test_replay_layer_outputs_without_plan_uses_generic_progress(): + """Untouched-layer replay should use generic progress and disable MoE config.""" + + input_tensor = torch.ones(2, 1, 1) + expected_output = input_tensor + 3.0 + timer_records = [] + + class DummyPB: + def manual(self): + return self + + def set(self, **kwargs): + return self + + def title(self, *_): + return self + + def subtitle(self, *_): + return self + + def draw(self): + return self + + def close(self): + return self + + class DummyLogger: + def pb(self, iterable): + return DummyPB() + + class DummyTimer: + def record(self, *args, **kwargs): + timer_records.append((args, kwargs)) + + class DummyLooper: + def __init__(self): + self._current_subset = "not-cleared" + self.forward_calls = [] + + def _resolve_batch_total(self, _num_batches, layer_inputs): + return len(layer_inputs) + + def _collect_row_counts(self, layer_inputs): + return [int(batch[0].shape[0]) for batch in layer_inputs] + + def _run_forward_batches(self, **kwargs): + self.forward_calls.append(kwargs) + return [[expected_output.clone()]] + + def _apply_forward_device_overrides(self, *args, **kwargs): + raise AssertionError("untouched-layer replay should not install device overrides") + + def _restore_forward_device_overrides(self, *args, **kwargs): + raise AssertionError("untouched-layer replay should not restore device overrides") + + looper = DummyLooper() + processor = types.SimpleNamespace(num_batches=None) + + outputs = _replay_layer_outputs( + looper, + module=torch.nn.Identity(), + processor=processor, + layer_inputs=[[input_tensor]], + layer_input_kwargs=[{}], + position_ids=[], + attention_masks=[], + cur_layer_device=torch.device("cpu"), + is_lm_head_module=False, + shared_kv_cache_dict={}, + layer_index=0, + layer_descriptor="model.layers.0", + full={}, + log=DummyLogger(), + region_timer=DummyTimer(), + replay_plan=None, + ) + + assert len(looper.forward_calls) == 1 + assert looper.forward_calls[0]["progress_rows_per_batch"] == [2] + assert looper.forward_calls[0]["progress_total_rows"] == 2 + assert looper.forward_calls[0]["force_serial"] is False + assert looper.forward_calls[0]["preserve_module_devices"] is False + assert looper.forward_calls[0]["apply_moe_config"] is False + assert looper._current_subset is None + assert len(outputs) == 1 + assert len(outputs[0]) == 1 + assert torch.allclose(outputs[0][0], expected_output) + assert timer_records[0][1]["source"] == "model.layers.0:untouched" + + +def test_replay_layer_outputs_with_plan_uses_plan_metadata_and_device_overrides(): + """Subset-driven replay should keep its plan metadata but still disable MoE config.""" + + tensor = torch.zeros(1, 1, 1) + replay_modules = { + "self_attn.q_proj": NamedModule( + torch.nn.Linear(1, 1, bias=False), + name="self_attn.q_proj", + full_name="model.layers.0.self_attn.q_proj", + layer_index=0, + ) + } + replay_plan = SubsetPlan( + modules=replay_modules, + subset_index=0, + subset_total=1, + execute_forward=True, + replay_after_process=True, + forward_mode="serial", + batch_count=2, + forward_row_counts=[2, 3], + forward_total_rows=5, + moe_groups={}, + forward_device_map={"self_attn.q_proj": torch.device("cuda:0")}, + calibration_coverage_policy=CalibrationCoveragePolicy( + validate_input_coverage=False, + fallback_enabled=True, + prune_uncovered_modules=False, + record_dynamic_exclusions=False, + ), + module_chunks=[replay_modules], + ) + timer_records = [] + + class DummyPB: + def manual(self): + return self + + def set(self, **kwargs): + return self + + def title(self, *_): + return self + + def subtitle(self, *_): + return self + + def draw(self): + return self + + def close(self): + return self + + class DummyLogger: + def pb(self, iterable): + return DummyPB() + + class DummyTimer: + def record(self, *args, **kwargs): + timer_records.append((args, kwargs)) + + class DummyLooper: + def __init__(self): + self._current_subset = replay_modules + self.forward_calls = [] + + def _run_forward_batches(self, **kwargs): + self.forward_calls.append(kwargs) + return [[tensor]] + + def _apply_forward_device_overrides(self, modules, forward_device_map, fallback_modules=None): + self.forward_override_modules = modules + self.forward_override_map = forward_device_map + self.forward_override_fallback = fallback_modules + return {"self_attn.q_proj": torch.device("cpu")} + + def _restore_forward_device_overrides(self, modules, previous_devices, fallback_modules=None): + self.restored_override_modules = modules + self.restored_previous_devices = previous_devices + self.restored_override_fallback = fallback_modules + + looper = DummyLooper() + processor = types.SimpleNamespace(num_batches=None) + + outputs = _replay_layer_outputs( + looper, + module=torch.nn.Linear(1, 1, bias=False), + processor=processor, + layer_inputs=[[tensor]], + layer_input_kwargs=[{}], + position_ids=[], + attention_masks=[], + cur_layer_device=torch.device("cpu"), + is_lm_head_module=False, + shared_kv_cache_dict={}, + layer_index=0, + layer_descriptor="model.layers.0", + full={}, + log=DummyLogger(), + region_timer=DummyTimer(), + replay_plan=replay_plan, + ) + + assert len(looper.forward_calls) == 1 + assert looper.forward_calls[0]["progress_rows_per_batch"] == [2, 3] + assert looper.forward_calls[0]["progress_total_rows"] == 5 + assert looper.forward_calls[0]["force_serial"] is True + assert looper.forward_calls[0]["preserve_module_devices"] is True + assert looper.forward_calls[0]["apply_moe_config"] is False + assert looper._current_subset is None + assert outputs == [[tensor]] + assert looper.forward_override_modules is replay_modules + assert looper.forward_override_map == {"self_attn.q_proj": torch.device("cuda:0")} + assert looper.forward_calls[0]["apply_moe_config"] is False + assert looper.restored_override_modules is replay_modules + assert looper.restored_previous_devices == {"self_attn.q_proj": torch.device("cpu")} + assert timer_records[0][1]["source"] == "model.layers.0:subset1/1" + + +def test_replay_layer_outputs_with_plan_can_skip_override_restore(): + """Replay should honor plans that intentionally keep module overrides installed.""" + + tensor = torch.zeros(1, 1, 1) + replay_modules = { + "self_attn.q_proj": NamedModule( + torch.nn.Linear(1, 1, bias=False), + name="self_attn.q_proj", + full_name="model.layers.0.self_attn.q_proj", + layer_index=0, + ) + } + replay_plan = SubsetPlan( + modules=replay_modules, + subset_index=0, + subset_total=1, + execute_forward=True, + replay_after_process=True, + forward_mode="serial", + batch_count=1, + forward_row_counts=[1], + forward_total_rows=1, + moe_groups={}, + forward_device_map={"self_attn.q_proj": torch.device("cuda:0")}, + calibration_coverage_policy=CalibrationCoveragePolicy( + validate_input_coverage=False, + fallback_enabled=True, + prune_uncovered_modules=False, + record_dynamic_exclusions=False, + ), + module_chunks=[replay_modules], + restore_forward_device_overrides=False, + ) + + class DummyPB: + def manual(self): + return self + + def set(self, **kwargs): + return self + + def title(self, *_): + return self + + def subtitle(self, *_): + return self + + def draw(self): + return self + + def close(self): + return self + + class DummyLogger: + def pb(self, iterable): + return DummyPB() + + class DummyLooper: + def __init__(self): + self._current_subset = replay_modules + self.forward_calls = [] + + def _run_forward_batches(self, **kwargs): + self.forward_calls.append(kwargs) + return [[tensor]] + + def _apply_forward_device_overrides(self, modules, forward_device_map, fallback_modules=None): + self.forward_override_modules = modules + self.forward_override_map = forward_device_map + return {"self_attn.q_proj": torch.device("cpu")} + + def _restore_forward_device_overrides(self, modules, previous_devices, fallback_modules=None): + raise AssertionError("restore should be skipped when replay_plan disables it") + + looper = DummyLooper() + processor = types.SimpleNamespace(num_batches=None) + + outputs = _replay_layer_outputs( + looper, + module=torch.nn.Linear(1, 1, bias=False), + processor=processor, + layer_inputs=[[tensor]], + layer_input_kwargs=[{}], + position_ids=[], + attention_masks=[], + cur_layer_device=torch.device("cpu"), + is_lm_head_module=False, + shared_kv_cache_dict={}, + layer_index=0, + layer_descriptor="model.layers.0", + full={}, + log=DummyLogger(), + region_timer=None, + replay_plan=replay_plan, + ) + + assert outputs == [[tensor]] + assert looper.forward_override_modules is replay_modules + assert looper.forward_override_map == {"self_attn.q_proj": torch.device("cuda:0")} + + +def test_replay_layer_outputs_with_multi_device_plan_skips_moe_config(): + """Multi-device replay should disable MoE config without changing override install.""" + + tensor = torch.zeros(1, 1, 1) + replay_modules = { + "self_attn.q_proj": NamedModule( + torch.nn.Linear(1, 1, bias=False), + name="self_attn.q_proj", + full_name="model.layers.0.self_attn.q_proj", + layer_index=0, + ) + } + replay_plan = SubsetPlan( + modules=replay_modules, + subset_index=0, + subset_total=1, + execute_forward=True, + replay_after_process=True, + forward_mode="serial", + batch_count=2, + forward_row_counts=[2, 3], + forward_total_rows=5, + moe_groups={}, + forward_device_map={ + "self_attn.q_proj": torch.device("cuda:0"), + "mlp.experts.0.gate_proj": torch.device("cuda:1"), + }, + calibration_coverage_policy=CalibrationCoveragePolicy( + validate_input_coverage=False, + fallback_enabled=True, + prune_uncovered_modules=False, + record_dynamic_exclusions=False, + ), + module_chunks=[replay_modules], + restore_forward_device_overrides=False, + ) + timer_records = [] + + class DummyPB: + def manual(self): + return self + + def set(self, **kwargs): + return self + + def title(self, *_): + return self + + def subtitle(self, *_): + return self + + def draw(self): + return self + + def close(self): + return self + + class DummyLogger: + def pb(self, iterable): + return DummyPB() + + class DummyTimer: + def record(self, *args, **kwargs): + timer_records.append((args, kwargs)) + + class DummyLooper: + def __init__(self): + self._current_subset = replay_modules + self.forward_calls = [] + self.override_calls = [] + + def _run_forward_batches(self, **kwargs): + self.forward_calls.append(kwargs) + return [[tensor]] + + def _apply_forward_device_overrides(self, modules, forward_device_map, fallback_modules=None): + self.override_calls.append((modules, forward_device_map, fallback_modules)) + return {} + + def _restore_forward_device_overrides(self, modules, previous_devices, fallback_modules=None): + raise AssertionError("restore should be skipped when replay_plan disables it") + + looper = DummyLooper() + processor = types.SimpleNamespace(num_batches=None) + + outputs = _replay_layer_outputs( + looper, + module=torch.nn.Linear(1, 1, bias=False), + processor=processor, + layer_inputs=[[tensor]], + layer_input_kwargs=[{}], + position_ids=[], + attention_masks=[], + cur_layer_device=torch.device("cpu"), + is_lm_head_module=False, + shared_kv_cache_dict={}, + layer_index=0, + layer_descriptor="model.layers.0", + full={}, + log=DummyLogger(), + region_timer=DummyTimer(), + replay_plan=replay_plan, + ) + + assert outputs == [[tensor]] + assert looper.override_calls == [ + ( + replay_modules, + { + "self_attn.q_proj": torch.device("cuda:0"), + "mlp.experts.0.gate_proj": torch.device("cuda:1"), + }, + {}, + ) + ] + assert len(looper.forward_calls) == 1 + assert looper.forward_calls[0]["progress_rows_per_batch"] == [2, 3] + assert looper.forward_calls[0]["progress_total_rows"] == 5 + assert looper.forward_calls[0]["force_serial"] is True + assert looper.forward_calls[0]["preserve_module_devices"] is True + assert looper.forward_calls[0]["apply_moe_config"] is False + assert timer_records[0][1]["source"] == "model.layers.0:subset1/1" + + +class _ToySelfAttention(torch.nn.Module): + def __init__(self): + super().__init__() + self.q_proj = torch.nn.Linear(1, 1, bias=False) + self.k_proj = torch.nn.Linear(1, 1, bias=False) + self.v_proj = torch.nn.Linear(1, 1, bias=False) + self.o_proj = torch.nn.Linear(1, 1, bias=False) + for proj in (self.q_proj, self.k_proj, self.v_proj, self.o_proj): + torch.nn.init.constant_(proj.weight, 1.0) + + def forward(self, hidden_states): + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + return self.o_proj(q + k + v) + + +class _ToyMLP(torch.nn.Module): + def __init__(self): + super().__init__() + self.gate_proj = torch.nn.Linear(1, 1, bias=False) + self.up_proj = torch.nn.Linear(1, 1, bias=False) + self.down_proj = torch.nn.Linear(1, 1, bias=False) + for proj in (self.gate_proj, self.up_proj, self.down_proj): + torch.nn.init.constant_(proj.weight, 1.0) + + def forward(self, hidden_states): + gate = self.gate_proj(hidden_states) + up = self.up_proj(hidden_states) + return self.down_proj(gate + up) + + +class _ToyLlamaDecoderLayer(torch.nn.Module): + def __init__(self): + super().__init__() + self.input_layernorm = torch.nn.Identity() + self.self_attn = _ToySelfAttention() + self.post_attention_layernorm = torch.nn.Identity() + self.mlp = _ToyMLP() + self.forward_inputs = [] + + def forward(self, hidden_states, attention_mask=None, position_ids=None, **kwargs): + self.forward_inputs.append(hidden_states.detach().clone()) + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn(hidden_states) + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + return hidden_states + + +def test_run_layer_stage_replays_untouched_layer_outputs_when_all_modules_skipped(monkeypatch): + observed_layer_inputs = [] + + def fake_run_subset_stage(looper, **kwargs): + observed_layer_inputs.append( + ( + kwargs["layer_index"], + kwargs["plan"].subset_index, + kwargs["layer_inputs"][0][0].detach().clone(), + ) + ) + return SubsetStageResult( + processed_subset={}, + layer_inputs=kwargs["layer_inputs"], + plan=kwargs["plan"], + ) + + monkeypatch.setattr("gptqmodel.looper.stage_layer.run_subset_stage", fake_run_subset_stage) + monkeypatch.setattr("gptqmodel.looper.stage_layer.find_modules", lambda *_, **__: {}) + + class DummyPB: + def __init__(self, iterable): + self._iterable = list(iterable) + self.current_iter_step = 0 + + def __iter__(self): + return iter(self._iterable) + + def __len__(self): + return len(self._iterable) + + def manual(self): + return self + + def set(self, **kwargs): + return self + + def title(self, *_): + return self + + def subtitle(self, *_): + return self + + def draw(self): + return self + + def next(self): + return self + + def close(self): + return self + + class DummyLogger: + def pb(self, iterable): + return DummyPB(iterable) + + def info(self, *_, **__): + return None + + def debug(self, *_, **__): + return None + + def warning(self, *_, **__): + return None + + warn = warning + + def error(self, *_, **__): + return None + + class DummyProcessor: + def __init__(self, initial_inputs): + self.execution_config = ExecutionConfig( + require_fwd=True, + fwd_replay_after_process=True, + fwd_all_modules_in_single_pass=False, + subset_forward_early_stop=True, + ) + self.inputs_cache = types.SimpleNamespace( + layer_inputs=initial_inputs, + layer_input_kwargs=[{}], + position_ids=[], + attention_masks=[], + ) + self.calibration_dataset = [] + self.log = [] + self.tasks = {} + + def collect_memory_info(self, *_): + return None + + def clear_cache_data(self): + self.tasks = {} + self.inputs_cache.layer_inputs = [] + + def receive_layer_inputs(self, inputs): + self.inputs_cache.layer_inputs = inputs + + def set_fwd_time(self, *_): + return None + + def name(self): + return "GPTQProcessor" + + def submodule_finalize(self, *_, **__): + return None + + def finalize(self, *_, **__): + return None + + def log_plotly(self): + return None + + class DummyGptqModel: + def __init__(self): + self.model = torch.nn.Module() + self.model.config = types.SimpleNamespace(model_type="llama") + self.quantize_config = QuantizeConfig( + bits=4, + group_size=128, + offload_to_disk=False, + wait_for_submodule_finalizers=True, + dynamic={ + r"-:^model\.layers\.0\.": {}, + }, + ) + self.lm_head = None + + def pre_quantize(self, module): + return module + + def post_quantize(self, module): + return module + + def lm_head_pre_quantize_generate_hook(self, value): + return value + + class DummyLooper: + def __init__(self, layers, initial_inputs): + self.gptq_model = DummyGptqModel() + self.processors = [DummyProcessor(initial_inputs)] + self._quant_devices = [torch.device("cpu")] + self._module_device_map = {} + self._quant_device_lock = threading.Lock() + self._moe_subset_threshold = 16 + self._dense_quant_devices = [torch.device("cpu")] + self._moe_quant_devices = [torch.device("cpu")] + self._dense_vram_strategy = types.SimpleNamespace() + self._moe_vram_strategy = types.SimpleNamespace() + self._dense_vram_strategy_explicit = False + self._moe_vram_strategy_explicit = False + self._current_subset = None + self.support_batch_quantize = False + self.moe_routing_override = None + self.moe_routing_bypass = False + self.forward_layer_indices = [] + self.layers = layers + + def _run_forward_batches(self, **kwargs): + self.forward_layer_indices.append(kwargs["layer_index"]) + outputs = [] + for batch_inputs in kwargs["layer_inputs"]: + hidden_states = batch_inputs[0] + output = kwargs["module"]( + hidden_states=hidden_states, + attention_mask=None, + position_ids=None, + ) + outputs.append([output]) + return outputs + + def _check_loop_stop(self): + return False + + def _is_attention_module_name(self, name): + return name.startswith("self_attn.") + + def _extract_moe_group_key(self, _name): + return None + + def _resolve_batch_total(self, _num_batches, layer_inputs): + return len(layer_inputs) + + def _collect_row_counts(self, layer_inputs): + return [int(batch[0].shape[0]) for batch in layer_inputs] + + def _emit_layer_complete(self, *, layer_idx, submodule_finalized, raise_in_place): + return None + + def _request_loop_stop(self, exc): + self._stop_exc = exc + + def _subset_event_dispatch(self, *kwargs): + return None + + def register_dangling_thread(self, thread): + return None + + def create_named_modules( + self, + module, + full, + is_lm_head_module, + layer_index, + layers_prefix, + names, + processor, + fallback, + layer_module=None, + ) -> Dict[str, NamedModule]: + subset = {} + for name in names: + full_name = f"{layers_prefix}.{layer_index}.{name}" + if self.gptq_model.quantize_config.dynamic_get(layer_name=full_name) is False: + continue + subset[name] = NamedModule( + module.get_submodule(name), + name=name, + full_name=full_name, + layer_index=layer_index, + ) + return subset + + input_tensor = torch.tensor([[[2.0]]]) + layers = [_ToyLlamaDecoderLayer(), _ToyLlamaDecoderLayer()] + looper = DummyLooper(layers, initial_inputs=[[input_tensor.clone()]]) + processor = looper.processors[0] + pb = DummyPB(range(2)) + processor.layer_count = 2 + processor.pb = pb + + run_layer_stage( + looper, + layers=layers, + layer_modules=[ + ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"], + ["self_attn.o_proj"], + ["mlp.gate_proj", "mlp.up_proj"], + ["mlp.down_proj"], + ], + planning_layer_modules=[ + ["self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"], + ["self_attn.o_proj"], + ["mlp.gate_proj", "mlp.up_proj"], + ["mlp.down_proj"], + ], + layers_prefix="model.layers", + fallback=True, + shared_kv_cache_dict={}, + pb=pb, + layer_count=2, + region_timer=None, + finalize_progress_cls=FinalizeProgressInfo, + logger=DummyLogger(), + ) + + layer1_inputs = [ + layer_input + for layer_idx, _subset_idx, layer_input in observed_layer_inputs + if layer_idx == 1 + ] + expected_layer0_output = input_tensor * 6.0 + + assert looper.forward_layer_indices == [0] + assert len(layers[0].forward_inputs) == 1 + assert torch.allclose(layers[0].forward_inputs[0], input_tensor) + assert layer1_inputs + assert all(torch.allclose(layer_input, expected_layer0_output) for layer_input in layer1_inputs) + + +def test_capture_pristine_group_context_preserves_untouched_layer_io(monkeypatch): + observed = {} + sentinel_outputs = [[torch.randn(1, 1, 1)]] + + def fake_replay_layer_outputs(*_args, **kwargs): + observed["replay_kwargs"] = kwargs + return sentinel_outputs + + monkeypatch.setattr("gptqmodel.looper.stage_layer._replay_layer_outputs", fake_replay_layer_outputs) + + class DummyProcessor: + def uses_grouped_optimization(self): + return True + + def receive_layer_forward_context(self, **kwargs): + observed["receive_kwargs"] = kwargs + + tensor = torch.randn(1, 1, 1) + subset_plan = SubsetPlan( + modules={}, + subset_index=0, + subset_total=1, + execute_forward=True, + replay_after_process=True, + forward_mode="serial", + batch_count=1, + forward_row_counts=[1], + forward_total_rows=1, + moe_groups={}, + forward_device_map={}, + calibration_coverage_policy=CalibrationCoveragePolicy( + validate_input_coverage=False, + fallback_enabled=True, + prune_uncovered_modules=False, + record_dynamic_exclusions=False, + ), + module_chunks=[{}], + ) + + _capture_pristine_group_context( + looper=types.SimpleNamespace(), + processor=DummyProcessor(), + module=torch.nn.Identity(), + pristine_module=None, + subset_plans=[subset_plan], + layer_inputs=[[tensor]], + layer_input_kwargs=[{}], + position_ids=[], + attention_masks=[], + cur_layer_device=torch.device("cpu"), + is_lm_head_module=False, + shared_kv_cache_dict={}, + layer_index=0, + layer_descriptor="model.layers.0", + full={}, + log=None, + region_timer=None, + ) + + assert observed["replay_kwargs"]["replay_plan"] is None + assert observed["receive_kwargs"]["layer_outputs"] is sentinel_outputs + assert observed["receive_kwargs"]["layer_inputs"] == [[tensor]] + assert observed["receive_kwargs"]["layer_input_kwargs"] == [{}] + assert observed["receive_kwargs"]["subset_total"] == 1 + + +def test_masked_hook_wrapper_trims_left_padded_inputs_before_add_batch(): + looper = ModuleLooper.__new__(ModuleLooper) + looper.gptq_model = types.SimpleNamespace(quant_region_timer=None) + + class _FakeTask: + def __init__(self): + self.add_batch_input = None + + def add_batch(self, inp, out, batch_index=None): + self.add_batch_input = inp + + processor = types.SimpleNamespace() + task = _FakeTask() + + input_ids = torch.tensor( + [ + [[1.0, 1.0], [2.0, 2.0], [30.0, 30.0], [40.0, 40.0]], + [[3.0, 3.0], [4.0, 4.0], [50.0, 50.0], [60.0, 60.0]], + ], + dtype=torch.float32, + ) + + attention_mask = torch.tensor( + [ + [0, 0, 1, 1], + [1, 1, 0, 0], + ], + dtype=torch.bool, + ) + looper._set_processor_mask(processor, attention_mask) + + def inner_hook(module, hook_inputs, hook_output): + task.add_batch(hook_inputs[0], torch.empty(0)) + return module, hook_inputs, hook_output + + wrapped_hook = looper._masked_hook_wrapper(processor, inner_hook, "test") + wrapped_hook( + None, + (input_ids,), + torch.empty((2, 4, 2)), + ) + + assert task.add_batch_input is not None + assert task.add_batch_input.shape == (4, 2) + assert torch.equal( + task.add_batch_input, + torch.tensor( + [ + [30.0, 30.0], + [40.0, 40.0], + [3.0, 3.0], + [4.0, 4.0], + ], + dtype=torch.float32, + ), + ) diff --git a/tests/test_startup_banner.py b/tests/test_startup_banner.py new file mode 100644 index 000000000..936e68803 --- /dev/null +++ b/tests/test_startup_banner.py @@ -0,0 +1,70 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +import importlib.util +from pathlib import Path + + +MODULE_PATH = Path(__file__).resolve().parents[1] / "gptqmodel" / "_banner.py" +MODULE_SPEC = importlib.util.spec_from_file_location("gptqmodel_banner_test_module", MODULE_PATH) +assert MODULE_SPEC is not None +assert MODULE_SPEC.loader is not None + +banner_module = importlib.util.module_from_spec(MODULE_SPEC) +MODULE_SPEC.loader.exec_module(banner_module) + + +def test_build_startup_banner_aligns_versions(): + banner = banner_module.build_startup_banner( + "LOGO\n", + gptqmodel_version="5.8.0", + transformers_version="5.3.0", + torch_version="2.10.0+cu130", + triton_version="3.6.0", + ) + + lines = banner.splitlines() + assert lines[0] == "LOGO" + assert lines[1].strip().endswith("5.8.0") + assert lines[2].strip().endswith("5.3.0") + assert lines[3].strip().endswith("2.10.0+cu130") + assert lines[4].strip().endswith("3.6.0") + assert lines[1].startswith("GPT-QModel") + assert lines[2].startswith("Transformers") + assert lines[3].startswith("Torch") + assert lines[4].startswith("Triton") + assert {line.index(":") for line in lines[1:]} == {13} + + +def test_build_startup_banner_skips_missing_optional_versions(): + banner = banner_module.build_startup_banner( + "LOGO\n", + gptqmodel_version="5.8.0", + transformers_version="5.3.0", + torch_version="2.10.0+cu130", + ) + + assert "Triton version" not in banner + assert "Triton" not in banner + + +def test_get_startup_banner_resolves_optional_versions(monkeypatch): + def fake_resolve(package_names): + if tuple(package_names) == banner_module.TRITON_PACKAGE_CANDIDATES: + return "3.6.0" + raise AssertionError(f"Unexpected package candidates: {package_names}") + + monkeypatch.setattr( + banner_module, + "resolve_installed_package_version", + fake_resolve, + ) + + banner = banner_module.get_startup_banner( + "LOGO\n", + gptqmodel_version="5.8.0", + transformers_version="5.3.0", + torch_version="2.10.0+cu130", + ) + + assert any(line.startswith("Triton") and line.endswith("3.6.0") for line in banner.splitlines()) diff --git a/tests/test_stream.py b/tests/test_stream.py index 3800beb23..8e4b7de3d 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -86,3 +86,41 @@ def test_stream_tensor_dict_to_cpu_cuda_background_release_preserves_events(): with state_lock: assert not state.get("streaming_events"), "stream_sync should clear pending tickets" assert "tensor" not in state.get("streaming_event_map", {}), "event map entry should be removed after sync" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for mixed stream tests") +def test_stream_tensor_dict_to_cpu_mixed_cuda_and_cpu_payload(): + device = torch.device("cuda", 0) + torch.cuda.set_device(device) + + payload = { + "cuda_tensor": torch.randn(8, 8, device=device, dtype=torch.float16), + "cpu_bias": torch.randn(8, dtype=torch.float32), + } + state: dict[str, object] = {} + state_lock = threading.RLock() + stored: dict[str, torch.Tensor] = {} + + result = stream_tensor_dict_to_cpu( + payload, + store_callback=lambda items: stored.update(items), + state=state, + state_lock=state_lock, + ) + + assert result["cuda_tensor"].device.type == "cpu" + assert result["cuda_tensor"].is_pinned() + assert result["cpu_bias"].device.type == "cpu" + torch.testing.assert_close(result["cpu_bias"], payload["cpu_bias"]) + + with state_lock: + assert stored["cuda_tensor"] is result["cuda_tensor"] + assert stored["cpu_bias"] is result["cpu_bias"] + event_map = state.get("streaming_event_map", {}) + assert "cuda_tensor" in event_map + assert "cpu_bias" not in event_map + + stream_sync(state, state_lock) + + torch.testing.assert_close(result["cuda_tensor"].cpu(), payload["cuda_tensor"].cpu()) + torch.testing.assert_close(result["cpu_bias"], payload["cpu_bias"]) diff --git a/tests/test_structure.py b/tests/test_structure.py new file mode 100644 index 000000000..0cab78c69 --- /dev/null +++ b/tests/test_structure.py @@ -0,0 +1,47 @@ +import torch.nn as nn + +from gptqmodel.utils.structure import print_module_tree + + +class DummyBlock(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(4, 4) + + +class DummyStackModel(nn.Module): + def __init__(self, num_layers: int): + super().__init__() + self.layers = nn.ModuleList([DummyBlock() for _ in range(num_layers)]) + self.heads = nn.ModuleList([nn.Linear(4, 4) for _ in range(num_layers)]) + self.lm_head = nn.Linear(4, 4) + + +def test_print_module_tree_caps_layer_stacks_by_default(capsys): + model = DummyStackModel(num_layers=6) + + print_module_tree(model, color=False, show_all=False) + captured = capsys.readouterr() + output = captured.out + + assert "model.layers.0: DummyBlock" in output + assert "model.layers.1: DummyBlock" in output + assert "model.layers.2: DummyBlock" in output + assert "model.layers.3: DummyBlock" in output + assert "model.layers.4: DummyBlock" not in output + assert "model.layers.5: DummyBlock" not in output + assert "collapsed (repeats 4..5, per-layer" in output + assert "model.heads.4: Linear" in output + assert "model.lm_head: Linear" in output + + +def test_print_module_tree_can_show_all_layers(capsys): + model = DummyStackModel(num_layers=6) + + print_module_tree(model, color=False, show_all=False, layers_show=None) + captured = capsys.readouterr() + output = captured.out + + assert "model.layers.4: DummyBlock" in output + assert "model.layers.5: DummyBlock" in output + assert "collapsed (repeats" not in output diff --git a/tests/test_subset_plan.py b/tests/test_subset_plan.py new file mode 100644 index 000000000..b43889afa --- /dev/null +++ b/tests/test_subset_plan.py @@ -0,0 +1,500 @@ +import sys +import types +from unittest.mock import MagicMock + +import torch + +from gptqmodel.looper.loop_processor import ExecutionConfig +from gptqmodel.looper.named_module import NamedModule +from gptqmodel.looper.stage_subset import build_layer_subset_plans, build_subset_plan +from gptqmodel.quantization.config import VramStrategy + + +def _make_named_module(name: str, layer_index: int = 0) -> NamedModule: + return NamedModule( + torch.nn.Linear(4, 4, bias=False), + name=name, + full_name=f"model.layers.{layer_index}.{name}", + layer_index=layer_index, + ) + + +def _planning_blocks(*blocks) -> list[list[str]]: + planning_blocks = [] + for block in blocks: + if isinstance(block, str): + planning_blocks.append([block]) + else: + planning_blocks.append(list(block)) + return planning_blocks + + +def _make_looper(): + looper = MagicMock() + looper.gptq_model = types.SimpleNamespace( + lm_head="lm_head", + quantize_config=types.SimpleNamespace( + auto_forward_data_parallel=True, + moe=None, + ), + ) + looper._is_attention_module_name.return_value = False + looper._extract_moe_group_key.return_value = None + looper._moe_subset_threshold = 2 + looper._quant_devices = [torch.device("cpu")] + looper._dense_quant_devices = [torch.device("cpu")] + looper._moe_quant_devices = [torch.device("cpu")] + looper._dense_vram_strategy = VramStrategy.EXCLUSIVE + looper._moe_vram_strategy = VramStrategy.EXCLUSIVE + looper._dense_vram_strategy_explicit = False + looper._moe_vram_strategy_explicit = False + looper._resolve_batch_total.return_value = 2 + looper._collect_row_counts.return_value = [3, 2] + return looper + + +class _StubProcessor: + def __init__(self, execution_config: ExecutionConfig): + self.execution_config = execution_config + + +def test_build_subset_plan_skips_forward_for_no_forward_processor(): + looper = _make_looper() + processor = _StubProcessor(ExecutionConfig(require_fwd=False)) + subset = {"mlp.down_proj": _make_named_module("mlp.down_proj")} + + plan = build_subset_plan( + looper, + processor=processor, + subset=subset, + subset_index=0, + subset_total=1, + full=subset, + fallback=None, + layer_inputs=[[torch.zeros(1, 4)]], + ) + + assert plan.execute_forward is False + assert plan.replay_after_process is False + assert plan.batch_count == 0 + assert plan.forward_row_counts == [] + assert plan.forward_total_rows == 1 + assert plan.forward_mode == "parallel" + assert plan.module_chunks == [subset] + assert plan.calibration_coverage_policy.validate_input_coverage is False + + +def test_build_subset_plan_balanced_moe_uses_serial_forward_and_device_map(): + looper = _make_looper() + looper._quant_devices = [torch.device("cuda:0"), torch.device("cuda:1")] + looper._moe_quant_devices = [torch.device("cuda:0"), torch.device("cuda:1")] + looper._moe_vram_strategy = VramStrategy.BALANCED + looper._moe_vram_strategy_explicit = True + + def _group_key(name: str): + parts = name.split(".") + if "experts" not in parts: + return None + expert_index = parts.index("experts") + return ".".join(parts[:expert_index + 2]) + + looper._extract_moe_group_key.side_effect = _group_key + + processor = _StubProcessor( + ExecutionConfig( + require_fwd=True, + fwd_replay_after_process=True, + ) + ) + subset = { + "mlp.experts.0.gate_proj": _make_named_module("mlp.experts.0.gate_proj"), + "mlp.experts.0.up_proj": _make_named_module("mlp.experts.0.up_proj"), + "mlp.experts.1.gate_proj": _make_named_module("mlp.experts.1.gate_proj"), + "mlp.experts.1.up_proj": _make_named_module("mlp.experts.1.up_proj"), + } + + plan = build_subset_plan( + looper, + processor=processor, + subset=subset, + subset_index=0, + subset_total=1, + full=subset, + fallback=True, + layer_inputs=[[torch.zeros(3, 4)], [torch.zeros(2, 4)]], + ) + + assert plan.execute_forward is True + assert plan.replay_after_process is True + assert plan.forward_mode == "serial" + assert plan.subset_forward_serial is True + assert plan.batch_count == 2 + assert plan.forward_row_counts == [3, 2] + assert plan.forward_total_rows == 5 + assert plan.forward_device_map == { + "mlp.experts.0.gate_proj": torch.device("cuda:0"), + "mlp.experts.0.up_proj": torch.device("cuda:0"), + "mlp.experts.1.gate_proj": torch.device("cuda:1"), + "mlp.experts.1.up_proj": torch.device("cuda:1"), + } + + +def test_build_subset_plan_balanced_moe_pins_untouched_modules_to_baseline_device(): + looper = _make_looper() + looper._quant_devices = [torch.device("cuda:0"), torch.device("cuda:1")] + looper._moe_quant_devices = [torch.device("cuda:0"), torch.device("cuda:1")] + looper._moe_vram_strategy = VramStrategy.BALANCED + looper._moe_vram_strategy_explicit = True + + def _group_key(name: str): + parts = name.split(".") + if "experts" not in parts: + return None + expert_index = parts.index("experts") + return ".".join(parts[:expert_index + 2]) + + looper._extract_moe_group_key.side_effect = _group_key + + processor = _StubProcessor( + ExecutionConfig( + require_fwd=True, + fwd_replay_after_process=True, + ) + ) + subset = { + "mlp.experts.0.gate_proj": _make_named_module("mlp.experts.0.gate_proj"), + "mlp.experts.0.up_proj": _make_named_module("mlp.experts.0.up_proj"), + "mlp.experts.1.gate_proj": _make_named_module("mlp.experts.1.gate_proj"), + "mlp.experts.1.up_proj": _make_named_module("mlp.experts.1.up_proj"), + } + full = {name: named.module for name, named in subset.items()} + full["self_attn.o_proj"] = torch.nn.Linear(4, 4, bias=False) + + plan = build_subset_plan( + looper, + processor=processor, + subset=subset, + subset_index=0, + subset_total=1, + full=full, + fallback=True, + layer_inputs=[[torch.zeros(3, 4)], [torch.zeros(2, 4)]], + ) + + assert plan.forward_device_map["mlp.experts.0.gate_proj"] == torch.device("cuda:0") + assert plan.forward_device_map["mlp.experts.1.gate_proj"] == torch.device("cuda:1") + assert plan.forward_device_map["self_attn.o_proj"] == torch.device("cpu") + assert plan.restore_forward_device_overrides is False + assert subset["mlp.experts.0.gate_proj"].state["preferred_quant_device"] == torch.device("cuda:0") + assert subset["mlp.experts.1.gate_proj"].state["preferred_quant_device"] == torch.device("cuda:1") + + +def test_build_subset_plan_split_pools_reserve_dense_and_moe_devices(): + looper = _make_looper() + looper._quant_devices = [ + torch.device("cuda:0"), + torch.device("cuda:1"), + torch.device("cuda:2"), + ] + looper._dense_quant_devices = [torch.device("cuda:0")] + looper._moe_quant_devices = [torch.device("cuda:1"), torch.device("cuda:2")] + looper._dense_vram_strategy = VramStrategy.EXCLUSIVE + looper._moe_vram_strategy = VramStrategy.BALANCED + looper._dense_vram_strategy_explicit = True + looper._moe_vram_strategy_explicit = True + + def _group_key(name: str): + parts = name.split(".") + if "experts" not in parts: + return None + expert_index = parts.index("experts") + return ".".join(parts[:expert_index + 2]) + + looper._extract_moe_group_key.side_effect = _group_key + + processor = _StubProcessor( + ExecutionConfig( + require_fwd=True, + fwd_replay_after_process=True, + ) + ) + subset = { + "mlp.experts.0.gate_proj": _make_named_module("mlp.experts.0.gate_proj"), + "mlp.experts.0.up_proj": _make_named_module("mlp.experts.0.up_proj"), + "mlp.experts.1.gate_proj": _make_named_module("mlp.experts.1.gate_proj"), + "mlp.experts.1.up_proj": _make_named_module("mlp.experts.1.up_proj"), + "mlp.experts.2.gate_proj": _make_named_module("mlp.experts.2.gate_proj"), + "mlp.experts.2.up_proj": _make_named_module("mlp.experts.2.up_proj"), + } + full = {name: named.module for name, named in subset.items()} + full["self_attn.o_proj"] = torch.nn.Linear(4, 4, bias=False) + full["mlp.shared_expert.down_proj"] = torch.nn.Linear(4, 4, bias=False) + + plan = build_subset_plan( + looper, + processor=processor, + subset=subset, + subset_index=0, + subset_total=1, + full=full, + fallback=True, + layer_inputs=[[torch.zeros(3, 4)]], + planning_layer_modules=_planning_blocks( + ("self_attn.q_norm:!", "self_attn.q_proj", "self_attn.k_norm:!", "self_attn.k_proj", "self_attn.v_proj"), + ("self_attn.o_proj",), + ("mlp.experts.0.gate_proj", "mlp.experts.0.up_proj", "mlp.experts.1.gate_proj", "mlp.experts.1.up_proj"), + ), + ) + + assert plan.forward_mode == "serial" + assert plan.restore_forward_device_overrides is False + assert plan.forward_device_map["mlp.experts.0.gate_proj"] == torch.device("cuda:1") + assert plan.forward_device_map["mlp.experts.1.gate_proj"] == torch.device("cuda:2") + assert plan.forward_device_map["mlp.experts.2.gate_proj"] == torch.device("cuda:1") + assert plan.forward_device_map["self_attn.o_proj"] == torch.device("cuda:0") + assert plan.forward_device_map["mlp.shared_expert.down_proj"] == torch.device("cuda:0") + assert subset["mlp.experts.0.gate_proj"].state["preferred_quant_device"] == torch.device("cuda:1") + assert subset["mlp.experts.1.gate_proj"].state["preferred_quant_device"] == torch.device("cuda:2") + assert subset["mlp.experts.2.gate_proj"].state["preferred_quant_device"] == torch.device("cuda:1") + + +def test_build_subset_plan_dense_balanced_keeps_qkv_group_together(): + looper = _make_looper() + looper._quant_devices = [torch.device("cuda:0"), torch.device("cuda:1")] + looper._dense_quant_devices = [torch.device("cuda:0"), torch.device("cuda:1")] + looper._dense_vram_strategy = VramStrategy.BALANCED + looper._dense_vram_strategy_explicit = True + + def _group_key(name: str): + parts = name.split(".") + if "experts" not in parts: + return None + expert_index = parts.index("experts") + return ".".join(parts[:expert_index + 2]) + + looper._extract_moe_group_key.side_effect = _group_key + looper._is_attention_module_name.side_effect = lambda name: name.startswith("self_attn.") + + processor = _StubProcessor( + ExecutionConfig( + require_fwd=True, + fwd_replay_after_process=True, + ) + ) + subset = { + "self_attn.q_proj": _make_named_module("self_attn.q_proj"), + "self_attn.k_proj": _make_named_module("self_attn.k_proj"), + "self_attn.v_proj": _make_named_module("self_attn.v_proj"), + } + full = {name: named.module for name, named in subset.items()} + full["self_attn.o_proj"] = torch.nn.Linear(4, 4, bias=False) + full["mlp.experts.0.gate_proj"] = torch.nn.Linear(4, 4, bias=False) + full["mlp.experts.0.up_proj"] = torch.nn.Linear(4, 4, bias=False) + full["mlp.experts.1.gate_proj"] = torch.nn.Linear(4, 4, bias=False) + full["mlp.experts.1.up_proj"] = torch.nn.Linear(4, 4, bias=False) + + plan = build_subset_plan( + looper, + processor=processor, + subset=subset, + subset_index=0, + subset_total=1, + full=full, + fallback=True, + layer_inputs=[[torch.zeros(3, 4)]], + planning_layer_modules=_planning_blocks( + ("self_attn.q_norm:!", "self_attn.q_proj", "self_attn.k_norm:!", "self_attn.k_proj", "self_attn.v_proj"), + ("self_attn.o_proj",), + ("mlp.experts.0.gate_proj", "mlp.experts.0.up_proj", "mlp.experts.1.gate_proj", "mlp.experts.1.up_proj"), + ), + ) + + assert plan.forward_mode == "serial" + assert plan.restore_forward_device_overrides is False + assert plan.forward_device_map["self_attn.q_proj"] == torch.device("cuda:0") + assert plan.forward_device_map["self_attn.k_proj"] == torch.device("cuda:0") + assert plan.forward_device_map["self_attn.v_proj"] == torch.device("cuda:0") + assert plan.forward_device_map["self_attn.o_proj"] == torch.device("cuda:1") + assert subset["self_attn.q_proj"].state["preferred_quant_device"] == torch.device("cuda:0") + assert subset["self_attn.k_proj"].state["preferred_quant_device"] == torch.device("cuda:0") + assert subset["self_attn.v_proj"].state["preferred_quant_device"] == torch.device("cuda:0") + + +def test_build_subset_plan_split_pools_pin_dense_subset_and_balance_experts(): + looper = _make_looper() + looper._quant_devices = [ + torch.device("cuda:0"), + torch.device("cuda:1"), + torch.device("cuda:2"), + ] + looper._dense_quant_devices = [torch.device("cuda:0")] + looper._moe_quant_devices = [torch.device("cuda:1"), torch.device("cuda:2")] + looper._dense_vram_strategy = VramStrategy.EXCLUSIVE + looper._moe_vram_strategy = VramStrategy.BALANCED + looper._dense_vram_strategy_explicit = True + looper._moe_vram_strategy_explicit = True + + def _group_key(name: str): + parts = name.split(".") + if "experts" not in parts: + return None + expert_index = parts.index("experts") + return ".".join(parts[:expert_index + 2]) + + looper._extract_moe_group_key.side_effect = _group_key + looper._is_attention_module_name.side_effect = lambda name: name.startswith("self_attn.") + + processor = _StubProcessor( + ExecutionConfig( + require_fwd=True, + fwd_replay_after_process=True, + ) + ) + subset = { + "self_attn.q_proj": _make_named_module("self_attn.q_proj"), + "self_attn.k_proj": _make_named_module("self_attn.k_proj"), + "self_attn.v_proj": _make_named_module("self_attn.v_proj"), + } + full = {name: named.module for name, named in subset.items()} + full["self_attn.o_proj"] = torch.nn.Linear(4, 4, bias=False) + full["mlp.experts.0.gate_proj"] = torch.nn.Linear(4, 4, bias=False) + full["mlp.experts.0.up_proj"] = torch.nn.Linear(4, 4, bias=False) + full["mlp.experts.1.gate_proj"] = torch.nn.Linear(4, 4, bias=False) + full["mlp.experts.1.up_proj"] = torch.nn.Linear(4, 4, bias=False) + + plan = build_subset_plan( + looper, + processor=processor, + subset=subset, + subset_index=0, + subset_total=1, + full=full, + fallback=True, + layer_inputs=[[torch.zeros(3, 4)]], + ) + + assert plan.forward_mode == "serial" + assert plan.restore_forward_device_overrides is False + assert plan.forward_device_map["self_attn.q_proj"] == torch.device("cuda:0") + assert plan.forward_device_map["self_attn.o_proj"] == torch.device("cuda:0") + assert plan.forward_device_map["mlp.experts.0.gate_proj"] == torch.device("cuda:1") + assert plan.forward_device_map["mlp.experts.1.gate_proj"] == torch.device("cuda:2") + assert subset["self_attn.q_proj"].state["preferred_quant_device"] == torch.device("cuda:0") + assert subset["self_attn.k_proj"].state["preferred_quant_device"] == torch.device("cuda:0") + assert subset["self_attn.v_proj"].state["preferred_quant_device"] == torch.device("cuda:0") + + +def test_build_layer_subset_plans_merges_groups_for_single_pass_processors(): + looper = _make_looper() + requested_name_groups = [] + + def _create_named_modules( + module, + full, + is_lm_head_module, + layer_index, + layers_prefix, + names, + processor, + fallback, + layer_module=None, + ): + requested_name_groups.append(list(names)) + return {name: _make_named_module(name, layer_index=layer_index) for name in names} + + looper.create_named_modules.side_effect = _create_named_modules + + processor = _StubProcessor( + ExecutionConfig( + require_fwd=True, + fwd_replay_after_process=False, + fwd_all_modules_in_single_pass=True, + ) + ) + + plans = build_layer_subset_plans( + looper, + processor=processor, + module=torch.nn.Linear(4, 4), + layer_modules=[["self_attn.q_proj"], ["mlp.down_proj"]], + planning_layer_modules=_planning_blocks( + ("self_attn.q_proj", "self_attn.k_proj", "self_attn.v_proj"), + ("mlp.down_proj",), + ), + layer_inputs=[[torch.zeros(1, 4)]], + full={}, + is_lm_head_module=False, + layer_index=3, + layers_prefix="model.layers", + fallback=True, + ) + + assert requested_name_groups == [["self_attn.q_proj", "mlp.down_proj"]] + assert len(plans) == 1 + assert plans[0].subset_index == 0 + assert plans[0].subset_total == 1 + assert list(plans[0].modules.keys()) == ["self_attn.q_proj", "mlp.down_proj"] + + +def test_emit_moe_parallel_quant_subset_telemetry_reports_gil_and_worker_fanout(monkeypatch): + emitted = [] + stage_subset_module = sys.modules[build_subset_plan.__module__] + + monkeypatch.setattr( + stage_subset_module, + "emit_device_telemetry", + lambda event, **fields: emitted.append((event, fields)), + ) + monkeypatch.setattr(stage_subset_module, "has_gil_control", lambda: True) + monkeypatch.setattr(stage_subset_module, "has_gil_disabled", lambda: True) + monkeypatch.setattr( + stage_subset_module.DEVICE_THREAD_POOL, + "_collect_state_snapshot", + lambda: { + "workers": {"cuda:1": 4, "cuda:2": 4}, + "total_workers": 8, + "total_inflight": 2, + }, + ) + + named = _make_named_module("mlp.experts.0.gate_proj") + plan = stage_subset_module.SubsetPlan( + modules={named.name: named}, + subset_index=0, + subset_total=1, + execute_forward=True, + replay_after_process=True, + forward_mode="serial", + batch_count=1, + forward_row_counts=[1], + forward_total_rows=1, + moe_groups={"mlp.experts.0": [named.name]}, + forward_device_map={}, + calibration_coverage_policy=stage_subset_module.CalibrationCoveragePolicy( + validate_input_coverage=False, + fallback_enabled=True, + prune_uncovered_modules=False, + record_dynamic_exclusions=False, + ), + module_chunks=[{named.name: named}], + ) + + stage_subset_module._emit_moe_parallel_quant_subset_telemetry( + plan=plan, + quant_target_devices={ + named.name: torch.device("cuda:1"), + "mlp.experts.0.up_proj": torch.device("cuda:2"), + }, + futures_count=2, + layer_index=3, + ) + + assert len(emitted) == 1 + event, fields = emitted[0] + assert event == "moe_parallel_quant_subset" + assert fields["layer_index"] == 3 + assert fields["submitted_tasks"] == 2 + assert fields["quant_devices"] == ["cuda:1", "cuda:2"] + assert fields["thread_pool_workers"] == {"cuda:1": 4, "cuda:2": 4} + assert fields["python_gil_disabled"] is True + assert fields["free_threaded_parallel_quant_active"] is True diff --git a/tests/test_tensor_parallel_padder.py b/tests/test_tensor_parallel_padder.py new file mode 100644 index 000000000..3457c4984 --- /dev/null +++ b/tests/test_tensor_parallel_padder.py @@ -0,0 +1,151 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +import copy + +import pytest +import torch +import torch.nn.functional as F + +from gptqmodel.looper.module_preprocessor import ModulePreProcessor +from gptqmodel.looper.named_module import NamedModule +from gptqmodel.quantization.config import QuantizeConfig, TensorParallelPadderConfig +from gptqmodel.quantization.gptq import GPTQ + + +@pytest.fixture(autouse=True) +def _disable_device_smi(monkeypatch): + monkeypatch.setattr(ModulePreProcessor, "_init_device_smi_handles", lambda self: {}) + monkeypatch.setattr(ModulePreProcessor, "_init_cpu_device_handle", lambda self: None) + + +def _build_preprocessor(qcfg: QuantizeConfig) -> ModulePreProcessor: + return ModulePreProcessor( + tokenizer=None, + qcfg=qcfg, + calibration=[], + prepare_dataset_func=None, + calibration_concat_size=None, + calibration_sort=None, + batch_size=1, + ) + + +def test_tensor_parallel_padder_applies_zero_pad_metadata(): + linear = torch.nn.Linear(10, 7, bias=False) + named = NamedModule(linear, name="proj", full_name="layer.0.proj", layer_index=0) + + qcfg = QuantizeConfig( + bits=4, + mock_quantization=True, + preprocessors=[TensorParallelPadderConfig()], + ) + qcfg.group_size = -1 + qcfg.desc_act = False + qcfg.act_group_aware = False + + _build_preprocessor(qcfg).preprocess(named) + + pad_info = named.state["tp_pad_info"] + assert pad_info["pad_cols"] == 6 + assert pad_info["original_columns"] == 10 + assert named.state["preprocessor_pipeline"][0]["code"] == "tensor_parallel_padder" + + gptq = GPTQ(named, qcfg) + gptq.quantizer.configure(perchannel=True) + + assert gptq._tp_pad_cols == 6 + assert gptq.columns == 16 + + inputs = torch.randn(32, 10) + outputs = linear(inputs) + gptq.add_batch(inputs, outputs) + + Q, scales, zeros, g_idx, *_ = gptq.quantize(blocksize=16) + + assert Q.shape == linear.weight.shape + assert scales.shape[-1] <= linear.weight.shape[1] + assert zeros.shape[-1] <= linear.weight.shape[1] + assert g_idx.numel() == linear.weight.shape[1] + + gptq.free() + assert "tp_pad_info" not in named.state + + +@pytest.mark.parametrize( + ("group_size", "expected_target_multiple"), + [ + (-1, 8), + (32, 32), + (64, 64), + (12, 24), + ], +) +def test_tensor_parallel_padder_uses_group_size_lcm(group_size: int, expected_target_multiple: int): + linear = torch.nn.Linear(10, 7, bias=False) + named = NamedModule(linear, name="proj", full_name="layer.0.proj", layer_index=0) + + qcfg = QuantizeConfig( + bits=4, + mock_quantization=True, + preprocessors=[TensorParallelPadderConfig()], + ) + qcfg.group_size = group_size + qcfg.desc_act = False + qcfg.act_group_aware = False + + _build_preprocessor(qcfg).preprocess(named) + + assert named.state["tp_pad_info"]["target_multiple"] == expected_target_multiple + + +def test_tensor_parallel_padder_does_not_change_quantized_matmul_output(): + torch.manual_seed(17) + + linear = torch.nn.Linear(10, 7, bias=False, dtype=torch.float32).eval() + calibration_inputs = torch.randn(64, 10, dtype=torch.float32) + calibration_outputs = linear(calibration_inputs) + eval_inputs = torch.randn(19, 10, dtype=torch.float32) + + baseline_named = NamedModule( + copy.deepcopy(linear), + name="proj", + full_name="layer.0.proj", + layer_index=0, + ) + baseline_qcfg = QuantizeConfig(bits=4, group_size=12) + baseline_qcfg.desc_act = False + baseline_qcfg.act_group_aware = False + + baseline_gptq = GPTQ(baseline_named, baseline_qcfg) + baseline_gptq.quantizer.configure(perchannel=True) + baseline_gptq.add_batch(calibration_inputs, calibration_outputs) + baseline_weight, *_ = baseline_gptq.quantize(blocksize=16) + baseline_gptq.free() + + padded_named = NamedModule( + copy.deepcopy(linear), + name="proj", + full_name="layer.0.proj", + layer_index=0, + ) + padded_qcfg = QuantizeConfig( + bits=4, + group_size=12, + preprocessors=[TensorParallelPadderConfig()], + ) + padded_qcfg.desc_act = False + padded_qcfg.act_group_aware = False + + _build_preprocessor(padded_qcfg).preprocess(padded_named) + + padded_gptq = GPTQ(padded_named, padded_qcfg) + padded_gptq.quantizer.configure(perchannel=True) + padded_gptq.add_batch(calibration_inputs, calibration_outputs) + padded_weight, *_ = padded_gptq.quantize(blocksize=16) + padded_gptq.free() + + baseline_output = F.linear(eval_inputs, baseline_weight) + padded_output = F.linear(eval_inputs, padded_weight) + + torch.testing.assert_close(padded_output, baseline_output, rtol=0.0, atol=0.0) diff --git a/tests/test_threadx.py b/tests/test_threadx.py index 0c7f994ac..22f7bf449 100644 --- a/tests/test_threadx.py +++ b/tests/test_threadx.py @@ -18,6 +18,8 @@ pytestmark = [ pytest.mark.cuda, + pytest.mark.cpu, + pytest.mark.gpu, pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required"), ] @@ -332,6 +334,10 @@ def noop(): monkeypatch.setattr(torch.cuda, "empty_cache", orig_empty) +@pytest.mark.xfail( + reason="Janitor retrigger timing remains runner-sensitive under shared multi-GPU load", + strict=False, +) def test_janitor_resets_device_watermark(pool, devices_two, monkeypatch): """ Ensure devices that only partially progressed before a GC pass still trigger @@ -370,8 +376,9 @@ def noop(): assert first_pass.wait(timeout=2.0) - # Device 1 finishes two more tasks (total=3) and should trigger another GC. - for _ in range(2): + # Device 1 was part of the first sweep, so it needs a fresh threshold worth + # of completions before the next GC pass is eligible. + for _ in range(3): pool.do(d1, noop) assert second_pass.wait(timeout=2.0) @@ -959,4 +966,3 @@ def cpu_task(): assert fut0.result(timeout=2) == 120 assert fut1.result(timeout=2) == 120 assert f_blocked.result(timeout=2) == 80 - diff --git a/tests/test_tiny_moe_quant_smoke.py b/tests/test_tiny_moe_quant_smoke.py new file mode 100644 index 000000000..4f14380ee --- /dev/null +++ b/tests/test_tiny_moe_quant_smoke.py @@ -0,0 +1,165 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +"""Tiny end-to-end MoE quantization smoke coverage. + +This test intentionally avoids the large on-disk MoE fixtures used elsewhere in +the suite. It builds a one-layer Qwen3 MoE model and a tiny local tokenizer in +temporary directories, then runs the real GPTQ save/load/quantize/save/reload +flow against that fixture. + +The goal is not kernel benchmarking or quality evaluation. The goal is a cheap +regression guard for the MoE lifecycle: +1. native HF MoE model can be loaded through GPT-QModel +2. MoE routing override can drive expert quantization +3. the quantized checkpoint can be reloaded +4. every expert gate/up/down projection is exported as a quantized linear +""" + +import os +from pathlib import Path + + +# Keep the smoke test on CPU so it stays small and works on CPU-only runners. +os.environ.setdefault("CUDA_VISIBLE_DEVICES", "") + +import pytest +from tokenizers import Tokenizer +from tokenizers.models import WordLevel +from tokenizers.pre_tokenizers import Whitespace +from tokenizers.trainers import WordLevelTrainer +from transformers import PreTrainedTokenizerFast +from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeConfig, Qwen3MoeForCausalLM + +from gptqmodel import BACKEND, GPTQModel, QuantizeConfig +from gptqmodel.nn_modules.qlinear.torch import TorchLinear +from gptqmodel.quantization.config import ExpertsRoutingOverride, MoEConfig + + +pytestmark = [pytest.mark.cpu, pytest.mark.slow] + + +_CALIBRATION_TEXTS = [ + "tiny moe calibration sample one with enough tokens to survive minimum length filtering and trigger expert routing", + "tiny moe calibration sample two with repeated expert words to exercise the moe quantization smoke path cleanly", + "another synthetic calibration example that is intentionally verbose so token filtering does not remove it", +] * 2 + + +def _build_local_tokenizer(model_dir: Path) -> PreTrainedTokenizerFast: + """Persist a minimal tokenizer because the GPT-QModel `load()` path expects local tokenizer files.""" + + tokenizer = Tokenizer(WordLevel(unk_token="[UNK]")) + tokenizer.pre_tokenizer = Whitespace() + + trainer = WordLevelTrainer( + special_tokens=["[PAD]", "[UNK]", "[BOS]", "[EOS]"], + ) + tokenizer.train_from_iterator(_CALIBRATION_TEXTS, trainer=trainer) + + fast_tokenizer = PreTrainedTokenizerFast( + tokenizer_object=tokenizer, + bos_token="[BOS]", + eos_token="[EOS]", + unk_token="[UNK]", + pad_token="[PAD]", + ) + fast_tokenizer.save_pretrained(model_dir) + return fast_tokenizer + + +def _build_tiny_qwen3_moe_fixture(model_dir: Path) -> tuple[Qwen3MoeConfig, PreTrainedTokenizerFast]: + """Save a tiny native HF MoE checkpoint that still exercises the real qwen3_moe path. + + Qwen3Moe is the lightest native MoE family available in this repo's test + dependencies because it does not carry Qwen2 MoE's large shared-expert + default weights. + """ + + config = Qwen3MoeConfig( + num_hidden_layers=1, + hidden_size=64, + intermediate_size=128, + moe_intermediate_size=32, + num_attention_heads=4, + num_key_value_heads=4, + num_experts=4, + num_experts_per_tok=2, + vocab_size=128, + pad_token_id=0, + bos_token_id=1, + eos_token_id=2, + ) + + model = Qwen3MoeForCausalLM(config) + model.save_pretrained(model_dir) + tokenizer = _build_local_tokenizer(model_dir) + return config, tokenizer + + +def _build_calibration_dataset(tokenizer: PreTrainedTokenizerFast) -> list[dict[str, object]]: + """Return the exact calibration shape accepted by prepare_calibration_dataset().""" + + dataset = [] + for text in _CALIBRATION_TEXTS: + encoded = tokenizer(text, return_tensors="pt") + dataset.append( + { + "input_ids": encoded["input_ids"], + "attention_mask": encoded["attention_mask"], + } + ) + return dataset + + +def test_tiny_qwen3_moe_quantization_smoke(tmp_path: Path): + """Quantize and reload a tiny local MoE model, then assert all expert projections are quantized.""" + + model_dir = tmp_path / "native" + quantized_dir = tmp_path / "quantized" + + config, tokenizer = _build_tiny_qwen3_moe_fixture(model_dir) + calibration = _build_calibration_dataset(tokenizer) + + quantize_config = QuantizeConfig( + bits=4, + group_size=32, + desc_act=False, + device="cpu", + moe=MoEConfig(routing=ExpertsRoutingOverride()), + ) + + model = GPTQModel.load( + str(model_dir), + quantize_config=quantize_config, + backend=BACKEND.TORCH, + ) + model.quantize( + calibration, + batch_size=1, + backend=BACKEND.TORCH, + calibration_data_min_length=1, + ) + model.save(quantized_dir) + + quantized_model = GPTQModel.load( + str(quantized_dir), + backend=BACKEND.TORCH, + device="cpu", + ) + + # Assert the full expert set was quantized, not just whichever experts the + # natural router happened to hit in this tiny calibration sample. + modules = dict(quantized_model.named_modules()) + expected_quantized = config.num_experts * 3 + quantized_expert_modules = [] + + for expert_index in range(config.num_experts): + for suffix in ("gate_proj", "up_proj", "down_proj"): + module_name = f"model.model.layers.0.mlp.experts.{expert_index}.{suffix}" + module = modules[module_name] + assert isinstance(module, TorchLinear), module_name + quantized_expert_modules.append(module_name) + + assert len(quantized_expert_modules) == expected_quantized + assert quantized_model.quantize_config.meta_get("moe")["routing"]["class"] == "ExpertsRoutingOverride" diff --git a/tests/test_torch.py b/tests/test_torch.py index a31e76eee..3adfa1a37 100644 --- a/tests/test_torch.py +++ b/tests/test_torch.py @@ -8,10 +8,11 @@ import torch import torch.nn as nn +import gptqmodel.utils.torch as torch_utils from gptqmodel.nn_modules.qlinear import PackableQuantLinear from gptqmodel.nn_modules.qlinear.lookahead import configure_default_lookahead -from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear -from gptqmodel.nn_modules.qlinear.tritonv2 import TritonV2QuantLinear +from gptqmodel.nn_modules.qlinear.torch import TorchLinear +from gptqmodel.nn_modules.qlinear.tritonv2 import TritonV2Linear def _mock_gptq_linear(bits: int, group_size: int, in_features: int, out_features: int) -> tuple[nn.Linear, torch.Tensor, torch.Tensor, torch.Tensor]: @@ -72,7 +73,7 @@ def test_torch_triton_large_group_sizes(group_size: int, dtype: torch.dtype) -> linear, scales, zeros, g_idx = _mock_gptq_linear(bits, group_size, in_features, out_features) - torch_module = TorchQuantLinear( + torch_module = TorchLinear( bits=bits, group_size=group_size, sym=True, @@ -86,7 +87,7 @@ def test_torch_triton_large_group_sizes(group_size: int, dtype: torch.dtype) -> torch_module.post_init() try: - triton_module = TritonV2QuantLinear( + triton_module = TritonV2Linear( bits=bits, group_size=group_size, desc_act=False, @@ -123,7 +124,7 @@ def test_torch_triton_large_group_sizes(group_size: int, dtype: torch.dtype) -> def _make_module(device: torch.device): - module = TorchQuantLinear( + module = TorchLinear( bits=4, group_size=32, sym=True, @@ -149,7 +150,7 @@ def _make_module(device: torch.device): def test_gptq_post_init_creates_wf_unpack_buffers(): - module = TorchQuantLinear( + module = TorchLinear( bits=4, group_size=32, sym=True, @@ -169,6 +170,31 @@ def test_gptq_post_init_creates_wf_unpack_buffers(): assert module.wf_unsqueeze_neg_one is not None +def test_torch_quant_linear_exposes_weight_metadata(): + module = TorchLinear( + bits=4, + group_size=32, + sym=True, + desc_act=False, + in_features=64, + out_features=96, + bias=False, + pack_dtype=torch.int32, + adapter=None, + register_buffers=True, + ) + + weight = module.weight + + assert weight.device == module.qweight.device + assert weight.dtype == module.scales.dtype + assert weight.shape == torch.Size((module.out_features, module.in_features)) + assert weight.size(0) == module.out_features + assert weight.size(1) == module.in_features + assert weight.T.shape == torch.Size((module.in_features, module.out_features)) + assert ("cuda" in weight.device.type) == weight.is_cuda + + def test_cached_forward_matches_baseline(): device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") module = _make_module(device) @@ -187,6 +213,32 @@ def test_cached_forward_matches_baseline(): assert module._cached_weights[x.dtype].device.type == device.type +def test_torch_empty_cache_syncs_before_releasing_allocator(monkeypatch): + calls = [] + device = torch.device("cpu") + + monkeypatch.setattr(torch_utils, "timed_gc_collect", lambda: calls.append("gc") or 0) + monkeypatch.setattr(torch_utils, "torch_sync", lambda device=None: calls.append(("sync", device))) + monkeypatch.setattr(torch_utils, "empty_cache_for_device", lambda device: calls.append(("empty", device)) or True) + + assert torch_utils.torch_empty_cache(device=device, gc=True, sync=True) is True + assert calls == ["gc", ("sync", device), ("empty", device)] + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="requires at least 2 CUDA devices") +def test_cross_device_forward_moves_weights_to_input_device(): + module = _make_module(torch.device("cuda:1")) + module.enable_weight_cache(True) + module.clear_weight_cache() + + x = torch.randn(8, module.in_features, device=torch.device("cuda:0"), dtype=torch.float16) + out = module(x) + + assert out.device == x.device + assert x.dtype in module._cached_weights + assert module._cached_weights[x.dtype].device == x.device + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required for lookahead prefetch test") def test_lookahead_prefetch_single_step(): device = torch.device("cuda") @@ -259,7 +311,7 @@ def __init__(self): model = DummyModel() for module in model.modules(): - if isinstance(module, TorchQuantLinear): + if isinstance(module, TorchLinear): module.enable_lookahead(True) configure_default_lookahead(model) @@ -294,7 +346,7 @@ def test_cpu_dequant_parity_and_g_idx_cache_allocation(): torch.manual_seed(0) linear, scales, zeros, g_idx = _mock_gptq_linear(bits, group_size, in_features, out_features) - module = TorchQuantLinear( + module = TorchLinear( bits=bits, group_size=group_size, sym=True, @@ -355,7 +407,7 @@ def test_cpu_cached_dequant_num_itr_matches_packable(): torch.manual_seed(0) linear, scales, zeros, g_idx = _mock_gptq_linear(bits, group_size, in_features, out_features) - module = TorchQuantLinear( + module = TorchLinear( bits=bits, group_size=group_size, sym=True, diff --git a/tests/test_torch_aten_kernel_import_guard.py b/tests/test_torch_aten_kernel_import_guard.py new file mode 100644 index 000000000..1d93fb484 --- /dev/null +++ b/tests/test_torch_aten_kernel_import_guard.py @@ -0,0 +1,87 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +import builtins + +import pytest +import torch + +from gptqmodel.nn_modules.qlinear.torch_aten_kernel import TorchAtenLinear, _cpu_int4pack_zero_offsets +from gptqmodel.nn_modules.qlinear.torch_aten_kernel_awq import TorchAtenAwqLinear +from gptqmodel.utils import python as python_utils + + +def test_free_threading_build_helper_uses_py_gil_disabled(monkeypatch): + monkeypatch.setattr( + python_utils.sysconfig, + "get_config_var", + lambda key: 1 if key == "Py_GIL_DISABLED" else None, + ) + assert python_utils.is_free_threading_build() + + monkeypatch.setattr( + python_utils.sysconfig, + "get_config_var", + lambda key: 0 if key == "Py_GIL_DISABLED" else None, + ) + assert not python_utils.is_free_threading_build() + + +@pytest.mark.parametrize("kernel_cls", [TorchAtenLinear, TorchAtenAwqLinear]) +def test_torch_aten_kernel_validate_once_does_not_import_external_kernels(monkeypatch, kernel_cls): + attempted = {"value": False} + original_import = builtins.__import__ + + def guarded_import(name, globals=None, locals=None, fromlist=(), level=0): + if name == "kernels" or name.startswith("kernels."): + attempted["value"] = True + raise AssertionError(f"unexpected import of {name}") + return original_import(name, globals, locals, fromlist, level) + + monkeypatch.setattr(builtins, "__import__", guarded_import) + + ok, err = kernel_cls.validate_once() + + assert not attempted["value"] + has_ops = ( + hasattr(torch.ops.aten, "_convert_weight_to_int4pack_for_cpu") + and hasattr(torch.ops.aten, "_weight_int4pack_mm_for_cpu") + ) + assert ok is has_ops + if has_ops: + assert err is None + else: + assert isinstance(err, ImportError) + + +@pytest.mark.skipif( + not ( + hasattr(torch.ops.aten, "_convert_weight_to_int4pack_for_cpu") + and hasattr(torch.ops.aten, "_weight_int4pack_mm_for_cpu") + ), + reason="CPU int4pack ATen ops are unavailable in this PyTorch build.", +) +def test_cpu_int4pack_zero_offsets_match_dense_gptq_formula(): + out_features = 16 + in_features = 32 + group_size = 32 + code = 5 + zero_code = 3 + scale = 2.0 + + unpacked_weight = torch.full((out_features, in_features), code, dtype=torch.int32) + packed_weight = torch.ops.aten._convert_weight_to_int4pack_for_cpu(unpacked_weight, 1) + + scales = torch.full((1, out_features), scale, dtype=torch.bfloat16) + zero_codes = torch.full((1, out_features), zero_code, dtype=torch.uint8) + zero_offsets = _cpu_int4pack_zero_offsets(zero_codes, scales, bits=4) + + scales_and_zeros = torch.zeros((1, out_features, 2), dtype=torch.bfloat16) + scales_and_zeros[:, :, 0] = scales + scales_and_zeros[:, :, 1] = zero_offsets + + x = torch.zeros((1, in_features), dtype=torch.bfloat16) + x[0, 0] = 1 + + out = torch.ops.aten._weight_int4pack_mm_for_cpu(x, packed_weight, group_size, scales_and_zeros) + assert float(out[0, 0]) == pytest.approx(scale * (code - zero_code)) diff --git a/tests/test_torch_ops_jit_extension.py b/tests/test_torch_ops_jit_extension.py new file mode 100644 index 000000000..6695bcbd4 --- /dev/null +++ b/tests/test_torch_ops_jit_extension.py @@ -0,0 +1,605 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import sys +import threading +import time +from pathlib import Path + +from gptqmodel.utils import cpp as cpp_module + + +class _FakeSpinner: + """Capture spinner lifecycle so tests can assert compile-stage UX hooks.""" + + def __init__(self, title: str): + self.title = title + self.closed = False + + def close(self): + self.closed = True + + +class _FakeLogger: + """Collect durable info logs and spinner titles without rendering output.""" + + def __init__(self): + self.info_messages: list[str] = [] + self.spinners: list[_FakeSpinner] = [] + + def info(self, message: str): + self.info_messages.append(message) + + def spinner(self, title: str = "", *, interval: float = 0.5, tail_length: int = 4): + del interval, tail_length + spinner = _FakeSpinner(title) + self.spinners.append(spinner) + return spinner + + +def _make_loader(tmp_path: Path, **overrides) -> cpp_module.TorchOpsJitExtension: + """Construct a shared torch.ops loader with a disposable build root.""" + + params = { + "name": "unit_test_ops", + "namespace": "unit_test_ns", + "required_ops": ("kernel",), + "sources": ["unit_test.cpp"], + "build_root_env": "UNIT_TEST_BUILD_ROOT", + "default_build_root": lambda: tmp_path / "jit_build", + "display_name": "Unit Test Kernel", + "requires_cuda": False, + } + params.update(overrides) + return cpp_module.TorchOpsJitExtension(**params) + + +def test_default_jit_cflags_allow_noopt(monkeypatch): + """Guard the noopt path so callers can intentionally omit all `-O*` flags.""" + + monkeypatch.delenv("GPTQMODEL_NVCC_COMPILE_LEVEL", raising=False) + + flags = cpp_module.default_jit_cflags(opt_level=None) + + assert "-std=c++17" in flags + assert not any(flag.startswith("-O") for flag in flags) + + +def test_default_jit_cuda_cflags_respect_o2_override(monkeypatch): + """Guard the global override so kernels can be forced onto one explicit optimization level.""" + + monkeypatch.setenv("GPTQMODEL_NVCC_COMPILE_LEVEL", "O2") + + flags = cpp_module.default_jit_cuda_cflags( + opt_level="O3", + include_nvcc_threads=True, + include_ptxas_optimizations=True, + ) + + assert "-O2" in flags + assert "-O3" not in flags + assert "--optimize=2" in flags + assert flags[flags.index("-Xptxas") + 1] == "-v,-O2,-dlcm=ca" + + +def test_default_jit_cuda_cflags_respect_noopt_override(monkeypatch): + """Guard the noopt override so users can disable every emitted `-O*` flag when needed.""" + + monkeypatch.setenv("GPTQMODEL_NVCC_COMPILE_LEVEL", "NONE") + + flags = cpp_module.default_jit_cuda_cflags( + opt_level="O3", + include_nvcc_threads=True, + include_ptxas_optimizations=True, + ) + + assert not any(flag.startswith("-O") for flag in flags) + assert not any(flag.startswith("--optimize=") for flag in flags) + assert flags[flags.index("-Xptxas") + 1] == "-v,-dlcm=ca" + + +def test_default_jit_cuda_cflags_allow_quiet_ptxas(monkeypatch): + """Guard per-kernel PTXAS verbosity overrides so AWQ can suppress giant compile logs.""" + + monkeypatch.delenv("GPTQMODEL_NVCC_COMPILE_LEVEL", raising=False) + + flags = cpp_module.default_jit_cuda_cflags( + include_ptxas_optimizations=True, + include_ptxas_verbosity=False, + ) + + assert flags[flags.index("-Xptxas") + 1] == "-O3,-dlcm=ca" + + +def test_detected_cuda_wheel_include_paths_discovers_merged_and_split_layouts(monkeypatch, tmp_path): + """Guard CUDA wheel header discovery so JIT kernels can see NVIDIA pip headers.""" + + nvidia_root = tmp_path / "site-packages" / "nvidia" + (nvidia_root / "cu13" / "include").mkdir(parents=True) + (nvidia_root / "cusparse" / "include").mkdir(parents=True) + (nvidia_root / "cublas" / "include").mkdir(parents=True) + + fake_nvidia = type("FakeNvidia", (), {"__path__": [str(nvidia_root)]})() + monkeypatch.setitem(sys.modules, "nvidia", fake_nvidia) + + assert cpp_module.detected_cuda_wheel_include_paths() == [ + str(nvidia_root / "cu13" / "include"), + str(nvidia_root / "cublas" / "include"), + str(nvidia_root / "cusparse" / "include"), + ] + + +def test_detected_local_cuda_include_paths_prefers_cuda_home(monkeypatch, tmp_path): + """Guard local CUDA header discovery so JIT builds can skip wheel headers when toolkit headers exist.""" + + cuda_home = tmp_path / "cuda-toolkit" + (cuda_home / "include").mkdir(parents=True) + + monkeypatch.setattr(cpp_module, "CUDA_HOME", str(cuda_home)) + monkeypatch.delenv("CUDA_PATH", raising=False) + + assert cpp_module.detected_local_cuda_include_paths() == [str(cuda_home / "include")] + + +def test_cuda_include_paths_with_fallback_use_wheel_headers_when_local_cuda_is_incomplete(monkeypatch, tmp_path): + """Guard shared CUDA header fallback so incomplete local toolkits still build JIT extensions.""" + + local_cuda_include = tmp_path / "local_cuda_include" + wheel_cuda_include = tmp_path / "wheel_cuda_include" + local_cuda_include.mkdir() + wheel_cuda_include.mkdir() + (wheel_cuda_include / "cusparse.h").write_text("// stub", encoding="utf-8") + + monkeypatch.setattr(cpp_module, "detected_local_cuda_include_paths", lambda: [str(local_cuda_include)]) + monkeypatch.setattr(cpp_module, "detected_cuda_wheel_include_paths", lambda: [str(wheel_cuda_include)]) + + include_paths = cpp_module.cuda_include_paths_with_fallback( + ["/tmp/extension"], + required_header_names=("cusparse.h",), + ) + + assert include_paths == ["/tmp/extension", str(wheel_cuda_include)] + + +def test_cuda_include_paths_with_fallback_skip_wheel_headers_when_local_cuda_has_required_headers( + monkeypatch, + tmp_path, +): + """Guard shared CUDA header fallback so complete local toolkits do not mix in wheel headers.""" + + local_cuda_include = tmp_path / "local_cuda_include" + wheel_cuda_include = tmp_path / "wheel_cuda_include" + local_cuda_include.mkdir() + wheel_cuda_include.mkdir() + (local_cuda_include / "cusparse.h").write_text("// stub", encoding="utf-8") + + monkeypatch.setattr(cpp_module, "detected_local_cuda_include_paths", lambda: [str(local_cuda_include)]) + monkeypatch.setattr(cpp_module, "detected_cuda_wheel_include_paths", lambda: [str(wheel_cuda_include)]) + + include_paths = cpp_module.cuda_include_paths_with_fallback( + ["/tmp/extension"], + required_header_names=("cusparse.h",), + ) + + assert include_paths == ["/tmp/extension"] + + +def test_cuda_cache_fingerprint_payload_includes_resolved_arch_flags(monkeypatch, tmp_path): + """Guard CUDA cache keys so stale binaries cannot cross architecture targets.""" + + loader = _make_loader(tmp_path, requires_cuda=True) + + monkeypatch.delenv("TORCH_CUDA_ARCH_LIST", raising=False) + monkeypatch.setattr(cpp_module.torch.cuda, "is_available", lambda: True) + monkeypatch.setattr(cpp_module.torch.cuda, "device_count", lambda: 1) + monkeypatch.setattr(cpp_module.torch.cuda, "get_device_capability", lambda _index: (12, 0)) + monkeypatch.setattr( + cpp_module, + "resolved_cuda_arch_flags", + lambda: [ + "-gencode=arch=compute_120,code=compute_120", + "-gencode=arch=compute_120,code=sm_120", + ], + ) + + payload = loader._cuda_cache_fingerprint_payload() + + assert payload == [ + "cuda_ext=1", + "visible_caps=12.0", + ( + "resolved_arch_flags=" + "-gencode=arch=compute_120,code=compute_120," + "-gencode=arch=compute_120,code=sm_120" + ), + ] + + +def test_default_torch_ops_build_root_ignores_removed_global_override(monkeypatch): + monkeypatch.setenv("GPTQMODEL_EXT_BUILD_BASE", "/tmp/obsolete-jit-root") + + assert cpp_module.default_torch_ops_build_root("marlin") == ( + Path.home() / ".cache" / "gptqmodel" / "torch_extensions" / "marlin" + ) + + +def test_torch_ops_jit_extension_prefers_cached_binary(monkeypatch, tmp_path): + """Guard cache reuse so startup skips expensive JIT rebuilds when ops are already built.""" + + loader = _make_loader(tmp_path) + build_root = loader.build_root() + build_root.mkdir(parents=True) + library_path = build_root / "unit_test_ops.so" + library_path.write_bytes(b"placeholder") + + state = {"ready": False} + load_library_calls = [] + compile_calls = [] + + monkeypatch.setattr(loader, "_ops_available", lambda: state["ready"]) + + def fake_load_library(path: str): + load_library_calls.append(path) + state["ready"] = True + + monkeypatch.setattr(cpp_module.torch.ops, "load_library", fake_load_library, raising=False) + monkeypatch.setattr(cpp_module, "load", lambda **kwargs: compile_calls.append(kwargs) or None) + + assert loader.load() is True + assert load_library_calls == [str(library_path)] + assert compile_calls == [] + + +def test_torch_ops_jit_extension_force_rebuild_clears_cache(monkeypatch, tmp_path): + """Guard force-rebuild mode so stale cached libraries never short-circuit a requested rebuild.""" + + loader = _make_loader(tmp_path, force_rebuild_env="UNIT_TEST_FORCE_REBUILD") + build_root = loader.build_root() + build_root.mkdir(parents=True) + stale_library = build_root / "unit_test_ops.so" + stale_library.write_bytes(b"stale") + + state = {"ready": False} + compile_calls = [] + logger = _FakeLogger() + runtime = type("RuntimeNamespace", (), {"kernel": object()})() + + monkeypatch.setenv("UNIT_TEST_FORCE_REBUILD", "1") + monkeypatch.setattr(loader, "_ops_available", lambda: state["ready"]) + monkeypatch.setattr(cpp_module, "setup_logger", lambda: logger) + monkeypatch.setattr( + cpp_module.torch.ops, + "load_library", + lambda path: (_ for _ in ()).throw(AssertionError(f"unexpected cached load: {path}")), + raising=False, + ) + + def fake_compile(**kwargs): + compile_calls.append(kwargs) + state["ready"] = True + monkeypatch.setattr(cpp_module.torch.ops, "unit_test_ns", runtime, raising=False) + + monkeypatch.setattr(cpp_module, "load", fake_compile) + + assert loader.load() is True + assert len(compile_calls) == 1 + assert stale_library.exists() is False + assert any("clearing cached JIT extension" in message for message in logger.info_messages) + + +def test_torch_ops_jit_extension_emits_spinner_logs_around_compile(monkeypatch, tmp_path): + """Guard compile UX so users get explicit progress feedback before and after JIT build stalls.""" + + loader = _make_loader( + tmp_path, + extra_cflags=["-O3"], + extra_cuda_cflags=["-lineinfo"], + extra_include_paths=["/tmp/include"], + extra_ldflags=["-lm"], + ) + + state = {"ready": False} + logger = _FakeLogger() + compile_calls = [] + runtime = type("RuntimeNamespace", (), {"kernel": object()})() + + monkeypatch.setattr(loader, "_ops_available", lambda: state["ready"]) + monkeypatch.setattr(cpp_module, "setup_logger", lambda: logger) + + def fake_compile(**kwargs): + compile_calls.append(kwargs) + state["ready"] = True + monkeypatch.setattr(cpp_module.torch.ops, "unit_test_ns", runtime, raising=False) + + monkeypatch.setattr(cpp_module, "load", fake_compile) + + assert loader.load() is True + assert len(compile_calls) == 1 + assert compile_calls[0]["is_python_module"] is False + assert compile_calls[0]["sources"] == ["unit_test.cpp"] + assert compile_calls[0]["extra_cflags"] == ["-O3"] + assert compile_calls[0]["extra_cuda_cflags"] == ["-lineinfo"] + assert compile_calls[0]["extra_include_paths"] == ["/tmp/include"] + assert compile_calls[0]["extra_ldflags"] == ["-lm"] + assert logger.spinners + assert logger.spinners[0].title == "Compiling extension: Unit Test Kernel..." + assert logger.spinners[0].closed is True + assert any("compiling torch.ops JIT extension" in message for message in logger.info_messages) + assert any("torch.ops JIT extension ready" in message for message in logger.info_messages) + + +def test_torch_ops_jit_extension_appends_detected_cuda_include_paths(monkeypatch, tmp_path): + """Guard CUDA JIT kwargs so detected NVIDIA wheel headers reach the compiler.""" + + loader = _make_loader( + tmp_path, + requires_cuda=True, + extra_include_paths=["/tmp/include"], + ) + + state = {"ready": False} + compile_calls = [] + runtime = type("RuntimeNamespace", (), {"kernel": object()})() + + monkeypatch.setattr(loader, "_ops_available", lambda: state["ready"]) + monkeypatch.setattr(cpp_module.torch.cuda, "is_available", lambda: True) + monkeypatch.setattr(cpp_module, "detected_local_cuda_include_paths", lambda: []) + monkeypatch.setattr( + cpp_module, + "detected_cuda_wheel_include_paths", + lambda: ["/tmp/nvidia/cu13/include", "/tmp/nvidia/cusparse/include"], + ) + + def fake_compile(**kwargs): + compile_calls.append(kwargs) + state["ready"] = True + monkeypatch.setattr(cpp_module.torch.ops, "unit_test_ns", runtime, raising=False) + + monkeypatch.setattr(cpp_module, "load", fake_compile) + + assert loader.load() is True + assert compile_calls[0]["extra_include_paths"] == [ + "/tmp/include", + "/tmp/nvidia/cu13/include", + "/tmp/nvidia/cusparse/include", + ] + + +def test_torch_ops_jit_extension_uses_original_compile_paths(monkeypatch, tmp_path): + local_root = tmp_path / "repo" + (local_root / "src").mkdir(parents=True) + (local_root / "include").mkdir(parents=True) + source_path = local_root / "src" / "unit_test.cpp" + include_path = local_root / "include" + source_path.write_text('#include "unit_test.h"\nint kernel() { return 1; }\n', encoding="utf-8") + (include_path / "unit_test.h").write_text("inline int unit_test_header() { return 1; }\n", encoding="utf-8") + + loader = _make_loader( + tmp_path, + sources=[str(source_path)], + extra_include_paths=[str(include_path), "/usr/local/cuda/include"], + ) + + state = {"ready": False} + compile_calls = [] + runtime = type("RuntimeNamespace", (), {"kernel": object()})() + + monkeypatch.setattr(loader, "_ops_available", lambda: state["ready"]) + + def fake_compile(**kwargs): + compile_calls.append(kwargs) + state["ready"] = True + monkeypatch.setattr(cpp_module.torch.ops, "unit_test_ns", runtime, raising=False) + + monkeypatch.setattr(cpp_module, "load", fake_compile) + + assert loader.load() is True + assert compile_calls[0]["sources"] == [str(source_path)] + assert compile_calls[0]["extra_include_paths"] == [str(include_path), "/usr/local/cuda/include"] + + +def test_torch_ops_jit_extension_skips_cuda_wheel_include_paths_when_local_headers_exist(monkeypatch, tmp_path): + """Guard local-toolkit builds so wheel headers do not get mixed into one CUDA compile invocation.""" + + loader = _make_loader( + tmp_path, + requires_cuda=True, + extra_include_paths=["/tmp/include"], + ) + + state = {"ready": False} + compile_calls = [] + runtime = type("RuntimeNamespace", (), {"kernel": object()})() + + monkeypatch.setattr(loader, "_ops_available", lambda: state["ready"]) + monkeypatch.setattr(cpp_module.torch.cuda, "is_available", lambda: True) + monkeypatch.setattr(cpp_module, "detected_local_cuda_include_paths", lambda: ["/usr/local/cuda/include"]) + monkeypatch.setattr( + cpp_module, + "detected_cuda_wheel_include_paths", + lambda: ["/tmp/nvidia/cu13/include", "/tmp/nvidia/cusparse/include"], + ) + + def fake_compile(**kwargs): + compile_calls.append(kwargs) + state["ready"] = True + monkeypatch.setattr(cpp_module.torch.ops, "unit_test_ns", runtime, raising=False) + + monkeypatch.setattr(cpp_module, "load", fake_compile) + + assert loader.load() is True + assert compile_calls[0]["extra_include_paths"] == ["/tmp/include"] + + +def test_torch_ops_jit_extension_reuses_cached_namespace_after_first_load(monkeypatch, tmp_path): + """Guard steady-state hot paths so repeated runtime checks skip torch.ops probing after first success.""" + + loader = _make_loader(tmp_path) + runtime = type("RuntimeNamespace", (), {"kernel": object()})() + monkeypatch.setattr(cpp_module.torch.ops, "unit_test_ns", runtime, raising=False) + + state = {"ready": True} + monkeypatch.setattr(loader, "_ops_available", lambda: state["ready"]) + + assert loader.load() is True + assert loader.op("kernel") is runtime.kernel + + def unexpected_probe(): + raise AssertionError("steady-state load should not re-probe torch.ops after first success") + + monkeypatch.setattr(loader, "_ops_available", unexpected_probe) + + assert loader.load() is True + assert loader.namespace_object() is runtime + assert loader.op("kernel") is runtime.kernel + + +def test_torch_ops_jit_extension_serializes_different_extensions_with_one_shared_lock(monkeypatch, tmp_path): + """Guard that different JIT extensions do not compile in parallel.""" + + loader_a = _make_loader( + tmp_path, + name="unit_test_ops_a", + namespace="unit_test_ns_a", + ) + loader_b = _make_loader( + tmp_path, + name="unit_test_ops_b", + namespace="unit_test_ns_b", + ) + + states = { + "unit_test_ops_a": False, + "unit_test_ops_b": False, + } + runtime_a = type("RuntimeNamespaceA", (), {"kernel": object()})() + runtime_b = type("RuntimeNamespaceB", (), {"kernel": object()})() + logger = _FakeLogger() + compile_tracker = { + "active": 0, + "max_active": 0, + } + compile_tracker_lock = threading.Lock() + start_barrier = threading.Barrier(3) + errors: list[Exception] = [] + + monkeypatch.setattr(loader_a, "_ops_available", lambda: states["unit_test_ops_a"]) + monkeypatch.setattr(loader_b, "_ops_available", lambda: states["unit_test_ops_b"]) + monkeypatch.setattr(cpp_module.torch.ops, "unit_test_ns_a", runtime_a, raising=False) + monkeypatch.setattr(cpp_module.torch.ops, "unit_test_ns_b", runtime_b, raising=False) + monkeypatch.setattr(cpp_module, "setup_logger", lambda: logger) + + def fake_compile(**kwargs): + extension_name = kwargs["name"] + with compile_tracker_lock: + compile_tracker["active"] += 1 + compile_tracker["max_active"] = max(compile_tracker["max_active"], compile_tracker["active"]) + time.sleep(0.02) + states[extension_name] = True + with compile_tracker_lock: + compile_tracker["active"] -= 1 + + monkeypatch.setattr(cpp_module, "load", fake_compile) + + def runner(loader): + try: + start_barrier.wait(timeout=1.0) + assert loader.load() is True + except Exception as exc: # pragma: no cover - assertion path below + errors.append(exc) + + threads = [ + threading.Thread(target=runner, args=(loader_a,)), + threading.Thread(target=runner, args=(loader_b,)), + ] + for thread in threads: + thread.start() + start_barrier.wait(timeout=1.0) + for thread in threads: + thread.join() + + assert errors == [] + assert compile_tracker["max_active"] == 1 + + +def test_torch_ops_jit_extension_cuda_fingerprint_tracks_visible_capabilities(monkeypatch, tmp_path): + """Guard CUDA cache keys so binaries do not get reused across incompatible GPU architectures.""" + + loader = _make_loader(tmp_path, requires_cuda=True) + + monkeypatch.delenv("TORCH_CUDA_ARCH_LIST", raising=False) + monkeypatch.setattr(cpp_module.torch.cuda, "is_available", lambda: True) + monkeypatch.setattr(cpp_module.torch.cuda, "device_count", lambda: 2) + monkeypatch.setattr( + cpp_module.torch.cuda, + "get_device_capability", + lambda device_index=0: (8, 9) if device_index == 0 else (8, 0), + ) + first_build_root = loader.build_root() + + monkeypatch.setattr( + cpp_module.torch.cuda, + "get_device_capability", + lambda device_index=0: (8, 9) if device_index == 0 else (9, 0), + ) + second_build_root = loader.build_root() + + assert first_build_root != second_build_root + + +def test_torch_ops_jit_extension_cuda_fingerprint_prefers_arch_override(monkeypatch, tmp_path): + """Guard explicit arch overrides so manual build targets produce isolated caches.""" + + loader = _make_loader(tmp_path, requires_cuda=True) + + monkeypatch.setenv("TORCH_CUDA_ARCH_LIST", "8.0") + first_build_root = loader.build_root() + + monkeypatch.setenv("TORCH_CUDA_ARCH_LIST", "8.9+PTX") + second_build_root = loader.build_root() + + assert first_build_root != second_build_root + + +def test_torch_ops_jit_extension_cuda_fingerprint_tracks_detected_include_paths(monkeypatch, tmp_path): + """Guard cache keys so CUDA wheel header layout changes invalidate old JIT binaries.""" + + loader = _make_loader(tmp_path, requires_cuda=True) + + monkeypatch.setenv("TORCH_CUDA_ARCH_LIST", "8.0") + monkeypatch.setattr(cpp_module, "detected_local_cuda_include_paths", lambda: []) + monkeypatch.setattr(cpp_module, "detected_cuda_wheel_include_paths", lambda: ["/tmp/nvidia/cu13/include"]) + first_build_root = loader.build_root() + + monkeypatch.setattr( + cpp_module, + "detected_cuda_wheel_include_paths", + lambda: ["/tmp/nvidia/cu13/include", "/tmp/nvidia/cusparse/include"], + ) + second_build_root = loader.build_root() + + assert first_build_root != second_build_root + + +def test_torch_ops_jit_extension_fingerprint_tracks_transitive_local_includes(tmp_path): + """Guard cache keys so changes under quoted transitive includes rebuild stale entrypoint binaries.""" + + source_root = tmp_path / "src" + source_root.mkdir() + entry = source_root / "entry.cpp" + middle = source_root / "middle.h" + leaf = source_root / "leaf.inc" + + entry.write_text('#include "middle.h"\nint kernel() { return answer(); }\n', encoding="utf-8") + middle.write_text('#include "leaf.inc"\ninline int answer() { return ANSWER_VALUE; }\n', encoding="utf-8") + leaf.write_text("#define ANSWER_VALUE 1\n", encoding="utf-8") + + loader = _make_loader(tmp_path, sources=[str(entry)]) + first_build_root = loader.build_root() + + leaf.write_text("#define ANSWER_VALUE 12345\n", encoding="utf-8") + second_build_root = loader.build_root() + + assert first_build_root != second_build_root diff --git a/tests/test_torch_xpu.py b/tests/test_torch_xpu.py index 0547b79af..66985990e 100644 --- a/tests/test_torch_xpu.py +++ b/tests/test_torch_xpu.py @@ -39,7 +39,12 @@ def test(self): backend=BACKEND.TORCH, device=DEVICE.XPU, ) - generate_str = tokenizer.decode(model.generate(**tokenizer("The capital of France is is", return_tensors="pt").to(model.device), max_new_tokens=2)[0]) + generate_str = self.generate_stable_with_limit( + model, + tokenizer, + "The capital of France is is", + max_new_tokens=2, + ) print(f"generate_str: {generate_str}") diff --git a/tests/test_vllm.py b/tests/test_vllm.py index 2dca623c6..03a2eee5a 100644 --- a/tests/test_vllm.py +++ b/tests/test_vllm.py @@ -12,69 +12,96 @@ import importlib.util # noqa: E402 import tempfile # noqa: E402 +import unittest # noqa: E402 +from pathlib import Path # noqa: E402 -from models.model_test import ModelTest # noqa: E402 +import pytest # noqa: E402 from transformers import AutoTokenizer # noqa: E402 -from gptqmodel import BACKEND, GPTQModel, QuantizeConfig # noqa: E402 +from gptqmodel import GPTQModel, QuantizeConfig # noqa: E402 from gptqmodel.nn_modules.qlinear import BaseQuantLinear # noqa: E402 from gptqmodel.utils.torch import torch_empty_cache # noqa: E402 +from tests.eval import evaluate, get_eval_task_metrics, import_evalution # noqa: E402 + +from .models.model_test import ModelTest # noqa: E402 + + +pytestmark = [pytest.mark.model, pytest.mark.slow] class TestLoadVLLM(ModelTest): + TASK_NAME = "arc_challenge" @classmethod def setUpClass(self): if ((importlib.util.find_spec("flashinfer") is None and importlib.util.find_spec("flashinfer-python") is None) or importlib.util.find_spec("vllm") is None): - raise RuntimeError("flashinfer and vllm are required by this test. you can install them by `pip install gptqmodel['vllm']`") + raise unittest.SkipTest( + "flashinfer and vllm are required by this test. install via `pip install gptqmodel['vllm']`" + ) - from vllm import SamplingParams # noqa: E402 + try: + import vllm._C # noqa: F401,E402 + except Exception as exc: + raise unittest.SkipTest(f"vllm runtime unavailable: {exc}") + try: + import_evalution() + except ValueError as exc: + raise unittest.SkipTest(str(exc)) self.MODEL_ID = "/monster/data/model/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit" self.SHARDED_MODEL_ID = "/monster/data/model/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit-sharded" - self.prompts = [ - self.INFERENCE_PROMPT, - ] - self.sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=16, top_k=1) + self.NATIVE_MODEL_ID = "/monster/data/model/TinyLlama-1.1B-Chat-v1.0" + for model_path in (self.MODEL_ID, self.SHARDED_MODEL_ID, self.NATIVE_MODEL_ID): + if not Path(model_path).exists(): + raise unittest.SkipTest(f"missing local model path: {model_path}") def release_vllm_model(self): - from vllm.distributed.parallel_state import destroy_model_parallel # noqa: E402 + try: + from vllm.distributed.parallel_state import destroy_model_parallel # noqa: E402 + except Exception: + torch_empty_cache() + return destroy_model_parallel() torch_empty_cache() - def test_load_vllm(self): - model = GPTQModel.load( - self.MODEL_ID, - device="cuda", - backend=BACKEND.VLLM, - gpu_memory_utilization=0.8, - ) - - tokenizer = model.get_tokenizer() - - self.assertInference(model, tokenizer) + def assert_evalution_vllm(self, model_path: str) -> None: + with tempfile.TemporaryDirectory() as tmp_dir: + results = evaluate( + model_or_id_or_path=model_path, + tasks=[self.TASK_NAME], + batch_size=1, + output_path=f"{tmp_dir}/result.json", + llm_backend="vllm", + model_args={ + "enforce_eager": False, + "gpu_memory_utilization": 0.8, + "tensor_parallel_size": 1, + }, + suite_kwargs={ + "max_rows": 2, + "num_fewshot": 1, + }, + ) - del model - self.release_vllm_model() + metrics = get_eval_task_metrics(results, self.TASK_NAME) + self.assertTrue(metrics, f"Expected Evalution metrics for task {self.TASK_NAME}") + self.assertEqual(results["engine"]["execution"]["generation_backend"], "vllm_generate") - def test_load_shared_vllm(self): - model = GPTQModel.load( - self.SHARDED_MODEL_ID, - device="cuda", - backend=BACKEND.VLLM, - gpu_memory_utilization=0.8, - ) - tokenizer = model.get_tokenizer() - - self.assertInference(model, tokenizer) + def test_evalution_vllm(self): + try: + self.assert_evalution_vllm(self.MODEL_ID) + finally: + self.release_vllm_model() - del model - self.release_vllm_model() + def test_evalution_sharded_vllm(self): + try: + self.assert_evalution_vllm(self.SHARDED_MODEL_ID) + finally: + self.release_vllm_model() def test_dynamic(self): - NATIVE_MODEL_ID = "/monster/data/model/TinyLlama-1.1B-Chat-v1.0" - tokenizer = AutoTokenizer.from_pretrained(NATIVE_MODEL_ID, use_fast=True) + tokenizer = AutoTokenizer.from_pretrained(self.NATIVE_MODEL_ID, use_fast=True) if not tokenizer.pad_token_id: tokenizer.pad_token_id = tokenizer.eos_token_id @@ -97,7 +124,7 @@ def test_dynamic(self): group_size=128, ) model = GPTQModel.load( - NATIVE_MODEL_ID, + self.NATIVE_MODEL_ID, quantize_config=quantize_config, ) model.quantize(calibration_dataset, batch_size=4) @@ -108,20 +135,13 @@ def test_dynamic(self): del model - model = GPTQModel.load( - tmp_dir, - device="cuda", - backend=BACKEND.VLLM, - gpu_memory_utilization=0.8, - ) - - tokenizer = model.get_tokenizer() - - for name, submodule in model.named_modules(): + inspect_model = GPTQModel.load(tmp_dir) + for name, submodule in inspect_model.named_modules(): if name == 'model.model.layers.0.self_attn.q_proj' and isinstance(submodule, BaseQuantLinear): # module 0 was skipped raise ValueError("first layer should be native module") + del inspect_model - self.assertInference(model, tokenizer) - - del model - self.release_vllm_model() + try: + self.assert_evalution_vllm(tmp_dir) + finally: + self.release_vllm_model() diff --git a/tests/test_weight_only.py b/tests/test_weight_only.py new file mode 100644 index 000000000..5af8fbeb6 --- /dev/null +++ b/tests/test_weight_only.py @@ -0,0 +1,1967 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import copy +from types import SimpleNamespace + +import numpy as np +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F + +from gptqmodel.models.base import BaseQModel +from gptqmodel.nn_modules.qlinear import PackableQuantLinear +from gptqmodel.nn_modules.qlinear.gguf import GGUFTorchLinear +from gptqmodel.nn_modules.qlinear.gguf_triton import GGUFTritonKernel +from gptqmodel.nn_modules.qlinear.torch import TorchLinear +from gptqmodel.nn_modules.qlinear.torch_awq import AwqTorchLinear +from gptqmodel.quantization.awq.utils.packing_utils import dequantize_gemm +from gptqmodel.quantization.config import ( + FORMAT, + METHOD, + AutoModuleDecoderConfig, + GGUFBits, + GGUFConfig, + QuantizeConfig, + RTNConfig, + SmoothMAD, + quant_bits_width, +) +from gptqmodel.quantization.rtn import RTN +from gptqmodel.utils.backend import BACKEND +from gptqmodel.utils.model import ( + convert_gptq_v1_to_v2_format_module, + convert_gptq_v2_to_v1_format_module, + find_modules, +) + + +class _TinyMLP(nn.Module): + def __init__(self, hidden_size: int): + super().__init__() + self.up_proj = nn.Linear(hidden_size, hidden_size, bias=False) + self.down_proj = nn.Linear(hidden_size, hidden_size, bias=False) + + +class _TinyBlock(nn.Module): + def __init__(self, hidden_size: int): + super().__init__() + self.mlp = _TinyMLP(hidden_size) + + +class _TinyBackbone(nn.Module): + def __init__(self, hidden_size: int, layers: int): + super().__init__() + self.layers = nn.ModuleList([_TinyBlock(hidden_size) for _ in range(layers)]) + + +class _TinyModel(nn.Module): + def __init__(self, hidden_size: int = 32, layers: int = 2): + super().__init__() + self.model = _TinyBackbone(hidden_size=hidden_size, layers=layers) + self.config = SimpleNamespace( + use_cache=False, + tie_word_embeddings=False, + model_type="tiny_weight_only_test", + ) + + +class _TinyQModel(BaseQModel): + module_tree = [ + "model", + "layers", + "#", + { + "mlp": ("up_proj:0", "down_proj:1"), + }, + ] + + +def _reference_rtn_quantized_weight(weight: torch.Tensor, device: torch.device, smooth: SmoothMAD) -> tuple[torch.Tensor, torch.Tensor]: + linear = nn.Linear(weight.shape[1], weight.shape[0], bias=False, dtype=weight.dtype) + linear.weight.data.copy_(weight) + linear.to(device) + + qcfg = RTNConfig( + bits=4, + group_size=32, + desc_act=False, + sym=True, + smooth=smooth, + offload_to_disk=False, + device=str(device), + ) + rtn = RTN(linear, qcfg=qcfg) + qweight, _, _, g_idx, *_ = rtn.quantize() + return qweight.detach().cpu(), g_idx.detach().cpu() + + +def _microbench_device(dtype: torch.dtype) -> torch.device: + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + if device.type == "cpu" and dtype == torch.float16: + pytest.skip("float16 RTN microbench requires CUDA for stable matmul") + return device + + +def _error_stats(reference: torch.Tensor, candidate: torch.Tensor) -> dict[str, float]: + diff = (candidate - reference).abs() + return { + "mae": diff.mean().item(), + "max": diff.max().item(), + } + + +def _build_rtn_microbench_case( + dtype: torch.dtype, + *, + bits: int | str = 4, +) -> dict[str, torch.Tensor | int | str | GGUFBits | torch.dtype | torch.device | None]: + device = _microbench_device(dtype) + + torch.manual_seed(1234) + + in_features = 128 + out_features = 128 + group_size = 32 + batch_size = 16 + + # Use LLM-like weight scales so the microbench measures RTN/export behavior, + # not an unrealistically wide toy distribution. + weight_master = torch.randn(out_features, in_features, dtype=torch.float32) * 0.01 + bias_master = torch.randn(out_features, dtype=torch.float32) * 0.001 + inputs_master = torch.randn(batch_size, in_features, dtype=torch.float32) * 0.1 + + linear = nn.Linear( + in_features, + out_features, + bias=True, + dtype=dtype, + device=device, + ).eval() + with torch.no_grad(): + linear.weight.copy_(weight_master.to(device=device, dtype=dtype)) + linear.bias.copy_(bias_master.to(device=device, dtype=dtype)) + + inputs = inputs_master.to(device=device, dtype=dtype) + native_output = linear(inputs) + + requested_bits: int | GGUFBits = bits + rtn_bits = bits + if isinstance(bits, GGUFBits): + requested_bits = bits + rtn_bits = int(bits) + elif isinstance(bits, str) and not bits.strip().isdigit(): + requested_bits = GGUFBits.from_string(bits) + rtn_bits = quant_bits_width(bits) + + qcfg = RTNConfig( + bits=rtn_bits, + group_size=group_size, + desc_act=False, + sym=True, + smooth=SmoothMAD(k=2.25), + offload_to_disk=False, + device=device.type, + ) + rtn_weight, scales, zeros, g_idx, *_ = RTN(linear, qcfg=qcfg).quantize() + rtn_output = F.linear(inputs, rtn_weight.to(device=device, dtype=dtype), linear.bias) + + cpu_linear = nn.Linear(in_features, out_features, bias=True, dtype=torch.float16).cpu().eval() + with torch.no_grad(): + cpu_linear.weight.copy_(rtn_weight.detach().cpu().to(torch.float16)) + cpu_linear.bias.copy_(linear.bias.detach().cpu().to(torch.float16)) + + return { + "device": device, + "dtype": dtype, + "in_features": in_features, + "out_features": out_features, + "group_size": group_size, + "bits": requested_bits, + "bit_width": quant_bits_width(requested_bits), + "inputs": inputs, + "native_output": native_output, + "rtn_output": rtn_output, + "rtn_weight_cpu": rtn_weight.detach().cpu(), + "scales_cpu": scales.detach().cpu(), + "zeros_cpu": zeros.detach().cpu(), + "g_idx_cpu": g_idx.detach().cpu(), + "cpu_linear": cpu_linear, + } + + +def _build_rtn_gptq_module(case: dict[str, torch.Tensor | int | torch.dtype | torch.device]) -> TorchLinear: + module = TorchLinear( + bits=case["bit_width"], + group_size=case["group_size"], + sym=True, + desc_act=False, + in_features=case["in_features"], + out_features=case["out_features"], + bias=True, + register_buffers=False, + ) + module.pack_block( + linear=case["cpu_linear"], + scales=case["scales_cpu"], + zeros=case["zeros_cpu"], + g_idx=case["g_idx_cpu"], + ) + # `pack_block()` gives us the in-memory GPTQ runtime layout that + # `TorchLinear` executes with. That is not the same thing as the + # serialized GPTQ checkpoint layout this project can export. + # + # The real checkpoint round-trip is: + # 1. runtime/internal layout -> GPTQ v1 serialized layout at export time + # 2. GPTQ v1 serialized layout -> runtime/internal layout at load time + # + # This helper intentionally performs that round-trip so the microbench + # validates export + reload behavior, not just the raw in-memory pack path. + convert_gptq_v2_to_v1_format_module( + module, + QuantizeConfig(bits=module.bits, quant_method=METHOD.GPTQ), + ) + convert_gptq_v1_to_v2_format_module( + module, + bits=module.bits, + pack_dtype=module.pack_dtype, + ) + # Skip the compile-heavy TorchLinear override; the microbench only + # needs unpack buffers initialized for numerical comparisons. + PackableQuantLinear.post_init(module) + return module.to(case["device"]).eval() + + +def _build_rtn_awq_module(case: dict[str, torch.Tensor | int | torch.dtype | torch.device]) -> AwqTorchLinear: + module = AwqTorchLinear( + bits=case["bit_width"], + group_size=case["group_size"], + sym=True, + desc_act=False, + in_features=case["in_features"], + out_features=case["out_features"], + bias=True, + register_buffers=False, + ) + module.pack( + linear=case["cpu_linear"], + scales=case["scales_cpu"], + zeros=case["zeros_cpu"], + ) + module.post_init() + return module.to(case["device"]).eval() + + +def _build_rtn_gguf_module( + case: dict[str, torch.Tensor | int | str | GGUFBits | torch.dtype | torch.device | None], +) -> GGUFTorchLinear: + module = GGUFTorchLinear( + bits=case["bits"], + group_size=-1, + sym=True, + desc_act=False, + in_features=case["in_features"], + out_features=case["out_features"], + bias=True, + register_buffers=False, + ) + module.pack( + linear=case["cpu_linear"], + scales=case["scales_cpu"], + zeros=case["zeros_cpu"], + g_idx=case["g_idx_cpu"], + ) + module.post_init() + return module.to(case["device"]).eval() + + +def test_baseqmodel_quantize_uses_weight_only_rtn_pipeline(): + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device_type = device.type + + native = _TinyModel().to(device=device, dtype=torch.float16).eval() + original_state = copy.deepcopy(native.state_dict()) + + smooth = SmoothMAD(k=2.25) + qcfg = RTNConfig( + bits=4, + group_size=32, + desc_act=False, + sym=True, + smooth=smooth, + offload_to_disk=False, + device=device_type, + ) + + model = _TinyQModel( + model=native, + quantized=False, + quantize_config=qcfg, + tokenizer=None, + ) + + result = model.quantize(calibration=None, backend=BACKEND.TORCH) + + assert "weight_only_rtn" in result + assert model.quantized is True + + qmodules = find_modules(model.model, [TorchLinear]) + assert len(qmodules) == 4 + + for name, qmodule in qmodules.items(): + original_weight = original_state[f"{name}.weight"].to(dtype=torch.float16) + expected_qweight, expected_g_idx = _reference_rtn_quantized_weight(original_weight, device=device, smooth=smooth) + + assert qmodule.qzero_format() == 1 + assert qmodule.qweight.device.type == "cpu" + assert qmodule.qzeros.device.type == "cpu" + assert qmodule.scales.device.type == "cpu" + assert qmodule.g_idx.device.type == "cpu" + + dequant_module = copy.deepcopy(qmodule) + if dequant_module.qzero_format() == 1: + convert_gptq_v1_to_v2_format_module( + dequant_module, + bits=dequant_module.bits, + pack_dtype=dequant_module.pack_dtype, + ) + if not hasattr(dequant_module, "wf_unsqueeze_zero"): + dequant_module.post_init() + + actual_qweight = dequant_module.dequantize_weight().T.detach().cpu().to(dtype=expected_qweight.dtype) + actual_error = (actual_qweight - original_weight.cpu()).abs().mean().item() + expected_error = (expected_qweight - original_weight.cpu()).abs().mean().item() + + assert actual_error <= expected_error + 0.01 + assert actual_error < 0.05 + torch.testing.assert_close(qmodule.g_idx.detach().cpu(), expected_g_idx) + + +def test_baseqmodel_quantize_allows_rtn_awq_export(): + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device_type = device.type + + native = _TinyModel().to(device=device, dtype=torch.float16).eval() + smooth = SmoothMAD(k=2.25) + + qcfg = RTNConfig( + bits=4, + group_size=32, + desc_act=False, + sym=True, + format=FORMAT.GEMM, + smooth=smooth, + offload_to_disk=False, + device=device_type, + ) + + model = _TinyQModel( + model=native, + quantized=False, + quantize_config=qcfg, + tokenizer=None, + ) + + result = model.quantize(calibration=None, backend=BACKEND.AUTO) + + assert "weight_only_rtn" in result + assert model.quantized is True + assert model.quantize_config.format == FORMAT.GEMM + assert model.quantize_config.export_quant_method() == METHOD.AWQ + assert getattr(model.qlinear_kernel, "__name__", "") == "AwqTorchLinear" + + qmodules = find_modules(model.model, [model.qlinear_kernel]) + assert len(qmodules) == 4 + + for qmodule in qmodules.values(): + assert qmodule.qweight.device.type == "cpu" + assert qmodule.qzeros.device.type == "cpu" + assert qmodule.scales.device.type == "cpu" + + +@pytest.mark.parametrize( + ("bits", "tensor_qtype", "bit_width", "variant", "quality"), + [ + ("q4_k_s", "Q4_K", 4, "k", "s"), + ("q4_k_m", "Q4_K", 4, "k", "m"), + ("q5_k_s", "Q5_K", 5, "k", "s"), + ("q5_k_m", "Q5_K", 5, "k", "m"), + ("q6_k", "Q6_K", 6, "k", None), + ], +) +def test_baseqmodel_quantize_allows_direct_gguf_export( + bits: str, + tensor_qtype: str, + bit_width: int, + variant: str, + quality: str | None, +): + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device_type = device.type + public_format = GGUFBits.from_string(bits).to_public_format() + + native = _TinyModel().to(device=device, dtype=torch.float16).eval() + qcfg = GGUFConfig( + bits=bit_width, + format=public_format, + smoother=None, + offload_to_disk=False, + device=device_type, + ) + + model = _TinyQModel( + model=native, + quantized=False, + quantize_config=qcfg, + tokenizer=None, + ) + + result = model.quantize(calibration=None, backend=BACKEND.AUTO) + + assert "weight_only_gguf" in result + assert model.quantized is True + assert model.quantize_config.format == public_format + assert model.quantize_config.bits == bit_width + assert model.quantize_config.quant_method == METHOD.GGUF + assert model.quantize_config.export_quant_method() == METHOD.GGUF + expected_kernel = GGUFTritonKernel if device_type == "cuda" else GGUFTorchLinear + assert model.qlinear_kernel is expected_kernel + + qmodules = find_modules(model.model, [model.qlinear_kernel]) + assert len(qmodules) == 4 + + for qmodule in qmodules.values(): + assert qmodule.qweight.device.type == "cpu" + assert qmodule.qweight.dtype == torch.uint8 + assert isinstance(qmodule.bits, GGUFBits) + assert qmodule.bits == bits + assert qmodule.bits.bits == bit_width + assert qmodule.bits.version == "q" + assert qmodule.bits.variant == variant + assert qmodule.bits.quality == quality + assert qmodule.gguf_tensor_qtype == tensor_qtype + expected_padded_in_features = ( + (qmodule.in_features + qmodule.gguf_block_size - 1) // qmodule.gguf_block_size + ) * qmodule.gguf_block_size + assert qmodule.padded_in_features == expected_padded_in_features + assert qmodule.qweight.shape == (qmodule.out_features, qmodule._bytes_per_row()) + + +def test_baseqmodel_quantize_gguf_weight_only_skips_rtn(monkeypatch): + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device_type = device.type + + native = _TinyModel().to(device=device, dtype=torch.float16).eval() + + qcfg = GGUFConfig( + bits=4, + format="q_k_m", + smoother=SmoothMAD(k=2.25), + offload_to_disk=False, + device=device_type, + ) + + model = _TinyQModel( + model=native, + quantized=False, + quantize_config=qcfg, + tokenizer=None, + ) + + def _fail_quantize(*args, **kwargs): + raise AssertionError("RTN.quantize should not be called for direct GGUF packing") + + monkeypatch.setattr(RTN, "quantize", _fail_quantize) + + result = model.quantize(calibration=None, backend=BACKEND.AUTO) + + assert "weight_only_gguf" in result + assert model.quantized is True + qmodules = find_modules(model.model, [model.qlinear_kernel]) + assert len(qmodules) == 4 + + +def test_baseqmodel_quantize_gguf_weight_only_applies_auto_module_decoder(monkeypatch): + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + device_type = device.type + + native = _TinyModel().to(device=device, dtype=torch.float16).eval() + + qcfg = GGUFConfig( + bits=4, + format="q_k_m", + preprocessors=[AutoModuleDecoderConfig(target_dtype=torch.bfloat16)], + offload_to_disk=False, + device=device_type, + ) + + model = _TinyQModel( + model=native, + quantized=False, + quantize_config=qcfg, + tokenizer=None, + ) + + materialize_calls = [] + original_shell_module_materialize = BaseQModel.shell_module_materialize + + def _spy_shell_module_materialize(self, *args, **kwargs): + materialize_calls.append(kwargs.get("role")) + return original_shell_module_materialize(self, *args, **kwargs) + + monkeypatch.setattr(BaseQModel, "shell_module_materialize", _spy_shell_module_materialize) + + result = model.quantize(calibration=None, backend=BACKEND.AUTO) + + assert "weight_only_gguf" in result + assert materialize_calls.count("quant_source") == 4 + + +@pytest.mark.parametrize("bits", ["q4_k_m", "q5_k_m", "q6_k"]) +def test_gguf_pack_original_auto_pads_non_aligned_k_in_features(bits: str): + torch.manual_seed(77) + + in_features = 130 + out_features = 48 + linear = nn.Linear(in_features, out_features, bias=True, dtype=torch.float16).cpu().eval() + + with torch.no_grad(): + linear.weight.normal_(mean=0.0, std=0.02) + linear.bias.normal_(mean=0.0, std=0.01) + + module = GGUFTorchLinear( + bits=bits, + group_size=-1, + sym=True, + desc_act=False, + in_features=in_features, + out_features=out_features, + bias=True, + register_buffers=True, + ).cpu().eval() + module.pack_original(linear, scales=None, zeros=None) + + assert module.in_features == in_features + assert module.gguf_block_size == 256 + assert module.padded_in_features == 256 + assert module.qweight.shape == (out_features, module._bytes_per_row()) + + x = torch.randn(7, in_features, dtype=torch.float32) + with torch.inference_mode(): + native_out = F.linear(x, linear.weight.detach().to(torch.float32), linear.bias.detach().to(torch.float32)) + gguf_out = module(x) + + stats = _error_stats(native_out, gguf_out.to(torch.float32)) + assert stats["mae"] < 0.02 + assert stats["max"] < 0.08 + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_rtn_microbench_quantized_output_stays_close_to_native(dtype: torch.dtype): + case = _build_rtn_microbench_case(dtype) + + stats = _error_stats(case["native_output"], case["rtn_output"]) + + assert stats["mae"] < 0.0010 + assert stats["max"] < 0.0060 + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_rtn_microbench_gptq_export_stays_close_to_rtn(dtype: torch.dtype): + case = _build_rtn_microbench_case(dtype) + module = _build_rtn_gptq_module(case) + + packed_weight = module.dequantize_weight().T.detach().cpu().to(case["rtn_weight_cpu"].dtype) + weight_stats = _error_stats(case["rtn_weight_cpu"], packed_weight) + + exported_output = module(case["inputs"]) + output_stats = _error_stats(case["rtn_output"], exported_output) + + assert weight_stats["mae"] < 2e-5 + assert output_stats["mae"] < 2e-5 + assert output_stats["max"] < 3e-4 + + +def test_rtn_microbench_awq_export_stays_close_to_rtn(): + case = _build_rtn_microbench_case(torch.float16) + module = _build_rtn_awq_module(case) + + packed_weight = dequantize_gemm( + qweight=module.qweight, + qzeros=module.qzeros, + scales=module.scales, + bits=module.bits, + group_size=module.group_size, + ).detach().cpu().to(case["rtn_weight_cpu"].dtype) + weight_stats = _error_stats(case["rtn_weight_cpu"], packed_weight) + + exported_output = module(case["inputs"]) + output_stats = _error_stats(case["rtn_output"], exported_output) + + assert weight_stats["mae"] < 0.0120 + assert output_stats["mae"] < 1e-5 + assert output_stats["max"] < 1e-4 + + +def test_rtn_microbench_gguf_export_matches_reference_bytes(): + gguf = pytest.importorskip("gguf") + + case = _build_rtn_microbench_case(torch.float16, bits="q4_0") + module = _build_rtn_gguf_module(case) + + reference = gguf.quantize( + case["rtn_weight_cpu"].numpy().astype(np.float32), + gguf.GGMLQuantizationType.Q4_0, + ) + + np.testing.assert_array_equal(module.qweight.detach().cpu().numpy(), reference) + + +def test_rtn_microbench_gguf_export_accepts_structured_bits(): + dtype = torch.float16 if torch.cuda.is_available() else torch.bfloat16 + bits = GGUFBits(bits=5, version="q", variant="k", quality="m") + + case = _build_rtn_microbench_case(dtype, bits=bits) + module = _build_rtn_gguf_module(case) + + assert isinstance(case["bits"], GGUFBits) + assert case["bits"] == "q5_k_m" + assert isinstance(module.bits, GGUFBits) + assert module.bits == "q5_k_m" + assert module.gguf_tensor_qtype == "Q5_K" + + output_stats = _error_stats(case["rtn_output"], module(case["inputs"])) + + assert output_stats["mae"] < 6e-4 + assert output_stats["max"] < 0.012 + + +@pytest.mark.parametrize("bits", ["q4_0", "q4_k_m"]) +def test_gguf_dequantize_weight_accepts_requested_dtype_and_device(bits: str): + dtype = torch.float16 if torch.cuda.is_available() else torch.bfloat16 + case = _build_rtn_microbench_case(dtype, bits=bits) + module = _build_rtn_gguf_module(case) + + direct = module.dequantize_weight(device=case["device"], dtype=case["dtype"]) + baseline = module.dequantize_weight().to(device=case["device"], dtype=case["dtype"]) + + assert direct.device == case["device"] + assert direct.dtype == case["dtype"] + torch.testing.assert_close(direct, baseline, atol=2e-3, rtol=0.0) + + +def test_gguf_forward_requests_dequantized_weight_in_input_dtype(monkeypatch): + dtype = torch.float16 if torch.cuda.is_available() else torch.bfloat16 + case = _build_rtn_microbench_case(dtype, bits="q4_0") + module = _build_rtn_gguf_module(case) + + observed: dict[str, torch.device | torch.dtype | None] = {"device": None, "dtype": None} + original = GGUFTorchLinear.dequantize_weight + + def _wrapped(self, *, device=None, dtype=None): + observed["device"] = None if device is None else torch.device(device) + observed["dtype"] = dtype + return original(self, device=device, dtype=dtype) + + monkeypatch.setattr(GGUFTorchLinear, "dequantize_weight", _wrapped) + + module(case["inputs"]) + + assert observed["device"] == case["inputs"].device + assert observed["dtype"] == case["inputs"].dtype + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for GGUF K-path CUDA tests") +@pytest.mark.parametrize( + ("bits", "tensor_qtype"), + [ + ("q4_k_m", "Q4_K"), + ("q5_k_m", "Q5_K"), + ("q6_k", "Q6_K"), + ], +) +def test_gguf_k_dequantize_weight_matches_reference_on_cuda(bits: str, tensor_qtype: str): + gguf = pytest.importorskip("gguf") + + case = _build_rtn_microbench_case(torch.float16, bits=bits) + module = _build_rtn_gguf_module(case).to(case["device"]).eval() + + actual = module.dequantize_weight(device=case["device"], dtype=torch.float32).T.detach().cpu().numpy() + reference = gguf.dequantize( + module.qweight.detach().cpu().numpy(), + getattr(gguf.GGMLQuantizationType, tensor_qtype), + )[:, : case["in_features"]] + + np.testing.assert_allclose(actual, reference, rtol=0.0, atol=1e-5) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for GGUF fused K forward tests") +@pytest.mark.parametrize("bits", ["q4_k_m", "q5_k_m", "q6_k"]) +def test_gguf_k_fused_forward_matches_dense_baseline(bits: str): + case = _build_rtn_microbench_case(torch.float16, bits=bits) + module = _build_rtn_gguf_module(case).to(case["device"]).eval() + module.gguf_fused_cuda_max_rows = case["inputs"].shape[0] + module.gguf_fused_cuda_min_matrix_elements = 0 + module.clear_weight_cache() + + baseline = module._forward_dequant_matmul(case["inputs"]) + fused = module._forward_fused_k(case["inputs"]) + + output_stats = _error_stats(baseline.to(torch.float32), fused.to(torch.float32)) + assert output_stats["mae"] < 2e-4 + assert output_stats["max"] < 3e-3 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for GGUF Triton fused K tests") +@pytest.mark.parametrize("bits", ["q4_k_m", "q5_k_m", "q6_k"]) +def test_gguf_triton_fused_forward_matches_dense_baseline(bits: str): + pytest.importorskip("triton") + + from gptqmodel.nn_modules.qlinear.gguf_triton import triton_available + + if not triton_available(): + pytest.skip("Triton GGUF fused kernel unavailable") + + case = _build_rtn_microbench_case(torch.float16, bits=bits) + module = _build_rtn_gguf_module(case).to(case["device"]).eval() + triton_module = GGUFTritonKernel( + bits=bits, + group_size=-1, + sym=True, + desc_act=False, + in_features=case["in_features"], + out_features=case["out_features"], + bias=True, + register_buffers=True, + ).to(case["device"]).eval() + triton_module.load_state_dict(module.state_dict(), strict=True) + triton_module.clear_weight_cache() + + baseline = module._forward_dequant_matmul(case["inputs"]) + if module.bias is not None: + baseline = baseline + module.bias.to(device=baseline.device, dtype=baseline.dtype) + fused = triton_module(case["inputs"]) + + output_stats = _error_stats(baseline.to(torch.float32), fused.to(torch.float32)) + assert output_stats["mae"] < 2e-4 + assert output_stats["max"] < 3e-3 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for GGUF Triton sign-only tests") +def test_gguf_triton_q1_0_g128_fused_forward_matches_dense_baseline(): + pytest.importorskip("triton") + + from gptqmodel.nn_modules.qlinear.gguf_triton import triton_available + + if not triton_available(): + pytest.skip("Triton GGUF fused kernel unavailable") + + torch.manual_seed(1234) + cpu_linear = nn.Linear(256, 192, bias=True, dtype=torch.float16).cpu().eval() + with torch.no_grad(): + cpu_linear.weight.normal_(mean=0.0, std=0.01) + cpu_linear.bias.normal_(mean=0.0, std=0.001) + + module = GGUFTorchLinear( + bits="q1_0_g128", + group_size=-1, + sym=True, + desc_act=False, + in_features=256, + out_features=192, + bias=True, + register_buffers=False, + ) + module.pack( + linear=cpu_linear, + scales=torch.empty(0), + zeros=torch.empty(0), + g_idx=None, + ) + module.post_init() + module = module.to("cuda").eval() + + triton_module = GGUFTritonKernel( + bits="q1_0_g128", + group_size=-1, + sym=True, + desc_act=False, + in_features=256, + out_features=192, + bias=True, + register_buffers=True, + ).to("cuda").eval() + triton_module.load_state_dict(module.state_dict(), strict=True) + triton_module.clear_weight_cache() + + inputs = torch.randn(16, 256, device="cuda", dtype=torch.float16) * 0.1 + baseline = module._forward_dequant_matmul(inputs) + if module.bias is not None: + baseline = baseline + module.bias.to(device=baseline.device, dtype=baseline.dtype) + fused = triton_module(inputs) + + output_stats = _error_stats(baseline.to(torch.float32), fused.to(torch.float32)) + assert output_stats["mae"] < 2e-4 + assert output_stats["max"] < 3e-3 + + +def test_gguf_triton_kernel_rejects_unsupported_formats(): + with pytest.raises(NotImplementedError, match="only supports fused GGUF Triton formats"): + GGUFTritonKernel( + bits="q4_0", + group_size=-1, + sym=True, + desc_act=False, + in_features=64, + out_features=48, + bias=False, + register_buffers=False, + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for GGUF Triton routing test") +def test_gguf_triton_selects_large_config_bank_for_large_k(monkeypatch): + pytest.importorskip("triton") + + from gptqmodel.nn_modules.qlinear import gguf_triton + + calls: list[object] = [] + small_kernel = object() + large_kernel = object() + + def _fake_launch(kernel, x, output, *args): + calls.append(kernel) + return output + + monkeypatch.setattr(gguf_triton, "_gguf_q4_k_fused_matmul_kernel_small", small_kernel) + monkeypatch.setattr(gguf_triton, "_gguf_q4_k_fused_matmul_kernel_large", large_kernel) + monkeypatch.setattr(gguf_triton, "_launch", _fake_launch) + + x_small = torch.randn(2, 2048, device="cuda", dtype=torch.float16) + qs_small = torch.empty((8, 144, 32), device="cuda", dtype=torch.uint8) + scale_small = torch.empty((8, 8, 32), device="cuda", dtype=torch.float16) + min_small = torch.empty((8, 8, 32), device="cuda", dtype=torch.float16) + gguf_triton.fused_q4_k_matmul(x_small, qs_small, scale_small, min_small) + assert calls[-1] is small_kernel + + x_large = torch.randn(2, 4096, device="cuda", dtype=torch.float16) + qs_large = torch.empty((16, 144, 32), device="cuda", dtype=torch.uint8) + scale_large = torch.empty((16, 8, 32), device="cuda", dtype=torch.float16) + min_large = torch.empty((16, 8, 32), device="cuda", dtype=torch.float16) + gguf_triton.fused_q4_k_matmul(x_large, qs_large, scale_large, min_large) + assert calls[-1] is large_kernel + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for GGUF Triton routing test") +def test_gguf_triton_selects_large_config_bank_for_large_q1_0_g128(monkeypatch): + pytest.importorskip("triton") + + from gptqmodel.nn_modules.qlinear import gguf_triton + + calls: list[object] = [] + small_kernel = object() + large_kernel = object() + + def _fake_launch(kernel, x, output, *args): + calls.append(kernel) + return output + + monkeypatch.setattr(gguf_triton, "_gguf_q1_0_g128_fused_matmul_kernel_small", small_kernel) + monkeypatch.setattr(gguf_triton, "_gguf_q1_0_g128_fused_matmul_kernel_large", large_kernel) + monkeypatch.setattr(gguf_triton, "_launch", _fake_launch) + + x_small = torch.randn(2, 1024, device="cuda", dtype=torch.float16) + sign_small = torch.empty((8, 16, 32), device="cuda", dtype=torch.uint8) + scale_small = torch.empty((8, 32), device="cuda", dtype=torch.float16) + gguf_triton.fused_q1_0_g128_matmul(x_small, sign_small, scale_small) + assert calls[-1] is small_kernel + + x_large = torch.randn(2, 2048, device="cuda", dtype=torch.float16) + sign_large = torch.empty((16, 16, 32), device="cuda", dtype=torch.uint8) + scale_large = torch.empty((16, 32), device="cuda", dtype=torch.float16) + gguf_triton.fused_q1_0_g128_matmul(x_large, sign_large, scale_large) + assert calls[-1] is large_kernel + + +def test_select_q1_0_g128_fixed_launch_config_targets_arch_decode_shapes(): + from gptqmodel.nn_modules.qlinear.gguf_triton import _select_q1_0_g128_fixed_launch_config + + assert _select_q1_0_g128_fixed_launch_config(capability=(8, 0), rows=1, cols=2048) == { + "BLOCK_SIZE_M": 4, + "BLOCK_SIZE_N": 32, + "num_warps": 4, + "num_stages": 2, + } + assert _select_q1_0_g128_fixed_launch_config(capability=(8, 0), rows=1, cols=6144) == { + "BLOCK_SIZE_M": 2, + "BLOCK_SIZE_N": 32, + "num_warps": 4, + "num_stages": 4, + } + assert _select_q1_0_g128_fixed_launch_config(capability=(8, 9), rows=1, cols=2048) == { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 32, + "num_warps": 8, + "num_stages": 2, + } + assert _select_q1_0_g128_fixed_launch_config(capability=(8, 9), rows=1, in_features=2048, cols=2048) == { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "num_warps": 4, + "num_stages": 2, + } + assert _select_q1_0_g128_fixed_launch_config(capability=(8, 9), rows=1, in_features=6144, cols=2048) == { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 32, + "num_warps": 8, + "num_stages": 2, + } + assert _select_q1_0_g128_fixed_launch_config(capability=(8, 9), rows=1, cols=1024) == { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 32, + "num_warps": 8, + "num_stages": 2, + } + assert _select_q1_0_g128_fixed_launch_config(capability=(8, 9), rows=1, cols=6144) == { + "BLOCK_SIZE_M": 2, + "BLOCK_SIZE_N": 32, + "num_warps": 8, + "num_stages": 4, + } + assert _select_q1_0_g128_fixed_launch_config(capability=(8, 9), rows=1, in_features=2048, cols=6144) == { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 64, + "num_warps": 8, + "num_stages": 2, + } + assert _select_q1_0_g128_fixed_launch_config(capability=(8, 0), rows=64, cols=2048) == { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "num_warps": 4, + "num_stages": 4, + } + assert _select_q1_0_g128_fixed_launch_config(capability=(8, 0), rows=64, cols=6144) == { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "num_warps": 4, + "num_stages": 4, + } + assert _select_q1_0_g128_fixed_launch_config(capability=(8, 9), rows=64, cols=2048) == { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 16, + "num_warps": 8, + "num_stages": 4, + } + assert _select_q1_0_g128_fixed_launch_config(capability=(8, 9), rows=64, cols=6144) == { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "num_warps": 4, + "num_stages": 4, + } + assert _select_q1_0_g128_fixed_launch_config(capability=(8, 9), rows=1, cols=4096) is None + assert _select_q1_0_g128_fixed_launch_config(capability=(8, 0), rows=2, cols=6144) is None + assert _select_q1_0_g128_fixed_launch_config(capability=(8, 9), rows=256, cols=2048) is None + + +def test_select_q1_0_g128_u32_layout_targets_sm80_down_proj(): + from gptqmodel.nn_modules.qlinear.gguf_triton import _select_q1_0_g128_u32_layout + + assert _select_q1_0_g128_u32_layout(capability=(8, 0), in_features=6144, out_features=2048) is True + assert _select_q1_0_g128_u32_layout(capability=(8, 0), in_features=2048, out_features=1024) is True + assert _select_q1_0_g128_u32_layout(capability=(8, 0), in_features=2048, out_features=2048) is True + assert _select_q1_0_g128_u32_layout(capability=(8, 0), in_features=2048, out_features=6144) is True + assert _select_q1_0_g128_u32_layout(capability=(8, 0), in_features=1024, out_features=2048) is False + assert _select_q1_0_g128_u32_layout(capability=(8, 9), in_features=6144, out_features=2048) is False + assert _select_q1_0_g128_u32_layout(capability=None, in_features=6144, out_features=2048) is False + + +def test_select_q1_0_g128_u32_fixed_launch_config_targets_sm80_down_proj_ranges(): + from gptqmodel.nn_modules.qlinear.gguf_triton import _select_q1_0_g128_u32_fixed_launch_config + + assert _select_q1_0_g128_u32_fixed_launch_config( + capability=(8, 0), + rows=1, + in_features=6144, + out_features=2048, + ) == { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "num_warps": 4, + "num_stages": 4, + } + assert _select_q1_0_g128_u32_fixed_launch_config( + capability=(8, 0), + rows=64, + in_features=6144, + out_features=2048, + ) == { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "num_warps": 4, + "num_stages": 4, + } + assert _select_q1_0_g128_u32_fixed_launch_config( + capability=(8, 0), + rows=1, + in_features=2048, + out_features=1024, + ) == { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "num_warps": 2, + "num_stages": 2, + } + assert _select_q1_0_g128_u32_fixed_launch_config( + capability=(8, 0), + rows=1, + in_features=2048, + out_features=2048, + ) == { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "num_warps": 4, + "num_stages": 2, + } + assert _select_q1_0_g128_u32_fixed_launch_config( + capability=(8, 0), + rows=1, + in_features=2048, + out_features=6144, + ) == { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 32, + "num_warps": 4, + "num_stages": 4, + } + assert _select_q1_0_g128_u32_fixed_launch_config( + capability=(8, 0), + rows=4, + in_features=6144, + out_features=2048, + ) is None + assert _select_q1_0_g128_u32_fixed_launch_config( + capability=(8, 0), + rows=256, + in_features=6144, + out_features=2048, + ) is None + assert _select_q1_0_g128_u32_fixed_launch_config( + capability=(8, 9), + rows=64, + in_features=6144, + out_features=2048, + ) is None + + +def test_use_q1_0_g128_k2048_decode_specialization_only_targets_decode_k2048(): + from gptqmodel.nn_modules.qlinear.gguf_triton import _use_q1_0_g128_k2048_decode_specialization + + assert _use_q1_0_g128_k2048_decode_specialization(rows=1, in_features=2048) is True + assert _use_q1_0_g128_k2048_decode_specialization(rows=4, in_features=2048) is False + assert _use_q1_0_g128_k2048_decode_specialization(rows=1, in_features=6144) is False + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for GGUF Triton routing test") +def test_gguf_triton_uses_fixed_decode_launch_for_sm80_q1_0_g128(monkeypatch): + pytest.importorskip("triton") + + from gptqmodel.nn_modules.qlinear import gguf_triton + + calls: list[tuple[str, object, dict[str, int] | None]] = [] + fixed_config = { + "BLOCK_SIZE_M": 2, + "BLOCK_SIZE_N": 32, + "num_warps": 4, + "num_stages": 4, + } + + def _fake_fixed_matmul(x, sign_bytes, scale, *, fixed_config): + calls.append(("fixed", gguf_triton._gguf_q1_0_g128_fused_matmul_kernel_impl, fixed_config)) + return torch.empty((x.shape[0], scale.shape[1]), device=x.device, dtype=x.dtype) + + def _fake_fused_matmul(x, sign_bytes, scale): + calls.append(("generic", None, None)) + return torch.empty((x.shape[0], scale.shape[1]), device=x.device, dtype=x.dtype) + + monkeypatch.setattr(gguf_triton, "_launch_q1_0_g128_fixed_matmul", _fake_fixed_matmul) + monkeypatch.setattr(gguf_triton, "fused_q1_0_g128_matmul", _fake_fused_matmul) + + module = GGUFTritonKernel( + bits="q1_0_g128", + group_size=-1, + sym=True, + desc_act=False, + in_features=256, + out_features=192, + bias=False, + register_buffers=True, + ).to("cuda").eval() + monkeypatch.setattr( + module, + "_get_triton_cache", + lambda _device: { + "fixed_decode_config": fixed_config, + "sign_bytes": torch.empty((2, 16, 192), device="cuda", dtype=torch.uint8), + "scale": torch.empty((2, 192), device="cuda", dtype=torch.float16), + }, + ) + + x = torch.randn(1, 256, device="cuda", dtype=torch.float16) + module._forward_triton(x) + + assert calls == [("fixed", gguf_triton._gguf_q1_0_g128_fused_matmul_kernel_impl, fixed_config)] + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for GGUF Triton u32 tests") +def test_gguf_triton_q1_0_g128_u32_fused_matches_byte_fused(): + pytest.importorskip("triton") + + from gptqmodel.nn_modules.qlinear import gguf_triton + + torch.manual_seed(1234) + x = torch.randn(8, 256, device="cuda", dtype=torch.float16) + sign_bytes = torch.randint(0, 256, (2, 16, 192), device="cuda", dtype=torch.uint8) + sign_words = ( + sign_bytes.to(torch.int32)[:, 0::4, :] + | torch.bitwise_left_shift(sign_bytes.to(torch.int32)[:, 1::4, :], 8) + | torch.bitwise_left_shift(sign_bytes.to(torch.int32)[:, 2::4, :], 16) + | torch.bitwise_left_shift(sign_bytes.to(torch.int32)[:, 3::4, :], 24) + ).contiguous() + scale = (torch.rand((2, 192), device="cuda", dtype=torch.float16) + 0.05).contiguous() + + byte_out = gguf_triton.fused_q1_0_g128_matmul(x, sign_bytes, scale) + u32_out = gguf_triton.fused_q1_0_g128_u32_matmul(x, sign_words, scale) + + output_stats = _error_stats(byte_out.to(torch.float32), u32_out.to(torch.float32)) + assert output_stats["mae"] < 2e-4 + assert output_stats["max"] < 3e-3 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for GGUF Triton routing test") +def test_gguf_triton_routes_sm80_down_proj_to_u32_path(monkeypatch): + pytest.importorskip("triton") + + from gptqmodel.nn_modules.qlinear import gguf_triton + + calls: list[tuple[str, dict[str, int] | None]] = [] + + fixed_config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 32, + "num_warps": 4, + "num_stages": 4, + } + + def _fake_u32_fixed(x, sign_words, scale, *, fixed_config): + calls.append(("u32_fixed", fixed_config)) + return torch.empty((x.shape[0], scale.shape[1]), device=x.device, dtype=x.dtype) + + def _fake_byte(x, sign_bytes, scale): + calls.append(("byte", None)) + return torch.empty((x.shape[0], scale.shape[1]), device=x.device, dtype=x.dtype) + + monkeypatch.setattr(gguf_triton, "_cuda_device_capability", lambda _device: (8, 0)) + monkeypatch.setattr(gguf_triton, "_launch_q1_0_g128_u32_fixed_matmul", _fake_u32_fixed) + monkeypatch.setattr(gguf_triton, "fused_q1_0_g128_matmul", _fake_byte) + + module = GGUFTritonKernel( + bits="q1_0_g128", + group_size=-1, + sym=True, + desc_act=False, + in_features=6144, + out_features=2048, + bias=False, + register_buffers=True, + ).to("cuda").eval() + monkeypatch.setattr( + module, + "_get_triton_cache", + lambda _device: { + "use_u32": True, + "fixed_decode_config": None, + "sign_bytes": torch.empty((48, 16, 2048), device="cuda", dtype=torch.uint8), + "sign_words": torch.empty((48, 4, 2048), device="cuda", dtype=torch.int32), + "scale": torch.empty((48, 2048), device="cuda", dtype=torch.float16), + }, + ) + + x = torch.randn(64, 6144, device="cuda", dtype=torch.float16) + module._forward_triton(x) + + assert calls == [("u32_fixed", fixed_config)] + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for GGUF Triton routing test") +def test_gguf_triton_routes_sm80_decode_2048x2048_to_k2048_specialized_u32_path(monkeypatch): + pytest.importorskip("triton") + + from gptqmodel.nn_modules.qlinear import gguf_triton + + calls: list[tuple[str, dict[str, int] | None]] = [] + + fixed_config = { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "num_warps": 4, + "num_stages": 2, + } + + def _fake_u32_k2048_fixed(x, sign_words, scale, *, fixed_config): + calls.append(("u32_k2048_fixed", fixed_config)) + return torch.empty((x.shape[0], scale.shape[1]), device=x.device, dtype=x.dtype) + + def _fake_u32_fixed(x, sign_words, scale, *, fixed_config): + calls.append(("u32_fixed", fixed_config)) + return torch.empty((x.shape[0], scale.shape[1]), device=x.device, dtype=x.dtype) + + monkeypatch.setattr(gguf_triton, "_cuda_device_capability", lambda _device: (8, 0)) + monkeypatch.setattr(gguf_triton, "_launch_q1_0_g128_u32_k2048_fixed_matmul", _fake_u32_k2048_fixed) + monkeypatch.setattr(gguf_triton, "_launch_q1_0_g128_u32_fixed_matmul", _fake_u32_fixed) + + module = GGUFTritonKernel( + bits="q1_0_g128", + group_size=-1, + sym=True, + desc_act=False, + in_features=2048, + out_features=2048, + bias=False, + register_buffers=True, + ).to("cuda").eval() + monkeypatch.setattr( + module, + "_get_triton_cache", + lambda _device: { + "use_u32": True, + "fixed_decode_config": None, + "sign_bytes": torch.empty((16, 16, 2048), device="cuda", dtype=torch.uint8), + "sign_words": torch.empty((16, 4, 2048), device="cuda", dtype=torch.int32), + "scale": torch.empty((16, 2048), device="cuda", dtype=torch.float16), + }, + ) + + x = torch.randn(1, 2048, device="cuda", dtype=torch.float16) + module._forward_triton(x) + + assert calls == [("u32_k2048_fixed", fixed_config)] + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for GGUF Triton routing test") +def test_gguf_triton_routes_sm80_decode_6144x2048_to_u32_path(monkeypatch): + pytest.importorskip("triton") + + from gptqmodel.nn_modules.qlinear import gguf_triton + + calls: list[tuple[str, dict[str, int] | None]] = [] + + fixed_config = { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 16, + "num_warps": 4, + "num_stages": 4, + } + + def _fake_u32_fixed(x, sign_words, scale, *, fixed_config): + calls.append(("u32_fixed", fixed_config)) + return torch.empty((x.shape[0], scale.shape[1]), device=x.device, dtype=x.dtype) + + def _fake_byte(x, sign_bytes, scale): + calls.append(("byte", None)) + return torch.empty((x.shape[0], scale.shape[1]), device=x.device, dtype=x.dtype) + + monkeypatch.setattr(gguf_triton, "_cuda_device_capability", lambda _device: (8, 0)) + monkeypatch.setattr(gguf_triton, "_launch_q1_0_g128_u32_fixed_matmul", _fake_u32_fixed) + monkeypatch.setattr(gguf_triton, "fused_q1_0_g128_matmul", _fake_byte) + + module = GGUFTritonKernel( + bits="q1_0_g128", + group_size=-1, + sym=True, + desc_act=False, + in_features=6144, + out_features=2048, + bias=False, + register_buffers=True, + ).to("cuda").eval() + monkeypatch.setattr( + module, + "_get_triton_cache", + lambda _device: { + "use_u32": True, + "fixed_decode_config": None, + "sign_bytes": torch.empty((48, 16, 2048), device="cuda", dtype=torch.uint8), + "sign_words": torch.empty((48, 4, 2048), device="cuda", dtype=torch.int32), + "scale": torch.empty((48, 2048), device="cuda", dtype=torch.float16), + }, + ) + + x = torch.randn(1, 6144, device="cuda", dtype=torch.float16) + module._forward_triton(x) + + assert calls == [("u32_fixed", fixed_config)] + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for GGUF Triton routing test") +def test_gguf_triton_routes_decode_2048x1024_to_k2048_specialized_byte_path(monkeypatch): + pytest.importorskip("triton") + + from gptqmodel.nn_modules.qlinear import gguf_triton + + calls: list[str] = [] + + def _fake_k2048(x, sign_bytes, scale): + calls.append("k2048") + return torch.empty((x.shape[0], scale.shape[1]), device=x.device, dtype=x.dtype) + + def _fake_generic(x, sign_bytes, scale): + calls.append("generic") + return torch.empty((x.shape[0], scale.shape[1]), device=x.device, dtype=x.dtype) + + monkeypatch.setattr(gguf_triton, "fused_q1_0_g128_k2048_matmul", _fake_k2048) + monkeypatch.setattr(gguf_triton, "fused_q1_0_g128_matmul", _fake_generic) + + module = GGUFTritonKernel( + bits="q1_0_g128", + group_size=-1, + sym=True, + desc_act=False, + in_features=2048, + out_features=1024, + bias=False, + register_buffers=True, + ).to("cuda").eval() + monkeypatch.setattr( + module, + "_get_triton_cache", + lambda _device: { + "use_u32": False, + "fixed_decode_config": None, + "sign_bytes": torch.empty((16, 16, 1024), device="cuda", dtype=torch.uint8), + "scale": torch.empty((16, 1024), device="cuda", dtype=torch.float16), + }, + ) + + x = torch.randn(1, 2048, device="cuda", dtype=torch.float16) + module._forward_triton(x) + + assert calls == ["k2048"] + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for GGUF Triton routing test") +def test_gguf_triton_routes_out_of_range_sm80_down_proj_to_byte_path(monkeypatch): + pytest.importorskip("triton") + + from gptqmodel.nn_modules.qlinear import gguf_triton + + calls: list[str] = [] + + def _fake_byte(x, sign_bytes, scale): + calls.append("byte") + return torch.empty((x.shape[0], scale.shape[1]), device=x.device, dtype=x.dtype) + + monkeypatch.setattr(gguf_triton, "_cuda_device_capability", lambda _device: (8, 0)) + monkeypatch.setattr(gguf_triton, "fused_q1_0_g128_matmul", _fake_byte) + + module = GGUFTritonKernel( + bits="q1_0_g128", + group_size=-1, + sym=True, + desc_act=False, + in_features=6144, + out_features=2048, + bias=False, + register_buffers=True, + ).to("cuda").eval() + monkeypatch.setattr( + module, + "_get_triton_cache", + lambda _device: { + "use_u32": True, + "fixed_decode_config": None, + "sign_bytes": torch.empty((48, 16, 2048), device="cuda", dtype=torch.uint8), + "sign_words": torch.empty((48, 4, 2048), device="cuda", dtype=torch.int32), + "scale": torch.empty((48, 2048), device="cuda", dtype=torch.float16), + }, + ) + + x = torch.randn(4, 6144, device="cuda", dtype=torch.float16) + module._forward_triton(x) + + assert calls == ["byte"] + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for GGUF Triton routing test") +def test_gguf_triton_routes_non_decode_2048x1024_to_generic_byte_path(monkeypatch): + pytest.importorskip("triton") + + from gptqmodel.nn_modules.qlinear import gguf_triton + + calls: list[str] = [] + + def _fake_k2048(x, sign_bytes, scale): + calls.append("k2048") + return torch.empty((x.shape[0], scale.shape[1]), device=x.device, dtype=x.dtype) + + def _fake_generic(x, sign_bytes, scale): + calls.append("generic") + return torch.empty((x.shape[0], scale.shape[1]), device=x.device, dtype=x.dtype) + + monkeypatch.setattr(gguf_triton, "fused_q1_0_g128_k2048_matmul", _fake_k2048) + monkeypatch.setattr(gguf_triton, "fused_q1_0_g128_matmul", _fake_generic) + + module = GGUFTritonKernel( + bits="q1_0_g128", + group_size=-1, + sym=True, + desc_act=False, + in_features=2048, + out_features=1024, + bias=False, + register_buffers=True, + ).to("cuda").eval() + monkeypatch.setattr( + module, + "_get_triton_cache", + lambda _device: { + "use_u32": False, + "fixed_decode_config": None, + "sign_bytes": torch.empty((16, 16, 1024), device="cuda", dtype=torch.uint8), + "scale": torch.empty((16, 1024), device="cuda", dtype=torch.float16), + }, + ) + + x = torch.randn(4, 2048, device="cuda", dtype=torch.float16) + module._forward_triton(x) + + assert calls == ["generic"] + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for GGUF Triton routing test") +def test_gguf_triton_routes_sm89_decode_2048x1024_to_fixed_byte_path(monkeypatch): + pytest.importorskip("triton") + + from gptqmodel.nn_modules.qlinear import gguf_triton + + calls: list[tuple[str, dict[str, int] | None]] = [] + + fixed_config = { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 32, + "num_warps": 8, + "num_stages": 2, + } + + def _fake_fixed(x, sign_bytes, scale, *, fixed_config): + calls.append(("fixed", fixed_config)) + return torch.empty((x.shape[0], scale.shape[1]), device=x.device, dtype=x.dtype) + + def _fake_generic(x, sign_bytes, scale): + calls.append(("generic", None)) + return torch.empty((x.shape[0], scale.shape[1]), device=x.device, dtype=x.dtype) + + monkeypatch.setattr(gguf_triton, "_cuda_device_capability", lambda _device: (8, 9)) + monkeypatch.setattr(gguf_triton, "_launch_q1_0_g128_k2048_fixed_matmul", _fake_fixed) + monkeypatch.setattr(gguf_triton, "fused_q1_0_g128_k2048_matmul", _fake_generic) + + module = GGUFTritonKernel( + bits="q1_0_g128", + group_size=-1, + sym=True, + desc_act=False, + in_features=2048, + out_features=1024, + bias=False, + register_buffers=True, + ).to("cuda").eval() + monkeypatch.setattr( + module, + "_get_triton_cache", + lambda _device: { + "use_u32": False, + "fixed_decode_config": fixed_config, + "sign_bytes": torch.empty((16, 16, 1024), device="cuda", dtype=torch.uint8), + "scale": torch.empty((16, 1024), device="cuda", dtype=torch.float16), + }, + ) + + x = torch.randn(1, 2048, device="cuda", dtype=torch.float16) + module._forward_triton(x) + + assert calls == [("fixed", fixed_config)] + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for GGUF Triton routing test") +def test_gguf_triton_routes_sm89_decode_2048x2048_to_exact_fixed_byte_path(monkeypatch): + pytest.importorskip("triton") + + from gptqmodel.nn_modules.qlinear import gguf_triton + + calls: list[tuple[str, dict[str, int] | None]] = [] + + fixed_config = { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "num_warps": 4, + "num_stages": 2, + } + + def _fake_fixed(x, sign_bytes, scale, *, fixed_config): + calls.append(("fixed", fixed_config)) + return torch.empty((x.shape[0], scale.shape[1]), device=x.device, dtype=x.dtype) + + def _fake_generic(x, sign_bytes, scale): + calls.append(("generic", None)) + return torch.empty((x.shape[0], scale.shape[1]), device=x.device, dtype=x.dtype) + + monkeypatch.setattr(gguf_triton, "_cuda_device_capability", lambda _device: (8, 9)) + monkeypatch.setattr(gguf_triton, "_launch_q1_0_g128_k2048_fixed_matmul", _fake_fixed) + monkeypatch.setattr(gguf_triton, "fused_q1_0_g128_k2048_matmul", _fake_generic) + + module = GGUFTritonKernel( + bits="q1_0_g128", + group_size=-1, + sym=True, + desc_act=False, + in_features=2048, + out_features=2048, + bias=False, + register_buffers=True, + ).to("cuda").eval() + monkeypatch.setattr( + module, + "_get_triton_cache", + lambda _device: { + "use_u32": False, + "fixed_decode_config": fixed_config, + "sign_bytes": torch.empty((16, 16, 2048), device="cuda", dtype=torch.uint8), + "scale": torch.empty((16, 2048), device="cuda", dtype=torch.float16), + }, + ) + + x = torch.randn(1, 2048, device="cuda", dtype=torch.float16) + module._forward_triton(x) + + assert calls == [("fixed", fixed_config)] + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for GGUF Triton routing test") +def test_gguf_triton_routes_sm89_decode_2048x6144_to_exact_fixed_byte_path(monkeypatch): + pytest.importorskip("triton") + + from gptqmodel.nn_modules.qlinear import gguf_triton + + calls: list[tuple[str, dict[str, int] | None]] = [] + + fixed_config = { + "BLOCK_SIZE_M": 8, + "BLOCK_SIZE_N": 64, + "num_warps": 8, + "num_stages": 2, + } + + def _fake_fixed(x, sign_bytes, scale, *, fixed_config): + calls.append(("fixed", fixed_config)) + return torch.empty((x.shape[0], scale.shape[1]), device=x.device, dtype=x.dtype) + + def _fake_generic(x, sign_bytes, scale): + calls.append(("generic", None)) + return torch.empty((x.shape[0], scale.shape[1]), device=x.device, dtype=x.dtype) + + monkeypatch.setattr(gguf_triton, "_cuda_device_capability", lambda _device: (8, 9)) + monkeypatch.setattr(gguf_triton, "_launch_q1_0_g128_k2048_fixed_matmul", _fake_fixed) + monkeypatch.setattr(gguf_triton, "fused_q1_0_g128_k2048_matmul", _fake_generic) + + module = GGUFTritonKernel( + bits="q1_0_g128", + group_size=-1, + sym=True, + desc_act=False, + in_features=2048, + out_features=6144, + bias=False, + register_buffers=True, + ).to("cuda").eval() + monkeypatch.setattr( + module, + "_get_triton_cache", + lambda _device: { + "use_u32": False, + "fixed_decode_config": fixed_config, + "sign_bytes": torch.empty((16, 16, 6144), device="cuda", dtype=torch.uint8), + "scale": torch.empty((16, 6144), device="cuda", dtype=torch.float16), + }, + ) + + x = torch.randn(1, 2048, device="cuda", dtype=torch.float16) + module._forward_triton(x) + + assert calls == [("fixed", fixed_config)] + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for GGUF Triton routing test") +def test_gguf_triton_routes_sm80_decode_2048x2048_out_of_range_to_byte_path(monkeypatch): + pytest.importorskip("triton") + + from gptqmodel.nn_modules.qlinear import gguf_triton + + calls: list[str] = [] + + def _fake_byte(x, sign_bytes, scale): + calls.append("byte") + return torch.empty((x.shape[0], scale.shape[1]), device=x.device, dtype=x.dtype) + + monkeypatch.setattr(gguf_triton, "_cuda_device_capability", lambda _device: (8, 0)) + monkeypatch.setattr(gguf_triton, "fused_q1_0_g128_matmul", _fake_byte) + + module = GGUFTritonKernel( + bits="q1_0_g128", + group_size=-1, + sym=True, + desc_act=False, + in_features=2048, + out_features=2048, + bias=False, + register_buffers=True, + ).to("cuda").eval() + monkeypatch.setattr( + module, + "_get_triton_cache", + lambda _device: { + "use_u32": True, + "fixed_decode_config": None, + "sign_bytes": torch.empty((16, 16, 2048), device="cuda", dtype=torch.uint8), + "sign_words": torch.empty((16, 4, 2048), device="cuda", dtype=torch.int32), + "scale": torch.empty((16, 2048), device="cuda", dtype=torch.float16), + }, + ) + + x = torch.randn(4, 2048, device="cuda", dtype=torch.float16) + module._forward_triton(x) + + assert calls == ["byte"] + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for GGUF fused K routing test") +def test_gguf_forward_uses_fused_k_path_for_small_cuda_batches(): + case = _build_rtn_microbench_case(torch.float16, bits="q4_k_m") + module = _build_rtn_gguf_module(case).to(case["device"]).eval() + calls = {"dense": 0, "fused": 0} + + module.gguf_fused_cuda_max_rows = case["inputs"].shape[0] + module.gguf_fused_cuda_min_matrix_elements = 0 + module.autotune_enabled = False + module.clear_autotune() + module.clear_weight_cache() + + def _dense(self, x_flat): + calls["dense"] += 1 + return torch.zeros((x_flat.shape[0], self.out_features), device=x_flat.device, dtype=x_flat.dtype) + + def _fused(self, x_flat): + calls["fused"] += 1 + return torch.zeros((x_flat.shape[0], self.out_features), device=x_flat.device, dtype=x_flat.dtype) + + module._forward_dequant_matmul = _dense.__get__(module, GGUFTorchLinear) + module._forward_fused_k = _fused.__get__(module, GGUFTorchLinear) + + module(case["inputs"]) + assert calls == {"dense": 0, "fused": 1} + + module.gguf_fused_cuda_max_rows = 1 + calls = {"dense": 0, "fused": 0} + module(case["inputs"]) + assert calls == {"dense": 1, "fused": 0} + + +@pytest.mark.parametrize("bits", ["q4_k_m", "q5_k_m", "q6_k"]) +def test_gguf_cpu_fused_forward_matches_dense_baseline(bits: str): + case = _build_rtn_microbench_case(torch.bfloat16, bits=bits) + module = _build_rtn_gguf_module(case).cpu().eval() + inputs = case["inputs"].detach().cpu().to(torch.bfloat16) + + module.gguf_fused_cpu_max_rows = inputs.shape[0] + module.gguf_fused_cpu_min_matrix_elements = 0 + module.clear_weight_cache() + + baseline = module._forward_dequant_matmul(inputs) + fused = module._forward_fused_k(inputs) + + output_stats = _error_stats(baseline.to(torch.float32), fused.to(torch.float32)) + assert output_stats["mae"] < 2e-4 + assert output_stats["max"] < 3e-3 + + +def test_gguf_forward_uses_fused_k_path_for_small_cpu_batches(): + case = _build_rtn_microbench_case(torch.bfloat16, bits="q4_k_m") + module = _build_rtn_gguf_module(case).cpu().eval() + inputs = case["inputs"].detach().cpu().to(torch.bfloat16) + calls = {"dense": 0, "fused": 0} + + module.gguf_fused_cpu_max_rows = inputs.shape[0] + module.gguf_fused_cpu_min_matrix_elements = 0 + module.autotune_enabled = False + module.clear_autotune() + module.clear_weight_cache() + + def _dense(self, x_flat): + calls["dense"] += 1 + return torch.zeros((x_flat.shape[0], self.out_features), device=x_flat.device, dtype=x_flat.dtype) + + def _fused(self, x_flat): + calls["fused"] += 1 + return torch.zeros((x_flat.shape[0], self.out_features), device=x_flat.device, dtype=x_flat.dtype) + + module._forward_dequant_matmul = _dense.__get__(module, GGUFTorchLinear) + module._forward_fused_k = _fused.__get__(module, GGUFTorchLinear) + + module(inputs) + assert calls == {"dense": 0, "fused": 1} + + module.gguf_fused_cpu_max_rows = 1 + calls = {"dense": 0, "fused": 0} + module(inputs) + assert calls == {"dense": 1, "fused": 0} + + +def test_gguf_forward_autotunes_once_per_instance_with_fused_plan_on_cpu(monkeypatch): + case = _build_rtn_microbench_case(torch.bfloat16, bits="q4_k_m") + module = _build_rtn_gguf_module(case).cpu().eval() + inputs = case["inputs"].detach().cpu().to(torch.bfloat16) + + module.gguf_fused_cpu_max_rows = inputs.shape[0] + module.gguf_fused_cpu_min_matrix_elements = 0 + module.autotune_enabled = True + module.clear_autotune() + module.clear_weight_cache() + + calls = {"dense": 0, "fused": 0} + + def _dense(self, x_flat): + calls["dense"] += 1 + return 2.0 + + def _fused(self, x_flat): + calls["fused"] += 1 + return 1.0 + + monkeypatch.setattr(GGUFTorchLinear, "_benchmark_dense_forward", _dense) + monkeypatch.setattr(GGUFTorchLinear, "_benchmark_fused_forward", _fused) + + module(inputs) + + assert calls == {"dense": 1, "fused": 1} + assert module.get_autotune_result() is True + + def _fail(self, x_flat): + raise AssertionError("autotune benchmark should not rerun for cached fused plan") + + monkeypatch.setattr(GGUFTorchLinear, "_benchmark_dense_forward", _fail) + monkeypatch.setattr(GGUFTorchLinear, "_benchmark_fused_forward", _fail) + + module(inputs) + assert module.get_autotune_result() is True + + +def test_gguf_forward_autotunes_once_per_instance_with_dense_plan_on_cpu(monkeypatch): + case = _build_rtn_microbench_case(torch.bfloat16, bits="q4_k_m") + module = _build_rtn_gguf_module(case).cpu().eval() + inputs = case["inputs"].detach().cpu().to(torch.bfloat16) + + module.gguf_fused_cpu_max_rows = inputs.shape[0] + module.gguf_fused_cpu_min_matrix_elements = 0 + module.autotune_enabled = True + module.clear_autotune() + module.clear_weight_cache() + + calls = {"dense": 0, "fused": 0} + + def _dense(self, x_flat): + calls["dense"] += 1 + return 1.0 + + def _fused(self, x_flat): + calls["fused"] += 1 + return 2.0 + + monkeypatch.setattr(GGUFTorchLinear, "_benchmark_dense_forward", _dense) + monkeypatch.setattr(GGUFTorchLinear, "_benchmark_fused_forward", _fused) + + module(inputs) + + assert calls == {"dense": 1, "fused": 1} + assert module.get_autotune_result() is False + + def _fail(self, x_flat): + raise AssertionError("autotune benchmark should not rerun for cached dense plan") + + monkeypatch.setattr(GGUFTorchLinear, "_benchmark_dense_forward", _fail) + monkeypatch.setattr(GGUFTorchLinear, "_benchmark_fused_forward", _fail) + + module(inputs) + assert module.get_autotune_result() is False + + +def test_gguf_forward_autotunes_once_per_module_instance(monkeypatch): + case = _build_rtn_microbench_case(torch.bfloat16, bits="q4_k_m") + inputs = case["inputs"].detach().cpu().to(torch.bfloat16) + module_a = _build_rtn_gguf_module(case).cpu().eval() + module_b = _build_rtn_gguf_module(case).cpu().eval() + + for module in (module_a, module_b): + module.gguf_fused_cpu_max_rows = inputs.shape[0] + module.gguf_fused_cpu_min_matrix_elements = 0 + module.autotune_enabled = True + module.clear_autotune() + module.clear_weight_cache() + + calls = {"dense": 0, "fused": 0} + + def _dense(self, x_flat): + calls["dense"] += 1 + return 2.0 + + def _fused(self, x_flat): + calls["fused"] += 1 + return 1.0 + + monkeypatch.setattr(GGUFTorchLinear, "_benchmark_dense_forward", _dense) + monkeypatch.setattr(GGUFTorchLinear, "_benchmark_fused_forward", _fused) + + module_a(inputs) + + assert calls == {"dense": 1, "fused": 1} + assert module_a.get_autotune_result() is True + + module_b(inputs) + + assert calls == {"dense": 2, "fused": 2} + assert module_b.get_autotune_result() is True + + +@pytest.mark.parametrize( + ("bits", "tensor_qtype"), + [ + ("q4_k_s", "Q4_K"), + ("q4_k_m", "Q4_K"), + ("q5_k_s", "Q5_K"), + ("q5_k_m", "Q5_K"), + ("q6_k", "Q6_K"), + ], +) +def test_rtn_microbench_gguf_export_layout_round_trips_with_reference_dequantizer( + bits: str, + tensor_qtype: str, +): + gguf = pytest.importorskip("gguf") + + dtype = torch.float16 if torch.cuda.is_available() else torch.bfloat16 + case = _build_rtn_microbench_case(dtype, bits=bits) + module = _build_rtn_gguf_module(case) + + reference = gguf.dequantize( + module.qweight.detach().cpu().numpy(), + getattr(gguf.GGMLQuantizationType, tensor_qtype), + ) + + np.testing.assert_allclose( + module.dequantize_weight().T.detach().cpu().numpy(), + reference[:, : case["in_features"]], + rtol=0.0, + atol=1e-6, + ) + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize( + ("bits", "weight_mae_max", "output_mae_max", "output_max_max"), + [ + ("q4_k_s", 0.0015, 7e-4, 0.015), + ("q4_k_m", 0.0015, 7e-4, 0.015), + ("q5_k_s", 0.0010, 6e-4, 0.012), + ("q5_k_m", 0.0010, 6e-4, 0.012), + ("q6_k", 7e-4, 5e-4, 0.010), + ], +) +def test_rtn_microbench_gguf_export_stays_close_to_rtn( + dtype: torch.dtype, + bits: str, + weight_mae_max: float, + output_mae_max: float, + output_max_max: float, +): + case = _build_rtn_microbench_case(dtype, bits=bits) + module = _build_rtn_gguf_module(case) + + packed_weight = module.dequantize_weight().T.detach().cpu().to(case["rtn_weight_cpu"].dtype) + weight_stats = _error_stats(case["rtn_weight_cpu"], packed_weight) + + exported_output = module(case["inputs"]) + output_stats = _error_stats(case["rtn_output"], exported_output) + + assert weight_stats["mae"] < weight_mae_max + assert output_stats["mae"] < output_mae_max + assert output_stats["max"] < output_max_max + + +@pytest.mark.parametrize( + ("bits", "output_mae_max", "output_max_max"), + [ + ("q4_k_s", 7e-4, 0.015), + ("q4_k_m", 7e-4, 0.015), + ("q5_k_m", 6e-4, 0.012), + ("q6_k", 5e-4, 0.010), + ], +) +def test_rtn_microbench_gguf_reload_from_state_dict_stays_close_to_rtn( + bits: str, + output_mae_max: float, + output_max_max: float, +): + case = _build_rtn_microbench_case(torch.float16, bits=bits) + module = _build_rtn_gguf_module(case) + + reloaded = GGUFTorchLinear( + bits=bits, + group_size=-1, + sym=True, + desc_act=False, + in_features=case["in_features"], + out_features=case["out_features"], + bias=True, + register_buffers=True, + ) + reloaded.load_state_dict({k: v.detach().cpu() for k, v in module.state_dict().items()}) + reloaded.post_init() + reloaded = reloaded.to(case["device"]).eval() + + output_stats = _error_stats(case["rtn_output"], reloaded(case["inputs"])) + + assert output_stats["mae"] < output_mae_max + assert output_stats["max"] < output_max_max + + +def test_q4_k_s_and_q4_k_m_export_identical_tensor_bytes(): + dtype = torch.float16 if torch.cuda.is_available() else torch.bfloat16 + case = _build_rtn_microbench_case(dtype, bits="q4_k_s") + module_s = _build_rtn_gguf_module(case) + + case_m = dict(case) + case_m["bits"] = GGUFBits.from_string("q4_k_m") + module_m = _build_rtn_gguf_module(case_m) + + assert module_s.gguf_tensor_qtype == "Q4_K" + assert module_m.gguf_tensor_qtype == "Q4_K" + np.testing.assert_array_equal( + module_s.qweight.detach().cpu().numpy(), + module_m.qweight.detach().cpu().numpy(), + ) + torch.testing.assert_close( + module_s.dequantize_weight(), + module_m.dequantize_weight(), + atol=0.0, + rtol=0.0, + ) diff --git a/tests/test_weight_only_config.py b/tests/test_weight_only_config.py new file mode 100644 index 000000000..a04bdbf67 --- /dev/null +++ b/tests/test_weight_only_config.py @@ -0,0 +1,427 @@ +# SPDX-FileCopyrightText: 2026 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import fields +from inspect import signature + +import pytest +import torch + +from gptqmodel.quantization.config import ( + METHOD, + AutoModuleDecoderConfig, + BaseQuantizeConfig, + BitsAndBytesConfig, + GGUFBits, + GGUFConfig, + GPTQConfig, + QuantizeConfig, + RTNConfig, + SmootherConfig, + SmoothMAD, + TensorParallelPadderConfig, +) + + +def test_quantize_config_weight_only_round_trip(): + smooth = SmoothMAD(k=1.75) + cfg = RTNConfig( + bits=4, + group_size=128, + smooth=smooth, + ) + + assert cfg.uses_weight_only_lifecycle() is True + assert cfg.requires_calibration_dataset() is False + assert isinstance(cfg.smooth, SmoothMAD) + assert cfg.smooth.k == pytest.approx(1.75) + + payload = cfg.to_dict() + assert "method" not in payload["meta"]["weight_only"] + assert payload["meta"]["weight_only"]["smooth"]["type"] == "mad" + assert payload["meta"]["weight_only"]["smooth"]["k"] == pytest.approx(1.75) + + reloaded = QuantizeConfig.from_quant_config(payload) + assert isinstance(reloaded, RTNConfig) + assert isinstance(reloaded.smooth, SmoothMAD) + assert reloaded.smooth.k == pytest.approx(1.75) + assert reloaded.uses_weight_only_lifecycle() is True + assert reloaded.requires_calibration_dataset() is False + + +def test_rtn_quantize_config_defaults_to_no_smoother(): + cfg = RTNConfig(bits=4, group_size=128) + + assert isinstance(cfg, BaseQuantizeConfig) + assert not isinstance(cfg, GPTQConfig) + assert cfg.uses_weight_only_lifecycle() is True + assert cfg.smooth is None + assert cfg.export_quant_method() is not None + + payload = cfg.to_dict() + assert payload["meta"]["weight_only"]["smooth"] is None + + reloaded = QuantizeConfig.from_quant_config(payload) + assert isinstance(reloaded, RTNConfig) + assert reloaded.smooth is None + + +def test_rtn_quantize_config_supports_awq_export_round_trip(): + smooth = SmoothMAD(k=1.5) + cfg = RTNConfig( + bits=4, + group_size=128, + format="gemm", + smooth=smooth, + ) + + assert cfg.format == "gemm" + assert cfg.export_quant_method().value == "awq" + + payload = cfg.to_dict() + reloaded = QuantizeConfig.from_quant_config(payload) + + assert isinstance(reloaded, RTNConfig) + assert reloaded.format == cfg.format + assert reloaded.export_quant_method() == cfg.export_quant_method() + assert isinstance(reloaded.smooth, SmoothMAD) + assert reloaded.smooth.k == pytest.approx(1.5) + + +def test_gguf_quantize_config_round_trip(): + smooth = SmoothMAD(k=1.25) + cfg = GGUFConfig( + bits=4, + smoother=smooth, + ) + + assert cfg.format == "q_0" + assert cfg.uses_weight_only_lifecycle() is True + assert cfg.requires_calibration_dataset() is False + assert cfg.bits == 4 + assert isinstance(cfg.runtime_bits, GGUFBits) + assert cfg.runtime_bits == "q4_0" + assert cfg.runtime_bits.bits == 4 + assert cfg.runtime_bits.version == "q" + assert cfg.runtime_bits.variant == "0" + assert cfg.runtime_bits.quality is None + assert cfg.group_size == -1 + assert cfg.desc_act is False + assert cfg.quant_method == METHOD.GGUF + assert cfg.export_quant_method() == METHOD.GGUF + assert isinstance(cfg.smoother, SmootherConfig) + assert isinstance(cfg.smooth, SmoothMAD) + assert cfg.smooth.k == pytest.approx(1.25) + + payload = cfg.to_dict() + assert payload["bits"] == 4 + assert payload["method"] == "gguf" + assert payload["quant_method"] == "gguf" + assert payload["format"] == "q_0" + assert payload["checkpoint_format"] == "q_0" + assert "group_size" not in payload + assert "desc_act" not in payload + assert "pack_dtype" not in payload + assert "weight_only" not in payload["meta"] + assert payload["meta"]["preprocessors"][0]["code"] == "smoother" + assert payload["meta"]["preprocessors"][0]["smooth"]["type"] == "mad" + reloaded = QuantizeConfig.from_quant_config(payload) + + assert isinstance(reloaded, GGUFConfig) + assert reloaded.format == cfg.format + assert reloaded.bits == 4 + assert reloaded.runtime_bits == "q4_0" + assert reloaded.quant_method == METHOD.GGUF + assert reloaded.export_quant_method() == cfg.export_quant_method() + assert isinstance(reloaded.smooth, SmoothMAD) + assert reloaded.smooth.k == pytest.approx(1.25) + + +def test_gguf_quantize_config_hides_non_gguf_constructor_args(): + with pytest.raises(TypeError): + GGUFConfig(bits=4, format="q_k_m", group_size=128) + + with pytest.raises(TypeError): + GGUFConfig(bits=4, format="q_k_m", desc_act=True) + + +def test_bitsandbytes_quantize_config_round_trip_4bit(): + cfg = BitsAndBytesConfig( + bits=4, + format="nf4", + block_size=128, + compress_statistics=False, + smoother=SmoothMAD(k=1.2), + ) + + assert cfg.quant_method == METHOD.BITSANDBYTES + assert cfg.format == "nf4" + assert cfg.bits == 4 + assert cfg.block_size == 128 + assert cfg.compress_statistics is False + assert cfg.uses_weight_only_lifecycle() is True + + payload = cfg.to_dict() + assert payload["method"] == "bitsandbytes" + assert payload["quant_method"] == "bitsandbytes" + assert payload["format"] == "nf4" + assert payload["checkpoint_format"] == "nf4" + assert payload["block_size"] == 128 + assert payload["compress_statistics"] is False + + reloaded = QuantizeConfig.from_quant_config(payload) + assert isinstance(reloaded, BitsAndBytesConfig) + assert reloaded.bits == 4 + assert reloaded.format == "nf4" + assert reloaded.block_size == 128 + assert reloaded.compress_statistics is False + assert isinstance(reloaded.smooth, SmoothMAD) + + +def test_bitsandbytes_quantize_config_round_trip_8bit(): + cfg = BitsAndBytesConfig(bits=8) + + assert cfg.quant_method == METHOD.BITSANDBYTES + assert cfg.format == "int8" + assert cfg.bits == 8 + assert cfg.uses_weight_only_lifecycle() is True + + payload = cfg.to_dict() + reloaded = QuantizeConfig.from_quant_config(payload) + + assert isinstance(reloaded, BitsAndBytesConfig) + assert reloaded.bits == 8 + assert reloaded.format == "int8" + assert reloaded.quant_method == METHOD.BITSANDBYTES + + +def test_gguf_config_registers_smoother_preprocessor(): + cfg = GGUFConfig( + bits=4, + format="q_k_m", + preprocessors=[SmootherConfig(smooth=SmoothMAD(k=1.9))], + ) + + assert isinstance(cfg.smoother, SmootherConfig) + assert isinstance(cfg.smooth, SmoothMAD) + assert cfg.smooth.k == pytest.approx(1.9) + assert len(cfg.preprocessors) == 1 + assert cfg.preprocessors[0].code == "smoother" + + +def test_gguf_config_does_not_auto_register_tensor_parallel_padder(): + cfg = GGUFConfig(bits=4, format="q_k_m") + + assert cfg.preprocessors == [] + + +def test_gguf_config_registers_auto_module_decoder_preprocessor(): + cfg = GGUFConfig( + bits=4, + format="q_k_m", + preprocessors=[ + AutoModuleDecoderConfig(target_dtype=torch.float16) + ], + ) + + assert len(cfg.preprocessors) == 1 + decoder = cfg.preprocessors[0] + assert isinstance(decoder, AutoModuleDecoderConfig) + assert decoder.code == "auto_module_decoder" + assert decoder.source_dtype == "auto" + assert decoder.target_dtype == torch.float16 + + payload = cfg.to_dict() + assert payload["meta"]["preprocessors"][0]["code"] == "auto_module_decoder" + assert payload["meta"]["preprocessors"][0]["target_dtype"] == "float16" + + reloaded = QuantizeConfig.from_quant_config(payload) + assert isinstance(reloaded.preprocessors[0], AutoModuleDecoderConfig) + assert reloaded.preprocessors[0].target_dtype == torch.float16 + + +def test_gguf_config_registers_tensor_parallel_padder_preprocessor(): + cfg = GGUFConfig( + bits=4, + format="q_k_m", + preprocessors=[TensorParallelPadderConfig()], + ) + + assert len(cfg.preprocessors) == 1 + padder = cfg.preprocessors[0] + assert isinstance(padder, TensorParallelPadderConfig) + assert padder.code == "tensor_parallel_padder" + + payload = cfg.to_dict() + assert payload["meta"]["preprocessors"][0]["code"] == "tensor_parallel_padder" + + reloaded = QuantizeConfig.from_quant_config(payload) + assert isinstance(reloaded.preprocessors[0], TensorParallelPadderConfig) + + +def test_auto_module_decoder_config_does_not_expose_code_as_init_field(): + decoder_fields = {field.name for field in fields(AutoModuleDecoderConfig)} + + assert "code" not in signature(AutoModuleDecoderConfig).parameters + assert "code" not in decoder_fields + assert AutoModuleDecoderConfig().to_dict()["code"] == "auto_module_decoder" + + +def test_tensor_parallel_padder_config_does_not_expose_code_as_init_field(): + padder_fields = {field.name for field in fields(TensorParallelPadderConfig)} + + assert "code" not in signature(TensorParallelPadderConfig).parameters + assert "code" not in padder_fields + assert TensorParallelPadderConfig().to_dict()["code"] == "tensor_parallel_padder" + + +@pytest.mark.parametrize( + ("bits", "format", "bit_width", "variant", "quality"), + [ + (4, "q_k_m", 4, "k", "m"), + (5, "q_k_s", 5, "k", "s"), + (5, "q_k_m", 5, "k", "m"), + (6, "q_k", 6, "k", None), + ], +) +def test_rtn_quantize_config_supports_gguf_bits_round_trip( + bits: int, + format: str, + bit_width: int, + variant: str, + quality: str | None, +): + cfg = GGUFConfig( + bits=bits, + format=format, + smoother=SmoothMAD(k=1.25), + ) + + payload = cfg.to_dict() + assert payload["bits"] == bit_width + assert payload["format"] == format + assert cfg.bits == bit_width + assert cfg.runtime_bits.bits == bit_width + assert cfg.runtime_bits.version == "q" + assert cfg.runtime_bits.variant == variant + assert cfg.runtime_bits.quality == quality + + reloaded = QuantizeConfig.from_quant_config(payload) + + assert isinstance(reloaded, GGUFConfig) + assert reloaded.quant_method == METHOD.GGUF + assert reloaded.format == format + assert reloaded.bits == bit_width + assert reloaded.runtime_bits.bits == bit_width + assert reloaded.runtime_bits.version == "q" + assert reloaded.runtime_bits.variant == variant + assert reloaded.runtime_bits.quality == quality + + +def test_rtn_quantize_config_supports_structured_gguf_bits_round_trip(): + cfg = GGUFConfig( + bits=GGUFBits(bits=4, version="q", variant="k", quality="s"), + smoother=SmoothMAD(k=1.25), + ) + + assert cfg.bits == 4 + assert cfg.format == "q_k_s" + assert isinstance(cfg.runtime_bits, GGUFBits) + assert cfg.runtime_bits == "q4_k_s" + assert cfg.runtime_bits.bits == 4 + assert cfg.runtime_bits.version == "q" + assert cfg.runtime_bits.variant == "k" + assert cfg.runtime_bits.quality == "s" + + payload = cfg.to_dict() + assert payload["bits"] == 4 + assert payload["format"] == "q_k_s" + + reloaded = QuantizeConfig.from_quant_config(payload) + assert reloaded.bits == 4 + assert reloaded.format == "q_k_s" + assert reloaded.runtime_bits == "q4_k_s" + + +def test_gguf_bits_string_parser_round_trip(): + bits = GGUFBits.from_string("q4_k_s") + + assert isinstance(bits, GGUFBits) + assert bits.bits == 4 + assert bits.version == "q" + assert bits.variant == "k" + assert bits.quality == "s" + assert bits.to_string() == "q4_k_s" + assert str(bits) == "q4_k_s" + assert int(bits) == 4 + + +def test_weight_only_payload_dispatches_to_rtn(): + cfg = QuantizeConfig( + bits=4, + group_size=128, + weight_only={ + "method": "rtn", + "smooth": {"type": "mad", "k": 2.0}, + }, + ) + + assert isinstance(cfg, RTNConfig) + assert isinstance(cfg.smooth, SmoothMAD) + assert cfg.smooth.k == pytest.approx(2.0) + + +def test_weight_only_payload_dispatches_to_rtn_gguf(): + cfg = QuantizeConfig( + bits=4, + weight_only={ + "method": "gguf", + "smooth": {"type": "mad", "k": 1.5}, + }, + ) + + assert isinstance(cfg, GGUFConfig) + assert cfg.format == "q_0" + assert cfg.bits == 4 + assert cfg.runtime_bits == "q4_0" + assert cfg.runtime_bits.bits == 4 + assert cfg.runtime_bits.version == "q" + assert cfg.runtime_bits.variant == "0" + assert isinstance(cfg.smooth, SmoothMAD) + assert cfg.smooth.k == pytest.approx(1.5) + + +def test_weight_only_payload_dispatches_to_rtn_gguf_with_qtype(): + cfg = QuantizeConfig( + bits="q6_k", + weight_only={ + "method": "gguf", + }, + ) + + assert isinstance(cfg, GGUFConfig) + assert cfg.format == "q_k" + assert cfg.bits == 6 + assert cfg.runtime_bits == "q6_k" + assert cfg.runtime_bits.bits == 6 + assert cfg.runtime_bits.version == "q" + assert cfg.runtime_bits.variant == "k" + assert cfg.runtime_bits.quality is None + + +def test_weight_only_payload_dispatches_legacy_gguf_qtype_to_bits(): + cfg = QuantizeConfig( + bits=6, + weight_only={ + "method": "gguf", + "gguf_qtype": "q6_k", + }, + ) + + assert isinstance(cfg, GGUFConfig) + assert cfg.format == "q_k" + assert cfg.bits == 6 + assert cfg.runtime_bits == "q6_k" + assert cfg.runtime_bits.bits == 6 + assert cfg.runtime_bits.version == "q" + assert cfg.runtime_bits.variant == "k" diff --git a/tests/test_writer_attention.py b/tests/test_writer_attention.py index 71ea188e4..4a7bc5e55 100644 --- a/tests/test_writer_attention.py +++ b/tests/test_writer_attention.py @@ -17,6 +17,7 @@ class _DummyKernel: class _DummyQuantizeConfig: format = FORMAT.GPTQ + checkpoint_format = FORMAT.GPTQ quant_method = METHOD.GPTQ damp_percent = 0.0 damp_auto_increment = 0.0