diff --git a/backend/api/auth/router.py b/backend/api/auth/router.py
index 18555e99..828cfe43 100644
--- a/backend/api/auth/router.py
+++ b/backend/api/auth/router.py
@@ -39,6 +39,7 @@
client_kwargs={"scope": "openid profile email"},
)
+
@router.post("/token", response_model=Token)
async def login_for_access_token(
form_data: OAuth2PasswordRequestForm = Depends(),
@@ -83,13 +84,17 @@ async def login(login_data: LoginRequest) -> Any:
return {"access_token": access_token, "token_type": "Bearer"}
+
@router.get("/microsoft-sso")
async def login_microsoft_sso(request: StarletteRequest, lang: str):
"""
Start Microsoft OAuth2 flow by redirecting the user to Microsoft login.
"""
request.session["lang"] = lang
- return await oauth.microsoft.authorize_redirect(request, f"{settings.API_URL}/api/auth/sso-authorize")
+ return await oauth.microsoft.authorize_redirect(
+ request, f"{settings.API_URL}/api/auth/sso-authorize"
+ )
+
@router.get("/sso-authorize")
async def microsoft_authorize(request: StarletteRequest):
@@ -122,6 +127,7 @@ async def microsoft_authorize(request: StarletteRequest):
detail="Error during Microsoft OAuth callback",
)
+
@router.post("/register", response_model=UserRead)
async def register_user(register_data: RegisterRequest) -> Any:
"""
diff --git a/backend/api/citations/router.py b/backend/api/citations/router.py
index 0a1c547d..cfba3ff3 100644
--- a/backend/api/citations/router.py
+++ b/backend/api/citations/router.py
@@ -36,7 +36,12 @@
from ..core.security import get_current_active_user
from ..core.config import settings
-from ..services.cit_db_service import cits_dp_service, snake_case, snake_case_column, parse_dsn
+from ..services.cit_db_service import (
+ cits_dp_service,
+ snake_case,
+ snake_case_column,
+ parse_dsn,
+)
from ..core.cit_utils import load_sr_and_check
router = APIRouter()
@@ -120,7 +125,9 @@ def _join_list(v: Any, sep: str = "; ") -> str:
if v is None:
return ""
if isinstance(v, list):
- return sep.join([str(x) for x in v if x is not None and str(x).strip() != ""]).strip()
+ return sep.join(
+ [str(x) for x in v if x is not None and str(x).strip() != ""]
+ ).strip()
return str(v)
@@ -139,9 +146,15 @@ def _ris_value_for_include(include_col: str, entry: Dict[str, Any]) -> Any:
# Common include names in our configs
if key == "title":
- return entry.get("title") or entry.get("primary_title") or entry.get("short_title")
+ return (
+ entry.get("title") or entry.get("primary_title") or entry.get("short_title")
+ )
if key == "abstract":
- return entry.get("abstract") or entry.get("notes_abstract") or _join_list(entry.get("notes"), "\n")
+ return (
+ entry.get("abstract")
+ or entry.get("notes_abstract")
+ or _join_list(entry.get("notes"), "\n")
+ )
if key == "keywords":
return _join_list(entry.get("keywords"))
if key == "journal":
@@ -173,7 +186,9 @@ def _ris_value_for_include(include_col: str, entry: Dict[str, Any]) -> Any:
return v
-def _parse_citations_csv_bytes(raw_bytes: bytes) -> tuple[list[dict[str, Any]], list[str]]:
+def _parse_citations_csv_bytes(
+ raw_bytes: bytes,
+) -> tuple[list[dict[str, Any]], list[str]]:
try:
text = raw_bytes.decode("utf-8-sig")
except Exception:
@@ -184,7 +199,9 @@ def _parse_citations_csv_bytes(raw_bytes: bytes) -> tuple[list[dict[str, Any]],
return rows, cols
-def _parse_citations_ris_bytes(raw_bytes: bytes, include_columns: list[str]) -> tuple[list[dict[str, Any]], list[str]]:
+def _parse_citations_ris_bytes(
+ raw_bytes: bytes, include_columns: list[str]
+) -> tuple[list[dict[str, Any]], list[str]]:
if not rispy:
raise RuntimeError("RIS upload requested but rispy is not installed")
@@ -231,15 +248,23 @@ def _ris_value_for_canonical(col: str, entry: Dict[str, Any]) -> Any:
key = (col or "").strip().lower()
if key == "title":
- return entry.get("title") or entry.get("primary_title") or entry.get("short_title")
+ return (
+ entry.get("title") or entry.get("primary_title") or entry.get("short_title")
+ )
if key == "abstract":
- return entry.get("abstract") or entry.get("notes_abstract") or _join_list(entry.get("notes"), "\n")
+ return (
+ entry.get("abstract")
+ or entry.get("notes_abstract")
+ or _join_list(entry.get("notes"), "\n")
+ )
if key == "keywords":
return _join_list(entry.get("keywords"))
if key == "journal":
return entry.get("secondary_title") or entry.get("journal_name")
if key == "year":
- return _extract_year(entry.get("year") or entry.get("publication_year") or entry.get("date"))
+ return _extract_year(
+ entry.get("year") or entry.get("publication_year") or entry.get("date")
+ )
if key == "authors":
return _join_list(entry.get("authors"))
if key == "doi":
@@ -254,7 +279,9 @@ def _ris_value_for_canonical(col: str, entry: Dict[str, Any]) -> Any:
return _join_list(v) if isinstance(v, list) else v
-def _parse_citations_ris_bytes_auto(raw_bytes: bytes) -> tuple[list[dict[str, Any]], list[str]]:
+def _parse_citations_ris_bytes_auto(
+ raw_bytes: bytes,
+) -> tuple[list[dict[str, Any]], list[str]]:
"""Parse RIS into a canonical table shape.
This is used when RIS is uploaded before SR criteria/config exists.
@@ -285,19 +312,28 @@ def _parse_citations_ris_bytes_auto(raw_bytes: bytes) -> tuple[list[dict[str, An
return rows, cols
-def _load_include_columns_from_criteria(sr_doc: Optional[Dict[str, Any]] = None) -> List[str]:
+
+def _load_include_columns_from_criteria(
+ sr_doc: Optional[Dict[str, Any]] = None,
+) -> List[str]:
# Delegate to consolidated postgres service
try:
return cits_dp_service.load_include_columns_from_criteria(sr_doc)
except Exception:
return []
+
+
def _parse_dsn(dsn: str) -> Dict[str, str]:
# Delegate to consolidated postgres service
try:
return parse_dsn(dsn)
except Exception:
return {}
-def _create_table_and_insert_sync(table_name: str, columns: List[str], rows: List[Dict[str, Any]]) -> int:
+
+
+def _create_table_and_insert_sync(
+ table_name: str, columns: List[str], rows: List[Dict[str, Any]]
+) -> int:
return cits_dp_service.create_table_and_insert_sync(table_name, columns, rows)
@@ -311,7 +347,9 @@ async def _upload_screening_citations_impl(
"""Shared implementation for citations upload (CSV or RIS)."""
try:
- sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service, require_screening=False)
+ sr, screening = await load_sr_and_check(
+ sr_id, current_user, srdb_service, require_screening=False
+ )
except HTTPException:
raise
except Exception as e:
@@ -328,19 +366,29 @@ async def _upload_screening_citations_impl(
)
if not file or not file.filename:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="File is required")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST, detail="File is required"
+ )
# Read bytes once (UploadFile stream is not reliably seekable after read)
try:
raw = await file.read()
except Exception as e:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Failed to read upload: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=f"Failed to read upload: {e}",
+ )
if not raw:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Uploaded file is empty")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST, detail="Uploaded file is empty"
+ )
fmt = (force_format or _sniff_citations_format(file.filename, raw)).lower()
if fmt not in ("csv", "ris"):
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Unsupported citations format: {fmt}")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=f"Unsupported citations format: {fmt}",
+ )
# Parse into normalized rows + include columns
try:
@@ -352,10 +400,16 @@ async def _upload_screening_citations_impl(
except HTTPException:
raise
except Exception as e:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Failed to parse {fmt.upper()}: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=f"Failed to parse {fmt.upper()}: {e}",
+ )
if len(include_columns) == 0:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="No columns found to create screening table")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail="No columns found to create screening table",
+ )
# Build a unique table name for this upload
safe_sr = re.sub(r"[^0-9a-zA-Z_]", "_", sr_id)
@@ -372,11 +426,18 @@ async def _upload_screening_citations_impl(
# Create table and insert rows in threadpool
try:
- inserted = await run_in_threadpool(_create_table_and_insert_sync, table_name, include_columns, normalized_rows)
+ inserted = await run_in_threadpool(
+ _create_table_and_insert_sync, table_name, include_columns, normalized_rows
+ )
except RuntimeError as rexc:
- raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc))
+ raise HTTPException(
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)
+ )
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to create table or insert rows: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to create table or insert rows: {e}",
+ )
# Save DB connection metadata into SR Mongo doc
try:
@@ -436,7 +497,9 @@ async def upload_screening_citations(
@router.get("/{sr_id}/citations")
async def list_citation_ids(
- sr_id: str, current_user: Dict[str, Any] = Depends(get_current_active_user), filter_step: Optional[str] = None,
+ sr_id: str,
+ current_user: Dict[str, Any] = Depends(get_current_active_user),
+ filter_step: Optional[str] = None,
):
"""
List all citation ids for the systematic review's screening database.
@@ -448,7 +511,10 @@ async def list_citation_ids(
except HTTPException:
raise
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to load systematic review or screening: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to load systematic review or screening: {e}",
+ )
if not screening:
return {"citation_ids": []}
@@ -459,22 +525,31 @@ async def list_citation_ids(
# Validation strategy: UI filters by human_l1_decision / human_l2_decision.
try:
cp = (sr or {}).get("criteria_parsed") or (sr or {}).get("criteria") or {}
- await run_in_threadpool(cits_dp_service.backfill_human_decisions, cp, table_name)
+ await run_in_threadpool(
+ cits_dp_service.backfill_human_decisions, cp, table_name
+ )
except Exception:
# best-effort; listing should still work even if backfill fails
pass
try:
- ids = await run_in_threadpool(cits_dp_service.list_citation_ids, filter_step, table_name)
+ ids = await run_in_threadpool(
+ cits_dp_service.list_citation_ids, filter_step, table_name
+ )
except RuntimeError as rexc:
- raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc))
+ raise HTTPException(
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)
+ )
except Exception as e:
# If the SR points at a screening table that no longer exists (e.g. dropped),
# treat it as "no citations" instead of poisoning the shared connection.
if _is_undefined_table_error(e):
return {"citation_ids": []}
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to query screening DB: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to query screening DB: {e}",
+ )
return {"citation_ids": ids}
@@ -515,7 +590,9 @@ async def get_citations_batch(
# Parse ids
raw_ids = [p.strip() for p in (ids or "").split(",") if p.strip()]
if not raw_ids:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="ids is required")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST, detail="ids is required"
+ )
parsed_ids: List[int] = []
for p in raw_ids:
try:
@@ -523,7 +600,10 @@ async def get_citations_batch(
except Exception:
continue
if not parsed_ids:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="ids must be a comma-separated list of integers")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail="ids must be a comma-separated list of integers",
+ )
parsed_fields: Optional[List[str]] = None
if fields is not None:
@@ -538,9 +618,14 @@ async def get_citations_batch(
parsed_fields,
)
except RuntimeError as rexc:
- raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc))
+ raise HTTPException(
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)
+ )
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to query screening DB: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to query screening DB: {e}",
+ )
return {"citations": rows}
@@ -566,19 +651,31 @@ async def get_citation_by_id(
except HTTPException:
raise
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to load systematic review or screening: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to load systematic review or screening: {e}",
+ )
table_name = (screening or {}).get("table_name") or "citations"
try:
- row = await run_in_threadpool(cits_dp_service.get_citation_by_id, int(citation_id), table_name)
+ row = await run_in_threadpool(
+ cits_dp_service.get_citation_by_id, int(citation_id), table_name
+ )
except RuntimeError as rexc:
- raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc))
+ raise HTTPException(
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)
+ )
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to query screening DB: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to query screening DB: {e}",
+ )
if not row:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found"
+ )
return row
@@ -588,7 +685,9 @@ class CombinedRequest(BaseModel):
include_columns: Optional[List[str]] = None
-def _build_combined_citation_from_row(row: Dict[str, Any], include_columns: List[str]) -> str:
+def _build_combined_citation_from_row(
+ row: Dict[str, Any], include_columns: List[str]
+) -> str:
# Delegate to consolidated postgres service
return cits_dp_service.build_combined_citation_from_row(row, include_columns)
@@ -618,22 +717,37 @@ async def build_combined_citation(
except HTTPException:
raise
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to load systematic review or screening: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to load systematic review or screening: {e}",
+ )
if not screening:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No screening database configured for this systematic review")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="No screening database configured for this systematic review",
+ )
table_name = (screening or {}).get("table_name") or "citations"
try:
- row = await run_in_threadpool(cits_dp_service.get_citation_by_id, int(citation_id), table_name)
+ row = await run_in_threadpool(
+ cits_dp_service.get_citation_by_id, int(citation_id), table_name
+ )
except RuntimeError as rexc:
- raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc))
+ raise HTTPException(
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)
+ )
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to query screening DB: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to query screening DB: {e}",
+ )
if not row:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found"
+ )
# As a final fallback, use all columns present in include columns
data = []
@@ -672,17 +786,25 @@ async def upload_citation_fulltext(
except HTTPException:
raise
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to load systematic review or screening: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to load systematic review or screening: {e}",
+ )
# Validate file
if not file or not file.filename:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="File is required")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST, detail="File is required"
+ )
_, ext = os.path.splitext(file.filename)
ext = ext.lower()
# Prefer PDFs but allow other types if necessary; restrict to .pdf here
if ext != ".pdf":
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only PDF files are accepted for full text upload")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail="Only PDF files are accepted for full text upload",
+ )
content = await file.read()
new_md5 = hashlib.md5(content).hexdigest() if content is not None else ""
@@ -690,14 +812,23 @@ async def upload_citation_fulltext(
table_name = (screening or {}).get("table_name") or "citations"
try:
- existing_row = await run_in_threadpool(cits_dp_service.get_citation_by_id, int(citation_id), table_name)
+ existing_row = await run_in_threadpool(
+ cits_dp_service.get_citation_by_id, int(citation_id), table_name
+ )
except RuntimeError as rexc:
- raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc))
+ raise HTTPException(
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)
+ )
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to query screening DB: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to query screening DB: {e}",
+ )
if not existing_row:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found"
+ )
# If the PDF changed (md5 differs), clear L2 screening answers + parameter extractions + fulltext artifacts.
# (Do NOT clear L1 answers.)
@@ -740,9 +871,15 @@ async def upload_citation_fulltext(
# NOTE (validation): we do not use l2_screen for filtering; keep it untouched/non-authoritative.
cols_to_clear = ["llm_l2_decision", "human_l2_decision"]
try:
- cp = (sr or {}).get("criteria_parsed") or (sr or {}).get("criteria") or {}
+ cp = (
+ (sr or {}).get("criteria_parsed")
+ or (sr or {}).get("criteria")
+ or {}
+ )
l2 = cp.get("l2") if isinstance(cp, dict) else None
- l2_questions = (l2 or {}).get("questions") if isinstance(l2, dict) else None
+ l2_questions = (
+ (l2 or {}).get("questions") if isinstance(l2, dict) else None
+ )
if isinstance(l2_questions, list):
for q in l2_questions:
try:
@@ -763,7 +900,9 @@ async def upload_citation_fulltext(
# best-effort
pass
- await run_in_threadpool(cits_dp_service.clear_columns, citation_id, cols_to_clear, table_name)
+ await run_in_threadpool(
+ cits_dp_service.clear_columns, citation_id, cols_to_clear, table_name
+ )
except Exception:
# Best-effort; do not block upload.
pass
@@ -775,7 +914,10 @@ async def upload_citation_fulltext(
storage_service = None
if not storage_service:
- raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Storage service not available")
+ raise HTTPException(
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
+ detail="Storage service not available",
+ )
# Upload file for the current user
document_id = await storage_service.upload_user_document(
@@ -785,7 +927,10 @@ async def upload_citation_fulltext(
)
if not document_id:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to upload file to storage service")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail="Failed to upload file to storage service",
+ )
# Build storage path (container + blob name) so it can be stored in Postgres
# Note: storage.py stores blobs at users/{user_id}/documents/{doc_id}_{filename}
blob_name = f"users/{current_user['id']}/documents/{document_id}_{file.filename}"
@@ -794,19 +939,35 @@ async def upload_citation_fulltext(
# Update citation row in Postgres
try:
- updated = await run_in_threadpool(cits_dp_service.attach_fulltext, citation_id, storage_path, content, table_name)
+ updated = await run_in_threadpool(
+ cits_dp_service.attach_fulltext,
+ citation_id,
+ storage_path,
+ content,
+ table_name,
+ )
except RuntimeError as rexc:
- raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc))
+ raise HTTPException(
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)
+ )
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to update citation row: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to update citation row: {e}",
+ )
if not updated:
# If the citation id doesn't exist, consider rolling back the uploaded file (best effort)
# Attempt to delete the uploaded blob (best-effort; not fatal if it fails)
try:
- await storage_service.delete_user_document(current_user["id"], document_id, file.filename)
+ await storage_service.delete_user_document(
+ current_user["id"], document_id, file.filename
+ )
except Exception:
pass
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found to attach fulltext file")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="Citation not found to attach fulltext file",
+ )
return {
"status": "success",
"sr_id": sr_id,
@@ -820,7 +981,9 @@ async def upload_citation_fulltext(
# Helper to drop a database - delegated to backend.api.core.postgres.drop_database
-async def hard_delete_screening_resources(sr_id: str, current_user: Dict[str, Any]) -> Dict[str, Any]:
+async def hard_delete_screening_resources(
+ sr_id: str, current_user: Dict[str, Any]
+) -> Dict[str, Any]:
"""
Delete the screening Postgres table and all associated fulltext files
for the given systematic review.
@@ -837,26 +1000,47 @@ async def hard_delete_screening_resources(sr_id: str, current_user: Dict[str, An
except HTTPException:
raise
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to load systematic review: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to load systematic review: {e}",
+ )
requester_id = current_user.get("id")
if requester_id != sr.get("owner_id"):
- raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Only the owner may perform screening cleanup for this systematic review")
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail="Only the owner may perform screening cleanup for this systematic review",
+ )
if not screening:
- return {"status": "no_screening_db", "message": "No screening table configured for this SR", "deleted_table": False, "deleted_files": 0}
+ return {
+ "status": "no_screening_db",
+ "message": "No screening table configured for this SR",
+ "deleted_table": False,
+ "deleted_files": 0,
+ }
table_name = screening.get("table_name")
if not table_name:
- return {"status": "no_screening_db", "message": "Incomplete screening DB metadata", "deleted_table": False, "deleted_files": 0}
+ return {
+ "status": "no_screening_db",
+ "message": "Incomplete screening DB metadata",
+ "deleted_table": False,
+ "deleted_files": 0,
+ }
# 1) collect fulltext URLs from the screening DB
try:
urls = await run_in_threadpool(cits_dp_service.list_fulltext_urls, table_name)
except RuntimeError as rexc:
- raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc))
+ raise HTTPException(
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)
+ )
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to query screening DB for fulltext URLs: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to query screening DB for fulltext URLs: {e}",
+ )
# 2) delete blobs for each url (best-effort)
deleted_files = 0
@@ -888,13 +1072,17 @@ async def hard_delete_screening_resources(sr_id: str, current_user: Dict[str, An
# expect ["users", user_id, "documents", "{doc_id}_{filename}"]
if len(parts) >= 4 and parts[2] == "documents":
user_id = parts[1]
- doc_part = "/".join(parts[3:]) # handle any extra slashes in filename
+ doc_part = "/".join(
+ parts[3:]
+ ) # handle any extra slashes in filename
# split first underscore to get doc_id and filename
if "_" in doc_part:
doc_id, filename = doc_part.split("_", 1)
# call delete_user_document (async)
try:
- ok = await storage_service.delete_user_document(user_id, doc_id, filename)
+ ok = await storage_service.delete_user_document(
+ user_id, doc_id, filename
+ )
if ok:
deleted_files += 1
else:
@@ -927,16 +1115,18 @@ async def hard_delete_screening_resources(sr_id: str, current_user: Dict[str, An
await run_in_threadpool(cits_dp_service.drop_table, table_name)
table_dropped = True
except RuntimeError as rexc:
- raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc))
+ raise HTTPException(
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)
+ )
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to drop screening table: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to drop screening table: {e}",
+ )
# 4) remove screening_db metadata from SR document
try:
- await run_in_threadpool(
- srdb_service.clear_screening_db_info,
- sr_id
- )
+ await run_in_threadpool(srdb_service.clear_screening_db_info, sr_id)
except Exception:
# non-fatal, but report it
pass
@@ -952,7 +1142,9 @@ async def hard_delete_screening_resources(sr_id: str, current_user: Dict[str, An
# Optional endpoint to trigger the cleanup directly
@router.post("/{sr_id}/hard-clean")
-async def hard_clean_screening_endpoint(sr_id: str, current_user: Dict[str, Any] = Depends(get_current_active_user)):
+async def hard_clean_screening_endpoint(
+ sr_id: str, current_user: Dict[str, Any] = Depends(get_current_active_user)
+):
result = await hard_delete_screening_resources(sr_id, current_user)
return result
@@ -971,9 +1163,7 @@ async def export_citations_csv(
"""
try:
- sr, screening = await load_sr_and_check(
- sr_id, current_user, srdb_service
- )
+ sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service)
except HTTPException:
raise
except Exception as e:
@@ -992,9 +1182,13 @@ async def export_citations_csv(
try:
# Validation-friendly export: exclude fulltext/artifacts and flatten JSON columns.
- csv_bytes = await run_in_threadpool(cits_dp_service.dump_citations_csv_filtered, table_name)
+ csv_bytes = await run_in_threadpool(
+ cits_dp_service.dump_citations_csv_filtered, table_name
+ )
except RuntimeError as rexc:
- raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc))
+ raise HTTPException(
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)
+ )
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
diff --git a/backend/api/core/bounding_box_matcher.py b/backend/api/core/bounding_box_matcher.py
index a0fe5e4d..dfe7279e 100644
--- a/backend/api/core/bounding_box_matcher.py
+++ b/backend/api/core/bounding_box_matcher.py
@@ -432,4 +432,4 @@ def match_figure_references_to_bounding_boxes(
print(
f"[FIGURE_MATCHING] Enhanced {len(enhanced_references)} references with figure information"
)
- return enhanced_references
\ No newline at end of file
+ return enhanced_references
diff --git a/backend/api/core/cit_utils.py b/backend/api/core/cit_utils.py
index 1e33ce13..39282b30 100644
--- a/backend/api/core/cit_utils.py
+++ b/backend/api/core/cit_utils.py
@@ -9,6 +9,7 @@
Routers should call load_sr_and_check(...) to avoid duplicating this logic.
"""
+
from typing import Any, Dict, Optional, Tuple
from fastapi import HTTPException, status
from fastapi.concurrency import run_in_threadpool
@@ -65,34 +66,53 @@ async def load_sr_and_check(
# fetch SR
try:
- sr = await run_in_threadpool(srdb_service.get_systematic_review, sr_id, not require_visible)
+ sr = await run_in_threadpool(
+ srdb_service.get_systematic_review, sr_id, not require_visible
+ )
except HTTPException:
raise
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to fetch systematic review: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to fetch systematic review: {e}",
+ )
if not sr:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Systematic review not found")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="Systematic review not found"
+ )
if require_visible and not sr.get("visible", True):
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Systematic review not found")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="Systematic review not found"
+ )
# permission check (user must be member or owner)
user_id = current_user.get("email")
try:
- has_perm = await run_in_threadpool(srdb_service.user_has_sr_permission, sr_id, user_id)
+ has_perm = await run_in_threadpool(
+ srdb_service.user_has_sr_permission, sr_id, user_id
+ )
except HTTPException:
raise
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to check permissions: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to check permissions: {e}",
+ )
if not has_perm:
- raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized to view/modify this systematic review")
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail="Not authorized to view/modify this systematic review",
+ )
screening = sr.get("screening_db") if isinstance(sr, dict) else None
if require_screening:
if not screening:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No screening database configured for this systematic review")
-
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="No screening database configured for this systematic review",
+ )
return sr, screening
diff --git a/backend/api/core/config.py b/backend/api/core/config.py
index 1e710ad5..1aa5c2f2 100644
--- a/backend/api/core/config.py
+++ b/backend/api/core/config.py
@@ -96,17 +96,25 @@ def convert_max_file_size(cls, v):
# Background jobs (Procrastinate)
# ---------------------------------------------------------------------
# Enable Procrastinate-backed background jobs (required for /api/jobs/*)
- ENABLE_PROCRASTINATE: bool = os.getenv("ENABLE_PROCRASTINATE", "false").lower().strip() == "true"
+ ENABLE_PROCRASTINATE: bool = (
+ os.getenv("ENABLE_PROCRASTINATE", "false").lower().strip() == "true"
+ )
# Embedded worker loop (dev-friendly). Keep this OFF by default.
- ENABLE_PROCRASTINATE_WORKER: bool = os.getenv("ENABLE_PROCRASTINATE_WORKER", "false").lower().strip() == "true"
+ ENABLE_PROCRASTINATE_WORKER: bool = (
+ os.getenv("ENABLE_PROCRASTINATE_WORKER", "false").lower().strip() == "true"
+ )
# Worker concurrency (only used when ENABLE_PROCRASTINATE_WORKER=true)
- PROCRASTINATE_WORKER_CONCURRENCY: int = int(os.getenv("PROCRASTINATE_WORKER_CONCURRENCY", "1"))
+ PROCRASTINATE_WORKER_CONCURRENCY: int = int(
+ os.getenv("PROCRASTINATE_WORKER_CONCURRENCY", "1")
+ )
# Optional dev cleanup: clear out leftover queued/doing tasks on API startup.
# IMPORTANT: default is true if the env var is absent.
- PROCRASTINATE_CLEAR_ON_START: bool = os.getenv("PROCRASTINATE_CLEAR_ON_START", "true").lower().strip() == "true"
+ PROCRASTINATE_CLEAR_ON_START: bool = (
+ os.getenv("PROCRASTINATE_CLEAR_ON_START", "true").lower().strip() == "true"
+ )
# Run-All job chunk size: citations per Procrastinate chunk task.
# Larger values reduce overhead but can reduce fairness/responsiveness.
@@ -123,7 +131,9 @@ def convert_max_file_size(cls, v):
# -------------------------------------------------------------------------
# Select primary Postgres mode.
# docker/local use password auth; azure uses Entra token auth.
- POSTGRES_MODE: str = os.getenv("POSTGRES_MODE", "docker").lower().strip() # docker|local|azure
+ POSTGRES_MODE: str = (
+ os.getenv("POSTGRES_MODE", "docker").lower().strip()
+ ) # docker|local|azure
# Canonical Postgres connection settings (single profile; values vary by environment)
POSTGRES_HOST: Optional[str] = os.getenv("POSTGRES_HOST")
@@ -153,7 +163,9 @@ def postgres_profile(self, mode: Optional[str] = None) -> dict:
raise ValueError("POSTGRES_MODE must be one of: docker, local, azure")
# Provide sensible defaults for host depending on mode.
- default_host = "pgdb-service" if m == "docker" else "localhost" if m == "local" else None
+ default_host = (
+ "pgdb-service" if m == "docker" else "localhost" if m == "local" else None
+ )
prof = {
"mode": m,
@@ -179,7 +191,7 @@ def has_local_fallback(self) -> bool:
JOB_ID_PUBMED: str = os.getenv("JOB_ID_PUBMED", "")
JOB_ID_SCOPUS: str = os.getenv("JOB_ID_SCOPUS", "")
- #search function
+ # search function
ENTREZ_EMAIL: str = os.getenv("ENTREZ_EMAIL")
ENTREZ_API_KEY: str = os.getenv("ENTREZ_API_KEY")
diff --git a/backend/api/core/security.py b/backend/api/core/security.py
index 19e38e20..cf3245c8 100644
--- a/backend/api/core/security.py
+++ b/backend/api/core/security.py
@@ -90,7 +90,9 @@ async def update_user(user_id: str, user_in: UserUpdate) -> Optional[UserRead]:
return UserRead.model_validate(updated_user_data)
-async def authenticate_user(email: str, password: str, sso: bool = False) -> Optional[Dict[str, Any]]:
+async def authenticate_user(
+ email: str, password: str, sso: bool = False
+) -> Optional[Dict[str, Any]]:
"""Authenticate a user"""
email = email.lower()
if not user_db_service:
diff --git a/backend/api/database_search/router.py b/backend/api/database_search/router.py
index a5b4e79f..9daf1671 100644
--- a/backend/api/database_search/router.py
+++ b/backend/api/database_search/router.py
@@ -4,8 +4,12 @@
from pydantic import BaseModel, Field
from ..core.security import get_current_active_user
from ..core.config import settings
-from ..services.citation_search.pubmed_citation_collection import PubMedCitationCollector
-from ..services.citation_search.europePMC_citation_collection import EuropePMCCitationCollector
+from ..services.citation_search.pubmed_citation_collection import (
+ PubMedCitationCollector,
+)
+from ..services.citation_search.europePMC_citation_collection import (
+ EuropePMCCitationCollector,
+)
from ..services.citation_search.scopus_citation_collection import ScopusDataProcessor
import logging
@@ -18,12 +22,13 @@
import datetime
+
def azure_client(container_name: str):
connection_string = settings.AZURE_STORAGE_CONNECTION_STRING
blob_service_client = BlobServiceClient.from_connection_string(connection_string)
blob_client = blob_service_client.get_container_client(container_name)
return blob_client
-
+
def read_blob_to_df(blob_name: str, container_name: str) -> pd.DataFrame:
"""Read a blob (Parquet or CSV) into a pandas DataFrame."""
@@ -36,32 +41,40 @@ def read_blob_to_df(blob_name: str, container_name: str) -> pd.DataFrame:
def write_df_to_blob(df: pd.DataFrame, blob_name: str, container_name: str):
"""Write a pandas DataFrame to blob as Parquet."""
buffer = BytesIO()
- df.to_csv(buffer, index=False, encoding='utf-8')
+ df.to_csv(buffer, index=False, encoding="utf-8")
buffer.seek(0)
blob_client = azure_client(container_name).get_blob_client(blob_name)
blob_client.upload_blob(buffer, overwrite=True)
+
router = APIRouter()
+
class SearchRequest(BaseModel):
database: str = Field(..., description="database selected for search")
- search_term: Optional[str] = Field("", description="search string for database search")
+ search_term: Optional[str] = Field(
+ "", description="search string for database search"
+ )
@router.post("/{sr_id}/search")
async def database_search(
- sr_id: str, payload: SearchRequest, current_user: Dict[str, Any] = Depends(get_current_active_user),
+ sr_id: str,
+ payload: SearchRequest,
+ current_user: Dict[str, Any] = Depends(get_current_active_user),
):
-
+
MAX_ARTICLES = 1000
MINDATE = None
MAXDATE = None
database = payload.database
SEARCH_TERM = payload.search_term
- if not SEARCH_TERM :
+ if not SEARCH_TERM:
SEARCH_TERM = '(("epidemiological parameters"[Title/Abstract] OR "incidence"[MeSH Terms]))'
- logging.info(f"{database} search function started at {datetime.datetime.now()}", )
+ logging.info(
+ f"{database} search function started at {datetime.datetime.now()}",
+ )
if database == "Pubmed":
collector = PubMedCitationCollector()
@@ -71,7 +84,7 @@ async def database_search(
search_term=SEARCH_TERM,
mindate=MINDATE,
maxdate=MAXDATE,
- max_articles=MAX_ARTICLES
+ max_articles=MAX_ARTICLES,
)
except Exception as e:
logging.error(f"Error collecting citations: {str(e)}")
@@ -81,58 +94,74 @@ async def database_search(
elif database == "EuropePMC":
collector = EuropePMCCitationCollector()
- try:
- citations = collector.collect_citations(search_term = SEARCH_TERM, max_articles=MAX_ARTICLES)
+ try:
+ citations = collector.collect_citations(
+ search_term=SEARCH_TERM, max_articles=MAX_ARTICLES
+ )
except Exception as e:
logging.error(f"Error collecting citations: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to collect citations")
content = pd.DataFrame(citations)
elif database == "Scopus":
-
-
- SCOPUS_API_KEY = settings.SCOPUS_API_KEY or None #currently don't have
+
+ SCOPUS_API_KEY = settings.SCOPUS_API_KEY or None # currently don't have
API_URL = settings.SCOPUS_API_URL
if not SCOPUS_API_KEY or not API_URL:
- logging.error("Missing SCOPUS_API_KEY or SCOPUS_BASE_URL in environment settings.")
- raise HTTPException(status_code=500, detail="Missing SCOPUS_API_KEY or SCOPUS_BASE_URL")
+ logging.error(
+ "Missing SCOPUS_API_KEY or SCOPUS_BASE_URL in environment settings."
+ )
+ raise HTTPException(
+ status_code=500, detail="Missing SCOPUS_API_KEY or SCOPUS_BASE_URL"
+ )
processor = ScopusDataProcessor(SCOPUS_API_KEY, API_URL)
processor.consume_api(SEARCH_TERM, delay=15)
content = pd.DataFrame.from_dict(processor.data)
-
blob_name = f"{database}-{datetime.datetime.now().strftime('%Y-%m-%d')}.csv"
container_name = f"citation-data/bronze-data/{database}/archive"
try:
write_df_to_blob(content, blob_name, container_name)
logging.info(f"Uploaded blob '{blob_name}' to container '{container_name}'")
except AzureError as e:
- logging.error(f"Azure Blob Storage error: {e.message if hasattr(e, 'message') else str(e)}")
+ logging.error(
+ f"Azure Blob Storage error: {e.message if hasattr(e, 'message') else str(e)}"
+ )
raise HTTPException(status_code=500, detail="Failed to write to Azure Blob")
except Exception as ex:
logging.error(f"Unexpected error: {str(ex)}")
- raise HTTPException(status_code=500, detail="Unexpected error while writing to blob")
+ raise HTTPException(
+ status_code=500, detail="Unexpected error while writing to blob"
+ )
try:
- all_citations_df = read_blob_to_df(f"{database}-all-citations.csv", f"citation-data/bronze-data/{database}")
+ all_citations_df = read_blob_to_df(
+ f"{database}-all-citations.csv", f"citation-data/bronze-data/{database}"
+ )
except Exception:
- all_citations_df = content.iloc[0:0].copy()
+ all_citations_df = content.iloc[0:0].copy()
- new_all_citations = (
- pd.concat([all_citations_df, content], ignore_index=True)
- .drop_duplicates(subset="pmid", keep="first")
- )
+ new_all_citations = pd.concat(
+ [all_citations_df, content], ignore_index=True
+ ).drop_duplicates(subset="pmid", keep="first")
new_citations = content[~content["pmid"].isin(all_citations_df["pmid"])]
if new_citations.empty:
- logging.info('No new citations on %s', datetime.datetime.now())
+ logging.info("No new citations on %s", datetime.datetime.now())
return {"message": "No new citations", "new_count": 0}
- logging.info('New citations found: %s', len(new_citations))
+ logging.info("New citations found: %s", len(new_citations))
- write_df_to_blob(new_all_citations, f"{database}-all-citations.csv", f"citation-data/bronze-data/{database}")
+ write_df_to_blob(
+ new_all_citations,
+ f"{database}-all-citations.csv",
+ f"citation-data/bronze-data/{database}",
+ )
write_df_to_blob(new_citations, blob_name, "citation-deduplicate/to-process")
- return {"message": f"{database} collection completed", "new_citations_count": len(new_citations)}
\ No newline at end of file
+ return {
+ "message": f"{database} collection completed",
+ "new_citations_count": len(new_citations),
+ }
diff --git a/backend/api/extract/router.py b/backend/api/extract/router.py
index 7e305fc0..3cb968d4 100644
--- a/backend/api/extract/router.py
+++ b/backend/api/extract/router.py
@@ -33,40 +33,55 @@
# Import consolidated Postgres helpers if available (optional)
from ..services.cit_db_service import cits_dp_service, snake_case_param
-
-
router = APIRouter()
class ParameterExtractRequest(BaseModel):
fulltext: Optional[str] = Field(
None,
- description="Full text with numbered sentences (e.g. '[0] First sentence\\n[1] Second sentence'). If omitted the endpoint will try to read fulltext_url from the screening DB row."
+ description="Full text with numbered sentences (e.g. '[0] First sentence\\n[1] Second sentence'). If omitted the endpoint will try to read fulltext_url from the screening DB row.",
+ )
+ parameter_name: str = Field(
+ ..., description="Short name for the parameter (used as column name slug)"
+ )
+ parameter_description: str = Field(
+ ..., description="Human-friendly description of what to extract"
)
- parameter_name: str = Field(..., description="Short name for the parameter (used as column name slug)")
- parameter_description: str = Field(..., description="Human-friendly description of what to extract")
model: Optional[str] = Field(None, description="Model to use")
temperature: Optional[float] = Field(0.0, ge=0.0, le=1.0)
max_tokens: Optional[int] = Field(512, ge=1, le=4000)
# Optional artifacts context (if omitted, server will read from citation row when available)
- tables: Optional[str] = Field(None, description="Optional numbered tables text (markdown).")
- figures: Optional[str] = Field(None, description="Optional numbered figure captions text.")
- attach_figures: Optional[bool] = Field(True, description="If true, attach figure images to the LLM request when available")
-
-
+ tables: Optional[str] = Field(
+ None, description="Optional numbered tables text (markdown)."
+ )
+ figures: Optional[str] = Field(
+ None, description="Optional numbered figure captions text."
+ )
+ attach_figures: Optional[bool] = Field(
+ True,
+ description="If true, attach figure images to the LLM request when available",
+ )
class HumanParameterRequest(BaseModel):
fulltext: Optional[str] = Field(
None,
- description="Optional numbered full text. If omitted the server will try to read fulltext_url from the screening DB row."
+ description="Optional numbered full text. If omitted the server will try to read fulltext_url from the screening DB row.",
+ )
+ parameter_name: str = Field(
+ ..., description="Short name for the parameter (used as column name slug)"
)
- parameter_name: str = Field(..., description="Short name for the parameter (used as column name slug)")
found: bool = Field(..., description="Whether the parameter was found (boolean)")
- value: Optional[str] = Field(None, description="Human-provided value (string) or null")
- explanation: Optional[str] = Field("", description="Optional explanation from the human reviewer")
- evidence_sentences: Optional[List[int]] = Field(None, description="Optional list of evidence sentence indices")
+ value: Optional[str] = Field(
+ None, description="Human-provided value (string) or null"
+ )
+ explanation: Optional[str] = Field(
+ "", description="Optional explanation from the human reviewer"
+ )
+ evidence_sentences: Optional[List[int]] = Field(
+ None, description="Optional list of evidence sentence indices"
+ )
reviewer: Optional[str] = Field(None, description="Optional reviewer id or name")
@@ -98,7 +113,10 @@ async def extract_parameter_endpoint(
except HTTPException:
raise
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to load systematic review or screening: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to load systematic review or screening: {e}",
+ )
table_name = (screening or {}).get("table_name") or "citations"
@@ -107,18 +125,30 @@ async def extract_parameter_endpoint(
row = None
if not fulltext:
try:
- row = await run_in_threadpool(cits_dp_service.get_citation_by_id, int(citation_id), table_name)
+ row = await run_in_threadpool(
+ cits_dp_service.get_citation_by_id, int(citation_id), table_name
+ )
except RuntimeError as rexc:
- raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc))
+ raise HTTPException(
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)
+ )
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to query screening DB: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to query screening DB: {e}",
+ )
if not row:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found"
+ )
fulltext = row.get("fulltext") if "fulltext" in row else None
if not fulltext:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Full text not provided and not available for this citation")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail="Full text not provided and not available for this citation",
+ )
# Build tables/figures context: prefer payload fields, else fetch from DB row (if loaded)
tables_text = payload.tables
@@ -127,7 +157,9 @@ async def extract_parameter_endpoint(
if (tables_text is None or figures_text is None) and row is None:
try:
- row = await run_in_threadpool(cits_dp_service.get_citation_by_id, int(citation_id), table_name)
+ row = await run_in_threadpool(
+ cits_dp_service.get_citation_by_id, int(citation_id), table_name
+ )
except Exception:
row = None
@@ -153,7 +185,9 @@ async def extract_parameter_endpoint(
try:
md_bytes, _ = await storage_service.get_bytes_by_path(blob_addr)
md_txt = md_bytes.decode("utf-8", errors="replace")
- header = f"Table [T{idx}]" + (f" caption: {caption}" if caption else "")
+ header = f"Table [T{idx}]" + (
+ f" caption: {caption}" if caption else ""
+ )
tables_md_lines.extend([header, md_txt, ""])
except Exception:
continue
@@ -182,7 +216,9 @@ async def extract_parameter_endpoint(
)
if payload.attach_figures:
try:
- img_bytes, _ = await storage_service.get_bytes_by_path(blob_addr)
+ img_bytes, _ = await storage_service.get_bytes_by_path(
+ blob_addr
+ )
if img_bytes:
images.append((img_bytes, "image/png"))
except Exception:
@@ -202,7 +238,10 @@ async def extract_parameter_endpoint(
)
if not azure_openai_client.is_configured():
- raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Azure OpenAI client is not configured on the server")
+ raise HTTPException(
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
+ detail="Azure OpenAI client is not configured on the server",
+ )
try:
if images:
@@ -223,7 +262,10 @@ async def extract_parameter_endpoint(
temperature=payload.temperature or 0.0,
)
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"LLM call failed: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"LLM call failed: {e}",
+ )
# Parse JSON with robustness: tolerate code fences or preamble text
def _extract_json_object(text: str) -> Optional[str]:
@@ -244,7 +286,7 @@ def _extract_json_object(text: str) -> Optional[str]:
elif ch == "}":
depth -= 1
if depth == 0:
- return t[start:i+1]
+ return t[start : i + 1]
return None
parsed = None
@@ -258,10 +300,16 @@ def _extract_json_object(text: str) -> Optional[str]:
except Exception:
parsed = None
if parsed is None:
- raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=f"LLM response was not valid JSON: {llm_response[:1000]}")
+ raise HTTPException(
+ status_code=status.HTTP_502_BAD_GATEWAY,
+ detail=f"LLM response was not valid JSON: {llm_response[:1000]}",
+ )
if not isinstance(parsed, dict):
- raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="LLM response JSON was not an object")
+ raise HTTPException(
+ status_code=status.HTTP_502_BAD_GATEWAY,
+ detail="LLM response JSON was not an object",
+ )
# Normalize and validate keys and types (tolerate minor deviations)
found_raw = parsed.get("found", None)
@@ -272,7 +320,10 @@ def _extract_json_object(text: str) -> Optional[str]:
elif isinstance(found_raw, (int, float)):
found_val = bool(found_raw)
else:
- raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="LLM JSON missing or invalid 'found' key")
+ raise HTTPException(
+ status_code=status.HTTP_502_BAD_GATEWAY,
+ detail="LLM JSON missing or invalid 'found' key",
+ )
# value may be string or null; coerce common primitives to string
val = parsed.get("value")
@@ -283,7 +334,10 @@ def _extract_json_object(text: str) -> Optional[str]:
try:
val = str(val)
except Exception:
- raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="'value' must be a string or null")
+ raise HTTPException(
+ status_code=status.HTTP_502_BAD_GATEWAY,
+ detail="'value' must be a string or null",
+ )
explanation = parsed.get("explanation") or ""
if not isinstance(explanation, str):
@@ -307,7 +361,10 @@ def _extract_json_object(text: str) -> Optional[str]:
# skip unsupported types
continue
else:
- raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="'evidence_sentences' must be a list")
+ raise HTTPException(
+ status_code=status.HTTP_502_BAD_GATEWAY,
+ detail="'evidence_sentences' must be a list",
+ )
# Normalize evidence tables/figures
def _norm_int_list(v: Any) -> List[int]:
@@ -351,14 +408,27 @@ def _norm_int_list(v: Any) -> List[int]:
col_name = snake_case_param(payload.parameter_name)
try:
- updated = await run_in_threadpool(cits_dp_service.update_jsonb_column, citation_id, col_name, stored, table_name)
+ updated = await run_in_threadpool(
+ cits_dp_service.update_jsonb_column,
+ citation_id,
+ col_name,
+ stored,
+ table_name,
+ )
except RuntimeError as rexc:
- raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc))
+ raise HTTPException(
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)
+ )
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to update citation row: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to update citation row: {e}",
+ )
if not updated:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found to update")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found to update"
+ )
# Auto-fill human_param_* from llm_param_* if missing (never overwrite)
try:
@@ -380,7 +450,13 @@ def _norm_int_list(v: Any) -> List[int]:
except Exception:
pass
- return {"status": "success", "sr_id": sr_id, "citation_id": citation_id, "column": col_name, "extraction": stored}
+ return {
+ "status": "success",
+ "sr_id": sr_id,
+ "citation_id": citation_id,
+ "column": col_name,
+ "extraction": stored,
+ }
@router.post("/{sr_id}/citations/{citation_id}/human-extract-parameter")
@@ -403,20 +479,32 @@ async def human_extract_parameter(
except HTTPException:
raise
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to load systematic review or screening: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to load systematic review or screening: {e}",
+ )
table_name = (screening or {}).get("table_name") or "citations"
# Ensure citation exists (we won't require full_text for human input but check row presence)
try:
- row = await run_in_threadpool(cits_dp_service.get_citation_by_id, int(citation_id), table_name)
+ row = await run_in_threadpool(
+ cits_dp_service.get_citation_by_id, int(citation_id), table_name
+ )
except RuntimeError as rexc:
- raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc))
+ raise HTTPException(
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)
+ )
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to query screening DB: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to query screening DB: {e}",
+ )
if not row:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found"
+ )
# Normalize value
val = payload.value
@@ -424,12 +512,18 @@ async def human_extract_parameter(
try:
val = str(val)
except Exception:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="'value' must be a string or null")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail="'value' must be a string or null",
+ )
# Normalize evidence_sentences
evidence = payload.evidence_sentences or []
if not isinstance(evidence, list) or not all(isinstance(i, int) for i in evidence):
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="'evidence_sentences' must be a list of integers")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail="'evidence_sentences' must be a list of integers",
+ )
explanation = payload.explanation or ""
if not isinstance(explanation, str):
@@ -458,22 +552,42 @@ async def human_extract_parameter(
# fallback core name
try:
from ..services.cit_db_service import snake_case as _snake_case
+
core = _snake_case(payload.parameter_name) if _snake_case else ""
except Exception:
core = ""
col_name = f"human_param_{core}" if core else "human_param_param"
try:
- updated = await run_in_threadpool(cits_dp_service.update_jsonb_column, citation_id, col_name, stored, table_name)
+ updated = await run_in_threadpool(
+ cits_dp_service.update_jsonb_column,
+ citation_id,
+ col_name,
+ stored,
+ table_name,
+ )
except RuntimeError as rexc:
- raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc))
+ raise HTTPException(
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)
+ )
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to update citation row: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to update citation row: {e}",
+ )
if not updated:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found to update")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found to update"
+ )
- return {"status": "success", "sr_id": sr_id, "citation_id": citation_id, "column": col_name, "extraction": stored}
+ return {
+ "status": "success",
+ "sr_id": sr_id,
+ "citation_id": citation_id,
+ "column": col_name,
+ "extraction": stored,
+ }
@router.post("/{sr_id}/citations/{citation_id}/extract-fulltext")
@@ -493,34 +607,58 @@ async def extract_fulltext_from_storage(
except HTTPException:
raise
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to load systematic review or screening: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to load systematic review or screening: {e}",
+ )
table_name = (screening or {}).get("table_name") or "citations"
# fetch citation row
try:
- row = await run_in_threadpool(cits_dp_service.get_citation_by_id, int(citation_id), table_name)
+ row = await run_in_threadpool(
+ cits_dp_service.get_citation_by_id, int(citation_id), table_name
+ )
except RuntimeError as rexc:
- raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc))
+ raise HTTPException(
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)
+ )
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to query screening DB: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to query screening DB: {e}",
+ )
if not row:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found"
+ )
# Determine storage path for the citation PDF
storage_path = row.get("fulltext_url")
if not storage_path:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No fulltext storage path found on citation row")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="No fulltext storage path found on citation row",
+ )
try:
content, _filename = await storage_service.get_bytes_by_path(storage_path)
except FileNotFoundError:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Fulltext file not found in storage")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="Fulltext file not found in storage",
+ )
except ValueError:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Unrecognized storage path format")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail="Unrecognized storage path format",
+ )
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to download from storage: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to download from storage: {e}",
+ )
# If the citation row already contains an extracted full text in the "fulltext" column,
# only use it if the stored md5 matches the pdf we just downloaded.
@@ -555,8 +693,15 @@ async def _run_grobid():
async def _run_docint():
if not azure_docint_client or not azure_docint_client.is_available():
- return {"success": False, "error": "Azure DI not configured", "figures": [], "tables": []}
- return await azure_docint_client.extract_citation_artifacts(tmp.name, source_type="file")
+ return {
+ "success": False,
+ "error": "Azure DI not configured",
+ "figures": [],
+ "tables": [],
+ }
+ return await azure_docint_client.extract_citation_artifacts(
+ tmp.name, source_type="file"
+ )
try:
(coords, pages), docint_res = await asyncio.gather(
@@ -564,7 +709,10 @@ async def _run_docint():
_run_docint(),
)
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Fulltext processing failed: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Fulltext processing failed: {e}",
+ )
# filter sentence annotations
annotations = [a for a in coords if a.get("type") == "s" and a.get("text")]
@@ -579,7 +727,11 @@ async def _run_docint():
artifact_coords: List[Dict[str, Any]] = []
try:
- if docint_res and isinstance(docint_res, dict) and docint_res.get("success"):
+ if (
+ docint_res
+ and isinstance(docint_res, dict)
+ and docint_res.get("success")
+ ):
pages_meta = docint_res.get("pages") or []
# Determine artifact base path from the fulltext_url directory
# storage_path is "container/blob".
@@ -589,7 +741,7 @@ async def _run_docint():
artifacts_prefix = artifacts_prefix.replace("//", "/").rstrip("/")
# Figures: write png
- for fig in (docint_res.get("figures") or []):
+ for fig in docint_res.get("figures") or []:
try:
idx = int(fig.get("index"))
except Exception:
@@ -634,7 +786,7 @@ async def _run_docint():
)
# Tables: write markdown (.md)
- for tbl in (docint_res.get("tables") or []):
+ for tbl in docint_res.get("tables") or []:
try:
idx = int(tbl.get("index"))
except Exception:
@@ -687,20 +839,63 @@ async def _run_docint():
# persist full_text_str and coordinates/pages into citation row
try:
coords_for_overlay = list(annotations) + list(artifact_coords)
- updated1 = await run_in_threadpool(cits_dp_service.update_text_column, citation_id, "fulltext", full_text_str, table_name)
- updated2 = await run_in_threadpool(cits_dp_service.update_text_column, citation_id, "fulltext_md5", current_md5, table_name)
- updated3 = await run_in_threadpool(cits_dp_service.update_jsonb_column, citation_id, "fulltext_coords", coords_for_overlay, table_name)
- updated4 = await run_in_threadpool(cits_dp_service.update_jsonb_column, citation_id, "fulltext_pages", pages, table_name)
- updated5 = await run_in_threadpool(cits_dp_service.update_jsonb_column, citation_id, "fulltext_figures", fulltext_figures, table_name)
- updated6 = await run_in_threadpool(cits_dp_service.update_jsonb_column, citation_id, "fulltext_tables", fulltext_tables, table_name)
+ updated1 = await run_in_threadpool(
+ cits_dp_service.update_text_column,
+ citation_id,
+ "fulltext",
+ full_text_str,
+ table_name,
+ )
+ updated2 = await run_in_threadpool(
+ cits_dp_service.update_text_column,
+ citation_id,
+ "fulltext_md5",
+ current_md5,
+ table_name,
+ )
+ updated3 = await run_in_threadpool(
+ cits_dp_service.update_jsonb_column,
+ citation_id,
+ "fulltext_coords",
+ coords_for_overlay,
+ table_name,
+ )
+ updated4 = await run_in_threadpool(
+ cits_dp_service.update_jsonb_column,
+ citation_id,
+ "fulltext_pages",
+ pages,
+ table_name,
+ )
+ updated5 = await run_in_threadpool(
+ cits_dp_service.update_jsonb_column,
+ citation_id,
+ "fulltext_figures",
+ fulltext_figures,
+ table_name,
+ )
+ updated6 = await run_in_threadpool(
+ cits_dp_service.update_jsonb_column,
+ citation_id,
+ "fulltext_tables",
+ fulltext_tables,
+ table_name,
+ )
updated = updated1 or updated2 or updated3 or updated4 or updated5 or updated6
except RuntimeError as rexc:
- raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc))
+ raise HTTPException(
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)
+ )
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to update citation row with full text: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to update citation row with full text: {e}",
+ )
if not updated:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found to update")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found to update"
+ )
return {
"status": "success",
diff --git a/backend/api/files/router.py b/backend/api/files/router.py
index 9e572103..49b35fcf 100644
--- a/backend/api/files/router.py
+++ b/backend/api/files/router.py
@@ -1,6 +1,7 @@
"""
Files router for document management in CAN-SR.
"""
+
from typing import List, Dict, Any, Optional
import os
import logging
@@ -28,6 +29,7 @@
class DocumentUploadResponse(BaseModel):
"""Response model for document upload"""
+
document_id: str
filename: str
file_size: int
@@ -37,6 +39,7 @@ class DocumentUploadResponse(BaseModel):
class DocumentInfo(BaseModel):
"""Document information model"""
+
document_id: str
filename: str
file_size: int
@@ -45,6 +48,7 @@ class DocumentInfo(BaseModel):
class DocumentListResponse(BaseModel):
"""Response model for document list"""
+
total_documents: int
documents: List[DocumentInfo]
@@ -60,8 +64,7 @@ async def upload_document(
try:
if not file.filename:
raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST,
- detail="Filename is required"
+ status_code=status.HTTP_400_BAD_REQUEST, detail="Filename is required"
)
file_content = await file.read()
@@ -129,8 +132,7 @@ async def list_documents(
document_infos.append(document_info)
return DocumentListResponse(
- total_documents=len(document_infos),
- documents=document_infos
+ total_documents=len(document_infos), documents=document_infos
)
except HTTPException:
@@ -165,8 +167,7 @@ async def download_document(
if not document:
raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="Document not found"
+ status_code=status.HTTP_404_NOT_FOUND, detail="Document not found"
)
file_content = await storage_service.get_user_document(
@@ -222,9 +223,13 @@ async def download_by_path(
try:
signed_url = await storage_service.generate_signed_url(path)
except ValueError:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid storage path")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid storage path"
+ )
except Exception as e:
- logger.warning("Signed URL generation failed, falling back to streaming: %s", e)
+ logger.warning(
+ "Signed URL generation failed, falling back to streaming: %s", e
+ )
signed_url = None
if signed_url:
@@ -234,20 +239,26 @@ async def download_by_path(
try:
content, filename = await storage_service.get_bytes_by_path(path)
except FileNotFoundError:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="File not found")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="File not found"
+ )
except ValueError:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid storage path")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid storage path"
+ )
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to download: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to download: {e}",
+ )
+
def gen():
yield content
return StreamingResponse(
gen(),
media_type="application/octet-stream",
- headers={
- "Content-Disposition": f"attachment; filename={filename}"
- },
+ headers={"Content-Disposition": f"attachment; filename={filename}"},
)
except HTTPException:
raise
@@ -281,8 +292,7 @@ async def delete_document(
if not document:
raise HTTPException(
- status_code=status.HTTP_404_NOT_FOUND,
- detail="Document not found"
+ status_code=status.HTTP_404_NOT_FOUND, detail="Document not found"
)
success = await storage_service.delete_user_document(
diff --git a/backend/api/jobs/procrastinate_app.py b/backend/api/jobs/procrastinate_app.py
index 932f9a44..39085b51 100644
--- a/backend/api/jobs/procrastinate_app.py
+++ b/backend/api/jobs/procrastinate_app.py
@@ -62,7 +62,9 @@ def workers_enabled() -> bool:
def worker_concurrency() -> int:
try:
- return max(1, int(getattr(settings, "PROCRASTINATE_WORKER_CONCURRENCY", 1) or 1))
+ return max(
+ 1, int(getattr(settings, "PROCRASTINATE_WORKER_CONCURRENCY", 1) or 1)
+ )
except Exception:
return 1
@@ -79,13 +81,11 @@ async def ensure_procrastinate_schema() -> None:
async def _schema_installed() -> bool:
try:
- row = await PROCRASTINATE_APP.connector.execute_query_one_async(
- """
+ row = await PROCRASTINATE_APP.connector.execute_query_one_async("""
SELECT
to_regclass('procrastinate_jobs') IS NOT NULL AS jobs_table,
EXISTS(SELECT 1 FROM pg_type WHERE typname = 'procrastinate_job_status') AS status_enum
- """ # noqa: S608
- )
+ """) # noqa: S608
return bool(row.get("jobs_table") and row.get("status_enum"))
except Exception:
return False
@@ -109,7 +109,14 @@ async def _schema_installed() -> bool:
# If schema is already present (possibly created by a previous run),
# treat duplicate-object errors as success.
cause: BaseException | None = e.__cause__
- if isinstance(cause, (psycopg.errors.DuplicateObject, psycopg.errors.DuplicateTable, psycopg.errors.DuplicateFunction)):
+ if isinstance(
+ cause,
+ (
+ psycopg.errors.DuplicateObject,
+ psycopg.errors.DuplicateTable,
+ psycopg.errors.DuplicateFunction,
+ ),
+ ):
if await _schema_installed():
return
raise
@@ -156,7 +163,9 @@ async def clear_pending_jobs(*, queues: Optional[list[str]] = None) -> int:
await PROCRASTINATE_APP.close_async()
-async def cancel_enqueued_jobs_for_run_all(job_id: str, *, queues: Optional[list[str]] = None) -> int:
+async def cancel_enqueued_jobs_for_run_all(
+ job_id: str, *, queues: Optional[list[str]] = None
+) -> int:
"""Best-effort: delete enqueued (todo) Procrastinate jobs for a given run-all job_id.
This makes Cancel feel more responsive by removing jobs that haven't started yet.
diff --git a/backend/api/jobs/router.py b/backend/api/jobs/router.py
index a49b8f6a..ec2224aa 100644
--- a/backend/api/jobs/router.py
+++ b/backend/api/jobs/router.py
@@ -13,7 +13,11 @@
from ..services.azure_openai_client import azure_openai_client
from .run_all_repo import run_all_repo
-from .procrastinate_app import cancel_enqueued_jobs_for_run_all, jobs_enabled, worker_concurrency
+from .procrastinate_app import (
+ cancel_enqueued_jobs_for_run_all,
+ jobs_enabled,
+ worker_concurrency,
+)
# Import task objects so we can enqueue via Task.defer_async (Procrastinate 3.2.x)
from .run_all_tasks import run_all_chunk, run_all_start
@@ -21,7 +25,6 @@
# Import tasks so Procrastinate can discover them.
from . import run_all_tasks # noqa: F401
-
router = APIRouter()
@@ -91,7 +94,9 @@ async def list_active_run_all(
if not user_email:
raise HTTPException(status_code=401, detail="Missing user identity")
- srs = await run_in_threadpool(srdb_service.list_systematic_reviews_for_user, user_email)
+ srs = await run_in_threadpool(
+ srdb_service.list_systematic_reviews_for_user, user_email
+ )
sr_ids = [str(sr.get("id")) for sr in (srs or []) if sr and sr.get("id")]
if not sr_ids:
return {"jobs": []}
@@ -99,7 +104,9 @@ async def list_active_run_all(
jobs = await run_in_threadpool(run_all_repo.list_active_jobs_for_srs, sr_ids)
# Attach SR name for nicer UI.
- sr_name_map = {str(sr.get("id")): sr.get("name") for sr in (srs or []) if sr and sr.get("id")}
+ sr_name_map = {
+ str(sr.get("id")): sr.get("name") for sr in (srs or []) if sr and sr.get("id")
+ }
for j in jobs:
sid = str(j.get("sr_id"))
if sid in sr_name_map:
@@ -122,7 +129,9 @@ async def start_run_all(
step = (payload.step or "").lower().strip()
if step not in {"l1", "l2", "extract"}:
- raise HTTPException(status_code=400, detail="step must be one of: l1, l2, extract")
+ raise HTTPException(
+ status_code=400, detail="step must be one of: l1, l2, extract"
+ )
# Authz: ensure user can access SR
try:
@@ -189,7 +198,9 @@ async def start_run_all(
# If another request won the race, the partial unique index on
# (sr_id) WHERE status IN ('queued','running','paused') will throw.
if _is_unique_violation(e):
- existing = await run_in_threadpool(run_all_repo.get_active_job_for_sr, sr_id)
+ existing = await run_in_threadpool(
+ run_all_repo.get_active_job_for_sr, sr_id
+ )
if existing:
return {
"job_id": existing.get("job_id"),
@@ -209,7 +220,9 @@ async def start_run_all(
await run_in_threadpool(run_all_repo.set_status, job_id, "paused")
else:
await run_in_threadpool(run_all_repo.set_status, job_id, "running")
- await run_in_threadpool(run_all_repo.update_phase, job_id, f"enqueued {len(sanitized_ids)}")
+ await run_in_threadpool(
+ run_all_repo.update_phase, job_id, f"enqueued {len(sanitized_ids)}"
+ )
# Fair scheduling: persist chunks and only enqueue the *next* chunk.
# This prevents one job from flooding the global queue.
@@ -228,7 +241,9 @@ async def start_run_all(
)
if next_chunk_id is None:
break
- await run_all_chunk.defer_async(job_id=job_id, chunk_id=int(next_chunk_id))
+ await run_all_chunk.defer_async(
+ job_id=job_id, chunk_id=int(next_chunk_id)
+ )
# Helpful operator logging
print(
@@ -348,7 +363,9 @@ async def run_all_dismiss(
st = str(job.get("status") or "").lower()
if st not in {"finished", "failed"}:
- raise HTTPException(status_code=400, detail="Only finished/failed jobs can be dismissed")
+ raise HTTPException(
+ status_code=400, detail="Only finished/failed jobs can be dismissed"
+ )
await run_in_threadpool(run_all_repo.set_status, job_id, "done")
return {"status": "done", "job_id": job_id}
diff --git a/backend/api/jobs/run_all_repo.py b/backend/api/jobs/run_all_repo.py
index 650dc88a..dd873387 100644
--- a/backend/api/jobs/run_all_repo.py
+++ b/backend/api/jobs/run_all_repo.py
@@ -30,8 +30,7 @@ def ensure_tables(self) -> None:
try:
conn = postgres_server.conn
cur = conn.cursor()
- cur.execute(
- """
+ cur.execute("""
CREATE TABLE IF NOT EXISTS run_all_jobs (
id UUID PRIMARY KEY,
sr_id TEXT NOT NULL,
@@ -50,15 +49,13 @@ def ensure_tables(self) -> None:
started_at TIMESTAMP WITH TIME ZONE,
finished_at TIMESTAMP WITH TIME ZONE
)
- """
- )
+ """)
# Migration safety: older deployments may have allowed multiple active
# jobs per SR. Creating the partial unique index would fail if such
# duplicates exist. Before creating the index, we dedupe by keeping
# the most recent active job per sr_id and canceling older ones.
- cur.execute(
- """
+ cur.execute("""
WITH ranked AS (
SELECT id,
sr_id,
@@ -76,22 +73,18 @@ def ensure_tables(self) -> None:
FROM ranked r
WHERE j.id = r.id
AND r.rn > 1
- """
- )
+ """)
# Enforce: only one active run-all job per SR at a time.
# Active statuses are queued/running/paused.
# This is the critical concurrency guard (race-safe across users).
- cur.execute(
- """
+ cur.execute("""
CREATE UNIQUE INDEX IF NOT EXISTS run_all_jobs_one_active_per_sr
ON run_all_jobs (sr_id)
WHERE status IN ('queued', 'running', 'paused')
- """
- )
+ """)
# Store only failures as requested
- cur.execute(
- """
+ cur.execute("""
CREATE TABLE IF NOT EXISTS run_all_job_errors (
id BIGSERIAL PRIMARY KEY,
job_id UUID NOT NULL REFERENCES run_all_jobs(id) ON DELETE CASCADE,
@@ -100,14 +93,12 @@ def ensure_tables(self) -> None:
error TEXT,
created_at TIMESTAMP WITH TIME ZONE DEFAULT now()
)
- """
- )
+ """)
# Chunk scheduling table (fairness): store chunk definitions and status.
# Each run-all job will only enqueue a small number of chunks at a time
# (prefetch) so multiple run-all jobs can make progress concurrently.
- cur.execute(
- """
+ cur.execute("""
CREATE TABLE IF NOT EXISTS run_all_job_chunks (
id BIGSERIAL PRIMARY KEY,
job_id UUID NOT NULL REFERENCES run_all_jobs(id) ON DELETE CASCADE,
@@ -120,15 +111,12 @@ def ensure_tables(self) -> None:
finished_at TIMESTAMP WITH TIME ZONE,
UNIQUE(job_id, chunk_index)
)
- """
- )
+ """)
- cur.execute(
- """
+ cur.execute("""
CREATE INDEX IF NOT EXISTS run_all_job_chunks_lookup
ON run_all_job_chunks (job_id, status, chunk_index)
- """
- )
+ """)
conn.commit()
except Exception:
_safe_rollback(conn)
@@ -211,7 +199,9 @@ def claim_next_todo_chunk(self, job_id: str, *, prefetch: int = 2) -> Optional[i
# Serialize claims per job to avoid races where multiple workers
# simultaneously observe the same doing-count and over-claim.
# (Row-level locking on the parent job is cheap and safe.)
- cur.execute("SELECT 1 FROM run_all_jobs WHERE id = %s FOR UPDATE", (job_id,))
+ cur.execute(
+ "SELECT 1 FROM run_all_jobs WHERE id = %s FOR UPDATE", (job_id,)
+ )
# Enforce prefetch limit: claim a todo chunk only if currently
# doing_count < pf.
@@ -325,13 +315,11 @@ def count_active_jobs(self) -> int:
try:
conn = postgres_server.conn
cur = conn.cursor()
- cur.execute(
- """
+ cur.execute("""
SELECT COUNT(1)
FROM run_all_jobs
WHERE status IN ('queued', 'running', 'paused')
- """
- )
+ """)
row = cur.fetchone()
return int(row[0] or 0) if row else 0
except Exception:
@@ -491,7 +479,9 @@ def get_job(self, job_id: str) -> Optional[Dict[str, Any]]:
_safe_rollback(conn)
raise
- def set_status(self, job_id: str, status: str, *, error: Optional[str] = None) -> None:
+ def set_status(
+ self, job_id: str, status: str, *, error: Optional[str] = None
+ ) -> None:
conn = None
try:
conn = postgres_server.conn
@@ -553,7 +543,9 @@ def set_total(self, job_id: str, total: int) -> None:
_safe_rollback(conn)
raise
- def inc_counts(self, job_id: str, *, done: int = 0, skipped: int = 0, failed: int = 0) -> None:
+ def inc_counts(
+ self, job_id: str, *, done: int = 0, skipped: int = 0, failed: int = 0
+ ) -> None:
conn = None
try:
conn = postgres_server.conn
@@ -590,7 +582,9 @@ def is_canceled(self, job_id: str) -> bool:
_safe_rollback(conn)
raise
- def add_error(self, job_id: str, *, citation_id: Optional[int], stage: str, error: str) -> None:
+ def add_error(
+ self, job_id: str, *, citation_id: Optional[int], stage: str, error: str
+ ) -> None:
conn = None
try:
conn = postgres_server.conn
@@ -600,7 +594,12 @@ def add_error(self, job_id: str, *, citation_id: Optional[int], stage: str, erro
INSERT INTO run_all_job_errors (job_id, citation_id, stage, error)
VALUES (%s, %s, %s, %s)
""",
- (job_id, int(citation_id) if citation_id is not None else None, stage, error[:8000]),
+ (
+ job_id,
+ int(citation_id) if citation_id is not None else None,
+ stage,
+ error[:8000],
+ ),
)
conn.commit()
except Exception:
diff --git a/backend/api/jobs/run_all_tasks.py b/backend/api/jobs/run_all_tasks.py
index c8c43c48..fda82afa 100644
--- a/backend/api/jobs/run_all_tasks.py
+++ b/backend/api/jobs/run_all_tasks.py
@@ -9,7 +9,12 @@
from .run_all_repo import run_all_repo
from .procrastinate_app import worker_concurrency
from ..services.sr_db_service import srdb_service
-from ..services.cit_db_service import cits_dp_service, snake_case_column, snake_case_param, snake_case
+from ..services.cit_db_service import (
+ cits_dp_service,
+ snake_case_column,
+ snake_case_param,
+ snake_case,
+)
from ..citations import router as citations_router
from ..services.azure_openai_client import azure_openai_client
from ..services.storage import storage_service
@@ -93,12 +98,16 @@ def _eligible_ids(*, sr_id: str, table_name: str, step: str) -> List[int]:
elif step == "extract":
filter_step = "l2"
- ids = cits_dp_service.list_citation_ids(filter_step if filter_step else None, table_name)
+ ids = cits_dp_service.list_citation_ids(
+ filter_step if filter_step else None, table_name
+ )
# PDF gating for l2/extract
if step in ("l2", "extract"):
# keep only rows with fulltext_url
- rows = cits_dp_service.get_citations_by_ids(ids, table_name, fields=["id", "fulltext_url"])
+ rows = cits_dp_service.get_citations_by_ids(
+ ids, table_name, fields=["id", "fulltext_url"]
+ )
ok = []
for r in rows:
try:
@@ -121,7 +130,9 @@ async def _run_l1_for_citation(
force: bool,
) -> tuple[int, int, int]:
"""Returns (done, skipped, failed) increments for this citation."""
- row = await run_in_threadpool(cits_dp_service.get_citation_by_id, citation_id, table_name)
+ row = await run_in_threadpool(
+ cits_dp_service.get_citation_by_id, citation_id, table_name
+ )
if not row:
return (0, 0, 1)
@@ -129,7 +140,9 @@ async def _run_l1_for_citation(
if not include_cols:
include_cols = ["title", "abstract"]
- citation_text = citations_router._build_combined_citation_from_row(row, include_cols)
+ citation_text = citations_router._build_combined_citation_from_row(
+ row, include_cols
+ )
cp = sr.get("criteria_parsed") or sr.get("criteria") or {}
l1 = cp.get("l1") if isinstance(cp, dict) else None
@@ -149,7 +162,9 @@ async def _run_l1_for_citation(
if await run_in_threadpool(run_all_repo.is_canceled, job_id):
raise RunAllCanceled()
await _wait_if_paused(job_id)
- opts = possible[i] if i < len(possible) and isinstance(possible[i], list) else []
+ opts = (
+ possible[i] if i < len(possible) and isinstance(possible[i], list) else []
+ )
xtra = addinfos[i] if i < len(addinfos) and isinstance(addinfos[i], str) else ""
col = snake_case_column(q)
existing = row.get(col)
@@ -160,7 +175,9 @@ async def _run_l1_for_citation(
raise RuntimeError("Azure OpenAI client not configured")
options_listed = "\n".join([f"{j}. {opt}" for j, opt in enumerate(opts)])
- prompt = PROMPT_JSON_TEMPLATE.format(question=q, cit=citation_text, options=options_listed, xtra=xtra)
+ prompt = PROMPT_JSON_TEMPLATE.format(
+ question=q, cit=citation_text, options=options_listed, xtra=xtra
+ )
llm_response = await azure_openai_client.simple_chat(
user_message=prompt,
system_prompt=None,
@@ -181,15 +198,28 @@ async def _run_l1_for_citation(
classification_json = {
"selected": resolved_selected,
- "explanation": parsed.get("explanation") or parsed.get("reason") or parsed.get("explain") or "",
- "confidence": float(parsed.get("confidence") or 0.0) if str(parsed.get("confidence") or "").strip() else 0.0,
+ "explanation": parsed.get("explanation")
+ or parsed.get("reason")
+ or parsed.get("explain")
+ or "",
+ "confidence": (
+ float(parsed.get("confidence") or 0.0)
+ if str(parsed.get("confidence") or "").strip()
+ else 0.0
+ ),
"evidence_sentences": parsed.get("evidence_sentences") or [],
"evidence_tables": parsed.get("evidence_tables") or [],
"evidence_figures": parsed.get("evidence_figures") or [],
"llm_raw": llm_response,
}
- await run_in_threadpool(cits_dp_service.update_jsonb_column, citation_id, col, classification_json, table_name)
+ await run_in_threadpool(
+ cits_dp_service.update_jsonb_column,
+ citation_id,
+ col,
+ classification_json,
+ table_name,
+ )
# Best-effort autofill human_ if empty
try:
@@ -229,7 +259,9 @@ async def _ensure_fulltext_if_needed(
force: bool,
) -> bool:
"""Ensure fulltext artifacts exist. Returns True if available after this call."""
- row = await run_in_threadpool(cits_dp_service.get_citation_by_id, citation_id, table_name)
+ row = await run_in_threadpool(
+ cits_dp_service.get_citation_by_id, citation_id, table_name
+ )
if not row:
return False
if not row.get("fulltext_url"):
@@ -242,9 +274,13 @@ async def _ensure_fulltext_if_needed(
await extract_fulltext_from_storage(sr_id, citation_id, current_user=current_user) # type: ignore
except Exception:
# It's okay if DI/grobid fails; L2/extract depends on fulltext text though.
- row2 = await run_in_threadpool(cits_dp_service.get_citation_by_id, citation_id, table_name)
+ row2 = await run_in_threadpool(
+ cits_dp_service.get_citation_by_id, citation_id, table_name
+ )
return bool(row2 and row2.get("fulltext"))
- row3 = await run_in_threadpool(cits_dp_service.get_citation_by_id, citation_id, table_name)
+ row3 = await run_in_threadpool(
+ cits_dp_service.get_citation_by_id, citation_id, table_name
+ )
return bool(row3 and row3.get("fulltext"))
@@ -258,18 +294,28 @@ async def _run_l2_for_citation(
model: Optional[str],
force: bool,
) -> tuple[int, int, int]:
- row = await run_in_threadpool(cits_dp_service.get_citation_by_id, citation_id, table_name)
+ row = await run_in_threadpool(
+ cits_dp_service.get_citation_by_id, citation_id, table_name
+ )
if not row or not row.get("fulltext_url"):
return (0, 1, 0)
# fake current_user for storage service paths (it only uses id in upload; extract reads by path)
current_user = {"id": "system", "email": "system"}
await _wait_if_paused(job_id)
- ok = await _ensure_fulltext_if_needed(sr_id=sr_id, citation_id=citation_id, current_user=current_user, table_name=table_name, force=force)
+ ok = await _ensure_fulltext_if_needed(
+ sr_id=sr_id,
+ citation_id=citation_id,
+ current_user=current_user,
+ table_name=table_name,
+ force=force,
+ )
if not ok:
return (0, 1, 0)
- row = await run_in_threadpool(cits_dp_service.get_citation_by_id, citation_id, table_name)
+ row = await run_in_threadpool(
+ cits_dp_service.get_citation_by_id, citation_id, table_name
+ )
if not row:
return (0, 0, 1)
@@ -284,8 +330,13 @@ async def _run_l2_for_citation(
if not questions:
return (0, 1, 0)
- include_cols = cits_dp_service.load_include_columns_from_criteria(sr) or ["title", "abstract"]
- citation_text = citations_router._build_combined_citation_from_row(row, include_cols)
+ include_cols = cits_dp_service.load_include_columns_from_criteria(sr) or [
+ "title",
+ "abstract",
+ ]
+ citation_text = citations_router._build_combined_citation_from_row(
+ row, include_cols
+ )
fulltext = row.get("fulltext") or citation_text
# Tables/Figures context from row
@@ -333,7 +384,9 @@ async def _run_l2_for_citation(
caption = item.get("caption")
if not idx or not blob_addr:
continue
- figures_lines.append(f"Figure [F{idx}] caption: {caption or '(no caption)'} (see attached image F{idx})")
+ figures_lines.append(
+ f"Figure [F{idx}] caption: {caption or '(no caption)'} (see attached image F{idx})"
+ )
try:
img_bytes, _ = await storage_service.get_bytes_by_path(blob_addr)
if img_bytes:
@@ -347,7 +400,9 @@ async def _run_l2_for_citation(
if await run_in_threadpool(run_all_repo.is_canceled, job_id):
raise RunAllCanceled()
await _wait_if_paused(job_id)
- opts = possible[i] if i < len(possible) and isinstance(possible[i], list) else []
+ opts = (
+ possible[i] if i < len(possible) and isinstance(possible[i], list) else []
+ )
xtra = addinfos[i] if i < len(addinfos) and isinstance(addinfos[i], str) else ""
col = snake_case_column(q)
existing = row.get(col)
@@ -392,15 +447,28 @@ async def _run_l2_for_citation(
classification_json = {
"selected": resolved_selected,
- "explanation": parsed.get("explanation") or parsed.get("reason") or parsed.get("explain") or "",
- "confidence": float(parsed.get("confidence") or 0.0) if str(parsed.get("confidence") or "").strip() else 0.0,
+ "explanation": parsed.get("explanation")
+ or parsed.get("reason")
+ or parsed.get("explain")
+ or "",
+ "confidence": (
+ float(parsed.get("confidence") or 0.0)
+ if str(parsed.get("confidence") or "").strip()
+ else 0.0
+ ),
"evidence_sentences": parsed.get("evidence_sentences") or [],
"evidence_tables": parsed.get("evidence_tables") or [],
"evidence_figures": parsed.get("evidence_figures") or [],
"llm_raw": llm_response,
}
- await run_in_threadpool(cits_dp_service.update_jsonb_column, citation_id, col, classification_json, table_name)
+ await run_in_threadpool(
+ cits_dp_service.update_jsonb_column,
+ citation_id,
+ col,
+ classification_json,
+ table_name,
+ )
# Best-effort autofill human
try:
@@ -440,24 +508,36 @@ async def _run_extract_for_citation(
model: Optional[str],
force: bool,
) -> tuple[int, int, int]:
- row = await run_in_threadpool(cits_dp_service.get_citation_by_id, citation_id, table_name)
+ row = await run_in_threadpool(
+ cits_dp_service.get_citation_by_id, citation_id, table_name
+ )
if not row or not row.get("fulltext_url"):
return (0, 1, 0)
current_user = {"id": "system", "email": "system"}
await _wait_if_paused(job_id)
- ok = await _ensure_fulltext_if_needed(sr_id=sr_id, citation_id=citation_id, current_user=current_user, table_name=table_name, force=force)
+ ok = await _ensure_fulltext_if_needed(
+ sr_id=sr_id,
+ citation_id=citation_id,
+ current_user=current_user,
+ table_name=table_name,
+ force=force,
+ )
if not ok:
return (0, 1, 0)
- row = await run_in_threadpool(cits_dp_service.get_citation_by_id, citation_id, table_name)
+ row = await run_in_threadpool(
+ cits_dp_service.get_citation_by_id, citation_id, table_name
+ )
if not row:
return (0, 0, 1)
cp = sr.get("criteria_parsed") or sr.get("criteria") or {}
params = cp.get("parameters") if isinstance(cp, dict) else None
categories = (params or {}).get("categories") if isinstance(params, dict) else []
- possible = (params or {}).get("possible_parameters") if isinstance(params, dict) else []
+ possible = (
+ (params or {}).get("possible_parameters") if isinstance(params, dict) else []
+ )
descs = (params or {}).get("descriptions") if isinstance(params, dict) else []
categories = categories if isinstance(categories, list) else []
possible = possible if isinstance(possible, list) else []
@@ -528,7 +608,9 @@ async def _run_extract_for_citation(
caption = item.get("caption")
if not idx or not blob_addr:
continue
- flines.append(f"Figure [F{idx}] caption: {caption or '(no caption)'} (see attached image F{idx})")
+ flines.append(
+ f"Figure [F{idx}] caption: {caption or '(no caption)'} (see attached image F{idx})"
+ )
try:
img_bytes, _ = await storage_service.get_bytes_by_path(blob_addr)
if img_bytes:
@@ -621,7 +703,9 @@ def _extract_json_object(text: str) -> Optional[str]:
"llm_raw": str(llm_response)[:4000],
}
- await run_in_threadpool(cits_dp_service.update_jsonb_column, citation_id, col, stored, table_name)
+ await run_in_threadpool(
+ cits_dp_service.update_jsonb_column, citation_id, col, stored, table_name
+ )
# best-effort autofill human_param
try:
@@ -664,16 +748,23 @@ async def run_all_start(job_id: str) -> None:
ids = []
sr, table_name = await _load_sr_and_table(sr_id)
- ids = await run_in_threadpool(_eligible_ids, sr_id=sr_id, table_name=table_name, step=step)
+ ids = await run_in_threadpool(
+ _eligible_ids, sr_id=sr_id, table_name=table_name, step=step
+ )
await run_in_threadpool(run_all_repo.set_total, job_id, len(ids))
- print(f"[run-all] kickoff job_id={job_id} step={step} eligible={len(ids)}", flush=True)
+ print(
+ f"[run-all] kickoff job_id={job_id} step={step} eligible={len(ids)}",
+ flush=True,
+ )
# If user paused before kickoff, preserve paused status.
if await run_in_threadpool(run_all_repo.is_paused, job_id):
await run_in_threadpool(run_all_repo.set_status, job_id, "paused")
else:
await run_in_threadpool(run_all_repo.set_status, job_id, "running")
- await run_in_threadpool(run_all_repo.update_phase, job_id, f"enqueued {len(ids)}")
+ await run_in_threadpool(
+ run_all_repo.update_phase, job_id, f"enqueued {len(ids)}"
+ )
# Fair scheduling: persist chunks and enqueue only the next chunk.
chunks = [ids[i : i + chunk_size] for i in range(0, len(ids), chunk_size)]
@@ -689,7 +780,9 @@ async def run_all_start(job_id: str) -> None:
)
if next_chunk_id is None:
break
- await run_all_chunk.defer_async(job_id=job_id, chunk_id=int(next_chunk_id))
+ await run_all_chunk.defer_async(
+ job_id=job_id, chunk_id=int(next_chunk_id)
+ )
except Exception as e:
await run_in_threadpool(run_all_repo.set_status, job_id, "failed", error=str(e))
@@ -739,18 +832,49 @@ async def run_all_chunk(job_id: str, chunk_id: int) -> None:
try:
await _wait_if_paused(job_id)
if step == "l1":
- d, s, f = await _run_l1_for_citation(job_id=job_id, sr=sr, table_name=table_name, citation_id=int(cid), model=model, force=force)
+ d, s, f = await _run_l1_for_citation(
+ job_id=job_id,
+ sr=sr,
+ table_name=table_name,
+ citation_id=int(cid),
+ model=model,
+ force=force,
+ )
elif step == "l2":
- d, s, f = await _run_l2_for_citation(job_id=job_id, sr=sr, table_name=table_name, sr_id=sr_id, citation_id=int(cid), model=model, force=force)
+ d, s, f = await _run_l2_for_citation(
+ job_id=job_id,
+ sr=sr,
+ table_name=table_name,
+ sr_id=sr_id,
+ citation_id=int(cid),
+ model=model,
+ force=force,
+ )
elif step == "extract":
- d, s, f = await _run_extract_for_citation(job_id=job_id, sr=sr, table_name=table_name, sr_id=sr_id, citation_id=int(cid), model=model, force=force)
+ d, s, f = await _run_extract_for_citation(
+ job_id=job_id,
+ sr=sr,
+ table_name=table_name,
+ sr_id=sr_id,
+ citation_id=int(cid),
+ model=model,
+ force=force,
+ )
else:
d, s, f = (0, 1, 0)
- await run_in_threadpool(run_all_repo.inc_counts, job_id, done=d, skipped=s, failed=f)
+ await run_in_threadpool(
+ run_all_repo.inc_counts, job_id, done=d, skipped=s, failed=f
+ )
except RunAllCanceled:
return
except Exception as e:
- await run_in_threadpool(run_all_repo.add_error, job_id, citation_id=int(cid), stage=step, error=str(e))
+ await run_in_threadpool(
+ run_all_repo.add_error,
+ job_id,
+ citation_id=int(cid),
+ stage=step,
+ error=str(e),
+ )
await run_in_threadpool(run_all_repo.inc_counts, job_id, failed=1)
chunk_failed = True
chunk_error = str(e)
@@ -763,7 +887,11 @@ async def run_all_chunk(job_id: str, chunk_id: int) -> None:
# Mark chunk complete and schedule more work (up to prefetch).
try:
if chunk_failed:
- await run_in_threadpool(run_all_repo.mark_chunk_failed, int(chunk_id), error=chunk_error or "chunk had failures")
+ await run_in_threadpool(
+ run_all_repo.mark_chunk_failed,
+ int(chunk_id),
+ error=chunk_error or "chunk had failures",
+ )
else:
await run_in_threadpool(run_all_repo.mark_chunk_done, int(chunk_id))
except Exception:
@@ -802,7 +930,11 @@ async def run_all_chunk(job_id: str, chunk_id: int) -> None:
# Successful completion should remain visible in the UI until the
# user dismisses it. We model that as a terminal-but-sticky status
# called "finished". Dismissal transitions "finished" -> "done".
- if total > 0 and (done + skipped + failed) >= total and not await run_in_threadpool(run_all_repo.is_canceled, job_id):
+ if (
+ total > 0
+ and (done + skipped + failed) >= total
+ and not await run_in_threadpool(run_all_repo.is_canceled, job_id)
+ ):
await run_in_threadpool(run_all_repo.set_status, job_id, "finished")
except Exception:
pass
diff --git a/backend/api/models/auth.py b/backend/api/models/auth.py
index 7e34fea0..4588ad2e 100644
--- a/backend/api/models/auth.py
+++ b/backend/api/models/auth.py
@@ -1,6 +1,7 @@
"""
Authentication and user models for the API.
"""
+
from typing import Optional
from datetime import datetime
diff --git a/backend/api/router.py b/backend/api/router.py
index 94d5144d..3950926d 100644
--- a/backend/api/router.py
+++ b/backend/api/router.py
@@ -37,4 +37,6 @@
api_router.include_router(jobs_router, prefix="/jobs", tags=["Jobs"])
# Database Search API
-api_router.include_router(database_search_router, prefix="/database_search", tags=["Database Search"])
+api_router.include_router(
+ database_search_router, prefix="/database_search", tags=["Database Search"]
+)
diff --git a/backend/api/screen/prompts.py b/backend/api/screen/prompts.py
index ba7ac9d5..8d5ae0c7 100644
--- a/backend/api/screen/prompts.py
+++ b/backend/api/screen/prompts.py
@@ -72,4 +72,4 @@
- Use sentence indices from the numbered full text for "evidence_sentences"
- Use table numbers from the Tables section for "evidence_tables"
- Use figure numbers from the Figures section for "evidence_figures"
-"""
\ No newline at end of file
+"""
diff --git a/backend/api/screen/router.py b/backend/api/screen/router.py
index 10174b82..93bd917b 100644
--- a/backend/api/screen/router.py
+++ b/backend/api/screen/router.py
@@ -54,18 +54,36 @@ def _normalize_int_list(v: Any) -> List[int]:
class ClassifyRequest(BaseModel):
citation_text: Optional[str] = Field(
- None, description="Optional combined citation text. If omitted the server will build it from the screening DB row."
+ None,
+ description="Optional combined citation text. If omitted the server will build it from the screening DB row.",
)
include_columns: Optional[List[str]] = Field(
- None, description="If citation_text is omitted, these columns (original CSV headers) will be used to build the combined citation"
+ None,
+ description="If citation_text is omitted, these columns (original CSV headers) will be used to build the combined citation",
)
- question: str = Field(..., description="L1 criteria question to apply to this citation")
- screening_step: str = Field(..., description="Screening step identifier: 'l1' or 'l2', etc.")
- options: List[str] = Field(..., description="List of possible options (exact strings). The model must pick one.")
- xtra: Optional[str] = Field("", description="Additional context/instructions for the model")
- model: Optional[str] = Field(None, description="Model to use (falls back to default configured model)")
- temperature: Optional[float] = Field(0.0, ge=0.0, le=1.0, description="Sampling temperature")
- max_tokens: Optional[int] = Field(2000, ge=1, le=4000, description="Max tokens for LLM response")
+ question: str = Field(
+ ..., description="L1 criteria question to apply to this citation"
+ )
+ screening_step: str = Field(
+ ..., description="Screening step identifier: 'l1' or 'l2', etc."
+ )
+ options: List[str] = Field(
+ ...,
+ description="List of possible options (exact strings). The model must pick one.",
+ )
+ xtra: Optional[str] = Field(
+ "", description="Additional context/instructions for the model"
+ )
+ model: Optional[str] = Field(
+ None, description="Model to use (falls back to default configured model)"
+ )
+ temperature: Optional[float] = Field(
+ 0.0, ge=0.0, le=1.0, description="Sampling temperature"
+ )
+ max_tokens: Optional[int] = Field(
+ 2000, ge=1, le=4000, description="Max tokens for LLM response"
+ )
+
class HumanClassifyRequest(BaseModel):
"""
@@ -73,24 +91,35 @@ class HumanClassifyRequest(BaseModel):
This mirrors the shape of the LLM-based classify payload but accepts a
direct `selected` value and optional explanation/confidence.
"""
+
citation_text: Optional[str] = Field(
- None, description="Optional combined citation text. If omitted the server will build it from the screening DB row."
+ None,
+ description="Optional combined citation text. If omitted the server will build it from the screening DB row.",
)
include_columns: Optional[List[str]] = Field(
- None, description="If citation_text is omitted, these columns (original CSV headers) will be used to build the combined citation"
+ None,
+ description="If citation_text is omitted, these columns (original CSV headers) will be used to build the combined citation",
+ )
+ question: str = Field(
+ ..., description="L1 criteria question to apply to this citation"
)
- question: str = Field(..., description="L1 criteria question to apply to this citation")
selected: str = Field(..., description="Human-selected option (string)")
- screening_step: str = Field(..., description="Screening step identifier: 'l1' or 'l2', etc.")
- explanation: Optional[str] = Field("", description="Optional free-text explanation from the human reviewer")
- confidence: Optional[float] = Field(None, ge=0.0, le=1.0, description="Optional confidence (0.0 - 1.0)")
+ screening_step: str = Field(
+ ..., description="Screening step identifier: 'l1' or 'l2', etc."
+ )
+ explanation: Optional[str] = Field(
+ "", description="Optional free-text explanation from the human reviewer"
+ )
+ confidence: Optional[float] = Field(
+ None, ge=0.0, le=1.0, description="Optional confidence (0.0 - 1.0)"
+ )
reviewer: Optional[str] = Field(None, description="Optional reviewer id or name")
-
+
+
# _update_sync moved to backend.api.core.postgres.update_jsonb_column
# Use run_in_threadpool(update_jsonb_column, ...) where needed.
-
@router.post("/{sr_id}/citations/{citation_id}/classify")
async def classify_citation(
sr_id: str,
@@ -111,27 +140,47 @@ async def classify_citation(
except HTTPException:
raise
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to load systematic review or screening: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to load systematic review or screening: {e}",
+ )
table_name = (screening or {}).get("table_name") or "citations"
# Load citation row (needed for l2 fulltext and for building citation_text)
try:
- row = await run_in_threadpool(cits_dp_service.get_citation_by_id, int(citation_id), table_name)
+ row = await run_in_threadpool(
+ cits_dp_service.get_citation_by_id, int(citation_id), table_name
+ )
except RuntimeError as rexc:
- raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc))
+ raise HTTPException(
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)
+ )
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to query screening DB: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to query screening DB: {e}",
+ )
if not row:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found"
+ )
# Build or use provided citation text (fall back to combined title/abstract when not provided)
- citation_text = payload.citation_text or citations_router._build_combined_citation_from_row(row, payload.include_columns)
+ citation_text = (
+ payload.citation_text
+ or citations_router._build_combined_citation_from_row(
+ row, payload.include_columns
+ )
+ )
# Ensure LLM client is available
if not azure_openai_client.is_configured():
- raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Azure OpenAI client is not configured on the server")
+ raise HTTPException(
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
+ detail="Azure OpenAI client is not configured on the server",
+ )
# Prepare prompt (use full-text template for l2, otherwise TA/L1 template)
options_listed = "\n".join([f"{i}. {opt}" for i, opt in enumerate(payload.options)])
@@ -242,17 +291,29 @@ async def classify_citation(
try:
parsed = json.loads(llm_response)
except Exception:
- raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=f"LLM response was not valid JSON: {llm_response[:1000]}")
+ raise HTTPException(
+ status_code=status.HTTP_502_BAD_GATEWAY,
+ detail=f"LLM response was not valid JSON: {llm_response[:1000]}",
+ )
if not isinstance(parsed, dict):
- raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=f"LLM response JSON was not an object: {str(type(parsed))}")
+ raise HTTPException(
+ status_code=status.HTTP_502_BAD_GATEWAY,
+ detail=f"LLM response JSON was not an object: {str(type(parsed))}",
+ )
# Require 'selected' key and validate it is a string
if "selected" not in parsed:
- raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=f"LLM JSON missing 'selected' key: {json.dumps(parsed)[:1000]}")
+ raise HTTPException(
+ status_code=status.HTTP_502_BAD_GATEWAY,
+ detail=f"LLM JSON missing 'selected' key: {json.dumps(parsed)[:1000]}",
+ )
selected_value = parsed.get("selected")
if not isinstance(selected_value, str):
- raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=f"LLM 'selected' must be a string: {str(type(selected_value))}")
+ raise HTTPException(
+ status_code=status.HTTP_502_BAD_GATEWAY,
+ detail=f"LLM 'selected' must be a string: {str(type(selected_value))}",
+ )
s = selected_value.strip()
@@ -263,7 +324,9 @@ async def classify_citation(
resolved_selected = opt
break
- explanation = parsed.get("explanation") or parsed.get("reason") or parsed.get("explain") or ""
+ explanation = (
+ parsed.get("explanation") or parsed.get("reason") or parsed.get("explain") or ""
+ )
confidence_raw = parsed.get("confidence")
# Parse confidence
@@ -298,14 +361,27 @@ async def classify_citation(
human_col_name = f"human_{col_core_h}" if col_core_h else "human_col"
try:
- updated = await run_in_threadpool(cits_dp_service.update_jsonb_column, citation_id, col_name, classification_json, table_name)
+ updated = await run_in_threadpool(
+ cits_dp_service.update_jsonb_column,
+ citation_id,
+ col_name,
+ classification_json,
+ table_name,
+ )
except RuntimeError as rexc:
- raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc))
+ raise HTTPException(
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)
+ )
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to update citation row: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to update citation row: {e}",
+ )
if not updated:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found to update")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found to update"
+ )
# Auto-fill human_* from llm_* if missing (never overwrite)
try:
@@ -326,10 +402,16 @@ async def classify_citation(
except Exception:
# best-effort
pass
-
+
await update_inclusion_decision(sr, citation_id, payload.screening_step, "llm")
- return {"status": "success", "sr_id": sr_id, "citation_id": citation_id, "column": col_name, "classification": classification_json}
+ return {
+ "status": "success",
+ "sr_id": sr_id,
+ "citation_id": citation_id,
+ "column": col_name,
+ "classification": classification_json,
+ }
@router.post("/{sr_id}/citations/{citation_id}/human_classify")
@@ -350,20 +432,32 @@ async def human_classify_citation(
except HTTPException:
raise
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to load systematic review or screening: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to load systematic review or screening: {e}",
+ )
table_name = (screening or {}).get("table_name") or "citations"
# Ensure citation exists and optionally build combined citation text
try:
- row = await run_in_threadpool(cits_dp_service.get_citation_by_id, int(citation_id), table_name)
+ row = await run_in_threadpool(
+ cits_dp_service.get_citation_by_id, int(citation_id), table_name
+ )
except RuntimeError as rexc:
- raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc))
+ raise HTTPException(
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)
+ )
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to query screening DB: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to query screening DB: {e}",
+ )
if not row:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found"
+ )
citation_text = payload.citation_text
confidence = payload.confidence
@@ -381,28 +475,48 @@ async def human_classify_citation(
# Persist into Postgres under a dynamic column name derived from question
# Use snake_case to create a stable core name and prefix with 'human_'
col_core = snake_case(payload.question, max_len=56) if snake_case else None
- col_name = f"human_{col_core}" if col_core else f"human_col"
+ col_name = f"human_{col_core}" if col_core else f"human_col"
try:
- updated = await run_in_threadpool(cits_dp_service.update_jsonb_column, citation_id, col_name, classification_json, table_name)
+ updated = await run_in_threadpool(
+ cits_dp_service.update_jsonb_column,
+ citation_id,
+ col_name,
+ classification_json,
+ table_name,
+ )
except RuntimeError as rexc:
- raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc))
+ raise HTTPException(
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)
+ )
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to update citation row: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to update citation row: {e}",
+ )
if not updated:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found to update")
-
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found to update"
+ )
+
await update_inclusion_decision(sr, citation_id, payload.screening_step, "human")
- return {"status": "success", "sr_id": sr_id, "citation_id": citation_id, "column": col_name, "classification": classification_json}
+ return {
+ "status": "success",
+ "sr_id": sr_id,
+ "citation_id": citation_id,
+ "column": col_name,
+ "classification": classification_json,
+ }
+
async def update_inclusion_decision(
sr: Dict[str, Any],
citation_id: int,
screening_step: str,
decision_maker: str,
-):
+):
table_name = (sr.get("screening_db") or {}).get("table_name") or "citations"
# IMPORTANT: decision/pass computation must not be stale.
@@ -415,10 +529,15 @@ def _get_row_fresh() -> Dict[str, Any]:
try:
row = await run_in_threadpool(_get_row_fresh)
except RuntimeError as rexc:
- raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc))
+ raise HTTPException(
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)
+ )
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to query screening DB: {e}")
-
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to query screening DB: {e}",
+ )
+
questions = sr["criteria_parsed"][screening_step]["questions"]
classified = True
decision = "undecided"
@@ -435,17 +554,30 @@ def _get_row_fresh() -> Dict[str, Any]:
if classified and decision == "undecided":
decision = "include"
-
+
col_name = f"{decision_maker}_{screening_step}_decision"
try:
- updated = await run_in_threadpool(cits_dp_service.update_text_column, citation_id, col_name, decision, table_name)
+ updated = await run_in_threadpool(
+ cits_dp_service.update_text_column,
+ citation_id,
+ col_name,
+ decision,
+ table_name,
+ )
except RuntimeError as rexc:
- raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc))
+ raise HTTPException(
+ status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=str(rexc)
+ )
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to update citation row: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to update citation row: {e}",
+ )
print(updated)
if not updated:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found to update")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="Citation not found to update"
+ )
# Validation rule (B1/B2): do NOT use l1_screen/l2_screen for filtering.
# Ensure human_l1_decision / human_l2_decision are always correct and derived
@@ -489,7 +621,9 @@ def _compute_human_decision(step: str) -> str:
except Exception:
selected = None
# Treat empty/whitespace as unanswered (UI shows "-- select --")
- if selected is None or (isinstance(selected, str) and selected.strip() == ""):
+ if selected is None or (
+ isinstance(selected, str) and selected.strip() == ""
+ ):
return "undecided"
if "exclude" in str(selected).lower():
return "exclude"
@@ -499,8 +633,20 @@ def _compute_human_decision(step: str) -> str:
# Always set both human decisions on any update, so the list filters never go stale.
h1 = _compute_human_decision("l1")
h2 = _compute_human_decision("l2")
- await run_in_threadpool(cits_dp_service.update_text_column, citation_id, "human_l1_decision", h1, table_name)
- await run_in_threadpool(cits_dp_service.update_text_column, citation_id, "human_l2_decision", h2, table_name)
+ await run_in_threadpool(
+ cits_dp_service.update_text_column,
+ citation_id,
+ "human_l1_decision",
+ h1,
+ table_name,
+ )
+ await run_in_threadpool(
+ cits_dp_service.update_text_column,
+ citation_id,
+ "human_l2_decision",
+ h2,
+ table_name,
+ )
except Exception:
# best-effort; do not block response
pass
diff --git a/backend/api/services/azure_docint_client.py b/backend/api/services/azure_docint_client.py
index cf0eb071..be1c3150 100644
--- a/backend/api/services/azure_docint_client.py
+++ b/backend/api/services/azure_docint_client.py
@@ -42,6 +42,7 @@
from typing import Any as DocumentIntelligenceClient # type: ignore
from typing import Any as AnalyzeDocumentRequest # type: ignore
from typing import Any as AzureKeyCredential # type: ignore
+
print(
"Azure Document Intelligence SDK not installed. Install with: pip install azure-ai-documentintelligence azure-core"
)
@@ -63,13 +64,13 @@ def _init_client(self) -> Optional["DocumentIntelligenceClient"]:
"""Initialize Azure Document Intelligence client"""
if not AZURE_DOC_INTELLIGENCE_AVAILABLE:
return None
-
+
if settings.AZURE_DOC_INT_MODE not in ["key", "entra"]:
print(
f"Invalid AZURE_DOC_INT_MODE: {settings.AZURE_DOC_INT_MODE}. Must be 'key' or 'entra'."
)
return None
-
+
if not settings.AZURE_DOC_INT_ENDPOINT:
print(
"Azure Document Intelligence endpoint not found. Set AZURE_DOC_INT_ENDPOINT environment variable."
@@ -81,11 +82,13 @@ def _init_client(self) -> Optional["DocumentIntelligenceClient"]:
"Azure Document Intelligence API key not found. Set AZURE_DOC_INT_API_KEY for key-based auth."
)
return None
-
+
doc_int_kwargs = {"endpoint": settings.AZURE_DOC_INT_ENDPOINT}
if settings.AZURE_DOC_INT_MODE == "key":
- doc_int_kwargs["credential"] = AzureKeyCredential(settings.AZURE_DOC_INT_API_KEY)
+ doc_int_kwargs["credential"] = AzureKeyCredential(
+ settings.AZURE_DOC_INT_API_KEY
+ )
elif settings.AZURE_DOC_INT_MODE == "entra":
doc_int_kwargs["credential"] = DefaultAzureCredential()
@@ -680,7 +683,9 @@ def _extract_html_tables_from_markdown(markdown: str) -> List[str]:
table_pattern = r"
"
return re.findall(table_pattern, markdown, re.DOTALL)
- def _download_figure_bytes_sync(self, result_id: str, figure_id: str) -> Optional[bytes]:
+ def _download_figure_bytes_sync(
+ self, result_id: str, figure_id: str
+ ) -> Optional[bytes]:
"""Download a single figure image from Azure DI as bytes (sync)."""
try:
stream = self.client.get_analyze_result_figure(
@@ -719,7 +724,10 @@ async def extract_citation_artifacts(
"""
if not self.client:
- return {"success": False, "error": "Azure Document Intelligence client not available"}
+ return {
+ "success": False,
+ "error": "Azure Document Intelligence client not available",
+ }
# We need figures in output.
output_param = ["figures"]
@@ -754,7 +762,9 @@ async def extract_citation_artifacts(
for i, md in enumerate(md_tables, start=1):
bbox = None
if i - 1 < len(raw_tables):
- bbox = raw_tables[i - 1].get("boundingRegions") or raw_tables[i - 1].get("bounding_regions")
+ bbox = raw_tables[i - 1].get("boundingRegions") or raw_tables[
+ i - 1
+ ].get("bounding_regions")
tables_out.append(
{
"index": i,
@@ -783,7 +793,7 @@ async def extract_citation_artifacts(
bounding_regions = []
try:
- for region in (getattr(fig, "bounding_regions", None) or []):
+ for region in getattr(fig, "bounding_regions", None) or []:
bounding_regions.append(
{"page_number": region.page_number, "polygon": region.polygon}
)
@@ -792,7 +802,9 @@ async def extract_citation_artifacts(
png_bytes = None
if result_id:
- png_bytes = await asyncio.to_thread(self._download_figure_bytes_sync, result_id, azure_id)
+ png_bytes = await asyncio.to_thread(
+ self._download_figure_bytes_sync, result_id, azure_id
+ )
if not png_bytes:
# If we couldn't download, skip storing bytes (still return metadata)
@@ -815,8 +827,13 @@ async def extract_citation_artifacts(
{
"index": idx,
"azure_id": fig.get("id") or f"raw_{idx}",
- "caption": (fig.get("caption", {}) or {}).get("content") if isinstance(fig.get("caption"), dict) else None,
- "bounding_box": fig.get("boundingRegions") or fig.get("bounding_regions"),
+ "caption": (
+ (fig.get("caption", {}) or {}).get("content")
+ if isinstance(fig.get("caption"), dict)
+ else None
+ ),
+ "bounding_box": fig.get("boundingRegions")
+ or fig.get("bounding_regions"),
"png_bytes": b"",
}
)
@@ -834,4 +851,4 @@ async def extract_citation_artifacts(
try:
azure_docint_client = AzureDocIntelligenceService()
except Exception:
- azure_docint_client = None # type: ignore
\ No newline at end of file
+ azure_docint_client = None # type: ignore
diff --git a/backend/api/services/azure_openai_client.py b/backend/api/services/azure_openai_client.py
index 8f885418..32d894f0 100644
--- a/backend/api/services/azure_openai_client.py
+++ b/backend/api/services/azure_openai_client.py
@@ -101,12 +101,10 @@ def __init__(self):
self._official_clients: Dict[Tuple[str, str, str], AzureOpenAI] = {}
self._official_async_clients: Dict[Tuple[str, str, str], Any] = {}
-
# ---------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------
-
@staticmethod
def _resolve_auth_type() -> str:
"""Return key|entra.
@@ -145,11 +143,15 @@ def _load_models_yaml(self) -> Dict[str, Any]:
try:
data = yaml.safe_load(path.read_text(encoding="utf-8"))
if not isinstance(data, dict):
- logger.warning("Invalid models.yaml format (expected mapping): %s", type(data))
+ logger.warning(
+ "Invalid models.yaml format (expected mapping): %s", type(data)
+ )
return {}
return data
except Exception as e:
- logger.exception("Failed to load Azure OpenAI model catalog from %s: %s", path, e)
+ logger.exception(
+ "Failed to load Azure OpenAI model catalog from %s: %s", path, e
+ )
return {}
def _load_model_configs(self) -> Dict[str, Dict[str, str]]:
@@ -244,7 +246,9 @@ def _get_official_client(self, model: str) -> AzureOpenAI:
endpoint = config.get("endpoint")
api_version = config.get("api_version")
if not endpoint or not api_version:
- raise ValueError(f"Azure OpenAI endpoint/api_version not configured for model {model}")
+ raise ValueError(
+ f"Azure OpenAI endpoint/api_version not configured for model {model}"
+ )
cache_key = (endpoint, api_version, self._auth_type)
if cache_key not in self._official_clients:
@@ -255,12 +259,16 @@ def _get_official_client(self, model: str) -> AzureOpenAI:
if self._auth_type == "entra":
if not self._token_provider:
- raise ValueError(self._config_error or "Azure AD token provider not configured")
+ raise ValueError(
+ self._config_error or "Azure AD token provider not configured"
+ )
azure_openai_kwargs["azure_ad_token_provider"] = self._token_provider
else:
# key auth
if not self._api_key:
- raise ValueError("AZURE_OPENAI_MODE=key requires AZURE_OPENAI_API_KEY")
+ raise ValueError(
+ "AZURE_OPENAI_MODE=key requires AZURE_OPENAI_API_KEY"
+ )
azure_openai_kwargs["api_key"] = self._api_key
self._official_clients[cache_key] = AzureOpenAI(**azure_openai_kwargs)
@@ -284,7 +292,9 @@ def _get_official_async_client(self, model: str):
endpoint = config.get("endpoint")
api_version = config.get("api_version")
if not endpoint or not api_version:
- raise ValueError(f"Azure OpenAI endpoint/api_version not configured for model {model}")
+ raise ValueError(
+ f"Azure OpenAI endpoint/api_version not configured for model {model}"
+ )
cache_key = (endpoint, api_version, self._auth_type)
if cache_key not in self._official_async_clients:
@@ -295,11 +305,15 @@ def _get_official_async_client(self, model: str):
if self._auth_type == "entra":
if not self._token_provider:
- raise ValueError(self._config_error or "Azure AD token provider not configured")
+ raise ValueError(
+ self._config_error or "Azure AD token provider not configured"
+ )
azure_openai_kwargs["azure_ad_token_provider"] = self._token_provider
else:
if not self._api_key:
- raise ValueError("AZURE_OPENAI_MODE=key requires AZURE_OPENAI_API_KEY")
+ raise ValueError(
+ "AZURE_OPENAI_MODE=key requires AZURE_OPENAI_API_KEY"
+ )
azure_openai_kwargs["api_key"] = self._api_key
self._official_async_clients[cache_key] = AsyncAzureOpenAI(**azure_openai_kwargs) # type: ignore
@@ -358,7 +372,7 @@ async def chat_completion(
"presence_penalty": presence_penalty,
"stream": stream,
}
-
+
# gpt-5 deployments may reject temperature/max_tokens in some previews.
# We gate this by the *deployment* name because the UI key can differ.
if deployment != "gpt-5-mini":
@@ -533,7 +547,9 @@ def _worker() -> None:
continue
content = update.choices[0].delta.content or ""
if content:
- asyncio.run_coroutine_threadsafe(q.put(content), loop).result()
+ asyncio.run_coroutine_threadsafe(
+ q.put(content), loop
+ ).result()
except Exception as e:
asyncio.run_coroutine_threadsafe(q.put(e), loop).result()
finally:
@@ -649,7 +665,11 @@ def get_available_models(self) -> List[str]:
"""Get list of available models that are properly configured"""
out: List[str] = []
for model, config in self.model_configs.items():
- if not config.get("endpoint") or not config.get("deployment") or not config.get("api_version"):
+ if (
+ not config.get("endpoint")
+ or not config.get("deployment")
+ or not config.get("api_version")
+ ):
continue
out.append(model)
return out
@@ -663,7 +683,11 @@ def get_available_deployments(self) -> List[str]:
out: List[str] = []
seen: set[str] = set()
for _model, config in self.model_configs.items():
- if not config.get("endpoint") or not config.get("deployment") or not config.get("api_version"):
+ if (
+ not config.get("endpoint")
+ or not config.get("deployment")
+ or not config.get("api_version")
+ ):
continue
dep = str(config.get("deployment") or "").strip()
if not dep:
@@ -695,6 +719,7 @@ def is_configured(self) -> bool:
azure_openai_client = AzureOpenAIClient()
except Exception as e: # pragma: no cover
logger.exception("Failed to initialize AzureOpenAIClient: %s", e)
+
# Provide a stub that reports not-configured.
class _DisabledAzureOpenAIClient: # type: ignore
def is_configured(self) -> bool:
diff --git a/backend/api/services/cit_db_service.py b/backend/api/services/cit_db_service.py
index 85371835..a4e49508 100644
--- a/backend/api/services/cit_db_service.py
+++ b/backend/api/services/cit_db_service.py
@@ -12,6 +12,7 @@
Methods raise RuntimeError when psycopg2 is not available so callers
can surface a 503 with an actionable message.
"""
+
from typing import Any, Dict, List, Optional, Tuple
import psycopg2
import psycopg2.extras
@@ -132,6 +133,7 @@ def _construct_db_dsn_from_admin(admin_dsn: str, db_name: str) -> str:
else:
return f"{admin_dsn} dbname={db_name}"
+
# -----------------------
# Citations Postgres DB service
# -----------------------
@@ -149,11 +151,12 @@ def __init__(self):
# Low level connection helpers
# -----------------------
-
# -----------------------
# Generic column ops
# -----------------------
- def create_column(self, col: str, col_type: str, table_name: str = "citations") -> None:
+ def create_column(
+ self, col: str, col_type: str, table_name: str = "citations"
+ ) -> None:
"""
Create column on citations table if it doesn't already exist.
col should be the exact column name to use (caller may pass snake_case(col)).
@@ -165,11 +168,15 @@ def create_column(self, col: str, col_type: str, table_name: str = "citations")
conn = postgres_server.conn
cur = conn.cursor()
try:
- cur.execute(f'ALTER TABLE "{table_name}" ADD COLUMN IF NOT EXISTS "{col}" {col_type}')
+ cur.execute(
+ f'ALTER TABLE "{table_name}" ADD COLUMN IF NOT EXISTS "{col}" {col_type}'
+ )
except Exception:
# fallback for PG versions without IF NOT EXISTS
try:
- cur.execute(f'ALTER TABLE "{table_name}" ADD COLUMN "{col}" {col_type}')
+ cur.execute(
+ f'ALTER TABLE "{table_name}" ADD COLUMN "{col}" {col_type}'
+ )
except Exception:
pass
conn.commit()
@@ -198,16 +205,21 @@ def update_jsonb_column(
conn = postgres_server.conn
cur = conn.cursor()
try:
- cur.execute(f'ALTER TABLE "{table_name}" ADD COLUMN IF NOT EXISTS "{col}" JSONB')
+ cur.execute(
+ f'ALTER TABLE "{table_name}" ADD COLUMN IF NOT EXISTS "{col}" JSONB'
+ )
except Exception:
try:
cur.execute(f'ALTER TABLE "{table_name}" ADD COLUMN "{col}" JSONB')
except Exception:
pass
- cur.execute(f'UPDATE "{table_name}" SET "{col}" = %s WHERE id = %s', (json.dumps(data), int(citation_id)))
+ cur.execute(
+ f'UPDATE "{table_name}" SET "{col}" = %s WHERE id = %s',
+ (json.dumps(data), int(citation_id)),
+ )
rows = cur.rowcount
conn.commit()
-
+
return rows or 0
except Exception:
_safe_rollback(conn)
@@ -232,13 +244,18 @@ def update_text_column(
conn = postgres_server.conn
cur = conn.cursor()
try:
- cur.execute(f'ALTER TABLE "{table_name}" ADD COLUMN IF NOT EXISTS "{col}" TEXT')
+ cur.execute(
+ f'ALTER TABLE "{table_name}" ADD COLUMN IF NOT EXISTS "{col}" TEXT'
+ )
except Exception:
try:
cur.execute(f'ALTER TABLE "{table_name}" ADD COLUMN "{col}" TEXT')
except Exception:
pass
- cur.execute(f'UPDATE "{table_name}" SET "{col}" = %s WHERE id = %s', (text_value, int(citation_id)))
+ cur.execute(
+ f'UPDATE "{table_name}" SET "{col}" = %s WHERE id = %s',
+ (text_value, int(citation_id)),
+ )
rows = cur.rowcount
conn.commit()
@@ -264,13 +281,20 @@ def update_bool_column(
conn = postgres_server.conn
cur = conn.cursor()
try:
- cur.execute(f'ALTER TABLE "{table_name}" ADD COLUMN IF NOT EXISTS "{col}" BOOLEAN')
+ cur.execute(
+ f'ALTER TABLE "{table_name}" ADD COLUMN IF NOT EXISTS "{col}" BOOLEAN'
+ )
except Exception:
try:
- cur.execute(f'ALTER TABLE "{table_name}" ADD COLUMN "{col}" BOOLEAN')
+ cur.execute(
+ f'ALTER TABLE "{table_name}" ADD COLUMN "{col}" BOOLEAN'
+ )
except Exception:
pass
- cur.execute(f'UPDATE "{table_name}" SET "{col}" = %s WHERE id = %s', (bool(bool_value), int(citation_id)))
+ cur.execute(
+ f'UPDATE "{table_name}" SET "{col}" = %s WHERE id = %s',
+ (bool(bool_value), int(citation_id)),
+ )
rows = cur.rowcount
conn.commit()
return rows or 0
@@ -311,7 +335,9 @@ def get_table_columns(self, table_name: str = "citations") -> List[Dict[str, str
if conn:
pass
- def clear_columns(self, citation_id: int, columns: List[str], table_name: str = "citations") -> int:
+ def clear_columns(
+ self, citation_id: int, columns: List[str], table_name: str = "citations"
+ ) -> int:
"""Set provided columns to NULL for a citation. Ignores unknown columns."""
table_name = _validate_ident(table_name, kind="table_name")
if not columns:
@@ -327,7 +353,10 @@ def clear_columns(self, citation_id: int, columns: List[str], table_name: str =
conn = postgres_server.conn
cur = conn.cursor()
set_sql = ", ".join([f'"{c}" = NULL' for c in cols])
- cur.execute(f'UPDATE "{table_name}" SET {set_sql} WHERE id = %s', (int(citation_id),))
+ cur.execute(
+ f'UPDATE "{table_name}" SET {set_sql} WHERE id = %s',
+ (int(citation_id),),
+ )
rows = cur.rowcount
conn.commit()
return rows or 0
@@ -338,7 +367,9 @@ def clear_columns(self, citation_id: int, columns: List[str], table_name: str =
if conn:
pass
- def clear_columns_by_prefix(self, citation_id: int, prefixes: List[str], table_name: str = "citations") -> int:
+ def clear_columns_by_prefix(
+ self, citation_id: int, prefixes: List[str], table_name: str = "citations"
+ ) -> int:
"""Set all columns matching any prefix to NULL for a citation."""
prefixes = [p for p in (prefixes or []) if isinstance(p, str) and p]
if not prefixes:
@@ -374,10 +405,14 @@ def copy_jsonb_if_empty(
cur = conn.cursor()
# Ensure destination column exists as JSONB
try:
- cur.execute(f'ALTER TABLE "{table_name}" ADD COLUMN IF NOT EXISTS "{dst_col}" JSONB')
+ cur.execute(
+ f'ALTER TABLE "{table_name}" ADD COLUMN IF NOT EXISTS "{dst_col}" JSONB'
+ )
except Exception:
try:
- cur.execute(f'ALTER TABLE "{table_name}" ADD COLUMN "{dst_col}" JSONB')
+ cur.execute(
+ f'ALTER TABLE "{table_name}" ADD COLUMN "{dst_col}" JSONB'
+ )
except Exception:
pass
@@ -420,8 +455,6 @@ def dump_citations_csv(self, table_name: str = "citations") -> bytes:
)
csv_text = buf.getvalue()
-
-
return csv_text.encode("utf-8")
except Exception:
_safe_rollback(conn)
@@ -473,7 +506,9 @@ def dump_citations_csv_filtered(self, table_name: str = "citations") -> bytes:
conn = postgres_server.conn
cur = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
select_cols = base_cols + jsonb_cols
- select_sql = ", ".join([f'"{c}"' for c in select_cols]) if select_cols else "*"
+ select_sql = (
+ ", ".join([f'"{c}"' for c in select_cols]) if select_cols else "*"
+ )
cur.execute(f'SELECT {select_sql} FROM "{table_name}" ORDER BY id')
rows = cur.fetchall() or []
@@ -514,7 +549,9 @@ def _flatten_keys_for(col: str) -> List[str]:
# Fallback
return [f"{col}__json"]
- json_flat_cols: Dict[str, List[str]] = {c: _flatten_keys_for(c) for c in jsonb_cols}
+ json_flat_cols: Dict[str, List[str]] = {
+ c: _flatten_keys_for(c) for c in jsonb_cols
+ }
for c in jsonb_cols:
out_cols.extend(json_flat_cols[c])
@@ -573,9 +610,15 @@ def _as_str(v: Any) -> str:
out[f"{c}__confidence"] = _as_str(parsed.get("confidence"))
out[f"{c}__explanation"] = _as_str(parsed.get("explanation"))
- out[f"{c}__evidence_sentences"] = _as_str(parsed.get("evidence_sentences"))
- out[f"{c}__evidence_tables"] = _as_str(parsed.get("evidence_tables"))
- out[f"{c}__evidence_figures"] = _as_str(parsed.get("evidence_figures"))
+ out[f"{c}__evidence_sentences"] = _as_str(
+ parsed.get("evidence_sentences")
+ )
+ out[f"{c}__evidence_tables"] = _as_str(
+ parsed.get("evidence_tables")
+ )
+ out[f"{c}__evidence_figures"] = _as_str(
+ parsed.get("evidence_figures")
+ )
out[f"{c}__autofilled"] = _as_str(parsed.get("autofilled"))
out[f"{c}__source"] = _as_str(parsed.get("source"))
out[f"{c}__timestamp"] = _as_str(parsed.get("timestamp"))
@@ -591,7 +634,9 @@ def _as_str(v: Any) -> str:
if conn:
pass
- def get_citation_by_id(self, citation_id: int, table_name: str = "citations") -> Optional[Dict[str, Any]]:
+ def get_citation_by_id(
+ self, citation_id: int, table_name: str = "citations"
+ ) -> Optional[Dict[str, Any]]:
"""
Return a dict mapping column -> value for the citation row, or None.
"""
@@ -655,10 +700,14 @@ def get_citations_by_ids(
select_sql = "*"
if fields:
try:
- existing_cols = {c.get("column_name") for c in self.get_table_columns(table_name)}
+ existing_cols = {
+ c.get("column_name") for c in self.get_table_columns(table_name)
+ }
except Exception:
existing_cols = set()
- safe_fields = [f for f in fields if isinstance(f, str) and f in existing_cols]
+ safe_fields = [
+ f for f in fields if isinstance(f, str) and f in existing_cols
+ ]
if safe_fields:
select_sql = ", ".join([f'"{c}"' for c in safe_fields])
@@ -679,7 +728,9 @@ def get_citations_by_ids(
if conn:
pass
- def backfill_human_decisions(self, criteria_parsed: Dict[str, Any], table_name: str = "citations") -> int:
+ def backfill_human_decisions(
+ self, criteria_parsed: Dict[str, Any], table_name: str = "citations"
+ ) -> int:
"""Recompute and persist human_l1_decision / human_l2_decision for all rows.
This is used to ensure decision columns are never stale when the UI fetches
@@ -693,8 +744,16 @@ def backfill_human_decisions(self, criteria_parsed: Dict[str, Any], table_name:
table_name = _validate_ident(table_name, kind="table_name")
cp = criteria_parsed or {}
- l1_qs = (cp.get("l1") or {}).get("questions") if isinstance(cp.get("l1"), dict) else None
- l2_qs = (cp.get("l2") or {}).get("questions") if isinstance(cp.get("l2"), dict) else None
+ l1_qs = (
+ (cp.get("l1") or {}).get("questions")
+ if isinstance(cp.get("l1"), dict)
+ else None
+ )
+ l2_qs = (
+ (cp.get("l2") or {}).get("questions")
+ if isinstance(cp.get("l2"), dict)
+ else None
+ )
l1_qs = l1_qs if isinstance(l1_qs, list) else []
l2_qs = l2_qs if isinstance(l2_qs, list) else []
@@ -728,7 +787,9 @@ def _human_col(q: str) -> str:
# If we try to SELECT a non-existent human_* column, the query fails and the
# caller silently skips the backfill, leaving stale decision columns.
try:
- existing_cols = {c.get("column_name") for c in self.get_table_columns(table_name)}
+ existing_cols = {
+ c.get("column_name") for c in self.get_table_columns(table_name)
+ }
except Exception:
existing_cols = set()
@@ -773,7 +834,9 @@ def _compute(step_qs: List[str], row: Dict[str, Any]) -> str:
hobj = _parse_jsonb(hval)
selected = hobj.get("selected")
# Treat empty/whitespace as unanswered (UI shows "-- select --")
- if selected is None or (isinstance(selected, str) and selected.strip() == ""):
+ if selected is None or (
+ isinstance(selected, str) and selected.strip() == ""
+ ):
return "undecided"
if "exclude" in str(selected).lower():
return "exclude"
@@ -810,7 +873,9 @@ def _compute(step_qs: List[str], row: Dict[str, Any]) -> str:
if conn:
pass
- def list_citation_ids(self, filter_step=None, table_name: str = "citations") -> List[int]:
+ def list_citation_ids(
+ self, filter_step=None, table_name: str = "citations"
+ ) -> List[int]:
"""
Return list of integer primary keys (id) from citations table ordered by id.
"""
@@ -829,7 +894,9 @@ def list_citation_ids(self, filter_step=None, table_name: str = "citations") ->
# Validation rule (B1/B2): Full-text list is driven by the human L1 decision.
# Do NOT use l1_screen/l2_screen booleans.
try:
- cur.execute(f'ALTER TABLE "{table_name}" ADD COLUMN IF NOT EXISTS "human_l1_decision" TEXT')
+ cur.execute(
+ f'ALTER TABLE "{table_name}" ADD COLUMN IF NOT EXISTS "human_l1_decision" TEXT'
+ )
except Exception:
pass
cur.execute(
@@ -838,7 +905,9 @@ def list_citation_ids(self, filter_step=None, table_name: str = "citations") ->
elif step == "l2":
# Validation rule (B1/B2): Extract list is driven by the human L2 decision.
try:
- cur.execute(f'ALTER TABLE "{table_name}" ADD COLUMN IF NOT EXISTS "human_l2_decision" TEXT')
+ cur.execute(
+ f'ALTER TABLE "{table_name}" ADD COLUMN IF NOT EXISTS "human_l2_decision" TEXT'
+ )
except Exception:
pass
cur.execute(
@@ -866,7 +935,9 @@ def list_fulltext_urls(self, table_name: str = "citations") -> List[str]:
try:
conn = postgres_server.conn
cur = conn.cursor()
- cur.execute(f'SELECT fulltext_url FROM "{table_name}" WHERE fulltext_url IS NOT NULL')
+ cur.execute(
+ f'SELECT fulltext_url FROM "{table_name}" WHERE fulltext_url IS NOT NULL'
+ )
rows = cur.fetchall()
return [r[0] for r in rows if r and r[0]]
@@ -924,7 +995,9 @@ def attach_fulltext(
# -----------------------
# Column get/set helpers
# -----------------------
- def get_column_value(self, citation_id: int, column: str, table_name: str = "citations") -> Any:
+ def get_column_value(
+ self, citation_id: int, column: str, table_name: str = "citations"
+ ) -> Any:
"""
Return the value stored in `column` for the citation row (or None).
"""
@@ -936,7 +1009,9 @@ def get_column_value(self, citation_id: int, column: str, table_name: str = "cit
cur = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
except Exception:
cur = conn.cursor()
- cur.execute(f'SELECT "{column}" FROM "{table_name}" WHERE id = %s', (citation_id,))
+ cur.execute(
+ f'SELECT "{column}" FROM "{table_name}" WHERE id = %s', (citation_id,)
+ )
row = cur.fetchone()
if not row:
return None
@@ -954,13 +1029,20 @@ def get_column_value(self, citation_id: int, column: str, table_name: str = "cit
if conn:
pass
- def set_column_value(self, citation_id: int, column: str, value: Any, table_name: str = "citations") -> int:
+ def set_column_value(
+ self, citation_id: int, column: str, value: Any, table_name: str = "citations"
+ ) -> int:
"""
Generic setter for a citation row column. Will create a TEXT column if it doesn't exist.
"""
# For simplicity, create a TEXT column. Callers that need JSONB should use update_jsonb_column.
self.create_column(column, "TEXT", table_name=table_name)
- return self.update_text_column(citation_id, column, value if value is not None else None, table_name=table_name)
+ return self.update_text_column(
+ citation_id,
+ column,
+ value if value is not None else None,
+ table_name=table_name,
+ )
# -----------------------
# Per-upload table lifecycle helpers
@@ -1019,7 +1101,12 @@ def create_table_and_insert_sync(
inserted = 0
if rows:
safe_cols = [snake_case(c) for c in columns]
- insert_cols = [f'"{c}"' for c in safe_cols] + ['"cit_id"', '"fulltext_url"', '"fulltext"', '"fulltext_md5"']
+ insert_cols = [f'"{c}"' for c in safe_cols] + [
+ '"cit_id"',
+ '"fulltext_url"',
+ '"fulltext"',
+ '"fulltext_md5"',
+ ]
placeholders = ", ".join(["%s"] * len(insert_cols))
insert_sql = f'INSERT INTO "{table_name}" ({", ".join(insert_cols)}) VALUES ({placeholders})'
@@ -1034,11 +1121,26 @@ def _row_has_data(row: dict) -> bool:
values = []
for r in filtered_rows:
- row_vals = [r.get(orig_col) if r.get(orig_col) is not None else None for orig_col in columns]
- row_vals.append(r.get("cit_id") if r.get("cit_id") is not None else None)
- row_vals.append(r.get("fulltext_url") if r.get("fulltext_url") is not None else None)
- row_vals.append(r.get("fulltext") if r.get("fulltext") is not None else None)
- row_vals.append(r.get("fulltext_md5") if r.get("fulltext_md5") is not None else None)
+ row_vals = [
+ r.get(orig_col) if r.get(orig_col) is not None else None
+ for orig_col in columns
+ ]
+ row_vals.append(
+ r.get("cit_id") if r.get("cit_id") is not None else None
+ )
+ row_vals.append(
+ r.get("fulltext_url")
+ if r.get("fulltext_url") is not None
+ else None
+ )
+ row_vals.append(
+ r.get("fulltext") if r.get("fulltext") is not None else None
+ )
+ row_vals.append(
+ r.get("fulltext_md5")
+ if r.get("fulltext_md5") is not None
+ else None
+ )
values.append(tuple(row_vals))
if values:
@@ -1058,7 +1160,9 @@ def _row_has_data(row: dict) -> bool:
# NOTE: legacy per-database helpers (drop_database, create_db_and_table_sync) were
# intentionally removed in favor of per-upload tables in a shared database.
- def load_include_columns_from_criteria(self, sr_doc: Optional[Dict[str, Any]] = None) -> List[str]:
+ def load_include_columns_from_criteria(
+ self, sr_doc: Optional[Dict[str, Any]] = None
+ ) -> List[str]:
"""
Load the 'include' list for L1 screening.
Mirrors logic previously embedded in the citations router but kept here
@@ -1080,7 +1184,13 @@ def load_include_columns_from_criteria(self, sr_doc: Optional[Dict[str, Any]] =
pass
# 2) fallback to project file
- cfg_path = os.path.join(os.path.dirname(__file__), "..", "sr_setup", "configs", "criteria_config_measles_updated.yaml")
+ cfg_path = os.path.join(
+ os.path.dirname(__file__),
+ "..",
+ "sr_setup",
+ "configs",
+ "criteria_config_measles_updated.yaml",
+ )
cfg_path = os.path.normpath(cfg_path)
try:
import yaml
@@ -1096,7 +1206,9 @@ def load_include_columns_from_criteria(self, sr_doc: Optional[Dict[str, Any]] =
except Exception:
return []
- def build_combined_citation_from_row(self, row: Dict[str, Any], include_columns: List[str]) -> str:
+ def build_combined_citation_from_row(
+ self, row: Dict[str, Any], include_columns: List[str]
+ ) -> str:
parts: List[str] = []
if not row:
return ""
diff --git a/backend/api/services/citation_search/europePMC_citation_collection.py b/backend/api/services/citation_search/europePMC_citation_collection.py
index f9842a21..8b62d15b 100644
--- a/backend/api/services/citation_search/europePMC_citation_collection.py
+++ b/backend/api/services/citation_search/europePMC_citation_collection.py
@@ -5,22 +5,24 @@
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
-logging.basicConfig(level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+logging.basicConfig(
+ level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
+)
logger = logging.getLogger("grep-exp-pubmed-citations")
+
class EuropePMCCitationCollector:
def __init__(self):
self.base_url = "https://www.ebi.ac.uk/europepmc/webservices/rest/search"
self.citations = []
-
+
# Set up retry-capable session
self.session = requests.Session()
retry_strategy = Retry(
total=3,
backoff_factor=0.3,
status_forcelist=[429, 500, 502, 503, 504],
- allowed_methods=["GET", "HEAD"]
+ allowed_methods=["GET", "HEAD"],
)
adapter = HTTPAdapter(max_retries=retry_strategy)
self.session.mount("http://", adapter)
@@ -33,9 +35,9 @@ def fetch_data(self, search_term, cursor_mark="*"):
"query": search_term,
"pageSize": 100,
"format": "json",
- "cursorMark": cursor_mark
+ "cursorMark": cursor_mark,
}
-
+
response = self.session.get(self.base_url, params=params, timeout=30)
response.raise_for_status()
return response.json()
@@ -57,115 +59,140 @@ def extract_citation_data(self, entry):
year = int(entry.get("pubYear"))
except (ValueError, TypeError):
year = None
-
+
# Extract keywords safely
keywords = []
keyword_list = entry.get("keywordList")
if isinstance(keyword_list, dict) and "keyword" in keyword_list:
- keywords = keyword_list["keyword"] if isinstance(keyword_list["keyword"], list) else [keyword_list["keyword"]]
-
+ keywords = (
+ keyword_list["keyword"]
+ if isinstance(keyword_list["keyword"], list)
+ else [keyword_list["keyword"]]
+ )
+
# Extract publication types safely
pub_types = []
pub_type_list = entry.get("pubTypeList")
if isinstance(pub_type_list, dict) and "pubType" in pub_type_list:
- pub_types = pub_type_list["pubType"] if isinstance(pub_type_list["pubType"], list) else [pub_type_list["pubType"]]
-
+ pub_types = (
+ pub_type_list["pubType"]
+ if isinstance(pub_type_list["pubType"], list)
+ else [pub_type_list["pubType"]]
+ )
+
# Extract full text URLs safely
full_text_urls = []
full_text_url_list = entry.get("fullTextUrlList")
- if isinstance(full_text_url_list, dict) and "fullTextUrl" in full_text_url_list:
+ if (
+ isinstance(full_text_url_list, dict)
+ and "fullTextUrl" in full_text_url_list
+ ):
urls = full_text_url_list["fullTextUrl"]
if isinstance(urls, list):
- full_text_urls = [url.get("url") for url in urls if isinstance(url, dict) and url.get("url")]
+ full_text_urls = [
+ url.get("url")
+ for url in urls
+ if isinstance(url, dict) and url.get("url")
+ ]
elif isinstance(urls, dict):
if urls.get("url"):
full_text_urls = [urls["url"]]
-
+
citation_data = {
- 'source': 'Europe PMC',
- 'id': entry.get("id"),
- 'pmid': entry.get("pmid"),
- 'pmcid': entry.get("pmcid"),
- 'title': entry.get("title"),
- 'authors': entry.get("authorList", {}).get("author", []) if isinstance(entry.get("authorList"), dict) else [],
- 'author_string': entry.get("authorString"),
- 'abstract': entry.get("abstractText"),
- 'keywords': keywords,
- 'doi': entry.get("doi"),
- 'journal': entry.get("journalTitle"),
- 'issn': entry.get("journalIssn"),
- 'volume': entry.get("journalVolume"),
- 'issue': entry.get("issue"),
- 'pages': entry.get("pageInfo"),
- 'publication_types': pub_types,
- 'date': entry.get("firstPublicationDate"),
- 'year': year,
- 'language': entry.get("language"),
- 'full_text_urls': full_text_urls,
- 'is_open_access': entry.get("isOpenAccess"),
- 'has_pdf': entry.get("hasPDF"),
- 'pdf_attempted': False,
- 'pdf_success': False,
- 'pdf_url': None,
- 'pdf_error': None
+ "source": "Europe PMC",
+ "id": entry.get("id"),
+ "pmid": entry.get("pmid"),
+ "pmcid": entry.get("pmcid"),
+ "title": entry.get("title"),
+ "authors": (
+ entry.get("authorList", {}).get("author", [])
+ if isinstance(entry.get("authorList"), dict)
+ else []
+ ),
+ "author_string": entry.get("authorString"),
+ "abstract": entry.get("abstractText"),
+ "keywords": keywords,
+ "doi": entry.get("doi"),
+ "journal": entry.get("journalTitle"),
+ "issn": entry.get("journalIssn"),
+ "volume": entry.get("journalVolume"),
+ "issue": entry.get("issue"),
+ "pages": entry.get("pageInfo"),
+ "publication_types": pub_types,
+ "date": entry.get("firstPublicationDate"),
+ "year": year,
+ "language": entry.get("language"),
+ "full_text_urls": full_text_urls,
+ "is_open_access": entry.get("isOpenAccess"),
+ "has_pdf": entry.get("hasPDF"),
+ "pdf_attempted": False,
+ "pdf_success": False,
+ "pdf_url": None,
+ "pdf_error": None,
}
-
+
return citation_data
-
+
except Exception as e:
- logger.error(f"Error extracting citation data from Europe PMC entry {entry.get('id', 'Unknown ID')}: {e}")
+ logger.error(
+ f"Error extracting citation data from Europe PMC entry {entry.get('id', 'Unknown ID')}: {e}"
+ )
return None
def collect_citations(self, search_term, max_articles=1000):
"""Main method to collect citations from Europe PMC"""
logger.info(f"Starting Europe PMC citation collection")
logger.info(f"Search query: {search_term[:100]}...")
-
+
try:
cursor_mark = "*"
total_processed = 0
next_cursor_mark = None
self.citations = []
-
+
# Continue until we reach max_articles or no more results
while cursor_mark != next_cursor_mark and total_processed < max_articles:
# Fetch batch
response_data = self.fetch_data(search_term, cursor_mark)
entries = response_data.get("resultList", {}).get("result", [])
-
+
if not entries:
logger.info("No more results found or API error")
break
-
+
# Update cursor for next request
next_cursor_mark = response_data.get("nextCursorMark")
-
+
# Log progress
batch_size = len(entries)
logger.info(f"Retrieved batch of {batch_size} articles. Processing...")
-
+
# Process entries
for i, entry in enumerate(entries):
if total_processed >= max_articles:
- logger.info(f"Reached maximum number of articles ({max_articles})")
+ logger.info(
+ f"Reached maximum number of articles ({max_articles})"
+ )
break
-
+
try:
article_num = total_processed + i + 1
- logger.info(f"Processing article {article_num}: {entry.get('title', 'Unknown title')}")
-
+ logger.info(
+ f"Processing article {article_num}: {entry.get('title', 'Unknown title')}"
+ )
+
citation_data = self.extract_citation_data(entry)
if citation_data:
self.citations.append(citation_data)
-
+
except Exception as e:
logger.error(f"Error processing entry: {e}")
continue
-
+
# Update for next iteration
total_processed += len(entries)
logger.info(f"Processed {total_processed} articles so far")
-
+
# Update cursor for next request
if next_cursor_mark and next_cursor_mark != cursor_mark:
cursor_mark = next_cursor_mark
@@ -174,10 +201,12 @@ def collect_citations(self, search_term, max_articles=1000):
time.sleep(2)
else:
break
-
- logger.info(f"Completed processing. Collected {len(self.citations)} citations.")
+
+ logger.info(
+ f"Completed processing. Collected {len(self.citations)} citations."
+ )
return self.citations
-
+
except Exception as e:
logger.error(f"Unexpected error in citation collection: {e}")
- return []
\ No newline at end of file
+ return []
diff --git a/backend/api/services/citation_search/pubmed_citation_collection.py b/backend/api/services/citation_search/pubmed_citation_collection.py
index 38449de1..ef7bf4ed 100644
--- a/backend/api/services/citation_search/pubmed_citation_collection.py
+++ b/backend/api/services/citation_search/pubmed_citation_collection.py
@@ -5,6 +5,7 @@
import logging
import string
import time
+
# from unidecode import unidecode
from cleantext import clean
from datetime import datetime
@@ -14,240 +15,259 @@
ENTREZ_EMAIL = settings.ENTREZ_EMAIL
ENTREZ_API_KEY = settings.ENTREZ_API_KEY
# Configure logging
-logging.basicConfig(level=logging.INFO,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+logging.basicConfig(
+ level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
+)
logger = logging.getLogger("grep-exp-pubmed-citations")
class PubMedCitationCollector:
- def __init__(self):
- self.citations = []
-
- def extract_citation_data(self, paper):
- """Extract citation metadata from a PubMed article"""
- try:
- # Basic information
- title = paper['MedlineCitation']['Article'].get('ArticleTitle', '')
- if title:
- title = clean(title.rstrip("."), lower=False, fix_unicode=True)
-
- pmid = str(paper['MedlineCitation']['PMID'])
-
- # Authors
- authors = []
- if 'AuthorList' in paper['MedlineCitation']['Article']:
- for author in paper['MedlineCitation']['Article']['AuthorList']:
- last_name = author.get('LastName', '')
- fore_name = author.get('ForeName', '')
- full_name = ' '.join(filter(None, [fore_name, last_name]))
- if full_name:
- authors.append(full_name)
-
- # Keywords
- keywords = []
- if 'KeywordList' in paper['MedlineCitation'] and paper['MedlineCitation']['KeywordList']:
- keywords = [str(kw) for kw in paper['MedlineCitation']['KeywordList'][0]]
-
- # DOI
- doi = None
- for e_location in paper['MedlineCitation']['Article'].get('ELocationID', []):
- if e_location.attributes['EIdType'] == 'doi':
- doi = str(e_location)
- break
-
- # If no DOI found in article, try CrossRef
- if doi is None:
- doi = self.find_doi_crossref(title)
-
- # Date
- article_date = paper['MedlineCitation']['Article'].get('ArticleDate', [])
- date = None
- if article_date:
- date = f"{article_date[0]['Year']}/{article_date[0]['Month']}/{article_date[0]['Day']}"
-
- # Abstract
- abstract_texts = paper['MedlineCitation']['Article'].get('Abstract', {}).get('AbstractText', [])
- abstract = ' '.join([str(text) for text in abstract_texts])
-
- # Journal info
- journal_info = paper['MedlineCitation']['Article']['Journal']
-
- # Publication Type
- pt = [str(pt) for pt in paper['MedlineCitation']['Article']['PublicationTypeList']]
-
- # Compile citation data
- citation_data = {
- 'source': 'PubMed',
- 'id': pmid,
- 'title': title,
- 'authors': ', '.join(authors),
- 'abstract': abstract,
- 'keywords': ', '.join(keywords),
- 'doi': doi,
- 'pmid': pmid,
- 'journal': journal_info['Title'],
- 'issn': ''.join(journal_info.get('ISSN', [])),
- 'publication_types': ', '.join(pt),
- 'date': date,
- 'year': paper['MedlineCitation']['Article']['Journal']['JournalIssue']['PubDate'].get('Year'),
- 'volume': journal_info['JournalIssue'].get('Volume'),
- 'issue': journal_info['JournalIssue'].get('Issue'),
- 'language': ', '.join(paper['MedlineCitation']['Article'].get('Language', [])),
- 'pdf_attempted': False,
- 'pdf_success': False,
- 'pdf_url': None,
- 'pdf_error': None
- }
-
- return citation_data
-
- except Exception as e:
- logger.error(f"Error extracting citation data: {e}")
- return None
-
- def collect_citations(self, search_term, mindate=None, maxdate=None, max_articles=1000):
- """Main method to collect citations from PubMed"""
-
- Entrez.email = ENTREZ_EMAIL
- Entrez.api_key = ENTREZ_API_KEY
-
-
- logger.info(f"Starting PubMed citation collection")
- logger.info(f"Search term: {search_term[:100]}...")
- if mindate and maxdate:
- logger.info(f"Date range: {mindate} to {maxdate}")
- elif mindate:
- logger.info(f"Date range: {mindate} to present")
- elif maxdate:
- logger.info(f"Date range: earliest to {maxdate}")
- else:
- logger.info("No date range specified")
-
- search_params = {
- 'db': 'pubmed',
- 'sort': 'pub_date',
- 'retmode': 'xml',
- 'retmax': max_articles,
- 'term': search_term,
- }
-
- if mindate:
- search_params['mindate'] = mindate
- if maxdate:
- search_params['maxdate'] = maxdate
-
- try:
- handle = Entrez.esearch(**search_params)
- pmid_list = Entrez.read(handle)
- pmid_list = pmid_list.get('IdList', [])
- except Exception as e:
- logger.error(f"Error searching PubMed: {e}")
- pmid_list = []
-
- # Search PubMed
- # pmid_list = search_pubmed(search_term, mindate, maxdate, max_articles)
-
- if not pmid_list:
- logger.warning("No PMIDs found")
- return []
-
- total_pmids = len(pmid_list)
- logger.info(f"Found {total_pmids} PMIDs")
-
- # Limit to max_articles
- if max_articles and total_pmids > max_articles:
- logger.info(f"Limiting to {max_articles} articles")
- pmid_list = pmid_list[:max_articles]
-
- # Fetch article details
- logger.info(f"Fetching details for {len(pmid_list)} articles")
- papers = self.fetch_details_batch(pmid_list)
-
- # Extract citation data
- self.citations = []
- for i, paper in enumerate(papers):
- try:
- # logger.info(f"Processing article {i+1}/{len(papers)}")
- citation_data = self.extract_citation_data(paper)
- if citation_data:
- self.citations.append(citation_data)
-
- except Exception as e:
- logger.error(f"Error processing paper {i}: {e}")
- continue
-
- logger.info(f"Successfully collected {len(self.citations)} citations")
- return self.citations
-
- def find_doi_crossref(self, title):
- """Find DOI using CrossRef API"""
- if not title:
- return None
-
- url = "https://api.crossref.org/works"
- params = {
- 'query.title': title,
- 'rows': 1
- }
-
- try:
- response = requests.get(url, params=params, timeout=10)
-
- if response.status_code == 200:
- data = response.json()
- items = data['message']['items']
-
- if items:
- first_item = items[0]
- result_title = first_item['title'][0]
-
- if self.clean_title(result_title) == self.clean_title(title):
- return first_item['DOI']
-
- return None
- except Exception as e:
- logger.warning(f"Error finding DOI for title: {e}")
- return None
-
- def fetch_details_batch(self, id_list, batch_size=50):
- """Fetch details for PubMed IDs in batches"""
- if not id_list:
- return []
-
- Entrez.email = ENTREZ_EMAIL
- Entrez.api_key = ENTREZ_API_KEY
-
- all_results = []
- print(id_list)
- print(f"Number of IDs: {len(id_list)}")
-
- # id_list = list(id_list)
-
- for i in range(0, len(id_list), batch_size):
- batch_ids = id_list[i:i+batch_size]
- ids = ','.join(batch_ids)
-
+ def __init__(self):
+ self.citations = []
+
+ def extract_citation_data(self, paper):
+ """Extract citation metadata from a PubMed article"""
+ try:
+ # Basic information
+ title = paper["MedlineCitation"]["Article"].get("ArticleTitle", "")
+ if title:
+ title = clean(title.rstrip("."), lower=False, fix_unicode=True)
+
+ pmid = str(paper["MedlineCitation"]["PMID"])
+
+ # Authors
+ authors = []
+ if "AuthorList" in paper["MedlineCitation"]["Article"]:
+ for author in paper["MedlineCitation"]["Article"]["AuthorList"]:
+ last_name = author.get("LastName", "")
+ fore_name = author.get("ForeName", "")
+ full_name = " ".join(filter(None, [fore_name, last_name]))
+ if full_name:
+ authors.append(full_name)
+
+ # Keywords
+ keywords = []
+ if (
+ "KeywordList" in paper["MedlineCitation"]
+ and paper["MedlineCitation"]["KeywordList"]
+ ):
+ keywords = [
+ str(kw) for kw in paper["MedlineCitation"]["KeywordList"][0]
+ ]
+
+ # DOI
+ doi = None
+ for e_location in paper["MedlineCitation"]["Article"].get(
+ "ELocationID", []
+ ):
+ if e_location.attributes["EIdType"] == "doi":
+ doi = str(e_location)
+ break
+
+ # If no DOI found in article, try CrossRef
+ if doi is None:
+ doi = self.find_doi_crossref(title)
+
+ # Date
+ article_date = paper["MedlineCitation"]["Article"].get("ArticleDate", [])
+ date = None
+ if article_date:
+ date = f"{article_date[0]['Year']}/{article_date[0]['Month']}/{article_date[0]['Day']}"
+
+ # Abstract
+ abstract_texts = (
+ paper["MedlineCitation"]["Article"]
+ .get("Abstract", {})
+ .get("AbstractText", [])
+ )
+ abstract = " ".join([str(text) for text in abstract_texts])
+
+ # Journal info
+ journal_info = paper["MedlineCitation"]["Article"]["Journal"]
+
+ # Publication Type
+ pt = [
+ str(pt)
+ for pt in paper["MedlineCitation"]["Article"]["PublicationTypeList"]
+ ]
+
+ # Compile citation data
+ citation_data = {
+ "source": "PubMed",
+ "id": pmid,
+ "title": title,
+ "authors": ", ".join(authors),
+ "abstract": abstract,
+ "keywords": ", ".join(keywords),
+ "doi": doi,
+ "pmid": pmid,
+ "journal": journal_info["Title"],
+ "issn": "".join(journal_info.get("ISSN", [])),
+ "publication_types": ", ".join(pt),
+ "date": date,
+ "year": paper["MedlineCitation"]["Article"]["Journal"]["JournalIssue"][
+ "PubDate"
+ ].get("Year"),
+ "volume": journal_info["JournalIssue"].get("Volume"),
+ "issue": journal_info["JournalIssue"].get("Issue"),
+ "language": ", ".join(
+ paper["MedlineCitation"]["Article"].get("Language", [])
+ ),
+ "pdf_attempted": False,
+ "pdf_success": False,
+ "pdf_url": None,
+ "pdf_error": None,
+ }
+
+ return citation_data
+
+ except Exception as e:
+ logger.error(f"Error extracting citation data: {e}")
+ return None
+
+ def collect_citations(
+ self, search_term, mindate=None, maxdate=None, max_articles=1000
+ ):
+ """Main method to collect citations from PubMed"""
+
+ Entrez.email = ENTREZ_EMAIL
+ Entrez.api_key = ENTREZ_API_KEY
+
+ logger.info(f"Starting PubMed citation collection")
+ logger.info(f"Search term: {search_term[:100]}...")
+ if mindate and maxdate:
+ logger.info(f"Date range: {mindate} to {maxdate}")
+ elif mindate:
+ logger.info(f"Date range: {mindate} to present")
+ elif maxdate:
+ logger.info(f"Date range: earliest to {maxdate}")
+ else:
+ logger.info("No date range specified")
+
+ search_params = {
+ "db": "pubmed",
+ "sort": "pub_date",
+ "retmode": "xml",
+ "retmax": max_articles,
+ "term": search_term,
+ }
+
+ if mindate:
+ search_params["mindate"] = mindate
+ if maxdate:
+ search_params["maxdate"] = maxdate
+
+ try:
+ handle = Entrez.esearch(**search_params)
+ pmid_list = Entrez.read(handle)
+ pmid_list = pmid_list.get("IdList", [])
+ except Exception as e:
+ logger.error(f"Error searching PubMed: {e}")
+ pmid_list = []
+
+ # Search PubMed
+ # pmid_list = search_pubmed(search_term, mindate, maxdate, max_articles)
+
+ if not pmid_list:
+ logger.warning("No PMIDs found")
+ return []
+
+ total_pmids = len(pmid_list)
+ logger.info(f"Found {total_pmids} PMIDs")
+
+ # Limit to max_articles
+ if max_articles and total_pmids > max_articles:
+ logger.info(f"Limiting to {max_articles} articles")
+ pmid_list = pmid_list[:max_articles]
+
+ # Fetch article details
+ logger.info(f"Fetching details for {len(pmid_list)} articles")
+ papers = self.fetch_details_batch(pmid_list)
+
+ # Extract citation data
+ self.citations = []
+ for i, paper in enumerate(papers):
+ try:
+ # logger.info(f"Processing article {i+1}/{len(papers)}")
+ citation_data = self.extract_citation_data(paper)
+ if citation_data:
+ self.citations.append(citation_data)
+
+ except Exception as e:
+ logger.error(f"Error processing paper {i}: {e}")
+ continue
+
+ logger.info(f"Successfully collected {len(self.citations)} citations")
+ return self.citations
+
+ def find_doi_crossref(self, title):
+ """Find DOI using CrossRef API"""
+ if not title:
+ return None
+
+ url = "https://api.crossref.org/works"
+ params = {"query.title": title, "rows": 1}
+
try:
- handle = Entrez.efetch(db='pubmed', retmode='xml', id=ids)
- results = Entrez.read(handle)
-
- if "PubmedArticle" in results:
- all_results.extend(results["PubmedArticle"])
-
- # Add delay to be nice to API
- if i + batch_size < len(id_list):
- time.sleep(1)
-
+ response = requests.get(url, params=params, timeout=10)
+
+ if response.status_code == 200:
+ data = response.json()
+ items = data["message"]["items"]
+
+ if items:
+ first_item = items[0]
+ result_title = first_item["title"][0]
+
+ if self.clean_title(result_title) == self.clean_title(title):
+ return first_item["DOI"]
+
+ return None
except Exception as e:
- logger.error(f"Error fetching details for batch {i//batch_size + 1}: {e}")
- continue
-
- return all_results
-
- def clean_title(self, title):
- """Clean and normalize a title"""
- if not title:
- return ""
- title = title.lower()
- title = title.translate(str.maketrans('', '', string.punctuation))
- return title
\ No newline at end of file
+ logger.warning(f"Error finding DOI for title: {e}")
+ return None
+
+ def fetch_details_batch(self, id_list, batch_size=50):
+ """Fetch details for PubMed IDs in batches"""
+ if not id_list:
+ return []
+
+ Entrez.email = ENTREZ_EMAIL
+ Entrez.api_key = ENTREZ_API_KEY
+
+ all_results = []
+ print(id_list)
+ print(f"Number of IDs: {len(id_list)}")
+
+ # id_list = list(id_list)
+
+ for i in range(0, len(id_list), batch_size):
+ batch_ids = id_list[i : i + batch_size]
+ ids = ",".join(batch_ids)
+
+ try:
+ handle = Entrez.efetch(db="pubmed", retmode="xml", id=ids)
+ results = Entrez.read(handle)
+
+ if "PubmedArticle" in results:
+ all_results.extend(results["PubmedArticle"])
+
+ # Add delay to be nice to API
+ if i + batch_size < len(id_list):
+ time.sleep(1)
+
+ except Exception as e:
+ logger.error(
+ f"Error fetching details for batch {i//batch_size + 1}: {e}"
+ )
+ continue
+
+ return all_results
+
+ def clean_title(self, title):
+ """Clean and normalize a title"""
+ if not title:
+ return ""
+ title = title.lower()
+ title = title.translate(str.maketrans("", "", string.punctuation))
+ return title
diff --git a/backend/api/services/citation_search/scopus_citation_collection.py b/backend/api/services/citation_search/scopus_citation_collection.py
index a5fba9ec..c52f8151 100644
--- a/backend/api/services/citation_search/scopus_citation_collection.py
+++ b/backend/api/services/citation_search/scopus_citation_collection.py
@@ -1,5 +1,6 @@
import requests
+
class ScopusDataProcessor:
def __init__(self, api_key, base_url):
self._api_key = api_key
@@ -8,11 +9,13 @@ def __init__(self, api_key, base_url):
self.data = []
def fetch_data(self, start, search_term):
- params = { "query": search_term,
- "apiKey": self._api_key,
- "count": 25,
- "start": start,
- "view": "COMPLETE"}
+ params = {
+ "query": search_term,
+ "apiKey": self._api_key,
+ "count": 25,
+ "start": start,
+ "view": "COMPLETE",
+ }
response = requests.get(self._URL, params=params)
if response.status_code == 200:
return response.json()
@@ -23,7 +26,9 @@ def process_entry(self, entry):
pdf_link_call = self.get_open_access_link(entry.get("prism:doi"))
return {
"Refid": entry.get("dc:identifier"),
- "Author": ", ".join([author["authname"] for author in entry.get("author", [])]),
+ "Author": ", ".join(
+ [author["authname"] for author in entry.get("author", [])]
+ ),
"Title": entry.get("dc:title"),
"Abstract": entry.get("dc:description"),
"Accession Number": entry.get("eid"),
@@ -38,7 +43,11 @@ def process_entry(self, entry):
"ISSN": entry.get("prism:issn"),
"Issue": entry.get("prism:issueIdentifier"),
"Journal": entry.get("prism:publicationName"),
- "Keywords": entry.get("authkeywords").replace(" |", ",") if entry.get("authkeywords") != None else None,
+ "Keywords": (
+ entry.get("authkeywords").replace(" |", ",")
+ if entry.get("authkeywords") != None
+ else None
+ ),
"Language": None,
"Notes": None,
"Original Publication": None,
@@ -61,13 +70,15 @@ def process_entry(self, entry):
"Is this article primary research?": None,
"Is this article on the human population?": None,
"Is the main focus of this study about measles disease?": None,
- "Measles aka rubeola, Morbilli, red measles, English measles": None
+ "Measles aka rubeola, Morbilli, red measles, English measles": None,
}
def consume_api(self, search_term, delay=1):
res_data = self.fetch_data(0)
- total_results = int(res_data.get("search-results", {}).get("opensearch:totalResults"))
+ total_results = int(
+ res_data.get("search-results", {}).get("opensearch:totalResults")
+ )
for start in range(0, total_results, 25):
res_data = self.fetch_data(start, search_term)
entries = res_data.get("search-results", {}).get("entry", [])
@@ -83,7 +94,7 @@ def get_open_access_link(self, doi):
if response and response.status_code == 200:
data = response.json()
- link = data.get('url', None)
- if link and 'pdf' in link.lower():
+ link = data.get("url", None)
+ if link and "pdf" in link.lower():
return link
- return None
\ No newline at end of file
+ return None
diff --git a/backend/api/services/document_service.py b/backend/api/services/document_service.py
index 571ed429..fca2b7be 100644
--- a/backend/api/services/document_service.py
+++ b/backend/api/services/document_service.py
@@ -36,7 +36,9 @@ async def convert_document_to_markdown(
result["processor_used"] = "azure_doc_intelligence"
return result
- async def get_raw_analysis_result(self, conversion_id: str) -> Optional[Dict[str, Any]]:
+ async def get_raw_analysis_result(
+ self, conversion_id: str
+ ) -> Optional[Dict[str, Any]]:
if not azure_docint_client:
return None
return await azure_docint_client.get_raw_analysis_result(conversion_id)
diff --git a/backend/api/services/grobid_service.py b/backend/api/services/grobid_service.py
index 6bed76ff..94d46e06 100644
--- a/backend/api/services/grobid_service.py
+++ b/backend/api/services/grobid_service.py
@@ -16,7 +16,7 @@
"formula": "rgba(255, 165, 0, 1)", # Orange
"figure": "rgba(165, 42, 42, 1)", # Brown
"title": "rgba(255, 0, 0, 1)", # Red
- "affiliation": "rgba(255, 165, 0, 1)" # red-orengi
+ "affiliation": "rgba(255, 165, 0, 1)", # red-orengi
}
@@ -27,11 +27,13 @@ def get_color(name, param):
return color
+
def exclude_tags(tag):
ret = False
- ret |= tag.name != 'abstract' # exclude the abstract
+ ret |= tag.name != "abstract" # exclude the abstract
return ret
+
class GrobidService:
def __init__(self):
@@ -77,16 +79,20 @@ def is_available(self) -> bool:
async def process_structure(self, input_path) -> (dict, list):
if not self.grobid_client:
- raise RuntimeError("GROBID client is not available (service not configured or down)")
- pdf_file, status, text = self.grobid_client.process_pdf("processFulltextDocument",
- input_path,
- consolidate_header=True,
- consolidate_citations=False,
- segment_sentences=True,
- tei_coordinates=True,
- include_raw_citations=False,
- include_raw_affiliations=False,
- generateIDs=True)
+ raise RuntimeError(
+ "GROBID client is not available (service not configured or down)"
+ )
+ pdf_file, status, text = self.grobid_client.process_pdf(
+ "processFulltextDocument",
+ input_path,
+ consolidate_header=True,
+ consolidate_citations=False,
+ segment_sentences=True,
+ tei_coordinates=True,
+ include_raw_citations=False,
+ include_raw_affiliations=False,
+ generateIDs=True,
+ )
if status != 200:
return
@@ -99,23 +105,29 @@ async def process_structure(self, input_path) -> (dict, list):
@staticmethod
def box_to_dict(box, color=None, type=None, text=None):
- item = {"page": box[0], "x": box[1], "y": box[2], "width": box[3], "height": box[4]}
+ item = {
+ "page": box[0],
+ "x": box[1],
+ "y": box[2],
+ "width": box[3],
+ "height": box[4],
+ }
if color is not None:
- item['color'] = color
+ item["color"] = color
if type:
- item['type'] = type
+ item["type"] = type
if text:
- item['text'] = text
+ item["text"] = text
return item
async def get_coordinates(self, text):
- soup = BeautifulSoup(text, 'xml')
+ soup = BeautifulSoup(text, "xml")
# exclude certain tag names
- all_blocks_with_coordinates = soup.find('text').find_all(coords=True)
+ all_blocks_with_coordinates = soup.find("text").find_all(coords=True)
# all_blocks_with_coordinates = soup.find_all()
# if use_sentences:
@@ -124,27 +136,35 @@ async def get_coordinates(self, text):
coordinates = []
count = 0
for block_id, block in enumerate(all_blocks_with_coordinates):
- for box in filter(lambda c: len(c) > 0 and c[0] != "", block['coords'].split(";")):
+ for box in filter(
+ lambda c: len(c) > 0 and c[0] != "", block["coords"].split(";")
+ ):
coordinates.append(
self.box_to_dict(
box.split(","),
get_color(block.name, count % 2 == 0),
type=block.name,
- text=block.text
+ text=block.text,
),
)
count += 1
return coordinates
async def get_pages(self, text):
- soup = BeautifulSoup(text, 'xml')
+ soup = BeautifulSoup(text, "xml")
pages_infos = soup.find_all("surface")
- pages = [{'width': float(page['lrx']) - float(page['ulx']), 'height': float(page['lry']) - float(page['uly'])}
- for page in pages_infos]
+ pages = [
+ {
+ "width": float(page["lrx"]) - float(page["ulx"]),
+ "height": float(page["lry"]) - float(page["uly"]),
+ }
+ for page in pages_infos
+ ]
return pages
+
# Global instance
try:
grobid_service = GrobidService()
diff --git a/backend/api/services/postgres_auth.py b/backend/api/services/postgres_auth.py
index e053339c..4c38379f 100644
--- a/backend/api/services/postgres_auth.py
+++ b/backend/api/services/postgres_auth.py
@@ -71,8 +71,14 @@ def conn(self):
# psycopg2.errors.InFailedSqlTransaction
# Heal it proactively so one bad request cannot poison the process.
try:
- if self._conn and self._conn.get_transaction_status() == psycopg2.extensions.TRANSACTION_STATUS_INERROR:
- logger.warning("Postgres connection in aborted transaction; rolling back")
+ if (
+ self._conn
+ and self._conn.get_transaction_status()
+ == psycopg2.extensions.TRANSACTION_STATUS_INERROR
+ ):
+ logger.warning(
+ "Postgres connection in aborted transaction; rolling back"
+ )
self._conn.rollback()
except Exception:
# If rollback fails, reconnect.
@@ -128,7 +134,9 @@ def _refresh_azure_token(self) -> str:
)
token = self._credential.get_token(self._AZURE_POSTGRES_SCOPE)
self._token = token.token
- self._token_expiration = token.expires_on - self._TOKEN_REFRESH_BUFFER_SECONDS
+ self._token_expiration = (
+ token.expires_on - self._TOKEN_REFRESH_BUFFER_SECONDS
+ )
return self._token
@staticmethod
@@ -167,7 +175,12 @@ def _candidate_kwargs(self, mode: str, psycopg3: bool = False) -> Dict[str, Any]
kwargs["password"] = prof.get("password")
# Sanity checks
- required = [kwargs.get("host"), kwargs.get("database"), kwargs.get("user"), kwargs.get("port")]
+ required = [
+ kwargs.get("host"),
+ kwargs.get("database"),
+ kwargs.get("user"),
+ kwargs.get("port"),
+ ]
if not all(required):
raise RuntimeError(f"Incomplete Postgres config for mode={mode}")
@@ -188,7 +201,9 @@ def _connect(self):
try:
return self._connect_with_mode(primary_mode)
except Exception as e:
- logger.error("Postgres connect failed (mode=%s): %s", primary_mode, e, exc_info=True)
+ logger.error(
+ "Postgres connect failed (mode=%s): %s", primary_mode, e, exc_info=True
+ )
raise psycopg2.OperationalError(
f"Could not connect to Postgres for mode={primary_mode}"
)
@@ -200,13 +215,14 @@ async def aconn(self) -> AsyncIterator[psycopg.AsyncConnection]:
Commits automatically on clean exit, rolls back on exception.
"""
kwargs = self._candidate_kwargs(self._mode(), psycopg3=True)
- async with await psycopg.AsyncConnection.connect(**kwargs, row_factory=dict_row) as conn:
+ async with await psycopg.AsyncConnection.connect(
+ **kwargs, row_factory=dict_row
+ ) as conn:
yield conn
def __repr__(self) -> str:
status = "open" if self._conn and not self._conn.closed else "closed"
- return (
- f""
- )
+ return f""
+
-postgres_server = PostgresServer()
\ No newline at end of file
+postgres_server = PostgresServer()
diff --git a/backend/api/services/sr_db_service.py b/backend/api/services/sr_db_service.py
index f936ee2a..1a4d46dc 100644
--- a/backend/api/services/sr_db_service.py
+++ b/backend/api/services/sr_db_service.py
@@ -8,6 +8,7 @@
All blocking DB operations are synchronous and intended to be run with
`fastapi.concurrency.run_in_threadpool` when called from async routes.
"""
+
from typing import Any, Dict, Optional, List
import logging
import uuid
@@ -37,7 +38,7 @@ def ensure_table_exists(self) -> None:
try:
conn = postgres_server.conn
cur = conn.cursor()
-
+
create_table_sql = """
CREATE TABLE IF NOT EXISTS systematic_reviews (
id TEXT PRIMARY KEY,
@@ -57,7 +58,7 @@ def ensure_table_exists(self) -> None:
"""
cur.execute(create_table_sql)
conn.commit()
-
+
logger.info("Ensured systematic_reviews table exists")
except Exception as e:
try:
@@ -71,7 +72,9 @@ def ensure_table_exists(self) -> None:
if conn:
pass
- def build_criteria_parsed(self, criteria_obj: Optional[Dict[str, Any]]) -> Dict[str, Any]:
+ def build_criteria_parsed(
+ self, criteria_obj: Optional[Dict[str, Any]]
+ ) -> Dict[str, Any]:
"""
Port of the _build_criteria_parsed helper - returns a mapping containing
'l1', 'l2', 'parameters' structured metadata for the UI and screening logic.
@@ -108,7 +111,13 @@ def build_criteria_parsed(self, criteria_obj: Optional[Dict[str, Any]]) -> Dict[
else:
possible.append([])
addinfos.append("")
- parsed["l1"].update({"questions": qlist, "possible_answers": possible, "additional_infos": addinfos})
+ parsed["l1"].update(
+ {
+ "questions": qlist,
+ "possible_answers": possible,
+ "additional_infos": addinfos,
+ }
+ )
# l2 criteria (fulltext)
l2 = criteria_obj.get("l2_criteria", {})
@@ -133,7 +142,13 @@ def build_criteria_parsed(self, criteria_obj: Optional[Dict[str, Any]]) -> Dict[
else:
l2_possible.append([])
l2_addinfos.append("")
- parsed["l2"].update({"questions": l2_q, "possible_answers": l2_possible, "additional_infos": l2_addinfos})
+ parsed["l2"].update(
+ {
+ "questions": l2_q,
+ "possible_answers": l2_possible,
+ "additional_infos": l2_addinfos,
+ }
+ )
# parameters
params = criteria_obj.get("parameters", {})
@@ -147,14 +162,22 @@ def build_criteria_parsed(self, criteria_obj: Optional[Dict[str, Any]]) -> Dict[
possible_params.append([k for k in param_map.keys()])
descriptions.append(
[
- "Parameter {key} are described as {info}.".format(key=k, info=v)
+ "Parameter {key} are described as {info}.".format(
+ key=k, info=v
+ )
for k, v in param_map.items()
]
)
else:
possible_params.append([])
descriptions.append([])
- parsed["parameters"].update({"categories": categories, "possible_parameters": possible_params, "descriptions": descriptions})
+ parsed["parameters"].update(
+ {
+ "categories": categories,
+ "possible_parameters": possible_params,
+ "descriptions": descriptions,
+ }
+ )
return parsed
@@ -182,55 +205,59 @@ def create_systematic_review(
try:
conn = postgres_server.conn
cur = conn.cursor()
-
+
insert_sql = """
INSERT INTO systematic_reviews
(id, name, description, owner_id, owner_email, users, visible,
criteria, criteria_yaml, criteria_parsed, created_at, updated_at)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
"""
-
- cur.execute(insert_sql, (
- sr_id,
- name.strip(),
- description,
- owner_id,
- owner_email,
- json.dumps(users_list),
- True,
- json.dumps(criteria_obj) if criteria_obj else None,
- criteria_str,
- json.dumps(criteria_parsed),
- now,
- now
- ))
-
+
+ cur.execute(
+ insert_sql,
+ (
+ sr_id,
+ name.strip(),
+ description,
+ owner_id,
+ owner_email,
+ json.dumps(users_list),
+ True,
+ json.dumps(criteria_obj) if criteria_obj else None,
+ criteria_str,
+ json.dumps(criteria_parsed),
+ now,
+ now,
+ ),
+ )
+
conn.commit()
-
+
# Fetch the created record
cur.execute("SELECT * FROM systematic_reviews WHERE id = %s", (sr_id,))
row = cur.fetchone()
cols = [desc[0] for desc in cur.description]
sr_doc = {cols[i]: row[i] for i in range(len(cols))}
-
+
# Parse JSON fields and convert timestamps
- if sr_doc.get('users') and isinstance(sr_doc['users'], str):
- sr_doc['users'] = json.loads(sr_doc['users'])
- if sr_doc.get('criteria') and isinstance(sr_doc['criteria'], str):
- sr_doc['criteria'] = json.loads(sr_doc['criteria'])
- if sr_doc.get('criteria_parsed') and isinstance(sr_doc['criteria_parsed'], str):
- sr_doc['criteria_parsed'] = json.loads(sr_doc['criteria_parsed'])
+ if sr_doc.get("users") and isinstance(sr_doc["users"], str):
+ sr_doc["users"] = json.loads(sr_doc["users"])
+ if sr_doc.get("criteria") and isinstance(sr_doc["criteria"], str):
+ sr_doc["criteria"] = json.loads(sr_doc["criteria"])
+ if sr_doc.get("criteria_parsed") and isinstance(
+ sr_doc["criteria_parsed"], str
+ ):
+ sr_doc["criteria_parsed"] = json.loads(sr_doc["criteria_parsed"])
# Convert datetime objects to ISO strings
from datetime import datetime as dt
- if sr_doc.get('created_at') and isinstance(sr_doc['created_at'], dt):
- sr_doc['created_at'] = sr_doc['created_at'].isoformat()
- if sr_doc.get('updated_at') and isinstance(sr_doc['updated_at'], dt):
- sr_doc['updated_at'] = sr_doc['updated_at'].isoformat()
-
-
+ if sr_doc.get("created_at") and isinstance(sr_doc["created_at"], dt):
+ sr_doc["created_at"] = sr_doc["created_at"].isoformat()
+ if sr_doc.get("updated_at") and isinstance(sr_doc["updated_at"], dt):
+ sr_doc["updated_at"] = sr_doc["updated_at"].isoformat()
+
return sr_doc
-
+
except Exception as e:
try:
if conn:
@@ -238,60 +265,75 @@ def create_systematic_review(
except Exception:
pass
logger.exception(f"Failed to insert SR document: {e}")
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to create systematic review: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to create systematic review: {e}",
+ )
finally:
if conn:
pass
- def add_user(self, sr_id: str, target_user_id: str, requester_id: str) -> Dict[str, Any]:
+ def add_user(
+ self, sr_id: str, target_user_id: str, requester_id: str
+ ) -> Dict[str, Any]:
"""
Add a user id to the SR's users list. Enforces that the SR exists and is visible;
requester must be a member or owner.
Returns a dict with update result metadata.
"""
-
sr = self.get_systematic_review(sr_id)
if not sr or not sr.get("visible", True):
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Systematic review not found")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="Systematic review not found",
+ )
# Check permission
has_perm = self.user_has_sr_permission(sr_id, requester_id)
if not has_perm:
- raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized to modify this systematic review")
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail="Not authorized to modify this systematic review",
+ )
conn = None
try:
conn = postgres_server.conn
cur = conn.cursor()
-
+
# Get current users array
cur.execute("SELECT users FROM systematic_reviews WHERE id = %s", (sr_id,))
row = cur.fetchone()
if not row:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Systematic review not found")
-
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="Systematic review not found",
+ )
+
users = row[0] if row[0] else []
if isinstance(users, str):
users = json.loads(users)
-
+
# Add user if not already present
if target_user_id not in users:
users.append(target_user_id)
-
+
# Update
now = datetime.utcnow().isoformat()
cur.execute(
"UPDATE systematic_reviews SET users = %s, updated_at = %s WHERE id = %s",
- (json.dumps(users), now, sr_id)
+ (json.dumps(users), now, sr_id),
)
modified_count = cur.rowcount
conn.commit()
-
-
- return {"matched_count": 1, "modified_count": modified_count, "added_user_id": target_user_id}
-
+ return {
+ "matched_count": 1,
+ "modified_count": modified_count,
+ "added_user_id": target_user_id,
+ }
+
except HTTPException:
try:
if conn:
@@ -306,62 +348,80 @@ def add_user(self, sr_id: str, target_user_id: str, requester_id: str) -> Dict[s
except Exception:
pass
logger.exception(f"Failed to add user to SR: {e}")
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to add user: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to add user: {e}",
+ )
finally:
if conn:
pass
- def remove_user(self, sr_id: str, target_user_id: str, requester_id: str) -> Dict[str, Any]:
+ def remove_user(
+ self, sr_id: str, target_user_id: str, requester_id: str
+ ) -> Dict[str, Any]:
"""
Remove a user id from the SR's users list. Owner cannot be removed.
Enforces requester permissions (must be a member or owner).
"""
-
sr = self.get_systematic_review(sr_id)
if not sr or not sr.get("visible", True):
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Systematic review not found")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="Systematic review not found",
+ )
# Check permission
has_perm = self.user_has_sr_permission(sr_id, requester_id)
if not has_perm:
- raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized to modify this systematic review")
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail="Not authorized to modify this systematic review",
+ )
if target_user_id == sr.get("owner_id"):
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot remove the owner from the systematic review")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail="Cannot remove the owner from the systematic review",
+ )
conn = None
try:
conn = postgres_server.conn
cur = conn.cursor()
-
+
# Get current users array
cur.execute("SELECT users FROM systematic_reviews WHERE id = %s", (sr_id,))
row = cur.fetchone()
if not row:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Systematic review not found")
-
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="Systematic review not found",
+ )
+
users = row[0] if row[0] else []
if isinstance(users, str):
users = json.loads(users)
-
+
# Remove user if present
if target_user_id in users:
users.remove(target_user_id)
-
+
# Update
now = datetime.utcnow().isoformat()
cur.execute(
"UPDATE systematic_reviews SET users = %s, updated_at = %s WHERE id = %s",
- (json.dumps(users), now, sr_id)
+ (json.dumps(users), now, sr_id),
)
modified_count = cur.rowcount
conn.commit()
-
-
- return {"matched_count": 1, "modified_count": modified_count, "removed_user_id": target_user_id}
-
+ return {
+ "matched_count": 1,
+ "modified_count": modified_count,
+ "removed_user_id": target_user_id,
+ }
+
except HTTPException:
try:
if conn:
@@ -376,7 +436,10 @@ def remove_user(self, sr_id: str, target_user_id: str, requester_id: str) -> Dic
except Exception:
pass
logger.exception(f"Failed to remove user from SR: {e}")
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to remove user: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to remove user: {e}",
+ )
finally:
if conn:
pass
@@ -388,7 +451,6 @@ def user_has_sr_permission(self, sr_id: str, user_id: str) -> bool:
Note: this check deliberately ignores the SR's 'visible' flag so membership checks
work regardless of whether the SR is hidden/soft-deleted.
"""
-
doc = self.get_systematic_review(sr_id, ignore_visibility=True)
if not doc:
@@ -399,56 +461,71 @@ def user_has_sr_permission(self, sr_id: str, user_id: str) -> bool:
return True
return False
- def update_criteria(self, sr_id: str, criteria_obj: Dict[str, Any], criteria_str: str, requester_id: str) -> Dict[str, Any]:
+ def update_criteria(
+ self,
+ sr_id: str,
+ criteria_obj: Dict[str, Any],
+ criteria_str: str,
+ requester_id: str,
+ ) -> Dict[str, Any]:
"""
Update the criteria fields (criteria, criteria_yaml, criteria_parsed, updated_at).
The requester must be a member or owner.
Returns the updated SR document.
"""
-
sr = self.get_systematic_review(sr_id)
if not sr or not sr.get("visible", True):
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Systematic review not found")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="Systematic review not found",
+ )
# Check permission
has_perm = self.user_has_sr_permission(sr_id, requester_id)
if not has_perm:
- raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Not authorized to modify this systematic review")
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail="Not authorized to modify this systematic review",
+ )
conn = None
try:
conn = postgres_server.conn
cur = conn.cursor()
-
+
updated_at = datetime.utcnow().isoformat()
criteria_parsed = self.build_criteria_parsed(criteria_obj)
-
+
update_sql = """
UPDATE systematic_reviews
SET criteria = %s, criteria_yaml = %s, criteria_parsed = %s, updated_at = %s
WHERE id = %s
"""
-
- cur.execute(update_sql, (
- json.dumps(criteria_obj),
- criteria_str,
- json.dumps(criteria_parsed),
- updated_at,
- sr_id
- ))
-
+
+ cur.execute(
+ update_sql,
+ (
+ json.dumps(criteria_obj),
+ criteria_str,
+ json.dumps(criteria_parsed),
+ updated_at,
+ sr_id,
+ ),
+ )
+
if cur.rowcount == 0:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Systematic review not found during update")
-
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="Systematic review not found during update",
+ )
+
conn.commit()
-
-
# Return fresh doc
doc = self.get_systematic_review(sr_id)
return doc
-
+
except HTTPException:
try:
if conn:
@@ -463,7 +540,10 @@ def update_criteria(self, sr_id: str, criteria_obj: Dict[str, Any], criteria_str
except Exception:
pass
logger.exception(f"Failed to update SR criteria: {e}")
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to update criteria: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to update criteria: {e}",
+ )
finally:
if conn:
pass
@@ -472,48 +552,48 @@ def list_systematic_reviews_for_user(self, user_email: str) -> List[Dict[str, An
"""
Return all SR documents where the user is a member (regardless of visible flag).
"""
-
conn = None
try:
conn = postgres_server.conn
cur = conn.cursor()
-
+
# Query using jsonb operator to check if user_email is in users array
query = """
SELECT * FROM systematic_reviews
WHERE users @> %s::jsonb
ORDER BY created_at DESC
"""
-
+
cur.execute(query, (json.dumps([user_email]),))
rows = cur.fetchall()
cols = [desc[0] for desc in cur.description]
-
+
results = []
for row in rows:
doc = {cols[i]: row[i] for i in range(len(cols))}
-
+
# Parse JSON fields and convert timestamps
- if doc.get('users') and isinstance(doc['users'], str):
- doc['users'] = json.loads(doc['users'])
- if doc.get('criteria') and isinstance(doc['criteria'], str):
- doc['criteria'] = json.loads(doc['criteria'])
- if doc.get('criteria_parsed') and isinstance(doc['criteria_parsed'], str):
- doc['criteria_parsed'] = json.loads(doc['criteria_parsed'])
+ if doc.get("users") and isinstance(doc["users"], str):
+ doc["users"] = json.loads(doc["users"])
+ if doc.get("criteria") and isinstance(doc["criteria"], str):
+ doc["criteria"] = json.loads(doc["criteria"])
+ if doc.get("criteria_parsed") and isinstance(
+ doc["criteria_parsed"], str
+ ):
+ doc["criteria_parsed"] = json.loads(doc["criteria_parsed"])
# Convert datetime objects to ISO strings
from datetime import datetime as dt
- if doc.get('created_at') and isinstance(doc['created_at'], dt):
- doc['created_at'] = doc['created_at'].isoformat()
- if doc.get('updated_at') and isinstance(doc['updated_at'], dt):
- doc['updated_at'] = doc['updated_at'].isoformat()
-
+
+ if doc.get("created_at") and isinstance(doc["created_at"], dt):
+ doc["created_at"] = doc["created_at"].isoformat()
+ if doc.get("updated_at") and isinstance(doc["updated_at"], dt):
+ doc["updated_at"] = doc["updated_at"].isoformat()
+
results.append(doc)
-
-
return results
-
+
except Exception as e:
try:
if conn:
@@ -521,55 +601,60 @@ def list_systematic_reviews_for_user(self, user_email: str) -> List[Dict[str, An
except Exception:
pass
logger.exception(f"Failed to list SRs for user: {e}")
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to list systematic reviews: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to list systematic reviews: {e}",
+ )
finally:
if conn:
pass
- def get_systematic_review(self, sr_id: str, ignore_visibility: bool = False) -> Optional[Dict[str, Any]]:
+ def get_systematic_review(
+ self, sr_id: str, ignore_visibility: bool = False
+ ) -> Optional[Dict[str, Any]]:
"""
Return SR document by id. Returns None if not found.
If ignore_visibility is False, only returns visible SRs.
"""
-
conn = None
try:
conn = postgres_server.conn
cur = conn.cursor()
-
+
if ignore_visibility:
query = "SELECT * FROM systematic_reviews WHERE id = %s"
else:
- query = "SELECT * FROM systematic_reviews WHERE id = %s AND visible = TRUE"
-
+ query = (
+ "SELECT * FROM systematic_reviews WHERE id = %s AND visible = TRUE"
+ )
+
cur.execute(query, (sr_id,))
row = cur.fetchone()
-
+
if not row:
return None
-
+
cols = [desc[0] for desc in cur.description]
doc = {cols[i]: row[i] for i in range(len(cols))}
-
+
# Parse JSON fields and convert timestamps
- if doc.get('users') and isinstance(doc['users'], str):
- doc['users'] = json.loads(doc['users'])
- if doc.get('criteria') and isinstance(doc['criteria'], str):
- doc['criteria'] = json.loads(doc['criteria'])
- if doc.get('criteria_parsed') and isinstance(doc['criteria_parsed'], str):
- doc['criteria_parsed'] = json.loads(doc['criteria_parsed'])
+ if doc.get("users") and isinstance(doc["users"], str):
+ doc["users"] = json.loads(doc["users"])
+ if doc.get("criteria") and isinstance(doc["criteria"], str):
+ doc["criteria"] = json.loads(doc["criteria"])
+ if doc.get("criteria_parsed") and isinstance(doc["criteria_parsed"], str):
+ doc["criteria_parsed"] = json.loads(doc["criteria_parsed"])
# Convert datetime objects to ISO strings
from datetime import datetime as dt
- if doc.get('created_at') and isinstance(doc['created_at'], dt):
- doc['created_at'] = doc['created_at'].isoformat()
- if doc.get('updated_at') and isinstance(doc['updated_at'], dt):
- doc['updated_at'] = doc['updated_at'].isoformat()
-
-
+ if doc.get("created_at") and isinstance(doc["created_at"], dt):
+ doc["created_at"] = doc["created_at"].isoformat()
+ if doc.get("updated_at") and isinstance(doc["updated_at"], dt):
+ doc["updated_at"] = doc["updated_at"].isoformat()
+
return doc
-
+
except Exception as e:
# IMPORTANT: do not swallow DB errors as "not found".
# Roll back so this connection isn't poisoned for subsequent requests.
@@ -584,37 +669,46 @@ def get_systematic_review(self, sr_id: str, ignore_visibility: bool = False) ->
if conn:
pass
- def set_visibility(self, sr_id: str, visible: bool, requester_id: str) -> Dict[str, Any]:
+ def set_visibility(
+ self, sr_id: str, visible: bool, requester_id: str
+ ) -> Dict[str, Any]:
"""
Set the visible flag on the SR. Only owner is allowed to change visibility.
Returns update metadata.
"""
-
sr = self.get_systematic_review(sr_id, ignore_visibility=True)
if not sr:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Systematic review not found")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="Systematic review not found",
+ )
if requester_id != sr.get("owner_id"):
- raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Only the owner may change visibility of this systematic review")
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail="Only the owner may change visibility of this systematic review",
+ )
conn = None
try:
conn = postgres_server.conn
cur = conn.cursor()
-
+
updated_at = datetime.utcnow().isoformat()
cur.execute(
"UPDATE systematic_reviews SET visible = %s, updated_at = %s WHERE id = %s",
- (bool(visible), updated_at, sr_id)
+ (bool(visible), updated_at, sr_id),
)
modified_count = cur.rowcount
conn.commit()
-
-
- return {"matched_count": 1, "modified_count": modified_count, "visible": visible}
-
+ return {
+ "matched_count": 1,
+ "modified_count": modified_count,
+ "visible": visible,
+ }
+
except Exception as e:
try:
if conn:
@@ -622,48 +716,62 @@ def set_visibility(self, sr_id: str, visible: bool, requester_id: str) -> Dict[s
except Exception:
pass
logger.exception(f"Failed to set visibility on SR: {e}")
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to set visibility: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to set visibility: {e}",
+ )
finally:
if conn:
- pass
+ pass
- def soft_delete_systematic_review(self, sr_id: str, requester_id: str) -> Dict[str, Any]:
+ def soft_delete_systematic_review(
+ self, sr_id: str, requester_id: str
+ ) -> Dict[str, Any]:
"""
Soft-delete (set visible=False). Only owner may delete.
"""
return self.set_visibility(sr_id, False, requester_id)
- def undelete_systematic_review(self, sr_id: str, requester_id: str) -> Dict[str, Any]:
+ def undelete_systematic_review(
+ self, sr_id: str, requester_id: str
+ ) -> Dict[str, Any]:
"""
Undelete (set visible=True). Only owner may undelete.
"""
return self.set_visibility(sr_id, True, requester_id)
- def hard_delete_systematic_review(self, sr_id: str, requester_id: str) -> Dict[str, Any]:
+ def hard_delete_systematic_review(
+ self, sr_id: str, requester_id: str
+ ) -> Dict[str, Any]:
"""
Permanently remove the SR document. Only owner may hard delete.
Returns deletion metadata (deleted_count).
"""
-
sr = self.get_systematic_review(sr_id, ignore_visibility=True)
if not sr:
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Systematic review not found")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="Systematic review not found",
+ )
if requester_id != sr.get("owner_id"):
- raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Only the owner may hard-delete this systematic review")
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail="Only the owner may hard-delete this systematic review",
+ )
conn = None
try:
conn = postgres_server.conn
cur = conn.cursor()
-
+
cur.execute("DELETE FROM systematic_reviews WHERE id = %s", (sr_id,))
deleted_count = cur.rowcount
conn.commit()
return {"deleted_count": deleted_count}
-
+
except Exception as e:
try:
if conn:
@@ -671,32 +779,33 @@ def hard_delete_systematic_review(self, sr_id: str, requester_id: str) -> Dict[s
except Exception:
pass
logger.exception(f"Failed to hard-delete SR: {e}")
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to hard-delete systematic review: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to hard-delete systematic review: {e}",
+ )
finally:
if conn:
pass
-
- def update_screening_db_info(self, sr_id: str, screening_db: Dict[str, Any]) -> None:
+ def update_screening_db_info(
+ self, sr_id: str, screening_db: Dict[str, Any]
+ ) -> None:
"""
Update the screening_db field in the SR document with screening database metadata.
"""
-
conn = None
try:
conn = postgres_server.conn
cur = conn.cursor()
-
+
updated_at = datetime.utcnow().isoformat()
cur.execute(
"UPDATE systematic_reviews SET screening_db = %s, updated_at = %s WHERE id = %s",
- (json.dumps(screening_db), updated_at, sr_id)
+ (json.dumps(screening_db), updated_at, sr_id),
)
conn.commit()
-
-
except Exception as e:
try:
if conn:
@@ -704,7 +813,10 @@ def update_screening_db_info(self, sr_id: str, screening_db: Dict[str, Any]) ->
except Exception:
pass
logger.exception(f"Failed to update screening DB info: {e}")
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to update screening DB info: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to update screening DB info: {e}",
+ )
finally:
if conn:
pass
@@ -713,22 +825,19 @@ def clear_screening_db_info(self, sr_id: str) -> None:
"""
Remove the screening_db field from the SR document.
"""
-
conn = None
try:
conn = postgres_server.conn
cur = conn.cursor()
-
+
updated_at = datetime.utcnow().isoformat()
cur.execute(
"UPDATE systematic_reviews SET screening_db = NULL, updated_at = %s WHERE id = %s",
- (updated_at, sr_id)
+ (updated_at, sr_id),
)
conn.commit()
-
-
except Exception as e:
try:
if conn:
@@ -736,7 +845,10 @@ def clear_screening_db_info(self, sr_id: str) -> None:
except Exception:
pass
logger.exception(f"Failed to clear screening DB info: {e}")
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to clear screening DB info: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to clear screening DB info: {e}",
+ )
finally:
if conn:
pass
diff --git a/backend/api/services/storage.py b/backend/api/services/storage.py
index 3ab9b25b..52c69441 100644
--- a/backend/api/services/storage.py
+++ b/backend/api/services/storage.py
@@ -23,7 +23,11 @@
try:
from azure.core.exceptions import ResourceNotFoundError
from azure.identity import DefaultAzureCredential
- from azure.storage.blob import BlobSasPermissions, BlobServiceClient, generate_blob_sas
+ from azure.storage.blob import (
+ BlobSasPermissions,
+ BlobServiceClient,
+ generate_blob_sas,
+ )
except Exception: # pragma: no cover
# Allow local-storage deployments/environments to import without azure packages.
ResourceNotFoundError = Exception # type: ignore
@@ -44,16 +48,28 @@ class StorageService(Protocol):
container_name: str
async def create_user_directory(self, user_id: str) -> bool: ...
- async def save_user_profile(self, user_id: str, profile_data: Dict[str, Any]) -> bool: ...
+ async def save_user_profile(
+ self, user_id: str, profile_data: Dict[str, Any]
+ ) -> bool: ...
async def get_user_profile(self, user_id: str) -> Optional[Dict[str, Any]]: ...
- async def upload_user_document(self, user_id: str, filename: str, file_content: bytes) -> Optional[str]: ...
- async def get_user_document(self, user_id: str, doc_id: str, filename: str) -> Optional[bytes]: ...
+ async def upload_user_document(
+ self, user_id: str, filename: str, file_content: bytes
+ ) -> Optional[str]: ...
+ async def get_user_document(
+ self, user_id: str, doc_id: str, filename: str
+ ) -> Optional[bytes]: ...
async def list_user_documents(self, user_id: str) -> List[Dict[str, Any]]: ...
- async def delete_user_document(self, user_id: str, doc_id: str, filename: str) -> bool: ...
- async def put_bytes_by_path(self, path: str, content: bytes, content_type: str = "application/octet-stream") -> bool: ...
+ async def delete_user_document(
+ self, user_id: str, doc_id: str, filename: str
+ ) -> bool: ...
+ async def put_bytes_by_path(
+ self, path: str, content: bytes, content_type: str = "application/octet-stream"
+ ) -> bool: ...
async def get_bytes_by_path(self, path: str) -> Tuple[bytes, str]: ...
async def delete_by_path(self, path: str) -> bool: ...
- async def generate_signed_url(self, path: str, expiry_minutes: int = 5) -> Optional[str]: ...
+ async def generate_signed_url(
+ self, path: str, expiry_minutes: int = 5
+ ) -> Optional[str]: ...
# =============================================================================
@@ -64,36 +80,50 @@ async def generate_signed_url(self, path: str, expiry_minutes: int = 5) -> Optio
class AzureStorageService:
"""Service for managing user data in Azure Blob Storage."""
- def __init__(self, *, account_url: str | None = None, connection_string: str | None = None, container_name: str):
+ def __init__(
+ self,
+ *,
+ account_url: str | None = None,
+ connection_string: str | None = None,
+ container_name: str,
+ ):
if not BlobServiceClient:
raise RuntimeError(
"Azure storage libraries are not installed. Install azure-identity and azure-storage-blob, or use STORAGE_MODE=local."
)
if bool(account_url) == bool(connection_string):
- raise ValueError("Exactly one of account_url or connection_string must be provided")
+ raise ValueError(
+ "Exactly one of account_url or connection_string must be provided"
+ )
self._account_key: str | None = None
self._credential: Any = None
if connection_string:
- self.blob_service_client = BlobServiceClient.from_connection_string(connection_string)
- self._account_key = self._get_account_key_from_connection_str(connection_string)
+ self.blob_service_client = BlobServiceClient.from_connection_string(
+ connection_string
+ )
+ self._account_key = self._get_account_key_from_connection_str(
+ connection_string
+ )
else:
if not DefaultAzureCredential:
raise RuntimeError(
"azure-identity is not installed. Install azure-identity, or use STORAGE_MODE=azure (connection string) or local."
)
self._credential = DefaultAzureCredential()
- self.blob_service_client = BlobServiceClient(account_url=account_url, credential=self._credential)
+ self.blob_service_client = BlobServiceClient(
+ account_url=account_url, credential=self._credential
+ )
self.container_name = container_name
self._ensure_container_exists()
-
+
def _get_account_key_from_connection_str(self, connection_str):
for part in connection_str.split(";"):
if part.startswith("AccountKey="):
- return part[len("AccountKey="):]
+ return part[len("AccountKey=") :]
return None
def _ensure_container_exists(self):
@@ -116,17 +146,23 @@ async def create_user_directory(self, user_id: str) -> bool:
# Create placeholder file to establish directory structure
blob_name = f"users/{user_id}/documents/.placeholder"
- blob_client = self.blob_service_client.get_blob_client(container=self.container_name, blob=blob_name)
+ blob_client = self.blob_service_client.get_blob_client(
+ container=self.container_name, blob=blob_name
+ )
blob_client.upload_blob(b"", overwrite=True)
return True
except Exception as e:
logger.error("Error creating user directory for %s: %s", user_id, e)
return False
- async def save_user_profile(self, user_id: str, profile_data: Dict[str, Any]) -> bool:
+ async def save_user_profile(
+ self, user_id: str, profile_data: Dict[str, Any]
+ ) -> bool:
try:
blob_name = f"users/{user_id}/profile.json"
- blob_client = self.blob_service_client.get_blob_client(container=self.container_name, blob=blob_name)
+ blob_client = self.blob_service_client.get_blob_client(
+ container=self.container_name, blob=blob_name
+ )
blob_client.upload_blob(json.dumps(profile_data, indent=2), overwrite=True)
return True
except Exception as e:
@@ -136,7 +172,9 @@ async def save_user_profile(self, user_id: str, profile_data: Dict[str, Any]) ->
async def get_user_profile(self, user_id: str) -> Optional[Dict[str, Any]]:
try:
blob_name = f"users/{user_id}/profile.json"
- blob_client = self.blob_service_client.get_blob_client(container=self.container_name, blob=blob_name)
+ blob_client = self.blob_service_client.get_blob_client(
+ container=self.container_name, blob=blob_name
+ )
blob_data = blob_client.download_blob().readall()
return json.loads(blob_data.decode("utf-8"))
except ResourceNotFoundError:
@@ -145,11 +183,15 @@ async def get_user_profile(self, user_id: str) -> Optional[Dict[str, Any]]:
logger.error("Error getting user profile for %s: %s", user_id, e)
return None
- async def upload_user_document(self, user_id: str, filename: str, file_content: bytes) -> Optional[str]:
+ async def upload_user_document(
+ self, user_id: str, filename: str, file_content: bytes
+ ) -> Optional[str]:
try:
doc_id = str(uuid.uuid4())
blob_name = f"users/{user_id}/documents/{doc_id}_{filename}"
- blob_client = self.blob_service_client.get_blob_client(container=self.container_name, blob=blob_name)
+ blob_client = self.blob_service_client.get_blob_client(
+ container=self.container_name, blob=blob_name
+ )
blob_client.upload_blob(file_content, overwrite=True)
file_metadata = create_file_metadata(
@@ -166,30 +208,42 @@ async def upload_user_document(self, user_id: str, filename: str, file_content:
profile = await self.get_user_profile(user_id)
if profile:
profile["document_count"] = int(profile.get("document_count", 0)) + 1
- profile["storage_used"] = int(profile.get("storage_used", 0)) + len(file_content)
+ profile["storage_used"] = int(profile.get("storage_used", 0)) + len(
+ file_content
+ )
profile["last_updated"] = datetime.now(timezone.utc).isoformat()
await self.save_user_profile(user_id, profile)
return doc_id
except Exception as e:
- logger.error("Error uploading document %s for user %s: %s", filename, user_id, e)
+ logger.error(
+ "Error uploading document %s for user %s: %s", filename, user_id, e
+ )
return None
- async def get_user_document(self, user_id: str, doc_id: str, filename: str) -> Optional[bytes]:
+ async def get_user_document(
+ self, user_id: str, doc_id: str, filename: str
+ ) -> Optional[bytes]:
try:
blob_name = f"users/{user_id}/documents/{doc_id}_{filename}"
- blob_client = self.blob_service_client.get_blob_client(container=self.container_name, blob=blob_name)
+ blob_client = self.blob_service_client.get_blob_client(
+ container=self.container_name, blob=blob_name
+ )
return blob_client.download_blob().readall()
except ResourceNotFoundError:
return None
except Exception as e:
- logger.error("Error getting document %s for user %s: %s", doc_id, user_id, e)
+ logger.error(
+ "Error getting document %s for user %s: %s", doc_id, user_id, e
+ )
return None
async def list_user_documents(self, user_id: str) -> List[Dict[str, Any]]:
try:
prefix = f"users/{user_id}/documents/"
- blobs = self.blob_service_client.get_container_client(self.container_name).list_blobs(name_starts_with=prefix)
+ blobs = self.blob_service_client.get_container_client(
+ self.container_name
+ ).list_blobs(name_starts_with=prefix)
documents: List[Dict[str, Any]] = []
for blob in blobs:
@@ -218,10 +272,14 @@ async def list_user_documents(self, user_id: str) -> List[Dict[str, Any]]:
logger.error("Error listing user documents for %s: %s", user_id, e)
return []
- async def delete_user_document(self, user_id: str, doc_id: str, filename: str) -> bool:
+ async def delete_user_document(
+ self, user_id: str, doc_id: str, filename: str
+ ) -> bool:
try:
doc_blob_name = f"users/{user_id}/documents/{doc_id}_{filename}"
- doc_blob_client = self.blob_service_client.get_blob_client(container=self.container_name, blob=doc_blob_name)
+ doc_blob_client = self.blob_service_client.get_blob_client(
+ container=self.container_name, blob=doc_blob_name
+ )
try:
doc_size = doc_blob_client.get_blob_properties().size
@@ -233,30 +291,44 @@ async def delete_user_document(self, user_id: str, doc_id: str, filename: str) -
profile = await self.get_user_profile(user_id)
if profile:
- profile["document_count"] = max(0, int(profile.get("document_count", 0)) - 1)
- profile["storage_used"] = max(0, int(profile.get("storage_used", 0)) - int(doc_size))
+ profile["document_count"] = max(
+ 0, int(profile.get("document_count", 0)) - 1
+ )
+ profile["storage_used"] = max(
+ 0, int(profile.get("storage_used", 0)) - int(doc_size)
+ )
profile["last_updated"] = datetime.now(timezone.utc).isoformat()
await self.save_user_profile(user_id, profile)
return True
except Exception as e:
- logger.error("Error deleting document %s for user %s: %s", doc_id, user_id, e)
+ logger.error(
+ "Error deleting document %s for user %s: %s", doc_id, user_id, e
+ )
return False
- async def save_file_hash_metadata(self, user_id: str, document_id: str, file_metadata: Dict[str, Any]) -> bool:
+ async def save_file_hash_metadata(
+ self, user_id: str, document_id: str, file_metadata: Dict[str, Any]
+ ) -> bool:
try:
blob_name = f"users/{user_id}/metadata/{document_id}_metadata.json"
- blob_client = self.blob_service_client.get_blob_client(container=self.container_name, blob=blob_name)
+ blob_client = self.blob_service_client.get_blob_client(
+ container=self.container_name, blob=blob_name
+ )
blob_client.upload_blob(json.dumps(file_metadata, indent=2), overwrite=True)
return True
except Exception as e:
logger.error("Error saving file metadata for %s: %s", document_id, e)
return False
- async def get_file_hash_metadata(self, user_id: str, document_id: str) -> Optional[Dict[str, Any]]:
+ async def get_file_hash_metadata(
+ self, user_id: str, document_id: str
+ ) -> Optional[Dict[str, Any]]:
try:
blob_name = f"users/{user_id}/metadata/{document_id}_metadata.json"
- blob_client = self.blob_service_client.get_blob_client(container=self.container_name, blob=blob_name)
+ blob_client = self.blob_service_client.get_blob_client(
+ container=self.container_name, blob=blob_name
+ )
metadata_json = blob_client.download_blob().readall().decode("utf-8")
return json.loads(metadata_json)
except ResourceNotFoundError:
@@ -268,7 +340,9 @@ async def get_file_hash_metadata(self, user_id: str, document_id: str) -> Option
async def delete_file_hash_metadata(self, user_id: str, document_id: str) -> bool:
try:
blob_name = f"users/{user_id}/metadata/{document_id}_metadata.json"
- blob_client = self.blob_service_client.get_blob_client(container=self.container_name, blob=blob_name)
+ blob_client = self.blob_service_client.get_blob_client(
+ container=self.container_name, blob=blob_name
+ )
blob_client.delete_blob()
return True
except ResourceNotFoundError:
@@ -277,12 +351,16 @@ async def delete_file_hash_metadata(self, user_id: str, document_id: str) -> boo
logger.error("Error deleting file metadata for %s: %s", document_id, e)
return False
- async def put_bytes_by_path(self, path: str, content: bytes, content_type: str = "application/octet-stream") -> bool:
+ async def put_bytes_by_path(
+ self, path: str, content: bytes, content_type: str = "application/octet-stream"
+ ) -> bool:
"""Write blob by storage path 'container/blob'."""
if not path or "/" not in path:
raise ValueError("Invalid storage path")
container, blob = path.split("/", 1)
- blob_client = self.blob_service_client.get_blob_client(container=container, blob=blob)
+ blob_client = self.blob_service_client.get_blob_client(
+ container=container, blob=blob
+ )
blob_client.upload_blob(content, overwrite=True, content_type=content_type)
return True
@@ -292,7 +370,9 @@ async def get_bytes_by_path(self, path: str) -> Tuple[bytes, str]:
raise ValueError("Invalid storage path")
container, blob = path.split("/", 1)
- blob_client = self.blob_service_client.get_blob_client(container=container, blob=blob)
+ blob_client = self.blob_service_client.get_blob_client(
+ container=container, blob=blob
+ )
content = blob_client.download_blob().readall()
filename = os.path.basename(blob) or "download"
return content, filename
@@ -302,11 +382,15 @@ async def delete_by_path(self, path: str) -> bool:
if not path or "/" not in path:
raise ValueError("Invalid storage path")
container, blob = path.split("/", 1)
- blob_client = self.blob_service_client.get_blob_client(container=container, blob=blob)
+ blob_client = self.blob_service_client.get_blob_client(
+ container=container, blob=blob
+ )
blob_client.delete_blob()
return True
- async def generate_signed_url(self, path: str, expiry_minutes: int = 5) -> Optional[str]:
+ async def generate_signed_url(
+ self, path: str, expiry_minutes: int = 5
+ ) -> Optional[str]:
"""Generate a read-only SAS URL for a blob. Path format: 'container/blob'."""
if not path or "/" not in path:
raise ValueError("Invalid storage path")
@@ -323,13 +407,17 @@ async def generate_signed_url(self, path: str, expiry_minutes: int = 5) -> Optio
"expiry": expiry,
}
if self._account_key:
- sas_token = generate_blob_sas(**blob_sas_kwargs, account_key=self._account_key)
+ sas_token = generate_blob_sas(
+ **blob_sas_kwargs, account_key=self._account_key
+ )
elif self._credential:
delegation_key = self.blob_service_client.get_user_delegation_key(
key_start_time=datetime.now(timezone.utc) - timedelta(minutes=1),
key_expiry_time=expiry,
)
- sas_token = generate_blob_sas(**blob_sas_kwargs, user_delegation_key=delegation_key)
+ sas_token = generate_blob_sas(
+ **blob_sas_kwargs, user_delegation_key=delegation_key
+ )
else:
raise RuntimeError("No credentials available for SAS generation")
@@ -387,10 +475,14 @@ async def create_user_directory(self, user_id: str) -> bool:
logger.error("Error creating user directory for %s: %s", user_id, e)
return False
- async def save_user_profile(self, user_id: str, profile_data: Dict[str, Any]) -> bool:
+ async def save_user_profile(
+ self, user_id: str, profile_data: Dict[str, Any]
+ ) -> bool:
try:
self._profile_path(user_id).parent.mkdir(parents=True, exist_ok=True)
- self._profile_path(user_id).write_text(json.dumps(profile_data, indent=2), encoding="utf-8")
+ self._profile_path(user_id).write_text(
+ json.dumps(profile_data, indent=2), encoding="utf-8"
+ )
return True
except Exception as e:
logger.error("Error saving user profile for %s: %s", user_id, e)
@@ -406,7 +498,9 @@ async def get_user_profile(self, user_id: str) -> Optional[Dict[str, Any]]:
logger.error("Error getting user profile for %s: %s", user_id, e)
return None
- async def upload_user_document(self, user_id: str, filename: str, file_content: bytes) -> Optional[str]:
+ async def upload_user_document(
+ self, user_id: str, filename: str, file_content: bytes
+ ) -> Optional[str]:
try:
await self.create_user_directory(user_id)
doc_id = str(uuid.uuid4())
@@ -428,23 +522,31 @@ async def upload_user_document(self, user_id: str, filename: str, file_content:
profile = await self.get_user_profile(user_id)
if profile:
profile["document_count"] = int(profile.get("document_count", 0)) + 1
- profile["storage_used"] = int(profile.get("storage_used", 0)) + len(file_content)
+ profile["storage_used"] = int(profile.get("storage_used", 0)) + len(
+ file_content
+ )
profile["last_updated"] = datetime.now(timezone.utc).isoformat()
await self.save_user_profile(user_id, profile)
return doc_id
except Exception as e:
- logger.error("Error uploading document %s for user %s: %s", filename, user_id, e)
+ logger.error(
+ "Error uploading document %s for user %s: %s", filename, user_id, e
+ )
return None
- async def get_user_document(self, user_id: str, doc_id: str, filename: str) -> Optional[bytes]:
+ async def get_user_document(
+ self, user_id: str, doc_id: str, filename: str
+ ) -> Optional[bytes]:
try:
p = self._doc_path(user_id, doc_id, filename)
if not p.exists():
return None
return p.read_bytes()
except Exception as e:
- logger.error("Error getting document %s for user %s: %s", doc_id, user_id, e)
+ logger.error(
+ "Error getting document %s for user %s: %s", doc_id, user_id, e
+ )
return None
async def list_user_documents(self, user_id: str) -> List[Dict[str, Any]]:
@@ -468,8 +570,12 @@ async def list_user_documents(self, user_id: str) -> List[Dict[str, Any]]:
"document_id": doc_id,
"filename": filename,
"file_size": stat.st_size,
- "upload_date": datetime.fromtimestamp(stat.st_mtime, tz=timezone.utc).isoformat(),
- "last_modified": datetime.fromtimestamp(stat.st_mtime, tz=timezone.utc).isoformat(),
+ "upload_date": datetime.fromtimestamp(
+ stat.st_mtime, tz=timezone.utc
+ ).isoformat(),
+ "last_modified": datetime.fromtimestamp(
+ stat.st_mtime, tz=timezone.utc
+ ).isoformat(),
}
if hash_metadata:
doc_info["file_hash"] = hash_metadata.get("file_hash")
@@ -483,7 +589,9 @@ async def list_user_documents(self, user_id: str) -> List[Dict[str, Any]]:
logger.error("Error listing user documents for %s: %s", user_id, e)
return []
- async def delete_user_document(self, user_id: str, doc_id: str, filename: str) -> bool:
+ async def delete_user_document(
+ self, user_id: str, doc_id: str, filename: str
+ ) -> bool:
try:
p = self._doc_path(user_id, doc_id, filename)
doc_size = p.stat().st_size if p.exists() else 0
@@ -493,16 +601,24 @@ async def delete_user_document(self, user_id: str, doc_id: str, filename: str) -
profile = await self.get_user_profile(user_id)
if profile:
- profile["document_count"] = max(0, int(profile.get("document_count", 0)) - 1)
- profile["storage_used"] = max(0, int(profile.get("storage_used", 0)) - int(doc_size))
+ profile["document_count"] = max(
+ 0, int(profile.get("document_count", 0)) - 1
+ )
+ profile["storage_used"] = max(
+ 0, int(profile.get("storage_used", 0)) - int(doc_size)
+ )
profile["last_updated"] = datetime.now(timezone.utc).isoformat()
await self.save_user_profile(user_id, profile)
return True
except Exception as e:
- logger.error("Error deleting document %s for user %s: %s", doc_id, user_id, e)
+ logger.error(
+ "Error deleting document %s for user %s: %s", doc_id, user_id, e
+ )
return False
- async def save_file_hash_metadata(self, user_id: str, document_id: str, file_metadata: Dict[str, Any]) -> bool:
+ async def save_file_hash_metadata(
+ self, user_id: str, document_id: str, file_metadata: Dict[str, Any]
+ ) -> bool:
try:
p = self._metadata_path(user_id, document_id)
p.parent.mkdir(parents=True, exist_ok=True)
@@ -512,7 +628,9 @@ async def save_file_hash_metadata(self, user_id: str, document_id: str, file_met
logger.error("Error saving file metadata for %s: %s", document_id, e)
return False
- async def get_file_hash_metadata(self, user_id: str, document_id: str) -> Optional[Dict[str, Any]]:
+ async def get_file_hash_metadata(
+ self, user_id: str, document_id: str
+ ) -> Optional[Dict[str, Any]]:
try:
p = self._metadata_path(user_id, document_id)
if not p.exists():
@@ -532,7 +650,9 @@ async def delete_file_hash_metadata(self, user_id: str, document_id: str) -> boo
logger.error("Error deleting file metadata for %s: %s", document_id, e)
return False
- async def put_bytes_by_path(self, path: str, content: bytes, content_type: str = "application/octet-stream") -> bool:
+ async def put_bytes_by_path(
+ self, path: str, content: bytes, content_type: str = "application/octet-stream"
+ ) -> bool:
"""Write file by storage path 'container/blob'."""
if not path or "/" not in path:
raise ValueError("Invalid storage path")
@@ -585,7 +705,9 @@ async def delete_by_path(self, path: str) -> bool:
p.unlink()
return True
- async def generate_signed_url(self, path: str, expiry_minutes: int = 5) -> Optional[str]:
+ async def generate_signed_url(
+ self, path: str, expiry_minutes: int = 5
+ ) -> Optional[str]:
"""Local storage cannot generate signed URLs; returns None to signal streaming fallback."""
return None
@@ -605,8 +727,13 @@ def _build_storage_service() -> Optional[StorageService]:
return None
if stype == "azure":
try:
- if not settings.AZURE_STORAGE_ACCOUNT_NAME or not settings.AZURE_STORAGE_ACCOUNT_KEY:
- raise ValueError("STORAGE_MODE=azure requires AZURE_STORAGE_ACCOUNT_NAME and AZURE_STORAGE_ACCOUNT_KEY")
+ if (
+ not settings.AZURE_STORAGE_ACCOUNT_NAME
+ or not settings.AZURE_STORAGE_ACCOUNT_KEY
+ ):
+ raise ValueError(
+ "STORAGE_MODE=azure requires AZURE_STORAGE_ACCOUNT_NAME and AZURE_STORAGE_ACCOUNT_KEY"
+ )
connection_string = (
"DefaultEndpointsProtocol=https;"
f"AccountName={settings.AZURE_STORAGE_ACCOUNT_NAME};"
@@ -618,13 +745,19 @@ def _build_storage_service() -> Optional[StorageService]:
container_name=settings.STORAGE_CONTAINER_NAME,
)
except Exception as e:
- logger.exception("Failed to initialize AzureStorageService (connection string): %s", e)
+ logger.exception(
+ "Failed to initialize AzureStorageService (connection string): %s", e
+ )
return None
if stype == "entra":
try:
if not settings.AZURE_STORAGE_ACCOUNT_NAME:
- raise ValueError("STORAGE_MODE=entra requires AZURE_STORAGE_ACCOUNT_NAME")
- account_url = f"https://{settings.AZURE_STORAGE_ACCOUNT_NAME}.blob.core.windows.net"
+ raise ValueError(
+ "STORAGE_MODE=entra requires AZURE_STORAGE_ACCOUNT_NAME"
+ )
+ account_url = (
+ f"https://{settings.AZURE_STORAGE_ACCOUNT_NAME}.blob.core.windows.net"
+ )
return AzureStorageService(
account_url=account_url,
container_name=settings.STORAGE_CONTAINER_NAME,
diff --git a/backend/api/services/user_db.py b/backend/api/services/user_db.py
index ee1c0411..c2025cb5 100644
--- a/backend/api/services/user_db.py
+++ b/backend/api/services/user_db.py
@@ -16,7 +16,6 @@
logger = logging.getLogger(__name__)
-
class UserDatabaseService:
"""Service for managing user data in PostgreSQL via psycopg3 async."""
@@ -29,6 +28,7 @@ def __init__(self):
self._registry_cache_ts: float = 0.0
self._registry_cache_ttl_s: float = 30.0
self._registry_cache_lock = asyncio.Lock()
+
@staticmethod
def _serialize_dates(row: Dict[str, Any]) -> Dict[str, Any]:
for field in ("created_at", "updated_at", "last_login"):
@@ -50,11 +50,16 @@ async def _load_user_registry(self) -> Dict[str, Any]:
"""Load the user registry from storage."""
async with self._registry_cache_lock:
now = asyncio.get_running_loop().time()
- if self._registry_cache is not None and (now - self._registry_cache_ts) < self._registry_cache_ttl_s:
+ if (
+ self._registry_cache is not None
+ and (now - self._registry_cache_ts) < self._registry_cache_ttl_s
+ ):
return self._registry_cache
try:
- content, _filename = await self.storage.get_bytes_by_path(self._registry_path())
+ content, _filename = await self.storage.get_bytes_by_path(
+ self._registry_path()
+ )
reg = json.loads(content.decode("utf-8"))
except Exception:
# Create empty registry if it doesn't exist / cannot be read
@@ -81,12 +86,12 @@ async def _save_user_registry(self, registry: Dict[str, Any]) -> bool:
return ok
except Exception:
return False
+
async def ensure_table_exists(self) -> None:
"""Create the users table if it does not already exist."""
try:
async with postgres_server.aconn() as conn:
- await conn.execute(
- """
+ await conn.execute("""
CREATE TABLE IF NOT EXISTS users (
id TEXT PRIMARY KEY,
email TEXT UNIQUE NOT NULL,
@@ -98,8 +103,7 @@ async def ensure_table_exists(self) -> None:
updated_at TIMESTAMP WITH TIME ZONE DEFAULT now(),
last_login TIMESTAMP WITH TIME ZONE
)
- """
- )
+ """)
logger.info("Ensured users table exists")
except Exception as e:
logger.exception("Failed to ensure users table exists: %s", e)
@@ -129,8 +133,17 @@ async def create_user(self, user_data: UserCreate) -> Optional[UserRead]:
created_at, updated_at, last_login)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
""",
- (user_id, email, user_data.full_name, hashed_password,
- True, False, now, now, None),
+ (
+ user_id,
+ email,
+ user_data.full_name,
+ hashed_password,
+ True,
+ False,
+ now,
+ now,
+ None,
+ ),
)
return UserRead(
id=user_id,
@@ -178,7 +191,9 @@ async def authenticate_user(
row = await cur.fetchone()
if not row:
return None
- if not sso and not self._verify_password(password, row["hashed_password"]):
+ if not sso and not self._verify_password(
+ password, row["hashed_password"]
+ ):
return None
now = datetime.now(timezone.utc)
await conn.execute(
@@ -193,8 +208,12 @@ async def update_user(
self, user_id: str, update_data: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
allowed_fields = {
- "full_name", "email", "hashed_password",
- "is_active", "is_superuser", "last_login",
+ "full_name",
+ "email",
+ "hashed_password",
+ "is_active",
+ "is_superuser",
+ "last_login",
}
fields = {k: v for k, v in update_data.items() if k in allowed_fields}
if not fields:
diff --git a/backend/api/sr/router.py b/backend/api/sr/router.py
index e36b095d..84a4e92f 100644
--- a/backend/api/sr/router.py
+++ b/backend/api/sr/router.py
@@ -28,6 +28,7 @@
router = APIRouter()
+
class SystematicReviewCreate(BaseModel):
name: str
description: Optional[str] = None
@@ -53,10 +54,9 @@ class SystematicReviewRead(BaseModel):
criteria_parsed: Optional[Dict[str, Any]] = None
-
-
-
-@router.post("/create", response_model=SystematicReviewRead, status_code=status.HTTP_201_CREATED)
+@router.post(
+ "/create", response_model=SystematicReviewRead, status_code=status.HTTP_201_CREATED
+)
async def create_systematic_review(
name: str = Form(...),
description: Optional[str] = Form(None),
@@ -78,7 +78,9 @@ async def create_systematic_review(
"""
if not name or not name.strip():
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="name is required")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST, detail="name is required"
+ )
# Load YAML criteria
criteria_str: Optional[str] = None
@@ -103,7 +105,8 @@ async def create_systematic_review(
)
except yaml.YAMLError as ye:
raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid YAML provided: {ye}"
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=f"Invalid YAML provided: {ye}",
)
except Exception as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
@@ -121,7 +124,10 @@ async def create_systematic_review(
except HTTPException:
raise
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to create systematic review: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to create systematic review: {e}",
+ )
return SystematicReviewRead(
id=sr_doc.get("id"),
@@ -164,27 +170,45 @@ async def add_user_to_systematic_review(
"""
try:
- sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service, require_screening=False)
+ sr, screening = await load_sr_and_check(
+ sr_id, current_user, srdb_service, require_screening=False
+ )
except HTTPException:
raise
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to load systematic review: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to load systematic review: {e}",
+ )
# resolve user
target_user_id = None
if payload.user_email:
target_user_id = payload.user_email
else:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Missing data user_email")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST, detail="Missing data user_email"
+ )
try:
- res = await run_in_threadpool(srdb_service.add_user, sr_id, target_user_id, current_user.get("id"))
+ res = await run_in_threadpool(
+ srdb_service.add_user, sr_id, target_user_id, current_user.get("id")
+ )
except HTTPException:
raise
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to add user: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to add user: {e}",
+ )
- return {"status": "success", "sr_id": sr_id, "added_user_id": target_user_id, "matched_count": res.get("matched_count"), "modified_count": res.get("modified_count")}
+ return {
+ "status": "success",
+ "sr_id": sr_id,
+ "added_user_id": target_user_id,
+ "matched_count": res.get("matched_count"),
+ "modified_count": res.get("modified_count"),
+ }
@router.post("/{sr_id}/remove-user")
@@ -201,31 +225,52 @@ async def remove_user_from_systematic_review(
The owner cannot be removed via this endpoint.
"""
try:
- sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service, require_screening=False)
+ sr, screening = await load_sr_and_check(
+ sr_id, current_user, srdb_service, require_screening=False
+ )
except HTTPException:
raise
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to load systematic review: {e}")
-
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to load systematic review: {e}",
+ )
+
# resolve user
target_user_id = None
if payload.user_email:
target_user_id = payload.user_email
else:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Missing data user_email")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST, detail="Missing data user_email"
+ )
# do not allow removing the owner
if target_user_id == sr.get("owner_id"):
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Cannot remove the owner from the systematic review")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail="Cannot remove the owner from the systematic review",
+ )
try:
- res = await run_in_threadpool(srdb_service.remove_user, sr_id, target_user_id, current_user.get("id"))
+ res = await run_in_threadpool(
+ srdb_service.remove_user, sr_id, target_user_id, current_user.get("id")
+ )
except HTTPException:
raise
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to remove user: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to remove user: {e}",
+ )
- return {"status": "success", "sr_id": sr_id, "removed_user_id": target_user_id, "matched_count": res.get("matched_count"), "modified_count": res.get("modified_count")}
+ return {
+ "status": "success",
+ "sr_id": sr_id,
+ "removed_user_id": target_user_id,
+ "matched_count": res.get("matched_count"),
+ "modified_count": res.get("modified_count"),
+ }
@router.get("/mine", response_model=List[SystematicReviewRead])
@@ -240,11 +285,16 @@ async def list_systematic_reviews_for_user(
user_id = current_user.get("email")
results = []
try:
- docs = await run_in_threadpool(srdb_service.list_systematic_reviews_for_user, user_id)
+ docs = await run_in_threadpool(
+ srdb_service.list_systematic_reviews_for_user, user_id
+ )
except HTTPException:
raise
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to list systematic reviews: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to list systematic reviews: {e}",
+ )
for doc in docs:
results.append(
@@ -268,17 +318,24 @@ async def list_systematic_reviews_for_user(
@router.get("/{sr_id}", response_model=SystematicReviewRead)
-async def get_systematic_review(sr_id: str, current_user: Dict[str, Any] = Depends(get_current_active_user)):
+async def get_systematic_review(
+ sr_id: str, current_user: Dict[str, Any] = Depends(get_current_active_user)
+):
"""
Get a single systematic review by id. User must be a member to view.
"""
try:
- doc, screening = await load_sr_and_check(sr_id, current_user, srdb_service, require_screening=False)
+ doc, screening = await load_sr_and_check(
+ sr_id, current_user, srdb_service, require_screening=False
+ )
except HTTPException:
raise
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to load systematic review: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to load systematic review: {e}",
+ )
return SystematicReviewRead(
id=doc.get("id"),
@@ -308,11 +365,16 @@ async def get_systematic_review_criteria_parsed(
"""
try:
- doc, screening = await load_sr_and_check(sr_id, current_user, srdb_service, require_screening=False)
+ doc, screening = await load_sr_and_check(
+ sr_id, current_user, srdb_service, require_screening=False
+ )
except HTTPException:
raise
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to load systematic review: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to load systematic review: {e}",
+ )
cp = doc.get("criteria_parsed") or {}
return {"criteria_parsed": cp}
@@ -334,11 +396,16 @@ async def update_systematic_review_criteria(
"""
try:
- sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service, require_screening=False)
+ sr, screening = await load_sr_and_check(
+ sr_id, current_user, srdb_service, require_screening=False
+ )
except HTTPException:
raise
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to load systematic review: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to load systematic review: {e}",
+ )
# Load YAML criteria
criteria_str: Optional[str] = None
@@ -352,7 +419,10 @@ async def update_systematic_review_criteria(
criteria_str = criteria_yaml
if not criteria_str:
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Either criteria_file or criteria_yaml must be provided")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail="Either criteria_file or criteria_yaml must be provided",
+ )
criteria_obj = yaml.safe_load(criteria_str)
if criteria_obj is None:
@@ -364,18 +434,28 @@ async def update_systematic_review_criteria(
)
except yaml.YAMLError as ye:
raise HTTPException(
- status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid YAML provided: {ye}"
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=f"Invalid YAML provided: {ye}",
)
except Exception as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
# perform update
try:
- doc = await run_in_threadpool(srdb_service.update_criteria, sr_id, criteria_obj, criteria_str, current_user.get("id"))
+ doc = await run_in_threadpool(
+ srdb_service.update_criteria,
+ sr_id,
+ criteria_obj,
+ criteria_str,
+ current_user.get("id"),
+ )
except HTTPException:
raise
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to update criteria: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to update criteria: {e}",
+ )
return SystematicReviewRead(
id=doc.get("id"),
@@ -394,7 +474,9 @@ async def update_systematic_review_criteria(
@router.delete("/{sr_id}")
-async def delete_systematic_review(sr_id: str, current_user: Dict[str, Any] = Depends(get_current_active_user)):
+async def delete_systematic_review(
+ sr_id: str, current_user: Dict[str, Any] = Depends(get_current_active_user)
+):
"""
Soft-delete a systematic review by marking its 'visible' flag as False.
@@ -402,28 +484,49 @@ async def delete_systematic_review(sr_id: str, current_user: Dict[str, Any] = De
"""
try:
- sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service, require_screening=False)
+ sr, screening = await load_sr_and_check(
+ sr_id, current_user, srdb_service, require_screening=False
+ )
except HTTPException:
raise
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to load systematic review: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to load systematic review: {e}",
+ )
requester_id = current_user.get("id")
if requester_id != sr.get("owner_id"):
- raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Only the owner may delete this systematic review")
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail="Only the owner may delete this systematic review",
+ )
try:
- res = await run_in_threadpool(srdb_service.soft_delete_systematic_review, sr_id, requester_id)
+ res = await run_in_threadpool(
+ srdb_service.soft_delete_systematic_review, sr_id, requester_id
+ )
except HTTPException:
raise
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to delete systematic review: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to delete systematic review: {e}",
+ )
- return {"status": "success", "sr_id": sr_id, "deleted": True, "matched_count": res.get("matched_count"), "modified_count": res.get("modified_count")}
+ return {
+ "status": "success",
+ "sr_id": sr_id,
+ "deleted": True,
+ "matched_count": res.get("matched_count"),
+ "modified_count": res.get("modified_count"),
+ }
@router.post("/{sr_id}/undelete")
-async def undelete_systematic_review(sr_id: str, current_user: Dict[str, Any] = Depends(get_current_active_user)):
+async def undelete_systematic_review(
+ sr_id: str, current_user: Dict[str, Any] = Depends(get_current_active_user)
+):
"""
Undelete (restore) a systematic review by marking its 'visible' flag as True.
@@ -431,28 +534,53 @@ async def undelete_systematic_review(sr_id: str, current_user: Dict[str, Any] =
"""
try:
- sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service, require_screening=False, require_visible=False)
+ sr, screening = await load_sr_and_check(
+ sr_id,
+ current_user,
+ srdb_service,
+ require_screening=False,
+ require_visible=False,
+ )
except HTTPException:
raise
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to load systematic review: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to load systematic review: {e}",
+ )
requester_id = current_user.get("id")
if requester_id != sr.get("owner_id"):
- raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Only the owner may undelete this systematic review")
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail="Only the owner may undelete this systematic review",
+ )
try:
- res = await run_in_threadpool(srdb_service.undelete_systematic_review, sr_id, requester_id)
+ res = await run_in_threadpool(
+ srdb_service.undelete_systematic_review, sr_id, requester_id
+ )
except HTTPException:
raise
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to undelete systematic review: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to undelete systematic review: {e}",
+ )
- return {"status": "success", "sr_id": sr_id, "undeleted": True, "matched_count": res.get("matched_count"), "modified_count": res.get("modified_count")}
+ return {
+ "status": "success",
+ "sr_id": sr_id,
+ "undeleted": True,
+ "matched_count": res.get("matched_count"),
+ "modified_count": res.get("modified_count"),
+ }
@router.delete("/{sr_id}/hard")
-async def hard_delete_systematic_review(sr_id: str, current_user: Dict[str, Any] = Depends(get_current_active_user)):
+async def hard_delete_systematic_review(
+ sr_id: str, current_user: Dict[str, Any] = Depends(get_current_active_user)
+):
"""
Permanently remove the systematic review document from MongoDB.
@@ -463,15 +591,27 @@ async def hard_delete_systematic_review(sr_id: str, current_user: Dict[str, Any]
"""
try:
- sr, screening = await load_sr_and_check(sr_id, current_user, srdb_service, require_screening=False, require_visible=False)
+ sr, screening = await load_sr_and_check(
+ sr_id,
+ current_user,
+ srdb_service,
+ require_screening=False,
+ require_visible=False,
+ )
except HTTPException:
raise
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to load systematic review: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to load systematic review: {e}",
+ )
requester_id = current_user.get("id")
if requester_id != sr.get("owner_id"):
- raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Only the owner may hard-delete this systematic review")
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail="Only the owner may hard-delete this systematic review",
+ )
# Attempt to perform screening resources cleanup prior to deleting the SR document.
cleanup_result = None
@@ -492,15 +632,23 @@ async def hard_delete_systematic_review(sr_id: str, current_user: Dict[str, Any]
cleanup_result = {"status": "cleanup_import_failed", "error": str(e)}
try:
- res = await run_in_threadpool(srdb_service.hard_delete_systematic_review, sr_id, requester_id)
+ res = await run_in_threadpool(
+ srdb_service.hard_delete_systematic_review, sr_id, requester_id
+ )
deleted_count = res.get("deleted_count")
if not deleted_count:
# If backend reported zero deletions, raise NotFound to match prior behavior
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Systematic review not found during hard delete")
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="Systematic review not found during hard delete",
+ )
except HTTPException:
raise
except Exception as e:
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to hard-delete systematic review: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=f"Failed to hard-delete systematic review: {e}",
+ )
return {
"status": "success",
diff --git a/backend/docker-compose.yml b/backend/docker-compose.yml
index 3cca1cbd..9a1bae1a 100644
--- a/backend/docker-compose.yml
+++ b/backend/docker-compose.yml
@@ -1,3 +1,19 @@
+# services:
+# # =============================================================================
+# # CAN-SR BACKEND API
+# # =============================================================================
+# api:
+# build:
+# context: .
+# dockerfile: Dockerfile
+# container_name: can-sr-api
+# ports:
+# - "8000:8000"
+# environment:
+# - GROBID_SERVICE_URL=http://grobid-service:8070
+# - JAVA_OPTS=-XX:-UseContainerSupport
+
+
services:
# =============================================================================
# CAN-SR BACKEND API
@@ -29,9 +45,24 @@ services:
# =============================================================================
# GROBID SERVICE - PDF Parsing for Full-Text Extraction
# =============================================================================
+ # grobid-service:
+ # image: grobid/grobid:0.8.2-crf
+ # container_name: grobid-service
+ # ports:
+ # - "8070:8070"
+ # - "8081:8081"
+ # restart: unless-stopped
+ # healthcheck:
+ # test: ["CMD", "curl", "-f", "http://localhost:8070/api/isalive"]
+ # interval: 10s
+ # timeout: 10s
+ # retries: 5
+ # start_period: 120s
grobid-service:
- image: grobid/grobid:0.8.2-crf
+ image: grobid/grobid:0.8.2
container_name: grobid-service
+ environment:
+ - JAVA_TOOL_OPTIONS=-XX:+UseContainerSupport
ports:
- "8070:8070"
- "8081:8081"
@@ -57,7 +88,7 @@ services:
ports:
- "5432:5432"
volumes:
- - ./volumes/postgres:/var/lib/postgresql/data
+ - ./volumes/postgres:/var/lib/postgresql #/data
healthcheck:
test: ["CMD-SHELL", "pg_isready -U admin -d postgres -h localhost"]
interval: 30s
@@ -67,4 +98,4 @@ services:
networks:
default:
- driver: bridge
+ driver: bridge
\ No newline at end of file
diff --git a/backend/main.py b/backend/main.py
index 606dc55f..244c36f7 100644
--- a/backend/main.py
+++ b/backend/main.py
@@ -15,7 +15,6 @@
from api.services.sr_db_service import srdb_service
from api.services.user_db import user_db_service
-
app = FastAPI(
title=settings.PROJECT_NAME,
description=settings.DESCRIPTION,
@@ -29,10 +28,12 @@ async def startup_event():
"""Startup event - initialize CAN-SR systematic review database"""
from fastapi.concurrency import run_in_threadpool
import asyncio
-
+
# Reduce Azure SDK HTTP logging noise (especially during polling endpoints).
logging.getLogger("azure").setLevel(logging.WARNING)
- logging.getLogger("azure.core.pipeline.policies.http_logging_policy").setLevel(logging.WARNING)
+ logging.getLogger("azure.core.pipeline.policies.http_logging_policy").setLevel(
+ logging.WARNING
+ )
print("🚀 Starting CAN-SR Backend...", flush=True)
print("📚 Initializing systematic review database...", flush=True)
@@ -61,6 +62,7 @@ async def startup_event():
# Keep Procrastinate open for the whole API lifespan so request handlers
# can enqueue jobs.
from api.jobs.procrastinate_app import PROCRASTINATE_APP
+
await PROCRASTINATE_APP.open_async()
await ensure_procrastinate_schema()
await run_in_threadpool(run_all_repo.ensure_tables)
@@ -71,9 +73,14 @@ async def startup_event():
if getattr(settings, "PROCRASTINATE_CLEAR_ON_START", True):
try:
cleared = await clear_pending_jobs(queues=["default"])
- print(f"🧹 Cleared {cleared} pending Procrastinate jobs", flush=True)
+ print(
+ f"🧹 Cleared {cleared} pending Procrastinate jobs", flush=True
+ )
except Exception as e:
- print(f"⚠️ Failed to clear pending Procrastinate jobs: {e}", flush=True)
+ print(
+ f"⚠️ Failed to clear pending Procrastinate jobs: {e}",
+ flush=True,
+ )
if workers_enabled():
# Run a worker loop inside the API process (dev/quick deploy).
@@ -96,8 +103,6 @@ async def shutdown_event():
except Exception:
pass
-
-
# Set up CORS
cors_origins = (
@@ -105,7 +110,12 @@ async def shutdown_event():
if isinstance(settings.CORS_ORIGINS, str)
else [settings.CORS_ORIGINS]
)
-app.add_middleware(SessionMiddleware, secret_key=settings.SECRET_KEY, same_site="lax", https_only=settings.IS_DEPLOYED)
+app.add_middleware(
+ SessionMiddleware,
+ secret_key=settings.SECRET_KEY,
+ same_site="lax",
+ https_only=settings.IS_DEPLOYED,
+)
app.add_middleware(
CORSMiddleware,
allow_origins=cors_origins,
diff --git a/frontend/app/[lang]/can-sr/citations/full-text/route.ts b/frontend/app/[lang]/can-sr/citations/full-text/route.ts
index 94912a11..5e587e9e 100644
--- a/frontend/app/[lang]/can-sr/citations/full-text/route.ts
+++ b/frontend/app/[lang]/can-sr/citations/full-text/route.ts
@@ -95,9 +95,16 @@ export async function GET(request: NextRequest) {
const citation = await citationRes.json().catch(() => ({}))
// Try common fields: document_id, documentId, storage_path, fulltext_url
const documentId =
- citation?.document_id || citation?.documentId || citation?.document || null
+ citation?.document_id ||
+ citation?.documentId ||
+ citation?.document ||
+ null
const storagePath =
- citation?.storage_path || citation?.storagePath || citation?.fulltext_url || citation?.fulltext || null
+ citation?.storage_path ||
+ citation?.storagePath ||
+ citation?.fulltext_url ||
+ citation?.fulltext ||
+ null
if (documentId) {
fileFetchUrl = `${BACKEND_URL}/api/files/documents/${encodeURIComponent(
@@ -163,7 +170,10 @@ export async function GET(request: NextRequest) {
})
} catch (err: any) {
console.error('Fulltext proxy GET error:', err)
- return NextResponse.json({ error: 'Internal server error' }, { status: 500 })
+ return NextResponse.json(
+ { error: 'Internal server error' },
+ { status: 500 },
+ )
}
}
diff --git a/frontend/app/[lang]/can-sr/extract/page.tsx b/frontend/app/[lang]/can-sr/extract/page.tsx
index a15f8773..f31b0e19 100644
--- a/frontend/app/[lang]/can-sr/extract/page.tsx
+++ b/frontend/app/[lang]/can-sr/extract/page.tsx
@@ -50,7 +50,8 @@ const buildCitationAiCalls: BuildCitationAiCalls = async ({
)
const data = await res.json().catch(() => ({}))
const parsed = data?.criteria_parsed || data?.criteria || {}
- const paramsInfo: ParametersParsed | null = (parsed?.parameters as any) || null
+ const paramsInfo: ParametersParsed | null =
+ (parsed?.parameters as any) || null
if (paramsInfo?.categories && paramsInfo?.possible_parameters) {
const out: Array<{ name: string; description: string }> = []
@@ -68,7 +69,10 @@ const buildCitationAiCalls: BuildCitationAiCalls = async ({
typeof descs?.[j] === 'string' ? (descs[j] as string) : ''
const cleanDesc = rawDesc.replace(/<\/?desc>/g, '')
if (rawName && rawName.trim()) {
- out.push({ name: rawName.trim(), description: cleanDesc || rawName })
+ out.push({
+ name: rawName.trim(),
+ description: cleanDesc || rawName,
+ })
}
})
})
diff --git a/frontend/app/[lang]/can-sr/extract/view/page.tsx b/frontend/app/[lang]/can-sr/extract/view/page.tsx
index 120c2609..d0750326 100644
--- a/frontend/app/[lang]/can-sr/extract/view/page.tsx
+++ b/frontend/app/[lang]/can-sr/extract/view/page.tsx
@@ -5,7 +5,9 @@ import { useState, useEffect, useRef } from 'react'
import GCHeader, { SRHeader } from '@/components/can-sr/headers'
import { getAuthToken, getTokenType } from '@/lib/auth'
import { ModelSelector } from '@/components/chat'
-import PDFBoundingBoxViewer, { PDFBoundingBoxViewerHandle } from '@/components/can-sr/PDFBoundingBoxViewer'
+import PDFBoundingBoxViewer, {
+ PDFBoundingBoxViewerHandle,
+} from '@/components/can-sr/PDFBoundingBoxViewer'
import { ChevronDown, ChevronRight, Wand2 } from 'lucide-react'
import { useDictionary } from '@/app/[lang]/DictionaryProvider'
@@ -78,7 +80,8 @@ export default function CanSrL2ScreenPage() {
return llmCol.replace('llm_param_', 'human_param_')
}
- const [parametersParsed, setParametersParsed] = useState(null)
+ const [parametersParsed, setParametersParsed] =
+ useState(null)
const [paramValues, setParamValues] = useState>({})
const [savingParam, setSavingParam] = useState(null)
const [saveStatus, setSaveStatus] = useState>({})
@@ -91,7 +94,9 @@ export default function CanSrL2ScreenPage() {
const [aiPanels, setAiPanels] = useState>({})
const [panelOpen, setPanelOpen] = useState>({})
const [fulltextCoords, setFulltextCoords] = useState(null)
- const [fulltextPages, setFulltextPages] = useState<{ width: number; height: number }[] | null>(null)
+ const [fulltextPages, setFulltextPages] = useState<
+ { width: number; height: number }[] | null
+ >(null)
const [fulltextStr, setFulltextStr] = useState(null)
const viewerRef = useRef(null)
@@ -100,9 +105,14 @@ export default function CanSrL2ScreenPage() {
const [fulltextFigures, setFulltextFigures] = useState(null)
const scrollToArtifact = (kind: 'table' | 'figure', idx: number) => {
- const list = kind === 'table' ? (fulltextTables || []) : (fulltextFigures || [])
+ const list = kind === 'table' ? fulltextTables || [] : fulltextFigures || []
const item = list.find((x: any) => Number(x?.index) === Number(idx))
- console.log('[artifact-click]', { kind, idx, hasViewer: !!viewerRef.current, item })
+ console.log('[artifact-click]', {
+ kind,
+ idx,
+ hasViewer: !!viewerRef.current,
+ item,
+ })
if (!item || !viewerRef.current) return
const bbox = item?.bounding_box
const first = Array.isArray(bbox) ? bbox[0] : null
@@ -112,7 +122,10 @@ export default function CanSrL2ScreenPage() {
}
const [runningAllAI, setRunningAllAI] = useState(false)
- const [runAllProgress, setRunAllProgress] = useState<{ done: number; total: number } | null>(null)
+ const [runAllProgress, setRunAllProgress] = useState<{
+ done: number
+ total: number
+ } | null>(null)
// Cache full text so single-param and run-all don’t repeatedly trigger extraction/DB reads
const fullTextCacheRef = useRef(null)
@@ -176,11 +189,11 @@ export default function CanSrL2ScreenPage() {
const suggestParam = async (name: string, description: string) => {
if (!citationId || !srId) return
- setAiStatus(prev => ({ ...prev, [name]: 'extracting' }))
+ setAiStatus((prev) => ({ ...prev, [name]: 'extracting' }))
try {
const headers = getAuthHeaders()
const fullText = await ensureFullText()
- setAiStatus(prev => ({ ...prev, [name]: 'suggesting' }))
+ setAiStatus((prev) => ({ ...prev, [name]: 'suggesting' }))
const res = await fetch(
`/api/can-sr/extract?action=extract-parameter&sr_id=${encodeURIComponent(
srId || '',
@@ -201,16 +214,16 @@ export default function CanSrL2ScreenPage() {
const data = await res.json().catch(() => ({}))
const ext = data?.extraction || data
if (res.ok && ext) {
- setParamValues(prev => ({ ...prev, [name]: ext?.value ?? '' }))
- setParamFound(prev => ({ ...prev, [name]: !!ext?.found }))
- setAiPanels(prev => ({ ...prev, [name]: ext }))
- setPanelOpen(prev => ({ ...prev, [name]: false }))
- setAiStatus(prev => ({ ...prev, [name]: 'suggested' }))
+ setParamValues((prev) => ({ ...prev, [name]: ext?.value ?? '' }))
+ setParamFound((prev) => ({ ...prev, [name]: !!ext?.found }))
+ setAiPanels((prev) => ({ ...prev, [name]: ext }))
+ setPanelOpen((prev) => ({ ...prev, [name]: false }))
+ setAiStatus((prev) => ({ ...prev, [name]: 'suggested' }))
} else {
- setAiStatus(prev => ({ ...prev, [name]: 'error' }))
+ setAiStatus((prev) => ({ ...prev, [name]: 'error' }))
}
} catch {
- setAiStatus(prev => ({ ...prev, [name]: 'error' }))
+ setAiStatus((prev) => ({ ...prev, [name]: 'error' }))
}
}
@@ -224,7 +237,11 @@ export default function CanSrL2ScreenPage() {
const desc = parametersParsed.descriptions?.[i]?.[j] || ''
const cleanDesc = desc.replace(/<\/?desc>/g, '')
const paramName =
- typeof param === 'string' ? param : Array.isArray(param) ? param[0] : String(param)
+ typeof param === 'string'
+ ? param
+ : Array.isArray(param)
+ ? param[0]
+ : String(param)
params.push({ name: paramName, description: cleanDesc })
})
})
@@ -258,7 +275,11 @@ export default function CanSrL2ScreenPage() {
const data = await res.json().catch(() => ({}))
const parsed = data?.criteria_parsed || data?.criteria || {}
const paramsInfo = parsed?.parameters
- if (paramsInfo && paramsInfo.categories && paramsInfo.possible_parameters) {
+ if (
+ paramsInfo &&
+ paramsInfo.categories &&
+ paramsInfo.possible_parameters
+ ) {
setParametersParsed(paramsInfo)
const defaults: Record = {}
paramsInfo.possible_parameters.forEach((arr: string[]) => {
@@ -266,14 +287,14 @@ export default function CanSrL2ScreenPage() {
defaults[name] = defaults[name] || ''
})
})
- setParamValues(prev => ({ ...defaults, ...prev }))
+ setParamValues((prev) => ({ ...defaults, ...prev }))
const defaultFound: Record = {}
paramsInfo.possible_parameters.forEach((arr: string[]) => {
arr.forEach((name: string) => {
defaultFound[name] = defaultFound[name] ?? false
})
})
- setParamFound(prev => ({ ...defaultFound, ...prev }))
+ setParamFound((prev) => ({ ...defaultFound, ...prev }))
}
} catch (err) {
console.warn('Failed to load parameters', err)
@@ -333,29 +354,39 @@ export default function CanSrL2ScreenPage() {
})
if (Object.keys(nextFound).length) {
- setParamFound(prev => ({ ...prev, ...nextFound }))
+ setParamFound((prev) => ({ ...prev, ...nextFound }))
}
if (Object.keys(nextValues).length) {
- setParamValues(prev => ({ ...prev, ...nextValues }))
+ setParamValues((prev) => ({ ...prev, ...nextValues }))
}
if (Object.keys(nextAIPanels).length) {
- setAiPanels(prev => ({ ...prev, ...nextAIPanels }))
+ setAiPanels((prev) => ({ ...prev, ...nextAIPanels }))
}
// extract coords/pages/fulltext and artifacts for PDF overlay
- const ft = typeof (row as any).fulltext === 'string' ? (row as any).fulltext : null
+ const ft =
+ typeof (row as any).fulltext === 'string'
+ ? (row as any).fulltext
+ : null
if (ft) setFulltextStr(ft)
- const tablesAny = parseJson((row as any).fulltext_tables) ?? (row as any).fulltext_tables
+ const tablesAny =
+ parseJson((row as any).fulltext_tables) ??
+ (row as any).fulltext_tables
if (tablesAny && Array.isArray(tablesAny)) setFulltextTables(tablesAny)
- const figsAny = parseJson((row as any).fulltext_figures) ?? (row as any).fulltext_figures
+ const figsAny =
+ parseJson((row as any).fulltext_figures) ??
+ (row as any).fulltext_figures
if (figsAny && Array.isArray(figsAny)) setFulltextFigures(figsAny)
- const coordsAny = parseJson((row as any).fulltext_coords) ?? (row as any).fulltext_coords
+ const coordsAny =
+ parseJson((row as any).fulltext_coords) ??
+ (row as any).fulltext_coords
if (coordsAny && Array.isArray(coordsAny)) setFulltextCoords(coordsAny)
- const pagesAny = parseJson((row as any).fulltext_pages) ?? (row as any).fulltext_pages
+ const pagesAny =
+ parseJson((row as any).fulltext_pages) ?? (row as any).fulltext_pages
if (pagesAny && Array.isArray(pagesAny)) setFulltextPages(pagesAny)
} catch (err) {
console.warn('Failed to prefill citation params', err)
@@ -387,24 +418,37 @@ export default function CanSrL2ScreenPage() {
}
}
- const ft = typeof (row as any).fulltext === 'string' ? (row as any).fulltext : null
+ const ft =
+ typeof (row as any).fulltext === 'string'
+ ? (row as any).fulltext
+ : null
if (ft) setFulltextStr(ft)
- const tablesAny = parseJson((row as any).fulltext_tables) ?? (row as any).fulltext_tables
+ const tablesAny =
+ parseJson((row as any).fulltext_tables) ??
+ (row as any).fulltext_tables
if (tablesAny && Array.isArray(tablesAny)) setFulltextTables(tablesAny)
- const figsAny = parseJson((row as any).fulltext_figures) ?? (row as any).fulltext_figures
+ const figsAny =
+ parseJson((row as any).fulltext_figures) ??
+ (row as any).fulltext_figures
if (figsAny && Array.isArray(figsAny)) setFulltextFigures(figsAny)
- const coordsAny = parseJson((row as any).fulltext_coords) ?? (row as any).fulltext_coords
+ const coordsAny =
+ parseJson((row as any).fulltext_coords) ??
+ (row as any).fulltext_coords
if (coordsAny && Array.isArray(coordsAny)) setFulltextCoords(coordsAny)
- const pagesAny = parseJson((row as any).fulltext_pages) ?? (row as any).fulltext_pages
+ const pagesAny =
+ parseJson((row as any).fulltext_pages) ?? (row as any).fulltext_pages
if (pagesAny && Array.isArray(pagesAny)) setFulltextPages(pagesAny)
// If coords/pages are missing, trigger backend extraction to populate them, then refetch
const needExtract =
- !Array.isArray(coordsAny) || coordsAny.length === 0 || !Array.isArray(pagesAny) || pagesAny.length === 0
+ !Array.isArray(coordsAny) ||
+ coordsAny.length === 0 ||
+ !Array.isArray(pagesAny) ||
+ pagesAny.length === 0
if (needExtract) {
try {
@@ -423,14 +467,23 @@ export default function CanSrL2ScreenPage() {
)
const row2 = await res3.json().catch(() => ({}))
- const ft2 = typeof (row2 as any).fulltext === 'string' ? (row2 as any).fulltext : null
+ const ft2 =
+ typeof (row2 as any).fulltext === 'string'
+ ? (row2 as any).fulltext
+ : null
if (ft2) setFulltextStr(ft2)
- const coordsAny2 = parseJson((row2 as any).fulltext_coords) ?? (row2 as any).fulltext_coords
- if (coordsAny2 && Array.isArray(coordsAny2)) setFulltextCoords(coordsAny2)
-
- const pagesAny2 = parseJson((row2 as any).fulltext_pages) ?? (row2 as any).fulltext_pages
- if (pagesAny2 && Array.isArray(pagesAny2)) setFulltextPages(pagesAny2)
+ const coordsAny2 =
+ parseJson((row2 as any).fulltext_coords) ??
+ (row2 as any).fulltext_coords
+ if (coordsAny2 && Array.isArray(coordsAny2))
+ setFulltextCoords(coordsAny2)
+
+ const pagesAny2 =
+ parseJson((row2 as any).fulltext_pages) ??
+ (row2 as any).fulltext_pages
+ if (pagesAny2 && Array.isArray(pagesAny2))
+ setFulltextPages(pagesAny2)
}
} catch (err) {
console.warn('Failed to extract fulltext for overlay', err)
@@ -443,13 +496,13 @@ export default function CanSrL2ScreenPage() {
}, [srId, citationId])
const updateValue = (name: string, val: string) => {
- setParamValues(prev => ({ ...prev, [name]: val }))
+ setParamValues((prev) => ({ ...prev, [name]: val }))
}
const saveParam = async (name: string) => {
if (!citationId || !srId) return
setSavingParam(name)
- setSaveStatus(prev => ({ ...prev, [name]: 'saving' }))
+ setSaveStatus((prev) => ({ ...prev, [name]: 'saving' }))
try {
const headers = getAuthHeaders()
const res = await fetch(
@@ -470,9 +523,9 @@ export default function CanSrL2ScreenPage() {
},
)
await res.json().catch(() => ({}))
- setSaveStatus(prev => ({ ...prev, [name]: res.ok ? 'saved' : 'error' }))
+ setSaveStatus((prev) => ({ ...prev, [name]: res.ok ? 'saved' : 'error' }))
} catch {
- setSaveStatus(prev => ({ ...prev, [name]: 'error' }))
+ setSaveStatus((prev) => ({ ...prev, [name]: 'error' }))
} finally {
setSavingParam(null)
}
@@ -488,7 +541,7 @@ export default function CanSrL2ScreenPage() {
-
-
+
{/* Workspace (left) */}
{/*
*/}
- {/*
+ {/*
Workspace
This area displays the full text (PDF) and is a flexible workspace for viewing and selecting parameter regions.
*/}
-
+
{/*
*/}
{/* Selection sidebar (right) */}