diff --git a/src/engine.py b/src/engine.py index 565a35e..e7d75d5 100644 --- a/src/engine.py +++ b/src/engine.py @@ -7,10 +7,14 @@ import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader -from tqdm import tqdm def train_one_epoch(model, loader: DataLoader, opt, loss_fn=None) -> float: + """Standard training epoch. tqdm is imported lazily so that callers that + only need `distillation_loss` or `benchmark_latency` (e.g. the unit-test + path) do not pay the tqdm install cost.""" + from tqdm import tqdm # noqa: PLC0415 + model.train() if loss_fn is None: loss_fn = nn.CrossEntropyLoss()