@@ -39,18 +39,18 @@ use openshell_core::proto::{
3939 GetClusterInferenceRequest , GetDraftHistoryRequest , GetDraftPolicyRequest ,
4040 GetGatewayConfigRequest , GetProviderProfileRequest , GetProviderRefreshStatusRequest ,
4141 GetProviderRequest , GetSandboxConfigRequest , GetSandboxLogsRequest ,
42- GetSandboxPolicyStatusRequest , GetSandboxRequest , GetServiceRequest , GpuRequestSpec ,
42+ GetSandboxPolicyStatusRequest , GetSandboxRequest , GetServiceRequest , GpuResourceRequirement ,
4343 HealthRequest , ImportProviderProfilesRequest , LintProviderProfilesRequest ,
4444 ListProviderProfilesRequest , ListProvidersRequest , ListSandboxPoliciesRequest ,
4545 ListSandboxProvidersRequest , ListSandboxesRequest , ListServicesRequest , PlatformEvent ,
4646 PolicySource , PolicyStatus , Provider , ProviderCredentialRefreshStatus ,
4747 ProviderCredentialRefreshStrategy , ProviderProfile , ProviderProfileDiagnostic ,
4848 ProviderProfileImportItem , RejectDraftChunkRequest , RevokeSshSessionRequest ,
49- RotateProviderCredentialRequest , Sandbox , SandboxPhase , SandboxPolicy , SandboxSpec ,
50- SandboxTemplate , ServiceEndpointResponse , SetClusterInferenceRequest , SettingScope ,
51- SettingValue , TcpForwardFrame , TcpForwardInit , TcpRelayTarget , UpdateConfigRequest ,
52- UpdateProviderRequest , WatchSandboxRequest , exec_sandbox_event , setting_value ,
53- tcp_forward_init,
49+ RotateProviderCredentialRequest , Sandbox , SandboxPhase , SandboxPolicy ,
50+ SandboxResourceRequirements , SandboxSpec , SandboxTemplate , ServiceEndpointResponse ,
51+ SetClusterInferenceRequest , SettingScope , SettingValue , TcpForwardFrame , TcpForwardInit ,
52+ TcpRelayTarget , UpdateConfigRequest , UpdateProviderRequest , WatchSandboxRequest ,
53+ exec_sandbox_event , setting_value , tcp_forward_init,
5454} ;
5555use openshell_core:: settings:: { self , SettingValueKind } ;
5656use openshell_core:: { ObjectId , ObjectName } ;
@@ -1754,11 +1754,6 @@ pub async fn sandbox_create(
17541754 }
17551755 None => None ,
17561756 } ;
1757- let requested_gpu = gpu
1758- || gpu_device. is_some_and ( |device_id| !device_id. is_empty ( ) )
1759- || gpu_count. is_some ( )
1760- || image. as_deref ( ) . is_some_and ( image_requests_gpu) ;
1761-
17621757 let providers_v2_enabled = gateway_providers_v2_enabled ( & mut client) . await ?;
17631758 let inferred_types: Vec < String > = if providers_v2_enabled {
17641759 Vec :: new ( )
@@ -1775,6 +1770,11 @@ pub async fn sandbox_create(
17751770
17761771 let policy = load_sandbox_policy ( policy) ?;
17771772 let resource_limits = build_sandbox_resource_limits ( cpu, memory) ?;
1773+ let resource_requirements =
1774+ resource_requirements_from_cli ( image. as_deref ( ) , gpu, gpu_device, gpu_count) ;
1775+ let requested_gpu = resource_requirements
1776+ . as_ref ( )
1777+ . is_some_and ( |requirements| requirements. gpu . is_some ( ) ) ;
17781778
17791779 let template = if image. is_some ( ) || resource_limits. is_some ( ) {
17801780 Some ( SandboxTemplate {
@@ -1788,7 +1788,7 @@ pub async fn sandbox_create(
17881788
17891789 let request = CreateSandboxRequest {
17901790 spec : Some ( SandboxSpec {
1791- gpu : gpu_request_from_cli ( requested_gpu , gpu_device , gpu_count ) ,
1791+ resource_requirements ,
17921792 policy,
17931793 providers : configured_providers,
17941794 template,
@@ -2223,17 +2223,26 @@ pub async fn sandbox_create(
22232223 }
22242224}
22252225
2226- fn gpu_request_from_cli (
2227- requested_gpu : bool ,
2226+ fn resource_requirements_from_cli (
2227+ image : Option < & str > ,
2228+ gpu : bool ,
22282229 gpu_device : Option < & str > ,
22292230 gpu_count : Option < u32 > ,
2230- ) -> Option < GpuRequestSpec > {
2231- requested_gpu. then ( || GpuRequestSpec {
2232- device_id : gpu_device
2233- . filter ( |device_id| !device_id. is_empty ( ) )
2234- . map ( |device_id| vec ! [ device_id. to_string( ) ] )
2235- . unwrap_or_default ( ) ,
2236- count : gpu_count,
2231+ ) -> Option < SandboxResourceRequirements > {
2232+ let device_ids = gpu_device
2233+ . filter ( |device_id| !device_id. is_empty ( ) )
2234+ . map ( |device_id| vec ! [ device_id. to_string( ) ] )
2235+ . unwrap_or_default ( ) ;
2236+ let requested_gpu = gpu
2237+ || gpu_count. is_some ( )
2238+ || !device_ids. is_empty ( )
2239+ || image. is_some_and ( image_requests_gpu) ;
2240+
2241+ requested_gpu. then_some ( SandboxResourceRequirements {
2242+ gpu : Some ( GpuResourceRequirement {
2243+ device_ids,
2244+ count : gpu_count,
2245+ } ) ,
22372246 } )
22382247}
22392248
@@ -7486,15 +7495,14 @@ mod tests {
74867495 dockerfile_sources_supported_for_gateway, format_endpoint, format_gateway_select_header,
74877496 format_gateway_select_items, format_provider_attachment_table, gateway_add,
74887497 gateway_auth_label, gateway_env_override_warning, gateway_select_with, gateway_type_label,
7489- git_sync_files, gpu_request_from_cli, http_health_check, image_requests_gpu,
7490- import_local_package_mtls_bundle, inferred_provider_type, mtls_certs_exist_for_gateway,
7491- package_managed_tls_dirs,
7498+ git_sync_files, http_health_check, image_requests_gpu, import_local_package_mtls_bundle,
7499+ inferred_provider_type, mtls_certs_exist_for_gateway, package_managed_tls_dirs,
74927500 parse_cli_setting_value, parse_credential_expiry_cli_value, parse_credential_expiry_pairs,
74937501 parse_credential_pairs, plaintext_gateway_is_remote, progress_step_from_metadata,
74947502 provider_profile_allows_refresh_bootstrap, provisioning_timeout_message,
74957503 ready_false_condition_message, refresh_status_header, refresh_status_row, resolve_from,
7496- sandbox_should_persist , sandbox_upload_plan , service_expose_status_error ,
7497- service_url_for_gateway,
7504+ resource_requirements_from_cli , sandbox_should_persist , sandbox_upload_plan ,
7505+ service_expose_status_error , service_url_for_gateway,
74987506 } ;
74997507 use crate :: TEST_ENV_LOCK ;
75007508 use hyper:: StatusCode ;
@@ -7974,43 +7982,64 @@ mod tests {
79747982 }
79757983
79767984 #[ test]
7977- fn gpu_request_from_cli_uses_presence_with_empty_device_ids_for_default_gpu ( ) {
7978- let request =
7979- gpu_request_from_cli ( true , None , None ) . expect ( "gpu request should be present" ) ;
7985+ fn resource_requirements_from_cli_uses_presence_for_default_gpu ( ) {
7986+ let requirements = resource_requirements_from_cli ( None , true , None , None )
7987+ . expect ( "resource requirements should be present" ) ;
7988+ let gpu = requirements. gpu . expect ( "GPU requirement should be present" ) ;
79807989
7981- assert ! ( request . device_id . is_empty( ) ) ;
7982- assert_eq ! ( request . count, None ) ;
7990+ assert ! ( gpu . device_ids . is_empty( ) ) ;
7991+ assert_eq ! ( gpu . count, None ) ;
79837992 }
79847993
79857994 #[ test]
7986- fn gpu_request_from_cli_maps_gpu_device_to_one_device_id ( ) {
7987- let request = gpu_request_from_cli ( true , Some ( "0000:2d:00.0" ) , None )
7988- . expect ( "gpu request should be present" ) ;
7995+ fn resource_requirements_from_cli_maps_gpu_device_to_one_device_id ( ) {
7996+ let requirements = resource_requirements_from_cli ( None , false , Some ( "0000:2d:00.0" ) , None )
7997+ . expect ( "resource requirements should be present" ) ;
7998+ let gpu = requirements. gpu . expect ( "GPU requirement should be present" ) ;
79897999
7990- assert_eq ! ( request . device_id , vec![ "0000:2d:00.0" ] ) ;
7991- assert_eq ! ( request . count, None ) ;
8000+ assert_eq ! ( gpu . device_ids , vec![ "0000:2d:00.0" ] ) ;
8001+ assert_eq ! ( gpu . count, None ) ;
79928002 }
79938003
79948004 #[ test]
7995- fn gpu_request_from_cli_maps_gpu_count ( ) {
7996- let request = gpu_request_from_cli ( true , None , Some ( 2 ) ) . expect ( "gpu request should exist" ) ;
8005+ fn resource_requirements_from_cli_maps_gpu_count ( ) {
8006+ let requirements = resource_requirements_from_cli ( None , false , None , Some ( 2 ) )
8007+ . expect ( "requirements should exist" ) ;
8008+ let gpu = requirements. gpu . expect ( "GPU requirement should be present" ) ;
79978009
7998- assert ! ( request . device_id . is_empty( ) ) ;
7999- assert_eq ! ( request . count, Some ( 2 ) ) ;
8010+ assert ! ( gpu . device_ids . is_empty( ) ) ;
8011+ assert_eq ! ( gpu . count, Some ( 2 ) ) ;
80008012 }
80018013
80028014 #[ test]
8003- fn gpu_request_from_cli_preserves_device_and_gpu_count_for_gateway_validation ( ) {
8004- let request = gpu_request_from_cli ( true , Some ( "nvidia.com/gpu=0" ) , Some ( 2 ) )
8005- . expect ( "gpu request should exist" ) ;
8015+ fn resource_requirements_from_cli_preserves_device_and_gpu_count_for_gateway_validation ( ) {
8016+ let requirements =
8017+ resource_requirements_from_cli ( None , false , Some ( "nvidia.com/gpu=0" ) , Some ( 2 ) )
8018+ . expect ( "requirements should exist" ) ;
8019+ let gpu = requirements. gpu . expect ( "GPU requirement should be present" ) ;
80068020
8007- assert_eq ! ( request . device_id , vec![ "nvidia.com/gpu=0" ] ) ;
8008- assert_eq ! ( request . count, Some ( 2 ) ) ;
8021+ assert_eq ! ( gpu . device_ids , vec![ "nvidia.com/gpu=0" ] ) ;
8022+ assert_eq ! ( gpu . count, Some ( 2 ) ) ;
80098023 }
80108024
80118025 #[ test]
8012- fn gpu_request_from_cli_omits_gpu_request_when_not_requested ( ) {
8013- assert ! ( gpu_request_from_cli( false , Some ( "0" ) , None ) . is_none( ) ) ;
8026+ fn resource_requirements_from_cli_omits_gpu_request_when_not_requested ( ) {
8027+ assert ! ( resource_requirements_from_cli( None , false , None , None ) . is_none( ) ) ;
8028+ }
8029+
8030+ #[ test]
8031+ fn resource_requirements_from_cli_infers_gpu_from_image ( ) {
8032+ let requirements = resource_requirements_from_cli (
8033+ Some ( "ghcr.io/nvidia/openshell-community/sandboxes/nvidia-gpu:latest" ) ,
8034+ false ,
8035+ None ,
8036+ None ,
8037+ )
8038+ . expect ( "resource requirements should be present" ) ;
8039+ let gpu = requirements. gpu . expect ( "GPU requirement should be present" ) ;
8040+
8041+ assert ! ( gpu. device_ids. is_empty( ) ) ;
8042+ assert_eq ! ( gpu. count, None ) ;
80148043 }
80158044
80168045 #[ test]
0 commit comments