diff --git a/examples/sensors/README.md b/examples/sensors/README.md new file mode 100644 index 0000000..a8a9787 --- /dev/null +++ b/examples/sensors/README.md @@ -0,0 +1,40 @@ +# Go2 Simple Locomotion Example + +A simple program that teaches the Go2 robot to walk forward. + +This example uses the Genesis Forge managed environment setup, which let's the environment be dedicated more to the scene setup +and reward shaping, than logic to handle domain randomization and logging. + +## Training + +This will be trained using the [rsl_rl](https://github.com/leggedrobotics/rsl_rl) training library. So first, we need to install that and tensorboard: + +```bash +pip install tensorboard rsl-rl-lib>=2.2.4 +``` + +Now you can run the training with: + +```bash +python ./train.py +``` + + +You can view the training progress with: + +```bash +tensorboard --logdir ./logs/ +``` + +The Genesis Forge training environment will also save videos while training that can be viewed in `./logs/go2-walking/videos`. + +https://github.com/user-attachments/assets/be46df1b-35e5-4b5b-9bbc-f543210dd463 + + +## Evaluation + +Now you can view the trained policy: + +```bash +python ./eval.py ./logs/go2-walking/ +``` diff --git a/examples/sensors/environment.py b/examples/sensors/environment.py new file mode 100644 index 0000000..110fb72 --- /dev/null +++ b/examples/sensors/environment.py @@ -0,0 +1,299 @@ +""" +Simplified Go2 Locomotion Environment using managers to handle everything. +""" + +import torch +import genesis as gs + +from genesis_forge import ManagedEnvironment +from genesis_forge import EnvMode +from genesis_forge.managers import ( + RewardManager, + TerminationManager, + EntityManager, + ObservationManager, + ActuatorManager, + PositionActionManager, +) +from genesis_forge.mdp import reset, rewards, terminations, observations + + +INITIAL_BODY_POSITION = [0.0, 0.0, 0.4] +INITIAL_QUAT = [1.0, 0.0, 0.0, 0.0] +TARGET_X_VELOCITY = 0.5 + + +class Go2SimpleEnv(ManagedEnvironment): + """ + Example training environment for the Go2 robot. + """ + + def __init__( + self, + num_envs: int = 1, + dt: float = 1 / 50, # control frequency on real robot is 50hz + max_episode_length_s: int | None = 20, + headless: bool = True, + ): + super().__init__( + num_envs=num_envs, + dt=dt, + max_episode_length_sec=max_episode_length_s, + max_episode_random_scaling=0.1, + ) + + # Set the commanded robot direction to be 0.5 along the X axis, for all environments + self.target_command = torch.zeros( + (self.num_envs, 3), device=gs.device, dtype=gs.tc_float + ) + self.target_command[:, 0] = ( + TARGET_X_VELOCITY # Linear velocity along the X axis + ) + + # Construct the scene + self.scene = gs.Scene( + show_viewer=not headless, + sim_options=gs.options.SimOptions(dt=self.dt, substeps=2), + viewer_options=gs.options.ViewerOptions( + max_FPS=int(0.5 / self.dt), + camera_pos=(2.0, 0.0, 2.5), + camera_lookat=(0.0, 0.0, 0.5), + camera_fov=40, + ), + vis_options=gs.options.VisOptions(rendered_envs_idx=list(range(1))), + rigid_options=gs.options.RigidOptions( + dt=self.dt, + constraint_solver=gs.constraint_solver.Newton, + enable_collision=True, + enable_joint_limit=True, + # for this locomotion policy there are usually no more than 30 collision pairs + # set a low value can save memory + max_collision_pairs=30, + ), + ) + + # Create terrain + self.terrain = self.scene.add_entity(gs.morphs.Plane()) + + # Robot + self.robot = self.scene.add_entity( + gs.morphs.URDF( + file="urdf/go2/urdf/go2.urdf", + pos=INITIAL_BODY_POSITION, + quat=INITIAL_QUAT, + ), + ) + self.imu = self.scene.add_sensor( + gs.sensors.IMU( + entity_idx=self.robot.idx, + link_idx_local=self.robot.links[0].idx_local, + pos_offset=(0.0, 0.0, 0.15), + # noise parameters + acc_cross_axis_coupling=(0.0, 0.01, 0.02), + gyro_cross_axis_coupling=(0.03, 0.04, 0.05), + acc_noise=(0.01, 0.01, 0.01), + gyro_noise=(0.01, 0.01, 0.01), + acc_random_walk=(0.001, 0.001, 0.001), + gyro_random_walk=(0.001, 0.001, 0.001), + # delay=0.01, + jitter=0.01, + interpolate=True, + # visualize + draw_debug=True, + ) + ) + + # Camera, for headless video recording + self.camera = self.scene.add_camera( + pos=(-2.5, -1.5, 1.0), + lookat=(0.0, 0.0, 0.0), + res=(1280, 720), + fov=40, + env_idx=0, + debug=True, + ) + + def config(self): + """ + Configure the environment managers + """ + ## + # Robot manager + # i.e. what to do with the robot when it is reset + self.robot_manager = EntityManager( + self, + entity_attr="robot", + on_reset={ + # Reset the robot's initial position + "position": { + "fn": reset.position, + "params": { + "position": INITIAL_BODY_POSITION, + "quat": INITIAL_QUAT, + "zero_velocity": True, + }, + }, + }, + ) + + ## + # Joint Actions + self.actuator_manager = ActuatorManager( + self, + joint_names=[ + "FL_.*_joint", + "FR_.*_joint", + "RL_.*_joint", + "RR_.*_joint", + ], + default_pos={ + ".*_hip_joint": 0.0, + "FL_thigh_joint": 0.8, + "FR_thigh_joint": 0.8, + "RL_thigh_joint": 1.0, + "RR_thigh_joint": 1.0, + ".*_calf_joint": -1.5, + }, + kp=20, + kv=0.5, + ) + self.action_manager = PositionActionManager( + self, + scale=0.25, + clip=(-100.0, 100.0), + use_default_offset=True, + actuator_manager=self.actuator_manager, + ) + + ## + # Observations + self.observation_manager = ObservationManager( + self, + history_len=2, + cfg={ + "angle_velocity": { + "fn": lambda env: self.robot_manager.get_angular_velocity(), + "scale": 0.25, + }, + "linear_velocity": { + "fn": lambda env: self.robot_manager.get_linear_velocity(), + "scale": 2.0, + }, + "projected_gravity": { + "fn": lambda env: self.robot_manager.get_projected_gravity(), + }, + "dof_position": { + "fn": lambda env: self.action_manager.get_dofs_position(), + }, + "dof_velocity": { + "fn": lambda env: self.action_manager.get_dofs_velocity(), + "scale": 0.05, + }, + "imu": { + "fn": observations.SensorObservation, + "params": { + "read": lambda: torch.cat( + [self.imu.read().lin_acc, self.imu.read().ang_vel], dim=-1 + ), + "frequency": 25, + }, + }, + "actions": { + "fn": lambda env: self.action_manager.get_actions(), + }, + }, + ) + + ## + # Rewards + RewardManager( + self, + logging_enabled=True, + cfg={ + "base_height_target": { + "weight": -50.0, + "fn": rewards.base_height, + "params": { + "target_height": 0.3, + "entity_attr": "robot", + }, + }, + "tracking_lin_vel": { + "weight": 1.0, + "fn": rewards.command_tracking_lin_vel, + "params": { + "command": self.target_command[:, :2], + "entity_manager": self.robot_manager, + }, + }, + "tracking_ang_vel": { + "weight": 0.2, + "fn": rewards.command_tracking_ang_vel, + "params": { + "commanded_ang_vel": self.target_command[:, 2], + "entity_manager": self.robot_manager, + }, + }, + "lin_vel_z": { + "weight": -1.0, + "fn": rewards.lin_vel_z_l2, + "params": { + "entity_manager": self.robot_manager, + }, + }, + "action_rate": { + "weight": -0.005, + "fn": rewards.action_rate_l2, + }, + "similar_to_default": { + "weight": -0.1, + "fn": rewards.dof_similar_to_default, + "params": { + "action_manager": self.action_manager, + }, + }, + "lin_acc_jitter": { + "weight": -0.1, + "fn": rewards.imu_lin_acc_jitter, + "params": { + "observation_manager": self.observation_manager, + "obs_item_key": "imu", + "ignore_gravity": True, + }, + }, + "ang_vel_jitter": { + "weight": -0.1, + "fn": rewards.imu_ang_vel_jitter, + "params": { + "observation_manager": self.observation_manager, + "obs_item_key": "imu", + }, + }, + }, + ) + + ## + # Termination conditions + self.termination_manager = TerminationManager( + self, + logging_enabled=True, + term_cfg={ + # The episode ended + "timeout": { + "fn": terminations.timeout, + "time_out": True, + }, + # Terminate if the robot's pitch and yaw angles are too large + "fall_over": { + "fn": terminations.bad_orientation, + "params": { + "limit_angle": 10.0, + "entity_manager": self.robot_manager, + }, + }, + }, + ) + + def build(self): + super().build() + self.camera.follow_entity(self.robot) diff --git a/examples/sensors/eval.py b/examples/sensors/eval.py new file mode 100644 index 0000000..a78b9f6 --- /dev/null +++ b/examples/sensors/eval.py @@ -0,0 +1,89 @@ +import os +import glob +import torch +import pickle +import argparse +from importlib import metadata +import genesis as gs + +from genesis_forge.wrappers import RslRlWrapper +from environment import Go2SimpleEnv + +try: + try: + if metadata.version("rsl-rl"): + raise ImportError + except metadata.PackageNotFoundError: + if metadata.version("rsl-rl-lib").startswith("1."): + raise ImportError +except (metadata.PackageNotFoundError, ImportError) as e: + raise ImportError("Please install install 'rsl-rl-lib>=2.2.4'.") from e +from rsl_rl.runners import OnPolicyRunner + +EXPERIMENT_NAME = "go2-simple" + +parser = argparse.ArgumentParser(add_help=True) +parser.add_argument("-d", "--device", type=str, default="gpu") +parser.add_argument("-e", "--exp_name", type=str, default=EXPERIMENT_NAME) +args = parser.parse_args() + + +def get_latest_model(log_dir: str) -> str: + """ + Get the last model from the log directory + """ + model_checkpoints = glob.glob(os.path.join(log_dir, "model_*.pt")) + if len(model_checkpoints) == 0: + print( + f"Warning: No model files found at '{log_dir}' (you might need to train more)." + ) + exit(1) + # Sort by the file with the highest number + sorted_models = sorted( + model_checkpoints, + key=lambda x: int(os.path.basename(x).split("_")[1].split(".")[0]), + ) + return sorted_models[-1] + + +def main(): + # Processor backend (GPU or CPU) + backend = gs.gpu + if args.device == "cpu": + backend = gs.cpu + torch.set_default_device("cpu") + gs.init(logging_level="warning", backend=backend) + + # Load training configuration + log_path = f"./logs/{args.exp_name}" + [cfg] = pickle.load(open(f"{log_path}/cfgs.pkl", "rb")) + model = get_latest_model(log_path) + + # Setup environment + env = Go2SimpleEnv(num_envs=1, headless=False, env_mode="eval") + env = RslRlWrapper(env) + env.build() + + # Eval + print("🎬 Loading last model...") + runner = OnPolicyRunner(env, cfg, log_path, device=gs.device) + runner.load(model) + policy = runner.get_inference_policy(device=gs.device) + + try: + obs, _ = env.reset() + with torch.no_grad(): + while True: + actions = policy(obs) + obs, _rews, _dones, _infos = env.step(actions) + except KeyboardInterrupt: + pass + except gs.GenesisException as e: + if e.message != "Viewer closed.": + raise e + except Exception as e: + raise e + + +if __name__ == "__main__": + main() diff --git a/examples/sensors/train.py b/examples/sensors/train.py new file mode 100644 index 0000000..57a7949 --- /dev/null +++ b/examples/sensors/train.py @@ -0,0 +1,134 @@ +import os +import copy +import torch +import shutil +import pickle +import argparse +from importlib import metadata +import genesis as gs + +from genesis_forge.wrappers import ( + VideoWrapper, + RslRlWrapper, +) +from environment import Go2SimpleEnv + +try: + try: + if metadata.version("rsl-rl"): + raise ImportError + except metadata.PackageNotFoundError: + if metadata.version("rsl-rl-lib").startswith("1."): + raise ImportError +except (metadata.PackageNotFoundError, ImportError) as e: + raise ImportError("Please install install 'rsl-rl-lib>=2.2.4'.") from e +from rsl_rl.runners import OnPolicyRunner + +EXPERIMENT_NAME = "go2-simple" + +parser = argparse.ArgumentParser(add_help=True) +parser.add_argument("-n", "--num_envs", type=int, default=4096) +parser.add_argument("--max_iterations", type=int, default=101) +parser.add_argument("-d", "--device", type=str, default="gpu") +parser.add_argument("-e", "--exp_name", type=str, default=EXPERIMENT_NAME) +args = parser.parse_args() + + +def training_cfg(exp_name: str, max_iterations: int): + return { + "algorithm": { + "class_name": "PPO", + "clip_param": 0.2, + "desired_kl": 0.01, + "entropy_coef": 0.01, + "gamma": 0.99, + "lam": 0.95, + "learning_rate": 0.001, + "max_grad_norm": 1.0, + "num_learning_epochs": 5, + "num_mini_batches": 4, + "schedule": "adaptive", + "use_clipped_value_loss": True, + "value_loss_coef": 1.0, + }, + "init_member_classes": {}, + "policy": { + "activation": "elu", + "actor_hidden_dims": [512, 256, 128], + "critic_hidden_dims": [512, 256, 128], + "init_noise_std": 1.0, + "class_name": "ActorCritic", + }, + "runner": { + "checkpoint": -1, + "experiment_name": exp_name, + "load_run": -1, + "log_interval": 1, + "max_iterations": max_iterations, + "record_interval": -1, + "resume": False, + "resume_path": None, + "run_name": "", + }, + "runner_class_name": "OnPolicyRunner", + "seed": 1, + "num_steps_per_env": 24, + "save_interval": 100, + "empirical_normalization": None, + "obs_groups": {"policy": ["policy"], "critic": ["policy"]}, + } + + +def main(): + # Initialize Genesis + # Processor backend (GPU or CPU) + backend = gs.gpu + if args.device == "cpu": + backend = gs.cpu + torch.set_default_device("cpu") + gs.init(logging_level="warning", backend=backend) + + # Logging directory + log_base_dir = "./logs" + experiment_name = args.exp_name + log_path = os.path.join(log_base_dir, experiment_name) + if os.path.exists(log_path): + shutil.rmtree(log_path) + os.makedirs(log_path, exist_ok=True) + print(f"Logging to: {log_path}") + + # Load training configuration and save snapshot of training configs + cfg = training_cfg(experiment_name, args.max_iterations) + pickle.dump( + [cfg], + open(os.path.join(log_path, "cfgs.pkl"), "wb"), + ) + + # Create environment + env = Go2SimpleEnv(num_envs=args.num_envs, headless=True) + + # Record videos in regular intervals + env = VideoWrapper( + env, + video_length_sec=12, + out_dir=os.path.join(log_path, "videos"), + episode_trigger=lambda episode_id: episode_id % 5 == 0, + ) + + # Build the environment + env = RslRlWrapper(env) + env.build() + env.reset() + + # Train + print("💪 Training model...") + runner = OnPolicyRunner(env, copy.deepcopy(cfg), log_path, device=gs.device) + runner.git_status_repos = ["."] + runner.learn( + num_learning_iterations=args.max_iterations, init_at_random_ep_len=False + ) + env.close() + + +if __name__ == "__main__": + main() diff --git a/examples/sensors/trained.mov b/examples/sensors/trained.mov new file mode 100644 index 0000000..050d460 Binary files /dev/null and b/examples/sensors/trained.mov differ diff --git a/genesis_forge/managers/observation_manager.py b/genesis_forge/managers/observation_manager.py index c3d47b6..75cada9 100644 --- a/genesis_forge/managers/observation_manager.py +++ b/genesis_forge/managers/observation_manager.py @@ -155,6 +155,9 @@ def __init__( for name, cfg in cfg.items(): self.cfg[name] = ObservationConfigItem(cfg, env) + self.observation_item_indices = {} + self.single_obs_size = 0 + """ Properties """ @@ -192,8 +195,8 @@ def build(self): return # Setup observation functions and the observation space - single_obs_size = self._setup_observation_functions() - self._observation_size = single_obs_size * self._history_len + self.single_obs_size = self._setup_observation_functions() + self._observation_size = self.single_obs_size * self._history_len self._observation_space = spaces.Box( low=-np.inf, high=np.inf, @@ -202,7 +205,7 @@ def build(self): ) # Fill history buffer - shape = (self.env.num_envs, single_obs_size) + shape = (self.env.num_envs, self.single_obs_size) self._history = [ torch.zeros(shape, device=gs.device) for _ in range(self._history_len) ] @@ -211,7 +214,7 @@ def build(self): device=gs.device, ) - def get_observations(self) -> torch.Tensor: + def get_observations(self, name=None) -> torch.Tensor: """Generate current observations for all environments.""" if not self.enabled: return torch.zeros((self.env.num_envs, self._observation_size)) @@ -227,6 +230,17 @@ def get_observations(self) -> torch.Tensor: size = obs.shape[1] self._history_output[:, offset : offset + size] = obs offset += size + if name is not None: + lower_index, upper_index = self.observation_item_indices[name] + indices = [] + for offset in range(self._history_len): + indices.extend( + range( + self.single_obs_size * offset + lower_index, + self.single_obs_size * offset + upper_index, + ) + ) + return self._history_output.clone()[:, torch.tensor(indices)] return self._history_output.clone() """ @@ -278,6 +292,11 @@ def _perform_observation(self, output: torch.Tensor) -> torch.Tensor: value_size = value.shape[-1] if value_size > 0: output[:, offset : offset + value_size] = value + if self.observation_item_indices.get(name) is None: + self.observation_item_indices[name] = ( + offset, + offset + value_size, + ) offset += value_size except Exception as e: print(f"Error generating observation for '{name}'") diff --git a/genesis_forge/mdp/observations.py b/genesis_forge/mdp/observations.py index ef386b9..79035fc 100644 --- a/genesis_forge/mdp/observations.py +++ b/genesis_forge/mdp/observations.py @@ -6,6 +6,7 @@ PositionActionManager, EntityManager, ContactManager, + MdpFnClass, ) from genesis_forge.utils import entity_lin_vel, entity_ang_vel, entity_projected_gravity from typing import TYPE_CHECKING @@ -205,3 +206,36 @@ def contact_force(env: GenesisEnv, contact_manager: ContactManager) -> torch.Ten Returns: tensor of shape (num_envs, num_contacts) """ return torch.norm(contact_manager.contacts[:, :, :], dim=-1) + + +""" +Rate limited sensors +""" + + +class RateLimitedObservation(MdpFnClass): + """ + Helper class for rate limited observations + + Args: + env: The Genesis Forge environment + read: The function for recieving observations with a rate limit + frequency: The frequency of the observation in hertz + """ + + def __init__( + self, + env: GenesisEnv, + read: Callable, + frequency: float, + ): + super().__init__(env) + self.last_data = None + self.next_read_time = 0 + self.read_interval = 1 / frequency + + def __call__(self, env: GenesisEnv, read: Callable, frequency: float): + if self.last_data is None or self.env.scene.cur_t >= self.next_read_time: + self.last_data = read() + self.next_read_time = self.env.scene.cur_t + self.read_interval + return self.last_data diff --git a/genesis_forge/mdp/reset.py b/genesis_forge/mdp/reset.py index 3ef9d44..094f7ac 100644 --- a/genesis_forge/mdp/reset.py +++ b/genesis_forge/mdp/reset.py @@ -291,3 +291,5 @@ def __call__( links_idx_local=self._links_idx_local, envs_idx=envs_idx, ) + + diff --git a/genesis_forge/mdp/rewards.py b/genesis_forge/mdp/rewards.py index eae526b..b6ec168 100644 --- a/genesis_forge/mdp/rewards.py +++ b/genesis_forge/mdp/rewards.py @@ -16,6 +16,7 @@ ContactManager, TerrainManager, EntityManager, + ObservationManager, ) from genesis_forge.utils import entity_lin_vel, entity_ang_vel, entity_projected_gravity from genesis_forge.managers import MdpFnClass @@ -406,6 +407,75 @@ def stand_still_joint_deviation_l1( return joint_deviation * (torch.norm(command[:, :2], dim=1) < command_threshold) +""" +Imu +""" + + +def imu_lin_acc_jitter( + _env: GenesisEnv, + observation_manager: ObservationManager, + obs_item_key: str, + ignore_gravity: bool = False, +) -> torch.Tensor: + """ + Penalize changes in linear acceleration (jerk) using the IMU's internal queue. + + Returns: + torch.Tensor of shape (num_envs,) + """ + single_obs_size = 3 + indices = [] + running_offset = 0 + for _ in range(observation_manager._history_len): + indices.append(list(range(running_offset, running_offset + single_obs_size))) + running_offset += single_obs_size + lin_acc_buffer = observation_manager.get_observations(obs_item_key)[ + :, torch.tensor(indices) + ] + if ignore_gravity: + lin_acc_buffer[:, :, 2] += 9.81 + + if lin_acc_buffer.shape[0] > 1: + diffs = lin_acc_buffer[:, 1:] - lin_acc_buffer[:, :-1, :] + else: + return torch.zeros(_env.num_envs, 1) + mags = torch.norm(diffs, dim=-1) + return mags.sum(dim=1) + + +def imu_ang_vel_jitter( + _env: GenesisEnv, + observation_manager: ObservationManager, + obs_item_key: str, +) -> torch.Tensor: + """ + Penalize changes in angular velocity (gyro jerk) using the IMU's internal queue. + + Returns: + torch.Tensor of shape (num_envs,) + """ + single_obs_size = 3 + indices = [] + running_offset = 0 + for _ in range(observation_manager._history_len): + indices.append(list(range(running_offset, running_offset + single_obs_size))) + running_offset += single_obs_size + ang_vel_buffer = observation_manager.get_observations(obs_item_key)[ + :, torch.tensor(indices) + ] + + if ang_vel_buffer.shape[0] > 1: + diffs = ang_vel_buffer[:, 1:] - ang_vel_buffer[:, :-1, :] + else: + print( + "the observation history buffer only has only the sensor data for one timsteps, the reward will be zero" + ) + return torch.zeros(_env.num_envs, 1) + mags = torch.norm(diffs, dim=-1) + return mags.sum(dim=1) + + """ Contacts """