2121import torch
2222import torch .nn as nn
2323
24+ from tico .quantization .algorithm .fpi_gptq .util import iterate_GPTQ
2425
2526def quantize (x , scale , zero , maxq ):
2627 if maxq < 0 :
@@ -41,11 +42,12 @@ def configure(
4142 bits ,
4243 perchannel = False ,
4344 sym = True ,
44- mse = False ,
45+ mse = None ,
4546 norm = 2.4 ,
4647 grid = 100 ,
4748 maxshrink = 0.8 ,
4849 trits = False ,
50+ sensitivity = None ,
4951 ):
5052 self .maxq = torch .tensor (2 ** bits - 1 )
5153 self .perchannel = perchannel
@@ -54,6 +56,7 @@ def configure(
5456 self .norm = norm
5557 self .grid = grid
5658 self .maxshrink = maxshrink
59+ self .sensitivity = sensitivity
5760 if trits :
5861 self .maxq = torch .tensor (- 1 )
5962
@@ -99,7 +102,10 @@ def find_params(self, x, weight=False):
99102 else :
100103 self .zero = torch .round (- xmin / self .scale )
101104
102- if self .mse :
105+ if self .mse is not None and self .mse != "smse_for_gptq" :
106+ if self .mse == "smse" :
107+ self .maxshrink = 0.5
108+
103109 best = torch .full ([x .shape [0 ]], float ("inf" ), device = dev )
104110 for i in range (int (self .maxshrink * self .grid )):
105111 p = 1 - i / self .grid
@@ -110,13 +116,19 @@ def find_params(self, x, weight=False):
110116 q = quantize (x , scale1 .unsqueeze (1 ), zero1 .unsqueeze (1 ), self .maxq )
111117 q -= x
112118 q .abs_ ()
113- q .pow_ (self .norm )
119+ if self .mse == "smse" :
120+ q = (q ** 2 ) * self .sensitivity .to (
121+ q .device
122+ ) # sensitivity weighted `mse`
123+ else :
124+ q .pow_ (self .norm )
114125 err = torch .sum (q , 1 )
115126 tmp = err < best
116127 if torch .any (tmp ):
117128 best [tmp ] = err [tmp ]
118129 self .scale [tmp ] = scale1 [tmp ]
119130 self .zero [tmp ] = zero1 [tmp ]
131+
120132 if not self .perchannel :
121133 if weight :
122134 tmp = shape [0 ]
@@ -141,6 +153,83 @@ def find_params(self, x, weight=False):
141153 self .scale = self .scale .unsqueeze (0 )
142154 self .zero = self .zero .unsqueeze (0 )
143155
156+ def update (self , x , Hinv , perm ):
157+ if self .mse is None or (
158+ self .mse != "smse_for_gptq" and self .mse != "mse_for_gptq"
159+ ):
160+ return
161+
162+ shape = x .shape
163+ if self .perchannel :
164+ x = x .flatten (1 )
165+ else :
166+ x = x .flatten ().unsqueeze (0 )
167+
168+ dev = x .device
169+ tmp = torch .zeros (x .shape [0 ], device = dev )
170+ xmin = torch .minimum (x .min (1 )[0 ], tmp )
171+ xmax = torch .maximum (x .max (1 )[0 ], tmp )
172+
173+ if self .sym :
174+ xmax = torch .maximum (torch .abs (xmin ), xmax )
175+ tmp = xmin < 0
176+ if torch .any (tmp ):
177+ xmin [tmp ] = - xmax [tmp ]
178+ tmp = (xmin == 0 ) & (xmax == 0 )
179+ xmin [tmp ] = - 1
180+ xmax [tmp ] = + 1
181+ if self .maxq < 0 :
182+ self .scale = xmax
183+ self .zero = xmin
184+ else :
185+ self .scale = (xmax - xmin ) / self .maxq
186+ if self .sym :
187+ self .zero = torch .full_like (self .scale , (self .maxq + 1 ) / 2 ) # type: ignore[arg-type]
188+ else :
189+ self .zero = torch .round (- xmin / self .scale )
190+
191+ self .maxshrink = 0.5
192+ sensitivity = None
193+ if self .sensitivity is not None :
194+ sensitivity = self .sensitivity .to (Hinv .dtype ).to (dev )
195+ if perm is not None :
196+ sensitivity = sensitivity [:, perm .to (dev )]
197+
198+ num_of_iters = 15
199+ best = torch .full ([x .shape [0 ]], float ("inf" ), device = dev )
200+ for i in range (int (self .maxshrink * self .grid )):
201+ p = 1 - i / self .grid
202+ xmin1 = p * xmin
203+ xmax1 = p * xmax
204+ scale1 = (xmax1 - xmin1 ) / self .maxq
205+ zero1 = torch .round (- xmin1 / scale1 ) if not self .sym else self .zero
206+ q , pre_q = iterate_GPTQ (
207+ scale1 .unsqueeze (1 ),
208+ zero1 .unsqueeze (1 ),
209+ self .maxq ,
210+ x ,
211+ Hinv ,
212+ max_num_of_iters = num_of_iters ,
213+ )
214+ if sensitivity is not None :
215+ assert self .mse == "smse_for_gptq"
216+ err = ((q - pre_q ) ** 2 ) * sensitivity .to (q .device )
217+ else :
218+ assert self .mse == "mse_for_gptq"
219+ # err = torch.abs((q - pre_q)).pow_(self.norm)
220+ err = ((q - pre_q ) / torch .diag (Hinv )) ** 2
221+ err = err
222+ err = torch .sum (err , 1 )
223+ tmp = err < best
224+ if torch .any (tmp ):
225+ best [tmp ] = err [tmp ]
226+ self .scale [tmp ] = scale1 [tmp ]
227+ self .zero [tmp ] = zero1 [tmp ]
228+
229+ shape = [- 1 ] + [1 ] * (len (shape ) - 1 )
230+ self .scale = self .scale .reshape (shape )
231+ self .zero = self .zero .reshape (shape )
232+
144233 def quantize (self , x ):
145234 if self .ready ():
146235 return quantize (x , self .scale , self .zero , self .maxq )
0 commit comments