diff --git a/genesis_forge/managers/action/base.py b/genesis_forge/managers/action/base.py index 327330c..bd215d4 100644 --- a/genesis_forge/managers/action/base.py +++ b/genesis_forge/managers/action/base.py @@ -12,16 +12,35 @@ class BaseActionManager(BaseManager): Args: env: The environment to manage the DOF actuators for. - delay_step: The number of steps to delay the actions for. - This is an easy way to emulate the latency in the system. + delay_step: The number of steps to delay actions for, to emulate the latency in the system. + This can be an integer for a fixed delay, or a tuple (min, max) for a per-environment random delay range. """ - def __init__(self, env: GenesisEnv, delay_step: int = 0): + def __init__(self, env: GenesisEnv, delay_step: int | tuple[int, int] = 0): super().__init__(env, type="action") self._raw_actions = None self._actions = None + self._envs_idx: torch.Tensor | None = None + self._delay_step = delay_step - self._action_delay_buffer = [] + self._delay_ring_buffer_head = 0 + self._delay_ring_buffer: torch.Tensor | None = None + self._delay_step_idx: torch.Tensor | None = None + + # Validate the delay_step tuple + if isinstance(delay_step, tuple): + min_delay, max_delay = delay_step + if min_delay < 0 or max_delay < min_delay: + raise ValueError( + f"Invalid delay_step range: {self._delay_step}. Must be (min, max) where min >= 0 and max >= min" + ) + elif isinstance(delay_step, int): + if delay_step < 0: + raise ValueError( + f"Invalid delay_step: {self._delay_step}. Must be >= 0" + ) + if delay_step == 0: + self._delay_step = None """ Properties @@ -64,17 +83,35 @@ def raw_actions(self) -> torch.Tensor: return torch.zeros((self.env.num_envs, self.num_actions)) return self._raw_actions + """ + Lifecycle Operations + """ + + def build(self): + """Initialize the action delay buffers.""" + self._envs_idx = torch.arange(self.env.num_envs, device=gs.device) + if self._delay_step is not None: + if isinstance(self._delay_step, tuple): + max_delay = self._delay_step[1] + else: + max_delay = self._delay_step + + self._delay_ring_buffer = torch.zeros( + (self.env.num_envs, max_delay + 1, self.num_actions), + dtype=torch.float32, + device=gs.device, + ) + self._delay_step_idx = torch.zeros( + self.env.num_envs, dtype=torch.int32, device=gs.device + ) + def step(self, actions: torch.Tensor) -> None: """ Handle the received actions. """ - # Action delay buffer - if self._delay_step > 0: - self._action_delay_buffer.insert(0, actions) - actions = self._action_delay_buffer.pop() + self._raw_actions = self._apply_action_delay(actions) # Copy the actions into the manager buffer - self._raw_actions = actions if self._actions is None: self._actions = self._raw_actions.clone() else: @@ -83,15 +120,20 @@ def step(self, actions: torch.Tensor) -> None: def reset(self, envs_idx: list[int] | None): """Reset environments.""" - if ( - self._delay_step > 0 - and len(self._action_delay_buffer) < self._delay_step - and self.num_actions > 0 - ): - while len(self._action_delay_buffer) < self._delay_step: - self._action_delay_buffer.append( - torch.zeros((self.env.num_envs, self.num_actions), device=gs.device) - ) + + # Per-environment random action delay + if isinstance(self._delay_step, tuple) and self._delay_step_idx is not None: + min_delay, max_delay = self._delay_step + if envs_idx is None: + envs_idx = self._envs_idx + + self._delay_step_idx[envs_idx] = torch.randint( + min_delay, + max_delay + 1, + (len(envs_idx),), + device=gs.device, + dtype=torch.int32, + ) def get_actions(self) -> torch.Tensor: """ @@ -100,3 +142,36 @@ def get_actions(self) -> torch.Tensor: if self._actions is None: return torch.zeros((self.env.num_envs, self.num_actions)) return self._actions + + """ + Internal Operations + """ + + def _apply_action_delay(self, actions: torch.Tensor) -> torch.Tensor: + """ + When action delay is enabled (via `delay_step`), the actions will be pushed onto a ring buffer, + and then an older action-set will be returned. If the delay step is a tuple range, each environment + will randomly be assigned a delay step within that range. + """ + if self._delay_step is None or self._delay_ring_buffer is None: + return actions + + # Ring buffer head index + # This is the index of the ring buffer that we should add the new action set to + delay_buffer_len = self._delay_ring_buffer.shape[1] + head = (self._delay_ring_buffer_head + 1) % delay_buffer_len + + # Add to delay ring buffer + self._delay_ring_buffer[:, head, :] = actions + self._delay_ring_buffer_head = head + + # Get the buffer index for the delayed actions + if isinstance(self._delay_step, tuple): + idx = (head - self._delay_step_idx) % delay_buffer_len + actions = self._delay_ring_buffer[self._envs_idx, idx, :] + # Fixed delay for all environments + else: + idx = (head - self._delay_step) % delay_buffer_len + actions = self._delay_ring_buffer[:, idx, :] + + return actions diff --git a/genesis_forge/managers/action/position_action_manager.py b/genesis_forge/managers/action/position_action_manager.py index 9fa5a30..c19ccfc 100644 --- a/genesis_forge/managers/action/position_action_manager.py +++ b/genesis_forge/managers/action/position_action_manager.py @@ -44,8 +44,8 @@ class PositionActionManager(BaseActionManager): use_default_offset: Whether to use default joint positions configured in the articulation asset as offset. Defaults to True. clip: Clip the action values to the range. If omitted, the action values will automatically be clipped to the joint limits. quiet_action_errors: Whether to quiet action errors. - delay_step: The number of steps to delay the actions for. - This is an easy way to emulate the latency in the system. + delay_step: The number of steps to delay the actions for, to emulate the latency in the system. + This can be an integer for a fixed delay, or a tuple (min, max) for a per-environment random delay range. Example:: @@ -123,7 +123,7 @@ def __init__( use_default_offset: bool = True, action_handler: Callable[[torch.Tensor], None] = None, quiet_action_errors: bool = False, - delay_step: int = 0, + delay_step: int | tuple[int, int] = 0, **kwargs, ): super().__init__(env, delay_step) @@ -277,6 +277,7 @@ def build(self): """ Builds the manager and initialized all the buffers. """ + super().build() # Define the clip values lower_limit, upper_limit = self._actuator_manager.get_dofs_limits() diff --git a/genesis_forge/managers/action/position_within_limits.py b/genesis_forge/managers/action/position_within_limits.py index d6f21d4..cc6b923 100644 --- a/genesis_forge/managers/action/position_within_limits.py +++ b/genesis_forge/managers/action/position_within_limits.py @@ -16,8 +16,8 @@ class PositionWithinLimitsActionManager(PositionActionManager): actuator_manager: The actuator manager which is used to setup and control the DOF joints. action_handler: A function to handle the actions. quiet_action_errors: Whether to quiet action errors. - delay_step: The number of steps to delay the actions for. - This is an easy way to emulate the latency in the system. + delay_step: The number of steps to delay the actions for, to emulate the latency in the system. + This can be an integer for a fixed delay, or a tuple (min, max) for a per-environment random delay range. Example:: @@ -55,7 +55,7 @@ def __init__( actuator_manager: ActuatorManager | None = None, action_handler: Callable[[torch.Tensor], None] = None, quiet_action_errors: bool = False, - delay_step: int = 0, + delay_step: int | tuple[int, int] = 0, **kwargs, ): super().__init__( @@ -75,6 +75,8 @@ def build(self): """ Builds the manager and initialized all the buffers. """ + super().build() + lower, upper = self._actuator_manager.get_dofs_limits() lower = lower.unsqueeze(0).expand(self.env.num_envs, -1) upper = upper.unsqueeze(0).expand(self.env.num_envs, -1)