From 99a22eb15ea224c727c4e7a57bae0b02e55b286e Mon Sep 17 00:00:00 2001 From: "tiffany.duneau" Date: Tue, 29 Jul 2025 13:22:42 +0100 Subject: [PATCH 1/5] add learning rate schedule option to pytorch trainer. --- lambeq/training/pytorch_trainer.py | 23 ++++++++++++++++++++++- lambeq/training/trainer.py | 7 +++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/lambeq/training/pytorch_trainer.py b/lambeq/training/pytorch_trainer.py index 3d8bf100..705aa95a 100644 --- a/lambeq/training/pytorch_trainer.py +++ b/lambeq/training/pytorch_trainer.py @@ -42,10 +42,12 @@ def __init__(self, loss_function: Callable[..., torch.Tensor], epochs: int, optimizer: type[torch.optim.Optimizer] = torch.optim.AdamW, + scheduler: type[torch.optim.lr_scheduler] | None = None, learning_rate: float = 1e-3, device: int = -1, *, optimizer_args: dict[str, Any] | None = None, + scheduler_args: dict[str, Any] | None = None, evaluate_functions: Mapping[str, EvalFuncT] | None = None, evaluate_on_train: bool = True, use_tensorboard: bool = False, @@ -66,6 +68,9 @@ def __init__(self, Number of training epochs. optimizer : torch.optim.Optimizer, default: torch.optim.AdamW A PyTorch optimizer from `torch.optim`. + scheduler : torch.optim.lr_scheduler, default: None + A PyTorch scheduler for the learning rate, from + `torch.optim.lr_scheduler`. learning_rate : float, default: 1e-3 The learning rate provided to the optimizer for training. device : int, default: -1 @@ -73,6 +78,8 @@ def __init__(self, A negative value uses the CPU. optimizer_args : dict of str to Any, optional Any extra arguments to pass to the optimizer. + scheduler_args : dict of str to Any, optional + Any extra arguments to pass to the scheduler. evaluate_functions : mapping of str to callable, optional Mapping of evaluation metric functions from their names. Structure [{"metric": func}]. @@ -118,6 +125,13 @@ def __init__(self, if learning_rate is not None: optimizer_args['lr'] = learning_rate self.optimizer = optimizer(self.model.parameters(), **optimizer_args) + + scheduler_args = dict(scheduler_args or {}) + if scheduler is not None: + self.scheduler = scheduler(self.optimizer, **scheduler_args) + else: + self.scheduler = None + self.model.to(self.device) def _add_extra_checkpoint_info(self, checkpoint: Checkpoint) -> None: @@ -136,7 +150,8 @@ def _add_extra_checkpoint_info(self, checkpoint: Checkpoint) -> None: """ checkpoint.add_many( {'torch_random_state': torch.get_rng_state(), - 'optimizer_state_dict': self.optimizer.state_dict()}) + 'optimizer_state_dict': self.optimizer.state_dict(), + 'scheduler_state_dict': self.scheduler.state_dict() or None}) def _load_extra_checkpoint_info(self, checkpoint: Checkpoint) -> None: """Load additional checkpoint information. @@ -152,6 +167,8 @@ def _load_extra_checkpoint_info(self, checkpoint: Checkpoint) -> None: """ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) torch.set_rng_state(checkpoint['torch_random_state']) + if checkpoint['scheduler_state_dict'] is not None: + self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) def validation_step( self, @@ -201,3 +218,7 @@ def training_step( loss.backward() self.optimizer.step() return y_hat, loss.item() + + def post_epoch_step(self, epoch): + # Step the scheduler if present + self.scheduler.step(epoch=epoch) diff --git a/lambeq/training/trainer.py b/lambeq/training/trainer.py index 2e426dc0..d1294ca5 100644 --- a/lambeq/training/trainer.py +++ b/lambeq/training/trainer.py @@ -368,6 +368,11 @@ def validation_step( """ + def post_epoch_step(self, epoch: int): + """Perform any post-epoch updates, such as updating the scheduled + learning rate.""" + pass + def _get_weighted_mean(self, metric_running: list[tuple[int, Any]]): """Calculate weighted mean of metric from the running results.""" @@ -761,6 +766,8 @@ def fit(self, if early_stopping: break # inner epoch loop + self.post_epoch_step(epoch) + epoch_end = time.time() epoch_duration = epoch_end - epoch_start self.train_epoch_durations.append(epoch_duration) From 8fa6d469f8a49a02e2ce96ba499a271e6c357cf2 Mon Sep 17 00:00:00 2001 From: "tiffany.duneau" Date: Tue, 29 Jul 2025 13:33:44 +0100 Subject: [PATCH 2/5] debug and add test --- lambeq/training/pytorch_trainer.py | 7 +++++-- tests/training/test_pytorch_trainer.py | 12 +++++++++++- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/lambeq/training/pytorch_trainer.py b/lambeq/training/pytorch_trainer.py index 705aa95a..ef380207 100644 --- a/lambeq/training/pytorch_trainer.py +++ b/lambeq/training/pytorch_trainer.py @@ -151,7 +151,9 @@ def _add_extra_checkpoint_info(self, checkpoint: Checkpoint) -> None: checkpoint.add_many( {'torch_random_state': torch.get_rng_state(), 'optimizer_state_dict': self.optimizer.state_dict(), - 'scheduler_state_dict': self.scheduler.state_dict() or None}) + 'scheduler_state_dict': self.scheduler.state_dict() + if self.scheduler is not None + else None}) def _load_extra_checkpoint_info(self, checkpoint: Checkpoint) -> None: """Load additional checkpoint information. @@ -221,4 +223,5 @@ def training_step( def post_epoch_step(self, epoch): # Step the scheduler if present - self.scheduler.step(epoch=epoch) + if self.scheduler is not None: + self.scheduler.step(epoch=epoch) diff --git a/tests/training/test_pytorch_trainer.py b/tests/training/test_pytorch_trainer.py index 5ab55212..9837cc23 100644 --- a/tests/training/test_pytorch_trainer.py +++ b/tests/training/test_pytorch_trainer.py @@ -3,6 +3,7 @@ import torch import numpy as np +import pytest from lambeq.backend.grammar import Cup, Id, Word from lambeq.backend.tensor import Dim @@ -36,7 +37,8 @@ dev_circuits = [ansatz(d) for d in dev_diagrams] -def test_trainer(tmp_path): +@pytest.mark.parametrize("scheduler", [None, torch.optim.lr_scheduler.StepLR]) +def test_trainer(tmp_path, scheduler): model = PytorchModel.from_diagrams(train_circuits + dev_circuits) log_dir = tmp_path / 'test_runs' @@ -45,6 +47,8 @@ def test_trainer(tmp_path): model=model, loss_function=torch.nn.BCEWithLogitsLoss(), optimizer=torch.optim.AdamW, + scheduler=scheduler, + scheduler_args={"step_size": 1, "gamma": 0.1}, learning_rate=3e-3, epochs=EPOCHS, evaluate_functions={"acc": acc}, @@ -66,6 +70,12 @@ def test_trainer(tmp_path): assert len(trainer.train_durations) == EPOCHS * ( ceil(len(train_diagrams) / train_dataset.batch_size)) assert len(trainer.val_durations) == EPOCHS + if scheduler is not None: + # Expect lr to have decayed exactly once, up to float error = 0.1 * 3e-3 + assert torch.allclose( + torch.tensor(trainer.scheduler.get_last_lr()), + torch.tensor(3e-4) + ) def test_restart_training(tmp_path): From 3de35e3f045725bb1ac83b823853bcf6f03dac0b Mon Sep 17 00:00:00 2001 From: "tiffany.duneau" Date: Tue, 29 Jul 2025 13:41:05 +0100 Subject: [PATCH 3/5] linting --- lambeq/training/pytorch_trainer.py | 4 ++-- lambeq/training/trainer.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lambeq/training/pytorch_trainer.py b/lambeq/training/pytorch_trainer.py index ef380207..8efffc8f 100644 --- a/lambeq/training/pytorch_trainer.py +++ b/lambeq/training/pytorch_trainer.py @@ -152,8 +152,8 @@ def _add_extra_checkpoint_info(self, checkpoint: Checkpoint) -> None: {'torch_random_state': torch.get_rng_state(), 'optimizer_state_dict': self.optimizer.state_dict(), 'scheduler_state_dict': self.scheduler.state_dict() - if self.scheduler is not None - else None}) + if self.scheduler is not None + else None}) def _load_extra_checkpoint_info(self, checkpoint: Checkpoint) -> None: """Load additional checkpoint information. diff --git a/lambeq/training/trainer.py b/lambeq/training/trainer.py index d1294ca5..70cdf7cb 100644 --- a/lambeq/training/trainer.py +++ b/lambeq/training/trainer.py @@ -371,7 +371,7 @@ def validation_step( def post_epoch_step(self, epoch: int): """Perform any post-epoch updates, such as updating the scheduled learning rate.""" - pass + return None def _get_weighted_mean(self, metric_running: list[tuple[int, Any]]): From 8d7e9b744d7e17bf9293e56319f340a0ba6fd195 Mon Sep 17 00:00:00 2001 From: "tiffany.duneau" Date: Tue, 29 Jul 2025 13:45:06 +0100 Subject: [PATCH 4/5] linting 2 --- lambeq/training/pytorch_trainer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/lambeq/training/pytorch_trainer.py b/lambeq/training/pytorch_trainer.py index 8efffc8f..d2f2f4f9 100644 --- a/lambeq/training/pytorch_trainer.py +++ b/lambeq/training/pytorch_trainer.py @@ -42,7 +42,8 @@ def __init__(self, loss_function: Callable[..., torch.Tensor], epochs: int, optimizer: type[torch.optim.Optimizer] = torch.optim.AdamW, - scheduler: type[torch.optim.lr_scheduler] | None = None, + scheduler: type[torch.optim.lr_scheduler.LRScheduler] | None + = None, learning_rate: float = 1e-3, device: int = -1, *, @@ -127,10 +128,8 @@ def __init__(self, self.optimizer = optimizer(self.model.parameters(), **optimizer_args) scheduler_args = dict(scheduler_args or {}) - if scheduler is not None: - self.scheduler = scheduler(self.optimizer, **scheduler_args) - else: - self.scheduler = None + self.scheduler = (scheduler(self.optimizer, **scheduler_args) + if scheduler is not None else None) self.model.to(self.device) @@ -169,7 +168,8 @@ def _load_extra_checkpoint_info(self, checkpoint: Checkpoint) -> None: """ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) torch.set_rng_state(checkpoint['torch_random_state']) - if checkpoint['scheduler_state_dict'] is not None: + if (checkpoint['scheduler_state_dict'] is not None + and self.scheduler is not None): self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) def validation_step( From 75684333c249fd9b303c72cc7eba52d5c9a7428d Mon Sep 17 00:00:00 2001 From: "tiffany.duneau" Date: Tue, 29 Jul 2025 14:13:05 +0100 Subject: [PATCH 5/5] add loss kwarg to scheduler --- lambeq/training/pytorch_trainer.py | 9 +++++++-- lambeq/training/trainer.py | 8 ++++---- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/lambeq/training/pytorch_trainer.py b/lambeq/training/pytorch_trainer.py index d2f2f4f9..56e4ec23 100644 --- a/lambeq/training/pytorch_trainer.py +++ b/lambeq/training/pytorch_trainer.py @@ -221,7 +221,12 @@ def training_step( self.optimizer.step() return y_hat, loss.item() - def post_epoch_step(self, epoch): + def post_epoch_step(self, epoch: int, loss: float): # Step the scheduler if present if self.scheduler is not None: - self.scheduler.step(epoch=epoch) + # Plateau scheduler wants to know about the loss + if isinstance(self.scheduler, + torch.optim.lr_scheduler.ReduceLROnPlateau): + self.scheduler.step(loss, epoch=epoch) + else: + self.scheduler.step(epoch=epoch) diff --git a/lambeq/training/trainer.py b/lambeq/training/trainer.py index 70cdf7cb..8bfbe8bd 100644 --- a/lambeq/training/trainer.py +++ b/lambeq/training/trainer.py @@ -368,7 +368,7 @@ def validation_step( """ - def post_epoch_step(self, epoch: int): + def post_epoch_step(self, epoch: int, loss: float): """Perform any post-epoch updates, such as updating the scheduled learning rate.""" return None @@ -766,15 +766,15 @@ def fit(self, if early_stopping: break # inner epoch loop - self.post_epoch_step(epoch) + epoch_loss = self._get_weighted_mean(train_losses) + self.post_epoch_step(epoch, loss=epoch_loss) epoch_end = time.time() epoch_duration = epoch_end - epoch_start self.train_epoch_durations.append(epoch_duration) # calculate epoch loss - self.train_epoch_costs.append( - self._get_weighted_mean(train_losses)) + self.train_epoch_costs.append(epoch_loss) self._to_tensorboard('train/epoch_loss', self.train_epoch_costs[-1], epoch)