Skip to content

Commit 20737bc

Browse files
committed
[quantization] Support override in GPTQ
This commit supports override option in GPTQ. TICO-DCO-1.0-Signed-off-by: seongwoo <mhs4670go@naver.com>
1 parent 783c865 commit 20737bc

3 files changed

Lines changed: 155 additions & 6 deletions

File tree

test/quantization/algorithm/test_gptq.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import torch
2020

2121
from tico.quantization import convert, prepare
22+
from tico.quantization.algorithm.gptq.quantizer import GPTQQuantizer
2223
from tico.quantization.algorithm.gptq.utils import SensitivityCalibrator
2324
from tico.quantization.config.gptq import GPTQConfig
2425
from tico.quantization.config.ptq import PTQConfig
@@ -48,6 +49,23 @@ def get_zero_inputs(self):
4849
return (torch.zeros(1, 2048),), {}
4950

5051

52+
class SmallLinear(torch.nn.Module):
53+
def __init__(self):
54+
super().__init__()
55+
self.m = torch.nn.ModuleList()
56+
for _ in range(3):
57+
self.m.append(torch.nn.Linear(16, 16))
58+
59+
def forward(self, x):
60+
z = self.m[0](x)
61+
z = self.m[1](z)
62+
z = self.m[2](z)
63+
return z
64+
65+
def get_example_inputs(self):
66+
return (torch.randn(1, 16),), {}
67+
68+
5169
class NormConv2D(torch.nn.Module):
5270
def __init__(self):
5371
super().__init__()
@@ -278,6 +296,87 @@ def get_example_inputs(self):
278296

279297

280298
class GPTQTest(unittest.TestCase):
299+
def test_gptq_config_validate_weight_bits_overrides(self):
300+
conf = GPTQConfig(weight_bits=4, weight_bits_overrides={"m.1": 8})
301+
conf.validate()
302+
303+
def test_gptq_config_validate_rejects_non_positive_weight_bits_override(self):
304+
conf = GPTQConfig(weight_bits=4, weight_bits_overrides={"m.1": 0})
305+
with self.assertRaises(ValueError):
306+
conf.validate()
307+
308+
def test_resolve_weight_bits_priority(self):
309+
quantizer = GPTQQuantizer(
310+
GPTQConfig(
311+
weight_bits=4,
312+
weight_bits_overrides={
313+
"proj": 5,
314+
"layer.proj": 6,
315+
"model.layers.0.layer.proj": 8,
316+
},
317+
)
318+
)
319+
320+
assert isinstance(quantizer.config, GPTQConfig)
321+
self.assertEqual(
322+
quantizer._resolve_weight_bits(
323+
quantizer.config,
324+
full_module_name="model.layers.0.layer.proj",
325+
local_module_name="layer.proj",
326+
),
327+
8,
328+
)
329+
self.assertEqual(
330+
quantizer._resolve_weight_bits(
331+
quantizer.config,
332+
full_module_name="model.layers.1.layer.proj",
333+
local_module_name="layer.proj",
334+
),
335+
6,
336+
)
337+
self.assertEqual(
338+
quantizer._resolve_weight_bits(
339+
quantizer.config,
340+
full_module_name="model.layers.2.other.proj",
341+
local_module_name="other.proj",
342+
),
343+
5,
344+
)
345+
self.assertEqual(
346+
quantizer._resolve_weight_bits(
347+
quantizer.config,
348+
full_module_name="model.layers.2.other.up_proj",
349+
local_module_name="other.up_proj",
350+
),
351+
4,
352+
)
353+
354+
@torch.inference_mode()
355+
def test_weight_bits_overrides_are_applied_per_module(self):
356+
q_m = SmallLinear()
357+
q_m.eval()
358+
ori_m = q_m
359+
360+
q_m = prepare(
361+
q_m,
362+
GPTQConfig(
363+
show_progress=False,
364+
weight_bits=4,
365+
weight_bits_overrides={
366+
"m.1": 8,
367+
},
368+
),
369+
)
370+
for _ in range(8):
371+
args, kwargs = ori_m.get_example_inputs()
372+
q_m(*args, **kwargs)
373+
convert(q_m, inplace=True)
374+
375+
self.assertTrue(hasattr(q_m, "quantizers"))
376+
self.assertEqual(q_m.quantizers["model.layers.0.m.0"].maxq.item(), 15)
377+
self.assertEqual(q_m.quantizers["model.layers.0.m.1"].maxq.item(), 255)
378+
self.assertEqual(q_m.quantizers["model.layers.0.m.2"].maxq.item(), 15)
379+
281380
@unittest.skipIf(
282381
not IS_INTERNAL_TEST, "Internal test — run only if --include-internal is set"
283382
)

tico/quantization/algorithm/gptq/quantizer.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,31 @@ def __init__(self, config: GPTQConfig):
6060
self._orig_layer_forward: Optional[Callable[..., Any]] = None
6161
self._first_layer_ref: Optional[torch.nn.Module] = None
6262

63+
def _resolve_weight_bits(
64+
self,
65+
gptq_conf: GPTQConfig,
66+
*,
67+
full_module_name: str,
68+
local_module_name: str,
69+
) -> int:
70+
"""Resolve the effective bit-width for a quantized submodule."""
71+
if full_module_name in gptq_conf.weight_bits_overrides:
72+
return gptq_conf.weight_bits_overrides[full_module_name]
73+
74+
if local_module_name in gptq_conf.weight_bits_overrides:
75+
return gptq_conf.weight_bits_overrides[local_module_name]
76+
77+
suffix_matches = [
78+
bits
79+
for pattern, bits in gptq_conf.weight_bits_overrides.items()
80+
if full_module_name.endswith(f".{pattern}")
81+
]
82+
83+
if suffix_matches:
84+
return suffix_matches[-1]
85+
86+
return gptq_conf.weight_bits
87+
6388
@torch.no_grad()
6489
def prepare(
6590
self,
@@ -220,18 +245,22 @@ def convert(self, model):
220245
gptq: Dict[str, GPTQ] = {}
221246
for name in subset:
222247
gptq[name] = GPTQ(subset[name])
248+
full_module_name = module_name[subset[name]]
249+
weight_bits = self._resolve_weight_bits(
250+
gptq_conf,
251+
full_module_name=full_module_name,
252+
local_module_name=name,
253+
)
223254
if (
224255
gptq_conf.sensitivity is not None
225256
and isinstance(gptq_conf.sensitivity, dict)
226-
and module_name[subset[name]] in gptq_conf.sensitivity
257+
and full_module_name in gptq_conf.sensitivity
227258
):
228-
cur_sensitivity = gptq_conf.sensitivity[
229-
module_name[subset[name]]
230-
]
259+
cur_sensitivity = gptq_conf.sensitivity[full_module_name]
231260
else:
232261
cur_sensitivity = None
233262
gptq[name].quantizer.configure(
234-
bits=gptq_conf.weight_bits,
263+
bits=weight_bits,
235264
perchannel=gptq_conf.perchannel,
236265
sym=gptq_conf.symmetric,
237266
mse=gptq_conf.mse,

tico/quantization/config/gptq.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from dataclasses import dataclass
15+
from dataclasses import dataclass, field
1616

1717
import torch
1818

@@ -23,6 +23,21 @@
2323
class GPTQConfig(BaseConfig):
2424
"""
2525
Configuration for GPTQ weight quantization.
26+
27+
Attributes
28+
----------
29+
weight_bits : int
30+
Default bit-width applied to quantized weights.
31+
weight_bits_overrides : dict[str, int]
32+
Optional per-module bit-width overrides.
33+
34+
Supported keys are matched in the following order:
35+
1) Full module name, for example `model.layers.0.self_attn.o_proj`
36+
2) Layer-local module name, for example `self_attn.o_proj`
37+
3) Full-name suffix, for example `self_attn.o_proj` or `down_proj`
38+
39+
This makes it possible to keep a default bit-width for most modules
40+
while selectively increasing precision for specific projections.
2641
"""
2742

2843
# general
@@ -31,6 +46,7 @@ class GPTQConfig(BaseConfig):
3146

3247
# quantizer.configure params (weight quantization spec)
3348
weight_bits: int = 8
49+
weight_bits_overrides: dict[str, int] = field(default_factory=dict)
3450
perchannel: bool = True
3551
symmetric: bool = False
3652
mse: str | None = None
@@ -49,6 +65,11 @@ def name(self) -> str:
4965
def validate(self) -> None:
5066
if self.weight_bits <= 0:
5167
raise ValueError(f"weight_bits must be positive. got {self.weight_bits}")
68+
for module_name, bits in self.weight_bits_overrides.items():
69+
if bits <= 0:
70+
raise ValueError(
71+
f"weight_bits_overrides[{module_name!r}] must be positive. got {bits}"
72+
)
5273
if self.groupsize != -1 and self.groupsize <= 0:
5374
raise ValueError(f"groupsize must be -1 or positive. got {self.groupsize}")
5475
if not (0.0 < self.percdamp <= 1.0):

0 commit comments

Comments
 (0)