Skip to content

Commit f1e56d0

Browse files
committed
[quantization] Ouput kv-tuples
This PR outputs kv-tuples in case `use_cache` was set. TICO-DCO-1.0-Signed-off-by: s.malakhov <s.malakhov@partner.samsung.com>
1 parent 3e4b06f commit f1e56d0

8 files changed

Lines changed: 555 additions & 44 deletions

File tree

test/quantization/wrapq/wrappers/llama/test_quant_model.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,3 +109,68 @@ def test_forward_diff(self):
109109
self.assertGreater(diff, 0.0)
110110
self.assertLess(diff, 0.4)
111111
self.assertEqual(fp_out.shape, q_out.shape)
112+
113+
114+
@unittest.skipUnless(has_transformers_for("llama"), skip_msg)
115+
class TestQuantLlamaModelWithCache(unittest.TestCase):
116+
seq_len: int
117+
vocab_size: int
118+
hid_layers: int
119+
fp_model: torch.nn.Module
120+
121+
@classmethod
122+
def setUpClass(cls):
123+
torch.manual_seed(0)
124+
125+
from transformers.models.llama.configuration_llama import LlamaConfig
126+
from transformers.models.llama.modeling_llama import LlamaModel
127+
128+
cls.seq_len = 16
129+
cls.vocab_size = 10000
130+
cls.hid_layers = 3
131+
cfg = LlamaConfig(
132+
hidden_size=8,
133+
num_attention_heads=2,
134+
num_key_value_heads=1,
135+
head_dim=4,
136+
attention_bias=False,
137+
attention_dropout=0.0,
138+
attn_implementation="eager",
139+
num_hidden_layers=cls.hid_layers,
140+
max_position_embeddings=cls.seq_len,
141+
use_cache=True,
142+
return_dict=False,
143+
)
144+
cls.fp_model = LlamaModel(cfg)
145+
146+
def test_model_output(self):
147+
qmodel = QuantLlamaModel(
148+
self.fp_model, qcfg=PTQConfig(wrapper_variant="prefill")
149+
)
150+
self.assertIs(qmodel._mode, Mode.NO_QUANT)
151+
152+
qmodel.enable_calibration()
153+
self.assertIs(qmodel._mode, Mode.CALIB)
154+
155+
x = torch.randint(
156+
0,
157+
self.vocab_size,
158+
(
159+
1,
160+
self.seq_len,
161+
),
162+
)
163+
output = qmodel(x)
164+
165+
assert len(output) == 2 # last_hidden_states + past_key_values
166+
past_key_values = output[1]
167+
assert len(past_key_values) == self.hid_layers
168+
for index in range(self.hid_layers):
169+
past_key_value = past_key_values[index]
170+
assert isinstance(past_key_value, tuple)
171+
172+
past_key = past_key_value[0]
173+
assert past_key.shape[-2] == self.seq_len
174+
175+
past_value = past_key_value[1]
176+
assert past_value.shape[-2] == self.seq_len

0 commit comments

Comments
 (0)