Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 44 additions & 5 deletions dimos/core/coordination/blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,14 @@
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, Literal, Union, get_args, get_origin, get_type_hints

from pydantic import create_model
from pydantic import BaseModel, create_model

if TYPE_CHECKING:
from dimos.protocol.service.system_configurator.base import SystemConfigurator

from dimos.core.global_config import GlobalConfig
from dimos.core.module import ModuleBase, is_module_type
from dimos.core.stream import In, Out
from dimos.core.transport import PubSubTransport
from dimos.core.stream import In, Out, Transport
from dimos.spec.utils import Spec, is_spec
from dimos.utils.logging_config import setup_logger

Expand Down Expand Up @@ -141,6 +140,29 @@ def create(cls, module: type[ModuleBase], kwargs: dict[str, Any]) -> Self:
)


@dataclass(frozen=True)
class TransportSpec:
"""Deferred transport construction: a transport class plus its ctor args.

Blueprint authors declare transports via ``Cls.spec(...)`` so nothing is
constructed at blueprint-definition time. The coordinator materializes
specs at build time, once CLI/env/config overrides have resolved.
"""

cls: type[Transport[Any]]
args: tuple[Any, ...] = ()
kwargs: Mapping[str, Any] = field(default_factory=dict)

@property
def config_cls(self) -> type[BaseModel] | None:
# Set by transports that expose a pydantic config override surface
return getattr(self.cls, "_config_cls", None)

def build(self, config: BaseModel | None = None) -> Transport[Any]:
extra = {"config": config} if config is not None else {}
return self.cls(*self.args, **self.kwargs, **extra)


# These fields cannot be pickled.
_PROXY_FIELDS = ("transport_map", "global_config_overrides", "remapping_map")

Expand All @@ -149,7 +171,7 @@ def create(cls, module: type[ModuleBase], kwargs: dict[str, Any]) -> Self:
class Blueprint:
blueprints: tuple[BlueprintAtom, ...]
disabled_modules_tuple: tuple[type[ModuleBase], ...] = field(default_factory=tuple)
transport_map: Mapping[tuple[str, type], PubSubTransport[Any]] = field(
transport_map: Mapping[tuple[str, type], TransportSpec] = field(
default_factory=lambda: MappingProxyType({})
)
global_config_overrides: Mapping[str, Any] = field(default_factory=lambda: MappingProxyType({}))
Expand Down Expand Up @@ -185,9 +207,22 @@ def config(self) -> type:
for b in self.blueprints
}
configs["g"] = (GlobalConfig | None, None)
transport_fields: dict[str, Any] = {}
seen: set[type] = set()
for spec in self.transport_map.values():
cls = spec.config_cls
if cls is None or cls in seen:
continue
seen.add(cls)
transport_fields[transport_config_name(cls)] = (cls | None, None)
if transport_fields:
transports_model = create_model(
"TransportsConfig", __config__={"extra": "forbid"}, **transport_fields
)
configs["transports"] = (transports_model | None, None)
return create_model("BlueprintConfig", __config__={"extra": "forbid"}, **configs) # type: ignore[call-overload,no-any-return]

def transports(self, transports: dict[tuple[str, type], Any]) -> "Blueprint":
def transports(self, transports: dict[tuple[str, type], TransportSpec]) -> "Blueprint":
return replace(self, transport_map=MappingProxyType({**self.transport_map, **transports}))

def global_config(self, **kwargs: Any) -> "Blueprint":
Expand Down Expand Up @@ -219,6 +254,10 @@ def active_blueprints(self) -> tuple[BlueprintAtom, ...]:
return tuple(bp for bp in self.blueprints if bp.module not in disabled)


def transport_config_name(cls: type) -> str:
return cls.__name__.removesuffix("Config").lower()


