-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathaccuracy.py
More file actions
executable file
·69 lines (56 loc) · 2.23 KB
/
accuracy.py
File metadata and controls
executable file
·69 lines (56 loc) · 2.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#!/usr/bin/env python3
import json
import sys
from collections import Counter
def _require_number(item, key: str, lineno: int):
if key not in item:
raise ValueError(f"line {lineno}: missing required field `{key}`")
value = item[key]
if value is None:
raise ValueError(f"line {lineno}: field `{key}` is None")
if not isinstance(value, (int, float)):
raise ValueError(f"line {lineno}: field `{key}` must be number, got {type(value).__name__}")
return value
def main(infile, print_=False):
eval_total = 0
eval_correct = 0
fail_empty_answer = 0
round_sum = 0.0
token_sum = 0.0
reason_counter = Counter()
with open(infile, "r") as f:
for lineno, line in enumerate(f, start=1):
item = json.loads(line)
eval_total += 1
reason = item.get("stop_reason", "missing_stop_reason")
reason_counter[str(reason)] += 1
# Strict mode: these fields must exist and be numeric.
round_sum += _require_number(item, "num_rounds", lineno)
token_sum += _require_number(item, "total_tokens", lineno)
gt_s = str(item.get("gt", "")).strip().upper()
ans_s = str(item.get("answer", "")).strip().upper()
if ans_s == "":
fail_empty_answer += 1
continue
if gt_s == ans_s:
eval_correct += 1
if eval_total == 0:
raise ValueError(f"{infile} is empty")
total_accuracy = eval_correct / eval_total
avg_num_rounds = round_sum / eval_total
avg_total_tokens = token_sum / eval_total
if print_:
print(f"Evaluating {infile}...")
print("summary:")
print(f" total entries: {eval_total}")
print(f" accuracy: {eval_correct}/{eval_total} = {total_accuracy:.4%}")
print(f" failed (empty answer): {fail_empty_answer}")
print(f" avg rounds: {avg_num_rounds:.4f}")
print(f" avg total_tokens: {avg_total_tokens:.4f}")
print(" stop_reason stats:")
for reason, cnt in reason_counter.most_common():
print(f" {reason}: {cnt}")
return total_accuracy, eval_total
if __name__ == "__main__":
infile = sys.argv[1]
main(infile, print_=True)