Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 1 addition & 24 deletions tico/quantization/algorithm/fpi_gptq/fpi_gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,30 +32,7 @@
)

from tico.quantization.algorithm.gptq.quant import quantize, Quantizer


def iterate_GPTQ(scale, zero, maxq, W, Hinv, max_num_of_iters=50):

cur_weights = W.clone()
mults = torch.pow(torch.diag(Hinv), -1)
Hinv_U = torch.triu(Hinv, diagonal=1)

init_weights = W.clone()
for _ in range(max_num_of_iters):
cur_Q = quantize(cur_weights, scale, zero, maxq)

d_W = torch.mul((cur_weights - cur_Q), mults)
cur_weights = init_weights - torch.matmul(d_W, Hinv_U)
del d_W, cur_Q
d_W = cur_Q = None

del init_weights
init_weights = None

cur_Q = quantize(cur_weights, scale, zero, maxq)

return cur_Q, cur_weights

from tico.quantization.algorithm.fpi_gptq.util import quantize, iterate_GPTQ

class FPI_GPTQ:
def __init__(self, layer):
Expand Down
50 changes: 50 additions & 0 deletions tico/quantization/algorithm/fpi_gptq/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright IST-DASLab. 2025. (commit: 2d65066). GitHub repository.
# Retrieved from https://github.com/IST-DASLab/gptq. Licensed under the
# Apache License 2.0.

# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# https://github.com/IST-DASLab/gptq/blob/2d65066/quant.py

import torch

def quantize(x, scale, zero, maxq):
if maxq < 0:
return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero
q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
return scale * (q - zero)


def iterate_GPTQ(scale, zero, maxq, W, Hinv, max_num_of_iters=50):

cur_weights = W.clone()
mults = torch.pow(torch.diag(Hinv), -1)
Hinv_U = torch.triu(Hinv, diagonal=1)

init_weights = W.clone()
for _ in range(max_num_of_iters):
cur_Q = quantize(cur_weights, scale, zero, maxq)

d_W = torch.mul((cur_weights - cur_Q), mults)
cur_weights = init_weights - torch.matmul(d_W, Hinv_U)
del d_W, cur_Q
d_W = cur_Q = None

del init_weights
init_weights = None

cur_Q = quantize(cur_weights, scale, zero, maxq)

return cur_Q, cur_weights
4 changes: 3 additions & 1 deletion tico/quantization/algorithm/gptq/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,9 @@ def fasterquant(
H = torch.cholesky_inverse(H)
H = torch.linalg.cholesky(H, upper=True)
Hinv = H


self.quantizer.update(W, Hinv, perm)

assert isinstance(Hinv, torch.Tensor)
for i1 in range(0, self.columns, blocksize):
i2 = min(i1 + blocksize, self.columns)
Expand Down
74 changes: 73 additions & 1 deletion tico/quantization/algorithm/gptq/quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torch
import torch.nn as nn

from tico.quantization.algorithm.fpi_gptq.util import iterate_GPTQ

def quantize(x, scale, zero, maxq):
if maxq < 0:
Expand Down Expand Up @@ -101,7 +102,7 @@ def find_params(self, x, weight=False):
else:
self.zero = torch.round(-xmin / self.scale)

if self.mse is not None:
if self.mse is not None and self.mse != "smse_for_gptq" and self.mse != "mse_for_gptq":
best = torch.full([x.shape[0]], float("inf"), device=dev)
for i in range(int(self.maxshrink * self.grid)):
p = 1 - i / self.grid
Expand Down Expand Up @@ -151,6 +152,77 @@ def find_params(self, x, weight=False):
self.scale = self.scale.unsqueeze(0)
self.zero = self.zero.unsqueeze(0)

def update(self, x, Hinv, perm):
if self.mse is None or self.mse != "smse_for_gptq":
return

shape = x.shape
if self.perchannel:
x = x.flatten(1)
else:
x = x.flatten().unsqueeze(0)

dev = x.device
tmp = torch.zeros(x.shape[0], device=dev)
xmin = torch.minimum(x.min(1)[0], tmp)
xmax = torch.maximum(x.max(1)[0], tmp)

if self.sym:
xmax = torch.maximum(torch.abs(xmin), xmax)
tmp = xmin < 0
if torch.any(tmp):
xmin[tmp] = -xmax[tmp]
tmp = (xmin == 0) & (xmax == 0)
xmin[tmp] = -1
xmax[tmp] = +1
if self.maxq < 0:
self.scale = xmax
self.zero = xmin
else:
self.scale = (xmax - xmin) / self.maxq
if self.sym:
self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) # type: ignore[arg-type]
else:
self.zero = torch.round(-xmin / self.scale)

