Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ uvicorn = {extras = ["standard"], version = "^0.29.0"}
autogen-core = "^0.4.9.3"
autogen-ext = "^0.4.9.3"
autogen-agentchat = "^0.4.9.2"
pypdf = "^5.4.0"

[tool.poetry.group.dev.dependencies]
pytest = "^7.4.0"
Expand Down
38 changes: 32 additions & 6 deletions src/app/api/api_v1/endpoints/tutor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Annotated

from fastapi import APIRouter, File, Response, UploadFile
from pypdf import PdfReader

from src.app.api.dependencies import get_settings
from src.app.services.abst_chat import AbstractChat, ChatFactory
Expand All @@ -20,7 +21,7 @@

settings = get_settings()

chatfactory: AbstractChat = ChatFactory().create_chat("openai")
chatfactory: AbstractChat = ChatFactory().create_chat("openai", model="gpt-4o")
chatfactory.init_client()

sp = SearchService()
Expand All @@ -39,26 +40,51 @@ async def tutor_search(
files: Annotated[list[UploadFile], File()],
response: Response,
):
file_content: list[bytes] = [await file.read() for file in files]
files_content: list[bytes] = []
for file in files:
if (
file.content_type == "application/pdf"
or file.content_type == "application/x-pdf"
):
file_content = ""
reader = PdfReader(file.file)
for page in reader.pages:
file_content += page.extract_text()
files_content.append(file_content.encode("utf-8", errors="ignore"))
else:
file_content = await file.read()
files_content.append(file_content)

doc_list_to_string = "Document {doc_nb}: {content}"

file_content_str = [
doc_list_to_string.format(
doc_nb=index + 1,
content=content.decode("utf-8", errors="ignore"),
)
for index, content in enumerate(file_content)
for index, content in enumerate(files_content)
]
file_content_str = "\n\n".join(file_content_str)

print(file_content_str)

messages = [
{"role": "system", "content": extractor_prompt},
{"role": "assistant", "content": file_content_str},
]

themes_extracted = await chatfactory.chat_schema(
model="gpt-4o-mini", messages=messages, response_format=ExtractorOuputList # type: ignore
)
try:
themes_extracted = await chatfactory.chat_schema(
model="gpt-4o", messages=messages, response_format=ExtractorOuputList # type: ignore
)
except Exception as e:
logger.error(f"Error in chat schema: {e}")
# todo: handle error
return TutorSearchResponse(
extracts=[],
nb_results=0,
documents=[],
)

if not themes_extracted or not themes_extracted.extracts:
return TutorSearchResponse(
Expand Down
26 changes: 23 additions & 3 deletions src/app/services/abst_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
from abc import ABC, abstractmethod
from typing import AsyncIterable, Dict, List, Literal, Optional

import openai
from azure.ai.inference import ChatCompletionsClient
from azure.core.credentials import AzureKeyCredential

# from ecologits import EcoLogits # type: ignore
from mistralai import Mistral
from openai import AsyncAzureOpenAI
from pydantic import BaseModel

from src.app.api.dependencies import get_settings
Expand Down Expand Up @@ -456,7 +456,7 @@ def init_client(self):
raise ValueError("API_BASE or API_VERSION not provided")

try:
self.chat_client = openai.AsyncAzureOpenAI(
self.chat_client = AsyncAzureOpenAI(
api_key=self.API_KEY,
azure_endpoint=self.API_BASE,
api_version=self.API_VERSION,
Expand Down Expand Up @@ -507,7 +507,7 @@ async def chat_schema(
):
try:
completion = await self.chat_client.beta.chat.completions.parse(
model=model,
model=self.model or model,
messages=messages,
temperature=0.2,
response_format=response_format,
Expand Down Expand Up @@ -653,6 +653,26 @@ def create_chat(
raise ValueError(f"Unsupported chat type: {chat_type}")

if chat_type == "openai":
openai_models = {
"gpt-4o-mini": (
settings.AZURE_API_KEY,
settings.AZURE_API_BASE,
settings.AZURE_API_VERSION,
),
"gpt-4o": (
settings.AZURE_GPT_4O_API_KEY,
settings.AZURE_GPT_4O_API_BASE,
"2025-01-01-preview",
),
}
if model:
key, base, version = openai_models.get(model, (None, None, None))
if not key or not base or not version:
raise ValueError(f"Unsupported model: {model}")
return chat_classes[chat_type](
API_KEY=key, API_BASE=base, API_VERSION=version, model=model
)

return chat_classes[chat_type](
API_VERSION=settings.AZURE_API_VERSION,
API_KEY=settings.AZURE_API_KEY,
Expand Down
11 changes: 9 additions & 2 deletions src/app/services/tutor/tutor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@
sdg_expert_topic_type,
university_teacher_topic_type,
)
from src.app.services.tutor.models import Message, TaskResponse, TutorSearchResponse, MessageWithResources
from src.app.services.tutor.models import (
Message,
MessageWithResources,
TaskResponse,
TutorSearchResponse,
)
from src.app.services.tutor.utils import extract_doc_info

settings = get_settings()
Expand All @@ -40,7 +45,9 @@
async def tutor_manager(content: TutorSearchResponse) -> Message:
queue = asyncio.Queue[TaskResponse]()

formatted_content = MessageWithResources(content=content.extracts, resources=extract_doc_info(content.documents))
formatted_content = MessageWithResources(
content=content.extracts, resources=extract_doc_info(content.documents)
)

async def collect_result(
_agent: ClosureContext, message: TaskResponse, ctx: MessageContext
Expand Down