forked from ml-research/blendrl-dev
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathevaluate_baselines.py
More file actions
157 lines (142 loc) · 13.5 KB
/
Copy pathevaluate_baselines.py
File metadata and controls
157 lines (142 loc) · 13.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from aiofiles.os import path
from statistics import median_grouped
from evaluate import main as evaluate
import numpy as np
def main():
# NUDGE
evals = {
# "kangaroo_jax_0": "out_nudge/runs/kangaroo_jax_softmax_lr_0.00025_llr_0.00025_gamma_0.99_numenvs_512_steps_128_0",
# "kangaroo_jax_1": "out_nudge/runs/kangaroo_jax_softmax_lr_0.00025_llr_0.00025_gamma_0.99_numenvs_512_steps_128_1_20251118_113300",
# "kangaroo_jax_2": "out_nudge/runs/kangaroo_jax_softmax_lr_0.00025_llr_0.00025_gamma_0.99_numenvs_512_steps_128_2_20251118_132417",
# "seaquest_jax_0": "out_nudge/runs/seaquest_jax_softmax_lr_0.00025_llr_0.00025_gamma_0.99_numenvs_512_steps_128_0",
# "seaquest_jax_1": "out_nudge/runs/seaquest_jax_softmax_lr_0.00025_llr_0.00025_gamma_0.99_numenvs_512_steps_128_1_20251118_113234",
# "seaquest_jax_2": "out_nudge/runs/seaquest_jax_softmax_lr_0.00025_llr_0.00025_gamma_0.99_numenvs_512_steps_128_2_20251118_140233",
}
# BLENDRL
evals = {
# "kangaroo_jax_0": "out/runs/kangaroo_jax_softmax_blender_logic_lr_0.00025_llr_0.00025_blr_0.00025_gamma_0.99_bentcoef_0.01_numenvs_512_steps_128__0_20260119_161001",
# "kangaroo_jax_1": "out/runs/kangaroo_jax_softmax_blender_logic_lr_0.00025_llr_0.00025_blr_0.00025_gamma_0.99_bentcoef_0.01_numenvs_512_steps_128__1_20251118_140348",
# "kangaroo_jax_2": "out/runs/kangaroo_jax_softmax_blender_logic_lr_0.00025_llr_0.00025_blr_0.00025_gamma_0.99_bentcoef_0.01_numenvs_512_steps_128__2_20251118_222638",
# "seaquest_jax_0": "out/runs/seaquest_jax_softmax_blender_logic_lr_0.00025_llr_0.00025_blr_0.00025_gamma_0.99_bentcoef_0.01_numenvs_512_steps_128__0_20260119_171753"
# "seaquest_jax_1": "out/runs/seaquest_jax_softmax_blender_logic_lr_0.00025_llr_0.00025_blr_0.00025_gamma_0.99_bentcoef_0.01_numenvs_512_steps_128__1_20251118_113132",
# "seaquest_jax_2": "out/runs/seaquest_jax_softmax_blender_logic_lr_0.00025_llr_0.00025_blr_0.00025_gamma_0.99_bentcoef_0.01_numenvs_512_steps_128__2_20251118_140442"
}
# For ASTRA
# Need:
# - nudge and blendrl:
# - kangaroo: default, center_ladders, four_ladders, flame_trap, cactus_trap, danger_trap, tanks, snakes, dragons, replace_coconut_fireball, replace_coconut_honey_bee, replace_coconut_wasp
# - seaquest: default, fireballs, mines, no_divers
# BLENDRL:
evals_blendrl = {
# "seaquest_jax_0_fireballs": "out/runs/seaquest_jax_softmax_blender_logic_lr_0.00025_llr_0.00025_blr_0.00025_gamma_0.99_bentcoef_0.01_numenvs_512_steps_128__0_20260119_171753",
# "seaquest_jax_0_mines": "out/runs/seaquest_jax_softmax_blender_logic_lr_0.00025_llr_0.00025_blr_0.00025_gamma_0.99_bentcoef_0.01_numenvs_512_steps_128__0_20260119_171753",
# "seaquest_jax_0_mines-detect": "out/runs/seaquest_jax_softmax_blender_logic_lr_0.00025_llr_0.00025_blr_0.00025_gamma_0.99_bentcoef_0.01_numenvs_512_steps_128__0_20260119_171753",
# "seaquest_jax_0_no-divers": "out/runs/seaquest_jax_softmax_blender_logic_lr_0.00025_llr_0.00025_blr_0.00025_gamma_0.99_bentcoef_0.01_numenvs_512_steps_128__0_20260119_171753",
# "seaquest_jax_0_None": "out/runs/seaquest_jax_softmax_blender_logic_lr_0.00025_llr_0.00025_blr_0.00025_gamma_0.99_bentcoef_0.01_numenvs_512_steps_128__0_20260119_171753",
"kangaroo_jax_0_None": "out/runs/kangaroo_jax_softmax_blender_logic_lr_0.00025_llr_0.00025_blr_0.00025_gamma_0.99_bentcoef_0.01_numenvs_512_steps_128__0_20260119_161001",
"kangaroo_jax_0_no-falling-coconut": "out/runs/kangaroo_jax_softmax_blender_logic_lr_0.00025_llr_0.00025_blr_0.00025_gamma_0.99_bentcoef_0.01_numenvs_512_steps_128__0_20260119_161001",
"kangaroo_jax_0_center-ladders": "out/runs/kangaroo_jax_softmax_blender_logic_lr_0.00025_llr_0.00025_blr_0.00025_gamma_0.99_bentcoef_0.01_numenvs_512_steps_128__0_20260119_161001",
"kangaroo_jax_0_four-ladders": "out/runs/kangaroo_jax_softmax_blender_logic_lr_0.00025_llr_0.00025_blr_0.00025_gamma_0.99_bentcoef_0.01_numenvs_512_steps_128__0_20260119_161001",
"kangaroo_jax_0_flame-trap": "out/runs/kangaroo_jax_softmax_blender_logic_lr_0.00025_llr_0.00025_blr_0.00025_gamma_0.99_bentcoef_0.01_numenvs_512_steps_128__0_20260119_161001",
"kangaroo_jax_0_cactus-trap": "out/runs/kangaroo_jax_softmax_blender_logic_lr_0.00025_llr_0.00025_blr_0.00025_gamma_0.99_bentcoef_0.01_numenvs_512_steps_128__0_20260119_161001",
"kangaroo_jax_0_danger-trap": "out/runs/kangaroo_jax_softmax_blender_logic_lr_0.00025_llr_0.00025_blr_0.00025_gamma_0.99_bentcoef_0.01_numenvs_512_steps_128__0_20260119_161001",
"kangaroo_jax_0_tanks": "out/runs/kangaroo_jax_softmax_blender_logic_lr_0.00025_llr_0.00025_blr_0.00025_gamma_0.99_bentcoef_0.01_numenvs_512_steps_128__0_20260119_161001",
"kangaroo_jax_0_snakes": "out/runs/kangaroo_jax_softmax_blender_logic_lr_0.00025_llr_0.00025_blr_0.00025_gamma_0.99_bentcoef_0.01_numenvs_512_steps_128__0_20260119_161001",
"kangaroo_jax_0_dragons": "out/runs/kangaroo_jax_softmax_blender_logic_lr_0.00025_llr_0.00025_blr_0.00025_gamma_0.99_bentcoef_0.01_numenvs_512_steps_128__0_20260119_161001",
"kangaroo_jax_0_replace-coconut-fireball": "out/runs/kangaroo_jax_softmax_blender_logic_lr_0.00025_llr_0.00025_blr_0.00025_gamma_0.99_bentcoef_0.01_numenvs_512_steps_128__0_20260119_161001",
"kangaroo_jax_0_replace-coconut-honey-bee": "out/runs/kangaroo_jax_softmax_blender_logic_lr_0.00025_llr_0.00025_blr_0.00025_gamma_0.99_bentcoef_0.01_numenvs_512_steps_128__0_20260119_161001",
"kangaroo_jax_0_replace-coconut-wasp": "out/runs/kangaroo_jax_softmax_blender_logic_lr_0.00025_llr_0.00025_blr_0.00025_gamma_0.99_bentcoef_0.01_numenvs_512_steps_128__0_20260119_161001",
}
# NUDGE
evals_nudge = {
# "seaquest_jax_0_fireballs": "out_nudge/runs/seaquest_jax_softmax_lr_0.00025_llr_0.00025_gamma_0.99_numenvs_512_steps_128_2_20251118_140233",
# "seaquest_jax_0_mines": "out_nudge/runs/seaquest_jax_softmax_lr_0.00025_llr_0.00025_gamma_0.99_numenvs_512_steps_128_2_20251118_140233",
# "seaquest_jax_0_mines-detect": "out_nudge/runs/seaquest_jax_softmax_lr_0.00025_llr_0.00025_gamma_0.99_numenvs_512_steps_128_2_20251118_140233",
# "seaquest_jax_0_no-divers": "out_nudge/runs/seaquest_jax_softmax_lr_0.00025_llr_0.00025_gamma_0.99_numenvs_512_steps_128_2_20251118_140233",
# "seaquest_jax_0_None": "out_nudge/runs/seaquest_jax_softmax_lr_0.00025_llr_0.00025_gamma_0.99_numenvs_512_steps_128_2_20251118_140233",
"kangaroo_jax_0_center-ladders": "out_nudge/runs/kangaroo_jax_softmax_lr_0.00025_llr_0.00025_gamma_0.99_numenvs_512_steps_128_2_20251118_132417",
"kangaroo_jax_0_four-ladders": "out_nudge/runs/kangaroo_jax_softmax_lr_0.00025_llr_0.00025_gamma_0.99_numenvs_512_steps_128_2_20251118_132417",
"kangaroo_jax_0_flame-trap": "out_nudge/runs/kangaroo_jax_softmax_lr_0.00025_llr_0.00025_gamma_0.99_numenvs_512_steps_128_2_20251118_132417",
"kangaroo_jax_0_cactus-trap": "out_nudge/runs/kangaroo_jax_softmax_lr_0.00025_llr_0.00025_gamma_0.99_numenvs_512_steps_128_2_20251118_132417",
"kangaroo_jax_0_danger-trap": "out_nudge/runs/kangaroo_jax_softmax_lr_0.00025_llr_0.00025_gamma_0.99_numenvs_512_steps_128_2_20251118_132417",
"kangaroo_jax_0_tanks": "out_nudge/runs/kangaroo_jax_softmax_lr_0.00025_llr_0.00025_gamma_0.99_numenvs_512_steps_128_2_20251118_132417",
"kangaroo_jax_0_snakes": "out_nudge/runs/kangaroo_jax_softmax_lr_0.00025_llr_0.00025_gamma_0.99_numenvs_512_steps_128_2_20251118_132417",
"kangaroo_jax_0_dragons": "out_nudge/runs/kangaroo_jax_softmax_lr_0.00025_llr_0.00025_gamma_0.99_numenvs_512_steps_128_2_20251118_132417",
"kangaroo_jax_0_replace-coconut-fireball": "out_nudge/runs/kangaroo_jax_softmax_lr_0.00025_llr_0.00025_gamma_0.99_numenvs_512_steps_128_2_20251118_132417",
"kangaroo_jax_0_replace-coconut-honey-bee": "out_nudge/runs/kangaroo_jax_softmax_lr_0.00025_llr_0.00025_gamma_0.99_numenvs_512_steps_128_2_20251118_132417",
"kangaroo_jax_0_replace-coconut-wasp": "out_nudge/runs/kangaroo_jax_softmax_lr_0.00025_llr_0.00025_gamma_0.99_numenvs_512_steps_128_2_20251118_132417",
"kangaroo_jax_0_None": "out_nudge/runs/kangaroo_jax_softmax_lr_0.00025_llr_0.00025_gamma_0.99_numenvs_512_steps_128_2_20251118_132417",
"kangaroo_jax_0_no-falling-coconut": "out_nudge/runs/kangaroo_jax_softmax_lr_0.00025_llr_0.00025_gamma_0.99_numenvs_512_steps_128_2_20251118_132417",
}
# NLRL
evals_nlrl = {
# "seaquest_jax_0_fireballs": "out_nlrl/runs/seaquest_jax_softmax_lr_0.00025_llr_0.00025_gamma_0.99_numenvs_256_steps_128_0_20260124_003447",
# "seaquest_jax_0_mines": "out_nlrl/runs/seaquest_jax_softmax_lr_0.00025_llr_0.00025_gamma_0.99_numenvs_256_steps_128_0_20260124_003447",
# "seaquest_jax_0_mines-detect": "out_nlrl/runs/seaquest_jax_softmax_lr_0.00025_llr_0.00025_gamma_0.99_numenvs_256_steps_128_0_20260124_003447",
# "seaquest_jax_0_no-divers": "out_nlrl/runs/seaquest_jax_softmax_lr_0.00025_llr_0.00025_gamma_0.99_numenvs_256_steps_128_0_20260124_003447",
# "seaquest_jax_0_None": "out_nlrl/runs/seaquest_jax_softmax_lr_0.00025_llr_0.00025_gamma_0.99_numenvs_256_steps_128_0_20260124_003447",
"kangaroo_jax_0_center-ladders": "out_nlrl/runs/kangaroo_jax_softmax_lr_0.00025_llr_0.00025_gamma_0.99_numenvs_256_steps_128_0_20260123_184901",
"kangaroo_jax_0_four-ladders": "out_nlrl/runs/kangaroo_jax_softmax_lr_0.00025_llr_0.00025_gamma_0.99_numenvs_256_steps_128_0_20260123_184901",
"kangaroo_jax_0_flame-trap": "out_nlrl/runs/kangaroo_jax_softmax_lr_0.00025_llr_0.00025_gamma_0.99_numenvs_256_steps_128_0_20260123_184901",
"kangaroo_jax_0_cactus-trap": "out_nlrl/runs/kangaroo_jax_softmax_lr_0.00025_llr_0.00025_gamma_0.99_numenvs_256_steps_128_0_20260123_184901",
"kangaroo_jax_0_danger-trap": "out_nlrl/runs/kangaroo_jax_softmax_lr_0.00025_llr_0.00025_gamma_0.99_numenvs_256_steps_128_0_20260123_184901",
"kangaroo_jax_0_tanks": "out_nlrl/runs/kangaroo_jax_softmax_lr_0.00025_llr_0.00025_gamma_0.99_numenvs_256_steps_128_0_20260123_184901",
"kangaroo_jax_0_snakes": "out_nlrl/runs/kangaroo_jax_softmax_lr_0.00025_llr_0.00025_gamma_0.99_numenvs_256_steps_128_0_20260123_184901",
"kangaroo_jax_0_dragons": "out_nlrl/runs/kangaroo_jax_softmax_lr_0.00025_llr_0.00025_gamma_0.99_numenvs_256_steps_128_0_20260123_184901",
"kangaroo_jax_0_replace-coconut-fireball": "out_nlrl/runs/kangaroo_jax_softmax_lr_0.00025_llr_0.00025_gamma_0.99_numenvs_256_steps_128_0_20260123_184901",
"kangaroo_jax_0_replace-coconut-honey-bee": "out_nlrl/runs/kangaroo_jax_softmax_lr_0.00025_llr_0.00025_gamma_0.99_numenvs_256_steps_128_0_20260123_184901",
"kangaroo_jax_0_replace-coconut-wasp": "out_nlrl/runs/kangaroo_jax_softmax_lr_0.00025_llr_0.00025_gamma_0.99_numenvs_256_steps_128_0_20260123_184901",
"kangaroo_jax_0_None": "out_nlrl/runs/kangaroo_jax_softmax_lr_0.00025_llr_0.00025_gamma_0.99_numenvs_256_steps_128_0_20260123_184901",
"kangaroo_jax_0_no-falling-coconut": "out_nlrl/runs/kangaroo_jax_softmax_lr_0.00025_llr_0.00025_gamma_0.99_numenvs_256_steps_128_0_20260123_184901",
}
# add "blendrl", "nudge", "nlrl" tags to eval keys
evals = {}
for k, v in evals_nudge.items():
evals["nudge_" + k] = v
for k, v in evals_blendrl.items():
evals["blendrl_" + k] = v
for k, v in evals_nlrl.items():
evals["nlrl_" + k] = v
scores = []
aligned_scores = []
mod_scores = []
aligned_mod_scores = []
for run, paths in evals.items():
detect_all_enemies = False #standard setting: blendrl does not detect objects as enemies
print("Evaluating run:", run)
env_name = "_".join(run.split("_")[1:-2])
seed = int(run.split("_")[-2])
modification = run.split("_")[-1]
# replace - with _ in modification
modification = modification.replace("-", "_")
if modification == "mines_detect":
modification = "mines"
detect_all_enemies=True
if modification == "None":
modification = None
mod_list = []
if modification is not None:
mod_list.append(modification)
mod_list.append("first_level_only") # always use first_level_only to avoid stochasticity from later levels
score, _, _, _, aligned_score, _ = evaluate(env_name, paths, episodes=3, seed=seed, modified_env=mod_list, device="cpu", detect_all_enemies=detect_all_enemies)
# mod_score, _, _, _, mod_aligned_score, _ = evaluate(env_name, path, episodes=3, seed=seed, modified_env=True)
scores.append(score)
aligned_scores.append(aligned_score)
# mod_scores.append(mod_score)
# aligned_mod_scores.append(mod_aligned_score)
mean_score = np.mean(scores)
std_score = np.std(scores)
mean_aligned_score = np.mean(aligned_scores)
std_aligned_score = np.std(aligned_scores)
mean_mod_score = np.mean(mod_scores)
std_mod_score = np.std(mod_scores)
mean_aligned_mod_score = np.mean(aligned_mod_scores)
std_aligned_mod_score = np.std(aligned_mod_scores)
print("Results over different seeds:")
print("Standard Env Score:", mean_score, "+-", std_score)
print("Standard Env Aligned Score:", mean_aligned_score, "+-", std_aligned_score)
print("Modified Env Score:", mean_mod_score, "+-", std_mod_score)
print("Modified Env Aligned Score:", mean_aligned_mod_score, "+-", std_aligned_mod_score)
print(mean_score, std_score, mean_mod_score, std_mod_score, mean_aligned_score, std_aligned_score, mean_aligned_mod_score, std_aligned_mod_score)
if __name__ == "__main__":
main()