-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathplot_n_step_errors.py
More file actions
90 lines (84 loc) · 3.37 KB
/
plot_n_step_errors.py
File metadata and controls
90 lines (84 loc) · 3.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from matplotlib import pyplot as plt
import numpy as np
import os
from strictfire import StrictFire
from plot_gym_training_progress import make_legend_pickable
def main(paper=False):
directory = "~/uplink/W_n_step_errors_test_dataset"
directory = os.path.expanduser(directory)
# list files in directory
files = os.listdir(directory)
errors = {}
for file in files:
if file.endswith("_n_step_errors.npz"):
# load file
data = np.load(os.path.join(directory, file))
wm_name = file.split("_n_step_errors.npz")[0]
obs_error = data["obs_error"]
vecobs_error = data["vecobs_error"]
errors[wm_name] = {"obs_error": obs_error, "vecobs_error": vecobs_error}
if paper:
fig, (ax1, ax2) = plt.subplots(1, 2, num="n_step_errors")
# ax1.set_yscale('log')
# ax2.set_yscale('log')
linegroups = []
legends = []
for wm_name in errors:
obs_error = errors[wm_name]["obs_error"]
vecobs_error = errors[wm_name]["vecobs_error"]
# manual cleanup
color = None
style = None
if wm_name == "RSSMWorldModel":
continue
if wm_name == "RSSMA0WorldModel":
continue
if wm_name == "GreyDummyWorldModel":
color = "k"
style = "--"
wm_name = "Arbitrary-guess Upper Bound"
if wm_name == "DummyWorldModel":
color = "k"
style = ":"
wm_name = "Static Upper Bound"
if wm_name == "GPT":
wm_name = "Transformer (z=64)"
if wm_name == "RSSMA0ExplicitWorldModel":
wm_name = "RSSM (z=1024)"
if wm_name == "TSSMWorldModel":
wm_name = "TSSM (z=1024)"
if wm_name == "TransformerLWorldModel":
wm_name = "Transformer (z=1024)"
line1, = ax1.plot(np.nanmean(obs_error, axis=0), color=color, linestyle=style)
line2, = ax2.plot(np.nanmean(vecobs_error, axis=0), color=color, linestyle=style)
linegroups.append([line1, line2])
legends.append(wm_name)
ax1.set_title("Image Observation")
ax1.set_xlabel("Dream Length")
ax1.set_ylabel("MSE Pixel-wise Prediction Error")
ax2.set_title("Vector Observation")
ax2.set_xlabel("Dream Length")
ax2.set_ylabel("MSE Prediction Error [m]")
ax1.set_xlim([0, 48])
ax2.set_xlim([0, 48])
L = fig.legend([lines[0] for lines in linegroups], legends, bbox_to_anchor=(0.9, 0.9))
make_legend_pickable(L, linegroups)
plt.show()
else:
fig, (ax1, ax2) = plt.subplots(1, 2, num="errors")
# ax1.set_yscale('log')
# ax2.set_yscale('log')
linegroups = []
legends = []
for wm_name in errors:
obs_error = errors[wm_name]["obs_error"]
vecobs_error = errors[wm_name]["vecobs_error"]
line1, = ax1.plot(np.nanmean(obs_error, axis=0))
line2, = ax2.plot(np.nanmean(vecobs_error, axis=0))
linegroups.append([line1, line2])
legends.append(wm_name)
L = fig.legend([lines[0] for lines in linegroups], legends)
make_legend_pickable(L, linegroups)
plt.show()
if __name__ == "__main__":
StrictFire(main)