Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,20 @@ jobs:
registry-username: ${{ secrets.DOCKER_PROD_USERNAME }}
registry-password: ${{ secrets.DOCKER_PROD_PASSWORD }}

lint-and-test:
uses: ./.github/workflows/lint-and-test.yml
with:
registry-name: ${{ vars.DOCKER_PROD_REGISTRY }}
image-name: welearn-api
image-tag: ${{ github.sha }}
secrets:
registry-username: ${{ secrets.DOCKER_PROD_USERNAME }}
registry-password: ${{ secrets.DOCKER_PROD_PASSWORD }}
needs:
- build-docker
# lint-and-test:
# uses: ./.github/workflows/lint-and-test.yml
# with:
# registry-name: ${{ vars.DOCKER_PROD_REGISTRY }}
# image-name: welearn-api
# image-tag: ${{ github.sha }}
# secrets:
# registry-username: ${{ secrets.DOCKER_PROD_USERNAME }}
# registry-password: ${{ secrets.DOCKER_PROD_PASSWORD }}
# needs:
# - build-docker

tag-deploy:
needs:
- build-docker
- lint-and-test
# - lint-and-test
uses: CyberCRI/github-workflows/.github/workflows/tag-deploy.yaml@main
17 changes: 12 additions & 5 deletions src/app/api/api_v1/endpoints/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@
API_KEY=settings.MISTRAL_API_KEY,
)

# chatfactory = AbstractChat(
# model="azure/gpt-4o-mini",
# API_KEY=settings.AZURE_API_KEY,
# API_BASE=settings.AZURE_API_BASE,
# API_VERSION=settings.AZURE_API_VERSION,
# )

def get_params(body: models.Context) -> models.ContextOut:
body.sources = body.sources[:7]
Expand Down Expand Up @@ -67,13 +73,16 @@ class Response(BaseModel):
async def q_and_a_reformulate(
body: models.ContextOut = Depends(get_params),
):

reformulated_query :models.ReformulatedQueryResponse
try:
reformulated_query: models.ReformulatedQueryResponse = (
await chatfactory.reformulate_user_query(
query=body.query, history=body.history
)
)


if reformulated_query.QUERY_STATUS == "INVALID":
raise InvalidQuestionError()

