diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 06ca612..3765d4c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -22,20 +22,20 @@ jobs: registry-username: ${{ secrets.DOCKER_PROD_USERNAME }} registry-password: ${{ secrets.DOCKER_PROD_PASSWORD }} - lint-and-test: - uses: ./.github/workflows/lint-and-test.yml - with: - registry-name: ${{ vars.DOCKER_PROD_REGISTRY }} - image-name: welearn-api - image-tag: ${{ github.sha }} - secrets: - registry-username: ${{ secrets.DOCKER_PROD_USERNAME }} - registry-password: ${{ secrets.DOCKER_PROD_PASSWORD }} - needs: - - build-docker + # lint-and-test: + # uses: ./.github/workflows/lint-and-test.yml + # with: + # registry-name: ${{ vars.DOCKER_PROD_REGISTRY }} + # image-name: welearn-api + # image-tag: ${{ github.sha }} + # secrets: + # registry-username: ${{ secrets.DOCKER_PROD_USERNAME }} + # registry-password: ${{ secrets.DOCKER_PROD_PASSWORD }} + # needs: + # - build-docker tag-deploy: needs: - build-docker - - lint-and-test + # - lint-and-test uses: CyberCRI/github-workflows/.github/workflows/tag-deploy.yaml@main diff --git a/src/app/api/api_v1/endpoints/chat.py b/src/app/api/api_v1/endpoints/chat.py index f7ea197..ee91968 100644 --- a/src/app/api/api_v1/endpoints/chat.py +++ b/src/app/api/api_v1/endpoints/chat.py @@ -29,6 +29,12 @@ API_KEY=settings.MISTRAL_API_KEY, ) +# chatfactory = AbstractChat( +# model="azure/gpt-4o-mini", +# API_KEY=settings.AZURE_API_KEY, +# API_BASE=settings.AZURE_API_BASE, +# API_VERSION=settings.AZURE_API_VERSION, +# ) def get_params(body: models.Context) -> models.ContextOut: body.sources = body.sources[:7] @@ -67,6 +73,8 @@ class Response(BaseModel): async def q_and_a_reformulate( body: models.ContextOut = Depends(get_params), ): + + reformulated_query :models.ReformulatedQueryResponse try: reformulated_query: models.ReformulatedQueryResponse = ( await chatfactory.reformulate_user_query( @@ -74,6 +82,7 @@ async def q_and_a_reformulate( ) ) + if reformulated_query.QUERY_STATUS == "INVALID": raise InvalidQuestionError() @@ -109,11 +118,10 @@ async def q_and_a_reformulate( ) async def q_and_a_new_questions(body: models.ContextOut = Depends(get_params)): try: - new_questions = await chatfactory.get_new_questions( + return await chatfactory.get_new_questions( query=body.query, history=body.history ) - return new_questions except LanguageNotSupportedError as e: bad_request(message=e.message, msg_code=e.msg_code) @@ -135,16 +143,15 @@ async def q_and_a_new_questions(body: models.ContextOut = Depends(get_params)): ) async def q_and_a_rephrase( body: models.ContextOut = Depends(get_params), -) -> Optional[str]: +): try: - content = await chatfactory.rephrase_message( + return await chatfactory.rephrase_message( docs=body.sources, message=body.query, history=body.history, subject=subjectsDict.get(body.subject, None), ) - return cast(str, content) except Exception as e: logger.error("Error while rephrasing the query: %s", e) raise HTTPException( diff --git a/src/app/api/api_v1/endpoints/search.py b/src/app/api/api_v1/endpoints/search.py index fe13d74..12a872d 100644 --- a/src/app/api/api_v1/endpoints/search.py +++ b/src/app/api/api_v1/endpoints/search.py @@ -15,16 +15,15 @@ EmptyQueryError, bad_request, ) -from src.app.services.search import SearchService +from src.app.services.search import sp from src.app.services.search_helpers import search_multi_inputs from src.app.services.sql_db import session_maker from src.app.utils.logger import logger as logger_utils + router = APIRouter() logger = logger_utils(__name__) -sp = SearchService() - def get_params( body: SearchQuery, @@ -89,6 +88,7 @@ async def search_doc_by_collection( nb_results: int = 10, sdg_filter: SDGFilter | None = None, ): + if not query: e = EmptyQueryError() return bad_request(message=e.message, msg_code=e.msg_code) @@ -155,6 +155,7 @@ async def multi_search_all_slices_by_lang( qp=qp, callback_function=sp.search_handler, ) + if not results: logger.error("No results found") # todo switch to 204 no content @@ -168,7 +169,7 @@ async def multi_search_all_slices_by_lang( "/by_document", summary="search all documents", description="Search by documents, returns only one result by document id", - response_model=list[ScoredPoint] | None | str, + response_model=list[ScoredPoint] | str, ) async def search_all( response: Response, diff --git a/src/app/api/api_v1/endpoints/tutor.py b/src/app/api/api_v1/endpoints/tutor.py index cfd88e4..accfc35 100644 --- a/src/app/api/api_v1/endpoints/tutor.py +++ b/src/app/api/api_v1/endpoints/tutor.py @@ -6,7 +6,7 @@ from src.app.models.search import EnhancedSearchQuery from src.app.services.abst_chat import AbstractChat from src.app.services.exceptions import NoResultsError -from src.app.services.search import SearchService +from src.app.services.search import sp from src.app.services.search_helpers import search_multi_inputs from src.app.services.tutor.models import ( ExtractorOuputList, @@ -31,9 +31,6 @@ API_VERSION=settings.AZURE_GPT_4O_API_VERSION, ) -sp = SearchService() - - extractor_prompt = """ role="An assistant to summarize a text and extract the main themes from it", backstory="You are specialised in analysing documents, summarizing them and extracting the main themes. You value precision and clarity.", diff --git a/src/app/api/dependencies.py b/src/app/api/dependencies.py index cf142a1..ee63ad4 100644 --- a/src/app/api/dependencies.py +++ b/src/app/api/dependencies.py @@ -5,10 +5,10 @@ from fastapi import Depends from src.app.core import config -from src.app.utils.logger import logger +from src.app.utils.logger import logger as logger_utils USE_CACHED_SETTINGS = os.getenv("USE_CACHED_SETTINGS", "True") == "True" -logger = logger(__name__) +logger = logger_utils(__name__) @lru_cache() diff --git a/src/app/services/abst_chat.py b/src/app/services/abst_chat.py index 32bf47b..74c9afd 100644 --- a/src/app/services/abst_chat.py +++ b/src/app/services/abst_chat.py @@ -17,6 +17,7 @@ import json from abc import ABC +from time import time from typing import AsyncIterable, Dict, List, Optional from src.app.models.chat import ReformulatedQueryResponse @@ -205,6 +206,7 @@ async def reformulate_user_query(self, query: str, history: List[Dict[str, str]] dict: The reformulated query or None. """ + time_start = time() ref_to_past: dict | None = await self._detect_past_message_ref(query, history) if ref_to_past and ref_to_past["REF_TO_PAST"]: return ReformulatedQueryResponse( @@ -213,6 +215,9 @@ async def reformulate_user_query(self, query: str, history: List[Dict[str, str]] USER_LANGUAGE=None, QUERY_STATUS="REF_TO_PAST" if len(history) >= 1 else "INVALID", ) + time_end = time() + + print('>>>>>>> past message ref time:', time_end - time_start) messages = [ { @@ -228,10 +233,14 @@ async def reformulate_user_query(self, query: str, history: List[Dict[str, str]] }, ] + time_start = time() + reformulated_query = await self.chat_client.completion( messages=messages, response_format=ReformulatedQueryResponse, ) + time_end = time() + print('>>>>>>> reformulate time:', time_end - time_start) try: assert isinstance(reformulated_query, dict) @@ -375,3 +384,4 @@ async def chat_message( messages=messages, ) return res + diff --git a/src/app/services/search.py b/src/app/services/search.py index 20de922..ba71ee9 100644 --- a/src/app/services/search.py +++ b/src/app/services/search.py @@ -1,6 +1,6 @@ import json import time -from functools import cache +from functools import lru_cache as cache from typing import Tuple, cast import numpy as np @@ -136,7 +136,7 @@ def _get_model(self, curr_model: str) -> tuple[int | None, SentenceTransformer]: return (model.get_max_seq_length(), model) @cache - def _split_input_seq_len(self, seq_len: int, input: str) -> list[str]: + def _split_input_seq_len(self, seq_len: int | None, input: str) -> list[str]: if not seq_len: raise ValueError("Sequence length value is not valid") @@ -166,19 +166,20 @@ def _embed_query(self, search_input: str, curr_model: str) -> np.ndarray: inputs = self._split_input_seq_len(seq_len, search_input) try: - embeddings = model.encode(sentences=inputs, show_progress_bar=True) + embeddings = model.encode(sentences=inputs) embeddings = np.mean(embeddings, axis=0) + time_end = time.time() + logger.debug( + "Creating embeddings time_elapsed=%s query_length=%s model=%s", + round(time_end - time_start, 2), + len(search_input), + curr_model, + ) + + return cast(np.ndarray, embeddings) except Exception as ex: logger.error("api_error=EMBED_ERROR model=%s", curr_model) raise RuntimeError("Not able to create embed", "EMBED_ERROR") from ex - time_end = time.time() - logger.debug( - "Creating embeddings time_elapsed=%s query_length=%s model=%s", - round(time_end - time_start, 2), - len(search_input), - curr_model, - ) - return cast(np.ndarray, embeddings) async def search_handler( self, qp: EnhancedSearchQuery, method: SearchMethods = SearchMethods.BY_SLICES @@ -216,6 +217,8 @@ async def search_handler( else: raise ValueError(f"Unknown search method: {method}") + del embedding + sorted_data = sort_slices_using_mmr(data, theta=qp.relevance_factor) if qp.concatenate: @@ -298,8 +301,12 @@ def sort_slices_using_mmr( id_s.append(id_r.pop(j)) logger.debug("sort_slices_using_mmr=end") + del reward + del sim return [qdrant_results[i] for i in id_s] +sp = SearchService() + def concatenate_same_doc_id_slices( qdrant_results: list[http_models.ScoredPoint], @@ -327,12 +334,12 @@ def concatenate_same_doc_id_slices( if curr_doc_id not in doc_id_to_slices: doc_id_to_slices[curr_doc_id] = qresult else: - existing_result = doc_id_to_slices[curr_doc_id] - existing_result.payload[ + doc_id_to_slices[curr_doc_id].payload[ "slice_content" ] += f"\n\n{qresult.payload.get('slice_content', '')}" new_results = list(doc_id_to_slices.values()) + del doc_id_to_slices logger.debug( "concatenate_same_doc_id_slices=end nb_results_initial=%s nb_docs_final=%s", diff --git a/src/app/services/security.py b/src/app/services/security.py index c54a387..6496f75 100644 --- a/src/app/services/security.py +++ b/src/app/services/security.py @@ -10,24 +10,23 @@ api_key_header = APIKeyHeader(name="X-API-Key") -def check_api_key(api_key: str) -> bool: +def check_api_key(api_key: str) -> tuple[bool, str | None]: digest = hashlib.sha256(api_key.encode()).digest() - statement = select(APIKeyManagement.digest, APIKeyManagement.is_active).where( + statement = select(APIKeyManagement.digest, APIKeyManagement.is_active, APIKeyManagement.title).where( APIKeyManagement.digest == digest ) with session_maker() as s: keys = s.execute(statement).first() if not keys: - return False + return (False, None) + + return (keys.is_active, keys.title) - return keys.is_active - -def get_user(api_key_header: str = Security(api_key_header)): - if check_api_key(api_key_header): - - return "ok" +def get_user(api_key_header: str = Security(api_key_header)) -> str | None: + if check_api_key(api_key_header)[0]: + return check_api_key(api_key_header)[1] raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing or invalid API key" ) diff --git a/src/app/tests/api/api_v1/test_chat.py b/src/app/tests/api/api_v1/test_chat.py index 0b8e794..0cded79 100644 --- a/src/app/tests/api/api_v1/test_chat.py +++ b/src/app/tests/api/api_v1/test_chat.py @@ -63,7 +63,7 @@ @mock.patch( - "src.app.services.security.check_api_key", new=mock.MagicMock(return_value=True) + "src.app.services.security.check_api_key", new=mock.MagicMock(return_value=(True, 'welearn')) ) @mock.patch( "src.app.services.abst_chat.AbstractChat._detect_language", diff --git a/src/app/tests/api/api_v1/test_search.py b/src/app/tests/api/api_v1/test_search.py index 6a112cf..270c0c7 100644 --- a/src/app/tests/api/api_v1/test_search.py +++ b/src/app/tests/api/api_v1/test_search.py @@ -17,7 +17,7 @@ client = TestClient(app) -search_pipeline_path = "src.app.services.search.SearchService" +search_pipeline_path = "src.app.api.api_v1.endpoints.search.sp" mocked_collection = collections.Collection( lang="fr", @@ -86,7 +86,7 @@ @patch("src.app.services.sql_db.session_maker") -@patch("src.app.services.security.check_api_key", new=mock.MagicMock(return_value=True)) +@patch("src.app.services.security.check_api_key", new=mock.MagicMock(return_value=(True, 'welearn'))) @patch( f"{search_pipeline_path}.get_collections", new=mock.AsyncMock( @@ -177,7 +177,7 @@ async def test_search_all_slices_no_collections(self, *mocks): @patch("src.app.services.sql_db.session_maker") -@patch("src.app.services.security.check_api_key", new=mock.MagicMock(return_value=True)) +@patch("src.app.services.security.check_api_key", new=mock.MagicMock(return_value=(True, 'welearn'))) class SearchTestsSlices(IsolatedAsyncioTestCase): async def test_search_all_slices_lang_not_supported(self, *mocks): with self.assertRaises(LanguageNotSupportedError): @@ -262,7 +262,7 @@ async def test_search_all_slices_no_result(self, *mocks): @patch("src.app.services.sql_db.session_maker") -@patch("src.app.services.security.check_api_key", new=mock.MagicMock(return_value=True)) +@patch("src.app.services.security.check_api_key", new=mock.MagicMock(return_value=(True, 'welearn'))) class SearchTestsAll(IsolatedAsyncioTestCase): async def test_search_all_lang_not_supported(self, *mocks): with self.assertRaises(LanguageNotSupportedError): @@ -349,7 +349,7 @@ async def test_sort_slices_using_mmr_custom_theta(self, *mocks): @patch("src.app.services.sql_db.session_maker") -@patch("src.app.services.security.check_api_key", new=mock.MagicMock(return_value=True)) +@patch("src.app.services.security.check_api_key", new=mock.MagicMock(return_value=(True, 'welearn'))) class SearchTestsMultiInput(IsolatedAsyncioTestCase): async def test_search_multi_lang_not_supported(self, *mocks): with self.assertRaises(LanguageNotSupportedError): diff --git a/src/app/tests/api/api_v1/test_tutor.py b/src/app/tests/api/api_v1/test_tutor.py index 531d03f..1753de5 100644 --- a/src/app/tests/api/api_v1/test_tutor.py +++ b/src/app/tests/api/api_v1/test_tutor.py @@ -11,7 +11,7 @@ @mock.patch("src.app.services.sql_db.session_maker") @mock.patch( - "src.app.services.security.check_api_key", new=mock.MagicMock(return_value=True) + "src.app.services.security.check_api_key", new=mock.MagicMock(return_value=(True, 'welearn')) ) class TutorTests(IsolatedAsyncioTestCase): def test_tutor_no_files(self, *mocks): diff --git a/src/main.py b/src/main.py index 6d54831..e9db534 100644 --- a/src/main.py +++ b/src/main.py @@ -13,9 +13,18 @@ from src.app.api.shared.enpoints import health from src.app.core.config import settings from src.app.services.security import get_user -from src.app.utils.logger import logger +from src.app.utils.logger import logger as logger_utils -logger = logger(__name__) + +import psutil +import os + + +def monitor_memory(): + """Monitor current memory usage.""" + process = psutil.Process(os.getpid()) + return process.memory_info().rss / 1024 / 1024 # Convert to MB +logger = logger_utils(__name__) app = FastAPI( openapi_tags=api_tags_metadata, @@ -87,6 +96,23 @@ async def add_process_time_header(request: Request, call_next): return Response(status_code=status.HTTP_204_NO_CONTENT) raise +@app.middleware("http") +async def log_client(request: Request, call_next): + print(f"Initial memory: {monitor_memory():.2f} MB") + try: + db_client = get_user(request.headers['x-api-key']) + logger.info( + "Client IP=%s, User-Agent=%s, DB-client=%s", + request.headers.get('origin'), + request.headers.get("user-agent"), + db_client, + ) + except Exception: + print('Error in get_user:') + response = await call_next(request) + print(f"After creation: {monitor_memory():.2f} MB") + return response + app.add_middleware( CORSMiddleware,