From c737b4bb2bf237bf8d8621dbc67e923fd53e869b Mon Sep 17 00:00:00 2001 From: Jack McKechnie <46594345+JackMcKechnie@users.noreply.github.com> Date: Wed, 19 Nov 2025 14:03:38 +0000 Subject: [PATCH] Refactor load_state_dict method to accept state_dict directly Refactor load_state_dict method to accept state_dict directly and handle string input. --- rankers/modelling/base.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/rankers/modelling/base.py b/rankers/modelling/base.py index c4557a5..d519b93 100644 --- a/rankers/modelling/base.py +++ b/rankers/modelling/base.py @@ -163,18 +163,12 @@ def save_pretrained(self, model_dir, **kwargs): self.model.save_pretrained(model_dir) self.tokenizer.save_pretrained(model_dir) - def load_state_dict(self, model_dir): - """Load model state dictionary from a directory. - - Args: - model_dir (str): Directory containing the saved model. - - Returns: - dict: Result of loading the state dictionary. - """ - return self.model.load_state_dict( - self.architecture_class.from_pretrained(model_dir).state_dict() - ) + def load_state_dict(self, state_dict, strict=False, **kwargs): + if isinstance(state_dict, str): + loaded = self.architecture_class.from_pretrained(state_dict).state_dict() + return self.model.load_state_dict(loaded, strict=strict) + + return self.model.load_state_dict(state_dict, strict=strict) def to_pyterrier(self, batch_size=None): """Convert the ranker to a PyTerrier transformer.