Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 44 additions & 22 deletions genesis_forge/genesis_env.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
from __future__ import annotations
import math
import torch
import genesis as gs

try:
import genesis as gs
except ImportError:
print(
"Genesis package not found, if your wish to train, eval or play use the appropriatere env_mode when initialising the env"
)
gs = None
from gymnasium import spaces
from typing import Any, Literal, TYPE_CHECKING

if TYPE_CHECKING:
from genesis.engine.entities import RigidEntity

EnvMode = Literal["train", "eval", "play"]
EnvMode = Literal["train", "eval", "play", "real"]


class GenesisEnv:
Expand Down Expand Up @@ -58,13 +65,27 @@ def __init__(
max_episode_length_sec: int | None = 10,
max_episode_random_scaling: float = 0.0,
extras_logging_key: str = "episode",
env_mode: EnvMode = "train",
):
self.dt = dt
self.device = gs.device
self.num_envs = num_envs
self.scene: gs.Scene = None
self.robot: RigidEntity = None
self.terrain: RigidEntity = None
self.env_mode = env_mode
if self.env_mode != "real":
self.device = gs.device
self.float_dtype = gs.tc_float
self.int_dtype = gs.tc_int
self.bool_dtype = gs.tc_bool
self.REVOLUTE_JOINT_TYPE = gs.JOINT_TYPE.REVOLUTE
self.PRISMATIC_JOINT_TYPE = gs.JOINT_TYPE.PRISMATIC
self.num_envs = num_envs
self.scene: gs.Scene = None
self.robot: RigidEntity = None
self.terrain: RigidEntity = None
else:
self.device = torch.get_default_device()
self.float_dtype = torch.float32
self.int_dtype = torch.int32
self.bool_dtype = torch.bool
self.num_envs = num_envs

