-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathvalidation.py
More file actions
83 lines (54 loc) · 2.8 KB
/
validation.py
File metadata and controls
83 lines (54 loc) · 2.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import torch
from dataset import causal_mask
def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device):
sos_index = tokenizer_tgt.token_to_id('[SOS]')
eos_index = tokenizer_tgt.token_to_id('[EOS]')
#precompute the encoder output and reuse it for every token we get from the decoder
encoder_output = model.encode(source, source_mask)
#initialize the decoder input with the sos token
decoder_input = torch.empty(1,1).fill_(sos_index).type_as(source).to(device)
while True:
if decoder_input.size(1) == max_len:
break
#build mask for the target (decoder input)
decoder_mask = causal_mask(decoder_input.size(1)).type_as(source_mask).to(device)
#calculate the output of the decoder
out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)
#get the next token
probab = model.project(out[:,-1])
#select the token with the maximum probability
_, next_word = torch.max(probab, dim=1)
#the moment we get the next word we will append it to the decoder input which will be used for the enxt word.
decoder_input = torch.cat([decoder_input, torch.empty(1,1).type_as(source).fill_(next_word.item()).to(device)], dim=1)
if next_word == eos_index: #if next word is end of token we'll stop the search
break
return decoder_input.squeeze(0)
def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len, device, print_msg, global_epoch, writer, num_example = 2):
#function to run
model.eval()
count = 0
# source_texts = []
# expected = []
# predicted =[]
#size of the control window
console_width = 80
with torch.no_grad():
for batch in validation_ds:
count+=1
encoder_input = batch['encoder_input'].to(device)
encoder_mask = batch['encoder_mask'].to(device)
assert encoder_input.size(0) ==1, "Batch size must be one for validation"
model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)
source_text = batch['src_text'][0]
target_text = batch['tgt_text'][0]
model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy())
# source_texts.append(source_text)
# expected.append(target_text)
# predicted.append(model_out_text)
#print to the console
print_msg('-'*console_width)
print_msg(f'SOURCE : {source_text}')
print_msg(f'TARGET : {target_text}')
print_msg(f'PREDICTED : {model_out_text}')
if count == num_example:
break