Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions architecture/compute-runtimes.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,9 @@ through the driver configuration. The Helm chart defaults sandbox agents to
`Unconfined` so runtime/default AppArmor profiles do not block supervisor
network namespace setup on AppArmor-enabled nodes.

GPU requests enter the driver layer through `SandboxSpec.gpu` and
`SandboxSpec.gpu_device`. Docker and Podman map default GPU requests to one
concrete NVIDIA CDI device when individual CDI devices are available, use
`nvidia.com/gpu=all` only for WSL2/all-only compatibility, and pass explicit
driver-native device IDs through.
Resource requirements enter the driver layer through `SandboxSpec.resource_requirements`. This includes a set of GPU requirements, where a user
can request a specific number of GPUs or the driver-specific default behaviour.
For all in-tree drivers, this is equivalent to selecting a single GPU.

VM runtime state paths are derived only from driver-validated sandbox IDs
matching `[A-Za-z0-9._-]{1,128}`. The gateway-owned VM driver socket uses a
Expand Down Expand Up @@ -98,7 +96,10 @@ users.
Custom sandbox images must include the agent runtime and any system
dependencies, but they should not need to include the gateway. GPU-capable
images must include the user-space libraries required by the workload. The
runtime still owns GPU device injection.
runtime still owns GPU device injection. GPU requests are explicit, and can be
refined with a driver-native device identifier or requested count; the gateway
validates the request shape and each runtime enforces the GPU allocation modes it
supports.

## Deployment Shape

Expand Down
169 changes: 166 additions & 3 deletions crates/openshell-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use openshell_bootstrap::{
use openshell_cli::completers;
use openshell_cli::run;
use openshell_cli::tls::TlsOptions;
use openshell_core::proto::GpuResourceRequirements;

/// Resolved gateway context: name + gateway endpoint.
struct GatewayContext {
Expand All @@ -28,6 +29,21 @@ struct GatewayContext {
endpoint: String,
}

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum GpuCliRequest {
DriverDefault,
Count(u32),
}

impl From<GpuCliRequest> for GpuResourceRequirements {
fn from(gpu: GpuCliRequest) -> Self {
match gpu {
GpuCliRequest::Count(count) => Self { count: Some(count) },
GpuCliRequest::DriverDefault => Self { count: None },
}
}
}

/// Resolve the gateway name to a [`GatewayContext`] with the gateway endpoint.
///
/// Resolution priority:
Expand Down Expand Up @@ -109,6 +125,21 @@ fn resolve_gateway(
})
}

fn parse_gpu_request(value: &str) -> std::result::Result<GpuCliRequest, String> {
if value.is_empty() {
return Ok(GpuCliRequest::DriverDefault);
}

let count = value
.parse::<u32>()
.map_err(|_| "GPU count must be a positive integer".to_string())?;
if count == 0 {
return Err("GPU count must be greater than 0".to_string());
}

Ok(GpuCliRequest::Count(count))
}

