-
Notifications
You must be signed in to change notification settings - Fork 32
Expand file tree
/
Copy pathvcd_sample.py
More file actions
342 lines (286 loc) · 15.2 KB
/
vcd_sample.py
File metadata and controls
342 lines (286 loc) · 15.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
import copy
import inspect
import warnings
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.distributed as dist
from torch import nn
from transformers.generation.logits_process import (
LogitsProcessorList,
)
from transformers.generation.stopping_criteria import (
StoppingCriteria,
StoppingCriteriaList,
validate_stopping_criteria,
)
import transformers
from transformers.generation.utils import SampleOutput
def sample(
self,
input_ids: torch.LongTensor,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
logits_warper: Optional[LogitsProcessorList] = None,
max_length: Optional[int] = None,
pad_token_id: Optional[int] = None,
eos_token_id: Optional[Union[int, List[int]]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
synced_gpus: bool = False,
streamer: Optional["BaseStreamer"] = None,
key_position: Optional[dict] =None,
sample_greedy: Optional[bool] = None,
**model_kwargs,
) -> Union[SampleOutput, torch.LongTensor]:
# init values
logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
if max_length is not None:
warnings.warn(
"`max_length` is deprecated in this function, use"
" `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
UserWarning,
)
stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None
output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
output_attentions = (
output_attentions if output_attentions is not None else self.generation_config.output_attentions
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
)
return_dict_in_generate = (
return_dict_in_generate
if return_dict_in_generate is not None
else self.generation_config.return_dict_in_generate
)
# init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
if return_dict_in_generate and self.config.is_encoder_decoder:
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
encoder_hidden_states = (
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
)
# keep track of which sequences are already finished
unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device)
this_peer_finished = False # used by synced_gpus only
# auto-regressive generation
model_kwargs_cd = model_kwargs.copy()
input_ids_cd = model_kwargs_cd.get("input_ids_cd", None) # for instruction cd only
while True:
if synced_gpus:
# Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
# The following logic allows an early break if all peers finished generating their sequence
this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
# send 0.0 if we finished, 1.0 otherwise
dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
# did all peers finish? the reduced sum will be 0.0 then
if this_peer_finished_flag.item() == 0.0:
break
# prepare model inputs
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
# turn to normal input
self.model.config.use_fast_v = False
self.model.reset_fastv()
# forward pass to get next token
outputs = self(
**model_inputs,
return_dict=True,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
vad=False,
# key_position=None,
key_position=key_position,
)
if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need
next_token_logits = outputs.logits[:, -1, :]
## For contrastive decoding initial
use_icd = (model_kwargs.get("input_ids_cd") != None) or (model_kwargs.get("inputs_embeds_cd") != None)
use_cd = model_kwargs.get("images_cd") != None
output_attentions_wo_img = (
output_attentions if output_attentions is not None else self.generation_config.output_attentions
)
output_hidden_states_wo_img = (
output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
)
# model_kwargs_cd = model_kwargs.copy()
if use_cd: #for VCD and SID
if self.model.config.fast_v_attention_rank != None:
self.model.config.use_fast_v = True
self.model.reset_fastv()
## cd_comments: forward pass of the model with distorted image input
model_inputs_cd = self.prepare_inputs_for_generation_cd(input_ids, **model_kwargs_cd)
outputs_cd = self(
**model_inputs_cd,
return_dict=True,
output_attentions=output_attentions_wo_img,
output_hidden_states=output_hidden_states_wo_img,
key_position=key_position,
vad = False,
)
next_token_logits_cd = outputs_cd.logits[:, -1, :]
# model_inputs_ad = self.prepare_inputs_for_generation_cd(input_ids, **model_kwargs_cd) # vision add (enhancement) experiments
# outputs_ad = self(
# **model_inputs_ad,
# return_dict=True,
# output_attentions=output_attentions_wo_img,
# output_hidden_states=output_hidden_states_wo_img,
# vad = True,
# )
# next_token_logits_ad = outputs_ad.logits[:, -1, :]
## cd_comments: pre-process logits from contrastive inputs
cd_alpha = model_kwargs.get("cd_alpha") if model_kwargs.get("cd_alpha") is not None else 0.5
cd_beta = model_kwargs.get("cd_beta") if model_kwargs.get("cd_beta") is not None else 0.1
# version 1 set cutoff for Adaptive Plausibility Constraints
# probs = nn.functional.softmax(next_token_logits, dim=-1)
# cutoff = 0.8 * probs.max(dim=-1, keepdim=True).values
# cutoff = 0.9 * next_token_logits.max(dim=-1, keepdim=True).values
# version 2 set cutoff for Adaptive Plausibility Constraints
cutoff1 = torch.log(torch.tensor(cd_beta))
cutoff = cutoff1 + next_token_logits.max(dim=-1, keepdim=True).values
diffs = (1-cd_alpha)*next_token_logits + cd_alpha*next_token_logits_cd
# diffs = 2 * next_token_logits - 1 * next_token_logits_cd
# diffs = next_token_logits_cd
cd_logits = diffs.masked_fill(next_token_logits < cutoff, -float("inf"))
## cd_comments: apply temperature warping and top-k filtering in contrastive decoding
cd_logits = logits_processor(input_ids, cd_logits)
cd_logits = logits_warper(input_ids, cd_logits)
if sample_greedy: # argmax
next_tokens = torch.argmax(diffs, dim=-1)
else: ## sampling:
cd_probs = nn.functional.softmax(cd_logits, dim=-1)
next_tokens = torch.multinomial(cd_probs, num_samples=1).squeeze(1)
elif use_icd:
if model_kwargs.get("input_ids_cd") == None: # Q-former
model_inputs_cd = self.prepare_inputs_for_generation_icd(input_ids = input_ids, **model_kwargs_cd)
else: # VLM connector
model_inputs_cd = self.prepare_inputs_for_generation_icd(input_ids = input_ids_cd, **model_kwargs_cd)
outputs_cd = self(
**model_inputs_cd,
return_dict=True,
output_attentions=output_attentions_wo_img,
output_hidden_states=output_hidden_states_wo_img,
key_position=key_position,
vad = False,
)
next_token_logits_cd = outputs_cd.logits[:, -1, :]
## cd_comments: pre-process logits from contrastive inputs
cd_alpha = model_kwargs.get("cd_alpha") if model_kwargs.get("cd_alpha") is not None else 0.5
cd_beta = model_kwargs.get("cd_beta") if model_kwargs.get("cd_beta") is not None else 0.1
# version 1 set cutoff for Adaptive Plausibility Constraints
# probs = nn.functional.softmax(next_token_logits, dim=-1)
# cutoff = 0.8 * probs.max(dim=-1, keepdim=True).values
# cutoff = 0.9 * next_token_logits.max(dim=-1, keepdim=True).values
# version 2 set cutoff for Adaptive Plausibility Constraints
cutoff1 = torch.log(torch.tensor(cd_beta))
cutoff = cutoff1 + next_token_logits.max(dim=-1, keepdim=True).values
diffs = (1+cd_alpha)*next_token_logits - cd_alpha*next_token_logits_cd
# diffs = 2 * next_token_logits - 1 * next_token_logits_cd
# diffs = next_token_logits_cd
cd_logits = diffs.masked_fill(next_token_logits < cutoff, -float("inf"))
## cd_comments: apply temperature warping and top-k filtering in contrastive decoding
cd_logits = logits_processor(input_ids, cd_logits)
cd_logits = logits_warper(input_ids, cd_logits)
if sample_greedy: # argmax
next_tokens = torch.argmax(diffs, dim=-1)
else: ## sampling:
cd_probs = nn.functional.softmax(cd_logits, dim=-1)
next_tokens = torch.multinomial(cd_probs, num_samples=1).squeeze(1)
# for instruction cd, update generated ids, model inputs, and length for next step
if model_kwargs.get("input_ids_cd") != None: # VLM connector
input_ids_cd = torch.cat([input_ids_cd, next_tokens[:, None]], dim=-1)
model_kwargs_cd = self._update_model_kwargs_for_generation(
outputs_cd, model_kwargs_cd, is_encoder_decoder=self.config.is_encoder_decoder
)
else:
next_token_scores = logits_processor(input_ids, next_token_logits)
next_token_scores = logits_warper(input_ids, next_token_scores)
if sample_greedy: # argmax
next_tokens = torch.argmax(next_token_scores, dim=-1)
else: ## sampling:
probs = nn.functional.softmax(next_token_scores, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
# Store scores, attentions and hidden_states when required
if return_dict_in_generate:
if output_scores:
scores += (next_token_scores,)
if output_attentions:
decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
)
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states:
decoder_hidden_states += (
(outputs.decoder_hidden_states,)
if self.config.is_encoder_decoder
else (outputs.hidden_states,)
)
# finished sentences should have their next token be a padding token
if eos_token_id is not None:
if pad_token_id is None:
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
if streamer is not None:
streamer.put(next_tokens.cpu())
model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
)
## cd_comments: update model_kwargs_cd for contrastive decoding
if use_cd:
model_kwargs_cd = self._update_model_kwargs_for_generation(
outputs_cd, model_kwargs_cd, is_encoder_decoder=self.config.is_encoder_decoder
)
# if eos_token was found in one sentence, set sentence to finished
if eos_token_id_tensor is not None:
unfinished_sequences = unfinished_sequences.mul(
next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0)
)
# stop when each sentence is finished
if unfinished_sequences.max() == 0:
this_peer_finished = True
# stop if we exceed the maximum length
if stopping_criteria(input_ids, scores):
this_peer_finished = True
if this_peer_finished and not synced_gpus:
break
if streamer is not None:
streamer.end()
if return_dict_in_generate:
if self.config.is_encoder_decoder:
return SampleEncoderDecoderOutput(
sequences=input_ids,
scores=scores,
encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states,
)
else:
return SampleDecoderOnlyOutput(
sequences=input_ids,
scores=scores,
attentions=decoder_attentions,
hidden_states=decoder_hidden_states,
)
else:
return input_ids
def evolve_vcd_sampling():
transformers.generation.utils.GenerationMixin.sample = sample