Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions test/quantization/wrapq/wrappers/llama/test_quant_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,68 @@ def test_forward_diff(self):
self.assertGreater(diff, 0.0)
self.assertLess(diff, 0.4)
self.assertEqual(fp_out.shape, q_out.shape)


@unittest.skipUnless(has_transformers_for("llama"), skip_msg)
class TestQuantLlamaModelWithCache(unittest.TestCase):
seq_len: int
vocab_size: int
hid_layers: int
fp_model: torch.nn.Module

@classmethod
def setUpClass(cls):
torch.manual_seed(0)

from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaModel

cls.seq_len = 16
cls.vocab_size = 10000
cls.hid_layers = 3
cfg = LlamaConfig(
hidden_size=8,
num_attention_heads=2,
num_key_value_heads=1,
head_dim=4,
attention_bias=False,
attention_dropout=0.0,
attn_implementation="eager",
num_hidden_layers=cls.hid_layers,
max_position_embeddings=cls.seq_len,
use_cache=True,
return_dict=False,
)
cls.fp_model = LlamaModel(cfg)

def test_model_output(self):
qmodel = QuantLlamaModel(
self.fp_model, qcfg=PTQConfig(wrapper_variant="prefill")
)
self.assertIs(qmodel._mode, Mode.NO_QUANT)

qmodel.enable_calibration()
self.assertIs(qmodel._mode, Mode.CALIB)

x = torch.randint(
0,
self.vocab_size,
(
1,
self.seq_len,
),
)
output = qmodel(x)

assert len(output) == 2 # last_hidden_states + past_key_values
past_key_values = output[1]
assert len(past_key_values) == self.hid_layers
for index in range(self.hid_layers):
past_key_value = past_key_values[index]
assert isinstance(past_key_value, tuple)

past_key = past_key_value[0]
assert past_key.shape[-2] == self.seq_len

past_value = past_key_value[1]
assert past_value.shape[-2] == self.seq_len
Loading
Loading