-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.py
More file actions
50 lines (38 loc) · 1.52 KB
/
test.py
File metadata and controls
50 lines (38 loc) · 1.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import json
from sentence_transformers import SentenceTransformer, util
import torch
from tqdm import tqdm
# 加载测试集(JSONL格式,每行为一个三元组)
def load_test_triplets(path):
triplets = []
with open(path, "r", encoding="utf-8") as f:
for line in f:
triplets.append(json.loads(line.strip()))
return triplets
test_triplets = load_test_triplets("test_triplets.jsonl")
import os
os.environ['HTTP_PROXY'] = "http://172.16.5.6:7890"
os.environ['HTTPS_PROXY'] = "http://172.16.5.6:7890"
# 加载微调前模型和微调后模型
model_before = SentenceTransformer("BAAI/bge-base-zh")
model_after = SentenceTransformer("output/bge-base-zh-finetuned")
def evaluate_model(model, triplets):
model.eval()
correct = 0
total = 0
for t in tqdm(triplets):
anchor_emb = model.encode(t["anchor"], convert_to_tensor=True)
pos_emb = model.encode(t["positive"], convert_to_tensor=True)
neg_emb = model.encode(t["negative"], convert_to_tensor=True)
# 计算余弦相似度
sim_pos = util.cos_sim(anchor_emb, pos_emb).item()
sim_neg = util.cos_sim(anchor_emb, neg_emb).item()
if sim_pos > sim_neg:
correct += 1
total += 1
accuracy = correct / total if total > 0 else 0
return accuracy
acc_before = evaluate_model(model_before, test_triplets)
acc_after = evaluate_model(model_after, test_triplets)
print(f"微调前模型准确率: {acc_before*100:.2f}%")
print(f"微调后模型准确率: {acc_after*100:.2f}%")