Skip to content

Commit 1822c45

Browse files
committed
[quantization] Introduce smse_for_gptq
This PR introduces smse_for_gptq to improve accuracy. TICO-DCO-1.0-Signed-off-by: s.malakhov <s.malakhov@partner.samsung.com>
1 parent 3e4b06f commit 1822c45

6 files changed

Lines changed: 235 additions & 30 deletions

File tree

tico/quantization/algorithm/fpi_gptq/fpi_gptq.py

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -32,30 +32,7 @@
3232
)
3333

3434
from tico.quantization.algorithm.gptq.quant import quantize, Quantizer
35-
36-
37-
def iterate_GPTQ(scale, zero, maxq, W, Hinv, max_num_of_iters=50):
38-
39-
cur_weights = W.clone()
40-
mults = torch.pow(torch.diag(Hinv), -1)
41-
Hinv_U = torch.triu(Hinv, diagonal=1)
42-
43-
init_weights = W.clone()
44-
for _ in range(max_num_of_iters):
45-
cur_Q = quantize(cur_weights, scale, zero, maxq)
46-
47-
d_W = torch.mul((cur_weights - cur_Q), mults)
48-
cur_weights = init_weights - torch.matmul(d_W, Hinv_U)
49-
del d_W, cur_Q
50-
d_W = cur_Q = None
51-
52-
del init_weights
53-
init_weights = None
54-
55-
cur_Q = quantize(cur_weights, scale, zero, maxq)
56-
57-
return cur_Q, cur_weights
58-
35+
from tico.quantization.algorithm.fpi_gptq.util import quantize, iterate_GPTQ
5936

6037
class FPI_GPTQ:
6138
def __init__(self, layer):
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright IST-DASLab. 2025. (commit: 2d65066). GitHub repository.
2+
# Retrieved from https://github.com/IST-DASLab/gptq. Licensed under the
3+
# Apache License 2.0.
4+
5+
# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
19+
# https://github.com/IST-DASLab/gptq/blob/2d65066/quant.py
20+
21+
import torch
22+
23+
def quantize(x, scale, zero, maxq):
24+
if maxq < 0:
25+
return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero
26+
q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
27+
return scale * (q - zero)
28+
29+
30+
def iterate_GPTQ(scale, zero, maxq, W, Hinv, max_num_of_iters=50):
31+
32+
cur_weights = W.clone()
33+
mults = torch.pow(torch.diag(Hinv), -1)
34+
Hinv_U = torch.triu(Hinv, diagonal=1)
35+
36+
init_weights = W.clone()
37+
for _ in range(max_num_of_iters):
38+
cur_Q = quantize(cur_weights, scale, zero, maxq)
39+
40+
d_W = torch.mul((cur_weights - cur_Q), mults)
41+
cur_weights = init_weights - torch.matmul(d_W, Hinv_U)
42+
del d_W, cur_Q
43+
d_W = cur_Q = None
44+
45+
del init_weights
46+
init_weights = None
47+
48+
cur_Q = quantize(cur_weights, scale, zero, maxq)
49+
50+
return cur_Q, cur_weights

