@@ -109,3 +109,68 @@ def test_forward_diff(self):
109109 self .assertGreater (diff , 0.0 )
110110 self .assertLess (diff , 0.4 )
111111 self .assertEqual (fp_out .shape , q_out .shape )
112+
113+
114+ @unittest .skipUnless (has_transformers_for ("llama" ), skip_msg )
115+ class TestQuantLlamaModelWithCache (unittest .TestCase ):
116+ seq_len : int
117+ vocab_size : int
118+ hid_layers : int
119+ fp_model : torch .nn .Module
120+
121+ @classmethod
122+ def setUpClass (cls ):
123+ torch .manual_seed (0 )
124+
125+ from transformers .models .llama .configuration_llama import LlamaConfig
126+ from transformers .models .llama .modeling_llama import LlamaModel
127+
128+ cls .seq_len = 16
129+ cls .vocab_size = 10000
130+ cls .hid_layers = 3
131+ cfg = LlamaConfig (
132+ hidden_size = 8 ,
133+ num_attention_heads = 2 ,
134+ num_key_value_heads = 1 ,
135+ head_dim = 4 ,
136+ attention_bias = False ,
137+ attention_dropout = 0.0 ,
138+ attn_implementation = "eager" ,
139+ num_hidden_layers = cls .hid_layers ,
140+ max_position_embeddings = cls .seq_len ,
141+ use_cache = True ,
142+ return_dict = False ,
143+ )
144+ cls .fp_model = LlamaModel (cfg )
145+
146+ def test_model_output (self ):
147+ qmodel = QuantLlamaModel (
148+ self .fp_model , qcfg = PTQConfig (wrapper_variant = "prefill" )
149+ )
150+ self .assertIs (qmodel ._mode , Mode .NO_QUANT )
151+
152+ qmodel .enable_calibration ()
153+ self .assertIs (qmodel ._mode , Mode .CALIB )
154+
155+ x = torch .randint (
156+ 0 ,
157+ self .vocab_size ,
158+ (
159+ 1 ,
160+ self .seq_len ,
161+ ),
162+ )
163+ output = qmodel (x )
164+
165+ assert len (output ) == 2 # last_hidden_states + past_key_values
166+ past_key_values = output [1 ]
167+ assert len (past_key_values ) == self .hid_layers
168+ for index in range (self .hid_layers ):
169+ past_key_value = past_key_values [index ]
170+ assert isinstance (past_key_value , tuple )
171+
172+ past_key = past_key_value [0 ]
173+ assert past_key .shape [- 2 ] == self .seq_len
174+
175+ past_value = past_key_value [1 ]
176+ assert past_value .shape [- 2 ] == self .seq_len
0 commit comments