sensitivity = None
if self.sensitivity is not None:
sensitivity = self.sensitivity.to(Hinv.dtype).to(dev)
if perm is not None:
sensitivity = sensitivity[:, perm.to(dev)]

num_of_iters = 15
best = torch.full([x.shape[0]], float("inf"), device=dev)
for i in range(int(self.maxshrink * self.grid)):
p = 1 - i / self.grid
xmin1 = p * xmin
xmax1 = p * xmax
scale1 = (xmax1 - xmin1) / self.maxq
zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
q, _ = iterate_GPTQ(
scale1.unsqueeze(1),
zero1.unsqueeze(1),
self.maxq,
x,
Hinv,
max_num_of_iters=num_of_iters,
)
assert sensitivity is not None
assert self.mse == "smse_for_gptq"
err = ((q - x) ** 2) * sensitivity.to(q.device)

err = err
err = torch.sum(err, 1)
tmp = err < best
if torch.any(tmp):
best[tmp] = err[tmp]
self.scale[tmp] = scale1[tmp]
self.zero[tmp] = zero1[tmp]

shape = [-1] + [1] * (len(shape) - 1)
self.scale = self.scale.reshape(shape)
self.zero = self.zero.reshape(shape)

def quantize(self, x):
if self.ready():
return quantize(x, self.scale, self.zero, self.maxq)
Expand Down
11 changes: 10 additions & 1 deletion tico/quantization/algorithm/gptq/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,11 @@ def compute_sensitivity_info(self):
if self.show_progress is True:
print("Calibrating sensitivity")
for inputs, targets in tqdm.tqdm(data_loader, disable=not self.show_progress):
model.zero_grad()
model.zero_grad(set_to_none=True)
if model.device.type != "cpu":
torch.cuda.empty_cache()
torch.cuda.synchronize()

if isinstance(inputs, torch.Tensor):
inp_ids = inputs.squeeze(0) # remove redundant batch dimension
logits = model(inp_ids.to(model.device)).logits
Expand Down Expand Up @@ -215,6 +219,11 @@ def compute_sensitivity_info(self):
for name in modules_to_process:
sensitivity[name] /= len(data_loader)

model.zero_grad(set_to_none=True)
if model.device.type != "cpu":
torch.cuda.synchronize()
torch.cuda.empty_cache()

model = model.to(dtype)

return sensitivity
101 changes: 98 additions & 3 deletions tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@

from typing import Any, List, Optional, Tuple, Union

import numpy as np

