Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions data/.lfs/g1_urdf.tar.gz
Git LFS file not shown
3 changes: 3 additions & 0 deletions data/.lfs/scene_packages.tar.gz
Git LFS file not shown
27 changes: 27 additions & 0 deletions dimos/control/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"""

from dataclasses import dataclass, field
import inspect
import threading
import time
from typing import TYPE_CHECKING, Any
Expand Down Expand Up @@ -551,6 +552,32 @@ def set_dry_run(self, enabled: bool) -> None:
except Exception:
logger.exception(f"set_dry_run() raised on task {task.name!r}")

@rpc
def reset_runtime_state(self, reactivate: bool | None = None) -> dict[str, bool]:
"""Reset transient state on tasks that expose ``reset_runtime_state``.

This is meant for simulation/runtime discontinuities such as MuJoCo
respawn, where task histories and latched commands must be cleared
without tearing down the coordinator.
"""
results: dict[str, bool] = {}
with self._task_lock:
for task in self._tasks.values():
handler = getattr(task, "reset_runtime_state", None)
if not callable(handler):
results[task.name] = False
continue
try:
params = inspect.signature(handler).parameters
if "reactivate" in params:
results[task.name] = bool(handler(reactivate=reactivate))
else:
results[task.name] = bool(handler())
except Exception:
logger.exception(f"reset_runtime_state() raised on task {task.name!r}")
results[task.name] = False
return results

@rpc
def task_invoke(
self, task_name: TaskName, method: str, kwargs: dict[str, Any] | None = None
Expand Down
62 changes: 60 additions & 2 deletions dimos/control/tasks/g1_groot_wbc_task/g1_groot_wbc_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,26 @@
_NUM_MOTORS = 29


def _preferred_onnx_providers() -> list[str]:
available = ort.get_available_providers()
providers: list[str] = []

if "CUDAExecutionProvider" in available:
preload_dlls = getattr(ort, "preload_dlls", None)
if preload_dlls is not None:
try:
preload_dlls(cuda=True, cudnn=True, msvc=False)
except Exception as exc:
logger.warning(
"Failed to preload ONNXRuntime CUDA/cuDNN libraries",
error=repr(exc),
)
providers.append("CUDAExecutionProvider")

providers.append("CPUExecutionProvider")
return providers


@dataclass
class G1GrootWBCTaskConfig:
"""Configuration for the GR00T WBC task.
Expand Down Expand Up @@ -289,7 +309,7 @@ def __init__(
self._joint_names_set = frozenset(config.joint_names)
self._all_joint_names = list(config.all_joint_names)

providers = ort.get_available_providers()
providers = _preferred_onnx_providers()
self._balance_session = ort.InferenceSession(str(config.balance_onnx), providers=providers)
self._walk_session = ort.InferenceSession(str(config.walk_onnx), providers=providers)
self._balance_input = self._balance_session.get_inputs()[0].name
Expand All @@ -299,7 +319,9 @@ def __init__(
task=name,
balance=str(config.balance_onnx),
walk=str(config.walk_onnx),
providers=providers,
requested_providers=providers,
balance_providers=self._balance_session.get_providers(),
walk_providers=self._walk_session.get_providers(),
)

self._default_29 = np.asarray(config.default_positions_29, dtype=np.float32)
Expand Down Expand Up @@ -663,6 +685,42 @@ def disarm(self) -> bool:
logger.info("G1GrootWBCTask disarmed (holding current pose)", task=self._name)
return True

def reset_runtime_state(self, reactivate: bool | None = None) -> bool:
"""Clear runtime policy state after a simulation discontinuity.

``reactivate=None`` preserves whether the task was armed/arming before
reset. Passing ``True`` forces a clean immediate re-arm on the next
coordinator tick, which is useful after MuJoCo respawn.
"""
was_armed = self._armed or self._arming or self._arm_pending
should_reactivate = was_armed if reactivate is None else bool(reactivate)

self._armed = False
self._arming = False
self._arm_pending = False
self._ramp_start = None
self._arming_start_t = 0.0
self._last_targets = None
self._state_seen = False
self._cached_q_29[:] = self._default_29
self._cached_dq_29[:] = 0.0
self._cached_q_15[:] = self._default_15
self._reset_policy_state()
with self._cmd_lock:
self._cmd[:] = 0.0
self._last_cmd_time = 0.0

if self._active and should_reactivate:
self._arming_duration = 0.0
self._arm_pending = True

logger.info(
"G1GrootWBCTask runtime state reset",
task=self._name,
reactivate=should_reactivate,
)
return True

def set_dry_run(self, enabled: bool) -> None:
"""Enable/disable dry-run.

Expand Down
32 changes: 32 additions & 0 deletions dimos/control/tasks/g1_groot_wbc_task/test_g1_groot_wbc_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,23 @@ def __init__(
label: str,
action: np.ndarray,
call_log: list[str],
providers: Any = None,
) -> None:
self.model_path = model_path
self._label = label
self._action = action
self._call_log = call_log
self._providers = list(providers or [])
fake_input = MagicMock()
fake_input.name = "obs"
self._inputs = [fake_input]

def get_inputs(self) -> list[Any]:
return self._inputs

def get_providers(self) -> list[str]:
return self._providers

def run(self, _outputs: Any, _feed: dict[str, np.ndarray]) -> list[np.ndarray]:
self._call_log.append(self._label)
return [self._action.reshape(1, -1)]
Expand All @@ -77,6 +82,7 @@ def _factory(path: str, providers: Any = None) -> _StubSession:
label=label,
action=np.full(15, 0.1, dtype=np.float32),
call_log=call_log,
providers=providers,
)

monkeypatch.setattr(g1_groot_wbc_task.ort, "InferenceSession", _factory)
Expand Down Expand Up @@ -278,6 +284,32 @@ def test_dry_run_suppresses_output_but_keeps_policy_hot(
assert np.any(task._obs_buf != 0.0)


def test_reset_runtime_state_clears_policy_state_and_rearms(
task: G1GrootWBCTask, joints_29: list[str]
) -> None:
task.start()
task.set_velocity_command(0.5, 0.0, 0.0, t_now=100.0)
for _ in range(10):
task.compute(_state_at(100.0, joints_29))

assert task.state_snapshot()["armed"]
assert np.any(task._obs_buf != 0.0)
assert np.any(task._cmd != 0.0)

assert task.reset_runtime_state(reactivate=True)

snapshot = task.state_snapshot()
assert not snapshot["armed"]
assert snapshot["arm_pending"]
np.testing.assert_array_equal(task._obs_buf, np.zeros_like(task._obs_buf))
np.testing.assert_array_equal(task._last_action, np.zeros_like(task._last_action))
np.testing.assert_array_equal(task._cmd, np.zeros_like(task._cmd))
assert task._last_cmd_time == 0.0

task.compute(_state_at(101.0, joints_29))
assert task.state_snapshot()["armed"]


def test_projected_gravity_matches_reference_quaternion_order() -> None:
np.testing.assert_allclose(
G1GrootWBCTask._projected_gravity((1.0, 0.0, 0.0, 0.0)),
Expand Down
38 changes: 38 additions & 0 deletions dimos/control/test_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from dimos.control.coordinator import ControlCoordinator
from dimos.control.hardware_interface import ConnectedHardware
from dimos.control.task import (
BaseControlTask,
ControlMode,
CoordinatorState,
JointCommandOutput,
Expand Down Expand Up @@ -190,6 +191,43 @@ def test_write_command(self, connected_hardware, mock_adapter):


class TestControlCoordinatorLifecycle:
def test_reset_runtime_state_calls_task_hooks(self):
class ResettableTask(BaseControlTask):
def __init__(self) -> None:
self.reset_reactivate_args: list[bool | None] = []

@property
def name(self) -> str:
return "resettable"

def claim(self) -> ResourceClaim:
return ResourceClaim(joints=frozenset())

def is_active(self) -> bool:
return True

def compute(self, state: CoordinatorState) -> JointCommandOutput | None:
_ = state
return None

def on_preempted(self, by_task: str, joints: frozenset[str]) -> None:
_ = by_task, joints

def reset_runtime_state(self, reactivate: bool | None = None) -> bool:
self.reset_reactivate_args.append(reactivate)
return True

coordinator = ControlCoordinator(publish_joint_state=False)
task = ResettableTask()

try:
assert coordinator.add_task(task)

assert coordinator.reset_runtime_state(reactivate=True) == {"resettable": True}
assert task.reset_reactivate_args == [True]
finally:
coordinator.stop()

def test_start_stop_calls_adapter_activate_and_deactivate(self):
from dimos.hardware.manipulators.mock.adapter import MockAdapter
from dimos.hardware.manipulators.registry import adapter_registry
Expand Down
1 change: 1 addition & 0 deletions dimos/core/global_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class GlobalConfig(BaseSettings):
mujoco_global_map_from_pointcloud: str | None = None
mujoco_start_pos: str = "-1.0, 1.0"
mujoco_steps_per_frame: int = 7
scene: str | None = None
robot_model: str | None = None
robot_id: str | None = None
robot_width: float = 0.3
Expand Down
23 changes: 21 additions & 2 deletions dimos/hardware/whole_body/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,33 @@ def available(self) -> list[str]:
return sorted(self._adapters.keys())

def discover(self) -> None:
"""Discover and register whole-body hardware adapters.
"""Discover and register whole-body adapters.

Walks the hardware whole-body package recursively looking for
``adapter.py`` modules that provide a ``register(registry)`` function.
``adapter.py`` modules, then scans simulation whole-body modules. Any
discovered module can provide a ``register(registry)`` function.
"""
import dimos.hardware.whole_body as hw_pkg
import dimos.simulation.adapters.whole_body as sim_pkg

self._discover_in("dimos.hardware.whole_body", hw_pkg.__path__[0], max_depth=2)
self._discover_simulation_adapters(
"dimos.simulation.adapters.whole_body",
sim_pkg.__path__[0],
)

def _discover_simulation_adapters(self, pkg_path: str, dir_path: str) -> None:
for entry in sorted(os.listdir(dir_path)):
if entry.startswith(("_", ".")) or not entry.endswith(".py"):
continue
module_name = entry.removesuffix(".py")
try:
mod = importlib.import_module(f"{pkg_path}.{module_name}")
except ImportError as e:
logger.warning(f"Skipping whole-body simulation adapter {module_name}: {e}")
continue
if hasattr(mod, "register"):
mod.register(self)

def _discover_in(self, pkg_path: str, dir_path: str, *, max_depth: int) -> None:
for entry in sorted(os.listdir(dir_path)):
Expand Down
Loading
Loading