diff --git a/test/quantization/config/test_builders.py b/test/quantization/config/test_builders.py index 0552e861..298e9461 100644 --- a/test/quantization/config/test_builders.py +++ b/test/quantization/config/test_builders.py @@ -15,7 +15,6 @@ import unittest from tico.quantization.config.builders import ( - _auto_qscheme_for, _build_llama_layer_overrides, _build_llama_overrides, _build_norm_override, @@ -25,6 +24,7 @@ build_llm_ptq_config, ) from tico.quantization.config.ptq import PTQConfig +from tico.quantization.config.utils import auto_qscheme_for from tico.quantization.wrapq.dtypes import DType from tico.quantization.wrapq.observers.ema import EMAObserver from tico.quantization.wrapq.qscheme import QScheme @@ -33,19 +33,19 @@ class TestBuilderHelpers(unittest.TestCase): def test_auto_qscheme_for_unsigned_activation(self): self.assertEqual( - _auto_qscheme_for(DType.uint(8), "act_in"), + auto_qscheme_for(DType.uint(8), "act_in"), QScheme.PER_TENSOR_ASYMM, ) def test_auto_qscheme_for_unsigned_weight(self): self.assertEqual( - _auto_qscheme_for(DType.uint(8), "weight"), + auto_qscheme_for(DType.uint(8), "weight"), QScheme.PER_CHANNEL_ASYMM, ) def test_auto_qscheme_for_signed_dtype(self): self.assertEqual( - _auto_qscheme_for(DType.int(8), "weight"), + auto_qscheme_for(DType.int(8), "weight"), QScheme.PER_TENSOR_SYMM, ) diff --git a/test/quantization/config/test_utils.py b/test/quantization/config/test_utils.py new file mode 100644 index 00000000..8c2afe49 --- /dev/null +++ b/test/quantization/config/test_utils.py @@ -0,0 +1,64 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from tico.quantization.config.utils import auto_qscheme_for, dtype_is_unsigned +from tico.quantization.wrapq.dtypes import DType +from tico.quantization.wrapq.qscheme import QScheme + + +class TestUtils(unittest.TestCase): + def test_dtype_is_unsigned_true(self): + self.assertTrue(dtype_is_unsigned(DType.uint(8))) + self.assertTrue(dtype_is_unsigned(DType.uint(4))) + + def test_dtype_is_unsigned_false(self): + self.assertFalse(dtype_is_unsigned(DType.int(8))) + self.assertFalse(dtype_is_unsigned(DType.int(16))) + + def test_auto_qscheme_for_unsigned_activation(self): + self.assertEqual( + auto_qscheme_for(DType.uint(8), "act_in"), + QScheme.PER_TENSOR_ASYMM, + ) + + def test_auto_qscheme_for_unsigned_weight(self): + self.assertEqual( + auto_qscheme_for(DType.uint(8), "weight"), + QScheme.PER_CHANNEL_ASYMM, + ) + + def test_auto_qscheme_for_signed_dtype(self): + self.assertEqual( + auto_qscheme_for(DType.int(8), "act_in"), + QScheme.PER_TENSOR_SYMM, + ) + self.assertEqual( + auto_qscheme_for(DType.int(16), "weight"), + QScheme.PER_TENSOR_SYMM, + ) + + def test_auto_qscheme_for_default_obs_name(self): + """ + When obs_name is None, it should behave like activation. + """ + self.assertEqual( + auto_qscheme_for(DType.uint(8), None), + QScheme.PER_TENSOR_ASYMM, + ) + self.assertEqual( + auto_qscheme_for(DType.int(8), None), + QScheme.PER_TENSOR_SYMM, + ) diff --git a/tico/quantization/config/builders.py b/tico/quantization/config/builders.py index 66f5cad3..d14d06d4 100644 --- a/tico/quantization/config/builders.py +++ b/tico/quantization/config/builders.py @@ -16,28 +16,13 @@ from typing import Any, Dict, Optional, Tuple, Type from tico.quantization.config.ptq import PTQConfig, WrapperVariant +from tico.quantization.config.utils import auto_qscheme_for from tico.quantization.wrapq.dtypes import DType from tico.quantization.wrapq.observers.base import ObserverBase from tico.quantization.wrapq.observers.minmax import MinMaxObserver from tico.quantization.wrapq.qscheme import QScheme -def _auto_qscheme_for(dtype: DType, obs_name: Optional[str] = None) -> QScheme: - """ - Choose the default qscheme associated with a dtype and observer name. - - Default policy: - - signed dtype -> symmetric per-tensor - - unsigned dtype -> asymmetric per-tensor - - unsigned weight -> asymmetric per-channel - """ - if not dtype.signed: - if obs_name == "weight": - return QScheme.PER_CHANNEL_ASYMM - return QScheme.PER_TENSOR_ASYMM - return QScheme.PER_TENSOR_SYMM - - def _weight_dtype_from_bits(bits: int) -> DType: """ Convert a commonly used bit-width into a corresponding quantized dtype. @@ -159,7 +144,7 @@ def _build_weight_override(weight_dtype: Optional[DType]) -> Dict[str, Any]: return { "weight": { "dtype": weight_dtype, - "qscheme": _auto_qscheme_for(weight_dtype, "weight"), + "qscheme": auto_qscheme_for(weight_dtype, "weight"), } } @@ -189,12 +174,12 @@ def _build_norm_override( if norm_dtype is not None: override["dtype"] = norm_dtype - override["qscheme"] = _auto_qscheme_for(norm_dtype) + override["qscheme"] = auto_qscheme_for(norm_dtype) if norm_weight_dtype is not None: override["weight"] = { "dtype": norm_weight_dtype, - "qscheme": _auto_qscheme_for(norm_weight_dtype, "weight"), + "qscheme": auto_qscheme_for(norm_weight_dtype, "weight"), } return override diff --git a/tico/quantization/config/ptq.py b/tico/quantization/config/ptq.py index 6ef15c80..1dc52c08 100644 --- a/tico/quantization/config/ptq.py +++ b/tico/quantization/config/ptq.py @@ -17,6 +17,7 @@ from typing import Any, Dict, Iterable, Literal, Mapping, MutableMapping, Optional, Type from tico.quantization.config.base import BaseConfig +from tico.quantization.config.utils import auto_qscheme_for, dtype_is_unsigned from tico.quantization.wrapq.dtypes import DType from tico.quantization.wrapq.observers.base import ObserverBase from tico.quantization.wrapq.observers.minmax import MinMaxObserver @@ -25,29 +26,6 @@ WrapperVariant = Literal["common", "prefill", "decode"] -def _dtype_is_unsigned(dtype: DType) -> bool: - """ - Return True when the dtype is unsigned. - """ - return not dtype.signed - - -def _auto_qscheme_for(dtype: DType, obs_name: Optional[str] = None) -> QScheme: - """ - Choose a default qscheme from the effective dtype and observer name. - - Default policy: - - signed dtype -> symmetric per-tensor - - unsigned dtype -> asymmetric per-tensor - - unsigned weight -> asymmetric per-channel - """ - if _dtype_is_unsigned(dtype): - if obs_name == "weight": - return QScheme.PER_CHANNEL_ASYMM - return QScheme.PER_TENSOR_ASYMM - return QScheme.PER_TENSOR_SYMM - - def _resolve_qscheme( *, dtype: DType, @@ -62,9 +40,9 @@ def _resolve_qscheme( 1. If `qscheme` is None, infer it from `dtype` and `obs_name`. 2. If the caller explicitly provides an incompatible pair, raise. """ - resolved_qscheme = qscheme or _auto_qscheme_for(dtype, obs_name) + resolved_qscheme = qscheme or auto_qscheme_for(dtype, obs_name) - if _dtype_is_unsigned(dtype) and resolved_qscheme.is_symmetric(): + if dtype_is_unsigned(dtype) and resolved_qscheme.is_symmetric(): raise ValueError( f"Invalid quantization config at {context}: unsigned dtype " f"{dtype!r} cannot be paired with symmetric qscheme " diff --git a/tico/quantization/config/utils.py b/tico/quantization/config/utils.py new file mode 100644 index 00000000..956f61da --- /dev/null +++ b/tico/quantization/config/utils.py @@ -0,0 +1,41 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from tico.quantization.wrapq.dtypes import DType +from tico.quantization.wrapq.qscheme import QScheme + + +def dtype_is_unsigned(dtype: DType) -> bool: + """ + Return True when the dtype is unsigned. + """ + return not dtype.signed + + +def auto_qscheme_for(dtype: DType, obs_name: Optional[str] = None) -> QScheme: + """ + Choose the default qscheme associated with a dtype and observer name. + + Default policy: + - signed dtype -> symmetric per-tensor + - unsigned dtype -> asymmetric per-tensor + - unsigned weight -> asymmetric per-channel + """ + if dtype_is_unsigned(dtype): + if obs_name == "weight": + return QScheme.PER_CHANNEL_ASYMM + return QScheme.PER_TENSOR_ASYMM + return QScheme.PER_TENSOR_SYMM