Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 47 additions & 3 deletions ro_diacritics/diacritcs_train.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Training, evaluation and prediction routines for the diacritics model."""

from pathlib import Path

import numpy as np
Expand All @@ -17,7 +19,21 @@ def train(
epochs=10,
checkpoint_file=None,
):

"""Train *model* with early stopping and periodic validation.

Saves the best checkpoint (by validation accuracy) to *checkpoint_file*
whenever a new best is found. Training stops early when validation
accuracy has not improved for *patience* consecutive evaluation steps.

:param model: :class:`~diacritics_model.Diacritics` model to train.
:param loss_func: Loss criterion (e.g. :class:`torch.nn.BCEWithLogitsLoss`).
:param train_dataloader: :class:`~torch.utils.data.DataLoader` for the
training set.
:param valid_dataloader: :class:`~torch.utils.data.DataLoader` for the
validation set, or ``None`` to skip validation.
:param epochs: Maximum number of training epochs.
:param checkpoint_file: Path where the best model checkpoint is saved.
"""
optimizer = optim.Adam(model.parameters(), lr=0.001)
device = next(model.parameters()).device
print(f"{device} device for training")
Expand All @@ -36,6 +52,17 @@ def train(
patience = 3

def evaluate_step(step, epoch, running_loss, max_eval_steps=None):
"""Run one evaluation pass on the validation set and update history.

Saves a new checkpoint when validation accuracy improves. Increments
the non-improving counter otherwise (used for early stopping).

:param step: Current training step within the epoch.
:param epoch: Current epoch index (0-based).
:param running_loss: Cumulative training loss up to *step*.
:param max_eval_steps: Cap on the number of validation batches to
process (``None`` means evaluate all).
"""
nonlocal best_acc, best_acc_epoch, nr_non_improving
if valid_dataloader is None:
return
Expand Down Expand Up @@ -142,7 +169,17 @@ def evaluate_step(step, epoch, running_loss, max_eval_steps=None):


def evaluate(model, dataloader: DataLoader, loss_func, epoch=None, max_eval_steps=None):
# print("***** Running prediction *****")
"""Evaluate *model* on *dataloader* and return loss, accuracy and F1.

:param model: :class:`~diacritics_model.Diacritics` model to evaluate.
:param dataloader: :class:`~torch.utils.data.DataLoader` for the
evaluation set.
:param loss_func: Loss criterion matching the one used during training.
:param epoch: Current epoch index for logging (``None`` = test evaluation).
:param max_eval_steps: Maximum number of batches to process before
stopping early (``None`` = full evaluation).
:return: Tuple of ``(eval_loss, eval_acc, f1_metrics)``.
"""
model.eval()
predict_out = []
all_label_ids = []
Expand Down Expand Up @@ -203,7 +240,14 @@ def evaluate(model, dataloader: DataLoader, loss_func, epoch=None, max_eval_step


def predict(model, dataloader: DataLoader):
# print("***** Running prediction *****")
"""Run inference with *model* over *dataloader* and return softmax scores.

:param model: :class:`~diacritics_model.Diacritics` model in eval mode.
:param dataloader: :class:`~torch.utils.data.DataLoader` yielding input
triples ``(char_input, word_emb, sentence_emb)``.
:return: List of per-sample softmax probability vectors (one list of
floats per sample, length = number of classes).
"""
model.eval()
predict_out = []

Expand Down
Loading