|
19 | 19 | import torch |
20 | 20 |
|
21 | 21 | from tico.quantization import convert, prepare |
| 22 | +from tico.quantization.algorithm.gptq.quantizer import GPTQQuantizer |
22 | 23 | from tico.quantization.algorithm.gptq.utils import SensitivityCalibrator |
23 | 24 | from tico.quantization.config.gptq import GPTQConfig |
24 | 25 | from tico.quantization.config.ptq import PTQConfig |
@@ -48,6 +49,23 @@ def get_zero_inputs(self): |
48 | 49 | return (torch.zeros(1, 2048),), {} |
49 | 50 |
|
50 | 51 |
|
| 52 | +class SmallLinear(torch.nn.Module): |
| 53 | + def __init__(self): |
| 54 | + super().__init__() |
| 55 | + self.m = torch.nn.ModuleList() |
| 56 | + for _ in range(3): |
| 57 | + self.m.append(torch.nn.Linear(16, 16)) |
| 58 | + |
| 59 | + def forward(self, x): |
| 60 | + z = self.m[0](x) |
| 61 | + z = self.m[1](z) |
| 62 | + z = self.m[2](z) |
| 63 | + return z |
| 64 | + |
| 65 | + def get_example_inputs(self): |
| 66 | + return (torch.randn(1, 16),), {} |
| 67 | + |
| 68 | + |
51 | 69 | class NormConv2D(torch.nn.Module): |
52 | 70 | def __init__(self): |
53 | 71 | super().__init__() |
@@ -278,6 +296,87 @@ def get_example_inputs(self): |
278 | 296 |
|
279 | 297 |
|
280 | 298 | class GPTQTest(unittest.TestCase): |
| 299 | + def test_gptq_config_validate_weight_bits_overrides(self): |
| 300 | + conf = GPTQConfig(weight_bits=4, weight_bits_overrides={"m.1": 8}) |
| 301 | + conf.validate() |
| 302 | + |
| 303 | + def test_gptq_config_validate_rejects_non_positive_weight_bits_override(self): |
| 304 | + conf = GPTQConfig(weight_bits=4, weight_bits_overrides={"m.1": 0}) |
| 305 | + with self.assertRaises(ValueError): |
| 306 | + conf.validate() |
| 307 | + |
| 308 | + def test_resolve_weight_bits_priority(self): |
| 309 | + quantizer = GPTQQuantizer( |
| 310 | + GPTQConfig( |
| 311 | + weight_bits=4, |
| 312 | + weight_bits_overrides={ |
| 313 | + "proj": 5, |
| 314 | + "layer.proj": 6, |
| 315 | + "model.layers.0.layer.proj": 8, |
| 316 | + }, |
| 317 | + ) |
| 318 | + ) |
| 319 | + |
| 320 | + assert isinstance(quantizer.config, GPTQConfig) |
| 321 | + self.assertEqual( |
| 322 | + quantizer._resolve_weight_bits( |
| 323 | + quantizer.config, |
| 324 | + full_module_name="model.layers.0.layer.proj", |
| 325 | + local_module_name="layer.proj", |
| 326 | + ), |
| 327 | + 8, |
| 328 | + ) |
| 329 | + self.assertEqual( |
| 330 | + quantizer._resolve_weight_bits( |
| 331 | + quantizer.config, |
| 332 | + full_module_name="model.layers.1.layer.proj", |
| 333 | + local_module_name="layer.proj", |
| 334 | + ), |
| 335 | + 6, |
| 336 | + ) |
| 337 | + self.assertEqual( |
| 338 | + quantizer._resolve_weight_bits( |
| 339 | + quantizer.config, |
| 340 | + full_module_name="model.layers.2.other.proj", |
| 341 | + local_module_name="other.proj", |
| 342 | + ), |
| 343 | + 5, |
| 344 | + ) |
| 345 | + self.assertEqual( |
| 346 | + quantizer._resolve_weight_bits( |
| 347 | + quantizer.config, |
| 348 | + full_module_name="model.layers.2.other.up_proj", |
| 349 | + local_module_name="other.up_proj", |
| 350 | + ), |
| 351 | + 4, |
| 352 | + ) |
| 353 | + |
| 354 | + @torch.inference_mode() |
| 355 | + def test_weight_bits_overrides_are_applied_per_module(self): |
| 356 | + q_m = SmallLinear() |
| 357 | + q_m.eval() |
| 358 | + ori_m = q_m |
| 359 | + |
| 360 | + q_m = prepare( |
| 361 | + q_m, |
| 362 | + GPTQConfig( |
| 363 | + show_progress=False, |
| 364 | + weight_bits=4, |
| 365 | + weight_bits_overrides={ |
| 366 | + "m.1": 8, |
| 367 | + }, |
| 368 | + ), |
| 369 | + ) |
| 370 | + for _ in range(8): |
| 371 | + args, kwargs = ori_m.get_example_inputs() |
| 372 | + q_m(*args, **kwargs) |
| 373 | + convert(q_m, inplace=True) |
| 374 | + |
| 375 | + self.assertTrue(hasattr(q_m, "quantizers")) |
| 376 | + self.assertEqual(q_m.quantizers["model.layers.0.m.0"].maxq.item(), 15) |
| 377 | + self.assertEqual(q_m.quantizers["model.layers.0.m.1"].maxq.item(), 255) |
| 378 | + self.assertEqual(q_m.quantizers["model.layers.0.m.2"].maxq.item(), 15) |
| 379 | + |
281 | 380 | @unittest.skipIf( |
282 | 381 | not IS_INTERNAL_TEST, "Internal test — run only if --include-internal is set" |
283 | 382 | ) |
|
0 commit comments