3434
3535from typing import Any , List , Optional , Tuple , Union
3636
37+ import numpy as np
38+
3739import torch
3840import tqdm
3941from datasets import load_dataset
@@ -98,6 +100,51 @@ def inject_gptq_qparams(
98100 obs .load_qparams (quantizer .scale , quantizer .zero , lock = True )
99101
100102
103+ def evaluate_ppl_of_model_on_dataset (model , dataset , device : str = "cuda" ):
104+ if hasattr (model , "device" ) and model .device .type != device .type :
105+ if hasattr (model , "to" ):
106+ model .to (device )
107+ nlls = []
108+ with torch .no_grad ():
109+ for batch in tqdm .tqdm (dataset ):
110+ if isinstance (batch , torch .Tensor ):
111+ batch = batch .to (device )
112+ output = model (
113+ batch .to (device ),
114+ )
115+ else :
116+ raise RuntimeError ("Unknown input in ppl_eval_on_dataset" )
117+
118+ if hasattr (output , "logits" ):
119+ lm_logits = output .logits
120+ elif len (output ) > 1 :
121+ lm_logits = torch .tensor (output [0 ])
122+ else :
123+ lm_logits = torch .tensor (output )
124+
125+ if torch .isfinite (lm_logits ).all ():
126+ shift_logits = lm_logits [:, :- 1 , :].contiguous ()
127+ if isinstance (batch , torch .Tensor ):
128+ shift_labels = batch [:, 1 :].contiguous ()
129+ else :
130+ assert isinstance (batch , tuple )
131+ shift_labels = batch [0 ][:, 1 :].contiguous ()
132+ loss_fct = torch .nn .CrossEntropyLoss (reduction = "none" )
133+ loss = loss_fct (
134+ shift_logits .reshape (- 1 , shift_logits .size (- 1 )),
135+ shift_labels .view (- 1 ),
136+ )
137+ nlls .append (loss )
138+ del shift_logits , shift_labels
139+ shift_logits = shift_labels = None # type: ignore[assignment]
140+
141+ del batch , lm_logits , output
142+ lm_logits = output = batch = None # noqa: F841
143+ torch .cuda .empty_cache ()
144+
145+ ppl = np .exp (torch .cat (nlls , dim = - 1 ).mean ().item ())
146+ return ppl
147+
101148# -------------------------------------------------------------------------
102149# Save model in circle format
103150# -------------------------------------------------------------------------
@@ -215,6 +262,24 @@ def evaluate(q_m, tokenizer, dataset_test, args):
215262 print ("Quantized RESULTS ARE:" )
216263 print (make_table (results ))
217264
265+ def get_sensitivities_info_path (model , save_folder , dataset , seed , n_samples ):
266+ model_name = model .config .name_or_path .replace ("/" , "_" )
267+ if save_folder is None :
268+ save_folder = "."
269+ cache_path = (
270+ "."
271+ + "/sensitivities_for_"
272+ + model_name
273+ + "_"
274+ + dataset
275+ + "_"
276+ + str (n_samples )
277+ + "_"
278+ + str (seed )
279+ + ".pt"
280+ )
281+ return cache_path
282+
218283
219284def main ():
220285 parser = argparse .ArgumentParser (
@@ -296,7 +361,7 @@ def main():
296361 "--gptq_mse" ,
297362 type = str ,
298363 default = None ,
299- choices = ["mse" , "smse" ],
364+ choices = ["mse" , "smse" , "smse_for_gptq" ],
300365 help = "Whether and how to use mse in gptq (none/mse/smse/)" ,
301366 )
302367 parser .add_argument (
@@ -410,6 +475,13 @@ def main():
410475 j = i + seqlen
411476 inp = train_ids [:, i :j ]
412477 calib_inputs .append (inp .cpu ())
478+
479+ train_ppl_fp32 = evaluate_ppl_of_model_on_dataset (
480+ model , calib_inputs , device = device
481+ )
482+ print ("\n ┌── Wikitext-2 train perplexity ─────────────" )
483+ print (f"│ FP32 : { train_ppl_fp32 :8.2f} " )
484+ print ("└───────────────────────────────────────────" )
413485
414486 # -------------------------------------------------------------------------
415487 # Run GPTQ (weight-only) pass
@@ -418,13 +490,24 @@ def main():
418490 print ("Applying GPTQ …" )
419491
420492 sens = None
421- if args .gptq_mse is not None and args .gptq_mse == "smse" :
493+ if args .gptq_mse is not None and (
494+ args .gptq_mse == "smse" or args .gptq_mse == "smse_for_gptq"
495+ ):
422496 if args .sensitivity_path is not None :
423497 sens = torch .load (args .sensitivity_path )
424498 else :
425499 calibrator = SensitivityCalibrator (model , calib_inputs )
426500 sens = calibrator .compute_sensitivity_info ()
427-
501+ save_folder = args .save_circle_to_folder if args .save_circle_to_folder is not None else args .save_layers_to_folder
502+ save_path = get_sensitivities_info_path (model , save_folder , "wikitext" , args .seed , len (calib_inputs ))
503+ print (f"Saving calibrated_sensitivities to { save_path } " )
504+ torch .save (sens , save_path )
505+
506+ model = model .cpu ()
507+ model = model .to (args .device )
508+ torch .cuda .empty_cache ()
509+ torch .cuda .synchronize ()
510+
428511 gptq_config = GPTQConfig (
429512 weight_bits = args .linear_weight_bits ,
430513 perchannel = True ,
@@ -440,12 +523,24 @@ def main():
440523 else :
441524 q_m = model
442525
526+ q_m = q_m .cpu ()
527+ q_m = q_m .to (args .device )
528+ torch .cuda .empty_cache ()
529+ torch .cuda .synchronize ()
530+
443531 # -------------------------------------------------------------------------
444532 # Wrap every layer with PTQWrapper
445533 # -------------------------------------------------------------------------
446534 if not args .no_PTQ :
447535 q_m = quantize_using_PTQ (q_m , calib_inputs , args )
448536
537+ train_ppl_ioqdtype = evaluate_ppl_of_model_on_dataset (
538+ q_m , calib_inputs , device = device
539+ )
540+ print ("\n ┌── Wikitext-2 train perplexity ─────────────" )
541+ print (f"│ int16 : { train_ppl_ioqdtype :8.2f} " )
542+ print ("└───────────────────────────────────────────" )
543+
449544 # after PTQ quantizer only fixed-length input sequences are valid
450545 evaluate (q_m , tokenizer , dataset_test , args )
451546
0 commit comments