Skip to content

Commit 3a244be

Browse files
committed
feat(gpu): add GPU resource count
Signed-off-by: Evan Lezar <elezar@nvidia.com>
1 parent 9506f16 commit 3a244be

20 files changed

Lines changed: 1693 additions & 567 deletions

File tree

architecture/compute-runtimes.md

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,9 @@ through the driver configuration. The Helm chart defaults sandbox agents to
5555
`Unconfined` so runtime/default AppArmor profiles do not block supervisor
5656
network namespace setup on AppArmor-enabled nodes.
5757

58-
GPU requests enter the driver layer through `SandboxSpec.gpu` and
59-
`SandboxSpec.gpu_device`. Docker and Podman map default GPU requests to one
60-
concrete NVIDIA CDI device when individual CDI devices are available, use
61-
`nvidia.com/gpu=all` only for WSL2/all-only compatibility, and pass explicit
62-
driver-native device IDs through.
58+
Resource requirements enter the driver layer through `SandboxSpec.resource_requirements`. This includes a set of GPU requirements, where a user
59+
can request a specific number of GPUs or the driver-specific default behaviour.
60+
For all in-tree drivers, this is equivalent to selecting a single GPU.
6361

6462
VM runtime state paths are derived only from driver-validated sandbox IDs
6563
matching `[A-Za-z0-9._-]{1,128}`. The gateway-owned VM driver socket uses a
@@ -99,8 +97,9 @@ Custom sandbox images must include the agent runtime and any system
9997
dependencies, but they should not need to include the gateway. GPU-capable
10098
images must include the user-space libraries required by the workload. The
10199
runtime still owns GPU device injection. GPU requests are explicit, and can be
102-
refined with a driver-native device identifier; the gateway validates the
103-
request shape and each runtime enforces the GPU allocation modes it supports.
100+
refined with a driver-native device identifier or requested count; the gateway
101+
validates the request shape and each runtime enforces the GPU allocation modes it
102+
supports.
104103

105104
## Deployment Shape
106105

crates/openshell-cli/src/main.rs

Lines changed: 133 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,21 @@ struct GatewayContext {
2929
endpoint: String,
3030
}
3131

32+
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
33+
enum GpuCliRequest {
34+
DriverDefault,
35+
Count(u32),
36+
}
37+
38+
impl From<GpuCliRequest> for GpuResourceRequirements {
39+
fn from(gpu: GpuCliRequest) -> Self {
40+
match gpu {
41+
GpuCliRequest::Count(count) => Self { count: Some(count) },
42+
GpuCliRequest::DriverDefault => Self { count: None },
43+
}
44+
}
45+
}
46+
3247
/// Resolve the gateway name to a [`GatewayContext`] with the gateway endpoint.
3348
///
3449
/// Resolution priority:
@@ -110,6 +125,21 @@ fn resolve_gateway(
110125
})
111126
}
112127

