forked from plau666/ContinuousBenchEval
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathevaluate.py
More file actions
220 lines (187 loc) · 7.62 KB
/
Copy pathevaluate.py
File metadata and controls
220 lines (187 loc) · 7.62 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
"""ContinuousBenchEval — standalone evaluation entry point.
Runs QA evaluation (exact match + fuzzy match) on a saved checkpoint.
Usage:
# Full fine-tune HF checkpoint (weights live in --checkpoint)
python evaluate.py --framework hf \\
--checkpoint outputs/cbe/news/.../checkpoints/step_001000 \\
--qa_data data/news/testqa.jsonl
# LoRA HF checkpoint (adapter in --checkpoint, base model via --model)
python evaluate.py --framework hf \\
--checkpoint outputs/cbe/news/.../checkpoints/step_001000 \\
--model google/gemma-3-1b-pt \\
--lora_rank 128 \\
--qa_data data/news/testqa.jsonl
# LoRA KD checkpoint (base pretrained resolved from --model Gemma name)
python evaluate.py --framework kd \\
--checkpoint outputs/cbe/news/.../checkpoints/ckpt_1000 \\
--model gemma3-1b-pt \\
--lora_rank 128 \\
--qa_data data/news/testqa.jsonl
"""
import argparse
import os
import sys
_FRAMEWORK_ALIASES = {
"hf": "huggingface",
"kd": "kauldron",
"huggingface": "huggingface",
"kauldron": "kauldron",
}
def main():
parser = argparse.ArgumentParser(description="ContinuousBenchEval evaluation")
parser.add_argument("--checkpoint", required=True, help="Path to checkpoint directory")
parser.add_argument("--qa_data", required=True, help="Path to QA .jsonl file")
parser.add_argument(
"--framework",
default="hf",
choices=list(_FRAMEWORK_ALIASES.keys()),
help="hf (huggingface) or kd (kauldron)",
)
parser.add_argument(
"--model",
default=None,
help="Base model name: HF hub ID (e.g. google/gemma-3-1b-pt) "
"or KD name (e.g. gemma3-1b-pt). Required for LoRA checkpoints.",
)
parser.add_argument("--lora_rank", type=int, default=None)
parser.add_argument("--prompt_prefix", default="")
parser.add_argument("--prompt_template", default="Q: {question}\nA:")
parser.add_argument("--max_new_tokens", type=int, default=50)
parser.add_argument("--batch_size", type=int, default=16)
parser.add_argument("--temperature", type=float, default=0.0)
parser.add_argument("--top_k", type=int, default=None)
parser.add_argument("--top_p", type=float, default=None)
parser.add_argument(
"--parser",
default=None,
help="Answer parser: 'finegrained_geminon' or None (default lowercase/substring match)",
)
parser.add_argument(
"--num_examples",
type=int,
default=10,
help="Print this many random prompt/completion examples with verdicts",
)
parser.add_argument(
"--save_details",
default=None,
help="If set, save all per-example results as JSONL at this path",
)
args = parser.parse_args()
framework = _FRAMEWORK_ALIASES[args.framework]
if framework == "huggingface":
metrics = _eval_hf(args)
else:
metrics = _eval_kd(args)
print("\n" + "=" * 60)
print("EVALUATION RESULTS")
print("=" * 60)
for k, v in metrics.items():
if isinstance(v, float):
print(f" {k}: {v:.4f}")
else:
print(f" {k}: {v}")
print("=" * 60)
# ---------------------------------------------------------------------------
# HuggingFace path
# ---------------------------------------------------------------------------
def _eval_hf(args):
from transformers import AutoModelForCausalLM, AutoTokenizer
from cbe.eval.inference import run_qa_eval_hf
# Auto-detect LoRA from the presence of adapter_config.json in the ckpt dir.
# (HF+PEFT Trainer only writes adapter weights to the checkpoint dir;
# base weights are NOT in there, so we must load the base separately.)
is_lora = os.path.exists(os.path.join(args.checkpoint, "adapter_config.json"))
if is_lora:
if not args.model:
raise SystemExit(
"LoRA adapter detected at --checkpoint; pass --model <hub_id> "
"so we can load the base model."
)
from peft import PeftModel
print(f"[evaluate] Loading base model {args.model} + LoRA adapter {args.checkpoint}")
base = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype="auto")
model = PeftModel.from_pretrained(base, args.checkpoint)
tokenizer_src = args.model # tokenizer lives with the base model
else:
print(f"[evaluate] Loading full-weights model from {args.checkpoint}")
model = AutoModelForCausalLM.from_pretrained(args.checkpoint, torch_dtype="auto")
tokenizer_src = args.checkpoint if not args.model else args.model
tokenizer = AutoTokenizer.from_pretrained(tokenizer_src)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Move to GPU if available
try:
import torch
if torch.cuda.is_available():
model = model.to("cuda")
except ImportError:
pass
return run_qa_eval_hf(
model=model,
tokenizer=tokenizer,
qa_path=args.qa_data,
prompt_prefix=args.prompt_prefix,
prompt_template=args.prompt_template,
max_new_tokens=args.max_new_tokens,
batch_size=args.batch_size,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
parser=args.parser,
num_examples=args.num_examples,
save_details_path=args.save_details,
)
# ---------------------------------------------------------------------------
# Kauldron path
# ---------------------------------------------------------------------------
def _eval_kd(args):
os.environ.setdefault("XLA_PYTHON_CLIENT_MEM_FRACTION", "0.9")
os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")
try:
import tensorflow as tf
tf.config.set_visible_devices([], "GPU")
except ImportError:
pass
from gemma import gm
from gemma.gm import peft as gm_peft
from cbe.config import ModelConfig
from cbe.models.kd_models import create_kd_model, _GEMMA_MODELS
from cbe.eval.inference import run_qa_eval_kd
if not args.model:
raise SystemExit(
"KD eval needs --model <name> (e.g. gemma3-1b-pt) to rebuild "
"the architecture."
)
model_config = ModelConfig(name=args.model, lora_rank=args.lora_rank)
factory = create_kd_model(model_config)
model = factory.make_model(model_config)
print(f"[evaluate] Loading KD checkpoint: {args.checkpoint}")
params = gm.ckpts.load_params(args.checkpoint)
# For LoRA checkpoints: training was init'd with SkipLoRA, meaning only
# LoRA params were updated. We need to re-inject the pretrained base
# params so the merged tree matches the LoRA-wrapped model.
if args.lora_rank and args.lora_rank > 0:
print(f"[evaluate] LoRA rank={args.lora_rank}: re-loading base weights + merging")
original, lora = gm_peft.split_params(params)
ckpt_attr, _ = _GEMMA_MODELS[args.model.lower()]
base_ckpt_path = getattr(gm.ckpts.CheckpointPath, ckpt_attr)
original = gm.ckpts.load_params(base_ckpt_path, params=original)
params = gm_peft.merge_params(original, lora)
return run_qa_eval_kd(
model=model,
params=params,
qa_path=args.qa_data,
prompt_prefix=args.prompt_prefix,
prompt_template=args.prompt_template,
max_new_tokens=args.max_new_tokens,
batch_size=args.batch_size,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
parser=args.parser,
num_examples=args.num_examples,
save_details_path=args.save_details,
)
if __name__ == "__main__":
main()