Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions test/quantization/algorithm/test_gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import torch

from tico.quantization import convert, prepare
from tico.quantization.algorithm.gptq.quantizer import GPTQQuantizer
from tico.quantization.algorithm.gptq.utils import SensitivityCalibrator
from tico.quantization.config.gptq import GPTQConfig
from tico.quantization.config.ptq import PTQConfig
Expand Down Expand Up @@ -48,6 +49,23 @@ def get_zero_inputs(self):
return (torch.zeros(1, 2048),), {}


class SmallLinear(torch.nn.Module):
def __init__(self):
super().__init__()
self.m = torch.nn.ModuleList()
for _ in range(3):
self.m.append(torch.nn.Linear(16, 16))

def forward(self, x):
z = self.m[0](x)
z = self.m[1](z)
z = self.m[2](z)
return z

def get_example_inputs(self):
return (torch.randn(1, 16),), {}


class NormConv2D(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -278,6 +296,87 @@ def get_example_inputs(self):


class GPTQTest(unittest.TestCase):
def test_gptq_config_validate_weight_bits_overrides(self):
conf = GPTQConfig(weight_bits=4, weight_bits_overrides={"m.1": 8})
conf.validate()

def test_gptq_config_validate_rejects_non_positive_weight_bits_override(self):
conf = GPTQConfig(weight_bits=4, weight_bits_overrides={"m.1": 0})
with self.assertRaises(ValueError):
conf.validate()

def test_resolve_weight_bits_priority(self):
quantizer = GPTQQuantizer(
GPTQConfig(
weight_bits=4,
weight_bits_overrides={
"proj": 5,
"layer.proj": 6,
"model.layers.0.layer.proj": 8,
},
)
)

assert isinstance(quantizer.config, GPTQConfig)
self.assertEqual(
quantizer._resolve_weight_bits(
quantizer.config,
full_module_name="model.layers.0.layer.proj",
local_module_name="layer.proj",
),
8,
)
self.assertEqual(
quantizer._resolve_weight_bits(
quantizer.config,
full_module_name="model.layers.1.layer.proj",
local_module_name="layer.proj",
),
6,
)
self.assertEqual(
quantizer._resolve_weight_bits(
quantizer.config,
full_module_name="model.layers.2.other.proj",
local_module_name="other.proj",
),
5,
)
self.assertEqual(
quantizer._resolve_weight_bits(
quantizer.config,
full_module_name="model.layers.2.other.up_proj",
local_module_name="other.up_proj",
),
4,
)

@torch.inference_mode()
def test_weight_bits_overrides_are_applied_per_module(self):
q_m = SmallLinear()
q_m.eval()
ori_m = q_m

q_m = prepare(
q_m,
GPTQConfig(
show_progress=False,
weight_bits=4,
weight_bits_overrides={
"m.1": 8,
},
),
)
for _ in range(8):
args, kwargs = ori_m.get_example_inputs()
q_m(*args, **kwargs)
convert(q_m, inplace=True)

self.assertTrue(hasattr(q_m, "quantizers"))
self.assertEqual(q_m.quantizers["model.layers.0.m.0"].maxq.item(), 15)
self.assertEqual(q_m.quantizers["model.layers.0.m.1"].maxq.item(), 255)
self.assertEqual(q_m.quantizers["model.layers.0.m.2"].maxq.item(), 15)

@unittest.skipIf(
not IS_INTERNAL_TEST, "Internal test — run only if --include-internal is set"
)
Expand Down
39 changes: 34 additions & 5 deletions tico/quantization/algorithm/gptq/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,31 @@ def __init__(self, config: GPTQConfig):
self._orig_layer_forward: Optional[Callable[..., Any]] = None
self._first_layer_ref: Optional[torch.nn.Module] = None

def _resolve_weight_bits(
self,
gptq_conf: GPTQConfig,
*,
full_module_name: str,
local_module_name: str,
) -> int:
"""Resolve the effective bit-width for a quantized submodule."""
if full_module_name in gptq_conf.weight_bits_overrides:
return gptq_conf.weight_bits_overrides[full_module_name]

if local_module_name in gptq_conf.weight_bits_overrides:
return gptq_conf.weight_bits_overrides[local_module_name]

suffix_matches = [
bits
for pattern, bits in gptq_conf.weight_bits_overrides.items()
if full_module_name.endswith(f".{pattern}")
]

if suffix_matches:
return suffix_matches[-1]

return gptq_conf.weight_bits

@torch.no_grad()
def prepare(
self,
Expand Down Expand Up @@ -220,18 +245,22 @@ def convert(self, model):
gptq: Dict[str, GPTQ] = {}
for name in subset:
gptq[name] = GPTQ(subset[name])
full_module_name = module_name[subset[name]]
weight_bits = self._resolve_weight_bits(
gptq_conf,
full_module_name=full_module_name,
local_module_name=name,
)
if (
gptq_conf.sensitivity is not None
and isinstance(gptq_conf.sensitivity, dict)
and module_name[subset[name]] in gptq_conf.sensitivity
and full_module_name in gptq_conf.sensitivity
):
cur_sensitivity = gptq_conf.sensitivity[
module_name[subset[name]]
]
cur_sensitivity = gptq_conf.sensitivity[full_module_name]
else:
cur_sensitivity = None
gptq[name].quantizer.configure(
bits=gptq_conf.weight_bits,
bits=weight_bits,
perchannel=gptq_conf.perchannel,
sym=gptq_conf.symmetric,
mse=gptq_conf.mse,
Expand Down
23 changes: 22 additions & 1 deletion tico/quantization/config/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from dataclasses import dataclass, field

import torch

Expand All @@ -23,6 +23,21 @@
class GPTQConfig(BaseConfig):
"""
Configuration for GPTQ weight quantization.

Attributes
----------
weight_bits : int
Default bit-width applied to quantized weights.
weight_bits_overrides : dict[str, int]
Optional per-module bit-width overrides.

Supported keys are matched in the following order:
1) Full module name, for example `model.layers.0.self_attn.o_proj`
2) Layer-local module name, for example `self_attn.o_proj`
3) Full-name suffix, for example `self_attn.o_proj` or `down_proj`

This makes it possible to keep a default bit-width for most modules
while selectively increasing precision for specific projections.
"""

# general
Expand All @@ -31,6 +46,7 @@ class GPTQConfig(BaseConfig):

# quantizer.configure params (weight quantization spec)
weight_bits: int = 8
weight_bits_overrides: dict[str, int] = field(default_factory=dict)
perchannel: bool = True
symmetric: bool = False
mse: str | None = None
Expand All @@ -49,6 +65,11 @@ def name(self) -> str:
def validate(self) -> None:
if self.weight_bits <= 0:
raise ValueError(f"weight_bits must be positive. got {self.weight_bits}")
for module_name, bits in self.weight_bits_overrides.items():
if bits <= 0:
raise ValueError(
f"weight_bits_overrides[{module_name!r}] must be positive. got {bits}"
)
if self.groupsize != -1 and self.groupsize <= 0:
raise ValueError(f"groupsize must be -1 or positive. got {self.groupsize}")
if not (0.0 < self.percdamp <= 1.0):
Expand Down
Loading