Skip to content

Commit 6ac97c0

Browse files
committed
[quantization] Quantization of Llama
This PR quantizes the full `LLama` model and converts it to circle format. TICO-DCO-1.0-Signed-off-by: s.malakhov <s.malakhov@partner.samsung.com>
1 parent 025374e commit 6ac97c0

8 files changed

Lines changed: 78 additions & 169 deletions

File tree

tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py

Lines changed: 25 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,6 @@
3939
from lm_eval.utils import make_table
4040
from transformers import AutoModelForCausalLM, AutoTokenizer
4141

42-
from transformers.cache_utils import Cache
43-
from transformers.modeling_outputs import CausalLMOutputWithPast
44-
from transformers.models.llama.modeling_llama import KwargsForCausalLM, LlamaForCausalLM
45-
from transformers.processing_utils import Unpack
46-
4742
import tico
4843

4944
from tico.quantization import convert, prepare
@@ -107,60 +102,12 @@ def inject_gptq_qparams(
107102
def save_circles_to(q_m, calib_inputs, save_circle_to_folder):
108103
q_m.eval()
109104
q_m.cpu()
110-
save_path = pathlib.Path(save_circle_to_folder, "embedding.q.circle")
111-
pathlib.Path()
112-
print(f"saving input embedding to {save_path.resolve()}")
113-
with torch.no_grad():
114-
with SuppressWarning(UserWarning, ".*"):
115-
cm = tico.convert(
116-
q_m.model.embed_tokens,
117-
(calib_inputs[0],),
118-
strict=False,
119-
)
120-
cm.save(save_path)
121-
122-
save_path = pathlib.Path(save_circle_to_folder, "lm_head.q.circle")
123-
print(f"saving lm_head to {save_path.resolve()}")
124-
with torch.no_grad():
125-
with SuppressWarning(UserWarning, ".*"):
126-
B, S, D = 1, q_m.config.max_position_embeddings, q_m.config.hidden_size
127-
example_hidden = torch.randn(B, S, D)
128-
cm = tico.convert(
129-
q_m.lm_head,
130-
(example_hidden,),
131-
strict=False,
132-
)
133-
cm.save(save_path)
134-
135-
print("saving layers")
136-
for i in range(len(q_m.model.layers)):
137-
save_path = pathlib.Path(save_circle_to_folder, f"decoder_layer_{i}.q.circle")
138-
print(f"saving model layer_{i} to {save_path.resolve()}")
139-
B, S, D = 1, q_m.config.max_position_embeddings, q_m.config.hidden_size
140-
example_hidden = torch.randn(B, S, D)
141-
142-
with torch.no_grad():
143-
with SuppressWarning(UserWarning, ".*"):
144-
cm = tico.convert(
145-
q_m.model.layers[i],
146-
(example_hidden,),
147-
strict=False,
148-
)
149-
cm.save(save_path)
150-
151-
save_path = pathlib.Path(save_circle_to_folder, "model.model.q.circle")
152-
print(f"saving model.model to {save_path.resolve()}")
153-
with torch.no_grad():
154-
with SuppressWarning(UserWarning, ".*"):
155-
cm = tico.convert(q_m.model, (calib_inputs[0],), strict=False)
156-
157-
cm.save(save_path)
158105

159106
save_path = pathlib.Path(save_circle_to_folder, "model.q.circle")
160107
print(f"saving the whole model to {save_path.resolve()}")
161108
with torch.no_grad():
162109
with SuppressWarning(UserWarning, ".*"):
163-
cm = tico.convert(q_m, (calib_inputs[0],), strict=False)
110+
cm = tico.convert(q_m.wrapped, (calib_inputs[0],), strict=False)
164111

165112
cm.save(save_path)
166113

@@ -222,13 +169,19 @@ def quantize_using_PTQ(q_m, calib_inputs, args):
222169
default_dtype=DType.int(16),
223170
default_qscheme=QScheme.PER_TENSOR_SYMM,
224171
overrides={
225-
"model.embeddings": {
226-
"weight": {
227-
"dtype": (
228-
DType.uint(args.embedding_weight_bits)
229-
if args.embedding_weight_bits < 16
230-
else DType.int(args.embedding_weight_bits)
231-
),
172+
"model": {
173+
"embed_tokens": {
174+
"weight": {
175+
"dtype": (
176+
DType.uint(args.embedding_weight_bits)
177+
if args.embedding_weight_bits < 16
178+
else DType.int(args.embedding_weight_bits)
179+
),
180+
},
181+
},
182+
"layers": {},
183+
"norm": {
184+
"weight": {"dtype": DType.int(16)},
232185
},
233186
},
234187
"lm_head": {
@@ -240,17 +193,14 @@ def quantize_using_PTQ(q_m, calib_inputs, args):
240193
),
241194
},
242195
},
243-
"model.norm": {
244-
"weight": {"dtype": DType.int(16)},
245-
},
246196
},
247197
)
248198
for i in range(len(q_m.model.layers)):
249-
child_scope = f"layer{i}"
250-
cfg.overrides[child_scope] = w_cfg # type: ignore[index]
199+
child_scope = f"{i}"
200+
cfg.overrides["model"]["layers"][child_scope] = w_cfg # type: ignore[index]
251201

