-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinspect_rag.py
More file actions
51 lines (42 loc) · 1.54 KB
/
inspect_rag.py
File metadata and controls
51 lines (42 loc) · 1.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import json
import random
import logging
from pathlib import Path
from src.rag import FAISSRetriever
# Silence logs except errors
logging.basicConfig(level=logging.ERROR)
logger = logging.getLogger("src.rag")
logger.setLevel(logging.INFO)
def main():
# 1. Initialize retriever
retriever = FAISSRetriever()
print("Loading RAG index (CACHED) and E5 model...")
retriever.load()
print("RAG ready.\n")
# 2. Get 5 random test files
test_dir = Path("data/test_set")
test_files = list(test_dir.glob("*.json"))
sample_files = random.sample(test_files, 5)
print("="*100)
print("RAG INSPECTION: 5 RANDOM TEST CASES")
print("="*100)
for i, f in enumerate(sample_files, 1):
data = json.loads(f.read_text())
query = data["query"]
gt = data["gt"]
valid_codes = set(data["icd_codes"])
print(f"\n[{i}/5] TEST CASE: {f.name}")
print(f"PATIENT SYMPTOMS: {query[:300]}...")
print(f"GROUND TRUTH ICD-10 CODE: {gt}")
print("-" * 40)
# 3. Retrieve
results = retriever.retrieve(query, top_k=5)
print("TOP 5 RAG RESULTS (PROTOCOL NAMES & CODES):")
for j, res in enumerate(results, 1):
is_match = any(c in valid_codes for c in res.icd_codes) or res.protocol_id == gt
match_str = " [MATCHED VERIFIED CODES] ✅" if is_match else ""
print(f" {j}. {res.title}")
print(f" Codes: {', '.join(res.icd_codes)}{match_str}")
print("-" * 100)
if __name__ == "__main__":
main()