diff --git a/.gitignore b/.gitignore index bffbd3e..0b40a08 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,10 @@ -api_config.py .env +api_config.py +__pycache__/ +*.pyc +.venv/ +vector_store/ +experiments/ +metadata_inputs/ +pdf_metadata_outputs/ +xml_metadata_outputs/ diff --git a/README.md b/README.md index d889892..367b704 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,7 @@ cBioAbstractor is a Streamlit-based curation assistant for cancer genomics studi - Upload a cancer genomics paper PDF - Upload supplementary data files such as `.xlsx`, `.csv`, `.tsv`, `.txt`, `.maf`, `.docx`, and `.pdf` +- Download supplementary files automatically from PubMed Central using a PMCID or PMID - Extract study-level metadata from the paper - Classify supplementary sheets against cBioPortal file-format schemas - Identify likely cBioPortal target files @@ -74,12 +75,16 @@ pip install -r requirements.txt ## API Key Setup -Set your Anthropic API key locally as an environment variable: +Set either an Anthropic or OpenAI API key locally as an environment variable: ```bash export ANTHROPIC_API_KEY="your-api-key" +# or +export OPENAI_API_KEY="your-api-key" ``` +The Streamlit sidebar lets you choose between Anthropic and OpenAI models. + Do not commit API keys to GitHub. Recommended `.gitignore` entries: @@ -113,7 +118,7 @@ http://localhost:8501 1. Open the Streamlit app 2. Upload the main paper PDF -3. Upload one or more supplementary files +3. Upload one or more supplementary files, or enter a PMCID/PMID to fetch them from PubMed Central 4. Run the curation workflow 5. Review detected file types, required fields, and missing fields 6. Download the generated cBioPortal curation report @@ -167,4 +172,3 @@ These examples help the app recognize recurring supplemental file patterns. - Supplementary file classification - cBioPortal format assessment - Curation report generation - diff --git a/cbio_detector.py b/cbio_detector.py index 4cc86c4..fb1cb4d 100644 --- a/cbio_detector.py +++ b/cbio_detector.py @@ -32,7 +32,7 @@ import pandas as pd -from config import FEW_SHOT_DIR, DETECTION_SAMPLE_ROWS, CBIO_FORMAT_IDS +from config import FEW_SHOT_DIR, DETECTION_SAMPLE_ROWS logger = logging.getLogger(__name__) @@ -238,14 +238,7 @@ def _heuristic_detect(df: pd.DataFrame) -> Tuple[Optional[str], float]: # LLM-powered detector (few-shot) # --------------------------------------------------------------------------- -def _llm_detect(df: pd.DataFrame, examples: List[dict], api_key: str) -> Tuple[str, float, str]: - """ - Use Claude to detect the file type with few-shot examples injected. - Returns (detected_type, confidence, reasoning). - """ - import anthropic - - # Build few-shot block +def _build_detection_prompt(df: pd.DataFrame, examples: List[dict]) -> str: few_shot_block = "" for i, ex in enumerate(examples[:6]): # max 6 examples to keep prompt manageable few_shot_block += f""" @@ -262,7 +255,7 @@ def _llm_detect(df: pd.DataFrame, examples: List[dict], api_key: str) -> Tuple[s col_list = list(df.columns) sample_rows = df.head(DETECTION_SAMPLE_ROWS).to_csv(sep="\t", index=False) - prompt = f"""You are a bioinformatics data curation expert specializing in cBioPortal data formats. + return f"""You are a bioinformatics data curation expert specializing in cBioPortal data formats. Your task: identify which cBioPortal file type this supplemental data file represents. @@ -295,6 +288,24 @@ def _llm_detect(df: pd.DataFrame, examples: List[dict], api_key: str) -> Tuple[s }} """ + +def _parse_detection_response(raw: str) -> Tuple[str, float, str, dict]: + raw = raw.strip() + raw = re.sub(r"^```[^\n]*\n?", "", raw, flags=re.MULTILINE) + raw = re.sub(r"```$", "", raw, flags=re.MULTILINE).strip() + + result = json.loads(raw) + return result["type"], float(result["confidence"]), result.get("reasoning", ""), result.get("column_mappings", {}) + + +def _llm_detect(df: pd.DataFrame, examples: List[dict], api_key: str) -> Tuple[str, float, str, dict]: + """ + Use Claude to detect the file type with few-shot examples injected. + Returns (detected_type, confidence, reasoning, column_mappings). + """ + import anthropic + + prompt = _build_detection_prompt(df, examples) client = anthropic.Anthropic(api_key=api_key) response = client.messages.create( model="claude-sonnet-4-20250514", @@ -302,13 +313,30 @@ def _llm_detect(df: pd.DataFrame, examples: List[dict], api_key: str) -> Tuple[s messages=[{"role": "user", "content": prompt}], ) - raw = response.content[0].text.strip() - # Strip accidental markdown fences - raw = re.sub(r"^```[^\n]*\n?", "", raw, flags=re.MULTILINE) - raw = re.sub(r"```$", "", raw, flags=re.MULTILINE).strip() + return _parse_detection_response(response.content[0].text) - result = json.loads(raw) - return result["type"], float(result["confidence"]), result.get("reasoning", ""), result.get("column_mappings", {}) + +def _openai_detect( + df: pd.DataFrame, + examples: List[dict], + api_key: str, + model: str = "gpt-5.5", +) -> Tuple[str, float, str, dict]: + """ + Use OpenAI to detect the file type with few-shot examples injected. + Returns (detected_type, confidence, reasoning, column_mappings). + """ + from openai import OpenAI + + prompt = _build_detection_prompt(df, examples) + client = OpenAI(api_key=api_key) + response = client.chat.completions.create( + model=model, + max_completion_tokens=800, + messages=[{"role": "user", "content": prompt}], + ) + + return _parse_detection_response(response.choices[0].message.content or "") # --------------------------------------------------------------------------- @@ -319,6 +347,7 @@ def detect_file_type( df: pd.DataFrame, anthropic_api_key: Optional[str] = None, openai_api_key: Optional[str] = None, + openai_model: str = "gpt-5.5", ) -> dict: """ Detect the cBioPortal format of a DataFrame. @@ -366,6 +395,27 @@ def detect_file_type( except Exception as e: logger.error(f"LLM detection failed: {e}") + if openai_api_key: + try: + examples = load_few_shot_examples() + llm_type, llm_conf, reasoning, mappings = _openai_detect( + df, + examples, + openai_api_key, + model=openai_model, + ) + logger.info(f"OpenAI detection: type={llm_type}, confidence={llm_conf:.2f}") + return { + "type": llm_type, + "confidence": llm_conf, + "method": "openai_few_shot", + "reasoning": reasoning, + "column_mappings": mappings, + "low_confidence": llm_conf < DETECTION_CONFIDENCE_THRESHOLD, + } + except Exception as e: + logger.error(f"OpenAI detection failed: {e}") + # 3. Fallback: return best heuristic guess with low confidence flag return { "type": h_type or "clinical_sample", diff --git a/cbioportal_curator.py b/cbioportal_curator.py index 05f0bfb..9575fe4 100644 --- a/cbioportal_curator.py +++ b/cbioportal_curator.py @@ -39,6 +39,9 @@ from spec_match import classify_sheet, ClassificationResult from cbioportal_spec import SPEC_BY_KEY +from xml_metadata import extract_metadata_from_xml as _extract_metadata_from_xml +from xml_metadata import extract_xml_llm_text +from xml_metadata import extract_xml_text # ───────────────────────────────────────────────────────────── # Constants @@ -636,6 +639,11 @@ def _find_int(patterns, text, default="?"): "corresponding_authors": corresp, } + +def extract_metadata_from_xml(xml_source: str | bytes | Path) -> dict: + return _extract_metadata_from_xml(xml_source) + + def _extract_metadata_llm(pdf_text: str, model: str, temperature: float) -> dict: import json, logging llm = load_chat_model(model) @@ -2519,4 +2527,4 @@ def curate( for r in records ], } - return {"report_path": output_path, "summary": summary} \ No newline at end of file + return {"report_path": output_path, "summary": summary} diff --git a/config.py b/config.py index 6bdf733..83bc6a2 100644 --- a/config.py +++ b/config.py @@ -6,9 +6,9 @@ Notes ----- -The Anthropic API key is read directly from the ANTHROPIC_API_KEY environment -variable (or Streamlit secret / sidebar input) inside streamlit_app.py, so it -deliberately is not re-exported here. +API keys are read directly from ANTHROPIC_API_KEY or OPENAI_API_KEY environment +variables (or Streamlit secrets / sidebar input) inside streamlit_app.py, so +they deliberately are not re-exported here. """ import os diff --git a/llm_client.py b/llm_client.py new file mode 100644 index 0000000..35dca98 --- /dev/null +++ b/llm_client.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +import json +import re +import time +from typing import Any + + +def call_anthropic_with_retry( + client, + model: str, + system: str, + user_content: str, + max_tokens: int = 2000, + retries: int = 3, + backoff: float = 5.0, +) -> str: + import anthropic + + last_error: Exception | None = None + for attempt in range(retries): + try: + response = client.messages.create( + model=model, + max_tokens=max_tokens, + system=system, + messages=[{"role": "user", "content": user_content}], + ) + return response.content[0].text + except anthropic.RateLimitError as exc: + last_error = exc + time.sleep(backoff * (attempt + 1)) + except anthropic.APIStatusError as exc: + if exc.status_code >= 500: + last_error = exc + time.sleep(backoff * (attempt + 1)) + else: + raise + except anthropic.APIConnectionError as exc: + last_error = exc + time.sleep(backoff * (attempt + 1)) + + raise last_error or RuntimeError("Anthropic API call failed after retries.") + + +def call_openai_with_retry( + client, + model: str, + system: str, + user_content: str, + max_tokens: int = 2000, + retries: int = 3, + backoff: float = 5.0, +) -> str: + import openai + + last_error: Exception | None = None + for attempt in range(retries): + try: + response = client.chat.completions.create( + model=model, + max_completion_tokens=max_tokens, + messages=[ + {"role": "system", "content": system}, + {"role": "user", "content": user_content}, + ], + ) + content = response.choices[0].message.content or "" + if not content: + finish_reason = response.choices[0].finish_reason + usage = getattr(response, "usage", None) + raise RuntimeError( + "OpenAI returned an empty message content. " + f"finish_reason={finish_reason}, usage={usage}" + ) + return content + except openai.RateLimitError as exc: + last_error = exc + time.sleep(backoff * (attempt + 1)) + except openai.APIStatusError as exc: + if exc.status_code >= 500: + last_error = exc + time.sleep(backoff * (attempt + 1)) + else: + raise + except openai.APIConnectionError as exc: + last_error = exc + time.sleep(backoff * (attempt + 1)) + + raise last_error or RuntimeError("OpenAI API call failed after retries.") + + +def call_llm_with_retry( + provider: str, + api_key: str, + model: str, + system: str, + user_content: str, + max_tokens: int = 2000, +) -> str: + if provider == "Anthropic": + import anthropic + + client = anthropic.Anthropic(api_key=api_key) + return call_anthropic_with_retry( + client=client, + model=model, + system=system, + user_content=user_content, + max_tokens=max_tokens, + ) + + if provider == "OpenAI": + from openai import OpenAI + + client = OpenAI(api_key=api_key) + return call_openai_with_retry( + client=client, + model=model, + system=system, + user_content=user_content, + max_tokens=max_tokens, + ) + + raise ValueError(f"Unsupported LLM provider: {provider}") + + +def parse_llm_json(raw: str) -> dict[str, Any]: + cleaned = raw.strip() + cleaned = re.sub(r"^```[^\n]*\n?", "", cleaned, flags=re.MULTILINE) + cleaned = re.sub(r"```$", "", cleaned, flags=re.MULTILINE).strip() + try: + return json.loads(cleaned) + except json.JSONDecodeError: + start = cleaned.find("{") + end = cleaned.rfind("}") + if start >= 0 and end > start: + return json.loads(cleaned[start : end + 1]) + raise diff --git a/metadata_merge.py b/metadata_merge.py new file mode 100644 index 0000000..5519a94 --- /dev/null +++ b/metadata_merge.py @@ -0,0 +1,47 @@ +from __future__ import annotations + + +DEFAULT_VALUES = { + "", + "?", + "Unknown", +} + + +def is_missing_metadata_value(value) -> bool: + if value is None: + return True + if isinstance(value, (list, tuple, set, dict)): + return len(value) == 0 + return value in DEFAULT_VALUES + + +def build_study_id(cancer_type: str | None, author: str | None, year: str | None) -> str | None: + if cancer_type and author and year: + study_id = f"{cancer_type}_{author.lower()}_{year}" + elif author and year: + study_id = f"study_{author.lower()}_{year}" + else: + return None + + return "".join( + char if char.isalnum() or char == "_" else "_" + for char in study_id.lower() + ).strip("_") + + +def merge_missing_metadata_fields(base: dict, completion: dict) -> dict: + merged = dict(base) + for key, value in completion.items(): + if key not in merged or not is_missing_metadata_value(merged[key]): + continue + if is_missing_metadata_value(value): + continue + merged[key] = value + + merged["study_id_suggestion"] = build_study_id( + cancer_type=merged.get("cancer_type"), + author=merged.get("first_author_surname"), + year=merged.get("year"), + ) + return merged diff --git a/pmc_supplement_fetcher.py b/pmc_supplement_fetcher.py new file mode 100644 index 0000000..c2688c3 --- /dev/null +++ b/pmc_supplement_fetcher.py @@ -0,0 +1,471 @@ +""" +pmc_supplement_fetcher.py +------------------------- +Download supplementary files from PubMed Central by PMCID or PMID. + +PMID input is converted to PMCID using the NCBI idconv utility. Supplementary +links are read from the PMC article XML and downloaded into a caller-provided +directory. Archive files are expanded and supported curation files are returned. +""" + +from __future__ import annotations + +import os +import re +import hashlib +import http.cookiejar +import tarfile +import zipfile +from dataclasses import dataclass +from pathlib import Path +from urllib.error import HTTPError +from urllib.parse import urlencode +from urllib.parse import unquote, urljoin, urlparse +from urllib.request import HTTPCookieProcessor, Request, build_opener, urlopen +import xml.etree.ElementTree as ET + +import requests + + +SUPPORTED_SUPPLEMENT_EXTENSIONS = { + ".xlsx", ".xls", ".csv", ".tsv", ".txt", ".tab", ".maf", ".doc", ".docx", ".pdf", +} +ARCHIVE_EXTENSIONS = {".zip", ".tar", ".tgz", ".gz", ".bz2", ".xz"} +NCBI_TIMEOUT_SECONDS = 30 +NCBI_TOOL_NAME = "cBioAbstractor" +NCBI_CONTACT_EMAIL = os.getenv("NCBI_EMAIL", "cBioAbstractor@example.com") +HTTP_HEADERS = { + "User-Agent": "cBioAbstractor/1.0 (supplementary-file-curation)", + "Accept": "*/*", +} +POW_MAX_ITERATIONS = 20_000_000 + + +@dataclass +class DownloadedSupplement: + path: str + filename: str + source_url: str + + +def normalize_pmcid(value: str) -> str: + """Return a PMCID in PMC123456 format.""" + raw = value.strip() + if not raw: + raise ValueError("PMCID is empty.") + + match = re.search(r"(?:PMC)?(\d+)", raw, flags=re.IGNORECASE) + if not match: + raise ValueError(f"Could not parse PMCID from '{value}'.") + + return f"PMC{match.group(1)}" + + +def pmid_to_pmcid(pmid: str) -> str: + """Convert PMID to PMCID using the PMC ID Converter API.""" + clean_pmid = re.sub(r"\D", "", pmid or "") + if not clean_pmid: + raise ValueError("PMID must contain digits.") + + params = urlencode({ + "ids": clean_pmid, + "idtype": "pmid", + "format": "json", + "tool": NCBI_TOOL_NAME, + "email": NCBI_CONTACT_EMAIL, + }) + request = Request( + f"https://pmc.ncbi.nlm.nih.gov/tools/idconv/api/v1/articles/?{params}", + headers=HTTP_HEADERS, + ) + with urlopen(request, timeout=NCBI_TIMEOUT_SECONDS) as response: + payload = response.read().decode("utf-8") + import json + + payload = json.loads(payload) + records = payload.get("records") or [] + if not records or not records[0].get("pmcid"): + raise ValueError(f"No PMCID found for PMID {clean_pmid}.") + + return normalize_pmcid(records[0]["pmcid"]) + + +def _pmcid_numeric(pmcid: str) -> str: + return normalize_pmcid(pmcid).replace("PMC", "") + + +def _fetch_pmc_xml(pmcid: str) -> str: + response = requests.get( + "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi", + params={"db": "pmc", "id": _pmcid_numeric(pmcid), "retmode": "xml"}, + headers=HTTP_HEADERS, + timeout=NCBI_TIMEOUT_SECONDS, + ) + response.raise_for_status() + if " str | None: + response = requests.get( + "https://www.ncbi.nlm.nih.gov/pmc/utils/oa/oa.fcgi", + params={"id": normalize_pmcid(pmcid)}, + headers=HTTP_HEADERS, + timeout=NCBI_TIMEOUT_SECONDS, + ) + response.raise_for_status() + root = ET.fromstring(response.text) + for link in root.iter("link"): + if link.attrib.get("format") == "tgz" and link.attrib.get("href"): + href = link.attrib["href"] + return href.replace("ftp://ftp.ncbi.nlm.nih.gov", "https://ftp.ncbi.nlm.nih.gov") + return None + + +def _download_oa_package(pmcid: str, output_dir: Path) -> list[Path]: + package_url = _oa_package_url(pmcid) + if not package_url: + return [] + + package_path = _download_url(package_url, output_dir, 0) + extracted = _extract_supported_files(package_path, output_dir) + return [ + path + for path in extracted + if path.name.lower() != f"{normalize_pmcid(pmcid).lower()}.pdf" + ] + + +def _xlink_href(element: ET.Element) -> str: + return ( + element.attrib.get("{http://www.w3.org/1999/xlink}href") + or element.attrib.get("href") + or "" + ).strip() + + +def _supplement_urls(pmcid: str, xml_text: str) -> list[str]: + root = ET.fromstring(xml_text) + urls: list[str] = [] + base_article_url = f"https://pmc.ncbi.nlm.nih.gov/articles/{normalize_pmcid(pmcid)}/" + base_site_url = "https://pmc.ncbi.nlm.nih.gov" + base_instance_bin_url = ( + f"https://pmc.ncbi.nlm.nih.gov/articles/instance/{_pmcid_numeric(pmcid)}/bin/" + ) + + for supp in root.iter(): + if not supp.tag.endswith("supplementary-material"): + continue + + candidate_hrefs = [] + direct_href = _xlink_href(supp) + if direct_href: + candidate_hrefs.append(direct_href) + + for child in supp.iter(): + if child.tag.endswith(("media", "graphic", "inline-supplementary-material")): + href = _xlink_href(child) + if href: + candidate_hrefs.append(href) + + for href in candidate_hrefs: + if not href: + continue + if href.startswith(("http://", "https://")): + url = href + elif href.startswith("/"): + url = urljoin(base_site_url, href) + elif "/" in href: + url = urljoin(base_article_url, href) + else: + url = urljoin(base_instance_bin_url, href) + if url not in urls: + urls.append(url) + + return urls + + +def _safe_filename(filename: str, fallback: str) -> str: + name = unquote(filename or "").strip() + name = os.path.basename(name) + name = re.sub(r"[^A-Za-z0-9._ -]+", "_", name).strip(" .") + return name or fallback + + +def _filename_from_headers(url: str, disposition: str, index: int) -> str: + match = re.search(r'filename="?([^";]+)"?', disposition, flags=re.IGNORECASE) + if match: + return _safe_filename(match.group(1), f"supplement_{index}") + + parsed_name = os.path.basename(urlparse(url).path) + return _safe_filename(parsed_name, f"supplement_{index}") + + +def _filename_from_response(url: str, response: requests.Response, index: int) -> str: + return _filename_from_headers(url, response.headers.get("content-disposition", ""), index) + + +def _extension(path: Path) -> str: + if path.name.lower().endswith(".tar.gz"): + return ".tar.gz" + return path.suffix.lower() + + +def _is_supported_file(path: Path) -> bool: + return path.suffix.lower() in SUPPORTED_SUPPLEMENT_EXTENSIONS + + +def _is_archive(path: Path) -> bool: + lower = path.name.lower() + return ( + lower.endswith((".zip", ".tar", ".tar.gz", ".tgz", ".tar.bz2", ".tar.xz")) + or path.suffix.lower() in ARCHIVE_EXTENSIONS + ) + + +def _parse_pow_challenge(html: str) -> tuple[str, int, str, str] | None: + challenge = re.search(r'POW_CHALLENGE\s*=\s*"([^"]+)"', html) + difficulty = re.search(r'POW_DIFFICULTY\s*=\s*"([^"]+)"', html) + cookie_name = re.search(r'POW_COOKIE_NAME\s*=\s*"([^"]+)"', html) + cookie_path = re.search(r'POW_COOKIE_PATH\s*=\s*"([^"]+)"', html) + if not challenge or not difficulty or not cookie_name: + return None + return ( + challenge.group(1), + int(difficulty.group(1)), + cookie_name.group(1), + cookie_path.group(1) if cookie_path else "/", + ) + + +def _solve_pow_nonce(challenge: str, difficulty: int) -> int: + prefix = "0" * difficulty + for nonce in range(POW_MAX_ITERATIONS): + digest = hashlib.sha256(f"{challenge}{nonce}".encode("utf-8")).hexdigest() + if digest.startswith(prefix): + return nonce + raise ValueError("PMC proof-of-work challenge was not solved within the iteration limit.") + + +def _set_cookie(cookie_jar: http.cookiejar.CookieJar, url: str, name: str, value: str, path: str) -> None: + domain = urlparse(url).hostname or "pmc.ncbi.nlm.nih.gov" + cookie = http.cookiejar.Cookie( + version=0, + name=name, + value=value, + port=None, + port_specified=False, + domain=domain, + domain_specified=False, + domain_initial_dot=False, + path=path or "/", + path_specified=True, + secure=False, + expires=None, + discard=True, + comment=None, + comment_url=None, + rest={}, + rfc2109=False, + ) + cookie_jar.set_cookie(cookie) + + +def _download_url_with_urllib_pow(url: str, output_dir: Path, index: int) -> Path: + cookie_jar = http.cookiejar.CookieJar() + opener = build_opener(HTTPCookieProcessor(cookie_jar)) + request = Request(url, headers=HTTP_HEADERS) + + try: + response = opener.open(request, timeout=NCBI_TIMEOUT_SECONDS) + except HTTPError as exc: + response = exc + + content = response.read() + content_type = response.headers.get("content-type", "").lower() + filename = _filename_from_headers(url, response.headers.get("content-disposition", ""), index) + + if "text/html" in content_type: + html = content.decode("utf-8", errors="replace") + pow_parts = _parse_pow_challenge(html) + if pow_parts: + challenge, difficulty, cookie_name, cookie_path = pow_parts + nonce = _solve_pow_nonce(challenge, difficulty) + _set_cookie(cookie_jar, url, cookie_name, f"{challenge},{nonce}", cookie_path) + + response = opener.open(Request(url, headers=HTTP_HEADERS), timeout=NCBI_TIMEOUT_SECONDS) + content = response.read() + content_type = response.headers.get("content-type", "").lower() + filename = _filename_from_headers(url, response.headers.get("content-disposition", ""), index) + + if "text/html" in content_type: + raise ValueError(f"PMC returned an HTML page instead of {filename}.") + + path = output_dir / filename + if path.exists(): + stem = path.stem + suffix = path.suffix + path = output_dir / f"{stem}_{index}{suffix}" + path.write_bytes(content) + return path + + +def _download_url(url: str, output_dir: Path, index: int) -> Path: + try: + response = requests.get(url, headers=HTTP_HEADERS, timeout=NCBI_TIMEOUT_SECONDS) + response.raise_for_status() + except Exception: + if urlparse(url).hostname == "pmc.ncbi.nlm.nih.gov": + return _download_url_with_urllib_pow(url, output_dir, index) + raise + + content_type = response.headers.get("content-type", "").lower() + filename = _filename_from_response(url, response, index) + if "text/html" in content_type and Path(filename).suffix.lower() in ( + SUPPORTED_SUPPLEMENT_EXTENSIONS | ARCHIVE_EXTENSIONS + ): + if urlparse(url).hostname == "pmc.ncbi.nlm.nih.gov": + return _download_url_with_urllib_pow(url, output_dir, index) + raise ValueError(f"PMC returned an HTML challenge page instead of {filename}.") + + path = output_dir / filename + if path.exists(): + stem = path.stem + suffix = path.suffix + path = output_dir / f"{stem}_{index}{suffix}" + path.write_bytes(response.content) + return path + + +def _safe_extract_path(base_dir: Path, member_name: str) -> Path: + base_dir = base_dir.resolve() + target = (base_dir / member_name).resolve() + try: + target.relative_to(base_dir) + except ValueError: + raise ValueError(f"Archive member escapes extraction directory: {member_name}") + return target + + +def _extract_supported_files(archive_path: Path, output_dir: Path) -> list[Path]: + extract_dir = output_dir / f"{archive_path.stem}_extracted" + extract_dir.mkdir(parents=True, exist_ok=True) + extracted: list[Path] = [] + + if zipfile.is_zipfile(archive_path): + with zipfile.ZipFile(archive_path) as zf: + for info in zf.infolist(): + if info.is_dir(): + continue + target = _safe_extract_path(extract_dir, info.filename) + target.parent.mkdir(parents=True, exist_ok=True) + with zf.open(info) as source, open(target, "wb") as dest: + dest.write(source.read()) + if _is_supported_file(target): + extracted.append(target) + return extracted + + if tarfile.is_tarfile(archive_path): + with tarfile.open(archive_path) as tf: + for member in tf.getmembers(): + if not member.isfile(): + continue + target = _safe_extract_path(extract_dir, member.name) + target.parent.mkdir(parents=True, exist_ok=True) + source = tf.extractfile(member) + if source is None: + continue + with source, open(target, "wb") as dest: + dest.write(source.read()) + if _is_supported_file(target): + extracted.append(target) + return extracted + + return [] + + +def download_pmc_supplements( + identifier: str, + identifier_type: str, + output_dir: str, +) -> tuple[str, list[DownloadedSupplement]]: + """ + Download PMC supplementary files. + + Returns (pmcid, downloaded_files). The returned file list contains only + formats supported by the downstream curation parser. + """ + if identifier_type == "PMID": + pmcid = pmid_to_pmcid(identifier) + elif identifier_type == "PMCID": + pmcid = normalize_pmcid(identifier) + else: + raise ValueError("identifier_type must be 'PMID' or 'PMCID'.") + + out_dir = Path(output_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + downloaded_paths: list[tuple[Path, str]] = [] + download_errors: list[str] = [] + + try: + for path in _download_oa_package(pmcid, out_dir): + downloaded_paths.append((path, "PMC OA package")) + except Exception as exc: + download_errors.append(f"PMC OA package: {exc}") + + try: + xml_text = _fetch_pmc_xml(pmcid) + urls = _supplement_urls(pmcid, xml_text) + except Exception as exc: + urls = [] + download_errors.append(f"PMC XML: {exc}") + + for index, url in enumerate(urls, start=1): + try: + path = _download_url(url, out_dir, index) + except Exception as exc: + download_errors.append(f"{url}: {exc}") + continue + + paths_to_add: list[Path] + if _is_archive(path): + paths_to_add = _extract_supported_files(path, out_dir) + elif _is_supported_file(path): + paths_to_add = [path] + else: + paths_to_add = [] + + for candidate in paths_to_add: + downloaded_paths.append((candidate, url)) + + seen: set[tuple[str, int]] = set() + downloaded: list[DownloadedSupplement] = [] + for path, source_url in downloaded_paths: + try: + signature = (path.name.lower(), path.stat().st_size) + except OSError: + continue + if signature in seen: + continue + seen.add(signature) + downloaded.append( + DownloadedSupplement( + path=str(path), + filename=path.name, + source_url=source_url, + ) + ) + + if not downloaded: + detail = "" + if download_errors: + detail = " Download attempts failed; first error: " + download_errors[0] + raise ValueError( + f"No supported supplementary files could be downloaded for {pmcid}. " + "Supported formats are .xlsx, .csv, .tsv, .txt, .maf, .docx, and .pdf." + f"{detail}" + ) + + return pmcid, downloaded diff --git a/requirements.txt b/requirements.txt index 179b8ff..6974276 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ streamlit>=1.41.0 anthropic>=0.40.0 +openai>=1.0.0 pandas>=2.0.0 openpyxl>=3.1.0 xlrd>=2.0.1 diff --git a/streamlit_app.py b/streamlit_app.py index b10ffd1..a8f2cdb 100644 --- a/streamlit_app.py +++ b/streamlit_app.py @@ -7,24 +7,29 @@ 2. File Classification It removes merge-conflict markers, Docker/backend assumptions, and api_config.py usage. -Set ANTHROPIC_API_KEY as an environment variable or Streamlit secret. +Set ANTHROPIC_API_KEY or OPENAI_API_KEY as an environment variable, .env value, +or Streamlit secret. """ from __future__ import annotations -import json +import io import os import re import shutil import sys import tempfile -import time import traceback +import zipfile from typing import Any import pandas as pd import streamlit as st +from llm_client import call_llm_with_retry as _call_llm_with_retry +from llm_client import parse_llm_json as _parse_llm_json +from metadata_merge import merge_missing_metadata_fields + # ── Path setup ──────────────────────────────────────────────────────────────── _HERE = os.path.dirname(os.path.abspath(__file__)) if _HERE not in sys.path: @@ -42,19 +47,54 @@ # ───────────────────────────────────────────────────────────────────────────── # API key loading # Resolution order: -# 1. ANTHROPIC_API_KEY environment variable +# 1. Provider-specific environment variable # 2. Streamlit secrets # 3. Sidebar input # ───────────────────────────────────────────────────────────────────────────── -def _load_api_key() -> str: - key = os.environ.get("ANTHROPIC_API_KEY", "").strip() +PROVIDER_CONFIG = { + "Anthropic": { + "env": "ANTHROPIC_API_KEY", + "placeholder": "sk-ant-...", + "models": [ + "claude-sonnet-4-20250514", + "claude-3-5-haiku-20241022", + "claude-3-5-sonnet-20241022", + "claude-sonnet-4-6", + "claude-opus-4-6", + "claude-haiku-4-5-20251001", + ], + }, + "OpenAI": { + "env": "OPENAI_API_KEY", + "placeholder": "sk-...", + "models": [ + "gpt-4o", + "gpt-5.5", + "gpt-5.5-pro", + "gpt-5.4", + "gpt-5.4-mini", + "gpt-5.4-nano", + "gpt-5", + "gpt-5-mini", + "gpt-5-nano", + "gpt-4o-mini", + "gpt-4.1", + "gpt-4.1-mini", + ], + }, +} + + +def _load_api_key(provider: str) -> str: + env_name = PROVIDER_CONFIG[provider]["env"] + key = os.environ.get(env_name, "").strip() if key: return key try: - key = st.secrets.get("ANTHROPIC_API_KEY", "").strip() + key = st.secrets.get(env_name, "").strip() if key: - os.environ["ANTHROPIC_API_KEY"] = key + os.environ[env_name] = key return key except Exception: pass @@ -62,16 +102,32 @@ def _load_api_key() -> str: return "" -_API_KEY = _load_api_key() +_API_KEYS = {provider: _load_api_key(provider) for provider in PROVIDER_CONFIG} + + +def _default_provider_index() -> int: + providers = list(PROVIDER_CONFIG.keys()) + if _API_KEYS.get("OpenAI") and not _API_KEYS.get("Anthropic"): + return providers.index("OpenAI") + return providers.index("Anthropic") + +def _get_api_key(provider: str) -> str: + env_name = PROVIDER_CONFIG[provider]["env"] + return (_API_KEYS.get(provider) or os.environ.get(env_name, "")).strip() -def _get_api_key() -> str: - return (_API_KEY or os.environ.get("ANTHROPIC_API_KEY", "")).strip() +def _set_api_key(provider: str, key: str) -> None: + env_name = PROVIDER_CONFIG[provider]["env"] + clean_key = key.strip() + os.environ[env_name] = clean_key + _API_KEYS[provider] = clean_key -def _require_api_key() -> bool: - if not _get_api_key(): - st.error("Please add your Anthropic API key in the sidebar or set ANTHROPIC_API_KEY.") + +def _require_api_key(provider: str) -> bool: + if not _get_api_key(provider): + env_name = PROVIDER_CONFIG[provider]["env"] + st.error(f"Please add your {provider} API key in the sidebar or set {env_name}.") return False return True @@ -98,48 +154,50 @@ def _safe_cleanup(*paths: str) -> None: pass -def _call_anthropic_with_retry( - client, - model: str, - system: str, - user_content: str, - max_tokens: int = 2000, - retries: int = 3, - backoff: float = 5.0, -) -> str: - import anthropic - - last_error: Exception | None = None - for attempt in range(retries): - try: - response = client.messages.create( - model=model, - max_tokens=max_tokens, - system=system, - messages=[{"role": "user", "content": user_content}], - ) - return response.content[0].text - except anthropic.RateLimitError as exc: - last_error = exc - time.sleep(backoff * (attempt + 1)) - except anthropic.APIStatusError as exc: - if exc.status_code >= 500: - last_error = exc - time.sleep(backoff * (attempt + 1)) - else: - raise - except anthropic.APIConnectionError as exc: - last_error = exc - time.sleep(backoff * (attempt + 1)) - - raise last_error or RuntimeError("Anthropic API call failed after retries.") - - -def _parse_llm_json(raw: str) -> dict[str, Any]: - cleaned = raw.strip() - cleaned = re.sub(r"^```[^\n]*\n?", "", cleaned, flags=re.MULTILINE) - cleaned = re.sub(r"```$", "", cleaned, flags=re.MULTILINE).strip() - return json.loads(cleaned) +def _clear_pmc_download_state() -> None: + tmp_dir = st.session_state.get("pmc_download_tmp_dir") + if tmp_dir: + shutil.rmtree(tmp_dir, ignore_errors=True) + for key in [ + "pmc_download_tmp_dir", + "pmc_download_pmcid", + "pmc_download_identifier", + "pmc_download_identifier_type", + "pmc_downloaded_files", + ]: + st.session_state.pop(key, None) + + +def _build_download_zip(files: list[dict[str, str]]) -> bytes: + buffer = io.BytesIO() + used_names: set[str] = set() + with zipfile.ZipFile(buffer, mode="w", compression=zipfile.ZIP_DEFLATED) as archive: + for item in files: + path = item.get("path", "") + filename = item.get("filename") or os.path.basename(path) + if not path or not os.path.exists(path): + continue + + safe_name = os.path.basename(filename) + if safe_name in used_names: + stem, ext = os.path.splitext(safe_name) + suffix = 2 + while f"{stem}_{suffix}{ext}" in used_names: + suffix += 1 + safe_name = f"{stem}_{suffix}{ext}" + used_names.add(safe_name) + archive.write(path, arcname=safe_name) + + return buffer.getvalue() + + +def _detect_pubmed_identifier_type(identifier: str) -> str | None: + value = identifier.strip() + if re.fullmatch(r"PMC\d+", value, flags=re.IGNORECASE): + return "PMCID" + if re.fullmatch(r"\d+", value): + return "PMID" + return None def _looks_tmp(name: str) -> bool: @@ -313,21 +371,28 @@ def _render_inline_report(meta: dict[str, Any], summary: dict[str, Any]) -> None st.caption("Automated curation support for cancer genomics studies.") st.divider() + provider = st.selectbox( + "AI provider", + options=list(PROVIDER_CONFIG.keys()), + index=_default_provider_index(), + key="llm_provider", + ) + provider_env = PROVIDER_CONFIG[provider]["env"] + entered_key = st.text_input( - "Anthropic API key", + f"{provider} API key", type="password", - value="" if _get_api_key() else "", - placeholder="sk-ant-...", - help="For local use, you can also set ANTHROPIC_API_KEY in your shell.", + value="" if _get_api_key(provider) else "", + placeholder=PROVIDER_CONFIG[provider]["placeholder"], + help=f"For local use, you can also set {provider_env} in your shell or .env file.", ) if entered_key: - os.environ["ANTHROPIC_API_KEY"] = entered_key.strip() - _API_KEY = entered_key.strip() + _set_api_key(provider, entered_key) - if _get_api_key(): - st.success("Connected") + if _get_api_key(provider): + st.success(f"{provider} connected") else: - st.warning("API key not configured") + st.warning(f"{provider} API key not configured") st.divider() st.caption("Version 1.2 — Streamlit only") @@ -352,8 +417,8 @@ def _render_inline_report(meta: dict[str, Any], summary: dict[str, Any]) -> None with tab_curate: st.subheader("Curation Report Generator") st.markdown( - "Upload the main paper PDF and supplementary files. The tool extracts " - "study metadata and classifies each file against cBioPortal formats." + "Upload a paper PDF with local supplementary files, or enter a PMID/PMCID " + "to retrieve metadata and supplementary files from PubMed Central." ) try: @@ -368,16 +433,118 @@ def _render_inline_report(meta: dict[str, Any], summary: dict[str, Any]) -> None st.caption("Using embedded cBioPortal format specifications.") st.divider() - col_pdf, col_supp = st.columns(2) - with col_pdf: - paper_pdf = st.file_uploader("Main paper PDF", type=["pdf"], key="paper_pdf") - with col_supp: - supp_files = st.file_uploader( - "Supplementary files", - type=["xlsx", "xls", "csv", "tsv", "txt", "tab", "maf", "doc", "docx", "pdf"], - accept_multiple_files=True, - key="supp_files", - ) + supp_source = st.radio( + "Supplementary source", + options=["Upload files", "PubMed Central"], + horizontal=True, + key="supp_source", + ) + + paper_pdf = None + supp_files = [] + pmc_identifier = "" + pmc_identifier_type = None + + if supp_source == "Upload files": + if st.session_state.get("pmc_downloaded_files"): + _clear_pmc_download_state() + col_pdf, col_supp = st.columns(2) + with col_pdf: + paper_pdf = st.file_uploader("Main paper PDF", type=["pdf"], key="paper_pdf") + with col_supp: + supp_files = st.file_uploader( + "Supplementary files", + type=["xlsx", "xls", "csv", "tsv", "txt", "tab", "maf", "doc", "docx", "pdf"], + accept_multiple_files=True, + key="supp_files", + ) + else: + with st.container(): + pmc_identifier = st.text_input( + "PMID or PMCID", + placeholder="34493867 or PMC8432745", + key="pmc_identifier", + ).strip() + pmc_identifier_type = _detect_pubmed_identifier_type(pmc_identifier) + if pmc_identifier and pmc_identifier_type: + st.caption(f"Detected {pmc_identifier_type}.") + elif pmc_identifier: + st.warning("Enter a numeric PMID or a PMCID such as PMC8432745.") + + current_download_matches = ( + st.session_state.get("pmc_download_identifier") == pmc_identifier + and st.session_state.get("pmc_download_identifier_type") == pmc_identifier_type + ) + if pmc_identifier and not current_download_matches: + _clear_pmc_download_state() + + if st.button( + "Download supplementary files and study full text", + disabled=not pmc_identifier_type, + key="download_pmc_supp_files", + ): + from pmc_supplement_fetcher import download_pmc_supplements + + _clear_pmc_download_state() + pmc_tmp_dir = tempfile.mkdtemp() + try: + with st.spinner("Downloading supplementary files from PubMed Central..."): + pmcid, downloaded = download_pmc_supplements( + identifier=pmc_identifier, + identifier_type=pmc_identifier_type, + output_dir=pmc_tmp_dir, + ) + st.session_state["pmc_download_tmp_dir"] = pmc_tmp_dir + st.session_state["pmc_download_pmcid"] = pmcid + st.session_state["pmc_download_identifier"] = pmc_identifier + st.session_state["pmc_download_identifier_type"] = pmc_identifier_type + st.session_state["pmc_downloaded_files"] = [ + { + "path": item.path, + "filename": item.filename, + "source_url": item.source_url, + } + for item in downloaded + ] + st.success(f"Downloaded {len(downloaded)} supplementary file(s) from {pmcid}.") + except Exception as exc: + shutil.rmtree(pmc_tmp_dir, ignore_errors=True) + print(f"Supplementary download failed: {exc}", file=sys.stderr) + traceback.print_exc() + st.error( + "Impossible to retrieve supplementary files for this identifier. " + "Please check that the PMID or PMCID is correct." + ) + + downloaded_files = st.session_state.get("pmc_downloaded_files") or [] + if downloaded_files: + st.success( + f"Ready: {len(downloaded_files)} file(s) from " + f"{st.session_state.get('pmc_download_pmcid', 'PMC')}." + ) + for idx, item in enumerate(downloaded_files): + cols = st.columns([0.08, 0.52, 0.40]) + if cols[0].button("X", key=f"remove_pmc_file_{idx}", help="Remove this file"): + try: + os.remove(item["path"]) + except OSError: + pass + st.session_state["pmc_downloaded_files"] = [ + row for row_idx, row in enumerate(downloaded_files) if row_idx != idx + ] + st.rerun() + cols[1].write(item["filename"]) + cols[2].caption(item["source_url"]) + + zip_bytes = _build_download_zip(downloaded_files) + if zip_bytes: + st.download_button( + "Download selected supplementary files (.zip)", + data=zip_bytes, + file_name=f"{st.session_state.get('pmc_download_pmcid', 'pmc')}_supplementary_files.zip", + mime="application/zip", + key="download_pmc_supp_zip", + ) if supp_files: st.markdown("#### Confirm uploaded filenames") @@ -391,20 +558,23 @@ def _render_inline_report(meta: dict[str, Any], summary: dict[str, Any]) -> None with st.expander("Options"): model = st.selectbox( - "Anthropic model", - options=[ - "claude-sonnet-4-20250514", - "claude-3-5-haiku-20241022", - "claude-3-5-sonnet-20241022", - "claude-sonnet-4-6", # latest Sonnet — recommended default - "claude-opus-4-6", # most capable, slower/more expensive - "claude-haiku-4-5-20251001", # fastest/cheapest - ], + f"{provider} model", + options=PROVIDER_CONFIG[provider]["models"], index=0, ) - if st.button("Generate Curation Report", disabled=paper_pdf is None, type="primary"): - if not _require_api_key(): + missing_supp_source = ( + (supp_source == "Upload files" and not supp_files) + or (supp_source == "PubMed Central" and not st.session_state.get("pmc_downloaded_files")) + ) + missing_paper_source = supp_source == "Upload files" and paper_pdf is None + + if st.button( + "Generate Curation Report", + disabled=missing_paper_source or missing_supp_source, + type="primary", + ): + if not _require_api_key(provider): st.stop() pdf_tmp: str | None = None @@ -412,33 +582,71 @@ def _render_inline_report(meta: dict[str, Any], summary: dict[str, Any]) -> None try: with st.spinner("Saving uploaded files..."): - pdf_tmp = _save_upload_to_tmp(paper_pdf) - for idx, uploaded in enumerate(supp_files or []): - filename = st.session_state.get(f"fname_{idx}") or uploaded.name - supp_tmps.append(_save_upload_to_tmp(uploaded, filename=filename)) - - with st.spinner("Step 1 of 2 — Extracting study metadata from PDF..."): - import anthropic - from cbioportal_curator import SYSTEM_PROMPT_CURATOR, _extract_pdf_text - - pdf_text = _extract_pdf_text(pdf_tmp) - meta: dict[str, Any] = {} - if pdf_text.strip(): - client = anthropic.Anthropic(api_key=_get_api_key()) - raw_meta = _call_anthropic_with_retry( - client=client, - model=model, - system=SYSTEM_PROMPT_CURATOR, - user_content=pdf_text[:40000], - max_tokens=2000, - ) - try: - meta = _parse_llm_json(raw_meta) - except Exception: - st.warning("Metadata extraction returned unexpected format. Continuing with file classification.") - meta = {} + if supp_source == "Upload files": + pdf_tmp = _save_upload_to_tmp(paper_pdf) + for idx, uploaded in enumerate(supp_files or []): + filename = st.session_state.get(f"fname_{idx}") or uploaded.name + supp_tmps.append(_save_upload_to_tmp(uploaded, filename=filename)) else: - st.warning("Could not extract text from the PDF. Metadata fields will be blank.") + downloaded_files = st.session_state.get("pmc_downloaded_files") or [] + supp_tmps.extend(item["path"] for item in downloaded_files) + + meta: dict[str, Any] = {} + if supp_source == "PubMed Central": + with st.spinner("Step 1 of 2 — Extracting study metadata from PMC XML..."): + from cbioportal_curator import ( + SYSTEM_PROMPT_CURATOR, + extract_metadata_from_xml, + extract_xml_llm_text, + ) + from pmc_supplement_fetcher import _fetch_pmc_xml + + pmcid = st.session_state.get("pmc_download_pmcid") + if not pmcid: + raise RuntimeError("Missing PMCID for PMC metadata extraction.") + + xml_text = _fetch_pmc_xml(pmcid) + meta = extract_metadata_from_xml(xml_text) + llm_text = extract_xml_llm_text(xml_text) + if llm_text.strip(): + try: + raw_meta = _call_llm_with_retry( + provider=provider, + api_key=_get_api_key(provider), + model=model, + system=SYSTEM_PROMPT_CURATOR, + user_content=llm_text[:40000], + max_tokens=2000, + ) + meta = merge_missing_metadata_fields(meta, _parse_llm_json(raw_meta)) + except Exception: + st.warning( + "XML metadata completion returned unexpected format. " + "Continuing with structured XML metadata only." + ) + else: + st.warning("Could not extract text from the PMC XML. Using structured XML metadata only.") + else: + with st.spinner("Step 1 of 2 — Extracting study metadata from PDF..."): + from cbioportal_curator import SYSTEM_PROMPT_CURATOR, _extract_pdf_text + + pdf_text = _extract_pdf_text(pdf_tmp) + if pdf_text.strip(): + raw_meta = _call_llm_with_retry( + provider=provider, + api_key=_get_api_key(provider), + model=model, + system=SYSTEM_PROMPT_CURATOR, + user_content=pdf_text[:40000], + max_tokens=2000, + ) + try: + meta = _parse_llm_json(raw_meta) + except Exception: + st.warning("Metadata extraction returned unexpected format. Continuing with file classification.") + meta = {} + else: + st.warning("Could not extract text from the PDF. Metadata fields will be blank.") with st.spinner(f"Step 2 of 2 — Classifying {len(supp_tmps)} supplementary file(s)..."): from cbioportal_curator import _analyse_supplementary_files @@ -478,7 +686,10 @@ def _render_inline_report(meta: dict[str, Any], summary: dict[str, Any]) -> None st.code(traceback.format_exc()) st.stop() finally: - _safe_cleanup(pdf_tmp or "", *supp_tmps) + if supp_source == "Upload files": + _safe_cleanup(pdf_tmp or "", *supp_tmps) + else: + _safe_cleanup(pdf_tmp or "") st.success("Curation complete.") st.divider() @@ -519,10 +730,15 @@ def _render_inline_report(meta: dict[str, Any], summary: dict[str, Any]) -> None st.markdown("#### File Preview") st.dataframe(df.head(10), use_container_width=True) - api_key = _get_api_key() if use_ai else None + api_key = _get_api_key(provider) if use_ai else None with st.spinner("Classifying file..."): try: - result = detect_file_type(df, anthropic_api_key=api_key) + result = detect_file_type( + df, + anthropic_api_key=api_key if provider == "Anthropic" else None, + openai_api_key=api_key if provider == "OpenAI" else None, + openai_model=model, + ) except Exception as exc: st.error(f"Classification failed: {exc}") st.stop() diff --git a/xml_metadata.py b/xml_metadata.py new file mode 100644 index 0000000..2e9ed48 --- /dev/null +++ b/xml_metadata.py @@ -0,0 +1,317 @@ +from __future__ import annotations + +import re +import xml.etree.ElementTree as ET +from pathlib import Path + + +METADATA_DEFAULTS = { + "study_title": None, + "cancer_type": None, + "cancer_type_full": None, + "num_samples": None, + "num_patients": None, + "reference_genome": None, + "sequencing_types": [], + "pmid": None, + "doi": None, + "first_author_surname": None, + "year": None, + "journal": None, + "study_id_suggestion": None, + "description": None, + "key_findings": [], + "primary_site": None, + "cohort_description": None, + "meta_description": None, + "data_repositories": [], + "corresponding_authors": None, +} + + +def _xml_local_name(tag: str) -> str: + return tag.rsplit("}", 1)[-1] if "}" in tag else tag + + +def _xml_text(element: ET.Element | None) -> str: + if element is None: + return "" + text_parts = [text.strip() for text in element.itertext() if text.strip()] + return re.sub(r"\s+", " ", " ".join(text_parts)).strip() + + +def _first_xml_text(root: ET.Element, local_names: tuple[str, ...]) -> str: + for local_name in local_names: + for element in root.iter(): + if _xml_local_name(element.tag) != local_name: + continue + text = _xml_text(element) + if text: + return text + return "" + + +def _first_child(element: ET.Element | None, local_name: str) -> ET.Element | None: + if element is None: + return None + for child in element: + if _xml_local_name(child.tag) == local_name: + return child + return None + + +def _first_descendant(element: ET.Element | None, local_name: str) -> ET.Element | None: + if element is None: + return None + for child in element.iter(): + if _xml_local_name(child.tag) == local_name: + return child + return None + + +def _article_root(root: ET.Element) -> ET.Element: + if _xml_local_name(root.tag) == "article": + return root + for element in root.iter(): + if _xml_local_name(element.tag) == "article": + return element + return root + + +def _article_meta(article: ET.Element) -> ET.Element | None: + front = _first_child(article, "front") + return _first_child(front, "article-meta") + + +def _journal_meta(article: ET.Element) -> ET.Element | None: + front = _first_child(article, "front") + return _first_child(front, "journal-meta") + + +def _xml_article_ids(root: ET.Element) -> dict[str, str]: + ids: dict[str, str] = {} + for element in root.iter(): + if _xml_local_name(element.tag) != "article-id": + continue + id_type = ( + element.attrib.get("pub-id-type") + or element.attrib.get("article-id-type") + or "" + ).lower() + text = _xml_text(element).rstrip(".,;") + if id_type and text: + ids[id_type] = text + return ids + + +def _xml_journal_title(root: ET.Element) -> str: + journal_title = _first_descendant(root, "journal-title") + if journal_title is not None: + return _xml_text(journal_title) + + for journal_id in root.iter(): + if _xml_local_name(journal_id.tag) != "journal-id": + continue + if journal_id.attrib.get("journal-id-type") in {"nlm-ta", "iso-abbrev"}: + return _xml_text(journal_id) + + return _first_xml_text(root, ("journal-id",)) + + +def _xml_publication_year(root: ET.Element) -> str: + for pub_date in root.iter(): + if _xml_local_name(pub_date.tag) != "pub-date": + continue + for child in pub_date: + if _xml_local_name(child.tag) == "year": + year = _xml_text(child) + if year: + return year + return _first_xml_text(root, ("year",)) + + +def _xml_first_author_surname(root: ET.Element) -> str: + for contrib in root.iter(): + if _xml_local_name(contrib.tag) != "contrib": + continue + contrib_type = contrib.attrib.get("contrib-type", "") + if contrib_type and contrib_type != "author": + continue + for child in contrib.iter(): + if _xml_local_name(child.tag) == "surname": + surname = _xml_text(child) + if surname: + return surname + return "" + + +def _xml_corresponding_authors(root: ET.Element) -> str: + values: list[str] = [] + for contrib in root.iter(): + if _xml_local_name(contrib.tag) != "contrib": + continue + is_corresp = contrib.attrib.get("corresp", "").lower() in {"yes", "true"} + emails = [ + _xml_text(child) + for child in contrib.iter() + if _xml_local_name(child.tag) == "email" and _xml_text(child) + ] + if not is_corresp and not emails: + continue + names = [ + _xml_text(child) + for child in contrib.iter() + if _xml_local_name(child.tag) == "name" and _xml_text(child) + ] + value = ", ".join(names + emails) + if value and value not in values: + values.append(value) + + for corresp in root.iter(): + if _xml_local_name(corresp.tag) != "corresp": + continue + text = _xml_text(corresp) + if text and text not in values: + values.append(text) + + return "; ".join(values) + + +def _parse_xml_root(xml_source: str | bytes | Path) -> ET.Element: + + if isinstance(xml_source, bytes): + return ET.fromstring(xml_source) + + if isinstance(xml_source, Path): + return ET.parse(xml_source).getroot() + + source = str(xml_source) + if source.lstrip().startswith("<"): + return ET.fromstring(source) + return ET.parse(source).getroot() + + +def _xml_article_title(article_meta: ET.Element | None) -> str: + title_group = _first_child(article_meta, "title-group") + article_title = _first_child(title_group, "article-title") + return _xml_text(article_title) + + +def _xml_abstract(article_meta: ET.Element | None) -> str: + abstract = _first_child(article_meta, "abstract") + return _xml_text(abstract) + + +def _xml_body(article: ET.Element) -> str: + body = _first_child(article, "body") + return _xml_text(body) + + +def _first_sentence(text: str) -> str: + if not text: + return "" + head, separator, _tail = text.partition(". ") + return f"{head}." if separator else text + + +def extract_xml_text(xml_source: str | bytes | Path) -> str: + """ + Extract readable article text from a PMC/JATS-like XML document. + + Accepts XML text, bytes, or a filesystem path. The returned text is useful + for inspection/debugging and is scoped to the first article in the XML. + """ + root = _parse_xml_root(xml_source) + article = _article_root(root) + sections: list[str] = [] + preferred_tags = { + "article-title", + "abstract", + "kwd-group", + "body", + "back", + } + + for element in article.iter(): + if _xml_local_name(element.tag) not in preferred_tags: + continue + text = _xml_text(element) + if text and text not in sections: + sections.append(text) + + if not sections: + sections.append(_xml_text(root)) + + return "\n".join(sections) + + +def extract_xml_llm_text(xml_source: str | bytes | Path) -> str: + """ + Extract clean article text for LLM completion from JATS XML. + + Includes only title, abstract, and body from the first article. It excludes + back matter/references to avoid contaminating metadata extraction. + """ + root = _parse_xml_root(xml_source) + article = _article_root(root) + article_meta = _article_meta(article) + + sections = [ + ("Title", _xml_article_title(article_meta)), + ("Abstract", _xml_abstract(article_meta)), + ("Body", _xml_body(article)), + ] + return "\n\n".join( + f"{label}\n{text}" + for label, text in sections + if text + ) + + +def extract_metadata_from_xml(xml_source: str | bytes | Path) -> dict: + """ + Extract study metadata from PMC/JATS XML without using PDF text or an LLM. + + Only structured JATS fields are used. Values that are not represented as + dedicated article metadata in JATS are left blank/default for now. + """ + root = _parse_xml_root(xml_source) + article = _article_root(root) + article_meta = _article_meta(article) + journal_meta = _journal_meta(article) + article_scope = article_meta if article_meta is not None else article + journal_scope = journal_meta if journal_meta is not None else article + meta = dict(METADATA_DEFAULTS) + + abstract = _xml_abstract(article_meta) + description = _first_sentence(abstract) + article_ids = _xml_article_ids(article_scope) + structured = { + "study_title": _xml_article_title(article_meta), + "journal": _xml_journal_title(journal_scope), + "year": _xml_publication_year(article_scope), + "pmid": article_ids.get("pmid", ""), + "doi": article_ids.get("doi", ""), + "first_author_surname": _xml_first_author_surname(article_scope), + "description": description, + "meta_description": description[:200], + "corresponding_authors": _xml_corresponding_authors(article_scope), + } + + for key, value in structured.items(): + if value: + meta[key] = value + + author = meta.get("first_author_surname", "") + year = meta.get("year", "") + cancer_t = meta.get("cancer_type") + if cancer_t and author and year: + study_id = f"{cancer_t}_{author.lower()}_{year}" + elif author and year: + study_id = f"study_{author.lower()}_{year}" + else: + study_id = "" + if study_id: + meta["study_id_suggestion"] = re.sub(r"[^a-z0-9_]", "_", study_id).strip("_") + + return meta