Skip to content

Commit 836fef1

Browse files
committed
feat(server): connect external compute driver via acquired UDS endpoint
Splits ComputeRuntime construction so that the gateway can dispatch to an out-of-tree compute driver process listening on a Unix domain socket the operator already owns, without adding a fifth `ComputeDriverKind` variant. * `ComputeRuntime::from_driver` now takes `Option<ComputeDriverKind>`; the four in-tree constructors wrap their kind in `Some(...)`. Out-of- tree drivers pass `None` so callers keyed off the enum skip the in-tree match. * `connect_external_compute_driver` produces a tonic `Channel` over a pre-existing UDS path, mirroring the vm driver's connector. A `#[cfg(not(unix))]` stub returns the same error the vm path uses. * `ComputeRuntime::new_remote_external` consumes the channel via the existing `RemoteComputeDriver` proxy and skips both shutdown cleanup and managed-process supervision (the operator owns the lifecycle). * `from_driver` logs the `driver_name` advertised by `GetCapabilities` whenever `driver_kind` is `None`, so operators can confirm the gateway connected to the driver they expect. * `build_compute_runtime` short-circuits to the external path when `Config.external_compute_driver_socket` is set, before consulting `--drivers` / auto-detect. The four in-tree backends (Kubernetes, Vm, Docker, Podman) keep their existing dispatch arms unchanged. Signed-off-by: st-gr <38470677+st-gr@users.noreply.github.com>
1 parent dbab961 commit 836fef1

2 files changed

Lines changed: 111 additions & 12 deletions

File tree

crates/openshell-server/src/compute/mod.rs

