diff --git a/test/quantization/wrapq/wrappers/llama/test_quant_model.py b/test/quantization/wrapq/wrappers/llama/test_quant_model.py index d7645ab3..29a3d28b 100644 --- a/test/quantization/wrapq/wrappers/llama/test_quant_model.py +++ b/test/quantization/wrapq/wrappers/llama/test_quant_model.py @@ -109,3 +109,68 @@ def test_forward_diff(self): self.assertGreater(diff, 0.0) self.assertLess(diff, 0.4) self.assertEqual(fp_out.shape, q_out.shape) + + +@unittest.skipUnless(has_transformers_for("llama"), skip_msg) +class TestQuantLlamaModelWithCache(unittest.TestCase): + seq_len: int + vocab_size: int + hid_layers: int + fp_model: torch.nn.Module + + @classmethod + def setUpClass(cls): + torch.manual_seed(0) + + from transformers.models.llama.configuration_llama import LlamaConfig + from transformers.models.llama.modeling_llama import LlamaModel + + cls.seq_len = 16 + cls.vocab_size = 10000 + cls.hid_layers = 3 + cfg = LlamaConfig( + hidden_size=8, + num_attention_heads=2, + num_key_value_heads=1, + head_dim=4, + attention_bias=False, + attention_dropout=0.0, + attn_implementation="eager", + num_hidden_layers=cls.hid_layers, + max_position_embeddings=cls.seq_len, + use_cache=True, + return_dict=False, + ) + cls.fp_model = LlamaModel(cfg) + + def test_model_output(self): + qmodel = QuantLlamaModel( + self.fp_model, qcfg=PTQConfig(wrapper_variant="prefill") + ) + self.assertIs(qmodel._mode, Mode.NO_QUANT) + + qmodel.enable_calibration() + self.assertIs(qmodel._mode, Mode.CALIB) + + x = torch.randint( + 0, + self.vocab_size, + ( + 1, + self.seq_len, + ), + ) + output = qmodel(x) + + assert len(output) == 2 # last_hidden_states + past_key_values + past_key_values = output[1] + assert len(past_key_values) == self.hid_layers + for index in range(self.hid_layers): + past_key_value = past_key_values[index] + assert isinstance(past_key_value, tuple) + + past_key = past_key_value[0] + assert past_key.shape[-2] == self.seq_len + + past_value = past_key_value[1] + assert past_value.shape[-2] == self.seq_len diff --git a/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py b/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py index 289b9014..684c30dc 100644 --- a/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py +++ b/tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py @@ -26,6 +26,7 @@ # ============================================================================= import argparse +import copy import pathlib import random @@ -34,6 +35,7 @@ from typing import Any, List, Optional, Tuple, Union +import numpy as np import torch import tqdm from datasets import load_dataset @@ -48,6 +50,11 @@ from tico.quantization.config.gptq import GPTQConfig from tico.quantization.evaluation.script.llm_tasks_eval import evaluate_llm_on_tasks from tico.quantization.wrapq.dtypes import DType +from tico.quantization.wrapq.examples.static_llama_layer_runtime import ( + _build_decode_attention_mask, + _build_rope_templates_from_config, + _slice_rope, +) from tico.quantization.wrapq.observers.affine_base import AffineObserverBase from tico.quantization.wrapq.qscheme import QScheme from tico.quantization.wrapq.utils.metrics import perplexity @@ -98,14 +105,348 @@ def inject_gptq_qparams( obs.load_qparams(quantizer.scale, quantizer.zero, lock=True) -# ------------------------------------------------------------------------- -# Save model in circle format -# ------------------------------------------------------------------------- -def save_model_to(q_m, calib_inputs, save_circle_to_folder): +def pad_input(input, pad_token, max_seq_len): + """Pad a tensor to a maximum sequence length using the specified pad token.""" + pads = torch.full( + (input.shape[0], max_seq_len - input.shape[1]), + fill_value=pad_token, + device=input.device, + ) + return torch.cat((input, pads), dim=1) + + +def get_decode_input( + prefill_model, + calib_input, + pad_token_id, + ropes, + max_seq_len, + device, + dtype=torch.float32, +): + """Prepare inputs for the decode model using prefill KV‑cache and rotary embeddings.""" + prefill_input = calib_input[..., :-1] + prefill_seq_len = calib_input.shape[-1] + prefill_input = pad_input( + prefill_input, pad_token_id, max_seq_len - 1 + ) # pad input to max_seq_len + # run prefill model to get kv-cache + outputs = prefill_model(prefill_input.to(device), use_cache=True) + + # fill inputs for decode model + next_token = calib_input[..., -1:] + attention_mask = _build_decode_attention_mask( + batch_size=1, + past_len=prefill_seq_len, + max_seq=max_seq_len, + device=device, + dtype=dtype, + ) + + rope_cos, rope_sin = ropes + position_embeddings = _slice_rope( + rope_cos, + rope_sin, + position=prefill_seq_len, + batch_size=1, + device=device, + dtype=dtype, + ) + + # fill in input + inputs = {} + inputs["input_ids"] = torch.tensor([[next_token]]) + inputs["attention_mask"] = attention_mask + inputs["position_embeddings"] = position_embeddings + inputs["past_key_values"] = outputs.past_key_values + return inputs + + +def evaluate_ppl_of_model_on_dataset(model, dataset, device: str = "cuda"): + """Compute perplexity of a model on a dataset.""" + if hasattr(model, "device") and model.device.type != device.type: + if hasattr(model, "to"): + model.to(device) + nlls = [] + with torch.no_grad(): + for batch in tqdm.tqdm(dataset): + if isinstance(batch, torch.Tensor): + batch = batch.to(device) + output = model( + batch.to(device), + ) + else: + raise RuntimeError("Unknown input in ppl_eval_on_dataset") + + if hasattr(output, "logits"): + lm_logits = output.logits + elif len(output) > 1: + lm_logits = torch.tensor(output[0]) + else: + lm_logits = torch.tensor(output) + + if torch.isfinite(lm_logits).all(): + shift_logits = lm_logits[:, :-1, :].contiguous() + if isinstance(batch, torch.Tensor): + shift_labels = batch[:, 1:].contiguous() + else: + assert isinstance(batch, tuple) + shift_labels = batch[0][:, 1:].contiguous() + loss_fct = torch.nn.CrossEntropyLoss(reduction="none") + loss = loss_fct( + shift_logits.reshape(-1, shift_logits.size(-1)), + shift_labels.view(-1), + ) + nlls.append(loss) + del shift_logits, shift_labels + shift_logits = shift_labels = None # type: ignore[assignment] + + del batch, lm_logits, output + lm_logits = output = batch = None # noqa: F841 + torch.cuda.empty_cache() + + ppl = np.exp(torch.cat(nlls, dim=-1).mean().item()) + return ppl + + +def evaluate_ppl_of_ref_prefill_model_on_dataset( + prefill_model, ref_model, dataset, device: str = "cuda" +): + """Compare prefilling model against reference model perplexity.""" + + for model in (prefill_model, ref_model): + if hasattr(model, "device") and model.device.type != device.type: + if hasattr(model, "to"): + model.to(device) + nlls = [] + with torch.no_grad(): + for batch in tqdm.tqdm(dataset): + if isinstance(batch, torch.Tensor): + prefill_output = prefill_model( + batch[..., :-1].to(device), use_cache=True + ) + next_token = batch[..., -1:].cpu() # + ref_output = ref_model( + batch.to(device), + ) + import transformers.cache_utils as cache_utils + + past_key_values = prefill_output.past_key_values + + output = ref_model(batch[..., :-1].to(device), use_cache=True) + ref_past_key_values = ( + output.past_key_values + ) # cache_utils.DynamicCache(config=ref_model.config) + key_eps = torch.tensor(-1, device=ref_model.device) + val_eps = torch.tensor(-1, device=ref_model.device) + for idx, past_key_value in enumerate(past_key_values): + key_eps = torch.max( + key_eps, + torch.max( + torch.abs( + ref_past_key_values.layers[idx].keys - past_key_value[0] + ) + ), + ) + val_eps = torch.max( + val_eps, + torch.max( + torch.abs( + ref_past_key_values.layers[idx].values + - past_key_value[1] + ) + ), + ) + ref_past_key_values.layers[idx].keys = past_key_value[0].clone() + ref_past_key_values.layers[idx].values = past_key_value[1].clone() + + last_token = torch.tensor( + [[torch.argmax(ref_output.logits[:, -1, :], dim=-1).cpu()]] + ) + + output = ref_model( + next_token.to(device), + past_key_values=ref_past_key_values, + use_cache=True, + ) + + lm_logits = output.logits + + if torch.isfinite(lm_logits).all(): + labels = last_token[0].to(device) + loss_fct = torch.nn.CrossEntropyLoss(reduction="none") + loss = loss_fct( + lm_logits.reshape(-1, lm_logits.size(-1)), + labels.view(-1), + ) + nlls.append(loss) + + torch.cuda.empty_cache() + ppl = np.exp(torch.cat(nlls, dim=-1).mean().item()) + return ppl + + +def evaluate_ppl_of_prefill_ref_model_on_dataset( + prefill_model, ref_model, dataset, device: str = "cuda" +): + """Evaluate prefilling model against a reference model.""" + + for model in (prefill_model, ref_model): + if hasattr(model, "device") and model.device.type != device.type: + if hasattr(model, "to"): + model.to(device) + nlls = [] + with torch.no_grad(): + for batch in tqdm.tqdm(dataset): + if isinstance(batch, torch.Tensor): + + prefill_output = ref_model(batch[..., :-1].to(device), use_cache=True) + next_token = batch[..., -1:].cpu() # + ref_output = ref_model( + batch.to(device), + ) + past_key_values = [] + ref_past_key_values = prefill_output.past_key_values + for layer in ref_past_key_values.layers: + keys = layer.keys.clone() + values = layer.values.clone() + past_key_values.append((keys, values)) + + last_token = torch.tensor( + [[torch.argmax(ref_output.logits[:, -1, :], dim=-1).cpu()]] + ) + + # output_1 = prefill_model( + # next_token.to(device), + # past_key_values=past_key_values, + # use_cache=True, + # ) + output = ref_model( + next_token.to(device), + past_key_values=prefill_output.past_key_values, + use_cache=True, + ) + # predicted_token = torch.tensor( + # [[torch.argmax(output.logits[:, -1, :], dim=-1).cpu()]] + # ) + + lm_logits = output.logits + + if torch.isfinite(lm_logits).all(): + labels = last_token[0].to(device) + loss_fct = torch.nn.CrossEntropyLoss(reduction="none") + loss = loss_fct( + lm_logits.reshape(-1, lm_logits.size(-1)), + labels.view(-1), + ) + nlls.append(loss) + + torch.cuda.empty_cache() + ppl = np.exp(torch.cat(nlls, dim=-1).mean().item()) + return ppl + + +def evaluate_ppl_of_prefill_decode_model_on_dataset( + prefill_model, + decode_model, + dataset, + pad_token_id, + max_seq_len, + seed=0, + device: str = "cuda", +): + """Compute perplexity for the prefill-decode logic.""" + + config = ( + prefill_model.config + if hasattr(prefill_model, "config") + else prefill_model.wrapped.config + ) + rope_cos, rope_sin = _build_rope_templates_from_config( + config, max_seq=max_seq_len, device=device, dtype=torch.float32 + ) + + for model in (prefill_model, decode_model): + if hasattr(model, "device") and model.device.type != device.type: + if hasattr(model, "to"): + model.to(device) + + torch.manual_seed(seed) + nlls = [] + with torch.no_grad(): + for batch in tqdm.tqdm(dataset): + if isinstance(batch, torch.Tensor): + + prefill_seq_len = ( + torch.randint(3, max_seq_len - 1, (1,)).cpu().item() + ) # max_seq_len - 1# cropped input length + prefill_input = batch[..., :prefill_seq_len] # cropped input + ref_output = prefill_model(prefill_input.to(device)) + + last_token = torch.tensor( + [[torch.argmax(ref_output.logits[:, -1, :], dim=-1).cpu()]] + ) + if hasattr(prefill_model, "wrapped"): + inputs = get_decode_input( + prefill_model, + prefill_input, + pad_token_id, + (rope_cos, rope_sin), + max_seq_len, + device, + ) + inputs = transfer_inputs_to_device(inputs, device) + output = decode_model(**inputs, use_cache=True) + else: + input = pad_input( + prefill_input[..., :-1], pad_token_id, max_seq_len - 1 + ) + prefill_attn_mask = input != pad_token_id + prefill_position_ids = prefill_attn_mask.long().cumsum(-1) - 1 + prefill_position_ids.masked_fill_(prefill_attn_mask == 0, 1) + prefill_ouput = prefill_model( + input.to(prefill_model.device), + attention_mask = prefill_attn_mask.to(prefill_model.device), + position_ids = prefill_position_ids.to(prefill_model.device), + use_cache=True + ) + next_token = prefill_input[..., -1:] + decode_attention_mask = torch.ones((1, max_seq_len)) + decode_attention_mask[..., :input.shape[-1]] = input != pad_token_id + decode_position_ids = torch.tensor([[prefill_seq_len-1]]) + + output = decode_model( + next_token.to(decode_model.device), + past_key_values=prefill_ouput.past_key_values, + attention_mask = decode_attention_mask.to(decode_model.device), + position_ids = decode_position_ids.to(decode_model.device), + use_cache=True, + ) + + lm_logits = output.logits + + if torch.isfinite(lm_logits).all(): + labels = last_token[0].to(device) + loss_fct = torch.nn.CrossEntropyLoss(reduction="none") + loss = loss_fct( + lm_logits.reshape(-1, lm_logits.size(-1)), + labels.view(-1), + ) + nlls.append(loss) + + torch.cuda.empty_cache() + ppl = np.exp(torch.cat(nlls, dim=-1).mean().item()) + return ppl + + +def save_model_to(q_m, calib_inputs, args): + """ Save the whole model in circle format """ q_m.eval() q_m.cpu() + save_circle_to_folder = args.save_circle_to_folder + suffix = "_prefill" if args.prefill_decode is True else "" - save_path = pathlib.Path(save_circle_to_folder, "model.q.circle") + save_path = pathlib.Path(save_circle_to_folder, f"model{suffix}.q.circle") print(f"saving the whole model to {save_path.resolve()}") with torch.no_grad(): with SuppressWarning(UserWarning, ".*"): @@ -114,7 +455,41 @@ def save_model_to(q_m, calib_inputs, save_circle_to_folder): cm.save(save_path) -def save_layers_to(q_m, max_seq_len, save_layers_to_folder): +def save_prefill_model_to(q_m, calib_inputs, args): + """ Save the whole prefill model with enabled kv-cache in circle format """ + q_m.eval() + q_m.cpu() + save_circle_to_folder = args.save_circle_to_folder + suffix = "_prefill" if args.prefill_decode is True else "" + + q_m.wrapped.config.use_cache = True + save_path = pathlib.Path(save_circle_to_folder, f"model{suffix}.q.circle") + print(f"saving the whole model to {save_path.resolve()}") + with torch.no_grad(): + with SuppressWarning(UserWarning, ".*"): + cm = tico.convert(q_m, (calib_inputs[0][..., :-1],), strict=False) + + cm.save(save_path) + + +def save_decode_model_to(model, decode_calib_input, args): + """ Save the whole decode model in circle format """ + save_path = pathlib.Path(args.save_circle_to_folder, "model_decode.q.circle") + print(f"saving the whole model to {save_path.resolve()}") + model.eval() + model.cpu() + + with torch.no_grad(): + with SuppressWarning(UserWarning, ".*"): + inp = transfer_inputs_to_device(decode_calib_input, "cpu") + cm = tico.convert(model, (), kwargs=inp, strict=False) + cm.save(save_path) + + +def save_layers_to(q_m, args): + """ Save all layers of the model in circle format """ + max_seq_len = args.max_seq_len + save_layers_to_folder = args.save_layers_to_folder q_m.eval() q_m.cpu() @@ -147,13 +522,92 @@ def save_layers_to(q_m, max_seq_len, save_layers_to_folder): cm.save(save_path) -def quantize_using_PTQ(q_m, calib_inputs, args): +def save_prefill_layers_to(q_m, args): + """ Save all layers of the prefill model in circle format """ + + max_seq_len = args.max_seq_len + save_layers_to_folder = args.save_layers_to_folder + suffix = "prefill_" + q_m.eval() + q_m.cpu() + + if not hasattr(q_m, "wrapped"): + print("Saving layers currently is supported only for PTQ quantized model") + return + + layers = q_m.wrapped.model.wrapped.layers + config = q_m.wrapped.config + for i, qlayer in enumerate(layers): + save_path = pathlib.Path( + save_layers_to_folder, f"decoder_layer_{suffix}{i}.q.circle" + ) + B, S, D = 1, max_seq_len - 1, config.hidden_size + example_hidden = torch.randn(B, S, D) + + attention_mask = qlayer.wrapped._slice_causal(S, "cpu").squeeze(0) + dtype = example_hidden.dtype + pos_embeds = qlayer.wrapped._slice_rope(S, "cpu", dtype) + + print(f"Saving model layer_{i} to {save_path.resolve()}") + with torch.no_grad(): + with SuppressWarning(UserWarning, ".*"): + cm = tico.convert( + qlayer, + (example_hidden,), + kwargs={ + "attention_mask": attention_mask, + "position_embeddings": pos_embeds, + "use_cache": True, # this should be set somehow differently + }, + ) + cm.save(save_path) + + +def transfer_inputs_to_device(inp, device): + """ + Transfer all tensor inputs in a dictionary to the specified device. + + Args: + inp (dict): Mapping of input names to tensors, tuples of tensors, or lists containing tensors/tuples. + device (torch.device or str): Target device (e.g., "cpu" or "cuda"). + + Returns: + dict: The same dictionary with all tensors moved to ``device``. The structure of the + dictionary (including nested tuples/lists) is preserved. + + The function iterates over each key in ``inp`` and: + * Moves plain ``torch.Tensor`` objects via ``.to(device)``. + * Recursively moves tensors inside ``tuple`` objects. + * Handles lists containing tensors or tuples, moving each element accordingly. + + This utility is used throughout the quantization pipeline to ensure that + calibration inputs are placed on the correct device before model execution. + """ + for key in inp: + if isinstance(inp[key], torch.Tensor): + inp[key] = inp[key].to(device) + elif isinstance(inp[key], tuple): + inp[key] = (inp[key][0].to(device), inp[key][1].to(device)) + elif isinstance(inp[key], list): + dev_list = [] + for k_inp in inp[key]: + if isinstance(k_inp, torch.Tensor): + dev_inp = k_inp.to(device) + elif isinstance(k_inp, tuple): + dev_inp = (k_inp[0].to(device), k_inp[1].to(device)) + dev_list.append(dev_inp) + inp[key] = dev_list + + return inp + + +def quantize_using_PTQ(q_m, calib_inputs, args, source_model=None, variant="prefill"): print("Wrapping layers with PTQWrapper …") qcfg = build_llm_ptq_config( model_type="llama", num_hidden_layers=len(q_m.model.layers), - wrapper_variant="prefill", + wrapper_variant=variant, activation_dtype=DType.int(16), default_qscheme=QScheme.PER_TENSOR_SYMM, linear_weight_bits=args.linear_weight_bits, @@ -167,7 +621,7 @@ def quantize_using_PTQ(q_m, calib_inputs, args): # ------------------------------------------------------------------------- # Single-pass activation calibration # ------------------------------------------------------------------------- - print("Calibrating PTQ obeservers…") + print("Calibrating PTQ observers…") # Overwrite weight observers with GPTQ statistics if hasattr(q_m, "quantizers") and isinstance(q_m.quantizers, dict): @@ -186,7 +640,38 @@ def quantize_using_PTQ(q_m, calib_inputs, args): device = torch.device(args.device) with torch.no_grad(): for inp in tqdm.tqdm(calib_inputs): - q_m(inp.to(device)) + if isinstance(inp, torch.Tensor): + q_m(inp.to(device)) + else: + inp = transfer_inputs_to_device(inp, device) + q_m(**inp) + + if source_model is not None: + source_modules = {} + for name, m in source_model.named_modules(): + if not isinstance(m, QuantModuleBase): + continue + + source_modules[name] = m + + from tico.quantization.wrapq.observers.base import ObserverBase + + for name, m in q_m.named_modules(): + if not isinstance(m, QuantModuleBase): + continue + + for attrib_name in dir(m): + attrib = getattr(m, attrib_name) + if not isinstance(attrib, ObserverBase): + continue + source_module = source_modules[name] + if not hasattr(source_module, attrib_name): + continue + source_attrib = getattr(source_module, attrib_name) + assert source_attrib is not None + import copy + + setattr(m, attrib_name, copy.deepcopy(source_attrib)) # Freeze all Q-params (scale, zero-point) q_m = convert(q_m) @@ -194,18 +679,19 @@ def quantize_using_PTQ(q_m, calib_inputs, args): return q_m -def evaluate(q_m, tokenizer, dataset_test, args): +def evaluate(q_m, tokenizer, dataset_test, args, quantized: bool): # ------------------------------------------------------------------------- # Evaluate perplexity on Wikitext-2 # ------------------------------------------------------------------------- print("\nCalculating perplexities …") enc = tokenizer("\n\n".join(dataset_test["text"]), return_tensors="pt") - ppl_uint8 = perplexity( + ppl = perplexity( q_m, enc, args.device, max_length=args.max_seq_len, stride=args.max_seq_len ) + help_str = "int16" if quantized is True else "FP32" print("\n┌── Wikitext-2 test perplexity ─────────────") - print(f"│ int16 : {ppl_uint8:8.2f}") + print(f"│ {help_str} : {ppl:8.2f}") print("└───────────────────────────────────────────") if args.eval_tasks is not None: @@ -216,6 +702,291 @@ def evaluate(q_m, tokenizer, dataset_test, args): print(make_table(results)) +class QModelProcessor: + """Base processor handling tokenization, GPTQ, and evaluation logic.""" + + def __init__(self, model, tokenizer, args): + """Initialize the processor with model, tokenizer, and arguments.""" + self.model = model + self.tokenizer = tokenizer + self.device = torch.device(args.device) + self.args = args + + def get_tokenized_inputs(self, dataset, shuffle=True): + """Tokenize the dataset into fixed‑length chunks for calibration.""" + text = " ".join(dataset["text"]) + ids = self.tokenizer(text, return_tensors="pt").input_ids.to(self.device) + tokenized_inputs = [] + nsamples = self.args.nsamples_for_qcalibration + seqlen = self.model.config.max_position_embeddings + if shuffle is True: + random.seed(self.args.seed) + else: + stride = min((ids.shape[1] - seqlen - 1) // nsamples, seqlen) + for index in range(nsamples): + if shuffle is True: + i = random.randint(0, ids.shape[1] - seqlen - 1) + else: + i = index * stride + j = i + seqlen + inp = ids[:, i:j] + tokenized_inputs.append(inp.cpu()) + return tokenized_inputs + + def run_gptq(self, calib_inputs): + """Run GPTQ weight‑only quantization on the model using calibration inputs.""" + print("Applying GPTQ …") + + sens = None + if self.args.gptq_mse is not None and self.args.gptq_mse == "smse": + if self.args.sensitivity_path is not None: + sens = torch.load(self.args.sensitivity_path) + else: + calibrator = SensitivityCalibrator(self.model, calib_inputs) + sens = calibrator.compute_sensitivity_info() + + gptq_config = GPTQConfig( + weight_bits=self.args.linear_weight_bits, + perchannel=True, + mse=self.args.gptq_mse, + sensitivity=sens, + ) + q_m = prepare(self.model, gptq_config, inplace=True) + with torch.no_grad(): + for inp in calib_inputs: + q_m(inp.to(self.device)) + + q_m = convert(q_m, inplace=True) # materialize INT-weight tensors + return q_m + + def evaluate_original(self, dataset_test): + """Evaluate the original ( model on the test dataset.""" + return evaluate( + self.model, self.tokenizer, dataset_test, self.args, quantized=False + ) + + def evaluate_quantized(self, dataset_test): + """Placeholder for evaluating the quantized model (implementation elsewhere).""" + assert False + + def save_quantized(self, model, calib_inputs): + """Placeholder for saving quantgization artifacts (implementation elsewhere).""" + assert False + + +class PrefillQModelProcessor(QModelProcessor): + """ + Processor for simple model (just-prefill-model) which doesn't use kv cache. + """ + + def __init__(self, model, tokenizer, args): + """Initialize the prefill‑decode processor, setting up rope embeddings and handling tokenizer pad token.""" + super().__init__(model, tokenizer, args) + + def run_ptq(self, q_m, calib_inputs): + return quantize_using_PTQ(q_m, calib_inputs, self.args, variant="prefill") + + def evaluate_quantized(self, model, dataset_test): + evaluate(model, self.tokenizer, dataset_test, self.args, quantized=True) + + def save_quantized(self, model, calib_inputs): + if self.args.save_layers_to_folder is not None: + save_layers_to(model, self.args) + + if self.args.save_circle_to_folder is not None: + calib_inputs = list( + torch.stack(calib_inputs).reshape(-1, 1, self.args.max_seq_len) + ) + save_model_to(model, calib_inputs, self.args) + + +class PrefillDecodeQModelProcessor(QModelProcessor): + """ + Processor for Prefill-Decode models. + Prefill-model computes kv-cache for the user input then each new token is produced by decode-model wit upadted kv-cache. + """ + + def __init__(self, model, tokenizer, args): + """Initialize the prefill‑decode processor, handling tokenizer pad token and preparing rotary embeddings.""" + super().__init__(model, tokenizer, args) + if tokenizer.pad_token_id is None: + print( + "Warning: tokenizer doesn't have pad_token. Prefill-decoding scheme may fail." + ) + tokenizer.pad_token = tokenizer.eos_token + + rope_cos, rope_sin = _build_rope_templates_from_config( + self.model.config, + max_seq=self.args.max_seq_len, + device=self.device, + dtype=torch.float32, + ) + self.rope_cos = rope_cos + self.rope_sin = rope_sin + + # debug padding + + # inputs = tokenizer("Hello!", return_tensors="pt", max_length=args.max_seq_len - 1, padding='max_length', padding_side="right").input_ids.to(device) + # #inputs = tokenizer("Hello! How are you?", return_tensors="pt").input_ids.to(device) + # model.config.use_cache = True + # model.config._attn_implementation = "eager" + # out_ids = model.generate(inputs) + # output = tokenizer.decode(out_ids.squeeze(), skip_special_tokens=True) + + # prefill_decode_ppl = evaluate_ppl_of_prefill_decode_model_on_dataset(q_m, q_m, calib_inputs, pad_token_id=tokenizer.pad_token_id, max_seq_len = args.max_seq_len, seed = args.seed, device = device) + # + # print("\n┌── Wikitext-2 prefill_decode initial calibration perplexity──") + # print(f"│ FP32 : {prefill_decode_ppl:8.2f}") + # print("└───────────────────────────────────────────") + + def run_ptq(self, q_m, calib_inputs): + """Run PTQ for the prefill‑decode pipeline, calibrating both prefill and decode models.""" + pre_ptq_model = copy.deepcopy(q_m).to("cpu") # to be used in decode quntizing + + # get prefill_model + q_m = quantize_using_PTQ(q_m, calib_inputs, self.args, variant="prefill") + + prefill_decode_ppl = evaluate_ppl_of_prefill_decode_model_on_dataset( + q_m, + q_m, + calib_inputs, + pad_token_id=self.tokenizer.pad_token_id, + max_seq_len=self.args.max_seq_len, + seed=self.args.seed, + device=self.device, + ) + + print( + "\n┌── Wikitext-2 prefill_prefill train calibration perplexity ─────────────" + ) + print(f"│ int16 : {prefill_decode_ppl:8.2f}") + print("└───────────────────────────────────────────") + + torch.manual_seed(self.args.seed) + + decode_calib_inputs = [] + with torch.no_grad(): + print("Computing calibration set for decode-model") + for calib_input in tqdm.tqdm(calib_inputs): + prefill_seq_len = ( + torch.randint(3, self.args.max_seq_len - 1, (1,)).cpu().item() + ) # cropped input length + prefill_input = calib_input[..., :prefill_seq_len] # cropped input + inputs = get_decode_input( + q_m, + prefill_input, + self.tokenizer.pad_token_id, + (self.rope_cos, self.rope_sin), + self.args.max_seq_len, + self.device, + ) + decode_calib_inputs.append(inputs) + + q_m_decode = quantize_using_PTQ( + pre_ptq_model.to(self.device), + decode_calib_inputs, + self.args, + source_model=q_m, + variant="decode", + ) + q_m_decode.wrapped.config.use_cache = True + prefill_decode_ppl = evaluate_ppl_of_prefill_decode_model_on_dataset( + q_m, + q_m_decode, + calib_inputs, + pad_token_id=self.tokenizer.pad_token_id, + max_seq_len=self.args.max_seq_len, + seed=self.args.seed, + device=self.device, + ) + + print( + "\n┌── Wikitext-2 prefill_decode train calibration perplexity ─────────────" + ) + print(f"│ int16 : {prefill_decode_ppl:8.2f}") + print("└───────────────────────────────────────────") + return (q_m, q_m_decode) + + def evaluate_original(self, dataset_test): + """Evaluate the original (FP) model using the prefill‑decode pipeline.""" + super().evaluate_original(dataset_test) + + test_inputs = self.get_tokenized_inputs(dataset_test, shuffle=False) + prefill_decode_ppl = evaluate_ppl_of_prefill_decode_model_on_dataset( + self.model, + self.model, + test_inputs, + pad_token_id=self.tokenizer.pad_token_id, + max_seq_len=self.args.max_seq_len, + seed=self.args.seed, + device=self.device, + ) + + print("\n┌── Wikitext-2 prefill_prefill original test perplexity ─────────────") + print(f"│ FP32 : {prefill_decode_ppl:8.2f}") + print("└───────────────────────────────────────────") + + def evaluate_quantized(self, model, dataset_test): + """Evaluate the quantized prefill‑decode model on the test dataset.""" + prefill_model, decode_model = model + evaluate(prefill_model, self.tokenizer, dataset_test, self.args, quantized=True) + + test_inputs = self.get_tokenized_inputs(dataset_test, shuffle=False) + prefill_decode_ppl = evaluate_ppl_of_prefill_decode_model_on_dataset( + prefill_model, + decode_model, + test_inputs, + pad_token_id=self.tokenizer.pad_token_id, + max_seq_len=self.args.max_seq_len, + seed=self.args.seed, + device=self.device, + ) + print("\n┌── Wikitext-2 prefill_decode quantized test perplexity ─────────────") + print(f"│ FP32 : {prefill_decode_ppl:8.2f}") + print("└───────────────────────────────────────────") + + def save_quantized(self, model, calib_inputs): + """Save the quantized prefill and decode models (and optionally their layers) to disk.""" + prefill_model, decode_model = model + + if self.args.save_layers_to_folder is not None: + save_prefill_layers_to(prefill_model, self.args) + + if self.args.save_circle_to_folder is not None: + calib_inputs = list( + torch.stack(calib_inputs).reshape(-1, 1, self.args.max_seq_len) + ) + # save prefill model + save_prefill_model_to(prefill_model, calib_inputs, self.args) + + # compute example input + prefill_seq_len = ( + torch.randint(3, self.args.max_seq_len - 1, (1,)).cpu().item() + ) # cropped input length + prefill_input = calib_inputs[0][..., :prefill_seq_len].to( + "cpu" + ) # cropped input + + inputs = get_decode_input( + prefill_model, + prefill_input, + self.tokenizer.pad_token_id, + (self.rope_cos, self.rope_sin), + self.args.max_seq_len, + "cpu", + ) + + # save decode model + save_decode_model_to(decode_model, inputs, self.args) + + +def get_qmodel_processor(model, tokenizer, args): + if args.prefill_decode: + return PrefillDecodeQModelProcessor(model, tokenizer, args) + + return PrefillQModelProcessor(model, tokenizer, args) + + def main(): parser = argparse.ArgumentParser( description="GPTQ+PTQ pipeline (weight-only + activation)" @@ -334,6 +1105,13 @@ def main(): type=str, default=None, ) + parser.add_argument( + "--prefill_decode", + action="store_true", + default=False, + help="Wether to use cache", + ) + args = parser.parse_args() print(args) @@ -368,7 +1146,7 @@ def main(): device_map=dev_map, ).eval() - model.config.use_cache = False # TODO use args for it + model.config.use_cache = False if args.calibrate_seq_len is not None: model.config.max_position_embeddings = min( model.config.max_position_embeddings, args.calibrate_seq_len @@ -378,65 +1156,32 @@ def main(): DATASET_NAME, DATASET_CONFIG, split=TEST_SPLIT, cache_dir=args.cache_dir ) - print("\nCalculating original perplexities …") - enc = tokenizer("\n\n".join(dataset_test["text"]), return_tensors="pt") - ppl_fp32 = perplexity( - model, enc, device, max_length=args.max_seq_len, stride=args.max_seq_len - ) - - print("\n┌── Wikitext-2 test perplexity ─────────────") - print(f"│ FP32 : {ppl_fp32:8.2f}") - print("└───────────────────────────────────────────") + # ------------------------------------------------------------------------- + # Create a processor for the model + # ------------------------------------------------------------------------- + qmodel_processor = get_qmodel_processor(model, tokenizer, args) - if args.eval_tasks is not None: - results = evaluate_llm_on_tasks( - model, tokenizer, args.eval_tasks, max_length=args.max_seq_len - ) - print("Original RESULTS ARE:") - print(make_table(results)) + # ------------------------------------------------------------------------- + # Compute original metrics to estimate metrics degradation + # ------------------------------------------------------------------------- + qmodel_processor.evaluate_original(dataset_test) # ------------------------------------------------------------------------- # Prepare calibration dataset # ------------------------------------------------------------------------- dataset_train = load_dataset(DATASET_NAME, DATASET_CONFIG, split=TRAIN_SPLIT) - calib_txt = " ".join(dataset_train["text"]) - train_ids = tokenizer(calib_txt, return_tensors="pt").input_ids.to(device) - calib_inputs = [] - nsamples = args.nsamples_for_qcalibration - seqlen = model.config.max_position_embeddings - random.seed(args.seed) - for _ in range(nsamples): - i = random.randint(0, train_ids.shape[1] - seqlen - 1) - j = i + seqlen - inp = train_ids[:, i:j] - calib_inputs.append(inp.cpu()) + calib_inputs = qmodel_processor.get_tokenized_inputs(dataset_train, shuffle=True) + + # original_prefill_decode_ppl = evaluate_ppl_of_prefill_decode_model_on_dataset(model, model, calib_inputs, pad_token_id=tokenizer.pad_token_id, max_seq_len = args.max_seq_len, seed = args.seed, device = device) + # print("\n┌── Wikitext-2 prefill_decode original calibration perplexity ─────────────") + # print(f"│ fp32 : {original_prefill_decode_ppl:8.2f}") + # print("└───────────────────────────────────────────") # ------------------------------------------------------------------------- # Run GPTQ (weight-only) pass # ------------------------------------------------------------------------- if not args.no_GPTQ: - print("Applying GPTQ …") - - sens = None - if args.gptq_mse is not None and args.gptq_mse == "smse": - if args.sensitivity_path is not None: - sens = torch.load(args.sensitivity_path) - else: - calibrator = SensitivityCalibrator(model, calib_inputs) - sens = calibrator.compute_sensitivity_info() - - gptq_config = GPTQConfig( - weight_bits=args.linear_weight_bits, - perchannel=True, - mse=args.gptq_mse, - sensitivity=sens, - ) - q_m = prepare(model, gptq_config, inplace=True) - with torch.no_grad(): - for inp in calib_inputs: - q_m(inp.to(args.device)) - - q_m = convert(q_m, inplace=True) # materialize INT-weight tensors + q_m = qmodel_processor.run_gptq(calib_inputs) else: q_m = model @@ -444,17 +1189,17 @@ def main(): # Wrap every layer with PTQWrapper # ------------------------------------------------------------------------- if not args.no_PTQ: - q_m = quantize_using_PTQ(q_m, calib_inputs, args) + q_m = qmodel_processor.run_ptq(q_m, calib_inputs) - # after PTQ quantizer only fixed-length input sequences are valid - evaluate(q_m, tokenizer, dataset_test, args) - - if args.save_layers_to_folder is not None: - save_layers_to(q_m, args.max_seq_len, args.save_layers_to_folder) + # ------------------------------------------------------------------------- + # Compute quantized model metrics to estimate metrics degradation + # ------------------------------------------------------------------------- + qmodel_processor.evaluate_quantized(q_m, dataset_test) - if args.save_circle_to_folder is not None: - calib_inputs = list(torch.stack(calib_inputs).reshape(-1, 1, args.max_seq_len)) - save_model_to(q_m, calib_inputs, args.save_circle_to_folder) + # ------------------------------------------------------------------------- + # Save layers and model + # ------------------------------------------------------------------------- + qmodel_processor.save_quantized(q_m, calib_inputs) if __name__ == "__main__": diff --git a/tico/quantization/wrapq/examples/static_llama_layer_runtime.py b/tico/quantization/wrapq/examples/static_llama_layer_runtime.py index 85aebdca..bf70a314 100644 --- a/tico/quantization/wrapq/examples/static_llama_layer_runtime.py +++ b/tico/quantization/wrapq/examples/static_llama_layer_runtime.py @@ -568,8 +568,9 @@ def main(): model = AutoModelForCausalLM.from_pretrained( args.model, dtype=torch.float32, + cache_dir = "/mnt/storage/transformers_cache" ).to(args.device) - tokenizer = AutoTokenizer.from_pretrained(args.model, legacy=False) + tokenizer = AutoTokenizer.from_pretrained(args.model, legacy=False, cache_dir = "/mnt/storage/transformers_cache") if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token diff --git a/tico/quantization/wrapq/wrappers/llama/quant_attn_decode.py b/tico/quantization/wrapq/wrappers/llama/quant_attn_decode.py index 1f7501d7..1f737d00 100644 --- a/tico/quantization/wrapq/wrappers/llama/quant_attn_decode.py +++ b/tico/quantization/wrapq/wrappers/llama/quant_attn_decode.py @@ -103,7 +103,7 @@ def __init__( fp_attn.q_proj, qcfg=q_cfg, fp_name=f"{fp_name}.q_proj" ) self.k_proj = PTQWrapper( - copy.deepcopy(fp_attn.k_proj), qcfg=k_cfg, fp_name=f"{fp_name}.k_proj" + fp_attn.k_proj, qcfg=k_cfg, fp_name=f"{fp_name}.k_proj" ) self.v_proj = PTQWrapper( fp_attn.v_proj, qcfg=v_cfg, fp_name=f"{fp_name}.v_proj" diff --git a/tico/quantization/wrapq/wrappers/llama/quant_attn_prefill.py b/tico/quantization/wrapq/wrappers/llama/quant_attn_prefill.py index 25f1765a..02e7a17c 100644 --- a/tico/quantization/wrapq/wrappers/llama/quant_attn_prefill.py +++ b/tico/quantization/wrapq/wrappers/llama/quant_attn_prefill.py @@ -78,7 +78,7 @@ def __init__( fp_attn.q_proj, qcfg=q_cfg, fp_name=f"{fp_name}.q_proj" ) self.k_proj = PTQWrapper( - copy.deepcopy(fp_attn.k_proj), qcfg=k_cfg, fp_name=f"{fp_name}.k_proj" + fp_attn.k_proj, qcfg=k_cfg, fp_name=f"{fp_name}.k_proj" ) self.v_proj = PTQWrapper( fp_attn.v_proj, qcfg=v_cfg, fp_name=f"{fp_name}.v_proj" @@ -124,7 +124,7 @@ def __init__( self.obs_k_rot = mk("k_rot") # Masking & attention math - self.obs_causal_mask = mk("causal_mask") + self.obs_attn_mask = mk("causal_mask") self.obs_logits = mk("logits") self.obs_mask_add = mk("mask_add") self.obs_softmax = mk("softmax") @@ -132,6 +132,14 @@ def __init__( self.obs_attn_weights = mk("attn_weights") self.obs_attn_out_h = mk("attn_out_h") + # New kv delta + self.obs_new_k = mk("new_k") # (B, n_kv, 1, H) + self.obs_new_v = mk("new_v") # (B, n_kv, 1, H) + + # Total KV after concat (used for matmul/attn) + self.obs_k_total = mk("k_total") # (B, max_seq, H) + self.obs_v_total = mk("v_total") # (B, max_seq, H) + # Static causal mask template assert hasattr(cfg, "max_position_embeddings") max_seq = cfg.max_position_embeddings @@ -207,7 +215,7 @@ def forward( hidden_states.device ) attention_mask = attention_mask.squeeze(0) - attention_mask = self._fq(attention_mask, self.obs_causal_mask) + attention_mask = self._fq(attention_mask, self.obs_attn_mask) attn_weights_parts = [] attn_out_parts = [] @@ -233,6 +241,8 @@ def forward( present_v_parts.append(v_i) k_i, v_i = self._concat_kv(past_key_value, k_i, v_i, kv_i) + k_i = self._fq(k_i, self.obs_k_total) + v_i = self._fq(v_i, self.obs_v_total) for rep_i in range(kv_rep): q_idx = kv_i * kv_rep + rep_i # q_h: (B, S, H) @@ -289,6 +299,8 @@ def forward( # Present KV: (B, n_kv, S, H) present_k = torch.stack(present_k_parts, dim=1) present_v = torch.stack(present_v_parts, dim=1) + present_k = self._fq(present_k, self.obs_new_k) + present_v = self._fq(present_v, self.obs_new_v) present_key_value = (present_k, present_v) # return with/without cache @@ -303,7 +315,7 @@ def _all_observers(self): self.obs_hidden, self.obs_cos, self.obs_sin, - self.obs_causal_mask, + self.obs_attn_mask, self.obs_q_x1, self.obs_q_x2, self.obs_q_cat, @@ -322,6 +334,10 @@ def _all_observers(self): self.obs_attn_out, self.obs_attn_weights, self.obs_attn_out_h, + self.obs_new_k, + self.obs_new_v, + self.obs_k_total, + self.obs_v_total, ) # recurse into children that are QuantModuleBase for m in (self.q_proj, self.k_proj, self.v_proj, self.o_proj): diff --git a/tico/quantization/wrapq/wrappers/llama/quant_decoder_layer_decode.py b/tico/quantization/wrapq/wrappers/llama/quant_decoder_layer_decode.py index 1c5ac5e7..8212b8fd 100644 --- a/tico/quantization/wrapq/wrappers/llama/quant_decoder_layer_decode.py +++ b/tico/quantization/wrapq/wrappers/llama/quant_decoder_layer_decode.py @@ -173,10 +173,10 @@ def forward( "Decode expects attention_mask with shape (B, 1, max_seq); " f"got {tuple(attention_mask.shape)}" ) - assert attention_mask.size(2) == self.max_seq, ( - f"Decode expects attention_mask width == max_seq ({self.max_seq}); " - f"got {attention_mask.size(2)}" - ) + # assert attention_mask.size(2) == self.max_seq, ( + # f"Decode expects attention_mask width == max_seq ({self.max_seq}); " + # f"got {attention_mask.size(2)}" + # ) # RoPE tables for the current token must be provided by the host/runtime. assert ( diff --git a/tico/quantization/wrapq/wrappers/llama/quant_model.py b/tico/quantization/wrapq/wrappers/llama/quant_model.py index 74325d28..3252e4a1 100644 --- a/tico/quantization/wrapq/wrappers/llama/quant_model.py +++ b/tico/quantization/wrapq/wrappers/llama/quant_model.py @@ -27,6 +27,7 @@ from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase from tico.quantization.wrapq.wrappers.registry import try_register +Q_INF = float(-120) #quantization friendly negative infinity @try_register("transformers.models.llama.modeling_llama.LlamaModel") class QuantLlamaModel(QuantModuleBase): @@ -79,7 +80,7 @@ def __init__( # Static causal mask template --------------------------------------- assert isinstance(self.config.max_position_embeddings, int) max_seq = self.config.max_position_embeddings - mask = torch.full((1, 1, max_seq, max_seq), float("-120")) + mask = torch.full((1, 1, max_seq, max_seq), Q_INF) mask.triu_(1) self.register_buffer("causal_mask_template", mask, persistent=False) @@ -126,14 +127,14 @@ def __init__( self.register_buffer("rope_cos_template", cos_t, persistent=False) self.register_buffer("rope_sin_template", sin_t, persistent=False) - def _slice_causal(self, seq_len: int, device: torch.device) -> torch.Tensor: + def _slice_causal(self, seq_len: int, device: torch.device, offset: int = 0) -> torch.Tensor: """Return `[1,1,L,L]` causal mask slice on *device*.""" assert isinstance(self.causal_mask_template, torch.Tensor) - return self.causal_mask_template[..., :seq_len, :seq_len].to(device) + return self.causal_mask_template[..., offset : offset + seq_len, : offset + seq_len].to(device) - def get_attention_mask_for(self, x): + def get_attention_mask_for(self, x, offset: int = 0): L = x.size(1) - attention_mask = self._slice_causal(L, x.device) + attention_mask = self._slice_causal(L, x.device, offset) return attention_mask def get_position_embeddings_for(self, hidden_states): @@ -158,6 +159,7 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[torch.Tensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPast]: @@ -185,12 +187,16 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if use_cache and past_key_values is None: - past_key_values = DynamicCache() - + past_key_values = [] + + present_key_values = [] if cache_position is None: past_seen_tokens = ( - past_key_values.get_seq_length() if past_key_values is not None else 0 + 0 + if (past_key_values is None or len(past_key_values) == 0) + else past_key_values[0][0].shape[-2] ) + cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], @@ -201,23 +207,45 @@ def forward( position_ids = cache_position.unsqueeze(0) hidden_states = inputs_embeds - # create position_embeddings and causal_mask to be shared across all the decoder layers - causal_mask = self.get_attention_mask_for(hidden_states) - causal_mask = causal_mask.squeeze(0) + + offset = past_key_values[0][0].shape[-2] if past_key_values is not None and len(past_key_values) > 0 else 0 + + if attention_mask is not None and len(attention_mask.shape) == 3: + causal_mask = attention_mask # set externally + else: + if attention_mask is not None: # assuming it's boolean matrix 0 - False, 1- True (e.g. padding) + # convert it to float, so that True(1) maps to 0, False(0) maps to Q_INF + attention_mask = (torch.ones_like(attention_mask) - attention_mask) * Q_INF + + # create causal_mask to be shared across all the decoder layers + causal_mask = self.get_attention_mask_for(hidden_states, offset) + if attention_mask is not None: + # in case external mask was set just `and` it with causal_mask + causal_mask = torch.max(Q_INF, causal_mask + attention_mask) + causal_mask = causal_mask.squeeze(0) causal_mask = self._fq(causal_mask, self.obs_causal_mask) - position_embeddings = self.get_position_embeddings_for(hidden_states) - cos, sin = position_embeddings - position_embeddings = ( - self._fq(cos[:, : hidden_states.size(1), :], self.obs_cos), - self._fq(sin[:, : hidden_states.size(1), :], self.obs_sin), - ) - + if position_embeddings is None: + position_embeddings = self.get_position_embeddings_for(hidden_states) + cos, sin = position_embeddings + + position_embeddings = ( + self._fq(cos[:, offset : offset + hidden_states.size(1), :], self.obs_cos), + self._fq(sin[:, offset : offset + hidden_states.size(1), :], self.obs_sin), + ) + else: + cos, sin = position_embeddings + position_embeddings = ( + self._fq(cos, self.obs_cos), + self._fq(sin, self.obs_sin), + ) # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - for decoder_layer in self.layers[: self.config.num_hidden_layers]: + for idx, decoder_layer in enumerate( + self.layers[: self.config.num_hidden_layers] + ): if output_hidden_states: all_hidden_states += (hidden_states,) # type: ignore[operator] @@ -225,7 +253,11 @@ def forward( hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_value=( + past_key_values[idx] + if past_key_values is not None and idx < len(past_key_values) + else None + ), output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -235,6 +267,10 @@ def forward( if decoder_layer.wrapped.return_type == "tuple": hidden_states = layer_outputs[0] + elif use_cache: + hidden_states = layer_outputs[0] + assert isinstance(layer_outputs[1], tuple) + present_key_values.append(layer_outputs[1]) else: hidden_states = layer_outputs @@ -249,7 +285,7 @@ def forward( output = BaseModelOutputWithPast( last_hidden_state=hidden_states, - past_key_values=past_key_values if use_cache else None, + past_key_values=present_key_values if use_cache else None, hidden_states=all_hidden_states, attentions=all_self_attns, ) diff --git a/tico/quantization/wrapq/wrappers/llama/quant_model_for_causal_lm.py b/tico/quantization/wrapq/wrappers/llama/quant_model_for_causal_lm.py index 70ca7df3..7d94f884 100644 --- a/tico/quantization/wrapq/wrappers/llama/quant_model_for_causal_lm.py +++ b/tico/quantization/wrapq/wrappers/llama/quant_model_for_causal_lm.py @@ -75,6 +75,7 @@ def forward( use_cache: bool | None = None, cache_position: torch.LongTensor | None = None, logits_to_keep: int | torch.Tensor = 0, + position_embeddings: Optional[torch.Tensor] = None, **kwargs, ) -> CausalLMOutputWithPast: @@ -94,6 +95,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, cache_position=cache_position, + position_embeddings=position_embeddings, **kwargs, )