diff --git a/backend/api/auth/router.py b/backend/api/auth/router.py index 18555e99..828cfe43 100644 --- a/backend/api/auth/router.py +++ b/backend/api/auth/router.py @@ -39,6 +39,7 @@ client_kwargs={"scope": "openid profile email"}, ) + @router.post("/token", response_model=Token) async def login_for_access_token( form_data: OAuth2PasswordRequestForm = Depends(), @@ -83,13 +84,17 @@ async def login(login_data: LoginRequest) -> Any: return {"access_token": access_token, "token_type": "Bearer"} + @router.get("/microsoft-sso") async def login_microsoft_sso(request: StarletteRequest, lang: str): """ Start Microsoft OAuth2 flow by redirecting the user to Microsoft login. """ request.session["lang"] = lang - return await oauth.microsoft.authorize_redirect(request, f"{settings.API_URL}/api/auth/sso-authorize") + return await oauth.microsoft.authorize_redirect( + request, f"{settings.API_URL}/api/auth/sso-authorize" + ) + @router.get("/sso-authorize") async def microsoft_authorize(request: StarletteRequest): @@ -122,6 +127,7 @@ async def microsoft_authorize(request: StarletteRequest): detail="Error during Microsoft OAuth callback", ) + @router.post("/register", response_model=UserRead) async def register_user(register_data: RegisterRequest) -> Any: """ diff --git a/backend/api/citations/router.py b/backend/api/citations/router.py index 0a1c547d..cfba3ff3 100644 --- a/backend/api/citations/router.py +++ b/backend/api/citations/router.py @@ -36,7 +36,12 @@ from ..core.security import get_current_active_user from ..core.config import settings -from ..services.cit_db_service import cits_dp_service, snake_case, snake_case_column, parse_dsn +from ..services.cit_db_service import ( + cits_dp_service, + snake_case, + snake_case_column, + parse_dsn, +) from ..core.cit_utils import load_sr_and_check router = APIRouter() @@ -120,7 +125,9 @@ def _join_list(v: Any, sep: str = "; ") -> str: if v is None: return "" if isinstance(v, list): - return sep.join([str(x) for x in v if x is not None and str(x).strip() != ""]).strip() + return sep.join( + [str(x) for x in v if x is not None and str(x).strip() != ""] + ).strip() return str(v) @@ -139,9 +146,15 @@ def _ris_value_for_include(include_col: str, entry: Dict[str, Any]) -> Any: # Common include names in our configs if key == "title": - return entry.get("title") or entry.get("primary_title") or entry.get("short_title") + return ( + entry.get("title") or entry.get("primary_title") or entry.get("short_title") + ) if key == "abstract": - return entry.get("abstract") or entry.get("notes_abstract") or _join_list(entry.get("notes"), "\n") + return ( + entry.get("abstract") + or entry.get("notes_abstract") + or _join_list(entry.get("notes"), "\n") + ) if key == "keywords": return _join_list(entry.get("keywords")) if key == "journal": @@ -173,7 +186,9 @@ def _ris_value_for_include(include_col: str, entry: Dict[str, Any]) -> Any: return v -def _parse_citations_csv_bytes(raw_bytes: bytes) -> tuple[list[dict[str, Any]], list[str]]: +def _parse_citations_csv_bytes( + raw_bytes: bytes, +) -> tuple[list[dict[str, Any]], list[str]]: try: text = raw_bytes.decode("utf-8-sig") except Exception: @@ -184,7 +199,9 @@ def _parse_citations_csv_bytes(raw_bytes: bytes) -> tuple[list[dict[str, Any]], return rows, cols -def _parse_citations_ris_bytes(raw_bytes: bytes, include_columns: list[str]) -> tuple[list[dict[str, Any]], list[str]]: +def _parse_citations_ris_bytes( + raw_bytes: bytes, include_columns: list[str] +) -> tuple[list[dict[str, Any]], list[str]]: if not rispy: raise RuntimeError("RIS upload requested but rispy is not installed") @@ -231,15 +248,23 @@ def _ris_value_for_canonical(col: str, entry: Dict[str, Any]) -> Any: key = (col or "").strip().lower() if key == "title": - return entry.get("title") or entry.get("primary_title") or entry.get("short_title") + return ( + entry.get("title") or entry.get("primary_title") or entry.get("short_title") + ) if key == "abstract": - return entry.get("abstract") or entry.get("notes_abstract") or _join_list(entry.get("notes"), "\n") + return ( + entry.get("abstract") + or entry.get("notes_abstract") + or _join_list(entry.get("notes"), "\n") + ) if key == "keywords": return _join_list(entry.get("keywords")) if key == "journal": return entry.get("secondary_title") or entry.get("journal_name") if key == "year": - return _extract_year(entry.get("year") or entry.get("publication_year") or entry.get("date")) + return _extract_year( + entry.get("year") or entry.get("publication_year") or entry.get("date") + ) if key == "authors": return _join_list(entry.get("authors")) if key == "doi": @@ -254,7 +279,9 @@ def _ris_value_for_canonical(col: str, entry: Dict[str, Any]) -> Any: return _join_list(v) if isinstance(v, list) else v -def _parse_citations_ris_bytes_auto(raw_bytes: bytes) -> tuple[list[dict[str, Any]], list[str]]: +def _parse_citations_ris_bytes_auto( + raw_bytes: bytes, +) -> tuple[list[dict[str, Any]], list[str]]: """Parse RIS into a canonical table shape. This is used when RIS is uploaded before SR criteria/config exists. @@ -285,19 +312,28 @@ def _parse_citations_ris_bytes_auto(raw_bytes: bytes) -> tuple[list[dict[str, An return rows, cols -def _load_include_columns_from_criteria(sr_doc: Optional[Dict[str, Any]] = None) -> List[str]: + +def _load_include_columns_from_criteria( + sr_doc: Optional[Dict[str, Any]] = None, +) -> List[str]: # Delegate to consolidated postgres service try: return cits_dp_service.load_include_columns_from_criteria(sr_doc) except Exception: return [] + + def _parse_dsn(dsn: str) -> Dict[str, str]: # Delegate to consolidated postgres service try: return parse_dsn(dsn) except Exception: return {} -def _create_table_and_insert_sync(table_name: str, columns: List[str], rows: List[Dict[str, Any]]) -> int: + + +def _create_table_and_insert_sync( + table_name: str, columns: List[str], rows: List[Dict[str, Any]] +) -> int: return cits_dp_service.create_table_and_insert_sync(table_name, columns, rows) @@ -311,7 +347,9 @@ async def _upload_screening_citations_impl( """Shared implementation for citations upload (CSV or RIS).""" try: - sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service, require_screening=False) + sr, screening = await load_sr_and_check( + sr_id, current_user, srdb_service, require_screening=False + ) except HTTPException: raise except Exception as e: @@ -328,19 +366,29 @@ async def _upload_screening_citations_impl( ) if not file or not file.filename: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="File is required") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="File is required" + ) # Read bytes once (UploadFile stream is not reliably seekable after read) try: raw = await file.read() except Exception as e: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Failed to read upload: {e}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Failed to read upload: {e}", + ) if not raw: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Uploaded file is empty") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="Uploaded file is empty" + ) fmt = (force_format or _sniff_citations_format(file.filename, raw)).lower() if fmt not in ("csv", "ris"): - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Unsupported citations format: {fmt}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Unsupported citations format: {fmt}", + ) # Parse into normalized rows + include columns try: @@ -352,10 +400,16 @@ async def _upload_screening_citations_impl( except HTTPException: raise except Exception as e: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Failed to parse {fmt.upper()}: {e}") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Failed to parse {fmt.upper()}: {e}", + ) if len(include_columns) == 0: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="No columns found to create screening table") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="No columns found to create screening table", + ) # Build a unique table name for this upload safe_sr = re.sub(r"[^0-9a-zA-Z_]", "_", sr_id) @@ -372,11 +426,18 @@ async def _upload_screening_citations_impl( # Create table and insert rows in threadpool try: - inserted = await run_in_threadpool(_create_table_and_insert_sync, table_name, include_columns, normalized_rows) + inserted = await run_in_threadpool( + _create_table_and_insert_sync, table_name, include_columns, normalized_rows + ) except RuntimeError as rexc: - raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc) + ) except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to create table or insert rows: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to create table or insert rows: {e}", + ) # Save DB connection metadata into SR Mongo doc try: @@ -436,7 +497,9 @@ async def upload_screening_citations( @router.get("/{sr_id}/citations") async def list_citation_ids( - sr_id: str, current_user: Dict[str, Any] = Depends(get_current_active_user), filter_step: Optional[str] = None, + sr_id: str, + current_user: Dict[str, Any] = Depends(get_current_active_user), + filter_step: Optional[str] = None, ): """ List all citation ids for the systematic review's screening database. @@ -448,7 +511,10 @@ async def list_citation_ids( except HTTPException: raise except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to load systematic review or screening: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to load systematic review or screening: {e}", + ) if not screening: return {"citation_ids": []} @@ -459,22 +525,31 @@ async def list_citation_ids( # Validation strategy: UI filters by human_l1_decision / human_l2_decision. try: cp = (sr or {}).get("criteria_parsed") or (sr or {}).get("criteria") or {} - await run_in_threadpool(cits_dp_service.backfill_human_decisions, cp, table_name) + await run_in_threadpool( + cits_dp_service.backfill_human_decisions, cp, table_name + ) except Exception: # best-effort; listing should still work even if backfill fails pass try: - ids = await run_in_threadpool(cits_dp_service.list_citation_ids, filter_step, table_name) + ids = await run_in_threadpool( + cits_dp_service.list_citation_ids, filter_step, table_name + ) except RuntimeError as rexc: - raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc) + ) except Exception as e: # If the SR points at a screening table that no longer exists (e.g. dropped), # treat it as "no citations" instead of poisoning the shared connection. if _is_undefined_table_error(e): return {"citation_ids": []} except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to query screening DB: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to query screening DB: {e}", + ) return {"citation_ids": ids} @@ -515,7 +590,9 @@ async def get_citations_batch( # Parse ids raw_ids = [p.strip() for p in (ids or "").split(",") if p.strip()] if not raw_ids: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="ids is required") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="ids is required" + ) parsed_ids: List[int] = [] for p in raw_ids: try: @@ -523,7 +600,10 @@ async def get_citations_batch( except Exception: continue if not parsed_ids: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="ids must be a comma-separated list of integers") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="ids must be a comma-separated list of integers", + ) parsed_fields: Optional[List[str]] = None if fields is not None: @@ -538,9 +618,14 @@ async def get_citations_batch( parsed_fields, ) except RuntimeError as rexc: - raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc) + ) except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to query screening DB: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to query screening DB: {e}", + ) return {"citations": rows} @@ -566,19 +651,31 @@ async def get_citation_by_id( except HTTPException: raise except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to load systematic review or screening: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to load systematic review or screening: {e}", + ) table_name = (screening or {}).get("table_name") or "citations" try: - row = await run_in_threadpool(cits_dp_service.get_citation_by_id, int(citation_id), table_name) + row = await run_in_threadpool( + cits_dp_service.get_citation_by_id, int(citation_id), table_name + ) except RuntimeError as rexc: - raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc) + ) except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to query screening DB: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to query screening DB: {e}", + ) if not row: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found" + ) return row @@ -588,7 +685,9 @@ class CombinedRequest(BaseModel): include_columns: Optional[List[str]] = None -def _build_combined_citation_from_row(row: Dict[str, Any], include_columns: List[str]) -> str: +def _build_combined_citation_from_row( + row: Dict[str, Any], include_columns: List[str] +) -> str: # Delegate to consolidated postgres service return cits_dp_service.build_combined_citation_from_row(row, include_columns) @@ -618,22 +717,37 @@ async def build_combined_citation( except HTTPException: raise except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to load systematic review or screening: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to load systematic review or screening: {e}", + ) if not screening: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No screening database configured for this systematic review") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="No screening database configured for this systematic review", + ) table_name = (screening or {}).get("table_name") or "citations" try: - row = await run_in_threadpool(cits_dp_service.get_citation_by_id, int(citation_id), table_name) + row = await run_in_threadpool( + cits_dp_service.get_citation_by_id, int(citation_id), table_name + ) except RuntimeError as rexc: - raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc) + ) except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to query screening DB: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to query screening DB: {e}", + ) if not row: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found" + ) # As a final fallback, use all columns present in include columns data = [] @@ -672,17 +786,25 @@ async def upload_citation_fulltext( except HTTPException: raise except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to load systematic review or screening: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to load systematic review or screening: {e}", + ) # Validate file if not file or not file.filename: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="File is required") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="File is required" + ) _, ext = os.path.splitext(file.filename) ext = ext.lower() # Prefer PDFs but allow other types if necessary; restrict to .pdf here if ext != ".pdf": - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only PDF files are accepted for full text upload") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Only PDF files are accepted for full text upload", + ) content = await file.read() new_md5 = hashlib.md5(content).hexdigest() if content is not None else "" @@ -690,14 +812,23 @@ async def upload_citation_fulltext( table_name = (screening or {}).get("table_name") or "citations" try: - existing_row = await run_in_threadpool(cits_dp_service.get_citation_by_id, int(citation_id), table_name) + existing_row = await run_in_threadpool( + cits_dp_service.get_citation_by_id, int(citation_id), table_name + ) except RuntimeError as rexc: - raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc) + ) except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to query screening DB: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to query screening DB: {e}", + ) if not existing_row: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found" + ) # If the PDF changed (md5 differs), clear L2 screening answers + parameter extractions + fulltext artifacts. # (Do NOT clear L1 answers.) @@ -740,9 +871,15 @@ async def upload_citation_fulltext( # NOTE (validation): we do not use l2_screen for filtering; keep it untouched/non-authoritative. cols_to_clear = ["llm_l2_decision", "human_l2_decision"] try: - cp = (sr or {}).get("criteria_parsed") or (sr or {}).get("criteria") or {} + cp = ( + (sr or {}).get("criteria_parsed") + or (sr or {}).get("criteria") + or {} + ) l2 = cp.get("l2") if isinstance(cp, dict) else None - l2_questions = (l2 or {}).get("questions") if isinstance(l2, dict) else None + l2_questions = ( + (l2 or {}).get("questions") if isinstance(l2, dict) else None + ) if isinstance(l2_questions, list): for q in l2_questions: try: @@ -763,7 +900,9 @@ async def upload_citation_fulltext( # best-effort pass - await run_in_threadpool(cits_dp_service.clear_columns, citation_id, cols_to_clear, table_name) + await run_in_threadpool( + cits_dp_service.clear_columns, citation_id, cols_to_clear, table_name + ) except Exception: # Best-effort; do not block upload. pass @@ -775,7 +914,10 @@ async def upload_citation_fulltext( storage_service = None if not storage_service: - raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Storage service not available") + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Storage service not available", + ) # Upload file for the current user document_id = await storage_service.upload_user_document( @@ -785,7 +927,10 @@ async def upload_citation_fulltext( ) if not document_id: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to upload file to storage service") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to upload file to storage service", + ) # Build storage path (container + blob name) so it can be stored in Postgres # Note: storage.py stores blobs at users/{user_id}/documents/{doc_id}_{filename} blob_name = f"users/{current_user['id']}/documents/{document_id}_{file.filename}" @@ -794,19 +939,35 @@ async def upload_citation_fulltext( # Update citation row in Postgres try: - updated = await run_in_threadpool(cits_dp_service.attach_fulltext, citation_id, storage_path, content, table_name) + updated = await run_in_threadpool( + cits_dp_service.attach_fulltext, + citation_id, + storage_path, + content, + table_name, + ) except RuntimeError as rexc: - raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc) + ) except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to update citation row: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to update citation row: {e}", + ) if not updated: # If the citation id doesn't exist, consider rolling back the uploaded file (best effort) # Attempt to delete the uploaded blob (best-effort; not fatal if it fails) try: - await storage_service.delete_user_document(current_user["id"], document_id, file.filename) + await storage_service.delete_user_document( + current_user["id"], document_id, file.filename + ) except Exception: pass - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found to attach fulltext file") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Citation not found to attach fulltext file", + ) return { "status": "success", "sr_id": sr_id, @@ -820,7 +981,9 @@ async def upload_citation_fulltext( # Helper to drop a database - delegated to backend.api.core.postgres.drop_database -async def hard_delete_screening_resources(sr_id: str, current_user: Dict[str, Any]) -> Dict[str, Any]: +async def hard_delete_screening_resources( + sr_id: str, current_user: Dict[str, Any] +) -> Dict[str, Any]: """ Delete the screening Postgres table and all associated fulltext files for the given systematic review. @@ -837,26 +1000,47 @@ async def hard_delete_screening_resources(sr_id: str, current_user: Dict[str, An except HTTPException: raise except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to load systematic review: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to load systematic review: {e}", + ) requester_id = current_user.get("id") if requester_id != sr.get("owner_id"): - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Only the owner may perform screening cleanup for this systematic review") + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Only the owner may perform screening cleanup for this systematic review", + ) if not screening: - return {"status": "no_screening_db", "message": "No screening table configured for this SR", "deleted_table": False, "deleted_files": 0} + return { + "status": "no_screening_db", + "message": "No screening table configured for this SR", + "deleted_table": False, + "deleted_files": 0, + } table_name = screening.get("table_name") if not table_name: - return {"status": "no_screening_db", "message": "Incomplete screening DB metadata", "deleted_table": False, "deleted_files": 0} + return { + "status": "no_screening_db", + "message": "Incomplete screening DB metadata", + "deleted_table": False, + "deleted_files": 0, + } # 1) collect fulltext URLs from the screening DB try: urls = await run_in_threadpool(cits_dp_service.list_fulltext_urls, table_name) except RuntimeError as rexc: - raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc) + ) except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to query screening DB for fulltext URLs: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to query screening DB for fulltext URLs: {e}", + ) # 2) delete blobs for each url (best-effort) deleted_files = 0 @@ -888,13 +1072,17 @@ async def hard_delete_screening_resources(sr_id: str, current_user: Dict[str, An # expect ["users", user_id, "documents", "{doc_id}_{filename}"] if len(parts) >= 4 and parts[2] == "documents": user_id = parts[1] - doc_part = "/".join(parts[3:]) # handle any extra slashes in filename + doc_part = "/".join( + parts[3:] + ) # handle any extra slashes in filename # split first underscore to get doc_id and filename if "_" in doc_part: doc_id, filename = doc_part.split("_", 1) # call delete_user_document (async) try: - ok = await storage_service.delete_user_document(user_id, doc_id, filename) + ok = await storage_service.delete_user_document( + user_id, doc_id, filename + ) if ok: deleted_files += 1 else: @@ -927,16 +1115,18 @@ async def hard_delete_screening_resources(sr_id: str, current_user: Dict[str, An await run_in_threadpool(cits_dp_service.drop_table, table_name) table_dropped = True except RuntimeError as rexc: - raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc) + ) except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to drop screening table: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to drop screening table: {e}", + ) # 4) remove screening_db metadata from SR document try: - await run_in_threadpool( - srdb_service.clear_screening_db_info, - sr_id - ) + await run_in_threadpool(srdb_service.clear_screening_db_info, sr_id) except Exception: # non-fatal, but report it pass @@ -952,7 +1142,9 @@ async def hard_delete_screening_resources(sr_id: str, current_user: Dict[str, An # Optional endpoint to trigger the cleanup directly @router.post("/{sr_id}/hard-clean") -async def hard_clean_screening_endpoint(sr_id: str, current_user: Dict[str, Any] = Depends(get_current_active_user)): +async def hard_clean_screening_endpoint( + sr_id: str, current_user: Dict[str, Any] = Depends(get_current_active_user) +): result = await hard_delete_screening_resources(sr_id, current_user) return result @@ -971,9 +1163,7 @@ async def export_citations_csv( """ try: - sr, screening = await load_sr_and_check( - sr_id, current_user, srdb_service - ) + sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service) except HTTPException: raise except Exception as e: @@ -992,9 +1182,13 @@ async def export_citations_csv( try: # Validation-friendly export: exclude fulltext/artifacts and flatten JSON columns. - csv_bytes = await run_in_threadpool(cits_dp_service.dump_citations_csv_filtered, table_name) + csv_bytes = await run_in_threadpool( + cits_dp_service.dump_citations_csv_filtered, table_name + ) except RuntimeError as rexc: - raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc) + ) except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, diff --git a/backend/api/core/bounding_box_matcher.py b/backend/api/core/bounding_box_matcher.py index a0fe5e4d..dfe7279e 100644 --- a/backend/api/core/bounding_box_matcher.py +++ b/backend/api/core/bounding_box_matcher.py @@ -432,4 +432,4 @@ def match_figure_references_to_bounding_boxes( print( f"[FIGURE_MATCHING] Enhanced {len(enhanced_references)} references with figure information" ) - return enhanced_references \ No newline at end of file + return enhanced_references diff --git a/backend/api/core/cit_utils.py b/backend/api/core/cit_utils.py index 1e33ce13..39282b30 100644 --- a/backend/api/core/cit_utils.py +++ b/backend/api/core/cit_utils.py @@ -9,6 +9,7 @@ Routers should call load_sr_and_check(...) to avoid duplicating this logic. """ + from typing import Any, Dict, Optional, Tuple from fastapi import HTTPException, status from fastapi.concurrency import run_in_threadpool @@ -65,34 +66,53 @@ async def load_sr_and_check( # fetch SR try: - sr = await run_in_threadpool(srdb_service.get_systematic_review, sr_id, not require_visible) + sr = await run_in_threadpool( + srdb_service.get_systematic_review, sr_id, not require_visible + ) except HTTPException: raise except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to fetch systematic review: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to fetch systematic review: {e}", + ) if not sr: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Systematic review not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Systematic review not found" + ) if require_visible and not sr.get("visible", True): - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Systematic review not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Systematic review not found" + ) # permission check (user must be member or owner) user_id = current_user.get("email") try: - has_perm = await run_in_threadpool(srdb_service.user_has_sr_permission, sr_id, user_id) + has_perm = await run_in_threadpool( + srdb_service.user_has_sr_permission, sr_id, user_id + ) except HTTPException: raise except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to check permissions: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to check permissions: {e}", + ) if not has_perm: - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized to view/modify this systematic review") + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Not authorized to view/modify this systematic review", + ) screening = sr.get("screening_db") if isinstance(sr, dict) else None if require_screening: if not screening: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No screening database configured for this systematic review") - + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="No screening database configured for this systematic review", + ) return sr, screening diff --git a/backend/api/core/config.py b/backend/api/core/config.py index 1e710ad5..1aa5c2f2 100644 --- a/backend/api/core/config.py +++ b/backend/api/core/config.py @@ -96,17 +96,25 @@ def convert_max_file_size(cls, v): # Background jobs (Procrastinate) # --------------------------------------------------------------------- # Enable Procrastinate-backed background jobs (required for /api/jobs/*) - ENABLE_PROCRASTINATE: bool = os.getenv("ENABLE_PROCRASTINATE", "false").lower().strip() == "true" + ENABLE_PROCRASTINATE: bool = ( + os.getenv("ENABLE_PROCRASTINATE", "false").lower().strip() == "true" + ) # Embedded worker loop (dev-friendly). Keep this OFF by default. - ENABLE_PROCRASTINATE_WORKER: bool = os.getenv("ENABLE_PROCRASTINATE_WORKER", "false").lower().strip() == "true" + ENABLE_PROCRASTINATE_WORKER: bool = ( + os.getenv("ENABLE_PROCRASTINATE_WORKER", "false").lower().strip() == "true" + ) # Worker concurrency (only used when ENABLE_PROCRASTINATE_WORKER=true) - PROCRASTINATE_WORKER_CONCURRENCY: int = int(os.getenv("PROCRASTINATE_WORKER_CONCURRENCY", "1")) + PROCRASTINATE_WORKER_CONCURRENCY: int = int( + os.getenv("PROCRASTINATE_WORKER_CONCURRENCY", "1") + ) # Optional dev cleanup: clear out leftover queued/doing tasks on API startup. # IMPORTANT: default is true if the env var is absent. - PROCRASTINATE_CLEAR_ON_START: bool = os.getenv("PROCRASTINATE_CLEAR_ON_START", "true").lower().strip() == "true" + PROCRASTINATE_CLEAR_ON_START: bool = ( + os.getenv("PROCRASTINATE_CLEAR_ON_START", "true").lower().strip() == "true" + ) # Run-All job chunk size: citations per Procrastinate chunk task. # Larger values reduce overhead but can reduce fairness/responsiveness. @@ -123,7 +131,9 @@ def convert_max_file_size(cls, v): # ------------------------------------------------------------------------- # Select primary Postgres mode. # docker/local use password auth; azure uses Entra token auth. - POSTGRES_MODE: str = os.getenv("POSTGRES_MODE", "docker").lower().strip() # docker|local|azure + POSTGRES_MODE: str = ( + os.getenv("POSTGRES_MODE", "docker").lower().strip() + ) # docker|local|azure # Canonical Postgres connection settings (single profile; values vary by environment) POSTGRES_HOST: Optional[str] = os.getenv("POSTGRES_HOST") @@ -153,7 +163,9 @@ def postgres_profile(self, mode: Optional[str] = None) -> dict: raise ValueError("POSTGRES_MODE must be one of: docker, local, azure") # Provide sensible defaults for host depending on mode. - default_host = "pgdb-service" if m == "docker" else "localhost" if m == "local" else None + default_host = ( + "pgdb-service" if m == "docker" else "localhost" if m == "local" else None + ) prof = { "mode": m, @@ -179,7 +191,7 @@ def has_local_fallback(self) -> bool: JOB_ID_PUBMED: str = os.getenv("JOB_ID_PUBMED", "") JOB_ID_SCOPUS: str = os.getenv("JOB_ID_SCOPUS", "") - #search function + # search function ENTREZ_EMAIL: str = os.getenv("ENTREZ_EMAIL") ENTREZ_API_KEY: str = os.getenv("ENTREZ_API_KEY") diff --git a/backend/api/core/security.py b/backend/api/core/security.py index 19e38e20..cf3245c8 100644 --- a/backend/api/core/security.py +++ b/backend/api/core/security.py @@ -90,7 +90,9 @@ async def update_user(user_id: str, user_in: UserUpdate) -> Optional[UserRead]: return UserRead.model_validate(updated_user_data) -async def authenticate_user(email: str, password: str, sso: bool = False) -> Optional[Dict[str, Any]]: +async def authenticate_user( + email: str, password: str, sso: bool = False +) -> Optional[Dict[str, Any]]: """Authenticate a user""" email = email.lower() if not user_db_service: diff --git a/backend/api/database_search/router.py b/backend/api/database_search/router.py index a5b4e79f..9daf1671 100644 --- a/backend/api/database_search/router.py +++ b/backend/api/database_search/router.py @@ -4,8 +4,12 @@ from pydantic import BaseModel, Field from ..core.security import get_current_active_user from ..core.config import settings -from ..services.citation_search.pubmed_citation_collection import PubMedCitationCollector -from ..services.citation_search.europePMC_citation_collection import EuropePMCCitationCollector +from ..services.citation_search.pubmed_citation_collection import ( + PubMedCitationCollector, +) +from ..services.citation_search.europePMC_citation_collection import ( + EuropePMCCitationCollector, +) from ..services.citation_search.scopus_citation_collection import ScopusDataProcessor import logging @@ -18,12 +22,13 @@ import datetime + def azure_client(container_name: str): connection_string = settings.AZURE_STORAGE_CONNECTION_STRING blob_service_client = BlobServiceClient.from_connection_string(connection_string) blob_client = blob_service_client.get_container_client(container_name) return blob_client - + def read_blob_to_df(blob_name: str, container_name: str) -> pd.DataFrame: """Read a blob (Parquet or CSV) into a pandas DataFrame.""" @@ -36,32 +41,40 @@ def read_blob_to_df(blob_name: str, container_name: str) -> pd.DataFrame: def write_df_to_blob(df: pd.DataFrame, blob_name: str, container_name: str): """Write a pandas DataFrame to blob as Parquet.""" buffer = BytesIO() - df.to_csv(buffer, index=False, encoding='utf-8') + df.to_csv(buffer, index=False, encoding="utf-8") buffer.seek(0) blob_client = azure_client(container_name).get_blob_client(blob_name) blob_client.upload_blob(buffer, overwrite=True) + router = APIRouter() + class SearchRequest(BaseModel): database: str = Field(..., description="database selected for search") - search_term: Optional[str] = Field("", description="search string for database search") + search_term: Optional[str] = Field( + "", description="search string for database search" + ) @router.post("/{sr_id}/search") async def database_search( - sr_id: str, payload: SearchRequest, current_user: Dict[str, Any] = Depends(get_current_active_user), + sr_id: str, + payload: SearchRequest, + current_user: Dict[str, Any] = Depends(get_current_active_user), ): - + MAX_ARTICLES = 1000 MINDATE = None MAXDATE = None database = payload.database SEARCH_TERM = payload.search_term - if not SEARCH_TERM : + if not SEARCH_TERM: SEARCH_TERM = '(("epidemiological parameters"[Title/Abstract] OR "incidence"[MeSH Terms]))' - logging.info(f"{database} search function started at {datetime.datetime.now()}", ) + logging.info( + f"{database} search function started at {datetime.datetime.now()}", + ) if database == "Pubmed": collector = PubMedCitationCollector() @@ -71,7 +84,7 @@ async def database_search( search_term=SEARCH_TERM, mindate=MINDATE, maxdate=MAXDATE, - max_articles=MAX_ARTICLES + max_articles=MAX_ARTICLES, ) except Exception as e: logging.error(f"Error collecting citations: {str(e)}") @@ -81,58 +94,74 @@ async def database_search( elif database == "EuropePMC": collector = EuropePMCCitationCollector() - try: - citations = collector.collect_citations(search_term = SEARCH_TERM, max_articles=MAX_ARTICLES) + try: + citations = collector.collect_citations( + search_term=SEARCH_TERM, max_articles=MAX_ARTICLES + ) except Exception as e: logging.error(f"Error collecting citations: {str(e)}") raise HTTPException(status_code=500, detail="Failed to collect citations") content = pd.DataFrame(citations) elif database == "Scopus": - - - SCOPUS_API_KEY = settings.SCOPUS_API_KEY or None #currently don't have + + SCOPUS_API_KEY = settings.SCOPUS_API_KEY or None # currently don't have API_URL = settings.SCOPUS_API_URL if not SCOPUS_API_KEY or not API_URL: - logging.error("Missing SCOPUS_API_KEY or SCOPUS_BASE_URL in environment settings.") - raise HTTPException(status_code=500, detail="Missing SCOPUS_API_KEY or SCOPUS_BASE_URL") + logging.error( + "Missing SCOPUS_API_KEY or SCOPUS_BASE_URL in environment settings." + ) + raise HTTPException( + status_code=500, detail="Missing SCOPUS_API_KEY or SCOPUS_BASE_URL" + ) processor = ScopusDataProcessor(SCOPUS_API_KEY, API_URL) processor.consume_api(SEARCH_TERM, delay=15) content = pd.DataFrame.from_dict(processor.data) - blob_name = f"{database}-{datetime.datetime.now().strftime('%Y-%m-%d')}.csv" container_name = f"citation-data/bronze-data/{database}/archive" try: write_df_to_blob(content, blob_name, container_name) logging.info(f"Uploaded blob '{blob_name}' to container '{container_name}'") except AzureError as e: - logging.error(f"Azure Blob Storage error: {e.message if hasattr(e, 'message') else str(e)}") + logging.error( + f"Azure Blob Storage error: {e.message if hasattr(e, 'message') else str(e)}" + ) raise HTTPException(status_code=500, detail="Failed to write to Azure Blob") except Exception as ex: logging.error(f"Unexpected error: {str(ex)}") - raise HTTPException(status_code=500, detail="Unexpected error while writing to blob") + raise HTTPException( + status_code=500, detail="Unexpected error while writing to blob" + ) try: - all_citations_df = read_blob_to_df(f"{database}-all-citations.csv", f"citation-data/bronze-data/{database}") + all_citations_df = read_blob_to_df( + f"{database}-all-citations.csv", f"citation-data/bronze-data/{database}" + ) except Exception: - all_citations_df = content.iloc[0:0].copy() + all_citations_df = content.iloc[0:0].copy() - new_all_citations = ( - pd.concat([all_citations_df, content], ignore_index=True) - .drop_duplicates(subset="pmid", keep="first") - ) + new_all_citations = pd.concat( + [all_citations_df, content], ignore_index=True + ).drop_duplicates(subset="pmid", keep="first") new_citations = content[~content["pmid"].isin(all_citations_df["pmid"])] if new_citations.empty: - logging.info('No new citations on %s', datetime.datetime.now()) + logging.info("No new citations on %s", datetime.datetime.now()) return {"message": "No new citations", "new_count": 0} - logging.info('New citations found: %s', len(new_citations)) + logging.info("New citations found: %s", len(new_citations)) - write_df_to_blob(new_all_citations, f"{database}-all-citations.csv", f"citation-data/bronze-data/{database}") + write_df_to_blob( + new_all_citations, + f"{database}-all-citations.csv", + f"citation-data/bronze-data/{database}", + ) write_df_to_blob(new_citations, blob_name, "citation-deduplicate/to-process") - return {"message": f"{database} collection completed", "new_citations_count": len(new_citations)} \ No newline at end of file + return { + "message": f"{database} collection completed", + "new_citations_count": len(new_citations), + } diff --git a/backend/api/extract/router.py b/backend/api/extract/router.py index 7e305fc0..3cb968d4 100644 --- a/backend/api/extract/router.py +++ b/backend/api/extract/router.py @@ -33,40 +33,55 @@ # Import consolidated Postgres helpers if available (optional) from ..services.cit_db_service import cits_dp_service, snake_case_param - - router = APIRouter() class ParameterExtractRequest(BaseModel): fulltext: Optional[str] = Field( None, - description="Full text with numbered sentences (e.g. '[0] First sentence\\n[1] Second sentence'). If omitted the endpoint will try to read fulltext_url from the screening DB row." + description="Full text with numbered sentences (e.g. '[0] First sentence\\n[1] Second sentence'). If omitted the endpoint will try to read fulltext_url from the screening DB row.", + ) + parameter_name: str = Field( + ..., description="Short name for the parameter (used as column name slug)" + ) + parameter_description: str = Field( + ..., description="Human-friendly description of what to extract" ) - parameter_name: str = Field(..., description="Short name for the parameter (used as column name slug)") - parameter_description: str = Field(..., description="Human-friendly description of what to extract") model: Optional[str] = Field(None, description="Model to use") temperature: Optional[float] = Field(0.0, ge=0.0, le=1.0) max_tokens: Optional[int] = Field(512, ge=1, le=4000) # Optional artifacts context (if omitted, server will read from citation row when available) - tables: Optional[str] = Field(None, description="Optional numbered tables text (markdown).") - figures: Optional[str] = Field(None, description="Optional numbered figure captions text.") - attach_figures: Optional[bool] = Field(True, description="If true, attach figure images to the LLM request when available") - - + tables: Optional[str] = Field( + None, description="Optional numbered tables text (markdown)." + ) + figures: Optional[str] = Field( + None, description="Optional numbered figure captions text." + ) + attach_figures: Optional[bool] = Field( + True, + description="If true, attach figure images to the LLM request when available", + ) class HumanParameterRequest(BaseModel): fulltext: Optional[str] = Field( None, - description="Optional numbered full text. If omitted the server will try to read fulltext_url from the screening DB row." + description="Optional numbered full text. If omitted the server will try to read fulltext_url from the screening DB row.", + ) + parameter_name: str = Field( + ..., description="Short name for the parameter (used as column name slug)" ) - parameter_name: str = Field(..., description="Short name for the parameter (used as column name slug)") found: bool = Field(..., description="Whether the parameter was found (boolean)") - value: Optional[str] = Field(None, description="Human-provided value (string) or null") - explanation: Optional[str] = Field("", description="Optional explanation from the human reviewer") - evidence_sentences: Optional[List[int]] = Field(None, description="Optional list of evidence sentence indices") + value: Optional[str] = Field( + None, description="Human-provided value (string) or null" + ) + explanation: Optional[str] = Field( + "", description="Optional explanation from the human reviewer" + ) + evidence_sentences: Optional[List[int]] = Field( + None, description="Optional list of evidence sentence indices" + ) reviewer: Optional[str] = Field(None, description="Optional reviewer id or name") @@ -98,7 +113,10 @@ async def extract_parameter_endpoint( except HTTPException: raise except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to load systematic review or screening: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to load systematic review or screening: {e}", + ) table_name = (screening or {}).get("table_name") or "citations" @@ -107,18 +125,30 @@ async def extract_parameter_endpoint( row = None if not fulltext: try: - row = await run_in_threadpool(cits_dp_service.get_citation_by_id, int(citation_id), table_name) + row = await run_in_threadpool( + cits_dp_service.get_citation_by_id, int(citation_id), table_name + ) except RuntimeError as rexc: - raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc) + ) except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to query screening DB: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to query screening DB: {e}", + ) if not row: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found" + ) fulltext = row.get("fulltext") if "fulltext" in row else None if not fulltext: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Full text not provided and not available for this citation") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Full text not provided and not available for this citation", + ) # Build tables/figures context: prefer payload fields, else fetch from DB row (if loaded) tables_text = payload.tables @@ -127,7 +157,9 @@ async def extract_parameter_endpoint( if (tables_text is None or figures_text is None) and row is None: try: - row = await run_in_threadpool(cits_dp_service.get_citation_by_id, int(citation_id), table_name) + row = await run_in_threadpool( + cits_dp_service.get_citation_by_id, int(citation_id), table_name + ) except Exception: row = None @@ -153,7 +185,9 @@ async def extract_parameter_endpoint( try: md_bytes, _ = await storage_service.get_bytes_by_path(blob_addr) md_txt = md_bytes.decode("utf-8", errors="replace") - header = f"Table [T{idx}]" + (f" caption: {caption}" if caption else "") + header = f"Table [T{idx}]" + ( + f" caption: {caption}" if caption else "" + ) tables_md_lines.extend([header, md_txt, ""]) except Exception: continue @@ -182,7 +216,9 @@ async def extract_parameter_endpoint( ) if payload.attach_figures: try: - img_bytes, _ = await storage_service.get_bytes_by_path(blob_addr) + img_bytes, _ = await storage_service.get_bytes_by_path( + blob_addr + ) if img_bytes: images.append((img_bytes, "image/png")) except Exception: @@ -202,7 +238,10 @@ async def extract_parameter_endpoint( ) if not azure_openai_client.is_configured(): - raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Azure OpenAI client is not configured on the server") + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Azure OpenAI client is not configured on the server", + ) try: if images: @@ -223,7 +262,10 @@ async def extract_parameter_endpoint( temperature=payload.temperature or 0.0, ) except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"LLM call failed: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"LLM call failed: {e}", + ) # Parse JSON with robustness: tolerate code fences or preamble text def _extract_json_object(text: str) -> Optional[str]: @@ -244,7 +286,7 @@ def _extract_json_object(text: str) -> Optional[str]: elif ch == "}": depth -= 1 if depth == 0: - return t[start:i+1] + return t[start : i + 1] return None parsed = None @@ -258,10 +300,16 @@ def _extract_json_object(text: str) -> Optional[str]: except Exception: parsed = None if parsed is None: - raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=f"LLM response was not valid JSON: {llm_response[:1000]}") + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail=f"LLM response was not valid JSON: {llm_response[:1000]}", + ) if not isinstance(parsed, dict): - raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="LLM response JSON was not an object") + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail="LLM response JSON was not an object", + ) # Normalize and validate keys and types (tolerate minor deviations) found_raw = parsed.get("found", None) @@ -272,7 +320,10 @@ def _extract_json_object(text: str) -> Optional[str]: elif isinstance(found_raw, (int, float)): found_val = bool(found_raw) else: - raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="LLM JSON missing or invalid 'found' key") + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail="LLM JSON missing or invalid 'found' key", + ) # value may be string or null; coerce common primitives to string val = parsed.get("value") @@ -283,7 +334,10 @@ def _extract_json_object(text: str) -> Optional[str]: try: val = str(val) except Exception: - raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="'value' must be a string or null") + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail="'value' must be a string or null", + ) explanation = parsed.get("explanation") or "" if not isinstance(explanation, str): @@ -307,7 +361,10 @@ def _extract_json_object(text: str) -> Optional[str]: # skip unsupported types continue else: - raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="'evidence_sentences' must be a list") + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail="'evidence_sentences' must be a list", + ) # Normalize evidence tables/figures def _norm_int_list(v: Any) -> List[int]: @@ -351,14 +408,27 @@ def _norm_int_list(v: Any) -> List[int]: col_name = snake_case_param(payload.parameter_name) try: - updated = await run_in_threadpool(cits_dp_service.update_jsonb_column, citation_id, col_name, stored, table_name) + updated = await run_in_threadpool( + cits_dp_service.update_jsonb_column, + citation_id, + col_name, + stored, + table_name, + ) except RuntimeError as rexc: - raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc) + ) except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to update citation row: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to update citation row: {e}", + ) if not updated: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found to update") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found to update" + ) # Auto-fill human_param_* from llm_param_* if missing (never overwrite) try: @@ -380,7 +450,13 @@ def _norm_int_list(v: Any) -> List[int]: except Exception: pass - return {"status": "success", "sr_id": sr_id, "citation_id": citation_id, "column": col_name, "extraction": stored} + return { + "status": "success", + "sr_id": sr_id, + "citation_id": citation_id, + "column": col_name, + "extraction": stored, + } @router.post("/{sr_id}/citations/{citation_id}/human-extract-parameter") @@ -403,20 +479,32 @@ async def human_extract_parameter( except HTTPException: raise except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to load systematic review or screening: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to load systematic review or screening: {e}", + ) table_name = (screening or {}).get("table_name") or "citations" # Ensure citation exists (we won't require full_text for human input but check row presence) try: - row = await run_in_threadpool(cits_dp_service.get_citation_by_id, int(citation_id), table_name) + row = await run_in_threadpool( + cits_dp_service.get_citation_by_id, int(citation_id), table_name + ) except RuntimeError as rexc: - raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc) + ) except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to query screening DB: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to query screening DB: {e}", + ) if not row: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found" + ) # Normalize value val = payload.value @@ -424,12 +512,18 @@ async def human_extract_parameter( try: val = str(val) except Exception: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="'value' must be a string or null") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="'value' must be a string or null", + ) # Normalize evidence_sentences evidence = payload.evidence_sentences or [] if not isinstance(evidence, list) or not all(isinstance(i, int) for i in evidence): - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="'evidence_sentences' must be a list of integers") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="'evidence_sentences' must be a list of integers", + ) explanation = payload.explanation or "" if not isinstance(explanation, str): @@ -458,22 +552,42 @@ async def human_extract_parameter( # fallback core name try: from ..services.cit_db_service import snake_case as _snake_case + core = _snake_case(payload.parameter_name) if _snake_case else "" except Exception: core = "" col_name = f"human_param_{core}" if core else "human_param_param" try: - updated = await run_in_threadpool(cits_dp_service.update_jsonb_column, citation_id, col_name, stored, table_name) + updated = await run_in_threadpool( + cits_dp_service.update_jsonb_column, + citation_id, + col_name, + stored, + table_name, + ) except RuntimeError as rexc: - raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc) + ) except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to update citation row: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to update citation row: {e}", + ) if not updated: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found to update") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found to update" + ) - return {"status": "success", "sr_id": sr_id, "citation_id": citation_id, "column": col_name, "extraction": stored} + return { + "status": "success", + "sr_id": sr_id, + "citation_id": citation_id, + "column": col_name, + "extraction": stored, + } @router.post("/{sr_id}/citations/{citation_id}/extract-fulltext") @@ -493,34 +607,58 @@ async def extract_fulltext_from_storage( except HTTPException: raise except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to load systematic review or screening: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to load systematic review or screening: {e}", + ) table_name = (screening or {}).get("table_name") or "citations" # fetch citation row try: - row = await run_in_threadpool(cits_dp_service.get_citation_by_id, int(citation_id), table_name) + row = await run_in_threadpool( + cits_dp_service.get_citation_by_id, int(citation_id), table_name + ) except RuntimeError as rexc: - raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc) + ) except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to query screening DB: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to query screening DB: {e}", + ) if not row: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found" + ) # Determine storage path for the citation PDF storage_path = row.get("fulltext_url") if not storage_path: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No fulltext storage path found on citation row") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="No fulltext storage path found on citation row", + ) try: content, _filename = await storage_service.get_bytes_by_path(storage_path) except FileNotFoundError: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Fulltext file not found in storage") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Fulltext file not found in storage", + ) except ValueError: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Unrecognized storage path format") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Unrecognized storage path format", + ) except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to download from storage: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to download from storage: {e}", + ) # If the citation row already contains an extracted full text in the "fulltext" column, # only use it if the stored md5 matches the pdf we just downloaded. @@ -555,8 +693,15 @@ async def _run_grobid(): async def _run_docint(): if not azure_docint_client or not azure_docint_client.is_available(): - return {"success": False, "error": "Azure DI not configured", "figures": [], "tables": []} - return await azure_docint_client.extract_citation_artifacts(tmp.name, source_type="file") + return { + "success": False, + "error": "Azure DI not configured", + "figures": [], + "tables": [], + } + return await azure_docint_client.extract_citation_artifacts( + tmp.name, source_type="file" + ) try: (coords, pages), docint_res = await asyncio.gather( @@ -564,7 +709,10 @@ async def _run_docint(): _run_docint(), ) except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Fulltext processing failed: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Fulltext processing failed: {e}", + ) # filter sentence annotations annotations = [a for a in coords if a.get("type") == "s" and a.get("text")] @@ -579,7 +727,11 @@ async def _run_docint(): artifact_coords: List[Dict[str, Any]] = [] try: - if docint_res and isinstance(docint_res, dict) and docint_res.get("success"): + if ( + docint_res + and isinstance(docint_res, dict) + and docint_res.get("success") + ): pages_meta = docint_res.get("pages") or [] # Determine artifact base path from the fulltext_url directory # storage_path is "container/blob". @@ -589,7 +741,7 @@ async def _run_docint(): artifacts_prefix = artifacts_prefix.replace("//", "/").rstrip("/") # Figures: write png - for fig in (docint_res.get("figures") or []): + for fig in docint_res.get("figures") or []: try: idx = int(fig.get("index")) except Exception: @@ -634,7 +786,7 @@ async def _run_docint(): ) # Tables: write markdown (.md) - for tbl in (docint_res.get("tables") or []): + for tbl in docint_res.get("tables") or []: try: idx = int(tbl.get("index")) except Exception: @@ -687,20 +839,63 @@ async def _run_docint(): # persist full_text_str and coordinates/pages into citation row try: coords_for_overlay = list(annotations) + list(artifact_coords) - updated1 = await run_in_threadpool(cits_dp_service.update_text_column, citation_id, "fulltext", full_text_str, table_name) - updated2 = await run_in_threadpool(cits_dp_service.update_text_column, citation_id, "fulltext_md5", current_md5, table_name) - updated3 = await run_in_threadpool(cits_dp_service.update_jsonb_column, citation_id, "fulltext_coords", coords_for_overlay, table_name) - updated4 = await run_in_threadpool(cits_dp_service.update_jsonb_column, citation_id, "fulltext_pages", pages, table_name) - updated5 = await run_in_threadpool(cits_dp_service.update_jsonb_column, citation_id, "fulltext_figures", fulltext_figures, table_name) - updated6 = await run_in_threadpool(cits_dp_service.update_jsonb_column, citation_id, "fulltext_tables", fulltext_tables, table_name) + updated1 = await run_in_threadpool( + cits_dp_service.update_text_column, + citation_id, + "fulltext", + full_text_str, + table_name, + ) + updated2 = await run_in_threadpool( + cits_dp_service.update_text_column, + citation_id, + "fulltext_md5", + current_md5, + table_name, + ) + updated3 = await run_in_threadpool( + cits_dp_service.update_jsonb_column, + citation_id, + "fulltext_coords", + coords_for_overlay, + table_name, + ) + updated4 = await run_in_threadpool( + cits_dp_service.update_jsonb_column, + citation_id, + "fulltext_pages", + pages, + table_name, + ) + updated5 = await run_in_threadpool( + cits_dp_service.update_jsonb_column, + citation_id, + "fulltext_figures", + fulltext_figures, + table_name, + ) + updated6 = await run_in_threadpool( + cits_dp_service.update_jsonb_column, + citation_id, + "fulltext_tables", + fulltext_tables, + table_name, + ) updated = updated1 or updated2 or updated3 or updated4 or updated5 or updated6 except RuntimeError as rexc: - raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc) + ) except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to update citation row with full text: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to update citation row with full text: {e}", + ) if not updated: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found to update") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found to update" + ) return { "status": "success", diff --git a/backend/api/files/router.py b/backend/api/files/router.py index 9e572103..49b35fcf 100644 --- a/backend/api/files/router.py +++ b/backend/api/files/router.py @@ -1,6 +1,7 @@ """ Files router for document management in CAN-SR. """ + from typing import List, Dict, Any, Optional import os import logging @@ -28,6 +29,7 @@ class DocumentUploadResponse(BaseModel): """Response model for document upload""" + document_id: str filename: str file_size: int @@ -37,6 +39,7 @@ class DocumentUploadResponse(BaseModel): class DocumentInfo(BaseModel): """Document information model""" + document_id: str filename: str file_size: int @@ -45,6 +48,7 @@ class DocumentInfo(BaseModel): class DocumentListResponse(BaseModel): """Response model for document list""" + total_documents: int documents: List[DocumentInfo] @@ -60,8 +64,7 @@ async def upload_document( try: if not file.filename: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Filename is required" + status_code=status.HTTP_400_BAD_REQUEST, detail="Filename is required" ) file_content = await file.read() @@ -129,8 +132,7 @@ async def list_documents( document_infos.append(document_info) return DocumentListResponse( - total_documents=len(document_infos), - documents=document_infos + total_documents=len(document_infos), documents=document_infos ) except HTTPException: @@ -165,8 +167,7 @@ async def download_document( if not document: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Document not found" + status_code=status.HTTP_404_NOT_FOUND, detail="Document not found" ) file_content = await storage_service.get_user_document( @@ -222,9 +223,13 @@ async def download_by_path( try: signed_url = await storage_service.generate_signed_url(path) except ValueError: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid storage path") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid storage path" + ) except Exception as e: - logger.warning("Signed URL generation failed, falling back to streaming: %s", e) + logger.warning( + "Signed URL generation failed, falling back to streaming: %s", e + ) signed_url = None if signed_url: @@ -234,20 +239,26 @@ async def download_by_path( try: content, filename = await storage_service.get_bytes_by_path(path) except FileNotFoundError: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="File not found" + ) except ValueError: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid storage path") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid storage path" + ) except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to download: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to download: {e}", + ) + def gen(): yield content return StreamingResponse( gen(), media_type="application/octet-stream", - headers={ - "Content-Disposition": f"attachment; filename={filename}" - }, + headers={"Content-Disposition": f"attachment; filename={filename}"}, ) except HTTPException: raise @@ -281,8 +292,7 @@ async def delete_document( if not document: raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Document not found" + status_code=status.HTTP_404_NOT_FOUND, detail="Document not found" ) success = await storage_service.delete_user_document( diff --git a/backend/api/jobs/procrastinate_app.py b/backend/api/jobs/procrastinate_app.py index 932f9a44..39085b51 100644 --- a/backend/api/jobs/procrastinate_app.py +++ b/backend/api/jobs/procrastinate_app.py @@ -62,7 +62,9 @@ def workers_enabled() -> bool: def worker_concurrency() -> int: try: - return max(1, int(getattr(settings, "PROCRASTINATE_WORKER_CONCURRENCY", 1) or 1)) + return max( + 1, int(getattr(settings, "PROCRASTINATE_WORKER_CONCURRENCY", 1) or 1) + ) except Exception: return 1 @@ -79,13 +81,11 @@ async def ensure_procrastinate_schema() -> None: async def _schema_installed() -> bool: try: - row = await PROCRASTINATE_APP.connector.execute_query_one_async( - """ + row = await PROCRASTINATE_APP.connector.execute_query_one_async(""" SELECT to_regclass('procrastinate_jobs') IS NOT NULL AS jobs_table, EXISTS(SELECT 1 FROM pg_type WHERE typname = 'procrastinate_job_status') AS status_enum - """ # noqa: S608 - ) + """) # noqa: S608 return bool(row.get("jobs_table") and row.get("status_enum")) except Exception: return False @@ -109,7 +109,14 @@ async def _schema_installed() -> bool: # If schema is already present (possibly created by a previous run), # treat duplicate-object errors as success. cause: BaseException | None = e.__cause__ - if isinstance(cause, (psycopg.errors.DuplicateObject, psycopg.errors.DuplicateTable, psycopg.errors.DuplicateFunction)): + if isinstance( + cause, + ( + psycopg.errors.DuplicateObject, + psycopg.errors.DuplicateTable, + psycopg.errors.DuplicateFunction, + ), + ): if await _schema_installed(): return raise @@ -156,7 +163,9 @@ async def clear_pending_jobs(*, queues: Optional[list[str]] = None) -> int: await PROCRASTINATE_APP.close_async() -async def cancel_enqueued_jobs_for_run_all(job_id: str, *, queues: Optional[list[str]] = None) -> int: +async def cancel_enqueued_jobs_for_run_all( + job_id: str, *, queues: Optional[list[str]] = None +) -> int: """Best-effort: delete enqueued (todo) Procrastinate jobs for a given run-all job_id. This makes Cancel feel more responsive by removing jobs that haven't started yet. diff --git a/backend/api/jobs/router.py b/backend/api/jobs/router.py index a49b8f6a..ec2224aa 100644 --- a/backend/api/jobs/router.py +++ b/backend/api/jobs/router.py @@ -13,7 +13,11 @@ from ..services.azure_openai_client import azure_openai_client from .run_all_repo import run_all_repo -from .procrastinate_app import cancel_enqueued_jobs_for_run_all, jobs_enabled, worker_concurrency +from .procrastinate_app import ( + cancel_enqueued_jobs_for_run_all, + jobs_enabled, + worker_concurrency, +) # Import task objects so we can enqueue via Task.defer_async (Procrastinate 3.2.x) from .run_all_tasks import run_all_chunk, run_all_start @@ -21,7 +25,6 @@ # Import tasks so Procrastinate can discover them. from . import run_all_tasks # noqa: F401 - router = APIRouter() @@ -91,7 +94,9 @@ async def list_active_run_all( if not user_email: raise HTTPException(status_code=401, detail="Missing user identity") - srs = await run_in_threadpool(srdb_service.list_systematic_reviews_for_user, user_email) + srs = await run_in_threadpool( + srdb_service.list_systematic_reviews_for_user, user_email + ) sr_ids = [str(sr.get("id")) for sr in (srs or []) if sr and sr.get("id")] if not sr_ids: return {"jobs": []} @@ -99,7 +104,9 @@ async def list_active_run_all( jobs = await run_in_threadpool(run_all_repo.list_active_jobs_for_srs, sr_ids) # Attach SR name for nicer UI. - sr_name_map = {str(sr.get("id")): sr.get("name") for sr in (srs or []) if sr and sr.get("id")} + sr_name_map = { + str(sr.get("id")): sr.get("name") for sr in (srs or []) if sr and sr.get("id") + } for j in jobs: sid = str(j.get("sr_id")) if sid in sr_name_map: @@ -122,7 +129,9 @@ async def start_run_all( step = (payload.step or "").lower().strip() if step not in {"l1", "l2", "extract"}: - raise HTTPException(status_code=400, detail="step must be one of: l1, l2, extract") + raise HTTPException( + status_code=400, detail="step must be one of: l1, l2, extract" + ) # Authz: ensure user can access SR try: @@ -189,7 +198,9 @@ async def start_run_all( # If another request won the race, the partial unique index on # (sr_id) WHERE status IN ('queued','running','paused') will throw. if _is_unique_violation(e): - existing = await run_in_threadpool(run_all_repo.get_active_job_for_sr, sr_id) + existing = await run_in_threadpool( + run_all_repo.get_active_job_for_sr, sr_id + ) if existing: return { "job_id": existing.get("job_id"), @@ -209,7 +220,9 @@ async def start_run_all( await run_in_threadpool(run_all_repo.set_status, job_id, "paused") else: await run_in_threadpool(run_all_repo.set_status, job_id, "running") - await run_in_threadpool(run_all_repo.update_phase, job_id, f"enqueued {len(sanitized_ids)}") + await run_in_threadpool( + run_all_repo.update_phase, job_id, f"enqueued {len(sanitized_ids)}" + ) # Fair scheduling: persist chunks and only enqueue the *next* chunk. # This prevents one job from flooding the global queue. @@ -228,7 +241,9 @@ async def start_run_all( ) if next_chunk_id is None: break - await run_all_chunk.defer_async(job_id=job_id, chunk_id=int(next_chunk_id)) + await run_all_chunk.defer_async( + job_id=job_id, chunk_id=int(next_chunk_id) + ) # Helpful operator logging print( @@ -348,7 +363,9 @@ async def run_all_dismiss( st = str(job.get("status") or "").lower() if st not in {"finished", "failed"}: - raise HTTPException(status_code=400, detail="Only finished/failed jobs can be dismissed") + raise HTTPException( + status_code=400, detail="Only finished/failed jobs can be dismissed" + ) await run_in_threadpool(run_all_repo.set_status, job_id, "done") return {"status": "done", "job_id": job_id} diff --git a/backend/api/jobs/run_all_repo.py b/backend/api/jobs/run_all_repo.py index 650dc88a..dd873387 100644 --- a/backend/api/jobs/run_all_repo.py +++ b/backend/api/jobs/run_all_repo.py @@ -30,8 +30,7 @@ def ensure_tables(self) -> None: try: conn = postgres_server.conn cur = conn.cursor() - cur.execute( - """ + cur.execute(""" CREATE TABLE IF NOT EXISTS run_all_jobs ( id UUID PRIMARY KEY, sr_id TEXT NOT NULL, @@ -50,15 +49,13 @@ def ensure_tables(self) -> None: started_at TIMESTAMP WITH TIME ZONE, finished_at TIMESTAMP WITH TIME ZONE ) - """ - ) + """) # Migration safety: older deployments may have allowed multiple active # jobs per SR. Creating the partial unique index would fail if such # duplicates exist. Before creating the index, we dedupe by keeping # the most recent active job per sr_id and canceling older ones. - cur.execute( - """ + cur.execute(""" WITH ranked AS ( SELECT id, sr_id, @@ -76,22 +73,18 @@ def ensure_tables(self) -> None: FROM ranked r WHERE j.id = r.id AND r.rn > 1 - """ - ) + """) # Enforce: only one active run-all job per SR at a time. # Active statuses are queued/running/paused. # This is the critical concurrency guard (race-safe across users). - cur.execute( - """ + cur.execute(""" CREATE UNIQUE INDEX IF NOT EXISTS run_all_jobs_one_active_per_sr ON run_all_jobs (sr_id) WHERE status IN ('queued', 'running', 'paused') - """ - ) + """) # Store only failures as requested - cur.execute( - """ + cur.execute(""" CREATE TABLE IF NOT EXISTS run_all_job_errors ( id BIGSERIAL PRIMARY KEY, job_id UUID NOT NULL REFERENCES run_all_jobs(id) ON DELETE CASCADE, @@ -100,14 +93,12 @@ def ensure_tables(self) -> None: error TEXT, created_at TIMESTAMP WITH TIME ZONE DEFAULT now() ) - """ - ) + """) # Chunk scheduling table (fairness): store chunk definitions and status. # Each run-all job will only enqueue a small number of chunks at a time # (prefetch) so multiple run-all jobs can make progress concurrently. - cur.execute( - """ + cur.execute(""" CREATE TABLE IF NOT EXISTS run_all_job_chunks ( id BIGSERIAL PRIMARY KEY, job_id UUID NOT NULL REFERENCES run_all_jobs(id) ON DELETE CASCADE, @@ -120,15 +111,12 @@ def ensure_tables(self) -> None: finished_at TIMESTAMP WITH TIME ZONE, UNIQUE(job_id, chunk_index) ) - """ - ) + """) - cur.execute( - """ + cur.execute(""" CREATE INDEX IF NOT EXISTS run_all_job_chunks_lookup ON run_all_job_chunks (job_id, status, chunk_index) - """ - ) + """) conn.commit() except Exception: _safe_rollback(conn) @@ -211,7 +199,9 @@ def claim_next_todo_chunk(self, job_id: str, *, prefetch: int = 2) -> Optional[i # Serialize claims per job to avoid races where multiple workers # simultaneously observe the same doing-count and over-claim. # (Row-level locking on the parent job is cheap and safe.) - cur.execute("SELECT 1 FROM run_all_jobs WHERE id = %s FOR UPDATE", (job_id,)) + cur.execute( + "SELECT 1 FROM run_all_jobs WHERE id = %s FOR UPDATE", (job_id,) + ) # Enforce prefetch limit: claim a todo chunk only if currently # doing_count < pf. @@ -325,13 +315,11 @@ def count_active_jobs(self) -> int: try: conn = postgres_server.conn cur = conn.cursor() - cur.execute( - """ + cur.execute(""" SELECT COUNT(1) FROM run_all_jobs WHERE status IN ('queued', 'running', 'paused') - """ - ) + """) row = cur.fetchone() return int(row[0] or 0) if row else 0 except Exception: @@ -491,7 +479,9 @@ def get_job(self, job_id: str) -> Optional[Dict[str, Any]]: _safe_rollback(conn) raise - def set_status(self, job_id: str, status: str, *, error: Optional[str] = None) -> None: + def set_status( + self, job_id: str, status: str, *, error: Optional[str] = None + ) -> None: conn = None try: conn = postgres_server.conn @@ -553,7 +543,9 @@ def set_total(self, job_id: str, total: int) -> None: _safe_rollback(conn) raise - def inc_counts(self, job_id: str, *, done: int = 0, skipped: int = 0, failed: int = 0) -> None: + def inc_counts( + self, job_id: str, *, done: int = 0, skipped: int = 0, failed: int = 0 + ) -> None: conn = None try: conn = postgres_server.conn @@ -590,7 +582,9 @@ def is_canceled(self, job_id: str) -> bool: _safe_rollback(conn) raise - def add_error(self, job_id: str, *, citation_id: Optional[int], stage: str, error: str) -> None: + def add_error( + self, job_id: str, *, citation_id: Optional[int], stage: str, error: str + ) -> None: conn = None try: conn = postgres_server.conn @@ -600,7 +594,12 @@ def add_error(self, job_id: str, *, citation_id: Optional[int], stage: str, erro INSERT INTO run_all_job_errors (job_id, citation_id, stage, error) VALUES (%s, %s, %s, %s) """, - (job_id, int(citation_id) if citation_id is not None else None, stage, error[:8000]), + ( + job_id, + int(citation_id) if citation_id is not None else None, + stage, + error[:8000], + ), ) conn.commit() except Exception: diff --git a/backend/api/jobs/run_all_tasks.py b/backend/api/jobs/run_all_tasks.py index c8c43c48..fda82afa 100644 --- a/backend/api/jobs/run_all_tasks.py +++ b/backend/api/jobs/run_all_tasks.py @@ -9,7 +9,12 @@ from .run_all_repo import run_all_repo from .procrastinate_app import worker_concurrency from ..services.sr_db_service import srdb_service -from ..services.cit_db_service import cits_dp_service, snake_case_column, snake_case_param, snake_case +from ..services.cit_db_service import ( + cits_dp_service, + snake_case_column, + snake_case_param, + snake_case, +) from ..citations import router as citations_router from ..services.azure_openai_client import azure_openai_client from ..services.storage import storage_service @@ -93,12 +98,16 @@ def _eligible_ids(*, sr_id: str, table_name: str, step: str) -> List[int]: elif step == "extract": filter_step = "l2" - ids = cits_dp_service.list_citation_ids(filter_step if filter_step else None, table_name) + ids = cits_dp_service.list_citation_ids( + filter_step if filter_step else None, table_name + ) # PDF gating for l2/extract if step in ("l2", "extract"): # keep only rows with fulltext_url - rows = cits_dp_service.get_citations_by_ids(ids, table_name, fields=["id", "fulltext_url"]) + rows = cits_dp_service.get_citations_by_ids( + ids, table_name, fields=["id", "fulltext_url"] + ) ok = [] for r in rows: try: @@ -121,7 +130,9 @@ async def _run_l1_for_citation( force: bool, ) -> tuple[int, int, int]: """Returns (done, skipped, failed) increments for this citation.""" - row = await run_in_threadpool(cits_dp_service.get_citation_by_id, citation_id, table_name) + row = await run_in_threadpool( + cits_dp_service.get_citation_by_id, citation_id, table_name + ) if not row: return (0, 0, 1) @@ -129,7 +140,9 @@ async def _run_l1_for_citation( if not include_cols: include_cols = ["title", "abstract"] - citation_text = citations_router._build_combined_citation_from_row(row, include_cols) + citation_text = citations_router._build_combined_citation_from_row( + row, include_cols + ) cp = sr.get("criteria_parsed") or sr.get("criteria") or {} l1 = cp.get("l1") if isinstance(cp, dict) else None @@ -149,7 +162,9 @@ async def _run_l1_for_citation( if await run_in_threadpool(run_all_repo.is_canceled, job_id): raise RunAllCanceled() await _wait_if_paused(job_id) - opts = possible[i] if i < len(possible) and isinstance(possible[i], list) else [] + opts = ( + possible[i] if i < len(possible) and isinstance(possible[i], list) else [] + ) xtra = addinfos[i] if i < len(addinfos) and isinstance(addinfos[i], str) else "" col = snake_case_column(q) existing = row.get(col) @@ -160,7 +175,9 @@ async def _run_l1_for_citation( raise RuntimeError("Azure OpenAI client not configured") options_listed = "\n".join([f"{j}. {opt}" for j, opt in enumerate(opts)]) - prompt = PROMPT_JSON_TEMPLATE.format(question=q, cit=citation_text, options=options_listed, xtra=xtra) + prompt = PROMPT_JSON_TEMPLATE.format( + question=q, cit=citation_text, options=options_listed, xtra=xtra + ) llm_response = await azure_openai_client.simple_chat( user_message=prompt, system_prompt=None, @@ -181,15 +198,28 @@ async def _run_l1_for_citation( classification_json = { "selected": resolved_selected, - "explanation": parsed.get("explanation") or parsed.get("reason") or parsed.get("explain") or "", - "confidence": float(parsed.get("confidence") or 0.0) if str(parsed.get("confidence") or "").strip() else 0.0, + "explanation": parsed.get("explanation") + or parsed.get("reason") + or parsed.get("explain") + or "", + "confidence": ( + float(parsed.get("confidence") or 0.0) + if str(parsed.get("confidence") or "").strip() + else 0.0 + ), "evidence_sentences": parsed.get("evidence_sentences") or [], "evidence_tables": parsed.get("evidence_tables") or [], "evidence_figures": parsed.get("evidence_figures") or [], "llm_raw": llm_response, } - await run_in_threadpool(cits_dp_service.update_jsonb_column, citation_id, col, classification_json, table_name) + await run_in_threadpool( + cits_dp_service.update_jsonb_column, + citation_id, + col, + classification_json, + table_name, + ) # Best-effort autofill human_ if empty try: @@ -229,7 +259,9 @@ async def _ensure_fulltext_if_needed( force: bool, ) -> bool: """Ensure fulltext artifacts exist. Returns True if available after this call.""" - row = await run_in_threadpool(cits_dp_service.get_citation_by_id, citation_id, table_name) + row = await run_in_threadpool( + cits_dp_service.get_citation_by_id, citation_id, table_name + ) if not row: return False if not row.get("fulltext_url"): @@ -242,9 +274,13 @@ async def _ensure_fulltext_if_needed( await extract_fulltext_from_storage(sr_id, citation_id, current_user=current_user) # type: ignore except Exception: # It's okay if DI/grobid fails; L2/extract depends on fulltext text though. - row2 = await run_in_threadpool(cits_dp_service.get_citation_by_id, citation_id, table_name) + row2 = await run_in_threadpool( + cits_dp_service.get_citation_by_id, citation_id, table_name + ) return bool(row2 and row2.get("fulltext")) - row3 = await run_in_threadpool(cits_dp_service.get_citation_by_id, citation_id, table_name) + row3 = await run_in_threadpool( + cits_dp_service.get_citation_by_id, citation_id, table_name + ) return bool(row3 and row3.get("fulltext")) @@ -258,18 +294,28 @@ async def _run_l2_for_citation( model: Optional[str], force: bool, ) -> tuple[int, int, int]: - row = await run_in_threadpool(cits_dp_service.get_citation_by_id, citation_id, table_name) + row = await run_in_threadpool( + cits_dp_service.get_citation_by_id, citation_id, table_name + ) if not row or not row.get("fulltext_url"): return (0, 1, 0) # fake current_user for storage service paths (it only uses id in upload; extract reads by path) current_user = {"id": "system", "email": "system"} await _wait_if_paused(job_id) - ok = await _ensure_fulltext_if_needed(sr_id=sr_id, citation_id=citation_id, current_user=current_user, table_name=table_name, force=force) + ok = await _ensure_fulltext_if_needed( + sr_id=sr_id, + citation_id=citation_id, + current_user=current_user, + table_name=table_name, + force=force, + ) if not ok: return (0, 1, 0) - row = await run_in_threadpool(cits_dp_service.get_citation_by_id, citation_id, table_name) + row = await run_in_threadpool( + cits_dp_service.get_citation_by_id, citation_id, table_name + ) if not row: return (0, 0, 1) @@ -284,8 +330,13 @@ async def _run_l2_for_citation( if not questions: return (0, 1, 0) - include_cols = cits_dp_service.load_include_columns_from_criteria(sr) or ["title", "abstract"] - citation_text = citations_router._build_combined_citation_from_row(row, include_cols) + include_cols = cits_dp_service.load_include_columns_from_criteria(sr) or [ + "title", + "abstract", + ] + citation_text = citations_router._build_combined_citation_from_row( + row, include_cols + ) fulltext = row.get("fulltext") or citation_text # Tables/Figures context from row @@ -333,7 +384,9 @@ async def _run_l2_for_citation( caption = item.get("caption") if not idx or not blob_addr: continue - figures_lines.append(f"Figure [F{idx}] caption: {caption or '(no caption)'} (see attached image F{idx})") + figures_lines.append( + f"Figure [F{idx}] caption: {caption or '(no caption)'} (see attached image F{idx})" + ) try: img_bytes, _ = await storage_service.get_bytes_by_path(blob_addr) if img_bytes: @@ -347,7 +400,9 @@ async def _run_l2_for_citation( if await run_in_threadpool(run_all_repo.is_canceled, job_id): raise RunAllCanceled() await _wait_if_paused(job_id) - opts = possible[i] if i < len(possible) and isinstance(possible[i], list) else [] + opts = ( + possible[i] if i < len(possible) and isinstance(possible[i], list) else [] + ) xtra = addinfos[i] if i < len(addinfos) and isinstance(addinfos[i], str) else "" col = snake_case_column(q) existing = row.get(col) @@ -392,15 +447,28 @@ async def _run_l2_for_citation( classification_json = { "selected": resolved_selected, - "explanation": parsed.get("explanation") or parsed.get("reason") or parsed.get("explain") or "", - "confidence": float(parsed.get("confidence") or 0.0) if str(parsed.get("confidence") or "").strip() else 0.0, + "explanation": parsed.get("explanation") + or parsed.get("reason") + or parsed.get("explain") + or "", + "confidence": ( + float(parsed.get("confidence") or 0.0) + if str(parsed.get("confidence") or "").strip() + else 0.0 + ), "evidence_sentences": parsed.get("evidence_sentences") or [], "evidence_tables": parsed.get("evidence_tables") or [], "evidence_figures": parsed.get("evidence_figures") or [], "llm_raw": llm_response, } - await run_in_threadpool(cits_dp_service.update_jsonb_column, citation_id, col, classification_json, table_name) + await run_in_threadpool( + cits_dp_service.update_jsonb_column, + citation_id, + col, + classification_json, + table_name, + ) # Best-effort autofill human try: @@ -440,24 +508,36 @@ async def _run_extract_for_citation( model: Optional[str], force: bool, ) -> tuple[int, int, int]: - row = await run_in_threadpool(cits_dp_service.get_citation_by_id, citation_id, table_name) + row = await run_in_threadpool( + cits_dp_service.get_citation_by_id, citation_id, table_name + ) if not row or not row.get("fulltext_url"): return (0, 1, 0) current_user = {"id": "system", "email": "system"} await _wait_if_paused(job_id) - ok = await _ensure_fulltext_if_needed(sr_id=sr_id, citation_id=citation_id, current_user=current_user, table_name=table_name, force=force) + ok = await _ensure_fulltext_if_needed( + sr_id=sr_id, + citation_id=citation_id, + current_user=current_user, + table_name=table_name, + force=force, + ) if not ok: return (0, 1, 0) - row = await run_in_threadpool(cits_dp_service.get_citation_by_id, citation_id, table_name) + row = await run_in_threadpool( + cits_dp_service.get_citation_by_id, citation_id, table_name + ) if not row: return (0, 0, 1) cp = sr.get("criteria_parsed") or sr.get("criteria") or {} params = cp.get("parameters") if isinstance(cp, dict) else None categories = (params or {}).get("categories") if isinstance(params, dict) else [] - possible = (params or {}).get("possible_parameters") if isinstance(params, dict) else [] + possible = ( + (params or {}).get("possible_parameters") if isinstance(params, dict) else [] + ) descs = (params or {}).get("descriptions") if isinstance(params, dict) else [] categories = categories if isinstance(categories, list) else [] possible = possible if isinstance(possible, list) else [] @@ -528,7 +608,9 @@ async def _run_extract_for_citation( caption = item.get("caption") if not idx or not blob_addr: continue - flines.append(f"Figure [F{idx}] caption: {caption or '(no caption)'} (see attached image F{idx})") + flines.append( + f"Figure [F{idx}] caption: {caption or '(no caption)'} (see attached image F{idx})" + ) try: img_bytes, _ = await storage_service.get_bytes_by_path(blob_addr) if img_bytes: @@ -621,7 +703,9 @@ def _extract_json_object(text: str) -> Optional[str]: "llm_raw": str(llm_response)[:4000], } - await run_in_threadpool(cits_dp_service.update_jsonb_column, citation_id, col, stored, table_name) + await run_in_threadpool( + cits_dp_service.update_jsonb_column, citation_id, col, stored, table_name + ) # best-effort autofill human_param try: @@ -664,16 +748,23 @@ async def run_all_start(job_id: str) -> None: ids = [] sr, table_name = await _load_sr_and_table(sr_id) - ids = await run_in_threadpool(_eligible_ids, sr_id=sr_id, table_name=table_name, step=step) + ids = await run_in_threadpool( + _eligible_ids, sr_id=sr_id, table_name=table_name, step=step + ) await run_in_threadpool(run_all_repo.set_total, job_id, len(ids)) - print(f"[run-all] kickoff job_id={job_id} step={step} eligible={len(ids)}", flush=True) + print( + f"[run-all] kickoff job_id={job_id} step={step} eligible={len(ids)}", + flush=True, + ) # If user paused before kickoff, preserve paused status. if await run_in_threadpool(run_all_repo.is_paused, job_id): await run_in_threadpool(run_all_repo.set_status, job_id, "paused") else: await run_in_threadpool(run_all_repo.set_status, job_id, "running") - await run_in_threadpool(run_all_repo.update_phase, job_id, f"enqueued {len(ids)}") + await run_in_threadpool( + run_all_repo.update_phase, job_id, f"enqueued {len(ids)}" + ) # Fair scheduling: persist chunks and enqueue only the next chunk. chunks = [ids[i : i + chunk_size] for i in range(0, len(ids), chunk_size)] @@ -689,7 +780,9 @@ async def run_all_start(job_id: str) -> None: ) if next_chunk_id is None: break - await run_all_chunk.defer_async(job_id=job_id, chunk_id=int(next_chunk_id)) + await run_all_chunk.defer_async( + job_id=job_id, chunk_id=int(next_chunk_id) + ) except Exception as e: await run_in_threadpool(run_all_repo.set_status, job_id, "failed", error=str(e)) @@ -739,18 +832,49 @@ async def run_all_chunk(job_id: str, chunk_id: int) -> None: try: await _wait_if_paused(job_id) if step == "l1": - d, s, f = await _run_l1_for_citation(job_id=job_id, sr=sr, table_name=table_name, citation_id=int(cid), model=model, force=force) + d, s, f = await _run_l1_for_citation( + job_id=job_id, + sr=sr, + table_name=table_name, + citation_id=int(cid), + model=model, + force=force, + ) elif step == "l2": - d, s, f = await _run_l2_for_citation(job_id=job_id, sr=sr, table_name=table_name, sr_id=sr_id, citation_id=int(cid), model=model, force=force) + d, s, f = await _run_l2_for_citation( + job_id=job_id, + sr=sr, + table_name=table_name, + sr_id=sr_id, + citation_id=int(cid), + model=model, + force=force, + ) elif step == "extract": - d, s, f = await _run_extract_for_citation(job_id=job_id, sr=sr, table_name=table_name, sr_id=sr_id, citation_id=int(cid), model=model, force=force) + d, s, f = await _run_extract_for_citation( + job_id=job_id, + sr=sr, + table_name=table_name, + sr_id=sr_id, + citation_id=int(cid), + model=model, + force=force, + ) else: d, s, f = (0, 1, 0) - await run_in_threadpool(run_all_repo.inc_counts, job_id, done=d, skipped=s, failed=f) + await run_in_threadpool( + run_all_repo.inc_counts, job_id, done=d, skipped=s, failed=f + ) except RunAllCanceled: return except Exception as e: - await run_in_threadpool(run_all_repo.add_error, job_id, citation_id=int(cid), stage=step, error=str(e)) + await run_in_threadpool( + run_all_repo.add_error, + job_id, + citation_id=int(cid), + stage=step, + error=str(e), + ) await run_in_threadpool(run_all_repo.inc_counts, job_id, failed=1) chunk_failed = True chunk_error = str(e) @@ -763,7 +887,11 @@ async def run_all_chunk(job_id: str, chunk_id: int) -> None: # Mark chunk complete and schedule more work (up to prefetch). try: if chunk_failed: - await run_in_threadpool(run_all_repo.mark_chunk_failed, int(chunk_id), error=chunk_error or "chunk had failures") + await run_in_threadpool( + run_all_repo.mark_chunk_failed, + int(chunk_id), + error=chunk_error or "chunk had failures", + ) else: await run_in_threadpool(run_all_repo.mark_chunk_done, int(chunk_id)) except Exception: @@ -802,7 +930,11 @@ async def run_all_chunk(job_id: str, chunk_id: int) -> None: # Successful completion should remain visible in the UI until the # user dismisses it. We model that as a terminal-but-sticky status # called "finished". Dismissal transitions "finished" -> "done". - if total > 0 and (done + skipped + failed) >= total and not await run_in_threadpool(run_all_repo.is_canceled, job_id): + if ( + total > 0 + and (done + skipped + failed) >= total + and not await run_in_threadpool(run_all_repo.is_canceled, job_id) + ): await run_in_threadpool(run_all_repo.set_status, job_id, "finished") except Exception: pass diff --git a/backend/api/models/auth.py b/backend/api/models/auth.py index 7e34fea0..4588ad2e 100644 --- a/backend/api/models/auth.py +++ b/backend/api/models/auth.py @@ -1,6 +1,7 @@ """ Authentication and user models for the API. """ + from typing import Optional from datetime import datetime diff --git a/backend/api/router.py b/backend/api/router.py index 94d5144d..3950926d 100644 --- a/backend/api/router.py +++ b/backend/api/router.py @@ -37,4 +37,6 @@ api_router.include_router(jobs_router, prefix="/jobs", tags=["Jobs"]) # Database Search API -api_router.include_router(database_search_router, prefix="/database_search", tags=["Database Search"]) +api_router.include_router( + database_search_router, prefix="/database_search", tags=["Database Search"] +) diff --git a/backend/api/screen/prompts.py b/backend/api/screen/prompts.py index ba7ac9d5..8d5ae0c7 100644 --- a/backend/api/screen/prompts.py +++ b/backend/api/screen/prompts.py @@ -72,4 +72,4 @@ - Use sentence indices from the numbered full text for "evidence_sentences" - Use table numbers from the Tables section for "evidence_tables" - Use figure numbers from the Figures section for "evidence_figures" -""" \ No newline at end of file +""" diff --git a/backend/api/screen/router.py b/backend/api/screen/router.py index 10174b82..93bd917b 100644 --- a/backend/api/screen/router.py +++ b/backend/api/screen/router.py @@ -54,18 +54,36 @@ def _normalize_int_list(v: Any) -> List[int]: class ClassifyRequest(BaseModel): citation_text: Optional[str] = Field( - None, description="Optional combined citation text. If omitted the server will build it from the screening DB row." + None, + description="Optional combined citation text. If omitted the server will build it from the screening DB row.", ) include_columns: Optional[List[str]] = Field( - None, description="If citation_text is omitted, these columns (original CSV headers) will be used to build the combined citation" + None, + description="If citation_text is omitted, these columns (original CSV headers) will be used to build the combined citation", ) - question: str = Field(..., description="L1 criteria question to apply to this citation") - screening_step: str = Field(..., description="Screening step identifier: 'l1' or 'l2', etc.") - options: List[str] = Field(..., description="List of possible options (exact strings). The model must pick one.") - xtra: Optional[str] = Field("", description="Additional context/instructions for the model") - model: Optional[str] = Field(None, description="Model to use (falls back to default configured model)") - temperature: Optional[float] = Field(0.0, ge=0.0, le=1.0, description="Sampling temperature") - max_tokens: Optional[int] = Field(2000, ge=1, le=4000, description="Max tokens for LLM response") + question: str = Field( + ..., description="L1 criteria question to apply to this citation" + ) + screening_step: str = Field( + ..., description="Screening step identifier: 'l1' or 'l2', etc." + ) + options: List[str] = Field( + ..., + description="List of possible options (exact strings). The model must pick one.", + ) + xtra: Optional[str] = Field( + "", description="Additional context/instructions for the model" + ) + model: Optional[str] = Field( + None, description="Model to use (falls back to default configured model)" + ) + temperature: Optional[float] = Field( + 0.0, ge=0.0, le=1.0, description="Sampling temperature" + ) + max_tokens: Optional[int] = Field( + 2000, ge=1, le=4000, description="Max tokens for LLM response" + ) + class HumanClassifyRequest(BaseModel): """ @@ -73,24 +91,35 @@ class HumanClassifyRequest(BaseModel): This mirrors the shape of the LLM-based classify payload but accepts a direct `selected` value and optional explanation/confidence. """ + citation_text: Optional[str] = Field( - None, description="Optional combined citation text. If omitted the server will build it from the screening DB row." + None, + description="Optional combined citation text. If omitted the server will build it from the screening DB row.", ) include_columns: Optional[List[str]] = Field( - None, description="If citation_text is omitted, these columns (original CSV headers) will be used to build the combined citation" + None, + description="If citation_text is omitted, these columns (original CSV headers) will be used to build the combined citation", + ) + question: str = Field( + ..., description="L1 criteria question to apply to this citation" ) - question: str = Field(..., description="L1 criteria question to apply to this citation") selected: str = Field(..., description="Human-selected option (string)") - screening_step: str = Field(..., description="Screening step identifier: 'l1' or 'l2', etc.") - explanation: Optional[str] = Field("", description="Optional free-text explanation from the human reviewer") - confidence: Optional[float] = Field(None, ge=0.0, le=1.0, description="Optional confidence (0.0 - 1.0)") + screening_step: str = Field( + ..., description="Screening step identifier: 'l1' or 'l2', etc." + ) + explanation: Optional[str] = Field( + "", description="Optional free-text explanation from the human reviewer" + ) + confidence: Optional[float] = Field( + None, ge=0.0, le=1.0, description="Optional confidence (0.0 - 1.0)" + ) reviewer: Optional[str] = Field(None, description="Optional reviewer id or name") - + + # _update_sync moved to backend.api.core.postgres.update_jsonb_column # Use run_in_threadpool(update_jsonb_column, ...) where needed. - @router.post("/{sr_id}/citations/{citation_id}/classify") async def classify_citation( sr_id: str, @@ -111,27 +140,47 @@ async def classify_citation( except HTTPException: raise except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to load systematic review or screening: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to load systematic review or screening: {e}", + ) table_name = (screening or {}).get("table_name") or "citations" # Load citation row (needed for l2 fulltext and for building citation_text) try: - row = await run_in_threadpool(cits_dp_service.get_citation_by_id, int(citation_id), table_name) + row = await run_in_threadpool( + cits_dp_service.get_citation_by_id, int(citation_id), table_name + ) except RuntimeError as rexc: - raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc) + ) except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to query screening DB: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to query screening DB: {e}", + ) if not row: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found" + ) # Build or use provided citation text (fall back to combined title/abstract when not provided) - citation_text = payload.citation_text or citations_router._build_combined_citation_from_row(row, payload.include_columns) + citation_text = ( + payload.citation_text + or citations_router._build_combined_citation_from_row( + row, payload.include_columns + ) + ) # Ensure LLM client is available if not azure_openai_client.is_configured(): - raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Azure OpenAI client is not configured on the server") + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Azure OpenAI client is not configured on the server", + ) # Prepare prompt (use full-text template for l2, otherwise TA/L1 template) options_listed = "\n".join([f"{i}. {opt}" for i, opt in enumerate(payload.options)]) @@ -242,17 +291,29 @@ async def classify_citation( try: parsed = json.loads(llm_response) except Exception: - raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=f"LLM response was not valid JSON: {llm_response[:1000]}") + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail=f"LLM response was not valid JSON: {llm_response[:1000]}", + ) if not isinstance(parsed, dict): - raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=f"LLM response JSON was not an object: {str(type(parsed))}") + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail=f"LLM response JSON was not an object: {str(type(parsed))}", + ) # Require 'selected' key and validate it is a string if "selected" not in parsed: - raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=f"LLM JSON missing 'selected' key: {json.dumps(parsed)[:1000]}") + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail=f"LLM JSON missing 'selected' key: {json.dumps(parsed)[:1000]}", + ) selected_value = parsed.get("selected") if not isinstance(selected_value, str): - raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=f"LLM 'selected' must be a string: {str(type(selected_value))}") + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail=f"LLM 'selected' must be a string: {str(type(selected_value))}", + ) s = selected_value.strip() @@ -263,7 +324,9 @@ async def classify_citation( resolved_selected = opt break - explanation = parsed.get("explanation") or parsed.get("reason") or parsed.get("explain") or "" + explanation = ( + parsed.get("explanation") or parsed.get("reason") or parsed.get("explain") or "" + ) confidence_raw = parsed.get("confidence") # Parse confidence @@ -298,14 +361,27 @@ async def classify_citation( human_col_name = f"human_{col_core_h}" if col_core_h else "human_col" try: - updated = await run_in_threadpool(cits_dp_service.update_jsonb_column, citation_id, col_name, classification_json, table_name) + updated = await run_in_threadpool( + cits_dp_service.update_jsonb_column, + citation_id, + col_name, + classification_json, + table_name, + ) except RuntimeError as rexc: - raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc) + ) except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to update citation row: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to update citation row: {e}", + ) if not updated: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found to update") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found to update" + ) # Auto-fill human_* from llm_* if missing (never overwrite) try: @@ -326,10 +402,16 @@ async def classify_citation( except Exception: # best-effort pass - + await update_inclusion_decision(sr, citation_id, payload.screening_step, "llm") - return {"status": "success", "sr_id": sr_id, "citation_id": citation_id, "column": col_name, "classification": classification_json} + return { + "status": "success", + "sr_id": sr_id, + "citation_id": citation_id, + "column": col_name, + "classification": classification_json, + } @router.post("/{sr_id}/citations/{citation_id}/human_classify") @@ -350,20 +432,32 @@ async def human_classify_citation( except HTTPException: raise except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to load systematic review or screening: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to load systematic review or screening: {e}", + ) table_name = (screening or {}).get("table_name") or "citations" # Ensure citation exists and optionally build combined citation text try: - row = await run_in_threadpool(cits_dp_service.get_citation_by_id, int(citation_id), table_name) + row = await run_in_threadpool( + cits_dp_service.get_citation_by_id, int(citation_id), table_name + ) except RuntimeError as rexc: - raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc) + ) except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to query screening DB: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to query screening DB: {e}", + ) if not row: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found" + ) citation_text = payload.citation_text confidence = payload.confidence @@ -381,28 +475,48 @@ async def human_classify_citation( # Persist into Postgres under a dynamic column name derived from question # Use snake_case to create a stable core name and prefix with 'human_' col_core = snake_case(payload.question, max_len=56) if snake_case else None - col_name = f"human_{col_core}" if col_core else f"human_col" + col_name = f"human_{col_core}" if col_core else f"human_col" try: - updated = await run_in_threadpool(cits_dp_service.update_jsonb_column, citation_id, col_name, classification_json, table_name) + updated = await run_in_threadpool( + cits_dp_service.update_jsonb_column, + citation_id, + col_name, + classification_json, + table_name, + ) except RuntimeError as rexc: - raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc) + ) except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to update citation row: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to update citation row: {e}", + ) if not updated: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found to update") - + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found to update" + ) + await update_inclusion_decision(sr, citation_id, payload.screening_step, "human") - return {"status": "success", "sr_id": sr_id, "citation_id": citation_id, "column": col_name, "classification": classification_json} + return { + "status": "success", + "sr_id": sr_id, + "citation_id": citation_id, + "column": col_name, + "classification": classification_json, + } + async def update_inclusion_decision( sr: Dict[str, Any], citation_id: int, screening_step: str, decision_maker: str, -): +): table_name = (sr.get("screening_db") or {}).get("table_name") or "citations" # IMPORTANT: decision/pass computation must not be stale. @@ -415,10 +529,15 @@ def _get_row_fresh() -> Dict[str, Any]: try: row = await run_in_threadpool(_get_row_fresh) except RuntimeError as rexc: - raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc) + ) except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to query screening DB: {e}") - + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to query screening DB: {e}", + ) + questions = sr["criteria_parsed"][screening_step]["questions"] classified = True decision = "undecided" @@ -435,17 +554,30 @@ def _get_row_fresh() -> Dict[str, Any]: if classified and decision == "undecided": decision = "include" - + col_name = f"{decision_maker}_{screening_step}_decision" try: - updated = await run_in_threadpool(cits_dp_service.update_text_column, citation_id, col_name, decision, table_name) + updated = await run_in_threadpool( + cits_dp_service.update_text_column, + citation_id, + col_name, + decision, + table_name, + ) except RuntimeError as rexc: - raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)) + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc) + ) except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to update citation row: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to update citation row: {e}", + ) print(updated) if not updated: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found to update") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found to update" + ) # Validation rule (B1/B2): do NOT use l1_screen/l2_screen for filtering. # Ensure human_l1_decision / human_l2_decision are always correct and derived @@ -489,7 +621,9 @@ def _compute_human_decision(step: str) -> str: except Exception: selected = None # Treat empty/whitespace as unanswered (UI shows "-- select --") - if selected is None or (isinstance(selected, str) and selected.strip() == ""): + if selected is None or ( + isinstance(selected, str) and selected.strip() == "" + ): return "undecided" if "exclude" in str(selected).lower(): return "exclude" @@ -499,8 +633,20 @@ def _compute_human_decision(step: str) -> str: # Always set both human decisions on any update, so the list filters never go stale. h1 = _compute_human_decision("l1") h2 = _compute_human_decision("l2") - await run_in_threadpool(cits_dp_service.update_text_column, citation_id, "human_l1_decision", h1, table_name) - await run_in_threadpool(cits_dp_service.update_text_column, citation_id, "human_l2_decision", h2, table_name) + await run_in_threadpool( + cits_dp_service.update_text_column, + citation_id, + "human_l1_decision", + h1, + table_name, + ) + await run_in_threadpool( + cits_dp_service.update_text_column, + citation_id, + "human_l2_decision", + h2, + table_name, + ) except Exception: # best-effort; do not block response pass diff --git a/backend/api/services/azure_docint_client.py b/backend/api/services/azure_docint_client.py index cf0eb071..be1c3150 100644 --- a/backend/api/services/azure_docint_client.py +++ b/backend/api/services/azure_docint_client.py @@ -42,6 +42,7 @@ from typing import Any as DocumentIntelligenceClient # type: ignore from typing import Any as AnalyzeDocumentRequest # type: ignore from typing import Any as AzureKeyCredential # type: ignore + print( "Azure Document Intelligence SDK not installed. Install with: pip install azure-ai-documentintelligence azure-core" ) @@ -63,13 +64,13 @@ def _init_client(self) -> Optional["DocumentIntelligenceClient"]: """Initialize Azure Document Intelligence client""" if not AZURE_DOC_INTELLIGENCE_AVAILABLE: return None - + if settings.AZURE_DOC_INT_MODE not in ["key", "entra"]: print( f"Invalid AZURE_DOC_INT_MODE: {settings.AZURE_DOC_INT_MODE}. Must be 'key' or 'entra'." ) return None - + if not settings.AZURE_DOC_INT_ENDPOINT: print( "Azure Document Intelligence endpoint not found. Set AZURE_DOC_INT_ENDPOINT environment variable." @@ -81,11 +82,13 @@ def _init_client(self) -> Optional["DocumentIntelligenceClient"]: "Azure Document Intelligence API key not found. Set AZURE_DOC_INT_API_KEY for key-based auth." ) return None - + doc_int_kwargs = {"endpoint": settings.AZURE_DOC_INT_ENDPOINT} if settings.AZURE_DOC_INT_MODE == "key": - doc_int_kwargs["credential"] = AzureKeyCredential(settings.AZURE_DOC_INT_API_KEY) + doc_int_kwargs["credential"] = AzureKeyCredential( + settings.AZURE_DOC_INT_API_KEY + ) elif settings.AZURE_DOC_INT_MODE == "entra": doc_int_kwargs["credential"] = DefaultAzureCredential() @@ -680,7 +683,9 @@ def _extract_html_tables_from_markdown(markdown: str) -> List[str]: table_pattern = r".*?
" return re.findall(table_pattern, markdown, re.DOTALL) - def _download_figure_bytes_sync(self, result_id: str, figure_id: str) -> Optional[bytes]: + def _download_figure_bytes_sync( + self, result_id: str, figure_id: str + ) -> Optional[bytes]: """Download a single figure image from Azure DI as bytes (sync).""" try: stream = self.client.get_analyze_result_figure( @@ -719,7 +724,10 @@ async def extract_citation_artifacts( """ if not self.client: - return {"success": False, "error": "Azure Document Intelligence client not available"} + return { + "success": False, + "error": "Azure Document Intelligence client not available", + } # We need figures in output. output_param = ["figures"] @@ -754,7 +762,9 @@ async def extract_citation_artifacts( for i, md in enumerate(md_tables, start=1): bbox = None if i - 1 < len(raw_tables): - bbox = raw_tables[i - 1].get("boundingRegions") or raw_tables[i - 1].get("bounding_regions") + bbox = raw_tables[i - 1].get("boundingRegions") or raw_tables[ + i - 1 + ].get("bounding_regions") tables_out.append( { "index": i, @@ -783,7 +793,7 @@ async def extract_citation_artifacts( bounding_regions = [] try: - for region in (getattr(fig, "bounding_regions", None) or []): + for region in getattr(fig, "bounding_regions", None) or []: bounding_regions.append( {"page_number": region.page_number, "polygon": region.polygon} ) @@ -792,7 +802,9 @@ async def extract_citation_artifacts( png_bytes = None if result_id: - png_bytes = await asyncio.to_thread(self._download_figure_bytes_sync, result_id, azure_id) + png_bytes = await asyncio.to_thread( + self._download_figure_bytes_sync, result_id, azure_id + ) if not png_bytes: # If we couldn't download, skip storing bytes (still return metadata) @@ -815,8 +827,13 @@ async def extract_citation_artifacts( { "index": idx, "azure_id": fig.get("id") or f"raw_{idx}", - "caption": (fig.get("caption", {}) or {}).get("content") if isinstance(fig.get("caption"), dict) else None, - "bounding_box": fig.get("boundingRegions") or fig.get("bounding_regions"), + "caption": ( + (fig.get("caption", {}) or {}).get("content") + if isinstance(fig.get("caption"), dict) + else None + ), + "bounding_box": fig.get("boundingRegions") + or fig.get("bounding_regions"), "png_bytes": b"", } ) @@ -834,4 +851,4 @@ async def extract_citation_artifacts( try: azure_docint_client = AzureDocIntelligenceService() except Exception: - azure_docint_client = None # type: ignore \ No newline at end of file + azure_docint_client = None # type: ignore diff --git a/backend/api/services/azure_openai_client.py b/backend/api/services/azure_openai_client.py index 8f885418..32d894f0 100644 --- a/backend/api/services/azure_openai_client.py +++ b/backend/api/services/azure_openai_client.py @@ -101,12 +101,10 @@ def __init__(self): self._official_clients: Dict[Tuple[str, str, str], AzureOpenAI] = {} self._official_async_clients: Dict[Tuple[str, str, str], Any] = {} - # --------------------------------------------------------------------- # Configuration # --------------------------------------------------------------------- - @staticmethod def _resolve_auth_type() -> str: """Return key|entra. @@ -145,11 +143,15 @@ def _load_models_yaml(self) -> Dict[str, Any]: try: data = yaml.safe_load(path.read_text(encoding="utf-8")) if not isinstance(data, dict): - logger.warning("Invalid models.yaml format (expected mapping): %s", type(data)) + logger.warning( + "Invalid models.yaml format (expected mapping): %s", type(data) + ) return {} return data except Exception as e: - logger.exception("Failed to load Azure OpenAI model catalog from %s: %s", path, e) + logger.exception( + "Failed to load Azure OpenAI model catalog from %s: %s", path, e + ) return {} def _load_model_configs(self) -> Dict[str, Dict[str, str]]: @@ -244,7 +246,9 @@ def _get_official_client(self, model: str) -> AzureOpenAI: endpoint = config.get("endpoint") api_version = config.get("api_version") if not endpoint or not api_version: - raise ValueError(f"Azure OpenAI endpoint/api_version not configured for model {model}") + raise ValueError( + f"Azure OpenAI endpoint/api_version not configured for model {model}" + ) cache_key = (endpoint, api_version, self._auth_type) if cache_key not in self._official_clients: @@ -255,12 +259,16 @@ def _get_official_client(self, model: str) -> AzureOpenAI: if self._auth_type == "entra": if not self._token_provider: - raise ValueError(self._config_error or "Azure AD token provider not configured") + raise ValueError( + self._config_error or "Azure AD token provider not configured" + ) azure_openai_kwargs["azure_ad_token_provider"] = self._token_provider else: # key auth if not self._api_key: - raise ValueError("AZURE_OPENAI_MODE=key requires AZURE_OPENAI_API_KEY") + raise ValueError( + "AZURE_OPENAI_MODE=key requires AZURE_OPENAI_API_KEY" + ) azure_openai_kwargs["api_key"] = self._api_key self._official_clients[cache_key] = AzureOpenAI(**azure_openai_kwargs) @@ -284,7 +292,9 @@ def _get_official_async_client(self, model: str): endpoint = config.get("endpoint") api_version = config.get("api_version") if not endpoint or not api_version: - raise ValueError(f"Azure OpenAI endpoint/api_version not configured for model {model}") + raise ValueError( + f"Azure OpenAI endpoint/api_version not configured for model {model}" + ) cache_key = (endpoint, api_version, self._auth_type) if cache_key not in self._official_async_clients: @@ -295,11 +305,15 @@ def _get_official_async_client(self, model: str): if self._auth_type == "entra": if not self._token_provider: - raise ValueError(self._config_error or "Azure AD token provider not configured") + raise ValueError( + self._config_error or "Azure AD token provider not configured" + ) azure_openai_kwargs["azure_ad_token_provider"] = self._token_provider else: if not self._api_key: - raise ValueError("AZURE_OPENAI_MODE=key requires AZURE_OPENAI_API_KEY") + raise ValueError( + "AZURE_OPENAI_MODE=key requires AZURE_OPENAI_API_KEY" + ) azure_openai_kwargs["api_key"] = self._api_key self._official_async_clients[cache_key] = AsyncAzureOpenAI(**azure_openai_kwargs) # type: ignore @@ -358,7 +372,7 @@ async def chat_completion( "presence_penalty": presence_penalty, "stream": stream, } - + # gpt-5 deployments may reject temperature/max_tokens in some previews. # We gate this by the *deployment* name because the UI key can differ. if deployment != "gpt-5-mini": @@ -533,7 +547,9 @@ def _worker() -> None: continue content = update.choices[0].delta.content or "" if content: - asyncio.run_coroutine_threadsafe(q.put(content), loop).result() + asyncio.run_coroutine_threadsafe( + q.put(content), loop + ).result() except Exception as e: asyncio.run_coroutine_threadsafe(q.put(e), loop).result() finally: @@ -649,7 +665,11 @@ def get_available_models(self) -> List[str]: """Get list of available models that are properly configured""" out: List[str] = [] for model, config in self.model_configs.items(): - if not config.get("endpoint") or not config.get("deployment") or not config.get("api_version"): + if ( + not config.get("endpoint") + or not config.get("deployment") + or not config.get("api_version") + ): continue out.append(model) return out @@ -663,7 +683,11 @@ def get_available_deployments(self) -> List[str]: out: List[str] = [] seen: set[str] = set() for _model, config in self.model_configs.items(): - if not config.get("endpoint") or not config.get("deployment") or not config.get("api_version"): + if ( + not config.get("endpoint") + or not config.get("deployment") + or not config.get("api_version") + ): continue dep = str(config.get("deployment") or "").strip() if not dep: @@ -695,6 +719,7 @@ def is_configured(self) -> bool: azure_openai_client = AzureOpenAIClient() except Exception as e: # pragma: no cover logger.exception("Failed to initialize AzureOpenAIClient: %s", e) + # Provide a stub that reports not-configured. class _DisabledAzureOpenAIClient: # type: ignore def is_configured(self) -> bool: diff --git a/backend/api/services/cit_db_service.py b/backend/api/services/cit_db_service.py index 85371835..a4e49508 100644 --- a/backend/api/services/cit_db_service.py +++ b/backend/api/services/cit_db_service.py @@ -12,6 +12,7 @@ Methods raise RuntimeError when psycopg2 is not available so callers can surface a 503 with an actionable message. """ + from typing import Any, Dict, List, Optional, Tuple import psycopg2 import psycopg2.extras @@ -132,6 +133,7 @@ def _construct_db_dsn_from_admin(admin_dsn: str, db_name: str) -> str: else: return f"{admin_dsn} dbname={db_name}" + # ----------------------- # Citations Postgres DB service # ----------------------- @@ -149,11 +151,12 @@ def __init__(self): # Low level connection helpers # ----------------------- - # ----------------------- # Generic column ops # ----------------------- - def create_column(self, col: str, col_type: str, table_name: str = "citations") -> None: + def create_column( + self, col: str, col_type: str, table_name: str = "citations" + ) -> None: """ Create column on citations table if it doesn't already exist. col should be the exact column name to use (caller may pass snake_case(col)). @@ -165,11 +168,15 @@ def create_column(self, col: str, col_type: str, table_name: str = "citations") conn = postgres_server.conn cur = conn.cursor() try: - cur.execute(f'ALTER TABLE "{table_name}" ADD COLUMN IF NOT EXISTS "{col}" {col_type}') + cur.execute( + f'ALTER TABLE "{table_name}" ADD COLUMN IF NOT EXISTS "{col}" {col_type}' + ) except Exception: # fallback for PG versions without IF NOT EXISTS try: - cur.execute(f'ALTER TABLE "{table_name}" ADD COLUMN "{col}" {col_type}') + cur.execute( + f'ALTER TABLE "{table_name}" ADD COLUMN "{col}" {col_type}' + ) except Exception: pass conn.commit() @@ -198,16 +205,21 @@ def update_jsonb_column( conn = postgres_server.conn cur = conn.cursor() try: - cur.execute(f'ALTER TABLE "{table_name}" ADD COLUMN IF NOT EXISTS "{col}" JSONB') + cur.execute( + f'ALTER TABLE "{table_name}" ADD COLUMN IF NOT EXISTS "{col}" JSONB' + ) except Exception: try: cur.execute(f'ALTER TABLE "{table_name}" ADD COLUMN "{col}" JSONB') except Exception: pass - cur.execute(f'UPDATE "{table_name}" SET "{col}" = %s WHERE id = %s', (json.dumps(data), int(citation_id))) + cur.execute( + f'UPDATE "{table_name}" SET "{col}" = %s WHERE id = %s', + (json.dumps(data), int(citation_id)), + ) rows = cur.rowcount conn.commit() - + return rows or 0 except Exception: _safe_rollback(conn) @@ -232,13 +244,18 @@ def update_text_column( conn = postgres_server.conn cur = conn.cursor() try: - cur.execute(f'ALTER TABLE "{table_name}" ADD COLUMN IF NOT EXISTS "{col}" TEXT') + cur.execute( + f'ALTER TABLE "{table_name}" ADD COLUMN IF NOT EXISTS "{col}" TEXT' + ) except Exception: try: cur.execute(f'ALTER TABLE "{table_name}" ADD COLUMN "{col}" TEXT') except Exception: pass - cur.execute(f'UPDATE "{table_name}" SET "{col}" = %s WHERE id = %s', (text_value, int(citation_id))) + cur.execute( + f'UPDATE "{table_name}" SET "{col}" = %s WHERE id = %s', + (text_value, int(citation_id)), + ) rows = cur.rowcount conn.commit() @@ -264,13 +281,20 @@ def update_bool_column( conn = postgres_server.conn cur = conn.cursor() try: - cur.execute(f'ALTER TABLE "{table_name}" ADD COLUMN IF NOT EXISTS "{col}" BOOLEAN') + cur.execute( + f'ALTER TABLE "{table_name}" ADD COLUMN IF NOT EXISTS "{col}" BOOLEAN' + ) except Exception: try: - cur.execute(f'ALTER TABLE "{table_name}" ADD COLUMN "{col}" BOOLEAN') + cur.execute( + f'ALTER TABLE "{table_name}" ADD COLUMN "{col}" BOOLEAN' + ) except Exception: pass - cur.execute(f'UPDATE "{table_name}" SET "{col}" = %s WHERE id = %s', (bool(bool_value), int(citation_id))) + cur.execute( + f'UPDATE "{table_name}" SET "{col}" = %s WHERE id = %s', + (bool(bool_value), int(citation_id)), + ) rows = cur.rowcount conn.commit() return rows or 0 @@ -311,7 +335,9 @@ def get_table_columns(self, table_name: str = "citations") -> List[Dict[str, str if conn: pass - def clear_columns(self, citation_id: int, columns: List[str], table_name: str = "citations") -> int: + def clear_columns( + self, citation_id: int, columns: List[str], table_name: str = "citations" + ) -> int: """Set provided columns to NULL for a citation. Ignores unknown columns.""" table_name = _validate_ident(table_name, kind="table_name") if not columns: @@ -327,7 +353,10 @@ def clear_columns(self, citation_id: int, columns: List[str], table_name: str = conn = postgres_server.conn cur = conn.cursor() set_sql = ", ".join([f'"{c}" = NULL' for c in cols]) - cur.execute(f'UPDATE "{table_name}" SET {set_sql} WHERE id = %s', (int(citation_id),)) + cur.execute( + f'UPDATE "{table_name}" SET {set_sql} WHERE id = %s', + (int(citation_id),), + ) rows = cur.rowcount conn.commit() return rows or 0 @@ -338,7 +367,9 @@ def clear_columns(self, citation_id: int, columns: List[str], table_name: str = if conn: pass - def clear_columns_by_prefix(self, citation_id: int, prefixes: List[str], table_name: str = "citations") -> int: + def clear_columns_by_prefix( + self, citation_id: int, prefixes: List[str], table_name: str = "citations" + ) -> int: """Set all columns matching any prefix to NULL for a citation.""" prefixes = [p for p in (prefixes or []) if isinstance(p, str) and p] if not prefixes: @@ -374,10 +405,14 @@ def copy_jsonb_if_empty( cur = conn.cursor() # Ensure destination column exists as JSONB try: - cur.execute(f'ALTER TABLE "{table_name}" ADD COLUMN IF NOT EXISTS "{dst_col}" JSONB') + cur.execute( + f'ALTER TABLE "{table_name}" ADD COLUMN IF NOT EXISTS "{dst_col}" JSONB' + ) except Exception: try: - cur.execute(f'ALTER TABLE "{table_name}" ADD COLUMN "{dst_col}" JSONB') + cur.execute( + f'ALTER TABLE "{table_name}" ADD COLUMN "{dst_col}" JSONB' + ) except Exception: pass @@ -420,8 +455,6 @@ def dump_citations_csv(self, table_name: str = "citations") -> bytes: ) csv_text = buf.getvalue() - - return csv_text.encode("utf-8") except Exception: _safe_rollback(conn) @@ -473,7 +506,9 @@ def dump_citations_csv_filtered(self, table_name: str = "citations") -> bytes: conn = postgres_server.conn cur = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) select_cols = base_cols + jsonb_cols - select_sql = ", ".join([f'"{c}"' for c in select_cols]) if select_cols else "*" + select_sql = ( + ", ".join([f'"{c}"' for c in select_cols]) if select_cols else "*" + ) cur.execute(f'SELECT {select_sql} FROM "{table_name}" ORDER BY id') rows = cur.fetchall() or [] @@ -514,7 +549,9 @@ def _flatten_keys_for(col: str) -> List[str]: # Fallback return [f"{col}__json"] - json_flat_cols: Dict[str, List[str]] = {c: _flatten_keys_for(c) for c in jsonb_cols} + json_flat_cols: Dict[str, List[str]] = { + c: _flatten_keys_for(c) for c in jsonb_cols + } for c in jsonb_cols: out_cols.extend(json_flat_cols[c]) @@ -573,9 +610,15 @@ def _as_str(v: Any) -> str: out[f"{c}__confidence"] = _as_str(parsed.get("confidence")) out[f"{c}__explanation"] = _as_str(parsed.get("explanation")) - out[f"{c}__evidence_sentences"] = _as_str(parsed.get("evidence_sentences")) - out[f"{c}__evidence_tables"] = _as_str(parsed.get("evidence_tables")) - out[f"{c}__evidence_figures"] = _as_str(parsed.get("evidence_figures")) + out[f"{c}__evidence_sentences"] = _as_str( + parsed.get("evidence_sentences") + ) + out[f"{c}__evidence_tables"] = _as_str( + parsed.get("evidence_tables") + ) + out[f"{c}__evidence_figures"] = _as_str( + parsed.get("evidence_figures") + ) out[f"{c}__autofilled"] = _as_str(parsed.get("autofilled")) out[f"{c}__source"] = _as_str(parsed.get("source")) out[f"{c}__timestamp"] = _as_str(parsed.get("timestamp")) @@ -591,7 +634,9 @@ def _as_str(v: Any) -> str: if conn: pass - def get_citation_by_id(self, citation_id: int, table_name: str = "citations") -> Optional[Dict[str, Any]]: + def get_citation_by_id( + self, citation_id: int, table_name: str = "citations" + ) -> Optional[Dict[str, Any]]: """ Return a dict mapping column -> value for the citation row, or None. """ @@ -655,10 +700,14 @@ def get_citations_by_ids( select_sql = "*" if fields: try: - existing_cols = {c.get("column_name") for c in self.get_table_columns(table_name)} + existing_cols = { + c.get("column_name") for c in self.get_table_columns(table_name) + } except Exception: existing_cols = set() - safe_fields = [f for f in fields if isinstance(f, str) and f in existing_cols] + safe_fields = [ + f for f in fields if isinstance(f, str) and f in existing_cols + ] if safe_fields: select_sql = ", ".join([f'"{c}"' for c in safe_fields]) @@ -679,7 +728,9 @@ def get_citations_by_ids( if conn: pass - def backfill_human_decisions(self, criteria_parsed: Dict[str, Any], table_name: str = "citations") -> int: + def backfill_human_decisions( + self, criteria_parsed: Dict[str, Any], table_name: str = "citations" + ) -> int: """Recompute and persist human_l1_decision / human_l2_decision for all rows. This is used to ensure decision columns are never stale when the UI fetches @@ -693,8 +744,16 @@ def backfill_human_decisions(self, criteria_parsed: Dict[str, Any], table_name: table_name = _validate_ident(table_name, kind="table_name") cp = criteria_parsed or {} - l1_qs = (cp.get("l1") or {}).get("questions") if isinstance(cp.get("l1"), dict) else None - l2_qs = (cp.get("l2") or {}).get("questions") if isinstance(cp.get("l2"), dict) else None + l1_qs = ( + (cp.get("l1") or {}).get("questions") + if isinstance(cp.get("l1"), dict) + else None + ) + l2_qs = ( + (cp.get("l2") or {}).get("questions") + if isinstance(cp.get("l2"), dict) + else None + ) l1_qs = l1_qs if isinstance(l1_qs, list) else [] l2_qs = l2_qs if isinstance(l2_qs, list) else [] @@ -728,7 +787,9 @@ def _human_col(q: str) -> str: # If we try to SELECT a non-existent human_* column, the query fails and the # caller silently skips the backfill, leaving stale decision columns. try: - existing_cols = {c.get("column_name") for c in self.get_table_columns(table_name)} + existing_cols = { + c.get("column_name") for c in self.get_table_columns(table_name) + } except Exception: existing_cols = set() @@ -773,7 +834,9 @@ def _compute(step_qs: List[str], row: Dict[str, Any]) -> str: hobj = _parse_jsonb(hval) selected = hobj.get("selected") # Treat empty/whitespace as unanswered (UI shows "-- select --") - if selected is None or (isinstance(selected, str) and selected.strip() == ""): + if selected is None or ( + isinstance(selected, str) and selected.strip() == "" + ): return "undecided" if "exclude" in str(selected).lower(): return "exclude" @@ -810,7 +873,9 @@ def _compute(step_qs: List[str], row: Dict[str, Any]) -> str: if conn: pass - def list_citation_ids(self, filter_step=None, table_name: str = "citations") -> List[int]: + def list_citation_ids( + self, filter_step=None, table_name: str = "citations" + ) -> List[int]: """ Return list of integer primary keys (id) from citations table ordered by id. """ @@ -829,7 +894,9 @@ def list_citation_ids(self, filter_step=None, table_name: str = "citations") -> # Validation rule (B1/B2): Full-text list is driven by the human L1 decision. # Do NOT use l1_screen/l2_screen booleans. try: - cur.execute(f'ALTER TABLE "{table_name}" ADD COLUMN IF NOT EXISTS "human_l1_decision" TEXT') + cur.execute( + f'ALTER TABLE "{table_name}" ADD COLUMN IF NOT EXISTS "human_l1_decision" TEXT' + ) except Exception: pass cur.execute( @@ -838,7 +905,9 @@ def list_citation_ids(self, filter_step=None, table_name: str = "citations") -> elif step == "l2": # Validation rule (B1/B2): Extract list is driven by the human L2 decision. try: - cur.execute(f'ALTER TABLE "{table_name}" ADD COLUMN IF NOT EXISTS "human_l2_decision" TEXT') + cur.execute( + f'ALTER TABLE "{table_name}" ADD COLUMN IF NOT EXISTS "human_l2_decision" TEXT' + ) except Exception: pass cur.execute( @@ -866,7 +935,9 @@ def list_fulltext_urls(self, table_name: str = "citations") -> List[str]: try: conn = postgres_server.conn cur = conn.cursor() - cur.execute(f'SELECT fulltext_url FROM "{table_name}" WHERE fulltext_url IS NOT NULL') + cur.execute( + f'SELECT fulltext_url FROM "{table_name}" WHERE fulltext_url IS NOT NULL' + ) rows = cur.fetchall() return [r[0] for r in rows if r and r[0]] @@ -924,7 +995,9 @@ def attach_fulltext( # ----------------------- # Column get/set helpers # ----------------------- - def get_column_value(self, citation_id: int, column: str, table_name: str = "citations") -> Any: + def get_column_value( + self, citation_id: int, column: str, table_name: str = "citations" + ) -> Any: """ Return the value stored in `column` for the citation row (or None). """ @@ -936,7 +1009,9 @@ def get_column_value(self, citation_id: int, column: str, table_name: str = "cit cur = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) except Exception: cur = conn.cursor() - cur.execute(f'SELECT "{column}" FROM "{table_name}" WHERE id = %s', (citation_id,)) + cur.execute( + f'SELECT "{column}" FROM "{table_name}" WHERE id = %s', (citation_id,) + ) row = cur.fetchone() if not row: return None @@ -954,13 +1029,20 @@ def get_column_value(self, citation_id: int, column: str, table_name: str = "cit if conn: pass - def set_column_value(self, citation_id: int, column: str, value: Any, table_name: str = "citations") -> int: + def set_column_value( + self, citation_id: int, column: str, value: Any, table_name: str = "citations" + ) -> int: """ Generic setter for a citation row column. Will create a TEXT column if it doesn't exist. """ # For simplicity, create a TEXT column. Callers that need JSONB should use update_jsonb_column. self.create_column(column, "TEXT", table_name=table_name) - return self.update_text_column(citation_id, column, value if value is not None else None, table_name=table_name) + return self.update_text_column( + citation_id, + column, + value if value is not None else None, + table_name=table_name, + ) # ----------------------- # Per-upload table lifecycle helpers @@ -1019,7 +1101,12 @@ def create_table_and_insert_sync( inserted = 0 if rows: safe_cols = [snake_case(c) for c in columns] - insert_cols = [f'"{c}"' for c in safe_cols] + ['"cit_id"', '"fulltext_url"', '"fulltext"', '"fulltext_md5"'] + insert_cols = [f'"{c}"' for c in safe_cols] + [ + '"cit_id"', + '"fulltext_url"', + '"fulltext"', + '"fulltext_md5"', + ] placeholders = ", ".join(["%s"] * len(insert_cols)) insert_sql = f'INSERT INTO "{table_name}" ({", ".join(insert_cols)}) VALUES ({placeholders})' @@ -1034,11 +1121,26 @@ def _row_has_data(row: dict) -> bool: values = [] for r in filtered_rows: - row_vals = [r.get(orig_col) if r.get(orig_col) is not None else None for orig_col in columns] - row_vals.append(r.get("cit_id") if r.get("cit_id") is not None else None) - row_vals.append(r.get("fulltext_url") if r.get("fulltext_url") is not None else None) - row_vals.append(r.get("fulltext") if r.get("fulltext") is not None else None) - row_vals.append(r.get("fulltext_md5") if r.get("fulltext_md5") is not None else None) + row_vals = [ + r.get(orig_col) if r.get(orig_col) is not None else None + for orig_col in columns + ] + row_vals.append( + r.get("cit_id") if r.get("cit_id") is not None else None + ) + row_vals.append( + r.get("fulltext_url") + if r.get("fulltext_url") is not None + else None + ) + row_vals.append( + r.get("fulltext") if r.get("fulltext") is not None else None + ) + row_vals.append( + r.get("fulltext_md5") + if r.get("fulltext_md5") is not None + else None + ) values.append(tuple(row_vals)) if values: @@ -1058,7 +1160,9 @@ def _row_has_data(row: dict) -> bool: # NOTE: legacy per-database helpers (drop_database, create_db_and_table_sync) were # intentionally removed in favor of per-upload tables in a shared database. - def load_include_columns_from_criteria(self, sr_doc: Optional[Dict[str, Any]] = None) -> List[str]: + def load_include_columns_from_criteria( + self, sr_doc: Optional[Dict[str, Any]] = None + ) -> List[str]: """ Load the 'include' list for L1 screening. Mirrors logic previously embedded in the citations router but kept here @@ -1080,7 +1184,13 @@ def load_include_columns_from_criteria(self, sr_doc: Optional[Dict[str, Any]] = pass # 2) fallback to project file - cfg_path = os.path.join(os.path.dirname(__file__), "..", "sr_setup", "configs", "criteria_config_measles_updated.yaml") + cfg_path = os.path.join( + os.path.dirname(__file__), + "..", + "sr_setup", + "configs", + "criteria_config_measles_updated.yaml", + ) cfg_path = os.path.normpath(cfg_path) try: import yaml @@ -1096,7 +1206,9 @@ def load_include_columns_from_criteria(self, sr_doc: Optional[Dict[str, Any]] = except Exception: return [] - def build_combined_citation_from_row(self, row: Dict[str, Any], include_columns: List[str]) -> str: + def build_combined_citation_from_row( + self, row: Dict[str, Any], include_columns: List[str] + ) -> str: parts: List[str] = [] if not row: return "" diff --git a/backend/api/services/citation_search/europePMC_citation_collection.py b/backend/api/services/citation_search/europePMC_citation_collection.py index f9842a21..8b62d15b 100644 --- a/backend/api/services/citation_search/europePMC_citation_collection.py +++ b/backend/api/services/citation_search/europePMC_citation_collection.py @@ -5,22 +5,24 @@ from requests.adapters import HTTPAdapter from urllib3.util.retry import Retry -logging.basicConfig(level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) logger = logging.getLogger("grep-exp-pubmed-citations") + class EuropePMCCitationCollector: def __init__(self): self.base_url = "https://www.ebi.ac.uk/europepmc/webservices/rest/search" self.citations = [] - + # Set up retry-capable session self.session = requests.Session() retry_strategy = Retry( total=3, backoff_factor=0.3, status_forcelist=[429, 500, 502, 503, 504], - allowed_methods=["GET", "HEAD"] + allowed_methods=["GET", "HEAD"], ) adapter = HTTPAdapter(max_retries=retry_strategy) self.session.mount("http://", adapter) @@ -33,9 +35,9 @@ def fetch_data(self, search_term, cursor_mark="*"): "query": search_term, "pageSize": 100, "format": "json", - "cursorMark": cursor_mark + "cursorMark": cursor_mark, } - + response = self.session.get(self.base_url, params=params, timeout=30) response.raise_for_status() return response.json() @@ -57,115 +59,140 @@ def extract_citation_data(self, entry): year = int(entry.get("pubYear")) except (ValueError, TypeError): year = None - + # Extract keywords safely keywords = [] keyword_list = entry.get("keywordList") if isinstance(keyword_list, dict) and "keyword" in keyword_list: - keywords = keyword_list["keyword"] if isinstance(keyword_list["keyword"], list) else [keyword_list["keyword"]] - + keywords = ( + keyword_list["keyword"] + if isinstance(keyword_list["keyword"], list) + else [keyword_list["keyword"]] + ) + # Extract publication types safely pub_types = [] pub_type_list = entry.get("pubTypeList") if isinstance(pub_type_list, dict) and "pubType" in pub_type_list: - pub_types = pub_type_list["pubType"] if isinstance(pub_type_list["pubType"], list) else [pub_type_list["pubType"]] - + pub_types = ( + pub_type_list["pubType"] + if isinstance(pub_type_list["pubType"], list) + else [pub_type_list["pubType"]] + ) + # Extract full text URLs safely full_text_urls = [] full_text_url_list = entry.get("fullTextUrlList") - if isinstance(full_text_url_list, dict) and "fullTextUrl" in full_text_url_list: + if ( + isinstance(full_text_url_list, dict) + and "fullTextUrl" in full_text_url_list + ): urls = full_text_url_list["fullTextUrl"] if isinstance(urls, list): - full_text_urls = [url.get("url") for url in urls if isinstance(url, dict) and url.get("url")] + full_text_urls = [ + url.get("url") + for url in urls + if isinstance(url, dict) and url.get("url") + ] elif isinstance(urls, dict): if urls.get("url"): full_text_urls = [urls["url"]] - + citation_data = { - 'source': 'Europe PMC', - 'id': entry.get("id"), - 'pmid': entry.get("pmid"), - 'pmcid': entry.get("pmcid"), - 'title': entry.get("title"), - 'authors': entry.get("authorList", {}).get("author", []) if isinstance(entry.get("authorList"), dict) else [], - 'author_string': entry.get("authorString"), - 'abstract': entry.get("abstractText"), - 'keywords': keywords, - 'doi': entry.get("doi"), - 'journal': entry.get("journalTitle"), - 'issn': entry.get("journalIssn"), - 'volume': entry.get("journalVolume"), - 'issue': entry.get("issue"), - 'pages': entry.get("pageInfo"), - 'publication_types': pub_types, - 'date': entry.get("firstPublicationDate"), - 'year': year, - 'language': entry.get("language"), - 'full_text_urls': full_text_urls, - 'is_open_access': entry.get("isOpenAccess"), - 'has_pdf': entry.get("hasPDF"), - 'pdf_attempted': False, - 'pdf_success': False, - 'pdf_url': None, - 'pdf_error': None + "source": "Europe PMC", + "id": entry.get("id"), + "pmid": entry.get("pmid"), + "pmcid": entry.get("pmcid"), + "title": entry.get("title"), + "authors": ( + entry.get("authorList", {}).get("author", []) + if isinstance(entry.get("authorList"), dict) + else [] + ), + "author_string": entry.get("authorString"), + "abstract": entry.get("abstractText"), + "keywords": keywords, + "doi": entry.get("doi"), + "journal": entry.get("journalTitle"), + "issn": entry.get("journalIssn"), + "volume": entry.get("journalVolume"), + "issue": entry.get("issue"), + "pages": entry.get("pageInfo"), + "publication_types": pub_types, + "date": entry.get("firstPublicationDate"), + "year": year, + "language": entry.get("language"), + "full_text_urls": full_text_urls, + "is_open_access": entry.get("isOpenAccess"), + "has_pdf": entry.get("hasPDF"), + "pdf_attempted": False, + "pdf_success": False, + "pdf_url": None, + "pdf_error": None, } - + return citation_data - + except Exception as e: - logger.error(f"Error extracting citation data from Europe PMC entry {entry.get('id', 'Unknown ID')}: {e}") + logger.error( + f"Error extracting citation data from Europe PMC entry {entry.get('id', 'Unknown ID')}: {e}" + ) return None def collect_citations(self, search_term, max_articles=1000): """Main method to collect citations from Europe PMC""" logger.info(f"Starting Europe PMC citation collection") logger.info(f"Search query: {search_term[:100]}...") - + try: cursor_mark = "*" total_processed = 0 next_cursor_mark = None self.citations = [] - + # Continue until we reach max_articles or no more results while cursor_mark != next_cursor_mark and total_processed < max_articles: # Fetch batch response_data = self.fetch_data(search_term, cursor_mark) entries = response_data.get("resultList", {}).get("result", []) - + if not entries: logger.info("No more results found or API error") break - + # Update cursor for next request next_cursor_mark = response_data.get("nextCursorMark") - + # Log progress batch_size = len(entries) logger.info(f"Retrieved batch of {batch_size} articles. Processing...") - + # Process entries for i, entry in enumerate(entries): if total_processed >= max_articles: - logger.info(f"Reached maximum number of articles ({max_articles})") + logger.info( + f"Reached maximum number of articles ({max_articles})" + ) break - + try: article_num = total_processed + i + 1 - logger.info(f"Processing article {article_num}: {entry.get('title', 'Unknown title')}") - + logger.info( + f"Processing article {article_num}: {entry.get('title', 'Unknown title')}" + ) + citation_data = self.extract_citation_data(entry) if citation_data: self.citations.append(citation_data) - + except Exception as e: logger.error(f"Error processing entry: {e}") continue - + # Update for next iteration total_processed += len(entries) logger.info(f"Processed {total_processed} articles so far") - + # Update cursor for next request if next_cursor_mark and next_cursor_mark != cursor_mark: cursor_mark = next_cursor_mark @@ -174,10 +201,12 @@ def collect_citations(self, search_term, max_articles=1000): time.sleep(2) else: break - - logger.info(f"Completed processing. Collected {len(self.citations)} citations.") + + logger.info( + f"Completed processing. Collected {len(self.citations)} citations." + ) return self.citations - + except Exception as e: logger.error(f"Unexpected error in citation collection: {e}") - return [] \ No newline at end of file + return [] diff --git a/backend/api/services/citation_search/pubmed_citation_collection.py b/backend/api/services/citation_search/pubmed_citation_collection.py index 38449de1..ef7bf4ed 100644 --- a/backend/api/services/citation_search/pubmed_citation_collection.py +++ b/backend/api/services/citation_search/pubmed_citation_collection.py @@ -5,6 +5,7 @@ import logging import string import time + # from unidecode import unidecode from cleantext import clean from datetime import datetime @@ -14,240 +15,259 @@ ENTREZ_EMAIL = settings.ENTREZ_EMAIL ENTREZ_API_KEY = settings.ENTREZ_API_KEY # Configure logging -logging.basicConfig(level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) logger = logging.getLogger("grep-exp-pubmed-citations") class PubMedCitationCollector: - def __init__(self): - self.citations = [] - - def extract_citation_data(self, paper): - """Extract citation metadata from a PubMed article""" - try: - # Basic information - title = paper['MedlineCitation']['Article'].get('ArticleTitle', '') - if title: - title = clean(title.rstrip("."), lower=False, fix_unicode=True) - - pmid = str(paper['MedlineCitation']['PMID']) - - # Authors - authors = [] - if 'AuthorList' in paper['MedlineCitation']['Article']: - for author in paper['MedlineCitation']['Article']['AuthorList']: - last_name = author.get('LastName', '') - fore_name = author.get('ForeName', '') - full_name = ' '.join(filter(None, [fore_name, last_name])) - if full_name: - authors.append(full_name) - - # Keywords - keywords = [] - if 'KeywordList' in paper['MedlineCitation'] and paper['MedlineCitation']['KeywordList']: - keywords = [str(kw) for kw in paper['MedlineCitation']['KeywordList'][0]] - - # DOI - doi = None - for e_location in paper['MedlineCitation']['Article'].get('ELocationID', []): - if e_location.attributes['EIdType'] == 'doi': - doi = str(e_location) - break - - # If no DOI found in article, try CrossRef - if doi is None: - doi = self.find_doi_crossref(title) - - # Date - article_date = paper['MedlineCitation']['Article'].get('ArticleDate', []) - date = None - if article_date: - date = f"{article_date[0]['Year']}/{article_date[0]['Month']}/{article_date[0]['Day']}" - - # Abstract - abstract_texts = paper['MedlineCitation']['Article'].get('Abstract', {}).get('AbstractText', []) - abstract = ' '.join([str(text) for text in abstract_texts]) - - # Journal info - journal_info = paper['MedlineCitation']['Article']['Journal'] - - # Publication Type - pt = [str(pt) for pt in paper['MedlineCitation']['Article']['PublicationTypeList']] - - # Compile citation data - citation_data = { - 'source': 'PubMed', - 'id': pmid, - 'title': title, - 'authors': ', '.join(authors), - 'abstract': abstract, - 'keywords': ', '.join(keywords), - 'doi': doi, - 'pmid': pmid, - 'journal': journal_info['Title'], - 'issn': ''.join(journal_info.get('ISSN', [])), - 'publication_types': ', '.join(pt), - 'date': date, - 'year': paper['MedlineCitation']['Article']['Journal']['JournalIssue']['PubDate'].get('Year'), - 'volume': journal_info['JournalIssue'].get('Volume'), - 'issue': journal_info['JournalIssue'].get('Issue'), - 'language': ', '.join(paper['MedlineCitation']['Article'].get('Language', [])), - 'pdf_attempted': False, - 'pdf_success': False, - 'pdf_url': None, - 'pdf_error': None - } - - return citation_data - - except Exception as e: - logger.error(f"Error extracting citation data: {e}") - return None - - def collect_citations(self, search_term, mindate=None, maxdate=None, max_articles=1000): - """Main method to collect citations from PubMed""" - - Entrez.email = ENTREZ_EMAIL - Entrez.api_key = ENTREZ_API_KEY - - - logger.info(f"Starting PubMed citation collection") - logger.info(f"Search term: {search_term[:100]}...") - if mindate and maxdate: - logger.info(f"Date range: {mindate} to {maxdate}") - elif mindate: - logger.info(f"Date range: {mindate} to present") - elif maxdate: - logger.info(f"Date range: earliest to {maxdate}") - else: - logger.info("No date range specified") - - search_params = { - 'db': 'pubmed', - 'sort': 'pub_date', - 'retmode': 'xml', - 'retmax': max_articles, - 'term': search_term, - } - - if mindate: - search_params['mindate'] = mindate - if maxdate: - search_params['maxdate'] = maxdate - - try: - handle = Entrez.esearch(**search_params) - pmid_list = Entrez.read(handle) - pmid_list = pmid_list.get('IdList', []) - except Exception as e: - logger.error(f"Error searching PubMed: {e}") - pmid_list = [] - - # Search PubMed - # pmid_list = search_pubmed(search_term, mindate, maxdate, max_articles) - - if not pmid_list: - logger.warning("No PMIDs found") - return [] - - total_pmids = len(pmid_list) - logger.info(f"Found {total_pmids} PMIDs") - - # Limit to max_articles - if max_articles and total_pmids > max_articles: - logger.info(f"Limiting to {max_articles} articles") - pmid_list = pmid_list[:max_articles] - - # Fetch article details - logger.info(f"Fetching details for {len(pmid_list)} articles") - papers = self.fetch_details_batch(pmid_list) - - # Extract citation data - self.citations = [] - for i, paper in enumerate(papers): - try: - # logger.info(f"Processing article {i+1}/{len(papers)}") - citation_data = self.extract_citation_data(paper) - if citation_data: - self.citations.append(citation_data) - - except Exception as e: - logger.error(f"Error processing paper {i}: {e}") - continue - - logger.info(f"Successfully collected {len(self.citations)} citations") - return self.citations - - def find_doi_crossref(self, title): - """Find DOI using CrossRef API""" - if not title: - return None - - url = "https://api.crossref.org/works" - params = { - 'query.title': title, - 'rows': 1 - } - - try: - response = requests.get(url, params=params, timeout=10) - - if response.status_code == 200: - data = response.json() - items = data['message']['items'] - - if items: - first_item = items[0] - result_title = first_item['title'][0] - - if self.clean_title(result_title) == self.clean_title(title): - return first_item['DOI'] - - return None - except Exception as e: - logger.warning(f"Error finding DOI for title: {e}") - return None - - def fetch_details_batch(self, id_list, batch_size=50): - """Fetch details for PubMed IDs in batches""" - if not id_list: - return [] - - Entrez.email = ENTREZ_EMAIL - Entrez.api_key = ENTREZ_API_KEY - - all_results = [] - print(id_list) - print(f"Number of IDs: {len(id_list)}") - - # id_list = list(id_list) - - for i in range(0, len(id_list), batch_size): - batch_ids = id_list[i:i+batch_size] - ids = ','.join(batch_ids) - + def __init__(self): + self.citations = [] + + def extract_citation_data(self, paper): + """Extract citation metadata from a PubMed article""" + try: + # Basic information + title = paper["MedlineCitation"]["Article"].get("ArticleTitle", "") + if title: + title = clean(title.rstrip("."), lower=False, fix_unicode=True) + + pmid = str(paper["MedlineCitation"]["PMID"]) + + # Authors + authors = [] + if "AuthorList" in paper["MedlineCitation"]["Article"]: + for author in paper["MedlineCitation"]["Article"]["AuthorList"]: + last_name = author.get("LastName", "") + fore_name = author.get("ForeName", "") + full_name = " ".join(filter(None, [fore_name, last_name])) + if full_name: + authors.append(full_name) + + # Keywords + keywords = [] + if ( + "KeywordList" in paper["MedlineCitation"] + and paper["MedlineCitation"]["KeywordList"] + ): + keywords = [ + str(kw) for kw in paper["MedlineCitation"]["KeywordList"][0] + ] + + # DOI + doi = None + for e_location in paper["MedlineCitation"]["Article"].get( + "ELocationID", [] + ): + if e_location.attributes["EIdType"] == "doi": + doi = str(e_location) + break + + # If no DOI found in article, try CrossRef + if doi is None: + doi = self.find_doi_crossref(title) + + # Date + article_date = paper["MedlineCitation"]["Article"].get("ArticleDate", []) + date = None + if article_date: + date = f"{article_date[0]['Year']}/{article_date[0]['Month']}/{article_date[0]['Day']}" + + # Abstract + abstract_texts = ( + paper["MedlineCitation"]["Article"] + .get("Abstract", {}) + .get("AbstractText", []) + ) + abstract = " ".join([str(text) for text in abstract_texts]) + + # Journal info + journal_info = paper["MedlineCitation"]["Article"]["Journal"] + + # Publication Type + pt = [ + str(pt) + for pt in paper["MedlineCitation"]["Article"]["PublicationTypeList"] + ] + + # Compile citation data + citation_data = { + "source": "PubMed", + "id": pmid, + "title": title, + "authors": ", ".join(authors), + "abstract": abstract, + "keywords": ", ".join(keywords), + "doi": doi, + "pmid": pmid, + "journal": journal_info["Title"], + "issn": "".join(journal_info.get("ISSN", [])), + "publication_types": ", ".join(pt), + "date": date, + "year": paper["MedlineCitation"]["Article"]["Journal"]["JournalIssue"][ + "PubDate" + ].get("Year"), + "volume": journal_info["JournalIssue"].get("Volume"), + "issue": journal_info["JournalIssue"].get("Issue"), + "language": ", ".join( + paper["MedlineCitation"]["Article"].get("Language", []) + ), + "pdf_attempted": False, + "pdf_success": False, + "pdf_url": None, + "pdf_error": None, + } + + return citation_data + + except Exception as e: + logger.error(f"Error extracting citation data: {e}") + return None + + def collect_citations( + self, search_term, mindate=None, maxdate=None, max_articles=1000 + ): + """Main method to collect citations from PubMed""" + + Entrez.email = ENTREZ_EMAIL + Entrez.api_key = ENTREZ_API_KEY + + logger.info(f"Starting PubMed citation collection") + logger.info(f"Search term: {search_term[:100]}...") + if mindate and maxdate: + logger.info(f"Date range: {mindate} to {maxdate}") + elif mindate: + logger.info(f"Date range: {mindate} to present") + elif maxdate: + logger.info(f"Date range: earliest to {maxdate}") + else: + logger.info("No date range specified") + + search_params = { + "db": "pubmed", + "sort": "pub_date", + "retmode": "xml", + "retmax": max_articles, + "term": search_term, + } + + if mindate: + search_params["mindate"] = mindate + if maxdate: + search_params["maxdate"] = maxdate + + try: + handle = Entrez.esearch(**search_params) + pmid_list = Entrez.read(handle) + pmid_list = pmid_list.get("IdList", []) + except Exception as e: + logger.error(f"Error searching PubMed: {e}") + pmid_list = [] + + # Search PubMed + # pmid_list = search_pubmed(search_term, mindate, maxdate, max_articles) + + if not pmid_list: + logger.warning("No PMIDs found") + return [] + + total_pmids = len(pmid_list) + logger.info(f"Found {total_pmids} PMIDs") + + # Limit to max_articles + if max_articles and total_pmids > max_articles: + logger.info(f"Limiting to {max_articles} articles") + pmid_list = pmid_list[:max_articles] + + # Fetch article details + logger.info(f"Fetching details for {len(pmid_list)} articles") + papers = self.fetch_details_batch(pmid_list) + + # Extract citation data + self.citations = [] + for i, paper in enumerate(papers): + try: + # logger.info(f"Processing article {i+1}/{len(papers)}") + citation_data = self.extract_citation_data(paper) + if citation_data: + self.citations.append(citation_data) + + except Exception as e: + logger.error(f"Error processing paper {i}: {e}") + continue + + logger.info(f"Successfully collected {len(self.citations)} citations") + return self.citations + + def find_doi_crossref(self, title): + """Find DOI using CrossRef API""" + if not title: + return None + + url = "https://api.crossref.org/works" + params = {"query.title": title, "rows": 1} + try: - handle = Entrez.efetch(db='pubmed', retmode='xml', id=ids) - results = Entrez.read(handle) - - if "PubmedArticle" in results: - all_results.extend(results["PubmedArticle"]) - - # Add delay to be nice to API - if i + batch_size < len(id_list): - time.sleep(1) - + response = requests.get(url, params=params, timeout=10) + + if response.status_code == 200: + data = response.json() + items = data["message"]["items"] + + if items: + first_item = items[0] + result_title = first_item["title"][0] + + if self.clean_title(result_title) == self.clean_title(title): + return first_item["DOI"] + + return None except Exception as e: - logger.error(f"Error fetching details for batch {i//batch_size + 1}: {e}") - continue - - return all_results - - def clean_title(self, title): - """Clean and normalize a title""" - if not title: - return "" - title = title.lower() - title = title.translate(str.maketrans('', '', string.punctuation)) - return title \ No newline at end of file + logger.warning(f"Error finding DOI for title: {e}") + return None + + def fetch_details_batch(self, id_list, batch_size=50): + """Fetch details for PubMed IDs in batches""" + if not id_list: + return [] + + Entrez.email = ENTREZ_EMAIL + Entrez.api_key = ENTREZ_API_KEY + + all_results = [] + print(id_list) + print(f"Number of IDs: {len(id_list)}") + + # id_list = list(id_list) + + for i in range(0, len(id_list), batch_size): + batch_ids = id_list[i : i + batch_size] + ids = ",".join(batch_ids) + + try: + handle = Entrez.efetch(db="pubmed", retmode="xml", id=ids) + results = Entrez.read(handle) + + if "PubmedArticle" in results: + all_results.extend(results["PubmedArticle"]) + + # Add delay to be nice to API + if i + batch_size < len(id_list): + time.sleep(1) + + except Exception as e: + logger.error( + f"Error fetching details for batch {i//batch_size + 1}: {e}" + ) + continue + + return all_results + + def clean_title(self, title): + """Clean and normalize a title""" + if not title: + return "" + title = title.lower() + title = title.translate(str.maketrans("", "", string.punctuation)) + return title diff --git a/backend/api/services/citation_search/scopus_citation_collection.py b/backend/api/services/citation_search/scopus_citation_collection.py index a5fba9ec..c52f8151 100644 --- a/backend/api/services/citation_search/scopus_citation_collection.py +++ b/backend/api/services/citation_search/scopus_citation_collection.py @@ -1,5 +1,6 @@ import requests + class ScopusDataProcessor: def __init__(self, api_key, base_url): self._api_key = api_key @@ -8,11 +9,13 @@ def __init__(self, api_key, base_url): self.data = [] def fetch_data(self, start, search_term): - params = { "query": search_term, - "apiKey": self._api_key, - "count": 25, - "start": start, - "view": "COMPLETE"} + params = { + "query": search_term, + "apiKey": self._api_key, + "count": 25, + "start": start, + "view": "COMPLETE", + } response = requests.get(self._URL, params=params) if response.status_code == 200: return response.json() @@ -23,7 +26,9 @@ def process_entry(self, entry): pdf_link_call = self.get_open_access_link(entry.get("prism:doi")) return { "Refid": entry.get("dc:identifier"), - "Author": ", ".join([author["authname"] for author in entry.get("author", [])]), + "Author": ", ".join( + [author["authname"] for author in entry.get("author", [])] + ), "Title": entry.get("dc:title"), "Abstract": entry.get("dc:description"), "Accession Number": entry.get("eid"), @@ -38,7 +43,11 @@ def process_entry(self, entry): "ISSN": entry.get("prism:issn"), "Issue": entry.get("prism:issueIdentifier"), "Journal": entry.get("prism:publicationName"), - "Keywords": entry.get("authkeywords").replace(" |", ",") if entry.get("authkeywords") != None else None, + "Keywords": ( + entry.get("authkeywords").replace(" |", ",") + if entry.get("authkeywords") != None + else None + ), "Language": None, "Notes": None, "Original Publication": None, @@ -61,13 +70,15 @@ def process_entry(self, entry): "Is this article primary research?": None, "Is this article on the human population?": None, "Is the main focus of this study about measles disease?": None, - "Measles aka rubeola, Morbilli, red measles, English measles": None + "Measles aka rubeola, Morbilli, red measles, English measles": None, } def consume_api(self, search_term, delay=1): res_data = self.fetch_data(0) - total_results = int(res_data.get("search-results", {}).get("opensearch:totalResults")) + total_results = int( + res_data.get("search-results", {}).get("opensearch:totalResults") + ) for start in range(0, total_results, 25): res_data = self.fetch_data(start, search_term) entries = res_data.get("search-results", {}).get("entry", []) @@ -83,7 +94,7 @@ def get_open_access_link(self, doi): if response and response.status_code == 200: data = response.json() - link = data.get('url', None) - if link and 'pdf' in link.lower(): + link = data.get("url", None) + if link and "pdf" in link.lower(): return link - return None \ No newline at end of file + return None diff --git a/backend/api/services/document_service.py b/backend/api/services/document_service.py index 571ed429..fca2b7be 100644 --- a/backend/api/services/document_service.py +++ b/backend/api/services/document_service.py @@ -36,7 +36,9 @@ async def convert_document_to_markdown( result["processor_used"] = "azure_doc_intelligence" return result - async def get_raw_analysis_result(self, conversion_id: str) -> Optional[Dict[str, Any]]: + async def get_raw_analysis_result( + self, conversion_id: str + ) -> Optional[Dict[str, Any]]: if not azure_docint_client: return None return await azure_docint_client.get_raw_analysis_result(conversion_id) diff --git a/backend/api/services/grobid_service.py b/backend/api/services/grobid_service.py index 6bed76ff..94d46e06 100644 --- a/backend/api/services/grobid_service.py +++ b/backend/api/services/grobid_service.py @@ -16,7 +16,7 @@ "formula": "rgba(255, 165, 0, 1)", # Orange "figure": "rgba(165, 42, 42, 1)", # Brown "title": "rgba(255, 0, 0, 1)", # Red - "affiliation": "rgba(255, 165, 0, 1)" # red-orengi + "affiliation": "rgba(255, 165, 0, 1)", # red-orengi } @@ -27,11 +27,13 @@ def get_color(name, param): return color + def exclude_tags(tag): ret = False - ret |= tag.name != 'abstract' # exclude the abstract + ret |= tag.name != "abstract" # exclude the abstract return ret + class GrobidService: def __init__(self): @@ -77,16 +79,20 @@ def is_available(self) -> bool: async def process_structure(self, input_path) -> (dict, list): if not self.grobid_client: - raise RuntimeError("GROBID client is not available (service not configured or down)") - pdf_file, status, text = self.grobid_client.process_pdf("processFulltextDocument", - input_path, - consolidate_header=True, - consolidate_citations=False, - segment_sentences=True, - tei_coordinates=True, - include_raw_citations=False, - include_raw_affiliations=False, - generateIDs=True) + raise RuntimeError( + "GROBID client is not available (service not configured or down)" + ) + pdf_file, status, text = self.grobid_client.process_pdf( + "processFulltextDocument", + input_path, + consolidate_header=True, + consolidate_citations=False, + segment_sentences=True, + tei_coordinates=True, + include_raw_citations=False, + include_raw_affiliations=False, + generateIDs=True, + ) if status != 200: return @@ -99,23 +105,29 @@ async def process_structure(self, input_path) -> (dict, list): @staticmethod def box_to_dict(box, color=None, type=None, text=None): - item = {"page": box[0], "x": box[1], "y": box[2], "width": box[3], "height": box[4]} + item = { + "page": box[0], + "x": box[1], + "y": box[2], + "width": box[3], + "height": box[4], + } if color is not None: - item['color'] = color + item["color"] = color if type: - item['type'] = type + item["type"] = type if text: - item['text'] = text + item["text"] = text return item async def get_coordinates(self, text): - soup = BeautifulSoup(text, 'xml') + soup = BeautifulSoup(text, "xml") # exclude certain tag names - all_blocks_with_coordinates = soup.find('text').find_all(coords=True) + all_blocks_with_coordinates = soup.find("text").find_all(coords=True) # all_blocks_with_coordinates = soup.find_all() # if use_sentences: @@ -124,27 +136,35 @@ async def get_coordinates(self, text): coordinates = [] count = 0 for block_id, block in enumerate(all_blocks_with_coordinates): - for box in filter(lambda c: len(c) > 0 and c[0] != "", block['coords'].split(";")): + for box in filter( + lambda c: len(c) > 0 and c[0] != "", block["coords"].split(";") + ): coordinates.append( self.box_to_dict( box.split(","), get_color(block.name, count % 2 == 0), type=block.name, - text=block.text + text=block.text, ), ) count += 1 return coordinates async def get_pages(self, text): - soup = BeautifulSoup(text, 'xml') + soup = BeautifulSoup(text, "xml") pages_infos = soup.find_all("surface") - pages = [{'width': float(page['lrx']) - float(page['ulx']), 'height': float(page['lry']) - float(page['uly'])} - for page in pages_infos] + pages = [ + { + "width": float(page["lrx"]) - float(page["ulx"]), + "height": float(page["lry"]) - float(page["uly"]), + } + for page in pages_infos + ] return pages + # Global instance try: grobid_service = GrobidService() diff --git a/backend/api/services/postgres_auth.py b/backend/api/services/postgres_auth.py index e053339c..4c38379f 100644 --- a/backend/api/services/postgres_auth.py +++ b/backend/api/services/postgres_auth.py @@ -71,8 +71,14 @@ def conn(self): # psycopg2.errors.InFailedSqlTransaction # Heal it proactively so one bad request cannot poison the process. try: - if self._conn and self._conn.get_transaction_status() == psycopg2.extensions.TRANSACTION_STATUS_INERROR: - logger.warning("Postgres connection in aborted transaction; rolling back") + if ( + self._conn + and self._conn.get_transaction_status() + == psycopg2.extensions.TRANSACTION_STATUS_INERROR + ): + logger.warning( + "Postgres connection in aborted transaction; rolling back" + ) self._conn.rollback() except Exception: # If rollback fails, reconnect. @@ -128,7 +134,9 @@ def _refresh_azure_token(self) -> str: ) token = self._credential.get_token(self._AZURE_POSTGRES_SCOPE) self._token = token.token - self._token_expiration = token.expires_on - self._TOKEN_REFRESH_BUFFER_SECONDS + self._token_expiration = ( + token.expires_on - self._TOKEN_REFRESH_BUFFER_SECONDS + ) return self._token @staticmethod @@ -167,7 +175,12 @@ def _candidate_kwargs(self, mode: str, psycopg3: bool = False) -> Dict[str, Any] kwargs["password"] = prof.get("password") # Sanity checks - required = [kwargs.get("host"), kwargs.get("database"), kwargs.get("user"), kwargs.get("port")] + required = [ + kwargs.get("host"), + kwargs.get("database"), + kwargs.get("user"), + kwargs.get("port"), + ] if not all(required): raise RuntimeError(f"Incomplete Postgres config for mode={mode}") @@ -188,7 +201,9 @@ def _connect(self): try: return self._connect_with_mode(primary_mode) except Exception as e: - logger.error("Postgres connect failed (mode=%s): %s", primary_mode, e, exc_info=True) + logger.error( + "Postgres connect failed (mode=%s): %s", primary_mode, e, exc_info=True + ) raise psycopg2.OperationalError( f"Could not connect to Postgres for mode={primary_mode}" ) @@ -200,13 +215,14 @@ async def aconn(self) -> AsyncIterator[psycopg.AsyncConnection]: Commits automatically on clean exit, rolls back on exception. """ kwargs = self._candidate_kwargs(self._mode(), psycopg3=True) - async with await psycopg.AsyncConnection.connect(**kwargs, row_factory=dict_row) as conn: + async with await psycopg.AsyncConnection.connect( + **kwargs, row_factory=dict_row + ) as conn: yield conn def __repr__(self) -> str: status = "open" if self._conn and not self._conn.closed else "closed" - return ( - f"" - ) + return f"" + -postgres_server = PostgresServer() \ No newline at end of file +postgres_server = PostgresServer() diff --git a/backend/api/services/sr_db_service.py b/backend/api/services/sr_db_service.py index f936ee2a..1a4d46dc 100644 --- a/backend/api/services/sr_db_service.py +++ b/backend/api/services/sr_db_service.py @@ -8,6 +8,7 @@ All blocking DB operations are synchronous and intended to be run with `fastapi.concurrency.run_in_threadpool` when called from async routes. """ + from typing import Any, Dict, Optional, List import logging import uuid @@ -37,7 +38,7 @@ def ensure_table_exists(self) -> None: try: conn = postgres_server.conn cur = conn.cursor() - + create_table_sql = """ CREATE TABLE IF NOT EXISTS systematic_reviews ( id TEXT PRIMARY KEY, @@ -57,7 +58,7 @@ def ensure_table_exists(self) -> None: """ cur.execute(create_table_sql) conn.commit() - + logger.info("Ensured systematic_reviews table exists") except Exception as e: try: @@ -71,7 +72,9 @@ def ensure_table_exists(self) -> None: if conn: pass - def build_criteria_parsed(self, criteria_obj: Optional[Dict[str, Any]]) -> Dict[str, Any]: + def build_criteria_parsed( + self, criteria_obj: Optional[Dict[str, Any]] + ) -> Dict[str, Any]: """ Port of the _build_criteria_parsed helper - returns a mapping containing 'l1', 'l2', 'parameters' structured metadata for the UI and screening logic. @@ -108,7 +111,13 @@ def build_criteria_parsed(self, criteria_obj: Optional[Dict[str, Any]]) -> Dict[ else: possible.append([]) addinfos.append("") - parsed["l1"].update({"questions": qlist, "possible_answers": possible, "additional_infos": addinfos}) + parsed["l1"].update( + { + "questions": qlist, + "possible_answers": possible, + "additional_infos": addinfos, + } + ) # l2 criteria (fulltext) l2 = criteria_obj.get("l2_criteria", {}) @@ -133,7 +142,13 @@ def build_criteria_parsed(self, criteria_obj: Optional[Dict[str, Any]]) -> Dict[ else: l2_possible.append([]) l2_addinfos.append("") - parsed["l2"].update({"questions": l2_q, "possible_answers": l2_possible, "additional_infos": l2_addinfos}) + parsed["l2"].update( + { + "questions": l2_q, + "possible_answers": l2_possible, + "additional_infos": l2_addinfos, + } + ) # parameters params = criteria_obj.get("parameters", {}) @@ -147,14 +162,22 @@ def build_criteria_parsed(self, criteria_obj: Optional[Dict[str, Any]]) -> Dict[ possible_params.append([k for k in param_map.keys()]) descriptions.append( [ - "Parameter {key} are described as {info}.".format(key=k, info=v) + "Parameter {key} are described as {info}.".format( + key=k, info=v + ) for k, v in param_map.items() ] ) else: possible_params.append([]) descriptions.append([]) - parsed["parameters"].update({"categories": categories, "possible_parameters": possible_params, "descriptions": descriptions}) + parsed["parameters"].update( + { + "categories": categories, + "possible_parameters": possible_params, + "descriptions": descriptions, + } + ) return parsed @@ -182,55 +205,59 @@ def create_systematic_review( try: conn = postgres_server.conn cur = conn.cursor() - + insert_sql = """ INSERT INTO systematic_reviews (id, name, description, owner_id, owner_email, users, visible, criteria, criteria_yaml, criteria_parsed, created_at, updated_at) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s) """ - - cur.execute(insert_sql, ( - sr_id, - name.strip(), - description, - owner_id, - owner_email, - json.dumps(users_list), - True, - json.dumps(criteria_obj) if criteria_obj else None, - criteria_str, - json.dumps(criteria_parsed), - now, - now - )) - + + cur.execute( + insert_sql, + ( + sr_id, + name.strip(), + description, + owner_id, + owner_email, + json.dumps(users_list), + True, + json.dumps(criteria_obj) if criteria_obj else None, + criteria_str, + json.dumps(criteria_parsed), + now, + now, + ), + ) + conn.commit() - + # Fetch the created record cur.execute("SELECT * FROM systematic_reviews WHERE id = %s", (sr_id,)) row = cur.fetchone() cols = [desc[0] for desc in cur.description] sr_doc = {cols[i]: row[i] for i in range(len(cols))} - + # Parse JSON fields and convert timestamps - if sr_doc.get('users') and isinstance(sr_doc['users'], str): - sr_doc['users'] = json.loads(sr_doc['users']) - if sr_doc.get('criteria') and isinstance(sr_doc['criteria'], str): - sr_doc['criteria'] = json.loads(sr_doc['criteria']) - if sr_doc.get('criteria_parsed') and isinstance(sr_doc['criteria_parsed'], str): - sr_doc['criteria_parsed'] = json.loads(sr_doc['criteria_parsed']) + if sr_doc.get("users") and isinstance(sr_doc["users"], str): + sr_doc["users"] = json.loads(sr_doc["users"]) + if sr_doc.get("criteria") and isinstance(sr_doc["criteria"], str): + sr_doc["criteria"] = json.loads(sr_doc["criteria"]) + if sr_doc.get("criteria_parsed") and isinstance( + sr_doc["criteria_parsed"], str + ): + sr_doc["criteria_parsed"] = json.loads(sr_doc["criteria_parsed"]) # Convert datetime objects to ISO strings from datetime import datetime as dt - if sr_doc.get('created_at') and isinstance(sr_doc['created_at'], dt): - sr_doc['created_at'] = sr_doc['created_at'].isoformat() - if sr_doc.get('updated_at') and isinstance(sr_doc['updated_at'], dt): - sr_doc['updated_at'] = sr_doc['updated_at'].isoformat() - - + if sr_doc.get("created_at") and isinstance(sr_doc["created_at"], dt): + sr_doc["created_at"] = sr_doc["created_at"].isoformat() + if sr_doc.get("updated_at") and isinstance(sr_doc["updated_at"], dt): + sr_doc["updated_at"] = sr_doc["updated_at"].isoformat() + return sr_doc - + except Exception as e: try: if conn: @@ -238,60 +265,75 @@ def create_systematic_review( except Exception: pass logger.exception(f"Failed to insert SR document: {e}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to create systematic review: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to create systematic review: {e}", + ) finally: if conn: pass - def add_user(self, sr_id: str, target_user_id: str, requester_id: str) -> Dict[str, Any]: + def add_user( + self, sr_id: str, target_user_id: str, requester_id: str + ) -> Dict[str, Any]: """ Add a user id to the SR's users list. Enforces that the SR exists and is visible; requester must be a member or owner. Returns a dict with update result metadata. """ - sr = self.get_systematic_review(sr_id) if not sr or not sr.get("visible", True): - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Systematic review not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Systematic review not found", + ) # Check permission has_perm = self.user_has_sr_permission(sr_id, requester_id) if not has_perm: - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized to modify this systematic review") + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Not authorized to modify this systematic review", + ) conn = None try: conn = postgres_server.conn cur = conn.cursor() - + # Get current users array cur.execute("SELECT users FROM systematic_reviews WHERE id = %s", (sr_id,)) row = cur.fetchone() if not row: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Systematic review not found") - + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Systematic review not found", + ) + users = row[0] if row[0] else [] if isinstance(users, str): users = json.loads(users) - + # Add user if not already present if target_user_id not in users: users.append(target_user_id) - + # Update now = datetime.utcnow().isoformat() cur.execute( "UPDATE systematic_reviews SET users = %s, updated_at = %s WHERE id = %s", - (json.dumps(users), now, sr_id) + (json.dumps(users), now, sr_id), ) modified_count = cur.rowcount conn.commit() - - - return {"matched_count": 1, "modified_count": modified_count, "added_user_id": target_user_id} - + return { + "matched_count": 1, + "modified_count": modified_count, + "added_user_id": target_user_id, + } + except HTTPException: try: if conn: @@ -306,62 +348,80 @@ def add_user(self, sr_id: str, target_user_id: str, requester_id: str) -> Dict[s except Exception: pass logger.exception(f"Failed to add user to SR: {e}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to add user: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to add user: {e}", + ) finally: if conn: pass - def remove_user(self, sr_id: str, target_user_id: str, requester_id: str) -> Dict[str, Any]: + def remove_user( + self, sr_id: str, target_user_id: str, requester_id: str + ) -> Dict[str, Any]: """ Remove a user id from the SR's users list. Owner cannot be removed. Enforces requester permissions (must be a member or owner). """ - sr = self.get_systematic_review(sr_id) if not sr or not sr.get("visible", True): - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Systematic review not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Systematic review not found", + ) # Check permission has_perm = self.user_has_sr_permission(sr_id, requester_id) if not has_perm: - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized to modify this systematic review") + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Not authorized to modify this systematic review", + ) if target_user_id == sr.get("owner_id"): - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot remove the owner from the systematic review") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Cannot remove the owner from the systematic review", + ) conn = None try: conn = postgres_server.conn cur = conn.cursor() - + # Get current users array cur.execute("SELECT users FROM systematic_reviews WHERE id = %s", (sr_id,)) row = cur.fetchone() if not row: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Systematic review not found") - + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Systematic review not found", + ) + users = row[0] if row[0] else [] if isinstance(users, str): users = json.loads(users) - + # Remove user if present if target_user_id in users: users.remove(target_user_id) - + # Update now = datetime.utcnow().isoformat() cur.execute( "UPDATE systematic_reviews SET users = %s, updated_at = %s WHERE id = %s", - (json.dumps(users), now, sr_id) + (json.dumps(users), now, sr_id), ) modified_count = cur.rowcount conn.commit() - - - return {"matched_count": 1, "modified_count": modified_count, "removed_user_id": target_user_id} - + return { + "matched_count": 1, + "modified_count": modified_count, + "removed_user_id": target_user_id, + } + except HTTPException: try: if conn: @@ -376,7 +436,10 @@ def remove_user(self, sr_id: str, target_user_id: str, requester_id: str) -> Dic except Exception: pass logger.exception(f"Failed to remove user from SR: {e}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to remove user: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to remove user: {e}", + ) finally: if conn: pass @@ -388,7 +451,6 @@ def user_has_sr_permission(self, sr_id: str, user_id: str) -> bool: Note: this check deliberately ignores the SR's 'visible' flag so membership checks work regardless of whether the SR is hidden/soft-deleted. """ - doc = self.get_systematic_review(sr_id, ignore_visibility=True) if not doc: @@ -399,56 +461,71 @@ def user_has_sr_permission(self, sr_id: str, user_id: str) -> bool: return True return False - def update_criteria(self, sr_id: str, criteria_obj: Dict[str, Any], criteria_str: str, requester_id: str) -> Dict[str, Any]: + def update_criteria( + self, + sr_id: str, + criteria_obj: Dict[str, Any], + criteria_str: str, + requester_id: str, + ) -> Dict[str, Any]: """ Update the criteria fields (criteria, criteria_yaml, criteria_parsed, updated_at). The requester must be a member or owner. Returns the updated SR document. """ - sr = self.get_systematic_review(sr_id) if not sr or not sr.get("visible", True): - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Systematic review not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Systematic review not found", + ) # Check permission has_perm = self.user_has_sr_permission(sr_id, requester_id) if not has_perm: - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized to modify this systematic review") + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Not authorized to modify this systematic review", + ) conn = None try: conn = postgres_server.conn cur = conn.cursor() - + updated_at = datetime.utcnow().isoformat() criteria_parsed = self.build_criteria_parsed(criteria_obj) - + update_sql = """ UPDATE systematic_reviews SET criteria = %s, criteria_yaml = %s, criteria_parsed = %s, updated_at = %s WHERE id = %s """ - - cur.execute(update_sql, ( - json.dumps(criteria_obj), - criteria_str, - json.dumps(criteria_parsed), - updated_at, - sr_id - )) - + + cur.execute( + update_sql, + ( + json.dumps(criteria_obj), + criteria_str, + json.dumps(criteria_parsed), + updated_at, + sr_id, + ), + ) + if cur.rowcount == 0: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Systematic review not found during update") - + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Systematic review not found during update", + ) + conn.commit() - - # Return fresh doc doc = self.get_systematic_review(sr_id) return doc - + except HTTPException: try: if conn: @@ -463,7 +540,10 @@ def update_criteria(self, sr_id: str, criteria_obj: Dict[str, Any], criteria_str except Exception: pass logger.exception(f"Failed to update SR criteria: {e}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to update criteria: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to update criteria: {e}", + ) finally: if conn: pass @@ -472,48 +552,48 @@ def list_systematic_reviews_for_user(self, user_email: str) -> List[Dict[str, An """ Return all SR documents where the user is a member (regardless of visible flag). """ - conn = None try: conn = postgres_server.conn cur = conn.cursor() - + # Query using jsonb operator to check if user_email is in users array query = """ SELECT * FROM systematic_reviews WHERE users @> %s::jsonb ORDER BY created_at DESC """ - + cur.execute(query, (json.dumps([user_email]),)) rows = cur.fetchall() cols = [desc[0] for desc in cur.description] - + results = [] for row in rows: doc = {cols[i]: row[i] for i in range(len(cols))} - + # Parse JSON fields and convert timestamps - if doc.get('users') and isinstance(doc['users'], str): - doc['users'] = json.loads(doc['users']) - if doc.get('criteria') and isinstance(doc['criteria'], str): - doc['criteria'] = json.loads(doc['criteria']) - if doc.get('criteria_parsed') and isinstance(doc['criteria_parsed'], str): - doc['criteria_parsed'] = json.loads(doc['criteria_parsed']) + if doc.get("users") and isinstance(doc["users"], str): + doc["users"] = json.loads(doc["users"]) + if doc.get("criteria") and isinstance(doc["criteria"], str): + doc["criteria"] = json.loads(doc["criteria"]) + if doc.get("criteria_parsed") and isinstance( + doc["criteria_parsed"], str + ): + doc["criteria_parsed"] = json.loads(doc["criteria_parsed"]) # Convert datetime objects to ISO strings from datetime import datetime as dt - if doc.get('created_at') and isinstance(doc['created_at'], dt): - doc['created_at'] = doc['created_at'].isoformat() - if doc.get('updated_at') and isinstance(doc['updated_at'], dt): - doc['updated_at'] = doc['updated_at'].isoformat() - + + if doc.get("created_at") and isinstance(doc["created_at"], dt): + doc["created_at"] = doc["created_at"].isoformat() + if doc.get("updated_at") and isinstance(doc["updated_at"], dt): + doc["updated_at"] = doc["updated_at"].isoformat() + results.append(doc) - - return results - + except Exception as e: try: if conn: @@ -521,55 +601,60 @@ def list_systematic_reviews_for_user(self, user_email: str) -> List[Dict[str, An except Exception: pass logger.exception(f"Failed to list SRs for user: {e}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to list systematic reviews: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to list systematic reviews: {e}", + ) finally: if conn: pass - def get_systematic_review(self, sr_id: str, ignore_visibility: bool = False) -> Optional[Dict[str, Any]]: + def get_systematic_review( + self, sr_id: str, ignore_visibility: bool = False + ) -> Optional[Dict[str, Any]]: """ Return SR document by id. Returns None if not found. If ignore_visibility is False, only returns visible SRs. """ - conn = None try: conn = postgres_server.conn cur = conn.cursor() - + if ignore_visibility: query = "SELECT * FROM systematic_reviews WHERE id = %s" else: - query = "SELECT * FROM systematic_reviews WHERE id = %s AND visible = TRUE" - + query = ( + "SELECT * FROM systematic_reviews WHERE id = %s AND visible = TRUE" + ) + cur.execute(query, (sr_id,)) row = cur.fetchone() - + if not row: return None - + cols = [desc[0] for desc in cur.description] doc = {cols[i]: row[i] for i in range(len(cols))} - + # Parse JSON fields and convert timestamps - if doc.get('users') and isinstance(doc['users'], str): - doc['users'] = json.loads(doc['users']) - if doc.get('criteria') and isinstance(doc['criteria'], str): - doc['criteria'] = json.loads(doc['criteria']) - if doc.get('criteria_parsed') and isinstance(doc['criteria_parsed'], str): - doc['criteria_parsed'] = json.loads(doc['criteria_parsed']) + if doc.get("users") and isinstance(doc["users"], str): + doc["users"] = json.loads(doc["users"]) + if doc.get("criteria") and isinstance(doc["criteria"], str): + doc["criteria"] = json.loads(doc["criteria"]) + if doc.get("criteria_parsed") and isinstance(doc["criteria_parsed"], str): + doc["criteria_parsed"] = json.loads(doc["criteria_parsed"]) # Convert datetime objects to ISO strings from datetime import datetime as dt - if doc.get('created_at') and isinstance(doc['created_at'], dt): - doc['created_at'] = doc['created_at'].isoformat() - if doc.get('updated_at') and isinstance(doc['updated_at'], dt): - doc['updated_at'] = doc['updated_at'].isoformat() - - + if doc.get("created_at") and isinstance(doc["created_at"], dt): + doc["created_at"] = doc["created_at"].isoformat() + if doc.get("updated_at") and isinstance(doc["updated_at"], dt): + doc["updated_at"] = doc["updated_at"].isoformat() + return doc - + except Exception as e: # IMPORTANT: do not swallow DB errors as "not found". # Roll back so this connection isn't poisoned for subsequent requests. @@ -584,37 +669,46 @@ def get_systematic_review(self, sr_id: str, ignore_visibility: bool = False) -> if conn: pass - def set_visibility(self, sr_id: str, visible: bool, requester_id: str) -> Dict[str, Any]: + def set_visibility( + self, sr_id: str, visible: bool, requester_id: str + ) -> Dict[str, Any]: """ Set the visible flag on the SR. Only owner is allowed to change visibility. Returns update metadata. """ - sr = self.get_systematic_review(sr_id, ignore_visibility=True) if not sr: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Systematic review not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Systematic review not found", + ) if requester_id != sr.get("owner_id"): - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Only the owner may change visibility of this systematic review") + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Only the owner may change visibility of this systematic review", + ) conn = None try: conn = postgres_server.conn cur = conn.cursor() - + updated_at = datetime.utcnow().isoformat() cur.execute( "UPDATE systematic_reviews SET visible = %s, updated_at = %s WHERE id = %s", - (bool(visible), updated_at, sr_id) + (bool(visible), updated_at, sr_id), ) modified_count = cur.rowcount conn.commit() - - - return {"matched_count": 1, "modified_count": modified_count, "visible": visible} - + return { + "matched_count": 1, + "modified_count": modified_count, + "visible": visible, + } + except Exception as e: try: if conn: @@ -622,48 +716,62 @@ def set_visibility(self, sr_id: str, visible: bool, requester_id: str) -> Dict[s except Exception: pass logger.exception(f"Failed to set visibility on SR: {e}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to set visibility: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to set visibility: {e}", + ) finally: if conn: - pass + pass - def soft_delete_systematic_review(self, sr_id: str, requester_id: str) -> Dict[str, Any]: + def soft_delete_systematic_review( + self, sr_id: str, requester_id: str + ) -> Dict[str, Any]: """ Soft-delete (set visible=False). Only owner may delete. """ return self.set_visibility(sr_id, False, requester_id) - def undelete_systematic_review(self, sr_id: str, requester_id: str) -> Dict[str, Any]: + def undelete_systematic_review( + self, sr_id: str, requester_id: str + ) -> Dict[str, Any]: """ Undelete (set visible=True). Only owner may undelete. """ return self.set_visibility(sr_id, True, requester_id) - def hard_delete_systematic_review(self, sr_id: str, requester_id: str) -> Dict[str, Any]: + def hard_delete_systematic_review( + self, sr_id: str, requester_id: str + ) -> Dict[str, Any]: """ Permanently remove the SR document. Only owner may hard delete. Returns deletion metadata (deleted_count). """ - sr = self.get_systematic_review(sr_id, ignore_visibility=True) if not sr: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Systematic review not found") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Systematic review not found", + ) if requester_id != sr.get("owner_id"): - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Only the owner may hard-delete this systematic review") + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Only the owner may hard-delete this systematic review", + ) conn = None try: conn = postgres_server.conn cur = conn.cursor() - + cur.execute("DELETE FROM systematic_reviews WHERE id = %s", (sr_id,)) deleted_count = cur.rowcount conn.commit() return {"deleted_count": deleted_count} - + except Exception as e: try: if conn: @@ -671,32 +779,33 @@ def hard_delete_systematic_review(self, sr_id: str, requester_id: str) -> Dict[s except Exception: pass logger.exception(f"Failed to hard-delete SR: {e}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to hard-delete systematic review: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to hard-delete systematic review: {e}", + ) finally: if conn: pass - - def update_screening_db_info(self, sr_id: str, screening_db: Dict[str, Any]) -> None: + def update_screening_db_info( + self, sr_id: str, screening_db: Dict[str, Any] + ) -> None: """ Update the screening_db field in the SR document with screening database metadata. """ - conn = None try: conn = postgres_server.conn cur = conn.cursor() - + updated_at = datetime.utcnow().isoformat() cur.execute( "UPDATE systematic_reviews SET screening_db = %s, updated_at = %s WHERE id = %s", - (json.dumps(screening_db), updated_at, sr_id) + (json.dumps(screening_db), updated_at, sr_id), ) conn.commit() - - except Exception as e: try: if conn: @@ -704,7 +813,10 @@ def update_screening_db_info(self, sr_id: str, screening_db: Dict[str, Any]) -> except Exception: pass logger.exception(f"Failed to update screening DB info: {e}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to update screening DB info: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to update screening DB info: {e}", + ) finally: if conn: pass @@ -713,22 +825,19 @@ def clear_screening_db_info(self, sr_id: str) -> None: """ Remove the screening_db field from the SR document. """ - conn = None try: conn = postgres_server.conn cur = conn.cursor() - + updated_at = datetime.utcnow().isoformat() cur.execute( "UPDATE systematic_reviews SET screening_db = NULL, updated_at = %s WHERE id = %s", - (updated_at, sr_id) + (updated_at, sr_id), ) conn.commit() - - except Exception as e: try: if conn: @@ -736,7 +845,10 @@ def clear_screening_db_info(self, sr_id: str) -> None: except Exception: pass logger.exception(f"Failed to clear screening DB info: {e}") - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to clear screening DB info: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to clear screening DB info: {e}", + ) finally: if conn: pass diff --git a/backend/api/services/storage.py b/backend/api/services/storage.py index 3ab9b25b..52c69441 100644 --- a/backend/api/services/storage.py +++ b/backend/api/services/storage.py @@ -23,7 +23,11 @@ try: from azure.core.exceptions import ResourceNotFoundError from azure.identity import DefaultAzureCredential - from azure.storage.blob import BlobSasPermissions, BlobServiceClient, generate_blob_sas + from azure.storage.blob import ( + BlobSasPermissions, + BlobServiceClient, + generate_blob_sas, + ) except Exception: # pragma: no cover # Allow local-storage deployments/environments to import without azure packages. ResourceNotFoundError = Exception # type: ignore @@ -44,16 +48,28 @@ class StorageService(Protocol): container_name: str async def create_user_directory(self, user_id: str) -> bool: ... - async def save_user_profile(self, user_id: str, profile_data: Dict[str, Any]) -> bool: ... + async def save_user_profile( + self, user_id: str, profile_data: Dict[str, Any] + ) -> bool: ... async def get_user_profile(self, user_id: str) -> Optional[Dict[str, Any]]: ... - async def upload_user_document(self, user_id: str, filename: str, file_content: bytes) -> Optional[str]: ... - async def get_user_document(self, user_id: str, doc_id: str, filename: str) -> Optional[bytes]: ... + async def upload_user_document( + self, user_id: str, filename: str, file_content: bytes + ) -> Optional[str]: ... + async def get_user_document( + self, user_id: str, doc_id: str, filename: str + ) -> Optional[bytes]: ... async def list_user_documents(self, user_id: str) -> List[Dict[str, Any]]: ... - async def delete_user_document(self, user_id: str, doc_id: str, filename: str) -> bool: ... - async def put_bytes_by_path(self, path: str, content: bytes, content_type: str = "application/octet-stream") -> bool: ... + async def delete_user_document( + self, user_id: str, doc_id: str, filename: str + ) -> bool: ... + async def put_bytes_by_path( + self, path: str, content: bytes, content_type: str = "application/octet-stream" + ) -> bool: ... async def get_bytes_by_path(self, path: str) -> Tuple[bytes, str]: ... async def delete_by_path(self, path: str) -> bool: ... - async def generate_signed_url(self, path: str, expiry_minutes: int = 5) -> Optional[str]: ... + async def generate_signed_url( + self, path: str, expiry_minutes: int = 5 + ) -> Optional[str]: ... # ============================================================================= @@ -64,36 +80,50 @@ async def generate_signed_url(self, path: str, expiry_minutes: int = 5) -> Optio class AzureStorageService: """Service for managing user data in Azure Blob Storage.""" - def __init__(self, *, account_url: str | None = None, connection_string: str | None = None, container_name: str): + def __init__( + self, + *, + account_url: str | None = None, + connection_string: str | None = None, + container_name: str, + ): if not BlobServiceClient: raise RuntimeError( "Azure storage libraries are not installed. Install azure-identity and azure-storage-blob, or use STORAGE_MODE=local." ) if bool(account_url) == bool(connection_string): - raise ValueError("Exactly one of account_url or connection_string must be provided") + raise ValueError( + "Exactly one of account_url or connection_string must be provided" + ) self._account_key: str | None = None self._credential: Any = None if connection_string: - self.blob_service_client = BlobServiceClient.from_connection_string(connection_string) - self._account_key = self._get_account_key_from_connection_str(connection_string) + self.blob_service_client = BlobServiceClient.from_connection_string( + connection_string + ) + self._account_key = self._get_account_key_from_connection_str( + connection_string + ) else: if not DefaultAzureCredential: raise RuntimeError( "azure-identity is not installed. Install azure-identity, or use STORAGE_MODE=azure (connection string) or local." ) self._credential = DefaultAzureCredential() - self.blob_service_client = BlobServiceClient(account_url=account_url, credential=self._credential) + self.blob_service_client = BlobServiceClient( + account_url=account_url, credential=self._credential + ) self.container_name = container_name self._ensure_container_exists() - + def _get_account_key_from_connection_str(self, connection_str): for part in connection_str.split(";"): if part.startswith("AccountKey="): - return part[len("AccountKey="):] + return part[len("AccountKey=") :] return None def _ensure_container_exists(self): @@ -116,17 +146,23 @@ async def create_user_directory(self, user_id: str) -> bool: # Create placeholder file to establish directory structure blob_name = f"users/{user_id}/documents/.placeholder" - blob_client = self.blob_service_client.get_blob_client(container=self.container_name, blob=blob_name) + blob_client = self.blob_service_client.get_blob_client( + container=self.container_name, blob=blob_name + ) blob_client.upload_blob(b"", overwrite=True) return True except Exception as e: logger.error("Error creating user directory for %s: %s", user_id, e) return False - async def save_user_profile(self, user_id: str, profile_data: Dict[str, Any]) -> bool: + async def save_user_profile( + self, user_id: str, profile_data: Dict[str, Any] + ) -> bool: try: blob_name = f"users/{user_id}/profile.json" - blob_client = self.blob_service_client.get_blob_client(container=self.container_name, blob=blob_name) + blob_client = self.blob_service_client.get_blob_client( + container=self.container_name, blob=blob_name + ) blob_client.upload_blob(json.dumps(profile_data, indent=2), overwrite=True) return True except Exception as e: @@ -136,7 +172,9 @@ async def save_user_profile(self, user_id: str, profile_data: Dict[str, Any]) -> async def get_user_profile(self, user_id: str) -> Optional[Dict[str, Any]]: try: blob_name = f"users/{user_id}/profile.json" - blob_client = self.blob_service_client.get_blob_client(container=self.container_name, blob=blob_name) + blob_client = self.blob_service_client.get_blob_client( + container=self.container_name, blob=blob_name + ) blob_data = blob_client.download_blob().readall() return json.loads(blob_data.decode("utf-8")) except ResourceNotFoundError: @@ -145,11 +183,15 @@ async def get_user_profile(self, user_id: str) -> Optional[Dict[str, Any]]: logger.error("Error getting user profile for %s: %s", user_id, e) return None - async def upload_user_document(self, user_id: str, filename: str, file_content: bytes) -> Optional[str]: + async def upload_user_document( + self, user_id: str, filename: str, file_content: bytes + ) -> Optional[str]: try: doc_id = str(uuid.uuid4()) blob_name = f"users/{user_id}/documents/{doc_id}_{filename}" - blob_client = self.blob_service_client.get_blob_client(container=self.container_name, blob=blob_name) + blob_client = self.blob_service_client.get_blob_client( + container=self.container_name, blob=blob_name + ) blob_client.upload_blob(file_content, overwrite=True) file_metadata = create_file_metadata( @@ -166,30 +208,42 @@ async def upload_user_document(self, user_id: str, filename: str, file_content: profile = await self.get_user_profile(user_id) if profile: profile["document_count"] = int(profile.get("document_count", 0)) + 1 - profile["storage_used"] = int(profile.get("storage_used", 0)) + len(file_content) + profile["storage_used"] = int(profile.get("storage_used", 0)) + len( + file_content + ) profile["last_updated"] = datetime.now(timezone.utc).isoformat() await self.save_user_profile(user_id, profile) return doc_id except Exception as e: - logger.error("Error uploading document %s for user %s: %s", filename, user_id, e) + logger.error( + "Error uploading document %s for user %s: %s", filename, user_id, e + ) return None - async def get_user_document(self, user_id: str, doc_id: str, filename: str) -> Optional[bytes]: + async def get_user_document( + self, user_id: str, doc_id: str, filename: str + ) -> Optional[bytes]: try: blob_name = f"users/{user_id}/documents/{doc_id}_{filename}" - blob_client = self.blob_service_client.get_blob_client(container=self.container_name, blob=blob_name) + blob_client = self.blob_service_client.get_blob_client( + container=self.container_name, blob=blob_name + ) return blob_client.download_blob().readall() except ResourceNotFoundError: return None except Exception as e: - logger.error("Error getting document %s for user %s: %s", doc_id, user_id, e) + logger.error( + "Error getting document %s for user %s: %s", doc_id, user_id, e + ) return None async def list_user_documents(self, user_id: str) -> List[Dict[str, Any]]: try: prefix = f"users/{user_id}/documents/" - blobs = self.blob_service_client.get_container_client(self.container_name).list_blobs(name_starts_with=prefix) + blobs = self.blob_service_client.get_container_client( + self.container_name + ).list_blobs(name_starts_with=prefix) documents: List[Dict[str, Any]] = [] for blob in blobs: @@ -218,10 +272,14 @@ async def list_user_documents(self, user_id: str) -> List[Dict[str, Any]]: logger.error("Error listing user documents for %s: %s", user_id, e) return [] - async def delete_user_document(self, user_id: str, doc_id: str, filename: str) -> bool: + async def delete_user_document( + self, user_id: str, doc_id: str, filename: str + ) -> bool: try: doc_blob_name = f"users/{user_id}/documents/{doc_id}_{filename}" - doc_blob_client = self.blob_service_client.get_blob_client(container=self.container_name, blob=doc_blob_name) + doc_blob_client = self.blob_service_client.get_blob_client( + container=self.container_name, blob=doc_blob_name + ) try: doc_size = doc_blob_client.get_blob_properties().size @@ -233,30 +291,44 @@ async def delete_user_document(self, user_id: str, doc_id: str, filename: str) - profile = await self.get_user_profile(user_id) if profile: - profile["document_count"] = max(0, int(profile.get("document_count", 0)) - 1) - profile["storage_used"] = max(0, int(profile.get("storage_used", 0)) - int(doc_size)) + profile["document_count"] = max( + 0, int(profile.get("document_count", 0)) - 1 + ) + profile["storage_used"] = max( + 0, int(profile.get("storage_used", 0)) - int(doc_size) + ) profile["last_updated"] = datetime.now(timezone.utc).isoformat() await self.save_user_profile(user_id, profile) return True except Exception as e: - logger.error("Error deleting document %s for user %s: %s", doc_id, user_id, e) + logger.error( + "Error deleting document %s for user %s: %s", doc_id, user_id, e + ) return False - async def save_file_hash_metadata(self, user_id: str, document_id: str, file_metadata: Dict[str, Any]) -> bool: + async def save_file_hash_metadata( + self, user_id: str, document_id: str, file_metadata: Dict[str, Any] + ) -> bool: try: blob_name = f"users/{user_id}/metadata/{document_id}_metadata.json" - blob_client = self.blob_service_client.get_blob_client(container=self.container_name, blob=blob_name) + blob_client = self.blob_service_client.get_blob_client( + container=self.container_name, blob=blob_name + ) blob_client.upload_blob(json.dumps(file_metadata, indent=2), overwrite=True) return True except Exception as e: logger.error("Error saving file metadata for %s: %s", document_id, e) return False - async def get_file_hash_metadata(self, user_id: str, document_id: str) -> Optional[Dict[str, Any]]: + async def get_file_hash_metadata( + self, user_id: str, document_id: str + ) -> Optional[Dict[str, Any]]: try: blob_name = f"users/{user_id}/metadata/{document_id}_metadata.json" - blob_client = self.blob_service_client.get_blob_client(container=self.container_name, blob=blob_name) + blob_client = self.blob_service_client.get_blob_client( + container=self.container_name, blob=blob_name + ) metadata_json = blob_client.download_blob().readall().decode("utf-8") return json.loads(metadata_json) except ResourceNotFoundError: @@ -268,7 +340,9 @@ async def get_file_hash_metadata(self, user_id: str, document_id: str) -> Option async def delete_file_hash_metadata(self, user_id: str, document_id: str) -> bool: try: blob_name = f"users/{user_id}/metadata/{document_id}_metadata.json" - blob_client = self.blob_service_client.get_blob_client(container=self.container_name, blob=blob_name) + blob_client = self.blob_service_client.get_blob_client( + container=self.container_name, blob=blob_name + ) blob_client.delete_blob() return True except ResourceNotFoundError: @@ -277,12 +351,16 @@ async def delete_file_hash_metadata(self, user_id: str, document_id: str) -> boo logger.error("Error deleting file metadata for %s: %s", document_id, e) return False - async def put_bytes_by_path(self, path: str, content: bytes, content_type: str = "application/octet-stream") -> bool: + async def put_bytes_by_path( + self, path: str, content: bytes, content_type: str = "application/octet-stream" + ) -> bool: """Write blob by storage path 'container/blob'.""" if not path or "/" not in path: raise ValueError("Invalid storage path") container, blob = path.split("/", 1) - blob_client = self.blob_service_client.get_blob_client(container=container, blob=blob) + blob_client = self.blob_service_client.get_blob_client( + container=container, blob=blob + ) blob_client.upload_blob(content, overwrite=True, content_type=content_type) return True @@ -292,7 +370,9 @@ async def get_bytes_by_path(self, path: str) -> Tuple[bytes, str]: raise ValueError("Invalid storage path") container, blob = path.split("/", 1) - blob_client = self.blob_service_client.get_blob_client(container=container, blob=blob) + blob_client = self.blob_service_client.get_blob_client( + container=container, blob=blob + ) content = blob_client.download_blob().readall() filename = os.path.basename(blob) or "download" return content, filename @@ -302,11 +382,15 @@ async def delete_by_path(self, path: str) -> bool: if not path or "/" not in path: raise ValueError("Invalid storage path") container, blob = path.split("/", 1) - blob_client = self.blob_service_client.get_blob_client(container=container, blob=blob) + blob_client = self.blob_service_client.get_blob_client( + container=container, blob=blob + ) blob_client.delete_blob() return True - async def generate_signed_url(self, path: str, expiry_minutes: int = 5) -> Optional[str]: + async def generate_signed_url( + self, path: str, expiry_minutes: int = 5 + ) -> Optional[str]: """Generate a read-only SAS URL for a blob. Path format: 'container/blob'.""" if not path or "/" not in path: raise ValueError("Invalid storage path") @@ -323,13 +407,17 @@ async def generate_signed_url(self, path: str, expiry_minutes: int = 5) -> Optio "expiry": expiry, } if self._account_key: - sas_token = generate_blob_sas(**blob_sas_kwargs, account_key=self._account_key) + sas_token = generate_blob_sas( + **blob_sas_kwargs, account_key=self._account_key + ) elif self._credential: delegation_key = self.blob_service_client.get_user_delegation_key( key_start_time=datetime.now(timezone.utc) - timedelta(minutes=1), key_expiry_time=expiry, ) - sas_token = generate_blob_sas(**blob_sas_kwargs, user_delegation_key=delegation_key) + sas_token = generate_blob_sas( + **blob_sas_kwargs, user_delegation_key=delegation_key + ) else: raise RuntimeError("No credentials available for SAS generation") @@ -387,10 +475,14 @@ async def create_user_directory(self, user_id: str) -> bool: logger.error("Error creating user directory for %s: %s", user_id, e) return False - async def save_user_profile(self, user_id: str, profile_data: Dict[str, Any]) -> bool: + async def save_user_profile( + self, user_id: str, profile_data: Dict[str, Any] + ) -> bool: try: self._profile_path(user_id).parent.mkdir(parents=True, exist_ok=True) - self._profile_path(user_id).write_text(json.dumps(profile_data, indent=2), encoding="utf-8") + self._profile_path(user_id).write_text( + json.dumps(profile_data, indent=2), encoding="utf-8" + ) return True except Exception as e: logger.error("Error saving user profile for %s: %s", user_id, e) @@ -406,7 +498,9 @@ async def get_user_profile(self, user_id: str) -> Optional[Dict[str, Any]]: logger.error("Error getting user profile for %s: %s", user_id, e) return None - async def upload_user_document(self, user_id: str, filename: str, file_content: bytes) -> Optional[str]: + async def upload_user_document( + self, user_id: str, filename: str, file_content: bytes + ) -> Optional[str]: try: await self.create_user_directory(user_id) doc_id = str(uuid.uuid4()) @@ -428,23 +522,31 @@ async def upload_user_document(self, user_id: str, filename: str, file_content: profile = await self.get_user_profile(user_id) if profile: profile["document_count"] = int(profile.get("document_count", 0)) + 1 - profile["storage_used"] = int(profile.get("storage_used", 0)) + len(file_content) + profile["storage_used"] = int(profile.get("storage_used", 0)) + len( + file_content + ) profile["last_updated"] = datetime.now(timezone.utc).isoformat() await self.save_user_profile(user_id, profile) return doc_id except Exception as e: - logger.error("Error uploading document %s for user %s: %s", filename, user_id, e) + logger.error( + "Error uploading document %s for user %s: %s", filename, user_id, e + ) return None - async def get_user_document(self, user_id: str, doc_id: str, filename: str) -> Optional[bytes]: + async def get_user_document( + self, user_id: str, doc_id: str, filename: str + ) -> Optional[bytes]: try: p = self._doc_path(user_id, doc_id, filename) if not p.exists(): return None return p.read_bytes() except Exception as e: - logger.error("Error getting document %s for user %s: %s", doc_id, user_id, e) + logger.error( + "Error getting document %s for user %s: %s", doc_id, user_id, e + ) return None async def list_user_documents(self, user_id: str) -> List[Dict[str, Any]]: @@ -468,8 +570,12 @@ async def list_user_documents(self, user_id: str) -> List[Dict[str, Any]]: "document_id": doc_id, "filename": filename, "file_size": stat.st_size, - "upload_date": datetime.fromtimestamp(stat.st_mtime, tz=timezone.utc).isoformat(), - "last_modified": datetime.fromtimestamp(stat.st_mtime, tz=timezone.utc).isoformat(), + "upload_date": datetime.fromtimestamp( + stat.st_mtime, tz=timezone.utc + ).isoformat(), + "last_modified": datetime.fromtimestamp( + stat.st_mtime, tz=timezone.utc + ).isoformat(), } if hash_metadata: doc_info["file_hash"] = hash_metadata.get("file_hash") @@ -483,7 +589,9 @@ async def list_user_documents(self, user_id: str) -> List[Dict[str, Any]]: logger.error("Error listing user documents for %s: %s", user_id, e) return [] - async def delete_user_document(self, user_id: str, doc_id: str, filename: str) -> bool: + async def delete_user_document( + self, user_id: str, doc_id: str, filename: str + ) -> bool: try: p = self._doc_path(user_id, doc_id, filename) doc_size = p.stat().st_size if p.exists() else 0 @@ -493,16 +601,24 @@ async def delete_user_document(self, user_id: str, doc_id: str, filename: str) - profile = await self.get_user_profile(user_id) if profile: - profile["document_count"] = max(0, int(profile.get("document_count", 0)) - 1) - profile["storage_used"] = max(0, int(profile.get("storage_used", 0)) - int(doc_size)) + profile["document_count"] = max( + 0, int(profile.get("document_count", 0)) - 1 + ) + profile["storage_used"] = max( + 0, int(profile.get("storage_used", 0)) - int(doc_size) + ) profile["last_updated"] = datetime.now(timezone.utc).isoformat() await self.save_user_profile(user_id, profile) return True except Exception as e: - logger.error("Error deleting document %s for user %s: %s", doc_id, user_id, e) + logger.error( + "Error deleting document %s for user %s: %s", doc_id, user_id, e + ) return False - async def save_file_hash_metadata(self, user_id: str, document_id: str, file_metadata: Dict[str, Any]) -> bool: + async def save_file_hash_metadata( + self, user_id: str, document_id: str, file_metadata: Dict[str, Any] + ) -> bool: try: p = self._metadata_path(user_id, document_id) p.parent.mkdir(parents=True, exist_ok=True) @@ -512,7 +628,9 @@ async def save_file_hash_metadata(self, user_id: str, document_id: str, file_met logger.error("Error saving file metadata for %s: %s", document_id, e) return False - async def get_file_hash_metadata(self, user_id: str, document_id: str) -> Optional[Dict[str, Any]]: + async def get_file_hash_metadata( + self, user_id: str, document_id: str + ) -> Optional[Dict[str, Any]]: try: p = self._metadata_path(user_id, document_id) if not p.exists(): @@ -532,7 +650,9 @@ async def delete_file_hash_metadata(self, user_id: str, document_id: str) -> boo logger.error("Error deleting file metadata for %s: %s", document_id, e) return False - async def put_bytes_by_path(self, path: str, content: bytes, content_type: str = "application/octet-stream") -> bool: + async def put_bytes_by_path( + self, path: str, content: bytes, content_type: str = "application/octet-stream" + ) -> bool: """Write file by storage path 'container/blob'.""" if not path or "/" not in path: raise ValueError("Invalid storage path") @@ -585,7 +705,9 @@ async def delete_by_path(self, path: str) -> bool: p.unlink() return True - async def generate_signed_url(self, path: str, expiry_minutes: int = 5) -> Optional[str]: + async def generate_signed_url( + self, path: str, expiry_minutes: int = 5 + ) -> Optional[str]: """Local storage cannot generate signed URLs; returns None to signal streaming fallback.""" return None @@ -605,8 +727,13 @@ def _build_storage_service() -> Optional[StorageService]: return None if stype == "azure": try: - if not settings.AZURE_STORAGE_ACCOUNT_NAME or not settings.AZURE_STORAGE_ACCOUNT_KEY: - raise ValueError("STORAGE_MODE=azure requires AZURE_STORAGE_ACCOUNT_NAME and AZURE_STORAGE_ACCOUNT_KEY") + if ( + not settings.AZURE_STORAGE_ACCOUNT_NAME + or not settings.AZURE_STORAGE_ACCOUNT_KEY + ): + raise ValueError( + "STORAGE_MODE=azure requires AZURE_STORAGE_ACCOUNT_NAME and AZURE_STORAGE_ACCOUNT_KEY" + ) connection_string = ( "DefaultEndpointsProtocol=https;" f"AccountName={settings.AZURE_STORAGE_ACCOUNT_NAME};" @@ -618,13 +745,19 @@ def _build_storage_service() -> Optional[StorageService]: container_name=settings.STORAGE_CONTAINER_NAME, ) except Exception as e: - logger.exception("Failed to initialize AzureStorageService (connection string): %s", e) + logger.exception( + "Failed to initialize AzureStorageService (connection string): %s", e + ) return None if stype == "entra": try: if not settings.AZURE_STORAGE_ACCOUNT_NAME: - raise ValueError("STORAGE_MODE=entra requires AZURE_STORAGE_ACCOUNT_NAME") - account_url = f"https://{settings.AZURE_STORAGE_ACCOUNT_NAME}.blob.core.windows.net" + raise ValueError( + "STORAGE_MODE=entra requires AZURE_STORAGE_ACCOUNT_NAME" + ) + account_url = ( + f"https://{settings.AZURE_STORAGE_ACCOUNT_NAME}.blob.core.windows.net" + ) return AzureStorageService( account_url=account_url, container_name=settings.STORAGE_CONTAINER_NAME, diff --git a/backend/api/services/user_db.py b/backend/api/services/user_db.py index ee1c0411..c2025cb5 100644 --- a/backend/api/services/user_db.py +++ b/backend/api/services/user_db.py @@ -16,7 +16,6 @@ logger = logging.getLogger(__name__) - class UserDatabaseService: """Service for managing user data in PostgreSQL via psycopg3 async.""" @@ -29,6 +28,7 @@ def __init__(self): self._registry_cache_ts: float = 0.0 self._registry_cache_ttl_s: float = 30.0 self._registry_cache_lock = asyncio.Lock() + @staticmethod def _serialize_dates(row: Dict[str, Any]) -> Dict[str, Any]: for field in ("created_at", "updated_at", "last_login"): @@ -50,11 +50,16 @@ async def _load_user_registry(self) -> Dict[str, Any]: """Load the user registry from storage.""" async with self._registry_cache_lock: now = asyncio.get_running_loop().time() - if self._registry_cache is not None and (now - self._registry_cache_ts) < self._registry_cache_ttl_s: + if ( + self._registry_cache is not None + and (now - self._registry_cache_ts) < self._registry_cache_ttl_s + ): return self._registry_cache try: - content, _filename = await self.storage.get_bytes_by_path(self._registry_path()) + content, _filename = await self.storage.get_bytes_by_path( + self._registry_path() + ) reg = json.loads(content.decode("utf-8")) except Exception: # Create empty registry if it doesn't exist / cannot be read @@ -81,12 +86,12 @@ async def _save_user_registry(self, registry: Dict[str, Any]) -> bool: return ok except Exception: return False + async def ensure_table_exists(self) -> None: """Create the users table if it does not already exist.""" try: async with postgres_server.aconn() as conn: - await conn.execute( - """ + await conn.execute(""" CREATE TABLE IF NOT EXISTS users ( id TEXT PRIMARY KEY, email TEXT UNIQUE NOT NULL, @@ -98,8 +103,7 @@ async def ensure_table_exists(self) -> None: updated_at TIMESTAMP WITH TIME ZONE DEFAULT now(), last_login TIMESTAMP WITH TIME ZONE ) - """ - ) + """) logger.info("Ensured users table exists") except Exception as e: logger.exception("Failed to ensure users table exists: %s", e) @@ -129,8 +133,17 @@ async def create_user(self, user_data: UserCreate) -> Optional[UserRead]: created_at, updated_at, last_login) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s) """, - (user_id, email, user_data.full_name, hashed_password, - True, False, now, now, None), + ( + user_id, + email, + user_data.full_name, + hashed_password, + True, + False, + now, + now, + None, + ), ) return UserRead( id=user_id, @@ -178,7 +191,9 @@ async def authenticate_user( row = await cur.fetchone() if not row: return None - if not sso and not self._verify_password(password, row["hashed_password"]): + if not sso and not self._verify_password( + password, row["hashed_password"] + ): return None now = datetime.now(timezone.utc) await conn.execute( @@ -193,8 +208,12 @@ async def update_user( self, user_id: str, update_data: Dict[str, Any] ) -> Optional[Dict[str, Any]]: allowed_fields = { - "full_name", "email", "hashed_password", - "is_active", "is_superuser", "last_login", + "full_name", + "email", + "hashed_password", + "is_active", + "is_superuser", + "last_login", } fields = {k: v for k, v in update_data.items() if k in allowed_fields} if not fields: diff --git a/backend/api/sr/router.py b/backend/api/sr/router.py index e36b095d..84a4e92f 100644 --- a/backend/api/sr/router.py +++ b/backend/api/sr/router.py @@ -28,6 +28,7 @@ router = APIRouter() + class SystematicReviewCreate(BaseModel): name: str description: Optional[str] = None @@ -53,10 +54,9 @@ class SystematicReviewRead(BaseModel): criteria_parsed: Optional[Dict[str, Any]] = None - - - -@router.post("/create", response_model=SystematicReviewRead, status_code=status.HTTP_201_CREATED) +@router.post( + "/create", response_model=SystematicReviewRead, status_code=status.HTTP_201_CREATED +) async def create_systematic_review( name: str = Form(...), description: Optional[str] = Form(None), @@ -78,7 +78,9 @@ async def create_systematic_review( """ if not name or not name.strip(): - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="name is required") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="name is required" + ) # Load YAML criteria criteria_str: Optional[str] = None @@ -103,7 +105,8 @@ async def create_systematic_review( ) except yaml.YAMLError as ye: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid YAML provided: {ye}" + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Invalid YAML provided: {ye}", ) except Exception as e: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) @@ -121,7 +124,10 @@ async def create_systematic_review( except HTTPException: raise except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to create systematic review: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to create systematic review: {e}", + ) return SystematicReviewRead( id=sr_doc.get("id"), @@ -164,27 +170,45 @@ async def add_user_to_systematic_review( """ try: - sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service, require_screening=False) + sr, screening = await load_sr_and_check( + sr_id, current_user, srdb_service, require_screening=False + ) except HTTPException: raise except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to load systematic review: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to load systematic review: {e}", + ) # resolve user target_user_id = None if payload.user_email: target_user_id = payload.user_email else: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Missing data user_email") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="Missing data user_email" + ) try: - res = await run_in_threadpool(srdb_service.add_user, sr_id, target_user_id, current_user.get("id")) + res = await run_in_threadpool( + srdb_service.add_user, sr_id, target_user_id, current_user.get("id") + ) except HTTPException: raise except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to add user: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to add user: {e}", + ) - return {"status": "success", "sr_id": sr_id, "added_user_id": target_user_id, "matched_count": res.get("matched_count"), "modified_count": res.get("modified_count")} + return { + "status": "success", + "sr_id": sr_id, + "added_user_id": target_user_id, + "matched_count": res.get("matched_count"), + "modified_count": res.get("modified_count"), + } @router.post("/{sr_id}/remove-user") @@ -201,31 +225,52 @@ async def remove_user_from_systematic_review( The owner cannot be removed via this endpoint. """ try: - sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service, require_screening=False) + sr, screening = await load_sr_and_check( + sr_id, current_user, srdb_service, require_screening=False + ) except HTTPException: raise except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to load systematic review: {e}") - + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to load systematic review: {e}", + ) + # resolve user target_user_id = None if payload.user_email: target_user_id = payload.user_email else: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Missing data user_email") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="Missing data user_email" + ) # do not allow removing the owner if target_user_id == sr.get("owner_id"): - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot remove the owner from the systematic review") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Cannot remove the owner from the systematic review", + ) try: - res = await run_in_threadpool(srdb_service.remove_user, sr_id, target_user_id, current_user.get("id")) + res = await run_in_threadpool( + srdb_service.remove_user, sr_id, target_user_id, current_user.get("id") + ) except HTTPException: raise except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to remove user: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to remove user: {e}", + ) - return {"status": "success", "sr_id": sr_id, "removed_user_id": target_user_id, "matched_count": res.get("matched_count"), "modified_count": res.get("modified_count")} + return { + "status": "success", + "sr_id": sr_id, + "removed_user_id": target_user_id, + "matched_count": res.get("matched_count"), + "modified_count": res.get("modified_count"), + } @router.get("/mine", response_model=List[SystematicReviewRead]) @@ -240,11 +285,16 @@ async def list_systematic_reviews_for_user( user_id = current_user.get("email") results = [] try: - docs = await run_in_threadpool(srdb_service.list_systematic_reviews_for_user, user_id) + docs = await run_in_threadpool( + srdb_service.list_systematic_reviews_for_user, user_id + ) except HTTPException: raise except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to list systematic reviews: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to list systematic reviews: {e}", + ) for doc in docs: results.append( @@ -268,17 +318,24 @@ async def list_systematic_reviews_for_user( @router.get("/{sr_id}", response_model=SystematicReviewRead) -async def get_systematic_review(sr_id: str, current_user: Dict[str, Any] = Depends(get_current_active_user)): +async def get_systematic_review( + sr_id: str, current_user: Dict[str, Any] = Depends(get_current_active_user) +): """ Get a single systematic review by id. User must be a member to view. """ try: - doc, screening = await load_sr_and_check(sr_id, current_user, srdb_service, require_screening=False) + doc, screening = await load_sr_and_check( + sr_id, current_user, srdb_service, require_screening=False + ) except HTTPException: raise except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to load systematic review: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to load systematic review: {e}", + ) return SystematicReviewRead( id=doc.get("id"), @@ -308,11 +365,16 @@ async def get_systematic_review_criteria_parsed( """ try: - doc, screening = await load_sr_and_check(sr_id, current_user, srdb_service, require_screening=False) + doc, screening = await load_sr_and_check( + sr_id, current_user, srdb_service, require_screening=False + ) except HTTPException: raise except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to load systematic review: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to load systematic review: {e}", + ) cp = doc.get("criteria_parsed") or {} return {"criteria_parsed": cp} @@ -334,11 +396,16 @@ async def update_systematic_review_criteria( """ try: - sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service, require_screening=False) + sr, screening = await load_sr_and_check( + sr_id, current_user, srdb_service, require_screening=False + ) except HTTPException: raise except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to load systematic review: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to load systematic review: {e}", + ) # Load YAML criteria criteria_str: Optional[str] = None @@ -352,7 +419,10 @@ async def update_systematic_review_criteria( criteria_str = criteria_yaml if not criteria_str: - raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Either criteria_file or criteria_yaml must be provided") + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Either criteria_file or criteria_yaml must be provided", + ) criteria_obj = yaml.safe_load(criteria_str) if criteria_obj is None: @@ -364,18 +434,28 @@ async def update_systematic_review_criteria( ) except yaml.YAMLError as ye: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid YAML provided: {ye}" + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Invalid YAML provided: {ye}", ) except Exception as e: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) # perform update try: - doc = await run_in_threadpool(srdb_service.update_criteria, sr_id, criteria_obj, criteria_str, current_user.get("id")) + doc = await run_in_threadpool( + srdb_service.update_criteria, + sr_id, + criteria_obj, + criteria_str, + current_user.get("id"), + ) except HTTPException: raise except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to update criteria: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to update criteria: {e}", + ) return SystematicReviewRead( id=doc.get("id"), @@ -394,7 +474,9 @@ async def update_systematic_review_criteria( @router.delete("/{sr_id}") -async def delete_systematic_review(sr_id: str, current_user: Dict[str, Any] = Depends(get_current_active_user)): +async def delete_systematic_review( + sr_id: str, current_user: Dict[str, Any] = Depends(get_current_active_user) +): """ Soft-delete a systematic review by marking its 'visible' flag as False. @@ -402,28 +484,49 @@ async def delete_systematic_review(sr_id: str, current_user: Dict[str, Any] = De """ try: - sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service, require_screening=False) + sr, screening = await load_sr_and_check( + sr_id, current_user, srdb_service, require_screening=False + ) except HTTPException: raise except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to load systematic review: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to load systematic review: {e}", + ) requester_id = current_user.get("id") if requester_id != sr.get("owner_id"): - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Only the owner may delete this systematic review") + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Only the owner may delete this systematic review", + ) try: - res = await run_in_threadpool(srdb_service.soft_delete_systematic_review, sr_id, requester_id) + res = await run_in_threadpool( + srdb_service.soft_delete_systematic_review, sr_id, requester_id + ) except HTTPException: raise except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to delete systematic review: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to delete systematic review: {e}", + ) - return {"status": "success", "sr_id": sr_id, "deleted": True, "matched_count": res.get("matched_count"), "modified_count": res.get("modified_count")} + return { + "status": "success", + "sr_id": sr_id, + "deleted": True, + "matched_count": res.get("matched_count"), + "modified_count": res.get("modified_count"), + } @router.post("/{sr_id}/undelete") -async def undelete_systematic_review(sr_id: str, current_user: Dict[str, Any] = Depends(get_current_active_user)): +async def undelete_systematic_review( + sr_id: str, current_user: Dict[str, Any] = Depends(get_current_active_user) +): """ Undelete (restore) a systematic review by marking its 'visible' flag as True. @@ -431,28 +534,53 @@ async def undelete_systematic_review(sr_id: str, current_user: Dict[str, Any] = """ try: - sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service, require_screening=False, require_visible=False) + sr, screening = await load_sr_and_check( + sr_id, + current_user, + srdb_service, + require_screening=False, + require_visible=False, + ) except HTTPException: raise except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to load systematic review: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to load systematic review: {e}", + ) requester_id = current_user.get("id") if requester_id != sr.get("owner_id"): - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Only the owner may undelete this systematic review") + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Only the owner may undelete this systematic review", + ) try: - res = await run_in_threadpool(srdb_service.undelete_systematic_review, sr_id, requester_id) + res = await run_in_threadpool( + srdb_service.undelete_systematic_review, sr_id, requester_id + ) except HTTPException: raise except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to undelete systematic review: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to undelete systematic review: {e}", + ) - return {"status": "success", "sr_id": sr_id, "undeleted": True, "matched_count": res.get("matched_count"), "modified_count": res.get("modified_count")} + return { + "status": "success", + "sr_id": sr_id, + "undeleted": True, + "matched_count": res.get("matched_count"), + "modified_count": res.get("modified_count"), + } @router.delete("/{sr_id}/hard") -async def hard_delete_systematic_review(sr_id: str, current_user: Dict[str, Any] = Depends(get_current_active_user)): +async def hard_delete_systematic_review( + sr_id: str, current_user: Dict[str, Any] = Depends(get_current_active_user) +): """ Permanently remove the systematic review document from MongoDB. @@ -463,15 +591,27 @@ async def hard_delete_systematic_review(sr_id: str, current_user: Dict[str, Any] """ try: - sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service, require_screening=False, require_visible=False) + sr, screening = await load_sr_and_check( + sr_id, + current_user, + srdb_service, + require_screening=False, + require_visible=False, + ) except HTTPException: raise except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to load systematic review: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to load systematic review: {e}", + ) requester_id = current_user.get("id") if requester_id != sr.get("owner_id"): - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Only the owner may hard-delete this systematic review") + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Only the owner may hard-delete this systematic review", + ) # Attempt to perform screening resources cleanup prior to deleting the SR document. cleanup_result = None @@ -492,15 +632,23 @@ async def hard_delete_systematic_review(sr_id: str, current_user: Dict[str, Any] cleanup_result = {"status": "cleanup_import_failed", "error": str(e)} try: - res = await run_in_threadpool(srdb_service.hard_delete_systematic_review, sr_id, requester_id) + res = await run_in_threadpool( + srdb_service.hard_delete_systematic_review, sr_id, requester_id + ) deleted_count = res.get("deleted_count") if not deleted_count: # If backend reported zero deletions, raise NotFound to match prior behavior - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Systematic review not found during hard delete") + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Systematic review not found during hard delete", + ) except HTTPException: raise except Exception as e: - raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to hard-delete systematic review: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to hard-delete systematic review: {e}", + ) return { "status": "success", diff --git a/backend/docker-compose.yml b/backend/docker-compose.yml index 3cca1cbd..9a1bae1a 100644 --- a/backend/docker-compose.yml +++ b/backend/docker-compose.yml @@ -1,3 +1,19 @@ +# services: +# # ============================================================================= +# # CAN-SR BACKEND API +# # ============================================================================= +# api: +# build: +# context: . +# dockerfile: Dockerfile +# container_name: can-sr-api +# ports: +# - "8000:8000" +# environment: +# - GROBID_SERVICE_URL=http://grobid-service:8070 +# - JAVA_OPTS=-XX:-UseContainerSupport + + services: # ============================================================================= # CAN-SR BACKEND API @@ -29,9 +45,24 @@ services: # ============================================================================= # GROBID SERVICE - PDF Parsing for Full-Text Extraction # ============================================================================= + # grobid-service: + # image: grobid/grobid:0.8.2-crf + # container_name: grobid-service + # ports: + # - "8070:8070" + # - "8081:8081" + # restart: unless-stopped + # healthcheck: + # test: ["CMD", "curl", "-f", "http://localhost:8070/api/isalive"] + # interval: 10s + # timeout: 10s + # retries: 5 + # start_period: 120s grobid-service: - image: grobid/grobid:0.8.2-crf + image: grobid/grobid:0.8.2 container_name: grobid-service + environment: + - JAVA_TOOL_OPTIONS=-XX:+UseContainerSupport ports: - "8070:8070" - "8081:8081" @@ -57,7 +88,7 @@ services: ports: - "5432:5432" volumes: - - ./volumes/postgres:/var/lib/postgresql/data + - ./volumes/postgres:/var/lib/postgresql #/data healthcheck: test: ["CMD-SHELL", "pg_isready -U admin -d postgres -h localhost"] interval: 30s @@ -67,4 +98,4 @@ services: networks: default: - driver: bridge + driver: bridge \ No newline at end of file diff --git a/backend/main.py b/backend/main.py index 606dc55f..244c36f7 100644 --- a/backend/main.py +++ b/backend/main.py @@ -15,7 +15,6 @@ from api.services.sr_db_service import srdb_service from api.services.user_db import user_db_service - app = FastAPI( title=settings.PROJECT_NAME, description=settings.DESCRIPTION, @@ -29,10 +28,12 @@ async def startup_event(): """Startup event - initialize CAN-SR systematic review database""" from fastapi.concurrency import run_in_threadpool import asyncio - + # Reduce Azure SDK HTTP logging noise (especially during polling endpoints). logging.getLogger("azure").setLevel(logging.WARNING) - logging.getLogger("azure.core.pipeline.policies.http_logging_policy").setLevel(logging.WARNING) + logging.getLogger("azure.core.pipeline.policies.http_logging_policy").setLevel( + logging.WARNING + ) print("🚀 Starting CAN-SR Backend...", flush=True) print("📚 Initializing systematic review database...", flush=True) @@ -61,6 +62,7 @@ async def startup_event(): # Keep Procrastinate open for the whole API lifespan so request handlers # can enqueue jobs. from api.jobs.procrastinate_app import PROCRASTINATE_APP + await PROCRASTINATE_APP.open_async() await ensure_procrastinate_schema() await run_in_threadpool(run_all_repo.ensure_tables) @@ -71,9 +73,14 @@ async def startup_event(): if getattr(settings, "PROCRASTINATE_CLEAR_ON_START", True): try: cleared = await clear_pending_jobs(queues=["default"]) - print(f"🧹 Cleared {cleared} pending Procrastinate jobs", flush=True) + print( + f"🧹 Cleared {cleared} pending Procrastinate jobs", flush=True + ) except Exception as e: - print(f"⚠️ Failed to clear pending Procrastinate jobs: {e}", flush=True) + print( + f"⚠️ Failed to clear pending Procrastinate jobs: {e}", + flush=True, + ) if workers_enabled(): # Run a worker loop inside the API process (dev/quick deploy). @@ -96,8 +103,6 @@ async def shutdown_event(): except Exception: pass - - # Set up CORS cors_origins = ( @@ -105,7 +110,12 @@ async def shutdown_event(): if isinstance(settings.CORS_ORIGINS, str) else [settings.CORS_ORIGINS] ) -app.add_middleware(SessionMiddleware, secret_key=settings.SECRET_KEY, same_site="lax", https_only=settings.IS_DEPLOYED) +app.add_middleware( + SessionMiddleware, + secret_key=settings.SECRET_KEY, + same_site="lax", + https_only=settings.IS_DEPLOYED, +) app.add_middleware( CORSMiddleware, allow_origins=cors_origins, diff --git a/frontend/app/[lang]/can-sr/citations/full-text/route.ts b/frontend/app/[lang]/can-sr/citations/full-text/route.ts index 94912a11..5e587e9e 100644 --- a/frontend/app/[lang]/can-sr/citations/full-text/route.ts +++ b/frontend/app/[lang]/can-sr/citations/full-text/route.ts @@ -95,9 +95,16 @@ export async function GET(request: NextRequest) { const citation = await citationRes.json().catch(() => ({})) // Try common fields: document_id, documentId, storage_path, fulltext_url const documentId = - citation?.document_id || citation?.documentId || citation?.document || null + citation?.document_id || + citation?.documentId || + citation?.document || + null const storagePath = - citation?.storage_path || citation?.storagePath || citation?.fulltext_url || citation?.fulltext || null + citation?.storage_path || + citation?.storagePath || + citation?.fulltext_url || + citation?.fulltext || + null if (documentId) { fileFetchUrl = `${BACKEND_URL}/api/files/documents/${encodeURIComponent( @@ -163,7 +170,10 @@ export async function GET(request: NextRequest) { }) } catch (err: any) { console.error('Fulltext proxy GET error:', err) - return NextResponse.json({ error: 'Internal server error' }, { status: 500 }) + return NextResponse.json( + { error: 'Internal server error' }, + { status: 500 }, + ) } } diff --git a/frontend/app/[lang]/can-sr/extract/page.tsx b/frontend/app/[lang]/can-sr/extract/page.tsx index a15f8773..f31b0e19 100644 --- a/frontend/app/[lang]/can-sr/extract/page.tsx +++ b/frontend/app/[lang]/can-sr/extract/page.tsx @@ -50,7 +50,8 @@ const buildCitationAiCalls: BuildCitationAiCalls = async ({ ) const data = await res.json().catch(() => ({})) const parsed = data?.criteria_parsed || data?.criteria || {} - const paramsInfo: ParametersParsed | null = (parsed?.parameters as any) || null + const paramsInfo: ParametersParsed | null = + (parsed?.parameters as any) || null if (paramsInfo?.categories && paramsInfo?.possible_parameters) { const out: Array<{ name: string; description: string }> = [] @@ -68,7 +69,10 @@ const buildCitationAiCalls: BuildCitationAiCalls = async ({ typeof descs?.[j] === 'string' ? (descs[j] as string) : '' const cleanDesc = rawDesc.replace(/<\/?desc>/g, '') if (rawName && rawName.trim()) { - out.push({ name: rawName.trim(), description: cleanDesc || rawName }) + out.push({ + name: rawName.trim(), + description: cleanDesc || rawName, + }) } }) }) diff --git a/frontend/app/[lang]/can-sr/extract/view/page.tsx b/frontend/app/[lang]/can-sr/extract/view/page.tsx index 120c2609..d0750326 100644 --- a/frontend/app/[lang]/can-sr/extract/view/page.tsx +++ b/frontend/app/[lang]/can-sr/extract/view/page.tsx @@ -5,7 +5,9 @@ import { useState, useEffect, useRef } from 'react' import GCHeader, { SRHeader } from '@/components/can-sr/headers' import { getAuthToken, getTokenType } from '@/lib/auth' import { ModelSelector } from '@/components/chat' -import PDFBoundingBoxViewer, { PDFBoundingBoxViewerHandle } from '@/components/can-sr/PDFBoundingBoxViewer' +import PDFBoundingBoxViewer, { + PDFBoundingBoxViewerHandle, +} from '@/components/can-sr/PDFBoundingBoxViewer' import { ChevronDown, ChevronRight, Wand2 } from 'lucide-react' import { useDictionary } from '@/app/[lang]/DictionaryProvider' @@ -78,7 +80,8 @@ export default function CanSrL2ScreenPage() { return llmCol.replace('llm_param_', 'human_param_') } - const [parametersParsed, setParametersParsed] = useState(null) + const [parametersParsed, setParametersParsed] = + useState(null) const [paramValues, setParamValues] = useState>({}) const [savingParam, setSavingParam] = useState(null) const [saveStatus, setSaveStatus] = useState>({}) @@ -91,7 +94,9 @@ export default function CanSrL2ScreenPage() { const [aiPanels, setAiPanels] = useState>({}) const [panelOpen, setPanelOpen] = useState>({}) const [fulltextCoords, setFulltextCoords] = useState(null) - const [fulltextPages, setFulltextPages] = useState<{ width: number; height: number }[] | null>(null) + const [fulltextPages, setFulltextPages] = useState< + { width: number; height: number }[] | null + >(null) const [fulltextStr, setFulltextStr] = useState(null) const viewerRef = useRef(null) @@ -100,9 +105,14 @@ export default function CanSrL2ScreenPage() { const [fulltextFigures, setFulltextFigures] = useState(null) const scrollToArtifact = (kind: 'table' | 'figure', idx: number) => { - const list = kind === 'table' ? (fulltextTables || []) : (fulltextFigures || []) + const list = kind === 'table' ? fulltextTables || [] : fulltextFigures || [] const item = list.find((x: any) => Number(x?.index) === Number(idx)) - console.log('[artifact-click]', { kind, idx, hasViewer: !!viewerRef.current, item }) + console.log('[artifact-click]', { + kind, + idx, + hasViewer: !!viewerRef.current, + item, + }) if (!item || !viewerRef.current) return const bbox = item?.bounding_box const first = Array.isArray(bbox) ? bbox[0] : null @@ -112,7 +122,10 @@ export default function CanSrL2ScreenPage() { } const [runningAllAI, setRunningAllAI] = useState(false) - const [runAllProgress, setRunAllProgress] = useState<{ done: number; total: number } | null>(null) + const [runAllProgress, setRunAllProgress] = useState<{ + done: number + total: number + } | null>(null) // Cache full text so single-param and run-all don’t repeatedly trigger extraction/DB reads const fullTextCacheRef = useRef(null) @@ -176,11 +189,11 @@ export default function CanSrL2ScreenPage() { const suggestParam = async (name: string, description: string) => { if (!citationId || !srId) return - setAiStatus(prev => ({ ...prev, [name]: 'extracting' })) + setAiStatus((prev) => ({ ...prev, [name]: 'extracting' })) try { const headers = getAuthHeaders() const fullText = await ensureFullText() - setAiStatus(prev => ({ ...prev, [name]: 'suggesting' })) + setAiStatus((prev) => ({ ...prev, [name]: 'suggesting' })) const res = await fetch( `/api/can-sr/extract?action=extract-parameter&sr_id=${encodeURIComponent( srId || '', @@ -201,16 +214,16 @@ export default function CanSrL2ScreenPage() { const data = await res.json().catch(() => ({})) const ext = data?.extraction || data if (res.ok && ext) { - setParamValues(prev => ({ ...prev, [name]: ext?.value ?? '' })) - setParamFound(prev => ({ ...prev, [name]: !!ext?.found })) - setAiPanels(prev => ({ ...prev, [name]: ext })) - setPanelOpen(prev => ({ ...prev, [name]: false })) - setAiStatus(prev => ({ ...prev, [name]: 'suggested' })) + setParamValues((prev) => ({ ...prev, [name]: ext?.value ?? '' })) + setParamFound((prev) => ({ ...prev, [name]: !!ext?.found })) + setAiPanels((prev) => ({ ...prev, [name]: ext })) + setPanelOpen((prev) => ({ ...prev, [name]: false })) + setAiStatus((prev) => ({ ...prev, [name]: 'suggested' })) } else { - setAiStatus(prev => ({ ...prev, [name]: 'error' })) + setAiStatus((prev) => ({ ...prev, [name]: 'error' })) } } catch { - setAiStatus(prev => ({ ...prev, [name]: 'error' })) + setAiStatus((prev) => ({ ...prev, [name]: 'error' })) } } @@ -224,7 +237,11 @@ export default function CanSrL2ScreenPage() { const desc = parametersParsed.descriptions?.[i]?.[j] || '' const cleanDesc = desc.replace(/<\/?desc>/g, '') const paramName = - typeof param === 'string' ? param : Array.isArray(param) ? param[0] : String(param) + typeof param === 'string' + ? param + : Array.isArray(param) + ? param[0] + : String(param) params.push({ name: paramName, description: cleanDesc }) }) }) @@ -258,7 +275,11 @@ export default function CanSrL2ScreenPage() { const data = await res.json().catch(() => ({})) const parsed = data?.criteria_parsed || data?.criteria || {} const paramsInfo = parsed?.parameters - if (paramsInfo && paramsInfo.categories && paramsInfo.possible_parameters) { + if ( + paramsInfo && + paramsInfo.categories && + paramsInfo.possible_parameters + ) { setParametersParsed(paramsInfo) const defaults: Record = {} paramsInfo.possible_parameters.forEach((arr: string[]) => { @@ -266,14 +287,14 @@ export default function CanSrL2ScreenPage() { defaults[name] = defaults[name] || '' }) }) - setParamValues(prev => ({ ...defaults, ...prev })) + setParamValues((prev) => ({ ...defaults, ...prev })) const defaultFound: Record = {} paramsInfo.possible_parameters.forEach((arr: string[]) => { arr.forEach((name: string) => { defaultFound[name] = defaultFound[name] ?? false }) }) - setParamFound(prev => ({ ...defaultFound, ...prev })) + setParamFound((prev) => ({ ...defaultFound, ...prev })) } } catch (err) { console.warn('Failed to load parameters', err) @@ -333,29 +354,39 @@ export default function CanSrL2ScreenPage() { }) if (Object.keys(nextFound).length) { - setParamFound(prev => ({ ...prev, ...nextFound })) + setParamFound((prev) => ({ ...prev, ...nextFound })) } if (Object.keys(nextValues).length) { - setParamValues(prev => ({ ...prev, ...nextValues })) + setParamValues((prev) => ({ ...prev, ...nextValues })) } if (Object.keys(nextAIPanels).length) { - setAiPanels(prev => ({ ...prev, ...nextAIPanels })) + setAiPanels((prev) => ({ ...prev, ...nextAIPanels })) } // extract coords/pages/fulltext and artifacts for PDF overlay - const ft = typeof (row as any).fulltext === 'string' ? (row as any).fulltext : null + const ft = + typeof (row as any).fulltext === 'string' + ? (row as any).fulltext + : null if (ft) setFulltextStr(ft) - const tablesAny = parseJson((row as any).fulltext_tables) ?? (row as any).fulltext_tables + const tablesAny = + parseJson((row as any).fulltext_tables) ?? + (row as any).fulltext_tables if (tablesAny && Array.isArray(tablesAny)) setFulltextTables(tablesAny) - const figsAny = parseJson((row as any).fulltext_figures) ?? (row as any).fulltext_figures + const figsAny = + parseJson((row as any).fulltext_figures) ?? + (row as any).fulltext_figures if (figsAny && Array.isArray(figsAny)) setFulltextFigures(figsAny) - const coordsAny = parseJson((row as any).fulltext_coords) ?? (row as any).fulltext_coords + const coordsAny = + parseJson((row as any).fulltext_coords) ?? + (row as any).fulltext_coords if (coordsAny && Array.isArray(coordsAny)) setFulltextCoords(coordsAny) - const pagesAny = parseJson((row as any).fulltext_pages) ?? (row as any).fulltext_pages + const pagesAny = + parseJson((row as any).fulltext_pages) ?? (row as any).fulltext_pages if (pagesAny && Array.isArray(pagesAny)) setFulltextPages(pagesAny) } catch (err) { console.warn('Failed to prefill citation params', err) @@ -387,24 +418,37 @@ export default function CanSrL2ScreenPage() { } } - const ft = typeof (row as any).fulltext === 'string' ? (row as any).fulltext : null + const ft = + typeof (row as any).fulltext === 'string' + ? (row as any).fulltext + : null if (ft) setFulltextStr(ft) - const tablesAny = parseJson((row as any).fulltext_tables) ?? (row as any).fulltext_tables + const tablesAny = + parseJson((row as any).fulltext_tables) ?? + (row as any).fulltext_tables if (tablesAny && Array.isArray(tablesAny)) setFulltextTables(tablesAny) - const figsAny = parseJson((row as any).fulltext_figures) ?? (row as any).fulltext_figures + const figsAny = + parseJson((row as any).fulltext_figures) ?? + (row as any).fulltext_figures if (figsAny && Array.isArray(figsAny)) setFulltextFigures(figsAny) - const coordsAny = parseJson((row as any).fulltext_coords) ?? (row as any).fulltext_coords + const coordsAny = + parseJson((row as any).fulltext_coords) ?? + (row as any).fulltext_coords if (coordsAny && Array.isArray(coordsAny)) setFulltextCoords(coordsAny) - const pagesAny = parseJson((row as any).fulltext_pages) ?? (row as any).fulltext_pages + const pagesAny = + parseJson((row as any).fulltext_pages) ?? (row as any).fulltext_pages if (pagesAny && Array.isArray(pagesAny)) setFulltextPages(pagesAny) // If coords/pages are missing, trigger backend extraction to populate them, then refetch const needExtract = - !Array.isArray(coordsAny) || coordsAny.length === 0 || !Array.isArray(pagesAny) || pagesAny.length === 0 + !Array.isArray(coordsAny) || + coordsAny.length === 0 || + !Array.isArray(pagesAny) || + pagesAny.length === 0 if (needExtract) { try { @@ -423,14 +467,23 @@ export default function CanSrL2ScreenPage() { ) const row2 = await res3.json().catch(() => ({})) - const ft2 = typeof (row2 as any).fulltext === 'string' ? (row2 as any).fulltext : null + const ft2 = + typeof (row2 as any).fulltext === 'string' + ? (row2 as any).fulltext + : null if (ft2) setFulltextStr(ft2) - const coordsAny2 = parseJson((row2 as any).fulltext_coords) ?? (row2 as any).fulltext_coords - if (coordsAny2 && Array.isArray(coordsAny2)) setFulltextCoords(coordsAny2) - - const pagesAny2 = parseJson((row2 as any).fulltext_pages) ?? (row2 as any).fulltext_pages - if (pagesAny2 && Array.isArray(pagesAny2)) setFulltextPages(pagesAny2) + const coordsAny2 = + parseJson((row2 as any).fulltext_coords) ?? + (row2 as any).fulltext_coords + if (coordsAny2 && Array.isArray(coordsAny2)) + setFulltextCoords(coordsAny2) + + const pagesAny2 = + parseJson((row2 as any).fulltext_pages) ?? + (row2 as any).fulltext_pages + if (pagesAny2 && Array.isArray(pagesAny2)) + setFulltextPages(pagesAny2) } } catch (err) { console.warn('Failed to extract fulltext for overlay', err) @@ -443,13 +496,13 @@ export default function CanSrL2ScreenPage() { }, [srId, citationId]) const updateValue = (name: string, val: string) => { - setParamValues(prev => ({ ...prev, [name]: val })) + setParamValues((prev) => ({ ...prev, [name]: val })) } const saveParam = async (name: string) => { if (!citationId || !srId) return setSavingParam(name) - setSaveStatus(prev => ({ ...prev, [name]: 'saving' })) + setSaveStatus((prev) => ({ ...prev, [name]: 'saving' })) try { const headers = getAuthHeaders() const res = await fetch( @@ -470,9 +523,9 @@ export default function CanSrL2ScreenPage() { }, ) await res.json().catch(() => ({})) - setSaveStatus(prev => ({ ...prev, [name]: res.ok ? 'saved' : 'error' })) + setSaveStatus((prev) => ({ ...prev, [name]: res.ok ? 'saved' : 'error' })) } catch { - setSaveStatus(prev => ({ ...prev, [name]: 'error' })) + setSaveStatus((prev) => ({ ...prev, [name]: 'error' })) } finally { setSavingParam(null) } @@ -488,7 +541,7 @@ export default function CanSrL2ScreenPage() {
- -
+
{/* Workspace (left) */}
{/*
*/} - {/*
+ {/*

Workspace

This area displays the full text (PDF) and is a flexible workspace for viewing and selecting parameter regions.

*/} - + {/*
*/}
{/* Selection sidebar (right) */}