252202
qcfg = cfg
253-
prepare(q_m, qcfg)
203+
q_m = prepare(q_m, qcfg)
254204

255205
# -------------------------------------------------------------------------
256206
# Single-pass activation calibration
@@ -260,6 +210,12 @@ def quantize_using_PTQ(q_m, calib_inputs, args):
260210
# Overwrite weight observers with GPTQ statistics
261211
if hasattr(q_m, "quantizers") and isinstance(q_m.quantizers, dict):
262212
inject_gptq_qparams(q_m, q_m.quantizers)
213+
elif (
214+
hasattr(q_m, "wrapped")
215+
and hasattr(q_m.wrapped, "quantizers")
216+
and isinstance(q_m.wrapped.quantizers, dict)
217+
):
218+
inject_gptq_qparams(q_m.wrapped, q_m.wrapped.quantizers)
263219
else:
264220
print(
265221
"[Warn] q_m.quantizers not found or not a dict; skipping GPTQ qparam injection."
@@ -276,91 +232,14 @@ def quantize_using_PTQ(q_m, calib_inputs, args):
276232
return q_m
277233

278234

279-
def fix_inputs(model, tokenizer, input_ids):
280-
if tokenizer.pad_token_id is not None:
281-
pads = torch.full(
282-
(
283-
input_ids.shape[0],
284-
model.config.max_position_embeddings - input_ids.shape[1],
285-
),
286-
fill_value=tokenizer.pad_token_id,
287-
device=input_ids.device,
288-
)
289-
elif tokenizer.eos_token_id is not None:
290-
pads = torch.full(
291-
(
292-
input_ids.shape[0],
293-
model.config.max_position_embeddings - input_ids.shape[1],
294-
),
295-
fill_value=tokenizer.eos_token_id,
296-
device=input_ids.device,
297-
)
298-
else:
299-
raise RuntimeError(
300-
"failed to pad sequence - tokenizer doesn't have pad_token_id/eos_token_id"
301-
)
302-
303-
return torch.cat((input_ids, pads), dim=1)
304-
305-
306-
class LLamaWithFixedInput(LlamaForCausalLM):
307-
def __init__(self, parent: LlamaForCausalLM, tokenizer):
308-
assert parent.config is not None, "config is a must have"
309-
super().__init__(parent.config)
310-
self.__dict__.update(parent.__dict__)
311-
312-
def forward(
313-
self,
314-
input_ids: torch.LongTensor = None, # type: ignore[assignment]
315-
attention_mask: Optional[torch.Tensor] = None,
316-
position_ids: Optional[torch.LongTensor] = None,
317-
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
318-
inputs_embeds: Optional[torch.FloatTensor] = None,
319-
labels: Optional[torch.LongTensor] = None,
320-
use_cache: Optional[bool] = None,
321-
output_attentions: Optional[bool] = None,
322-
output_hidden_states: Optional[bool] = None,
323-
return_dict: Optional[bool] = None,
324-
cache_position: Optional[torch.LongTensor] = None,
325-
logits_to_keep: Union[int, torch.Tensor] = 0,
326-
**kwargs: Unpack[KwargsForCausalLM],
327-
) -> Union[Tuple, CausalLMOutputWithPast]:
328-
# fixed input size, due to position_ids fixed
329-
orig_len = input_ids.shape[-1]
330-
input_ids = fix_inputs(self, self.tokenizer, input_ids)
331-
if labels is not None:
332-
labels = fix_inputs(self, self.tokenizer, labels)
333-
res = super().forward(
334-
input_ids,
335-
attention_mask,
336-
position_ids,
337-
past_key_values,
338-
inputs_embeds,
339-
labels,
340-
use_cache,
341-
output_attentions,
342-
output_hidden_states,
343-
return_dict,
344-
cache_position,
345-
logits_to_keep,
346-
**kwargs,
347-
)
348-
# we need to trim to the original size
349-
res.logits = res.logits[..., :orig_len, :]
350-
return res
351-
352-
self.forward = types.MethodType(forward, self)
353-
self.tokenizer = tokenizer
354-
355-
356235
def evaluate(q_m, tokenizer, dataset_test, args):
357236
# -------------------------------------------------------------------------
358237
# Evaluate perplexity on Wikitext-2
359238
# -------------------------------------------------------------------------
360239
print("\nCalculating perplexities …")
361240
enc = tokenizer("\n\n".join(dataset_test["text"]), return_tensors="pt")
362241
ppl_uint8 = perplexity(
363-
q_m, enc, args.device, stride=q_m.config.max_position_embeddings
242+
q_m, enc, args.device, stride=q_m.wrapped.config.max_position_embeddings
364243
)
365244

366245
print("\n┌── Wikitext-2 test perplexity ─────────────")
@@ -576,7 +455,7 @@ def main():
576455
q_m = quantize_using_PTQ(q_m, calib_inputs, args)
577456

578457
# after PTQ quantizer only fixed-length input sequences are valid
579-
evaluate(LLamaWithFixedInput(q_m, tokenizer), tokenizer, dataset_test, args)
458+
evaluate(q_m, tokenizer, dataset_test, args)
580459

581460
if args.save_circle_to_folder is not None:
582461
save_circles_to(q_m, calib_inputs, args.save_circle_to_folder)

tico/quantization/wrapq/examples/quantize_with_gptq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
4343
from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
4444

45-
4645
# Token-budget presets for activation calibration
4746
TOKENS: dict[str, int] = {
4847
# Smoke test (<1 min turnaround on CPU/GPU)
@@ -66,6 +65,7 @@
6665
TRAIN_SPLIT = "train"
6766
TEST_SPLIT = "test"
6867

68+
6969
# -------------------------------------------------------------------------
7070
# 1. Helper — copy GPTQ (scale, zp) into PTQ observers
7171
# -------------------------------------------------------------------------

tico/quantization/wrapq/quantizer.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,18 @@ def _wrap_supported(
8181
Recursively attempt to wrap boundaries. Strictness is applied at every boundary.
8282
"""
8383
assert not isinstance(root, QuantModuleBase), "The module is already wrapped."
84+
try:
85+
return PTQWrapper(root, qcfg=qcfg, fp_name="model")
86+
except NotImplementedError as e:
87+
print("no special wrapper for model, wrappig using general case")
8488

8589
# Case A: HuggingFace-style transformers: model.model.layers
8690
lm = getattr(root, "model", None)
8791

8892
embeddings = (
89-
getattr(lm, "embed_tokens", None) if isinstance(lm, nn.Module) else None
93+
getattr(lm, "embed_tokens", None)
94+
if isinstance(lm.embed_tokens, nn.Module) # type: ignore[union-attr]
95+
else None
9096
)
9197
if isinstance(embeddings, nn.Module):
9298
child_scope = "model.embeddings"
@@ -99,7 +105,11 @@ def _wrap_supported(
99105
)
100106
lm.embed_tokens = wrapped # type: ignore[union-attr]
101107

102-
model_norm = getattr(lm, "norm", None) if isinstance(lm, nn.Module) else None
108+
model_norm = (
109+
getattr(lm, "norm", None)
110+
if isinstance(lm.norm, nn.Module) # type: ignore[union-attr]
111+
else None
112+
)
103113
if isinstance(model_norm, nn.Module):
104114
child_scope = "model.norm"
105115
child_cfg = qcfg.child(child_scope)

tico/quantization/wrapq/utils/metrics.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,15 @@ def perplexity(
9090
input_ids_full = input_ids_full.to(device)
9191

9292
if max_length is None:
93-
assert hasattr(model, "config")
94-
model_config = model.config
95-
if hasattr(model.config, "text_config"):
96-
model_config = model.config.text_config
93+
if hasattr(model, "config"):
94+
assert hasattr(model, "config")
95+
model_config = model.config
96+
else:
97+
assert hasattr(model.wrapped, "config")
98+
model_config = model.wrapped.config
99+
100+
if hasattr(model_config, "text_config"):
101+
model_config = model_config.text_config
97102
assert hasattr(model_config, "max_position_embeddings")
98103
assert isinstance(model_config.max_position_embeddings, int)
99104
max_length = model_config.max_position_embeddings

tico/quantization/wrapq/wrappers/llama/quant_attn_prefill.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,6 @@ def forward(
191191

192192
# Rope tables
193193
cos, sin = position_embeddings
194-
cos = self._fq(cos, self.obs_cos)
195-
sin = self._fq(sin, self.obs_sin)
196194

197195
# --- KV for attention & present_key_value -------------
198196
present_key_value: Tuple[torch.Tensor, torch.Tensor]
@@ -205,7 +203,7 @@ def forward(
205203
attention_mask = self.causal_mask_template[..., :q_len, :k_len].to(
206204
hidden_states.device
207205
)
208-
attention_mask = self._fq(attention_mask, self.obs_causal_mask)
206+
attention_mask = self._fq(attention_mask, self.obs_causal_mask)
209207

210208
attn_weights_parts = []
211209
attn_out_parts = []
@@ -251,8 +249,9 @@ def forward(
251249
logits_i = self._fq(q_i @ k_i.transpose(-2, -1), self.obs_logits)
252250

253251
# mask add
252+
assert attention_mask.shape == logits_i.shape # check for compatiblity
254253
logits_i = self._fq(
255-
logits_i + attention_mask.view(1, q_i.size(1), k_i.size(1)),
254+
logits_i + attention_mask,
256255
self.obs_mask_add,
257256
)
258257

tico/quantization/wrapq/wrappers/llama/quant_decoder_layer_prefill.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@ def __init__(
107107
qcfg=post_attention_layernorm,
108108
fp_name=f"{fp_name}.post_attention_layernorm",
109109
)
110+
self.obs_causal_mask = self._make_obs("causal_mask")
111+
self.obs_cos = self._make_obs("cos")
112+
self.obs_sin = self._make_obs("sin")
110113

111114
# Static causal mask template ---------------------------------------
112115
assert hasattr(fp_layer.self_attn, "config") and hasattr(
@@ -184,18 +187,28 @@ def forward(
184187
residual = hidden_states
185188
hidden_states = self.input_layernorm(hidden_states)
186189

187-
# to prevent introduction of attention_mask as a parameter let's use preset attention_mask
188-
L = hidden_states.size(1)
189-
attention_mask = self._slice_causal(L, hidden_states.device)
190-
191-
position_embeddings = (
192-
self.rope_cos_template.to(
193-
dtype=hidden_states.dtype, device=hidden_states.device
194-
),
195-
self.rope_sin_template.to(
196-
dtype=hidden_states.dtype, device=hidden_states.device
197-
),
198-
)
190+
if attention_mask is None or attention_mask.dtype == torch.bool:
191+
L = hidden_states.size(1)
192+
attention_mask = self._slice_causal(L, hidden_states.device)
193+
attention_mask = attention_mask.squeeze(0)
194+
attention_mask = self.fq(
195+
attention_mask, self.obs_causal_mask
196+
) # let it be quantized immediately
197+
198+
if position_embeddings is None:
199+
position_embeddings = (
200+
self.rope_cos_template.to(
201+
dtype=hidden_states.dtype, device=hidden_states.device
202+
),
203+
self.rope_sin_template.to(
204+
dtype=hidden_states.dtype, device=hidden_states.device
205+
),
206+
)
207+
cos, sin = position_embeddings
208+
position_embeddings = (
209+
self._fq(cos, self.obs_cos),
210+
self._fq(sin, self.obs_sin),
211+
)
199212

200213
attn_out = self.self_attn(
201214
hidden_states=hidden_states,
@@ -241,6 +254,7 @@ def forward(
241254

242255
# No local observers; just recurse into children
243256
def _all_observers(self):
257+
yield from (self.obs_causal_mask, self.obs_cos, self.obs_sin)
244258
yield from self.self_attn._all_observers()
245259
yield from self.mlp._all_observers()
246260
yield self.obs_mlp_residual_out

0 commit comments

Comments
 (0)