3939from lm_eval .utils import make_table
4040from transformers import AutoModelForCausalLM , AutoTokenizer
4141
42- from transformers .cache_utils import Cache
43- from transformers .modeling_outputs import CausalLMOutputWithPast
44- from transformers .models .llama .modeling_llama import KwargsForCausalLM , LlamaForCausalLM
45- from transformers .processing_utils import Unpack
46-
4742import tico
4843
4944from tico .quantization import convert , prepare
@@ -107,60 +102,12 @@ def inject_gptq_qparams(
107102def save_circles_to (q_m , calib_inputs , save_circle_to_folder ):
108103 q_m .eval ()
109104 q_m .cpu ()
110- save_path = pathlib .Path (save_circle_to_folder , "embedding.q.circle" )
111- pathlib .Path ()
112- print (f"saving input embedding to { save_path .resolve ()} " )
113- with torch .no_grad ():
114- with SuppressWarning (UserWarning , ".*" ):
115- cm = tico .convert (
116- q_m .model .embed_tokens ,
117- (calib_inputs [0 ],),
118- strict = False ,
119- )
120- cm .save (save_path )
121-
122- save_path = pathlib .Path (save_circle_to_folder , "lm_head.q.circle" )
123- print (f"saving lm_head to { save_path .resolve ()} " )
124- with torch .no_grad ():
125- with SuppressWarning (UserWarning , ".*" ):
126- B , S , D = 1 , q_m .config .max_position_embeddings , q_m .config .hidden_size
127- example_hidden = torch .randn (B , S , D )
128- cm = tico .convert (
129- q_m .lm_head ,
130- (example_hidden ,),
131- strict = False ,
132- )
133- cm .save (save_path )
134-
135- print ("saving layers" )
136- for i in range (len (q_m .model .layers )):
137- save_path = pathlib .Path (save_circle_to_folder , f"decoder_layer_{ i } .q.circle" )
138- print (f"saving model layer_{ i } to { save_path .resolve ()} " )
139- B , S , D = 1 , q_m .config .max_position_embeddings , q_m .config .hidden_size
140- example_hidden = torch .randn (B , S , D )
141-
142- with torch .no_grad ():
143- with SuppressWarning (UserWarning , ".*" ):
144- cm = tico .convert (
145- q_m .model .layers [i ],
146- (example_hidden ,),
147- strict = False ,
148- )
149- cm .save (save_path )
150-
151- save_path = pathlib .Path (save_circle_to_folder , "model.model.q.circle" )
152- print (f"saving model.model to { save_path .resolve ()} " )
153- with torch .no_grad ():
154- with SuppressWarning (UserWarning , ".*" ):
155- cm = tico .convert (q_m .model , (calib_inputs [0 ],), strict = False )
156-
157- cm .save (save_path )
158105
159106 save_path = pathlib .Path (save_circle_to_folder , "model.q.circle" )
160107 print (f"saving the whole model to { save_path .resolve ()} " )
161108 with torch .no_grad ():
162109 with SuppressWarning (UserWarning , ".*" ):
163- cm = tico .convert (q_m , (calib_inputs [0 ],), strict = False )
110+ cm = tico .convert (q_m . wrapped , (calib_inputs [0 ],), strict = False )
164111
165112 cm .save (save_path )
166113
@@ -222,13 +169,19 @@ def quantize_using_PTQ(q_m, calib_inputs, args):
222169 default_dtype = DType .int (16 ),
223170 default_qscheme = QScheme .PER_TENSOR_SYMM ,
224171 overrides = {
225- "model.embeddings" : {
226- "weight" : {
227- "dtype" : (
228- DType .uint (args .embedding_weight_bits )
229- if args .embedding_weight_bits < 16
230- else DType .int (args .embedding_weight_bits )
231- ),
172+ "model" : {
173+ "embed_tokens" : {
174+ "weight" : {
175+ "dtype" : (
176+ DType .uint (args .embedding_weight_bits )
177+ if args .embedding_weight_bits < 16
178+ else DType .int (args .embedding_weight_bits )
179+ ),
180+ },
181+ },
182+ "layers" : {},
183+ "norm" : {
184+ "weight" : {"dtype" : DType .int (16 )},
232185 },
233186 },
234187 "lm_head" : {
@@ -240,17 +193,14 @@ def quantize_using_PTQ(q_m, calib_inputs, args):
240193 ),
241194 },
242195 },
243- "model.norm" : {
244- "weight" : {"dtype" : DType .int (16 )},
245- },
246196 },
247197 )
248198 for i in range (len (q_m .model .layers )):
249- child_scope = f"layer { i } "
250- cfg .overrides [child_scope ] = w_cfg # type: ignore[index]
199+ child_scope = f"{ i } "
200+ cfg .overrides ["model" ][ "layers" ][ child_scope ] = w_cfg # type: ignore[index]
251201
252202 qcfg = cfg
253- prepare (q_m , qcfg )
203+ q_m = prepare (q_m , qcfg )
254204
255205 # -------------------------------------------------------------------------
256206 # Single-pass activation calibration
@@ -260,6 +210,12 @@ def quantize_using_PTQ(q_m, calib_inputs, args):
260210 # Overwrite weight observers with GPTQ statistics
261211 if hasattr (q_m , "quantizers" ) and isinstance (q_m .quantizers , dict ):
262212 inject_gptq_qparams (q_m , q_m .quantizers )
213+ elif (
214+ hasattr (q_m , "wrapped" )
215+ and hasattr (q_m .wrapped , "quantizers" )
216+ and isinstance (q_m .wrapped .quantizers , dict )
217+ ):
218+ inject_gptq_qparams (q_m .wrapped , q_m .wrapped .quantizers )
263219 else :
264220 print (
265221 "[Warn] q_m.quantizers not found or not a dict; skipping GPTQ qparam injection."
@@ -276,91 +232,14 @@ def quantize_using_PTQ(q_m, calib_inputs, args):
276232 return q_m
277233
278234
279- def fix_inputs (model , tokenizer , input_ids ):
280- if tokenizer .pad_token_id is not None :
281- pads = torch .full (
282- (
283- input_ids .shape [0 ],
284- model .config .max_position_embeddings - input_ids .shape [1 ],
285- ),
286- fill_value = tokenizer .pad_token_id ,
287- device = input_ids .device ,
288- )
289- elif tokenizer .eos_token_id is not None :
290- pads = torch .full (
291- (
292- input_ids .shape [0 ],
293- model .config .max_position_embeddings - input_ids .shape [1 ],
294- ),
295- fill_value = tokenizer .eos_token_id ,
296- device = input_ids .device ,
297- )
298- else :
299- raise RuntimeError (
300- "failed to pad sequence - tokenizer doesn't have pad_token_id/eos_token_id"
301- )
302-
303- return torch .cat ((input_ids , pads ), dim = 1 )
304-
305-
306- class LLamaWithFixedInput (LlamaForCausalLM ):
307- def __init__ (self , parent : LlamaForCausalLM , tokenizer ):
308- assert parent .config is not None , "config is a must have"
309- super ().__init__ (parent .config )
310- self .__dict__ .update (parent .__dict__ )
311-
312- def forward (
313- self ,
314- input_ids : torch .LongTensor = None , # type: ignore[assignment]
315- attention_mask : Optional [torch .Tensor ] = None ,
316- position_ids : Optional [torch .LongTensor ] = None ,
317- past_key_values : Optional [Union [Cache , List [torch .FloatTensor ]]] = None ,
318- inputs_embeds : Optional [torch .FloatTensor ] = None ,
319- labels : Optional [torch .LongTensor ] = None ,
320- use_cache : Optional [bool ] = None ,
321- output_attentions : Optional [bool ] = None ,
322- output_hidden_states : Optional [bool ] = None ,
323- return_dict : Optional [bool ] = None ,
324- cache_position : Optional [torch .LongTensor ] = None ,
325- logits_to_keep : Union [int , torch .Tensor ] = 0 ,
326- ** kwargs : Unpack [KwargsForCausalLM ],
327- ) -> Union [Tuple , CausalLMOutputWithPast ]:
328- # fixed input size, due to position_ids fixed
329- orig_len = input_ids .shape [- 1 ]
330- input_ids = fix_inputs (self , self .tokenizer , input_ids )
331- if labels is not None :
332- labels = fix_inputs (self , self .tokenizer , labels )
333- res = super ().forward (
334- input_ids ,
335- attention_mask ,
336- position_ids ,
337- past_key_values ,
338- inputs_embeds ,
339- labels ,
340- use_cache ,
341- output_attentions ,
342- output_hidden_states ,
343- return_dict ,
344- cache_position ,
345- logits_to_keep ,
346- ** kwargs ,
347- )
348- # we need to trim to the original size
349- res .logits = res .logits [..., :orig_len , :]
350- return res
351-
352- self .forward = types .MethodType (forward , self )
353- self .tokenizer = tokenizer
354-
355-
356235def evaluate (q_m , tokenizer , dataset_test , args ):
357236 # -------------------------------------------------------------------------
358237 # Evaluate perplexity on Wikitext-2
359238 # -------------------------------------------------------------------------
360239 print ("\n Calculating perplexities …" )
361240 enc = tokenizer ("\n \n " .join (dataset_test ["text" ]), return_tensors = "pt" )
362241 ppl_uint8 = perplexity (
363- q_m , enc , args .device , stride = q_m .config .max_position_embeddings
242+ q_m , enc , args .device , stride = q_m .wrapped . config .max_position_embeddings
364243 )
365244
366245 print ("\n ┌── Wikitext-2 test perplexity ─────────────" )
@@ -576,7 +455,7 @@ def main():
576455 q_m = quantize_using_PTQ (q_m , calib_inputs , args )
577456
578457 # after PTQ quantizer only fixed-length input sequences are valid
579- evaluate (LLamaWithFixedInput ( q_m , tokenizer ) , tokenizer , dataset_test , args )
458+ evaluate (q_m , tokenizer , dataset_test , args )
580459
581460 if args .save_circle_to_folder is not None :
582461 save_circles_to (q_m , calib_inputs , args .save_circle_to_folder )
0 commit comments