Skip to content

Commit 44ac2a6

Browse files
committed
[quantization] Improve overrides UX
This commit improves overrides UX. TICO-DCO-1.0-Signed-off-by: seongwoo <mhs4670go@naver.com>
1 parent 3e4b06f commit 44ac2a6

7 files changed

Lines changed: 1002 additions & 279 deletions

File tree

test/quantization/config/__init__.py

Whitespace-only changes.
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
from tico.quantization.config.builders import (
18+
_auto_qscheme_for,
19+
_build_llama_layer_overrides,
20+
_build_llama_overrides,
21+
_build_norm_override,
22+
_build_weight_override,
23+
_resolve_weight_dtype,
24+
_weight_dtype_from_bits,
25+
build_llm_ptq_config,
26+
)
27+
from tico.quantization.config.ptq import PTQConfig
28+
from tico.quantization.wrapq.dtypes import DType
29+
from tico.quantization.wrapq.qscheme import QScheme
30+
31+
32+
class TestBuilderHelpers(unittest.TestCase):
33+
def test_auto_qscheme_for_unsigned_activation(self):
34+
self.assertEqual(
35+
_auto_qscheme_for(DType.uint(8), "act_in"),
36+
QScheme.PER_TENSOR_ASYMM,
37+
)
38+
39+
def test_auto_qscheme_for_unsigned_weight(self):
40+
self.assertEqual(
41+
_auto_qscheme_for(DType.uint(8), "weight"),
42+
QScheme.PER_CHANNEL_ASYMM,
43+
)
44+
45+
def test_auto_qscheme_for_signed_dtype(self):
46+
self.assertEqual(
47+
_auto_qscheme_for(DType.int(8), "weight"),
48+
QScheme.PER_TENSOR_SYMM,
49+
)
50+
51+
def test_weight_dtype_from_bits(self):
52+
self.assertEqual(_weight_dtype_from_bits(16), DType.int(16))
53+
self.assertEqual(_weight_dtype_from_bits(8), DType.uint(8))
54+
self.assertEqual(_weight_dtype_from_bits(4), DType.uint(4))
55+
56+
def test_weight_dtype_from_bits_invalid_raises(self):
57+
with self.assertRaises(ValueError):
58+
_weight_dtype_from_bits(3)
59+
60+
def test_resolve_weight_dtype_prefers_explicit_dtype(self):
61+
self.assertEqual(
62+
_resolve_weight_dtype(dtype=DType.int(8), bits=4),
63+
DType.int(8),
64+
)
65+
66+
def test_resolve_weight_dtype_falls_back_to_bits(self):
67+
self.assertEqual(
68+
_resolve_weight_dtype(dtype=None, bits=4),
69+
DType.uint(4),
70+
)
71+
self.assertIsNone(_resolve_weight_dtype(dtype=None, bits=None))
72+
73+
def test_build_weight_override_includes_qscheme(self):
74+
override = _build_weight_override(DType.uint(8))
75+
self.assertEqual(
76+
override,
77+
{
78+
"weight": {
79+
"dtype": DType.uint(8),
80+
"qscheme": QScheme.PER_CHANNEL_ASYMM,
81+
}
82+
},
83+
)
84+
self.assertEqual(_build_weight_override(None), {})
85+
86+
def test_build_norm_override_includes_module_and_weight_qscheme(self):
87+
override = _build_norm_override(
88+
norm_dtype=DType.uint(8),
89+
norm_weight_dtype=DType.uint(4),
90+
)
91+
92+
self.assertEqual(override["dtype"], DType.uint(8))
93+
self.assertEqual(override["qscheme"], QScheme.PER_TENSOR_ASYMM)
94+
self.assertEqual(override["weight"]["dtype"], DType.uint(4))
95+
self.assertEqual(
96+
override["weight"]["qscheme"],
97+
QScheme.PER_CHANNEL_ASYMM,
98+
)
99+
100+
101+
class TestLlamaOverrideBuilders(unittest.TestCase):
102+
def test_build_llama_layer_overrides(self):
103+
overrides = _build_llama_layer_overrides(
104+
linear_weight_dtype=DType.uint(8),
105+
norm_dtype=DType.uint(8),
106+
norm_weight_dtype=DType.uint(4),
107+
)
108+
109+
self.assertEqual(
110+
overrides["self_attn"]["q_proj"]["weight"]["qscheme"],
111+
QScheme.PER_CHANNEL_ASYMM,
112+
)
113+
self.assertEqual(
114+
overrides["mlp"]["down_proj"]["weight"]["qscheme"],
115+
QScheme.PER_CHANNEL_ASYMM,
116+
)
117+
self.assertEqual(
118+
overrides["input_layernorm"]["qscheme"],
119+
QScheme.PER_TENSOR_ASYMM,
120+
)
121+
self.assertEqual(
122+
overrides["input_layernorm"]["weight"]["qscheme"],
123+
QScheme.PER_CHANNEL_ASYMM,
124+
)
125+
126+
def test_build_llama_overrides(self):
127+
overrides = _build_llama_overrides(
128+
num_hidden_layers=2,
129+
linear_weight_dtype=DType.uint(8),
130+
embedding_weight_dtype=DType.uint(4),
131+
lm_head_weight_dtype=DType.uint(8),
132+
norm_dtype=DType.int(16),
133+
norm_weight_dtype=DType.uint(4),
134+
)
135+
136+
self.assertIn("model", overrides)
137+
self.assertIn("layers", overrides["model"])
138+
self.assertEqual(len(overrides["model"]["layers"]), 2)
139+
self.assertEqual(
140+
overrides["model"]["embed_tokens"]["weight"]["qscheme"],
141+
QScheme.PER_CHANNEL_ASYMM,
142+
)
143+
self.assertEqual(
144+
overrides["lm_head"]["weight"]["qscheme"],
145+
QScheme.PER_CHANNEL_ASYMM,
146+
)
147+
self.assertEqual(
148+
overrides["model"]["norm"]["qscheme"],
149+
QScheme.PER_TENSOR_SYMM,
150+
)
151+
self.assertEqual(
152+
overrides["model"]["layers"]["0"]["self_attn"]["o_proj"]["weight"][
153+
"qscheme"
154+
],
155+
QScheme.PER_CHANNEL_ASYMM,
156+
)
157+
158+
159+
class TestBuildLlmPtqConfig(unittest.TestCase):
160+
def test_build_llm_ptq_config_llama(self):
161+
cfg = build_llm_ptq_config(
162+
model_type="llama",
163+
num_hidden_layers=2,
164+
wrapper_variant="decode",
165+
activation_dtype=DType.uint(8),
166+
default_qscheme=QScheme.PER_TENSOR_ASYMM,
167+
linear_weight_dtype=DType.uint(8),
168+
embedding_weight_dtype=DType.uint(4),
169+
lm_head_weight_dtype=DType.uint(8),
170+
norm_dtype=DType.int(16),
171+
norm_weight_dtype=DType.uint(4),
172+
strict_wrap=False,
173+
)
174+
175+
self.assertIsInstance(cfg, PTQConfig)
176+
self.assertEqual(cfg.default_dtype, DType.uint(8))
177+
self.assertEqual(cfg.default_qscheme, QScheme.PER_TENSOR_ASYMM)
178+
self.assertEqual(cfg.wrapper_variant, "decode")
179+
self.assertFalse(cfg.strict_wrap)
180+
181+
self.assertEqual(
182+
cfg.overrides["model"]["embed_tokens"]["weight"]["qscheme"],
183+
QScheme.PER_CHANNEL_ASYMM,
184+
)
185+
self.assertEqual(
186+
cfg.overrides["lm_head"]["weight"]["qscheme"],
187+
QScheme.PER_CHANNEL_ASYMM,
188+
)
189+
self.assertEqual(
190+
cfg.overrides["model"]["layers"]["1"]["mlp"]["up_proj"]["weight"][
191+
"qscheme"
192+
],
193+
QScheme.PER_CHANNEL_ASYMM,
194+
)
195+
self.assertEqual(
196+
cfg.overrides["model"]["norm"]["qscheme"],
197+
QScheme.PER_TENSOR_SYMM,
198+
)
199+
200+
def test_explicit_dtype_takes_precedence_over_bits(self):
201+
cfg = build_llm_ptq_config(
202+
model_type="llama",
203+
num_hidden_layers=1,
204+
linear_weight_bits=4,
205+
linear_weight_dtype=DType.uint(8),
206+
)
207+
208+
self.assertEqual(
209+
cfg.overrides["model"]["layers"]["0"]["self_attn"]["q_proj"]["weight"][
210+
"dtype"
211+
],
212+
DType.uint(8),
213+
)
214+
self.assertEqual(
215+
cfg.overrides["model"]["layers"]["0"]["self_attn"]["q_proj"]["weight"][
216+
"qscheme"
217+
],
218+
QScheme.PER_CHANNEL_ASYMM,
219+
)
220+
221+
def test_build_llm_ptq_config_unsupported_model_type_raises(self):
222+
with self.assertRaises(NotImplementedError):
223+
build_llm_ptq_config(
224+
model_type="mistral",
225+
num_hidden_layers=1,
226+
)

0 commit comments

Comments
 (0)