fn resolve_gateway_name(gateway_flag: &Option<String>) -> Option<String> {
gateway_flag
.clone()
Expand Down Expand Up @@ -1216,8 +1247,11 @@ enum SandboxCommands {
editor: Option<CliEditor>,

/// Request GPU resources for the sandbox.
#[arg(long)]
gpu: bool,
///
/// Omit COUNT for the driver's default GPU selection, or pass COUNT
/// to request a specific number of GPUs.
#[arg(long, num_args = 0..=1, value_name = "COUNT", default_missing_value = "", value_parser = parse_gpu_request)]
gpu: Option<GpuCliRequest>,

/// CPU limit for the sandbox (for example: 500m, 1, 2.5).
#[arg(long)]
Expand Down Expand Up @@ -2625,6 +2659,7 @@ async fn main() -> Result<()> {
.map(|s| openshell_core::forward::ForwardSpec::parse(&s))
.transpose()?;
let keep = keep || !no_keep || editor.is_some() || forward.is_some();
let gpu_requirements: Option<GpuResourceRequirements> = gpu.map(Into::into);

let ctx = resolve_gateway(&cli.gateway, &cli.gateway_endpoint)?;
let endpoint = &ctx.endpoint;
Expand All @@ -2637,7 +2672,7 @@ async fn main() -> Result<()> {
&ctx.name,
&upload_specs,
keep,
gpu,
gpu_requirements,
cpu.as_deref(),
memory.as_deref(),
driver_config_json.as_deref(),
Expand Down Expand Up @@ -3634,6 +3669,27 @@ mod tests {
});
}

#[test]
fn gpu_cli_request_option_maps_absent_gpu_to_no_requirements() {
let gpu: Option<GpuResourceRequirements> = Option::<GpuCliRequest>::None.map(Into::into);

assert_eq!(gpu, None);
}

#[test]
fn gpu_cli_request_driver_default_converts_to_requirements() {
let gpu = GpuResourceRequirements::from(GpuCliRequest::DriverDefault);

assert_eq!(gpu.count, None);
}

#[test]
fn gpu_cli_request_count_converts_to_requirements() {
let gpu = GpuResourceRequirements::from(GpuCliRequest::Count(2));

assert_eq!(gpu.count, Some(2));
}

#[test]
fn apply_auth_uses_stored_token() {
let tmp = tempfile::tempdir().unwrap();
Expand Down Expand Up @@ -4495,6 +4551,113 @@ mod tests {
}
}

#[test]
fn sandbox_create_gpu_parses_driver_default() {
let cli = Cli::try_parse_from(["openshell", "sandbox", "create", "--gpu"])
.expect("sandbox create --gpu should parse");

match cli.command {
Some(Commands::Sandbox {
command: Some(SandboxCommands::Create { gpu, .. }),
..
}) => {
assert_eq!(gpu, Some(GpuCliRequest::DriverDefault));
}
other => panic!("expected SandboxCommands::Create, got: {other:?}"),
}
}

#[test]
fn sandbox_create_gpu_count_parses_from_gpu_flag() {
let cli = Cli::try_parse_from(["openshell", "sandbox", "create", "--gpu", "2"])
.expect("sandbox create --gpu 2 should parse");

match cli.command {
Some(Commands::Sandbox {
command: Some(SandboxCommands::Create { gpu, .. }),
..
}) => {
assert_eq!(gpu, Some(GpuCliRequest::Count(2)));
}
other => panic!("expected SandboxCommands::Create, got: {other:?}"),
}
}

#[test]
fn sandbox_create_gpu_driver_default_allows_trailing_command() {
let cli = Cli::try_parse_from(["openshell", "sandbox", "create", "--gpu", "--", "claude"])
.expect("sandbox create --gpu -- claude should parse");

match cli.command {
Some(Commands::Sandbox {
command: Some(SandboxCommands::Create { gpu, command, .. }),
..
}) => {
assert_eq!(gpu, Some(GpuCliRequest::DriverDefault));
assert_eq!(command, vec!["claude".to_string()]);
}
other => panic!("expected SandboxCommands::Create, got: {other:?}"),
}
}

#[test]
fn sandbox_create_gpu_count_allows_trailing_command() {
let cli = Cli::try_parse_from([
"openshell",
"sandbox",
"create",
"--gpu",
"2",
"--",
"claude",
])
.expect("sandbox create --gpu 2 -- claude should parse");

match cli.command {
Some(Commands::Sandbox {
command: Some(SandboxCommands::Create { gpu, command, .. }),
..
}) => {
assert_eq!(gpu, Some(GpuCliRequest::Count(2)));
assert_eq!(command, vec!["claude".to_string()]);
}
other => panic!("expected SandboxCommands::Create, got: {other:?}"),
}
}

#[test]
fn sandbox_create_gpu_count_rejects_zero() {
let result = Cli::try_parse_from(["openshell", "sandbox", "create", "--gpu", "0"]);

assert!(result.is_err(), "sandbox create --gpu 0 should be rejected");
}

#[test]
fn sandbox_create_gpu_count_accepts_equals_syntax() {
let cli = Cli::try_parse_from(["openshell", "sandbox", "create", "--gpu=2"])
.expect("sandbox create --gpu=2 should parse");

match cli.command {
Some(Commands::Sandbox {
command: Some(SandboxCommands::Create { gpu, .. }),
..
}) => {
assert_eq!(gpu, Some(GpuCliRequest::Count(2)));
}
other => panic!("expected SandboxCommands::Create, got: {other:?}"),
}
}

#[test]
fn sandbox_create_gpu_count_rejects_non_integer() {
let result = Cli::try_parse_from(["openshell", "sandbox", "create", "--gpu", "many"]);

assert!(
result.is_err(),
"sandbox create --gpu many should be rejected"
);
}

#[test]
fn service_expose_accepts_positional_target_port_and_service() {
let cli = Cli::try_parse_from([
Expand Down
51 changes: 32 additions & 19 deletions crates/openshell-cli/src/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,13 @@ use openshell_core::proto::{
GetClusterInferenceRequest, GetDraftHistoryRequest, GetDraftPolicyRequest,
GetGatewayConfigRequest, GetProviderProfileRequest, GetProviderRefreshStatusRequest,
GetProviderRequest, GetSandboxConfigRequest, GetSandboxLogsRequest,
GetSandboxPolicyStatusRequest, GetSandboxRequest, GetServiceRequest, HealthRequest,
ImportProviderProfilesRequest, LintProviderProfilesRequest, ListProviderProfilesRequest,
ListProvidersRequest, ListSandboxPoliciesRequest, ListSandboxProvidersRequest,
ListSandboxesRequest, ListServicesRequest, PlatformEvent, PolicySource, PolicyStatus, Provider,
ProviderCredentialRefreshStatus, ProviderCredentialRefreshStrategy, ProviderProfile,
ProviderProfileDiagnostic, ProviderProfileImportItem, RejectDraftChunkRequest,
GetSandboxPolicyStatusRequest, GetSandboxRequest, GetServiceRequest, GpuResourceRequirements,
HealthRequest, ImportProviderProfilesRequest, LintProviderProfilesRequest,
ListProviderProfilesRequest, ListProvidersRequest, ListSandboxPoliciesRequest,
ListSandboxProvidersRequest, ListSandboxesRequest, ListServicesRequest, PlatformEvent,
PolicySource, PolicyStatus, Provider, ProviderCredentialRefreshStatus,
ProviderCredentialRefreshStrategy, ProviderProfile, ProviderProfileDiagnostic,
ProviderProfileImportItem, RejectDraftChunkRequest, ResourceRequirements,
RevokeSshSessionRequest, RotateProviderCredentialRequest, Sandbox, SandboxPhase, SandboxPolicy,
SandboxSpec, SandboxTemplate, ServiceEndpointResponse, SetClusterInferenceRequest,
SettingScope, SettingValue, TcpForwardFrame, TcpForwardInit, TcpRelayTarget,
Expand Down Expand Up @@ -123,7 +124,7 @@ fn ready_false_condition_message(

fn provisioning_timeout_message(
timeout_secs: u64,
requested_gpu: bool,
resource_requirements: Option<&ResourceRequirements>,
condition_message: Option<&str>,
) -> String {
let mut message = format!("sandbox provisioning timed out after {timeout_secs}s");
Expand All @@ -133,7 +134,7 @@ fn provisioning_timeout_message(
message.push_str(condition_message);
}

if requested_gpu {
if resource_requirements.is_some_and(|requirements| requirements.gpu.is_some()) {
message.push_str(
". Hint: this may be because the available GPU is already in use by another sandbox.",
);
Expand Down Expand Up @@ -1753,7 +1754,7 @@ pub async fn sandbox_create(
gateway_name: &str,
uploads: &[(String, Option<String>, bool)],
keep: bool,
gpu: bool,
gpu_requirements: Option<GpuResourceRequirements>,
cpu: Option<&str>,
memory: Option<&str>,
driver_config_json: Option<&str>,
Expand Down Expand Up @@ -1809,8 +1810,6 @@ pub async fn sandbox_create(
}
None => None,
};
let requested_gpu = gpu;

let providers_v2_enabled = gateway_providers_v2_enabled(&mut client).await?;
let inferred_types: Vec<String> = if providers_v2_enabled {
Vec::new()
Expand Down Expand Up @@ -1842,9 +1841,11 @@ pub async fn sandbox_create(
None
};

let resource_requirements = gpu_requirements.map(|gpu| ResourceRequirements { gpu: Some(gpu) });

let request = CreateSandboxRequest {
spec: Some(SandboxSpec {
gpu: requested_gpu,
resource_requirements,
environment: environment.clone(),
policy,
providers: configured_providers,
Expand Down Expand Up @@ -1989,7 +1990,7 @@ pub async fn sandbox_create(
if remaining.is_zero() {
let timeout_message = provisioning_timeout_message(
provision_timeout.as_secs(),
requested_gpu,
resource_requirements.as_ref(),
last_condition_message.as_deref(),
);
if let Some(d) = display.as_mut() {
Expand All @@ -2008,7 +2009,7 @@ pub async fn sandbox_create(
// Timeout fired — the stream was idle for too long.
let timeout_message = provisioning_timeout_message(
provision_timeout.as_secs(),
requested_gpu,
resource_requirements.as_ref(),
last_condition_message.as_deref(),
);
if let Some(d) = display.as_mut() {
Expand Down Expand Up @@ -7686,9 +7687,10 @@ mod tests {
PROGRESS_STEP_STARTING_SANDBOX,
};
use openshell_core::proto::{
Provider, ProviderCredentialRefresh, ProviderCredentialRefreshStatus,
ProviderCredentialRefreshStrategy, ProviderCredentialTokenGrant, ProviderProfile,
ProviderProfileCredential, SandboxCondition, SandboxStatus, datamodel::v1::ObjectMeta,
GpuResourceRequirements, Provider, ProviderCredentialRefresh,
ProviderCredentialRefreshStatus, ProviderCredentialRefreshStrategy,
ProviderCredentialTokenGrant, ProviderProfile, ProviderProfileCredential,
ResourceRequirements, SandboxCondition, SandboxStatus, datamodel::v1::ObjectMeta,
};

struct EnvVarGuard {
Expand Down Expand Up @@ -8392,9 +8394,12 @@ mod tests {

#[test]
fn provisioning_timeout_message_includes_condition_and_gpu_hint() {
let resource_requirements = ResourceRequirements {
gpu: Some(GpuResourceRequirements { count: None }),
};
let message = provisioning_timeout_message(
120,
true,
Some(&resource_requirements),
Some("DependenciesNotReady: Pod exists with phase: Pending; Service Exists"),
);

Expand All @@ -8405,7 +8410,15 @@ mod tests {

#[test]
fn provisioning_timeout_message_omits_gpu_hint_for_non_gpu_requests() {
let message = provisioning_timeout_message(120, false, None);
let message = provisioning_timeout_message(120, None, None);

assert_eq!(message, "sandbox provisioning timed out after 120s");
}

#[test]
fn provisioning_timeout_message_omits_gpu_hint_without_gpu_requirements() {
let resource_requirements = ResourceRequirements { gpu: None };
let message = provisioning_timeout_message(120, Some(&resource_requirements), None);

assert_eq!(message, "sandbox provisioning timed out after 120s");
}
Expand Down
Loading
Loading