-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patheval_sanity.py
More file actions
149 lines (123 loc) · 4.38 KB
/
eval_sanity.py
File metadata and controls
149 lines (123 loc) · 4.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import json
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parent
DATA_REVIEW = ROOT / "data_review"
if str(DATA_REVIEW) not in sys.path:
sys.path.insert(0, str(DATA_REVIEW))
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from data_review.dataset import InverseIndex # noqa: E402
from data_review.auxiliaries import clean_chunk # noqa: E402
try:
from rankers.bm25 import BM25 # noqa: E402
except Exception as exc: # pragma: no cover - debug-only
BM25 = None
BM25_IMPORT_ERROR = exc
def load_queries(path):
queries = []
with open(path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
obj = json.loads(line)
queries.append((obj["_id"], obj["text"]))
return queries
def load_qrels(path):
qrels = {}
with open(path, "r", encoding="utf-8") as f:
header = f.readline()
if header and header.lower().startswith("query-id"):
pass
else:
if header:
parts = header.strip().split("\t")
if len(parts) == 3:
qid, docid, score = parts
qrels.setdefault(qid, {})[docid] = int(score)
for line in f:
line = line.strip()
if not line:
continue
parts = line.split("\t")
if len(parts) != 3:
continue
qid, docid, score = parts
qrels.setdefault(qid, {})[docid] = int(score)
return qrels
def pick_first_query_with_qrels(queries, qrels):
for qid, qtext in queries:
if qid in qrels:
return qid, qtext
return None, None
def main():
corpus_path = ROOT / "scidocs" / "corpus.jsonl"
queries_path = ROOT / "scidocs" / "queries.jsonl"
qrels_path = ROOT / "scidocs" / "qrels" / "test.tsv"
if not corpus_path.exists():
print(f"Missing corpus: {corpus_path}")
return
if not queries_path.exists():
print(f"Missing queries: {queries_path}")
return
if not qrels_path.exists():
print(f"Missing qrels: {qrels_path}")
return
queries = load_queries(queries_path)
qrels = load_qrels(qrels_path)
print("=== QID / Qrels Checks ===")
print("queries:", len(queries))
print("qrels qids:", len(qrels))
with_qrels = sum(1 for qid, _ in queries if qid in qrels)
print("queries with qrels:", with_qrels)
missing = next((qid for qid, _ in queries if qid not in qrels), None)
print("example missing qid:", missing)
qid, qtext = pick_first_query_with_qrels(queries, qrels)
if qid is None:
print("No queries matched qrels. Check splits/paths.")
return
print("\n=== Index Build ===")
idx = InverseIndex(str(corpus_path))
idx.populate_id_map()
idx.construct_inverse_index()
idx.n_docs = len(idx.id_map)
print("docs:", len(idx.id_map))
print("lengths:", len(idx.lengths))
avgdl = sum(idx.lengths.values()) / len(idx.lengths) if idx.lengths else 0.0
print("avgdl (stored):", idx.avg_len)
print("avgdl (recomputed):", avgdl)
print("\n=== TF Check ===")
try:
term = next(iter(idx.index_dict))
doc_id = next(iter(idx.index_dict[term]))
tokens = clean_chunk(idx.id_map[doc_id]["text"])
tf_count = tokens.count(term)
tf_index = idx.index_dict[term][doc_id]
print("term:", term)
print("doc_id:", doc_id)
print("tf count:", tf_count)
print("tf in index:", tf_index)
except Exception as exc:
print("TF check failed:", exc)
print("\n=== Query Tokenization Check ===")
tokens = clean_chunk(qtext)
missing_terms = [t for t in tokens if t not in idx.index_dict]
print("qid:", qid)
print("query:", qtext)
print("tokens:", tokens)
print("missing terms:", missing_terms)
if BM25 is None:
print("\nBM25 not available:", BM25_IMPORT_ERROR)
return
print("\n=== BM25 Retrieval Sanity ===")
bm25 = BM25(idx)
scores = bm25.retrieve(qtext)
top5 = [doc for doc, _ in sorted(scores.items(), key=lambda x: x[1], reverse=True)[:5]]
gold_pos = {doc for doc, score in qrels[qid].items() if score > 0}
overlap = set(top5) & gold_pos
print("gold size:", len(gold_pos))
print("top5:", top5)
print("overlap:", overlap)
if __name__ == "__main__":
main()