Skip to content

Commit 003d8f4

Browse files
committed
[DRAFT] Improvements in disk space
This PR fixes population of static `causal_masks`\`position_embeddings` through the layers to save disk space. TICO-DCO-1.0-Signed-off-by: s.malakhov <s.malakhov@partner.samsung.com>
1 parent 9055317 commit 003d8f4

11 files changed

Lines changed: 716 additions & 176 deletions

File tree

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
The tests run only if *transformers* is available (they depend on the genuine
17+
`transformers.models.llama.modeling_llama.LlamaForCausalLM`).
18+
"""
19+
20+
import unittest
21+
22+
import torch
23+
24+
from tico.quantization.wrapq.mode import Mode
25+
from tico.quantization.wrapq.utils.version import has_transformers_for
26+
from tico.quantization.wrapq.wrappers.llama.quant_model_for_causal_lm import (
27+
QuantLlamaForCausalLM,
28+
)
29+
30+
skip_msg = "required transformers not installed — skipping LlamaForCausalLM tests"
31+
32+
33+
@unittest.skipUnless(has_transformers_for("llama"), skip_msg)
34+
class TestQuantLlamaForCausalLM(unittest.TestCase):
35+
seq_len: int
36+
vocab_size: int
37+
fp_model: torch.nn.Module
38+
39+
@classmethod
40+
def setUpClass(cls):
41+
torch.manual_seed(0)
42+
43+
from transformers.models.llama.configuration_llama import LlamaConfig
44+
from transformers.models.llama.modeling_llama import LlamaForCausalLM
45+
46+
cls.seq_len = 16
47+
cls.vocab_size = 10000
48+
cfg = LlamaConfig(
49+
hidden_size=8,
50+
num_attention_heads=2,
51+
num_key_value_heads=1,
52+
head_dim=4,
53+
attention_bias=False,
54+
attention_dropout=0.0,
55+
attn_implementation="eager",
56+
num_hidden_layers=2,
57+
max_position_embeddings=cls.seq_len,
58+
use_cache=False,
59+
return_dict=False,
60+
)
61+
cls.fp_model = LlamaForCausalLM(cfg)
62+
63+
def test_mode_transitions(self):
64+
qmodel = QuantLlamaForCausalLM(self.fp_model)
65+
self.assertIs(qmodel._mode, Mode.NO_QUANT)
66+
67+
qmodel.enable_calibration()
68+
self.assertIs(qmodel._mode, Mode.CALIB)
69+
70+
x = torch.randint(
71+
0,
72+
self.vocab_size,
73+
(
74+
1,
75+
self.seq_len // 2,
76+
),
77+
)
78+
_ = qmodel(x) # gather stats
79+
80+
qmodel.freeze_qparams()
81+
self.assertIs(qmodel._mode, Mode.QUANT)
82+
ndf = 0
83+
84+
def test_forward_diff(self):
85+
qmodel = QuantLlamaForCausalLM(self.fp_model)
86+
qmodel.enable_calibration()
87+
calib_set = []
88+
for index in range(4):
89+
inp = torch.randint(
90+
0,
91+
self.vocab_size,
92+
(
93+
1,
94+
self.seq_len // (index + 1),
95+
),
96+
)
97+
_ = qmodel(inp)
98+
calib_set.append(inp)
99+
qmodel.freeze_qparams()
100+
101+
with torch.no_grad():
102+
q_out = qmodel(calib_set[0])[0]
103+
fp_out = self.fp_model(calib_set[0])[0]
104+
105+
diff = (fp_out - q_out).abs().mean().item()
106+
self.assertGreater(diff, 0.0)
107+
self.assertLess(diff, 0.4)
108+
self.assertEqual(fp_out.shape, q_out.shape)
109+
nbm = 0
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
The tests run only if *transformers* is available (they depend on the genuine
17+
`transformers.models.llama.modeling_llama.LlamaModel`).
18+
"""
19+
20+
import unittest
21+
22+
import torch
23+
from tico.quantization.config.ptq import PTQConfig
24+
25+
from tico.quantization.wrapq.dtypes import DType
26+
from tico.quantization.wrapq.mode import Mode
27+
from tico.quantization.wrapq.utils.version import has_transformers_for
28+
from tico.quantization.wrapq.wrappers.llama.quant_model import QuantLlamaModel
29+
30+
skip_msg = "required transformers not installed — skipping LlamaModel tests"
31+
32+
33+
@unittest.skipUnless(has_transformers_for("llama"), skip_msg)
34+
class TestQuantLlamaModel(unittest.TestCase):
35+
seq_len: int
36+
vocab_size: int
37+
fp_model: torch.nn.Module
38+
39+
@classmethod
40+
def setUpClass(cls):
41+
torch.manual_seed(0)
42+
43+
from transformers.models.llama.configuration_llama import LlamaConfig
44+
from transformers.models.llama.modeling_llama import LlamaModel
45+
46+
cls.seq_len = 16
47+
cls.vocab_size = 10000
48+
cfg = LlamaConfig(
49+
hidden_size=8,
50+
num_attention_heads=2,
51+
num_key_value_heads=1,
52+
head_dim=4,
53+
attention_bias=False,
54+
attention_dropout=0.0,
55+
attn_implementation="eager",
56+
num_hidden_layers=2,
57+
max_position_embeddings=cls.seq_len,
58+
use_cache=False,
59+
return_dict=False,
60+
)
61+
cls.fp_model = LlamaModel(cfg)
62+
63+
def test_mode_transitions(self):
64+
qmodel = QuantLlamaModel(self.fp_model)
65+
self.assertIs(qmodel._mode, Mode.NO_QUANT)
66+
67+
qmodel.enable_calibration()
68+
self.assertIs(qmodel._mode, Mode.CALIB)
69+
70+
x = torch.randint(
71+
0,
72+
self.vocab_size,
73+
(
74+
1,
75+
self.seq_len,
76+
),
77+
)
78+
_ = qmodel(x) # gather stats
79+
80+
qmodel.freeze_qparams()
81+
self.assertIs(qmodel._mode, Mode.QUANT)
82+
83+
def test_forward_diff(self):
84+
qmodel = QuantLlamaModel(self.fp_model)
85+
qmodel.enable_calibration()
86+
calib_set = []
87+
for _ in range(4):
88+
inp = torch.randint(
89+
0,
90+
self.vocab_size,
91+
(
92+
1,
93+
self.seq_len,
94+
),
95+
)
96+
_ = qmodel(inp)
97+
calib_set.append(inp)
98+
qmodel.freeze_qparams()
99+
100+
with torch.no_grad():
101+
q_out = qmodel(calib_set[0])[0]
102+
fp_out = self.fp_model(calib_set[0])[0]
103+
104+
diff = (fp_out - q_out).abs().mean().item()
105+
self.assertGreater(diff, 0.0)
106+
self.assertLess(diff, 0.4)
107+
self.assertEqual(fp_out.shape, q_out.shape)

0 commit comments

Comments
 (0)