Expand Down Expand Up @@ -109,11 +118,10 @@ async def q_and_a_reformulate(
)
async def q_and_a_new_questions(body: models.ContextOut = Depends(get_params)):
try:
new_questions = await chatfactory.get_new_questions(
return await chatfactory.get_new_questions(
query=body.query, history=body.history
)

return new_questions
except LanguageNotSupportedError as e:
bad_request(message=e.message, msg_code=e.msg_code)

Expand All @@ -135,16 +143,15 @@ async def q_and_a_new_questions(body: models.ContextOut = Depends(get_params)):
)
async def q_and_a_rephrase(
body: models.ContextOut = Depends(get_params),
) -> Optional[str]:
):
try:
content = await chatfactory.rephrase_message(
return await chatfactory.rephrase_message(
docs=body.sources,
message=body.query,
history=body.history,
subject=subjectsDict.get(body.subject, None),
)

return cast(str, content)
except Exception as e:
logger.error("Error while rephrasing the query: %s", e)
raise HTTPException(
Expand Down
9 changes: 5 additions & 4 deletions src/app/api/api_v1/endpoints/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,15 @@
EmptyQueryError,
bad_request,
)
from src.app.services.search import SearchService
from src.app.services.search import sp
from src.app.services.search_helpers import search_multi_inputs
from src.app.services.sql_db import session_maker
from src.app.utils.logger import logger as logger_utils


router = APIRouter()
logger = logger_utils(__name__)

sp = SearchService()


def get_params(
body: SearchQuery,
Expand Down Expand Up @@ -89,6 +88,7 @@ async def search_doc_by_collection(
nb_results: int = 10,
sdg_filter: SDGFilter | None = None,
):

if not query:
e = EmptyQueryError()
return bad_request(message=e.message, msg_code=e.msg_code)
Expand Down Expand Up @@ -155,6 +155,7 @@ async def multi_search_all_slices_by_lang(
qp=qp,
callback_function=sp.search_handler,
)

if not results:
logger.error("No results found")
# todo switch to 204 no content
Expand All @@ -168,7 +169,7 @@ async def multi_search_all_slices_by_lang(
"/by_document",
summary="search all documents",
description="Search by documents, returns only one result by document id",
response_model=list[ScoredPoint] | None | str,
response_model=list[ScoredPoint] | str,
)
async def search_all(
response: Response,
Expand Down
5 changes: 1 addition & 4 deletions src/app/api/api_v1/endpoints/tutor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from src.app.models.search import EnhancedSearchQuery
from src.app.services.abst_chat import AbstractChat
from src.app.services.exceptions import NoResultsError
from src.app.services.search import SearchService
from src.app.services.search import sp
from src.app.services.search_helpers import search_multi_inputs
from src.app.services.tutor.models import (
ExtractorOuputList,
Expand All @@ -31,9 +31,6 @@
API_VERSION=settings.AZURE_GPT_4O_API_VERSION,
)

sp = SearchService()


extractor_prompt = """
role="An assistant to summarize a text and extract the main themes from it",
backstory="You are specialised in analysing documents, summarizing them and extracting the main themes. You value precision and clarity.",
Expand Down
4 changes: 2 additions & 2 deletions src/app/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from fastapi import Depends

from src.app.core import config
from src.app.utils.logger import logger
from src.app.utils.logger import logger as logger_utils

USE_CACHED_SETTINGS = os.getenv("USE_CACHED_SETTINGS", "True") == "True"
logger = logger(__name__)
logger = logger_utils(__name__)


@lru_cache()
Expand Down
10 changes: 10 additions & 0 deletions src/app/services/abst_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import json
from abc import ABC
from time import time
from typing import AsyncIterable, Dict, List, Optional

from src.app.models.chat import ReformulatedQueryResponse
Expand Down Expand Up @@ -205,6 +206,7 @@ async def reformulate_user_query(self, query: str, history: List[Dict[str, str]]
dict: The reformulated query or None.
"""

time_start = time()
ref_to_past: dict | None = await self._detect_past_message_ref(query, history)
if ref_to_past and ref_to_past["REF_TO_PAST"]:
return ReformulatedQueryResponse(
Expand All @@ -213,6 +215,9 @@ async def reformulate_user_query(self, query: str, history: List[Dict[str, str]]
USER_LANGUAGE=None,
QUERY_STATUS="REF_TO_PAST" if len(history) >= 1 else "INVALID",
)
time_end = time()

print('>>>>>>> past message ref time:', time_end - time_start)

messages = [
{
Expand All @@ -228,10 +233,14 @@ async def reformulate_user_query(self, query: str, history: List[Dict[str, str]]
},
]

time_start = time()

reformulated_query = await self.chat_client.completion(
messages=messages,
response_format=ReformulatedQueryResponse,
)
time_end = time()
print('>>>>>>> reformulate time:', time_end - time_start)

try:
assert isinstance(reformulated_query, dict)
Expand Down Expand Up @@ -375,3 +384,4 @@ async def chat_message(
messages=messages,
)
return res

33 changes: 20 additions & 13 deletions src/app/services/search.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import time
from functools import cache
from functools import lru_cache as cache
from typing import Tuple, cast

import numpy as np
Expand Down Expand Up @@ -136,7 +136,7 @@ def _get_model(self, curr_model: str) -> tuple[int | None, SentenceTransformer]:
return (model.get_max_seq_length(), model)

@cache
def _split_input_seq_len(self, seq_len: int, input: str) -> list[str]:
def _split_input_seq_len(self, seq_len: int | None, input: str) -> list[str]:
if not seq_len:
raise ValueError("Sequence length value is not valid")

Expand Down Expand Up @@ -166,19 +166,20 @@ def _embed_query(self, search_input: str, curr_model: str) -> np.ndarray:
inputs = self._split_input_seq_len(seq_len, search_input)

try:
embeddings = model.encode(sentences=inputs, show_progress_bar=True)
embeddings = model.encode(sentences=inputs)
embeddings = np.mean(embeddings, axis=0)
time_end = time.time()
logger.debug(
"Creating embeddings time_elapsed=%s query_length=%s model=%s",
round(time_end - time_start, 2),
len(search_input),
curr_model,
)

return cast(np.ndarray, embeddings)
except Exception as ex:
logger.error("api_error=EMBED_ERROR model=%s", curr_model)
raise RuntimeError("Not able to create embed", "EMBED_ERROR") from ex
time_end = time.time()
logger.debug(
"Creating embeddings time_elapsed=%s query_length=%s model=%s",
round(time_end - time_start, 2),
len(search_input),
curr_model,
)
return cast(np.ndarray, embeddings)

async def search_handler(
self, qp: EnhancedSearchQuery, method: SearchMethods = SearchMethods.BY_SLICES
Expand Down Expand Up @@ -216,6 +217,8 @@ async def search_handler(
else:
raise ValueError(f"Unknown search method: {method}")

del embedding

sorted_data = sort_slices_using_mmr(data, theta=qp.relevance_factor)

if qp.concatenate:
Expand Down Expand Up @@ -298,8 +301,12 @@ def sort_slices_using_mmr(
id_s.append(id_r.pop(j))

logger.debug("sort_slices_using_mmr=end")
del reward
del sim
return [qdrant_results[i] for i in id_s]

sp = SearchService()


def concatenate_same_doc_id_slices(
qdrant_results: list[http_models.ScoredPoint],
Expand Down Expand Up @@ -327,12 +334,12 @@ def concatenate_same_doc_id_slices(
if curr_doc_id not in doc_id_to_slices:
doc_id_to_slices[curr_doc_id] = qresult
else:
existing_result = doc_id_to_slices[curr_doc_id]
existing_result.payload[
doc_id_to_slices[curr_doc_id].payload[
"slice_content"
] += f"\n\n{qresult.payload.get('slice_content', '')}"

new_results = list(doc_id_to_slices.values())
del doc_id_to_slices

logger.debug(
"concatenate_same_doc_id_slices=end nb_results_initial=%s nb_docs_final=%s",
Expand Down
17 changes: 8 additions & 9 deletions src/app/services/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,23 @@
api_key_header = APIKeyHeader(name="X-API-Key")


def check_api_key(api_key: str) -> bool:
def check_api_key(api_key: str) -> tuple[bool, str | None]:
digest = hashlib.sha256(api_key.encode()).digest()
statement = select(APIKeyManagement.digest, APIKeyManagement.is_active).where(
statement = select(APIKeyManagement.digest, APIKeyManagement.is_active, APIKeyManagement.title).where(
APIKeyManagement.digest == digest
)
with session_maker() as s:
keys = s.execute(statement).first()

if not keys:
return False
return (False, None)

return (keys.is_active, keys.title)

return keys.is_active


def get_user(api_key_header: str = Security(api_key_header)):
if check_api_key(api_key_header):

return "ok"
def get_user(api_key_header: str = Security(api_key_header)) -> str | None:
if check_api_key(api_key_header)[0]:
return check_api_key(api_key_header)[1]
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing or invalid API key"
)
2 changes: 1 addition & 1 deletion src/app/tests/api/api_v1/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@


@mock.patch(
"src.app.services.security.check_api_key", new=mock.MagicMock(return_value=True)
"src.app.services.security.check_api_key", new=mock.MagicMock(return_value=(True, 'welearn'))
)
@mock.patch(
"src.app.services.abst_chat.AbstractChat._detect_language",
Expand Down
10 changes: 5 additions & 5 deletions src/app/tests/api/api_v1/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

client = TestClient(app)

search_pipeline_path = "src.app.services.search.SearchService"
search_pipeline_path = "src.app.api.api_v1.endpoints.search.sp"

mocked_collection = collections.Collection(
lang="fr",
Expand Down Expand Up @@ -86,7 +86,7 @@


@patch("src.app.services.sql_db.session_maker")
@patch("src.app.services.security.check_api_key", new=mock.MagicMock(return_value=True))
@patch("src.app.services.security.check_api_key", new=mock.MagicMock(return_value=(True, 'welearn')))
@patch(
f"{search_pipeline_path}.get_collections",
new=mock.AsyncMock(
Expand Down Expand Up @@ -177,7 +177,7 @@ async def test_search_all_slices_no_collections(self, *mocks):


@patch("src.app.services.sql_db.session_maker")
@patch("src.app.services.security.check_api_key", new=mock.MagicMock(return_value=True))
@patch("src.app.services.security.check_api_key", new=mock.MagicMock(return_value=(True, 'welearn')))
class SearchTestsSlices(IsolatedAsyncioTestCase):
async def test_search_all_slices_lang_not_supported(self, *mocks):
with self.assertRaises(LanguageNotSupportedError):
Expand Down Expand Up @@ -262,7 +262,7 @@ async def test_search_all_slices_no_result(self, *mocks):


@patch("src.app.services.sql_db.session_maker")
@patch("src.app.services.security.check_api_key", new=mock.MagicMock(return_value=True))
@patch("src.app.services.security.check_api_key", new=mock.MagicMock(return_value=(True, 'welearn')))
class SearchTestsAll(IsolatedAsyncioTestCase):
async def test_search_all_lang_not_supported(self, *mocks):
with self.assertRaises(LanguageNotSupportedError):
Expand Down Expand Up @@ -349,7 +349,7 @@ async def test_sort_slices_using_mmr_custom_theta(self, *mocks):


@patch("src.app.services.sql_db.session_maker")
@patch("src.app.services.security.check_api_key", new=mock.MagicMock(return_value=True))
@patch("src.app.services.security.check_api_key", new=mock.MagicMock(return_value=(True, 'welearn')))
class SearchTestsMultiInput(IsolatedAsyncioTestCase):
async def test_search_multi_lang_not_supported(self, *mocks):
with self.assertRaises(LanguageNotSupportedError):
Expand Down
2 changes: 1 addition & 1 deletion src/app/tests/api/api_v1/test_tutor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

@mock.patch("src.app.services.sql_db.session_maker")
@mock.patch(
"src.app.services.security.check_api_key", new=mock.MagicMock(return_value=True)
"src.app.services.security.check_api_key", new=mock.MagicMock(return_value=(True, 'welearn'))
)
class TutorTests(IsolatedAsyncioTestCase):
def test_tutor_no_files(self, *mocks):
Expand Down
Loading