Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/app/services/search.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import time
from functools import cache
# from functools import lru_cache as cache
from typing import Tuple, cast

import numpy as np
Expand Down Expand Up @@ -118,7 +118,7 @@ def get_query_embed(

return embedding

@cache
# @cache
def _get_model(self, curr_model: str) -> tuple[int | None, SentenceTransformer]:
try:
time_start = time.time()
Expand All @@ -135,8 +135,8 @@ def _get_model(self, curr_model: str) -> tuple[int | None, SentenceTransformer]:
raise ModelNotFoundError()
return (model.get_max_seq_length(), model)

@cache
def _split_input_seq_len(self, seq_len: int, input: str) -> list[str]:
# @cache
def _split_input_seq_len(self, seq_len: int | None, input: str) -> list[str]:
if not seq_len:
raise ValueError("Sequence length value is not valid")

Expand All @@ -158,7 +158,7 @@ def _split_input_seq_len(self, seq_len: int, input: str) -> list[str]:

return inputs

@cache
# @cache
def _embed_query(self, search_input: str, curr_model: str) -> np.ndarray:
logger.debug("Creating embeddings model=%s", curr_model)
time_start = time.time()
Expand Down
8 changes: 4 additions & 4 deletions src/app/services/search_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ async def search_multi_inputs(
for qp in qps
]

data = await asyncio.gather(*tasks)
all_data: list[ScoredPoint] = []
for coroutine in asyncio.as_completed(tasks):
data = await coroutine
if data:
all_data.extend(data)

for sublist in data:
all_data.extend(sublist)

return all_data
except Exception as e:
Expand Down