From 06d8e796dbaaea136036f98eb8495f4b2db895c9 Mon Sep 17 00:00:00 2001 From: "s.malakhov" Date: Mon, 23 Mar 2026 17:38:03 +0300 Subject: [PATCH] [quantization] Process `past_key_values` This PR processes `past_key_values` in QuantLlamaModel if `use_cache` was set. TICO-DCO-1.0-Signed-off-by: s.malakhov --- .../wrapq/wrappers/llama/test_quant_model.py | 65 +++++++++++++++++++ .../wrapq/wrappers/llama/quant_model.py | 26 ++++++-- 2 files changed, 87 insertions(+), 4 deletions(-) diff --git a/test/quantization/wrapq/wrappers/llama/test_quant_model.py b/test/quantization/wrapq/wrappers/llama/test_quant_model.py index d7645ab3..29a3d28b 100644 --- a/test/quantization/wrapq/wrappers/llama/test_quant_model.py +++ b/test/quantization/wrapq/wrappers/llama/test_quant_model.py @@ -109,3 +109,68 @@ def test_forward_diff(self): self.assertGreater(diff, 0.0) self.assertLess(diff, 0.4) self.assertEqual(fp_out.shape, q_out.shape) + + +@unittest.skipUnless(has_transformers_for("llama"), skip_msg) +class TestQuantLlamaModelWithCache(unittest.TestCase): + seq_len: int + vocab_size: int + hid_layers: int + fp_model: torch.nn.Module + + @classmethod + def setUpClass(cls): + torch.manual_seed(0) + + from transformers.models.llama.configuration_llama import LlamaConfig + from transformers.models.llama.modeling_llama import LlamaModel + + cls.seq_len = 16 + cls.vocab_size = 10000 + cls.hid_layers = 3 + cfg = LlamaConfig( + hidden_size=8, + num_attention_heads=2, + num_key_value_heads=1, + head_dim=4, + attention_bias=False, + attention_dropout=0.0, + attn_implementation="eager", + num_hidden_layers=cls.hid_layers, + max_position_embeddings=cls.seq_len, + use_cache=True, + return_dict=False, + ) + cls.fp_model = LlamaModel(cfg) + + def test_model_output(self): + qmodel = QuantLlamaModel( + self.fp_model, qcfg=PTQConfig(wrapper_variant="prefill") + ) + self.assertIs(qmodel._mode, Mode.NO_QUANT) + + qmodel.enable_calibration() + self.assertIs(qmodel._mode, Mode.CALIB) + + x = torch.randint( + 0, + self.vocab_size, + ( + 1, + self.seq_len, + ), + ) + output = qmodel(x) + + assert len(output) == 2 # last_hidden_states + past_key_values + past_key_values = output[1] + assert len(past_key_values) == self.hid_layers + for index in range(self.hid_layers): + past_key_value = past_key_values[index] + assert isinstance(past_key_value, tuple) + + past_key = past_key_value[0] + assert past_key.shape[-2] == self.seq_len + + past_value = past_key_value[1] + assert past_value.shape[-2] == self.seq_len diff --git a/tico/quantization/wrapq/wrappers/llama/quant_model.py b/tico/quantization/wrapq/wrappers/llama/quant_model.py index 74325d28..5049f1e2 100644 --- a/tico/quantization/wrapq/wrappers/llama/quant_model.py +++ b/tico/quantization/wrapq/wrappers/llama/quant_model.py @@ -185,12 +185,15 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: - past_key_values = DynamicCache() + past_key_values = [] if cache_position is None: past_seen_tokens = ( - past_key_values.get_seq_length() if past_key_values is not None else 0 + 0 + if (past_key_values is None or len(past_key_values) == 0) + else past_key_values[0][0].shape[-2] ) + cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], @@ -217,7 +220,9 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + for idx, decoder_layer in enumerate( + self.layers[: self.config.num_hidden_layers] + ): if output_hidden_states: all_hidden_states += (hidden_states,) # type: ignore[operator] @@ -225,7 +230,11 @@ def forward( hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_value=( + past_key_values[idx] + if past_key_values is not None and len(past_key_values) > idx + else None + ), output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -235,6 +244,15 @@ def forward( if decoder_layer.wrapped.return_type == "tuple": hidden_states = layer_outputs[0] + elif use_cache: + hidden_states = layer_outputs[0] + assert isinstance(layer_outputs[1], tuple) + if len(past_key_values) >= idx: # type: ignore[arg-type] + # prefill mode + past_key_values += (layer_outputs[1],) # type: ignore[operator] + else: + # decode mode + past_key_values[idx] = (layer_outputs[1],) # type: ignore[index] else: hidden_states = layer_outputs