forked from AshChadha-iitg/OpenMath
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathevaluate_gsm8k.py
More file actions
128 lines (99 loc) · 3.84 KB
/
evaluate_gsm8k.py
File metadata and controls
128 lines (99 loc) · 3.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import argparse
import csv
import re
import math
from fractions import Fraction
from datasets import load_dataset
import inference
NUM_RE = re.compile(r"-?\d+\/?\d*\.?\d*")
def parse_numeric(s: str):
"""Try to extract a numeric value from a string. Returns float or None."""
if not s or not isinstance(s, str):
return None
# Try fraction first
frac_match = re.search(r"(\d+)/(\d+)", s)
if frac_match:
try:
return float(Fraction(int(frac_match.group(1)), int(frac_match.group(2))))
except Exception:
pass
# Find decimals or integers
nums = re.findall(r"-?\d+\.?\d*", s)
if not nums:
return None
# Prefer last numeric token (often final answer)
token = nums[-1]
try:
return float(token)
except Exception:
return None
def normalize_reference_answer(ans: str):
# GSM8K references sometimes include explanation; extract numeric
return parse_numeric(ans)
def extract_predicted_answer(text: str):
# heuristic: look for last numeric occurrence in model output
return parse_numeric(text)
def main():
parser = argparse.ArgumentParser(description="Evaluate OpenMath on GSM8K test set")
parser.add_argument("--base_model", type=str, default=None)
parser.add_argument("--adapter_path", type=str, default=None)
parser.add_argument("--cot", action="store_true")
parser.add_argument("--temperature", type=float, default=0.0)
parser.add_argument("--top_p", type=float, default=1.0)
parser.add_argument("--max_new_tokens", type=int, default=200)
parser.add_argument("--limit", type=int, default=0, help="Limit samples for testing (0 = full set)")
parser.add_argument("--outfile", type=str, default="gsm8k_eval_results.csv")
args = parser.parse_args()
ds = load_dataset("gsm8k", "main", split="test")
total = len(ds)
if args.limit and args.limit > 0:
total = min(total, args.limit)
correct = 0
rows = []
for i, sample in enumerate(ds):
if i >= total:
break
question = sample.get("question") or sample.get("problem") or ""
ref = sample.get("answer") or sample.get("output") or ""
ref_val = normalize_reference_answer(ref)
# Generate
try:
pred_text = inference.generate_solution(
problem=question,
cot=args.cot,
temperature=args.temperature,
top_p=args.top_p,
max_new_tokens=args.max_new_tokens,
base_model=args.base_model,
adapter_path=args.adapter_path,
)
except Exception as e:
pred_text = f"<error: {e}>"
pred_val = extract_predicted_answer(pred_text)
is_correct = False
if ref_val is not None and pred_val is not None:
# numeric comparison with tolerance
if math.isclose(ref_val, pred_val, rel_tol=1e-2, abs_tol=1e-2):
is_correct = True
if is_correct:
correct += 1
rows.append({
"index": i,
"question": question,
"reference": ref,
"reference_value": ref_val,
"prediction_text": pred_text,
"prediction_value": pred_val,
"correct": is_correct,
})
print(f"[{i+1}/{total}] correct={correct} question='{question[:60]}...'")
accuracy = correct / total if total > 0 else 0.0
print(f"\nFinished. Accuracy: {accuracy:.4f} ({correct}/{total})")
# Write CSV
with open(args.outfile, "w", newline='', encoding='utf-8') as f:
writer = csv.DictWriter(f, fieldnames=["index", "question", "reference", "reference_value", "prediction_text", "prediction_value", "correct"])
writer.writeheader()
for r in rows:
writer.writerow(r)
if __name__ == "__main__":
main()