diff --git a/test/quantization/config/__init__.py b/test/quantization/config/__init__.py new file mode 100644 index 00000000..0c29109f --- /dev/null +++ b/test/quantization/config/__init__.py @@ -0,0 +1 @@ +# DO NOT REMOVE THIS FILE diff --git a/test/quantization/config/test_builders.py b/test/quantization/config/test_builders.py new file mode 100644 index 00000000..79451d42 --- /dev/null +++ b/test/quantization/config/test_builders.py @@ -0,0 +1,226 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from tico.quantization.config.builders import ( + _auto_qscheme_for, + _build_llama_layer_overrides, + _build_llama_overrides, + _build_norm_override, + _build_weight_override, + _resolve_weight_dtype, + _weight_dtype_from_bits, + build_llm_ptq_config, +) +from tico.quantization.config.ptq import PTQConfig +from tico.quantization.wrapq.dtypes import DType +from tico.quantization.wrapq.qscheme import QScheme + + +class TestBuilderHelpers(unittest.TestCase): + def test_auto_qscheme_for_unsigned_activation(self): + self.assertEqual( + _auto_qscheme_for(DType.uint(8), "act_in"), + QScheme.PER_TENSOR_ASYMM, + ) + + def test_auto_qscheme_for_unsigned_weight(self): + self.assertEqual( + _auto_qscheme_for(DType.uint(8), "weight"), + QScheme.PER_CHANNEL_ASYMM, + ) + + def test_auto_qscheme_for_signed_dtype(self): + self.assertEqual( + _auto_qscheme_for(DType.int(8), "weight"), + QScheme.PER_TENSOR_SYMM, + ) + + def test_weight_dtype_from_bits(self): + self.assertEqual(_weight_dtype_from_bits(16), DType.int(16)) + self.assertEqual(_weight_dtype_from_bits(8), DType.uint(8)) + self.assertEqual(_weight_dtype_from_bits(4), DType.uint(4)) + + def test_weight_dtype_from_bits_invalid_raises(self): + with self.assertRaises(ValueError): + _weight_dtype_from_bits(3) + + def test_resolve_weight_dtype_prefers_explicit_dtype(self): + self.assertEqual( + _resolve_weight_dtype(dtype=DType.int(8), bits=4), + DType.int(8), + ) + + def test_resolve_weight_dtype_falls_back_to_bits(self): + self.assertEqual( + _resolve_weight_dtype(dtype=None, bits=4), + DType.uint(4), + ) + self.assertIsNone(_resolve_weight_dtype(dtype=None, bits=None)) + + def test_build_weight_override_includes_qscheme(self): + override = _build_weight_override(DType.uint(8)) + self.assertEqual( + override, + { + "weight": { + "dtype": DType.uint(8), + "qscheme": QScheme.PER_CHANNEL_ASYMM, + } + }, + ) + self.assertEqual(_build_weight_override(None), {}) + + def test_build_norm_override_includes_module_and_weight_qscheme(self): + override = _build_norm_override( + norm_dtype=DType.uint(8), + norm_weight_dtype=DType.uint(4), + ) + + self.assertEqual(override["dtype"], DType.uint(8)) + self.assertEqual(override["qscheme"], QScheme.PER_TENSOR_ASYMM) + self.assertEqual(override["weight"]["dtype"], DType.uint(4)) + self.assertEqual( + override["weight"]["qscheme"], + QScheme.PER_CHANNEL_ASYMM, + ) + + +class TestLlamaOverrideBuilders(unittest.TestCase): + def test_build_llama_layer_overrides(self): + overrides = _build_llama_layer_overrides( + linear_weight_dtype=DType.uint(8), + norm_dtype=DType.uint(8), + norm_weight_dtype=DType.uint(4), + ) + + self.assertEqual( + overrides["self_attn"]["q_proj"]["weight"]["qscheme"], + QScheme.PER_CHANNEL_ASYMM, + ) + self.assertEqual( + overrides["mlp"]["down_proj"]["weight"]["qscheme"], + QScheme.PER_CHANNEL_ASYMM, + ) + self.assertEqual( + overrides["input_layernorm"]["qscheme"], + QScheme.PER_TENSOR_ASYMM, + ) + self.assertEqual( + overrides["input_layernorm"]["weight"]["qscheme"], + QScheme.PER_CHANNEL_ASYMM, + ) + + def test_build_llama_overrides(self): + overrides = _build_llama_overrides( + num_hidden_layers=2, + linear_weight_dtype=DType.uint(8), + embedding_weight_dtype=DType.uint(4), + lm_head_weight_dtype=DType.uint(8), + norm_dtype=DType.int(16), + norm_weight_dtype=DType.uint(4), + ) + + self.assertIn("model", overrides) + self.assertIn("layers", overrides["model"]) + self.assertEqual(len(overrides["model"]["layers"]), 2) + self.assertEqual( + overrides["model"]["embed_tokens"]["weight"]["qscheme"], + QScheme.PER_CHANNEL_ASYMM, + ) + self.assertEqual( + overrides["lm_head"]["weight"]["qscheme"], + QScheme.PER_CHANNEL_ASYMM, + ) + self.assertEqual( + overrides["model"]["norm"]["qscheme"], + QScheme.PER_TENSOR_SYMM, + ) + self.assertEqual( + overrides["model"]["layers"]["0"]["self_attn"]["o_proj"]["weight"][ + "qscheme" + ], + QScheme.PER_CHANNEL_ASYMM, + ) + + +class TestBuildLlmPtqConfig(unittest.TestCase): + def test_build_llm_ptq_config_llama(self): + cfg = build_llm_ptq_config( + model_type="llama", + num_hidden_layers=2, + wrapper_variant="decode", + activation_dtype=DType.uint(8), + default_qscheme=QScheme.PER_TENSOR_ASYMM, + linear_weight_dtype=DType.uint(8), + embedding_weight_dtype=DType.uint(4), + lm_head_weight_dtype=DType.uint(8), + norm_dtype=DType.int(16), + norm_weight_dtype=DType.uint(4), + strict_wrap=False, + ) + + self.assertIsInstance(cfg, PTQConfig) + self.assertEqual(cfg.default_dtype, DType.uint(8)) + self.assertEqual(cfg.default_qscheme, QScheme.PER_TENSOR_ASYMM) + self.assertEqual(cfg.wrapper_variant, "decode") + self.assertFalse(cfg.strict_wrap) + + self.assertEqual( + cfg.overrides["model"]["embed_tokens"]["weight"]["qscheme"], + QScheme.PER_CHANNEL_ASYMM, + ) + self.assertEqual( + cfg.overrides["lm_head"]["weight"]["qscheme"], + QScheme.PER_CHANNEL_ASYMM, + ) + self.assertEqual( + cfg.overrides["model"]["layers"]["1"]["mlp"]["up_proj"]["weight"][ + "qscheme" + ], + QScheme.PER_CHANNEL_ASYMM, + ) + self.assertEqual( + cfg.overrides["model"]["norm"]["qscheme"], + QScheme.PER_TENSOR_SYMM, + ) + + def test_explicit_dtype_takes_precedence_over_bits(self): + cfg = build_llm_ptq_config( + model_type="llama", + num_hidden_layers=1, + linear_weight_bits=4, + linear_weight_dtype=DType.uint(8), + ) + + self.assertEqual( + cfg.overrides["model"]["layers"]["0"]["self_attn"]["q_proj"]["weight"][ + "dtype" + ], + DType.uint(8), + ) + self.assertEqual( + cfg.overrides["model"]["layers"]["0"]["self_attn"]["q_proj"]["weight"][ + "qscheme" + ], + QScheme.PER_CHANNEL_ASYMM, + ) + + def test_build_llm_ptq_config_unsupported_model_type_raises(self): + with self.assertRaises(NotImplementedError): + build_llm_ptq_config( + model_type="mistral", + num_hidden_layers=1, + ) diff --git a/test/quantization/config/test_ptq.py b/test/quantization/config/test_ptq.py new file mode 100644 index 00000000..06d2100b --- /dev/null +++ b/test/quantization/config/test_ptq.py @@ -0,0 +1,524 @@ +# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from tico.quantization.config.ptq import PTQConfig + +from tico.quantization.wrapq.dtypes import DType +from tico.quantization.wrapq.observers.affine_base import AffineObserverBase +from tico.quantization.wrapq.observers.ema import EMAObserver +from tico.quantization.wrapq.observers.minmax import MinMaxObserver +from tico.quantization.wrapq.qscheme import QScheme +from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase + + +class DummyWrapper(QuantModuleBase): + """Minimal wrapper to expose `_make_obs` and store the created observer.""" + + def __init__(self, qcfg, **kwargs): + super().__init__(qcfg) + # kwargs here are wrapper-level defaults for _make_obs + self.obs_act_in = self._make_obs("act_in", **kwargs) + self.obs_act_out = self._make_obs("act_out", **kwargs) + self.obs_weight = self._make_obs("weight", **kwargs) + + def _all_observers(self): + # required by QuantModuleBase + return (self.obs_act_in, self.obs_act_out, self.obs_weight) + + +class TestPTQConfig(unittest.TestCase): + def test_default_dtype_applied(self): + cfg = PTQConfig(default_dtype=DType.uint(8)) + w = DummyWrapper(cfg) + self.assertEqual(w.obs_act_in.dtype, DType.uint(8)) + self.assertEqual(w.obs_act_out.dtype, DType.uint(8)) + + def test_per_observer_dtype_override(self): + cfg = PTQConfig( + default_dtype=DType.uint(8), + overrides={"act_out": {"dtype": DType.uint(4)}}, + ) + w = DummyWrapper(cfg) + self.assertEqual(w.obs_act_in.dtype, DType.uint(8)) # default + self.assertEqual(w.obs_act_out.dtype, DType.uint(4)) # override + + def test_observer_override(self): + cfg = PTQConfig( + default_dtype=DType.uint(8), + overrides={ + "act_in": { + "observer": EMAObserver, + "dtype": DType.uint(8), + } + }, + ) + w = DummyWrapper(cfg) + self.assertIsInstance(w.obs_act_in, EMAObserver) + self.assertEqual(w.obs_act_in.dtype, DType.uint(8)) + self.assertIsInstance(w.obs_act_out, MinMaxObserver) # unaffected + + +class TestPTQConfigChild(unittest.TestCase): + def test_child_inherits_default_dtype(self): + parent = PTQConfig(default_dtype=DType.uint(8)) + child = parent.child("gate_proj") + self.assertEqual(child.default_dtype, DType.uint(8)) + self.assertEqual(child.default_dtype, DType.uint(8)) + + def test_child_override_applied(self): + parent = PTQConfig( + default_dtype=DType.uint(8), + overrides={ + "gate_proj": {"act_in": {"dtype": DType.uint(4)}}, + "mul": {"dtype": DType.uint(4)}, + }, + ) + gate_cfg = parent.child("gate_proj") + up_cfg = parent.child("up_proj") # no specific override + + # gate_proj.act_in should pick up uint4 + self.assertEqual(gate_cfg.get_kwargs("act_in")["dtype"], DType.uint(4)) + # top-level override still visible to parent + self.assertEqual(parent.get_kwargs("mul")["dtype"], DType.uint(4)) + + def test_child_is_view_not_copy(self): + parent = PTQConfig(default_dtype=DType.uint(8)) + child = parent.child("dummy") + # mutate child's overrides → parent unaffected + child.overrides["x"] = {"dtype": DType.int(8)} # type: ignore[index] + self.assertNotIn("x", parent.overrides) + + def test_child_inherits_default_qscheme(self): + parent = PTQConfig(default_qscheme=QScheme.PER_CHANNEL_ASYMM) + child = parent.child("gate_proj") + self.assertEqual(child.default_qscheme, QScheme.PER_CHANNEL_ASYMM) + + def test_child_inherits_defaults(self): + parent = PTQConfig( + default_dtype=DType.uint(8), + default_qscheme=None, + wrapper_variant="decode", + ) + child = parent.child("gate_proj") + + self.assertEqual(child.default_dtype, DType.uint(8)) + self.assertEqual(child.default_qscheme, QScheme.PER_TENSOR_ASYMM) + self.assertEqual(child.wrapper_variant, "decode") + + def test_child_override_applied_and_normalized(self): + parent = PTQConfig( + default_dtype=DType.int(16), + overrides={ + "gate_proj": {"act_in": {"dtype": DType.uint(4)}}, + "mul": {"dtype": DType.uint(4)}, + }, + ) + gate_cfg = parent.child("gate_proj") + up_cfg = parent.child("up_proj") + + self.assertEqual(gate_cfg.get_kwargs("act_in")["dtype"], DType.uint(4)) + self.assertEqual( + gate_cfg.get_kwargs("act_in")["qscheme"], + QScheme.PER_TENSOR_ASYMM, + ) + self.assertEqual(parent.get_kwargs("mul")["dtype"], DType.uint(4)) + self.assertEqual(parent.get_kwargs("mul")["qscheme"], QScheme.PER_TENSOR_ASYMM) + self.assertEqual(up_cfg.get_kwargs("act_in"), {}) + + def test_child_isolated_from_parent_mutation(self): + parent = PTQConfig(default_dtype=DType.uint(8)) + child = parent.child("dummy") + child.overrides["x"] = {"dtype": DType.int(8)} # type: ignore[index] + child.normalize_overrides() + + self.assertNotIn("x", parent.overrides) + self.assertEqual(child.overrides["x"]["qscheme"], QScheme.PER_TENSOR_SYMM) + + +# ---- Dummy observers for testing (just to distinguish classes) ---- +class DummyObserverA(AffineObserverBase): + def _update_stats(self, x): + return super()._update_stats(x) + + +class DummyObserverB(AffineObserverBase): + def _update_stats(self, x): + return super()._update_stats(x) + + +class TestObserverAndDTypePrecedence(unittest.TestCase): + """ + Ensure `_make_obs()` applies 3-level precedence to dtype/observer: + + 1) User override in PTQConfig.overrides[name] + 2) Wrapper default passed via `_make_obs(..., dtype=..., observer=...)` + 3) PTQConfig.default_dtype or default_observer + + And other kwargs follow: + user override > wrapper default + """ + + def test_user_override_wins(self): + """ + If user supplies both dtype and observer, they must override + both wrapper defaults and PTQConfig defaults. + """ + qcfg = PTQConfig( + default_dtype=DType.uint(8), + default_observer=MinMaxObserver, # type: ignore[type-abstract] + overrides={ + "act_in": { + "dtype": DType.uint(4), + "observer": DummyObserverA, + "qscheme": QScheme.PER_TENSOR_ASYMM, # user override for another kw + "channel_axis": None, + } + }, + ) + + # Wrapper defaults: dtype=6bit, observer=DummyObserverB, qscheme=PER_CHANNEL + wrapper = DummyWrapper( + qcfg, + dtype=DType.uint(6), + observer=DummyObserverB, + qscheme=QScheme.PER_CHANNEL_ASYMM, + channel_axis=0, + ) + + self.assertIsInstance(wrapper.obs_act_in, DummyObserverA) + self.assertEqual(wrapper.obs_act_in.dtype, DType.uint(4)) + # user override wins for qscheme/channel_axis + self.assertEqual(wrapper.obs_act_in.qscheme, QScheme.PER_TENSOR_ASYMM) + self.assertIsNone(wrapper.obs_act_in.channel_axis) + + def test_wrapper_default_when_no_user_override(self): + """ + If the user supplies nothing for a given name, wrapper defaults must + override PTQConfig defaults. + """ + qcfg = PTQConfig( + default_dtype=DType.uint(8), + default_observer=MinMaxObserver, # type: ignore[type-abstract] + overrides={ + # nothing for 'act_out' + }, + ) + + wrapper = DummyWrapper( + qcfg, + dtype=DType.uint(6), + observer=DummyObserverB, + qscheme=QScheme.PER_CHANNEL_ASYMM, + channel_axis=1, + ) + + self.assertIsInstance(wrapper.obs_act_out, DummyObserverB) + self.assertEqual(wrapper.obs_act_out.dtype, DType.uint(6)) + self.assertEqual(wrapper.obs_act_out.qscheme, QScheme.PER_CHANNEL_ASYMM) + self.assertEqual(wrapper.obs_act_out.channel_axis, 1) + + def test_other_kwargs_user_override_precedence(self): + """ + For keys without PTQConfig-level defaults (like qscheme/channel_axis), + user overrides > wrapper defaults. + """ + qcfg = PTQConfig( + default_dtype=DType.uint(8), + default_observer=MinMaxObserver, # type: ignore[type-abstract] + overrides={ + "act_in": { + "qscheme": QScheme.PER_TENSOR_ASYMM, + "channel_axis": None, + } + }, + ) + + # wrapper defaults try to force a per-channel scheme + wrapper = DummyWrapper( + qcfg, + qscheme=QScheme.PER_CHANNEL_ASYMM, + channel_axis=2, + ) + + # user override must win + self.assertEqual(wrapper.obs_act_in.qscheme, QScheme.PER_TENSOR_ASYMM) + self.assertIsNone(wrapper.obs_act_in.channel_axis) + + def test_PTQConfig_get_kwargs_does_not_inject_dtype(self): + """ + Ensure PTQConfig.get_kwargs() itself doesn't inject dtype anymore. + It should return exactly the user override dict. + """ + qcfg = PTQConfig( + default_dtype=DType.uint(8), + overrides={"bar": {"qscheme": QScheme.PER_TENSOR_ASYMM}}, + ) + kw = qcfg.get_kwargs("bar") + self.assertIn("qscheme", kw) + self.assertNotIn("dtype", kw) + + def test_config_default_when_neither_user_nor_wrapper(self): + """ + If neither user nor wrapper provides dtype/qscheme, fallback to PTQConfig defaults. + """ + qcfg = PTQConfig( + default_dtype=DType.uint(8), + default_qscheme=QScheme.PER_TENSOR_ASYMM, + default_observer=MinMaxObserver, # type: ignore[type-abstract] + overrides={}, + ) + + wrapper = DummyWrapper(qcfg) + + self.assertIsInstance(wrapper.obs_act_in, MinMaxObserver) + self.assertEqual(wrapper.obs_act_in.dtype, DType.uint(8)) + self.assertEqual(wrapper.obs_act_in.qscheme, qcfg.default_qscheme) + self.assertIsNone(wrapper.obs_act_in.channel_axis) + + +class TestPTQConfigQScheme(unittest.TestCase): + def test_default_qscheme_applied(self): + cfg = PTQConfig( + default_dtype=DType.uint(8), + default_qscheme=QScheme.PER_CHANNEL_ASYMM, + ) + w = DummyWrapper(cfg) + self.assertEqual(w.obs_act_in.qscheme, QScheme.PER_CHANNEL_ASYMM) + self.assertEqual(w.obs_act_out.qscheme, QScheme.PER_CHANNEL_ASYMM) + + def test_per_observer_qscheme_override(self): + cfg = PTQConfig( + default_dtype=DType.uint(8), + default_qscheme=QScheme.PER_CHANNEL_ASYMM, + overrides={ + "act_out": {"dtype": DType.int(16)}, + }, + ) + w = DummyWrapper(cfg) + self.assertEqual(w.obs_act_in.qscheme, QScheme.PER_CHANNEL_ASYMM) + self.assertEqual(w.obs_act_out.qscheme, QScheme.PER_TENSOR_SYMM) + + +class TestPTQConfigDefaults(unittest.TestCase): + def test_unsigned_default_auto_qscheme_for_activation(self): + cfg = PTQConfig(default_dtype=DType.uint(8), default_qscheme=None) + wrapper = DummyWrapper(cfg) + + self.assertEqual(cfg.default_qscheme, QScheme.PER_TENSOR_ASYMM) + self.assertEqual(wrapper.obs_act_in.dtype, DType.uint(8)) + self.assertEqual(wrapper.obs_act_in.qscheme, QScheme.PER_TENSOR_ASYMM) + self.assertEqual(wrapper.obs_act_out.qscheme, QScheme.PER_TENSOR_ASYMM) + + def test_signed_default_auto_qscheme(self): + cfg = PTQConfig(default_dtype=DType.int(16), default_qscheme=None) + wrapper = DummyWrapper(cfg) + + self.assertEqual(cfg.default_qscheme, QScheme.PER_TENSOR_SYMM) + self.assertEqual(wrapper.obs_act_in.dtype, DType.int(16)) + self.assertEqual(wrapper.obs_act_in.qscheme, QScheme.PER_TENSOR_SYMM) + + def test_explicit_invalid_default_pair_raises(self): + with self.assertRaises(ValueError): + PTQConfig( + default_dtype=DType.uint(8), + default_qscheme=QScheme.PER_TENSOR_SYMM, + ) + + def test_default_weight_observer_uses_wrapper_default_when_provided(self): + cfg = PTQConfig(default_dtype=DType.uint(8), default_qscheme=None) + wrapper = DummyWrapper( + cfg, + dtype=DType.int(8), + qscheme=QScheme.PER_CHANNEL_SYMM, + ) + + self.assertEqual(wrapper.obs_weight.dtype, DType.int(8)) + self.assertEqual(wrapper.obs_weight.qscheme, QScheme.PER_CHANNEL_SYMM) + + +class TestPTQConfigOverrides(unittest.TestCase): + def test_per_observer_dtype_override_auto_infers_qscheme(self): + cfg = PTQConfig( + default_dtype=DType.int(16), + overrides={"act_out": {"dtype": DType.uint(4)}}, + ) + wrapper = DummyWrapper(cfg) + + self.assertEqual(wrapper.obs_act_in.dtype, DType.int(16)) + self.assertEqual(wrapper.obs_act_in.qscheme, QScheme.PER_TENSOR_SYMM) + self.assertEqual(wrapper.obs_act_out.dtype, DType.uint(4)) + self.assertEqual(wrapper.obs_act_out.qscheme, QScheme.PER_TENSOR_ASYMM) + + def test_weight_override_auto_infers_per_channel_asymmetric(self): + cfg = PTQConfig( + default_dtype=DType.int(16), + overrides={"weight": {"dtype": DType.uint(8)}}, + ) + wrapper = DummyWrapper(cfg) + + self.assertEqual(wrapper.obs_weight.dtype, DType.uint(8)) + self.assertEqual(wrapper.obs_weight.qscheme, QScheme.PER_CHANNEL_ASYMM) + + def test_observer_override(self): + cfg = PTQConfig( + default_dtype=DType.uint(8), + overrides={ + "act_in": { + "observer": EMAObserver, + "dtype": DType.uint(8), + } + }, + ) + wrapper = DummyWrapper(cfg) + self.assertIsInstance(wrapper.obs_act_in, EMAObserver) + self.assertEqual(wrapper.obs_act_in.dtype, DType.uint(8)) + self.assertEqual(wrapper.obs_act_in.qscheme, QScheme.PER_TENSOR_ASYMM) + self.assertIsInstance(wrapper.obs_act_out, MinMaxObserver) + + def test_get_kwargs_returns_normalized_override(self): + cfg = PTQConfig( + default_dtype=DType.int(16), + overrides={"weight": {"dtype": DType.uint(8)}}, + ) + kwargs = cfg.get_kwargs("weight") + + self.assertEqual(kwargs["dtype"], DType.uint(8)) + self.assertEqual(kwargs["qscheme"], QScheme.PER_CHANNEL_ASYMM) + + def test_explicit_invalid_override_pair_raises(self): + with self.assertRaises(ValueError): + PTQConfig( + default_dtype=DType.int(16), + overrides={ + "weight": { + "dtype": DType.uint(8), + "qscheme": QScheme.PER_CHANNEL_SYMM, + } + }, + ) + + def test_other_kwargs_user_override_precedence(self): + cfg = PTQConfig( + default_dtype=DType.uint(8), + overrides={ + "act_in": { + "qscheme": QScheme.PER_TENSOR_ASYMM, + "channel_axis": None, + } + }, + ) + + wrapper = DummyWrapper( + cfg, + qscheme=QScheme.PER_CHANNEL_ASYMM, + channel_axis=2, + ) + + self.assertEqual(wrapper.obs_act_in.qscheme, QScheme.PER_TENSOR_ASYMM) + self.assertIsNone(wrapper.obs_act_in.channel_axis) + + def test_user_override_wins_over_wrapper_defaults(self): + qcfg = PTQConfig( + default_dtype=DType.uint(8), + default_observer=MinMaxObserver, # type: ignore[type-abstract] + overrides={ + "act_in": { + "dtype": DType.uint(4), + "observer": DummyObserverA, + "qscheme": QScheme.PER_TENSOR_ASYMM, + "channel_axis": None, + } + }, + ) + + wrapper = DummyWrapper( + qcfg, + dtype=DType.uint(6), + observer=DummyObserverB, + qscheme=QScheme.PER_CHANNEL_ASYMM, + channel_axis=0, + ) + + self.assertIsInstance(wrapper.obs_act_in, DummyObserverA) + self.assertEqual(wrapper.obs_act_in.dtype, DType.uint(4)) + self.assertEqual(wrapper.obs_act_in.qscheme, QScheme.PER_TENSOR_ASYMM) + self.assertIsNone(wrapper.obs_act_in.channel_axis) + + def test_wrapper_default_when_no_user_override(self): + qcfg = PTQConfig( + default_dtype=DType.uint(8), + default_observer=MinMaxObserver, # type: ignore[type-abstract] + ) + + wrapper = DummyWrapper( + qcfg, + dtype=DType.uint(6), + observer=DummyObserverB, + qscheme=QScheme.PER_CHANNEL_ASYMM, + channel_axis=1, + ) + + self.assertIsInstance(wrapper.obs_act_out, DummyObserverB) + self.assertEqual(wrapper.obs_act_out.dtype, DType.uint(6)) + self.assertEqual(wrapper.obs_act_out.qscheme, QScheme.PER_CHANNEL_ASYMM) + self.assertEqual(wrapper.obs_act_out.channel_axis, 1) + + +class TestPTQConfigMutationHelpers(unittest.TestCase): + def test_set_override_for_activation_infers_per_tensor_asymmetric(self): + cfg = PTQConfig(default_dtype=DType.int(16)) + cfg.set_override(("block", "act_in"), {"dtype": DType.uint(8)}) + + self.assertEqual( + cfg.overrides["block"]["act_in"]["qscheme"], + QScheme.PER_TENSOR_ASYMM, + ) + + def test_set_override_for_weight_infers_per_channel_asymmetric(self): + cfg = PTQConfig(default_dtype=DType.int(16)) + cfg.set_override(("block", "weight"), {"dtype": DType.uint(8)}) + + self.assertEqual( + cfg.overrides["block"]["weight"]["qscheme"], + QScheme.PER_CHANNEL_ASYMM, + ) + + def test_set_override_empty_path_raises(self): + cfg = PTQConfig(default_dtype=DType.int(16)) + with self.assertRaises(ValueError): + cfg.set_override((), {"dtype": DType.uint(8)}) + + def test_normalize_overrides_after_direct_mutation(self): + cfg = PTQConfig(default_dtype=DType.int(16)) + cfg.overrides["block"] = {"weight": {"dtype": DType.uint(8)}} # type: ignore[index] + cfg.normalize_overrides() + + self.assertEqual( + cfg.overrides["block"]["weight"]["qscheme"], + QScheme.PER_CHANNEL_ASYMM, + ) + + def test_normalize_overrides_rejects_invalid_direct_mutation(self): + cfg = PTQConfig(default_dtype=DType.int(16)) + cfg.overrides["block"] = { # type: ignore[index] + "weight": { + "dtype": DType.uint(8), + "qscheme": QScheme.PER_CHANNEL_SYMM, + } + } + + with self.assertRaises(ValueError): + cfg.normalize_overrides() diff --git a/test/quantization/wrapq/test_quant_config.py b/test/quantization/wrapq/test_quant_config.py deleted file mode 100644 index 6b23b4c9..00000000 --- a/test/quantization/wrapq/test_quant_config.py +++ /dev/null @@ -1,270 +0,0 @@ -# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -from tico.quantization.config.ptq import PTQConfig - -from tico.quantization.wrapq.dtypes import DType -from tico.quantization.wrapq.observers.affine_base import AffineObserverBase -from tico.quantization.wrapq.observers.ema import EMAObserver -from tico.quantization.wrapq.observers.minmax import MinMaxObserver -from tico.quantization.wrapq.qscheme import QScheme -from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase - - -class DummyWrapper(QuantModuleBase): - """Minimal wrapper to expose `_make_obs` and store the created observer.""" - - def __init__(self, qcfg, **kwargs): - super().__init__(qcfg) - # kwargs here are wrapper-level defaults for _make_obs - self.obs_act_in = self._make_obs("act_in", **kwargs) - self.obs_act_out = self._make_obs("act_out", **kwargs) - - def _all_observers(self): - # required by QuantModuleBase - return (self.obs_act_in, self.obs_act_out) - - -class TestPTQConfig(unittest.TestCase): - def test_default_dtype_applied(self): - cfg = PTQConfig(default_dtype=DType.uint(8)) - w = DummyWrapper(cfg) - self.assertEqual(w.obs_act_in.dtype, DType.uint(8)) - self.assertEqual(w.obs_act_out.dtype, DType.uint(8)) - - def test_per_observer_dtype_override(self): - cfg = PTQConfig( - default_dtype=DType.uint(8), - overrides={"act_out": {"dtype": DType.uint(4)}}, - ) - w = DummyWrapper(cfg) - self.assertEqual(w.obs_act_in.dtype, DType.uint(8)) # default - self.assertEqual(w.obs_act_out.dtype, DType.uint(4)) # override - - def test_observer_override(self): - cfg = PTQConfig( - default_dtype=DType.uint(8), - overrides={ - "act_in": { - "observer": EMAObserver, - "dtype": DType.uint(8), - } - }, - ) - w = DummyWrapper(cfg) - self.assertIsInstance(w.obs_act_in, EMAObserver) - self.assertEqual(w.obs_act_in.dtype, DType.uint(8)) - self.assertIsInstance(w.obs_act_out, MinMaxObserver) # unaffected - - -class TestPTQConfigChild(unittest.TestCase): - def test_child_inherits_default_dtype(self): - parent = PTQConfig(default_dtype=DType.uint(8)) - child = parent.child("gate_proj") - self.assertEqual(child.default_dtype, DType.uint(8)) - self.assertEqual(child.default_dtype, DType.uint(8)) - - def test_child_override_applied(self): - parent = PTQConfig( - default_dtype=DType.uint(8), - overrides={ - "gate_proj": {"act_in": {"dtype": DType.uint(4)}}, - "mul": {"dtype": DType.uint(4)}, - }, - ) - gate_cfg = parent.child("gate_proj") - up_cfg = parent.child("up_proj") # no specific override - - # gate_proj.act_in should pick up uint4 - self.assertEqual(gate_cfg.get_kwargs("act_in")["dtype"], DType.uint(4)) - # top-level override still visible to parent - self.assertEqual(parent.get_kwargs("mul")["dtype"], DType.uint(4)) - - def test_child_is_view_not_copy(self): - parent = PTQConfig(default_dtype=DType.uint(8)) - child = parent.child("dummy") - # mutate child's overrides → parent unaffected - child.overrides["x"] = {"dtype": DType.int(8)} # type: ignore[index] - self.assertNotIn("x", parent.overrides) - - def test_child_inherits_default_qscheme(self): - parent = PTQConfig(default_qscheme=QScheme.PER_CHANNEL_SYMM) - child = parent.child("gate_proj") - self.assertEqual(child.default_qscheme, QScheme.PER_CHANNEL_SYMM) - - -# ---- Dummy observers for testing (just to distinguish classes) ---- -class DummyObserverA(AffineObserverBase): - def _update_stats(self, x): - return super()._update_stats(x) - - -class DummyObserverB(AffineObserverBase): - def _update_stats(self, x): - return super()._update_stats(x) - - -class TestObserverAndDTypePrecedence(unittest.TestCase): - """ - Ensure `_make_obs()` applies 3-level precedence to dtype/observer: - - 1) User override in PTQConfig.overrides[name] - 2) Wrapper default passed via `_make_obs(..., dtype=..., observer=...)` - 3) PTQConfig.default_dtype or default_observer - - And other kwargs follow: - user override > wrapper default - """ - - def test_user_override_wins(self): - """ - If user supplies both dtype and observer, they must override - both wrapper defaults and PTQConfig defaults. - """ - qcfg = PTQConfig( - default_dtype=DType.uint(8), - default_observer=MinMaxObserver, # type: ignore[type-abstract] - overrides={ - "act_in": { - "dtype": DType.uint(4), - "observer": DummyObserverA, - "qscheme": QScheme.PER_TENSOR_ASYMM, # user override for another kw - "channel_axis": None, - } - }, - ) - - # Wrapper defaults: dtype=6bit, observer=DummyObserverB, qscheme=PER_CHANNEL - wrapper = DummyWrapper( - qcfg, - dtype=DType.uint(6), - observer=DummyObserverB, - qscheme=QScheme.PER_CHANNEL_ASYMM, - channel_axis=0, - ) - - self.assertIsInstance(wrapper.obs_act_in, DummyObserverA) - self.assertEqual(wrapper.obs_act_in.dtype, DType.uint(4)) - # user override wins for qscheme/channel_axis - self.assertEqual(wrapper.obs_act_in.qscheme, QScheme.PER_TENSOR_ASYMM) - self.assertIsNone(wrapper.obs_act_in.channel_axis) - - def test_wrapper_default_when_no_user_override(self): - """ - If the user supplies nothing for a given name, wrapper defaults must - override PTQConfig defaults. - """ - qcfg = PTQConfig( - default_dtype=DType.uint(8), - default_observer=MinMaxObserver, # type: ignore[type-abstract] - overrides={ - # nothing for 'act_out' - }, - ) - - wrapper = DummyWrapper( - qcfg, - dtype=DType.uint(6), - observer=DummyObserverB, - qscheme=QScheme.PER_CHANNEL_ASYMM, - channel_axis=1, - ) - - self.assertIsInstance(wrapper.obs_act_out, DummyObserverB) - self.assertEqual(wrapper.obs_act_out.dtype, DType.uint(6)) - self.assertEqual(wrapper.obs_act_out.qscheme, QScheme.PER_CHANNEL_ASYMM) - self.assertEqual(wrapper.obs_act_out.channel_axis, 1) - - def test_other_kwargs_user_override_precedence(self): - """ - For keys without PTQConfig-level defaults (like qscheme/channel_axis), - user overrides > wrapper defaults. - """ - qcfg = PTQConfig( - default_dtype=DType.uint(8), - default_observer=MinMaxObserver, # type: ignore[type-abstract] - overrides={ - "act_in": { - "qscheme": QScheme.PER_TENSOR_ASYMM, - "channel_axis": None, - } - }, - ) - - # wrapper defaults try to force a per-channel scheme - wrapper = DummyWrapper( - qcfg, - qscheme=QScheme.PER_CHANNEL_ASYMM, - channel_axis=2, - ) - - # user override must win - self.assertEqual(wrapper.obs_act_in.qscheme, QScheme.PER_TENSOR_ASYMM) - self.assertIsNone(wrapper.obs_act_in.channel_axis) - - def test_PTQConfig_get_kwargs_does_not_inject_dtype(self): - """ - Ensure PTQConfig.get_kwargs() itself doesn't inject dtype anymore. - It should return exactly the user override dict. - """ - qcfg = PTQConfig( - default_dtype=DType.uint(8), - overrides={"bar": {"qscheme": QScheme.PER_TENSOR_ASYMM}}, - ) - kw = qcfg.get_kwargs("bar") - self.assertIn("qscheme", kw) - self.assertNotIn("dtype", kw) - - def test_config_default_when_neither_user_nor_wrapper(self): - """ - If neither user nor wrapper provides dtype/qscheme, fallback to PTQConfig defaults. - """ - qcfg = PTQConfig( - default_dtype=DType.uint(8), - default_qscheme=QScheme.PER_TENSOR_ASYMM, - default_observer=MinMaxObserver, # type: ignore[type-abstract] - overrides={}, - ) - - wrapper = DummyWrapper(qcfg) - - self.assertIsInstance(wrapper.obs_act_in, MinMaxObserver) - self.assertEqual(wrapper.obs_act_in.dtype, DType.uint(8)) - self.assertEqual(wrapper.obs_act_in.qscheme, qcfg.default_qscheme) - self.assertIsNone(wrapper.obs_act_in.channel_axis) - - -class TestPTQConfigQScheme(unittest.TestCase): - def test_default_qscheme_applied(self): - cfg = PTQConfig( - default_dtype=DType.uint(8), - default_qscheme=QScheme.PER_CHANNEL_SYMM, - ) - w = DummyWrapper(cfg) - self.assertEqual(w.obs_act_in.qscheme, QScheme.PER_CHANNEL_SYMM) - self.assertEqual(w.obs_act_out.qscheme, QScheme.PER_CHANNEL_SYMM) - - def test_per_observer_qscheme_override(self): - cfg = PTQConfig( - default_dtype=DType.uint(8), - default_qscheme=QScheme.PER_CHANNEL_ASYMM, - overrides={ - "act_out": {"qscheme": QScheme.PER_TENSOR_SYMM}, - }, - ) - w = DummyWrapper(cfg) - self.assertEqual(w.obs_act_in.qscheme, QScheme.PER_CHANNEL_ASYMM) - self.assertEqual(w.obs_act_out.qscheme, QScheme.PER_TENSOR_SYMM) diff --git a/test/quantization/wrapq/wrappers/test_quant_module_base.py b/test/quantization/wrapq/wrappers/test_quant_module_base.py index 12454a27..6e4f0950 100644 --- a/test/quantization/wrapq/wrappers/test_quant_module_base.py +++ b/test/quantization/wrapq/wrappers/test_quant_module_base.py @@ -184,7 +184,9 @@ def _all_observers(self): class TestQuantModuleQScheme(unittest.TestCase): def test_config_default_qscheme(self): - cfg = PTQConfig(default_qscheme=QScheme.PER_CHANNEL_SYMM) + cfg = PTQConfig( + default_dtype=DType.int(16), default_qscheme=QScheme.PER_CHANNEL_SYMM + ) qm = DummyQM(cfg) self.assertEqual(qm.obs.qscheme, QScheme.PER_CHANNEL_SYMM) @@ -199,6 +201,7 @@ def test_user_override_qscheme_wins(self): default_qscheme=QScheme.PER_TENSOR_ASYMM, overrides={ "act": { + "dtype": DType.int(16), "qscheme": QScheme.PER_CHANNEL_SYMM, "channel_axis": 0, } diff --git a/tico/quantization/config/builders.py b/tico/quantization/config/builders.py index afc53484..05e9c7bf 100644 --- a/tico/quantization/config/builders.py +++ b/tico/quantization/config/builders.py @@ -20,6 +20,22 @@ from tico.quantization.wrapq.qscheme import QScheme +def _auto_qscheme_for(dtype: DType, obs_name: Optional[str] = None) -> QScheme: + """ + Choose the default qscheme associated with a dtype and observer name. + + Default policy: + - signed dtype -> symmetric per-tensor + - unsigned dtype -> asymmetric per-tensor + - unsigned weight -> asymmetric per-channel + """ + if not dtype.signed: + if obs_name == "weight": + return QScheme.PER_CHANNEL_ASYMM + return QScheme.PER_TENSOR_ASYMM + return QScheme.PER_TENSOR_SYMM + + def _weight_dtype_from_bits(bits: int) -> DType: """ Convert a commonly used bit-width into a corresponding quantized dtype. @@ -121,6 +137,10 @@ def _build_weight_override(weight_dtype: Optional[DType]) -> Dict[str, Any]: """ Build a weight override dictionary. + The override explicitly carries both dtype and qscheme so that local dtype + changes do not accidentally inherit an incompatible or less suitable + qscheme from an outer scope. + Parameters ---------- weight_dtype : Optional[DType] @@ -134,7 +154,12 @@ def _build_weight_override(weight_dtype: Optional[DType]) -> Dict[str, Any]: """ if weight_dtype is None: return {} - return {"weight": {"dtype": weight_dtype}} + return { + "weight": { + "dtype": weight_dtype, + "qscheme": _auto_qscheme_for(weight_dtype, "weight"), + } + } def _build_norm_override( @@ -162,9 +187,13 @@ def _build_norm_override( if norm_dtype is not None: override["dtype"] = norm_dtype + override["qscheme"] = _auto_qscheme_for(norm_dtype) if norm_weight_dtype is not None: - override["weight"] = {"dtype": norm_weight_dtype} + override["weight"] = { + "dtype": norm_weight_dtype, + "qscheme": _auto_qscheme_for(norm_weight_dtype, "weight"), + } return override diff --git a/tico/quantization/config/ptq.py b/tico/quantization/config/ptq.py index ee8f7380..6ef15c80 100644 --- a/tico/quantization/config/ptq.py +++ b/tico/quantization/config/ptq.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from copy import deepcopy from dataclasses import dataclass, field -from typing import Any, Dict, Literal, Mapping, Type +from typing import Any, Dict, Iterable, Literal, Mapping, MutableMapping, Optional, Type from tico.quantization.config.base import BaseConfig from tico.quantization.wrapq.dtypes import DType @@ -24,6 +25,115 @@ WrapperVariant = Literal["common", "prefill", "decode"] +def _dtype_is_unsigned(dtype: DType) -> bool: + """ + Return True when the dtype is unsigned. + """ + return not dtype.signed + + +def _auto_qscheme_for(dtype: DType, obs_name: Optional[str] = None) -> QScheme: + """ + Choose a default qscheme from the effective dtype and observer name. + + Default policy: + - signed dtype -> symmetric per-tensor + - unsigned dtype -> asymmetric per-tensor + - unsigned weight -> asymmetric per-channel + """ + if _dtype_is_unsigned(dtype): + if obs_name == "weight": + return QScheme.PER_CHANNEL_ASYMM + return QScheme.PER_TENSOR_ASYMM + return QScheme.PER_TENSOR_SYMM + + +def _resolve_qscheme( + *, + dtype: DType, + qscheme: Optional[QScheme], + context: str, + obs_name: Optional[str] = None, +) -> QScheme: + """ + Resolve a dtype/qscheme pair using the option-C policy. + + Resolution policy: + 1. If `qscheme` is None, infer it from `dtype` and `obs_name`. + 2. If the caller explicitly provides an incompatible pair, raise. + """ + resolved_qscheme = qscheme or _auto_qscheme_for(dtype, obs_name) + + if _dtype_is_unsigned(dtype) and resolved_qscheme.is_symmetric(): + raise ValueError( + f"Invalid quantization config at {context}: unsigned dtype " + f"{dtype!r} cannot be paired with symmetric qscheme " + f"{resolved_qscheme!r}." + ) + + return resolved_qscheme + + +def _normalize_overrides( + mapping: Mapping[str, Any], + *, + inherited_dtype: DType, + inherited_qscheme: QScheme, + context: str, + current_name: Optional[str] = None, +) -> Dict[str, Any]: + """ + Recursively normalize and validate nested override mappings. + + Any node that provides `dtype` but omits `qscheme` receives an inferred + qscheme derived from that dtype. Explicit incompatible pairs are rejected + immediately. + + The current mapping key is tracked as `current_name` so that special + observer names such as `weight` can receive a more suitable automatic + default qscheme. + """ + normalized: Dict[str, Any] = dict(mapping) + + local_dtype = normalized.get("dtype", inherited_dtype) + local_qscheme = normalized.get("qscheme", inherited_qscheme) + + if "dtype" in normalized: + normalized["qscheme"] = _resolve_qscheme( + dtype=local_dtype, + qscheme=normalized.get("qscheme"), + context=context, + obs_name=current_name, + ) + local_qscheme = normalized["qscheme"] + elif "qscheme" in normalized: + local_qscheme = _resolve_qscheme( + dtype=local_dtype, + qscheme=normalized["qscheme"], + context=context, + obs_name=current_name, + ) + else: + _resolve_qscheme( + dtype=local_dtype, + qscheme=local_qscheme, + context=context, + obs_name=current_name, + ) + + for key, value in list(normalized.items()): + if isinstance(value, Mapping): + normalized[key] = _normalize_overrides( + value, + inherited_dtype=local_dtype, + inherited_qscheme=local_qscheme, + context=f"{context}.{key}", + current_name=key, + ) + + return normalized + + @dataclass class PTQConfig(BaseConfig): """ @@ -38,10 +148,19 @@ class PTQConfig(BaseConfig): default_observer : Type[ObserverBase], optional Observer class to instantiate when the caller (or an override) does not provide a `observer` key. - default_qscheme : QScheme - Fallback quantization scheme (per-tensor / per-channel, - asymmetric / symmetric) for observers that DO NOT receive an explicit - override. + default_qscheme : Optional[QScheme] + Fallback quantization scheme for observers that do not receive an + explicit override. + + When set to `None`, the qscheme is inferred automatically from the + effective dtype and, for special observer names such as `weight`, + from the observer role: + - unsigned activation-like dtype -> `QScheme.PER_TENSOR_ASYMM` + - unsigned weight dtype -> `QScheme.PER_CHANNEL_ASYMM` + - signed dtype -> `QScheme.PER_TENSOR_SYMM` + + When explicitly provided, the pair is validated. Incompatible pairs, + such as unsigned dtype with symmetric qscheme, raise immediately. wrapper_variant : str Execution specialization used when resolving quantization wrappers. @@ -114,17 +233,111 @@ class PTQConfig(BaseConfig): default_dtype: DType = DType.uint(8) default_observer: Type[ObserverBase] = MinMaxObserver # type: ignore[type-abstract] - default_qscheme: QScheme = QScheme.PER_TENSOR_ASYMM + default_qscheme: Optional[QScheme] = None wrapper_variant: WrapperVariant = "common" overrides: Mapping[str, Mapping[str, Any]] = field(default_factory=dict) model_args: Mapping[str, Any] = field(default_factory=dict) # If True, any module that cannot be wrapped will raise. strict_wrap: bool = True + def __post_init__(self) -> None: + """ + Resolve automatic qscheme defaults and validate nested overrides. + """ + self.default_qscheme = _resolve_qscheme( + dtype=self.default_dtype, + qscheme=self.default_qscheme, + context="PTQConfig.default_qscheme", + ) + self.normalize_overrides() + @property def name(self) -> str: return "ptq" + def normalize_overrides(self) -> None: + """ + Normalize and validate the entire override tree in-place. + + This method is useful when callers directly mutate `self.overrides` + after construction and want to retroactively apply automatic qscheme + inference and compatibility checks. + """ + assert self.default_qscheme is not None + self.overrides = _normalize_overrides( + self.overrides, + inherited_dtype=self.default_dtype, + inherited_qscheme=self.default_qscheme, + context="PTQConfig.overrides", + ) + + def set_override( + self, + path: Iterable[str], + value: Mapping[str, Any], + ) -> None: + """ + Set a nested override and normalize only the affected subtree. + + Parameters + ---------- + path : Iterable[str] + Hierarchical path inside `self.overrides`. + Example: `("model", "layers", "0", "self_attn", "o_proj", "weight")` + value : Mapping[str, Any] + Override payload to assign at the target path. + + Notes + ----- + The inserted subtree is normalized immediately, so callers may provide + only `dtype` and rely on automatic qscheme inference. + """ + keys = tuple(path) + if not keys: + raise ValueError("Override path must not be empty.") + + root: MutableMapping[str, Any] = dict(self.overrides) + current: MutableMapping[str, Any] = root + parent_dtype = self.default_dtype + parent_qscheme = self.default_qscheme + context = "PTQConfig.overrides" + + for key in keys[:-1]: + context = f"{context}.{key}" + next_value = current.get(key) + if isinstance(next_value, Mapping): + child = dict(next_value) + elif next_value is None: + child = {} + else: + raise ValueError( + f"Cannot create nested override under non-mapping node at {context}." + ) + + current[key] = child + current = child + + local_dtype = current.get("dtype", parent_dtype) + parent_qscheme = _resolve_qscheme( + dtype=local_dtype, + qscheme=current.get("qscheme", parent_qscheme), + context=context, + obs_name=key, + ) + parent_dtype = local_dtype + + assert parent_qscheme is not None + leaf_key = keys[-1] + leaf_context = f"{context}.{leaf_key}" + current[leaf_key] = _normalize_overrides( + deepcopy(value), + inherited_dtype=parent_dtype, + inherited_qscheme=parent_qscheme, + context=leaf_context, + current_name=leaf_key, + ) + self.overrides = root + def get_kwargs(self, obs_name: str) -> Dict[str, Any]: """ Return user-specified kwargs for *obs_name* inside **this** wrapper.