4040from lm_eval .utils import make_table
4141from transformers import AutoModelForCausalLM , AutoTokenizer
4242
43+ import numpy as np
4344import tico
4445
4546from tico .quantization import convert , prepare
@@ -98,6 +99,51 @@ def inject_gptq_qparams(
9899 obs .load_qparams (quantizer .scale , quantizer .zero , lock = True )
99100
100101
102+ def evaluate_ppl_of_model_on_dataset (model , dataset , device : str = "cuda" ):
103+ if hasattr (model , "device" ) and model .device .type != device .type :
104+ if hasattr (model , "to" ):
105+ model .to (device )
106+ nlls = []
107+ with torch .no_grad ():
108+ for batch in tqdm .tqdm (dataset ):
109+ if isinstance (batch , torch .Tensor ):
110+ batch = batch .to (device )
111+ output = model (
112+ batch .to (device ),
113+ )
114+ else :
115+ raise RuntimeError ("Unknown input in ppl_eval_on_dataset" )
116+
117+ if hasattr (output , "logits" ):
118+ lm_logits = output .logits
119+ elif len (output ) > 1 :
120+ lm_logits = torch .tensor (output [0 ])
121+ else :
122+ lm_logits = torch .tensor (output )
123+
124+ if torch .isfinite (lm_logits ).all ():
125+ shift_logits = lm_logits [:, :- 1 , :].contiguous ()
126+ if isinstance (batch , torch .Tensor ):
127+ shift_labels = batch [:, 1 :].contiguous ()
128+ else :
129+ assert isinstance (batch , tuple )
130+ shift_labels = batch [0 ][:, 1 :].contiguous ()
131+ loss_fct = torch .nn .CrossEntropyLoss (reduction = "none" )
132+ loss = loss_fct (
133+ shift_logits .reshape (- 1 , shift_logits .size (- 1 )),
134+ shift_labels .view (- 1 ),
135+ )
136+ nlls .append (loss )
137+ del shift_logits , shift_labels
138+ shift_logits = shift_labels = None # type: ignore[assignment]
139+
140+ del batch , lm_logits , output
141+ lm_logits = output = batch = None # noqa: F841
142+ torch .cuda .empty_cache ()
143+
144+ ppl = np .exp (torch .cat (nlls , dim = - 1 ).mean ().item ())
145+ return ppl
146+
101147# -------------------------------------------------------------------------
102148# Save model in circle format
103149# -------------------------------------------------------------------------
@@ -296,7 +342,7 @@ def main():
296342 "--gptq_mse" ,
297343 type = str ,
298344 default = None ,
299- choices = ["mse" , "smse" ],
345+ choices = ["mse" , "smse" , "smse_for_gptq" ],
300346 help = "Whether and how to use mse in gptq (none/mse/smse/)" ,
301347 )
302348 parser .add_argument (
@@ -410,6 +456,13 @@ def main():
410456 j = i + seqlen
411457 inp = train_ids [:, i :j ]
412458 calib_inputs .append (inp .cpu ())
459+
460+ train_ppl_fp32 = evaluate_ppl_of_model_on_dataset (
461+ model , calib_inputs , device = device
462+ )
463+ print ("\n ┌── Wikitext-2 train perplexity ─────────────" )
464+ print (f"│ FP32 : { train_ppl_fp32 :8.2f} " )
465+ print ("└───────────────────────────────────────────" )
413466
414467 # -------------------------------------------------------------------------
415468 # Run GPTQ (weight-only) pass
@@ -418,7 +471,9 @@ def main():
418471 print ("Applying GPTQ …" )
419472
420473 sens = None
421- if args .gptq_mse is not None and args .gptq_mse == "smse" :
474+ if args .gptq_mse is not None and (
475+ args .gptq_mse == "smse" or args .gptq_mse == "smse_for_gptq"
476+ ):
422477 if args .sensitivity_path is not None :
423478 sens = torch .load (args .sensitivity_path )
424479 else :
@@ -446,6 +501,13 @@ def main():
446501 if not args .no_PTQ :
447502 q_m = quantize_using_PTQ (q_m , calib_inputs , args )
448503
504+ train_ppl_ioqdtype = evaluate_ppl_of_model_on_dataset (
505+ q_m , calib_inputs , device = device
506+ )
507+ print ("\n ┌── Wikitext-2 train perplexity ─────────────" )
508+ print (f"│ int16 : { train_ppl_ioqdtype :8.2f} " )
509+ print ("└───────────────────────────────────────────" )
510+
449511 # after PTQ quantizer only fixed-length input sequences are valid
450512 evaluate (q_m , tokenizer , dataset_test , args )
451513
0 commit comments