Skip to content

Commit 83220a4

Browse files
committed
[quantization] Introduce smse_for_gptq
This PR introduces smse_for_gptq to improve accuracy. TICO-DCO-1.0-Signed-off-by: s.malakhov <s.malakhov@partner.samsung.com>
1 parent 3e4b06f commit 83220a4

6 files changed

Lines changed: 201 additions & 29 deletions

File tree

tico/quantization/algorithm/fpi_gptq/fpi_gptq.py

Lines changed: 1 addition & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -32,30 +32,7 @@
3232
)
3333

3434
from tico.quantization.algorithm.gptq.quant import quantize, Quantizer
35-
36-
37-
def iterate_GPTQ(scale, zero, maxq, W, Hinv, max_num_of_iters=50):
38-
39-
cur_weights = W.clone()
40-
mults = torch.pow(torch.diag(Hinv), -1)
41-
Hinv_U = torch.triu(Hinv, diagonal=1)
42-
43-
init_weights = W.clone()
44-
for _ in range(max_num_of_iters):
45-
cur_Q = quantize(cur_weights, scale, zero, maxq)
46-
47-
d_W = torch.mul((cur_weights - cur_Q), mults)
48-
cur_weights = init_weights - torch.matmul(d_W, Hinv_U)
49-
del d_W, cur_Q
50-
d_W = cur_Q = None
51-
52-
del init_weights
53-
init_weights = None
54-
55-
cur_Q = quantize(cur_weights, scale, zero, maxq)
56-
57-
return cur_Q, cur_weights
58-
35+
from tico.quantization.algorithm.fpi_gptq.util import quantize, iterate_GPTQ
5936

6037
class FPI_GPTQ:
6138
def __init__(self, layer):
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright IST-DASLab. 2025. (commit: 2d65066). GitHub repository.
2+
# Retrieved from https://github.com/IST-DASLab/gptq. Licensed under the
3+
# Apache License 2.0.
4+
5+
# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved
6+
#
7+
# Licensed under the Apache License, Version 2.0 (the "License");
8+
# you may not use this file except in compliance with the License.
9+
# You may obtain a copy of the License at
10+
#
11+
# http://www.apache.org/licenses/LICENSE-2.0
12+
#
13+
# Unless required by applicable law or agreed to in writing, software
14+
# distributed under the License is distributed on an "AS IS" BASIS,
15+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
# See the License for the specific language governing permissions and
17+
# limitations under the License.
18+
19+
# https://github.com/IST-DASLab/gptq/blob/2d65066/quant.py
20+
21+
import torch
22+
23+
def quantize(x, scale, zero, maxq):
24+
if maxq < 0:
25+
return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero
26+
q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
27+
return scale * (q - zero)
28+
29+
30+
def iterate_GPTQ(scale, zero, maxq, W, Hinv, max_num_of_iters=50):
31+
32+
cur_weights = W.clone()
33+
mults = torch.pow(torch.diag(Hinv), -1)
34+
Hinv_U = torch.triu(Hinv, diagonal=1)
35+
36+
init_weights = W.clone()
37+
for _ in range(max_num_of_iters):
38+
cur_Q = quantize(cur_weights, scale, zero, maxq)
39+
40+
d_W = torch.mul((cur_weights - cur_Q), mults)
41+
cur_weights = init_weights - torch.matmul(d_W, Hinv_U)
42+
del d_W, cur_Q
43+
d_W = cur_Q = None
44+
45+
del init_weights
46+
init_weights = None
47+
48+
cur_Q = quantize(cur_weights, scale, zero, maxq)
49+
50+
return cur_Q, cur_weights

