-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathenv.py
More file actions
388 lines (319 loc) · 16.1 KB
/
Copy pathenv.py
File metadata and controls
388 lines (319 loc) · 16.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
"""
This file contains the Environment class, which prepares the environment for the mlattack agent to run in.
"""
import json
import os
import sys
import subprocess
import shutil
import copy
import time
import fnmatch
import signal
from traceback import format_exception
from multiprocessing import active_children
import readline # to make sure input() works properly
from dacite import from_dict
from low_level_actions import LOW_LEVEL_ACTIONS
from high_level_actions import HIGH_LEVEL_ACTIONS
from schema import Step, Trace, EnvException, TooLongPromptError, LLMError, EnhancedJSONEncoder
from LLM import complete_text_claude
class TimeoutException(Exception): pass
class Environment:
def __init__(self, args):
self._args = args
self.sub_environments = {}
self._log_dir = os.path.join(args.log_dir, "env_log")
self._benchmark_folder = os.path.join("benchmarks", args.task)
self._task_config = json.load(open(args.task_config, "r"))
self._benchmark_folder = self._task_config["benchmark_folder"]
self._prompt_template_path = self._task_config["prompt_template"]
self._setup_log_dir()
self._task_prompt = self._read_prompt_template()
if args.task in ["evasion"]:
format_args = {
"target_service": self._task_config["target_service"],
}
else:
format_args = {
"target_task_description": self._task_config["target_task_description"],
"target_service": self._task_config["target_service"],
"input_format": self._task_config["input_format"],
"output_format": self._task_config["output_format"]
}
if args.task in ["attr_infer", "all_in_one"]:
format_args["target_attribute"] = self._task_config["target_attribute"]
self._task_prompt = self._task_prompt.format(**format_args)
self._work_dir = os.path.join(args.work_dir, self._benchmark_folder)
self._read_only_files = []
self._initialize_task_env() # set up work dir and log dir
self._action_infos = {t.name: t for t in LOW_LEVEL_ACTIONS + HIGH_LEVEL_ACTIONS}
self._task_prompt = self._task_prompt.format(task=args.task)
self._static_kwargs_for_tools = {
"device": args.device,
"python": args.python,
"work_dir": self._work_dir,
"args": args,
"read_only_files": self._read_only_files,
"task_prompt": self._task_prompt,
}
self._trace = self._initialize_trace()
self._start_time = time.time()
############################## getters ########################################
@property
def task_prompt(self):
return self._task_prompt
@property
def benchmark_folder_name(self):
return self._benchmark_folder
@property
def log_dir(self):
return self._log_dir
@property
def work_dir(self):
return self._work_dir
@property
def read_only_files(self):
return self._read_only_files
@property
def action_infos(self):
return self._action_infos
@property
def args(self):
return self._args
@property
def static_kwargs_for_tools(self):
return self._static_kwargs_for_tools
@property
def trace(self):
return copy.deepcopy(self._trace)
@property
def start_time(self):
return self._start_time
############################## internal functions ########################################
def _read_prompt_template(self):
with open(self._prompt_template_path, 'r') as file:
return file.read().strip()
def _setup_log_dir(self):
# set up log dir
if os.path.exists(self.args.log_dir):
print("log_dir {} already exists".format(self.log_dir))
else:
os.makedirs(self.log_dir)
if os.path.exists(os.path.join(self.log_dir, "tool_logs")):
print("tools_log_dir {} already exists".format(os.path.join(self.log_dir, "tool_logs")))
# raise ValueError("log_dir {} already exists".format(self.log_dir))
else:
os.makedirs(os.path.join(self.log_dir, "tool_logs"))
if os.path.exists(os.path.join(self.log_dir, "traces")):
print("tools_log_dir {} already exists".format(os.path.join(self.log_dir, "traces")))
# raise ValueError("log_dir {} already exists".format(self.log_dir))
else:
os.makedirs(os.path.join(self.log_dir, "traces"))
def _initialize_task_env(self):
work_dir = self._work_dir
if os.path.exists(work_dir):
for root, dirs, files in os.walk(work_dir, topdown=False):
for name in files:
file_path = os.path.join(root, name)
if not any(subdir in file_path for subdir in ['eval_datasets', 'models', 'available_datasets']):
os.remove(file_path)
for name in dirs:
dir_path = os.path.join(root, name)
if name not in ['eval_datasets', 'models', 'available_datasets']:
shutil.rmtree(dir_path)
else:
os.makedirs(work_dir)
benchmark_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "benchmarks", self._benchmark_folder)
if os.path.exists(os.path.join(benchmark_dir, "env")):
for root, dirs, files in os.walk(os.path.join(benchmark_dir, "env")):
for file in files:
src_path = os.path.join(root, file)
dst_path = os.path.join(work_dir, os.path.relpath(src_path, os.path.join(benchmark_dir, "env")))
os.makedirs(os.path.dirname(dst_path), exist_ok=True)
if not os.path.exists(dst_path):
shutil.copy2(src_path, dst_path)
# find all read only files
if os.path.exists(os.path.join(benchmark_dir, "scripts", "read_only_files.txt")):
ignore_files = open(os.path.join(benchmark_dir, "scripts", "read_only_files.txt"), "r").read().split("\n")
for path, subdirs, files in os.walk(os.path.join(work_dir)):
relpath = os.path.relpath(path, work_dir)
# filter out the files that are read only
filenames = [os.path.join(relpath, filename) for filename in files]
for ignore in ignore_files:
ignore_filenames = [n for n in filenames if fnmatch.fnmatch(n, ignore)]
self.read_only_files.extend(ignore_filenames)
# init backup folder and remove all content if it exists
if os.path.exists(os.path.join(work_dir, "backup")):
shutil.rmtree(os.path.join(work_dir, "backup"))
os.mkdir(os.path.join(work_dir, "backup"))
if self.args.resume:
shutil.rmtree(work_dir)
resume_dir = os.path.join(self.args.resume, "env_log", "traces" , f"step_{self.args.resume_step}_files")
print("Restoring workspace ing from {}".format(resume_dir))
shutil.copytree(resume_dir, work_dir, symlinks=True)
if not os.path.exists(os.path.join(work_dir, "backup")):
os.mkdir(os.path.join(work_dir, "backup"))
def _initialize_trace(self):
if self.args.resume:
print("Restoring trace from {}".format(self.args.resume))
prev_trace = from_dict(data_class=Trace, data=json.load(open(os.path.join(self.args.resume, "env_log","trace.json"), "r")))
print("Resetting trace to step {}".format(self.args.resume_step))
steps = prev_trace.steps[:self.args.resume_step+1]
t = steps[-1].timestamp
low_level_steps = [s for s in prev_trace.low_level_steps if s.timestamp < t]
trace = Trace(
steps=steps,
low_level_steps=low_level_steps,
action_infos=self.action_infos,
task_description=self.task_prompt,
)
else:
trace = Trace(
steps=[],
low_level_steps=[],
action_infos=self.action_infos,
task_description=self.task_prompt,
)
return trace
def __enter__(self):
# set time out
def signal_handler(signum, frame):
raise TimeoutException("Timed out!")
signal.signal(signal.SIGALRM, signal_handler)
signal.alarm(self.args.max_time)
return self
def __exit__(self, exc_type, exc_value, traceback):
# save error message
active = active_children()
print(f'Active Children: {len(active)}')
# terminate all active children
for child in active:
child.terminate()
# block until all children have closed
for child in active:
child.join()
# report active children
active = active_children()
print(f'Active Children: {len(active)}')
if traceback is not None:
print("Error message saved in error.txt")
open(os.path.join(self.log_dir, "error.txt"), "w").write(''.join(format_exception(exc_type, exc_value, traceback)))
open(os.path.join(self.log_dir, "overall_time.txt"), "w").write(str(time.time() - self.start_time))
for sub_env in self.sub_environments.values():
sub_env.__exit__(exc_type, exc_value, traceback)
################################# public functions ########################################
def get_sub_environments_status(self):
status = {}
for attack, sub_env in self.sub_environments.items():
status[attack] = {
"is_final": sub_env.is_final(),
"steps": len(sub_env.trace.steps)
}
return status
def create_sub_environment(self, attack_name):
sub_args = copy.deepcopy(self._args)
sub_args.log_dir = os.path.join(self._log_dir, f"attackagent_{attack_name}")
sub_args.work_dir = os.path.join(self._work_dir, f"attackagent_{attack_name}")
sub_env = Environment(sub_args)
# Copy eval_datasets, models, and available_datasets from main env to sub env
for folder in ['eval_datasets', 'models', 'available_datasets']:
src_folder = os.path.join(self._work_dir, folder)
dst_folder = os.path.join(sub_env._work_dir, folder)
if os.path.exists(src_folder):
shutil.copytree(src_folder, dst_folder, symlinks=True, dirs_exist_ok=True)
sub_env._task_prompt = self._task_prompt + f"\n\nThe above is the description for the whole task.\nYou are the AttackAgent that are initialized for {attack_name} only.\n Now perform the attack!"
self.sub_environments[attack_name] = sub_env
return sub_env
def is_final(self):
"""Check if the task has reached a final state, either by reaching the maximum steps or time, or because the agent has submitted a final answer. """
curr_step = len(self.trace.steps)
# check if any step is final answer
any_final_answer = any([s.action.name == "Final Answer" for s in self.trace.steps])
return curr_step >= self.args.max_steps or any_final_answer or time.time() - self.start_time > self.args.max_time
def execute(self, action):
"""Execute an action and return the observation."""
trace = self._trace
curr_step = len(trace.steps)
action_name = action.name
action_input = action.args
if action_name == "Final Answer":
observation = "end"
elif self.is_final():
observation = "The environment has shut down because the maximum number of steps or time has been reached. Please submit your final answer."
elif action_name not in list(self.action_infos.keys()):
actions = ", ".join(self.action_infos.keys())
observation = f"Invalid action: {action_name}. Action did not execute. Please use one of the following actions:\n{actions}"
else:
# execute the action and get the observation
log_file = os.path.join(os.path.join(self.log_dir, "tool_logs") , f"step_{curr_step}_tool_log.log")
usage = ",\n ".join([f"{k}: [{v}]" for k, v in self.action_infos[action_name].usage.items()])
usage = f"""{{
{usage}
}}"""
invalid_action_error = f"The action input for {action_name} needs to be a valid json with proper entries. You may have missed the comma between entries. Please use the correct format and try again:\n{usage}"
if isinstance(action_input, dict):
try:
observation = self.action_infos[action_name].function(**action_input, log_file=log_file, trace=trace, **self.static_kwargs_for_tools)
except TooLongPromptError:
observation="EnvError: too long input for the tool"
except LLMError as e:
observation = "LLMError: " + e.message
except EnvException as e:
observation = "EnvError: " + e.message
except TypeError as e:
print("Step: ", curr_step, file=sys.stderr)
print(e, file=sys.stderr)
print(action_input, file=sys.stderr)
observation = "EnvError: " + invalid_action_error
except TimeoutException as e:
raise e
except Exception as e:
# should not happen
import traceback
error_traceback = traceback.format_exc()
print("Step: ", curr_step, file=sys.stderr)
print(error_traceback, file=sys.stderr)
observation = f"EnvError: Error executing {action_name}. Error details: {str(e)}\nTraceback:\n{error_traceback}"
else:
observation = invalid_action_error
step_time = time.time()
trace.steps.append(Step(action, observation, step_time))
self.save(curr_step)
return observation
def save(self, curr_step):
""" Save the trace and snapshot of the workspace folder """
with open(os.path.join(self.log_dir, f"trace.json"), "w") as f:
json.dump(self.trace, f, indent=4, cls=EnhancedJSONEncoder)
# Save sub-environments
for attack, sub_env in self.sub_environments.items():
sub_env.save(curr_step)
##### save a snapshot of the current step
save_folder = os.path.join(self.log_dir, f"traces/step_{curr_step}_files")
if os.path.exists(save_folder):
shutil.rmtree(save_folder)
os.makedirs(save_folder)
# save files in the folder that are not read only
for path, subdirs, files in os.walk(os.path.join(self.work_dir)):
relpath = os.path.relpath(path, self.work_dir)
dest = os.path.join(save_folder, relpath)
for file_name in files:
file_path = os.path.join(relpath, file_name)
if file_path not in self.read_only_files:
# check wether the file to copy is part of self.log_dir
if os.path.abspath(os.path.join(self.work_dir, file_path)).startswith(os.path.abspath(self.log_dir.split("/env_log")[0])):
continue
if not os.path.exists(dest):
os.makedirs(dest)
shutil.copyfile(os.path.join(self.work_dir, file_path), os.path.join(save_folder, file_path))
############## for logging convenience ##############
def get_task_description(self):
return self._task_prompt, self._benchmark_folder
@property
def low_level_actions(self):
return list(filter(lambda x: x.is_primitive, self.action_infos.values()))
@property
def high_level_actions(self):
return list(filter(lambda x: not x.is_primitive, self.action_infos.values()))
def print_action(self, entries):
return "".join([ k + ": " + v for k,v in entries.items()])