128+
fn parse_gpu_request(value: &str) -> std::result::Result<GpuCliRequest, String> {
129+
if value.is_empty() {
130+
return Ok(GpuCliRequest::DriverDefault);
131+
}
132+
133+
let count = value
134+
.parse::<u32>()
135+
.map_err(|_| "GPU count must be a positive integer".to_string())?;
136+
if count == 0 {
137+
return Err("GPU count must be greater than 0".to_string());
138+
}
139+
140+
Ok(GpuCliRequest::Count(count))
141+
}
142+
113143
fn resolve_gateway_name(gateway_flag: &Option<String>) -> Option<String> {
114144
gateway_flag
115145
.clone()
@@ -1217,8 +1247,11 @@ enum SandboxCommands {
12171247
editor: Option<CliEditor>,
12181248

12191249
/// Request GPU resources for the sandbox.
1220-
#[arg(long)]
1221-
gpu: bool,
1250+
///
1251+
/// Omit COUNT for the driver's default GPU selection, or pass COUNT
1252+
/// to request a specific number of GPUs.
1253+
#[arg(long, num_args = 0..=1, value_name = "COUNT", default_missing_value = "", value_parser = parse_gpu_request)]
1254+
gpu: Option<GpuCliRequest>,
12221255

12231256
/// CPU limit for the sandbox (for example: 500m, 1, 2.5).
12241257
#[arg(long)]
@@ -2626,7 +2659,7 @@ async fn main() -> Result<()> {
26262659
.map(|s| openshell_core::forward::ForwardSpec::parse(&s))
26272660
.transpose()?;
26282661
let keep = keep || !no_keep || editor.is_some() || forward.is_some();
2629-
let gpu_requirements = gpu.then_some(GpuResourceRequirements {});
2662+
let gpu_requirements: Option<GpuResourceRequirements> = gpu.map(Into::into);
26302663

26312664
let ctx = resolve_gateway(&cli.gateway, &cli.gateway_endpoint)?;
26322665
let endpoint = &ctx.endpoint;
@@ -3636,6 +3669,27 @@ mod tests {
36363669
});
36373670
}
36383671

3672+
#[test]
3673+
fn gpu_cli_request_option_maps_absent_gpu_to_no_requirements() {
3674+
let gpu: Option<GpuResourceRequirements> = Option::<GpuCliRequest>::None.map(Into::into);
3675+
3676+
assert_eq!(gpu, None);
3677+
}
3678+
3679+
#[test]
3680+
fn gpu_cli_request_driver_default_converts_to_requirements() {
3681+
let gpu = GpuResourceRequirements::from(GpuCliRequest::DriverDefault);
3682+
3683+
assert_eq!(gpu.count, None);
3684+
}
3685+
3686+
#[test]
3687+
fn gpu_cli_request_count_converts_to_requirements() {
3688+
let gpu = GpuResourceRequirements::from(GpuCliRequest::Count(2));
3689+
3690+
assert_eq!(gpu.count, Some(2));
3691+
}
3692+
36393693
#[test]
36403694
fn apply_auth_uses_stored_token() {
36413695
let tmp = tempfile::tempdir().unwrap();
@@ -4507,7 +4561,23 @@ mod tests {
45074561
command: Some(SandboxCommands::Create { gpu, .. }),
45084562
..
45094563
}) => {
4510-
assert!(gpu);
4564+
assert_eq!(gpu, Some(GpuCliRequest::DriverDefault));
4565+
}
4566+
other => panic!("expected SandboxCommands::Create, got: {other:?}"),
4567+
}
4568+
}
4569+
4570+
#[test]
4571+
fn sandbox_create_gpu_count_parses_from_gpu_flag() {
4572+
let cli = Cli::try_parse_from(["openshell", "sandbox", "create", "--gpu", "2"])
4573+
.expect("sandbox create --gpu 2 should parse");
4574+
4575+
match cli.command {
4576+
Some(Commands::Sandbox {
4577+
command: Some(SandboxCommands::Create { gpu, .. }),
4578+
..
4579+
}) => {
4580+
assert_eq!(gpu, Some(GpuCliRequest::Count(2)));
45114581
}
45124582
other => panic!("expected SandboxCommands::Create, got: {other:?}"),
45134583
}
@@ -4523,13 +4593,71 @@ mod tests {
45234593
command: Some(SandboxCommands::Create { gpu, command, .. }),
45244594
..
45254595
}) => {
4526-
assert!(gpu);
4596+
assert_eq!(gpu, Some(GpuCliRequest::DriverDefault));
4597+
assert_eq!(command, vec!["claude".to_string()]);
4598+
}
4599+
other => panic!("expected SandboxCommands::Create, got: {other:?}"),
4600+
}
4601+
}
4602+
4603+
#[test]
4604+
fn sandbox_create_gpu_count_allows_trailing_command() {
4605+
let cli = Cli::try_parse_from([
4606+
"openshell",
4607+
"sandbox",
4608+
"create",
4609+
"--gpu",
4610+
"2",
4611+
"--",
4612+
"claude",
4613+
])
4614+
.expect("sandbox create --gpu 2 -- claude should parse");
4615+
4616+
match cli.command {
4617+
Some(Commands::Sandbox {
4618+
command: Some(SandboxCommands::Create { gpu, command, .. }),
4619+
..
4620+
}) => {
4621+
assert_eq!(gpu, Some(GpuCliRequest::Count(2)));
45274622
assert_eq!(command, vec!["claude".to_string()]);
45284623
}
45294624
other => panic!("expected SandboxCommands::Create, got: {other:?}"),
45304625
}
45314626
}
45324627