tico/quantization/algorithm/gptq/gptq.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,9 @@ def fasterquant(
361361
H = torch.cholesky_inverse(H)
362362
H = torch.linalg.cholesky(H, upper=True)
363363
Hinv = H
364-
364+
365+
self.quantizer.update(W, Hinv, perm)
366+
365367
assert isinstance(Hinv, torch.Tensor)
366368
for i1 in range(0, self.columns, blocksize):
367369
i2 = min(i1 + blocksize, self.columns)

tico/quantization/algorithm/gptq/quant.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import torch
2222
import torch.nn as nn
2323

24+
from tico.quantization.algorithm.fpi_gptq.util import iterate_GPTQ
2425

2526
def quantize(x, scale, zero, maxq):
2627
if maxq < 0:
@@ -101,7 +102,7 @@ def find_params(self, x, weight=False):
101102
else:
102103
self.zero = torch.round(-xmin / self.scale)
103104

104-
if self.mse is not None:
105+
if self.mse is not None and self.mse != "smse_for_gptq" and self.mse != "mse_for_gptq":
105106
best = torch.full([x.shape[0]], float("inf"), device=dev)
106107
for i in range(int(self.maxshrink * self.grid)):
107108
p = 1 - i / self.grid
@@ -151,6 +152,77 @@ def find_params(self, x, weight=False):
151152
self.scale = self.scale.unsqueeze(0)
152153
self.zero = self.zero.unsqueeze(0)
153154

155+
def update(self, x, Hinv, perm):
156+
if self.mse is None or self.mse != "smse_for_gptq":
157+
return
158+
159+
shape = x.shape
160+
if self.perchannel:
161+
x = x.flatten(1)
162+
else:
163+
x = x.flatten().unsqueeze(0)
164+
165+
dev = x.device
166+
tmp = torch.zeros(x.shape[0], device=dev)
167+
xmin = torch.minimum(x.min(1)[0], tmp)
168+
xmax = torch.maximum(x.max(1)[0], tmp)
169+
170+
if self.sym:
171+
xmax = torch.maximum(torch.abs(xmin), xmax)
172+
tmp = xmin < 0
173+
if torch.any(tmp):
174+
xmin[tmp] = -xmax[tmp]
175+
tmp = (xmin == 0) & (xmax == 0)
176+
xmin[tmp] = -1
177+
xmax[tmp] = +1
178+
if self.maxq < 0:
179+
self.scale = xmax
180+
self.zero = xmin
181+
else:
182+
self.scale = (xmax - xmin) / self.maxq
183+
if self.sym:
184+
self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) # type: ignore[arg-type]
185+
else:
186+
self.zero = torch.round(-xmin / self.scale)
187+
188+
sensitivity = None
189+
if self.sensitivity is not None:
190+
sensitivity = self.sensitivity.to(Hinv.dtype).to(dev)
191+
if perm is not None:
192+
sensitivity = sensitivity[:, perm.to(dev)]
193+
194+
num_of_iters = 15
195+
best = torch.full([x.shape[0]], float("inf"), device=dev)
196+
for i in range(int(self.maxshrink * self.grid)):
197+
p = 1 - i / self.grid
198+
xmin1 = p * xmin
199+
xmax1 = p * xmax
200+
scale1 = (xmax1 - xmin1) / self.maxq
201+
zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
202+
q, _ = iterate_GPTQ(
203+
scale1.unsqueeze(1),
204+
zero1.unsqueeze(1),
205+
self.maxq,
206+
x,
207+
Hinv,
208+
max_num_of_iters=num_of_iters,
209+
)
210+
assert sensitivity is not None
211+
assert self.mse == "smse_for_gptq"
212+
err = ((q - x) ** 2) * sensitivity.to(q.device)
213+
214+
err = err
215+
err = torch.sum(err, 1)
216+
tmp = err < best
217+
if torch.any(tmp):
218+
best[tmp] = err[tmp]
219+
self.scale[tmp] = scale1[tmp]
220+
self.zero[tmp] = zero1[tmp]
221+
222+
shape = [-1] + [1] * (len(shape) - 1)
223+
self.scale = self.scale.reshape(shape)
224+
self.zero = self.zero.reshape(shape)
225+
154226
def quantize(self, x):
155227
if self.ready():
156228
return quantize(x, self.scale, self.zero, self.maxq)

tico/quantization/algorithm/gptq/utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,11 @@ def compute_sensitivity_info(self):
163163
if self.show_progress is True:
164164
print("Calibrating sensitivity")
165165
for inputs, targets in tqdm.tqdm(data_loader, disable=not self.show_progress):
166-
model.zero_grad()
166+
model.zero_grad(set_to_none=True)
167+
if model.device.type != "cpu":
168+
torch.cuda.empty_cache()
169+
torch.cuda.synchronize()
170+
167171
if isinstance(inputs, torch.Tensor):
168172
inp_ids = inputs.squeeze(0) # remove redundant batch dimension
169173
logits = model(inp_ids.to(model.device)).logits
@@ -215,6 +219,11 @@ def compute_sensitivity_info(self):
215219
for name in modules_to_process:
216220
sensitivity[name] /= len(data_loader)
217221

222+
model.zero_grad(set_to_none=True)
223+
if model.device.type != "cpu":
224+
torch.cuda.synchronize()
225+
torch.cuda.empty_cache()
226+
218227
model = model.to(dtype)
219228

220229
return sensitivity

tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py

Lines changed: 98 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434

3535
from typing import Any, List, Optional, Tuple, Union
3636

