-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathbase.py
More file actions
248 lines (224 loc) · 9.57 KB
/
base.py
File metadata and controls
248 lines (224 loc) · 9.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
import os
import copy
from tqdm import tqdm
from typing import List, Dict
from concurrent.futures import ThreadPoolExecutor, as_completed
from src.dataset import BaseDataset
from src.utils import if_memory_cached, mark_memory_cached
from src.agent.base_agent import BaseAgentConfig, BaseAgent
class BaseSolver:
AGENT_CLASS = BaseAgent
MAX_THREADS = 4
def __init__(self, config, memory_cache_dir: str):
if "memory_cache_dir" in config.__dict__:
config.memory_cache_dir = memory_cache_dir
self.config = config
self.memory_cache_dir = memory_cache_dir
self.method_name = "wo_memory"
self.agent = self.AGENT_CLASS(config)
def record_all_memories(self):
paths = self.agent.write_memory_records(self.memory_cache_dir)
for path in paths:
print(f"Saved memory records to {path}")
return paths
def _attach_memory_trace(self, result: Dict) -> Dict:
result["retrieved_memories"] = self.agent.get_last_memory_trace()
return result
def _create_or_load_memory(self, dialogs: List[Dict], can_thread: bool = False):
"""
Create or load memory cache for Memory system.
The memory cache will save in the memory_cache/a_mem/ directory.
Args:
dialogs (List[Dict]): List of dialog data.
"""
if not if_memory_cached(self.memory_cache_dir):
print(f"Creating memory cache at {self.memory_cache_dir}")
if can_thread:
with ThreadPoolExecutor(max_workers=self.MAX_THREADS) as executor:
futures = [
executor.submit(self.agent.add_conversation_to_memory, dialog["dialog"], dialog["test_idx"])
for dialog in dialogs
]
for future in tqdm(
as_completed(futures),
total=len(futures),
desc=f"Memorying dialogs with {self.method_name}",
ascii=True,
dynamic_ncols=False,
ncols=80,
):
result = future.result()
else:
for dialog in tqdm(
dialogs,
desc=f"Memorying dialogs with {self.method_name}",
ascii=True,
dynamic_ncols=False,
ncols=80,
):
try:
self.agent.add_conversation_to_memory(dialog["dialog"], dialog["test_idx"])
except:
print(dialog["test_idx"], dialog["dataset"], "failed to memory.")
import json
print(json.dumps(dialog["dialog"], indent=4, ensure_ascii=False))
raise ValueError("Memorying failed.")
self.agent.save_memories()
self.record_all_memories()
mark_memory_cached(self.memory_cache_dir)
else:
print("Loading memory cache from", self.memory_cache_dir)
self.agent.load_memories()
self.record_all_memories()
def create_or_load_memory(self, dialogs: List[Dict]):
return
def predict_single_data(self, dataset: BaseDataset, data) -> str:
"""
Predict response for a single data point.
Args:
dataset (BaseDataset): The dataset containing the data point.
data: A single data point containing messages.
Returns:
str: The response generated by the agent.
"""
input_messages = dataset.get_initial_chat_messages(data["test_idx"])
messages = copy.deepcopy(input_messages)
self.agent.clear_last_memory_trace()
if hasattr(self.agent.config, "retrieve_k"):
retrieve_k = self.agent.config.retrieve_k
while retrieve_k:
try:
messages = copy.deepcopy(input_messages)
self.agent.clear_last_memory_trace()
response = self.agent.generate_response(
messages=messages,
lang=data["lang"],
retrieve_k=retrieve_k,
)
break
except Exception as e:
print(e)
retrieve_k -= 1
else:
print(f"Failed to generate response for test_idx {data['test_idx']} with retrieve_k down to 0.")
response = "Error: Unable to generate response."
else:
try:
messages = copy.deepcopy(input_messages)
self.agent.clear_last_memory_trace()
response = self.agent.generate_response(
messages=messages,
lang=data["lang"],
)
except Exception as e:
print(e)
print(f"Failed to generate response for test_idx {data['test_idx']}.")
response = "Error: Unable to generate response."
return self._attach_memory_trace({
"test_idx": data["test_idx"],
"messages": messages,
"response": response,
})
def predict_test(self, dataset: BaseDataset, split_name = "test") -> List[Dict]:
results = []
with ThreadPoolExecutor(max_workers=self.MAX_THREADS) as executor:
futures = [
executor.submit(self.predict_single_data, dataset, data)
for data in dataset.dataset[split_name].to_list()
]
for future in tqdm(
as_completed(futures),
total=len(futures),
desc="Predicting tests",
ascii=True,
dynamic_ncols=False,
ncols=80,
):
result = future.result()
results.append(result)
results = sorted(results, key=lambda x: x["test_idx"])
# for idx in tqdm(test_ids, desc="Evaluating tests"):
# result = self.predict_single_data(dataset, dataset.get_data(idx))
# results.append(result)
return results
def predict_test_with_corpus(self, dataset: BaseDataset, split_name = "test") -> List[Dict]:
assert dataset.has_corpus, "Dataset does not have corpus for context window."
conversation, session_cnt = dataset.corpus, dataset.session_cnt
def solve_data(data, convers: List[Dict[str, str]], max_session_idx: int):
input_messages = dataset.get_initial_chat_messages(data["test_idx"])
question = input_messages[-1]["content"]
context = ""
for idx in range(max_session_idx+1):
k = f"session_{idx}"
if k in convers:
vv = convers[k]
for v in vv:
context += f"{v['speaker']}: {v['text']}\n"
if data["lang"] == "en":
user_prompt = f"""Context:
{context}
User:
{question}
Based on the context provided, respond naturally and appropriately to the user's input above."""
elif data["lang"] == "zh":
user_prompt = f"""相关知识:
{context}
用户输入:
{question}
请根据提供的相关知识准确、自然地回答用户的输入。"""
messages = copy.deepcopy(input_messages)
messages[-1]["content"] = user_prompt
ok_flag = True
self.agent.clear_last_memory_trace()
try:
response = self.agent.generate_response(
messages=messages,
lang=data["lang"],
)
except Exception as e:
print(e)
print(f"Failed to generate response for test_idx {data['test_idx']}.")
ok_flag = False
response = "Error: Unable to generate response."
return ok_flag, self._attach_memory_trace({
"test_idx": data["test_idx"],
"messages": messages,
"response": response,
})
# determine the session cnt for context window
min_session, max_session = 0, session_cnt
# 二分
while min_session < max_session:
mid_session = (min_session + max_session + 1) // 2
ok_flag, _ = solve_data(dataset.dataset[split_name][0], conversation, mid_session)
# print(min_session, mid_session, max_session, ok_flag)
if ok_flag:
min_session = mid_session
else:
max_session = mid_session - 1
session_cnt = min_session
print(f"Using top {session_cnt} sessions as context window.")
def predict_data(data):
for cur_session_cnt in range(session_cnt, -1, -1):
ok_flag, result = solve_data(data, conversation, cur_session_cnt)
if ok_flag:
return result
return result
results = []
with ThreadPoolExecutor(max_workers=self.MAX_THREADS) as executor:
futures = [
executor.submit(predict_data, data)
for data in dataset.dataset[split_name].to_list()
]
for future in tqdm(
as_completed(futures),
total=len(futures),
desc="Predicting tests wo_memory",
ascii=True,
dynamic_ncols=False,
ncols=80,
):
result = future.result()
results.append(result)
results = sorted(results, key=lambda x: x["test_idx"])
return results