diff --git a/run.py b/run.py index 2906b29..798a9e2 100644 --- a/run.py +++ b/run.py @@ -345,7 +345,7 @@ def mask_tokens(inputs, tokenizer, mlm_probability=0.15): probability_matrix.masked_fill_(torch.tensor(labels == 0, dtype=torch.bool), value=0.0) masked_indices = torch.bernoulli(probability_matrix).bool() - labels[~masked_indices] = -1 # We only compute loss on masked tokens + labels[~masked_indices] = -100 # We only compute loss on masked tokens # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices