-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun-use.py
More file actions
161 lines (128 loc) · 6.03 KB
/
run-use.py
File metadata and controls
161 lines (128 loc) · 6.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import json
from tqdm import tqdm
import re
import time
import random
import os
from llm import LLM
# model_name = "claude-3-7-sonnet@20250219"
# model_name = "gemini-2.0-flash-lite"
# model_name = "gemini-2.5-pro-exp-03-25"
# model_name = "gpt-4o-240513"
model_name = "gpt-4.1"
# model_name = "Qwen/Qwen2.5-72B-Instruct-Turbo"
# model_name = "deepseek-ai/DeepSeek-V3"
# model_name = "deepseek-ai/DeepSeek-R1"
if "/" in model_name:
file_name = model_name.split("/")[-1]
else:
file_name = model_name
model = LLM(model_name)
def extract_idioms_others(response_text):
"""
Extract idioms from the model response using regex pattern matching.
Returns a list of the 5 idioms in order of relevance.
"""
# Look for text between angle brackets
response_text = response_text.replace('、', ',')
match = re.search(r'<([^>]+)>', response_text)
if match:
# Split the matched content by commas and clean up whitespace
idioms = [idiom.strip() for idiom in match.group(1).split(',')]
# Ensure we get exactly 5 idioms, pad with empty strings if necessary
return (idioms + [''] * 5)[:5]
# Alternative pattern if angle brackets aren't used
idioms = re.findall(r'\d+\.\s*([^\n,]+)', response_text)
if idioms and len(idioms) <= 5:
return (idioms + [''] * 5)[:5]
# Last resort - just split by commas or newlines
potential_idioms = re.split(r'[,\n]', response_text)
cleaned_idioms = [i.strip() for i in potential_idioms if i.strip()]
return (cleaned_idioms + [''] * 5)[:5]
def extract_idioms_r1(response_text):
"""
Extract idioms from the model response using regex pattern matching.
Returns a list of the 5 idioms in order of relevance.
"""
try:
# Look for text between angle brackets
response_text = response_text.split("<answer>")[-1].split("</answer>")[0].replace(",", ",")
match = re.search(r'<([^>]+)>', response_text)
if match:
# Split the matched content by commas and clean up whitespace
idioms = [idiom.strip() for idiom in match.group(1).split(',')]
# Ensure we get exactly 5 idioms, pad with empty strings if necessary
return (idioms + [''] * 5)[:5]
except:
pass
# Alternative pattern if angle brackets aren't used
idioms = re.findall(r'\d+\.\s*([^\n,]+)', response_text)
if idioms and len(idioms) <= 5:
return (idioms + [''] * 5)[:5]
# Last resort - just split by commas or newlines
potential_idioms = re.split(r'[,\n]', response_text)
cleaned_idioms = [i.strip() for i in potential_idioms if i.strip()]
return (cleaned_idioms + [''] * 5)[:5]
if model_name == "deepseek-ai/DeepSeek-R1":
extract_idioms = extract_idioms_r1
else:
extract_idioms = extract_idioms_others
# Load data
with open("dataset/dataset-use.json", mode='r', encoding='utf-8') as f:
data = json.load(f)
# Try to resume from existing results
results_path = f"results-use/{file_name}.json"
if os.path.exists(results_path):
with open(results_path, mode='r', encoding='utf-8') as f:
results = json.load(f)
else:
results = {}
# Loop and resume inference
for line in tqdm(data):
index = str(line["index"])
if index in results:
continue # Skip already processed
if model_name == "deepseek-ai/DeepSeek-R1":
prompt = f"""Below is a Chinese passage. Please generate five four-character idioms that would be contextually appropriate to replace the placeholder #idiom# in the passage.
The passage is as follows:
{line["text"]}
Please rank the idioms from most to least appropriate based on the context. At the end of your response, provide the idioms in the following format between <answer> and </answer>:
<answer><idiom1, idiom2, idiom3, idiom4, idiom5></answer>
Do not output any additional content between <answer> and </answer>."""
else:
prompt = f"""Below is a Chinese passage. Please generate five four-character idioms that would be contextually appropriate to replace the placeholder #idiom# in the passage.
The passage is as follows:
{line["text"]}
Please only provide the idioms in the format:
<idiom1, idiom2, idiom3, idiom4, idiom5>
Rank the idioms from most to least appropriate based on the context. Do not output any additional content."""
max_retries = 5
retry_count = 0
success = False
while retry_count < max_retries and not success:
try:
# Add exponential backoff with jitter
if retry_count > 0:
# Calculate delay with exponential backoff and jitter
delay = (2 ** retry_count) + random.uniform(0, 1)
print(f"Retrying in {delay:.2f} seconds...")
time.sleep(delay)
# Always add a small delay between requests to avoid overloading
time.sleep(0.5)
response = model.call_llm(prompt, max_tokens=128)
results[index] = {"response_text": response, "idioms": extract_idioms(response), "answer_idiom": line["idiom"], "source": line["source"]}
# Save after each successful inference
with open(results_path, mode='w', encoding='utf-8') as f:
json.dump(results, f, indent=4, ensure_ascii=False)
success = True
except Exception as e:
retry_count += 1
print(f"Error at index {index} (attempt {retry_count}/{max_retries}): {e}")
# If this is a rate limit or overloaded error, wait longer
if "overloaded" in str(e).lower() or "rate limit" in str(e).lower() or '429' in str(e).lower():
longer_delay = 10 * (2 ** retry_count) + random.uniform(0, 5)
print(f"Rate limit hit. Waiting {longer_delay:.2f} seconds...")
time.sleep(longer_delay)
# If we've used all our retries, save progress and exit
if retry_count >= max_retries:
print(f"Max retries reached for index {index}. Saving and continuing.")