@@ -104,13 +104,30 @@ def inject_gptq_qparams(
104104def save_circles_to (q_m , calib_inputs , save_circle_to_folder ):
105105 q_m .eval ()
106106 q_m .cpu ()
107+
108+ save_path = pathlib .Path (save_circle_to_folder , "model.q.circle" )
109+ print (f"saving the whole model to { save_path .resolve ()} " )
110+ with torch .no_grad ():
111+ with SuppressWarning (UserWarning , ".*" ):
112+ cm = tico .convert (q_m .wrapped , (calib_inputs [0 ],), strict = False )
113+
114+ cm .save (save_path )
115+
116+ save_path = pathlib .Path (save_circle_to_folder , "model.model.q.circle" )
117+ print (f"saving model.model to { save_path .resolve ()} " )
118+ with torch .no_grad ():
119+ with SuppressWarning (UserWarning , ".*" ):
120+ cm = tico .convert (q_m .wrapped .model , (calib_inputs [0 ],), strict = False )
121+
122+ cm .save (save_path )
123+
107124 save_path = pathlib .Path (save_circle_to_folder , "embedding.q.circle" )
108125 pathlib .Path ()
109126 print (f"saving input embedding to { save_path .resolve ()} " )
110127 with torch .no_grad ():
111128 with SuppressWarning (UserWarning , ".*" ):
112129 cm = tico .convert (
113- q_m .model .embed_tokens ,
130+ q_m .wrapped . model . wrapped .embed_tokens ,
114131 (calib_inputs [0 ],),
115132 strict = False ,
116133 )
@@ -120,47 +137,42 @@ def save_circles_to(q_m, calib_inputs, save_circle_to_folder):
120137 print (f"saving lm_head to { save_path .resolve ()} " )
121138 with torch .no_grad ():
122139 with SuppressWarning (UserWarning , ".*" ):
123- B , S , D = 1 , q_m .config .max_position_embeddings , q_m .config .hidden_size
140+ B , S , D = (
141+ 1 ,
142+ q_m .wrapped .config .max_position_embeddings ,
143+ q_m .wrapped .config .hidden_size ,
144+ )
124145 example_hidden = torch .randn (B , S , D )
125146 cm = tico .convert (
126- q_m .lm_head ,
147+ q_m .wrapped . lm_head ,
127148 (example_hidden ,),
128149 strict = False ,
129150 )
130151 cm .save (save_path )
131152
132153 print ("saving layers" )
133- for i in range (len (q_m .model .layers )):
154+ for i in range (len (q_m .wrapped . model . wrapped .layers )):
134155 save_path = pathlib .Path (save_circle_to_folder , f"decoder_layer_{ i } .q.circle" )
135156 print (f"saving model layer_{ i } to { save_path .resolve ()} " )
136- B , S , D = 1 , q_m .config .max_position_embeddings , q_m .config .hidden_size
157+ B , S , D = (
158+ 1 ,
159+ q_m .wrapped .config .max_position_embeddings ,
160+ q_m .wrapped .config .hidden_size ,
161+ )
137162 example_hidden = torch .randn (B , S , D )
163+ cur_layer = q_m .wrapped .model .wrapped .layers [i ].wrapped
164+ if hasattr (cur_layer , "copy_quantizers" ):
165+ cur_layer .copy_quantizers (q_m .wrapped .model .wrapped )
138166
139167 with torch .no_grad ():
140168 with SuppressWarning (UserWarning , ".*" ):
141169 cm = tico .convert (
142- q_m .model .layers [i ],
170+ q_m .wrapped . model . wrapped .layers [i ],
143171 (example_hidden ,),
144172 strict = False ,
145173 )
146174 cm .save (save_path )
147175
148- save_path = pathlib .Path (save_circle_to_folder , "model.model.q.circle" )
149- print (f"saving model.model to { save_path .resolve ()} " )
150- with torch .no_grad ():
151- with SuppressWarning (UserWarning , ".*" ):
152- cm = tico .convert (q_m .model , (calib_inputs [0 ],), strict = False )
153-
154- cm .save (save_path )
155-
156- save_path = pathlib .Path (save_circle_to_folder , "model.q.circle" )
157- print (f"saving the whole model to { save_path .resolve ()} " )
158- with torch .no_grad ():
159- with SuppressWarning (UserWarning , ".*" ):
160- cm = tico .convert (q_m , (calib_inputs [0 ],), strict = False )
161-
162- cm .save (save_path )
163-
164176
165177def quantize_using_PTQ (q_m , calib_inputs , args ):
166178 print ("Wrapping layers with PTQWrapper …" )
@@ -219,13 +231,19 @@ def quantize_using_PTQ(q_m, calib_inputs, args):
219231 default_dtype = DType .int (16 ),
220232 default_qscheme = QScheme .PER_TENSOR_SYMM ,
221233 overrides = {
222- "model.embeddings" : {
223- "weight" : {
224- "dtype" : (
225- DType .uint (args .embedding_weight_bits )
226- if args .embedding_weight_bits < 16
227- else DType .int (args .embedding_weight_bits )
228- ),
234+ "model" : {
235+ "embed_tokens" : {
236+ "weight" : {
237+ "dtype" : (
238+ DType .uint (args .embedding_weight_bits )
239+ if args .embedding_weight_bits < 16
240+ else DType .int (args .embedding_weight_bits )
241+ ),
242+ },
243+ },
244+ "layers" : {},
245+ "norm" : {
246+ "weight" : {"dtype" : DType .int (16 )},
229247 },
230248 },
231249 "lm_head" : {
@@ -237,17 +255,14 @@ def quantize_using_PTQ(q_m, calib_inputs, args):
237255 ),
238256 },
239257 },
240- "model.norm" : {
241- "weight" : {"dtype" : DType .int (16 )},
242- },
243258 },
244259 )
245260 for i in range (len (q_m .model .layers )):
246- child_scope = f"layer { i } "
247- cfg .overrides [child_scope ] = w_cfg # type: ignore[index]
261+ child_scope = f"{ i } "
262+ cfg .overrides ["model" ][ "layers" ][ child_scope ] = w_cfg # type: ignore[index]
248263
249264 qcfg = cfg
250- prepare (q_m , qcfg )
265+ q_m = prepare (q_m , qcfg )
251266
252267 # -------------------------------------------------------------------------
253268 # Single-pass activation calibration
@@ -257,6 +272,12 @@ def quantize_using_PTQ(q_m, calib_inputs, args):
257272 # Overwrite weight observers with GPTQ statistics
258273 if hasattr (q_m , "quantizers" ) and isinstance (q_m .quantizers , dict ):
259274 inject_gptq_qparams (q_m , q_m .quantizers )
275+ elif (
276+ hasattr (q_m , "wrapped" )
277+ and hasattr (q_m .wrapped , "quantizers" )
278+ and isinstance (q_m .wrapped .quantizers , dict )
279+ ):
280+ inject_gptq_qparams (q_m .wrapped , q_m .wrapped .quantizers )
260281 else :
261282 print (
262283 "[Warn] q_m.quantizers not found or not a dict; skipping GPTQ qparam injection."
@@ -300,65 +321,14 @@ def fix_inputs(model, tokenizer, input_ids):
300321 return torch .cat ((input_ids , pads ), dim = 1 )
301322
302323
303- class LLamaWithFixedInput (LlamaForCausalLM ):
304-
305- def __init__ (self , parent : LlamaForCausalLM , tokenizer ):
306- assert parent .config is not None , "config is a must have"
307- super (LlamaForCausalLM , self ).__init__ (parent .config )
308- self .__dict__ .update (parent .__dict__ )
309-
310- def forward (
311- self ,
312- input_ids : torch .LongTensor = None ,
313- attention_mask : Optional [torch .Tensor ] = None ,
314- position_ids : Optional [torch .LongTensor ] = None ,
315- past_key_values : Optional [Union [Cache , List [torch .FloatTensor ]]] = None ,
316- inputs_embeds : Optional [torch .FloatTensor ] = None ,
317- labels : Optional [torch .LongTensor ] = None ,
318- use_cache : Optional [bool ] = None ,
319- output_attentions : Optional [bool ] = None ,
320- output_hidden_states : Optional [bool ] = None ,
321- return_dict : Optional [bool ] = None ,
322- cache_position : Optional [torch .LongTensor ] = None ,
323- logits_to_keep : Union [int , torch .Tensor ] = 0 ,
324- ** kwargs : Unpack [KwargsForCausalLM ],
325- ) -> Union [Tuple , CausalLMOutputWithPast ]:
326- # fixed input size, due to position_ids fixed
327- orig_len = input_ids .shape [- 1 ]
328- input_ids = fix_inputs (self , self .tokenizer , input_ids )
329- if labels is not None :
330- labels = fix_inputs (self , self .tokenizer , labels )
331- res = super ().forward (
332- input_ids ,
333- attention_mask ,
334- position_ids ,
335- past_key_values ,
336- inputs_embeds ,
337- labels ,
338- use_cache ,
339- output_attentions ,
340- output_hidden_states ,
341- return_dict ,
342- cache_position ,
343- logits_to_keep ,
344- ** kwargs ,
345- )
346- # we need to trim to the original size
347- res .logits = res .logits [..., :orig_len , :]
348- return res
349-
350- self .forward = types .MethodType (forward , self )
351- self .tokenizer = tokenizer
352-
353-
354324def evaluate (q_m , tokenizer , dataset_test , args ):
355325 # -------------------------------------------------------------------------
356326 # Evaluate perplexity on Wikitext-2
357327 # -------------------------------------------------------------------------
358328 print ("\n Calculating perplexities …" )
359329 enc = tokenizer ("\n \n " .join (dataset_test ["text" ]), return_tensors = "pt" )
360330 ppl_uint8 = perplexity (
361- q_m , enc , args .device , stride = q_m .config .max_position_embeddings
331+ q_m , enc , args .device , stride = q_m .wrapped . config .max_position_embeddings
362332 )
363333
364334 print ("\n ┌── Wikitext-2 test perplexity ─────────────" )
@@ -564,7 +534,7 @@ def main():
564534 q_m = quantize_using_PTQ (q_m , calib_inputs , args )
565535
566536 # after PTQ quantizer only fixed-length input sequences are valid
567- evaluate (LLamaWithFixedInput ( q_m , tokenizer ) , tokenizer , dataset_test , args )
537+ evaluate (q_m , tokenizer , dataset_test , args )
568538
569539 if args .save_circle_to_folder is not None :
570540 save_circles_to (q_m , calib_inputs , args .save_circle_to_folder )
0 commit comments