tico/quantization/algorithm/gptq/gptq.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,9 @@ def fasterquant(
361361
H = torch.cholesky_inverse(H)
362362
H = torch.linalg.cholesky(H, upper=True)
363363
Hinv = H
364-
364+
365+
self.quantizer.update(W, Hinv, perm)
366+
365367
assert isinstance(Hinv, torch.Tensor)
366368
for i1 in range(0, self.columns, blocksize):
367369
i2 = min(i1 + blocksize, self.columns)

tico/quantization/algorithm/gptq/quant.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import torch
2222
import torch.nn as nn
2323

24+
from tico.quantization.algorithm.fpi_gptq.util import iterate_GPTQ
2425

2526
def quantize(x, scale, zero, maxq):
2627
if maxq < 0:
@@ -101,7 +102,7 @@ def find_params(self, x, weight=False):
101102
else:
102103
self.zero = torch.round(-xmin / self.scale)
103104

104-
if self.mse is not None:
105+
if self.mse is not None and self.mse != "smse_for_gptq" and self.mse != "mse_for_gptq":
105106
best = torch.full([x.shape[0]], float("inf"), device=dev)
106107
for i in range(int(self.maxshrink * self.grid)):
107108
p = 1 - i / self.grid
@@ -151,6 +152,77 @@ def find_params(self, x, weight=False):
151152
self.scale = self.scale.unsqueeze(0)
152153
self.zero = self.zero.unsqueeze(0)
153154

155+
def update(self, x, Hinv, perm):
156+
if self.mse is None or self.mse != "smse_for_gptq":
157+
return
158+
159+
shape = x.shape
160+
if self.perchannel:
161+
x = x.flatten(1)
162+
else:
163+
x = x.flatten().unsqueeze(0)
164+
165+
dev = x.device
166+
tmp = torch.zeros(x.shape[0], device=dev)
167+
xmin = torch.minimum(x.min(1)[0], tmp)
168+
xmax = torch.maximum(x.max(1)[0], tmp)
169+
170+
if self.sym:
171+
xmax = torch.maximum(torch.abs(xmin), xmax)
172+
tmp = xmin < 0
173+
if torch.any(tmp):
174+
xmin[tmp] = -xmax[tmp]
175+
tmp = (xmin == 0) & (xmax == 0)
176+
xmin[tmp] = -1
177+
xmax[tmp] = +1
178+
if self.maxq < 0:
179+
self.scale = xmax
180+
self.zero = xmin
181+
else:
182+
self.scale = (xmax - xmin) / self.maxq
183+
if self.sym:
184+
self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) # type: ignore[arg-type]
185+
else:
186+
self.zero = torch.round(-xmin / self.scale)
187+
188+
sensitivity = None
189+
if self.sensitivity is not None:
190+
sensitivity = self.sensitivity.to(Hinv.dtype).to(dev)
191+
if perm is not None:
192+
sensitivity = sensitivity[:, perm.to(dev)]
193+
194+
num_of_iters = 15
195+
best = torch.full([x.shape[0]], float("inf"), device=dev)
196+
for i in range(int(self.maxshrink * self.grid)):
197+
p = 1 - i / self.grid
198+
xmin1 = p * xmin
199+
xmax1 = p * xmax
200+
scale1 = (xmax1 - xmin1) / self.maxq
201+
zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
202+
q, _ = iterate_GPTQ(
203+
scale1.unsqueeze(1),
204+
zero1.unsqueeze(1),
205+
self.maxq,
206+
x,
207+
Hinv,
208+
max_num_of_iters=num_of_iters,
209+
)
210+
assert sensitivity is not None
211+
assert self.mse == "smse_for_gptq"
212+
err = ((q - x) ** 2) * sensitivity.to(q.device)
213+
214+
err = err
215+
err = torch.sum(err, 1)
216+
tmp = err < best
217+
if torch.any(tmp):
218+
best[tmp] = err[tmp]
219+
self.scale[tmp] = scale1[tmp]
220+
self.zero[tmp] = zero1[tmp]
221+
222+
shape = [-1] + [1] * (len(shape) - 1)
223+
self.scale = self.scale.reshape(shape)
224+
self.zero = self.zero.reshape(shape)
225+
154226
def quantize(self, x):
155227
if self.ready():
156228
return quantize(x, self.scale, self.zero, self.maxq)

tico/quantization/algorithm/gptq/utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,11 @@ def compute_sensitivity_info(self):
163163
if self.show_progress is True:
164164
print("Calibrating sensitivity")
165165
for inputs, targets in tqdm.tqdm(data_loader, disable=not self.show_progress):
166-
model.zero_grad()
166+
model.zero_grad(set_to_none=True)
167+
if model.device.type != "cpu":
168+
torch.cuda.empty_cache()
169+
torch.cuda.synchronize()
170+
167171
if isinstance(inputs, torch.Tensor):
168172
inp_ids = inputs.squeeze(0) # remove redundant batch dimension
169173
logits = model(inp_ids.to(model.device)).logits
@@ -215,6 +219,11 @@ def compute_sensitivity_info(self):
215219
for name in modules_to_process:
216220
sensitivity[name] /= len(data_loader)
217221

222+
model.zero_grad(set_to_none=True)
223+
if model.device.type != "cpu":
224+
torch.cuda.synchronize()
225+
torch.cuda.empty_cache()
226+
218227
model = model.to(dtype)
219228

