diff --git a/gptqmodel/looper/gptq_processor.py b/gptqmodel/looper/gptq_processor.py index 71a6d5356..a035f31da 100644 --- a/gptqmodel/looper/gptq_processor.py +++ b/gptqmodel/looper/gptq_processor.py @@ -48,6 +48,9 @@ def clone_gptq_config_for_module( qcfg_clone.bits = qcfg.dynamic_get(module_full_name, "bits", qcfg_clone.bits) qcfg_clone.sym = qcfg.dynamic_get(module_full_name, "sym", qcfg_clone.sym) qcfg_clone.mse = qcfg.dynamic_get(module_full_name, "mse", qcfg_clone.mse) + qcfg_clone.activation_weighted_mse = qcfg.dynamic_get( + module_full_name, "activation_weighted_mse", qcfg_clone.activation_weighted_mse + ) qcfg_clone.group_size = qcfg.dynamic_get(module_full_name, "group_size", qcfg_clone.group_size) desc_act_override = qcfg.dynamic_get(module_full_name, "desc_act", None) diff --git a/gptqmodel/quantization/config.py b/gptqmodel/quantization/config.py index a1c60fe96..0a2091fea 100644 --- a/gptqmodel/quantization/config.py +++ b/gptqmodel/quantization/config.py @@ -2633,6 +2633,73 @@ def save_pretrained(self, save_dir: str, **kwargs): log.info(f"Saved Quantize Config: \n{json_str}") f.write(json_str) + @classmethod + def gptq_pro( + cls, + *, + bits: int = 4, + group_size: int = 128, + sym: bool = True, + mse: float = 2.0, + damp_percent: float = 0.05, + damp_auto_increment: float = 0.01, + gptaq_alpha: Optional[float] = None, + gptaq_device: Union[str, torch.device] = "auto", + failsafe: Optional[Union[Fallback, Dict[str, Any], str, int, float]] = None, + **kwargs, + ) -> "QuantizeConfig": + """ + Build a speed-preserving GPTQ quality profile. + + The returned config keeps the standard GPTQ output format so existing + GPTQ/Marlin/ExLlama/VLLM kernels continue to run unchanged, while + enabling offline-only quality improvements already implemented in + GPTQModel such as GAR (`act_group_aware`), MSE scale search, and + adaptive damping for badly conditioned Hessian blocks. + """ + if "quant_method" in kwargs and kwargs["quant_method"] != METHOD.GPTQ: + raise ValueError("QuantizeConfig.gptq_pro() only supports `quant_method=METHOD.GPTQ`.") + if METHOD_FIELD_CODE in kwargs and kwargs[METHOD_FIELD_CODE] != METHOD.GPTQ: + raise ValueError("QuantizeConfig.gptq_pro() only supports `method=METHOD.GPTQ`.") + + if "format" in kwargs and kwargs["format"] not in QUANT_METHOD_FORMAT_MAPPING[METHOD.GPTQ]: + raise ValueError("QuantizeConfig.gptq_pro() only supports GPTQ-compatible output formats.") + + fallback = kwargs.pop("fallback", None) + if fallback is None and "failsafe" in kwargs: + fallback = kwargs.pop("failsafe") + if fallback is None: + fallback = failsafe + + if failsafe is None: + fallback = Fallback( + strategy=FallbackStrategy.RTN, + threshold="0.5%", + smooth=SmoothMSE(steps=32, maxshrink=0.9), + ) + + gptaq = kwargs.pop("gptaq", None) + if gptaq is None and gptaq_alpha is not None: + gptaq = GPTAQConfig(alpha=gptaq_alpha, device=gptaq_device) + + defaults = { + "bits": bits, + "group_size": group_size, + "sym": sym, + METHOD_FIELD_CODE: METHOD.GPTQ, + "format": FORMAT.GPTQ, + "desc_act": False, + "act_group_aware": True, + "mse": mse, + "activation_weighted_mse": True, + "damp_percent": damp_percent, + "damp_auto_increment": damp_auto_increment, + "fallback": fallback, + "gptaq": gptaq, + } + defaults.update(kwargs) + return cls(**defaults) + @classmethod def from_quant_config(cls, quantize_cfg, format: str = None): valid_formats = set(FORMAT) @@ -2761,6 +2828,7 @@ def from_quant_config(cls, quantize_cfg, format: str = None): "offload_to_disk_path": "offload_to_disk_path", "pack_impl": "pack_impl", "mse": "mse", + "activation_weighted_mse": "activation_weighted_mse", "mock_quantization": "mock_quantization", "act_group_aware": "act_group_aware", "true_sequential": "true_sequential", @@ -3016,6 +3084,7 @@ class GPTQConfig(PreProcessorConfig): act_group_aware: Optional[bool] = field(default=None) static_groups: bool = field(default=False) mse: float = field(default=0.0) + activation_weighted_mse: bool = field(default=False) gptaq: Optional[GPTAQConfig] = field(default=None) foem: Optional[FOEMConfig] = field(default=None) mock_quantization: bool = field( @@ -3100,6 +3169,7 @@ def _update_meta_payload(self, meta_payload: Dict[str, Any]) -> None: } meta_payload["mse"] = self.mse + meta_payload["activation_weighted_mse"] = self.activation_weighted_mse meta_payload["mock_quantization"] = self.mock_quantization meta_payload["act_group_aware"] = self.act_group_aware meta_payload["hessian"] = { diff --git a/gptqmodel/quantization/gptq.py b/gptqmodel/quantization/gptq.py index efb9cc706..4c4d59db1 100644 --- a/gptqmodel/quantization/gptq.py +++ b/gptqmodel/quantization/gptq.py @@ -961,16 +961,22 @@ def quantize( W = self.module_copy.to(device=self.H.device) del self.module_copy - self.quantizer.find_params(W, weight=True) - # H = self.H.to(device=self.H.device) + activation_importance = None if use_hessian: # Replace NaN/Inf in H before processing (can occur with some model architectures) self.H.nan_to_num_(nan=0.0, posinf=0.0, neginf=0.0) dead = torch.diag(self.H) == 0 self.H[dead, dead] = 1 W[:, dead] = 0 + if self.qcfg.activation_weighted_mse: + activation_importance = torch.diag(self.H).clamp_min(0).to(device=W.device, dtype=W.dtype) + importance_mean = activation_importance.mean() + if torch.isfinite(importance_mean) and importance_mean > 0: + activation_importance = activation_importance / importance_mean + else: + activation_importance = None # g_idx = [] scale = [] @@ -983,7 +989,14 @@ def quantize( groups = [] for i in range(0, self.columns, self.qcfg.group_size): quantizer = copy.deepcopy(self.quantizer) - quantizer.find_params(W[:, i: (i + self.qcfg.group_size)], weight=True) + group_importance = None + if activation_importance is not None: + group_importance = activation_importance[i: (i + self.qcfg.group_size)] + quantizer.find_params( + W[:, i: (i + self.qcfg.group_size)], + weight=True, + importance=group_importance, + ) scale.append(quantizer.scale) zero.append(quantizer.zero) @@ -994,6 +1007,8 @@ def quantize( try: W = W[:, perm] self.H = self.H[perm][:, perm] + if activation_importance is not None: + activation_importance = activation_importance[perm] except RuntimeError as exc: if self.H.device.type != "cuda" or "out of memory" not in str(exc).lower(): raise @@ -1004,7 +1019,8 @@ def quantize( perm = perm.to(device=cpu_device) W = W.to(device=cpu_device)[:, perm] self.H = self.H.to(device=cpu_device)[perm][:, perm] - self.quantizer.find_params(W, weight=True) + if activation_importance is not None: + activation_importance = activation_importance.to(device=cpu_device)[perm] invperm = torch.argsort(perm) elif self.qcfg.act_group_aware and use_hessian: @@ -1022,6 +1038,8 @@ def quantize( try: W = W[:, final_perm] self.H = self.H[final_perm][:, final_perm] + if activation_importance is not None: + activation_importance = activation_importance[final_perm] except RuntimeError as exc: if self.H.device.type != "cuda" or "out of memory" not in str(exc).lower(): raise @@ -1032,7 +1050,10 @@ def quantize( final_perm = final_perm.to(device=cpu_device) W = W.to(device=cpu_device)[:, final_perm] self.H = self.H.to(device=cpu_device)[final_perm][:, final_perm] - self.quantizer.find_params(W, weight=True) + if activation_importance is not None: + activation_importance = activation_importance.to(device=cpu_device)[final_perm] + + self.quantizer.find_params(W, weight=True, importance=activation_importance) if use_hessian: try: @@ -1048,7 +1069,9 @@ def quantize( cpu_device = torch.device("cpu") self.H = self.H.to(device=cpu_device) W = W.to(device=cpu_device) - self.quantizer.find_params(W, weight=True) + if activation_importance is not None: + activation_importance = activation_importance.to(device=cpu_device) + self.quantizer.find_params(W, weight=True, importance=activation_importance) Hinv, damp = self.hessian_inverse(self.H) else: Hinv, damp = None, 0.0 @@ -1073,7 +1096,14 @@ def quantize( for group_start in group_start_cols: group_end = min(group_start + self.qcfg.group_size, self.columns) if group_start < group_end: - self.quantizer.find_params(W[:, group_start:group_end], weight=True) + group_importance = None + if activation_importance is not None: + group_importance = activation_importance[group_start:group_end] + self.quantizer.find_params( + W[:, group_start:group_end], + weight=True, + importance=group_importance, + ) scale.append(self.quantizer.scale) zero.append(self.quantizer.zero) now_idx += 1 @@ -1182,7 +1212,16 @@ def quantize( if self.qcfg.group_size != -1: if not self.qcfg.static_groups: if (i1 + i) % self.qcfg.group_size == 0: - self.quantizer.find_params(W[:, (i1 + i) : (i1 + i + self.qcfg.group_size)], weight=True) + group_start = i1 + i + group_end = group_start + self.qcfg.group_size + group_importance = None + if activation_importance is not None: + group_importance = activation_importance[group_start:group_end] + self.quantizer.find_params( + W[:, group_start:group_end], + weight=True, + importance=group_importance, + ) if ((i1 + i) // self.qcfg.group_size) - now_idx == -1: scale.append(self.quantizer.scale) diff --git a/gptqmodel/quantization/quantizer.py b/gptqmodel/quantization/quantizer.py index fa99c5542..07d4dca04 100644 --- a/gptqmodel/quantization/quantizer.py +++ b/gptqmodel/quantization/quantizer.py @@ -68,7 +68,7 @@ def configure( if trits: self.maxq = torch.tensor(-1) - def find_params(self, x, weight=False): + def find_params(self, x, weight=False, importance: torch.Tensor = None): dev = x.device self.maxq = self.maxq.to(dev) @@ -116,6 +116,37 @@ def find_params(self, x, weight=False): mse = float(getattr(self.qcfg, "mse", 0.0) or 0.0) if mse > 0.0: + importance_weights = None + if getattr(self.qcfg, "activation_weighted_mse", False) and importance is not None: + importance_weights = torch.nan_to_num( + importance.to(device=dev, dtype=x.dtype), + nan=0.0, + posinf=0.0, + neginf=0.0, + ).clamp_min_(0) + if importance_weights.ndim == 1: + importance_weights = importance_weights.unsqueeze(0) + if importance_weights.shape[-1] != x.shape[1]: + raise ValueError( + "Quantizer.find_params(): importance parameter shape mismatch. " + f"Expected columns: {x.shape[1]}, got: {importance_weights.shape[-1]}." + ) + if importance_weights.shape[0] == 1 and x.shape[0] != 1: + importance_weights = importance_weights.expand(x.shape[0], -1) + elif importance_weights.shape[0] != x.shape[0]: + raise ValueError( + "Quantizer.find_params(): importance parameter row count mismatch. " + f"Expected 1 or {x.shape[0]} rows, got: {importance_weights.shape[0]}." + ) + + importance_mean = importance_weights.mean(dim=1, keepdim=True) + valid = torch.isfinite(importance_mean) & (importance_mean > 0) + if torch.any(valid): + normalized_weights = importance_weights / importance_mean.clamp_min(1e-8) + importance_weights = torch.where(valid, normalized_weights, torch.ones_like(importance_weights)) + else: + importance_weights = None + best = torch.full([x.shape[0]], float("inf"), device=dev) for i in range(int(self.maxshrink * self.grid)): p = 1 - i / self.grid @@ -131,6 +162,8 @@ def find_params(self, x, weight=False): q -= x q.abs_() q.pow_(mse) + if importance_weights is not None: + q.mul_(importance_weights) err = torch.sum(q, 1) tmp = err < best if torch.any(tmp): diff --git a/tests/qcfg/test_gptq_pro.py b/tests/qcfg/test_gptq_pro.py new file mode 100644 index 000000000..56adbb44e --- /dev/null +++ b/tests/qcfg/test_gptq_pro.py @@ -0,0 +1,56 @@ +import torch + +from gptqmodel.quantization import QuantizeConfig +from gptqmodel.quantization.quantizer import Quantizer + + +def _calculate_weighted_squared_error( + quantizer: Quantizer, + weights: torch.Tensor, + importance: torch.Tensor, +) -> torch.Tensor: + dequant = quantizer.quantize(weights) + return ((dequant - weights).pow(2) * importance.view(1, -1)).sum() + + +def test_gptq_pro_enables_activation_weighted_mse(): + cfg = QuantizeConfig.gptq_pro() + + assert cfg.activation_weighted_mse is True + assert cfg.act_group_aware is True + assert cfg.desc_act is False + + +def test_activation_weighted_mse_prioritizes_salient_columns(): + weights = torch.tensor([[0.1, 0.45, 0.8, 1.2]], dtype=torch.float32) + importance = torch.tensor([1.0, 1.0, 8.0, 8.0], dtype=torch.float32) + + baseline = Quantizer( + QuantizeConfig(bits=4, sym=False, mse=2.0, act_group_aware=False, desc_act=False), + ) + baseline.configure(perchannel=True) + baseline.find_params(weights, weight=True) + + weighted = Quantizer( + QuantizeConfig( + bits=4, + sym=False, + mse=2.0, + activation_weighted_mse=True, + act_group_aware=False, + desc_act=False, + ), + ) + weighted.configure(perchannel=True) + weighted.find_params(weights, weight=True, importance=importance) + + assert not torch.allclose(weighted.scale, baseline.scale) + assert _calculate_weighted_squared_error( + weighted, + weights, + importance, + ) < _calculate_weighted_squared_error( + baseline, + weights, + importance, + ) diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 8a4a1d8ae..e0bcec20e 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -86,6 +86,7 @@ def test_quantize_config_meta_only_fields_serialization(self): offload_to_disk_path="./offload-test", pack_impl="gpu", mse=0.125, + activation_weighted_mse=True, mock_quantization=True, hessian=HessianConfig( chunk_size=256, @@ -109,6 +110,7 @@ def test_quantize_config_meta_only_fields_serialization(self): "offload_to_disk_path", "pack_impl", "mse", + "activation_weighted_mse", "mock_quantization", "act_group_aware", "hessian", @@ -127,6 +129,7 @@ def test_quantize_config_meta_only_fields_serialization(self): self.assertEqual(meta["offload_to_disk_path"], cfg.offload_to_disk_path) self.assertEqual(meta["pack_impl"], cfg.pack_impl) self.assertEqual(meta["mse"], cfg.mse) + self.assertEqual(meta["activation_weighted_mse"], cfg.activation_weighted_mse) self.assertEqual(meta["mock_quantization"], cfg.mock_quantization) self.assertEqual(meta["act_group_aware"], cfg.act_group_aware) self.assertEqual(meta["hessian"]["chunk_size"], cfg.hessian.chunk_size)