Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 7 additions & 50 deletions poetry.lock

Large diffs are not rendered by default.

122 changes: 68 additions & 54 deletions src/app/api/api_v1/endpoints/search.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
from typing import List, Optional, Union

from fastapi import APIRouter, Depends, Response
from qdrant_client.models import ScoredPoint
from sqlalchemy.sql import select

from src.app.models.db_models import CorpusEmbedding
from src.app.models.documents import Collection_schema, Document
from src.app.models.search import EnhancedSearchQuery, SearchFilter, SearchQuery
from src.app.services.exceptions import EmptyQueryError, bad_request
from src.app.services.search import SearchService
from src.app.services.search_helpers import (
search_all_base,
search_items_base,
search_multi_inputs,
from src.app.models.documents import Collection_schema
from src.app.models.search import (
EnhancedSearchQuery,
SDGFilter,
SearchMethods,
SearchQuery,
)
from src.app.services.exceptions import (
CollectionNotFoundError,
EmptyQueryError,
bad_request,
)
from src.app.services.search import SearchService
from src.app.services.search_helpers import search_multi_inputs
from src.app.services.sql_db import session_maker
from src.app.utils.logger import logger as logger_utils

Expand All @@ -25,7 +29,7 @@
def get_params(
body: SearchQuery,
nb_results: int = 30,
subject: Optional[str] = None,
subject: str | None = None,
influence_factor: float = 2,
relevance_factor: float = 1,
concatenate: bool = True,
Expand All @@ -52,7 +56,7 @@ def get_params(
"/collections",
summary="get all collections",
description="Get all collections available in the database",
response_model=List[Collection_schema],
response_model=list[Collection_schema],
)
async def get_corpus():
statement = select(
Expand All @@ -73,59 +77,72 @@ async def get_corpus():


@router.post(
"/collections/{collection_query}",
summary="search items in a specific collection",
description="Search items in a specific collection",
response_model=Union[List[Document], None],
"/collections/{collection}",
summary="search documents in a specific collection",
description="Search documents in a specific collection",
response_model=list[ScoredPoint] | str | None,
)
async def search_items(
query: Optional[str] = None,
collection_query: str = "conversation",
async def search_doc_by_collection(
response: Response,
query: str,
collection: str = "conversation",
Comment thread
sandragjacinto marked this conversation as resolved.
nb_results: int = 10,
sdg_filter: Optional[SearchFilter] = None,
sdg_filter: SDGFilter | None = None,
):
if not query:
e = EmptyQueryError()
return bad_request(message=e.message, msg_code=e.msg_code)

return await search_items_base(
qp = EnhancedSearchQuery(
query=query,
collection_query=collection_query,
nb_results=nb_results,
sdg_filter=sdg_filter,
search_func=sp.search_group_by_document,
corpora=(collection,),
sdg_filter=sdg_filter.sdg_filter if sdg_filter else None,
)

try:
res = await sp.search_handler(qp=qp, method=SearchMethods.BY_DOCUMENT)

if not res:
response.status_code = 206
return []

return res
except CollectionNotFoundError as e:
response.status_code = 404
return e.message


@router.post(
"/by_slices",
summary="search all slices",
description="Search slices in all collections or in collections specified",
response_model=Union[List[Document], None],
response_model=list[ScoredPoint] | None | str,
)
async def search_all_slices_by_lang(
response: Response,
qp: EnhancedSearchQuery = Depends(get_params),
):
res = await search_all_base(
response=response,
qp=qp,
search_func=sp.search,
)
try:

if not res:
logger.error("No results found")
response.status_code = 404
return None
res = await sp.search_handler(qp=qp, method=SearchMethods.BY_SLICES)

return res
if not res:
logger.debug("No results found")
response.status_code = 404
return []

return res
except CollectionNotFoundError as e:
response.status_code = 404
return e.message


@router.post(
"/multiple_by_slices",
summary="search all slices",
description="Search slices in all collections or in collections specified",
response_model=Union[List[Document], None],
response_model=list[ScoredPoint] | None,
)
async def multi_search_all_slices_by_lang(
response: Response,
Expand All @@ -135,40 +152,37 @@ async def multi_search_all_slices_by_lang(
qp.query = [qp.query]

results = await search_multi_inputs(
response=response,
nb_results=qp.nb_results,
sdg_filter=qp.sdg_filter,
collections=qp.corpora,
inputs=qp.query,
callback_function=sp.search,
qp=qp,
callback_function=sp.search_handler,
)
if not results:
logger.error("No results found")
# todo switch to 204 no content
response.status_code = 404
return None
return []

return results


@router.post(
"/by_document",
summary="search all documents",
description="Search documents in all collections or in collections specified",
response_model=Union[List[Document], None],
description="Search by documents, returns only one result by document id",
response_model=list[ScoredPoint] | None | str,
)
async def search_all(
response: Response,
qp: EnhancedSearchQuery = Depends(get_params),
):
res = await search_all_base(
response=response,
qp=qp,
search_func=sp.search_group_by_document,
)

if not res:
logger.error("No results found")
try:
res = await sp.search_handler(qp=qp, method=SearchMethods.BY_DOCUMENT)

if not res:
logger.error("No results found")
response.status_code = 404
return []
except CollectionNotFoundError as e:
response.status_code = 404
return None
return e.message

return res
17 changes: 10 additions & 7 deletions src/app/api/api_v1/endpoints/tutor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from fastapi import APIRouter, File, HTTPException, Response, UploadFile

from src.app.api.dependencies import get_settings
from src.app.models.search import EnhancedSearchQuery
from src.app.services.abst_chat import AbstractChat, ChatFactory
from src.app.services.exceptions import NoResultsError
from src.app.services.search import SearchService
Expand Down Expand Up @@ -90,13 +91,17 @@ async def tutor_search(
inputs = [doc.summary for doc in themes_extracted.extracts] # type: ignore

try:
search_results = await search_multi_inputs(
response=response,
inputs=inputs,
qp = EnhancedSearchQuery(
query=inputs,
nb_results=5,
sdg_filter=None,
collections=None,
callback_function=sp.search,
corpora=None,
)

search_results = await search_multi_inputs(
response=response,
qp=qp,
callback_function=sp.search_handler,
)
except NoResultsError as e:
response.status_code = 404
Expand All @@ -120,8 +125,6 @@ async def tutor_search(
documents=search_results,
)

# TODO: handle duplicates

return resp


Expand Down
3 changes: 1 addition & 2 deletions src/app/models/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ class Collection_schema(BaseModel):


class Collection(NamedTuple):
name: str
lang: str
model: str
alias: str
name: str
48 changes: 45 additions & 3 deletions src/app/models/search.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from enum import StrEnum

from pydantic import BaseModel, Field
from qdrant_client.models import FieldCondition, Filter, MatchAny

from src.app.utils.logger import logger as logger_utils

logger = logger_utils(__name__)

class SearchFilter(BaseModel):

class SDGFilter(BaseModel):
sdg_filter: list[int] | None = Field(
None,
max_length=17,
Expand All @@ -11,16 +18,51 @@ class SearchFilter(BaseModel):
)


class SearchQuery(SearchFilter):
class SearchQuery(SDGFilter):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why it inherit from SDGFilter ?

query: str | list[str] | None
corpora: list[str] | None = None


class EnhancedSearchQuery(SearchFilter):
class EnhancedSearchQuery(SDGFilter):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why it inherit from SDGFilter ?

query: str | list[str]
corpora: tuple[str, ...] | None = None
nb_results: int = 30
subject: str | None = None
influence_factor: float = 2
relevance_factor: float = 1
concatenate: bool = True


class SearchFilters(BaseModel):
Comment thread
sandragjacinto marked this conversation as resolved.
slice_sdg: list[int] | None
document_corpus: tuple[str, ...] | list[str] | None

def build_filters(self) -> Filter | None:
if not self.slice_sdg and not self.document_corpus:
return None

filters = {
"slice_sdg": self.slice_sdg,
"document_corpus": self.document_corpus,
}

qdrant_filter = []
for key, values in filters.items():
if not values:
continue

qdrant_filter.append(
FieldCondition(
key=key,
match=MatchAny(any=values),
)
)

logger.debug("build_filters=%s", qdrant_filter)

return Filter(must=qdrant_filter)


class SearchMethods(StrEnum):
BY_SLICES = "by_slices"
BY_DOCUMENT = "by_document"
4 changes: 1 addition & 3 deletions src/app/services/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Optional

from fastapi import HTTPException, Response, status

from src.app.utils.logger import logger as logger_utils
Expand Down Expand Up @@ -152,7 +150,7 @@ def __init__(
super().__init__(self.message, self.msg_code)


def handle_error(response: Optional[Response], exc: Exception) -> None:
def handle_error(exc: Exception, response: Response | None = None) -> None:
if isinstance(exc, PartialResponseResultError):
if response:
response.status_code = status.HTTP_206_PARTIAL_CONTENT
Expand Down
Loading