Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 93 additions & 18 deletions genesis_forge/managers/action/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,35 @@ class BaseActionManager(BaseManager):

Args:
env: The environment to manage the DOF actuators for.
delay_step: The number of steps to delay the actions for.
This is an easy way to emulate the latency in the system.
delay_step: The number of steps to delay actions for, to emulate the latency in the system.
This can be an integer for a fixed delay, or a tuple (min, max) for a per-environment random delay range.
"""

def __init__(self, env: GenesisEnv, delay_step: int = 0):
def __init__(self, env: GenesisEnv, delay_step: int | tuple[int, int] = 0):
super().__init__(env, type="action")
self._raw_actions = None
self._actions = None
self._envs_idx: torch.Tensor | None = None

self._delay_step = delay_step
self._action_delay_buffer = []
self._delay_ring_buffer_head = 0
self._delay_ring_buffer: torch.Tensor | None = None
self._delay_step_idx: torch.Tensor | None = None

# Validate the delay_step tuple
if isinstance(delay_step, tuple):
min_delay, max_delay = delay_step
if min_delay < 0 or max_delay < min_delay:
raise ValueError(
f"Invalid delay_step range: {self._delay_step}. Must be (min, max) where min >= 0 and max >= min"
)
elif isinstance(delay_step, int):
if delay_step < 0:
raise ValueError(
f"Invalid delay_step: {self._delay_step}. Must be >= 0"
)
if delay_step == 0:
self._delay_step = None

"""
Properties
Expand Down Expand Up @@ -64,17 +83,35 @@ def raw_actions(self) -> torch.Tensor:
return torch.zeros((self.env.num_envs, self.num_actions))
return self._raw_actions

"""
Lifecycle Operations
"""

def build(self):
"""Initialize the action delay buffers."""
self._envs_idx = torch.arange(self.env.num_envs, device=gs.device)
if self._delay_step is not None:
if isinstance(self._delay_step, tuple):
max_delay = self._delay_step[1]
else:
max_delay = self._delay_step

self._delay_ring_buffer = torch.zeros(
(self.env.num_envs, max_delay + 1, self.num_actions),
dtype=torch.float32,
device=gs.device,
)
self._delay_step_idx = torch.zeros(
self.env.num_envs, dtype=torch.int32, device=gs.device
)

def step(self, actions: torch.Tensor) -> None:
"""
Handle the received actions.
"""
# Action delay buffer
if self._delay_step > 0:
self._action_delay_buffer.insert(0, actions)
actions = self._action_delay_buffer.pop()
self._raw_actions = self._apply_action_delay(actions)

# Copy the actions into the manager buffer
self._raw_actions = actions
if self._actions is None:
self._actions = self._raw_actions.clone()
else:
Expand All @@ -83,15 +120,20 @@ def step(self, actions: torch.Tensor) -> None:

def reset(self, envs_idx: list[int] | None):
"""Reset environments."""
if (
self._delay_step > 0
and len(self._action_delay_buffer) < self._delay_step
and self.num_actions > 0
):
while len(self._action_delay_buffer) < self._delay_step:
self._action_delay_buffer.append(
torch.zeros((self.env.num_envs, self.num_actions), device=gs.device)
)

# Per-environment random action delay
if isinstance(self._delay_step, tuple) and self._delay_step_idx is not None:
min_delay, max_delay = self._delay_step
if envs_idx is None:
envs_idx = self._envs_idx

self._delay_step_idx[envs_idx] = torch.randint(
min_delay,
max_delay + 1,
(len(envs_idx),),
device=gs.device,
dtype=torch.int32,
)

def get_actions(self) -> torch.Tensor:
"""
Expand All @@ -100,3 +142,36 @@ def get_actions(self) -> torch.Tensor:
if self._actions is None:
return torch.zeros((self.env.num_envs, self.num_actions))
return self._actions

"""
Internal Operations
"""

def _apply_action_delay(self, actions: torch.Tensor) -> torch.Tensor:
"""
When action delay is enabled (via `delay_step`), the actions will be pushed onto a ring buffer,
and then an older action-set will be returned. If the delay step is a tuple range, each environment
will randomly be assigned a delay step within that range.
"""
if self._delay_step is None or self._delay_ring_buffer is None:
return actions

# Ring buffer head index
# This is the index of the ring buffer that we should add the new action set to
delay_buffer_len = self._delay_ring_buffer.shape[1]
head = (self._delay_ring_buffer_head + 1) % delay_buffer_len

# Add to delay ring buffer
self._delay_ring_buffer[:, head, :] = actions
self._delay_ring_buffer_head = head

# Get the buffer index for the delayed actions
if isinstance(self._delay_step, tuple):
idx = (head - self._delay_step_idx) % delay_buffer_len
actions = self._delay_ring_buffer[self._envs_idx, idx, :]
# Fixed delay for all environments
else:
idx = (head - self._delay_step) % delay_buffer_len
actions = self._delay_ring_buffer[:, idx, :]

return actions
7 changes: 4 additions & 3 deletions genesis_forge/managers/action/position_action_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ class PositionActionManager(BaseActionManager):
use_default_offset: Whether to use default joint positions configured in the articulation asset as offset. Defaults to True.
clip: Clip the action values to the range. If omitted, the action values will automatically be clipped to the joint limits.
quiet_action_errors: Whether to quiet action errors.
delay_step: The number of steps to delay the actions for.
This is an easy way to emulate the latency in the system.
delay_step: The number of steps to delay the actions for, to emulate the latency in the system.
This can be an integer for a fixed delay, or a tuple (min, max) for a per-environment random delay range.

Example::

Expand Down Expand Up @@ -123,7 +123,7 @@ def __init__(
use_default_offset: bool = True,
action_handler: Callable[[torch.Tensor], None] = None,
quiet_action_errors: bool = False,
delay_step: int = 0,
delay_step: int | tuple[int, int] = 0,
**kwargs,
):
super().__init__(env, delay_step)
Expand Down Expand Up @@ -277,6 +277,7 @@ def build(self):
"""
Builds the manager and initialized all the buffers.
"""
super().build()

# Define the clip values
lower_limit, upper_limit = self._actuator_manager.get_dofs_limits()
Expand Down
8 changes: 5 additions & 3 deletions genesis_forge/managers/action/position_within_limits.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ class PositionWithinLimitsActionManager(PositionActionManager):
actuator_manager: The actuator manager which is used to setup and control the DOF joints.
action_handler: A function to handle the actions.
quiet_action_errors: Whether to quiet action errors.
delay_step: The number of steps to delay the actions for.
This is an easy way to emulate the latency in the system.
delay_step: The number of steps to delay the actions for, to emulate the latency in the system.
This can be an integer for a fixed delay, or a tuple (min, max) for a per-environment random delay range.

Example::

Expand Down Expand Up @@ -55,7 +55,7 @@ def __init__(
actuator_manager: ActuatorManager | None = None,
action_handler: Callable[[torch.Tensor], None] = None,
quiet_action_errors: bool = False,
delay_step: int = 0,
delay_step: int | tuple[int, int] = 0,
**kwargs,
):
super().__init__(
Expand All @@ -75,6 +75,8 @@ def build(self):
"""
Builds the manager and initialized all the buffers.
"""
super().build()

lower, upper = self._actuator_manager.get_dofs_limits()
lower = lower.unsqueeze(0).expand(self.env.num_envs, -1)
upper = upper.unsqueeze(0).expand(self.env.num_envs, -1)
Expand Down