diff --git a/src/app/services/search.py b/src/app/services/search.py index 20de922..d7d2159 100644 --- a/src/app/services/search.py +++ b/src/app/services/search.py @@ -1,6 +1,6 @@ import json import time -from functools import cache +# from functools import lru_cache as cache from typing import Tuple, cast import numpy as np @@ -118,7 +118,7 @@ def get_query_embed( return embedding - @cache + # @cache def _get_model(self, curr_model: str) -> tuple[int | None, SentenceTransformer]: try: time_start = time.time() @@ -135,8 +135,8 @@ def _get_model(self, curr_model: str) -> tuple[int | None, SentenceTransformer]: raise ModelNotFoundError() return (model.get_max_seq_length(), model) - @cache - def _split_input_seq_len(self, seq_len: int, input: str) -> list[str]: + # @cache + def _split_input_seq_len(self, seq_len: int | None, input: str) -> list[str]: if not seq_len: raise ValueError("Sequence length value is not valid") @@ -158,7 +158,7 @@ def _split_input_seq_len(self, seq_len: int, input: str) -> list[str]: return inputs - @cache + # @cache def _embed_query(self, search_input: str, curr_model: str) -> np.ndarray: logger.debug("Creating embeddings model=%s", curr_model) time_start = time.time() diff --git a/src/app/services/search_helpers.py b/src/app/services/search_helpers.py index cc5563b..303f7cd 100644 --- a/src/app/services/search_helpers.py +++ b/src/app/services/search_helpers.py @@ -29,11 +29,11 @@ async def search_multi_inputs( for qp in qps ] + data = await asyncio.gather(*tasks) all_data: list[ScoredPoint] = [] - for coroutine in asyncio.as_completed(tasks): - data = await coroutine - if data: - all_data.extend(data) + + for sublist in data: + all_data.extend(sublist) return all_data except Exception as e: