diff --git a/lotus/models/cross_encoder_reranker.py b/lotus/models/cross_encoder_reranker.py index a177ef22..a506660b 100644 --- a/lotus/models/cross_encoder_reranker.py +++ b/lotus/models/cross_encoder_reranker.py @@ -1,4 +1,3 @@ -from sentence_transformers import CrossEncoder from lotus.models.reranker import Reranker from lotus.types import RerankerOutput @@ -20,7 +19,19 @@ def __init__( max_batch_size: int = 64, ): self.max_batch_size: int = max_batch_size - self.model = CrossEncoder(model, device=device) # type: ignore # CrossEncoder has wrong type stubs + self._model_name = model + self._device = device + self._model = None # Initialize model as None for lazy loading + + @property + def model(self): + """Lazy load the model when it's first accessed.""" + if self._model is None: + # Only import CrossEncoder when needed + from sentence_transformers import CrossEncoder + + self._model = CrossEncoder(self._model_name, device=self._device) # type: ignore # CrossEncoder has wrong type stubs + return self._model def __call__(self, query: str, docs: list[str], K: int) -> RerankerOutput: results = self.model.rank(query, docs, top_k=K, batch_size=self.max_batch_size, show_progress_bar=False)