def autoconnect(*blueprints: Blueprint) -> Blueprint:
all_blueprints = tuple(_eliminate_duplicates([bp for bs in blueprints for bp in bs.blueprints]))
all_transports = dict( # type: ignore[var-annotated]
Expand Down
50 changes: 37 additions & 13 deletions dimos/core/coordination/module_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@
import threading
from typing import TYPE_CHECKING, Any, NamedTuple, cast

from dimos.core.coordination.blueprints import transport_config_name
from dimos.core.coordination.coordinator_rpc import CoordinatorRPC
from dimos.core.coordination.worker_manager import WorkerManager
from dimos.core.coordination.worker_manager_python import WorkerManagerPython
from dimos.core.global_config import GlobalConfig, global_config
from dimos.core.module import ModuleBase, ModuleSpec
from dimos.core.resource import Resource
from dimos.core.stream import Transport
from dimos.core.transport import LCMTransport, PubSubTransport, pLCMTransport
from dimos.spec.utils import is_spec, spec_annotation_compliance, spec_structural_compliance
from dimos.utils.generic import short_id
Expand Down Expand Up @@ -65,9 +67,9 @@ def __init__(
self._deployed_modules = {}
self._deployed_atoms: dict[type[ModuleBase], BlueprintAtom] = {}
self._resolved_module_refs: dict[tuple[type[ModuleBase], str], type[ModuleBase]] = {}
self._transport_registry: dict[tuple[str, type], PubSubTransport[Any]] = {}
self._transport_registry: dict[tuple[str, type], Transport[Any]] = {}
self._class_aliases: dict[type[ModuleBase], type[ModuleBase]] = {}
self._module_transports: dict[type[ModuleBase], dict[str, PubSubTransport[Any]]] = {}
self._module_transports: dict[type[ModuleBase], dict[str, Transport[Any]]] = {}
self._started = False
self._modules_lock = threading.RLock()
self._coordinator_rpc: CoordinatorRPC | None = None
Expand Down Expand Up @@ -249,7 +251,9 @@ def _send_on_system_modules(self) -> None:
if hasattr(module, "on_system_modules"):
module.on_system_modules(modules)

def _connect_streams(self, blueprint: Blueprint) -> None:
def _connect_streams(
self, blueprint: Blueprint, transports: Mapping[tuple[str, type], Transport[Any]]
) -> None:
streams: dict[tuple[str, type], list[tuple[type, str]]] = defaultdict(list)

for bp in blueprint.active_blueprints:
Expand All @@ -263,7 +267,9 @@ def _connect_streams(self, blueprint: Blueprint) -> None:
if key in self._transport_registry:
transport = self._transport_registry[key]
else:
transport = _get_transport_for(blueprint, remapped_name, stream_type)
transport = transports.get(key)
if transport is None:
transport = _get_transport_for(blueprint, remapped_name, stream_type)
self._transport_registry[key] = transport
for module, original_name in streams[key]:
instance = self.get_instance(module) # type: ignore[assignment]
Expand All @@ -290,6 +296,8 @@ def build(
blueprint_args = blueprint_args or {}
if "g" in blueprint_args:
global_config.update(**blueprint_args.pop("g"))
transport_overrides = blueprint_args.pop("transports", None) or {}
transports = _materialize_transports(blueprint, transport_overrides)

_run_configurators(blueprint)
_check_requirements(blueprint)
Expand All @@ -300,7 +308,7 @@ def build(
coordinator.start()

_deploy_all_modules(blueprint, coordinator, global_config, blueprint_args)
coordinator._connect_streams(blueprint)
coordinator._connect_streams(blueprint, transports)
_connect_module_refs(blueprint, coordinator)

coordinator.build_all_modules()
Expand Down Expand Up @@ -337,6 +345,8 @@ def _load_blueprint(
blueprint_args = blueprint_args or {}
if "g" in blueprint_args:
self._global_config.update(**blueprint_args.pop("g"))
transport_overrides = blueprint_args.pop("transports", None) or {}
transports = _materialize_transports(blueprint, transport_overrides)

# Scale worker pool.
n_extra = int(blueprint.global_config_overrides.get("n_workers", 0))
Expand All @@ -361,7 +371,7 @@ def _load_blueprint(
before = set(self._deployed_modules)

_deploy_all_modules(blueprint, self, self._global_config, blueprint_args)
self._connect_streams(blueprint)
self._connect_streams(blueprint, transports)
_connect_module_refs(blueprint, self, existing_modules=before)

new_modules = [proxy for cls, proxy in self._deployed_modules.items() if cls not in before]
Expand Down Expand Up @@ -572,15 +582,29 @@ def _is_name_unique(blueprint: Blueprint, name: str) -> bool:


def _get_transport_for(blueprint: Blueprint, name: str, stream_type: type) -> PubSubTransport[Any]:
transport = blueprint.transport_map.get((name, stream_type), None)
if transport:
return transport

use_pickled = getattr(stream_type, "lcm_encode", None) is None
topic = f"/{name}" if _is_name_unique(blueprint, name) else f"/{short_id()}"
transport = pLCMTransport(topic) if use_pickled else LCMTransport(topic, stream_type)
return pLCMTransport(topic) if use_pickled else LCMTransport(topic, stream_type)


return transport
def _materialize_transports(
blueprint: Blueprint, overrides: Mapping[str, Mapping[str, Any]]
) -> dict[tuple[str, type], Transport[Any]]:
"""Build the blueprint's declared transports, merging CLI/env config overrides.

WebRTC transports get a freshly constructed provider config from the
resolved ``transports.<name>.*`` overrides; everything else builds from the
spec as-is. Returns ready-to-use instances pickled into module workers.
"""
materialized: dict[tuple[str, type], Transport[Any]] = {}
for key, spec in blueprint.transport_map.items():
config = None
config_cls = spec.config_cls
if config_cls is not None:
sub = overrides.get(transport_config_name(config_cls), {})
config = config_cls(**sub)
materialized[key] = spec.build(config=config)
return materialized


def _verify_no_name_conflicts(blueprint: Blueprint) -> None:
Expand Down Expand Up @@ -621,7 +645,7 @@ def _verify_no_name_conflicts(blueprint: Blueprint) -> None:

def _verify_no_conflicts_with_existing(
blueprint: Blueprint,
existing_registry: dict[tuple[str, type], PubSubTransport[Any]],
existing_registry: dict[tuple[str, type], Transport[Any]],
) -> None:
"""Check that a new blueprint's streams don't conflict with already-registered transports."""
if not existing_registry:
Expand Down
8 changes: 5 additions & 3 deletions dimos/core/coordination/test_blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,15 @@ def test_config() -> None:


def test_transports() -> None:
custom_transport = LCMTransport("/custom_topic", Data1)
blueprint_set = autoconnect(ModuleA.blueprint(), ModuleB.blueprint()).transports(
{("data1", Data1): custom_transport}
{("data1", Data1): LCMTransport.spec("/custom_topic", Data1)}
)

assert ("data1", Data1) in blueprint_set.transport_map
assert blueprint_set.transport_map[("data1", Data1)] == custom_transport
# TransportSpec compares by value (class + args + kwargs), not identity.
assert blueprint_set.transport_map[("data1", Data1)] == LCMTransport.spec(
"/custom_topic", Data1
)


def test_global_config() -> None:
Expand Down
6 changes: 6 additions & 0 deletions dimos/core/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
TypeVar,
)

from pydantic import BaseModel
import reactivex as rx
from reactivex import operators as ops
from reactivex.disposable import Disposable
Expand Down Expand Up @@ -80,6 +81,11 @@ class State(enum.Enum):


class Transport(Resource, ObservableMixin[T]):
# Transports that expose a pydantic config override surface set this to the
# config class; the blueprint config flow picks them up automatically. None
# means "no overridable config" (LCM/SHM transports).
_config_cls: type[BaseModel] | None = None

# used by local Output
def broadcast(self, selfstream: Out[T], value: T) -> None:
raise NotImplementedError
Expand Down
2 changes: 1 addition & 1 deletion dimos/core/test_native_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def test_autoconnect(args_file: str) -> None:
StubProducer.blueprint(),
).transports(
{
("pointcloud", PointCloud2): LCMTransport("/my/custom/lidar", PointCloud2),
("pointcloud", PointCloud2): LCMTransport.spec("/my/custom/lidar", PointCloud2),
},
)

Expand Down
9 changes: 9 additions & 0 deletions dimos/core/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
if TYPE_CHECKING:
from collections.abc import Callable

from dimos.core.coordination.blueprints import TransportSpec

T = TypeVar("T")

# TODO
Expand Down Expand Up @@ -68,6 +70,13 @@ class PubSubTransport(Transport[T]):
def __init__(self, topic: Any) -> None:
self.topic = topic

@classmethod
def spec(cls, *args: Any, **kwargs: Any) -> TransportSpec:
"""Defer construction: capture ctor args for the coordinator to build later."""
from dimos.core.coordination.blueprints import TransportSpec

return TransportSpec(cls, args, kwargs)

def __str__(self) -> str:
return (
colors.green(f"{self.__class__.__name__}(")
Expand Down
2 changes: 1 addition & 1 deletion dimos/perception/fiducial/blueprints/desk_marker_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def publish_static_chain(self) -> None:
),
).transports(
{
("detections", MarkerDetectionStreamModule): LCMTransport(
("detections", MarkerDetectionStreamModule): LCMTransport.spec(
"/marker_detection/detections",
Detection3DArray,
),
Expand Down
11 changes: 6 additions & 5 deletions dimos/robot/cli/dimos.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,17 +171,18 @@ def arg_help(

if inspect.isclass(t) and issubclass(t, BaseModel):
output += f"{indent}{module}{k}:\n"
# Find blueprint atom
bp = next(bp for bp in blueprint.blueprints if bp.module.name == k)
# transports.* has no backing blueprint atom — its leaves come from
# the transport configs' own defaults.
bp = next((bp for bp in blueprint.blueprints if bp.module.name == k), None)
output += arg_help(
t, blueprint, indent=indent + " ", module=module + k + ".", _atom=bp
)
else:
assert _atom is not None
# Use __name__ to avoid "<class 'int'>" style output on basic types.
display_type = t.__name__ if isinstance(t, type) else t
required = "[Required] " if info.is_required() and k not in _atom.kwargs else ""
d = _atom.kwargs.get(k, info.default)
in_kwargs = _atom is not None and k in _atom.kwargs
required = "[Required] " if info.is_required() and not in_kwargs else ""
d = _atom.kwargs.get(k, info.default) if _atom is not None else info.default
default = f" (default: {d})" if d is not PydanticUndefined else ""
output += f"{indent}* {required}{module}{k}: {display_type}{default}\n"
return output
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,12 @@
# collide with ControlCoordinator's (joint_state/joint_command/...).
.transports(
{
("motor_states", JointState): LCMTransport("/g1/motor_states", JointState),
("imu", Imu): LCMTransport("/g1/imu", Imu),
("motor_command", MotorCommandArray): LCMTransport(
("motor_states", JointState): LCMTransport.spec("/g1/motor_states", JointState),
("imu", Imu): LCMTransport.spec("/g1/imu", Imu),
("motor_command", MotorCommandArray): LCMTransport.spec(
"/g1/motor_command", MotorCommandArray
),
("joint_command", JointState): LCMTransport("/g1/joint_command", JointState),
("joint_command", JointState): LCMTransport.spec("/g1/joint_command", JointState),
}
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -245,16 +245,18 @@ def _viewer() -> Any:
],
).transports(
{
("joint_command", JointState): LCMTransport("/g1/joint_command", JointState),
("twist_command", Twist): LCMTransport("/g1/cmd_vel", Twist),
("tele_cmd_vel", Twist): LCMTransport("/g1/cmd_vel", Twist),
("joint_command", JointState): LCMTransport.spec("/g1/joint_command", JointState),
("twist_command", Twist): LCMTransport.spec("/g1/cmd_vel", Twist),
("tele_cmd_vel", Twist): LCMTransport.spec("/g1/cmd_vel", Twist),
# Real-hw only: the transport_lcm adapter speaks to
# G1WholeBodyConnection over these topics. autoconnect already
# matches by (name, type) so sim doesn't need them -- they're
# harmless when the sim engine doesn't expose those ports.
("motor_states", JointState): LCMTransport("/g1/motor_states", JointState),
("imu", Imu): LCMTransport("/g1/imu", Imu),
("motor_command", MotorCommandArray): LCMTransport("/g1/motor_command", MotorCommandArray),
("motor_states", JointState): LCMTransport.spec("/g1/motor_states", JointState),
("imu", Imu): LCMTransport.spec("/g1/imu", Imu),
("motor_command", MotorCommandArray): LCMTransport.spec(
"/g1/motor_command", MotorCommandArray
),
}
)

Expand Down
Loading
Loading