-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathModel.py
More file actions
1462 lines (1229 loc) · 52.9 KB
/
Model.py
File metadata and controls
1462 lines (1229 loc) · 52.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from SaxExtraction import *
from ExMetric import *
from CCMetric import *
from CCTree import *
from MOutput import *
from PaddedMInput import *
from SaxDataset import *
from PickleList import *
from utils_l_sample_str import write_l_sample_str
import os
from copy import copy, deepcopy
from collections import OrderedDict
import logging
import regex as re
from pprint import pprint
import torch
import torch.nn.functional as F
from torch import nn
from torch.optim import Adam
import lightning as L
from transformers import AdamW, AutoModel
# prevents printing of model weights, etc
logging.getLogger(
'transformers.configuration_utils').setLevel(logging.ERROR)
logging.getLogger(
'transformers.modeling_utils').setLevel(logging.ERROR)
logging.getLogger().setLevel(logging.ERROR)
class Model(L.LightningModule):
"""
This class inherits from L.LightningModule some powerful methods that
loop through the batches of an epoch. It can either:
1. calculate the loss when training (here the weights are changing)
2. calculate the accuracy when tuning or testing or
extracting (here the weights are fixed).
The class loops over batches of an epoch for the 3 actions ttt=train,
tune (a.k.a. validation) and test.
This class has an abstract class as its parent. To distinguish between
inherited and uninherited methods, we add a prefix "sax_" to the name of
all uninherited methods.
Note that `batch` has a different meaning in Openie6 and SentenceAx. In
Openie6, it's a dictionary with the fallowing keys:
batch={
"lll_label":
"meta_data":
"pos_locs":
"text":
"verb_bools":
"verb_locs":
"word_starts":
}
In SentenceAx, (see sax_get_batch_in_dicts()), we have instead:
x, y, l_orig_sent, xname_to_l_dim1 = batch
SentenceAX is a fine-tuning of bert-base-cased and bert-large-cased.
Both of these weights/models are cased (meaning they both distinguish
between upper and lower cases), but bert-large-cased > bert-base-cased.
embedding = nn.Embedding(num_embeddings=L, embedding_dim=d)
d= hidden_size = 768 for BERT base
L = 100
Attributes
----------
auto_tokenizer: AutoTokenizer
con_to_weight: dict[str, float]
dropout_fun: Dropout
embedding: Embedding
hidden_size: int
ilabelling_layer: Linear
iterative_transformer: ModuleList
l_batch_m_out: list[MOutput]
l_cc_epoch_sample_str: list[str]
l_cc_epoch_spanned_word: list[list[str]]
lll_cc_epoch_spanned_loc: list[list[list[int]]]
loss_fun: CrossEntropyLoss
merge_layer: Linear
metric: CCMetric | ExMetric
name: str
name_to_param0: dict[str, Any]
osent_to_words: dict[str, list[str]]
params: Params
scores_epoch_end_d: dict[str, Any]
base_model: BertModel
sub_osent_to_osent: dict[str, str]
dictionary that maps sentences to sentences.
Both Model and ExMetric possess a pointer to this dictionary.
verbose: bool
# some inherited attributes that won't be used
# hparams (dictionary, Used by Openie6, but not by us.
# We use the class Params instead.)
# logger
# trainer
# on_gpu
"""
def __init__(self,
params,
auto_tokenizer,
verbose=False,
name=""):
"""
Constructor
Parameters
----------
params: Params
auto_tokenizer: AutoTokenizer
verbose: bool
name: str
name of Model instance if more than one is being used at the
same time. ActionConductor declares 4 Model instances which it
calls "train", "resume", "test", "pred"
"""
super().__init__()
# This stores `pi_test=3.14` in logs/ex/test/hparams.yaml. This is
# here for illustrative purposes only. In SentenceAx, instead of
# hparams, we use the Params class and sax_globals.py
if verbose:
self.hparams["pi_test"] = 3.14
self.save_hyperparameters(self.hparams)
print("Saving self.hparams= ", self.hparams)
self.params = params
self.auto_tokenizer = auto_tokenizer
self.verbose = verbose
self.name = name
# return_dict=False avoids error message from Dropout
self.base_model = AutoModel.from_pretrained(
self.params.d["model_str"],
cache_dir=CACHE_DIR,
return_dict=False)
self.hidden_size = self.base_model.config.hidden_size
if self.verbose:
print("Model init")
print(f"\tname={self.name}, hidden_size={self.hidden_size}")
# Actually, self.params.d["num_iterative_layers"]=2 for all Params.pid
if self.params.d["num_iterative_layers"] > 0:
num_layers = len(self.base_model.encoder.layer)
num_encoder_layers = \
num_layers - self.params.d["num_iterative_layers"]
self.iterative_transformer = \
self.base_model.encoder.layer[
num_encoder_layers:num_layers]
# this truncation of self.base_model.encoder.layer must
# be done after, not before defining self.iterative_transformer
self.base_model.encoder.layer = \
self.base_model.encoder.layer[0:num_encoder_layers]
if verbose:
print("num_iterative_layers= ", num_layers -
num_encoder_layers)
print("num_encoder_layers= ", num_encoder_layers)
print("total num layers= ", num_layers)
else:
self.iterative_transformer = []
self.dropout_fun = nn.Dropout(p=PROB_DROPOUT) # 0.0
self.embedding = nn.Embedding(
MAX_NUM_OSENTL_WORDS, # maximum number of words analyzed, 100
self.hidden_size) # dim of embedding space, 768
self.merge_layer = nn.Linear(self.hidden_size, # 768
MERGE_DIM) # 300
self.ilabelling_layer = nn.Linear(MERGE_DIM, # 300
NUM_ILABELS) # 6
# ignore_index=-100 is the default, but including it
# explicitly for clarity
# see file misc/CrossEntropyLoss-examples.py for examples of usage
self.loss_fun = nn.CrossEntropyLoss(ignore_index=-100)
self.sub_osent_to_osent = {}
# self.osent_to_words is similar to Openie6 conj_word_mapping
# Note that self.osent_to_words is never used;
# It is filled in ActionConductor but never used.
# We include it in SentenceAx to follow Openie6.
self.osent_to_words = {}
if self.params.task == "ex":
# ExMetric gets a pointer (address) to the sub_osent_to_osent
# dict. This dictionary is initially empty, but if we add
# elements to it later on, both Model and ExMetric will know
# about it because the dictionary pointer will not have changed,
# only its contents.
self.metric = ExMetric(
sub_osent_to_osent=self.sub_osent_to_osent,
verbose=self.verbose)
elif self.params.task == "cc":
self.metric = CCMetric(verbose=self.verbose)
self.scores_epoch_end_d = {} # filled in test_epoch_end()
if "multi_opt" not in self.params.d \
or not self.params.d["multi_opt"]:
constraint_str = ""
con_weight_str = ""
else:
if "constraint_str" not in self.params.d or \
"con_weight_str" not in self.params.d:
constraint_str = 'posm_hvc_hvr_hve'
con_weight_str = '1_1_1_1'
else:
constraint_str = self.params.d["constraint_str"]
con_weight_str = self.params.d["con_weight_str"]
l_constraint = constraint_str.split('_')
l_con_weight = con_weight_str.split('_')
assert len(l_constraint) == len(l_con_weight)
self.con_to_weight = {l_constraint[k]: float(l_con_weight[k])
for k in range(len(l_constraint)) if
l_constraint[k]}
# no longer used
# Openie6.all_predictions_conj
# self.l_cc_epoch_sample_str = []
# Openie6.all_conjunct_words_conj
# self.l_cc_epoch_spanned_word = []
# Openie6.all_sentence_indices_conj
# self.lll_cc_epoch_spanned_loc = []
# not used
# self.l_ex_pred_sample_str = [] # Openie6.all_predictions_oie
self.l_batch_m_out = \
PickleList(f"action_{self.name}_l_batch_m_out_dir")
self.name_to_param0 = None
def configure_optimizers(self):
"""
similar to Openie6.model.configure_optimizers()
This method returns a list of optimizers, one for each constraint in
self.con_to_weight. Optimizers can be either all Adam or all AdamW.
This is how ChatGPT explains the Openie6.model.configure_optimizers(
) method:
This PyTorch code is a method called `configure_optimizers` inside a
PyTorch Lightning module or a subclass of it. This method is
responsible for configuring the optimizers used during the training
process. Let's break down the code:
1. `all_params = list(self.named_parameters())`: This line retrieves
all the parameters of the model along with their names.
2. `bert_params = []` and `other_params = []`: These lists are used
to separate parameters that belong to the BERT model (presumably a
pre-trained language model) and other parameters (possibly
task-specific layers or embeddings).
3. `no_decay = ["bias", "gamma", "beta"]`: This list contains names
of parameters for which weight decay is not applied. These are
typically bias terms or normalization parameters like those in
BatchNorm layers.
4. `opt_params`: This list contains dictionaries, each specifying
the parameters for a particular optimizer. The parameters are
separated based on whether they should undergo weight decay or not.
- For parameters that do not contain any strings in the `no_decay`
list and contain the string `'base_model'` in their name, a weight
decay rate of 0.01 is applied.
- For parameters that contain strings in the `no_decay` list and
contain the string `'base_model'` in their name, no weight decay is
applied.
- For parameters that do not contain the string `'base_model'` in
their name, no weight decay is applied.
5. `if self.hparams.optimizer == 'adamW':` and `elif
self.hparams.optimizer == 'adam':`: These conditions select between
the AdamW optimizer and the Adam optimizer based on the value of
`self.hparams.optimizer`.
6. `if self.hparams.multi_opt and self.hparams.constraints != None:`:
This condition checks if multiple optimizers are to be used and if
constraints are provided.
- If both conditions are true, the number of optimizers is
determined by the number of constraints provided, and a list of
optimizers is returned with each optimizer having the same
configuration.
- If the condition is not met, a single optimizer is returned.
Finally, the method returns a list containing the selected
optimizer(s) for training.
Returns
-------
list[Adam|AdamW]
"""
# self.named_parameters() is a method inherited from parent class
# Its type is Iterator[Tuple[str, Parameter]]. Apply dict() to
# to turn it into dict[str, Parameter] or list() to turn into
# list[tuple(str, Parameter)].
# self.named_parameters() contains all (name, value) pairs of
# weights to be optimized
all_pairs = list(self.named_parameters())
# opt = optimizer
# apple = parameter
# pair = ("apple", apple)
def base_model_pairs():
return [pair for pair in all_pairs if
"base_model" in pair[0]]
def non_base_model_pairs():
return [pair for pair in all_pairs if
"base_model" not in pair[0]]
# parameters that do not decay, fixed during optimization
xnames = ["bias", "gamma", "beta"]
def pair_in_xnames(pair):
return any((pair[0] in xname) for xname in xnames)
opt_param_d = [
{"params": [pair[1] for pair in base_model_pairs() if
not pair_in_xnames(pair)],
"weight_decay_rate": 0.01,
'lr': self.params.d["lr"]},
{"params": [pair[1] for pair in base_model_pairs() if
pair_in_xnames(pair)],
"weight_decay_rate": 0.0,
'lr': self.params.d["lr"]},
{"params": [pair[1] for pair in non_base_model_pairs()],
'lr': self.params.d["lr"]}
]
if self.params.d["optimizer"] == 'adamW':
optimizer = AdamW(opt_param_d, lr=1e-3)
elif self.params.d["optimizer"] == 'adam':
optimizer = Adam(opt_param_d, lr=1e-3)
else:
assert False
if "multi_opt" in self.params.d:
num_optimizers = len(self.con_to_weight)
return [optimizer] * num_optimizers
else:
return [optimizer]
def get_progress_bar_dict(self):
"""
similar to Openie6.get_progress_bar_dict()
Use this inherited method to add to super( ).get_progress_bar_dict()
additional items to be displayed in the progress bar. We will not
add any. The modified dictionary is returned by the method.
Openie6 uses tqdm for all progress bars, including this one. We do
too, except for this one. For this one, we use the one built into
lightning.
tqdm derives from the Arabic word taqaddum which can mean "progress"
and is also an abbreviation for "I love you so much" in Spanish (te
quiero demasiado).
Returns
-------
Dict[str, int | str]
Dictionary with the items to be displayed in the progress bar.
"""
# take a look at what Openie6 does
# ----------------------------------
# # get avg_training_loss
# running_train_loss = self.trainer.running_loss.mean()
# avg_training_loss = running_train_loss.cpu().item() if \
# running_train_loss else float('NaN')
# # get `best` as float
# if type(self.trainer.checkpoint_callback.kth_value) \
# not in [int, float]:
# best = self.trainer.checkpoint_callback.kth_value.item()
# else:
# best = self.trainer.checkpoint_callback.kth_value
#
# tqdm_d = OrderedDict()
# tqdm_d['loss_fun'] = '{:.3f}'.format(avg_training_loss)
# tqdm_d['best'] = best
# return tqdm_d
# ----------------------------------
# # Get the losses
# losses = self.log_dict.pop('val_loss', None)
# val_losses = losses if losses is not None else self.log_dict.pop(
# 'val_main_loss', None)
# Get the progress bar
progress_bar_d = super().get_progress_bar_dict()
# # Add parameters to the progress bar
# progress_bar_d['loss'] = self.log_dict['loss']
# progress_bar_d['epoch_acc'] = self.log_dict['epoch_acc']
return progress_bar_d
@staticmethod
def sax_get_batch_in_dicts(batch):
"""
This method takes as input `batch`:
x, y, l_orig_sent, xname_to_l_dim1 = batch
and returns as output 3 dictionaries:
x_d, y_d, meta_d
Parameters
----------
batch: tuple[torch.Tensor,
torch.Tensor, list[str], dict[str, list[int]]
Returns
-------
OrderedDict, dict[str, torch.Tensor], dict[str, list[str]]
x_d, y_d, meta_d
"""
x, y, l_orig_sent, xname_to_l_dim1 = batch
y_d = {"lll_ilabel": y}
meta_d = {"l_orig_sent": l_orig_sent}
xname_to_dim1 = OrderedDict(
{xname: int(l_dim1[0]) for xname, l_dim1 in
xname_to_l_dim1.items()})
x_d = SaxDataset.invert_cat(x, xname_to_dim1)
return x_d, y_d, meta_d
def sax_get_llll_word_score(self, x_d, ttt, verbose=False):
"""
This method is used inside self.forward() and is the heart of that
method. It contains a while loop over depths that drives a batch
through the layers of the model and returns `llll_word_score`.
Setting `verbose` to True prints out a detailed trail of what occurs
in this method. The following example was obtained from such a
verbose trail.
Assume:
batch_size= 24,
hidden_size= 768,
NUM_ILABELS= 6,
MERGE_DIM= 300
2 iterative layers and 5 depths.
lll_word_score is the output of the last ilabelling_layer for each
depth
llll_word_score is a list of lll_word_score
len(llll_word_score)= 5 = num_depths
Note that llll_word_scoreT = Ten(llll_word_score)
Parameters
----------
x_d: OrderedDict
ttt: str
verbose: bool
Returns
-------
list[torch.Tensor]
llll_word_score
"""
# lll_label is similar to Openie6.labels
# first (outer) list over batch/sample of events
# second list over extractions
# third (inner) list over number of labels in a line
# after padding and adding the 3 unused tokens
# batch_size, num_depths, num_words = y_d["lll_ilabel"].shape
# sometimes num_depths will exceed max.
# This doesn't happen when training, because
# num_depths is specified when training.
num_depths = get_num_depths(self.params.task)
# `loss_fun` is not used in this function anymore
# loss_fun, lstm_loss = 0, 0
# batch_text = " ".join(redoL(meta_d["l_orig_sent"]))
# base_model_input = \
# torch.Tensor(self.auto_tokenizer.encode(batch_text))
if verbose:
print("Entering model.get_llll_word_score()")
hstate_count = Counter(verbose, "lll_hidstate")
word_hstate_count = Counter(verbose, "lll_word_hidstate")
lll_hidstate, _ = self.base_model(x_d["ll_osent_icode"])
hstate_count.new_one(reset=True)
comment(
verbose,
prefix="after base_model",
params_d={
"ll_osent_icode.shape": x_d["ll_osent_icode"].shape,
"lll_hidstate.shape": lll_hidstate.shape})
lll_word_score = Ten([0]) # this statement is unnecessary
llll_word_score = [] # ~ Openie6.all_depth_scores
depth = 0
# loop over depths
while True:
for ilay, layer in enumerate(self.iterative_transformer):
comment(verbose,
prefix="*********** Starting iterative layer",
params_d={"ilay": ilay})
# layer(lll_hidstate)[0] returns a copy
# of the tensor lll_hidstate after transforming it
# in some way. [0] chooses first component
comment(
verbose,
prefix="Before iterative layer",
params_d={
"ilay": ilay,
"depth": depth,
"lll_hidstate.shape": lll_hidstate.shape})
lll_hidstate = layer(lll_hidstate)[0]
hstate_count.new_one()
comment(
verbose,
prefix="After iterative layer",
params_d={
"ilay": ilay,
"depth": depth,
"lll_hidstate.shape": lll_hidstate.shape})
comment(verbose,
prefix="Before dropout",
params_d={
"depth": depth,
"lll_hidstate.shape": lll_hidstate.shape})
lll_hidstate = self.dropout_fun(lll_hidstate)
hstate_count.new_one()
comment(verbose,
prefix="After dropout",
params_d={
"depth": depth,
"lll_hidstate.shape": lll_hidstate.shape})
lll_loc = x_d["ll_osent_wstart_loc"].unsqueeze(2). \
repeat(1, 1, lll_hidstate.shape[2])
lll_word_hidstate = torch.gather(
input=lll_hidstate,
dim=1,
index=lll_loc)
comment(
verbose,
prefix="Gather's 2 inputs, then output",
params_d={
"lll_hidstate.shape": lll_hidstate.shape,
"lll_loc.shape": lll_loc.shape,
"lll_word_hidstate.shape": lll_word_hidstate.shape})
word_hstate_count.new_one(reset=True)
if depth != 0:
comment(
verbose,
prefix="before argmax",
params_d={"lll_word_score.shape": lll_word_score.shape})
ll_greedy_ilabel = torch.argmax(lll_word_score, dim=-1)
comment(
verbose,
prefix="after argmax",
params_d={"ll_greedy_ilabel.shape":
ll_greedy_ilabel.shape})
# not an integer code/embedding
comment(
verbose,
prefix="before embedding",
params_d={"ll_greedy_ilabel.shape":
ll_greedy_ilabel.shape})
lll_pred_code = self.embedding(ll_greedy_ilabel)
comment(
verbose,
prefix="after embedding",
params_d={"lll_word_hidstate.state":
lll_word_hidstate.shape})
lll_word_hidstate += lll_pred_code
word_hstate_count.new_one()
comment(
verbose,
prefix="just summed two signals with this shape",
params_d={
"depth": depth,
"lll_word_hidstate.shape": lll_word_hidstate.shape})
comment(verbose,
prefix="Before merge layer",
params_d={
"depth": depth,
"lll_word_hidstate.shape": lll_word_hidstate.shape})
lll_word_hidstate = self.merge_layer(lll_word_hidstate)
comment(
verbose,
prefix="After merge layer",
params_d={
"depth": depth,
"lll_word_hidstate.shape": lll_word_hidstate.shape})
comment(
verbose,
prefix="Before ilabelling",
params_d={
"depth": depth,
"lll_word_hidstate.shape": lll_word_hidstate.shape})
lll_word_score = self.ilabelling_layer(lll_word_hidstate)
comment(
verbose,
prefix="After ilabelling",
params_d={
"depth": depth,
"lll_word_score.shape": lll_word_score.shape})
llll_word_score.append(lll_word_score)
depth += 1
if depth >= num_depths:
break
if ttt != 'train':
# torch.max() returns a tuple (max, argmax)
ll_pred_ilabel = torch.max(lll_word_score, dim=2)[1]
valid_extraction = False
# if not training, leave while loop if
# ll_pred_ilabel has no valid extractions in it
for l_pred_ilabel in ll_pred_ilabel:
if is_valid_label_list(
l_pred_ilabel, self.params.task, "ilabels"):
valid_extraction = True
break
if not valid_extraction:
break
comment(verbose,
prefix="Leaving Model.sax_get_llll_word_score()",
params_d={
"len(llll_word_score)": len(llll_word_score),
"llll_word_score[0].shape": llll_word_score[0].shape})
return llll_word_score
@staticmethod
def sax_penalty_loss(x_d,
llll_word_scoreT,
con_to_weight):
"""
similar to Openie6.model.constrained_loss()
This method is called inside sax_batch_loss(). It returns the
penalty loss.
Parameters
----------
x_d: OrderedDict
llll_word_scoreT: torch.Tensor
con_to_weight: dict[str, float]
Returns
-------
float
penalty_loss
"""
batch_size, num_depths, num_words, icode_dim = \
llll_word_scoreT.shape
penalty_loss = 0
llll_index = x_d["ll_osent_verb_loc"]. \
unsqueeze(1).unsqueeze(3).repeat(1, num_depths, 1, icode_dim)
llll_verb_trust = torch.gather(
input=llll_word_scoreT,
dim=2,
index=llll_index)
lll_verb_rel_trust = llll_verb_trust[:, :, :, 2]
# (batch_size, depth, num_words)
lll_bool = (x_d["ll_osent_verb_loc"] != 0).unsqueeze(1).float()
lll_verb_rel_trust = lll_verb_rel_trust * lll_bool
# every head-verb must be included in a relation
if 'hvc' in con_to_weight:
ll_column_loss = \
torch.abs(1 - torch.sum(lll_verb_rel_trust, dim=1))
ll_column_loss = \
ll_column_loss[x_d["ll_osent_verb_loc"] != 0]
penalty_loss += con_to_weight['hvc'] * ll_column_loss.sum()
# extractions must have at least k-relations with
# a head verb in them
if 'hvr' in con_to_weight:
l_a = x_d["ll_osent_verb_bool"].sum(dim=1).float()
l_b = torch.max(lll_verb_rel_trust, dim=2)[0].sum(dim=1)
row_rel_loss = F.relu(l_a - l_b)
penalty_loss += con_to_weight['hvr'] * row_rel_loss.sum()
# one relation cannot contain more than one head verb
if 'hve' in con_to_weight:
ll_ex_loss = \
F.relu(torch.sum(lll_verb_rel_trust, dim=2) - 1)
penalty_loss += con_to_weight['hve'] * ll_ex_loss.sum()
if 'posm' in con_to_weight:
llll_index = \
x_d["ll_osent_pos_loc"].unsqueeze(1).unsqueeze(3). \
repeat(1, num_depths, 1, icode_dim)
llll_pred_trust = torch.gather(
input=llll_word_scoreT,
dim=2,
index=llll_index)
lll_pos_not_none_trust = \
torch.max(llll_pred_trust[:, :, :, 1:], dim=-1)[0]
ll_column_loss = \
(1 - torch.max(lll_pos_not_none_trust, dim=1)[0]) * \
(x_d["ll_osent_pos_loc"] != 0).float()
penalty_loss += con_to_weight['posm'] * ll_column_loss.sum()
return penalty_loss
def sax_get_con_to_l_penalty_loss(self,
x_d,
llll_word_score,
lll_pred_ilabel0):
"""
This method returns a dictionary con_to_l_penalty_loss. Although
Openie6 calculates con_to_l_penalty_loss inside self.forward(),
it never uses it. SentenceAx doesn't either.
con_to_l_penalty_loss similar to Openie6._constD.
Parameters
----------
x_d: OrderedDict
llll_word_score: list[torch.Tensor]
lll_pred_ilabel0: torch.Tensor
Returns
-------
dict[str, list[float]]
"""
con_to_l_penalty_loss = {}
# this calculates llll_word_score
if self.constraint_str and \
'extract' not in self.params.action and \
self.params.d["batch_size"] != 1:
# reshape llll_word_score
llll_word_scoreT = torch.cat([lll.unsqueeze(1) for
lll in llll_word_score], dim=1)
# this fills tensor with 0's
llll_word_scoreT.fill_(0)
# for checking test set
# lll_ilabel = copy(lll_pred_ilabel)
# ll_ilabel[lll_ilabel == -100] = 0
lll_ilabel = copy(lll_pred_ilabel0)
llll_ilabel = lll_ilabel.unsqueeze(-1)
number_depths = llll_ilabel.shape[1]
llll_word_scoreT = llll_word_scoreT[:, :number_depths, :, :]
llll_word_scoreT.scatter_(
dim=3,
index=llll_ilabel.long(),
src=1)
# this uses llll_word_score that was calculated previously
# to calculate con_to_l_penalty_loss
for constraint, con_weight in self.con_to_weight.items():
penalty_loss = Model.sax_penalty_loss(
x_d,
llll_word_scoreT,
{constraint: con_weight})
if constraint not in con_to_l_penalty_loss:
con_to_l_penalty_loss[constraint] = []
con_to_l_penalty_loss[constraint].append(penalty_loss)
return con_to_l_penalty_loss
def forward(self, batch, batch_idx, ttt):
"""
This method returns an instance of MOutput named batch_m_out.
batch_m_out is the output after a batch passes through all the
layers of the neural net.
signature of parent method: def forward(self, *args, **kwargs)
The following methods invoke forward() once:
training_step(), validation_step(), test_step()
lll_word_score = Openie6.word_scores
llll_word_score = Openie6.all_depth_scores (shape=(5,..))
lll_pred_ilabel0 = Openie6.predictions
llll_pred_ilabel = Openie6.all_depth_predictions
ll_pred_confidence0 = Openie6.confidences
lll_pred_confidence = Openie6.all_depth_confidences
the outermost l in "all_depths_*" is for depth \in range(5)
Many of the tensor contortions in this method are done in order to
move that depth index in llll_word_score from the outermost position
to the dim=1, where it is located in lll_ilabel. Also, we need to
get rid (by argmax) of the innermost index corresponding to the 6
possible ilabels (classes).
if A and B are of shape (3, 4):
torch.cat([A, B], dim=0) will be of shape (6, 4)
torch.stack([A, B], dim=0) will be of shape (2, 3, 4)
Parameter
----------
batch: tuple[torch.Tensor,
torch.Tensor, list[str], dict[str, list[int]]
batch_idx: int
ttt: str
Returns
-------
MOutput
batch_m_out
"""
x_d, y_d, meta_d = Model.sax_get_batch_in_dicts(batch)
# print_tensor("y_d['lll_ilabel']", y_d['lll_ilabel'])
# print_list("y_d['lll_ilabel'][0][0]", y_d['lll_ilabel'][0][0])
use_wreg = "wreg" in self.params.d and self.params.d["wreg"] != 0
if use_wreg:
# wreg=weight regulator
# name_to_param0 is self.named_parameters() when
# forward() is first called
if not self.name_to_param0:
name_to_param0 = deepcopy(
dict(self.named_parameters()))
# lll_label is similar to Openie6.labels
# first (outer) list over batch/sample of events
# second list over extractions
# third (inner) list over number of labels in a line
# after padding and adding the 3 unused tokens
# batch_size, num_depths, num_words = y_d["lll_ilabel"].shape
# `loss_fun` is not used in this function anymore
# loss_fun, lstm_loss = 0, 0
# batch_text = " ".join(redoL(meta_d["l_orig_sent"]))
# base_model_input = \
# torch.Tensor(self.auto_tokenizer.encode(batch_text))
llll_word_score = self.sax_get_llll_word_score(
x_d, ttt, self.verbose)
# print_tensor("lll_word_score", lll_word_score)
# print("vvbg", "len(llll_word_score)", len(llll_word_score))
# print_tensor("llll_word_score[0]", llll_word_score[0])
loss = 0
llll_pred_ilabel = []
lll_pred_confidence = []
batch_size, num_words, xxx = llll_word_score[0].shape
# len(llll_word_score)=5, the num_depths
# xxx = 6, the number of different ilabels (num_classes)
# y_d["lll_ilabel"] = \
# y_d["lll_ilabel"].long()
for depth, lll_word_score in enumerate(llll_word_score):
if ttt == 'train':
# Here -1 will be
# num_ilabels=6=number of classes to classify.
# In general, reshape(x, -1) means final shape = (x, y)
# where y is whatever it takes to get original num of entries
ll_loss_input = \
lll_word_score.reshape(batch_size * num_words, -1)
# print_tensor("lll_word_score", lll_word_score)
# print_tensor("ll_loss_input", ll_loss_input)
# ll_loss_input.shape = (batch_size * num_words, num_classes=6)
# = data
# l_loss_target.shape = (batch_size * num_words, )
# = theory
# l_loss_target[i] \in range(6)
# loss is scalar
l_loss_target = \
y_d["lll_ilabel"][:, depth, :num_words].reshape(-1)
loss += self.loss_fun(ll_loss_input,
l_loss_target)
# print("loss shape", loss.shape)
# print_tensor("l_loss_target", l_loss_target)
# print("loss", loss)
if ttt != "train":
lll_soft_word_score = \
torch.log_softmax(lll_word_score, dim=2)
ll_max_log_prob, ll_pred_ilabel = \
torch.max(lll_soft_word_score, dim=2)
# print_tensor("ll_max_log_prob", ll_max_log_prob)
# print_tensor("ll_pred_ilabel", ll_pred_ilabel)
# print("task=", ttt)
# print_tensor("lll_word_score", lll_word_score)
# print_tensor("lll_soft_word_score", lll_soft_word_score)
# print_tensor("ll_pred_ilabel", ll_pred_ilabel)
# print("sum(""ll_pred_ilabel=", torch.sum(ll_pred_ilabel))
# remember: lll_ilabel was similar to Openie6.labels
# first (outermost) list over batch events
# second list over extractions
# third (innermost) list over number of ilabels in a line
# print("ttt, action", ttt, self.params.action)
# print_tensor("lll_ilabel", y_d["lll_ilabel"])
# For ttt!=train, y_d["lll_ilabel"] entries are all \in [
# 0, -100] because we store that info in Carb benchmarks.
# That is fine because we only need y_d["lll_ilabel"] to
# create this template
ll_nonpad_bool = \
(y_d["lll_ilabel"][:, 0, :] != -100).float()
# print("dfrt", {name: x_d[name].shape for name in x_d.keys()})
# print_tensor("ll_nonpad_bool", ll_nonpad_bool)
# print_tensor("(ll_pred_ilabel != 0)",
# (ll_pred_ilabel != 0).float())
# * is element-wise multiplication of tensors
ll_nonpad_bool = ll_nonpad_bool[:, :ll_pred_ilabel.shape[1]]
ll_nonpad_bool = \
(ll_pred_ilabel != 0).float() * ll_nonpad_bool
ll_norm_log_prob = \
(ll_max_log_prob * ll_nonpad_bool) \
/ (1 + ll_nonpad_bool.sum(dim=0))
l_confidence = torch.exp(
torch.sum(ll_norm_log_prob, dim=1))
# this unsqueezes depth dim=1
llll_pred_ilabel.append(ll_pred_ilabel.unsqueeze(1))
lll_pred_confidence.append(l_confidence.unsqueeze(1))
# } end of loop over depth, lll_word_score
if ttt == 'train':
if self.con_to_weight:
# unsqueeze dim=1, then cat() along dim=1. This
# removes the outermost l and "fattens" dim=1
llll_word_scoreT = torch.cat(
[lll.unsqueeze(1) for lll in llll_word_score], dim=1)
llll_word_scoreT = torch.softmax(llll_word_scoreT, dim=-1)
penalty_loss = Model.sax_penalty_loss(
x_d,
llll_word_scoreT,
self.con_to_weight) / batch_size
# IMPORTANT Openie6 has
# loss = const_loss
# instead of
# loss += const_loss
loss += penalty_loss
if use_wreg:
weight_diff = 0
# name_to_param0 is self.named_parameters()
# when forward() is first called
name_to_param = dict(self.named_parameters())
for name in name_to_param0:
weight_diff += torch.linalg.vector_norm(
name_to_param[name] - name_to_param0[name])
loss += self.params.d["wreg"] * weight_diff
# llll_pred_ilabel: list[tensor]
# lll_pred_confidence: list[tensor]
if ttt != "train":
# unsqueeze dim=1, then cat along dim=1. This
# removes the outermost l and "fattens" dim=1
lll_pred_ilabel0 = torch.cat(llll_pred_ilabel, dim=1)
ll_pred_confidence0 = torch.cat(lll_pred_confidence, dim=1)
# never used
# self.con_to_l_loss = self.sax_get_con_to_l_loss(
# x_d,
# llll_word_scoreT,
# lll_pred_ilabel0)
assert loss == 0
else:
lll_pred_ilabel0 = Ten([0])
ll_pred_confidence0 = Ten([0])
batch_m_out = MOutput(meta_d["l_orig_sent"],
y_d["lll_ilabel"],
lll_pred_ilabel0,
ll_pred_confidence0,
loss)