diff --git a/rankers/modelling/base.py b/rankers/modelling/base.py index c4557a5..d519b93 100644 --- a/rankers/modelling/base.py +++ b/rankers/modelling/base.py @@ -163,18 +163,12 @@ def save_pretrained(self, model_dir, **kwargs): self.model.save_pretrained(model_dir) self.tokenizer.save_pretrained(model_dir) - def load_state_dict(self, model_dir): - """Load model state dictionary from a directory. - - Args: - model_dir (str): Directory containing the saved model. - - Returns: - dict: Result of loading the state dictionary. - """ - return self.model.load_state_dict( - self.architecture_class.from_pretrained(model_dir).state_dict() - ) + def load_state_dict(self, state_dict, strict=False, **kwargs): + if isinstance(state_dict, str): + loaded = self.architecture_class.from_pretrained(state_dict).state_dict() + return self.model.load_state_dict(loaded, strict=strict) + + return self.model.load_state_dict(state_dict, strict=strict) def to_pyterrier(self, batch_size=None): """Convert the ranker to a PyTerrier transformer.