diff --git a/tico/quantization/algorithm/fpi_gptq/fpi_gptq.py b/tico/quantization/algorithm/fpi_gptq/fpi_gptq.py index cdd99ef7..641a59ae 100644 --- a/tico/quantization/algorithm/fpi_gptq/fpi_gptq.py +++ b/tico/quantization/algorithm/fpi_gptq/fpi_gptq.py @@ -32,30 +32,7 @@ ) from tico.quantization.algorithm.gptq.quant import quantize, Quantizer - - -def iterate_GPTQ(scale, zero, maxq, W, Hinv, max_num_of_iters=50): - - cur_weights = W.clone() - mults = torch.pow(torch.diag(Hinv), -1) - Hinv_U = torch.triu(Hinv, diagonal=1) - - init_weights = W.clone() - for _ in range(max_num_of_iters): - cur_Q = quantize(cur_weights, scale, zero, maxq) - - d_W = torch.mul((cur_weights - cur_Q), mults) - cur_weights = init_weights - torch.matmul(d_W, Hinv_U) - del d_W, cur_Q - d_W = cur_Q = None - - del init_weights - init_weights = None - - cur_Q = quantize(cur_weights, scale, zero, maxq) - - return cur_Q, cur_weights - +from tico.quantization.algorithm.fpi_gptq.util import quantize, iterate_GPTQ class FPI_GPTQ: def __init__(self, layer): diff --git a/tico/quantization/algorithm/fpi_gptq/util.py b/tico/quantization/algorithm/fpi_gptq/util.py new file mode 100644 index 00000000..9d73b052 --- /dev/null +++ b/tico/quantization/algorithm/fpi_gptq/util.py @@ -0,0 +1,50 @@ +# Copyright IST-DASLab. 2025. (commit: 2d65066). GitHub repository. +# Retrieved from https://github.com/IST-DASLab/gptq. Licensed under the +# Apache License 2.0. + +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# https://github.com/IST-DASLab/gptq/blob/2d65066/quant.py + +import torch + +def quantize(x, scale, zero, maxq): + if maxq < 0: + return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero + q = torch.clamp(torch.round(x / scale) + zero, 0, maxq) + return scale * (q - zero) + + +def iterate_GPTQ(scale, zero, maxq, W, Hinv, max_num_of_iters=50): + + cur_weights = W.clone() + mults = torch.pow(torch.diag(Hinv), -1) + Hinv_U = torch.triu(Hinv, diagonal=1) + + init_weights = W.clone() + for _ in range(max_num_of_iters): + cur_Q = quantize(cur_weights, scale, zero, maxq) + + d_W = torch.mul((cur_weights - cur_Q), mults) + cur_weights = init_weights - torch.matmul(d_W, Hinv_U) + del d_W, cur_Q + d_W = cur_Q = None + + del init_weights + init_weights = None + + cur_Q = quantize(cur_weights, scale, zero, maxq) + + return cur_Q, cur_weights diff --git a/tico/quantization/algorithm/gptq/gptq.py b/tico/quantization/algorithm/gptq/gptq.py index 8fb14ca8..5693bfa8 100644 --- a/tico/quantization/algorithm/gptq/gptq.py +++ b/tico/quantization/algorithm/gptq/gptq.py @@ -361,7 +361,9 @@ def fasterquant( H = torch.cholesky_inverse(H) H = torch.linalg.cholesky(H, upper=True) Hinv = H - + + self.quantizer.update(W, Hinv, perm) + assert isinstance(Hinv, torch.Tensor) for i1 in range(0, self.columns, blocksize): i2 = min(i1 + blocksize, self.columns) diff --git a/tico/quantization/algorithm/gptq/quant.py b/tico/quantization/algorithm/gptq/quant.py index 98e7731d..4acefeca 100644 --- a/tico/quantization/algorithm/gptq/quant.py +++ b/tico/quantization/algorithm/gptq/quant.py @@ -21,6 +21,7 @@ import torch import torch.nn as nn +from tico.quantization.algorithm.fpi_gptq.util import iterate_GPTQ def quantize(x, scale, zero, maxq): if maxq < 0: @@ -101,7 +102,7 @@ def find_params(self, x, weight=False): else: self.zero = torch.round(-xmin / self.scale) - if self.mse is not None: + if self.mse is not None and self.mse != "smse_for_gptq" and self.mse != "mse_for_gptq": best = torch.full([x.shape[0]], float("inf"), device=dev) for i in range(int(self.maxshrink * self.grid)): p = 1 - i / self.grid @@ -151,6 +152,77 @@ def find_params(self, x, weight=False): self.scale = self.scale.unsqueeze(0) self.zero = self.zero.unsqueeze(0) + def update(self, x, Hinv, perm): + if self.mse is None or self.mse != "smse_for_gptq": + return + + shape = x.shape + if self.perchannel: + x = x.flatten(1) + else: + x = x.flatten().unsqueeze(0) + + dev = x.device + tmp = torch.zeros(x.shape[0], device=dev) + xmin = torch.minimum(x.min(1)[0], tmp) + xmax = torch.maximum(x.max(1)[0], tmp) + + if self.sym: + xmax = torch.maximum(torch.abs(xmin), xmax) + tmp = xmin < 0 + if torch.any(tmp): + xmin[tmp] = -xmax[tmp] + tmp = (xmin == 0) & (xmax == 0) + xmin[tmp] = -1 + xmax[tmp] = +1 + if self.maxq < 0: + self.scale = xmax + self.zero = xmin + else: + self.scale = (xmax - xmin) / self.maxq + if self.sym: + self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) # type: ignore[arg-type] + else: + self.zero = torch.round(-xmin / self.scale) + + sensitivity = None + if self.sensitivity is not None: + sensitivity = self.sensitivity.to(Hinv.dtype).to(dev) + if perm is not None: + sensitivity = sensitivity[:, perm.to(dev)] + + num_of_iters = 15 + best = torch.full([x.shape[0]], float("inf"), device=dev) + for i in range(int(self.maxshrink * self.grid)): + p = 1 - i / self.grid + xmin1 = p * xmin + xmax1 = p * xmax + scale1 = (xmax1 - xmin1) / self.maxq + zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero + q, _ = iterate_GPTQ( + scale1.unsqueeze(1), + zero1.unsqueeze(1), + self.maxq, + x, + Hinv, + max_num_of_iters=num_of_iters, + ) + assert sensitivity is not None + assert self.mse == "smse_for_gptq" + err = ((q - x) ** 2) * sensitivity.to(q.device) + + err = err + err = torch.sum(err, 1) + tmp = err < best + if torch.any(tmp): + best[tmp] = err[tmp] + self.scale[tmp] = scale1[tmp] + self.zero[tmp] = zero1[tmp] + + shape = [-1] + [1] * (len(shape) - 1) + self.scale = self.scale.reshape(shape) + self.zero = self.zero.reshape(shape) + def quantize(self, x): if self.ready(): return quantize(x, self.scale, self.zero, self.maxq) diff --git a/tico/quantization/algorithm/gptq/utils.py b/tico/quantization/algorithm/gptq/utils.py index 3dcc3be9..a0aaa279 100644 --- a/tico/quantization/algorithm/gptq/utils.py +++ b/tico/quantization/algorithm/gptq/utils.py @@ -163,7 +163,11 @@ def compute_sensitivity_info(self): if self.show_progress is True: print("Calibrating sensitivity") for inputs, targets in tqdm.tqdm(data_loader, disable=not self.show_progress): - model.zero_grad() + model.zero_grad(set_to_none=True) + if model.device.type != "cpu": + torch.cuda.empty_cache() + torch.cuda.synchronize() + if isinstance(inputs, torch.Tensor): inp_ids = inputs.squeeze(0) # remove redundant batch dimension logits = model(inp_ids.to(model.device)).logits @@ -215,6 +219,11 @@ def compute_sensitivity_info(self): for name in modules_to_process: sensitivity[name] /= len(data_loader) + model.zero_grad(set_to_none=True) + if model.device.type != "cpu": + torch.cuda.synchronize() + torch.cuda.empty_cache() + model = model.to(dtype) return sensitivity diff --git a/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py b/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py index 289b9014..1c60dc54 100644 --- a/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py +++ b/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py @@ -34,6 +34,8 @@ from typing import Any, List, Optional, Tuple, Union +import numpy as np + import torch import tqdm from datasets import load_dataset @@ -98,6 +100,51 @@ def inject_gptq_qparams( obs.load_qparams(quantizer.scale, quantizer.zero, lock=True) +def evaluate_ppl_of_model_on_dataset(model, dataset, device: str = "cuda"): + if hasattr(model, "device") and model.device.type != device.type: + if hasattr(model, "to"): + model.to(device) + nlls = [] + with torch.no_grad(): + for batch in tqdm.tqdm(dataset): + if isinstance(batch, torch.Tensor): + batch = batch.to(device) + output = model( + batch.to(device), + ) + else: + raise RuntimeError("Unknown input in ppl_eval_on_dataset") + + if hasattr(output, "logits"): + lm_logits = output.logits + elif len(output) > 1: + lm_logits = torch.tensor(output[0]) + else: + lm_logits = torch.tensor(output) + + if torch.isfinite(lm_logits).all(): + shift_logits = lm_logits[:, :-1, :].contiguous() + if isinstance(batch, torch.Tensor): + shift_labels = batch[:, 1:].contiguous() + else: + assert isinstance(batch, tuple) + shift_labels = batch[0][:, 1:].contiguous() + loss_fct = torch.nn.CrossEntropyLoss(reduction="none") + loss = loss_fct( + shift_logits.reshape(-1, shift_logits.size(-1)), + shift_labels.view(-1), + ) + nlls.append(loss) + del shift_logits, shift_labels + shift_logits = shift_labels = None # type: ignore[assignment] + + del batch, lm_logits, output + lm_logits = output = batch = None # noqa: F841 + torch.cuda.empty_cache() + + ppl = np.exp(torch.cat(nlls, dim=-1).mean().item()) + return ppl + # ------------------------------------------------------------------------- # Save model in circle format # ------------------------------------------------------------------------- @@ -215,6 +262,24 @@ def evaluate(q_m, tokenizer, dataset_test, args): print("Quantized RESULTS ARE:") print(make_table(results)) +def get_sensitivities_info_path(model, save_folder, dataset, seed, n_samples): + model_name = model.config.name_or_path.replace("/", "_") + if save_folder is None: + save_folder = "." + cache_path = ( + "." + + "/sensitivities_for_" + + model_name + + "_" + + dataset + + "_" + + str(n_samples) + + "_" + + str(seed) + + ".pt" + ) + return cache_path + def main(): parser = argparse.ArgumentParser( @@ -296,7 +361,7 @@ def main(): "--gptq_mse", type=str, default=None, - choices=["mse", "smse"], + choices=["mse", "smse", "smse_for_gptq"], help="Whether and how to use mse in gptq (none/mse/smse/)", ) parser.add_argument( @@ -410,6 +475,13 @@ def main(): j = i + seqlen inp = train_ids[:, i:j] calib_inputs.append(inp.cpu()) + + train_ppl_fp32 = evaluate_ppl_of_model_on_dataset( + model, calib_inputs, device=device + ) + print("\n┌── Wikitext-2 train perplexity ─────────────") + print(f"│ FP32 : {train_ppl_fp32:8.2f}") + print("└───────────────────────────────────────────") # ------------------------------------------------------------------------- # Run GPTQ (weight-only) pass @@ -418,13 +490,24 @@ def main(): print("Applying GPTQ …") sens = None - if args.gptq_mse is not None and args.gptq_mse == "smse": + if args.gptq_mse is not None and ( + args.gptq_mse == "smse" or args.gptq_mse == "smse_for_gptq" + ): if args.sensitivity_path is not None: sens = torch.load(args.sensitivity_path) else: calibrator = SensitivityCalibrator(model, calib_inputs) sens = calibrator.compute_sensitivity_info() - + save_folder = args.save_circle_to_folder if args.save_circle_to_folder is not None else args.save_layers_to_folder + save_path = get_sensitivities_info_path(model, save_folder, "wikitext", args.seed, len(calib_inputs)) + print(f"Saving calibrated_sensitivities to {save_path}") + torch.save(sens, save_path) + + model = model.cpu() + model = model.to(args.device) + torch.cuda.empty_cache() + torch.cuda.synchronize() + gptq_config = GPTQConfig( weight_bits=args.linear_weight_bits, perchannel=True, @@ -440,12 +523,24 @@ def main(): else: q_m = model + q_m = q_m.cpu() + q_m = q_m.to(args.device) + torch.cuda.empty_cache() + torch.cuda.synchronize() + # ------------------------------------------------------------------------- # Wrap every layer with PTQWrapper # ------------------------------------------------------------------------- if not args.no_PTQ: q_m = quantize_using_PTQ(q_m, calib_inputs, args) + train_ppl_ioqdtype = evaluate_ppl_of_model_on_dataset( + q_m, calib_inputs, device=device + ) + print("\n┌── Wikitext-2 train perplexity ─────────────") + print(f"│ int16 : {train_ppl_ioqdtype:8.2f}") + print("└───────────────────────────────────────────") + # after PTQ quantizer only fixed-length input sequences are valid evaluate(q_m, tokenizer, dataset_test, args)