Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions test/quantization/wrapq/wrappers/llama/test_quant_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,68 @@ def test_forward_diff(self):
self.assertGreater(diff, 0.0)
self.assertLess(diff, 0.4)
self.assertEqual(fp_out.shape, q_out.shape)


@unittest.skipUnless(has_transformers_for("llama"), skip_msg)
class TestQuantLlamaModelWithCache(unittest.TestCase):
seq_len: int
vocab_size: int
hid_layers: int
fp_model: torch.nn.Module

@classmethod
def setUpClass(cls):
torch.manual_seed(0)

from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaModel

cls.seq_len = 16
cls.vocab_size = 10000
cls.hid_layers = 3
cfg = LlamaConfig(
hidden_size=8,
num_attention_heads=2,
num_key_value_heads=1,
head_dim=4,
attention_bias=False,
attention_dropout=0.0,
attn_implementation="eager",
num_hidden_layers=cls.hid_layers,
max_position_embeddings=cls.seq_len,
use_cache=True,
return_dict=False,
)
cls.fp_model = LlamaModel(cfg)

def test_model_output(self):
qmodel = QuantLlamaModel(
self.fp_model, qcfg=PTQConfig(wrapper_variant="prefill")
)
self.assertIs(qmodel._mode, Mode.NO_QUANT)

qmodel.enable_calibration()
self.assertIs(qmodel._mode, Mode.CALIB)

x = torch.randint(
0,
self.vocab_size,
(
1,
self.seq_len,
),
)
output = qmodel(x)

assert len(output) == 2 # last_hidden_states + past_key_values
past_key_values = output[1]
assert len(past_key_values) == self.hid_layers
for index in range(self.hid_layers):
past_key_value = past_key_values[index]
assert isinstance(past_key_value, tuple)

past_key = past_key_value[0]
assert past_key.shape[-2] == self.seq_len

past_value = past_key_value[1]
assert past_value.shape[-2] == self.seq_len
26 changes: 22 additions & 4 deletions tico/quantization/wrapq/wrappers/llama/quant_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,12 +185,15 @@ def forward(
inputs_embeds = self.embed_tokens(input_ids)

if use_cache and past_key_values is None:
past_key_values = DynamicCache()
past_key_values = []

if cache_position is None:
past_seen_tokens = (
past_key_values.get_seq_length() if past_key_values is not None else 0
0
if (past_key_values is None or len(past_key_values) == 0)
else past_key_values[0][0].shape[-2]
)

cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + inputs_embeds.shape[1],
Expand All @@ -217,15 +220,21 @@ def forward(
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None

for decoder_layer in self.layers[: self.config.num_hidden_layers]:
for idx, decoder_layer in enumerate(
self.layers[: self.config.num_hidden_layers]
):
if output_hidden_states:
all_hidden_states += (hidden_states,) # type: ignore[operator]

layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_mask,
position_ids=position_ids,
past_key_value=past_key_values,
past_key_value=(
past_key_values[idx]
if past_key_values is not None and len(past_key_values) > idx
else None
),
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
Expand All @@ -235,6 +244,15 @@ def forward(

if decoder_layer.wrapped.return_type == "tuple":
hidden_states = layer_outputs[0]
elif use_cache:
hidden_states = layer_outputs[0]
assert isinstance(layer_outputs[1], tuple)
if len(past_key_values) >= idx: # type: ignore[arg-type]
# prefill mode
past_key_values += (layer_outputs[1],) # type: ignore[operator]
else:
# decode mode
past_key_values[idx] = (layer_outputs[1],) # type: ignore[index]
else:
hidden_states = layer_outputs

Expand Down
Loading