4628+
#[test]
4629+
fn sandbox_create_gpu_count_rejects_zero() {
4630+
let result = Cli::try_parse_from(["openshell", "sandbox", "create", "--gpu", "0"]);
4631+
4632+
assert!(result.is_err(), "sandbox create --gpu 0 should be rejected");
4633+
}
4634+
4635+
#[test]
4636+
fn sandbox_create_gpu_count_accepts_equals_syntax() {
4637+
let cli = Cli::try_parse_from(["openshell", "sandbox", "create", "--gpu=2"])
4638+
.expect("sandbox create --gpu=2 should parse");
4639+
4640+
match cli.command {
4641+
Some(Commands::Sandbox {
4642+
command: Some(SandboxCommands::Create { gpu, .. }),
4643+
..
4644+
}) => {
4645+
assert_eq!(gpu, Some(GpuCliRequest::Count(2)));
4646+
}
4647+
other => panic!("expected SandboxCommands::Create, got: {other:?}"),
4648+
}
4649+
}
4650+
4651+
#[test]
4652+
fn sandbox_create_gpu_count_rejects_non_integer() {
4653+
let result = Cli::try_parse_from(["openshell", "sandbox", "create", "--gpu", "many"]);
4654+
4655+
assert!(
4656+
result.is_err(),
4657+
"sandbox create --gpu many should be rejected"
4658+
);
4659+
}
4660+
45334661
#[test]
45344662
fn service_expose_accepts_positional_target_port_and_service() {
45354663
let cli = Cli::try_parse_from([

crates/openshell-cli/src/run.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8395,7 +8395,7 @@ mod tests {
83958395
#[test]
83968396
fn provisioning_timeout_message_includes_condition_and_gpu_hint() {
83978397
let resource_requirements = ResourceRequirements {
8398-
gpu: Some(GpuResourceRequirements {}),
8398+
gpu: Some(GpuResourceRequirements { count: None }),
83998399
};
84008400
let message = provisioning_timeout_message(
84018401
120,

crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1088,8 +1088,8 @@ fn test_tls(server: &TestServer) -> TlsOptions {
10881088
server.tls.with_gateway_name("openshell")
10891089
}
10901090

1091-
fn gpu_requirements() -> GpuResourceRequirements {
1092-
GpuResourceRequirements {}
1091+
fn gpu_requirements(count: Option<u32>) -> GpuResourceRequirements {
1092+
GpuResourceRequirements { count }
10931093
}
10941094

10951095
#[tokio::test]
@@ -1301,7 +1301,7 @@ async fn sandbox_create_sends_gpu_default_request() {
13011301
"openshell",
13021302
&[],
13031303
true,
1304-
Some(gpu_requirements()),
1304+
Some(gpu_requirements(None)),
13051305
None,
13061306
None,
13071307
None,
@@ -1328,11 +1328,53 @@ async fn sandbox_create_sends_gpu_default_request() {
13281328
.and_then(|requirements| requirements.gpu.as_ref())
13291329
.expect("GPU requirement should be sent");
13301330

1331-
assert!(requests[0]
1331+
assert_eq!(gpu.count, None);
1332+
}
1333+
1334+
#[tokio::test]
1335+
async fn sandbox_create_sends_gpu_count_request() {
1336+
let server = run_server().await;
1337+
let fake_ssh_dir = tempfile::tempdir().unwrap();
1338+
let xdg_dir = tempfile::tempdir().unwrap();
1339+
let _env = test_env(&fake_ssh_dir, &xdg_dir);
1340+
let tls = test_tls(&server);
1341+
install_fake_ssh(&fake_ssh_dir);
1342+
1343+
run::sandbox_create(
1344+
&server.endpoint,
1345+
Some("gpu-two"),
1346+
None,
1347+
"openshell",
1348+
&[],
1349+
true,
1350+
Some(gpu_requirements(Some(2))),
1351+
None,
1352+
None,
1353+
None,
1354+
None,
1355+
&[],
1356+
None,
1357+
None,
1358+
&["echo".to_string(), "OK".to_string()],
1359+
Some(false),
1360+
Some(false),
1361+
&HashMap::new(),
1362+
&HashMap::new(),
1363+
"manual",
1364+
&tls,
1365+
)
1366+
.await
1367+
.expect("sandbox create should succeed");
1368+
1369+
let requests = create_requests(&server).await;
1370+
let gpu = requests[0]
13321371
.spec
13331372
.as_ref()
13341373
.and_then(|spec| spec.resource_requirements.as_ref())
1335-
.is_some_and(|requirements| requirements.gpu.is_some()));
1374+
.and_then(|requirements| requirements.gpu.as_ref())
1375+
.expect("GPU requirement should be sent");
1376+
1377+
assert_eq!(gpu.count, Some(2));
13361378
}
13371379

13381380
#[tokio::test]

0 commit comments

Comments
 (0)