Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions gptqmodel/looper/gptq_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def clone_gptq_config_for_module(
qcfg_clone.bits = qcfg.dynamic_get(module_full_name, "bits", qcfg_clone.bits)
qcfg_clone.sym = qcfg.dynamic_get(module_full_name, "sym", qcfg_clone.sym)
qcfg_clone.mse = qcfg.dynamic_get(module_full_name, "mse", qcfg_clone.mse)
qcfg_clone.activation_weighted_mse = qcfg.dynamic_get(
module_full_name, "activation_weighted_mse", qcfg_clone.activation_weighted_mse
)

qcfg_clone.group_size = qcfg.dynamic_get(module_full_name, "group_size", qcfg_clone.group_size)
desc_act_override = qcfg.dynamic_get(module_full_name, "desc_act", None)
Expand Down
70 changes: 70 additions & 0 deletions gptqmodel/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2633,6 +2633,73 @@ def save_pretrained(self, save_dir: str, **kwargs):
log.info(f"Saved Quantize Config: \n{json_str}")
f.write(json_str)

@classmethod
def gptq_pro(
cls,
*,
bits: int = 4,
group_size: int = 128,
sym: bool = True,
mse: float = 2.0,
damp_percent: float = 0.05,
damp_auto_increment: float = 0.01,
gptaq_alpha: Optional[float] = None,
gptaq_device: Union[str, torch.device] = "auto",
failsafe: Optional[Union[Fallback, Dict[str, Any], str, int, float]] = None,
**kwargs,
) -> "QuantizeConfig":
"""
Build a speed-preserving GPTQ quality profile.

The returned config keeps the standard GPTQ output format so existing
GPTQ/Marlin/ExLlama/VLLM kernels continue to run unchanged, while
enabling offline-only quality improvements already implemented in
GPTQModel such as GAR (`act_group_aware`), MSE scale search, and
adaptive damping for badly conditioned Hessian blocks.
"""
if "quant_method" in kwargs and kwargs["quant_method"] != METHOD.GPTQ:
raise ValueError("QuantizeConfig.gptq_pro() only supports `quant_method=METHOD.GPTQ`.")
if METHOD_FIELD_CODE in kwargs and kwargs[METHOD_FIELD_CODE] != METHOD.GPTQ:
raise ValueError("QuantizeConfig.gptq_pro() only supports `method=METHOD.GPTQ`.")

if "format" in kwargs and kwargs["format"] not in QUANT_METHOD_FORMAT_MAPPING[METHOD.GPTQ]:
raise ValueError("QuantizeConfig.gptq_pro() only supports GPTQ-compatible output formats.")

fallback = kwargs.pop("fallback", None)
if fallback is None and "failsafe" in kwargs:
fallback = kwargs.pop("failsafe")
if fallback is None:
fallback = failsafe

if failsafe is None:
fallback = Fallback(
strategy=FallbackStrategy.RTN,
threshold="0.5%",
smooth=SmoothMSE(steps=32, maxshrink=0.9),
)

gptaq = kwargs.pop("gptaq", None)
if gptaq is None and gptaq_alpha is not None:
gptaq = GPTAQConfig(alpha=gptaq_alpha, device=gptaq_device)

defaults = {
"bits": bits,
"group_size": group_size,
"sym": sym,
METHOD_FIELD_CODE: METHOD.GPTQ,
"format": FORMAT.GPTQ,
"desc_act": False,
"act_group_aware": True,
"mse": mse,
"activation_weighted_mse": True,
"damp_percent": damp_percent,
"damp_auto_increment": damp_auto_increment,
"fallback": fallback,
"gptaq": gptaq,
}
defaults.update(kwargs)
return cls(**defaults)

@classmethod
def from_quant_config(cls, quantize_cfg, format: str = None):
valid_formats = set(FORMAT)
Expand Down Expand Up @@ -2761,6 +2828,7 @@ def from_quant_config(cls, quantize_cfg, format: str = None):
"offload_to_disk_path": "offload_to_disk_path",
"pack_impl": "pack_impl",
"mse": "mse",
"activation_weighted_mse": "activation_weighted_mse",
"mock_quantization": "mock_quantization",
"act_group_aware": "act_group_aware",
"true_sequential": "true_sequential",
Expand Down Expand Up @@ -3016,6 +3084,7 @@ class GPTQConfig(PreProcessorConfig):
act_group_aware: Optional[bool] = field(default=None)
static_groups: bool = field(default=False)
mse: float = field(default=0.0)
activation_weighted_mse: bool = field(default=False)
gptaq: Optional[GPTAQConfig] = field(default=None)
foem: Optional[FOEMConfig] = field(default=None)
mock_quantization: bool = field(
Expand Down Expand Up @@ -3100,6 +3169,7 @@ def _update_meta_payload(self, meta_payload: Dict[str, Any]) -> None:
}

meta_payload["mse"] = self.mse
meta_payload["activation_weighted_mse"] = self.activation_weighted_mse
meta_payload["mock_quantization"] = self.mock_quantization
meta_payload["act_group_aware"] = self.act_group_aware
meta_payload["hessian"] = {
Expand Down
55 changes: 47 additions & 8 deletions gptqmodel/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,16 +961,22 @@ def quantize(
W = self.module_copy.to(device=self.H.device)
del self.module_copy

self.quantizer.find_params(W, weight=True)

# H = self.H.to(device=self.H.device)

activation_importance = None
if use_hessian:
# Replace NaN/Inf in H before processing (can occur with some model architectures)
self.H.nan_to_num_(nan=0.0, posinf=0.0, neginf=0.0)
dead = torch.diag(self.H) == 0
self.H[dead, dead] = 1
W[:, dead] = 0
if self.qcfg.activation_weighted_mse:
activation_importance = torch.diag(self.H).clamp_min(0).to(device=W.device, dtype=W.dtype)
importance_mean = activation_importance.mean()
if torch.isfinite(importance_mean) and importance_mean > 0:
activation_importance = activation_importance / importance_mean
else:
activation_importance = None

# g_idx = []
scale = []
Expand All @@ -983,7 +989,14 @@ def quantize(
groups = []
for i in range(0, self.columns, self.qcfg.group_size):
quantizer = copy.deepcopy(self.quantizer)
quantizer.find_params(W[:, i: (i + self.qcfg.group_size)], weight=True)
group_importance = None
if activation_importance is not None:
group_importance = activation_importance[i: (i + self.qcfg.group_size)]
quantizer.find_params(
W[:, i: (i + self.qcfg.group_size)],
weight=True,
importance=group_importance,
)

scale.append(quantizer.scale)
zero.append(quantizer.zero)
Expand All @@ -994,6 +1007,8 @@ def quantize(
try:
W = W[:, perm]
self.H = self.H[perm][:, perm]
if activation_importance is not None:
activation_importance = activation_importance[perm]
except RuntimeError as exc:
if self.H.device.type != "cuda" or "out of memory" not in str(exc).lower():
raise
Expand All @@ -1004,7 +1019,8 @@ def quantize(
perm = perm.to(device=cpu_device)
W = W.to(device=cpu_device)[:, perm]
self.H = self.H.to(device=cpu_device)[perm][:, perm]
self.quantizer.find_params(W, weight=True)
if activation_importance is not None:
activation_importance = activation_importance.to(device=cpu_device)[perm]
invperm = torch.argsort(perm)

elif self.qcfg.act_group_aware and use_hessian:
Expand All @@ -1022,6 +1038,8 @@ def quantize(
try:
W = W[:, final_perm]
self.H = self.H[final_perm][:, final_perm]
if activation_importance is not None:
activation_importance = activation_importance[final_perm]
except RuntimeError as exc:
if self.H.device.type != "cuda" or "out of memory" not in str(exc).lower():
raise
Expand All @@ -1032,7 +1050,10 @@ def quantize(
final_perm = final_perm.to(device=cpu_device)
W = W.to(device=cpu_device)[:, final_perm]
self.H = self.H.to(device=cpu_device)[final_perm][:, final_perm]
self.quantizer.find_params(W, weight=True)
if activation_importance is not None:
activation_importance = activation_importance.to(device=cpu_device)[final_perm]

self.quantizer.find_params(W, weight=True, importance=activation_importance)

if use_hessian:
try:
Expand All @@ -1048,7 +1069,9 @@ def quantize(
cpu_device = torch.device("cpu")
self.H = self.H.to(device=cpu_device)
W = W.to(device=cpu_device)
self.quantizer.find_params(W, weight=True)
if activation_importance is not None:
activation_importance = activation_importance.to(device=cpu_device)
self.quantizer.find_params(W, weight=True, importance=activation_importance)
Hinv, damp = self.hessian_inverse(self.H)
else:
Hinv, damp = None, 0.0
Expand All @@ -1073,7 +1096,14 @@ def quantize(
for group_start in group_start_cols:
group_end = min(group_start + self.qcfg.group_size, self.columns)
if group_start < group_end:
self.quantizer.find_params(W[:, group_start:group_end], weight=True)
group_importance = None
if activation_importance is not None:
group_importance = activation_importance[group_start:group_end]
self.quantizer.find_params(
W[:, group_start:group_end],
weight=True,
importance=group_importance,
)
scale.append(self.quantizer.scale)
zero.append(self.quantizer.zero)
now_idx += 1
Expand Down Expand Up @@ -1182,7 +1212,16 @@ def quantize(
if self.qcfg.group_size != -1:
if not self.qcfg.static_groups:
if (i1 + i) % self.qcfg.group_size == 0:
self.quantizer.find_params(W[:, (i1 + i) : (i1 + i + self.qcfg.group_size)], weight=True)
group_start = i1 + i
group_end = group_start + self.qcfg.group_size
group_importance = None
if activation_importance is not None:
group_importance = activation_importance[group_start:group_end]
self.quantizer.find_params(
W[:, group_start:group_end],
weight=True,
importance=group_importance,
)

if ((i1 + i) // self.qcfg.group_size) - now_idx == -1:
scale.append(self.quantizer.scale)
Expand Down
35 changes: 34 additions & 1 deletion gptqmodel/quantization/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def configure(
if trits:
self.maxq = torch.tensor(-1)

def find_params(self, x, weight=False):
def find_params(self, x, weight=False, importance: torch.Tensor = None):
dev = x.device
self.maxq = self.maxq.to(dev)

Expand Down Expand Up @@ -116,6 +116,37 @@ def find_params(self, x, weight=False):

mse = float(getattr(self.qcfg, "mse", 0.0) or 0.0)
if mse > 0.0:
importance_weights = None
if getattr(self.qcfg, "activation_weighted_mse", False) and importance is not None:
importance_weights = torch.nan_to_num(
importance.to(device=dev, dtype=x.dtype),
nan=0.0,
posinf=0.0,
neginf=0.0,
).clamp_min_(0)
if importance_weights.ndim == 1:
importance_weights = importance_weights.unsqueeze(0)
if importance_weights.shape[-1] != x.shape[1]:
raise ValueError(
"Quantizer.find_params(): importance parameter shape mismatch. "
f"Expected columns: {x.shape[1]}, got: {importance_weights.shape[-1]}."
)
if importance_weights.shape[0] == 1 and x.shape[0] != 1:
importance_weights = importance_weights.expand(x.shape[0], -1)
elif importance_weights.shape[0] != x.shape[0]:
raise ValueError(
"Quantizer.find_params(): importance parameter row count mismatch. "
f"Expected 1 or {x.shape[0]} rows, got: {importance_weights.shape[0]}."
)

importance_mean = importance_weights.mean(dim=1, keepdim=True)
valid = torch.isfinite(importance_mean) & (importance_mean > 0)
if torch.any(valid):
normalized_weights = importance_weights / importance_mean.clamp_min(1e-8)
importance_weights = torch.where(valid, normalized_weights, torch.ones_like(importance_weights))
else:
importance_weights = None

best = torch.full([x.shape[0]], float("inf"), device=dev)
for i in range(int(self.maxshrink * self.grid)):
p = 1 - i / self.grid
Expand All @@ -131,6 +162,8 @@ def find_params(self, x, weight=False):
q -= x
q.abs_()
q.pow_(mse)
if importance_weights is not None:
q.mul_(importance_weights)
err = torch.sum(q, 1)
tmp = err < best
if torch.any(tmp):
Expand Down
56 changes: 56 additions & 0 deletions tests/qcfg/test_gptq_pro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import torch

from gptqmodel.quantization import QuantizeConfig
from gptqmodel.quantization.quantizer import Quantizer


def _calculate_weighted_squared_error(
quantizer: Quantizer,
weights: torch.Tensor,
importance: torch.Tensor,
) -> torch.Tensor:
dequant = quantizer.quantize(weights)
return ((dequant - weights).pow(2) * importance.view(1, -1)).sum()


def test_gptq_pro_enables_activation_weighted_mse():
cfg = QuantizeConfig.gptq_pro()

assert cfg.activation_weighted_mse is True
assert cfg.act_group_aware is True
assert cfg.desc_act is False


def test_activation_weighted_mse_prioritizes_salient_columns():
weights = torch.tensor([[0.1, 0.45, 0.8, 1.2]], dtype=torch.float32)
importance = torch.tensor([1.0, 1.0, 8.0, 8.0], dtype=torch.float32)

baseline = Quantizer(
QuantizeConfig(bits=4, sym=False, mse=2.0, act_group_aware=False, desc_act=False),
)
baseline.configure(perchannel=True)
baseline.find_params(weights, weight=True)

weighted = Quantizer(
QuantizeConfig(
bits=4,
sym=False,
mse=2.0,
activation_weighted_mse=True,
act_group_aware=False,
desc_act=False,
),
)
weighted.configure(perchannel=True)
weighted.find_params(weights, weight=True, importance=importance)

assert not torch.allclose(weighted.scale, baseline.scale)
assert _calculate_weighted_squared_error(
weighted,
weights,
importance,
) < _calculate_weighted_squared_error(
baseline,
weights,
importance,
)
3 changes: 3 additions & 0 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def test_quantize_config_meta_only_fields_serialization(self):
offload_to_disk_path="./offload-test",
pack_impl="gpu",
mse=0.125,
activation_weighted_mse=True,
mock_quantization=True,
hessian=HessianConfig(
chunk_size=256,
Expand All @@ -109,6 +110,7 @@ def test_quantize_config_meta_only_fields_serialization(self):
"offload_to_disk_path",
"pack_impl",
"mse",
"activation_weighted_mse",
"mock_quantization",
"act_group_aware",
"hessian",
Expand All @@ -127,6 +129,7 @@ def test_quantize_config_meta_only_fields_serialization(self):
self.assertEqual(meta["offload_to_disk_path"], cfg.offload_to_disk_path)
self.assertEqual(meta["pack_impl"], cfg.pack_impl)
self.assertEqual(meta["mse"], cfg.mse)
self.assertEqual(meta["activation_weighted_mse"], cfg.activation_weighted_mse)
self.assertEqual(meta["mock_quantization"], cfg.mock_quantization)
self.assertEqual(meta["act_group_aware"], cfg.act_group_aware)
self.assertEqual(meta["hessian"]["chunk_size"], cfg.hessian.chunk_size)
Expand Down