Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions examples/go2_pos_command/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Go2 Simple Locomotion Example

A simple program that teaches the Go2 robot to go to a target position in the world frame.

This example uses the Genesis Forge managed environment setup, which let's the environment be dedicated more to the scene setup
and reward shaping, than logic to handle domain randomization and logging.

## Training

This will be trained using the [rsl_rl](https://github.com/leggedrobotics/rsl_rl) training library. So first, we need to install that and tensorboard:

```bash
pip install tensorboard rsl-rl-lib>=2.2.4
```

Now you can run the training with:

```bash
python ./train.py
```


You can view the training progress with:

```bash
tensorboard --logdir ./logs/
```

The Genesis Forge training environment will also save videos while training that can be viewed in `./logs/go2-pos-command/videos`.


## Evaluation

Now you can view the trained policy:

```bash
python ./eval.py ./logs/go2-walking/
```
237 changes: 237 additions & 0 deletions examples/go2_pos_command/environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
"""
Simplified Go2 Locomotion Environment using managers to handle everything.
"""

import torch
import genesis as gs

from genesis_forge import ManagedEnvironment
from genesis_forge.managers import (
RewardManager,
TerminationManager,
EntityManager,
ObservationManager,
ActuatorManager,
PositionActionManager,
PositionCommandManager
)
from genesis_forge.mdp import reset, rewards, terminations


INITIAL_BODY_POSITION = [0.0, 0.0, 0.4]
INITIAL_QUAT = [1.0, 0.0, 0.0, 0.0]
TARGET_X_VELOCITY = 0.5


class Go2PosEnv(ManagedEnvironment):
"""
Example training environment for the Go2 robot.
"""

def __init__(
self,
num_envs: int = 1,
dt: float = 1 / 50, # control frequency on real robot is 50hz
max_episode_length_s: int | None = 20,
headless: bool = True,
):
super().__init__(
num_envs=num_envs,
dt=dt,
max_episode_length_sec=max_episode_length_s,
max_episode_random_scaling=0.1,
)

# Construct the scene
self.scene = gs.Scene(
show_viewer=not headless,
sim_options=gs.options.SimOptions(dt=self.dt, substeps=2),
viewer_options=gs.options.ViewerOptions(
max_FPS=int(0.5 / self.dt),
camera_pos=(2.0, 0.0, 2.5),
camera_lookat=(0.0, 0.0, 0.5),
camera_fov=40,
),
vis_options=gs.options.VisOptions(rendered_envs_idx=list(range(1))),
rigid_options=gs.options.RigidOptions(
dt=self.dt,
constraint_solver=gs.constraint_solver.Newton,
enable_collision=True,
enable_joint_limit=True,
# for this locomotion policy there are usually no more than 30 collision pairs
# set a low value can save memory
max_collision_pairs=30,
),
)

# Create terrain
self.terrain = self.scene.add_entity(gs.morphs.Plane())

# Robot
self.robot = self.scene.add_entity(
gs.morphs.URDF(
file="urdf/go2/urdf/go2.urdf",
pos=INITIAL_BODY_POSITION,
quat=INITIAL_QUAT,
),
)

# Camera, for headless video recording
self.camera = self.scene.add_camera(
pos=(-2.5, -1.5, 1.0),
lookat=(0.0, 0.0, 0.0),
res=(1280, 720),
fov=40,
env_idx=0,
debug=True,
)

def config(self):
"""
Configure the environment managers
"""
##
# Robot manager
# i.e. what to do with the robot when it is reset
self.robot_manager = EntityManager(
self,
entity_attr="robot",
on_reset={
# Reset the robot's initial position
"position": {
"fn": reset.position,
"params": {
"position": INITIAL_BODY_POSITION,
"quat": INITIAL_QUAT,
"zero_velocity": True,
},
},
},
)
self.position_command_manager = PositionCommandManager(
self,
debug_visualizer=True,
range = {
"x": (-10.0, 10.0),
"y": (-10.0, 10.0),
"z": (0.29, 0.31),
}
)
##
# Joint Actions
self.actuator_manager = ActuatorManager(
self,
joint_names=[
"FL_.*_joint",
"FR_.*_joint",
"RL_.*_joint",
"RR_.*_joint",
],
default_pos={
".*_hip_joint": 0.0,
"FL_thigh_joint": 0.8,
"FR_thigh_joint": 0.8,
"RL_thigh_joint": 1.0,
"RR_thigh_joint": 1.0,
".*_calf_joint": -1.5,
},
kp=20,
kv=0.5,
)
self.action_manager = PositionActionManager(
self,
scale=0.25,
clip=(-100.0, 100.0),
use_default_offset=True,
actuator_manager=self.actuator_manager,
)

##
# Rewards
RewardManager(
self,
logging_enabled=True,
cfg={
"tracking_lin_vel": {
"weight": 1.0,
"fn": rewards.command_tracking_position,
"params": {
"position_cmd_manager": self.position_command_manager,
"entity_manager": self.robot_manager,
},
},
"lin_vel_z": {
"weight": -1.0,
"fn": rewards.lin_vel_z_l2,
"params": {
"entity_manager": self.robot_manager,
},
},
"action_rate": {
"weight": -0.005,
"fn": rewards.action_rate_l2,
},
"similar_to_default": {
"weight": -0.1,
"fn": rewards.dof_similar_to_default,
"params": {
"action_manager": self.action_manager,
},
},
},
)

##
# Termination conditions
self.termination_manager = TerminationManager(
self,
logging_enabled=True,
term_cfg={
# The episode ended
"timeout": {
"fn": terminations.timeout,
"time_out": True,
},
# Terminate if the robot's pitch and yaw angles are too large
"fall_over": {
"fn": terminations.bad_orientation,
"params": {
"limit_angle": 10.0,
"entity_manager": self.robot_manager,
},
},
},
)

##
# Observations
ObservationManager(
self,
cfg={
"angle_velocity": {
"fn": lambda env: self.robot_manager.get_angular_velocity(),
"scale": 0.25,
},
"linear_velocity": {
"fn": lambda env: self.robot_manager.get_linear_velocity(),
"scale": 2.0,
},
"projected_gravity": {
"fn": lambda env: self.robot_manager.get_projected_gravity(),
},
"dof_position": {
"fn": lambda env: self.action_manager.get_dofs_position(),
},
"dof_velocity": {
"fn": lambda env: self.action_manager.get_dofs_velocity(),
"scale": 0.05,
},
"actions": {
"fn": lambda env: self.action_manager.get_actions(),
},
},
)

def build(self):
super().build()
self.camera.follow_entity(self.robot)
89 changes: 89 additions & 0 deletions examples/go2_pos_command/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import os
import glob
import torch
import pickle
import argparse
from importlib import metadata
import genesis as gs

from genesis_forge.wrappers import RslRlWrapper
from environment import Go2PosEnv

try:
try:
if metadata.version("rsl-rl"):
raise ImportError
except metadata.PackageNotFoundError:
if metadata.version("rsl-rl-lib").startswith("1."):
raise ImportError
except (metadata.PackageNotFoundError, ImportError) as e:
raise ImportError("Please install install 'rsl-rl-lib>=2.2.4'.") from e
from rsl_rl.runners import OnPolicyRunner

EXPERIMENT_NAME = "go2-pos-command"

parser = argparse.ArgumentParser(add_help=True)
parser.add_argument("-d", "--device", type=str, default="gpu")
parser.add_argument("-e", "--exp_name", type=str, default=EXPERIMENT_NAME)
args = parser.parse_args()


def get_latest_model(log_dir: str) -> str:
"""
Get the last model from the log directory
"""
model_checkpoints = glob.glob(os.path.join(log_dir, "model_*.pt"))
if len(model_checkpoints) == 0:
print(
f"Warning: No model files found at '{log_dir}' (you might need to train more)."
)
exit(1)
# Sort by the file with the highest number
sorted_models = sorted(
model_checkpoints,
key=lambda x: int(os.path.basename(x).split("_")[1].split(".")[0]),
)
return sorted_models[-1]


def main():
# Processor backend (GPU or CPU)
backend = gs.gpu
if args.device == "cpu":
backend = gs.cpu
torch.set_default_device("cpu")
gs.init(logging_level="warning", backend=backend)

# Load training configuration
log_path = f"./logs/{args.exp_name}"
[cfg] = pickle.load(open(f"{log_path}/cfgs.pkl", "rb"))
model = get_latest_model(log_path)

# Setup environment
env = Go2PosEnv(num_envs=1, headless=False)
env = RslRlWrapper(env)
env.build()

# Eval
print("🎬 Loading last model...")
runner = OnPolicyRunner(env, cfg, log_path, device=gs.device)
runner.load(model)
policy = runner.get_inference_policy(device=gs.device)

try:
obs, _ = env.reset()
with torch.no_grad():
while True:
actions = policy(obs)
obs, _rews, _dones, _infos = env.step(actions)
except KeyboardInterrupt:
pass
except gs.GenesisException as e:
if e.message != "Viewer closed.":
raise e
except Exception as e:
raise e


if __name__ == "__main__":
main()
Loading