From 28c181e1b3a93cc0d5e4c615e9178c36206dc1a5 Mon Sep 17 00:00:00 2001 From: CRI USER Date: Thu, 22 May 2025 15:35:39 +0200 Subject: [PATCH 1/2] test remove cache --- src/app/services/search.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/app/services/search.py b/src/app/services/search.py index 20de922..a298d43 100644 --- a/src/app/services/search.py +++ b/src/app/services/search.py @@ -1,6 +1,6 @@ import json import time -from functools import cache +# from functools import lru_cache as cache from typing import Tuple, cast import numpy as np @@ -118,7 +118,7 @@ def get_query_embed( return embedding - @cache + # @cache def _get_model(self, curr_model: str) -> tuple[int | None, SentenceTransformer]: try: time_start = time.time() @@ -135,7 +135,7 @@ def _get_model(self, curr_model: str) -> tuple[int | None, SentenceTransformer]: raise ModelNotFoundError() return (model.get_max_seq_length(), model) - @cache + # @cache def _split_input_seq_len(self, seq_len: int, input: str) -> list[str]: if not seq_len: raise ValueError("Sequence length value is not valid") @@ -158,7 +158,7 @@ def _split_input_seq_len(self, seq_len: int, input: str) -> list[str]: return inputs - @cache + # @cache def _embed_query(self, search_input: str, curr_model: str) -> np.ndarray: logger.debug("Creating embeddings model=%s", curr_model) time_start = time.time() From 4954c25740f26822d8575117d483ad429e954339 Mon Sep 17 00:00:00 2001 From: CRI USER Date: Thu, 22 May 2025 16:18:08 +0200 Subject: [PATCH 2/2] use async io gather instead of as completed --- src/app/services/search.py | 2 +- src/app/services/search_helpers.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/app/services/search.py b/src/app/services/search.py index a298d43..d7d2159 100644 --- a/src/app/services/search.py +++ b/src/app/services/search.py @@ -136,7 +136,7 @@ def _get_model(self, curr_model: str) -> tuple[int | None, SentenceTransformer]: return (model.get_max_seq_length(), model) # @cache - def _split_input_seq_len(self, seq_len: int, input: str) -> list[str]: + def _split_input_seq_len(self, seq_len: int | None, input: str) -> list[str]: if not seq_len: raise ValueError("Sequence length value is not valid") diff --git a/src/app/services/search_helpers.py b/src/app/services/search_helpers.py index cc5563b..303f7cd 100644 --- a/src/app/services/search_helpers.py +++ b/src/app/services/search_helpers.py @@ -29,11 +29,11 @@ async def search_multi_inputs( for qp in qps ] + data = await asyncio.gather(*tasks) all_data: list[ScoredPoint] = [] - for coroutine in asyncio.as_completed(tasks): - data = await coroutine - if data: - all_data.extend(data) + + for sublist in data: + all_data.extend(sublist) return all_data except Exception as e: