-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathplot_trails.py
More file actions
56 lines (50 loc) · 2.05 KB
/
plot_trails.py
File metadata and controls
56 lines (50 loc) · 2.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import os
from strictfire import StrictFire
from stable_baselines3 import PPO
from tqdm import tqdm
from navdreams.navrep3dtrainencodedenv import EncoderObsWrapper
from navdreams.navrep3danyenv import NavRep3DAnyEnvDiscrete
from plot_gym_training_progress import get_variant
def main(build_name="./alternate.x86_64", difficulty_mode="random", model_path=None):
render = True
MODELPATH = "~/navdreams_data/results/models/gym/navrep3daltencodedenv_2021_12_15__08_43_12_DISCRETE_PPO_GPT_V_ONLY_V64M64_SCR_bestckpt.zip" # noqa
if model_path is not None:
MODELPATH = model_path
MODELPATH = os.path.expanduser(MODELPATH)
backend = "GPT"
encoding = "V_ONLY"
model = PPO.load(MODELPATH)
print("Loaded {}".format(MODELPATH))
variant = get_variant(os.path.basename(MODELPATH))
env = NavRep3DAnyEnvDiscrete(build_name=build_name,
debug_export_every_n_episodes=0,
difficulty_mode=difficulty_mode,
render_trajectories=True)
env = EncoderObsWrapper(env, backend=backend, encoding=encoding, variant=variant)
env.reset()
successes = []
N = 1000
pbar = tqdm(range(N))
for i in pbar:
obs = env.reset()
while True:
action, _states = model.predict(obs, deterministic=True)
obs, reward, done, _ = env.step(action)
if render and i != 0:
if build_name == "rosbag":
env.render(save_to_file=True, action_override=action)
else:
env.render(save_to_file=True)
if done:
if reward > 50.:
if render:
print("Success!")
successes.append(1.)
else:
if render:
print("Failure.")
successes.append(0.)
pbar.set_description(f"Success rate: {sum(successes)/len(successes):.2f}")
break
if __name__ == "__main__":
StrictFire(main)