-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathvisualize.py
More file actions
80 lines (64 loc) · 2.36 KB
/
Copy pathvisualize.py
File metadata and controls
80 lines (64 loc) · 2.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import myosuite
import imageio
import numpy as np
from stable_baselines3 import PPO
from reward_tutorial import CustomRewardPoseEnv
from gymnasium.wrappers import TimeLimit
def visualize_and_save_video():
# We will use the model trained on our custom environment
model_path = "models/ppo_myofinger_custom_reward.zip"
video_dir = "video"
if not os.path.exists(model_path):
print(f"Error: Model not found at {model_path}. Please run train.py first.")
return
os.makedirs(video_dir, exist_ok=True)
video_path = os.path.join(video_dir, "episode_custom_reward.mp4")
# Use myosuite.__file__ to get the exact path to the package directory
myosuite_dir = os.path.dirname(myosuite.__file__)
model_path_env = os.path.join(myosuite_dir, "simhive", "myo_sim", "finger", "myofinger_v0.xml")
base_env = CustomRewardPoseEnv(
model_path=model_path_env,
target_jnt_range={
"IFadb": (0, 0),
"IFmcp": (0, 0),
"IFpip": (0.75, 0.75),
"IFdip": (0.75, 0.75),
},
viz_site_targets=("IFtip",),
normalize_act=True,
use_muscle_noise=False
)
# We must wrap it exactly as we did in training!
env = TimeLimit(base_env, max_episode_steps=100)
model = PPO.load(model_path, env=env)
obs, info = env.reset()
frames = []
print("Running an episode with the custom-trained agent...")
max_steps = 200
for _ in range(max_steps):
try:
# We must access the unwrapped environment's renderer
frame = env.unwrapped.mj_renderer.render_offscreen(
width=640,
height=480,
camera_id=-1
)
frames.append(np.asarray(frame, dtype=np.uint8))
except Exception as e:
env.close()
print(f"Rendering failed: {e}")
return
action, _states = model.predict(obs, deterministic=True)
obs, reward, terminated, truncated, info = env.step(action)
if terminated or truncated:
break
env.close()
if frames:
print(f"Saving video to {video_path}...")
imageio.mimsave(video_path, frames, fps=30)
print("Video saved successfully!")
else:
print("Failed to capture any frames for the video.")
if __name__ == "__main__":
visualize_and_save_video()