37+
import numpy as np
38+
3739
import torch
3840
import tqdm
3941
from datasets import load_dataset
@@ -98,6 +100,51 @@ def inject_gptq_qparams(
98100
obs.load_qparams(quantizer.scale, quantizer.zero, lock=True)
99101

100102

103+
def evaluate_ppl_of_model_on_dataset(model, dataset, device: str = "cuda"):
104+
if hasattr(model, "device") and model.device.type != device.type:
105+
if hasattr(model, "to"):
106+
model.to(device)
107+
nlls = []
108+
with torch.no_grad():
109+
for batch in tqdm.tqdm(dataset):
110+
if isinstance(batch, torch.Tensor):
111+
batch = batch.to(device)
112+
output = model(
113+
batch.to(device),
114+
)
115+
else:
116+
raise RuntimeError("Unknown input in ppl_eval_on_dataset")
117+
118+
if hasattr(output, "logits"):
119+
lm_logits = output.logits
120+
elif len(output) > 1:
121+
lm_logits = torch.tensor(output[0])
122+
else:
123+
lm_logits = torch.tensor(output)
124+
125+
if torch.isfinite(lm_logits).all():
126+
shift_logits = lm_logits[:, :-1, :].contiguous()
127+
if isinstance(batch, torch.Tensor):
128+
shift_labels = batch[:, 1:].contiguous()
129+
else:
130+
assert isinstance(batch, tuple)
131+
shift_labels = batch[0][:, 1:].contiguous()
132+
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
133+
loss = loss_fct(
134+
shift_logits.reshape(-1, shift_logits.size(-1)),
135+
shift_labels.view(-1),
136+
)
137+
nlls.append(loss)
138+
del shift_logits, shift_labels
139+
shift_logits = shift_labels = None # type: ignore[assignment]
140+
141+
del batch, lm_logits, output
142+
lm_logits = output = batch = None # noqa: F841
143+
torch.cuda.empty_cache()
144+
145+
ppl = np.exp(torch.cat(nlls, dim=-1).mean().item())
146+
return ppl
147+
101148
# -------------------------------------------------------------------------
102149
# Save model in circle format
103150
# -------------------------------------------------------------------------
@@ -215,6 +262,24 @@ def evaluate(q_m, tokenizer, dataset_test, args):
215262
print("Quantized RESULTS ARE:")
216263
print(make_table(results))
217264

265+
def get_sensitivities_info_path(model, save_folder, dataset, seed, n_samples):
266+
model_name = model.config.name_or_path.replace("/", "_")
267+
if save_folder is None:
268+
save_folder = "."
269+
cache_path = (
270+
"."
271+
+ "/sensitivities_for_"
272+
+ model_name
273+
+ "_"
274+
+ dataset
275+
+ "_"
276+
+ str(n_samples)
277+
+ "_"
278+
+ str(seed)
279+
+ ".pt"
280+
)
281+
return cache_path
282+
218283

219284
def main():
220285
parser = argparse.ArgumentParser(
@@ -296,7 +361,7 @@ def main():
296361
"--gptq_mse",
297362
type=str,
298363
default=None,
299-
choices=["mse", "smse"],
364+
choices=["mse", "smse", "smse_for_gptq"],
300365
help="Whether and how to use mse in gptq (none/mse/smse/)",
301366
)
302367
parser.add_argument(
@@ -410,6 +475,13 @@ def main():
410475
j = i + seqlen
411476
inp = train_ids[:, i:j]
412477
calib_inputs.append(inp.cpu())
478+
479+
train_ppl_fp32 = evaluate_ppl_of_model_on_dataset(
480+
model, calib_inputs, device=device
481+
)
482+
print("\n┌── Wikitext-2 train perplexity ─────────────")
483+
print(f"│ FP32 : {train_ppl_fp32:8.2f}")
484+
print("└───────────────────────────────────────────")
413485

414486
# -------------------------------------------------------------------------
415487
# Run GPTQ (weight-only) pass
@@ -418,13 +490,24 @@ def main():
418490
print("Applying GPTQ …")
419491

420492
sens = None
421-
if args.gptq_mse is not None and args.gptq_mse == "smse":
493+
if args.gptq_mse is not None and (
494+
args.gptq_mse == "smse" or args.gptq_mse == "smse_for_gptq"
495+
):
422496
if args.sensitivity_path is not None:
423497
sens = torch.load(args.sensitivity_path)
424498
else:
425499
calibrator = SensitivityCalibrator(model, calib_inputs)
426500
sens = calibrator.compute_sensitivity_info()
427-
501+
save_folder = args.save_circle_to_folder if args.save_circle_to_folder is not None else args.save_layers_to_folder
502+
save_path = get_sensitivities_info_path(model, save_folder, "wikitext", args.seed, len(calib_inputs))
503+
print(f"Saving calibrated_sensitivities to {save_path}")
504+
torch.save(sens, save_path)
505+
506+
model = model.cpu()
507+
model = model.to(args.device)
508+
torch.cuda.empty_cache()
509+
torch.cuda.synchronize()
510+
428511
gptq_config = GPTQConfig(
429512
weight_bits=args.linear_weight_bits,
430513
perchannel=True,
@@ -440,12 +523,24 @@ def main():
440523
else:
441524
q_m = model
442525

526+
q_m = q_m.cpu()
527+
q_m = q_m.to(args.device)
528+
torch.cuda.empty_cache()
529+
torch.cuda.synchronize()
530+
443531
# -------------------------------------------------------------------------
444532
# Wrap every layer with PTQWrapper
445533
# -------------------------------------------------------------------------
446534
if not args.no_PTQ:
447535
q_m = quantize_using_PTQ(q_m, calib_inputs, args)
448536

537+
train_ppl_ioqdtype = evaluate_ppl_of_model_on_dataset(
538+
q_m, calib_inputs, device=device
539+
)
540+
print("\n┌── Wikitext-2 train perplexity ─────────────")
541+
print(f"│ int16 : {train_ppl_ioqdtype:8.2f}")
542+
print("└───────────────────────────────────────────")
543+
449544
# after PTQ quantizer only fixed-length input sequences are valid
450545
evaluate(q_m, tokenizer, dataset_test, args)
451546

0 commit comments

Comments
 (0)