self.extras_logging_key = extras_logging_key
self._extras = {}
Expand All @@ -75,7 +96,7 @@ def __init__(

self.step_count: int = 0
self.episode_length = torch.zeros(
(self.num_envs,), device=gs.device, dtype=torch.int32
(self.num_envs,), device=self.device, dtype=self.int_dtype
)
self.max_episode_length: torch.Tensor = None

Expand All @@ -84,7 +105,7 @@ def __init__(
self._max_episode_random_scaling = max_episode_random_scaling
if max_episode_length_sec and max_episode_length_sec > 0:
self.max_episode_length = torch.zeros(
(self.num_envs,), device=gs.device, dtype=gs.tc_int
(self.num_envs,), device=self.device, dtype=self.int_dtype
)
self.max_episode_length[:] = self.set_max_episode_length(
max_episode_length_sec
Expand Down Expand Up @@ -175,10 +196,11 @@ def build(self) -> None:
Builds the environment before the first step.
The Genesis scene and all the scene entities must be added before calling this method.
"""
assert (
self.scene is not None
), "The scene must be constructed and assigned to the <env>.scene attribute before building."
self.scene.build(n_envs=self.num_envs)
if hasattr(self, "scene"):
assert (
self.scene is not None
), "The scene must be constructed and assigned to the <env>.scene attribute before building."
self.scene.build(n_envs=self.num_envs)

def step(
self, actions: torch.Tensor
Expand All @@ -198,8 +220,8 @@ def step(
self.episode_length += 1

if self._actions is None:
self._actions = torch.zeros_like(actions, device=gs.device)
self._last_actions = torch.zeros_like(actions, device=gs.device)
self._actions = torch.zeros_like(actions, device=self.device)
self._last_actions = torch.zeros_like(actions, device=self.device)

self._last_actions[:] = self._actions[:]
self._actions[:] = actions[:]
Expand All @@ -221,16 +243,16 @@ def reset(
A batch of observations and info from the vectorized environment.
"""
if envs_idx is None:
envs_idx = torch.arange(self.num_envs, device=gs.device)
envs_idx = torch.arange(self.num_envs, device=self.device)

# Initial reset, set buffers
if self.step_count == 0 and self.action_space is not None:
self._actions = torch.zeros(
(self.num_envs, self.action_space.shape[0]),
device=gs.device,
dtype=gs.tc_float,
device=self.device,
dtype=self.float_dtype,
)
self._last_actions = torch.zeros_like(self._actions, device=gs.device)
self._last_actions = torch.zeros_like(self._actions, device=self.device)

# Actions
if envs_idx.numel() > 0:
Expand All @@ -256,7 +278,7 @@ def reset(
)
self.max_episode_length[envs_idx] = torch.round(
self._base_max_episode_length + randomization
).to(gs.tc_int)
).to(self.int_dtype)

return None, self.extras

Expand All @@ -283,8 +305,8 @@ def get_observations(self) -> torch.Tensor:
if self.observation_space is not None:
return torch.zeros(
(self.num_envs, self.observation_space.shape[0]),
device=gs.device,
dtype=gs.tc_float,
device=self.device,
dtype=self.float_dtype,
)
return None

Expand Down
89 changes: 63 additions & 26 deletions genesis_forge/managed_env.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import torch
import numpy as np
from typing import Any, TypedDict
from gymnasium import spaces
import genesis as gs
from tensordict import TensorDict
from genesis_forge.genesis_env import GenesisEnv
from genesis_forge.genesis_env import GenesisEnv, EnvMode
from genesis_forge.managers.base import BaseManager, ManagerType
from genesis_forge.managers import (
ContactManager,
Expand All @@ -19,12 +20,12 @@


class ManagersDict(TypedDict):
actuator: ActuatorManager | None
actuator: list[ActuatorManager]
contact: list[ContactManager]
entity: list[EntityManager]
command: list[CommandManager]
terrain: list[TerrainManager]
action: PositionActionManager | None
action: list[PositionActionManager]
observation: list[ObservationManager]
reward: RewardManager | None
termination: TerminationManager | None
Expand Down Expand Up @@ -119,40 +120,43 @@ def __init__(
max_episode_length_sec: int | None = 10,
max_episode_random_scaling: float = 0.0,
extras_logging_key: str = "episode",
env_mode: EnvMode = "train",
):
super().__init__(
num_envs=num_envs,
dt=dt,
max_episode_length_sec=max_episode_length_sec,
max_episode_random_scaling=max_episode_random_scaling,
extras_logging_key=extras_logging_key,
env_mode=env_mode,
)

self.managers: ManagersDict = {
"contact": [],
"entity": [],
"command": [],
"terrain": [],
"actuator": [],
"action": [],
# there can only be one of each of these
"actuator": None,
"action": None,
"observation": [],
"reward": None,
"termination": None,
}

self._action_space = None
self._action_ranges: list[tuple[int, int]] = []
self._observation_space = None
self._reward_buf = torch.zeros(
(self.num_envs,), device=gs.device, dtype=gs.tc_float
(self.num_envs,), device=self.device, dtype=self.float_dtype
)
self._terminated_buf = torch.zeros(
(self.num_envs,), device=gs.device, dtype=gs.tc_bool
(self.num_envs,), device=self.device, dtype=self.bool_dtype
)
self._truncated_buf = torch.zeros(
(self.num_envs,), device=gs.device, dtype=gs.tc_bool
(self.num_envs,), device=self.device, dtype=self.bool_dtype
)
self._observations_buf = TensorDict({}, device=gs.device)
self._observations_buf = TensorDict({}, device=self.device)

"""
Properties
Expand All @@ -161,13 +165,9 @@ def __init__(
@property
def action_space(self) -> torch.Tensor:
"""
The action space, provided by the action manager, if it exists.
The action space, provided by the action manager(s), if any exist.
"""
if self.managers["action"] is not None:
return self.managers["action"].action_space
if self._action_space is not None:
return self._action_space
return None
return self._action_space

@action_space.setter
def action_space(self, action_space: spaces.Space):
Expand Down Expand Up @@ -260,10 +260,11 @@ def build(self):

for terrain_manager in self.managers["terrain"]:
terrain_manager.build()
if self.managers["actuator"] is not None:
self.managers["actuator"].build()
if self.managers["action"] is not None:
self.managers["action"].build()

for actuator_manager in self.managers["actuator"]:
actuator_manager.build()
self._build_action_managers()

for contact_manager in self.managers["contact"]:
contact_manager.build()
if self.managers["termination"] is not None:
Expand Down Expand Up @@ -292,8 +293,11 @@ def step(
super().step(actions)

# Execute the actions and a simulation step
if self.managers["action"] is not None:
self.managers["action"].step(actions)
for i, action_manager in enumerate[PositionActionManager](
self.managers["action"]
):
(start, end) = self._action_ranges[i]
action_manager.step(actions[:, start:end])
self.scene.step()

# Update entity managers
Expand Down Expand Up @@ -353,10 +357,10 @@ def reset(
"""
(obs, _) = super().reset(env_ids)

if self.managers["actuator"] is not None:
self.managers["actuator"].reset(env_ids)
if self.managers["action"] is not None:
self.managers["action"].reset(env_ids)
for actuator_manager in self.managers["actuator"]:
actuator_manager.reset(env_ids)
for action_manager in self.managers["action"]:
action_manager.reset(env_ids)
for entity_manager in self.managers["entity"]:
entity_manager.reset(env_ids)
for contact_manager in self.managers["contact"]:
Expand All @@ -383,7 +387,7 @@ def get_observations(self) -> torch.Tensor:
If you use the ObservationManager, this will be handled automatically.
Otherwise, override this method to return the observations.
"""
self.extras["observations"] = TensorDict({}, device=gs.device)
self.extras["observations"] = TensorDict({}, device=self.device)

# Get observations
if len(self.managers["observation"]) > 0:
Expand All @@ -397,3 +401,36 @@ def get_observations(self) -> torch.Tensor:

# Otherwise, call super
return super().get_observations()

"""
Internal methods
"""

def _build_action_managers(self):
"""
Build the action managers and combine the action spaces.
"""
if len(self.managers["action"]) == 0:
return

low = []
high = []
size = 0
self._action_ranges = []
for action_manager in self.managers["action"]:
action_manager.build()

start = size
size += action_manager.action_space.shape[0]
end = size
self._action_ranges.append((start, end))

low.append(action_manager.action_space.low)
high.append(action_manager.action_space.high)

self._action_space = spaces.Box(
low=np.concatenate(low),
high=np.concatenate(high),
shape=(size,),
dtype=np.float32,
)
Loading