Lines changed: 91 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,20 @@ use openshell_driver_kubernetes::{
3535
use openshell_driver_podman::{
3636
ComputeDriverService as PodmanDriverService, PodmanComputeConfig, PodmanComputeDriver,
3737
};
38+
use hyper_util::rt::TokioIo;
3839
use prost::Message;
3940
use std::fmt;
4041
use std::net::SocketAddr;
42+
use std::path::{Path, PathBuf};
4143
use std::pin::Pin;
4244
use std::sync::Arc;
4345
use std::time::Duration;
4446
use tokio::sync::Mutex;
45-
use tonic::transport::Channel;
47+
#[cfg(unix)]
48+
use tokio::net::UnixStream;
49+
use tonic::transport::{Channel, Endpoint};
4650
use tonic::{Code, Request, Status};
51+
use tower::service_fn;
4752
use tracing::{info, warn};
4853

4954
type DriverWatchStream = Pin<Box<dyn Stream<Item = Result<WatchSandboxesEvent, Status>> + Send>>;
@@ -101,11 +106,11 @@ pub use openshell_core::ComputeDriverError as ComputeError;
101106
#[derive(Debug)]
102107
pub struct ManagedDriverProcess {
103108
child: std::sync::Mutex<Option<tokio::process::Child>>,
104-
socket_path: std::path::PathBuf,
109+
socket_path: PathBuf,
105110
}
106111

107112
impl ManagedDriverProcess {
108-
pub(crate) fn new(child: tokio::process::Child, socket_path: std::path::PathBuf) -> Self {
113+
pub(crate) fn new(child: tokio::process::Child, socket_path: PathBuf) -> Self {
109114
Self {
110115
child: std::sync::Mutex::new(Some(child)),
111116
socket_path,
@@ -243,7 +248,7 @@ impl fmt::Debug for ComputeRuntime {
243248
impl ComputeRuntime {
244249
#[allow(clippy::too_many_arguments)]
245250
async fn from_driver(
246-
driver_kind: ComputeDriverKind,
251+
driver_kind: Option<ComputeDriverKind>,
247252
driver: SharedComputeDriver,
248253
shutdown_cleanup: Option<Arc<dyn ShutdownCleanup>>,
249254
startup_resume: Option<Arc<dyn StartupResume>>,
@@ -256,15 +261,24 @@ impl ComputeRuntime {
256261
_allows_loopback_endpoints: bool,
257262
gateway_bind_addresses: Vec<SocketAddr>,
258263
) -> Result<Self, ComputeError> {
259-
let default_image = driver
264+
let capabilities = driver
260265
.get_capabilities(Request::new(GetCapabilitiesRequest {}))
261266
.await
262267
.map_err(compute_error_from_status)?
263-
.into_inner()
264-
.default_image;
268+
.into_inner();
269+
// For out-of-tree drivers (driver_kind = None), log the name the
270+
// driver advertises in GetCapabilities so operators can confirm
271+
// the gateway is talking to the driver they expect.
272+
if driver_kind.is_none() {
273+
info!(
274+
driver_name = %capabilities.driver_name,
275+
"External compute driver connected"
276+
);
277+
}
278+
let default_image = capabilities.default_image;
265279
Ok(Self {
266280
driver,
267-
driver_kind: Some(driver_kind),
281+
driver_kind,
268282
shutdown_cleanup,
269283
startup_resume,
270284
_driver_process: driver_process,
@@ -308,7 +322,7 @@ impl ComputeRuntime {
308322
let startup_resume: Arc<dyn StartupResume> = driver.clone();
309323
let driver: SharedComputeDriver = driver;
310324
Self::from_driver(
311-
ComputeDriverKind::Docker,
325+
Some(ComputeDriverKind::Docker),
312326
driver,
313327
Some(shutdown_cleanup),
314328
Some(startup_resume),
@@ -337,7 +351,7 @@ impl ComputeRuntime {
337351
.map_err(|err| ComputeError::Message(err.to_string()))?;
338352
let driver: SharedComputeDriver = Arc::new(ComputeDriverService::new(driver));
339353
Self::from_driver(
340-
ComputeDriverKind::Kubernetes,
354+
Some(ComputeDriverKind::Kubernetes),
341355
driver,
342356
None,
343357
None,
@@ -364,7 +378,7 @@ impl ComputeRuntime {
364378
) -> Result<Self, ComputeError> {
365379
let driver: SharedComputeDriver = Arc::new(RemoteComputeDriver::new(channel));
366380
Self::from_driver(
367-
ComputeDriverKind::Vm,
381+
Some(ComputeDriverKind::Vm),
368382
driver,
369383
None,
370384
None,
@@ -380,6 +394,39 @@ impl ComputeRuntime {
380394
.await
381395
}
382396

397+
/// Construct a runtime that proxies all sandbox lifecycle to an
398+
/// out-of-tree compute driver listening on a pre-existing UDS endpoint.
399+
///
400+
/// The driver process is operator-managed (not spawned by the gateway),
401+
/// so no [`ManagedDriverProcess`] handle is attached. The advertised
402+
/// `driver_name` from `GetCapabilities` is logged for diagnostics by
403+
/// [`Self::from_driver`].
404+
pub(crate) async fn new_remote_external(
405+
channel: Channel,
406+
store: Arc<Store>,
407+
sandbox_index: SandboxIndex,
408+
sandbox_watch_bus: SandboxWatchBus,
409+
tracing_log_bus: TracingLogBus,
410+
supervisor_sessions: Arc<SupervisorSessionRegistry>,
411+
) -> Result<Self, ComputeError> {
412+
let driver: SharedComputeDriver = Arc::new(RemoteComputeDriver::new(channel));
413+
Self::from_driver(
414+
None,
415+
driver,
416+
None,
417+
None,
418+
None,
419+
store,
420+
sandbox_index,
421+
sandbox_watch_bus,
422+
tracing_log_bus,
423+
supervisor_sessions,
424+
true,
425+
Vec::new(),
426+
)
427+
.await
428+
}
429+
383430
pub async fn new_podman(
384431
config: PodmanComputeConfig,
385432
store: Arc<Store>,
@@ -393,7 +440,7 @@ impl ComputeRuntime {
393440
.map_err(|err| ComputeError::Message(err.to_string()))?;
394441
let driver: SharedComputeDriver = Arc::new(PodmanDriverService::new(driver));
395442
Self::from_driver(
396-
ComputeDriverKind::Podman,
443+
Some(ComputeDriverKind::Podman),
397444
driver,
398445
None,
399446
None,
@@ -1250,6 +1297,38 @@ impl ComputeRuntime {
12501297
}
12511298
}
12521299

1300+
/// Connect to an out-of-tree compute driver that is already listening on
1301+
/// `socket_path` and return a tonic `Channel` speaking `compute_driver.proto`.
1302+
///
1303+
/// The gateway does not spawn or own the driver process — the operator is
1304+
/// responsible for placing the driver alongside the gateway and granting the
1305+
/// gateway uid read/write on the socket. The host portion of the URL is
1306+
/// ignored because the connector resolves to the UDS rather than DNS.
1307+
#[cfg(unix)]
1308+
pub async fn connect_external_compute_driver(socket_path: &Path) -> Result<Channel, ComputeError> {
1309+
let socket_path: PathBuf = socket_path.to_path_buf();
1310+
let display_path = socket_path.clone();
1311+
Endpoint::from_static("http://[::]:50051")
1312+
.connect_with_connector(service_fn(move |_: tonic::transport::Uri| {
1313+
let socket_path = socket_path.clone();
1314+
async move { UnixStream::connect(socket_path).await.map(TokioIo::new) }
1315+
}))
1316+
.await
1317+
.map_err(|e| {
1318+
ComputeError::Message(format!(
1319+
"failed to connect to external compute driver socket '{}': {e}",
1320+
display_path.display()
1321+
))
1322+
})
1323+
}
1324+
1325+
#[cfg(not(unix))]
1326+
pub async fn connect_external_compute_driver(_socket_path: &Path) -> Result<Channel, ComputeError> {
1327+
Err(ComputeError::Message(
1328+
"the external compute driver requires unix domain socket support".to_string(),
1329+
))
1330+
}
1331+
12531332
fn driver_sandbox_from_public(
12541333
sandbox: &Sandbox,
12551334
driver_kind: Option<ComputeDriverKind>,

crates/openshell-server/src/lib.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -703,6 +703,26 @@ async fn build_compute_runtime(
703703
tracing_log_bus: TracingLogBus,
704704
supervisor_sessions: Arc<supervisor_session::SupervisorSessionRegistry>,
705705
) -> Result<ComputeRuntime> {
706+
if let Some(socket_path) = config.external_compute_driver_socket.as_deref() {
707+
info!(
708+
socket = %socket_path.display(),
709+
"Using external compute driver"
710+
);
711+
let channel = compute::connect_external_compute_driver(socket_path)
712+
.await
713+
.map_err(|e| Error::execution(format!("failed to create compute runtime: {e}")))?;
714+
return ComputeRuntime::new_remote_external(
715+
channel,
716+
store,
717+
sandbox_index,
718+
sandbox_watch_bus,
719+
tracing_log_bus,
720+
supervisor_sessions,
721+
)
722+
.await
723+
.map_err(|e| Error::execution(format!("failed to create compute runtime: {e}")));
724+
}
725+
706726
let driver = configured_compute_driver(config)?;
707727
info!(driver = %driver, "Using compute driver");
708728
warn_if_kubernetes_sandbox_jwt_expiry_disabled(config, driver);

0 commit comments

Comments
 (0)