Skip to content

Commit 19aca2b

Browse files
committed
[DRAFT] Improvements in disk space
This PR fixes population of static `causal_masks`\`position_embeddings` through the layers to save disk space. TICO-DCO-1.0-Signed-off-by: s.malakhov <s.malakhov@partner.samsung.com>
1 parent 47f09ed commit 19aca2b

9 files changed

Lines changed: 550 additions & 111 deletions

File tree

tico/quantization/wrapq/examples/quantize_full_qmodel_with_gptq.py

Lines changed: 58 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,30 @@ def inject_gptq_qparams(
104104
def save_circles_to(q_m, calib_inputs, save_circle_to_folder):
105105
q_m.eval()
106106
q_m.cpu()
107+
108+
save_path = pathlib.Path(save_circle_to_folder, "model.q.circle")
109+
print(f"saving the whole model to {save_path.resolve()}")
110+
with torch.no_grad():
111+
with SuppressWarning(UserWarning, ".*"):
112+
cm = tico.convert(q_m.wrapped, (calib_inputs[0],), strict=False)
113+
114+
cm.save(save_path)
115+
116+
save_path = pathlib.Path(save_circle_to_folder, "model.model.q.circle")
117+
print(f"saving model.model to {save_path.resolve()}")
118+
with torch.no_grad():
119+
with SuppressWarning(UserWarning, ".*"):
120+
cm = tico.convert(q_m.wrapped.model, (calib_inputs[0],), strict=False)
121+
122+
cm.save(save_path)
123+
107124
save_path = pathlib.Path(save_circle_to_folder, "embedding.q.circle")
108125
pathlib.Path()
109126
print(f"saving input embedding to {save_path.resolve()}")
110127
with torch.no_grad():
111128
with SuppressWarning(UserWarning, ".*"):
112129
cm = tico.convert(
113-
q_m.model.embed_tokens,
130+
q_m.wrapped.model.wrapped.embed_tokens,
114131
(calib_inputs[0],),
115132
strict=False,
116133
)
@@ -120,47 +137,42 @@ def save_circles_to(q_m, calib_inputs, save_circle_to_folder):
120137
print(f"saving lm_head to {save_path.resolve()}")
121138
with torch.no_grad():
122139
with SuppressWarning(UserWarning, ".*"):
123-
B, S, D = 1, q_m.config.max_position_embeddings, q_m.config.hidden_size
140+
B, S, D = (
141+
1,
142+
q_m.wrapped.config.max_position_embeddings,
143+
q_m.wrapped.config.hidden_size,
144+
)
124145
example_hidden = torch.randn(B, S, D)
125146
cm = tico.convert(
126-
q_m.lm_head,
147+
q_m.wrapped.lm_head,
127148
(example_hidden,),
128149
strict=False,
129150
)
130151
cm.save(save_path)
131152

132153
print("saving layers")
133-
for i in range(len(q_m.model.layers)):
154+
for i in range(len(q_m.wrapped.model.wrapped.layers)):
134155
save_path = pathlib.Path(save_circle_to_folder, f"decoder_layer_{i}.q.circle")
135156
print(f"saving model layer_{i} to {save_path.resolve()}")
136-
B, S, D = 1, q_m.config.max_position_embeddings, q_m.config.hidden_size
157+
B, S, D = (
158+
1,
159+
q_m.wrapped.config.max_position_embeddings,
160+
q_m.wrapped.config.hidden_size,
161+
)
137162
example_hidden = torch.randn(B, S, D)
163+
cur_layer = q_m.wrapped.model.wrapped.layers[i].wrapped
164+
if hasattr(cur_layer, "copy_quantizers"):
165+
cur_layer.copy_quantizers(q_m.wrapped.model.wrapped)
138166

139167
with torch.no_grad():
140168
with SuppressWarning(UserWarning, ".*"):
141169
cm = tico.convert(
142-
q_m.model.layers[i],
170+
q_m.wrapped.model.wrapped.layers[i],
143171
(example_hidden,),
144172
strict=False,
145173
)
146174
cm.save(save_path)
147175

148-
save_path = pathlib.Path(save_circle_to_folder, "model.model.q.circle")
149-
print(f"saving model.model to {save_path.resolve()}")
150-
with torch.no_grad():
151-
with SuppressWarning(UserWarning, ".*"):
152-
cm = tico.convert(q_m.model, (calib_inputs[0],), strict=False)
153-
154-
cm.save(save_path)
155-
156-
save_path = pathlib.Path(save_circle_to_folder, "model.q.circle")
157-
print(f"saving the whole model to {save_path.resolve()}")
158-
with torch.no_grad():
159-
with SuppressWarning(UserWarning, ".*"):
160-
cm = tico.convert(q_m, (calib_inputs[0],), strict=False)
161-
162-
cm.save(save_path)
163-
164176

165177
def quantize_using_PTQ(q_m, calib_inputs, args):
166178
print("Wrapping layers with PTQWrapper …")
@@ -219,13 +231,19 @@ def quantize_using_PTQ(q_m, calib_inputs, args):
219231
default_dtype=DType.int(16),
220232
default_qscheme=QScheme.PER_TENSOR_SYMM,
221233
overrides={
222-
"model.embeddings": {
223-
"weight": {
224-
"dtype": (
225-
DType.uint(args.embedding_weight_bits)
226-
if args.embedding_weight_bits < 16
227-
else DType.int(args.embedding_weight_bits)
228-
),
234+
"model": {
235+
"embed_tokens": {
236+
"weight": {
237+
"dtype": (
238+
DType.uint(args.embedding_weight_bits)
239+
if args.embedding_weight_bits < 16
240+
else DType.int(args.embedding_weight_bits)
241+
),
242+
},
243+
},
244+
"layers": {},
245+
"norm": {
246+
"weight": {"dtype": DType.int(16)},
229247
},
230248
},
231249
"lm_head": {
@@ -237,17 +255,14 @@ def quantize_using_PTQ(q_m, calib_inputs, args):
237255
),
238256
},
239257
},
240-
"model.norm": {
241-
"weight": {"dtype": DType.int(16)},
242-
},
243258
},
244259
)
245260
for i in range(len(q_m.model.layers)):
246-
child_scope = f"layer{i}"
247-
cfg.overrides[child_scope] = w_cfg # type: ignore[index]
261+
child_scope = f"{i}"
262+
cfg.overrides["model"]["layers"][child_scope] = w_cfg # type: ignore[index]
248263

