Skip to content

Commit acfb3d4

Browse files
committed
refactor(gpu): use resource requirements for GPU requests
Signed-off-by: Evan Lezar <elezar@nvidia.com>
1 parent aa66525 commit acfb3d4

20 files changed

Lines changed: 436 additions & 280 deletions

File tree

architecture/compute-runtimes.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ through the driver configuration. The Helm chart defaults sandbox agents to
4545
`Unconfined` so runtime/default AppArmor profiles do not block supervisor
4646
network namespace setup on AppArmor-enabled nodes.
4747

48+
GPU requests enter the driver layer through
49+
`SandboxSpec.resource_requirements.gpu`.
50+
4851
VM runtime state paths are derived only from driver-validated sandbox IDs
4952
matching `[A-Za-z0-9._-]{1,128}`. The gateway-owned VM driver socket uses a
5053
private `run/` directory plus Unix peer UID/PID checks. Standalone

crates/openshell-cli/src/run.rs

Lines changed: 76 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -39,18 +39,18 @@ use openshell_core::proto::{
3939
GetClusterInferenceRequest, GetDraftHistoryRequest, GetDraftPolicyRequest,
4040
GetGatewayConfigRequest, GetProviderProfileRequest, GetProviderRefreshStatusRequest,
4141
GetProviderRequest, GetSandboxConfigRequest, GetSandboxLogsRequest,
42-
GetSandboxPolicyStatusRequest, GetSandboxRequest, GetServiceRequest, GpuRequestSpec,
42+
GetSandboxPolicyStatusRequest, GetSandboxRequest, GetServiceRequest, GpuResourceRequirement,
4343
HealthRequest, ImportProviderProfilesRequest, LintProviderProfilesRequest,
4444
ListProviderProfilesRequest, ListProvidersRequest, ListSandboxPoliciesRequest,
4545
ListSandboxProvidersRequest, ListSandboxesRequest, ListServicesRequest, PlatformEvent,
4646
PolicySource, PolicyStatus, Provider, ProviderCredentialRefreshStatus,
4747
ProviderCredentialRefreshStrategy, ProviderProfile, ProviderProfileDiagnostic,
4848
ProviderProfileImportItem, RejectDraftChunkRequest, RevokeSshSessionRequest,
49-
RotateProviderCredentialRequest, Sandbox, SandboxPhase, SandboxPolicy, SandboxSpec,
50-
SandboxTemplate, ServiceEndpointResponse, SetClusterInferenceRequest, SettingScope,
51-
SettingValue, TcpForwardFrame, TcpForwardInit, TcpRelayTarget, UpdateConfigRequest,
52-
UpdateProviderRequest, WatchSandboxRequest, exec_sandbox_event, setting_value,
53-
tcp_forward_init,
49+
RotateProviderCredentialRequest, Sandbox, SandboxPhase, SandboxPolicy,
50+
SandboxResourceRequirements, SandboxSpec, SandboxTemplate, ServiceEndpointResponse,
51+
SetClusterInferenceRequest, SettingScope, SettingValue, TcpForwardFrame, TcpForwardInit,
52+
TcpRelayTarget, UpdateConfigRequest, UpdateProviderRequest, WatchSandboxRequest,
53+
exec_sandbox_event, setting_value, tcp_forward_init,
5454
};
5555
use openshell_core::settings::{self, SettingValueKind};
5656
use openshell_core::{ObjectId, ObjectName};
@@ -1754,11 +1754,6 @@ pub async fn sandbox_create(
17541754
}
17551755
None => None,
17561756
};
1757-
let requested_gpu = gpu
1758-
|| gpu_device.is_some_and(|device_id| !device_id.is_empty())
1759-
|| gpu_count.is_some()
1760-
|| image.as_deref().is_some_and(image_requests_gpu);
1761-
17621757
let providers_v2_enabled = gateway_providers_v2_enabled(&mut client).await?;
17631758
let inferred_types: Vec<String> = if providers_v2_enabled {
17641759
Vec::new()
@@ -1775,6 +1770,11 @@ pub async fn sandbox_create(
17751770

17761771
let policy = load_sandbox_policy(policy)?;
17771772
let resource_limits = build_sandbox_resource_limits(cpu, memory)?;
1773+
let resource_requirements =
1774+
resource_requirements_from_cli(image.as_deref(), gpu, gpu_device, gpu_count);
1775+
let requested_gpu = resource_requirements
1776+
.as_ref()
1777+
.is_some_and(|requirements| requirements.gpu.is_some());
17781778

17791779
let template = if image.is_some() || resource_limits.is_some() {
17801780
Some(SandboxTemplate {
@@ -1788,7 +1788,7 @@ pub async fn sandbox_create(
17881788

17891789
let request = CreateSandboxRequest {
17901790
spec: Some(SandboxSpec {
1791-
gpu: gpu_request_from_cli(requested_gpu, gpu_device, gpu_count),
1791+
resource_requirements,
17921792
policy,
17931793
providers: configured_providers,
17941794
template,
@@ -2223,17 +2223,26 @@ pub async fn sandbox_create(
22232223
}
22242224
}
22252225

2226-
fn gpu_request_from_cli(
2227-
requested_gpu: bool,
2226+
fn resource_requirements_from_cli(
2227+
image: Option<&str>,
2228+
gpu: bool,
22282229
gpu_device: Option<&str>,
22292230
gpu_count: Option<u32>,
2230-
) -> Option<GpuRequestSpec> {
2231-
requested_gpu.then(|| GpuRequestSpec {
2232-
device_id: gpu_device
2233-
.filter(|device_id| !device_id.is_empty())
2234-
.map(|device_id| vec![device_id.to_string()])
2235-
.unwrap_or_default(),
2236-
count: gpu_count,
2231+
) -> Option<SandboxResourceRequirements> {
2232+
let device_ids = gpu_device
2233+
.filter(|device_id| !device_id.is_empty())
2234+
.map(|device_id| vec![device_id.to_string()])
2235+
.unwrap_or_default();
2236+
let requested_gpu = gpu
2237+
|| gpu_count.is_some()
2238+
|| !device_ids.is_empty()
2239+
|| image.is_some_and(image_requests_gpu);
2240+
2241+
requested_gpu.then_some(SandboxResourceRequirements {
2242+
gpu: Some(GpuResourceRequirement {
2243+
device_ids,
2244+
count: gpu_count,
2245+
}),
22372246
})
22382247
}
22392248

@@ -7486,15 +7495,14 @@ mod tests {
74867495
dockerfile_sources_supported_for_gateway, format_endpoint, format_gateway_select_header,
74877496
format_gateway_select_items, format_provider_attachment_table, gateway_add,
74887497
gateway_auth_label, gateway_env_override_warning, gateway_select_with, gateway_type_label,
7489-
git_sync_files, gpu_request_from_cli, http_health_check, image_requests_gpu,
7490-
import_local_package_mtls_bundle, inferred_provider_type, mtls_certs_exist_for_gateway,
7491-
package_managed_tls_dirs,
7498+
git_sync_files, http_health_check, image_requests_gpu, import_local_package_mtls_bundle,
7499+
inferred_provider_type, mtls_certs_exist_for_gateway, package_managed_tls_dirs,
74927500
parse_cli_setting_value, parse_credential_expiry_cli_value, parse_credential_expiry_pairs,
74937501
parse_credential_pairs, plaintext_gateway_is_remote, progress_step_from_metadata,
74947502
provider_profile_allows_refresh_bootstrap, provisioning_timeout_message,
74957503
ready_false_condition_message, refresh_status_header, refresh_status_row, resolve_from,
7496-
sandbox_should_persist, sandbox_upload_plan, service_expose_status_error,
7497-
service_url_for_gateway,
7504+
resource_requirements_from_cli, sandbox_should_persist, sandbox_upload_plan,
7505+
service_expose_status_error, service_url_for_gateway,
74987506
};
74997507
use crate::TEST_ENV_LOCK;
75007508
use hyper::StatusCode;
@@ -7974,43 +7982,64 @@ mod tests {
79747982
}
79757983

79767984
#[test]
7977-
fn gpu_request_from_cli_uses_presence_with_empty_device_ids_for_default_gpu() {
7978-
let request =
7979-
gpu_request_from_cli(true, None, None).expect("gpu request should be present");
7985+
fn resource_requirements_from_cli_uses_presence_for_default_gpu() {
7986+
let requirements = resource_requirements_from_cli(None, true, None, None)
7987+
.expect("resource requirements should be present");
7988+
let gpu = requirements.gpu.expect("GPU requirement should be present");
79807989

7981-
assert!(request.device_id.is_empty());
7982-
assert_eq!(request.count, None);
7990+
assert!(gpu.device_ids.is_empty());
7991+
assert_eq!(gpu.count, None);
79837992
}
79847993

79857994
#[test]
7986-
fn gpu_request_from_cli_maps_gpu_device_to_one_device_id() {
7987-
let request = gpu_request_from_cli(true, Some("0000:2d:00.0"), None)
7988-
.expect("gpu request should be present");
7995+
fn resource_requirements_from_cli_maps_gpu_device_to_one_device_id() {
7996+
let requirements = resource_requirements_from_cli(None, false, Some("0000:2d:00.0"), None)
7997+
.expect("resource requirements should be present");
7998+
let gpu = requirements.gpu.expect("GPU requirement should be present");
79897999

7990-
assert_eq!(request.device_id, vec!["0000:2d:00.0"]);
7991-
assert_eq!(request.count, None);
8000+
assert_eq!(gpu.device_ids, vec!["0000:2d:00.0"]);
8001+
assert_eq!(gpu.count, None);
79928002
}
79938003

79948004
#[test]
7995-
fn gpu_request_from_cli_maps_gpu_count() {
7996-
let request = gpu_request_from_cli(true, None, Some(2)).expect("gpu request should exist");
8005+
fn resource_requirements_from_cli_maps_gpu_count() {
8006+
let requirements = resource_requirements_from_cli(None, false, None, Some(2))
8007+
.expect("requirements should exist");
8008+
let gpu = requirements.gpu.expect("GPU requirement should be present");
79978009

7998-
assert!(request.device_id.is_empty());
7999-
assert_eq!(request.count, Some(2));
8010+
assert!(gpu.device_ids.is_empty());
8011+
assert_eq!(gpu.count, Some(2));
80008012
}
80018013

80028014
#[test]
8003-
fn gpu_request_from_cli_preserves_device_and_gpu_count_for_gateway_validation() {
8004-
let request = gpu_request_from_cli(true, Some("nvidia.com/gpu=0"), Some(2))
8005-
.expect("gpu request should exist");
8015+
fn resource_requirements_from_cli_preserves_device_and_gpu_count_for_gateway_validation() {
8016+
let requirements =
8017+
resource_requirements_from_cli(None, false, Some("nvidia.com/gpu=0"), Some(2))
8018+
.expect("requirements should exist");
8019+
let gpu = requirements.gpu.expect("GPU requirement should be present");
80068020

8007-
assert_eq!(request.device_id, vec!["nvidia.com/gpu=0"]);
8008-
assert_eq!(request.count, Some(2));
8021+
assert_eq!(gpu.device_ids, vec!["nvidia.com/gpu=0"]);
8022+
assert_eq!(gpu.count, Some(2));
80098023
}
80108024

80118025
#[test]
8012-
fn gpu_request_from_cli_omits_gpu_request_when_not_requested() {
8013-
assert!(gpu_request_from_cli(false, Some("0"), None).is_none());
8026+
fn resource_requirements_from_cli_omits_gpu_request_when_not_requested() {
8027+
assert!(resource_requirements_from_cli(None, false, None, None).is_none());
8028+
}
8029+
8030+
#[test]
8031+
fn resource_requirements_from_cli_infers_gpu_from_image() {
8032+
let requirements = resource_requirements_from_cli(
8033+
Some("ghcr.io/nvidia/openshell-community/sandboxes/nvidia-gpu:latest"),
8034+
false,
8035+
None,
8036+
None,
8037+
)
8038+
.expect("resource requirements should be present");
8039+
let gpu = requirements.gpu.expect("GPU requirement should be present");
8040+
8041+
assert!(gpu.device_ids.is_empty());
8042+
assert_eq!(gpu.count, None);
80148043
}
80158044

80168045
#[test]

crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -907,6 +907,7 @@ async fn sandbox_create_sends_gpu_device_request_without_gpu_flag() {
907907
None,
908908
None,
909909
None,
910+
None,
910911
&[],
911912
None,
912913
None,
@@ -924,10 +925,11 @@ async fn sandbox_create_sends_gpu_device_request_without_gpu_flag() {
924925
let gpu = requests[0]
925926
.spec
926927
.as_ref()
927-
.and_then(|spec| spec.gpu.as_ref())
928+
.and_then(|spec| spec.resource_requirements.as_ref())
929+
.and_then(|requirements| requirements.gpu.as_ref())
928930
.expect("GPU request should be sent");
929931

930-
assert_eq!(gpu.device_id, vec!["nvidia.com/gpu=0"]);
932+
assert_eq!(gpu.device_ids, vec!["nvidia.com/gpu=0"]);
931933
assert_eq!(gpu.count, None);
932934
}
933935

@@ -970,10 +972,11 @@ async fn sandbox_create_sends_gpu_count_request() {
970972
let gpu = requests[0]
971973
.spec
972974
.as_ref()
973-
.and_then(|spec| spec.gpu.as_ref())
975+
.and_then(|spec| spec.resource_requirements.as_ref())
976+
.and_then(|requirements| requirements.gpu.as_ref())
974977
.expect("GPU request should be sent");
975978

976-
assert!(gpu.device_id.is_empty());
979+
assert!(gpu.device_ids.is_empty());
977980
assert_eq!(gpu.count, Some(2));
978981
}
979982

crates/openshell-core/src/gpu.rs

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,26 @@
44
//! Shared GPU request helpers.
55
66
use crate::config::CDI_GPU_DEVICE_ALL;
7-
use crate::proto::compute::v1::{DriverSandboxSpec, GpuRequestSpec};
7+
use crate::proto::compute::v1::{DriverGpuResourceRequirement, DriverSandboxSpec};
88

9-
/// Extract the driver GPU request from a sandbox spec.
9+
/// Extract the driver GPU requirement from a sandbox spec.
1010
#[must_use]
11-
pub fn driver_gpu_request(spec: &DriverSandboxSpec) -> Option<&GpuRequestSpec> {
12-
spec.gpu.as_ref()
11+
pub fn driver_gpu_requirement(spec: &DriverSandboxSpec) -> Option<&DriverGpuResourceRequirement> {
12+
spec.resource_requirements
13+
.as_ref()
14+
.and_then(|requirements| requirements.gpu.as_ref())
1315
}
1416

1517
/// Resolve a driver GPU request into CDI device identifiers.
1618
///
1719
/// `None` means no GPU was requested. Presence with no explicit device IDs
18-
/// uses the CDI all-GPU request; otherwise the driver-native IDs pass through.
20+
/// uses the CDI all-GPU request, preserving the current default GPU behavior;
21+
/// otherwise the driver-native IDs pass through.
1922
#[must_use]
20-
pub fn cdi_gpu_device_ids(gpu: Option<&GpuRequestSpec>) -> Option<Vec<String>> {
23+
pub fn cdi_gpu_device_ids(gpu: Option<&DriverGpuResourceRequirement>) -> Option<Vec<String>> {
2124
match gpu {
22-
Some(gpu) if gpu.device_id.is_empty() => Some(vec![CDI_GPU_DEVICE_ALL.to_string()]),
23-
Some(gpu) => Some(gpu.device_id.clone()),
25+
Some(gpu) if gpu.device_ids.is_empty() => Some(vec![CDI_GPU_DEVICE_ALL.to_string()]),
26+
Some(gpu) => Some(gpu.device_ids.clone()),
2427
None => None,
2528
}
2629
}
@@ -36,8 +39,8 @@ mod tests {
3639

3740
#[test]
3841
fn cdi_gpu_device_ids_defaults_empty_request_to_all_gpus() {
39-
let request = GpuRequestSpec {
40-
device_id: vec![],
42+
let request = DriverGpuResourceRequirement {
43+
device_ids: vec![],
4144
count: None,
4245
};
4346

@@ -49,8 +52,8 @@ mod tests {
4952

5053
#[test]
5154
fn cdi_gpu_device_ids_passes_single_device_id_through() {
52-
let request = GpuRequestSpec {
53-
device_id: vec!["nvidia.com/gpu=0".to_string()],
55+
let request = DriverGpuResourceRequirement {
56+
device_ids: vec!["nvidia.com/gpu=0".to_string()],
5457
count: None,
5558
};
5659

@@ -62,8 +65,8 @@ mod tests {
6265

6366
#[test]
6467
fn cdi_gpu_device_ids_passes_multiple_device_ids_through() {
65-
let request = GpuRequestSpec {
66-
device_id: vec![
68+
let request = DriverGpuResourceRequirement {
69+
device_ids: vec![
6770
"nvidia.com/gpu=0".to_string(),
6871
"nvidia.com/gpu=1".to_string(),
6972
],

crates/openshell-driver-docker/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ contract:
3232
| `apparmor=unconfined` | Avoids Docker's default profile blocking required mount operations. |
3333
| `restart_policy = unless-stopped` | Keeps managed sandboxes resumable across daemon or gateway restarts. |
3434
| `PidsLimit` | Enforces the sandbox PID budget at the Docker cgroup layer. Set `[openshell.drivers.docker].sandbox_pids_limit = 0` to inherit the Docker/runtime default. |
35-
| CDI GPU request | Uses explicit GPU request device IDs when set; otherwise requests all NVIDIA GPUs when the sandbox spec asks for GPU support and daemon CDI support is detected. Count-based GPU requests are rejected until Docker CDI selection can map counts to concrete devices. |
35+
| CDI GPU request | Uses explicit `resource_requirements.gpu.device_ids` when set; otherwise requests all NVIDIA GPUs when `resource_requirements.gpu` is present and daemon CDI support is detected. Count-based GPU requests are rejected until Docker CDI selection can map counts to concrete devices. |
3636

3737
The agent child process does not retain these supervisor privileges.
3838

crates/openshell-driver-docker/src/lib.rs

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,16 @@ use openshell_core::driver_utils::{
2525
LABEL_MANAGED_BY, LABEL_MANAGED_BY_VALUE, LABEL_SANDBOX_ID, LABEL_SANDBOX_NAME,
2626
LABEL_SANDBOX_NAMESPACE, SUPERVISOR_IMAGE_BINARY_PATH, supervisor_image_should_refresh,
2727
};
28-
use openshell_core::gpu::{cdi_gpu_device_ids, driver_gpu_request};
28+
use openshell_core::gpu::{cdi_gpu_device_ids, driver_gpu_requirement};
2929
use openshell_core::progress::{
3030
PROGRESS_STEP_PULLING_IMAGE, PROGRESS_STEP_REQUESTING_SANDBOX, PROGRESS_STEP_STARTING_SANDBOX,
3131
format_bytes, mark_progress_active, mark_progress_complete, mark_progress_detail,
3232
};
3333
use openshell_core::proto::compute::v1::{
3434
CreateSandboxRequest, CreateSandboxResponse, DeleteSandboxRequest, DeleteSandboxResponse,
35-
DriverCondition, DriverPlatformEvent, DriverSandbox, DriverSandboxStatus,
36-
DriverSandboxTemplate, GetCapabilitiesRequest, GetCapabilitiesResponse, GetSandboxRequest,
37-
GetSandboxResponse, GpuRequestSpec, ListSandboxesRequest, ListSandboxesResponse,
35+
DriverCondition, DriverGpuResourceRequirement, DriverPlatformEvent, DriverSandbox,
36+
DriverSandboxStatus, DriverSandboxTemplate, GetCapabilitiesRequest, GetCapabilitiesResponse,
37+
GetSandboxRequest, GetSandboxResponse, ListSandboxesRequest, ListSandboxesResponse,
3838
StopSandboxRequest, StopSandboxResponse, ValidateSandboxCreateRequest,
3939
ValidateSandboxCreateResponse, WatchSandboxesDeletedEvent, WatchSandboxesEvent,
4040
WatchSandboxesPlatformEvent, WatchSandboxesRequest, WatchSandboxesSandboxEvent,
@@ -375,7 +375,7 @@ impl DockerComputeDriver {
375375
"docker sandboxes require a template image",
376376
));
377377
}
378-
Self::validate_gpu_request(driver_gpu_request(spec), config.supports_gpu)?;
378+
Self::validate_gpu_request(driver_gpu_requirement(spec), config.supports_gpu)?;
379379
if !template.agent_socket_path.trim().is_empty() {
380380
return Err(Status::failed_precondition(
381381
"docker compute driver does not support template.agent_socket_path",
@@ -410,7 +410,7 @@ impl DockerComputeDriver {
410410
}
411411

412412
fn validate_gpu_request(
413-
gpu: Option<&GpuRequestSpec>,
413+
gpu: Option<&DriverGpuResourceRequirement>,
414414
supports_gpu: bool,
415415
) -> Result<(), Status> {
416416
if gpu.is_some_and(|gpu| gpu.count.is_some()) {
@@ -1721,7 +1721,9 @@ fn build_environment(sandbox: &DriverSandbox, config: &DockerDriverRuntimeConfig
17211721
.collect()
17221722
}
17231723

1724-
fn docker_gpu_device_requests(gpu: Option<&GpuRequestSpec>) -> Option<Vec<DeviceRequest>> {
1724+
fn docker_gpu_device_requests(
1725+
gpu: Option<&DriverGpuResourceRequirement>,
1726+
) -> Option<Vec<DeviceRequest>> {
17251727
cdi_gpu_device_ids(gpu).map(|device_ids| {
17261728
vec![DeviceRequest {
17271729
driver: Some("cdi".to_string()),
@@ -1773,7 +1775,7 @@ fn build_container_create_body(
17731775
nano_cpus: resource_limits.nano_cpus,
17741776
memory: resource_limits.memory_bytes,
17751777
pids_limit: docker_pids_limit(config.sandbox_pids_limit)?,
1776-
device_requests: docker_gpu_device_requests(driver_gpu_request(spec)),
1778+
device_requests: docker_gpu_device_requests(driver_gpu_requirement(spec)),
17771779
binds: Some(build_binds(sandbox, config)?),
17781780
restart_policy: Some(RestartPolicy {
17791781
name: Some(RestartPolicyNameEnum::UNLESS_STOPPED),

0 commit comments

Comments
 (0)