220229
return sensitivity

tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from lm_eval.utils import make_table
4141
from transformers import AutoModelForCausalLM, AutoTokenizer
4242

43+
import numpy as np
4344
import tico
4445

4546
from tico.quantization import convert, prepare
@@ -98,6 +99,51 @@ def inject_gptq_qparams(
9899
obs.load_qparams(quantizer.scale, quantizer.zero, lock=True)
99100

100101

102+
def evaluate_ppl_of_model_on_dataset(model, dataset, device: str = "cuda"):
103+
if hasattr(model, "device") and model.device.type != device.type:
104+
if hasattr(model, "to"):
105+
model.to(device)
106+
nlls = []
107+
with torch.no_grad():
108+
for batch in tqdm.tqdm(dataset):
109+
if isinstance(batch, torch.Tensor):
110+
batch = batch.to(device)
111+
output = model(
112+
batch.to(device),
113+
)
114+
else:
115+
raise RuntimeError("Unknown input in ppl_eval_on_dataset")
116+
117+
if hasattr(output, "logits"):
118+
lm_logits = output.logits
119+
elif len(output) > 1:
120+
lm_logits = torch.tensor(output[0])
121+
else:
122+
lm_logits = torch.tensor(output)
123+
124+
if torch.isfinite(lm_logits).all():
125+
shift_logits = lm_logits[:, :-1, :].contiguous()
126+
if isinstance(batch, torch.Tensor):
127+
shift_labels = batch[:, 1:].contiguous()
128+
else:
129+
assert isinstance(batch, tuple)
130+
shift_labels = batch[0][:, 1:].contiguous()
131+
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
132+
loss = loss_fct(
133+
shift_logits.reshape(-1, shift_logits.size(-1)),
134+
shift_labels.view(-1),
135+
)
136+
nlls.append(loss)
137+
del shift_logits, shift_labels
138+
shift_logits = shift_labels = None # type: ignore[assignment]
139+
140+
del batch, lm_logits, output
141+
lm_logits = output = batch = None # noqa: F841
142+
torch.cuda.empty_cache()
143+
144+
ppl = np.exp(torch.cat(nlls, dim=-1).mean().item())
145+
return ppl
146+
101147
# -------------------------------------------------------------------------
102148
# Save model in circle format
103149
# -------------------------------------------------------------------------
@@ -296,7 +342,7 @@ def main():
296342
"--gptq_mse",
297343
type=str,
298344
default=None,
299-
choices=["mse", "smse"],
345+
choices=["mse", "smse", "smse_for_gptq"],
300346
help="Whether and how to use mse in gptq (none/mse/smse/)",
301347
)
302348
parser.add_argument(
@@ -410,6 +456,13 @@ def main():
410456
j = i + seqlen
411457
inp = train_ids[:, i:j]
412458
calib_inputs.append(inp.cpu())
459+
460+
train_ppl_fp32 = evaluate_ppl_of_model_on_dataset(
461+
model, calib_inputs, device=device
462+
)
463+
print("\n┌── Wikitext-2 train perplexity ─────────────")
464+
print(f"│ FP32 : {train_ppl_fp32:8.2f}")
465+
print("└───────────────────────────────────────────")
413466

414467
# -------------------------------------------------------------------------
415468
# Run GPTQ (weight-only) pass
@@ -418,7 +471,9 @@ def main():
418471
print("Applying GPTQ …")
419472

420473
sens = None
421-
if args.gptq_mse is not None and args.gptq_mse == "smse":
474+
if args.gptq_mse is not None and (
475+
args.gptq_mse == "smse" or args.gptq_mse == "smse_for_gptq"
476+
):
422477
if args.sensitivity_path is not None:
423478
sens = torch.load(args.sensitivity_path)
424479
else:
@@ -446,6 +501,13 @@ def main():
446501
if not args.no_PTQ:
447502
q_m = quantize_using_PTQ(q_m, calib_inputs, args)
448503

504+
train_ppl_ioqdtype = evaluate_ppl_of_model_on_dataset(
505+
q_m, calib_inputs, device=device
506+
)
507+
print("\n┌── Wikitext-2 train perplexity ─────────────")
508+
print(f"│ int16 : {train_ppl_ioqdtype:8.2f}")
509+
print("└───────────────────────────────────────────")
510+
449511
# after PTQ quantizer only fixed-length input sequences are valid
450512
evaluate(q_m, tokenizer, dataset_test, args)
451513

0 commit comments

Comments
 (0)