249264
qcfg = cfg
250-
prepare(q_m, qcfg)
265+
q_m = prepare(q_m, qcfg)
251266

252267
# -------------------------------------------------------------------------
253268
# Single-pass activation calibration
@@ -257,6 +272,12 @@ def quantize_using_PTQ(q_m, calib_inputs, args):
257272
# Overwrite weight observers with GPTQ statistics
258273
if hasattr(q_m, "quantizers") and isinstance(q_m.quantizers, dict):
259274
inject_gptq_qparams(q_m, q_m.quantizers)
275+
elif (
276+
hasattr(q_m, "wrapped")
277+
and hasattr(q_m.wrapped, "quantizers")
278+
and isinstance(q_m.wrapped.quantizers, dict)
279+
):
280+
inject_gptq_qparams(q_m.wrapped, q_m.wrapped.quantizers)
260281
else:
261282
print(
262283
"[Warn] q_m.quantizers not found or not a dict; skipping GPTQ qparam injection."
@@ -300,65 +321,14 @@ def fix_inputs(model, tokenizer, input_ids):
300321
return torch.cat((input_ids, pads), dim=1)
301322

302323

303-
class LLamaWithFixedInput(LlamaForCausalLM):
304-
305-
def __init__(self, parent: LlamaForCausalLM, tokenizer):
306-
assert parent.config is not None, "config is a must have"
307-
super(LlamaForCausalLM, self).__init__(parent.config)
308-
self.__dict__.update(parent.__dict__)
309-
310-
def forward(
311-
self,
312-
input_ids: torch.LongTensor = None,
313-
attention_mask: Optional[torch.Tensor] = None,
314-
position_ids: Optional[torch.LongTensor] = None,
315-
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
316-
inputs_embeds: Optional[torch.FloatTensor] = None,
317-
labels: Optional[torch.LongTensor] = None,
318-
use_cache: Optional[bool] = None,
319-
output_attentions: Optional[bool] = None,
320-
output_hidden_states: Optional[bool] = None,
321-
return_dict: Optional[bool] = None,
322-
cache_position: Optional[torch.LongTensor] = None,
323-
logits_to_keep: Union[int, torch.Tensor] = 0,
324-
**kwargs: Unpack[KwargsForCausalLM],
325-
) -> Union[Tuple, CausalLMOutputWithPast]:
326-
# fixed input size, due to position_ids fixed
327-
orig_len = input_ids.shape[-1]
328-
input_ids = fix_inputs(self, self.tokenizer, input_ids)
329-
if labels is not None:
330-
labels = fix_inputs(self, self.tokenizer, labels)
331-
res = super().forward(
332-
input_ids,
333-
attention_mask,
334-
position_ids,
335-
past_key_values,
336-
inputs_embeds,
337-
labels,
338-
use_cache,
339-
output_attentions,
340-
output_hidden_states,
341-
return_dict,
342-
cache_position,
343-
logits_to_keep,
344-
**kwargs,
345-
)
346-
# we need to trim to the original size
347-
res.logits = res.logits[..., :orig_len, :]
348-
return res
349-
350-
self.forward = types.MethodType(forward, self)
351-
self.tokenizer = tokenizer
352-
353-
354324
def evaluate(q_m, tokenizer, dataset_test, args):
355325
# -------------------------------------------------------------------------
356326
# Evaluate perplexity on Wikitext-2
357327
# -------------------------------------------------------------------------
358328
print("\nCalculating perplexities …")
359329
enc = tokenizer("\n\n".join(dataset_test["text"]), return_tensors="pt")
360330
ppl_uint8 = perplexity(
361-
q_m, enc, args.device, stride=q_m.config.max_position_embeddings
331+
q_m, enc, args.device, stride=q_m.wrapped.config.max_position_embeddings
362332
)
363333

364334
print("\n┌── Wikitext-2 test perplexity ─────────────")
@@ -564,7 +534,7 @@ def main():
564534
q_m = quantize_using_PTQ(q_m, calib_inputs, args)
565535

566536
# after PTQ quantizer only fixed-length input sequences are valid
567-
evaluate(LLamaWithFixedInput(q_m, tokenizer), tokenizer, dataset_test, args)
537+
evaluate(q_m, tokenizer, dataset_test, args)
568538

569539
if args.save_circle_to_folder is not None:
570540
save_circles_to(q_m, calib_inputs, args.save_circle_to_folder)

tico/quantization/wrapq/examples/quantize_with_gptq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper
4343
from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase
4444

45-
4645
# Token-budget presets for activation calibration
4746
TOKENS: dict[str, int] = {
4847
# Smoke test (<1 min turnaround on CPU/GPU)
@@ -65,6 +64,7 @@
6564
TRAIN_SPLIT = "train"
6665
TEST_SPLIT = "test"
6766

67+
6868
# -------------------------------------------------------------------------
6969
# 1. Helper — copy GPTQ (scale, zp) into PTQ observers
7070
# -------------------------------------------------------------------------

tico/quantization/wrapq/quantizer.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,18 @@ def _wrap_supported(
8181
Recursively attempt to wrap boundaries. Strictness is applied at every boundary.
8282
"""
8383
assert not isinstance(root, QuantModuleBase), "The module is already wrapped."
84+
try:
85+
return PTQWrapper(root, qcfg=qcfg, fp_name="model")
86+
except NotImplementedError as e:
87+
print("no special wrapper for model, wrappig using general case")
8488

8589
# Case A: HuggingFace-style transformers: model.model.layers
8690
lm = getattr(root, "model", None)
8791

8892
embeddings = (
89-
getattr(lm, "embed_tokens", None) if isinstance(lm, nn.Module) else None
93+
getattr(lm, "embed_tokens", None)
94+
if isinstance(lm.embed_tokens, nn.Module) # type: ignore[union-attr]
95+
else None
9096
)
9197
if isinstance(embeddings, nn.Module):
9298
child_scope = "model.embeddings"
@@ -99,7 +105,11 @@ def _wrap_supported(
99105
)
100106
lm.embed_tokens = wrapped # type: ignore[union-attr]
101107

102-
model_norm = getattr(lm, "norm", None) if isinstance(lm, nn.Module) else None
108+
model_norm = (
109+
getattr(lm, "norm", None)
110+
if isinstance(lm.norm, nn.Module) # type: ignore[union-attr]
111+
else None
112+
)
103113
if isinstance(model_norm, nn.Module):
104114
child_scope = "model.norm"
105115
child_cfg = qcfg.child(child_scope)

tico/quantization/wrapq/utils/metrics.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,15 @@ def perplexity(
9090
input_ids_full = input_ids_full.to(device)
9191

9292
if max_length is None:
93-
assert hasattr(model, "config")
94-
model_config = model.config
95-
if hasattr(model.config, "text_config"):
96-
model_config = model.config.text_config
93+
if hasattr(model, "config"):
94+
assert hasattr(model, "config")
95+
model_config = model.config
96+
else:
97+
assert hasattr(model.wrapped, "config")
98+
model_config = model.wrapped.config
99+
100+
if hasattr(model_config, "text_config"):
101+
model_config = model_config.text_config
97102
assert hasattr(model_config, "max_position_embeddings")
98103
assert isinstance(model_config.max_position_embeddings, int)
99104
max_length = model_config.max_position_embeddings

tico/quantization/wrapq/wrappers/llama/quant_attn.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,7 @@ def _concat_kv(
161161
return k, v
162162

163163
def _apply_rope(self, q, k, cos, sin, unsqueeze_dim: int = 1):
164-
cos_u = cos.unsqueeze(unsqueeze_dim)
165-
sin_u = sin.unsqueeze(unsqueeze_dim)
164+
cos_u, sin_u = cos, sin
166165

167166
q_half = self._rot(
168167
q, self.obs_q_x1, self.obs_q_x2, self.obs_q_neg, self.obs_q_cat
@@ -201,8 +200,6 @@ def forward(
201200

202201
# Rope tables
203202
cos, sin = position_embeddings
204-
cos = self._fq(cos, self.obs_cos)
205-
sin = self._fq(sin, self.obs_sin)
206203
q_rot, k_rot = self._apply_rope(q, k, cos, sin, unsqueeze_dim=1)
207204

208205
# --- build/update KV for attention & present_key_value -------------
@@ -228,7 +225,7 @@ def forward(
228225
attention_mask = self.causal_mask_template[..., :q_len, :k_len].to(
229226
hidden_states.device
230227
)
231-
attention_mask = self._fq(attention_mask, self.obs_causal_mask)
228+
attention_mask = self._fq(attention_mask, self.obs_causal_mask)
232229

233230
attn_weights_parts = []
234231
attn_out_parts = []

0 commit comments

Comments
 (0)