From b8737033acb6dc110432905a9113a5d70f92d1bd Mon Sep 17 00:00:00 2001 From: Gordon Sim Date: Fri, 19 Jun 2026 18:16:47 +0100 Subject: [PATCH] feat(provider): add ability to request token exchange instead of client credentials as OAuth grant_type Signed-off-by: Gordon Sim --- .../skills/debug-openshell-cluster/SKILL.md | 12 +- Cargo.lock | 10 + Cargo.toml | 2 +- architecture/sandbox.md | 30 +- crates/openshell-cli/src/main.rs | 184 ++- crates/openshell-cli/src/oidc_auth.rs | 22 +- crates/openshell-cli/src/run.rs | 292 +++- .../tests/ensure_providers_integration.rs | 10 +- .../openshell-cli/tests/mtls_integration.rs | 16 +- .../tests/provider_commands_integration.rs | 377 ++++- .../sandbox_create_lifecycle_integration.rs | 10 +- .../sandbox_name_fallback_integration.rs | 10 +- crates/openshell-core/src/grpc_client.rs | 67 +- crates/openshell-core/src/lib.rs | 1 + crates/openshell-core/src/sandbox_env.rs | 6 + crates/openshell-core/src/spiffe.rs | 173 +++ crates/openshell-providers/src/profiles.rs | 356 ++++- crates/openshell-server/Cargo.toml | 2 + .../src/auth/sandbox_methods.rs | 3 + crates/openshell-server/src/grpc/mod.rs | 11 +- crates/openshell-server/src/grpc/provider.rs | 1330 ++++++++++++++++- crates/openshell-server/tests/common/mod.rs | 12 +- .../tests/supervisor_relay_integration.rs | 6 + .../src/l7/token_grant_injection.rs | 55 +- .../openshell-supervisor-network/src/lib.rs | 1 - .../openshell-supervisor-network/src/proxy.rs | 52 +- .../src/spiffe_endpoint.rs | 17 - .../src/token_grant.rs | 402 +++-- .../openshell-supervisor-process/src/run.rs | 12 +- deploy/helm/openshell/README.md | 19 +- deploy/helm/openshell/README.md.gotmpl | 15 +- .../openshell/templates/_gateway-workload.tpl | 17 +- .../openshell/tests/gateway_config_test.yaml | 36 +- deploy/helm/openshell/values.yaml | 12 +- docs/kubernetes/access-control.mdx | 8 +- docs/reference/gateway-config.mdx | 8 + docs/sandboxes/manage-providers.mdx | 27 + docs/sandboxes/providers-v2.mdx | 59 +- proto/openshell.proto | 56 + 39 files changed, 3446 insertions(+), 292 deletions(-) create mode 100644 crates/openshell-core/src/spiffe.rs delete mode 100644 crates/openshell-supervisor-network/src/spiffe_endpoint.rs diff --git a/.agents/skills/debug-openshell-cluster/SKILL.md b/.agents/skills/debug-openshell-cluster/SKILL.md index 68ecc7749..c7c442217 100644 --- a/.agents/skills/debug-openshell-cluster/SKILL.md +++ b/.agents/skills/debug-openshell-cluster/SKILL.md @@ -180,14 +180,20 @@ even when local Helm values disable TLS. If `server.providerTokenGrants.spiffe.enabled=true`, the gateway should still render `[openshell.gateway.gateway_jwt]` and mount the `sandbox-jwt` Secret. -SPIRE is used only by sandbox pods for dynamic provider token grants. Verify -that SPIRE is installed, the CSI driver is available, and the Kubernetes driver -config includes `provider_spiffe_workload_api_socket_path`: +SPIRE is used by both the gateway and sandbox supervisors for dynamic provider +token grants. The gateway pod must mount the `spiffe-workload-api` CSI volume +and set `OPENSHELL_GATEWAY_SPIFFE_WORKLOAD_API_SOCKET`; sandbox pods must +receive the matching Workload API socket from the Kubernetes driver config. +The gateway verifies supervisor JWT-SVIDs from JWT bundles fetched through this +Workload API socket, not from the SPIRE OIDC discovery endpoint. +Verify that SPIRE is installed, the CSI driver is available, and the Kubernetes +driver config includes `provider_spiffe_workload_api_socket_path`: ```bash helm -n openshell get values openshell | grep -E 'providerTokenGrants|workloadApiSocketPath' kubectl get pods -A | grep -E 'spire|spiffe' kubectl -n openshell get configmap openshell-config -o yaml | grep provider_spiffe_workload_api_socket_path +kubectl -n openshell get pod -l app.kubernetes.io/name=helm-chart -o jsonpath="{.items[*].spec.containers[*].env[?(@.name==\"OPENSHELL_GATEWAY_SPIFFE_WORKLOAD_API_SOCKET\")].value}{\"\n\"}" ``` Sandbox pods using provider token grants should have an diff --git a/Cargo.lock b/Cargo.lock index f693acd66..71a2a6ec1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2550,10 +2550,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0529410abe238729a60b108898784df8984c87f6054c9c4fcacc47e4803c1ce1" dependencies = [ "base64 0.22.1", + "ed25519-dalek", "getrandom 0.2.17", + "hmac", "js-sys", + "p256", + "p384", + "rand 0.8.6", + "rsa 0.9.10", "serde", "serde_json", + "sha2 0.10.9", "signature 2.2.0", ] @@ -3673,6 +3680,7 @@ dependencies = [ "arc-swap", "async-trait", "axum", + "base64 0.22.1", "bytes", "clap", "futures", @@ -3719,6 +3727,7 @@ dependencies = [ "serde", "serde_json", "sha2 0.10.9", + "spiffe", "sqlx", "tempfile", "thiserror 2.0.18", @@ -5541,6 +5550,7 @@ dependencies = [ "fastrand", "futures", "hyper-util", + "jsonwebtoken 10.3.0", "log", "prost", "prost-types", diff --git a/Cargo.toml b/Cargo.toml index 86025646a..8215f0a7b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -88,7 +88,7 @@ sha2 = "0.10" rand = "0.9" jsonwebtoken = "9" getrandom = "0.3" -spiffe = { version = "0.15", default-features = false, features = ["workload-api-jwt", "tracing"] } +spiffe = { version = "0.15", default-features = false, features = ["workload-api-jwt", "jwt-verify-rust-crypto", "tracing"] } # Filesystem embedding include_dir = "0.7" diff --git a/architecture/sandbox.md b/architecture/sandbox.md index e60b727a5..401295a8a 100644 --- a/architecture/sandbox.md +++ b/architecture/sandbox.md @@ -80,13 +80,29 @@ when policy allows the target endpoint. Secrets must not be logged in OCSF or plain tracing output. Provider profiles can also declare dynamic token grants. For matching HTTP -endpoints, the supervisor obtains a SPIFFE JWT-SVID from the local Workload API, -exchanges it for an OAuth2 access token, caches the token, and injects it as an -`Authorization: Bearer` header before forwarding the request. Token grant -endpoints are HTTPS-only except for loopback and Kubernetes service DNS hosts, -and returned access tokens must be bearer-compatible before they are cached or -injected. Token response lifetimes are capped and cached with an expiry margin -unless a profile supplies an explicit cache TTL override. +endpoints, the supervisor obtains or exchanges OAuth2 access tokens, caches +them, and injects them before forwarding the request. `client_credentials` +grants use the supervisor SPIFFE JWT-SVID directly as the client assertion. +`token_exchange` grants ask the gateway to broker an intermediate token using a +stored provider subject credential and the gateway's own SPIFFE JWT-SVID; the +supervisor then exchanges that intermediate token for the final upstream token +using its own JWT-SVID. The gateway validates that its own JWT-SVID has the +requested audience, a SPIFFE subject, and a non-expired `exp` claim when +present. It also validates that the stored subject credential is declared by the +provider profile, and that the supervisor JWT-SVID is a well-formed +three-segment JWT with a SPIFFE subject in the same trust domain as the gateway +SVID. The gateway verifies the supervisor JWT-SVID signature with JWT bundles +fetched from its SPIFFE Workload API. Token grant endpoints are HTTPS-only +except for loopback and Kubernetes service DNS hosts, and returned access tokens +must be bearer-compatible before they are cached or injected. Token response +lifetimes are capped and cached with an expiry margin unless a profile supplies +an explicit cache TTL override. Cache entries are scoped by the sandbox provider +environment revision so provider credential updates miss the old token cache +without changing endpoint matching semantics. Gateway-brokered intermediate +tokens are cached separately by provider resource version, supervisor SPIFFE +subject, and gateway SPIFFE subject, and their cache lifetime is capped by the +intermediate token response, stored subject-token expiry, and supervisor SVID +expiry. ## Connect and Logs diff --git a/crates/openshell-cli/src/main.rs b/crates/openshell-cli/src/main.rs index 9b80f1914..b993f05db 100644 --- a/crates/openshell-cli/src/main.rs +++ b/crates/openshell-cli/src/main.rs @@ -716,7 +716,7 @@ impl From for openshell_cli::ssh::Editor { #[derive(Subcommand, Debug)] enum ProviderCommands { /// Create a provider config. - #[command(group = clap::ArgGroup::new("cred_source").required(true).args(["from_existing", "credentials", "from_gcloud_adc", "runtime_credentials"]), help_template = LEAF_HELP_TEMPLATE, next_help_heading = "FLAGS")] + #[command(group = clap::ArgGroup::new("cred_source").required(true).multiple(true).args(["from_existing", "credentials", "from_gcloud_adc", "runtime_credentials", "from_oidc_token"]), help_template = LEAF_HELP_TEMPLATE, next_help_heading = "FLAGS")] Create { /// Provider name. #[arg(long)] @@ -744,8 +744,12 @@ enum ProviderCommands { #[arg(long, group = "cred_source", conflicts_with_all = ["from_existing", "credentials", "runtime_credentials"])] from_gcloud_adc: bool, + /// Store the active gateway OIDC access token as the named provider credential. + #[arg(long, group = "cred_source", conflicts_with_all = ["from_existing", "from_gcloud_adc", "runtime_credentials"])] + from_oidc_token: bool, + /// Create a provider whose required credentials are resolved at runtime by the gateway/sandbox. - #[arg(long, conflicts_with_all = ["from_existing", "credentials", "from_gcloud_adc"])] + #[arg(long, conflicts_with_all = ["from_existing", "credentials", "from_gcloud_adc", "from_oidc_token"])] runtime_credentials: bool, /// Provider config key/value pair. @@ -805,9 +809,13 @@ enum ProviderCommands { name: String, /// Re-discover credentials from existing local state (e.g. env vars, config files). - #[arg(long, conflicts_with = "credentials")] + #[arg(long, conflicts_with_all = ["credentials", "from_oidc_token"])] from_existing: bool, + /// Store the active gateway OIDC access token as the named provider credential. + #[arg(long, conflicts_with = "from_existing")] + from_oidc_token: bool, + /// Provider credential pair (`KEY=VALUE`) or env lookup key (`KEY`). #[arg( long = "credential", @@ -2834,20 +2842,44 @@ async fn main() -> Result<()> { from_existing, credentials, from_gcloud_adc, + from_oidc_token, runtime_credentials, config, } => { - run::provider_create_with_options( - endpoint, - &name, - provider_type.as_str(), + let selected_sources = [ from_existing, - &credentials, from_gcloud_adc, + from_oidc_token, runtime_credentials, - &config, - &tls, - ) + ] + .into_iter() + .filter(|selected| *selected) + .count(); + if selected_sources > 1 { + return Err(miette::miette!( + "--from-existing, --from-gcloud-adc, --from-oidc-token, and --runtime-credentials are mutually exclusive" + )); + } + let credential_source = if from_existing { + run::ProviderCreateCredentialSource::Existing + } else if from_gcloud_adc { + run::ProviderCreateCredentialSource::GcloudAdc + } else if from_oidc_token { + run::ProviderCreateCredentialSource::OidcToken + } else if runtime_credentials { + run::ProviderCreateCredentialSource::Runtime + } else { + run::ProviderCreateCredentialSource::ExplicitCredentials + }; + run::provider_create_with_options(run::ProviderCreateOptions { + server: endpoint, + name: &name, + provider_type: provider_type.as_str(), + credentials: &credentials, + credential_source, + config: &config, + tls: &tls, + }) .await?; } ProviderCommands::Refresh(command) => match command { @@ -2943,19 +2975,21 @@ async fn main() -> Result<()> { ProviderCommands::Update { name, from_existing, + from_oidc_token, credentials, config, credential_expires_at, } => { - run::provider_update( - endpoint, - &name, + run::provider_update(run::ProviderUpdateOptions { + server: endpoint, + name: &name, from_existing, - &credentials, - &config, - &credential_expires_at, - &tls, - ) + from_oidc_token, + credentials: &credentials, + config: &config, + credential_expires_at: &credential_expires_at, + tls: &tls, + }) .await?; } ProviderCommands::Delete { names } => { @@ -4016,6 +4050,68 @@ mod tests { } } + #[test] + fn provider_create_accepts_from_oidc_token_destination_credential() { + let cli = Cli::try_parse_from([ + "openshell", + "provider", + "create", + "--name", + "custom-api", + "--type", + "custom-api", + "--from-oidc-token", + "--credential", + "user_oidc_token", + ]) + .expect("provider create should parse from oidc token"); + + match cli.command { + Some(Commands::Provider { + command: + Some(ProviderCommands::Create { + name, + provider_type, + from_oidc_token, + credentials, + .. + }), + }) => { + assert_eq!(name, "custom-api"); + assert_eq!(provider_type, "custom-api"); + assert!(from_oidc_token); + assert_eq!(credentials, vec!["user_oidc_token"]); + } + other => panic!("expected provider create command, got: {other:?}"), + } + } + + #[test] + fn provider_create_accepts_from_oidc_token_without_credential() { + let cli = Cli::try_parse_from([ + "openshell", + "provider", + "create", + "--name", + "custom-api", + "--type", + "custom-api", + "--from-oidc-token", + ]) + .expect("provider create should parse inferred oidc token destination"); + + assert!(matches!( + cli.command, + Some(Commands::Provider { + command: Some(ProviderCommands::Create { + from_oidc_token: true, + credentials, + .. + }) + }) if credentials.is_empty() + )); + } + #[test] fn provider_create_rejects_from_gcloud_adc_with_from_existing() { let err = Cli::try_parse_from([ @@ -4176,6 +4272,56 @@ mod tests { )); } + #[test] + fn provider_update_accepts_from_oidc_token_destination_credential() { + let cli = Cli::try_parse_from([ + "openshell", + "provider", + "update", + "custom-api", + "--from-oidc-token", + "--credential", + "user_oidc_token", + ]) + .expect("provider update should parse from oidc token"); + + assert!(matches!( + cli.command, + Some(Commands::Provider { + command: Some(ProviderCommands::Update { + name, + from_oidc_token: true, + credentials, + .. + }) + }) if name == "custom-api" && credentials == vec!["user_oidc_token"] + )); + } + + #[test] + fn provider_update_accepts_from_oidc_token_without_credential() { + let cli = Cli::try_parse_from([ + "openshell", + "provider", + "update", + "custom-api", + "--from-oidc-token", + ]) + .expect("provider update should parse inferred oidc token destination"); + + assert!(matches!( + cli.command, + Some(Commands::Provider { + command: Some(ProviderCommands::Update { + name, + from_oidc_token: true, + credentials, + .. + }) + }) if name == "custom-api" && credentials.is_empty() + )); + } + #[test] fn provider_refresh_config_accepts_rfc3339_credential_expiry() { let cli = Cli::try_parse_from([ diff --git a/crates/openshell-cli/src/oidc_auth.rs b/crates/openshell-cli/src/oidc_auth.rs index 379a53112..63981bd28 100644 --- a/crates/openshell-cli/src/oidc_auth.rs +++ b/crates/openshell-cli/src/oidc_auth.rs @@ -259,10 +259,11 @@ pub async fn oidc_refresh_token( Ok(refreshed) } -/// Ensure we have a valid OIDC token for the given gateway, refreshing if needed. -/// -/// Returns the access token string. -pub async fn ensure_valid_oidc_token(gateway_name: &str, insecure: bool) -> Result { +/// Ensure we have a valid OIDC token bundle for the given gateway, refreshing if needed. +pub async fn ensure_valid_oidc_token_bundle( + gateway_name: &str, + insecure: bool, +) -> Result { let bundle = openshell_bootstrap::oidc_token::load_oidc_token(gateway_name).ok_or_else(|| { miette::miette!( @@ -272,7 +273,7 @@ pub async fn ensure_valid_oidc_token(gateway_name: &str, insecure: bool) -> Resu })?; if !openshell_bootstrap::oidc_token::is_token_expired(&bundle) { - return Ok(bundle.access_token); + return Ok(bundle); } debug!( @@ -281,7 +282,16 @@ pub async fn ensure_valid_oidc_token(gateway_name: &str, insecure: bool) -> Resu ); let refreshed = oidc_refresh_token(&bundle, insecure).await?; openshell_bootstrap::oidc_token::store_oidc_token(gateway_name, &refreshed)?; - Ok(refreshed.access_token) + Ok(refreshed) +} + +/// Ensure we have a valid OIDC token for the given gateway, refreshing if needed. +/// +/// Returns the access token string. +pub async fn ensure_valid_oidc_token(gateway_name: &str, insecure: bool) -> Result { + Ok(ensure_valid_oidc_token_bundle(gateway_name, insecure) + .await? + .access_token) } // ── Helpers ────────────────────────────────────────────────────────── diff --git a/crates/openshell-cli/src/run.rs b/crates/openshell-cli/src/run.rs index 6f5520755..e4b4bcd3a 100644 --- a/crates/openshell-cli/src/run.rs +++ b/crates/openshell-cli/src/run.rs @@ -46,13 +46,14 @@ use openshell_core::proto::{ ImportProviderProfilesRequest, LintProviderProfilesRequest, ListProviderProfilesRequest, ListProvidersRequest, ListSandboxPoliciesRequest, ListSandboxProvidersRequest, ListSandboxesRequest, ListServicesRequest, PlatformEvent, PolicySource, PolicyStatus, Provider, - ProviderCredentialRefreshStatus, ProviderCredentialRefreshStrategy, ProviderProfile, - ProviderProfileDiagnostic, ProviderProfileImportItem, RejectDraftChunkRequest, - RevokeSshSessionRequest, RotateProviderCredentialRequest, Sandbox, SandboxPhase, SandboxPolicy, - SandboxSpec, SandboxTemplate, ServiceEndpointResponse, SetClusterInferenceRequest, - SettingScope, SettingValue, TcpForwardFrame, TcpForwardInit, TcpRelayTarget, - UpdateConfigRequest, UpdateProviderRequest, WatchSandboxRequest, exec_sandbox_event, - setting_value, tcp_forward_init, + ProviderCredentialRefreshStatus, ProviderCredentialRefreshStrategy, + ProviderCredentialTokenGrantType, ProviderProfile, ProviderProfileDiagnostic, + ProviderProfileImportItem, RejectDraftChunkRequest, RevokeSshSessionRequest, + RotateProviderCredentialRequest, Sandbox, SandboxPhase, SandboxPolicy, SandboxSpec, + SandboxTemplate, ServiceEndpointResponse, SetClusterInferenceRequest, SettingScope, + SettingValue, TcpForwardFrame, TcpForwardInit, TcpRelayTarget, UpdateConfigRequest, + UpdateProviderRequest, WatchSandboxRequest, exec_sandbox_event, setting_value, + tcp_forward_init, }; use openshell_core::settings::{self, SettingValueKind}; use openshell_core::{ObjectId, ObjectName}; @@ -4500,6 +4501,135 @@ fn missing_credentials_error(provider_type: &str) -> miette::Report { ) } +async fn provider_credential_from_oidc_token( + credentials: &[String], + profile: Option<&ProviderProfile>, + tls: &TlsOptions, +) -> Result<(HashMap, HashMap)> { + let credential_key = oidc_subject_credential_key(credentials, profile)?; + + let gateway_name = tls.gateway_name().ok_or_else(|| { + miette::miette!("--from-oidc-token requires an active named OIDC gateway") + })?; + let bundle = + crate::oidc_auth::ensure_valid_oidc_token_bundle(gateway_name, tls.gateway_insecure) + .await + .map_err(|err| { + miette::miette!( + "failed to load or refresh OIDC token for gateway '{gateway_name}' while preparing provider credential: {err}" + ) + })?; + + let mut credential_map = HashMap::new(); + credential_map.insert(credential_key.clone(), bundle.access_token); + + let mut credential_expires_at_ms = HashMap::new(); + if let Some(expires_at) = bundle.expires_at { + let expires_at_ms = expires_at + .checked_mul(1000) + .and_then(|value| i64::try_from(value).ok()) + .ok_or_else(|| miette::miette!("stored OIDC token expiry is out of range"))?; + credential_expires_at_ms.insert(credential_key, expires_at_ms); + } + + Ok((credential_map, credential_expires_at_ms)) +} + +fn oidc_subject_credential_key( + credentials: &[String], + profile: Option<&ProviderProfile>, +) -> Result { + if credentials.len() > 1 { + return Err(miette::miette!( + "--from-oidc-token accepts at most one --credential KEY destination" + )); + } + + if let Some(credential) = credentials.first() { + let credential = credential.trim(); + if credential.is_empty() || credential.contains('=') { + return Err(miette::miette!( + "--from-oidc-token requires --credential KEY without an inline value" + )); + } + if let Some(profile) = profile { + ensure_profile_declares_subject_credential(profile, credential)?; + } + return Ok(credential.to_string()); + } + + let Some(profile) = profile else { + return Err(miette::miette!( + "--from-oidc-token requires --credential KEY when the provider profile is unavailable" + )); + }; + + infer_oidc_subject_credential_from_profile(profile) +} + +fn ensure_profile_declares_subject_credential( + profile: &ProviderProfile, + credential: &str, +) -> Result<()> { + let matches = token_exchange_subject_credentials(profile); + if matches.iter().any(|candidate| candidate == credential) { + return Ok(()); + } + + if matches.is_empty() { + return Err(miette::miette!( + "provider profile '{}' does not declare a token_exchange subject credential", + profile.id + )); + } + + Err(miette::miette!( + "credential '{credential}' is not a token_exchange subject credential in provider profile '{}'; expected {}", + profile.id, + matches.join(", ") + )) +} + +fn infer_oidc_subject_credential_from_profile(profile: &ProviderProfile) -> Result { + let matches = token_exchange_subject_credentials(profile); + match matches.as_slice() { + [credential] => Ok(credential.clone()), + [] => Err(miette::miette!( + "provider profile '{}' does not declare a token_exchange subject credential; pass --credential KEY explicitly or use a token_exchange profile", + profile.id + )), + _ => Err(miette::miette!( + "provider profile '{}' declares multiple token_exchange subject credentials ({}); pass --credential KEY", + profile.id, + matches.join(", ") + )), + } +} + +fn token_exchange_subject_credentials(profile: &ProviderProfile) -> Vec { + let mut matches = Vec::new(); + for credential in &profile.credentials { + let Some(token_grant) = credential.token_grant.as_ref() else { + continue; + }; + if ProviderCredentialTokenGrantType::try_from(token_grant.grant_type).ok() + != Some(ProviderCredentialTokenGrantType::TokenExchange) + { + continue; + } + let Some(subject_token) = token_grant.subject_token.as_ref() else { + continue; + }; + if subject_token.source != "provider_credential" || subject_token.credential.is_empty() { + continue; + } + if !matches.contains(&subject_token.credential) { + matches.push(subject_token.credential.clone()); + } + } + matches +} + #[allow(clippy::too_many_arguments)] pub async fn provider_create( server: &str, @@ -4511,40 +4641,71 @@ pub async fn provider_create( config: &[String], tls: &TlsOptions, ) -> Result<()> { - provider_create_with_options( + let credential_source = match (from_existing, from_gcloud_adc) { + (true, true) => { + return Err(miette::miette!( + "--from-gcloud-adc cannot be combined with --from-existing, --from-oidc-token, or --credential; it also cannot be combined with --runtime-credentials" + )); + } + (true, false) => ProviderCreateCredentialSource::Existing, + (false, true) => ProviderCreateCredentialSource::GcloudAdc, + (false, false) => ProviderCreateCredentialSource::ExplicitCredentials, + }; + provider_create_with_options(ProviderCreateOptions { server, name, provider_type, - from_existing, credentials, - from_gcloud_adc, - false, + credential_source, config, tls, - ) + }) .await } -#[allow(clippy::too_many_arguments)] -pub async fn provider_create_with_options( - server: &str, - name: &str, - provider_type: &str, - from_existing: bool, - credentials: &[String], - from_gcloud_adc: bool, - runtime_credentials: bool, - config: &[String], - tls: &TlsOptions, -) -> Result<()> { - if from_gcloud_adc && (from_existing || !credentials.is_empty() || runtime_credentials) { +pub struct ProviderCreateOptions<'a> { + pub server: &'a str, + pub name: &'a str, + pub provider_type: &'a str, + pub credentials: &'a [String], + pub credential_source: ProviderCreateCredentialSource, + pub config: &'a [String], + pub tls: &'a TlsOptions, +} + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub enum ProviderCreateCredentialSource { + ExplicitCredentials, + Existing, + GcloudAdc, + OidcToken, + Runtime, +} + +pub async fn provider_create_with_options(options: ProviderCreateOptions<'_>) -> Result<()> { + let ProviderCreateOptions { + server, + name, + provider_type, + credentials, + credential_source, + config, + tls, + } = options; + + let from_existing = credential_source == ProviderCreateCredentialSource::Existing; + let from_gcloud_adc = credential_source == ProviderCreateCredentialSource::GcloudAdc; + let from_oidc_token = credential_source == ProviderCreateCredentialSource::OidcToken; + let runtime_credentials = credential_source == ProviderCreateCredentialSource::Runtime; + + if from_gcloud_adc && !credentials.is_empty() { return Err(miette::miette!( - "--from-gcloud-adc cannot be combined with --from-existing or --credential; it also cannot be combined with --runtime-credentials" + "--from-gcloud-adc cannot be combined with --from-existing, --from-oidc-token, or --credential; it also cannot be combined with --runtime-credentials" )); } - if from_existing && (!credentials.is_empty() || runtime_credentials) { + if from_existing && !credentials.is_empty() { return Err(miette::miette!( - "--from-existing cannot be combined with --credential or --runtime-credentials" + "--from-existing cannot be combined with --credential" )); } if runtime_credentials && !credentials.is_empty() { @@ -4589,7 +4750,17 @@ pub async fn provider_create_with_options( )); } - let mut credential_map = parse_credential_pairs(credentials)?; + let oidc_profile = if from_oidc_token { + Some(fetch_provider_profile(&mut client, &provider_type).await?) + } else { + None + }; + + let (mut credential_map, oidc_credential_expires_at_ms) = if from_oidc_token { + provider_credential_from_oidc_token(credentials, oidc_profile.as_ref(), tls).await? + } else { + (parse_credential_pairs(credentials)?, HashMap::new()) + }; let mut config_map = parse_key_value_pairs(config, "--config")?; if from_existing { @@ -4657,7 +4828,7 @@ pub async fn provider_create_with_options( r#type: provider_type.clone(), credentials: credential_map, config: config_map, - credential_expires_at_ms: HashMap::new(), + credential_expires_at_ms: oidc_credential_expires_at_ms, }), }) .await @@ -5472,26 +5643,65 @@ fn truncate_display(value: &str, max_width: usize) -> String { truncated } -pub async fn provider_update( - server: &str, - name: &str, - from_existing: bool, - credentials: &[String], - config: &[String], - credential_expires_at: &[String], - tls: &TlsOptions, -) -> Result<()> { +pub struct ProviderUpdateOptions<'a> { + pub server: &'a str, + pub name: &'a str, + pub from_existing: bool, + pub from_oidc_token: bool, + pub credentials: &'a [String], + pub config: &'a [String], + pub credential_expires_at: &'a [String], + pub tls: &'a TlsOptions, +} + +pub async fn provider_update(options: ProviderUpdateOptions<'_>) -> Result<()> { + let ProviderUpdateOptions { + server, + name, + from_existing, + from_oidc_token, + credentials, + config, + credential_expires_at, + tls, + } = options; + if from_existing && !credentials.is_empty() { return Err(miette::miette!( "--from-existing cannot be combined with --credential" )); } + if from_existing && from_oidc_token { + return Err(miette::miette!( + "--from-existing cannot be combined with --from-oidc-token" + )); + } let mut client = grpc_client(server, tls).await?; - let mut credential_map = parse_credential_pairs(credentials)?; + let oidc_profile = if from_oidc_token { + let existing = client + .get_provider(GetProviderRequest { + name: name.to_string(), + }) + .await + .into_diagnostic()? + .into_inner() + .provider + .ok_or_else(|| miette::miette!("provider '{name}' not found"))?; + Some(fetch_provider_profile(&mut client, &existing.r#type).await?) + } else { + None + }; + + let (mut credential_map, oidc_credential_expires_at_ms) = if from_oidc_token { + provider_credential_from_oidc_token(credentials, oidc_profile.as_ref(), tls).await? + } else { + (parse_credential_pairs(credentials)?, HashMap::new()) + }; let mut config_map = parse_key_value_pairs(config, "--config")?; - let credential_expires_at_ms = parse_credential_expiry_pairs(credential_expires_at)?; + let mut credential_expires_at_ms = parse_credential_expiry_pairs(credential_expires_at)?; + credential_expires_at_ms.extend(oidc_credential_expires_at_ms); if from_existing { // Fetch the existing provider to discover its type for credential lookup. diff --git a/crates/openshell-cli/tests/ensure_providers_integration.rs b/crates/openshell-cli/tests/ensure_providers_integration.rs index ea2d5a465..fd380485a 100644 --- a/crates/openshell-cli/tests/ensure_providers_integration.rs +++ b/crates/openshell-cli/tests/ensure_providers_integration.rs @@ -17,7 +17,8 @@ use openshell_core::proto::{ AttachSandboxProviderRequest, AttachSandboxProviderResponse, CreateProviderRequest, CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, DeleteProviderRequest, DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, - DetachSandboxProviderRequest, DetachSandboxProviderResponse, ExecSandboxEvent, + DetachSandboxProviderRequest, DetachSandboxProviderResponse, + ExchangeProviderSubjectTokenRequest, ExchangeProviderSubjectTokenResponse, ExecSandboxEvent, ExecSandboxInput, ExecSandboxRequest, GatewayMessage, GetGatewayConfigRequest, GetGatewayConfigResponse, GetProviderRequest, GetSandboxConfigRequest, GetSandboxConfigResponse, GetSandboxProviderEnvironmentRequest, @@ -201,6 +202,13 @@ impl OpenShell for TestOpenShell { Ok(Response::new(RevokeSshSessionResponse::default())) } + async fn exchange_provider_subject_token( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + async fn create_provider( &self, request: tonic::Request, diff --git a/crates/openshell-cli/tests/mtls_integration.rs b/crates/openshell-cli/tests/mtls_integration.rs index 7cb9e1e76..ea5267769 100644 --- a/crates/openshell-cli/tests/mtls_integration.rs +++ b/crates/openshell-cli/tests/mtls_integration.rs @@ -13,10 +13,11 @@ use openshell_cli::{ }; use openshell_core::proto::{ CreateProviderRequest, CreateSshSessionRequest, CreateSshSessionResponse, - DeleteProviderRequest, DeleteProviderResponse, ExecSandboxEvent, ExecSandboxInput, - ExecSandboxRequest, GetProviderRequest, HealthRequest, HealthResponse, ListProvidersRequest, - ListProvidersResponse, ProviderResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, - ServiceStatus, UpdateProviderRequest, + DeleteProviderRequest, DeleteProviderResponse, ExchangeProviderSubjectTokenRequest, + ExchangeProviderSubjectTokenResponse, ExecSandboxEvent, ExecSandboxInput, ExecSandboxRequest, + GetProviderRequest, HealthRequest, HealthResponse, ListProvidersRequest, ListProvidersResponse, + ProviderResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, ServiceStatus, + UpdateProviderRequest, open_shell_server::{OpenShell, OpenShellServer}, }; use tempfile::tempdir; @@ -178,6 +179,13 @@ impl OpenShell for TestOpenShell { Ok(Response::new(RevokeSshSessionResponse::default())) } + async fn exchange_provider_subject_token( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + async fn create_provider( &self, _request: tonic::Request, diff --git a/crates/openshell-cli/tests/provider_commands_integration.rs b/crates/openshell-cli/tests/provider_commands_integration.rs index 1450b99d4..b1e875d13 100644 --- a/crates/openshell-cli/tests/provider_commands_integration.rs +++ b/crates/openshell-cli/tests/provider_commands_integration.rs @@ -14,7 +14,8 @@ use openshell_core::proto::{ CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, DeleteProviderRefreshRequest, DeleteProviderRefreshResponse, DeleteProviderRequest, DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, - DetachSandboxProviderRequest, DetachSandboxProviderResponse, ExecSandboxEvent, + DetachSandboxProviderRequest, DetachSandboxProviderResponse, + ExchangeProviderSubjectTokenRequest, ExchangeProviderSubjectTokenResponse, ExecSandboxEvent, ExecSandboxInput, ExecSandboxRequest, GatewayMessage, GetGatewayConfigRequest, GetGatewayConfigResponse, GetProviderRefreshStatusRequest, GetProviderRefreshStatusResponse, GetProviderRequest, GetSandboxConfigRequest, GetSandboxConfigResponse, @@ -22,11 +23,12 @@ use openshell_core::proto::{ HealthRequest, HealthResponse, ListProvidersRequest, ListProvidersResponse, ListSandboxProvidersRequest, ListSandboxProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, Provider, ProviderCredentialRefresh, ProviderCredentialRefreshStatus, - ProviderCredentialRefreshStrategy, ProviderProfile, ProviderProfileCredential, - ProviderProfileDiscovery, ProviderResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, - RotateProviderCredentialRequest, RotateProviderCredentialResponse, Sandbox, SandboxResponse, - SandboxStreamEvent, ServiceStatus, SettingValue, SupervisorMessage, UpdateProviderRequest, - WatchSandboxRequest, setting_value, + ProviderCredentialRefreshStrategy, ProviderCredentialTokenGrant, + ProviderCredentialTokenGrantSubjectToken, ProviderCredentialTokenGrantType, ProviderProfile, + ProviderProfileCredential, ProviderProfileDiscovery, ProviderResponse, RevokeSshSessionRequest, + RevokeSshSessionResponse, RotateProviderCredentialRequest, RotateProviderCredentialResponse, + Sandbox, SandboxResponse, SandboxStreamEvent, ServiceStatus, SettingValue, SupervisorMessage, + UpdateProviderRequest, WatchSandboxRequest, setting_value, }; use openshell_core::{ObjectId, ObjectName}; use std::collections::HashMap; @@ -332,6 +334,13 @@ impl OpenShell for TestOpenShell { Ok(Response::new(RevokeSshSessionResponse::default())) } + async fn exchange_provider_subject_token( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + async fn create_provider( &self, request: tonic::Request, @@ -484,8 +493,8 @@ impl OpenShell for TestOpenShell { &self, request: tonic::Request, ) -> Result, Status> { + let request = request.into_inner(); let provider = request - .into_inner() .provider .ok_or_else(|| Status::invalid_argument("provider is required"))?; @@ -538,7 +547,7 @@ impl OpenShell for TestOpenShell { config: merge(existing.config, provider.config), credential_expires_at_ms: merge_expiry( existing.credential_expires_at_ms, - provider.credential_expires_at_ms, + request.credential_expires_at_ms, ), }; let updated_name = updated.object_name().to_string(); @@ -977,6 +986,95 @@ async fn enable_providers_v2(ts: &TestServer) { ); } +async fn register_oidc_token_exchange_profile(ts: &TestServer) { + ts.state.profiles.lock().await.insert( + "oidc-token-exchange".to_string(), + ProviderProfile { + id: "oidc-token-exchange".to_string(), + display_name: "OIDC Token Exchange".to_string(), + credentials: vec![ + ProviderProfileCredential { + name: "user_oidc_token".to_string(), + required: true, + ..Default::default() + }, + ProviderProfileCredential { + name: "api_token".to_string(), + required: true, + token_grant: Some(ProviderCredentialTokenGrant { + grant_type: ProviderCredentialTokenGrantType::TokenExchange as i32, + token_endpoint: "https://issuer.example.com/token".to_string(), + subject_token: Some(ProviderCredentialTokenGrantSubjectToken { + source: "provider_credential".to_string(), + credential: "user_oidc_token".to_string(), + subject_token_type: "urn:ietf:params:oauth:token-type:access_token" + .to_string(), + }), + ..Default::default() + }), + ..Default::default() + }, + ], + ..Default::default() + }, + ); +} + +async fn register_ambiguous_oidc_token_exchange_profile(ts: &TestServer) { + ts.state.profiles.lock().await.insert( + "ambiguous-oidc-token-exchange".to_string(), + ProviderProfile { + id: "ambiguous-oidc-token-exchange".to_string(), + display_name: "Ambiguous OIDC Token Exchange".to_string(), + credentials: vec![ + ProviderProfileCredential { + name: "user_oidc_token".to_string(), + required: true, + ..Default::default() + }, + ProviderProfileCredential { + name: "admin_oidc_token".to_string(), + required: true, + ..Default::default() + }, + ProviderProfileCredential { + name: "api_token".to_string(), + required: true, + token_grant: Some(ProviderCredentialTokenGrant { + grant_type: ProviderCredentialTokenGrantType::TokenExchange as i32, + token_endpoint: "https://issuer.example.com/token".to_string(), + subject_token: Some(ProviderCredentialTokenGrantSubjectToken { + source: "provider_credential".to_string(), + credential: "user_oidc_token".to_string(), + subject_token_type: "urn:ietf:params:oauth:token-type:access_token" + .to_string(), + }), + ..Default::default() + }), + ..Default::default() + }, + ProviderProfileCredential { + name: "admin_api_token".to_string(), + required: true, + token_grant: Some(ProviderCredentialTokenGrant { + grant_type: ProviderCredentialTokenGrantType::TokenExchange as i32, + token_endpoint: "https://issuer.example.com/token".to_string(), + subject_token: Some(ProviderCredentialTokenGrantSubjectToken { + source: "provider_credential".to_string(), + credential: "admin_oidc_token".to_string(), + subject_token_type: "urn:ietf:params:oauth:token-type:access_token" + .to_string(), + }), + ..Default::default() + }), + ..Default::default() + }, + ], + ..Default::default() + }, + ); +} + #[tokio::test] async fn provider_cli_run_functions_support_full_crud_flow() { let ts = run_server().await; @@ -1001,15 +1099,16 @@ async fn provider_cli_run_functions_support_full_crud_flow() { .await .expect("provider list"); - run::provider_update( - &ts.endpoint, - "my-claude", - false, - &["API_KEY=rotated".to_string()], - &["profile=prod".to_string()], - &[], - &ts.tls, - ) + run::provider_update(run::ProviderUpdateOptions { + server: &ts.endpoint, + name: "my-claude", + from_existing: false, + from_oidc_token: false, + credentials: &["API_KEY=rotated".to_string()], + config: &["profile=prod".to_string()], + credential_expires_at: &[], + tls: &ts.tls, + }) .await .expect("provider update"); @@ -1018,6 +1117,205 @@ async fn provider_cli_run_functions_support_full_crud_flow() { .expect("provider delete"); } +#[tokio::test] +async fn provider_create_from_oidc_token_stores_active_gateway_token() { + let ts = run_server().await; + register_oidc_token_exchange_profile(&ts).await; + let xdg_dir = tempfile::tempdir().unwrap(); + let _xdg = EnvVarGuard::set(&[("XDG_CONFIG_HOME", xdg_dir.path().to_str().unwrap())]); + let gateway_name = "oidc-gateway"; + openshell_bootstrap::oidc_token::store_oidc_token( + gateway_name, + &openshell_bootstrap::oidc_token::OidcTokenBundle { + access_token: "user-access-token".to_string(), + refresh_token: Some("user-refresh-token".to_string()), + expires_at: Some(1_893_456_000), + issuer: "https://issuer.example.com".to_string(), + client_id: "openshell-cli".to_string(), + }, + ) + .expect("store oidc token"); + let tls = ts.tls.with_gateway_name(gateway_name); + + run::provider_create_with_options(run::ProviderCreateOptions { + server: &ts.endpoint, + name: "custom-api", + provider_type: "oidc-token-exchange", + credentials: &[], + credential_source: run::ProviderCreateCredentialSource::OidcToken, + config: &[], + tls: &tls, + }) + .await + .expect("provider create from oidc token"); + + let provider = ts + .state + .providers + .lock() + .await + .get("custom-api") + .cloned() + .expect("provider"); + assert_eq!( + provider.credentials.get("user_oidc_token"), + Some(&"user-access-token".to_string()) + ); + assert_eq!( + provider.credential_expires_at_ms.get("user_oidc_token"), + Some(&1_893_456_000_000) + ); +} + +#[tokio::test] +async fn provider_update_from_oidc_token_replaces_subject_token() { + let ts = run_server().await; + register_oidc_token_exchange_profile(&ts).await; + ts.state.providers.lock().await.insert( + "custom-api".to_string(), + Provider { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: "provider-id".to_string(), + name: "custom-api".to_string(), + ..Default::default() + }), + r#type: "oidc-token-exchange".to_string(), + credentials: std::iter::once(("user_oidc_token".to_string(), "old-token".to_string())) + .collect(), + config: HashMap::new(), + credential_expires_at_ms: HashMap::new(), + }, + ); + let xdg_dir = tempfile::tempdir().unwrap(); + let _xdg = EnvVarGuard::set(&[("XDG_CONFIG_HOME", xdg_dir.path().to_str().unwrap())]); + let gateway_name = "oidc-gateway"; + openshell_bootstrap::oidc_token::store_oidc_token( + gateway_name, + &openshell_bootstrap::oidc_token::OidcTokenBundle { + access_token: "new-user-access-token".to_string(), + refresh_token: None, + expires_at: Some(1_893_456_300), + issuer: "https://issuer.example.com".to_string(), + client_id: "openshell-cli".to_string(), + }, + ) + .expect("store oidc token"); + let tls = ts.tls.with_gateway_name(gateway_name); + + run::provider_update(run::ProviderUpdateOptions { + server: &ts.endpoint, + name: "custom-api", + from_existing: false, + from_oidc_token: true, + credentials: &[], + config: &[], + credential_expires_at: &[], + tls: &tls, + }) + .await + .expect("provider update from oidc token"); + + let provider = ts + .state + .providers + .lock() + .await + .get("custom-api") + .cloned() + .expect("provider"); + assert_eq!( + provider.credentials.get("user_oidc_token"), + Some(&"new-user-access-token".to_string()) + ); + assert_eq!( + provider.credential_expires_at_ms.get("user_oidc_token"), + Some(&1_893_456_300_000) + ); +} + +#[tokio::test] +async fn provider_create_from_oidc_token_rejects_non_subject_credential() { + let ts = run_server().await; + register_oidc_token_exchange_profile(&ts).await; + + let err = run::provider_create_with_options(run::ProviderCreateOptions { + server: &ts.endpoint, + name: "custom-api", + provider_type: "oidc-token-exchange", + credentials: &["api_token".to_string()], + credential_source: run::ProviderCreateCredentialSource::OidcToken, + config: &[], + tls: &ts.tls, + }) + .await + .expect_err("wrong subject credential should fail"); + + let message = err.to_string(); + assert!(message.contains("is not a token_exchange subject credential")); + assert!(message.contains("user_oidc_token")); +} + +#[tokio::test] +async fn provider_create_from_oidc_token_requires_credential_for_ambiguous_profile() { + let ts = run_server().await; + register_ambiguous_oidc_token_exchange_profile(&ts).await; + + let err = run::provider_create_with_options(run::ProviderCreateOptions { + server: &ts.endpoint, + name: "custom-api", + provider_type: "ambiguous-oidc-token-exchange", + credentials: &[], + credential_source: run::ProviderCreateCredentialSource::OidcToken, + config: &[], + tls: &ts.tls, + }) + .await + .expect_err("ambiguous subject credential should fail"); + + let message = err.to_string(); + assert!(message.contains("declares multiple token_exchange subject credentials")); + assert!(message.contains("user_oidc_token")); + assert!(message.contains("admin_oidc_token")); + assert!(message.contains("--credential KEY")); +} + +#[tokio::test] +async fn provider_create_from_oidc_token_reports_expired_token_without_refresh_token() { + let ts = run_server().await; + register_oidc_token_exchange_profile(&ts).await; + let xdg_dir = tempfile::tempdir().unwrap(); + let _xdg = EnvVarGuard::set(&[("XDG_CONFIG_HOME", xdg_dir.path().to_str().unwrap())]); + let gateway_name = "oidc-gateway"; + openshell_bootstrap::oidc_token::store_oidc_token( + gateway_name, + &openshell_bootstrap::oidc_token::OidcTokenBundle { + access_token: "expired-user-access-token".to_string(), + refresh_token: None, + expires_at: Some(1), + issuer: "https://issuer.example.com".to_string(), + client_id: "openshell-cli".to_string(), + }, + ) + .expect("store oidc token"); + let tls = ts.tls.with_gateway_name(gateway_name); + + let err = run::provider_create_with_options(run::ProviderCreateOptions { + server: &ts.endpoint, + name: "custom-api", + provider_type: "oidc-token-exchange", + credentials: &[], + credential_source: run::ProviderCreateCredentialSource::OidcToken, + config: &[], + tls: &tls, + }) + .await + .expect_err("expired oidc token without refresh token should fail"); + + let message = err.to_string(); + assert!(message.contains("failed to load or refresh OIDC token")); + assert!(message.contains("no refresh token available")); +} + #[tokio::test] async fn provider_list_profiles_cli_uses_profile_browsing_rpc() { let ts = run_server().await; @@ -1185,17 +1483,15 @@ async fn provider_create_allows_empty_credentials_for_gateway_refresh_profiles() }, ); - run::provider_create_with_options( - &ts.endpoint, - "custom-refresh-provider", - "custom-refresh", - false, - &[], - false, - true, - &[], - &ts.tls, - ) + run::provider_create_with_options(run::ProviderCreateOptions { + server: &ts.endpoint, + name: "custom-refresh-provider", + provider_type: "custom-refresh", + credentials: &[], + credential_source: run::ProviderCreateCredentialSource::Runtime, + config: &[], + tls: &ts.tls, + }) .await .expect("provider create"); @@ -1664,9 +1960,18 @@ async fn provider_update_from_existing_uses_profile_discovery_when_v2_enabled() ); let _env = EnvVarGuard::set(&[("CUSTOM_UPDATE_DISCOVERY_API_KEY", "updated-profile-secret")]); - run::provider_update(&ts.endpoint, "custom-update", true, &[], &[], &[], &ts.tls) - .await - .expect("profile-backed provider update --from-existing"); + run::provider_update(run::ProviderUpdateOptions { + server: &ts.endpoint, + name: "custom-update", + from_existing: true, + from_oidc_token: false, + credentials: &[], + config: &[], + credential_expires_at: &[], + tls: &ts.tls, + }) + .await + .expect("profile-backed provider update --from-existing"); let provider = ts .state @@ -1952,8 +2257,9 @@ async fn provider_create_rejects_combined_from_gcloud_adc_and_from_existing() { .expect_err("from-gcloud-adc and from-existing should be mutually exclusive"); assert!( - err.to_string() - .contains("--from-gcloud-adc cannot be combined with --from-existing or --credential"), + err.to_string().contains( + "--from-gcloud-adc cannot be combined with --from-existing, --from-oidc-token, or --credential" + ), "unexpected error: {err}" ); assert!(ts.state.providers.lock().await.is_empty()); @@ -1977,8 +2283,9 @@ async fn provider_create_rejects_combined_from_gcloud_adc_and_credentials() { .expect_err("from-gcloud-adc and credentials should be mutually exclusive"); assert!( - err.to_string() - .contains("--from-gcloud-adc cannot be combined with --from-existing or --credential"), + err.to_string().contains( + "--from-gcloud-adc cannot be combined with --from-existing, --from-oidc-token, or --credential" + ), "unexpected error: {err}" ); assert!(ts.state.providers.lock().await.is_empty()); diff --git a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs index 7061614cb..cb7a793c8 100644 --- a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs +++ b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs @@ -14,7 +14,8 @@ use openshell_core::proto::{ AttachSandboxProviderRequest, AttachSandboxProviderResponse, CreateProviderRequest, CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, DeleteProviderRequest, DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, - DetachSandboxProviderRequest, DetachSandboxProviderResponse, ExecSandboxEvent, + DetachSandboxProviderRequest, DetachSandboxProviderResponse, + ExchangeProviderSubjectTokenRequest, ExchangeProviderSubjectTokenResponse, ExecSandboxEvent, ExecSandboxInput, ExecSandboxRequest, GatewayMessage, GetGatewayConfigRequest, GetGatewayConfigResponse, GetProviderRequest, GetSandboxConfigRequest, GetSandboxConfigResponse, GetSandboxProviderEnvironmentRequest, @@ -235,6 +236,13 @@ impl OpenShell for TestOpenShell { Ok(Response::new(RevokeSshSessionResponse::default())) } + async fn exchange_provider_subject_token( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + async fn create_provider( &self, _request: tonic::Request, diff --git a/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs b/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs index 5e753eff9..afd8b6f33 100644 --- a/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs +++ b/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs @@ -14,7 +14,8 @@ use openshell_core::proto::{ AttachSandboxProviderRequest, AttachSandboxProviderResponse, CreateProviderRequest, CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, DeleteProviderRequest, DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, - DetachSandboxProviderRequest, DetachSandboxProviderResponse, ExecSandboxEvent, + DetachSandboxProviderRequest, DetachSandboxProviderResponse, + ExchangeProviderSubjectTokenRequest, ExchangeProviderSubjectTokenResponse, ExecSandboxEvent, ExecSandboxInput, ExecSandboxRequest, GatewayMessage, GetGatewayConfigRequest, GetGatewayConfigResponse, GetProviderRequest, GetSandboxConfigRequest, GetSandboxConfigResponse, GetSandboxPolicyStatusRequest, GetSandboxPolicyStatusResponse, @@ -220,6 +221,13 @@ impl OpenShell for TestOpenShell { )) } + async fn exchange_provider_subject_token( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + async fn create_provider( &self, _request: tonic::Request, diff --git a/crates/openshell-core/src/grpc_client.rs b/crates/openshell-core/src/grpc_client.rs index 96158a1d1..682fdeb2d 100644 --- a/crates/openshell-core/src/grpc_client.rs +++ b/crates/openshell-core/src/grpc_client.rs @@ -23,12 +23,12 @@ use std::sync::{Arc, OnceLock, RwLock}; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use crate::proto::{ - DenialSummary, GetDraftPolicyRequest, GetInferenceBundleRequest, GetInferenceBundleResponse, - GetSandboxConfigRequest, GetSandboxProviderEnvironmentRequest, IssueSandboxTokenRequest, - NetworkActivitySummary, PolicyChunk, PolicySource, PolicyStatus, RefreshSandboxTokenRequest, - ReportPolicyStatusRequest, SandboxPolicy as ProtoSandboxPolicy, SubmitPolicyAnalysisRequest, - SubmitPolicyAnalysisResponse, UpdateConfigRequest, inference_client::InferenceClient, - open_shell_client::OpenShellClient, + DenialSummary, ExchangeProviderSubjectTokenRequest, GetDraftPolicyRequest, + GetInferenceBundleRequest, GetInferenceBundleResponse, GetSandboxConfigRequest, + GetSandboxProviderEnvironmentRequest, IssueSandboxTokenRequest, NetworkActivitySummary, + PolicyChunk, PolicySource, PolicyStatus, RefreshSandboxTokenRequest, ReportPolicyStatusRequest, + SandboxPolicy as ProtoSandboxPolicy, SubmitPolicyAnalysisRequest, SubmitPolicyAnalysisResponse, + UpdateConfigRequest, inference_client::InferenceClient, open_shell_client::OpenShellClient, }; use crate::sandbox_env; use miette::{IntoDiagnostic, Result, WrapErr}; @@ -690,6 +690,55 @@ pub async fn fetch_provider_environment( }) } +pub async fn exchange_provider_subject_token( + endpoint: &str, + sandbox_id: &str, + provider: &str, + credential_key: &str, + supervisor_jwt_svid: &str, +) -> Result { + debug!( + endpoint = %endpoint, + sandbox_id = %sandbox_id, + provider = %provider, + credential_key = %credential_key, + "Exchanging provider subject token through gateway" + ); + + let mut client = connect(endpoint).await?; + let response = client + .exchange_provider_subject_token(ExchangeProviderSubjectTokenRequest { + sandbox_id: sandbox_id.to_string(), + provider: provider.to_string(), + credential_key: credential_key.to_string(), + supervisor_jwt_svid: supervisor_jwt_svid.to_string(), + }) + .await + .map_err(provider_subject_token_exchange_status)?; + let inner = response.into_inner(); + Ok(ProviderSubjectTokenExchangeResult { + access_token: inner.access_token, + expires_in: inner.expires_in, + token_type: inner.token_type, + }) +} + +fn provider_subject_token_exchange_status(status: Status) -> miette::Report { + let message = status.message(); + if message.is_empty() { + miette::miette!( + "gateway ExchangeProviderSubjectToken failed with status {}", + status.code() + ) + } else { + miette::miette!( + "gateway ExchangeProviderSubjectToken failed with status {}: {}", + status.code(), + message + ) + } +} + /// A reusable gRPC client for the `OpenShell` service. /// /// Wraps a tonic channel connected once and reused for policy polling @@ -720,6 +769,12 @@ pub struct ProviderEnvironmentResult { pub dynamic_credentials: HashMap, } +pub struct ProviderSubjectTokenExchangeResult { + pub access_token: String, + pub expires_in: i64, + pub token_type: String, +} + impl CachedOpenShellClient { pub async fn connect(endpoint: &str) -> Result { debug!(endpoint = %endpoint, "Connecting openshell gRPC client for policy polling"); diff --git a/crates/openshell-core/src/lib.rs b/crates/openshell-core/src/lib.rs index 3215ac2c9..f002e3c70 100644 --- a/crates/openshell-core/src/lib.rs +++ b/crates/openshell-core/src/lib.rs @@ -33,6 +33,7 @@ pub mod provider_credentials; pub mod sandbox_env; pub mod secrets; pub mod settings; +pub mod spiffe; pub mod telemetry; pub mod time; diff --git a/crates/openshell-core/src/sandbox_env.rs b/crates/openshell-core/src/sandbox_env.rs index b457a4a8e..3850287af 100644 --- a/crates/openshell-core/src/sandbox_env.rs +++ b/crates/openshell-core/src/sandbox_env.rs @@ -71,3 +71,9 @@ pub const K8S_SA_TOKEN_FILE: &str = "OPENSHELL_K8S_SA_TOKEN_FILE"; /// exchanges without using SPIFFE for gateway authentication. pub const PROVIDER_SPIFFE_WORKLOAD_API_SOCKET: &str = "OPENSHELL_PROVIDER_SPIFFE_WORKLOAD_API_SOCKET"; + +/// Filesystem path to the gateway's SPIFFE Workload API UNIX socket. +/// +/// When set, the gateway can fetch its own JWT-SVID for provider token exchange +/// client assertions. +pub const GATEWAY_SPIFFE_WORKLOAD_API_SOCKET: &str = "OPENSHELL_GATEWAY_SPIFFE_WORKLOAD_API_SOCKET"; diff --git a/crates/openshell-core/src/spiffe.rs b/crates/openshell-core/src/spiffe.rs new file mode 100644 index 000000000..843dbbff6 --- /dev/null +++ b/crates/openshell-core/src/spiffe.rs @@ -0,0 +1,173 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Shared SPIFFE helpers used by the gateway and sandbox supervisor. + +use std::path::Path; + +use base64::Engine as _; +use serde::Deserialize; + +/// SPIFFE JWT-SVID claims used by `OpenShell` token exchange flows. +#[derive(Debug, Clone, Deserialize)] +pub struct SpiffeJwtClaims { + pub iss: String, + pub sub: String, + pub aud: AudienceClaim, + #[serde(default)] + pub exp: i64, +} + +/// JWT `aud` claim representation accepted by SPIFFE JWT-SVIDs. +#[derive(Debug, Clone, Deserialize)] +#[serde(untagged)] +pub enum AudienceClaim { + One(String), + Many(Vec), +} + +impl AudienceClaim { + pub fn contains(&self, expected: &str) -> bool { + match self { + Self::One(value) => value == expected, + Self::Many(values) => values.iter().any(|value| value == expected), + } + } +} + +#[derive(Debug, thiserror::Error)] +pub enum JwtSvidParseError { + #[error("invalid JWT-SVID format")] + Format, + #[error("invalid JWT-SVID payload encoding")] + PayloadEncoding, + #[error("invalid JWT-SVID payload")] + Payload, +} + +/// Convert a path to a SPIFFE Workload API endpoint URL. +/// +/// If the path already has a scheme (`unix:` or `tcp:`), use it as-is. +/// Otherwise, assume it is a Unix socket path and prepend `unix:`. +pub fn workload_api_endpoint(path: &Path) -> String { + let path = path.to_string_lossy(); + if path.starts_with("unix:") || path.starts_with("tcp:") { + path.into_owned() + } else { + format!("unix:{path}") + } +} + +pub fn parse_unverified_jwt_svid_claims(token: &str) -> Result { + let segments = token.split('.').collect::>(); + if segments.len() != 3 || segments.iter().any(|segment| segment.is_empty()) { + return Err(JwtSvidParseError::Format); + } + let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD + .decode(segments[1]) + .map_err(|_| JwtSvidParseError::PayloadEncoding)?; + serde_json::from_slice::(&decoded).map_err(|_| JwtSvidParseError::Payload) +} + +pub fn trust_domain(subject: &str) -> Option<&str> { + let rest = subject.strip_prefix("spiffe://")?; + let (trust_domain, _) = rest.split_once('/').unwrap_or((rest, "")); + (!trust_domain.is_empty()).then_some(trust_domain) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn unsigned_svid_fixture(issuer: &str, subject: &str, audience: serde_json::Value) -> String { + let header = serde_json::json!({ "alg": "RS256", "kid": "test-key" }); + let payload = serde_json::json!({ + "iss": issuer, + "sub": subject, + "aud": audience, + "exp": 4_102_444_800_i64 + }); + let encoded_header = base64::engine::general_purpose::URL_SAFE_NO_PAD + .encode(serde_json::to_vec(&header).expect("serialize header")); + let encoded_payload = base64::engine::general_purpose::URL_SAFE_NO_PAD + .encode(serde_json::to_vec(&payload).expect("serialize payload")); + format!("{encoded_header}.{encoded_payload}.signature") + } + + #[test] + fn workload_api_endpoint_preserves_explicit_scheme() { + assert_eq!( + workload_api_endpoint(Path::new("unix:/run/spire/agent.sock")), + "unix:/run/spire/agent.sock" + ); + assert_eq!( + workload_api_endpoint(Path::new("tcp:127.0.0.1:8081")), + "tcp:127.0.0.1:8081" + ); + } + + #[test] + fn workload_api_endpoint_defaults_to_unix_socket() { + assert_eq!( + workload_api_endpoint(Path::new("/run/spire/agent.sock")), + "unix:/run/spire/agent.sock" + ); + } + + #[test] + fn parse_unverified_jwt_svid_claims_accepts_string_audience() { + let token = unsigned_svid_fixture( + "https://spiffe.example.test", + "spiffe://openshell/openshell/sandbox/sb-a", + serde_json::json!("https://auth.example.com"), + ); + + let claims = parse_unverified_jwt_svid_claims(&token).expect("valid claims"); + + assert_eq!(claims.iss, "https://spiffe.example.test"); + assert_eq!(claims.sub, "spiffe://openshell/openshell/sandbox/sb-a"); + assert!(claims.aud.contains("https://auth.example.com")); + assert!(!claims.aud.contains("https://other.example.com")); + } + + #[test] + fn parse_unverified_jwt_svid_claims_accepts_array_audience() { + let token = unsigned_svid_fixture( + "https://spiffe.example.test", + "spiffe://openshell/openshell/sandbox/sb-a", + serde_json::json!(["https://auth.example.com", "https://other.example.com"]), + ); + + let claims = parse_unverified_jwt_svid_claims(&token).expect("valid claims"); + + assert!(claims.aud.contains("https://auth.example.com")); + assert!(claims.aud.contains("https://other.example.com")); + } + + #[test] + fn parse_unverified_jwt_svid_claims_rejects_truncated_jwt() { + assert!(matches!( + parse_unverified_jwt_svid_claims("header.payload"), + Err(JwtSvidParseError::Format) + )); + } + + #[test] + fn parse_unverified_jwt_svid_claims_rejects_empty_jwt_segments() { + assert!(matches!( + parse_unverified_jwt_svid_claims("header..signature"), + Err(JwtSvidParseError::Format) + )); + } + + #[test] + fn trust_domain_extracts_domain_from_spiffe_id() { + assert_eq!( + trust_domain("spiffe://openshell/openshell/sandbox/sb-a"), + Some("openshell") + ); + assert_eq!(trust_domain("spiffe://openshell"), Some("openshell")); + assert_eq!(trust_domain("not-a-spiffe-id"), None); + assert_eq!(trust_domain("spiffe:///empty"), None); + } +} diff --git a/crates/openshell-providers/src/profiles.rs b/crates/openshell-providers/src/profiles.rs index d31085b64..089c7e448 100644 --- a/crates/openshell-providers/src/profiles.rs +++ b/crates/openshell-providers/src/profiles.rs @@ -8,7 +8,8 @@ use openshell_core::proto::{ GraphqlOperation, L7Allow, L7DenyRule, L7QueryMatcher, L7Rule, NetworkBinary, NetworkEndpoint, NetworkPolicyRule, ProviderCredentialRefresh, ProviderCredentialRefreshMaterial, - ProviderCredentialRefreshStrategy, ProviderProfile, ProviderProfileCategory, + ProviderCredentialRefreshStrategy, ProviderCredentialTokenGrantSubjectToken, + ProviderCredentialTokenGrantType, ProviderProfile, ProviderProfileCategory, ProviderProfileCredential, ProviderProfileDiscovery, }; use serde::ser::SerializeStruct; @@ -98,6 +99,13 @@ pub struct CredentialProfile { #[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] pub struct TokenGrantProfile { + #[serde( + default = "default_token_grant_type", + deserialize_with = "deserialize_token_grant_type", + serialize_with = "serialize_token_grant_type", + skip_serializing_if = "is_client_credentials_grant" + )] + pub grant_type: ProviderCredentialTokenGrantType, pub token_endpoint: String, #[serde(default, skip_serializing_if = "String::is_empty")] pub audience: String, @@ -111,6 +119,18 @@ pub struct TokenGrantProfile { pub cache_ttl_seconds: i64, #[serde(default, skip_serializing_if = "Vec::is_empty")] pub audience_overrides: Vec, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub subject_token: Option, + #[serde(default, skip_serializing_if = "String::is_empty")] + pub requested_token_type: String, +} + +#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] +pub struct TokenGrantSubjectTokenProfile { + pub source: String, + pub credential: String, + #[serde(default, skip_serializing_if = "String::is_empty")] + pub subject_token_type: String, } #[derive(Debug, Clone, Deserialize, Serialize, PartialEq, Eq)] @@ -508,6 +528,26 @@ fn default_refresh_strategy() -> ProviderCredentialRefreshStrategy { ProviderCredentialRefreshStrategy::Unspecified } +fn default_token_grant_type() -> ProviderCredentialTokenGrantType { + ProviderCredentialTokenGrantType::ClientCredentials +} + +fn effective_token_grant_type( + grant_type: ProviderCredentialTokenGrantType, +) -> ProviderCredentialTokenGrantType { + match grant_type { + ProviderCredentialTokenGrantType::Unspecified => { + ProviderCredentialTokenGrantType::ClientCredentials + } + other => other, + } +} + +#[allow(clippy::trivially_copy_pass_by_ref)] +fn is_client_credentials_grant(value: &ProviderCredentialTokenGrantType) -> bool { + effective_token_grant_type(*value) == ProviderCredentialTokenGrantType::ClientCredentials +} + fn deserialize_category<'de, D>(deserializer: D) -> Result where D: Deserializer<'de>, @@ -550,6 +590,28 @@ where serializer.serialize_str(provider_refresh_strategy_to_yaml(*strategy)) } +fn deserialize_token_grant_type<'de, D>( + deserializer: D, +) -> Result +where + D: Deserializer<'de>, +{ + let raw = String::deserialize(deserializer)?; + provider_token_grant_type_from_yaml(&raw) + .ok_or_else(|| de::Error::custom(format!("unsupported provider token grant type: {raw}"))) +} + +#[allow(clippy::trivially_copy_pass_by_ref)] +fn serialize_token_grant_type( + grant_type: &ProviderCredentialTokenGrantType, + serializer: S, +) -> Result +where + S: Serializer, +{ + serializer.serialize_str(provider_token_grant_type_to_yaml(*grant_type)) +} + #[must_use] pub fn provider_profile_category_from_yaml(raw: &str) -> Option { match raw.trim().to_ascii_lowercase().replace('-', "_").as_str() { @@ -608,6 +670,26 @@ pub fn provider_refresh_strategy_to_yaml( } } +#[must_use] +pub fn provider_token_grant_type_from_yaml(raw: &str) -> Option { + match raw.trim().to_ascii_lowercase().replace('-', "_").as_str() { + "" | "client_credentials" => Some(ProviderCredentialTokenGrantType::ClientCredentials), + "token_exchange" => Some(ProviderCredentialTokenGrantType::TokenExchange), + _ => None, + } +} + +#[must_use] +pub fn provider_token_grant_type_to_yaml( + grant_type: ProviderCredentialTokenGrantType, +) -> &'static str { + match grant_type { + ProviderCredentialTokenGrantType::TokenExchange => "token_exchange", + ProviderCredentialTokenGrantType::ClientCredentials + | ProviderCredentialTokenGrantType::Unspecified => "client_credentials", + } +} + fn credential_refresh_from_proto(refresh: &ProviderCredentialRefresh) -> CredentialRefreshProfile { CredentialRefreshProfile { strategy: ProviderCredentialRefreshStrategy::try_from(refresh.strategy) @@ -653,6 +735,10 @@ fn token_grant_from_proto( token_grant: &openshell_core::proto::ProviderCredentialTokenGrant, ) -> TokenGrantProfile { TokenGrantProfile { + grant_type: effective_token_grant_type( + ProviderCredentialTokenGrantType::try_from(token_grant.grant_type) + .unwrap_or(ProviderCredentialTokenGrantType::ClientCredentials), + ), token_endpoint: token_grant.token_endpoint.clone(), audience: token_grant.audience.clone(), jwt_svid_audience: token_grant.jwt_svid_audience.clone(), @@ -664,6 +750,11 @@ fn token_grant_from_proto( .iter() .map(token_grant_audience_override_from_proto) .collect(), + subject_token: token_grant + .subject_token + .as_ref() + .map(token_grant_subject_token_from_proto), + requested_token_type: token_grant.requested_token_type.clone(), } } @@ -671,6 +762,7 @@ fn token_grant_to_proto( token_grant: &TokenGrantProfile, ) -> openshell_core::proto::ProviderCredentialTokenGrant { openshell_core::proto::ProviderCredentialTokenGrant { + grant_type: token_grant.grant_type as i32, token_endpoint: token_grant.token_endpoint.clone(), audience: token_grant.audience.clone(), jwt_svid_audience: token_grant.jwt_svid_audience.clone(), @@ -682,6 +774,31 @@ fn token_grant_to_proto( .iter() .map(token_grant_audience_override_to_proto) .collect(), + subject_token: token_grant + .subject_token + .as_ref() + .map(token_grant_subject_token_to_proto), + requested_token_type: token_grant.requested_token_type.clone(), + } +} + +fn token_grant_subject_token_from_proto( + subject_token: &ProviderCredentialTokenGrantSubjectToken, +) -> TokenGrantSubjectTokenProfile { + TokenGrantSubjectTokenProfile { + source: subject_token.source.clone(), + credential: subject_token.credential.clone(), + subject_token_type: subject_token.subject_token_type.clone(), + } +} + +fn token_grant_subject_token_to_proto( + subject_token: &TokenGrantSubjectTokenProfile, +) -> ProviderCredentialTokenGrantSubjectToken { + ProviderCredentialTokenGrantSubjectToken { + source: subject_token.source.clone(), + credential: subject_token.credential.clone(), + subject_token_type: subject_token.subject_token_type.clone(), } } @@ -1215,6 +1332,12 @@ pub fn validate_profile_set( message, )); } + diagnostics.extend(validate_token_grant_subject_token( + source, + profile_id, + credential, + &credential_names, + )); diagnostics.extend(validate_token_grant_audience_overrides( source, profile_id, @@ -1290,6 +1413,75 @@ struct TokenGrantOverrideBinding { score: u32, } +fn validate_token_grant_subject_token( + source: &str, + profile_id: &str, + credential: &CredentialProfile, + credential_names: &HashSet, +) -> Vec { + let Some(token_grant) = credential.token_grant.as_ref() else { + return Vec::new(); + }; + let grant_type = effective_token_grant_type(token_grant.grant_type); + let mut diagnostics = Vec::new(); + + match grant_type { + ProviderCredentialTokenGrantType::ClientCredentials => { + if token_grant.subject_token.is_some() { + diagnostics.push(ProfileValidationDiagnostic::error( + source, + profile_id, + "credentials.token_grant.subject_token", + "subject_token is only valid for token_exchange grants", + )); + } + } + ProviderCredentialTokenGrantType::TokenExchange => { + let Some(subject_token) = token_grant.subject_token.as_ref() else { + diagnostics.push(ProfileValidationDiagnostic::error( + source, + profile_id, + "credentials.token_grant.subject_token", + "token_exchange grants require subject_token", + )); + return diagnostics; + }; + + let source_value = subject_token.source.trim(); + if source_value != "provider_credential" { + diagnostics.push(ProfileValidationDiagnostic::error( + source, + profile_id, + "credentials.token_grant.subject_token.source", + "subject_token.source must be provider_credential", + )); + } + + let subject_credential = subject_token.credential.trim(); + if subject_credential.is_empty() { + diagnostics.push(ProfileValidationDiagnostic::error( + source, + profile_id, + "credentials.token_grant.subject_token.credential", + "subject_token.credential is required", + )); + } else if !credential_names.contains(subject_credential) { + diagnostics.push(ProfileValidationDiagnostic::error( + source, + profile_id, + "credentials.token_grant.subject_token.credential", + format!("unknown subject token credential: {subject_credential}"), + )); + } + } + ProviderCredentialTokenGrantType::Unspecified => { + unreachable!("effective_token_grant_type must normalize unspecified token grant type") + } + } + + diagnostics +} + fn validate_token_grant_audience_overrides( source: &str, profile_id: &str, @@ -1633,7 +1825,7 @@ pub fn get_default_profile(id: &str) -> Option<&'static ProviderTypeProfile> { #[cfg(test)] mod tests { - use openshell_core::proto::ProviderProfileCategory; + use openshell_core::proto::{ProviderCredentialTokenGrantType, ProviderProfileCategory}; use super::{ DiscoveryProfile, ProfileError, ProviderTypeProfile, default_profiles, get_default_profile, @@ -1990,6 +2182,166 @@ credentials: ); } + #[test] + fn token_exchange_grant_round_trips_through_proto_and_yaml() { + let profile = parse_profile_yaml( + r" +id: keycloak-token-exchange +display_name: Keycloak Token Exchange +credentials: + - name: USER_OIDC_TOKEN + required: true + - name: access_token + auth_style: bearer + header_name: Authorization + token_grant: + grant_type: token_exchange + token_endpoint: https://keycloak.example.com/realms/openshell/protocol/openid-connect/token + subject_token: + source: provider_credential + credential: USER_OIDC_TOKEN + subject_token_type: urn:ietf:params:oauth:token-type:access_token + jwt_svid_audience: https://keycloak.example.com/realms/openshell + client_assertion_type: urn:ietf:params:oauth:client-assertion-type:jwt-bearer + audience: https://graph.example.com + scopes: [graph.read] + requested_token_type: urn:ietf:params:oauth:token-type:access_token +", + ) + .expect("profile should parse"); + + let diagnostics = + validate_profile_set(&[("keycloak-token-exchange.yaml".to_string(), profile.clone())]); + assert!( + diagnostics.is_empty(), + "unexpected diagnostics: {diagnostics:?}" + ); + + let token_grant = profile.credentials[1] + .token_grant + .as_ref() + .expect("token grant should parse"); + assert_eq!( + token_grant.grant_type, + ProviderCredentialTokenGrantType::TokenExchange + ); + assert_eq!( + token_grant + .subject_token + .as_ref() + .map(|subject| subject.credential.as_str()), + Some("USER_OIDC_TOKEN") + ); + + let from_proto = ProviderTypeProfile::from_proto(&profile.to_proto()); + assert_eq!( + from_proto.credentials[1].token_grant, + profile.credentials[1].token_grant + ); + + let exported = profile_to_yaml(&from_proto).expect("yaml"); + assert!(exported.contains("grant_type: token_exchange")); + assert!(exported.contains("subject_token:")); + let reparsed = parse_profile_yaml(&exported).expect("re-parse"); + assert_eq!( + reparsed.credentials[1].token_grant, + profile.credentials[1].token_grant + ); + } + + #[test] + fn validate_profile_set_rejects_token_exchange_without_subject_token() { + let profile = parse_profile_yaml( + r" +id: missing-subject-token +display_name: Missing Subject Token +credentials: + - name: access_token + auth_style: bearer + header_name: Authorization + token_grant: + grant_type: token_exchange + token_endpoint: https://keycloak.example.com/realms/openshell/protocol/openid-connect/token +", + ) + .expect("profile should parse"); + + let diagnostics = validate_profile_set(&[("missing.yaml".to_string(), profile)]); + let diagnostic = diagnostics + .iter() + .find(|diagnostic| diagnostic.field == "credentials.token_grant.subject_token") + .expect("expected subject_token diagnostic"); + assert_eq!( + diagnostic.message, + "token_exchange grants require subject_token" + ); + } + + #[test] + fn validate_profile_set_rejects_token_exchange_unknown_subject_credential() { + let profile = parse_profile_yaml( + r" +id: unknown-subject-token +display_name: Unknown Subject Token +credentials: + - name: access_token + auth_style: bearer + header_name: Authorization + token_grant: + grant_type: token_exchange + token_endpoint: https://keycloak.example.com/realms/openshell/protocol/openid-connect/token + subject_token: + source: provider_credential + credential: USER_OIDC_TOKEN +", + ) + .expect("profile should parse"); + + let diagnostics = validate_profile_set(&[("unknown.yaml".to_string(), profile)]); + let diagnostic = diagnostics + .iter() + .find(|diagnostic| { + diagnostic.field == "credentials.token_grant.subject_token.credential" + }) + .expect("expected subject token credential diagnostic"); + assert!( + diagnostic + .message + .contains("unknown subject token credential: USER_OIDC_TOKEN") + ); + } + + #[test] + fn validate_profile_set_rejects_subject_token_on_client_credentials_grant() { + let profile = parse_profile_yaml( + r" +id: misplaced-subject-token +display_name: Misplaced Subject Token +credentials: + - name: USER_OIDC_TOKEN + - name: access_token + auth_style: bearer + header_name: Authorization + token_grant: + token_endpoint: https://keycloak.example.com/realms/openshell/protocol/openid-connect/token + subject_token: + source: provider_credential + credential: USER_OIDC_TOKEN +", + ) + .expect("profile should parse"); + + let diagnostics = validate_profile_set(&[("misplaced.yaml".to_string(), profile)]); + let diagnostic = diagnostics + .iter() + .find(|diagnostic| diagnostic.field == "credentials.token_grant.subject_token") + .expect("expected subject_token diagnostic"); + assert_eq!( + diagnostic.message, + "subject_token is only valid for token_exchange grants" + ); + } + #[test] fn validate_profile_set_rejects_plain_http_token_endpoint() { for token_endpoint in [ diff --git a/crates/openshell-server/Cargo.toml b/crates/openshell-server/Cargo.toml index 39a26b14e..9770f1009 100644 --- a/crates/openshell-server/Cargo.toml +++ b/crates/openshell-server/Cargo.toml @@ -81,9 +81,11 @@ tokio-stream = { workspace = true } sqlx = { workspace = true } reqwest = { workspace = true } uuid = { workspace = true } +base64 = { workspace = true } hmac = "0.12" sha2 = { workspace = true } jsonwebtoken = { workspace = true } +spiffe = { workspace = true } async-trait = "0.1" url = { workspace = true } glob = { workspace = true } diff --git a/crates/openshell-server/src/auth/sandbox_methods.rs b/crates/openshell-server/src/auth/sandbox_methods.rs index 76d5e1324..1ea54a438 100644 --- a/crates/openshell-server/src/auth/sandbox_methods.rs +++ b/crates/openshell-server/src/auth/sandbox_methods.rs @@ -32,6 +32,9 @@ mod tests { assert!(is_sandbox_callable( "/openshell.inference.v1.Inference/GetInferenceBundle" )); + assert!(is_sandbox_callable( + "/openshell.v1.OpenShell/ExchangeProviderSubjectToken" + )); } #[test] diff --git a/crates/openshell-server/src/grpc/mod.rs b/crates/openshell-server/src/grpc/mod.rs index 5947bb334..a95353c67 100644 --- a/crates/openshell-server/src/grpc/mod.rs +++ b/crates/openshell-server/src/grpc/mod.rs @@ -19,7 +19,8 @@ use openshell_core::proto::{ DeleteProviderProfileResponse, DeleteProviderRefreshRequest, DeleteProviderRefreshResponse, DeleteProviderRequest, DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, DeleteServiceRequest, DeleteServiceResponse, DetachSandboxProviderRequest, - DetachSandboxProviderResponse, EditDraftChunkRequest, EditDraftChunkResponse, ExecSandboxEvent, + DetachSandboxProviderResponse, EditDraftChunkRequest, EditDraftChunkResponse, + ExchangeProviderSubjectTokenRequest, ExchangeProviderSubjectTokenResponse, ExecSandboxEvent, ExecSandboxInput, ExecSandboxRequest, ExposeServiceRequest, GatewayMessage, GetDraftHistoryRequest, GetDraftHistoryResponse, GetDraftPolicyRequest, GetDraftPolicyResponse, GetGatewayConfigRequest, GetGatewayConfigResponse, GetProviderProfileRequest, @@ -494,6 +495,14 @@ impl OpenShell for OpenShellService { policy::handle_get_sandbox_provider_environment(&self.state, request).await } + #[rpc_auth(auth = "sandbox")] + async fn exchange_provider_subject_token( + &self, + request: Request, + ) -> Result, Status> { + provider::handle_exchange_provider_subject_token(&self.state, request).await + } + #[rpc_auth(auth = "dual", scope = "config:write", role = "admin")] async fn update_config( &self, diff --git a/crates/openshell-server/src/grpc/provider.rs b/crates/openshell-server/src/grpc/provider.rs index 3f760e834..dd7886715 100644 --- a/crates/openshell-server/src/grpc/provider.rs +++ b/crates/openshell-server/src/grpc/provider.rs @@ -16,6 +16,8 @@ use openshell_core::telemetry::{ LifecycleOperation, ProviderProfile as TelemetryProviderProfile, TelemetryOutcome, }; use prost::Message; +use std::error::Error as StdError; + use tonic::Status; use tracing::warn; @@ -1175,21 +1177,127 @@ use openshell_core::proto::{ ConfigureProviderRefreshRequest, ConfigureProviderRefreshResponse, CreateProviderRequest, DeleteProviderProfileRequest, DeleteProviderProfileResponse, DeleteProviderRefreshRequest, DeleteProviderRefreshResponse, DeleteProviderRequest, DeleteProviderResponse, + ExchangeProviderSubjectTokenRequest, ExchangeProviderSubjectTokenResponse, GetProviderProfileRequest, GetProviderRefreshStatusRequest, GetProviderRefreshStatusResponse, GetProviderRequest, ImportProviderProfilesRequest, ImportProviderProfilesResponse, LintProviderProfilesRequest, LintProviderProfilesResponse, ListProviderProfilesRequest, ListProviderProfilesResponse, ListProvidersRequest, ListProvidersResponse, - ProviderCredentialRefreshStrategy, ProviderProfileDiagnostic, ProviderProfileImportItem, - ProviderProfileResponse, ProviderResponse, RotateProviderCredentialRequest, - RotateProviderCredentialResponse, StoredProviderProfile, UpdateProviderRequest, + ProviderCredentialRefreshStrategy, ProviderCredentialTokenGrantType, ProviderProfileDiagnostic, + ProviderProfileImportItem, ProviderProfileResponse, ProviderResponse, + RotateProviderCredentialRequest, RotateProviderCredentialResponse, StoredProviderProfile, + UpdateProviderRequest, +}; +use openshell_core::spiffe::{ + JwtSvidParseError, SpiffeJwtClaims, parse_unverified_jwt_svid_claims, + trust_domain as spiffe_trust_domain, workload_api_endpoint, }; use openshell_providers::{ CredentialRefreshProfile, ProfileValidationDiagnostic, ProviderTypeProfile, default_profiles, get_default_profile, normalize_profile_id, normalize_provider_type, validate_profile_set, }; -use std::sync::Arc; +use serde::Deserialize; +use std::sync::{Arc, LazyLock, RwLock}; use tonic::{Request, Response}; +const TOKEN_EXCHANGE_GRANT_TYPE: &str = "urn:ietf:params:oauth:grant-type:token-exchange"; +const DEFAULT_CLIENT_ASSERTION_TYPE: &str = + "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"; +const DEFAULT_TOKEN_TYPE: &str = "urn:ietf:params:oauth:token-type:access_token"; +const MAX_OAUTH_ERROR_FIELD_LEN: usize = 256; +const DEFAULT_INTERMEDIATE_TOKEN_CACHE_TTL_SECONDS: i64 = 300; +const MAX_INTERMEDIATE_TOKEN_CACHE_TTL_SECONDS: i64 = 3600; +const INTERMEDIATE_TOKEN_CACHE_EXPIRY_SKEW_SECONDS: i64 = 30; +const MAX_INTERMEDIATE_TOKEN_CACHE_ENTRIES: usize = 1024; + +static TOKEN_EXCHANGE_HTTP_CLIENT: LazyLock> = + LazyLock::new(|| { + reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(30)) + .connect_timeout(std::time::Duration::from_secs(30)) + .no_proxy() + .redirect(reqwest::redirect::Policy::none()) + .build() + .map_err(|err| { + format!("provider token exchange HTTP client configuration failed: {err}") + }) + }); +static INTERMEDIATE_TOKEN_CACHE: LazyLock = + LazyLock::new(IntermediateTokenCache::new); + +fn token_exchange_http_client() -> Result<&'static reqwest::Client, Status> { + TOKEN_EXCHANGE_HTTP_CLIENT + .as_ref() + .map_err(|err| Status::internal(err.clone())) +} + +#[derive(Clone)] +struct CachedIntermediateToken { + access_token: String, + token_type: String, + expires_at_ms: i64, +} + +struct IntermediateTokenCache { + tokens: Arc>>, +} + +impl IntermediateTokenCache { + fn new() -> Self { + Self { + tokens: Arc::new(RwLock::new(std::collections::HashMap::new())), + } + } + + fn get(&self, key: &str) -> Option { + let now_ms = crate::persistence::current_time_ms(); + let tokens = self.tokens.read().ok()?; + let cached = tokens.get(key)?; + if cached.expires_at_ms <= now_ms { + return None; + } + Some(TokenExchangeResponseBody { + access_token: cached.access_token.clone(), + expires_in: cached.expires_at_ms.saturating_sub(now_ms) / 1000, + token_type: cached.token_type.clone(), + }) + } + + fn set(&self, key: String, token: &TokenExchangeResponseBody, expires_at_ms: i64) { + if let Ok(mut tokens) = self.tokens.write() { + let now_ms = crate::persistence::current_time_ms(); + tokens.retain(|_, cached| cached.expires_at_ms > now_ms); + if tokens.len() >= MAX_INTERMEDIATE_TOKEN_CACHE_ENTRIES + && let Some(evict_key) = tokens.keys().next().cloned() + { + tokens.remove(&evict_key); + } + tokens.insert( + key, + CachedIntermediateToken { + access_token: token.access_token.clone(), + token_type: token.token_type.clone(), + expires_at_ms, + }, + ); + } + } +} + +#[derive(Debug, Deserialize)] +struct TokenExchangeResponseBody { + access_token: String, + #[serde(default)] + expires_in: i64, + #[serde(default)] + token_type: String, +} + +#[derive(Debug, Deserialize)] +struct OAuthErrorResponse { + error: Option, + error_description: Option, +} + pub(super) async fn handle_create_provider( state: &Arc, request: Request, @@ -1757,6 +1865,610 @@ pub(super) async fn handle_update_provider( } } +pub(super) async fn handle_exchange_provider_subject_token( + state: &Arc, + request: Request, +) -> Result, Status> { + let req = request.get_ref().clone(); + let principal = crate::auth::guard::enforce_sandbox_scope(&request, &req.sandbox_id)?; + crate::auth::guard::ensure_sandbox_principal_scope(&principal, &req.sandbox_id)?; + drop(request); + + if req.provider.trim().is_empty() { + return Err(Status::invalid_argument("provider is required")); + } + if req.credential_key.trim().is_empty() { + return Err(Status::invalid_argument("credential_key is required")); + } + if req.supervisor_jwt_svid.trim().is_empty() { + return Err(Status::invalid_argument("supervisor_jwt_svid is required")); + } + + let sandbox = state + .store + .get_message::(&req.sandbox_id) + .await + .map_err(|e| Status::internal(format!("fetch sandbox failed: {e}")))? + .ok_or_else(|| Status::not_found("sandbox not found"))?; + let spec = sandbox + .spec + .as_ref() + .ok_or_else(|| Status::internal("sandbox has no spec"))?; + if !spec + .providers + .iter() + .any(|provider| provider == &req.provider) + { + return Err(Status::permission_denied( + "provider is not attached to this sandbox", + )); + } + + let provider = state + .store + .get_message_by_name::(&req.provider) + .await + .map_err(|e| Status::internal(format!("fetch provider failed: {e}")))? + .ok_or_else(|| Status::not_found("provider not found"))?; + let profile_id = normalize_provider_type(&provider.r#type).unwrap_or(provider.r#type.as_str()); + let profile = get_provider_type_profile(state.store.as_ref(), profile_id) + .await? + .ok_or_else(|| Status::failed_precondition("provider profile not found"))?; + let profile_proto = profile.to_proto(); + let credential = profile_proto + .credentials + .iter() + .find(|credential| credential.name == req.credential_key) + .ok_or_else(|| { + Status::failed_precondition("credential not declared by provider profile") + })?; + let token_grant = credential + .token_grant + .as_ref() + .ok_or_else(|| Status::failed_precondition("credential does not declare token_grant"))?; + let grant_type = ProviderCredentialTokenGrantType::try_from(token_grant.grant_type) + .unwrap_or(ProviderCredentialTokenGrantType::ClientCredentials); + if grant_type != ProviderCredentialTokenGrantType::TokenExchange { + return Err(Status::failed_precondition( + "credential token_grant is not token_exchange", + )); + } + let subject_token = token_grant + .subject_token + .as_ref() + .ok_or_else(|| Status::failed_precondition("token_exchange subject_token is missing"))?; + if subject_token.source != "provider_credential" { + return Err(Status::failed_precondition( + "unsupported subject_token source", + )); + } + if !profile_proto + .credentials + .iter() + .any(|credential| credential.name == subject_token.credential) + { + return Err(Status::failed_precondition( + "subject token credential not declared by provider profile", + )); + } + let stored_subject_token = provider + .credentials + .get(&subject_token.credential) + .filter(|value| !value.is_empty()) + .ok_or_else(|| Status::failed_precondition("subject token credential is not configured"))?; + ensure_subject_token_credential_not_expired(&provider, &subject_token.credential)?; + + let jwt_svid_audience = + effective_jwt_svid_audience(&token_grant.token_endpoint, &token_grant.jwt_svid_audience); + let gateway_jwt_svid = fetch_gateway_jwt_svid(&jwt_svid_audience).await?; + let gateway_claims = parse_unverified_spiffe_claims(&gateway_jwt_svid)?; + validate_gateway_jwt_svid_claims(&gateway_claims, &jwt_svid_audience)?; + let supervisor_claims = validate_supervisor_jwt_svid( + &req.supervisor_jwt_svid, + &gateway_claims, + &jwt_svid_audience, + ) + .await?; + + let intermediate_cache_key = intermediate_token_cache_key(IntermediateTokenCacheKeyInput { + provider: &provider, + dynamic_credential: &req.credential_key, + subject_credential: &subject_token.credential, + token_endpoint: &token_grant.token_endpoint, + client_assertion_type: effective_client_assertion_type(&token_grant.client_assertion_type), + subject_token_type: effective_token_type(&subject_token.subject_token_type), + audience: &supervisor_claims.sub, + requested_token_type: effective_token_type(&token_grant.requested_token_type), + supervisor_subject: &supervisor_claims.sub, + gateway_subject: &gateway_claims.sub, + }); + if let Some(cached) = INTERMEDIATE_TOKEN_CACHE.get(&intermediate_cache_key) { + return Ok(Response::new(ExchangeProviderSubjectTokenResponse { + access_token: cached.access_token, + expires_in: cached.expires_in, + token_type: cached.token_type, + })); + } + + let token_response = perform_intermediate_token_exchange( + &token_grant.token_endpoint, + &gateway_jwt_svid, + &token_grant.client_assertion_type, + stored_subject_token, + &subject_token.subject_token_type, + &supervisor_claims.sub, + &token_grant.requested_token_type, + ) + .await + .inspect_err(|status| { + warn!( + sandbox_id = %req.sandbox_id, + provider = %req.provider, + credential_key = %req.credential_key, + subject_credential = %subject_token.credential, + client_assertion_type = %effective_client_assertion_type(&token_grant.client_assertion_type), + gateway_svid_issuer = %gateway_claims.iss, + gateway_svid_subject = %gateway_claims.sub, + gateway_svid_audience = ?gateway_claims.aud, + supervisor_svid_issuer = %supervisor_claims.iss, + supervisor_svid_subject = %supervisor_claims.sub, + supervisor_svid_audience = ?supervisor_claims.aud, + status = ?status.code(), + error = %status.message(), + "intermediate provider token exchange failed" + ); + })?; + let cache_expires_at_ms = intermediate_token_cache_expires_at_ms( + &token_response, + token_grant.cache_ttl_seconds, + provider_credential_expires_at_ms(&provider, &subject_token.credential), + supervisor_claims.exp, + ); + if cache_expires_at_ms > crate::persistence::current_time_ms() { + INTERMEDIATE_TOKEN_CACHE.set(intermediate_cache_key, &token_response, cache_expires_at_ms); + } + + Ok(Response::new(ExchangeProviderSubjectTokenResponse { + access_token: token_response.access_token, + expires_in: token_response.expires_in, + token_type: token_response.token_type, + })) +} + +fn ensure_subject_token_credential_not_expired( + provider: &Provider, + credential_key: &str, +) -> Result<(), Status> { + let expires_at_ms = provider_credential_expires_at_ms(provider, credential_key); + if expires_at_ms > 0 && expires_at_ms <= crate::persistence::current_time_ms() { + return Err(Status::failed_precondition( + "subject token credential has expired", + )); + } + Ok(()) +} + +fn provider_credential_expires_at_ms(provider: &Provider, credential_key: &str) -> i64 { + provider + .credential_expires_at_ms + .get(credential_key) + .copied() + .unwrap_or_default() +} + +struct IntermediateTokenCacheKeyInput<'a> { + provider: &'a Provider, + dynamic_credential: &'a str, + subject_credential: &'a str, + token_endpoint: &'a str, + client_assertion_type: &'a str, + subject_token_type: &'a str, + audience: &'a str, + requested_token_type: &'a str, + supervisor_subject: &'a str, + gateway_subject: &'a str, +} + +fn intermediate_token_cache_key(input: IntermediateTokenCacheKeyInput<'_>) -> String { + let provider_id = input + .provider + .metadata + .as_ref() + .map(|metadata| metadata.id.as_str()) + .filter(|id| !id.is_empty()) + .unwrap_or_else(|| input.provider.object_name()); + let provider_resource_version = input + .provider + .metadata + .as_ref() + .map_or(0, |metadata| metadata.resource_version); + format!( + "{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}", + provider_id, + provider_resource_version, + input.dynamic_credential, + input.subject_credential, + input.token_endpoint, + input.client_assertion_type, + input.subject_token_type, + input.audience, + input.requested_token_type, + input.supervisor_subject, + input.gateway_subject + ) +} + +fn intermediate_token_cache_expires_at_ms( + token: &TokenExchangeResponseBody, + cache_ttl_seconds: i64, + subject_token_expires_at_ms: i64, + supervisor_svid_exp_seconds: i64, +) -> i64 { + let now_ms = crate::persistence::current_time_ms(); + let mut ttl_seconds = if token.expires_in > 0 { + token + .expires_in + .min(MAX_INTERMEDIATE_TOKEN_CACHE_TTL_SECONDS) + } else { + DEFAULT_INTERMEDIATE_TOKEN_CACHE_TTL_SECONDS + }; + if cache_ttl_seconds > 0 { + ttl_seconds = ttl_seconds.min(cache_ttl_seconds); + } + ttl_seconds = ttl_seconds + .saturating_sub(INTERMEDIATE_TOKEN_CACHE_EXPIRY_SKEW_SECONDS) + .max(1); + let mut expires_at_ms = now_ms.saturating_add(ttl_seconds.saturating_mul(1000)); + expires_at_ms = cap_cache_expiry_ms(expires_at_ms, jwt_exp_ms(&token.access_token)); + expires_at_ms = cap_cache_expiry_ms(expires_at_ms, Some(subject_token_expires_at_ms)); + expires_at_ms = cap_cache_expiry_ms( + expires_at_ms, + (supervisor_svid_exp_seconds > 0).then(|| supervisor_svid_exp_seconds.saturating_mul(1000)), + ); + expires_at_ms +} + +fn cap_cache_expiry_ms(current_expires_at_ms: i64, cap_expires_at_ms: Option) -> i64 { + let Some(cap_expires_at_ms) = cap_expires_at_ms.filter(|value| *value > 0) else { + return current_expires_at_ms; + }; + current_expires_at_ms.min( + cap_expires_at_ms + .saturating_sub(INTERMEDIATE_TOKEN_CACHE_EXPIRY_SKEW_SECONDS.saturating_mul(1000)), + ) +} + +fn jwt_exp_ms(token: &str) -> Option { + use base64::Engine as _; + let payload = token.split('.').nth(1)?; + let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD + .decode(payload) + .ok()?; + let claims = serde_json::from_slice::(&decoded).ok()?; + claims + .get("exp")? + .as_i64() + .map(|exp| exp.saturating_mul(1000)) +} + +async fn fetch_gateway_jwt_svid(audience: &str) -> Result { + let socket_path = + std::env::var(openshell_core::sandbox_env::GATEWAY_SPIFFE_WORKLOAD_API_SOCKET) + .ok() + .filter(|value| !value.trim().is_empty()) + .ok_or_else(|| { + Status::failed_precondition(format!( + "{} is required for provider token exchange", + openshell_core::sandbox_env::GATEWAY_SPIFFE_WORKLOAD_API_SOCKET + )) + })?; + let endpoint = workload_api_endpoint(std::path::Path::new(&socket_path)); + let client = spiffe::WorkloadApiClient::connect_to(&endpoint) + .await + .map_err(|e| { + Status::failed_precondition(format!("SPIFFE Workload API unavailable: {e}")) + })?; + client + .fetch_jwt_token([audience], None) + .await + .map_err(|e| Status::failed_precondition(format!("failed to fetch gateway JWT-SVID: {e}"))) +} + +fn validate_gateway_jwt_svid_claims( + claims: &SpiffeJwtClaims, + expected_audience: &str, +) -> Result<(), Status> { + if !claims.aud.contains(expected_audience) { + return Err(Status::failed_precondition( + "gateway SVID audience does not match token grant audience", + )); + } + if spiffe_trust_domain(&claims.sub).is_none() { + return Err(Status::failed_precondition( + "gateway SVID subject is not a SPIFFE ID", + )); + } + if claims.exp > 0 && claims.exp.saturating_mul(1000) <= crate::persistence::current_time_ms() { + return Err(Status::failed_precondition("gateway SVID has expired")); + } + Ok(()) +} + +async fn validate_supervisor_jwt_svid( + token: &str, + gateway_claims: &SpiffeJwtClaims, + expected_audience: &str, +) -> Result { + let unverified = parse_unverified_spiffe_claims(token)?; + if unverified.iss != gateway_claims.iss { + return Err(Status::permission_denied( + "supervisor SVID issuer does not match gateway SVID issuer", + )); + } + if !unverified.aud.contains(expected_audience) { + return Err(Status::permission_denied( + "supervisor SVID audience does not match token grant audience", + )); + } + let supervisor_trust_domain = spiffe_trust_domain(&unverified.sub) + .ok_or_else(|| Status::permission_denied("supervisor SVID subject is not a SPIFFE ID"))?; + let gateway_trust_domain = spiffe_trust_domain(&gateway_claims.sub) + .ok_or_else(|| Status::failed_precondition("gateway SVID subject is not a SPIFFE ID"))?; + if supervisor_trust_domain != gateway_trust_domain { + return Err(Status::permission_denied( + "supervisor SVID trust domain does not match gateway SVID trust domain", + )); + } + + let socket_path = + std::env::var(openshell_core::sandbox_env::GATEWAY_SPIFFE_WORKLOAD_API_SOCKET) + .ok() + .filter(|value| !value.trim().is_empty()) + .ok_or_else(|| { + Status::failed_precondition(format!( + "{} is required for supervisor JWT-SVID validation", + openshell_core::sandbox_env::GATEWAY_SPIFFE_WORKLOAD_API_SOCKET + )) + })?; + let endpoint = workload_api_endpoint(std::path::Path::new(&socket_path)); + let client = spiffe::WorkloadApiClient::connect_to(&endpoint) + .await + .map_err(|e| { + Status::failed_precondition(format!("SPIFFE Workload API unavailable: {e}")) + })?; + let bundles = client + .fetch_jwt_bundles() + .await + .map_err(|e| Status::internal(format!("SPIFFE JWT bundle fetch failed: {e}")))?; + spiffe::JwtSvid::parse_and_validate(token, &bundles, &[expected_audience]) + .map_err(|e| Status::permission_denied(format!("invalid supervisor JWT-SVID: {e}")))?; + Ok(unverified) +} + +fn format_error_chain(prefix: &str, error: &dyn StdError) -> String { + let mut message = format!("{prefix}: {error}"); + let mut source = error.source(); + while let Some(err) = source { + message.push_str(": "); + message.push_str(&err.to_string()); + source = err.source(); + } + message +} + +fn parse_unverified_spiffe_claims(token: &str) -> Result { + parse_unverified_jwt_svid_claims(token).map_err(jwt_svid_parse_error_status) +} + +fn jwt_svid_parse_error_status(error: JwtSvidParseError) -> Status { + Status::permission_denied(error.to_string()) +} + +async fn perform_intermediate_token_exchange( + token_endpoint: &str, + gateway_jwt_svid: &str, + client_assertion_type: &str, + subject_token: &str, + subject_token_type: &str, + audience: &str, + requested_token_type: &str, +) -> Result { + let token_endpoint_url = parse_token_endpoint_url(token_endpoint)?; + let client_assertion_type = effective_client_assertion_type(client_assertion_type); + let subject_token_type = effective_token_type(subject_token_type); + let requested_token_type = effective_token_type(requested_token_type); + let form_params = [ + ("grant_type", TOKEN_EXCHANGE_GRANT_TYPE), + ("client_assertion_type", client_assertion_type), + ("client_assertion", gateway_jwt_svid), + ("subject_token", subject_token), + ("subject_token_type", subject_token_type), + ("audience", audience), + ("requested_token_type", requested_token_type), + ]; + + let response = token_exchange_http_client()? + .post(token_endpoint_url) + .form(&form_params) + .send() + .await + .map_err(|e| { + Status::internal(format_error_chain( + "provider token exchange request failed", + &e, + )) + })?; + if !response.status().is_success() { + let status = response.status(); + let body = response + .text() + .await + .unwrap_or_else(|_| "".to_string()); + return Err(Status::failed_precondition(token_exchange_failure_message( + status, &body, + ))); + } + let body = response + .json::() + .await + .map_err(|e| { + Status::internal(format!( + "provider token exchange response parse failed: {e}" + )) + })?; + validate_oauth_access_token(&body.access_token)?; + Ok(body) +} + +fn parse_token_endpoint_url(token_endpoint: &str) -> Result { + let url = reqwest::Url::parse(token_endpoint) + .map_err(|_| Status::invalid_argument("token_endpoint must be an absolute URL"))?; + if token_endpoint_transport_allowed(&url) { + return Ok(url); + } + Err(Status::invalid_argument( + "token_endpoint must use https, except http for loopback or in-cluster service hosts", + )) +} + +fn token_endpoint_transport_allowed(url: &reqwest::Url) -> bool { + match url.scheme() { + "https" => true, + "http" => url + .host_str() + .is_some_and(|host| is_loopback_host(host) || is_kubernetes_service_host(host)), + _ => false, + } +} + +fn is_loopback_host(host: &str) -> bool { + let host = host.trim_matches(['[', ']']); + if host.eq_ignore_ascii_case("localhost") { + return true; + } + match host.parse::() { + Ok(std::net::IpAddr::V4(v4)) => v4.is_loopback(), + Ok(std::net::IpAddr::V6(v6)) => { + v6.is_loopback() || v6.to_ipv4_mapped().is_some_and(|v4| v4.is_loopback()) + } + Err(_) => false, + } +} + +fn is_kubernetes_service_host(host: &str) -> bool { + let host = host.trim_end_matches('.').to_ascii_lowercase(); + let labels = host.split('.').collect::>(); + let is_service_name = labels.len() == 3 && labels[2] == "svc"; + let is_cluster_local_service = + labels.len() == 5 && labels[2] == "svc" && labels[3] == "cluster" && labels[4] == "local"; + (is_service_name || is_cluster_local_service) && labels.iter().all(|label| !label.is_empty()) +} + +fn effective_client_assertion_type(client_assertion_type: &str) -> &str { + if client_assertion_type.trim().is_empty() { + DEFAULT_CLIENT_ASSERTION_TYPE + } else { + client_assertion_type + } +} + +fn effective_token_type(token_type: &str) -> &str { + if token_type.trim().is_empty() { + DEFAULT_TOKEN_TYPE + } else { + token_type + } +} + +fn effective_jwt_svid_audience(token_endpoint: &str, jwt_svid_audience: &str) -> String { + if !jwt_svid_audience.trim().is_empty() { + return jwt_svid_audience.to_string(); + } + derive_issuer_from_token_endpoint(token_endpoint) +} + +fn derive_issuer_from_token_endpoint(token_endpoint: &str) -> String { + if let Some(realms_idx) = token_endpoint.find("/realms/") { + let after_realms = &token_endpoint[realms_idx + "/realms/".len()..]; + if let Some(slash_idx) = after_realms.find('/') { + let realm_end = realms_idx + "/realms/".len() + slash_idx; + return token_endpoint[..realm_end].to_string(); + } + } + token_endpoint.to_string() +} + +fn validate_oauth_access_token(token: &str) -> Result<(), Status> { + if token.is_empty() || !is_token68(token) { + return Err(Status::internal( + "provider token exchange returned a malformed access token", + )); + } + Ok(()) +} + +fn is_token68(token: &str) -> bool { + let mut padding_started = false; + let mut saw_value = false; + for byte in token.bytes() { + if byte == b'=' { + padding_started = true; + continue; + } + if padding_started || !is_token68_value_byte(byte) { + return false; + } + saw_value = true; + } + saw_value +} + +fn is_token68_value_byte(byte: u8) -> bool { + byte.is_ascii_alphanumeric() || matches!(byte, b'-' | b'.' | b'_' | b'~' | b'+' | b'/') +} + +fn token_exchange_failure_message(status: reqwest::StatusCode, body: &str) -> String { + let Ok(error_response) = serde_json::from_str::(body) else { + return format!("provider token exchange failed with status {status}"); + }; + let error = error_response + .error + .as_deref() + .map(sanitize_oauth_error_field) + .filter(|value| !value.is_empty()); + let description = error_response + .error_description + .as_deref() + .map(sanitize_oauth_error_field) + .filter(|value| !value.is_empty()); + match (error, description) { + (Some(error), Some(description)) => { + format!( + "provider token exchange failed with status {status}: error={error}; error_description={description}" + ) + } + (Some(error), None) => { + format!("provider token exchange failed with status {status}: error={error}") + } + (None, Some(description)) => { + format!( + "provider token exchange failed with status {status}: error_description={description}" + ) + } + (None, None) => format!("provider token exchange failed with status {status}"), + } +} + +fn sanitize_oauth_error_field(value: &str) -> String { + value + .chars() + .map(|ch| if ch.is_control() { ' ' } else { ch }) + .take(MAX_OAUTH_ERROR_FIELD_LEN) + .collect::() + .trim() + .to_string() +} + pub(super) async fn handle_get_provider_refresh_status( state: &Arc, request: Request, @@ -2166,6 +2878,7 @@ fn telemetry_provider_profile(provider_type: &str) -> TelemetryProviderProfile { #[cfg(test)] mod tests { use super::*; + use crate::auth::principal::{Principal, SandboxIdentitySource, SandboxPrincipal}; use crate::grpc::test_support::test_server_state; use crate::grpc::{MAX_MAP_KEY_LEN, MAX_PROVIDER_TYPE_LEN}; use crate::persistence::test_store; @@ -2174,12 +2887,16 @@ mod tests { L7Allow, L7Rule, LintProviderProfilesRequest, ListProviderProfilesRequest, NetworkBinary, NetworkEndpoint, ProviderCredentialRefresh, ProviderCredentialRefreshMaterial, ProviderCredentialTokenGrant, ProviderCredentialTokenGrantAudienceOverride, + ProviderCredentialTokenGrantSubjectToken, ProviderCredentialTokenGrantType, ProviderProfile, ProviderProfileCategory, ProviderProfileCredential, ProviderProfileImportItem, Sandbox, SandboxSpec, }; + use openshell_core::spiffe::AudienceClaim; use openshell_core::{ObjectId, ObjectName}; use std::collections::HashMap; use tonic::{Code, Request}; + use wiremock::matchers::{body_string_contains, method, path}; + use wiremock::{Mock, MockServer, ResponseTemplate}; #[test] fn env_key_validation_accepts_valid_keys() { @@ -2277,6 +2994,9 @@ mod tests { }, ) .collect(), + grant_type: ProviderCredentialTokenGrantType::ClientCredentials as i32, + subject_token: None, + requested_token_type: String::new(), }), }; let profile = ProviderProfile { @@ -2617,10 +3337,612 @@ mod tests { scopes: vec!["read".to_string()], cache_ttl_seconds: 300, audience_overrides: Vec::new(), + grant_type: ProviderCredentialTokenGrantType::ClientCredentials as i32, + subject_token: None, + requested_token_type: String::new(), }), } } + fn token_exchange_credential( + name: &str, + subject_credential: &str, + ) -> ProviderProfileCredential { + let mut credential = token_grant_credential(name); + let token_grant = credential + .token_grant + .as_mut() + .expect("token grant credential"); + token_grant.grant_type = ProviderCredentialTokenGrantType::TokenExchange as i32; + token_grant.subject_token = Some(ProviderCredentialTokenGrantSubjectToken { + source: "provider_credential".to_string(), + credential: subject_credential.to_string(), + subject_token_type: "urn:ietf:params:oauth:token-type:access_token".to_string(), + }); + credential + } + + async fn import_token_exchange_profile( + state: &Arc, + id: &str, + dynamic_credential: &str, + subject_credential: &str, + ) { + let mut profile = custom_profile(id); + profile.credentials = vec![ + token_exchange_credential(dynamic_credential, subject_credential), + static_credential(subject_credential, subject_credential, false), + ]; + profile.endpoints = vec![NetworkEndpoint { + host: "api.example.com".to_string(), + port: 443, + path: "/v1/**".to_string(), + protocol: "rest".to_string(), + ..Default::default() + }]; + let response = handle_import_provider_profiles( + state, + Request::new(ImportProviderProfilesRequest { + profiles: vec![ProviderProfileImportItem { + profile: Some(profile), + source: format!("{id}.yaml"), + }], + }), + ) + .await + .unwrap() + .into_inner(); + assert!( + response.imported, + "profile import failed: {:?}", + response.diagnostics + ); + } + + async fn store_token_exchange_profile_with_undeclared_subject( + state: &Arc, + id: &str, + dynamic_credential: &str, + subject_credential: &str, + ) { + let mut profile = custom_profile(id); + profile.credentials = vec![token_exchange_credential( + dynamic_credential, + subject_credential, + )]; + profile.endpoints = vec![NetworkEndpoint { + host: "api.example.com".to_string(), + port: 443, + path: "/v1/**".to_string(), + protocol: "rest".to_string(), + ..Default::default() + }]; + state + .store + .put_message(&stored_provider_profile(profile)) + .await + .unwrap(); + } + + fn sandbox_principal(sandbox_id: &str) -> Principal { + Principal::Sandbox(SandboxPrincipal { + sandbox_id: sandbox_id.to_string(), + source: SandboxIdentitySource::BootstrapJwt { + issuer: "openshell-gateway:test".to_string(), + }, + trust_domain: Some("openshell".to_string()), + }) + } + + fn with_sandbox_principal(mut request: Request, sandbox_id: &str) -> Request { + request + .extensions_mut() + .insert(sandbox_principal(sandbox_id)); + request + } + + fn unsigned_svid_fixture(issuer: &str, subject: &str, audience: serde_json::Value) -> String { + use base64::Engine as _; + let header = serde_json::json!({ "alg": "RS256", "kid": "test-key" }); + let payload = serde_json::json!({ + "iss": issuer, + "sub": subject, + "aud": audience, + "exp": 4_102_444_800_i64 + }); + let encoded_header = base64::engine::general_purpose::URL_SAFE_NO_PAD + .encode(serde_json::to_vec(&header).expect("serialize header")); + let encoded_payload = base64::engine::general_purpose::URL_SAFE_NO_PAD + .encode(serde_json::to_vec(&payload).expect("serialize payload")); + format!("{encoded_header}.{encoded_payload}.signature") + } + + fn gateway_spiffe_claims(trust_domain: &str) -> SpiffeJwtClaims { + SpiffeJwtClaims { + iss: "https://spiffe.example.test".to_string(), + sub: format!("spiffe://{trust_domain}/openshell/gateway"), + aud: AudienceClaim::One("https://auth.example.com".to_string()), + exp: 4_102_444_800, + } + } + + #[test] + fn parse_unverified_spiffe_claims_rejects_truncated_jwt() { + let err = parse_unverified_spiffe_claims("header.payload") + .expect_err("truncated JWT-SVID must fail"); + + assert_eq!(err.code(), Code::PermissionDenied); + assert!(err.message().contains("format")); + } + + #[test] + fn parse_unverified_spiffe_claims_rejects_empty_jwt_segments() { + let err = parse_unverified_spiffe_claims("header..signature") + .expect_err("empty payload segment must fail"); + + assert_eq!(err.code(), Code::PermissionDenied); + assert!(err.message().contains("format")); + } + + #[test] + fn gateway_svid_validation_requires_expected_audience() { + let claims = gateway_spiffe_claims("openshell"); + + let err = validate_gateway_jwt_svid_claims(&claims, "https://other.example.com") + .expect_err("wrong gateway SVID audience must fail"); + + assert_eq!(err.code(), Code::FailedPrecondition); + assert!(err.message().contains("audience")); + } + + #[test] + fn gateway_svid_validation_requires_spiffe_subject() { + let mut claims = gateway_spiffe_claims("openshell"); + claims.sub = "not-a-spiffe-id".to_string(); + + let err = validate_gateway_jwt_svid_claims(&claims, "https://auth.example.com") + .expect_err("non-SPIFFE gateway SVID subject must fail"); + + assert_eq!(err.code(), Code::FailedPrecondition); + assert!(err.message().contains("not a SPIFFE ID")); + } + + #[test] + fn gateway_svid_validation_rejects_expired_claims() { + let mut claims = gateway_spiffe_claims("openshell"); + claims.exp = (crate::persistence::current_time_ms() / 1000).saturating_sub(1); + + let err = validate_gateway_jwt_svid_claims(&claims, "https://auth.example.com") + .expect_err("expired gateway SVID must fail"); + + assert_eq!(err.code(), Code::FailedPrecondition); + assert!(err.message().contains("expired")); + } + + #[tokio::test] + async fn supervisor_svid_validation_rejects_non_spiffe_subject_before_bundle_fetch() { + let token = unsigned_svid_fixture( + "https://spiffe.example.test", + "not-a-spiffe-id", + serde_json::json!("https://auth.example.com"), + ); + + let err = validate_supervisor_jwt_svid( + &token, + &gateway_spiffe_claims("openshell"), + "https://auth.example.com", + ) + .await + .expect_err("non-SPIFFE subject must fail"); + + assert_eq!(err.code(), Code::PermissionDenied); + assert!(err.message().contains("not a SPIFFE ID")); + } + + #[tokio::test] + async fn supervisor_svid_validation_rejects_wrong_trust_domain_before_bundle_fetch() { + let token = unsigned_svid_fixture( + "https://spiffe.example.test", + "spiffe://other-domain/openshell/sandbox/sb-a", + serde_json::json!("https://auth.example.com"), + ); + + let err = validate_supervisor_jwt_svid( + &token, + &gateway_spiffe_claims("openshell"), + "https://auth.example.com", + ) + .await + .expect_err("wrong trust domain must fail"); + + assert_eq!(err.code(), Code::PermissionDenied); + assert!(err.message().contains("trust domain")); + } + + #[tokio::test] + async fn supervisor_svid_validation_rejects_wrong_audience_before_bundle_fetch() { + let token = unsigned_svid_fixture( + "https://spiffe.example.test", + "spiffe://openshell/openshell/sandbox/sb-a", + serde_json::json!(["https://other.example.com"]), + ); + + let err = validate_supervisor_jwt_svid( + &token, + &gateway_spiffe_claims("openshell"), + "https://auth.example.com", + ) + .await + .expect_err("wrong audience must fail"); + + assert_eq!(err.code(), Code::PermissionDenied); + assert!(err.message().contains("audience")); + } + + #[tokio::test] + async fn supervisor_svid_validation_accepts_arbitrary_path_before_bundle_fetch() { + let token = unsigned_svid_fixture( + "https://spiffe.example.test", + "spiffe://openshell/custom/install/specific/supervisor", + serde_json::json!("https://auth.example.com"), + ); + + let err = validate_supervisor_jwt_svid( + &token, + &gateway_spiffe_claims("openshell"), + "https://auth.example.com", + ) + .await + .expect_err("matching trust domain should reach bundle validation"); + + assert_ne!(err.code(), Code::PermissionDenied); + assert!( + err.message() + .contains("OPENSHELL_GATEWAY_SPIFFE_WORKLOAD_API_SOCKET") + || err.message().contains("SPIFFE Workload API") + ); + } + + #[tokio::test] + async fn intermediate_token_exchange_posts_expected_form_and_parses_response() { + let mock_server = MockServer::start().await; + Mock::given(method("POST")) + .and(path("/token")) + .and(body_string_contains( + "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange", + )) + .and(body_string_contains( + "client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer", + )) + .and(body_string_contains("client_assertion=gateway-jwt-svid")) + .and(body_string_contains("subject_token=stored-user-token")) + .and(body_string_contains( + "subject_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aaccess_token", + )) + .and(body_string_contains( + "audience=spiffe%3A%2F%2Fopenshell%2Fopenshell%2Fsandbox%2Fsb-a", + )) + .and(body_string_contains( + "requested_token_type=urn%3Aietf%3Aparams%3Aoauth%3Atoken-type%3Aaccess_token", + )) + .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({ + "access_token": "intermediate-token", + "expires_in": 120, + "token_type": "Bearer" + }))) + .expect(1) + .mount(&mock_server) + .await; + let token_endpoint = format!("{}/token", mock_server.uri()); + + let response = perform_intermediate_token_exchange( + &token_endpoint, + "gateway-jwt-svid", + "", + "stored-user-token", + "", + "spiffe://openshell/openshell/sandbox/sb-a", + "", + ) + .await + .expect("intermediate token exchange should succeed"); + + assert_eq!(response.access_token, "intermediate-token"); + assert_eq!(response.expires_in, 120); + assert_eq!(response.token_type, "Bearer"); + } + + #[test] + fn intermediate_token_cache_key_varies_by_provider_revision_and_supervisor_subject() { + let mut provider = provider_with_values("cached-provider", "token-exchange"); + { + let metadata = provider.metadata.as_mut().expect("provider metadata"); + metadata.id = "provider-id".to_string(); + metadata.resource_version = 1; + } + + let base = intermediate_token_cache_key(IntermediateTokenCacheKeyInput { + provider: &provider, + dynamic_credential: "access_token", + subject_credential: "USER_OIDC_TOKEN", + token_endpoint: "https://auth.example.com/token", + client_assertion_type: DEFAULT_CLIENT_ASSERTION_TYPE, + subject_token_type: DEFAULT_TOKEN_TYPE, + audience: "spiffe://openshell/sandbox/a", + requested_token_type: DEFAULT_TOKEN_TYPE, + supervisor_subject: "spiffe://openshell/sandbox/a", + gateway_subject: "spiffe://openshell/gateway", + }); + + provider + .metadata + .as_mut() + .expect("provider metadata") + .resource_version = 2; + let changed_revision = intermediate_token_cache_key(IntermediateTokenCacheKeyInput { + provider: &provider, + dynamic_credential: "access_token", + subject_credential: "USER_OIDC_TOKEN", + token_endpoint: "https://auth.example.com/token", + client_assertion_type: DEFAULT_CLIENT_ASSERTION_TYPE, + subject_token_type: DEFAULT_TOKEN_TYPE, + audience: "spiffe://openshell/sandbox/a", + requested_token_type: DEFAULT_TOKEN_TYPE, + supervisor_subject: "spiffe://openshell/sandbox/a", + gateway_subject: "spiffe://openshell/gateway", + }); + provider + .metadata + .as_mut() + .expect("provider metadata") + .resource_version = 1; + let changed_supervisor = intermediate_token_cache_key(IntermediateTokenCacheKeyInput { + provider: &provider, + dynamic_credential: "access_token", + subject_credential: "USER_OIDC_TOKEN", + token_endpoint: "https://auth.example.com/token", + client_assertion_type: DEFAULT_CLIENT_ASSERTION_TYPE, + subject_token_type: DEFAULT_TOKEN_TYPE, + audience: "spiffe://openshell/sandbox/b", + requested_token_type: DEFAULT_TOKEN_TYPE, + supervisor_subject: "spiffe://openshell/sandbox/b", + gateway_subject: "spiffe://openshell/gateway", + }); + + assert_ne!(base, changed_revision); + assert_ne!(base, changed_supervisor); + } + + #[test] + fn intermediate_token_cache_expiry_is_capped_by_subject_and_supervisor_expiry() { + let now_ms = crate::persistence::current_time_ms(); + let subject_expires_at_ms = now_ms + 45_000; + let supervisor_exp_seconds = (now_ms + 120_000) / 1000; + let token = TokenExchangeResponseBody { + access_token: "opaque-token".to_string(), + expires_in: 600, + token_type: "Bearer".to_string(), + }; + + let expires_at_ms = intermediate_token_cache_expires_at_ms( + &token, + 300, + subject_expires_at_ms, + supervisor_exp_seconds, + ); + + assert!(expires_at_ms <= subject_expires_at_ms - 30_000); + assert!(expires_at_ms > now_ms); + } + + #[test] + fn intermediate_token_cache_returns_remaining_ttl() { + let cache = IntermediateTokenCache::new(); + let token = TokenExchangeResponseBody { + access_token: "cached-token".to_string(), + expires_in: 300, + token_type: "Bearer".to_string(), + }; + cache.set( + "cache-key".to_string(), + &token, + crate::persistence::current_time_ms() + 60_000, + ); + + let cached = cache.get("cache-key").expect("cache hit"); + + assert_eq!(cached.access_token, "cached-token"); + assert_eq!(cached.token_type, "Bearer"); + assert!((1..=60).contains(&cached.expires_in)); + } + + #[test] + fn intermediate_token_cache_prunes_expired_entries_on_set() { + let cache = IntermediateTokenCache::new(); + let token = TokenExchangeResponseBody { + access_token: "cached-token".to_string(), + expires_in: 300, + token_type: "Bearer".to_string(), + }; + cache.set( + "expired-key".to_string(), + &token, + crate::persistence::current_time_ms() - 1_000, + ); + cache.set( + "fresh-key".to_string(), + &token, + crate::persistence::current_time_ms() + 60_000, + ); + + assert!(cache.get("expired-key").is_none()); + assert!(cache.get("fresh-key").is_some()); + assert_eq!(cache.tokens.read().expect("cache lock").len(), 1); + } + + #[test] + fn jwt_exp_ms_reads_unverified_exp_claim() { + use base64::Engine as _; + let payload = serde_json::json!({ "exp": 12345 }); + let encoded_payload = base64::engine::general_purpose::URL_SAFE_NO_PAD + .encode(serde_json::to_vec(&payload).expect("serialize payload")); + let token = format!("header.{encoded_payload}.signature"); + + assert_eq!(jwt_exp_ms(&token), Some(12_345_000)); + } + + #[tokio::test] + async fn exchange_provider_subject_token_rejects_expired_subject_credential() { + let state = test_server_state().await; + let store = state.store.as_ref(); + let sandbox_id = "sb-token-exchange-expired"; + let provider_name = "keycloak-user"; + let provider_type = "keycloak-user-token-exchange"; + let subject_credential = "USER_OIDC_TOKEN"; + + import_token_exchange_profile(&state, provider_type, "access_token", subject_credential) + .await; + create_provider_record( + store, + Provider { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: String::new(), + name: provider_name.to_string(), + created_at_ms: 0, + labels: HashMap::new(), + resource_version: 0, + }), + r#type: provider_type.to_string(), + credentials: HashMap::from([( + subject_credential.to_string(), + "stored-user-token".to_string(), + )]), + config: HashMap::new(), + credential_expires_at_ms: HashMap::from([( + subject_credential.to_string(), + crate::persistence::current_time_ms() - 1_000, + )]), + }, + ) + .await + .unwrap(); + store + .put_message(&Sandbox { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: sandbox_id.to_string(), + name: sandbox_id.to_string(), + created_at_ms: 0, + labels: HashMap::new(), + resource_version: 0, + }), + spec: Some(SandboxSpec { + providers: vec![provider_name.to_string()], + ..Default::default() + }), + ..Default::default() + }) + .await + .unwrap(); + + let err = handle_exchange_provider_subject_token( + &state, + with_sandbox_principal( + Request::new(ExchangeProviderSubjectTokenRequest { + sandbox_id: sandbox_id.to_string(), + provider: provider_name.to_string(), + credential_key: "access_token".to_string(), + supervisor_jwt_svid: "header.payload.signature".to_string(), + }), + sandbox_id, + ), + ) + .await + .expect_err("expired subject credential must be rejected"); + + assert_eq!(err.code(), Code::FailedPrecondition); + assert!( + err.message() + .contains("subject token credential has expired") + ); + } + + #[tokio::test] + async fn exchange_provider_subject_token_rejects_undeclared_subject_credential() { + let state = test_server_state().await; + let store = state.store.as_ref(); + let sandbox_id = "sb-token-exchange-undeclared-subject"; + let provider_name = "keycloak-user-undeclared-subject"; + let provider_type = "keycloak-user-token-exchange-undeclared-subject"; + let subject_credential = "USER_OIDC_TOKEN"; + + store_token_exchange_profile_with_undeclared_subject( + &state, + provider_type, + "access_token", + subject_credential, + ) + .await; + create_provider_record( + store, + Provider { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: String::new(), + name: provider_name.to_string(), + created_at_ms: 0, + labels: HashMap::new(), + resource_version: 0, + }), + r#type: provider_type.to_string(), + credentials: HashMap::from([( + subject_credential.to_string(), + "stored-user-token".to_string(), + )]), + config: HashMap::new(), + credential_expires_at_ms: HashMap::new(), + }, + ) + .await + .unwrap(); + store + .put_message(&Sandbox { + metadata: Some(openshell_core::proto::datamodel::v1::ObjectMeta { + id: sandbox_id.to_string(), + name: sandbox_id.to_string(), + created_at_ms: 0, + labels: HashMap::new(), + resource_version: 0, + }), + spec: Some(SandboxSpec { + providers: vec![provider_name.to_string()], + ..Default::default() + }), + ..Default::default() + }) + .await + .unwrap(); + + let err = handle_exchange_provider_subject_token( + &state, + with_sandbox_principal( + Request::new(ExchangeProviderSubjectTokenRequest { + sandbox_id: sandbox_id.to_string(), + provider: provider_name.to_string(), + credential_key: "access_token".to_string(), + supervisor_jwt_svid: "header.payload.signature".to_string(), + }), + sandbox_id, + ), + ) + .await + .expect_err("undeclared subject credential must be rejected"); + + assert_eq!(err.code(), Code::FailedPrecondition); + assert!( + err.message() + .contains("subject token credential not declared") + ); + } + #[tokio::test] async fn list_provider_profiles_returns_built_in_profile_categories() { let state = test_server_state().await; diff --git a/crates/openshell-server/tests/common/mod.rs b/crates/openshell-server/tests/common/mod.rs index 3077cf4c9..7bc9803c2 100644 --- a/crates/openshell-server/tests/common/mod.rs +++ b/crates/openshell-server/tests/common/mod.rs @@ -16,8 +16,9 @@ use hyper_util::{ use openshell_core::proto::{ CreateProviderRequest, CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, DeleteProviderRequest, DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, - ExecSandboxEvent, ExecSandboxInput, ExecSandboxRequest, GatewayMessage, - GetGatewayConfigRequest, GetGatewayConfigResponse, GetProviderRequest, GetSandboxConfigRequest, + ExchangeProviderSubjectTokenRequest, ExchangeProviderSubjectTokenResponse, ExecSandboxEvent, + ExecSandboxInput, ExecSandboxRequest, GatewayMessage, GetGatewayConfigRequest, + GetGatewayConfigResponse, GetProviderRequest, GetSandboxConfigRequest, GetSandboxConfigResponse, GetSandboxProviderEnvironmentRequest, GetSandboxProviderEnvironmentResponse, GetSandboxRequest, HealthRequest, HealthResponse, IssueSandboxTokenRequest, IssueSandboxTokenResponse, ListProvidersRequest, @@ -184,6 +185,13 @@ impl OpenShell for TestOpenShell { Ok(Response::new(RevokeSshSessionResponse::default())) } + async fn exchange_provider_subject_token( + &self, + _request: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } + async fn create_provider( &self, _request: tonic::Request, diff --git a/crates/openshell-server/tests/supervisor_relay_integration.rs b/crates/openshell-server/tests/supervisor_relay_integration.rs index dadb8b384..3d0e1e42a 100644 --- a/crates/openshell-server/tests/supervisor_relay_integration.rs +++ b/crates/openshell-server/tests/supervisor_relay_integration.rs @@ -211,6 +211,12 @@ impl OpenShell for RelayGateway { ) -> Result, Status> { Err(Status::unimplemented("unused")) } + async fn exchange_provider_subject_token( + &self, + _: tonic::Request, + ) -> Result, Status> { + Err(Status::unimplemented("unused")) + } async fn create_provider( &self, _: tonic::Request, diff --git a/crates/openshell-supervisor-network/src/l7/token_grant_injection.rs b/crates/openshell-supervisor-network/src/l7/token_grant_injection.rs index 0d7c18e99..c9b1c7b10 100644 --- a/crates/openshell-supervisor-network/src/l7/token_grant_injection.rs +++ b/crates/openshell-supervisor-network/src/l7/token_grant_injection.rs @@ -26,6 +26,9 @@ pub struct TokenGrantRequest<'a> { pub audience: &'a str, pub scopes: &'a [String], pub cache_ttl_seconds: i64, + pub grant_type: i32, + pub subject_token_type: &'a str, + pub requested_token_type: &'a str, } pub trait TokenGrantResolver: Send + Sync { @@ -45,13 +48,18 @@ impl TokenGrantResolver for SpiffeTokenGrantResolver { ) -> Pin> + Send + 'a>> { Box::pin(async move { crate::token_grant::obtain_provider_token( - request.provider_key, - request.token_endpoint, - request.jwt_svid_audience, - request.client_assertion_type, - request.audience, - request.scopes, - request.cache_ttl_seconds, + crate::token_grant::ObtainProviderTokenRequest { + provider_name: request.provider_key, + token_endpoint: request.token_endpoint, + jwt_svid_audience: request.jwt_svid_audience, + client_assertion_type: request.client_assertion_type, + audience: request.audience, + scopes: request.scopes, + cache_ttl_override: request.cache_ttl_seconds, + grant_type: request.grant_type, + subject_token_type: request.subject_token_type, + requested_token_type: request.requested_token_type, + }, ) .await }) @@ -127,7 +135,7 @@ pub async fn inject_if_needed(req: L7Request, ctx: &L7EvalContext) -> Result( audience: &token_grant.audience, scopes: &token_grant.scopes, cache_ttl_seconds: token_grant.cache_ttl_seconds, + grant_type: token_grant.grant_type, + subject_token_type: token_grant + .subject_token + .as_ref() + .map_or("", |subject| subject.subject_token_type.as_str()), + requested_token_type: &token_grant.requested_token_type, } } @@ -377,7 +391,9 @@ fn inject_header(raw_header: &[u8], header_name: &str, header_value: &str) -> Re #[cfg(test)] pub mod test_support { use super::*; - use openshell_core::proto::{ProviderCredentialTokenGrant, ProviderProfileCredential}; + use openshell_core::proto::{ + ProviderCredentialTokenGrant, ProviderCredentialTokenGrantType, ProviderProfileCredential, + }; use std::collections::HashMap; use std::sync::{Arc, Mutex}; @@ -395,6 +411,9 @@ pub mod test_support { audience: String, scopes: Vec, cache_ttl_seconds: i64, + grant_type: i32, + subject_token_type: String, + requested_token_type: String, } pub struct TokenGrantTestFixture { @@ -466,6 +485,12 @@ pub mod test_support { assert_eq!(request.audience, "api://example"); assert_eq!(request.scopes, ["read"]); assert_eq!(request.cache_ttl_seconds, 300); + assert_eq!( + request.grant_type, + ProviderCredentialTokenGrantType::ClientCredentials as i32 + ); + assert!(request.subject_token_type.is_empty()); + assert!(request.requested_token_type.is_empty()); } } @@ -479,6 +504,9 @@ pub mod test_support { scopes: vec!["read".to_string()], cache_ttl_seconds: 300, audience_overrides: Vec::new(), + grant_type: ProviderCredentialTokenGrantType::ClientCredentials as i32, + subject_token: None, + requested_token_type: String::new(), } } @@ -495,6 +523,9 @@ pub mod test_support { audience: request.audience.to_string(), scopes: request.scopes.to_vec(), cache_ttl_seconds: request.cache_ttl_seconds, + grant_type: request.grant_type, + subject_token_type: request.subject_token_type.to_string(), + requested_token_type: request.requested_token_type.to_string(), }; Box::pin(async move { self.requests @@ -531,6 +562,12 @@ mod tests { 443, "/repos/owner/repo" )); + assert!(dynamic_credential_key_matches( + "api.example.com\t443\t/repos/**\trev:42\tgithub:access_token", + "api.example.com", + 443, + "/repos/owner/repo" + )); assert!(!dynamic_credential_key_matches( key, "uploads.example.com", diff --git a/crates/openshell-supervisor-network/src/lib.rs b/crates/openshell-supervisor-network/src/lib.rs index a559a57e6..c995b86ab 100644 --- a/crates/openshell-supervisor-network/src/lib.rs +++ b/crates/openshell-supervisor-network/src/lib.rs @@ -16,5 +16,4 @@ pub mod policy_local; pub mod procfs; pub mod proxy; pub mod run; -mod spiffe_endpoint; mod token_grant; diff --git a/crates/openshell-supervisor-network/src/proxy.rs b/crates/openshell-supervisor-network/src/proxy.rs index d467b022e..d57c668a4 100644 --- a/crates/openshell-supervisor-network/src/proxy.rs +++ b/crates/openshell-supervisor-network/src/proxy.rs @@ -12,7 +12,7 @@ use openshell_core::activity::{ActivitySender, try_record_activity}; use openshell_core::denial::DenialEvent; use openshell_core::net::{is_always_blocked_ip, is_internal_ip, is_link_local_ip}; use openshell_core::policy::ProxyPolicy; -use openshell_core::provider_credentials::ProviderCredentialState; +use openshell_core::provider_credentials::{ProviderCredentialSnapshot, ProviderCredentialState}; use openshell_core::secrets::{SecretResolver, rewrite_header_line_checked}; use openshell_ocsf::{ ActionId, ActivityId, DispositionId, Endpoint, HttpActivityBuilder, HttpRequest, @@ -44,6 +44,27 @@ const HOST_GATEWAY_ALIASES: &[&str] = &[ "host.docker.internal", ]; +fn revision_scoped_dynamic_credentials( + snapshot: &ProviderCredentialSnapshot, +) -> std::collections::HashMap { + snapshot + .dynamic_credentials + .iter() + .map(|(key, credential)| { + let scoped_key = key.rsplit_once('\t').map_or_else( + || format!("rev:{}\t{key}", snapshot.revision), + |(endpoint_selector, provider_credential)| { + format!( + "{endpoint_selector}\trev:{}\t{provider_credential}", + snapshot.revision + ) + }, + ); + (scoped_key, credential.clone()) + }) + .collect() +} + /// Cloud instance metadata IPs that are NEVER exempted from SSRF blocking, /// even when they coincidentally match a host-gateway alias resolution. /// This list covers the well-known IMDS endpoints across major cloud providers. @@ -243,9 +264,9 @@ impl ProxyHandle { .as_ref() .and_then(ProviderCredentialState::resolver); let dynamic_credentials = provider_credentials.as_ref().map(|state| { - Arc::new(std::sync::RwLock::new( - state.snapshot().dynamic_credentials.clone(), - )) + Arc::new(std::sync::RwLock::new(revision_scoped_dynamic_credentials( + &state.snapshot(), + ))) }); let dtx = denial_tx.clone(); let atx = activity_tx.clone(); @@ -4098,6 +4119,29 @@ mod tests { } } + #[test] + fn revision_scoped_dynamic_credentials_preserves_endpoint_selector_and_adds_revision() { + let mut dynamic_credentials = std::collections::HashMap::new(); + dynamic_credentials.insert( + "api.example.test\t443\t/v1/**\tprovider:access_token".to_string(), + openshell_core::proto::ProviderProfileCredential { + name: "access_token".to_string(), + ..Default::default() + }, + ); + let snapshot = ProviderCredentialSnapshot { + revision: 42, + child_env: std::collections::HashMap::new(), + dynamic_credentials, + }; + + let scoped = revision_scoped_dynamic_credentials(&snapshot); + + assert!( + scoped.contains_key("api.example.test\t443\t/v1/**\trev:42\tprovider:access_token") + ); + } + #[test] fn connect_activity_is_skipped_when_l7_will_count_the_request() { let (tx, mut rx) = mpsc::channel(4); diff --git a/crates/openshell-supervisor-network/src/spiffe_endpoint.rs b/crates/openshell-supervisor-network/src/spiffe_endpoint.rs deleted file mode 100644 index 449462627..000000000 --- a/crates/openshell-supervisor-network/src/spiffe_endpoint.rs +++ /dev/null @@ -1,17 +0,0 @@ -// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -// SPDX-License-Identifier: Apache-2.0 - -use std::path::Path; - -/// Convert a path to a SPIFFE Workload API endpoint URL. -/// -/// If the path already has a scheme (`unix:` or `tcp:`), use it as-is. -/// Otherwise, assume it is a Unix socket path and prepend `unix:`. -pub fn workload_api_endpoint(path: &Path) -> String { - let path = path.to_string_lossy(); - if path.starts_with("unix:") || path.starts_with("tcp:") { - path.into_owned() - } else { - format!("unix:{path}") - } -} diff --git a/crates/openshell-supervisor-network/src/token_grant.rs b/crates/openshell-supervisor-network/src/token_grant.rs index 03e9bfb39..f8d09e60b 100644 --- a/crates/openshell-supervisor-network/src/token_grant.rs +++ b/crates/openshell-supervisor-network/src/token_grant.rs @@ -39,6 +39,7 @@ use std::sync::{Arc, LazyLock, RwLock}; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use miette::{IntoDiagnostic, Result, WrapErr}; +use openshell_core::proto::ProviderCredentialTokenGrantType; use openshell_core::sandbox_env; use serde::Deserialize; use spiffe::WorkloadApiClient; @@ -60,6 +61,7 @@ const TOKEN_CACHE_EXPIRY_SKEW_SECONDS: i64 = 30; const MAX_TOKEN_EXPIRES_IN_SECONDS: i64 = 3600; const DEFAULT_CLIENT_ASSERTION_TYPE: &str = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"; +const DEFAULT_TOKEN_TYPE: &str = "urn:ietf:params:oauth:token-type:access_token"; /// `OAuth2` token response from the authorization server. #[derive(Debug, Clone, Deserialize)] @@ -153,41 +155,87 @@ impl TokenCache { /// - JWT-SVID fetch fails /// - Token service request fails /// - Token response is invalid -pub async fn obtain_provider_token( - provider_name: &str, - token_endpoint: &str, - jwt_svid_audience: &str, - client_assertion_type: &str, - audience: &str, - scopes: &[String], - cache_ttl_override: i64, -) -> Result { +pub struct ObtainProviderTokenRequest<'a> { + pub provider_name: &'a str, + pub token_endpoint: &'a str, + pub jwt_svid_audience: &'a str, + pub client_assertion_type: &'a str, + pub audience: &'a str, + pub scopes: &'a [String], + pub cache_ttl_override: i64, + pub grant_type: i32, + pub subject_token_type: &'a str, + pub requested_token_type: &'a str, +} + +pub async fn obtain_provider_token(request: ObtainProviderTokenRequest<'_>) -> Result { + let grant_type = ProviderCredentialTokenGrantType::try_from(request.grant_type) + .unwrap_or(ProviderCredentialTokenGrantType::ClientCredentials); obtain_provider_token_with_grant( ObtainProviderTokenInput { cache: &TOKEN_CACHE, - provider_name, - token_endpoint, - jwt_svid_audience, - client_assertion_type, - audience, - scopes, - cache_ttl_override, + provider_name: request.provider_name, + token_endpoint: request.token_endpoint, + jwt_svid_audience: request.jwt_svid_audience, + client_assertion_type: request.client_assertion_type, + audience: request.audience, + scopes: request.scopes, + cache_ttl_override: request.cache_ttl_override, + grant_type, + requested_token_type: request.requested_token_type, }, |jwt_audience| async move { // Fetch JWT-SVID with authorization server as audience // For RFC 7523, the JWT assertion's aud claim identifies the issuer/realm let jwt_svid = fetch_jwt_svid_for_token_grant(&jwt_audience).await?; - // Perform OAuth2 JWT client assertion grant - // The audience parameter in the token request specifies the resource server - perform_token_grant( - token_endpoint, - &jwt_svid, - client_assertion_type, - audience, - scopes, - ) - .await + match grant_type { + ProviderCredentialTokenGrantType::ClientCredentials + | ProviderCredentialTokenGrantType::Unspecified => { + // Perform OAuth2 JWT client assertion grant. The audience + // parameter in the token request specifies the resource server. + perform_token_grant( + request.token_endpoint, + &jwt_svid, + request.client_assertion_type, + request.audience, + request.scopes, + ) + .await + } + ProviderCredentialTokenGrantType::TokenExchange => { + let (provider, credential_key) = + parse_provider_credential_key(request.provider_name)?; + let endpoint = supervisor_gateway_endpoint_from_env()?; + let sandbox_id = supervisor_sandbox_id_from_env()?; + let intermediate = + openshell_core::grpc_client::exchange_provider_subject_token( + &endpoint, + &sandbox_id, + provider, + credential_key, + &jwt_svid, + ) + .await + .map_err(|err| { + miette::miette!( + "gateway intermediate provider token exchange failed: {err}" + ) + })?; + validate_access_token(&intermediate.access_token)?; + perform_token_exchange( + request.token_endpoint, + &jwt_svid, + request.client_assertion_type, + &intermediate.access_token, + request.subject_token_type, + request.audience, + request.scopes, + request.requested_token_type, + ) + .await + } + } }, ) .await @@ -202,6 +250,8 @@ struct ObtainProviderTokenInput<'a> { audience: &'a str, scopes: &'a [String], cache_ttl_override: i64, + grant_type: ProviderCredentialTokenGrantType, + requested_token_type: &'a str, } async fn obtain_provider_token_with_grant( @@ -216,14 +266,16 @@ where // For Keycloak: https://auth.example.com/realms/openshell/protocol/openid-connect/token // -> https://auth.example.com/realms/openshell let jwt_audience = effective_jwt_svid_audience(input.token_endpoint, input.jwt_svid_audience); - let cache_key = token_cache_key( - input.provider_name, - input.token_endpoint, - &jwt_audience, - effective_client_assertion_type(input.client_assertion_type), - input.audience, - input.scopes, - ); + let cache_key = token_cache_key(TokenCacheKeyInput { + provider_name: input.provider_name, + token_endpoint: input.token_endpoint, + jwt_svid_audience: &jwt_audience, + client_assertion_type: effective_client_assertion_type(input.client_assertion_type), + audience: input.audience, + scopes: input.scopes, + grant_type: input.grant_type, + requested_token_type: effective_token_type(input.requested_token_type), + }); // Check cache first if let Some(cached) = input.cache.get(&cache_key) { @@ -256,7 +308,7 @@ async fn fetch_jwt_svid_for_token_grant(audience: &str) -> Result { let socket_path = provider_spiffe_workload_api_socket_from_env()?; let endpoint = - crate::spiffe_endpoint::workload_api_endpoint(std::path::Path::new(&socket_path)); + openshell_core::spiffe::workload_api_endpoint(std::path::Path::new(&socket_path)); // Connect to SPIRE agent let client = WorkloadApiClient::connect_to(&endpoint) @@ -359,6 +411,74 @@ async fn perform_token_grant( Ok(token_response) } +#[allow(clippy::too_many_arguments)] +async fn perform_token_exchange( + token_endpoint: &str, + jwt_svid: &str, + client_assertion_type: &str, + subject_token: &str, + subject_token_type: &str, + audience: &str, + scopes: &[String], + requested_token_type: &str, +) -> Result { + let token_endpoint_url = parse_token_endpoint_url(token_endpoint)?; + let client_assertion_type = effective_client_assertion_type(client_assertion_type); + let subject_token_type = effective_token_type(subject_token_type); + let requested_token_type = effective_token_type(requested_token_type); + let mut form_params = vec![ + ( + "grant_type", + "urn:ietf:params:oauth:grant-type:token-exchange", + ), + ("client_assertion_type", client_assertion_type), + ("client_assertion", jwt_svid), + ("subject_token", subject_token), + ("subject_token_type", subject_token_type), + ("requested_token_type", requested_token_type), + ]; + + let audience_param; + if !audience.is_empty() { + audience_param = audience.to_string(); + form_params.push(("audience", &audience_param)); + } + + let scope_param; + if !scopes.is_empty() { + scope_param = scopes.join(" "); + form_params.push(("scope", &scope_param)); + } + + let response = TOKEN_GRANT_HTTP_CLIENT + .post(token_endpoint_url) + .form(&form_params) + .send() + .await + .into_diagnostic() + .wrap_err_with(|| format!("failed to POST token exchange to {token_endpoint}"))?; + + if !response.status().is_success() { + let status = response.status(); + let body = response + .text() + .await + .unwrap_or_else(|_| "".to_string()); + return Err(miette::miette!( + "{}", + token_grant_failure_message(status, &body) + )); + } + + let token_response = response + .json::() + .await + .into_diagnostic() + .wrap_err("failed to parse token exchange response as JSON")?; + validate_access_token(&token_response.access_token)?; + Ok(token_response) +} + fn parse_token_endpoint_url(token_endpoint: &str) -> Result { let url = reqwest::Url::parse(token_endpoint) .into_diagnostic() @@ -491,22 +611,59 @@ fn effective_client_assertion_type(client_assertion_type: &str) -> &str { } } -fn token_cache_key( - provider_name: &str, - token_endpoint: &str, - jwt_svid_audience: &str, - client_assertion_type: &str, - audience: &str, - scopes: &[String], -) -> String { +fn effective_token_type(token_type: &str) -> &str { + if token_type.trim().is_empty() { + DEFAULT_TOKEN_TYPE + } else { + token_type + } +} + +fn supervisor_gateway_endpoint_from_env() -> Result { + std::env::var(sandbox_env::ENDPOINT) + .ok() + .filter(|value| !value.trim().is_empty()) + .ok_or_else(|| miette::miette!("{} not set", sandbox_env::ENDPOINT)) +} + +fn supervisor_sandbox_id_from_env() -> Result { + std::env::var(sandbox_env::SANDBOX_ID) + .ok() + .filter(|value| !value.trim().is_empty()) + .ok_or_else(|| miette::miette!("{} not set", sandbox_env::SANDBOX_ID)) +} + +fn parse_provider_credential_key(key: &str) -> Result<(&str, &str)> { + let provider_and_credential = key + .rsplit_once('\t') + .map_or(key, |(_, provider_and_credential)| provider_and_credential); + provider_and_credential.split_once(':').ok_or_else(|| { + miette::miette!("dynamic token grant key is missing provider credential identity") + }) +} + +struct TokenCacheKeyInput<'a> { + provider_name: &'a str, + token_endpoint: &'a str, + jwt_svid_audience: &'a str, + client_assertion_type: &'a str, + audience: &'a str, + scopes: &'a [String], + grant_type: ProviderCredentialTokenGrantType, + requested_token_type: &'a str, +} + +fn token_cache_key(input: TokenCacheKeyInput<'_>) -> String { format!( - "{}\t{}\t{}\t{}\t{}\t{}", - provider_name, - token_endpoint, - jwt_svid_audience, - client_assertion_type, - audience, - scopes.join(" ") + "{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}", + input.provider_name, + input.token_endpoint, + input.jwt_svid_audience, + input.client_assertion_type, + input.audience, + input.scopes.join(" "), + input.grant_type as i32, + input.requested_token_type ) } @@ -778,6 +935,8 @@ mod tests { audience: input.audience, scopes: input.scopes, cache_ttl_override: input.cache_ttl_override, + grant_type: ProviderCredentialTokenGrantType::ClientCredentials, + requested_token_type: "", }, move |_| { let grant_calls = input.grant_calls.clone(); @@ -814,6 +973,8 @@ mod tests { audience, scopes, cache_ttl_override, + grant_type: ProviderCredentialTokenGrantType::ClientCredentials, + requested_token_type: "", }, |_| async { Err(miette::miette!("grant should not be called on cache hit")) }, ) @@ -916,44 +1077,99 @@ mod tests { #[test] fn token_cache_key_varies_by_resource_audience_and_scopes() { - let base = token_cache_key( - "alpha.default.svc.cluster.local\t80\t\tprovider:access_token", - "https://auth.example.com/realms/openshell/protocol/openid-connect/token", - "https://auth.example.com/realms/openshell", - DEFAULT_CLIENT_ASSERTION_TYPE, - "alpha", - &["alpha".to_string()], - ); - let different_audience = token_cache_key( - "alpha.default.svc.cluster.local\t80\t\tprovider:access_token", - "https://auth.example.com/realms/openshell/protocol/openid-connect/token", - "https://auth.example.com/realms/openshell", - DEFAULT_CLIENT_ASSERTION_TYPE, - "delta", - &["alpha".to_string()], - ); - let different_scopes = token_cache_key( - "alpha.default.svc.cluster.local\t80\t\tprovider:access_token", - "https://auth.example.com/realms/openshell/protocol/openid-connect/token", - "https://auth.example.com/realms/openshell", - DEFAULT_CLIENT_ASSERTION_TYPE, - "alpha", - &["delta".to_string()], - ); - let different_assertion_type = token_cache_key( - "alpha.default.svc.cluster.local\t80\t\tprovider:access_token", - "https://auth.example.com/realms/openshell/protocol/openid-connect/token", - "https://auth.example.com/realms/openshell", - "urn:ietf:params:oauth:client-assertion-type:jwt-spiffe", - "alpha", - &["alpha".to_string()], - ); + let provider_name = "alpha.default.svc.cluster.local\t80\t\tprovider:access_token"; + let token_endpoint = + "https://auth.example.com/realms/openshell/protocol/openid-connect/token"; + let jwt_svid_audience = "https://auth.example.com/realms/openshell"; + let alpha_scopes = ["alpha".to_string()]; + let delta_scopes = ["delta".to_string()]; + let base = token_cache_key(TokenCacheKeyInput { + provider_name, + token_endpoint, + jwt_svid_audience, + client_assertion_type: DEFAULT_CLIENT_ASSERTION_TYPE, + audience: "alpha", + scopes: &alpha_scopes, + grant_type: ProviderCredentialTokenGrantType::ClientCredentials, + requested_token_type: DEFAULT_TOKEN_TYPE, + }); + let different_audience = token_cache_key(TokenCacheKeyInput { + provider_name, + token_endpoint, + jwt_svid_audience, + client_assertion_type: DEFAULT_CLIENT_ASSERTION_TYPE, + audience: "delta", + scopes: &alpha_scopes, + grant_type: ProviderCredentialTokenGrantType::ClientCredentials, + requested_token_type: DEFAULT_TOKEN_TYPE, + }); + let different_scopes = token_cache_key(TokenCacheKeyInput { + provider_name, + token_endpoint, + jwt_svid_audience, + client_assertion_type: DEFAULT_CLIENT_ASSERTION_TYPE, + audience: "alpha", + scopes: &delta_scopes, + grant_type: ProviderCredentialTokenGrantType::ClientCredentials, + requested_token_type: DEFAULT_TOKEN_TYPE, + }); + let different_assertion_type = token_cache_key(TokenCacheKeyInput { + provider_name, + token_endpoint, + jwt_svid_audience, + client_assertion_type: "urn:ietf:params:oauth:client-assertion-type:jwt-spiffe", + audience: "alpha", + scopes: &alpha_scopes, + grant_type: ProviderCredentialTokenGrantType::ClientCredentials, + requested_token_type: DEFAULT_TOKEN_TYPE, + }); assert_ne!(base, different_audience); assert_ne!(base, different_scopes); assert_ne!(base, different_assertion_type); } + #[test] + fn token_cache_key_varies_by_provider_env_revision_prefix() { + let token_endpoint = + "https://auth.example.com/realms/openshell/protocol/openid-connect/token"; + let jwt_svid_audience = "https://auth.example.com/realms/openshell"; + let scopes = ["alpha".to_string()]; + let revision_one = token_cache_key(TokenCacheKeyInput { + provider_name: "api.example.test\t443\t/v1/**\trev:1\tprovider:access_token", + token_endpoint, + jwt_svid_audience, + client_assertion_type: DEFAULT_CLIENT_ASSERTION_TYPE, + audience: "alpha", + scopes: &scopes, + grant_type: ProviderCredentialTokenGrantType::TokenExchange, + requested_token_type: DEFAULT_TOKEN_TYPE, + }); + let revision_two = token_cache_key(TokenCacheKeyInput { + provider_name: "api.example.test\t443\t/v1/**\trev:2\tprovider:access_token", + token_endpoint, + jwt_svid_audience, + client_assertion_type: DEFAULT_CLIENT_ASSERTION_TYPE, + audience: "alpha", + scopes: &scopes, + grant_type: ProviderCredentialTokenGrantType::TokenExchange, + requested_token_type: DEFAULT_TOKEN_TYPE, + }); + + assert_ne!(revision_one, revision_two); + } + + #[test] + fn provider_credential_key_parser_ignores_revision_segment() { + assert_eq!( + parse_provider_credential_key( + "api.example.test\t443\t/v1/**\trev:42\tprovider:access_token" + ) + .expect("parse provider credential key"), + ("provider", "access_token") + ); + } + #[test] fn token_cache_ttl_uses_override_without_endpoint_skew() { assert_eq!(token_cache_ttl_seconds(120, 10), 120); @@ -1077,14 +1293,16 @@ mod tests { let jwt_svid_audience = "https://auth.example.com"; let audience = "api://resource"; - let cache_key = token_cache_key( + let cache_key = token_cache_key(TokenCacheKeyInput { provider_name, token_endpoint, jwt_svid_audience, - DEFAULT_CLIENT_ASSERTION_TYPE, + client_assertion_type: DEFAULT_CLIENT_ASSERTION_TYPE, audience, - &scopes, - ); + scopes: &scopes, + grant_type: ProviderCredentialTokenGrantType::ClientCredentials, + requested_token_type: DEFAULT_TOKEN_TYPE, + }); cache.set( cache_key, "expired-token".to_string(), @@ -1128,6 +1346,8 @@ mod tests { audience, scopes: &scopes, cache_ttl_override: 0, + grant_type: ProviderCredentialTokenGrantType::ClientCredentials, + requested_token_type: "", }, |_| async { Ok(TokenResponse { @@ -1141,14 +1361,16 @@ mod tests { .await .expect_err("malformed access token should fail before caching"); - let cache_key = token_cache_key( + let cache_key = token_cache_key(TokenCacheKeyInput { provider_name, token_endpoint, jwt_svid_audience, - DEFAULT_CLIENT_ASSERTION_TYPE, + client_assertion_type: DEFAULT_CLIENT_ASSERTION_TYPE, audience, - &scopes, - ); + scopes: &scopes, + grant_type: ProviderCredentialTokenGrantType::ClientCredentials, + requested_token_type: DEFAULT_TOKEN_TYPE, + }); assert_eq!( err.to_string(), diff --git a/crates/openshell-supervisor-process/src/run.rs b/crates/openshell-supervisor-process/src/run.rs index 5a5c203a2..7a926b4f6 100644 --- a/crates/openshell-supervisor-process/src/run.rs +++ b/crates/openshell-supervisor-process/src/run.rs @@ -85,10 +85,18 @@ pub async fn run_process( // the flag stays at its default (false) and no skill is installed. install_initial_agent_skill(sandbox_id, openshell_endpoint).await; + // Provider token grants may mount supervisor-only identity sockets such as + // the SPIFFE Workload API. Prepare the child mount namespace that hides + // those mounts before supervisor seccomp hardening removes the needed + // namespace syscalls. + #[cfg(target_os = "linux")] + crate::process::prepare_supervisor_identity_mount_namespace_from_env()?; + // Install the supervisor seccomp prelude before spawning any workload-side // tasks. By this point the orchestrator has finished privileged startup - // helpers (network namespace setup, nftables probes via run_networking), - // and the SSH listener and entrypoint child have not been exposed yet. + // helpers (network namespace setup, identity mount namespace setup, + // nftables probes via run_networking), and the SSH listener and entrypoint + // child have not been exposed yet. crate::sandbox::apply_supervisor_startup_hardening()?; // Spawn the bypass detection monitor. It tails dmesg for nftables LOG diff --git a/deploy/helm/openshell/README.md b/deploy/helm/openshell/README.md index e6d539592..c1528f617 100644 --- a/deploy/helm/openshell/README.md +++ b/deploy/helm/openshell/README.md @@ -125,15 +125,20 @@ JWT signing Secret. ## SPIFFE/SPIRE provider token grants -Set `server.providerTokenGrants.spiffe.enabled=true` to let sandbox supervisors -use SPIFFE JWT-SVIDs for dynamic provider token grants. The chart keeps -supervisor-to-gateway authentication on gateway-minted sandbox JWTs and passes -the SPIFFE Workload API socket path to the Kubernetes driver so sandbox pods can -mount the SPIFFE CSI socket. +Set `server.providerTokenGrants.spiffe.enabled=true` to let the gateway and +sandbox supervisors use SPIFFE JWT-SVIDs for dynamic provider token grants. The +chart keeps supervisor-to-gateway authentication on gateway-minted sandbox JWTs, +mounts the SPIFFE CSI socket into the gateway pod, exports +`OPENSHELL_GATEWAY_SPIFFE_WORKLOAD_API_SOCKET`, and passes the socket path to +the Kubernetes driver so sandbox pods can mount the same socket. For local development, uncomment the SPIRE Helm releases in `skaffold.yaml` and add `ci/values-spire.yaml` to the OpenShell release values files. +The gateway verifies supervisor JWT-SVIDs with JWT bundles fetched from the +SPIFFE Workload API, so this path does not require access to the SPIRE OIDC +discovery endpoint or its TLS CA. + ## Values | Key | Type | Default | Description | @@ -211,8 +216,8 @@ add `ci/values-spire.yaml` to the OpenShell release values files. | server.oidc.rolesClaim | string | `""` | Dot-separated path to the roles array in the JWT claims. Keycloak: "realm_access.roles", Entra ID: "roles", Okta: "groups". | | server.oidc.scopesClaim | string | `""` | Dot-separated path to the scopes array in the JWT claims. | | server.oidc.userRole | string | `""` | Role name for standard user access. | -| server.providerTokenGrants.spiffe.enabled | bool | `false` | Mount the SPIFFE Workload API socket into sandbox pods for dynamic provider token grants. | -| server.providerTokenGrants.spiffe.workloadApiSocketPath | string | `"/spiffe-workload-api/spire-agent.sock"` | Path to the SPIFFE Workload API socket mounted into sandbox pods. | +| server.providerTokenGrants.spiffe.enabled | bool | `false` | Mount the SPIFFE Workload API socket into gateway and sandbox pods for dynamic provider token grants. | +| server.providerTokenGrants.spiffe.workloadApiSocketPath | string | `"/spiffe-workload-api/spire-agent.sock"` | Path to the SPIFFE Workload API socket mounted into gateway and sandbox pods. | | server.sandboxImage | string | `"ghcr.io/nvidia/openshell-community/sandboxes/base:latest"` | Default sandbox image used when requests do not specify one. | | server.sandboxImagePullPolicy | string | `""` | Kubernetes imagePullPolicy for sandbox pods. Empty = Kubernetes default (Always for :latest, IfNotPresent otherwise). Set to "Always" for dev clusters so new images are picked up without manual eviction. | | server.sandboxImagePullSecrets | list | `[]` | Image pull secrets attached to sandbox pods. Referenced Secrets must exist in the sandbox namespace. | diff --git a/deploy/helm/openshell/README.md.gotmpl b/deploy/helm/openshell/README.md.gotmpl index e246ca67b..66eac0c63 100644 --- a/deploy/helm/openshell/README.md.gotmpl +++ b/deploy/helm/openshell/README.md.gotmpl @@ -125,14 +125,19 @@ JWT signing Secret. ## SPIFFE/SPIRE provider token grants -Set `server.providerTokenGrants.spiffe.enabled=true` to let sandbox supervisors -use SPIFFE JWT-SVIDs for dynamic provider token grants. The chart keeps -supervisor-to-gateway authentication on gateway-minted sandbox JWTs and passes -the SPIFFE Workload API socket path to the Kubernetes driver so sandbox pods can -mount the SPIFFE CSI socket. +Set `server.providerTokenGrants.spiffe.enabled=true` to let the gateway and +sandbox supervisors use SPIFFE JWT-SVIDs for dynamic provider token grants. The +chart keeps supervisor-to-gateway authentication on gateway-minted sandbox JWTs, +mounts the SPIFFE CSI socket into the gateway pod, exports +`OPENSHELL_GATEWAY_SPIFFE_WORKLOAD_API_SOCKET`, and passes the socket path to +the Kubernetes driver so sandbox pods can mount the same socket. For local development, uncomment the SPIRE Helm releases in `skaffold.yaml` and add `ci/values-spire.yaml` to the OpenShell release values files. +The gateway verifies supervisor JWT-SVIDs with JWT bundles fetched from the +SPIFFE Workload API, so this path does not require access to the SPIRE OIDC +discovery endpoint or its TLS CA. + {{ template "chart.valuesSection" . }} {{ template "helm-docs.versionFooter" . }} diff --git a/deploy/helm/openshell/templates/_gateway-workload.tpl b/deploy/helm/openshell/templates/_gateway-workload.tpl index 5931047e5..b54112a64 100644 --- a/deploy/helm/openshell/templates/_gateway-workload.tpl +++ b/deploy/helm/openshell/templates/_gateway-workload.tpl @@ -60,7 +60,7 @@ spec: # All gateway settings live in the ConfigMap-backed TOML file # mounted at /etc/openshell/gateway.toml. The only env var below # is a process-level setting consumed by libraries outside - # gateway code (currently just SSL_CERT_FILE for OIDC issuer TLS). + # gateway code (currently SSL_CERT_FILE for OIDC issuer TLS). {{- if and .Values.server.oidc.issuer .Values.server.oidc.caConfigMapName }} # OIDC issuer custom-CA: rustls/reqwest read SSL_CERT_FILE for # outbound TLS verification. This is a process-level env var @@ -69,6 +69,10 @@ spec: - name: SSL_CERT_FILE value: /etc/openshell-tls/oidc-ca/ca.crt {{- end }} + {{- if .Values.server.providerTokenGrants.spiffe.enabled }} + - name: OPENSHELL_GATEWAY_SPIFFE_WORKLOAD_API_SOCKET + value: {{ .Values.server.providerTokenGrants.spiffe.workloadApiSocketPath | quote }} + {{- end }} volumeMounts: {{- if eq (include "openshell.workloadKind" .) "statefulset" }} - name: openshell-data @@ -95,6 +99,11 @@ spec: mountPath: /etc/openshell-tls/oidc-ca readOnly: true {{- end }} + {{- if .Values.server.providerTokenGrants.spiffe.enabled }} + - name: spiffe-workload-api + mountPath: {{ dir .Values.server.providerTokenGrants.spiffe.workloadApiSocketPath | quote }} + readOnly: true + {{- end }} ports: - name: grpc containerPort: {{ .Values.service.port }} @@ -162,6 +171,12 @@ spec: configMap: name: {{ .Values.server.oidc.caConfigMapName }} {{- end }} + {{- if .Values.server.providerTokenGrants.spiffe.enabled }} + - name: spiffe-workload-api + csi: + driver: csi.spiffe.io + readOnly: true + {{- end }} {{- with .Values.nodeSelector }} nodeSelector: {{- toYaml . | nindent 4 }} diff --git a/deploy/helm/openshell/tests/gateway_config_test.yaml b/deploy/helm/openshell/tests/gateway_config_test.yaml index c2708a20f..b7f0f9ce4 100644 --- a/deploy/helm/openshell/tests/gateway_config_test.yaml +++ b/deploy/helm/openshell/tests/gateway_config_test.yaml @@ -399,11 +399,39 @@ tests: path: data["gateway.toml"] pattern: '\[openshell\.gateway\.spiffe\]' - - it: keeps the gateway sandbox JWT secret mounted when provider SPIFFE grants are enabled + - it: mounts the gateway SPIFFE socket while keeping sandbox JWT auth set: server.providerTokenGrants.spiffe.enabled: true template: templates/statefulset.yaml asserts: - - matchRegex: - path: spec.template.spec.volumes[1].name - pattern: '^sandbox-jwt$' + - contains: + path: spec.template.spec.containers[0].env + content: + name: OPENSHELL_GATEWAY_SPIFFE_WORKLOAD_API_SOCKET + value: /spiffe-workload-api/spire-agent.sock + - contains: + path: spec.template.spec.containers[0].volumeMounts + content: + name: sandbox-jwt + mountPath: /etc/openshell-jwt + readOnly: true + - contains: + path: spec.template.spec.containers[0].volumeMounts + content: + name: spiffe-workload-api + mountPath: /spiffe-workload-api + readOnly: true + - contains: + path: spec.template.spec.volumes + content: + name: sandbox-jwt + secret: + defaultMode: 256 + secretName: openshell-jwt-keys + - contains: + path: spec.template.spec.volumes + content: + name: spiffe-workload-api + csi: + driver: csi.spiffe.io + readOnly: true diff --git a/deploy/helm/openshell/values.yaml b/deploy/helm/openshell/values.yaml index d7ff8b257..51f8adeff 100644 --- a/deploy/helm/openshell/values.yaml +++ b/deploy/helm/openshell/values.yaml @@ -255,15 +255,15 @@ server: # (owner-read only). Override to 0440 or 0444 if the container UID # does not match the volume file owner. secretDefaultMode: "" - # Dynamic provider token grants. When SPIFFE is enabled here, sandbox - # supervisors mount the SPIFFE Workload API socket so provider profiles can - # exchange JWT-SVIDs for upstream access tokens. Supervisor-to-gateway - # authentication still uses gateway-minted sandbox JWTs. + # Dynamic provider token grants. When SPIFFE is enabled here, both the + # gateway and sandbox supervisors mount the SPIFFE Workload API socket so + # token-exchange profiles can use gateway- and sandbox-scoped JWT-SVIDs. + # Supervisor-to-gateway authentication still uses gateway-minted sandbox JWTs. providerTokenGrants: spiffe: - # -- Mount the SPIFFE Workload API socket into sandbox pods for dynamic provider token grants. + # -- Mount the SPIFFE Workload API socket into gateway and sandbox pods for dynamic provider token grants. enabled: false - # -- Path to the SPIFFE Workload API socket mounted into sandbox pods. + # -- Path to the SPIFFE Workload API socket mounted into gateway and sandbox pods. workloadApiSocketPath: /spiffe-workload-api/spire-agent.sock # OIDC (OpenID Connect) configuration for JWT-based authentication. # When issuer is set, the server validates Bearer tokens on gRPC requests. diff --git a/docs/kubernetes/access-control.mdx b/docs/kubernetes/access-control.mdx index 8824b6de1..23126197a 100644 --- a/docs/kubernetes/access-control.mdx +++ b/docs/kubernetes/access-control.mdx @@ -23,9 +23,11 @@ For how the CLI resolves gateways and stores credentials, refer to [Gateway Auth Kubernetes sandbox supervisors authenticate back to the gateway as sandbox workloads. By default, the gateway mints its own sandbox JWTs and Kubernetes sandboxes bootstrap them with a projected ServiceAccount token. -Dynamic provider token grants can use SPIFFE without changing supervisor-to-gateway authentication. Set `server.providerTokenGrants.spiffe.enabled=true` to mount the SPIFFE CSI Workload API socket into sandbox pods while keeping the projected ServiceAccount token bootstrap and gateway-minted sandbox JWT path. +Dynamic provider token grants can use SPIFFE without changing supervisor-to-gateway authentication. Set `server.providerTokenGrants.spiffe.enabled=true` to mount the SPIFFE CSI Workload API socket into gateway and sandbox pods while keeping the projected ServiceAccount token bootstrap and gateway-minted sandbox JWT path. -Provider token grants require a SPIFFE implementation such as SPIRE and a `ClusterSPIFFEID` that assigns per-sandbox IDs from the pod's `openshell.io/sandbox-id` annotation. Provider profiles with `token_grant` metadata cause the sandbox supervisor to request JWT-SVIDs and exchange them for upstream OAuth2 access tokens. +Provider token grants require a SPIFFE implementation such as SPIRE and identities for the gateway and sandbox pods. The repository's local SPIRE overlay assigns sandbox IDs from the pod's `openshell.io/sandbox-id` annotation, but the gateway validation path only requires the supervisor SVID to be valid and in the same SPIFFE trust domain as the gateway SVID. Provider profiles with `token_grant` metadata cause the sandbox supervisor to request JWT-SVIDs and exchange them for upstream OAuth2 access tokens. Token-exchange profiles also require a gateway SPIFFE identity because the gateway brokers the intermediate token exchange with its own JWT-SVID. + +The gateway verifies supervisor JWT-SVIDs with JWT bundles fetched from the SPIFFE Workload API, so intermediate token exchange does not require gateway access to the SPIRE OIDC discovery endpoint or its TLS CA. ## OIDC User Authentication @@ -74,6 +76,8 @@ helm upgrade openshell \ Both `adminRole` and `userRole` must be set, or both must be empty. Setting only one is not supported. +OIDC RBAC is method-level authorization. It controls which API operations a caller can perform, but provider and sandbox records are not owned by individual OIDC subjects. In shared clusters, treat provider credentials as gateway-wide resources and use separate gateways or external tenancy controls when users must not see or attach each other's providers and sandboxes. + ### Provider-specific rolesClaim paths | Provider | rolesClaim value | diff --git a/docs/reference/gateway-config.mdx b/docs/reference/gateway-config.mdx index ff4542136..238502995 100644 --- a/docs/reference/gateway-config.mdx +++ b/docs/reference/gateway-config.mdx @@ -196,6 +196,14 @@ sa_token_ttl_secs = 3600 provider_spiffe_workload_api_socket_path = "/spiffe-workload-api/spire-agent.sock" ``` +For token-exchange provider profiles, the gateway also needs access to its own +SPIFFE Workload API socket. In Helm deployments, set +`server.providerTokenGrants.spiffe.enabled=true`; the chart mounts the socket +into the gateway pod and sets `OPENSHELL_GATEWAY_SPIFFE_WORKLOAD_API_SOCKET`. +The gateway verifies supervisor JWT-SVIDs with JWT bundles fetched from the +SPIFFE Workload API, so this validation path does not require gateway access to +the SPIRE OIDC discovery endpoint or its TLS CA. + ### Docker Sandboxes run as containers on a local bridge network. The supervisor binary is bind-mounted from the host (no in-cluster image pull required); guest mTLS material is supplied as host paths. diff --git a/docs/sandboxes/manage-providers.mdx b/docs/sandboxes/manage-providers.mdx index ee651f695..fd8c5fc16 100644 --- a/docs/sandboxes/manage-providers.mdx +++ b/docs/sandboxes/manage-providers.mdx @@ -60,6 +60,33 @@ openshell provider create --name my-api --type generic --credential API_KEY This looks up the current value of `$API_KEY` in your shell and stores it. +### From the Current OIDC Login + +Profile-backed token-exchange providers can store the current gateway OIDC +access token as their subject credential: + +```shell +openshell provider create \ + --name custom-api \ + --type custom-api \ + --from-oidc-token +``` + +OpenShell infers the destination credential from the provider profile when the +profile has exactly one `token_grant.subject_token.credential`. If the profile +has more than one, pass `--credential `. Refresh the stored subject token +later with: + +```shell +openshell provider update custom-api --from-oidc-token +``` + +This copies the current OIDC access token and expiry from the active gateway +login. This requires an active named gateway that was registered for OIDC. If +the stored gateway access token is expired and a refresh token is available, the +CLI refreshes it before storing the provider credential. It does not store the +OIDC refresh token in the provider. + Provider profile metadata is available for known provider types. Provider profile network policy is gateway opt-in: diff --git a/docs/sandboxes/providers-v2.mdx b/docs/sandboxes/providers-v2.mdx index f78d9e060..bfb2f7d6e 100644 --- a/docs/sandboxes/providers-v2.mdx +++ b/docs/sandboxes/providers-v2.mdx @@ -155,6 +155,12 @@ category: data inference_capable: false credentials: + - name: user_oidc_token + description: User OIDC token used as a token-exchange subject token + env_vars: [CUSTOM_API_USER_OIDC_TOKEN] + required: true + auth_style: bearer + - name: api_token description: API access token env_vars: [CUSTOM_API_TOKEN] @@ -187,17 +193,23 @@ credentials: required: true secret: true - # Optional dynamic credential. The sandbox supervisor requests a - # SPIFFE JWT-SVID, exchanges it at token_endpoint, caches the returned - # access token, and injects it according to auth_style/header_name for - # matching endpoint traffic. + # Optional dynamic credential. The sandbox supervisor resolves this on + # demand for matching endpoint traffic, caches the returned access token, + # and injects it according to auth_style/header_name. token_grant: + # Accepted values: client_credentials, token_exchange. + grant_type: token_exchange token_endpoint: https://login.example.com/realms/custom/protocol/openid-connect/token audience: api://custom-api jwt_svid_audience: https://login.example.com/realms/custom - client_assertion_type: urn:ietf:params:oauth:client-assertion-type:jwt-bearer + client_assertion_type: urn:ietf:params:oauth:client-assertion-type:jwt-spiffe scopes: [api.read, api.write] cache_ttl_seconds: 300 + requested_token_type: urn:ietf:params:oauth:token-type:access_token + subject_token: + source: provider_credential + credential: user_oidc_token + subject_token_type: urn:ietf:params:oauth:token-type:access_token audience_overrides: - host: api.example.com port: 443 @@ -302,7 +314,14 @@ OpenShell keeps token endpoints profile-owned. Refresh material cannot override ### Dynamic Token Grants -`token_grant` belongs to one credential declaration. When a sandbox with the provider attached sends HTTP traffic to a matching profile endpoint, the supervisor requests a SPIFFE JWT-SVID from the local Workload API, exchanges it at `token_endpoint`, caches the returned access token, and injects it before forwarding the request upstream. Use `auth_style: bearer` to inject `Authorization: Bearer `, or `auth_style: header` with `header_name` to inject the raw access token into a custom header. Token grants do not support `query` or `path` placement. +`token_grant` belongs to one credential declaration. When a sandbox with the provider attached sends HTTP traffic to a matching profile endpoint, the supervisor resolves the dynamic credential, caches the returned access token, and injects it before forwarding the request upstream. Use `auth_style: bearer` to inject `Authorization: Bearer `, or `auth_style: header` with `header_name` to inject the raw access token into a custom header. Token grants do not support `query` or `path` placement. + +OpenShell supports two dynamic grant types: + +| Grant type | Behavior | +|---|---| +| `client_credentials` | The supervisor requests a SPIFFE JWT-SVID from the local Workload API and sends it directly to `token_endpoint` as the OAuth2 client assertion. This is the default when `grant_type` is omitted. | +| `token_exchange` | The supervisor first asks the gateway for an intermediate token. The request includes the supervisor JWT-SVID; the gateway verifies it, uses the SVID subject as the intermediate token audience, and exchanges the stored subject credential at the same `token_endpoint` using the gateway's own JWT-SVID as the client assertion. The supervisor then exchanges that intermediate token for the final upstream token using its own JWT-SVID as the client assertion. | Create provider instances for token-grant-only profiles with `--runtime-credentials`. This records an empty provider instance and makes the runtime-resolved credential source explicit: @@ -313,19 +332,39 @@ openshell provider create \ --runtime-credentials ``` +For `token_exchange` profiles, the provider also stores the user subject token referenced by `token_grant.subject_token.credential`. Create or update that provider credential from the current gateway OIDC login with `--from-oidc-token`. This requires an active named gateway that was registered for OIDC. The CLI copies the current OIDC access token and its expiry into the provider. If the stored gateway access token is expired and a refresh token is available, the CLI refreshes it first. OpenShell does not store the OIDC refresh token in the provider. When the stored subject-token credential expires, the gateway rejects intermediate token exchange until the provider is updated with a fresh token. + +```shell +openshell provider create \ + --name custom-api \ + --type custom-api \ + --from-oidc-token + +openshell provider update custom-api \ + --from-oidc-token +``` + +OpenShell infers the destination credential when the provider profile has exactly one `token_grant.subject_token.credential`. If a profile declares more than one token-exchange subject credential, pass `--credential ` to choose one. + Token grant fields: | Field | Required | Behavior | |---|---|---| +| `grant_type` | No | `client_credentials` or `token_exchange`. Defaults to `client_credentials` for backward compatibility. | | `token_endpoint` | Yes | OAuth2 token endpoint that accepts a SPIFFE JWT-SVID client assertion. Use `https://` unless the endpoint is loopback or a Kubernetes service DNS name such as `token-issuer.default.svc.cluster.local`. | -| `audience` | No | Resource audience requested from the token service. | +| `audience` | No | Resource audience requested from the token service. For `token_exchange`, this is the final exchange audience; the gateway intermediate exchange always uses the verified supervisor SVID subject as its audience. | | `jwt_svid_audience` | No | Audience used when requesting the JWT-SVID. When omitted, OpenShell derives an issuer-style audience from Keycloak token endpoint paths or falls back to the full token endpoint URL. | -| `client_assertion_type` | No | OAuth2 `client_assertion_type` form value. Defaults to RFC 7523 `urn:ietf:params:oauth:client-assertion-type:jwt-bearer`. Set a provider-specific value, such as `urn:ietf:params:oauth:client-assertion-type:jwt-spiffe`, only when the token issuer explicitly requires it. | +| `client_assertion_type` | No | OAuth2 `client_assertion_type` form value. Defaults to RFC 7523 `urn:ietf:params:oauth:client-assertion-type:jwt-bearer`. Set `urn:ietf:params:oauth:client-assertion-type:jwt-spiffe` when the token issuer expects the SPIFFE assertion type. | | `scopes` | No | OAuth2 scopes sent as a space-separated `scope` parameter. | | `cache_ttl_seconds` | No | Token cache TTL override. When omitted or `0`, OpenShell uses the token response `expires_in` with a 30-second safety margin and one-hour cap, or five minutes minus the margin if the response does not include an expiry. | -| `audience_overrides` | No | Endpoint-specific `audience` and `scopes` overrides selected by host, port, and path. | +| `requested_token_type` | No | RFC 8693 `requested_token_type` sent during token exchange. Defaults to `urn:ietf:params:oauth:token-type:access_token`. | +| `subject_token` | Required for `token_exchange` | Subject-token source used for the gateway-brokered intermediate exchange. Phase one supports `source: provider_credential`, where `credential` names another credential declared in the same profile. | +| `subject_token.subject_token_type` | No | RFC 8693 `subject_token_type` for the stored subject token. Defaults to `urn:ietf:params:oauth:token-type:access_token`. | +| `audience_overrides` | No | Endpoint-specific final-exchange `audience` and `scopes` overrides selected by host, port, and path. These overrides do not affect the gateway intermediate exchange. | + +Token grants require the sandbox supervisor to have access to a SPIFFE Workload API socket. `token_exchange` also requires the gateway to have its own Workload API socket so it can present a gateway JWT-SVID during the intermediate exchange. They apply to HTTP traffic that the proxy can inspect. Endpoints with `tls: skip` bypass TLS termination and cannot receive dynamic token grant injection for HTTPS traffic. The token service must return a token value that is safe for HTTP header placement; malformed values are rejected before caching or header injection. -Token grants require the sandbox supervisor to have access to a SPIFFE Workload API socket. They apply to HTTP traffic that the proxy can inspect. Endpoints with `tls: skip` bypass TLS termination and cannot receive dynamic token grant injection for HTTPS traffic. The token service must return a token value that is safe for HTTP header placement; malformed values are rejected before caching or header injection. +The gateway only brokers an intermediate token for a sandbox principal, and only when the requested provider is attached to that sandbox. It verifies the supervisor JWT-SVID issuer, audience, signature, and SPIFFE trust domain against the gateway's own JWT-SVID, then uses the verified supervisor SVID subject as the intermediate-token audience. ## Provider Instances diff --git a/proto/openshell.proto b/proto/openshell.proto index d701956d3..80ab9f9d8 100644 --- a/proto/openshell.proto +++ b/proto/openshell.proto @@ -153,6 +153,11 @@ service OpenShell { rpc GetSandboxProviderEnvironment(GetSandboxProviderEnvironmentRequest) returns (GetSandboxProviderEnvironmentResponse); + // Exchange a stored provider subject token for an intermediate token scoped + // to the calling supervisor's SPIFFE identity. + rpc ExchangeProviderSubjectToken(ExchangeProviderSubjectTokenRequest) + returns (ExchangeProviderSubjectTokenResponse); + // Fetch recent sandbox logs (one-shot). rpc GetSandboxLogs(GetSandboxLogsRequest) returns (GetSandboxLogsResponse); @@ -910,6 +915,25 @@ message ProviderCredentialTokenGrantAudienceOverride { // Provider credential token grant configuration. // When present, the credential is obtained dynamically via OAuth2 grant when needed. +enum ProviderCredentialTokenGrantType { + PROVIDER_CREDENTIAL_TOKEN_GRANT_TYPE_UNSPECIFIED = 0; + PROVIDER_CREDENTIAL_TOKEN_GRANT_TYPE_CLIENT_CREDENTIALS = 1; + PROVIDER_CREDENTIAL_TOKEN_GRANT_TYPE_TOKEN_EXCHANGE = 2; +} + +message ProviderCredentialTokenGrantSubjectToken { + // Source for the token exchange subject token. Phase one supports + // "provider_credential". + string source = 1; + + // Provider credential key that stores the subject token. + string credential = 2; + + // OAuth2 subject_token_type. If omitted, OpenShell uses + // urn:ietf:params:oauth:token-type:access_token. + string subject_token_type = 3; +} + message ProviderCredentialTokenGrant { // OAuth2 token endpoint URL (e.g., https://keycloak.example.com/realms/my-realm/protocol/openid-connect/token) string token_endpoint = 1; @@ -934,6 +958,17 @@ message ProviderCredentialTokenGrant { // Optional: OAuth2 client_assertion_type value. If omitted, OpenShell uses // urn:ietf:params:oauth:client-assertion-type:jwt-bearer. string client_assertion_type = 7; + + // Grant type. If omitted/unspecified, OpenShell treats this as client_credentials + // for backwards compatibility. + ProviderCredentialTokenGrantType grant_type = 8; + + // Subject token metadata for token_exchange grants. + ProviderCredentialTokenGrantSubjectToken subject_token = 9; + + // OAuth2 requested_token_type. If omitted for token_exchange, OpenShell uses + // urn:ietf:params:oauth:token-type:access_token. + string requested_token_type = 10; } // Provider credential declaration. @@ -1151,6 +1186,27 @@ message GetSandboxProviderEnvironmentResponse { map dynamic_credentials = 4; } +message ExchangeProviderSubjectTokenRequest { + // The sandbox ID. Must match the authenticated sandbox principal. + string sandbox_id = 1; + + // Attached provider record holding the configured subject token credential. + string provider = 2; + + // Provider profile credential that declares the token_exchange grant. + string credential_key = 3; + + // Supervisor JWT-SVID. The gateway verifies this and uses its `sub` claim + // as the requested audience for the intermediate token. + string supervisor_jwt_svid = 4; +} + +message ExchangeProviderSubjectTokenResponse { + string access_token = 1; + int64 expires_in = 2; + string token_type = 3; +} + // --------------------------------------------------------------------------- // Policy update messages // ---------------------------------------------------------------------------