diff --git a/test/quantization/algorithm/test_gptq.py b/test/quantization/algorithm/test_gptq.py index 7ac94083..7805b9ae 100644 --- a/test/quantization/algorithm/test_gptq.py +++ b/test/quantization/algorithm/test_gptq.py @@ -19,6 +19,7 @@ import torch from tico.quantization import convert, prepare +from tico.quantization.algorithm.gptq.quantizer import GPTQQuantizer from tico.quantization.algorithm.gptq.utils import SensitivityCalibrator from tico.quantization.config.gptq import GPTQConfig from tico.quantization.config.ptq import PTQConfig @@ -48,6 +49,23 @@ def get_zero_inputs(self): return (torch.zeros(1, 2048),), {} +class SmallLinear(torch.nn.Module): + def __init__(self): + super().__init__() + self.m = torch.nn.ModuleList() + for _ in range(3): + self.m.append(torch.nn.Linear(16, 16)) + + def forward(self, x): + z = self.m[0](x) + z = self.m[1](z) + z = self.m[2](z) + return z + + def get_example_inputs(self): + return (torch.randn(1, 16),), {} + + class NormConv2D(torch.nn.Module): def __init__(self): super().__init__() @@ -278,6 +296,87 @@ def get_example_inputs(self): class GPTQTest(unittest.TestCase): + def test_gptq_config_validate_weight_bits_overrides(self): + conf = GPTQConfig(weight_bits=4, weight_bits_overrides={"m.1": 8}) + conf.validate() + + def test_gptq_config_validate_rejects_non_positive_weight_bits_override(self): + conf = GPTQConfig(weight_bits=4, weight_bits_overrides={"m.1": 0}) + with self.assertRaises(ValueError): + conf.validate() + + def test_resolve_weight_bits_priority(self): + quantizer = GPTQQuantizer( + GPTQConfig( + weight_bits=4, + weight_bits_overrides={ + "proj": 5, + "layer.proj": 6, + "model.layers.0.layer.proj": 8, + }, + ) + ) + + assert isinstance(quantizer.config, GPTQConfig) + self.assertEqual( + quantizer._resolve_weight_bits( + quantizer.config, + full_module_name="model.layers.0.layer.proj", + local_module_name="layer.proj", + ), + 8, + ) + self.assertEqual( + quantizer._resolve_weight_bits( + quantizer.config, + full_module_name="model.layers.1.layer.proj", + local_module_name="layer.proj", + ), + 6, + ) + self.assertEqual( + quantizer._resolve_weight_bits( + quantizer.config, + full_module_name="model.layers.2.other.proj", + local_module_name="other.proj", + ), + 5, + ) + self.assertEqual( + quantizer._resolve_weight_bits( + quantizer.config, + full_module_name="model.layers.2.other.up_proj", + local_module_name="other.up_proj", + ), + 4, + ) + + @torch.inference_mode() + def test_weight_bits_overrides_are_applied_per_module(self): + q_m = SmallLinear() + q_m.eval() + ori_m = q_m + + q_m = prepare( + q_m, + GPTQConfig( + show_progress=False, + weight_bits=4, + weight_bits_overrides={ + "m.1": 8, + }, + ), + ) + for _ in range(8): + args, kwargs = ori_m.get_example_inputs() + q_m(*args, **kwargs) + convert(q_m, inplace=True) + + self.assertTrue(hasattr(q_m, "quantizers")) + self.assertEqual(q_m.quantizers["model.layers.0.m.0"].maxq.item(), 15) + self.assertEqual(q_m.quantizers["model.layers.0.m.1"].maxq.item(), 255) + self.assertEqual(q_m.quantizers["model.layers.0.m.2"].maxq.item(), 15) + @unittest.skipIf( not IS_INTERNAL_TEST, "Internal test — run only if --include-internal is set" ) diff --git a/tico/quantization/algorithm/gptq/quantizer.py b/tico/quantization/algorithm/gptq/quantizer.py index d9bb1b9c..3c1ee80d 100644 --- a/tico/quantization/algorithm/gptq/quantizer.py +++ b/tico/quantization/algorithm/gptq/quantizer.py @@ -60,6 +60,31 @@ def __init__(self, config: GPTQConfig): self._orig_layer_forward: Optional[Callable[..., Any]] = None self._first_layer_ref: Optional[torch.nn.Module] = None + def _resolve_weight_bits( + self, + gptq_conf: GPTQConfig, + *, + full_module_name: str, + local_module_name: str, + ) -> int: + """Resolve the effective bit-width for a quantized submodule.""" + if full_module_name in gptq_conf.weight_bits_overrides: + return gptq_conf.weight_bits_overrides[full_module_name] + + if local_module_name in gptq_conf.weight_bits_overrides: + return gptq_conf.weight_bits_overrides[local_module_name] + + suffix_matches = [ + bits + for pattern, bits in gptq_conf.weight_bits_overrides.items() + if full_module_name.endswith(f".{pattern}") + ] + + if suffix_matches: + return suffix_matches[-1] + + return gptq_conf.weight_bits + @torch.no_grad() def prepare( self, @@ -220,18 +245,22 @@ def convert(self, model): gptq: Dict[str, GPTQ] = {} for name in subset: gptq[name] = GPTQ(subset[name]) + full_module_name = module_name[subset[name]] + weight_bits = self._resolve_weight_bits( + gptq_conf, + full_module_name=full_module_name, + local_module_name=name, + ) if ( gptq_conf.sensitivity is not None and isinstance(gptq_conf.sensitivity, dict) - and module_name[subset[name]] in gptq_conf.sensitivity + and full_module_name in gptq_conf.sensitivity ): - cur_sensitivity = gptq_conf.sensitivity[ - module_name[subset[name]] - ] + cur_sensitivity = gptq_conf.sensitivity[full_module_name] else: cur_sensitivity = None gptq[name].quantizer.configure( - bits=gptq_conf.weight_bits, + bits=weight_bits, perchannel=gptq_conf.perchannel, sym=gptq_conf.symmetric, mse=gptq_conf.mse, diff --git a/tico/quantization/config/gptq.py b/tico/quantization/config/gptq.py index ccf2e7e6..4ac2efcb 100644 --- a/tico/quantization/config/gptq.py +++ b/tico/quantization/config/gptq.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass +from dataclasses import dataclass, field import torch @@ -23,6 +23,21 @@ class GPTQConfig(BaseConfig): """ Configuration for GPTQ weight quantization. + + Attributes + ---------- + weight_bits : int + Default bit-width applied to quantized weights. + weight_bits_overrides : dict[str, int] + Optional per-module bit-width overrides. + + Supported keys are matched in the following order: + 1) Full module name, for example `model.layers.0.self_attn.o_proj` + 2) Layer-local module name, for example `self_attn.o_proj` + 3) Full-name suffix, for example `self_attn.o_proj` or `down_proj` + + This makes it possible to keep a default bit-width for most modules + while selectively increasing precision for specific projections. """ # general @@ -31,6 +46,7 @@ class GPTQConfig(BaseConfig): # quantizer.configure params (weight quantization spec) weight_bits: int = 8 + weight_bits_overrides: dict[str, int] = field(default_factory=dict) perchannel: bool = True symmetric: bool = False mse: str | None = None @@ -49,6 +65,11 @@ def name(self) -> str: def validate(self) -> None: if self.weight_bits <= 0: raise ValueError(f"weight_bits must be positive. got {self.weight_bits}") + for module_name, bits in self.weight_bits_overrides.items(): + if bits <= 0: + raise ValueError( + f"weight_bits_overrides[{module_name!r}] must be positive. got {bits}" + ) if self.groupsize != -1 and self.groupsize <= 0: raise ValueError(f"groupsize must be -1 or positive. got {self.groupsize}") if not (0.0 < self.percdamp <= 1.0):