import torch
import tqdm
from datasets import load_dataset
Expand Down Expand Up @@ -98,6 +100,51 @@ def inject_gptq_qparams(
obs.load_qparams(quantizer.scale, quantizer.zero, lock=True)


def evaluate_ppl_of_model_on_dataset(model, dataset, device: str = "cuda"):
if hasattr(model, "device") and model.device.type != device.type:
if hasattr(model, "to"):
model.to(device)
nlls = []
with torch.no_grad():
for batch in tqdm.tqdm(dataset):
if isinstance(batch, torch.Tensor):
batch = batch.to(device)
output = model(
batch.to(device),
)
else:
raise RuntimeError("Unknown input in ppl_eval_on_dataset")

if hasattr(output, "logits"):
lm_logits = output.logits
elif len(output) > 1:
lm_logits = torch.tensor(output[0])
else:
lm_logits = torch.tensor(output)

if torch.isfinite(lm_logits).all():
shift_logits = lm_logits[:, :-1, :].contiguous()
if isinstance(batch, torch.Tensor):
shift_labels = batch[:, 1:].contiguous()
else:
assert isinstance(batch, tuple)
shift_labels = batch[0][:, 1:].contiguous()
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
loss = loss_fct(
shift_logits.reshape(-1, shift_logits.size(-1)),
shift_labels.view(-1),
)
nlls.append(loss)
del shift_logits, shift_labels
shift_logits = shift_labels = None # type: ignore[assignment]

del batch, lm_logits, output
lm_logits = output = batch = None # noqa: F841
torch.cuda.empty_cache()

ppl = np.exp(torch.cat(nlls, dim=-1).mean().item())
return ppl

# -------------------------------------------------------------------------
# Save model in circle format
# -------------------------------------------------------------------------
Expand Down Expand Up @@ -215,6 +262,24 @@ def evaluate(q_m, tokenizer, dataset_test, args):
print("Quantized RESULTS ARE:")
print(make_table(results))

def get_sensitivities_info_path(model, save_folder, dataset, seed, n_samples):
model_name = model.config.name_or_path.replace("/", "_")
if save_folder is None:
save_folder = "."
cache_path = (
"."
+ "/sensitivities_for_"
+ model_name
+ "_"
+ dataset
+ "_"
+ str(n_samples)
+ "_"
+ str(seed)
+ ".pt"
)
return cache_path


def main():
parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -296,7 +361,7 @@ def main():
"--gptq_mse",
type=str,
default=None,
choices=["mse", "smse"],
choices=["mse", "smse", "smse_for_gptq"],
help="Whether and how to use mse in gptq (none/mse/smse/)",
)
parser.add_argument(
Expand Down Expand Up @@ -410,6 +475,13 @@ def main():
j = i + seqlen
inp = train_ids[:, i:j]
calib_inputs.append(inp.cpu())

train_ppl_fp32 = evaluate_ppl_of_model_on_dataset(
model, calib_inputs, device=device
)
print("\n┌── Wikitext-2 train perplexity ─────────────")
print(f"│ FP32 : {train_ppl_fp32:8.2f}")
print("└───────────────────────────────────────────")

# -------------------------------------------------------------------------
# Run GPTQ (weight-only) pass
Expand All @@ -418,13 +490,24 @@ def main():
print("Applying GPTQ …")

sens = None
if args.gptq_mse is not None and args.gptq_mse == "smse":
if args.gptq_mse is not None and (
args.gptq_mse == "smse" or args.gptq_mse == "smse_for_gptq"
):
if args.sensitivity_path is not None:
sens = torch.load(args.sensitivity_path)
else:
calibrator = SensitivityCalibrator(model, calib_inputs)
sens = calibrator.compute_sensitivity_info()

save_folder = args.save_circle_to_folder if args.save_circle_to_folder is not None else args.save_layers_to_folder
save_path = get_sensitivities_info_path(model, save_folder, "wikitext", args.seed, len(calib_inputs))
print(f"Saving calibrated_sensitivities to {save_path}")
torch.save(sens, save_path)

model = model.cpu()
model = model.to(args.device)
torch.cuda.empty_cache()
torch.cuda.synchronize()

gptq_config = GPTQConfig(
weight_bits=args.linear_weight_bits,
perchannel=True,
Expand All @@ -440,12 +523,24 @@ def main():
else:
q_m = model

q_m = q_m.cpu()
q_m = q_m.to(args.device)
torch.cuda.empty_cache()
torch.cuda.synchronize()

# -------------------------------------------------------------------------
# Wrap every layer with PTQWrapper
# -------------------------------------------------------------------------
if not args.no_PTQ:
q_m = quantize_using_PTQ(q_m, calib_inputs, args)

train_ppl_ioqdtype = evaluate_ppl_of_model_on_dataset(
q_m, calib_inputs, device=device
)
print("\n┌── Wikitext-2 train perplexity ─────────────")
print(f"│ int16 : {train_ppl_ioqdtype:8.2f}")
print("└───────────────────────────────────────────")

# after PTQ quantizer only fixed-length input sequences are valid
evaluate(q_m, tokenizer, dataset_test, args)

Expand Down
Loading